├── LICENSE ├── Makefile ├── README.md ├── data ├── figures ├── models ├── notebooks ├── discrete_diffusion_graph.ipynb ├── discrete_diffusion_graph_controlled_cliques.ipynb ├── discrete_diffusion_graph_controlled_community.ipynb ├── discrete_diffusion_graph_controlled_molecule_like.ipynb ├── discrete_diffusion_graph_controlled_zinc.ipynb ├── discrete_diffusion_graph_degree.ipynb ├── discrete_diffusion_graph_sanity_checks.ipynb ├── discrete_diffusion_mnist.ipynb ├── figures │ ├── controlled_cliques.ipynb │ ├── controlled_community.ipynb │ ├── controlled_molecule_like.ipynb │ ├── discrete_diffusion_summary.ipynb │ ├── mmd_performance.ipynb │ └── mmd_performance_nocache.ipynb ├── generative_performance_controlled_cliques.ipynb ├── generative_performance_controlled_community.ipynb ├── generative_performance_controlled_molecule_like.ipynb ├── generative_performance_kernel_comparison_controlled_cliques.ipynb ├── generative_performance_kernel_comparison_controlled_community.ipynb └── zinc_sandbox.ipynb ├── references ├── bibtex.bib └── thumbnail.png ├── results └── src ├── analysis ├── graph_metrics.py ├── mmd.py └── orca.cpp ├── feature ├── __init__.py ├── graph_conversions.py ├── molecule_dataset.py └── random_graph_dataset.py ├── model ├── __init__.py ├── digress_gnn.py ├── discrete_diffusers.py ├── generate.py ├── gnn.py ├── image_unet.py ├── train_model.py └── util.py └── plot ├── __init__.py └── plot.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2022, Genentech, Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | install-dependencies: 2 | # Must be running Python 3.8 3 | conda install -y -c pytorch pytorch=1.11.0 4 | TORCH=$(python -c "import torch; print(torch.__version__)") 5 | CUDA=$(python -c "import torch; print(torch.version.cuda)") 6 | 7 | pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html 8 | pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html # This installs NumPy and SciPy 9 | 10 | pip install torch-geometric # This installs tqdm and scikit-learn 11 | 12 | pip install networkx==2.6 13 | conda install -y -c anaconda click pymongo jupyter pandas 14 | conda install -y -c conda-forge matplotlib 15 | pip install sacred tables vdom 16 | conda install -y h5py 17 | 18 | # Note about torch-scatter: 19 | # It is possible that this installation pipeline will not install torch-scatter 20 | # successfully. In particular, torch-scatter might import fine, but it may not 21 | # be usable on GPU. Using certain PyTorch Geometric operations might cause a 22 | # "Not compiled with CUDA support" error to be thrown on GPU. To fix this, 23 | # reinstall torch-scatter as follows: 24 | # 1. Navigate to https://data.pyg.org/whl/, which has all the pip wheel files 25 | # for PyTorch Geometric 26 | # 2. Go to https://data.pyg.org/whl/torch-1.11.0+cu113.html 27 | # Replace the versions of PyTorch and CUDA appropriately, using the same 28 | # versions as $TORCH and $CUDA above; for example, cu113 is for CUDA 11.3 29 | # 3. Download the wheel file `torch_scatter-2.0.9-cp38-cp38-linux_x86_64.whl`, 30 | # which is for Linux; the `cp38` refers to Python 3.8, which is what we're 31 | # using 32 | # 4. Uninstall torch-scatter using `pip uninstall torch-scatter` and manually 33 | # install the wheel file using `pip install torch_scatter*.whl` 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GraphGUIDE: interpretable and controllable conditional graph generation with discrete Bernoulli diffusion 2 | 3 |

4 | 5 |

