├── src ├── __init__.py ├── metrics │ ├── __init__.py │ ├── train_metrics.py │ ├── molecular_metrics_discrete.py │ ├── abstract_metrics.py │ └── tls_metrics.py ├── models │ ├── __init__.py │ ├── layers.py │ ├── extra_features_molecular.py │ ├── transformer_model.py │ └── extra_features.py ├── analysis │ ├── __init__.py │ ├── dist_helper.py │ └── visualization.py ├── datasets │ ├── __init__.py │ ├── dataset_utils.py │ ├── abstract_dataset.py │ ├── tu_dataset_origin.py │ ├── tu_dataset.py │ ├── spectre_dataset.py │ └── moses_dataset.py ├── flow_matching │ ├── init.py │ ├── utils.py │ ├── flow_matching_utils.py │ ├── noise_distribution.py │ ├── time_distorter.py │ └── rate_matrix.py ├── utils.py └── main.py ├── configs ├── __init__.py ├── dataset │ ├── tls.yaml │ ├── sbm.yaml │ ├── tree.yaml │ ├── comm20.yaml │ ├── planar.yaml │ ├── moses.yaml │ ├── zinc.yaml │ ├── qm9.yaml │ └── guacamol.yaml ├── config.yaml ├── experiment │ ├── test.yaml │ ├── qm9_with_h.yaml │ ├── qm9_no_h.yaml │ ├── debug.yaml │ ├── comm20.yaml │ ├── planar.yaml │ ├── tree.yaml │ ├── sbm.yaml │ ├── tls.yaml │ ├── guacamol.yaml │ ├── zinc.yaml │ └── moses.yaml ├── sample │ └── sample_default.yaml ├── train │ └── train_default.yaml ├── model │ └── discrete.yaml └── general │ └── general_default.yaml ├── images ├── defog.pdf ├── defog.png ├── motivation.pdf ├── qm9_molecule_4.gif └── sbm_molecule_14.gif ├── docker ├── entrypoint.sh └── dependencies │ ├── pip_no_deps.sh │ ├── apt-runtime.txt │ └── environment.yaml ├── setup.py ├── requirements.txt ├── environment.yaml ├── LICENSE ├── Dockerfile ├── .gitignore └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/flow_matching/init.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/dataset/tls.yaml: -------------------------------------------------------------------------------- 1 | name: tls 2 | datadir: 'data/tls' -------------------------------------------------------------------------------- /configs/dataset/sbm.yaml: -------------------------------------------------------------------------------- 1 | name: sbm 2 | remove_h: null 3 | datadir: 'data/sbm/' -------------------------------------------------------------------------------- /configs/dataset/tree.yaml: -------------------------------------------------------------------------------- 1 | name: tree 2 | remove_h: null 3 | datadir: 'data/tree/' -------------------------------------------------------------------------------- /configs/dataset/comm20.yaml: -------------------------------------------------------------------------------- 1 | name: comm20 2 | remove_h: null 3 | datadir: 'data/comm20/' -------------------------------------------------------------------------------- /configs/dataset/planar.yaml: -------------------------------------------------------------------------------- 1 | name: planar 2 | remove_h: null 3 | datadir: 'data/planar/' -------------------------------------------------------------------------------- /images/defog.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manuelmlmadeira/DeFoG/HEAD/images/defog.pdf -------------------------------------------------------------------------------- /images/defog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manuelmlmadeira/DeFoG/HEAD/images/defog.png -------------------------------------------------------------------------------- /images/motivation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manuelmlmadeira/DeFoG/HEAD/images/motivation.pdf -------------------------------------------------------------------------------- /images/qm9_molecule_4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manuelmlmadeira/DeFoG/HEAD/images/qm9_molecule_4.gif -------------------------------------------------------------------------------- /images/sbm_molecule_14.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manuelmlmadeira/DeFoG/HEAD/images/sbm_molecule_14.gif -------------------------------------------------------------------------------- /docker/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Activate conda environment 5 | source /opt/conda/etc/profile.d/conda.sh 6 | conda activate defog 7 | 8 | exec "$@" -------------------------------------------------------------------------------- /configs/dataset/moses.yaml: -------------------------------------------------------------------------------- 1 | name: 'moses' 2 | datadir: 'data/moses/moses_pyg/' # Relative to the moses_dataset.py file 3 | remove_h: null 4 | filter: False 5 | compute_fcd: False # Compute FCD scores -------------------------------------------------------------------------------- /configs/dataset/zinc.yaml: -------------------------------------------------------------------------------- 1 | name: 'zinc' # qm9, qm9_positional 2 | datadir: 'data/zinc/' 3 | remove_h: True 4 | random_subset: null 5 | pin_memory: False 6 | compute_fcd: True # Compute FCD scores 7 | aromatic: False -------------------------------------------------------------------------------- /configs/dataset/qm9.yaml: -------------------------------------------------------------------------------- 1 | name: 'qm9' # qm9, qm9_positional 2 | datadir: 'data/qm9/qm9_pyg/' 3 | remove_h: True 4 | random_subset: null 5 | pin_memory: False 6 | compute_fcd: True # Compute FCD scores 7 | aromatic: True -------------------------------------------------------------------------------- /docker/dependencies/pip_no_deps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Activate the conda environment so that the installation happens within the environment 4 | source /conda/etc/profile.d/conda.sh 5 | conda activate defog 6 | 7 | # Install fcd within the activated environment 8 | pip install fcd==1.2 --no-deps -------------------------------------------------------------------------------- /configs/dataset/guacamol.yaml: -------------------------------------------------------------------------------- 1 | name: 'guacamol' 2 | datadir: 'data/guacamol/guacamol_pyg/' # Relative to the guacamol_dataset.py file 3 | remove_h: null 4 | filter: True # Use the filtered version or the raw guacamol file 5 | compute_fcd: False # Compute FCD scores -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="defog", 5 | version="1.0.0", 6 | url=None, 7 | author="Yiming Qin, Manuel Madeira et al.", 8 | author_email="", 9 | description="DeFoG: Discrete flow matching for graph generation", 10 | packages=find_packages(), 11 | ) 12 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - general : general_default 4 | - model : discrete 5 | - train : train_default 6 | - dataset : qm9 7 | - sample : sample_default 8 | 9 | hydra: 10 | job: 11 | chdir: True 12 | run: 13 | dir: ../outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}-${dataset.name}-${general.name} 14 | 15 | -------------------------------------------------------------------------------- /configs/experiment/test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | general: 3 | name : 'test' 4 | gpus : 1 5 | wandb: 'disabled' 6 | sample_every_val: 1 7 | samples_to_generate: 10 8 | samples_to_save: 5 9 | chains_to_save: 1 10 | test_only: False 11 | train: 12 | ema_decay : 0 13 | batch_size: 3 14 | save_model: False 15 | n_epochs : 3 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core==1.3.2 2 | imageio==2.31.1 3 | matplotlib==3.7.1 4 | networkx==2.8.7 5 | numpy==1.23 6 | omegaconf==2.3.0 7 | overrides==7.3.1 8 | pandas==1.4 9 | pyemd==1.0.0 10 | PyGSP==0.5.1 11 | pytorch_lightning==2.0.4 12 | scipy==1.11.0 13 | setuptools==68.0.0 14 | torch_geometric==2.3.1 15 | torchmetrics==0.11.4 16 | tqdm==4.65.0 17 | wandb==0.15.4 18 | seaborn -------------------------------------------------------------------------------- /docker/dependencies/apt-runtime.txt: -------------------------------------------------------------------------------- 1 | build-essential 2 | ca-certificates 3 | pkg-config 4 | tzdata 5 | libsm6 6 | libxext-dev 7 | libxrender1 8 | libcurl3-dev 9 | libfreetype6-dev 10 | libzmq3-dev 11 | libcupti-dev 12 | libjpeg-dev 13 | libpng-dev 14 | zlib1g-dev 15 | locales 16 | rsync 17 | cmake 18 | g++ 19 | swig 20 | vim 21 | nano 22 | curl 23 | wget 24 | unzip 25 | zsh 26 | git 27 | tmux 28 | htop 29 | tree -------------------------------------------------------------------------------- /configs/experiment/qm9_with_h.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | general: 3 | name : 'qm9h' 4 | gpus : 1 5 | wandb: 'online' 6 | evaluate_all_checkpoints: False 7 | test_only: null 8 | train: 9 | n_epochs: 2000 10 | batch_size: 512 11 | save_model: True 12 | sample: 13 | sample_steps: 500 14 | time_distortion: 'polydec' 15 | omega: 0.05 16 | eta: 0 17 | model: 18 | n_layers: 7 19 | dataset: 20 | remove_h: False 21 | -------------------------------------------------------------------------------- /configs/sample/sample_default.yaml: -------------------------------------------------------------------------------- 1 | # balanced rate matrice settings 2 | eta: 0. 3 | omega: 0. 4 | 5 | # generation settings 6 | sample_steps: 1000 7 | time_distortion: "identity" # 'identity', 'cosine', polyinc, polydec 8 | search: False # 'all' | 'target_guidance' | 'distortion' | 'stochasticity' | False 9 | 10 | # fixed 11 | rdb: 'general' # general | column | entry 12 | rdb_crit: dummy # max_marginal | x_t | p_x1_g_xt | x_1 | p_xt_g_x1 | p_xtdt_g_xt | x_0 | xhat_t 13 | # abs_state 14 | -------------------------------------------------------------------------------- /configs/train/train_default.yaml: -------------------------------------------------------------------------------- 1 | # Training settings 2 | n_epochs: 1000 3 | batch_size: 512 4 | lr: 0.0002 5 | clip_grad: null # float, null to disable 6 | save_model: True 7 | num_workers: 0 8 | ema_decay: 0 # 'Amount of EMA decay, 0 means off. A reasonable value is 0.999.' 9 | progress_bar: false 10 | weight_decay: 1e-12 11 | optimizer: adamw # adamw,nadamw,nadam => nadamw for large batches, see http://arxiv.org/abs/2102.06356 for the use of nesterov momentum with large batches 12 | seed: 0 13 | 14 | time_distortion: "identity" # 'identity', 'cosine', polyinc, polydec 15 | -------------------------------------------------------------------------------- /configs/experiment/qm9_no_h.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | general: 3 | name : 'qm9_no_h' 4 | gpus : 1 5 | wandb: 'online' 6 | test_only: null 7 | evaluate_all_checkpoints: False 8 | # debug 9 | check_val_every_n_epochs: 50 10 | sample_every_val: 1 11 | train: 12 | n_epochs: 1000 13 | batch_size: 1024 14 | save_model: True 15 | sample: 16 | sample_steps: 500 17 | time_distortion: 'polydec' 18 | omega: 0 19 | eta: 0 20 | model: 21 | n_layers: 9 22 | transition: marginal 23 | dataset: 24 | remove_h: True 25 | pin_memory: True 26 | num_workers: 16 27 | -------------------------------------------------------------------------------- /configs/experiment/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | general: 3 | name : 'debug' 4 | gpus : 0 5 | wandb: 'disabled' 6 | sample_every_val: 1 7 | samples_to_generate: 4 8 | samples_to_save: 2 9 | chains_to_save: 1 10 | remove_h: True 11 | number_chain_steps: 10 # Number of frames in each gif 12 | check_val_every_n_epochs: 1 13 | train: 14 | batch_size: 4 15 | save_model: False 16 | n_epochs: 2 17 | model: 18 | n_layers: 2 19 | hidden_mlp_dims: {'X': 17, 'E': 18, 'y': 19 } 20 | hidden_dims: {'dx': 20, 'de': 21, 'dy': 22, 'n_head': 5, 'dim_ffX': 23, 'dim_ffE': 24, 'dim_ffy': 25} 21 | sample: 22 | sample_steps: 10 23 | -------------------------------------------------------------------------------- /configs/model/discrete.yaml: -------------------------------------------------------------------------------- 1 | # Model settings 2 | transition: 'marginal' # uniform, marginal, argmax, absorbfirst, absorbing 3 | model: 'graph_tf' 4 | n_layers: 5 5 | 6 | extra_features: 'rrwp' # 'all', 'cycles', 'eigenvalues', 'rrwp', 'rrwp_comp' or null 7 | rrwp_steps: 12 8 | 9 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly 10 | hidden_mlp_dims: {'X': 256, 'E': 128, 'y': 128} 11 | 12 | # The dimensions should satisfy dx % n_head == 0 13 | hidden_dims : {'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128} 14 | 15 | # training weight for edges, y, and nodes 16 | lambda_train: [5, 0] # X=1, E = lambda[0], y = lambda[1] 17 | -------------------------------------------------------------------------------- /configs/experiment/comm20.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | general: 3 | name : 'comm20' 4 | gpus : 1 5 | wandb: 'online' 6 | resume: null # If resume, path to ckpt file from outputs directory in main directory 7 | test_only: null 8 | # test_only: /home/yqin/coding/graph_dfm/outputs/2024-04-25/18-10-09-max_64Temb/checkpoints/max_64Temb/epoch=287999.ckpt 9 | check_val_every_n_epochs: 1000 10 | sample_every_val: 10 11 | samples_to_generate: 20 12 | samples_to_save: 20 13 | chains_to_save: 1 14 | log_every_steps: 50 15 | number_chain_steps: 50 # Number of frames in each gif 16 | final_model_samples_to_generate: 20 17 | final_model_samples_to_save: 10 18 | final_model_chains_to_save: 10 19 | train: 20 | n_epochs: 1000000 21 | batch_size: 256 22 | save_model: True 23 | model: 24 | n_layers: 8 -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: defog 2 | channels: 3 | - conda-forge 4 | - nvidia 5 | - pytorch 6 | dependencies: 7 | - python=3.9 8 | - rdkit=2023.03.2 9 | - graph-tool=2.45 10 | - pytorch==2.4.0 11 | - pytorch-cuda=12.1 12 | - psi4=1.9.1 13 | - pip=24.2 14 | - pip: 15 | - ipykernel==6.29.5 16 | - hydra-core==1.3.2 17 | - imageio==2.31.1 18 | - matplotlib==3.7.1 19 | - networkx==2.8.7 20 | - numpy==1.23 21 | - omegaconf==2.3.0 22 | - overrides==7.3.1 23 | - pandas==1.4 24 | - pyemd==1.0.0 25 | - PyGSP==0.5.1 26 | - pytorch_lightning==2.0.4 27 | - scipy==1.11.0 28 | - setuptools==68.0.0 29 | - torch_geometric==2.3.1 30 | - torchmetrics==0.11.4 31 | - tqdm==4.65.0 32 | - wandb==0.15.4 33 | - seaborn==0.13.2 34 | - gpustat==0.6.0 35 | - black==24.3.0 36 | - -e . 37 | -------------------------------------------------------------------------------- /docker/dependencies/environment.yaml: -------------------------------------------------------------------------------- 1 | name: defog 2 | channels: 3 | - conda-forge 4 | - nvidia 5 | - pytorch 6 | dependencies: 7 | - python=3.9 8 | - rdkit=2023.03.2 9 | - graph-tool=2.45 10 | - pytorch==2.4.0 11 | - pytorch-cuda=12.1 12 | - psi4=1.9.1 13 | - pip=24.2 14 | - pip: 15 | - ipykernel==6.29.5 16 | - hydra-core==1.3.2 17 | - imageio==2.31.1 18 | - matplotlib==3.7.1 19 | - networkx==2.8.7 20 | - numpy==1.23 21 | - omegaconf==2.3.0 22 | - overrides==7.3.1 23 | - pandas==1.4 24 | - pyemd==1.0.0 25 | - PyGSP==0.5.1 26 | - pytorch_lightning==2.0.4 27 | - scipy==1.11.0 28 | - setuptools==68.0.0 29 | - torch_geometric==2.3.1 30 | - torchmetrics==0.11.4 31 | - tqdm==4.65.0 32 | - wandb==0.15.4 33 | - seaborn==0.13.2 34 | - gpustat==0.6.0 35 | - black==24.3.0 36 | prefix: /opt/conda/envs/defog -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2025 Manuel Madeira, Yiming Qin 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /configs/experiment/planar.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | general: 3 | name : 'planar' 4 | gpus : 1 5 | wandb: 'online' 6 | resume: null # If resume, path to ckpt file from outputs directory in main directory 7 | test_only: null 8 | check_val_every_n_epochs: 2000 9 | sample_every_val: 1 10 | samples_to_generate: 40 11 | samples_to_save: 9 12 | chains_to_save: 1 13 | final_model_samples_to_generate: 40 14 | final_model_samples_to_save: 30 15 | final_model_chains_to_save: 20 16 | sample_steps: 1000 17 | train: 18 | n_epochs: 100000 19 | batch_size: 64 20 | save_model: True 21 | sample: 22 | time_distortion: 'polydec' 23 | omega: 0.05 24 | eta: 50 25 | model: 26 | n_layers: 10 27 | 28 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly 29 | # At the moment (03/08), y contains quite little information 30 | hidden_mlp_dims: { 'X': 128, 'E': 64, 'y': 128 } 31 | 32 | # The dimensions should satisfy dx % n_head == 0 33 | hidden_dims: { 'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 64, 'dim_ffy': 256 } 34 | -------------------------------------------------------------------------------- /configs/experiment/tree.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | general: 3 | name : 'tree' 4 | gpus : 1 5 | wandb: 'online' 6 | resume: null # If resume, path to ckpt file from outputs directory in main directory 7 | test_only: null 8 | check_val_every_n_epochs: 2000 9 | sample_every_val: 1 10 | samples_to_generate: 40 11 | samples_to_save: 9 12 | chains_to_save: 1 13 | final_model_samples_to_generate: 40 14 | final_model_samples_to_save: 30 15 | final_model_chains_to_save: 20 16 | sample_steps: 1000 17 | train: 18 | n_epochs: 100000 19 | batch_size: 64 20 | save_model: True 21 | time_distortion: 'polydec' 22 | sample: 23 | time_distortion: 'polydec' 24 | omega: 0 25 | eta: 0 26 | model: 27 | n_layers: 10 28 | 29 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly 30 | # At the moment (03/08), y contains quite little information 31 | hidden_mlp_dims: { 'X': 128, 'E': 64, 'y': 128 } 32 | 33 | # The dimensions should satisfy dx % n_head == 0 34 | hidden_dims: { 'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 64, 'dim_ffy': 256 } 35 | -------------------------------------------------------------------------------- /configs/experiment/sbm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | general: 3 | name : 'sbm' 4 | gpus : 1 5 | wandb: 'online' 6 | resume: null # If resume, path to ckpt file from outputs directory in main directory 7 | test_only: null 8 | check_val_every_n_epochs: 2000 9 | sample_every_val: 1 10 | samples_to_generate: 40 11 | samples_to_save: 9 12 | chains_to_save: 1 13 | final_model_samples_to_generate: 40 14 | final_model_samples_to_save: 30 15 | final_model_chains_to_save: 20 16 | sample_steps: 1000 17 | train: 18 | n_epochs: 50000 19 | batch_size: 32 20 | save_model: True 21 | sample: 22 | time_distortion: 'identity' 23 | omega: 0 24 | eta: 0 25 | model: 26 | transition: 'absorbfirst' 27 | n_layers: 8 28 | rrwp_steps: 20 29 | 30 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly 31 | # At the moment (03/08), y contains quite little information 32 | hidden_mlp_dims: { 'X': 128, 'E': 64, 'y': 128 } 33 | 34 | # The dimensions should satisfy dx % n_head == 0 35 | hidden_dims: { 'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 64, 'dim_ffy': 256 } -------------------------------------------------------------------------------- /configs/general/general_default.yaml: -------------------------------------------------------------------------------- 1 | # General settings 2 | name: 'graph-tf-model' # Warning: 'debug' and 'test' are reserved name that have a special behavior 3 | 4 | wandb: 'online' # online | offline | disabled 5 | gpus: 1 # Multi-gpu is not implemented on this branch 6 | 7 | resume: null # If resume, path to ckpt file from outputs directory in main directory 8 | test_only: null # Use absolute path 9 | 10 | check_val_every_n_epochs: 5 11 | sample_every_val: 4 12 | val_check_interval: null 13 | samples_to_generate: 512 # We advise to set it to 2 x batch_size maximum 14 | samples_to_save: 20 15 | chains_to_save: 1 16 | log_every_steps: 50 17 | number_chain_steps: 50 # Number of frames in each gif 18 | 19 | # Test 20 | generated_path: null 21 | final_model_samples_to_generate: 10000 22 | final_model_samples_to_save: 30 23 | final_model_chains_to_save: 20 24 | num_sample_fold: 1 25 | evaluate_all_checkpoints: False 26 | save_samples: True # Save samples at the final test step or not, normally only used at the last epoch or during inference 27 | 28 | # Conditional Generation 29 | conditional: False 30 | target: 'k2' 31 | guidance_weight: 2.0 32 | -------------------------------------------------------------------------------- /configs/experiment/tls.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | general: 3 | name : 'tls' 4 | gpus : 1 5 | wandb: 'online' 6 | resume: null # If resume, path to ckpt file from outputs directory in main directory 7 | test_only: null 8 | check_val_every_n_epochs: 2000 9 | sample_every_val: 1 10 | samples_to_generate: 40 11 | samples_to_save: 9 12 | chains_to_save: 1 13 | final_model_samples_to_generate: 80 # same as test set 14 | final_model_samples_to_save: 30 15 | final_model_chains_to_save: 20 16 | sample_steps: 1000 17 | conditional: True 18 | target: 'k2' 19 | guidance_weight: 2.0 20 | train: 21 | n_epochs: 100000 22 | batch_size: 64 23 | save_model: True 24 | sample: 25 | time_distortion: 'polydec' 26 | omega: 0.05 27 | eta: 0 28 | model: 29 | n_layers: 10 30 | rrwp_steps: 20 31 | 32 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly 33 | # At the moment (03/08), y contains quite little information 34 | hidden_mlp_dims: { 'X': 128, 'E': 64, 'y': 128 } 35 | 36 | # The dimensions should satisfy dx % n_head == 0 37 | hidden_dims: { 'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 64, 'dim_ffy': 256 } 38 | -------------------------------------------------------------------------------- /configs/experiment/guacamol.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | general: 3 | name : 'guacamol' 4 | gpus : 1 5 | wandb: 'online' 6 | resume: null 7 | test_only: null 8 | check_val_every_n_epochs: 2 9 | val_check_interval: null 10 | sample_every_val: 2 11 | samples_to_generate: 500 12 | samples_to_save: 20 13 | chains_to_save: 5 14 | log_every_steps: 50 15 | final_model_samples_to_generate: 18000 16 | final_model_samples_to_save: 10 17 | final_model_chains_to_save: 5 18 | train: 19 | optimizer: adam 20 | n_epochs: 1000 21 | batch_size: 64 22 | save_model: True 23 | lr: 2e-4 24 | time_distortion: 'polydec' 25 | sample: 26 | time_distortion: 'polydec' 27 | omega: 0.1 28 | eta: 300 29 | model: 30 | n_layers: 12 31 | type: 'discrete' 32 | transition: 'marginal' # uniform or marginal 33 | model: 'graph_tf' 34 | 35 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly 36 | # At the moment (03/08), y contains quite little information 37 | hidden_mlp_dims: {'X': 256, 'E': 128, 'y': 256} 38 | rrwp_steps: 20 39 | 40 | # The dimensions should satisfy dx % n_head == 0 41 | hidden_dims: {'dx': 256, 'de': 64, 'dy': 128, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 256} 42 | -------------------------------------------------------------------------------- /configs/experiment/zinc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | general: 3 | name : 'zinc' 4 | gpus : 1 5 | wandb: 'online' 6 | remove_h: True 7 | resume: null 8 | test_only: null 9 | check_val_every_n_epochs: 4 10 | val_check_interval: null 11 | sample_every_val: 2 12 | samples_to_generate: 256 13 | samples_to_save: 20 14 | chains_to_save: 5 15 | log_every_steps: 50 16 | 17 | final_model_samples_to_generate: 10000 18 | final_model_samples_to_save: 50 19 | final_model_chains_to_save: 20 20 | train: 21 | optimizer: adamw 22 | n_epochs: 300 23 | batch_size: 256 24 | save_model: True 25 | lr: 2e-4 26 | num_workers: 4 27 | time_distortion: 'polydec' 28 | sample: 29 | time_distortion: 'polydec' 30 | omega: 0. 31 | eta: 0. 32 | model: 33 | n_layers: 12 34 | type: 'discrete' 35 | transition: 'marginal' # uniform or marginal 36 | model: 'graph_tf' 37 | 38 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly 39 | # At the moment (03/08), y contains quite little information 40 | hidden_mlp_dims: { 'X': 256, 'E': 128, 'y': 256} 41 | rrwp_steps: 20 42 | 43 | # The dimensions should satisfy dx % n_head == 0 44 | hidden_dims: { 'dx': 256, 'de': 64, 'dy': 128, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 256} 45 | -------------------------------------------------------------------------------- /configs/experiment/moses.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | general: 3 | name : 'moses' 4 | gpus : 1 5 | wandb: 'online' 6 | remove_h: True 7 | resume: null 8 | test_only: null 9 | check_val_every_n_epochs: 1 10 | val_check_interval: null 11 | sample_every_val: 4 12 | samples_to_generate: 256 13 | samples_to_save: 20 14 | chains_to_save: 5 15 | log_every_steps: 50 16 | 17 | final_model_samples_to_generate: 25000 18 | final_model_samples_to_save: 50 19 | final_model_chains_to_save: 20 20 | 21 | train: 22 | optimizer: adamw 23 | n_epochs: 300 24 | batch_size: 256 25 | save_model: True 26 | lr: 2e-4 27 | num_workers: 4 28 | time_distortion: 'polydec' 29 | sample: 30 | time_distortion: 'polydec' 31 | omega: 0.5 32 | eta: 200 33 | model: 34 | n_layers: 12 35 | type: 'discrete' 36 | transition: 'marginal' # uniform or marginal 37 | model: 'graph_tf' 38 | 39 | # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly 40 | # At the moment (03/08), y contains quite little information 41 | hidden_mlp_dims: { 'X': 256, 'E': 128, 'y': 256} 42 | rrwp_steps: 20 43 | 44 | # The dimensions should satisfy dx % n_head == 0 45 | hidden_dims: { 'dx': 256, 'de': 64, 'dy': 128, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 256} 46 | -------------------------------------------------------------------------------- /src/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Xtoy(nn.Module): 6 | def __init__(self, dx, dy): 7 | """Map node features to global features""" 8 | super().__init__() 9 | self.lin = nn.Linear(4 * dx, dy) 10 | 11 | def forward(self, X): 12 | """X: bs, n, dx.""" 13 | m = X.mean(dim=1) 14 | mi = X.min(dim=1)[0] 15 | ma = X.max(dim=1)[0] 16 | std = X.std(dim=1) 17 | z = torch.hstack((m, mi, ma, std)) 18 | out = self.lin(z) 19 | return out 20 | 21 | 22 | class Etoy(nn.Module): 23 | def __init__(self, d, dy): 24 | """Map edge features to global features.""" 25 | super().__init__() 26 | self.lin = nn.Linear(4 * d, dy) 27 | 28 | def forward(self, E): 29 | """E: bs, n, n, de 30 | Features relative to the diagonal of E could potentially be added. 31 | """ 32 | m = E.mean(dim=(1, 2)) 33 | mi = E.min(dim=2)[0].min(dim=1)[0] 34 | ma = E.max(dim=2)[0].max(dim=1)[0] 35 | std = torch.std(E, dim=(1, 2)) 36 | z = torch.hstack((m, mi, ma, std)) 37 | out = self.lin(z) 38 | return out 39 | 40 | 41 | def masked_softmax(x, mask, **kwargs): 42 | if mask.sum() == 0: 43 | return x 44 | x_masked = x.clone() 45 | x_masked[mask == 0] = -float("inf") 46 | return torch.softmax(x_masked, **kwargs) 47 | -------------------------------------------------------------------------------- /src/flow_matching/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def p_xt_g_x1(X1, E1, t, limit_dist): 5 | # x1 (B, D) 6 | # t float 7 | # returns (B, D, S) for varying x_t value 8 | device = X1.device 9 | limit_dist.X = limit_dist.X.to(device) 10 | limit_dist.E = limit_dist.E.to(device) 11 | 12 | t_time = t.squeeze(-1)[:, None, None] 13 | X1_onehot = F.one_hot(X1, num_classes=len(limit_dist.X)).float() 14 | E1_onehot = F.one_hot(E1, num_classes=len(limit_dist.E)).float() 15 | 16 | Xt = t_time * X1_onehot + (1 - t_time) * limit_dist.X[None, None, :] 17 | Et = ( 18 | t_time[:, None] * E1_onehot 19 | + (1 - t_time[:, None]) * limit_dist.E[None, None, None, :] 20 | ) 21 | 22 | assert ((Xt.sum(-1) - 1).abs() < 1e-4).all() and ( 23 | (Et.sum(-1) - 1).abs() < 1e-4 24 | ).all() 25 | 26 | return Xt.clamp(min=0.0, max=1.0), Et.clamp(min=0.0, max=1.0) 27 | 28 | 29 | def dt_p_xt_g_x1(X1, E1, limit_dist): 30 | # x1 (B, D) 31 | # returns (B, D, S) for varying x_t value 32 | device = X1.device 33 | limit_dist.X = limit_dist.X.to(device) 34 | limit_dist.E = limit_dist.E.to(device) 35 | 36 | X1_onehot = F.one_hot(X1, num_classes=len(limit_dist.X)).float() 37 | E1_onehot = F.one_hot(E1, num_classes=len(limit_dist.E)).float() 38 | 39 | dX = X1_onehot - limit_dist.X[None, None, :] 40 | dE = E1_onehot - limit_dist.E[None, None, None, :] 41 | 42 | assert (dX.sum(-1).abs() < 1e-4).all() and (dE.sum(-1).abs() < 1e-4).all() 43 | 44 | return dX, dE 45 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Base image 2 | FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime 3 | 4 | # Avoid user interaction 5 | ENV DEBIAN_FRONTEND=noninteractive 6 | 7 | # Environment Variables 8 | ENV CONDA_URL=https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-23.3.1-1-Linux-x86_64.sh 9 | ENV CONDA_INSTALL_PATH=/opt/conda 10 | ENV DEPENDENCIES_DIR=/tmp/dependencies 11 | 12 | # Copy needed files 13 | COPY ./docker/dependencies ${DEPENDENCIES_DIR} 14 | 15 | # Install apt dependencies 16 | RUN apt-get update && \ 17 | xargs -a ${DEPENDENCIES_DIR}/apt-runtime.txt apt-get install -y --no-install-recommends && \ 18 | rm -rf /var/lib/apt/lists/* 19 | 20 | # Set timezone and locale 21 | RUN echo "Etc/UTC" > /etc/timezone && \ 22 | ln -sf /usr/share/zoneinfo/Etc/UTC /etc/localtime 23 | RUN locale-gen en_US.UTF-8 24 | ENV LANG=en_US.UTF-8 25 | ENV LANGUAGE=en_US:en 26 | ENV LC_ALL=en_US.UTF-8 27 | 28 | # Install Miniforge (Conda) 29 | RUN mkdir -p /tmp/conda && \ 30 | curl -fvL -o /tmp/conda/miniconda.sh ${CONDA_URL} && \ 31 | bash /tmp/conda/miniconda.sh -b -p ${CONDA_INSTALL_PATH} -u && \ 32 | rm -rf /tmp/conda 33 | # make mamba visible (using mamba for efficiency) 34 | ENV PATH=${CONDA_INSTALL_PATH}/condabin:${CONDA_INSTALL_PATH}/bin:${PATH} 35 | 36 | # Install Mamba for faster Conda operations 37 | RUN conda install mamba -n base -c conda-forge 38 | 39 | # Create Conda environment 40 | RUN mamba env create --file ${DEPENDENCIES_DIR}/environment.yaml 41 | 42 | # Activate Conda environment 43 | ENV PATH=${CONDA_INSTALL_PATH}/envs/defog/bin:${PATH} 44 | 45 | # Install additional pip packages without dependencies if needed 46 | RUN chmod +x ${DEPENDENCIES_DIR}/pip_no_deps.sh && \ 47 | ${DEPENDENCIES_DIR}/pip_no_deps.sh 48 | 49 | # Clean up 50 | RUN mamba clean -a -y && \ 51 | find ${CONDA_INSTALL_PATH}/envs/defog -name '__pycache__' -type d -exec rm -rf {} + 52 | 53 | # Initialize Conda for shells 54 | RUN mamba init --system bash && \ 55 | echo "mamba activate defog" >> /etc/profile.d/conda.sh 56 | 57 | # Delete temporary files 58 | RUN rm -rf ${DEPENDENCIES_DIR} 59 | 60 | # Set working directory 61 | WORKDIR /workspace 62 | 63 | # Entrypoint script 64 | COPY ./docker/entrypoint.sh /entrypoint.sh 65 | RUN chmod +x /entrypoint.sh 66 | ENTRYPOINT ["/entrypoint.sh"] 67 | 68 | # Default command 69 | CMD ["bash"] 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Docker configurarion 2 | checkpoints/ 3 | defog.egg-info 4 | final_checkpoints/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | 137 | .DS_Store 138 | .idea/ 139 | __pycache__/ 140 | dgd/configs/__pycache__/ 141 | data/ 142 | egnn/__pycache__/ 143 | equivariant_diffusion/__pycache__/ 144 | outputs/ 145 | archives/qm9/__pycache__/ 146 | archives/qm9/data_utils/__pycache__/ 147 | archives/qm9/data_utils/prepare/__pycache__/ 148 | archives/qm9/property_prediction/__pycache__/ 149 | archives/* 150 | .env 151 | dgd/analysis/orca/orca 152 | ggg_data/ 153 | ggg_utils/ 154 | saved_models 155 | src/analysis/orca/orca 156 | src/analysis/orca/tmp* 157 | src/timer.dat 158 | checkpoints 159 | wandb/ 160 | -------------------------------------------------------------------------------- /src/flow_matching/flow_matching_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from torch.distributions.categorical import Categorical 4 | import numpy as np 5 | import math 6 | 7 | from utils import PlaceHolder 8 | 9 | 10 | def assert_correctly_masked(variable, node_mask): 11 | assert ( 12 | variable * (1 - node_mask.long()) 13 | ).abs().max().item() < 1e-4, "Variables not masked properly." 14 | 15 | 16 | def sample_discrete_feature_noise(limit_dist, node_mask): 17 | """Sample from the limit distribution of the diffusion process""" 18 | bs, n_max = node_mask.shape 19 | x_limit = limit_dist.X[None, None, :].expand(bs, n_max, -1) 20 | e_limit = limit_dist.E[None, None, None, :].expand(bs, n_max, n_max, -1) 21 | y_limit = limit_dist.y[None, :].expand(bs, -1) 22 | U_X = ( 23 | x_limit.flatten(end_dim=-2).multinomial(1, replacement=True).reshape(bs, n_max) 24 | ) 25 | U_E = ( 26 | e_limit.flatten(end_dim=-2) 27 | .multinomial(1, replacement=True) 28 | .reshape(bs, n_max, n_max) 29 | ) 30 | U_y = torch.empty((bs, 0)) 31 | 32 | long_mask = node_mask.long() 33 | U_X = U_X.type_as(long_mask) 34 | U_E = U_E.type_as(long_mask) 35 | U_y = U_y.type_as(long_mask) 36 | 37 | U_X = F.one_hot(U_X, num_classes=x_limit.shape[-1]).float() 38 | U_E = F.one_hot(U_E, num_classes=e_limit.shape[-1]).float() 39 | 40 | # Get upper triangular part of edge noise, without main diagonal 41 | upper_triangular_mask = torch.zeros_like(U_E) 42 | indices = torch.triu_indices(row=U_E.size(1), col=U_E.size(2), offset=1) 43 | upper_triangular_mask[:, indices[0], indices[1], :] = 1 44 | 45 | U_E = U_E * upper_triangular_mask 46 | U_E = U_E + torch.transpose(U_E, 1, 2) 47 | 48 | assert (U_E == torch.transpose(U_E, 1, 2)).all() 49 | 50 | return PlaceHolder(X=U_X, E=U_E, y=U_y).mask(node_mask) 51 | 52 | 53 | def sample_discrete_features(probX, probE, node_mask, mask=False): 54 | """Sample features from multinomial distribution with given probabilities (probX, probE, proby) 55 | :param probX: bs, n, dx_out node features 56 | :param probE: bs, n, n, de_out edge features 57 | :param proby: bs, dy_out global features. 58 | """ 59 | bs, n, _ = probX.shape 60 | # Noise X 61 | # The masked rows should define probability distributions as well 62 | probX[~node_mask] = 1 / probX.shape[-1] 63 | 64 | # Flatten the probability tensor to sample with multinomial 65 | probX = probX.reshape(bs * n, -1) # (bs * n, dx_out) 66 | 67 | # Sample X 68 | X_t = probX.multinomial(1, replacement=True) # (bs * n, 1) 69 | # X_t = Categorical(probs=probX).sample() # (bs * n, 1) 70 | X_t = X_t.reshape(bs, n) # (bs, n) 71 | 72 | # Noise E 73 | # The masked rows should define probability distributions as well 74 | inverse_edge_mask = ~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2)) 75 | diag_mask = torch.eye(n).unsqueeze(0).expand(bs, -1, -1) 76 | 77 | probE[inverse_edge_mask] = 1 / probE.shape[-1] 78 | probE[diag_mask.bool()] = 1 / probE.shape[-1] 79 | 80 | probE = probE.reshape(bs * n * n, -1) # (bs * n * n, de_out) 81 | 82 | # Sample E 83 | E_t = probE.multinomial(1, replacement=True).reshape(bs, n, n) # (bs, n, n) 84 | # E_t = Categorical(probs=probE).sample().reshape(bs, n, n) # (bs, n, n) 85 | E_t = torch.triu(E_t, diagonal=1) 86 | E_t = E_t + torch.transpose(E_t, 1, 2) 87 | 88 | if mask: 89 | X_t = X_t * node_mask 90 | E_t = E_t * node_mask.unsqueeze(1) * node_mask.unsqueeze(2) 91 | 92 | return PlaceHolder(X=X_t, E=E_t, y=torch.zeros(bs, 0).type_as(X_t)) 93 | 94 | 95 | -------------------------------------------------------------------------------- /src/models/extra_features_molecular.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src import utils 3 | 4 | 5 | class ExtraMolecularFeatures: 6 | def __init__(self, dataset_infos): 7 | self.charge = ChargeFeature( 8 | remove_h=dataset_infos.remove_h, valencies=dataset_infos.valencies 9 | ) 10 | self.valency = ValencyFeature() 11 | self.weight = WeightFeature( 12 | max_weight=dataset_infos.max_weight, atom_weights=dataset_infos.atom_weights 13 | ) 14 | 15 | def __call__(self, noisy_data): 16 | charge = self.charge(noisy_data).unsqueeze(-1) # (bs, n, 1) 17 | valency = self.valency(noisy_data).unsqueeze(-1) # (bs, n, 1) 18 | weight = self.weight(noisy_data) # (bs, 1) 19 | 20 | extra_edge_attr = torch.zeros((*noisy_data["E_t"].shape[:-1], 0)).type_as( 21 | noisy_data["E_t"] 22 | ) 23 | 24 | return utils.PlaceHolder( 25 | X=torch.cat((charge, valency), dim=-1), E=extra_edge_attr, y=weight 26 | ) 27 | 28 | 29 | class ChargeFeature: 30 | def __init__(self, remove_h, valencies): 31 | self.remove_h = remove_h 32 | self.valencies = valencies 33 | 34 | def __call__(self, noisy_data): 35 | # bond_orders = torch.tensor( 36 | # [0, 1, 2, 3, 1.5], device=noisy_data["E_t"].device 37 | # ).reshape(1, 1, 1, -1) 38 | de = noisy_data["E_t"].shape[-1] 39 | if de == 5: 40 | bond_orders = torch.tensor( 41 | [0, 1, 2, 3, 1.5], device=noisy_data["E_t"].device 42 | ).reshape(1, 1, 1, -1) 43 | else: 44 | bond_orders = torch.tensor( 45 | [0, 1, 2, 3], device=noisy_data["E_t"].device 46 | ).reshape(1, 1, 1, -1) 47 | weighted_E = noisy_data["E_t"] * bond_orders # (bs, n, n, de) 48 | current_valencies = weighted_E.argmax(dim=-1).sum(dim=-1) # (bs, n) 49 | 50 | valencies = torch.tensor( 51 | self.valencies, device=noisy_data["X_t"].device 52 | ).reshape(1, 1, -1) 53 | X = noisy_data["X_t"] * valencies # (bs, n, dx) 54 | normal_valencies = torch.argmax(X, dim=-1) # (bs, n) 55 | 56 | return (normal_valencies - current_valencies).type_as(noisy_data["X_t"]) 57 | 58 | 59 | class ValencyFeature: 60 | def __init__(self): 61 | pass 62 | 63 | def __call__(self, noisy_data): 64 | # bond_orders = torch.tensor( 65 | # [0, 1, 2, 3, 1.5], device=noisy_data["E_t"].device 66 | # ).reshape(1, 1, 1, -1) 67 | de = noisy_data["E_t"].shape[-1] 68 | if de == 5: 69 | bond_orders = torch.tensor( 70 | [0, 1, 2, 3, 1.5], device=noisy_data["E_t"].device 71 | ).reshape(1, 1, 1, -1) 72 | else: 73 | bond_orders = torch.tensor( 74 | [0, 1, 2, 3], device=noisy_data["E_t"].device 75 | ).reshape(1, 1, 1, -1) 76 | # bond_orders = torch.tensor([0, 1, 2, 3], device=noisy_data['E_t'].device).reshape(1, 1, 1, -1) # debug 77 | E = noisy_data["E_t"] * bond_orders # (bs, n, n, de) 78 | valencies = E.argmax(dim=-1).sum(dim=-1) # (bs, n) 79 | return valencies.type_as(noisy_data["X_t"]) 80 | 81 | 82 | class WeightFeature: 83 | def __init__(self, max_weight, atom_weights): 84 | self.max_weight = max_weight 85 | self.atom_weight_list = torch.tensor(list(atom_weights.values())) 86 | 87 | def __call__(self, noisy_data): 88 | X = torch.argmax(noisy_data["X_t"], dim=-1) # (bs, n) 89 | X_weights = self.atom_weight_list.to(X.device)[X] # (bs, n) 90 | return ( 91 | X_weights.sum(dim=-1).unsqueeze(-1).type_as(noisy_data["X_t"]) 92 | / self.max_weight 93 | ) # (bs, 1) 94 | -------------------------------------------------------------------------------- /src/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle 3 | from typing import Any, Sequence 4 | 5 | from rdkit import Chem 6 | import torch 7 | from torch_geometric.data import Data 8 | from torch_geometric.utils import subgraph 9 | 10 | 11 | def mol_to_torch_geometric(mol, atom_encoder, smiles): 12 | adj = torch.from_numpy(Chem.rdmolops.GetAdjacencyMatrix(mol, useBO=True)) 13 | edge_index = adj.nonzero().contiguous().T 14 | bond_types = adj[edge_index[0], edge_index[1]] 15 | bond_types[bond_types == 1.5] = 4 16 | edge_attr = bond_types.long() 17 | 18 | node_types = [] 19 | all_charge = [] 20 | for atom in mol.GetAtoms(): 21 | node_types.append(atom_encoder[atom.GetSymbol()]) 22 | all_charge.append(atom.GetFormalCharge()) 23 | 24 | node_types = torch.Tensor(node_types).long() 25 | all_charge = torch.Tensor(all_charge).long() 26 | 27 | data = Data( 28 | x=node_types, 29 | edge_index=edge_index, 30 | edge_attr=edge_attr, 31 | charge=all_charge, 32 | smiles=smiles, 33 | ) 34 | return data 35 | 36 | 37 | def remove_hydrogens(data: Data): 38 | to_keep = data.x > 0 39 | new_edge_index, new_edge_attr = subgraph( 40 | to_keep, 41 | data.edge_index, 42 | data.edge_attr, 43 | relabel_nodes=True, 44 | num_nodes=len(to_keep), 45 | ) 46 | return Data( 47 | x=data.x[to_keep] - 1, # Shift onehot encoding to match atom decoder 48 | charge=data.charge[to_keep], 49 | edge_index=new_edge_index, 50 | edge_attr=new_edge_attr, 51 | ) 52 | 53 | 54 | def save_pickle(array, path): 55 | with open(path, "wb") as f: 56 | pickle.dump(array, f) 57 | 58 | 59 | def load_pickle(path): 60 | with open(path, "rb") as f: 61 | return pickle.load(f) 62 | 63 | 64 | def files_exist(files) -> bool: 65 | return len(files) != 0 and all([osp.exists(f) for f in files]) 66 | 67 | 68 | def to_list(value: Any) -> Sequence: 69 | if isinstance(value, Sequence) and not isinstance(value, str): 70 | return value 71 | else: 72 | return [value] 73 | 74 | 75 | class Statistics: 76 | def __init__( 77 | self, num_nodes, node_types, bond_types, charge_types=None, valencies=None 78 | ): 79 | self.num_nodes = num_nodes 80 | self.node_types = node_types 81 | self.bond_types = bond_types 82 | self.charge_types = charge_types 83 | self.valencies = valencies 84 | 85 | 86 | class RemoveYTransform: 87 | def __call__(self, data): 88 | data.y = torch.zeros((1, 0), dtype=torch.float) 89 | return data 90 | 91 | 92 | class DistributionNodes: 93 | def __init__(self, histogram): 94 | """Compute the distribution of the number of nodes in the dataset, and sample from this distribution. 95 | historgram: dict. The keys are num_nodes, the values are counts 96 | """ 97 | 98 | if type(histogram) == dict: 99 | max_n_nodes = max(histogram.keys()) 100 | prob = torch.zeros(max_n_nodes + 1) 101 | for num_nodes, count in histogram.items(): 102 | prob[num_nodes] = count 103 | else: 104 | prob = histogram 105 | 106 | self.prob = prob / prob.sum() 107 | self.m = torch.distributions.Categorical(prob) 108 | 109 | def sample_n(self, n_samples, device): 110 | idx = self.m.sample((n_samples,)) 111 | return idx.to(device) 112 | 113 | def log_prob(self, batch_n_nodes): 114 | assert len(batch_n_nodes.size()) == 1 115 | p = self.prob.to(batch_n_nodes.device) 116 | 117 | probas = p[batch_n_nodes] 118 | log_p = torch.log(probas + 1e-30) 119 | return log_p 120 | -------------------------------------------------------------------------------- /src/metrics/train_metrics.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import wandb 4 | import torch 5 | from torch import Tensor 6 | import torch.nn as nn 7 | from torchmetrics import Metric, MeanSquaredError, MetricCollection 8 | 9 | from metrics.abstract_metrics import ( 10 | CrossEntropyMetric, 11 | KLDMetric, 12 | ) 13 | 14 | 15 | class NodeMSE(MeanSquaredError): 16 | def __init__(self, *args): 17 | super().__init__(*args) 18 | 19 | 20 | class EdgeMSE(MeanSquaredError): 21 | def __init__(self, *args): 22 | super().__init__(*args) 23 | 24 | 25 | class TrainLossDiscrete(nn.Module): 26 | """Train with Cross entropy""" 27 | 28 | def __init__(self, lambda_train, kld=False): 29 | super().__init__() 30 | self.lambda_train = lambda_train 31 | if not kld: 32 | self.node_loss = CrossEntropyMetric() 33 | self.edge_loss = CrossEntropyMetric() 34 | else: 35 | self.node_loss = KLDMetric() 36 | self.edge_loss = KLDMetric() 37 | self.y_loss = CrossEntropyMetric() 38 | 39 | def forward( 40 | self, 41 | masked_pred_X, 42 | masked_pred_E, 43 | pred_y, 44 | true_X, 45 | true_E, 46 | true_y, 47 | log: bool, 48 | ): 49 | """Compute train metrics 50 | masked_pred_X : tensor -- (bs, n, dx) 51 | masked_pred_E : tensor -- (bs, n, n, de) 52 | pred_y : tensor -- (bs, ) 53 | true_X : tensor -- (bs, n, dx) 54 | true_E : tensor -- (bs, n, n, de) 55 | true_y : tensor -- (bs, ) 56 | log : boolean.""" 57 | true_X = torch.reshape(true_X, (-1, true_X.size(-1))) # (bs * n, dx) 58 | true_E = torch.reshape(true_E, (-1, true_E.size(-1))) # (bs * n * n, de) 59 | masked_pred_X = torch.reshape( 60 | masked_pred_X, (-1, masked_pred_X.size(-1)) 61 | ) # (bs * n, dx) 62 | masked_pred_E = torch.reshape( 63 | masked_pred_E, (-1, masked_pred_E.size(-1)) 64 | ) # (bs * n * n, de) 65 | 66 | # Remove masked rows 67 | mask_X = (true_X != 0.0).any(dim=-1) 68 | mask_E = (true_E != 0.0).any(dim=-1) 69 | 70 | flat_true_X = true_X[mask_X, :] 71 | flat_pred_X = masked_pred_X[mask_X, :] 72 | 73 | flat_true_E = true_E[mask_E, :] 74 | flat_pred_E = masked_pred_E[mask_E, :] 75 | 76 | loss_X = self.node_loss(flat_pred_X, flat_true_X) if true_X.numel() > 0 else 0.0 77 | loss_E = self.edge_loss(flat_pred_E, flat_true_E) if true_E.numel() > 0 else 0.0 78 | loss_y = self.y_loss(pred_y, true_y) if pred_y.numel() > 0 else 0.0 79 | 80 | if log: 81 | to_log = { 82 | "train_loss/batch_CE": (loss_X + loss_E + loss_y).detach(), 83 | "train_loss/X_CE": ( 84 | self.node_loss.compute() if true_X.numel() > 0 else -1 85 | ), 86 | "train_loss/E_CE": ( 87 | self.edge_loss.compute() if true_E.numel() > 0 else -1 88 | ), 89 | "train_loss/y_CE": self.y_loss.compute() if true_y.numel() > 0 else -1, 90 | } 91 | if wandb.run: 92 | wandb.log(to_log, commit=True) 93 | return loss_X + self.lambda_train[0] * loss_E + self.lambda_train[1] * loss_y 94 | 95 | def reset(self): 96 | for metric in [self.node_loss, self.edge_loss, self.y_loss]: 97 | metric.reset() 98 | 99 | def log_epoch_metrics(self): 100 | epoch_node_loss = ( 101 | self.node_loss.compute() if self.node_loss.total_samples > 0 else -1 102 | ) 103 | epoch_edge_loss = ( 104 | self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1 105 | ) 106 | epoch_y_loss = ( 107 | self.y_loss.compute() if self.y_loss.total_samples > 0 else -1 108 | ) 109 | 110 | to_log = { 111 | "train_epoch/x_CE": epoch_node_loss, 112 | "train_epoch/E_CE": epoch_edge_loss, 113 | "train_epoch/y_CE": epoch_y_loss, 114 | } 115 | if wandb.run: 116 | wandb.log(to_log, commit=False) 117 | 118 | return to_log 119 | -------------------------------------------------------------------------------- /src/flow_matching/noise_distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src import utils 4 | 5 | 6 | class NoiseDistribution: 7 | 8 | def __init__(self, model_transition, dataset_infos): 9 | 10 | self.x_num_classes = dataset_infos.output_dims["X"] 11 | self.e_num_classes = dataset_infos.output_dims["E"] 12 | self.y_num_classes = dataset_infos.output_dims["y"] 13 | self.x_added_classes = 0 14 | self.e_added_classes = 0 15 | self.y_added_classes = 0 16 | self.transition = model_transition 17 | 18 | if model_transition == "uniform": 19 | x_limit = torch.ones(self.x_num_classes) / self.x_num_classes 20 | e_limit = torch.ones(self.e_num_classes) / self.e_num_classes 21 | 22 | elif model_transition == "absorbfirst": 23 | x_limit = torch.zeros(self.x_num_classes) 24 | x_limit[0] = 1 25 | e_limit = torch.zeros(self.e_num_classes) 26 | e_limit[0] = 1 27 | 28 | elif model_transition == "argmax": 29 | node_types = dataset_infos.node_types.float() 30 | x_marginals = node_types / torch.sum(node_types) 31 | 32 | edge_types = dataset_infos.edge_types.float() 33 | e_marginals = edge_types / torch.sum(edge_types) 34 | 35 | x_max_dim = torch.argmax(x_marginals) 36 | e_max_dim = torch.argmax(e_marginals) 37 | x_limit = torch.zeros(self.x_num_classes) 38 | x_limit[x_max_dim] = 1 39 | e_limit = torch.zeros(self.e_num_classes) 40 | e_limit[e_max_dim] = 1 41 | 42 | elif model_transition == "absorbing": 43 | # only add virtual classes when there are several 44 | if self.x_num_classes > 1: 45 | # if self.x_num_classes >= 1: 46 | self.x_num_classes += 1 47 | self.x_added_classes = 1 48 | if self.e_num_classes > 1: 49 | self.e_num_classes += 1 50 | self.e_added_classes = 1 51 | 52 | x_limit = torch.zeros(self.x_num_classes) 53 | x_limit[-1] = 1 54 | e_limit = torch.zeros(self.e_num_classes) 55 | e_limit[-1] = 1 56 | 57 | elif model_transition == "marginal": 58 | 59 | node_types = dataset_infos.node_types.float() 60 | x_limit = node_types / torch.sum(node_types) 61 | 62 | edge_types = dataset_infos.edge_types.float() 63 | e_limit = edge_types / torch.sum(edge_types) 64 | 65 | elif model_transition == "edge_marginal": 66 | x_limit = torch.ones(self.x_num_classes) / self.x_num_classes 67 | 68 | edge_types = dataset_infos.edge_types.float() 69 | e_limit = edge_types / torch.sum(edge_types) 70 | 71 | elif model_transition == "node_marginal": 72 | e_limit = torch.ones(self.e_num_classes) / self.e_num_classes 73 | 74 | node_types = dataset_infos.node_types.float() 75 | x_limit = node_types / torch.sum(node_types) 76 | 77 | else: 78 | raise ValueError(f"Unknown transition model: {model_transition}") 79 | 80 | y_limit = torch.ones(self.y_num_classes) / self.y_num_classes # typically dummy 81 | print( 82 | f"Limit distribution of the classes | Nodes: {x_limit} | Edges: {e_limit}" 83 | ) 84 | self.limit_dist = utils.PlaceHolder(X=x_limit, E=e_limit, y=y_limit) 85 | 86 | def update_input_output_dims(self, input_dims): 87 | input_dims["X"] += self.x_added_classes 88 | input_dims["E"] += self.e_added_classes 89 | input_dims["y"] += self.y_added_classes 90 | 91 | def update_dataset_infos(self, dataset_infos): 92 | if hasattr(dataset_infos, "atom_decoder"): 93 | dataset_infos.atom_decoder = ( 94 | dataset_infos.atom_decoder + ["Y"] * self.x_added_classes 95 | ) 96 | 97 | def get_limit_dist(self): 98 | return self.limit_dist 99 | 100 | def get_noise_dims(self): 101 | return { 102 | "X": len(self.limit_dist.X), 103 | "E": len(self.limit_dist.E), 104 | "y": len(self.limit_dist.E), 105 | } 106 | 107 | def ignore_virtual_classes(self, X, E, y=None): 108 | if self.transition == "absorbing": 109 | new_X = X[..., : -self.x_added_classes] 110 | new_E = E[..., : -self.e_added_classes] 111 | new_y = y[..., : -self.y_added_classes] if y is not None else None 112 | return new_X, new_E, new_y 113 | else: 114 | return X, E, y 115 | 116 | def add_virtual_classes(self, X, E, y=None): 117 | x_virtual = torch.zeros_like(X[..., :1]).repeat(1, 1, self.x_added_classes) 118 | new_X = torch.cat([X, x_virtual], dim=-1) 119 | 120 | e_virtual = torch.zeros_like(E[..., :1]).repeat(1, 1, 1, self.e_added_classes) 121 | new_E = torch.cat([E, e_virtual], dim=-1) 122 | 123 | if y is not None: 124 | y_virtual = torch.zeros_like(y[..., :1]).repeat(1, self.y_added_classes) 125 | new_y = torch.cat([y, y_virtual], dim=-1) 126 | else: 127 | new_y = None 128 | 129 | return new_X, new_E, new_y 130 | -------------------------------------------------------------------------------- /src/flow_matching/time_distorter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.stats import norm 4 | from scipy.special import beta as beta_func, betaln 5 | from scipy.stats import beta as sp_beta 6 | from scipy.interpolate import interp1d 7 | 8 | 9 | def beta_pdf(x, alpha, beta): 10 | """Beta distribution PDF.""" 11 | # coeff = np.exp(betaln(alpha, beta)) 12 | # return x ** (alpha - 1) * (1 - x) ** (beta - 1) / coeff 13 | return x ** (alpha - 1) * (1 - x) ** (beta - 1) / beta_func(alpha, beta) 14 | 15 | 16 | def objective_function(alpha, beta, y, t): 17 | """Objective function to minimize (mean squared error).""" 18 | y_pred = beta_pdf(t, alpha, beta) 19 | regularization = (alpha + beta) + (1 / alpha + 1 / beta) 20 | error = np.mean((y - y_pred) ** 2) 21 | error = error + 0.0001 * regularization 22 | return error 23 | 24 | 25 | class TimeDistorter: 26 | 27 | def __init__( 28 | self, 29 | train_distortion, 30 | sample_distortion, 31 | mu=0, 32 | sigma=1, 33 | alpha=1, 34 | beta=1, 35 | ): 36 | self.train_distortion = train_distortion # used for sample_ft 37 | self.sample_distortion = sample_distortion # used for get_ft 38 | self.alpha = alpha 39 | self.beta = beta 40 | print( 41 | f"TimeDistorter: train_distortion={train_distortion}, sample_distortion={sample_distortion}" 42 | ) 43 | self.f_inv = None 44 | 45 | def train_ft(self, batch_size, device): 46 | t_uniform = torch.rand((batch_size, 1), device=device) 47 | t_distort = self.apply_distortion(t_uniform, self.train_distortion) 48 | 49 | return t_distort 50 | 51 | def sample_ft(self, t, sample_distortion): 52 | t_distort = self.apply_distortion(t, sample_distortion) 53 | return t_distort 54 | 55 | def fit(self, difficulty, t_array, learning_rate=0.01, iterations=1000): 56 | """Fit a beta distribution to data using the method of moments.""" 57 | alpha, beta = self.alpha, self.beta 58 | t_array = t_array + 1e-6 # Avoid division by zero 59 | 60 | for _ in range(iterations): 61 | y_pred = beta_pdf(t_array, alpha, beta) 62 | 63 | # Numerical approximation of the gradients 64 | epsilon = 1e-5 65 | grad_alpha = ( 66 | objective_function(alpha + epsilon, beta, difficulty, t_array) 67 | - objective_function(alpha - epsilon, beta, difficulty, t_array) 68 | ) / (2 * epsilon) 69 | grad_beta = ( 70 | objective_function(alpha, beta + epsilon, difficulty, t_array) 71 | - objective_function(alpha, beta - epsilon, difficulty, t_array) 72 | ) / (2 * epsilon) 73 | 74 | # # Add regularization gradient components 75 | # grad_alpha += learning_rate * (1 - 1 / alpha**2) 76 | # grad_beta += learning_rate * (1 + 1 / beta**2) 77 | 78 | # Update parameters 79 | alpha -= learning_rate * grad_alpha 80 | beta -= learning_rate * grad_beta 81 | 82 | alpha = min(max(0.3, alpha), 3) 83 | beta = min(max(0.3, beta), 3) 84 | 85 | y_pred = beta_pdf(t_array, alpha, beta) 86 | self.approximate_f_inverse(alpha, beta) 87 | 88 | return y_pred, alpha, beta 89 | 90 | def approximate_f_inverse(self, alpha, beta): 91 | # Generate data points 92 | t_values = np.linspace(0, 1, 100000) 93 | f_values = sp_beta.cdf(t_values, alpha, beta) 94 | 95 | # Sort and remove duplicates 96 | sorted_indices = np.argsort(f_values) 97 | f_values_sorted = f_values[sorted_indices] 98 | t_values_sorted = t_values[sorted_indices] 99 | 100 | # Remove duplicates 101 | _, unique_indices = np.unique(f_values_sorted, return_index=True) 102 | f_values_unique = f_values_sorted[unique_indices] 103 | t_values_unique = t_values_sorted[unique_indices] 104 | 105 | # Create the interpolation function for the inverse 106 | f_inv = interp1d( 107 | f_values_unique, 108 | t_values_unique, 109 | bounds_error=False, 110 | fill_value="extrapolate", 111 | ) 112 | 113 | self.f_inv = f_inv 114 | 115 | def apply_distortion(self, t, distortion_type): 116 | assert torch.all((t >= 0) & (t <= 1)), "t must be in the range (0, 1)" 117 | 118 | if distortion_type == "identity": 119 | ft = t 120 | elif distortion_type == "cos": 121 | ft = (1 - torch.cos(t * torch.pi)) / 2 122 | elif distortion_type == "revcos": 123 | ft = 2 * t - (1 - torch.cos(t * torch.pi)) / 2 124 | elif distortion_type == "polyinc": 125 | ft = t**2 126 | elif distortion_type == "polydec": 127 | ft = 2 * t - t**2 128 | elif distortion_type == "beta": 129 | raise ValueError(f"Unsupported for now: {distortion_type}") 130 | elif distortion_type == "logitnormal": 131 | raise ValueError(f"Unsupported for now: {distortion_type}") 132 | else: 133 | raise ValueError(f"Unknown distortion type: {distortion_type}") 134 | 135 | return ft 136 | -------------------------------------------------------------------------------- /src/analysis/dist_helper.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # 3 | # Adapted from https://github.com/lrjconan/GRAN/ which in turn is adapted from https://github.com/JiaxuanYou/graph-generation 4 | # 5 | ############################################################################### 6 | import pyemd 7 | import numpy as np 8 | import concurrent.futures 9 | from functools import partial 10 | from scipy.linalg import toeplitz 11 | 12 | 13 | def emd(x, y, distance_scaling=1.0): 14 | support_size = max(len(x), len(y)) 15 | d_mat = toeplitz(range(support_size)).astype(float) 16 | distance_mat = d_mat / distance_scaling 17 | 18 | # convert histogram values x and y to float, and make them equal len 19 | x = x.astype(float) 20 | y = y.astype(float) 21 | if len(x) < len(y): 22 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 23 | elif len(y) < len(x): 24 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 25 | 26 | emd = pyemd.emd(x, y, distance_mat) 27 | return emd 28 | 29 | 30 | def l2(x, y): 31 | dist = np.linalg.norm(x - y, 2) 32 | return dist 33 | 34 | 35 | def emd(x, y, sigma=1.0, distance_scaling=1.0): 36 | """EMD 37 | Args: 38 | x, y: 1D pmf of two distributions with the same support 39 | sigma: standard deviation 40 | """ 41 | support_size = max(len(x), len(y)) 42 | d_mat = toeplitz(range(support_size)).astype(float) 43 | distance_mat = d_mat / distance_scaling 44 | 45 | # convert histogram values x and y to float, and make them equal len 46 | x = x.astype(float) 47 | y = y.astype(float) 48 | if len(x) < len(y): 49 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 50 | elif len(y) < len(x): 51 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 52 | 53 | return np.abs(pyemd.emd(x, y, distance_mat)) 54 | 55 | 56 | def gaussian_emd(x, y, sigma=1.0, distance_scaling=1.0): 57 | """Gaussian kernel with squared distance in exponential term replaced by EMD 58 | Args: 59 | x, y: 1D pmf of two distributions with the same support 60 | sigma: standard deviation 61 | """ 62 | support_size = max(len(x), len(y)) 63 | d_mat = toeplitz(range(support_size)).astype(float) 64 | distance_mat = d_mat / distance_scaling 65 | 66 | # convert histogram values x and y to float, and make them equal len 67 | x = x.astype(float) 68 | y = y.astype(float) 69 | if len(x) < len(y): 70 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 71 | elif len(y) < len(x): 72 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 73 | 74 | emd = pyemd.emd(x, y, distance_mat) 75 | return np.exp(-emd * emd / (2 * sigma * sigma)) 76 | 77 | 78 | def gaussian(x, y, sigma=1.0): 79 | support_size = max(len(x), len(y)) 80 | # convert histogram values x and y to float, and make them equal len 81 | x = x.astype(float) 82 | y = y.astype(float) 83 | if len(x) < len(y): 84 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 85 | elif len(y) < len(x): 86 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 87 | 88 | dist = np.linalg.norm(x - y, 2) 89 | return np.exp(-dist * dist / (2 * sigma * sigma)) 90 | 91 | 92 | def gaussian_tv(x, y, sigma=1.0): 93 | support_size = max(len(x), len(y)) 94 | # convert histogram values x and y to float, and make them equal len 95 | x = x.astype(float) 96 | y = y.astype(float) 97 | if len(x) < len(y): 98 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 99 | elif len(y) < len(x): 100 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 101 | 102 | dist = np.abs(x - y).sum() / 2.0 103 | return np.exp(-dist * dist / (2 * sigma * sigma)) 104 | 105 | 106 | def kernel_parallel_unpacked(x, samples2, kernel): 107 | d = 0 108 | for s2 in samples2: 109 | d += kernel(x, s2) 110 | return d 111 | 112 | 113 | def kernel_parallel_worker(t): 114 | return kernel_parallel_unpacked(*t) 115 | 116 | 117 | def disc(samples1, samples2, kernel, is_parallel=True, *args, **kwargs): 118 | """Discrepancy between 2 samples""" 119 | d = 0 120 | 121 | if not is_parallel: 122 | for s1 in samples1: 123 | for s2 in samples2: 124 | d += kernel(s1, s2, *args, **kwargs) 125 | else: 126 | with concurrent.futures.ThreadPoolExecutor() as executor: 127 | for dist in executor.map( 128 | kernel_parallel_worker, 129 | [(s1, samples2, partial(kernel, *args, **kwargs)) for s1 in samples1], 130 | ): 131 | d += dist 132 | if len(samples1) * len(samples2) > 0: 133 | d /= len(samples1) * len(samples2) 134 | else: 135 | d = 1e6 136 | return d 137 | 138 | 139 | def compute_mmd(samples1, samples2, kernel, is_hist=True, *args, **kwargs): 140 | """MMD between two samples""" 141 | # normalize histograms into pmf 142 | if is_hist: 143 | samples1 = [s1 / (np.sum(s1) + 1e-6) for s1 in samples1] 144 | samples2 = [s2 / (np.sum(s2) + 1e-6) for s2 in samples2] 145 | mmd = ( 146 | disc(samples1, samples1, kernel, *args, **kwargs) 147 | + disc(samples2, samples2, kernel, *args, **kwargs) 148 | - 2 * disc(samples1, samples2, kernel, *args, **kwargs) 149 | ) 150 | 151 | mmd = np.abs(mmd) 152 | 153 | if mmd < 0: 154 | import pdb 155 | 156 | pdb.set_trace() 157 | 158 | return mmd 159 | 160 | 161 | def compute_emd(samples1, samples2, kernel, is_hist=True, *args, **kwargs): 162 | """EMD between average of two samples""" 163 | # normalize histograms into pmf 164 | if is_hist: 165 | samples1 = [np.mean(samples1)] 166 | samples2 = [np.mean(samples2)] 167 | return disc(samples1, samples2, kernel, *args, **kwargs), [samples1[0], samples2[0]] 168 | -------------------------------------------------------------------------------- /src/metrics/molecular_metrics_discrete.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric, MetricCollection 3 | from torch import Tensor 4 | import wandb 5 | import torch.nn as nn 6 | 7 | 8 | class CEPerClass(Metric): 9 | full_state_update = False 10 | 11 | def __init__(self, class_id): 12 | super().__init__() 13 | self.class_id = class_id 14 | self.add_state("total_ce", default=torch.tensor(0.0), dist_reduce_fx="sum") 15 | self.add_state("total_samples", default=torch.tensor(0.0), dist_reduce_fx="sum") 16 | self.softmax = torch.nn.Softmax(dim=-1) 17 | self.binary_cross_entropy = torch.nn.BCELoss(reduction="sum") 18 | 19 | def update(self, preds: Tensor, target: Tensor) -> None: 20 | """Update state with predictions and targets. 21 | Args: 22 | preds: Predictions from model (bs, n, d) or (bs, n, n, d) 23 | target: Ground truth values (bs, n, d) or (bs, n, n, d) 24 | """ 25 | target = target.reshape(-1, target.shape[-1]) 26 | mask = (target != 0.0).any(dim=-1) 27 | 28 | prob = self.softmax(preds)[..., self.class_id] 29 | prob = prob.flatten()[mask] 30 | 31 | target = target[:, self.class_id] 32 | target = target[mask] 33 | 34 | output = self.binary_cross_entropy(prob, target) 35 | self.total_ce += output 36 | self.total_samples += prob.numel() 37 | 38 | def compute(self): 39 | return self.total_ce / self.total_samples 40 | 41 | 42 | class HydrogenCE(CEPerClass): 43 | def __init__(self, i): 44 | super().__init__(i) 45 | 46 | 47 | class CarbonCE(CEPerClass): 48 | def __init__(self, i): 49 | super().__init__(i) 50 | 51 | 52 | class NitroCE(CEPerClass): 53 | def __init__(self, i): 54 | super().__init__(i) 55 | 56 | 57 | class OxyCE(CEPerClass): 58 | def __init__(self, i): 59 | super().__init__(i) 60 | 61 | 62 | class FluorCE(CEPerClass): 63 | def __init__(self, i): 64 | super().__init__(i) 65 | 66 | 67 | class BoronCE(CEPerClass): 68 | def __init__(self, i): 69 | super().__init__(i) 70 | 71 | 72 | class BrCE(CEPerClass): 73 | def __init__(self, i): 74 | super().__init__(i) 75 | 76 | 77 | class ClCE(CEPerClass): 78 | def __init__(self, i): 79 | super().__init__(i) 80 | 81 | 82 | class IodineCE(CEPerClass): 83 | def __init__(self, i): 84 | super().__init__(i) 85 | 86 | 87 | class PhosphorusCE(CEPerClass): 88 | def __init__(self, i): 89 | super().__init__(i) 90 | 91 | 92 | class SulfurCE(CEPerClass): 93 | def __init__(self, i): 94 | super().__init__(i) 95 | 96 | 97 | class SeCE(CEPerClass): 98 | def __init__(self, i): 99 | super().__init__(i) 100 | 101 | 102 | class SiCE(CEPerClass): 103 | def __init__(self, i): 104 | super().__init__(i) 105 | 106 | 107 | class NoBondCE(CEPerClass): 108 | def __init__(self, i): 109 | super().__init__(i) 110 | 111 | 112 | class SingleCE(CEPerClass): 113 | def __init__(self, i): 114 | super().__init__(i) 115 | 116 | 117 | class DoubleCE(CEPerClass): 118 | def __init__(self, i): 119 | super().__init__(i) 120 | 121 | 122 | class TripleCE(CEPerClass): 123 | def __init__(self, i): 124 | super().__init__(i) 125 | 126 | 127 | class AromaticCE(CEPerClass): 128 | def __init__(self, i): 129 | super().__init__(i) 130 | 131 | 132 | class AtomMetricsCE(MetricCollection): 133 | def __init__(self, dataset_infos): 134 | atom_decoder = dataset_infos.atom_decoder 135 | 136 | class_dict = { 137 | "H": HydrogenCE, 138 | "C": CarbonCE, 139 | "N": NitroCE, 140 | "O": OxyCE, 141 | "F": FluorCE, 142 | "B": BoronCE, 143 | "Br": BrCE, 144 | "Cl": ClCE, 145 | "I": IodineCE, 146 | "P": PhosphorusCE, 147 | "S": SulfurCE, 148 | "Se": SeCE, 149 | "Si": SiCE, 150 | } 151 | 152 | metrics_list = [] 153 | for i, atom_type in enumerate(atom_decoder): 154 | metrics_list.append(class_dict[atom_type](i)) 155 | super().__init__(metrics_list) 156 | 157 | 158 | class BondMetricsCE(MetricCollection): 159 | def __init__(self): 160 | ce_no_bond = NoBondCE(0) 161 | ce_SI = SingleCE(1) 162 | ce_DO = DoubleCE(2) 163 | ce_TR = TripleCE(3) 164 | # ce_AR = AromaticCE(4) 165 | # super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR, ce_AR]) 166 | super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR]) 167 | 168 | 169 | class TrainMolecularMetricsDiscrete(nn.Module): 170 | def __init__(self, dataset_infos): 171 | super().__init__() 172 | self.train_atom_metrics = AtomMetricsCE(dataset_infos=dataset_infos) 173 | self.train_bond_metrics = BondMetricsCE() 174 | 175 | def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool): 176 | self.train_atom_metrics(masked_pred_X, true_X) 177 | self.train_bond_metrics(masked_pred_E, true_E) 178 | if log: 179 | to_log = {} 180 | for key, val in self.train_atom_metrics.compute().items(): 181 | to_log["train/" + key] = val.item() 182 | for key, val in self.train_bond_metrics.compute().items(): 183 | to_log["train/" + key] = val.item() 184 | if wandb.run: 185 | wandb.log(to_log, commit=False) 186 | 187 | def reset(self): 188 | for metric in [self.train_atom_metrics, self.train_bond_metrics]: 189 | metric.reset() 190 | 191 | def log_epoch_metrics(self): 192 | epoch_atom_metrics = self.train_atom_metrics.compute() 193 | epoch_bond_metrics = self.train_bond_metrics.compute() 194 | 195 | to_log = {} 196 | for key, val in epoch_atom_metrics.items(): 197 | to_log["train_epoch/" + key] = val.item() 198 | for key, val in epoch_bond_metrics.items(): 199 | to_log["train_epoch/" + key] = val.item() 200 | if wandb.run: 201 | wandb.log(to_log, commit=False) 202 | 203 | for key, val in epoch_atom_metrics.items(): 204 | epoch_atom_metrics[key] = val.item() 205 | for key, val in epoch_bond_metrics.items(): 206 | epoch_bond_metrics[key] = val.item() 207 | 208 | return epoch_atom_metrics, epoch_bond_metrics 209 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch_geometric.utils 3 | from omegaconf import OmegaConf, open_dict 4 | from torch_geometric.utils import to_dense_adj, to_dense_batch 5 | import torch 6 | import omegaconf 7 | import wandb 8 | 9 | 10 | def create_folders(args): 11 | try: 12 | # os.makedirs('checkpoints') 13 | os.makedirs("graphs") 14 | os.makedirs("chains") 15 | except OSError: 16 | pass 17 | 18 | try: 19 | # os.makedirs('checkpoints/' + args.general.name) 20 | os.makedirs("graphs/" + args.general.name) 21 | os.makedirs("chains/" + args.general.name) 22 | except OSError: 23 | pass 24 | 25 | 26 | def normalize(X, E, y, norm_values, norm_biases, node_mask): 27 | X = (X - norm_biases[0]) / norm_values[0] 28 | E = (E - norm_biases[1]) / norm_values[1] 29 | y = (y - norm_biases[2]) / norm_values[2] 30 | 31 | diag = ( 32 | torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1) 33 | ) 34 | E[diag] = 0 35 | 36 | return PlaceHolder(X=X, E=E, y=y).mask(node_mask) 37 | 38 | 39 | def unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=False): 40 | """ 41 | X : node features 42 | E : edge features 43 | y : global features` 44 | norm_values : [norm value X, norm value E, norm value y] 45 | norm_biases : same order 46 | node_mask 47 | """ 48 | X = X * norm_values[0] + norm_biases[0] 49 | E = E * norm_values[1] + norm_biases[1] 50 | y = y * norm_values[2] + norm_biases[2] 51 | 52 | return PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse) 53 | 54 | 55 | def symmetrize_and_mask_diag(E): 56 | # symmetrize the edge matrix 57 | upper_triangular_mask = torch.zeros_like(E) 58 | indices = torch.triu_indices(row=E.size(1), col=E.size(2), offset=1) 59 | if len(E.shape) == 4: 60 | upper_triangular_mask[:, indices[0], indices[1], :] = 1 61 | else: 62 | upper_triangular_mask[:, indices[0], indices[1]] = 1 63 | E = E * upper_triangular_mask 64 | E = E + torch.transpose(E, 1, 2) 65 | # mask the diagonal 66 | diag = ( 67 | torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1) 68 | ) 69 | E[diag] = 0 70 | 71 | return E 72 | 73 | 74 | def to_dense(x, edge_index, edge_attr, batch): 75 | X, node_mask = to_dense_batch(x=x, batch=batch) 76 | # node_mask = node_mask.float() 77 | edge_index, edge_attr = torch_geometric.utils.remove_self_loops( 78 | edge_index, edge_attr 79 | ) 80 | max_num_nodes = X.size(1) 81 | E = to_dense_adj( 82 | edge_index=edge_index, 83 | batch=batch, 84 | edge_attr=edge_attr, 85 | max_num_nodes=max_num_nodes, 86 | ) 87 | E = encode_no_edge(E) 88 | 89 | return PlaceHolder(X=X, E=E, y=None), node_mask 90 | 91 | 92 | def encode_no_edge(E): 93 | assert len(E.shape) == 4 94 | if E.shape[-1] == 0: 95 | return E 96 | no_edge = torch.sum(E, dim=3) == 0 97 | first_elt = E[:, :, :, 0] 98 | first_elt[no_edge] = 1 99 | E[:, :, :, 0] = first_elt 100 | diag = ( 101 | torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1) 102 | ) 103 | E[diag] = 0 104 | return E 105 | 106 | 107 | def update_config_with_new_keys(cfg, saved_cfg): 108 | saved_general = saved_cfg.general 109 | saved_train = saved_cfg.train 110 | saved_model = saved_cfg.model 111 | 112 | for key, val in saved_general.items(): 113 | OmegaConf.set_struct(cfg.general, True) 114 | with open_dict(cfg.general): 115 | if key not in cfg.general.keys(): 116 | setattr(cfg.general, key, val) 117 | 118 | OmegaConf.set_struct(cfg.train, True) 119 | with open_dict(cfg.train): 120 | for key, val in saved_train.items(): 121 | if key not in cfg.train.keys(): 122 | setattr(cfg.train, key, val) 123 | 124 | OmegaConf.set_struct(cfg.model, True) 125 | with open_dict(cfg.model): 126 | for key, val in saved_model.items(): 127 | if key not in cfg.model.keys(): 128 | setattr(cfg.model, key, val) 129 | return cfg 130 | 131 | 132 | class PlaceHolder: 133 | def __init__(self, X, E, y): 134 | self.X = X 135 | self.E = E 136 | self.y = y 137 | 138 | def type_as(self, x: torch.Tensor): 139 | """Changes the device and dtype of X, E, y.""" 140 | self.X = self.X.type_as(x) 141 | self.E = self.E.type_as(x) 142 | self.y = self.y.type_as(x) 143 | return self 144 | 145 | def to_device(self, device): 146 | """Changes the device and dtype of X, E, y.""" 147 | self.X = self.X.to(device) 148 | self.E = self.E.to(device) 149 | self.y = self.y.to(device) if self.y is not None else None 150 | return self 151 | 152 | def mask(self, node_mask, collapse=False): 153 | x_mask = node_mask.unsqueeze(-1) # bs, n, 1 154 | e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1 155 | e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1 156 | 157 | if collapse: 158 | self.X = torch.argmax(self.X, dim=-1) 159 | self.E = torch.argmax(self.E, dim=-1) 160 | 161 | self.X[node_mask == 0] = -1 162 | self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = -1 163 | else: 164 | self.X = self.X * x_mask 165 | self.E = self.E * e_mask1 * e_mask2 166 | assert torch.allclose(self.E, torch.transpose(self.E, 1, 2)) 167 | return self 168 | 169 | def __repr__(self): 170 | return ( 171 | f"X: {self.X.shape if type(self.X) == torch.Tensor else self.X} -- " 172 | + f"E: {self.E.shape if type(self.E) == torch.Tensor else self.E} -- " 173 | + f"y: {self.y.shape if type(self.y) == torch.Tensor else self.y}" 174 | ) 175 | 176 | def split(self, node_mask): 177 | """Split a PlaceHolder representing a batch into a list of placeholders representing individual graphs.""" 178 | graph_list = [] 179 | batch_size = self.X.shape[0] 180 | for i in range(batch_size): 181 | n = torch.sum(node_mask[i], dim=0) 182 | x = self.X[i, :n] 183 | e = self.E[i, :n, :n] 184 | y = self.y[i] if self.y is not None else None 185 | graph_list.append(PlaceHolder(X=x, E=e, y=y)) 186 | return graph_list 187 | 188 | 189 | def setup_wandb(cfg): 190 | config_dict = omegaconf.OmegaConf.to_container( 191 | cfg, resolve=True, throw_on_missing=True 192 | ) 193 | if cfg.general.test_only is None: 194 | name = f"{cfg.general.name}" 195 | else: 196 | if cfg.sample.search: 197 | name = f"{cfg.general.name}_search_{cfg.sample.search}" 198 | else: 199 | name = f"{cfg.general.name}_eta{cfg.sample.eta}_{cfg.sample.rdb}_{cfg.sample.time_distortion}" 200 | kwargs = { 201 | "name": name, 202 | "project": f"graph_dfm_{cfg.dataset.name}", 203 | "config": config_dict, 204 | "settings": wandb.Settings(_disable_stats=True), 205 | "reinit": True, 206 | "mode": cfg.general.wandb, 207 | } 208 | config_dict["general"]["local_dir"] = os.getcwd() 209 | wandb.init(**kwargs) 210 | wandb.save("*.txt") 211 | -------------------------------------------------------------------------------- /src/metrics/abstract_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import functional as F 4 | from torch.nn import KLDivLoss 5 | from torchmetrics import Metric, MeanSquaredError 6 | 7 | 8 | class TrainAbstractMetricsDiscrete(torch.nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool): 13 | pass 14 | 15 | def reset(self): 16 | pass 17 | 18 | def log_epoch_metrics(self): 19 | return None, None 20 | 21 | 22 | class SumExceptBatchMetric(Metric): 23 | def __init__(self): 24 | super().__init__() 25 | self.add_state("total_value", default=torch.tensor(0.0), dist_reduce_fx="sum") 26 | self.add_state("total_samples", default=torch.tensor(0.0), dist_reduce_fx="sum") 27 | 28 | def update(self, values) -> None: 29 | self.total_value += torch.sum(values) 30 | self.total_samples += values.shape[0] 31 | 32 | def compute(self): 33 | return self.total_value / self.total_samples 34 | 35 | 36 | class SumExceptBatchMSE(MeanSquaredError): 37 | def update(self, preds: Tensor, target: Tensor) -> None: 38 | """Update state with predictions and targets. 39 | 40 | Args: 41 | preds: Predictions from model 42 | target: Ground truth values 43 | """ 44 | assert preds.shape == target.shape 45 | sum_squared_error, n_obs = self._mean_squared_error_update(preds, target) 46 | 47 | self.sum_squared_error += sum_squared_error 48 | self.total += n_obs 49 | 50 | def _mean_squared_error_update(self, preds: Tensor, target: Tensor): 51 | """Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input 52 | tensors. 53 | preds: Predicted tensor 54 | target: Ground truth tensor 55 | """ 56 | diff = preds - target 57 | sum_squared_error = torch.sum(diff * diff) 58 | n_obs = preds.shape[0] 59 | return sum_squared_error, n_obs 60 | 61 | 62 | class SumExceptBatchKL(Metric): 63 | def __init__(self): 64 | super().__init__() 65 | self.add_state("total_value", default=torch.tensor(0.0), dist_reduce_fx="sum") 66 | self.add_state("total_samples", default=torch.tensor(0.0), dist_reduce_fx="sum") 67 | 68 | def update(self, p, q) -> None: 69 | self.total_value += F.kl_div(q, p, reduction="sum") 70 | self.total_samples += p.size(0) 71 | 72 | def compute(self): 73 | return self.total_value / self.total_samples 74 | 75 | 76 | class CrossEntropyMetric(Metric): 77 | def __init__(self): 78 | super().__init__() 79 | self.add_state("total_ce", default=torch.tensor(0.0), dist_reduce_fx="sum") 80 | self.add_state("total_samples", default=torch.tensor(0.0), dist_reduce_fx="sum") 81 | 82 | def update(self, preds: Tensor, target: Tensor, weight: Tensor = None) -> None: 83 | """Update state with predictions and targets. 84 | preds: Predictions from model (bs * n, d) or (bs * n * n, d) 85 | target: Ground truth values (bs * n, d) or (bs * n * n, d).""" 86 | target = torch.argmax(target, dim=-1) 87 | if weight is not None: 88 | output = F.cross_entropy( 89 | preds, 90 | target, 91 | reduction="none", 92 | weight=None, 93 | ) 94 | output = (output * weight).sum() 95 | else: 96 | output = F.cross_entropy( 97 | preds, 98 | target, 99 | reduction="sum", 100 | weight=None, 101 | ) 102 | # output = F.cross_entropy(preds, target, reduction="sum") 103 | self.total_ce += output 104 | self.total_samples += preds.size(0) 105 | 106 | def compute(self): 107 | return self.total_ce / self.total_samples 108 | 109 | 110 | class KLDMetric(Metric): 111 | def __init__(self): 112 | super().__init__() 113 | self.add_state("total_ce", default=torch.tensor(0.0), dist_reduce_fx="sum") 114 | self.add_state("total_samples", default=torch.tensor(0.0), dist_reduce_fx="sum") 115 | 116 | def update(self, preds: Tensor, target: Tensor, weight: Tensor = None) -> None: 117 | """Update state with predictions and targets. 118 | preds: Predictions from model (bs * n, d) or (bs * n * n, d) 119 | target: Ground truth values (bs * n, d) or (bs * n * n, d).""" 120 | # target = torch.argmax(target, dim=-1) 121 | if weight is not None: 122 | output = KLDivLoss(reduction="none")( 123 | preds, 124 | target, 125 | ) 126 | output = (output * weight).sum() 127 | else: 128 | output = KLDivLoss(reduction="none")( 129 | preds, 130 | target, 131 | ) 132 | 133 | output[output.isnan()] = 0 # zero-out masked places 134 | 135 | output = output.sum() 136 | # output = F.cross_entropy(preds, target, reduction="sum") 137 | self.total_ce += output 138 | self.total_samples += preds.size(0) 139 | 140 | def compute(self): 141 | return self.total_ce / self.total_samples 142 | 143 | 144 | class ProbabilityMetric(Metric): 145 | def __init__(self): 146 | """This metric is used to track the marginal predicted probability of a class during training.""" 147 | super().__init__() 148 | self.add_state("prob", default=torch.tensor(0.0), dist_reduce_fx="sum") 149 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 150 | 151 | def update(self, preds: Tensor) -> None: 152 | self.prob += preds.sum() 153 | self.total += preds.numel() 154 | 155 | def compute(self): 156 | return self.prob / self.total 157 | 158 | 159 | class NLL(Metric): 160 | def __init__(self): 161 | super().__init__() 162 | self.add_state("total_nll", default=torch.tensor(0.0), dist_reduce_fx="sum") 163 | self.add_state("total_samples", default=torch.tensor(0.0), dist_reduce_fx="sum") 164 | 165 | def update(self, batch_nll) -> None: 166 | self.total_nll += torch.sum(batch_nll) 167 | self.total_samples += batch_nll.numel() 168 | 169 | def compute(self): 170 | return self.total_nll / self.total_samples 171 | 172 | 173 | def compute_ratios(gen_metrics, ref_metrics, metrics_keys): 174 | print("Computing ratios of metrics: ", metrics_keys) 175 | if ref_metrics is not None and len(metrics_keys) > 0: 176 | ratios = {} 177 | for key in metrics_keys: 178 | try: 179 | ref_metric = round(ref_metrics[key], 4) 180 | except: 181 | print(key, "not found") 182 | continue 183 | if ref_metric != 0.0: 184 | ratios[key + "_ratio"] = gen_metrics[key] / ref_metric 185 | else: 186 | print(f"WARNING: Reference {key} is 0. Skipping its ratio.") 187 | if len(ratios) > 0: 188 | ratios["average_ratio"] = sum(ratios.values()) / len(ratios) 189 | else: 190 | ratios["average_ratio"] = -1 191 | print(f"WARNING: no ratio being saved.") 192 | else: 193 | print("WARNING: No reference metrics for ratio computation.") 194 | ratios = {} 195 | 196 | return ratios 197 | -------------------------------------------------------------------------------- /src/metrics/tls_metrics.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import Counter 3 | from typing import List 4 | 5 | import networkx as nx 6 | import numpy as np 7 | import scipy.sparse as sp 8 | import wandb 9 | import torch 10 | from torch import Tensor 11 | from torchmetrics import MeanMetric 12 | 13 | from datasets.tls_dataset import CellGraph 14 | from analysis.spectre_utils import PlanarSamplingMetrics 15 | from analysis.spectre_utils import is_planar_graph 16 | 17 | 18 | class TLSSamplingMetrics(PlanarSamplingMetrics): 19 | def __init__(self, datamodule): 20 | super().__init__(datamodule) 21 | self.train_cell_graphs = self.loader_to_cell_graphs( 22 | datamodule.train_dataloader() 23 | ) 24 | self.val_cell_graphs = self.loader_to_cell_graphs(datamodule.val_dataloader()) 25 | self.test_cell_graphs = self.loader_to_cell_graphs(datamodule.test_dataloader()) 26 | 27 | def loader_to_cell_graphs(self, loader): 28 | cell_graphs = [] 29 | for batch in loader: 30 | for tg_graph in batch.to_data_list(): 31 | cell_graph = CellGraph.from_torch_geometric(tg_graph) 32 | cell_graphs.append(cell_graph) 33 | 34 | return cell_graphs 35 | 36 | def is_cell_graph_valid(self, cg: CellGraph): 37 | # connected and planar 38 | return is_planar_graph(cg) 39 | 40 | def forward( 41 | self, 42 | generated_graphs: list, 43 | ref_metrics, 44 | name, 45 | current_epoch, 46 | val_counter, 47 | local_rank, 48 | test=False, 49 | labels=None, 50 | ): 51 | 52 | # Unattributed graphs specific 53 | to_log = super().forward( 54 | generated_graphs, 55 | ref_metrics, 56 | name, 57 | current_epoch, 58 | val_counter, 59 | local_rank, 60 | test, 61 | labels, 62 | ) 63 | 64 | # Cell graph specific 65 | reference_cgs = self.test_cell_graphs if test else self.val_cell_graphs 66 | generated_cgs = [] 67 | if local_rank == 0: 68 | print("Building generated cell graphs...") 69 | for graph in generated_graphs: 70 | generated_cgs.append(CellGraph.from_dense_graph(graph)) 71 | 72 | # TODO: Implement these metrics with torchmetrics for parallelization 73 | generated_labels = torch.tensor([cg.to_label() for cg in generated_cgs]) 74 | ambiguous_gen_cgs = sum( 75 | [(cg_label == -1).int() for cg_label in generated_labels] 76 | ).item() 77 | if labels is not None: 78 | true_labels = torch.tensor(labels) 79 | high_tls_idxs = true_labels == 1 80 | low_tls_idxs = true_labels == 0 81 | total_tls_acc = (generated_labels == true_labels).float().mean().item() 82 | high_tls_acc = ( 83 | (generated_labels[high_tls_idxs] == true_labels[high_tls_idxs]) 84 | .float() 85 | .mean() 86 | .item() 87 | ) 88 | low_tls_acc = ( 89 | (generated_labels[low_tls_idxs] == true_labels[low_tls_idxs]) 90 | .float() 91 | .mean() 92 | .item() 93 | ) 94 | else: 95 | total_tls_acc = -1 96 | high_tls_acc = -1 97 | low_tls_acc = -1 98 | 99 | # Compute novelty and uniqueness 100 | if local_rank == 0: 101 | print("Computing uniqueness, novelty and validity for cell graphs...") 102 | frac_novel = eval_fraction_novel_cell_graphs( 103 | generated_cell_graphs=generated_cgs, 104 | train_cell_graphs=self.train_cell_graphs, 105 | ) 106 | ( 107 | frac_unique, 108 | frac_unique_and_novel, 109 | frac_unique_and_novel_valid, 110 | ) = eval_fraction_unique_novel_valid_cell_graphs( 111 | generated_cell_graphs=generated_cgs, 112 | train_cell_graphs=self.train_cell_graphs, 113 | valid_cg_fn=self.is_cell_graph_valid, 114 | ) 115 | 116 | tls_to_log = { 117 | "tls_metrics/total_tls_acc": total_tls_acc, 118 | "tls_metrics/high_tls_acc": high_tls_acc, 119 | "tls_metrics/low_tls_acc": low_tls_acc, 120 | "tls_metrics/num_ambiguous_tls": ambiguous_gen_cgs, 121 | "tls_metrics/frac_novel": frac_novel, 122 | "tls_metrics/frac_unique": frac_unique, 123 | "tls_metrics/frac_unique_and_novel": frac_unique_and_novel, 124 | "tls_metrics/frac_unique_and_novel_valid": frac_unique_and_novel_valid, 125 | } 126 | 127 | print(f"TLS sampling metrics: {tls_to_log}") 128 | if wandb.run: 129 | # only log TLS sampling metrics because others are already logged by planar sampling metrics 130 | wandb.log(tls_to_log, commit=False) 131 | 132 | to_log.update(tls_to_log) 133 | 134 | return to_log 135 | 136 | 137 | # specific for cell graphs (isomorphism function is of cell graphs) 138 | def eval_fraction_novel_cell_graphs(generated_cell_graphs, train_cell_graphs): 139 | count_non_novel = 0 140 | for gen_cg in generated_cell_graphs: 141 | for train_cg in train_cell_graphs: 142 | if nx.faster_could_be_isomorphic(train_cg, gen_cg): 143 | if gen_cg.is_isomorphic(train_cg): 144 | count_non_novel += 1 145 | break 146 | return 1 - count_non_novel / len(generated_cell_graphs) 147 | 148 | 149 | # specific for cell graphs (isomorphism function is of cell graphs) 150 | def eval_fraction_unique_novel_valid_cell_graphs( 151 | generated_cell_graphs, 152 | train_cell_graphs, 153 | valid_cg_fn, 154 | ): 155 | count_non_unique = 0 156 | count_not_novel = 0 157 | count_not_valid = 0 158 | for cg_idx, gen_cg in enumerate(generated_cell_graphs): 159 | is_unique = True 160 | for gen_cg_seen in generated_cell_graphs[:cg_idx]: 161 | # test =gen_cg.is_isomorphic(gen_cg_seen) 162 | # print(test) 163 | # breakpoint() 164 | if nx.faster_could_be_isomorphic(gen_cg_seen, gen_cg): 165 | # we also need to consider phenotypes of nodes 166 | if gen_cg.is_isomorphic(gen_cg_seen): 167 | count_non_unique += 1 168 | is_unique = False 169 | break 170 | if is_unique: 171 | is_novel = True 172 | for train_cg in train_cell_graphs: 173 | if nx.faster_could_be_isomorphic(train_cg, gen_cg): 174 | if gen_cg.is_isomorphic(train_cg): 175 | count_not_novel += 1 176 | is_novel = False 177 | break 178 | if is_novel: 179 | if not valid_cg_fn(gen_cg): 180 | count_not_valid += 1 181 | 182 | frac_unique = 1 - count_non_unique / len(generated_cell_graphs) 183 | frac_unique_non_isomorphic = frac_unique - count_not_novel / len( 184 | generated_cell_graphs 185 | ) 186 | frac_unique_non_isomorphic_valid = ( 187 | frac_unique_non_isomorphic - count_not_valid / len(generated_cell_graphs) 188 | ) 189 | 190 | return ( 191 | frac_unique, 192 | frac_unique_non_isomorphic, 193 | frac_unique_non_isomorphic_valid, 194 | ) 195 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeFoG: Discrete Flow Matching for Graph Generation 2 | 3 | > A PyTorch implementation of the DeFoG model for training and sampling discrete graph flows. (Please update to the latest commit. Recent fixes have been applied.) 4 | 5 | > Paper: https://arxiv.org/pdf/2410.04263 6 | 7 | > Poster: https://icml.cc/virtual/2025/poster/45644 8 | 9 | > Oral Presentation: https://icml.cc/virtual/2025/oral/47238 10 | 11 | 12 | ![DeFoG: Visualization](images/defog.png) 13 | 14 |

