├── .gitignore
├── README.md
├── assets
├── conditional-ligand-pocket-sampling-trajectory.gif
├── conditional-sampling-maximized-polarizability.gif
├── conditional-sampling-minimized-polarizability.gif
├── unconditional-sampling-trajectory.gif
└── unconditional-sampling-trajectory.mp4
├── data
├── README.md
├── geom.zip
├── geom
│ └── .gitkeep
├── qm9.zip
└── qm9
│ └── .gitkeep
├── eqgat_diff
├── .gitignore
├── .ipynb_checkpoints
│ └── Untitled-checkpoint.ipynb
├── README.md
├── callbacks
│ └── ema.py
├── e3moldiffusion
│ ├── README.md
│ ├── __init__.py
│ ├── chem.py
│ ├── convs.py
│ ├── coordsatomsbonds.py
│ ├── gnn.py
│ ├── modules.py
│ └── molfeat.py
├── environment.yml
├── experiments
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── abstract_dataset.py
│ │ ├── adaptive_loader.py
│ │ ├── calculate_dipoles.py
│ │ ├── calculate_energies.py
│ │ ├── config_file.py
│ │ ├── data_info.py
│ │ ├── distributions.py
│ │ ├── geom
│ │ │ ├── calculate_energies.py
│ │ │ ├── geom_dataset_adaptive.py
│ │ │ ├── geom_dataset_adaptive_qm.py
│ │ │ ├── geom_dataset_energy.py
│ │ │ └── geom_dataset_nonadaptive.py
│ │ ├── ligand
│ │ │ ├── __init__.py
│ │ │ ├── constants.py
│ │ │ ├── geometry_utils.py
│ │ │ ├── ligand_dataset_adaptive.py
│ │ │ ├── ligand_dataset_nonadaptive.py
│ │ │ ├── molecule_builder.py
│ │ │ ├── process_bindingmoad.py
│ │ │ ├── process_crossdocked.py
│ │ │ └── utils.py
│ │ ├── metrics.py
│ │ ├── pubchem
│ │ │ ├── download_pubchem.py
│ │ │ ├── preprocess_pubchem.py
│ │ │ ├── process_pubchem.py
│ │ │ ├── pubchem_dataset_adaptive.py
│ │ │ └── pubchem_dataset_nonadaptive.py
│ │ ├── qm9
│ │ │ └── qm9_dataset.py
│ │ └── utils.py
│ ├── diffusion
│ │ ├── __init__.py
│ │ ├── categorical.py
│ │ ├── continuous.py
│ │ └── utils.py
│ ├── diffusion_continuous.py
│ ├── diffusion_continuous_pocket.py
│ ├── diffusion_discrete.py
│ ├── diffusion_discrete_addfeats.py
│ ├── diffusion_discrete_pocket.py
│ ├── diffusion_discrete_pocket_addfeats.py
│ ├── diffusion_pretrain_discrete.py
│ ├── diffusion_pretrain_discrete_addfeats.py
│ ├── docking.py
│ ├── docking_mgl.py
│ ├── generate_ligands.py
│ ├── hparams.py
│ ├── losses.py
│ ├── molecule_utils.py
│ ├── run_evaluation.py
│ ├── run_evaluation_ligand.py
│ ├── run_train.py
│ ├── sampling
│ │ ├── __init__.py
│ │ ├── analyze.py
│ │ ├── analyze_strict.py
│ │ ├── fpscores.pkl.gz
│ │ └── utils.py
│ ├── utils.py
│ ├── xtb_energy.py
│ ├── xtb_relaxation.py
│ └── xtb_wrapper.py
├── setup.cfg
└── setup.py
├── inference
├── README.md
├── run_eval_geom.sh
├── run_eval_qm9.sh
├── sampling_geom.ipynb
├── sampling_qm9.ipynb
└── tmp
│ ├── .gitkeep
│ └── geom
│ └── .gitkeep
└── weights
├── README.md
├── geom
└── .gitkeep
└── qm9
└── .gitkeep
/.gitignore:
--------------------------------------------------------------------------------
1 | *.egg-info
2 | .eggs
3 | __pycache__
4 | build
5 | dist
6 | .idea/*
7 |
8 | *.DS_Store*
9 | *.ipynb_checkpoints/*
10 |
11 | .vscode/
12 |
13 | data/geom/*
14 | !data/geom/.gitkeep
15 |
16 | data/qm9/*
17 | !data/qm9/.gitkeep
18 |
19 | inference/tmp/geom/*
20 | !inference/tmp/geom/.gitkeep
21 | inference/tmp/qm9/*
22 | !inference/tmp/qm9/.gitkeep
23 |
24 | weights/geom/*
25 | !weights/geom/.gitkeep
26 |
27 | weights/qm9/*
28 | !weights/qm9/.gitkeep
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # eqgat-diff
2 |
3 | This repository serves as placeholder.
4 | The codebase is stored in `eqgat_diff/` directory.
5 |
6 | Below, we show some sampling trajectories from our models. The corresponding .gif files can be found in the `assets/` directory.
7 |
8 |
9 | ### Unconditional Sampling Trajectory
10 |
11 |
12 |
13 |
14 | ### Conditional Sampling Trajectory on maximized polarizability
15 |
16 |
17 |
18 |
19 | ### Conditional Sampling Trajectory on minimized polarizability
20 |
21 |
22 |
23 |
24 |
25 | ### Conditional Sampling Trajectory on a fixed protein pocket
26 |
27 |
28 |
--------------------------------------------------------------------------------
/assets/conditional-ligand-pocket-sampling-trajectory.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/assets/conditional-ligand-pocket-sampling-trajectory.gif
--------------------------------------------------------------------------------
/assets/conditional-sampling-maximized-polarizability.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/assets/conditional-sampling-maximized-polarizability.gif
--------------------------------------------------------------------------------
/assets/conditional-sampling-minimized-polarizability.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/assets/conditional-sampling-minimized-polarizability.gif
--------------------------------------------------------------------------------
/assets/unconditional-sampling-trajectory.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/assets/unconditional-sampling-trajectory.gif
--------------------------------------------------------------------------------
/assets/unconditional-sampling-trajectory.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/assets/unconditional-sampling-trajectory.mp4
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | ## Dataset
2 |
3 | We provide the reduced datasets for QM9 and Geom-Drugs to run the inference.
4 |
5 | Please extract `qm9.zip` and `geom.zip`.
6 | Notice that the training/validation and test sets are not provided since they require larger sizes. As of now, we upload only the required empirical prior distributions to make inference work.
--------------------------------------------------------------------------------
/data/geom.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/data/geom.zip
--------------------------------------------------------------------------------
/data/geom/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/data/geom/.gitkeep
--------------------------------------------------------------------------------
/data/qm9.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/data/qm9.zip
--------------------------------------------------------------------------------
/data/qm9/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/data/qm9/.gitkeep
--------------------------------------------------------------------------------
/eqgat_diff/.gitignore:
--------------------------------------------------------------------------------
1 | *.egg-info
2 | .eggs
3 | __pycache__
4 | build
5 | dist
6 | .idea/*
7 |
8 | *.DS_Store*
9 |
10 | # setuptools_scm
11 | e3moldiffusion/_version.py
12 | figs/
13 | e3moldiffusion/potential.py
14 | geom/train_potential.py
15 | geom/data/*
16 | !geom/data/.gitkeep
17 | geom/slurm_outs/*
18 |
19 | e3moldiffusion/sample_data.py
20 | sample_data/*
21 |
22 | .vscode/
23 |
24 | configs/
25 | scripts/
26 | generate_ligands.sh
27 | # scripts/experiment*.sl
28 |
29 | geom/.ipynb_checkpoints/
30 | geom/trajs/*
31 | geom/my_movie.gif
32 | geom/vis.ipynb
33 |
34 | geom/logs
35 | geom/evaluation/
36 | geom/train_gan.py
37 |
38 | # fullerene/
39 | fullerene/data/
40 |
41 | notebooks/
42 | _old/
43 |
44 | *xlsx
45 |
--------------------------------------------------------------------------------
/eqgat_diff/.ipynb_checkpoints/Untitled-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [],
3 | "metadata": {},
4 | "nbformat": 4,
5 | "nbformat_minor": 5
6 | }
7 |
--------------------------------------------------------------------------------
/eqgat_diff/README.md:
--------------------------------------------------------------------------------
1 | # E(3) Equivariant Diffusion for Molecules
2 |
3 | Research repository exploring the capabalities of (continuous and discrete) denoising diffusion probabilistic models applied on molecular data.
4 |
5 | ## Installation
6 | Best installed using mamba.
7 | ```bash
8 | mamba env create -f environment.yml
9 | ```
10 |
11 | ## Experiments
12 | We will update the repository with the corresponding datasets to run the experiments.
13 |
14 | Generally, in `experiments/run_train.py` the main training script is executed, while in `configs/*.yaml` the configuration files for each dataset are stored.
15 |
16 | An example training run can be executed with the following command
17 |
18 | ```bash
19 | mamba activate eqgatdiff
20 | export PYTHONPATH="YOUR_ABSOLUTE_PATH/eqgat-diff/eqgat_diff"
21 | python experiments/run_train --conf configs/my_config_file.yaml
22 | ```
23 |
24 | Note that currently we do not provide the datasets, hence the code is currently just for reviewing and how the model and training runs are implemented.
25 |
26 | For example, an EQGAT-diff model that leverages Gaussian diffusion on atomic coordinates, but discrete diffusion for atom- and bond-types is implemented in `experiments/diffusion_discrete.py`.
27 |
28 | The same model, that leverages Gaussian diffusion for atomic coordinates, atom- and bond-types is implemented in `experiments/diffusion_continuous.py`.
29 |
30 | All configurable hyperparameters are listed in `experiments/hparams.py`
31 |
32 | ## Inference and Weights
33 |
34 | Currently we are still in the progress of publishing the code. Upon request, we provide model weights of the QM9 and Geom-Drugs models.
35 | Please look into `inference/` and `weights/` subdirectory for more details.
--------------------------------------------------------------------------------
/eqgat_diff/callbacks/ema.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning.callbacks import Callback
2 | from pytorch_lightning import LightningModule, Trainer
3 |
4 | from torch_ema import ExponentialMovingAverage as EMA
5 |
6 |
7 | class ExponentialMovingAverage(Callback):
8 | """
9 | Callback for using an exponential moving average over model weights.
10 | The most recent weights are only accessed during the training steps,
11 | otherwise the smoothed weight are used.
12 | """
13 |
14 | def __init__(self, decay: float, *args, **kwargs):
15 | """
16 | Args:
17 | decay (float): decay of the exponential moving average
18 | """
19 | self.decay = decay
20 | self.ema = None
21 | self._to_load = None
22 |
23 | def on_fit_start(self, trainer, pl_module: LightningModule):
24 | if self.ema is None:
25 | self.ema = EMA(pl_module.parameters(), decay=self.decay)
26 | if self._to_load is not None:
27 | self.ema.load_state_dict(self._to_load)
28 | self._to_load = None
29 |
30 | self.ema.store()
31 | self.ema.copy_to()
32 |
33 | def on_train_epoch_start(
34 | self, trainer: Trainer, pl_module: LightningModule
35 | ) -> None:
36 | self.ema.restore()
37 |
38 | def on_train_batch_end(self, trainer, pl_module: LightningModule, *args, **kwargs):
39 | self.ema.update()
40 |
41 | def on_validation_epoch_start(
42 | self, trainer: Trainer, pl_module: LightningModule, *args, **kwargs
43 | ):
44 | self.ema.store()
45 | self.ema.copy_to()
46 |
47 | def load_state_dict(self, state_dict):
48 | if "exponential_moving_average" in state_dict:
49 | if self.ema is None:
50 | self._to_load = state_dict["exponential_moving_average"]
51 | else:
52 | self.ema.load_state_dict(state_dict["exponential_moving_average"])
53 |
54 | def state_dict(self):
55 | return {"exponential_moving_average": self.ema.state_dict()}
--------------------------------------------------------------------------------
/eqgat_diff/e3moldiffusion/README.md:
--------------------------------------------------------------------------------
1 | # e3moldiffusion
2 | Directory implements the package to create the diffusion/score model.
3 |
--------------------------------------------------------------------------------
/eqgat_diff/e3moldiffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/eqgat_diff/e3moldiffusion/__init__.py
--------------------------------------------------------------------------------
/eqgat_diff/e3moldiffusion/chem.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import List, Tuple
3 |
4 | import torch
5 | from rdkit import Chem
6 | from rdkit.Chem import PeriodicTable as PT, rdDepictor as DP, rdMolAlign as MA
7 | from rdkit.Chem.Draw import rdMolDraw2D as MD2
8 | from rdkit.Chem.rdchem import GetPeriodicTable, Mol
9 | from rdkit.Chem.rdmolops import RemoveHs
10 |
11 |
12 | def set_conformer_positions(conf, pos):
13 | for i in range(pos.shape[0]):
14 | conf.SetAtomPosition(i, pos[i].tolist())
15 | return conf
16 |
17 |
18 | def update_data_rdmol_positions(data):
19 | for i in range(data.pos.size(0)):
20 | data.rdmol.GetConformer(0).SetAtomPosition(i, data.pos[i].tolist())
21 | return data
22 |
23 |
24 | def update_data_pos_from_rdmol(data):
25 | new_pos = torch.FloatTensor(data.rdmol.GetConformer(0).GetPositions()).to(data.pos)
26 | data.pos = new_pos
27 | return data
28 |
29 |
30 | def set_rdmol_positions(rdkit_mol, pos):
31 | """
32 | Args:
33 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.
34 | pos: (N_atoms, 3)
35 | """
36 | mol = copy.deepcopy(rdkit_mol)
37 | mol = set_rdmol_positions_(mol, pos)
38 | return mol
39 |
40 |
41 | def set_rdmol_positions_(mol, pos):
42 | """
43 | Args:
44 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.
45 | pos: (N_atoms, 3)
46 | """
47 | for i in range(pos.shape[0]):
48 | mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())
49 | return mol
50 |
51 |
52 | def get_atom_symbol(atomic_number):
53 | return PT.GetElementSymbol(GetPeriodicTable(), atomic_number)
54 |
55 |
56 | def mol_to_smiles(mol: Mol) -> str:
57 | return Chem.MolToSmiles(mol, allHsExplicit=True)
58 |
59 |
60 | def mol_to_smiles_without_Hs(mol: Mol) -> str:
61 | return Chem.MolToSmiles(Chem.RemoveHs(mol))
62 |
63 |
64 | def remove_duplicate_mols(molecules: List[Mol]) -> List[Mol]:
65 | unique_tuples: List[Tuple[str, Mol]] = []
66 |
67 | for molecule in molecules:
68 | duplicate = False
69 | smiles = mol_to_smiles(molecule)
70 | for unique_smiles, _ in unique_tuples:
71 | if smiles == unique_smiles:
72 | duplicate = True
73 | break
74 |
75 | if not duplicate:
76 | unique_tuples.append((smiles, molecule))
77 |
78 | return [mol for smiles, mol in unique_tuples]
79 |
80 |
81 | def get_atoms_in_ring(mol):
82 | atoms = set()
83 | for ring in mol.GetRingInfo().AtomRings():
84 | for a in ring:
85 | atoms.add(a)
86 | return atoms
87 |
88 |
89 | def get_2D_mol(mol):
90 | mol = copy.deepcopy(mol)
91 | DP.Compute2DCoords(mol)
92 | return mol
93 |
94 |
95 | def draw_mol_svg(mol, molSize=(450, 150), kekulize=False):
96 | mc = Chem.Mol(mol.ToBinary())
97 | if kekulize:
98 | try:
99 | Chem.Kekulize(mc)
100 | except:
101 | mc = Chem.Mol(mol.ToBinary())
102 | if not mc.GetNumConformers():
103 | DP.Compute2DCoords(mc)
104 | drawer = MD2.MolDraw2DSVG(molSize[0], molSize[1])
105 | drawer.DrawMolecule(mc)
106 | drawer.FinishDrawing()
107 | svg = drawer.GetDrawingText()
108 | # It seems that the svg renderer used doesn't quite hit the spec.
109 | # Here are some fixes to make it work in the notebook, although I think
110 | # the underlying issue needs to be resolved at the generation step
111 | # return svg.replace('svg:','')
112 | return svg
113 |
114 |
115 | def GetBestRMSD(probe, ref):
116 | probe = RemoveHs(probe)
117 | ref = RemoveHs(ref)
118 | rmsd = MA.GetBestRMS(probe, ref)
119 | return rmsd
120 |
--------------------------------------------------------------------------------
/eqgat_diff/e3moldiffusion/molfeat.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import torch
4 | import rdkit
5 | from rdkit import Chem
6 | from rdkit.Chem import GetPeriodicTable
7 | from torch import nn
8 | from torch_geometric.data import Data
9 | from torch_geometric.data import Data, Batch
10 |
11 |
12 | PERIODIC_TABLE = GetPeriodicTable()
13 |
14 | # allowable multiple choice node and edge features
15 | allowable_features = {
16 | 'possible_atomic_num_list' : list(range(1, 119)) + ['misc'],
17 | 'possible_chirality_list' : [
18 | 'CHI_UNSPECIFIED',
19 | 'CHI_TETRAHEDRAL_CW',
20 | 'CHI_TETRAHEDRAL_CCW',
21 | 'CHI_OTHER'
22 | ],
23 | 'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
24 | 'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
25 | 'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
26 | 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
27 | 'possible_hybridization_list' : [
28 | 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc'
29 | ],
30 | 'possible_is_aromatic_list': [False, True],
31 | 'possible_is_in_ring_list': [False, True],
32 | 'possible_bond_type_list' : [
33 | 'SINGLE',
34 | 'DOUBLE',
35 | 'TRIPLE',
36 | 'AROMATIC'#,
37 | # 'misc'
38 | ],
39 | 'possible_bond_stereo_list': [
40 | 'STEREONONE',
41 | 'STEREOZ',
42 | 'STEREOE',
43 | 'STEREOCIS',
44 | 'STEREOTRANS',
45 | 'STEREOANY',
46 | ],
47 | 'possible_is_conjugated_list': [False, True],
48 | }
49 |
50 | def atom_type_config(dataset: str = "qm9"):
51 | if dataset == "qm9":
52 | mapping = {"H": 0, "C": 1, "N": 2, "O": 3, "F": 4}
53 | elif dataset == "drugs":
54 | mapping = {
55 | "H": 0,
56 | "B": 1,
57 | "C": 2,
58 | "N": 3,
59 | "O": 4,
60 | "F": 5,
61 | "Al": 6,
62 | "Si": 7,
63 | "P": 8,
64 | "S": 9,
65 | "Cl": 10,
66 | "As": 11,
67 | "Br": 12,
68 | "I": 13,
69 | "Hg": 14,
70 | "Bi": 15,
71 | }
72 | return mapping
73 |
74 | def safe_index(l, e):
75 | """
76 | Return index of element e in list l. If e is not present, return the last index
77 | """
78 | try:
79 | return l.index(e)
80 | except:
81 | return len(l) - 1
82 |
83 |
84 | def atom_to_feature_vector(atom):
85 | """
86 | Converts rdkit atom object to feature list of indices
87 | :param mol: rdkit atom object
88 | :return: list
89 | """
90 | atom_feature = [
91 | safe_index(allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()),
92 | safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()),
93 | safe_index(allowable_features['possible_hybridization_list'], str(atom.GetHybridization())),
94 | allowable_features['possible_is_aromatic_list'].index(atom.GetIsAromatic()),
95 | allowable_features['possible_is_in_ring_list'].index(atom.IsInRing())
96 | ]
97 | return atom_feature
98 |
99 |
100 | def get_atom_feature_dims():
101 | return list(map(len, [
102 | allowable_features['possible_atomic_num_list'],
103 | allowable_features['possible_degree_list'],
104 | allowable_features['possible_hybridization_list'],
105 | allowable_features['possible_is_aromatic_list'],
106 | allowable_features['possible_is_in_ring_list']
107 | ]))
108 |
109 | def bond_to_feature_vector(bond):
110 | """
111 | Converts rdkit bond object to feature list of indices
112 | :param mol: rdkit bond object
113 | :return: list
114 | """
115 | bond_feature = [
116 | safe_index(allowable_features['possible_bond_type_list'], str(bond.GetBondType()))
117 | ]
118 | return bond_feature
119 |
120 |
121 | def get_bond_feature_dims():
122 | return list(map(len, [
123 | allowable_features['possible_bond_type_list'],
124 | ]))
125 |
126 |
127 | def atom_feature_vector_to_dict(atom_feature):
128 | [atomic_num_idx,
129 | degree_idx,
130 | hybridization_idx,
131 | is_aromatic_idx,
132 | is_in_ring_idx] = atom_feature
133 |
134 | feature_dict = {
135 | 'atomic_num': allowable_features['possible_atomic_num_list'][atomic_num_idx],
136 | 'degree': allowable_features['possible_degree_list'][degree_idx],
137 | 'hybridization': allowable_features['possible_hybridization_list'][hybridization_idx],
138 | 'is_aromatic': allowable_features['possible_is_aromatic_list'][is_aromatic_idx],
139 | 'is_in_ring': allowable_features['possible_is_in_ring_list'][is_in_ring_idx]
140 | }
141 | return feature_dict
142 |
143 |
144 | def bond_feature_vector_to_dict(bond_feature):
145 | [bond_type_idx] = bond_feature
146 |
147 | feature_dict = {
148 | 'bond_type': allowable_features['possible_bond_type_list'][bond_type_idx]
149 | }
150 |
151 | return feature_dict
152 |
153 |
154 | def smiles_or_mol_to_graph(smol: Union[str, Chem.Mol], create_bond_graph: bool = True):
155 | if isinstance(smol, str):
156 | mol = Chem.MolFromSmiles(smol)
157 | else:
158 | mol = smol
159 |
160 | # atoms
161 | atom_features_list = []
162 | atom_element_name_list = []
163 | for atom in mol.GetAtoms():
164 | atom_features_list.append(atom_to_feature_vector(atom))
165 | atom_element_name_list.append(PERIODIC_TABLE.GetElementSymbol(atom.GetAtomicNum()))
166 |
167 |
168 | x = torch.tensor(atom_features_list, dtype=torch.int64)
169 | assert x.size(-1) == 5
170 | # only take atom element
171 | # x = x[:, 0].view(-1, 1)
172 |
173 | if create_bond_graph:
174 | # bonds
175 | edges_list = []
176 | edge_features_list = []
177 | for bond in mol.GetBonds():
178 | i = bond.GetBeginAtomIdx()
179 | j = bond.GetEndAtomIdx()
180 | edge_feature = bond_to_feature_vector(bond)
181 | # add edges in both directions
182 | edges_list.append((i, j))
183 | edge_features_list.append(edge_feature[0])
184 | edges_list.append((j, i))
185 | edge_features_list.append(edge_feature[0])
186 |
187 | # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
188 | edge_index = torch.tensor(edges_list, dtype=torch.int64).T
189 | # data.edge_attr: Edge feature matrix with shape [num_edges]
190 | edge_attr = torch.tensor(edge_features_list, dtype=torch.int64)
191 |
192 | if edge_index.numel() > 0: # Sort indices.
193 | perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
194 | edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]
195 | else:
196 | edge_index = edge_attr = None
197 |
198 | data = Data(x=x, atom_elements = atom_element_name_list, edge_index=edge_index, edge_attr=edge_attr)
199 | return data
200 |
201 |
202 | class AtomEncoder(nn.Module):
203 | def __init__(self, emb_dim, max_norm: float = 10.0,
204 | use_all_atom_features: bool = False):
205 | super(AtomEncoder, self).__init__()
206 | # before: richer input featurization that also consists information about topology of graph like degree etc.
207 | FULL_ATOM_FEATURE_DIMS = get_atom_feature_dims()
208 | if not use_all_atom_features:
209 | # now: only atom type
210 | FULL_ATOM_FEATURE_DIMS = [FULL_ATOM_FEATURE_DIMS[0]]
211 | self.atom_embedding_list = nn.ModuleList()
212 | for dim in FULL_ATOM_FEATURE_DIMS:
213 | emb = nn.Embedding(dim, emb_dim, max_norm=max_norm)
214 | self.atom_embedding_list.append(emb)
215 | self.reset_parameters()
216 |
217 | def reset_parameters(self):
218 | for emb in self.atom_embedding_list:
219 | nn.init.xavier_uniform_(emb.weight.data)
220 |
221 | def forward(self, x):
222 | x_embedding = 0
223 | for i in range(len(self.atom_embedding_list)):
224 | x_embedding += self.atom_embedding_list[i](x[:, i])
225 | return x_embedding
226 |
227 | class BondEncoderOHE(nn.Module):
228 | def __init__(self, emb_dim, max_num_classes: int):
229 | super(BondEncoderOHE, self).__init__()
230 | self.linear = nn.Linear(max_num_classes, emb_dim)
231 | self.reset_parameters()
232 | def reset_parameters(self):
233 | self.linear.reset_parameters
234 | def forward(self, edge_attr):
235 | bond_embedding = self.linear(edge_attr)
236 | return bond_embedding
237 |
238 |
239 | class BondEncoder(nn.Module):
240 | def __init__(self, emb_dim, max_norm: float = 10.0):
241 | super(BondEncoder, self).__init__()
242 | FULL_BOND_FEATURE_DIMS = get_bond_feature_dims()
243 | self.bond_embedding = nn.Embedding(FULL_BOND_FEATURE_DIMS[0] + 3, emb_dim, max_norm=max_norm)
244 | self.reset_parameters()
245 | def reset_parameters(self):
246 | nn.init.xavier_uniform_(self.bond_embedding.weight.data)
247 | def forward(self, edge_attr):
248 | bond_embedding = self.bond_embedding(edge_attr)
249 | return bond_embedding
250 |
251 |
252 | if __name__ == '__main__':
253 | FULL_BOND_FEATURE_DIMS = get_bond_feature_dims()
254 | bond_attrs = FULL_BOND_FEATURE_DIMS[0] + 3
255 | from torch_sparse import coalesce
256 | from torch_cluster import radius_graph
257 | smol = "O1C=C[C@H]([C@H]1O2)c3c2cc(OC)c4c3OC(=O)C5=C4CCC(=O)5"
258 | data = smiles_or_mol_to_graph(smol)
259 | print(data)
260 | print(data.x)
261 | print(data.atom_elements)
262 |
263 | smol1 = "CC(=O)Oc1ccccc1C(=O)O"
264 | smol2 = "CCCc1nn(C)c2c(=O)[nH]c(-c3cc(S(=O)(=O)N4CCN(C)CC4)ccc3OCC)nc12"
265 |
266 | datalist = [smiles_or_mol_to_graph(s) for s in [smol, smol1, smol2]]
267 | data = Batch.from_data_list(datalist)
268 | bond_edge_index, bond_edge_attr, batch = data.edge_index, data.edge_attr, data.batch
269 |
270 |
271 | atomencoder = AtomEncoder(emb_dim=16)
272 | bondencoder = BondEncoder(emb_dim=16)
273 |
274 | x = atomencoder(data.x)
275 | edge_attr = bondencoder(data.edge_attr)
276 |
277 | pos = torch.randn(data.x.size(0), 3)
278 | bond_edge_index, bond_edge_attr = data.edge_index, data.edge_attr
279 | radius_edge_index = radius_graph(pos, r=5.0, max_num_neighbors=64, flow="source_to_target")
280 | # fromr radius_edge_index remove that are in bond_edge_index
281 |
282 | radius_feat = FULL_BOND_FEATURE_DIMS[0] + 1
283 | radius_edge_attr = torch.full((radius_edge_index.size(1), ), fill_value=radius_feat, device=pos.device, dtype=torch.long)
284 | # need to combine radius-edge-index with graph-edge-index
285 |
286 | nbonds = bond_edge_index.size(1)
287 | nradius = radius_edge_index.size(1)
288 |
289 | combined_edge_index = torch.cat([bond_edge_index, radius_edge_index], dim=-1)
290 | combined_edge_attr = torch.cat([bond_edge_attr, radius_edge_attr], dim=0)
291 |
292 | nbefore = combined_edge_index.size(1)
293 | # coalesce
294 | combined_edge_index, combined_edge_attr = coalesce(index=combined_edge_index, value=combined_edge_attr, m=pos.size(0), n=pos.size(0), op="min")
295 | print(combined_edge_index[:, :30])
296 | print()
297 | print(combined_edge_attr[:30])
--------------------------------------------------------------------------------
/eqgat_diff/environment.yml:
--------------------------------------------------------------------------------
1 | name: eqgatdiff
2 | channels:
3 | - anaconda
4 | - conda-forge
5 | - defaults
6 | - pytorch
7 | - pyg
8 | - nvidia
9 | dependencies:
10 | - python=3.10.0
11 | - pip
12 | - notebook
13 | - matplotlib
14 | - pytorch::pytorch=2.0.1
15 | - pyg::pyg=2.3.0
16 | - pyg::pytorch-scatter
17 | - pyg::pytorch-sparse
18 | - pyg::pytorch-cluster
19 | - conda-forge::rdkit=2023.03.2
20 | - conda-forge::pytorch-lightning=2.0.6
21 | - conda-forge::jupyter
22 | - conda-forge::openbabel
23 | - xtb
24 | - pip:
25 | - xtb
26 | - ase
27 | - py3Dmol
28 | - lmdb
29 | - nglview
30 | - rmsd
31 | - torch_ema
32 | - tensorboard
33 | - biopython
34 | - xtb
--------------------------------------------------------------------------------
/eqgat_diff/experiments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/eqgat_diff/experiments/__init__.py
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/eqgat_diff/experiments/data/__init__.py
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/abstract_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch.utils.data import Subset
5 |
6 | from experiments.data.adaptive_loader import AdaptiveLightningDataset
7 |
8 | try:
9 | from torch_geometric.data import LightningDataset
10 | except ImportError:
11 | from torch_geometric.data.lightning import LightningDataset
12 | from experiments.data.distributions import DistributionNodes
13 |
14 |
15 | def maybe_subset(
16 | ds, random_subset: Optional[float] = None, split=None
17 | ) -> torch.utils.data.Dataset:
18 | if random_subset is None or split in {"test", "val"}:
19 | return ds
20 | else:
21 | idx = torch.randperm(len(ds))[: int(random_subset * len(ds))]
22 | return Subset(ds, idx)
23 |
24 |
25 | class Mixin:
26 | def __getitem__(self, idx):
27 | return self.dataloaders["train"][idx]
28 |
29 | def node_counts(self, max_nodes_possible=300):
30 | all_counts = torch.zeros(max_nodes_possible)
31 | for split in ["train", "val", "test"]:
32 | for i, batch in enumerate(self.dataloaders[split]):
33 | for data in batch:
34 | if data is None:
35 | continue
36 | unique, counts = torch.unique(data.batch, return_counts=True)
37 | for count in counts:
38 | all_counts[count] += 1
39 | max_index = max(all_counts.nonzero())
40 | all_counts = all_counts[: max_index + 1]
41 | all_counts = all_counts / all_counts.sum()
42 | return all_counts
43 |
44 | def node_types(self):
45 | num_classes = None
46 | for batch in self.dataloaders["train"]:
47 | for data in batch:
48 | num_classes = data.x.shape[1]
49 | break
50 | break
51 |
52 | counts = torch.zeros(num_classes)
53 |
54 | for split in ["train", "val", "test"]:
55 | for i, batch in enumerate(self.dataloaders[split]):
56 | for data in batch:
57 | if data is None:
58 | continue
59 | counts += data.x.sum(dim=0)
60 |
61 | counts = counts / counts.sum()
62 | return counts
63 |
64 | def edge_counts(self):
65 | num_classes = 5
66 |
67 | d = torch.zeros(num_classes)
68 |
69 | for split in ["train", "val", "test"]:
70 | for i, batch in enumerate(self.dataloaders[split]):
71 | for data in batch:
72 | if data is None:
73 | continue
74 | unique, counts = torch.unique(data.batch, return_counts=True)
75 |
76 | all_pairs = 0
77 | for count in counts:
78 | all_pairs += count * (count - 1)
79 |
80 | num_edges = data.edge_index.shape[1]
81 | num_non_edges = all_pairs - num_edges
82 | edge_types = data.edge_attr.sum(dim=0)
83 | assert num_non_edges >= 0
84 | d[0] += num_non_edges
85 | d[1:] += edge_types[1:]
86 |
87 | d = d / d.sum()
88 | return d
89 |
90 | def valency_count(self, max_n_nodes):
91 | valencies = torch.zeros(
92 | 3 * max_n_nodes - 2
93 | ) # Max valency possible if everything is connected
94 |
95 | multiplier = torch.Tensor([0, 1, 2, 3, 1.5])
96 |
97 | for split in ["train", "val", "test"]:
98 | for i, batch in enumerate(self.dataloaders[split]):
99 | for data in batch:
100 | if data is None:
101 | continue
102 |
103 | n = data.x.shape[0]
104 |
105 | for atom in range(n):
106 | edges = data.edge_attr[data.edge_index[0] == atom]
107 | edges_total = edges.sum(dim=0)
108 | valency = (edges_total * multiplier).sum()
109 | valencies[valency.long().item()] += 1
110 | valencies = valencies / valencies.sum()
111 | return valencies
112 |
113 |
114 | class AbstractDataModule(Mixin, LightningDataset):
115 | def __init__(self, cfg, train_dataset, val_dataset, test_dataset):
116 | super().__init__(
117 | train_dataset,
118 | val_dataset,
119 | test_dataset,
120 | batch_size=cfg.batch_size,
121 | num_workers=cfg.num_workers,
122 | shuffle=True,
123 | pin_memory=getattr(cfg.dataset, "pin_memory", False),
124 | )
125 | self.cfg = cfg
126 |
127 |
128 | class AbstractDataModuleLigand(Mixin, LightningDataset):
129 | def __init__(self, cfg, train_dataset, val_dataset, test_dataset):
130 | super().__init__(
131 | train_dataset,
132 | val_dataset,
133 | test_dataset,
134 | batch_size=cfg.batch_size,
135 | follow_batch=["pos", "pos_pocket"],
136 | num_workers=cfg.num_workers,
137 | shuffle=True,
138 | pin_memory=getattr(cfg.dataset, "pin_memory", False),
139 | )
140 | self.cfg = cfg
141 |
142 |
143 | class AbstractAdaptiveDataModule(Mixin, AdaptiveLightningDataset):
144 | def __init__(self, cfg, train_dataset, val_dataset, test_dataset):
145 | super().__init__(
146 | train_dataset,
147 | val_dataset,
148 | test_dataset,
149 | batch_size=cfg.batch_size,
150 | reference_batch_size=cfg.inference_batch_size,
151 | num_workers=cfg.num_workers,
152 | shuffle=True,
153 | pin_memory=getattr(cfg.dataset, "pin_memory", False),
154 | )
155 | self.cfg = cfg
156 |
157 |
158 | class AbstractAdaptiveDataModuleLigand(Mixin, AdaptiveLightningDataset):
159 | def __init__(self, cfg, train_dataset, val_dataset, test_dataset):
160 | super().__init__(
161 | train_dataset,
162 | val_dataset,
163 | test_dataset,
164 | batch_size=cfg.batch_size,
165 | follow_batch=["pos", "pos_pocket"],
166 | reference_batch_size=cfg.inference_batch_size,
167 | num_workers=cfg.num_workers,
168 | shuffle=True,
169 | pin_memory=getattr(cfg.dataset, "pin_memory", False),
170 | )
171 | self.cfg = cfg
172 |
173 |
174 | class AbstractDatasetInfos:
175 | def complete_infos(self, statistics, atom_encoder):
176 | self.atom_decoder = [key for key in atom_encoder.keys()]
177 | self.num_atom_types = len(self.atom_decoder)
178 |
179 | # Train + val + test for n_nodes
180 | train_n_nodes = statistics["train"].num_nodes
181 | val_n_nodes = statistics["val"].num_nodes
182 | test_n_nodes = statistics["test"].num_nodes
183 | max_n_nodes = max(
184 | max(train_n_nodes.keys()), max(val_n_nodes.keys()), max(test_n_nodes.keys())
185 | )
186 | n_nodes = torch.zeros(max_n_nodes + 1, dtype=torch.long)
187 | for c in [train_n_nodes, val_n_nodes, test_n_nodes]:
188 | for key, value in c.items():
189 | n_nodes[key] += value
190 |
191 | self.n_nodes = n_nodes / n_nodes.sum()
192 | self.atom_types = statistics["train"].atom_types
193 | self.edge_types = statistics["train"].bond_types
194 | self.charges_types = statistics["train"].charge_types
195 | self.charges_marginals = (self.charges_types * self.atom_types[:, None]).sum(
196 | dim=0
197 | )
198 | self.valency_distribution = statistics["train"].valencies
199 | self.max_n_nodes = len(n_nodes) - 1
200 | self.nodes_dist = DistributionNodes(n_nodes)
201 |
202 | if hasattr(statistics["train"], "is_aromatic"):
203 | self.is_aromatic = statistics["train"].is_aromatic
204 | if hasattr(statistics["train"], "is_in_ring"):
205 | self.is_in_ring = statistics["train"].is_in_ring
206 | if hasattr(statistics["train"], "hybridization"):
207 | self.hybridization = statistics["train"].hybridization
208 | if hasattr(statistics["train"], "numHs"):
209 | self.numHs = statistics["train"].numHs
210 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/adaptive_loader.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections.abc import Mapping, Sequence
3 | from typing import List, Optional, Union
4 |
5 | import torch
6 | import torch.utils.data
7 | from torch.utils.data.dataloader import default_collate
8 | from torch_geometric.data import Batch, Dataset
9 | from torch_geometric.data.data import BaseData
10 |
11 | try:
12 | from torch_geometric.data import LightningDataset
13 | except ImportError:
14 | from torch_geometric.data.lightning import LightningDataset
15 |
16 |
17 | def effective_batch_size(
18 | max_size, reference_batch_size, reference_size=20, sampling=False
19 | ):
20 | x = reference_batch_size * (reference_size / max_size) ** 2
21 | return math.floor(1.8 * x) if sampling else math.floor(x)
22 |
23 |
24 | class AdaptiveCollater:
25 | def __init__(self, follow_batch, exclude_keys, reference_batch_size):
26 | """Copypaste from pyg.loader.Collater + small changes"""
27 | self.follow_batch = follow_batch
28 | self.exclude_keys = exclude_keys
29 | self.reference_bs = reference_batch_size
30 |
31 | def __call__(self, batch):
32 | # checks the number of node for basedata graphs and slots into appropriate buckets,
33 | # errors on other options
34 | elem = batch[0]
35 | if isinstance(elem, BaseData):
36 | to_keep = []
37 | graph_sizes = []
38 |
39 | for e in batch:
40 | e: BaseData
41 | graph_sizes.append(e.num_nodes)
42 |
43 | m = len(graph_sizes)
44 | graph_sizes = torch.Tensor(graph_sizes)
45 | srted, argsort = torch.sort(graph_sizes)
46 | random = torch.randint(0, m, size=(1, 1)).item()
47 | max_size = min(srted.max().item(), srted[random].item() + 5)
48 | max_size = max(
49 | max_size, 9
50 | ) # The batch sizes may be huge if the graphs happen to be tiny
51 |
52 | ebs = effective_batch_size(max_size, self.reference_bs)
53 |
54 | max_index = torch.nonzero(srted <= max_size).max().item()
55 | min_index = max(0, max_index - ebs)
56 | indices_to_keep = set(argsort[min_index : max_index + 1].tolist())
57 | if max_index < ebs:
58 | for index in range(max_index + 1, m):
59 | # Check if we could add the graph to the list
60 | size = srted[index].item()
61 | potential_ebs = effective_batch_size(size, self.reference_bs)
62 | if len(indices_to_keep) < potential_ebs:
63 | indices_to_keep.add(argsort[index].item())
64 |
65 | for i, e in enumerate(batch):
66 | e: BaseData
67 | if i in indices_to_keep:
68 | to_keep.append(e)
69 |
70 | new_batch = Batch.from_data_list(
71 | to_keep, self.follow_batch, self.exclude_keys
72 | )
73 | return new_batch
74 |
75 | elif True:
76 | # early exit
77 | raise NotImplementedError("Only supporting BaseData for now")
78 | elif isinstance(elem, torch.Tensor):
79 | return default_collate(batch)
80 | elif isinstance(elem, float):
81 | return torch.tensor(batch, dtype=torch.float)
82 | elif isinstance(elem, int):
83 | return torch.tensor(batch)
84 | elif isinstance(elem, str):
85 | return batch
86 | elif isinstance(elem, Mapping):
87 | return {key: self([data[key] for data in batch]) for key in elem}
88 | elif isinstance(elem, tuple) and hasattr(elem, "_fields"):
89 | return type(elem)(*(self(s) for s in zip(*batch)))
90 | elif isinstance(elem, Sequence) and not isinstance(elem, str):
91 | return [self(s) for s in zip(*batch)]
92 |
93 | raise TypeError(f"DataLoader found invalid type: {type(elem)}")
94 |
95 | def collate(self, batch): # Deprecated...
96 | return self(batch)
97 |
98 |
99 | class AdaptiveDataLoader(torch.utils.data.DataLoader):
100 | r"""A data loader which merges data objects from a
101 | :class:`torch_geometric.data.Dataset` into mini-batches, each minibatch being a bucket with num_nodes < some threshold,
102 | except the last which holds the overflow-graphs. Apart from the bucketing, identical to torch_geometric.loader.DataLoader
103 | Default bucket_thresholds is [30,50,90], yielding 4 buckets
104 | Data objects can be either of type :class:`~torch_geometric.data.Data` or
105 | :class:`~torch_geometric.data.HeteroData`.
106 |
107 | Args:
108 | dataset (Dataset): The dataset from which to load the data.
109 | batch_size (int, optional): How many samples per batch to load.
110 | (default: :obj:`1`)
111 | shuffle (bool, optional): If set to :obj:`True`, the data will be
112 | reshuffled at every epoch. (default: :obj:`False`)
113 | follow_batch (List[str], optional): Creates assignment batch
114 | vectors for each key in the list. (default: :obj:`None`)
115 | exclude_keys (List[str], optional): Will exclude each key in the
116 | list. (default: :obj:`None`)
117 | **kwargs (optional): Additional arguments of
118 | :class:`torch.utils.data.DataLoader`.
119 | """
120 |
121 | def __init__(
122 | self,
123 | dataset: Union[Dataset, List[BaseData]],
124 | batch_size: int = 1,
125 | reference_batch_size: int = 1,
126 | shuffle: bool = False,
127 | follow_batch: Optional[List[str]] = None,
128 | exclude_keys: Optional[List[str]] = None,
129 | **kwargs,
130 | ):
131 | if "collate_fn" in kwargs:
132 | del kwargs["collate_fn"]
133 |
134 | # Save for PyTorch Lightning:
135 | self.follow_batch = follow_batch
136 | self.exclude_keys = exclude_keys
137 |
138 | super().__init__(
139 | dataset,
140 | batch_size,
141 | shuffle,
142 | collate_fn=AdaptiveCollater(
143 | follow_batch, exclude_keys, reference_batch_size=reference_batch_size
144 | ),
145 | **kwargs,
146 | )
147 |
148 |
149 | class AdaptiveLightningDataset(LightningDataset):
150 | r"""Converts a set of :class:`~torch_geometric.data.Dataset` objects into a
151 | :class:`pytorch_lightning.LightningDataModule` variant, which can be
152 | automatically used as a :obj:`datamodule` for multi-GPU graph-level
153 | training via `PyTorch Lightning `__.
154 | :class:`LightningDataset` will take care of providing mini-batches via
155 | :class:`~torch_geometric.loader.DataLoader`.
156 |
157 | .. note::
158 |
159 | Currently only the
160 | :class:`pytorch_lightning.strategies.SingleDeviceStrategy` and
161 | :class:`pytorch_lightning.strategies.DDPSpawnStrategy` training
162 | strategies of `PyTorch Lightning
163 | `__ are supported in order to correctly share data across
165 | all devices/processes:
166 |
167 | .. code-block::
168 |
169 | import pytorch_lightning as pl
170 | trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
171 | devices=4)
172 | trainer.fit(model, datamodule)
173 |
174 | Args:
175 | train_dataset (Dataset): The training dataset.
176 | val_dataset (Dataset, optional): The validation dataset.
177 | (default: :obj:`None`)
178 | test_dataset (Dataset, optional): The test dataset.
179 | (default: :obj:`None`)
180 | batch_size (int, optional): How many samples per batch to load.
181 | (default: :obj:`1`)
182 | num_workers: How many subprocesses to use for data loading.
183 | :obj:`0` means that the data will be loaded in the main process.
184 | (default: :obj:`0`)
185 | **kwargs (optional): Additional arguments of
186 | :class:`torch_geometric.loader.DataLoader`.
187 | """
188 |
189 | def __init__(
190 | self,
191 | train_dataset: Dataset,
192 | val_dataset: Optional[Dataset] = None,
193 | test_dataset: Optional[Dataset] = None,
194 | batch_size: int = 1,
195 | reference_batch_size: int = 1,
196 | num_workers: int = 0,
197 | **kwargs,
198 | ):
199 | self.reference_batch_size = reference_batch_size
200 | super().__init__(
201 | train_dataset=train_dataset,
202 | val_dataset=val_dataset,
203 | test_dataset=test_dataset,
204 | # has_val=val_dataset is not None,
205 | # has_test=test_dataset is not None,
206 | batch_size=batch_size,
207 | num_workers=num_workers,
208 | **kwargs,
209 | )
210 |
211 | self.train_dataset = train_dataset
212 | self.val_dataset = val_dataset
213 | self.test_dataset = test_dataset
214 |
215 | def dataloader(
216 | self, dataset: Dataset, shuffle: bool = False, **kwargs
217 | ) -> AdaptiveDataLoader:
218 | return AdaptiveDataLoader(
219 | dataset,
220 | reference_batch_size=self.reference_batch_size,
221 | shuffle=shuffle,
222 | **self.kwargs,
223 | )
224 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/calculate_dipoles.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Subset
7 | from torch_geometric.data.collate import collate
8 | from tqdm import tqdm
9 |
10 | from experiments.xtb_energy import calculate_dipole
11 |
12 |
13 | def get_args():
14 | # fmt: off
15 | parser = argparse.ArgumentParser(description='Energy calculation')
16 | parser.add_argument('--dataset', type=str, help='Which dataset')
17 | parser.add_argument('--split', type=str, help='Which data split train/val/test')
18 | parser.add_argument('--idx', type=int, default=None, help='Which part of the dataset (pubchem only)')
19 |
20 | args = parser.parse_args()
21 | return args
22 |
23 |
24 | atom_encoder = {
25 | "H": 0,
26 | "B": 1,
27 | "C": 2,
28 | "N": 3,
29 | "O": 4,
30 | "F": 5,
31 | "Al": 6,
32 | "Si": 7,
33 | "P": 8,
34 | "S": 9,
35 | "Cl": 10,
36 | "As": 11,
37 | "Br": 12,
38 | "I": 13,
39 | "Hg": 14,
40 | "Bi": 15,
41 | }
42 | atom_decoder = {v: k for k, v in atom_encoder.items()}
43 | atom_reference = {
44 | "H": -0.393482763936,
45 | "B": -0.952436614164,
46 | "C": -1.795110518041,
47 | "N": -2.60945245463,
48 | "O": -3.769421097051,
49 | "F": -4.619339964238,
50 | "Al": -0.905328611479,
51 | "Si": -1.571424085131,
52 | "P": -2.377807088084,
53 | "S": -3.148271017078,
54 | "Cl": -4.482525134961,
55 | "As": -2.239425948594,
56 | "Br": -4.048339371234,
57 | "I": -3.77963026339,
58 | "Hg": -0.848032246708,
59 | "Bi": -2.26665341636,
60 | }
61 |
62 |
63 | def process(dataset, split, idx):
64 | if dataset == "drugs":
65 | from experiments.data.geom.geom_dataset_adaptive import (
66 | GeomDrugsDataset as DataModule,
67 | )
68 |
69 | root_path = "----"
70 | elif dataset == "qm9":
71 | from experiments.data.qm9.qm9_dataset import QM9Dataset as DataModule
72 |
73 | root_path = "----"
74 | elif dataset == "aqm":
75 | from experiments.data.aqm.aqm_dataset_nonadaptive import (
76 | AQMDataset as DataModule,
77 | )
78 |
79 | root_path = "----"
80 | elif dataset == "pubchem":
81 | from experiments.data.pubchem.pubchem_dataset_nonadaptive import (
82 | PubChemLMDBDataset as DataModule,
83 | )
84 |
85 | root_path = "----"
86 | else:
87 | raise ValueError("Dataset not found")
88 |
89 | remove_hs = False
90 |
91 | datamodule = DataModule(split=split, root=root_path, remove_h=remove_hs)
92 |
93 | if dataset == "pubchem":
94 | split_len = len(datamodule) // 500
95 | rng = np.arange(0, len(datamodule))
96 | rng = rng[idx * split_len : (idx + 1) * split_len]
97 | datamodule = Subset(datamodule, rng)
98 |
99 | # elif dataset == "drugs":
100 | # split_len = len(datamodule) // 50
101 | # rng = np.arange(0, len(datamodule))
102 | # rng = rng[idx * split_len : (idx + 1) * split_len]
103 | # datamodule = Subset(datamodule, rng)
104 |
105 | mols = []
106 | for i, mol in tqdm(enumerate(datamodule), total=len(datamodule)):
107 | atom_types = [atom_decoder[int(a)] for a in mol.x]
108 | try:
109 | d = calculate_dipole(mol.pos, atom_types)
110 | mol.dipole_classic = torch.tensor(d, dtype=torch.float32).unsqueeze(0)
111 | mols.append(mol)
112 | except:
113 | print(f"Molecule with id {i} failed...")
114 | continue
115 |
116 | print(f"Collate the data...")
117 | data, slices = _collate(mols)
118 |
119 | print(f"Saving the data...")
120 | torch.save(
121 | (data, slices),
122 | (os.path.join(root_path, f"processed/{split}_{idx}_data_energy.pt")),
123 | )
124 |
125 |
126 | def _collate(data_list):
127 | r"""Collates a Python list of :obj:`torch_geometric.data.Data` objects
128 | to the internal storage format of
129 | :class:`~torch_geometric.data.InMemoryDataset`."""
130 | if len(data_list) == 1:
131 | return data_list[0], None
132 |
133 | data, slices, _ = collate(
134 | data_list[0].__class__,
135 | data_list=data_list,
136 | increment=False,
137 | add_batch=False,
138 | )
139 |
140 | return data, slices
141 |
142 |
143 | if __name__ == "__main__":
144 | args = get_args()
145 | process(args.dataset, args.split, args.idx)
146 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/calculate_energies.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Subset
7 | from torch_geometric.data.collate import collate
8 | from tqdm import tqdm
9 |
10 | from experiments.xtb_energy import calculate_xtb_energy
11 |
12 |
13 | def get_args():
14 | # fmt: off
15 | parser = argparse.ArgumentParser(description='Energy calculation')
16 | parser.add_argument('--dataset', type=str, help='Which dataset')
17 | parser.add_argument('--split', type=str, help='Which data split train/val/test')
18 | parser.add_argument('--idx', type=int, default=None, help='Which part of the dataset (pubchem only)')
19 |
20 | args = parser.parse_args()
21 | return args
22 |
23 |
24 | atom_encoder = {
25 | "H": 0,
26 | "B": 1,
27 | "C": 2,
28 | "N": 3,
29 | "O": 4,
30 | "F": 5,
31 | "Al": 6,
32 | "Si": 7,
33 | "P": 8,
34 | "S": 9,
35 | "Cl": 10,
36 | "As": 11,
37 | "Br": 12,
38 | "I": 13,
39 | "Hg": 14,
40 | "Bi": 15,
41 | }
42 | atom_decoder = {v: k for k, v in atom_encoder.items()}
43 | atom_reference = {
44 | "H": -0.393482763936,
45 | "B": -0.952436614164,
46 | "C": -1.795110518041,
47 | "N": -2.60945245463,
48 | "O": -3.769421097051,
49 | "F": -4.619339964238,
50 | "Al": -0.905328611479,
51 | "Si": -1.571424085131,
52 | "P": -2.377807088084,
53 | "S": -3.148271017078,
54 | "Cl": -4.482525134961,
55 | "As": -2.239425948594,
56 | "Br": -4.048339371234,
57 | "I": -3.77963026339,
58 | "Hg": -0.848032246708,
59 | "Bi": -2.26665341636,
60 | }
61 |
62 |
63 | def process(dataset, split, idx):
64 | if dataset == "drugs":
65 | from experiments.data.geom.geom_dataset_adaptive import (
66 | GeomDrugsDataset as DataModule,
67 | )
68 |
69 | root_path = "----"
70 | elif dataset == "qm9":
71 | from experiments.data.qm9.qm9_dataset import QM9Dataset as DataModule
72 |
73 | root_path = "----"
74 | elif dataset == "aqm":
75 | from experiments.data.aqm.aqm_dataset_nonadaptive import (
76 | AQMDataset as DataModule,
77 | )
78 |
79 | root_path = "----"
80 | elif dataset == "pubchem":
81 | from experiments.data.pubchem.pubchem_dataset_nonadaptive import (
82 | PubChemLMDBDataset as DataModule,
83 | )
84 |
85 | root_path = "----"
86 | else:
87 | raise ValueError("Dataset not found")
88 |
89 | remove_hs = False
90 |
91 | datamodule = DataModule(split=split, root=root_path, remove_h=remove_hs)
92 |
93 | if dataset == "pubchem":
94 | split_len = len(datamodule) // 500
95 | rng = np.arange(0, len(datamodule))
96 | rng = rng[idx * split_len : (idx + 1) * split_len]
97 | datamodule = Subset(datamodule, rng)
98 |
99 | # elif dataset == "drugs":
100 | # split_len = len(datamodule) // 50
101 | # rng = np.arange(0, len(datamodule))
102 | # rng = rng[idx * split_len : (idx + 1) * split_len]
103 | # datamodule = Subset(datamodule, rng)
104 |
105 | mols = []
106 | for i, mol in tqdm(enumerate(datamodule), total=len(datamodule)):
107 | atom_types = [atom_decoder[int(a)] for a in mol.x]
108 | try:
109 | e_ref = np.sum(
110 | [atom_reference[a] for a in atom_types]
111 | ) # * 27.2114 #Hartree to eV
112 | e, _ = calculate_xtb_energy(mol.pos, atom_types)
113 | e *= 0.0367493 # eV to Hartree
114 | mol.energy = torch.tensor(e - e_ref, dtype=torch.float32).unsqueeze(0)
115 | mols.append(mol)
116 | except:
117 | print(f"Molecule with id {i} failed...")
118 | continue
119 |
120 | print(f"Collate the data...")
121 | data, slices = _collate(mols)
122 |
123 | print(f"Saving the data...")
124 | torch.save(
125 | (data, slices),
126 | (os.path.join(root_path, f"processed/{split}_{idx}_data_energy.pt")),
127 | )
128 |
129 |
130 | def _collate(data_list):
131 | r"""Collates a Python list of :obj:`torch_geometric.data.Data` objects
132 | to the internal storage format of
133 | :class:`~torch_geometric.data.InMemoryDataset`."""
134 | if len(data_list) == 1:
135 | return data_list[0], None
136 |
137 | data, slices, _ = collate(
138 | data_list[0].__class__,
139 | data_list=data_list,
140 | increment=False,
141 | add_batch=False,
142 | )
143 |
144 | return data, slices
145 |
146 |
147 | if __name__ == "__main__":
148 | args = get_args()
149 | process(args.dataset, args.split, args.idx)
150 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/data_info.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from experiments.data.abstract_dataset import (
5 | AbstractDatasetInfos,
6 | )
7 | from experiments.molecule_utils import PlaceHolder
8 |
9 | full_atom_encoder = {
10 | "H": 0,
11 | "B": 1,
12 | "C": 2,
13 | "N": 3,
14 | "O": 4,
15 | "F": 5,
16 | "Al": 6,
17 | "Si": 7,
18 | "P": 8,
19 | "S": 9,
20 | "Cl": 10,
21 | "As": 11,
22 | "Br": 12,
23 | "I": 13,
24 | "Hg": 14,
25 | "Bi": 15,
26 | }
27 |
28 |
29 | class GeneralInfos(AbstractDatasetInfos):
30 | def __init__(self, datamodule, cfg):
31 | self.remove_h = cfg.remove_hs
32 | self.need_to_strip = (
33 | False # to indicate whether we need to ignore one output from the model
34 | )
35 | self.statistics = datamodule.statistics
36 | self.name = "drugs"
37 | self.atom_encoder = full_atom_encoder
38 | self.num_bond_classes = cfg.num_bond_classes
39 | self.num_charge_classes = cfg.num_charge_classes
40 | self.charge_offset = 2
41 | self.collapse_charges = torch.Tensor([-2, -1, 0, 1, 2, 3]).int()
42 | # if self.remove_h:
43 | # self.atom_encoder = {
44 | # k: v - 1 for k, v in self.atom_encoder.items() if k != "H"
45 | # }
46 |
47 | super().complete_infos(datamodule.statistics, self.atom_encoder)
48 |
49 | self.input_dims = PlaceHolder(
50 | X=self.num_atom_types,
51 | C=self.num_charge_classes,
52 | E=self.num_bond_classes,
53 | y=1,
54 | pos=3,
55 | )
56 | self.output_dims = PlaceHolder(
57 | X=self.num_atom_types,
58 | C=self.num_charge_classes,
59 | E=self.num_bond_classes,
60 | y=0,
61 | pos=3,
62 | )
63 |
64 | def to_one_hot(self, X, C, E, node_mask):
65 | X = F.one_hot(X, num_classes=self.num_atom_types).float()
66 | E = F.one_hot(E, num_classes=self.num_bond_classes).float()
67 | C = F.one_hot(
68 | C + self.charge_offset, num_classes=self.num_charge_classes
69 | ).float()
70 | placeholder = PlaceHolder(X=X, C=C, E=E, y=None, pos=None)
71 | pl = placeholder.mask(node_mask)
72 | return pl.X, pl.C, pl.E
73 |
74 | def one_hot_charges(self, C):
75 | return F.one_hot(
76 | (C + self.charge_offset).long(), num_classes=self.num_charge_classes
77 | ).float()
78 |
79 |
80 | full_atom_encoder_drugs = {
81 | "H": 0,
82 | "B": 1,
83 | "C": 2,
84 | "N": 3,
85 | "O": 4,
86 | "F": 5,
87 | "Al": 6,
88 | "Si": 7,
89 | "P": 8,
90 | "S": 9,
91 | "Cl": 10,
92 | "As": 11,
93 | "Br": 12,
94 | "I": 13,
95 | "Hg": 14,
96 | "Bi": 15,
97 | }
98 |
99 |
100 | class GEOMInfos(AbstractDatasetInfos):
101 | def __init__(self, datamodule, cfg):
102 | self.remove_h = cfg.remove_hs
103 | self.need_to_strip = (
104 | False # to indicate whether we need to ignore one output from the model
105 | )
106 | self.statistics = datamodule.statistics
107 | self.name = "drugs"
108 | self.atom_encoder = full_atom_encoder_drugs
109 | self.charge_offset = 2
110 | self.collapse_charges = torch.Tensor([-2, -1, 0, 1, 2, 3]).int()
111 | if self.remove_h:
112 | self.atom_encoder = {
113 | k: v - 1 for k, v in self.atom_encoder.items() if k != "H"
114 | }
115 |
116 | super().complete_infos(datamodule.statistics, self.atom_encoder)
117 |
118 | self.input_dims = PlaceHolder(X=self.num_atom_types, C=6, E=5, y=1, pos=3)
119 | self.output_dims = PlaceHolder(X=self.num_atom_types, C=6, E=5, y=0, pos=3)
120 |
121 | def to_one_hot(self, X, C, E, node_mask):
122 | X = F.one_hot(X, num_classes=self.num_atom_types).float()
123 | E = F.one_hot(E, num_classes=5).float()
124 | C = F.one_hot(C + self.charge_offset, num_classes=6).float()
125 | placeholder = PlaceHolder(X=X, C=C, E=E, y=None, pos=None)
126 | pl = placeholder.mask(node_mask)
127 | return pl.X, pl.C, pl.E
128 |
129 | def one_hot_charges(self, C):
130 | return F.one_hot((C + self.charge_offset).long(), num_classes=6).float()
131 |
132 |
133 | full_atom_encoder_pubchem = {
134 | "H": 0,
135 | "B": 1,
136 | "C": 2,
137 | "N": 3,
138 | "O": 4,
139 | "F": 5,
140 | "Al": 6,
141 | "Si": 7,
142 | "P": 8,
143 | "S": 9,
144 | "Cl": 10,
145 | "As": 11,
146 | "Br": 12,
147 | "I": 13,
148 | "Hg": 14,
149 | "Bi": 15,
150 | }
151 |
152 |
153 | class PubChemInfos(AbstractDatasetInfos):
154 | def __init__(self, datamodule, cfg):
155 | self.remove_h = cfg.remove_hs
156 | self.need_to_strip = (
157 | False # to indicate whether we need to ignore one output from the model
158 | )
159 | self.statistics = datamodule.statistics
160 | self.name = "pubchem"
161 | self.atom_encoder = full_atom_encoder_pubchem
162 | self.atom_idx_mapping = {
163 | 0: 0,
164 | 1: 2,
165 | 2: 3,
166 | 3: 4,
167 | 4: 5,
168 | 5: 7,
169 | 6: 8,
170 | 7: 9,
171 | 8: 10,
172 | 9: 12,
173 | 10: 13,
174 | }
175 | self.charge_offset = 2
176 | self.collapse_charges = torch.Tensor([-2, -1, 0, 1, 2, 3]).int()
177 | if self.remove_h:
178 | self.atom_encoder = {
179 | k: v - 1 for k, v in self.atom_encoder.items() if k != "H"
180 | }
181 |
182 | super().complete_infos(datamodule.statistics, self.atom_encoder)
183 |
184 | self.input_dims = PlaceHolder(X=len(self.atom_encoder), C=6, E=5, y=1, pos=3)
185 | self.output_dims = PlaceHolder(X=len(self.atom_encoder), C=6, E=5, y=0, pos=3)
186 |
187 | def to_one_hot(self, X, C, E, node_mask):
188 | X = F.one_hot(X, num_classes=len(self.atom_encoder)).float()
189 | E = F.one_hot(E, num_classes=5).float()
190 | C = F.one_hot(C + self.charge_offset, num_classes=6).float()
191 | placeholder = PlaceHolder(X=X, C=C, E=E, y=None, pos=None)
192 | pl = placeholder.mask(node_mask)
193 | return pl.X, pl.C, pl.E
194 |
195 | def one_hot_charges(self, C):
196 | return F.one_hot((C + self.charge_offset).long(), num_classes=6).float()
197 |
198 |
199 | full_atom_encoder_qm9 = {"H": 0, "C": 1, "N": 2, "O": 3, "F": 4}
200 |
201 |
202 | class QM9Infos(AbstractDatasetInfos):
203 | def __init__(self, datamodule, cfg):
204 | self.remove_h = cfg.remove_hs
205 | self.statistics = datamodule.statistics
206 | self.name = "qm9"
207 | self.atom_encoder = full_atom_encoder_qm9
208 | self.charge_offset = 1
209 | self.collapse_charges = torch.Tensor([-1, 0, 1]).int()
210 | if self.remove_h:
211 | self.atom_encoder = {
212 | k: v - 1 for k, v in self.atom_encoder.items() if k != "H"
213 | }
214 | super().complete_infos(datamodule.statistics, self.atom_encoder)
215 | self.input_dims = PlaceHolder(X=self.num_atom_types, C=3, E=5, y=1, pos=3)
216 | self.output_dims = PlaceHolder(X=self.num_atom_types, C=3, E=5, y=0, pos=3)
217 |
218 | def to_one_hot(self, X, C, E, node_mask):
219 | X = F.one_hot(X, num_classes=self.num_atom_types).float()
220 | E = F.one_hot(E, num_classes=5).float()
221 | C = F.one_hot(C + self.charge_offset, num_classes=3).float()
222 | placeholder = PlaceHolder(X=X, C=C, E=E, y=None, pos=None)
223 | pl = placeholder.mask(node_mask)
224 | return pl.X, pl.C, pl.E
225 |
226 | def one_hot_charges(self, charges):
227 | return F.one_hot((charges + self.charge_offset).long(), num_classes=3).float()
228 |
229 |
230 | mol_properties = [
231 | "DIP",
232 | "HLgap",
233 | "eAT",
234 | "eC",
235 | "eEE",
236 | "eH",
237 | "eKIN",
238 | "eKSE",
239 | "eL",
240 | "eNE",
241 | "eNN",
242 | "eMBD",
243 | "eTS",
244 | "eX",
245 | "eXC",
246 | "eXX",
247 | "mPOL",
248 | ]
249 |
250 | atomic_energies_dict = {
251 | 1: -13.643321054,
252 | 6: -1027.610746263,
253 | 7: -1484.276217092,
254 | 8: -2039.751675679,
255 | 9: -3139.751675679,
256 | 15: -9283.015861995,
257 | 16: -10828.726222083,
258 | 17: -12516.462339357,
259 | }
260 | atomic_numbers = [1, 6, 7, 8, 9, 15, 16, 17]
261 | full_atom_encoder_aqm = {
262 | "H": 0,
263 | "C": 1,
264 | "N": 2,
265 | "O": 3,
266 | "F": 4,
267 | "P": 5,
268 | "S": 6,
269 | "Cl": 7,
270 | }
271 |
272 |
273 | class AQMInfos(AbstractDatasetInfos):
274 | def __init__(self, datamodule, cfg):
275 | self.remove_h = cfg.remove_hs
276 | self.statistics = datamodule.statistics
277 | self.name = "aqm"
278 | self.atom_encoder = full_atom_encoder_aqm
279 | self.charge_offset = 1
280 | self.collapse_charges = torch.Tensor([-1, 0, 1]).int()
281 | if self.remove_h:
282 | self.atom_encoder = {
283 | k: v - 1 for k, v in self.atom_encoder.items() if k != "H"
284 | }
285 |
286 | super().complete_infos(datamodule.statistics, self.atom_encoder)
287 |
288 | self.input_dims = PlaceHolder(X=self.num_atom_types, C=3, E=5, y=1, pos=3)
289 | self.output_dims = PlaceHolder(X=self.num_atom_types, C=3, E=5, y=0, pos=3)
290 |
291 | def to_one_hot(self, X, C, E, node_mask):
292 | X = F.one_hot(X, num_classes=self.num_atom_types).float()
293 | E = F.one_hot(E, num_classes=5).float()
294 | C = F.one_hot(C + 1, num_classes=3).float()
295 | placeholder = PlaceHolder(X=X, C=C, E=E, y=None, pos=None)
296 | pl = placeholder.mask(node_mask)
297 | return pl.X, pl.C, pl.E
298 |
299 | def one_hot_charges(self, C):
300 | return F.one_hot((C + self.charge_offset).long(), num_classes=3).float()
301 |
302 |
303 | class AQMQM7XInfos(AbstractDatasetInfos):
304 | def __init__(self, datamodule, cfg):
305 | self.remove_h = cfg.remove_hs
306 | self.statistics = datamodule.statistics
307 | self.name = "aqm_qm7x"
308 | self.atom_encoder = full_atom_encoder_aqm
309 | self.charge_offset = 1
310 | self.collapse_charges = torch.Tensor([-1, 0, 1]).int()
311 | if self.remove_h:
312 | self.atom_encoder = {
313 | k: v - 1 for k, v in self.atom_encoder.items() if k != "H"
314 | }
315 |
316 | super().complete_infos(datamodule.statistics, self.atom_encoder)
317 |
318 | self.input_dims = PlaceHolder(X=self.num_atom_types, C=3, E=5, y=1, pos=3)
319 | self.output_dims = PlaceHolder(X=self.num_atom_types, C=3, E=5, y=0, pos=3)
320 |
321 | def to_one_hot(self, X, C, E, node_mask):
322 | X = F.one_hot(X, num_classes=self.num_atom_types).float()
323 | E = F.one_hot(E, num_classes=5).float()
324 | C = F.one_hot(C + 1, num_classes=3).float()
325 | placeholder = PlaceHolder(X=X, C=C, E=E, y=None, pos=None)
326 | pl = placeholder.mask(node_mask)
327 | return pl.X, pl.C, pl.E
328 |
329 | def one_hot_charges(self, C):
330 | return F.one_hot((C + self.charge_offset).long(), num_classes=3).float()
331 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/geom/calculate_energies.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 |
5 | import numpy as np
6 | import torch
7 | from torch_geometric.data.collate import collate
8 | from tqdm import tqdm
9 |
10 | from experiments.xtb_energy import calculate_xtb_energy
11 |
12 |
13 | def get_args():
14 | # fmt: off
15 | parser = argparse.ArgumentParser(description='Energy calculation')
16 | parser.add_argument('--dataset', type=str, help='Which dataset')
17 | parser.add_argument('--split', type=str, help='Which data split train/val/test')
18 | args = parser.parse_args()
19 | return args
20 |
21 |
22 | atom_encoder = {
23 | "H": 0,
24 | "B": 1,
25 | "C": 2,
26 | "N": 3,
27 | "O": 4,
28 | "F": 5,
29 | "Al": 6,
30 | "Si": 7,
31 | "P": 8,
32 | "S": 9,
33 | "Cl": 10,
34 | "As": 11,
35 | "Br": 12,
36 | "I": 13,
37 | "Hg": 14,
38 | "Bi": 15,
39 | }
40 | atom_decoder = {v: k for k, v in atom_encoder.items()}
41 | atom_reference = {
42 | "H": -0.393482763936,
43 | "B": -0.952436614164,
44 | "C": -1.795110518041,
45 | "N": -2.60945245463,
46 | "O": -3.769421097051,
47 | "F": -4.619339964238,
48 | "Al": -0.905328611479,
49 | "Si": -1.571424085131,
50 | "P": -2.377807088084,
51 | "S": -3.148271017078,
52 | "Cl": -4.482525134961,
53 | "As": -2.239425948594,
54 | "Br": -4.048339371234,
55 | "I": -3.77963026339,
56 | "Hg": -0.848032246708,
57 | "Bi": -2.26665341636,
58 | }
59 |
60 |
61 | def process(dataset, split):
62 | if dataset == "drugs":
63 | from experiments.data.geom.geom_dataset_adaptive import (
64 | GeomDrugsDataset as DataModule,
65 | )
66 |
67 | root_path = "/scratch1/cremej01/data/geom"
68 | elif dataset == "qm9":
69 | from experiments.data.qm9.qm9_dataset import GeomDrugsDataset as DataModule
70 |
71 | root_path = "/scratch1/cremej01/data/qm9"
72 | else:
73 | raise ValueError("Dataset not found")
74 |
75 | remove_hs = False
76 |
77 | dataset = DataModule(split=split, root=root_path, remove_h=remove_hs)
78 |
79 | failed_ids = []
80 | mols = []
81 | for i, mol in tqdm(enumerate(dataset)):
82 | atom_types = [atom_decoder[int(a)] for a in mol.x]
83 | try:
84 | e_ref = np.sum(
85 | [atom_reference[a] for a in atom_types]
86 | ) # * 27.2114 #Hartree to eV
87 | e, _ = calculate_xtb_energy(mol.pos, atom_types)
88 | e *= 0.0367493 # eV to Hartree
89 | mol.energy = torch.tensor(e - e_ref, dtype=torch.float32).unsqueeze(0)
90 | except:
91 | print(f"Molecule with id {i} failed...")
92 | failed_ids.append(i)
93 | continue
94 | mols.append(mol)
95 |
96 | print("Collate the data...")
97 | data, slices = _collate(mols)
98 |
99 | print("Saving the data...")
100 | torch.save(
101 | (data, slices), (os.path.join(root_path, f"processed/{split}_data_energy.pt"))
102 | )
103 |
104 | with open(os.path.join(root_path, f"failed_ids_{split}.pickle"), "wb") as f:
105 | pickle.dump(failed_ids, f)
106 |
107 |
108 | def _collate(data_list):
109 | r"""Collates a Python list of :obj:`torch_geometric.data.Data` objects
110 | to the internal storage format of
111 | :class:`~torch_geometric.data.InMemoryDataset`."""
112 | if len(data_list) == 1:
113 | return data_list[0], None
114 |
115 | data, slices, _ = collate(
116 | data_list[0].__class__,
117 | data_list=data_list,
118 | increment=False,
119 | add_batch=False,
120 | )
121 |
122 | return data, slices
123 |
124 |
125 | if __name__ == "__main__":
126 | args = get_args()
127 | process(args.dataset, args.split)
128 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/geom/geom_dataset_energy.py:
--------------------------------------------------------------------------------
1 | from rdkit import RDLogger
2 | from tqdm import tqdm
3 | import numpy as np
4 | import torch
5 | from os.path import join
6 | from experiments.data.utils import train_subset
7 | from torch_geometric.data import InMemoryDataset, DataLoader
8 | import experiments.data.utils as dataset_utils
9 | from experiments.data.utils import load_pickle, save_pickle
10 | from experiments.data.abstract_dataset import (
11 | AbstractAdaptiveDataModule,
12 | )
13 | from experiments.data.metrics import compute_all_statistics
14 | from torch.utils.data import Subset
15 |
16 |
17 | full_atom_encoder = {
18 | "H": 0,
19 | "B": 1,
20 | "C": 2,
21 | "N": 3,
22 | "O": 4,
23 | "F": 5,
24 | "Al": 6,
25 | "Si": 7,
26 | "P": 8,
27 | "S": 9,
28 | "Cl": 10,
29 | "As": 11,
30 | "Br": 12,
31 | "I": 13,
32 | "Hg": 14,
33 | "Bi": 15,
34 | }
35 |
36 |
37 | class GeomDrugsDataset(InMemoryDataset):
38 | def __init__(
39 | self, split, root, remove_h, transform=None, pre_transform=None, pre_filter=None
40 | ):
41 | assert split in ["train", "val", "test"]
42 | self.split = split
43 | self.remove_h = remove_h
44 |
45 | self.compute_bond_distance_angles = True
46 |
47 | self.atom_encoder = full_atom_encoder
48 |
49 | if remove_h:
50 | self.atom_encoder = {
51 | k: v - 1 for k, v in self.atom_encoder.items() if k != "H"
52 | }
53 |
54 | super().__init__(root, transform, pre_transform, pre_filter)
55 | self.data, self.slices = torch.load(self.processed_paths[0])
56 | self.statistics = dataset_utils.Statistics(
57 | num_nodes=load_pickle(self.processed_paths[1]),
58 | atom_types=torch.from_numpy(np.load(self.processed_paths[2])),
59 | bond_types=torch.from_numpy(np.load(self.processed_paths[3])),
60 | charge_types=torch.from_numpy(np.load(self.processed_paths[4])),
61 | valencies=load_pickle(self.processed_paths[5]),
62 | bond_lengths=load_pickle(self.processed_paths[6]),
63 | bond_angles=torch.from_numpy(np.load(self.processed_paths[7])),
64 | is_aromatic=torch.from_numpy(np.load(self.processed_paths[9])).float(),
65 | is_in_ring=torch.from_numpy(np.load(self.processed_paths[10])).float(),
66 | hybridization=torch.from_numpy(np.load(self.processed_paths[11])).float(),
67 | )
68 | self.smiles = load_pickle(self.processed_paths[8])
69 |
70 | @property
71 | def raw_file_names(self):
72 | if self.split == "train":
73 | return ["train_data.pickle"]
74 | elif self.split == "val":
75 | return ["val_data.pickle"]
76 | else:
77 | return ["test_data.pickle"]
78 |
79 | def processed_file_names(self):
80 | h = "noh" if self.remove_h else "h"
81 | if self.split == "train":
82 | return [
83 | f"train_{h}_energy.pt",
84 | f"train_n_{h}.pickle",
85 | f"train_atom_types_{h}.npy",
86 | f"train_bond_types_{h}.npy",
87 | f"train_charges_{h}.npy",
88 | f"train_valency_{h}.pickle",
89 | f"train_bond_lengths_{h}.pickle",
90 | f"train_angles_{h}.npy",
91 | "train_smiles.pickle",
92 | f"train_is_aromatic_{h}.npy",
93 | f"train_is_in_ring_{h}.npy",
94 | f"train_hybridization_{h}.npy",
95 | ]
96 | elif self.split == "val":
97 | return [
98 | f"val_{h}_energy.pt",
99 | f"val_n_{h}.pickle",
100 | f"val_atom_types_{h}.npy",
101 | f"val_bond_types_{h}.npy",
102 | f"val_charges_{h}.npy",
103 | f"val_valency_{h}.pickle",
104 | f"val_bond_lengths_{h}.pickle",
105 | f"val_angles_{h}.npy",
106 | "val_smiles.pickle",
107 | f"val_is_aromatic_{h}.npy",
108 | f"val_is_in_ring_{h}.npy",
109 | f"val_hybridization_{h}.npy",
110 | ]
111 | else:
112 | return [
113 | f"test_{h}_energy.pt",
114 | f"test_n_{h}.pickle",
115 | f"test_atom_types_{h}.npy",
116 | f"test_bond_types_{h}.npy",
117 | f"test_charges_{h}.npy",
118 | f"test_valency_{h}.pickle",
119 | f"test_bond_lengths_{h}.pickle",
120 | f"test_angles_{h}.npy",
121 | "test_smiles.pickle",
122 | f"test_is_aromatic_{h}.npy",
123 | f"test_is_in_ring_{h}.npy",
124 | f"test_hybridization_{h}.npy",
125 | ]
126 |
127 | def download(self):
128 | raise ValueError(
129 | "Download and preprocessing is manual. If the data is already downloaded, "
130 | f"check that the paths are correct. Root dir = {self.root} -- raw files {self.raw_paths}"
131 | )
132 |
133 | def process(self):
134 | RDLogger.DisableLog("rdApp.*")
135 | all_data = load_pickle(self.raw_paths[0])
136 |
137 | data_list = []
138 | all_smiles = []
139 | for i, data in enumerate(tqdm(all_data)):
140 | smiles, all_conformers = data
141 | all_smiles.append(smiles)
142 | for j, conformer in enumerate(all_conformers):
143 | if j >= 5:
144 | break
145 | data = dataset_utils.mol_to_torch_geometric(
146 | conformer,
147 | full_atom_encoder,
148 | smiles,
149 | remove_hydrogens=self.remove_h, # need to give full atom encoder since hydrogens might still be available if Chem.RemoveHs is called
150 | )
151 | # even when calling Chem.RemoveHs, hydrogens might be present
152 | if self.remove_h:
153 | data = dataset_utils.remove_hydrogens(
154 | data
155 | ) # remove through masking
156 |
157 | if self.pre_filter is not None and not self.pre_filter(data):
158 | continue
159 | if self.pre_transform is not None:
160 | data = self.pre_transform(data)
161 |
162 | data_list.append(data)
163 |
164 | torch.save(self.collate(data_list), self.processed_paths[0])
165 |
166 | statistics = compute_all_statistics(
167 | data_list,
168 | self.atom_encoder,
169 | charges_dic={-2: 0, -1: 1, 0: 2, 1: 3, 2: 4, 3: 5},
170 | additional_feats=True,
171 | # do not compute bond distance and bond angle statistics due to time and we do not use it anyways currently
172 | )
173 | save_pickle(statistics.num_nodes, self.processed_paths[1])
174 | np.save(self.processed_paths[2], statistics.atom_types)
175 | np.save(self.processed_paths[3], statistics.bond_types)
176 | np.save(self.processed_paths[4], statistics.charge_types)
177 | save_pickle(statistics.valencies, self.processed_paths[5])
178 | save_pickle(statistics.bond_lengths, self.processed_paths[6])
179 | np.save(self.processed_paths[7], statistics.bond_angles)
180 | save_pickle(set(all_smiles), self.processed_paths[8])
181 |
182 | np.save(self.processed_paths[9], statistics.is_aromatic)
183 | np.save(self.processed_paths[10], statistics.is_in_ring)
184 | np.save(self.processed_paths[11], statistics.hybridization)
185 |
186 |
187 | class GeomDataModule(AbstractAdaptiveDataModule):
188 | def __init__(self, cfg):
189 | self.datadir = cfg.dataset_root
190 | root_path = cfg.dataset_root
191 | self.pin_memory = True
192 |
193 | train_dataset = GeomDrugsDataset(
194 | split="train", root=root_path, remove_h=cfg.remove_hs
195 | )
196 | val_dataset = GeomDrugsDataset(
197 | split="val", root=root_path, remove_h=cfg.remove_hs
198 | )
199 | test_dataset = GeomDrugsDataset(
200 | split="test", root=root_path, remove_h=cfg.remove_hs
201 | )
202 |
203 | self.statistics = {
204 | "train": train_dataset.statistics,
205 | "val": val_dataset.statistics,
206 | "test": test_dataset.statistics,
207 | }
208 |
209 | if cfg.select_train_subset:
210 | self.idx_train = train_subset(
211 | dset_len=len(train_dataset),
212 | train_size=cfg.train_size,
213 | seed=cfg.seed,
214 | filename=join(cfg.save_dir, "splits.npz"),
215 | )
216 | self.train_smiles = train_dataset.smiles
217 | train_dataset = Subset(train_dataset, self.idx_train)
218 |
219 | self.remove_h = cfg.remove_hs
220 |
221 | super().__init__(cfg, train_dataset, val_dataset, test_dataset)
222 |
223 | def _train_dataloader(self, shuffle=True):
224 | dataloader = DataLoader(
225 | dataset=self.train_dataset,
226 | batch_size=self.cfg.batch_size,
227 | num_workers=self.cfg.num_workers,
228 | pin_memory=self.pin_memory,
229 | shuffle=shuffle,
230 | persistent_workers=False,
231 | )
232 | return dataloader
233 |
234 | def _val_dataloader(self, shuffle=False):
235 | dataloader = DataLoader(
236 | dataset=self.val_dataset,
237 | batch_size=self.cfg.batch_size,
238 | num_workers=self.cfg.num_workers,
239 | pin_memory=self.pin_memory,
240 | shuffle=shuffle,
241 | persistent_workers=False,
242 | )
243 | return dataloader
244 |
245 | def _test_dataloader(self, shuffle=False):
246 | dataloader = DataLoader(
247 | dataset=self.test_dataset,
248 | batch_size=self.cfg.batch_size,
249 | num_workers=self.cfg.num_workers,
250 | pin_memory=self.pin_memory,
251 | shuffle=shuffle,
252 | persistent_workers=False,
253 | )
254 | return dataloader
255 |
256 | def compute_mean_mad(self, properties_list):
257 | if self.cfg.dataset == "qm9" or self.cfg.dataset == "drugs":
258 | dataloader = self.get_dataloader(self.train_dataset, "val")
259 | return self.compute_mean_mad_from_dataloader(dataloader, properties_list)
260 | elif self.cfg.dataset == "qm9_1half" or self.cfg.dataset == "qm9_2half":
261 | dataloader = self.get_dataloader(self.val_dataset, "val")
262 | return self.compute_mean_mad_from_dataloader(dataloader, properties_list)
263 | else:
264 | raise Exception("Wrong dataset name")
265 |
266 | def compute_mean_mad_from_dataloader(self, dataloader, properties_list):
267 | property_norms = {}
268 | for property_key in properties_list:
269 | try:
270 | property_name = property_key + "_mm"
271 | values = getattr(dataloader.dataset[:], property_name)
272 | except:
273 | property_name = property_key
274 | idx = dataloader.dataset[:].label2idx[property_name]
275 | values = torch.tensor(
276 | [data.y[:, idx] for data in dataloader.dataset[:]]
277 | )
278 |
279 | mean = torch.mean(values)
280 | ma = torch.abs(values - mean)
281 | mad = torch.mean(ma)
282 | property_norms[property_key] = {}
283 | property_norms[property_key]["mean"] = mean
284 | property_norms[property_key]["mad"] = mad
285 | del values
286 | return property_norms
287 |
288 | def get_dataloader(self, dataset, stage):
289 | if stage == "train":
290 | batch_size = self.cfg.batch_size
291 | shuffle = True
292 | elif stage in ["val", "test"]:
293 | batch_size = self.cfg.inference_batch_size
294 | shuffle = False
295 |
296 | dl = DataLoader(
297 | dataset=dataset,
298 | batch_size=batch_size,
299 | num_workers=self.cfg.num_workers,
300 | pin_memory=True,
301 | shuffle=shuffle,
302 | )
303 |
304 | return dl
305 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/geom/geom_dataset_nonadaptive.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 | from typing import Optional
3 |
4 | from torch.utils.data import Subset
5 | from torch_geometric.data import DataLoader
6 |
7 | from experiments.data.abstract_dataset import (
8 | AbstractDataModule,
9 | )
10 | from experiments.data.geom.geom_dataset_adaptive import GeomDrugsDataset
11 | from experiments.data.utils import train_subset
12 |
13 | full_atom_encoder = {
14 | "H": 0,
15 | "B": 1,
16 | "C": 2,
17 | "N": 3,
18 | "O": 4,
19 | "F": 5,
20 | "Al": 6,
21 | "Si": 7,
22 | "P": 8,
23 | "S": 9,
24 | "Cl": 10,
25 | "As": 11,
26 | "Br": 12,
27 | "I": 13,
28 | "Hg": 14,
29 | "Bi": 15,
30 | }
31 |
32 |
33 | class GeomDataModule(AbstractDataModule):
34 | def __init__(self, cfg, only_stats: bool = False):
35 | self.datadir = cfg.dataset_root
36 | root_path = cfg.dataset_root
37 | self.cfg = cfg
38 | self.pin_memory = True
39 | self.persistent_workers = False
40 |
41 | train_dataset = GeomDrugsDataset(
42 | split="train", root=root_path, remove_h=cfg.remove_hs, only_stats=only_stats
43 | )
44 | val_dataset = GeomDrugsDataset(
45 | split="val", root=root_path, remove_h=cfg.remove_hs, only_stats=only_stats
46 | )
47 | test_dataset = GeomDrugsDataset(
48 | split="test", root=root_path, remove_h=cfg.remove_hs, only_stats=only_stats
49 | )
50 |
51 | if not only_stats:
52 | if cfg.select_train_subset:
53 | self.idx_train = train_subset(
54 | dset_len=len(train_dataset),
55 | train_size=cfg.train_size,
56 | seed=cfg.seed,
57 | filename=join(cfg.save_dir, "splits.npz"),
58 | )
59 | train_dataset = Subset(train_dataset, self.idx_train)
60 |
61 | self.remove_h = cfg.remove_hs
62 | self.statistics = {
63 | "train": train_dataset.statistics,
64 | "val": val_dataset.statistics,
65 | "test": test_dataset.statistics,
66 | }
67 | super().__init__(cfg, train_dataset, val_dataset, test_dataset)
68 |
69 | def setup(self, stage: Optional[str] = None) -> None:
70 | train_dataset = GeomDrugsDataset(
71 | root=self.cfg.dataset_root, split="train", remove_h=self.cfg.remove_hs
72 | )
73 | val_dataset = GeomDrugsDataset(
74 | root=self.cfg.dataset_root, split="val", remove_h=self.cfg.remove_hs
75 | )
76 | test_dataset = GeomDrugsDataset(
77 | root=self.cfg.dataset_root, split="test", remove_h=self.cfg.remove_hs
78 | )
79 |
80 | if stage == "fit" or stage is None:
81 | self.train_dataset = train_dataset
82 | self.val_dataset = val_dataset
83 | self.test_dataset = test_dataset
84 |
85 | def train_dataloader(self, shuffle=True):
86 | dataloader = DataLoader(
87 | dataset=self.train_dataset,
88 | batch_size=self.cfg.batch_size,
89 | num_workers=self.cfg.num_workers,
90 | pin_memory=self.pin_memory,
91 | shuffle=shuffle,
92 | persistent_workers=self.persistent_workers,
93 | )
94 | return dataloader
95 |
96 | def val_dataloader(self, shuffle=False):
97 | dataloader = DataLoader(
98 | dataset=self.val_dataset,
99 | batch_size=self.cfg.batch_size,
100 | num_workers=self.cfg.num_workers,
101 | pin_memory=self.pin_memory,
102 | shuffle=shuffle,
103 | persistent_workers=self.persistent_workers,
104 | )
105 | return dataloader
106 |
107 | def test_dataloader(self, shuffle=False):
108 | dataloader = DataLoader(
109 | dataset=self.test_dataset,
110 | batch_size=self.cfg.batch_size,
111 | num_workers=self.cfg.num_workers,
112 | pin_memory=self.pin_memory,
113 | shuffle=shuffle,
114 | persistent_workers=self.persistent_workers,
115 | )
116 | return dataloader
117 |
118 | def get_dataloader(self, dataset, stage):
119 | if stage == "train":
120 | batch_size = self.cfg.batch_size
121 | shuffle = True
122 | elif stage in ["val", "test"]:
123 | batch_size = self.cfg.inference_batch_size
124 | shuffle = False
125 |
126 | dl = DataLoader(
127 | dataset=dataset,
128 | batch_size=batch_size,
129 | num_workers=self.cfg.num_workers,
130 | pin_memory=True,
131 | shuffle=shuffle,
132 | )
133 |
134 | return dl
135 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/ligand/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/eqgat_diff/experiments/data/ligand/__init__.py
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/ligand/geometry_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from constants import CA_C_DIST, N_CA_C_ANGLE, N_CA_DIST
3 |
4 |
5 | def rotation_matrix(angle, axis):
6 | """
7 | Args:
8 | angle: (n,)
9 | axis: 0=x, 1=y, 2=z
10 | Returns:
11 | (n, 3, 3)
12 | """
13 | n = len(angle)
14 | R = np.eye(3)[None, :, :].repeat(n, axis=0)
15 |
16 | axis = 2 - axis
17 | start = axis // 2
18 | step = axis % 2 + 1
19 | s = slice(start, start + step + 1, step)
20 |
21 | R[:, s, s] = np.array(
22 | [
23 | [np.cos(angle), (-1) ** (axis + 1) * np.sin(angle)],
24 | [(-1) ** axis * np.sin(angle), np.cos(angle)],
25 | ]
26 | ).transpose(2, 0, 1)
27 | return R
28 |
29 |
30 | def get_bb_transform(n_xyz, ca_xyz, c_xyz):
31 | """
32 | Compute translation and rotation of the canoncical backbone frame (triangle N-Ca-C) from a position with
33 | Ca at the origin, N on the x-axis and C in the xy-plane to the global position of the backbone frame
34 |
35 | Args:
36 | n_xyz: (n, 3)
37 | ca_xyz: (n, 3)
38 | c_xyz: (n, 3)
39 |
40 | Returns:
41 | quaternion represented as array of shape (n, 4)
42 | translation vector which is an array of shape (n, 3)
43 | """
44 |
45 | translation = ca_xyz
46 | n_xyz = n_xyz - translation
47 | c_xyz = c_xyz - translation
48 |
49 | # Find rotation matrix that aligns the coordinate systems
50 | # rotate around y-axis to move N into the xy-plane
51 | theta_y = np.arctan2(n_xyz[:, 2], -n_xyz[:, 0])
52 | Ry = rotation_matrix(theta_y, 1)
53 | n_xyz = np.einsum("noi,ni->no", Ry.transpose(0, 2, 1), n_xyz)
54 |
55 | # rotate around z-axis to move N onto the x-axis
56 | theta_z = np.arctan2(n_xyz[:, 1], n_xyz[:, 0])
57 | Rz = rotation_matrix(theta_z, 2)
58 | # n_xyz = np.einsum('noi,ni->no', Rz.transpose(0, 2, 1), n_xyz)
59 |
60 | # rotate around x-axis to move C into the xy-plane
61 | c_xyz = np.einsum(
62 | "noj,nji,ni->no", Rz.transpose(0, 2, 1), Ry.transpose(0, 2, 1), c_xyz
63 | )
64 | theta_x = np.arctan2(c_xyz[:, 2], c_xyz[:, 1])
65 | Rx = rotation_matrix(theta_x, 0)
66 |
67 | # Final rotation matrix
68 | R = np.einsum("nok,nkj,nji->noi", Ry, Rz, Rx)
69 |
70 | # Convert to quaternion
71 | # q = w + i*u_x + j*u_y + k * u_z
72 | quaternion = rotation_matrix_to_quaternion(R)
73 |
74 | return quaternion, translation
75 |
76 |
77 | def get_bb_coords_from_transform(ca_coords, quaternion):
78 | """
79 | Args:
80 | ca_coords: (n, 3)
81 | quaternion: (n, 4)
82 | Returns:
83 | backbone coords (n*3, 3), order is [N, CA, C]
84 | backbone atom types as a list of length n*3
85 | """
86 | R = quaternion_to_rotation_matrix(quaternion)
87 | bb_coords = np.tile(
88 | np.array(
89 | [
90 | [N_CA_DIST, 0, 0],
91 | [0, 0, 0],
92 | [CA_C_DIST * np.cos(N_CA_C_ANGLE), CA_C_DIST * np.sin(N_CA_C_ANGLE), 0],
93 | ]
94 | ),
95 | [len(ca_coords), 1],
96 | )
97 | bb_coords = np.einsum(
98 | "noi,ni->no", R.repeat(3, axis=0), bb_coords
99 | ) + ca_coords.repeat(3, axis=0)
100 | bb_atom_types = [t for _ in range(len(ca_coords)) for t in ["N", "C", "C"]]
101 |
102 | return bb_coords, bb_atom_types
103 |
104 |
105 | def quaternion_to_rotation_matrix(q):
106 | """
107 | x_rot = R x
108 |
109 | Args:
110 | q: (n, 4)
111 | Returns:
112 | R: (n, 3, 3)
113 | """
114 | # Normalize
115 | q = q / (q**2).sum(1, keepdims=True) ** 0.5
116 |
117 | # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
118 | w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
119 | R = np.stack(
120 | [
121 | np.stack(
122 | [
123 | 1 - 2 * y**2 - 2 * z**2,
124 | 2 * x * y - 2 * z * w,
125 | 2 * x * z + 2 * y * w,
126 | ],
127 | axis=1,
128 | ),
129 | np.stack(
130 | [
131 | 2 * x * y + 2 * z * w,
132 | 1 - 2 * x**2 - 2 * z**2,
133 | 2 * y * z - 2 * x * w,
134 | ],
135 | axis=1,
136 | ),
137 | np.stack(
138 | [
139 | 2 * x * z - 2 * y * w,
140 | 2 * y * z + 2 * x * w,
141 | 1 - 2 * x**2 - 2 * y**2,
142 | ],
143 | axis=1,
144 | ),
145 | ],
146 | axis=1,
147 | )
148 |
149 | return R
150 |
151 |
152 | def rotation_matrix_to_quaternion(R):
153 | """
154 | https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
155 | Args:
156 | R: (n, 3, 3)
157 | Returns:
158 | q: (n, 4)
159 | """
160 |
161 | t = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
162 | r = np.sqrt(1 + t)
163 | w = 0.5 * r
164 | x = np.sign(R[:, 2, 1] - R[:, 1, 2]) * np.abs(
165 | 0.5 * np.sqrt(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2])
166 | )
167 | y = np.sign(R[:, 0, 2] - R[:, 2, 0]) * np.abs(
168 | 0.5 * np.sqrt(1 - R[:, 0, 0] + R[:, 1, 1] - R[:, 2, 2])
169 | )
170 | z = np.sign(R[:, 1, 0] - R[:, 0, 1]) * np.abs(
171 | 0.5 * np.sqrt(1 - R[:, 0, 0] - R[:, 1, 1] + R[:, 2, 2])
172 | )
173 |
174 | return np.stack((w, x, y, z), axis=1)
175 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/ligand/molecule_builder.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | import warnings
3 |
4 | import numpy as np
5 | import openbabel
6 | import torch
7 | from rdkit import Chem
8 | from rdkit.Chem.rdForceFieldHelpers import UFFHasAllMoleculeParams, UFFOptimizeMolecule
9 |
10 | from experiments.data.ligand import utils
11 | from experiments.data.ligand.constants import (
12 | bond_dict,
13 | bonds1,
14 | bonds2,
15 | bonds3,
16 | margin1,
17 | margin2,
18 | margin3,
19 | )
20 |
21 |
22 | def get_bond_order(atom1, atom2, distance):
23 | distance = 100 * distance # We change the metric
24 |
25 | if (
26 | atom1 in bonds3
27 | and atom2 in bonds3[atom1]
28 | and distance < bonds3[atom1][atom2] + margin3
29 | ):
30 | return 3 # Triple
31 |
32 | if (
33 | atom1 in bonds2
34 | and atom2 in bonds2[atom1]
35 | and distance < bonds2[atom1][atom2] + margin2
36 | ):
37 | return 2 # Double
38 |
39 | if (
40 | atom1 in bonds1
41 | and atom2 in bonds1[atom1]
42 | and distance < bonds1[atom1][atom2] + margin1
43 | ):
44 | return 1 # Single
45 |
46 | return 0 # No bond
47 |
48 |
49 | def get_bond_order_batch(atoms1, atoms2, distances, dataset_info):
50 | if isinstance(atoms1, np.ndarray):
51 | atoms1 = torch.from_numpy(atoms1)
52 | if isinstance(atoms2, np.ndarray):
53 | atoms2 = torch.from_numpy(atoms2)
54 | if isinstance(distances, np.ndarray):
55 | distances = torch.from_numpy(distances)
56 |
57 | distances = 100 * distances # We change the metric
58 |
59 | bonds1 = torch.tensor(dataset_info["bonds1"], device=atoms1.device)
60 | bonds2 = torch.tensor(dataset_info["bonds2"], device=atoms1.device)
61 | bonds3 = torch.tensor(dataset_info["bonds3"], device=atoms1.device)
62 |
63 | bond_types = torch.zeros_like(atoms1) # 0: No bond
64 |
65 | # Single
66 | bond_types[distances < bonds1[atoms1, atoms2] + margin1] = 1
67 |
68 | # Double (note that already assigned single bonds will be overwritten)
69 | bond_types[distances < bonds2[atoms1, atoms2] + margin2] = 2
70 |
71 | # Triple
72 | bond_types[distances < bonds3[atoms1, atoms2] + margin3] = 3
73 |
74 | return bond_types
75 |
76 |
77 | def make_mol_openbabel(positions, atom_types, atom_decoder):
78 | """
79 | Build an RDKit molecule using openbabel for creating bonds
80 | Args:
81 | positions: N x 3
82 | atom_types: N
83 | atom_decoder: maps indices to atom types
84 | Returns:
85 | rdkit molecule
86 | """
87 |
88 | with tempfile.NamedTemporaryFile() as tmp:
89 | tmp_file = tmp.name
90 |
91 | # Write xyz file
92 | utils.write_xyz_file(positions, atom_types, tmp_file)
93 |
94 | # Convert to sdf file with openbabel
95 | # openbabel will add bonds
96 | obConversion = openbabel.OBConversion()
97 | obConversion.SetInAndOutFormats("xyz", "sdf")
98 | ob_mol = openbabel.OBMol()
99 | obConversion.ReadFile(ob_mol, tmp_file)
100 |
101 | obConversion.WriteFile(ob_mol, tmp_file)
102 |
103 | # Read sdf file with RDKit
104 | mol = Chem.SDMolSupplier(tmp_file, sanitize=False)[0]
105 |
106 | return mol
107 |
108 |
109 | def make_mol_edm(positions, atom_types, dataset_info, add_coords):
110 | """
111 | Equivalent to EDM's way of building RDKit molecules
112 | """
113 | n = len(positions)
114 |
115 | # (X, A, E): atom_types, adjacency matrix, edge_types
116 | # X: N (int)
117 | # A: N x N (bool) -> (binary adjacency matrix)
118 | # E: N x N (int) -> (bond type, 0 if no bond)
119 | pos = positions.unsqueeze(0) # add batch dim
120 | dists = torch.cdist(pos, pos, p=2).squeeze(0).view(-1) # remove batch dim & flatten
121 | atoms1, atoms2 = torch.cartesian_prod(atom_types, atom_types).T
122 | E_full = get_bond_order_batch(atoms1, atoms2, dists, dataset_info).view(n, n)
123 | E = torch.tril(E_full, diagonal=-1) # Warning: the graph should be DIRECTED
124 | A = E.bool()
125 | X = atom_types
126 |
127 | mol = Chem.RWMol()
128 | for atom in X:
129 | a = Chem.Atom(dataset_info["atom_decoder"][atom.item()])
130 | mol.AddAtom(a)
131 |
132 | all_bonds = torch.nonzero(A)
133 | for bond in all_bonds:
134 | mol.AddBond(
135 | bond[0].item(), bond[1].item(), bond_dict[E[bond[0], bond[1]].item()]
136 | )
137 |
138 | if add_coords:
139 | conf = Chem.Conformer(mol.GetNumAtoms())
140 | for i in range(mol.GetNumAtoms()):
141 | conf.SetAtomPosition(
142 | i,
143 | (
144 | positions[i, 0].item(),
145 | positions[i, 1].item(),
146 | positions[i, 2].item(),
147 | ),
148 | )
149 | mol.AddConformer(conf)
150 |
151 | return mol
152 |
153 |
154 | def build_molecule(
155 | positions, atom_types, dataset_info, add_coords=False, use_openbabel=True
156 | ):
157 | """
158 | Build RDKit molecule
159 | Args:
160 | positions: N x 3
161 | atom_types: N
162 | dataset_info: dict
163 | add_coords: Add conformer to mol (always added if use_openbabel=True)
164 | use_openbabel: use OpenBabel to create bonds
165 | Returns:
166 | RDKit molecule
167 | """
168 | if use_openbabel:
169 | mol = make_mol_openbabel(positions, atom_types, dataset_info["atom_decoder"])
170 | else:
171 | mol = make_mol_edm(positions, atom_types, dataset_info, add_coords)
172 |
173 | return mol
174 |
175 |
176 | def process_molecule(
177 | rdmol, add_hydrogens=False, sanitize=False, relax_iter=0, largest_frag=False
178 | ):
179 | """
180 | Apply filters to an RDKit molecule. Makes a copy first.
181 | Args:
182 | rdmol: rdkit molecule
183 | add_hydrogens
184 | sanitize
185 | relax_iter: maximum number of UFF optimization iterations
186 | largest_frag: filter out the largest fragment in a set of disjoint
187 | molecules
188 | Returns:
189 | RDKit molecule or None if it does not pass the filters
190 | """
191 |
192 | # Create a copy
193 | mol = Chem.Mol(rdmol)
194 |
195 | if sanitize:
196 | try:
197 | Chem.SanitizeMol(mol)
198 | except ValueError:
199 | warnings.warn("Sanitization failed. Returning None.")
200 | return None
201 |
202 | if add_hydrogens:
203 | mol = Chem.AddHs(mol, addCoords=(len(mol.GetConformers()) > 0))
204 |
205 | if largest_frag:
206 | mol_frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
207 | mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
208 | if sanitize:
209 | # sanitize the updated molecule
210 | try:
211 | Chem.SanitizeMol(mol)
212 | except ValueError:
213 | return None
214 |
215 | if relax_iter > 0:
216 | if not UFFHasAllMoleculeParams(mol):
217 | warnings.warn(
218 | "UFF parameters not available for all atoms. " "Returning None."
219 | )
220 | return None
221 |
222 | try:
223 | uff_relax(mol, relax_iter)
224 | if sanitize:
225 | # sanitize the updated molecule
226 | Chem.SanitizeMol(mol)
227 | except (RuntimeError, ValueError) as e:
228 | return None
229 |
230 | return mol
231 |
232 |
233 | def uff_relax(mol, max_iter=200):
234 | """
235 | Uses RDKit's universal force field (UFF) implementation to optimize a
236 | molecule.
237 | """
238 | more_iterations_required = UFFOptimizeMolecule(mol, maxIters=max_iter)
239 | if more_iterations_required:
240 | warnings.warn(
241 | f"Maximum number of FF iterations reached. "
242 | f"Returning molecule after {max_iter} relaxation steps."
243 | )
244 | return more_iterations_required
245 |
246 |
247 | def filter_rd_mol(rdmol):
248 | """
249 | Filter out RDMols if they have a 3-3 ring intersection
250 | adapted from:
251 | https://github.com/luost26/3D-Generative-SBDD/blob/main/utils/chem.py
252 | """
253 | ring_info = rdmol.GetRingInfo()
254 | ring_info.AtomRings()
255 | rings = [set(r) for r in ring_info.AtomRings()]
256 |
257 | # 3-3 ring intersection
258 | for i, ring_a in enumerate(rings):
259 | if len(ring_a) != 3:
260 | continue
261 | for j, ring_b in enumerate(rings):
262 | if i <= j:
263 | continue
264 | inter = ring_a.intersection(ring_b)
265 | if (len(ring_b) == 3) and (len(inter) > 0):
266 | return False
267 |
268 | return True
269 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/ligand/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable, Union
2 |
3 | import networkx as nx
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from Bio.PDB.Polypeptide import is_aa
8 | from networkx.algorithms import isomorphism
9 | from rdkit import Chem
10 |
11 |
12 | class Queue:
13 | def __init__(self, max_len=50):
14 | self.items = []
15 | self.max_len = max_len
16 |
17 | def __len__(self):
18 | return len(self.items)
19 |
20 | def add(self, item):
21 | self.items.insert(0, item)
22 | if len(self) > self.max_len:
23 | self.items.pop()
24 |
25 | def mean(self):
26 | return np.mean(self.items)
27 |
28 | def std(self):
29 | return np.std(self.items)
30 |
31 |
32 | def reverse_tensor(x):
33 | return x[torch.arange(x.size(0) - 1, -1, -1)]
34 |
35 |
36 | #####
37 |
38 |
39 | def get_grad_norm(
40 | parameters: Union[torch.Tensor, Iterable[torch.Tensor]], norm_type: float = 2.0
41 | ) -> torch.Tensor:
42 | """
43 | Adapted from: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_
44 | """
45 |
46 | if isinstance(parameters, torch.Tensor):
47 | parameters = [parameters]
48 | parameters = [p for p in parameters if p.grad is not None]
49 |
50 | norm_type = float(norm_type)
51 |
52 | if len(parameters) == 0:
53 | return torch.tensor(0.0)
54 |
55 | device = parameters[0].grad.device
56 |
57 | total_norm = torch.norm(
58 | torch.stack(
59 | [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
60 | ),
61 | norm_type,
62 | )
63 |
64 | return total_norm
65 |
66 |
67 | def write_xyz_file(coords, atom_types, filename):
68 | out = f"{len(coords)}\n\n"
69 | assert len(coords) == len(atom_types)
70 | for i in range(len(coords)):
71 | out += f"{atom_types[i]} {coords[i, 0]:.3f} {coords[i, 1]:.3f} {coords[i, 2]:.3f}\n"
72 | with open(filename, "w") as f:
73 | f.write(out)
74 |
75 |
76 | def write_sdf_file(sdf_path, molecules):
77 | # NOTE Changed to be compatitble with more versions of rdkit
78 | # with Chem.SDWriter(str(sdf_path)) as w:
79 | # for mol in molecules:
80 | # w.write(mol)
81 |
82 | w = Chem.SDWriter(str(sdf_path))
83 | for m in molecules:
84 | if m is not None:
85 | w.write(m)
86 |
87 | print(f"Wrote SDF file to {sdf_path}")
88 |
89 |
90 | def residues_to_atoms(x_ca, dataset_info):
91 | x = x_ca
92 | one_hot = F.one_hot(
93 | torch.tensor(dataset_info["atom_encoder"]["C"], device=x_ca.device),
94 | num_classes=len(dataset_info["atom_encoder"]),
95 | ).repeat(*x_ca.shape[:-1], 1)
96 | return x, one_hot
97 |
98 |
99 | def get_residue_with_resi(pdb_chain, resi):
100 | res = [x for x in pdb_chain.get_residues() if x.id[1] == resi]
101 | assert len(res) == 1
102 | return res[0]
103 |
104 |
105 | def get_pocket_from_ligand(pdb_model, ligand_id, dist_cutoff=8.0):
106 | chain, resi = ligand_id.split(":")
107 | ligand = get_residue_with_resi(pdb_model[chain], int(resi))
108 | ligand_coords = torch.from_numpy(
109 | np.array([a.get_coord() for a in ligand.get_atoms()])
110 | )
111 |
112 | pocket_residues = []
113 | for residue in pdb_model.get_residues():
114 | if residue.id[1] == resi:
115 | continue # skip ligand itself
116 |
117 | res_coords = torch.from_numpy(
118 | np.array([a.get_coord() for a in residue.get_atoms()])
119 | )
120 | if (
121 | is_aa(residue.get_resname(), standard=True)
122 | and torch.cdist(res_coords, ligand_coords).min() < dist_cutoff
123 | ):
124 | pocket_residues.append(residue)
125 |
126 | return pocket_residues
127 |
128 |
129 | def batch_to_list(data, batch_mask):
130 | # data_list = []
131 | # for i in torch.unique(batch_mask):
132 | # data_list.append(data[batch_mask == i])
133 | # return data_list
134 |
135 | # make sure batch_mask is increasing
136 | idx = torch.argsort(batch_mask)
137 | batch_mask = batch_mask[idx]
138 | data = data[idx]
139 |
140 | chunk_sizes = torch.unique(batch_mask, return_counts=True)[1].tolist()
141 | return torch.split(data, chunk_sizes)
142 |
143 |
144 | def num_nodes_to_batch_mask(n_samples, num_nodes, device):
145 | assert isinstance(num_nodes, int) or len(num_nodes) == n_samples
146 |
147 | if isinstance(num_nodes, torch.Tensor):
148 | num_nodes = num_nodes.to(device)
149 |
150 | sample_inds = torch.arange(n_samples, device=device)
151 |
152 | return torch.repeat_interleave(sample_inds, num_nodes)
153 |
154 |
155 | def rdmol_to_nxgraph(rdmol):
156 | graph = nx.Graph()
157 | for atom in rdmol.GetAtoms():
158 | # Add the atoms as nodes
159 | graph.add_node(atom.GetIdx(), atom_type=atom.GetAtomicNum())
160 |
161 | # Add the bonds as edges
162 | for bond in rdmol.GetBonds():
163 | graph.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())
164 |
165 | return graph
166 |
167 |
168 | def calc_rmsd(mol_a, mol_b):
169 | """Calculate RMSD of two molecules with unknown atom correspondence."""
170 | graph_a = rdmol_to_nxgraph(mol_a)
171 | graph_b = rdmol_to_nxgraph(mol_b)
172 |
173 | gm = isomorphism.GraphMatcher(
174 | graph_a, graph_b, node_match=lambda na, nb: na["atom_type"] == nb["atom_type"]
175 | )
176 |
177 | isomorphisms = list(gm.isomorphisms_iter())
178 | if len(isomorphisms) < 1:
179 | return None
180 |
181 | all_rmsds = []
182 | for mapping in isomorphisms:
183 | atom_types_a = [atom.GetAtomicNum() for atom in mol_a.GetAtoms()]
184 | atom_types_b = [
185 | mol_b.GetAtomWithIdx(mapping[i]).GetAtomicNum()
186 | for i in range(mol_b.GetNumAtoms())
187 | ]
188 | assert atom_types_a == atom_types_b
189 |
190 | conf_a = mol_a.GetConformer()
191 | coords_a = np.array(
192 | [conf_a.GetAtomPosition(i) for i in range(mol_a.GetNumAtoms())]
193 | )
194 | conf_b = mol_b.GetConformer()
195 | coords_b = np.array(
196 | [conf_b.GetAtomPosition(mapping[i]) for i in range(mol_b.GetNumAtoms())]
197 | )
198 |
199 | diff = coords_a - coords_b
200 | rmsd = np.sqrt(np.mean(np.sum(diff * diff, axis=1)))
201 | all_rmsds.append(rmsd)
202 |
203 | if len(isomorphisms) > 1:
204 | print("More than one isomorphism found. Returning minimum RMSD.")
205 |
206 | return min(all_rmsds)
207 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/pubchem/download_pubchem.py:
--------------------------------------------------------------------------------
1 | import os
2 | import requests
3 | from bs4 import BeautifulSoup
4 | from tqdm import tqdm
5 |
6 |
7 | def download_sdf_files(url, folder_path):
8 | try:
9 | response = requests.get(url)
10 | response.raise_for_status() # If the response was unsuccessful, this will raise a HTTPError
11 | except requests.exceptions.HTTPError as errh:
12 | print("HTTP Error:", errh)
13 | return
14 | except requests.exceptions.ConnectionError as errc:
15 | print("Error Connecting:", errc)
16 | return
17 | except requests.exceptions.Timeout as errt:
18 | print("Timeout Error:", errt)
19 | return
20 | except requests.exceptions.RequestException as err:
21 | print("Something went wrong with the request:", err)
22 | return
23 |
24 | soup = BeautifulSoup(response.text, "html.parser")
25 |
26 | for link in tqdm(soup.find_all("a")):
27 | file_link = link.get("href")
28 | if file_link.endswith(".sdf.gz"):
29 | file_url = url + file_link
30 | file_path = os.path.join(folder_path, file_link)
31 |
32 | try:
33 | with requests.get(file_url, stream=True) as r:
34 | r.raise_for_status()
35 | with open(file_path, "wb") as f:
36 | for chunk in r.iter_content(chunk_size=8192):
37 | f.write(chunk)
38 | except requests.exceptions.HTTPError as errh:
39 | print("HTTP Error:", errh)
40 | except requests.exceptions.ConnectionError as errc:
41 | print("Error Connecting:", errc)
42 | except requests.exceptions.Timeout as errt:
43 | print("Timeout Error:", errt)
44 | except requests.exceptions.RequestException as err:
45 | print("Something went wrong with the request:", err)
46 |
47 |
48 | url = "https://ftp.ncbi.nlm.nih.gov/pubchem/Compound_3D/01_conf_per_cmpd/SDF/"
49 | folder_path = (
50 | "----" # replace with your folder path
51 | )
52 | download_sdf_files(url, folder_path)
53 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/pubchem/preprocess_pubchem.py:
--------------------------------------------------------------------------------
1 | from rdkit import Chem
2 | import gzip
3 | from glob import glob
4 | import os
5 | from multiprocessing import cpu_count, Pool
6 | import pubchem.data.dataset_utils as dataset_utils
7 | from tqdm import tqdm
8 | import pickle
9 | import argparse
10 |
11 |
12 | def get_args():
13 | # fmt: off
14 | parser = argparse.ArgumentParser(description='Data generation')
15 | parser.add_argument('--index', default=0, type=int, help='Which part of the splitted dataset')
16 | args = parser.parse_args()
17 | return args
18 |
19 |
20 | def chunks(lst, n):
21 | """Yield successive n-sized chunks from lst."""
22 | for i in range(0, len(lst), n):
23 | yield lst[i : i + n]
24 |
25 |
26 | full_atom_encoder = {
27 | "H": 0,
28 | "C": 1,
29 | "N": 2,
30 | "O": 3,
31 | "F": 4,
32 | "Si": 5,
33 | "P": 6,
34 | "S": 7,
35 | "Cl": 8,
36 | "Br": 9,
37 | "I": 10,
38 | }
39 |
40 |
41 | def process(file):
42 | inf = gzip.open(file)
43 | with Chem.ForwardSDMolSupplier(inf) as gzsuppl:
44 | molecules = [x for x in gzsuppl if x is not None]
45 | for mol in molecules:
46 | try:
47 | smiles = Chem.MolToSmiles(mol)
48 | smiles_list.append(smiles)
49 | data = dataset_utils.mol_to_torch_geometric(mol, full_atom_encoder, smiles)
50 | if data.pos.shape[0] != data.x.shape[0]:
51 | print(f"Molecule {smiles} does not have 3D information!")
52 | continue
53 | if data.pos.ndim != 2:
54 | print(f"Molecule {smiles} does not have 3D information!")
55 | continue
56 | if len(data.pos) < 2:
57 | print(f"Molecule {smiles} does not have 3D information!")
58 | continue
59 | data_list.append(data)
60 | except:
61 | continue
62 | # pbar.update(1)
63 |
64 | return
65 |
66 |
67 | if __name__ == "__main__":
68 | args = get_args()
69 |
70 | data_list = []
71 | smiles_list = []
72 |
73 | files = f"----files_{args.index}.pickle"
74 | with open(files, "rb") as f:
75 | files = pickle.load(f)
76 |
77 | pbar = tqdm(total=len(files))
78 |
79 | for file in tqdm(files):
80 | process(file)
81 |
82 | with open(f"data_list_{args.index}.pickle", "wb") as f:
83 | pickle.dump(data_list, f)
84 | with open(f"smiles_list_{args.index}.pickle", "wb") as f:
85 | pickle.dump(smiles_list, f)
86 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/pubchem/process_pubchem.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | import os
3 | from tqdm import tqdm
4 | import lmdb
5 | from rdkit import Chem, RDLogger
6 | import experiments.data.utils as dataset_utils
7 | import gzip
8 | import io
9 | import multiprocessing as mp
10 | import pickle
11 |
12 | RDLogger.DisableLog("rdApp.*")
13 |
14 |
15 | def chunks(lst, n):
16 | """Yield successive n-sized chunks from lst."""
17 | for i in range(0, len(lst), n):
18 | yield lst[i : i + n]
19 |
20 |
21 | DATA_PATH = "---"
22 | FULL_ATOM_ENCODER = {
23 | "H": 0,
24 | "B": 1,
25 | "C": 2,
26 | "N": 3,
27 | "O": 4,
28 | "F": 5,
29 | "Al": 6,
30 | "Si": 7,
31 | "P": 8,
32 | "S": 9,
33 | "Cl": 10,
34 | "As": 11,
35 | "Br": 12,
36 | "I": 13,
37 | "Hg": 14,
38 | "Bi": 15,
39 | }
40 |
41 |
42 | def process_files(
43 | processes: int = 36, chunk_size: int = 1024, subchunk: int = 128, removeHs=False
44 | ):
45 | """
46 | :param dataset:
47 | :param max_conformers:
48 | :param processes:
49 | :param chunk_size:
50 | :param subchunk:
51 | :return:
52 | """
53 |
54 | data_list = glob(os.path.join(DATA_PATH, "raw/*.gz"))
55 | h = "noh" if removeHs else "h"
56 | print(f"Process without hydrogens: {removeHs}")
57 |
58 | save_path = os.path.join(DATA_PATH, f"database_{h}")
59 | if os.path.exists(save_path):
60 | print("FYI: Output directory has been created already.")
61 | chunked_list = list(chunks(data_list, chunk_size))
62 | chunked_list = [list(chunks(l, subchunk)) for l in chunked_list]
63 |
64 | print(f"Total number of molecules {len(data_list)}.")
65 | print(f"Processing {len(chunked_list)} chunks each of size {chunk_size}.")
66 |
67 | env = lmdb.open(str(save_path), map_size=int(1e13))
68 | global_id = 0
69 | with env.begin(write=True) as txn:
70 | for chunklist in tqdm(chunked_list, total=len(chunked_list), desc="Chunks"):
71 | chunkresult = []
72 | for datachunk in tqdm(chunklist, total=len(chunklist), desc="Datachunks"):
73 | removeHs_list = [removeHs] * len(datachunk)
74 | with mp.Pool(processes=processes) as pool:
75 | res = pool.starmap(
76 | func=db_sample_helper, iterable=zip(datachunk, removeHs_list)
77 | )
78 | res = [r for r in res if r is not None]
79 | chunkresult.append(res)
80 |
81 | confs_sub = []
82 | smiles_sub = []
83 | for cr in chunkresult:
84 | subconfs = [a["confs"] for a in cr]
85 | subconfs = [item for sublist in subconfs for item in sublist]
86 | subsmiles = [a["smiles"] for a in cr]
87 | subsmiles = [item for sublist in subsmiles for item in sublist]
88 | confs_sub.append(subconfs)
89 | smiles_sub.append(subsmiles)
90 |
91 | confs_sub = [item for sublist in confs_sub for item in sublist]
92 | smiles_sub = [item for sublist in smiles_sub for item in sublist]
93 |
94 | assert len(confs_sub) == len(smiles_sub)
95 | # save
96 | for conf in confs_sub:
97 | result = txn.put(str(global_id).encode(), conf, overwrite=False)
98 | if not result:
99 | raise RuntimeError(
100 | f"LMDB entry {global_id} in {str(save_path)} " "already exists"
101 | )
102 | global_id += 1
103 |
104 | print(f"{global_id} molecules have been processed!")
105 | print("Finished!")
106 |
107 |
108 | def db_sample_helper(file, removeHs=False):
109 | saved_confs_list = []
110 | smiles_list = []
111 |
112 | inf = gzip.open(file)
113 | with Chem.ForwardSDMolSupplier(inf, removeHs=removeHs) as gzsuppl:
114 | molecules = [x for x in gzsuppl if x is not None]
115 | for mol in molecules:
116 | try:
117 | smiles = Chem.MolToSmiles(mol)
118 | data = dataset_utils.mol_to_torch_geometric(
119 | mol, FULL_ATOM_ENCODER, smiles, remove_hydrogens=removeHs
120 | )
121 | if data.pos.shape[0] != data.x.shape[0]:
122 | continue
123 | if data.pos.ndim != 2:
124 | continue
125 | if len(data.pos) < 2:
126 | continue
127 | except:
128 | continue
129 | # create binary object to be saved
130 | buf = io.BytesIO()
131 | saves = {
132 | "mol": mol,
133 | "data": data,
134 | }
135 | with gzip.GzipFile(fileobj=buf, mode="wb", compresslevel=6) as f:
136 | f.write(pickle.dumps(saves))
137 | compressed = buf.getvalue()
138 | saved_confs_list.append(compressed)
139 | smiles_list.append(smiles)
140 | return {
141 | "confs": saved_confs_list,
142 | "smiles": smiles_list,
143 | }
144 |
145 |
146 | if __name__ == "__main__":
147 | process_files(removeHs=False)
148 | # process_files(removeHs=True)
149 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/pubchem/pubchem_dataset_adaptive.py:
--------------------------------------------------------------------------------
1 | from rdkit import RDLogger
2 | import torch
3 | import numpy as np
4 | from os.path import join
5 | from torch_geometric.data import Dataset, DataLoader
6 | from experiments.data.utils import load_pickle, make_splits
7 | from torch.utils.data import Subset
8 | import lmdb
9 | import pickle
10 | import gzip
11 | import io
12 | from pytorch_lightning import LightningDataModule
13 | import os
14 | import experiments.data.utils as dataset_utils
15 | from experiments.data.abstract_dataset import (
16 | AbstractAdaptiveDataModule,
17 | )
18 |
19 | GEOM_DATADIR = "----"
20 |
21 | full_atom_encoder = {
22 | "H": 0,
23 | "B": 1,
24 | "C": 2,
25 | "N": 3,
26 | "O": 4,
27 | "F": 5,
28 | "Al": 6,
29 | "Si": 7,
30 | "P": 8,
31 | "S": 9,
32 | "Cl": 10,
33 | "As": 11,
34 | "Br": 12,
35 | "I": 13,
36 | "Hg": 14,
37 | "Bi": 15,
38 | }
39 |
40 |
41 | class PubChemLMDBDataset(Dataset):
42 | def __init__(
43 | self,
44 | root: str,
45 | remove_hs: bool,
46 | evaluation: bool = False,
47 | ):
48 | """
49 | Constructor
50 | """
51 | self.data_file = root
52 | self._num_graphs = 95173200 if remove_hs else 95173300 # 94980241 for noH!!
53 |
54 | if remove_hs:
55 | assert "_noh" in root
56 | self.stats_dir = (
57 | "/scratch1/cremej01/data/pubchem/dataset_noh/processed"
58 | if not evaluation
59 | else GEOM_DATADIR
60 | )
61 | else:
62 | assert "_h" in root
63 | self.stats_dir = (
64 | "/scratch1/cremej01/data/pubchem/dataset_h/processed"
65 | if not evaluation
66 | else GEOM_DATADIR
67 | )
68 |
69 | super().__init__(root)
70 |
71 | self.remove_hs = remove_hs
72 | self.statistics = dataset_utils.Statistics(
73 | num_nodes=load_pickle(
74 | os.path.join(self.stats_dir, self.processed_files[0])
75 | ),
76 | atom_types=torch.from_numpy(
77 | np.load(os.path.join(self.stats_dir, self.processed_files[1]))
78 | ),
79 | bond_types=torch.from_numpy(
80 | np.load(os.path.join(self.stats_dir, self.processed_files[2]))
81 | ),
82 | charge_types=torch.from_numpy(
83 | np.load(os.path.join(self.stats_dir, self.processed_files[3]))
84 | ),
85 | valencies=load_pickle(
86 | os.path.join(self.stats_dir, self.processed_files[4])
87 | ),
88 | bond_lengths=load_pickle(
89 | os.path.join(self.stats_dir, self.processed_files[5])
90 | ),
91 | bond_angles=torch.from_numpy(
92 | np.load(os.path.join(self.stats_dir, self.processed_files[6]))
93 | ),
94 | is_aromatic=torch.from_numpy(
95 | np.load(os.path.join(self.stats_dir, self.processed_files[7]))
96 | ).float(),
97 | is_in_ring=torch.from_numpy(
98 | np.load(os.path.join(self.stats_dir, self.processed_files[8]))
99 | ).float(),
100 | hybridization=torch.from_numpy(
101 | np.load(os.path.join(self.stats_dir, self.processed_files[9]))
102 | ).float(),
103 | )
104 | self.smiles = load_pickle(
105 | os.path.join(self.stats_dir, self.processed_files[10])
106 | )
107 |
108 | def _init_db(self):
109 | self._env = lmdb.open(
110 | str(self.data_file),
111 | readonly=True,
112 | lock=False,
113 | readahead=False,
114 | meminit=False,
115 | create=False,
116 | )
117 |
118 | def get(self, index: int):
119 | self._init_db()
120 |
121 | with self._env.begin(write=False) as txn:
122 | compressed = txn.get(str(index).encode())
123 | buf = io.BytesIO(compressed)
124 | with gzip.GzipFile(fileobj=buf, mode="rb") as f:
125 | serialized = f.read()
126 | try:
127 | item = pickle.loads(serialized)["data"]
128 | except:
129 | return None
130 |
131 | return item
132 |
133 | def len(self) -> int:
134 | r"""Returns the number of graphs stored in the dataset."""
135 | return self._num_graphs
136 |
137 | def __len__(self) -> int:
138 | return self._num_graphs
139 |
140 | @property
141 | def processed_files(self):
142 | h = "noh" if self.remove_hs else "h"
143 | return [
144 | f"train_n_{h}.pickle",
145 | f"train_atom_types_{h}.npy",
146 | f"train_bond_types_{h}.npy",
147 | f"train_charges_{h}.npy",
148 | f"train_valency_{h}.pickle",
149 | f"train_bond_lengths_{h}.pickle",
150 | f"train_angles_{h}.npy",
151 | f"train_is_aromatic_{h}.npy",
152 | f"train_is_in_ring_{h}.npy",
153 | f"train_hybridization_{h}.npy",
154 | "train_smiles.pickle",
155 | ]
156 |
157 |
158 | class PubChemDataModule(AbstractAdaptiveDataModule):
159 | def __init__(self, hparams, evaluation=False):
160 | self.save_hyperparameters(hparams)
161 | self.datadir = hparams.dataset_root
162 | self.pin_memory = True
163 |
164 | self.remove_hs = hparams.remove_hs
165 | if self.remove_hs:
166 | print("Pre-Training on dataset with implicit hydrogens")
167 | self.dataset = PubChemLMDBDataset(
168 | root=self.datadir, remove_hs=self.remove_hs, evaluation=evaluation
169 | )
170 |
171 | self.train_smiles = self.dataset.smiles
172 |
173 | self.idx_train, self.idx_val, self.idx_test = make_splits(
174 | len(self.dataset),
175 | train_size=hparams.train_size,
176 | val_size=hparams.val_size,
177 | test_size=hparams.test_size,
178 | seed=hparams.seed,
179 | filename=join(self.hparams["save_dir"], "splits.npz"),
180 | splits=None,
181 | )
182 | print(
183 | f"train {len(self.idx_train)}, val {len(self.idx_val)}, test {len(self.idx_test)}"
184 | )
185 | train_dataset = Subset(self.dataset, self.idx_train)
186 | val_dataset = Subset(self.dataset, self.idx_val)
187 | test_dataset = Subset(self.dataset, self.idx_test)
188 |
189 | self.statistics = {
190 | "train": self.dataset.statistics,
191 | "val": self.dataset.statistics,
192 | "test": self.dataset.statistics,
193 | }
194 |
195 | super().__init__(hparams, train_dataset, val_dataset, test_dataset)
196 |
197 | def _train_dataloader(self, shuffle=True):
198 | dataloader = DataLoader(
199 | dataset=self.train_dataset,
200 | batch_size=self.cfg.batch_size,
201 | num_workers=self.cfg.num_workers,
202 | pin_memory=self.pin_memory,
203 | shuffle=shuffle,
204 | persistent_workers=False,
205 | )
206 | return dataloader
207 |
208 | def _val_dataloader(self, shuffle=False):
209 | dataloader = DataLoader(
210 | dataset=self.val_dataset,
211 | batch_size=self.cfg.batch_size,
212 | num_workers=self.cfg.num_workers,
213 | pin_memory=self.pin_memory,
214 | shuffle=shuffle,
215 | persistent_workers=False,
216 | )
217 | return dataloader
218 |
219 | def _test_dataloader(self, shuffle=False):
220 | dataloader = DataLoader(
221 | dataset=self.test_dataset,
222 | batch_size=self.cfg.batch_size,
223 | num_workers=self.cfg.num_workers,
224 | pin_memory=self.pin_memory,
225 | shuffle=shuffle,
226 | persistent_workers=False,
227 | )
228 | return dataloader
229 |
230 | def compute_mean_mad(self, properties_list):
231 | if self.cfg.dataset == "qm9" or self.cfg.dataset == "drugs":
232 | dataloader = self.get_dataloader(self.train_dataset, "val")
233 | return self.compute_mean_mad_from_dataloader(dataloader, properties_list)
234 | elif self.cfg.dataset == "qm9_1half" or self.cfg.dataset == "qm9_2half":
235 | dataloader = self.get_dataloader(self.val_dataset, "val")
236 | return self.compute_mean_mad_from_dataloader(dataloader, properties_list)
237 | else:
238 | raise Exception("Wrong dataset name")
239 |
240 | def compute_mean_mad_from_dataloader(self, dataloader, properties_list):
241 | property_norms = {}
242 | for property_key in properties_list:
243 | try:
244 | property_name = property_key + "_mm"
245 | values = getattr(dataloader.dataset[:], property_name)
246 | except:
247 | property_name = property_key
248 | idx = dataloader.dataset[:].label2idx[property_name]
249 | values = torch.tensor(
250 | [data.y[:, idx] for data in dataloader.dataset[:]]
251 | )
252 |
253 | mean = torch.mean(values)
254 | ma = torch.abs(values - mean)
255 | mad = torch.mean(ma)
256 | property_norms[property_key] = {}
257 | property_norms[property_key]["mean"] = mean
258 | property_norms[property_key]["mad"] = mad
259 | del values
260 | return property_norms
261 |
262 | def get_dataloader(self, dataset, stage):
263 | if stage == "train":
264 | batch_size = self.cfg.batch_size
265 | shuffle = True
266 | elif stage in ["val", "test"]:
267 | batch_size = self.cfg.inference_batch_size
268 | shuffle = False
269 |
270 | dl = DataLoader(
271 | dataset=dataset,
272 | batch_size=batch_size,
273 | num_workers=self.cfg.num_workers,
274 | pin_memory=True,
275 | shuffle=shuffle,
276 | )
277 |
278 | return dl
279 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/pubchem/pubchem_dataset_nonadaptive.py:
--------------------------------------------------------------------------------
1 | from rdkit import RDLogger
2 | import torch
3 | import numpy as np
4 | from os.path import join
5 | from torch_geometric.data import Dataset, DataLoader
6 | from experiments.data.utils import load_pickle, make_splits
7 | from torch.utils.data import Subset
8 | import lmdb
9 | import pickle
10 | import gzip
11 | import io
12 | from pytorch_lightning import LightningDataModule
13 | import os
14 | import experiments.data.utils as dataset_utils
15 |
16 | full_atom_encoder = {
17 | "H": 0,
18 | "B": 1,
19 | "C": 2,
20 | "N": 3,
21 | "O": 4,
22 | "F": 5,
23 | "Al": 6,
24 | "Si": 7,
25 | "P": 8,
26 | "S": 9,
27 | "Cl": 10,
28 | "As": 11,
29 | "Br": 12,
30 | "I": 13,
31 | "Hg": 14,
32 | "Bi": 15,
33 | }
34 | GEOM_DATADIR = "----"
35 |
36 |
37 | class PubChemLMDBDataset(Dataset):
38 | def __init__(
39 | self,
40 | root: str,
41 | remove_hs: bool,
42 | evaluation: bool = False,
43 | ):
44 | """
45 | Constructor
46 | """
47 | self.data_file = root
48 | self._num_graphs = 95173200 if remove_hs else 95173300 # 94980241 for noH!!
49 | if remove_hs:
50 | assert "_noh" in root
51 | self.stats_dir = "----"
52 | else:
53 | assert "_h" in root
54 | self.stats_dir = "----"
55 | super().__init__(root)
56 |
57 | self.remove_hs = remove_hs
58 | self.statistics = dataset_utils.Statistics(
59 | num_nodes=load_pickle(
60 | os.path.join(self.stats_dir, self.processed_files[0])
61 | ),
62 | atom_types=torch.from_numpy(
63 | np.load(os.path.join(self.stats_dir, self.processed_files[1]))
64 | ),
65 | bond_types=torch.from_numpy(
66 | np.load(os.path.join(self.stats_dir, self.processed_files[2]))
67 | ),
68 | charge_types=torch.from_numpy(
69 | np.load(os.path.join(self.stats_dir, self.processed_files[3]))
70 | ),
71 | valencies=None #load_pickle(
72 | #os.path.join(self.stats_dir, self.processed_files[4])
73 | #),
74 | ,
75 | bond_lengths=load_pickle(
76 | os.path.join(self.stats_dir, self.processed_files[5])
77 | ),
78 | bond_angles=torch.from_numpy(
79 | np.load(os.path.join(self.stats_dir, self.processed_files[6]))
80 | ),
81 | is_aromatic=torch.from_numpy(
82 | np.load(os.path.join(self.stats_dir, self.processed_files[7]))
83 | ).float(),
84 | is_in_ring=torch.from_numpy(
85 | np.load(os.path.join(self.stats_dir, self.processed_files[8]))
86 | ).float(),
87 | hybridization=torch.from_numpy(
88 | np.load(os.path.join(self.stats_dir, self.processed_files[9]))
89 | ).float(),
90 | )
91 | self.smiles = None # load_pickle(
92 | #os.path.join(self.stats_dir, self.processed_files[10])
93 | #)
94 |
95 | def _init_db(self):
96 | self._env = lmdb.open(
97 | str(self.data_file),
98 | readonly=True,
99 | lock=False,
100 | readahead=False,
101 | meminit=False,
102 | create=False,
103 | )
104 |
105 | def get(self, index: int):
106 | self._init_db()
107 |
108 | with self._env.begin(write=False) as txn:
109 | compressed = txn.get(str(index).encode())
110 | buf = io.BytesIO(compressed)
111 | with gzip.GzipFile(fileobj=buf, mode="rb") as f:
112 | serialized = f.read()
113 | try:
114 | item = pickle.loads(serialized)["data"]
115 | except:
116 | return None
117 |
118 | return item
119 |
120 | def len(self) -> int:
121 | r"""Returns the number of graphs stored in the dataset."""
122 | return self._num_graphs
123 |
124 | def __len__(self) -> int:
125 | return self._num_graphs
126 |
127 | @property
128 | def processed_files(self):
129 | h = "noh" if self.remove_hs else "h"
130 | return [
131 | f"train_n_{h}.pickle",
132 | f"train_atom_types_{h}.npy",
133 | f"train_bond_types_{h}.npy",
134 | f"train_charges_{h}.npy",
135 | f"train_valency_{h}.pickle",
136 | f"train_bond_lengths_{h}.pickle",
137 | f"train_angles_{h}.npy",
138 | f"train_is_aromatic_{h}.npy",
139 | f"train_is_in_ring_{h}.npy",
140 | f"train_hybridization_{h}.npy",
141 | "train_smiles.pickle",
142 | ]
143 |
144 |
145 | class PubChemDataModule(LightningDataModule):
146 | def __init__(self, hparams, evaluation=False):
147 | super(PubChemDataModule, self).__init__()
148 | self.save_hyperparameters(hparams)
149 | self.datadir = hparams.dataset_root
150 | self.pin_memory = True
151 |
152 | self.remove_hs = hparams.remove_hs
153 | if self.remove_hs:
154 | print("Pre-Training on dataset with implicit hydrogens")
155 | self.dataset = PubChemLMDBDataset(
156 | root=self.datadir, remove_hs=self.remove_hs, evaluation=evaluation
157 | )
158 |
159 | self.train_smiles = self.dataset.smiles
160 |
161 | self.idx_train, self.idx_val, self.idx_test = make_splits(
162 | len(self.dataset),
163 | train_size=hparams.train_size,
164 | val_size=hparams.val_size,
165 | test_size=hparams.test_size,
166 | seed=hparams.seed,
167 | filename=join(self.hparams["save_dir"], "splits.npz"),
168 | splits=None,
169 | )
170 | print(
171 | f"train {len(self.idx_train)}, val {len(self.idx_val)}, test {len(self.idx_test)}"
172 | )
173 | self.train_dataset = Subset(self.dataset, self.idx_train)
174 | self.val_dataset = Subset(self.dataset, self.idx_val)
175 | self.test_dataset = Subset(self.dataset, self.idx_test)
176 |
177 | self.statistics = {
178 | "train": self.dataset.statistics,
179 | "val": self.dataset.statistics,
180 | "test": self.dataset.statistics,
181 | }
182 |
183 | def train_dataloader(self, shuffle=False):
184 | dataloader = DataLoader(
185 | dataset=self.train_dataset,
186 | batch_size=self.hparams.batch_size,
187 | num_workers=self.hparams.num_workers,
188 | pin_memory=self.pin_memory,
189 | shuffle=True,
190 | persistent_workers=False,
191 | )
192 | return dataloader
193 |
194 | def val_dataloader(self, shuffle=False):
195 | dataloader = DataLoader(
196 | dataset=self.val_dataset,
197 | batch_size=self.hparams.batch_size,
198 | num_workers=self.hparams.num_workers,
199 | pin_memory=self.pin_memory,
200 | shuffle=shuffle,
201 | persistent_workers=False,
202 | )
203 | return dataloader
204 |
205 | def test_dataloader(self, shuffle=False):
206 | dataloader = DataLoader(
207 | dataset=self.test_dataset,
208 | batch_size=self.hparams.batch_size,
209 | num_workers=self.hparams.num_workers,
210 | pin_memory=self.pin_memory,
211 | shuffle=shuffle,
212 | persistent_workers=False,
213 | )
214 | return dataloader
215 |
216 | def compute_mean_mad(self, properties_list):
217 | if self.cfg.dataset == "qm9" or self.cfg.dataset == "drugs":
218 | dataloader = self.get_dataloader(self.train_dataset, "val")
219 | return self.compute_mean_mad_from_dataloader(dataloader, properties_list)
220 | elif self.cfg.dataset == "qm9_1half" or self.cfg.dataset == "qm9_2half":
221 | dataloader = self.get_dataloader(self.val_dataset, "val")
222 | return self.compute_mean_mad_from_dataloader(dataloader, properties_list)
223 | else:
224 | raise Exception("Wrong dataset name")
225 |
226 | def compute_mean_mad_from_dataloader(self, dataloader, properties_list):
227 | property_norms = {}
228 | for property_key in properties_list:
229 | try:
230 | property_name = property_key + "_mm"
231 | values = getattr(dataloader.dataset[:], property_name)
232 | except:
233 | property_name = property_key
234 | idx = dataloader.dataset[:].label2idx[property_name]
235 | values = torch.tensor(
236 | [data.y[:, idx] for data in dataloader.dataset[:]]
237 | )
238 |
239 | mean = torch.mean(values)
240 | ma = torch.abs(values - mean)
241 | mad = torch.mean(ma)
242 | property_norms[property_key] = {}
243 | property_norms[property_key]["mean"] = mean
244 | property_norms[property_key]["mad"] = mad
245 | del values
246 | return property_norms
247 |
248 | def get_dataloader(self, dataset, stage):
249 | if stage == "train":
250 | batch_size = self.cfg.batch_size
251 | shuffle = True
252 | elif stage in ["val", "test"]:
253 | batch_size = self.cfg.inference_batch_size
254 | shuffle = False
255 |
256 | dl = DataLoader(
257 | dataset=dataset,
258 | batch_size=batch_size,
259 | num_workers=self.cfg.num_workers,
260 | pin_memory=True,
261 | shuffle=shuffle,
262 | )
263 |
264 | return dl
265 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/data/qm9/qm9_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | from typing import Any, Sequence
4 | from torch.utils.data import Subset
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | from experiments.data.abstract_dataset import AbstractDataModule
9 | from experiments.data.metrics import compute_all_statistics
10 | from experiments.data.utils import (
11 | Statistics,
12 | load_pickle,
13 | mol_to_torch_geometric,
14 | remove_hydrogens,
15 | save_pickle,
16 | train_subset,
17 | )
18 | from os.path import join
19 |
20 | from rdkit import Chem, RDLogger
21 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip
22 | from tqdm import tqdm
23 | from torch_geometric.data import DataLoader
24 |
25 |
26 | def files_exist(files) -> bool:
27 | # NOTE: We return `False` in case `files` is empty, leading to a
28 | # re-processing of files on every instantiation.
29 | return len(files) != 0 and all([osp.exists(f) for f in files])
30 |
31 |
32 | def to_list(value: Any) -> Sequence:
33 | if isinstance(value, Sequence) and not isinstance(value, str):
34 | return value
35 | else:
36 | return [value]
37 |
38 |
39 | full_atom_encoder = {
40 | "H": 0,
41 | "B": 1,
42 | "C": 2,
43 | "N": 3,
44 | "O": 4,
45 | "F": 5,
46 | "Al": 6,
47 | "Si": 7,
48 | "P": 8,
49 | "S": 9,
50 | "Cl": 10,
51 | "As": 11,
52 | "Br": 12,
53 | "I": 13,
54 | "Hg": 14,
55 | "Bi": 15,
56 | }
57 |
58 |
59 | class QM9Dataset(InMemoryDataset):
60 | raw_url = (
61 | "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/"
62 | "molnet_publish/qm9.zip"
63 | )
64 | raw_url2 = "https://ndownloader.figshare.com/files/3195404"
65 | processed_url = "https://data.pyg.org/datasets/qm9_v3.zip"
66 |
67 | def __init__(
68 | self,
69 | split,
70 | root,
71 | remove_h: bool,
72 | transform=None,
73 | pre_transform=None,
74 | pre_filter=None,
75 | only_stats=False
76 | ):
77 | self.split = split
78 | if self.split == "train":
79 | self.file_idx = 0
80 | elif self.split == "val":
81 | self.file_idx = 1
82 | else:
83 | self.file_idx = 2
84 | self.remove_h = remove_h
85 |
86 | self.atom_encoder = full_atom_encoder
87 | if remove_h:
88 | self.atom_encoder = {
89 | k: v - 1 for k, v in self.atom_encoder.items() if k != "H"
90 | }
91 |
92 | super().__init__(root, transform, pre_transform, pre_filter)
93 | if not only_stats:
94 | self.data, self.slices = torch.load(self.processed_paths[0])
95 | else:
96 | self.data, self.slices = None, None
97 |
98 | self.statistics = Statistics(
99 | num_nodes=load_pickle(self.processed_paths[1]),
100 | atom_types=torch.from_numpy(np.load(self.processed_paths[2])).float(),
101 | bond_types=torch.from_numpy(np.load(self.processed_paths[3])).float(),
102 | charge_types=torch.from_numpy(np.load(self.processed_paths[4])).float(),
103 | valencies=load_pickle(self.processed_paths[5]),
104 | bond_lengths=load_pickle(self.processed_paths[6]),
105 | bond_angles=torch.from_numpy(np.load(self.processed_paths[7])).float(),
106 | is_aromatic=torch.from_numpy(np.load(self.processed_paths[9])).float(),
107 | is_in_ring=torch.from_numpy(np.load(self.processed_paths[10])).float(),
108 | hybridization=torch.from_numpy(np.load(self.processed_paths[11])).float(),
109 | )
110 | self.smiles = load_pickle(self.processed_paths[8])
111 |
112 | @property
113 | def raw_file_names(self):
114 | return ["gdb9.sdf", "gdb9.sdf.csv", "uncharacterized.txt"]
115 |
116 | @property
117 | def split_file_name(self):
118 | return ["train.csv", "val.csv", "test.csv"]
119 |
120 | @property
121 | def split_paths(self):
122 | r"""The absolute filepaths that must be present in order to skip
123 | splitting."""
124 | files = to_list(self.split_file_name)
125 | return [osp.join(self.raw_dir, f) for f in files]
126 |
127 | @property
128 | def processed_file_names(self):
129 | h = "noh" if self.remove_h else "h"
130 | if self.split == "train":
131 | return [
132 | f"train_{h}.pt",
133 | f"train_n_{h}.pickle",
134 | f"train_atom_types_{h}.npy",
135 | f"train_bond_types_{h}.npy",
136 | f"train_charges_{h}.npy",
137 | f"train_valency_{h}.pickle",
138 | f"train_bond_lengths_{h}.pickle",
139 | f"train_angles_{h}.npy",
140 | "train_smiles.pickle",
141 | f"train_is_aromatic_{h}.npy",
142 | f"train_is_in_ring_{h}.npy",
143 | f"train_hybridization_{h}.npy",
144 | ]
145 | elif self.split == "val":
146 | return [
147 | f"val_{h}.pt",
148 | f"val_n_{h}.pickle",
149 | f"val_atom_types_{h}.npy",
150 | f"val_bond_types_{h}.npy",
151 | f"val_charges_{h}.npy",
152 | f"val_valency_{h}.pickle",
153 | f"val_bond_lengths_{h}.pickle",
154 | f"val_angles_{h}.npy",
155 | "val_smiles.pickle",
156 | f"val_is_aromatic_{h}.npy",
157 | f"val_is_in_ring_{h}.npy",
158 | f"val_hybridization_{h}.npy",
159 | ]
160 | else:
161 | return [
162 | f"test_{h}.pt",
163 | f"test_n_{h}.pickle",
164 | f"test_atom_types_{h}.npy",
165 | f"test_bond_types_{h}.npy",
166 | f"test_charges_{h}.npy",
167 | f"test_valency_{h}.pickle",
168 | f"test_bond_lengths_{h}.pickle",
169 | f"test_angles_{h}.npy",
170 | "test_smiles.pickle",
171 | f"test_is_aromatic_{h}.npy",
172 | f"test_is_in_ring_{h}.npy",
173 | f"test_hybridization_{h}.npy",
174 | ]
175 |
176 | def download(self):
177 | """
178 | Download raw qm9 files. Taken from PyG QM9 class
179 | """
180 | try:
181 | import rdkit # noqa
182 |
183 | file_path = download_url(self.raw_url, self.raw_dir)
184 | extract_zip(file_path, self.raw_dir)
185 | os.unlink(file_path)
186 |
187 | file_path = download_url(self.raw_url2, self.raw_dir)
188 | os.rename(
189 | osp.join(self.raw_dir, "3195404"),
190 | osp.join(self.raw_dir, "uncharacterized.txt"),
191 | )
192 | except ImportError:
193 | path = download_url(self.processed_url, self.raw_dir)
194 | extract_zip(path, self.raw_dir)
195 | os.unlink(path)
196 |
197 | if files_exist(self.split_paths):
198 | return
199 |
200 | dataset = pd.read_csv(self.raw_paths[1])
201 |
202 | n_samples = len(dataset)
203 | n_train = 100000
204 | n_test = int(0.1 * n_samples)
205 | n_val = n_samples - (n_train + n_test)
206 |
207 | # Shuffle dataset with df.sample, then split
208 | train, val, test = np.split(
209 | dataset.sample(frac=1, random_state=42), [n_train, n_val + n_train]
210 | )
211 |
212 | train.to_csv(os.path.join(self.raw_dir, "train.csv"))
213 | val.to_csv(os.path.join(self.raw_dir, "val.csv"))
214 | test.to_csv(os.path.join(self.raw_dir, "test.csv"))
215 |
216 | def process(self):
217 | RDLogger.DisableLog("rdApp.*")
218 |
219 | target_df = pd.read_csv(self.split_paths[self.file_idx], index_col=0)
220 | target_df.drop(columns=["mol_id"], inplace=True)
221 |
222 | with open(self.raw_paths[-1]) as f:
223 | skip = [int(x.split()[0]) - 1 for x in f.read().split("\n")[9:-2]]
224 |
225 | suppl = Chem.SDMolSupplier(self.raw_paths[0], removeHs=False, sanitize=False)
226 | data_list = []
227 | all_smiles = []
228 | num_errors = 0
229 | for i, mol in enumerate(tqdm(suppl)):
230 | if i in skip or i not in target_df.index:
231 | continue
232 | smiles = Chem.MolToSmiles(mol, isomericSmiles=False)
233 | if smiles is None:
234 | num_errors += 1
235 | else:
236 | all_smiles.append(smiles)
237 |
238 | data = mol_to_torch_geometric(mol, full_atom_encoder, smiles)
239 | if self.remove_h:
240 | data = remove_hydrogens(data)
241 |
242 | if self.pre_filter is not None and not self.pre_filter(data):
243 | continue
244 | if self.pre_transform is not None:
245 | data = self.pre_transform(data)
246 |
247 | data_list.append(data)
248 | torch.save(self.collate(data_list), self.processed_paths[self.file_idx])
249 |
250 | statistics = compute_all_statistics(
251 | data_list,
252 | self.atom_encoder,
253 | charges_dic={-2: 0, -1: 1, 0: 2, 1: 3, 2: 4, 3: 5},
254 | additional_feats=True,
255 | )
256 |
257 | save_pickle(statistics.num_nodes, self.processed_paths[1])
258 | np.save(self.processed_paths[2], statistics.atom_types)
259 | np.save(self.processed_paths[3], statistics.bond_types)
260 | np.save(self.processed_paths[4], statistics.charge_types)
261 | save_pickle(statistics.valencies, self.processed_paths[5])
262 | save_pickle(statistics.bond_lengths, self.processed_paths[6])
263 | np.save(self.processed_paths[7], statistics.bond_angles)
264 |
265 | np.save(self.processed_paths[9], statistics.is_aromatic)
266 | np.save(self.processed_paths[10], statistics.is_in_ring)
267 | np.save(self.processed_paths[11], statistics.hybridization)
268 |
269 | print("Number of molecules that could not be mapped to smiles: ", num_errors)
270 | save_pickle(set(all_smiles), self.processed_paths[8])
271 | torch.save(self.collate(data_list), self.processed_paths[0])
272 |
273 |
274 | class QM9DataModule(AbstractDataModule):
275 | def __init__(self, cfg, only_stats: bool = False):
276 | self.datadir = cfg.dataset_root
277 | root_path = self.datadir
278 |
279 | train_dataset = QM9Dataset(
280 | split="train", root=root_path, remove_h=cfg.remove_hs, only_stats=only_stats
281 | )
282 | val_dataset = QM9Dataset(split="val", root=root_path, remove_h=cfg.remove_hs, only_stats=only_stats)
283 | test_dataset = QM9Dataset(split="test", root=root_path, remove_h=cfg.remove_hs, only_stats=only_stats)
284 |
285 | self.statistics = {
286 | "train": train_dataset.statistics,
287 | "val": val_dataset.statistics,
288 | "test": test_dataset.statistics,
289 | }
290 | if not only_stats:
291 | if cfg.select_train_subset:
292 | self.idx_train = train_subset(
293 | dset_len=len(train_dataset),
294 | train_size=cfg.train_size,
295 | seed=cfg.seed,
296 | filename=join(cfg.save_dir, "splits.npz"),
297 | )
298 | self.train_smiles = train_dataset.smiles
299 | train_dataset = Subset(train_dataset, self.idx_train)
300 |
301 | self.remove_h = cfg.remove_hs
302 | super().__init__(
303 | cfg,
304 | train_dataset=train_dataset,
305 | val_dataset=val_dataset,
306 | test_dataset=test_dataset,
307 | )
308 |
309 | def get_dataloader(self, dataset, stage):
310 | if stage == "train":
311 | batch_size = self.cfg.batch_size
312 | shuffle = True
313 | elif stage in ["val", "test"]:
314 | batch_size = self.cfg.inference_batch_size
315 | shuffle = False
316 |
317 | dl = DataLoader(
318 | dataset=dataset,
319 | batch_size=batch_size,
320 | num_workers=self.cfg.num_workers,
321 | pin_memory=True,
322 | shuffle=shuffle,
323 | )
324 |
325 | return dl
--------------------------------------------------------------------------------
/eqgat_diff/experiments/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/eqgat_diff/experiments/diffusion/__init__.py
--------------------------------------------------------------------------------
/eqgat_diff/experiments/diffusion/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from typing import List, Optional
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from rdkit.Chem import RDConfig
8 | from rdkit.Chem.Scaffolds.MurckoScaffold import GetScaffoldForMol
9 | from torch_geometric.utils import remove_self_loops, sort_edge_index
10 | from torch_scatter import scatter_mean
11 |
12 | from experiments.utils import get_edges, zero_mean
13 |
14 | sys.path.append(os.path.join(RDConfig.RDContribDir, "IFG"))
15 | from ifg import identify_functional_groups
16 |
17 |
18 | def initialize_edge_attrs_reverse(
19 | edge_index_global, n, bonds_prior, num_bond_classes, device
20 | ):
21 | # edge types for FC graph
22 | j, i = edge_index_global
23 | mask = j < i
24 | mask_i = i[mask]
25 | mask_j = j[mask]
26 | nE = len(mask_i)
27 | edge_attr_triu = torch.multinomial(bonds_prior, num_samples=nE, replacement=True)
28 |
29 | j = torch.concat([mask_j, mask_i])
30 | i = torch.concat([mask_i, mask_j])
31 | edge_index_global = torch.stack([j, i], dim=0)
32 | edge_attr_global = torch.concat([edge_attr_triu, edge_attr_triu], dim=0)
33 | edge_index_global, edge_attr_global = sort_edge_index(
34 | edge_index=edge_index_global, edge_attr=edge_attr_global, sort_by_row=False
35 | )
36 | j, i = edge_index_global
37 | mask = j < i
38 | mask_i = i[mask]
39 | mask_j = j[mask]
40 |
41 | # some assert
42 |
43 | edge_attr_global_dense = torch.zeros(size=(n, n), device=device, dtype=torch.long)
44 | edge_attr_global_dense[
45 | edge_index_global[0], edge_index_global[1]
46 | ] = edge_attr_global
47 | assert (edge_attr_global_dense - edge_attr_global_dense.T).sum().float() == 0.0
48 |
49 | edge_attr_global = F.one_hot(edge_attr_global, num_bond_classes).float()
50 |
51 | return edge_attr_global, edge_index_global, mask, mask_i
52 |
53 |
54 | def get_joint_edge_attrs(
55 | pos,
56 | pos_pocket,
57 | batch,
58 | batch_pocket,
59 | edge_attr_global_lig,
60 | num_bond_classes,
61 | device,
62 | ):
63 | edge_index_global = get_edges(
64 | batch, batch_pocket, pos, pos_pocket, cutoff_p=5, cutoff_lp=5
65 | )
66 | edge_index_global = sort_edge_index(edge_index=edge_index_global, sort_by_row=False)
67 | edge_index_global, _ = remove_self_loops(edge_index_global)
68 | edge_attr_global = torch.zeros(
69 | (edge_index_global.size(1), num_bond_classes),
70 | dtype=torch.float32,
71 | device=device,
72 | )
73 | edge_mask = (edge_index_global[0] < len(batch)) & (
74 | edge_index_global[1] < len(batch)
75 | )
76 | edge_mask_pocket = (edge_index_global[0] >= len(batch)) & (
77 | edge_index_global[1] >= len(batch)
78 | )
79 | edge_attr_global[edge_mask] = edge_attr_global_lig
80 |
81 | if num_bond_classes == 7:
82 | edge_mask_ligand_pocket = (edge_index_global[0] < len(batch)) & (
83 | edge_index_global[1] >= len(batch)
84 | )
85 | edge_mask_pocket_ligand = (edge_index_global[0] >= len(batch)) & (
86 | edge_index_global[1] < len(batch)
87 | )
88 | edge_attr_global[edge_mask_pocket] = (
89 | torch.tensor([0, 0, 0, 0, 0, 0, 1]).float().to(edge_attr_global.device)
90 | )
91 | edge_attr_global[edge_mask_ligand_pocket] = (
92 | torch.tensor([0, 0, 0, 0, 0, 1, 0]).float().to(edge_attr_global.device)
93 | )
94 | edge_attr_global[edge_mask_pocket_ligand] = (
95 | torch.tensor([0, 0, 0, 0, 0, 1, 0]).float().to(edge_attr_global.device)
96 | )
97 | else:
98 | edge_attr_global[edge_mask_pocket] = (
99 | torch.tensor([0, 0, 0, 0, 1]).float().to(edge_attr_global.device)
100 | )
101 | # edge_attr_global[edge_mask_pocket] = 0.0
102 |
103 | batch_full = torch.cat([batch, batch_pocket])
104 | batch_edge_global = batch_full[edge_index_global[0]] #
105 |
106 | return (
107 | edge_index_global,
108 | edge_attr_global,
109 | batch_edge_global,
110 | edge_mask,
111 | )
112 |
113 |
114 | def bond_guidance(
115 | pos,
116 | node_feats_in,
117 | temb,
118 | bond_model,
119 | batch,
120 | batch_edge_global,
121 | edge_attr_global,
122 | edge_index_local,
123 | edge_index_global,
124 | ):
125 | guidance_type = "logsum"
126 | guidance_scale = 1.0e-4
127 |
128 | bs = len(batch.bincount())
129 | with torch.enable_grad():
130 | node_feats_in = node_feats_in.detach()
131 | pos = pos.detach().requires_grad_(True)
132 | bond_prediction = bond_model(
133 | x=node_feats_in,
134 | t=temb,
135 | pos=pos,
136 | edge_index_local=edge_index_local,
137 | edge_index_global=edge_index_global,
138 | edge_attr_global=edge_attr_global,
139 | batch=batch,
140 | batch_edge_global=batch_edge_global,
141 | )
142 | if guidance_type == "ensemble":
143 | # TO-DO
144 | raise NotImplementedError
145 | elif guidance_type == "logsum":
146 | uncertainty = torch.sigmoid(-torch.logsumexp(bond_prediction, dim=-1))
147 | uncertainty = (
148 | 0.5
149 | * scatter_mean(
150 | uncertainty,
151 | index=edge_index_global[1],
152 | dim=0,
153 | dim_size=pos.size(0),
154 | ).log()
155 | )
156 | uncertainty = scatter_mean(uncertainty, index=batch, dim=0, dim_size=bs)
157 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(uncertainty)]
158 | dist_shift = -torch.autograd.grad(
159 | [uncertainty],
160 | [pos],
161 | grad_outputs=grad_outputs,
162 | create_graph=False,
163 | retain_graph=False,
164 | )[0]
165 |
166 | return pos + guidance_scale * dist_shift
167 |
168 |
169 | def energy_guidance(
170 | pos,
171 | node_feats_in,
172 | temb,
173 | energy_model,
174 | batch,
175 | batch_size,
176 | signal=1.0e-3,
177 | guidance_scale=100,
178 | optimization="minimize",
179 | ):
180 | with torch.enable_grad():
181 | node_feats_in = node_feats_in.detach()
182 | pos = pos.detach().requires_grad_(True)
183 | out = energy_model(
184 | x=node_feats_in,
185 | t=temb,
186 | pos=pos,
187 | batch=batch,
188 | )
189 | if optimization == "minimize":
190 | sign = -1.0
191 | elif optimization == "maximize":
192 | sign = 1.0
193 | else:
194 | raise Exception("Optimization arg needs to be 'minimize' or 'maximize'!")
195 | energy_prediction = sign * guidance_scale * out["property_pred"]
196 |
197 | grad_outputs: List[Optional[torch.Tensor]] = [
198 | torch.ones_like(energy_prediction)
199 | ]
200 | pos_shift = torch.autograd.grad(
201 | [energy_prediction],
202 | [pos],
203 | grad_outputs=grad_outputs,
204 | create_graph=False,
205 | retain_graph=False,
206 | )[0]
207 |
208 | pos_shift = zero_mean(pos_shift, batch=batch, dim_size=batch_size, dim=0)
209 |
210 | pos = pos + signal * pos_shift
211 | pos = zero_mean(pos, batch=batch, dim_size=batch_size, dim=0)
212 |
213 | return pos.detach()
214 |
215 |
216 | def extract_scaffolds_(batch_data):
217 | def scaffold_per_mol(mol):
218 | for a in mol.GetAtoms():
219 | a.SetIntProp("org_idx", a.GetIdx())
220 |
221 | scaffold = GetScaffoldForMol(mol)
222 | scaffold_atoms = [a.GetIntProp("org_idx") for a in scaffold.GetAtoms()]
223 | mask = torch.zeros(mol.GetNumAtoms(), dtype=bool)
224 | mask[torch.tensor(scaffold_atoms)] = 1
225 | return mask
226 |
227 | batch_data.scaffold_mask = torch.hstack(
228 | [scaffold_per_mol(mol) for mol in batch_data.mol]
229 | )
230 |
231 |
232 | def extract_func_groups_(batch_data, includeHs=True):
233 | def func_groups_per_mol(mol, includeHs=True):
234 | fgroups = identify_functional_groups(mol)
235 | findices = []
236 | for f in fgroups:
237 | findices.extend(list(f.atomIds))
238 | if includeHs: # include neighboring H atoms in functional groups
239 | findices_incl_h = []
240 | for fi in findices:
241 | hidx = [
242 | n.GetIdx()
243 | for n in mol.GetAtomWithIdx(fi).GetNeighbors()
244 | if n.GetSymbol() == "H"
245 | ]
246 | findices_incl_h.extend([fi] + hidx)
247 | findices = findices_incl_h
248 | mask = torch.zeros(mol.GetNumAtoms(), dtype=bool)
249 | mask[torch.tensor(findices)] = 1
250 | return mask
251 |
252 | batch_data.func_group_mask = torch.hstack(
253 | [func_groups_per_mol(mol, includeHs) for mol in batch_data.mol]
254 | )
255 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/docking.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import tempfile
4 | import numpy as np
5 | import torch
6 | from pathlib import Path
7 | import argparse
8 | import pandas as pd
9 | from rdkit import Chem
10 | from tqdm import tqdm
11 | from experiments.utils import write_sdf_file
12 |
13 |
14 | def calculate_smina_score(pdb_file, sdf_file):
15 | # add '-o _smina.sdf' if you want to see the output
16 | out = os.popen(f"smina.static -l {sdf_file} -r {pdb_file} " f"--score_only").read()
17 | matches = re.findall(r"Affinity:[ ]+([+-]?[0-9]*[.]?[0-9]+)[ ]+\(kcal/mol\)", out)
18 | return [float(x) for x in matches]
19 |
20 |
21 | def smina_score(rdmols, receptor_file):
22 | """
23 | Calculate smina score
24 | :param rdmols: List of RDKit molecules
25 | :param receptor_file: Receptor pdb/pdbqt file or list of receptor files
26 | :return: Smina score for each input molecule (list)
27 | """
28 |
29 | if isinstance(receptor_file, list):
30 | scores = []
31 | for mol, rec_file in zip(rdmols, receptor_file):
32 | with tempfile.NamedTemporaryFile(suffix=".sdf") as tmp:
33 | tmp_file = tmp.name
34 | write_sdf_file(tmp_file, [mol])
35 | scores.extend(calculate_smina_score(rec_file, tmp_file))
36 |
37 | # Use same receptor file for all molecules
38 | else:
39 | with tempfile.NamedTemporaryFile(suffix=".sdf") as tmp:
40 | tmp_file = tmp.name
41 | write_sdf_file(tmp_file, rdmols)
42 | scores = calculate_smina_score(receptor_file, tmp_file)
43 |
44 | return scores
45 |
46 |
47 | def sdf_to_pdbqt(sdf_file, pdbqt_outfile, mol_id):
48 | os.popen(
49 | f"obabel {sdf_file} -O {pdbqt_outfile} " f"-f {mol_id + 1} -l {mol_id + 1}"
50 | ).read()
51 | return pdbqt_outfile
52 |
53 |
54 | def calculate_qvina2_score(
55 | receptor_file, sdf_file, out_dir, size=20, exhaustiveness=16, return_rdmol=False
56 | ):
57 | receptor_file = Path(receptor_file)
58 | sdf_file = Path(sdf_file)
59 |
60 | if receptor_file.suffix == ".pdb":
61 | # prepare receptor, requires Python 2.7
62 | receptor_pdbqt_file = Path(out_dir, receptor_file.stem + ".pdbqt")
63 | os.popen(f"prepare_receptor4.py -r {receptor_file} -O {receptor_pdbqt_file}")
64 | else:
65 | receptor_pdbqt_file = receptor_file
66 |
67 | scores = []
68 | rdmols = [] # for if return rdmols
69 | suppl = Chem.SDMolSupplier(str(sdf_file), sanitize=False)
70 | for i, mol in enumerate(suppl): # sdf file may contain several ligands
71 | ligand_name = f"{sdf_file.stem}_{i}"
72 | # prepare ligand
73 | ligand_pdbqt_file = Path(out_dir, ligand_name + ".pdbqt")
74 | out_sdf_file = Path(out_dir, ligand_name + "_out.sdf")
75 |
76 | if out_sdf_file.exists():
77 | with open(out_sdf_file, "r") as f:
78 | scores.append(
79 | min(
80 | [
81 | float(x.split()[2])
82 | for x in f.readlines()
83 | if x.startswith(" VINA RESULT:")
84 | ]
85 | )
86 | )
87 |
88 | else:
89 | sdf_to_pdbqt(sdf_file, ligand_pdbqt_file, i)
90 |
91 | # center box at ligand's center of mass
92 | cx, cy, cz = mol.GetConformer().GetPositions().mean(0)
93 |
94 | # run QuickVina 2
95 | try:
96 | os.stat("----")
97 | PATH = "----"
98 | except PermissionError:
99 | PATH = "----"
100 |
101 | out = os.popen(
102 | f"/{PATH} --receptor {receptor_pdbqt_file} "
103 | f"--ligand {ligand_pdbqt_file} "
104 | f"--center_x {cx:.4f} --center_y {cy:.4f} --center_z {cz:.4f} "
105 | f"--size_x {size} --size_y {size} --size_z {size} "
106 | f"--exhaustiveness {exhaustiveness}",
107 | ).read()
108 | # clean up
109 | ligand_pdbqt_file.unlink()
110 |
111 | if "-----+------------+----------+----------" not in out:
112 | scores.append(np.nan)
113 | continue
114 |
115 | out_split = out.splitlines()
116 | best_idx = out_split.index("-----+------------+----------+----------") + 1
117 | best_line = out_split[best_idx].split()
118 | assert best_line[0] == "1"
119 | scores.append(float(best_line[1]))
120 |
121 | out_pdbqt_file = Path(out_dir, ligand_name + "_out.pdbqt")
122 | if out_pdbqt_file.exists():
123 | os.popen(f"obabel {out_pdbqt_file} -O {out_sdf_file}").read()
124 |
125 | # clean up
126 | out_pdbqt_file.unlink()
127 |
128 | if return_rdmol:
129 | rdmol = Chem.SDMolSupplier(str(out_sdf_file))[0]
130 | rdmols.append(rdmol)
131 |
132 | if return_rdmol:
133 | return scores, rdmols
134 | else:
135 | return scores
136 |
137 |
138 | if __name__ == "__main__":
139 | parser = argparse.ArgumentParser("QuickVina evaluation")
140 | parser.add_argument("--pdbqt_dir", type=Path, help="Receptor files in pdbqt format")
141 | parser.add_argument(
142 | "--sdf_dir", type=Path, default=None, help="Ligand files in sdf format"
143 | )
144 | parser.add_argument("--sdf_files", type=Path, nargs="+", default=None)
145 | parser.add_argument("--out_dir", type=Path)
146 | parser.add_argument("--write_csv", action="store_true")
147 | parser.add_argument("--write_dict", action="store_true")
148 | parser.add_argument("--dataset", type=str, default="moad")
149 | args = parser.parse_args()
150 |
151 | assert (args.sdf_dir is not None) ^ (args.sdf_files is not None)
152 |
153 | args.out_dir.mkdir(exist_ok=True)
154 |
155 | results = {"receptor": [], "ligand": [], "scores": []}
156 | results_dict = {}
157 | sdf_files = (
158 | list(args.sdf_dir.glob("[!.]*.sdf"))
159 | if args.sdf_dir is not None
160 | else args.sdf_files
161 | )
162 | pbar = tqdm(sdf_files)
163 | for sdf_file in pbar:
164 | pbar.set_description(f"Processing {sdf_file.name}")
165 |
166 | if args.dataset == "moad":
167 | """
168 | Ligand file names should be of the following form:
169 | __.sdf
170 | where and cannot contain any
171 | underscores, e.g.: 1abc-bio1_pocket0_gen.sdf
172 | """
173 | ligand_name = sdf_file.stem
174 | receptor_name, pocket_id, *suffix = ligand_name.split("_")
175 | suffix = "_".join(suffix)
176 | receptor_file = Path(args.pdbqt_dir, receptor_name + ".pdbqt")
177 | elif args.dataset == "crossdocked":
178 | ligand_name = sdf_file.stem
179 | receptor_name = ligand_name[:-4]
180 | receptor_file = Path(args.pdbqt_dir, receptor_name + ".pdbqt")
181 |
182 | # try:
183 | scores, rdmols = calculate_qvina2_score(
184 | receptor_file, sdf_file, args.out_dir, return_rdmol=True
185 | )
186 | # except AttributeError as e:
187 | # print(e)
188 | # continue
189 | results["receptor"].append(str(receptor_file))
190 | results["ligand"].append(str(sdf_file))
191 | results["scores"].append(scores)
192 |
193 | if args.write_dict:
194 | results_dict[ligand_name] = {
195 | "receptor": str(receptor_file),
196 | "ligand": str(sdf_file),
197 | "scores": scores,
198 | "rmdols": rdmols,
199 | }
200 |
201 | if args.write_csv:
202 | df = pd.DataFrame.from_dict(results)
203 | df.to_csv(Path(args.out_dir, "qvina2_scores.csv"))
204 |
205 | if args.write_dict:
206 | torch.save(results_dict, Path(args.out_dir, "qvina2_scores.pt"))
207 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/docking_mgl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import glob
4 |
5 |
6 | def pdbs_to_pdbqts(pdb_dir, pdbqt_dir, dataset):
7 | for file in glob.glob(os.path.join(pdb_dir, "*.pdb")):
8 | name = os.path.splitext(os.path.basename(file))[0]
9 | outfile = os.path.join(pdbqt_dir, name + ".pdbqt")
10 | pdb_to_pdbqt(file, outfile, dataset)
11 | print("Wrote converted file to {}".format(outfile))
12 |
13 |
14 | def pdb_to_pdbqt(pdb_file, pdbqt_file, dataset):
15 | if os.path.exists(pdbqt_file):
16 | return pdbqt_file
17 | if dataset == "crossdocked":
18 | os.system("prepare_receptor4.py -r {} -o {}".format(pdb_file, pdbqt_file))
19 | elif dataset == "bindingmoad":
20 | os.system(
21 | "prepare_receptor4.py -r {} -o {} -A checkhydrogens -e".format(
22 | pdb_file, pdbqt_file
23 | )
24 | )
25 | else:
26 | raise NotImplementedError
27 | return pdbqt_file
28 |
29 |
30 | if __name__ == "__main__":
31 | pdbs_to_pdbqts(sys.argv[1], sys.argv[2], sys.argv[3])
32 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/hparams.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from experiments.utils import LoadFromFile
4 |
5 | DEFAULT_SAVE_DIR = os.path.join(os.getcwd(), "3DcoordsAtomsBonds_0")
6 |
7 | if not os.path.exists(DEFAULT_SAVE_DIR):
8 | os.makedirs(DEFAULT_SAVE_DIR)
9 |
10 |
11 | def add_arguments(parser):
12 | """Helper function to fill the parser object.
13 |
14 | Args:
15 | parser: Parser object
16 | Returns:
17 | parser: Updated parser object
18 | """
19 |
20 | # Load yaml file
21 | parser.add_argument(
22 | "--conf", "-c", type=open, action=LoadFromFile, help="Configuration yaml file"
23 | ) # keep first
24 |
25 | # Load from checkpoint
26 | parser.add_argument("--load-ckpt", default=None, type=str)
27 | parser.add_argument("--load-ckpt-from-pretrained", default=None, type=str)
28 |
29 | # DATA and FILES
30 | parser.add_argument("-s", "--save-dir", default=DEFAULT_SAVE_DIR, type=str)
31 | parser.add_argument("--test-save-dir", default=DEFAULT_SAVE_DIR, type=str)
32 |
33 | parser.add_argument(
34 | "--dataset",
35 | default="drugs",
36 | choices=[
37 | "qm9",
38 | "drugs",
39 | "aqm",
40 | "aqm_qm7x",
41 | "pcqm4mv2",
42 | "pepconf",
43 | "crossdocked",
44 | ],
45 | )
46 | parser.add_argument(
47 | "--dataset-root", default="----"
48 | )
49 | parser.add_argument("--use-adaptive-loader", default=True, action="store_true")
50 | parser.add_argument("--remove-hs", default=False, action="store_true")
51 | parser.add_argument("--select-train-subset", default=False, action="store_true")
52 | parser.add_argument("--train-size", default=0.8, type=float)
53 | parser.add_argument("--val-size", default=0.1, type=float)
54 | parser.add_argument("--test-size", default=0.1, type=float)
55 |
56 | parser.add_argument("--dropout-prob", default=0.3, type=float)
57 |
58 | # LEARNING
59 | parser.add_argument("-b", "--batch-size", default=32, type=int)
60 | parser.add_argument("-ib", "--inference-batch-size", default=32, type=int)
61 | parser.add_argument("--gamma", default=0.975, type=float)
62 | parser.add_argument("--grad-clip-val", default=10.0, type=float)
63 | parser.add_argument(
64 | "--lr-scheduler",
65 | default="reduce_on_plateau",
66 | choices=["reduce_on_plateau", "cosine_annealing", "cyclic"],
67 | )
68 | parser.add_argument(
69 | "--optimizer",
70 | default="adam",
71 | choices=["adam", "sgd"],
72 | )
73 | parser.add_argument("--lr", default=5e-4, type=float)
74 | parser.add_argument("--lr-min", default=5e-5, type=float)
75 | parser.add_argument("--lr-step-size", default=10000, type=int)
76 | parser.add_argument("--lr-frequency", default=5, type=int)
77 | parser.add_argument("--lr-patience", default=20, type=int)
78 | parser.add_argument("--lr-cooldown", default=5, type=int)
79 | parser.add_argument("--lr-factor", default=0.75, type=float)
80 |
81 | # MODEL
82 | parser.add_argument("--sdim", default=256, type=int)
83 | parser.add_argument("--vdim", default=64, type=int)
84 | parser.add_argument("--latent_dim", default=None, type=int)
85 | parser.add_argument("--rbf-dim", default=32, type=int)
86 | parser.add_argument("--edim", default=32, type=int)
87 | parser.add_argument("--edge-mp", default=False, action="store_true")
88 | parser.add_argument("--vector-aggr", default="mean", type=str)
89 | parser.add_argument("--num-layers", default=7, type=int)
90 | parser.add_argument("--fully-connected", default=True, action="store_true")
91 | parser.add_argument("--local-global-model", default=False, action="store_true")
92 | parser.add_argument("--local-edge-attrs", default=False, action="store_true")
93 | parser.add_argument("--use-cross-product", default=False, action="store_true")
94 | parser.add_argument("--cutoff-local", default=7.0, type=float)
95 | parser.add_argument("--cutoff-global", default=10.0, type=float)
96 | parser.add_argument("--energy-training", default=False, action="store_true")
97 | parser.add_argument("--property-training", default=False, action="store_true")
98 | parser.add_argument(
99 | "--regression-property",
100 | default="polarizability",
101 | type=str,
102 | choices=[
103 | "dipole_norm",
104 | "total_energy",
105 | "HOMO-LUMO_gap",
106 | "dispersion",
107 | "atomisation_energy",
108 | "polarizability",
109 | ],
110 | )
111 | parser.add_argument("--energy-loss", default="l2", type=str, choices=["l2", "l1"])
112 | parser.add_argument("--use-pos-norm", default=False, action="store_true")
113 |
114 | # For Discrete: Include more features: (is_aromatic, is_in_ring, hybridization)
115 | parser.add_argument("--additional-feats", default=False, action="store_true")
116 | parser.add_argument("--use-qm-props", default=False, action="store_true")
117 | parser.add_argument("--build-mol-with-addfeats", default=False, action="store_true")
118 |
119 | # DIFFUSION
120 | parser.add_argument(
121 | "--continuous",
122 | default=False,
123 | action="store_true",
124 | help="If the diffusion process is applied on continuous time variable. Defaults to False",
125 | )
126 | parser.add_argument(
127 | "--noise-scheduler",
128 | default="cosine",
129 | choices=["linear", "cosine", "quad", "sigmoid", "adaptive", "linear-time"],
130 | )
131 | parser.add_argument("--eps-min", default=1e-3, type=float)
132 | parser.add_argument("--beta-min", default=1e-4, type=float)
133 | parser.add_argument("--beta-max", default=2e-2, type=float)
134 | parser.add_argument("--timesteps", default=500, type=int)
135 | parser.add_argument("--max-time", type=str, default=None)
136 | parser.add_argument("--lc-coords", default=3.0, type=float)
137 | parser.add_argument("--lc-atoms", default=0.4, type=float)
138 | parser.add_argument("--lc-bonds", default=2.0, type=float)
139 | parser.add_argument("--lc-charges", default=1.0, type=float)
140 | parser.add_argument("--lc-mulliken", default=1.5, type=float)
141 | parser.add_argument("--lc-wbo", default=2.0, type=float)
142 |
143 | parser.add_argument("--pocket-noise-std", default=0.1, type=float)
144 | parser.add_argument(
145 | "--use-ligand-dataset-sizes", default=False, action="store_true"
146 | )
147 |
148 | parser.add_argument(
149 | "--loss-weighting",
150 | default="snr_t",
151 | choices=["snr_s_t", "snr_t", "exp_t", "expt_t_half", "uniform"],
152 | )
153 | parser.add_argument("--snr-clamp-min", default=0.05, type=float)
154 | parser.add_argument("--snr-clamp-max", default=1.50, type=float)
155 |
156 | parser.add_argument(
157 | "--ligand-pocket-interaction", default=False, action="store_true"
158 | )
159 | parser.add_argument("--diffusion-pretraining", default=False, action="store_true")
160 | parser.add_argument(
161 | "--continuous-param", default="data", type=str, choices=["data", "noise"]
162 | )
163 | parser.add_argument("--atoms-categorical", default=False, action="store_true")
164 | parser.add_argument("--bonds-categorical", default=False, action="store_true")
165 |
166 | parser.add_argument("--atom-type-masking", default=False, action="store_true")
167 | parser.add_argument("--use-absorbing-state", default=False, action="store_true")
168 |
169 | parser.add_argument("--num-bond-classes", default=5, type=int)
170 | parser.add_argument("--num-charge-classes", default=6, type=int)
171 |
172 | # BOND PREDICTION AND GUIDANCE:
173 | parser.add_argument("--bond-guidance-model", default=False, action="store_true")
174 | parser.add_argument("--bond-prediction", default=False, action="store_true")
175 | parser.add_argument("--bond-model-guidance", default=False, action="store_true")
176 | parser.add_argument("--energy-model-guidance", default=False, action="store_true")
177 | parser.add_argument(
178 | "--polarizabilty-model-guidance", default=False, action="store_true"
179 | )
180 | parser.add_argument("--ckpt-bond-model", default=None, type=str)
181 | parser.add_argument("--ckpt-energy-model", default=None, type=str)
182 | parser.add_argument("--ckpt-polarizabilty-model", default=None, type=str)
183 | parser.add_argument("--guidance-scale", default=1.0e-4, type=float)
184 |
185 | # CONTEXT
186 | parser.add_argument("--context-mapping", default=False, action="store_true")
187 | parser.add_argument("--num-context-features", default=0, type=int)
188 | parser.add_argument("--properties-list", default=[], nargs="+", type=str)
189 |
190 | # PROPERTY PREDICTION
191 | parser.add_argument("--property-prediction", default=False, action="store_true")
192 |
193 | # LATENT
194 | parser.add_argument("--prior-beta", default=1.0, type=float)
195 | parser.add_argument("--sdim-latent", default=256, type=int)
196 | parser.add_argument("--vdim-latent", default=64, type=int)
197 | parser.add_argument("--latent-dim", default=None, type=int)
198 | parser.add_argument("--edim-latent", default=32, type=int)
199 | parser.add_argument("--num-layers-latent", default=7, type=int)
200 | parser.add_argument("--latent-layers", default=7, type=int)
201 | parser.add_argument("--latentmodel", default="diffusion", type=str)
202 | parser.add_argument("--latent-detach", default=False, action="store_true")
203 |
204 | # GENERAL
205 | parser.add_argument("-i", "--id", type=int, default=0)
206 | parser.add_argument("-g", "--gpus", default=1, type=int)
207 | parser.add_argument("-e", "--num-epochs", default=300, type=int)
208 | parser.add_argument("--eval-freq", default=1.0, type=float)
209 | parser.add_argument("--test-interval", default=5, type=int)
210 | parser.add_argument("-nh", "--no_h", default=False, action="store_true")
211 | parser.add_argument("--precision", default=32, type=int)
212 | parser.add_argument("--detect-anomaly", default=False, action="store_true")
213 | parser.add_argument("--num-workers", default=4, type=int)
214 | parser.add_argument(
215 | "--max-num-conformers",
216 | default=5,
217 | type=int,
218 | help="Maximum number of conformers per molecule. \
219 | Defaults to 30. Set to -1 for all conformers available in database",
220 | )
221 | parser.add_argument("--accum-batch", default=1, type=int)
222 | parser.add_argument("--max-num-neighbors", default=128, type=int)
223 | parser.add_argument("--ema-decay", default=0.9999, type=float)
224 | parser.add_argument("--weight-decay", default=0.9999, type=float)
225 | parser.add_argument("--seed", default=42, type=int)
226 | parser.add_argument("--backprop-local", default=False, action="store_true")
227 |
228 | # SAMPLING
229 | parser.add_argument("--num-test-graphs", default=10000, type=int)
230 | parser.add_argument("--calculate-energy", default=False, action="store_true")
231 | parser.add_argument("--save-xyz", default=False, action="store_true")
232 |
233 | return parser
234 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/losses.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import Tensor, nn
6 | from torch_scatter import scatter_mean
7 |
8 |
9 | class DiffusionLoss(nn.Module):
10 | def __init__(
11 | self,
12 | modalities: List = ["coords", "atoms", "charges", "bonds"],
13 | param: List = ["data", "data", "data", "data"],
14 | ) -> None:
15 | super().__init__()
16 | assert len(modalities) == len(param)
17 | self.modalities = modalities
18 | self.param = param
19 |
20 | if "coords" in modalities:
21 | self.regression_key = "coords"
22 | elif "latents" in modalities:
23 | self.regression_key = "latents"
24 | else:
25 | raise ValueError
26 |
27 | def loss_non_nans(self, loss: Tensor, modality: str) -> Tensor:
28 | m = loss.isnan()
29 | if torch.any(m):
30 | print(f"Recovered NaNs in {modality}. Selecting NoN-Nans")
31 | return loss[~m], m
32 |
33 | def forward(
34 | self,
35 | true_data: Dict,
36 | pred_data: Dict,
37 | batch: Tensor,
38 | bond_aggregation_index: Tensor,
39 | weights: Optional[Tensor] = None,
40 | ) -> Dict:
41 | batch_size = len(batch.unique())
42 | mulliken_loss = None
43 | wbo_loss = None
44 |
45 | if weights is not None:
46 | assert len(weights) == batch_size
47 |
48 | regr_loss = F.mse_loss(
49 | pred_data[self.regression_key],
50 | true_data[self.regression_key],
51 | reduction="none",
52 | ).mean(-1)
53 | regr_loss = scatter_mean(regr_loss, index=batch, dim=0, dim_size=batch_size)
54 | regr_loss, m = self.loss_non_nans(regr_loss, self.regression_key)
55 |
56 | regr_loss *= weights[~m]
57 | regr_loss = torch.sum(regr_loss, dim=0)
58 |
59 | if "mulliken" in true_data:
60 | mulliken_loss = F.mse_loss(
61 | pred_data["mulliken"],
62 | true_data["mulliken"],
63 | reduction="none",
64 | ).mean(-1)
65 | mulliken_loss = scatter_mean(
66 | mulliken_loss, index=batch, dim=0, dim_size=batch_size
67 | )
68 | mulliken_loss *= weights
69 | mulliken_loss = torch.sum(mulliken_loss, dim=0)
70 |
71 | if "wbo" in true_data:
72 | wbo_loss = F.mse_loss(
73 | pred_data["wbo"],
74 | true_data["wbo"],
75 | reduction="none",
76 | ).mean(-1)
77 | wbo_loss = 0.5 * scatter_mean(
78 | wbo_loss,
79 | index=bond_aggregation_index,
80 | dim=0,
81 | dim_size=true_data["atoms"].size(0),
82 | )
83 | wbo_loss = scatter_mean(
84 | wbo_loss, index=batch, dim=0, dim_size=batch_size
85 | )
86 | wbo_loss *= weights
87 | wbo_loss = torch.sum(wbo_loss, dim=0)
88 |
89 | if self.param[self.modalities.index("atoms")] == "data":
90 | fnc = F.cross_entropy
91 | take_mean = False
92 | else:
93 | fnc = F.mse_loss
94 | take_mean = True
95 |
96 | atoms_loss = fnc(pred_data["atoms"], true_data["atoms"], reduction="none")
97 | if take_mean:
98 | atoms_loss = atoms_loss.mean(dim=1)
99 | atoms_loss = scatter_mean(
100 | atoms_loss, index=batch, dim=0, dim_size=batch_size
101 | )
102 | atoms_loss, m = self.loss_non_nans(atoms_loss, "atoms")
103 | atoms_loss *= weights[~m]
104 | atoms_loss = torch.sum(atoms_loss, dim=0)
105 |
106 | if self.param[self.modalities.index("charges")] == "data":
107 | fnc = F.cross_entropy
108 | take_mean = False
109 | else:
110 | fnc = F.mse_loss
111 | take_mean = True
112 |
113 | charges_loss = fnc(
114 | pred_data["charges"], true_data["charges"], reduction="none"
115 | )
116 | if take_mean:
117 | charges_loss = charges_loss.mean(dim=1)
118 | charges_loss = scatter_mean(
119 | charges_loss, index=batch, dim=0, dim_size=batch_size
120 | )
121 | charges_loss, m = self.loss_non_nans(charges_loss, "charges")
122 | charges_loss *= weights[~m]
123 | charges_loss = torch.sum(charges_loss, dim=0)
124 |
125 | if self.param[self.modalities.index("bonds")] == "data":
126 | fnc = F.cross_entropy
127 | take_mean = False
128 | else:
129 | fnc = F.mse_loss
130 | take_mean = True
131 |
132 | bonds_loss = fnc(pred_data["bonds"], true_data["bonds"], reduction="none")
133 | if take_mean:
134 | bonds_loss = bonds_loss.mean(dim=1)
135 | bonds_loss = 0.5 * scatter_mean(
136 | bonds_loss,
137 | index=bond_aggregation_index,
138 | dim=0,
139 | dim_size=true_data["atoms"].size(0),
140 | )
141 | bonds_loss = scatter_mean(
142 | bonds_loss, index=batch, dim=0, dim_size=batch_size
143 | )
144 | bonds_loss, m = self.loss_non_nans(bonds_loss, "bonds")
145 | bonds_loss *= weights[~m]
146 | bonds_loss = bonds_loss.sum(dim=0)
147 |
148 | if "ring" in self.modalities:
149 | ring_loss = F.cross_entropy(
150 | pred_data["ring"], true_data["ring"], reduction="none"
151 | )
152 | ring_loss = scatter_mean(
153 | ring_loss, index=batch, dim=0, dim_size=batch_size
154 | )
155 | ring_loss, m = self.loss_non_nans(ring_loss, "ring")
156 | ring_loss *= weights[~m]
157 | ring_loss = torch.sum(ring_loss, dim=0)
158 | else:
159 | ring_loss = None
160 |
161 | if "aromatic" in self.modalities:
162 | aromatic_loss = F.cross_entropy(
163 | pred_data["aromatic"], true_data["aromatic"], reduction="none"
164 | )
165 | aromatic_loss = scatter_mean(
166 | aromatic_loss, index=batch, dim=0, dim_size=batch_size
167 | )
168 | aromatic_loss, m = self.loss_non_nans(aromatic_loss, "aromatic")
169 | aromatic_loss *= weights[~m]
170 | aromatic_loss = torch.sum(aromatic_loss, dim=0)
171 | else:
172 | aromatic_loss = None
173 |
174 | if "hybridization" in self.modalities:
175 | hybridization_loss = F.cross_entropy(
176 | pred_data["hybridization"],
177 | true_data["hybridization"],
178 | reduction="none",
179 | )
180 | hybridization_loss = scatter_mean(
181 | hybridization_loss, index=batch, dim=0, dim_size=batch_size
182 | )
183 | hybridization_loss, m = self.loss_non_nans(
184 | hybridization_loss, "hybridization"
185 | )
186 | hybridization_loss *= weights[~m]
187 | hybridization_loss = torch.sum(hybridization_loss, dim=0)
188 | else:
189 | hybridization_loss = None
190 |
191 | else:
192 | regr_loss = F.mse_loss(
193 | pred_data[self.regression_key],
194 | true_data[self.regression_key],
195 | reduction="mean",
196 | ).mean(-1)
197 | if self.param[self.modalities.index("atoms")] == "data":
198 | fnc = F.cross_entropy
199 | else:
200 | fnc = F.mse_loss
201 | atoms_loss = fnc(pred_data["atoms"], true_data["atoms"], reduction="mean")
202 | if self.param[self.modalities.index("charges")] == "data":
203 | fnc = F.cross_entropy
204 | else:
205 | fnc = F.mse_loss
206 | charges_loss = fnc(
207 | pred_data["charges"], true_data["charges"], reduction="mean"
208 | )
209 | if self.param[self.modalities.index("bonds")] == "data":
210 | fnc = F.cross_entropy
211 | else:
212 | fnc = F.mse_loss
213 | bonds_loss = fnc(pred_data["bonds"], true_data["bonds"], reduction="mean")
214 |
215 | if "ring" in self.modalities:
216 | ring_loss = F.cross_entropy(
217 | pred_data["ring"], true_data["ring"], reduction="mean"
218 | )
219 | else:
220 | ring_loss = None
221 |
222 | if "aromatic" in self.modalities:
223 | aromatic_loss = F.cross_entropy(
224 | pred_data["aromatic"], true_data["aromatic"], reduction="mean"
225 | )
226 | else:
227 | aromatic_loss = None
228 |
229 | if "hybridization" in self.modalities:
230 | hybridization_loss = F.cross_entropy(
231 | pred_data["hybridization"],
232 | true_data["hybridization"],
233 | reduction="mean",
234 | )
235 | else:
236 | hybridization_loss = None
237 |
238 | loss = {
239 | self.regression_key: regr_loss,
240 | "atoms": atoms_loss,
241 | "charges": charges_loss,
242 | "bonds": bonds_loss,
243 | "ring": ring_loss,
244 | "aromatic": aromatic_loss,
245 | "hybridization": hybridization_loss,
246 | "mulliken": mulliken_loss,
247 | "wbo": wbo_loss,
248 | }
249 |
250 | return loss
251 |
252 |
253 | class EdgePredictionLoss(nn.Module):
254 | def __init__(
255 | self,
256 | ) -> None:
257 | super().__init__()
258 |
259 | def forward(
260 | self,
261 | true_data: Dict,
262 | pred_data: Dict,
263 | ) -> Dict:
264 | bonds_loss = F.cross_entropy(pred_data, true_data, reduction="mean")
265 |
266 | return bonds_loss
267 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/run_evaluation_ligand.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import warnings
4 |
5 | import numpy as np
6 | import pytorch_lightning as pl
7 | import torch
8 |
9 | from experiments.data.distributions import DistributionProperty
10 |
11 | warnings.filterwarnings(
12 | "ignore", category=UserWarning, message="TypedStorage is deprecated"
13 | )
14 | warnings.filterwarnings(
15 | "ignore", category=UserWarning, message="TypedStorage is deprecated"
16 | )
17 |
18 |
19 | class dotdict(dict):
20 | """dot.notation access to dictionary attributes"""
21 |
22 | __getattr__ = dict.get
23 | __setattr__ = dict.__setitem__
24 | __delattr__ = dict.__delitem__
25 |
26 |
27 | def evaluate(
28 | model_path,
29 | save_dir,
30 | save_xyz=True,
31 | calculate_energy=False,
32 | batch_size=2,
33 | use_ligand_dataset_sizes=False,
34 | build_obabel_mol=False,
35 | save_traj=False,
36 | use_energy_guidance=False,
37 | ckpt_energy_model=None,
38 | guidance_scale=1.0e-4,
39 | ddpm=True,
40 | eta_ddim=1.0,
41 | ):
42 | print("Loading from checkpoint; adapting hyperparameters to specified args")
43 |
44 | # load model
45 | ckpt = torch.load(model_path)
46 | ckpt["hyper_parameters"]["load_ckpt"] = None
47 | ckpt["hyper_parameters"]["load_ckpt_from_pretrained"] = None
48 | ckpt["hyper_parameters"]["test_save_dir"] = save_dir
49 | ckpt["hyper_parameters"]["calculate_energy"] = calculate_energy
50 | ckpt["hyper_parameters"]["save_xyz"] = save_xyz
51 | ckpt["hyper_parameters"]["batch_size"] = batch_size
52 | ckpt["hyper_parameters"]["select_train_subset"] = False
53 | ckpt["hyper_parameters"]["diffusion_pretraining"] = False
54 | ckpt["hyper_parameters"]["gpus"] = 1
55 | ckpt["hyper_parameters"]["use_ligand_dataset_sizes"] = use_ligand_dataset_sizes
56 | ckpt["hyper_parameters"]["build_obabel_mol"] = build_obabel_mol
57 | ckpt["hyper_parameters"]["save_traj"] = save_traj
58 | ckpt["hyper_parameters"]["num_charge_classes"] = 6
59 |
60 | ckpt_path = os.path.join(save_dir, f"test_model.ckpt")
61 | if not os.path.exists(ckpt_path):
62 | torch.save(ckpt, ckpt_path)
63 |
64 | hparams = ckpt["hyper_parameters"]
65 | hparams = dotdict(hparams)
66 |
67 | print(f"Loading {hparams.dataset} Datamodule.")
68 | dataset = "crossdocked"
69 | if hparams.use_adaptive_loader:
70 | print("Using adaptive dataloader")
71 | from experiments.data.ligand.ligand_dataset_adaptive import (
72 | LigandPocketDataModule as DataModule,
73 | )
74 | else:
75 | print("Using non-adaptive dataloader")
76 | from experiments.data.ligand.ligand_dataset_nonadaptive import (
77 | LigandPocketDataModule as DataModule,
78 | )
79 | datamodule = DataModule(hparams)
80 |
81 | from experiments.data.data_info import GeneralInfos as DataInfos
82 |
83 | dataset_info = DataInfos(datamodule, hparams)
84 |
85 | train_smiles = (
86 | list(datamodule.train_dataset.smiles)
87 | if hparams.dataset != "pubchem"
88 | else datamodule.train_smiles
89 | )
90 | prop_norm, prop_dist = None, None
91 | if len(hparams.properties_list) > 0 and hparams.context_mapping:
92 | prop_norm = datamodule.compute_mean_mad(hparams.properties_list)
93 | prop_dist = DistributionProperty(datamodule, hparams.properties_list)
94 | prop_dist.set_normalizer(prop_norm)
95 |
96 | if hparams.continuous:
97 | print("Using continuous diffusion")
98 | from experiments.diffusion_continuous import Trainer
99 | else:
100 | print("Using discrete diffusion")
101 | if hparams.diffusion_pretraining:
102 | print("Starting pre-training")
103 | if hparams.additional_feats:
104 | from experiments.diffusion_pretrain_discrete_addfeats import (
105 | Trainer,
106 | )
107 | else:
108 | from experiments.diffusion_pretrain_discrete import Trainer
109 | elif hparams.additional_feats:
110 | if dataset == "crossdocked":
111 | print("Ligand-pocket testing using additional features")
112 | from experiments.diffusion_discrete_moreFeats_ligand import Trainer
113 | else:
114 | print("Using additional features")
115 | from experiments.diffusion_discrete_moreFeats import Trainer
116 | else:
117 | if dataset == "crossdocked":
118 | print("Ligand-pocket testing")
119 | histogram = os.path.join(hparams.dataset_root, "size_distribution.npy")
120 | histogram = np.load(histogram).tolist()
121 | from experiments.diffusion_discrete_pocket import Trainer
122 | else:
123 | from experiments.diffusion_discrete import Trainer
124 |
125 | if build_obabel_mol:
126 | print(
127 | "Sampled molecules will be built with OpenBabel (without bond information)!"
128 | )
129 |
130 | trainer = pl.Trainer(
131 | accelerator="gpu" if hparams.gpus else "cpu",
132 | devices=1,
133 | strategy="auto",
134 | num_nodes=1,
135 | precision=hparams.precision,
136 | )
137 | pl.seed_everything(seed=hparams.seed, workers=hparams.gpus > 1)
138 |
139 | model = Trainer(
140 | hparams=hparams,
141 | dataset_info=dataset_info,
142 | smiles_list=train_smiles,
143 | prop_dist=prop_dist,
144 | prop_norm=prop_norm,
145 | histogram=histogram,
146 | )
147 |
148 | trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path)
149 |
150 |
151 | def get_args():
152 | # fmt: off
153 | parser = argparse.ArgumentParser(description='Data generation')
154 | parser.add_argument('--model-path', default="----", type=str,
155 | help='Path to trained model')
156 | parser.add_argument("--use-energy-guidance", default=False, action="store_true")
157 | parser.add_argument("--use-ligand-dataset-sizes", default=False, action="store_true")
158 | parser.add_argument("--build-obabel-mol", default=False, action="store_true")
159 | parser.add_argument("--save-traj", default=False, action="store_true")
160 | parser.add_argument("--ckpt-energy-model", default=None, type=str)
161 | parser.add_argument('--guidance-scale', default=1.0e-4, type=float,
162 | help='How to scale the guidance shift')
163 | parser.add_argument('--save-dir', default="----", type=str,
164 | help='Path to test output')
165 | parser.add_argument('--save-xyz', default=False, action="store_true",
166 | help='Whether or not to store generated molecules in xyz files')
167 | parser.add_argument('--calculate-energy', default=False, action="store_true",
168 | help='Whether or not to calculate xTB energies and forces')
169 | parser.add_argument('--batch-size', default=80, type=int,
170 | help='Batch-size to generate the selected ngraphs. Defaults to 80.')
171 | parser.add_argument('--ddim', default=False, action="store_true",
172 | help='If DDIM sampling should be used. Defaults to False')
173 | parser.add_argument('--eta-ddim', default=1.0, type=float,
174 | help='How to scale the std of noise in the reverse posterior. \
175 | Can also be used for DDPM to track a deterministic trajectory. \
176 | Defaults to 1.0')
177 | args = parser.parse_args()
178 | return args
179 |
180 |
181 | if __name__ == "__main__":
182 | args = get_args()
183 | # Evaluate negative log-likelihood for the test partitions
184 | evaluate(
185 | model_path=args.model_path,
186 | save_dir=args.save_dir,
187 | batch_size=args.batch_size,
188 | ddpm=not args.ddim,
189 | eta_ddim=args.eta_ddim,
190 | save_xyz=args.save_xyz,
191 | save_traj=args.save_traj,
192 | calculate_energy=args.calculate_energy,
193 | use_energy_guidance=args.use_energy_guidance,
194 | use_ligand_dataset_sizes=args.use_ligand_dataset_sizes,
195 | build_obabel_mol=args.build_obabel_mol,
196 | ckpt_energy_model=args.ckpt_energy_model,
197 | guidance_scale=args.guidance_scale,
198 | )
199 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/run_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | from argparse import ArgumentParser
4 |
5 | import numpy as np
6 | import pytorch_lightning as pl
7 | import torch.nn.functional as F
8 | from pytorch_lightning.callbacks import (
9 | LearningRateMonitor,
10 | ModelCheckpoint,
11 | ModelSummary,
12 | TQDMProgressBar,
13 | )
14 | from pytorch_lightning.loggers import TensorBoardLogger
15 |
16 | from callbacks.ema import ExponentialMovingAverage
17 |
18 | warnings.filterwarnings(
19 | "ignore", category=UserWarning, message="TypedStorage is deprecated"
20 | )
21 |
22 | from experiments.data.distributions import DistributionProperty
23 | from experiments.hparams import add_arguments
24 |
25 | if __name__ == "__main__":
26 | parser = ArgumentParser()
27 | parser = add_arguments(parser)
28 | hparams = parser.parse_args()
29 |
30 | if not os.path.exists(hparams.save_dir):
31 | os.makedirs(hparams.save_dir)
32 |
33 | if not os.path.isdir(hparams.save_dir + f"/run{hparams.id}/"):
34 | print("Creating directory")
35 | os.mkdir(hparams.save_dir + f"/run{hparams.id}/")
36 | print(f"Starting Run {hparams.id}")
37 | ema_callback = ExponentialMovingAverage(decay=hparams.ema_decay)
38 | checkpoint_callback = ModelCheckpoint(
39 | dirpath=hparams.save_dir + f"/run{hparams.id}/",
40 | save_top_k=3,
41 | monitor="val/loss",
42 | save_last=True,
43 | )
44 | lr_logger = LearningRateMonitor()
45 | tb_logger = TensorBoardLogger(
46 | hparams.save_dir + f"/run{hparams.id}/", default_hp_metric=False
47 | )
48 |
49 | print(f"Loading {hparams.dataset} Datamodule.")
50 | non_adaptive = True
51 | if hparams.dataset == "drugs":
52 | dataset = "drugs"
53 | if hparams.use_adaptive_loader:
54 | print("Using adaptive dataloader")
55 | non_adaptive = False
56 | from experiments.data.geom.geom_dataset_adaptive import (
57 | GeomDataModule as DataModule,
58 | )
59 | else:
60 | print("Using non-adaptive dataloader")
61 | from experiments.data.geom.geom_dataset_nonadaptive import (
62 | GeomDataModule as DataModule,
63 | )
64 | elif hparams.dataset == "qm9":
65 | dataset = "qm9"
66 | from experiments.data.qm9.qm9_dataset import QM9DataModule as DataModule
67 |
68 | elif hparams.dataset == "pubchem":
69 | dataset = "pubchem" # take dataset infos from GEOM for simplicity
70 | if hparams.use_adaptive_loader:
71 | print("Using adaptive dataloader")
72 | non_adaptive = False
73 | from experiments.data.pubchem.pubchem_dataset_adaptive import (
74 | PubChemDataModule as DataModule,
75 | )
76 | else:
77 | print("Using non-adaptive dataloader")
78 | from experiments.data.pubchem.pubchem_dataset_nonadaptive import (
79 | PubChemDataModule as DataModule,
80 | )
81 | elif hparams.dataset == "crossdocked":
82 | dataset = "crossdocked"
83 | if hparams.use_adaptive_loader:
84 | print("Using adaptive dataloader")
85 | from experiments.data.ligand.ligand_dataset_adaptive import (
86 | LigandPocketDataModule as DataModule,
87 | )
88 | else:
89 | print("Using non-adaptive dataloader")
90 | from experiments.data.ligand.ligand_dataset_nonadaptive import (
91 | LigandPocketDataModule as DataModule,
92 | )
93 | else:
94 | raise ValueError(f"Unknown dataset: {hparams.dataset}")
95 |
96 | datamodule = DataModule(hparams)
97 |
98 | from experiments.data.data_info import GeneralInfos as DataInfos
99 |
100 | dataset_info = DataInfos(datamodule, hparams)
101 |
102 | train_smiles = (
103 | (
104 | list(datamodule.train_dataset.smiles)
105 | if hparams.dataset != "pubchem"
106 | else None
107 | )
108 | if not hparams.select_train_subset
109 | else datamodule.train_smiles
110 | )
111 | prop_norm, prop_dist = None, None
112 | if (
113 | len(hparams.properties_list) > 0 and hparams.context_mapping
114 | ) or hparams.property_training:
115 | prop_norm = datamodule.compute_mean_mad(hparams.properties_list)
116 | prop_dist = DistributionProperty(datamodule, hparams.properties_list)
117 | prop_dist.set_normalizer(prop_norm)
118 |
119 | histogram = None
120 |
121 | if hparams.continuous and dataset != "crossdocked":
122 | print("Using continuous diffusion")
123 | if hparams.diffusion_pretraining:
124 | print("Starting pre-training")
125 | from experiments.diffusion_pretrain_continuous import Trainer
126 | else:
127 | from experiments.diffusion_continuous import Trainer
128 | else:
129 | print("Using discrete diffusion")
130 | if hparams.diffusion_pretraining:
131 | if hparams.additional_feats:
132 | print("Starting pre-training on PubChem3D with additional features")
133 | from experiments.diffusion_pretrain_discrete_addfeats import (
134 | Trainer,
135 | )
136 | else:
137 | print("Starting pre-training on PubChem3D")
138 | from experiments.diffusion_pretrain_discrete import Trainer
139 | elif (
140 | dataset == "crossdocked"
141 | and hparams.additional_feats
142 | and not hparams.use_qm_props
143 | ):
144 | histogram = os.path.join(hparams.dataset_root, "size_distribution.npy")
145 | histogram = np.load(histogram).tolist()
146 | print("Ligand-pocket training using additional features")
147 | from experiments.diffusion_discrete_pocket_addfeats import (
148 | Trainer,
149 | )
150 | else:
151 | if dataset == "crossdocked":
152 | histogram = os.path.join(hparams.dataset_root, "size_distribution.npy")
153 | histogram = np.load(histogram).tolist()
154 | if hparams.continuous:
155 | print("Continuous ligand-pocket training")
156 | from experiments.diffusion_continuous_pocket import Trainer
157 | else:
158 | print("Discrete ligand-pocket training")
159 | from experiments.diffusion_discrete_pocket import (
160 | Trainer,
161 | )
162 | else:
163 | if hparams.additional_feats:
164 | from experiments.diffusion_discrete_addfeats import Trainer
165 | else:
166 | from experiments.diffusion_discrete import Trainer
167 |
168 | model = Trainer(
169 | hparams=hparams.__dict__,
170 | dataset_info=dataset_info,
171 | smiles_list=train_smiles,
172 | histogram=histogram,
173 | prop_dist=prop_dist,
174 | prop_norm=prop_norm,
175 | )
176 |
177 | from pytorch_lightning.plugins.environments import LightningEnvironment
178 |
179 | strategy = "ddp" if hparams.gpus > 1 else "auto"
180 | # strategy = 'ddp_find_unused_parameters_true'
181 | callbacks = [
182 | ema_callback,
183 | lr_logger,
184 | checkpoint_callback,
185 | TQDMProgressBar(refresh_rate=5),
186 | ModelSummary(max_depth=2),
187 | ]
188 |
189 | if hparams.ema_decay == 1.0:
190 | callbacks = callbacks[1:]
191 |
192 | trainer = pl.Trainer(
193 | accelerator="gpu" if hparams.gpus else "cpu",
194 | devices=hparams.gpus if hparams.gpus else 1,
195 | strategy=strategy,
196 | plugins=LightningEnvironment(),
197 | num_nodes=1,
198 | logger=tb_logger,
199 | enable_checkpointing=True,
200 | accumulate_grad_batches=hparams.accum_batch,
201 | val_check_interval=hparams.eval_freq,
202 | gradient_clip_val=hparams.grad_clip_val,
203 | callbacks=callbacks,
204 | precision=hparams.precision,
205 | num_sanity_val_steps=2,
206 | max_epochs=hparams.num_epochs,
207 | detect_anomaly=hparams.detect_anomaly,
208 | )
209 |
210 | pl.seed_everything(seed=hparams.seed, workers=hparams.gpus > 1)
211 |
212 | ckpt_path = None
213 | if hparams.load_ckpt is not None:
214 | print("Loading from checkpoint ...")
215 | import torch
216 |
217 | ckpt_path = hparams.load_ckpt
218 | ckpt = torch.load(ckpt_path)
219 | if ckpt["optimizer_states"][0]["param_groups"][0]["lr"] != hparams.lr:
220 | print("Changing learning rate ...")
221 | ckpt["optimizer_states"][0]["param_groups"][0]["lr"] = hparams.lr
222 | ckpt["optimizer_states"][0]["param_groups"][0]["initial_lr"] = hparams.lr
223 | ckpt_path = (
224 | "lr" + "_" + str(hparams.lr) + "_" + os.path.basename(hparams.load_ckpt)
225 | )
226 | ckpt_path = os.path.join(
227 | os.path.dirname(hparams.load_ckpt),
228 | f"retraining_with_lr{hparams.lr}.ckpt",
229 | )
230 | if not os.path.exists(ckpt_path):
231 | torch.save(ckpt, ckpt_path)
232 | trainer.fit(
233 | model=model,
234 | datamodule=datamodule,
235 | ckpt_path=ckpt_path if hparams.load_ckpt is not None else None,
236 | )
237 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/sampling/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/eqgat_diff/experiments/sampling/__init__.py
--------------------------------------------------------------------------------
/eqgat_diff/experiments/sampling/fpscores.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/eqgat_diff/experiments/sampling/fpscores.pkl.gz
--------------------------------------------------------------------------------
/eqgat_diff/experiments/xtb_energy.py:
--------------------------------------------------------------------------------
1 | from ase import Atoms
2 | from xtb.ase.calculator import XTB
3 | import argparse
4 | import numpy as np
5 |
6 |
7 | def get_args():
8 | # fmt: off
9 | parser = argparse.ArgumentParser(description='Data generation')
10 | parser.add_argument('--output-path', default=None, type=str,
11 | help='If set, saves the energy to a text file to the given path.')
12 | parser.add_argument('--data', type=dict, help='Input data as a dictionary containing atom types and coordinates')
13 | args = parser.parse_args()
14 | return args
15 |
16 |
17 | def calculate_xtb_energy(positions, atom_types):
18 | atoms = Atoms(positions=positions, symbols=atom_types)
19 | atoms.calc = XTB(method="GFN2-xTB")
20 | pot_e = atoms.get_potential_energy()
21 | forces = atoms.get_forces()
22 | forces_norm = np.linalg.norm(forces)
23 |
24 | return pot_e, forces_norm
25 |
26 |
27 | if __name__ == "__main__":
28 | args = get_args()
29 |
30 | potential_energy = calculate_xtb_energy(args.data)
31 | if args.output_path is not None:
32 | f = open(args.output_path, "a")
33 | f.write(potential_energy)
34 | f.close()
35 |
--------------------------------------------------------------------------------
/eqgat_diff/experiments/xtb_relaxation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import ase.units
3 | import ase
4 | from ase import Atoms
5 | from ase.io import write, read
6 | import ase.units as units
7 | import logging
8 | import os
9 | import subprocess
10 | from rdkit import rdBase
11 | import argparse
12 | import pickle
13 | import shutil
14 | from tqdm import tqdm
15 |
16 | rdBase.DisableLog("rdApp.error")
17 | logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
18 |
19 |
20 | def make_dir(path):
21 | isExist = os.path.exists(path)
22 | if not isExist:
23 | os.makedirs(path)
24 | print(f"\nDirectory {path} has been created!")
25 | else:
26 | print(f"\nDirectory {path} exists already!")
27 |
28 |
29 | def parse_xtb_xyz(filename):
30 | num_atoms = 0
31 | energy = 0
32 | pos = []
33 | with open(filename, "r") as f:
34 | for line_num, line in enumerate(f):
35 | if line_num == 0:
36 | num_atoms = int(line)
37 | elif line_num == 1:
38 | # xTB outputs energy in Hartrees: Hartree to eV
39 | energy = np.array(
40 | float(line.split(" ")[2]) * units.Hartree, dtype=np.float32
41 | )
42 | elif line_num >= 2:
43 | _, x, y, z = line.split()
44 | pos.append([parse_float(x), parse_float(y), parse_float(z)])
45 |
46 | result = {
47 | "num_atoms": num_atoms,
48 | "energy": energy,
49 | "pos": np.array(pos, dtype=np.float32),
50 | }
51 | return result
52 |
53 |
54 | def parse_float(s: str) -> float:
55 | try:
56 | return float(s)
57 | except ValueError:
58 | base, power = s.split("*^")
59 | return float(base) * 10 ** float(power)
60 |
61 |
62 | def xtb_optimization(data, file_path):
63 | zs = []
64 | positions = []
65 | energies = []
66 | failed_mol_ids = []
67 |
68 | for i, mol in tqdm(enumerate(data)):
69 | xtb_temp_dir = os.path.join(file_path, "xtb_tmp")
70 | make_dir(xtb_temp_dir)
71 |
72 | mol_size = mol.GetNumAtoms()
73 | opt = "lax" if mol_size > 60 else "normal"
74 |
75 | pos = np.array(mol.GetConformer().GetPositions(), dtype=np.float32)
76 | atomic_number = []
77 | for atom in mol.GetAtoms():
78 | atomic_number.append(atom.GetAtomicNum())
79 | z = np.array(atomic_number, dtype=np.int64)
80 | mol = Atoms(numbers=z, positions=pos)
81 |
82 | mol_path = os.path.join(xtb_temp_dir, f"xtb_conformer.xyz")
83 | write(mol_path, images=mol)
84 |
85 | os.chdir(xtb_temp_dir)
86 | try:
87 | subprocess.call(
88 | # ["xtb", mol_path, "--opt", opt, "--cycles", "2000", "--gbsa", "water"],
89 | ["xtb", mol_path, "--opt", opt],
90 | stdout=subprocess.DEVNULL,
91 | stderr=subprocess.STDOUT,
92 | )
93 | except:
94 | print(f"Molecule with id {i} failed!")
95 | failed_mol_ids.append(i)
96 | continue
97 | result = parse_xtb_xyz(os.path.join(xtb_temp_dir, f"xtbopt.xyz"))
98 | atom = read(filename=os.path.join(xtb_temp_dir, f"xtbopt.xyz"))
99 |
100 | os.chdir(file_path)
101 | shutil.rmtree(xtb_temp_dir)
102 |
103 | z = atom.get_atomic_numbers()
104 | energy = result["energy"]
105 | pos = atom.get_positions()
106 |
107 | zs.append(z)
108 | energies.append(energy)
109 | positions.append(pos)
110 |
111 | with open(os.path.join(file_path, "energies.pickle"), "wb") as f:
112 | pickle.dump(energies, f)
113 | with open(os.path.join(file_path, "failed_mol_ids.pickle"), "wb") as f:
114 | pickle.dump(failed_mol_ids, f)
115 |
116 | return zs, positions, energies, failed_mol_ids
117 |
118 |
119 | def get_args():
120 | # fmt: off
121 | parser = argparse.ArgumentParser(description='Data generation')
122 | parser.add_argument('--output-path', default=None, type=str,
123 | help='If set, saves the energy to a text file to the given path.')
124 | parser.add_argument('--data', type=list, help='Input data as a list of RDkit molecules')
125 | args = parser.parse_args()
126 | return args
127 |
128 |
129 | if __name__ == "__main__":
130 | args = get_args()
131 |
132 | _ = xtb_optimization(args.data, args.output_path)
133 |
--------------------------------------------------------------------------------
/eqgat_diff/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = e3moldiffusion
3 | author = Tuan Le
4 | author_email = tuanle@hotmail.de
5 | description = A research module for 3D molecular generation using diffusion methods
6 | long_description = file: README.md
7 | long_description_content_type = text/markdown
8 | classifiers =
9 | Development Status :: 3 - Alpha
10 | Intended Audience :: Science/Research
11 | License :: OSI Approved :: MIT
12 | Programming Language :: Python
13 | Programming Language :: Python :: 3 :: Only
14 | Programming Language :: Python :: 3.9
15 | Topic :: Scientific/Engineering
16 | Topic :: Utilities
17 |
18 | [options]
19 | packages = find:
20 | python_requires = >=3.9.0
21 | setup_requires =
22 | setuptools_scm
23 | install_requires =
24 |
25 | [mypy]
26 | python_version = 3.9
27 | ignore_missing_imports = True
--------------------------------------------------------------------------------
/eqgat_diff/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | use_scm_version={
5 | "version_scheme": "post-release",
6 | "write_to": "e3moldiffusion/_version.py",
7 | },
8 | setup_requires=["setuptools_scm"],
9 | )
--------------------------------------------------------------------------------
/inference/README.md:
--------------------------------------------------------------------------------
1 | ## Inference
2 |
3 | We provide an exemplary jupyter notebook as well as bash script to generate ligands for QM9 and Geom-Drugs.
4 | Notice that the model weights for each dataset are not provided.
5 | As of now, we share the model weights upon request.
6 |
7 | To run the bash script, execute the following:
8 |
9 | ```bash
10 | mamba activate eqgatdiff
11 | bash run_eval_geom.sh
12 | bash run_eval_qm9.sh
13 | ```
14 |
15 | In each bash script, the paths to dataset-root, model-dir and save-dir have to included.
16 | E.g. in case of geom-drugs:
17 |
18 | ````
19 | export PYTHONPATH="YOUR_ABSOLUTE_PATH/eqgat-diff/eqgat_diff"
20 | dataset_root="YOUR_ABSOLUTE_PATH/eqgat-diff/data/geom"
21 | model_path="YOUR_ABSOLUTE_PATH/eqgat-diff/weights/geom/best_mol_stab.ckpt"
22 | save_dir="YOUR_ABSOLUTE_PATH/eqgat-diff/inference/tmp/geom/gen_samples"
23 |
24 | python YOUR_ABSOLUTE_PATH/eqgat_diff/experiments/run_evaluation.py \
25 | --dataset drugs \
26 | --dataset-root $dataset_root \
27 | --model-path $model_path \
28 | --save-dir $save_dir \
29 | --batch-size 50 \
30 | --ngraphs 100 \
31 | ````
32 |
33 | Depending on your GPU, you can increase the batch-size to have more molecules in a generated batch.
--------------------------------------------------------------------------------
/inference/run_eval_geom.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | export PYTHONPATH="YOUR_ABSOLUTE_PATH/eqgat-diff/eqgat_diff"
4 | dataset_root="YOUR_ABSOLUTE_PATH/eqgat-diff/data/geom"
5 | model_path="YOUR_ABSOLUTE_PATH/eqgat-diff/weights/geom/best_mol_stab.ckpt"
6 | save_dir="YOUR_ABSOLUTE_PATH/eqgat-diff/inference/tmp/geom/gen_samples"
7 |
8 | python YOUR_ABSOLUTE_PATH/eqgat_diff/experiments/run_evaluation.py \
9 | --dataset drugs \
10 | --dataset-root $dataset_root \
11 | --model-path $model_path \
12 | --save-dir $save_dir \
13 | --batch-size 50 \
14 | --ngraphs 100 \
--------------------------------------------------------------------------------
/inference/run_eval_qm9.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | export PYTHONPATH="YOUR_ABSOLUTE_PATH/eqgat-diff/eqgat_diff"
4 | dataset_root="YOUR_ABSOLUTE_PATH/eqgat-diff/data/qm9"
5 | model_path="YOUR_ABSOLUTE_PATH/eqgat-diff/weights/qm9/best_mol_stab.ckpt"
6 | save_dir="YOUR_ABSOLUTE_PATH/eqgat-diff/inference/tmp/qm9/gen_samples"
7 |
8 | python YOUR_ABSOLUTE_PATH/eqgat_diff/experiments/run_evaluation.py \
9 | --dataset qm9 \
10 | --dataset-root $dataset_root \
11 | --model-path $model_path \
12 | --save-dir $save_dir \
13 | --batch-size 50 \
14 | --ngraphs 100 \
--------------------------------------------------------------------------------
/inference/tmp/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/inference/tmp/.gitkeep
--------------------------------------------------------------------------------
/inference/tmp/geom/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/inference/tmp/geom/.gitkeep
--------------------------------------------------------------------------------
/weights/README.md:
--------------------------------------------------------------------------------
1 | ## Model Weights
2 |
3 | The model weights for `qm9` and `geom` should be stored in the respective folders.
4 | Currently, we provide trained model weights only upon request.
5 |
6 | Please reach out to tuan.le@pfizer.com or julian.cremer@pfizer.com , if you are interested.
7 |
--------------------------------------------------------------------------------
/weights/geom/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/weights/geom/.gitkeep
--------------------------------------------------------------------------------
/weights/qm9/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanle618/eqgat-diff/68aea80691a8ba82e00816c82875347cbda2c2e5/weights/qm9/.gitkeep
--------------------------------------------------------------------------------