├── .gitignore ├── README.md ├── assets ├── conditional-ligand-pocket-sampling-trajectory.gif ├── conditional-sampling-maximized-polarizability.gif ├── conditional-sampling-minimized-polarizability.gif ├── unconditional-sampling-trajectory.gif └── unconditional-sampling-trajectory.mp4 ├── data ├── README.md ├── geom.zip ├── geom │ └── .gitkeep ├── qm9.zip └── qm9 │ └── .gitkeep ├── eqgat_diff ├── .gitignore ├── .ipynb_checkpoints │ └── Untitled-checkpoint.ipynb ├── README.md ├── callbacks │ └── ema.py ├── e3moldiffusion │ ├── README.md │ ├── __init__.py │ ├── chem.py │ ├── convs.py │ ├── coordsatomsbonds.py │ ├── gnn.py │ ├── modules.py │ └── molfeat.py ├── environment.yml ├── experiments │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── abstract_dataset.py │ │ ├── adaptive_loader.py │ │ ├── calculate_dipoles.py │ │ ├── calculate_energies.py │ │ ├── config_file.py │ │ ├── data_info.py │ │ ├── distributions.py │ │ ├── geom │ │ │ ├── calculate_energies.py │ │ │ ├── geom_dataset_adaptive.py │ │ │ ├── geom_dataset_adaptive_qm.py │ │ │ ├── geom_dataset_energy.py │ │ │ └── geom_dataset_nonadaptive.py │ │ ├── ligand │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ ├── geometry_utils.py │ │ │ ├── ligand_dataset_adaptive.py │ │ │ ├── ligand_dataset_nonadaptive.py │ │ │ ├── molecule_builder.py │ │ │ ├── process_bindingmoad.py │ │ │ ├── process_crossdocked.py │ │ │ └── utils.py │ │ ├── metrics.py │ │ ├── pubchem │ │ │ ├── download_pubchem.py │ │ │ ├── preprocess_pubchem.py │ │ │ ├── process_pubchem.py │ │ │ ├── pubchem_dataset_adaptive.py │ │ │ └── pubchem_dataset_nonadaptive.py │ │ ├── qm9 │ │ │ └── qm9_dataset.py │ │ └── utils.py │ ├── diffusion │ │ ├── __init__.py │ │ ├── categorical.py │ │ ├── continuous.py │ │ └── utils.py │ ├── diffusion_continuous.py │ ├── diffusion_continuous_pocket.py │ ├── diffusion_discrete.py │ ├── diffusion_discrete_addfeats.py │ ├── diffusion_discrete_pocket.py │ ├── diffusion_discrete_pocket_addfeats.py │ ├── diffusion_pretrain_discrete.py │ ├── diffusion_pretrain_discrete_addfeats.py │ ├── docking.py │ ├── docking_mgl.py │ ├── generate_ligands.py │ ├── hparams.py │ ├── losses.py │ ├── molecule_utils.py │ ├── run_evaluation.py │ ├── run_evaluation_ligand.py │ ├── run_train.py │ ├── sampling │ │ ├── __init__.py │ │ ├── analyze.py │ │ ├── analyze_strict.py │ │ ├── fpscores.pkl.gz │ │ └── utils.py │ ├── utils.py │ ├── xtb_energy.py │ ├── xtb_relaxation.py │ └── xtb_wrapper.py ├── setup.cfg └── setup.py ├── inference ├── README.md ├── run_eval_geom.sh ├── run_eval_qm9.sh ├── sampling_geom.ipynb ├── sampling_qm9.ipynb └── tmp │ ├── .gitkeep │ └── geom │ └── .gitkeep └── weights ├── README.md ├── geom └── .gitkeep └── qm9 └── .gitkeep /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | .eggs 3 | __pycache__ 4 | build 5 | dist 6 | .idea/* 7 | 8 | *.DS_Store* 9 | *.ipynb_checkpoints/* 10 | 11 | .vscode/ 12 | 13 | data/geom/* 14 | !data/geom/.gitkeep 15 | 16 | data/qm9/* 17 | !data/qm9/.gitkeep 18 | 19 | inference/tmp/geom/* 20 | !inference/tmp/geom/.gitkeep 21 | inference/tmp/qm9/* 22 | !inference/tmp/qm9/.gitkeep 23 | 24 | weights/geom/* 25 | !weights/geom/.gitkeep 26 | 27 | weights/qm9/* 28 | !weights/qm9/.gitkeep -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # eqgat-diff 2 | 3 | This repository serves as placeholder. 4 | The codebase is stored in `eqgat_diff/` directory. 5 | 6 | Below, we show some sampling trajectories from our models. The corresponding .gif files can be found in the `assets/` directory. 7 | 8 | 9 | ### Unconditional Sampling Trajectory 10 | 11 | 12 | 13 | 14 | ### Conditional Sampling Trajectory on maximized polarizability 15 | 16 | 17 | 18 | 19 | ### Conditional Sampling Trajectory on minimized polarizability 20 | 21 | 22 | 23 | 24 | 25 | ### Conditional Sampling Trajectory on a fixed protein pocket 26 | 27 | 28 | -------------------------------------------------------------------------------- /assets/conditional-ligand-pocket-sampling-trajectory.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/assets/conditional-ligand-pocket-sampling-trajectory.gif -------------------------------------------------------------------------------- /assets/conditional-sampling-maximized-polarizability.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/assets/conditional-sampling-maximized-polarizability.gif -------------------------------------------------------------------------------- /assets/conditional-sampling-minimized-polarizability.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/assets/conditional-sampling-minimized-polarizability.gif -------------------------------------------------------------------------------- /assets/unconditional-sampling-trajectory.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/assets/unconditional-sampling-trajectory.gif -------------------------------------------------------------------------------- /assets/unconditional-sampling-trajectory.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/assets/unconditional-sampling-trajectory.mp4 -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## Dataset 2 | 3 | We provide the reduced datasets for QM9 and Geom-Drugs to run the inference. 4 | 5 | Please extract `qm9.zip` and `geom.zip`. 6 | Notice that the training/validation and test sets are not provided since they require larger sizes. As of now, we upload only the required empirical prior distributions to make inference work. -------------------------------------------------------------------------------- /data/geom.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/data/geom.zip -------------------------------------------------------------------------------- /data/geom/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/data/geom/.gitkeep -------------------------------------------------------------------------------- /data/qm9.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/data/qm9.zip -------------------------------------------------------------------------------- /data/qm9/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/data/qm9/.gitkeep -------------------------------------------------------------------------------- /eqgat_diff/.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | .eggs 3 | __pycache__ 4 | build 5 | dist 6 | .idea/* 7 | 8 | *.DS_Store* 9 | 10 | # setuptools_scm 11 | e3moldiffusion/_version.py 12 | figs/ 13 | e3moldiffusion/potential.py 14 | geom/train_potential.py 15 | geom/data/* 16 | !geom/data/.gitkeep 17 | geom/slurm_outs/* 18 | 19 | e3moldiffusion/sample_data.py 20 | sample_data/* 21 | 22 | .vscode/ 23 | 24 | configs/ 25 | scripts/ 26 | generate_ligands.sh 27 | # scripts/experiment*.sl 28 | 29 | geom/.ipynb_checkpoints/ 30 | geom/trajs/* 31 | geom/my_movie.gif 32 | geom/vis.ipynb 33 | 34 | geom/logs 35 | geom/evaluation/ 36 | geom/train_gan.py 37 | 38 | # fullerene/ 39 | fullerene/data/ 40 | 41 | notebooks/ 42 | _old/ 43 | 44 | *xlsx 45 | -------------------------------------------------------------------------------- /eqgat_diff/.ipynb_checkpoints/Untitled-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 5 6 | } 7 | -------------------------------------------------------------------------------- /eqgat_diff/README.md: -------------------------------------------------------------------------------- 1 | # E(3) Equivariant Diffusion for Molecules 2 | 3 | Research repository exploring the capabalities of (continuous and discrete) denoising diffusion probabilistic models applied on molecular data. 4 | 5 | ## Installation 6 | Best installed using mamba. 7 | ```bash 8 | mamba env create -f environment.yml 9 | ``` 10 | 11 | ## Experiments 12 | We will update the repository with the corresponding datasets to run the experiments. 13 | 14 | Generally, in `experiments/run_train.py` the main training script is executed, while in `configs/*.yaml` the configuration files for each dataset are stored. 15 | 16 | An example training run can be executed with the following command 17 | 18 | ```bash 19 | mamba activate eqgatdiff 20 | export PYTHONPATH="YOUR_ABSOLUTE_PATH/eqgat-diff/eqgat_diff" 21 | python experiments/run_train --conf configs/my_config_file.yaml 22 | ``` 23 | 24 | Note that currently we do not provide the datasets, hence the code is currently just for reviewing and how the model and training runs are implemented. 25 | 26 | For example, an EQGAT-diff model that leverages Gaussian diffusion on atomic coordinates, but discrete diffusion for atom- and bond-types is implemented in `experiments/diffusion_discrete.py`. 27 | 28 | The same model, that leverages Gaussian diffusion for atomic coordinates, atom- and bond-types is implemented in `experiments/diffusion_continuous.py`. 29 | 30 | All configurable hyperparameters are listed in `experiments/hparams.py` 31 | 32 | ## Inference and Weights 33 | 34 | Currently we are still in the progress of publishing the code. Upon request, we provide model weights of the QM9 and Geom-Drugs models. 35 | Please look into `inference/` and `weights/` subdirectory for more details. -------------------------------------------------------------------------------- /eqgat_diff/callbacks/ema.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import Callback 2 | from pytorch_lightning import LightningModule, Trainer 3 | 4 | from torch_ema import ExponentialMovingAverage as EMA 5 | 6 | 7 | class ExponentialMovingAverage(Callback): 8 | """ 9 | Callback for using an exponential moving average over model weights. 10 | The most recent weights are only accessed during the training steps, 11 | otherwise the smoothed weight are used. 12 | """ 13 | 14 | def __init__(self, decay: float, *args, **kwargs): 15 | """ 16 | Args: 17 | decay (float): decay of the exponential moving average 18 | """ 19 | self.decay = decay 20 | self.ema = None 21 | self._to_load = None 22 | 23 | def on_fit_start(self, trainer, pl_module: LightningModule): 24 | if self.ema is None: 25 | self.ema = EMA(pl_module.parameters(), decay=self.decay) 26 | if self._to_load is not None: 27 | self.ema.load_state_dict(self._to_load) 28 | self._to_load = None 29 | 30 | self.ema.store() 31 | self.ema.copy_to() 32 | 33 | def on_train_epoch_start( 34 | self, trainer: Trainer, pl_module: LightningModule 35 | ) -> None: 36 | self.ema.restore() 37 | 38 | def on_train_batch_end(self, trainer, pl_module: LightningModule, *args, **kwargs): 39 | self.ema.update() 40 | 41 | def on_validation_epoch_start( 42 | self, trainer: Trainer, pl_module: LightningModule, *args, **kwargs 43 | ): 44 | self.ema.store() 45 | self.ema.copy_to() 46 | 47 | def load_state_dict(self, state_dict): 48 | if "exponential_moving_average" in state_dict: 49 | if self.ema is None: 50 | self._to_load = state_dict["exponential_moving_average"] 51 | else: 52 | self.ema.load_state_dict(state_dict["exponential_moving_average"]) 53 | 54 | def state_dict(self): 55 | return {"exponential_moving_average": self.ema.state_dict()} -------------------------------------------------------------------------------- /eqgat_diff/e3moldiffusion/README.md: -------------------------------------------------------------------------------- 1 | # e3moldiffusion 2 | Directory implements the package to create the diffusion/score model. 3 | -------------------------------------------------------------------------------- /eqgat_diff/e3moldiffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/eqgat_diff/e3moldiffusion/__init__.py -------------------------------------------------------------------------------- /eqgat_diff/e3moldiffusion/chem.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import List, Tuple 3 | 4 | import torch 5 | from rdkit import Chem 6 | from rdkit.Chem import PeriodicTable as PT, rdDepictor as DP, rdMolAlign as MA 7 | from rdkit.Chem.Draw import rdMolDraw2D as MD2 8 | from rdkit.Chem.rdchem import GetPeriodicTable, Mol 9 | from rdkit.Chem.rdmolops import RemoveHs 10 | 11 | 12 | def set_conformer_positions(conf, pos): 13 | for i in range(pos.shape[0]): 14 | conf.SetAtomPosition(i, pos[i].tolist()) 15 | return conf 16 | 17 | 18 | def update_data_rdmol_positions(data): 19 | for i in range(data.pos.size(0)): 20 | data.rdmol.GetConformer(0).SetAtomPosition(i, data.pos[i].tolist()) 21 | return data 22 | 23 | 24 | def update_data_pos_from_rdmol(data): 25 | new_pos = torch.FloatTensor(data.rdmol.GetConformer(0).GetPositions()).to(data.pos) 26 | data.pos = new_pos 27 | return data 28 | 29 | 30 | def set_rdmol_positions(rdkit_mol, pos): 31 | """ 32 | Args: 33 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object. 34 | pos: (N_atoms, 3) 35 | """ 36 | mol = copy.deepcopy(rdkit_mol) 37 | mol = set_rdmol_positions_(mol, pos) 38 | return mol 39 | 40 | 41 | def set_rdmol_positions_(mol, pos): 42 | """ 43 | Args: 44 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object. 45 | pos: (N_atoms, 3) 46 | """ 47 | for i in range(pos.shape[0]): 48 | mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist()) 49 | return mol 50 | 51 | 52 | def get_atom_symbol(atomic_number): 53 | return PT.GetElementSymbol(GetPeriodicTable(), atomic_number) 54 | 55 | 56 | def mol_to_smiles(mol: Mol) -> str: 57 | return Chem.MolToSmiles(mol, allHsExplicit=True) 58 | 59 | 60 | def mol_to_smiles_without_Hs(mol: Mol) -> str: 61 | return Chem.MolToSmiles(Chem.RemoveHs(mol)) 62 | 63 | 64 | def remove_duplicate_mols(molecules: List[Mol]) -> List[Mol]: 65 | unique_tuples: List[Tuple[str, Mol]] = [] 66 | 67 | for molecule in molecules: 68 | duplicate = False 69 | smiles = mol_to_smiles(molecule) 70 | for unique_smiles, _ in unique_tuples: 71 | if smiles == unique_smiles: 72 | duplicate = True 73 | break 74 | 75 | if not duplicate: 76 | unique_tuples.append((smiles, molecule)) 77 | 78 | return [mol for smiles, mol in unique_tuples] 79 | 80 | 81 | def get_atoms_in_ring(mol): 82 | atoms = set() 83 | for ring in mol.GetRingInfo().AtomRings(): 84 | for a in ring: 85 | atoms.add(a) 86 | return atoms 87 | 88 | 89 | def get_2D_mol(mol): 90 | mol = copy.deepcopy(mol) 91 | DP.Compute2DCoords(mol) 92 | return mol 93 | 94 | 95 | def draw_mol_svg(mol, molSize=(450, 150), kekulize=False): 96 | mc = Chem.Mol(mol.ToBinary()) 97 | if kekulize: 98 | try: 99 | Chem.Kekulize(mc) 100 | except: 101 | mc = Chem.Mol(mol.ToBinary()) 102 | if not mc.GetNumConformers(): 103 | DP.Compute2DCoords(mc) 104 | drawer = MD2.MolDraw2DSVG(molSize[0], molSize[1]) 105 | drawer.DrawMolecule(mc) 106 | drawer.FinishDrawing() 107 | svg = drawer.GetDrawingText() 108 | # It seems that the svg renderer used doesn't quite hit the spec. 109 | # Here are some fixes to make it work in the notebook, although I think 110 | # the underlying issue needs to be resolved at the generation step 111 | # return svg.replace('svg:','') 112 | return svg 113 | 114 | 115 | def GetBestRMSD(probe, ref): 116 | probe = RemoveHs(probe) 117 | ref = RemoveHs(ref) 118 | rmsd = MA.GetBestRMS(probe, ref) 119 | return rmsd 120 | -------------------------------------------------------------------------------- /eqgat_diff/e3moldiffusion/molfeat.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | import rdkit 5 | from rdkit import Chem 6 | from rdkit.Chem import GetPeriodicTable 7 | from torch import nn 8 | from torch_geometric.data import Data 9 | from torch_geometric.data import Data, Batch 10 | 11 | 12 | PERIODIC_TABLE = GetPeriodicTable() 13 | 14 | # allowable multiple choice node and edge features 15 | allowable_features = { 16 | 'possible_atomic_num_list' : list(range(1, 119)) + ['misc'], 17 | 'possible_chirality_list' : [ 18 | 'CHI_UNSPECIFIED', 19 | 'CHI_TETRAHEDRAL_CW', 20 | 'CHI_TETRAHEDRAL_CCW', 21 | 'CHI_OTHER' 22 | ], 23 | 'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], 24 | 'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'], 25 | 'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], 26 | 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'], 27 | 'possible_hybridization_list' : [ 28 | 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc' 29 | ], 30 | 'possible_is_aromatic_list': [False, True], 31 | 'possible_is_in_ring_list': [False, True], 32 | 'possible_bond_type_list' : [ 33 | 'SINGLE', 34 | 'DOUBLE', 35 | 'TRIPLE', 36 | 'AROMATIC'#, 37 | # 'misc' 38 | ], 39 | 'possible_bond_stereo_list': [ 40 | 'STEREONONE', 41 | 'STEREOZ', 42 | 'STEREOE', 43 | 'STEREOCIS', 44 | 'STEREOTRANS', 45 | 'STEREOANY', 46 | ], 47 | 'possible_is_conjugated_list': [False, True], 48 | } 49 | 50 | def atom_type_config(dataset: str = "qm9"): 51 | if dataset == "qm9": 52 | mapping = {"H": 0, "C": 1, "N": 2, "O": 3, "F": 4} 53 | elif dataset == "drugs": 54 | mapping = { 55 | "H": 0, 56 | "B": 1, 57 | "C": 2, 58 | "N": 3, 59 | "O": 4, 60 | "F": 5, 61 | "Al": 6, 62 | "Si": 7, 63 | "P": 8, 64 | "S": 9, 65 | "Cl": 10, 66 | "As": 11, 67 | "Br": 12, 68 | "I": 13, 69 | "Hg": 14, 70 | "Bi": 15, 71 | } 72 | return mapping 73 | 74 | def safe_index(l, e): 75 | """ 76 | Return index of element e in list l. If e is not present, return the last index 77 | """ 78 | try: 79 | return l.index(e) 80 | except: 81 | return len(l) - 1 82 | 83 | 84 | def atom_to_feature_vector(atom): 85 | """ 86 | Converts rdkit atom object to feature list of indices 87 | :param mol: rdkit atom object 88 | :return: list 89 | """ 90 | atom_feature = [ 91 | safe_index(allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()), 92 | safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()), 93 | safe_index(allowable_features['possible_hybridization_list'], str(atom.GetHybridization())), 94 | allowable_features['possible_is_aromatic_list'].index(atom.GetIsAromatic()), 95 | allowable_features['possible_is_in_ring_list'].index(atom.IsInRing()) 96 | ] 97 | return atom_feature 98 | 99 | 100 | def get_atom_feature_dims(): 101 | return list(map(len, [ 102 | allowable_features['possible_atomic_num_list'], 103 | allowable_features['possible_degree_list'], 104 | allowable_features['possible_hybridization_list'], 105 | allowable_features['possible_is_aromatic_list'], 106 | allowable_features['possible_is_in_ring_list'] 107 | ])) 108 | 109 | def bond_to_feature_vector(bond): 110 | """ 111 | Converts rdkit bond object to feature list of indices 112 | :param mol: rdkit bond object 113 | :return: list 114 | """ 115 | bond_feature = [ 116 | safe_index(allowable_features['possible_bond_type_list'], str(bond.GetBondType())) 117 | ] 118 | return bond_feature 119 | 120 | 121 | def get_bond_feature_dims(): 122 | return list(map(len, [ 123 | allowable_features['possible_bond_type_list'], 124 | ])) 125 | 126 | 127 | def atom_feature_vector_to_dict(atom_feature): 128 | [atomic_num_idx, 129 | degree_idx, 130 | hybridization_idx, 131 | is_aromatic_idx, 132 | is_in_ring_idx] = atom_feature 133 | 134 | feature_dict = { 135 | 'atomic_num': allowable_features['possible_atomic_num_list'][atomic_num_idx], 136 | 'degree': allowable_features['possible_degree_list'][degree_idx], 137 | 'hybridization': allowable_features['possible_hybridization_list'][hybridization_idx], 138 | 'is_aromatic': allowable_features['possible_is_aromatic_list'][is_aromatic_idx], 139 | 'is_in_ring': allowable_features['possible_is_in_ring_list'][is_in_ring_idx] 140 | } 141 | return feature_dict 142 | 143 | 144 | def bond_feature_vector_to_dict(bond_feature): 145 | [bond_type_idx] = bond_feature 146 | 147 | feature_dict = { 148 | 'bond_type': allowable_features['possible_bond_type_list'][bond_type_idx] 149 | } 150 | 151 | return feature_dict 152 | 153 | 154 | def smiles_or_mol_to_graph(smol: Union[str, Chem.Mol], create_bond_graph: bool = True): 155 | if isinstance(smol, str): 156 | mol = Chem.MolFromSmiles(smol) 157 | else: 158 | mol = smol 159 | 160 | # atoms 161 | atom_features_list = [] 162 | atom_element_name_list = [] 163 | for atom in mol.GetAtoms(): 164 | atom_features_list.append(atom_to_feature_vector(atom)) 165 | atom_element_name_list.append(PERIODIC_TABLE.GetElementSymbol(atom.GetAtomicNum())) 166 | 167 | 168 | x = torch.tensor(atom_features_list, dtype=torch.int64) 169 | assert x.size(-1) == 5 170 | # only take atom element 171 | # x = x[:, 0].view(-1, 1) 172 | 173 | if create_bond_graph: 174 | # bonds 175 | edges_list = [] 176 | edge_features_list = [] 177 | for bond in mol.GetBonds(): 178 | i = bond.GetBeginAtomIdx() 179 | j = bond.GetEndAtomIdx() 180 | edge_feature = bond_to_feature_vector(bond) 181 | # add edges in both directions 182 | edges_list.append((i, j)) 183 | edge_features_list.append(edge_feature[0]) 184 | edges_list.append((j, i)) 185 | edge_features_list.append(edge_feature[0]) 186 | 187 | # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] 188 | edge_index = torch.tensor(edges_list, dtype=torch.int64).T 189 | # data.edge_attr: Edge feature matrix with shape [num_edges] 190 | edge_attr = torch.tensor(edge_features_list, dtype=torch.int64) 191 | 192 | if edge_index.numel() > 0: # Sort indices. 193 | perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort() 194 | edge_index, edge_attr = edge_index[:, perm], edge_attr[perm] 195 | else: 196 | edge_index = edge_attr = None 197 | 198 | data = Data(x=x, atom_elements = atom_element_name_list, edge_index=edge_index, edge_attr=edge_attr) 199 | return data 200 | 201 | 202 | class AtomEncoder(nn.Module): 203 | def __init__(self, emb_dim, max_norm: float = 10.0, 204 | use_all_atom_features: bool = False): 205 | super(AtomEncoder, self).__init__() 206 | # before: richer input featurization that also consists information about topology of graph like degree etc. 207 | FULL_ATOM_FEATURE_DIMS = get_atom_feature_dims() 208 | if not use_all_atom_features: 209 | # now: only atom type 210 | FULL_ATOM_FEATURE_DIMS = [FULL_ATOM_FEATURE_DIMS[0]] 211 | self.atom_embedding_list = nn.ModuleList() 212 | for dim in FULL_ATOM_FEATURE_DIMS: 213 | emb = nn.Embedding(dim, emb_dim, max_norm=max_norm) 214 | self.atom_embedding_list.append(emb) 215 | self.reset_parameters() 216 | 217 | def reset_parameters(self): 218 | for emb in self.atom_embedding_list: 219 | nn.init.xavier_uniform_(emb.weight.data) 220 | 221 | def forward(self, x): 222 | x_embedding = 0 223 | for i in range(len(self.atom_embedding_list)): 224 | x_embedding += self.atom_embedding_list[i](x[:, i]) 225 | return x_embedding 226 | 227 | class BondEncoderOHE(nn.Module): 228 | def __init__(self, emb_dim, max_num_classes: int): 229 | super(BondEncoderOHE, self).__init__() 230 | self.linear = nn.Linear(max_num_classes, emb_dim) 231 | self.reset_parameters() 232 | def reset_parameters(self): 233 | self.linear.reset_parameters 234 | def forward(self, edge_attr): 235 | bond_embedding = self.linear(edge_attr) 236 | return bond_embedding 237 | 238 | 239 | class BondEncoder(nn.Module): 240 | def __init__(self, emb_dim, max_norm: float = 10.0): 241 | super(BondEncoder, self).__init__() 242 | FULL_BOND_FEATURE_DIMS = get_bond_feature_dims() 243 | self.bond_embedding = nn.Embedding(FULL_BOND_FEATURE_DIMS[0] + 3, emb_dim, max_norm=max_norm) 244 | self.reset_parameters() 245 | def reset_parameters(self): 246 | nn.init.xavier_uniform_(self.bond_embedding.weight.data) 247 | def forward(self, edge_attr): 248 | bond_embedding = self.bond_embedding(edge_attr) 249 | return bond_embedding 250 | 251 | 252 | if __name__ == '__main__': 253 | FULL_BOND_FEATURE_DIMS = get_bond_feature_dims() 254 | bond_attrs = FULL_BOND_FEATURE_DIMS[0] + 3 255 | from torch_sparse import coalesce 256 | from torch_cluster import radius_graph 257 | smol = "O1C=C[C@H]([C@H]1O2)c3c2cc(OC)c4c3OC(=O)C5=C4CCC(=O)5" 258 | data = smiles_or_mol_to_graph(smol) 259 | print(data) 260 | print(data.x) 261 | print(data.atom_elements) 262 | 263 | smol1 = "CC(=O)Oc1ccccc1C(=O)O" 264 | smol2 = "CCCc1nn(C)c2c(=O)[nH]c(-c3cc(S(=O)(=O)N4CCN(C)CC4)ccc3OCC)nc12" 265 | 266 | datalist = [smiles_or_mol_to_graph(s) for s in [smol, smol1, smol2]] 267 | data = Batch.from_data_list(datalist) 268 | bond_edge_index, bond_edge_attr, batch = data.edge_index, data.edge_attr, data.batch 269 | 270 | 271 | atomencoder = AtomEncoder(emb_dim=16) 272 | bondencoder = BondEncoder(emb_dim=16) 273 | 274 | x = atomencoder(data.x) 275 | edge_attr = bondencoder(data.edge_attr) 276 | 277 | pos = torch.randn(data.x.size(0), 3) 278 | bond_edge_index, bond_edge_attr = data.edge_index, data.edge_attr 279 | radius_edge_index = radius_graph(pos, r=5.0, max_num_neighbors=64, flow="source_to_target") 280 | # fromr radius_edge_index remove that are in bond_edge_index 281 | 282 | radius_feat = FULL_BOND_FEATURE_DIMS[0] + 1 283 | radius_edge_attr = torch.full((radius_edge_index.size(1), ), fill_value=radius_feat, device=pos.device, dtype=torch.long) 284 | # need to combine radius-edge-index with graph-edge-index 285 | 286 | nbonds = bond_edge_index.size(1) 287 | nradius = radius_edge_index.size(1) 288 | 289 | combined_edge_index = torch.cat([bond_edge_index, radius_edge_index], dim=-1) 290 | combined_edge_attr = torch.cat([bond_edge_attr, radius_edge_attr], dim=0) 291 | 292 | nbefore = combined_edge_index.size(1) 293 | # coalesce 294 | combined_edge_index, combined_edge_attr = coalesce(index=combined_edge_index, value=combined_edge_attr, m=pos.size(0), n=pos.size(0), op="min") 295 | print(combined_edge_index[:, :30]) 296 | print() 297 | print(combined_edge_attr[:30]) -------------------------------------------------------------------------------- /eqgat_diff/environment.yml: -------------------------------------------------------------------------------- 1 | name: eqgatdiff 2 | channels: 3 | - anaconda 4 | - conda-forge 5 | - defaults 6 | - pytorch 7 | - pyg 8 | - nvidia 9 | dependencies: 10 | - python=3.10.0 11 | - pip 12 | - notebook 13 | - matplotlib 14 | - pytorch::pytorch=2.0.1 15 | - pyg::pyg=2.3.0 16 | - pyg::pytorch-scatter 17 | - pyg::pytorch-sparse 18 | - pyg::pytorch-cluster 19 | - conda-forge::rdkit=2023.03.2 20 | - conda-forge::pytorch-lightning=2.0.6 21 | - conda-forge::jupyter 22 | - conda-forge::openbabel 23 | - xtb 24 | - pip: 25 | - xtb 26 | - ase 27 | - py3Dmol 28 | - lmdb 29 | - nglview 30 | - rmsd 31 | - torch_ema 32 | - tensorboard 33 | - biopython 34 | - xtb -------------------------------------------------------------------------------- /eqgat_diff/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/eqgat_diff/experiments/__init__.py -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/eqgat_diff/experiments/data/__init__.py -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/abstract_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch.utils.data import Subset 5 | 6 | from experiments.data.adaptive_loader import AdaptiveLightningDataset 7 | 8 | try: 9 | from torch_geometric.data import LightningDataset 10 | except ImportError: 11 | from torch_geometric.data.lightning import LightningDataset 12 | from experiments.data.distributions import DistributionNodes 13 | 14 | 15 | def maybe_subset( 16 | ds, random_subset: Optional[float] = None, split=None 17 | ) -> torch.utils.data.Dataset: 18 | if random_subset is None or split in {"test", "val"}: 19 | return ds 20 | else: 21 | idx = torch.randperm(len(ds))[: int(random_subset * len(ds))] 22 | return Subset(ds, idx) 23 | 24 | 25 | class Mixin: 26 | def __getitem__(self, idx): 27 | return self.dataloaders["train"][idx] 28 | 29 | def node_counts(self, max_nodes_possible=300): 30 | all_counts = torch.zeros(max_nodes_possible) 31 | for split in ["train", "val", "test"]: 32 | for i, batch in enumerate(self.dataloaders[split]): 33 | for data in batch: 34 | if data is None: 35 | continue 36 | unique, counts = torch.unique(data.batch, return_counts=True) 37 | for count in counts: 38 | all_counts[count] += 1 39 | max_index = max(all_counts.nonzero()) 40 | all_counts = all_counts[: max_index + 1] 41 | all_counts = all_counts / all_counts.sum() 42 | return all_counts 43 | 44 | def node_types(self): 45 | num_classes = None 46 | for batch in self.dataloaders["train"]: 47 | for data in batch: 48 | num_classes = data.x.shape[1] 49 | break 50 | break 51 | 52 | counts = torch.zeros(num_classes) 53 | 54 | for split in ["train", "val", "test"]: 55 | for i, batch in enumerate(self.dataloaders[split]): 56 | for data in batch: 57 | if data is None: 58 | continue 59 | counts += data.x.sum(dim=0) 60 | 61 | counts = counts / counts.sum() 62 | return counts 63 | 64 | def edge_counts(self): 65 | num_classes = 5 66 | 67 | d = torch.zeros(num_classes) 68 | 69 | for split in ["train", "val", "test"]: 70 | for i, batch in enumerate(self.dataloaders[split]): 71 | for data in batch: 72 | if data is None: 73 | continue 74 | unique, counts = torch.unique(data.batch, return_counts=True) 75 | 76 | all_pairs = 0 77 | for count in counts: 78 | all_pairs += count * (count - 1) 79 | 80 | num_edges = data.edge_index.shape[1] 81 | num_non_edges = all_pairs - num_edges 82 | edge_types = data.edge_attr.sum(dim=0) 83 | assert num_non_edges >= 0 84 | d[0] += num_non_edges 85 | d[1:] += edge_types[1:] 86 | 87 | d = d / d.sum() 88 | return d 89 | 90 | def valency_count(self, max_n_nodes): 91 | valencies = torch.zeros( 92 | 3 * max_n_nodes - 2 93 | ) # Max valency possible if everything is connected 94 | 95 | multiplier = torch.Tensor([0, 1, 2, 3, 1.5]) 96 | 97 | for split in ["train", "val", "test"]: 98 | for i, batch in enumerate(self.dataloaders[split]): 99 | for data in batch: 100 | if data is None: 101 | continue 102 | 103 | n = data.x.shape[0] 104 | 105 | for atom in range(n): 106 | edges = data.edge_attr[data.edge_index[0] == atom] 107 | edges_total = edges.sum(dim=0) 108 | valency = (edges_total * multiplier).sum() 109 | valencies[valency.long().item()] += 1 110 | valencies = valencies / valencies.sum() 111 | return valencies 112 | 113 | 114 | class AbstractDataModule(Mixin, LightningDataset): 115 | def __init__(self, cfg, train_dataset, val_dataset, test_dataset): 116 | super().__init__( 117 | train_dataset, 118 | val_dataset, 119 | test_dataset, 120 | batch_size=cfg.batch_size, 121 | num_workers=cfg.num_workers, 122 | shuffle=True, 123 | pin_memory=getattr(cfg.dataset, "pin_memory", False), 124 | ) 125 | self.cfg = cfg 126 | 127 | 128 | class AbstractDataModuleLigand(Mixin, LightningDataset): 129 | def __init__(self, cfg, train_dataset, val_dataset, test_dataset): 130 | super().__init__( 131 | train_dataset, 132 | val_dataset, 133 | test_dataset, 134 | batch_size=cfg.batch_size, 135 | follow_batch=["pos", "pos_pocket"], 136 | num_workers=cfg.num_workers, 137 | shuffle=True, 138 | pin_memory=getattr(cfg.dataset, "pin_memory", False), 139 | ) 140 | self.cfg = cfg 141 | 142 | 143 | class AbstractAdaptiveDataModule(Mixin, AdaptiveLightningDataset): 144 | def __init__(self, cfg, train_dataset, val_dataset, test_dataset): 145 | super().__init__( 146 | train_dataset, 147 | val_dataset, 148 | test_dataset, 149 | batch_size=cfg.batch_size, 150 | reference_batch_size=cfg.inference_batch_size, 151 | num_workers=cfg.num_workers, 152 | shuffle=True, 153 | pin_memory=getattr(cfg.dataset, "pin_memory", False), 154 | ) 155 | self.cfg = cfg 156 | 157 | 158 | class AbstractAdaptiveDataModuleLigand(Mixin, AdaptiveLightningDataset): 159 | def __init__(self, cfg, train_dataset, val_dataset, test_dataset): 160 | super().__init__( 161 | train_dataset, 162 | val_dataset, 163 | test_dataset, 164 | batch_size=cfg.batch_size, 165 | follow_batch=["pos", "pos_pocket"], 166 | reference_batch_size=cfg.inference_batch_size, 167 | num_workers=cfg.num_workers, 168 | shuffle=True, 169 | pin_memory=getattr(cfg.dataset, "pin_memory", False), 170 | ) 171 | self.cfg = cfg 172 | 173 | 174 | class AbstractDatasetInfos: 175 | def complete_infos(self, statistics, atom_encoder): 176 | self.atom_decoder = [key for key in atom_encoder.keys()] 177 | self.num_atom_types = len(self.atom_decoder) 178 | 179 | # Train + val + test for n_nodes 180 | train_n_nodes = statistics["train"].num_nodes 181 | val_n_nodes = statistics["val"].num_nodes 182 | test_n_nodes = statistics["test"].num_nodes 183 | max_n_nodes = max( 184 | max(train_n_nodes.keys()), max(val_n_nodes.keys()), max(test_n_nodes.keys()) 185 | ) 186 | n_nodes = torch.zeros(max_n_nodes + 1, dtype=torch.long) 187 | for c in [train_n_nodes, val_n_nodes, test_n_nodes]: 188 | for key, value in c.items(): 189 | n_nodes[key] += value 190 | 191 | self.n_nodes = n_nodes / n_nodes.sum() 192 | self.atom_types = statistics["train"].atom_types 193 | self.edge_types = statistics["train"].bond_types 194 | self.charges_types = statistics["train"].charge_types 195 | self.charges_marginals = (self.charges_types * self.atom_types[:, None]).sum( 196 | dim=0 197 | ) 198 | self.valency_distribution = statistics["train"].valencies 199 | self.max_n_nodes = len(n_nodes) - 1 200 | self.nodes_dist = DistributionNodes(n_nodes) 201 | 202 | if hasattr(statistics["train"], "is_aromatic"): 203 | self.is_aromatic = statistics["train"].is_aromatic 204 | if hasattr(statistics["train"], "is_in_ring"): 205 | self.is_in_ring = statistics["train"].is_in_ring 206 | if hasattr(statistics["train"], "hybridization"): 207 | self.hybridization = statistics["train"].hybridization 208 | if hasattr(statistics["train"], "numHs"): 209 | self.numHs = statistics["train"].numHs 210 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/adaptive_loader.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections.abc import Mapping, Sequence 3 | from typing import List, Optional, Union 4 | 5 | import torch 6 | import torch.utils.data 7 | from torch.utils.data.dataloader import default_collate 8 | from torch_geometric.data import Batch, Dataset 9 | from torch_geometric.data.data import BaseData 10 | 11 | try: 12 | from torch_geometric.data import LightningDataset 13 | except ImportError: 14 | from torch_geometric.data.lightning import LightningDataset 15 | 16 | 17 | def effective_batch_size( 18 | max_size, reference_batch_size, reference_size=20, sampling=False 19 | ): 20 | x = reference_batch_size * (reference_size / max_size) ** 2 21 | return math.floor(1.8 * x) if sampling else math.floor(x) 22 | 23 | 24 | class AdaptiveCollater: 25 | def __init__(self, follow_batch, exclude_keys, reference_batch_size): 26 | """Copypaste from pyg.loader.Collater + small changes""" 27 | self.follow_batch = follow_batch 28 | self.exclude_keys = exclude_keys 29 | self.reference_bs = reference_batch_size 30 | 31 | def __call__(self, batch): 32 | # checks the number of node for basedata graphs and slots into appropriate buckets, 33 | # errors on other options 34 | elem = batch[0] 35 | if isinstance(elem, BaseData): 36 | to_keep = [] 37 | graph_sizes = [] 38 | 39 | for e in batch: 40 | e: BaseData 41 | graph_sizes.append(e.num_nodes) 42 | 43 | m = len(graph_sizes) 44 | graph_sizes = torch.Tensor(graph_sizes) 45 | srted, argsort = torch.sort(graph_sizes) 46 | random = torch.randint(0, m, size=(1, 1)).item() 47 | max_size = min(srted.max().item(), srted[random].item() + 5) 48 | max_size = max( 49 | max_size, 9 50 | ) # The batch sizes may be huge if the graphs happen to be tiny 51 | 52 | ebs = effective_batch_size(max_size, self.reference_bs) 53 | 54 | max_index = torch.nonzero(srted <= max_size).max().item() 55 | min_index = max(0, max_index - ebs) 56 | indices_to_keep = set(argsort[min_index : max_index + 1].tolist()) 57 | if max_index < ebs: 58 | for index in range(max_index + 1, m): 59 | # Check if we could add the graph to the list 60 | size = srted[index].item() 61 | potential_ebs = effective_batch_size(size, self.reference_bs) 62 | if len(indices_to_keep) < potential_ebs: 63 | indices_to_keep.add(argsort[index].item()) 64 | 65 | for i, e in enumerate(batch): 66 | e: BaseData 67 | if i in indices_to_keep: 68 | to_keep.append(e) 69 | 70 | new_batch = Batch.from_data_list( 71 | to_keep, self.follow_batch, self.exclude_keys 72 | ) 73 | return new_batch 74 | 75 | elif True: 76 | # early exit 77 | raise NotImplementedError("Only supporting BaseData for now") 78 | elif isinstance(elem, torch.Tensor): 79 | return default_collate(batch) 80 | elif isinstance(elem, float): 81 | return torch.tensor(batch, dtype=torch.float) 82 | elif isinstance(elem, int): 83 | return torch.tensor(batch) 84 | elif isinstance(elem, str): 85 | return batch 86 | elif isinstance(elem, Mapping): 87 | return {key: self([data[key] for data in batch]) for key in elem} 88 | elif isinstance(elem, tuple) and hasattr(elem, "_fields"): 89 | return type(elem)(*(self(s) for s in zip(*batch))) 90 | elif isinstance(elem, Sequence) and not isinstance(elem, str): 91 | return [self(s) for s in zip(*batch)] 92 | 93 | raise TypeError(f"DataLoader found invalid type: {type(elem)}") 94 | 95 | def collate(self, batch): # Deprecated... 96 | return self(batch) 97 | 98 | 99 | class AdaptiveDataLoader(torch.utils.data.DataLoader): 100 | r"""A data loader which merges data objects from a 101 | :class:`torch_geometric.data.Dataset` into mini-batches, each minibatch being a bucket with num_nodes < some threshold, 102 | except the last which holds the overflow-graphs. Apart from the bucketing, identical to torch_geometric.loader.DataLoader 103 | Default bucket_thresholds is [30,50,90], yielding 4 buckets 104 | Data objects can be either of type :class:`~torch_geometric.data.Data` or 105 | :class:`~torch_geometric.data.HeteroData`. 106 | 107 | Args: 108 | dataset (Dataset): The dataset from which to load the data. 109 | batch_size (int, optional): How many samples per batch to load. 110 | (default: :obj:`1`) 111 | shuffle (bool, optional): If set to :obj:`True`, the data will be 112 | reshuffled at every epoch. (default: :obj:`False`) 113 | follow_batch (List[str], optional): Creates assignment batch 114 | vectors for each key in the list. (default: :obj:`None`) 115 | exclude_keys (List[str], optional): Will exclude each key in the 116 | list. (default: :obj:`None`) 117 | **kwargs (optional): Additional arguments of 118 | :class:`torch.utils.data.DataLoader`. 119 | """ 120 | 121 | def __init__( 122 | self, 123 | dataset: Union[Dataset, List[BaseData]], 124 | batch_size: int = 1, 125 | reference_batch_size: int = 1, 126 | shuffle: bool = False, 127 | follow_batch: Optional[List[str]] = None, 128 | exclude_keys: Optional[List[str]] = None, 129 | **kwargs, 130 | ): 131 | if "collate_fn" in kwargs: 132 | del kwargs["collate_fn"] 133 | 134 | # Save for PyTorch Lightning: 135 | self.follow_batch = follow_batch 136 | self.exclude_keys = exclude_keys 137 | 138 | super().__init__( 139 | dataset, 140 | batch_size, 141 | shuffle, 142 | collate_fn=AdaptiveCollater( 143 | follow_batch, exclude_keys, reference_batch_size=reference_batch_size 144 | ), 145 | **kwargs, 146 | ) 147 | 148 | 149 | class AdaptiveLightningDataset(LightningDataset): 150 | r"""Converts a set of :class:`~torch_geometric.data.Dataset` objects into a 151 | :class:`pytorch_lightning.LightningDataModule` variant, which can be 152 | automatically used as a :obj:`datamodule` for multi-GPU graph-level 153 | training via `PyTorch Lightning `__. 154 | :class:`LightningDataset` will take care of providing mini-batches via 155 | :class:`~torch_geometric.loader.DataLoader`. 156 | 157 | .. note:: 158 | 159 | Currently only the 160 | :class:`pytorch_lightning.strategies.SingleDeviceStrategy` and 161 | :class:`pytorch_lightning.strategies.DDPSpawnStrategy` training 162 | strategies of `PyTorch Lightning 163 | `__ are supported in order to correctly share data across 165 | all devices/processes: 166 | 167 | .. code-block:: 168 | 169 | import pytorch_lightning as pl 170 | trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu", 171 | devices=4) 172 | trainer.fit(model, datamodule) 173 | 174 | Args: 175 | train_dataset (Dataset): The training dataset. 176 | val_dataset (Dataset, optional): The validation dataset. 177 | (default: :obj:`None`) 178 | test_dataset (Dataset, optional): The test dataset. 179 | (default: :obj:`None`) 180 | batch_size (int, optional): How many samples per batch to load. 181 | (default: :obj:`1`) 182 | num_workers: How many subprocesses to use for data loading. 183 | :obj:`0` means that the data will be loaded in the main process. 184 | (default: :obj:`0`) 185 | **kwargs (optional): Additional arguments of 186 | :class:`torch_geometric.loader.DataLoader`. 187 | """ 188 | 189 | def __init__( 190 | self, 191 | train_dataset: Dataset, 192 | val_dataset: Optional[Dataset] = None, 193 | test_dataset: Optional[Dataset] = None, 194 | batch_size: int = 1, 195 | reference_batch_size: int = 1, 196 | num_workers: int = 0, 197 | **kwargs, 198 | ): 199 | self.reference_batch_size = reference_batch_size 200 | super().__init__( 201 | train_dataset=train_dataset, 202 | val_dataset=val_dataset, 203 | test_dataset=test_dataset, 204 | # has_val=val_dataset is not None, 205 | # has_test=test_dataset is not None, 206 | batch_size=batch_size, 207 | num_workers=num_workers, 208 | **kwargs, 209 | ) 210 | 211 | self.train_dataset = train_dataset 212 | self.val_dataset = val_dataset 213 | self.test_dataset = test_dataset 214 | 215 | def dataloader( 216 | self, dataset: Dataset, shuffle: bool = False, **kwargs 217 | ) -> AdaptiveDataLoader: 218 | return AdaptiveDataLoader( 219 | dataset, 220 | reference_batch_size=self.reference_batch_size, 221 | shuffle=shuffle, 222 | **self.kwargs, 223 | ) 224 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/calculate_dipoles.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Subset 7 | from torch_geometric.data.collate import collate 8 | from tqdm import tqdm 9 | 10 | from experiments.xtb_energy import calculate_dipole 11 | 12 | 13 | def get_args(): 14 | # fmt: off 15 | parser = argparse.ArgumentParser(description='Energy calculation') 16 | parser.add_argument('--dataset', type=str, help='Which dataset') 17 | parser.add_argument('--split', type=str, help='Which data split train/val/test') 18 | parser.add_argument('--idx', type=int, default=None, help='Which part of the dataset (pubchem only)') 19 | 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | atom_encoder = { 25 | "H": 0, 26 | "B": 1, 27 | "C": 2, 28 | "N": 3, 29 | "O": 4, 30 | "F": 5, 31 | "Al": 6, 32 | "Si": 7, 33 | "P": 8, 34 | "S": 9, 35 | "Cl": 10, 36 | "As": 11, 37 | "Br": 12, 38 | "I": 13, 39 | "Hg": 14, 40 | "Bi": 15, 41 | } 42 | atom_decoder = {v: k for k, v in atom_encoder.items()} 43 | atom_reference = { 44 | "H": -0.393482763936, 45 | "B": -0.952436614164, 46 | "C": -1.795110518041, 47 | "N": -2.60945245463, 48 | "O": -3.769421097051, 49 | "F": -4.619339964238, 50 | "Al": -0.905328611479, 51 | "Si": -1.571424085131, 52 | "P": -2.377807088084, 53 | "S": -3.148271017078, 54 | "Cl": -4.482525134961, 55 | "As": -2.239425948594, 56 | "Br": -4.048339371234, 57 | "I": -3.77963026339, 58 | "Hg": -0.848032246708, 59 | "Bi": -2.26665341636, 60 | } 61 | 62 | 63 | def process(dataset, split, idx): 64 | if dataset == "drugs": 65 | from experiments.data.geom.geom_dataset_adaptive import ( 66 | GeomDrugsDataset as DataModule, 67 | ) 68 | 69 | root_path = "----" 70 | elif dataset == "qm9": 71 | from experiments.data.qm9.qm9_dataset import QM9Dataset as DataModule 72 | 73 | root_path = "----" 74 | elif dataset == "aqm": 75 | from experiments.data.aqm.aqm_dataset_nonadaptive import ( 76 | AQMDataset as DataModule, 77 | ) 78 | 79 | root_path = "----" 80 | elif dataset == "pubchem": 81 | from experiments.data.pubchem.pubchem_dataset_nonadaptive import ( 82 | PubChemLMDBDataset as DataModule, 83 | ) 84 | 85 | root_path = "----" 86 | else: 87 | raise ValueError("Dataset not found") 88 | 89 | remove_hs = False 90 | 91 | datamodule = DataModule(split=split, root=root_path, remove_h=remove_hs) 92 | 93 | if dataset == "pubchem": 94 | split_len = len(datamodule) // 500 95 | rng = np.arange(0, len(datamodule)) 96 | rng = rng[idx * split_len : (idx + 1) * split_len] 97 | datamodule = Subset(datamodule, rng) 98 | 99 | # elif dataset == "drugs": 100 | # split_len = len(datamodule) // 50 101 | # rng = np.arange(0, len(datamodule)) 102 | # rng = rng[idx * split_len : (idx + 1) * split_len] 103 | # datamodule = Subset(datamodule, rng) 104 | 105 | mols = [] 106 | for i, mol in tqdm(enumerate(datamodule), total=len(datamodule)): 107 | atom_types = [atom_decoder[int(a)] for a in mol.x] 108 | try: 109 | d = calculate_dipole(mol.pos, atom_types) 110 | mol.dipole_classic = torch.tensor(d, dtype=torch.float32).unsqueeze(0) 111 | mols.append(mol) 112 | except: 113 | print(f"Molecule with id {i} failed...") 114 | continue 115 | 116 | print(f"Collate the data...") 117 | data, slices = _collate(mols) 118 | 119 | print(f"Saving the data...") 120 | torch.save( 121 | (data, slices), 122 | (os.path.join(root_path, f"processed/{split}_{idx}_data_energy.pt")), 123 | ) 124 | 125 | 126 | def _collate(data_list): 127 | r"""Collates a Python list of :obj:`torch_geometric.data.Data` objects 128 | to the internal storage format of 129 | :class:`~torch_geometric.data.InMemoryDataset`.""" 130 | if len(data_list) == 1: 131 | return data_list[0], None 132 | 133 | data, slices, _ = collate( 134 | data_list[0].__class__, 135 | data_list=data_list, 136 | increment=False, 137 | add_batch=False, 138 | ) 139 | 140 | return data, slices 141 | 142 | 143 | if __name__ == "__main__": 144 | args = get_args() 145 | process(args.dataset, args.split, args.idx) 146 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/calculate_energies.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Subset 7 | from torch_geometric.data.collate import collate 8 | from tqdm import tqdm 9 | 10 | from experiments.xtb_energy import calculate_xtb_energy 11 | 12 | 13 | def get_args(): 14 | # fmt: off 15 | parser = argparse.ArgumentParser(description='Energy calculation') 16 | parser.add_argument('--dataset', type=str, help='Which dataset') 17 | parser.add_argument('--split', type=str, help='Which data split train/val/test') 18 | parser.add_argument('--idx', type=int, default=None, help='Which part of the dataset (pubchem only)') 19 | 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | atom_encoder = { 25 | "H": 0, 26 | "B": 1, 27 | "C": 2, 28 | "N": 3, 29 | "O": 4, 30 | "F": 5, 31 | "Al": 6, 32 | "Si": 7, 33 | "P": 8, 34 | "S": 9, 35 | "Cl": 10, 36 | "As": 11, 37 | "Br": 12, 38 | "I": 13, 39 | "Hg": 14, 40 | "Bi": 15, 41 | } 42 | atom_decoder = {v: k for k, v in atom_encoder.items()} 43 | atom_reference = { 44 | "H": -0.393482763936, 45 | "B": -0.952436614164, 46 | "C": -1.795110518041, 47 | "N": -2.60945245463, 48 | "O": -3.769421097051, 49 | "F": -4.619339964238, 50 | "Al": -0.905328611479, 51 | "Si": -1.571424085131, 52 | "P": -2.377807088084, 53 | "S": -3.148271017078, 54 | "Cl": -4.482525134961, 55 | "As": -2.239425948594, 56 | "Br": -4.048339371234, 57 | "I": -3.77963026339, 58 | "Hg": -0.848032246708, 59 | "Bi": -2.26665341636, 60 | } 61 | 62 | 63 | def process(dataset, split, idx): 64 | if dataset == "drugs": 65 | from experiments.data.geom.geom_dataset_adaptive import ( 66 | GeomDrugsDataset as DataModule, 67 | ) 68 | 69 | root_path = "----" 70 | elif dataset == "qm9": 71 | from experiments.data.qm9.qm9_dataset import QM9Dataset as DataModule 72 | 73 | root_path = "----" 74 | elif dataset == "aqm": 75 | from experiments.data.aqm.aqm_dataset_nonadaptive import ( 76 | AQMDataset as DataModule, 77 | ) 78 | 79 | root_path = "----" 80 | elif dataset == "pubchem": 81 | from experiments.data.pubchem.pubchem_dataset_nonadaptive import ( 82 | PubChemLMDBDataset as DataModule, 83 | ) 84 | 85 | root_path = "----" 86 | else: 87 | raise ValueError("Dataset not found") 88 | 89 | remove_hs = False 90 | 91 | datamodule = DataModule(split=split, root=root_path, remove_h=remove_hs) 92 | 93 | if dataset == "pubchem": 94 | split_len = len(datamodule) // 500 95 | rng = np.arange(0, len(datamodule)) 96 | rng = rng[idx * split_len : (idx + 1) * split_len] 97 | datamodule = Subset(datamodule, rng) 98 | 99 | # elif dataset == "drugs": 100 | # split_len = len(datamodule) // 50 101 | # rng = np.arange(0, len(datamodule)) 102 | # rng = rng[idx * split_len : (idx + 1) * split_len] 103 | # datamodule = Subset(datamodule, rng) 104 | 105 | mols = [] 106 | for i, mol in tqdm(enumerate(datamodule), total=len(datamodule)): 107 | atom_types = [atom_decoder[int(a)] for a in mol.x] 108 | try: 109 | e_ref = np.sum( 110 | [atom_reference[a] for a in atom_types] 111 | ) # * 27.2114 #Hartree to eV 112 | e, _ = calculate_xtb_energy(mol.pos, atom_types) 113 | e *= 0.0367493 # eV to Hartree 114 | mol.energy = torch.tensor(e - e_ref, dtype=torch.float32).unsqueeze(0) 115 | mols.append(mol) 116 | except: 117 | print(f"Molecule with id {i} failed...") 118 | continue 119 | 120 | print(f"Collate the data...") 121 | data, slices = _collate(mols) 122 | 123 | print(f"Saving the data...") 124 | torch.save( 125 | (data, slices), 126 | (os.path.join(root_path, f"processed/{split}_{idx}_data_energy.pt")), 127 | ) 128 | 129 | 130 | def _collate(data_list): 131 | r"""Collates a Python list of :obj:`torch_geometric.data.Data` objects 132 | to the internal storage format of 133 | :class:`~torch_geometric.data.InMemoryDataset`.""" 134 | if len(data_list) == 1: 135 | return data_list[0], None 136 | 137 | data, slices, _ = collate( 138 | data_list[0].__class__, 139 | data_list=data_list, 140 | increment=False, 141 | add_batch=False, 142 | ) 143 | 144 | return data, slices 145 | 146 | 147 | if __name__ == "__main__": 148 | args = get_args() 149 | process(args.dataset, args.split, args.idx) 150 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/data_info.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from experiments.data.abstract_dataset import ( 5 | AbstractDatasetInfos, 6 | ) 7 | from experiments.molecule_utils import PlaceHolder 8 | 9 | full_atom_encoder = { 10 | "H": 0, 11 | "B": 1, 12 | "C": 2, 13 | "N": 3, 14 | "O": 4, 15 | "F": 5, 16 | "Al": 6, 17 | "Si": 7, 18 | "P": 8, 19 | "S": 9, 20 | "Cl": 10, 21 | "As": 11, 22 | "Br": 12, 23 | "I": 13, 24 | "Hg": 14, 25 | "Bi": 15, 26 | } 27 | 28 | 29 | class GeneralInfos(AbstractDatasetInfos): 30 | def __init__(self, datamodule, cfg): 31 | self.remove_h = cfg.remove_hs 32 | self.need_to_strip = ( 33 | False # to indicate whether we need to ignore one output from the model 34 | ) 35 | self.statistics = datamodule.statistics 36 | self.name = "drugs" 37 | self.atom_encoder = full_atom_encoder 38 | self.num_bond_classes = cfg.num_bond_classes 39 | self.num_charge_classes = cfg.num_charge_classes 40 | self.charge_offset = 2 41 | self.collapse_charges = torch.Tensor([-2, -1, 0, 1, 2, 3]).int() 42 | # if self.remove_h: 43 | # self.atom_encoder = { 44 | # k: v - 1 for k, v in self.atom_encoder.items() if k != "H" 45 | # } 46 | 47 | super().complete_infos(datamodule.statistics, self.atom_encoder) 48 | 49 | self.input_dims = PlaceHolder( 50 | X=self.num_atom_types, 51 | C=self.num_charge_classes, 52 | E=self.num_bond_classes, 53 | y=1, 54 | pos=3, 55 | ) 56 | self.output_dims = PlaceHolder( 57 | X=self.num_atom_types, 58 | C=self.num_charge_classes, 59 | E=self.num_bond_classes, 60 | y=0, 61 | pos=3, 62 | ) 63 | 64 | def to_one_hot(self, X, C, E, node_mask): 65 | X = F.one_hot(X, num_classes=self.num_atom_types).float() 66 | E = F.one_hot(E, num_classes=self.num_bond_classes).float() 67 | C = F.one_hot( 68 | C + self.charge_offset, num_classes=self.num_charge_classes 69 | ).float() 70 | placeholder = PlaceHolder(X=X, C=C, E=E, y=None, pos=None) 71 | pl = placeholder.mask(node_mask) 72 | return pl.X, pl.C, pl.E 73 | 74 | def one_hot_charges(self, C): 75 | return F.one_hot( 76 | (C + self.charge_offset).long(), num_classes=self.num_charge_classes 77 | ).float() 78 | 79 | 80 | full_atom_encoder_drugs = { 81 | "H": 0, 82 | "B": 1, 83 | "C": 2, 84 | "N": 3, 85 | "O": 4, 86 | "F": 5, 87 | "Al": 6, 88 | "Si": 7, 89 | "P": 8, 90 | "S": 9, 91 | "Cl": 10, 92 | "As": 11, 93 | "Br": 12, 94 | "I": 13, 95 | "Hg": 14, 96 | "Bi": 15, 97 | } 98 | 99 | 100 | class GEOMInfos(AbstractDatasetInfos): 101 | def __init__(self, datamodule, cfg): 102 | self.remove_h = cfg.remove_hs 103 | self.need_to_strip = ( 104 | False # to indicate whether we need to ignore one output from the model 105 | ) 106 | self.statistics = datamodule.statistics 107 | self.name = "drugs" 108 | self.atom_encoder = full_atom_encoder_drugs 109 | self.charge_offset = 2 110 | self.collapse_charges = torch.Tensor([-2, -1, 0, 1, 2, 3]).int() 111 | if self.remove_h: 112 | self.atom_encoder = { 113 | k: v - 1 for k, v in self.atom_encoder.items() if k != "H" 114 | } 115 | 116 | super().complete_infos(datamodule.statistics, self.atom_encoder) 117 | 118 | self.input_dims = PlaceHolder(X=self.num_atom_types, C=6, E=5, y=1, pos=3) 119 | self.output_dims = PlaceHolder(X=self.num_atom_types, C=6, E=5, y=0, pos=3) 120 | 121 | def to_one_hot(self, X, C, E, node_mask): 122 | X = F.one_hot(X, num_classes=self.num_atom_types).float() 123 | E = F.one_hot(E, num_classes=5).float() 124 | C = F.one_hot(C + self.charge_offset, num_classes=6).float() 125 | placeholder = PlaceHolder(X=X, C=C, E=E, y=None, pos=None) 126 | pl = placeholder.mask(node_mask) 127 | return pl.X, pl.C, pl.E 128 | 129 | def one_hot_charges(self, C): 130 | return F.one_hot((C + self.charge_offset).long(), num_classes=6).float() 131 | 132 | 133 | full_atom_encoder_pubchem = { 134 | "H": 0, 135 | "B": 1, 136 | "C": 2, 137 | "N": 3, 138 | "O": 4, 139 | "F": 5, 140 | "Al": 6, 141 | "Si": 7, 142 | "P": 8, 143 | "S": 9, 144 | "Cl": 10, 145 | "As": 11, 146 | "Br": 12, 147 | "I": 13, 148 | "Hg": 14, 149 | "Bi": 15, 150 | } 151 | 152 | 153 | class PubChemInfos(AbstractDatasetInfos): 154 | def __init__(self, datamodule, cfg): 155 | self.remove_h = cfg.remove_hs 156 | self.need_to_strip = ( 157 | False # to indicate whether we need to ignore one output from the model 158 | ) 159 | self.statistics = datamodule.statistics 160 | self.name = "pubchem" 161 | self.atom_encoder = full_atom_encoder_pubchem 162 | self.atom_idx_mapping = { 163 | 0: 0, 164 | 1: 2, 165 | 2: 3, 166 | 3: 4, 167 | 4: 5, 168 | 5: 7, 169 | 6: 8, 170 | 7: 9, 171 | 8: 10, 172 | 9: 12, 173 | 10: 13, 174 | } 175 | self.charge_offset = 2 176 | self.collapse_charges = torch.Tensor([-2, -1, 0, 1, 2, 3]).int() 177 | if self.remove_h: 178 | self.atom_encoder = { 179 | k: v - 1 for k, v in self.atom_encoder.items() if k != "H" 180 | } 181 | 182 | super().complete_infos(datamodule.statistics, self.atom_encoder) 183 | 184 | self.input_dims = PlaceHolder(X=len(self.atom_encoder), C=6, E=5, y=1, pos=3) 185 | self.output_dims = PlaceHolder(X=len(self.atom_encoder), C=6, E=5, y=0, pos=3) 186 | 187 | def to_one_hot(self, X, C, E, node_mask): 188 | X = F.one_hot(X, num_classes=len(self.atom_encoder)).float() 189 | E = F.one_hot(E, num_classes=5).float() 190 | C = F.one_hot(C + self.charge_offset, num_classes=6).float() 191 | placeholder = PlaceHolder(X=X, C=C, E=E, y=None, pos=None) 192 | pl = placeholder.mask(node_mask) 193 | return pl.X, pl.C, pl.E 194 | 195 | def one_hot_charges(self, C): 196 | return F.one_hot((C + self.charge_offset).long(), num_classes=6).float() 197 | 198 | 199 | full_atom_encoder_qm9 = {"H": 0, "C": 1, "N": 2, "O": 3, "F": 4} 200 | 201 | 202 | class QM9Infos(AbstractDatasetInfos): 203 | def __init__(self, datamodule, cfg): 204 | self.remove_h = cfg.remove_hs 205 | self.statistics = datamodule.statistics 206 | self.name = "qm9" 207 | self.atom_encoder = full_atom_encoder_qm9 208 | self.charge_offset = 1 209 | self.collapse_charges = torch.Tensor([-1, 0, 1]).int() 210 | if self.remove_h: 211 | self.atom_encoder = { 212 | k: v - 1 for k, v in self.atom_encoder.items() if k != "H" 213 | } 214 | super().complete_infos(datamodule.statistics, self.atom_encoder) 215 | self.input_dims = PlaceHolder(X=self.num_atom_types, C=3, E=5, y=1, pos=3) 216 | self.output_dims = PlaceHolder(X=self.num_atom_types, C=3, E=5, y=0, pos=3) 217 | 218 | def to_one_hot(self, X, C, E, node_mask): 219 | X = F.one_hot(X, num_classes=self.num_atom_types).float() 220 | E = F.one_hot(E, num_classes=5).float() 221 | C = F.one_hot(C + self.charge_offset, num_classes=3).float() 222 | placeholder = PlaceHolder(X=X, C=C, E=E, y=None, pos=None) 223 | pl = placeholder.mask(node_mask) 224 | return pl.X, pl.C, pl.E 225 | 226 | def one_hot_charges(self, charges): 227 | return F.one_hot((charges + self.charge_offset).long(), num_classes=3).float() 228 | 229 | 230 | mol_properties = [ 231 | "DIP", 232 | "HLgap", 233 | "eAT", 234 | "eC", 235 | "eEE", 236 | "eH", 237 | "eKIN", 238 | "eKSE", 239 | "eL", 240 | "eNE", 241 | "eNN", 242 | "eMBD", 243 | "eTS", 244 | "eX", 245 | "eXC", 246 | "eXX", 247 | "mPOL", 248 | ] 249 | 250 | atomic_energies_dict = { 251 | 1: -13.643321054, 252 | 6: -1027.610746263, 253 | 7: -1484.276217092, 254 | 8: -2039.751675679, 255 | 9: -3139.751675679, 256 | 15: -9283.015861995, 257 | 16: -10828.726222083, 258 | 17: -12516.462339357, 259 | } 260 | atomic_numbers = [1, 6, 7, 8, 9, 15, 16, 17] 261 | full_atom_encoder_aqm = { 262 | "H": 0, 263 | "C": 1, 264 | "N": 2, 265 | "O": 3, 266 | "F": 4, 267 | "P": 5, 268 | "S": 6, 269 | "Cl": 7, 270 | } 271 | 272 | 273 | class AQMInfos(AbstractDatasetInfos): 274 | def __init__(self, datamodule, cfg): 275 | self.remove_h = cfg.remove_hs 276 | self.statistics = datamodule.statistics 277 | self.name = "aqm" 278 | self.atom_encoder = full_atom_encoder_aqm 279 | self.charge_offset = 1 280 | self.collapse_charges = torch.Tensor([-1, 0, 1]).int() 281 | if self.remove_h: 282 | self.atom_encoder = { 283 | k: v - 1 for k, v in self.atom_encoder.items() if k != "H" 284 | } 285 | 286 | super().complete_infos(datamodule.statistics, self.atom_encoder) 287 | 288 | self.input_dims = PlaceHolder(X=self.num_atom_types, C=3, E=5, y=1, pos=3) 289 | self.output_dims = PlaceHolder(X=self.num_atom_types, C=3, E=5, y=0, pos=3) 290 | 291 | def to_one_hot(self, X, C, E, node_mask): 292 | X = F.one_hot(X, num_classes=self.num_atom_types).float() 293 | E = F.one_hot(E, num_classes=5).float() 294 | C = F.one_hot(C + 1, num_classes=3).float() 295 | placeholder = PlaceHolder(X=X, C=C, E=E, y=None, pos=None) 296 | pl = placeholder.mask(node_mask) 297 | return pl.X, pl.C, pl.E 298 | 299 | def one_hot_charges(self, C): 300 | return F.one_hot((C + self.charge_offset).long(), num_classes=3).float() 301 | 302 | 303 | class AQMQM7XInfos(AbstractDatasetInfos): 304 | def __init__(self, datamodule, cfg): 305 | self.remove_h = cfg.remove_hs 306 | self.statistics = datamodule.statistics 307 | self.name = "aqm_qm7x" 308 | self.atom_encoder = full_atom_encoder_aqm 309 | self.charge_offset = 1 310 | self.collapse_charges = torch.Tensor([-1, 0, 1]).int() 311 | if self.remove_h: 312 | self.atom_encoder = { 313 | k: v - 1 for k, v in self.atom_encoder.items() if k != "H" 314 | } 315 | 316 | super().complete_infos(datamodule.statistics, self.atom_encoder) 317 | 318 | self.input_dims = PlaceHolder(X=self.num_atom_types, C=3, E=5, y=1, pos=3) 319 | self.output_dims = PlaceHolder(X=self.num_atom_types, C=3, E=5, y=0, pos=3) 320 | 321 | def to_one_hot(self, X, C, E, node_mask): 322 | X = F.one_hot(X, num_classes=self.num_atom_types).float() 323 | E = F.one_hot(E, num_classes=5).float() 324 | C = F.one_hot(C + 1, num_classes=3).float() 325 | placeholder = PlaceHolder(X=X, C=C, E=E, y=None, pos=None) 326 | pl = placeholder.mask(node_mask) 327 | return pl.X, pl.C, pl.E 328 | 329 | def one_hot_charges(self, C): 330 | return F.one_hot((C + self.charge_offset).long(), num_classes=3).float() 331 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/geom/calculate_energies.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | 5 | import numpy as np 6 | import torch 7 | from torch_geometric.data.collate import collate 8 | from tqdm import tqdm 9 | 10 | from experiments.xtb_energy import calculate_xtb_energy 11 | 12 | 13 | def get_args(): 14 | # fmt: off 15 | parser = argparse.ArgumentParser(description='Energy calculation') 16 | parser.add_argument('--dataset', type=str, help='Which dataset') 17 | parser.add_argument('--split', type=str, help='Which data split train/val/test') 18 | args = parser.parse_args() 19 | return args 20 | 21 | 22 | atom_encoder = { 23 | "H": 0, 24 | "B": 1, 25 | "C": 2, 26 | "N": 3, 27 | "O": 4, 28 | "F": 5, 29 | "Al": 6, 30 | "Si": 7, 31 | "P": 8, 32 | "S": 9, 33 | "Cl": 10, 34 | "As": 11, 35 | "Br": 12, 36 | "I": 13, 37 | "Hg": 14, 38 | "Bi": 15, 39 | } 40 | atom_decoder = {v: k for k, v in atom_encoder.items()} 41 | atom_reference = { 42 | "H": -0.393482763936, 43 | "B": -0.952436614164, 44 | "C": -1.795110518041, 45 | "N": -2.60945245463, 46 | "O": -3.769421097051, 47 | "F": -4.619339964238, 48 | "Al": -0.905328611479, 49 | "Si": -1.571424085131, 50 | "P": -2.377807088084, 51 | "S": -3.148271017078, 52 | "Cl": -4.482525134961, 53 | "As": -2.239425948594, 54 | "Br": -4.048339371234, 55 | "I": -3.77963026339, 56 | "Hg": -0.848032246708, 57 | "Bi": -2.26665341636, 58 | } 59 | 60 | 61 | def process(dataset, split): 62 | if dataset == "drugs": 63 | from experiments.data.geom.geom_dataset_adaptive import ( 64 | GeomDrugsDataset as DataModule, 65 | ) 66 | 67 | root_path = "/scratch1/cremej01/data/geom" 68 | elif dataset == "qm9": 69 | from experiments.data.qm9.qm9_dataset import GeomDrugsDataset as DataModule 70 | 71 | root_path = "/scratch1/cremej01/data/qm9" 72 | else: 73 | raise ValueError("Dataset not found") 74 | 75 | remove_hs = False 76 | 77 | dataset = DataModule(split=split, root=root_path, remove_h=remove_hs) 78 | 79 | failed_ids = [] 80 | mols = [] 81 | for i, mol in tqdm(enumerate(dataset)): 82 | atom_types = [atom_decoder[int(a)] for a in mol.x] 83 | try: 84 | e_ref = np.sum( 85 | [atom_reference[a] for a in atom_types] 86 | ) # * 27.2114 #Hartree to eV 87 | e, _ = calculate_xtb_energy(mol.pos, atom_types) 88 | e *= 0.0367493 # eV to Hartree 89 | mol.energy = torch.tensor(e - e_ref, dtype=torch.float32).unsqueeze(0) 90 | except: 91 | print(f"Molecule with id {i} failed...") 92 | failed_ids.append(i) 93 | continue 94 | mols.append(mol) 95 | 96 | print("Collate the data...") 97 | data, slices = _collate(mols) 98 | 99 | print("Saving the data...") 100 | torch.save( 101 | (data, slices), (os.path.join(root_path, f"processed/{split}_data_energy.pt")) 102 | ) 103 | 104 | with open(os.path.join(root_path, f"failed_ids_{split}.pickle"), "wb") as f: 105 | pickle.dump(failed_ids, f) 106 | 107 | 108 | def _collate(data_list): 109 | r"""Collates a Python list of :obj:`torch_geometric.data.Data` objects 110 | to the internal storage format of 111 | :class:`~torch_geometric.data.InMemoryDataset`.""" 112 | if len(data_list) == 1: 113 | return data_list[0], None 114 | 115 | data, slices, _ = collate( 116 | data_list[0].__class__, 117 | data_list=data_list, 118 | increment=False, 119 | add_batch=False, 120 | ) 121 | 122 | return data, slices 123 | 124 | 125 | if __name__ == "__main__": 126 | args = get_args() 127 | process(args.dataset, args.split) 128 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/geom/geom_dataset_energy.py: -------------------------------------------------------------------------------- 1 | from rdkit import RDLogger 2 | from tqdm import tqdm 3 | import numpy as np 4 | import torch 5 | from os.path import join 6 | from experiments.data.utils import train_subset 7 | from torch_geometric.data import InMemoryDataset, DataLoader 8 | import experiments.data.utils as dataset_utils 9 | from experiments.data.utils import load_pickle, save_pickle 10 | from experiments.data.abstract_dataset import ( 11 | AbstractAdaptiveDataModule, 12 | ) 13 | from experiments.data.metrics import compute_all_statistics 14 | from torch.utils.data import Subset 15 | 16 | 17 | full_atom_encoder = { 18 | "H": 0, 19 | "B": 1, 20 | "C": 2, 21 | "N": 3, 22 | "O": 4, 23 | "F": 5, 24 | "Al": 6, 25 | "Si": 7, 26 | "P": 8, 27 | "S": 9, 28 | "Cl": 10, 29 | "As": 11, 30 | "Br": 12, 31 | "I": 13, 32 | "Hg": 14, 33 | "Bi": 15, 34 | } 35 | 36 | 37 | class GeomDrugsDataset(InMemoryDataset): 38 | def __init__( 39 | self, split, root, remove_h, transform=None, pre_transform=None, pre_filter=None 40 | ): 41 | assert split in ["train", "val", "test"] 42 | self.split = split 43 | self.remove_h = remove_h 44 | 45 | self.compute_bond_distance_angles = True 46 | 47 | self.atom_encoder = full_atom_encoder 48 | 49 | if remove_h: 50 | self.atom_encoder = { 51 | k: v - 1 for k, v in self.atom_encoder.items() if k != "H" 52 | } 53 | 54 | super().__init__(root, transform, pre_transform, pre_filter) 55 | self.data, self.slices = torch.load(self.processed_paths[0]) 56 | self.statistics = dataset_utils.Statistics( 57 | num_nodes=load_pickle(self.processed_paths[1]), 58 | atom_types=torch.from_numpy(np.load(self.processed_paths[2])), 59 | bond_types=torch.from_numpy(np.load(self.processed_paths[3])), 60 | charge_types=torch.from_numpy(np.load(self.processed_paths[4])), 61 | valencies=load_pickle(self.processed_paths[5]), 62 | bond_lengths=load_pickle(self.processed_paths[6]), 63 | bond_angles=torch.from_numpy(np.load(self.processed_paths[7])), 64 | is_aromatic=torch.from_numpy(np.load(self.processed_paths[9])).float(), 65 | is_in_ring=torch.from_numpy(np.load(self.processed_paths[10])).float(), 66 | hybridization=torch.from_numpy(np.load(self.processed_paths[11])).float(), 67 | ) 68 | self.smiles = load_pickle(self.processed_paths[8]) 69 | 70 | @property 71 | def raw_file_names(self): 72 | if self.split == "train": 73 | return ["train_data.pickle"] 74 | elif self.split == "val": 75 | return ["val_data.pickle"] 76 | else: 77 | return ["test_data.pickle"] 78 | 79 | def processed_file_names(self): 80 | h = "noh" if self.remove_h else "h" 81 | if self.split == "train": 82 | return [ 83 | f"train_{h}_energy.pt", 84 | f"train_n_{h}.pickle", 85 | f"train_atom_types_{h}.npy", 86 | f"train_bond_types_{h}.npy", 87 | f"train_charges_{h}.npy", 88 | f"train_valency_{h}.pickle", 89 | f"train_bond_lengths_{h}.pickle", 90 | f"train_angles_{h}.npy", 91 | "train_smiles.pickle", 92 | f"train_is_aromatic_{h}.npy", 93 | f"train_is_in_ring_{h}.npy", 94 | f"train_hybridization_{h}.npy", 95 | ] 96 | elif self.split == "val": 97 | return [ 98 | f"val_{h}_energy.pt", 99 | f"val_n_{h}.pickle", 100 | f"val_atom_types_{h}.npy", 101 | f"val_bond_types_{h}.npy", 102 | f"val_charges_{h}.npy", 103 | f"val_valency_{h}.pickle", 104 | f"val_bond_lengths_{h}.pickle", 105 | f"val_angles_{h}.npy", 106 | "val_smiles.pickle", 107 | f"val_is_aromatic_{h}.npy", 108 | f"val_is_in_ring_{h}.npy", 109 | f"val_hybridization_{h}.npy", 110 | ] 111 | else: 112 | return [ 113 | f"test_{h}_energy.pt", 114 | f"test_n_{h}.pickle", 115 | f"test_atom_types_{h}.npy", 116 | f"test_bond_types_{h}.npy", 117 | f"test_charges_{h}.npy", 118 | f"test_valency_{h}.pickle", 119 | f"test_bond_lengths_{h}.pickle", 120 | f"test_angles_{h}.npy", 121 | "test_smiles.pickle", 122 | f"test_is_aromatic_{h}.npy", 123 | f"test_is_in_ring_{h}.npy", 124 | f"test_hybridization_{h}.npy", 125 | ] 126 | 127 | def download(self): 128 | raise ValueError( 129 | "Download and preprocessing is manual. If the data is already downloaded, " 130 | f"check that the paths are correct. Root dir = {self.root} -- raw files {self.raw_paths}" 131 | ) 132 | 133 | def process(self): 134 | RDLogger.DisableLog("rdApp.*") 135 | all_data = load_pickle(self.raw_paths[0]) 136 | 137 | data_list = [] 138 | all_smiles = [] 139 | for i, data in enumerate(tqdm(all_data)): 140 | smiles, all_conformers = data 141 | all_smiles.append(smiles) 142 | for j, conformer in enumerate(all_conformers): 143 | if j >= 5: 144 | break 145 | data = dataset_utils.mol_to_torch_geometric( 146 | conformer, 147 | full_atom_encoder, 148 | smiles, 149 | remove_hydrogens=self.remove_h, # need to give full atom encoder since hydrogens might still be available if Chem.RemoveHs is called 150 | ) 151 | # even when calling Chem.RemoveHs, hydrogens might be present 152 | if self.remove_h: 153 | data = dataset_utils.remove_hydrogens( 154 | data 155 | ) # remove through masking 156 | 157 | if self.pre_filter is not None and not self.pre_filter(data): 158 | continue 159 | if self.pre_transform is not None: 160 | data = self.pre_transform(data) 161 | 162 | data_list.append(data) 163 | 164 | torch.save(self.collate(data_list), self.processed_paths[0]) 165 | 166 | statistics = compute_all_statistics( 167 | data_list, 168 | self.atom_encoder, 169 | charges_dic={-2: 0, -1: 1, 0: 2, 1: 3, 2: 4, 3: 5}, 170 | additional_feats=True, 171 | # do not compute bond distance and bond angle statistics due to time and we do not use it anyways currently 172 | ) 173 | save_pickle(statistics.num_nodes, self.processed_paths[1]) 174 | np.save(self.processed_paths[2], statistics.atom_types) 175 | np.save(self.processed_paths[3], statistics.bond_types) 176 | np.save(self.processed_paths[4], statistics.charge_types) 177 | save_pickle(statistics.valencies, self.processed_paths[5]) 178 | save_pickle(statistics.bond_lengths, self.processed_paths[6]) 179 | np.save(self.processed_paths[7], statistics.bond_angles) 180 | save_pickle(set(all_smiles), self.processed_paths[8]) 181 | 182 | np.save(self.processed_paths[9], statistics.is_aromatic) 183 | np.save(self.processed_paths[10], statistics.is_in_ring) 184 | np.save(self.processed_paths[11], statistics.hybridization) 185 | 186 | 187 | class GeomDataModule(AbstractAdaptiveDataModule): 188 | def __init__(self, cfg): 189 | self.datadir = cfg.dataset_root 190 | root_path = cfg.dataset_root 191 | self.pin_memory = True 192 | 193 | train_dataset = GeomDrugsDataset( 194 | split="train", root=root_path, remove_h=cfg.remove_hs 195 | ) 196 | val_dataset = GeomDrugsDataset( 197 | split="val", root=root_path, remove_h=cfg.remove_hs 198 | ) 199 | test_dataset = GeomDrugsDataset( 200 | split="test", root=root_path, remove_h=cfg.remove_hs 201 | ) 202 | 203 | self.statistics = { 204 | "train": train_dataset.statistics, 205 | "val": val_dataset.statistics, 206 | "test": test_dataset.statistics, 207 | } 208 | 209 | if cfg.select_train_subset: 210 | self.idx_train = train_subset( 211 | dset_len=len(train_dataset), 212 | train_size=cfg.train_size, 213 | seed=cfg.seed, 214 | filename=join(cfg.save_dir, "splits.npz"), 215 | ) 216 | self.train_smiles = train_dataset.smiles 217 | train_dataset = Subset(train_dataset, self.idx_train) 218 | 219 | self.remove_h = cfg.remove_hs 220 | 221 | super().__init__(cfg, train_dataset, val_dataset, test_dataset) 222 | 223 | def _train_dataloader(self, shuffle=True): 224 | dataloader = DataLoader( 225 | dataset=self.train_dataset, 226 | batch_size=self.cfg.batch_size, 227 | num_workers=self.cfg.num_workers, 228 | pin_memory=self.pin_memory, 229 | shuffle=shuffle, 230 | persistent_workers=False, 231 | ) 232 | return dataloader 233 | 234 | def _val_dataloader(self, shuffle=False): 235 | dataloader = DataLoader( 236 | dataset=self.val_dataset, 237 | batch_size=self.cfg.batch_size, 238 | num_workers=self.cfg.num_workers, 239 | pin_memory=self.pin_memory, 240 | shuffle=shuffle, 241 | persistent_workers=False, 242 | ) 243 | return dataloader 244 | 245 | def _test_dataloader(self, shuffle=False): 246 | dataloader = DataLoader( 247 | dataset=self.test_dataset, 248 | batch_size=self.cfg.batch_size, 249 | num_workers=self.cfg.num_workers, 250 | pin_memory=self.pin_memory, 251 | shuffle=shuffle, 252 | persistent_workers=False, 253 | ) 254 | return dataloader 255 | 256 | def compute_mean_mad(self, properties_list): 257 | if self.cfg.dataset == "qm9" or self.cfg.dataset == "drugs": 258 | dataloader = self.get_dataloader(self.train_dataset, "val") 259 | return self.compute_mean_mad_from_dataloader(dataloader, properties_list) 260 | elif self.cfg.dataset == "qm9_1half" or self.cfg.dataset == "qm9_2half": 261 | dataloader = self.get_dataloader(self.val_dataset, "val") 262 | return self.compute_mean_mad_from_dataloader(dataloader, properties_list) 263 | else: 264 | raise Exception("Wrong dataset name") 265 | 266 | def compute_mean_mad_from_dataloader(self, dataloader, properties_list): 267 | property_norms = {} 268 | for property_key in properties_list: 269 | try: 270 | property_name = property_key + "_mm" 271 | values = getattr(dataloader.dataset[:], property_name) 272 | except: 273 | property_name = property_key 274 | idx = dataloader.dataset[:].label2idx[property_name] 275 | values = torch.tensor( 276 | [data.y[:, idx] for data in dataloader.dataset[:]] 277 | ) 278 | 279 | mean = torch.mean(values) 280 | ma = torch.abs(values - mean) 281 | mad = torch.mean(ma) 282 | property_norms[property_key] = {} 283 | property_norms[property_key]["mean"] = mean 284 | property_norms[property_key]["mad"] = mad 285 | del values 286 | return property_norms 287 | 288 | def get_dataloader(self, dataset, stage): 289 | if stage == "train": 290 | batch_size = self.cfg.batch_size 291 | shuffle = True 292 | elif stage in ["val", "test"]: 293 | batch_size = self.cfg.inference_batch_size 294 | shuffle = False 295 | 296 | dl = DataLoader( 297 | dataset=dataset, 298 | batch_size=batch_size, 299 | num_workers=self.cfg.num_workers, 300 | pin_memory=True, 301 | shuffle=shuffle, 302 | ) 303 | 304 | return dl 305 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/geom/geom_dataset_nonadaptive.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from typing import Optional 3 | 4 | from torch.utils.data import Subset 5 | from torch_geometric.data import DataLoader 6 | 7 | from experiments.data.abstract_dataset import ( 8 | AbstractDataModule, 9 | ) 10 | from experiments.data.geom.geom_dataset_adaptive import GeomDrugsDataset 11 | from experiments.data.utils import train_subset 12 | 13 | full_atom_encoder = { 14 | "H": 0, 15 | "B": 1, 16 | "C": 2, 17 | "N": 3, 18 | "O": 4, 19 | "F": 5, 20 | "Al": 6, 21 | "Si": 7, 22 | "P": 8, 23 | "S": 9, 24 | "Cl": 10, 25 | "As": 11, 26 | "Br": 12, 27 | "I": 13, 28 | "Hg": 14, 29 | "Bi": 15, 30 | } 31 | 32 | 33 | class GeomDataModule(AbstractDataModule): 34 | def __init__(self, cfg, only_stats: bool = False): 35 | self.datadir = cfg.dataset_root 36 | root_path = cfg.dataset_root 37 | self.cfg = cfg 38 | self.pin_memory = True 39 | self.persistent_workers = False 40 | 41 | train_dataset = GeomDrugsDataset( 42 | split="train", root=root_path, remove_h=cfg.remove_hs, only_stats=only_stats 43 | ) 44 | val_dataset = GeomDrugsDataset( 45 | split="val", root=root_path, remove_h=cfg.remove_hs, only_stats=only_stats 46 | ) 47 | test_dataset = GeomDrugsDataset( 48 | split="test", root=root_path, remove_h=cfg.remove_hs, only_stats=only_stats 49 | ) 50 | 51 | if not only_stats: 52 | if cfg.select_train_subset: 53 | self.idx_train = train_subset( 54 | dset_len=len(train_dataset), 55 | train_size=cfg.train_size, 56 | seed=cfg.seed, 57 | filename=join(cfg.save_dir, "splits.npz"), 58 | ) 59 | train_dataset = Subset(train_dataset, self.idx_train) 60 | 61 | self.remove_h = cfg.remove_hs 62 | self.statistics = { 63 | "train": train_dataset.statistics, 64 | "val": val_dataset.statistics, 65 | "test": test_dataset.statistics, 66 | } 67 | super().__init__(cfg, train_dataset, val_dataset, test_dataset) 68 | 69 | def setup(self, stage: Optional[str] = None) -> None: 70 | train_dataset = GeomDrugsDataset( 71 | root=self.cfg.dataset_root, split="train", remove_h=self.cfg.remove_hs 72 | ) 73 | val_dataset = GeomDrugsDataset( 74 | root=self.cfg.dataset_root, split="val", remove_h=self.cfg.remove_hs 75 | ) 76 | test_dataset = GeomDrugsDataset( 77 | root=self.cfg.dataset_root, split="test", remove_h=self.cfg.remove_hs 78 | ) 79 | 80 | if stage == "fit" or stage is None: 81 | self.train_dataset = train_dataset 82 | self.val_dataset = val_dataset 83 | self.test_dataset = test_dataset 84 | 85 | def train_dataloader(self, shuffle=True): 86 | dataloader = DataLoader( 87 | dataset=self.train_dataset, 88 | batch_size=self.cfg.batch_size, 89 | num_workers=self.cfg.num_workers, 90 | pin_memory=self.pin_memory, 91 | shuffle=shuffle, 92 | persistent_workers=self.persistent_workers, 93 | ) 94 | return dataloader 95 | 96 | def val_dataloader(self, shuffle=False): 97 | dataloader = DataLoader( 98 | dataset=self.val_dataset, 99 | batch_size=self.cfg.batch_size, 100 | num_workers=self.cfg.num_workers, 101 | pin_memory=self.pin_memory, 102 | shuffle=shuffle, 103 | persistent_workers=self.persistent_workers, 104 | ) 105 | return dataloader 106 | 107 | def test_dataloader(self, shuffle=False): 108 | dataloader = DataLoader( 109 | dataset=self.test_dataset, 110 | batch_size=self.cfg.batch_size, 111 | num_workers=self.cfg.num_workers, 112 | pin_memory=self.pin_memory, 113 | shuffle=shuffle, 114 | persistent_workers=self.persistent_workers, 115 | ) 116 | return dataloader 117 | 118 | def get_dataloader(self, dataset, stage): 119 | if stage == "train": 120 | batch_size = self.cfg.batch_size 121 | shuffle = True 122 | elif stage in ["val", "test"]: 123 | batch_size = self.cfg.inference_batch_size 124 | shuffle = False 125 | 126 | dl = DataLoader( 127 | dataset=dataset, 128 | batch_size=batch_size, 129 | num_workers=self.cfg.num_workers, 130 | pin_memory=True, 131 | shuffle=shuffle, 132 | ) 133 | 134 | return dl 135 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/ligand/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/eqgat_diff/experiments/data/ligand/__init__.py -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/ligand/geometry_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from constants import CA_C_DIST, N_CA_C_ANGLE, N_CA_DIST 3 | 4 | 5 | def rotation_matrix(angle, axis): 6 | """ 7 | Args: 8 | angle: (n,) 9 | axis: 0=x, 1=y, 2=z 10 | Returns: 11 | (n, 3, 3) 12 | """ 13 | n = len(angle) 14 | R = np.eye(3)[None, :, :].repeat(n, axis=0) 15 | 16 | axis = 2 - axis 17 | start = axis // 2 18 | step = axis % 2 + 1 19 | s = slice(start, start + step + 1, step) 20 | 21 | R[:, s, s] = np.array( 22 | [ 23 | [np.cos(angle), (-1) ** (axis + 1) * np.sin(angle)], 24 | [(-1) ** axis * np.sin(angle), np.cos(angle)], 25 | ] 26 | ).transpose(2, 0, 1) 27 | return R 28 | 29 | 30 | def get_bb_transform(n_xyz, ca_xyz, c_xyz): 31 | """ 32 | Compute translation and rotation of the canoncical backbone frame (triangle N-Ca-C) from a position with 33 | Ca at the origin, N on the x-axis and C in the xy-plane to the global position of the backbone frame 34 | 35 | Args: 36 | n_xyz: (n, 3) 37 | ca_xyz: (n, 3) 38 | c_xyz: (n, 3) 39 | 40 | Returns: 41 | quaternion represented as array of shape (n, 4) 42 | translation vector which is an array of shape (n, 3) 43 | """ 44 | 45 | translation = ca_xyz 46 | n_xyz = n_xyz - translation 47 | c_xyz = c_xyz - translation 48 | 49 | # Find rotation matrix that aligns the coordinate systems 50 | # rotate around y-axis to move N into the xy-plane 51 | theta_y = np.arctan2(n_xyz[:, 2], -n_xyz[:, 0]) 52 | Ry = rotation_matrix(theta_y, 1) 53 | n_xyz = np.einsum("noi,ni->no", Ry.transpose(0, 2, 1), n_xyz) 54 | 55 | # rotate around z-axis to move N onto the x-axis 56 | theta_z = np.arctan2(n_xyz[:, 1], n_xyz[:, 0]) 57 | Rz = rotation_matrix(theta_z, 2) 58 | # n_xyz = np.einsum('noi,ni->no', Rz.transpose(0, 2, 1), n_xyz) 59 | 60 | # rotate around x-axis to move C into the xy-plane 61 | c_xyz = np.einsum( 62 | "noj,nji,ni->no", Rz.transpose(0, 2, 1), Ry.transpose(0, 2, 1), c_xyz 63 | ) 64 | theta_x = np.arctan2(c_xyz[:, 2], c_xyz[:, 1]) 65 | Rx = rotation_matrix(theta_x, 0) 66 | 67 | # Final rotation matrix 68 | R = np.einsum("nok,nkj,nji->noi", Ry, Rz, Rx) 69 | 70 | # Convert to quaternion 71 | # q = w + i*u_x + j*u_y + k * u_z 72 | quaternion = rotation_matrix_to_quaternion(R) 73 | 74 | return quaternion, translation 75 | 76 | 77 | def get_bb_coords_from_transform(ca_coords, quaternion): 78 | """ 79 | Args: 80 | ca_coords: (n, 3) 81 | quaternion: (n, 4) 82 | Returns: 83 | backbone coords (n*3, 3), order is [N, CA, C] 84 | backbone atom types as a list of length n*3 85 | """ 86 | R = quaternion_to_rotation_matrix(quaternion) 87 | bb_coords = np.tile( 88 | np.array( 89 | [ 90 | [N_CA_DIST, 0, 0], 91 | [0, 0, 0], 92 | [CA_C_DIST * np.cos(N_CA_C_ANGLE), CA_C_DIST * np.sin(N_CA_C_ANGLE), 0], 93 | ] 94 | ), 95 | [len(ca_coords), 1], 96 | ) 97 | bb_coords = np.einsum( 98 | "noi,ni->no", R.repeat(3, axis=0), bb_coords 99 | ) + ca_coords.repeat(3, axis=0) 100 | bb_atom_types = [t for _ in range(len(ca_coords)) for t in ["N", "C", "C"]] 101 | 102 | return bb_coords, bb_atom_types 103 | 104 | 105 | def quaternion_to_rotation_matrix(q): 106 | """ 107 | x_rot = R x 108 | 109 | Args: 110 | q: (n, 4) 111 | Returns: 112 | R: (n, 3, 3) 113 | """ 114 | # Normalize 115 | q = q / (q**2).sum(1, keepdims=True) ** 0.5 116 | 117 | # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion 118 | w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3] 119 | R = np.stack( 120 | [ 121 | np.stack( 122 | [ 123 | 1 - 2 * y**2 - 2 * z**2, 124 | 2 * x * y - 2 * z * w, 125 | 2 * x * z + 2 * y * w, 126 | ], 127 | axis=1, 128 | ), 129 | np.stack( 130 | [ 131 | 2 * x * y + 2 * z * w, 132 | 1 - 2 * x**2 - 2 * z**2, 133 | 2 * y * z - 2 * x * w, 134 | ], 135 | axis=1, 136 | ), 137 | np.stack( 138 | [ 139 | 2 * x * z - 2 * y * w, 140 | 2 * y * z + 2 * x * w, 141 | 1 - 2 * x**2 - 2 * y**2, 142 | ], 143 | axis=1, 144 | ), 145 | ], 146 | axis=1, 147 | ) 148 | 149 | return R 150 | 151 | 152 | def rotation_matrix_to_quaternion(R): 153 | """ 154 | https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion 155 | Args: 156 | R: (n, 3, 3) 157 | Returns: 158 | q: (n, 4) 159 | """ 160 | 161 | t = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2] 162 | r = np.sqrt(1 + t) 163 | w = 0.5 * r 164 | x = np.sign(R[:, 2, 1] - R[:, 1, 2]) * np.abs( 165 | 0.5 * np.sqrt(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2]) 166 | ) 167 | y = np.sign(R[:, 0, 2] - R[:, 2, 0]) * np.abs( 168 | 0.5 * np.sqrt(1 - R[:, 0, 0] + R[:, 1, 1] - R[:, 2, 2]) 169 | ) 170 | z = np.sign(R[:, 1, 0] - R[:, 0, 1]) * np.abs( 171 | 0.5 * np.sqrt(1 - R[:, 0, 0] - R[:, 1, 1] + R[:, 2, 2]) 172 | ) 173 | 174 | return np.stack((w, x, y, z), axis=1) 175 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/ligand/molecule_builder.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import warnings 3 | 4 | import numpy as np 5 | import openbabel 6 | import torch 7 | from rdkit import Chem 8 | from rdkit.Chem.rdForceFieldHelpers import UFFHasAllMoleculeParams, UFFOptimizeMolecule 9 | 10 | from experiments.data.ligand import utils 11 | from experiments.data.ligand.constants import ( 12 | bond_dict, 13 | bonds1, 14 | bonds2, 15 | bonds3, 16 | margin1, 17 | margin2, 18 | margin3, 19 | ) 20 | 21 | 22 | def get_bond_order(atom1, atom2, distance): 23 | distance = 100 * distance # We change the metric 24 | 25 | if ( 26 | atom1 in bonds3 27 | and atom2 in bonds3[atom1] 28 | and distance < bonds3[atom1][atom2] + margin3 29 | ): 30 | return 3 # Triple 31 | 32 | if ( 33 | atom1 in bonds2 34 | and atom2 in bonds2[atom1] 35 | and distance < bonds2[atom1][atom2] + margin2 36 | ): 37 | return 2 # Double 38 | 39 | if ( 40 | atom1 in bonds1 41 | and atom2 in bonds1[atom1] 42 | and distance < bonds1[atom1][atom2] + margin1 43 | ): 44 | return 1 # Single 45 | 46 | return 0 # No bond 47 | 48 | 49 | def get_bond_order_batch(atoms1, atoms2, distances, dataset_info): 50 | if isinstance(atoms1, np.ndarray): 51 | atoms1 = torch.from_numpy(atoms1) 52 | if isinstance(atoms2, np.ndarray): 53 | atoms2 = torch.from_numpy(atoms2) 54 | if isinstance(distances, np.ndarray): 55 | distances = torch.from_numpy(distances) 56 | 57 | distances = 100 * distances # We change the metric 58 | 59 | bonds1 = torch.tensor(dataset_info["bonds1"], device=atoms1.device) 60 | bonds2 = torch.tensor(dataset_info["bonds2"], device=atoms1.device) 61 | bonds3 = torch.tensor(dataset_info["bonds3"], device=atoms1.device) 62 | 63 | bond_types = torch.zeros_like(atoms1) # 0: No bond 64 | 65 | # Single 66 | bond_types[distances < bonds1[atoms1, atoms2] + margin1] = 1 67 | 68 | # Double (note that already assigned single bonds will be overwritten) 69 | bond_types[distances < bonds2[atoms1, atoms2] + margin2] = 2 70 | 71 | # Triple 72 | bond_types[distances < bonds3[atoms1, atoms2] + margin3] = 3 73 | 74 | return bond_types 75 | 76 | 77 | def make_mol_openbabel(positions, atom_types, atom_decoder): 78 | """ 79 | Build an RDKit molecule using openbabel for creating bonds 80 | Args: 81 | positions: N x 3 82 | atom_types: N 83 | atom_decoder: maps indices to atom types 84 | Returns: 85 | rdkit molecule 86 | """ 87 | 88 | with tempfile.NamedTemporaryFile() as tmp: 89 | tmp_file = tmp.name 90 | 91 | # Write xyz file 92 | utils.write_xyz_file(positions, atom_types, tmp_file) 93 | 94 | # Convert to sdf file with openbabel 95 | # openbabel will add bonds 96 | obConversion = openbabel.OBConversion() 97 | obConversion.SetInAndOutFormats("xyz", "sdf") 98 | ob_mol = openbabel.OBMol() 99 | obConversion.ReadFile(ob_mol, tmp_file) 100 | 101 | obConversion.WriteFile(ob_mol, tmp_file) 102 | 103 | # Read sdf file with RDKit 104 | mol = Chem.SDMolSupplier(tmp_file, sanitize=False)[0] 105 | 106 | return mol 107 | 108 | 109 | def make_mol_edm(positions, atom_types, dataset_info, add_coords): 110 | """ 111 | Equivalent to EDM's way of building RDKit molecules 112 | """ 113 | n = len(positions) 114 | 115 | # (X, A, E): atom_types, adjacency matrix, edge_types 116 | # X: N (int) 117 | # A: N x N (bool) -> (binary adjacency matrix) 118 | # E: N x N (int) -> (bond type, 0 if no bond) 119 | pos = positions.unsqueeze(0) # add batch dim 120 | dists = torch.cdist(pos, pos, p=2).squeeze(0).view(-1) # remove batch dim & flatten 121 | atoms1, atoms2 = torch.cartesian_prod(atom_types, atom_types).T 122 | E_full = get_bond_order_batch(atoms1, atoms2, dists, dataset_info).view(n, n) 123 | E = torch.tril(E_full, diagonal=-1) # Warning: the graph should be DIRECTED 124 | A = E.bool() 125 | X = atom_types 126 | 127 | mol = Chem.RWMol() 128 | for atom in X: 129 | a = Chem.Atom(dataset_info["atom_decoder"][atom.item()]) 130 | mol.AddAtom(a) 131 | 132 | all_bonds = torch.nonzero(A) 133 | for bond in all_bonds: 134 | mol.AddBond( 135 | bond[0].item(), bond[1].item(), bond_dict[E[bond[0], bond[1]].item()] 136 | ) 137 | 138 | if add_coords: 139 | conf = Chem.Conformer(mol.GetNumAtoms()) 140 | for i in range(mol.GetNumAtoms()): 141 | conf.SetAtomPosition( 142 | i, 143 | ( 144 | positions[i, 0].item(), 145 | positions[i, 1].item(), 146 | positions[i, 2].item(), 147 | ), 148 | ) 149 | mol.AddConformer(conf) 150 | 151 | return mol 152 | 153 | 154 | def build_molecule( 155 | positions, atom_types, dataset_info, add_coords=False, use_openbabel=True 156 | ): 157 | """ 158 | Build RDKit molecule 159 | Args: 160 | positions: N x 3 161 | atom_types: N 162 | dataset_info: dict 163 | add_coords: Add conformer to mol (always added if use_openbabel=True) 164 | use_openbabel: use OpenBabel to create bonds 165 | Returns: 166 | RDKit molecule 167 | """ 168 | if use_openbabel: 169 | mol = make_mol_openbabel(positions, atom_types, dataset_info["atom_decoder"]) 170 | else: 171 | mol = make_mol_edm(positions, atom_types, dataset_info, add_coords) 172 | 173 | return mol 174 | 175 | 176 | def process_molecule( 177 | rdmol, add_hydrogens=False, sanitize=False, relax_iter=0, largest_frag=False 178 | ): 179 | """ 180 | Apply filters to an RDKit molecule. Makes a copy first. 181 | Args: 182 | rdmol: rdkit molecule 183 | add_hydrogens 184 | sanitize 185 | relax_iter: maximum number of UFF optimization iterations 186 | largest_frag: filter out the largest fragment in a set of disjoint 187 | molecules 188 | Returns: 189 | RDKit molecule or None if it does not pass the filters 190 | """ 191 | 192 | # Create a copy 193 | mol = Chem.Mol(rdmol) 194 | 195 | if sanitize: 196 | try: 197 | Chem.SanitizeMol(mol) 198 | except ValueError: 199 | warnings.warn("Sanitization failed. Returning None.") 200 | return None 201 | 202 | if add_hydrogens: 203 | mol = Chem.AddHs(mol, addCoords=(len(mol.GetConformers()) > 0)) 204 | 205 | if largest_frag: 206 | mol_frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) 207 | mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) 208 | if sanitize: 209 | # sanitize the updated molecule 210 | try: 211 | Chem.SanitizeMol(mol) 212 | except ValueError: 213 | return None 214 | 215 | if relax_iter > 0: 216 | if not UFFHasAllMoleculeParams(mol): 217 | warnings.warn( 218 | "UFF parameters not available for all atoms. " "Returning None." 219 | ) 220 | return None 221 | 222 | try: 223 | uff_relax(mol, relax_iter) 224 | if sanitize: 225 | # sanitize the updated molecule 226 | Chem.SanitizeMol(mol) 227 | except (RuntimeError, ValueError) as e: 228 | return None 229 | 230 | return mol 231 | 232 | 233 | def uff_relax(mol, max_iter=200): 234 | """ 235 | Uses RDKit's universal force field (UFF) implementation to optimize a 236 | molecule. 237 | """ 238 | more_iterations_required = UFFOptimizeMolecule(mol, maxIters=max_iter) 239 | if more_iterations_required: 240 | warnings.warn( 241 | f"Maximum number of FF iterations reached. " 242 | f"Returning molecule after {max_iter} relaxation steps." 243 | ) 244 | return more_iterations_required 245 | 246 | 247 | def filter_rd_mol(rdmol): 248 | """ 249 | Filter out RDMols if they have a 3-3 ring intersection 250 | adapted from: 251 | https://github.com/luost26/3D-Generative-SBDD/blob/main/utils/chem.py 252 | """ 253 | ring_info = rdmol.GetRingInfo() 254 | ring_info.AtomRings() 255 | rings = [set(r) for r in ring_info.AtomRings()] 256 | 257 | # 3-3 ring intersection 258 | for i, ring_a in enumerate(rings): 259 | if len(ring_a) != 3: 260 | continue 261 | for j, ring_b in enumerate(rings): 262 | if i <= j: 263 | continue 264 | inter = ring_a.intersection(ring_b) 265 | if (len(ring_b) == 3) and (len(inter) > 0): 266 | return False 267 | 268 | return True 269 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/ligand/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Union 2 | 3 | import networkx as nx 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from Bio.PDB.Polypeptide import is_aa 8 | from networkx.algorithms import isomorphism 9 | from rdkit import Chem 10 | 11 | 12 | class Queue: 13 | def __init__(self, max_len=50): 14 | self.items = [] 15 | self.max_len = max_len 16 | 17 | def __len__(self): 18 | return len(self.items) 19 | 20 | def add(self, item): 21 | self.items.insert(0, item) 22 | if len(self) > self.max_len: 23 | self.items.pop() 24 | 25 | def mean(self): 26 | return np.mean(self.items) 27 | 28 | def std(self): 29 | return np.std(self.items) 30 | 31 | 32 | def reverse_tensor(x): 33 | return x[torch.arange(x.size(0) - 1, -1, -1)] 34 | 35 | 36 | ##### 37 | 38 | 39 | def get_grad_norm( 40 | parameters: Union[torch.Tensor, Iterable[torch.Tensor]], norm_type: float = 2.0 41 | ) -> torch.Tensor: 42 | """ 43 | Adapted from: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_ 44 | """ 45 | 46 | if isinstance(parameters, torch.Tensor): 47 | parameters = [parameters] 48 | parameters = [p for p in parameters if p.grad is not None] 49 | 50 | norm_type = float(norm_type) 51 | 52 | if len(parameters) == 0: 53 | return torch.tensor(0.0) 54 | 55 | device = parameters[0].grad.device 56 | 57 | total_norm = torch.norm( 58 | torch.stack( 59 | [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] 60 | ), 61 | norm_type, 62 | ) 63 | 64 | return total_norm 65 | 66 | 67 | def write_xyz_file(coords, atom_types, filename): 68 | out = f"{len(coords)}\n\n" 69 | assert len(coords) == len(atom_types) 70 | for i in range(len(coords)): 71 | out += f"{atom_types[i]} {coords[i, 0]:.3f} {coords[i, 1]:.3f} {coords[i, 2]:.3f}\n" 72 | with open(filename, "w") as f: 73 | f.write(out) 74 | 75 | 76 | def write_sdf_file(sdf_path, molecules): 77 | # NOTE Changed to be compatitble with more versions of rdkit 78 | # with Chem.SDWriter(str(sdf_path)) as w: 79 | # for mol in molecules: 80 | # w.write(mol) 81 | 82 | w = Chem.SDWriter(str(sdf_path)) 83 | for m in molecules: 84 | if m is not None: 85 | w.write(m) 86 | 87 | print(f"Wrote SDF file to {sdf_path}") 88 | 89 | 90 | def residues_to_atoms(x_ca, dataset_info): 91 | x = x_ca 92 | one_hot = F.one_hot( 93 | torch.tensor(dataset_info["atom_encoder"]["C"], device=x_ca.device), 94 | num_classes=len(dataset_info["atom_encoder"]), 95 | ).repeat(*x_ca.shape[:-1], 1) 96 | return x, one_hot 97 | 98 | 99 | def get_residue_with_resi(pdb_chain, resi): 100 | res = [x for x in pdb_chain.get_residues() if x.id[1] == resi] 101 | assert len(res) == 1 102 | return res[0] 103 | 104 | 105 | def get_pocket_from_ligand(pdb_model, ligand_id, dist_cutoff=8.0): 106 | chain, resi = ligand_id.split(":") 107 | ligand = get_residue_with_resi(pdb_model[chain], int(resi)) 108 | ligand_coords = torch.from_numpy( 109 | np.array([a.get_coord() for a in ligand.get_atoms()]) 110 | ) 111 | 112 | pocket_residues = [] 113 | for residue in pdb_model.get_residues(): 114 | if residue.id[1] == resi: 115 | continue # skip ligand itself 116 | 117 | res_coords = torch.from_numpy( 118 | np.array([a.get_coord() for a in residue.get_atoms()]) 119 | ) 120 | if ( 121 | is_aa(residue.get_resname(), standard=True) 122 | and torch.cdist(res_coords, ligand_coords).min() < dist_cutoff 123 | ): 124 | pocket_residues.append(residue) 125 | 126 | return pocket_residues 127 | 128 | 129 | def batch_to_list(data, batch_mask): 130 | # data_list = [] 131 | # for i in torch.unique(batch_mask): 132 | # data_list.append(data[batch_mask == i]) 133 | # return data_list 134 | 135 | # make sure batch_mask is increasing 136 | idx = torch.argsort(batch_mask) 137 | batch_mask = batch_mask[idx] 138 | data = data[idx] 139 | 140 | chunk_sizes = torch.unique(batch_mask, return_counts=True)[1].tolist() 141 | return torch.split(data, chunk_sizes) 142 | 143 | 144 | def num_nodes_to_batch_mask(n_samples, num_nodes, device): 145 | assert isinstance(num_nodes, int) or len(num_nodes) == n_samples 146 | 147 | if isinstance(num_nodes, torch.Tensor): 148 | num_nodes = num_nodes.to(device) 149 | 150 | sample_inds = torch.arange(n_samples, device=device) 151 | 152 | return torch.repeat_interleave(sample_inds, num_nodes) 153 | 154 | 155 | def rdmol_to_nxgraph(rdmol): 156 | graph = nx.Graph() 157 | for atom in rdmol.GetAtoms(): 158 | # Add the atoms as nodes 159 | graph.add_node(atom.GetIdx(), atom_type=atom.GetAtomicNum()) 160 | 161 | # Add the bonds as edges 162 | for bond in rdmol.GetBonds(): 163 | graph.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) 164 | 165 | return graph 166 | 167 | 168 | def calc_rmsd(mol_a, mol_b): 169 | """Calculate RMSD of two molecules with unknown atom correspondence.""" 170 | graph_a = rdmol_to_nxgraph(mol_a) 171 | graph_b = rdmol_to_nxgraph(mol_b) 172 | 173 | gm = isomorphism.GraphMatcher( 174 | graph_a, graph_b, node_match=lambda na, nb: na["atom_type"] == nb["atom_type"] 175 | ) 176 | 177 | isomorphisms = list(gm.isomorphisms_iter()) 178 | if len(isomorphisms) < 1: 179 | return None 180 | 181 | all_rmsds = [] 182 | for mapping in isomorphisms: 183 | atom_types_a = [atom.GetAtomicNum() for atom in mol_a.GetAtoms()] 184 | atom_types_b = [ 185 | mol_b.GetAtomWithIdx(mapping[i]).GetAtomicNum() 186 | for i in range(mol_b.GetNumAtoms()) 187 | ] 188 | assert atom_types_a == atom_types_b 189 | 190 | conf_a = mol_a.GetConformer() 191 | coords_a = np.array( 192 | [conf_a.GetAtomPosition(i) for i in range(mol_a.GetNumAtoms())] 193 | ) 194 | conf_b = mol_b.GetConformer() 195 | coords_b = np.array( 196 | [conf_b.GetAtomPosition(mapping[i]) for i in range(mol_b.GetNumAtoms())] 197 | ) 198 | 199 | diff = coords_a - coords_b 200 | rmsd = np.sqrt(np.mean(np.sum(diff * diff, axis=1))) 201 | all_rmsds.append(rmsd) 202 | 203 | if len(isomorphisms) > 1: 204 | print("More than one isomorphism found. Returning minimum RMSD.") 205 | 206 | return min(all_rmsds) 207 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/pubchem/download_pubchem.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | from bs4 import BeautifulSoup 4 | from tqdm import tqdm 5 | 6 | 7 | def download_sdf_files(url, folder_path): 8 | try: 9 | response = requests.get(url) 10 | response.raise_for_status() # If the response was unsuccessful, this will raise a HTTPError 11 | except requests.exceptions.HTTPError as errh: 12 | print("HTTP Error:", errh) 13 | return 14 | except requests.exceptions.ConnectionError as errc: 15 | print("Error Connecting:", errc) 16 | return 17 | except requests.exceptions.Timeout as errt: 18 | print("Timeout Error:", errt) 19 | return 20 | except requests.exceptions.RequestException as err: 21 | print("Something went wrong with the request:", err) 22 | return 23 | 24 | soup = BeautifulSoup(response.text, "html.parser") 25 | 26 | for link in tqdm(soup.find_all("a")): 27 | file_link = link.get("href") 28 | if file_link.endswith(".sdf.gz"): 29 | file_url = url + file_link 30 | file_path = os.path.join(folder_path, file_link) 31 | 32 | try: 33 | with requests.get(file_url, stream=True) as r: 34 | r.raise_for_status() 35 | with open(file_path, "wb") as f: 36 | for chunk in r.iter_content(chunk_size=8192): 37 | f.write(chunk) 38 | except requests.exceptions.HTTPError as errh: 39 | print("HTTP Error:", errh) 40 | except requests.exceptions.ConnectionError as errc: 41 | print("Error Connecting:", errc) 42 | except requests.exceptions.Timeout as errt: 43 | print("Timeout Error:", errt) 44 | except requests.exceptions.RequestException as err: 45 | print("Something went wrong with the request:", err) 46 | 47 | 48 | url = "https://ftp.ncbi.nlm.nih.gov/pubchem/Compound_3D/01_conf_per_cmpd/SDF/" 49 | folder_path = ( 50 | "----" # replace with your folder path 51 | ) 52 | download_sdf_files(url, folder_path) 53 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/pubchem/preprocess_pubchem.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | import gzip 3 | from glob import glob 4 | import os 5 | from multiprocessing import cpu_count, Pool 6 | import pubchem.data.dataset_utils as dataset_utils 7 | from tqdm import tqdm 8 | import pickle 9 | import argparse 10 | 11 | 12 | def get_args(): 13 | # fmt: off 14 | parser = argparse.ArgumentParser(description='Data generation') 15 | parser.add_argument('--index', default=0, type=int, help='Which part of the splitted dataset') 16 | args = parser.parse_args() 17 | return args 18 | 19 | 20 | def chunks(lst, n): 21 | """Yield successive n-sized chunks from lst.""" 22 | for i in range(0, len(lst), n): 23 | yield lst[i : i + n] 24 | 25 | 26 | full_atom_encoder = { 27 | "H": 0, 28 | "C": 1, 29 | "N": 2, 30 | "O": 3, 31 | "F": 4, 32 | "Si": 5, 33 | "P": 6, 34 | "S": 7, 35 | "Cl": 8, 36 | "Br": 9, 37 | "I": 10, 38 | } 39 | 40 | 41 | def process(file): 42 | inf = gzip.open(file) 43 | with Chem.ForwardSDMolSupplier(inf) as gzsuppl: 44 | molecules = [x for x in gzsuppl if x is not None] 45 | for mol in molecules: 46 | try: 47 | smiles = Chem.MolToSmiles(mol) 48 | smiles_list.append(smiles) 49 | data = dataset_utils.mol_to_torch_geometric(mol, full_atom_encoder, smiles) 50 | if data.pos.shape[0] != data.x.shape[0]: 51 | print(f"Molecule {smiles} does not have 3D information!") 52 | continue 53 | if data.pos.ndim != 2: 54 | print(f"Molecule {smiles} does not have 3D information!") 55 | continue 56 | if len(data.pos) < 2: 57 | print(f"Molecule {smiles} does not have 3D information!") 58 | continue 59 | data_list.append(data) 60 | except: 61 | continue 62 | # pbar.update(1) 63 | 64 | return 65 | 66 | 67 | if __name__ == "__main__": 68 | args = get_args() 69 | 70 | data_list = [] 71 | smiles_list = [] 72 | 73 | files = f"----files_{args.index}.pickle" 74 | with open(files, "rb") as f: 75 | files = pickle.load(f) 76 | 77 | pbar = tqdm(total=len(files)) 78 | 79 | for file in tqdm(files): 80 | process(file) 81 | 82 | with open(f"data_list_{args.index}.pickle", "wb") as f: 83 | pickle.dump(data_list, f) 84 | with open(f"smiles_list_{args.index}.pickle", "wb") as f: 85 | pickle.dump(smiles_list, f) 86 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/pubchem/process_pubchem.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | from tqdm import tqdm 4 | import lmdb 5 | from rdkit import Chem, RDLogger 6 | import experiments.data.utils as dataset_utils 7 | import gzip 8 | import io 9 | import multiprocessing as mp 10 | import pickle 11 | 12 | RDLogger.DisableLog("rdApp.*") 13 | 14 | 15 | def chunks(lst, n): 16 | """Yield successive n-sized chunks from lst.""" 17 | for i in range(0, len(lst), n): 18 | yield lst[i : i + n] 19 | 20 | 21 | DATA_PATH = "---" 22 | FULL_ATOM_ENCODER = { 23 | "H": 0, 24 | "B": 1, 25 | "C": 2, 26 | "N": 3, 27 | "O": 4, 28 | "F": 5, 29 | "Al": 6, 30 | "Si": 7, 31 | "P": 8, 32 | "S": 9, 33 | "Cl": 10, 34 | "As": 11, 35 | "Br": 12, 36 | "I": 13, 37 | "Hg": 14, 38 | "Bi": 15, 39 | } 40 | 41 | 42 | def process_files( 43 | processes: int = 36, chunk_size: int = 1024, subchunk: int = 128, removeHs=False 44 | ): 45 | """ 46 | :param dataset: 47 | :param max_conformers: 48 | :param processes: 49 | :param chunk_size: 50 | :param subchunk: 51 | :return: 52 | """ 53 | 54 | data_list = glob(os.path.join(DATA_PATH, "raw/*.gz")) 55 | h = "noh" if removeHs else "h" 56 | print(f"Process without hydrogens: {removeHs}") 57 | 58 | save_path = os.path.join(DATA_PATH, f"database_{h}") 59 | if os.path.exists(save_path): 60 | print("FYI: Output directory has been created already.") 61 | chunked_list = list(chunks(data_list, chunk_size)) 62 | chunked_list = [list(chunks(l, subchunk)) for l in chunked_list] 63 | 64 | print(f"Total number of molecules {len(data_list)}.") 65 | print(f"Processing {len(chunked_list)} chunks each of size {chunk_size}.") 66 | 67 | env = lmdb.open(str(save_path), map_size=int(1e13)) 68 | global_id = 0 69 | with env.begin(write=True) as txn: 70 | for chunklist in tqdm(chunked_list, total=len(chunked_list), desc="Chunks"): 71 | chunkresult = [] 72 | for datachunk in tqdm(chunklist, total=len(chunklist), desc="Datachunks"): 73 | removeHs_list = [removeHs] * len(datachunk) 74 | with mp.Pool(processes=processes) as pool: 75 | res = pool.starmap( 76 | func=db_sample_helper, iterable=zip(datachunk, removeHs_list) 77 | ) 78 | res = [r for r in res if r is not None] 79 | chunkresult.append(res) 80 | 81 | confs_sub = [] 82 | smiles_sub = [] 83 | for cr in chunkresult: 84 | subconfs = [a["confs"] for a in cr] 85 | subconfs = [item for sublist in subconfs for item in sublist] 86 | subsmiles = [a["smiles"] for a in cr] 87 | subsmiles = [item for sublist in subsmiles for item in sublist] 88 | confs_sub.append(subconfs) 89 | smiles_sub.append(subsmiles) 90 | 91 | confs_sub = [item for sublist in confs_sub for item in sublist] 92 | smiles_sub = [item for sublist in smiles_sub for item in sublist] 93 | 94 | assert len(confs_sub) == len(smiles_sub) 95 | # save 96 | for conf in confs_sub: 97 | result = txn.put(str(global_id).encode(), conf, overwrite=False) 98 | if not result: 99 | raise RuntimeError( 100 | f"LMDB entry {global_id} in {str(save_path)} " "already exists" 101 | ) 102 | global_id += 1 103 | 104 | print(f"{global_id} molecules have been processed!") 105 | print("Finished!") 106 | 107 | 108 | def db_sample_helper(file, removeHs=False): 109 | saved_confs_list = [] 110 | smiles_list = [] 111 | 112 | inf = gzip.open(file) 113 | with Chem.ForwardSDMolSupplier(inf, removeHs=removeHs) as gzsuppl: 114 | molecules = [x for x in gzsuppl if x is not None] 115 | for mol in molecules: 116 | try: 117 | smiles = Chem.MolToSmiles(mol) 118 | data = dataset_utils.mol_to_torch_geometric( 119 | mol, FULL_ATOM_ENCODER, smiles, remove_hydrogens=removeHs 120 | ) 121 | if data.pos.shape[0] != data.x.shape[0]: 122 | continue 123 | if data.pos.ndim != 2: 124 | continue 125 | if len(data.pos) < 2: 126 | continue 127 | except: 128 | continue 129 | # create binary object to be saved 130 | buf = io.BytesIO() 131 | saves = { 132 | "mol": mol, 133 | "data": data, 134 | } 135 | with gzip.GzipFile(fileobj=buf, mode="wb", compresslevel=6) as f: 136 | f.write(pickle.dumps(saves)) 137 | compressed = buf.getvalue() 138 | saved_confs_list.append(compressed) 139 | smiles_list.append(smiles) 140 | return { 141 | "confs": saved_confs_list, 142 | "smiles": smiles_list, 143 | } 144 | 145 | 146 | if __name__ == "__main__": 147 | process_files(removeHs=False) 148 | # process_files(removeHs=True) 149 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/pubchem/pubchem_dataset_adaptive.py: -------------------------------------------------------------------------------- 1 | from rdkit import RDLogger 2 | import torch 3 | import numpy as np 4 | from os.path import join 5 | from torch_geometric.data import Dataset, DataLoader 6 | from experiments.data.utils import load_pickle, make_splits 7 | from torch.utils.data import Subset 8 | import lmdb 9 | import pickle 10 | import gzip 11 | import io 12 | from pytorch_lightning import LightningDataModule 13 | import os 14 | import experiments.data.utils as dataset_utils 15 | from experiments.data.abstract_dataset import ( 16 | AbstractAdaptiveDataModule, 17 | ) 18 | 19 | GEOM_DATADIR = "----" 20 | 21 | full_atom_encoder = { 22 | "H": 0, 23 | "B": 1, 24 | "C": 2, 25 | "N": 3, 26 | "O": 4, 27 | "F": 5, 28 | "Al": 6, 29 | "Si": 7, 30 | "P": 8, 31 | "S": 9, 32 | "Cl": 10, 33 | "As": 11, 34 | "Br": 12, 35 | "I": 13, 36 | "Hg": 14, 37 | "Bi": 15, 38 | } 39 | 40 | 41 | class PubChemLMDBDataset(Dataset): 42 | def __init__( 43 | self, 44 | root: str, 45 | remove_hs: bool, 46 | evaluation: bool = False, 47 | ): 48 | """ 49 | Constructor 50 | """ 51 | self.data_file = root 52 | self._num_graphs = 95173200 if remove_hs else 95173300 # 94980241 for noH!! 53 | 54 | if remove_hs: 55 | assert "_noh" in root 56 | self.stats_dir = ( 57 | "/scratch1/cremej01/data/pubchem/dataset_noh/processed" 58 | if not evaluation 59 | else GEOM_DATADIR 60 | ) 61 | else: 62 | assert "_h" in root 63 | self.stats_dir = ( 64 | "/scratch1/cremej01/data/pubchem/dataset_h/processed" 65 | if not evaluation 66 | else GEOM_DATADIR 67 | ) 68 | 69 | super().__init__(root) 70 | 71 | self.remove_hs = remove_hs 72 | self.statistics = dataset_utils.Statistics( 73 | num_nodes=load_pickle( 74 | os.path.join(self.stats_dir, self.processed_files[0]) 75 | ), 76 | atom_types=torch.from_numpy( 77 | np.load(os.path.join(self.stats_dir, self.processed_files[1])) 78 | ), 79 | bond_types=torch.from_numpy( 80 | np.load(os.path.join(self.stats_dir, self.processed_files[2])) 81 | ), 82 | charge_types=torch.from_numpy( 83 | np.load(os.path.join(self.stats_dir, self.processed_files[3])) 84 | ), 85 | valencies=load_pickle( 86 | os.path.join(self.stats_dir, self.processed_files[4]) 87 | ), 88 | bond_lengths=load_pickle( 89 | os.path.join(self.stats_dir, self.processed_files[5]) 90 | ), 91 | bond_angles=torch.from_numpy( 92 | np.load(os.path.join(self.stats_dir, self.processed_files[6])) 93 | ), 94 | is_aromatic=torch.from_numpy( 95 | np.load(os.path.join(self.stats_dir, self.processed_files[7])) 96 | ).float(), 97 | is_in_ring=torch.from_numpy( 98 | np.load(os.path.join(self.stats_dir, self.processed_files[8])) 99 | ).float(), 100 | hybridization=torch.from_numpy( 101 | np.load(os.path.join(self.stats_dir, self.processed_files[9])) 102 | ).float(), 103 | ) 104 | self.smiles = load_pickle( 105 | os.path.join(self.stats_dir, self.processed_files[10]) 106 | ) 107 | 108 | def _init_db(self): 109 | self._env = lmdb.open( 110 | str(self.data_file), 111 | readonly=True, 112 | lock=False, 113 | readahead=False, 114 | meminit=False, 115 | create=False, 116 | ) 117 | 118 | def get(self, index: int): 119 | self._init_db() 120 | 121 | with self._env.begin(write=False) as txn: 122 | compressed = txn.get(str(index).encode()) 123 | buf = io.BytesIO(compressed) 124 | with gzip.GzipFile(fileobj=buf, mode="rb") as f: 125 | serialized = f.read() 126 | try: 127 | item = pickle.loads(serialized)["data"] 128 | except: 129 | return None 130 | 131 | return item 132 | 133 | def len(self) -> int: 134 | r"""Returns the number of graphs stored in the dataset.""" 135 | return self._num_graphs 136 | 137 | def __len__(self) -> int: 138 | return self._num_graphs 139 | 140 | @property 141 | def processed_files(self): 142 | h = "noh" if self.remove_hs else "h" 143 | return [ 144 | f"train_n_{h}.pickle", 145 | f"train_atom_types_{h}.npy", 146 | f"train_bond_types_{h}.npy", 147 | f"train_charges_{h}.npy", 148 | f"train_valency_{h}.pickle", 149 | f"train_bond_lengths_{h}.pickle", 150 | f"train_angles_{h}.npy", 151 | f"train_is_aromatic_{h}.npy", 152 | f"train_is_in_ring_{h}.npy", 153 | f"train_hybridization_{h}.npy", 154 | "train_smiles.pickle", 155 | ] 156 | 157 | 158 | class PubChemDataModule(AbstractAdaptiveDataModule): 159 | def __init__(self, hparams, evaluation=False): 160 | self.save_hyperparameters(hparams) 161 | self.datadir = hparams.dataset_root 162 | self.pin_memory = True 163 | 164 | self.remove_hs = hparams.remove_hs 165 | if self.remove_hs: 166 | print("Pre-Training on dataset with implicit hydrogens") 167 | self.dataset = PubChemLMDBDataset( 168 | root=self.datadir, remove_hs=self.remove_hs, evaluation=evaluation 169 | ) 170 | 171 | self.train_smiles = self.dataset.smiles 172 | 173 | self.idx_train, self.idx_val, self.idx_test = make_splits( 174 | len(self.dataset), 175 | train_size=hparams.train_size, 176 | val_size=hparams.val_size, 177 | test_size=hparams.test_size, 178 | seed=hparams.seed, 179 | filename=join(self.hparams["save_dir"], "splits.npz"), 180 | splits=None, 181 | ) 182 | print( 183 | f"train {len(self.idx_train)}, val {len(self.idx_val)}, test {len(self.idx_test)}" 184 | ) 185 | train_dataset = Subset(self.dataset, self.idx_train) 186 | val_dataset = Subset(self.dataset, self.idx_val) 187 | test_dataset = Subset(self.dataset, self.idx_test) 188 | 189 | self.statistics = { 190 | "train": self.dataset.statistics, 191 | "val": self.dataset.statistics, 192 | "test": self.dataset.statistics, 193 | } 194 | 195 | super().__init__(hparams, train_dataset, val_dataset, test_dataset) 196 | 197 | def _train_dataloader(self, shuffle=True): 198 | dataloader = DataLoader( 199 | dataset=self.train_dataset, 200 | batch_size=self.cfg.batch_size, 201 | num_workers=self.cfg.num_workers, 202 | pin_memory=self.pin_memory, 203 | shuffle=shuffle, 204 | persistent_workers=False, 205 | ) 206 | return dataloader 207 | 208 | def _val_dataloader(self, shuffle=False): 209 | dataloader = DataLoader( 210 | dataset=self.val_dataset, 211 | batch_size=self.cfg.batch_size, 212 | num_workers=self.cfg.num_workers, 213 | pin_memory=self.pin_memory, 214 | shuffle=shuffle, 215 | persistent_workers=False, 216 | ) 217 | return dataloader 218 | 219 | def _test_dataloader(self, shuffle=False): 220 | dataloader = DataLoader( 221 | dataset=self.test_dataset, 222 | batch_size=self.cfg.batch_size, 223 | num_workers=self.cfg.num_workers, 224 | pin_memory=self.pin_memory, 225 | shuffle=shuffle, 226 | persistent_workers=False, 227 | ) 228 | return dataloader 229 | 230 | def compute_mean_mad(self, properties_list): 231 | if self.cfg.dataset == "qm9" or self.cfg.dataset == "drugs": 232 | dataloader = self.get_dataloader(self.train_dataset, "val") 233 | return self.compute_mean_mad_from_dataloader(dataloader, properties_list) 234 | elif self.cfg.dataset == "qm9_1half" or self.cfg.dataset == "qm9_2half": 235 | dataloader = self.get_dataloader(self.val_dataset, "val") 236 | return self.compute_mean_mad_from_dataloader(dataloader, properties_list) 237 | else: 238 | raise Exception("Wrong dataset name") 239 | 240 | def compute_mean_mad_from_dataloader(self, dataloader, properties_list): 241 | property_norms = {} 242 | for property_key in properties_list: 243 | try: 244 | property_name = property_key + "_mm" 245 | values = getattr(dataloader.dataset[:], property_name) 246 | except: 247 | property_name = property_key 248 | idx = dataloader.dataset[:].label2idx[property_name] 249 | values = torch.tensor( 250 | [data.y[:, idx] for data in dataloader.dataset[:]] 251 | ) 252 | 253 | mean = torch.mean(values) 254 | ma = torch.abs(values - mean) 255 | mad = torch.mean(ma) 256 | property_norms[property_key] = {} 257 | property_norms[property_key]["mean"] = mean 258 | property_norms[property_key]["mad"] = mad 259 | del values 260 | return property_norms 261 | 262 | def get_dataloader(self, dataset, stage): 263 | if stage == "train": 264 | batch_size = self.cfg.batch_size 265 | shuffle = True 266 | elif stage in ["val", "test"]: 267 | batch_size = self.cfg.inference_batch_size 268 | shuffle = False 269 | 270 | dl = DataLoader( 271 | dataset=dataset, 272 | batch_size=batch_size, 273 | num_workers=self.cfg.num_workers, 274 | pin_memory=True, 275 | shuffle=shuffle, 276 | ) 277 | 278 | return dl 279 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/pubchem/pubchem_dataset_nonadaptive.py: -------------------------------------------------------------------------------- 1 | from rdkit import RDLogger 2 | import torch 3 | import numpy as np 4 | from os.path import join 5 | from torch_geometric.data import Dataset, DataLoader 6 | from experiments.data.utils import load_pickle, make_splits 7 | from torch.utils.data import Subset 8 | import lmdb 9 | import pickle 10 | import gzip 11 | import io 12 | from pytorch_lightning import LightningDataModule 13 | import os 14 | import experiments.data.utils as dataset_utils 15 | 16 | full_atom_encoder = { 17 | "H": 0, 18 | "B": 1, 19 | "C": 2, 20 | "N": 3, 21 | "O": 4, 22 | "F": 5, 23 | "Al": 6, 24 | "Si": 7, 25 | "P": 8, 26 | "S": 9, 27 | "Cl": 10, 28 | "As": 11, 29 | "Br": 12, 30 | "I": 13, 31 | "Hg": 14, 32 | "Bi": 15, 33 | } 34 | GEOM_DATADIR = "----" 35 | 36 | 37 | class PubChemLMDBDataset(Dataset): 38 | def __init__( 39 | self, 40 | root: str, 41 | remove_hs: bool, 42 | evaluation: bool = False, 43 | ): 44 | """ 45 | Constructor 46 | """ 47 | self.data_file = root 48 | self._num_graphs = 95173200 if remove_hs else 95173300 # 94980241 for noH!! 49 | if remove_hs: 50 | assert "_noh" in root 51 | self.stats_dir = "----" 52 | else: 53 | assert "_h" in root 54 | self.stats_dir = "----" 55 | super().__init__(root) 56 | 57 | self.remove_hs = remove_hs 58 | self.statistics = dataset_utils.Statistics( 59 | num_nodes=load_pickle( 60 | os.path.join(self.stats_dir, self.processed_files[0]) 61 | ), 62 | atom_types=torch.from_numpy( 63 | np.load(os.path.join(self.stats_dir, self.processed_files[1])) 64 | ), 65 | bond_types=torch.from_numpy( 66 | np.load(os.path.join(self.stats_dir, self.processed_files[2])) 67 | ), 68 | charge_types=torch.from_numpy( 69 | np.load(os.path.join(self.stats_dir, self.processed_files[3])) 70 | ), 71 | valencies=None #load_pickle( 72 | #os.path.join(self.stats_dir, self.processed_files[4]) 73 | #), 74 | , 75 | bond_lengths=load_pickle( 76 | os.path.join(self.stats_dir, self.processed_files[5]) 77 | ), 78 | bond_angles=torch.from_numpy( 79 | np.load(os.path.join(self.stats_dir, self.processed_files[6])) 80 | ), 81 | is_aromatic=torch.from_numpy( 82 | np.load(os.path.join(self.stats_dir, self.processed_files[7])) 83 | ).float(), 84 | is_in_ring=torch.from_numpy( 85 | np.load(os.path.join(self.stats_dir, self.processed_files[8])) 86 | ).float(), 87 | hybridization=torch.from_numpy( 88 | np.load(os.path.join(self.stats_dir, self.processed_files[9])) 89 | ).float(), 90 | ) 91 | self.smiles = None # load_pickle( 92 | #os.path.join(self.stats_dir, self.processed_files[10]) 93 | #) 94 | 95 | def _init_db(self): 96 | self._env = lmdb.open( 97 | str(self.data_file), 98 | readonly=True, 99 | lock=False, 100 | readahead=False, 101 | meminit=False, 102 | create=False, 103 | ) 104 | 105 | def get(self, index: int): 106 | self._init_db() 107 | 108 | with self._env.begin(write=False) as txn: 109 | compressed = txn.get(str(index).encode()) 110 | buf = io.BytesIO(compressed) 111 | with gzip.GzipFile(fileobj=buf, mode="rb") as f: 112 | serialized = f.read() 113 | try: 114 | item = pickle.loads(serialized)["data"] 115 | except: 116 | return None 117 | 118 | return item 119 | 120 | def len(self) -> int: 121 | r"""Returns the number of graphs stored in the dataset.""" 122 | return self._num_graphs 123 | 124 | def __len__(self) -> int: 125 | return self._num_graphs 126 | 127 | @property 128 | def processed_files(self): 129 | h = "noh" if self.remove_hs else "h" 130 | return [ 131 | f"train_n_{h}.pickle", 132 | f"train_atom_types_{h}.npy", 133 | f"train_bond_types_{h}.npy", 134 | f"train_charges_{h}.npy", 135 | f"train_valency_{h}.pickle", 136 | f"train_bond_lengths_{h}.pickle", 137 | f"train_angles_{h}.npy", 138 | f"train_is_aromatic_{h}.npy", 139 | f"train_is_in_ring_{h}.npy", 140 | f"train_hybridization_{h}.npy", 141 | "train_smiles.pickle", 142 | ] 143 | 144 | 145 | class PubChemDataModule(LightningDataModule): 146 | def __init__(self, hparams, evaluation=False): 147 | super(PubChemDataModule, self).__init__() 148 | self.save_hyperparameters(hparams) 149 | self.datadir = hparams.dataset_root 150 | self.pin_memory = True 151 | 152 | self.remove_hs = hparams.remove_hs 153 | if self.remove_hs: 154 | print("Pre-Training on dataset with implicit hydrogens") 155 | self.dataset = PubChemLMDBDataset( 156 | root=self.datadir, remove_hs=self.remove_hs, evaluation=evaluation 157 | ) 158 | 159 | self.train_smiles = self.dataset.smiles 160 | 161 | self.idx_train, self.idx_val, self.idx_test = make_splits( 162 | len(self.dataset), 163 | train_size=hparams.train_size, 164 | val_size=hparams.val_size, 165 | test_size=hparams.test_size, 166 | seed=hparams.seed, 167 | filename=join(self.hparams["save_dir"], "splits.npz"), 168 | splits=None, 169 | ) 170 | print( 171 | f"train {len(self.idx_train)}, val {len(self.idx_val)}, test {len(self.idx_test)}" 172 | ) 173 | self.train_dataset = Subset(self.dataset, self.idx_train) 174 | self.val_dataset = Subset(self.dataset, self.idx_val) 175 | self.test_dataset = Subset(self.dataset, self.idx_test) 176 | 177 | self.statistics = { 178 | "train": self.dataset.statistics, 179 | "val": self.dataset.statistics, 180 | "test": self.dataset.statistics, 181 | } 182 | 183 | def train_dataloader(self, shuffle=False): 184 | dataloader = DataLoader( 185 | dataset=self.train_dataset, 186 | batch_size=self.hparams.batch_size, 187 | num_workers=self.hparams.num_workers, 188 | pin_memory=self.pin_memory, 189 | shuffle=True, 190 | persistent_workers=False, 191 | ) 192 | return dataloader 193 | 194 | def val_dataloader(self, shuffle=False): 195 | dataloader = DataLoader( 196 | dataset=self.val_dataset, 197 | batch_size=self.hparams.batch_size, 198 | num_workers=self.hparams.num_workers, 199 | pin_memory=self.pin_memory, 200 | shuffle=shuffle, 201 | persistent_workers=False, 202 | ) 203 | return dataloader 204 | 205 | def test_dataloader(self, shuffle=False): 206 | dataloader = DataLoader( 207 | dataset=self.test_dataset, 208 | batch_size=self.hparams.batch_size, 209 | num_workers=self.hparams.num_workers, 210 | pin_memory=self.pin_memory, 211 | shuffle=shuffle, 212 | persistent_workers=False, 213 | ) 214 | return dataloader 215 | 216 | def compute_mean_mad(self, properties_list): 217 | if self.cfg.dataset == "qm9" or self.cfg.dataset == "drugs": 218 | dataloader = self.get_dataloader(self.train_dataset, "val") 219 | return self.compute_mean_mad_from_dataloader(dataloader, properties_list) 220 | elif self.cfg.dataset == "qm9_1half" or self.cfg.dataset == "qm9_2half": 221 | dataloader = self.get_dataloader(self.val_dataset, "val") 222 | return self.compute_mean_mad_from_dataloader(dataloader, properties_list) 223 | else: 224 | raise Exception("Wrong dataset name") 225 | 226 | def compute_mean_mad_from_dataloader(self, dataloader, properties_list): 227 | property_norms = {} 228 | for property_key in properties_list: 229 | try: 230 | property_name = property_key + "_mm" 231 | values = getattr(dataloader.dataset[:], property_name) 232 | except: 233 | property_name = property_key 234 | idx = dataloader.dataset[:].label2idx[property_name] 235 | values = torch.tensor( 236 | [data.y[:, idx] for data in dataloader.dataset[:]] 237 | ) 238 | 239 | mean = torch.mean(values) 240 | ma = torch.abs(values - mean) 241 | mad = torch.mean(ma) 242 | property_norms[property_key] = {} 243 | property_norms[property_key]["mean"] = mean 244 | property_norms[property_key]["mad"] = mad 245 | del values 246 | return property_norms 247 | 248 | def get_dataloader(self, dataset, stage): 249 | if stage == "train": 250 | batch_size = self.cfg.batch_size 251 | shuffle = True 252 | elif stage in ["val", "test"]: 253 | batch_size = self.cfg.inference_batch_size 254 | shuffle = False 255 | 256 | dl = DataLoader( 257 | dataset=dataset, 258 | batch_size=batch_size, 259 | num_workers=self.cfg.num_workers, 260 | pin_memory=True, 261 | shuffle=shuffle, 262 | ) 263 | 264 | return dl 265 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/data/qm9/qm9_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from typing import Any, Sequence 4 | from torch.utils.data import Subset 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from experiments.data.abstract_dataset import AbstractDataModule 9 | from experiments.data.metrics import compute_all_statistics 10 | from experiments.data.utils import ( 11 | Statistics, 12 | load_pickle, 13 | mol_to_torch_geometric, 14 | remove_hydrogens, 15 | save_pickle, 16 | train_subset, 17 | ) 18 | from os.path import join 19 | 20 | from rdkit import Chem, RDLogger 21 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip 22 | from tqdm import tqdm 23 | from torch_geometric.data import DataLoader 24 | 25 | 26 | def files_exist(files) -> bool: 27 | # NOTE: We return `False` in case `files` is empty, leading to a 28 | # re-processing of files on every instantiation. 29 | return len(files) != 0 and all([osp.exists(f) for f in files]) 30 | 31 | 32 | def to_list(value: Any) -> Sequence: 33 | if isinstance(value, Sequence) and not isinstance(value, str): 34 | return value 35 | else: 36 | return [value] 37 | 38 | 39 | full_atom_encoder = { 40 | "H": 0, 41 | "B": 1, 42 | "C": 2, 43 | "N": 3, 44 | "O": 4, 45 | "F": 5, 46 | "Al": 6, 47 | "Si": 7, 48 | "P": 8, 49 | "S": 9, 50 | "Cl": 10, 51 | "As": 11, 52 | "Br": 12, 53 | "I": 13, 54 | "Hg": 14, 55 | "Bi": 15, 56 | } 57 | 58 | 59 | class QM9Dataset(InMemoryDataset): 60 | raw_url = ( 61 | "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/" 62 | "molnet_publish/qm9.zip" 63 | ) 64 | raw_url2 = "https://ndownloader.figshare.com/files/3195404" 65 | processed_url = "https://data.pyg.org/datasets/qm9_v3.zip" 66 | 67 | def __init__( 68 | self, 69 | split, 70 | root, 71 | remove_h: bool, 72 | transform=None, 73 | pre_transform=None, 74 | pre_filter=None, 75 | only_stats=False 76 | ): 77 | self.split = split 78 | if self.split == "train": 79 | self.file_idx = 0 80 | elif self.split == "val": 81 | self.file_idx = 1 82 | else: 83 | self.file_idx = 2 84 | self.remove_h = remove_h 85 | 86 | self.atom_encoder = full_atom_encoder 87 | if remove_h: 88 | self.atom_encoder = { 89 | k: v - 1 for k, v in self.atom_encoder.items() if k != "H" 90 | } 91 | 92 | super().__init__(root, transform, pre_transform, pre_filter) 93 | if not only_stats: 94 | self.data, self.slices = torch.load(self.processed_paths[0]) 95 | else: 96 | self.data, self.slices = None, None 97 | 98 | self.statistics = Statistics( 99 | num_nodes=load_pickle(self.processed_paths[1]), 100 | atom_types=torch.from_numpy(np.load(self.processed_paths[2])).float(), 101 | bond_types=torch.from_numpy(np.load(self.processed_paths[3])).float(), 102 | charge_types=torch.from_numpy(np.load(self.processed_paths[4])).float(), 103 | valencies=load_pickle(self.processed_paths[5]), 104 | bond_lengths=load_pickle(self.processed_paths[6]), 105 | bond_angles=torch.from_numpy(np.load(self.processed_paths[7])).float(), 106 | is_aromatic=torch.from_numpy(np.load(self.processed_paths[9])).float(), 107 | is_in_ring=torch.from_numpy(np.load(self.processed_paths[10])).float(), 108 | hybridization=torch.from_numpy(np.load(self.processed_paths[11])).float(), 109 | ) 110 | self.smiles = load_pickle(self.processed_paths[8]) 111 | 112 | @property 113 | def raw_file_names(self): 114 | return ["gdb9.sdf", "gdb9.sdf.csv", "uncharacterized.txt"] 115 | 116 | @property 117 | def split_file_name(self): 118 | return ["train.csv", "val.csv", "test.csv"] 119 | 120 | @property 121 | def split_paths(self): 122 | r"""The absolute filepaths that must be present in order to skip 123 | splitting.""" 124 | files = to_list(self.split_file_name) 125 | return [osp.join(self.raw_dir, f) for f in files] 126 | 127 | @property 128 | def processed_file_names(self): 129 | h = "noh" if self.remove_h else "h" 130 | if self.split == "train": 131 | return [ 132 | f"train_{h}.pt", 133 | f"train_n_{h}.pickle", 134 | f"train_atom_types_{h}.npy", 135 | f"train_bond_types_{h}.npy", 136 | f"train_charges_{h}.npy", 137 | f"train_valency_{h}.pickle", 138 | f"train_bond_lengths_{h}.pickle", 139 | f"train_angles_{h}.npy", 140 | "train_smiles.pickle", 141 | f"train_is_aromatic_{h}.npy", 142 | f"train_is_in_ring_{h}.npy", 143 | f"train_hybridization_{h}.npy", 144 | ] 145 | elif self.split == "val": 146 | return [ 147 | f"val_{h}.pt", 148 | f"val_n_{h}.pickle", 149 | f"val_atom_types_{h}.npy", 150 | f"val_bond_types_{h}.npy", 151 | f"val_charges_{h}.npy", 152 | f"val_valency_{h}.pickle", 153 | f"val_bond_lengths_{h}.pickle", 154 | f"val_angles_{h}.npy", 155 | "val_smiles.pickle", 156 | f"val_is_aromatic_{h}.npy", 157 | f"val_is_in_ring_{h}.npy", 158 | f"val_hybridization_{h}.npy", 159 | ] 160 | else: 161 | return [ 162 | f"test_{h}.pt", 163 | f"test_n_{h}.pickle", 164 | f"test_atom_types_{h}.npy", 165 | f"test_bond_types_{h}.npy", 166 | f"test_charges_{h}.npy", 167 | f"test_valency_{h}.pickle", 168 | f"test_bond_lengths_{h}.pickle", 169 | f"test_angles_{h}.npy", 170 | "test_smiles.pickle", 171 | f"test_is_aromatic_{h}.npy", 172 | f"test_is_in_ring_{h}.npy", 173 | f"test_hybridization_{h}.npy", 174 | ] 175 | 176 | def download(self): 177 | """ 178 | Download raw qm9 files. Taken from PyG QM9 class 179 | """ 180 | try: 181 | import rdkit # noqa 182 | 183 | file_path = download_url(self.raw_url, self.raw_dir) 184 | extract_zip(file_path, self.raw_dir) 185 | os.unlink(file_path) 186 | 187 | file_path = download_url(self.raw_url2, self.raw_dir) 188 | os.rename( 189 | osp.join(self.raw_dir, "3195404"), 190 | osp.join(self.raw_dir, "uncharacterized.txt"), 191 | ) 192 | except ImportError: 193 | path = download_url(self.processed_url, self.raw_dir) 194 | extract_zip(path, self.raw_dir) 195 | os.unlink(path) 196 | 197 | if files_exist(self.split_paths): 198 | return 199 | 200 | dataset = pd.read_csv(self.raw_paths[1]) 201 | 202 | n_samples = len(dataset) 203 | n_train = 100000 204 | n_test = int(0.1 * n_samples) 205 | n_val = n_samples - (n_train + n_test) 206 | 207 | # Shuffle dataset with df.sample, then split 208 | train, val, test = np.split( 209 | dataset.sample(frac=1, random_state=42), [n_train, n_val + n_train] 210 | ) 211 | 212 | train.to_csv(os.path.join(self.raw_dir, "train.csv")) 213 | val.to_csv(os.path.join(self.raw_dir, "val.csv")) 214 | test.to_csv(os.path.join(self.raw_dir, "test.csv")) 215 | 216 | def process(self): 217 | RDLogger.DisableLog("rdApp.*") 218 | 219 | target_df = pd.read_csv(self.split_paths[self.file_idx], index_col=0) 220 | target_df.drop(columns=["mol_id"], inplace=True) 221 | 222 | with open(self.raw_paths[-1]) as f: 223 | skip = [int(x.split()[0]) - 1 for x in f.read().split("\n")[9:-2]] 224 | 225 | suppl = Chem.SDMolSupplier(self.raw_paths[0], removeHs=False, sanitize=False) 226 | data_list = [] 227 | all_smiles = [] 228 | num_errors = 0 229 | for i, mol in enumerate(tqdm(suppl)): 230 | if i in skip or i not in target_df.index: 231 | continue 232 | smiles = Chem.MolToSmiles(mol, isomericSmiles=False) 233 | if smiles is None: 234 | num_errors += 1 235 | else: 236 | all_smiles.append(smiles) 237 | 238 | data = mol_to_torch_geometric(mol, full_atom_encoder, smiles) 239 | if self.remove_h: 240 | data = remove_hydrogens(data) 241 | 242 | if self.pre_filter is not None and not self.pre_filter(data): 243 | continue 244 | if self.pre_transform is not None: 245 | data = self.pre_transform(data) 246 | 247 | data_list.append(data) 248 | torch.save(self.collate(data_list), self.processed_paths[self.file_idx]) 249 | 250 | statistics = compute_all_statistics( 251 | data_list, 252 | self.atom_encoder, 253 | charges_dic={-2: 0, -1: 1, 0: 2, 1: 3, 2: 4, 3: 5}, 254 | additional_feats=True, 255 | ) 256 | 257 | save_pickle(statistics.num_nodes, self.processed_paths[1]) 258 | np.save(self.processed_paths[2], statistics.atom_types) 259 | np.save(self.processed_paths[3], statistics.bond_types) 260 | np.save(self.processed_paths[4], statistics.charge_types) 261 | save_pickle(statistics.valencies, self.processed_paths[5]) 262 | save_pickle(statistics.bond_lengths, self.processed_paths[6]) 263 | np.save(self.processed_paths[7], statistics.bond_angles) 264 | 265 | np.save(self.processed_paths[9], statistics.is_aromatic) 266 | np.save(self.processed_paths[10], statistics.is_in_ring) 267 | np.save(self.processed_paths[11], statistics.hybridization) 268 | 269 | print("Number of molecules that could not be mapped to smiles: ", num_errors) 270 | save_pickle(set(all_smiles), self.processed_paths[8]) 271 | torch.save(self.collate(data_list), self.processed_paths[0]) 272 | 273 | 274 | class QM9DataModule(AbstractDataModule): 275 | def __init__(self, cfg, only_stats: bool = False): 276 | self.datadir = cfg.dataset_root 277 | root_path = self.datadir 278 | 279 | train_dataset = QM9Dataset( 280 | split="train", root=root_path, remove_h=cfg.remove_hs, only_stats=only_stats 281 | ) 282 | val_dataset = QM9Dataset(split="val", root=root_path, remove_h=cfg.remove_hs, only_stats=only_stats) 283 | test_dataset = QM9Dataset(split="test", root=root_path, remove_h=cfg.remove_hs, only_stats=only_stats) 284 | 285 | self.statistics = { 286 | "train": train_dataset.statistics, 287 | "val": val_dataset.statistics, 288 | "test": test_dataset.statistics, 289 | } 290 | if not only_stats: 291 | if cfg.select_train_subset: 292 | self.idx_train = train_subset( 293 | dset_len=len(train_dataset), 294 | train_size=cfg.train_size, 295 | seed=cfg.seed, 296 | filename=join(cfg.save_dir, "splits.npz"), 297 | ) 298 | self.train_smiles = train_dataset.smiles 299 | train_dataset = Subset(train_dataset, self.idx_train) 300 | 301 | self.remove_h = cfg.remove_hs 302 | super().__init__( 303 | cfg, 304 | train_dataset=train_dataset, 305 | val_dataset=val_dataset, 306 | test_dataset=test_dataset, 307 | ) 308 | 309 | def get_dataloader(self, dataset, stage): 310 | if stage == "train": 311 | batch_size = self.cfg.batch_size 312 | shuffle = True 313 | elif stage in ["val", "test"]: 314 | batch_size = self.cfg.inference_batch_size 315 | shuffle = False 316 | 317 | dl = DataLoader( 318 | dataset=dataset, 319 | batch_size=batch_size, 320 | num_workers=self.cfg.num_workers, 321 | pin_memory=True, 322 | shuffle=shuffle, 323 | ) 324 | 325 | return dl -------------------------------------------------------------------------------- /eqgat_diff/experiments/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/eqgat_diff/experiments/diffusion/__init__.py -------------------------------------------------------------------------------- /eqgat_diff/experiments/diffusion/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import List, Optional 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from rdkit.Chem import RDConfig 8 | from rdkit.Chem.Scaffolds.MurckoScaffold import GetScaffoldForMol 9 | from torch_geometric.utils import remove_self_loops, sort_edge_index 10 | from torch_scatter import scatter_mean 11 | 12 | from experiments.utils import get_edges, zero_mean 13 | 14 | sys.path.append(os.path.join(RDConfig.RDContribDir, "IFG")) 15 | from ifg import identify_functional_groups 16 | 17 | 18 | def initialize_edge_attrs_reverse( 19 | edge_index_global, n, bonds_prior, num_bond_classes, device 20 | ): 21 | # edge types for FC graph 22 | j, i = edge_index_global 23 | mask = j < i 24 | mask_i = i[mask] 25 | mask_j = j[mask] 26 | nE = len(mask_i) 27 | edge_attr_triu = torch.multinomial(bonds_prior, num_samples=nE, replacement=True) 28 | 29 | j = torch.concat([mask_j, mask_i]) 30 | i = torch.concat([mask_i, mask_j]) 31 | edge_index_global = torch.stack([j, i], dim=0) 32 | edge_attr_global = torch.concat([edge_attr_triu, edge_attr_triu], dim=0) 33 | edge_index_global, edge_attr_global = sort_edge_index( 34 | edge_index=edge_index_global, edge_attr=edge_attr_global, sort_by_row=False 35 | ) 36 | j, i = edge_index_global 37 | mask = j < i 38 | mask_i = i[mask] 39 | mask_j = j[mask] 40 | 41 | # some assert 42 | 43 | edge_attr_global_dense = torch.zeros(size=(n, n), device=device, dtype=torch.long) 44 | edge_attr_global_dense[ 45 | edge_index_global[0], edge_index_global[1] 46 | ] = edge_attr_global 47 | assert (edge_attr_global_dense - edge_attr_global_dense.T).sum().float() == 0.0 48 | 49 | edge_attr_global = F.one_hot(edge_attr_global, num_bond_classes).float() 50 | 51 | return edge_attr_global, edge_index_global, mask, mask_i 52 | 53 | 54 | def get_joint_edge_attrs( 55 | pos, 56 | pos_pocket, 57 | batch, 58 | batch_pocket, 59 | edge_attr_global_lig, 60 | num_bond_classes, 61 | device, 62 | ): 63 | edge_index_global = get_edges( 64 | batch, batch_pocket, pos, pos_pocket, cutoff_p=5, cutoff_lp=5 65 | ) 66 | edge_index_global = sort_edge_index(edge_index=edge_index_global, sort_by_row=False) 67 | edge_index_global, _ = remove_self_loops(edge_index_global) 68 | edge_attr_global = torch.zeros( 69 | (edge_index_global.size(1), num_bond_classes), 70 | dtype=torch.float32, 71 | device=device, 72 | ) 73 | edge_mask = (edge_index_global[0] < len(batch)) & ( 74 | edge_index_global[1] < len(batch) 75 | ) 76 | edge_mask_pocket = (edge_index_global[0] >= len(batch)) & ( 77 | edge_index_global[1] >= len(batch) 78 | ) 79 | edge_attr_global[edge_mask] = edge_attr_global_lig 80 | 81 | if num_bond_classes == 7: 82 | edge_mask_ligand_pocket = (edge_index_global[0] < len(batch)) & ( 83 | edge_index_global[1] >= len(batch) 84 | ) 85 | edge_mask_pocket_ligand = (edge_index_global[0] >= len(batch)) & ( 86 | edge_index_global[1] < len(batch) 87 | ) 88 | edge_attr_global[edge_mask_pocket] = ( 89 | torch.tensor([0, 0, 0, 0, 0, 0, 1]).float().to(edge_attr_global.device) 90 | ) 91 | edge_attr_global[edge_mask_ligand_pocket] = ( 92 | torch.tensor([0, 0, 0, 0, 0, 1, 0]).float().to(edge_attr_global.device) 93 | ) 94 | edge_attr_global[edge_mask_pocket_ligand] = ( 95 | torch.tensor([0, 0, 0, 0, 0, 1, 0]).float().to(edge_attr_global.device) 96 | ) 97 | else: 98 | edge_attr_global[edge_mask_pocket] = ( 99 | torch.tensor([0, 0, 0, 0, 1]).float().to(edge_attr_global.device) 100 | ) 101 | # edge_attr_global[edge_mask_pocket] = 0.0 102 | 103 | batch_full = torch.cat([batch, batch_pocket]) 104 | batch_edge_global = batch_full[edge_index_global[0]] # 105 | 106 | return ( 107 | edge_index_global, 108 | edge_attr_global, 109 | batch_edge_global, 110 | edge_mask, 111 | ) 112 | 113 | 114 | def bond_guidance( 115 | pos, 116 | node_feats_in, 117 | temb, 118 | bond_model, 119 | batch, 120 | batch_edge_global, 121 | edge_attr_global, 122 | edge_index_local, 123 | edge_index_global, 124 | ): 125 | guidance_type = "logsum" 126 | guidance_scale = 1.0e-4 127 | 128 | bs = len(batch.bincount()) 129 | with torch.enable_grad(): 130 | node_feats_in = node_feats_in.detach() 131 | pos = pos.detach().requires_grad_(True) 132 | bond_prediction = bond_model( 133 | x=node_feats_in, 134 | t=temb, 135 | pos=pos, 136 | edge_index_local=edge_index_local, 137 | edge_index_global=edge_index_global, 138 | edge_attr_global=edge_attr_global, 139 | batch=batch, 140 | batch_edge_global=batch_edge_global, 141 | ) 142 | if guidance_type == "ensemble": 143 | # TO-DO 144 | raise NotImplementedError 145 | elif guidance_type == "logsum": 146 | uncertainty = torch.sigmoid(-torch.logsumexp(bond_prediction, dim=-1)) 147 | uncertainty = ( 148 | 0.5 149 | * scatter_mean( 150 | uncertainty, 151 | index=edge_index_global[1], 152 | dim=0, 153 | dim_size=pos.size(0), 154 | ).log() 155 | ) 156 | uncertainty = scatter_mean(uncertainty, index=batch, dim=0, dim_size=bs) 157 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(uncertainty)] 158 | dist_shift = -torch.autograd.grad( 159 | [uncertainty], 160 | [pos], 161 | grad_outputs=grad_outputs, 162 | create_graph=False, 163 | retain_graph=False, 164 | )[0] 165 | 166 | return pos + guidance_scale * dist_shift 167 | 168 | 169 | def energy_guidance( 170 | pos, 171 | node_feats_in, 172 | temb, 173 | energy_model, 174 | batch, 175 | batch_size, 176 | signal=1.0e-3, 177 | guidance_scale=100, 178 | optimization="minimize", 179 | ): 180 | with torch.enable_grad(): 181 | node_feats_in = node_feats_in.detach() 182 | pos = pos.detach().requires_grad_(True) 183 | out = energy_model( 184 | x=node_feats_in, 185 | t=temb, 186 | pos=pos, 187 | batch=batch, 188 | ) 189 | if optimization == "minimize": 190 | sign = -1.0 191 | elif optimization == "maximize": 192 | sign = 1.0 193 | else: 194 | raise Exception("Optimization arg needs to be 'minimize' or 'maximize'!") 195 | energy_prediction = sign * guidance_scale * out["property_pred"] 196 | 197 | grad_outputs: List[Optional[torch.Tensor]] = [ 198 | torch.ones_like(energy_prediction) 199 | ] 200 | pos_shift = torch.autograd.grad( 201 | [energy_prediction], 202 | [pos], 203 | grad_outputs=grad_outputs, 204 | create_graph=False, 205 | retain_graph=False, 206 | )[0] 207 | 208 | pos_shift = zero_mean(pos_shift, batch=batch, dim_size=batch_size, dim=0) 209 | 210 | pos = pos + signal * pos_shift 211 | pos = zero_mean(pos, batch=batch, dim_size=batch_size, dim=0) 212 | 213 | return pos.detach() 214 | 215 | 216 | def extract_scaffolds_(batch_data): 217 | def scaffold_per_mol(mol): 218 | for a in mol.GetAtoms(): 219 | a.SetIntProp("org_idx", a.GetIdx()) 220 | 221 | scaffold = GetScaffoldForMol(mol) 222 | scaffold_atoms = [a.GetIntProp("org_idx") for a in scaffold.GetAtoms()] 223 | mask = torch.zeros(mol.GetNumAtoms(), dtype=bool) 224 | mask[torch.tensor(scaffold_atoms)] = 1 225 | return mask 226 | 227 | batch_data.scaffold_mask = torch.hstack( 228 | [scaffold_per_mol(mol) for mol in batch_data.mol] 229 | ) 230 | 231 | 232 | def extract_func_groups_(batch_data, includeHs=True): 233 | def func_groups_per_mol(mol, includeHs=True): 234 | fgroups = identify_functional_groups(mol) 235 | findices = [] 236 | for f in fgroups: 237 | findices.extend(list(f.atomIds)) 238 | if includeHs: # include neighboring H atoms in functional groups 239 | findices_incl_h = [] 240 | for fi in findices: 241 | hidx = [ 242 | n.GetIdx() 243 | for n in mol.GetAtomWithIdx(fi).GetNeighbors() 244 | if n.GetSymbol() == "H" 245 | ] 246 | findices_incl_h.extend([fi] + hidx) 247 | findices = findices_incl_h 248 | mask = torch.zeros(mol.GetNumAtoms(), dtype=bool) 249 | mask[torch.tensor(findices)] = 1 250 | return mask 251 | 252 | batch_data.func_group_mask = torch.hstack( 253 | [func_groups_per_mol(mol, includeHs) for mol in batch_data.mol] 254 | ) 255 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/docking.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import tempfile 4 | import numpy as np 5 | import torch 6 | from pathlib import Path 7 | import argparse 8 | import pandas as pd 9 | from rdkit import Chem 10 | from tqdm import tqdm 11 | from experiments.utils import write_sdf_file 12 | 13 | 14 | def calculate_smina_score(pdb_file, sdf_file): 15 | # add '-o _smina.sdf' if you want to see the output 16 | out = os.popen(f"smina.static -l {sdf_file} -r {pdb_file} " f"--score_only").read() 17 | matches = re.findall(r"Affinity:[ ]+([+-]?[0-9]*[.]?[0-9]+)[ ]+\(kcal/mol\)", out) 18 | return [float(x) for x in matches] 19 | 20 | 21 | def smina_score(rdmols, receptor_file): 22 | """ 23 | Calculate smina score 24 | :param rdmols: List of RDKit molecules 25 | :param receptor_file: Receptor pdb/pdbqt file or list of receptor files 26 | :return: Smina score for each input molecule (list) 27 | """ 28 | 29 | if isinstance(receptor_file, list): 30 | scores = [] 31 | for mol, rec_file in zip(rdmols, receptor_file): 32 | with tempfile.NamedTemporaryFile(suffix=".sdf") as tmp: 33 | tmp_file = tmp.name 34 | write_sdf_file(tmp_file, [mol]) 35 | scores.extend(calculate_smina_score(rec_file, tmp_file)) 36 | 37 | # Use same receptor file for all molecules 38 | else: 39 | with tempfile.NamedTemporaryFile(suffix=".sdf") as tmp: 40 | tmp_file = tmp.name 41 | write_sdf_file(tmp_file, rdmols) 42 | scores = calculate_smina_score(receptor_file, tmp_file) 43 | 44 | return scores 45 | 46 | 47 | def sdf_to_pdbqt(sdf_file, pdbqt_outfile, mol_id): 48 | os.popen( 49 | f"obabel {sdf_file} -O {pdbqt_outfile} " f"-f {mol_id + 1} -l {mol_id + 1}" 50 | ).read() 51 | return pdbqt_outfile 52 | 53 | 54 | def calculate_qvina2_score( 55 | receptor_file, sdf_file, out_dir, size=20, exhaustiveness=16, return_rdmol=False 56 | ): 57 | receptor_file = Path(receptor_file) 58 | sdf_file = Path(sdf_file) 59 | 60 | if receptor_file.suffix == ".pdb": 61 | # prepare receptor, requires Python 2.7 62 | receptor_pdbqt_file = Path(out_dir, receptor_file.stem + ".pdbqt") 63 | os.popen(f"prepare_receptor4.py -r {receptor_file} -O {receptor_pdbqt_file}") 64 | else: 65 | receptor_pdbqt_file = receptor_file 66 | 67 | scores = [] 68 | rdmols = [] # for if return rdmols 69 | suppl = Chem.SDMolSupplier(str(sdf_file), sanitize=False) 70 | for i, mol in enumerate(suppl): # sdf file may contain several ligands 71 | ligand_name = f"{sdf_file.stem}_{i}" 72 | # prepare ligand 73 | ligand_pdbqt_file = Path(out_dir, ligand_name + ".pdbqt") 74 | out_sdf_file = Path(out_dir, ligand_name + "_out.sdf") 75 | 76 | if out_sdf_file.exists(): 77 | with open(out_sdf_file, "r") as f: 78 | scores.append( 79 | min( 80 | [ 81 | float(x.split()[2]) 82 | for x in f.readlines() 83 | if x.startswith(" VINA RESULT:") 84 | ] 85 | ) 86 | ) 87 | 88 | else: 89 | sdf_to_pdbqt(sdf_file, ligand_pdbqt_file, i) 90 | 91 | # center box at ligand's center of mass 92 | cx, cy, cz = mol.GetConformer().GetPositions().mean(0) 93 | 94 | # run QuickVina 2 95 | try: 96 | os.stat("----") 97 | PATH = "----" 98 | except PermissionError: 99 | PATH = "----" 100 | 101 | out = os.popen( 102 | f"/{PATH} --receptor {receptor_pdbqt_file} " 103 | f"--ligand {ligand_pdbqt_file} " 104 | f"--center_x {cx:.4f} --center_y {cy:.4f} --center_z {cz:.4f} " 105 | f"--size_x {size} --size_y {size} --size_z {size} " 106 | f"--exhaustiveness {exhaustiveness}", 107 | ).read() 108 | # clean up 109 | ligand_pdbqt_file.unlink() 110 | 111 | if "-----+------------+----------+----------" not in out: 112 | scores.append(np.nan) 113 | continue 114 | 115 | out_split = out.splitlines() 116 | best_idx = out_split.index("-----+------------+----------+----------") + 1 117 | best_line = out_split[best_idx].split() 118 | assert best_line[0] == "1" 119 | scores.append(float(best_line[1])) 120 | 121 | out_pdbqt_file = Path(out_dir, ligand_name + "_out.pdbqt") 122 | if out_pdbqt_file.exists(): 123 | os.popen(f"obabel {out_pdbqt_file} -O {out_sdf_file}").read() 124 | 125 | # clean up 126 | out_pdbqt_file.unlink() 127 | 128 | if return_rdmol: 129 | rdmol = Chem.SDMolSupplier(str(out_sdf_file))[0] 130 | rdmols.append(rdmol) 131 | 132 | if return_rdmol: 133 | return scores, rdmols 134 | else: 135 | return scores 136 | 137 | 138 | if __name__ == "__main__": 139 | parser = argparse.ArgumentParser("QuickVina evaluation") 140 | parser.add_argument("--pdbqt_dir", type=Path, help="Receptor files in pdbqt format") 141 | parser.add_argument( 142 | "--sdf_dir", type=Path, default=None, help="Ligand files in sdf format" 143 | ) 144 | parser.add_argument("--sdf_files", type=Path, nargs="+", default=None) 145 | parser.add_argument("--out_dir", type=Path) 146 | parser.add_argument("--write_csv", action="store_true") 147 | parser.add_argument("--write_dict", action="store_true") 148 | parser.add_argument("--dataset", type=str, default="moad") 149 | args = parser.parse_args() 150 | 151 | assert (args.sdf_dir is not None) ^ (args.sdf_files is not None) 152 | 153 | args.out_dir.mkdir(exist_ok=True) 154 | 155 | results = {"receptor": [], "ligand": [], "scores": []} 156 | results_dict = {} 157 | sdf_files = ( 158 | list(args.sdf_dir.glob("[!.]*.sdf")) 159 | if args.sdf_dir is not None 160 | else args.sdf_files 161 | ) 162 | pbar = tqdm(sdf_files) 163 | for sdf_file in pbar: 164 | pbar.set_description(f"Processing {sdf_file.name}") 165 | 166 | if args.dataset == "moad": 167 | """ 168 | Ligand file names should be of the following form: 169 | __.sdf 170 | where and cannot contain any 171 | underscores, e.g.: 1abc-bio1_pocket0_gen.sdf 172 | """ 173 | ligand_name = sdf_file.stem 174 | receptor_name, pocket_id, *suffix = ligand_name.split("_") 175 | suffix = "_".join(suffix) 176 | receptor_file = Path(args.pdbqt_dir, receptor_name + ".pdbqt") 177 | elif args.dataset == "crossdocked": 178 | ligand_name = sdf_file.stem 179 | receptor_name = ligand_name[:-4] 180 | receptor_file = Path(args.pdbqt_dir, receptor_name + ".pdbqt") 181 | 182 | # try: 183 | scores, rdmols = calculate_qvina2_score( 184 | receptor_file, sdf_file, args.out_dir, return_rdmol=True 185 | ) 186 | # except AttributeError as e: 187 | # print(e) 188 | # continue 189 | results["receptor"].append(str(receptor_file)) 190 | results["ligand"].append(str(sdf_file)) 191 | results["scores"].append(scores) 192 | 193 | if args.write_dict: 194 | results_dict[ligand_name] = { 195 | "receptor": str(receptor_file), 196 | "ligand": str(sdf_file), 197 | "scores": scores, 198 | "rmdols": rdmols, 199 | } 200 | 201 | if args.write_csv: 202 | df = pd.DataFrame.from_dict(results) 203 | df.to_csv(Path(args.out_dir, "qvina2_scores.csv")) 204 | 205 | if args.write_dict: 206 | torch.save(results_dict, Path(args.out_dir, "qvina2_scores.pt")) 207 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/docking_mgl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | 5 | 6 | def pdbs_to_pdbqts(pdb_dir, pdbqt_dir, dataset): 7 | for file in glob.glob(os.path.join(pdb_dir, "*.pdb")): 8 | name = os.path.splitext(os.path.basename(file))[0] 9 | outfile = os.path.join(pdbqt_dir, name + ".pdbqt") 10 | pdb_to_pdbqt(file, outfile, dataset) 11 | print("Wrote converted file to {}".format(outfile)) 12 | 13 | 14 | def pdb_to_pdbqt(pdb_file, pdbqt_file, dataset): 15 | if os.path.exists(pdbqt_file): 16 | return pdbqt_file 17 | if dataset == "crossdocked": 18 | os.system("prepare_receptor4.py -r {} -o {}".format(pdb_file, pdbqt_file)) 19 | elif dataset == "bindingmoad": 20 | os.system( 21 | "prepare_receptor4.py -r {} -o {} -A checkhydrogens -e".format( 22 | pdb_file, pdbqt_file 23 | ) 24 | ) 25 | else: 26 | raise NotImplementedError 27 | return pdbqt_file 28 | 29 | 30 | if __name__ == "__main__": 31 | pdbs_to_pdbqts(sys.argv[1], sys.argv[2], sys.argv[3]) 32 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/hparams.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from experiments.utils import LoadFromFile 4 | 5 | DEFAULT_SAVE_DIR = os.path.join(os.getcwd(), "3DcoordsAtomsBonds_0") 6 | 7 | if not os.path.exists(DEFAULT_SAVE_DIR): 8 | os.makedirs(DEFAULT_SAVE_DIR) 9 | 10 | 11 | def add_arguments(parser): 12 | """Helper function to fill the parser object. 13 | 14 | Args: 15 | parser: Parser object 16 | Returns: 17 | parser: Updated parser object 18 | """ 19 | 20 | # Load yaml file 21 | parser.add_argument( 22 | "--conf", "-c", type=open, action=LoadFromFile, help="Configuration yaml file" 23 | ) # keep first 24 | 25 | # Load from checkpoint 26 | parser.add_argument("--load-ckpt", default=None, type=str) 27 | parser.add_argument("--load-ckpt-from-pretrained", default=None, type=str) 28 | 29 | # DATA and FILES 30 | parser.add_argument("-s", "--save-dir", default=DEFAULT_SAVE_DIR, type=str) 31 | parser.add_argument("--test-save-dir", default=DEFAULT_SAVE_DIR, type=str) 32 | 33 | parser.add_argument( 34 | "--dataset", 35 | default="drugs", 36 | choices=[ 37 | "qm9", 38 | "drugs", 39 | "aqm", 40 | "aqm_qm7x", 41 | "pcqm4mv2", 42 | "pepconf", 43 | "crossdocked", 44 | ], 45 | ) 46 | parser.add_argument( 47 | "--dataset-root", default="----" 48 | ) 49 | parser.add_argument("--use-adaptive-loader", default=True, action="store_true") 50 | parser.add_argument("--remove-hs", default=False, action="store_true") 51 | parser.add_argument("--select-train-subset", default=False, action="store_true") 52 | parser.add_argument("--train-size", default=0.8, type=float) 53 | parser.add_argument("--val-size", default=0.1, type=float) 54 | parser.add_argument("--test-size", default=0.1, type=float) 55 | 56 | parser.add_argument("--dropout-prob", default=0.3, type=float) 57 | 58 | # LEARNING 59 | parser.add_argument("-b", "--batch-size", default=32, type=int) 60 | parser.add_argument("-ib", "--inference-batch-size", default=32, type=int) 61 | parser.add_argument("--gamma", default=0.975, type=float) 62 | parser.add_argument("--grad-clip-val", default=10.0, type=float) 63 | parser.add_argument( 64 | "--lr-scheduler", 65 | default="reduce_on_plateau", 66 | choices=["reduce_on_plateau", "cosine_annealing", "cyclic"], 67 | ) 68 | parser.add_argument( 69 | "--optimizer", 70 | default="adam", 71 | choices=["adam", "sgd"], 72 | ) 73 | parser.add_argument("--lr", default=5e-4, type=float) 74 | parser.add_argument("--lr-min", default=5e-5, type=float) 75 | parser.add_argument("--lr-step-size", default=10000, type=int) 76 | parser.add_argument("--lr-frequency", default=5, type=int) 77 | parser.add_argument("--lr-patience", default=20, type=int) 78 | parser.add_argument("--lr-cooldown", default=5, type=int) 79 | parser.add_argument("--lr-factor", default=0.75, type=float) 80 | 81 | # MODEL 82 | parser.add_argument("--sdim", default=256, type=int) 83 | parser.add_argument("--vdim", default=64, type=int) 84 | parser.add_argument("--latent_dim", default=None, type=int) 85 | parser.add_argument("--rbf-dim", default=32, type=int) 86 | parser.add_argument("--edim", default=32, type=int) 87 | parser.add_argument("--edge-mp", default=False, action="store_true") 88 | parser.add_argument("--vector-aggr", default="mean", type=str) 89 | parser.add_argument("--num-layers", default=7, type=int) 90 | parser.add_argument("--fully-connected", default=True, action="store_true") 91 | parser.add_argument("--local-global-model", default=False, action="store_true") 92 | parser.add_argument("--local-edge-attrs", default=False, action="store_true") 93 | parser.add_argument("--use-cross-product", default=False, action="store_true") 94 | parser.add_argument("--cutoff-local", default=7.0, type=float) 95 | parser.add_argument("--cutoff-global", default=10.0, type=float) 96 | parser.add_argument("--energy-training", default=False, action="store_true") 97 | parser.add_argument("--property-training", default=False, action="store_true") 98 | parser.add_argument( 99 | "--regression-property", 100 | default="polarizability", 101 | type=str, 102 | choices=[ 103 | "dipole_norm", 104 | "total_energy", 105 | "HOMO-LUMO_gap", 106 | "dispersion", 107 | "atomisation_energy", 108 | "polarizability", 109 | ], 110 | ) 111 | parser.add_argument("--energy-loss", default="l2", type=str, choices=["l2", "l1"]) 112 | parser.add_argument("--use-pos-norm", default=False, action="store_true") 113 | 114 | # For Discrete: Include more features: (is_aromatic, is_in_ring, hybridization) 115 | parser.add_argument("--additional-feats", default=False, action="store_true") 116 | parser.add_argument("--use-qm-props", default=False, action="store_true") 117 | parser.add_argument("--build-mol-with-addfeats", default=False, action="store_true") 118 | 119 | # DIFFUSION 120 | parser.add_argument( 121 | "--continuous", 122 | default=False, 123 | action="store_true", 124 | help="If the diffusion process is applied on continuous time variable. Defaults to False", 125 | ) 126 | parser.add_argument( 127 | "--noise-scheduler", 128 | default="cosine", 129 | choices=["linear", "cosine", "quad", "sigmoid", "adaptive", "linear-time"], 130 | ) 131 | parser.add_argument("--eps-min", default=1e-3, type=float) 132 | parser.add_argument("--beta-min", default=1e-4, type=float) 133 | parser.add_argument("--beta-max", default=2e-2, type=float) 134 | parser.add_argument("--timesteps", default=500, type=int) 135 | parser.add_argument("--max-time", type=str, default=None) 136 | parser.add_argument("--lc-coords", default=3.0, type=float) 137 | parser.add_argument("--lc-atoms", default=0.4, type=float) 138 | parser.add_argument("--lc-bonds", default=2.0, type=float) 139 | parser.add_argument("--lc-charges", default=1.0, type=float) 140 | parser.add_argument("--lc-mulliken", default=1.5, type=float) 141 | parser.add_argument("--lc-wbo", default=2.0, type=float) 142 | 143 | parser.add_argument("--pocket-noise-std", default=0.1, type=float) 144 | parser.add_argument( 145 | "--use-ligand-dataset-sizes", default=False, action="store_true" 146 | ) 147 | 148 | parser.add_argument( 149 | "--loss-weighting", 150 | default="snr_t", 151 | choices=["snr_s_t", "snr_t", "exp_t", "expt_t_half", "uniform"], 152 | ) 153 | parser.add_argument("--snr-clamp-min", default=0.05, type=float) 154 | parser.add_argument("--snr-clamp-max", default=1.50, type=float) 155 | 156 | parser.add_argument( 157 | "--ligand-pocket-interaction", default=False, action="store_true" 158 | ) 159 | parser.add_argument("--diffusion-pretraining", default=False, action="store_true") 160 | parser.add_argument( 161 | "--continuous-param", default="data", type=str, choices=["data", "noise"] 162 | ) 163 | parser.add_argument("--atoms-categorical", default=False, action="store_true") 164 | parser.add_argument("--bonds-categorical", default=False, action="store_true") 165 | 166 | parser.add_argument("--atom-type-masking", default=False, action="store_true") 167 | parser.add_argument("--use-absorbing-state", default=False, action="store_true") 168 | 169 | parser.add_argument("--num-bond-classes", default=5, type=int) 170 | parser.add_argument("--num-charge-classes", default=6, type=int) 171 | 172 | # BOND PREDICTION AND GUIDANCE: 173 | parser.add_argument("--bond-guidance-model", default=False, action="store_true") 174 | parser.add_argument("--bond-prediction", default=False, action="store_true") 175 | parser.add_argument("--bond-model-guidance", default=False, action="store_true") 176 | parser.add_argument("--energy-model-guidance", default=False, action="store_true") 177 | parser.add_argument( 178 | "--polarizabilty-model-guidance", default=False, action="store_true" 179 | ) 180 | parser.add_argument("--ckpt-bond-model", default=None, type=str) 181 | parser.add_argument("--ckpt-energy-model", default=None, type=str) 182 | parser.add_argument("--ckpt-polarizabilty-model", default=None, type=str) 183 | parser.add_argument("--guidance-scale", default=1.0e-4, type=float) 184 | 185 | # CONTEXT 186 | parser.add_argument("--context-mapping", default=False, action="store_true") 187 | parser.add_argument("--num-context-features", default=0, type=int) 188 | parser.add_argument("--properties-list", default=[], nargs="+", type=str) 189 | 190 | # PROPERTY PREDICTION 191 | parser.add_argument("--property-prediction", default=False, action="store_true") 192 | 193 | # LATENT 194 | parser.add_argument("--prior-beta", default=1.0, type=float) 195 | parser.add_argument("--sdim-latent", default=256, type=int) 196 | parser.add_argument("--vdim-latent", default=64, type=int) 197 | parser.add_argument("--latent-dim", default=None, type=int) 198 | parser.add_argument("--edim-latent", default=32, type=int) 199 | parser.add_argument("--num-layers-latent", default=7, type=int) 200 | parser.add_argument("--latent-layers", default=7, type=int) 201 | parser.add_argument("--latentmodel", default="diffusion", type=str) 202 | parser.add_argument("--latent-detach", default=False, action="store_true") 203 | 204 | # GENERAL 205 | parser.add_argument("-i", "--id", type=int, default=0) 206 | parser.add_argument("-g", "--gpus", default=1, type=int) 207 | parser.add_argument("-e", "--num-epochs", default=300, type=int) 208 | parser.add_argument("--eval-freq", default=1.0, type=float) 209 | parser.add_argument("--test-interval", default=5, type=int) 210 | parser.add_argument("-nh", "--no_h", default=False, action="store_true") 211 | parser.add_argument("--precision", default=32, type=int) 212 | parser.add_argument("--detect-anomaly", default=False, action="store_true") 213 | parser.add_argument("--num-workers", default=4, type=int) 214 | parser.add_argument( 215 | "--max-num-conformers", 216 | default=5, 217 | type=int, 218 | help="Maximum number of conformers per molecule. \ 219 | Defaults to 30. Set to -1 for all conformers available in database", 220 | ) 221 | parser.add_argument("--accum-batch", default=1, type=int) 222 | parser.add_argument("--max-num-neighbors", default=128, type=int) 223 | parser.add_argument("--ema-decay", default=0.9999, type=float) 224 | parser.add_argument("--weight-decay", default=0.9999, type=float) 225 | parser.add_argument("--seed", default=42, type=int) 226 | parser.add_argument("--backprop-local", default=False, action="store_true") 227 | 228 | # SAMPLING 229 | parser.add_argument("--num-test-graphs", default=10000, type=int) 230 | parser.add_argument("--calculate-energy", default=False, action="store_true") 231 | parser.add_argument("--save-xyz", default=False, action="store_true") 232 | 233 | return parser 234 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/losses.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import Tensor, nn 6 | from torch_scatter import scatter_mean 7 | 8 | 9 | class DiffusionLoss(nn.Module): 10 | def __init__( 11 | self, 12 | modalities: List = ["coords", "atoms", "charges", "bonds"], 13 | param: List = ["data", "data", "data", "data"], 14 | ) -> None: 15 | super().__init__() 16 | assert len(modalities) == len(param) 17 | self.modalities = modalities 18 | self.param = param 19 | 20 | if "coords" in modalities: 21 | self.regression_key = "coords" 22 | elif "latents" in modalities: 23 | self.regression_key = "latents" 24 | else: 25 | raise ValueError 26 | 27 | def loss_non_nans(self, loss: Tensor, modality: str) -> Tensor: 28 | m = loss.isnan() 29 | if torch.any(m): 30 | print(f"Recovered NaNs in {modality}. Selecting NoN-Nans") 31 | return loss[~m], m 32 | 33 | def forward( 34 | self, 35 | true_data: Dict, 36 | pred_data: Dict, 37 | batch: Tensor, 38 | bond_aggregation_index: Tensor, 39 | weights: Optional[Tensor] = None, 40 | ) -> Dict: 41 | batch_size = len(batch.unique()) 42 | mulliken_loss = None 43 | wbo_loss = None 44 | 45 | if weights is not None: 46 | assert len(weights) == batch_size 47 | 48 | regr_loss = F.mse_loss( 49 | pred_data[self.regression_key], 50 | true_data[self.regression_key], 51 | reduction="none", 52 | ).mean(-1) 53 | regr_loss = scatter_mean(regr_loss, index=batch, dim=0, dim_size=batch_size) 54 | regr_loss, m = self.loss_non_nans(regr_loss, self.regression_key) 55 | 56 | regr_loss *= weights[~m] 57 | regr_loss = torch.sum(regr_loss, dim=0) 58 | 59 | if "mulliken" in true_data: 60 | mulliken_loss = F.mse_loss( 61 | pred_data["mulliken"], 62 | true_data["mulliken"], 63 | reduction="none", 64 | ).mean(-1) 65 | mulliken_loss = scatter_mean( 66 | mulliken_loss, index=batch, dim=0, dim_size=batch_size 67 | ) 68 | mulliken_loss *= weights 69 | mulliken_loss = torch.sum(mulliken_loss, dim=0) 70 | 71 | if "wbo" in true_data: 72 | wbo_loss = F.mse_loss( 73 | pred_data["wbo"], 74 | true_data["wbo"], 75 | reduction="none", 76 | ).mean(-1) 77 | wbo_loss = 0.5 * scatter_mean( 78 | wbo_loss, 79 | index=bond_aggregation_index, 80 | dim=0, 81 | dim_size=true_data["atoms"].size(0), 82 | ) 83 | wbo_loss = scatter_mean( 84 | wbo_loss, index=batch, dim=0, dim_size=batch_size 85 | ) 86 | wbo_loss *= weights 87 | wbo_loss = torch.sum(wbo_loss, dim=0) 88 | 89 | if self.param[self.modalities.index("atoms")] == "data": 90 | fnc = F.cross_entropy 91 | take_mean = False 92 | else: 93 | fnc = F.mse_loss 94 | take_mean = True 95 | 96 | atoms_loss = fnc(pred_data["atoms"], true_data["atoms"], reduction="none") 97 | if take_mean: 98 | atoms_loss = atoms_loss.mean(dim=1) 99 | atoms_loss = scatter_mean( 100 | atoms_loss, index=batch, dim=0, dim_size=batch_size 101 | ) 102 | atoms_loss, m = self.loss_non_nans(atoms_loss, "atoms") 103 | atoms_loss *= weights[~m] 104 | atoms_loss = torch.sum(atoms_loss, dim=0) 105 | 106 | if self.param[self.modalities.index("charges")] == "data": 107 | fnc = F.cross_entropy 108 | take_mean = False 109 | else: 110 | fnc = F.mse_loss 111 | take_mean = True 112 | 113 | charges_loss = fnc( 114 | pred_data["charges"], true_data["charges"], reduction="none" 115 | ) 116 | if take_mean: 117 | charges_loss = charges_loss.mean(dim=1) 118 | charges_loss = scatter_mean( 119 | charges_loss, index=batch, dim=0, dim_size=batch_size 120 | ) 121 | charges_loss, m = self.loss_non_nans(charges_loss, "charges") 122 | charges_loss *= weights[~m] 123 | charges_loss = torch.sum(charges_loss, dim=0) 124 | 125 | if self.param[self.modalities.index("bonds")] == "data": 126 | fnc = F.cross_entropy 127 | take_mean = False 128 | else: 129 | fnc = F.mse_loss 130 | take_mean = True 131 | 132 | bonds_loss = fnc(pred_data["bonds"], true_data["bonds"], reduction="none") 133 | if take_mean: 134 | bonds_loss = bonds_loss.mean(dim=1) 135 | bonds_loss = 0.5 * scatter_mean( 136 | bonds_loss, 137 | index=bond_aggregation_index, 138 | dim=0, 139 | dim_size=true_data["atoms"].size(0), 140 | ) 141 | bonds_loss = scatter_mean( 142 | bonds_loss, index=batch, dim=0, dim_size=batch_size 143 | ) 144 | bonds_loss, m = self.loss_non_nans(bonds_loss, "bonds") 145 | bonds_loss *= weights[~m] 146 | bonds_loss = bonds_loss.sum(dim=0) 147 | 148 | if "ring" in self.modalities: 149 | ring_loss = F.cross_entropy( 150 | pred_data["ring"], true_data["ring"], reduction="none" 151 | ) 152 | ring_loss = scatter_mean( 153 | ring_loss, index=batch, dim=0, dim_size=batch_size 154 | ) 155 | ring_loss, m = self.loss_non_nans(ring_loss, "ring") 156 | ring_loss *= weights[~m] 157 | ring_loss = torch.sum(ring_loss, dim=0) 158 | else: 159 | ring_loss = None 160 | 161 | if "aromatic" in self.modalities: 162 | aromatic_loss = F.cross_entropy( 163 | pred_data["aromatic"], true_data["aromatic"], reduction="none" 164 | ) 165 | aromatic_loss = scatter_mean( 166 | aromatic_loss, index=batch, dim=0, dim_size=batch_size 167 | ) 168 | aromatic_loss, m = self.loss_non_nans(aromatic_loss, "aromatic") 169 | aromatic_loss *= weights[~m] 170 | aromatic_loss = torch.sum(aromatic_loss, dim=0) 171 | else: 172 | aromatic_loss = None 173 | 174 | if "hybridization" in self.modalities: 175 | hybridization_loss = F.cross_entropy( 176 | pred_data["hybridization"], 177 | true_data["hybridization"], 178 | reduction="none", 179 | ) 180 | hybridization_loss = scatter_mean( 181 | hybridization_loss, index=batch, dim=0, dim_size=batch_size 182 | ) 183 | hybridization_loss, m = self.loss_non_nans( 184 | hybridization_loss, "hybridization" 185 | ) 186 | hybridization_loss *= weights[~m] 187 | hybridization_loss = torch.sum(hybridization_loss, dim=0) 188 | else: 189 | hybridization_loss = None 190 | 191 | else: 192 | regr_loss = F.mse_loss( 193 | pred_data[self.regression_key], 194 | true_data[self.regression_key], 195 | reduction="mean", 196 | ).mean(-1) 197 | if self.param[self.modalities.index("atoms")] == "data": 198 | fnc = F.cross_entropy 199 | else: 200 | fnc = F.mse_loss 201 | atoms_loss = fnc(pred_data["atoms"], true_data["atoms"], reduction="mean") 202 | if self.param[self.modalities.index("charges")] == "data": 203 | fnc = F.cross_entropy 204 | else: 205 | fnc = F.mse_loss 206 | charges_loss = fnc( 207 | pred_data["charges"], true_data["charges"], reduction="mean" 208 | ) 209 | if self.param[self.modalities.index("bonds")] == "data": 210 | fnc = F.cross_entropy 211 | else: 212 | fnc = F.mse_loss 213 | bonds_loss = fnc(pred_data["bonds"], true_data["bonds"], reduction="mean") 214 | 215 | if "ring" in self.modalities: 216 | ring_loss = F.cross_entropy( 217 | pred_data["ring"], true_data["ring"], reduction="mean" 218 | ) 219 | else: 220 | ring_loss = None 221 | 222 | if "aromatic" in self.modalities: 223 | aromatic_loss = F.cross_entropy( 224 | pred_data["aromatic"], true_data["aromatic"], reduction="mean" 225 | ) 226 | else: 227 | aromatic_loss = None 228 | 229 | if "hybridization" in self.modalities: 230 | hybridization_loss = F.cross_entropy( 231 | pred_data["hybridization"], 232 | true_data["hybridization"], 233 | reduction="mean", 234 | ) 235 | else: 236 | hybridization_loss = None 237 | 238 | loss = { 239 | self.regression_key: regr_loss, 240 | "atoms": atoms_loss, 241 | "charges": charges_loss, 242 | "bonds": bonds_loss, 243 | "ring": ring_loss, 244 | "aromatic": aromatic_loss, 245 | "hybridization": hybridization_loss, 246 | "mulliken": mulliken_loss, 247 | "wbo": wbo_loss, 248 | } 249 | 250 | return loss 251 | 252 | 253 | class EdgePredictionLoss(nn.Module): 254 | def __init__( 255 | self, 256 | ) -> None: 257 | super().__init__() 258 | 259 | def forward( 260 | self, 261 | true_data: Dict, 262 | pred_data: Dict, 263 | ) -> Dict: 264 | bonds_loss = F.cross_entropy(pred_data, true_data, reduction="mean") 265 | 266 | return bonds_loss 267 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/run_evaluation_ligand.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | import torch 8 | 9 | from experiments.data.distributions import DistributionProperty 10 | 11 | warnings.filterwarnings( 12 | "ignore", category=UserWarning, message="TypedStorage is deprecated" 13 | ) 14 | warnings.filterwarnings( 15 | "ignore", category=UserWarning, message="TypedStorage is deprecated" 16 | ) 17 | 18 | 19 | class dotdict(dict): 20 | """dot.notation access to dictionary attributes""" 21 | 22 | __getattr__ = dict.get 23 | __setattr__ = dict.__setitem__ 24 | __delattr__ = dict.__delitem__ 25 | 26 | 27 | def evaluate( 28 | model_path, 29 | save_dir, 30 | save_xyz=True, 31 | calculate_energy=False, 32 | batch_size=2, 33 | use_ligand_dataset_sizes=False, 34 | build_obabel_mol=False, 35 | save_traj=False, 36 | use_energy_guidance=False, 37 | ckpt_energy_model=None, 38 | guidance_scale=1.0e-4, 39 | ddpm=True, 40 | eta_ddim=1.0, 41 | ): 42 | print("Loading from checkpoint; adapting hyperparameters to specified args") 43 | 44 | # load model 45 | ckpt = torch.load(model_path) 46 | ckpt["hyper_parameters"]["load_ckpt"] = None 47 | ckpt["hyper_parameters"]["load_ckpt_from_pretrained"] = None 48 | ckpt["hyper_parameters"]["test_save_dir"] = save_dir 49 | ckpt["hyper_parameters"]["calculate_energy"] = calculate_energy 50 | ckpt["hyper_parameters"]["save_xyz"] = save_xyz 51 | ckpt["hyper_parameters"]["batch_size"] = batch_size 52 | ckpt["hyper_parameters"]["select_train_subset"] = False 53 | ckpt["hyper_parameters"]["diffusion_pretraining"] = False 54 | ckpt["hyper_parameters"]["gpus"] = 1 55 | ckpt["hyper_parameters"]["use_ligand_dataset_sizes"] = use_ligand_dataset_sizes 56 | ckpt["hyper_parameters"]["build_obabel_mol"] = build_obabel_mol 57 | ckpt["hyper_parameters"]["save_traj"] = save_traj 58 | ckpt["hyper_parameters"]["num_charge_classes"] = 6 59 | 60 | ckpt_path = os.path.join(save_dir, f"test_model.ckpt") 61 | if not os.path.exists(ckpt_path): 62 | torch.save(ckpt, ckpt_path) 63 | 64 | hparams = ckpt["hyper_parameters"] 65 | hparams = dotdict(hparams) 66 | 67 | print(f"Loading {hparams.dataset} Datamodule.") 68 | dataset = "crossdocked" 69 | if hparams.use_adaptive_loader: 70 | print("Using adaptive dataloader") 71 | from experiments.data.ligand.ligand_dataset_adaptive import ( 72 | LigandPocketDataModule as DataModule, 73 | ) 74 | else: 75 | print("Using non-adaptive dataloader") 76 | from experiments.data.ligand.ligand_dataset_nonadaptive import ( 77 | LigandPocketDataModule as DataModule, 78 | ) 79 | datamodule = DataModule(hparams) 80 | 81 | from experiments.data.data_info import GeneralInfos as DataInfos 82 | 83 | dataset_info = DataInfos(datamodule, hparams) 84 | 85 | train_smiles = ( 86 | list(datamodule.train_dataset.smiles) 87 | if hparams.dataset != "pubchem" 88 | else datamodule.train_smiles 89 | ) 90 | prop_norm, prop_dist = None, None 91 | if len(hparams.properties_list) > 0 and hparams.context_mapping: 92 | prop_norm = datamodule.compute_mean_mad(hparams.properties_list) 93 | prop_dist = DistributionProperty(datamodule, hparams.properties_list) 94 | prop_dist.set_normalizer(prop_norm) 95 | 96 | if hparams.continuous: 97 | print("Using continuous diffusion") 98 | from experiments.diffusion_continuous import Trainer 99 | else: 100 | print("Using discrete diffusion") 101 | if hparams.diffusion_pretraining: 102 | print("Starting pre-training") 103 | if hparams.additional_feats: 104 | from experiments.diffusion_pretrain_discrete_addfeats import ( 105 | Trainer, 106 | ) 107 | else: 108 | from experiments.diffusion_pretrain_discrete import Trainer 109 | elif hparams.additional_feats: 110 | if dataset == "crossdocked": 111 | print("Ligand-pocket testing using additional features") 112 | from experiments.diffusion_discrete_moreFeats_ligand import Trainer 113 | else: 114 | print("Using additional features") 115 | from experiments.diffusion_discrete_moreFeats import Trainer 116 | else: 117 | if dataset == "crossdocked": 118 | print("Ligand-pocket testing") 119 | histogram = os.path.join(hparams.dataset_root, "size_distribution.npy") 120 | histogram = np.load(histogram).tolist() 121 | from experiments.diffusion_discrete_pocket import Trainer 122 | else: 123 | from experiments.diffusion_discrete import Trainer 124 | 125 | if build_obabel_mol: 126 | print( 127 | "Sampled molecules will be built with OpenBabel (without bond information)!" 128 | ) 129 | 130 | trainer = pl.Trainer( 131 | accelerator="gpu" if hparams.gpus else "cpu", 132 | devices=1, 133 | strategy="auto", 134 | num_nodes=1, 135 | precision=hparams.precision, 136 | ) 137 | pl.seed_everything(seed=hparams.seed, workers=hparams.gpus > 1) 138 | 139 | model = Trainer( 140 | hparams=hparams, 141 | dataset_info=dataset_info, 142 | smiles_list=train_smiles, 143 | prop_dist=prop_dist, 144 | prop_norm=prop_norm, 145 | histogram=histogram, 146 | ) 147 | 148 | trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path) 149 | 150 | 151 | def get_args(): 152 | # fmt: off 153 | parser = argparse.ArgumentParser(description='Data generation') 154 | parser.add_argument('--model-path', default="----", type=str, 155 | help='Path to trained model') 156 | parser.add_argument("--use-energy-guidance", default=False, action="store_true") 157 | parser.add_argument("--use-ligand-dataset-sizes", default=False, action="store_true") 158 | parser.add_argument("--build-obabel-mol", default=False, action="store_true") 159 | parser.add_argument("--save-traj", default=False, action="store_true") 160 | parser.add_argument("--ckpt-energy-model", default=None, type=str) 161 | parser.add_argument('--guidance-scale', default=1.0e-4, type=float, 162 | help='How to scale the guidance shift') 163 | parser.add_argument('--save-dir', default="----", type=str, 164 | help='Path to test output') 165 | parser.add_argument('--save-xyz', default=False, action="store_true", 166 | help='Whether or not to store generated molecules in xyz files') 167 | parser.add_argument('--calculate-energy', default=False, action="store_true", 168 | help='Whether or not to calculate xTB energies and forces') 169 | parser.add_argument('--batch-size', default=80, type=int, 170 | help='Batch-size to generate the selected ngraphs. Defaults to 80.') 171 | parser.add_argument('--ddim', default=False, action="store_true", 172 | help='If DDIM sampling should be used. Defaults to False') 173 | parser.add_argument('--eta-ddim', default=1.0, type=float, 174 | help='How to scale the std of noise in the reverse posterior. \ 175 | Can also be used for DDPM to track a deterministic trajectory. \ 176 | Defaults to 1.0') 177 | args = parser.parse_args() 178 | return args 179 | 180 | 181 | if __name__ == "__main__": 182 | args = get_args() 183 | # Evaluate negative log-likelihood for the test partitions 184 | evaluate( 185 | model_path=args.model_path, 186 | save_dir=args.save_dir, 187 | batch_size=args.batch_size, 188 | ddpm=not args.ddim, 189 | eta_ddim=args.eta_ddim, 190 | save_xyz=args.save_xyz, 191 | save_traj=args.save_traj, 192 | calculate_energy=args.calculate_energy, 193 | use_energy_guidance=args.use_energy_guidance, 194 | use_ligand_dataset_sizes=args.use_ligand_dataset_sizes, 195 | build_obabel_mol=args.build_obabel_mol, 196 | ckpt_energy_model=args.ckpt_energy_model, 197 | guidance_scale=args.guidance_scale, 198 | ) 199 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/run_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from argparse import ArgumentParser 4 | 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | import torch.nn.functional as F 8 | from pytorch_lightning.callbacks import ( 9 | LearningRateMonitor, 10 | ModelCheckpoint, 11 | ModelSummary, 12 | TQDMProgressBar, 13 | ) 14 | from pytorch_lightning.loggers import TensorBoardLogger 15 | 16 | from callbacks.ema import ExponentialMovingAverage 17 | 18 | warnings.filterwarnings( 19 | "ignore", category=UserWarning, message="TypedStorage is deprecated" 20 | ) 21 | 22 | from experiments.data.distributions import DistributionProperty 23 | from experiments.hparams import add_arguments 24 | 25 | if __name__ == "__main__": 26 | parser = ArgumentParser() 27 | parser = add_arguments(parser) 28 | hparams = parser.parse_args() 29 | 30 | if not os.path.exists(hparams.save_dir): 31 | os.makedirs(hparams.save_dir) 32 | 33 | if not os.path.isdir(hparams.save_dir + f"/run{hparams.id}/"): 34 | print("Creating directory") 35 | os.mkdir(hparams.save_dir + f"/run{hparams.id}/") 36 | print(f"Starting Run {hparams.id}") 37 | ema_callback = ExponentialMovingAverage(decay=hparams.ema_decay) 38 | checkpoint_callback = ModelCheckpoint( 39 | dirpath=hparams.save_dir + f"/run{hparams.id}/", 40 | save_top_k=3, 41 | monitor="val/loss", 42 | save_last=True, 43 | ) 44 | lr_logger = LearningRateMonitor() 45 | tb_logger = TensorBoardLogger( 46 | hparams.save_dir + f"/run{hparams.id}/", default_hp_metric=False 47 | ) 48 | 49 | print(f"Loading {hparams.dataset} Datamodule.") 50 | non_adaptive = True 51 | if hparams.dataset == "drugs": 52 | dataset = "drugs" 53 | if hparams.use_adaptive_loader: 54 | print("Using adaptive dataloader") 55 | non_adaptive = False 56 | from experiments.data.geom.geom_dataset_adaptive import ( 57 | GeomDataModule as DataModule, 58 | ) 59 | else: 60 | print("Using non-adaptive dataloader") 61 | from experiments.data.geom.geom_dataset_nonadaptive import ( 62 | GeomDataModule as DataModule, 63 | ) 64 | elif hparams.dataset == "qm9": 65 | dataset = "qm9" 66 | from experiments.data.qm9.qm9_dataset import QM9DataModule as DataModule 67 | 68 | elif hparams.dataset == "pubchem": 69 | dataset = "pubchem" # take dataset infos from GEOM for simplicity 70 | if hparams.use_adaptive_loader: 71 | print("Using adaptive dataloader") 72 | non_adaptive = False 73 | from experiments.data.pubchem.pubchem_dataset_adaptive import ( 74 | PubChemDataModule as DataModule, 75 | ) 76 | else: 77 | print("Using non-adaptive dataloader") 78 | from experiments.data.pubchem.pubchem_dataset_nonadaptive import ( 79 | PubChemDataModule as DataModule, 80 | ) 81 | elif hparams.dataset == "crossdocked": 82 | dataset = "crossdocked" 83 | if hparams.use_adaptive_loader: 84 | print("Using adaptive dataloader") 85 | from experiments.data.ligand.ligand_dataset_adaptive import ( 86 | LigandPocketDataModule as DataModule, 87 | ) 88 | else: 89 | print("Using non-adaptive dataloader") 90 | from experiments.data.ligand.ligand_dataset_nonadaptive import ( 91 | LigandPocketDataModule as DataModule, 92 | ) 93 | else: 94 | raise ValueError(f"Unknown dataset: {hparams.dataset}") 95 | 96 | datamodule = DataModule(hparams) 97 | 98 | from experiments.data.data_info import GeneralInfos as DataInfos 99 | 100 | dataset_info = DataInfos(datamodule, hparams) 101 | 102 | train_smiles = ( 103 | ( 104 | list(datamodule.train_dataset.smiles) 105 | if hparams.dataset != "pubchem" 106 | else None 107 | ) 108 | if not hparams.select_train_subset 109 | else datamodule.train_smiles 110 | ) 111 | prop_norm, prop_dist = None, None 112 | if ( 113 | len(hparams.properties_list) > 0 and hparams.context_mapping 114 | ) or hparams.property_training: 115 | prop_norm = datamodule.compute_mean_mad(hparams.properties_list) 116 | prop_dist = DistributionProperty(datamodule, hparams.properties_list) 117 | prop_dist.set_normalizer(prop_norm) 118 | 119 | histogram = None 120 | 121 | if hparams.continuous and dataset != "crossdocked": 122 | print("Using continuous diffusion") 123 | if hparams.diffusion_pretraining: 124 | print("Starting pre-training") 125 | from experiments.diffusion_pretrain_continuous import Trainer 126 | else: 127 | from experiments.diffusion_continuous import Trainer 128 | else: 129 | print("Using discrete diffusion") 130 | if hparams.diffusion_pretraining: 131 | if hparams.additional_feats: 132 | print("Starting pre-training on PubChem3D with additional features") 133 | from experiments.diffusion_pretrain_discrete_addfeats import ( 134 | Trainer, 135 | ) 136 | else: 137 | print("Starting pre-training on PubChem3D") 138 | from experiments.diffusion_pretrain_discrete import Trainer 139 | elif ( 140 | dataset == "crossdocked" 141 | and hparams.additional_feats 142 | and not hparams.use_qm_props 143 | ): 144 | histogram = os.path.join(hparams.dataset_root, "size_distribution.npy") 145 | histogram = np.load(histogram).tolist() 146 | print("Ligand-pocket training using additional features") 147 | from experiments.diffusion_discrete_pocket_addfeats import ( 148 | Trainer, 149 | ) 150 | else: 151 | if dataset == "crossdocked": 152 | histogram = os.path.join(hparams.dataset_root, "size_distribution.npy") 153 | histogram = np.load(histogram).tolist() 154 | if hparams.continuous: 155 | print("Continuous ligand-pocket training") 156 | from experiments.diffusion_continuous_pocket import Trainer 157 | else: 158 | print("Discrete ligand-pocket training") 159 | from experiments.diffusion_discrete_pocket import ( 160 | Trainer, 161 | ) 162 | else: 163 | if hparams.additional_feats: 164 | from experiments.diffusion_discrete_addfeats import Trainer 165 | else: 166 | from experiments.diffusion_discrete import Trainer 167 | 168 | model = Trainer( 169 | hparams=hparams.__dict__, 170 | dataset_info=dataset_info, 171 | smiles_list=train_smiles, 172 | histogram=histogram, 173 | prop_dist=prop_dist, 174 | prop_norm=prop_norm, 175 | ) 176 | 177 | from pytorch_lightning.plugins.environments import LightningEnvironment 178 | 179 | strategy = "ddp" if hparams.gpus > 1 else "auto" 180 | # strategy = 'ddp_find_unused_parameters_true' 181 | callbacks = [ 182 | ema_callback, 183 | lr_logger, 184 | checkpoint_callback, 185 | TQDMProgressBar(refresh_rate=5), 186 | ModelSummary(max_depth=2), 187 | ] 188 | 189 | if hparams.ema_decay == 1.0: 190 | callbacks = callbacks[1:] 191 | 192 | trainer = pl.Trainer( 193 | accelerator="gpu" if hparams.gpus else "cpu", 194 | devices=hparams.gpus if hparams.gpus else 1, 195 | strategy=strategy, 196 | plugins=LightningEnvironment(), 197 | num_nodes=1, 198 | logger=tb_logger, 199 | enable_checkpointing=True, 200 | accumulate_grad_batches=hparams.accum_batch, 201 | val_check_interval=hparams.eval_freq, 202 | gradient_clip_val=hparams.grad_clip_val, 203 | callbacks=callbacks, 204 | precision=hparams.precision, 205 | num_sanity_val_steps=2, 206 | max_epochs=hparams.num_epochs, 207 | detect_anomaly=hparams.detect_anomaly, 208 | ) 209 | 210 | pl.seed_everything(seed=hparams.seed, workers=hparams.gpus > 1) 211 | 212 | ckpt_path = None 213 | if hparams.load_ckpt is not None: 214 | print("Loading from checkpoint ...") 215 | import torch 216 | 217 | ckpt_path = hparams.load_ckpt 218 | ckpt = torch.load(ckpt_path) 219 | if ckpt["optimizer_states"][0]["param_groups"][0]["lr"] != hparams.lr: 220 | print("Changing learning rate ...") 221 | ckpt["optimizer_states"][0]["param_groups"][0]["lr"] = hparams.lr 222 | ckpt["optimizer_states"][0]["param_groups"][0]["initial_lr"] = hparams.lr 223 | ckpt_path = ( 224 | "lr" + "_" + str(hparams.lr) + "_" + os.path.basename(hparams.load_ckpt) 225 | ) 226 | ckpt_path = os.path.join( 227 | os.path.dirname(hparams.load_ckpt), 228 | f"retraining_with_lr{hparams.lr}.ckpt", 229 | ) 230 | if not os.path.exists(ckpt_path): 231 | torch.save(ckpt, ckpt_path) 232 | trainer.fit( 233 | model=model, 234 | datamodule=datamodule, 235 | ckpt_path=ckpt_path if hparams.load_ckpt is not None else None, 236 | ) 237 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/sampling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/eqgat_diff/experiments/sampling/__init__.py -------------------------------------------------------------------------------- /eqgat_diff/experiments/sampling/fpscores.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/eqgat_diff/experiments/sampling/fpscores.pkl.gz -------------------------------------------------------------------------------- /eqgat_diff/experiments/xtb_energy.py: -------------------------------------------------------------------------------- 1 | from ase import Atoms 2 | from xtb.ase.calculator import XTB 3 | import argparse 4 | import numpy as np 5 | 6 | 7 | def get_args(): 8 | # fmt: off 9 | parser = argparse.ArgumentParser(description='Data generation') 10 | parser.add_argument('--output-path', default=None, type=str, 11 | help='If set, saves the energy to a text file to the given path.') 12 | parser.add_argument('--data', type=dict, help='Input data as a dictionary containing atom types and coordinates') 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | def calculate_xtb_energy(positions, atom_types): 18 | atoms = Atoms(positions=positions, symbols=atom_types) 19 | atoms.calc = XTB(method="GFN2-xTB") 20 | pot_e = atoms.get_potential_energy() 21 | forces = atoms.get_forces() 22 | forces_norm = np.linalg.norm(forces) 23 | 24 | return pot_e, forces_norm 25 | 26 | 27 | if __name__ == "__main__": 28 | args = get_args() 29 | 30 | potential_energy = calculate_xtb_energy(args.data) 31 | if args.output_path is not None: 32 | f = open(args.output_path, "a") 33 | f.write(potential_energy) 34 | f.close() 35 | -------------------------------------------------------------------------------- /eqgat_diff/experiments/xtb_relaxation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import ase.units 3 | import ase 4 | from ase import Atoms 5 | from ase.io import write, read 6 | import ase.units as units 7 | import logging 8 | import os 9 | import subprocess 10 | from rdkit import rdBase 11 | import argparse 12 | import pickle 13 | import shutil 14 | from tqdm import tqdm 15 | 16 | rdBase.DisableLog("rdApp.error") 17 | logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) 18 | 19 | 20 | def make_dir(path): 21 | isExist = os.path.exists(path) 22 | if not isExist: 23 | os.makedirs(path) 24 | print(f"\nDirectory {path} has been created!") 25 | else: 26 | print(f"\nDirectory {path} exists already!") 27 | 28 | 29 | def parse_xtb_xyz(filename): 30 | num_atoms = 0 31 | energy = 0 32 | pos = [] 33 | with open(filename, "r") as f: 34 | for line_num, line in enumerate(f): 35 | if line_num == 0: 36 | num_atoms = int(line) 37 | elif line_num == 1: 38 | # xTB outputs energy in Hartrees: Hartree to eV 39 | energy = np.array( 40 | float(line.split(" ")[2]) * units.Hartree, dtype=np.float32 41 | ) 42 | elif line_num >= 2: 43 | _, x, y, z = line.split() 44 | pos.append([parse_float(x), parse_float(y), parse_float(z)]) 45 | 46 | result = { 47 | "num_atoms": num_atoms, 48 | "energy": energy, 49 | "pos": np.array(pos, dtype=np.float32), 50 | } 51 | return result 52 | 53 | 54 | def parse_float(s: str) -> float: 55 | try: 56 | return float(s) 57 | except ValueError: 58 | base, power = s.split("*^") 59 | return float(base) * 10 ** float(power) 60 | 61 | 62 | def xtb_optimization(data, file_path): 63 | zs = [] 64 | positions = [] 65 | energies = [] 66 | failed_mol_ids = [] 67 | 68 | for i, mol in tqdm(enumerate(data)): 69 | xtb_temp_dir = os.path.join(file_path, "xtb_tmp") 70 | make_dir(xtb_temp_dir) 71 | 72 | mol_size = mol.GetNumAtoms() 73 | opt = "lax" if mol_size > 60 else "normal" 74 | 75 | pos = np.array(mol.GetConformer().GetPositions(), dtype=np.float32) 76 | atomic_number = [] 77 | for atom in mol.GetAtoms(): 78 | atomic_number.append(atom.GetAtomicNum()) 79 | z = np.array(atomic_number, dtype=np.int64) 80 | mol = Atoms(numbers=z, positions=pos) 81 | 82 | mol_path = os.path.join(xtb_temp_dir, f"xtb_conformer.xyz") 83 | write(mol_path, images=mol) 84 | 85 | os.chdir(xtb_temp_dir) 86 | try: 87 | subprocess.call( 88 | # ["xtb", mol_path, "--opt", opt, "--cycles", "2000", "--gbsa", "water"], 89 | ["xtb", mol_path, "--opt", opt], 90 | stdout=subprocess.DEVNULL, 91 | stderr=subprocess.STDOUT, 92 | ) 93 | except: 94 | print(f"Molecule with id {i} failed!") 95 | failed_mol_ids.append(i) 96 | continue 97 | result = parse_xtb_xyz(os.path.join(xtb_temp_dir, f"xtbopt.xyz")) 98 | atom = read(filename=os.path.join(xtb_temp_dir, f"xtbopt.xyz")) 99 | 100 | os.chdir(file_path) 101 | shutil.rmtree(xtb_temp_dir) 102 | 103 | z = atom.get_atomic_numbers() 104 | energy = result["energy"] 105 | pos = atom.get_positions() 106 | 107 | zs.append(z) 108 | energies.append(energy) 109 | positions.append(pos) 110 | 111 | with open(os.path.join(file_path, "energies.pickle"), "wb") as f: 112 | pickle.dump(energies, f) 113 | with open(os.path.join(file_path, "failed_mol_ids.pickle"), "wb") as f: 114 | pickle.dump(failed_mol_ids, f) 115 | 116 | return zs, positions, energies, failed_mol_ids 117 | 118 | 119 | def get_args(): 120 | # fmt: off 121 | parser = argparse.ArgumentParser(description='Data generation') 122 | parser.add_argument('--output-path', default=None, type=str, 123 | help='If set, saves the energy to a text file to the given path.') 124 | parser.add_argument('--data', type=list, help='Input data as a list of RDkit molecules') 125 | args = parser.parse_args() 126 | return args 127 | 128 | 129 | if __name__ == "__main__": 130 | args = get_args() 131 | 132 | _ = xtb_optimization(args.data, args.output_path) 133 | -------------------------------------------------------------------------------- /eqgat_diff/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = e3moldiffusion 3 | author = Tuan Le 4 | author_email = tuanle@hotmail.de 5 | description = A research module for 3D molecular generation using diffusion methods 6 | long_description = file: README.md 7 | long_description_content_type = text/markdown 8 | classifiers = 9 | Development Status :: 3 - Alpha 10 | Intended Audience :: Science/Research 11 | License :: OSI Approved :: MIT 12 | Programming Language :: Python 13 | Programming Language :: Python :: 3 :: Only 14 | Programming Language :: Python :: 3.9 15 | Topic :: Scientific/Engineering 16 | Topic :: Utilities 17 | 18 | [options] 19 | packages = find: 20 | python_requires = >=3.9.0 21 | setup_requires = 22 | setuptools_scm 23 | install_requires = 24 | 25 | [mypy] 26 | python_version = 3.9 27 | ignore_missing_imports = True -------------------------------------------------------------------------------- /eqgat_diff/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | use_scm_version={ 5 | "version_scheme": "post-release", 6 | "write_to": "e3moldiffusion/_version.py", 7 | }, 8 | setup_requires=["setuptools_scm"], 9 | ) -------------------------------------------------------------------------------- /inference/README.md: -------------------------------------------------------------------------------- 1 | ## Inference 2 | 3 | We provide an exemplary jupyter notebook as well as bash script to generate ligands for QM9 and Geom-Drugs. 4 | Notice that the model weights for each dataset are not provided. 5 | As of now, we share the model weights upon request. 6 | 7 | To run the bash script, execute the following: 8 | 9 | ```bash 10 | mamba activate eqgatdiff 11 | bash run_eval_geom.sh 12 | bash run_eval_qm9.sh 13 | ``` 14 | 15 | In each bash script, the paths to dataset-root, model-dir and save-dir have to included. 16 | E.g. in case of geom-drugs: 17 | 18 | ```` 19 | export PYTHONPATH="YOUR_ABSOLUTE_PATH/eqgat-diff/eqgat_diff" 20 | dataset_root="YOUR_ABSOLUTE_PATH/eqgat-diff/data/geom" 21 | model_path="YOUR_ABSOLUTE_PATH/eqgat-diff/weights/geom/best_mol_stab.ckpt" 22 | save_dir="YOUR_ABSOLUTE_PATH/eqgat-diff/inference/tmp/geom/gen_samples" 23 | 24 | python YOUR_ABSOLUTE_PATH/eqgat_diff/experiments/run_evaluation.py \ 25 | --dataset drugs \ 26 | --dataset-root $dataset_root \ 27 | --model-path $model_path \ 28 | --save-dir $save_dir \ 29 | --batch-size 50 \ 30 | --ngraphs 100 \ 31 | ```` 32 | 33 | Depending on your GPU, you can increase the batch-size to have more molecules in a generated batch. -------------------------------------------------------------------------------- /inference/run_eval_geom.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export PYTHONPATH="YOUR_ABSOLUTE_PATH/eqgat-diff/eqgat_diff" 4 | dataset_root="YOUR_ABSOLUTE_PATH/eqgat-diff/data/geom" 5 | model_path="YOUR_ABSOLUTE_PATH/eqgat-diff/weights/geom/best_mol_stab.ckpt" 6 | save_dir="YOUR_ABSOLUTE_PATH/eqgat-diff/inference/tmp/geom/gen_samples" 7 | 8 | python YOUR_ABSOLUTE_PATH/eqgat_diff/experiments/run_evaluation.py \ 9 | --dataset drugs \ 10 | --dataset-root $dataset_root \ 11 | --model-path $model_path \ 12 | --save-dir $save_dir \ 13 | --batch-size 50 \ 14 | --ngraphs 100 \ -------------------------------------------------------------------------------- /inference/run_eval_qm9.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export PYTHONPATH="YOUR_ABSOLUTE_PATH/eqgat-diff/eqgat_diff" 4 | dataset_root="YOUR_ABSOLUTE_PATH/eqgat-diff/data/qm9" 5 | model_path="YOUR_ABSOLUTE_PATH/eqgat-diff/weights/qm9/best_mol_stab.ckpt" 6 | save_dir="YOUR_ABSOLUTE_PATH/eqgat-diff/inference/tmp/qm9/gen_samples" 7 | 8 | python YOUR_ABSOLUTE_PATH/eqgat_diff/experiments/run_evaluation.py \ 9 | --dataset qm9 \ 10 | --dataset-root $dataset_root \ 11 | --model-path $model_path \ 12 | --save-dir $save_dir \ 13 | --batch-size 50 \ 14 | --ngraphs 100 \ -------------------------------------------------------------------------------- /inference/tmp/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/inference/tmp/.gitkeep -------------------------------------------------------------------------------- /inference/tmp/geom/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/inference/tmp/geom/.gitkeep -------------------------------------------------------------------------------- /weights/README.md: -------------------------------------------------------------------------------- 1 | ## Model Weights 2 | 3 | The model weights for `qm9` and `geom` should be stored in the respective folders. 4 | Currently, we provide trained model weights only upon request. 5 | 6 | Please reach out to tuan.le@pfizer.com or julian.cremer@pfizer.com , if you are interested. 7 | -------------------------------------------------------------------------------- /weights/geom/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/weights/geom/.gitkeep -------------------------------------------------------------------------------- /weights/qm9/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/weights/qm9/.gitkeep --------------------------------------------------------------------------------