├── src
├── __init__.py
├── metrics
│ ├── __init__.py
│ ├── train_metrics.py
│ ├── molecular_metrics_discrete.py
│ ├── abstract_metrics.py
│ └── tls_metrics.py
├── models
│ ├── __init__.py
│ ├── layers.py
│ ├── extra_features_molecular.py
│ ├── transformer_model.py
│ └── extra_features.py
├── analysis
│ ├── __init__.py
│ ├── dist_helper.py
│ └── visualization.py
├── datasets
│ ├── __init__.py
│ ├── dataset_utils.py
│ ├── abstract_dataset.py
│ ├── tu_dataset_origin.py
│ ├── tu_dataset.py
│ ├── spectre_dataset.py
│ └── moses_dataset.py
├── flow_matching
│ ├── init.py
│ ├── utils.py
│ ├── flow_matching_utils.py
│ ├── noise_distribution.py
│ ├── time_distorter.py
│ └── rate_matrix.py
├── utils.py
└── main.py
├── configs
├── __init__.py
├── dataset
│ ├── tls.yaml
│ ├── sbm.yaml
│ ├── tree.yaml
│ ├── comm20.yaml
│ ├── planar.yaml
│ ├── moses.yaml
│ ├── zinc.yaml
│ ├── qm9.yaml
│ └── guacamol.yaml
├── config.yaml
├── experiment
│ ├── test.yaml
│ ├── qm9_with_h.yaml
│ ├── qm9_no_h.yaml
│ ├── debug.yaml
│ ├── comm20.yaml
│ ├── planar.yaml
│ ├── tree.yaml
│ ├── sbm.yaml
│ ├── tls.yaml
│ ├── guacamol.yaml
│ ├── zinc.yaml
│ └── moses.yaml
├── sample
│ └── sample_default.yaml
├── train
│ └── train_default.yaml
├── model
│ └── discrete.yaml
└── general
│ └── general_default.yaml
├── images
├── defog.pdf
├── defog.png
├── motivation.pdf
├── qm9_molecule_4.gif
└── sbm_molecule_14.gif
├── docker
├── entrypoint.sh
└── dependencies
│ ├── pip_no_deps.sh
│ ├── apt-runtime.txt
│ └── environment.yaml
├── setup.py
├── requirements.txt
├── environment.yaml
├── LICENSE
├── Dockerfile
├── .gitignore
└── README.md
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/metrics/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/analysis/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/flow_matching/init.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/configs/dataset/tls.yaml:
--------------------------------------------------------------------------------
1 | name: tls
2 | datadir: 'data/tls'
--------------------------------------------------------------------------------
/configs/dataset/sbm.yaml:
--------------------------------------------------------------------------------
1 | name: sbm
2 | remove_h: null
3 | datadir: 'data/sbm/'
--------------------------------------------------------------------------------
/configs/dataset/tree.yaml:
--------------------------------------------------------------------------------
1 | name: tree
2 | remove_h: null
3 | datadir: 'data/tree/'
--------------------------------------------------------------------------------
/configs/dataset/comm20.yaml:
--------------------------------------------------------------------------------
1 | name: comm20
2 | remove_h: null
3 | datadir: 'data/comm20/'
--------------------------------------------------------------------------------
/configs/dataset/planar.yaml:
--------------------------------------------------------------------------------
1 | name: planar
2 | remove_h: null
3 | datadir: 'data/planar/'
--------------------------------------------------------------------------------
/images/defog.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/manuelmlmadeira/DeFoG/HEAD/images/defog.pdf
--------------------------------------------------------------------------------
/images/defog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/manuelmlmadeira/DeFoG/HEAD/images/defog.png
--------------------------------------------------------------------------------
/images/motivation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/manuelmlmadeira/DeFoG/HEAD/images/motivation.pdf
--------------------------------------------------------------------------------
/images/qm9_molecule_4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/manuelmlmadeira/DeFoG/HEAD/images/qm9_molecule_4.gif
--------------------------------------------------------------------------------
/images/sbm_molecule_14.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/manuelmlmadeira/DeFoG/HEAD/images/sbm_molecule_14.gif
--------------------------------------------------------------------------------
/docker/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | # Activate conda environment
5 | source /opt/conda/etc/profile.d/conda.sh
6 | conda activate defog
7 |
8 | exec "$@"
--------------------------------------------------------------------------------
/configs/dataset/moses.yaml:
--------------------------------------------------------------------------------
1 | name: 'moses'
2 | datadir: 'data/moses/moses_pyg/' # Relative to the moses_dataset.py file
3 | remove_h: null
4 | filter: False
5 | compute_fcd: False # Compute FCD scores
--------------------------------------------------------------------------------
/configs/dataset/zinc.yaml:
--------------------------------------------------------------------------------
1 | name: 'zinc' # qm9, qm9_positional
2 | datadir: 'data/zinc/'
3 | remove_h: True
4 | random_subset: null
5 | pin_memory: False
6 | compute_fcd: True # Compute FCD scores
7 | aromatic: False
--------------------------------------------------------------------------------
/configs/dataset/qm9.yaml:
--------------------------------------------------------------------------------
1 | name: 'qm9' # qm9, qm9_positional
2 | datadir: 'data/qm9/qm9_pyg/'
3 | remove_h: True
4 | random_subset: null
5 | pin_memory: False
6 | compute_fcd: True # Compute FCD scores
7 | aromatic: True
--------------------------------------------------------------------------------
/docker/dependencies/pip_no_deps.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Activate the conda environment so that the installation happens within the environment
4 | source /conda/etc/profile.d/conda.sh
5 | conda activate defog
6 |
7 | # Install fcd within the activated environment
8 | pip install fcd==1.2 --no-deps
--------------------------------------------------------------------------------
/configs/dataset/guacamol.yaml:
--------------------------------------------------------------------------------
1 | name: 'guacamol'
2 | datadir: 'data/guacamol/guacamol_pyg/' # Relative to the guacamol_dataset.py file
3 | remove_h: null
4 | filter: True # Use the filtered version or the raw guacamol file
5 | compute_fcd: False # Compute FCD scores
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="defog",
5 | version="1.0.0",
6 | url=None,
7 | author="Yiming Qin, Manuel Madeira et al.",
8 | author_email="",
9 | description="DeFoG: Discrete flow matching for graph generation",
10 | packages=find_packages(),
11 | )
12 |
--------------------------------------------------------------------------------
/configs/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - general : general_default
4 | - model : discrete
5 | - train : train_default
6 | - dataset : qm9
7 | - sample : sample_default
8 |
9 | hydra:
10 | job:
11 | chdir: True
12 | run:
13 | dir: ../outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}-${dataset.name}-${general.name}
14 |
15 |
--------------------------------------------------------------------------------
/configs/experiment/test.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | general:
3 | name : 'test'
4 | gpus : 1
5 | wandb: 'disabled'
6 | sample_every_val: 1
7 | samples_to_generate: 10
8 | samples_to_save: 5
9 | chains_to_save: 1
10 | test_only: False
11 | train:
12 | ema_decay : 0
13 | batch_size: 3
14 | save_model: False
15 | n_epochs : 3
16 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | hydra-core==1.3.2
2 | imageio==2.31.1
3 | matplotlib==3.7.1
4 | networkx==2.8.7
5 | numpy==1.23
6 | omegaconf==2.3.0
7 | overrides==7.3.1
8 | pandas==1.4
9 | pyemd==1.0.0
10 | PyGSP==0.5.1
11 | pytorch_lightning==2.0.4
12 | scipy==1.11.0
13 | setuptools==68.0.0
14 | torch_geometric==2.3.1
15 | torchmetrics==0.11.4
16 | tqdm==4.65.0
17 | wandb==0.15.4
18 | seaborn
--------------------------------------------------------------------------------
/docker/dependencies/apt-runtime.txt:
--------------------------------------------------------------------------------
1 | build-essential
2 | ca-certificates
3 | pkg-config
4 | tzdata
5 | libsm6
6 | libxext-dev
7 | libxrender1
8 | libcurl3-dev
9 | libfreetype6-dev
10 | libzmq3-dev
11 | libcupti-dev
12 | libjpeg-dev
13 | libpng-dev
14 | zlib1g-dev
15 | locales
16 | rsync
17 | cmake
18 | g++
19 | swig
20 | vim
21 | nano
22 | curl
23 | wget
24 | unzip
25 | zsh
26 | git
27 | tmux
28 | htop
29 | tree
--------------------------------------------------------------------------------
/configs/experiment/qm9_with_h.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | general:
3 | name : 'qm9h'
4 | gpus : 1
5 | wandb: 'online'
6 | evaluate_all_checkpoints: False
7 | test_only: null
8 | train:
9 | n_epochs: 2000
10 | batch_size: 512
11 | save_model: True
12 | sample:
13 | sample_steps: 500
14 | time_distortion: 'polydec'
15 | omega: 0.05
16 | eta: 0
17 | model:
18 | n_layers: 7
19 | dataset:
20 | remove_h: False
21 |
--------------------------------------------------------------------------------
/configs/sample/sample_default.yaml:
--------------------------------------------------------------------------------
1 | # balanced rate matrice settings
2 | eta: 0.
3 | omega: 0.
4 |
5 | # generation settings
6 | sample_steps: 1000
7 | time_distortion: "identity" # 'identity', 'cosine', polyinc, polydec
8 | search: False # 'all' | 'target_guidance' | 'distortion' | 'stochasticity' | False
9 |
10 | # fixed
11 | rdb: 'general' # general | column | entry
12 | rdb_crit: dummy # max_marginal | x_t | p_x1_g_xt | x_1 | p_xt_g_x1 | p_xtdt_g_xt | x_0 | xhat_t
13 | # abs_state
14 |
--------------------------------------------------------------------------------
/configs/train/train_default.yaml:
--------------------------------------------------------------------------------
1 | # Training settings
2 | n_epochs: 1000
3 | batch_size: 512
4 | lr: 0.0002
5 | clip_grad: null # float, null to disable
6 | save_model: True
7 | num_workers: 0
8 | ema_decay: 0 # 'Amount of EMA decay, 0 means off. A reasonable value is 0.999.'
9 | progress_bar: false
10 | weight_decay: 1e-12
11 | optimizer: adamw # adamw,nadamw,nadam => nadamw for large batches, see http://arxiv.org/abs/2102.06356 for the use of nesterov momentum with large batches
12 | seed: 0
13 |
14 | time_distortion: "identity" # 'identity', 'cosine', polyinc, polydec
15 |
--------------------------------------------------------------------------------
/configs/experiment/qm9_no_h.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | general:
3 | name : 'qm9_no_h'
4 | gpus : 1
5 | wandb: 'online'
6 | test_only: null
7 | evaluate_all_checkpoints: False
8 | # debug
9 | check_val_every_n_epochs: 50
10 | sample_every_val: 1
11 | train:
12 | n_epochs: 1000
13 | batch_size: 1024
14 | save_model: True
15 | sample:
16 | sample_steps: 500
17 | time_distortion: 'polydec'
18 | omega: 0
19 | eta: 0
20 | model:
21 | n_layers: 9
22 | transition: marginal
23 | dataset:
24 | remove_h: True
25 | pin_memory: True
26 | num_workers: 16
27 |
--------------------------------------------------------------------------------
/configs/experiment/debug.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | general:
3 | name : 'debug'
4 | gpus : 0
5 | wandb: 'disabled'
6 | sample_every_val: 1
7 | samples_to_generate: 4
8 | samples_to_save: 2
9 | chains_to_save: 1
10 | remove_h: True
11 | number_chain_steps: 10 # Number of frames in each gif
12 | check_val_every_n_epochs: 1
13 | train:
14 | batch_size: 4
15 | save_model: False
16 | n_epochs: 2
17 | model:
18 | n_layers: 2
19 | hidden_mlp_dims: {'X': 17, 'E': 18, 'y': 19 }
20 | hidden_dims: {'dx': 20, 'de': 21, 'dy': 22, 'n_head': 5, 'dim_ffX': 23, 'dim_ffE': 24, 'dim_ffy': 25}
21 | sample:
22 | sample_steps: 10
23 |
--------------------------------------------------------------------------------
/configs/model/discrete.yaml:
--------------------------------------------------------------------------------
1 | # Model settings
2 | transition: 'marginal' # uniform, marginal, argmax, absorbfirst, absorbing
3 | model: 'graph_tf'
4 | n_layers: 5
5 |
6 | extra_features: 'rrwp' # 'all', 'cycles', 'eigenvalues', 'rrwp', 'rrwp_comp' or null
7 | rrwp_steps: 12
8 |
9 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly
10 | hidden_mlp_dims: {'X': 256, 'E': 128, 'y': 128}
11 |
12 | # The dimensions should satisfy dx % n_head == 0
13 | hidden_dims : {'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128}
14 |
15 | # training weight for edges, y, and nodes
16 | lambda_train: [5, 0] # X=1, E = lambda[0], y = lambda[1]
17 |
--------------------------------------------------------------------------------
/configs/experiment/comm20.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | general:
3 | name : 'comm20'
4 | gpus : 1
5 | wandb: 'online'
6 | resume: null # If resume, path to ckpt file from outputs directory in main directory
7 | test_only: null
8 | # test_only: /home/yqin/coding/graph_dfm/outputs/2024-04-25/18-10-09-max_64Temb/checkpoints/max_64Temb/epoch=287999.ckpt
9 | check_val_every_n_epochs: 1000
10 | sample_every_val: 10
11 | samples_to_generate: 20
12 | samples_to_save: 20
13 | chains_to_save: 1
14 | log_every_steps: 50
15 | number_chain_steps: 50 # Number of frames in each gif
16 | final_model_samples_to_generate: 20
17 | final_model_samples_to_save: 10
18 | final_model_chains_to_save: 10
19 | train:
20 | n_epochs: 1000000
21 | batch_size: 256
22 | save_model: True
23 | model:
24 | n_layers: 8
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: defog
2 | channels:
3 | - conda-forge
4 | - nvidia
5 | - pytorch
6 | dependencies:
7 | - python=3.9
8 | - rdkit=2023.03.2
9 | - graph-tool=2.45
10 | - pytorch==2.4.0
11 | - pytorch-cuda=12.1
12 | - psi4=1.9.1
13 | - pip=24.2
14 | - pip:
15 | - ipykernel==6.29.5
16 | - hydra-core==1.3.2
17 | - imageio==2.31.1
18 | - matplotlib==3.7.1
19 | - networkx==2.8.7
20 | - numpy==1.23
21 | - omegaconf==2.3.0
22 | - overrides==7.3.1
23 | - pandas==1.4
24 | - pyemd==1.0.0
25 | - PyGSP==0.5.1
26 | - pytorch_lightning==2.0.4
27 | - scipy==1.11.0
28 | - setuptools==68.0.0
29 | - torch_geometric==2.3.1
30 | - torchmetrics==0.11.4
31 | - tqdm==4.65.0
32 | - wandb==0.15.4
33 | - seaborn==0.13.2
34 | - gpustat==0.6.0
35 | - black==24.3.0
36 | - -e .
37 |
--------------------------------------------------------------------------------
/docker/dependencies/environment.yaml:
--------------------------------------------------------------------------------
1 | name: defog
2 | channels:
3 | - conda-forge
4 | - nvidia
5 | - pytorch
6 | dependencies:
7 | - python=3.9
8 | - rdkit=2023.03.2
9 | - graph-tool=2.45
10 | - pytorch==2.4.0
11 | - pytorch-cuda=12.1
12 | - psi4=1.9.1
13 | - pip=24.2
14 | - pip:
15 | - ipykernel==6.29.5
16 | - hydra-core==1.3.2
17 | - imageio==2.31.1
18 | - matplotlib==3.7.1
19 | - networkx==2.8.7
20 | - numpy==1.23
21 | - omegaconf==2.3.0
22 | - overrides==7.3.1
23 | - pandas==1.4
24 | - pyemd==1.0.0
25 | - PyGSP==0.5.1
26 | - pytorch_lightning==2.0.4
27 | - scipy==1.11.0
28 | - setuptools==68.0.0
29 | - torch_geometric==2.3.1
30 | - torchmetrics==0.11.4
31 | - tqdm==4.65.0
32 | - wandb==0.15.4
33 | - seaborn==0.13.2
34 | - gpustat==0.6.0
35 | - black==24.3.0
36 | prefix: /opt/conda/envs/defog
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2012-2025 Manuel Madeira, Yiming Qin
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining
4 | a copy of this software and associated documentation files (the
5 | "Software"), to deal in the Software without restriction, including
6 | without limitation the rights to use, copy, modify, merge, publish,
7 | distribute, sublicense, and/or sell copies of the Software, and to
8 | permit persons to whom the Software is furnished to do so, subject to
9 | the following conditions:
10 |
11 | The above copyright notice and this permission notice shall be
12 | included in all copies or substantial portions of the Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21 |
--------------------------------------------------------------------------------
/configs/experiment/planar.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | general:
3 | name : 'planar'
4 | gpus : 1
5 | wandb: 'online'
6 | resume: null # If resume, path to ckpt file from outputs directory in main directory
7 | test_only: null
8 | check_val_every_n_epochs: 2000
9 | sample_every_val: 1
10 | samples_to_generate: 40
11 | samples_to_save: 9
12 | chains_to_save: 1
13 | final_model_samples_to_generate: 40
14 | final_model_samples_to_save: 30
15 | final_model_chains_to_save: 20
16 | sample_steps: 1000
17 | train:
18 | n_epochs: 100000
19 | batch_size: 64
20 | save_model: True
21 | sample:
22 | time_distortion: 'polydec'
23 | omega: 0.05
24 | eta: 50
25 | model:
26 | n_layers: 10
27 |
28 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly
29 | # At the moment (03/08), y contains quite little information
30 | hidden_mlp_dims: { 'X': 128, 'E': 64, 'y': 128 }
31 |
32 | # The dimensions should satisfy dx % n_head == 0
33 | hidden_dims: { 'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 64, 'dim_ffy': 256 }
34 |
--------------------------------------------------------------------------------
/configs/experiment/tree.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | general:
3 | name : 'tree'
4 | gpus : 1
5 | wandb: 'online'
6 | resume: null # If resume, path to ckpt file from outputs directory in main directory
7 | test_only: null
8 | check_val_every_n_epochs: 2000
9 | sample_every_val: 1
10 | samples_to_generate: 40
11 | samples_to_save: 9
12 | chains_to_save: 1
13 | final_model_samples_to_generate: 40
14 | final_model_samples_to_save: 30
15 | final_model_chains_to_save: 20
16 | sample_steps: 1000
17 | train:
18 | n_epochs: 100000
19 | batch_size: 64
20 | save_model: True
21 | time_distortion: 'polydec'
22 | sample:
23 | time_distortion: 'polydec'
24 | omega: 0
25 | eta: 0
26 | model:
27 | n_layers: 10
28 |
29 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly
30 | # At the moment (03/08), y contains quite little information
31 | hidden_mlp_dims: { 'X': 128, 'E': 64, 'y': 128 }
32 |
33 | # The dimensions should satisfy dx % n_head == 0
34 | hidden_dims: { 'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 64, 'dim_ffy': 256 }
35 |
--------------------------------------------------------------------------------
/configs/experiment/sbm.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | general:
3 | name : 'sbm'
4 | gpus : 1
5 | wandb: 'online'
6 | resume: null # If resume, path to ckpt file from outputs directory in main directory
7 | test_only: null
8 | check_val_every_n_epochs: 2000
9 | sample_every_val: 1
10 | samples_to_generate: 40
11 | samples_to_save: 9
12 | chains_to_save: 1
13 | final_model_samples_to_generate: 40
14 | final_model_samples_to_save: 30
15 | final_model_chains_to_save: 20
16 | sample_steps: 1000
17 | train:
18 | n_epochs: 50000
19 | batch_size: 32
20 | save_model: True
21 | sample:
22 | time_distortion: 'identity'
23 | omega: 0
24 | eta: 0
25 | model:
26 | transition: 'absorbfirst'
27 | n_layers: 8
28 | rrwp_steps: 20
29 |
30 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly
31 | # At the moment (03/08), y contains quite little information
32 | hidden_mlp_dims: { 'X': 128, 'E': 64, 'y': 128 }
33 |
34 | # The dimensions should satisfy dx % n_head == 0
35 | hidden_dims: { 'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 64, 'dim_ffy': 256 }
--------------------------------------------------------------------------------
/configs/general/general_default.yaml:
--------------------------------------------------------------------------------
1 | # General settings
2 | name: 'graph-tf-model' # Warning: 'debug' and 'test' are reserved name that have a special behavior
3 |
4 | wandb: 'online' # online | offline | disabled
5 | gpus: 1 # Multi-gpu is not implemented on this branch
6 |
7 | resume: null # If resume, path to ckpt file from outputs directory in main directory
8 | test_only: null # Use absolute path
9 |
10 | check_val_every_n_epochs: 5
11 | sample_every_val: 4
12 | val_check_interval: null
13 | samples_to_generate: 512 # We advise to set it to 2 x batch_size maximum
14 | samples_to_save: 20
15 | chains_to_save: 1
16 | log_every_steps: 50
17 | number_chain_steps: 50 # Number of frames in each gif
18 |
19 | # Test
20 | generated_path: null
21 | final_model_samples_to_generate: 10000
22 | final_model_samples_to_save: 30
23 | final_model_chains_to_save: 20
24 | num_sample_fold: 1
25 | evaluate_all_checkpoints: False
26 | save_samples: True # Save samples at the final test step or not, normally only used at the last epoch or during inference
27 |
28 | # Conditional Generation
29 | conditional: False
30 | target: 'k2'
31 | guidance_weight: 2.0
32 |
--------------------------------------------------------------------------------
/configs/experiment/tls.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | general:
3 | name : 'tls'
4 | gpus : 1
5 | wandb: 'online'
6 | resume: null # If resume, path to ckpt file from outputs directory in main directory
7 | test_only: null
8 | check_val_every_n_epochs: 2000
9 | sample_every_val: 1
10 | samples_to_generate: 40
11 | samples_to_save: 9
12 | chains_to_save: 1
13 | final_model_samples_to_generate: 80 # same as test set
14 | final_model_samples_to_save: 30
15 | final_model_chains_to_save: 20
16 | sample_steps: 1000
17 | conditional: True
18 | target: 'k2'
19 | guidance_weight: 2.0
20 | train:
21 | n_epochs: 100000
22 | batch_size: 64
23 | save_model: True
24 | sample:
25 | time_distortion: 'polydec'
26 | omega: 0.05
27 | eta: 0
28 | model:
29 | n_layers: 10
30 | rrwp_steps: 20
31 |
32 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly
33 | # At the moment (03/08), y contains quite little information
34 | hidden_mlp_dims: { 'X': 128, 'E': 64, 'y': 128 }
35 |
36 | # The dimensions should satisfy dx % n_head == 0
37 | hidden_dims: { 'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 64, 'dim_ffy': 256 }
38 |
--------------------------------------------------------------------------------
/configs/experiment/guacamol.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | general:
3 | name : 'guacamol'
4 | gpus : 1
5 | wandb: 'online'
6 | resume: null
7 | test_only: null
8 | check_val_every_n_epochs: 2
9 | val_check_interval: null
10 | sample_every_val: 2
11 | samples_to_generate: 500
12 | samples_to_save: 20
13 | chains_to_save: 5
14 | log_every_steps: 50
15 | final_model_samples_to_generate: 18000
16 | final_model_samples_to_save: 10
17 | final_model_chains_to_save: 5
18 | train:
19 | optimizer: adam
20 | n_epochs: 1000
21 | batch_size: 64
22 | save_model: True
23 | lr: 2e-4
24 | time_distortion: 'polydec'
25 | sample:
26 | time_distortion: 'polydec'
27 | omega: 0.1
28 | eta: 300
29 | model:
30 | n_layers: 12
31 | type: 'discrete'
32 | transition: 'marginal' # uniform or marginal
33 | model: 'graph_tf'
34 |
35 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly
36 | # At the moment (03/08), y contains quite little information
37 | hidden_mlp_dims: {'X': 256, 'E': 128, 'y': 256}
38 | rrwp_steps: 20
39 |
40 | # The dimensions should satisfy dx % n_head == 0
41 | hidden_dims: {'dx': 256, 'de': 64, 'dy': 128, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 256}
42 |
--------------------------------------------------------------------------------
/configs/experiment/zinc.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | general:
3 | name : 'zinc'
4 | gpus : 1
5 | wandb: 'online'
6 | remove_h: True
7 | resume: null
8 | test_only: null
9 | check_val_every_n_epochs: 4
10 | val_check_interval: null
11 | sample_every_val: 2
12 | samples_to_generate: 256
13 | samples_to_save: 20
14 | chains_to_save: 5
15 | log_every_steps: 50
16 |
17 | final_model_samples_to_generate: 10000
18 | final_model_samples_to_save: 50
19 | final_model_chains_to_save: 20
20 | train:
21 | optimizer: adamw
22 | n_epochs: 300
23 | batch_size: 256
24 | save_model: True
25 | lr: 2e-4
26 | num_workers: 4
27 | time_distortion: 'polydec'
28 | sample:
29 | time_distortion: 'polydec'
30 | omega: 0.
31 | eta: 0.
32 | model:
33 | n_layers: 12
34 | type: 'discrete'
35 | transition: 'marginal' # uniform or marginal
36 | model: 'graph_tf'
37 |
38 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly
39 | # At the moment (03/08), y contains quite little information
40 | hidden_mlp_dims: { 'X': 256, 'E': 128, 'y': 256}
41 | rrwp_steps: 20
42 |
43 | # The dimensions should satisfy dx % n_head == 0
44 | hidden_dims: { 'dx': 256, 'de': 64, 'dy': 128, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 256}
45 |
--------------------------------------------------------------------------------
/configs/experiment/moses.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | general:
3 | name : 'moses'
4 | gpus : 1
5 | wandb: 'online'
6 | remove_h: True
7 | resume: null
8 | test_only: null
9 | check_val_every_n_epochs: 1
10 | val_check_interval: null
11 | sample_every_val: 4
12 | samples_to_generate: 256
13 | samples_to_save: 20
14 | chains_to_save: 5
15 | log_every_steps: 50
16 |
17 | final_model_samples_to_generate: 25000
18 | final_model_samples_to_save: 50
19 | final_model_chains_to_save: 20
20 |
21 | train:
22 | optimizer: adamw
23 | n_epochs: 300
24 | batch_size: 256
25 | save_model: True
26 | lr: 2e-4
27 | num_workers: 4
28 | time_distortion: 'polydec'
29 | sample:
30 | time_distortion: 'polydec'
31 | omega: 0.5
32 | eta: 200
33 | model:
34 | n_layers: 12
35 | type: 'discrete'
36 | transition: 'marginal' # uniform or marginal
37 | model: 'graph_tf'
38 |
39 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly
40 | # At the moment (03/08), y contains quite little information
41 | hidden_mlp_dims: { 'X': 256, 'E': 128, 'y': 256}
42 | rrwp_steps: 20
43 |
44 | # The dimensions should satisfy dx % n_head == 0
45 | hidden_dims: { 'dx': 256, 'de': 64, 'dy': 128, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 256}
46 |
--------------------------------------------------------------------------------
/src/models/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Xtoy(nn.Module):
6 | def __init__(self, dx, dy):
7 | """Map node features to global features"""
8 | super().__init__()
9 | self.lin = nn.Linear(4 * dx, dy)
10 |
11 | def forward(self, X):
12 | """X: bs, n, dx."""
13 | m = X.mean(dim=1)
14 | mi = X.min(dim=1)[0]
15 | ma = X.max(dim=1)[0]
16 | std = X.std(dim=1)
17 | z = torch.hstack((m, mi, ma, std))
18 | out = self.lin(z)
19 | return out
20 |
21 |
22 | class Etoy(nn.Module):
23 | def __init__(self, d, dy):
24 | """Map edge features to global features."""
25 | super().__init__()
26 | self.lin = nn.Linear(4 * d, dy)
27 |
28 | def forward(self, E):
29 | """E: bs, n, n, de
30 | Features relative to the diagonal of E could potentially be added.
31 | """
32 | m = E.mean(dim=(1, 2))
33 | mi = E.min(dim=2)[0].min(dim=1)[0]
34 | ma = E.max(dim=2)[0].max(dim=1)[0]
35 | std = torch.std(E, dim=(1, 2))
36 | z = torch.hstack((m, mi, ma, std))
37 | out = self.lin(z)
38 | return out
39 |
40 |
41 | def masked_softmax(x, mask, **kwargs):
42 | if mask.sum() == 0:
43 | return x
44 | x_masked = x.clone()
45 | x_masked[mask == 0] = -float("inf")
46 | return torch.softmax(x_masked, **kwargs)
47 |
--------------------------------------------------------------------------------
/src/flow_matching/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 |
3 |
4 | def p_xt_g_x1(X1, E1, t, limit_dist):
5 | # x1 (B, D)
6 | # t float
7 | # returns (B, D, S) for varying x_t value
8 | device = X1.device
9 | limit_dist.X = limit_dist.X.to(device)
10 | limit_dist.E = limit_dist.E.to(device)
11 |
12 | t_time = t.squeeze(-1)[:, None, None]
13 | X1_onehot = F.one_hot(X1, num_classes=len(limit_dist.X)).float()
14 | E1_onehot = F.one_hot(E1, num_classes=len(limit_dist.E)).float()
15 |
16 | Xt = t_time * X1_onehot + (1 - t_time) * limit_dist.X[None, None, :]
17 | Et = (
18 | t_time[:, None] * E1_onehot
19 | + (1 - t_time[:, None]) * limit_dist.E[None, None, None, :]
20 | )
21 |
22 | assert ((Xt.sum(-1) - 1).abs() < 1e-4).all() and (
23 | (Et.sum(-1) - 1).abs() < 1e-4
24 | ).all()
25 |
26 | return Xt.clamp(min=0.0, max=1.0), Et.clamp(min=0.0, max=1.0)
27 |
28 |
29 | def dt_p_xt_g_x1(X1, E1, limit_dist):
30 | # x1 (B, D)
31 | # returns (B, D, S) for varying x_t value
32 | device = X1.device
33 | limit_dist.X = limit_dist.X.to(device)
34 | limit_dist.E = limit_dist.E.to(device)
35 |
36 | X1_onehot = F.one_hot(X1, num_classes=len(limit_dist.X)).float()
37 | E1_onehot = F.one_hot(E1, num_classes=len(limit_dist.E)).float()
38 |
39 | dX = X1_onehot - limit_dist.X[None, None, :]
40 | dE = E1_onehot - limit_dist.E[None, None, None, :]
41 |
42 | assert (dX.sum(-1).abs() < 1e-4).all() and (dE.sum(-1).abs() < 1e-4).all()
43 |
44 | return dX, dE
45 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Base image
2 | FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
3 |
4 | # Avoid user interaction
5 | ENV DEBIAN_FRONTEND=noninteractive
6 |
7 | # Environment Variables
8 | ENV CONDA_URL=https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-23.3.1-1-Linux-x86_64.sh
9 | ENV CONDA_INSTALL_PATH=/opt/conda
10 | ENV DEPENDENCIES_DIR=/tmp/dependencies
11 |
12 | # Copy needed files
13 | COPY ./docker/dependencies ${DEPENDENCIES_DIR}
14 |
15 | # Install apt dependencies
16 | RUN apt-get update && \
17 | xargs -a ${DEPENDENCIES_DIR}/apt-runtime.txt apt-get install -y --no-install-recommends && \
18 | rm -rf /var/lib/apt/lists/*
19 |
20 | # Set timezone and locale
21 | RUN echo "Etc/UTC" > /etc/timezone && \
22 | ln -sf /usr/share/zoneinfo/Etc/UTC /etc/localtime
23 | RUN locale-gen en_US.UTF-8
24 | ENV LANG=en_US.UTF-8
25 | ENV LANGUAGE=en_US:en
26 | ENV LC_ALL=en_US.UTF-8
27 |
28 | # Install Miniforge (Conda)
29 | RUN mkdir -p /tmp/conda && \
30 | curl -fvL -o /tmp/conda/miniconda.sh ${CONDA_URL} && \
31 | bash /tmp/conda/miniconda.sh -b -p ${CONDA_INSTALL_PATH} -u && \
32 | rm -rf /tmp/conda
33 | # make mamba visible (using mamba for efficiency)
34 | ENV PATH=${CONDA_INSTALL_PATH}/condabin:${CONDA_INSTALL_PATH}/bin:${PATH}
35 |
36 | # Install Mamba for faster Conda operations
37 | RUN conda install mamba -n base -c conda-forge
38 |
39 | # Create Conda environment
40 | RUN mamba env create --file ${DEPENDENCIES_DIR}/environment.yaml
41 |
42 | # Activate Conda environment
43 | ENV PATH=${CONDA_INSTALL_PATH}/envs/defog/bin:${PATH}
44 |
45 | # Install additional pip packages without dependencies if needed
46 | RUN chmod +x ${DEPENDENCIES_DIR}/pip_no_deps.sh && \
47 | ${DEPENDENCIES_DIR}/pip_no_deps.sh
48 |
49 | # Clean up
50 | RUN mamba clean -a -y && \
51 | find ${CONDA_INSTALL_PATH}/envs/defog -name '__pycache__' -type d -exec rm -rf {} +
52 |
53 | # Initialize Conda for shells
54 | RUN mamba init --system bash && \
55 | echo "mamba activate defog" >> /etc/profile.d/conda.sh
56 |
57 | # Delete temporary files
58 | RUN rm -rf ${DEPENDENCIES_DIR}
59 |
60 | # Set working directory
61 | WORKDIR /workspace
62 |
63 | # Entrypoint script
64 | COPY ./docker/entrypoint.sh /entrypoint.sh
65 | RUN chmod +x /entrypoint.sh
66 | ENTRYPOINT ["/entrypoint.sh"]
67 |
68 | # Default command
69 | CMD ["bash"]
70 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Docker configurarion
2 | checkpoints/
3 | defog.egg-info
4 | final_checkpoints/
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 |
136 |
137 | .DS_Store
138 | .idea/
139 | __pycache__/
140 | dgd/configs/__pycache__/
141 | data/
142 | egnn/__pycache__/
143 | equivariant_diffusion/__pycache__/
144 | outputs/
145 | archives/qm9/__pycache__/
146 | archives/qm9/data_utils/__pycache__/
147 | archives/qm9/data_utils/prepare/__pycache__/
148 | archives/qm9/property_prediction/__pycache__/
149 | archives/*
150 | .env
151 | dgd/analysis/orca/orca
152 | ggg_data/
153 | ggg_utils/
154 | saved_models
155 | src/analysis/orca/orca
156 | src/analysis/orca/tmp*
157 | src/timer.dat
158 | checkpoints
159 | wandb/
160 |
--------------------------------------------------------------------------------
/src/flow_matching/flow_matching_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 | from torch.distributions.categorical import Categorical
4 | import numpy as np
5 | import math
6 |
7 | from utils import PlaceHolder
8 |
9 |
10 | def assert_correctly_masked(variable, node_mask):
11 | assert (
12 | variable * (1 - node_mask.long())
13 | ).abs().max().item() < 1e-4, "Variables not masked properly."
14 |
15 |
16 | def sample_discrete_feature_noise(limit_dist, node_mask):
17 | """Sample from the limit distribution of the diffusion process"""
18 | bs, n_max = node_mask.shape
19 | x_limit = limit_dist.X[None, None, :].expand(bs, n_max, -1)
20 | e_limit = limit_dist.E[None, None, None, :].expand(bs, n_max, n_max, -1)
21 | y_limit = limit_dist.y[None, :].expand(bs, -1)
22 | U_X = (
23 | x_limit.flatten(end_dim=-2).multinomial(1, replacement=True).reshape(bs, n_max)
24 | )
25 | U_E = (
26 | e_limit.flatten(end_dim=-2)
27 | .multinomial(1, replacement=True)
28 | .reshape(bs, n_max, n_max)
29 | )
30 | U_y = torch.empty((bs, 0))
31 |
32 | long_mask = node_mask.long()
33 | U_X = U_X.type_as(long_mask)
34 | U_E = U_E.type_as(long_mask)
35 | U_y = U_y.type_as(long_mask)
36 |
37 | U_X = F.one_hot(U_X, num_classes=x_limit.shape[-1]).float()
38 | U_E = F.one_hot(U_E, num_classes=e_limit.shape[-1]).float()
39 |
40 | # Get upper triangular part of edge noise, without main diagonal
41 | upper_triangular_mask = torch.zeros_like(U_E)
42 | indices = torch.triu_indices(row=U_E.size(1), col=U_E.size(2), offset=1)
43 | upper_triangular_mask[:, indices[0], indices[1], :] = 1
44 |
45 | U_E = U_E * upper_triangular_mask
46 | U_E = U_E + torch.transpose(U_E, 1, 2)
47 |
48 | assert (U_E == torch.transpose(U_E, 1, 2)).all()
49 |
50 | return PlaceHolder(X=U_X, E=U_E, y=U_y).mask(node_mask)
51 |
52 |
53 | def sample_discrete_features(probX, probE, node_mask, mask=False):
54 | """Sample features from multinomial distribution with given probabilities (probX, probE, proby)
55 | :param probX: bs, n, dx_out node features
56 | :param probE: bs, n, n, de_out edge features
57 | :param proby: bs, dy_out global features.
58 | """
59 | bs, n, _ = probX.shape
60 | # Noise X
61 | # The masked rows should define probability distributions as well
62 | probX[~node_mask] = 1 / probX.shape[-1]
63 |
64 | # Flatten the probability tensor to sample with multinomial
65 | probX = probX.reshape(bs * n, -1) # (bs * n, dx_out)
66 |
67 | # Sample X
68 | X_t = probX.multinomial(1, replacement=True) # (bs * n, 1)
69 | # X_t = Categorical(probs=probX).sample() # (bs * n, 1)
70 | X_t = X_t.reshape(bs, n) # (bs, n)
71 |
72 | # Noise E
73 | # The masked rows should define probability distributions as well
74 | inverse_edge_mask = ~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2))
75 | diag_mask = torch.eye(n).unsqueeze(0).expand(bs, -1, -1)
76 |
77 | probE[inverse_edge_mask] = 1 / probE.shape[-1]
78 | probE[diag_mask.bool()] = 1 / probE.shape[-1]
79 |
80 | probE = probE.reshape(bs * n * n, -1) # (bs * n * n, de_out)
81 |
82 | # Sample E
83 | E_t = probE.multinomial(1, replacement=True).reshape(bs, n, n) # (bs, n, n)
84 | # E_t = Categorical(probs=probE).sample().reshape(bs, n, n) # (bs, n, n)
85 | E_t = torch.triu(E_t, diagonal=1)
86 | E_t = E_t + torch.transpose(E_t, 1, 2)
87 |
88 | if mask:
89 | X_t = X_t * node_mask
90 | E_t = E_t * node_mask.unsqueeze(1) * node_mask.unsqueeze(2)
91 |
92 | return PlaceHolder(X=X_t, E=E_t, y=torch.zeros(bs, 0).type_as(X_t))
93 |
94 |
95 |
--------------------------------------------------------------------------------
/src/models/extra_features_molecular.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from src import utils
3 |
4 |
5 | class ExtraMolecularFeatures:
6 | def __init__(self, dataset_infos):
7 | self.charge = ChargeFeature(
8 | remove_h=dataset_infos.remove_h, valencies=dataset_infos.valencies
9 | )
10 | self.valency = ValencyFeature()
11 | self.weight = WeightFeature(
12 | max_weight=dataset_infos.max_weight, atom_weights=dataset_infos.atom_weights
13 | )
14 |
15 | def __call__(self, noisy_data):
16 | charge = self.charge(noisy_data).unsqueeze(-1) # (bs, n, 1)
17 | valency = self.valency(noisy_data).unsqueeze(-1) # (bs, n, 1)
18 | weight = self.weight(noisy_data) # (bs, 1)
19 |
20 | extra_edge_attr = torch.zeros((*noisy_data["E_t"].shape[:-1], 0)).type_as(
21 | noisy_data["E_t"]
22 | )
23 |
24 | return utils.PlaceHolder(
25 | X=torch.cat((charge, valency), dim=-1), E=extra_edge_attr, y=weight
26 | )
27 |
28 |
29 | class ChargeFeature:
30 | def __init__(self, remove_h, valencies):
31 | self.remove_h = remove_h
32 | self.valencies = valencies
33 |
34 | def __call__(self, noisy_data):
35 | # bond_orders = torch.tensor(
36 | # [0, 1, 2, 3, 1.5], device=noisy_data["E_t"].device
37 | # ).reshape(1, 1, 1, -1)
38 | de = noisy_data["E_t"].shape[-1]
39 | if de == 5:
40 | bond_orders = torch.tensor(
41 | [0, 1, 2, 3, 1.5], device=noisy_data["E_t"].device
42 | ).reshape(1, 1, 1, -1)
43 | else:
44 | bond_orders = torch.tensor(
45 | [0, 1, 2, 3], device=noisy_data["E_t"].device
46 | ).reshape(1, 1, 1, -1)
47 | weighted_E = noisy_data["E_t"] * bond_orders # (bs, n, n, de)
48 | current_valencies = weighted_E.argmax(dim=-1).sum(dim=-1) # (bs, n)
49 |
50 | valencies = torch.tensor(
51 | self.valencies, device=noisy_data["X_t"].device
52 | ).reshape(1, 1, -1)
53 | X = noisy_data["X_t"] * valencies # (bs, n, dx)
54 | normal_valencies = torch.argmax(X, dim=-1) # (bs, n)
55 |
56 | return (normal_valencies - current_valencies).type_as(noisy_data["X_t"])
57 |
58 |
59 | class ValencyFeature:
60 | def __init__(self):
61 | pass
62 |
63 | def __call__(self, noisy_data):
64 | # bond_orders = torch.tensor(
65 | # [0, 1, 2, 3, 1.5], device=noisy_data["E_t"].device
66 | # ).reshape(1, 1, 1, -1)
67 | de = noisy_data["E_t"].shape[-1]
68 | if de == 5:
69 | bond_orders = torch.tensor(
70 | [0, 1, 2, 3, 1.5], device=noisy_data["E_t"].device
71 | ).reshape(1, 1, 1, -1)
72 | else:
73 | bond_orders = torch.tensor(
74 | [0, 1, 2, 3], device=noisy_data["E_t"].device
75 | ).reshape(1, 1, 1, -1)
76 | # bond_orders = torch.tensor([0, 1, 2, 3], device=noisy_data['E_t'].device).reshape(1, 1, 1, -1) # debug
77 | E = noisy_data["E_t"] * bond_orders # (bs, n, n, de)
78 | valencies = E.argmax(dim=-1).sum(dim=-1) # (bs, n)
79 | return valencies.type_as(noisy_data["X_t"])
80 |
81 |
82 | class WeightFeature:
83 | def __init__(self, max_weight, atom_weights):
84 | self.max_weight = max_weight
85 | self.atom_weight_list = torch.tensor(list(atom_weights.values()))
86 |
87 | def __call__(self, noisy_data):
88 | X = torch.argmax(noisy_data["X_t"], dim=-1) # (bs, n)
89 | X_weights = self.atom_weight_list.to(X.device)[X] # (bs, n)
90 | return (
91 | X_weights.sum(dim=-1).unsqueeze(-1).type_as(noisy_data["X_t"])
92 | / self.max_weight
93 | ) # (bs, 1)
94 |
--------------------------------------------------------------------------------
/src/datasets/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import pickle
3 | from typing import Any, Sequence
4 |
5 | from rdkit import Chem
6 | import torch
7 | from torch_geometric.data import Data
8 | from torch_geometric.utils import subgraph
9 |
10 |
11 | def mol_to_torch_geometric(mol, atom_encoder, smiles):
12 | adj = torch.from_numpy(Chem.rdmolops.GetAdjacencyMatrix(mol, useBO=True))
13 | edge_index = adj.nonzero().contiguous().T
14 | bond_types = adj[edge_index[0], edge_index[1]]
15 | bond_types[bond_types == 1.5] = 4
16 | edge_attr = bond_types.long()
17 |
18 | node_types = []
19 | all_charge = []
20 | for atom in mol.GetAtoms():
21 | node_types.append(atom_encoder[atom.GetSymbol()])
22 | all_charge.append(atom.GetFormalCharge())
23 |
24 | node_types = torch.Tensor(node_types).long()
25 | all_charge = torch.Tensor(all_charge).long()
26 |
27 | data = Data(
28 | x=node_types,
29 | edge_index=edge_index,
30 | edge_attr=edge_attr,
31 | charge=all_charge,
32 | smiles=smiles,
33 | )
34 | return data
35 |
36 |
37 | def remove_hydrogens(data: Data):
38 | to_keep = data.x > 0
39 | new_edge_index, new_edge_attr = subgraph(
40 | to_keep,
41 | data.edge_index,
42 | data.edge_attr,
43 | relabel_nodes=True,
44 | num_nodes=len(to_keep),
45 | )
46 | return Data(
47 | x=data.x[to_keep] - 1, # Shift onehot encoding to match atom decoder
48 | charge=data.charge[to_keep],
49 | edge_index=new_edge_index,
50 | edge_attr=new_edge_attr,
51 | )
52 |
53 |
54 | def save_pickle(array, path):
55 | with open(path, "wb") as f:
56 | pickle.dump(array, f)
57 |
58 |
59 | def load_pickle(path):
60 | with open(path, "rb") as f:
61 | return pickle.load(f)
62 |
63 |
64 | def files_exist(files) -> bool:
65 | return len(files) != 0 and all([osp.exists(f) for f in files])
66 |
67 |
68 | def to_list(value: Any) -> Sequence:
69 | if isinstance(value, Sequence) and not isinstance(value, str):
70 | return value
71 | else:
72 | return [value]
73 |
74 |
75 | class Statistics:
76 | def __init__(
77 | self, num_nodes, node_types, bond_types, charge_types=None, valencies=None
78 | ):
79 | self.num_nodes = num_nodes
80 | self.node_types = node_types
81 | self.bond_types = bond_types
82 | self.charge_types = charge_types
83 | self.valencies = valencies
84 |
85 |
86 | class RemoveYTransform:
87 | def __call__(self, data):
88 | data.y = torch.zeros((1, 0), dtype=torch.float)
89 | return data
90 |
91 |
92 | class DistributionNodes:
93 | def __init__(self, histogram):
94 | """Compute the distribution of the number of nodes in the dataset, and sample from this distribution.
95 | historgram: dict. The keys are num_nodes, the values are counts
96 | """
97 |
98 | if type(histogram) == dict:
99 | max_n_nodes = max(histogram.keys())
100 | prob = torch.zeros(max_n_nodes + 1)
101 | for num_nodes, count in histogram.items():
102 | prob[num_nodes] = count
103 | else:
104 | prob = histogram
105 |
106 | self.prob = prob / prob.sum()
107 | self.m = torch.distributions.Categorical(prob)
108 |
109 | def sample_n(self, n_samples, device):
110 | idx = self.m.sample((n_samples,))
111 | return idx.to(device)
112 |
113 | def log_prob(self, batch_n_nodes):
114 | assert len(batch_n_nodes.size()) == 1
115 | p = self.prob.to(batch_n_nodes.device)
116 |
117 | probas = p[batch_n_nodes]
118 | log_p = torch.log(probas + 1e-30)
119 | return log_p
120 |
--------------------------------------------------------------------------------
/src/metrics/train_metrics.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import wandb
4 | import torch
5 | from torch import Tensor
6 | import torch.nn as nn
7 | from torchmetrics import Metric, MeanSquaredError, MetricCollection
8 |
9 | from metrics.abstract_metrics import (
10 | CrossEntropyMetric,
11 | KLDMetric,
12 | )
13 |
14 |
15 | class NodeMSE(MeanSquaredError):
16 | def __init__(self, *args):
17 | super().__init__(*args)
18 |
19 |
20 | class EdgeMSE(MeanSquaredError):
21 | def __init__(self, *args):
22 | super().__init__(*args)
23 |
24 |
25 | class TrainLossDiscrete(nn.Module):
26 | """Train with Cross entropy"""
27 |
28 | def __init__(self, lambda_train, kld=False):
29 | super().__init__()
30 | self.lambda_train = lambda_train
31 | if not kld:
32 | self.node_loss = CrossEntropyMetric()
33 | self.edge_loss = CrossEntropyMetric()
34 | else:
35 | self.node_loss = KLDMetric()
36 | self.edge_loss = KLDMetric()
37 | self.y_loss = CrossEntropyMetric()
38 |
39 | def forward(
40 | self,
41 | masked_pred_X,
42 | masked_pred_E,
43 | pred_y,
44 | true_X,
45 | true_E,
46 | true_y,
47 | log: bool,
48 | ):
49 | """Compute train metrics
50 | masked_pred_X : tensor -- (bs, n, dx)
51 | masked_pred_E : tensor -- (bs, n, n, de)
52 | pred_y : tensor -- (bs, )
53 | true_X : tensor -- (bs, n, dx)
54 | true_E : tensor -- (bs, n, n, de)
55 | true_y : tensor -- (bs, )
56 | log : boolean."""
57 | true_X = torch.reshape(true_X, (-1, true_X.size(-1))) # (bs * n, dx)
58 | true_E = torch.reshape(true_E, (-1, true_E.size(-1))) # (bs * n * n, de)
59 | masked_pred_X = torch.reshape(
60 | masked_pred_X, (-1, masked_pred_X.size(-1))
61 | ) # (bs * n, dx)
62 | masked_pred_E = torch.reshape(
63 | masked_pred_E, (-1, masked_pred_E.size(-1))
64 | ) # (bs * n * n, de)
65 |
66 | # Remove masked rows
67 | mask_X = (true_X != 0.0).any(dim=-1)
68 | mask_E = (true_E != 0.0).any(dim=-1)
69 |
70 | flat_true_X = true_X[mask_X, :]
71 | flat_pred_X = masked_pred_X[mask_X, :]
72 |
73 | flat_true_E = true_E[mask_E, :]
74 | flat_pred_E = masked_pred_E[mask_E, :]
75 |
76 | loss_X = self.node_loss(flat_pred_X, flat_true_X) if true_X.numel() > 0 else 0.0
77 | loss_E = self.edge_loss(flat_pred_E, flat_true_E) if true_E.numel() > 0 else 0.0
78 | loss_y = self.y_loss(pred_y, true_y) if pred_y.numel() > 0 else 0.0
79 |
80 | if log:
81 | to_log = {
82 | "train_loss/batch_CE": (loss_X + loss_E + loss_y).detach(),
83 | "train_loss/X_CE": (
84 | self.node_loss.compute() if true_X.numel() > 0 else -1
85 | ),
86 | "train_loss/E_CE": (
87 | self.edge_loss.compute() if true_E.numel() > 0 else -1
88 | ),
89 | "train_loss/y_CE": self.y_loss.compute() if true_y.numel() > 0 else -1,
90 | }
91 | if wandb.run:
92 | wandb.log(to_log, commit=True)
93 | return loss_X + self.lambda_train[0] * loss_E + self.lambda_train[1] * loss_y
94 |
95 | def reset(self):
96 | for metric in [self.node_loss, self.edge_loss, self.y_loss]:
97 | metric.reset()
98 |
99 | def log_epoch_metrics(self):
100 | epoch_node_loss = (
101 | self.node_loss.compute() if self.node_loss.total_samples > 0 else -1
102 | )
103 | epoch_edge_loss = (
104 | self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1
105 | )
106 | epoch_y_loss = (
107 | self.y_loss.compute() if self.y_loss.total_samples > 0 else -1
108 | )
109 |
110 | to_log = {
111 | "train_epoch/x_CE": epoch_node_loss,
112 | "train_epoch/E_CE": epoch_edge_loss,
113 | "train_epoch/y_CE": epoch_y_loss,
114 | }
115 | if wandb.run:
116 | wandb.log(to_log, commit=False)
117 |
118 | return to_log
119 |
--------------------------------------------------------------------------------
/src/flow_matching/noise_distribution.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from src import utils
4 |
5 |
6 | class NoiseDistribution:
7 |
8 | def __init__(self, model_transition, dataset_infos):
9 |
10 | self.x_num_classes = dataset_infos.output_dims["X"]
11 | self.e_num_classes = dataset_infos.output_dims["E"]
12 | self.y_num_classes = dataset_infos.output_dims["y"]
13 | self.x_added_classes = 0
14 | self.e_added_classes = 0
15 | self.y_added_classes = 0
16 | self.transition = model_transition
17 |
18 | if model_transition == "uniform":
19 | x_limit = torch.ones(self.x_num_classes) / self.x_num_classes
20 | e_limit = torch.ones(self.e_num_classes) / self.e_num_classes
21 |
22 | elif model_transition == "absorbfirst":
23 | x_limit = torch.zeros(self.x_num_classes)
24 | x_limit[0] = 1
25 | e_limit = torch.zeros(self.e_num_classes)
26 | e_limit[0] = 1
27 |
28 | elif model_transition == "argmax":
29 | node_types = dataset_infos.node_types.float()
30 | x_marginals = node_types / torch.sum(node_types)
31 |
32 | edge_types = dataset_infos.edge_types.float()
33 | e_marginals = edge_types / torch.sum(edge_types)
34 |
35 | x_max_dim = torch.argmax(x_marginals)
36 | e_max_dim = torch.argmax(e_marginals)
37 | x_limit = torch.zeros(self.x_num_classes)
38 | x_limit[x_max_dim] = 1
39 | e_limit = torch.zeros(self.e_num_classes)
40 | e_limit[e_max_dim] = 1
41 |
42 | elif model_transition == "absorbing":
43 | # only add virtual classes when there are several
44 | if self.x_num_classes > 1:
45 | # if self.x_num_classes >= 1:
46 | self.x_num_classes += 1
47 | self.x_added_classes = 1
48 | if self.e_num_classes > 1:
49 | self.e_num_classes += 1
50 | self.e_added_classes = 1
51 |
52 | x_limit = torch.zeros(self.x_num_classes)
53 | x_limit[-1] = 1
54 | e_limit = torch.zeros(self.e_num_classes)
55 | e_limit[-1] = 1
56 |
57 | elif model_transition == "marginal":
58 |
59 | node_types = dataset_infos.node_types.float()
60 | x_limit = node_types / torch.sum(node_types)
61 |
62 | edge_types = dataset_infos.edge_types.float()
63 | e_limit = edge_types / torch.sum(edge_types)
64 |
65 | elif model_transition == "edge_marginal":
66 | x_limit = torch.ones(self.x_num_classes) / self.x_num_classes
67 |
68 | edge_types = dataset_infos.edge_types.float()
69 | e_limit = edge_types / torch.sum(edge_types)
70 |
71 | elif model_transition == "node_marginal":
72 | e_limit = torch.ones(self.e_num_classes) / self.e_num_classes
73 |
74 | node_types = dataset_infos.node_types.float()
75 | x_limit = node_types / torch.sum(node_types)
76 |
77 | else:
78 | raise ValueError(f"Unknown transition model: {model_transition}")
79 |
80 | y_limit = torch.ones(self.y_num_classes) / self.y_num_classes # typically dummy
81 | print(
82 | f"Limit distribution of the classes | Nodes: {x_limit} | Edges: {e_limit}"
83 | )
84 | self.limit_dist = utils.PlaceHolder(X=x_limit, E=e_limit, y=y_limit)
85 |
86 | def update_input_output_dims(self, input_dims):
87 | input_dims["X"] += self.x_added_classes
88 | input_dims["E"] += self.e_added_classes
89 | input_dims["y"] += self.y_added_classes
90 |
91 | def update_dataset_infos(self, dataset_infos):
92 | if hasattr(dataset_infos, "atom_decoder"):
93 | dataset_infos.atom_decoder = (
94 | dataset_infos.atom_decoder + ["Y"] * self.x_added_classes
95 | )
96 |
97 | def get_limit_dist(self):
98 | return self.limit_dist
99 |
100 | def get_noise_dims(self):
101 | return {
102 | "X": len(self.limit_dist.X),
103 | "E": len(self.limit_dist.E),
104 | "y": len(self.limit_dist.E),
105 | }
106 |
107 | def ignore_virtual_classes(self, X, E, y=None):
108 | if self.transition == "absorbing":
109 | new_X = X[..., : -self.x_added_classes]
110 | new_E = E[..., : -self.e_added_classes]
111 | new_y = y[..., : -self.y_added_classes] if y is not None else None
112 | return new_X, new_E, new_y
113 | else:
114 | return X, E, y
115 |
116 | def add_virtual_classes(self, X, E, y=None):
117 | x_virtual = torch.zeros_like(X[..., :1]).repeat(1, 1, self.x_added_classes)
118 | new_X = torch.cat([X, x_virtual], dim=-1)
119 |
120 | e_virtual = torch.zeros_like(E[..., :1]).repeat(1, 1, 1, self.e_added_classes)
121 | new_E = torch.cat([E, e_virtual], dim=-1)
122 |
123 | if y is not None:
124 | y_virtual = torch.zeros_like(y[..., :1]).repeat(1, self.y_added_classes)
125 | new_y = torch.cat([y, y_virtual], dim=-1)
126 | else:
127 | new_y = None
128 |
129 | return new_X, new_E, new_y
130 |
--------------------------------------------------------------------------------
/src/flow_matching/time_distorter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.stats import norm
4 | from scipy.special import beta as beta_func, betaln
5 | from scipy.stats import beta as sp_beta
6 | from scipy.interpolate import interp1d
7 |
8 |
9 | def beta_pdf(x, alpha, beta):
10 | """Beta distribution PDF."""
11 | # coeff = np.exp(betaln(alpha, beta))
12 | # return x ** (alpha - 1) * (1 - x) ** (beta - 1) / coeff
13 | return x ** (alpha - 1) * (1 - x) ** (beta - 1) / beta_func(alpha, beta)
14 |
15 |
16 | def objective_function(alpha, beta, y, t):
17 | """Objective function to minimize (mean squared error)."""
18 | y_pred = beta_pdf(t, alpha, beta)
19 | regularization = (alpha + beta) + (1 / alpha + 1 / beta)
20 | error = np.mean((y - y_pred) ** 2)
21 | error = error + 0.0001 * regularization
22 | return error
23 |
24 |
25 | class TimeDistorter:
26 |
27 | def __init__(
28 | self,
29 | train_distortion,
30 | sample_distortion,
31 | mu=0,
32 | sigma=1,
33 | alpha=1,
34 | beta=1,
35 | ):
36 | self.train_distortion = train_distortion # used for sample_ft
37 | self.sample_distortion = sample_distortion # used for get_ft
38 | self.alpha = alpha
39 | self.beta = beta
40 | print(
41 | f"TimeDistorter: train_distortion={train_distortion}, sample_distortion={sample_distortion}"
42 | )
43 | self.f_inv = None
44 |
45 | def train_ft(self, batch_size, device):
46 | t_uniform = torch.rand((batch_size, 1), device=device)
47 | t_distort = self.apply_distortion(t_uniform, self.train_distortion)
48 |
49 | return t_distort
50 |
51 | def sample_ft(self, t, sample_distortion):
52 | t_distort = self.apply_distortion(t, sample_distortion)
53 | return t_distort
54 |
55 | def fit(self, difficulty, t_array, learning_rate=0.01, iterations=1000):
56 | """Fit a beta distribution to data using the method of moments."""
57 | alpha, beta = self.alpha, self.beta
58 | t_array = t_array + 1e-6 # Avoid division by zero
59 |
60 | for _ in range(iterations):
61 | y_pred = beta_pdf(t_array, alpha, beta)
62 |
63 | # Numerical approximation of the gradients
64 | epsilon = 1e-5
65 | grad_alpha = (
66 | objective_function(alpha + epsilon, beta, difficulty, t_array)
67 | - objective_function(alpha - epsilon, beta, difficulty, t_array)
68 | ) / (2 * epsilon)
69 | grad_beta = (
70 | objective_function(alpha, beta + epsilon, difficulty, t_array)
71 | - objective_function(alpha, beta - epsilon, difficulty, t_array)
72 | ) / (2 * epsilon)
73 |
74 | # # Add regularization gradient components
75 | # grad_alpha += learning_rate * (1 - 1 / alpha**2)
76 | # grad_beta += learning_rate * (1 + 1 / beta**2)
77 |
78 | # Update parameters
79 | alpha -= learning_rate * grad_alpha
80 | beta -= learning_rate * grad_beta
81 |
82 | alpha = min(max(0.3, alpha), 3)
83 | beta = min(max(0.3, beta), 3)
84 |
85 | y_pred = beta_pdf(t_array, alpha, beta)
86 | self.approximate_f_inverse(alpha, beta)
87 |
88 | return y_pred, alpha, beta
89 |
90 | def approximate_f_inverse(self, alpha, beta):
91 | # Generate data points
92 | t_values = np.linspace(0, 1, 100000)
93 | f_values = sp_beta.cdf(t_values, alpha, beta)
94 |
95 | # Sort and remove duplicates
96 | sorted_indices = np.argsort(f_values)
97 | f_values_sorted = f_values[sorted_indices]
98 | t_values_sorted = t_values[sorted_indices]
99 |
100 | # Remove duplicates
101 | _, unique_indices = np.unique(f_values_sorted, return_index=True)
102 | f_values_unique = f_values_sorted[unique_indices]
103 | t_values_unique = t_values_sorted[unique_indices]
104 |
105 | # Create the interpolation function for the inverse
106 | f_inv = interp1d(
107 | f_values_unique,
108 | t_values_unique,
109 | bounds_error=False,
110 | fill_value="extrapolate",
111 | )
112 |
113 | self.f_inv = f_inv
114 |
115 | def apply_distortion(self, t, distortion_type):
116 | assert torch.all((t >= 0) & (t <= 1)), "t must be in the range (0, 1)"
117 |
118 | if distortion_type == "identity":
119 | ft = t
120 | elif distortion_type == "cos":
121 | ft = (1 - torch.cos(t * torch.pi)) / 2
122 | elif distortion_type == "revcos":
123 | ft = 2 * t - (1 - torch.cos(t * torch.pi)) / 2
124 | elif distortion_type == "polyinc":
125 | ft = t**2
126 | elif distortion_type == "polydec":
127 | ft = 2 * t - t**2
128 | elif distortion_type == "beta":
129 | raise ValueError(f"Unsupported for now: {distortion_type}")
130 | elif distortion_type == "logitnormal":
131 | raise ValueError(f"Unsupported for now: {distortion_type}")
132 | else:
133 | raise ValueError(f"Unknown distortion type: {distortion_type}")
134 |
135 | return ft
136 |
--------------------------------------------------------------------------------
/src/analysis/dist_helper.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | #
3 | # Adapted from https://github.com/lrjconan/GRAN/ which in turn is adapted from https://github.com/JiaxuanYou/graph-generation
4 | #
5 | ###############################################################################
6 | import pyemd
7 | import numpy as np
8 | import concurrent.futures
9 | from functools import partial
10 | from scipy.linalg import toeplitz
11 |
12 |
13 | def emd(x, y, distance_scaling=1.0):
14 | support_size = max(len(x), len(y))
15 | d_mat = toeplitz(range(support_size)).astype(float)
16 | distance_mat = d_mat / distance_scaling
17 |
18 | # convert histogram values x and y to float, and make them equal len
19 | x = x.astype(float)
20 | y = y.astype(float)
21 | if len(x) < len(y):
22 | x = np.hstack((x, [0.0] * (support_size - len(x))))
23 | elif len(y) < len(x):
24 | y = np.hstack((y, [0.0] * (support_size - len(y))))
25 |
26 | emd = pyemd.emd(x, y, distance_mat)
27 | return emd
28 |
29 |
30 | def l2(x, y):
31 | dist = np.linalg.norm(x - y, 2)
32 | return dist
33 |
34 |
35 | def emd(x, y, sigma=1.0, distance_scaling=1.0):
36 | """EMD
37 | Args:
38 | x, y: 1D pmf of two distributions with the same support
39 | sigma: standard deviation
40 | """
41 | support_size = max(len(x), len(y))
42 | d_mat = toeplitz(range(support_size)).astype(float)
43 | distance_mat = d_mat / distance_scaling
44 |
45 | # convert histogram values x and y to float, and make them equal len
46 | x = x.astype(float)
47 | y = y.astype(float)
48 | if len(x) < len(y):
49 | x = np.hstack((x, [0.0] * (support_size - len(x))))
50 | elif len(y) < len(x):
51 | y = np.hstack((y, [0.0] * (support_size - len(y))))
52 |
53 | return np.abs(pyemd.emd(x, y, distance_mat))
54 |
55 |
56 | def gaussian_emd(x, y, sigma=1.0, distance_scaling=1.0):
57 | """Gaussian kernel with squared distance in exponential term replaced by EMD
58 | Args:
59 | x, y: 1D pmf of two distributions with the same support
60 | sigma: standard deviation
61 | """
62 | support_size = max(len(x), len(y))
63 | d_mat = toeplitz(range(support_size)).astype(float)
64 | distance_mat = d_mat / distance_scaling
65 |
66 | # convert histogram values x and y to float, and make them equal len
67 | x = x.astype(float)
68 | y = y.astype(float)
69 | if len(x) < len(y):
70 | x = np.hstack((x, [0.0] * (support_size - len(x))))
71 | elif len(y) < len(x):
72 | y = np.hstack((y, [0.0] * (support_size - len(y))))
73 |
74 | emd = pyemd.emd(x, y, distance_mat)
75 | return np.exp(-emd * emd / (2 * sigma * sigma))
76 |
77 |
78 | def gaussian(x, y, sigma=1.0):
79 | support_size = max(len(x), len(y))
80 | # convert histogram values x and y to float, and make them equal len
81 | x = x.astype(float)
82 | y = y.astype(float)
83 | if len(x) < len(y):
84 | x = np.hstack((x, [0.0] * (support_size - len(x))))
85 | elif len(y) < len(x):
86 | y = np.hstack((y, [0.0] * (support_size - len(y))))
87 |
88 | dist = np.linalg.norm(x - y, 2)
89 | return np.exp(-dist * dist / (2 * sigma * sigma))
90 |
91 |
92 | def gaussian_tv(x, y, sigma=1.0):
93 | support_size = max(len(x), len(y))
94 | # convert histogram values x and y to float, and make them equal len
95 | x = x.astype(float)
96 | y = y.astype(float)
97 | if len(x) < len(y):
98 | x = np.hstack((x, [0.0] * (support_size - len(x))))
99 | elif len(y) < len(x):
100 | y = np.hstack((y, [0.0] * (support_size - len(y))))
101 |
102 | dist = np.abs(x - y).sum() / 2.0
103 | return np.exp(-dist * dist / (2 * sigma * sigma))
104 |
105 |
106 | def kernel_parallel_unpacked(x, samples2, kernel):
107 | d = 0
108 | for s2 in samples2:
109 | d += kernel(x, s2)
110 | return d
111 |
112 |
113 | def kernel_parallel_worker(t):
114 | return kernel_parallel_unpacked(*t)
115 |
116 |
117 | def disc(samples1, samples2, kernel, is_parallel=True, *args, **kwargs):
118 | """Discrepancy between 2 samples"""
119 | d = 0
120 |
121 | if not is_parallel:
122 | for s1 in samples1:
123 | for s2 in samples2:
124 | d += kernel(s1, s2, *args, **kwargs)
125 | else:
126 | with concurrent.futures.ThreadPoolExecutor() as executor:
127 | for dist in executor.map(
128 | kernel_parallel_worker,
129 | [(s1, samples2, partial(kernel, *args, **kwargs)) for s1 in samples1],
130 | ):
131 | d += dist
132 | if len(samples1) * len(samples2) > 0:
133 | d /= len(samples1) * len(samples2)
134 | else:
135 | d = 1e6
136 | return d
137 |
138 |
139 | def compute_mmd(samples1, samples2, kernel, is_hist=True, *args, **kwargs):
140 | """MMD between two samples"""
141 | # normalize histograms into pmf
142 | if is_hist:
143 | samples1 = [s1 / (np.sum(s1) + 1e-6) for s1 in samples1]
144 | samples2 = [s2 / (np.sum(s2) + 1e-6) for s2 in samples2]
145 | mmd = (
146 | disc(samples1, samples1, kernel, *args, **kwargs)
147 | + disc(samples2, samples2, kernel, *args, **kwargs)
148 | - 2 * disc(samples1, samples2, kernel, *args, **kwargs)
149 | )
150 |
151 | mmd = np.abs(mmd)
152 |
153 | if mmd < 0:
154 | import pdb
155 |
156 | pdb.set_trace()
157 |
158 | return mmd
159 |
160 |
161 | def compute_emd(samples1, samples2, kernel, is_hist=True, *args, **kwargs):
162 | """EMD between average of two samples"""
163 | # normalize histograms into pmf
164 | if is_hist:
165 | samples1 = [np.mean(samples1)]
166 | samples2 = [np.mean(samples2)]
167 | return disc(samples1, samples2, kernel, *args, **kwargs), [samples1[0], samples2[0]]
168 |
--------------------------------------------------------------------------------
/src/metrics/molecular_metrics_discrete.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchmetrics import Metric, MetricCollection
3 | from torch import Tensor
4 | import wandb
5 | import torch.nn as nn
6 |
7 |
8 | class CEPerClass(Metric):
9 | full_state_update = False
10 |
11 | def __init__(self, class_id):
12 | super().__init__()
13 | self.class_id = class_id
14 | self.add_state("total_ce", default=torch.tensor(0.0), dist_reduce_fx="sum")
15 | self.add_state("total_samples", default=torch.tensor(0.0), dist_reduce_fx="sum")
16 | self.softmax = torch.nn.Softmax(dim=-1)
17 | self.binary_cross_entropy = torch.nn.BCELoss(reduction="sum")
18 |
19 | def update(self, preds: Tensor, target: Tensor) -> None:
20 | """Update state with predictions and targets.
21 | Args:
22 | preds: Predictions from model (bs, n, d) or (bs, n, n, d)
23 | target: Ground truth values (bs, n, d) or (bs, n, n, d)
24 | """
25 | target = target.reshape(-1, target.shape[-1])
26 | mask = (target != 0.0).any(dim=-1)
27 |
28 | prob = self.softmax(preds)[..., self.class_id]
29 | prob = prob.flatten()[mask]
30 |
31 | target = target[:, self.class_id]
32 | target = target[mask]
33 |
34 | output = self.binary_cross_entropy(prob, target)
35 | self.total_ce += output
36 | self.total_samples += prob.numel()
37 |
38 | def compute(self):
39 | return self.total_ce / self.total_samples
40 |
41 |
42 | class HydrogenCE(CEPerClass):
43 | def __init__(self, i):
44 | super().__init__(i)
45 |
46 |
47 | class CarbonCE(CEPerClass):
48 | def __init__(self, i):
49 | super().__init__(i)
50 |
51 |
52 | class NitroCE(CEPerClass):
53 | def __init__(self, i):
54 | super().__init__(i)
55 |
56 |
57 | class OxyCE(CEPerClass):
58 | def __init__(self, i):
59 | super().__init__(i)
60 |
61 |
62 | class FluorCE(CEPerClass):
63 | def __init__(self, i):
64 | super().__init__(i)
65 |
66 |
67 | class BoronCE(CEPerClass):
68 | def __init__(self, i):
69 | super().__init__(i)
70 |
71 |
72 | class BrCE(CEPerClass):
73 | def __init__(self, i):
74 | super().__init__(i)
75 |
76 |
77 | class ClCE(CEPerClass):
78 | def __init__(self, i):
79 | super().__init__(i)
80 |
81 |
82 | class IodineCE(CEPerClass):
83 | def __init__(self, i):
84 | super().__init__(i)
85 |
86 |
87 | class PhosphorusCE(CEPerClass):
88 | def __init__(self, i):
89 | super().__init__(i)
90 |
91 |
92 | class SulfurCE(CEPerClass):
93 | def __init__(self, i):
94 | super().__init__(i)
95 |
96 |
97 | class SeCE(CEPerClass):
98 | def __init__(self, i):
99 | super().__init__(i)
100 |
101 |
102 | class SiCE(CEPerClass):
103 | def __init__(self, i):
104 | super().__init__(i)
105 |
106 |
107 | class NoBondCE(CEPerClass):
108 | def __init__(self, i):
109 | super().__init__(i)
110 |
111 |
112 | class SingleCE(CEPerClass):
113 | def __init__(self, i):
114 | super().__init__(i)
115 |
116 |
117 | class DoubleCE(CEPerClass):
118 | def __init__(self, i):
119 | super().__init__(i)
120 |
121 |
122 | class TripleCE(CEPerClass):
123 | def __init__(self, i):
124 | super().__init__(i)
125 |
126 |
127 | class AromaticCE(CEPerClass):
128 | def __init__(self, i):
129 | super().__init__(i)
130 |
131 |
132 | class AtomMetricsCE(MetricCollection):
133 | def __init__(self, dataset_infos):
134 | atom_decoder = dataset_infos.atom_decoder
135 |
136 | class_dict = {
137 | "H": HydrogenCE,
138 | "C": CarbonCE,
139 | "N": NitroCE,
140 | "O": OxyCE,
141 | "F": FluorCE,
142 | "B": BoronCE,
143 | "Br": BrCE,
144 | "Cl": ClCE,
145 | "I": IodineCE,
146 | "P": PhosphorusCE,
147 | "S": SulfurCE,
148 | "Se": SeCE,
149 | "Si": SiCE,
150 | }
151 |
152 | metrics_list = []
153 | for i, atom_type in enumerate(atom_decoder):
154 | metrics_list.append(class_dict[atom_type](i))
155 | super().__init__(metrics_list)
156 |
157 |
158 | class BondMetricsCE(MetricCollection):
159 | def __init__(self):
160 | ce_no_bond = NoBondCE(0)
161 | ce_SI = SingleCE(1)
162 | ce_DO = DoubleCE(2)
163 | ce_TR = TripleCE(3)
164 | # ce_AR = AromaticCE(4)
165 | # super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR, ce_AR])
166 | super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
167 |
168 |
169 | class TrainMolecularMetricsDiscrete(nn.Module):
170 | def __init__(self, dataset_infos):
171 | super().__init__()
172 | self.train_atom_metrics = AtomMetricsCE(dataset_infos=dataset_infos)
173 | self.train_bond_metrics = BondMetricsCE()
174 |
175 | def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
176 | self.train_atom_metrics(masked_pred_X, true_X)
177 | self.train_bond_metrics(masked_pred_E, true_E)
178 | if log:
179 | to_log = {}
180 | for key, val in self.train_atom_metrics.compute().items():
181 | to_log["train/" + key] = val.item()
182 | for key, val in self.train_bond_metrics.compute().items():
183 | to_log["train/" + key] = val.item()
184 | if wandb.run:
185 | wandb.log(to_log, commit=False)
186 |
187 | def reset(self):
188 | for metric in [self.train_atom_metrics, self.train_bond_metrics]:
189 | metric.reset()
190 |
191 | def log_epoch_metrics(self):
192 | epoch_atom_metrics = self.train_atom_metrics.compute()
193 | epoch_bond_metrics = self.train_bond_metrics.compute()
194 |
195 | to_log = {}
196 | for key, val in epoch_atom_metrics.items():
197 | to_log["train_epoch/" + key] = val.item()
198 | for key, val in epoch_bond_metrics.items():
199 | to_log["train_epoch/" + key] = val.item()
200 | if wandb.run:
201 | wandb.log(to_log, commit=False)
202 |
203 | for key, val in epoch_atom_metrics.items():
204 | epoch_atom_metrics[key] = val.item()
205 | for key, val in epoch_bond_metrics.items():
206 | epoch_bond_metrics[key] = val.item()
207 |
208 | return epoch_atom_metrics, epoch_bond_metrics
209 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch_geometric.utils
3 | from omegaconf import OmegaConf, open_dict
4 | from torch_geometric.utils import to_dense_adj, to_dense_batch
5 | import torch
6 | import omegaconf
7 | import wandb
8 |
9 |
10 | def create_folders(args):
11 | try:
12 | # os.makedirs('checkpoints')
13 | os.makedirs("graphs")
14 | os.makedirs("chains")
15 | except OSError:
16 | pass
17 |
18 | try:
19 | # os.makedirs('checkpoints/' + args.general.name)
20 | os.makedirs("graphs/" + args.general.name)
21 | os.makedirs("chains/" + args.general.name)
22 | except OSError:
23 | pass
24 |
25 |
26 | def normalize(X, E, y, norm_values, norm_biases, node_mask):
27 | X = (X - norm_biases[0]) / norm_values[0]
28 | E = (E - norm_biases[1]) / norm_values[1]
29 | y = (y - norm_biases[2]) / norm_values[2]
30 |
31 | diag = (
32 | torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
33 | )
34 | E[diag] = 0
35 |
36 | return PlaceHolder(X=X, E=E, y=y).mask(node_mask)
37 |
38 |
39 | def unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=False):
40 | """
41 | X : node features
42 | E : edge features
43 | y : global features`
44 | norm_values : [norm value X, norm value E, norm value y]
45 | norm_biases : same order
46 | node_mask
47 | """
48 | X = X * norm_values[0] + norm_biases[0]
49 | E = E * norm_values[1] + norm_biases[1]
50 | y = y * norm_values[2] + norm_biases[2]
51 |
52 | return PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse)
53 |
54 |
55 | def symmetrize_and_mask_diag(E):
56 | # symmetrize the edge matrix
57 | upper_triangular_mask = torch.zeros_like(E)
58 | indices = torch.triu_indices(row=E.size(1), col=E.size(2), offset=1)
59 | if len(E.shape) == 4:
60 | upper_triangular_mask[:, indices[0], indices[1], :] = 1
61 | else:
62 | upper_triangular_mask[:, indices[0], indices[1]] = 1
63 | E = E * upper_triangular_mask
64 | E = E + torch.transpose(E, 1, 2)
65 | # mask the diagonal
66 | diag = (
67 | torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
68 | )
69 | E[diag] = 0
70 |
71 | return E
72 |
73 |
74 | def to_dense(x, edge_index, edge_attr, batch):
75 | X, node_mask = to_dense_batch(x=x, batch=batch)
76 | # node_mask = node_mask.float()
77 | edge_index, edge_attr = torch_geometric.utils.remove_self_loops(
78 | edge_index, edge_attr
79 | )
80 | max_num_nodes = X.size(1)
81 | E = to_dense_adj(
82 | edge_index=edge_index,
83 | batch=batch,
84 | edge_attr=edge_attr,
85 | max_num_nodes=max_num_nodes,
86 | )
87 | E = encode_no_edge(E)
88 |
89 | return PlaceHolder(X=X, E=E, y=None), node_mask
90 |
91 |
92 | def encode_no_edge(E):
93 | assert len(E.shape) == 4
94 | if E.shape[-1] == 0:
95 | return E
96 | no_edge = torch.sum(E, dim=3) == 0
97 | first_elt = E[:, :, :, 0]
98 | first_elt[no_edge] = 1
99 | E[:, :, :, 0] = first_elt
100 | diag = (
101 | torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
102 | )
103 | E[diag] = 0
104 | return E
105 |
106 |
107 | def update_config_with_new_keys(cfg, saved_cfg):
108 | saved_general = saved_cfg.general
109 | saved_train = saved_cfg.train
110 | saved_model = saved_cfg.model
111 |
112 | for key, val in saved_general.items():
113 | OmegaConf.set_struct(cfg.general, True)
114 | with open_dict(cfg.general):
115 | if key not in cfg.general.keys():
116 | setattr(cfg.general, key, val)
117 |
118 | OmegaConf.set_struct(cfg.train, True)
119 | with open_dict(cfg.train):
120 | for key, val in saved_train.items():
121 | if key not in cfg.train.keys():
122 | setattr(cfg.train, key, val)
123 |
124 | OmegaConf.set_struct(cfg.model, True)
125 | with open_dict(cfg.model):
126 | for key, val in saved_model.items():
127 | if key not in cfg.model.keys():
128 | setattr(cfg.model, key, val)
129 | return cfg
130 |
131 |
132 | class PlaceHolder:
133 | def __init__(self, X, E, y):
134 | self.X = X
135 | self.E = E
136 | self.y = y
137 |
138 | def type_as(self, x: torch.Tensor):
139 | """Changes the device and dtype of X, E, y."""
140 | self.X = self.X.type_as(x)
141 | self.E = self.E.type_as(x)
142 | self.y = self.y.type_as(x)
143 | return self
144 |
145 | def to_device(self, device):
146 | """Changes the device and dtype of X, E, y."""
147 | self.X = self.X.to(device)
148 | self.E = self.E.to(device)
149 | self.y = self.y.to(device) if self.y is not None else None
150 | return self
151 |
152 | def mask(self, node_mask, collapse=False):
153 | x_mask = node_mask.unsqueeze(-1) # bs, n, 1
154 | e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1
155 | e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1
156 |
157 | if collapse:
158 | self.X = torch.argmax(self.X, dim=-1)
159 | self.E = torch.argmax(self.E, dim=-1)
160 |
161 | self.X[node_mask == 0] = -1
162 | self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = -1
163 | else:
164 | self.X = self.X * x_mask
165 | self.E = self.E * e_mask1 * e_mask2
166 | assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
167 | return self
168 |
169 | def __repr__(self):
170 | return (
171 | f"X: {self.X.shape if type(self.X) == torch.Tensor else self.X} -- "
172 | + f"E: {self.E.shape if type(self.E) == torch.Tensor else self.E} -- "
173 | + f"y: {self.y.shape if type(self.y) == torch.Tensor else self.y}"
174 | )
175 |
176 | def split(self, node_mask):
177 | """Split a PlaceHolder representing a batch into a list of placeholders representing individual graphs."""
178 | graph_list = []
179 | batch_size = self.X.shape[0]
180 | for i in range(batch_size):
181 | n = torch.sum(node_mask[i], dim=0)
182 | x = self.X[i, :n]
183 | e = self.E[i, :n, :n]
184 | y = self.y[i] if self.y is not None else None
185 | graph_list.append(PlaceHolder(X=x, E=e, y=y))
186 | return graph_list
187 |
188 |
189 | def setup_wandb(cfg):
190 | config_dict = omegaconf.OmegaConf.to_container(
191 | cfg, resolve=True, throw_on_missing=True
192 | )
193 | if cfg.general.test_only is None:
194 | name = f"{cfg.general.name}"
195 | else:
196 | if cfg.sample.search:
197 | name = f"{cfg.general.name}_search_{cfg.sample.search}"
198 | else:
199 | name = f"{cfg.general.name}_eta{cfg.sample.eta}_{cfg.sample.rdb}_{cfg.sample.time_distortion}"
200 | kwargs = {
201 | "name": name,
202 | "project": f"graph_dfm_{cfg.dataset.name}",
203 | "config": config_dict,
204 | "settings": wandb.Settings(_disable_stats=True),
205 | "reinit": True,
206 | "mode": cfg.general.wandb,
207 | }
208 | config_dict["general"]["local_dir"] = os.getcwd()
209 | wandb.init(**kwargs)
210 | wandb.save("*.txt")
211 |
--------------------------------------------------------------------------------
/src/metrics/abstract_metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn import functional as F
4 | from torch.nn import KLDivLoss
5 | from torchmetrics import Metric, MeanSquaredError
6 |
7 |
8 | class TrainAbstractMetricsDiscrete(torch.nn.Module):
9 | def __init__(self):
10 | super().__init__()
11 |
12 | def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
13 | pass
14 |
15 | def reset(self):
16 | pass
17 |
18 | def log_epoch_metrics(self):
19 | return None, None
20 |
21 |
22 | class SumExceptBatchMetric(Metric):
23 | def __init__(self):
24 | super().__init__()
25 | self.add_state("total_value", default=torch.tensor(0.0), dist_reduce_fx="sum")
26 | self.add_state("total_samples", default=torch.tensor(0.0), dist_reduce_fx="sum")
27 |
28 | def update(self, values) -> None:
29 | self.total_value += torch.sum(values)
30 | self.total_samples += values.shape[0]
31 |
32 | def compute(self):
33 | return self.total_value / self.total_samples
34 |
35 |
36 | class SumExceptBatchMSE(MeanSquaredError):
37 | def update(self, preds: Tensor, target: Tensor) -> None:
38 | """Update state with predictions and targets.
39 |
40 | Args:
41 | preds: Predictions from model
42 | target: Ground truth values
43 | """
44 | assert preds.shape == target.shape
45 | sum_squared_error, n_obs = self._mean_squared_error_update(preds, target)
46 |
47 | self.sum_squared_error += sum_squared_error
48 | self.total += n_obs
49 |
50 | def _mean_squared_error_update(self, preds: Tensor, target: Tensor):
51 | """Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input
52 | tensors.
53 | preds: Predicted tensor
54 | target: Ground truth tensor
55 | """
56 | diff = preds - target
57 | sum_squared_error = torch.sum(diff * diff)
58 | n_obs = preds.shape[0]
59 | return sum_squared_error, n_obs
60 |
61 |
62 | class SumExceptBatchKL(Metric):
63 | def __init__(self):
64 | super().__init__()
65 | self.add_state("total_value", default=torch.tensor(0.0), dist_reduce_fx="sum")
66 | self.add_state("total_samples", default=torch.tensor(0.0), dist_reduce_fx="sum")
67 |
68 | def update(self, p, q) -> None:
69 | self.total_value += F.kl_div(q, p, reduction="sum")
70 | self.total_samples += p.size(0)
71 |
72 | def compute(self):
73 | return self.total_value / self.total_samples
74 |
75 |
76 | class CrossEntropyMetric(Metric):
77 | def __init__(self):
78 | super().__init__()
79 | self.add_state("total_ce", default=torch.tensor(0.0), dist_reduce_fx="sum")
80 | self.add_state("total_samples", default=torch.tensor(0.0), dist_reduce_fx="sum")
81 |
82 | def update(self, preds: Tensor, target: Tensor, weight: Tensor = None) -> None:
83 | """Update state with predictions and targets.
84 | preds: Predictions from model (bs * n, d) or (bs * n * n, d)
85 | target: Ground truth values (bs * n, d) or (bs * n * n, d)."""
86 | target = torch.argmax(target, dim=-1)
87 | if weight is not None:
88 | output = F.cross_entropy(
89 | preds,
90 | target,
91 | reduction="none",
92 | weight=None,
93 | )
94 | output = (output * weight).sum()
95 | else:
96 | output = F.cross_entropy(
97 | preds,
98 | target,
99 | reduction="sum",
100 | weight=None,
101 | )
102 | # output = F.cross_entropy(preds, target, reduction="sum")
103 | self.total_ce += output
104 | self.total_samples += preds.size(0)
105 |
106 | def compute(self):
107 | return self.total_ce / self.total_samples
108 |
109 |
110 | class KLDMetric(Metric):
111 | def __init__(self):
112 | super().__init__()
113 | self.add_state("total_ce", default=torch.tensor(0.0), dist_reduce_fx="sum")
114 | self.add_state("total_samples", default=torch.tensor(0.0), dist_reduce_fx="sum")
115 |
116 | def update(self, preds: Tensor, target: Tensor, weight: Tensor = None) -> None:
117 | """Update state with predictions and targets.
118 | preds: Predictions from model (bs * n, d) or (bs * n * n, d)
119 | target: Ground truth values (bs * n, d) or (bs * n * n, d)."""
120 | # target = torch.argmax(target, dim=-1)
121 | if weight is not None:
122 | output = KLDivLoss(reduction="none")(
123 | preds,
124 | target,
125 | )
126 | output = (output * weight).sum()
127 | else:
128 | output = KLDivLoss(reduction="none")(
129 | preds,
130 | target,
131 | )
132 |
133 | output[output.isnan()] = 0 # zero-out masked places
134 |
135 | output = output.sum()
136 | # output = F.cross_entropy(preds, target, reduction="sum")
137 | self.total_ce += output
138 | self.total_samples += preds.size(0)
139 |
140 | def compute(self):
141 | return self.total_ce / self.total_samples
142 |
143 |
144 | class ProbabilityMetric(Metric):
145 | def __init__(self):
146 | """This metric is used to track the marginal predicted probability of a class during training."""
147 | super().__init__()
148 | self.add_state("prob", default=torch.tensor(0.0), dist_reduce_fx="sum")
149 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
150 |
151 | def update(self, preds: Tensor) -> None:
152 | self.prob += preds.sum()
153 | self.total += preds.numel()
154 |
155 | def compute(self):
156 | return self.prob / self.total
157 |
158 |
159 | class NLL(Metric):
160 | def __init__(self):
161 | super().__init__()
162 | self.add_state("total_nll", default=torch.tensor(0.0), dist_reduce_fx="sum")
163 | self.add_state("total_samples", default=torch.tensor(0.0), dist_reduce_fx="sum")
164 |
165 | def update(self, batch_nll) -> None:
166 | self.total_nll += torch.sum(batch_nll)
167 | self.total_samples += batch_nll.numel()
168 |
169 | def compute(self):
170 | return self.total_nll / self.total_samples
171 |
172 |
173 | def compute_ratios(gen_metrics, ref_metrics, metrics_keys):
174 | print("Computing ratios of metrics: ", metrics_keys)
175 | if ref_metrics is not None and len(metrics_keys) > 0:
176 | ratios = {}
177 | for key in metrics_keys:
178 | try:
179 | ref_metric = round(ref_metrics[key], 4)
180 | except:
181 | print(key, "not found")
182 | continue
183 | if ref_metric != 0.0:
184 | ratios[key + "_ratio"] = gen_metrics[key] / ref_metric
185 | else:
186 | print(f"WARNING: Reference {key} is 0. Skipping its ratio.")
187 | if len(ratios) > 0:
188 | ratios["average_ratio"] = sum(ratios.values()) / len(ratios)
189 | else:
190 | ratios["average_ratio"] = -1
191 | print(f"WARNING: no ratio being saved.")
192 | else:
193 | print("WARNING: No reference metrics for ratio computation.")
194 | ratios = {}
195 |
196 | return ratios
197 |
--------------------------------------------------------------------------------
/src/metrics/tls_metrics.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from collections import Counter
3 | from typing import List
4 |
5 | import networkx as nx
6 | import numpy as np
7 | import scipy.sparse as sp
8 | import wandb
9 | import torch
10 | from torch import Tensor
11 | from torchmetrics import MeanMetric
12 |
13 | from datasets.tls_dataset import CellGraph
14 | from analysis.spectre_utils import PlanarSamplingMetrics
15 | from analysis.spectre_utils import is_planar_graph
16 |
17 |
18 | class TLSSamplingMetrics(PlanarSamplingMetrics):
19 | def __init__(self, datamodule):
20 | super().__init__(datamodule)
21 | self.train_cell_graphs = self.loader_to_cell_graphs(
22 | datamodule.train_dataloader()
23 | )
24 | self.val_cell_graphs = self.loader_to_cell_graphs(datamodule.val_dataloader())
25 | self.test_cell_graphs = self.loader_to_cell_graphs(datamodule.test_dataloader())
26 |
27 | def loader_to_cell_graphs(self, loader):
28 | cell_graphs = []
29 | for batch in loader:
30 | for tg_graph in batch.to_data_list():
31 | cell_graph = CellGraph.from_torch_geometric(tg_graph)
32 | cell_graphs.append(cell_graph)
33 |
34 | return cell_graphs
35 |
36 | def is_cell_graph_valid(self, cg: CellGraph):
37 | # connected and planar
38 | return is_planar_graph(cg)
39 |
40 | def forward(
41 | self,
42 | generated_graphs: list,
43 | ref_metrics,
44 | name,
45 | current_epoch,
46 | val_counter,
47 | local_rank,
48 | test=False,
49 | labels=None,
50 | ):
51 |
52 | # Unattributed graphs specific
53 | to_log = super().forward(
54 | generated_graphs,
55 | ref_metrics,
56 | name,
57 | current_epoch,
58 | val_counter,
59 | local_rank,
60 | test,
61 | labels,
62 | )
63 |
64 | # Cell graph specific
65 | reference_cgs = self.test_cell_graphs if test else self.val_cell_graphs
66 | generated_cgs = []
67 | if local_rank == 0:
68 | print("Building generated cell graphs...")
69 | for graph in generated_graphs:
70 | generated_cgs.append(CellGraph.from_dense_graph(graph))
71 |
72 | # TODO: Implement these metrics with torchmetrics for parallelization
73 | generated_labels = torch.tensor([cg.to_label() for cg in generated_cgs])
74 | ambiguous_gen_cgs = sum(
75 | [(cg_label == -1).int() for cg_label in generated_labels]
76 | ).item()
77 | if labels is not None:
78 | true_labels = torch.tensor(labels)
79 | high_tls_idxs = true_labels == 1
80 | low_tls_idxs = true_labels == 0
81 | total_tls_acc = (generated_labels == true_labels).float().mean().item()
82 | high_tls_acc = (
83 | (generated_labels[high_tls_idxs] == true_labels[high_tls_idxs])
84 | .float()
85 | .mean()
86 | .item()
87 | )
88 | low_tls_acc = (
89 | (generated_labels[low_tls_idxs] == true_labels[low_tls_idxs])
90 | .float()
91 | .mean()
92 | .item()
93 | )
94 | else:
95 | total_tls_acc = -1
96 | high_tls_acc = -1
97 | low_tls_acc = -1
98 |
99 | # Compute novelty and uniqueness
100 | if local_rank == 0:
101 | print("Computing uniqueness, novelty and validity for cell graphs...")
102 | frac_novel = eval_fraction_novel_cell_graphs(
103 | generated_cell_graphs=generated_cgs,
104 | train_cell_graphs=self.train_cell_graphs,
105 | )
106 | (
107 | frac_unique,
108 | frac_unique_and_novel,
109 | frac_unique_and_novel_valid,
110 | ) = eval_fraction_unique_novel_valid_cell_graphs(
111 | generated_cell_graphs=generated_cgs,
112 | train_cell_graphs=self.train_cell_graphs,
113 | valid_cg_fn=self.is_cell_graph_valid,
114 | )
115 |
116 | tls_to_log = {
117 | "tls_metrics/total_tls_acc": total_tls_acc,
118 | "tls_metrics/high_tls_acc": high_tls_acc,
119 | "tls_metrics/low_tls_acc": low_tls_acc,
120 | "tls_metrics/num_ambiguous_tls": ambiguous_gen_cgs,
121 | "tls_metrics/frac_novel": frac_novel,
122 | "tls_metrics/frac_unique": frac_unique,
123 | "tls_metrics/frac_unique_and_novel": frac_unique_and_novel,
124 | "tls_metrics/frac_unique_and_novel_valid": frac_unique_and_novel_valid,
125 | }
126 |
127 | print(f"TLS sampling metrics: {tls_to_log}")
128 | if wandb.run:
129 | # only log TLS sampling metrics because others are already logged by planar sampling metrics
130 | wandb.log(tls_to_log, commit=False)
131 |
132 | to_log.update(tls_to_log)
133 |
134 | return to_log
135 |
136 |
137 | # specific for cell graphs (isomorphism function is of cell graphs)
138 | def eval_fraction_novel_cell_graphs(generated_cell_graphs, train_cell_graphs):
139 | count_non_novel = 0
140 | for gen_cg in generated_cell_graphs:
141 | for train_cg in train_cell_graphs:
142 | if nx.faster_could_be_isomorphic(train_cg, gen_cg):
143 | if gen_cg.is_isomorphic(train_cg):
144 | count_non_novel += 1
145 | break
146 | return 1 - count_non_novel / len(generated_cell_graphs)
147 |
148 |
149 | # specific for cell graphs (isomorphism function is of cell graphs)
150 | def eval_fraction_unique_novel_valid_cell_graphs(
151 | generated_cell_graphs,
152 | train_cell_graphs,
153 | valid_cg_fn,
154 | ):
155 | count_non_unique = 0
156 | count_not_novel = 0
157 | count_not_valid = 0
158 | for cg_idx, gen_cg in enumerate(generated_cell_graphs):
159 | is_unique = True
160 | for gen_cg_seen in generated_cell_graphs[:cg_idx]:
161 | # test =gen_cg.is_isomorphic(gen_cg_seen)
162 | # print(test)
163 | # breakpoint()
164 | if nx.faster_could_be_isomorphic(gen_cg_seen, gen_cg):
165 | # we also need to consider phenotypes of nodes
166 | if gen_cg.is_isomorphic(gen_cg_seen):
167 | count_non_unique += 1
168 | is_unique = False
169 | break
170 | if is_unique:
171 | is_novel = True
172 | for train_cg in train_cell_graphs:
173 | if nx.faster_could_be_isomorphic(train_cg, gen_cg):
174 | if gen_cg.is_isomorphic(train_cg):
175 | count_not_novel += 1
176 | is_novel = False
177 | break
178 | if is_novel:
179 | if not valid_cg_fn(gen_cg):
180 | count_not_valid += 1
181 |
182 | frac_unique = 1 - count_non_unique / len(generated_cell_graphs)
183 | frac_unique_non_isomorphic = frac_unique - count_not_novel / len(
184 | generated_cell_graphs
185 | )
186 | frac_unique_non_isomorphic_valid = (
187 | frac_unique_non_isomorphic - count_not_valid / len(generated_cell_graphs)
188 | )
189 |
190 | return (
191 | frac_unique,
192 | frac_unique_non_isomorphic,
193 | frac_unique_non_isomorphic_valid,
194 | )
195 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeFoG: Discrete Flow Matching for Graph Generation
2 |
3 | > A PyTorch implementation of the DeFoG model for training and sampling discrete graph flows. (Please update to the latest commit. Recent fixes have been applied.)
4 |
5 | > Paper: https://arxiv.org/pdf/2410.04263
6 |
7 | > Poster: https://icml.cc/virtual/2025/poster/45644
8 |
9 | > Oral Presentation: https://icml.cc/virtual/2025/oral/47238
10 |
11 |
12 | 
13 |
14 |
15 |
16 |
17 |
18 |
19 | ---
20 |
21 | ## 📝 Updates
22 |
23 | > Working with directed graphs? Consider using [DIRECTO](https://github.com/acarballocastro/DIRECTO), a discrete flow matching framework for directed graph generation.
24 |
25 | > For an updated development environment with modernized dependencies, see the `updated_env` branch. The `main` branch remains the reference implementation, based on Python 3.9 and older package versions.
26 |
27 |
28 |
29 | ## 🚀 Installation
30 |
31 | We provide two alternative installation methods: Docker and Conda.
32 |
33 | ### 🐳 Docker
34 |
35 | We provide the Dockerfile to run DeFoG in a container.
36 | 1. Build the Docker image:
37 | ```bash
38 | docker build --platform=linux/amd64 -t defog-image .
39 | ```
40 |
41 | ⚠️ Once you clone DeFoG's git repository to your workspace, you may need to run `pip install -e .` to make all the repository modules visible.
42 |
43 |
44 | ### 🐍 Conda
45 |
46 | 1. Install Conda (we used version 25.1.1) and create DeFoG's environment:
47 | ```bash
48 | conda env create -f environment.yaml
49 | conda activate defog
50 | ```
51 | 2. Run the following commands to check if the installation of the main packages was successful:
52 | ```bash
53 | python -c "import sys; print('Python version:', sys.version)"
54 | python -c "import rdkit; print('RDKit version:', rdkit.__version__)"
55 | python -c "import graph_tool as gt; print('Graph-Tool version:', gt.__version__)"
56 | python -c "import torch; print(f'PyTorch version: {torch.__version__}, CUDA version (via PyTorch): {torch.version.cuda}')"
57 | python -c "import torch_geometric as tg; print('PyTorch Geometric version:', tg.__version__)"
58 | ```
59 | If you see no errors, the installation was successful and you can proceed to the next step.
60 | 3. Compile the ORCA evaluator:
61 | ```bash
62 | cd src/analysis/orca
63 | g++ -O2 -std=c++11 -o orca orca.cpp
64 | ```
65 |
66 | ⚠️ Tested on Ubuntu.
67 |
68 | ---
69 |
70 | ## ⚙️ Usage
71 |
72 | All commands use `python main.py` with [Hydra](https://hydra.cc/) overrides. Note that `main.py` is inside the `src` directory.
73 |
74 | ### Quick start
75 |
76 | Use this script to quickly test the code.
77 |
78 | ```bash
79 | python main.py +experiment=debug
80 | ```
81 |
82 | ### Full training
83 |
84 | ```bash
85 | python main.py +experiment= dataset=
86 | ```
87 |
88 | - **QM9 (no H):** `+experiment=qm9_no_h dataset=qm9`
89 | - **Planar:** `+experiment=planar dataset=planar`
90 | - **SBM:** `+experiment=sbm dataset=sbm`
91 | - **Tree:** `+experiment=tree dataset=tree`
92 | - **Comm20:** `+experiment=comm20 dataset=comm20`
93 | - **Guacamol:** `+experiment=guacamol dataset=guacamol`
94 | - **MOSES:** `+experiment=moses dataset=moses`
95 | - **QM9 (with H):** `+experiment=qm9_with_h dataset=qm9`
96 | - **TLS (conditional):** `+experiment=tls dataset=tls`
97 | - **ZINC:** `+experiment=zinc dataset=zinc`
98 |
99 | ---
100 |
101 | ## 📊 Evaluation
102 |
103 | Sampling from DeFoG is typically done in two steps:
104 |
105 | 1. **Sampling Optimization** → find best sampling configuration
106 | 2. **Final Sampling** → sample and measure performance under the best configuration
107 |
108 | To perform 5 runs (mean ± std), set `general.num_sample_fold=5`.
109 |
110 | For the rest of this section, we take Planar dataset as an example:
111 |
112 | ### Default sampling
113 | ```bash
114 | python main.py +experiment=planar dataset=planar general.test_only= sample.eta=0 sample.omega=0 sample.time_distortion=identity
115 | ```
116 |
117 | Note that if you run:
118 | ```bash
119 | python main.py +experiment=planar dataset=planar general.test_only=
120 | ```
121 | it will run with the sampling parameters (η, ω, sample distortion) that we obtained after sampling optimization (see next section) and are reported in the paper.
122 |
123 | ### Sampling optimization
124 | To search over the optimal inference hyperperameters (η, ω, distortion), use the `sample.search` flag, which will save a csv file with the results.
125 | - **Non-grid search** (independent search for each component):
126 | ```bash
127 | python main.py +experiment=planar dataset=planar general.test_only= sample.search=all
128 | ```
129 | - **Component-wise**: set `sample.search=target_guidance | distortion | stochasticity` above.
130 |
131 | ⚠️ We set the default search intervals for each sampling parameter as we used in our experiments. You may want to adjust these intervals according to your needs.
132 |
133 | ### Final sampling
134 | Use optimal η, ω, time distortion resulting from the search:
135 | ```bash
136 | python main.py +experiment=planar dataset=planar general.test_only= sample.eta=<η> sample.omega=<ω> sample.time_distortion=
137 | ```
138 |
139 | ---
140 |
141 | ## 🌐 Extend DeFoG to new datasets
142 |
143 | Start by creating a new file in the `src/datasets` directory. You can refer to the following scripts as examples:
144 | - `spectre_dataset.py`, if you are using unattributed graphs;
145 | - `tls_dataset.py`, if you are using graphs with attributed nodes;
146 | - `qm9_dataset.py` or `guacamol_dataset.py`, if you are using graphs with attributed nodes and edges (e.g., molecular data).
147 |
148 | This new file should define a `Dataset` class to handle data processing (refer to the [PyG documentation](https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_dataset.html) for guidance), as well as a `DatasetInfos` class to specify relevant dataset properties (e.g., number of nodes, edges, etc.).
149 |
150 | Once your dataset file is ready, update `main.py` to incorporate the new dataset. Additionally, you can add a corresponding file in the `configs/dataset` directory.
151 |
152 | Finally, if you are planning to introduce custom metrics, you can create a new file under the `metrics` directory.
153 |
154 | ---
155 |
156 | ## Checkpoints
157 |
158 | Checkpoints, along with their corresponding results and generated samples, are shared [here](https://drive.switch.ch/index.php/s/MG7y2EZoithAywE).
159 |
160 | To run sampling and evaluate generation with a given checkpoint, set the `general.test_only` flag to the path of the checkpoint file (`.ckpt` file). To skip sampling and directly evaluate previously generated samples, set the flag `general.generated_path` to the path of the generated samples (`.pkl` file).
161 |
162 | (*Note*: The released checkpoints are retrained models from the public repository. Their performance is consistent with the paper’s findings, with minor variations attributable to training/sampling stochasticity.)
163 |
164 | ---
165 |
166 | ## 📌 Upon request
167 |
168 | - protein / EGO datasets
169 | - FCD score for molecules
170 | - W&B sweeps for sampling optimization
171 |
172 |
173 |
174 | ---
175 | ## 🙏 Acknowledgements
176 |
177 | - DiGress: https://github.com/cvignac/DiGress
178 | - Discrete Flow Models: https://github.com/andrew-cr/discrete_flow_models
179 |
180 | ---
181 |
182 | ## 📚 Citation
183 |
184 | ```bibtex
185 | @inproceedings{qinmadeira2024defog,
186 | title = {DeFoG: Discrete Flow Matching for Graph Generation},
187 | author = {Qin, Yiming and Madeira, Manuel and Thanou, Dorina and Frossard, Pascal},
188 | booktitle = {International Conference on Machine Learning (ICML)},
189 | year = {2025},
190 | }
191 | ```
192 |
--------------------------------------------------------------------------------
/src/datasets/abstract_dataset.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 |
4 | import torch
5 | import wandb
6 | import pytorch_lightning as pl
7 | from torch_geometric.loader import DataLoader
8 | from torch_geometric.data.lightning import LightningDataset
9 | from tqdm import tqdm
10 |
11 | import utils as utils
12 | from datasets.dataset_utils import DistributionNodes, load_pickle, save_pickle
13 |
14 |
15 | class AbstractDataModule(LightningDataset):
16 | def __init__(self, cfg, datasets):
17 | super().__init__(
18 | train_dataset=datasets["train"],
19 | val_dataset=datasets["val"],
20 | test_dataset=datasets["test"],
21 | batch_size=cfg.train.batch_size if "debug" not in cfg.general.name else 2,
22 | num_workers=cfg.train.num_workers,
23 | pin_memory=getattr(cfg.dataset, "pin_memory", False),
24 | )
25 | self.cfg = cfg
26 | self.input_dims = None
27 | self.output_dims = None
28 | print(
29 | f'This dataset contains {len(datasets["train"])} training graphs, {len(datasets["val"])} validation graphs, {len(datasets["test"])} test graphs.'
30 | )
31 |
32 | def __getitem__(self, idx):
33 | return self.train_dataset[idx]
34 |
35 | def node_counts(self, max_nodes_possible=1000):
36 | all_counts = torch.zeros(max_nodes_possible)
37 | for loader in [self.train_dataloader(), self.val_dataloader()]:
38 | for data in loader:
39 | unique, counts = torch.unique(data.batch, return_counts=True)
40 | for count in counts:
41 | all_counts[count] += 1
42 | max_index = max(all_counts.nonzero())
43 | all_counts = all_counts[: max_index + 1]
44 | all_counts = all_counts / all_counts.sum()
45 | return all_counts
46 |
47 | def node_types(self):
48 | num_classes = None
49 |
50 | for data in self.train_dataloader():
51 | num_classes = data.x.shape[1]
52 | break
53 |
54 | counts = torch.zeros(num_classes)
55 |
56 | for i, data in enumerate(self.train_dataloader()):
57 | counts += data.x.sum(dim=0)
58 |
59 | counts = counts / counts.sum()
60 | return counts
61 |
62 | def edge_counts(self):
63 | num_classes = None
64 | for data in self.train_dataloader():
65 | num_classes = data.edge_attr.shape[1]
66 | break
67 |
68 | d = torch.zeros(num_classes, dtype=torch.float)
69 |
70 | for i, data in enumerate(self.train_dataloader()):
71 | unique, counts = torch.unique(data.batch, return_counts=True)
72 |
73 | all_pairs = 0
74 | for count in counts:
75 | all_pairs += count * (count - 1)
76 |
77 | num_edges = data.edge_index.shape[1]
78 | num_non_edges = all_pairs - num_edges
79 |
80 | edge_types = data.edge_attr.sum(dim=0)
81 | assert num_non_edges >= 0
82 | d[0] += num_non_edges
83 | d[1:] += edge_types[1:]
84 |
85 | d = d / d.sum()
86 | return d
87 |
88 |
89 | class MolecularDataModule(AbstractDataModule):
90 | def valency_count(self, max_n_nodes, zinc=False):
91 | valencies = torch.zeros(
92 | 3 * max_n_nodes - 2
93 | ) # Max valency possible if everything is connected
94 |
95 | # No bond, single bond, double bond, triple bond, aromatic bond
96 | multiplier = torch.tensor([0, 1, 2, 3, 1.5])
97 | if zinc:
98 | multiplier = torch.tensor([0, 1, 2, 3]) # zinc250
99 |
100 | for data in self.train_dataloader():
101 | n = data.x.shape[0]
102 |
103 | for atom in range(n):
104 | edges = data.edge_attr[data.edge_index[0] == atom]
105 | edges_total = edges.sum(dim=0)
106 | valency = (edges_total * multiplier).sum()
107 | valencies[valency.long().item()] += 1
108 | valencies = valencies / valencies.sum()
109 | return valencies
110 |
111 |
112 | class AbstractDatasetInfos:
113 | def complete_infos(self, n_nodes, node_types):
114 | self.input_dims = None
115 | self.output_dims = None
116 | self.num_classes = len(node_types)
117 | self.max_n_nodes = len(n_nodes) - 1
118 | self.nodes_dist = DistributionNodes(n_nodes)
119 |
120 | def compute_input_output_dims(self, datamodule, extra_features, domain_features):
121 | example_batch = next(iter(datamodule.train_dataloader()))
122 | ex_dense, node_mask = utils.to_dense(
123 | example_batch.x,
124 | example_batch.edge_index,
125 | example_batch.edge_attr,
126 | example_batch.batch,
127 | )
128 |
129 | # ex_dense.E = ex_dense.E[..., :-1] # debug
130 |
131 | example_data = {
132 | "X_t": ex_dense.X,
133 | "E_t": ex_dense.E,
134 | "y_t": example_batch["y"],
135 | "node_mask": node_mask,
136 | }
137 | self.input_dims = {
138 | "X": example_batch["x"].size(1),
139 | "E": example_batch["edge_attr"].size(1),
140 | "y": example_batch["y"].size(1)
141 | + 1, # this part take into account the conditioning
142 | } # + 1 due to time conditioning
143 | ex_extra_feat = extra_features(example_data)
144 | self.input_dims["X"] += ex_extra_feat.X.size(-1)
145 | self.input_dims["E"] += ex_extra_feat.E.size(-1)
146 | self.input_dims["y"] += ex_extra_feat.y.size(-1)
147 |
148 | ex_extra_molecular_feat = domain_features(example_data)
149 | self.input_dims["X"] += ex_extra_molecular_feat.X.size(-1)
150 | self.input_dims["E"] += ex_extra_molecular_feat.E.size(-1)
151 | self.input_dims["y"] += ex_extra_molecular_feat.y.size(-1)
152 |
153 | self.output_dims = {
154 | "X": example_batch["x"].size(1),
155 | "E": example_batch["edge_attr"].size(1),
156 | "y": 0,
157 | }
158 |
159 | def compute_reference_metrics(self, datamodule, sampling_metrics):
160 |
161 | ref_metrics_path = os.path.join(
162 | datamodule.train_dataloader().dataset.root, f"ref_metrics.pkl"
163 | )
164 | if hasattr(datamodule, "remove_h"):
165 | if datamodule.remove_h:
166 | ref_metrics_path = ref_metrics_path.replace(".pkl", "_no_h.pkl")
167 | else:
168 | ref_metrics_path = ref_metrics_path.replace(".pkl", "_h.pkl")
169 |
170 | # Only compute the reference metrics if they haven't been computed already
171 | if not os.path.exists(ref_metrics_path):
172 |
173 | print("Reference metrics not found. Computing them now.")
174 | # Transform the training dataset into a list of graphs in the appropriate format
175 | training_graphs = []
176 | print("Converting training dataset to format required by sampling metrics.")
177 | for data_batch in tqdm(datamodule.train_dataloader()):
178 | dense_data, node_mask = utils.to_dense(
179 | data_batch.x,
180 | data_batch.edge_index,
181 | data_batch.edge_attr,
182 | data_batch.batch,
183 | )
184 | dense_data = dense_data.mask(node_mask, collapse=True).split(node_mask)
185 | for graph in dense_data:
186 | training_graphs.append([graph.X, graph.E])
187 |
188 | # defining dummy arguments
189 | dummy_kwargs = {
190 | "name": "ref_metrics",
191 | "current_epoch": 0,
192 | "val_counter": 0,
193 | "local_rank": 0,
194 | "ref_metrics": {"val": None, "test": None},
195 | }
196 |
197 | print("Computing validation reference metrics.")
198 | # do not have to worry about wandb because it was not init yet
199 | val_sampling_metrics = copy.deepcopy(sampling_metrics)
200 |
201 | val_ref_metrics = val_sampling_metrics.forward(
202 | training_graphs,
203 | test=False,
204 | **dummy_kwargs,
205 | )
206 |
207 | print("Computing test reference metrics.")
208 | test_sampling_metrics = copy.deepcopy(sampling_metrics)
209 | test_ref_metrics = test_sampling_metrics.forward(
210 | training_graphs,
211 | test=True,
212 | **dummy_kwargs,
213 | )
214 |
215 | print("Saving reference metrics.")
216 | # print(f"deg: {test_reference_metrics['degree']} | clus: {test_reference_metrics['clustering']} | orbit: {test_reference_metrics['orbit']}")
217 | # breakpoint()
218 | save_pickle(
219 | {"val": val_ref_metrics, "test": test_ref_metrics}, ref_metrics_path
220 | )
221 |
222 | print("Loading reference metrics.")
223 | self.ref_metrics = load_pickle(ref_metrics_path)
224 | print("Validation reference metrics:", self.ref_metrics["val"])
225 | print("Test reference metrics:", self.ref_metrics["test"])
226 |
--------------------------------------------------------------------------------
/src/datasets/tu_dataset_origin.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | import os.path as osp
4 | from typing import Callable, List, Optional
5 |
6 | import numpy as np
7 | import torch.nn.functional as F
8 | import torch
9 | from torch.utils.data import random_split
10 | import torch_geometric.utils
11 | from torch_geometric.io import fs, read_tu_data
12 | from torch_geometric.utils import remove_self_loops
13 | from torch_geometric.data import Data, InMemoryDataset, download_url
14 | from hydra.utils import get_original_cwd
15 |
16 | from utils import PlaceHolder
17 | from datasets.abstract_dataset import (
18 | AbstractDataModule,
19 | AbstractDatasetInfos,
20 | )
21 | from datasets.dataset_utils import (
22 | load_pickle,
23 | save_pickle,
24 | Statistics,
25 | to_list,
26 | RemoveYTransform,
27 | )
28 |
29 |
30 | class TUDataset(InMemoryDataset):
31 | r"""A variety of graph kernel benchmark datasets, *.e.g.*,
32 | :obj:`"IMDB-BINARY"`, :obj:`"REDDIT-BINARY"` or :obj:`"PROTEINS"`,
33 | collected from the `TU Dortmund University
34 | `_.
35 | In addition, this dataset wrapper provides `cleaned dataset versions
36 | `_ as motivated by the
37 | `"Understanding Isomorphism Bias in Graph Data Sets"
38 | `_ paper, containing only non-isomorphic
39 | graphs.
40 |
41 | .. note::
42 | Some datasets may not come with any node labels.
43 | You can then either make use of the argument :obj:`use_node_attr`
44 | to load additional continuous node attributes (if present) or provide
45 | synthetic node features using transforms such as
46 | :class:`torch_geometric.transforms.Constant` or
47 | :class:`torch_geometric.transforms.OneHotDegree`.
48 |
49 | Args:
50 | root (str): Root directory where the dataset should be saved.
51 | name (str): The `name
52 | `_ of the
53 | dataset.
54 | transform (callable, optional): A function/transform that takes in an
55 | :obj:`torch_geometric.data.Data` object and returns a transformed
56 | version. The data object will be transformed before every access.
57 | (default: :obj:`None`)
58 | pre_transform (callable, optional): A function/transform that takes in
59 | an :obj:`torch_geometric.data.Data` object and returns a
60 | transformed version. The data object will be transformed before
61 | being saved to disk. (default: :obj:`None`)
62 | pre_filter (callable, optional): A function that takes in an
63 | :obj:`torch_geometric.data.Data` object and returns a boolean
64 | value, indicating whether the data object should be included in the
65 | final dataset. (default: :obj:`None`)
66 | force_reload (bool, optional): Whether to re-process the dataset.
67 | (default: :obj:`False`)
68 | use_node_attr (bool, optional): If :obj:`True`, the dataset will
69 | contain additional continuous node attributes (if present).
70 | (default: :obj:`False`)
71 | use_edge_attr (bool, optional): If :obj:`True`, the dataset will
72 | contain additional continuous edge attributes (if present).
73 | (default: :obj:`False`)
74 | cleaned (bool, optional): If :obj:`True`, the dataset will
75 | contain only non-isomorphic graphs. (default: :obj:`False`)
76 |
77 | **STATS:**
78 |
79 | .. list-table::
80 | :widths: 20 10 10 10 10 10
81 | :header-rows: 1
82 |
83 | * - Name
84 | - #graphs
85 | - #nodes
86 | - #edges
87 | - #features
88 | - #classes
89 | * - MUTAG
90 | - 188
91 | - ~17.9
92 | - ~39.6
93 | - 7
94 | - 2
95 | * - ENZYMES
96 | - 600
97 | - ~32.6
98 | - ~124.3
99 | - 3
100 | - 6
101 | * - PROTEINS
102 | - 1,113
103 | - ~39.1
104 | - ~145.6
105 | - 3
106 | - 2
107 | * - COLLAB
108 | - 5,000
109 | - ~74.5
110 | - ~4914.4
111 | - 0
112 | - 3
113 | * - IMDB-BINARY
114 | - 1,000
115 | - ~19.8
116 | - ~193.1
117 | - 0
118 | - 2
119 | * - REDDIT-BINARY
120 | - 2,000
121 | - ~429.6
122 | - ~995.5
123 | - 0
124 | - 2
125 | * - ...
126 | -
127 | -
128 | -
129 | -
130 | -
131 | """
132 |
133 | url = "https://www.chrsmrrs.com/graphkerneldatasets"
134 | cleaned_url = (
135 | "https://raw.githubusercontent.com/nd7141/" "graph_datasets/master/datasets"
136 | )
137 |
138 | def __init__(
139 | self,
140 | root: str,
141 | name: str,
142 | transform: Optional[Callable] = None,
143 | pre_transform: Optional[Callable] = None,
144 | pre_filter: Optional[Callable] = None,
145 | force_reload: bool = False,
146 | use_node_attr: bool = False,
147 | use_edge_attr: bool = False,
148 | cleaned: bool = False,
149 | ) -> None:
150 | self.name = name
151 | self.cleaned = cleaned
152 | super().__init__(
153 | root, transform, pre_transform, pre_filter, force_reload=force_reload
154 | )
155 |
156 | out = fs.torch_load(self.processed_paths[0])
157 | if not isinstance(out, tuple) or len(out) < 3:
158 | raise RuntimeError(
159 | "The 'data' object was created by an older version of PyG. "
160 | "If this error occurred while loading an already existing "
161 | "dataset, remove the 'processed/' directory in the dataset's "
162 | "root folder and try again."
163 | )
164 | assert len(out) == 3 or len(out) == 4
165 |
166 | if len(out) == 3: # Backward compatibility.
167 | data, self.slices, self.sizes = out
168 | data_cls = Data
169 | else:
170 | data, self.slices, self.sizes, data_cls = out
171 |
172 | if not isinstance(data, dict): # Backward compatibility.
173 | self.data = data
174 | else:
175 | self.data = data_cls.from_dict(data)
176 |
177 | assert isinstance(self._data, Data)
178 | if self._data.x is not None and not use_node_attr:
179 | num_node_attributes = self.num_node_attributes
180 | self._data.x = self._data.x[:, num_node_attributes:]
181 | if self._data.edge_attr is not None and not use_edge_attr:
182 | num_edge_attrs = self.num_edge_attributes
183 | self._data.edge_attr = self._data.edge_attr[:, num_edge_attrs:]
184 |
185 | @property
186 | def raw_dir(self) -> str:
187 | name = f'raw{"_cleaned" if self.cleaned else ""}'
188 | return osp.join(self.root, self.name, name)
189 |
190 | @property
191 | def processed_dir(self) -> str:
192 | name = f'processed{"_cleaned" if self.cleaned else ""}'
193 | return osp.join(self.root, self.name, name)
194 |
195 | @property
196 | def num_node_labels(self) -> int:
197 | return self.sizes["num_node_labels"]
198 |
199 | @property
200 | def num_node_attributes(self) -> int:
201 | return self.sizes["num_node_attributes"]
202 |
203 | @property
204 | def num_edge_labels(self) -> int:
205 | return self.sizes["num_edge_labels"]
206 |
207 | @property
208 | def num_edge_attributes(self) -> int:
209 | return self.sizes["num_edge_attributes"]
210 |
211 | @property
212 | def raw_file_names(self) -> List[str]:
213 | names = ["A", "graph_indicator"]
214 | return [f"{self.name}_{name}.txt" for name in names]
215 |
216 | @property
217 | def processed_file_names(self) -> str:
218 | return "data.pt"
219 |
220 | def download(self) -> None:
221 | url = self.cleaned_url if self.cleaned else self.url
222 | fs.cp(f"{url}/{self.name}.zip", self.raw_dir, extract=True)
223 | for filename in fs.ls(osp.join(self.raw_dir, self.name)):
224 | fs.mv(filename, osp.join(self.raw_dir, osp.basename(filename)))
225 | fs.rm(osp.join(self.raw_dir, self.name))
226 |
227 | def process(self) -> None:
228 | self.data, self.slices, sizes = read_tu_data(self.raw_dir, self.name)
229 |
230 | if self.pre_filter is not None or self.pre_transform is not None:
231 | data_list = [self.get(idx) for idx in range(len(self))]
232 |
233 | if self.pre_filter is not None:
234 | data_list = [d for d in data_list if self.pre_filter(d)]
235 |
236 | if self.pre_transform is not None:
237 | data_list = [self.pre_transform(d) for d in data_list]
238 |
239 | self.data, self.slices = self.collate(data_list)
240 | self._data_list = None # Reset cache.
241 |
242 | assert isinstance(self._data, Data)
243 | fs.torch_save(
244 | (self._data.to_dict(), self.slices, sizes, self._data.__class__),
245 | self.processed_paths[0],
246 | )
247 |
248 | def __repr__(self) -> str:
249 | return f"{self.name}({len(self)})"
250 |
--------------------------------------------------------------------------------
/src/datasets/tu_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | import os.path as osp
4 | from typing import Callable, List, Optional
5 |
6 | import numpy as np
7 | import torch.nn.functional as F
8 | import torch
9 | from torch.utils.data import random_split
10 | import torch_geometric.utils
11 | from torch_geometric.io import fs, read_tu_data
12 | from torch_geometric.utils import remove_self_loops
13 | from torch_geometric.data import Data, InMemoryDataset, download_url
14 | from hydra.utils import get_original_cwd
15 |
16 | from utils import PlaceHolder
17 | from datasets.abstract_dataset import (
18 | AbstractDataModule,
19 | AbstractDatasetInfos,
20 | )
21 | from datasets.dataset_utils import (
22 | load_pickle,
23 | save_pickle,
24 | Statistics,
25 | to_list,
26 | RemoveYTransform,
27 | )
28 |
29 |
30 | class ProteinDataset(InMemoryDataset):
31 | """
32 | Implementation based on https://github.com/KarolisMart/SPECTRE/blob/main/data.py
33 | """
34 |
35 | def __init__(
36 | self,
37 | split,
38 | root,
39 | transform=None,
40 | pre_transform=None,
41 | pre_filter=None,
42 | ):
43 | self.dataset_name = "protein"
44 | root = root
45 |
46 | self.split = split
47 | if self.split == "train":
48 | self.file_idx = 0
49 | elif self.split == "val":
50 | self.file_idx = 1
51 | else:
52 | self.file_idx = 2
53 |
54 | super().__init__(root, transform, pre_transform, pre_filter)
55 | self.data, self.slices = torch.load(self.processed_paths[0])
56 |
57 | @property
58 | def raw_file_names(self):
59 | return ["train_indices.pt", "val_indices.pt", "test_indices.pt"]
60 |
61 | @property
62 | def split_file_name(self):
63 | return ["train.pt", "val.pt", "test.pt"]
64 |
65 | @property
66 | def split_paths(self):
67 | r"""The absolute filepaths that must be present in order to skip
68 | splitting."""
69 | files = to_list(self.split_file_name)
70 | return [osp.join(self.raw_dir, f) for f in files]
71 |
72 | @property
73 | def processed_file_names(self):
74 | if self.split == "train":
75 | return [
76 | f"train.pt",
77 | f"train_n.pickle",
78 | f"train_node_types.npy",
79 | f"train_bond_types.npy",
80 | ]
81 | elif self.split == "val":
82 | return [
83 | f"val.pt",
84 | f"val_n.pickle",
85 | f"val_node_types.npy",
86 | f"val_bond_types.npy",
87 | ]
88 | else:
89 | return [
90 | f"test.pt",
91 | f"test_n.pickle",
92 | f"test_node_types.npy",
93 | f"test_bond_types.npy",
94 | ]
95 |
96 | def download(self):
97 | """
98 | Download raw files.
99 | """
100 | raw_url = "https://raw.githubusercontent.com/KarolisMart/SPECTRE/main/data/DD"
101 | for name in [
102 | "DD_A.txt",
103 | "DD_graph_indicator.txt",
104 | "DD_graph_labels.txt",
105 | "DD_node_labels.txt",
106 | ]:
107 | download_url(f"{raw_url}/{name}", self.raw_dir)
108 |
109 | # read
110 | path = os.path.join(self.root, "raw")
111 | data_graph_indicator = np.loadtxt(
112 | os.path.join(path, "DD_graph_indicator.txt"), delimiter=","
113 | ).astype(int)
114 |
115 | # split data
116 | g_cpu = torch.Generator()
117 | g_cpu.manual_seed(1234)
118 |
119 | min_num_nodes = 100
120 | max_num_nodes = 500
121 | available_graphs = []
122 | for idx in np.arange(1, data_graph_indicator.max() + 1):
123 | node_idx = data_graph_indicator == idx
124 | if node_idx.sum() >= min_num_nodes and node_idx.sum() <= max_num_nodes:
125 | available_graphs.append(idx)
126 | available_graphs = torch.Tensor(available_graphs)
127 |
128 | self.num_graphs = len(available_graphs)
129 | test_len = int(round(self.num_graphs * 0.2))
130 | train_len = int(round((self.num_graphs - test_len) * 0.8))
131 | val_len = self.num_graphs - train_len - test_len
132 |
133 | train_indices, val_indices, test_indices = random_split(
134 | available_graphs,
135 | [train_len, val_len, test_len],
136 | generator=torch.Generator().manual_seed(1234),
137 | )
138 | print(f"Dataset sizes: train {train_len}, val {val_len}, test {test_len}")
139 |
140 | print(f"Train indices: {train_indices}")
141 | print(f"Val indices: {val_indices}")
142 | print(f"Test indices: {test_indices}")
143 |
144 | torch.save(train_indices, self.raw_paths[0])
145 | torch.save(val_indices, self.raw_paths[1])
146 | torch.save(test_indices, self.raw_paths[2])
147 |
148 | def process(self):
149 | indices = torch.load(
150 | os.path.join(self.raw_dir, "{}_indices.pt".format(self.split))
151 | )
152 | data_adj = (
153 | torch.Tensor(
154 | np.loadtxt(os.path.join(self.raw_dir, "DD_A.txt"), delimiter=",")
155 | ).long()
156 | - 1
157 | )
158 | data_node_label = (
159 | torch.Tensor(
160 | np.loadtxt(
161 | os.path.join(self.raw_dir, "DD_node_labels.txt"), delimiter=","
162 | )
163 | ).long()
164 | - 1
165 | )
166 | data_graph_indicator = torch.Tensor(
167 | np.loadtxt(
168 | os.path.join(self.raw_dir, "DD_graph_indicator.txt"), delimiter=","
169 | )
170 | ).long()
171 | data_graph_types = (
172 | torch.Tensor(
173 | np.loadtxt(
174 | os.path.join(self.raw_dir, "DD_graph_labels.txt"), delimiter=","
175 | )
176 | ).long()
177 | - 1
178 | )
179 | data_list = []
180 |
181 | # get information
182 | self.num_node_type = data_node_label.max() + 1
183 | self.num_edge_type = 2
184 | self.num_graph_type = data_graph_types.max() + 1
185 | print(f"Number of node types: {self.num_node_type}")
186 | print(f"Number of edge types: {self.num_edge_type}")
187 | print(f"Number of graph types: {self.num_graph_type}")
188 |
189 | for idx in indices:
190 | offset = torch.where(data_graph_indicator == idx)[0].min()
191 | node_idx = data_graph_indicator == idx
192 | # perm = torch.randperm(node_idx.sum()).long()
193 | # reverse_perm = torch.sort(perm)[1]
194 | # nodes = data_node_label[node_idx][perm].long()
195 | # edge_idx = node_idx[data_adj[:, 0]]
196 | # edge_index = data_adj[edge_idx] - offset
197 | # edge_index[:, 0] = reverse_perm[edge_index[:, 0]]
198 | # edge_index[:, 1] = reverse_perm[edge_index[:, 1]]
199 | # nodes = data_node_label[node_idx].float()
200 | nodes = torch.ones(len(data_node_label[node_idx]), 1).float()
201 | edge_idx = node_idx[data_adj[:, 0]]
202 | edge_index = data_adj[edge_idx] - offset
203 | edge_attr = torch.ones_like(edge_index[:, 0]).float()
204 | edge_index, edge_attr = remove_self_loops(edge_index.T, edge_attr)
205 | data = torch_geometric.data.Data(
206 | x=nodes,
207 | edge_index=edge_index,
208 | edge_attr=edge_attr.unsqueeze(-1),
209 | n_nodes=nodes.shape[0],
210 | )
211 |
212 | if self.pre_filter is not None and not self.pre_filter(data):
213 | continue
214 | if self.pre_transform is not None:
215 | data = self.pre_transform(data)
216 |
217 | data_list.append(data)
218 |
219 | torch.save(self.collate(data_list), self.processed_paths[0])
220 |
221 |
222 | class ProteinDataModule(AbstractDataModule):
223 | def __init__(self, cfg):
224 | self.cfg = cfg
225 | self.dataset_name = self.cfg.dataset.name
226 | self.datadir = cfg.dataset.datadir
227 | base_path = pathlib.Path(get_original_cwd()).parents[0]
228 | root_path = os.path.join(base_path, cfg.dataset.datadir)
229 | transform = RemoveYTransform()
230 |
231 | datasets = {
232 | "train": ProteinDataset(
233 | root=root_path,
234 | transform=transform,
235 | split="train",
236 | ),
237 | "val": ProteinDataset(
238 | root=root_path,
239 | transform=transform,
240 | split="val",
241 | ),
242 | "test": ProteinDataset(
243 | root=root_path,
244 | transform=transform,
245 | split="test",
246 | ),
247 | }
248 |
249 | super().__init__(cfg, datasets)
250 | self.inner = self.train_dataset
251 |
252 |
253 | class ProteinInfos(AbstractDatasetInfos):
254 | def __init__(self, datamodule):
255 | self.is_molecular = False
256 | self.spectre = True
257 | self.use_charge = False
258 | # self.datamodule = datamodule
259 | self.dataset_name = datamodule.inner.dataset_name
260 | self.n_nodes = datamodule.node_counts()
261 | self.node_types = datamodule.node_types()
262 | self.edge_types = datamodule.edge_counts()
263 | super().complete_infos(self.n_nodes, self.node_types)
264 |
265 | def to_one_hot(self, data):
266 | """
267 | call in the beginning of data
268 | get the one_hot encoding for a charge beginning from -1
269 | """
270 | data.charge = data.x.new_zeros((*data.x.shape[:-1], 0))
271 | data.x = F.one_hot(data.x, num_classes=self.num_node_types).float()
272 | data.edge_attr = F.one_hot(
273 | data.edge_attr, num_classes=self.num_edge_types
274 | ).float()
275 |
276 | return data
277 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | import warnings
4 |
5 | import graph_tool
6 | import torch
7 |
8 | torch.cuda.empty_cache()
9 | import hydra
10 | from omegaconf import DictConfig
11 | import pytorch_lightning as pl
12 | from pytorch_lightning import Trainer
13 | from pytorch_lightning.callbacks import ModelCheckpoint
14 | from pytorch_lightning.utilities.warnings import PossibleUserWarning
15 |
16 | from src import utils
17 | from metrics.abstract_metrics import TrainAbstractMetricsDiscrete
18 | from graph_discrete_flow_model import GraphDiscreteFlowModel
19 | from models.extra_features import DummyExtraFeatures, ExtraFeatures
20 |
21 |
22 | warnings.filterwarnings("ignore", category=PossibleUserWarning)
23 |
24 |
25 | @hydra.main(version_base="1.3", config_path="../configs", config_name="config")
26 | def main(cfg: DictConfig):
27 | pl.seed_everything(cfg.train.seed)
28 | dataset_config = cfg["dataset"]
29 |
30 | if dataset_config["name"] in [
31 | "sbm",
32 | "comm20",
33 | "planar",
34 | "tree",
35 | ]:
36 | from analysis.visualization import NonMolecularVisualization
37 | from datasets.spectre_dataset import (
38 | SpectreGraphDataModule,
39 | SpectreDatasetInfos,
40 | )
41 | from analysis.spectre_utils import (
42 | PlanarSamplingMetrics,
43 | SBMSamplingMetrics,
44 | Comm20SamplingMetrics,
45 | TreeSamplingMetrics,
46 | )
47 |
48 | datamodule = SpectreGraphDataModule(cfg)
49 | if dataset_config["name"] == "sbm":
50 | sampling_metrics = SBMSamplingMetrics(datamodule)
51 | elif dataset_config["name"] == "comm20":
52 | sampling_metrics = Comm20SamplingMetrics(datamodule)
53 | elif dataset_config["name"] == "planar":
54 | sampling_metrics = PlanarSamplingMetrics(datamodule)
55 | elif dataset_config["name"] == "tree":
56 | sampling_metrics = TreeSamplingMetrics(datamodule)
57 | else:
58 | raise NotImplementedError(
59 | f"Dataset {dataset_config['name']} not implemented"
60 | )
61 |
62 | dataset_infos = SpectreDatasetInfos(datamodule, dataset_config)
63 |
64 | train_metrics = TrainAbstractMetricsDiscrete()
65 | visualization_tools = NonMolecularVisualization(dataset_name=cfg.dataset.name)
66 |
67 | extra_features = ExtraFeatures(
68 | cfg.model.extra_features,
69 | cfg.model.rrwp_steps,
70 | dataset_info=dataset_infos,
71 | )
72 | domain_features = DummyExtraFeatures()
73 |
74 | dataset_infos.compute_input_output_dims(
75 | datamodule=datamodule,
76 | extra_features=extra_features,
77 | domain_features=domain_features,
78 | )
79 |
80 |
81 | elif dataset_config["name"] in ["qm9", "guacamol", "moses", "zinc"]:
82 | from metrics.molecular_metrics import (
83 | TrainMolecularMetrics,
84 | SamplingMolecularMetrics,
85 | )
86 | from metrics.molecular_metrics_discrete import TrainMolecularMetricsDiscrete
87 | from models.extra_features_molecular import ExtraMolecularFeatures
88 | from analysis.visualization import MolecularVisualization
89 |
90 | if "qm9" in dataset_config["name"]:
91 | from datasets import qm9_dataset
92 |
93 | datamodule = qm9_dataset.QM9DataModule(cfg)
94 | dataset_infos = qm9_dataset.QM9infos(datamodule=datamodule, cfg=cfg)
95 | dataset_smiles = qm9_dataset.get_smiles(
96 | cfg=cfg,
97 | datamodule=datamodule,
98 | dataset_infos=dataset_infos,
99 | evaluate_datasets=False,
100 | )
101 | elif dataset_config["name"] == "guacamol":
102 | from datasets import guacamol_dataset
103 |
104 | datamodule = guacamol_dataset.GuacamolDataModule(cfg)
105 | dataset_infos = guacamol_dataset.Guacamolinfos(datamodule, cfg)
106 | dataset_smiles = guacamol_dataset.get_smiles(
107 | raw_dir=datamodule.train_dataset.raw_dir,
108 | filter_dataset=cfg.dataset.filter,
109 | )
110 |
111 | elif dataset_config.name == "moses":
112 | from datasets import moses_dataset
113 |
114 | datamodule = moses_dataset.MosesDataModule(cfg)
115 | dataset_infos = moses_dataset.MOSESinfos(datamodule, cfg)
116 | dataset_smiles = moses_dataset.get_smiles(
117 | raw_dir=datamodule.train_dataset.raw_dir,
118 | filter_dataset=cfg.dataset.filter,
119 | )
120 | elif "zinc" in dataset_config["name"]:
121 | from datasets import zinc_dataset
122 |
123 | datamodule = zinc_dataset.ZINCDataModule(cfg)
124 | dataset_infos = zinc_dataset.ZINCinfos(datamodule=datamodule, cfg=cfg)
125 | dataset_smiles = zinc_dataset.get_smiles(
126 | cfg=cfg,
127 | datamodule=datamodule,
128 | dataset_infos=dataset_infos,
129 | evaluate_datasets=False,
130 | )
131 | else:
132 | raise ValueError("Dataset not implemented")
133 |
134 | extra_features = ExtraFeatures(
135 | cfg.model.extra_features,
136 | cfg.model.rrwp_steps,
137 | dataset_info=dataset_infos,
138 | )
139 | domain_features = ExtraMolecularFeatures(dataset_infos=dataset_infos)
140 |
141 | dataset_infos.compute_input_output_dims(
142 | datamodule=datamodule,
143 | extra_features=extra_features,
144 | domain_features=domain_features,
145 | )
146 |
147 | train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
148 |
149 | # We do not evaluate novelty during training
150 | add_virtual_states = "absorbing" == cfg.model.transition
151 | sampling_metrics = SamplingMolecularMetrics(
152 | dataset_infos, dataset_smiles, cfg, add_virtual_states=add_virtual_states
153 | )
154 | visualization_tools = MolecularVisualization(
155 | cfg.dataset.remove_h, dataset_infos=dataset_infos
156 | )
157 |
158 | elif dataset_config["name"] == "tls":
159 | from datasets import tls_dataset
160 | from metrics.tls_metrics import TLSSamplingMetrics
161 | from analysis.visualization import NonMolecularVisualization
162 |
163 | datamodule = tls_dataset.TLSDataModule(cfg)
164 | dataset_infos = tls_dataset.TLSInfos(datamodule=datamodule)
165 |
166 | train_metrics = TrainAbstractMetricsDiscrete()
167 | extra_features = (
168 | ExtraFeatures(
169 | cfg.model.extra_features,
170 | cfg.model.rrwp_steps,
171 | dataset_info=dataset_infos,
172 | )
173 | if cfg.model.extra_features is not None
174 | else DummyExtraFeatures()
175 | )
176 | domain_features = DummyExtraFeatures()
177 |
178 | sampling_metrics = TLSSamplingMetrics(datamodule)
179 |
180 | visualization_tools = NonMolecularVisualization(dataset_name=cfg.dataset.name)
181 |
182 | dataset_infos.compute_input_output_dims(
183 | datamodule=datamodule,
184 | extra_features=extra_features,
185 | domain_features=domain_features,
186 | )
187 |
188 | else:
189 | raise NotImplementedError("Unknown dataset {}".format(cfg["dataset"]))
190 |
191 | dataset_infos.compute_reference_metrics(
192 | datamodule=datamodule,
193 | sampling_metrics=sampling_metrics,
194 | )
195 |
196 | model_kwargs = {
197 | "dataset_infos": dataset_infos,
198 | "train_metrics": train_metrics,
199 | "sampling_metrics": sampling_metrics,
200 | "visualization_tools": visualization_tools,
201 | "extra_features": extra_features,
202 | "domain_features": domain_features,
203 | "test_labels": (
204 | datamodule.test_labels
205 | if ("qm9" in cfg.dataset.name and cfg.general.conditional)
206 | else None
207 | ),
208 | }
209 |
210 | utils.create_folders(cfg)
211 | model = GraphDiscreteFlowModel(cfg=cfg, **model_kwargs)
212 |
213 | callbacks = []
214 | if cfg.train.save_model:
215 | checkpoint_callback = ModelCheckpoint(
216 | dirpath=f"checkpoints/{cfg.general.name}",
217 | filename="{epoch}",
218 | save_top_k=-1,
219 | every_n_epochs=cfg.general.sample_every_val
220 | * cfg.general.check_val_every_n_epochs,
221 | )
222 | callbacks.append(checkpoint_callback)
223 |
224 | if cfg.train.ema_decay > 0:
225 | ema_callback = utils.EMA(decay=cfg.train.ema_decay)
226 | callbacks.append(ema_callback)
227 |
228 | name = cfg.general.name
229 | if name == "debug":
230 | print("[WARNING]: Run is called 'debug' -- it will run with fast_dev_run. ")
231 |
232 | use_gpu = cfg.general.gpus > 0 and torch.cuda.is_available()
233 | trainer = Trainer(
234 | gradient_clip_val=cfg.train.clip_grad,
235 | strategy="ddp_find_unused_parameters_true", # Needed to load old checkpoints
236 | accelerator="gpu" if use_gpu else "cpu",
237 | devices=cfg.general.gpus if use_gpu else 1,
238 | max_epochs=cfg.train.n_epochs,
239 | check_val_every_n_epoch=cfg.general.check_val_every_n_epochs,
240 | fast_dev_run=name == "debug",
241 | enable_progress_bar=False,
242 | callbacks=callbacks,
243 | log_every_n_steps=50 if name != "debug" else 1,
244 | logger=[],
245 | )
246 |
247 | if not cfg.general.test_only and cfg.general.generated_path is None:
248 | trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume)
249 | else:
250 | # Start by evaluating test_only_path
251 | trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only)
252 | if cfg.general.evaluate_all_checkpoints:
253 | directory = pathlib.Path(cfg.general.test_only).parents[0]
254 | print("Directory:", directory)
255 | files_list = os.listdir(directory)
256 | for file in files_list:
257 | if ".ckpt" in file:
258 | ckpt_path = os.path.join(directory, file)
259 | if ckpt_path == cfg.general.test_only:
260 | continue
261 | print("Loading checkpoint", ckpt_path)
262 | trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path)
263 |
264 |
265 | if __name__ == "__main__":
266 | main()
267 |
--------------------------------------------------------------------------------
/src/datasets/spectre_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | import pickle as pkl
4 | import zipfile
5 |
6 | from networkx import to_numpy_array
7 |
8 | import torch
9 | from torch.utils.data import random_split
10 | import torch_geometric.utils
11 | from torch_geometric.data import InMemoryDataset, download_url
12 |
13 | from datasets.abstract_dataset import AbstractDataModule, AbstractDatasetInfos
14 |
15 |
16 | class SpectreGraphDataset(InMemoryDataset):
17 | def __init__(
18 | self,
19 | dataset_name,
20 | split,
21 | root,
22 | transform=None,
23 | pre_transform=None,
24 | pre_filter=None,
25 | ):
26 | self.sbm_file = "sbm_200.pt"
27 | self.planar_file = "planar_64_200.pt"
28 | self.comm20_file = "community_12_21_100.pt"
29 | self.dataset_name = dataset_name
30 | self.split = split
31 | super().__init__(root, transform, pre_transform, pre_filter)
32 | self.data, self.slices = torch.load(self.processed_paths[0])
33 | self.num_graphs = len(self.data.n_nodes)
34 |
35 | @property
36 | def raw_file_names(self):
37 | return ["train.pt", "val.pt", "test.pt"]
38 |
39 | @property
40 | def processed_file_names(self):
41 | return [self.split + ".pt"]
42 |
43 | def download(self):
44 | """
45 | Download raw qm9 files. Taken from PyG QM9 class
46 | """
47 | if self.dataset_name == "sbm":
48 | raw_url = "https://raw.githubusercontent.com/AndreasBergmeister/graph-generation/main/data/sbm.pkl"
49 | elif self.dataset_name == "planar":
50 | raw_url = "https://raw.githubusercontent.com/AndreasBergmeister/graph-generation/main/data/planar.pkl"
51 | elif self.dataset_name == "tree":
52 | raw_url = "https://raw.githubusercontent.com/AndreasBergmeister/graph-generation/main/data/tree.pkl"
53 | elif self.dataset_name == "comm20":
54 | raw_url = "https://raw.githubusercontent.com/KarolisMart/SPECTRE/main/data/community_12_21_100.pt"
55 | elif self.dataset_name == "ego":
56 | raw_url = "https://raw.githubusercontent.com/tufts-ml/graph-generation-EDGE/main/graphs/Ego.pkl"
57 | elif self.dataset_name == "imdb":
58 | raw_url = "https://www.chrsmrrs.com/graphkerneldatasets/IMDB-BINARY.zip"
59 | else:
60 | raise ValueError(f"Unknown dataset {self.dataset_name}")
61 | file_path = download_url(raw_url, self.raw_dir)
62 |
63 | if self.dataset_name in ["tree", "sbm", "planar"]:
64 | with open(file_path, "rb") as f:
65 | dataset = pkl.load(f)
66 | train_data = dataset["train"]
67 | val_data = dataset["val"]
68 | test_data = dataset["test"]
69 |
70 | train_data = [
71 | torch.Tensor(to_numpy_array(graph)).fill_diagonal_(0)
72 | for graph in train_data
73 | ]
74 | val_data = [
75 | torch.Tensor(to_numpy_array(graph)).fill_diagonal_(0)
76 | for graph in val_data
77 | ]
78 | test_data = [
79 | torch.Tensor(to_numpy_array(graph)).fill_diagonal_(0)
80 | for graph in test_data
81 | ]
82 | else:
83 | if self.dataset_name == "ego":
84 | networks = pkl.load(open(file_path, "rb"))
85 | adjs = [
86 | torch.Tensor(to_numpy_array(network)).fill_diagonal_(0)
87 | for network in networks
88 | ]
89 | elif self.dataset_name == "imdb":
90 | with zipfile.ZipFile(file_path, "r") as zip_ref:
91 | zip_ref.extractall(os.path.dirname(file_path))
92 |
93 | # Step 1: Read edge_index from file
94 | index_path = os.path.join(
95 | os.path.dirname(file_path), "IMDB-BINARY", "IMDB-BINARY_A.txt"
96 | )
97 | edge_index = []
98 | with open(index_path, "r") as file:
99 | for line in file:
100 | int1, int2 = map(int, line.strip().split(","))
101 | edge_index.append([int1, int2])
102 | edge_index = torch.tensor(edge_index).t().contiguous() - 1
103 |
104 | # Step 2: Read graph_indicator from file
105 | index_path = os.path.join(
106 | os.path.dirname(file_path),
107 | "IMDB-BINARY",
108 | "IMDB-BINARY_graph_indicator.txt",
109 | )
110 | graph_indicator = []
111 | with open(index_path, "r") as file:
112 | for line in file:
113 | num = int(line.strip())
114 | graph_indicator.append(num)
115 | graph_indicator = torch.tensor(graph_indicator) - 1
116 |
117 | # Step 3: Create individual graphs based on graph_indicator
118 | num_graphs = graph_indicator.max().item() + 1
119 | adjs = []
120 | for i in range(num_graphs):
121 | node_mask = graph_indicator == i
122 | n_node = node_mask.sum().item()
123 | edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
124 | edges = edge_index[:, edge_mask]
125 | ptr = torch.where(node_mask)[0][0]
126 | edges -= ptr
127 | adj = torch.zeros(n_node, n_node)
128 | adj[edges[0], edges[1]] = 1
129 | adj[edges[1], edges[0]] = 1
130 | adjs.append(adj)
131 | else:
132 | (
133 | adjs,
134 | eigvals,
135 | eigvecs,
136 | n_nodes,
137 | max_eigval,
138 | min_eigval,
139 | same_sample,
140 | n_max,
141 | ) = torch.load(file_path)
142 |
143 | g_cpu = torch.Generator().manual_seed(1234)
144 | self.num_graphs = 200
145 | if self.dataset_name in ["ego", "protein"]:
146 | self.num_graphs = len(adjs)
147 | elif self.dataset_name == "imdb":
148 | self.num_graphs = graph_indicator.max().item() + 1
149 |
150 | if self.dataset_name == "ego":
151 | test_len = int(round(self.num_graphs * 0.2))
152 | train_len = int(round(self.num_graphs * 0.8))
153 | val_len = int(round(self.num_graphs * 0.2))
154 | indices = torch.randperm(self.num_graphs, generator=g_cpu)
155 | print(
156 | f"Dataset sizes: train {train_len}, val {val_len}, test {test_len}"
157 | )
158 | train_indices = indices[:train_len]
159 | val_indices = indices[:val_len]
160 | test_indices = indices[train_len:]
161 | else:
162 | test_len = int(round(self.num_graphs * 0.2))
163 | train_len = int(round((self.num_graphs - test_len) * 0.8))
164 | val_len = self.num_graphs - train_len - test_len
165 | indices = torch.randperm(self.num_graphs, generator=g_cpu)
166 | print(
167 | f"Dataset sizes: train {train_len}, val {val_len}, test {test_len}"
168 | )
169 | train_indices = indices[:train_len]
170 | val_indices = indices[train_len : train_len + val_len]
171 | test_indices = indices[train_len + val_len :]
172 |
173 | train_data = []
174 | val_data = []
175 | test_data = []
176 |
177 | for i, adj in enumerate(adjs):
178 | if i in train_indices:
179 | train_data.append(adj)
180 | if i in val_indices:
181 | val_data.append(adj)
182 | if i in test_indices:
183 | test_data.append(adj)
184 |
185 | torch.save(train_data, self.raw_paths[0])
186 | torch.save(val_data, self.raw_paths[1])
187 | torch.save(test_data, self.raw_paths[2])
188 |
189 | def process(self):
190 | file_idx = {"train": 0, "val": 1, "test": 2}
191 | raw_dataset = torch.load(self.raw_paths[file_idx[self.split]])
192 |
193 | data_list = []
194 | for adj in raw_dataset:
195 | n = adj.shape[-1]
196 | X = torch.ones(n, 1, dtype=torch.float)
197 | y = torch.zeros([1, 0]).float()
198 | edge_index, _ = torch_geometric.utils.dense_to_sparse(adj)
199 | edge_attr = torch.zeros(edge_index.shape[-1], 2, dtype=torch.float)
200 | edge_attr[:, 1] = 1
201 | num_nodes = n * torch.ones(1, dtype=torch.long)
202 | data = torch_geometric.data.Data(
203 | x=X, edge_index=edge_index, edge_attr=edge_attr, y=y, n_nodes=num_nodes
204 | )
205 |
206 | if self.pre_filter is not None and not self.pre_filter(data):
207 | continue
208 | if self.pre_transform is not None:
209 | data = self.pre_transform(data)
210 |
211 | data_list.append(data)
212 |
213 | torch.save(self.collate(data_list), self.processed_paths[0])
214 |
215 |
216 | class SpectreGraphDataModule(AbstractDataModule):
217 | def __init__(self, cfg, n_graphs=200):
218 | self.cfg = cfg
219 | self.datadir = cfg.dataset.datadir
220 | base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
221 | root_path = os.path.join(base_path, self.datadir)
222 |
223 | datasets = {
224 | "train": SpectreGraphDataset(
225 | dataset_name=self.cfg.dataset.name, split="train", root=root_path
226 | ),
227 | "val": SpectreGraphDataset(
228 | dataset_name=self.cfg.dataset.name, split="val", root=root_path
229 | ),
230 | "test": SpectreGraphDataset(
231 | dataset_name=self.cfg.dataset.name, split="test", root=root_path
232 | ),
233 | }
234 |
235 | train_len = len(datasets["train"].data.n_nodes)
236 | val_len = len(datasets["val"].data.n_nodes)
237 | test_len = len(datasets["test"].data.n_nodes)
238 | print(f"Dataset sizes: train {train_len}, val {val_len}, test {test_len}")
239 |
240 | super().__init__(cfg, datasets)
241 | self.inner = self.train_dataset
242 |
243 | def __getitem__(self, item):
244 | return self.inner[item]
245 |
246 |
247 | class SpectreDatasetInfos(AbstractDatasetInfos):
248 | def __init__(self, datamodule, dataset_config):
249 | # self.datamodule = datamodule
250 | self.dataset_name = datamodule.inner.dataset_name
251 | self.n_nodes = datamodule.node_counts()
252 | self.node_types = datamodule.node_types()
253 | self.edge_types = datamodule.edge_counts()
254 | super().complete_infos(self.n_nodes, self.node_types)
255 |
--------------------------------------------------------------------------------
/src/analysis/visualization.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from rdkit import Chem
4 | from rdkit.Chem import Draw, AllChem
5 | from rdkit.Geometry import Point3D
6 | from rdkit import RDLogger
7 | import imageio
8 | import networkx as nx
9 | import numpy as np
10 | import rdkit.Chem
11 | import wandb
12 | import matplotlib.pyplot as plt
13 |
14 |
15 | from datasets.tls_dataset import CellGraph
16 |
17 |
18 | class MolecularVisualization:
19 | def __init__(self, remove_h, dataset_infos):
20 | self.remove_h = remove_h
21 | self.dataset_infos = dataset_infos
22 |
23 | def mol_from_graphs(self, node_list, adjacency_matrix):
24 | """
25 | Convert graphs to rdkit molecules
26 | node_list: the nodes of a batch of nodes (bs x n)
27 | adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
28 | """
29 | # dictionary to map integer value to the char of atom
30 | atom_decoder = self.dataset_infos.atom_decoder
31 |
32 | # create empty editable mol object
33 | mol = Chem.RWMol()
34 |
35 | # add atoms to mol and keep track of index
36 | node_to_idx = {}
37 | for i in range(len(node_list)):
38 | if node_list[i] == -1:
39 | continue
40 | a = Chem.Atom(atom_decoder[int(node_list[i])])
41 | molIdx = mol.AddAtom(a)
42 | node_to_idx[i] = molIdx
43 |
44 | for ix, row in enumerate(adjacency_matrix):
45 | for iy, bond in enumerate(row):
46 | # only traverse half the symmetric matrix
47 | if iy <= ix:
48 | continue
49 | if bond == 1:
50 | bond_type = Chem.rdchem.BondType.SINGLE
51 | elif bond == 2:
52 | bond_type = Chem.rdchem.BondType.DOUBLE
53 | elif bond == 3:
54 | bond_type = Chem.rdchem.BondType.TRIPLE
55 | elif bond == 4:
56 | bond_type = Chem.rdchem.BondType.AROMATIC
57 | else:
58 | continue
59 | mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
60 |
61 | try:
62 | mol = mol.GetMol()
63 | except rdkit.Chem.KekulizeException:
64 | print("Can't kekulize molecule")
65 | mol = None
66 | return mol
67 |
68 | def visualize(
69 | self, path: str, molecules: list, num_molecules_to_visualize: int, log="graph"
70 | ):
71 | # define path to save figures
72 | if not os.path.exists(path):
73 | os.makedirs(path)
74 |
75 | # visualize the final molecules
76 | print(f"Visualizing {num_molecules_to_visualize} of {len(molecules)}")
77 | if num_molecules_to_visualize > len(molecules):
78 | print(f"Shortening to {len(molecules)}")
79 | num_molecules_to_visualize = len(molecules)
80 |
81 | for i in range(num_molecules_to_visualize):
82 | file_path = os.path.join(path, "molecule_{}.png".format(i))
83 | mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy())
84 | try:
85 | Draw.MolToFile(mol, file_path)
86 | if wandb.run and log is not None:
87 | print(f"Saving {file_path} to wandb")
88 | wandb.log({log: wandb.Image(file_path)}, commit=True)
89 | except rdkit.Chem.KekulizeException:
90 | print("Can't kekulize molecule")
91 |
92 | def visualize_chain(self, path, nodes_list, adjacency_matrix, times, trainer=None):
93 | RDLogger.DisableLog("rdApp.*")
94 | # convert graphs to the rdkit molecules
95 | mols = [
96 | self.mol_from_graphs(nodes_list[i], adjacency_matrix[i])
97 | for i in range(nodes_list.shape[0])
98 | ]
99 |
100 | # find the coordinates of atoms in the final molecule
101 | final_molecule = mols[-1]
102 | AllChem.Compute2DCoords(final_molecule)
103 |
104 | coords = []
105 | for i, atom in enumerate(final_molecule.GetAtoms()):
106 | positions = final_molecule.GetConformer().GetAtomPosition(i)
107 | coords.append((positions.x, positions.y, positions.z))
108 |
109 | # align all the molecules
110 | for i, mol in enumerate(mols):
111 | AllChem.Compute2DCoords(mol)
112 | conf = mol.GetConformer()
113 | for j, atom in enumerate(mol.GetAtoms()):
114 | x, y, z = coords[j]
115 | conf.SetAtomPosition(j, Point3D(x, y, z))
116 |
117 | # draw gif
118 | save_paths = []
119 | num_frams = nodes_list.shape[0]
120 |
121 | for frame in range(num_frams):
122 | file_name = os.path.join(path, "fram_{}.png".format(frame))
123 | Draw.MolToFile(
124 | mols[frame],
125 | file_name,
126 | size=(300, 300),
127 | legend=f"t = {times[frame]:.2f}",
128 | )
129 | save_paths.append(file_name)
130 |
131 | imgs = [imageio.imread(fn) for fn in save_paths]
132 | gif_path = os.path.join(
133 | os.path.dirname(path), "{}.gif".format(path.split("/")[-1])
134 | )
135 | imgs.extend([imgs[-1]] * 10)
136 | imageio.mimsave(gif_path, imgs, subrectangles=True, duration=200)
137 |
138 | if wandb.run:
139 | print(f"Saving {gif_path} to wandb")
140 | wandb.log(
141 | {"chain": wandb.Video(gif_path, fps=5, format="gif")}, commit=True
142 | )
143 |
144 | # draw grid image
145 | try:
146 | img = Draw.MolsToGridImage(mols, molsPerRow=10, subImgSize=(200, 200))
147 | img.save(
148 | os.path.join(path, "{}_grid_image.png".format(path.split("/")[-1]))
149 | )
150 | except Chem.rdchem.KekulizeException:
151 | print("Can't kekulize molecule")
152 | return mols
153 |
154 |
155 | class NonMolecularVisualization:
156 |
157 | def __init__(self, dataset_name):
158 | self.is_tls = "tls" in dataset_name
159 |
160 | def to_networkx(self, node_list, adjacency_matrix):
161 | """
162 | Convert graphs to networkx graphs
163 | node_list: the nodes of a batch of nodes (bs x n)
164 | adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
165 | """
166 | graph = nx.Graph()
167 |
168 | for i in range(len(node_list)):
169 | if node_list[i] == -1:
170 | continue
171 | graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i])
172 |
173 | rows, cols = np.where(adjacency_matrix >= 1)
174 | edges = zip(rows.tolist(), cols.tolist())
175 | for edge in edges:
176 | edge_type = adjacency_matrix[edge[0]][edge[1]]
177 | graph.add_edge(
178 | edge[0], edge[1], color=float(edge_type), weight=3 * edge_type
179 | )
180 |
181 | return graph
182 |
183 | def visualize_non_molecule(
184 | self,
185 | graph,
186 | pos,
187 | path,
188 | iterations=100,
189 | node_size=100,
190 | largest_component=False,
191 | time=None,
192 | ):
193 | if largest_component:
194 | CGs = [graph.subgraph(c) for c in nx.connected_components(graph)]
195 | CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
196 | graph = CGs[0]
197 |
198 | # Plot the graph structure with colors
199 | if pos is None:
200 | pos = nx.spring_layout(graph, iterations=iterations)
201 |
202 | # Set node colors based on the eigenvectors
203 | w, U = np.linalg.eigh(nx.normalized_laplacian_matrix(graph).toarray())
204 | vmin, vmax = np.min(U[:, 1]), np.max(U[:, 1])
205 | m = max(np.abs(vmin), vmax)
206 | vmin, vmax = -m, m
207 |
208 | plt.figure()
209 | nx.draw(
210 | graph,
211 | pos,
212 | font_size=5,
213 | node_size=node_size,
214 | with_labels=False,
215 | node_color=U[:, 1],
216 | cmap=plt.cm.coolwarm,
217 | vmin=vmin,
218 | vmax=vmax,
219 | edge_color="grey",
220 | )
221 | if time is not None:
222 | plt.text(
223 | 0.5,
224 | 0.05, # place below the graph
225 | f"t = {time:.2f}",
226 | ha="center",
227 | va="center",
228 | transform=plt.gcf().transFigure,
229 | fontsize=16,
230 | )
231 |
232 | plt.tight_layout()
233 | plt.savefig(path)
234 | plt.close("all")
235 |
236 | def visualize(
237 | self, path: str, graphs: list, num_graphs_to_visualize: int, log="graph"
238 | ):
239 | # define path to save figures
240 | if not os.path.exists(path):
241 | os.makedirs(path)
242 |
243 | # visualize the final molecules
244 | for i in range(num_graphs_to_visualize):
245 | file_path = os.path.join(path, "graph_{}.png".format(i))
246 |
247 | if self.is_tls:
248 | cg = CellGraph.from_dense_graph(graphs[i])
249 | cg.plot_graph(save_path=file_path, has_legend=True)
250 | else:
251 | graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy())
252 | self.visualize_non_molecule(graph=graph, pos=None, path=file_path)
253 |
254 | im = plt.imread(file_path)
255 | if wandb.run and log is not None:
256 | wandb.log({log: [wandb.Image(im, caption=file_path)]})
257 |
258 | def visualize_chain(self, path, nodes_list, adjacency_matrix, times):
259 |
260 | graphs = []
261 | for i in range(nodes_list.shape[0]):
262 | if self.is_tls:
263 | graphs.append(
264 | CellGraph.from_dense_graph((nodes_list[i], adjacency_matrix[i]))
265 | )
266 | else:
267 | graphs.append(self.to_networkx(nodes_list[i], adjacency_matrix[i]))
268 |
269 | # find the coordinates of atoms in the final molecule
270 | final_graph = graphs[-1]
271 | final_pos = nx.spring_layout(final_graph, seed=0)
272 |
273 | # draw gif
274 | save_paths = []
275 | num_frams = nodes_list.shape[0]
276 |
277 | for frame in range(num_frams):
278 | file_name = os.path.join(path, "fram_{}.png".format(frame))
279 | if self.is_tls:
280 | if not graphs[frame].get_pos(): # The last one already has a pos
281 | graphs[frame].set_pos(pos=final_pos)
282 | graphs[frame].plot_graph(
283 | save_path=file_name,
284 | has_legend=False,
285 | verbose=False,
286 | time=times[frame],
287 | )
288 | else:
289 | self.visualize_non_molecule(
290 | graph=graphs[frame],
291 | pos=final_pos,
292 | path=file_name,
293 | time=times[frame],
294 | )
295 | save_paths.append(file_name)
296 |
297 | imgs = [imageio.imread(fn) for fn in save_paths]
298 | gif_path = os.path.join(
299 | os.path.dirname(path), "{}.gif".format(path.split("/")[-1])
300 | )
301 | imgs.extend([imgs[-1]] * 10)
302 | imageio.mimsave(gif_path, imgs, subrectangles=True, duration=200)
303 | if wandb.run:
304 | wandb.log(
305 | {"chain": [wandb.Video(gif_path, caption=gif_path, format="gif")]}
306 | )
307 |
--------------------------------------------------------------------------------
/src/flow_matching/rate_matrix.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from flow_matching import flow_matching_utils
5 | from flow_matching.utils import dt_p_xt_g_x1, p_xt_g_x1
6 |
7 |
8 | class RateMatrixDesigner:
9 |
10 | def __init__(self, rdb, rdb_crit, eta, omega, limit_dist):
11 |
12 | self.omega = omega # target guidance
13 | self.eta = eta # stochasticity
14 | # Different designs of R^db
15 | self.rdb = rdb
16 | self.rdb_crit = rdb_crit
17 | self.limit_dist = limit_dist
18 | self.num_classes_X = len(self.limit_dist.X)
19 | self.num_classes_E = len(self.limit_dist.E)
20 |
21 | print(
22 | f"RateMatrixDesigner: rdb={rdb}, rdb_crit={rdb_crit}, eta={eta}, omega={omega}"
23 | )
24 |
25 | def compute_graph_rate_matrix(self, t, node_mask, G_t, G_1_pred):
26 |
27 | X_t, E_t = G_t
28 | X_1_pred, E_1_pred = G_1_pred
29 |
30 | X_t_label = X_t.argmax(-1, keepdim=True)
31 | E_t_label = E_t.argmax(-1, keepdim=True)
32 | sampled_G_1 = flow_matching_utils.sample_discrete_features(
33 | X_1_pred,
34 | E_1_pred,
35 | node_mask=node_mask,
36 | )
37 | X_1_sampled = sampled_G_1.X
38 | E_1_sampled = sampled_G_1.E
39 |
40 | dfm_variables = self.compute_dfm_variables(
41 | t, X_t_label, E_t_label, X_1_sampled, E_1_sampled
42 | )
43 |
44 | Rstar_t_X, Rstar_t_E = self.compute_Rstar(dfm_variables)
45 |
46 | Rdb_t_X, Rdb_t_E = self.compute_RDB(
47 | X_t_label,
48 | E_t_label,
49 | X_1_pred,
50 | E_1_pred,
51 | X_1_sampled,
52 | E_1_sampled,
53 | node_mask,
54 | t,
55 | dfm_variables,
56 | )
57 |
58 | Rtg_t_X, Rtg_t_E = self.compute_R_tg(
59 | X_1_sampled,
60 | E_1_sampled,
61 | X_t_label,
62 | E_t_label,
63 | dfm_variables,
64 | )
65 |
66 | # sum to get the final R_t_X and R_t_E
67 | R_t_X = Rstar_t_X + Rdb_t_X + Rtg_t_X
68 | R_t_E = Rstar_t_E + Rdb_t_E + Rtg_t_E
69 |
70 | # Stabilize rate matrices
71 | R_t_X, R_t_E = self.stabilize_rate_matrix(R_t_X, R_t_E, dfm_variables)
72 |
73 | return R_t_X, R_t_E
74 |
75 | def compute_dfm_variables(self, t, X_t_label, E_t_label, X_1_sampled, E_1_sampled):
76 |
77 | dt_p_vals_X, dt_p_vals_E = dt_p_xt_g_x1(
78 | X_1_sampled,
79 | E_1_sampled,
80 | self.limit_dist,
81 | ) # (bs, n, dx), (bs, n, n, de)
82 |
83 | dt_p_vals_at_Xt = dt_p_vals_X.gather(-1, X_t_label).squeeze(-1) # (bs, n, )
84 | dt_p_vals_at_Et = dt_p_vals_E.gather(-1, E_t_label).squeeze(-1) # (bs, n, n, )
85 |
86 | pt_vals_X, pt_vals_E = p_xt_g_x1(
87 | X_1_sampled,
88 | E_1_sampled,
89 | t,
90 | self.limit_dist,
91 | ) # (bs, n, dx), (bs, n, n, de)
92 |
93 | pt_vals_at_Xt = pt_vals_X.gather(-1, X_t_label).squeeze(-1) # (bs, n, )
94 | pt_vals_at_Et = pt_vals_E.gather(-1, E_t_label).squeeze(-1) # (bs, n, n, )
95 |
96 | Z_t_X = torch.count_nonzero(pt_vals_X, dim=-1) # (bs, n)
97 | Z_t_E = torch.count_nonzero(pt_vals_E, dim=-1) # (bs, n, n)
98 |
99 | dfm_variables = {
100 | "pt_vals_X": pt_vals_X,
101 | "pt_vals_E": pt_vals_E,
102 | "pt_vals_at_Xt": pt_vals_at_Xt,
103 | "pt_vals_at_Et": pt_vals_at_Et,
104 | "dt_p_vals_X": dt_p_vals_X,
105 | "dt_p_vals_E": dt_p_vals_E,
106 | "dt_p_vals_at_Xt": dt_p_vals_at_Xt,
107 | "dt_p_vals_at_Et": dt_p_vals_at_Et,
108 | "Z_t_X": Z_t_X,
109 | "Z_t_E": Z_t_E,
110 | }
111 |
112 | return dfm_variables
113 |
114 | def compute_Rstar(self, dfm_variables):
115 |
116 | # Unpack needed variables
117 | dt_p_vals_X = dfm_variables["dt_p_vals_X"]
118 | dt_p_vals_E = dfm_variables["dt_p_vals_E"]
119 | dt_p_vals_at_Xt = dfm_variables["dt_p_vals_at_Xt"]
120 | dt_p_vals_at_Et = dfm_variables["dt_p_vals_at_Et"]
121 | pt_vals_at_Xt = dfm_variables["pt_vals_at_Xt"]
122 | pt_vals_at_Et = dfm_variables["pt_vals_at_Et"]
123 | Z_t_X = dfm_variables["Z_t_X"]
124 | Z_t_E = dfm_variables["Z_t_E"]
125 |
126 | # Numerator of R_t^*
127 | inner_X = dt_p_vals_X - dt_p_vals_at_Xt[:, :, None]
128 | inner_E = dt_p_vals_E - dt_p_vals_at_Et[:, :, :, None]
129 | Rstar_t_numer_X = F.relu(inner_X) # (bs, n, dx)
130 | Rstar_t_numer_E = F.relu(inner_E) # (bs, n, n, de)
131 |
132 | # Denominator
133 | Rstar_t_denom_X = Z_t_X * pt_vals_at_Xt # (bs, n)
134 | Rstar_t_denom_E = Z_t_E * pt_vals_at_Et # (bs, n, n)
135 |
136 | # Final R^\star
137 | Rstar_t_X = Rstar_t_numer_X / Rstar_t_denom_X[:, :, None] # (bs, n, dx)
138 | Rstar_t_E = Rstar_t_numer_E / Rstar_t_denom_E[:, :, :, None] # (B, n, n, de)
139 |
140 | return Rstar_t_X, Rstar_t_E
141 |
142 | def compute_RDB(
143 | self,
144 | X_t_label,
145 | E_t_label,
146 | X_1_pred,
147 | E_1_pred,
148 | X_1_sampled,
149 | E_1_sampled,
150 | node_mask,
151 | t,
152 | dfm_variables,
153 | ):
154 | # unpack needed variables
155 | pt_vals_X = dfm_variables["pt_vals_X"]
156 | pt_vals_E = dfm_variables["pt_vals_E"]
157 |
158 | # dimensions
159 | dx = pt_vals_X.shape[-1]
160 | de = pt_vals_E.shape[-1]
161 |
162 | # build mask for Rdb
163 | if self.rdb == "general":
164 | x_mask = torch.ones_like(pt_vals_X)
165 | e_mask = torch.ones_like(pt_vals_E)
166 |
167 | elif self.rdb == "marginal":
168 | x_limit = self.limit_dist.X
169 | e_limit = self.limit_dist.E
170 |
171 | Xt_marginal = x_limit[X_t_label]
172 | Et_marginal = e_limit[E_t_label]
173 |
174 | x_mask = x_limit.repeat(X_t_label.shape[0], X_t_label.shape[1], 1)
175 | e_mask = e_limit.repeat(
176 | E_t_label.shape[0], E_t_label.shape[1], E_t_label.shape[2], 1
177 | )
178 |
179 | x_mask = x_mask > Xt_marginal
180 | e_mask = e_mask > Et_marginal
181 |
182 | elif self.rdb == "column":
183 | # Get column idx to pick
184 | if self.rdb_crit == "max_marginal":
185 | x_column_idxs = self.limit_dist.X.argmax(keepdim=True).expand(
186 | X_t_label.shape
187 | )
188 | e_column_idxs = self.limit_dist.E.argmax(keepdim=True).expand(
189 | E_t_label.shape
190 | )
191 | elif self.rdb_crit == "x_t":
192 | x_column_idxs = X_t_label
193 | e_column_idxs = E_t_label
194 | elif self.rdb_crit == "abs_state":
195 | x_column_idxs = torch.ones_like(X_t_label) * (dx - 1)
196 | e_column_idxs = torch.ones_like(E_t_label) * (de - 1)
197 | elif self.rdb_crit == "p_x1_g_xt":
198 | x_column_idxs = X_1_pred.argmax(dim=-1, keepdim=True)
199 | e_column_idxs = E_1_pred.argmax(dim=-1, keepdim=True)
200 | elif self.rdb_crit == "x_1": # as in paper, uniform
201 | x_column_idxs = X_1_sampled.unsqueeze(-1)
202 | e_column_idxs = E_1_sampled.unsqueeze(-1)
203 | elif self.rdb_crit == "p_xt_g_x1":
204 | x_column_idxs = pt_vals_X.argmax(dim=-1, keepdim=True)
205 | e_column_idxs = pt_vals_E.argmax(dim=-1, keepdim=True)
206 | elif self.rdb_crit == "xhat_t":
207 | sampled_1_hat = flow_matching_utils.sample_discrete_features(
208 | pt_vals_X,
209 | pt_vals_E,
210 | node_mask=node_mask,
211 | )
212 | x_column_idxs = sampled_1_hat.X.unsqueeze(-1)
213 | e_column_idxs = sampled_1_hat.E.unsqueeze(-1)
214 | else:
215 | raise NotImplementedError
216 |
217 | # create mask based on columns picked
218 | x_mask = F.one_hot(x_column_idxs.squeeze(-1), num_classes=dx)
219 | x_mask[(x_column_idxs == X_t_label).squeeze(-1)] = 1.0
220 | e_mask = F.one_hot(e_column_idxs.squeeze(-1), num_classes=de)
221 | e_mask[(e_column_idxs == E_t_label).squeeze(-1)] = 1.0
222 |
223 | elif self.rdb == "entry":
224 | if self.rdb_crit == "abs_state":
225 | # select last index
226 | x_masked_idx = torch.tensor(
227 | dx
228 | - 1 # delete -1 for the last index
229 | # dx - 1
230 | ).to(
231 | self.device
232 | ) # leaving this for now, can change later if we want to explore it a bit more
233 | e_masked_idx = torch.tensor(de - 1).to(self.device)
234 |
235 | x1_idxs = X_1_sampled.unsqueeze(-1) # (bs, n, 1)
236 | e1_idxs = E_1_sampled.unsqueeze(-1) # (bs, n, n, 1)
237 | if self.rdb_crit == "first": # here in all datasets it's the argmax
238 | # select last index
239 | x_masked_idx = torch.tensor(0).to(
240 | self.device
241 | ) # leaving this for now, can change later if we want to explore it a bit more
242 | e_masked_idx = torch.tensor(0).to(self.device)
243 |
244 | x1_idxs = X_1_sampled.unsqueeze(-1) # (bs, n, 1)
245 | e1_idxs = E_1_sampled.unsqueeze(-1) # (bs, n, n, 1)
246 | else:
247 | raise NotImplementedError
248 |
249 | # create mask based on columns picked
250 | # bs, n, _ = X_t_label.shape
251 | # x_mask = torch.zeros((bs, n, dx), device=self.device) # (bs, n, dx)
252 | x_mask = torch.zeros_like(pt_vals_X) # (bs, n, dx)
253 | xt_in_x1 = (X_t_label == x1_idxs).squeeze(-1) # (bs, n, 1)
254 | x_mask[xt_in_x1] = F.one_hot(x_masked_idx, num_classes=dx).float()
255 | xt_in_masked = (X_t_label == x_masked_idx).squeeze(-1)
256 | x_mask[xt_in_masked] = F.one_hot(
257 | x1_idxs.squeeze(-1), num_classes=dx
258 | ).float()[xt_in_masked]
259 |
260 | # e_mask = torch.zeros((bs, n, n, de), device=self.device) # (bs, n, dx)
261 | e_mask = torch.zeros_like(pt_vals_E)
262 | et_in_e1 = (E_t_label == e1_idxs).squeeze(-1)
263 | e_mask[et_in_e1] = F.one_hot(e_masked_idx, num_classes=de).float()
264 | et_in_masked = (E_t_label == e_masked_idx).squeeze(-1)
265 | e_mask[et_in_masked] = F.one_hot(
266 | e1_idxs.squeeze(-1), num_classes=de
267 | ).float()[et_in_masked]
268 |
269 | else:
270 | raise NotImplementedError(f"Not implemented rdb type: {self.rdb}")
271 |
272 | # stochastic rate matrix
273 | Rdb_t_X = pt_vals_X * x_mask * self.eta
274 | Rdb_t_E = pt_vals_E * e_mask * self.eta
275 |
276 | return Rdb_t_X, Rdb_t_E
277 |
278 | def compute_R_tg(
279 | self,
280 | X_1_sampled,
281 | E_1_sampled,
282 | X_t_label,
283 | E_t_label,
284 | dfm_variables,
285 | ):
286 | """Target guidance rate matrix"""
287 |
288 | # Unpack needed variables
289 | pt_vals_at_Xt = dfm_variables["pt_vals_at_Xt"]
290 | pt_vals_at_Et = dfm_variables["pt_vals_at_Et"]
291 | Z_t_X = dfm_variables["Z_t_X"]
292 | Z_t_E = dfm_variables["Z_t_E"]
293 |
294 | # Numerator
295 | X1_onehot = F.one_hot(X_1_sampled, num_classes=self.num_classes_X).float()
296 | E1_onehot = F.one_hot(E_1_sampled, num_classes=self.num_classes_E).float()
297 | mask_X = X_1_sampled.unsqueeze(-1) != X_t_label
298 | mask_E = E_1_sampled.unsqueeze(-1) != E_t_label
299 |
300 | Rtg_t_numer_X = X1_onehot * self.omega * mask_X
301 | Rtg_t_numer_E = E1_onehot * self.omega * mask_E
302 |
303 | # Denominator
304 | denom_X = Z_t_X * pt_vals_at_Xt # (bs, n)
305 | denom_E = Z_t_E * pt_vals_at_Et # (bs, n, n)
306 |
307 | # Final R^TG
308 | Rtg_t_X = Rtg_t_numer_X / denom_X[:, :, None]
309 | Rtg_t_E = Rtg_t_numer_E / denom_E[:, :, :, None]
310 |
311 | return Rtg_t_X, Rtg_t_E
312 |
313 | def stabilize_rate_matrix(self, R_t_X, R_t_E, dfm_variables):
314 |
315 | # Unpack needed variables
316 | pt_vals_X = dfm_variables["pt_vals_X"]
317 | pt_vals_E = dfm_variables["pt_vals_E"]
318 | pt_vals_at_Xt = dfm_variables["pt_vals_at_Xt"]
319 | pt_vals_at_Et = dfm_variables["pt_vals_at_Et"]
320 |
321 | # protect to avoid NaN and too large values
322 | R_t_X = torch.nan_to_num(R_t_X, nan=0.0, posinf=0.0, neginf=0.0)
323 | R_t_E = torch.nan_to_num(R_t_E, nan=0.0, posinf=0.0, neginf=0.0)
324 | R_t_X[R_t_X > 1e5] = 0.0
325 | R_t_E[R_t_E > 1e5] = 0.0
326 |
327 | # Set p(x_t | x_1) = 0 or p(j | x_1) = 0 cases to zero, which need to be applied to Rdb too
328 | dx = R_t_X.shape[-1]
329 | de = R_t_E.shape[-1]
330 | R_t_X[(pt_vals_at_Xt == 0.0)[:, :, None].repeat(1, 1, dx)] = 0.0
331 | R_t_E[(pt_vals_at_Et == 0.0)[:, :, :, None].repeat(1, 1, 1, de)] = 0.0
332 |
333 | # zero-out certain columns of R, which is implied in the computation of Rdb
334 | # if the probability of a place is 0, then we should not consider it in the R computation
335 | R_t_X[pt_vals_X == 0.0] = 0.0
336 | R_t_E[pt_vals_E == 0.0] = 0.0
337 |
338 | return R_t_X, R_t_E
339 |
--------------------------------------------------------------------------------
/src/models/transformer_model.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.modules.dropout import Dropout
6 | from torch.nn.modules.linear import Linear
7 | from torch.nn.modules.normalization import LayerNorm
8 | from torch.nn import functional as F
9 | from torch import Tensor
10 |
11 | from src import utils
12 | from flow_matching import flow_matching_utils
13 | from models.layers import Xtoy, Etoy, masked_softmax
14 |
15 |
16 | class XEyTransformerLayer(nn.Module):
17 | """Transformer that updates node, edge and global features
18 | d_x: node features
19 | d_e: edge features
20 | dz : global features
21 | n_head: the number of heads in the multi_head_attention
22 | dim_feedforward: the dimension of the feedforward network model after self-attention
23 | dropout: dropout probablility. 0 to disable
24 | layer_norm_eps: eps value in layer normalizations.
25 | """
26 |
27 | def __init__(
28 | self,
29 | dx: int,
30 | de: int,
31 | dy: int,
32 | n_head: int,
33 | dim_ffX: int = 2048,
34 | dim_ffE: int = 128,
35 | dim_ffy: int = 2048,
36 | dropout: float = 0.1,
37 | layer_norm_eps: float = 1e-5,
38 | device=None,
39 | dtype=None,
40 | ) -> None:
41 | kw = {"device": device, "dtype": dtype}
42 | super().__init__()
43 |
44 | self.self_attn = NodeEdgeBlock(dx, de, dy, n_head, **kw)
45 |
46 | self.linX1 = Linear(dx, dim_ffX, **kw)
47 | self.linX2 = Linear(dim_ffX, dx, **kw)
48 | self.normX1 = LayerNorm(dx, eps=layer_norm_eps, **kw)
49 | self.normX2 = LayerNorm(dx, eps=layer_norm_eps, **kw)
50 | self.dropoutX1 = Dropout(dropout)
51 | self.dropoutX2 = Dropout(dropout)
52 | self.dropoutX3 = Dropout(dropout)
53 |
54 | self.linE1 = Linear(de, dim_ffE, **kw)
55 | self.linE2 = Linear(dim_ffE, de, **kw)
56 | self.normE1 = LayerNorm(de, eps=layer_norm_eps, **kw)
57 | self.normE2 = LayerNorm(de, eps=layer_norm_eps, **kw)
58 | self.dropoutE1 = Dropout(dropout)
59 | self.dropoutE2 = Dropout(dropout)
60 | self.dropoutE3 = Dropout(dropout)
61 |
62 | self.lin_y1 = Linear(dy, dim_ffy, **kw)
63 | self.lin_y2 = Linear(dim_ffy, dy, **kw)
64 | self.norm_y1 = LayerNorm(dy, eps=layer_norm_eps, **kw)
65 | self.norm_y2 = LayerNorm(dy, eps=layer_norm_eps, **kw)
66 | self.dropout_y1 = Dropout(dropout)
67 | self.dropout_y2 = Dropout(dropout)
68 | self.dropout_y3 = Dropout(dropout)
69 |
70 | self.activation = F.relu
71 |
72 | def forward(self, X: Tensor, E: Tensor, y, node_mask: Tensor):
73 | """Pass the input through the encoder layer.
74 | X: (bs, n, d)
75 | E: (bs, n, n, d)
76 | y: (bs, dy)
77 | node_mask: (bs, n) Mask for the src keys per batch (optional)
78 | Output: newX, newE, new_y with the same shape.
79 | """
80 |
81 | newX, newE, new_y = self.self_attn(X, E, y, node_mask=node_mask)
82 |
83 | newX_d = self.dropoutX1(newX)
84 | X = self.normX1(X + newX_d)
85 | # X = self.normX1(newX_d + nn.functional.sigmoid(X))
86 |
87 | newE_d = self.dropoutE1(newE)
88 | E = self.normE1(E + newE_d)
89 | # newE_d = self.normE1(newE_d + nn.functional.sigmoid(E))
90 |
91 | new_y_d = self.dropout_y1(new_y)
92 | y = self.norm_y1(y + new_y_d)
93 | # new_y_d = self.norm_y1(new_y_d + nn.functional.sigmoid(y))
94 |
95 | ff_outputX = self.linX2(self.dropoutX2(self.activation(self.linX1(X))))
96 | ff_outputX = self.dropoutX3(ff_outputX)
97 | X = self.normX2(X + ff_outputX)
98 |
99 | ff_outputE = self.linE2(self.dropoutE2(self.activation(self.linE1(E))))
100 | ff_outputE = self.dropoutE3(ff_outputE)
101 | E = self.normE2(E + ff_outputE)
102 |
103 | ff_output_y = self.lin_y2(self.dropout_y2(self.activation(self.lin_y1(y))))
104 | ff_output_y = self.dropout_y3(ff_output_y)
105 | y = self.norm_y2(y + ff_output_y)
106 |
107 | return X, E, y
108 |
109 |
110 | class NodeEdgeBlock(nn.Module):
111 | """Self attention layer that also updates the representations on the edges."""
112 |
113 | def __init__(self, dx, de, dy, n_head, **kwargs):
114 | super().__init__()
115 | assert dx % n_head == 0, f"dx: {dx} -- nhead: {n_head}"
116 | self.dropout_attn = Dropout(0.1)
117 | self.dropout_X = Dropout(0.1)
118 | self.dropout_E = Dropout(0.1)
119 |
120 | self.dx = dx
121 | self.de = de
122 | self.dy = dy
123 | self.df = int(dx / n_head)
124 | self.n_head = n_head
125 |
126 | # Attention
127 | self.q = Linear(dx, dx)
128 | self.k = Linear(dx, dx)
129 | self.v = Linear(dx, dx)
130 |
131 | # FiLM E to X
132 | self.e_add = Linear(de, dx)
133 | self.e_mul = Linear(de, dx)
134 |
135 | # FiLM y to E
136 | self.y_e_mul = Linear(dy, dx) # Warning: here it's dx and not de
137 | self.y_e_add = Linear(dy, dx)
138 |
139 | # FiLM y to X
140 | self.y_x_mul = Linear(dy, dx)
141 | self.y_x_add = Linear(dy, dx)
142 |
143 | # Process y
144 | self.y_y = Linear(dy, dy)
145 | self.x_y = Xtoy(dx, dy)
146 | self.e_y = Etoy(de, dy)
147 |
148 | # Output layers
149 | self.x_out = Linear(dx, dx)
150 | self.e_out = Linear(dx, de)
151 | self.y_out = nn.Sequential(nn.Linear(dy, dy), nn.ReLU(), nn.Linear(dy, dy))
152 |
153 | def forward(self, X, E, y, node_mask):
154 | """
155 | :param X: bs, n, d node features
156 | :param E: bs, n, n, d edge features
157 | :param y: bs, dz global features
158 | :param node_mask: bs, n
159 | :return: newX, newE, new_y with the same shape.
160 | """
161 | bs, n, _ = X.shape
162 | x_mask = node_mask.unsqueeze(-1) # bs, n, 1
163 | e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1
164 | e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1
165 |
166 | # 1. Map X to keys and queries
167 | Q = self.q(X) * x_mask # (bs, n, dx)
168 | K = self.k(X) * x_mask # (bs, n, dx)
169 | flow_matching_utils.assert_correctly_masked(Q, x_mask)
170 | # 2. Reshape to (bs, n, n_head, df) with dx = n_head * df
171 |
172 | Q = Q.reshape((Q.size(0), Q.size(1), self.n_head, self.df))
173 | K = K.reshape((K.size(0), K.size(1), self.n_head, self.df))
174 |
175 | Q = Q.unsqueeze(2) # (bs, 1, n, n_head, df)
176 | K = K.unsqueeze(1) # (bs, n, 1, n head, df)
177 |
178 | # Compute unnormalized attentions. Y is (bs, n, n, n_head, df)
179 | Y = Q * K
180 | Y = Y / math.sqrt(Y.size(-1))
181 |
182 | # # Compute the distance based on the Gaussian kernel
183 | # Y_exp = torch.exp(
184 | # -(Q - K) * (Q - K) / math.sqrt(Y.size(-1))
185 | # ) # bs, n, n, n_head, df
186 | # Y = Y + Y_exp
187 | flow_matching_utils.assert_correctly_masked(
188 | Y, (e_mask1 * e_mask2).unsqueeze(-1)
189 | )
190 |
191 | # Incorporate edge features to the self attention scores.
192 | E1 = self.e_mul(E) * e_mask1 * e_mask2 # bs, n, n, dx
193 | E1 = E1.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df))
194 | E2 = self.e_add(E) * e_mask1 * e_mask2 # bs, n, n, dx
195 | E2 = E2.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df))
196 | Y = Y * (E1 + 1) + E2 # (bs, n, n, n_head, df)
197 |
198 | # Incorporate y to E
199 | newE = Y.flatten(start_dim=3) # bs, n, n, dx
200 | ye1 = self.y_e_add(y).unsqueeze(1).unsqueeze(1) # bs, 1, 1, de
201 | ye2 = self.y_e_mul(y).unsqueeze(1).unsqueeze(1)
202 | newE = ye1 + (ye2 + 1) * newE
203 |
204 | # Output E
205 | newE = self.e_out(newE) * e_mask1 * e_mask2 # bs, n, n, de
206 | flow_matching_utils.assert_correctly_masked(newE, e_mask1 * e_mask2)
207 |
208 | # Compute attentions. attn is still (bs, n, n, n_head, df)
209 | softmax_mask = e_mask2.expand(-1, n, -1, self.n_head) # bs, 1, n, 1
210 | attn = masked_softmax(Y, softmax_mask, dim=2) # bs, n, n, n_head
211 |
212 | V = self.v(X) * x_mask # bs, n, dx
213 | V = V.reshape((V.size(0), V.size(1), self.n_head, self.df))
214 | V = V.unsqueeze(1) # (bs, 1, n, n_head, df)
215 |
216 | # Compute weighted values
217 | weighted_V = attn * V
218 | weighted_V = weighted_V.sum(dim=2)
219 |
220 | # Send output to input dim
221 | weighted_V = weighted_V.flatten(start_dim=2) # bs, n, dx
222 |
223 | # Incorporate y to X
224 | yx1 = self.y_x_add(y).unsqueeze(1)
225 | yx2 = self.y_x_mul(y).unsqueeze(1)
226 | newX = yx1 + (yx2 + 1) * weighted_V
227 |
228 | # Output X
229 | newX = self.x_out(newX) * x_mask
230 | flow_matching_utils.assert_correctly_masked(newX, x_mask)
231 |
232 | # Process y based on X axnd E
233 | y = self.y_y(y)
234 | e_y = self.e_y(E)
235 | x_y = self.x_y(X)
236 | new_y = y + x_y + e_y
237 | new_y = self.y_out(new_y) # bs, dy
238 |
239 | # newX = self.dropout_X(newX)
240 | # newE = self.dropout_E(newE)
241 |
242 | return newX, newE, new_y
243 |
244 |
245 | class GraphTransformer(nn.Module):
246 | """
247 | n_layers : int -- number of layers
248 | dims : dict -- contains dimensions for each feature type
249 | """
250 |
251 | def __init__(
252 | self,
253 | n_layers: int,
254 | input_dims: dict,
255 | hidden_mlp_dims: dict,
256 | hidden_dims: dict,
257 | output_dims: dict,
258 | act_fn_in: nn.ReLU(),
259 | act_fn_out: nn.ReLU(),
260 | ):
261 | super().__init__()
262 | self.n_layers = n_layers
263 | self.out_dim_X = output_dims["X"]
264 | self.out_dim_E = output_dims["E"]
265 | self.out_dim_y = output_dims["y"]
266 |
267 | self.mlp_in_X = nn.Sequential(
268 | nn.Linear(input_dims["X"], hidden_mlp_dims["X"]),
269 | act_fn_in,
270 | nn.Linear(hidden_mlp_dims["X"], hidden_dims["dx"]),
271 | act_fn_in,
272 | )
273 |
274 | self.mlp_in_E = nn.Sequential(
275 | nn.Linear(input_dims["E"], hidden_mlp_dims["E"]),
276 | act_fn_in,
277 | nn.Linear(hidden_mlp_dims["E"], hidden_dims["de"]),
278 | act_fn_in,
279 | )
280 |
281 | self.mlp_in_y = nn.Sequential(
282 | nn.Linear(input_dims["y"] + 64, hidden_mlp_dims["y"]),
283 | act_fn_in,
284 | nn.Linear(hidden_mlp_dims["y"], hidden_dims["dy"]),
285 | act_fn_in,
286 | )
287 |
288 | self.tf_layers = nn.ModuleList(
289 | [
290 | XEyTransformerLayer(
291 | dx=hidden_dims["dx"],
292 | de=hidden_dims["de"],
293 | dy=hidden_dims["dy"],
294 | n_head=hidden_dims["n_head"],
295 | dim_ffX=hidden_dims["dim_ffX"],
296 | dim_ffE=hidden_dims["dim_ffE"],
297 | )
298 | for i in range(n_layers)
299 | ]
300 | )
301 |
302 | self.mlp_out_X = nn.Sequential(
303 | nn.Linear(hidden_dims["dx"], hidden_mlp_dims["X"]),
304 | act_fn_out,
305 | nn.Linear(hidden_mlp_dims["X"], output_dims["X"]),
306 | )
307 |
308 | self.mlp_out_E = nn.Sequential(
309 | nn.Linear(hidden_dims["de"], hidden_mlp_dims["E"]),
310 | act_fn_out,
311 | nn.Linear(hidden_mlp_dims["E"], output_dims["E"]),
312 | )
313 |
314 | self.mlp_out_y = nn.Sequential(
315 | nn.Linear(hidden_dims["dy"], hidden_mlp_dims["y"]),
316 | act_fn_out,
317 | nn.Linear(hidden_mlp_dims["y"], output_dims["y"]),
318 | )
319 |
320 | def forward(self, X, E, y, node_mask):
321 | bs, n = X.shape[0], X.shape[1]
322 |
323 | diag_mask = torch.eye(n)
324 | diag_mask = ~diag_mask.type_as(E).bool()
325 | diag_mask = diag_mask.unsqueeze(0).unsqueeze(-1).expand(bs, -1, -1, -1)
326 |
327 | X_to_out = X[..., : self.out_dim_X]
328 | E_to_out = E[..., : self.out_dim_E]
329 | y_to_out = y[..., : self.out_dim_y]
330 |
331 | new_E = self.mlp_in_E(E)
332 | new_E = (new_E + new_E.transpose(1, 2)) / 2
333 |
334 | # encode time steps additionally
335 | time_emb = timestep_embedding(y[:, -1].unsqueeze(-1), 64)
336 | y = torch.hstack([y, time_emb])
337 |
338 | after_in = utils.PlaceHolder(
339 | X=self.mlp_in_X(X), E=new_E, y=self.mlp_in_y(y)
340 | ).mask(node_mask)
341 | X, E, y = after_in.X, after_in.E, after_in.y
342 |
343 | for layer in self.tf_layers:
344 | X, E, y = layer(X, E, y, node_mask)
345 |
346 | X = self.mlp_out_X(X)
347 | E = self.mlp_out_E(E)
348 | y = self.mlp_out_y(y)
349 |
350 | X = X + X_to_out
351 | E = (E + E_to_out) * diag_mask
352 | y = y + y_to_out
353 |
354 | E = 1 / 2 * (E + torch.transpose(E, 1, 2))
355 |
356 | return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask)
357 |
358 |
359 | def timestep_embedding(timesteps, dim, max_period=10000):
360 | """
361 | Create sinusoidal timestep embeddings.
362 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
363 | These may be fractional.
364 | :param dim: the dimension of the output.
365 | :param max_period: controls the minimum frequency of the embeddings.
366 | :return: an [N x dim] Tensor of positional embeddings.
367 | """
368 | half = dim // 2
369 | freqs = torch.exp(
370 | -math.log(max_period)
371 | * torch.arange(start=0, end=half, dtype=torch.float32)
372 | / half
373 | ).to(device=timesteps.device)
374 | args = timesteps.float() * freqs[None]
375 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
376 | if dim % 2:
377 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
378 |
379 | # import pdb; pdb.set_trace()
380 | return embedding
381 |
--------------------------------------------------------------------------------
/src/datasets/moses_dataset.py:
--------------------------------------------------------------------------------
1 | from rdkit import Chem, RDLogger
2 | from rdkit.Chem.rdchem import BondType as BT
3 |
4 | import os
5 | import os.path as osp
6 | import pathlib
7 | from typing import Any, Sequence
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from tqdm import tqdm
12 | import numpy as np
13 | from torch_geometric.data import Data, InMemoryDataset, download_url
14 | import pandas as pd
15 |
16 | from src import utils
17 | from analysis.rdkit_functions import (
18 | mol2smiles,
19 | build_molecule_with_partial_charges,
20 | compute_molecular_metrics,
21 | )
22 | from datasets.abstract_dataset import AbstractDatasetInfos, MolecularDataModule
23 |
24 |
25 | def to_list(value: Any) -> Sequence:
26 | if isinstance(value, Sequence) and not isinstance(value, str):
27 | return value
28 | else:
29 | return [value]
30 |
31 |
32 | atom_decoder = ["C", "N", "S", "O", "F", "Cl", "Br", "H"]
33 |
34 |
35 | class MOSESDataset(InMemoryDataset):
36 | train_url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/train.csv"
37 | val_url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/test.csv"
38 | test_url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/test_scaffolds.csv"
39 |
40 | def __init__(
41 | self,
42 | stage,
43 | root,
44 | filter_dataset: bool,
45 | transform=None,
46 | pre_transform=None,
47 | pre_filter=None,
48 | ):
49 | self.stage = stage
50 | self.atom_decoder = atom_decoder
51 | self.filter_dataset = filter_dataset
52 | if self.stage == "train":
53 | self.file_idx = 0
54 | elif self.stage == "val":
55 | self.file_idx = 1
56 | else:
57 | self.file_idx = 2
58 | super().__init__(root, transform, pre_transform, pre_filter)
59 | self.data, self.slices = torch.load(self.processed_paths[self.file_idx])
60 |
61 | @property
62 | def raw_file_names(self):
63 | return ["train_moses.csv", "val_moses.csv", "test_moses.csv"]
64 |
65 | @property
66 | def split_file_name(self):
67 | return ["train_moses.csv", "val_moses.csv", "test_moses.csv"]
68 |
69 | @property
70 | def split_paths(self):
71 | r"""The absolute filepaths that must be present in order to skip
72 | splitting."""
73 | files = to_list(self.split_file_name)
74 | return [osp.join(self.raw_dir, f) for f in files]
75 |
76 | @property
77 | def processed_file_names(self):
78 | if self.filter_dataset:
79 | return [
80 | "train_filtered.pt",
81 | "test_filtered.pt",
82 | "test_scaffold_filtered.pt",
83 | ]
84 | else:
85 | return ["train.pt", "test.pt", "test_scaffold.pt"]
86 |
87 | def download(self):
88 | import rdkit # noqa
89 |
90 | train_path = download_url(self.train_url, self.raw_dir)
91 | os.rename(train_path, osp.join(self.raw_dir, "train_moses.csv"))
92 |
93 | test_path = download_url(self.test_url, self.raw_dir)
94 | os.rename(test_path, osp.join(self.raw_dir, "val_moses.csv"))
95 |
96 | valid_path = download_url(self.val_url, self.raw_dir)
97 | os.rename(valid_path, osp.join(self.raw_dir, "test_moses.csv"))
98 |
99 | def process(self):
100 | RDLogger.DisableLog("rdApp.*")
101 | types = {atom: i for i, atom in enumerate(self.atom_decoder)}
102 |
103 | bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
104 |
105 | path = self.split_paths[self.file_idx]
106 | smiles_list = pd.read_csv(path)["SMILES"].values
107 |
108 | data_list = []
109 | smiles_kept = []
110 |
111 | for i, smile in enumerate(tqdm(smiles_list)):
112 | mol = Chem.MolFromSmiles(smile)
113 | N = mol.GetNumAtoms()
114 |
115 | type_idx = []
116 | for atom in mol.GetAtoms():
117 | type_idx.append(types[atom.GetSymbol()])
118 |
119 | row, col, edge_type = [], [], []
120 | for bond in mol.GetBonds():
121 | start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
122 | row += [start, end]
123 | col += [end, start]
124 | edge_type += 2 * [bonds[bond.GetBondType()] + 1]
125 |
126 | if len(row) == 0:
127 | continue
128 |
129 | edge_index = torch.tensor([row, col], dtype=torch.long)
130 | edge_type = torch.tensor(edge_type, dtype=torch.long)
131 | edge_attr = F.one_hot(edge_type, num_classes=len(bonds) + 1).to(torch.float)
132 |
133 | perm = (edge_index[0] * N + edge_index[1]).argsort()
134 | edge_index = edge_index[:, perm]
135 | edge_attr = edge_attr[perm]
136 |
137 | x = F.one_hot(torch.tensor(type_idx), num_classes=len(types)).float()
138 | y = torch.zeros(size=(1, 0), dtype=torch.float)
139 |
140 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
141 |
142 | if self.filter_dataset:
143 | # Try to build the molecule again from the graph. If it fails, do not add it to the training set
144 | dense_data, node_mask = utils.to_dense(
145 | data.x, data.edge_index, data.edge_attr, data.batch
146 | )
147 | dense_data = dense_data.mask(node_mask, collapse=True)
148 | X, E = dense_data.X, dense_data.E
149 |
150 | assert X.size(0) == 1
151 | atom_types = X[0]
152 | edge_types = E[0]
153 | mol = build_molecule_with_partial_charges(
154 | atom_types, edge_types, atom_decoder
155 | )
156 | smiles = mol2smiles(mol)
157 | if smiles is not None:
158 | try:
159 | mol_frags = Chem.rdmolops.GetMolFrags(
160 | mol, asMols=True, sanitizeFrags=True
161 | )
162 | if len(mol_frags) == 1:
163 | data_list.append(data)
164 | smiles_kept.append(smiles)
165 |
166 | except Chem.rdchem.AtomValenceException:
167 | print("Valence error in GetmolFrags")
168 | except Chem.rdchem.KekulizeException:
169 | print("Can't kekulize molecule")
170 | else:
171 | if self.pre_filter is not None and not self.pre_filter(data):
172 | continue
173 | if self.pre_transform is not None:
174 | data = self.pre_transform(data)
175 | data_list.append(data)
176 |
177 | torch.save(self.collate(data_list), self.processed_paths[self.file_idx])
178 |
179 | if self.filter_dataset:
180 | smiles_save_path = osp.join(
181 | pathlib.Path(self.raw_paths[0]).parent, f"new_{self.stage}.smiles"
182 | )
183 | print(smiles_save_path)
184 | with open(smiles_save_path, "w") as f:
185 | f.writelines("%s\n" % s for s in smiles_kept)
186 | print(f"Number of molecules kept: {len(smiles_kept)} / {len(smiles_list)}")
187 |
188 |
189 | class MosesDataModule(MolecularDataModule):
190 | def __init__(self, cfg):
191 | self.remove_h = False
192 | self.datadir = cfg.dataset.datadir
193 | self.filter_dataset = cfg.dataset.filter
194 | self.train_smiles = []
195 | base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
196 | root_path = os.path.join(base_path, self.datadir)
197 | datasets = {
198 | "train": MOSESDataset(
199 | stage="train", root=root_path, filter_dataset=self.filter_dataset
200 | ),
201 | "val": MOSESDataset(
202 | stage="val", root=root_path, filter_dataset=self.filter_dataset
203 | ),
204 | "test": MOSESDataset(
205 | stage="test", root=root_path, filter_dataset=self.filter_dataset
206 | ),
207 | }
208 | super().__init__(cfg, datasets)
209 |
210 |
211 | class MOSESinfos(AbstractDatasetInfos):
212 | def __init__(self, datamodule, cfg, recompute_statistics=False, meta=None):
213 | self.name = "MOSES"
214 | self.input_dims = None
215 | self.output_dims = None
216 | self.remove_h = False
217 | self.compute_fcd = cfg.dataset.compute_fcd
218 |
219 | self.atom_decoder = atom_decoder
220 | self.atom_encoder = {atom: i for i, atom in enumerate(self.atom_decoder)}
221 | self.atom_weights = {0: 12, 1: 14, 2: 32, 3: 16, 4: 19, 5: 35.4, 6: 79.9, 7: 1}
222 | self.valencies = [4, 3, 4, 2, 1, 1, 1, 1]
223 | self.num_atom_types = len(self.atom_decoder)
224 | self.max_weight = 350
225 |
226 | meta_files = dict(
227 | n_nodes=f"{self.name}_n_counts.txt",
228 | node_types=f"{self.name}_atom_types.txt",
229 | edge_types=f"{self.name}_edge_types.txt",
230 | valency_distribution=f"{self.name}_valencies.txt",
231 | )
232 |
233 | self.n_nodes = torch.tensor(
234 | [
235 | 0.0,
236 | 0.0,
237 | 0.0,
238 | 0.0,
239 | 0.0,
240 | 0.0,
241 | 0.0,
242 | 0.0,
243 | 3.097634362347889692e-06,
244 | 1.858580617408733815e-05,
245 | 5.007842264603823423e-05,
246 | 5.678996240021660924e-05,
247 | 1.244216400664299726e-04,
248 | 4.486406978685408831e-04,
249 | 2.253012731671333313e-03,
250 | 3.231865121051669121e-03,
251 | 6.709992419928312302e-03,
252 | 2.289564721286296844e-02,
253 | 5.411050841212272644e-02,
254 | 1.099515631794929504e-01,
255 | 1.223291903734207153e-01,
256 | 1.280680745840072632e-01,
257 | 1.445975750684738159e-01,
258 | 1.505961418151855469e-01,
259 | 1.436946094036102295e-01,
260 | 9.265746921300888062e-02,
261 | 1.820066757500171661e-02,
262 | 2.065089574898593128e-06,
263 | ]
264 | )
265 | self.max_n_nodes = len(self.n_nodes) - 1 if self.n_nodes is not None else None
266 | self.node_types = torch.tensor(
267 | [0.722338, 0.13661, 0.163655, 0.103549, 0.1421803, 0.005411, 0.00150, 0.0]
268 | )
269 | self.edge_types = torch.tensor(
270 | [0.89740, 0.0472947, 0.062670, 0.0003524, 0.0486]
271 | )
272 | self.valency_distribution = torch.zeros(3 * self.max_n_nodes - 2)
273 | self.valency_distribution[:7] = torch.tensor(
274 | [0.0, 0.1055, 0.2728, 0.3613, 0.2499, 0.00544, 0.00485]
275 | )
276 |
277 | if meta is None:
278 | meta = dict(
279 | n_nodes=None,
280 | node_types=None,
281 | edge_types=None,
282 | valency_distribution=None,
283 | )
284 | assert set(meta.keys()) == set(meta_files.keys())
285 | for k, v in meta_files.items():
286 | if (k not in meta or meta[k] is None) and os.path.exists(v):
287 | meta[k] = np.loadtxt(v)
288 | setattr(self, k, meta[k])
289 | if recompute_statistics or self.n_nodes is None:
290 | self.n_nodes = datamodule.node_counts()
291 | print("Distribution of number of nodes", self.n_nodes)
292 | np.savetxt(meta_files["n_nodes"], self.n_nodes.numpy())
293 | self.max_n_nodes = len(self.n_nodes) - 1
294 | if recompute_statistics or self.node_types is None:
295 | self.node_types = datamodule.node_types() # There are no node types
296 | print("Distribution of node types", self.node_types)
297 | np.savetxt(meta_files["node_types"], self.node_types.numpy())
298 |
299 | if recompute_statistics or self.edge_types is None:
300 | self.edge_types = datamodule.edge_counts()
301 | print("Distribution of edge types", self.edge_types)
302 | np.savetxt(meta_files["edge_types"], self.edge_types.numpy())
303 | if recompute_statistics or self.valency_distribution is None:
304 | valencies = datamodule.valency_count(self.max_n_nodes)
305 | print("Distribution of the valencies", valencies)
306 | np.savetxt(meta_files["valency_distribution"], valencies.numpy())
307 | self.valency_distribution = valencies
308 | # after we can be sure we have the data, complete infos
309 | self.complete_infos(n_nodes=self.n_nodes, node_types=self.node_types)
310 |
311 |
312 | def get_smiles(raw_dir, filter_dataset):
313 |
314 | if filter_dataset:
315 | smiles_save_paths = {
316 | "train": osp.join(raw_dir, "new_train.smiles"),
317 | "val": osp.join(raw_dir, "new_val.smiles"),
318 | "test": osp.join(raw_dir, "new_test.smiles"),
319 | }
320 | train_smiles = open(smiles_save_paths["train"]).readlines()
321 | val_smiles = open(smiles_save_paths["val"]).readlines()
322 | test_smiles = open(smiles_save_paths["test"]).readlines()
323 |
324 | else:
325 | smiles_save_paths = {
326 | "train": osp.join(raw_dir, "train_moses.csv"),
327 | "val": osp.join(raw_dir, "val_moses.csv"),
328 | "test": osp.join(raw_dir, "test_moses.csv"),
329 | }
330 | train_smiles = extract_smiles_from_csv(smiles_save_paths["train"])
331 | val_smiles = extract_smiles_from_csv(smiles_save_paths["val"])
332 | test_smiles = extract_smiles_from_csv(smiles_save_paths["test"])
333 |
334 | return {
335 | "train": train_smiles,
336 | "val": val_smiles,
337 | "test": test_smiles,
338 | }
339 |
340 |
341 | def extract_smiles_from_csv(csv_path):
342 | return pd.read_csv(csv_path)["SMILES"].to_list()
343 |
344 |
345 | if __name__ == "__main__":
346 | ds = [
347 | MOSESDataset(
348 | s,
349 | os.path.join(os.path.abspath(__file__), "../../../data/moses"),
350 | preprocess=True,
351 | )
352 | for s in ["train", "val", "test"]
353 | ]
354 |
--------------------------------------------------------------------------------
/src/models/extra_features.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from src import utils
3 |
4 |
5 | class DummyExtraFeatures:
6 | def __init__(self):
7 | """This class does not compute anything, just returns empty tensors."""
8 |
9 | def __call__(self, noisy_data):
10 | X = noisy_data["X_t"]
11 | E = noisy_data["E_t"]
12 | y = noisy_data["y_t"]
13 | empty_x = X.new_zeros((*X.shape[:-1], 0))
14 | empty_e = E.new_zeros((*E.shape[:-1], 0))
15 | empty_y = y.new_zeros((y.shape[0], 0))
16 | return utils.PlaceHolder(X=empty_x, E=empty_e, y=empty_y)
17 |
18 |
19 | class ExtraFeatures:
20 | def __init__(self, extra_features_type, rrwp_steps, dataset_info):
21 | self.max_n_nodes = dataset_info.max_n_nodes
22 | self.ncycles = NodeCycleFeatures()
23 | self.features_type = extra_features_type
24 | self.rrwp_steps = rrwp_steps
25 | self.RRWP = RRWPFeatures()
26 | self.RWP = RRWPFeatures(normalize=False)
27 | if extra_features_type in ["eigenvalues", "all"]:
28 | self.eigenfeatures = EigenFeatures(mode=extra_features_type)
29 |
30 | def __call__(self, noisy_data):
31 | n = noisy_data["node_mask"].sum(dim=1).unsqueeze(1) / self.max_n_nodes
32 | x_cycles, y_cycles = self.ncycles(noisy_data) # (bs, n_cycles)
33 |
34 | if self.features_type == "cycles":
35 | E = noisy_data["E_t"]
36 | extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E)
37 | return utils.PlaceHolder(
38 | X=x_cycles, E=extra_edge_attr, y=torch.hstack((n, y_cycles))
39 | )
40 |
41 | elif self.features_type == "eigenvalues":
42 | eigenfeatures = self.eigenfeatures(noisy_data)
43 | E = noisy_data["E_t"]
44 | extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E)
45 | n_components, batched_eigenvalues = eigenfeatures # (bs, 1), (bs, 10)
46 | return utils.PlaceHolder(
47 | X=x_cycles,
48 | E=extra_edge_attr,
49 | y=torch.hstack((n, y_cycles)),
50 | )
51 |
52 | elif self.features_type == "rrwp":
53 | E = noisy_data["E_t"].float()[..., 1:].sum(-1) # bs, n, n
54 | rrwp_edge_attr = self.RRWP(E, k=self.rrwp_steps)
55 | diag_index = torch.arange(rrwp_edge_attr.shape[1])
56 | rrwp_node_attr = rrwp_edge_attr[:, diag_index, diag_index, :]
57 | self.eigenfeatures = EigenFeatures(mode="all")
58 |
59 | return utils.PlaceHolder(
60 | X=rrwp_node_attr,
61 | E=rrwp_edge_attr,
62 | y=torch.hstack((n, y_cycles)),
63 | )
64 |
65 | elif self.features_type == "rrwp_double":
66 | E = noisy_data["E_t"].float()[..., 1:].sum(-1) # bs, n, n
67 | rrwp_edge_attr = self.RRWP(E, k=self.rrwp_steps)
68 | rrwp_edge_attr_wo_norm = self.RWP(E, k=self.rrwp_steps)
69 |
70 | # Normalize the rrwp_edge_attr_wo_norm
71 | max_value = rrwp_edge_attr_wo_norm.max(dim=1, keepdim=True).values
72 | max_value = max_value.max(dim=2, keepdim=True).values
73 | rrwp_edge_attr_wo_norm = rrwp_edge_attr_wo_norm / max_value
74 |
75 | rrwp_edge_attr = torch.cat((rrwp_edge_attr, rrwp_edge_attr_wo_norm), dim=-1)
76 | diag_index = torch.arange(rrwp_edge_attr.shape[1])
77 | rrwp_node_attr = rrwp_edge_attr[:, diag_index, diag_index, :]
78 | # self.eigenfeatures = EigenFeatures(mode='all')
79 |
80 | return utils.PlaceHolder(
81 | X=rrwp_node_attr,
82 | E=rrwp_edge_attr,
83 | y=torch.hstack((n, y_cycles)),
84 | )
85 |
86 | elif self.features_type == "rrwp_only":
87 | E = noisy_data["E_t"].float()[..., 1:].sum(-1) # bs, n, n
88 | rrwp_edge_attr = self.RRWP(E, k=self.rrwp_steps)
89 | diag_index = torch.arange(rrwp_edge_attr.shape[1])
90 | rrwp_node_attr = rrwp_edge_attr[:, diag_index, diag_index, :]
91 |
92 | return utils.PlaceHolder(
93 | X=rrwp_node_attr,
94 | E=rrwp_edge_attr,
95 | y=n,
96 | )
97 |
98 | elif self.features_type == "rrwp_comp":
99 | E = noisy_data["E_t"].float()[..., 1:].sum(-1) # bs, n, n
100 | rrwp_edge_attr = self.RRWP(E, k=int(self.rrwp_steps / 2))
101 | diag_index = torch.arange(rrwp_edge_attr.shape[1])
102 | rrwp_node_attr = rrwp_edge_attr[:, diag_index, diag_index, :]
103 |
104 | comp_E = 1 - noisy_data["E_t"].float()[..., 1:].sum(-1) # bs, n, n
105 | comp_rrwp_edge_attr = self.RRWP(comp_E, k=int(self.rrwp_steps / 2))
106 | comp_rrwp_node_attr = comp_rrwp_edge_attr[:, diag_index, diag_index, :]
107 |
108 | return utils.PlaceHolder(
109 | X=torch.cat((rrwp_node_attr, comp_rrwp_node_attr), dim=-1),
110 | E=torch.cat((rrwp_edge_attr, comp_rrwp_edge_attr), dim=-1),
111 | y=torch.hstack((n, y_cycles)),
112 | )
113 |
114 | elif self.features_type == "all":
115 | eigenfeatures = self.eigenfeatures(noisy_data)
116 | E = noisy_data["E_t"]
117 | extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E)
118 | n_components, batched_eigenvalues, nonlcc_indicator, k_lowest_eigvec = (
119 | eigenfeatures # (bs, 1), (bs, 10),
120 | )
121 |
122 | return utils.PlaceHolder(
123 | X=torch.cat(
124 | (x_cycles, nonlcc_indicator, k_lowest_eigvec),
125 | dim=-1,
126 | ),
127 | E=extra_edge_attr,
128 | y=torch.hstack((n, y_cycles, n_components, batched_eigenvalues)),
129 | )
130 |
131 | else:
132 | raise ValueError(f"Features type {self.features_type} not implemented")
133 |
134 |
135 | class RRWPFeatures:
136 | def __init__(self, k=10, normalize=True):
137 | self.k = k
138 | self.normalize = normalize
139 |
140 | def __call__(self, E, k=None):
141 | k = k or self.k
142 |
143 | (
144 | bs,
145 | n,
146 | _,
147 | ) = E.shape
148 | if self.normalize:
149 | degree = torch.zeros(bs, n, n, device=E.device)
150 | to_fill = 1 / (E.sum(dim=-1).float())
151 | to_fill[E.sum(dim=-1).float() == 0] = 0
152 | degree = torch.diagonal_scatter(degree, to_fill, dim1=1, dim2=2)
153 | E = degree @ E
154 |
155 | id = torch.eye(n, device=E.device).unsqueeze(0).repeat(bs, 1, 1)
156 | rrwp_list = [id]
157 |
158 | for i in range(k - 1):
159 | cur_rrwp = rrwp_list[-1] @ E
160 | rrwp_list.append(cur_rrwp)
161 |
162 | return torch.stack(rrwp_list, -1)
163 |
164 |
165 | class NodeCycleFeatures:
166 | def __init__(self):
167 | self.kcycles = KNodeCycles()
168 |
169 | def __call__(self, noisy_data):
170 | adj_matrix = noisy_data["E_t"][..., 1:].sum(dim=-1).float()
171 |
172 | x_cycles, y_cycles = self.kcycles.k_cycles(
173 | adj_matrix=adj_matrix
174 | ) # (bs, n_cycles)
175 | x_cycles = x_cycles.type_as(adj_matrix) * noisy_data["node_mask"].unsqueeze(-1)
176 | # Avoid large values when the graph is dense
177 | x_cycles = x_cycles / 10
178 | y_cycles = y_cycles / 10
179 | x_cycles[x_cycles > 1] = 1
180 | y_cycles[y_cycles > 1] = 1
181 | return x_cycles, y_cycles
182 |
183 |
184 | class EigenFeatures:
185 | """
186 | Code taken from : https://github.com/Saro00/DGN/blob/master/models/pytorch/eigen_agg.py
187 | """
188 |
189 | def __init__(self, mode):
190 | """mode: 'eigenvalues' or 'all'"""
191 | self.mode = mode
192 |
193 | def __call__(self, noisy_data):
194 | E_t = noisy_data["E_t"]
195 | mask = noisy_data["node_mask"]
196 | A = E_t[..., 1:].sum(dim=-1).float() * mask.unsqueeze(1) * mask.unsqueeze(2)
197 | # L = compute_laplacian(A, normalize="sym")
198 | L = compute_laplacian(A, normalize=False)
199 | mask_diag = 2 * L.shape[-1] * torch.eye(A.shape[-1]).type_as(L).unsqueeze(0)
200 | mask_diag = mask_diag * (~mask.unsqueeze(1)) * (~mask.unsqueeze(2))
201 | L = L * mask.unsqueeze(1) * mask.unsqueeze(2) + mask_diag
202 |
203 | if self.mode == "eigenvalues":
204 | eigvals = torch.linalg.eigvalsh(L) # bs, n
205 | eigvals = eigvals.type_as(A) / torch.sum(mask, dim=1, keepdim=True)
206 |
207 | n_connected_comp, batch_eigenvalues = get_eigenvalues_features(
208 | eigenvalues=eigvals
209 | )
210 | return n_connected_comp.type_as(A), batch_eigenvalues.type_as(A)
211 |
212 | elif self.mode == "all":
213 | eigvals, eigvectors = torch.linalg.eigh(L)
214 | # print(eigvals)
215 | eigvals = eigvals.type_as(A) / torch.sum(mask, dim=1, keepdim=True)
216 | eigvectors = eigvectors * mask.unsqueeze(2) * mask.unsqueeze(1)
217 | # Retrieve eigenvalues features
218 | n_connected_comp, batch_eigenvalues = get_eigenvalues_features(
219 | eigenvalues=eigvals
220 | )
221 |
222 | # Retrieve eigenvectors features
223 | nonlcc_indicator, k_lowest_eigenvector = get_eigenvectors_features(
224 | vectors=eigvectors,
225 | node_mask=noisy_data["node_mask"],
226 | n_connected=n_connected_comp,
227 | )
228 | return (
229 | n_connected_comp,
230 | batch_eigenvalues,
231 | nonlcc_indicator,
232 | k_lowest_eigenvector,
233 | )
234 | else:
235 | raise NotImplementedError(f"Mode {self.mode} is not implemented")
236 |
237 |
238 | def compute_laplacian(adjacency, normalize: bool):
239 | """
240 | adjacency : batched adjacency matrix (bs, n, n)
241 | normalize: can be None, 'sym' or 'rw' for the combinatorial, symmetric normalized or random walk Laplacians
242 | Return:
243 | L (n x n ndarray): combinatorial or symmetric normalized Laplacian.
244 | """
245 | diag = torch.sum(adjacency, dim=-1) # (bs, n)
246 | n = diag.shape[-1]
247 | D = torch.diag_embed(diag) # Degree matrix # (bs, n, n)
248 | combinatorial = D - adjacency # (bs, n, n)
249 |
250 | if not normalize:
251 | return (combinatorial + combinatorial.transpose(1, 2)) / 2
252 |
253 | diag0 = diag.clone()
254 | diag[diag == 0] = 1e-12
255 |
256 | diag_norm = 1 / torch.sqrt(diag) # (bs, n)
257 | D_norm = torch.diag_embed(diag_norm) # (bs, n, n)
258 | L = torch.eye(n, device=adjacency.device).unsqueeze(0) - D_norm @ adjacency @ D_norm
259 | L[diag0 == 0] = 0
260 | return (L + L.transpose(1, 2)) / 2
261 |
262 |
263 | def get_eigenvalues_features(eigenvalues, k=5):
264 | """
265 | values : eigenvalues -- (bs, n)
266 | node_mask: (bs, n)
267 | k: num of non zero eigenvalues to keep
268 | """
269 | ev = eigenvalues
270 | bs, n = ev.shape
271 | n_connected_components = (ev < 1e-5).sum(dim=-1)
272 | try:
273 | assert (n_connected_components > 0).all(), (n_connected_components, ev)
274 | except:
275 | import pdb
276 |
277 | pdb.set_trace()
278 |
279 | to_extend = max(n_connected_components) + k - n
280 | if to_extend > 0:
281 | eigenvalues = torch.hstack(
282 | (eigenvalues, 2 * torch.ones(bs, to_extend).type_as(eigenvalues))
283 | )
284 | indices = torch.arange(k).type_as(eigenvalues).long().unsqueeze(
285 | 0
286 | ) + n_connected_components.unsqueeze(1)
287 | first_k_ev = torch.gather(eigenvalues, dim=1, index=indices)
288 | return n_connected_components.unsqueeze(-1), first_k_ev
289 |
290 |
291 | def get_eigenvectors_features(vectors, node_mask, n_connected, k=2):
292 | """
293 | vectors (bs, n, n) : eigenvectors of Laplacian IN COLUMNS
294 | returns:
295 | not_lcc_indicator : indicator vectors of largest connected component (lcc) for each graph -- (bs, n, 1)
296 | k_lowest_eigvec : k first eigenvectors for the largest connected component -- (bs, n, k)
297 | """
298 | bs, n = vectors.size(0), vectors.size(1)
299 |
300 | # Create an indicator for the nodes outside the largest connected components
301 | first_ev = torch.round(vectors[:, :, 0], decimals=3) * node_mask # bs, n
302 | # Add random value to the mask to prevent 0 from becoming the mode
303 | random = torch.randn(bs, n, device=node_mask.device) * (~node_mask) # bs, n
304 | first_ev = first_ev + random
305 | most_common = torch.mode(first_ev, dim=1).values # values: bs -- indices: bs
306 | mask = ~(first_ev == most_common.unsqueeze(1))
307 | not_lcc_indicator = (mask * node_mask).unsqueeze(-1).float()
308 |
309 | # Get the eigenvectors corresponding to the first nonzero eigenvalues
310 | to_extend = max(n_connected) + k - n
311 | if to_extend > 0:
312 | vectors = torch.cat(
313 | (vectors, torch.zeros(bs, n, to_extend).type_as(vectors)), dim=2
314 | ) # bs, n , n + to_extend
315 | indices = torch.arange(k).type_as(vectors).long().unsqueeze(0).unsqueeze(
316 | 0
317 | ) + n_connected.unsqueeze(
318 | 2
319 | ) # bs, 1, k
320 | indices = indices.expand(-1, n, -1) # bs, n, k
321 | first_k_ev = torch.gather(vectors, dim=2, index=indices) # bs, n, k
322 | first_k_ev = first_k_ev * node_mask.unsqueeze(2)
323 |
324 | return not_lcc_indicator, first_k_ev
325 |
326 |
327 | def batch_trace(X):
328 | """
329 | Expect a matrix of shape B N N, returns the trace in shape B
330 | :param X:
331 | :return:
332 | """
333 | diag = torch.diagonal(X, dim1=-2, dim2=-1)
334 | trace = diag.sum(dim=-1)
335 | return trace
336 |
337 |
338 | def batch_diagonal(X):
339 | """
340 | Extracts the diagonal from the last two dims of a tensor
341 | :param X:
342 | :return:
343 | """
344 | return torch.diagonal(X, dim1=-2, dim2=-1)
345 |
346 |
347 | class KNodeCycles:
348 | """Builds cycle counts for each node in a graph."""
349 |
350 | def __init__(self):
351 | super().__init__()
352 |
353 | def calculate_kpowers(self):
354 | self.k1_matrix = self.adj_matrix.float()
355 | self.d = self.adj_matrix.sum(dim=-1)
356 | self.k2_matrix = self.k1_matrix @ self.adj_matrix.float()
357 | self.k3_matrix = self.k2_matrix @ self.adj_matrix.float()
358 | self.k4_matrix = self.k3_matrix @ self.adj_matrix.float()
359 | self.k5_matrix = self.k4_matrix @ self.adj_matrix.float()
360 | self.k6_matrix = self.k5_matrix @ self.adj_matrix.float()
361 |
362 | def k3_cycle(self):
363 | """tr(A ** 3)."""
364 | c3 = batch_diagonal(self.k3_matrix)
365 | return (c3 / 2).unsqueeze(-1).float(), (torch.sum(c3, dim=-1) / 6).unsqueeze(
366 | -1
367 | ).float()
368 |
369 | def k4_cycle(self):
370 | diag_a4 = batch_diagonal(self.k4_matrix)
371 | c4 = (
372 | diag_a4
373 | - self.d * (self.d - 1)
374 | - (self.adj_matrix @ self.d.unsqueeze(-1)).sum(dim=-1)
375 | )
376 | return (c4 / 2).unsqueeze(-1).float(), (torch.sum(c4, dim=-1) / 8).unsqueeze(
377 | -1
378 | ).float()
379 |
380 | def k5_cycle(self):
381 | diag_a5 = batch_diagonal(self.k5_matrix)
382 | triangles = batch_diagonal(self.k3_matrix)
383 | c5 = (
384 | diag_a5
385 | - 2 * triangles * self.d
386 | - (self.adj_matrix @ triangles.unsqueeze(-1)).sum(dim=-1)
387 | + triangles
388 | )
389 | return (c5 / 2).unsqueeze(-1).float(), (c5.sum(dim=-1) / 10).unsqueeze(
390 | -1
391 | ).float()
392 |
393 | def k6_cycle(self):
394 | term_1_t = batch_trace(self.k6_matrix)
395 | term_2_t = batch_trace(self.k3_matrix**2)
396 | term3_t = torch.sum(self.adj_matrix * self.k2_matrix.pow(2), dim=[-2, -1])
397 | d_t4 = batch_diagonal(self.k2_matrix)
398 | a_4_t = batch_diagonal(self.k4_matrix)
399 | term_4_t = (d_t4 * a_4_t).sum(dim=-1)
400 | term_5_t = batch_trace(self.k4_matrix)
401 | term_6_t = batch_trace(self.k3_matrix)
402 | term_7_t = batch_diagonal(self.k2_matrix).pow(3).sum(-1)
403 | term8_t = torch.sum(self.k3_matrix, dim=[-2, -1])
404 | term9_t = batch_diagonal(self.k2_matrix).pow(2).sum(-1)
405 | term10_t = batch_trace(self.k2_matrix)
406 |
407 | c6_t = (
408 | term_1_t
409 | - 3 * term_2_t
410 | + 9 * term3_t
411 | - 6 * term_4_t
412 | + 6 * term_5_t
413 | - 4 * term_6_t
414 | + 4 * term_7_t
415 | + 3 * term8_t
416 | - 12 * term9_t
417 | + 4 * term10_t
418 | )
419 | return None, (c6_t / 12).unsqueeze(-1).float()
420 |
421 | def k_cycles(self, adj_matrix, verbose=False):
422 | self.adj_matrix = adj_matrix
423 | self.calculate_kpowers()
424 |
425 | k3x, k3y = self.k3_cycle()
426 | assert (k3x >= -0.1).all()
427 |
428 | k4x, k4y = self.k4_cycle()
429 | assert (k4x >= -0.1).all()
430 |
431 | k5x, k5y = self.k5_cycle()
432 | assert (k5x >= -0.1).all(), k5x
433 |
434 | _, k6y = self.k6_cycle()
435 | assert (k6y >= -0.1).all()
436 |
437 | kcyclesx = torch.cat([k3x, k4x, k5x], dim=-1)
438 | kcyclesy = torch.cat([k3y, k4y, k5y, k6y], dim=-1)
439 | return kcyclesx, kcyclesy
440 |
--------------------------------------------------------------------------------