6 | 7 | ### Introduction 8 | 9 | Diffusion models achieve state-of-the art performance in generating realistic objects, with the most useful and impactful work in diffusion being on _conditional_ generation, where objects of certain specific properties are generated. 10 | 11 | Unfortunately, it has remained difficult to perform conditional generation on graphs, particularly due to the discrete nature of graphs and the large number of possible structural properties which could be desired in any given generation task. 12 | 13 | This repository implements GraphGUIDE, a framework for graph generation using diffusion models, where edges in the graph are flipped or set at each discrete time step. GraphGUIDE enables full control over the conditional generation of arbitrary structural properties without relying on predefined labels. This framework for graph diffusion allows for the interpretable conditional generation of graphs, including the generation of drug-like molecules with desired properties in a way which is informed by experimental evidence. 14 | 15 | See the [corresponding paper](https://arxiv.org/abs/2302.03790) for more information. 16 | 17 | This repository houses all of the code used to generate the results for the paper, including code that processes data, trains models, implements GraphGUIDE, and generates all figures in the paper. 18 | 19 | ### Citing this work 20 | 21 | If you found GraphGUIDE to be helpful for your work, please cite the following: 22 | 23 | Tseng, A.M., Diamant, N., Biancalani, T., Scalia, G. GraphGUIDE: interpretable and controllable conditional graph generation with discrete Bernoulli diffusion. arXiv (2023) [Link](https://arxiv.org/abs/2302.03790) 24 | 25 | [\[BibTeX\]](references/bibtex.bib) 26 | 27 | ### Description of files and directories 28 | 29 | ``` 30 | ├── Makefile <- Installation of dependencies 31 | ├── data <- Contains data for training and downstream analysis 32 | │ ├── raw <- Raw data, directly downloaded from the source 33 | │ ├── interim <- Intermediate data mid-processing 34 | │ ├── processed <- Processed data ready for training or analysis 35 | │ └── README.md <- Description of data 36 | ├── models 37 | │ └── trained_models <- Trained models 38 | ├── notebooks <- Jupyter notebooks that explore data, plot results, and analyze results 39 | │ └── figures <- Jupyter notebooks that create figures 40 | ├── results <- Saved results 41 | ├── README.md <- This file 42 | └── src <- Source code 43 | ├── feature <- Code for data loading and featurization 44 | ├── model <- Code for model architectures and training 45 | ├── analysis <- Code for analyzing results 46 | └── plot <- Code for plotting and visualization 47 | ``` 48 | -------------------------------------------------------------------------------- /data: -------------------------------------------------------------------------------- 1 | /gstore/data/resbioai/tsenga5/discrete_graph_diffusion/data -------------------------------------------------------------------------------- /figures: -------------------------------------------------------------------------------- 1 | /gstore/data/resbioai/tsenga5/discrete_graph_diffusion/figures/ -------------------------------------------------------------------------------- /models: -------------------------------------------------------------------------------- 1 | /gstore/data/resbioai/tsenga5/discrete_graph_diffusion/models -------------------------------------------------------------------------------- /notebooks/figures/mmd_performance.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "5b894b89", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import json\n", 12 | "import numpy as np" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "id": "6f3afe9a", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "# Define MMD values from other papers\n", 23 | "\n", 24 | "metric_list = [\"Degree\", \"Cluster\", \"Spectrum\", \"Orbit\"]\n", 25 | "\n", 26 | "benchmark_mmds = { # From SPECTRE paper\n", 27 | " \"Community (small)\": {\n", 28 | " \"GraphRNN\": [0.08, 0.12, -1, 0.04],\n", 29 | " \"GRAN\": [0.06, 0.11, -1, 0.01],\n", 30 | " \"MolGAN\": [0.06, 0.13, -1, 0.01],\n", 31 | " \"SPECTRE\": [0.02, 0.21, -1, 0.01]\n", 32 | " \n", 33 | " },\n", 34 | " \"Stochastic block models\": {\n", 35 | " \"GraphRNN\": [0.0055, 0.0584, 0.0065, 0.0785],\n", 36 | " \"GRAN\": [0.0113, 0.0553, 0.0054, 0.0540],\n", 37 | " \"MolGAN\": [0.0235, 0.1161, 0.0117, 0.0712],\n", 38 | " \"SPECTRE\": [0.0079, 0.0528, 0.0643, 0.0074]\n", 39 | " }\n", 40 | "}\n", 41 | "\n", 42 | "benchmark_mmd_ratios = { # From DiGress paper\n", 43 | " \"Community (small)\": {\n", 44 | " \"GraphRNN\": [4.0, 1.7, -1, 4.0],\n", 45 | " \"GRAN\": [3.0, 1.6, -1, 1.0],\n", 46 | " \"SPECTRE\": [0.5, 2.7, -1, 2.0],\n", 47 | " \"DiGress\": [1.0, 0.9, -1, 1.0],\n", 48 | " \n", 49 | " },\n", 50 | " \"Stochastic block models\": {\n", 51 | " \"GraphRNN\": [6.9, 1.7, -1, 3.1],\n", 52 | " \"GRAN\": [14.1, 1.7, -1, 2.1],\n", 53 | " \"SPECTRE\": [1.9, 1.6, -1, 1.6],\n", 54 | " \"DiGress\": [1.6, 1.5, -1, 1.7]\n", 55 | " }\n", 56 | "}\n", 57 | "\n", 58 | "benchmark_baselines = { # From SPECTRE paper\n", 59 | " \"Community (small)\": [0.02, 0.07, 1, 0.01],\n", 60 | " \"Stochastic block models\": [0.0008, 0.0332, 0.0063, 0.0255]\n", 61 | "}" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "id": "05ea7b2d", 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "def get_best_mmds(run_dir):\n", 72 | " # First, get the best run based on last loss\n", 73 | " best_loss, best_metrics = float(\"inf\"), None\n", 74 | " for run_num in os.listdir(run_dir):\n", 75 | " if run_num == \"_sources\":\n", 76 | " continue\n", 77 | " metrics_path = os.path.join(run_dir, run_num, \"metrics.json\")\n", 78 | " with open(metrics_path, \"r\") as f:\n", 79 | " metrics = json.load(f)\n", 80 | " last_loss = metrics[\"train_epoch_loss\"][\"values\"][-1]\n", 81 | " if last_loss < best_loss:\n", 82 | " best_loss, best_metrics = last_loss, metrics\n", 83 | " \n", 84 | " # Now return the MMDs and baselines\n", 85 | " return (\n", 86 | " [\n", 87 | " best_metrics[\"degree_mmd\"][\"values\"][0],\n", 88 | " best_metrics[\"cluster_coef_mmd\"][\"values\"][0],\n", 89 | " best_metrics[\"spectra_mmd\"][\"values\"][0],\n", 90 | " best_metrics[\"orbit_mmd\"][\"values\"][0]\n", 91 | " ],\n", 92 | " [\n", 93 | " best_metrics[\"degree_mmd_baseline\"][\"values\"][0],\n", 94 | " best_metrics[\"cluster_coef_mmd_baseline\"][\"values\"][0],\n", 95 | " best_metrics[\"spectra_mmd_baseline\"][\"values\"][0],\n", 96 | " best_metrics[\"orbit_mmd_baseline\"][\"values\"][0]\n", 97 | " ]\n", 98 | " )" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 4, 104 | "id": "aa392c06", 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "# Import MMD and baseline values from training runs\n", 109 | "\n", 110 | "base_path = \"/gstore/home/tsenga5/discrete_graph_diffusion/models/trained_models/\"\n", 111 | "my_mmds_and_baselines = {\n", 112 | " \"Community (small)\": {\n", 113 | " \"Edge-flip\": get_best_mmds(os.path.join(base_path, \"benchmark_community-small_edge-flip\")),\n", 114 | " \"Edge-one\": get_best_mmds(os.path.join(base_path, \"benchmark_community-small_edge-addition\")),\n", 115 | " \"Edge-zero\": get_best_mmds(os.path.join(base_path, \"benchmark_community-small_edge-deletion\"))\n", 116 | " },\n", 117 | " \"Stochastic block models\": {\n", 118 | " \"Edge-flip\": get_best_mmds(os.path.join(base_path, \"benchmark_sbm_edge-flip\")),\n", 119 | " \"Edge-one\": get_best_mmds(os.path.join(base_path, \"benchmark_sbm_edge-addition\")),\n", 120 | " \"Edge-zero\": get_best_mmds(os.path.join(base_path, \"benchmark_sbm_edge-deletion\"))\n", 121 | " }\n", 122 | "}\n", 123 | "\n", 124 | "my_mmds = {d_key : {k_key : vals[0] for k_key, vals in d_dict.items()} for d_key, d_dict in my_mmds_and_baselines.items()}\n", 125 | "my_baselines = {d_key : {k_key : vals[1] for k_key, vals in d_dict.items()} for d_key, d_dict in my_mmds_and_baselines.items()}" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 5, 131 | "id": "90b38078", 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "Community (small)\n", 139 | "GraphRNN & 2.00 & 1.31 & 2.00\n", 140 | "GRAN & 1.73 & 1.25 & 1.00\n", 141 | "MolGAN & 1.73 & 1.36 & 1.00\n", 142 | "SPECTRE & 1.00 & 1.73 & 1.00\n", 143 | "DiGress & 1.00 & 0.95 & 1.00\n", 144 | "Edge-flip & 0.99 & 0.58 & 2.55\n", 145 | "Edge-one & 1.21 & 0.62 & 1.83\n", 146 | "Edge-zero & 1.87 & 1.02 & 4.69\n", 147 | "=========================\n", 148 | "Stochastic block models\n", 149 | "GraphRNN & 2.62 & 1.33 & 1.75\n", 150 | "GRAN & 3.76 & 1.29 & 1.46\n", 151 | "MolGAN & 5.42 & 1.87 & 1.67\n", 152 | "SPECTRE & 3.14 & 1.26 & 0.54\n", 153 | "DiGress & 1.26 & 1.22 & 1.30\n", 154 | "Edge-flip & 2.73 & 1.23 & 0.94\n", 155 | "Edge-one & 1.00 & 1.21 & 0.81\n", 156 | "Edge-zero & 1.31 & 1.19 & 0.80\n", 157 | "=========================\n" 158 | ] 159 | }, 160 | { 161 | "name": "stderr", 162 | "output_type": "stream", 163 | "text": [ 164 | "/local/60676799/ipykernel_28989/435977596.py:4: RuntimeWarning: invalid value encountered in sqrt\n", 165 | " vals = np.sqrt(vals)\n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "# Print out results\n", 171 | "\n", 172 | "def print_vals(key, vals):\n", 173 | " vals = np.sqrt(vals)\n", 174 | " print(\"%s & %.2f & %.2f & %.2f\" % (key, vals[0], vals[1], vals[3]))\n", 175 | "\n", 176 | "for d_key in my_mmds.keys():\n", 177 | " print(d_key)\n", 178 | " \n", 179 | " for bm_key, bm_vals in benchmark_mmds[d_key].items():\n", 180 | " print_vals(bm_key, np.array(bm_vals) / np.array(benchmark_baselines[d_key]))\n", 181 | " # print(bm_key, np.array(bm_vals) / np.array(my_baselines[d_key][\"Edge-flip\"]))\n", 182 | "# for bm_key, bm_vals in benchmark_mmd_ratios[d_key].items():\n", 183 | "# print(bm_key, np.array(bm_vals))\n", 184 | " print_vals(\"DiGress\", np.array(benchmark_mmd_ratios[d_key][\"DiGress\"]))\n", 185 | " \n", 186 | " for my_key, my_vals in my_mmds[d_key].items():\n", 187 | " print_vals(my_key, np.array(my_vals) / np.array(benchmark_baselines[d_key]))\n", 188 | "# print(my_key, np.array(my_vals) / np.array(my_baselines[d_key][my_key]))\n", 189 | " print(\"=========================\")" 190 | ] 191 | } 192 | ], 193 | "metadata": { 194 | "kernelspec": { 195 | "display_name": "Python 3 (ipykernel)", 196 | "language": "python", 197 | "name": "python3" 198 | }, 199 | "language_info": { 200 | "codemirror_mode": { 201 | "name": "ipython", 202 | "version": 3 203 | }, 204 | "file_extension": ".py", 205 | "mimetype": "text/x-python", 206 | "name": "python", 207 | "nbconvert_exporter": "python", 208 | "pygments_lexer": "ipython3", 209 | "version": "3.8.13" 210 | } 211 | }, 212 | "nbformat": 4, 213 | "nbformat_minor": 5 214 | } 215 | -------------------------------------------------------------------------------- /notebooks/figures/mmd_performance_nocache.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "5b894b89", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import json\n", 12 | "import numpy as np" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 8, 18 | "id": "05ea7b2d", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "def get_best_mmds(run_dir):\n", 23 | " # First, get the best run based on last loss\n", 24 | " best_loss, best_metrics = float(\"inf\"), None\n", 25 | " for run_num in os.listdir(run_dir):\n", 26 | " if run_num == \"_sources\":\n", 27 | " continue\n", 28 | " metrics_path = os.path.join(run_dir, run_num, \"metrics.json\")\n", 29 | " with open(metrics_path, \"r\") as f:\n", 30 | " metrics = json.load(f)\n", 31 | " try:\n", 32 | " last_loss = metrics[\"train_epoch_loss\"][\"values\"][-1]\n", 33 | " \n", 34 | " # Try and get a metric out\n", 35 | " _ = metrics[\"orbit_mmd\"][\"values\"]\n", 36 | " except KeyError:\n", 37 | " print(\"Warning: Did not find finished run in %s\" % os.path.join(run_dir, run_num))\n", 38 | " last_loss = float(\"inf\")\n", 39 | " if last_loss < best_loss:\n", 40 | " best_loss, best_metrics = last_loss, metrics\n", 41 | " \n", 42 | " # Now return the MMDs and baselines\n", 43 | " return (\n", 44 | " [\n", 45 | " best_metrics[\"degree_mmd\"][\"values\"][0],\n", 46 | " best_metrics[\"cluster_coef_mmd\"][\"values\"][0],\n", 47 | "# best_metrics[\"spectra_mmd\"][\"values\"][0],\n", 48 | " best_metrics[\"orbit_mmd\"][\"values\"][0]\n", 49 | " ],\n", 50 | " [\n", 51 | " best_metrics[\"degree_mmd_baseline\"][\"values\"][0],\n", 52 | " best_metrics[\"cluster_coef_mmd_baseline\"][\"values\"][0],\n", 53 | "# best_metrics[\"spectra_mmd_baseline\"][\"values\"][0],\n", 54 | " best_metrics[\"orbit_mmd_baseline\"][\"values\"][0]\n", 55 | " ]\n", 56 | " )" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 9, 62 | "id": "aa392c06", 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "Warning: Did not find finished run in /gstore/home/tsenga5/discrete_graph_diffusion/models/trained_models/benchmark-nocache_sbm_edge-addition/3\n", 70 | "Warning: Did not find finished run in /gstore/home/tsenga5/discrete_graph_diffusion/models/trained_models/benchmark-nocache_sbm_edge-deletion/4\n", 71 | "Warning: Did not find finished run in /gstore/home/tsenga5/discrete_graph_diffusion/models/trained_models/benchmark-nocache_sbm_edge-deletion/1\n" 72 | ] 73 | } 74 | ], 75 | "source": [ 76 | "# Import MMD and baseline values from training runs\n", 77 | "\n", 78 | "base_path = \"/gstore/home/tsenga5/discrete_graph_diffusion/models/trained_models/\"\n", 79 | "my_mmds_and_baselines = {\n", 80 | " \"Community (small)\": {\n", 81 | " \"Edge-flip\": get_best_mmds(os.path.join(base_path, \"benchmark-nocache_community-small_edge-flip\")),\n", 82 | " \"Edge-one\": get_best_mmds(os.path.join(base_path, \"benchmark-nocache_community-small_edge-addition\")),\n", 83 | " \"Edge-zero\": get_best_mmds(os.path.join(base_path, \"benchmark-nocache_community-small_edge-deletion\"))\n", 84 | " },\n", 85 | " \"Stochastic block models\": {\n", 86 | " \"Edge-flip\": get_best_mmds(os.path.join(base_path, \"benchmark-nocache_sbm_edge-flip\")),\n", 87 | " \"Edge-one\": get_best_mmds(os.path.join(base_path, \"benchmark-nocache_sbm_edge-addition\")),\n", 88 | " \"Edge-zero\": get_best_mmds(os.path.join(base_path, \"benchmark-nocache_sbm_edge-deletion\"))\n", 89 | " }\n", 90 | "}\n", 91 | "\n", 92 | "my_mmds = {d_key : {k_key : vals[0] for k_key, vals in d_dict.items()} for d_key, d_dict in my_mmds_and_baselines.items()}\n", 93 | "my_baselines = {d_key : {k_key : vals[1] for k_key, vals in d_dict.items()} for d_key, d_dict in my_mmds_and_baselines.items()}" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 13, 99 | "id": "667229b5", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "benchmark_baselines = { # From SPECTRE paper\n", 104 | " \"Community (small)\": [0.02, 0.07, 0.01],\n", 105 | " \"Stochastic block models\": [0.0008, 0.0332, 0.0255]\n", 106 | "}" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 15, 112 | "id": "90b38078", 113 | "metadata": {}, 114 | "outputs": [ 115 | { 116 | "name": "stdout", 117 | "output_type": "stream", 118 | "text": [ 119 | "Community (small)\n", 120 | "Edge-flip & 0.0072 & 0.0319 & 0.0033 & 0.36 & 0.46 & 0.33 \\\\\n", 121 | "Edge-one & 0.1913 & 0.1332 & 0.2977 & 9.56 & 1.90 & 29.77 \\\\\n", 122 | "Edge-zero & 0.0171 & 0.0274 & 0.0046 & 0.85 & 0.39 & 0.46 \\\\\n", 123 | "=========================\n", 124 | "Stochastic block models\n", 125 | "Edge-flip & 0.0409 & 0.0359 & 0.0147 & 51.17 & 1.08 & 0.58 \\\\\n", 126 | "Edge-one & 0.0313 & 0.0328 & 0.0160 & 39.16 & 0.99 & 0.63 \\\\\n", 127 | "Edge-zero & 0.0013 & 0.0337 & 0.0168 & 1.59 & 1.02 & 0.66 \\\\\n", 128 | "=========================\n" 129 | ] 130 | } 131 | ], 132 | "source": [ 133 | "# Print out results\n", 134 | "\n", 135 | "def print_vals(key, vals):\n", 136 | " vals = np.sqrt(vals)\n", 137 | " print(\"%s & %.2f & %.2f & %.2f\" % (key, vals[0], vals[1], vals[3]))\n", 138 | "\n", 139 | "for d_key in my_mmds.keys():\n", 140 | " print(d_key) \n", 141 | " for my_key, my_vals in my_mmds[d_key].items():\n", 142 | " v = np.array(my_vals)\n", 143 | " b = np.array(my_baselines[d_key][my_key])\n", 144 | " b = np.array(benchmark_baselines[d_key])\n", 145 | " s = my_key + \" & \"\n", 146 | " s += \" & \".join([\"%.4f\" % x for x in v]) + \" & \"\n", 147 | "# s += \" & \".join([\"%.4f\" % x for x in b]) + \" & \"\n", 148 | " s += \" & \".join([\"%.2f\" % x for x in (v / b)]) + \" \\\\\\\\\"\n", 149 | " print(s)\n", 150 | " print(\"=========================\")" 151 | ] 152 | } 153 | ], 154 | "metadata": { 155 | "kernelspec": { 156 | "display_name": "Python 3 (ipykernel)", 157 | "language": "python", 158 | "name": "python3" 159 | }, 160 | "language_info": { 161 | "codemirror_mode": { 162 | "name": "ipython", 163 | "version": 3 164 | }, 165 | "file_extension": ".py", 166 | "mimetype": "text/x-python", 167 | "name": "python", 168 | "nbconvert_exporter": "python", 169 | "pygments_lexer": "ipython3", 170 | "version": "3.8.13" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 5 175 | } 176 | -------------------------------------------------------------------------------- /references/bibtex.bib: -------------------------------------------------------------------------------- 1 | @misc{https://doi.org/10.48550/arxiv.2302.03790, 2 | doi = {10.48550/ARXIV.2302.03790}, 3 | 4 | url = {https://arxiv.org/abs/2302.03790}, 5 | 6 | author = {Tseng, Alex M. and Diamant, Nathaniel and Biancalani, Tommaso and Scalia, Gabriele}, 7 | 8 | keywords = {Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences}, 9 | 10 | title = {GraphGUIDE: interpretable and controllable conditional graph generation with discrete Bernoulli diffusion}, 11 | 12 | publisher = {arXiv}, 13 | 14 | year = {2023}, 15 | 16 | copyright = {Creative Commons Attribution 4.0 International} 17 | } 18 | -------------------------------------------------------------------------------- /references/thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/GraphGUIDE/dad0dd371268684a5203839441febe0484d8a3e4/references/thumbnail.png -------------------------------------------------------------------------------- /results: -------------------------------------------------------------------------------- 1 | /gstore/data/resbioai/tsenga5/discrete_graph_diffusion/results -------------------------------------------------------------------------------- /src/analysis/graph_metrics.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import scipy.linalg 4 | import tempfile 5 | import os 6 | import subprocess 7 | 8 | ORCA_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "orca") 9 | 10 | def get_degrees(graphs): 11 | """ 12 | Computes the degrees of nodes in graphs. 13 | Arguments: 14 | `graphs`: a list of N NetworkX undirected graphs 15 | Returns a list of N NumPy arrays (each of size V) of degrees, where each 16 | array is ordered by the nodes in the graph. Note that V can be different for 17 | each graph. 18 | """ 19 | result = [] 20 | for g in graphs: 21 | degrees = nx.degree(g) 22 | result.append(np.array([degrees[i] for i in g.nodes])) 23 | return result 24 | 25 | 26 | def get_clustering_coefficients(graphs): 27 | """ 28 | Computes the clustering coefficients of nodes in graphs. 29 | Arguments: 30 | `graphs`: a list of N NetworkX undirected graphs 31 | Returns a list of N NumPy arrays (each of size V) of clustering 32 | coefficients, where each array is ordered by the nodes in the graph. Note 33 | that V can be different for each graph. 34 | """ 35 | result = [] 36 | for g in graphs: 37 | coefs = nx.clustering(g) 38 | result.append(np.array([coefs[i] for i in g.nodes])) 39 | return result 40 | 41 | 42 | def get_spectra(graphs): 43 | """ 44 | Computes the spectrum of graphs as the eigenvalues of the normalized 45 | Laplacian. 46 | Arguments: 47 | `graphs`: a list of N NetworkX undirected graphs 48 | Returns a list of N NumPy arrays (each of size V) of eigenvalues. Note that 49 | V can be different for each graph. 50 | """ 51 | return [nx.normalized_laplacian_spectrum(g) for g in graphs] 52 | 53 | 54 | def run_orca(graph, max_graphlet_size=4): 55 | """ 56 | Runs Orca on a given graph to count the number of orbits of each type each 57 | node belongs in. 58 | Arguments: 59 | `graph`: a NetworkX undirected graph 60 | `max_graphlet_size`: maximum size of graphlets whose orbits to count; 61 | must be either 4 or 5 62 | Returns a V x O NumPy array of orbit counts for each of the V nodes. 63 | """ 64 | # Create temporary directory to do work in 65 | temp_dir_obj = tempfile.TemporaryDirectory() 66 | temp_dir = temp_dir_obj.name 67 | in_path = os.path.join(temp_dir, "graph.in") 68 | out_path = os.path.join(temp_dir, "orca.out") 69 | 70 | # Create input file 71 | with open(in_path, "w") as f: 72 | edges = graph.edges 73 | f.write("%d %d\n" % (max(graph.nodes) + 1, len(edges))) 74 | for edge in edges: 75 | f.write("%d %d\n" % edge) 76 | 77 | # Run Orca 78 | with open(os.devnull, "w") as f: 79 | subprocess.check_call([ 80 | ORCA_PATH, "node", str(max_graphlet_size), in_path, out_path 81 | ], stdout=f) 82 | 83 | # Read in result 84 | with open(out_path, "r") as f: 85 | result = np.stack([ 86 | np.array(list(map(int, line.strip().split()))) for line in f 87 | ]) 88 | 89 | temp_dir_obj.cleanup() 90 | 91 | return result 92 | 93 | 94 | def get_orbit_counts(graphs, max_graphlet_size=4): 95 | """ 96 | Computes the orbit counts of nodes in graphs. The orbit counts of a node are 97 | the number of times it appears in each orbit of the possible graphlets of 98 | size up to `max_graphlet_size`. For example, for `max_graphlet_size` of 4, 99 | there are 15 orbits possible for a node. 100 | Orbits are computed using Orca: 101 | https://academic.oup.com/bioinformatics/article/30/4/559/205331 102 | Arguments: 103 | `graphs`: a list of N NetworkX undirected graphs 104 | `max_graphlet_size`: maximum size of graphlets whose orbits to count; 105 | must be either 4 or 5 106 | Returns a list of N NumPy arrays (of size V x O) of orbit counts, where each 107 | array is ordered by the nodes in the graph. Note that V can be different for 108 | each graph, but O is the same for the same `max_graphlet_size`). 109 | """ 110 | assert max_graphlet_size in (4, 5) 111 | return [run_orca(graph, max_graphlet_size) for graph in graphs] 112 | 113 | 114 | if __name__ == "__main__": 115 | graphs = [ 116 | nx.erdos_renyi_graph( 117 | np.random.choice(np.arange(10, 20)), 118 | np.random.choice(np.linspace(0, 1, 10)) 119 | ) 120 | for _ in range(100) 121 | ] 122 | 123 | degrees = get_degrees(graphs) 124 | cluster_coefs = get_clustering_coefficients(graphs) 125 | spectra = get_spectra(graphs) 126 | orbit_counts = get_orbit_counts(graphs) 127 | -------------------------------------------------------------------------------- /src/analysis/mmd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats 3 | 4 | def make_histograms( 5 | value_arrs, num_bins=None, bin_width=None, bin_array=None, frequency=False, 6 | epsilon=1e-6 7 | ): 8 | """ 9 | Given a set of value arrays, converts them into histograms of counts or 10 | frequencies. The bins may be specified as a number of bins, a bin width, or 11 | a pre-defined array of bin edges. This function creates histograms such that 12 | all value arrays given are transformed into the same histogram space. 13 | Arguments: 14 | `value_arrs`: an iterable of N 1D NumPy arrays to make histograms of 15 | `num_bins`: if given, make the histograms have this number of bins total 16 | `bin_width`: if given, make the histograms have bins of this width, 17 | aligned starting at the minimum value 18 | `bin_array`: if given, make the histograms according to this NumPy array 19 | of bin edges 20 | `frequency`: if True, normalize each histogram into frequencies 21 | `epsilon`: small number for stability of last endpoint if `bin_width` is 22 | specified 23 | Returns an N x B array of counts or frequencies (N is parallel to the input 24 | `value_arrs`), where B is the number of bins in the histograms. 25 | """ 26 | # Compute bins if needed 27 | if num_bins is not None: 28 | assert bin_width is None and bin_array is None 29 | min_val = min(np.nanmin(arr) for arr in value_arrs) 30 | max_val = max(np.nanmax(arr) for arr in value_arrs) 31 | bin_array = np.linspace(min_val, max_val, num_bins + 1) 32 | elif bin_width is not None: 33 | assert num_bins is None and bin_array is None 34 | min_val = min(np.nanmin(arr) for arr in value_arrs) 35 | max_val = max(np.nanmax(arr) for arr in value_arrs) + bin_width + \ 36 | epsilon 37 | bin_array = np.arange(min_val, max_val, bin_width) 38 | elif bin_array is not None: 39 | assert num_bins is None and bin_width is None 40 | else: 41 | raise ValueError( 42 | "Must specify one of `num_bins`, `bin_width`, or `bin_array`" 43 | ) 44 | 45 | # Compute histograms 46 | hists = np.empty((len(value_arrs), len(bin_array) - 1)) 47 | for i, arr in enumerate(value_arrs): 48 | hist = np.histogram(arr, bins=bin_array)[0] 49 | if frequency: 50 | hist = hist / len(arr) 51 | hists[i] = hist 52 | 53 | return hists 54 | 55 | 56 | def gaussian_kernel(vec_1, vec_2, sigma=1): 57 | """ 58 | Computes the Gaussian kernel function on two vectors. This is also known as 59 | the radial basis function. 60 | Arguments: 61 | `vec_1`: a NumPy array of values 62 | `vec_2`: a NumPy array of values; the underlying vector space must be 63 | the same as `vec_1` 64 | `sigma`: standard deviation for the Gaussian kernel 65 | Returns a scalar similarity value between 0 and 1. 66 | """ 67 | l2_dist_squared = np.sum(np.square(vec_1 - vec_2)) 68 | return np.exp(-l2_dist_squared / (2 * sigma * sigma)) 69 | 70 | 71 | def gaussian_wasserstein_kernel(vec_1, vec_2, sigma=1): 72 | """ 73 | Computes the Gaussian kernel function on two vectors, where the similarity 74 | metric within the Gaussian is the Wasserstein distance (i.e. Earthmover's 75 | distance). The two vectors must be distributions represented as PMFs over 76 | the same probability space. 77 | Arguments: 78 | `vec_1`: a NumPy array representing a PMF distribution (values are 79 | probabilities) 80 | `vec_2`: a NumPy array representing a PMF distribution (values are 81 | probabilities); the underlying probability space (i.e. support) must 82 | be the same as `vec_1` 83 | `sigma`: standard deviation for the Gaussian kernel 84 | Returns a scalar similarity value between 0 and 1. 85 | """ 86 | assert vec_1.shape == vec_2.shape 87 | # Since the vectors are supposed to be PMFs, if everything is 0 then just 88 | # turn it into an (unnormalized) uniform distribution 89 | if np.all(vec_1 == 0): 90 | vec_1 = np.ones_like(vec_1) 91 | if np.all(vec_2 == 0): 92 | vec_2 = np.ones_like(vec_2) 93 | # The SciPy Wasserstein distance function takes in empirical observations 94 | # instead of histograms/distributions as an input, but we can get the same 95 | # result by specifying weights which are the PMF probabilities 96 | w_dist = scipy.stats.wasserstein_distance( 97 | np.arange(len(vec_1)), np.arange(len(vec_1)), vec_1, vec_2 98 | ) 99 | return np.exp(-(w_dist * w_dist) / (2 * sigma * sigma)) 100 | 101 | 102 | def gaussian_total_variation_kernel(vec_1, vec_2, sigma=1): 103 | """ 104 | Computes the Gaussian kernel function on two vectors, where the similarity 105 | metric within the Gaussian is the total variation between the two vectors. 106 | Arguments: 107 | `vec_1`: a NumPy array of values 108 | `vec_2`: a NumPy array of values; the underlying vector space must be 109 | the same as `vec_1` 110 | `sigma`: standard deviation for the Gaussian kernel 111 | Returns a scalar similarity value between 0 and 1. 112 | """ 113 | tv_dist = np.sum(np.abs(vec_1 - vec_2)) / 2 114 | return np.exp(-(tv_dist * tv_dist) / (2 * sigma * sigma)) 115 | 116 | 117 | def compute_inner_prod_feature_mean(dist_1, dist_2, kernel_type, **kwargs): 118 | """ 119 | Given two empirical distributions of vectors, computes the inner product of 120 | the feature means using the specified kernel. This is equivalent to the 121 | expected/average kernel function on all pairs of vectors between the two 122 | distributions. 123 | Arguments: 124 | `dist_1`: an M x D NumPy array of M vectors, each of size D; all vectors 125 | must share the same underlying vector space (or probability space if 126 | the vectors represent a probability distribution) with each other 127 | and with `dist_2` 128 | `dist_2`: an M x D NumPy array of M vectors, each of size D; all vectors 129 | must share the same underlying vector space (or probability space if 130 | the vectors represent a probability distribution) with each other 131 | and with `dist_1` 132 | `kernel_type`: type of kernel to apply for computing the kernelized 133 | inner product; can be "gaussian", "gaussian_wasserstein", or 134 | "gaussian_total_variation" 135 | `kwargs`: extra keyword arguments to be passed to the kernel function 136 | Returns a scalar which is the average kernelized inner product between all 137 | pairs of vectors across the two distributions. 138 | """ 139 | if kernel_type == "gaussian": 140 | kernel_func = gaussian_kernel 141 | elif kernel_type == "gaussian_wasserstein": 142 | kernel_func = gaussian_wasserstein_kernel 143 | elif kernel_type == "gaussian_total_variation": 144 | kernel_func = gaussian_total_variation_kernel 145 | else: 146 | raise ValueError("Unknown kernel type: %s" % kernel_type) 147 | 148 | inner_prods = [] 149 | for vec_1 in dist_1: 150 | for vec_2 in dist_2: 151 | inner_prods.append(kernel_func(vec_1, vec_2, **kwargs)) 152 | 153 | return np.mean(inner_prods) 154 | 155 | 156 | def compute_maximum_mean_discrepancy( 157 | dist_1, dist_2, kernel_type, normalize=True, **kwargs 158 | ): 159 | """ 160 | Given two empirical distributions of vectors, computes the maximum mean 161 | discrepancy (MMD) between the two distributions. 162 | Arguments: 163 | `dist_1`: an M x D NumPy array of M vectors, each of size D; all vectors 164 | must share the same underlying vector space (or probability space if 165 | the vectors represent a probability distribution) with each other 166 | and with `dist_2` 167 | `dist_2`: an M x D NumPy array of M vectors, each of size D; all vectors 168 | must share the same underlying vector space (or probability space if 169 | the vectors represent a probability distribution) with each other 170 | and with `dist_1` 171 | `kernel_type`: type of kernel to apply for computing the kernelized 172 | inner product; can be "gaussian", "gaussian_wasserstein", or 173 | "gaussian_total_variation" 174 | `normalize`: if True, normalize each D-vector to sum to 1 175 | `kwargs`: extra keyword arguments to be passed to the kernel function 176 | Returns the scalar MMD value. 177 | """ 178 | if normalize: 179 | dist_1 = dist_1 / np.sum(dist_1, axis=1, keepdims=True) 180 | dist_2 = dist_2 / np.sum(dist_2, axis=1, keepdims=True) 181 | 182 | term_1 = compute_inner_prod_feature_mean( 183 | dist_1, dist_1, kernel_type, **kwargs 184 | ) 185 | term_2 = compute_inner_prod_feature_mean( 186 | dist_2, dist_2, kernel_type, **kwargs 187 | ) 188 | term_3 = compute_inner_prod_feature_mean( 189 | dist_1, dist_2, kernel_type, **kwargs 190 | ) 191 | return np.sqrt(term_1 + term_2 - (2 * term_3)) 192 | 193 | 194 | if __name__ == "__main__": 195 | import networkx as nx 196 | import graph_metrics 197 | 198 | graphs = [ 199 | nx.erdos_renyi_graph( 200 | np.random.choice(np.arange(10, 20)), 201 | np.random.choice(np.linspace(0, 1, 10)) 202 | ) for _ in range(50) 203 | ] + [ 204 | nx.erdos_renyi_graph( 205 | np.random.choice(np.arange(10, 20)), 206 | np.random.choice(np.linspace(0, 1, 10)) 207 | ) for _ in range(50) 208 | ] 209 | 210 | degrees = graph_metrics.get_degrees(graphs) 211 | cluster_coefs = graph_metrics.get_clustering_coefficients(graphs) 212 | spectra = graph_metrics.get_spectra(graphs) 213 | orbit_counts = graph_metrics.get_orbit_counts(graphs) 214 | orbit_counts = np.stack([ 215 | np.mean(counts, axis=0) for counts in orbit_counts 216 | ]) 217 | 218 | kernel_type = "gaussian_total_variation" 219 | 220 | degree_hists = make_histograms(degrees, bin_width=1) 221 | degree_mmd = compute_maximum_mean_discrepancy( 222 | degree_hists[:50], degree_hists[50:], kernel_type, sigma=1 223 | ) 224 | cluster_coef_hists = make_histograms(cluster_coefs, num_bins=100) 225 | cluster_coef_mmd = compute_maximum_mean_discrepancy( 226 | cluster_coef_hists[:50], cluster_coef_hists[50:], kernel_type, sigma=0.1 227 | ) 228 | spectra_hists = make_histograms( 229 | spectra, bin_array=np.linspace(-1e-5, 2, 200 + 1) 230 | ) 231 | spectra_mmd = compute_maximum_mean_discrepancy( 232 | spectra_hists[:50], spectra_hists[50:], kernel_type, sigma=1 233 | ) 234 | orbit_mmd = compute_maximum_mean_discrepancy( 235 | orbit_counts[:50], orbit_counts[50:], kernel_type, normalize=False, 236 | sigma=30 237 | ) 238 | 239 | print("MMD values") 240 | print("Degree: %.15f" % np.square(degree_mmd)) 241 | print("Clustering coefficient: %.15f" % np.square(cluster_coef_mmd)) 242 | print("Spectrum: %.15f" % np.square(spectra_mmd)) 243 | print("Orbit: %.15f" % np.square(orbit_mmd)) 244 | -------------------------------------------------------------------------------- /src/analysis/orca.cpp: -------------------------------------------------------------------------------- 1 | // Downloaded from https://github.com/thocevar/orca 2 | // Compile with `g++ -O2 -std=c++11 -o orca orca.cpp` 3 | 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | using namespace std; 16 | 17 | 18 | typedef long long int64; 19 | typedef pair PII; 20 | typedef struct { int first, second, third; } TIII; 21 | 22 | struct PAIR { 23 | int a, b; 24 | PAIR(int a0, int b0) { a=min(a0,b0); b=max(a0,b0); } 25 | }; 26 | bool operator<(const PAIR &x, const PAIR &y) { 27 | if (x.a==y.a) return x.bb) swap(a,b); 44 | if (b>c) swap(b,c); 45 | if (a>b) swap(a,b); 46 | } 47 | }; 48 | bool operator<(const TRIPLE &x, const TRIPLE &y) { 49 | if (x.a==y.a) { 50 | if (x.b==y.b) return x.c common2; 64 | unordered_map common3; 65 | unordered_map::iterator common2_it; 66 | unordered_map::iterator common3_it; 67 | 68 | #define common3_get(x) (((common3_it=common3.find(x))!=common3.end())?(common3_it->second):0) 69 | #define common2_get(x) (((common2_it=common2.find(x))!=common2.end())?(common2_it->second):0) 70 | 71 | int n,m; // n = number of nodes, m = number of edges 72 | int *deg; // degrees of individual nodes 73 | PAIR *edges; // list of edges 74 | 75 | int **adj; // adj[x] - adjacency list of node x 76 | PII **inc; // inc[x] - incidence list of node x: (y, edge id) 77 | bool adjacent_list(int x, int y) { return binary_search(adj[x],adj[x]+deg[x],y); } 78 | int *adj_matrix; // compressed adjacency matrix 79 | const int adj_chunk = 8*sizeof(int); 80 | bool adjacent_matrix(int x, int y) { return adj_matrix[(x*n+y)/adj_chunk]&(1<<((x*n+y)%adj_chunk)); } 81 | bool (*adjacent)(int,int); 82 | int getEdgeId(int x, int y) { return inc[x][lower_bound(adj[x],adj[x]+deg[x],y)-adj[x]].second; } 83 | 84 | int64 **orbit; // orbit[x][o] - how many times does node x participate in orbit o 85 | int64 **eorbit; // eorbit[x][o] - how many times does node x participate in edge orbit o 86 | 87 | 88 | /** count graphlets on max 4 nodes */ 89 | void count4() { 90 | clock_t startTime, endTime; 91 | startTime = clock(); 92 | clock_t startTime_all, endTime_all; 93 | startTime_all = startTime; 94 | int frac,frac_prev; 95 | 96 | // precompute triangles that span over edges 97 | printf("stage 1 - precomputing common nodes\n"); 98 | int *tri = (int*)calloc(m,sizeof(int)); 99 | frac_prev=-1; 100 | for (int i=0;i= x) break; 131 | nn=0; 132 | for (int ny=0;ny= y) break; 135 | if (adjacent(x,z)==0) continue; 136 | neigh[nn++]=z; 137 | } 138 | for (int i=0;i= x) break; 292 | nn=0; 293 | for (int ny=0;ny= y) break; 296 | if (neighx[z]==-1) continue; 297 | int xz=neighx[z]; 298 | neigh[nn]=z; 299 | neigh_edges[nn]={xz, yz}; 300 | nn++; 301 | } 302 | for (int i=0;i=0;nx--) { 333 | int y=inc[x][nx].first, xy=inc[x][nx].second; 334 | if (y <= x) break; 335 | nn=0; 336 | for (int ny=deg[y]-1;ny>=0;ny--) { 337 | int z=adj[y][ny]; 338 | if (z <= y) break; 339 | if (adjacent(x,z)==0) continue; 340 | neigh[nn++]=z; 341 | } 342 | for (int i=0;i= x) break; 497 | nn=0; 498 | for (int ny=0;ny= y) break; 501 | if (adjacent(x,z)) { 502 | neigh[nn++]=z; 503 | } 504 | } 505 | for (int i=0;i2 && tri[xb]>2)?(common3_get(TRIPLE(x,a,b))-1):0; 604 | f_71 += (tri[xa]>2 && tri[xc]>2)?(common3_get(TRIPLE(x,a,c))-1):0; 605 | f_71 += (tri[xb]>2 && tri[xc]>2)?(common3_get(TRIPLE(x,b,c))-1):0; 606 | f_67 += tri[xa]-2+tri[xb]-2+tri[xc]-2; 607 | f_66 += common2_get(PAIR(a,b))-2; 608 | f_66 += common2_get(PAIR(a,c))-2; 609 | f_66 += common2_get(PAIR(b,c))-2; 610 | f_58 += deg[x]-3; 611 | f_57 += deg[a]-3+deg[b]-3+deg[c]-3; 612 | } 613 | } 614 | 615 | // x = orbit-13 (diamond) 616 | for (int nx2=0;nx21 && tri[xc]>1)?(common3_get(TRIPLE(x,b,c))-1):0; 624 | f_68 += common3_get(TRIPLE(a,b,c))-1; 625 | f_64 += common2_get(PAIR(b,c))-2; 626 | f_61 += tri[xb]-1+tri[xc]-1; 627 | f_60 += common2_get(PAIR(a,b))-1; 628 | f_60 += common2_get(PAIR(a,c))-1; 629 | f_55 += tri[xa]-2; 630 | f_48 += deg[b]-2+deg[c]-2; 631 | f_42 += deg[x]-3; 632 | f_41 += deg[a]-3; 633 | } 634 | } 635 | 636 | // x = orbit-12 (diamond) 637 | for (int nx2=nx1+1;nx21)?common3_get(TRIPLE(a,b,c)):0; 645 | f_63 += common_x[c]-2; 646 | f_59 += tri[ac]-1+common2_get(PAIR(b,c))-1; 647 | f_54 += common2_get(PAIR(a,b))-2; 648 | f_47 += deg[x]-2; 649 | f_46 += deg[c]-2; 650 | f_40 += deg[a]-3+deg[b]-3; 651 | } 652 | } 653 | 654 | // x = orbit-8 (cycle) 655 | for (int nx2=nx1+1;nx20)?common3_get(TRIPLE(a,b,c)):0; 663 | f_53 += tri[xa]+tri[xb]; 664 | f_51 += tri[ac]+common2_get(PAIR(c,b)); 665 | f_50 += common_x[c]-2; 666 | f_49 += common_a[b]-2; 667 | f_38 += deg[x]-2; 668 | f_37 += deg[a]-2+deg[b]-2; 669 | f_36 += deg[c]-2; 670 | } 671 | } 672 | 673 | // x = orbit-11 (paw) 674 | for (int nx2=nx1+1;nx21 && tri[ac]>1)?common3_get(TRIPLE(a,b,c)):0; 713 | f_45 += common2_get(PAIR(b,c))-1; 714 | f_39 += tri[ab]-1+tri[ac]-1; 715 | f_31 += deg[a]-3; 716 | f_28 += deg[x]-1; 717 | f_24 += deg[b]-2+deg[c]-2; 718 | } 719 | } 720 | 721 | // x = orbit-4 (path) 722 | for (int na=0;na= x) break; 917 | nn=0; 918 | for (int ny=0;ny= y) break; 921 | if (neighx[z]==-1) continue; 922 | int xz=neighx[z]; 923 | neigh[nn]=z; 924 | neigh_edges[nn]={xz, yz}; 925 | nn++; 926 | } 927 | for (int i=0;i=x) break; 997 | 998 | // common nodes of y and some other node 999 | for (int i=0;i> n >> m; 1378 | int d_max=0; 1379 | edges = (PAIR*)malloc(m*sizeof(PAIR)); 1380 | deg = (int*)calloc(n,sizeof(int)); 1381 | for (int i=0;i> a >> b; 1384 | if (!(0<=a && a(edges,edges+m).size())!=m) { 1401 | cerr << "Input file contains duplicate undirected edges." << endl; 1402 | return 0; 1403 | } 1404 | // set up adjacency matrix if it's smaller than 100MB 1405 | if ((int64)n*n < 100LL*1024*1024*8) { 1406 | adjacent = adjacent_matrix; 1407 | adj_matrix = (int*)calloc((n*n)/adj_chunk+1,sizeof(int)); 1408 | for (int i=0;i= len(graph_sizes): 82 | # There are more indices in `batch` than there are pointers, so cut off 83 | # the excess 84 | adj_matrix = adj_matrix[:len(graph_sizes)] 85 | 86 | # Create boolean mask of only the top upper triangle of each B x V x V 87 | # matrix, ignoring the diagonal 88 | triu_mask = torch.triu(torch.ones_like(adj_matrix), diagonal=1) == 1 89 | for i, size in enumerate(torch.diff(data.ptr)): 90 | # TODO: make this more efficient 91 | # For each individual graph, limit the upper-triangle mask to 92 | # only the size of that graph 93 | triu_mask[i, :, size:] = False 94 | 95 | # Edge vector is all entries in the graph-specific upper triangle of each 96 | # individual adjacency matrix 97 | edge_vec = adj_matrix[triu_mask] 98 | 99 | if return_batch_inds: 100 | # Number of edges in each graph: 101 | edge_counts = ((graph_sizes * (graph_sizes - 1)) / 2).int() 102 | edge_counts_cumsum = torch.cumsum(edge_counts, dim=0) 103 | 104 | # Create binary marker array, which is all 0s except with 1s 105 | # wherever we switch to a new graph 106 | markers = torch.zeros( 107 | edge_counts_cumsum[-1], dtype=torch.int, device=DEVICE 108 | ) 109 | markers[edge_counts_cumsum[:-1]] = 1 110 | 111 | batch_inds = torch.cumsum(markers, dim=0) 112 | return edge_vec, batch_inds 113 | return edge_vec 114 | 115 | 116 | def edge_vector_to_pyg_data(data, edges, reflect=True): 117 | """ 118 | Returns the edge-index tensor which would be associated with the 119 | torch-geometric Data object `data`, if the edges in the edge index attribute 120 | were set according to the given edges. Note that self-edges are not allowed. 121 | If `edges` is a scalar 1, then this returns the set of all edges as an 122 | edge-index tensor. 123 | Arguments: 124 | `data`: a torch-geometric Data object 125 | `edges`: a binary E-tensor of edges in canonical order, or the scalar 1 126 | `reflect`: by default, each edge will be represented twice in the 127 | edge-index tensor (no self-edges are allowed); if False, only the 128 | upper-triangular indices will be present (and thus the tensor's size 129 | will be halved) 130 | Returns a 2 x E' edge-index tensor (type long). 131 | """ 132 | graph_sizes = torch.diff(data.ptr) 133 | max_size = torch.max(graph_sizes) 134 | 135 | # Create filler adjacency matrix that starts out as all 0s 136 | adj_matrix = torch.zeros( 137 | graph_sizes.shape[0], max_size, max_size, device=DEVICE 138 | ) 139 | 140 | # Create boolean mask of only the top upper triangle of each B x V x V 141 | # matrix, ignoring the diagonal 142 | triu_mask = torch.triu(torch.ones_like(adj_matrix), diagonal=1) == 1 143 | for i, size in enumerate(torch.diff(data.ptr)): 144 | # TODO: make this more efficient 145 | # For each individual graph, limit the upper-triangle mask to 146 | # only the size of that graph 147 | triu_mask[i, :, size:] = False 148 | 149 | # Set the upper triangle of each graph-specific adjacency matrix to 150 | # the edges given, 151 | adj_matrix[triu_mask] = edges 152 | 153 | # Symmetrize the matrix 154 | if reflect: 155 | adj_matrix = adj_matrix + torch.transpose(adj_matrix, 1, 2) 156 | 157 | # Get indices where the adjacency matrix is nonzero (an E x 3 matrix) 158 | nonzero_inds = adj_matrix.nonzero() 159 | 160 | # The indices are for each individual graph, so add the graph-size 161 | # pointers to each set of indices so later graphs have higher 162 | # indices based on the sizes of earlier graphs 163 | edges_to_set = nonzero_inds[:, 1:] + data.ptr[nonzero_inds[:, 0]][:, None] 164 | 165 | # Convert to a 2 x E matrix 166 | edges_to_set = edges_to_set.t().contiguous() 167 | return torch_geometric.utils.sort_edge_index(edges_to_set).long() 168 | 169 | 170 | def split_pyg_data_to_nx_graphs(data): 171 | """ 172 | Given a torch-geometric Data object, splits the objects into an ordered list 173 | of NetworkX graphs. The NetworkX graphs will be undirected (no self-edges 174 | allowed), and node features will be under the attribute "feats". 175 | Arguments: 176 | `data`: a batched torch-geometric Data object 177 | Returns an ordered list of NetworkX graph objects. 178 | """ 179 | graphs = [] 180 | pointers = data.ptr.cpu().numpy() 181 | graph_sizes = np.diff(pointers) 182 | num_graphs = len(graph_sizes) 183 | 184 | # First, get (padded) adjacency matrix of size B x V x V, where V is 185 | # the maximum number of nodes in any individual graph 186 | adj_matrix = torch_geometric.utils.to_dense_adj(data.edge_index, data.batch) 187 | 188 | for i in range(num_graphs): 189 | graph = nx.empty_graph(graph_sizes[i]) 190 | 191 | # Get the indices of adjacency matrix, upper triangle only 192 | edge_indices = torch.triu( 193 | adj_matrix[i], diagonal=1 194 | ).nonzero().cpu().numpy() 195 | for u, v in edge_indices: 196 | graph.add_edge(u, v) 197 | 198 | node_feats = data.x[pointers[i] : pointers[i + 1]].cpu().numpy() 199 | feat_dict = {i : node_feats[i] for i in range(graph.number_of_nodes())} 200 | nx.set_node_attributes(graph, feat_dict, "feats") 201 | 202 | graphs.append(graph) 203 | return graphs 204 | 205 | 206 | def add_virtual_nodes(data): 207 | """ 208 | Given a torch-geometric Data object, adds a virtual node to each individual 209 | graph in the batch. This modifies `data` in place. The attributes are 210 | modified as follows: 211 | `x`: new row of all 0s added for each virtual node 212 | `edge_index`: edges from each virtual node to all other nodes in its 213 | graph 214 | `batch`: new entry with unique index for each virtual node 215 | `edge_type`: introduces new edge type for each edge added 216 | `ptr`: unmodified 217 | Arguments: 218 | `data`: a batched torch-geometric Data object 219 | """ 220 | # Code adapted from `torch_geometric.transforms.virtual_node.VirtualNode` 221 | # We write our own function to make sure we add a distinct virtual node for 222 | # each graph in the batch, and in light of the issue here: 223 | # https://github.com/pyg-team/pytorch_geometric/issues/5818 224 | updates = { 225 | "x": [data.x], 226 | "edge_index": [data.edge_index], 227 | "batch": [data.batch], 228 | "edge_type": [ 229 | data.get("edge_type", torch.zeros_like(data.edge_index[0])) 230 | ] 231 | } 232 | if "num_nodes" in data: 233 | num_nodes = data.num_nodes 234 | 235 | graph_sizes = torch.diff(data.ptr) 236 | num_graphs = len(graph_sizes) 237 | device = data.x.device 238 | 239 | if not updates["edge_type"][0].numel(): 240 | new_edge_type = 1 241 | else: 242 | new_edge_type = int(torch.max(updates["edge_type"][0])) + 1 243 | next_node_i = len(data.x) 244 | next_batch_i = num_graphs 245 | 246 | for i in range(num_graphs): 247 | start_node_i, end_node_i = data.ptr[i], data.ptr[i + 1] 248 | graph_size = graph_sizes[i] 249 | 250 | # New edges 251 | new_edges_1 = torch.full((graph_size,), next_node_i, device=device) 252 | new_edges_2 = torch.arange(start_node_i, end_node_i, device=device) 253 | new_edges = torch.cat([ 254 | torch.stack([new_edges_1, new_edges_2], dim=0), 255 | torch.stack([new_edges_2, new_edges_1], dim=0), 256 | ], dim=1) 257 | updates["edge_index"].append(new_edges) 258 | 259 | new_edge_types_1 = torch.full( 260 | new_edges_1.shape, new_edge_type, device=device 261 | ) 262 | new_edge_types_2 = new_edge_types_1 + 1 263 | new_edge_types = torch.cat([new_edge_types_1, new_edge_types_2], dim=0) 264 | updates["edge_type"].append(new_edge_types) 265 | 266 | # New virtual node 267 | new_x = torch.zeros_like(data.x[0])[None] 268 | updates["x"].append(new_x) 269 | next_node_i = next_node_i + 1 270 | 271 | # New batch 272 | new_batch = torch.tensor([next_batch_i], device=device) 273 | updates["batch"].append(new_batch) 274 | next_batch_i = next_batch_i + 1 275 | 276 | data.edge_index = torch.cat(updates["edge_index"], dim=1) 277 | data.edge_type = torch.cat(updates["edge_type"], dim=0) 278 | data.x = torch.cat(updates["x"], dim=0) 279 | data.batch = torch.cat(updates["batch"], dim=0) 280 | 281 | if "num_nodes" in data: 282 | data.num_nodes = num_nodes + num_graphs 283 | -------------------------------------------------------------------------------- /src/feature/molecule_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import rdkit.Chem 3 | import torch 4 | import torch_geometric 5 | import networkx as nx 6 | 7 | # Define device 8 | if torch.cuda.is_available(): 9 | DEVICE = "cuda" 10 | else: 11 | DEVICE = "cpu" 12 | 13 | 14 | ZINC250K_PATH = "/gstore/home/tsenga5/discrete_graph_diffusion/data/250k_rndm_zinc_drugs_clean_3.csv" 15 | 16 | 17 | def smiles_to_networkx(smiles): 18 | """ 19 | Converts a SMILES string to a NetworkX graph. The graph will retain the 20 | atomic number and bond type for nodes and edges (respectively), under the 21 | keys `atomic_num` and `bond_type` (respectively). 22 | Arguments: 23 | `smiles`: a SMILES string 24 | Returns a NetworkX graph. 25 | """ 26 | mol = rdkit.Chem.MolFromSmiles(smiles) 27 | g = nx.Graph() 28 | for atom in mol.GetAtoms(): 29 | g.add_node( 30 | atom.GetIdx(), 31 | atomic_num=atom.GetAtomicNum() 32 | ) 33 | for bond in mol.GetBonds(): 34 | g.add_edge( 35 | bond.GetBeginAtomIdx(), 36 | bond.GetEndAtomIdx(), 37 | bond_type=bond.GetBondType() 38 | ) 39 | return g 40 | 41 | 42 | ATOM_MAP = torch.tensor([6, 7, 8, 9, 16, 17, 35, 53, 15]) 43 | BOND_MAP = torch.tensor([1, 2, 3, 12]) 44 | 45 | def smiles_to_pyg_data(smiles, ignore_edge_attr=False): 46 | """ 47 | Converts a SMILES string to a torch-geometric Data object. The data object 48 | will have node attributes and edge attributes under `x` and `edge_attr`, 49 | respectively. 50 | Arguments: 51 | `smiles`: a SMILES string 52 | `ignore_edge_attr`: if True, no edge attributes will be included 53 | Returns a torch-geometric Data object. 54 | """ 55 | g = smiles_to_networkx(smiles) 56 | data = torch_geometric.utils.from_networkx(g) 57 | 58 | # Set atom features 59 | atom_inds = torch.argmax( 60 | (data.atomic_num.view(-1, 1) == ATOM_MAP).int(), dim=1 61 | ) 62 | data.x = torch.nn.functional.one_hot(atom_inds, num_classes=len(ATOM_MAP)) 63 | 64 | if not ignore_edge_attr: 65 | # Set bond features 66 | # For aromatic bonds, set them to be both single and double 67 | aromatic_mask = data.bond_type == BOND_MAP[-1] 68 | bond_inds = torch.argmax( 69 | (data.bond_type.view(-1, 1) == BOND_MAP).int(), dim=1 70 | ) 71 | bond_inds[aromatic_mask] = 0 72 | data.edge_attr = torch.nn.functional.one_hot( 73 | bond_inds, num_classes=(len(BOND_MAP) - 1) 74 | ) 75 | data.edge_attr[aromatic_mask, 1] = 1 76 | 77 | del data.atomic_num 78 | del data.bond_type 79 | 80 | return data 81 | 82 | 83 | class ZINCDataset(torch.utils.data.Dataset): 84 | def __init__(self, connectivity_only=False): 85 | """ 86 | Create a PyTorch IterableDataset which yields molecular graphs. 87 | Arguments: 88 | `connectivity_only`: if True, only connectivity information is 89 | retained, and no edge attributes will be included 90 | """ 91 | super().__init__() 92 | 93 | self.connectivity_only = connectivity_only 94 | self.node_dim = len(ATOM_MAP) 95 | 96 | zinc_table = pd.read_csv(ZINC250K_PATH, sep=",", header=0) 97 | zinc_table["smiles"] = zinc_table["smiles"].str.strip() 98 | self.all_smiles = zinc_table["smiles"] 99 | 100 | def __getitem__(self, index): 101 | """ 102 | Returns the torch-geometric Data object representing the molecule at 103 | index `index` in `self.all_smiles`. 104 | """ 105 | data = smiles_to_pyg_data(self.all_smiles[index]) 106 | data.edge_index = torch_geometric.utils.sort_edge_index(data.edge_index) 107 | return data.to(DEVICE) 108 | 109 | def __len__(self): 110 | return len(self.all_smiles) 111 | 112 | 113 | if __name__ == "__main__": 114 | dataset = ZINCDataset(connectivity_only=True) 115 | data_loader = torch_geometric.loader.DataLoader( 116 | dataset, batch_size=32, shuffle=False 117 | ) 118 | batch = next(iter(data_loader)) 119 | -------------------------------------------------------------------------------- /src/feature/random_graph_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric 3 | import numpy as np 4 | import networkx as nx 5 | import scipy.spatial 6 | 7 | # Define device 8 | if torch.cuda.is_available(): 9 | DEVICE = "cuda" 10 | else: 11 | DEVICE = "cpu" 12 | 13 | 14 | def create_tree(node_dim, num_nodes=10, noise_level=1): 15 | """ 16 | Creates a random connected tree. The node attributes will be initialized by 17 | a random vector plus the distance from a randomly selected source node, plus 18 | noise. 19 | Arguments: 20 | `node_dim`: size of node feature vector 21 | `num_nodes`: number of nodes in the graph, or an array to sample from 22 | `noise_level`: standard deviation of Gaussian noise to add to distances 23 | Returns a NetworkX Graph with NumPy arrays as node attributes. 24 | """ 25 | if type(num_nodes) is not int: 26 | num_nodes = np.random.choice(num_nodes) 27 | 28 | g = nx.random_tree(num_nodes) 29 | 30 | node_features = np.empty((num_nodes, node_dim)) 31 | 32 | # Pick a random source node 33 | source = np.random.choice(num_nodes) 34 | 35 | # Set source node's feature to random vector 36 | source_feat = np.random.randn(node_dim) * 2 * np.sqrt(num_nodes) 37 | node_features[source] = source_feat 38 | 39 | # Run BFS starting from source node; for each layer, set features 40 | # to be the source vector plus the distance plus noise 41 | current_layer = [source] 42 | distance = 1 43 | visited = set(current_layer) 44 | while current_layer: 45 | next_layer = [] 46 | for node in current_layer: 47 | for child in g[node]: 48 | if child not in visited: 49 | visited.add(child) 50 | next_layer.append(child) 51 | node_features[child] = source_feat + distance + ( 52 | np.random.randn(node_dim) * noise_level 53 | ) 54 | current_layer = next_layer 55 | distance += 1 56 | 57 | nx.set_node_attributes( 58 | g, {i : node_features[i] for i in range(num_nodes)}, "feats" 59 | ) 60 | return g 61 | 62 | 63 | def create_uniform_cliques( 64 | node_dim, num_nodes=10, clique_size=6, noise_level=1 65 | ): 66 | """ 67 | Creates a random graph of disconnected cliques. The node attributes will be 68 | initialized by a constant vector for each clique, plus some noise. If 69 | `clique_size` does not divide `num_nodes`, there will be a smaller clique. 70 | Arguments: 71 | `node_dim`: size of node feature vector 72 | `num_nodes`: number of nodes in the graph, or an array to sample from 73 | `clique_size`: size of cliques 74 | `noise_level`: standard deviation of Gaussian noise to add to node 75 | features 76 | Returns a NetworkX Graph with NumPy arrays as node attributes. 77 | """ 78 | if type(num_nodes) is not int: 79 | num_nodes = np.random.choice(num_nodes) 80 | 81 | g = nx.empty_graph() 82 | 83 | clique_count = 0 84 | while g.number_of_nodes() < num_nodes: 85 | size = min(clique_size, num_nodes - g.number_of_nodes()) 86 | 87 | # Create clique 88 | clique = nx.complete_graph(size) 89 | 90 | # Create the core feature vector for the clique 91 | core = np.ones((node_dim, 1)) * clique_count * \ 92 | (2 * num_nodes / clique_size) 93 | 94 | # Add a small bit of noise to for each node in the clique 95 | node_features = core + (np.random.randn(size, node_dim) * noise_level) 96 | nx.set_node_attributes( 97 | clique, {i : node_features[i] for i in range(size)}, "feats" 98 | ) 99 | 100 | # Add the clique to the graph 101 | g = nx.disjoint_union(g, clique) 102 | clique_count += 1 103 | 104 | return g 105 | 106 | 107 | def create_diverse_cliques( 108 | node_dim, num_nodes=10, clique_sizes=[3, 4, 5], repeat=False, noise_level=1, 109 | unity_features=False 110 | ): 111 | """ 112 | Creates a random graph of disconnected cliques. The node attributes will be 113 | initialized by a constant vector for each clique (which is the size of the 114 | clique), plus some noise. Leftover nodes will be singleton nodes. 115 | Arguments: 116 | `node_dim`: size of node feature vector 117 | `num_nodes`: number of nodes in the graph, or an array to sample from 118 | `clique_sizes`: iterable of clique sizes to use 119 | `repeat`: if False, all clique sizes will be unique 120 | `noise_level`: standard deviation of Gaussian noise to add to node 121 | features 122 | `unity_features`: if True, use 1 for all features instead of clique size 123 | Returns a NetworkX Graph with NumPy arrays as node attributes. 124 | """ 125 | if type(num_nodes) is not int: 126 | num_nodes = np.random.choice(num_nodes) 127 | 128 | g = nx.empty_graph() 129 | 130 | clique_sizes = np.unique(clique_sizes) 131 | 132 | sizes_to_make, size_left = [], num_nodes 133 | if repeat: 134 | while clique_sizes.size: 135 | size = np.random.choice(clique_sizes) 136 | sizes_to_make.append(size) 137 | size_left -= size 138 | clique_sizes = clique_sizes[clique_sizes >= size_left] 139 | else: 140 | clique_sizes = np.random.permutation(clique_sizes) 141 | for size in clique_sizes: 142 | if size <= size_left: 143 | sizes_to_make.append(size) 144 | size_left -= size 145 | sizes_to_make.extend([1] * size_left) 146 | 147 | for size in sizes_to_make: 148 | # Create clique 149 | clique = nx.complete_graph(size) 150 | 151 | # Create the core feature vector for the clique 152 | core = np.ones((size, 1)) * (1 if unity_features else size) 153 | 154 | # Add a small bit of noise to for each node in the clique 155 | node_features = core + (np.random.randn(size, node_dim) * noise_level) 156 | nx.set_node_attributes( 157 | clique, {i : node_features[i] for i in range(size)}, "feats" 158 | ) 159 | 160 | # Add the clique to the graph 161 | g = nx.disjoint_union(g, clique) 162 | 163 | return g 164 | 165 | 166 | def create_degree_graph( 167 | node_dim, num_nodes=10, edge_prob=0.2, noise_level=1 168 | ): 169 | """ 170 | Creates a random Erdos-Renyi graph where the node attributes are a constant 171 | vector of the degree of the node, plus some noise. 172 | Arguments: 173 | `node_dim`: size of node feature vector 174 | `num_nodes`: number of nodes in the graph, or an array to sample from 175 | `edge_prob`: probability of edges in Erdos-Renyi graph 176 | `noise_level`: standard deviation of Gaussian noise to add to node 177 | features 178 | Returns a NetworkX Graph with NumPy arrays as node attributes. 179 | """ 180 | if type(num_nodes) is not int: 181 | num_nodes = np.random.choice(num_nodes) 182 | 183 | g = nx.erdos_renyi_graph(num_nodes, edge_prob) 184 | 185 | degrees = dict(g.degree()) 186 | node_features = np.tile( 187 | np.array([degrees[n] for n in range(len(g))])[:, None], 188 | (1, node_dim) 189 | ) 190 | node_features = node_features + \ 191 | (np.random.randn(num_nodes, node_dim) * noise_level) 192 | 193 | nx.set_node_attributes( 194 | g, {i : node_features[i] for i in range(num_nodes)}, "feats" 195 | ) 196 | 197 | return g 198 | 199 | 200 | def create_planar_graph(node_dim, num_nodes=64): 201 | """ 202 | Creates a planar graph using the Delaunay triangulation algorithm. 203 | All nodes will be given a feature vector of all 1s. 204 | Arguments: 205 | `node_dim`: size of node feature vector 206 | `num_nodes`: number of nodes in the graph, or an array to sample from 207 | Returns a NetworkX Graph with NumPy arrays as node attributes. 208 | """ 209 | if type(num_nodes) is not int: 210 | num_nodes = np.random.choice(num_nodes) 211 | 212 | # Sample points uniformly at random from unit square 213 | points = np.random.rand(num_nodes, 2) 214 | 215 | # Perform Delaunay triangulation 216 | tri = scipy.spatial.Delaunay(points) 217 | 218 | # Create graph and add edges from triangulation result 219 | g = nx.empty_graph(num_nodes) 220 | indptr, indices = tri.vertex_neighbor_vertices 221 | for i in range(num_nodes): 222 | for j in indices[indptr[i]:indptr[i + 1]]: 223 | g.add_edge(i, j) 224 | 225 | nx.set_node_attributes( 226 | g, {i : np.ones(node_dim) for i in range(num_nodes)}, "feats" 227 | ) 228 | 229 | return g 230 | 231 | 232 | def create_community_graph( 233 | node_dim, num_nodes=np.arange(12, 21), num_comms=2, 234 | intra_comm_edge_prob=0.3, inter_comm_edge_frac=0.05 235 | ): 236 | """ 237 | Creates a community graph following this paper: 238 | https://arxiv.org/abs/1802.08773 239 | The default values give the definition of a "community-small" graph in the 240 | above paper. Each community is a Erdos-Renyi graph, with a certain set 241 | number of edges connecting the communities sparsely (drawn uniformly). 242 | All nodes will be given a feature vector of all 1s. 243 | Arguments: 244 | `node_dim`: size of node feature vector 245 | `num_nodes`: number of nodes in the graph, or an array to sample from 246 | `num_comms`: number of communities to create 247 | `intra_comm_edge_prob`: probability of edge in Erdos-Renyi graph for 248 | each community 249 | `inter_comm_edge_frac`: number of edges to draw between each pair of 250 | communities, as a fraction of `num_nodes`; edges are drawn uniformly 251 | at random between communities 252 | Returns a NetworkX Graph with NumPy arrays as node attributes. 253 | """ 254 | if type(num_nodes) is not int: 255 | num_nodes = np.random.choice(num_nodes) 256 | 257 | # Create communities 258 | exp_size = int(num_nodes / num_comms) 259 | comm_sizes = [] 260 | total_size = 0 261 | g = nx.empty_graph() 262 | while total_size < num_nodes: 263 | size = min(exp_size, num_nodes - total_size) 264 | g = nx.disjoint_union( 265 | g, nx.erdos_renyi_graph(size, intra_comm_edge_prob) 266 | ) 267 | comm_sizes.append(size) 268 | total_size += size 269 | 270 | # Link together communities 271 | node_inds = np.cumsum(comm_sizes) 272 | num_inter_edges = int(num_nodes * inter_comm_edge_frac) 273 | for i in range(num_comms): 274 | for j in range(i): 275 | i_nodes = np.arange(node_inds[i - 1] if i else 0, node_inds[i]) 276 | j_nodes = np.arange(node_inds[j - 1] if j else 0, node_inds[j]) 277 | for _ in range(num_inter_edges): 278 | g.add_edge( 279 | np.random.choice(i_nodes), np.random.choice(j_nodes) 280 | ) 281 | 282 | nx.set_node_attributes( 283 | g, {i : np.ones(node_dim) for i in range(num_nodes)}, "feats" 284 | ) 285 | 286 | return g 287 | 288 | 289 | def create_sbm_graph( 290 | node_dim, num_blocks_arr=np.arange(2, 6), block_size_arr=np.arange(20, 41), 291 | intra_block_edge_prob=0.3, inter_block_edge_prob=0.05 292 | ): 293 | """ 294 | Creates a stochastic-block-model graph, where the number of blocks and size 295 | of blocks is drawn randomly. 296 | All nodes will be given a feature vector of all 1s. 297 | Arguments: 298 | `node_dim`: size of node feature vector 299 | `num_blocks_arr`: iterable containing possible numbers of blocks to have 300 | (selected uniformly) 301 | `block_size_arr`: iterable containing possible block sizes for each 302 | block (selected uniformly per block) 303 | `intra_block_edge_prob`: probability of edge within blocks 304 | `inter_block_edge_prob`: probability of edge between blocks 305 | Returns a NetworkX Graph with NumPy arrays as node attributes. 306 | """ 307 | num_blocks = np.random.choice(num_blocks_arr) 308 | block_sizes = np.random.choice(block_size_arr, num_blocks, replace=True) 309 | num_nodes = np.sum(block_sizes) 310 | 311 | # Create matrix of edge probabilities between blocks 312 | p = np.full((len(block_sizes), len(block_sizes)), inter_block_edge_prob) 313 | np.fill_diagonal(p, intra_block_edge_prob) 314 | 315 | # Create SBM graph 316 | g = nx.stochastic_block_model(block_sizes, p) 317 | 318 | nx.set_node_attributes( 319 | g, {i : np.ones(node_dim) for i in range(num_nodes)}, "feats" 320 | ) 321 | 322 | # Delete these two attributes, or else conversion to PyTorch Geometric Data 323 | # object will fail 324 | del g.graph["partition"] 325 | del g.graph["name"] 326 | 327 | return g 328 | 329 | 330 | def create_molecule_like_graph( 331 | node_dim, num_nodes=10, backbone_size_range=(4, 6), 332 | ornament_size_range=(0, 4), ring_prob=0.5 333 | ): 334 | """ 335 | Creates a random graph that looks a bit like a molecule. There are two types 336 | of nodes: backbone nodes (with node features of 0s) and ornamental nodes 337 | (with node features of 1s). Leftover nodes will be either backbone or 338 | ornamental nodes (selected uniformly at random). The backbone may be cyclic. 339 | Each backbone node can be connected to at most 4 other nodes (other backbone 340 | nodes or ornamental nodes). 341 | Arguments: 342 | `node_dim`: size of node feature vector 343 | `num_nodes`: number of nodes in the graph, or an array to sample from 344 | `backbone_size_range`: pair of minimum and maximum sizes of backbones to 345 | choose from 346 | `ornament_size_range`: pair of minimum and maximum number of ornaments 347 | to choose from; the actual number of ornaments added may be smaller 348 | if the backbone selected is too small 349 | `ring_prob`: probability of generating a ring for the backbone 350 | Returns a NetworkX Graph with NumPy arrays as node attributes. 351 | """ 352 | if type(num_nodes) is not int: 353 | assert np.max(num_nodes) >= \ 354 | backbone_size_range[1] + ornament_size_range[1] 355 | num_nodes = np.random.choice(num_nodes) 356 | else: 357 | assert num_nodes >= backbone_size_range[1] + ornament_size_range[1] 358 | 359 | backbone_size = np.random.randint( 360 | backbone_size_range[0], backbone_size_range[1] + 1 361 | ) 362 | ornament_size = np.random.randint( 363 | ornament_size_range[0], 364 | min(ornament_size_range[1], (backbone_size * 2) + 2) + 1 365 | ) 366 | ring = np.random.random() <= ring_prob 367 | 368 | # Create backbone 369 | if ring: 370 | g = nx.circulant_graph(backbone_size, [1]) 371 | else: 372 | g = nx.empty_graph() 373 | g.add_node(0) 374 | for _ in range(backbone_size - 1): 375 | # Add leaf node to randomly selected node with degree < 4 376 | degrees = g.degree() 377 | connector = np.random.choice( 378 | [i for i in range(len(g)) if degrees[i] < 4] 379 | ) 380 | new_node = len(g) 381 | g.add_node(new_node) 382 | g.add_edge(new_node, connector) 383 | 384 | # Add ornaments 385 | capacities = [(n, 4 - m) for n, m in g.degree() if m < 4] 386 | inds, counts = zip(*capacities) 387 | node_arr = np.repeat(inds, counts) 388 | connectors = np.random.choice(node_arr, size=ornament_size, replace=False) 389 | for connector in connectors: 390 | new_node = len(g) 391 | g.add_node(new_node) 392 | g.add_edge(new_node, connector) 393 | 394 | # Add extra singleton nodes 395 | extra_nodes = num_nodes - len(g) 396 | extra_backbone_size = extra_nodes // 2 397 | extra_ornament_size = extra_nodes - extra_backbone_size 398 | for _ in range(extra_backbone_size): 399 | g.add_node(len(g)) 400 | for _ in range(extra_ornament_size): 401 | g.add_node(len(g)) 402 | 403 | node_features = np.concatenate([ 404 | np.zeros((backbone_size, node_dim)), 405 | np.ones((ornament_size, node_dim)), 406 | np.zeros((extra_backbone_size, node_dim)), 407 | np.ones((extra_ornament_size, node_dim)) 408 | ], axis=0) 409 | 410 | nx.set_node_attributes( 411 | g, {i : node_features[i] for i in range(len(g))}, "feats" 412 | ) 413 | return g 414 | 415 | 416 | class RandomGraphDataset(torch.utils.data.Dataset): 417 | def __init__( 418 | self, node_dim, graph_type="tree", num_items=1000, static=False, 419 | **kwargs 420 | ): 421 | """ 422 | Create a PyTorch IterableDataset which yields random graphs. 423 | Arguments: 424 | `node_dim`: size of node feature vector 425 | `graph_type`: type of graph to generate; can be "tree", 426 | "uniform_cliques", "diverse_cliques", "degree", "planar", 427 | "community", or "sbm" 428 | `num_items`: number of items to yield in an epoch 429 | `static`: if True, generate `num_items` graphs initially and only 430 | yield pregenerated graphs 431 | `kwargs`: extra keyword arguments for the graph generator 432 | """ 433 | super().__init__() 434 | 435 | self.node_dim = node_dim 436 | self.num_items = num_items 437 | self.static = static 438 | self.kwargs = kwargs 439 | 440 | if graph_type == "tree": 441 | self.graph_creater = create_tree 442 | elif graph_type == "uniform_cliques": 443 | self.graph_creater = create_uniform_cliques 444 | elif graph_type == "diverse_cliques": 445 | self.graph_creater = create_diverse_cliques 446 | elif graph_type == "degree": 447 | self.graph_creater = create_degree_graph 448 | elif graph_type == "planar": 449 | self.graph_creater = create_planar_graph 450 | elif graph_type == "community": 451 | self.graph_creater = create_community_graph 452 | elif graph_type == "sbm": 453 | self.graph_creater = create_sbm_graph 454 | elif graph_type == "molecule_like": 455 | self.graph_creater = create_molecule_like_graph 456 | else: 457 | raise ValueError("Unrecognize random graph type: %s" % graph_type) 458 | 459 | if static: 460 | self.graph_cache = [ 461 | self.graph_creater(node_dim, **kwargs) for _ in range(num_items) 462 | ] 463 | 464 | 465 | def __getitem__(self, index): 466 | """ 467 | Returns a single data point generated randomly, as a torch-geometric 468 | Data object. `index` is ignored. 469 | """ 470 | if self.static: 471 | graph = self.graph_cache[index] 472 | else: 473 | graph = self.graph_creater(self.node_dim, **self.kwargs) 474 | data = torch_geometric.utils.from_networkx( 475 | graph, group_node_attrs=["feats"] 476 | ) 477 | data.edge_index = torch_geometric.utils.sort_edge_index(data.edge_index) 478 | return data.to(DEVICE) 479 | 480 | def __len__(self): 481 | return self.num_items 482 | 483 | 484 | if __name__ == "__main__": 485 | node_dim = 5 486 | 487 | dataset = RandomGraphDataset( 488 | node_dim, num_items=100, 489 | graph_type="diverse_cliques", num_nodes=np.arange(10, 20), 490 | clique_sizes=[3, 4, 5, 6], noise_level=0 491 | ) 492 | data_loader = torch_geometric.loader.DataLoader( 493 | dataset, batch_size=32, shuffle=False 494 | ) 495 | batch = next(iter(data_loader)) 496 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/GraphGUIDE/dad0dd371268684a5203839441febe0484d8a3e4/src/model/__init__.py -------------------------------------------------------------------------------- /src/model/digress_gnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.modules.dropout import Dropout 5 | from torch.nn.modules.linear import Linear 6 | from torch.nn.modules.normalization import LayerNorm 7 | from torch.nn import functional as F 8 | from torch import Tensor 9 | import torch_geometric 10 | 11 | def assert_correctly_masked(variable, node_mask): 12 | assert (variable * (1 - node_mask.long())).abs().max().item() < 1e-4, \ 13 | 'Variables not masked properly.' 14 | 15 | class PlaceHolder: 16 | def __init__(self, X, E, y): 17 | self.X = X 18 | self.E = E 19 | self.y = y 20 | 21 | def type_as(self, x: torch.Tensor): 22 | """ Changes the device and dtype of X, E, y. """ 23 | self.X = self.X.type_as(x) 24 | self.E = self.E.type_as(x) 25 | self.y = self.y.type_as(x) 26 | return self 27 | 28 | def mask(self, node_mask, collapse=False): 29 | x_mask = node_mask.unsqueeze(-1) # bs, n, 1 30 | e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1 31 | e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1 32 | 33 | if collapse: 34 | self.X = torch.argmax(self.X, dim=-1) 35 | self.E = torch.argmax(self.E, dim=-1) 36 | 37 | self.X[node_mask == 0] = - 1 38 | self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1 39 | else: 40 | self.X = self.X * x_mask 41 | self.E = self.E * e_mask1 * e_mask2 42 | assert torch.allclose(self.E, torch.transpose(self.E, 1, 2)) 43 | return self 44 | 45 | 46 | def encode_no_edge(E): 47 | assert len(E.shape) == 4 48 | if E.shape[-1] == 0: 49 | return E 50 | no_edge = torch.sum(E, dim=3) == 0 51 | first_elt = E[:, :, :, 0] 52 | first_elt[no_edge] = 1 53 | E[:, :, :, 0] = first_elt 54 | diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1) 55 | E[diag] = 0 56 | return E 57 | 58 | def to_dense(x, edge_index, edge_attr, batch): 59 | X, node_mask = torch_geometric.utils.to_dense_batch(x=x, batch=batch) 60 | # node_mask = node_mask.float() 61 | edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr) 62 | # TODO: carefully check if setting node_mask as a bool breaks the continuous case 63 | max_num_nodes = X.size(1) 64 | if edge_index.numel() == 0: 65 | # We have to do this check otherwise things fail 66 | E = torch_geometric.utils.to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr) 67 | else: 68 | E = torch_geometric.utils.to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes) 69 | E = encode_no_edge(E) 70 | 71 | return PlaceHolder(X=X, E=E, y=None), node_mask 72 | 73 | 74 | class Xtoy(nn.Module): 75 | def __init__(self, dx, dy): 76 | """ Map node features to global features """ 77 | super().__init__() 78 | self.lin = nn.Linear(4 * dx, dy) 79 | 80 | def forward(self, X): 81 | """ X: bs, n, dx. """ 82 | m = X.mean(dim=1) 83 | mi = X.min(dim=1)[0] 84 | ma = X.max(dim=1)[0] 85 | std = X.std(dim=1) 86 | z = torch.hstack((m, mi, ma, std)) 87 | out = self.lin(z) 88 | return out 89 | 90 | 91 | class Etoy(nn.Module): 92 | def __init__(self, d, dy): 93 | """ Map edge features to global features. """ 94 | super().__init__() 95 | self.lin = nn.Linear(4 * d, dy) 96 | 97 | def forward(self, E): 98 | """ E: bs, n, n, de 99 | Features relative to the diagonal of E could potentially be added. 100 | """ 101 | m = E.mean(dim=(1, 2)) 102 | mi = E.min(dim=2)[0].min(dim=1)[0] 103 | ma = E.max(dim=2)[0].max(dim=1)[0] 104 | std = torch.std(E, dim=(1, 2)) 105 | z = torch.hstack((m, mi, ma, std)) 106 | out = self.lin(z) 107 | return out 108 | 109 | class XEyTransformerLayer(nn.Module): 110 | """ Transformer that updates node, edge and global features 111 | d_x: node features 112 | d_e: edge features 113 | dz : global features 114 | n_head: the number of heads in the multi_head_attention 115 | dim_feedforward: the dimension of the feedforward network model after self-attention 116 | dropout: dropout probablility. 0 to disable 117 | layer_norm_eps: eps value in layer normalizations. 118 | """ 119 | def __init__(self, dx: int, de: int, dy: int, n_head: int, dim_ffX: int = 2048, 120 | dim_ffE: int = 128, dim_ffy: int = 2048, dropout: float = 0.1, 121 | layer_norm_eps: float = 1e-5, device=None, dtype=None) -> None: 122 | kw = {'device': device, 'dtype': dtype} 123 | super().__init__() 124 | 125 | self.self_attn = NodeEdgeBlock(dx, de, dy, n_head, **kw) 126 | 127 | self.linX1 = Linear(dx, dim_ffX, **kw) 128 | self.linX2 = Linear(dim_ffX, dx, **kw) 129 | self.normX1 = LayerNorm(dx, eps=layer_norm_eps, **kw) 130 | self.normX2 = LayerNorm(dx, eps=layer_norm_eps, **kw) 131 | self.dropoutX1 = Dropout(dropout) 132 | self.dropoutX2 = Dropout(dropout) 133 | self.dropoutX3 = Dropout(dropout) 134 | 135 | self.linE1 = Linear(de, dim_ffE, **kw) 136 | self.linE2 = Linear(dim_ffE, de, **kw) 137 | self.normE1 = LayerNorm(de, eps=layer_norm_eps, **kw) 138 | self.normE2 = LayerNorm(de, eps=layer_norm_eps, **kw) 139 | self.dropoutE1 = Dropout(dropout) 140 | self.dropoutE2 = Dropout(dropout) 141 | self.dropoutE3 = Dropout(dropout) 142 | 143 | self.lin_y1 = Linear(dy, dim_ffy, **kw) 144 | self.lin_y2 = Linear(dim_ffy, dy, **kw) 145 | self.norm_y1 = LayerNorm(dy, eps=layer_norm_eps, **kw) 146 | self.norm_y2 = LayerNorm(dy, eps=layer_norm_eps, **kw) 147 | self.dropout_y1 = Dropout(dropout) 148 | self.dropout_y2 = Dropout(dropout) 149 | self.dropout_y3 = Dropout(dropout) 150 | 151 | self.activation = F.relu 152 | 153 | def forward(self, X: Tensor, E: Tensor, y, node_mask: Tensor): 154 | """ Pass the input through the encoder layer. 155 | X: (bs, n, d) 156 | E: (bs, n, n, d) 157 | y: (bs, dy) 158 | node_mask: (bs, n) Mask for the src keys per batch (optional) 159 | Output: newX, newE, new_y with the same shape. 160 | """ 161 | 162 | newX, newE, new_y = self.self_attn(X, E, y, node_mask=node_mask) 163 | 164 | newX_d = self.dropoutX1(newX) 165 | X = self.normX1(X + newX_d) 166 | 167 | newE_d = self.dropoutE1(newE) 168 | E = self.normE1(E + newE_d) 169 | 170 | new_y_d = self.dropout_y1(new_y) 171 | y = self.norm_y1(y + new_y_d) 172 | 173 | ff_outputX = self.linX2(self.dropoutX2(self.activation(self.linX1(X)))) 174 | ff_outputX = self.dropoutX3(ff_outputX) 175 | X = self.normX2(X + ff_outputX) 176 | 177 | ff_outputE = self.linE2(self.dropoutE2(self.activation(self.linE1(E)))) 178 | ff_outputE = self.dropoutE3(ff_outputE) 179 | E = self.normE2(E + ff_outputE) 180 | 181 | ff_output_y = self.lin_y2(self.dropout_y2(self.activation(self.lin_y1(y)))) 182 | ff_output_y = self.dropout_y3(ff_output_y) 183 | y = self.norm_y2(y + ff_output_y) 184 | 185 | return X, E, y 186 | 187 | 188 | class NodeEdgeBlock(nn.Module): 189 | """ Self attention layer that also updates the representations on the edges. """ 190 | def __init__(self, dx, de, dy, n_head, **kwargs): 191 | super().__init__() 192 | assert dx % n_head == 0, f"dx: {dx} -- nhead: {n_head}" 193 | self.dx = dx 194 | self.de = de 195 | self.dy = dy 196 | self.df = int(dx / n_head) 197 | self.n_head = n_head 198 | 199 | # Attention 200 | self.q = Linear(dx, dx) 201 | self.k = Linear(dx, dx) 202 | self.v = Linear(dx, dx) 203 | 204 | # FiLM E to X 205 | self.e_add = Linear(de, dx) 206 | self.e_mul = Linear(de, dx) 207 | 208 | # FiLM y to E 209 | self.y_e_mul = Linear(dy, dx) # Warning: here it's dx and not de 210 | self.y_e_add = Linear(dy, dx) 211 | 212 | # FiLM y to X 213 | self.y_x_mul = Linear(dy, dx) 214 | self.y_x_add = Linear(dy, dx) 215 | 216 | # Process y 217 | self.y_y = Linear(dy, dy) 218 | self.x_y = Xtoy(dx, dy) 219 | self.e_y = Etoy(de, dy) 220 | 221 | # Output layers 222 | self.x_out = Linear(dx, dx) 223 | self.e_out = Linear(dx, de) 224 | self.y_out = nn.Sequential(nn.Linear(dy, dy), nn.ReLU(), nn.Linear(dy, dy)) 225 | 226 | def forward(self, X, E, y, node_mask): 227 | """ 228 | :param X: bs, n, d node features 229 | :param E: bs, n, n, d edge features 230 | :param y: bs, dz global features 231 | :param node_mask: bs, n 232 | :return: newX, newE, new_y with the same shape. 233 | """ 234 | x_mask = node_mask.unsqueeze(-1) # bs, n, 1 235 | e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1 236 | e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1 237 | 238 | # 1. Map X to keys and queries 239 | Q = self.q(X) * x_mask # (bs, n, dx) 240 | K = self.k(X) * x_mask # (bs, n, dx) 241 | assert_correctly_masked(Q, x_mask) 242 | # 2. Reshape to (bs, n, n_head, df) with dx = n_head * df 243 | 244 | Q = Q.reshape((Q.size(0), Q.size(1), self.n_head, self.df)) 245 | K = K.reshape((K.size(0), K.size(1), self.n_head, self.df)) 246 | 247 | Q = Q.unsqueeze(2) # (bs, 1, n, n_head, df) 248 | K = K.unsqueeze(1) # (bs, n, 1, n head, df) 249 | 250 | # Compute unnormalized attentions. Y is (bs, n, n, n_head, df) 251 | Y = Q * K 252 | Y = Y / math.sqrt(Y.size(-1)) 253 | assert_correctly_masked(Y, (e_mask1 * e_mask2).unsqueeze(-1)) 254 | 255 | E1 = self.e_mul(E) * e_mask1 * e_mask2 # bs, n, n, dx 256 | E1 = E1.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df)) 257 | 258 | E2 = self.e_add(E) * e_mask1 * e_mask2 # bs, n, n, dx 259 | E2 = E2.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df)) 260 | 261 | # Incorporate edge features to the self attention scores. 262 | Y = Y * (E1 + 1) + E2 # (bs, n, n, n_head, df) 263 | 264 | # Incorporate y to E 265 | newE = Y.flatten(start_dim=3) # bs, n, n, dx 266 | ye1 = self.y_e_add(y).unsqueeze(1).unsqueeze(1) # bs, 1, 1, de 267 | ye2 = self.y_e_mul(y).unsqueeze(1).unsqueeze(1) 268 | newE = ye1 + (ye2 + 1) * newE 269 | 270 | # Output E 271 | newE = self.e_out(newE) * e_mask1 * e_mask2 # bs, n, n, de 272 | assert_correctly_masked(newE, e_mask1 * e_mask2) 273 | 274 | # Compute attentions. attn is still (bs, n, n, n_head, df) 275 | attn = F.softmax(Y, dim=2) 276 | 277 | V = self.v(X) * x_mask # bs, n, dx 278 | V = V.reshape((V.size(0), V.size(1), self.n_head, self.df)) 279 | V = V.unsqueeze(1) # (bs, 1, n, n_head, df) 280 | 281 | # Compute weighted values 282 | weighted_V = attn * V 283 | weighted_V = weighted_V.sum(dim=2) 284 | 285 | # Send output to input dim 286 | weighted_V = weighted_V.flatten(start_dim=2) # bs, n, dx 287 | 288 | # Incorporate y to X 289 | yx1 = self.y_x_add(y).unsqueeze(1) 290 | yx2 = self.y_x_mul(y).unsqueeze(1) 291 | newX = yx1 + (yx2 + 1) * weighted_V 292 | 293 | # Output X 294 | newX = self.x_out(newX) * x_mask 295 | assert_correctly_masked(newX, x_mask) 296 | 297 | # Process y based on X axnd E 298 | y = self.y_y(y) 299 | e_y = self.e_y(E) 300 | x_y = self.x_y(X) 301 | new_y = y + x_y + e_y 302 | new_y = self.y_out(new_y) # bs, dy 303 | 304 | return newX, newE, new_y 305 | 306 | 307 | class GraphTransformer(nn.Module): 308 | """ 309 | n_layers : int -- number of layers 310 | dims : dict -- contains dimensions for each feature type 311 | """ 312 | def __init__(self, n_layers: int, input_dims: dict, hidden_mlp_dims: dict, hidden_dims: dict, 313 | output_dims: dict, act_fn_in: nn.ReLU(), act_fn_out: nn.ReLU()): 314 | super().__init__() 315 | self.n_layers = n_layers 316 | self.out_dim_X = output_dims['X'] 317 | self.out_dim_E = output_dims['E'] 318 | self.out_dim_y = output_dims['y'] 319 | 320 | self.mlp_in_X = nn.Sequential(nn.Linear(input_dims['X'], hidden_mlp_dims['X']), act_fn_in, 321 | nn.Linear(hidden_mlp_dims['X'], hidden_dims['dx']), act_fn_in) 322 | 323 | self.mlp_in_E = nn.Sequential(nn.Linear(input_dims['E'], hidden_mlp_dims['E']), act_fn_in, 324 | nn.Linear(hidden_mlp_dims['E'], hidden_dims['de']), act_fn_in) 325 | 326 | self.mlp_in_y = nn.Sequential(nn.Linear(input_dims['y'], hidden_mlp_dims['y']), act_fn_in, 327 | nn.Linear(hidden_mlp_dims['y'], hidden_dims['dy']), act_fn_in) 328 | 329 | self.tf_layers = nn.ModuleList([XEyTransformerLayer(dx=hidden_dims['dx'], 330 | de=hidden_dims['de'], 331 | dy=hidden_dims['dy'], 332 | n_head=hidden_dims['n_head'], 333 | dim_ffX=hidden_dims['dim_ffX'], 334 | dim_ffE=hidden_dims['dim_ffE']) 335 | for i in range(n_layers)]) 336 | 337 | self.mlp_out_X = nn.Sequential(nn.Linear(hidden_dims['dx'], hidden_mlp_dims['X']), act_fn_out, 338 | nn.Linear(hidden_mlp_dims['X'], output_dims['X'])) 339 | 340 | # Note: we change the activation function here to sigmoid! 341 | self.mlp_out_E = nn.Sequential(nn.Linear(hidden_dims['de'], hidden_mlp_dims['E']), nn.Sigmoid(), 342 | nn.Linear(hidden_mlp_dims['E'], output_dims['E'])) 343 | 344 | self.mlp_out_y = nn.Sequential(nn.Linear(hidden_dims['dy'], hidden_mlp_dims['y']), act_fn_out, 345 | nn.Linear(hidden_mlp_dims['y'], output_dims['y'])) 346 | 347 | def forward(self, X, E, y, node_mask): 348 | bs, n = X.shape[0], X.shape[1] 349 | 350 | diag_mask = torch.eye(n) 351 | diag_mask = ~diag_mask.type_as(E).bool() 352 | diag_mask = diag_mask.unsqueeze(0).unsqueeze(-1).expand(bs, -1, -1, -1) 353 | 354 | X_to_out = X[..., :self.out_dim_X] 355 | E_to_out = E[..., :self.out_dim_E] 356 | y_to_out = y[..., :self.out_dim_y] 357 | 358 | new_E = self.mlp_in_E(E) 359 | new_E = (new_E + new_E.transpose(1, 2)) / 2 360 | after_in = PlaceHolder(X=self.mlp_in_X(X), E=new_E, y=self.mlp_in_y(y)).mask(node_mask) 361 | X, E, y = after_in.X, after_in.E, after_in.y 362 | 363 | for layer in self.tf_layers: 364 | X, E, y = layer(X, E, y, node_mask) 365 | 366 | X = self.mlp_out_X(X) 367 | E = self.mlp_out_E(E) 368 | y = self.mlp_out_y(y) 369 | 370 | X = (X + X_to_out) 371 | E = (E + E_to_out) * diag_mask 372 | y = y + y_to_out 373 | 374 | E = 1/2 * (E + torch.transpose(E, 1, 2)) 375 | 376 | return PlaceHolder(X=X, E=E, y=y).mask(node_mask) 377 | 378 | 379 | class DiGressGNN(nn.Module): 380 | 381 | def __init__(self, input_dim, t_limit): 382 | super().__init__() 383 | self.creation_args = {} 384 | self.t_limit = t_limit 385 | self.model = GraphTransformer( 386 | 5, 387 | {'X': input_dim, 'E': 1, 'y': 1}, 388 | {'X': 256, 'E': 128, 'y': 128}, 389 | {'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128}, 390 | {'X': input_dim, 'E': 1, 'y': 1}, 391 | nn.ReLU(), nn.ReLU() 392 | ) 393 | self.sigmoid = torch.nn.Sigmoid() 394 | self.bce_loss = torch.nn.BCELoss() 395 | 396 | def forward(self, data, t): 397 | # Add extra attribute 398 | data.edge_attr = torch.ones(data.edge_index.shape[1], 1, device=data.x.device) 399 | 400 | # Convert to proper format 401 | dense_data, node_mask = to_dense(data.x, data.edge_index, data.edge_attr, data.batch) 402 | dense_data = dense_data.mask(node_mask) 403 | X, E = dense_data.X.float(), dense_data.E.float() 404 | 405 | # Encode time as y 406 | y = (t[data.ptr[:-1]] / self.t_limit)[:, None] 407 | 408 | # Extract out edge predictions 409 | pred_object = self.model(X, E, y, node_mask) 410 | edge_preds = pred_object.E[:, :, :, 0] # Shape: B x V_max x V_max 411 | 412 | # Convert edge predictions into canonical tensor; do it the same way as 413 | # pyg_data_to_edge_vector 414 | graph_sizes = torch.diff(data.ptr) 415 | 416 | if torch.max(data.batch) >= len(graph_sizes): 417 | edge_preds = edge_preds[:len(graph_sizes)] 418 | 419 | # Create boolean mask of only the top upper triangle of each B x V x V 420 | # matrix, ignoring the diagonal 421 | triu_mask = torch.triu(torch.ones_like(edge_preds), diagonal=1) == 1 422 | for i, size in enumerate(torch.diff(data.ptr)): 423 | # For each individual graph, limit the upper-triangle mask to 424 | # only the size of that graph 425 | triu_mask[i, :, size:] = False 426 | 427 | # Edge vector is all entries in the graph-specific upper triangle of each 428 | # individual adjacency matrix 429 | edge_vec = edge_preds[triu_mask] 430 | 431 | # Finally, pass through a sigmoid 432 | return self.sigmoid(edge_vec) 433 | 434 | def loss(self, pred_probs, true_probs): 435 | return self.bce_loss(pred_probs, true_probs) 436 | 437 | if __name__ == "__main__": 438 | import networkx as nx 439 | 440 | if torch.cuda.is_available(): 441 | DEVICE = "cuda" 442 | else: 443 | DEVICE = "cpu" 444 | 445 | batch_size = 32 446 | num_nodes = 10 447 | node_dim = 5 448 | t_limit = 1 449 | 450 | # Prepare a data and time object just as the model would see 451 | edge_index = [] 452 | for i in range(batch_size): 453 | edges = torch.tensor(list(nx.erdos_renyi_graph(num_nodes, 0.5).edges)) + (i * num_nodes) 454 | edge_index.append(edges) 455 | edge_index.append(torch.flip(edges, (1,))) 456 | edge_index = torch.concat(edge_index, dim=0) 457 | batch = torch.repeat_interleave(torch.arange(batch_size), num_nodes, 0) 458 | ptr = torch.concat([ 459 | torch.zeros(1, device=batch.device), 460 | torch.where(torch.diff(batch) != 0)[0] + 1, 461 | torch.tensor([len(batch)], device=batch.device) 462 | ]).long() 463 | 464 | data = torch_geometric.data.Data( 465 | x=torch.rand((batch_size * num_nodes, node_dim)), 466 | edge_index=edge_index, 467 | batch=batch, 468 | ptr=ptr 469 | ).to(DEVICE) 470 | t = torch.rand(len(data.x)).to(DEVICE) 471 | 472 | # Run the model 473 | model = DiGressGNN(node_dim, t_limit).to(DEVICE) 474 | edge_pred = model(data, t) 475 | -------------------------------------------------------------------------------- /src/model/discrete_diffusers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # Define device 4 | if torch.cuda.is_available(): 5 | DEVICE = "cuda" 6 | else: 7 | DEVICE = "cpu" 8 | 9 | 10 | class DiscreteDiffuser: 11 | # Base class for discrete diffusers 12 | def __init__(self, input_shape, seed=None): 13 | """ 14 | Arguments: 15 | `input_shape`: a tuple of ints which is the shape of input tensors 16 | x; does not include batch dimension 17 | `seed`: random seed for sampling and running the diffusion process 18 | """ 19 | self.input_shape = input_shape 20 | self.rng = torch.Generator(device=DEVICE) 21 | if seed: 22 | self.rng.manual_seed(seed) 23 | 24 | def _inflate_dims(self, v): 25 | """ 26 | Given a tensor vector `v`, appends dimensions of size 1 so that it has 27 | the same number of dimensions as `self.input_shape`. For example, if 28 | `self.input_shape` is (3, 50, 50), then this function turns `v` from a 29 | B-tensor to a B x 1 x 1 x 1 tensor. This is useful for combining the 30 | tensor with things shaped like the input later. 31 | Arguments: 32 | `v`: a B-tensor 33 | Returns a B x `self.input_shape` tensor. 34 | """ 35 | return v[(slice(None),) + ((None,) * len(self.input_shape))] 36 | 37 | def forward(self, x0, t, return_posterior=True): 38 | """ 39 | Runs diffusion process forward given starting point `x0` and a time `t`. 40 | Optionally also returns a tensor which represents the posterior of 41 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.) 42 | Arguments: 43 | `x0`: a B x `self.input_shape` tensor containing the data at some 44 | time points 45 | `t`: a B-tensor containing the time in the diffusion process for 46 | each input 47 | Returns a B x `self.input_shape` tensor to represent xt. If 48 | `return_posterior` is True, then also returns a B x `self.input_shape` 49 | tensor which is a parameter of the posterior. 50 | """ 51 | if return_posterior: 52 | return torch.zeros_like(x0), torch.zeros_like(x0) 53 | else: 54 | return torch.zeros_like(x0) 55 | 56 | def reverse_step(self, xt, t, post): 57 | """ 58 | Performs a reverse sampling step to compute x_{t-1} given xt and the 59 | posterior quantity (or an estimate of it) defined in `posterior`. 60 | Arguments: 61 | `xt`: a B x `self.input_shape` tensor containing the data at time t 62 | `t`: a B-tensor containing the time in the diffusion process for 63 | each input 64 | `post`: a B x `self.input_shape` tensor containing the posterior 65 | quantity (or a model-predicted estimate of it) as defined in 66 | `posterior` 67 | Returns a B x `self.input_shape` tensor for x_{t-1}. 68 | """ 69 | return torch.zeros_like(x0) 70 | 71 | def sample_prior(self, num_samples, t): 72 | """ 73 | Samples from the prior distribution specified by the diffusion process 74 | at time `t`. 75 | Arguments: 76 | `num_samples`: B, the number of samples to return 77 | `t`: a B-tensor containing the time in the diffusion process for 78 | each input 79 | Returns a B x `self.input_shape` tensor for the `xt` values that are 80 | sampled. 81 | """ 82 | return torch.zeros(torch.Size([num_samples] + list(self.input_shape))) 83 | 84 | def __str__(self): 85 | return "Base Discrete Diffuser" 86 | 87 | 88 | class GaussianDiffuser(DiscreteDiffuser): 89 | # Diffuser which adds Gaussian noise to the inputs 90 | def __init__(self, beta_1, delta_beta, input_shape, seed=None): 91 | """ 92 | Arguments: 93 | `beta_1`: beta(1), the first value of beta at t = 1 94 | `delta_beta`: beta(t) will be linear with this slope 95 | `input_shape`: a tuple of ints which is the shape of input tensors 96 | x; does not include batch dimension 97 | `seed`: random seed for sampling and running the diffusion process 98 | """ 99 | super().__init__(input_shape, seed) 100 | 101 | self.beta_1 = torch.tensor(beta_1).to(DEVICE) 102 | self.delta_beta = torch.tensor(delta_beta).to(DEVICE) 103 | self.string = "Gaussian Diffuser (beta(t) = %.2f + %.2ft)" % ( 104 | beta_1, delta_beta 105 | ) 106 | 107 | def _beta(self, t): 108 | """ 109 | Computes beta(t). 110 | Arguments: 111 | `t`: a B-tensor of times 112 | Returns a B-tensor of beta values. 113 | """ 114 | # Subtract 1 from t: when t = 1, beta(t) = beta_1 115 | return self.beta_1 + (self.delta_beta * (t - 1)) 116 | 117 | def _alpha(self, t): 118 | """ 119 | Computes alpha(t). 120 | Arguments: 121 | `t`: a B-tensor of times 122 | Returns a B-tensor of alpha values. 123 | """ 124 | return 1 - self._beta(t) 125 | 126 | def _alpha_bar(self, t): 127 | """ 128 | Computes alpha-bar(t). 129 | Arguments: 130 | `t`: a B-tensor of times 131 | Returns a B-tensor of alpha-bar values. 132 | """ 133 | max_t = torch.max(t) 134 | t_range = torch.arange(max_t.int() + 1, device=DEVICE) 135 | alphas = self._alpha(t_range) 136 | alphas_prod = torch.cumprod(alphas, dim=0) 137 | return alphas_prod[t.long()] 138 | 139 | def forward(self, x0, t, return_posterior=True): 140 | """ 141 | Runs diffusion process forward given starting point `x0` and a time `t`. 142 | Optionally also returns a tensor which represents the posterior of 143 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.) 144 | Arguments: 145 | `x0`: a B x `self.input_shape` tensor containing the data at some 146 | time points 147 | `t`: a B-tensor containing the time in the diffusion process for 148 | each input 149 | Returns a B x `self.input_shape` tensor to represent xt. If 150 | `return_posterior` is True, then also returns a B x `self.input_shape` 151 | tensor which is a parameter of the posterior. 152 | """ 153 | z = torch.normal( 154 | torch.zeros_like(x0), torch.ones_like(x0), generator=self.rng 155 | ) # Shape: B x ... 156 | 157 | alpha_bar = self._inflate_dims(self._alpha_bar(t)) 158 | xt = (torch.sqrt(alpha_bar) * x0) + (torch.sqrt(1 - alpha_bar) * z) 159 | 160 | if return_posterior: 161 | return xt, z 162 | return xt 163 | 164 | def reverse_step(self, xt, t, post): 165 | """ 166 | Performs a reverse sampling step to compute x_{t-1} given xt and the 167 | posterior quantity (or an estimate of it) defined in `posterior`. 168 | Arguments: 169 | `xt`: a B x `self.input_shape` tensor containing the data at time t 170 | `t`: a B-tensor containing the time in the diffusion process for 171 | each input 172 | `post`: a B x `self.input_shape` tensor containing the posterior 173 | quantity (or a model-predicted estimate of it) as defined in 174 | `posterior` 175 | Returns a B x `self.input_shape` tensor for x_{t-1}. 176 | """ 177 | beta = self._beta(t) 178 | alpha = self._alpha(t) 179 | alpha_bar = self._alpha_bar(t) 180 | alpha_bar_1 = self._alpha_bar(t - 1) 181 | beta_tilde = beta * (1 - alpha_bar_1) / (1 - alpha_bar) 182 | 183 | z = torch.normal( 184 | torch.zeros_like(xt), torch.ones_like(xt), generator=self.rng 185 | ) 186 | 187 | d = xt - (post * self._inflate_dims(beta / torch.sqrt(1 - alpha_bar))) 188 | d = d / self._inflate_dims(torch.sqrt(alpha)) 189 | std = torch.sqrt(beta_tilde) 190 | std[t == 1] = 0 # No noise for the last step 191 | return d + (z * self._inflate_dims(std)) 192 | 193 | def sample_prior(self, num_samples, t): 194 | """ 195 | Samples from the prior distribution specified by the diffusion process 196 | at time `t`. 197 | Arguments: 198 | `num_samples`: B, the number of samples to return 199 | `t`: a B-tensor containing the time in the diffusion process for 200 | each input 201 | Returns a B x `self.input_shape` tensor for the `xt` values that are 202 | sampled. 203 | """ 204 | # Ignore t 205 | size = torch.Size([num_samples] + list(self.input_shape)) 206 | return torch.normal( 207 | torch.zeros(size, device=DEVICE), torch.ones(size, device=DEVICE), 208 | generator=self.rng 209 | ) 210 | 211 | def __str__(self): 212 | return self.string 213 | 214 | 215 | class BernoulliDiffuser(DiscreteDiffuser): 216 | # Diffuser which flips bits in the inputs 217 | def __init__(self, a, b, input_shape, seed=None): 218 | """ 219 | Arguments: 220 | `a`: stretch factor of logistic beta function 221 | `b`: shift factor of logistic beta function 222 | `input_shape`: a tuple of ints which is the shape of input tensors 223 | x; does not include batch dimension 224 | `seed`: random seed for sampling and running the diffusion process 225 | """ 226 | super().__init__(input_shape, seed) 227 | 228 | self.a = torch.tensor(a).to(DEVICE) 229 | self.b = torch.tensor(b).to(DEVICE) 230 | self.string = "Bernoulli Diffuser (beta(t) = s((t / %.2f) - %.2f)" % ( 231 | a, b 232 | ) 233 | 234 | epsilon = 1e-6 # For numerical stability 235 | self.half = torch.tensor(0.5, device=DEVICE) 236 | self.half_eps = self.half - 1e-6 237 | self.log2 = torch.log(torch.tensor(2, device=DEVICE)) 238 | 239 | def _beta(self, t): 240 | """ 241 | Computes beta(t), which is the probability of a flip at time `t`. 242 | Arguments: 243 | `t`: a B-tensor of times 244 | Returns a B-tensor of beta values. 245 | """ 246 | # Subtract 1 from t: when t = 1, beta(t) = beta_1 247 | return torch.minimum( 248 | torch.sigmoid((t / self.a) - self.b), self.half_eps 249 | ) 250 | 251 | def _beta_bar(self, t): 252 | """ 253 | Computes beta-bar(t), which is the probability of a flip from time 0 to 254 | time `t` (flip is relative to time 0). 255 | Arguments: 256 | `t`: a B-tensor of times 257 | Returns a B-tensor of beta-bar values. 258 | """ 259 | max_range = torch.arange(1, torch.max(t) + 1, device=DEVICE) 260 | betas = self._beta(max_range) # Shape: maxT 261 | 262 | betas_tiled = torch.tile(betas, (t.shape[0], 1)) # Shape: B x maxT 263 | biases = self.half - betas_tiled 264 | 265 | # Anything that ran over a time t, set to 1 266 | mask = max_range[None] > t[:, None] 267 | biases[mask] = 1 268 | 269 | # Do the product in log-space and transform back for stability 270 | coef = (t - 1) * self.log2 271 | log_prod = torch.sum(torch.log(biases), dim=1) 272 | prod = torch.exp(coef + log_prod) 273 | 274 | return self.half - prod 275 | 276 | def forward(self, x0, t, return_posterior=True): 277 | """ 278 | Runs diffusion process forward given starting point `x0` and a time `t`. 279 | Optionally also returns a tensor which represents the posterior of 280 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.) 281 | Arguments: 282 | `x0`: a B x `self.input_shape` tensor containing the data at some 283 | time points 284 | `t`: a B-tensor containing the time in the diffusion process for 285 | each input 286 | Returns a B x `self.input_shape` tensor to represent xt. If 287 | `return_posterior` is True, then also returns a B x `self.input_shape` 288 | tensor which is a parameter of the posterior. 289 | """ 290 | beta_t = self._inflate_dims(self._beta(t)) 291 | beta_bar_t = self._inflate_dims(self._beta_bar(t)) 292 | beta_bar_t_1 = self._inflate_dims(self._beta_bar(t - 1)) 293 | 294 | prob_flip = torch.tile(beta_bar_t, (1, *x0.shape[1:])) 295 | indicators = torch.bernoulli(prob_flip) 296 | 297 | # Perform flips 298 | xt = x0.clone() 299 | mask = indicators == 1 300 | xt[mask] = 1 - x0[mask] 301 | 302 | if return_posterior: 303 | term_1 = ((1 - xt) * beta_t) + (xt * (1 - beta_t)) 304 | term_2 = ((1 - x0) * beta_bar_t_1) + (x0 * (1 - beta_bar_t_1)) 305 | x0_xor_xt = torch.square(x0 - xt) 306 | term_3 = (x0_xor_xt * beta_bar_t) + \ 307 | ((1 - x0_xor_xt) * (1 - beta_bar_t)) 308 | 309 | post = term_1 * term_2 / term_3 310 | # Due to small numerical instabilities (particularly at t = 0), clip 311 | # the probabilities 312 | post = torch.clamp(post, 0, 1) 313 | return xt, post 314 | return xt 315 | 316 | def reverse_step(self, xt, t, post): 317 | """ 318 | Performs a reverse sampling step to compute x_{t-1} given xt and the 319 | posterior quantity (or an estimate of it) defined in `posterior`. 320 | Arguments: 321 | `xt`: a B x `self.input_shape` tensor containing the data at time t 322 | `t`: a B-tensor containing the time in the diffusion process for 323 | each input 324 | `post`: a B x `self.input_shape` tensor containing the posterior 325 | quantity (or a model-predicted estimate of it) as defined in 326 | `posterior` 327 | Returns a B x `self.input_shape` tensor for x_{t-1}. 328 | """ 329 | # Ignore xt and t 330 | return torch.bernoulli(post) 331 | 332 | def sample_prior(self, num_samples, t): 333 | """ 334 | Samples from the prior distribution specified by the diffusion process 335 | at time `t`. 336 | Arguments: 337 | `num_samples`: B, the number of samples to return 338 | `t`: a B-tensor containing the time in the diffusion process for 339 | each input 340 | Returns a B x `self.input_shape` tensor for the `xt` values that are 341 | sampled. 342 | """ 343 | # Ignore t 344 | size = torch.Size([num_samples] + list(self.input_shape)) 345 | probs = torch.tile(self.half, size) 346 | return torch.bernoulli(probs) 347 | 348 | def __str__(self): 349 | return self.string 350 | 351 | 352 | class BernoulliOneDiffuser(DiscreteDiffuser): 353 | # Diffuser which sets bits to 1 in the inputs 354 | def __init__(self, a, b, input_shape, seed=None, reverse_reflect=False): 355 | """ 356 | Arguments: 357 | `a`: stretch factor of logistic beta function 358 | `b`: shift factor of logistic beta function 359 | `input_shape`: a tuple of ints which is the shape of input tensors 360 | x; does not include batch dimension 361 | `seed`: random seed for sampling and running the diffusion process 362 | `reverse_reflect`: if True, enforce that the reverse-diffusion 363 | process only sets entries to 0; if False, the reverse-diffusion 364 | process samples from the Bernoulli posterior as usual 365 | """ 366 | super().__init__(input_shape, seed) 367 | 368 | self.a = torch.tensor(a).to(DEVICE) 369 | self.b = torch.tensor(b).to(DEVICE) 370 | self.string = "Bernoulli One Diffuser (beta(t) = s((t / %.2f) - %.2f)" \ 371 | % (a, b) 372 | self.reverse_reflect = reverse_reflect 373 | 374 | def _beta(self, t): 375 | """ 376 | Computes beta(t), which is the probability of a flip to 1 at time `t`. 377 | Arguments: 378 | `t`: a B-tensor of times 379 | Returns a B-tensor of beta values. 380 | """ 381 | # Subtract 1 from t: when t = 1, beta(t) = beta_1 382 | return torch.sigmoid((t / self.a) - self.b) 383 | 384 | def _beta_bar(self, t): 385 | """ 386 | Computes beta-bar(t), which is the probability of a flip to 1 from time 387 | 0 to time `t` (flip is assuming the bit at time 0 is 0). 388 | Arguments: 389 | `t`: a B-tensor of times 390 | Returns a B-tensor of beta-bar values. 391 | """ 392 | max_range = torch.arange(1, torch.max(t) + 1, device=DEVICE) 393 | betas = self._beta(max_range) # Shape: maxT 394 | 395 | betas_tiled = torch.tile(betas, (t.shape[0], 1)) # Shape: B x maxT 396 | comps = 1 - betas_tiled 397 | 398 | # Anything that ran over a time t, set to 1 399 | mask = max_range[None] > t[:, None] 400 | comps[mask] = 1 401 | 402 | # Do the product in log-space and transform back for stability 403 | log_prod = torch.sum(torch.log(comps), dim=1) 404 | prod = torch.exp(log_prod) 405 | 406 | return 1 - prod 407 | 408 | def forward(self, x0, t, return_posterior=True): 409 | """ 410 | Runs diffusion process forward given starting point `x0` and a time `t`. 411 | Optionally also returns a tensor which represents the posterior of 412 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.) 413 | Arguments: 414 | `x0`: a B x `self.input_shape` tensor containing the data at some 415 | time points 416 | `t`: a B-tensor containing the time in the diffusion process for 417 | each input 418 | Returns a B x `self.input_shape` tensor to represent xt. If 419 | `return_posterior` is True, then also returns a B x `self.input_shape` 420 | tensor which is a parameter of the posterior. 421 | """ 422 | beta_bar_t = self._inflate_dims(self._beta_bar(t)) 423 | beta_bar_t_1 = self._inflate_dims(self._beta_bar(t - 1)) 424 | 425 | prob_one = x0 + ((1 - x0) * beta_bar_t) 426 | indicators = torch.bernoulli(prob_one) 427 | 428 | # Perform sampling 429 | xt = x0.clone() 430 | mask = indicators == 1 431 | xt[mask] = 1 # Set to 1 432 | 433 | if return_posterior: 434 | term_1 = xt 435 | term_2 = x0 + ((1 - x0) * beta_bar_t_1) 436 | term_3 = (x0 * xt) + ((1 - x0) * xt * beta_bar_t) + \ 437 | ((1 - x0) * (1 - xt) * (1 - beta_bar_t)) 438 | 439 | post = term_1 * term_2 / term_3 440 | # Due to small numerical instabilities (particularly at t = 0), clip 441 | # the probabilities 442 | post = torch.clamp(post, 0, 1) 443 | return xt, post 444 | return xt 445 | 446 | def reverse_step(self, xt, t, post): 447 | """ 448 | Performs a reverse sampling step to compute x_{t-1} given xt and the 449 | posterior quantity (or an estimate of it) defined in `posterior`. 450 | Arguments: 451 | `xt`: a B x `self.input_shape` tensor containing the data at time t 452 | `t`: a B-tensor containing the time in the diffusion process for 453 | each input 454 | `post`: a B x `self.input_shape` tensor containing the posterior 455 | quantity (or a model-predicted estimate of it) as defined in 456 | `posterior` 457 | Returns a B x `self.input_shape` tensor for x_{t-1}. 458 | """ 459 | # Ignore xt and t 460 | if self.reverse_reflect: 461 | xt_1 = xt.clone() 462 | indicators = torch.bernoulli(post) 463 | xt_1[indicators == 0] = 0 # Only set to 0 when a 0 is sampled 464 | return xt_1 465 | return torch.bernoulli(post) 466 | 467 | def sample_prior(self, num_samples, t): 468 | """ 469 | Samples from the prior distribution specified by the diffusion process 470 | at time `t`. 471 | Arguments: 472 | `num_samples`: B, the number of samples to return 473 | `t`: a B-tensor containing the time in the diffusion process for 474 | each input 475 | Returns a B x `self.input_shape` tensor for the `xt` values that are 476 | sampled. 477 | """ 478 | # Ignore t 479 | size = torch.Size([num_samples] + list(self.input_shape)) 480 | return torch.ones(size, device=DEVICE) 481 | 482 | def __str__(self): 483 | return self.string 484 | 485 | 486 | class BernoulliZeroDiffuser(DiscreteDiffuser): 487 | # Diffuser which sets bits to 0 in the inputs 488 | def __init__(self, a, b, input_shape, seed=None, reverse_reflect=False): 489 | """ 490 | Arguments: 491 | `a`: stretch factor of logistic beta function 492 | `b`: shift factor of logistic beta function 493 | `input_shape`: a tuple of ints which is the shape of input tensors 494 | x; does not include batch dimension 495 | `seed`: random seed for sampling and running the diffusion process 496 | `reverse_reflect`: if True, enforce that the reverse-diffusion 497 | process only sets entries to 1; if False, the reverse-diffusion 498 | process samples from the Bernoulli posterior as usual 499 | """ 500 | super().__init__(input_shape, seed) 501 | 502 | self.a = torch.tensor(a).to(DEVICE) 503 | self.b = torch.tensor(b).to(DEVICE) 504 | self.string = "Bernoulli Zero Diffuser (beta(t) = s((t / %.2f) - %.2f)" \ 505 | % (a, b) 506 | self.reverse_reflect = reverse_reflect 507 | 508 | def _beta(self, t): 509 | """ 510 | Computes beta(t), which is the probability of a flip to 0 at time `t`. 511 | Arguments: 512 | `t`: a B-tensor of times 513 | Returns a B-tensor of beta values. 514 | """ 515 | # Subtract 1 from t: when t = 1, beta(t) = beta_1 516 | return torch.sigmoid((t / self.a) - self.b) 517 | 518 | def _beta_bar(self, t): 519 | """ 520 | Computes beta-bar(t), which is the probability of a flip to 0 from time 521 | 0 to time `t` (flip is assuming the bit at time 0 is 1). 522 | Arguments: 523 | `t`: a B-tensor of times 524 | Returns a B-tensor of beta-bar values. 525 | """ 526 | max_range = torch.arange(1, torch.max(t) + 1, device=DEVICE) 527 | betas = self._beta(max_range) # Shape: maxT 528 | 529 | betas_tiled = torch.tile(betas, (t.shape[0], 1)) # Shape: B x maxT 530 | comps = 1 - betas_tiled 531 | 532 | # Anything that ran over a time t, set to 1 533 | mask = max_range[None] > t[:, None] 534 | comps[mask] = 1 535 | 536 | # Do the product in log-space and transform back for stability 537 | log_prod = torch.sum(torch.log(comps), dim=1) 538 | prod = torch.exp(log_prod) 539 | 540 | return 1 - prod 541 | 542 | def forward(self, x0, t, return_posterior=True): 543 | """ 544 | Runs diffusion process forward given starting point `x0` and a time `t`. 545 | Optionally also returns a tensor which represents the posterior of 546 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.) 547 | Arguments: 548 | `x0`: a B x `self.input_shape` tensor containing the data at some 549 | time points 550 | `t`: a B-tensor containing the time in the diffusion process for 551 | each input 552 | Returns a B x `self.input_shape` tensor to represent xt. If 553 | `return_posterior` is True, then also returns a B x `self.input_shape` 554 | tensor which is a parameter of the posterior. 555 | """ 556 | beta_t = self._inflate_dims(self._beta(t)) 557 | beta_bar_t = self._inflate_dims(self._beta_bar(t)) 558 | beta_bar_t_1 = self._inflate_dims(self._beta_bar(t - 1)) 559 | 560 | prob_one = x0 * (1 - beta_bar_t) 561 | indicators = torch.bernoulli(prob_one) 562 | 563 | # Perform sampling 564 | xt = x0.clone() 565 | mask = indicators == 0 566 | xt[mask] = 0 # Set to 0 567 | 568 | if return_posterior: 569 | term_1 = xt + beta_t - (2 * xt * beta_t) 570 | term_2 = x0 * (1 - beta_bar_t_1) 571 | term_3 = ((1 - x0) * (1 - xt)) + (x0 * (1 - xt) * beta_bar_t) + \ 572 | (x0 * xt * (1 - beta_bar_t)) 573 | 574 | post = term_1 * term_2 / term_3 575 | # Due to small numerical instabilities (particularly at t = 0), clip 576 | # the probabilities 577 | post = torch.clamp(post, 0, 1) 578 | return xt, post 579 | return xt 580 | 581 | def reverse_step(self, xt, t, post): 582 | """ 583 | Performs a reverse sampling step to compute x_{t-1} given xt and the 584 | posterior quantity (or an estimate of it) defined in `posterior`. 585 | Arguments: 586 | `xt`: a B x `self.input_shape` tensor containing the data at time t 587 | `t`: a B-tensor containing the time in the diffusion process for 588 | each input 589 | `post`: a B x `self.input_shape` tensor containing the posterior 590 | quantity (or a model-predicted estimate of it) as defined in 591 | `posterior` 592 | Returns a B x `self.input_shape` tensor for x_{t-1}. 593 | """ 594 | # Ignore xt and t 595 | if self.reverse_reflect: 596 | xt_1 = xt.clone() 597 | indicators = torch.bernoulli(post) 598 | xt_1[indicators == 1] = 1 # Only set to 1 when a 1 is sampled 599 | return xt_1 600 | return torch.bernoulli(post) 601 | 602 | def sample_prior(self, num_samples, t): 603 | """ 604 | Samples from the prior distribution specified by the diffusion process 605 | at time `t`. 606 | Arguments: 607 | `num_samples`: B, the number of samples to return 608 | `t`: a B-tensor containing the time in the diffusion process for 609 | each input 610 | Returns a B x `self.input_shape` tensor for the `xt` values that are 611 | sampled. 612 | """ 613 | # Ignore t 614 | size = torch.Size([num_samples] + list(self.input_shape)) 615 | return torch.zeros(size, device=DEVICE) 616 | 617 | def __str__(self): 618 | return self.string 619 | 620 | 621 | class BernoulliSkipDiffuser(BernoulliDiffuser): 622 | # Diffuser which flips bits in the inputs, but the posterior and reverse 623 | # step are directly just probabilities of the original 624 | def forward(self, x0, t, return_posterior=True): 625 | """ 626 | Runs diffusion process forward given starting point `x0` and a time `t`. 627 | Optionally also returns a tensor which represents the posterior of 628 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.) 629 | Arguments: 630 | `x0`: a B x `self.input_shape` tensor containing the data at some 631 | time points 632 | `t`: a B-tensor containing the time in the diffusion process for 633 | each input 634 | Returns a B x `self.input_shape` tensor to represent xt. If 635 | `return_posterior` is True, then also returns a B x `self.input_shape` 636 | tensor which is a parameter of the posterior. 637 | """ 638 | beta_bar_t = self._inflate_dims(self._beta_bar(t)) 639 | prob_flip = torch.tile(beta_bar_t, (1, *x0.shape[1:])) 640 | indicators = torch.bernoulli(prob_flip) 641 | 642 | # Perform flips 643 | xt = x0.clone() 644 | mask = indicators == 1 645 | xt[mask] = 1 - x0[mask] 646 | 647 | if return_posterior: 648 | return xt, x0 # Here, our "posterior" is just x0 649 | return xt 650 | 651 | def reverse_step(self, xt, t, post): 652 | """ 653 | Performs a reverse sampling step to compute x_{t-1} given xt and the 654 | posterior quantity (or an estimate of it) defined in `posterior`. 655 | Arguments: 656 | `xt`: a B x `self.input_shape` tensor containing the data at time t 657 | `t`: a B-tensor containing the time in the diffusion process for 658 | each input 659 | `post`: a B x `self.input_shape` tensor containing the posterior 660 | quantity (or a model-predicted estimate of it) as defined in 661 | `posterior` 662 | Returns a B x `self.input_shape` tensor for x_{t-1}. 663 | """ 664 | last_step_mask = None 665 | if not torch.all(t > 1): 666 | last_step_mask = t == 1 667 | 668 | x0 = torch.bernoulli(post) 669 | # In this case, what we're calling the "posterior" is just x0 670 | 671 | beta_t = self._inflate_dims(self._beta(t)) 672 | beta_bar_t = self._inflate_dims(self._beta_bar(t)) 673 | beta_bar_t_1 = self._inflate_dims(self._beta_bar(t - 1)) 674 | 675 | term_1 = ((1 - xt) * beta_t) + (xt * (1 - beta_t)) 676 | term_2 = ((1 - x0) * beta_bar_t_1) + (x0 * (1 - beta_bar_t_1)) 677 | x0_xor_xt = torch.square(x0 - xt) 678 | term_3 = (x0_xor_xt * beta_bar_t) + ((1 - x0_xor_xt) * (1 - beta_bar_t)) 679 | 680 | posterior = term_1 * term_2 / term_3 # p(x_{t-1} = 1 | xt, x0) 681 | # Due to small numerical instabilities (particularly at t = 0), clip 682 | # the probabilities 683 | posterior = torch.clamp(posterior, 0, 1) 684 | xt_1 = torch.bernoulli(posterior) 685 | 686 | if last_step_mask is not None: 687 | # For the last step, don't do this posterior calculation; just use 688 | # the x0 we are given 689 | xt_1[last_step_mask] = x0[last_step_mask] 690 | return xt_1 691 | 692 | 693 | class BernoulliOneSkipDiffuser(BernoulliOneDiffuser): 694 | # Diffuser which sets bits to 1 in the inputs, but the posterior and reverse 695 | # step are directly just probabilities of the original 696 | def forward(self, x0, t, return_posterior=True): 697 | """ 698 | Runs diffusion process forward given starting point `x0` and a time `t`. 699 | Optionally also returns a tensor which represents the posterior of 700 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.) 701 | Arguments: 702 | `x0`: a B x `self.input_shape` tensor containing the data at some 703 | time points 704 | `t`: a B-tensor containing the time in the diffusion process for 705 | each input 706 | Returns a B x `self.input_shape` tensor to represent xt. If 707 | `return_posterior` is True, then also returns a B x `self.input_shape` 708 | tensor which is a parameter of the posterior. 709 | """ 710 | beta_bar_t = self._inflate_dims(self._beta_bar(t)) 711 | 712 | prob_one = x0 + ((1 - x0) * beta_bar_t) 713 | indicators = torch.bernoulli(prob_one) 714 | 715 | # Perform sampling 716 | xt = x0.clone() 717 | mask = indicators == 1 718 | xt[mask] = 1 # Set to 1 719 | 720 | if return_posterior: 721 | return xt, x0 # Here, our "posterior" is just x0 722 | return xt 723 | 724 | def reverse_step(self, xt, t, post): 725 | """ 726 | Performs a reverse sampling step to compute x_{t-1} given xt and the 727 | posterior quantity (or an estimate of it) defined in `posterior`. 728 | Arguments: 729 | `xt`: a B x `self.input_shape` tensor containing the data at time t 730 | `t`: a B-tensor containing the time in the diffusion process for 731 | each input 732 | `post`: a B x `self.input_shape` tensor containing the posterior 733 | quantity (or a model-predicted estimate of it) as defined in 734 | `posterior` 735 | Returns a B x `self.input_shape` tensor for x_{t-1}. 736 | """ 737 | last_step_mask = None 738 | if not torch.all(t > 1): 739 | last_step_mask = t == 1 740 | 741 | x0 = torch.bernoulli(post) 742 | # In this case, what we're calling the "posterior" is just x0 743 | 744 | beta_bar_t = self._inflate_dims(self._beta_bar(t)) 745 | beta_bar_t_1 = self._inflate_dims(self._beta_bar(t - 1)) 746 | 747 | term_1 = xt 748 | term_2 = x0 + ((1 - x0) * beta_bar_t_1) 749 | term_3 = (x0 * xt) + ((1 - x0) * xt * beta_bar_t) + \ 750 | ((1 - x0) * (1 - xt) * (1 - beta_bar_t)) 751 | 752 | posterior = term_1 * term_2 / term_3 # p(x_{t-1} = 1 | xt, x0) 753 | # Due to small numerical instabilities (particularly at t = 0), clip 754 | # the probabilities 755 | posterior = torch.nan_to_num(posterior, 1) 756 | posterior = torch.clamp(posterior, 0, 1) 757 | xt_1 = torch.bernoulli(posterior) 758 | 759 | if last_step_mask is not None: 760 | # For the last step, don't do this posterior calculation; just use 761 | # the x0 we are given 762 | xt_1[last_step_mask] = x0[last_step_mask] 763 | 764 | if self.reverse_reflect: 765 | mask = (xt == 0) & (xt_1 == 1) 766 | xt_1[mask] = 0 # Disallow flipping to 1 from t to t-1 767 | 768 | return xt_1 769 | 770 | 771 | class BernoulliZeroSkipDiffuser(BernoulliZeroDiffuser): 772 | # Diffuser which sets bits to 0 in the inputs, but the posterior and reverse 773 | # step are directly just probabilities of the original 774 | def forward(self, x0, t, return_posterior=True): 775 | """ 776 | Runs diffusion process forward given starting point `x0` and a time `t`. 777 | Optionally also returns a tensor which represents the posterior of 778 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.) 779 | Arguments: 780 | `x0`: a B x `self.input_shape` tensor containing the data at some 781 | time points 782 | `t`: a B-tensor containing the time in the diffusion process for 783 | each input 784 | Returns a B x `self.input_shape` tensor to represent xt. If 785 | `return_posterior` is True, then also returns a B x `self.input_shape` 786 | tensor which is a parameter of the posterior. 787 | """ 788 | beta_bar_t = self._inflate_dims(self._beta_bar(t)) 789 | 790 | prob_one = x0 * (1 - beta_bar_t) 791 | indicators = torch.bernoulli(prob_one) 792 | 793 | # Perform sampling 794 | xt = x0.clone() 795 | mask = indicators == 0 796 | xt[mask] = 0 # Set to 0 797 | 798 | if return_posterior: 799 | return xt, x0 # Here, our "posterior" is just x0 800 | return xt 801 | 802 | def reverse_step(self, xt, t, post): 803 | """ 804 | Performs a reverse sampling step to compute x_{t-1} given xt and the 805 | posterior quantity (or an estimate of it) defined in `posterior`. 806 | Arguments: 807 | `xt`: a B x `self.input_shape` tensor containing the data at time t 808 | `t`: a B-tensor containing the time in the diffusion process for 809 | each input 810 | `post`: a B x `self.input_shape` tensor containing the posterior 811 | quantity (or a model-predicted estimate of it) as defined in 812 | `posterior` 813 | Returns a B x `self.input_shape` tensor for x_{t-1}. 814 | """ 815 | last_step_mask = None 816 | if not torch.all(t > 1): 817 | last_step_mask = t == 1 818 | 819 | x0 = torch.bernoulli(post) 820 | # In this case, what we're calling the "posterior" is just x0 821 | 822 | beta_t = self._inflate_dims(self._beta(t)) 823 | beta_bar_t = self._inflate_dims(self._beta_bar(t)) 824 | beta_bar_t_1 = self._inflate_dims(self._beta_bar(t - 1)) 825 | 826 | term_1 = xt + beta_t - (2 * xt * beta_t) 827 | term_2 = x0 * (1 - beta_bar_t_1) 828 | term_3 = ((1 - x0) * (1 - xt)) + (x0 * (1 - xt) * beta_bar_t) + \ 829 | (x0 * xt * (1 - beta_bar_t)) 830 | 831 | posterior = term_1 * term_2 / term_3 # p(x_{t-1} = 1 | xt, x0) 832 | # Due to small numerical instabilities (particularly at t = 0), clip 833 | # the probabilities 834 | posterior = torch.nan_to_num(posterior, 1) 835 | posterior = torch.clamp(posterior, 0, 1) 836 | xt_1 = torch.bernoulli(posterior) 837 | 838 | if last_step_mask is not None: 839 | # For the last step, don't do this posterior calculation; just use 840 | # the x0 we are given 841 | xt_1[last_step_mask] = x0[last_step_mask] 842 | 843 | if self.reverse_reflect: 844 | mask = (xt == 1) & (xt_1 == 0) 845 | xt_1[mask] = 1 # Disallow flipping to 0 from t to t-1 846 | 847 | return xt_1 848 | 849 | 850 | if __name__ == "__main__": 851 | input_shape = (1, 28, 28) 852 | diffuser = GaussianDiffuser(1e-4, 1.2e-4, input_shape) 853 | x0 = torch.empty((32,) + input_shape, device=DEVICE) 854 | t = torch.randint(1, 1001, (32,), device=DEVICE) 855 | xt, z = diffuser.forward(x0, t) 856 | xt1 = diffuser.reverse_step(xt, t, z) 857 | xT = diffuser.sample_prior(32, None) 858 | -------------------------------------------------------------------------------- /src/model/generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | import feature.graph_conversions as graph_conversions 4 | 5 | # Define device 6 | if torch.cuda.is_available(): 7 | DEVICE = "cuda" 8 | else: 9 | DEVICE = "cpu" 10 | 11 | 12 | def generate_samples( 13 | model, diffuser, num_samples=64, t_start=0, t_limit=1000, 14 | initial_samples=None, return_all_times=False, verbose=False 15 | ): 16 | """ 17 | Generates samples from a trained posterior model and discrete diffuser. This 18 | first generates a sample from the prior distribution a `t_limit`, then steps 19 | backward through time to generate new data points. 20 | Arguments: 21 | `model`: a trained model which takes in x, t and predicts a posterior 22 | `diffuser`: a DiscreteDiffuser object 23 | `num_samples`: number of objects to return 24 | `t_start`: last time step to stop at (a smaller positive integer) than 25 | `t_limit` 26 | `t_limit`: the time step to start generating at (a larger positive 27 | integer than `t_start`) 28 | `initial_samples`: if given, this is a tensor which contains the samples 29 | to start from initially, to be used instead of sampling from the 30 | diffuser's defined prior 31 | `return_all_times`: if True, instead of returning a tensor at `t_start`, 32 | return a larger tensor where the first dimension is 33 | `t_limit - t_start + 1`, and a tensor of times; each tensor is the 34 | reconstruction of the object for that time; the first entry will be 35 | the object at `t_limit`, and the last entry will be the object at 36 | `t_start` 37 | `verbose`: if True, print out progress bar 38 | Returns a tensor of size `num_samples` x .... If `return_all_times` is True, 39 | returns a tensor of size T x `num_samples` x ... of reconstructions and a 40 | tensor of size T for times. 41 | """ 42 | # First, sample from the prior distribution at some late time t 43 | if initial_samples is not None: 44 | xt = initial_samples 45 | assert len(xt) == num_samples 46 | else: 47 | t = (torch.ones(num_samples) * t_limit).to(DEVICE) 48 | xt = diffuser.sample_prior(num_samples, t) 49 | 50 | if return_all_times: 51 | all_xt = torch.empty( 52 | (t_limit - t_start + 1,) + xt.shape, 53 | dtype=xt.dtype, device=xt.device 54 | ) 55 | all_xt[0] = xt 56 | all_t = torch.arange(t_limit, t_start - 1, step=-1).to(DEVICE) 57 | 58 | # Disable gradient computation in model 59 | model.eval() 60 | torch.set_grad_enabled(False) 61 | 62 | time_steps = torch.arange(t_limit, t_start, step=-1).to(DEVICE) 63 | # (descending order) 64 | 65 | # Step backward through time starting at xt 66 | x = xt 67 | t_iter = tqdm.tqdm(enumerate(time_steps), total=len(time_steps)) if verbose \ 68 | else enumerate(time_steps) 69 | for t_i, time_step in t_iter: 70 | t = torch.ones(num_samples).to(DEVICE) * time_step 71 | post = model(xt, t) 72 | xt = diffuser.reverse_step(xt, t, post) 73 | 74 | if return_all_times: 75 | all_xt[t_i + 1] = xt 76 | 77 | if return_all_times: 78 | return all_xt, all_t 79 | return xt 80 | 81 | 82 | def generate_graph_samples( 83 | model, diffuser, initial_samples, t_start=0, t_limit=1000, 84 | return_all_times=False, verbose=False 85 | ): 86 | """ 87 | Generates samples from a trained posterior model and discrete diffuser. This 88 | first generates a sample from the prior distribution a `t_limit`, then steps 89 | backward through time to generate new data points. 90 | Arguments: 91 | `model`: a trained model which takes in x, t and predicts a posterior on 92 | edges in canonical order 93 | `diffuser`: a DiscreteDiffuser object 94 | `initial_samples`: a torch-geometric Data object which contains the 95 | samples to start from initially, at `t_limit` 96 | `t_start`: last time step to stop at (a smaller positive integer) than 97 | `t_limit` 98 | `t_limit`: the time step to start generating at (a larger positive 99 | integer than `t_start`) 100 | `return_all_times`: if True, instead of returning a single 101 | torch-geometric Data object at `t_start`, return a list of Data 102 | objects of length `t_limit - t_start + 1`, and parallel tensor of 103 | times; each Data object reconstruction of the object for that time; 104 | the first entry will be the object at `t_limit`, and the last entry 105 | will be the object at `t_start` 106 | `verbose`: if True, print out progress bar 107 | Returns a torch-geometric Data object. If `return_all_times` is True, 108 | returns a list of T torch-geometric Data objects and a T-tensor of times. 109 | """ 110 | xt = initial_samples 111 | 112 | if return_all_times: 113 | all_xt = [xt] 114 | all_t = torch.arange(t_limit, t_start - 1, step=-1).to(DEVICE) 115 | 116 | # Disable gradient computation in model 117 | model.eval() 118 | torch.set_grad_enabled(False) 119 | 120 | time_steps = torch.arange(t_limit, t_start, step=-1).to(DEVICE) 121 | # (descending order) 122 | 123 | # Step backward through time starting at xt 124 | x = xt 125 | t_iter = tqdm.tqdm(enumerate(time_steps), total=len(time_steps)) if verbose \ 126 | else enumerate(time_steps) 127 | for t_i, time_step in t_iter: 128 | edges = graph_conversions.pyg_data_to_edge_vector(xt) 129 | 130 | t_v = torch.tile( 131 | torch.tensor([time_step], device=DEVICE), (xt.x.shape[0],) 132 | ) # Shape: V 133 | t_e = torch.tile( 134 | torch.tensor([time_step], device=DEVICE), edges.shape 135 | ) # Shape: E 136 | 137 | post = model(xt, t_v) 138 | post_edges = diffuser.reverse_step( 139 | edges[:, None], t_e, post[:, None] 140 | )[:, 0] # Do everything on E x 1 tensors, and then squeeze back 141 | 142 | # Make copy of Data object 143 | xt = xt.clone() 144 | 145 | xt.edge_index = graph_conversions.edge_vector_to_pyg_data( 146 | xt, post_edges 147 | ) 148 | 149 | if return_all_times: 150 | all_xt.append(xt) 151 | 152 | if return_all_times: 153 | return all_xt, all_t 154 | return xt 155 | -------------------------------------------------------------------------------- /src/model/gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric 3 | import numpy as np 4 | import networkx as nx 5 | from model.util import sanitize_sacred_arguments 6 | import feature.graph_conversions as graph_conversions 7 | 8 | class GraphLinkGIN(torch.nn.Module): 9 | 10 | def __init__( 11 | self, input_dim, t_limit, num_gnn_layers=4, hidden_dim=10, 12 | time_embed_size=256, virtual_node=False 13 | ): 14 | """ 15 | Initialize a time-dependent GNN which predicts bit probabilities for 16 | each edge. 17 | Arguments: 18 | `input_dim`: the dimension of the input node features 19 | `t_limit`: maximum time horizon 20 | `num_gnn_layers`: number of GNN layers to have 21 | `hidden_dim`: the dimension of the hidden node embeddings 22 | `time_embed_size`: size of the time embeddings 23 | `virtual_node`: if True, include a virtual node in the architecture 24 | """ 25 | super().__init__() 26 | 27 | self.creation_args = locals() 28 | del self.creation_args["self"] 29 | del self.creation_args["__class__"] 30 | self.creation_args = sanitize_sacred_arguments(self.creation_args) 31 | 32 | self.t_limit = t_limit 33 | self.num_gnn_layers = num_gnn_layers 34 | self.virtual_node = virtual_node 35 | 36 | self.time_embed_dense = torch.nn.Linear(3, time_embed_size) 37 | 38 | self.swish = lambda x: x * torch.sigmoid(x) 39 | self.relu = torch.nn.ReLU() 40 | 41 | # GNN layers 42 | self.gnn_layers = torch.nn.ModuleList() 43 | self.gnn_batch_norms = torch.nn.ModuleList() 44 | for i in range(num_gnn_layers): 45 | gnn_nn = torch.nn.Sequential( 46 | torch.nn.Linear( 47 | input_dim + time_embed_size if i == 0 else hidden_dim, 48 | hidden_dim * 2 49 | ), 50 | self.relu, 51 | torch.nn.Linear(hidden_dim * 2, hidden_dim) 52 | ) 53 | gnn_layer = torch_geometric.nn.GINConv(gnn_nn, train_eps=True) 54 | gnn_batch_norm = torch_geometric.nn.LayerNorm(hidden_dim) 55 | 56 | self.gnn_layers.append(gnn_layer) 57 | self.gnn_batch_norms.append(gnn_batch_norm) 58 | 59 | # Link prediction 60 | self.link_dense = torch.nn.Linear(hidden_dim, 1) 61 | 62 | # Loss 63 | self.bce_loss = torch.nn.BCELoss() 64 | 65 | def forward(self, data, t): 66 | """ 67 | Forward pass of the network. 68 | Arguments: 69 | `data`: a (batched) torch-geometric Data object 70 | `t`: a V-tensor containing the time to train on for each node; note 71 | that the time should be the same for nodes belonging to the same 72 | individual graph 73 | Returns an E-tensor of probabilities of each edge at time t - 1, where E 74 | is the total possible number of edges, and is in canonical ordering. 75 | """ 76 | if self.virtual_node: 77 | # Add a virtual node 78 | data_copy = data.clone() # Don't modify the original 79 | graph_conversions.add_virtual_nodes(data_copy) # Modify `data_copy` 80 | # Also extend `t` by the number of virtual nodes added (use 0, 81 | # since the node is fake) 82 | # TODO: it would be better to use the same time as its host graph 83 | t = torch.cat([ 84 | t, torch.zeros(len(data.ptr) - 1, device=t.device) 85 | ]) 86 | data = data_copy # Use our modified Data object 87 | 88 | # Get the time embeddings for `t` 89 | time_embed_args = t[:, None] / self.t_limit # Shape: V x 1 90 | time_embed = self.swish(self.time_embed_dense( 91 | torch.cat([ 92 | torch.sin(time_embed_args * (np.pi / 2)), 93 | torch.cos(time_embed_args * (np.pi / 2)), 94 | time_embed_args 95 | ], dim=1) 96 | )) # Shape: V x Dt 97 | 98 | # Concatenate initial node features and time embedding 99 | node_embed = torch.cat([data.x.float(), time_embed], dim=1) 100 | # Shape: V x D 101 | 102 | # GNN layers 103 | for i in range(self.num_gnn_layers): 104 | node_embed = self.gnn_batch_norms[i]( 105 | self.gnn_layers[i](node_embed, data.edge_index), 106 | data.batch 107 | ) 108 | 109 | # For all possible edges (i.e. node pairs), compute probability 110 | edge_inds = graph_conversions.edge_vector_to_pyg_data( 111 | data, 1, reflect=False 112 | ) # Shape: 2 x E 113 | node_embed_1 = node_embed[edge_inds[0]] # Shape: E x D' 114 | node_embed_2 = node_embed[edge_inds[1]] # Shape: E x D' 115 | node_prod = node_embed_1 * node_embed_2 116 | 117 | edge_probs = torch.sigmoid(self.link_dense(node_prod))[:, 0] 118 | 119 | return edge_probs 120 | 121 | def loss(self, pred_probs, true_probs): 122 | """ 123 | Computes the loss of a batch. 124 | Arguments: 125 | `pred_probs`: an E-tensor of predicted probabilities 126 | `true_probs`: an E-tensor of true probabilities 127 | Returns a scalar loss value. 128 | """ 129 | return self.bce_loss(pred_probs, true_probs) 130 | 131 | 132 | class GraphAttentionLayer(torch.nn.Module): 133 | 134 | def __init__(self, input_dim, num_heads=8, att_hidden_dim=32, dropout=0.1): 135 | """ 136 | Initialize a graph attention layer which computes attention between all 137 | nodes. 138 | Arguments: 139 | `input_dim`: the dimension of the input node features, D 140 | `num_heads`: number of attention heads 141 | `att_hidden_dim`: the dimension of the hidden node embeddings in the 142 | GAT 143 | `dropout`: dropout rate for post-GAT layers 144 | """ 145 | 146 | super().__init__() 147 | 148 | self.input_dim = input_dim 149 | self.num_heads = num_heads 150 | self.att_hidden_dim = att_hidden_dim 151 | 152 | self.spectral_dense = torch.nn.Linear(input_dim, input_dim) 153 | 154 | self.gat = torch_geometric.nn.GATv2Conv( 155 | input_dim * 2, att_hidden_dim, heads=num_heads, edge_dim=1 156 | # edge_dim is 1 for presence/absence of edge 157 | ) 158 | 159 | self.dropout_1 = torch.nn.Dropout(dropout) 160 | self.norm_1 = torch_geometric.nn.LayerNorm( 161 | att_hidden_dim * num_heads 162 | ) 163 | self.dense_2 = torch.nn.Linear( 164 | att_hidden_dim * num_heads, input_dim 165 | ) 166 | self.dropout_2 = torch.nn.Dropout(dropout) 167 | self.norm_2 = torch_geometric.nn.LayerNorm(input_dim) 168 | 169 | self.dense_3 = torch.nn.Linear(input_dim, input_dim) 170 | self.dropout_3 = torch.nn.Dropout(dropout) 171 | self.norm_3 = torch_geometric.nn.LayerNorm(input_dim) 172 | 173 | self.relu = torch.nn.ReLU() 174 | 175 | def forward( 176 | self, x, full_edge_index, edge_indicators, batch, spectrum_mats 177 | ): 178 | """ 179 | Forward pass of the attention layer. 180 | Arguments: 181 | `x`: V x D tensor of node features 182 | `full_edge_index`: 2 x E tensor of edge indices for the attention 183 | mechanism, denoting all possible edges within each subgraph in 184 | the batch 185 | `edge_feats`: E x 1 tensor of edge features, denoting which edges 186 | actually exist in each subgraph (e.g. 1 if the edge exists, 0 187 | otherwise) 188 | `batch`: V-tensor of which nodes belong to which batch 189 | `spectrum_mats`: list of B matrices, where each matrix is n x m, 190 | where n is the number of nodes in each graph and m is the 191 | number of eigenvectors (columns); this matrix transforms _out_ 192 | of the spectral domain when left-multiplying node features; B is 193 | the number of graphs in the batch 194 | Returns a V x D tensor of updated node features. 195 | """ 196 | # Perform spectral convolution 197 | specconv_out = x.clone() 198 | for i in range(len(spectrum_mats)): 199 | batch_mask = batch == i 200 | x_in = x[batch_mask] # Shape: n x d 201 | mat = spectrum_mats[i] # Shape: n x m 202 | # Transform into spectral domain 203 | x_spec = torch.matmul(torch.transpose(mat, 0, 1), x_in) 204 | # Shape: m x d 205 | # Perform convolution by combining channels across each "frequency" 206 | x_spec_out = self.spectral_dense(x_spec) # Shape: m x d 207 | # Transform back into feature domain 208 | x_out = torch.matmul(mat, x_spec_out) # Shape: n x d 209 | specconv_out[batch_mask] = x_out 210 | 211 | # Attention on x and output of spectral convolution 212 | gat_out = self.gat( 213 | torch.cat([x, specconv_out], dim=1), full_edge_index, 214 | edge_attr=edge_indicators 215 | ) 216 | x_out_1 = self.norm_1(x + self.relu(self.dropout_1(gat_out))) 217 | 218 | # Post-attention dense layers 219 | x_out_2 = self.norm_2(x_out_1 + self.dropout_2(self.dense_2(x_out_1))) 220 | x_out_3 = self.norm_3(x_out_2 + self.dropout_3(self.dense_3(x_out_2))) 221 | return x_out_3 222 | 223 | 224 | class GraphLinkGAT(torch.nn.Module): 225 | 226 | def __init__( 227 | self, input_dim, t_limit, num_gnn_layers=4, gat_num_heads=8, 228 | gat_hidden_dim=32, hidden_dim=256, time_embed_size=256, spectrum_dim=5, 229 | epsilon=1e-6 230 | ): 231 | """ 232 | Initialize a time-dependent GNN which predicts bit probabilities for 233 | each edge. 234 | Arguments: 235 | `input_dim`: the dimension of the input node features 236 | `t_limit`: maximum time horizon 237 | `num_gnn_layers`: number of GNN layers to have 238 | `gat_num_heads`: number of attention heads 239 | `gat_hidden_dim`: the dimension of the hidden node embeddings in the 240 | GAT 241 | `hidden_dim`: size of hidden dimension before and after attention 242 | layers 243 | `time_embed_size`: size of the time embeddings 244 | `spectrum_dim`: number of spectral features to use (i.e. the number 245 | of eigenvectors to use) 246 | `epsilon`: small number for numerical stability when computing graph 247 | Laplacian 248 | """ 249 | super().__init__() 250 | 251 | self.creation_args = locals() 252 | del self.creation_args["self"] 253 | del self.creation_args["__class__"] 254 | self.creation_args = sanitize_sacred_arguments(self.creation_args) 255 | 256 | self.t_limit = t_limit 257 | self.num_gnn_layers = num_gnn_layers 258 | self.spectrum_dim = spectrum_dim 259 | self.epsilon = epsilon 260 | 261 | self.time_embed_dense = torch.nn.Linear(3, time_embed_size) 262 | 263 | self.swish = lambda x: x * torch.sigmoid(x) 264 | self.relu = torch.nn.ReLU() 265 | 266 | # Pre-GNN linear layers 267 | self.pregnn_dense_1 = torch.nn.Linear( 268 | input_dim + time_embed_size, hidden_dim 269 | ) 270 | self.pregnn_dense_2 = torch.nn.Linear(hidden_dim, hidden_dim) 271 | 272 | # GNN layers 273 | self.gnn_layers = torch.nn.ModuleList() 274 | for i in range(num_gnn_layers): 275 | self.gnn_layers.append(GraphAttentionLayer( 276 | hidden_dim if i == 0 else gat_num_heads * gat_hidden_dim, 277 | num_heads=gat_num_heads, att_hidden_dim=gat_hidden_dim 278 | )) 279 | 280 | # Pre-GNN linear layers 281 | self.postgnn_dense_1 = torch.nn.Linear( 282 | gat_hidden_dim * gat_num_heads, hidden_dim 283 | ) 284 | self.postgnn_dense_2 = torch.nn.Linear(hidden_dim, hidden_dim) 285 | 286 | # Link prediction 287 | self.link_dense = torch.nn.Linear(hidden_dim, 1) 288 | 289 | # Loss 290 | self.bce_loss = torch.nn.BCELoss() 291 | 292 | def forward(self, data, t): 293 | """ 294 | Forward pass of the network. 295 | Arguments: 296 | `data`: a (batched) torch-geometric Data object 297 | `t`: a V-tensor containing the time to train on for each node; note 298 | that the time should be the same for nodes belonging to the same 299 | individual graph 300 | Returns an E-tensor of probabilities of each edge at time t - 1, where E 301 | is the total possible number of edges, and is in canonical ordering. 302 | """ 303 | # Get the time embeddings for `t` 304 | time_embed_args = t[:, None] / self.t_limit # Shape: V x 1 305 | time_embed = self.swish(self.time_embed_dense( 306 | torch.cat([ 307 | torch.sin(time_embed_args * (np.pi / 2)), 308 | torch.cos(time_embed_args * (np.pi / 2)), 309 | time_embed_args 310 | ], dim=1) 311 | )) # Shape: V x Dt 312 | 313 | # Concatenate initial node features and time embedding 314 | node_embed = torch.cat([data.x.float(), time_embed], dim=1) 315 | # Shape: V x D 316 | 317 | # Pre-GNN dense layers on node features 318 | node_embed = self.relu(self.pregnn_dense_1(node_embed)) 319 | node_embed = self.relu(self.pregnn_dense_2(node_embed)) 320 | 321 | # Create edge_index specifying the full dense subgraphs 322 | full_edge_index = graph_conversions.edge_vector_to_pyg_data( 323 | data, 1, reflect=False 324 | ) # Shape: 2 x E 325 | 326 | # Create edge features, which denotes both which edges are real 327 | edge_indicators = \ 328 | graph_conversions.pyg_data_to_edge_vector(data)[:, None] 329 | # Shape: E x 1 330 | 331 | # Compute the Laplacian for each graph in the batch 332 | adj_mat = torch_geometric.utils.to_dense_adj( 333 | data.edge_index, data.batch 334 | ) 335 | deg = torch.sum(adj_mat, dim=2) 336 | sqrt_deg = 1 / torch.sqrt(deg + self.epsilon) 337 | sqrt_deg_mat = torch.diag_embed(sqrt_deg) 338 | 339 | identity = torch.eye(adj_mat.size(1), device=adj_mat.device)[None] 340 | laplacian = identity - \ 341 | torch.matmul(torch.matmul(sqrt_deg_mat, adj_mat), sqrt_deg_mat) 342 | 343 | # Compute spectrum transformation matrix (i.e. eigenvalues/eigenvectors) 344 | # This is done separately for each graph, as each graph may have a 345 | # different size 346 | spectrum_mats = [] 347 | for i, graph_size in enumerate(torch.diff(data.ptr)): 348 | # We only compute the eigendecomposition on the graph-size-limited 349 | # Laplacian, since this function always sorts eigenvalues 350 | evals, evecs = torch.linalg.eigh( 351 | laplacian[i, :graph_size, :graph_size] 352 | ) 353 | # Limit the eigenvectors to the smallest eigenvalues if needed 354 | if self.spectrum_dim < graph_size: 355 | evecs = evecs[:, :self.spectrum_dim] 356 | spectrum_mats.append(evecs) 357 | 358 | # GNN layers 359 | for i in range(self.num_gnn_layers): 360 | node_embed = self.gnn_layers[i]( 361 | node_embed, full_edge_index, edge_indicators, data.batch, 362 | spectrum_mats 363 | ) 364 | 365 | # Post-GNN dense layers on node features 366 | node_embed = self.relu(self.postgnn_dense_1(node_embed)) 367 | node_embed = self.relu(self.postgnn_dense_2(node_embed)) 368 | 369 | # For all possible edges (i.e. node pairs), compute probability 370 | node_embed_1 = node_embed[full_edge_index[0]] # Shape: E x D' 371 | node_embed_2 = node_embed[full_edge_index[1]] # Shape: E x D' 372 | node_prod = node_embed_1 * node_embed_2 373 | 374 | edge_probs = torch.sigmoid(self.link_dense(node_prod))[:, 0] 375 | 376 | return edge_probs 377 | 378 | def loss(self, pred_probs, true_probs): 379 | """ 380 | Computes the loss of a batch. 381 | Arguments: 382 | `pred_probs`: an E-tensor of predicted probabilities 383 | `true_probs`: an E-tensor of true probabilities 384 | Returns a scalar loss value. 385 | """ 386 | return self.bce_loss(pred_probs, true_probs) 387 | -------------------------------------------------------------------------------- /src/model/image_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from model.util import sanitize_sacred_arguments 4 | 5 | class MNISTProbUNetTimeConcat(torch.nn.Module): 6 | 7 | def __init__( 8 | self, t_limit, enc_dims=[32, 64, 128, 256], dec_dims=[32, 64, 128], 9 | time_embed_size=32, data_channels=1 10 | ): 11 | """ 12 | Initialize a time-dependent U-net for MNIST, where time embeddings are 13 | concatenated to image representations. 14 | Arguments: 15 | `t_limit`: maximum time horizon 16 | `enc_dims`: the number of channels in each encoding layer 17 | `dec_dims`: the number of channels in each decoding layer (given in 18 | reverse order of usage) 19 | `time_embed_size`: size of the time embeddings 20 | `data_channels`: number of channels in input image 21 | """ 22 | super().__init__() 23 | 24 | assert len(enc_dims) == 4 25 | assert len(dec_dims) == 3 26 | 27 | self.creation_args = locals() 28 | del self.creation_args["self"] 29 | del self.creation_args["__class__"] 30 | self.creation_args = sanitize_sacred_arguments(self.creation_args) 31 | 32 | self.t_limit = t_limit 33 | 34 | # Encoders: receptive field increases and depth increases 35 | self.conv_e1 = torch.nn.Conv2d( 36 | data_channels + time_embed_size, enc_dims[0], kernel_size=3, 37 | stride=1, bias=False 38 | ) 39 | self.time_dense_e1 = torch.nn.Linear(2, time_embed_size) 40 | self.norm_e1 = torch.nn.GroupNorm(4, num_channels=enc_dims[0]) 41 | self.conv_e2 = torch.nn.Conv2d( 42 | enc_dims[0] + time_embed_size, enc_dims[1], kernel_size=3, stride=2, 43 | bias=False 44 | ) 45 | self.time_dense_e2 = torch.nn.Linear(2, time_embed_size) 46 | self.norm_e2 = torch.nn.GroupNorm(32, num_channels=enc_dims[1]) 47 | self.conv_e3 = torch.nn.Conv2d( 48 | enc_dims[1] + time_embed_size, enc_dims[2], kernel_size=3, stride=2, 49 | bias=False 50 | ) 51 | self.time_dense_e3 = torch.nn.Linear(2, time_embed_size) 52 | self.norm_e3 = torch.nn.GroupNorm(32, num_channels=enc_dims[2]) 53 | self.conv_e4 = torch.nn.Conv2d( 54 | enc_dims[2] + time_embed_size, enc_dims[3], kernel_size=3, stride=2, 55 | bias=False 56 | ) 57 | self.time_dense_e4 = torch.nn.Linear(2, time_embed_size) 58 | self.norm_e4 = torch.nn.GroupNorm(32, num_channels=enc_dims[3]) 59 | 60 | # Decoders: depth decreases 61 | self.conv_d4 = torch.nn.ConvTranspose2d( 62 | enc_dims[3] + time_embed_size, dec_dims[2], 3, stride=2, bias=False 63 | ) 64 | self.time_dense_d4 = torch.nn.Linear(2, time_embed_size) 65 | self.norm_d4 = torch.nn.GroupNorm(32, num_channels=dec_dims[2]) 66 | self.conv_d3 = torch.nn.ConvTranspose2d( 67 | dec_dims[2] + enc_dims[2] + time_embed_size, dec_dims[1], 3, 68 | stride=2, output_padding=1, bias=False 69 | ) 70 | self.time_dense_d3 = torch.nn.Linear(2, time_embed_size) 71 | self.norm_d3 = torch.nn.GroupNorm(32, num_channels=dec_dims[1]) 72 | self.conv_d2 = torch.nn.ConvTranspose2d( 73 | dec_dims[1] + enc_dims[1] + time_embed_size, dec_dims[0], 3, 74 | stride=2, output_padding=1, bias=False 75 | ) 76 | self.time_dense_d2 = torch.nn.Linear(2, time_embed_size) 77 | self.norm_d2 = torch.nn.GroupNorm(32, num_channels=dec_dims[0]) 78 | self.conv_d1 = torch.nn.ConvTranspose2d( 79 | dec_dims[0] + enc_dims[0], data_channels, 3, stride=1, bias=True 80 | ) 81 | 82 | # Activation functions 83 | self.swish = lambda x: x * torch.sigmoid(x) 84 | 85 | # Loss 86 | self.bce_loss = torch.nn.BCELoss() 87 | 88 | def forward(self, xt, t): 89 | """ 90 | Forward pass of the network. 91 | Arguments: 92 | `xt`: B x 1 x H x W tensor containing the images to train on 93 | `t`: B-tensor containing the times to train the network for each 94 | image 95 | Returns a B x 1 x H x W tensor which consists of the prediction. 96 | """ 97 | # Get the time embeddings for `t` 98 | # We embed the time as cos((t/T) * (2pi)) and sin((t/T) * (2pi)) 99 | time_embed_args = (t[:, None] / self.t_limit) * (2 * np.pi) 100 | # Shape: B x 1 101 | time_embed = self.swish( 102 | torch.cat([ 103 | torch.sin(time_embed_args), torch.cos(time_embed_args) 104 | ], dim=1) 105 | ) 106 | # Shape: B x 2 107 | 108 | # Encoding 109 | enc_1_out = self.swish(self.norm_e1(self.conv_e1( 110 | torch.cat([ 111 | xt, 112 | torch.tile( 113 | self.time_dense_e1(time_embed)[:, :, None, None], 114 | (1, 1) + xt.shape[2:] 115 | ) 116 | ], dim=1) 117 | ))) 118 | enc_2_out = self.swish(self.norm_e2(self.conv_e2( 119 | torch.cat([ 120 | enc_1_out, 121 | torch.tile( 122 | self.time_dense_e2(time_embed)[:, :, None, None], 123 | (1, 1) + enc_1_out.shape[2:] 124 | ) 125 | ], dim=1) 126 | ))) 127 | enc_3_out = self.swish(self.norm_e3(self.conv_e3( 128 | torch.cat([ 129 | enc_2_out, 130 | torch.tile( 131 | self.time_dense_e3(time_embed)[:, :, None, None], 132 | (1, 1) + enc_2_out.shape[2:] 133 | ) 134 | ], dim=1) 135 | ))) 136 | enc_4_out = self.swish(self.norm_e4(self.conv_e4( 137 | torch.cat([ 138 | enc_3_out, 139 | torch.tile( 140 | self.time_dense_e4(time_embed)[:, :, None, None], 141 | (1, 1) + enc_3_out.shape[2:] 142 | ) 143 | ], dim=1) 144 | ))) 145 | 146 | # Decoding 147 | dec_4_out = self.swish(self.norm_d4(self.conv_d4( 148 | torch.cat([ 149 | enc_4_out, 150 | torch.tile( 151 | self.time_dense_d4(time_embed)[:, :, None, None], 152 | (1, 1) + enc_4_out.shape[2:] 153 | ) 154 | ], dim=1) 155 | ))) 156 | dec_3_out = self.swish(self.norm_d3(self.conv_d3( 157 | torch.cat([ 158 | dec_4_out, enc_3_out, 159 | torch.tile( 160 | self.time_dense_d3(time_embed)[:, :, None, None], 161 | (1, 1) + dec_4_out.shape[2:] 162 | ) 163 | ], dim=1) 164 | ))) 165 | dec_2_out = self.swish(self.norm_d2(self.conv_d2( 166 | torch.cat([ 167 | dec_3_out, enc_2_out, 168 | torch.tile( 169 | self.time_dense_d2(time_embed)[:, :, None, None], 170 | (1, 1) + dec_3_out.shape[2:] 171 | ) 172 | ], dim=1) 173 | ))) 174 | dec_1_out = self.conv_d1(torch.cat([dec_2_out, enc_1_out], dim=1)) 175 | return torch.sigmoid(dec_1_out) 176 | 177 | def loss(self, pred_probs, true_probs): 178 | """ 179 | Computes the loss of the neural network. 180 | Arguments: 181 | `pred_probs`: a B x 1 x H x W tensor of predicted probabilities 182 | `true_probs`: a B x 1 x H x W tensor of true probabilities 183 | Returns a scalar loss of binary cross-entropy values, averaged across 184 | all dimensions. 185 | """ 186 | return self.bce_loss(pred_probs, true_probs) 187 | 188 | 189 | class MNISTProbUNetTimeAdd(torch.nn.Module): 190 | 191 | def __init__( 192 | self, t_limit, enc_dims=[32, 64, 128, 256], dec_dims=[32, 64, 128], 193 | time_embed_size=32, time_embed_std=30, use_time_embed_dense=False, 194 | data_channels=1 195 | ): 196 | """ 197 | Initialize a time-dependent U-net for MNIST, where time embeddings are 198 | added to image representations. 199 | Arguments: 200 | `t_limit`: maximum time horizon 201 | `enc_dims`: the number of channels in each encoding layer 202 | `dec_dims`: the number of channels in each decoding layer (given in 203 | reverse order of usage) 204 | `time_embed_size`: size of the time embeddings 205 | `time_embed_std`: standard deviation of random weights to sample for 206 | time embeddings 207 | `use_time_embed_dense`: if True, have a dense layer on top of time 208 | embeddings initially 209 | `data_channels`: number of channels in input image 210 | """ 211 | super().__init__() 212 | 213 | assert len(enc_dims) == 4 214 | assert len(dec_dims) == 3 215 | assert time_embed_size % 2 == 0 216 | 217 | self.creation_args = locals() 218 | del self.creation_args["self"] 219 | del self.creation_args["__class__"] 220 | self.creation_args = sanitize_sacred_arguments(self.creation_args) 221 | 222 | self.t_limit = t_limit 223 | self.use_time_embed_dense = use_time_embed_dense 224 | 225 | # Random embedding layer for time; the random weights are set at the 226 | # start and are not trainable 227 | self.time_embed_rand_weights = torch.nn.Parameter( 228 | torch.randn(time_embed_size // 2) * time_embed_std, 229 | requires_grad=False 230 | ) 231 | if use_time_embed_dense: 232 | self.time_embed_dense = torch.nn.Linear( 233 | time_embed_size, time_embed_size 234 | ) 235 | 236 | # Encoders: receptive field increases and depth increases 237 | self.conv_e1 = torch.nn.Conv2d( 238 | data_channels, enc_dims[0], kernel_size=3, stride=1, bias=False 239 | ) 240 | self.time_dense_e1 = torch.nn.Linear(time_embed_size, enc_dims[0]) 241 | self.norm_e1 = torch.nn.GroupNorm(4, num_channels=enc_dims[0]) 242 | self.conv_e2 = torch.nn.Conv2d( 243 | enc_dims[0], enc_dims[1], kernel_size=3, stride=2, bias=False 244 | ) 245 | self.time_dense_e2 = torch.nn.Linear(time_embed_size, enc_dims[1]) 246 | self.norm_e2 = torch.nn.GroupNorm(32, num_channels=enc_dims[1]) 247 | self.conv_e3 = torch.nn.Conv2d( 248 | enc_dims[1], enc_dims[2], kernel_size=3, stride=2, bias=False 249 | ) 250 | self.time_dense_e3 = torch.nn.Linear(time_embed_size, enc_dims[2]) 251 | self.norm_e3 = torch.nn.GroupNorm(32, num_channels=enc_dims[2]) 252 | self.conv_e4 = torch.nn.Conv2d( 253 | enc_dims[2], enc_dims[3], kernel_size=3, stride=2, bias=False 254 | ) 255 | self.time_dense_e4 = torch.nn.Linear(time_embed_size, enc_dims[3]) 256 | self.norm_e4 = torch.nn.GroupNorm(32, num_channels=enc_dims[3]) 257 | 258 | # Decoders: depth decreases 259 | self.conv_d4 = torch.nn.ConvTranspose2d( 260 | enc_dims[3], dec_dims[2], 3, stride=2, bias=False 261 | ) 262 | self.time_dense_d4 = torch.nn.Linear(time_embed_size, dec_dims[2]) 263 | self.norm_d4 = torch.nn.GroupNorm(32, num_channels=dec_dims[2]) 264 | self.conv_d3 = torch.nn.ConvTranspose2d( 265 | dec_dims[2] + enc_dims[2], dec_dims[1], 3, stride=2, 266 | output_padding=1, bias=False 267 | ) 268 | self.time_dense_d3 = torch.nn.Linear(time_embed_size, dec_dims[1]) 269 | self.norm_d3 = torch.nn.GroupNorm(32, num_channels=dec_dims[1]) 270 | self.conv_d2 = torch.nn.ConvTranspose2d( 271 | dec_dims[1] + enc_dims[1], dec_dims[0], 3, stride=2, 272 | output_padding=1, bias=False 273 | ) 274 | self.time_dense_d2 = torch.nn.Linear(time_embed_size, dec_dims[0]) 275 | self.norm_d2 = torch.nn.GroupNorm(32, num_channels=dec_dims[0]) 276 | self.conv_d1 = torch.nn.ConvTranspose2d( 277 | dec_dims[0] + enc_dims[0], data_channels, 3, stride=1, bias=True 278 | ) 279 | 280 | # Activation functions 281 | self.swish = lambda x: x * torch.sigmoid(x) 282 | 283 | # Loss 284 | self.bce_loss = torch.nn.BCELoss() 285 | 286 | def forward(self, xt, t): 287 | """ 288 | Forward pass of the network. 289 | Arguments: 290 | `xt`: B x 1 x H x W tensor containing the images to train on 291 | `t`: B-tensor containing the times to train the network for each 292 | image 293 | Returns a B x 1 x H x W tensor which consists of the prediction. 294 | """ 295 | # Get the time embeddings for `t` 296 | # We had sampled a vector z from a zero-mean Gaussian of fixed variance 297 | # We embed the time as cos((t/T) * (2pi) * z) and sin((t/T) * (2pi) * z) 298 | time_embed_args = (t[:, None] / self.t_limit) * \ 299 | self.time_embed_rand_weights[None, :] * (2 * np.pi) 300 | # Shape: B x (E / 2) 301 | 302 | time_embed = torch.cat([ 303 | torch.sin(time_embed_args), torch.cos(time_embed_args) 304 | ], dim=1) 305 | if self.use_time_embed_dense: 306 | time_embed = self.swish(self.time_embed_dense(time_embed)) 307 | else: 308 | time_embed = self.swish(time_embed) 309 | # Shape: B x E 310 | 311 | # Encoding 312 | enc_1_out = self.swish(self.norm_e1( 313 | self.conv_e1(xt) + 314 | self.time_dense_e1(time_embed)[:, :, None, None] 315 | )) 316 | enc_2_out = self.swish(self.norm_e2( 317 | self.conv_e2(enc_1_out) + 318 | self.time_dense_e2(time_embed)[:, :, None, None] 319 | )) 320 | enc_3_out = self.swish(self.norm_e3( 321 | self.conv_e3(enc_2_out) + 322 | self.time_dense_e3(time_embed)[:, :, None, None] 323 | )) 324 | enc_4_out = self.swish(self.norm_e4( 325 | self.conv_e4(enc_3_out) + 326 | self.time_dense_e4(time_embed)[:, :, None, None] 327 | )) 328 | 329 | # Decoding 330 | dec_4_out = self.swish(self.norm_d4( 331 | self.conv_d4(enc_4_out) + 332 | self.time_dense_d4(time_embed)[:, :, None, None] 333 | )) 334 | dec_3_out = self.swish(self.norm_d3( 335 | self.conv_d3(torch.cat([dec_4_out, enc_3_out], dim=1)) + 336 | self.time_dense_d3(time_embed)[:, :, None, None] 337 | )) 338 | dec_2_out = self.swish(self.norm_d2( 339 | self.conv_d2(torch.cat([dec_3_out, enc_2_out], dim=1)) + 340 | self.time_dense_d2(time_embed)[:, :, None, None] 341 | )) 342 | dec_1_out = self.conv_d1(torch.cat([dec_2_out, enc_1_out], dim=1)) 343 | return torch.sigmoid(dec_1_out) 344 | 345 | def loss(self, pred_probs, true_probs): 346 | """ 347 | Computes the loss of the neural network. 348 | Arguments: 349 | `pred_probs`: a B x 1 x H x W tensor of predicted probabilities 350 | `true_probs`: a B x 1 x H x W tensor of true probabilities 351 | Returns a scalar loss of binary cross-entropy values, averaged across 352 | all dimensions. 353 | """ 354 | return self.bce_loss(pred_probs, true_probs) 355 | -------------------------------------------------------------------------------- /src/model/train_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import tqdm 4 | import os 5 | import sacred 6 | import model.util as util 7 | import feature.graph_conversions as graph_conversions 8 | import model.generate as generate 9 | import analysis.graph_metrics as graph_metrics 10 | import analysis.mmd as mmd 11 | 12 | 13 | MODEL_DIR = os.environ.get( 14 | "MODEL_DIR", 15 | "/gstore/data/resbioai/tsenga5/branched_diffusion/models/trained_models/misc" 16 | ) 17 | 18 | train_ex = sacred.Experiment("train") 19 | 20 | train_ex.observers.append( 21 | sacred.observers.FileStorageObserver.create(MODEL_DIR) 22 | ) 23 | 24 | # Define device 25 | if torch.cuda.is_available(): 26 | DEVICE = "cuda" 27 | else: 28 | DEVICE = "cpu" 29 | 30 | 31 | @train_ex.config 32 | def config(): 33 | # Number of training epochs 34 | num_epochs = 30 35 | 36 | # Learning rate 37 | learning_rate = 0.001 38 | 39 | 40 | @train_ex.command 41 | def train_model( 42 | model, diffuser, data_loader, num_epochs, learning_rate, _run, t_limit=1000 43 | ): 44 | """ 45 | Trains a diffusion model using the given instantiated model and discrete 46 | diffuser object. 47 | Arguments: 48 | `model`: an instantiated model which takes in x, t and predicts a 49 | posterior 50 | `diffuser`: a DiscreteDiffuser object 51 | `data_loader`: a DataLoader object that yields batches of data as 52 | tensors in pairs: x, y 53 | `class_time_to_branch_index`: function that takes in B-tensors of class 54 | and time and maps to a B-tensor of branch indices 55 | `num_epochs`: number of epochs to train for 56 | `learning_rate`: learning rate to use for training 57 | `t_limit`: training will occur between time 1 and `t_limit` 58 | """ 59 | run_num = _run._id 60 | output_dir = os.path.join(MODEL_DIR, str(run_num)) 61 | 62 | model.train() 63 | torch.set_grad_enabled(True) 64 | optim = torch.optim.Adam(model.parameters(), lr=learning_rate) 65 | 66 | for epoch_num in range(num_epochs): 67 | batch_losses = [] 68 | t_iter = tqdm.tqdm(data_loader) 69 | for x0, y in t_iter: 70 | x0 = x0.to(DEVICE).float() 71 | 72 | # Sample random times between 1 and t_limit (inclusive) 73 | t = torch.randint( 74 | t_limit, size=(x0.shape[0],), device=DEVICE 75 | ) + 1 76 | 77 | # Run diffusion forward to get xt and the posterior parameter to 78 | # predict 79 | xt, true_post = diffuser.forward(x0, t) 80 | 81 | # Get model-predicted posterior parameter 82 | pred_post = model(xt, t) 83 | 84 | # Compute loss 85 | loss = model.loss(pred_post, true_post) 86 | loss_val = loss.item() 87 | t_iter.set_description("Loss: %.4f" % loss_val) 88 | 89 | if np.isnan(loss_val): 90 | continue 91 | 92 | optim.zero_grad() 93 | loss.backward() 94 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 95 | optim.step() 96 | 97 | batch_losses.append(loss_val) 98 | 99 | epoch_loss = np.mean(batch_losses) 100 | print("Epoch %d average Loss: %.4f" % (epoch_num + 1, epoch_loss)) 101 | 102 | _run.log_scalar("train_epoch_loss", epoch_loss) 103 | _run.log_scalar("train_batch_losses", batch_losses) 104 | 105 | model_path = os.path.join( 106 | output_dir, "epoch_%d_ckpt.pth" % (epoch_num + 1) 107 | ) 108 | link_path = os.path.join(output_dir, "last_ckpt.pth") 109 | 110 | # Save model 111 | util.save_model(model, model_path) 112 | 113 | # Create symlink to last epoch 114 | if os.path.islink(link_path): 115 | os.remove(link_path) 116 | os.symlink(os.path.basename(model_path), link_path) 117 | 118 | 119 | @train_ex.command 120 | def train_graph_model( 121 | model, diffuser, data_loader, num_epochs, learning_rate, _run, t_limit=1000, 122 | compute_mmd=False, val_data_loader=None, mmd_sample_size=200 123 | ): 124 | """ 125 | Trains a diffusion model on graphs using the given instantiated model and 126 | discrete diffuser object. 127 | Arguments: 128 | `model`: an instantiated model which takes in x, t and predicts a 129 | posterior on edges in canonical order 130 | `diffuser`: a DiscreteDiffuser object 131 | `data_loader`: a DataLoader object that yields torch-geometric Data 132 | objects 133 | `class_time_to_branch_index`: function that takes in B-tensors of class 134 | and time and maps to a B-tensor of branch indices 135 | `num_epochs`: number of epochs to train for 136 | `learning_rate`: learning rate to use for training 137 | `t_limit`: training will occur between time 1 and `t_limit` 138 | `compute_mmd`: if True, compute some performance metrics at the end of 139 | training 140 | `val_data_loader`: if `compute_mmd` is True, this must be another data 141 | loader (like `data_loader`) which yields validation-set objects 142 | `mmd_sample_size`: number of graphs to compute MMD over 143 | """ 144 | run_num = _run._id 145 | output_dir = os.path.join(MODEL_DIR, str(run_num)) 146 | 147 | model.train() 148 | torch.set_grad_enabled(True) 149 | optim = torch.optim.Adam(model.parameters(), lr=learning_rate) 150 | 151 | for epoch_num in range(num_epochs): 152 | batch_losses = [] 153 | t_iter = tqdm.tqdm(data_loader) 154 | for data in t_iter: 155 | 156 | e0, edge_batch_inds = graph_conversions.pyg_data_to_edge_vector( 157 | data, return_batch_inds=True 158 | ) # Shape: E 159 | 160 | # Pick some random times t between 1 and t_limit (inclusive), one 161 | # value for each individual graph 162 | graph_sizes = torch.diff(data.ptr) 163 | graph_times = torch.randint( 164 | t_limit, size=(graph_sizes.shape[0],), device=DEVICE 165 | ) + 1 166 | 167 | # Tile the graph times to the size of all nodes 168 | t_v = graph_times[data.batch].float() # Shape: V 169 | t_e = graph_times[edge_batch_inds].float() # Shape: E 170 | 171 | # Add noise to edges from time 0 to time t 172 | et, true_post = diffuser.forward(e0[:, None], t_e) 173 | # Do the noising on E x 1 tensors 174 | et, true_post = et[:, 0], true_post[:, 0] 175 | data.edge_index = graph_conversions.edge_vector_to_pyg_data( 176 | data, et 177 | ) 178 | # Note: this modifies `data` 179 | 180 | # Get model-predicted posterior parameter 181 | pred_post = model(data, t_v) 182 | 183 | # Compute loss 184 | loss = model.loss(pred_post, true_post) 185 | loss_val = loss.item() 186 | t_iter.set_description("Loss: %.4f" % loss_val) 187 | 188 | if np.isnan(loss_val): 189 | continue 190 | 191 | optim.zero_grad() 192 | loss.backward() 193 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 194 | optim.step() 195 | 196 | batch_losses.append(loss_val) 197 | 198 | epoch_loss = np.mean(batch_losses) 199 | print("Epoch %d average Loss: %.4f" % (epoch_num + 1, epoch_loss)) 200 | 201 | _run.log_scalar("train_epoch_loss", epoch_loss) 202 | _run.log_scalar("train_batch_losses", batch_losses) 203 | 204 | model_path = os.path.join( 205 | output_dir, "epoch_%d_ckpt.pth" % (epoch_num + 1) 206 | ) 207 | link_path = os.path.join(output_dir, "last_ckpt.pth") 208 | 209 | # Save model 210 | util.save_model(model, model_path) 211 | 212 | # Create symlink to last epoch 213 | if os.path.islink(link_path): 214 | os.remove(link_path) 215 | os.symlink(os.path.basename(model_path), link_path) 216 | 217 | # If required, compute MMD metrics 218 | if compute_mmd: 219 | num_batches = int(np.ceil(mmd_sample_size / data_loader.batch_size)) 220 | train_graphs_1, train_graphs_2 = [], [] 221 | gen_graphs = [] 222 | data_iter_1, data_iter_2 = iter(data_loader), iter(val_data_loader) 223 | print("Generating %d graphs over %d batches" % ( 224 | mmd_sample_size, num_batches) 225 | ) 226 | for i in range(num_batches): 227 | print("Batch %d/%d" % (i + 1, num_batches)) 228 | data = next(data_iter_1) 229 | train_graphs_1.extend( 230 | graph_conversions.split_pyg_data_to_nx_graphs(data) 231 | ) 232 | data = next(data_iter_2) 233 | train_graphs_2.extend( 234 | graph_conversions.split_pyg_data_to_nx_graphs(data) 235 | ) 236 | edges = graph_conversions.pyg_data_to_edge_vector(data) 237 | sampled_edges = diffuser.sample_prior( 238 | edges.shape[0], # Samples will be E x 1 239 | torch.tile(torch.tensor([t_limit], device=DEVICE), edges.shape) 240 | )[:, 0] # Shape: E 241 | data.edge_index = graph_conversions.edge_vector_to_pyg_data( 242 | data, sampled_edges 243 | ) 244 | 245 | samples = generate.generate_graph_samples( 246 | model, diffuser, data, t_limit=t_limit, verbose=True 247 | ) 248 | gen_graphs.extend( 249 | graph_conversions.split_pyg_data_to_nx_graphs(samples) 250 | ) 251 | 252 | train_graphs_1 = train_graphs_1[:mmd_sample_size] 253 | train_graphs_2 = train_graphs_2[:mmd_sample_size] 254 | gen_graphs = gen_graphs[:mmd_sample_size] 255 | assert len(train_graphs_1) == mmd_sample_size 256 | assert len(train_graphs_2) == mmd_sample_size 257 | assert len(gen_graphs) == mmd_sample_size 258 | all_graphs = train_graphs_1 + train_graphs_2 + gen_graphs 259 | 260 | # Compute MMD values 261 | print("MMD (squared) values:") 262 | square_func = np.square 263 | kernel_type = "gaussian_total_variation" 264 | 265 | degree_hists = mmd.make_histograms( 266 | graph_metrics.get_degrees(all_graphs), bin_width=1 267 | ) 268 | degree_mmd_1 = square_func(mmd.compute_maximum_mean_discrepancy( 269 | degree_hists[:mmd_sample_size], degree_hists[-mmd_sample_size:], 270 | kernel_type, sigma=1 271 | )) 272 | degree_mmd_2 = square_func(mmd.compute_maximum_mean_discrepancy( 273 | degree_hists[:mmd_sample_size], 274 | degree_hists[mmd_sample_size:-mmd_sample_size], 275 | kernel_type, sigma=1 276 | )) 277 | _run.log_scalar("degree_mmd", degree_mmd_1) 278 | _run.log_scalar("degree_mmd_baseline", degree_mmd_2) 279 | print("Degree MMD ratio: %.8f/%.8f = %.8f" % ( 280 | degree_mmd_1, degree_mmd_2, degree_mmd_1 / degree_mmd_2 281 | )) 282 | 283 | cluster_coef_hists = mmd.make_histograms( 284 | graph_metrics.get_clustering_coefficients(all_graphs), num_bins=100 285 | ) 286 | cluster_coef_mmd_1 = square_func(mmd.compute_maximum_mean_discrepancy( 287 | cluster_coef_hists[:mmd_sample_size], 288 | cluster_coef_hists[-mmd_sample_size:], 289 | kernel_type, sigma=0.1 290 | )) 291 | cluster_coef_mmd_2 = square_func(mmd.compute_maximum_mean_discrepancy( 292 | cluster_coef_hists[:mmd_sample_size], 293 | cluster_coef_hists[mmd_sample_size:-mmd_sample_size], 294 | kernel_type, sigma=0.1 295 | )) 296 | _run.log_scalar("cluster_coef_mmd", cluster_coef_mmd_1) 297 | _run.log_scalar("cluster_coef_mmd_baseline", cluster_coef_mmd_2) 298 | print("Clustering coefficient MMD ratio: %.8f/%.8f = %.8f" % ( 299 | cluster_coef_mmd_1, cluster_coef_mmd_2, 300 | cluster_coef_mmd_1 / cluster_coef_mmd_2 301 | )) 302 | 303 | spectra_hists = mmd.make_histograms( 304 | graph_metrics.get_spectra(all_graphs), 305 | bin_array=np.linspace(-1e-5, 2, 200 + 1) 306 | ) 307 | spectra_mmd_1 = square_func(mmd.compute_maximum_mean_discrepancy( 308 | spectra_hists[:mmd_sample_size], spectra_hists[-mmd_sample_size:], 309 | kernel_type, sigma=1 310 | )) 311 | spectra_mmd_2 = square_func(mmd.compute_maximum_mean_discrepancy( 312 | spectra_hists[:mmd_sample_size], 313 | spectra_hists[mmd_sample_size:-mmd_sample_size], 314 | kernel_type, sigma=1 315 | )) 316 | _run.log_scalar("spectra_mmd", spectra_mmd_1) 317 | _run.log_scalar("spectra_mmd_baseline", spectra_mmd_2) 318 | print("Spectrum MMD ratio: %.8f/%.8f = %.8f" % ( 319 | spectra_mmd_1, spectra_mmd_2, spectra_mmd_1 / spectra_mmd_2 320 | )) 321 | 322 | orbit_counts = graph_metrics.get_orbit_counts(all_graphs) 323 | orbit_counts = np.stack([ 324 | np.mean(counts, axis=0) for counts in orbit_counts 325 | ]) 326 | orbit_mmd_1 = square_func(mmd.compute_maximum_mean_discrepancy( 327 | orbit_counts[:mmd_sample_size], orbit_counts[-mmd_sample_size:], 328 | kernel_type, normalize=False, sigma=30 329 | )) 330 | orbit_mmd_2 = square_func(mmd.compute_maximum_mean_discrepancy( 331 | orbit_counts[:mmd_sample_size], 332 | orbit_counts[mmd_sample_size:-mmd_sample_size], 333 | kernel_type, normalize=False, sigma=30 334 | )) 335 | _run.log_scalar("orbit_mmd", orbit_mmd_1) 336 | _run.log_scalar("orbit_mmd_baseline", orbit_mmd_2) 337 | print("Orbit MMD ratio: %.8f/%.8f = %.8f" % ( 338 | orbit_mmd_1, orbit_mmd_2, orbit_mmd_1 / orbit_mmd_2 339 | )) 340 | -------------------------------------------------------------------------------- /src/model/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def sanitize_sacred_arguments(args): 4 | """ 5 | This function goes through and sanitizes the arguments to native types. 6 | Lists and dictionaries passed through Sacred automatically become 7 | ReadOnlyLists and ReadOnlyDicts. This function will go through and 8 | recursively change them to native lists and dicts. 9 | `args` can be a single token, a list of items, or a dictionary of items. 10 | The return type will be a native token, list, or dictionary. 11 | """ 12 | if isinstance(args, list): # Captures ReadOnlyLists 13 | return [ 14 | sanitize_sacred_arguments(item) for item in args 15 | ] 16 | elif isinstance(args, dict): # Captures ReadOnlyDicts 17 | return { 18 | str(key) : sanitize_sacred_arguments(val) \ 19 | for key, val in args.items() 20 | } 21 | else: # Return the single token as-is 22 | return args 23 | 24 | 25 | def save_model(model, save_path): 26 | """ 27 | Saves the given model at the given path. This saves the state of the model 28 | (i.e. trained layers and parameters), and the arguments used to create the 29 | model (i.e. a dictionary of the original arguments). 30 | """ 31 | save_dict = { 32 | "model_state": model.state_dict(), 33 | "model_creation_args": model.creation_args 34 | } 35 | torch.save(save_dict, save_path) 36 | 37 | 38 | def load_model(model_class, load_path): 39 | """ 40 | Restores a model from the given path. `model_class` must be the class for 41 | which the saved model was created from. This will create a model of this 42 | class, using the loaded creation arguments. It will then restore the learned 43 | parameters to the model. 44 | """ 45 | load_dict = torch.load(load_path) 46 | model_state = load_dict["model_state"] 47 | model_creation_args = load_dict["model_creation_args"] 48 | model = model_class(**model_creation_args) 49 | model.load_state_dict(model_state) 50 | return model 51 | -------------------------------------------------------------------------------- /src/plot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/GraphGUIDE/dad0dd371268684a5203839441febe0484d8a3e4/src/plot/__init__.py -------------------------------------------------------------------------------- /src/plot/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | def plot_mnist_digits( 5 | digits, grid_size=(10, 1), scale=1, clip=True, title=None 6 | ): 7 | """ 8 | Plots MNIST digits. No normalization will be done. 9 | Arguments: 10 | `digits`: a B x 1 x 28 x 28 NumPy array of numbers between 11 | 0 and 1 12 | `grid_size`: a pair of integers denoting the number of digits 13 | to plot horizontally and vertically (in that order); if 14 | more digits are provided than spaces in the grid, leftover 15 | digits will not be plotted; if fewer digits are provided 16 | than spaces in the grid, there will be at most one 17 | unfinished row 18 | `scale`: amount to scale figure size by 19 | `clip`: if True, clip values to between 0 and 1 20 | `title`: if given, title for the plot 21 | """ 22 | digits = np.transpose(digits, (0, 2, 3, 1)) 23 | if clip: 24 | digits = np.clip(digits, 0, 1) 25 | 26 | width, height = grid_size 27 | num_digits = len(digits) 28 | height = min(height, num_digits // width) 29 | 30 | figsize = (width * scale, (height * scale) + 0.5) 31 | 32 | fig, ax = plt.subplots( 33 | ncols=width, nrows=height, 34 | figsize=figsize, gridspec_kw={"wspace": 0, "hspace": 0} 35 | ) 36 | if height == 1: 37 | ax = [ax] 38 | if width == 1: 39 | ax = [[a] for a in ax] 40 | 41 | for j in range(height): 42 | for i in range(width): 43 | index = i + (width * j) 44 | ax[j][i].imshow(digits[index], cmap="gray", aspect="auto") 45 | ax[j][i].axis("off") 46 | if title is not None: 47 | ax[0][0].set_title(title) 48 | plt.subplots_adjust(bottom=0.25) 49 | plt.show() 50 | --------------------------------------------------------------------------------