├── .gitignore ├── LICENSE ├── README.md ├── configs ├── __init__.py ├── config.yaml ├── dataset │ ├── canopus.yaml │ ├── fp2mol.yaml │ └── msg.yaml ├── general │ └── general_default.yaml ├── model │ └── model_default.yaml └── train │ └── train_default.yaml ├── data_processing ├── 00_download_fp2mol_data.sh ├── 01_download_canopus_data.sh ├── 02_download_msg_data.sh ├── 03_preprocess_fp2mol.sh └── build_fp2mol_datasets.py ├── figs └── diffms-animation.gif ├── notebooks ├── build_fp2mol_datasets.ipynb └── compute_metrics.ipynb ├── pyproject.toml └── src ├── __init__.py ├── analysis ├── __init__.py ├── rdkit_functions.py └── visualization.py ├── datasets ├── __init__.py ├── abstract_dataset.py ├── fp2mol_dataset.py ├── spec2mol_dataset.py └── spectra_utils.py ├── diffusion ├── __init__.py ├── diffusion_utils.py ├── distributions.py ├── extra_features.py ├── extra_features_molecular.py ├── layers.py └── noise_schedule.py ├── diffusion_model_fp2mol.py ├── diffusion_model_spec2mol.py ├── fp2mol_main.py ├── metrics ├── __init__.py ├── abstract_metrics.py ├── diffms_metrics.py ├── molecular_metrics.py ├── molecular_metrics_discrete.py └── train_metrics.py ├── mist ├── data │ ├── data.py │ ├── datasets.py │ ├── featurizers.py │ └── splitter.py ├── models │ ├── form_embedders.py │ ├── modules.py │ ├── spectra_encoder.py │ └── transformer_layer.py └── utils │ ├── __init__.py │ ├── chem_utils.py │ ├── misc_utils.py │ ├── parse_utils.py │ └── spectra_utils.py ├── models ├── __init__.py ├── layers.py └── transformer_model.py ├── spec2mol_main.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | .DS_Store 133 | .idea/ 134 | __pycache__/ 135 | outputs/ 136 | archives/* 137 | .env 138 | results/* 139 | logs/* 140 | 141 | checkpoints/* 142 | checkpoints_internal/* 143 | checkpoints_save/* 144 | notebooks/data/* 145 | data/* 146 | 147 | __pycache__/ 148 | **/__pycache__/ 149 | *.pyc 150 | train.sh -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2025 Montgomery Bohde, Mrunali Manjrekar, Runzhong Wang, Shuiwang Ji, Connor W. Coley 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffMS: Diffusion Generation of Molecules Conditioned on Mass Spectra 2 | 3 | ![teaser](./figs/diffms-animation.gif) 4 | 5 | This is the codebase for our preprint [DiffMS: Diffusion Generation of Molecules Conditioned on Mass Spectra](https://arxiv.org/abs/2502.09571). 6 | 7 | The DiffMS codebase is adapted from [DiGress](https://github.com/cvignac/DiGress). 8 | 9 | ## Environment installation 10 | This code was tested with PyTorch 2.3.1, cuda 11.8 and torch_geometrics 2.3.1 11 | 12 | - Download anaconda/miniconda if needed 13 | - Create a conda environment with rdkit: 14 | 15 | ``` 16 | conda create -y -c conda-forge -n diffms rdkit=2024.09.4 python=3.9 17 | conda activate diffms 18 | ``` 19 | 20 | - OR for a faster installation, you can use mamba: 21 | 22 | ``` 23 | mamba create -y -n diffms rdkit=2024.09.4 python=3.9 24 | mamba activate diffms 25 | ``` 26 | 27 | - Install a corresponding version of pytorch, for example: 28 | 29 | ```pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cu118``` 30 | 31 | - Run: 32 | 33 | ```pip install -e .``` 34 | 35 | 36 | ## Dataset Download/Processing 37 | 38 | We provide a series of scripts to download/process the pretraining and finetuning datasets. To download/setup the datasets, run the scripts in the data_processing/ folder in order: 39 | 40 | ``` 41 | bash data_processing/00_download_fp2mol_data.sh 42 | bash data_processing/01_download_canopus_data.sh 43 | bash data_processing/02_download_msg_data.sh 44 | bash data_processing/03_preprocess_fp2mol.sh 45 | ``` 46 | 47 | These scripts use unzip, which can be installed with ```sudo apt-get install unzip``` on Linux. If you are on a different OS, you many need to edit these scripts or run the command manually. 48 | 49 | ## Run the code 50 | 51 | For fingerprint-molecule pretraining run [fp2mol_main.py](src/fp2mol_main.py). You will need to set the dataset in [config.yaml](configs/config.yaml) to 'fp2mol'. The primary pretraining dataset in our paper is referred to as 'combined' in the [fp2mol.yaml](configs/dataset/fp2mol.yaml) config. 52 | 53 | To finetune the end-to-end model on spectra-molecule generation, run [spec2mol_main.py](src/spec2mol_main.py). You will also need to set the dataset in [config.yaml](configs/config.yaml) to 'msg' for MassSpecGym or 'canopus' for NPLIB1. 54 | 55 | ## Pretrained Checkpoints 56 | 57 | We provide checkpoints for the end-to-end finetuned DiffMS model as well as the pretrained encoder/decoder weights [here](https://zenodo.org/records/15122968). 58 | 59 | To load the pretrained DiffMS weights, set the load_weights argument in [general_default.yaml](configs/general/general_default.yaml) to the corresponding path. To use the pretrained encoder/decoder set the corresponding arguments in [general_default.yaml](configs/general/general_default.yaml). 60 | 61 | ## License 62 | 63 | DiffMS is released under the [MIT](LICENSE.txt) license. 64 | 65 | ## Contact 66 | 67 | If you have any questions, please reach out to mbohde@tamu.edu 68 | 69 | ## Reference 70 | If you find this codebase useful in your research, please kindly cite the following manuscript 71 | ``` 72 | @article{bohde2025diffms, 73 | title={DiffMS: Diffusion Generation of Molecules Conditioned on Mass Spectra}, 74 | author={Bohde, Montgomery and Manjrekar, Mrunali and Wang, Runzhong and Ji, Shuiwang and Coley, Connor W}, 75 | journal={arXiv preprint arXiv:2502.09571}, 76 | year={2025} 77 | } 78 | ``` 79 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/DiffMS/9d9f4fd497162eec045f7db2787da30ce69a9622/configs/__init__.py -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - general : general_default 4 | - model : model_default 5 | - train : train_default 6 | - dataset : canopus 7 | 8 | hydra: 9 | job: 10 | chdir: True 11 | run: 12 | dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}-${general.name} -------------------------------------------------------------------------------- /configs/dataset/canopus.yaml: -------------------------------------------------------------------------------- 1 | name: canopus 2 | remove_h: null 3 | stats_dir: null 4 | datadir: '../../../data/canopus' 5 | filter: False 6 | denoise_nodes: False 7 | merge: 'downproject_4096' # 'none' | 'mist_fp' | 'merge-encoder_output-linear' | 'merge-encoder_output-mlp' | 'downproject_4096' 8 | morgan_nbits: 2048 9 | morgan_r: 2 10 | split_file: '../../../data/canopus/splits/canopus_hplus_100_0.tsv' 11 | spec_features: 'peakformula' 12 | mol_features: 'fingerprint' 13 | subform_folder: '../../../data/canopus/subformulae/subformulae_default' 14 | augment_data: False 15 | remove_prob: 0.1 16 | remove_weights: 'exp' 17 | inten_prob: 0.1 18 | inten_transform: 'float' 19 | cls_type: 'ms1' 20 | magma_aux_loss: False 21 | labels_file: '../../../data/canopus/labels.tsv' 22 | spec_folder: '../../../data/canopus/spec_files' 23 | cache_featurizers: True 24 | set_pooling: 'cls' 25 | max_count: null 26 | -------------------------------------------------------------------------------- /configs/dataset/fp2mol.yaml: -------------------------------------------------------------------------------- 1 | name: fp2mol 2 | dataset: combined 3 | stats_dir: null 4 | remove_h: null 5 | datadir: '../../../data/fp2mol/' 6 | filter: False 7 | denoise_nodes: False 8 | morgan_nbits: 2048 9 | morgan_r: 2 -------------------------------------------------------------------------------- /configs/dataset/msg.yaml: -------------------------------------------------------------------------------- 1 | name: msg 2 | remove_h: null 3 | stats_dir: null 4 | datadir: '../../../data/msg' 5 | filter: False 6 | denoise_nodes: False 7 | merge: 'downproject_4096' # 'none' | 'mist_fp' | 'merge-encoder_output-linear' | 'merge-encoder_output-mlp' | 'downproject_4096' 8 | morgan_nbits: 2048 9 | morgan_r: 2 10 | split_file: '../../../data/msg/split.tsv' 11 | spec_features: 'peakformula' 12 | mol_features: 'fingerprint' 13 | subform_folder: '../../../data/msg/subformulae/default_subformulae' 14 | augment_data: False 15 | remove_prob: 0.1 16 | remove_weights: 'exp' 17 | inten_prob: 0.1 18 | inten_transform: 'float' 19 | cls_type: 'ms1' 20 | magma_aux_loss: False 21 | labels_file: '../../../data/msg/labels.tsv' 22 | spec_folder: '../../../data/msg/spec_files' 23 | cache_featurizers: True 24 | set_pooling: 'cls' 25 | max_count: null 26 | -------------------------------------------------------------------------------- /configs/general/general_default.yaml: -------------------------------------------------------------------------------- 1 | # General settings 2 | name: 'dev' # Warning: 'debug' and 'test' are reserved name that have a special behavior 3 | 4 | parent_dir: '.' 5 | 6 | wandb: 'online' # online | offline | disabled 7 | wandb_name: 'mass_spec_exp' 8 | gpus: 1 # Multi-gpu is not implemented on this branch 9 | 10 | decoder: null # path to pretrained decoder 11 | encoder: null # path to pretrained encoder 12 | 13 | resume: null 14 | test_only: null 15 | load_weights: null 16 | 17 | encoder_finetune_strategy: null # null | freeze | ft-unfold | freeze-unfold | freeze-transformer | ft-transformer 18 | decoder_finetune_strategy: null # null | freeze | ft-input | freeze-input | ft-transformer | freeze-transformer | ft-output 19 | 20 | check_val_every_n_epochs: 1 21 | sample_every_val: 1000 22 | val_samples_to_generate: 100 23 | test_samples_to_generate: 100 24 | log_every_steps: 50 25 | 26 | evaluate_all_checkpoints: False 27 | checkpoint_strategy: 'last' -------------------------------------------------------------------------------- /configs/model/model_default.yaml: -------------------------------------------------------------------------------- 1 | # Model settings 2 | transition: 'marginal' 3 | model: 'graph_tf' # 'graph_tf', 'graph_tf_v2', 4 | diffusion_steps: 500 5 | diffusion_noise_schedule: 'cosine' 6 | n_layers: 5 7 | 8 | extra_features: 'all' # 'all', 'cycles', 'eigenvalues' or null 9 | 10 | hidden_mlp_dims: {'X': 256, 'E': 128, 'y': 2048} 11 | 12 | hidden_dims : {'dx': 256, 'de': 64, 'dy': 1024, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 1024} 13 | 14 | encoder_hidden_dim: 256 # Small Model Default (CANOPUS) 15 | #encoder_hidden_dim: 512 # Large Model Default (MSG) 16 | 17 | encoder_magma_modulo: 512 # Small Model Default (CANOPUS) 18 | #encoder_magma_modulo: 2048 # Large Model Default (MSG) 19 | 20 | lambda_train: [0, 1, 0] 21 | -------------------------------------------------------------------------------- /configs/train/train_default.yaml: -------------------------------------------------------------------------------- 1 | # Training settings 2 | n_epochs: 75 3 | batch_size: 96 4 | eval_batch_size: 128 5 | lr: 0.0015 # 0.0015 for training, 0.0002 for fine-tuning 6 | clip_grad: null # float, null to disable 7 | save_model: True 8 | num_workers: 1 9 | pin_memory: True 10 | ema_decay: 0 # EMA decay current not implemented 11 | progress_bar: false 12 | weight_decay: 1e-12 13 | optimizer: adamw # adamw | nadamw | nadam 14 | scheduler: 'one_cycle' # 'const' | 'one_cycle' 15 | pct_start: 0.3 16 | seed: 123 17 | -------------------------------------------------------------------------------- /data_processing/00_download_fp2mol_data.sh: -------------------------------------------------------------------------------- 1 | # build datadir file structure 2 | mkdir data/ 3 | mkdir data/fp2mol/ 4 | mkdir data/fp2mol/raw/ 5 | 6 | cd data/fp2mol/raw/ 7 | 8 | # download raw data 9 | wget https://hmdb.ca/system/downloads/current/structures.zip 10 | unzip structures.zip 11 | 12 | wget https://clowder.edap-cluster.com/api/files/6616d8d7e4b063812d70fc95/blob 13 | unzip blob 14 | 15 | wget https://coconut.s3.uni-jena.de/prod/downloads/2025-03/coconut_csv-03-2025.zip 16 | unzip coconut_csv-03-2025.zip 17 | 18 | wget https://media.githubusercontent.com/media/molecularsets/moses/master/data/dataset_v1.csv 19 | mv dataset_v1.csv moses.csv -------------------------------------------------------------------------------- /data_processing/01_download_canopus_data.sh: -------------------------------------------------------------------------------- 1 | # script adapted from MIST repo: https://github.com/samgoldman97/mist/blob/main_v2/data_processing/canopus_train/00_download_canopus_data.sh 2 | 3 | # Original data link 4 | #SVM_URL="https://bio.informatik.uni-jena.de/wp/wp-content/uploads/2020/08/svm_training_data.zip" 5 | 6 | export_link="https://zenodo.org/record/8316682/files/canopus_train_export_v2.tar" 7 | 8 | mkdir data/ 9 | 10 | cd data/ 11 | wget -O canopus_train_export.tar $export_link 12 | 13 | tar -xvf canopus_train_export.tar 14 | mv canopus_train_export canopus 15 | rm -f canopus_train_export.tar -------------------------------------------------------------------------------- /data_processing/02_download_msg_data.sh: -------------------------------------------------------------------------------- 1 | # script adapted from MIST repo: https://github.com/samgoldman97/mist/blob/main_v2/data_processing/canopus_train/00_download_canopus_data.sh 2 | # This script downloads preprocessed data from the MassSpecGym project 3 | # Original MassSpecGym code/data: https://github.com/pluskal-lab/MassSpecGym 4 | 5 | export_link="https://zenodo.org/records/15008938/files/msg_preprocessed.tar.gz" 6 | 7 | mkdir data/ 8 | cd data/ 9 | 10 | wget $export_link 11 | 12 | tar -xvzf msg_preprocessed.tar.gz 13 | 14 | rm -f msg_preprocessed.tar.gz -------------------------------------------------------------------------------- /data_processing/03_preprocess_fp2mol.sh: -------------------------------------------------------------------------------- 1 | for dataset in hmdb dss coconut moses canopus msg combined 2 | do 3 | mkdir data/fp2mol/$dataset/ 4 | mkdir data/fp2mol/$dataset/preprocessed/ 5 | mkdir data/fp2mol/$dataset/processed/ 6 | mkdir data/fp2mol/$dataset/stats/ 7 | done 8 | 9 | cd data_processing/ 10 | python build_fp2mol_datasets.py -------------------------------------------------------------------------------- /data_processing/build_fp2mol_datasets.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import Counter 3 | 4 | import pandas as pd 5 | from tqdm import tqdm 6 | 7 | from rdkit import Chem 8 | from rdkit import RDLogger 9 | from rdkit.Chem import Descriptors 10 | 11 | random.seed(42) 12 | 13 | lg = RDLogger.logger() 14 | lg.setLevel(RDLogger.CRITICAL) 15 | 16 | def read_from_sdf(path): 17 | res = [] 18 | app = False 19 | with open(path, 'r') as f: 20 | for line in tqdm(f.readlines(), desc='Loading SDF structures', leave=False): 21 | if app: 22 | res.append(line.strip()) 23 | app = False 24 | if line.startswith('> '): 25 | app = True 26 | 27 | return res 28 | 29 | def filter(mol): 30 | try: 31 | smi = Chem.MolToSmiles(mol, isomericSmiles=False) # remove stereochemistry information 32 | mol = Chem.MolFromSmiles(smi) 33 | 34 | if "." in smi: 35 | return False 36 | 37 | if Descriptors.MolWt(mol) >= 1500: 38 | return False 39 | 40 | for atom in mol.GetAtoms(): 41 | if atom.GetFormalCharge() != 0: 42 | return False 43 | except: 44 | return False 45 | 46 | return True 47 | 48 | FILTER_ATOMS = {'C', 'N', 'S', 'O', 'F', 'Cl', 'H', 'P'} 49 | 50 | def filter_with_atom_types(mol): 51 | try: 52 | smi = Chem.MolToSmiles(mol, isomericSmiles=False) # remove stereochemistry information 53 | mol = Chem.MolFromSmiles(smi) 54 | 55 | if "." in smi: 56 | return False 57 | 58 | if Descriptors.MolWt(mol) >= 1500: 59 | return False 60 | 61 | for atom in mol.GetAtoms(): 62 | if atom.GetFormalCharge() != 0: 63 | return False 64 | if atom.GetSymbol() not in FILTER_ATOMS: 65 | return False 66 | except: 67 | return False 68 | 69 | return True 70 | 71 | ########## CANOPUS DATASET ########## 72 | 73 | canopus_split = pd.read_csv('../data/canopus/splits/canopus_hplus_100_0.tsv', sep='\t') 74 | 75 | canopus_labels = pd.read_csv('../data/canopus/labels.tsv', sep='\t') 76 | canopus_labels["name"] = canopus_labels["spec"] 77 | canopus_labels = canopus_labels[["name", "smiles"]].reset_index(drop=True) 78 | 79 | canopus_labels = canopus_labels.merge(canopus_split, on="name") 80 | 81 | canopus_train_inchis = [] 82 | canopus_test_inchis = [] 83 | canopus_val_inchis = [] 84 | 85 | for i in tqdm(range(len(canopus_labels)), desc="Converting CANOPUS SMILES to InChI", leave=False): 86 | 87 | mol = Chem.MolFromSmiles(canopus_labels.loc[i, "smiles"]) 88 | smi = Chem.MolToSmiles(mol, isomericSmiles=False) # remove stereochemistry information 89 | mol = Chem.MolFromSmiles(smi) 90 | inchi = Chem.MolToInchi(mol) 91 | 92 | if canopus_labels.loc[i, "split"] == "train": 93 | if filter(mol): 94 | canopus_train_inchis.append(inchi) 95 | elif canopus_labels.loc[i, "split"] == "test": 96 | canopus_test_inchis.append(inchi) 97 | elif canopus_labels.loc[i, "split"] == "val": 98 | canopus_val_inchis.append(inchi) 99 | 100 | canopus_train_df = pd.DataFrame(set(canopus_train_inchis), columns=["inchi"]) 101 | canopus_train_df.to_csv("../data/fp2mol/canopus/preprocessed/canopus_train.csv", index=False) 102 | 103 | canopus_test_df = pd.DataFrame(canopus_test_inchis, columns=["inchi"]) 104 | canopus_test_df.to_csv("../data/fp2mol/canopus/preprocessed/canopus_test.csv", index=False) 105 | 106 | canopus_val_df = pd.DataFrame(canopus_val_inchis, columns=["inchi"]) 107 | canopus_val_df.to_csv("../data/fp2mol/canopus/preprocessed/canopus_val.csv", index=False) 108 | 109 | excluded_inchis = set(canopus_test_inchis + canopus_val_inchis) 110 | 111 | ########## MSG DATASET ########## 112 | 113 | msg_split = pd.read_csv('../data/msg/split.tsv', sep='\t') 114 | 115 | msg_labels = pd.read_csv('../data/msg/labels.tsv', sep='\t') 116 | msg_labels["name"] = msg_labels["spec"] 117 | msg_labels = msg_labels[["name", "smiles"]].reset_index(drop=True) 118 | 119 | msg_labels = msg_labels.merge(msg_split, on="name") 120 | 121 | msg_train_inchis = [] 122 | msg_test_inchis = [] 123 | msg_val_inchis = [] 124 | 125 | for i in tqdm(range(len(msg_labels)), desc="Converting MSG SMILES to InChI", leave=False): 126 | 127 | mol = Chem.MolFromSmiles(msg_labels.loc[i, "smiles"]) 128 | smi = Chem.MolToSmiles(mol, isomericSmiles=False) # remove stereochemistry information 129 | mol = Chem.MolFromSmiles(smi) 130 | inchi = Chem.MolToInchi(mol) 131 | 132 | if msg_labels.loc[i, "split"] == "train": 133 | if filter(mol): 134 | msg_train_inchis.append(inchi) 135 | elif msg_labels.loc[i, "split"] == "test": 136 | msg_test_inchis.append(inchi) 137 | elif msg_labels.loc[i, "split"] == "val": 138 | msg_val_inchis.append(inchi) 139 | 140 | msg_train_df = pd.DataFrame(set(msg_train_inchis), columns=["inchi"]) 141 | msg_train_df.to_csv("../data/fp2mol/msg/preprocessed/msg_train.csv", index=False) 142 | 143 | msg_test_df = pd.DataFrame(msg_test_inchis, columns=["inchi"]) 144 | msg_test_df.to_csv("../data/fp2mol/msg/preprocessed/msg_test.csv", index=False) 145 | 146 | msg_val_df = pd.DataFrame(msg_val_inchis, columns=["inchi"]) 147 | msg_val_df.to_csv("../data/fp2mol/msg/preprocessed/msg_val.csv", index=False) 148 | 149 | excluded_inchis.update(msg_test_inchis + msg_val_inchis) 150 | 151 | ########## HMDB DATASET ########## 152 | 153 | hmdb_set = set() 154 | raw_smiles = read_from_sdf('../data/fp2mol/raw/structures.sdf') 155 | for smi in tqdm(raw_smiles, desc='Cleaning HMDB structures', leave=False): 156 | try: 157 | mol = Chem.MolFromSmiles(smi) 158 | smi = Chem.MolToSmiles(mol, isomericSmiles=False) # remove stereochemistry information 159 | mol = Chem.MolFromSmiles(smi) 160 | if filter_with_atom_types(mol): 161 | hmdb_set.add(Chem.MolToInchi(mol)) 162 | except: 163 | pass 164 | 165 | hmdb_inchis = list(hmdb_set) 166 | random.shuffle(hmdb_inchis) 167 | 168 | hmdb_train_inchis = hmdb_inchis[:int(0.95 * len(hmdb_inchis))] 169 | hmdb_val_inchis = hmdb_inchis[int(0.95 * len(hmdb_inchis)):] 170 | 171 | hmdb_train_inchis = [inchi for inchi in hmdb_train_inchis if inchi not in excluded_inchis] 172 | 173 | hmdb_train_df = pd.DataFrame(hmdb_train_inchis, columns=["inchi"]) 174 | hmdb_train_df.to_csv("../data/fp2mol/hmdb/preprocessed/hmdb_train.csv", index=False) 175 | 176 | hmdb_val_df = pd.DataFrame(hmdb_val_inchis, columns=["inchi"]) 177 | hmdb_val_df.to_csv("../data/fp2mol/hmdb/preprocessed/hmdb_val.csv", index=False) 178 | 179 | ########## DSSTox DATASET ########## 180 | 181 | dss_set_raw = set() 182 | for i in tqdm(range(1, 14), desc='Loading DSSTox structures', leave=False): 183 | df = pd.read_excel(f'../data/fp2mol/raw/DSSToxDump{i}.xlsx') 184 | dss_set_raw.update(df[df['SMILES'].notnull()]['SMILES']) 185 | 186 | dss_set = set() 187 | for smi in tqdm(dss_set_raw, desc='Cleaning DSSTox structures', leave=False): 188 | try: 189 | mol = Chem.MolFromSmiles(smi) 190 | smi = Chem.MolToSmiles(mol, isomericSmiles=False) # remove stereochemistry information 191 | mol = Chem.MolFromSmiles(smi) 192 | if filter_with_atom_types(mol): 193 | dss_set.add(Chem.MolToInchi(mol)) 194 | except: 195 | pass 196 | 197 | dss_inchis = list(dss_set) 198 | random.shuffle(dss_inchis) 199 | 200 | dss_train_inchis = dss_inchis[:int(0.95 * len(dss_inchis))] 201 | dss_val_inchis = dss_inchis[int(0.95 * len(dss_inchis)):] 202 | 203 | dss_train_inchis = [inchi for inchi in dss_train_inchis if inchi not in excluded_inchis] 204 | 205 | dss_train_df = pd.DataFrame(dss_train_inchis, columns=["inchi"]) 206 | dss_train_df.to_csv("../data/fp2mol/dss/preprocessed/dss_train.csv", index=False) 207 | 208 | dss_val_df = pd.DataFrame(dss_val_inchis, columns=["inchi"]) 209 | dss_val_df.to_csv("../data/fp2mol/dss/preprocessed/dss_val.csv", index=False) 210 | 211 | ########## COCONUT DATASET ########## 212 | 213 | coconut_df = pd.read_csv('../data/fp2mol/raw/coconut_csv-03-2025.csv') 214 | 215 | coconut_set_raw = set(coconut_df["canonical_smiles"]) 216 | 217 | coconut_set = set() 218 | for smi in tqdm(coconut_set_raw, desc='Cleaning COCONUT structures', leave=False): 219 | try: 220 | mol = Chem.MolFromSmiles(smi) 221 | smi = Chem.MolToSmiles(mol, isomericSmiles=False) # remove stereochemistry information 222 | mol = Chem.MolFromSmiles(smi) 223 | if filter_with_atom_types(mol): 224 | coconut_set.add(Chem.MolToInchi(mol)) 225 | except: 226 | pass 227 | 228 | coconut_inchis = list(coconut_set) 229 | random.shuffle(coconut_inchis) 230 | 231 | coconut_train_inchis = coconut_inchis[:int(0.95 * len(coconut_inchis))] 232 | coconut_val_inchis = coconut_inchis[int(0.95 * len(coconut_inchis)):] 233 | 234 | coconut_train_inchis = [inchi for inchi in coconut_train_inchis if inchi not in excluded_inchis] 235 | 236 | coconut_train_df = pd.DataFrame(coconut_train_inchis, columns=["inchi"]) 237 | coconut_train_df.to_csv("../data/fp2mol/coconut/preprocessed/coconut_train.csv", index=False) 238 | 239 | coconut_val_df = pd.DataFrame(coconut_val_inchis, columns=["inchi"]) 240 | coconut_val_df.to_csv("../data/fp2mol/coconut/preprocessed/coconut_val.csv", index=False) 241 | 242 | 243 | ########## MOSES DATASET ########## 244 | 245 | moses_df = pd.read_csv('../data/fp2mol/raw/moses.csv') 246 | 247 | moses_set_raw = set(moses_df["SMILES"]) 248 | 249 | moses_set = set() 250 | for smi in tqdm(moses_set_raw, desc='Cleaning MOSES structures', leave=False): 251 | try: 252 | mol = Chem.MolFromSmiles(smi) 253 | smi = Chem.MolToSmiles(mol, isomericSmiles=False) # remove stereochemistry information 254 | mol = Chem.MolFromSmiles(smi) 255 | if filter_with_atom_types(mol): 256 | moses_set.add(Chem.MolToInchi(mol)) 257 | except: 258 | pass 259 | 260 | moses_inchis = list(moses_set) 261 | random.shuffle(moses_inchis) 262 | 263 | moses_train_inchis = moses_inchis[:int(0.95 * len(moses_inchis))] 264 | moses_val_inchis = moses_inchis[int(0.95 * len(moses_inchis)):] 265 | 266 | moses_train_inchis = [inchi for inchi in moses_train_inchis if inchi not in excluded_inchis] 267 | 268 | moses_train_df = pd.DataFrame(moses_train_inchis, columns=["inchi"]) 269 | moses_train_df.to_csv("../data/fp2mol/moses/preprocessed/moses_train.csv", index=False) 270 | 271 | moses_val_df = pd.DataFrame(moses_val_inchis, columns=["inchi"]) 272 | moses_val_df.to_csv("../data/fp2mol/moses/preprocessed/moses_val.csv", index=False) 273 | 274 | ########## COMBINED DATASET ########## 275 | 276 | combined_inchis = hmdb_inchis + dss_inchis + coconut_inchis + moses_inchis 277 | combined_inchis = list(set(combined_inchis)) 278 | random.shuffle(combined_inchis) 279 | 280 | combined_train_inchis = combined_inchis[:int(0.95 * len(combined_inchis))] 281 | combined_val_inchis = combined_inchis[int(0.95 * len(combined_inchis)):] 282 | combined_train_inchis = [inchi for inchi in combined_train_inchis if inchi not in excluded_inchis] 283 | 284 | combined_train_df = pd.DataFrame(combined_train_inchis, columns=["inchi"]) 285 | combined_train_df.to_csv("../data/fp2mol/combined/preprocessed/combined_train.csv", index=False) 286 | 287 | combined_val_df = pd.DataFrame(combined_val_inchis, columns=["inchi"]) 288 | combined_val_df.to_csv("../data/fp2mol/combined/preprocessed/combined_val.csv", index=False) -------------------------------------------------------------------------------- /figs/diffms-animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/DiffMS/9d9f4fd497162eec045f7db2787da30ce69a9622/figs/diffms-animation.gif -------------------------------------------------------------------------------- /notebooks/compute_metrics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pickle\n", 10 | "import os\n", 11 | "from collections import Counter, defaultdict\n", 12 | "\n", 13 | "import pulp\n", 14 | "from rdkit import Chem\n", 15 | "from rdkit.Chem import AllChem\n", 16 | "from rdkit.Chem import DataStructs\n", 17 | "import pandas as pd\n", 18 | "from myopic_mces import MCES\n", 19 | "from joblib import Parallel, delayed\n", 20 | "from tqdm import tqdm\n", 21 | "from tqdm_joblib import tqdm_joblib\n", 22 | "\n", 23 | "\n", 24 | "from rdkit import RDLogger\n", 25 | "RDLogger.DisableLog('rdApp.*')" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 7, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "try:\n", 35 | " from rdkit.Chem.MolStandardize.tautomer import TautomerCanonicalizer, TautomerTransform\n", 36 | " _RD_TAUTOMER_CANONICALIZER = 'v1'\n", 37 | " _TAUTOMER_TRANSFORMS = (\n", 38 | " TautomerTransform('1,3 heteroatom H shift',\n", 39 | " '[#7,S,O,Se,Te;!H0]-[#7X2,#6,#15]=[#7,#16,#8,Se,Te]'),\n", 40 | " TautomerTransform('1,3 (thio)keto/enol r', '[O,S,Se,Te;X2!H0]-[C]=[C]'),\n", 41 | " )\n", 42 | "except ModuleNotFoundError:\n", 43 | " from rdkit.Chem.MolStandardize.rdMolStandardize import TautomerEnumerator # newer rdkit\n", 44 | " _RD_TAUTOMER_CANONICALIZER = 'v2'\n", 45 | "\n", 46 | "def canonical_mol_from_inchi(inchi):\n", 47 | " \"\"\"Canonicalize mol after Chem.MolFromInchi\n", 48 | " Note that this function may be 50 times slower than Chem.MolFromInchi\"\"\"\n", 49 | " mol = Chem.MolFromInchi(inchi)\n", 50 | " if mol is None:\n", 51 | " return None\n", 52 | " if _RD_TAUTOMER_CANONICALIZER == 'v1':\n", 53 | " _molvs_t = TautomerCanonicalizer(transforms=_TAUTOMER_TRANSFORMS)\n", 54 | " mol = _molvs_t.canonicalize(mol)\n", 55 | " else:\n", 56 | " _te = TautomerEnumerator()\n", 57 | " mol = _te.Canonicalize(mol)\n", 58 | " return mol\n", 59 | "\n", 60 | "def mol2smiles(mol):\n", 61 | " try:\n", 62 | " Chem.SanitizeMol(mol)\n", 63 | " except ValueError:\n", 64 | " return None\n", 65 | " return Chem.MolToSmiles(mol)\n", 66 | "\n", 67 | "def is_valid(mol):\n", 68 | " smiles = mol2smiles(mol)\n", 69 | " if smiles is None:\n", 70 | " return False\n", 71 | "\n", 72 | " try:\n", 73 | " mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True)\n", 74 | " except:\n", 75 | " return False\n", 76 | " if len(mol_frags) > 1:\n", 77 | " return False\n", 78 | " \n", 79 | " return True\n", 80 | "\n", 81 | "def compute_metrics_for_one(t_inchi, p_inchi, solver, doMCES=False, doFull=False):\n", 82 | " RDLogger.DisableLog('rdApp.*')\n", 83 | "\n", 84 | " true_mol = canonical_mol_from_inchi(t_inchi)\n", 85 | " true_smi = Chem.MolToSmiles(true_mol)\n", 86 | " true_fp = AllChem.GetMorganFingerprintAsBitVect(true_mol, 2, nBits=2048)\n", 87 | " true_num_bonds = true_mol.GetNumBonds()\n", 88 | "\n", 89 | " # Precompute metrics for each predicted molecule\n", 90 | " p_mces = []\n", 91 | " p_tanimoto = []\n", 92 | " p_cosine = []\n", 93 | " for pi in p_inchi:\n", 94 | " pmol = canonical_mol_from_inchi(pi)\n", 95 | "\n", 96 | " try:\n", 97 | " pmol_smi = Chem.MolToSmiles(pmol)\n", 98 | " if doMCES:\n", 99 | " p_mces.append(MCES(true_smi, pmol_smi, solver=solver, threshold=100, always_stronger_bound=False, solver_options=dict(msg=0))[1])\n", 100 | " else:\n", 101 | " p_mces.append(true_num_bonds + pmol.GetNumBonds())\n", 102 | " except:\n", 103 | " p_mces.append(true_num_bonds + pmol.GetNumBonds())\n", 104 | "\n", 105 | " try:\n", 106 | " pmol_fp = AllChem.GetMorganFingerprintAsBitVect(pmol, 2, nBits=2048)\n", 107 | " p_tanimoto.append(DataStructs.TanimotoSimilarity(true_fp, pmol_fp))\n", 108 | " p_cosine.append(DataStructs.CosineSimilarity(true_fp, pmol_fp))\n", 109 | " except:\n", 110 | " p_tanimoto.append(0.0)\n", 111 | " p_cosine.append(0.0)\n", 112 | "\n", 113 | " # Build prefix arrays for best (min) MCES, best (max) Tanimoto, best (max) Cosine\n", 114 | " prefix_min_mces = [100]\n", 115 | " prefix_max_tanimoto = [0.0]\n", 116 | " prefix_max_cosine = [0.0]\n", 117 | " for j in range(len(p_inchi)):\n", 118 | " prefix_min_mces.append(min(prefix_min_mces[-1], p_mces[j]))\n", 119 | " prefix_max_tanimoto.append(max(prefix_max_tanimoto[-1], p_tanimoto[j]))\n", 120 | " prefix_max_cosine.append(max(prefix_max_cosine[-1], p_cosine[j]))\n", 121 | "\n", 122 | " # Earliest index of true InChI, if present\n", 123 | " try:\n", 124 | " earliest_idx = p_inchi.index(t_inchi)\n", 125 | " except ValueError:\n", 126 | " earliest_idx = -1\n", 127 | "\n", 128 | " if doFull:\n", 129 | " # Compute metrics using prefix arrays\n", 130 | " m_local = defaultdict(float)\n", 131 | " for k in range(1, 101):\n", 132 | " m_local[f'acc@{k}'] = 1.0 if (earliest_idx != -1 and earliest_idx < k) else 0.0\n", 133 | " idx = min(k, len(p_inchi))\n", 134 | " m_local[f'mces@{k}'] = prefix_min_mces[idx]\n", 135 | " m_local[f'tanimoto@{k}'] = prefix_max_tanimoto[idx]\n", 136 | " m_local[f'cosine@{k}'] = prefix_max_cosine[idx]\n", 137 | " m_local[f'close_match@{k}'] = 1.0 if (prefix_max_tanimoto[idx] >= 0.675) else 0.0\n", 138 | " m_local[f'meaningful_match@{k}'] = 1.0 if (prefix_max_tanimoto[idx] >= 0.4) else 0.0\n", 139 | " else:\n", 140 | " m_local = defaultdict(float)\n", 141 | " for k in range(1, 11):\n", 142 | " m_local[f'acc@{k}'] = 1.0 if (earliest_idx != -1 and earliest_idx < k) else 0.0\n", 143 | " idx = min(k, len(p_inchi))\n", 144 | " m_local[f'mces@{k}'] = prefix_min_mces[idx]\n", 145 | " m_local[f'tanimoto@{k}'] = prefix_max_tanimoto[idx]\n", 146 | " m_local[f'cosine@{k}'] = prefix_max_cosine[idx]\n", 147 | " m_local[f'close_match@{k}'] = 1.0 if (prefix_max_tanimoto[idx] >= 0.675) else 0.0\n", 148 | " m_local[f'meaningful_match@{k}'] = 1.0 if (prefix_max_tanimoto[idx] >= 0.4) else 0.0\n", 149 | "\n", 150 | " return m_local\n", 151 | "\n", 152 | "def compute_metrics(true, pred, csv_path, doMCES=False, doFull=False):\n", 153 | " true_inchi = []\n", 154 | " pred_inchi = []\n", 155 | " for i in range(len(true)):\n", 156 | " local_pred_inchi = []\n", 157 | " for j in range(len(pred[i])):\n", 158 | " if is_valid(pred[i][j]):\n", 159 | " local_pred_inchi.append(Chem.MolToInchi(pred[i][j]))\n", 160 | "\n", 161 | " # sort local_pred_inchi by frequency\n", 162 | " inchi_counts = Counter(local_pred_inchi)\n", 163 | " local_pred_inchi = [item for item, count in inchi_counts.most_common()]\n", 164 | "\n", 165 | " if not doFull:\n", 166 | " local_pred_inchi = local_pred_inchi[:11]\n", 167 | "\n", 168 | " pred_inchi.append(local_pred_inchi)\n", 169 | " true_inchi.append(Chem.MolToInchi(true[i]))\n", 170 | "\n", 171 | " solver = pulp.listSolvers(onlyAvailable=True)[0]\n", 172 | "\n", 173 | " with tqdm_joblib(tqdm(total=len(true_inchi))) as progress_bar:\n", 174 | " results = Parallel(n_jobs=-1)(\n", 175 | " delayed(compute_metrics_for_one)(\n", 176 | " true_inchi[i],\n", 177 | " pred_inchi[i],\n", 178 | " solver,\n", 179 | " doMCES=doMCES,\n", 180 | " doFull=doFull\n", 181 | " )\n", 182 | " for i in range(len(true_inchi))\n", 183 | " )\n", 184 | "\n", 185 | " # aggregate results\n", 186 | " final_metrics = defaultdict(float)\n", 187 | " for r in results:\n", 188 | " for key, val in r.items():\n", 189 | " final_metrics[key] += val\n", 190 | "\n", 191 | " if doFull:\n", 192 | " for k in range(1, 101):\n", 193 | " final_metrics[f'acc@{k}'] /= len(true_inchi)\n", 194 | " final_metrics[f'mces@{k}'] /= len(true_inchi)\n", 195 | " final_metrics[f'tanimoto@{k}'] /= len(true_inchi)\n", 196 | " final_metrics[f'cosine@{k}'] /= len(true_inchi)\n", 197 | " final_metrics[f'close_match@{k}'] /= len(true_inchi)\n", 198 | " final_metrics[f'meaningful_match@{k}'] /= len(true_inchi)\n", 199 | " else:\n", 200 | " for k in range(1, 11):\n", 201 | " final_metrics[f'acc@{k}'] /= len(true_inchi)\n", 202 | " final_metrics[f'mces@{k}'] /= len(true_inchi)\n", 203 | " final_metrics[f'tanimoto@{k}'] /= len(true_inchi)\n", 204 | " final_metrics[f'cosine@{k}'] /= len(true_inchi)\n", 205 | " final_metrics[f'close_match@{k}'] /= len(true_inchi)\n", 206 | " final_metrics[f'meaningful_match@{k}'] /= len(true_inchi)\n", 207 | "\n", 208 | " df = pd.DataFrame(final_metrics, index=[0])\n", 209 | " df.to_csv(csv_path, index=False)\n" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "# example code of loading model predictions as saved in the diffusion_model_spec2mol.py test step\n", 219 | "# paths/loading will be different\n", 220 | "\n", 221 | "canopus_true = []\n", 222 | "canopus_pred = []\n", 223 | "for idx in range(1, 5):\n", 224 | " i = idx-1\n", 225 | " while os.path.exists(f\"../final_results/canopus/spec2mol-canopus-eval-{idx}_resume_pred_{i}.pkl\"):\n", 226 | " with open(f\"../final_results/canopus/spec2mol-canopus-eval-{idx}_resume_true_{i}.pkl\", 'rb') as f:\n", 227 | " canopus_true.extend(pickle.load(f))\n", 228 | " with open(f\"../final_results/canopus/spec2mol-canopus-eval-{idx}_resume_pred_{i}.pkl\", 'rb') as f:\n", 229 | " canopus_pred.extend(pickle.load(f))\n", 230 | " i += 4" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "compute_metrics(canopus_true, canopus_pred, \"canopus_metrics.csv\", doMCES=False, doFull=True)" 240 | ] 241 | } 242 | ], 243 | "metadata": { 244 | "kernelspec": { 245 | "display_name": "diffms", 246 | "language": "python", 247 | "name": "python3" 248 | }, 249 | "language_info": { 250 | "codemirror_mode": { 251 | "name": "ipython", 252 | "version": 3 253 | }, 254 | "file_extension": ".py", 255 | "mimetype": "text/x-python", 256 | "name": "python", 257 | "nbconvert_exporter": "python", 258 | "pygments_lexer": "ipython3", 259 | "version": "3.9.21" 260 | } 261 | }, 262 | "nbformat": 4, 263 | "nbformat_minor": 2 264 | } 265 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "diffms" 7 | version = "1.0.0" 8 | description = "DiffMS: Diffusion Generation of Molecules Conditioned on Mass Spectra" 9 | requires-python = "==3.9.*" 10 | dependencies = [ 11 | "hydra-core==1.3.2", 12 | "matplotlib==3.7.1", 13 | "numpy==1.23", 14 | "omegaconf==2.3.0", 15 | "overrides==7.3.1", 16 | "pandas==1.4", 17 | "pytorch_lightning==2.0.4", 18 | "setuptools==68.0.0", 19 | "torch_geometric==2.3.1", 20 | "torchmetrics==0.11.4", 21 | "tqdm", 22 | "wandb", 23 | "h5py", 24 | "seaborn", 25 | "myopic-mces", 26 | "tqdm-joblib", 27 | ] 28 | authors = [ 29 | {name = "Montgomery Bohde", email = "mbohde@tamu.edu"}, 30 | {name = "Mrunali Manjrekar"}, 31 | {name = "Runzhong Wang"}, 32 | {name = "Shuiwang Ji"}, 33 | {name = "Connor W. Coley"}, 34 | ] 35 | maintainers = [ 36 | {name = "Montgomery Bohde", email = "mbohde@tamu.edu"} 37 | ] 38 | readme = "README.md" 39 | 40 | [tool.setuptools.packages.find] 41 | where = ["."] 42 | include = ["src"] 43 | exclude = [] 44 | namespaces = false -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/DiffMS/9d9f4fd497162eec045f7db2787da30ce69a9622/src/__init__.py -------------------------------------------------------------------------------- /src/analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/DiffMS/9d9f4fd497162eec045f7db2787da30ce69a9622/src/analysis/__init__.py -------------------------------------------------------------------------------- /src/analysis/visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from rdkit import Chem 4 | from rdkit.Chem import Draw 5 | import numpy as np 6 | import rdkit.Chem 7 | import wandb 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | class MolecularVisualization: 12 | def __init__(self, remove_h, dataset_infos): 13 | self.remove_h = remove_h 14 | self.dataset_infos = dataset_infos 15 | 16 | def mol_from_graphs(self, node_list, adjacency_matrix): 17 | """ 18 | Convert graphs to rdkit molecules 19 | node_list: the nodes of a batch of nodes (bs x n) 20 | adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) 21 | """ 22 | # dictionary to map integer value to the char of atom 23 | atom_decoder = self.dataset_infos.atom_decoder 24 | 25 | # create empty editable mol object 26 | mol = Chem.RWMol() 27 | 28 | # add atoms to mol and keep track of index 29 | node_to_idx = {} 30 | for i in range(len(node_list)): 31 | if node_list[i] == -1: 32 | continue 33 | a = Chem.Atom(atom_decoder[int(node_list[i])]) 34 | molIdx = mol.AddAtom(a) 35 | node_to_idx[i] = molIdx 36 | 37 | for ix, row in enumerate(adjacency_matrix): 38 | for iy, bond in enumerate(row): 39 | # only traverse half the symmetric matrix 40 | if iy <= ix: 41 | continue 42 | if bond == 1: 43 | bond_type = Chem.rdchem.BondType.SINGLE 44 | elif bond == 2: 45 | bond_type = Chem.rdchem.BondType.DOUBLE 46 | elif bond == 3: 47 | bond_type = Chem.rdchem.BondType.TRIPLE 48 | elif bond == 4: 49 | bond_type = Chem.rdchem.BondType.AROMATIC 50 | else: 51 | continue 52 | mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type) 53 | 54 | try: 55 | mol = mol.GetMol() 56 | except rdkit.Chem.KekulizeException: 57 | print("Can't kekulize molecule") 58 | mol = None 59 | return mol 60 | 61 | def visualize(self, path: str, molecules: list, num_molecules_to_visualize: int, log='graph'): 62 | # define path to save figures 63 | if not os.path.exists(path): 64 | os.makedirs(path) 65 | 66 | # visualize the final molecules 67 | print(f"Visualizing {num_molecules_to_visualize} of {len(molecules)}") 68 | if num_molecules_to_visualize > len(molecules): 69 | print(f"Shortening to {len(molecules)}") 70 | num_molecules_to_visualize = len(molecules) 71 | 72 | for i in range(num_molecules_to_visualize): 73 | file_path = os.path.join(path, 'molecule_{}.png'.format(i)) 74 | mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy()) 75 | try: 76 | Draw.MolToFile(mol, file_path) 77 | if wandb.run and log is not None: 78 | print(f"Saving {file_path} to wandb") 79 | wandb.log({log: wandb.Image(file_path)}, commit=True) 80 | except rdkit.Chem.KekulizeException: 81 | print("Can't kekulize molecule") 82 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/DiffMS/9d9f4fd497162eec045f7db2787da30ce69a9622/src/datasets/__init__.py -------------------------------------------------------------------------------- /src/datasets/abstract_dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import pytorch_lightning as pl 5 | from torch_geometric.data import Dataset 6 | from torch_geometric.loader import DataLoader 7 | from torch_geometric.data.lightning import LightningDataset 8 | 9 | import src.utils as utils 10 | from src.diffusion.distributions import DistributionNodes 11 | 12 | def kwargs_repr(**kwargs) -> str: 13 | return ', '.join([f'{k}={v}' for k, v in kwargs.items() if v is not None]) 14 | 15 | class CustomLightningDataset(LightningDataset): 16 | def __init__(self, cfg, datasets, **kwargs): 17 | kwargs.pop('batch_size', None) 18 | self.kwargs = kwargs 19 | 20 | self.batch_size = cfg.train.batch_size if 'debug' not in cfg.general.name else 2 21 | self.eval_batch_size = cfg.train.eval_batch_size if 'debug' not in cfg.general.name else 1 22 | 23 | super().__init__(train_dataset=datasets['train'], val_dataset=datasets['val'], test_dataset=datasets['test'],) 24 | for k, v in kwargs.items(): # overwrite default kwargs from LightningDataset 25 | self.kwargs[k] = v 26 | self.kwargs.pop('batch_size', None) 27 | 28 | def dataloader(self, dataset: Dataset, **kwargs) -> DataLoader: 29 | return DataLoader(dataset, **kwargs) 30 | 31 | def train_dataloader(self) -> DataLoader: 32 | from torch.utils.data import IterableDataset 33 | 34 | shuffle = not isinstance(self.train_dataset, IterableDataset) 35 | shuffle &= self.kwargs.get('sampler', None) is None 36 | shuffle &= self.kwargs.get('batch_sampler', None) is None 37 | return self.dataloader(self.train_dataset, shuffle=shuffle, batch_size=self.batch_size, **self.kwargs) 38 | 39 | def val_dataloader(self) -> DataLoader: 40 | kwargs = copy.copy(self.kwargs) 41 | kwargs.pop('sampler', None) 42 | kwargs.pop('batch_sampler', None) 43 | 44 | return self.dataloader(self.val_dataset, shuffle=True, batch_size=self.eval_batch_size, **kwargs) 45 | 46 | def test_dataloader(self) -> DataLoader: 47 | kwargs = copy.copy(self.kwargs) 48 | kwargs.pop('sampler', None) 49 | kwargs.pop('batch_sampler', None) 50 | 51 | return self.dataloader(self.test_dataset, shuffle=False, batch_size=self.eval_batch_size, **kwargs) 52 | 53 | def predict_dataloader(self) -> DataLoader: 54 | kwargs = copy.copy(self.kwargs) 55 | kwargs.pop('sampler', None) 56 | kwargs.pop('batch_sampler', None) 57 | 58 | return self.dataloader(self.pred_dataset, shuffle=False, batch_size=self.eval_batch_size, **kwargs) 59 | 60 | def __repr__(self) -> str: 61 | kwargs = kwargs_repr( 62 | train_dataset=self.train_dataset, 63 | val_dataset=self.val_dataset, 64 | test_dataset=self.test_dataset, 65 | pred_dataset=self.pred_dataset, 66 | batch_size=self.batch_size, 67 | eval_batch_size=self.eval_batch_size, 68 | **self.kwargs 69 | ) 70 | 71 | return f'{self.__class__.__name__}({kwargs})' 72 | 73 | class AbstractDataModule(CustomLightningDataset): 74 | def __init__(self, cfg, datasets): 75 | super().__init__(cfg, datasets, num_workers=cfg.train.num_workers, pin_memory=getattr(cfg.train, "pin_memory", True)) 76 | self.cfg = cfg 77 | self.input_dims = None 78 | self.output_dims = None 79 | 80 | def __getitem__(self, idx): 81 | return self.train_dataset[idx] 82 | 83 | def node_counts(self, max_nodes_possible=150): 84 | all_counts = torch.zeros(max_nodes_possible) 85 | for loader in [self.train_dataloader(), self.val_dataloader()]: 86 | for data in loader: 87 | unique, counts = torch.unique(data.batch, return_counts=True) 88 | for count in counts: 89 | all_counts[count] += 1 90 | max_index = max(all_counts.nonzero()) 91 | all_counts = all_counts[:max_index + 1] 92 | all_counts = all_counts / all_counts.sum() 93 | return all_counts 94 | 95 | def node_types(self): 96 | num_classes = None 97 | for data in self.train_dataloader(): 98 | num_classes = data.x.shape[1] 99 | break 100 | 101 | counts = torch.zeros(num_classes) 102 | 103 | for i, data in enumerate(self.train_dataloader()): 104 | counts += data.x.sum(dim=0) 105 | 106 | counts = counts / counts.sum() 107 | return counts 108 | 109 | def edge_counts(self): 110 | num_classes = None 111 | for data in self.train_dataloader(): 112 | num_classes = data.edge_attr.shape[1] 113 | break 114 | 115 | d = torch.zeros(num_classes, dtype=torch.float) 116 | 117 | for i, data in enumerate(self.train_dataloader()): 118 | unique, counts = torch.unique(data.batch, return_counts=True) 119 | 120 | all_pairs = 0 121 | for count in counts: 122 | all_pairs += count * (count - 1) 123 | 124 | num_edges = data.edge_index.shape[1] 125 | num_non_edges = all_pairs - num_edges 126 | 127 | edge_types = data.edge_attr.sum(dim=0) 128 | assert num_non_edges >= 0 129 | d[0] += num_non_edges 130 | d[1:] += edge_types[1:] 131 | 132 | d = d / d.sum() 133 | return d 134 | 135 | 136 | class MolecularDataModule(AbstractDataModule): 137 | def valency_count(self, max_n_nodes): 138 | valencies = torch.zeros(3 * max_n_nodes - 2) # Max valency possible if everything is connected 139 | 140 | # No bond, single bond, double bond, triple bond, aromatic bond 141 | multiplier = torch.tensor([0, 1, 2, 3, 1.5]) 142 | 143 | for data in self.train_dataloader(): 144 | n = data.x.shape[0] 145 | 146 | for atom in range(n): 147 | edges = data.edge_attr[data.edge_index[0] == atom] 148 | edges_total = edges.sum(dim=0) 149 | valency = (edges_total * multiplier).sum() 150 | valencies[valency.long().item()] += 1 151 | valencies = valencies / valencies.sum() 152 | return valencies 153 | 154 | 155 | class AbstractDatasetInfos: 156 | def complete_infos(self, n_nodes, node_types): 157 | self.input_dims = None 158 | self.output_dims = None 159 | self.num_classes = len(node_types) 160 | self.max_n_nodes = len(n_nodes) - 1 161 | self.nodes_dist = DistributionNodes(n_nodes) 162 | 163 | def compute_input_output_dims(self, datamodule, extra_features, domain_features): 164 | example_batch = next(iter(datamodule.train_dataloader())) 165 | ex_dense, node_mask = utils.to_dense(example_batch.x, example_batch.edge_index, example_batch.edge_attr, 166 | example_batch.batch) 167 | example_data = {'X_t': ex_dense.X, 'E_t': ex_dense.E, 'y_t': example_batch['y'], 'node_mask': node_mask} 168 | 169 | self.input_dims = {'X': example_batch['x'].size(1), 170 | 'E': example_batch['edge_attr'].size(1), 171 | 'y': example_batch['y'].size(1) + 1} # + 1 due to time conditioning 172 | 173 | ex_extra_feat = extra_features(example_data) 174 | self.input_dims['X'] += ex_extra_feat.X.size(-1) 175 | self.input_dims['E'] += ex_extra_feat.E.size(-1) 176 | self.input_dims['y'] += ex_extra_feat.y.size(-1) 177 | 178 | ex_extra_molecular_feat = domain_features(example_data) 179 | self.input_dims['X'] += ex_extra_molecular_feat.X.size(-1) 180 | self.input_dims['E'] += ex_extra_molecular_feat.E.size(-1) 181 | self.input_dims['y'] += ex_extra_molecular_feat.y.size(-1) 182 | 183 | self.output_dims = {'X': example_batch['x'].size(1), 184 | 'E': example_batch['edge_attr'].size(1), 185 | 'y': example_batch['y'].size(1)} 186 | 187 | 188 | ATOM_TO_VALENCY = { 189 | 'H': 1, 190 | 'He': 0, 191 | 'Li': 1, 192 | 'Be': 2, 193 | 'B': 3, 194 | 'C': 4, 195 | 'N': 3, 196 | 'O': 2, 197 | 'F': 1, 198 | 'Ne': 0, 199 | 'Na': 1, 200 | 'Mg': 2, 201 | 'Al': 3, 202 | 'Si': 4, 203 | 'P': 3, 204 | 'S': 2, 205 | 'Cl': 1, 206 | 'Ar': 0, 207 | 'K': 1, 208 | 'Ca': 2, 209 | 'Sc': 3, 210 | 'Ti': 4, 211 | 'V': 5, 212 | 'Cr': 2, 213 | 'Mn': 7, 214 | 'Fe': 2, 215 | 'Co': 3, 216 | 'Ni': 2, 217 | 'Cu': 2, 218 | 'Zn': 2, 219 | 'Ga': 3, 220 | 'Ge': 4, 221 | 'As': 3, 222 | 'Se': 2, 223 | 'Br': 1, 224 | 'Kr': 0, 225 | 'Rb': 1, 226 | 'Sr': 2, 227 | 'Y': 3, 228 | 'Zr': 2, 229 | 'Nb': 2, 230 | 'Mo': 2, 231 | 'Tc': 6, 232 | 'Ru': 2, 233 | 'Rh': 3, 234 | 'Pd': 2, 235 | 'Ag': 1, 236 | 'Cd': 1, 237 | 'In': 1, 238 | 'Sn': 2, 239 | 'Sb': 3, 240 | 'Te': 2, 241 | 'I': 1, 242 | 'Xe': 0, 243 | 'Cs': 1, 244 | 'Ba': 2, 245 | 'La': 3, 246 | 'Ce': 3, 247 | 'Pr': 3, 248 | 'Nd': 3, 249 | 'Pm': 3, 250 | 'Sm': 2, 251 | 'Eu': 2, 252 | 'Gd': 3, 253 | 'Tb': 3, 254 | 'Dy': 3, 255 | 'Ho': 3, 256 | 'Er': 3, 257 | 'Tm': 2, 258 | 'Yb': 2, 259 | 'Lu': 3, 260 | 'Hf': 4, 261 | 'Ta': 3, 262 | 'W': 2, 263 | 'Re': 1, 264 | 'Os': 2, 265 | 'Ir': 1, 266 | 'Pt': 1, 267 | 'Au': 1, 268 | 'Hg': 1, 269 | 'Tl': 1, 270 | 'Pb': 2, 271 | 'Bi': 3, 272 | 'Po': 2, 273 | 'At': 1, 274 | 'Rn': 0, 275 | 'Fr': 1, 276 | 'Ra': 2, 277 | 'Ac': 3, 278 | 'Th': 4, 279 | 'Pa': 5, 280 | 'U': 2, 281 | } 282 | 283 | ATOM_TO_WEIGHT = { 284 | 'H': 1, 285 | 'He': 4, 286 | 'Li': 7, 287 | 'Be': 9, 288 | 'B': 11, 289 | 'C': 12, 290 | 'N': 14, 291 | 'O': 16, 292 | 'F': 19, 293 | 'Ne': 20, 294 | 'Na': 23, 295 | 'Mg': 24, 296 | 'Al': 27, 297 | 'Si': 28, 298 | 'P': 31, 299 | 'S': 32, 300 | 'Cl': 35, 301 | 'Ar': 40, 302 | 'K': 39, 303 | 'Ca': 40, 304 | 'Sc': 45, 305 | 'Ti': 48, 306 | 'V': 51, 307 | 'Cr': 52, 308 | 'Mn': 55, 309 | 'Fe': 56, 310 | 'Co': 59, 311 | 'Ni': 59, 312 | 'Cu': 64, 313 | 'Zn': 65, 314 | 'Ga': 70, 315 | 'Ge': 73, 316 | 'As': 75, 317 | 'Se': 79, 318 | 'Br': 80, 319 | 'Kr': 84, 320 | 'Rb': 85, 321 | 'Sr': 88, 322 | 'Y': 89, 323 | 'Zr': 91, 324 | 'Nb': 93, 325 | 'Mo': 96, 326 | 'Tc': 98, 327 | 'Ru': 101, 328 | 'Rh': 103, 329 | 'Pd': 106, 330 | 'Ag': 108, 331 | 'Cd': 112, 332 | 'In': 115, 333 | 'Sn': 119, 334 | 'Sb': 122, 335 | 'Te': 128, 336 | 'I': 127, 337 | 'Xe': 131, 338 | 'Cs': 133, 339 | 'Ba': 137, 340 | 'La': 139, 341 | 'Ce': 140, 342 | 'Pr': 141, 343 | 'Nd': 144, 344 | 'Pm': 145, 345 | 'Sm': 150, 346 | 'Eu': 152, 347 | 'Gd': 157, 348 | 'Tb': 159, 349 | 'Dy': 163, 350 | 'Ho': 165, 351 | 'Er': 167, 352 | 'Tm': 169, 353 | 'Yb': 173, 354 | 'Lu': 175, 355 | 'Hf': 178, 356 | 'Ta': 181, 357 | 'W': 184, 358 | 'Re': 186, 359 | 'Os': 190, 360 | 'Ir': 192, 361 | 'Pt': 195, 362 | 'Au': 197, 363 | 'Hg': 201, 364 | 'Tl': 204, 365 | 'Pb': 207, 366 | 'Bi': 209, 367 | 'Po': 209, 368 | 'At': 210, 369 | 'Rn': 222, 370 | 'Fr': 223, 371 | 'Ra': 226, 372 | 'Ac': 227, 373 | 'Th': 232, 374 | 'Pa': 231, 375 | 'U': 238, 376 | 'Np': 237, 377 | 'Pu': 244, 378 | 'Am': 243, 379 | 'Cm': 247, 380 | 'Bk': 247, 381 | 'Cf': 251, 382 | 'Es': 252, 383 | 'Fm': 257, 384 | 'Md': 258, 385 | 'No': 259, 386 | 'Lr': 262, 387 | 'Rf': 267, 388 | 'Db': 270, 389 | 'Sg': 269, 390 | 'Bh': 264, 391 | 'Hs': 269, 392 | 'Mt': 278, 393 | 'Ds': 281, 394 | 'Rg': 282, 395 | 'Cn': 285, 396 | 'Nh': 286, 397 | 'Fl': 289, 398 | 'Mc': 290, 399 | 'Lv': 293, 400 | 'Ts': 294, 401 | 'Og': 294, 402 | } -------------------------------------------------------------------------------- /src/datasets/fp2mol_dataset.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem, RDLogger 2 | from rdkit.Chem.rdchem import BondType as BT 3 | 4 | import os 5 | import pathlib 6 | from typing import Any, Sequence 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from tqdm import tqdm 11 | import numpy as np 12 | from torch_geometric.data import Data, InMemoryDataset 13 | import pandas as pd 14 | from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect 15 | from joblib import Parallel, delayed 16 | from tqdm_joblib import tqdm_joblib 17 | 18 | from src import utils 19 | from src.analysis.rdkit_functions import mol2smiles, build_molecule_with_partial_charges, compute_molecular_metrics 20 | from src.datasets.abstract_dataset import AbstractDatasetInfos, MolecularDataModule 21 | from src.datasets.abstract_dataset import ATOM_TO_VALENCY, ATOM_TO_WEIGHT 22 | 23 | def to_list(value: Any) -> Sequence: 24 | if isinstance(value, Sequence) and not isinstance(value, str): 25 | return value 26 | else: 27 | return [value] 28 | 29 | def process_single_inchi(args): 30 | """ 31 | Process a single inchi string. 32 | 33 | Parameters: 34 | args: tuple of (i, inchi, types, bonds, morgan_r, morgan_nbits, 35 | filter_dataset, pre_filter, pre_transform, atom_decoder) 36 | Returns: 37 | If filter_dataset is True: a tuple (data, smiles) if the molecule passes filtering, 38 | or None otherwise. 39 | Otherwise: the processed Data object (or None if it fails). 40 | """ 41 | RDLogger.DisableLog('rdApp.*') 42 | 43 | #unpack args 44 | (i, inchi, types, bonds, morgan_r, morgan_nbits, 45 | filter_dataset, pre_filter, pre_transform, atom_decoder) = args 46 | 47 | try: 48 | mol = Chem.MolFromInchi(inchi) 49 | if mol is None: 50 | return None 51 | # Remove stereochemistry information 52 | smi = Chem.MolToSmiles(mol, isomericSmiles=False) 53 | mol = Chem.MolFromSmiles(smi) 54 | if mol is None: 55 | return None 56 | N = mol.GetNumAtoms() 57 | type_idx = [] 58 | for atom in mol.GetAtoms(): 59 | symbol = atom.GetSymbol() 60 | if symbol not in types: 61 | return None # Skip if unknown atom is encountered 62 | type_idx.append(types[symbol]) 63 | row, col, edge_type = [], [], [] 64 | for bond in mol.GetBonds(): 65 | start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() 66 | row += [start, end] 67 | col += [end, start] 68 | edge_type += 2 * [bonds[bond.GetBondType()] + 1] 69 | if len(row) == 0: 70 | return None 71 | edge_index = torch.tensor([row, col], dtype=torch.long) 72 | edge_type = torch.tensor(edge_type, dtype=torch.long) 73 | edge_attr = F.one_hot(edge_type, num_classes=len(bonds) + 1).to(torch.float) 74 | perm = (edge_index[0] * N + edge_index[1]).argsort() 75 | edge_index = edge_index[:, perm] 76 | edge_attr = edge_attr[perm] 77 | x = F.one_hot(torch.tensor(type_idx), num_classes=len(types)).float() 78 | fp = GetMorganFingerprintAsBitVect(mol, morgan_r, nBits=morgan_nbits) 79 | y = torch.tensor(np.asarray(fp, dtype=np.int8)).unsqueeze(0) 80 | inchi_canonical = Chem.MolToInchi(mol) 81 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i, inchi=inchi_canonical) 82 | 83 | if filter_dataset: 84 | # Filtering: rebuild the molecule from the graph 85 | batch = getattr(data, 'batch', torch.zeros(data.x.size(0), dtype=torch.long)) 86 | dense_data, node_mask = utils.to_dense(data.x, data.edge_index, data.edge_attr, batch) 87 | dense_data = dense_data.mask(node_mask, collapse=True) 88 | X, E = dense_data.X, dense_data.E 89 | if X.size(0) != 1: 90 | return None 91 | atom_types = X[0] 92 | edge_types = E[0] 93 | mol_reconstructed = build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder) 94 | smiles = mol2smiles(mol_reconstructed) 95 | if smiles is not None: 96 | try: 97 | mol_frags = Chem.rdmolops.GetMolFrags(mol_reconstructed, asMols=True, sanitizeFrags=True) 98 | if len(mol_frags) == 1: 99 | return (data, smiles) 100 | except Chem.rdchem.AtomValenceException: 101 | print("Valence error in GetMolFrags") 102 | except Chem.rdchem.KekulizeException: 103 | print("Can't kekulize molecule") 104 | return None 105 | else: 106 | if pre_filter is not None and not pre_filter(data): 107 | return None 108 | if pre_transform is not None: 109 | data = pre_transform(data) 110 | return data 111 | except Exception as e: 112 | print(e) 113 | return None 114 | 115 | atom_decoder = ['C', 'O', 'P', 'N', 'S', 'Cl', 'F', 'H'] 116 | valency = [ATOM_TO_VALENCY.get(atom, 0) for atom in atom_decoder] 117 | 118 | # Data sources: 119 | # HMDB: https://hmdb.ca/downloads 120 | # DSSTox: https://clowder.edap-cluster.com/datasets/61147fefe4b0856fdc65639b#folderId=6616d85ce4b063812d70fc8f 121 | # COCONUT: https://zenodo.org/records/13692394 122 | 123 | class FP2MolDataset(InMemoryDataset): 124 | def __init__(self, stage, root, filter_dataset: bool, transform=None, pre_transform=None, pre_filter=None, morgan_r=2, morgan_nBits=2048, dataset='hmdb'): 125 | self.stage = stage 126 | self.atom_decoder = atom_decoder 127 | self.filter_dataset = filter_dataset 128 | 129 | self.morgan_r = morgan_r 130 | self.morgan_nbits = morgan_nBits 131 | self.dataset = dataset 132 | 133 | self._processed_dir = os.path.join(root, 'processed', f'morgan_r-{self.morgan_r}__morgan_nbits-{self.morgan_nbits}') 134 | self._raw_dir = os.path.join(root, 'preprocessed') 135 | 136 | if self.stage == 'train': self.file_idx = 0 137 | elif self.stage == 'val': self.file_idx = 1 138 | elif self.stage == 'test': self.file_idx = 1 139 | else: raise ValueError(f"Invalid stage {self.stage}") 140 | 141 | super().__init__(root, None, pre_transform, pre_filter) 142 | self.data, self.slices = torch.load(self.processed_paths[self.file_idx]) 143 | 144 | @property 145 | def processed_dir(self): 146 | return self._processed_dir 147 | 148 | @property 149 | def raw_file_names(self): 150 | return [f"{self.dataset}_train.csv", f"{self.dataset}_val.csv"] 151 | 152 | @property 153 | def split_file_name(self): 154 | return [f"{self.dataset}_train.csv", f"{self.dataset}_val.csv"] 155 | 156 | 157 | @property 158 | def split_paths(self): 159 | r"""The absolute filepaths that must be present in order to skip 160 | splitting.""" 161 | files = to_list(self.split_file_name) 162 | return [os.path.join(self._raw_dir, f) for f in files] 163 | 164 | @property 165 | def processed_file_names(self): 166 | return ['train.pt', 'val.pt', 'test.pt'] 167 | 168 | def process(self): 169 | RDLogger.DisableLog('rdApp.*') 170 | types = {atom: i for i, atom in enumerate(self.atom_decoder)} 171 | 172 | bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} 173 | 174 | path = self.split_paths[self.file_idx] 175 | inchi_list = pd.read_csv(path)['inchi'].values 176 | 177 | if not os.path.exists(self.processed_paths[self.file_idx]): 178 | data_list = [] 179 | smiles_kept = [] 180 | 181 | # Build the argument list for parallel processing. 182 | args_list = [ 183 | (i, inchi, types, bonds, self.morgan_r, self.morgan_nbits, 184 | self.filter_dataset, self.pre_filter, self.pre_transform, self.atom_decoder) 185 | for i, inchi in enumerate(inchi_list) 186 | ] 187 | 188 | # Use joblib's Parallel with tqdm_joblib to show a progress bar. 189 | with tqdm_joblib(tqdm(desc="Processing inchi.....", total=len(args_list), leave=False)) as progress_bar: 190 | 191 | results = Parallel(n_jobs=-1)(delayed(process_single_inchi)(arg) for arg in args_list) 192 | 193 | # Process results: if filter_dataset is enabled, result is a tuple (data, smiles) 194 | for result in tqdm(results, desc="Filtering graphs.....", total=len(results), leave=False): 195 | if result is not None: 196 | if self.filter_dataset: 197 | data, smiles = result 198 | data_list.append(data) 199 | smiles_kept.append(smiles) 200 | else: 201 | data_list.append(result) 202 | 203 | torch.save(self.collate(data_list), self.processed_paths[self.file_idx]) 204 | 205 | class FP2MolDataModule(MolecularDataModule): 206 | def __init__(self, cfg): 207 | self.remove_h = False 208 | self.datadir = cfg.dataset.datadir 209 | self.filter_dataset = cfg.dataset.filter 210 | self.train_smiles = [] 211 | self.dataset_name = cfg.dataset.dataset 212 | self._root_path = os.path.join(cfg.general.parent_dir, self.datadir, self.dataset_name) 213 | datasets = {'train': FP2MolDataset(stage='train', root=self._root_path, filter_dataset=self.filter_dataset, morgan_r=cfg.dataset.morgan_r, morgan_nBits=cfg.dataset.morgan_nbits, dataset=cfg.dataset.dataset), 214 | 'val': FP2MolDataset(stage='val', root=self._root_path, filter_dataset=self.filter_dataset, morgan_r=cfg.dataset.morgan_r, morgan_nBits=cfg.dataset.morgan_nbits, dataset=cfg.dataset.dataset), 215 | 'test': FP2MolDataset(stage='val', root=self._root_path, filter_dataset=self.filter_dataset, morgan_r=cfg.dataset.morgan_r, morgan_nBits=cfg.dataset.morgan_nbits, dataset=cfg.dataset.dataset)} 216 | super().__init__(cfg, datasets) 217 | 218 | 219 | class FP2Mol_infos(AbstractDatasetInfos): 220 | def __init__(self, datamodule, cfg, recompute_statistics=False, meta=None): 221 | self.name = datamodule.dataset_name 222 | self.input_dims = None 223 | self.output_dims = None 224 | self.remove_h = False 225 | 226 | self.atom_decoder = atom_decoder 227 | self.atom_encoder = {atom: i for i, atom in enumerate(self.atom_decoder)} 228 | self.atom_weights = {i: ATOM_TO_WEIGHT.get(atom, 0) for i, atom in enumerate(self.atom_decoder)} 229 | self.valencies = valency 230 | self.num_atom_types = len(self.atom_decoder) 231 | self.max_weight = max(self.atom_weights.values()) 232 | 233 | meta_files = dict(n_nodes=f'{datamodule._root_path}/stats/n_counts.txt', 234 | node_types=f'{datamodule._root_path}/stats/atom_types.txt', 235 | edge_types=f'{datamodule._root_path}/stats/edge_types.txt', 236 | valency_distribution=f'{datamodule._root_path}/stats/valencies.txt') 237 | 238 | # n_nodes and valency_distribution are not transferrable between datatsets because of shape mismatches 239 | if cfg.dataset.stats_dir: 240 | meta_read = dict(n_nodes=f'{datamodule._root_path}/stats/n_counts.txt', 241 | node_types=f'{cfg.dataset.stats_dir}/atom_types.txt', 242 | edge_types=f'{cfg.dataset.stats_dir}/edge_types.txt', 243 | valency_distribution=f'{datamodule._root_path}/stats/valencies.txt') 244 | else: 245 | meta_read = dict(n_nodes=f'{datamodule._root_path}/stats/n_counts.txt', 246 | node_types=f'{datamodule._root_path}/stats/atom_types.txt', 247 | edge_types=f'{datamodule._root_path}/stats/edge_types.txt', 248 | valency_distribution=f'{datamodule._root_path}/stats/valencies.txt') 249 | 250 | 251 | self.n_nodes = None 252 | self.node_types = None 253 | self.edge_types = None 254 | self.valency_distribution = None 255 | 256 | if meta is None: 257 | meta = dict(n_nodes=None, node_types=None, edge_types=None, valency_distribution=None) 258 | assert set(meta.keys()) == set(meta_files.keys()) 259 | 260 | for k, v in meta_read.items(): 261 | if (k not in meta or meta[k] is None) and os.path.exists(v): 262 | meta[k] = torch.tensor(np.loadtxt(v)) 263 | setattr(self, k, meta[k]) 264 | 265 | self.max_n_nodes = len(self.n_nodes) - 1 if self.n_nodes is not None else None 266 | 267 | if recompute_statistics or self.n_nodes is None: 268 | self.n_nodes = datamodule.node_counts() 269 | print("Distribution of number of nodes", self.n_nodes) 270 | np.savetxt(meta_files["n_nodes"], self.n_nodes.numpy()) 271 | self.max_n_nodes = len(self.n_nodes) - 1 272 | if recompute_statistics or self.node_types is None: 273 | self.node_types = datamodule.node_types() # There are no node types 274 | print("Distribution of node types", self.node_types) 275 | np.savetxt(meta_files["node_types"], self.node_types.numpy()) 276 | 277 | if recompute_statistics or self.edge_types is None: 278 | self.edge_types = datamodule.edge_counts() 279 | print("Distribution of edge types", self.edge_types) 280 | np.savetxt(meta_files["edge_types"], self.edge_types.numpy()) 281 | if recompute_statistics or self.valency_distribution is None: 282 | valencies = datamodule.valency_count(self.max_n_nodes) 283 | print("Distribution of the valencies", valencies) 284 | np.savetxt(meta_files["valency_distribution"], valencies.numpy()) 285 | self.valency_distribution = valencies 286 | 287 | self.complete_infos(n_nodes=self.n_nodes, node_types=self.node_types) -------------------------------------------------------------------------------- /src/datasets/spec2mol_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from typing import Any, Sequence 4 | import random 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from torch_geometric.loader import DataLoader 9 | from tqdm import tqdm 10 | from rdkit import Chem, RDLogger 11 | from rdkit.Chem.rdchem import BondType as BT 12 | from rdkit.Chem import Descriptors 13 | import torch 14 | import torch.nn.functional as F 15 | from torch_geometric.data import Data, InMemoryDataset, download_url 16 | from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect 17 | 18 | from src.mist.data import datasets, splitter, featurizers 19 | from src.datasets.abstract_dataset import AbstractDatasetInfos, MolecularDataModule 20 | import src.utils as utils 21 | from src.mist.data.datasets import get_paired_loader_graph 22 | from src.datasets.abstract_dataset import ATOM_TO_VALENCY, ATOM_TO_WEIGHT 23 | 24 | def to_list(value: Any) -> Sequence: 25 | if isinstance(value, Sequence) and not isinstance(value, str): 26 | return value 27 | else: 28 | return [value] 29 | 30 | atom_decoder = ['C', 'O', 'P', 'N', 'S', 'Cl', 'F', 'H'] 31 | valency = [ATOM_TO_VALENCY.get(atom, 0) for atom in atom_decoder] 32 | 33 | 34 | class Spec2MolDataModule(MolecularDataModule): 35 | def __init__(self, cfg): 36 | self.remove_h = False 37 | self.datadir = cfg.dataset.datadir 38 | self.filter_dataset = cfg.dataset.filter 39 | self.train_smiles = [] 40 | 41 | data_splitter = splitter.PresetSpectraSplitter(split_file=cfg.dataset.split_file) 42 | 43 | paired_featurizer = featurizers.PairedFeaturizer( 44 | spec_featurizer=featurizers.PeakFormula(**cfg.dataset), 45 | mol_featurizer=featurizers.FingerprintFeaturizer(fp_names=['morgan4096'], **cfg.dataset), 46 | graph_featurizer=featurizers.GraphFeaturizer(**cfg.dataset), 47 | ) 48 | 49 | spectra_mol_pairs = datasets.get_paired_spectra(**cfg.dataset) 50 | spectra_mol_pairs = list(zip(*spectra_mol_pairs)) 51 | 52 | # Redefine splitter s.t. this splits three times and remove subsetting 53 | split_name, (train, val, test) = data_splitter.get_splits(spectra_mol_pairs) 54 | 55 | # randomly shuffle test set with fixed seed 56 | random.seed(42) 57 | random.shuffle(test) 58 | 59 | ms_datasets = {'train': datasets.SpectraMolDataset(spectra_mol_list=train, featurizer=paired_featurizer, **cfg.dataset), 60 | 'val': datasets.SpectraMolDataset(spectra_mol_list=val, featurizer=paired_featurizer, **cfg.dataset), 61 | 'test': datasets.SpectraMolDataset(spectra_mol_list=test, featurizer=paired_featurizer, **cfg.dataset)} 62 | super().__init__(cfg, ms_datasets) 63 | 64 | def train_dataloader(self) -> DataLoader: 65 | return get_paired_loader_graph(self.train_dataset, shuffle=True, batch_size=self.batch_size, **self.kwargs) 66 | 67 | def val_dataloader(self) -> DataLoader: 68 | return get_paired_loader_graph(self.val_dataset, shuffle=False, batch_size=self.eval_batch_size, **self.kwargs) 69 | 70 | def test_dataloader(self) -> DataLoader: 71 | return get_paired_loader_graph(self.test_dataset, shuffle=False, batch_size=self.eval_batch_size, **self.kwargs) 72 | 73 | def valency_count(self, max_n_nodes): 74 | valencies = torch.zeros(3 * max_n_nodes - 2) # Max valency possible if everything is connected 75 | 76 | # No bond, single bond, double bond, triple bond, aromatic bond 77 | multiplier = torch.tensor([0, 1, 2, 3, 1.5]) 78 | 79 | for batch in self.train_dataloader(): 80 | data = batch['graph'] 81 | n = data.x.shape[0] 82 | 83 | for atom in range(n): 84 | edges = data.edge_attr[data.edge_index[0] == atom] 85 | edges_total = edges.sum(dim=0) 86 | valency = (edges_total * multiplier).sum() 87 | valencies[valency.long().item()] += 1 88 | valencies = valencies / valencies.sum() 89 | return valencies 90 | 91 | def node_counts(self, max_nodes_possible=150): 92 | all_counts = torch.zeros(max_nodes_possible) 93 | for loader in [self.train_dataloader(), self.val_dataloader()]: 94 | for batch in loader: 95 | data = batch['graph'] 96 | unique, counts = torch.unique(data.batch, return_counts=True) 97 | for count in counts: 98 | all_counts[count] += 1 99 | max_index = max(all_counts.nonzero()) 100 | all_counts = all_counts[:max_index + 1] 101 | all_counts = all_counts / all_counts.sum() 102 | return all_counts 103 | 104 | def node_types(self): 105 | num_classes = None 106 | for batch in self.train_dataloader(): 107 | data = batch['graph'] 108 | num_classes = data.x.shape[1] 109 | break 110 | 111 | counts = torch.zeros(num_classes) 112 | 113 | for i, batch in enumerate(self.train_dataloader()): 114 | data = batch['graph'] 115 | counts += data.x.sum(dim=0) 116 | 117 | counts = counts / counts.sum() 118 | return counts 119 | 120 | def edge_counts(self): 121 | num_classes = None 122 | for batch in self.train_dataloader(): 123 | data = batch['graph'] 124 | num_classes = data.edge_attr.shape[1] 125 | break 126 | 127 | d = torch.zeros(num_classes, dtype=torch.float) 128 | 129 | for i, batch in enumerate(self.train_dataloader()): 130 | data = batch['graph'] 131 | unique, counts = torch.unique(data.batch, return_counts=True) 132 | 133 | all_pairs = 0 134 | for count in counts: 135 | all_pairs += count * (count - 1) 136 | 137 | num_edges = data.edge_index.shape[1] 138 | num_non_edges = all_pairs - num_edges 139 | 140 | edge_types = data.edge_attr.sum(dim=0) 141 | assert num_non_edges >= 0 142 | d[0] += num_non_edges 143 | d[1:] += edge_types[1:] 144 | 145 | d = d / d.sum() 146 | return d 147 | 148 | 149 | class Spec2MolDatasetInfos(AbstractDatasetInfos): 150 | def __init__(self, datamodule, cfg, recompute_statistics=False, meta=None): 151 | self.name = 'canopus' 152 | self.input_dims = None 153 | self.output_dims = None 154 | self.remove_h = False 155 | 156 | self.atom_decoder = atom_decoder 157 | self.atom_encoder = {atom: i for i, atom in enumerate(self.atom_decoder)} 158 | self.atom_weights = {i: ATOM_TO_WEIGHT.get(atom, 0) for i, atom in enumerate(self.atom_decoder)} 159 | self.valencies = valency 160 | self.num_atom_types = len(self.atom_decoder) 161 | self.max_weight = max(self.atom_weights.values()) 162 | self._root_path = os.path.join(cfg.general.parent_dir, cfg.dataset.datadir) 163 | 164 | meta_files = dict(n_nodes=f'{self._root_path}/n_counts.txt', 165 | node_types=f'{self._root_path}/atom_types.txt', 166 | edge_types=f'{self._root_path}/edge_types.txt', 167 | valency_distribution=f'{self._root_path}/valencies.txt') 168 | 169 | if cfg.dataset.stats_dir: 170 | meta_read = dict(n_nodes=f'{self._root_path}/n_counts.txt', 171 | node_types=f'{cfg.dataset.stats_dir}/atom_types.txt', 172 | edge_types=f'{cfg.dataset.stats_dir}/edge_types.txt', 173 | valency_distribution=f'{self._root_path}/valencies.txt') 174 | else: 175 | meta_read = dict(n_nodes=f'{self._root_path}/n_counts.txt', 176 | node_types=f'{self._root_path}/atom_types.txt', 177 | edge_types=f'{self._root_path}/edge_types.txt', 178 | valency_distribution=f'{self._root_path}/valencies.txt') 179 | 180 | self.n_nodes = None 181 | self.node_types = None 182 | self.edge_types = None 183 | self.valency_distribution = None 184 | 185 | if meta is None: 186 | meta = dict(n_nodes=None, node_types=None, edge_types=None, valency_distribution=None) 187 | assert set(meta.keys()) == set(meta_files.keys()) 188 | 189 | for k, v in meta_read.items(): 190 | if (k not in meta or meta[k] is None) and os.path.exists(v): 191 | meta[k] = torch.tensor(np.loadtxt(v)) 192 | setattr(self, k, meta[k]) 193 | 194 | self.max_n_nodes = len(self.n_nodes) - 1 if self.n_nodes is not None else None 195 | 196 | if recompute_statistics or self.n_nodes is None: 197 | self.n_nodes = datamodule.node_counts() 198 | print("Distribution of number of nodes", self.n_nodes) 199 | np.savetxt(meta_files["n_nodes"], self.n_nodes.numpy()) 200 | self.max_n_nodes = len(self.n_nodes) - 1 201 | if recompute_statistics or self.node_types is None: 202 | self.node_types = datamodule.node_types() # There are no node types 203 | print("Distribution of node types", self.node_types) 204 | np.savetxt(meta_files["node_types"], self.node_types.numpy()) 205 | 206 | if recompute_statistics or self.edge_types is None: 207 | self.edge_types = datamodule.edge_counts() 208 | print("Distribution of edge types", self.edge_types) 209 | np.savetxt(meta_files["edge_types"], self.edge_types.numpy()) 210 | if recompute_statistics or self.valency_distribution is None: 211 | valencies = datamodule.valency_count(self.max_n_nodes) 212 | print("Distribution of the valencies", valencies) 213 | np.savetxt(meta_files["valency_distribution"], valencies.numpy()) 214 | self.valency_distribution = valencies 215 | 216 | self.complete_infos(n_nodes=self.n_nodes, node_types=self.node_types) 217 | 218 | def compute_input_output_dims(self, datamodule, extra_features, domain_features): 219 | example_batch = next(iter(datamodule.train_dataloader()))['graph'] 220 | ex_dense, node_mask = utils.to_dense(example_batch.x, example_batch.edge_index, example_batch.edge_attr, 221 | example_batch.batch) 222 | example_data = {'X_t': ex_dense.X, 'E_t': ex_dense.E, 'y_t': example_batch['y'], 'node_mask': node_mask} 223 | 224 | self.input_dims = {'X': example_batch['x'].size(1), 225 | 'E': example_batch['edge_attr'].size(1), 226 | 'y': example_batch['y'].size(1) + 1} # + 1 due to time conditioning 227 | 228 | ex_extra_feat = extra_features(example_data) 229 | self.input_dims['X'] += ex_extra_feat.X.size(-1) 230 | self.input_dims['E'] += ex_extra_feat.E.size(-1) 231 | self.input_dims['y'] += ex_extra_feat.y.size(-1) 232 | 233 | ex_extra_molecular_feat = domain_features(example_data) 234 | self.input_dims['X'] += ex_extra_molecular_feat.X.size(-1) 235 | self.input_dims['E'] += ex_extra_molecular_feat.E.size(-1) 236 | self.input_dims['y'] += ex_extra_molecular_feat.y.size(-1) 237 | 238 | self.output_dims = {'X': example_batch['x'].size(1), 239 | 'E': example_batch['edge_attr'].size(1), 240 | 'y': example_batch['y'].size(1)} -------------------------------------------------------------------------------- /src/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/DiffMS/9d9f4fd497162eec045f7db2787da30ce69a9622/src/diffusion/__init__.py -------------------------------------------------------------------------------- /src/diffusion/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DistributionNodes: 5 | def __init__(self, histogram): 6 | """ Compute the distribution of the number of nodes in the dataset, and sample from this distribution. 7 | historgram: dict. The keys are num_nodes, the values are counts 8 | """ 9 | 10 | if type(histogram) == dict: 11 | max_n_nodes = max(histogram.keys()) 12 | prob = torch.zeros(max_n_nodes + 1) 13 | for num_nodes, count in histogram.items(): 14 | prob[num_nodes] = count 15 | else: 16 | prob = histogram 17 | 18 | self.prob = prob / prob.sum() 19 | self.m = torch.distributions.Categorical(prob) 20 | 21 | def sample_n(self, n_samples, device): 22 | idx = self.m.sample((n_samples,)) 23 | return idx.to(device) 24 | 25 | def log_prob(self, batch_n_nodes): 26 | assert len(batch_n_nodes.size()) == 1 27 | p = self.prob.to(batch_n_nodes.device) 28 | 29 | probas = p[batch_n_nodes] 30 | log_p = torch.log(probas + 1e-6) 31 | return log_p 32 | -------------------------------------------------------------------------------- /src/diffusion/extra_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src import utils 3 | 4 | 5 | class DummyExtraFeatures: 6 | def __init__(self): 7 | """ This class does not compute anything, just returns empty tensors.""" 8 | 9 | def __call__(self, noisy_data): 10 | X = noisy_data['X_t'] 11 | E = noisy_data['E_t'] 12 | y = noisy_data['y_t'] 13 | empty_x = X.new_zeros((*X.shape[:-1], 0)) 14 | empty_e = E.new_zeros((*E.shape[:-1], 0)) 15 | empty_y = y.new_zeros((y.shape[0], 0)) 16 | return utils.PlaceHolder(X=empty_x, E=empty_e, y=empty_y) 17 | 18 | 19 | class ExtraFeatures: 20 | def __init__(self, extra_features_type, dataset_info): 21 | self.max_n_nodes = dataset_info.max_n_nodes 22 | self.ncycles = NodeCycleFeatures() 23 | self.features_type = extra_features_type 24 | if extra_features_type in ['eigenvalues', 'all']: 25 | self.eigenfeatures = EigenFeatures(mode=extra_features_type) 26 | 27 | def __call__(self, noisy_data): 28 | n = noisy_data['node_mask'].sum(dim=1).unsqueeze(1) / self.max_n_nodes 29 | x_cycles, y_cycles = self.ncycles(noisy_data) # (bs, n_cycles) 30 | 31 | if self.features_type == 'cycles': 32 | E = noisy_data['E_t'] 33 | extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E) 34 | return utils.PlaceHolder(X=x_cycles, E=extra_edge_attr, y=torch.hstack((n, y_cycles))) 35 | 36 | elif self.features_type == 'eigenvalues': 37 | eigenfeatures = self.eigenfeatures(noisy_data) 38 | E = noisy_data['E_t'] 39 | extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E) 40 | n_components, batched_eigenvalues = eigenfeatures # (bs, 1), (bs, 10) 41 | return utils.PlaceHolder(X=x_cycles, E=extra_edge_attr, y=torch.hstack((n, y_cycles, n_components, 42 | batched_eigenvalues))) 43 | elif self.features_type == 'all': 44 | eigenfeatures = self.eigenfeatures(noisy_data) 45 | E = noisy_data['E_t'] 46 | extra_edge_attr = torch.zeros((*E.shape[:-1], 0)).type_as(E) 47 | n_components, batched_eigenvalues, nonlcc_indicator, k_lowest_eigvec = eigenfeatures # (bs, 1), (bs, 10), 48 | # (bs, n, 1), (bs, n, 2) 49 | 50 | return utils.PlaceHolder(X=torch.cat((x_cycles, nonlcc_indicator, k_lowest_eigvec), dim=-1), 51 | E=extra_edge_attr, 52 | y=torch.hstack((n, y_cycles, n_components, batched_eigenvalues))) 53 | else: 54 | raise ValueError(f"Features type {self.features_type} not implemented") 55 | 56 | 57 | class NodeCycleFeatures: 58 | def __init__(self): 59 | self.kcycles = KNodeCycles() 60 | 61 | def __call__(self, noisy_data): 62 | adj_matrix = noisy_data['E_t'][..., 1:].sum(dim=-1).float() 63 | 64 | x_cycles, y_cycles = self.kcycles.k_cycles(adj_matrix=adj_matrix) # (bs, n_cycles) 65 | x_cycles = x_cycles.type_as(adj_matrix) * noisy_data['node_mask'].unsqueeze(-1) 66 | # Avoid large values when the graph is dense 67 | x_cycles = x_cycles / 10 68 | y_cycles = y_cycles / 10 69 | x_cycles[x_cycles > 1] = 1 70 | y_cycles[y_cycles > 1] = 1 71 | return x_cycles, y_cycles 72 | 73 | 74 | class EigenFeatures: 75 | """ 76 | Code taken from : https://github.com/Saro00/DGN/blob/master/models/pytorch/eigen_agg.py 77 | """ 78 | def __init__(self, mode): 79 | """ mode: 'eigenvalues' or 'all' """ 80 | self.mode = mode 81 | 82 | def __call__(self, noisy_data): 83 | E_t = noisy_data['E_t'] 84 | mask = noisy_data['node_mask'] 85 | A = E_t[..., 1:].sum(dim=-1).float() * mask.unsqueeze(1) * mask.unsqueeze(2) 86 | L = compute_laplacian(A, normalize=False) 87 | mask_diag = 2 * L.shape[-1] * torch.eye(A.shape[-1]).type_as(L).unsqueeze(0) 88 | mask_diag = mask_diag * (~mask.unsqueeze(1)) * (~mask.unsqueeze(2)) 89 | L = L * mask.unsqueeze(1) * mask.unsqueeze(2) + mask_diag 90 | 91 | if self.mode == 'eigenvalues': 92 | eigvals = torch.linalg.eigvalsh(L) # bs, n 93 | eigvals = eigvals.type_as(A) / torch.sum(mask, dim=1, keepdim=True) 94 | 95 | n_connected_comp, batch_eigenvalues = get_eigenvalues_features(eigenvalues=eigvals) 96 | return n_connected_comp.type_as(A), batch_eigenvalues.type_as(A) 97 | 98 | elif self.mode == 'all': 99 | eigvals, eigvectors = torch.linalg.eigh(L) 100 | eigvals = eigvals.type_as(A) / torch.sum(mask, dim=1, keepdim=True) 101 | eigvectors = eigvectors * mask.unsqueeze(2) * mask.unsqueeze(1) 102 | # Retrieve eigenvalues features 103 | n_connected_comp, batch_eigenvalues = get_eigenvalues_features(eigenvalues=eigvals) 104 | 105 | # Retrieve eigenvectors features 106 | nonlcc_indicator, k_lowest_eigenvector = get_eigenvectors_features(vectors=eigvectors, 107 | node_mask=noisy_data['node_mask'], 108 | n_connected=n_connected_comp) 109 | return n_connected_comp, batch_eigenvalues, nonlcc_indicator, k_lowest_eigenvector 110 | else: 111 | raise NotImplementedError(f"Mode {self.mode} is not implemented") 112 | 113 | 114 | def compute_laplacian(adjacency, normalize: bool): 115 | """ 116 | adjacency : batched adjacency matrix (bs, n, n) 117 | normalize: can be None, 'sym' or 'rw' for the combinatorial, symmetric normalized or random walk Laplacians 118 | Return: 119 | L (n x n ndarray): combinatorial or symmetric normalized Laplacian. 120 | """ 121 | diag = torch.sum(adjacency, dim=-1) # (bs, n) 122 | n = diag.shape[-1] 123 | D = torch.diag_embed(diag) # Degree matrix # (bs, n, n) 124 | combinatorial = D - adjacency # (bs, n, n) 125 | 126 | if not normalize: 127 | return (combinatorial + combinatorial.transpose(1, 2)) / 2 128 | 129 | diag0 = diag.clone() 130 | diag[diag == 0] = 1e-12 131 | 132 | diag_norm = 1 / torch.sqrt(diag) # (bs, n) 133 | D_norm = torch.diag_embed(diag_norm) # (bs, n, n) 134 | L = torch.eye(n).unsqueeze(0) - D_norm @ adjacency @ D_norm 135 | L[diag0 == 0] = 0 136 | return (L + L.transpose(1, 2)) / 2 137 | 138 | 139 | def get_eigenvalues_features(eigenvalues, k=5): 140 | """ 141 | values : eigenvalues -- (bs, n) 142 | node_mask: (bs, n) 143 | k: num of non zero eigenvalues to keep 144 | """ 145 | ev = eigenvalues 146 | bs, n = ev.shape 147 | n_connected_components = (ev < 1e-5).sum(dim=-1) 148 | assert (n_connected_components > 0).all(), (n_connected_components, ev) 149 | 150 | to_extend = max(n_connected_components) + k - n 151 | if to_extend > 0: 152 | eigenvalues = torch.hstack((eigenvalues, 2 * torch.ones(bs, to_extend).type_as(eigenvalues))) 153 | indices = torch.arange(k).type_as(eigenvalues).long().unsqueeze(0) + n_connected_components.unsqueeze(1) 154 | first_k_ev = torch.gather(eigenvalues, dim=1, index=indices) 155 | return n_connected_components.unsqueeze(-1), first_k_ev 156 | 157 | 158 | def get_eigenvectors_features(vectors, node_mask, n_connected, k=2): 159 | """ 160 | vectors (bs, n, n) : eigenvectors of Laplacian IN COLUMNS 161 | returns: 162 | not_lcc_indicator : indicator vectors of largest connected component (lcc) for each graph -- (bs, n, 1) 163 | k_lowest_eigvec : k first eigenvectors for the largest connected component -- (bs, n, k) 164 | """ 165 | bs, n = vectors.size(0), vectors.size(1) 166 | 167 | # Create an indicator for the nodes outside the largest connected components 168 | first_ev = torch.round(vectors[:, :, 0], decimals=3) * node_mask # bs, n 169 | # Add random value to the mask to prevent 0 from becoming the mode 170 | random = torch.randn(bs, n, device=node_mask.device) * (~node_mask) # bs, n 171 | first_ev = first_ev + random 172 | most_common = torch.mode(first_ev, dim=1).values # values: bs -- indices: bs 173 | mask = ~ (first_ev == most_common.unsqueeze(1)) 174 | not_lcc_indicator = (mask * node_mask).unsqueeze(-1).float() 175 | 176 | # Get the eigenvectors corresponding to the first nonzero eigenvalues 177 | to_extend = max(n_connected) + k - n 178 | if to_extend > 0: 179 | vectors = torch.cat((vectors, torch.zeros(bs, n, to_extend).type_as(vectors)), dim=2) # bs, n , n + to_extend 180 | indices = torch.arange(k).type_as(vectors).long().unsqueeze(0).unsqueeze(0) + n_connected.unsqueeze(2) # bs, 1, k 181 | indices = indices.expand(-1, n, -1) # bs, n, k 182 | first_k_ev = torch.gather(vectors, dim=2, index=indices) # bs, n, k 183 | first_k_ev = first_k_ev * node_mask.unsqueeze(2) 184 | 185 | return not_lcc_indicator, first_k_ev 186 | 187 | def batch_trace(X): 188 | """ 189 | Expect a matrix of shape B N N, returns the trace in shape B 190 | :param X: 191 | :return: 192 | """ 193 | diag = torch.diagonal(X, dim1=-2, dim2=-1) 194 | trace = diag.sum(dim=-1) 195 | return trace 196 | 197 | 198 | def batch_diagonal(X): 199 | """ 200 | Extracts the diagonal from the last two dims of a tensor 201 | :param X: 202 | :return: 203 | """ 204 | return torch.diagonal(X, dim1=-2, dim2=-1) 205 | 206 | 207 | class KNodeCycles: 208 | """ Builds cycle counts for each node in a graph. 209 | """ 210 | 211 | def __init__(self): 212 | super().__init__() 213 | 214 | def calculate_kpowers(self): 215 | self.k1_matrix = self.adj_matrix.float() 216 | self.d = self.adj_matrix.sum(dim=-1) 217 | self.k2_matrix = self.k1_matrix @ self.adj_matrix.float() 218 | self.k3_matrix = self.k2_matrix @ self.adj_matrix.float() 219 | self.k4_matrix = self.k3_matrix @ self.adj_matrix.float() 220 | self.k5_matrix = self.k4_matrix @ self.adj_matrix.float() 221 | self.k6_matrix = self.k5_matrix @ self.adj_matrix.float() 222 | 223 | def k3_cycle(self): 224 | """ tr(A ** 3). """ 225 | c3 = batch_diagonal(self.k3_matrix) 226 | return (c3 / 2).unsqueeze(-1).float(), (torch.sum(c3, dim=-1) / 6).unsqueeze(-1).float() 227 | 228 | def k4_cycle(self): 229 | diag_a4 = batch_diagonal(self.k4_matrix) 230 | c4 = diag_a4 - self.d * (self.d - 1) - (self.adj_matrix @ self.d.unsqueeze(-1)).sum(dim=-1) 231 | return (c4 / 2).unsqueeze(-1).float(), (torch.sum(c4, dim=-1) / 8).unsqueeze(-1).float() 232 | 233 | def k5_cycle(self): 234 | diag_a5 = batch_diagonal(self.k5_matrix) 235 | triangles = batch_diagonal(self.k3_matrix) / 2 236 | 237 | # Triangle count matrix (indicates for each node i how many triangles it shares with node j) 238 | joint_cycles = self.k2_matrix * self.adj_matrix 239 | # c5 = diag_a5 - 2 * triangles * self.d - (self.adj_matrix @ triangles.unsqueeze(-1)).sum(dim=-1) + triangles 240 | prod = 2 * (joint_cycles @ self.d.unsqueeze(-1)).squeeze(-1) 241 | prod2 = 2 * (self.adj_matrix @ triangles.unsqueeze(-1)).squeeze(-1) 242 | c5 = diag_a5 - prod - 4 * self.d * triangles - prod2 + 10 * triangles 243 | return (c5 / 2).unsqueeze(-1).float(), (c5.sum(dim=-1) / 10).unsqueeze(-1).float() 244 | 245 | def k6_cycle(self): 246 | term_1_t = batch_trace(self.k6_matrix) 247 | term_2_t = batch_trace(self.k3_matrix ** 2) 248 | term3_t = torch.sum(self.adj_matrix * self.k2_matrix.pow(2), dim=[-2, -1]) 249 | d_t4 = batch_diagonal(self.k2_matrix) 250 | a_4_t = batch_diagonal(self.k4_matrix) 251 | term_4_t = (d_t4 * a_4_t).sum(dim=-1) 252 | term_5_t = batch_trace(self.k4_matrix) 253 | term_6_t = batch_trace(self.k3_matrix) 254 | term_7_t = batch_diagonal(self.k2_matrix).pow(3).sum(-1) 255 | term8_t = torch.sum(self.k3_matrix, dim=[-2, -1]) 256 | term9_t = batch_diagonal(self.k2_matrix).pow(2).sum(-1) 257 | term10_t = batch_trace(self.k2_matrix) 258 | 259 | c6_t = (term_1_t - 3 * term_2_t + 9 * term3_t - 6 * term_4_t + 6 * term_5_t - 4 * term_6_t + 4 * term_7_t + 260 | 3 * term8_t - 12 * term9_t + 4 * term10_t) 261 | return None, (c6_t / 12).unsqueeze(-1).float() 262 | 263 | def k_cycles(self, adj_matrix, verbose=False): 264 | self.adj_matrix = adj_matrix 265 | self.calculate_kpowers() 266 | 267 | k3x, k3y = self.k3_cycle() 268 | assert (k3x >= -0.1).all() 269 | 270 | k4x, k4y = self.k4_cycle() 271 | assert (k4x >= -0.1).all() 272 | 273 | k5x, k5y = self.k5_cycle() 274 | assert (k5x >= -0.1).all(), k5x 275 | 276 | _, k6y = self.k6_cycle() 277 | assert (k6y >= -0.1).all() 278 | 279 | kcyclesx = torch.cat([k3x, k4x, k5x], dim=-1) 280 | kcyclesy = torch.cat([k3y, k4y, k5y, k6y], dim=-1) 281 | return kcyclesx, kcyclesy -------------------------------------------------------------------------------- /src/diffusion/extra_features_molecular.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src import utils 3 | 4 | 5 | class ExtraMolecularFeatures: 6 | def __init__(self, dataset_infos): 7 | self.charge = ChargeFeature(remove_h=dataset_infos.remove_h, valencies=dataset_infos.valencies) 8 | self.valency = ValencyFeature() 9 | self.weight = WeightFeature(max_weight=dataset_infos.max_weight, atom_weights=dataset_infos.atom_weights) 10 | 11 | def __call__(self, noisy_data): 12 | charge = self.charge(noisy_data).unsqueeze(-1) # (bs, n, 1) 13 | valency = self.valency(noisy_data).unsqueeze(-1) # (bs, n, 1) 14 | weight = self.weight(noisy_data) # (bs, 1) 15 | 16 | extra_edge_attr = torch.zeros((*noisy_data['E_t'].shape[:-1], 0)).type_as(noisy_data['E_t']) 17 | 18 | return utils.PlaceHolder(X=torch.cat((charge, valency), dim=-1), E=extra_edge_attr, y=weight) 19 | 20 | 21 | class ChargeFeature: 22 | def __init__(self, remove_h, valencies): 23 | self.remove_h = remove_h 24 | self.valencies = valencies 25 | 26 | def __call__(self, noisy_data): 27 | bond_orders = torch.tensor([0, 1, 2, 3, 1.5], device=noisy_data['E_t'].device).reshape(1, 1, 1, -1) 28 | weighted_E = noisy_data['E_t'] * bond_orders # (bs, n, n, de) 29 | current_valencies = weighted_E.argmax(dim=-1).sum(dim=-1) # (bs, n) 30 | 31 | valencies = torch.tensor(self.valencies, device=noisy_data['X_t'].device).reshape(1, 1, -1) 32 | X = noisy_data['X_t'] * valencies # (bs, n, dx) 33 | normal_valencies = torch.argmax(X, dim=-1) # (bs, n) 34 | 35 | return (normal_valencies - current_valencies).type_as(noisy_data['X_t']) 36 | 37 | 38 | class ValencyFeature: 39 | def __init__(self): 40 | pass 41 | 42 | def __call__(self, noisy_data): 43 | orders = torch.tensor([0, 1, 2, 3, 1.5], device=noisy_data['E_t'].device).reshape(1, 1, 1, -1) 44 | E = noisy_data['E_t'] * orders # (bs, n, n, de) 45 | valencies = E.argmax(dim=-1).sum(dim=-1) # (bs, n) 46 | return valencies.type_as(noisy_data['X_t']) 47 | 48 | 49 | class WeightFeature: 50 | def __init__(self, max_weight, atom_weights): 51 | self.max_weight = max_weight 52 | self.atom_weight_list = torch.tensor(list(atom_weights.values())) 53 | 54 | def __call__(self, noisy_data): 55 | X = torch.argmax(noisy_data['X_t'], dim=-1) # (bs, n) 56 | X_weights = self.atom_weight_list.to(X.device)[X] # (bs, n) 57 | return X_weights.sum(dim=-1).unsqueeze(-1).type_as(noisy_data['X_t']) / self.max_weight # (bs, 1) 58 | -------------------------------------------------------------------------------- /src/diffusion/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | class SinusoidalPosEmb(torch.nn.Module): 6 | def __init__(self, dim): 7 | super().__init__() 8 | self.dim = dim 9 | 10 | def forward(self, x): 11 | x = x.squeeze() * 1000 12 | assert len(x.shape) == 1 13 | half_dim = self.dim // 2 14 | emb = math.log(10000) / (half_dim - 1) 15 | emb = torch.exp(torch.arange(half_dim) * -emb) 16 | emb = emb.type_as(x) 17 | emb = x[:, None] * emb[None, :] 18 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 19 | return emb 20 | -------------------------------------------------------------------------------- /src/diffusion/noise_schedule.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from src import utils 4 | from src.diffusion import diffusion_utils 5 | 6 | 7 | class PredefinedNoiseSchedule(torch.nn.Module): 8 | """ 9 | Predefined noise schedule. Essentially creates a lookup array for predefined (non-learned) noise schedules. 10 | """ 11 | 12 | def __init__(self, noise_schedule, timesteps): 13 | super(PredefinedNoiseSchedule, self).__init__() 14 | self.timesteps = timesteps 15 | 16 | if noise_schedule == 'cosine': 17 | alphas2 = diffusion_utils.cosine_beta_schedule(timesteps) 18 | elif noise_schedule == 'custom': 19 | raise NotImplementedError() 20 | else: 21 | raise ValueError(noise_schedule) 22 | 23 | # print('alphas2', alphas2) 24 | 25 | sigmas2 = 1 - alphas2 26 | 27 | log_alphas2 = np.log(alphas2) 28 | log_sigmas2 = np.log(sigmas2) 29 | 30 | log_alphas2_to_sigmas2 = log_alphas2 - log_sigmas2 # (timesteps + 1, ) 31 | 32 | # print('gamma', -log_alphas2_to_sigmas2) 33 | 34 | self.gamma = torch.nn.Parameter( 35 | torch.from_numpy(-log_alphas2_to_sigmas2).float(), 36 | requires_grad=False) 37 | 38 | def forward(self, t): 39 | t_int = torch.round(t * self.timesteps).long() 40 | return self.gamma[t_int] 41 | 42 | 43 | 44 | class PredefinedNoiseScheduleDiscrete(torch.nn.Module): 45 | """ 46 | Predefined noise schedule. Essentially creates a lookup array for predefined (non-learned) noise schedules. 47 | """ 48 | 49 | def __init__(self, noise_schedule, timesteps): 50 | super(PredefinedNoiseScheduleDiscrete, self).__init__() 51 | self.timesteps = timesteps 52 | 53 | if noise_schedule == 'cosine': 54 | betas = diffusion_utils.cosine_beta_schedule_discrete(timesteps) 55 | elif noise_schedule == 'custom': 56 | betas = diffusion_utils.custom_beta_schedule_discrete(timesteps) 57 | else: 58 | raise NotImplementedError(noise_schedule) 59 | 60 | self.register_buffer('betas', torch.from_numpy(betas).float()) 61 | 62 | self.alphas = 1 - torch.clamp(self.betas, min=0, max=0.9999) 63 | 64 | log_alpha = torch.log(self.alphas) 65 | log_alpha_bar = torch.cumsum(log_alpha, dim=0) 66 | self.alphas_bar = torch.exp(log_alpha_bar) 67 | # print(f"[Noise schedule: {noise_schedule}] alpha_bar:", self.alphas_bar) 68 | 69 | def forward(self, t_normalized=None, t_int=None): 70 | assert int(t_normalized is None) + int(t_int is None) == 1 71 | if t_int is None: 72 | t_int = torch.round(t_normalized * self.timesteps) 73 | return self.betas[t_int.long()] 74 | 75 | def get_alpha_bar(self, t_normalized=None, t_int=None): 76 | assert int(t_normalized is None) + int(t_int is None) == 1 77 | if t_int is None: 78 | t_int = torch.round(t_normalized * self.timesteps) 79 | return self.alphas_bar.to(t_int.device)[t_int.long()] 80 | 81 | 82 | class DiscreteUniformTransition: 83 | def __init__(self, x_classes: int, e_classes: int, y_classes: int): 84 | self.X_classes = x_classes 85 | self.E_classes = e_classes 86 | self.y_classes = y_classes 87 | self.u_x = torch.ones(1, self.X_classes, self.X_classes) 88 | if self.X_classes > 0: 89 | self.u_x = self.u_x / self.X_classes 90 | 91 | self.u_e = torch.ones(1, self.E_classes, self.E_classes) 92 | if self.E_classes > 0: 93 | self.u_e = self.u_e / self.E_classes 94 | 95 | self.u_y = torch.ones(1, self.y_classes, self.y_classes) 96 | if self.y_classes > 0: 97 | self.u_y = self.u_y / self.y_classes 98 | 99 | def get_Qt(self, beta_t, device): 100 | """ Returns one-step transition matrices for X and E, from step t - 1 to step t. 101 | Qt = (1 - beta_t) * I + beta_t / K 102 | 103 | beta_t: (bs) noise level between 0 and 1 104 | returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy). 105 | """ 106 | beta_t = beta_t.unsqueeze(1) 107 | beta_t = beta_t.to(device) 108 | self.u_x = self.u_x.to(device) 109 | self.u_e = self.u_e.to(device) 110 | self.u_y = self.u_y.to(device) 111 | 112 | q_x = beta_t * self.u_x + (1 - beta_t) * torch.eye(self.X_classes, device=device).unsqueeze(0) 113 | q_e = beta_t * self.u_e + (1 - beta_t) * torch.eye(self.E_classes, device=device).unsqueeze(0) 114 | q_y = beta_t * self.u_y + (1 - beta_t) * torch.eye(self.y_classes, device=device).unsqueeze(0) 115 | 116 | return utils.PlaceHolder(X=q_x, E=q_e, y=q_y) 117 | 118 | def get_Qt_bar(self, alpha_bar_t, device): 119 | """ Returns t-step transition matrices for X and E, from step 0 to step t. 120 | Qt = prod(1 - beta_t) * I + (1 - prod(1 - beta_t)) / K 121 | 122 | alpha_bar_t: (bs) Product of the (1 - beta_t) for each time step from 0 to t. 123 | returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy). 124 | """ 125 | alpha_bar_t = alpha_bar_t.unsqueeze(1) 126 | alpha_bar_t = alpha_bar_t.to(device) 127 | self.u_x = self.u_x.to(device) 128 | self.u_e = self.u_e.to(device) 129 | self.u_y = self.u_y.to(device) 130 | 131 | q_x = alpha_bar_t * torch.eye(self.X_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_x 132 | q_e = alpha_bar_t * torch.eye(self.E_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_e 133 | q_y = alpha_bar_t * torch.eye(self.y_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_y 134 | 135 | return utils.PlaceHolder(X=q_x, E=q_e, y=q_y) 136 | 137 | 138 | class MarginalUniformTransition: 139 | def __init__(self, x_marginals, e_marginals, y_classes): 140 | self.X_classes = len(x_marginals) 141 | self.E_classes = len(e_marginals) 142 | self.y_classes = y_classes 143 | self.x_marginals = x_marginals 144 | self.e_marginals = e_marginals 145 | 146 | self.u_x = x_marginals.unsqueeze(0).expand(self.X_classes, -1).unsqueeze(0) 147 | self.u_e = e_marginals.unsqueeze(0).expand(self.E_classes, -1).unsqueeze(0) 148 | self.u_y = torch.ones(1, self.y_classes, self.y_classes) 149 | if self.y_classes > 0: 150 | self.u_y = self.u_y / self.y_classes 151 | 152 | def get_Qt(self, beta_t, device): 153 | """ Returns one-step transition matrices for X and E, from step t - 1 to step t. 154 | Qt = (1 - beta_t) * I + beta_t / K 155 | 156 | beta_t: (bs) noise level between 0 and 1 157 | returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy). """ 158 | beta_t = beta_t.unsqueeze(1) 159 | beta_t = beta_t.to(device) 160 | self.u_x = self.u_x.to(device) 161 | self.u_e = self.u_e.to(device) 162 | self.u_y = self.u_y.to(device) 163 | 164 | q_x = beta_t * self.u_x + (1 - beta_t) * torch.eye(self.X_classes, device=device).unsqueeze(0) 165 | q_e = beta_t * self.u_e + (1 - beta_t) * torch.eye(self.E_classes, device=device).unsqueeze(0) 166 | q_y = beta_t * self.u_y + (1 - beta_t) * torch.eye(self.y_classes, device=device).unsqueeze(0) 167 | 168 | return utils.PlaceHolder(X=q_x, E=q_e, y=q_y) 169 | 170 | def get_Qt_bar(self, alpha_bar_t, device): 171 | """ Returns t-step transition matrices for X and E, from step 0 to step t. 172 | Qt = prod(1 - beta_t) * I + (1 - prod(1 - beta_t)) * K 173 | 174 | alpha_bar_t: (bs) Product of the (1 - beta_t) for each time step from 0 to t. 175 | returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy). 176 | """ 177 | alpha_bar_t = alpha_bar_t.unsqueeze(1) 178 | alpha_bar_t = alpha_bar_t.to(device) 179 | self.u_x = self.u_x.to(device) 180 | self.u_e = self.u_e.to(device) 181 | self.u_y = self.u_y.to(device) 182 | 183 | q_x = alpha_bar_t * torch.eye(self.X_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_x 184 | q_e = alpha_bar_t * torch.eye(self.E_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_e 185 | q_y = alpha_bar_t * torch.eye(self.y_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_y 186 | 187 | return utils.PlaceHolder(X=q_x, E=q_e, y=q_y) 188 | 189 | 190 | class AbsorbingStateTransition: 191 | def __init__(self, abs_state: int, x_classes: int, e_classes: int, y_classes: int): 192 | self.X_classes = x_classes 193 | self.E_classes = e_classes 194 | self.y_classes = y_classes 195 | 196 | self.u_x = torch.zeros(1, self.X_classes, self.X_classes) 197 | self.u_x[:, :, abs_state] = 1 198 | 199 | self.u_e = torch.zeros(1, self.E_classes, self.E_classes) 200 | self.u_e[:, :, abs_state] = 1 201 | 202 | self.u_y = torch.zeros(1, self.y_classes, self.y_classes) 203 | self.u_e[:, :, abs_state] = 1 204 | 205 | def get_Qt(self, beta_t): 206 | """ Returns two transition matrix for X and E""" 207 | beta_t = beta_t.unsqueeze(1) 208 | q_x = beta_t * self.u_x + (1 - beta_t) * torch.eye(self.X_classes).unsqueeze(0) 209 | q_e = beta_t * self.u_e + (1 - beta_t) * torch.eye(self.E_classes).unsqueeze(0) 210 | q_y = beta_t * self.u_y + (1 - beta_t) * torch.eye(self.y_classes).unsqueeze(0) 211 | return q_x, q_e, q_y 212 | 213 | def get_Qt_bar(self, alpha_bar_t): 214 | """ beta_t: (bs) 215 | Returns transition matrices for X and E""" 216 | 217 | alpha_bar_t = alpha_bar_t.unsqueeze(1) 218 | 219 | q_x = alpha_bar_t * torch.eye(self.X_classes).unsqueeze(0) + (1 - alpha_bar_t) * self.u_x 220 | q_e = alpha_bar_t * torch.eye(self.E_classes).unsqueeze(0) + (1 - alpha_bar_t) * self.u_e 221 | q_y = alpha_bar_t * torch.eye(self.y_classes).unsqueeze(0) + (1 - alpha_bar_t) * self.u_y 222 | 223 | return q_x, q_e, q_y 224 | -------------------------------------------------------------------------------- /src/fp2mol_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pathlib 4 | import warnings 5 | import logging 6 | 7 | import torch 8 | torch.cuda.empty_cache() 9 | try: 10 | torch.set_float32_matmul_precision('medium') 11 | logging.info("Enabled float32 matmul precision - medium") 12 | except: 13 | logging.info("Could not enable float32 matmul precision - medium") 14 | import hydra 15 | from omegaconf import DictConfig 16 | from pytorch_lightning import Trainer 17 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 18 | from pytorch_lightning.loggers import CSVLogger, WandbLogger 19 | from pytorch_lightning.utilities.warnings import PossibleUserWarning 20 | 21 | from src import utils 22 | from src.diffusion_model_fp2mol import FP2MolDenoisingDiffusion 23 | from src.diffusion.extra_features import DummyExtraFeatures, ExtraFeatures 24 | 25 | 26 | warnings.filterwarnings("ignore", category=PossibleUserWarning) 27 | 28 | 29 | def get_resume(cfg, model_kwargs): 30 | """ Resumes a run. It loads previous config without allowing to update keys (used for testing). """ 31 | saved_cfg = cfg.copy() 32 | name = cfg.general.name + '_resume' 33 | resume = cfg.general.test_only 34 | val_samples_to_generate = cfg.general.val_samples_to_generate 35 | test_samples_to_generate = cfg.general.test_samples_to_generate 36 | if cfg.model.type == 'discrete': 37 | model = FP2MolDenoisingDiffusion.load_from_checkpoint(resume, **model_kwargs) 38 | else: 39 | raise NotImplementedError("Only discrete diffusion models are supported for FP2Mol dataset currently") 40 | cfg = model.cfg 41 | cfg.general.test_only = resume 42 | cfg.general.name = name 43 | cfg.general.val_samples_to_generate = val_samples_to_generate 44 | cfg.general.test_samples_to_generate = test_samples_to_generate 45 | cfg = utils.update_config_with_new_keys(cfg, saved_cfg) 46 | return cfg, model 47 | 48 | 49 | def get_resume_adaptive(cfg, model_kwargs): 50 | """ Resumes a run. It loads previous config but allows to make some changes (used for resuming training).""" 51 | saved_cfg = cfg.copy() 52 | # Fetch path to this file to get base path 53 | current_path = os.path.dirname(os.path.realpath(__file__)) 54 | root_dir = current_path.split('outputs')[0] 55 | 56 | resume_path = os.path.join(root_dir, cfg.general.resume) 57 | 58 | model = FP2MolDenoisingDiffusion.load_from_checkpoint(resume_path, **model_kwargs) 59 | new_cfg = model.cfg 60 | 61 | for category in cfg: 62 | for arg in cfg[category]: 63 | new_cfg[category][arg] = cfg[category][arg] 64 | 65 | new_cfg.general.resume = resume_path 66 | new_cfg.general.name = new_cfg.general.name + '_resume' 67 | 68 | new_cfg = utils.update_config_with_new_keys(new_cfg, saved_cfg) 69 | return new_cfg, model 70 | 71 | def load_decoder_from_lightning_ckpt(model, ckpt_path): 72 | """ Load a model from a PyTorch Lightning checkpoint. """ 73 | state_dict = torch.load(ckpt_path, map_location='cpu')["state_dict"] 74 | cleaned_state_dict = {} 75 | for k, v in state_dict.items(): 76 | if k.startswith('model.'): 77 | k = k[6:] 78 | cleaned_state_dict[k] = v 79 | 80 | model.model.load_state_dict(cleaned_state_dict, strict=True) 81 | logging.info(f"Loaded model from: '{ckpt_path}'") 82 | 83 | 84 | @hydra.main(version_base='1.3', config_path='../configs', config_name='config') 85 | def main(cfg: DictConfig): 86 | from rdkit import RDLogger 87 | RDLogger.DisableLog('rdApp.*') 88 | 89 | logger = logging.getLogger("msms_main") 90 | logger.setLevel(logging.INFO) 91 | 92 | formatter = logging.Formatter( 93 | "%(asctime)s.%(msecs)03d %(levelname)s: %(message)s", 94 | datefmt="%Y-%m-%d %H:%M:%S", 95 | ) 96 | 97 | ch = logging.StreamHandler(stream=sys.stdout) 98 | ch.setFormatter(formatter) 99 | logger.addHandler(ch) 100 | 101 | path = os.path.join("msms_main.log") 102 | fh = logging.FileHandler(path) 103 | fh.setFormatter(formatter) 104 | 105 | logger.addHandler(fh) 106 | 107 | logging.info(cfg) 108 | 109 | dataset_config = cfg["dataset"] 110 | 111 | if dataset_config["name"] != "fp2mol": 112 | raise NotImplementedError("Unknown dataset {}".format(cfg["dataset"])) 113 | 114 | from metrics.molecular_metrics import TrainMolecularMetrics, SamplingMolecularMetrics 115 | from metrics.molecular_metrics_discrete import TrainMolecularMetricsDiscrete 116 | from diffusion.extra_features_molecular import ExtraMolecularFeatures 117 | from analysis.visualization import MolecularVisualization 118 | 119 | from datasets import fp2mol_dataset 120 | 121 | datamodule = fp2mol_dataset.FP2MolDataModule(cfg) 122 | logging.info("Dataset loaded") 123 | logging.info(f"Train Size: {len(datamodule.train_dataloader())}, Val Size: {len(datamodule.val_dataloader())}, Test Size: {len(datamodule.test_dataloader())}") 124 | dataset_infos = fp2mol_dataset.FP2Mol_infos(datamodule, cfg, recompute_statistics=False) 125 | 126 | domain_features = ExtraMolecularFeatures(dataset_infos=dataset_infos) 127 | if cfg.model.extra_features is not None: 128 | extra_features = ExtraFeatures(cfg.model.extra_features, dataset_info=dataset_infos) 129 | else: 130 | extra_features = DummyExtraFeatures() 131 | 132 | dataset_infos.compute_input_output_dims(datamodule=datamodule, extra_features=extra_features, domain_features=domain_features) 133 | 134 | logging.info("Dataset infos:", dataset_infos.output_dims) 135 | train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) 136 | 137 | visualization_tools = MolecularVisualization(cfg.dataset.remove_h, dataset_infos=dataset_infos) 138 | 139 | model_kwargs = {'dataset_infos': dataset_infos, 'train_metrics': train_metrics, 140 | 'visualization_tools': visualization_tools, 'extra_features': extra_features, 'domain_features': domain_features} 141 | 142 | if cfg.general.test_only: 143 | # When testing, previous configuration is fully loaded 144 | cfg, _ = get_resume(cfg, model_kwargs) 145 | os.chdir(cfg.general.test_only.split('checkpoints')[0]) 146 | elif cfg.general.resume is not None: 147 | # When resuming, we can override some parts of previous configuration 148 | cfg, _ = get_resume_adaptive(cfg, model_kwargs) 149 | try: 150 | os.chdir(cfg.general.resume.split('checkpoints')[0]) 151 | except: 152 | logging.info("Could not change directory to resume path. Using current directory.") 153 | 154 | os.makedirs('preds/', exist_ok=True) 155 | os.makedirs('models/', exist_ok=True) 156 | os.makedirs('logs/', exist_ok=True) 157 | os.makedirs('logs/' + cfg.general.name, exist_ok=True) 158 | 159 | model = FP2MolDenoisingDiffusion(cfg=cfg, **model_kwargs) 160 | 161 | callbacks = [] 162 | callbacks.append(LearningRateMonitor(logging_interval='step')) 163 | if cfg.train.save_model: 164 | checkpoint_callback = ModelCheckpoint(dirpath=f"checkpoints/{cfg.general.name}", # best (top-5) checkpoints 165 | filename='{epoch}', 166 | monitor='val/NLL', 167 | save_top_k=5, 168 | mode='min', 169 | every_n_epochs=1) 170 | last_ckpt_save = ModelCheckpoint(dirpath=f"checkpoints/{cfg.general.name}", filename='last', every_n_epochs=1) # most recent checkpoint 171 | callbacks.append(last_ckpt_save) 172 | callbacks.append(checkpoint_callback) 173 | 174 | if cfg.train.ema_decay > 0: # TODO: Implement EMA for FP2Mol 175 | ema_callback = utils.EMA(decay=cfg.train.ema_decay) 176 | callbacks.append(ema_callback) 177 | 178 | name = cfg.general.name 179 | if name == 'debug': 180 | logging.warning("Run is called 'debug' -- it will run with fast_dev_run. ") 181 | 182 | loggers = [ 183 | CSVLogger(save_dir=f"logs/{name}", name=name), 184 | WandbLogger(name=name, save_dir=f"logs/{name}", project=cfg.general.wandb_name, log_model=False, config=utils.cfg_to_dict(cfg)) 185 | ] 186 | 187 | use_gpu = cfg.general.gpus > 0 and torch.cuda.is_available() 188 | trainer = Trainer(gradient_clip_val=cfg.train.clip_grad, 189 | strategy="ddp", 190 | accelerator='gpu' if use_gpu else 'cpu', 191 | devices=cfg.general.gpus if use_gpu else 1, 192 | max_epochs=cfg.train.n_epochs, 193 | check_val_every_n_epoch=cfg.general.check_val_every_n_epochs, 194 | fast_dev_run=cfg.general.name == 'debug', 195 | callbacks=callbacks, 196 | log_every_n_steps=50 if name != 'debug' else 1, 197 | logger=loggers) 198 | 199 | if not cfg.general.test_only: 200 | trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) 201 | if cfg.general.name not in ['debug', 'test']: 202 | trainer.test(model, datamodule=datamodule) 203 | else: 204 | # Start by evaluating test_only_path 205 | trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) 206 | if cfg.general.evaluate_all_checkpoints: 207 | directory = pathlib.Path(cfg.general.test_only).parents[0] 208 | logging.info("Directory:", directory) 209 | files_list = os.listdir(directory) 210 | for file in files_list: 211 | if '.ckpt' in file: 212 | ckpt_path = os.path.join(directory, file) 213 | if ckpt_path == cfg.general.test_only: 214 | continue 215 | logging.info("Loading checkpoint", ckpt_path) 216 | trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path) 217 | 218 | 219 | if __name__ == '__main__': 220 | main() 221 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/DiffMS/9d9f4fd497162eec045f7db2787da30ce69a9622/src/metrics/__init__.py -------------------------------------------------------------------------------- /src/metrics/abstract_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import functional as F 4 | from torchmetrics import Metric, MeanSquaredError 5 | 6 | 7 | class TrainAbstractMetricsDiscrete(torch.nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool): 12 | pass 13 | 14 | def reset(self): 15 | pass 16 | 17 | def log_epoch_metrics(self): 18 | return None, None 19 | 20 | 21 | class TrainAbstractMetrics(torch.nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | 25 | def forward(self, masked_pred_epsX, masked_pred_epsE, pred_y, true_epsX, true_epsE, true_y, log): 26 | pass 27 | 28 | def reset(self): 29 | pass 30 | 31 | def log_epoch_metrics(self): 32 | return None, None 33 | 34 | 35 | class SumExceptBatchMetric(Metric): 36 | def __init__(self): 37 | super().__init__() 38 | self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum") 39 | self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") 40 | 41 | def update(self, values) -> None: 42 | self.total_value += torch.sum(values) 43 | self.total_samples += values.shape[0] 44 | 45 | def compute(self): 46 | return self.total_value / self.total_samples 47 | 48 | 49 | class SumExceptBatchMSE(MeanSquaredError): 50 | def update(self, preds: Tensor, target: Tensor) -> None: 51 | """Update state with predictions and targets. 52 | 53 | Args: 54 | preds: Predictions from model 55 | target: Ground truth values 56 | """ 57 | assert preds.shape == target.shape 58 | sum_squared_error, n_obs = self._mean_squared_error_update(preds, target) 59 | 60 | self.sum_squared_error += sum_squared_error 61 | self.total += n_obs 62 | 63 | def _mean_squared_error_update(self, preds: Tensor, target: Tensor): 64 | """ Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input 65 | tensors. 66 | preds: Predicted tensor 67 | target: Ground truth tensor 68 | """ 69 | diff = preds - target 70 | sum_squared_error = torch.sum(diff * diff) 71 | n_obs = preds.shape[0] 72 | return sum_squared_error, n_obs 73 | 74 | 75 | class SumExceptBatchKL(Metric): 76 | def __init__(self): 77 | super().__init__() 78 | self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum") 79 | self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") 80 | 81 | def update(self, p, q) -> None: 82 | self.total_value += F.kl_div(q, p, reduction='sum') 83 | self.total_samples += p.size(0) 84 | 85 | def compute(self): 86 | return self.total_value / self.total_samples 87 | 88 | 89 | class CrossEntropyMetric(Metric): 90 | def __init__(self): 91 | super().__init__() 92 | self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum") 93 | self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") 94 | 95 | def update(self, preds: Tensor, target: Tensor) -> None: 96 | """ Update state with predictions and targets. 97 | preds: Predictions from model (bs * n, d) or (bs * n * n, d) 98 | target: Ground truth values (bs * n, d) or (bs * n * n, d). """ 99 | target = torch.argmax(target, dim=-1) 100 | output = F.cross_entropy(preds, target, reduction='sum') 101 | self.total_ce += output 102 | self.total_samples += preds.size(0) 103 | 104 | def compute(self): 105 | return self.total_ce / self.total_samples 106 | 107 | 108 | class ProbabilityMetric(Metric): 109 | def __init__(self): 110 | """ This metric is used to track the marginal predicted probability of a class during training. """ 111 | super().__init__() 112 | self.add_state('prob', default=torch.tensor(0.), dist_reduce_fx="sum") 113 | self.add_state('total', default=torch.tensor(0.), dist_reduce_fx="sum") 114 | 115 | def update(self, preds: Tensor) -> None: 116 | self.prob += preds.sum() 117 | self.total += preds.numel() 118 | 119 | def compute(self): 120 | return self.prob / self.total 121 | 122 | 123 | class NLL(Metric): 124 | def __init__(self): 125 | super().__init__() 126 | self.add_state('total_nll', default=torch.tensor(0.), dist_reduce_fx="sum") 127 | self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") 128 | 129 | def update(self, batch_nll) -> None: 130 | self.total_nll += torch.sum(batch_nll) 131 | self.total_samples += batch_nll.numel() 132 | 133 | def compute(self): 134 | return self.total_nll / self.total_samples -------------------------------------------------------------------------------- /src/metrics/diffms_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchmetrics import Metric 4 | from collections import Counter 5 | from typing import List 6 | from rdkit import Chem 7 | from rdkit.Chem import AllChem 8 | from rdkit.Chem import DataStructs 9 | 10 | from src.utils import is_valid, canonical_mol_from_inchi 11 | 12 | 13 | class K_ACC(Metric): 14 | def __init__(self, k: int, dist_sync_on_step: bool = False): 15 | super().__init__(dist_sync_on_step=dist_sync_on_step) 16 | self.k = k 17 | self.add_state("correct", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") 18 | self.add_state("total", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") 19 | 20 | def update(self, generated_inchis: List[str], true_inchi: str): 21 | if true_inchi in generated_inchis[: self.k]: 22 | self.correct += 1 23 | self.total += 1 24 | 25 | def compute(self) -> torch.Tensor: 26 | """Compute final top-k accuracy.""" 27 | if self.total == 0: 28 | return torch.tensor(0.0, device=self.correct.device) 29 | return self.correct.float() / self.total.float() 30 | 31 | 32 | class K_ACC_Collection(Metric): 33 | """ 34 | A collection of K_ACC metrics for multiple values of k. 35 | """ 36 | def __init__(self, k_list: List[int], dist_sync_on_step: bool = False): 37 | super().__init__(dist_sync_on_step=dist_sync_on_step) 38 | self.metrics = nn.ModuleDict() 39 | for k in k_list: 40 | self.metrics[f"acc_at_{k}"] = K_ACC(k, dist_sync_on_step=dist_sync_on_step) 41 | 42 | def update(self, generated_mols: List[Chem.Mol], true_mol: Chem.Mol): 43 | # Filter out invalid molecules, and select unique InChIs by frequency 44 | inchis = [Chem.MolToInchi(mol) for mol in generated_mols if is_valid(mol)] 45 | inchi_counter = Counter(inchis) 46 | # Sort by frequency, keep unique 47 | inchis = [item for item, _count in inchi_counter.most_common()] 48 | true_inchi = Chem.MolToInchi(true_mol) 49 | 50 | # Update each K_ACC submetric 51 | for metric in self.metrics.values(): 52 | metric.update(inchis, true_inchi) 53 | 54 | def compute(self): 55 | return {name: m.compute() for name, m in self.metrics.items()} 56 | 57 | class K_TanimotoSimilarity(Metric): 58 | def __init__(self, k: int, dist_sync_on_step: bool = False): 59 | super().__init__(dist_sync_on_step=dist_sync_on_step) 60 | self.k = k 61 | self.add_state("similarity_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") 62 | self.add_state("total", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") 63 | 64 | def update(self, generated_mols: List[Chem.Mol], true_mol: Chem.Mol): 65 | true_fp = AllChem.GetMorganFingerprintAsBitVect(true_mol, 2, nBits=2048) 66 | max_sim = 0.0 67 | for mol in generated_mols[: self.k]: 68 | try: 69 | gen_fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) 70 | sim = DataStructs.TanimotoSimilarity(gen_fp, true_fp) 71 | max_sim = max(max_sim, sim) 72 | except Exception: 73 | pass 74 | self.similarity_sum += max_sim 75 | self.total += 1 76 | 77 | def compute(self) -> torch.Tensor: 78 | """Compute the average max Tanimoto similarity.""" 79 | if self.total == 0: 80 | return torch.tensor(0.0, device=self.similarity_sum.device) 81 | return self.similarity_sum / self.total.float() 82 | 83 | 84 | class K_CosineSimilarity(Metric): 85 | def __init__(self, k: int, dist_sync_on_step: bool = False): 86 | super().__init__(dist_sync_on_step=dist_sync_on_step) 87 | self.k = k 88 | self.add_state("similarity_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") 89 | self.add_state("total", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") 90 | 91 | def update(self, generated_mols: List[Chem.Mol], true_mol: Chem.Mol): 92 | true_fp = AllChem.GetMorganFingerprintAsBitVect(true_mol, 2, nBits=2048) 93 | max_sim = 0.0 94 | for mol in generated_mols[: self.k]: 95 | try: 96 | gen_fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) 97 | sim = DataStructs.CosineSimilarity(gen_fp, true_fp) 98 | max_sim = max(max_sim, sim) 99 | except Exception: 100 | pass 101 | self.similarity_sum += max_sim 102 | self.total += 1 103 | 104 | def compute(self) -> torch.Tensor: 105 | if self.total == 0: 106 | return torch.tensor(0.0, device=self.similarity_sum.device) 107 | return self.similarity_sum / self.total.float() 108 | 109 | 110 | class K_SimilarityCollection(Metric): 111 | def __init__(self, k_list: List[int], dist_sync_on_step: bool = False): 112 | super().__init__(dist_sync_on_step=dist_sync_on_step) 113 | self.metrics = nn.ModuleDict() 114 | for k in k_list: 115 | self.metrics[f"tanimoto_at_{k}"] = K_TanimotoSimilarity(k, dist_sync_on_step=dist_sync_on_step) 116 | self.metrics[f"cosine_at_{k}"] = K_CosineSimilarity(k, dist_sync_on_step=dist_sync_on_step) 117 | 118 | def update(self, generated_mols: List[Chem.Mol], true_mol: Chem.Mol): 119 | inchis = [Chem.MolToInchi(mol) for mol in generated_mols if is_valid(mol)] 120 | inchi_counter = Counter(inchis) 121 | inchis = [item for item, _count in inchi_counter.most_common()] 122 | 123 | processed_mols = [canonical_mol_from_inchi(inchi) for inchi in inchis] 124 | 125 | for metric in self.metrics.values(): 126 | metric.update(processed_mols, true_mol) 127 | 128 | def compute(self): 129 | return {name: m.compute() for name, m in self.metrics.items()} 130 | 131 | 132 | class Validity(Metric): 133 | def __init__(self, dist_sync_on_step: bool = False): 134 | super().__init__(dist_sync_on_step=dist_sync_on_step) 135 | self.add_state("valid", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") 136 | self.add_state("total", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") 137 | 138 | def update(self, generated_mols: List[Chem.Mol]): 139 | for mol in generated_mols: 140 | if is_valid(mol): 141 | self.valid += 1 142 | self.total += 1 143 | 144 | def compute(self) -> torch.Tensor: 145 | if self.total == 0: 146 | return torch.tensor(0.0, device=self.valid.device) 147 | return self.valid.float() / self.total.float() 148 | -------------------------------------------------------------------------------- /src/metrics/molecular_metrics_discrete.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric, MetricCollection 3 | from torch import Tensor 4 | import wandb 5 | import torch.nn as nn 6 | 7 | class CEPerClass(Metric): 8 | full_state_update = False 9 | def __init__(self, class_id): 10 | super().__init__() 11 | self.class_id = class_id 12 | self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum") 13 | self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") 14 | self.softmax = torch.nn.Softmax(dim=-1) 15 | self.binary_cross_entropy = torch.nn.BCELoss(reduction='sum') 16 | 17 | def update(self, preds: Tensor, target: Tensor) -> None: 18 | """Update state with predictions and targets. 19 | Args: 20 | preds: Predictions from model (bs, n, d) or (bs, n, n, d) 21 | target: Ground truth values (bs, n, d) or (bs, n, n, d) 22 | """ 23 | target = target.reshape(-1, target.shape[-1]) 24 | mask = (target != 0.).any(dim=-1) 25 | 26 | prob = self.softmax(preds)[..., self.class_id] 27 | prob = prob.flatten()[mask] 28 | 29 | target = target[:, self.class_id] 30 | target = target[mask] 31 | 32 | output = self.binary_cross_entropy(prob, target) 33 | self.total_ce += output 34 | self.total_samples += prob.numel() 35 | 36 | def compute(self): 37 | return self.total_ce / self.total_samples 38 | 39 | 40 | class HydrogenCE(CEPerClass): 41 | def __init__(self, i): 42 | super().__init__(i) 43 | 44 | 45 | class CarbonCE(CEPerClass): 46 | def __init__(self, i): 47 | super().__init__(i) 48 | 49 | 50 | class NitroCE(CEPerClass): 51 | def __init__(self, i): 52 | super().__init__(i) 53 | 54 | 55 | class OxyCE(CEPerClass): 56 | def __init__(self, i): 57 | super().__init__(i) 58 | 59 | 60 | class FluorCE(CEPerClass): 61 | def __init__(self, i): 62 | super().__init__(i) 63 | 64 | 65 | class BoronCE(CEPerClass): 66 | def __init__(self, i): 67 | super().__init__(i) 68 | 69 | 70 | class BrCE(CEPerClass): 71 | def __init__(self, i): 72 | super().__init__(i) 73 | 74 | 75 | class ClCE(CEPerClass): 76 | def __init__(self, i): 77 | super().__init__(i) 78 | 79 | 80 | class IodineCE(CEPerClass): 81 | def __init__(self, i): 82 | super().__init__(i) 83 | 84 | 85 | class PhosphorusCE(CEPerClass): 86 | def __init__(self, i): 87 | super().__init__(i) 88 | 89 | 90 | class SulfurCE(CEPerClass): 91 | def __init__(self, i): 92 | super().__init__(i) 93 | 94 | 95 | class SeCE(CEPerClass): 96 | def __init__(self, i): 97 | super().__init__(i) 98 | 99 | 100 | class SiCE(CEPerClass): 101 | def __init__(self, i): 102 | super().__init__(i) 103 | 104 | 105 | class NoBondCE(CEPerClass): 106 | def __init__(self, i): 107 | super().__init__(i) 108 | 109 | 110 | class SingleCE(CEPerClass): 111 | def __init__(self, i): 112 | super().__init__(i) 113 | 114 | 115 | class DoubleCE(CEPerClass): 116 | def __init__(self, i): 117 | super().__init__(i) 118 | 119 | 120 | class TripleCE(CEPerClass): 121 | def __init__(self, i): 122 | super().__init__(i) 123 | 124 | 125 | class AromaticCE(CEPerClass): 126 | def __init__(self, i): 127 | super().__init__(i) 128 | 129 | 130 | class AtomMetricsCE(MetricCollection): 131 | def __init__(self, dataset_infos): 132 | atom_decoder = dataset_infos.atom_decoder 133 | 134 | class_dict = {'H': HydrogenCE, 'C': CarbonCE, 'N': NitroCE, 'O': OxyCE, 'F': FluorCE, 'B': BoronCE, 135 | 'Br': BrCE, 'Cl': ClCE, 'I': IodineCE, 'P': PhosphorusCE, 'S': SulfurCE, 'Se': SeCE, 136 | 'Si': SiCE} 137 | 138 | metrics_list = [] 139 | for i, atom_type in enumerate(atom_decoder): 140 | try: 141 | metrics_list.append(class_dict[atom_type](i)) 142 | except KeyError: 143 | pass 144 | super().__init__(metrics_list) 145 | 146 | 147 | class BondMetricsCE(MetricCollection): 148 | def __init__(self): 149 | ce_no_bond = NoBondCE(0) 150 | ce_SI = SingleCE(1) 151 | ce_DO = DoubleCE(2) 152 | ce_TR = TripleCE(3) 153 | ce_AR = AromaticCE(4) 154 | super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR, ce_AR]) 155 | 156 | 157 | class TrainMolecularMetricsDiscrete(nn.Module): 158 | def __init__(self, dataset_infos): 159 | super().__init__() 160 | self.train_atom_metrics = AtomMetricsCE(dataset_infos=dataset_infos) 161 | self.train_bond_metrics = BondMetricsCE() 162 | 163 | def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool): 164 | self.train_atom_metrics(masked_pred_X, true_X) 165 | self.train_bond_metrics(masked_pred_E, true_E) 166 | if log: 167 | to_log = {} 168 | for key, val in self.train_atom_metrics.compute().items(): 169 | to_log['train/' + key] = val.item() 170 | for key, val in self.train_bond_metrics.compute().items(): 171 | to_log['train/' + key] = val.item() 172 | if wandb.run: 173 | wandb.log(to_log, commit=False) 174 | 175 | self.reset() 176 | 177 | def reset(self): 178 | for metric in [self.train_atom_metrics, self.train_bond_metrics]: 179 | metric.reset() 180 | 181 | def log_epoch_metrics(self): 182 | epoch_atom_metrics = self.train_atom_metrics.compute() 183 | epoch_bond_metrics = self.train_bond_metrics.compute() 184 | 185 | to_log = {} 186 | for key, val in epoch_atom_metrics.items(): 187 | to_log['train_epoch/' + key] = val.item() 188 | for key, val in epoch_bond_metrics.items(): 189 | to_log['train_epoch/' + key] = val.item() 190 | if wandb.run: 191 | wandb.log(to_log, commit=False) 192 | 193 | for key, val in epoch_atom_metrics.items(): 194 | epoch_atom_metrics[key] = val.item() 195 | for key, val in epoch_bond_metrics.items(): 196 | epoch_bond_metrics[key] = val.item() 197 | 198 | return epoch_atom_metrics, epoch_bond_metrics 199 | 200 | -------------------------------------------------------------------------------- /src/metrics/train_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchmetrics import MeanSquaredError 4 | import wandb 5 | from src.metrics.abstract_metrics import CrossEntropyMetric 6 | 7 | class TrainLossDiscrete(nn.Module): 8 | """ Train with Cross entropy""" 9 | def __init__(self, lambda_train): 10 | super().__init__() 11 | self.node_loss = CrossEntropyMetric() 12 | self.edge_loss = CrossEntropyMetric() 13 | self.y_loss = CrossEntropyMetric() 14 | self.lambda_train = lambda_train 15 | 16 | def forward(self, masked_pred_X, masked_pred_E, pred_y, true_X, true_E, true_y, log: bool): 17 | """ Compute train metrics 18 | masked_pred_X : tensor -- (bs, n, dx) 19 | masked_pred_E : tensor -- (bs, n, n, de) 20 | pred_y : tensor -- (bs, ) 21 | true_X : tensor -- (bs, n, dx) 22 | true_E : tensor -- (bs, n, n, de) 23 | true_y : tensor -- (bs, ) 24 | log : boolean. """ 25 | true_X = torch.reshape(true_X, (-1, true_X.size(-1))) # (bs * n, dx) 26 | true_E = torch.reshape(true_E, (-1, true_E.size(-1))) # (bs * n * n, de) 27 | masked_pred_X = torch.reshape(masked_pred_X, (-1, masked_pred_X.size(-1))) # (bs * n, dx) 28 | masked_pred_E = torch.reshape(masked_pred_E, (-1, masked_pred_E.size(-1))) # (bs * n * n, de) 29 | 30 | # Remove masked rows 31 | mask_X = (true_X != 0.).any(dim=-1) 32 | mask_E = (true_E != 0.).any(dim=-1) 33 | 34 | flat_true_X = true_X[mask_X, :] 35 | flat_pred_X = masked_pred_X[mask_X, :] 36 | 37 | flat_true_E = true_E[mask_E, :] 38 | flat_pred_E = masked_pred_E[mask_E, :] 39 | 40 | loss_X = self.node_loss(flat_pred_X, flat_true_X) if true_X.numel() > 0 else 0.0 41 | loss_E = self.edge_loss(flat_pred_E, flat_true_E) if true_E.numel() > 0 else 0.0 42 | loss_y = self.y_loss(pred_y, true_y) if true_y.numel() > 0 else 0.0 43 | 44 | if log: 45 | to_log = {"train_loss/batch_CE": (loss_X + loss_E + loss_y).detach(), 46 | "train_loss/X_CE": self.node_loss.compute() if true_X.numel() > 0 else -1, 47 | "train_loss/E_CE": self.edge_loss.compute() if true_E.numel() > 0 else -1, 48 | "train_loss/y_CE": self.y_loss.compute() if true_y.numel() > 0 else -1} 49 | if wandb.run: 50 | wandb.log(to_log, commit=True) 51 | 52 | self.reset() 53 | 54 | return self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E + self.lambda_train[2] * loss_y 55 | 56 | def reset(self): 57 | for metric in [self.node_loss, self.edge_loss, self.y_loss]: 58 | metric.reset() 59 | 60 | def log_epoch_metrics(self): 61 | epoch_node_loss = self.node_loss.compute() if self.node_loss.total_samples > 0 else -1 62 | epoch_edge_loss = self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1 63 | epoch_y_loss = self.y_loss.compute() if self.y_loss.total_samples > 0 else -1 64 | 65 | to_log = {"train_epoch/x_CE": epoch_node_loss, 66 | "train_epoch/E_CE": epoch_edge_loss, 67 | "train_epoch/y_CE": epoch_y_loss} 68 | if wandb.run: 69 | wandb.log(to_log, commit=False) 70 | 71 | return to_log 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /src/mist/data/data.py: -------------------------------------------------------------------------------- 1 | """ data.py """ 2 | import logging 3 | from typing import Optional 4 | import re 5 | 6 | from rdkit import Chem 7 | from rdkit.Chem import Descriptors 8 | 9 | from mist import utils 10 | 11 | 12 | class Spectra(object): 13 | def __init__( 14 | self, 15 | spectra_name: str = "", 16 | spectra_file: str = "", 17 | spectra_formula: str = "", 18 | instrument: str = "", 19 | **kwargs, 20 | ): 21 | """_summary_ 22 | 23 | Args: 24 | spectra_name (str, optional): _description_. Defaults to "". 25 | spectra_file (str, optional): _description_. Defaults to "". 26 | spectra_formula (str, optional): _description_. Defaults to "". 27 | instrument (str, optional): _description_. Defaults to "". 28 | """ 29 | self.spectra_name = spectra_name 30 | self.spectra_file = spectra_file 31 | self.formula = spectra_formula 32 | self.instrument = instrument 33 | 34 | ## 35 | self._is_loaded = False 36 | self.parentmass = None 37 | self.num_spectra = None 38 | self.meta = None 39 | self.spectrum_names = None 40 | self.spectra = None 41 | 42 | def get_instrument(self): 43 | return self.instrument 44 | 45 | def _load_spectra(self): 46 | """Load the spectra from files""" 47 | meta, spectrum_tuples = utils.parse_spectra(self.spectra_file) 48 | 49 | self.meta = meta 50 | self.parentmass = None 51 | for parent_kw in ["parentmass", "PEPMASS"]: 52 | self.parentmass = self.meta.get(parent_kw, None) 53 | if self.parentmass is not None: 54 | break 55 | 56 | if self.parentmass is None: 57 | logging.info(f"Unable to find precursor mass for {self.spectrum_name}") 58 | self.parentmass = 0 59 | else: 60 | self.parentmass = float(self.parentmass) 61 | 62 | # Store all the spectrum names (e.g., e.v.) and spectra arrays 63 | self.spectrum_names, self.spectra = zip(*spectrum_tuples) 64 | self.num_spectra = len(self.spectra) 65 | self._is_loaded = True 66 | 67 | def get_spec_name(self, **kwargs): 68 | """get_spec_name.""" 69 | return self.spectra_name 70 | 71 | def get_spec(self, **kwargs): 72 | """get_spec.""" 73 | if not self._is_loaded: 74 | self._load_spectra() 75 | 76 | return self.spectra 77 | 78 | def get_meta(self, **kwargs): 79 | """get_meta.""" 80 | if not self._is_loaded: 81 | self._load_spectra() 82 | return self.meta 83 | 84 | def get_spectra_formula(self): 85 | """Get spectra formula.""" 86 | return self.formula 87 | 88 | 89 | class Mol(object): 90 | """ 91 | Object to store a compound, including possibly multiple mass spectrometry 92 | spectra. 93 | """ 94 | 95 | def __init__( 96 | self, 97 | mol: Chem.Mol, 98 | smiles: Optional[str] = None, 99 | inchikey: Optional[str] = None, 100 | mol_formula: Optional[str] = None, 101 | ): 102 | """_summary_ 103 | 104 | Args: 105 | mol (Chem.Mol): _description_ 106 | smiles (Optional[str], optional): _description_. Defaults to None. 107 | inchikey (Optional[str], optional): _description_. Defaults to None. 108 | mol_formula (Optional[str], optional): _description_. Defaults to None. 109 | """ 110 | self.mol = mol 111 | 112 | self.smiles = smiles 113 | if self.smiles is None: 114 | # Isomeric smiles handled in preprocessing 115 | self.smiles = Chem.MolToSmiles(mol) 116 | 117 | self.inchikey = inchikey 118 | if self.inchikey is None and self.smiles != "": 119 | self.inchikey = Chem.MolToInchiKey(mol) 120 | 121 | self.mol_formula = mol_formula 122 | if self.mol_formula is None: 123 | self.mol_formula = utils.uncharged_formula(self.mol, mol_type="mol") 124 | self.num_hs = None 125 | 126 | @classmethod 127 | def MolFromInchi(cls, inchi: str, **kwargs): 128 | """_summary_ 129 | 130 | Args: 131 | inchi (str): _description_ 132 | 133 | Returns: 134 | _type_: _description_ 135 | """ 136 | mol = Chem.MolFromInchi(inchi) 137 | 138 | # Catch exception 139 | if mol is None: 140 | return None 141 | 142 | return cls(mol=mol, smiles=None, **kwargs) 143 | 144 | @classmethod 145 | def MolFromSmiles(cls, smiles: str, **kwargs): 146 | """_summary_ 147 | 148 | Args: 149 | smiles (str): _description_ 150 | 151 | Returns: 152 | _type_: _description_ 153 | """ 154 | if not smiles or isinstance(smiles, float): 155 | smiles = "" 156 | 157 | mol = Chem.MolFromSmiles(smiles) 158 | # Catch exception 159 | if mol is None: 160 | return None 161 | 162 | return cls(mol=mol, smiles=smiles, **kwargs) 163 | 164 | @classmethod 165 | def MolFromFormula(cls, formula: str, **kwargs): 166 | """ 167 | Create a Mol object from a chemical formula. 168 | This creates a molecule with atoms but no bonds. 169 | 170 | Args: 171 | formula (str): Chemical formula (e.g., "C6H12O6") 172 | inchikey (str, optional): InChIKey for the molecule. Defaults to None. 173 | 174 | Returns: 175 | Mol: Molecule object with atoms but no bonds 176 | """ 177 | # Regular expression to extract element symbols and counts from the formula 178 | pattern = r'([A-Z][a-z]*)(\d*)' 179 | matches = re.findall(pattern, formula) 180 | 181 | # Create an empty molecule 182 | mol = Chem.RWMol() 183 | 184 | # Add atoms to the molecule 185 | for element, count in matches: 186 | # If no count is specified, default to 1 187 | count = int(count) if count else 1 188 | 189 | # Get atomic number for the element 190 | atomic_num = Chem.GetPeriodicTable().GetAtomicNumber(element) 191 | 192 | # Add the atoms to the molecule 193 | for _ in range(count): 194 | atom = Chem.Atom(atomic_num) 195 | mol.AddAtom(atom) 196 | 197 | return cls(mol=mol, mol_formula=formula, **kwargs) 198 | 199 | def get_smiles(self) -> str: 200 | """_summary_ 201 | 202 | Returns: 203 | str: _description_ 204 | """ 205 | return self.smiles 206 | 207 | def get_inchikey(self) -> str: 208 | """_summary_ 209 | 210 | Returns: 211 | str: _description_ 212 | """ 213 | return self.inchikey 214 | 215 | def get_molform(self) -> str: 216 | """_summary_ 217 | 218 | Returns: 219 | str: _description_ 220 | """ 221 | return self.mol_formula 222 | 223 | def get_num_hs(self): 224 | """_summary_ 225 | 226 | Raises: 227 | ValueError: _description_ 228 | 229 | Returns: 230 | _type_: _description_ 231 | """ 232 | """get_num_hs.""" 233 | if self.num_hs is None: 234 | num = re.findall("H([0-9]*)", self.mol_formula) 235 | if num is None: 236 | out_num_hs = 0 237 | else: 238 | if len(num) == 0: 239 | out_num_hs = 0 240 | elif len(num) == 1: 241 | num = num[0] 242 | out_num_hs = 1 if num == "" else int(num) 243 | else: 244 | raise ValueError() 245 | self.num_hs = out_num_hs 246 | else: 247 | out_num_hs = self.num_hs 248 | 249 | return out_num_hs 250 | 251 | def get_mol_mass(self): 252 | """_summary_ 253 | 254 | Returns: 255 | _type_: _description_ 256 | """ 257 | return Descriptors.MolWt(self.mol) 258 | 259 | def get_rdkit_mol(self) -> Chem.Mol: 260 | """_summary_ 261 | 262 | Returns: 263 | Chem.Mol: _description_ 264 | """ 265 | return self.mol -------------------------------------------------------------------------------- /src/mist/data/splitter.py: -------------------------------------------------------------------------------- 1 | """splitter.py""" 2 | 3 | from pathlib import Path 4 | from typing import List, Tuple, Iterator 5 | import pandas as pd 6 | import numpy as np 7 | 8 | from .data import Spectra, Mol 9 | 10 | DATASET = List[Tuple[Spectra, Mol]] 11 | 12 | 13 | def get_splitter(**kwargs): 14 | """_summary_ 15 | 16 | Returns: 17 | _type_: _description_ 18 | """ 19 | return {"preset": PresetSpectraSplitter,}[ 20 | "preset" 21 | ](**kwargs) 22 | 23 | 24 | class SpectraSplitter(object): 25 | """SpectraSplitter.""" 26 | 27 | def __init__( 28 | self, 29 | **kwargs, 30 | ): 31 | """_summary_ 32 | 33 | Returns: 34 | _type_: _description_ 35 | """ 36 | 37 | pass 38 | 39 | def split_from_indices( 40 | self, 41 | full_dataset: DATASET, 42 | train_inds: np.ndarray, 43 | val_inds: np.ndarray, 44 | test_inds: np.ndarray, 45 | ) -> Tuple[DATASET]: 46 | """_summary_ 47 | 48 | Args: 49 | full_dataset (DATASET): _description_ 50 | train_inds (np.ndarray): _description_ 51 | val_inds (np.ndarray): _description_ 52 | test_inds (np.ndarray): _description_ 53 | 54 | Returns: 55 | Tuple[DATASET]: _description_ 56 | """ 57 | full_dataset = np.array(full_dataset) 58 | train_sub = full_dataset[train_inds].tolist() 59 | val_sub = full_dataset[val_inds].tolist() 60 | test_sub = full_dataset[test_inds].tolist() 61 | return (train_sub, val_sub, test_sub) 62 | 63 | 64 | class PresetSpectraSplitter(SpectraSplitter): 65 | """PresetSpectraSplitter.""" 66 | 67 | def __init__(self, split_file: str = None, **kwargs): 68 | """_summary_ 69 | 70 | Args: 71 | split_file (str, optional): _description_. Defaults to None. 72 | 73 | Raises: 74 | ValueError: _description_ 75 | """ 76 | super().__init__(**kwargs) 77 | if split_file is None: 78 | raise ValueError("Preset splitter requires split_file arg.") 79 | 80 | self.split_file = split_file 81 | self.split_name = Path(split_file).stem 82 | self.split_df = pd.read_csv(self.split_file, sep="\t") 83 | self.name_to_fold = dict(zip(self.split_df["name"], self.split_df["split"])) 84 | 85 | def get_splits(self, full_dataset: DATASET) -> Iterator[Tuple[str, Tuple[DATASET]]]: 86 | """_summary_ 87 | 88 | Args: 89 | full_dataset (DATASET): _description_ 90 | 91 | Returns: 92 | _type_: _description_ 93 | 94 | Yields: 95 | Iterator[Tuple[str, Tuple[DATASET]]]: _description_ 96 | """ 97 | # Map name to index 98 | spec_names = [i.get_spec_name() for i, j in full_dataset] 99 | train_inds = [ 100 | i for i, j in enumerate(spec_names) if self.name_to_fold.get(j) == "train" 101 | ] 102 | val_inds = [ 103 | i for i, j in enumerate(spec_names) if self.name_to_fold.get(j) == "val" 104 | ] 105 | test_inds = [ 106 | i for i, j in enumerate(spec_names) if self.name_to_fold.get(j) == "test" 107 | ] 108 | new_split = self.split_from_indices( 109 | full_dataset, train_inds, val_inds, test_inds 110 | ) 111 | return (self.split_name, new_split) -------------------------------------------------------------------------------- /src/mist/models/form_embedders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from ..utils.chem_utils import NORM_VEC 6 | 7 | 8 | class IntFeaturizer(nn.Module): 9 | """ 10 | Base class for mapping integers to a vector representation (primarily to be used as a "richer" embedding for NNs 11 | processing integers). 12 | 13 | Subclasses should define `self.int_to_feat_matrix`, a matrix where each row is the vector representation for that 14 | integer, i.e. to get a vector representation for `5`, one could call `self.int_to_feat_matrix[5]`. 15 | 16 | Note that this class takes care of creating a fixed number (`self.NUM_EXTRA_EMBEDDINGS` to be precise) of extra 17 | "learned" embeddings these will be concatenated after the integer embeddings in the forward pass, 18 | be learned, and be used for extra non-integer tokens such as the "to be confirmed token" (i.e., pad) token. 19 | They are indexed starting from `self.MAX_COUNT_INT`. 20 | """ 21 | 22 | MAX_COUNT_INT = 255 # the maximum number of integers that we are going to see as a "count", i.e. 0 to MAX_COUNT_INT-1 23 | NUM_EXTRA_EMBEDDINGS = 1 # Number of extra embeddings to learn -- one for the "to be confirmed" embedding. 24 | 25 | def __init__(self, embedding_dim): 26 | super().__init__() 27 | weights = torch.zeros(self.NUM_EXTRA_EMBEDDINGS, embedding_dim) 28 | self._extra_embeddings = nn.Parameter(weights, requires_grad=True) 29 | nn.init.normal_(self._extra_embeddings, 0.0, 1.0) 30 | self.embedding_dim = embedding_dim 31 | 32 | def forward(self, tensor): 33 | """ 34 | Convert the integer `tensor` into its new representation -- note that it gets stacked along final dimension. 35 | """ 36 | orig_shape = tensor.shape 37 | out_tensor = torch.empty( 38 | (*orig_shape, self.embedding_dim), device=tensor.device 39 | ) 40 | extra_embed = tensor >= self.MAX_COUNT_INT 41 | 42 | tensor = tensor.long() 43 | norm_embeds = self.int_to_feat_matrix[tensor[~extra_embed]] 44 | extra_embeds = self._extra_embeddings[tensor[extra_embed] - self.MAX_COUNT_INT] 45 | 46 | out_tensor[~extra_embed] = norm_embeds 47 | out_tensor[extra_embed] = extra_embeds 48 | 49 | temp_out = out_tensor.reshape(*orig_shape[:-1], -1) 50 | return temp_out 51 | 52 | @property 53 | def num_dim(self): 54 | return self.int_to_feat_matrix.shape[1] 55 | 56 | @property 57 | def full_dim(self): 58 | return self.num_dim * NORM_VEC.shape[0] 59 | 60 | 61 | class FourierFeaturizer(IntFeaturizer): 62 | """ 63 | Inspired by: 64 | Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R., 65 | Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional 66 | Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739. 67 | 68 | Some notes: 69 | * we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the 70 | Binarizer quite closely but be a bit smoother. 71 | """ 72 | 73 | def __init__(self): 74 | 75 | num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2 76 | # ^ need at least this many to ensure that the whole input range can be represented on the half circle. 77 | 78 | freqs = 0.5 ** torch.arange(num_freqs, dtype=torch.float32) 79 | freqs_time_2pi = 2 * np.pi * freqs 80 | 81 | super().__init__( 82 | embedding_dim=2 * freqs_time_2pi.shape[0] 83 | ) # 2 for cosine and sine 84 | 85 | # we will define the features at this frequency up front (as we only will ever see a fixed number of counts): 86 | combo_of_sinusoid_args = ( 87 | torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None] 88 | * freqs_time_2pi[None, :] 89 | ) 90 | all_features = torch.cat( 91 | [torch.cos(combo_of_sinusoid_args), torch.sin(combo_of_sinusoid_args)], 92 | dim=1, 93 | ) 94 | 95 | # ^ shape: MAX_COUNT_INT x 2 * num_freqs 96 | self.int_to_feat_matrix = nn.Parameter(all_features.float()) 97 | self.int_to_feat_matrix.requires_grad = False 98 | 99 | 100 | class FourierFeaturizerSines(IntFeaturizer): 101 | """ 102 | Like other fourier feats but sines only 103 | 104 | Inspired by: 105 | Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R., 106 | Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional 107 | Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739. 108 | 109 | Some notes: 110 | * we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the 111 | Binarizer quite closely but be a bit smoother. 112 | """ 113 | 114 | def __init__(self): 115 | 116 | num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2 117 | # ^ need at least this many to ensure that the whole input range can be represented on the half circle. 118 | 119 | freqs = (0.5 ** torch.arange(num_freqs, dtype=torch.float32))[2:] 120 | freqs_time_2pi = 2 * np.pi * freqs 121 | 122 | super().__init__(embedding_dim=freqs_time_2pi.shape[0]) 123 | 124 | # we will define the features at this frequency up front (as we only will ever see a fixed number of counts): 125 | combo_of_sinusoid_args = ( 126 | torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None] 127 | * freqs_time_2pi[None, :] 128 | ) 129 | # ^ shape: MAX_COUNT_INT x 2 * num_freqs 130 | self.int_to_feat_matrix = nn.Parameter( 131 | torch.sin(combo_of_sinusoid_args).float() 132 | ) 133 | self.int_to_feat_matrix.requires_grad = False 134 | 135 | 136 | class FourierFeaturizerAbsoluteSines(IntFeaturizer): 137 | """ 138 | Like other fourier feats but sines only and absoluted. 139 | 140 | Inspired by: 141 | Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R., 142 | Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional 143 | Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739. 144 | 145 | Some notes: 146 | * we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the 147 | Binarizer quite closely but be a bit smoother. 148 | """ 149 | 150 | def __init__(self): 151 | 152 | num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2 153 | 154 | freqs = (0.5 ** torch.arange(num_freqs, dtype=torch.float32))[2:] 155 | freqs_time_2pi = 2 * np.pi * freqs 156 | 157 | super().__init__(embedding_dim=freqs_time_2pi.shape[0]) 158 | 159 | # we will define the features at this frequency up front (as we only will ever see a fixed number of counts): 160 | combo_of_sinusoid_args = ( 161 | torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None] 162 | * freqs_time_2pi[None, :] 163 | ) 164 | # ^ shape: MAX_COUNT_INT x 2 * num_freqs 165 | self.int_to_feat_matrix = nn.Parameter( 166 | torch.abs(torch.sin(combo_of_sinusoid_args)).float() 167 | ) 168 | self.int_to_feat_matrix.requires_grad = False 169 | 170 | 171 | class FourierFeaturizerPosCos(IntFeaturizer): 172 | """ 173 | Like other fourier feats but sines only and absoluted. 174 | 175 | Inspired by: 176 | Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R., 177 | Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional 178 | Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739. 179 | 180 | Some notes: 181 | * we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the 182 | Binarizer quite closely but be a bit smoother. 183 | """ 184 | 185 | def __init__(self, num_funcs=9): 186 | 187 | # Variable 188 | self.num_funcs = num_funcs 189 | 190 | # Define a frequency that will be smoothly increasing from 0 to max 191 | max_freq = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 1 192 | freqs = 0.5 ** np.linspace(1, max_freq, num_funcs) 193 | freqs_time_2pi = 2 * np.pi * freqs 194 | freqs_time_2pi = torch.from_numpy(freqs_time_2pi).float() 195 | super().__init__(embedding_dim=freqs_time_2pi.shape[0]) 196 | 197 | # we will define the features at this frequency up front (as we only will ever see a fixed number of counts): 198 | combo_of_sinusoid_args = ( 199 | torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None] 200 | * freqs_time_2pi[None, :] 201 | ) 202 | # ^ shape: MAX_COUNT_INT x 2 * num_freqs 203 | self.int_to_feat_matrix = nn.Parameter( 204 | (-torch.cos(combo_of_sinusoid_args) + 1).float() 205 | ) 206 | self.int_to_feat_matrix.requires_grad = False 207 | 208 | 209 | class RBFFeaturizer(IntFeaturizer): 210 | """ 211 | A featurizer that puts radial basis functions evenly between 0 and max_count-1. These will have a width of 212 | (max_count-1) / (num_funcs) to decay to about 0.6 of its original height at reaching the next func. 213 | 214 | """ 215 | 216 | def __init__(self, num_funcs=32): 217 | """ 218 | :param num_funcs: number of radial basis functions to use: their width will automatically be chosen -- see class 219 | docstring. 220 | """ 221 | super().__init__(embedding_dim=num_funcs) 222 | width = (self.MAX_COUNT_INT - 1) / num_funcs 223 | centers = torch.linspace(0, self.MAX_COUNT_INT - 1, num_funcs) 224 | 225 | pre_exponential_terms = ( 226 | -0.5 227 | * ((torch.arange(self.MAX_COUNT_INT)[:, None] - centers[None, :]) / width) 228 | ** 2 229 | ) 230 | # ^ shape: MAX_COUNT_INT x num_funcs 231 | feats = torch.exp(pre_exponential_terms) 232 | 233 | self.int_to_feat_matrix = nn.Parameter(feats.float()) 234 | self.int_to_feat_matrix.requires_grad = False 235 | 236 | 237 | class OneHotFeaturizer(IntFeaturizer): 238 | """ 239 | A featurizer that turns integers into their one hot encoding. 240 | 241 | Represents: 242 | - 0 as 1000000000... 243 | - 1 as 0100000000... 244 | - 2 as 0010000000... 245 | and so on. 246 | """ 247 | 248 | def __init__(self): 249 | super().__init__(embedding_dim=self.MAX_COUNT_INT) 250 | feats = torch.eye(self.MAX_COUNT_INT) 251 | self.int_to_feat_matrix = nn.Parameter(feats.float()) 252 | self.int_to_feat_matrix.requires_grad = False 253 | 254 | 255 | class LearnedFeaturizer(IntFeaturizer): 256 | """ 257 | Learns the features for the different integers. 258 | 259 | Pretty much `nn.Embedding` but we get to use the forward of the superclass which behaves a bit differently. 260 | """ 261 | 262 | def __init__(self, feature_dim=32): 263 | super().__init__(embedding_dim=feature_dim) 264 | self.nn_embedder = nn.Embedding( 265 | self.MAX_COUNT_INT + self.NUM_EXTRA_EMBEDDINGS, feature_dim 266 | ) 267 | self.int_to_feat_matrix = list(self.nn_embedder.parameters())[0] 268 | 269 | def forward(self, tensor): 270 | """ 271 | Convert the integer `tensor` into its new representation -- note that it gets stacked along final dimension. 272 | """ 273 | orig_shape = tensor.shape 274 | out_tensor = self.nn_embedder(tensor.long()) 275 | temp_out = out_tensor.reshape(*orig_shape[:-1], -1) 276 | return temp_out 277 | 278 | 279 | class FloatFeaturizer(IntFeaturizer): 280 | """ 281 | Norms the features 282 | """ 283 | 284 | def __init__(self): 285 | # Norm vec 286 | # Placeholder.. 287 | super().__init__(embedding_dim=1) 288 | self.norm_vec = torch.from_numpy(NORM_VEC).float() 289 | self.norm_vec = nn.Parameter(self.norm_vec) 290 | self.norm_vec.requires_grad = False 291 | 292 | def forward(self, tensor): 293 | """ 294 | Convert the integer `tensor` into its new representation -- note that it gets stacked along final dimension. 295 | """ 296 | tens_shape = tensor.shape 297 | out_shape = [1] * (len(tens_shape) - 1) + [-1] 298 | return tensor / self.norm_vec.reshape(*out_shape) 299 | 300 | @property 301 | def num_dim(self): 302 | return 1 303 | 304 | 305 | def get_embedder(embedder): 306 | if embedder == "fourier": 307 | embedder = FourierFeaturizer() 308 | elif embedder == "rbf": 309 | embedder = RBFFeaturizer() 310 | elif embedder == "one-hot": 311 | embedder = OneHotFeaturizer() 312 | elif embedder == "learnt": 313 | embedder = LearnedFeaturizer() 314 | elif embedder == "float": 315 | embedder = FloatFeaturizer() 316 | elif embedder == "fourier-sines": 317 | embedder = FourierFeaturizerSines() 318 | elif embedder == "abs-sines": 319 | embedder = FourierFeaturizerAbsoluteSines() 320 | elif embedder == "pos-cos": 321 | embedder = FourierFeaturizerPosCos() 322 | else: 323 | raise NotImplementedError 324 | return embedder -------------------------------------------------------------------------------- /src/mist/models/spectra_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from . import modules 7 | 8 | class SpectraEncoder(nn.Module): 9 | """SpectraEncoder.""" 10 | def __init__( 11 | self, 12 | form_embedder: str = "float", 13 | output_size: int = 4096, 14 | hidden_size: int = 50, 15 | spectra_dropout: float = 0.0, 16 | top_layers: int = 1, 17 | refine_layers: int = 0, 18 | magma_modulo: int = 2048, 19 | **kwargs, 20 | ): 21 | super(SpectraEncoder, self).__init__() 22 | 23 | spectra_encoder_main = modules.FormulaTransformer( 24 | hidden_size=hidden_size, 25 | spectra_dropout=spectra_dropout, 26 | form_embedder=form_embedder, 27 | **kwargs, 28 | ) 29 | 30 | fragment_pred_parts = [] 31 | for _ in range(top_layers - 1): 32 | fragment_pred_parts.append(nn.Linear(hidden_size, hidden_size)) 33 | fragment_pred_parts.append(nn.ReLU()) 34 | fragment_pred_parts.append(nn.Dropout(spectra_dropout)) 35 | 36 | fragment_pred_parts.append(nn.Linear(hidden_size, magma_modulo)) 37 | 38 | fragment_predictor = nn.Sequential(*fragment_pred_parts) 39 | 40 | top_layer_parts = [] 41 | for _ in range(top_layers - 1): 42 | top_layer_parts.append(nn.Linear(hidden_size, hidden_size)) 43 | top_layer_parts.append(nn.ReLU()) 44 | top_layer_parts.append(nn.Dropout(spectra_dropout)) 45 | top_layer_parts.append(nn.Linear(hidden_size, output_size)) 46 | top_layer_parts.append(nn.Sigmoid()) 47 | spectra_predictor = nn.Sequential(*top_layer_parts) 48 | 49 | self.spectra_encoder = nn.ModuleList([spectra_encoder_main, fragment_predictor, spectra_predictor]) 50 | 51 | 52 | def forward(self, batch: dict) -> Tuple[torch.Tensor, dict]: 53 | """Forward pass.""" 54 | encoder_output, aux_out = self.spectra_encoder[0](batch, return_aux=True) 55 | 56 | pred_frag_fps = self.spectra_encoder[1](aux_out["peak_tensor"]) 57 | aux_outputs = {"pred_frag_fps": pred_frag_fps} 58 | 59 | output = self.spectra_encoder[2](encoder_output) 60 | aux_outputs["h0"] = encoder_output 61 | 62 | return output, aux_outputs 63 | 64 | 65 | class SpectraEncoderGrowing(nn.Module): 66 | """SpectraEncoder.""" 67 | def __init__( 68 | self, 69 | form_embedder: str = "float", 70 | output_size: int = 4096, 71 | hidden_size: int = 50, 72 | spectra_dropout: float = 0.0, 73 | top_layers: int = 1, 74 | refine_layers: int = 0, 75 | magma_modulo: int = 2048, 76 | **kwargs, 77 | ): 78 | super(SpectraEncoderGrowing, self).__init__() 79 | 80 | spectra_encoder_main = modules.FormulaTransformer( 81 | hidden_size=hidden_size, 82 | spectra_dropout=spectra_dropout, 83 | form_embedder=form_embedder, 84 | **kwargs, 85 | ) 86 | 87 | fragment_pred_parts = [] 88 | for _ in range(top_layers - 1): 89 | fragment_pred_parts.append(nn.Linear(hidden_size, hidden_size)) 90 | fragment_pred_parts.append(nn.ReLU()) 91 | fragment_pred_parts.append(nn.Dropout(spectra_dropout)) 92 | 93 | fragment_pred_parts.append(nn.Linear(hidden_size, magma_modulo)) 94 | 95 | fragment_predictor = nn.Sequential(*fragment_pred_parts) 96 | 97 | spectra_predictor = modules.FPGrowingModule( 98 | hidden_input_dim=hidden_size, 99 | final_target_dim=output_size, 100 | num_splits=refine_layers, 101 | reduce_factor=2, 102 | ) 103 | 104 | self.spectra_encoder = nn.ModuleList([spectra_encoder_main, fragment_predictor, spectra_predictor]) 105 | 106 | def forward(self, batch: dict) -> Tuple[torch.Tensor, dict]: 107 | """Forward pass.""" 108 | encoder_output, aux_out = self.spectra_encoder[0](batch, return_aux=True) 109 | pred_frag_fps = self.spectra_encoder[1](aux_out["peak_tensor"]) 110 | aux_outputs = {"pred_frag_fps": pred_frag_fps} 111 | 112 | output = self.spectra_encoder[2](encoder_output) 113 | intermediates = output[:-1] 114 | final_output = output[-1] 115 | aux_outputs["int_preds"] = intermediates 116 | output = final_output 117 | aux_outputs["h0"] = encoder_output 118 | 119 | return output, aux_outputs # aux_outputs["int_preds"][-1] 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /src/mist/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc_utils import * 2 | from .parse_utils import * 3 | from .chem_utils import * 4 | from .spectra_utils import * -------------------------------------------------------------------------------- /src/mist/utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | """misc_utils.py""" 2 | from typing import List, Iterable, Iterator 3 | from itertools import islice 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def unravel_index(index, shape): 10 | out = [] 11 | for dim in reversed(shape): 12 | out.append(index % dim) 13 | index = torch.div(index, dim, rounding_mode="trunc") 14 | return tuple(reversed(out)) 15 | 16 | 17 | def np_clamp(x, _min=-100): 18 | x = np.ones_like(x) * x 19 | x[x <= _min] = _min 20 | return x 21 | 22 | 23 | def clamped_log_np(x, _min=-100): 24 | res = np.log(x) 25 | return np_clamp(res, _min=_min) 26 | 27 | 28 | def batches(it: Iterable, chunk_size: int) -> Iterator[List]: 29 | """Consume an iterable in batches of size chunk_size""" "" 30 | it = iter(it) 31 | return iter(lambda: list(islice(it, chunk_size)), []) 32 | 33 | 34 | def pad_packed_tensor(input, lengths, value): 35 | """pad_packed_tensor""" 36 | old_shape = input.shape 37 | device = input.device 38 | if not isinstance(lengths, torch.Tensor): 39 | lengths = torch.tensor(lengths, dtype=torch.int64, device=device) 40 | else: 41 | lengths = lengths.to(device) 42 | max_len = (lengths.max()).item() 43 | 44 | batch_size = len(lengths) 45 | x = input.new(batch_size * max_len, *old_shape[1:]) 46 | x.fill_(value) 47 | 48 | # Initialize a tensor with an index for every value in the array 49 | index = torch.ones(len(input), dtype=torch.int64, device=device) 50 | 51 | # Row shifts 52 | row_shifts = torch.cumsum(max_len - lengths, 0) 53 | 54 | # Calculate shifts for second row, third row... nth row (not the n+1th row) 55 | # Expand this out to match the shape of all entries after the first row 56 | row_shifts_expanded = row_shifts[:-1].repeat_interleave(lengths[1:]) 57 | 58 | # Add this to the list of inds _after_ the first row 59 | cumsum_inds = torch.cumsum(index, 0) - 1 60 | cumsum_inds[lengths[0] :] += row_shifts_expanded 61 | x[cumsum_inds] = input 62 | return x.view(batch_size, max_len, *old_shape[1:]) 63 | 64 | 65 | def reverse_packed_tensor(packed_tensor, lengths): 66 | """reverse_packed_tensor. 67 | 68 | Args: 69 | packed tensor: Batch x length x feat_dim 70 | lengths : Batch 71 | Return: 72 | [batch,length] x feat_dim 73 | """ 74 | device = packed_tensor.device 75 | batch_size, batch_len, feat_dim = packed_tensor.shape 76 | max_length = torch.arange(batch_len).to(device) 77 | indices = max_length.unsqueeze(0).expand(batch_size, batch_len) 78 | bool_mask = indices < lengths.unsqueeze(1) 79 | output = packed_tensor[bool_mask] 80 | return output 81 | 82 | 83 | def unpack_bits(vec, num_bits): 84 | return np.unpackbits(vec, axis=-1)[..., -num_bits:] -------------------------------------------------------------------------------- /src/mist/utils/parse_utils.py: -------------------------------------------------------------------------------- 1 | """ parse_utils.py """ 2 | from pathlib import Path 3 | from typing import Tuple, List, Optional 4 | from itertools import groupby 5 | 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | 10 | def parse_spectra(spectra_file: str) -> Tuple[dict, List[Tuple[str, np.ndarray]]]: 11 | """parse_spectra. 12 | 13 | Parses spectra in the SIRIUS format and returns 14 | 15 | Args: 16 | spectra_file (str): Name of spectra file to parse 17 | Return: 18 | Tuple[dict, List[Tuple[str, np.ndarray]]]: metadata and list of spectra 19 | tuples containing name and array 20 | """ 21 | lines = [i.strip() for i in open(spectra_file, "r").readlines()] 22 | 23 | group_num = 0 24 | metadata = {} 25 | spectras = [] 26 | my_iterator = groupby( 27 | lines, lambda line: line.startswith(">") or line.startswith("#") 28 | ) 29 | 30 | for index, (start_line, lines) in enumerate(my_iterator): 31 | group_lines = list(lines) 32 | subject_lines = list(next(my_iterator)[1]) 33 | # Get spectra 34 | if group_num > 0: 35 | spectra_header = group_lines[0].split(">")[1] 36 | peak_data = [ 37 | [float(x) for x in peak.split()[:2]] 38 | for peak in subject_lines 39 | if peak.strip() 40 | ] 41 | # Check if spectra is empty 42 | if len(peak_data): 43 | peak_data = np.vstack(peak_data) 44 | # Add new tuple 45 | spectras.append((spectra_header, peak_data)) 46 | # Get meta data 47 | else: 48 | entries = {} 49 | for i in group_lines: 50 | if " " not in i: 51 | continue 52 | elif i.startswith("#INSTRUMENT TYPE"): 53 | key = "#INSTRUMENT TYPE" 54 | val = i.split(key)[1].strip() 55 | entries[key[1:]] = val 56 | else: 57 | start, end = i.split(" ", 1) 58 | start = start[1:] 59 | while start in entries: 60 | start = f"{start}'" 61 | entries[start] = end 62 | 63 | metadata.update(entries) 64 | group_num += 1 65 | 66 | metadata["_FILE_PATH"] = spectra_file 67 | metadata["_FILE"] = Path(spectra_file).stem 68 | return metadata, spectras 69 | 70 | 71 | def spec_to_ms_str( 72 | spec: List[Tuple[str, np.ndarray]], essential_keys: dict, comments: dict = {} 73 | ) -> str: 74 | """spec_to_ms_str. 75 | 76 | Turn spec ars and info dicts into str for output file 77 | 78 | 79 | Args: 80 | spec (List[Tuple[str, np.ndarray]]): spec 81 | essential_keys (dict): essential_keys 82 | comments (dict): comments 83 | 84 | Returns: 85 | str: 86 | """ 87 | 88 | def pair_rows(rows): 89 | return "\n".join([f"{i} {j}" for i, j in rows]) 90 | 91 | header = "\n".join(f">{k} {v}" for k, v in essential_keys.items()) 92 | comments = "\n".join(f"#{k} {v}" for k, v in essential_keys.items()) 93 | spec_strs = [f">{name}\n{pair_rows(ar)}" for name, ar in spec] 94 | spec_str = "\n\n".join(spec_strs) 95 | output = f"{header}\n{comments}\n\n{spec_str}" 96 | return output 97 | 98 | 99 | def build_mgf_str( 100 | meta_spec_list: List[Tuple[dict, List[Tuple[str, np.ndarray]]]], 101 | merge_charges=True, 102 | parent_mass_keys=["PEPMASS", "parentmass", "PRECURSOR_MZ"], 103 | ) -> str: 104 | """build_mgf_str. 105 | 106 | Args: 107 | meta_spec_list (List[Tuple[dict, List[Tuple[str, np.ndarray]]]]): meta_spec_list 108 | 109 | Returns: 110 | str: 111 | """ 112 | entries = [] 113 | for meta, spec in tqdm(meta_spec_list): 114 | str_rows = ["BEGIN IONS"] 115 | 116 | # Try to add precusor mass 117 | for i in parent_mass_keys: 118 | if i in meta: 119 | pep_mass = float(meta.get(i, -100)) 120 | str_rows.append(f"PEPMASS={pep_mass}") 121 | break 122 | 123 | for k, v in meta.items(): 124 | str_rows.append(f"{k.upper().replace(' ', '_')}={v}") 125 | 126 | if merge_charges: 127 | spec_ar = np.vstack([i[1] for i in spec]) 128 | spec_ar = np.vstack([i for i in sorted(spec_ar, key=lambda x: x[0])]) 129 | else: 130 | raise NotImplementedError() 131 | str_rows.extend([f"{i} {j}" for i, j in spec_ar]) 132 | str_rows.append("END IONS") 133 | 134 | str_out = "\n".join(str_rows) 135 | entries.append(str_out) 136 | 137 | full_out = "\n\n".join(entries) 138 | return full_out 139 | 140 | 141 | def parse_spectra_msp( 142 | mgf_file: str, max_num: Optional[int] = None 143 | ) -> List[Tuple[dict, List[Tuple[str, np.ndarray]]]]: 144 | """parse_spectr_msp. 145 | 146 | Parses spectra in the MSP file format 147 | 148 | Args: 149 | mgf_file (str) : str 150 | max_num (Optional[int]): If set, only parse this many 151 | Return: 152 | List[Tuple[dict, List[Tuple[str, np.ndarray]]]]: metadata and list of spectra 153 | tuples containing name and array 154 | """ 155 | 156 | key = lambda x: x.strip().startswith("PEPMASS") 157 | parsed_spectra = [] 158 | with open(mgf_file, "r", encoding="utf-8") as fp: 159 | for (is_header, group) in tqdm(groupby(fp, key)): 160 | 161 | if is_header: 162 | continue 163 | meta = dict() 164 | spectra = [] 165 | # Note: Sometimes we have multiple scans 166 | # This mgf has them collapsed 167 | cur_spectra_name = "spec" 168 | cur_spectra = [] 169 | group = list(group) 170 | for line in group: 171 | line = line.strip() 172 | if not line: 173 | pass 174 | elif ":" in line: 175 | k, v = [i.strip() for i in line.split(":", 1)] 176 | meta[k] = v 177 | else: 178 | mz, intens = line.split() 179 | cur_spectra.append((float(mz), float(intens))) 180 | 181 | if len(cur_spectra) > 0: 182 | cur_spectra = np.vstack(cur_spectra) 183 | spectra.append((cur_spectra_name, cur_spectra)) 184 | parsed_spectra.append((meta, spectra)) 185 | else: 186 | pass 187 | # print("no spectra found for group: ", "".join(group)) 188 | 189 | if max_num is not None and len(parsed_spectra) > max_num: 190 | # print("Breaking") 191 | break 192 | return parsed_spectra 193 | 194 | 195 | def parse_spectra_mgf( 196 | mgf_file: str, max_num: Optional[int] = None 197 | ) -> List[Tuple[dict, List[Tuple[str, np.ndarray]]]]: 198 | """parse_spectr_mgf. 199 | 200 | Parses spectra in the MGF file formate, with 201 | 202 | Args: 203 | mgf_file (str) : str 204 | max_num (Optional[int]): If set, only parse this many 205 | Return: 206 | List[Tuple[dict, List[Tuple[str, np.ndarray]]]]: metadata and list of spectra 207 | tuples containing name and array 208 | """ 209 | 210 | key = lambda x: x.strip() == "BEGIN IONS" 211 | parsed_spectra = [] 212 | with open(mgf_file, "r") as fp: 213 | 214 | for (is_header, group) in tqdm(groupby(fp, key)): 215 | 216 | if is_header: 217 | continue 218 | 219 | meta = dict() 220 | spectra = [] 221 | # Note: Sometimes we have multiple scans 222 | # This mgf has them collapsed 223 | cur_spectra_name = "spec" 224 | cur_spectra = [] 225 | group = list(group) 226 | for line in group: 227 | line = line.strip() 228 | if not line: 229 | pass 230 | elif line == "END IONS" or line == "BEGIN IONS": 231 | pass 232 | elif "=" in line: 233 | k, v = [i.strip() for i in line.split("=", 1)] 234 | meta[k] = v 235 | else: 236 | mz, intens = line.split() 237 | cur_spectra.append((float(mz), float(intens))) 238 | 239 | if len(cur_spectra) > 0: 240 | cur_spectra = np.vstack(cur_spectra) 241 | spectra.append((cur_spectra_name, cur_spectra)) 242 | parsed_spectra.append((meta, spectra)) 243 | else: 244 | pass 245 | # print("no spectra found for group: ", "".join(group)) 246 | 247 | if max_num is not None and len(parsed_spectra) > max_num: 248 | # print("Breaking") 249 | break 250 | return parsed_spectra 251 | 252 | 253 | def parse_tsv_spectra(spectra_file: str) -> List[Tuple[str, np.ndarray]]: 254 | """parse_tsv_spectra. 255 | 256 | Parses spectra returned from sirius fragmentation tree 257 | 258 | Args: 259 | spectra_file (str): Name of spectra tsv file to parse 260 | Return: 261 | List[Tuple[str, np.ndarray]]]: list of spectra 262 | tuples containing name and array. This is used to maintain 263 | consistency with the parse_spectra output 264 | """ 265 | output_spec = [] 266 | with open(spectra_file, "r") as fp: 267 | for index, line in enumerate(fp): 268 | if index == 0: 269 | continue 270 | line = line.strip().split("\t") 271 | intensity = float(line[1]) 272 | exact_mass = float(line[3]) 273 | output_spec.append([exact_mass, intensity]) 274 | 275 | output_spec = np.array(output_spec) 276 | return_obj = [("sirius_spec", output_spec)] 277 | return return_obj -------------------------------------------------------------------------------- /src/mist/utils/spectra_utils.py: -------------------------------------------------------------------------------- 1 | """ spectra_utils.py""" 2 | import logging 3 | import numpy as np 4 | from typing import List 5 | 6 | 7 | from .chem_utils import ( 8 | vec_to_formula, 9 | get_all_subsets, 10 | ion_to_mass, 11 | ION_LST, 12 | clipped_ppm, 13 | ) 14 | 15 | 16 | def bin_spectra( 17 | spectras: List[np.ndarray], num_bins: int = 2000, upper_limit: int = 1000 18 | ) -> np.ndarray: 19 | """bin_spectra. 20 | 21 | Args: 22 | spectras (List[np.ndarray]): Input list of spectra tuples 23 | [(header, spec array)] 24 | num_bins (int): Number of discrete bins from [0, upper_limit) 25 | upper_limit (int): Max m/z to consider featurizing 26 | 27 | Return: 28 | np.ndarray of shape [channels, num_bins] 29 | """ 30 | bins = np.linspace(0, upper_limit, num=num_bins) 31 | binned_spec = np.zeros((len(spectras), len(bins))) 32 | for spec_index, spec in enumerate(spectras): 33 | 34 | # Convert to digitized spectra 35 | digitized_mz = np.digitize(spec[:, 0], bins=bins) 36 | 37 | # Remove all spectral peaks out of range 38 | in_range = digitized_mz < len(bins) 39 | digitized_mz, spec = digitized_mz[in_range], spec[in_range, :] 40 | 41 | # Add the current peaks to the spectra 42 | # Use a loop rather than vectorize because certain bins have conflicts 43 | # based upon resolution 44 | for bin_index, spec_val in zip(digitized_mz, spec[:, 1]): 45 | binned_spec[spec_index, bin_index] += spec_val 46 | 47 | return binned_spec 48 | 49 | 50 | def merge_norm_spectra(spec_tuples, precision=4) -> np.ndarray: 51 | """merge_norm_spectra. 52 | 53 | Take a list of mz, inten tuple arrays and merge them by 4 digit precision 54 | 55 | Note this uses _max_ merging 56 | 57 | """ 58 | mz_to_inten_pair = {} 59 | for i in spec_tuples: 60 | for tup in i: 61 | mz, inten = tup 62 | mz_ind = np.round(mz, precision) 63 | cur_pair = mz_to_inten_pair.get(mz_ind) 64 | if cur_pair is None: 65 | mz_to_inten_pair[mz_ind] = tup 66 | elif inten > cur_pair[1]: 67 | mz_to_inten_pair[mz_ind] = (mz_ind, inten) 68 | else: 69 | pass 70 | 71 | merged_spec = np.vstack([v for k, v in mz_to_inten_pair.items()]) 72 | merged_spec[:, 1] = merged_spec[:, 1] / merged_spec[:, 1].max() 73 | return merged_spec 74 | 75 | 76 | def norm_spectrum(binned_spec: np.ndarray) -> np.ndarray: 77 | """norm_spectrum. 78 | 79 | Normalizes each spectral channel to have norm 1 80 | This change is made in place 81 | 82 | Args: 83 | binned_spec (np.ndarray) : Vector of spectras 84 | 85 | Return: 86 | np.ndarray where each channel has max(1) 87 | """ 88 | 89 | spec_maxes = binned_spec.max(1) 90 | 91 | non_zero_max = spec_maxes > 0 92 | 93 | spec_maxes = spec_maxes[non_zero_max] 94 | binned_spec[non_zero_max] = binned_spec[non_zero_max] / spec_maxes.reshape(-1, 1) 95 | 96 | return binned_spec 97 | 98 | 99 | def process_spec_file(meta, tuples, precision=4, max_inten=0.001, max_peaks=60): 100 | """process_spec_file.""" 101 | 102 | if "parentmass" in meta: 103 | parentmass = meta.get("parentmass", None) 104 | elif "PARENTMASS" in meta: 105 | parentmass = meta.get("PARENTMASS", None) 106 | elif "PEPMASS" in meta: 107 | parentmass = meta.get("PEPMASS", None) 108 | else: 109 | logging.debug(f"missing parentmass for spec") 110 | parentmass = 1000000 111 | 112 | parentmass = float(parentmass) 113 | 114 | # First norm spectra 115 | fused_tuples = [x for _, x in tuples if x.size > 0] 116 | 117 | if len(fused_tuples) == 0: 118 | return 119 | 120 | mz_to_inten_pair = {} 121 | new_tuples = [] 122 | for i in fused_tuples: 123 | for tup in i: 124 | mz, inten = tup 125 | mz_ind = np.round(mz, precision) 126 | cur_pair = mz_to_inten_pair.get(mz_ind) 127 | if cur_pair is None: 128 | mz_to_inten_pair[mz_ind] = tup 129 | new_tuples.append(tup) 130 | elif inten > cur_pair[1]: 131 | cur_pair[1] = inten 132 | else: 133 | pass 134 | 135 | merged_spec = np.vstack(new_tuples) 136 | merged_spec = merged_spec[merged_spec[:, 0] <= (parentmass + 1)] 137 | merged_spec[:, 1] = merged_spec[:, 1] / merged_spec[:, 1].max() 138 | 139 | # Sqrt intensities here 140 | merged_spec[:, 1] = np.sqrt(merged_spec[:, 1]) 141 | 142 | merged_spec = max_inten_spec( 143 | merged_spec, max_num_inten=max_peaks, inten_thresh=max_inten 144 | ) 145 | return merged_spec 146 | 147 | 148 | def max_inten_spec(spec, max_num_inten: int = 60, inten_thresh: float = 0): 149 | """max_inten_spec. 150 | 151 | Args: 152 | spec: 2D spectra array 153 | max_num_inten: Max number of peaks 154 | inten_thresh: Min intensity to alloow in returned peak 155 | 156 | Return: 157 | Spec filtered down 158 | 159 | 160 | """ 161 | spec_masses, spec_intens = spec[:, 0], spec[:, 1] 162 | 163 | # Make sure to only take max of each formula 164 | # Sort by intensity and select top subpeaks 165 | new_sort_order = np.argsort(spec_intens)[::-1] 166 | if max_num_inten is not None: 167 | new_sort_order = new_sort_order[:max_num_inten] 168 | 169 | spec_masses = spec_masses[new_sort_order] 170 | spec_intens = spec_intens[new_sort_order] 171 | 172 | spec_mask = spec_intens > inten_thresh 173 | spec_masses = spec_masses[spec_mask] 174 | spec_intens = spec_intens[spec_mask] 175 | spec = np.vstack([spec_masses, spec_intens]).transpose(1, 0) 176 | return spec 177 | 178 | 179 | def max_thresh_spec(spec: np.ndarray, max_peaks=100, inten_thresh=0.003): 180 | """max_thresh_spec. 181 | 182 | Args: 183 | spec (np.ndarray): spec 184 | max_peaks: Max num peaks to keep 185 | inten_thresh: Min inten to keep 186 | """ 187 | 188 | spec_masses, spec_intens = spec[:, 0], spec[:, 1] 189 | 190 | # Make sure to only take max of each formula 191 | # Sort by intensity and select top subpeaks 192 | new_sort_order = np.argsort(spec_intens)[::-1] 193 | new_sort_order = new_sort_order[:max_peaks] 194 | 195 | spec_masses = spec_masses[new_sort_order] 196 | spec_intens = spec_intens[new_sort_order] 197 | 198 | spec_mask = spec_intens > inten_thresh 199 | spec_masses = spec_masses[spec_mask] 200 | spec_intens = spec_intens[spec_mask] 201 | out_ar = np.vstack([spec_masses, spec_intens]).transpose(1, 0) 202 | return out_ar 203 | 204 | 205 | def assign_subforms(form, spec, ion_type, mass_diff_thresh=15): 206 | """_summary_ 207 | 208 | Args: 209 | form (_type_): _description_ 210 | spec (_type_): _description_ 211 | ion_type (_type_): _description_ 212 | mass_diff_thresh (int, optional): _description_. Defaults to 15. 213 | 214 | Returns: 215 | _type_: _description_ 216 | """ 217 | cross_prod, masses = get_all_subsets(form) 218 | spec_masses, spec_intens = spec[:, 0], spec[:, 1] 219 | 220 | ion_masses = ion_to_mass[ion_type] 221 | masses_with_ion = masses + ion_masses 222 | ion_types = np.array([ion_type] * len(masses_with_ion)) 223 | 224 | mass_diffs = np.abs(spec_masses[:, None] - masses_with_ion[None, :]) 225 | 226 | formula_inds = mass_diffs.argmin(-1) 227 | min_mass_diff = mass_diffs[np.arange(len(mass_diffs)), formula_inds] 228 | rel_mass_diff = clipped_ppm(min_mass_diff, spec_masses) 229 | 230 | # Filter by mass diff threshold (ppm) 231 | valid_mask = rel_mass_diff < mass_diff_thresh 232 | spec_masses = spec_masses[valid_mask] 233 | spec_intens = spec_intens[valid_mask] 234 | min_mass_diff = min_mass_diff[valid_mask] 235 | rel_mass_diff = rel_mass_diff[valid_mask] 236 | formula_inds = formula_inds[valid_mask] 237 | 238 | formulas = np.array([vec_to_formula(j) for j in cross_prod[formula_inds]]) 239 | formula_masses = masses_with_ion[formula_inds] 240 | ion_types = ion_types[formula_inds] 241 | 242 | # Build mask for uniqueness on formula and ionization 243 | # note that ionization are all the same for one subformula assignment 244 | # hence we only need to consider the uniqueness of the formula 245 | formula_idx_dict = {} 246 | uniq_mask = [] 247 | for idx, formula in enumerate(formulas): 248 | uniq_mask.append(formula not in formula_idx_dict) 249 | gather_ind = formula_idx_dict.get(formula, None) 250 | if gather_ind is None: 251 | continue 252 | spec_intens[gather_ind] += spec_intens[idx] 253 | formula_idx_dict[formula] = idx 254 | 255 | spec_masses = spec_masses[uniq_mask] 256 | spec_intens = spec_intens[uniq_mask] 257 | min_mass_diff = min_mass_diff[uniq_mask] 258 | rel_mass_diff = rel_mass_diff[uniq_mask] 259 | formula_masses = formula_masses[uniq_mask] 260 | formulas = formulas[uniq_mask] 261 | ion_types = ion_types[uniq_mask] 262 | 263 | # To calculate explained intensity, preserve the original normalized 264 | # intensity 265 | if spec_intens.size == 0: 266 | output_tbl = None 267 | else: 268 | output_tbl = { 269 | "mz": list(spec_masses), 270 | "ms2_inten": list(spec_intens), 271 | "mono_mass": list(formula_masses), 272 | "abs_mass_diff": list(min_mass_diff), 273 | "mass_diff": list(rel_mass_diff), 274 | "formula": list(formulas), 275 | "ions": list(ion_types), 276 | } 277 | output_dict = { 278 | "cand_form": form, 279 | "cand_ion": ion_type, 280 | "output_tbl": output_tbl, 281 | } 282 | return output_dict 283 | 284 | 285 | def get_output_dict( 286 | spec_name: str, 287 | spec: np.ndarray, 288 | form: str, 289 | mass_diff_type: str, 290 | mass_diff_thresh: float, 291 | ion_type: str, 292 | ) -> dict: 293 | """_summary_ 294 | 295 | This function attemps to take an array of mass intensity values and assign 296 | formula subsets to subpeaks 297 | 298 | Args: 299 | spec_name (str): _description_ 300 | spec (np.ndarray): _description_ 301 | form (str): _description_ 302 | mass_diff_type (str): _description_ 303 | mass_diff_thresh (float): _description_ 304 | ion_type (str): _description_ 305 | 306 | Returns: 307 | dict: _description_ 308 | """ 309 | assert mass_diff_type == "ppm" 310 | # This is the case for some erroneous MS2 files for which proc_spec_file return None 311 | # All the MS2 subpeaks in these erroneous MS2 files has mz larger than parentmass 312 | output_dict = {"cand_form": form, "cand_ion": ion_type, "output_tbl": None} 313 | if spec is not None and ion_type in ION_LST: 314 | output_dict = assign_subforms( 315 | form, spec, ion_type, mass_diff_thresh=mass_diff_thresh 316 | ) 317 | return output_dict -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/DiffMS/9d9f4fd497162eec045f7db2787da30ce69a9622/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Xtoy(nn.Module): 6 | def __init__(self, dx, dy): 7 | """ Map node features to global features """ 8 | super().__init__() 9 | self.lin = nn.Linear(4 * dx, dy) 10 | 11 | def forward(self, X, x_mask): 12 | """ X: bs, n, dx. """ 13 | x_mask = x_mask.expand(-1, -1, X.shape[-1]) 14 | float_imask = 1 - x_mask.float() 15 | m = X.sum(dim=1) / torch.sum(x_mask, dim=1) 16 | mi = (X + 1e5 * float_imask).min(dim=1)[0] 17 | ma = (X - 1e5 * float_imask).max(dim=1)[0] 18 | std = torch.sum(((X - m[:, None, :]) ** 2) * x_mask, dim=1) / torch.sum(x_mask, dim=1) 19 | z = torch.hstack((m, mi, ma, std)) 20 | out = self.lin(z) 21 | return out 22 | 23 | 24 | class Etoy(nn.Module): 25 | def __init__(self, d, dy): 26 | """ Map edge features to global features. """ 27 | super().__init__() 28 | self.lin = nn.Linear(4 * d, dy) 29 | 30 | def forward(self, E, e_mask1, e_mask2): 31 | """ E: bs, n, n, de 32 | Features relative to the diagonal of E could potentially be added. 33 | """ 34 | mask = (e_mask1 * e_mask2).expand(-1, -1, -1, E.shape[-1]) 35 | float_imask = 1 - mask.float() 36 | divide = torch.sum(mask, dim=(1, 2)) 37 | m = E.sum(dim=(1, 2)) / divide 38 | mi = (E + 1e5 * float_imask).min(dim=2)[0].min(dim=1)[0] 39 | ma = (E - 1e5 * float_imask).max(dim=2)[0].max(dim=1)[0] 40 | std = torch.sum(((E - m[:, None, None, :]) ** 2) * mask, dim=(1, 2)) / divide 41 | z = torch.hstack((m, mi, ma, std)) 42 | out = self.lin(z) 43 | return out 44 | 45 | 46 | def masked_softmax(x, mask, **kwargs): 47 | if mask.sum() == 0: 48 | return x 49 | x_masked = x.clone() 50 | x_masked[mask == 0] = -float("inf") 51 | return torch.softmax(x_masked, **kwargs) -------------------------------------------------------------------------------- /src/spec2mol_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pathlib 4 | import warnings 5 | import logging 6 | 7 | import torch 8 | torch.cuda.empty_cache() 9 | try: 10 | torch.set_float32_matmul_precision('medium') 11 | logging.info("Enabled float32 matmul precision - medium") 12 | except: 13 | logging.info("Could not enable float32 matmul precision - medium") 14 | import hydra 15 | from omegaconf import DictConfig 16 | from pytorch_lightning import Trainer 17 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 18 | from pytorch_lightning.loggers import CSVLogger, WandbLogger 19 | from pytorch_lightning.utilities.warnings import PossibleUserWarning 20 | 21 | from src import utils 22 | from src.diffusion_model_spec2mol import Spec2MolDenoisingDiffusion 23 | from src.diffusion.extra_features import DummyExtraFeatures, ExtraFeatures 24 | from src.metrics.molecular_metrics_discrete import TrainMolecularMetricsDiscrete 25 | from src.diffusion.extra_features_molecular import ExtraMolecularFeatures 26 | from src.analysis.visualization import MolecularVisualization 27 | from src.datasets import spec2mol_dataset 28 | 29 | 30 | warnings.filterwarnings("ignore", category=PossibleUserWarning) 31 | 32 | # TODO: refactor how configs are resumed (need old cfg.model and cfg.train but probably not general) 33 | def get_resume(cfg, model_kwargs): 34 | """ Resumes a run. It loads previous config without allowing to update keys (used for testing). """ 35 | saved_cfg = cfg.copy() 36 | name = cfg.general.name + '_resume' 37 | resume = cfg.general.test_only 38 | val_samples_to_generate = cfg.general.val_samples_to_generate 39 | test_samples_to_generate = cfg.general.test_samples_to_generate 40 | gpus = cfg.general.gpus 41 | 42 | model = Spec2MolDenoisingDiffusion.load_from_checkpoint(resume, **model_kwargs) 43 | 44 | cfg = model.cfg 45 | cfg.general.test_only = resume 46 | cfg.general.name = name 47 | cfg.general.val_samples_to_generate = val_samples_to_generate 48 | cfg.general.test_samples_to_generate = test_samples_to_generate 49 | cfg.general.gpus = gpus 50 | cfg = utils.update_config_with_new_keys(cfg, saved_cfg) 51 | return cfg, model 52 | 53 | 54 | def get_resume_adaptive(cfg, model_kwargs): 55 | """ Resumes a run. It loads previous config but allows to make some changes (used for resuming training).""" 56 | saved_cfg = cfg.copy() 57 | # Fetch path to this file to get base path 58 | current_path = os.path.dirname(os.path.realpath(__file__)) 59 | root_dir = current_path.split('outputs')[0] 60 | 61 | resume_path = os.path.join(root_dir, cfg.general.resume) 62 | 63 | model = Spec2MolDenoisingDiffusion.load_from_checkpoint(resume_path, **model_kwargs) 64 | 65 | new_cfg = model.cfg 66 | 67 | for category in cfg: 68 | for arg in cfg[category]: 69 | new_cfg[category][arg] = cfg[category][arg] 70 | 71 | new_cfg.general.resume = resume_path 72 | new_cfg.general.name = new_cfg.general.name + '_resume' 73 | 74 | new_cfg = utils.update_config_with_new_keys(new_cfg, saved_cfg) 75 | return new_cfg, model 76 | 77 | def apply_encoder_finetuning(model, strategy): 78 | if strategy is None: 79 | pass 80 | elif strategy == 'freeze': 81 | for param in model.encoder.parameters(): 82 | param.requires_grad = False 83 | elif strategy == 'ft-unfold': 84 | for param in model.encoder.named_parameters(): 85 | layer = param[0].split('.')[1] 86 | if layer != '2': 87 | param[1].requires_grad = False 88 | elif strategy == 'freeze-unfold': 89 | for param in model.encoder.named_parameters(): 90 | layer = param[0].split('.')[1] 91 | if layer == '2': 92 | param[1].requires_grad = False 93 | elif strategy == 'ft-transformer': 94 | for param in model.encoder.named_parameters(): 95 | layer = param[0].split('.')[1] 96 | if layer != '0': 97 | param[1].requires_grad = False 98 | elif strategy == 'freeze-transformer': 99 | for param in model.encoder.named_parameters(): 100 | layer = param[0].split('.')[1] 101 | if layer == '0': 102 | param[1].requires_grad = False 103 | else: 104 | raise NotImplementedError(f'Unknown Finetune Strategy: {strategy}') 105 | 106 | def apply_decoder_finetuning(model, strategy): 107 | if strategy is None: 108 | pass 109 | elif strategy == 'freeze': 110 | for param in model.decoder.parameters(): 111 | param.requires_grad = False 112 | elif strategy == 'ft-input': 113 | for p in model.decoder.named_parameters(): 114 | layer_name = p[0].split('.')[0] 115 | if layer_name not in ['mlp_in_X', 'mlp_in_E', 'mlp_in_y']: 116 | p[1].requires_grad = False 117 | elif strategy == 'freeze-input': 118 | for p in model.decoder.named_parameters(): 119 | layer_name = p[0].split('.')[0] 120 | if layer_name in ['mlp_in_X', 'mlp_in_E', 'mlp_in_y']: 121 | p[1].requires_grad = False 122 | elif strategy == 'ft-transformer': 123 | for param in model.decoder.parameters(): 124 | param.requires_grad = False 125 | for param in model.decoder.tf_layers.parameters(): 126 | param.requires_grad = True 127 | elif strategy == 'freeze-transformer': 128 | for param in model.decoder.tf_layers.parameters(): 129 | param.requires_grad = False 130 | elif strategy == 'ft-output': 131 | for p in model.decoder.named_parameters(): 132 | layer_name = p[0].split('.')[0] 133 | if layer_name not in ['mlp_out_X', 'mlp_out_E', 'mlp_out_y']: 134 | p[1].requires_grad = False 135 | else: 136 | raise NotImplementedError(f'Unknown Finetune Strategy: {strategy}') 137 | 138 | def load_weights(model, path): 139 | """ 140 | Loads only the weights from a checkpoint file into the model without loading the full Lightning module. 141 | 142 | Args: 143 | model: The model to load weights into 144 | path: Path to the checkpoint file 145 | 146 | Returns: 147 | The model with loaded weights 148 | """ 149 | checkpoint = torch.load(path, map_location='cpu') 150 | state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint 151 | 152 | # Filter out keys that don't match the model (for partial loading) 153 | model_state_dict = model.state_dict() 154 | filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} 155 | 156 | # Load the weights 157 | missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False) 158 | logging.info(f"Loaded weights from {path}") 159 | logging.info(f"Missing keys: {missing_keys}") 160 | logging.info(f"Unexpected keys: {unexpected_keys}") 161 | 162 | return model 163 | 164 | @hydra.main(version_base='1.3', config_path='../configs', config_name='config') 165 | def main(cfg: DictConfig): 166 | from rdkit import RDLogger 167 | RDLogger.DisableLog('rdApp.*') 168 | 169 | logger = logging.getLogger("msms_main") 170 | logger.setLevel(logging.INFO) 171 | 172 | formatter = logging.Formatter( 173 | "%(asctime)s.%(msecs)03d %(levelname)s: %(message)s", 174 | datefmt="%Y-%m-%d %H:%M:%S", 175 | ) 176 | 177 | ch = logging.StreamHandler(stream=sys.stdout) 178 | ch.setFormatter(formatter) 179 | logger.addHandler(ch) 180 | 181 | path = os.path.join("msms_main.log") 182 | fh = logging.FileHandler(path) 183 | fh.setFormatter(formatter) 184 | 185 | logger.addHandler(fh) 186 | 187 | logging.info(cfg) 188 | 189 | dataset_config = cfg["dataset"] 190 | 191 | if dataset_config["name"] not in ("canopus", "msg"): 192 | raise NotImplementedError("Unknown dataset {}".format(cfg["dataset"])) 193 | 194 | datamodule = spec2mol_dataset.Spec2MolDataModule(cfg) # TODO: Add hyper for n_bits 195 | dataset_infos = spec2mol_dataset.Spec2MolDatasetInfos(datamodule, cfg) 196 | 197 | domain_features = ExtraMolecularFeatures(dataset_infos=dataset_infos) 198 | if cfg.model.extra_features is not None: 199 | extra_features = ExtraFeatures(cfg.model.extra_features, dataset_info=dataset_infos) 200 | else: 201 | extra_features = DummyExtraFeatures() 202 | 203 | dataset_infos.compute_input_output_dims(datamodule=datamodule, extra_features=extra_features, domain_features=domain_features) 204 | 205 | logging.info("Dataset infos:", dataset_infos.output_dims) 206 | train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) 207 | 208 | # We do not evaluate novelty during training 209 | visualization_tools = MolecularVisualization(cfg.dataset.remove_h, dataset_infos=dataset_infos) 210 | 211 | model_kwargs = {'dataset_infos': dataset_infos, 'train_metrics': train_metrics, 'visualization_tools': visualization_tools, 212 | 'extra_features': extra_features, 'domain_features': domain_features} 213 | 214 | if cfg.general.test_only: 215 | # When testing, previous configuration is fully loaded 216 | cfg, _ = get_resume(cfg, model_kwargs) 217 | #os.chdir(cfg.general.test_only.split('checkpoints')[0]) 218 | elif cfg.general.resume is not None: 219 | # When resuming, we can override some parts of previous configuration 220 | cfg, _ = get_resume_adaptive(cfg, model_kwargs) 221 | #os.chdir(cfg.general.resume.split('checkpoints')[0]) 222 | 223 | os.makedirs('preds/', exist_ok=True) 224 | os.makedirs('logs/', exist_ok=True) 225 | os.makedirs('logs/' + cfg.general.name, exist_ok=True) 226 | 227 | model = Spec2MolDenoisingDiffusion(cfg=cfg, **model_kwargs) 228 | 229 | callbacks = [] 230 | callbacks.append(LearningRateMonitor(logging_interval='step')) 231 | if cfg.train.save_model: # TODO: More advanced checkpointing 232 | checkpoint_callback = ModelCheckpoint(dirpath=f"checkpoints/{cfg.general.name}", # best (top-5) checkpoints 233 | filename='{epoch}', 234 | monitor='val/NLL', 235 | save_top_k=5, 236 | mode='min', 237 | every_n_epochs=1) 238 | last_ckpt_save = ModelCheckpoint(dirpath=f"checkpoints/{cfg.general.name}", filename='last', every_n_epochs=1) # most recent checkpoint 239 | callbacks.append(last_ckpt_save) 240 | callbacks.append(checkpoint_callback) 241 | 242 | if cfg.train.ema_decay > 0: 243 | ema_callback = utils.EMA(decay=cfg.train.ema_decay) 244 | callbacks.append(ema_callback) 245 | 246 | name = cfg.general.name 247 | if name == 'debug': 248 | logging.warning("Run is called 'debug' -- it will run with fast_dev_run. ") 249 | 250 | loggers = [ 251 | CSVLogger(save_dir=f"logs/{name}", name=name), 252 | WandbLogger(name=name, save_dir=f"logs/{name}", project=cfg.general.wandb_name, log_model=False, config=utils.cfg_to_dict(cfg)) 253 | ] 254 | 255 | use_gpu = cfg.general.gpus > 0 and torch.cuda.is_available() 256 | trainer = Trainer(gradient_clip_val=cfg.train.clip_grad, 257 | strategy="ddp_find_unused_parameters_true", # Needed to load old checkpoints 258 | accelerator='gpu' if use_gpu else 'cpu', 259 | devices=cfg.general.gpus if use_gpu else 1, 260 | max_epochs=cfg.train.n_epochs, 261 | check_val_every_n_epoch=cfg.general.check_val_every_n_epochs, 262 | fast_dev_run=cfg.general.name == 'debug', 263 | callbacks=callbacks, 264 | log_every_n_steps=50 if name != 'debug' else 1, 265 | logger=loggers) 266 | 267 | apply_encoder_finetuning(model, cfg.general.encoder_finetune_strategy) 268 | apply_decoder_finetuning(model, cfg.general.decoder_finetune_strategy) 269 | 270 | if cfg.general.load_weights is not None: 271 | logging.info(f"Loading weights from {cfg.general.load_weights}") 272 | model = load_weights(model, cfg.general.load_weights) 273 | 274 | if not cfg.general.test_only: 275 | trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) 276 | if cfg.general.name not in ['debug', 'test']: 277 | trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.checkpoint_strategy) 278 | else: 279 | # Start by evaluating test_only_path 280 | trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) 281 | if cfg.general.evaluate_all_checkpoints: 282 | directory = pathlib.Path(cfg.general.test_only).parents[0] 283 | logging.info("Directory:", directory) 284 | files_list = os.listdir(directory) 285 | for file in files_list: 286 | if '.ckpt' in file: 287 | ckpt_path = os.path.join(directory, file) 288 | if ckpt_path == cfg.general.test_only: 289 | continue 290 | logging.info("Loading checkpoint", ckpt_path) 291 | trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path) 292 | 293 | 294 | if __name__ == '__main__': 295 | main() 296 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch_geometric.utils 6 | from omegaconf import OmegaConf, open_dict 7 | from torch_geometric.utils import to_dense_adj, to_dense_batch 8 | import torch 9 | import omegaconf 10 | import wandb 11 | from rdkit import Chem 12 | from rdkit.Chem import AllChem 13 | from rdkit.Chem import DataStructs 14 | 15 | def cfg_to_dict(cfg): 16 | return omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) 17 | 18 | def normalize(X, E, y, norm_values, norm_biases, node_mask): 19 | X = (X - norm_biases[0]) / norm_values[0] 20 | E = (E - norm_biases[1]) / norm_values[1] 21 | y = (y - norm_biases[2]) / norm_values[2] 22 | 23 | diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1) 24 | E[diag] = 0 25 | 26 | return PlaceHolder(X=X, E=E, y=y).mask(node_mask) 27 | 28 | 29 | def unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=False): 30 | """ 31 | X : node features 32 | E : edge features 33 | y : global features` 34 | norm_values : [norm value X, norm value E, norm value y] 35 | norm_biases : same order 36 | node_mask 37 | """ 38 | X = (X * norm_values[0] + norm_biases[0]) 39 | E = (E * norm_values[1] + norm_biases[1]) 40 | y = y * norm_values[2] + norm_biases[2] 41 | 42 | return PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse) 43 | 44 | 45 | def to_dense(x, edge_index, edge_attr, batch): 46 | X, node_mask = to_dense_batch(x=x, batch=batch) 47 | # node_mask = node_mask.float() 48 | edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr) 49 | # TODO: carefully check if setting node_mask as a bool breaks the continuous case 50 | max_num_nodes = X.size(1) 51 | E = to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes) 52 | E = encode_no_edge(E) 53 | 54 | return PlaceHolder(X=X, E=E, y=None), node_mask 55 | 56 | 57 | def encode_no_edge(E): 58 | assert len(E.shape) == 4 59 | if E.shape[-1] == 0: 60 | return E 61 | no_edge = torch.sum(E, dim=3) == 0 62 | first_elt = E[:, :, :, 0] 63 | first_elt[no_edge] = 1 64 | E[:, :, :, 0] = first_elt 65 | diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1) 66 | E[diag] = 0 67 | return E 68 | 69 | 70 | def update_config_with_new_keys(cfg, saved_cfg): 71 | saved_general = saved_cfg.general 72 | saved_train = saved_cfg.train 73 | saved_model = saved_cfg.model 74 | 75 | for key, val in saved_general.items(): 76 | OmegaConf.set_struct(cfg.general, True) 77 | with open_dict(cfg.general): 78 | if key not in cfg.general.keys(): 79 | setattr(cfg.general, key, val) 80 | 81 | OmegaConf.set_struct(cfg.train, True) 82 | with open_dict(cfg.train): 83 | for key, val in saved_train.items(): 84 | if key not in cfg.train.keys(): 85 | setattr(cfg.train, key, val) 86 | 87 | OmegaConf.set_struct(cfg.model, True) 88 | with open_dict(cfg.model): 89 | for key, val in saved_model.items(): 90 | if key not in cfg.model.keys(): 91 | setattr(cfg.model, key, val) 92 | return cfg 93 | 94 | 95 | class PlaceHolder: 96 | def __init__(self, X, E, y): 97 | self.X = X 98 | self.E = E 99 | self.y = y 100 | 101 | def type_as(self, x: torch.Tensor): 102 | """ Changes the device and dtype of X, E, y. """ 103 | self.X = self.X.type_as(x) 104 | self.E = self.E.type_as(x) 105 | self.y = self.y.type_as(x) 106 | return self 107 | 108 | def mask(self, node_mask, collapse=False): 109 | x_mask = node_mask.unsqueeze(-1) # bs, n, 1 110 | e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1 111 | e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1 112 | 113 | if collapse: 114 | self.X = torch.argmax(self.X, dim=-1) 115 | self.E = torch.argmax(self.E, dim=-1) 116 | 117 | self.X[node_mask == 0] = - 1 118 | self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1 119 | else: 120 | self.X = self.X * x_mask 121 | self.E = self.E * e_mask1 * e_mask2 122 | assert torch.allclose(self.E, torch.transpose(self.E, 1, 2)) 123 | return self 124 | 125 | 126 | def setup_wandb(cfg): 127 | config_dict = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) 128 | kwargs = {'name': cfg.general.name, 'project': f'graph_ddm_{cfg.dataset.name}', 'config': config_dict, 129 | 'settings': wandb.Settings(_disable_stats=True), 'reinit': True, 'mode': cfg.general.wandb} 130 | wandb.init(**kwargs) 131 | wandb.save('*.txt') 132 | 133 | def mol2smiles(mol): 134 | try: 135 | Chem.SanitizeMol(mol) 136 | except ValueError: 137 | return None 138 | return Chem.MolToSmiles(mol) 139 | 140 | def is_valid(mol): 141 | smiles = mol2smiles(mol) 142 | if smiles is None: 143 | return False 144 | 145 | try: 146 | mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True) 147 | except: 148 | return False 149 | if len(mol_frags) > 1: 150 | return False 151 | 152 | return True 153 | 154 | def inchi_to_fingerprint(inchi: str, nbits: int = 2048, radius=3) -> np.ndarray: 155 | """get_morgan_fp.""" 156 | 157 | mol = Chem.MolFromInchi(inchi) 158 | 159 | curr_fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nbits) 160 | 161 | fingerprint = np.zeros((0,), dtype=np.uint8) 162 | DataStructs.ConvertToNumpyArray(curr_fp, fingerprint) 163 | return fingerprint 164 | 165 | def tanimoto_sim(x: np.ndarray, y: np.ndarray) -> List[float]: 166 | # Calculate tanimoto distance with binary fingerprint 167 | intersect_mat = x & y 168 | union_mat = x | y 169 | 170 | intersection = intersect_mat.sum(-1) 171 | union = union_mat.sum(-1) 172 | 173 | ### I took the reciprocal here so instead of tanimoto sim, it became 174 | # distance. Could have just made negative but 175 | # sklearn doesn't accept negative distance matrices 176 | output = intersection / union 177 | return output 178 | 179 | def cosine_sim(x: np.ndarray, y: np.ndarray) -> List[float]: 180 | # Calculate cosine similarity with binary fingerprint 181 | dot_product = np.dot(x, y) 182 | 183 | norm_x = np.linalg.norm(x) 184 | norm_y = np.linalg.norm(y) 185 | 186 | output = dot_product / (norm_x * norm_y) 187 | return output 188 | 189 | try: 190 | from rdkit.Chem.MolStandardize.tautomer import TautomerCanonicalizer, TautomerTransform 191 | _RD_TAUTOMER_CANONICALIZER = 'v1' 192 | _TAUTOMER_TRANSFORMS = ( 193 | TautomerTransform('1,3 heteroatom H shift', 194 | '[#7,S,O,Se,Te;!H0]-[#7X2,#6,#15]=[#7,#16,#8,Se,Te]'), 195 | TautomerTransform('1,3 (thio)keto/enol r', '[O,S,Se,Te;X2!H0]-[C]=[C]'), 196 | ) 197 | except ModuleNotFoundError: 198 | from rdkit.Chem.MolStandardize.rdMolStandardize import TautomerEnumerator # newer rdkit 199 | _RD_TAUTOMER_CANONICALIZER = 'v2' 200 | 201 | def canonical_mol_from_inchi(inchi): 202 | """Canonicalize mol after Chem.MolFromInchi 203 | Note that this function may be 50 times slower than Chem.MolFromInchi""" 204 | mol = Chem.MolFromInchi(inchi) 205 | if mol is None: 206 | return None 207 | if _RD_TAUTOMER_CANONICALIZER == 'v1': 208 | _molvs_t = TautomerCanonicalizer(transforms=_TAUTOMER_TRANSFORMS) 209 | mol = _molvs_t.canonicalize(mol) 210 | else: 211 | _te = TautomerEnumerator() 212 | mol = _te.Canonicalize(mol) 213 | return mol 214 | 215 | --------------------------------------------------------------------------------