15 | 16 | 17 |

18 | 19 | --- 20 | 21 | ## 📝 Updates 22 | 23 | > Working with directed graphs? Consider using [DIRECTO](https://github.com/acarballocastro/DIRECTO), a discrete flow matching framework for directed graph generation. 24 | 25 | > For an updated development environment with modernized dependencies, see the `updated_env` branch. The `main` branch remains the reference implementation, based on Python 3.9 and older package versions. 26 | 27 | 28 | 29 | ## 🚀 Installation 30 | 31 | We provide two alternative installation methods: Docker and Conda. 32 | 33 | ### 🐳 Docker 34 | 35 | We provide the Dockerfile to run DeFoG in a container. 36 | 1. Build the Docker image: 37 | ```bash 38 | docker build --platform=linux/amd64 -t defog-image . 39 | ``` 40 | 41 | ⚠️ Once you clone DeFoG's git repository to your workspace, you may need to run `pip install -e .` to make all the repository modules visible. 42 | 43 | 44 | ### 🐍 Conda 45 | 46 | 1. Install Conda (we used version 25.1.1) and create DeFoG's environment: 47 | ```bash 48 | conda env create -f environment.yaml 49 | conda activate defog 50 | ``` 51 | 2. Run the following commands to check if the installation of the main packages was successful: 52 | ```bash 53 | python -c "import sys; print('Python version:', sys.version)" 54 | python -c "import rdkit; print('RDKit version:', rdkit.__version__)" 55 | python -c "import graph_tool as gt; print('Graph-Tool version:', gt.__version__)" 56 | python -c "import torch; print(f'PyTorch version: {torch.__version__}, CUDA version (via PyTorch): {torch.version.cuda}')" 57 | python -c "import torch_geometric as tg; print('PyTorch Geometric version:', tg.__version__)" 58 | ``` 59 | If you see no errors, the installation was successful and you can proceed to the next step. 60 | 3. Compile the ORCA evaluator: 61 | ```bash 62 | cd src/analysis/orca 63 | g++ -O2 -std=c++11 -o orca orca.cpp 64 | ``` 65 | 66 | ⚠️ Tested on Ubuntu. 67 | 68 | --- 69 | 70 | ## ⚙️ Usage 71 | 72 | All commands use `python main.py` with [Hydra](https://hydra.cc/) overrides. Note that `main.py` is inside the `src` directory. 73 | 74 | ### Quick start 75 | 76 | Use this script to quickly test the code. 77 | 78 | ```bash 79 | python main.py +experiment=debug 80 | ``` 81 | 82 | ### Full training 83 | 84 | ```bash 85 | python main.py +experiment= dataset= 86 | ``` 87 | 88 | - **QM9 (no H):** `+experiment=qm9_no_h dataset=qm9` 89 | - **Planar:** `+experiment=planar dataset=planar` 90 | - **SBM:** `+experiment=sbm dataset=sbm` 91 | - **Tree:** `+experiment=tree dataset=tree` 92 | - **Comm20:** `+experiment=comm20 dataset=comm20` 93 | - **Guacamol:** `+experiment=guacamol dataset=guacamol` 94 | - **MOSES:** `+experiment=moses dataset=moses` 95 | - **QM9 (with H):** `+experiment=qm9_with_h dataset=qm9` 96 | - **TLS (conditional):** `+experiment=tls dataset=tls` 97 | - **ZINC:** `+experiment=zinc dataset=zinc` 98 | 99 | --- 100 | 101 | ## 📊 Evaluation 102 | 103 | Sampling from DeFoG is typically done in two steps: 104 | 105 | 1. **Sampling Optimization** → find best sampling configuration 106 | 2. **Final Sampling** → sample and measure performance under the best configuration 107 | 108 | To perform 5 runs (mean ± std), set `general.num_sample_fold=5`. 109 | 110 | For the rest of this section, we take Planar dataset as an example: 111 | 112 | ### Default sampling 113 | ```bash 114 | python main.py +experiment=planar dataset=planar general.test_only= sample.eta=0 sample.omega=0 sample.time_distortion=identity 115 | ``` 116 | 117 | Note that if you run: 118 | ```bash 119 | python main.py +experiment=planar dataset=planar general.test_only= 120 | ``` 121 | it will run with the sampling parameters (η, ω, sample distortion) that we obtained after sampling optimization (see next section) and are reported in the paper. 122 | 123 | ### Sampling optimization 124 | To search over the optimal inference hyperperameters (η, ω, distortion), use the `sample.search` flag, which will save a csv file with the results. 125 | - **Non-grid search** (independent search for each component): 126 | ```bash 127 | python main.py +experiment=planar dataset=planar general.test_only= sample.search=all 128 | ``` 129 | - **Component-wise**: set `sample.search=target_guidance | distortion | stochasticity` above. 130 | 131 | ⚠️ We set the default search intervals for each sampling parameter as we used in our experiments. You may want to adjust these intervals according to your needs. 132 | 133 | ### Final sampling 134 | Use optimal η, ω, time distortion resulting from the search: 135 | ```bash 136 | python main.py +experiment=planar dataset=planar general.test_only= sample.eta=<η> sample.omega=<ω> sample.time_distortion= 137 | ``` 138 | 139 | --- 140 | 141 | ## 🌐 Extend DeFoG to new datasets 142 | 143 | Start by creating a new file in the `src/datasets` directory. You can refer to the following scripts as examples: 144 | - `spectre_dataset.py`, if you are using unattributed graphs; 145 | - `tls_dataset.py`, if you are using graphs with attributed nodes; 146 | - `qm9_dataset.py` or `guacamol_dataset.py`, if you are using graphs with attributed nodes and edges (e.g., molecular data). 147 | 148 | This new file should define a `Dataset` class to handle data processing (refer to the [PyG documentation](https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_dataset.html) for guidance), as well as a `DatasetInfos` class to specify relevant dataset properties (e.g., number of nodes, edges, etc.). 149 | 150 | Once your dataset file is ready, update `main.py` to incorporate the new dataset. Additionally, you can add a corresponding file in the `configs/dataset` directory. 151 | 152 | Finally, if you are planning to introduce custom metrics, you can create a new file under the `metrics` directory. 153 | 154 | --- 155 | 156 | ## Checkpoints 157 | 158 | Checkpoints, along with their corresponding results and generated samples, are shared [here](https://drive.switch.ch/index.php/s/MG7y2EZoithAywE). 159 | 160 | To run sampling and evaluate generation with a given checkpoint, set the `general.test_only` flag to the path of the checkpoint file (`.ckpt` file). To skip sampling and directly evaluate previously generated samples, set the flag `general.generated_path` to the path of the generated samples (`.pkl` file). 161 | 162 | (*Note*: The released checkpoints are retrained models from the public repository. Their performance is consistent with the paper’s findings, with minor variations attributable to training/sampling stochasticity.) 163 | 164 | --- 165 | 166 | ## 📌 Upon request 167 | 168 | - protein / EGO datasets 169 | - FCD score for molecules 170 | - W&B sweeps for sampling optimization 171 | 172 | 173 | 174 | --- 175 | ## 🙏 Acknowledgements 176 | 177 | - DiGress: https://github.com/cvignac/DiGress 178 | - Discrete Flow Models: https://github.com/andrew-cr/discrete_flow_models 179 | 180 | --- 181 | 182 | ## 📚 Citation 183 | 184 | ```bibtex 185 | @inproceedings{qinmadeira2024defog, 186 | title = {DeFoG: Discrete Flow Matching for Graph Generation}, 187 | author = {Qin, Yiming and Madeira, Manuel and Thanou, Dorina and Frossard, Pascal}, 188 | booktitle = {International Conference on Machine Learning (ICML)}, 189 | year = {2025}, 190 | } 191 | ``` 192 | -------------------------------------------------------------------------------- /src/datasets/abstract_dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import torch 5 | import wandb 6 | import pytorch_lightning as pl 7 | from torch_geometric.loader import DataLoader 8 | from torch_geometric.data.lightning import LightningDataset 9 | from tqdm import tqdm 10 | 11 | import utils as utils 12 | from datasets.dataset_utils import DistributionNodes, load_pickle, save_pickle 13 | 14 | 15 | class AbstractDataModule(LightningDataset): 16 | def __init__(self, cfg, datasets): 17 | super().__init__( 18 | train_dataset=datasets["train"], 19 | val_dataset=datasets["val"], 20 | test_dataset=datasets["test"], 21 | batch_size=cfg.train.batch_size if "debug" not in cfg.general.name else 2, 22 | num_workers=cfg.train.num_workers, 23 | pin_memory=getattr(cfg.dataset, "pin_memory", False), 24 | ) 25 | self.cfg = cfg 26 | self.input_dims = None 27 | self.output_dims = None 28 | print( 29 | f'This dataset contains {len(datasets["train"])} training graphs, {len(datasets["val"])} validation graphs, {len(datasets["test"])} test graphs.' 30 | ) 31 | 32 | def __getitem__(self, idx): 33 | return self.train_dataset[idx] 34 | 35 | def node_counts(self, max_nodes_possible=1000): 36 | all_counts = torch.zeros(max_nodes_possible) 37 | for loader in [self.train_dataloader(), self.val_dataloader()]: 38 | for data in loader: 39 | unique, counts = torch.unique(data.batch, return_counts=True) 40 | for count in counts: 41 | all_counts[count] += 1 42 | max_index = max(all_counts.nonzero()) 43 | all_counts = all_counts[: max_index + 1] 44 | all_counts = all_counts / all_counts.sum() 45 | return all_counts 46 | 47 | def node_types(self): 48 | num_classes = None 49 | 50 | for data in self.train_dataloader(): 51 | num_classes = data.x.shape[1] 52 | break 53 | 54 | counts = torch.zeros(num_classes) 55 | 56 | for i, data in enumerate(self.train_dataloader()): 57 | counts += data.x.sum(dim=0) 58 | 59 | counts = counts / counts.sum() 60 | return counts 61 | 62 | def edge_counts(self): 63 | num_classes = None 64 | for data in self.train_dataloader(): 65 | num_classes = data.edge_attr.shape[1] 66 | break 67 | 68 | d = torch.zeros(num_classes, dtype=torch.float) 69 | 70 | for i, data in enumerate(self.train_dataloader()): 71 | unique, counts = torch.unique(data.batch, return_counts=True) 72 | 73 | all_pairs = 0 74 | for count in counts: 75 | all_pairs += count * (count - 1) 76 | 77 | num_edges = data.edge_index.shape[1] 78 | num_non_edges = all_pairs - num_edges 79 | 80 | edge_types = data.edge_attr.sum(dim=0) 81 | assert num_non_edges >= 0 82 | d[0] += num_non_edges 83 | d[1:] += edge_types[1:] 84 | 85 | d = d / d.sum() 86 | return d 87 | 88 | 89 | class MolecularDataModule(AbstractDataModule): 90 | def valency_count(self, max_n_nodes, zinc=False): 91 | valencies = torch.zeros( 92 | 3 * max_n_nodes - 2 93 | ) # Max valency possible if everything is connected 94 | 95 | # No bond, single bond, double bond, triple bond, aromatic bond 96 | multiplier = torch.tensor([0, 1, 2, 3, 1.5]) 97 | if zinc: 98 | multiplier = torch.tensor([0, 1, 2, 3]) # zinc250 99 | 100 | for data in self.train_dataloader(): 101 | n = data.x.shape[0] 102 | 103 | for atom in range(n): 104 | edges = data.edge_attr[data.edge_index[0] == atom] 105 | edges_total = edges.sum(dim=0) 106 | valency = (edges_total * multiplier).sum() 107 | valencies[valency.long().item()] += 1 108 | valencies = valencies / valencies.sum() 109 | return valencies 110 | 111 | 112 | class AbstractDatasetInfos: 113 | def complete_infos(self, n_nodes, node_types): 114 | self.input_dims = None 115 | self.output_dims = None 116 | self.num_classes = len(node_types) 117 | self.max_n_nodes = len(n_nodes) - 1 118 | self.nodes_dist = DistributionNodes(n_nodes) 119 | 120 | def compute_input_output_dims(self, datamodule, extra_features, domain_features): 121 | example_batch = next(iter(datamodule.train_dataloader())) 122 | ex_dense, node_mask = utils.to_dense( 123 | example_batch.x, 124 | example_batch.edge_index, 125 | example_batch.edge_attr, 126 | example_batch.batch, 127 | ) 128 | 129 | # ex_dense.E = ex_dense.E[..., :-1] # debug 130 | 131 | example_data = { 132 | "X_t": ex_dense.X, 133 | "E_t": ex_dense.E, 134 | "y_t": example_batch["y"], 135 | "node_mask": node_mask, 136 | } 137 | self.input_dims = { 138 | "X": example_batch["x"].size(1), 139 | "E": example_batch["edge_attr"].size(1), 140 | "y": example_batch["y"].size(1) 141 | + 1, # this part take into account the conditioning 142 | } # + 1 due to time conditioning 143 | ex_extra_feat = extra_features(example_data) 144 | self.input_dims["X"] += ex_extra_feat.X.size(-1) 145 | self.input_dims["E"] += ex_extra_feat.E.size(-1) 146 | self.input_dims["y"] += ex_extra_feat.y.size(-1) 147 | 148 | ex_extra_molecular_feat = domain_features(example_data) 149 | self.input_dims["X"] += ex_extra_molecular_feat.X.size(-1) 150 | self.input_dims["E"] += ex_extra_molecular_feat.E.size(-1) 151 | self.input_dims["y"] += ex_extra_molecular_feat.y.size(-1) 152 | 153 | self.output_dims = { 154 | "X": example_batch["x"].size(1), 155 | "E": example_batch["edge_attr"].size(1), 156 | "y": 0, 157 | } 158 | 159 | def compute_reference_metrics(self, datamodule, sampling_metrics): 160 | 161 | ref_metrics_path = os.path.join( 162 | datamodule.train_dataloader().dataset.root, f"ref_metrics.pkl" 163 | ) 164 | if hasattr(datamodule, "remove_h"): 165 | if datamodule.remove_h: 166 | ref_metrics_path = ref_metrics_path.replace(".pkl", "_no_h.pkl") 167 | else: 168 | ref_metrics_path = ref_metrics_path.replace(".pkl", "_h.pkl") 169 | 170 | # Only compute the reference metrics if they haven't been computed already 171 | if not os.path.exists(ref_metrics_path): 172 | 173 | print("Reference metrics not found. Computing them now.") 174 | # Transform the training dataset into a list of graphs in the appropriate format 175 | training_graphs = [] 176 | print("Converting training dataset to format required by sampling metrics.") 177 | for data_batch in tqdm(datamodule.train_dataloader()): 178 | dense_data, node_mask = utils.to_dense( 179 | data_batch.x, 180 | data_batch.edge_index, 181 | data_batch.edge_attr, 182 | data_batch.batch, 183 | ) 184 | dense_data = dense_data.mask(node_mask, collapse=True).split(node_mask) 185 | for graph in dense_data: 186 | training_graphs.append([graph.X, graph.E]) 187 | 188 | # defining dummy arguments 189 | dummy_kwargs = { 190 | "name": "ref_metrics", 191 | "current_epoch": 0, 192 | "val_counter": 0, 193 | "local_rank": 0, 194 | "ref_metrics": {"val": None, "test": None}, 195 | } 196 | 197 | print("Computing validation reference metrics.") 198 | # do not have to worry about wandb because it was not init yet 199 | val_sampling_metrics = copy.deepcopy(sampling_metrics) 200 | 201 | val_ref_metrics = val_sampling_metrics.forward( 202 | training_graphs, 203 | test=False, 204 | **dummy_kwargs, 205 | ) 206 | 207 | print("Computing test reference metrics.") 208 | test_sampling_metrics = copy.deepcopy(sampling_metrics) 209 | test_ref_metrics = test_sampling_metrics.forward( 210 | training_graphs, 211 | test=True, 212 | **dummy_kwargs, 213 | ) 214 | 215 | print("Saving reference metrics.") 216 | # print(f"deg: {test_reference_metrics['degree']} | clus: {test_reference_metrics['clustering']} | orbit: {test_reference_metrics['orbit']}") 217 | # breakpoint() 218 | save_pickle( 219 | {"val": val_ref_metrics, "test": test_ref_metrics}, ref_metrics_path 220 | ) 221 | 222 | print("Loading reference metrics.") 223 | self.ref_metrics = load_pickle(ref_metrics_path) 224 | print("Validation reference metrics:", self.ref_metrics["val"]) 225 | print("Test reference metrics:", self.ref_metrics["test"]) 226 | -------------------------------------------------------------------------------- /src/datasets/tu_dataset_origin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import os.path as osp 4 | from typing import Callable, List, Optional 5 | 6 | import numpy as np 7 | import torch.nn.functional as F 8 | import torch 9 | from torch.utils.data import random_split 10 | import torch_geometric.utils 11 | from torch_geometric.io import fs, read_tu_data 12 | from torch_geometric.utils import remove_self_loops 13 | from torch_geometric.data import Data, InMemoryDataset, download_url 14 | from hydra.utils import get_original_cwd 15 | 16 | from utils import PlaceHolder 17 | from datasets.abstract_dataset import ( 18 | AbstractDataModule, 19 | AbstractDatasetInfos, 20 | ) 21 | from datasets.dataset_utils import ( 22 | load_pickle, 23 | save_pickle, 24 | Statistics, 25 | to_list, 26 | RemoveYTransform, 27 | ) 28 | 29 | 30 | class TUDataset(InMemoryDataset): 31 | r"""A variety of graph kernel benchmark datasets, *.e.g.*, 32 | :obj:`"IMDB-BINARY"`, :obj:`"REDDIT-BINARY"` or :obj:`"PROTEINS"`, 33 | collected from the `TU Dortmund University 34 | `_. 35 | In addition, this dataset wrapper provides `cleaned dataset versions 36 | `_ as motivated by the 37 | `"Understanding Isomorphism Bias in Graph Data Sets" 38 | `_ paper, containing only non-isomorphic 39 | graphs. 40 | 41 | .. note:: 42 | Some datasets may not come with any node labels. 43 | You can then either make use of the argument :obj:`use_node_attr` 44 | to load additional continuous node attributes (if present) or provide 45 | synthetic node features using transforms such as 46 | :class:`torch_geometric.transforms.Constant` or 47 | :class:`torch_geometric.transforms.OneHotDegree`. 48 | 49 | Args: 50 | root (str): Root directory where the dataset should be saved. 51 | name (str): The `name 52 | `_ of the 53 | dataset. 54 | transform (callable, optional): A function/transform that takes in an 55 | :obj:`torch_geometric.data.Data` object and returns a transformed 56 | version. The data object will be transformed before every access. 57 | (default: :obj:`None`) 58 | pre_transform (callable, optional): A function/transform that takes in 59 | an :obj:`torch_geometric.data.Data` object and returns a 60 | transformed version. The data object will be transformed before 61 | being saved to disk. (default: :obj:`None`) 62 | pre_filter (callable, optional): A function that takes in an 63 | :obj:`torch_geometric.data.Data` object and returns a boolean 64 | value, indicating whether the data object should be included in the 65 | final dataset. (default: :obj:`None`) 66 | force_reload (bool, optional): Whether to re-process the dataset. 67 | (default: :obj:`False`) 68 | use_node_attr (bool, optional): If :obj:`True`, the dataset will 69 | contain additional continuous node attributes (if present). 70 | (default: :obj:`False`) 71 | use_edge_attr (bool, optional): If :obj:`True`, the dataset will 72 | contain additional continuous edge attributes (if present). 73 | (default: :obj:`False`) 74 | cleaned (bool, optional): If :obj:`True`, the dataset will 75 | contain only non-isomorphic graphs. (default: :obj:`False`) 76 | 77 | **STATS:** 78 | 79 | .. list-table:: 80 | :widths: 20 10 10 10 10 10 81 | :header-rows: 1 82 | 83 | * - Name 84 | - #graphs 85 | - #nodes 86 | - #edges 87 | - #features 88 | - #classes 89 | * - MUTAG 90 | - 188 91 | - ~17.9 92 | - ~39.6 93 | - 7 94 | - 2 95 | * - ENZYMES 96 | - 600 97 | - ~32.6 98 | - ~124.3 99 | - 3 100 | - 6 101 | * - PROTEINS 102 | - 1,113 103 | - ~39.1 104 | - ~145.6 105 | - 3 106 | - 2 107 | * - COLLAB 108 | - 5,000 109 | - ~74.5 110 | - ~4914.4 111 | - 0 112 | - 3 113 | * - IMDB-BINARY 114 | - 1,000 115 | - ~19.8 116 | - ~193.1 117 | - 0 118 | - 2 119 | * - REDDIT-BINARY 120 | - 2,000 121 | - ~429.6 122 | - ~995.5 123 | - 0 124 | - 2 125 | * - ... 126 | - 127 | - 128 | - 129 | - 130 | - 131 | """ 132 | 133 | url = "https://www.chrsmrrs.com/graphkerneldatasets" 134 | cleaned_url = ( 135 | "https://raw.githubusercontent.com/nd7141/" "graph_datasets/master/datasets" 136 | ) 137 | 138 | def __init__( 139 | self, 140 | root: str, 141 | name: str, 142 | transform: Optional[Callable] = None, 143 | pre_transform: Optional[Callable] = None, 144 | pre_filter: Optional[Callable] = None, 145 | force_reload: bool = False, 146 | use_node_attr: bool = False, 147 | use_edge_attr: bool = False, 148 | cleaned: bool = False, 149 | ) -> None: 150 | self.name = name 151 | self.cleaned = cleaned 152 | super().__init__( 153 | root, transform, pre_transform, pre_filter, force_reload=force_reload 154 | ) 155 | 156 | out = fs.torch_load(self.processed_paths[0]) 157 | if not isinstance(out, tuple) or len(out) < 3: 158 | raise RuntimeError( 159 | "The 'data' object was created by an older version of PyG. " 160 | "If this error occurred while loading an already existing " 161 | "dataset, remove the 'processed/' directory in the dataset's " 162 | "root folder and try again." 163 | ) 164 | assert len(out) == 3 or len(out) == 4 165 | 166 | if len(out) == 3: # Backward compatibility. 167 | data, self.slices, self.sizes = out 168 | data_cls = Data 169 | else: 170 | data, self.slices, self.sizes, data_cls = out 171 | 172 | if not isinstance(data, dict): # Backward compatibility. 173 | self.data = data 174 | else: 175 | self.data = data_cls.from_dict(data) 176 | 177 | assert isinstance(self._data, Data) 178 | if self._data.x is not None and not use_node_attr: 179 | num_node_attributes = self.num_node_attributes 180 | self._data.x = self._data.x[:, num_node_attributes:] 181 | if self._data.edge_attr is not None and not use_edge_attr: 182 | num_edge_attrs = self.num_edge_attributes 183 | self._data.edge_attr = self._data.edge_attr[:, num_edge_attrs:] 184 | 185 | @property 186 | def raw_dir(self) -> str: 187 | name = f'raw{"_cleaned" if self.cleaned else ""}' 188 | return osp.join(self.root, self.name, name) 189 | 190 | @property 191 | def processed_dir(self) -> str: 192 | name = f'processed{"_cleaned" if self.cleaned else ""}' 193 | return osp.join(self.root, self.name, name) 194 | 195 | @property 196 | def num_node_labels(self) -> int: 197 | return self.sizes["num_node_labels"] 198 | 199 | @property 200 | def num_node_attributes(self) -> int: 201 | return self.sizes["num_node_attributes"] 202 | 203 | @property 204 | def num_edge_labels(self) -> int: 205 | return self.sizes["num_edge_labels"] 206 | 207 | @property 208 | def num_edge_attributes(self) -> int: 209 | return self.sizes["num_edge_attributes"] 210 | 211 | @property 212 | def raw_file_names(self) -> List[str]: 213 | names = ["A", "graph_indicator"] 214 | return [f"{self.name}_{name}.txt" for name in names] 215 | 216 | @property 217 | def processed_file_names(self) -> str: 218 | return "data.pt" 219 | 220 | def download(self) -> None: 221 | url = self.cleaned_url if self.cleaned else self.url 222 | fs.cp(f"{url}/{self.name}.zip", self.raw_dir, extract=True) 223 | for filename in fs.ls(osp.join(self.raw_dir, self.name)): 224 | fs.mv(filename, osp.join(self.raw_dir, osp.basename(filename))) 225 | fs.rm(osp.join(self.raw_dir, self.name)) 226 | 227 | def process(self) -> None: 228 | self.data, self.slices, sizes = read_tu_data(self.raw_dir, self.name) 229 | 230 | if self.pre_filter is not None or self.pre_transform is not None: 231 | data_list = [self.get(idx) for idx in range(len(self))] 232 | 233 | if self.pre_filter is not None: 234 | data_list = [d for d in data_list if self.pre_filter(d)] 235 | 236 | if self.pre_transform is not None: 237 | data_list = [self.pre_transform(d) for d in data_list] 238 | 239 | self.data, self.slices = self.collate(data_list) 240 | self._data_list = None # Reset cache. 241 | 242 | assert isinstance(self._data, Data) 243 | fs.torch_save( 244 | (self._data.to_dict(), self.slices, sizes, self._data.__class__), 245 | self.processed_paths[0], 246 | ) 247 | 248 | def __repr__(self) -> str: 249 | return f"{self.name}({len(self)})" 250 | -------------------------------------------------------------------------------- /src/datasets/tu_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import os.path as osp 4 | from typing import Callable, List, Optional 5 | 6 | import numpy as np 7 | import torch.nn.functional as F 8 | import torch 9 | from torch.utils.data import random_split 10 | import torch_geometric.utils 11 | from torch_geometric.io import fs, read_tu_data 12 | from torch_geometric.utils import remove_self_loops 13 | from torch_geometric.data import Data, InMemoryDataset, download_url 14 | from hydra.utils import get_original_cwd 15 | 16 | from utils import PlaceHolder 17 | from datasets.abstract_dataset import ( 18 | AbstractDataModule, 19 | AbstractDatasetInfos, 20 | ) 21 | from datasets.dataset_utils import ( 22 | load_pickle, 23 | save_pickle, 24 | Statistics, 25 | to_list, 26 | RemoveYTransform, 27 | ) 28 | 29 | 30 | class ProteinDataset(InMemoryDataset): 31 | """ 32 | Implementation based on https://github.com/KarolisMart/SPECTRE/blob/main/data.py 33 | """ 34 | 35 | def __init__( 36 | self, 37 | split, 38 | root, 39 | transform=None, 40 | pre_transform=None, 41 | pre_filter=None, 42 | ): 43 | self.dataset_name = "protein" 44 | root = root 45 | 46 | self.split = split 47 | if self.split == "train": 48 | self.file_idx = 0 49 | elif self.split == "val": 50 | self.file_idx = 1 51 | else: 52 | self.file_idx = 2 53 | 54 | super().__init__(root, transform, pre_transform, pre_filter) 55 | self.data, self.slices = torch.load(self.processed_paths[0]) 56 | 57 | @property 58 | def raw_file_names(self): 59 | return ["train_indices.pt", "val_indices.pt", "test_indices.pt"] 60 | 61 | @property 62 | def split_file_name(self): 63 | return ["train.pt", "val.pt", "test.pt"] 64 | 65 | @property 66 | def split_paths(self): 67 | r"""The absolute filepaths that must be present in order to skip 68 | splitting.""" 69 | files = to_list(self.split_file_name) 70 | return [osp.join(self.raw_dir, f) for f in files] 71 | 72 | @property 73 | def processed_file_names(self): 74 | if self.split == "train": 75 | return [ 76 | f"train.pt", 77 | f"train_n.pickle", 78 | f"train_node_types.npy", 79 | f"train_bond_types.npy", 80 | ] 81 | elif self.split == "val": 82 | return [ 83 | f"val.pt", 84 | f"val_n.pickle", 85 | f"val_node_types.npy", 86 | f"val_bond_types.npy", 87 | ] 88 | else: 89 | return [ 90 | f"test.pt", 91 | f"test_n.pickle", 92 | f"test_node_types.npy", 93 | f"test_bond_types.npy", 94 | ] 95 | 96 | def download(self): 97 | """ 98 | Download raw files. 99 | """ 100 | raw_url = "https://raw.githubusercontent.com/KarolisMart/SPECTRE/main/data/DD" 101 | for name in [ 102 | "DD_A.txt", 103 | "DD_graph_indicator.txt", 104 | "DD_graph_labels.txt", 105 | "DD_node_labels.txt", 106 | ]: 107 | download_url(f"{raw_url}/{name}", self.raw_dir) 108 | 109 | # read 110 | path = os.path.join(self.root, "raw") 111 | data_graph_indicator = np.loadtxt( 112 | os.path.join(path, "DD_graph_indicator.txt"), delimiter="," 113 | ).astype(int) 114 | 115 | # split data 116 | g_cpu = torch.Generator() 117 | g_cpu.manual_seed(1234) 118 | 119 | min_num_nodes = 100 120 | max_num_nodes = 500 121 | available_graphs = [] 122 | for idx in np.arange(1, data_graph_indicator.max() + 1): 123 | node_idx = data_graph_indicator == idx 124 | if node_idx.sum() >= min_num_nodes and node_idx.sum() <= max_num_nodes: 125 | available_graphs.append(idx) 126 | available_graphs = torch.Tensor(available_graphs) 127 | 128 | self.num_graphs = len(available_graphs) 129 | test_len = int(round(self.num_graphs * 0.2)) 130 | train_len = int(round((self.num_graphs - test_len) * 0.8)) 131 | val_len = self.num_graphs - train_len - test_len 132 | 133 | train_indices, val_indices, test_indices = random_split( 134 | available_graphs, 135 | [train_len, val_len, test_len], 136 | generator=torch.Generator().manual_seed(1234), 137 | ) 138 | print(f"Dataset sizes: train {train_len}, val {val_len}, test {test_len}") 139 | 140 | print(f"Train indices: {train_indices}") 141 | print(f"Val indices: {val_indices}") 142 | print(f"Test indices: {test_indices}") 143 | 144 | torch.save(train_indices, self.raw_paths[0]) 145 | torch.save(val_indices, self.raw_paths[1]) 146 | torch.save(test_indices, self.raw_paths[2]) 147 | 148 | def process(self): 149 | indices = torch.load( 150 | os.path.join(self.raw_dir, "{}_indices.pt".format(self.split)) 151 | ) 152 | data_adj = ( 153 | torch.Tensor( 154 | np.loadtxt(os.path.join(self.raw_dir, "DD_A.txt"), delimiter=",") 155 | ).long() 156 | - 1 157 | ) 158 | data_node_label = ( 159 | torch.Tensor( 160 | np.loadtxt( 161 | os.path.join(self.raw_dir, "DD_node_labels.txt"), delimiter="," 162 | ) 163 | ).long() 164 | - 1 165 | ) 166 | data_graph_indicator = torch.Tensor( 167 | np.loadtxt( 168 | os.path.join(self.raw_dir, "DD_graph_indicator.txt"), delimiter="," 169 | ) 170 | ).long() 171 | data_graph_types = ( 172 | torch.Tensor( 173 | np.loadtxt( 174 | os.path.join(self.raw_dir, "DD_graph_labels.txt"), delimiter="," 175 | ) 176 | ).long() 177 | - 1 178 | ) 179 | data_list = [] 180 | 181 | # get information 182 | self.num_node_type = data_node_label.max() + 1 183 | self.num_edge_type = 2 184 | self.num_graph_type = data_graph_types.max() + 1 185 | print(f"Number of node types: {self.num_node_type}") 186 | print(f"Number of edge types: {self.num_edge_type}") 187 | print(f"Number of graph types: {self.num_graph_type}") 188 | 189 | for idx in indices: 190 | offset = torch.where(data_graph_indicator == idx)[0].min() 191 | node_idx = data_graph_indicator == idx 192 | # perm = torch.randperm(node_idx.sum()).long() 193 | # reverse_perm = torch.sort(perm)[1] 194 | # nodes = data_node_label[node_idx][perm].long() 195 | # edge_idx = node_idx[data_adj[:, 0]] 196 | # edge_index = data_adj[edge_idx] - offset 197 | # edge_index[:, 0] = reverse_perm[edge_index[:, 0]] 198 | # edge_index[:, 1] = reverse_perm[edge_index[:, 1]] 199 | # nodes = data_node_label[node_idx].float() 200 | nodes = torch.ones(len(data_node_label[node_idx]), 1).float() 201 | edge_idx = node_idx[data_adj[:, 0]] 202 | edge_index = data_adj[edge_idx] - offset 203 | edge_attr = torch.ones_like(edge_index[:, 0]).float() 204 | edge_index, edge_attr = remove_self_loops(edge_index.T, edge_attr) 205 | data = torch_geometric.data.Data( 206 | x=nodes, 207 | edge_index=edge_index, 208 | edge_attr=edge_attr.unsqueeze(-1), 209 | n_nodes=nodes.shape[0], 210 | ) 211 | 212 | if self.pre_filter is not None and not self.pre_filter(data): 213 | continue 214 | if self.pre_transform is not None: 215 | data = self.pre_transform(data) 216 | 217 | data_list.append(data) 218 | 219 | torch.save(self.collate(data_list), self.processed_paths[0]) 220 | 221 | 222 | class ProteinDataModule(AbstractDataModule): 223 | def __init__(self, cfg): 224 | self.cfg = cfg 225 | self.dataset_name = self.cfg.dataset.name 226 | self.datadir = cfg.dataset.datadir 227 | base_path = pathlib.Path(get_original_cwd()).parents[0] 228 | root_path = os.path.join(base_path, cfg.dataset.datadir) 229 | transform = RemoveYTransform() 230 | 231 | datasets = { 232 | "train": ProteinDataset( 233 | root=root_path, 234 | transform=transform, 235 | split="train", 236 | ), 237 | "val": ProteinDataset( 238 | root=root_path, 239 | transform=transform, 240 | split="val", 241 | ), 242 | "test": ProteinDataset( 243 | root=root_path, 244 | transform=transform, 245 | split="test", 246 | ), 247 | } 248 | 249 | super().__init__(cfg, datasets) 250 | self.inner = self.train_dataset 251 | 252 | 253 | class ProteinInfos(AbstractDatasetInfos): 254 | def __init__(self, datamodule): 255 | self.is_molecular = False 256 | self.spectre = True 257 | self.use_charge = False 258 | # self.datamodule = datamodule 259 | self.dataset_name = datamodule.inner.dataset_name 260 | self.n_nodes = datamodule.node_counts() 261 | self.node_types = datamodule.node_types() 262 | self.edge_types = datamodule.edge_counts() 263 | super().complete_infos(self.n_nodes, self.node_types) 264 | 265 | def to_one_hot(self, data): 266 | """ 267 | call in the beginning of data 268 | get the one_hot encoding for a charge beginning from -1 269 | """ 270 | data.charge = data.x.new_zeros((*data.x.shape[:-1], 0)) 271 | data.x = F.one_hot(data.x, num_classes=self.num_node_types).float() 272 | data.edge_attr = F.one_hot( 273 | data.edge_attr, num_classes=self.num_edge_types 274 | ).float() 275 | 276 | return data 277 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import warnings 4 | 5 | import graph_tool 6 | import torch 7 | 8 | torch.cuda.empty_cache() 9 | import hydra 10 | from omegaconf import DictConfig 11 | import pytorch_lightning as pl 12 | from pytorch_lightning import Trainer 13 | from pytorch_lightning.callbacks import ModelCheckpoint 14 | from pytorch_lightning.utilities.warnings import PossibleUserWarning 15 | 16 | from src import utils 17 | from metrics.abstract_metrics import TrainAbstractMetricsDiscrete 18 | from graph_discrete_flow_model import GraphDiscreteFlowModel 19 | from models.extra_features import DummyExtraFeatures, ExtraFeatures 20 | 21 | 22 | warnings.filterwarnings("ignore", category=PossibleUserWarning) 23 | 24 | 25 | @hydra.main(version_base="1.3", config_path="../configs", config_name="config") 26 | def main(cfg: DictConfig): 27 | pl.seed_everything(cfg.train.seed) 28 | dataset_config = cfg["dataset"] 29 | 30 | if dataset_config["name"] in [ 31 | "sbm", 32 | "comm20", 33 | "planar", 34 | "tree", 35 | ]: 36 | from analysis.visualization import NonMolecularVisualization 37 | from datasets.spectre_dataset import ( 38 | SpectreGraphDataModule, 39 | SpectreDatasetInfos, 40 | ) 41 | from analysis.spectre_utils import ( 42 | PlanarSamplingMetrics, 43 | SBMSamplingMetrics, 44 | Comm20SamplingMetrics, 45 | TreeSamplingMetrics, 46 | ) 47 | 48 | datamodule = SpectreGraphDataModule(cfg) 49 | if dataset_config["name"] == "sbm": 50 | sampling_metrics = SBMSamplingMetrics(datamodule) 51 | elif dataset_config["name"] == "comm20": 52 | sampling_metrics = Comm20SamplingMetrics(datamodule) 53 | elif dataset_config["name"] == "planar": 54 | sampling_metrics = PlanarSamplingMetrics(datamodule) 55 | elif dataset_config["name"] == "tree": 56 | sampling_metrics = TreeSamplingMetrics(datamodule) 57 | else: 58 | raise NotImplementedError( 59 | f"Dataset {dataset_config['name']} not implemented" 60 | ) 61 | 62 | dataset_infos = SpectreDatasetInfos(datamodule, dataset_config) 63 | 64 | train_metrics = TrainAbstractMetricsDiscrete() 65 | visualization_tools = NonMolecularVisualization(dataset_name=cfg.dataset.name) 66 | 67 | extra_features = ExtraFeatures( 68 | cfg.model.extra_features, 69 | cfg.model.rrwp_steps, 70 | dataset_info=dataset_infos, 71 | ) 72 | domain_features = DummyExtraFeatures() 73 | 74 | dataset_infos.compute_input_output_dims( 75 | datamodule=datamodule, 76 | extra_features=extra_features, 77 | domain_features=domain_features, 78 | ) 79 | 80 | 81 | elif dataset_config["name"] in ["qm9", "guacamol", "moses", "zinc"]: 82 | from metrics.molecular_metrics import ( 83 | TrainMolecularMetrics, 84 | SamplingMolecularMetrics, 85 | ) 86 | from metrics.molecular_metrics_discrete import TrainMolecularMetricsDiscrete 87 | from models.extra_features_molecular import ExtraMolecularFeatures 88 | from analysis.visualization import MolecularVisualization 89 | 90 | if "qm9" in dataset_config["name"]: 91 | from datasets import qm9_dataset 92 | 93 | datamodule = qm9_dataset.QM9DataModule(cfg) 94 | dataset_infos = qm9_dataset.QM9infos(datamodule=datamodule, cfg=cfg) 95 | dataset_smiles = qm9_dataset.get_smiles( 96 | cfg=cfg, 97 | datamodule=datamodule, 98 | dataset_infos=dataset_infos, 99 | evaluate_datasets=False, 100 | ) 101 | elif dataset_config["name"] == "guacamol": 102 | from datasets import guacamol_dataset 103 | 104 | datamodule = guacamol_dataset.GuacamolDataModule(cfg) 105 | dataset_infos = guacamol_dataset.Guacamolinfos(datamodule, cfg) 106 | dataset_smiles = guacamol_dataset.get_smiles( 107 | raw_dir=datamodule.train_dataset.raw_dir, 108 | filter_dataset=cfg.dataset.filter, 109 | ) 110 | 111 | elif dataset_config.name == "moses": 112 | from datasets import moses_dataset 113 | 114 | datamodule = moses_dataset.MosesDataModule(cfg) 115 | dataset_infos = moses_dataset.MOSESinfos(datamodule, cfg) 116 | dataset_smiles = moses_dataset.get_smiles( 117 | raw_dir=datamodule.train_dataset.raw_dir, 118 | filter_dataset=cfg.dataset.filter, 119 | ) 120 | elif "zinc" in dataset_config["name"]: 121 | from datasets import zinc_dataset 122 | 123 | datamodule = zinc_dataset.ZINCDataModule(cfg) 124 | dataset_infos = zinc_dataset.ZINCinfos(datamodule=datamodule, cfg=cfg) 125 | dataset_smiles = zinc_dataset.get_smiles( 126 | cfg=cfg, 127 | datamodule=datamodule, 128 | dataset_infos=dataset_infos, 129 | evaluate_datasets=False, 130 | ) 131 | else: 132 | raise ValueError("Dataset not implemented") 133 | 134 | extra_features = ExtraFeatures( 135 | cfg.model.extra_features, 136 | cfg.model.rrwp_steps, 137 | dataset_info=dataset_infos, 138 | ) 139 | domain_features = ExtraMolecularFeatures(dataset_infos=dataset_infos) 140 | 141 | dataset_infos.compute_input_output_dims( 142 | datamodule=datamodule, 143 | extra_features=extra_features, 144 | domain_features=domain_features, 145 | ) 146 | 147 | train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) 148 | 149 | # We do not evaluate novelty during training 150 | add_virtual_states = "absorbing" == cfg.model.transition 151 | sampling_metrics = SamplingMolecularMetrics( 152 | dataset_infos, dataset_smiles, cfg, add_virtual_states=add_virtual_states 153 | ) 154 | visualization_tools = MolecularVisualization( 155 | cfg.dataset.remove_h, dataset_infos=dataset_infos 156 | ) 157 | 158 | elif dataset_config["name"] == "tls": 159 | from datasets import tls_dataset 160 | from metrics.tls_metrics import TLSSamplingMetrics 161 | from analysis.visualization import NonMolecularVisualization 162 | 163 | datamodule = tls_dataset.TLSDataModule(cfg) 164 | dataset_infos = tls_dataset.TLSInfos(datamodule=datamodule) 165 | 166 | train_metrics = TrainAbstractMetricsDiscrete() 167 | extra_features = ( 168 | ExtraFeatures( 169 | cfg.model.extra_features, 170 | cfg.model.rrwp_steps, 171 | dataset_info=dataset_infos, 172 | ) 173 | if cfg.model.extra_features is not None 174 | else DummyExtraFeatures() 175 | ) 176 | domain_features = DummyExtraFeatures() 177 | 178 | sampling_metrics = TLSSamplingMetrics(datamodule) 179 | 180 | visualization_tools = NonMolecularVisualization(dataset_name=cfg.dataset.name) 181 | 182 | dataset_infos.compute_input_output_dims( 183 | datamodule=datamodule, 184 | extra_features=extra_features, 185 | domain_features=domain_features, 186 | ) 187 | 188 | else: 189 | raise NotImplementedError("Unknown dataset {}".format(cfg["dataset"])) 190 | 191 | dataset_infos.compute_reference_metrics( 192 | datamodule=datamodule, 193 | sampling_metrics=sampling_metrics, 194 | ) 195 | 196 | model_kwargs = { 197 | "dataset_infos": dataset_infos, 198 | "train_metrics": train_metrics, 199 | "sampling_metrics": sampling_metrics, 200 | "visualization_tools": visualization_tools, 201 | "extra_features": extra_features, 202 | "domain_features": domain_features, 203 | "test_labels": ( 204 | datamodule.test_labels 205 | if ("qm9" in cfg.dataset.name and cfg.general.conditional) 206 | else None 207 | ), 208 | } 209 | 210 | utils.create_folders(cfg) 211 | model = GraphDiscreteFlowModel(cfg=cfg, **model_kwargs) 212 | 213 | callbacks = [] 214 | if cfg.train.save_model: 215 | checkpoint_callback = ModelCheckpoint( 216 | dirpath=f"checkpoints/{cfg.general.name}", 217 | filename="{epoch}", 218 | save_top_k=-1, 219 | every_n_epochs=cfg.general.sample_every_val 220 | * cfg.general.check_val_every_n_epochs, 221 | ) 222 | callbacks.append(checkpoint_callback) 223 | 224 | if cfg.train.ema_decay > 0: 225 | ema_callback = utils.EMA(decay=cfg.train.ema_decay) 226 | callbacks.append(ema_callback) 227 | 228 | name = cfg.general.name 229 | if name == "debug": 230 | print("[WARNING]: Run is called 'debug' -- it will run with fast_dev_run. ") 231 | 232 | use_gpu = cfg.general.gpus > 0 and torch.cuda.is_available() 233 | trainer = Trainer( 234 | gradient_clip_val=cfg.train.clip_grad, 235 | strategy="ddp_find_unused_parameters_true", # Needed to load old checkpoints 236 | accelerator="gpu" if use_gpu else "cpu", 237 | devices=cfg.general.gpus if use_gpu else 1, 238 | max_epochs=cfg.train.n_epochs, 239 | check_val_every_n_epoch=cfg.general.check_val_every_n_epochs, 240 | fast_dev_run=name == "debug", 241 | enable_progress_bar=False, 242 | callbacks=callbacks, 243 | log_every_n_steps=50 if name != "debug" else 1, 244 | logger=[], 245 | ) 246 | 247 | if not cfg.general.test_only and cfg.general.generated_path is None: 248 | trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) 249 | else: 250 | # Start by evaluating test_only_path 251 | trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) 252 | if cfg.general.evaluate_all_checkpoints: 253 | directory = pathlib.Path(cfg.general.test_only).parents[0] 254 | print("Directory:", directory) 255 | files_list = os.listdir(directory) 256 | for file in files_list: 257 | if ".ckpt" in file: 258 | ckpt_path = os.path.join(directory, file) 259 | if ckpt_path == cfg.general.test_only: 260 | continue 261 | print("Loading checkpoint", ckpt_path) 262 | trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path) 263 | 264 | 265 | if __name__ == "__main__": 266 | main() 267 | -------------------------------------------------------------------------------- /src/datasets/spectre_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import pickle as pkl 4 | import zipfile 5 | 6 | from networkx import to_numpy_array 7 | 8 | import torch 9 | from torch.utils.data import random_split 10 | import torch_geometric.utils 11 | from torch_geometric.data import InMemoryDataset, download_url 12 | 13 | from datasets.abstract_dataset import AbstractDataModule, AbstractDatasetInfos 14 | 15 | 16 | class SpectreGraphDataset(InMemoryDataset): 17 | def __init__( 18 | self, 19 | dataset_name, 20 | split, 21 | root, 22 | transform=None, 23 | pre_transform=None, 24 | pre_filter=None, 25 | ): 26 | self.sbm_file = "sbm_200.pt" 27 | self.planar_file = "planar_64_200.pt" 28 | self.comm20_file = "community_12_21_100.pt" 29 | self.dataset_name = dataset_name 30 | self.split = split 31 | super().__init__(root, transform, pre_transform, pre_filter) 32 | self.data, self.slices = torch.load(self.processed_paths[0]) 33 | self.num_graphs = len(self.data.n_nodes) 34 | 35 | @property 36 | def raw_file_names(self): 37 | return ["train.pt", "val.pt", "test.pt"] 38 | 39 | @property 40 | def processed_file_names(self): 41 | return [self.split + ".pt"] 42 | 43 | def download(self): 44 | """ 45 | Download raw qm9 files. Taken from PyG QM9 class 46 | """ 47 | if self.dataset_name == "sbm": 48 | raw_url = "https://raw.githubusercontent.com/AndreasBergmeister/graph-generation/main/data/sbm.pkl" 49 | elif self.dataset_name == "planar": 50 | raw_url = "https://raw.githubusercontent.com/AndreasBergmeister/graph-generation/main/data/planar.pkl" 51 | elif self.dataset_name == "tree": 52 | raw_url = "https://raw.githubusercontent.com/AndreasBergmeister/graph-generation/main/data/tree.pkl" 53 | elif self.dataset_name == "comm20": 54 | raw_url = "https://raw.githubusercontent.com/KarolisMart/SPECTRE/main/data/community_12_21_100.pt" 55 | elif self.dataset_name == "ego": 56 | raw_url = "https://raw.githubusercontent.com/tufts-ml/graph-generation-EDGE/main/graphs/Ego.pkl" 57 | elif self.dataset_name == "imdb": 58 | raw_url = "https://www.chrsmrrs.com/graphkerneldatasets/IMDB-BINARY.zip" 59 | else: 60 | raise ValueError(f"Unknown dataset {self.dataset_name}") 61 | file_path = download_url(raw_url, self.raw_dir) 62 | 63 | if self.dataset_name in ["tree", "sbm", "planar"]: 64 | with open(file_path, "rb") as f: 65 | dataset = pkl.load(f) 66 | train_data = dataset["train"] 67 | val_data = dataset["val"] 68 | test_data = dataset["test"] 69 | 70 | train_data = [ 71 | torch.Tensor(to_numpy_array(graph)).fill_diagonal_(0) 72 | for graph in train_data 73 | ] 74 | val_data = [ 75 | torch.Tensor(to_numpy_array(graph)).fill_diagonal_(0) 76 | for graph in val_data 77 | ] 78 | test_data = [ 79 | torch.Tensor(to_numpy_array(graph)).fill_diagonal_(0) 80 | for graph in test_data 81 | ] 82 | else: 83 | if self.dataset_name == "ego": 84 | networks = pkl.load(open(file_path, "rb")) 85 | adjs = [ 86 | torch.Tensor(to_numpy_array(network)).fill_diagonal_(0) 87 | for network in networks 88 | ] 89 | elif self.dataset_name == "imdb": 90 | with zipfile.ZipFile(file_path, "r") as zip_ref: 91 | zip_ref.extractall(os.path.dirname(file_path)) 92 | 93 | # Step 1: Read edge_index from file 94 | index_path = os.path.join( 95 | os.path.dirname(file_path), "IMDB-BINARY", "IMDB-BINARY_A.txt" 96 | ) 97 | edge_index = [] 98 | with open(index_path, "r") as file: 99 | for line in file: 100 | int1, int2 = map(int, line.strip().split(",")) 101 | edge_index.append([int1, int2]) 102 | edge_index = torch.tensor(edge_index).t().contiguous() - 1 103 | 104 | # Step 2: Read graph_indicator from file 105 | index_path = os.path.join( 106 | os.path.dirname(file_path), 107 | "IMDB-BINARY", 108 | "IMDB-BINARY_graph_indicator.txt", 109 | ) 110 | graph_indicator = [] 111 | with open(index_path, "r") as file: 112 | for line in file: 113 | num = int(line.strip()) 114 | graph_indicator.append(num) 115 | graph_indicator = torch.tensor(graph_indicator) - 1 116 | 117 | # Step 3: Create individual graphs based on graph_indicator 118 | num_graphs = graph_indicator.max().item() + 1 119 | adjs = [] 120 | for i in range(num_graphs): 121 | node_mask = graph_indicator == i 122 | n_node = node_mask.sum().item() 123 | edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]] 124 | edges = edge_index[:, edge_mask] 125 | ptr = torch.where(node_mask)[0][0] 126 | edges -= ptr 127 | adj = torch.zeros(n_node, n_node) 128 | adj[edges[0], edges[1]] = 1 129 | adj[edges[1], edges[0]] = 1 130 | adjs.append(adj) 131 | else: 132 | ( 133 | adjs, 134 | eigvals, 135 | eigvecs, 136 | n_nodes, 137 | max_eigval, 138 | min_eigval, 139 | same_sample, 140 | n_max, 141 | ) = torch.load(file_path) 142 | 143 | g_cpu = torch.Generator().manual_seed(1234) 144 | self.num_graphs = 200 145 | if self.dataset_name in ["ego", "protein"]: 146 | self.num_graphs = len(adjs) 147 | elif self.dataset_name == "imdb": 148 | self.num_graphs = graph_indicator.max().item() + 1 149 | 150 | if self.dataset_name == "ego": 151 | test_len = int(round(self.num_graphs * 0.2)) 152 | train_len = int(round(self.num_graphs * 0.8)) 153 | val_len = int(round(self.num_graphs * 0.2)) 154 | indices = torch.randperm(self.num_graphs, generator=g_cpu) 155 | print( 156 | f"Dataset sizes: train {train_len}, val {val_len}, test {test_len}" 157 | ) 158 | train_indices = indices[:train_len] 159 | val_indices = indices[:val_len] 160 | test_indices = indices[train_len:] 161 | else: 162 | test_len = int(round(self.num_graphs * 0.2)) 163 | train_len = int(round((self.num_graphs - test_len) * 0.8)) 164 | val_len = self.num_graphs - train_len - test_len 165 | indices = torch.randperm(self.num_graphs, generator=g_cpu) 166 | print( 167 | f"Dataset sizes: train {train_len}, val {val_len}, test {test_len}" 168 | ) 169 | train_indices = indices[:train_len] 170 | val_indices = indices[train_len : train_len + val_len] 171 | test_indices = indices[train_len + val_len :] 172 | 173 | train_data = [] 174 | val_data = [] 175 | test_data = [] 176 | 177 | for i, adj in enumerate(adjs): 178 | if i in train_indices: 179 | train_data.append(adj) 180 | if i in val_indices: 181 | val_data.append(adj) 182 | if i in test_indices: 183 | test_data.append(adj) 184 | 185 | torch.save(train_data, self.raw_paths[0]) 186 | torch.save(val_data, self.raw_paths[1]) 187 | torch.save(test_data, self.raw_paths[2]) 188 | 189 | def process(self): 190 | file_idx = {"train": 0, "val": 1, "test": 2} 191 | raw_dataset = torch.load(self.raw_paths[file_idx[self.split]]) 192 | 193 | data_list = [] 194 | for adj in raw_dataset: 195 | n = adj.shape[-1] 196 | X = torch.ones(n, 1, dtype=torch.float) 197 | y = torch.zeros([1, 0]).float() 198 | edge_index, _ = torch_geometric.utils.dense_to_sparse(adj) 199 | edge_attr = torch.zeros(edge_index.shape[-1], 2, dtype=torch.float) 200 | edge_attr[:, 1] = 1 201 | num_nodes = n * torch.ones(1, dtype=torch.long) 202 | data = torch_geometric.data.Data( 203 | x=X, edge_index=edge_index, edge_attr=edge_attr, y=y, n_nodes=num_nodes 204 | ) 205 | 206 | if self.pre_filter is not None and not self.pre_filter(data): 207 | continue 208 | if self.pre_transform is not None: 209 | data = self.pre_transform(data) 210 | 211 | data_list.append(data) 212 | 213 | torch.save(self.collate(data_list), self.processed_paths[0]) 214 | 215 | 216 | class SpectreGraphDataModule(AbstractDataModule): 217 | def __init__(self, cfg, n_graphs=200): 218 | self.cfg = cfg 219 | self.datadir = cfg.dataset.datadir 220 | base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] 221 | root_path = os.path.join(base_path, self.datadir) 222 | 223 | datasets = { 224 | "train": SpectreGraphDataset( 225 | dataset_name=self.cfg.dataset.name, split="train", root=root_path 226 | ), 227 | "val": SpectreGraphDataset( 228 | dataset_name=self.cfg.dataset.name, split="val", root=root_path 229 | ), 230 | "test": SpectreGraphDataset( 231 | dataset_name=self.cfg.dataset.name, split="test", root=root_path 232 | ), 233 | } 234 | 235 | train_len = len(datasets["train"].data.n_nodes) 236 | val_len = len(datasets["val"].data.n_nodes) 237 | test_len = len(datasets["test"].data.n_nodes) 238 | print(f"Dataset sizes: train {train_len}, val {val_len}, test {test_len}") 239 | 240 | super().__init__(cfg, datasets) 241 | self.inner = self.train_dataset 242 | 243 | def __getitem__(self, item): 244 | return self.inner[item] 245 | 246 | 247 | class SpectreDatasetInfos(AbstractDatasetInfos): 248 | def __init__(self, datamodule, dataset_config): 249 | # self.datamodule = datamodule 250 | self.dataset_name = datamodule.inner.dataset_name 251 | self.n_nodes = datamodule.node_counts() 252 | self.node_types = datamodule.node_types() 253 | self.edge_types = datamodule.edge_counts() 254 | super().complete_infos(self.n_nodes, self.node_types) 255 | -------------------------------------------------------------------------------- /src/analysis/visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from rdkit import Chem 4 | from rdkit.Chem import Draw, AllChem 5 | from rdkit.Geometry import Point3D 6 | from rdkit import RDLogger 7 | import imageio 8 | import networkx as nx 9 | import numpy as np 10 | import rdkit.Chem 11 | import wandb 12 | import matplotlib.pyplot as plt 13 | 14 | 15 | from datasets.tls_dataset import CellGraph 16 | 17 | 18 | class MolecularVisualization: 19 | def __init__(self, remove_h, dataset_infos): 20 | self.remove_h = remove_h 21 | self.dataset_infos = dataset_infos 22 | 23 | def mol_from_graphs(self, node_list, adjacency_matrix): 24 | """ 25 | Convert graphs to rdkit molecules 26 | node_list: the nodes of a batch of nodes (bs x n) 27 | adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) 28 | """ 29 | # dictionary to map integer value to the char of atom 30 | atom_decoder = self.dataset_infos.atom_decoder 31 | 32 | # create empty editable mol object 33 | mol = Chem.RWMol() 34 | 35 | # add atoms to mol and keep track of index 36 | node_to_idx = {} 37 | for i in range(len(node_list)): 38 | if node_list[i] == -1: 39 | continue 40 | a = Chem.Atom(atom_decoder[int(node_list[i])]) 41 | molIdx = mol.AddAtom(a) 42 | node_to_idx[i] = molIdx 43 | 44 | for ix, row in enumerate(adjacency_matrix): 45 | for iy, bond in enumerate(row): 46 | # only traverse half the symmetric matrix 47 | if iy <= ix: 48 | continue 49 | if bond == 1: 50 | bond_type = Chem.rdchem.BondType.SINGLE 51 | elif bond == 2: 52 | bond_type = Chem.rdchem.BondType.DOUBLE 53 | elif bond == 3: 54 | bond_type = Chem.rdchem.BondType.TRIPLE 55 | elif bond == 4: 56 | bond_type = Chem.rdchem.BondType.AROMATIC 57 | else: 58 | continue 59 | mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type) 60 | 61 | try: 62 | mol = mol.GetMol() 63 | except rdkit.Chem.KekulizeException: 64 | print("Can't kekulize molecule") 65 | mol = None 66 | return mol 67 | 68 | def visualize( 69 | self, path: str, molecules: list, num_molecules_to_visualize: int, log="graph" 70 | ): 71 | # define path to save figures 72 | if not os.path.exists(path): 73 | os.makedirs(path) 74 | 75 | # visualize the final molecules 76 | print(f"Visualizing {num_molecules_to_visualize} of {len(molecules)}") 77 | if num_molecules_to_visualize > len(molecules): 78 | print(f"Shortening to {len(molecules)}") 79 | num_molecules_to_visualize = len(molecules) 80 | 81 | for i in range(num_molecules_to_visualize): 82 | file_path = os.path.join(path, "molecule_{}.png".format(i)) 83 | mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy()) 84 | try: 85 | Draw.MolToFile(mol, file_path) 86 | if wandb.run and log is not None: 87 | print(f"Saving {file_path} to wandb") 88 | wandb.log({log: wandb.Image(file_path)}, commit=True) 89 | except rdkit.Chem.KekulizeException: 90 | print("Can't kekulize molecule") 91 | 92 | def visualize_chain(self, path, nodes_list, adjacency_matrix, times, trainer=None): 93 | RDLogger.DisableLog("rdApp.*") 94 | # convert graphs to the rdkit molecules 95 | mols = [ 96 | self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) 97 | for i in range(nodes_list.shape[0]) 98 | ] 99 | 100 | # find the coordinates of atoms in the final molecule 101 | final_molecule = mols[-1] 102 | AllChem.Compute2DCoords(final_molecule) 103 | 104 | coords = [] 105 | for i, atom in enumerate(final_molecule.GetAtoms()): 106 | positions = final_molecule.GetConformer().GetAtomPosition(i) 107 | coords.append((positions.x, positions.y, positions.z)) 108 | 109 | # align all the molecules 110 | for i, mol in enumerate(mols): 111 | AllChem.Compute2DCoords(mol) 112 | conf = mol.GetConformer() 113 | for j, atom in enumerate(mol.GetAtoms()): 114 | x, y, z = coords[j] 115 | conf.SetAtomPosition(j, Point3D(x, y, z)) 116 | 117 | # draw gif 118 | save_paths = [] 119 | num_frams = nodes_list.shape[0] 120 | 121 | for frame in range(num_frams): 122 | file_name = os.path.join(path, "fram_{}.png".format(frame)) 123 | Draw.MolToFile( 124 | mols[frame], 125 | file_name, 126 | size=(300, 300), 127 | legend=f"t = {times[frame]:.2f}", 128 | ) 129 | save_paths.append(file_name) 130 | 131 | imgs = [imageio.imread(fn) for fn in save_paths] 132 | gif_path = os.path.join( 133 | os.path.dirname(path), "{}.gif".format(path.split("/")[-1]) 134 | ) 135 | imgs.extend([imgs[-1]] * 10) 136 | imageio.mimsave(gif_path, imgs, subrectangles=True, duration=200) 137 | 138 | if wandb.run: 139 | print(f"Saving {gif_path} to wandb") 140 | wandb.log( 141 | {"chain": wandb.Video(gif_path, fps=5, format="gif")}, commit=True 142 | ) 143 | 144 | # draw grid image 145 | try: 146 | img = Draw.MolsToGridImage(mols, molsPerRow=10, subImgSize=(200, 200)) 147 | img.save( 148 | os.path.join(path, "{}_grid_image.png".format(path.split("/")[-1])) 149 | ) 150 | except Chem.rdchem.KekulizeException: 151 | print("Can't kekulize molecule") 152 | return mols 153 | 154 | 155 | class NonMolecularVisualization: 156 | 157 | def __init__(self, dataset_name): 158 | self.is_tls = "tls" in dataset_name 159 | 160 | def to_networkx(self, node_list, adjacency_matrix): 161 | """ 162 | Convert graphs to networkx graphs 163 | node_list: the nodes of a batch of nodes (bs x n) 164 | adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) 165 | """ 166 | graph = nx.Graph() 167 | 168 | for i in range(len(node_list)): 169 | if node_list[i] == -1: 170 | continue 171 | graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i]) 172 | 173 | rows, cols = np.where(adjacency_matrix >= 1) 174 | edges = zip(rows.tolist(), cols.tolist()) 175 | for edge in edges: 176 | edge_type = adjacency_matrix[edge[0]][edge[1]] 177 | graph.add_edge( 178 | edge[0], edge[1], color=float(edge_type), weight=3 * edge_type 179 | ) 180 | 181 | return graph 182 | 183 | def visualize_non_molecule( 184 | self, 185 | graph, 186 | pos, 187 | path, 188 | iterations=100, 189 | node_size=100, 190 | largest_component=False, 191 | time=None, 192 | ): 193 | if largest_component: 194 | CGs = [graph.subgraph(c) for c in nx.connected_components(graph)] 195 | CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True) 196 | graph = CGs[0] 197 | 198 | # Plot the graph structure with colors 199 | if pos is None: 200 | pos = nx.spring_layout(graph, iterations=iterations) 201 | 202 | # Set node colors based on the eigenvectors 203 | w, U = np.linalg.eigh(nx.normalized_laplacian_matrix(graph).toarray()) 204 | vmin, vmax = np.min(U[:, 1]), np.max(U[:, 1]) 205 | m = max(np.abs(vmin), vmax) 206 | vmin, vmax = -m, m 207 | 208 | plt.figure() 209 | nx.draw( 210 | graph, 211 | pos, 212 | font_size=5, 213 | node_size=node_size, 214 | with_labels=False, 215 | node_color=U[:, 1], 216 | cmap=plt.cm.coolwarm, 217 | vmin=vmin, 218 | vmax=vmax, 219 | edge_color="grey", 220 | ) 221 | if time is not None: 222 | plt.text( 223 | 0.5, 224 | 0.05, # place below the graph 225 | f"t = {time:.2f}", 226 | ha="center", 227 | va="center", 228 | transform=plt.gcf().transFigure, 229 | fontsize=16, 230 | ) 231 | 232 | plt.tight_layout() 233 | plt.savefig(path) 234 | plt.close("all") 235 | 236 | def visualize( 237 | self, path: str, graphs: list, num_graphs_to_visualize: int, log="graph" 238 | ): 239 | # define path to save figures 240 | if not os.path.exists(path): 241 | os.makedirs(path) 242 | 243 | # visualize the final molecules 244 | for i in range(num_graphs_to_visualize): 245 | file_path = os.path.join(path, "graph_{}.png".format(i)) 246 | 247 | if self.is_tls: 248 | cg = CellGraph.from_dense_graph(graphs[i]) 249 | cg.plot_graph(save_path=file_path, has_legend=True) 250 | else: 251 | graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy()) 252 | self.visualize_non_molecule(graph=graph, pos=None, path=file_path) 253 | 254 | im = plt.imread(file_path) 255 | if wandb.run and log is not None: 256 | wandb.log({log: [wandb.Image(im, caption=file_path)]}) 257 | 258 | def visualize_chain(self, path, nodes_list, adjacency_matrix, times): 259 | 260 | graphs = [] 261 | for i in range(nodes_list.shape[0]): 262 | if self.is_tls: 263 | graphs.append( 264 | CellGraph.from_dense_graph((nodes_list[i], adjacency_matrix[i])) 265 | ) 266 | else: 267 | graphs.append(self.to_networkx(nodes_list[i], adjacency_matrix[i])) 268 | 269 | # find the coordinates of atoms in the final molecule 270 | final_graph = graphs[-1] 271 | final_pos = nx.spring_layout(final_graph, seed=0) 272 | 273 | # draw gif 274 | save_paths = [] 275 | num_frams = nodes_list.shape[0] 276 | 277 | for frame in range(num_frams): 278 | file_name = os.path.join(path, "fram_{}.png".format(frame)) 279 | if self.is_tls: 280 | if not graphs[frame].get_pos(): # The last one already has a pos 281 | graphs[frame].set_pos(pos=final_pos) 282 | graphs[frame].plot_graph( 283 | save_path=file_name, 284 | has_legend=False, 285 | verbose=False, 286 | time=times[frame], 287 | ) 288 | else: 289 | self.visualize_non_molecule( 290 | graph=graphs[frame], 291 | pos=final_pos, 292 | path=file_name, 293 | time=times[frame], 294 | ) 295 | save_paths.append(file_name) 296 | 297 | imgs = [imageio.imread(fn) for fn in save_paths] 298 | gif_path = os.path.join( 299 | os.path.dirname(path), "{}.gif".format(path.split("/")[-1]) 300 | ) 301 | imgs.extend([imgs[-1]] * 10) 302 | imageio.mimsave(gif_path, imgs, subrectangles=True, duration=200) 303 | if wandb.run: 304 | wandb.log( 305 | {"chain": [wandb.Video(gif_path, caption=gif_path, format="gif")]} 306 | ) 307 | -------------------------------------------------------------------------------- /src/flow_matching/rate_matrix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from flow_matching import flow_matching_utils 5 | from flow_matching.utils import dt_p_xt_g_x1, p_xt_g_x1 6 | 7 | 8 | class RateMatrixDesigner: 9 | 10 | def __init__(self, rdb, rdb_crit, eta, omega, limit_dist): 11 | 12 | self.omega = omega # target guidance 13 | self.eta = eta # stochasticity 14 | # Different designs of R^db 15 | self.rdb = rdb 16 | self.rdb_crit = rdb_crit 17 | self.limit_dist = limit_dist 18 | self.num_classes_X = len(self.limit_dist.X) 19 | self.num_classes_E = len(self.limit_dist.E) 20 | 21 | print( 22 | f"RateMatrixDesigner: rdb={rdb}, rdb_crit={rdb_crit}, eta={eta}, omega={omega}" 23 | ) 24 | 25 | def compute_graph_rate_matrix(self, t, node_mask, G_t, G_1_pred): 26 | 27 | X_t, E_t = G_t 28 | X_1_pred, E_1_pred = G_1_pred 29 | 30 | X_t_label = X_t.argmax(-1, keepdim=True) 31 | E_t_label = E_t.argmax(-1, keepdim=True) 32 | sampled_G_1 = flow_matching_utils.sample_discrete_features( 33 | X_1_pred, 34 | E_1_pred, 35 | node_mask=node_mask, 36 | ) 37 | X_1_sampled = sampled_G_1.X 38 | E_1_sampled = sampled_G_1.E 39 | 40 | dfm_variables = self.compute_dfm_variables( 41 | t, X_t_label, E_t_label, X_1_sampled, E_1_sampled 42 | ) 43 | 44 | Rstar_t_X, Rstar_t_E = self.compute_Rstar(dfm_variables) 45 | 46 | Rdb_t_X, Rdb_t_E = self.compute_RDB( 47 | X_t_label, 48 | E_t_label, 49 | X_1_pred, 50 | E_1_pred, 51 | X_1_sampled, 52 | E_1_sampled, 53 | node_mask, 54 | t, 55 | dfm_variables, 56 | ) 57 | 58 | Rtg_t_X, Rtg_t_E = self.compute_R_tg( 59 | X_1_sampled, 60 | E_1_sampled, 61 | X_t_label, 62 | E_t_label, 63 | dfm_variables, 64 | ) 65 | 66 | # sum to get the final R_t_X and R_t_E 67 | R_t_X = Rstar_t_X + Rdb_t_X + Rtg_t_X 68 | R_t_E = Rstar_t_E + Rdb_t_E + Rtg_t_E 69 | 70 | # Stabilize rate matrices 71 | R_t_X, R_t_E = self.stabilize_rate_matrix(R_t_X, R_t_E, dfm_variables) 72 | 73 | return R_t_X, R_t_E 74 | 75 | def compute_dfm_variables(self, t, X_t_label, E_t_label, X_1_sampled, E_1_sampled): 76 | 77 | dt_p_vals_X, dt_p_vals_E = dt_p_xt_g_x1( 78 | X_1_sampled, 79 | E_1_sampled, 80 | self.limit_dist, 81 | ) # (bs, n, dx), (bs, n, n, de) 82 | 83 | dt_p_vals_at_Xt = dt_p_vals_X.gather(-1, X_t_label).squeeze(-1) # (bs, n, ) 84 | dt_p_vals_at_Et = dt_p_vals_E.gather(-1, E_t_label).squeeze(-1) # (bs, n, n, ) 85 | 86 | pt_vals_X, pt_vals_E = p_xt_g_x1( 87 | X_1_sampled, 88 | E_1_sampled, 89 | t, 90 | self.limit_dist, 91 | ) # (bs, n, dx), (bs, n, n, de) 92 | 93 | pt_vals_at_Xt = pt_vals_X.gather(-1, X_t_label).squeeze(-1) # (bs, n, ) 94 | pt_vals_at_Et = pt_vals_E.gather(-1, E_t_label).squeeze(-1) # (bs, n, n, ) 95 | 96 | Z_t_X = torch.count_nonzero(pt_vals_X, dim=-1) # (bs, n) 97 | Z_t_E = torch.count_nonzero(pt_vals_E, dim=-1) # (bs, n, n) 98 | 99 | dfm_variables = { 100 | "pt_vals_X": pt_vals_X, 101 | "pt_vals_E": pt_vals_E, 102 | "pt_vals_at_Xt": pt_vals_at_Xt, 103 | "pt_vals_at_Et": pt_vals_at_Et, 104 | "dt_p_vals_X": dt_p_vals_X, 105 | "dt_p_vals_E": dt_p_vals_E, 106 | "dt_p_vals_at_Xt": dt_p_vals_at_Xt, 107 | "dt_p_vals_at_Et": dt_p_vals_at_Et, 108 | "Z_t_X": Z_t_X, 109 | "Z_t_E": Z_t_E, 110 | } 111 | 112 | return dfm_variables 113 | 114 | def compute_Rstar(self, dfm_variables): 115 | 116 | # Unpack needed variables 117 | dt_p_vals_X = dfm_variables["dt_p_vals_X"] 118 | dt_p_vals_E = dfm_variables["dt_p_vals_E"] 119 | dt_p_vals_at_Xt = dfm_variables["dt_p_vals_at_Xt"] 120 | dt_p_vals_at_Et = dfm_variables["dt_p_vals_at_Et"] 121 | pt_vals_at_Xt = dfm_variables["pt_vals_at_Xt"] 122 | pt_vals_at_Et = dfm_variables["pt_vals_at_Et"] 123 | Z_t_X = dfm_variables["Z_t_X"] 124 | Z_t_E = dfm_variables["Z_t_E"] 125 | 126 | # Numerator of R_t^* 127 | inner_X = dt_p_vals_X - dt_p_vals_at_Xt[:, :, None] 128 | inner_E = dt_p_vals_E - dt_p_vals_at_Et[:, :, :, None] 129 | Rstar_t_numer_X = F.relu(inner_X) # (bs, n, dx) 130 | Rstar_t_numer_E = F.relu(inner_E) # (bs, n, n, de) 131 | 132 | # Denominator 133 | Rstar_t_denom_X = Z_t_X * pt_vals_at_Xt # (bs, n) 134 | Rstar_t_denom_E = Z_t_E * pt_vals_at_Et # (bs, n, n) 135 | 136 | # Final R^\star 137 | Rstar_t_X = Rstar_t_numer_X / Rstar_t_denom_X[:, :, None] # (bs, n, dx) 138 | Rstar_t_E = Rstar_t_numer_E / Rstar_t_denom_E[:, :, :, None] # (B, n, n, de) 139 | 140 | return Rstar_t_X, Rstar_t_E 141 | 142 | def compute_RDB( 143 | self, 144 | X_t_label, 145 | E_t_label, 146 | X_1_pred, 147 | E_1_pred, 148 | X_1_sampled, 149 | E_1_sampled, 150 | node_mask, 151 | t, 152 | dfm_variables, 153 | ): 154 | # unpack needed variables 155 | pt_vals_X = dfm_variables["pt_vals_X"] 156 | pt_vals_E = dfm_variables["pt_vals_E"] 157 | 158 | # dimensions 159 | dx = pt_vals_X.shape[-1] 160 | de = pt_vals_E.shape[-1] 161 | 162 | # build mask for Rdb 163 | if self.rdb == "general": 164 | x_mask = torch.ones_like(pt_vals_X) 165 | e_mask = torch.ones_like(pt_vals_E) 166 | 167 | elif self.rdb == "marginal": 168 | x_limit = self.limit_dist.X 169 | e_limit = self.limit_dist.E 170 | 171 | Xt_marginal = x_limit[X_t_label] 172 | Et_marginal = e_limit[E_t_label] 173 | 174 | x_mask = x_limit.repeat(X_t_label.shape[0], X_t_label.shape[1], 1) 175 | e_mask = e_limit.repeat( 176 | E_t_label.shape[0], E_t_label.shape[1], E_t_label.shape[2], 1 177 | ) 178 | 179 | x_mask = x_mask > Xt_marginal 180 | e_mask = e_mask > Et_marginal 181 | 182 | elif self.rdb == "column": 183 | # Get column idx to pick 184 | if self.rdb_crit == "max_marginal": 185 | x_column_idxs = self.limit_dist.X.argmax(keepdim=True).expand( 186 | X_t_label.shape 187 | ) 188 | e_column_idxs = self.limit_dist.E.argmax(keepdim=True).expand( 189 | E_t_label.shape 190 | ) 191 | elif self.rdb_crit == "x_t": 192 | x_column_idxs = X_t_label 193 | e_column_idxs = E_t_label 194 | elif self.rdb_crit == "abs_state": 195 | x_column_idxs = torch.ones_like(X_t_label) * (dx - 1) 196 | e_column_idxs = torch.ones_like(E_t_label) * (de - 1) 197 | elif self.rdb_crit == "p_x1_g_xt": 198 | x_column_idxs = X_1_pred.argmax(dim=-1, keepdim=True) 199 | e_column_idxs = E_1_pred.argmax(dim=-1, keepdim=True) 200 | elif self.rdb_crit == "x_1": # as in paper, uniform 201 | x_column_idxs = X_1_sampled.unsqueeze(-1) 202 | e_column_idxs = E_1_sampled.unsqueeze(-1) 203 | elif self.rdb_crit == "p_xt_g_x1": 204 | x_column_idxs = pt_vals_X.argmax(dim=-1, keepdim=True) 205 | e_column_idxs = pt_vals_E.argmax(dim=-1, keepdim=True) 206 | elif self.rdb_crit == "xhat_t": 207 | sampled_1_hat = flow_matching_utils.sample_discrete_features( 208 | pt_vals_X, 209 | pt_vals_E, 210 | node_mask=node_mask, 211 | ) 212 | x_column_idxs = sampled_1_hat.X.unsqueeze(-1) 213 | e_column_idxs = sampled_1_hat.E.unsqueeze(-1) 214 | else: 215 | raise NotImplementedError 216 | 217 | # create mask based on columns picked 218 | x_mask = F.one_hot(x_column_idxs.squeeze(-1), num_classes=dx) 219 | x_mask[(x_column_idxs == X_t_label).squeeze(-1)] = 1.0 220 | e_mask = F.one_hot(e_column_idxs.squeeze(-1), num_classes=de) 221 | e_mask[(e_column_idxs == E_t_label).squeeze(-1)] = 1.0 222 | 223 | elif self.rdb == "entry": 224 | if self.rdb_crit == "abs_state": 225 | # select last index 226 | x_masked_idx = torch.tensor( 227 | dx 228 | - 1 # delete -1 for the last index 229 | # dx - 1 230 | ).to( 231 | self.device 232 | ) # leaving this for now, can change later if we want to explore it a bit more 233 | e_masked_idx = torch.tensor(de - 1).to(self.device) 234 | 235 | x1_idxs = X_1_sampled.unsqueeze(-1) # (bs, n, 1) 236 | e1_idxs = E_1_sampled.unsqueeze(-1) # (bs, n, n, 1) 237 | if self.rdb_crit == "first": # here in all datasets it's the argmax 238 | # select last index 239 | x_masked_idx = torch.tensor(0).to( 240 | self.device 241 | ) # leaving this for now, can change later if we want to explore it a bit more 242 | e_masked_idx = torch.tensor(0).to(self.device) 243 | 244 | x1_idxs = X_1_sampled.unsqueeze(-1) # (bs, n, 1) 245 | e1_idxs = E_1_sampled.unsqueeze(-1) # (bs, n, n, 1) 246 | else: 247 | raise NotImplementedError 248 | 249 | # create mask based on columns picked 250 | # bs, n, _ = X_t_label.shape 251 | # x_mask = torch.zeros((bs, n, dx), device=self.device) # (bs, n, dx) 252 | x_mask = torch.zeros_like(pt_vals_X) # (bs, n, dx) 253 | xt_in_x1 = (X_t_label == x1_idxs).squeeze(-1) # (bs, n, 1) 254 | x_mask[xt_in_x1] = F.one_hot(x_masked_idx, num_classes=dx).float() 255 | xt_in_masked = (X_t_label == x_masked_idx).squeeze(-1) 256 | x_mask[xt_in_masked] = F.one_hot( 257 | x1_idxs.squeeze(-1), num_classes=dx 258 | ).float()[xt_in_masked] 259 | 260 | # e_mask = torch.zeros((bs, n, n, de), device=self.device) # (bs, n, dx) 261 | e_mask = torch.zeros_like(pt_vals_E) 262 | et_in_e1 = (E_t_label == e1_idxs).squeeze(-1) 263 | e_mask[et_in_e1] = F.one_hot(e_masked_idx, num_classes=de).float() 264 | et_in_masked = (E_t_label == e_masked_idx).squeeze(-1) 265 | e_mask[et_in_masked] = F.one_hot( 266 | e1_idxs.squeeze(-1), num_classes=de 267 | ).float()[et_in_masked] 268 | 269 | else: 270 | raise NotImplementedError(f"Not implemented rdb type: {self.rdb}") 271 | 272 | # stochastic rate matrix 273 | Rdb_t_X = pt_vals_X * x_mask * self.eta 274 | Rdb_t_E = pt_vals_E * e_mask * self.eta 275 | 276 | return Rdb_t_X, Rdb_t_E 277 | 278 | def compute_R_tg( 279 | self, 280 | X_1_sampled, 281 | E_1_sampled, 282 | X_t_label, 283 | E_t_label, 284 | dfm_variables, 285 | ): 286 | """Target guidance rate matrix""" 287 | 288 | # Unpack needed variables 289 | pt_vals_at_Xt = dfm_variables["pt_vals_at_Xt"] 290 | pt_vals_at_Et = dfm_variables["pt_vals_at_Et"] 291 | Z_t_X = dfm_variables["Z_t_X"] 292 | Z_t_E = dfm_variables["Z_t_E"] 293 | 294 | # Numerator 295 | X1_onehot = F.one_hot(X_1_sampled, num_classes=self.num_classes_X).float() 296 | E1_onehot = F.one_hot(E_1_sampled, num_classes=self.num_classes_E).float() 297 | mask_X = X_1_sampled.unsqueeze(-1) != X_t_label 298 | mask_E = E_1_sampled.unsqueeze(-1) != E_t_label 299 | 300 | Rtg_t_numer_X = X1_onehot * self.omega * mask_X 301 | Rtg_t_numer_E = E1_onehot * self.omega * mask_E 302 | 303 | # Denominator 304 | denom_X = Z_t_X * pt_vals_at_Xt # (bs, n) 305 | denom_E = Z_t_E * pt_vals_at_Et # (bs, n, n) 306 | 307 | # Final R^TG 308 | Rtg_t_X = Rtg_t_numer_X / denom_X[:, :, None] 309 | Rtg_t_E = Rtg_t_numer_E / denom_E[:, :, :, None] 310 | 311 | return Rtg_t_X, Rtg_t_E 312 | 313 | def stabilize_rate_matrix(self, R_t_X, R_t_E, dfm_variables): 314 | 315 | # Unpack needed variables 316 | pt_vals_X = dfm_variables["pt_vals_X"] 317 | pt_vals_E = dfm_variables["pt_vals_E"] 318 | pt_vals_at_Xt = dfm_variables["pt_vals_at_Xt"] 319 | pt_vals_at_Et = dfm_variables["pt_vals_at_Et"] 320 | 321 | # protect to avoid NaN and too large values 322 | R_t_X = torch.nan_to_num(R_t_X, nan=0.0, posinf=0.0, neginf=0.0) 323 | R_t_E = torch.nan_to_num(R_t_E, nan=0.0, posinf=0.0, neginf=0.0) 324 | R_t_X[R_t_X > 1e5] = 0.0 325 | R_t_E[R_t_E > 1e5] = 0.0 326 | 327 | # Set p(x_t | x_1) = 0 or p(j | x_1) = 0 cases to zero, which need to be applied to Rdb too 328 | dx = R_t_X.shape[-1] 329 | de = R_t_E.shape[-1] 330 | R_t_X[(pt_vals_at_Xt == 0.0)[:, :, None].repeat(1, 1, dx)] = 0.0 331 | R_t_E[(pt_vals_at_Et == 0.0)[:, :, :, None].repeat(1, 1, 1, de)] = 0.0 332 | 333 | # zero-out certain columns of R, which is implied in the computation of Rdb 334 | # if the probability of a place is 0, then we should not consider it in the R computation 335 | R_t_X[pt_vals_X == 0.0] = 0.0 336 | R_t_E[pt_vals_E == 0.0] = 0.0 337 | 338 | return R_t_X, R_t_E 339 | -------------------------------------------------------------------------------- /src/models/transformer_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.modules.dropout import Dropout 6 | from torch.nn.modules.linear import Linear 7 | from torch.nn.modules.normalization import LayerNorm 8 | from torch.nn import functional as F 9 | from torch import Tensor 10 | 11 | from src import utils 12 | from flow_matching import flow_matching_utils 13 | from models.layers import Xtoy, Etoy, masked_softmax 14 | 15 | 16 | class XEyTransformerLayer(nn.Module): 17 | """Transformer that updates node, edge and global features 18 | d_x: node features 19 | d_e: edge features 20 | dz : global features 21 | n_head: the number of heads in the multi_head_attention 22 | dim_feedforward: the dimension of the feedforward network model after self-attention 23 | dropout: dropout probablility. 0 to disable 24 | layer_norm_eps: eps value in layer normalizations. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | dx: int, 30 | de: int, 31 | dy: int, 32 | n_head: int, 33 | dim_ffX: int = 2048, 34 | dim_ffE: int = 128, 35 | dim_ffy: int = 2048, 36 | dropout: float = 0.1, 37 | layer_norm_eps: float = 1e-5, 38 | device=None, 39 | dtype=None, 40 | ) -> None: 41 | kw = {"device": device, "dtype": dtype} 42 | super().__init__() 43 | 44 | self.self_attn = NodeEdgeBlock(dx, de, dy, n_head, **kw) 45 | 46 | self.linX1 = Linear(dx, dim_ffX, **kw) 47 | self.linX2 = Linear(dim_ffX, dx, **kw) 48 | self.normX1 = LayerNorm(dx, eps=layer_norm_eps, **kw) 49 | self.normX2 = LayerNorm(dx, eps=layer_norm_eps, **kw) 50 | self.dropoutX1 = Dropout(dropout) 51 | self.dropoutX2 = Dropout(dropout) 52 | self.dropoutX3 = Dropout(dropout) 53 | 54 | self.linE1 = Linear(de, dim_ffE, **kw) 55 | self.linE2 = Linear(dim_ffE, de, **kw) 56 | self.normE1 = LayerNorm(de, eps=layer_norm_eps, **kw) 57 | self.normE2 = LayerNorm(de, eps=layer_norm_eps, **kw) 58 | self.dropoutE1 = Dropout(dropout) 59 | self.dropoutE2 = Dropout(dropout) 60 | self.dropoutE3 = Dropout(dropout) 61 | 62 | self.lin_y1 = Linear(dy, dim_ffy, **kw) 63 | self.lin_y2 = Linear(dim_ffy, dy, **kw) 64 | self.norm_y1 = LayerNorm(dy, eps=layer_norm_eps, **kw) 65 | self.norm_y2 = LayerNorm(dy, eps=layer_norm_eps, **kw) 66 | self.dropout_y1 = Dropout(dropout) 67 | self.dropout_y2 = Dropout(dropout) 68 | self.dropout_y3 = Dropout(dropout) 69 | 70 | self.activation = F.relu 71 | 72 | def forward(self, X: Tensor, E: Tensor, y, node_mask: Tensor): 73 | """Pass the input through the encoder layer. 74 | X: (bs, n, d) 75 | E: (bs, n, n, d) 76 | y: (bs, dy) 77 | node_mask: (bs, n) Mask for the src keys per batch (optional) 78 | Output: newX, newE, new_y with the same shape. 79 | """ 80 | 81 | newX, newE, new_y = self.self_attn(X, E, y, node_mask=node_mask) 82 | 83 | newX_d = self.dropoutX1(newX) 84 | X = self.normX1(X + newX_d) 85 | # X = self.normX1(newX_d + nn.functional.sigmoid(X)) 86 | 87 | newE_d = self.dropoutE1(newE) 88 | E = self.normE1(E + newE_d) 89 | # newE_d = self.normE1(newE_d + nn.functional.sigmoid(E)) 90 | 91 | new_y_d = self.dropout_y1(new_y) 92 | y = self.norm_y1(y + new_y_d) 93 | # new_y_d = self.norm_y1(new_y_d + nn.functional.sigmoid(y)) 94 | 95 | ff_outputX = self.linX2(self.dropoutX2(self.activation(self.linX1(X)))) 96 | ff_outputX = self.dropoutX3(ff_outputX) 97 | X = self.normX2(X + ff_outputX) 98 | 99 | ff_outputE = self.linE2(self.dropoutE2(self.activation(self.linE1(E)))) 100 | ff_outputE = self.dropoutE3(ff_outputE) 101 | E = self.normE2(E + ff_outputE) 102 | 103 | ff_output_y = self.lin_y2(self.dropout_y2(self.activation(self.lin_y1(y)))) 104 | ff_output_y = self.dropout_y3(ff_output_y) 105 | y = self.norm_y2(y + ff_output_y) 106 | 107 | return X, E, y 108 | 109 | 110 | class NodeEdgeBlock(nn.Module): 111 | """Self attention layer that also updates the representations on the edges.""" 112 | 113 | def __init__(self, dx, de, dy, n_head, **kwargs): 114 | super().__init__() 115 | assert dx % n_head == 0, f"dx: {dx} -- nhead: {n_head}" 116 | self.dropout_attn = Dropout(0.1) 117 | self.dropout_X = Dropout(0.1) 118 | self.dropout_E = Dropout(0.1) 119 | 120 | self.dx = dx 121 | self.de = de 122 | self.dy = dy 123 | self.df = int(dx / n_head) 124 | self.n_head = n_head 125 | 126 | # Attention 127 | self.q = Linear(dx, dx) 128 | self.k = Linear(dx, dx) 129 | self.v = Linear(dx, dx) 130 | 131 | # FiLM E to X 132 | self.e_add = Linear(de, dx) 133 | self.e_mul = Linear(de, dx) 134 | 135 | # FiLM y to E 136 | self.y_e_mul = Linear(dy, dx) # Warning: here it's dx and not de 137 | self.y_e_add = Linear(dy, dx) 138 | 139 | # FiLM y to X 140 | self.y_x_mul = Linear(dy, dx) 141 | self.y_x_add = Linear(dy, dx) 142 | 143 | # Process y 144 | self.y_y = Linear(dy, dy) 145 | self.x_y = Xtoy(dx, dy) 146 | self.e_y = Etoy(de, dy) 147 | 148 | # Output layers 149 | self.x_out = Linear(dx, dx) 150 | self.e_out = Linear(dx, de) 151 | self.y_out = nn.Sequential(nn.Linear(dy, dy), nn.ReLU(), nn.Linear(dy, dy)) 152 | 153 | def forward(self, X, E, y, node_mask): 154 | """ 155 | :param X: bs, n, d node features 156 | :param E: bs, n, n, d edge features 157 | :param y: bs, dz global features 158 | :param node_mask: bs, n 159 | :return: newX, newE, new_y with the same shape. 160 | """ 161 | bs, n, _ = X.shape 162 | x_mask = node_mask.unsqueeze(-1) # bs, n, 1 163 | e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1 164 | e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1 165 | 166 | # 1. Map X to keys and queries 167 | Q = self.q(X) * x_mask # (bs, n, dx) 168 | K = self.k(X) * x_mask # (bs, n, dx) 169 | flow_matching_utils.assert_correctly_masked(Q, x_mask) 170 | # 2. Reshape to (bs, n, n_head, df) with dx = n_head * df 171 | 172 | Q = Q.reshape((Q.size(0), Q.size(1), self.n_head, self.df)) 173 | K = K.reshape((K.size(0), K.size(1), self.n_head, self.df)) 174 | 175 | Q = Q.unsqueeze(2) # (bs, 1, n, n_head, df) 176 | K = K.unsqueeze(1) # (bs, n, 1, n head, df) 177 | 178 | # Compute unnormalized attentions. Y is (bs, n, n, n_head, df) 179 | Y = Q * K 180 | Y = Y / math.sqrt(Y.size(-1)) 181 | 182 | # # Compute the distance based on the Gaussian kernel 183 | # Y_exp = torch.exp( 184 | # -(Q - K) * (Q - K) / math.sqrt(Y.size(-1)) 185 | # ) # bs, n, n, n_head, df 186 | # Y = Y + Y_exp 187 | flow_matching_utils.assert_correctly_masked( 188 | Y, (e_mask1 * e_mask2).unsqueeze(-1) 189 | ) 190 | 191 | # Incorporate edge features to the self attention scores. 192 | E1 = self.e_mul(E) * e_mask1 * e_mask2 # bs, n, n, dx 193 | E1 = E1.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df)) 194 | E2 = self.e_add(E) * e_mask1 * e_mask2 # bs, n, n, dx 195 | E2 = E2.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df)) 196 | Y = Y * (E1 + 1) + E2 # (bs, n, n, n_head, df) 197 | 198 | # Incorporate y to E 199 | newE = Y.flatten(start_dim=3) # bs, n, n, dx 200 | ye1 = self.y_e_add(y).unsqueeze(1).unsqueeze(1) # bs, 1, 1, de 201 | ye2 = self.y_e_mul(y).unsqueeze(1).unsqueeze(1) 202 | newE = ye1 + (ye2 + 1) * newE 203 | 204 | # Output E 205 | newE = self.e_out(newE) * e_mask1 * e_mask2 # bs, n, n, de 206 | flow_matching_utils.assert_correctly_masked(newE, e_mask1 * e_mask2) 207 | 208 | # Compute attentions. attn is still (bs, n, n, n_head, df) 209 | softmax_mask = e_mask2.expand(-1, n, -1, self.n_head) # bs, 1, n, 1 210 | attn = masked_softmax(Y, softmax_mask, dim=2) # bs, n, n, n_head 211 | 212 | V = self.v(X) * x_mask # bs, n, dx 213 | V = V.reshape((V.size(0), V.size(1), self.n_head, self.df)) 214 | V = V.unsqueeze(1) # (bs, 1, n, n_head, df) 215 | 216 | # Compute weighted values 217 | weighted_V = attn * V 218 | weighted_V = weighted_V.sum(dim=2) 219 | 220 | # Send output to input dim 221 | weighted_V = weighted_V.flatten(start_dim=2) # bs, n, dx 222 | 223 | # Incorporate y to X 224 | yx1 = self.y_x_add(y).unsqueeze(1) 225 | yx2 = self.y_x_mul(y).unsqueeze(1) 226 | newX = yx1 + (yx2 + 1) * weighted_V 227 | 228 | # Output X 229 | newX = self.x_out(newX) * x_mask 230 | flow_matching_utils.assert_correctly_masked(newX, x_mask) 231 | 232 | # Process y based on X axnd E 233 | y = self.y_y(y) 234 | e_y = self.e_y(E) 235 | x_y = self.x_y(X) 236 | new_y = y + x_y + e_y 237 | new_y = self.y_out(new_y) # bs, dy 238 | 239 | # newX = self.dropout_X(newX) 240 | # newE = self.dropout_E(newE) 241 | 242 | return newX, newE, new_y 243 | 244 | 245 | class GraphTransformer(nn.Module): 246 | """ 247 | n_layers : int -- number of layers 248 | dims : dict -- contains dimensions for each feature type 249 | """ 250 | 251 | def __init__( 252 | self, 253 | n_layers: int, 254 | input_dims: dict, 255 | hidden_mlp_dims: dict, 256 | hidden_dims: dict, 257 | output_dims: dict, 258 | act_fn_in: nn.ReLU(), 259 | act_fn_out: nn.ReLU(), 260 | ): 261 | super().__init__() 262 | self.n_layers = n_layers 263 | self.out_dim_X = output_dims["X"] 264 | self.out_dim_E = output_dims["E"] 265 | self.out_dim_y = output_dims["y"] 266 | 267 | self.mlp_in_X = nn.Sequential( 268 | nn.Linear(input_dims["X"], hidden_mlp_dims["X"]), 269 | act_fn_in, 270 | nn.Linear(hidden_mlp_dims["X"], hidden_dims["dx"]), 271 | act_fn_in, 272 | ) 273 | 274 | self.mlp_in_E = nn.Sequential( 275 | nn.Linear(input_dims["E"], hidden_mlp_dims["E"]), 276 | act_fn_in, 277 | nn.Linear(hidden_mlp_dims["E"], hidden_dims["de"]), 278 | act_fn_in, 279 | ) 280 | 281 | self.mlp_in_y = nn.Sequential( 282 | nn.Linear(input_dims["y"] + 64, hidden_mlp_dims["y"]), 283 | act_fn_in, 284 | nn.Linear(hidden_mlp_dims["y"], hidden_dims["dy"]), 285 | act_fn_in, 286 | ) 287 | 288 | self.tf_layers = nn.ModuleList( 289 | [ 290 | XEyTransformerLayer( 291 | dx=hidden_dims["dx"], 292 | de=hidden_dims["de"], 293 | dy=hidden_dims["dy"], 294 | n_head=hidden_dims["n_head"], 295 | dim_ffX=hidden_dims["dim_ffX"], 296 | dim_ffE=hidden_dims["dim_ffE"], 297 | ) 298 | for i in range(n_layers) 299 | ] 300 | ) 301 | 302 | self.mlp_out_X = nn.Sequential( 303 | nn.Linear(hidden_dims["dx"], hidden_mlp_dims["X"]), 304 | act_fn_out, 305 | nn.Linear(hidden_mlp_dims["X"], output_dims["X"]), 306 | ) 307 | 308 | self.mlp_out_E = nn.Sequential( 309 | nn.Linear(hidden_dims["de"], hidden_mlp_dims["E"]), 310 | act_fn_out, 311 | nn.Linear(hidden_mlp_dims["E"], output_dims["E"]), 312 | ) 313 | 314 | self.mlp_out_y = nn.Sequential( 315 | nn.Linear(hidden_dims["dy"], hidden_mlp_dims["y"]), 316 | act_fn_out, 317 | nn.Linear(hidden_mlp_dims["y"], output_dims["y"]), 318 | ) 319 | 320 | def forward(self, X, E, y, node_mask): 321 | bs, n = X.shape[0], X.shape[1] 322 | 323 | diag_mask = torch.eye(n) 324 | diag_mask = ~diag_mask.type_as(E).bool() 325 | diag_mask = diag_mask.unsqueeze(0).unsqueeze(-1).expand(bs, -1, -1, -1) 326 | 327 | X_to_out = X[..., : self.out_dim_X] 328 | E_to_out = E[..., : self.out_dim_E] 329 | y_to_out = y[..., : self.out_dim_y] 330 | 331 | new_E = self.mlp_in_E(E) 332 | new_E = (new_E + new_E.transpose(1, 2)) / 2 333 | 334 | # encode time steps additionally 335 | time_emb = timestep_embedding(y[:, -1].unsqueeze(-1), 64) 336 | y = torch.hstack([y, time_emb]) 337 | 338 | after_in = utils.PlaceHolder( 339 | X=self.mlp_in_X(X), E=new_E, y=self.mlp_in_y(y) 340 | ).mask(node_mask) 341 | X, E, y = after_in.X, after_in.E, after_in.y 342 | 343 | for layer in self.tf_layers: 344 | X, E, y = layer(X, E, y, node_mask) 345 | 346 | X = self.mlp_out_X(X) 347 | E = self.mlp_out_E(E) 348 | y = self.mlp_out_y(y) 349 | 350 | X = X + X_to_out 351 | E = (E + E_to_out) * diag_mask 352 | y = y + y_to_out 353 | 354 | E = 1 / 2 * (E + torch.transpose(E, 1, 2)) 355 | 356 | return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask) 357 | 358 | 359 | def timestep_embedding(timesteps, dim, max_period=10000): 360 | """ 361 | Create sinusoidal timestep embeddings. 362 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 363 | These may be fractional. 364 | :param dim: the dimension of the output. 365 | :param max_period: controls the minimum frequency of the embeddings. 366 | :return: an [N x dim] Tensor of positional embeddings. 367 | """ 368 | half = dim // 2 369 | freqs = torch.exp( 370 | -math.log(max_period) 371 | * torch.arange(start=0, end=half, dtype=torch.float32) 372 | / half 373 | ).to(device=timesteps.device) 374 | args = timesteps.float() * freqs[None] 375 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 376 | if dim % 2: 377 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 378 | 379 | # import pdb; pdb.set_trace() 380 | return embedding 381 | -------------------------------------------------------------------------------- /src/datasets/moses_dataset.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem, RDLogger 2 | from rdkit.Chem.rdchem import BondType as BT 3 | 4 | import os 5 | import os.path as osp 6 | import pathlib 7 | from typing import Any, Sequence 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from tqdm import tqdm 12 | import numpy as np 13 | from torch_geometric.data import Data, InMemoryDataset, download_url 14 | import pandas as pd 15 | 16 | from src import utils 17 | from analysis.rdkit_functions import ( 18 | mol2smiles, 19 | build_molecule_with_partial_charges, 20 | compute_molecular_metrics, 21 | ) 22 | from datasets.abstract_dataset import AbstractDatasetInfos, MolecularDataModule 23 | 24 | 25 | def to_list(value: Any) -> Sequence: 26 | if isinstance(value, Sequence) and not isinstance(value, str): 27 | return value 28 | else: 29 | return [value] 30 | 31 | 32 | atom_decoder = ["C", "N", "S", "O", "F", "Cl", "Br", "H"] 33 | 34 | 35 | class MOSESDataset(InMemoryDataset): 36 | train_url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/train.csv" 37 | val_url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/test.csv" 38 | test_url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/test_scaffolds.csv" 39 | 40 | def __init__( 41 | self, 42 | stage, 43 | root, 44 | filter_dataset: bool, 45 | transform=None, 46 | pre_transform=None, 47 | pre_filter=None, 48 | ): 49 | self.stage = stage 50 | self.atom_decoder = atom_decoder 51 | self.filter_dataset = filter_dataset 52 | if self.stage == "train": 53 | self.file_idx = 0 54 | elif self.stage == "val": 55 | self.file_idx = 1 56 | else: 57 | self.file_idx = 2 58 | super().__init__(root, transform, pre_transform, pre_filter) 59 | self.data, self.slices = torch.load(self.processed_paths[self.file_idx]) 60 | 61 | @property 62 | def raw_file_names(self): 63 | return ["train_moses.csv", "val_moses.csv", "test_moses.csv"] 64 | 65 | @property 66 | def split_file_name(self): 67 | return ["train_moses.csv", "val_moses.csv", "test_moses.csv"] 68 | 69 | @property 70 | def split_paths(self): 71 | r"""The absolute filepaths that must be present in order to skip 72 | splitting.""" 73 | files = to_list(self.split_file_name) 74 | return [osp.join(self.raw_dir, f) for f in files] 75 | 76 | @property 77 | def processed_file_names(self): 78 | if self.filter_dataset: 79 | return [ 80 | "train_filtered.pt", 81 | "test_filtered.pt", 82 | "test_scaffold_filtered.pt", 83 | ] 84 | else: 85 | return ["train.pt", "test.pt", "test_scaffold.pt"] 86 | 87 | def download(self): 88 | import rdkit # noqa 89 | 90 | train_path = download_url(self.train_url, self.raw_dir) 91 | os.rename(train_path, osp.join(self.raw_dir, "train_moses.csv")) 92 | 93 | test_path = download_url(self.test_url, self.raw_dir) 94 | os.rename(test_path, osp.join(self.raw_dir, "val_moses.csv")) 95 | 96 | valid_path = download_url(self.val_url, self.raw_dir) 97 | os.rename(valid_path, osp.join(self.raw_dir, "test_moses.csv")) 98 | 99 | def process(self): 100 | RDLogger.DisableLog("rdApp.*") 101 | types = {atom: i for i, atom in enumerate(self.atom_decoder)} 102 | 103 | bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} 104 | 105 | path = self.split_paths[self.file_idx] 106 | smiles_list = pd.read_csv(path)["SMILES"].values 107 | 108 | data_list = [] 109 | smiles_kept = [] 110 | 111 | for i, smile in enumerate(tqdm(smiles_list)): 112 | mol = Chem.MolFromSmiles(smile) 113 | N = mol.GetNumAtoms() 114 | 115 | type_idx = [] 116 | for atom in mol.GetAtoms(): 117 | type_idx.append(types[atom.GetSymbol()]) 118 | 119 | row, col, edge_type = [], [], [] 120 | for bond in mol.GetBonds(): 121 | start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() 122 | row += [start, end] 123 | col += [end, start] 124 | edge_type += 2 * [bonds[bond.GetBondType()] + 1] 125 | 126 | if len(row) == 0: 127 | continue 128 | 129 | edge_index = torch.tensor([row, col], dtype=torch.long) 130 | edge_type = torch.tensor(edge_type, dtype=torch.long) 131 | edge_attr = F.one_hot(edge_type, num_classes=len(bonds) + 1).to(torch.float) 132 | 133 | perm = (edge_index[0] * N + edge_index[1]).argsort() 134 | edge_index = edge_index[:, perm] 135 | edge_attr = edge_attr[perm] 136 | 137 | x = F.one_hot(torch.tensor(type_idx), num_classes=len(types)).float() 138 | y = torch.zeros(size=(1, 0), dtype=torch.float) 139 | 140 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) 141 | 142 | if self.filter_dataset: 143 | # Try to build the molecule again from the graph. If it fails, do not add it to the training set 144 | dense_data, node_mask = utils.to_dense( 145 | data.x, data.edge_index, data.edge_attr, data.batch 146 | ) 147 | dense_data = dense_data.mask(node_mask, collapse=True) 148 | X, E = dense_data.X, dense_data.E 149 | 150 | assert X.size(0) == 1 151 | atom_types = X[0] 152 | edge_types = E[0] 153 | mol = build_molecule_with_partial_charges( 154 | atom_types, edge_types, atom_decoder 155 | ) 156 | smiles = mol2smiles(mol) 157 | if smiles is not None: 158 | try: 159 | mol_frags = Chem.rdmolops.GetMolFrags( 160 | mol, asMols=True, sanitizeFrags=True 161 | ) 162 | if len(mol_frags) == 1: 163 | data_list.append(data) 164 | smiles_kept.append(smiles) 165 | 166 | except Chem.rdchem.AtomValenceException: 167 | print("Valence error in GetmolFrags") 168 | except Chem.rdchem.KekulizeException: 169 | print("Can't kekulize molecule") 170 | else: 171 | if self.pre_filter is not None and not self.pre_filter(data): 172 | continue 173 | if self.pre_transform is not None: 174 | data = self.pre_transform(data) 175 | data_list.append(data) 176 | 177 | torch.save(self.collate(data_list), self.processed_paths[self.file_idx]) 178 | 179 | if self.filter_dataset: 180 | smiles_save_path = osp.join( 181 | pathlib.Path(self.raw_paths[0]).parent, f"new_{self.stage}.smiles" 182 | ) 183 | print(smiles_save_path) 184 | with open(smiles_save_path, "w") as f: 185 | f.writelines("%s\n" % s for s in smiles_kept) 186 | print(f"Number of molecules kept: {len(smiles_kept)} / {len(smiles_list)}") 187 | 188 | 189 | class MosesDataModule(MolecularDataModule): 190 | def __init__(self, cfg): 191 | self.remove_h = False 192 | self.datadir = cfg.dataset.datadir 193 | self.filter_dataset = cfg.dataset.filter 194 | self.train_smiles = [] 195 | base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] 196 | root_path = os.path.join(base_path, self.datadir) 197 | datasets = { 198 | "train": MOSESDataset( 199 | stage="train", root=root_path, filter_dataset=self.filter_dataset 200 | ), 201 | "val": MOSESDataset( 202 | stage="val", root=root_path, filter_dataset=self.filter_dataset 203 | ), 204 | "test": MOSESDataset( 205 | stage="test", root=root_path, filter_dataset=self.filter_dataset 206 | ), 207 | } 208 | super().__init__(cfg, datasets) 209 | 210 | 211 | class MOSESinfos(AbstractDatasetInfos): 212 | def __init__(self, datamodule, cfg, recompute_statistics=False, meta=None): 213 | self.name = "MOSES" 214 | self.input_dims = None 215 | self.output_dims = None 216 | self.remove_h = False 217 | self.compute_fcd = cfg.dataset.compute_fcd 218 | 219 | self.atom_decoder = atom_decoder 220 | self.atom_encoder = {atom: i for i, atom in enumerate(self.atom_decoder)} 221 | self.atom_weights = {0: 12, 1: 14, 2: 32, 3: 16, 4: 19, 5: 35.4, 6: 79.9, 7: 1} 222 | self.valencies = [4, 3, 4, 2, 1, 1, 1, 1] 223 | self.num_atom_types = len(self.atom_decoder) 224 | self.max_weight = 350 225 | 226 | meta_files = dict( 227 | n_nodes=f"{self.name}_n_counts.txt", 228 | node_types=f"{self.name}_atom_types.txt", 229 | edge_types=f"{self.name}_edge_types.txt", 230 | valency_distribution=f"{self.name}_valencies.txt", 231 | ) 232 | 233 | self.n_nodes = torch.tensor( 234 | [ 235 | 0.0, 236 | 0.0, 237 | 0.0, 238 | 0.0, 239 | 0.0, 240 | 0.0, 241 | 0.0, 242 | 0.0, 243 | 3.097634362347889692e-06, 244 | 1.858580617408733815e-05, 245 | 5.007842264603823423e-05, 246 | 5.678996240021660924e-05, 247 | 1.244216400664299726e-04, 248 | 4.486406978685408831e-04, 249 | 2.253012731671333313e-03, 250 | 3.231865121051669121e-03, 251 | 6.709992419928312302e-03, 252 | 2.289564721286296844e-02, 253 | 5.411050841212272644e-02, 254 | 1.099515631794929504e-01, 255 | 1.223291903734207153e-01, 256 | 1.280680745840072632e-01, 257 | 1.445975750684738159e-01, 258 | 1.505961418151855469e-01, 259 | 1.436946094036102295e-01, 260 | 9.265746921300888062e-02, 261 | 1.820066757500171661e-02, 262 | 2.065089574898593128e-06, 263 | ] 264 | ) 265 | self.max_n_nodes = len(self.n_nodes) - 1 if self.n_nodes is not None else None 266 | self.node_types = torch.tensor( 267 | [0.722338, 0.13661, 0.163655, 0.103549, 0.1421803, 0.005411, 0.00150, 0.0] 268 | ) 269 | self.edge_types = torch.tensor( 270 | [0.89740, 0.0472947, 0.062670, 0.0003524, 0.0486] 271 | ) 272 | self.valency_distribution = torch.zeros(3 * self.max_n_nodes - 2) 273 | self.valency_distribution[:7] = torch.tensor( 274 | [0.0, 0.1055, 0.2728, 0.3613, 0.2499, 0.00544, 0.00485] 275 | ) 276 | 277 | if meta is None: 278 | meta = dict( 279 | n_nodes=None, 280 | node_types=None, 281 | edge_types=None, 282 | valency_distribution=None, 283 | ) 284 | assert set(meta.keys()) == set(meta_files.keys()) 285 | for k, v in meta_files.items(): 286 | if (k not in meta or meta[k] is None) and os.path.exists(v): 287 | meta[k] = np.loadtxt(v) 288 | setattr(self, k, meta[k]) 289 | if recompute_statistics or self.n_nodes is None: 290 | self.n_nodes = datamodule.node_counts() 291 | print("Distribution of number of nodes", self.n_nodes) 292 | np.savetxt(meta_files["n_nodes"], self.n_nodes.numpy()) 293 | self.max_n_nodes = len(self.n_nodes) - 1 294 | if recompute_statistics or self.node_types is None: 295 | self.node_types = datamodule.node_types() # There are no node types 296 | print("Distribution of node types", self.node_types) 297 | np.savetxt(meta_files["node_types"], self.node_types.numpy()) 298 | 299 | if recompute_statistics or self.edge_types is None: 300 | self.edge_types = datamodule.edge_counts() 301 | print("Distribution of edge types", self.edge_types) 302 | np.savetxt(meta_files["edge_types"], self.edge_types.numpy()) 303 | if recompute_statistics or self.valency_distribution is None: 304 | valencies = datamodule.valency_count(self.max_n_nodes) 305 | print("Distribution of the valencies", valencies) 306 | np.savetxt(meta_files["valency_distribution"], valencies.numpy()) 307 | self.valency_distribution = valencies 308 | # after we can be sure we have the data, complete infos 309 | self.complete_infos(n_nodes=self.n_nodes, node_types=self.node_types) 310 | 311 | 312 | def get_smiles(raw_dir, filter_dataset): 313 | 314 | if filter_dataset: 315 | smiles_save_paths = { 316 | "train": osp.join(raw_dir, "new_train.smiles"), 317 | "val": osp.join(raw_dir, "new_val.smiles"), 318 | "test": osp.join(raw_dir, "new_test.smiles"), 319 | } 320 | train_smiles = open(smiles_save_paths["train"]).readlines() 321 | val_smiles = open(smiles_save_paths["val"]).readlines() 322 | test_smiles = open(smiles_save_paths["test"]).readlines() 323 | 324 | else: 325 | smiles_save_paths = { 326 | "train": osp.join(raw_dir, "train_moses.csv"), 327 | "val": osp.join(raw_dir, "val_moses.csv"), 328 | "test": osp.join(raw_dir, "test_moses.csv"), 329 | } 330 | train_smiles = extract_smiles_from_csv(smiles_save_paths["train"]) 331 | val_smiles = extract_smiles_from_csv(smiles_save_paths["val"]) 332 | test_smiles = extract_smiles_from_csv(smiles_save_paths["test"]) 333 | 334 | return { 335 | "train": train_smiles, 336 | "val": val_smiles, 337 | "test": test_smiles, 338 | } 339 | 340 | 341 | def extract_smiles_from_csv(csv_path): 342 | return pd.read_csv(csv_path)["SMILES"].to_list() 343 | 344 | 345 | if __name__ == "__main__": 346 | ds = [ 347 | MOSESDataset( 348 | s, 349 | os.path.join(os.path.abspath(__file__), "../../../data/moses"), 350 | preprocess=True, 351 | ) 352 | for s in ["train", "val", "test"] 353 | ] 354 | -------------------------------------------------------------------------------- /src/models/extra_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src import utils 3 | 4 | 5 | class DummyExtraFeatures: 6 | def __init__(self): 7 | """This class does not compute anything, just returns empty tensors.""" 8 | 9 | def __call__(self, noisy_data): 10 | X = noisy_data["X_t"] 11 | E = noisy_data["E_t"] 12 | y = noisy_data["y_t"] 13 | empty_x = X.new_zeros((*X.shape[:-1], 0)) 14 | empty_e = E.new_zeros((*E.shape[:-1], 0)) 15 | empty_y = y.new_zeros((y.shape[0], 0)) 16 | return utils.PlaceHolder(X=empty_x, E=empty_e, y=empty_y) 17 | 18 | 19 | class ExtraFeatures: 20 | def __init__(self, extra_features_type, rrwp_steps, dataset_info): 21 | self.max_n_nodes = dataset_info.max_n_nodes 22 | self.ncycles = NodeCycleFeatures() 23 | self.features_type = extra_features_type 24 | self.rrwp_steps = rrwp_steps 25 | self.RRWP = RRWPFeatures() 26 | self.RWP = RRWPFeatures(normalize=False) 27 | if extra_features_type in ["eigenvalues", "all"]: 28 | self.eigenfeatures = EigenFeatures(mode=extra_features_type) 29 | 30 | def __call__(self, noisy_data): 31 | n = noisy_data["node_mask"].sum(dim=1).unsqueeze(1) / self.max_n_nodes 32 | x_cycles, y_cycles = self.ncycles(noisy_data) # (bs, n_cycles) 33 | 34 | if self.features_type == "cycles": 35 | E = noisy_data["E_t"] 36 | extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E) 37 | return utils.PlaceHolder( 38 | X=x_cycles, E=extra_edge_attr, y=torch.hstack((n, y_cycles)) 39 | ) 40 | 41 | elif self.features_type == "eigenvalues": 42 | eigenfeatures = self.eigenfeatures(noisy_data) 43 | E = noisy_data["E_t"] 44 | extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E) 45 | n_components, batched_eigenvalues = eigenfeatures # (bs, 1), (bs, 10) 46 | return utils.PlaceHolder( 47 | X=x_cycles, 48 | E=extra_edge_attr, 49 | y=torch.hstack((n, y_cycles)), 50 | ) 51 | 52 | elif self.features_type == "rrwp": 53 | E = noisy_data["E_t"].float()[..., 1:].sum(-1) # bs, n, n 54 | rrwp_edge_attr = self.RRWP(E, k=self.rrwp_steps) 55 | diag_index = torch.arange(rrwp_edge_attr.shape[1]) 56 | rrwp_node_attr = rrwp_edge_attr[:, diag_index, diag_index, :] 57 | self.eigenfeatures = EigenFeatures(mode="all") 58 | 59 | return utils.PlaceHolder( 60 | X=rrwp_node_attr, 61 | E=rrwp_edge_attr, 62 | y=torch.hstack((n, y_cycles)), 63 | ) 64 | 65 | elif self.features_type == "rrwp_double": 66 | E = noisy_data["E_t"].float()[..., 1:].sum(-1) # bs, n, n 67 | rrwp_edge_attr = self.RRWP(E, k=self.rrwp_steps) 68 | rrwp_edge_attr_wo_norm = self.RWP(E, k=self.rrwp_steps) 69 | 70 | # Normalize the rrwp_edge_attr_wo_norm 71 | max_value = rrwp_edge_attr_wo_norm.max(dim=1, keepdim=True).values 72 | max_value = max_value.max(dim=2, keepdim=True).values 73 | rrwp_edge_attr_wo_norm = rrwp_edge_attr_wo_norm / max_value 74 | 75 | rrwp_edge_attr = torch.cat((rrwp_edge_attr, rrwp_edge_attr_wo_norm), dim=-1) 76 | diag_index = torch.arange(rrwp_edge_attr.shape[1]) 77 | rrwp_node_attr = rrwp_edge_attr[:, diag_index, diag_index, :] 78 | # self.eigenfeatures = EigenFeatures(mode='all') 79 | 80 | return utils.PlaceHolder( 81 | X=rrwp_node_attr, 82 | E=rrwp_edge_attr, 83 | y=torch.hstack((n, y_cycles)), 84 | ) 85 | 86 | elif self.features_type == "rrwp_only": 87 | E = noisy_data["E_t"].float()[..., 1:].sum(-1) # bs, n, n 88 | rrwp_edge_attr = self.RRWP(E, k=self.rrwp_steps) 89 | diag_index = torch.arange(rrwp_edge_attr.shape[1]) 90 | rrwp_node_attr = rrwp_edge_attr[:, diag_index, diag_index, :] 91 | 92 | return utils.PlaceHolder( 93 | X=rrwp_node_attr, 94 | E=rrwp_edge_attr, 95 | y=n, 96 | ) 97 | 98 | elif self.features_type == "rrwp_comp": 99 | E = noisy_data["E_t"].float()[..., 1:].sum(-1) # bs, n, n 100 | rrwp_edge_attr = self.RRWP(E, k=int(self.rrwp_steps / 2)) 101 | diag_index = torch.arange(rrwp_edge_attr.shape[1]) 102 | rrwp_node_attr = rrwp_edge_attr[:, diag_index, diag_index, :] 103 | 104 | comp_E = 1 - noisy_data["E_t"].float()[..., 1:].sum(-1) # bs, n, n 105 | comp_rrwp_edge_attr = self.RRWP(comp_E, k=int(self.rrwp_steps / 2)) 106 | comp_rrwp_node_attr = comp_rrwp_edge_attr[:, diag_index, diag_index, :] 107 | 108 | return utils.PlaceHolder( 109 | X=torch.cat((rrwp_node_attr, comp_rrwp_node_attr), dim=-1), 110 | E=torch.cat((rrwp_edge_attr, comp_rrwp_edge_attr), dim=-1), 111 | y=torch.hstack((n, y_cycles)), 112 | ) 113 | 114 | elif self.features_type == "all": 115 | eigenfeatures = self.eigenfeatures(noisy_data) 116 | E = noisy_data["E_t"] 117 | extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E) 118 | n_components, batched_eigenvalues, nonlcc_indicator, k_lowest_eigvec = ( 119 | eigenfeatures # (bs, 1), (bs, 10), 120 | ) 121 | 122 | return utils.PlaceHolder( 123 | X=torch.cat( 124 | (x_cycles, nonlcc_indicator, k_lowest_eigvec), 125 | dim=-1, 126 | ), 127 | E=extra_edge_attr, 128 | y=torch.hstack((n, y_cycles, n_components, batched_eigenvalues)), 129 | ) 130 | 131 | else: 132 | raise ValueError(f"Features type {self.features_type} not implemented") 133 | 134 | 135 | class RRWPFeatures: 136 | def __init__(self, k=10, normalize=True): 137 | self.k = k 138 | self.normalize = normalize 139 | 140 | def __call__(self, E, k=None): 141 | k = k or self.k 142 | 143 | ( 144 | bs, 145 | n, 146 | _, 147 | ) = E.shape 148 | if self.normalize: 149 | degree = torch.zeros(bs, n, n, device=E.device) 150 | to_fill = 1 / (E.sum(dim=-1).float()) 151 | to_fill[E.sum(dim=-1).float() == 0] = 0 152 | degree = torch.diagonal_scatter(degree, to_fill, dim1=1, dim2=2) 153 | E = degree @ E 154 | 155 | id = torch.eye(n, device=E.device).unsqueeze(0).repeat(bs, 1, 1) 156 | rrwp_list = [id] 157 | 158 | for i in range(k - 1): 159 | cur_rrwp = rrwp_list[-1] @ E 160 | rrwp_list.append(cur_rrwp) 161 | 162 | return torch.stack(rrwp_list, -1) 163 | 164 | 165 | class NodeCycleFeatures: 166 | def __init__(self): 167 | self.kcycles = KNodeCycles() 168 | 169 | def __call__(self, noisy_data): 170 | adj_matrix = noisy_data["E_t"][..., 1:].sum(dim=-1).float() 171 | 172 | x_cycles, y_cycles = self.kcycles.k_cycles( 173 | adj_matrix=adj_matrix 174 | ) # (bs, n_cycles) 175 | x_cycles = x_cycles.type_as(adj_matrix) * noisy_data["node_mask"].unsqueeze(-1) 176 | # Avoid large values when the graph is dense 177 | x_cycles = x_cycles / 10 178 | y_cycles = y_cycles / 10 179 | x_cycles[x_cycles > 1] = 1 180 | y_cycles[y_cycles > 1] = 1 181 | return x_cycles, y_cycles 182 | 183 | 184 | class EigenFeatures: 185 | """ 186 | Code taken from : https://github.com/Saro00/DGN/blob/master/models/pytorch/eigen_agg.py 187 | """ 188 | 189 | def __init__(self, mode): 190 | """mode: 'eigenvalues' or 'all'""" 191 | self.mode = mode 192 | 193 | def __call__(self, noisy_data): 194 | E_t = noisy_data["E_t"] 195 | mask = noisy_data["node_mask"] 196 | A = E_t[..., 1:].sum(dim=-1).float() * mask.unsqueeze(1) * mask.unsqueeze(2) 197 | # L = compute_laplacian(A, normalize="sym") 198 | L = compute_laplacian(A, normalize=False) 199 | mask_diag = 2 * L.shape[-1] * torch.eye(A.shape[-1]).type_as(L).unsqueeze(0) 200 | mask_diag = mask_diag * (~mask.unsqueeze(1)) * (~mask.unsqueeze(2)) 201 | L = L * mask.unsqueeze(1) * mask.unsqueeze(2) + mask_diag 202 | 203 | if self.mode == "eigenvalues": 204 | eigvals = torch.linalg.eigvalsh(L) # bs, n 205 | eigvals = eigvals.type_as(A) / torch.sum(mask, dim=1, keepdim=True) 206 | 207 | n_connected_comp, batch_eigenvalues = get_eigenvalues_features( 208 | eigenvalues=eigvals 209 | ) 210 | return n_connected_comp.type_as(A), batch_eigenvalues.type_as(A) 211 | 212 | elif self.mode == "all": 213 | eigvals, eigvectors = torch.linalg.eigh(L) 214 | # print(eigvals) 215 | eigvals = eigvals.type_as(A) / torch.sum(mask, dim=1, keepdim=True) 216 | eigvectors = eigvectors * mask.unsqueeze(2) * mask.unsqueeze(1) 217 | # Retrieve eigenvalues features 218 | n_connected_comp, batch_eigenvalues = get_eigenvalues_features( 219 | eigenvalues=eigvals 220 | ) 221 | 222 | # Retrieve eigenvectors features 223 | nonlcc_indicator, k_lowest_eigenvector = get_eigenvectors_features( 224 | vectors=eigvectors, 225 | node_mask=noisy_data["node_mask"], 226 | n_connected=n_connected_comp, 227 | ) 228 | return ( 229 | n_connected_comp, 230 | batch_eigenvalues, 231 | nonlcc_indicator, 232 | k_lowest_eigenvector, 233 | ) 234 | else: 235 | raise NotImplementedError(f"Mode {self.mode} is not implemented") 236 | 237 | 238 | def compute_laplacian(adjacency, normalize: bool): 239 | """ 240 | adjacency : batched adjacency matrix (bs, n, n) 241 | normalize: can be None, 'sym' or 'rw' for the combinatorial, symmetric normalized or random walk Laplacians 242 | Return: 243 | L (n x n ndarray): combinatorial or symmetric normalized Laplacian. 244 | """ 245 | diag = torch.sum(adjacency, dim=-1) # (bs, n) 246 | n = diag.shape[-1] 247 | D = torch.diag_embed(diag) # Degree matrix # (bs, n, n) 248 | combinatorial = D - adjacency # (bs, n, n) 249 | 250 | if not normalize: 251 | return (combinatorial + combinatorial.transpose(1, 2)) / 2 252 | 253 | diag0 = diag.clone() 254 | diag[diag == 0] = 1e-12 255 | 256 | diag_norm = 1 / torch.sqrt(diag) # (bs, n) 257 | D_norm = torch.diag_embed(diag_norm) # (bs, n, n) 258 | L = torch.eye(n, device=adjacency.device).unsqueeze(0) - D_norm @ adjacency @ D_norm 259 | L[diag0 == 0] = 0 260 | return (L + L.transpose(1, 2)) / 2 261 | 262 | 263 | def get_eigenvalues_features(eigenvalues, k=5): 264 | """ 265 | values : eigenvalues -- (bs, n) 266 | node_mask: (bs, n) 267 | k: num of non zero eigenvalues to keep 268 | """ 269 | ev = eigenvalues 270 | bs, n = ev.shape 271 | n_connected_components = (ev < 1e-5).sum(dim=-1) 272 | try: 273 | assert (n_connected_components > 0).all(), (n_connected_components, ev) 274 | except: 275 | import pdb 276 | 277 | pdb.set_trace() 278 | 279 | to_extend = max(n_connected_components) + k - n 280 | if to_extend > 0: 281 | eigenvalues = torch.hstack( 282 | (eigenvalues, 2 * torch.ones(bs, to_extend).type_as(eigenvalues)) 283 | ) 284 | indices = torch.arange(k).type_as(eigenvalues).long().unsqueeze( 285 | 0 286 | ) + n_connected_components.unsqueeze(1) 287 | first_k_ev = torch.gather(eigenvalues, dim=1, index=indices) 288 | return n_connected_components.unsqueeze(-1), first_k_ev 289 | 290 | 291 | def get_eigenvectors_features(vectors, node_mask, n_connected, k=2): 292 | """ 293 | vectors (bs, n, n) : eigenvectors of Laplacian IN COLUMNS 294 | returns: 295 | not_lcc_indicator : indicator vectors of largest connected component (lcc) for each graph -- (bs, n, 1) 296 | k_lowest_eigvec : k first eigenvectors for the largest connected component -- (bs, n, k) 297 | """ 298 | bs, n = vectors.size(0), vectors.size(1) 299 | 300 | # Create an indicator for the nodes outside the largest connected components 301 | first_ev = torch.round(vectors[:, :, 0], decimals=3) * node_mask # bs, n 302 | # Add random value to the mask to prevent 0 from becoming the mode 303 | random = torch.randn(bs, n, device=node_mask.device) * (~node_mask) # bs, n 304 | first_ev = first_ev + random 305 | most_common = torch.mode(first_ev, dim=1).values # values: bs -- indices: bs 306 | mask = ~(first_ev == most_common.unsqueeze(1)) 307 | not_lcc_indicator = (mask * node_mask).unsqueeze(-1).float() 308 | 309 | # Get the eigenvectors corresponding to the first nonzero eigenvalues 310 | to_extend = max(n_connected) + k - n 311 | if to_extend > 0: 312 | vectors = torch.cat( 313 | (vectors, torch.zeros(bs, n, to_extend).type_as(vectors)), dim=2 314 | ) # bs, n , n + to_extend 315 | indices = torch.arange(k).type_as(vectors).long().unsqueeze(0).unsqueeze( 316 | 0 317 | ) + n_connected.unsqueeze( 318 | 2 319 | ) # bs, 1, k 320 | indices = indices.expand(-1, n, -1) # bs, n, k 321 | first_k_ev = torch.gather(vectors, dim=2, index=indices) # bs, n, k 322 | first_k_ev = first_k_ev * node_mask.unsqueeze(2) 323 | 324 | return not_lcc_indicator, first_k_ev 325 | 326 | 327 | def batch_trace(X): 328 | """ 329 | Expect a matrix of shape B N N, returns the trace in shape B 330 | :param X: 331 | :return: 332 | """ 333 | diag = torch.diagonal(X, dim1=-2, dim2=-1) 334 | trace = diag.sum(dim=-1) 335 | return trace 336 | 337 | 338 | def batch_diagonal(X): 339 | """ 340 | Extracts the diagonal from the last two dims of a tensor 341 | :param X: 342 | :return: 343 | """ 344 | return torch.diagonal(X, dim1=-2, dim2=-1) 345 | 346 | 347 | class KNodeCycles: 348 | """Builds cycle counts for each node in a graph.""" 349 | 350 | def __init__(self): 351 | super().__init__() 352 | 353 | def calculate_kpowers(self): 354 | self.k1_matrix = self.adj_matrix.float() 355 | self.d = self.adj_matrix.sum(dim=-1) 356 | self.k2_matrix = self.k1_matrix @ self.adj_matrix.float() 357 | self.k3_matrix = self.k2_matrix @ self.adj_matrix.float() 358 | self.k4_matrix = self.k3_matrix @ self.adj_matrix.float() 359 | self.k5_matrix = self.k4_matrix @ self.adj_matrix.float() 360 | self.k6_matrix = self.k5_matrix @ self.adj_matrix.float() 361 | 362 | def k3_cycle(self): 363 | """tr(A ** 3).""" 364 | c3 = batch_diagonal(self.k3_matrix) 365 | return (c3 / 2).unsqueeze(-1).float(), (torch.sum(c3, dim=-1) / 6).unsqueeze( 366 | -1 367 | ).float() 368 | 369 | def k4_cycle(self): 370 | diag_a4 = batch_diagonal(self.k4_matrix) 371 | c4 = ( 372 | diag_a4 373 | - self.d * (self.d - 1) 374 | - (self.adj_matrix @ self.d.unsqueeze(-1)).sum(dim=-1) 375 | ) 376 | return (c4 / 2).unsqueeze(-1).float(), (torch.sum(c4, dim=-1) / 8).unsqueeze( 377 | -1 378 | ).float() 379 | 380 | def k5_cycle(self): 381 | diag_a5 = batch_diagonal(self.k5_matrix) 382 | triangles = batch_diagonal(self.k3_matrix) 383 | c5 = ( 384 | diag_a5 385 | - 2 * triangles * self.d 386 | - (self.adj_matrix @ triangles.unsqueeze(-1)).sum(dim=-1) 387 | + triangles 388 | ) 389 | return (c5 / 2).unsqueeze(-1).float(), (c5.sum(dim=-1) / 10).unsqueeze( 390 | -1 391 | ).float() 392 | 393 | def k6_cycle(self): 394 | term_1_t = batch_trace(self.k6_matrix) 395 | term_2_t = batch_trace(self.k3_matrix**2) 396 | term3_t = torch.sum(self.adj_matrix * self.k2_matrix.pow(2), dim=[-2, -1]) 397 | d_t4 = batch_diagonal(self.k2_matrix) 398 | a_4_t = batch_diagonal(self.k4_matrix) 399 | term_4_t = (d_t4 * a_4_t).sum(dim=-1) 400 | term_5_t = batch_trace(self.k4_matrix) 401 | term_6_t = batch_trace(self.k3_matrix) 402 | term_7_t = batch_diagonal(self.k2_matrix).pow(3).sum(-1) 403 | term8_t = torch.sum(self.k3_matrix, dim=[-2, -1]) 404 | term9_t = batch_diagonal(self.k2_matrix).pow(2).sum(-1) 405 | term10_t = batch_trace(self.k2_matrix) 406 | 407 | c6_t = ( 408 | term_1_t 409 | - 3 * term_2_t 410 | + 9 * term3_t 411 | - 6 * term_4_t 412 | + 6 * term_5_t 413 | - 4 * term_6_t 414 | + 4 * term_7_t 415 | + 3 * term8_t 416 | - 12 * term9_t 417 | + 4 * term10_t 418 | ) 419 | return None, (c6_t / 12).unsqueeze(-1).float() 420 | 421 | def k_cycles(self, adj_matrix, verbose=False): 422 | self.adj_matrix = adj_matrix 423 | self.calculate_kpowers() 424 | 425 | k3x, k3y = self.k3_cycle() 426 | assert (k3x >= -0.1).all() 427 | 428 | k4x, k4y = self.k4_cycle() 429 | assert (k4x >= -0.1).all() 430 | 431 | k5x, k5y = self.k5_cycle() 432 | assert (k5x >= -0.1).all(), k5x 433 | 434 | _, k6y = self.k6_cycle() 435 | assert (k6y >= -0.1).all() 436 | 437 | kcyclesx = torch.cat([k3x, k4x, k5x], dim=-1) 438 | kcyclesy = torch.cat([k3y, k4y, k5y, k6y], dim=-1) 439 | return kcyclesx, kcyclesy 440 | --------------------------------------------------------------------------------