├── src └── shepherd │ ├── model │ ├── utils │ │ ├── __init__.py │ │ ├── positional_encoding.py │ │ └── add_virtual_edges_to_edge_index.py │ ├── equiformer_v2 │ │ ├── __init__.py │ │ ├── datasets │ │ │ └── __init__.py │ │ ├── ocpmodels │ │ │ ├── common │ │ │ │ ├── relaxation │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── optimizers │ │ │ │ │ │ └── __init__.py │ │ │ │ │ └── ml_relaxation.py │ │ │ │ ├── __init__.py │ │ │ │ ├── hpo_utils.py │ │ │ │ ├── transforms.py │ │ │ │ └── logger.py │ │ │ ├── modules │ │ │ │ ├── scaling │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── util.py │ │ │ │ │ └── compat.py │ │ │ │ ├── __init__.py │ │ │ │ ├── normalizer.py │ │ │ │ ├── scheduler.py │ │ │ │ └── loss.py │ │ │ ├── models │ │ │ │ ├── scn │ │ │ │ │ ├── Jd.pt │ │ │ │ │ ├── README.md │ │ │ │ │ ├── sampling.py │ │ │ │ │ └── smearing.py │ │ │ │ ├── escn │ │ │ │ │ └── Jd.pt │ │ │ │ ├── utils │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── activations.py │ │ │ │ ├── __init__.py │ │ │ │ ├── gemnet_gp │ │ │ │ │ ├── README.md │ │ │ │ │ ├── initializers.py │ │ │ │ │ └── layers │ │ │ │ │ │ ├── embedding_block.py │ │ │ │ │ │ ├── base_layers.py │ │ │ │ │ │ └── spherical_basis.py │ │ │ │ ├── gemnet │ │ │ │ │ ├── initializers.py │ │ │ │ │ └── layers │ │ │ │ │ │ ├── embedding_block.py │ │ │ │ │ │ ├── base_layers.py │ │ │ │ │ │ └── spherical_basis.py │ │ │ │ ├── painn │ │ │ │ │ └── README.md │ │ │ │ ├── gemnet_oc │ │ │ │ │ ├── initializers.py │ │ │ │ │ ├── layers │ │ │ │ │ │ ├── embedding_block.py │ │ │ │ │ │ ├── base_layers.py │ │ │ │ │ │ ├── force_scaler.py │ │ │ │ │ │ └── spherical_basis.py │ │ │ │ │ └── README.md │ │ │ │ └── base.py │ │ │ ├── __init__.py │ │ │ ├── preprocessing │ │ │ │ └── __init__.py │ │ │ ├── tasks │ │ │ │ ├── __init__.py │ │ │ │ └── task.py │ │ │ ├── datasets │ │ │ │ ├── embeddings │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── atomic_radii.py │ │ │ │ └── __init__.py │ │ │ └── trainers │ │ │ │ └── __init__.py │ │ ├── nets │ │ │ ├── __init__.py │ │ │ └── equiformer_v2 │ │ │ │ ├── Jd.pt │ │ │ │ ├── module_list.py │ │ │ │ ├── radial_function.py │ │ │ │ ├── wigner.py │ │ │ │ ├── gaussian_rbf.py │ │ │ │ └── edge_rot_mat.py │ │ ├── fig │ │ │ ├── equiformer_v2.png │ │ │ ├── equiformer_v2_oc20_results.png │ │ │ └── equiformer_v2_adsorbml_results.png │ │ ├── oc20 │ │ │ ├── trainer │ │ │ │ ├── task_compute_stats.py │ │ │ │ ├── task_relaxation.py │ │ │ │ ├── __init__.py │ │ │ │ ├── logger.py │ │ │ │ ├── make_lmdb_sizes.py │ │ │ │ └── dist_setup.py │ │ │ └── configs │ │ │ │ └── s2ef │ │ │ │ ├── 2M │ │ │ │ ├── base.yml │ │ │ │ └── equiformer_v2 │ │ │ │ │ ├── equiformer_v2_N@12_L@6_M@2.yml │ │ │ │ │ └── equiformer_v2_N@12_L@6_M@2_epochs@30.yml │ │ │ │ └── all_md │ │ │ │ └── equiformer_v2 │ │ │ │ ├── equiformer_v2_N@8_L@4_M@2_31M.yml │ │ │ │ └── equiformer_v2_N@20_L@6_M@3_153M.yml │ │ ├── docs │ │ │ ├── changelog.md │ │ │ └── env_setup.md │ │ ├── scripts │ │ │ └── train │ │ │ │ └── oc20 │ │ │ │ └── s2ef │ │ │ │ └── equiformer_v2 │ │ │ │ ├── equiformer_v2_N@12_L@6_M@2_splits@2M_g@8.sh │ │ │ │ ├── equiformer_v2_N@12_L@6_M@2_splits@2M_g@multi-nodes.sh │ │ │ │ ├── equiformer_v2_N@8_L@4_M@2_splits@all+md_g@multi-nodes.sh │ │ │ │ └── equiformer_v2_N@20_L@6_M@3_splits@all+md_g@multi-nodes.sh │ │ ├── LICENSE │ │ ├── utils.py │ │ └── logger.py │ └── __init__.py │ ├── shepherd_score_utils │ ├── pharm_utils │ │ └── __init__.py │ └── __init__.py │ ├── inference │ └── __init__.py │ ├── __init__.py │ └── model_loader.py ├── examples ├── paper_experiments │ ├── jobs │ │ └── job_logs │ │ │ └── README_empty.md │ ├── samples │ │ ├── README_empty.md │ │ ├── NP_analogues │ │ │ └── README_empty.md │ │ ├── GDB_conditional │ │ │ ├── x2 │ │ │ │ └── README_empty.md │ │ │ ├── x3 │ │ │ │ └── README_empty.md │ │ │ └── x4 │ │ │ │ └── README_empty.md │ │ ├── PDB_analogues_pose │ │ │ └── README_empty.md │ │ ├── GDB_unconditional │ │ │ ├── x1x2 │ │ │ │ └── README_empty.md │ │ │ ├── x1x3 │ │ │ │ └── README_empty.md │ │ │ └── x1x4 │ │ │ │ └── README_empty.md │ │ ├── PDB_analogues_lowestenergy │ │ │ └── README_empty.md │ │ ├── fragment_merging_samples │ │ │ └── README_empty.md │ │ ├── PDB_analogues_bestdocked_pose │ │ │ └── README_empty.md │ │ ├── mosesaq_unconditional │ │ │ └── x1x3x4 │ │ │ │ └── README_empty.md │ │ └── PDB_analogues_bestdocked_lowestenergy │ │ │ └── README_empty.md │ ├── run_inference_gdb_unconditional_x1x2.py │ └── run_inference_gdb_unconditional_x1x3.py └── data │ └── WX7.sdf ├── data ├── conformers │ ├── np │ │ └── molblock_charges_NPs.pkl │ ├── gdb │ │ ├── example_molblock_charges.pkl │ │ └── molblock_charges_9_test100.pkl │ ├── pdb │ │ ├── molblock_charges_pdb_pose.pkl │ │ ├── molblock_charges_bestdocked_pose.pkl │ │ ├── molblock_charges_pdb_lowestenergy.pkl │ │ ├── molblock_charges_bestdocked_lowestenergy.pkl │ │ └── README.md │ ├── distributions │ │ └── atom_pharm_count.npz │ ├── moses_aq │ │ └── example_molblock_charges.pkl │ └── fragment_merging │ │ ├── fragment_merge_condition.pickle │ │ └── fragments │ │ ├── mol_6.mol │ │ ├── mol_1.mol │ │ ├── mol_3.mol │ │ ├── mol_5.mol │ │ ├── mol_0.mol │ │ ├── mol_9.mol │ │ ├── mol_7.mol │ │ ├── mol_2.mol │ │ ├── mol_8.mol │ │ ├── mol_10.mol │ │ ├── mol_12.mol │ │ ├── mol_4.mol │ │ └── mol_11.mol └── shepherd_chkpts │ └── README.md ├── docker └── original_publication_env │ └── Dockerfile ├── app └── README.md ├── LICENSE ├── pyproject.toml ├── CHANGELOG.md └── .gitignore /src/shepherd/model/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /examples/paper_experiments/jobs/job_logs/README_empty.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /examples/paper_experiments/samples/README_empty.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/common/relaxation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/paper_experiments/samples/NP_analogues/README_empty.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/common/relaxation/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/paper_experiments/samples/GDB_conditional/x2/README_empty.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /examples/paper_experiments/samples/GDB_conditional/x3/README_empty.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /examples/paper_experiments/samples/GDB_conditional/x4/README_empty.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /examples/paper_experiments/samples/PDB_analogues_pose/README_empty.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /examples/paper_experiments/samples/GDB_unconditional/x1x2/README_empty.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /examples/paper_experiments/samples/GDB_unconditional/x1x3/README_empty.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /examples/paper_experiments/samples/GDB_unconditional/x1x4/README_empty.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /examples/paper_experiments/samples/PDB_analogues_lowestenergy/README_empty.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /examples/paper_experiments/samples/fragment_merging_samples/README_empty.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /examples/paper_experiments/samples/PDB_analogues_bestdocked_pose/README_empty.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /examples/paper_experiments/samples/mosesaq_unconditional/x1x3x4/README_empty.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /examples/paper_experiments/samples/PDB_analogues_bestdocked_lowestenergy/README_empty.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/nets/__init__.py: -------------------------------------------------------------------------------- 1 | # from .equiformer_v2.equiformer_v2_oc20 import EquiformerV2_OC20 -------------------------------------------------------------------------------- /data/conformers/np/molblock_charges_NPs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/data/conformers/np/molblock_charges_NPs.pkl -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/modules/scaling/__init__.py: -------------------------------------------------------------------------------- 1 | from .scale_factor import ScaleFactor 2 | 3 | __all__ = ["ScaleFactor"] 4 | -------------------------------------------------------------------------------- /data/conformers/gdb/example_molblock_charges.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/data/conformers/gdb/example_molblock_charges.pkl -------------------------------------------------------------------------------- /data/conformers/pdb/molblock_charges_pdb_pose.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/data/conformers/pdb/molblock_charges_pdb_pose.pkl -------------------------------------------------------------------------------- /data/conformers/distributions/atom_pharm_count.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/data/conformers/distributions/atom_pharm_count.npz -------------------------------------------------------------------------------- /data/conformers/gdb/molblock_charges_9_test100.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/data/conformers/gdb/molblock_charges_9_test100.pkl -------------------------------------------------------------------------------- /data/conformers/moses_aq/example_molblock_charges.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/data/conformers/moses_aq/example_molblock_charges.pkl -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/fig/equiformer_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/src/shepherd/model/equiformer_v2/fig/equiformer_v2.png -------------------------------------------------------------------------------- /data/conformers/pdb/molblock_charges_bestdocked_pose.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/data/conformers/pdb/molblock_charges_bestdocked_pose.pkl -------------------------------------------------------------------------------- /data/conformers/pdb/molblock_charges_pdb_lowestenergy.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/data/conformers/pdb/molblock_charges_pdb_lowestenergy.pkl -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/nets/equiformer_v2/Jd.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/src/shepherd/model/equiformer_v2/nets/equiformer_v2/Jd.pt -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/scn/Jd.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/src/shepherd/model/equiformer_v2/ocpmodels/models/scn/Jd.pt -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/escn/Jd.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/src/shepherd/model/equiformer_v2/ocpmodels/models/escn/Jd.pt -------------------------------------------------------------------------------- /data/conformers/fragment_merging/fragment_merge_condition.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/data/conformers/fragment_merging/fragment_merge_condition.pickle -------------------------------------------------------------------------------- /data/conformers/pdb/molblock_charges_bestdocked_lowestenergy.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/data/conformers/pdb/molblock_charges_bestdocked_lowestenergy.pkl -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/fig/equiformer_v2_oc20_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/src/shepherd/model/equiformer_v2/fig/equiformer_v2_oc20_results.png -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/fig/equiformer_v2_adsorbml_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleygroup/shepherd/HEAD/src/shepherd/model/equiformer_v2/fig/equiformer_v2_adsorbml_results.png -------------------------------------------------------------------------------- /src/shepherd/shepherd_score_utils/pharm_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from shepherd.shepherd_score_utils.pharm_utils.pharmacophore import get_pharmacophores 2 | 3 | __all__ = [ 4 | "get_pharmacophores", 5 | ] -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/common/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | -------------------------------------------------------------------------------- /src/shepherd/model/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ShEPhERD model components. 3 | """ 4 | 5 | from shepherd.model.model import Model 6 | from shepherd.model.equiformer_v2_encoder import EquiformerV2 7 | 8 | __all__ = [ 9 | "Model", 10 | "EquiformerV2", 11 | ] 12 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/oc20/trainer/task_compute_stats.py: -------------------------------------------------------------------------------- 1 | from ocpmodels.tasks.task import BaseTask 2 | from ocpmodels.common.registry import registry 3 | 4 | 5 | @registry.register_task("compute_stats") 6 | class ComputeStatsTask(BaseTask): 7 | def run(self): 8 | self.trainer.compute_stats() -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from .atoms_to_graphs import AtomsToGraphs 9 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/nets/equiformer_v2/module_list.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ModuleListInfo(torch.nn.ModuleList): 5 | def __init__(self, info_str, modules=None): 6 | super().__init__(modules) 7 | self.info_str = str(info_str) 8 | 9 | 10 | def __repr__(self): 11 | return self.info_str -------------------------------------------------------------------------------- /src/shepherd/inference/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shepherd Inference submodule. 3 | 4 | Provides the main inference sampling function. 5 | """ 6 | 7 | from shepherd.inference.inference_original import inference_sample 8 | from shepherd.inference.sampler import generate, generate_from_intermediate_time 9 | 10 | __all__ = ['inference_sample', 'generate', 'generate_from_intermediate_time'] -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/docs/changelog.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | 4 | ## July 28 2023 5 | - EquiformerV2 has been incorporated into OCP repository and used in Open Catalyst demo. 6 | 7 | 8 | ## June 28 2023 9 | - [Commit](https://github.com/atomicarchitects/equiformer_v2/commit/8fe8cbaf8f3c27865b6e28c21db7867e75a107f7) 10 | - Rename `v2s` to `v2`. 11 | - Remove unused parts in `nets`. 12 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | __all__ = ["TrainTask", "PredictTask", "ValidateTask", "RelxationTask"] 7 | 8 | from .task import PredictTask, RelxationTask, TrainTask, ValidateTask 9 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/datasets/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "ATOMIC_RADII", 3 | "KHOT_EMBEDDINGS", 4 | "CONTINUOUS_EMBEDDINGS", 5 | "QMOF_KHOT_EMBEDDINGS", 6 | ] 7 | 8 | from .atomic_radii import ATOMIC_RADII 9 | from .continuous_embeddings import CONTINUOUS_EMBEDDINGS 10 | from .khot_embeddings import KHOT_EMBEDDINGS 11 | from .qmof_khot_embeddings import QMOF_KHOT_EMBEDDINGS 12 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .lmdb_dataset import ( 7 | LmdbDataset, 8 | SinglePointLmdbDataset, 9 | TrajectoryLmdbDataset, 10 | data_list_collater, 11 | ) 12 | from .oc22_lmdb_dataset import OC22LmdbDataset 13 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | __all__ = [ 7 | "BaseTrainer", 8 | "ForcesTrainer", 9 | "EnergyTrainer", 10 | ] 11 | 12 | from .base_trainer import BaseTrainer 13 | from .energy_trainer import EnergyTrainer 14 | from .forces_trainer import ForcesTrainer 15 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/scripts/train/oc20/s2ef/equiformer_v2/equiformer_v2_N@12_L@6_M@2_splits@2M_g@8.sh: -------------------------------------------------------------------------------- 1 | python -u -m torch.distributed.launch --nproc_per_node=8 main_oc20.py \ 2 | --distributed \ 3 | --num-gpus 8 \ 4 | --mode train \ 5 | --config-yml 'oc20/configs/s2ef/2M/equiformer_v2/equiformer_v2_N@12_L@6_M@2.yml' \ 6 | --run-dir 'models/oc20/s2ef/2M/equiformer_v2/N@12_L@6_M@2/bs@32_lr@2e-4_wd@1e-3_epochs@12_warmup-epochs@0.1_g@8' \ 7 | --print-every 200 \ 8 | --amp 9 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/oc20/trainer/task_relaxation.py: -------------------------------------------------------------------------------- 1 | from ocpmodels.tasks.task import BaseTask 2 | from ocpmodels.common.registry import registry 3 | 4 | 5 | @registry.register_task("run-relaxations") 6 | class MyRelaxationTask(BaseTask): 7 | def run(self): 8 | assert ( 9 | self.trainer.relax_dataset is not None 10 | ), "Relax dataset is required for making predictions" 11 | assert self.config["checkpoint"] 12 | self.trainer.run_relaxations() -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/scripts/train/oc20/s2ef/equiformer_v2/equiformer_v2_N@12_L@6_M@2_splits@2M_g@multi-nodes.sh: -------------------------------------------------------------------------------- 1 | python main_oc20.py \ 2 | --distributed \ 3 | --num-gpus 8 \ 4 | --num-nodes 2 \ 5 | --mode train \ 6 | --config-yml 'oc20/configs/s2ef/2M/equiformer_v2/equiformer_v2_N@12_L@6_M@2.yml' \ 7 | --run-dir 'models/oc20/s2ef/2M/equiformer_v2/N@12_L@6_M@2/bs@64_lr@2e-4_wd@1e-3_epochs@12_warmup-epochs@0.1_g@8x2' \ 8 | --print-every 200 \ 9 | --amp \ 10 | --submit 11 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/scripts/train/oc20/s2ef/equiformer_v2/equiformer_v2_N@8_L@4_M@2_splits@all+md_g@multi-nodes.sh: -------------------------------------------------------------------------------- 1 | python main_oc20.py \ 2 | --distributed \ 3 | --num-gpus 8 \ 4 | --num-nodes 8 \ 5 | --mode train \ 6 | --config-yml 'oc20/configs/s2ef/all_md/equiformer_v2/equiformer_v2_N@8_L@4_M@2_31M.yml' \ 7 | --run-dir 'models/oc20/s2ef/all_md/equiformer_v2/N@8_L@4_M@2_31M/bs@512_lr@4e-4_wd@1e-3_epochs@3_warmup-epochs@0.01_g@8x8' \ 8 | --print-every 200 \ 9 | --amp \ 10 | --submit 11 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/scripts/train/oc20/s2ef/equiformer_v2/equiformer_v2_N@20_L@6_M@3_splits@all+md_g@multi-nodes.sh: -------------------------------------------------------------------------------- 1 | python main_oc20.py \ 2 | --distributed \ 3 | --num-gpus 8 \ 4 | --num-nodes 16 \ 5 | --mode train \ 6 | --config-yml 'oc20/configs/s2ef/all_md/equiformer_v2/equiformer_v2_N@20_L@6_M@3_153M.yml' \ 7 | --run-dir 'models/oc20/s2ef/all_md/equiformer_v2/N@20_L@6_M@3_153M/bs@512_lr@4e-4_wd@1e-3_epochs@1_warmup-epochs@0.01_g@8x16' \ 8 | --print-every 200 \ 9 | --amp \ 10 | --submit 11 | -------------------------------------------------------------------------------- /src/shepherd/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ShEPhERD: Diffusing Shape, Electrostatics, and Pharmacophores for Drug Design. 3 | 4 | A generative diffusion model (DDPM) framework. 5 | """ 6 | 7 | from importlib.metadata import PackageNotFoundError, version 8 | 9 | try: # noqa: SIM105 10 | __version__ = version("shepherd") 11 | except PackageNotFoundError: 12 | pass 13 | 14 | from .model_loader import load_model, get_model_info, clear_model_cache 15 | 16 | __all__ = [ 17 | "load_model", 18 | "get_model_info", 19 | "clear_model_cache", 20 | ] 21 | -------------------------------------------------------------------------------- /src/shepherd/shepherd_score_utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for ShEPhERD score calculation. 3 | """ 4 | 5 | from shepherd.shepherd_score_utils.generate_point_cloud import ( 6 | get_atom_coords, 7 | get_atomic_vdw_radii, 8 | get_molecular_surface, 9 | get_electrostatics, 10 | get_electrostatics_given_point_charges, 11 | ) 12 | from shepherd.shepherd_score_utils.conformer_generation import update_mol_coordinates 13 | 14 | __all__ = [ 15 | "get_atom_coords", 16 | "get_atomic_vdw_radii", 17 | "get_molecular_surface", 18 | "get_electrostatics", 19 | "get_electrostatics_given_point_charges", 20 | "update_mol_coordinates", 21 | ] -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/oc20/configs/s2ef/2M/base.yml: -------------------------------------------------------------------------------- 1 | trainer: forces 2 | 3 | dataset: 4 | - src: data/s2ef/2M/train/ 5 | normalize_labels: True 6 | target_mean: -0.7554450631141663 7 | target_std: 2.887317180633545 8 | grad_target_mean: 0.0 9 | grad_target_std: 2.887317180633545 10 | - src: data/s2ef/all/val_id/ 11 | 12 | logger: tensorboard 13 | 14 | task: 15 | dataset: trajectory_lmdb 16 | description: "Regressing to energies and forces for DFT trajectories from OCP" 17 | type: regression 18 | metric: mae 19 | labels: 20 | - potential energy 21 | grad_input: atomic forces 22 | train_on_free_atoms: True 23 | eval_on_free_atoms: True 24 | -------------------------------------------------------------------------------- /src/shepherd/model/utils/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | def positional_encoding(position, dim, device): 5 | if len(position.shape) == 1: 6 | position = position[None,:].T 7 | assert len(position.shape) == 2 8 | assert position.shape[1] == 1 9 | assert dim % 2 == 0 10 | 11 | # position has shape (B, 1) 12 | # dim is scalar 13 | # returns position embeddings of shape (B, dim) 14 | 15 | pe = torch.zeros(position.shape[0], dim, device = device) 16 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float, device = device) * -(math.log(10000.0) / dim))) 17 | pe[:, 0::2] = torch.sin(position * div_term) 18 | pe[:, 1::2] = torch.cos(position * div_term) 19 | return pe -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .base import BaseModel 7 | #from .cgcnn import CGCNN 8 | #from .dimenet import DimeNetWrap as DimeNet 9 | #from .dimenet_plus_plus import DimeNetPlusPlusWrap as DimeNetPlusPlus 10 | #from .forcenet import ForceNet 11 | #from .gemnet.gemnet import GemNetT 12 | #from .gemnet_gp.gemnet import GraphParallelGemNetT as GraphParallelGemNetT 13 | #from .gemnet_oc.gemnet_oc import GemNetOC 14 | #from .painn.painn import PaiNN 15 | #from .schnet import SchNetWrap as SchNet 16 | #from .scn.scn import SphericalChannelNetwork 17 | #from .spinconv import spinconv 18 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/oc20/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | __all__ = [ 7 | "BaseTrainerV2", 8 | "ForcesTrainerV2", 9 | "EnergyTrainerV2", 10 | "ComputeStatsTask", 11 | "MyRelaxationTask", 12 | "LmdbDatasetV2" 13 | ] 14 | 15 | #from .base_trainer import BaseTrainerV2 16 | #from .energy_trainer import EnergyTrainerV2 17 | #from .forces_trainer import ForcesTrainerV2 18 | 19 | from .energy_trainer_v2 import EnergyTrainerV2 20 | from .forces_trainer_v2 import ForcesTrainerV2 21 | from .task_compute_stats import ComputeStatsTask 22 | from .task_relaxation import MyRelaxationTask 23 | 24 | from .lmdb_dataset import LmdbDatasetV2 25 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/modules/scaling/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.nn as nn 4 | 5 | from .scale_factor import ScaleFactor 6 | 7 | 8 | def ensure_fitted(module: nn.Module, warn: bool = False): 9 | for name, child in module.named_modules(): 10 | if not isinstance(child, ScaleFactor) or child.fitted: 11 | continue 12 | if child.name is not None: 13 | name = f"{child.name} ({name})" 14 | msg = ( 15 | f"Scale factor {name} is not fitted. " 16 | "Please make sure that you either (1) load a checkpoint with fitted scale factors, " 17 | "(2) explicitly load scale factors using the `model.scale_file` attribute, or " 18 | "(3) fit the scale factors using the `fit.py` script." 19 | ) 20 | if warn: 21 | logging.warning(msg) 22 | else: 23 | raise ValueError(msg) 24 | -------------------------------------------------------------------------------- /docker/original_publication_env/Dockerfile: -------------------------------------------------------------------------------- 1 | # Use CUDA runtime with cuDNN for PyTorch 1.12.1 + CUDA 11.3 2 | FROM nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu20.04 3 | 4 | ENV DEBIAN_FRONTEND=noninteractive \ 5 | TZ=Etc/UTC \ 6 | CONDA_DIR=/opt/conda \ 7 | PATH=/opt/conda/bin:$PATH 8 | 9 | # Install minimal dependencies & Miniconda 10 | RUN apt-get update && apt-get install -y --no-install-recommends \ 11 | wget bzip2 ca-certificates \ 12 | && rm -rf /var/lib/apt/lists/* \ 13 | && wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/conda.sh \ 14 | && bash /tmp/conda.sh -b -p $CONDA_DIR \ 15 | && rm /tmp/conda.sh \ 16 | && conda clean -afy 17 | 18 | # Create shepherd env 19 | COPY shepherd_env.yml /workspace/shepherd_env.yml 20 | WORKDIR /workspace 21 | RUN conda env create -f shepherd_env.yml \ 22 | && conda clean -afy 23 | 24 | # Activate shepherd env 25 | SHELL ["conda", "run", "-n", "shepherd", "/bin/bash", "-c"] 26 | 27 | CMD ["bash"] -------------------------------------------------------------------------------- /data/conformers/fragment_merging/fragments/mol_6.mol: -------------------------------------------------------------------------------- 1 | D68EV3CPROA-x1537_0A 2 | RDKit 3D 3 | 4 | 10 10 0 0 0 0 0 0 0 0999 V2000 5 | -8.2600 -4.8770 -7.2440 N 0 0 0 0 0 0 0 0 0 0 0 0 6 | -9.5880 -2.8150 -7.6680 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -8.4300 -3.7320 -7.8620 C 0 0 0 0 0 0 0 0 0 0 0 0 8 | -7.1220 -5.4200 -7.7580 N 0 0 0 0 0 0 0 0 0 0 0 0 9 | -6.6020 -4.6430 -8.6520 N 0 0 0 0 0 0 0 0 0 0 0 0 10 | -7.4070 -3.5690 -8.7290 N 0 0 0 0 0 0 0 0 0 0 0 0 11 | -9.7671 -2.6732 -6.5920 H 0 0 0 0 0 0 0 0 0 0 0 0 12 | -9.3684 -1.8433 -8.1344 H 0 0 0 0 0 0 0 0 0 0 0 0 13 | -10.4835 -3.2515 -8.1344 H 0 0 0 0 0 0 0 0 0 0 0 0 14 | -7.2670 -2.7567 -9.3466 H 0 0 0 0 0 0 0 0 0 0 0 0 15 | 3 2 1 0 16 | 3 1 2 0 17 | 4 1 1 0 18 | 5 4 2 0 19 | 6 5 1 0 20 | 6 3 1 0 21 | 2 7 1 0 22 | 2 8 1 0 23 | 2 9 1 0 24 | 6 10 1 0 25 | M END 26 | -------------------------------------------------------------------------------- /app/README.md: -------------------------------------------------------------------------------- 1 | # ShEPhERD Streamlit App 2 | 3 | Interactive web interface for **ShEPhERD** (Shape, Electrostatics, and Pharmacophores for Enhanced Rational Design) - a diffusion model for bioisosteric drug design. 4 | 5 | ## Features 6 | - **Molecule Input**: SMILES, MolBlock, XYZ, or test molecules from MOSES dataset 7 | - **Conditional Generation**: Generate molecules matching reference shape, electrostatics, and/or pharmacophores 8 | - **Multiple Models**: MOSES-aq (full conditioning) and GDB variants (shape/electrostatics only) 9 | - **Interactive Visualization**: 2D/3D molecular structures with interaction profiles 10 | - **Atom Inpainting**: Advanced control for modifying specific molecular regions to enable scaffold decoration 11 | - **Evaluation & Export**: Similarity scoring and SDF/XYZ download of generated molecules 12 | 13 | ### Installation 14 | ``` 15 | uv pip install streamlit stmol "shepherd-score>=1.1.3" py3Dmol seaborn ipython_genutils 16 | ``` 17 | NOTE: requires shepherd-score >= 1.1.3 for visualizations 18 | 19 | ### How to use 20 | ``` 21 | streamlit run app.py 22 | ``` -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/nets/equiformer_v2/radial_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RadialFunction(nn.Module): 6 | ''' 7 | Contruct a radial function (linear layers + layer normalization + SiLU) given a list of channels 8 | ''' 9 | def __init__(self, channels_list): 10 | super().__init__() 11 | modules = [] 12 | input_channels = channels_list[0] 13 | for i in range(len(channels_list)): 14 | if i == 0: 15 | continue 16 | 17 | modules.append(nn.Linear(input_channels, channels_list[i], bias=True)) 18 | input_channels = channels_list[i] 19 | 20 | if i == len(channels_list) - 1: 21 | break 22 | 23 | modules.append(nn.LayerNorm(channels_list[i])) 24 | modules.append(torch.nn.SiLU()) 25 | 26 | self.net = nn.Sequential(*modules) 27 | 28 | 29 | def forward(self, inputs): 30 | return self.net(inputs) 31 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/scn/README.md: -------------------------------------------------------------------------------- 1 | # Spherical Channels for Modeling Atomic Interactions 2 | 3 | C. Lawrence Zitnick, Abhishek Das, Adeesh Kolluru, Janice Lan, Muhammed Shuaibi, Anuroop Sriram, Zachary Ulissi, Brandon Wood 4 | 5 | [[`arXiv:2206.14331`](https://arxiv.org/abs/2206.14331)] 6 | 7 | To run the Spherical Channel Network (SCN), install [e3nn](https://github.com/e3nn/e3nn/) with `pip install e3nn==0.2.6`. 8 | 9 | SCN was developed with e3nn v0.2.6, and might run slower with later versions [[1](https://github.com/Open-Catalyst-Project/ocp/issues/397), [2](https://github.com/Open-Catalyst-Project/ocp/pull/402)]. 10 | 11 | ## Citing 12 | 13 | If you use SCN in your work, please consider citing: 14 | 15 | ```bibtex 16 | @inproceedings{zitnick_scn_2022, 17 | title = {{Spherical Channels for Modeling Atomic Interactions}}, 18 | author = {Zitnick, C. Lawrence and Das, Abhishek and Kolluru, Adeesh and Lan, Janice and Shuaibi, Muhammed and Sriram, Anuroop and Ulissi, Zachary and Wood, Brandon}, 19 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, 20 | year = {2022}, 21 | } 22 | ``` 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Coley Research Group @ MIT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yi-Lun Liao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/conformers/fragment_merging/fragments/mol_1.mol: -------------------------------------------------------------------------------- 1 | D68EV3CPROA-x1498_0A 2 | RDKit 3D 3 | 4 | 12 12 0 0 0 0 0 0 0 0999 V2000 5 | -8.6550 -3.7860 -8.1420 N 0 0 0 0 0 0 0 0 0 0 0 0 6 | -7.5860 -4.5140 -8.5660 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -6.8210 -3.9810 -9.4850 O 0 0 0 0 0 0 0 0 0 0 0 0 8 | -7.2610 -5.7690 -8.0470 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | -8.0670 -6.2770 -7.0430 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | -9.1560 -5.5540 -6.5950 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | -9.4090 -4.3230 -7.1730 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | -6.0329 -4.4767 -9.8216 H 0 0 0 0 0 0 0 0 0 0 0 0 13 | -6.3951 -6.3338 -8.4228 H 0 0 0 0 0 0 0 0 0 0 0 0 14 | -7.8402 -7.2583 -6.6008 H 0 0 0 0 0 0 0 0 0 0 0 0 15 | -9.8053 -5.9474 -5.7989 H 0 0 0 0 0 0 0 0 0 0 0 0 16 | -10.2791 -3.7521 -6.8168 H 0 0 0 0 0 0 0 0 0 0 0 0 17 | 2 1 2 0 18 | 3 2 1 0 19 | 4 2 1 0 20 | 5 4 2 0 21 | 6 5 1 0 22 | 7 6 2 0 23 | 7 1 1 0 24 | 3 8 1 0 25 | 4 9 1 0 26 | 5 10 1 0 27 | 6 11 1 0 28 | 7 12 1 0 29 | M END 30 | -------------------------------------------------------------------------------- /data/conformers/fragment_merging/fragments/mol_3.mol: -------------------------------------------------------------------------------- 1 | D68EV3CPROA-x0147_0A 2 | RDKit 3D 3 | 4 | 13 14 0 0 0 0 0 0 0 0999 V2000 5 | -6.6970 -4.8840 -8.8560 N 0 0 0 0 0 0 0 0 0 0 0 0 6 | -7.3500 -3.7580 -8.8970 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -8.4850 -5.1170 -7.6100 C 0 0 0 0 0 0 0 0 0 0 0 0 8 | -9.5900 -5.4640 -6.7880 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | -10.5010 -3.3770 -7.0450 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | -9.5040 -2.9460 -7.8410 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | -7.4060 -5.7610 -8.0350 N 0 0 0 0 0 0 0 0 0 0 0 0 12 | -10.5710 -4.6390 -6.5110 N 0 0 0 0 0 0 0 0 0 0 0 0 13 | -8.4770 -3.8300 -8.1380 N 0 0 0 0 0 0 0 0 0 0 0 0 14 | -7.0342 -2.8714 -9.4664 H 0 0 0 0 0 0 0 0 0 0 0 0 15 | -9.6248 -6.4774 -6.3616 H 0 0 0 0 0 0 0 0 0 0 0 0 16 | -11.3114 -2.6720 -6.8076 H 0 0 0 0 0 0 0 0 0 0 0 0 17 | -9.5040 -1.9225 -8.2441 H 0 0 0 0 0 0 0 0 0 0 0 0 18 | 2 1 2 0 19 | 4 3 1 0 20 | 6 5 2 0 21 | 7 3 2 0 22 | 7 1 1 0 23 | 8 4 2 0 24 | 8 5 1 0 25 | 9 6 1 0 26 | 9 3 1 0 27 | 9 2 1 0 28 | 2 10 1 0 29 | 4 11 1 0 30 | 5 12 1 0 31 | 6 13 1 0 32 | M END 33 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/docs/env_setup.md: -------------------------------------------------------------------------------- 1 | # Environment Setup 2 | 3 | 4 | - We use conda to install required packages: 5 | ``` 6 | conda env create -f env/env_equiformer_v2.yml 7 | ``` 8 | 9 | - This will create a new environment called `equiformer_v2`. 10 | 11 | - We activate the environment: 12 | ``` 13 | export PYTHONNOUSERSITE=True # prevent using packages from base 14 | conda activate equiformer_v2 15 | ``` 16 | 17 | - Besides, [`env/env_equiformer_v2.yml`](../env/env_equiformer_v2.yml) specifies versions of all packages. 18 | 19 | - After setting up the environment, clone OC20 GitHub repository: 20 | ``` 21 | git clone https://github.com/Open-Catalyst-Project/ocp 22 | cd ocp 23 | git checkout 5a7738f 24 | ``` 25 | 26 | - The correpsonding version of OC20 GitHub repository is [here](https://github.com/Open-Catalyst-Project/ocp/tree/5a7738f9aa80b1a9a7e0ca15e33938b4d2557edd). 27 | 28 | - We need to modify `ocp/ocpmodels/common/utils.py` and add the following two lines after [Line 329](https://github.com/Open-Catalyst-Project/ocp/blob/5a7738f9aa80b1a9a7e0ca15e33938b4d2557edd/ocpmodels/common/utils.py#L329) as shown below: 29 | ```diff 30 | finally: 31 | + import nets 32 | + import oc20.trainer 33 | registry.register("imports_setup", True) 34 | ``` 35 | 36 | - Finally, we install `ocpmodels` by running: 37 | ``` 38 | # After activating the environment and under ocp/ 39 | pip install -e . 40 | ``` 41 | 42 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/gemnet_gp/README.md: -------------------------------------------------------------------------------- 1 | # Towards Training Billion Parameter Graph Neural Networks for Atomic Simulations 2 | 3 | Anuroop Sriram, Abhishek Das, Brandon M. Wood, Siddharth Goyal, C. Lawrence Zitnick 4 | 5 | [[`arXiv:2203.09697`](https://arxiv.org/abs/2203.09697)] 6 | 7 | 8 | To use graph parallel training, add `--gp-gpus N` to your command line, where N = number of GPUs to split the model over. This flag works for all tasks (`train`, `predict`, `validate` & `run-relaxations`). 9 | 10 | As an example, the Gemnet-XL model can be trained using: 11 | ```bash 12 | python main.py --mode train --config-yml configs/s2ef/all/gp_gemnet/gp-gemnet-xl.yml \ 13 | --distributed --num-nodes 32 --num-gpus 8 --gp-gpus 4 14 | ``` 15 | This trains the model on 256 GPUs (32 nodes x 8 GPUs each) with 4-way graph parallelism (i.e. the graph is distributed over 4 GPUs) and 64-way data parallelism (64 == 256 / 4). 16 | 17 | The Gemnet-XL model was trained without AMP as it led to unstable training. 18 | 19 | ## Citing 20 | 21 | If you use Graph Parallelism in your work, please consider citing: 22 | 23 | ```bibtex 24 | @inproceedings{sriram_graphparallel_2022, 25 | title={{Towards Training Billion Parameter Graph Neural Networks for Atomic Simulations}}, 26 | author={Sriram, Anuroop and Das, Abhishek and Wood, Brandon M. and Goyal, Siddharth and Zitnick, C. Lawrence}, 27 | booktitle={International Conference on Learning Representations (ICLR)}, 28 | year={2022} 29 | } 30 | ``` 31 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/modules/normalizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | 10 | 11 | class Normalizer(object): 12 | """Normalize a Tensor and restore it later.""" 13 | 14 | def __init__(self, tensor=None, mean=None, std=None, device=None): 15 | """tensor is taken as a sample to calculate the mean and std""" 16 | if tensor is None and mean is None: 17 | return 18 | 19 | if device is None: 20 | device = "cpu" 21 | 22 | if tensor is not None: 23 | self.mean = torch.mean(tensor, dim=0).to(device) 24 | self.std = torch.std(tensor, dim=0).to(device) 25 | return 26 | 27 | if mean is not None and std is not None: 28 | self.mean = torch.tensor(mean).to(device) 29 | self.std = torch.tensor(std).to(device) 30 | 31 | def to(self, device): 32 | self.mean = self.mean.to(device) 33 | self.std = self.std.to(device) 34 | 35 | def norm(self, tensor): 36 | return (tensor - self.mean) / self.std 37 | 38 | def denorm(self, normed_tensor): 39 | return normed_tensor * self.std + self.mean 40 | 41 | def state_dict(self): 42 | return {"mean": self.mean, "std": self.std} 43 | 44 | def load_state_dict(self, state_dict): 45 | self.mean = state_dict["mean"].to(self.mean.device) 46 | self.std = state_dict["std"].to(self.mean.device) 47 | -------------------------------------------------------------------------------- /data/conformers/fragment_merging/fragments/mol_5.mol: -------------------------------------------------------------------------------- 1 | D68EV3CPROA-x1594_0A 2 | RDKit 3D 3 | 4 | 16 17 0 0 0 0 0 0 0 0999 V2000 5 | -7.1290 -6.7810 -7.5670 N 0 0 0 0 0 0 0 0 0 0 0 0 6 | -7.4260 -5.5390 -7.9640 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -8.4810 -3.6240 -8.1510 C 0 0 0 0 0 0 0 0 0 0 0 0 8 | -9.5180 -2.6910 -7.9110 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | -10.5390 -3.0250 -7.0760 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | -10.5780 -4.2850 -6.4400 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | -9.5860 -5.1790 -6.6460 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | -6.7530 -4.7950 -8.8300 N 0 0 0 0 0 0 0 0 0 0 0 0 13 | -7.4220 -3.5740 -8.9490 N 0 0 0 0 0 0 0 0 0 0 0 0 14 | -8.5440 -4.8510 -7.4960 N 0 0 0 0 0 0 0 0 0 0 0 0 15 | -7.7318 -7.2612 -6.8837 H 0 0 0 0 0 0 0 0 0 0 0 0 16 | -6.2975 -7.2573 -7.9445 H 0 0 0 0 0 0 0 0 0 0 0 0 17 | -9.4986 -1.7037 -8.3957 H 0 0 0 0 0 0 0 0 0 0 0 0 18 | -11.3473 -2.3016 -6.8934 H 0 0 0 0 0 0 0 0 0 0 0 0 19 | -11.4172 -4.5392 -5.7758 H 0 0 0 0 0 0 0 0 0 0 0 0 20 | -9.6063 -6.1573 -6.1434 H 0 0 0 0 0 0 0 0 0 0 0 0 21 | 2 1 1 0 22 | 4 3 1 0 23 | 5 4 2 0 24 | 6 5 1 0 25 | 7 6 2 0 26 | 8 2 2 0 27 | 9 8 1 0 28 | 9 3 2 0 29 | 10 2 1 0 30 | 10 7 1 0 31 | 10 3 1 0 32 | 1 11 1 0 33 | 1 12 1 0 34 | 4 13 1 0 35 | 5 14 1 0 36 | 6 15 1 0 37 | 7 16 1 0 38 | M END 39 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/gemnet/initializers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | 10 | 11 | def _standardize(kernel): 12 | """ 13 | Makes sure that N*Var(W) = 1 and E[W] = 0 14 | """ 15 | eps = 1e-6 16 | 17 | if len(kernel.shape) == 3: 18 | axis = [0, 1] # last dimension is output dimension 19 | else: 20 | axis = 1 21 | 22 | var, mean = torch.var_mean(kernel, dim=axis, unbiased=True, keepdim=True) 23 | kernel = (kernel - mean) / (var + eps) ** 0.5 24 | return kernel 25 | 26 | 27 | def he_orthogonal_init(tensor): 28 | """ 29 | Generate a weight matrix with variance according to He (Kaiming) initialization. 30 | Based on a random (semi-)orthogonal matrix neural networks 31 | are expected to learn better when features are decorrelated 32 | (stated by eg. "Reducing overfitting in deep networks by decorrelating representations", 33 | "Dropout: a simple way to prevent neural networks from overfitting", 34 | "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks") 35 | """ 36 | tensor = torch.nn.init.orthogonal_(tensor) 37 | 38 | if len(tensor.shape) == 3: 39 | fan_in = tensor.shape[:-1].numel() 40 | else: 41 | fan_in = tensor.shape[1] 42 | 43 | with torch.no_grad(): 44 | tensor.data = _standardize(tensor.data) 45 | tensor.data *= (1 / fan_in) ** 0.5 46 | 47 | return tensor 48 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/gemnet_gp/initializers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | 10 | 11 | def _standardize(kernel): 12 | """ 13 | Makes sure that N*Var(W) = 1 and E[W] = 0 14 | """ 15 | eps = 1e-6 16 | 17 | if len(kernel.shape) == 3: 18 | axis = [0, 1] # last dimension is output dimension 19 | else: 20 | axis = 1 21 | 22 | var, mean = torch.var_mean(kernel, dim=axis, unbiased=True, keepdim=True) 23 | kernel = (kernel - mean) / (var + eps) ** 0.5 24 | return kernel 25 | 26 | 27 | def he_orthogonal_init(tensor): 28 | """ 29 | Generate a weight matrix with variance according to He (Kaiming) initialization. 30 | Based on a random (semi-)orthogonal matrix neural networks 31 | are expected to learn better when features are decorrelated 32 | (stated by eg. "Reducing overfitting in deep networks by decorrelating representations", 33 | "Dropout: a simple way to prevent neural networks from overfitting", 34 | "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks") 35 | """ 36 | tensor = torch.nn.init.orthogonal_(tensor) 37 | 38 | if len(tensor.shape) == 3: 39 | fan_in = tensor.shape[:-1].numel() 40 | else: 41 | fan_in = tensor.shape[1] 42 | 43 | with torch.no_grad(): 44 | tensor.data = _standardize(tensor.data) 45 | tensor.data *= (1 / fan_in) ** 0.5 46 | 47 | return tensor 48 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/nets/equiformer_v2/wigner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | # Borrowed from e3nn @ 0.4.0: 6 | # https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L10 7 | # _Jd is a list of tensors of shape (2l+1, 2l+1) 8 | _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"), weights_only=True) 9 | 10 | 11 | # Borrowed from e3nn @ 0.4.0: 12 | # https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L37 13 | # 14 | # In 0.5.0, e3nn shifted to torch.matrix_exp which is significantly slower: 15 | # https://github.com/e3nn/e3nn/blob/0.5.0/e3nn/o3/_wigner.py#L92 16 | def wigner_D(l, alpha, beta, gamma): 17 | if not l < len(_Jd): 18 | raise NotImplementedError( 19 | f"wigner D maximum l implemented is {len(_Jd) - 1}, send us an email to ask for more" 20 | ) 21 | 22 | alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma) 23 | J = _Jd[l].to(dtype=alpha.dtype, device=alpha.device) 24 | Xa = _z_rot_mat(alpha, l) 25 | Xb = _z_rot_mat(beta, l) 26 | Xc = _z_rot_mat(gamma, l) 27 | return Xa @ J @ Xb @ J @ Xc 28 | 29 | 30 | def _z_rot_mat(angle, l): 31 | shape, device, dtype = angle.shape, angle.device, angle.dtype 32 | M = angle.new_zeros((*shape, 2 * l + 1, 2 * l + 1)) 33 | inds = torch.arange(0, 2 * l + 1, 1, device=device) 34 | reversed_inds = torch.arange(2 * l, -1, -1, device=device) 35 | frequencies = torch.arange(l, -l - 1, -1, dtype=dtype, device=device) 36 | M[..., inds, reversed_inds] = torch.sin(frequencies * angle[..., None]) 37 | M[..., inds, inds] = torch.cos(frequencies * angle[..., None]) 38 | return M -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/scn/sampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | import math 8 | 9 | import torch 10 | 11 | ### Methods for sample points on a sphere 12 | 13 | 14 | def CalcSpherePoints(num_points, device): 15 | goldenRatio = (1 + 5**0.5) / 2 16 | i = torch.arange(num_points, device=device).view(-1, 1) 17 | theta = 2 * math.pi * i / goldenRatio 18 | phi = torch.arccos(1 - 2 * (i + 0.5) / num_points) 19 | points = torch.cat( 20 | [ 21 | torch.cos(theta) * torch.sin(phi), 22 | torch.sin(theta) * torch.sin(phi), 23 | torch.cos(phi), 24 | ], 25 | dim=1, 26 | ) 27 | 28 | # weight the points by their density 29 | pt_cross = points.view(1, -1, 3) - points.view(-1, 1, 3) 30 | pt_cross = torch.sum(pt_cross**2, dim=2) 31 | pt_cross = torch.exp(-pt_cross / (0.5 * 0.3)) 32 | scalar = 1.0 / torch.sum(pt_cross, dim=1) 33 | scalar = num_points * scalar / torch.sum(scalar) 34 | return points * (scalar.view(-1, 1)) 35 | 36 | 37 | def CalcSpherePointsRandom(num_points, device): 38 | pts = 2.0 * (torch.rand(num_points, 3, device=device) - 0.5) 39 | radius = torch.sum(pts**2, dim=1) 40 | while torch.max(radius) > 1.0: 41 | replace_pts = 2.0 * (torch.rand(num_points, 3, device=device) - 0.5) 42 | replace_mask = radius.gt(0.99) 43 | pts.masked_scatter_(replace_mask.view(-1, 1).repeat(1, 3), replace_pts) 44 | radius = torch.sum(pts**2, dim=1) 45 | 46 | return pts / radius.view(-1, 1) 47 | -------------------------------------------------------------------------------- /src/shepherd/model/utils/add_virtual_edges_to_edge_index.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import radius_graph 3 | import numpy as np 4 | 5 | def add_virtual_edges_to_edge_index(edge_index, virtual_node_mask, batch): 6 | """ 7 | Adds edges to edge_index that connect all (real) nodes to the virtual node(s) 8 | 9 | Arguments: 10 | edge_index -- torch.LongTensor with shape (2, N_edges) 11 | virtual_node_mask -- torch.BoolTensor with shape (N_nodes,) 12 | batch -- torch.LongTensor with shape (N_nodes,) 13 | 14 | Returns: 15 | new_edge_index -- updated edge_index with additional virtual edges 16 | """ 17 | # edge_index (2, N_edges) 18 | # virtual_node_mask (N_nodes,) -- boolean tensor where True indicates a virtual node 19 | # batch (N_nodes,) 20 | 21 | # remove existing edges to/from virtual nodes, to avoid duplicating edges 22 | edge_mask = virtual_node_mask[edge_index[1]] | virtual_node_mask[edge_index[0]] 23 | edge_index_without_VN = edge_index[:, ~edge_mask] 24 | 25 | # create edge_index_VN that has edges between all real nodes and each VN 26 | edge_index_fully_connected = radius_graph( 27 | torch.zeros((virtual_node_mask.shape[0],3), device = batch.device), 28 | r = np.inf, 29 | batch = batch, 30 | max_num_neighbors = 1000000, 31 | ) # this excludes self-loops, by default 32 | edge_mask_fully_connected = virtual_node_mask[edge_index_fully_connected[1]] | virtual_node_mask[edge_index_fully_connected[0]] 33 | edge_index_VN = edge_index_fully_connected[:, edge_mask_fully_connected] 34 | 35 | # combine and return 36 | new_edge_index = torch.cat([edge_index_without_VN, edge_index_VN], dim = 1) 37 | return new_edge_index -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/nets/equiformer_v2/gaussian_rbf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.jit.script 5 | def gaussian(x, mean, std): 6 | pi = 3.14159 7 | a = (2*pi) ** 0.5 8 | return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std) 9 | 10 | 11 | # From Graphormer 12 | class GaussianRadialBasisLayer(torch.nn.Module): 13 | def __init__(self, num_basis, cutoff): 14 | super().__init__() 15 | self.num_basis = num_basis 16 | self.cutoff = cutoff + 0.0 17 | self.mean = torch.nn.Parameter(torch.zeros(1, self.num_basis)) 18 | self.std = torch.nn.Parameter(torch.zeros(1, self.num_basis)) 19 | self.weight = torch.nn.Parameter(torch.ones(1, 1)) 20 | self.bias = torch.nn.Parameter(torch.zeros(1, 1)) 21 | 22 | self.std_init_max = 1.0 23 | self.std_init_min = 1.0 / self.num_basis 24 | self.mean_init_max = 1.0 25 | self.mean_init_min = 0 26 | torch.nn.init.uniform_(self.mean, self.mean_init_min, self.mean_init_max) 27 | torch.nn.init.uniform_(self.std, self.std_init_min, self.std_init_max) 28 | torch.nn.init.constant_(self.weight, 1) 29 | torch.nn.init.constant_(self.bias, 0) 30 | 31 | 32 | def forward(self, dist, node_atom=None, edge_src=None, edge_dst=None): 33 | x = dist / self.cutoff 34 | x = x.unsqueeze(-1) 35 | x = self.weight * x + self.bias 36 | x = x.expand(-1, self.num_basis) 37 | mean = self.mean 38 | std = self.std.abs() + 1e-5 39 | x = gaussian(x, mean, std) 40 | return x 41 | 42 | 43 | def extra_repr(self): 44 | return 'mean_init_max={}, mean_init_min={}, std_init_max={}, std_init_min={}'.format( 45 | self.mean_init_max, self.mean_init_min, self.std_init_max, self.std_init_min) 46 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "shepherd" 7 | version = "0.2.6" 8 | description = "ShEPhERD: Diffusing Shape, Electrostatics, and Pharmacophores for Drug Design" 9 | authors = [ 10 | {name = "Keir Adams"}, 11 | {name = "Kento Abeywardane", email = "kento@mit.edu"} 12 | ] 13 | requires-python = ">=3.9,<3.12" 14 | 15 | dependencies = [ 16 | "torch", 17 | "torch-geometric", 18 | "pytorch-lightning", 19 | "rdkit==2024.09.6", 20 | "e3nn", 21 | "open3d==0.18", 22 | "numpy", 23 | "matplotlib", 24 | "tqdm", 25 | "huggingface-hub", 26 | "pandas==2.2.3", 27 | "jupyterlab", 28 | ] 29 | 30 | [tool.setuptools] 31 | package-dir = {"" = "src"} 32 | 33 | [project.optional-dependencies] 34 | dev = [ 35 | "pytest>=7.0.0", 36 | "pytest-cov>=4.0.0", 37 | "black>=23.0.0", 38 | "isort>=5.12.0", 39 | "mypy>=1.0.0", 40 | "flake8>=6.0.0", 41 | ] 42 | 43 | [project.urls] 44 | "Homepage" = "https://github.com/coleygroup/shepherd" 45 | "Bug Tracker" = "https://github.com/coleygroup/shepherd/issues" 46 | 47 | [tool.setuptools.packages.find] 48 | where = ["src"] 49 | 50 | [tool.setuptools.exclude-package-data] 51 | "*" = ["*.so", "*.dylib", "*.dll"] 52 | 53 | [tool.black] 54 | line-length = 88 55 | target-version = ["py38", "py39", "py310"] 56 | include = '\.pyi?$' 57 | 58 | [tool.isort] 59 | profile = "black" 60 | line_length = 88 61 | multi_line_output = 3 62 | 63 | [tool.mypy] 64 | python_version = "3.8" 65 | warn_return_any = true 66 | warn_unused_configs = true 67 | disallow_untyped_defs = false 68 | disallow_incomplete_defs = false 69 | 70 | [tool.pytest.ini_options] 71 | testpaths = ["tests"] 72 | python_files = "test_*.py" 73 | python_functions = "test_*" 74 | addopts = "--cov=shepherd" -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/utils/activations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | class Act(torch.nn.Module): 13 | def __init__(self, act, slope=0.05): 14 | super(Act, self).__init__() 15 | self.act = act 16 | self.slope = slope 17 | self.shift = torch.log(torch.tensor(2.0)).item() 18 | 19 | def forward(self, input): 20 | if self.act == "relu": 21 | return F.relu(input) 22 | elif self.act == "leaky_relu": 23 | return F.leaky_relu(input) 24 | elif self.act == "sp": 25 | return F.softplus(input, beta=1) 26 | elif self.act == "leaky_sp": 27 | return F.softplus(input, beta=1) - self.slope * F.relu(-input) 28 | elif self.act == "elu": 29 | return F.elu(input, alpha=1) 30 | elif self.act == "leaky_elu": 31 | return F.elu(input, alpha=1) - self.slope * F.relu(-input) 32 | elif self.act == "ssp": 33 | return F.softplus(input, beta=1) - self.shift 34 | elif self.act == "leaky_ssp": 35 | return ( 36 | F.softplus(input, beta=1) 37 | - self.slope * F.relu(-input) 38 | - self.shift 39 | ) 40 | elif self.act == "tanh": 41 | return torch.tanh(input) 42 | elif self.act == "leaky_tanh": 43 | return torch.tanh(input) + self.slope * input 44 | elif self.act == "swish": 45 | return torch.sigmoid(input) * input 46 | else: 47 | raise RuntimeError(f"Undefined activation called {self.act}") 48 | -------------------------------------------------------------------------------- /data/conformers/fragment_merging/fragments/mol_0.mol: -------------------------------------------------------------------------------- 1 | D68EV3CPROA-x1163_0A 2 | RDKit 3D 3 | 4 | 19 20 0 0 0 0 0 0 0 0999 V2000 5 | -9.4350 -2.6960 -8.2680 N 0 0 0 0 0 0 0 0 0 0 0 0 6 | -8.5100 -3.6940 -8.2430 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -7.1530 -5.5870 -7.8950 C 0 0 0 0 0 0 0 0 0 0 0 0 8 | -10.5300 -2.4380 -7.4050 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | -11.7870 -2.2370 -7.9710 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | -12.8690 -1.9250 -7.1620 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | -12.7100 -1.8090 -5.7960 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | -11.4670 -2.0140 -5.2300 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | -10.3780 -2.3340 -6.0250 C 0 0 0 0 0 0 0 0 0 0 0 0 14 | -7.6090 -3.8330 -9.1780 N 0 0 0 0 0 0 0 0 0 0 0 0 15 | -6.8180 -4.9520 -8.9650 N 0 0 0 0 0 0 0 0 0 0 0 0 16 | -8.4680 -4.9100 -7.0090 S 0 0 0 0 0 0 0 0 0 0 0 0 17 | -9.3214 -2.0200 -9.0368 H 0 0 0 0 0 0 0 0 0 0 0 0 18 | -6.6344 -6.5005 -7.5685 H 0 0 0 0 0 0 0 0 0 0 0 0 19 | -11.9211 -2.3258 -9.0592 H 0 0 0 0 0 0 0 0 0 0 0 0 20 | -13.8606 -1.7689 -7.6119 H 0 0 0 0 0 0 0 0 0 0 0 0 21 | -13.5706 -1.5541 -5.1600 H 0 0 0 0 0 0 0 0 0 0 0 0 22 | -11.3406 -1.9227 -4.1411 H 0 0 0 0 0 0 0 0 0 0 0 0 23 | -9.3934 -2.5059 -5.5657 H 0 0 0 0 0 0 0 0 0 0 0 0 24 | 2 1 1 0 25 | 4 1 1 0 26 | 5 4 2 0 27 | 6 5 1 0 28 | 7 6 2 0 29 | 8 7 1 0 30 | 9 8 2 0 31 | 9 4 1 0 32 | 10 2 2 0 33 | 11 3 2 0 34 | 11 10 1 0 35 | 12 3 1 0 36 | 12 2 1 0 37 | 1 13 1 0 38 | 3 14 1 0 39 | 5 15 1 0 40 | 6 16 1 0 41 | 7 17 1 0 42 | 8 18 1 0 43 | 9 19 1 0 44 | M END 45 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/common/hpo_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import math 9 | 10 | from ray import tune 11 | 12 | 13 | def tune_reporter( 14 | iters, 15 | train_metrics, 16 | val_metrics, 17 | test_metrics=None, 18 | metric_to_opt="val_loss", 19 | min_max="min", 20 | ): 21 | """ 22 | Wrapper function for tune.report() 23 | 24 | Args: 25 | iters(dict): dict with training iteration info (e.g. steps, epochs) 26 | train_metrics(dict): train metrics dict 27 | val_metrics(dict): val metrics dict 28 | test_metrics(dict, optional): test metrics dict, default is None 29 | metric_to_opt(str, optional): str for val metric to optimize, default is val_loss 30 | min_max(str, optional): either "min" or "max", determines whether metric_to_opt is to be minimized or maximized, default is min 31 | 32 | """ 33 | # labels metric dicts 34 | train = label_metric_dict(train_metrics, "train") 35 | val = label_metric_dict(val_metrics, "val") 36 | # this enables tolerance for NaNs assumes val set is used for optimization 37 | if math.isnan(val[metric_to_opt]): 38 | if min_max == "min": 39 | val[metric_to_opt] = 100000.0 40 | if min_max == "max": 41 | val[metric_to_opt] = 0.0 42 | if test_metrics: 43 | test = label_metric_dict(test_metrics, "test") 44 | else: 45 | test = {} 46 | # report results to Ray Tune 47 | tune.report(**iters, **train, **val, **test) 48 | 49 | 50 | def label_metric_dict(metric_dict, split): 51 | new_dict = {} 52 | for key in metric_dict: 53 | new_dict["{}_{}".format(split, key)] = metric_dict[key] 54 | metric_dict = new_dict 55 | return metric_dict 56 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Distributed training related functions. 3 | 4 | From DeiT. 5 | ''' 6 | 7 | import io 8 | import os 9 | import time 10 | from collections import defaultdict, deque 11 | import datetime 12 | 13 | import torch 14 | import torch.distributed as dist 15 | 16 | 17 | def is_dist_avail_and_initialized(): 18 | if not dist.is_available(): 19 | return False 20 | if not dist.is_initialized(): 21 | return False 22 | return True 23 | 24 | 25 | def get_world_size(): 26 | if not is_dist_avail_and_initialized(): 27 | return 1 28 | return dist.get_world_size() 29 | 30 | 31 | def get_rank(): 32 | if not is_dist_avail_and_initialized(): 33 | return 0 34 | return dist.get_rank() 35 | 36 | 37 | def is_main_process(): 38 | return get_rank() == 0 39 | 40 | 41 | def save_on_master(*args, **kwargs): 42 | if is_main_process(): 43 | torch.save(*args, **kwargs) 44 | 45 | 46 | def init_distributed_mode(args): 47 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 48 | args.rank = int(os.environ["RANK"]) 49 | args.world_size = int(os.environ['WORLD_SIZE']) 50 | args.local_rank = int(os.environ['LOCAL_RANK']) 51 | elif 'SLURM_PROCID' in os.environ: 52 | args.rank = int(os.environ['SLURM_PROCID']) 53 | args.local_rank = args.rank % torch.cuda.device_count() 54 | else: 55 | print('Not using distributed mode') 56 | args.distributed = False 57 | args.rank = 0 58 | args.local_rank = 0 59 | return 60 | 61 | args.distributed = True 62 | 63 | torch.cuda.set_device(args.local_rank) 64 | args.dist_backend = 'nccl' 65 | #print('| distributed init (rank {}): {}'.format( 66 | # args.rank, args.dist_url), flush=True) 67 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 68 | world_size=args.world_size, rank=args.rank) 69 | torch.distributed.barrier() 70 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class FileLogger: 5 | def __init__(self, is_master=False, is_rank0=False, output_dir=None, logger_name='training'): 6 | # only call by master 7 | # checked outside the class 8 | self.output_dir = output_dir 9 | if is_rank0: 10 | self.logger_name = logger_name 11 | self.logger = self.get_logger(output_dir, log_to_file=is_master) 12 | else: 13 | self.logger_name = None 14 | self.logger = NoOp() 15 | 16 | 17 | def get_logger(self, output_dir, log_to_file): 18 | logger = logging.getLogger(self.logger_name) 19 | logger.setLevel(logging.DEBUG) 20 | formatter = logging.Formatter('%(message)s') 21 | 22 | if output_dir and log_to_file: 23 | 24 | time_formatter = logging.Formatter('%(asctime)s - %(filename)s:%(lineno)d - %(message)s') 25 | debuglog = logging.FileHandler(output_dir+'/debug.log') 26 | debuglog.setLevel(logging.DEBUG) 27 | debuglog.setFormatter(time_formatter) 28 | logger.addHandler(debuglog) 29 | 30 | console = logging.StreamHandler() 31 | console.setFormatter(formatter) 32 | console.setLevel(logging.DEBUG) 33 | logger.addHandler(console) 34 | 35 | # Reference: https://stackoverflow.com/questions/21127360/python-2-7-log-displayed-twice-when-logging-module-is-used-in-two-python-scri 36 | logger.propagate = False 37 | 38 | return logger 39 | 40 | def console(self, *args): 41 | self.logger.debug(*args) 42 | 43 | def event(self, *args): 44 | self.logger.warn(*args) 45 | 46 | def verbose(self, *args): 47 | self.logger.info(*args) 48 | 49 | def info(self, *args): 50 | self.logger.info(*args) 51 | 52 | 53 | # no_op method/object that accept every signature 54 | class NoOp: 55 | def __getattr__(self, *args): 56 | def no_op(*args, **kwargs): pass 57 | return no_op -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/oc20/trainer/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class FileLogger: 5 | def __init__(self, is_master=False, is_rank0=False, output_dir=None, logger_name='training'): 6 | # only call by master 7 | # checked outside the class 8 | self.output_dir = output_dir 9 | if is_rank0: 10 | self.logger_name = logger_name 11 | self.logger = self.get_logger(output_dir, log_to_file=is_master) 12 | else: 13 | self.logger_name = None 14 | self.logger = NoOp() 15 | 16 | 17 | def get_logger(self, output_dir, log_to_file): 18 | logger = logging.getLogger(self.logger_name) 19 | logger.setLevel(logging.DEBUG) 20 | formatter = logging.Formatter('%(message)s') 21 | 22 | if output_dir and log_to_file: 23 | 24 | time_formatter = logging.Formatter('%(asctime)s - %(filename)s:%(lineno)d - %(message)s') 25 | debuglog = logging.FileHandler(output_dir+'/debug.log') 26 | debuglog.setLevel(logging.DEBUG) 27 | debuglog.setFormatter(time_formatter) 28 | logger.addHandler(debuglog) 29 | 30 | console = logging.StreamHandler() 31 | console.setFormatter(formatter) 32 | console.setLevel(logging.DEBUG) 33 | logger.addHandler(console) 34 | 35 | # Reference: https://stackoverflow.com/questions/21127360/python-2-7-log-displayed-twice-when-logging-module-is-used-in-two-python-scri 36 | logger.propagate = False 37 | 38 | return logger 39 | 40 | def console(self, *args): 41 | self.logger.debug(*args) 42 | 43 | def event(self, *args): 44 | self.logger.warn(*args) 45 | 46 | def verbose(self, *args): 47 | self.logger.info(*args) 48 | 49 | def info(self, *args): 50 | self.logger.info(*args) 51 | 52 | 53 | # no_op method/object that accept every signature 54 | class NoOp: 55 | def __getattr__(self, *args): 56 | def no_op(*args, **kwargs): pass 57 | return no_op -------------------------------------------------------------------------------- /data/shepherd_chkpts/README.md: -------------------------------------------------------------------------------- 1 | # *ShEPhERD* checkpoints for PyTorch 2.5.1 2 | Checkpoints have been moved to HuggingFace to reduce repo size. They can be automatically downloaded or manually from here: [https://huggingface.co/kabeywar/shepherd](https://huggingface.co/kabeywar/shepherd). 3 | 4 | These checkpoints were converted from the original model weights trained using PyTorch Lightning v1.2 to v2.5.1 using `python -m pytorch_lightning.utilities.upgrade_checkpoint `. The original model weights can be found at: 5 | [https://www.dropbox.com/scl/fo/rgn33g9kwthnjt27bsc3m/ADGt-CplyEXSU7u5MKc0aTo?rlkey=fhi74vkktpoj1irl84ehnw95h&e=1&st=wn46d6o2&dl=0](https://www.dropbox.com/scl/fo/rgn33g9kwthnjt27bsc3m/ADGt-CplyEXSU7u5MKc0aTo?rlkey=fhi74vkktpoj1irl84ehnw95h&e=1&st=wn46d6o2&dl=0). 6 | 7 | ## Available Models 8 | 9 | | Model Type | Description | Training Dataset | 10 | |------------|-------------|------------------| 11 | | `mosesaq` | Shape, electrostatics, and pharmacophores | MOSES-aq | 12 | | `gdb_x2` | Shape conditioning only | GDB17 | 13 | | `gdb_x3` | Shape and electrostatics | GDB17 | 14 | | `gdb_x4` | Pharmacophores only | GDB17 | 15 | 16 | 17 | ### Basic Usage 18 | 19 | ```python 20 | from shepherd import load_shepherd_model 21 | 22 | # Load the default MOSES-aq model (downloads automatically if needed) 23 | model = load_shepherd_model() 24 | 25 | # Load a specific model type 26 | model = load_shepherd_model('gdb_x3') 27 | ``` 28 | 29 | ### Advanced Usage 30 | ```python 31 | from shepherd import load_model, clear_model_cache 32 | 33 | # Use custom cache directory 34 | model = load_model(cache_dir='./data/shepherd_chkpts') 35 | 36 | # Check for local checkpoints first 37 | model = load_model(local_data_dir='./data/shepherd_chkpts') 38 | 39 | # Clear cached models 40 | clear_model_cache('mosesaq') # Clear specific model 41 | clear_model_cache() # Clear all models 42 | ``` 43 | 44 | ### Get Model Information 45 | ```python 46 | from shepherd import get_model_info 47 | 48 | models = get_model_info() 49 | for model_type, description in models.items(): 50 | print(f"{model_type}: {description}") 51 | ``` -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/datasets/embeddings/atomic_radii.py: -------------------------------------------------------------------------------- 1 | """ 2 | Atomic radii in picometers 3 | 4 | NaN stored for unavailable parameters. 5 | """ 6 | ATOMIC_RADII = { 7 | 0: float("NaN"), 8 | 1: 25.0, 9 | 2: 120.0, 10 | 3: 145.0, 11 | 4: 105.0, 12 | 5: 85.0, 13 | 6: 70.0, 14 | 7: 65.0, 15 | 8: 60.0, 16 | 9: 50.0, 17 | 10: 160.0, 18 | 11: 180.0, 19 | 12: 150.0, 20 | 13: 125.0, 21 | 14: 110.0, 22 | 15: 100.0, 23 | 16: 100.0, 24 | 17: 100.0, 25 | 18: 71.0, 26 | 19: 220.0, 27 | 20: 180.0, 28 | 21: 160.0, 29 | 22: 140.0, 30 | 23: 135.0, 31 | 24: 140.0, 32 | 25: 140.0, 33 | 26: 140.0, 34 | 27: 135.0, 35 | 28: 135.0, 36 | 29: 135.0, 37 | 30: 135.0, 38 | 31: 130.0, 39 | 32: 125.0, 40 | 33: 115.0, 41 | 34: 115.0, 42 | 35: 115.0, 43 | 36: float("NaN"), 44 | 37: 235.0, 45 | 38: 200.0, 46 | 39: 180.0, 47 | 40: 155.0, 48 | 41: 145.0, 49 | 42: 145.0, 50 | 43: 135.0, 51 | 44: 130.0, 52 | 45: 135.0, 53 | 46: 140.0, 54 | 47: 160.0, 55 | 48: 155.0, 56 | 49: 155.0, 57 | 50: 145.0, 58 | 51: 145.0, 59 | 52: 140.0, 60 | 53: 140.0, 61 | 54: float("NaN"), 62 | 55: 260.0, 63 | 56: 215.0, 64 | 57: 195.0, 65 | 58: 185.0, 66 | 59: 185.0, 67 | 60: 185.0, 68 | 61: 185.0, 69 | 62: 185.0, 70 | 63: 185.0, 71 | 64: 180.0, 72 | 65: 175.0, 73 | 66: 175.0, 74 | 67: 175.0, 75 | 68: 175.0, 76 | 69: 175.0, 77 | 70: 175.0, 78 | 71: 175.0, 79 | 72: 155.0, 80 | 73: 145.0, 81 | 74: 135.0, 82 | 75: 135.0, 83 | 76: 130.0, 84 | 77: 135.0, 85 | 78: 135.0, 86 | 79: 135.0, 87 | 80: 150.0, 88 | 81: 190.0, 89 | 82: 180.0, 90 | 83: 160.0, 91 | 84: 190.0, 92 | 85: float("NaN"), 93 | 86: float("NaN"), 94 | 87: float("NaN"), 95 | 88: 215.0, 96 | 89: 195.0, 97 | 90: 180.0, 98 | 91: 180.0, 99 | 92: 175.0, 100 | 93: 175.0, 101 | 94: 175.0, 102 | 95: 175.0, 103 | 96: float("NaN"), 104 | 97: float("NaN"), 105 | 98: float("NaN"), 106 | 99: float("NaN"), 107 | 100: float("NaN"), 108 | } 109 | -------------------------------------------------------------------------------- /data/conformers/fragment_merging/fragments/mol_9.mol: -------------------------------------------------------------------------------- 1 | D68EV3CPROA-x1083_0A 2 | RDKit 3D 3 | 4 | 24 24 0 0 0 0 0 0 0 0999 V2000 5 | -4.7250 -7.9920 0.3010 N 0 0 0 0 0 0 0 0 0 0 0 0 6 | -6.2190 -8.7450 -1.8530 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -5.0980 -6.4300 -1.5740 O 0 0 0 0 0 0 0 0 0 0 0 0 8 | -5.2400 -8.1870 1.6740 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | -4.2390 -8.9280 2.5540 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | -2.9060 -8.2190 2.5580 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | -2.3900 -8.1410 1.1460 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | -3.3760 -7.3900 0.2660 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | -6.9580 -7.0070 -0.0800 O 0 0 0 0 0 0 0 0 0 0 0 0 14 | -1.9580 -8.8820 3.3910 O 0 0 0 0 0 0 0 0 0 0 0 0 15 | -5.7910 -7.4090 -0.7980 S 0 0 0 0 0 0 0 0 0 0 0 0 16 | -6.4986 -9.6177 -1.2445 H 0 0 0 0 0 0 0 0 0 0 0 0 17 | -7.0686 -8.4533 -2.4878 H 0 0 0 0 0 0 0 0 0 0 0 0 18 | -5.3580 -9.0013 -2.4878 H 0 0 0 0 0 0 0 0 0 0 0 0 19 | -5.4462 -7.2029 2.1201 H 0 0 0 0 0 0 0 0 0 0 0 0 20 | -6.1560 -8.7930 1.6129 H 0 0 0 0 0 0 0 0 0 0 0 0 21 | -4.6268 -8.9740 3.5824 H 0 0 0 0 0 0 0 0 0 0 0 0 22 | -4.0990 -9.9429 2.1536 H 0 0 0 0 0 0 0 0 0 0 0 0 23 | -3.0491 -7.2076 2.9662 H 0 0 0 0 0 0 0 0 0 0 0 0 24 | -1.4244 -7.6142 1.1409 H 0 0 0 0 0 0 0 0 0 0 0 0 25 | -2.2642 -9.1609 0.7536 H 0 0 0 0 0 0 0 0 0 0 0 0 26 | -3.4430 -6.3496 0.6168 H 0 0 0 0 0 0 0 0 0 0 0 0 27 | -3.0103 -7.4399 -0.7702 H 0 0 0 0 0 0 0 0 0 0 0 0 28 | -2.1855 -8.7229 4.3413 H 0 0 0 0 0 0 0 0 0 0 0 0 29 | 4 1 1 0 30 | 5 4 1 0 31 | 6 5 1 0 32 | 7 6 1 0 33 | 8 7 1 0 34 | 8 1 1 0 35 | 6 10 1 0 36 | 11 9 2 0 37 | 11 2 1 0 38 | 11 3 2 0 39 | 11 1 1 0 40 | 2 12 1 0 41 | 2 13 1 0 42 | 2 14 1 0 43 | 4 15 1 0 44 | 4 16 1 0 45 | 5 17 1 0 46 | 5 18 1 0 47 | 6 19 1 0 48 | 7 20 1 0 49 | 7 21 1 0 50 | 8 22 1 0 51 | 8 23 1 0 52 | 10 24 1 0 53 | M END 54 | -------------------------------------------------------------------------------- /data/conformers/fragment_merging/fragments/mol_7.mol: -------------------------------------------------------------------------------- 1 | D68EV3CPROA-x2135_0A 2 | RDKit 3D 3 | 4 | 24 26 0 0 0 0 0 0 0 0999 V2000 5 | -11.0050 -2.2540 -5.2650 N 0 0 0 0 0 0 0 0 0 0 0 0 6 | -11.1310 -3.1790 -6.2060 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -12.1350 -3.2810 -6.9140 O 0 0 0 0 0 0 0 0 0 0 0 0 8 | -9.9510 -4.1010 -6.3890 C 0 0 2 0 0 0 0 0 0 0 0 0 9 | -8.7670 -3.4540 -7.0880 C 0 0 1 0 0 0 0 0 0 0 0 0 10 | -7.3640 -4.0720 -6.9720 C 0 0 2 0 0 0 0 0 0 0 0 0 11 | -6.7470 -4.0100 -8.3730 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | -7.5260 -5.0880 -9.1340 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | -8.5070 -5.6070 -8.0750 C 0 0 1 0 0 0 0 0 0 0 0 0 14 | -7.6210 -5.5780 -6.8260 C 0 0 0 0 0 0 0 0 0 0 0 0 15 | -9.5550 -4.5010 -7.8030 C 0 0 1 0 0 0 0 0 0 0 0 0 16 | -10.1458 -2.2091 -4.6987 H 0 0 0 0 0 0 0 0 0 0 0 0 17 | -11.7663 -1.5801 -5.1001 H 0 0 0 0 0 0 0 0 0 0 0 0 18 | -10.1387 -4.6658 -5.4639 H 0 0 0 0 0 0 0 0 0 0 0 0 19 | -8.4924 -2.3896 -7.1273 H 0 0 0 0 0 0 0 0 0 0 0 0 20 | -6.7655 -3.5958 -6.1814 H 0 0 0 0 0 0 0 0 0 0 0 0 21 | -6.8944 -3.0192 -8.8275 H 0 0 0 0 0 0 0 0 0 0 0 0 22 | -5.6592 -4.1732 -8.3778 H 0 0 0 0 0 0 0 0 0 0 0 0 23 | -8.0612 -4.6581 -9.9935 H 0 0 0 0 0 0 0 0 0 0 0 0 24 | -6.8833 -5.8746 -9.5560 H 0 0 0 0 0 0 0 0 0 0 0 0 25 | -8.9877 -6.5596 -8.3424 H 0 0 0 0 0 0 0 0 0 0 0 0 26 | -8.1537 -5.8464 -5.9018 H 0 0 0 0 0 0 0 0 0 0 0 0 27 | -6.7647 -6.2651 -6.7587 H 0 0 0 0 0 0 0 0 0 0 0 0 28 | -10.2158 -4.6756 -8.6649 H 0 0 0 0 0 0 0 0 0 0 0 0 29 | 2 1 1 0 30 | 3 2 2 0 31 | 4 2 1 0 32 | 5 4 1 0 33 | 5 6 1 0 34 | 6 7 1 0 35 | 8 7 1 0 36 | 9 8 1 0 37 | 10 9 1 0 38 | 10 6 1 0 39 | 11 9 1 0 40 | 11 5 1 0 41 | 11 4 1 0 42 | 1 12 1 0 43 | 1 13 1 0 44 | 4 14 1 1 45 | 5 15 1 6 46 | 6 16 1 1 47 | 7 17 1 0 48 | 7 18 1 0 49 | 8 19 1 0 50 | 8 20 1 0 51 | 9 21 1 6 52 | 10 22 1 0 53 | 10 23 1 0 54 | 11 24 1 6 55 | M END 56 | -------------------------------------------------------------------------------- /data/conformers/fragment_merging/fragments/mol_2.mol: -------------------------------------------------------------------------------- 1 | D68EV3CPROA-x1084_0A 2 | RDKit 3D 3 | 4 | 25 25 0 0 0 0 0 0 0 0999 V2000 5 | -5.1360 -6.6140 -3.4210 N 0 0 0 0 0 0 0 0 0 0 0 0 6 | -6.7990 -4.5530 -3.9980 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -6.9450 -6.9160 -5.0510 O 0 0 0 0 0 0 0 0 0 0 0 0 8 | -4.6800 -7.3750 -2.3030 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | -3.3580 -7.2310 -1.9000 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | -2.9620 -7.6590 -0.6440 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | -3.8640 -8.2490 0.2280 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | -3.4810 -8.5200 1.6360 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | -5.1600 -8.4660 -0.2150 C 0 0 0 0 0 0 0 0 0 0 0 0 14 | -5.5680 -8.0430 -1.4660 C 0 0 0 0 0 0 0 0 0 0 0 0 15 | -3.2780 -7.2170 2.2860 N 0 0 0 0 0 0 0 0 0 0 0 0 16 | -7.4920 -6.6520 -2.6630 O 0 0 0 0 0 0 0 0 0 0 0 0 17 | -6.6900 -6.2900 -3.7910 S 0 0 0 0 0 0 0 0 0 0 0 0 18 | -4.4164 -6.2343 -4.0526 H 0 0 0 0 0 0 0 0 0 0 0 0 19 | -6.8295 -4.0674 -3.0115 H 0 0 0 0 0 0 0 0 0 0 0 0 20 | -5.9216 -4.1945 -4.5562 H 0 0 0 0 0 0 0 0 0 0 0 0 21 | -7.7143 -4.3070 -4.5562 H 0 0 0 0 0 0 0 0 0 0 0 0 22 | -2.6238 -6.7756 -2.5808 H 0 0 0 0 0 0 0 0 0 0 0 0 23 | -1.9149 -7.5288 -0.3331 H 0 0 0 0 0 0 0 0 0 0 0 0 24 | -4.2826 -9.0751 2.1452 H 0 0 0 0 0 0 0 0 0 0 0 0 25 | -2.5658 -9.1284 1.6831 H 0 0 0 0 0 0 0 0 0 0 0 0 26 | -5.8757 -8.9837 0.4405 H 0 0 0 0 0 0 0 0 0 0 0 0 27 | -6.5980 -8.2351 -1.8008 H 0 0 0 0 0 0 0 0 0 0 0 0 28 | -3.2967 -7.3369 3.3088 H 0 0 0 0 0 0 0 0 0 0 0 0 29 | -2.3666 -6.8301 2.0023 H 0 0 0 0 0 0 0 0 0 0 0 0 30 | 4 1 1 0 31 | 5 4 2 0 32 | 6 5 1 0 33 | 7 6 2 0 34 | 8 7 1 0 35 | 9 7 1 0 36 | 10 9 2 0 37 | 10 4 1 0 38 | 11 8 1 0 39 | 13 12 2 0 40 | 13 3 2 0 41 | 13 2 1 0 42 | 13 1 1 0 43 | 1 14 1 0 44 | 2 15 1 0 45 | 2 16 1 0 46 | 2 17 1 0 47 | 5 18 1 0 48 | 6 19 1 0 49 | 8 20 1 0 50 | 8 21 1 0 51 | 9 22 1 0 52 | 10 23 1 0 53 | 11 24 1 0 54 | 11 25 1 0 55 | M END 56 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/painn/README.md: -------------------------------------------------------------------------------- 1 | # Polarizable Atom Interaction Neural Network (PaiNN) 2 | 3 | Kristof T. Schütt, Oliver T. Unke, Michael Gastegger 4 | 5 | [[`arXiv:2102.03150`](https://arxiv.org/abs/2102.03150)] 6 | 7 | This is our independent reimplementation of the original PaiNN architecture 8 | with the difference that forces are predicted directly from vectorial features 9 | via a gated equivariant block instead of gradients of the energy output. 10 | This breaks energy conservation but is essential for good performance on OC20. 11 | 12 | All PaiNN models were trained without AMP, as using AMP led to unstable training. 13 | 14 | ## IS2RE 15 | 16 | Trained only using IS2RE data, no auxiliary losses and/or S2EF data. 17 | 18 | | Model | Val ID Energy MAE | Test metrics | Download | 19 | | ----- | ----------------- | ------------ | -------- | 20 | | painn_h1024_bs4x8 | 0.5728 | [IS2RE](https://evalai.s3.amazonaws.com/media/submission_files/submission_200972/45d289fc-8de9-45cc-aed4-6cd1753cb56d.json) | [config](https://github.com/Open-Catalyst-Project/ocp/tree/main/configs/is2re/all/painn/painn_h1024_bs8x4.yml) \| [checkpoint](https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_05/is2re/painn_h1024_bs4x8_is2re_all.pt) | 21 | 22 | ## S2EF 23 | 24 | | Model | Val ID 30k Force MAE | Val ID 30k Energy MAE | Val ID 30k Force cos | Test metrics | Download | 25 | | ----- | -------------------- | --------------------- | -------------------- | ------------ | -------- | 26 | | painn_h512 | 0.02945 | 0.2459 | 0.5143 | [S2EF](https://evalai.s3.amazonaws.com/media/submission_files/submission_200711/2f487981-051d-445e-a7cd-6eb00ebe0735.json) \| [IS2RE](https://evalai.s3.amazonaws.com/media/submission_files/submission_200710/7fe29c4c-c203-434d-a6d4-9ea992d3bb5c.json) \| [IS2RS](https://evalai.s3.amazonaws.com/media/submission_files/submission_200700/8fd419e6-bab3-49be-a936-ae31979b4866.json) | [config](https://github.com/Open-Catalyst-Project/ocp/tree/main/configs/s2ef/all/painn/painn_h512.yml) \| [checkpoint](https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_05/s2ef/painn_h512_s2ef_all.pt) | 27 | 28 | ## Citing 29 | 30 | If you use PaiNN in your work, please consider citing the original paper: 31 | 32 | ```bibtex 33 | @inproceedings{schutt_painn_2021, 34 | title = {Equivariant message passing for the prediction of tensorial properties and molecular spectra}, 35 | author = {Sch{\"u}tt, Kristof and Unke, Oliver and Gastegger, Michael}, 36 | booktitle = {Proceedings of the International Conference on Machine Learning (ICML)}, 37 | year = {2021}, 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /data/conformers/fragment_merging/fragments/mol_8.mol: -------------------------------------------------------------------------------- 1 | D68EV3CPROA-x2099_0A 2 | RDKit 3D 3 | 4 | 27 29 0 0 0 0 0 0 0 0999 V2000 5 | -12.5050 -1.5520 -7.1690 N 0 0 0 0 0 0 0 0 0 0 0 0 6 | -7.4410 -4.1620 -8.3630 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -6.6950 -3.8860 -9.3130 O 0 0 0 0 0 0 0 0 0 0 0 0 8 | -8.5890 -3.3360 -8.0400 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | -6.6220 -7.3450 -5.6390 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | -5.8100 -7.1560 -6.7350 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | -6.0730 -6.1390 -7.6420 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | -7.1600 -5.2860 -7.4440 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | -9.3560 -3.5650 -6.9360 C 0 0 0 0 0 0 0 0 0 0 0 0 14 | -10.4980 -2.7220 -6.5450 C 0 0 0 0 0 0 0 0 0 0 0 0 15 | -10.6670 -2.3050 -5.2310 C 0 0 0 0 0 0 0 0 0 0 0 0 16 | -11.7460 -1.5100 -4.8990 C 0 0 0 0 0 0 0 0 0 0 0 0 17 | -12.6350 -1.1610 -5.8950 C 0 0 0 0 0 0 0 0 0 0 0 0 18 | -11.4460 -2.3140 -7.4750 C 0 0 0 0 0 0 0 0 0 0 0 0 19 | -7.9850 -5.4840 -6.3220 C 0 0 0 0 0 0 0 0 0 0 0 0 20 | -7.7120 -6.5180 -5.4260 C 0 0 0 0 0 0 0 0 0 0 0 0 21 | -9.0520 -4.6230 -6.1030 N 0 0 0 0 0 0 0 0 0 0 0 0 22 | -8.8486 -2.4997 -8.7057 H 0 0 0 0 0 0 0 0 0 0 0 0 23 | -6.4037 -8.1567 -4.9294 H 0 0 0 0 0 0 0 0 0 0 0 0 24 | -4.9451 -7.8171 -6.8930 H 0 0 0 0 0 0 0 0 0 0 0 0 25 | -5.4231 -6.0048 -8.5193 H 0 0 0 0 0 0 0 0 0 0 0 0 26 | -9.9445 -2.6067 -4.4583 H 0 0 0 0 0 0 0 0 0 0 0 0 27 | -11.8934 -1.1632 -3.8655 H 0 0 0 0 0 0 0 0 0 0 0 0 28 | -13.4949 -0.5289 -5.6286 H 0 0 0 0 0 0 0 0 0 0 0 0 29 | -11.3204 -2.6340 -8.5199 H 0 0 0 0 0 0 0 0 0 0 0 0 30 | -8.3613 -6.6759 -4.5522 H 0 0 0 0 0 0 0 0 0 0 0 0 31 | -9.6468 -4.7817 -5.2772 H 0 0 0 0 0 0 0 0 0 0 0 0 32 | 3 2 2 0 33 | 4 2 1 0 34 | 6 5 2 0 35 | 7 6 1 0 36 | 8 7 2 0 37 | 8 2 1 0 38 | 9 4 2 0 39 | 10 9 1 0 40 | 11 10 1 0 41 | 12 11 2 0 42 | 13 1 2 0 43 | 13 12 1 0 44 | 14 1 1 0 45 | 14 10 2 0 46 | 15 8 1 0 47 | 16 15 2 0 48 | 16 5 1 0 49 | 17 15 1 0 50 | 17 9 1 0 51 | 4 18 1 0 52 | 5 19 1 0 53 | 6 20 1 0 54 | 7 21 1 0 55 | 11 22 1 0 56 | 12 23 1 0 57 | 13 24 1 0 58 | 14 25 1 0 59 | 16 26 1 0 60 | 17 27 1 0 61 | M END 62 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/modules/scheduler.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import torch.optim.lr_scheduler as lr_scheduler 4 | 5 | from ocpmodels.common.utils import warmup_lr_lambda 6 | 7 | 8 | class LRScheduler: 9 | """ 10 | Learning rate scheduler class for torch.optim learning rate schedulers 11 | 12 | Notes: 13 | If no learning rate scheduler is specified in the config the default 14 | scheduler is warmup_lr_lambda (ocpmodels.common.utils) not no scheduler, 15 | this is for backward-compatibility reasons. To run without a lr scheduler 16 | specify scheduler: "Null" in the optim section of the config. 17 | 18 | Args: 19 | config (dict): Optim dict from the input config 20 | optimizer (obj): torch optim object 21 | """ 22 | 23 | def __init__(self, optimizer, config): 24 | self.optimizer = optimizer 25 | self.config = config.copy() 26 | if "scheduler" in self.config: 27 | self.scheduler_type = self.config["scheduler"] 28 | else: 29 | self.scheduler_type = "LambdaLR" 30 | scheduler_lambda_fn = lambda x: warmup_lr_lambda(x, self.config) 31 | self.config["lr_lambda"] = scheduler_lambda_fn 32 | 33 | if self.scheduler_type != "Null": 34 | self.scheduler = getattr(lr_scheduler, self.scheduler_type) 35 | scheduler_args = self.filter_kwargs(config) 36 | self.scheduler = self.scheduler(optimizer, **scheduler_args) 37 | 38 | def step(self, metrics=None, epoch=None): 39 | if self.scheduler_type == "Null": 40 | return 41 | if self.scheduler_type == "ReduceLROnPlateau": 42 | if metrics is None: 43 | raise Exception( 44 | "Validation set required for ReduceLROnPlateau." 45 | ) 46 | self.scheduler.step(metrics) 47 | else: 48 | self.scheduler.step() 49 | 50 | def filter_kwargs(self, config): 51 | # adapted from https://stackoverflow.com/questions/26515595/ 52 | sig = inspect.signature(self.scheduler) 53 | filter_keys = [ 54 | param.name 55 | for param in sig.parameters.values() 56 | if param.kind == param.POSITIONAL_OR_KEYWORD 57 | ] 58 | filter_keys.remove("optimizer") 59 | scheduler_args = { 60 | arg: self.config[arg] for arg in self.config if arg in filter_keys 61 | } 62 | return scheduler_args 63 | 64 | def get_lr(self): 65 | for group in self.optimizer.param_groups: 66 | return group["lr"] 67 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/modules/scaling/compat.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from pathlib import Path 4 | from typing import Dict, Optional, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .scale_factor import ScaleFactor 10 | 11 | ScaleDict = Union[Dict[str, float], Dict[str, torch.Tensor]] 12 | 13 | 14 | def _load_scale_dict(scale_file: Optional[Union[str, ScaleDict]]): 15 | """ 16 | Loads scale factors from either: 17 | - a JSON file mapping scale factor names to scale values 18 | - a python dictionary pickled object (loaded using `torch.load`) mapping scale factor names to scale values 19 | - a dictionary mapping scale factor names to scale values 20 | """ 21 | if not scale_file: 22 | return None 23 | 24 | if isinstance(scale_file, dict): 25 | if not scale_file: 26 | logging.warning("Empty scale dictionary provided to model.") 27 | return scale_file 28 | 29 | path = Path(scale_file) 30 | if not path.exists(): 31 | raise ValueError(f"Scale file {path} does not exist.") 32 | 33 | scale_dict: Optional[ScaleDict] = None 34 | if path.suffix == ".pt": 35 | scale_dict = torch.load(path) 36 | elif path.suffix == ".json": 37 | with open(path, "r") as f: 38 | scale_dict = json.load(f) 39 | 40 | if isinstance(scale_dict, dict): 41 | # old json scale factors have a comment field that has the model name 42 | scale_dict.pop("comment", None) 43 | else: 44 | raise ValueError(f"Unsupported scale file extension: {path.suffix}") 45 | 46 | if not scale_dict: 47 | return None 48 | 49 | return scale_dict 50 | 51 | 52 | def load_scales_compat( 53 | module: nn.Module, scale_file: Optional[Union[str, ScaleDict]] 54 | ): 55 | scale_dict = _load_scale_dict(scale_file) 56 | if not scale_dict: 57 | return 58 | 59 | scale_factors = { 60 | module.name or name: (module, name) 61 | for name, module in module.named_modules() 62 | if isinstance(module, ScaleFactor) 63 | } 64 | logging.debug( 65 | f"Found the following scale factors: {[(k, name) for k, (_, name) in scale_factors.items()]}" 66 | ) 67 | for name, scale in scale_dict.items(): 68 | if name not in scale_factors: 69 | logging.warning(f"Scale factor {name} not found in model") 70 | continue 71 | 72 | scale_module, module_name = scale_factors[name] 73 | logging.debug( 74 | f"Loading scale factor {scale} for ({name} => {module_name})" 75 | ) 76 | scale_module.set_(scale) 77 | -------------------------------------------------------------------------------- /data/conformers/pdb/README.md: -------------------------------------------------------------------------------- 1 | # Co-factors from PDB structures 2 | 3 | Best docking scores using AutoDock Vina 1.2.5 (exhaustiveness = 1000). 4 | 5 | Generated the lowest energy conformers by: 6 | 1. Embed 1000 molecules using RDKit v2024.03.3 ETKDG 7 | 2. Cluster with RDKit's Butina clustering 8 | 3. Relax remaining structures with xTB v6.6.1 in implicit water as solvent 9 | 4. Cluster those structures with RDKit's Butina clustering 10 | 5. Select the conformer with the lowest energy 11 | 12 | Order of structures: 13 | 0. Cc1ccc(OC[C@@H](O)[C@H](C)NC(C)C)c2c1CCC2 14 | - JRZ for 3ny8: -8.615 kcal/mol 15 | 1. CCCOc1cc(Cl)cc(-c2cc(-c3ccccc3C#N)cn(-c3cccnc3)c2=O)c1 16 | - XF1 for 7l11: -9.132 17 | 2. Nc1nc(NCCc2ccc(O)cc2)nc2nc(-c3ccco3)nn12 18 | - ZMA for 3eml: -9.158 kcal/mol 19 | 3. COc1cccc(-c2cc(-c3ccc(C(=O)O)cc3)c(C#N)c(=O)[nH]2)c1 20 | - QZZ for 4unn: -9.251 kcal/mol 21 | 4. O=C(C=Cc1ccc(O)cc1)c1ccc(O)cc1O 22 | - HCC for 4rlu: -8.697 kcal/mol 23 | 5. Cc1ccc(NC(=O)c2ccc(CN3CCN(C)CC3)cc2)cc1Nc1nccc(-c2cccnc2)n1 24 | - STI / Imatinib for 1iep: -9.521 kcal/mol 25 | 6. O=C(Nc1ccc(OC(F)(F)Cl)cc1)c1cnc(N2CC[C@@H](O)C2)c(-c2ccn[nH]2)c1 26 | - AY7 / asciminib for 5mo4: -9.938 kcal/mol 27 | 28 | # Top 1 from docking screen of 10k random molecules selected from the test set of MOSES 29 | `molblock_charges_docking_screen.pkl` 30 | 31 | Followed the same procedure to find the lowest energy conformers as above. 32 | 33 | Best docking scores using AutoDock Vina 1.2.5 (exhaustiveness = 32). 34 | 35 | 0. Cn1cc(NC(=O)c2ccc3c(c2)CC(c2ccccc2)OC3=O)cn1 36 | - 1iep: -12.11 kcal/mol 37 | 1. Cc1ccc2nc(-c3ccccn3)cc(C(=O)Nc3cccnc3)c2c1 38 | - 3eml: -11.341 kcal/mol 39 | 2. Cn1nc(C(=O)Nc2n[nH]c3ccccc23)c2ccccc2c1=O 40 | - 3ny8: -11.857 kcal/mol 41 | 3. Cc1cccc(C(=O)Nc2cccc(-c3nn4c(C)nnc4s3)c2)c1 42 | - 4rlu: -11.904 kcal/mol 43 | 4. Cc1cc(-c2cccc(-c3ccn(Cc4nnc(C)o4)n3)c2)cc(C)n1 44 | - 4unn: -10.636 kcal/mol 45 | 5. O=C(c1n[nH]c2ccccc12)N1CCc2c([nH]c3ccnn3c2=O)C1 46 | - 5mo4: -10.461 kcal/mol 47 | 6. O=C1CC(C(=O)N2CCc3cc(F)ccc3C2)c2ccccc2N1 48 | - 7l11: -8.815 kcal/mol 49 | 50 | 51 | # PDB co-crystal ligands -- pose and lowest energies 52 | - `molblock_charges_pdb_pose.pkl` 53 | - Added hydrogens with rdkit, not xtb optimized 54 | - `molblock_charges_pdb_lowestenergy.pkl` 55 | - Lowest energy conformer of PDB co-crystal 56 | 57 | Order (but check) [4unn, 7l11, 5mo4, 1iep, 3ny8, 4rlu, 3eml] 58 | 59 | # Top 1 docked from docking screen -- pose and lowest energies 60 | - `molblock_charges_bestdocked_pose.pkl` 61 | - Added hydrogens with rdkit, not xtb optimized 62 | - `molblock_charges_bestdocked_lowestenergy.pkl 63 | - lowest energy conformers 64 | 65 | In order of ['1iep', '3eml', '3ny8', '4rlu', '4unn', '5mo4', '7l11'] 66 | -------------------------------------------------------------------------------- /data/conformers/fragment_merging/fragments/mol_10.mol: -------------------------------------------------------------------------------- 1 | D68EV3CPROA-x1140_0A 2 | RDKit 3D 3 | 4 | 29 30 0 0 0 0 0 0 0 0999 V2000 5 | -5.9620 -6.5250 -3.8110 N 0 0 0 0 0 0 0 0 0 0 0 0 6 | -6.4520 -4.6580 -1.0260 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -8.3660 -6.0800 -4.1360 O 0 0 0 0 0 0 0 0 0 0 0 0 8 | -7.4830 -5.4800 -1.7580 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | -6.0180 -7.9390 -3.4130 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | -5.3790 -8.1280 -2.0630 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | -6.0350 -8.8460 -1.0650 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | -5.4880 -9.0040 0.1980 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | -4.2510 -8.4410 0.4370 C 0 0 0 0 0 0 0 0 0 0 0 0 14 | -2.3180 -7.7170 1.2970 C 0 0 0 0 0 0 0 0 0 0 0 0 15 | -3.5790 -7.7610 -0.5440 C 0 0 0 0 0 0 0 0 0 0 0 0 16 | -4.1290 -7.5770 -1.7930 C 0 0 0 0 0 0 0 0 0 0 0 0 17 | -6.7760 -4.1990 -3.9200 O 0 0 0 0 0 0 0 0 0 0 0 0 18 | -3.5290 -8.4120 1.6090 O 0 0 0 0 0 0 0 0 0 0 0 0 19 | -2.3300 -7.3970 -0.0970 O 0 0 0 0 0 0 0 0 0 0 0 0 20 | -7.2080 -5.5050 -3.5180 S 0 0 0 0 0 0 0 0 0 0 0 0 21 | -5.1397 -6.1225 -3.3391 H 0 0 0 0 0 0 0 0 0 0 0 0 22 | -6.5949 -4.7720 0.0587 H 0 0 0 0 0 0 0 0 0 0 0 0 23 | -5.4446 -5.0035 -1.3014 H 0 0 0 0 0 0 0 0 0 0 0 0 24 | -6.5644 -3.5990 -1.3014 H 0 0 0 0 0 0 0 0 0 0 0 0 25 | -7.4450 -6.5123 -1.3799 H 0 0 0 0 0 0 0 0 0 0 0 0 26 | -8.4664 -5.0208 -1.5790 H 0 0 0 0 0 0 0 0 0 0 0 0 27 | -7.0687 -8.2611 -3.3653 H 0 0 0 0 0 0 0 0 0 0 0 0 28 | -5.4723 -8.5411 -4.1544 H 0 0 0 0 0 0 0 0 0 0 0 0 29 | -7.0127 -9.2992 -1.2858 H 0 0 0 0 0 0 0 0 0 0 0 0 30 | -6.0215 -9.5593 0.9835 H 0 0 0 0 0 0 0 0 0 0 0 0 31 | -2.2562 -6.7927 1.8902 H 0 0 0 0 0 0 0 0 0 0 0 0 32 | -1.4479 -8.3469 1.5337 H 0 0 0 0 0 0 0 0 0 0 0 0 33 | -3.5907 -7.0052 -2.5633 H 0 0 0 0 0 0 0 0 0 0 0 0 34 | 4 2 1 0 35 | 5 1 1 0 36 | 6 5 1 0 37 | 7 6 2 0 38 | 8 7 1 0 39 | 9 8 2 0 40 | 11 9 1 0 41 | 12 11 2 0 42 | 12 6 1 0 43 | 14 10 1 0 44 | 14 9 1 0 45 | 15 11 1 0 46 | 15 10 1 0 47 | 16 4 1 0 48 | 16 3 2 0 49 | 16 1 1 0 50 | 16 13 2 0 51 | 1 17 1 0 52 | 2 18 1 0 53 | 2 19 1 0 54 | 2 20 1 0 55 | 4 21 1 0 56 | 4 22 1 0 57 | 5 23 1 0 58 | 5 24 1 0 59 | 7 25 1 0 60 | 8 26 1 0 61 | 10 27 1 0 62 | 10 28 1 0 63 | 12 29 1 0 64 | M END 65 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/scn/smearing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | # Different encodings for the atom distance embeddings 13 | class GaussianSmearing(torch.nn.Module): 14 | def __init__( 15 | self, start=-5.0, stop=5.0, num_gaussians=50, basis_width_scalar=1.0 16 | ): 17 | super(GaussianSmearing, self).__init__() 18 | self.num_output = num_gaussians 19 | offset = torch.linspace(start, stop, num_gaussians) 20 | self.coeff = ( 21 | -0.5 / (basis_width_scalar * (offset[1] - offset[0])).item() ** 2 22 | ) 23 | self.register_buffer("offset", offset) 24 | 25 | def forward(self, dist): 26 | dist = dist.view(-1, 1) - self.offset.view(1, -1) 27 | return torch.exp(self.coeff * torch.pow(dist, 2)) 28 | 29 | 30 | class SigmoidSmearing(torch.nn.Module): 31 | def __init__( 32 | self, start=-5.0, stop=5.0, num_sigmoid=50, basis_width_scalar=1.0 33 | ): 34 | super(SigmoidSmearing, self).__init__() 35 | self.num_output = num_sigmoid 36 | offset = torch.linspace(start, stop, num_sigmoid) 37 | self.coeff = (basis_width_scalar / (offset[1] - offset[0])).item() 38 | self.register_buffer("offset", offset) 39 | 40 | def forward(self, dist): 41 | exp_dist = self.coeff * (dist.view(-1, 1) - self.offset.view(1, -1)) 42 | return torch.sigmoid(exp_dist) 43 | 44 | 45 | class LinearSigmoidSmearing(torch.nn.Module): 46 | def __init__( 47 | self, start=-5.0, stop=5.0, num_sigmoid=50, basis_width_scalar=1.0 48 | ): 49 | super(LinearSigmoidSmearing, self).__init__() 50 | self.num_output = num_sigmoid 51 | offset = torch.linspace(start, stop, num_sigmoid) 52 | self.coeff = (basis_width_scalar / (offset[1] - offset[0])).item() 53 | self.register_buffer("offset", offset) 54 | 55 | def forward(self, dist): 56 | exp_dist = self.coeff * (dist.view(-1, 1) - self.offset.view(1, -1)) 57 | x_dist = torch.sigmoid(exp_dist) + 0.001 * exp_dist 58 | return x_dist 59 | 60 | 61 | class SiLUSmearing(torch.nn.Module): 62 | def __init__( 63 | self, start=-5.0, stop=5.0, num_output=50, basis_width_scalar=1.0 64 | ): 65 | super(SiLUSmearing, self).__init__() 66 | self.num_output = num_output 67 | self.fc1 = nn.Linear(2, num_output) 68 | self.act = nn.SiLU() 69 | 70 | def forward(self, dist): 71 | x_dist = dist.view(-1, 1) 72 | x_dist = torch.cat([x_dist, torch.ones_like(x_dist)], dim=1) 73 | x_dist = self.act(self.fc1(x_dist)) 74 | return x_dist 75 | -------------------------------------------------------------------------------- /data/conformers/fragment_merging/fragments/mol_12.mol: -------------------------------------------------------------------------------- 1 | D68EV3CPROA-x1071_0A 2 | RDKit 3D 3 | 4 | 30 30 0 0 0 0 0 0 0 0999 V2000 5 | -6.9480 -2.0740 -1.9750 N 0 0 0 0 0 0 0 0 0 0 0 0 6 | -7.9010 0.1820 -0.7890 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -5.8830 -1.1980 0.0490 O 0 0 0 0 0 0 0 0 0 0 0 0 8 | -6.0210 -3.1740 -2.2340 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | -6.6440 -4.3900 -2.9870 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | -6.1460 -6.8090 -3.0610 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | -4.9510 -7.4070 -2.3400 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | -5.1050 -7.3000 -0.8330 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | -6.5600 -7.1310 -0.4510 C 0 0 0 0 0 0 0 0 0 0 0 0 14 | -7.1340 -5.8510 -1.0260 C 0 0 0 0 0 0 0 0 0 0 0 0 15 | -6.6290 -5.5990 -2.3830 N 0 0 0 0 0 0 0 0 0 0 0 0 16 | -5.6690 0.0200 -2.0780 O 0 0 0 0 0 0 0 0 0 0 0 0 17 | -7.1020 -4.2370 -4.1080 O 0 0 0 0 0 0 0 0 0 0 0 0 18 | -6.4670 -0.7510 -1.1750 S 0 0 0 0 0 0 0 0 0 0 0 0 19 | -7.2872 -1.7541 -2.8934 H 0 0 0 0 0 0 0 0 0 0 0 0 20 | -8.0095 0.2526 0.3034 H 0 0 0 0 0 0 0 0 0 0 0 0 21 | -7.8067 1.1921 -1.2141 H 0 0 0 0 0 0 0 0 0 0 0 0 22 | -8.7863 -0.3135 -1.2141 H 0 0 0 0 0 0 0 0 0 0 0 0 23 | -5.1895 -2.7854 -2.8403 H 0 0 0 0 0 0 0 0 0 0 0 0 24 | -5.7022 -3.5489 -1.2502 H 0 0 0 0 0 0 0 0 0 0 0 0 25 | -5.8514 -6.5502 -4.0887 H 0 0 0 0 0 0 0 0 0 0 0 0 26 | -6.9570 -7.5522 -3.0634 H 0 0 0 0 0 0 0 0 0 0 0 0 27 | -4.8599 -8.4678 -2.6164 H 0 0 0 0 0 0 0 0 0 0 0 0 28 | -4.0512 -6.8485 -2.6374 H 0 0 0 0 0 0 0 0 0 0 0 0 29 | -4.7133 -8.2151 -0.3648 H 0 0 0 0 0 0 0 0 0 0 0 0 30 | -4.5446 -6.4204 -0.4833 H 0 0 0 0 0 0 0 0 0 0 0 0 31 | -7.1338 -7.9864 -0.8371 H 0 0 0 0 0 0 0 0 0 0 0 0 32 | -6.6280 -7.0808 0.6457 H 0 0 0 0 0 0 0 0 0 0 0 0 33 | -8.2302 -5.9348 -1.0619 H 0 0 0 0 0 0 0 0 0 0 0 0 34 | -6.8274 -5.0140 -0.3815 H 0 0 0 0 0 0 0 0 0 0 0 0 35 | 4 1 1 0 36 | 5 4 1 0 37 | 7 6 1 0 38 | 8 7 1 0 39 | 9 8 1 0 40 | 10 9 1 0 41 | 11 10 1 0 42 | 11 6 1 0 43 | 11 5 1 0 44 | 13 5 2 0 45 | 14 12 2 0 46 | 14 3 2 0 47 | 14 2 1 0 48 | 14 1 1 0 49 | 1 15 1 0 50 | 2 16 1 0 51 | 2 17 1 0 52 | 2 18 1 0 53 | 4 19 1 0 54 | 4 20 1 0 55 | 6 21 1 0 56 | 6 22 1 0 57 | 7 23 1 0 58 | 7 24 1 0 59 | 8 25 1 0 60 | 8 26 1 0 61 | 9 27 1 0 62 | 9 28 1 0 63 | 10 29 1 0 64 | 10 30 1 0 65 | M END 66 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/nets/equiformer_v2/edge_rot_mat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def init_edge_rot_mat(edge_distance_vec): 5 | edge_vec_0 = edge_distance_vec 6 | edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1)) 7 | 8 | # Make sure the atoms are far enough apart 9 | #assert torch.min(edge_vec_0_distance) < 0.0001 10 | 11 | if torch.min(edge_vec_0_distance) < 0.0001: 12 | print( 13 | "Error edge_vec_0_distance: {}".format( 14 | torch.min(edge_vec_0_distance) 15 | ) 16 | ) 17 | 18 | norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1)) 19 | 20 | edge_vec_2 = torch.rand_like(edge_vec_0) - 0.5 21 | edge_vec_2 = edge_vec_2 / ( 22 | torch.sqrt(torch.sum(edge_vec_2**2, dim=1)).view(-1, 1) 23 | ) 24 | # Create two rotated copys of the random vectors in case the random vector is aligned with norm_x 25 | # With two 90 degree rotated vectors, at least one should not be aligned with norm_x 26 | edge_vec_2b = edge_vec_2.clone() 27 | edge_vec_2b[:, 0] = -edge_vec_2[:, 1] 28 | edge_vec_2b[:, 1] = edge_vec_2[:, 0] 29 | edge_vec_2c = edge_vec_2.clone() 30 | edge_vec_2c[:, 1] = -edge_vec_2[:, 2] 31 | edge_vec_2c[:, 2] = edge_vec_2[:, 1] 32 | vec_dot_b = torch.abs(torch.sum(edge_vec_2b * norm_x, dim=1)).view( 33 | -1, 1 34 | ) 35 | vec_dot_c = torch.abs(torch.sum(edge_vec_2c * norm_x, dim=1)).view( 36 | -1, 1 37 | ) 38 | 39 | vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) 40 | edge_vec_2 = torch.where( 41 | torch.gt(vec_dot, vec_dot_b), edge_vec_2b, edge_vec_2 42 | ) 43 | vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) 44 | edge_vec_2 = torch.where( 45 | torch.gt(vec_dot, vec_dot_c), edge_vec_2c, edge_vec_2 46 | ) 47 | 48 | vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)) 49 | 50 | # Check the vectors aren't aligned 51 | assert torch.max(vec_dot) < 0.99 52 | 53 | norm_z = torch.cross(norm_x, edge_vec_2, dim=1) 54 | norm_z = norm_z / ( 55 | torch.sqrt(torch.sum(norm_z**2, dim=1, keepdim=True)) 56 | ) 57 | norm_z = norm_z / ( 58 | torch.sqrt(torch.sum(norm_z**2, dim=1)).view(-1, 1) 59 | ) 60 | norm_y = torch.cross(norm_x, norm_z, dim=1) 61 | norm_y = norm_y / ( 62 | torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True)) 63 | ) 64 | 65 | # Construct the 3D rotation matrix 66 | norm_x = norm_x.view(-1, 3, 1) 67 | norm_y = -norm_y.view(-1, 3, 1) 68 | norm_z = norm_z.view(-1, 3, 1) 69 | 70 | edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2) 71 | edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2) 72 | 73 | return edge_rot_mat.detach() -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/oc20/trainer/make_lmdb_sizes.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script provides the functionality to generate metadata.npz files necessary 3 | for load_balancing the DataLoader. 4 | 5 | 1. Copy from: https://github.com/Open-Catalyst-Project/ocp/blob/09f0d3bdf4c9154c2f11105eb5d61e1cc0d8c638/scripts/make_lmdb_sizes.py 6 | 2. Used for generating sizes of data for S2EF task. Since the data is generated with 7 | PyG 2+, the code for generating sizes of data should be modified. 8 | 9 | """ 10 | 11 | 12 | import argparse 13 | import multiprocessing as mp 14 | import os 15 | import warnings 16 | 17 | import numpy as np 18 | from tqdm import tqdm 19 | 20 | #from ocpmodels.datasets import SinglePointLmdbDataset, TrajectoryLmdbDataset 21 | from lmdb_dataset import SinglePointLmdbDatasetV2, TrajectoryLmdbDatasetV2 22 | 23 | 24 | def get_data(index): 25 | data = dataset[index] 26 | natoms = data.natoms 27 | neighbors = None 28 | if hasattr(data, "edge_index") and data.edge_index is not None: 29 | neighbors = data.edge_index.shape[1] 30 | 31 | return index, natoms, neighbors 32 | 33 | 34 | def main(args): 35 | path = args.data_path 36 | global dataset 37 | if os.path.isdir(path): 38 | dataset = TrajectoryLmdbDatasetV2({"src": path}) 39 | outpath = os.path.join(path, "metadata.npz") 40 | elif os.path.isfile(path): 41 | dataset = SinglePointLmdbDatasetV2({"src": path}) 42 | outpath = os.path.join(os.path.dirname(path), "metadata.npz") 43 | 44 | indices = range(len(dataset)) 45 | 46 | pool = mp.Pool(args.num_workers) 47 | outputs = list(tqdm(pool.imap(get_data, indices), total=len(indices))) 48 | 49 | indices = [] 50 | natoms = [] 51 | neighbors = [] 52 | for i in outputs: 53 | indices.append(i[0]) 54 | natoms.append(i[1]) 55 | neighbors.append(i[2]) 56 | 57 | _sort = np.argsort(indices) 58 | sorted_natoms = np.array(natoms, dtype=np.int32)[_sort] 59 | if None in neighbors: 60 | warnings.warn( 61 | f"edge_index information not found, {outpath} only supports atom-wise load balancing." 62 | ) 63 | np.savez(outpath, natoms=sorted_natoms) 64 | else: 65 | sorted_neighbors = np.array(neighbors, dtype=np.int32)[_sort] 66 | np.savez(outpath, natoms=sorted_natoms, neighbors=sorted_neighbors) 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument( 72 | "--data-path", 73 | required=True, 74 | type=str, 75 | help="Path to S2EF directory or IS2R* .lmdb file", 76 | ) 77 | parser.add_argument( 78 | "--num-workers", 79 | default=1, 80 | type=int, 81 | help="Num of workers to parallelize across", 82 | ) 83 | args = parser.parse_args() 84 | main(args) -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/modules/loss.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from ocpmodels.common import distutils 7 | 8 | 9 | class L2MAELoss(nn.Module): 10 | def __init__(self, reduction="mean"): 11 | super().__init__() 12 | self.reduction = reduction 13 | assert reduction in ["mean", "sum"] 14 | 15 | def forward(self, input: torch.Tensor, target: torch.Tensor): 16 | dists = torch.norm(input - target, p=2, dim=-1) 17 | if self.reduction == "mean": 18 | return torch.mean(dists) 19 | elif self.reduction == "sum": 20 | return torch.sum(dists) 21 | 22 | 23 | class AtomwiseL2Loss(nn.Module): 24 | def __init__(self, reduction="mean"): 25 | super().__init__() 26 | self.reduction = reduction 27 | assert reduction in ["mean", "sum"] 28 | 29 | def forward( 30 | self, 31 | input: torch.Tensor, 32 | target: torch.Tensor, 33 | natoms: torch.Tensor, 34 | ): 35 | assert natoms.shape[0] == input.shape[0] == target.shape[0] 36 | assert len(natoms.shape) == 1 # (nAtoms, ) 37 | 38 | dists = torch.norm(input - target, p=2, dim=-1) 39 | loss = natoms * dists 40 | 41 | if self.reduction == "mean": 42 | return torch.mean(loss) 43 | elif self.reduction == "sum": 44 | return torch.sum(loss) 45 | 46 | 47 | class DDPLoss(nn.Module): 48 | def __init__(self, loss_fn, reduction="mean"): 49 | super().__init__() 50 | self.loss_fn = loss_fn 51 | self.loss_fn.reduction = "sum" 52 | self.reduction = reduction 53 | assert reduction in ["mean", "sum"] 54 | 55 | def forward( 56 | self, 57 | input: torch.Tensor, 58 | target: torch.Tensor, 59 | natoms: torch.Tensor = None, 60 | batch_size: int = None, 61 | ): 62 | # zero out nans, if any 63 | found_nans_or_infs = not torch.all(input.isfinite()) 64 | if found_nans_or_infs is True: 65 | logging.warning("Found nans while computing loss") 66 | input = torch.nan_to_num(input, nan=0.0) 67 | 68 | if natoms is None: 69 | loss = self.loss_fn(input, target) 70 | else: # atom-wise loss 71 | loss = self.loss_fn(input, target, natoms) 72 | if self.reduction == "mean": 73 | num_samples = ( 74 | batch_size if batch_size is not None else input.shape[0] 75 | ) 76 | num_samples = distutils.all_reduce( 77 | num_samples, device=input.device 78 | ) 79 | # Multiply by world size since gradients are averaged 80 | # across DDP replicas 81 | return loss * distutils.get_world_size() / num_samples 82 | else: 83 | return loss 84 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/gemnet/layers/embedding_block.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from .base_layers import Dense 12 | 13 | 14 | class AtomEmbedding(torch.nn.Module): 15 | """ 16 | Initial atom embeddings based on the atom type 17 | 18 | Parameters 19 | ---------- 20 | emb_size: int 21 | Atom embeddings size 22 | """ 23 | 24 | def __init__(self, emb_size, num_elements): 25 | super().__init__() 26 | self.emb_size = emb_size 27 | 28 | self.embeddings = torch.nn.Embedding(num_elements, emb_size) 29 | # init by uniform distribution 30 | torch.nn.init.uniform_( 31 | self.embeddings.weight, a=-np.sqrt(3), b=np.sqrt(3) 32 | ) 33 | 34 | def forward(self, Z): 35 | """ 36 | Returns 37 | ------- 38 | h: torch.Tensor, shape=(nAtoms, emb_size) 39 | Atom embeddings. 40 | """ 41 | h = self.embeddings(Z - 1) # -1 because Z.min()=1 (==Hydrogen) 42 | return h 43 | 44 | 45 | class EdgeEmbedding(torch.nn.Module): 46 | """ 47 | Edge embedding based on the concatenation of atom embeddings and subsequent dense layer. 48 | 49 | Parameters 50 | ---------- 51 | emb_size: int 52 | Embedding size after the dense layer. 53 | activation: str 54 | Activation function used in the dense layer. 55 | """ 56 | 57 | def __init__( 58 | self, 59 | atom_features, 60 | edge_features, 61 | out_features, 62 | activation=None, 63 | ): 64 | super().__init__() 65 | in_features = 2 * atom_features + edge_features 66 | self.dense = Dense( 67 | in_features, out_features, activation=activation, bias=False 68 | ) 69 | 70 | def forward( 71 | self, 72 | h, 73 | m_rbf, 74 | idx_s, 75 | idx_t, 76 | ): 77 | """ 78 | 79 | Arguments 80 | --------- 81 | h 82 | m_rbf: shape (nEdges, nFeatures) 83 | in embedding block: m_rbf = rbf ; In interaction block: m_rbf = m_st 84 | idx_s 85 | idx_t 86 | 87 | Returns 88 | ------- 89 | m_st: torch.Tensor, shape=(nEdges, emb_size) 90 | Edge embeddings. 91 | """ 92 | h_s = h[idx_s] # shape=(nEdges, emb_size) 93 | h_t = h[idx_t] # shape=(nEdges, emb_size) 94 | 95 | m_st = torch.cat( 96 | [h_s, h_t, m_rbf], dim=-1 97 | ) # (nEdges, 2*emb_size+nFeatures) 98 | m_st = self.dense(m_st) # (nEdges, emb_size) 99 | return m_st 100 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/gemnet_gp/layers/embedding_block.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from ocpmodels.common import gp_utils 12 | 13 | from .base_layers import Dense 14 | 15 | 16 | class AtomEmbedding(torch.nn.Module): 17 | """ 18 | Initial atom embeddings based on the atom type 19 | 20 | Parameters 21 | ---------- 22 | emb_size: int 23 | Atom embeddings size 24 | """ 25 | 26 | def __init__(self, emb_size): 27 | super().__init__() 28 | self.emb_size = emb_size 29 | 30 | # Atom embeddings: We go up to Bi (83). 31 | self.embeddings = torch.nn.Embedding(83, emb_size) 32 | # init by uniform distribution 33 | torch.nn.init.uniform_( 34 | self.embeddings.weight, a=-np.sqrt(3), b=np.sqrt(3) 35 | ) 36 | 37 | def forward(self, Z): 38 | """ 39 | Returns 40 | ------- 41 | h: torch.Tensor, shape=(nAtoms, emb_size) 42 | Atom embeddings. 43 | """ 44 | h = self.embeddings(Z - 1) # -1 because Z.min()=1 (==Hydrogen) 45 | return h 46 | 47 | 48 | class EdgeEmbedding(torch.nn.Module): 49 | """ 50 | Edge embedding based on the concatenation of atom embeddings and subsequent dense layer. 51 | 52 | Parameters 53 | ---------- 54 | emb_size: int 55 | Embedding size after the dense layer. 56 | activation: str 57 | Activation function used in the dense layer. 58 | """ 59 | 60 | def __init__( 61 | self, 62 | atom_features, 63 | edge_features, 64 | out_features, 65 | activation=None, 66 | ): 67 | super().__init__() 68 | in_features = 2 * atom_features + edge_features 69 | self.dense = Dense( 70 | in_features, out_features, activation=activation, bias=False 71 | ) 72 | 73 | def forward( 74 | self, 75 | h, 76 | m_rbf, 77 | idx_s, 78 | idx_t, 79 | ): 80 | """ 81 | 82 | Arguments 83 | --------- 84 | h 85 | m_rbf: shape (nEdges, nFeatures) 86 | in embedding block: m_rbf = rbf ; In interaction block: m_rbf = m_st 87 | idx_s 88 | idx_t 89 | 90 | Returns 91 | ------- 92 | m_st: torch.Tensor, shape=(nEdges, emb_size) 93 | Edge embeddings. 94 | """ 95 | h = gp_utils.gather_from_model_parallel_region(h, dim=0) 96 | h_s = h[idx_s] # shape=(nEdges, emb_size) 97 | h_t = h[idx_t] # shape=(nEdges, emb_size) 98 | 99 | m_st = torch.cat( 100 | [h_s, h_t, m_rbf], dim=-1 101 | ) # (nEdges, 2*emb_size+nFeatures) 102 | m_st = self.dense(m_st) # (nEdges, emb_size) 103 | return m_st 104 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # September 3, 2025 (v0.2.4) 2 | ### Model loading and repository optimization 3 | - Added automatic model downloading from HuggingFace Hub with `load_model()`, `get_model_info()`, and `clear_model_cache()` functions 4 | - Removed model weights from git history to reduce repository size - **users should re-clone the repository** 5 | - Added ability for interrupting inference with improved UI to Streamlit app 6 | 7 | # August 29, 2025 (v0.2.3) 8 | ### Add Streamlit app for demonstrations 9 | - Added an easy-to-use app for demonstration purposes 10 | 11 | # August 29, 2025 (v0.2.2) 12 | ### Refresh inference code 13 | - Refactored ShEPhERD inference code to be more modular (backwards compatible). 14 | - The original inference code can still be imported: `from shepherd.inference import inference_sample` 15 | - New inference functions can be imported with: 16 | - `from shepherd.inference import generate` 17 | - `from shepherd.inference import generate_from_intermediate_time` 18 | - Inference now supports atom and bond inpainting 19 | - `generate` is updated to allow atom inpainting from t=T 20 | - `generate_from_intermediate_time` is specialized to allow atom inpainting from an intermediate time (T ≤ t < 0) 21 | - Inference can store full diffusion trajectories by setting `return_trajectories=True` during sampling. 22 | - `shepherd.extract` 23 | - Added `remove_side_groups_with_geometry` 24 | - Added `remove_overlaps` to quickly filter sampled molecules that use atom-inpainting. 25 | 26 | 27 | # June 5, 2025 (v0.2.0) 28 | ### Refactoring and upgrades for PyTorch >= v2.5.1 29 | 30 | - Refactored ShEPhERD (aided by Matthew Cox's fork: https://github.com/mcox3406/shepherd/) 31 | - Updated import statements: throughout repo to import directly from `shepherd` assuming local install. 32 | - Fix depreciation warnings: 33 | - `torch.load()` -> `torch.load(weights_only=True)` 34 | - `@torch.cuda.amp.autocast(enabled=False)` -> `@torch.amp.autocast('cuda', enabled=False)` 35 | - Training scripts 36 | - Updated `src/shepherd/datasets.py` for higher versions of PyG. Required changes to the batching functionality for edges (still backwards compatible). 37 | - Slight changes to `training/train.py` for upgraded versions of PyTorch Lightning. 38 | - Model checkpoints have been UPDATED for PyTorch Lightning v2.5.1 39 | - The original checkpoints for PyTorch Lightning v1.2 can be found in previous commits (`c3d5ec0` or before), the original publication Release, or at the Dropbox data link: https://www.dropbox.com/scl/fo/rgn33g9kwthnjt27bsc3m/ADGt-CplyEXSU7u5MKc0aTo?rlkey=fhi74vkktpoj1irl84ehnw95h&e=1&st=wn46d6o2&dl=0 40 | - Created a basic unconditional generation test script 41 | - Updated the environment and relevant files to be compatible with PyTorch >= v2.5.1 42 | - Bug fix for `shepherd.datasets.HeteroDataset.__getitem__` where x3 point extraction should use `get_x2_data` 43 | 44 | #### Additional notes 45 | Thank you to Matthew Cox for his contributions in the updated code. 46 | 47 | 48 | # January 13, 2025 49 | - Added the ability to do partial inpainting for pharmacophores at inference. -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/gemnet_oc/initializers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | This source code is licensed under the MIT license found in the 4 | LICENSE file in the root directory of this source tree. 5 | """ 6 | 7 | from functools import partial 8 | 9 | import torch 10 | 11 | 12 | def _standardize(kernel): 13 | """ 14 | Makes sure that N*Var(W) = 1 and E[W] = 0 15 | """ 16 | eps = 1e-6 17 | 18 | if len(kernel.shape) == 3: 19 | axis = [0, 1] # last dimension is output dimension 20 | else: 21 | axis = 1 22 | 23 | var, mean = torch.var_mean(kernel, dim=axis, unbiased=True, keepdim=True) 24 | kernel = (kernel - mean) / (var + eps) ** 0.5 25 | return kernel 26 | 27 | 28 | def he_orthogonal_init(tensor): 29 | """ 30 | Generate a weight matrix with variance according to He (Kaiming) initialization. 31 | Based on a random (semi-)orthogonal matrix neural networks 32 | are expected to learn better when features are decorrelated 33 | (stated by eg. "Reducing overfitting in deep networks by decorrelating representations", 34 | "Dropout: a simple way to prevent neural networks from overfitting", 35 | "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks") 36 | """ 37 | tensor = torch.nn.init.orthogonal_(tensor) 38 | 39 | if len(tensor.shape) == 3: 40 | fan_in = tensor.shape[:-1].numel() 41 | else: 42 | fan_in = tensor.shape[1] 43 | 44 | with torch.no_grad(): 45 | tensor.data = _standardize(tensor.data) 46 | tensor.data *= (1 / fan_in) ** 0.5 47 | 48 | return tensor 49 | 50 | 51 | def grid_init(tensor, start=-1, end=1): 52 | """ 53 | Generate a weight matrix so that each input value corresponds to one value on a regular grid between start and end. 54 | """ 55 | fan_in = tensor.shape[1] 56 | 57 | with torch.no_grad(): 58 | data = torch.linspace( 59 | start, end, fan_in, device=tensor.device, dtype=tensor.dtype 60 | ).expand_as(tensor) 61 | tensor.copy_(data) 62 | 63 | return tensor 64 | 65 | 66 | def log_grid_init(tensor, start=-4, end=0): 67 | """ 68 | Generate a weight matrix so that each input value corresponds to one value on a regular logarithmic grid between 10^start and 10^end. 69 | """ 70 | fan_in = tensor.shape[1] 71 | 72 | with torch.no_grad(): 73 | data = torch.logspace( 74 | start, end, fan_in, device=tensor.device, dtype=tensor.dtype 75 | ).expand_as(tensor) 76 | tensor.copy_(data) 77 | 78 | return tensor 79 | 80 | 81 | def get_initializer(name, **init_kwargs): 82 | name = name.lower() 83 | if name == "heorthogonal": 84 | initializer = he_orthogonal_init 85 | elif name == "zeros": 86 | initializer = torch.nn.init.zeros_ 87 | elif name == "grid": 88 | initializer = grid_init 89 | elif name == "loggrid": 90 | initializer = log_grid_init 91 | else: 92 | raise UserWarning(f"Unknown initializer: {name}") 93 | 94 | initializer = partial(initializer, **init_kwargs) 95 | return initializer 96 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/gemnet_oc/layers/embedding_block.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | This source code is licensed under the MIT license found in the 4 | LICENSE file in the root directory of this source tree. 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from .base_layers import Dense 11 | 12 | 13 | class AtomEmbedding(torch.nn.Module): 14 | """ 15 | Initial atom embeddings based on the atom type 16 | 17 | Arguments 18 | --------- 19 | emb_size: int 20 | Atom embeddings size 21 | """ 22 | 23 | def __init__(self, emb_size, num_elements): 24 | super().__init__() 25 | self.emb_size = emb_size 26 | 27 | self.embeddings = torch.nn.Embedding(num_elements, emb_size) 28 | # init by uniform distribution 29 | torch.nn.init.uniform_( 30 | self.embeddings.weight, a=-np.sqrt(3), b=np.sqrt(3) 31 | ) 32 | 33 | def forward(self, Z): 34 | """ 35 | Returns 36 | ------- 37 | h: torch.Tensor, shape=(nAtoms, emb_size) 38 | Atom embeddings. 39 | """ 40 | h = self.embeddings(Z - 1) # -1 because Z.min()=1 (==Hydrogen) 41 | return h 42 | 43 | 44 | class EdgeEmbedding(torch.nn.Module): 45 | """ 46 | Edge embedding based on the concatenation of atom embeddings 47 | and a subsequent dense layer. 48 | 49 | Arguments 50 | --------- 51 | atom_features: int 52 | Embedding size of the atom embedding. 53 | edge_features: int 54 | Embedding size of the input edge embedding. 55 | out_features: int 56 | Embedding size after the dense layer. 57 | activation: str 58 | Activation function used in the dense layer. 59 | """ 60 | 61 | def __init__( 62 | self, 63 | atom_features, 64 | edge_features, 65 | out_features, 66 | activation=None, 67 | ): 68 | super().__init__() 69 | in_features = 2 * atom_features + edge_features 70 | self.dense = Dense( 71 | in_features, out_features, activation=activation, bias=False 72 | ) 73 | 74 | def forward( 75 | self, 76 | h, 77 | m, 78 | edge_index, 79 | ): 80 | """ 81 | Arguments 82 | --------- 83 | h: torch.Tensor, shape (num_atoms, atom_features) 84 | Atom embeddings. 85 | m: torch.Tensor, shape (num_edges, edge_features) 86 | Radial basis in embedding block, 87 | edge embedding in interaction block. 88 | 89 | Returns 90 | ------- 91 | m_st: torch.Tensor, shape=(nEdges, emb_size) 92 | Edge embeddings. 93 | """ 94 | h_s = h[edge_index[0]] # shape=(nEdges, emb_size) 95 | h_t = h[edge_index[1]] # shape=(nEdges, emb_size) 96 | 97 | m_st = torch.cat( 98 | [h_s, h_t, m], dim=-1 99 | ) # (nEdges, 2*emb_size+nFeatures) 100 | m_st = self.dense(m_st) # (nEdges, emb_size) 101 | return m_st 102 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/common/relaxation/ml_relaxation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import logging 9 | from collections import deque 10 | from pathlib import Path 11 | 12 | import torch 13 | from torch_geometric.data import Batch 14 | 15 | from ocpmodels.common.registry import registry 16 | from ocpmodels.datasets.lmdb_dataset import data_list_collater 17 | 18 | from .optimizers.lbfgs_torch import LBFGS, TorchCalc 19 | 20 | 21 | def ml_relax( 22 | batch, 23 | model, 24 | steps, 25 | fmax, 26 | relax_opt, 27 | save_full_traj, 28 | device="cuda:0", 29 | transform=None, 30 | early_stop_batch=False, 31 | ): 32 | """ 33 | Runs ML-based relaxations. 34 | Args: 35 | batch: object 36 | model: object 37 | steps: int 38 | Max number of steps in the structure relaxation. 39 | fmax: float 40 | Structure relaxation terminates when the max force 41 | of the system is no bigger than fmax. 42 | relax_opt: str 43 | Optimizer and corresponding parameters to be used for structure relaxations. 44 | save_full_traj: bool 45 | Whether to save out the full ASE trajectory. If False, only save out initial and final frames. 46 | """ 47 | batches = deque([batch[0]]) 48 | relaxed_batches = [] 49 | while batches: 50 | batch = batches.popleft() 51 | oom = False 52 | ids = batch.sid 53 | calc = TorchCalc(model, transform) 54 | 55 | # Run ML-based relaxation 56 | traj_dir = relax_opt.get("traj_dir", None) 57 | optimizer = LBFGS( 58 | batch, 59 | calc, 60 | maxstep=relax_opt.get("maxstep", 0.04), 61 | memory=relax_opt["memory"], 62 | damping=relax_opt.get("damping", 1.0), 63 | alpha=relax_opt.get("alpha", 70.0), 64 | device=device, 65 | save_full_traj=save_full_traj, 66 | traj_dir=Path(traj_dir) if traj_dir is not None else None, 67 | traj_names=ids, 68 | early_stop_batch=early_stop_batch, 69 | ) 70 | try: 71 | relaxed_batch = optimizer.run(fmax=fmax, steps=steps) 72 | relaxed_batches.append(relaxed_batch) 73 | except RuntimeError as e: 74 | oom = True 75 | torch.cuda.empty_cache() 76 | 77 | if oom: 78 | # move OOM recovery code outside of except clause to allow tensors to be freed. 79 | data_list = batch.to_data_list() 80 | if len(data_list) == 1: 81 | raise e 82 | logging.info( 83 | f"Failed to relax batch with size: {len(data_list)}, splitting into two..." 84 | ) 85 | mid = len(data_list) // 2 86 | batches.appendleft(data_list_collater(data_list[:mid])) 87 | batches.appendleft(data_list_collater(data_list[mid:])) 88 | 89 | relaxed_batch = Batch.from_data_list(relaxed_batches) 90 | return relaxed_batch 91 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/common/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | # Borrowed from https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/transforms/random_rotate.py 9 | # with changes to keep track of the rotation / inverse rotation matrices. 10 | 11 | import math 12 | import numbers 13 | import random 14 | 15 | import torch 16 | import torch_geometric 17 | from torch_geometric.transforms import LinearTransformation 18 | 19 | 20 | class RandomRotate(object): 21 | r"""Rotates node positions around a specific axis by a randomly sampled 22 | factor within a given interval. 23 | 24 | Args: 25 | degrees (tuple or float): Rotation interval from which the rotation 26 | angle is sampled. If `degrees` is a number instead of a 27 | tuple, the interval is given by :math:`[-\mathrm{degrees}, 28 | \mathrm{degrees}]`. 29 | axes (int, optional): The rotation axes. (default: `[0, 1, 2]`) 30 | """ 31 | 32 | def __init__(self, degrees, axes=[0, 1, 2]): 33 | if isinstance(degrees, numbers.Number): 34 | degrees = (-abs(degrees), abs(degrees)) 35 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2 36 | self.degrees = degrees 37 | self.axes = axes 38 | 39 | def __call__(self, data): 40 | if data.pos.size(-1) == 2: 41 | degree = math.pi * random.uniform(*self.degrees) / 180.0 42 | sin, cos = math.sin(degree), math.cos(degree) 43 | matrix = [[cos, sin], [-sin, cos]] 44 | else: 45 | m1, m2, m3 = torch.eye(3), torch.eye(3), torch.eye(3) 46 | if 0 in self.axes: 47 | degree = math.pi * random.uniform(*self.degrees) / 180.0 48 | sin, cos = math.sin(degree), math.cos(degree) 49 | m1 = torch.tensor([[1, 0, 0], [0, cos, sin], [0, -sin, cos]]) 50 | if 1 in self.axes: 51 | degree = math.pi * random.uniform(*self.degrees) / 180.0 52 | sin, cos = math.sin(degree), math.cos(degree) 53 | m2 = torch.tensor([[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]]) 54 | if 2 in self.axes: 55 | degree = math.pi * random.uniform(*self.degrees) / 180.0 56 | sin, cos = math.sin(degree), math.cos(degree) 57 | m3 = torch.tensor([[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]]) 58 | 59 | matrix = torch.mm(torch.mm(m1, m2), m3) 60 | 61 | data_rotated = LinearTransformation(matrix)(data) 62 | if torch_geometric.__version__.startswith("2."): 63 | matrix = matrix.T 64 | 65 | # LinearTransformation only rotates `.pos`; need to rotate `.cell` too. 66 | if hasattr(data_rotated, "cell"): 67 | data_rotated.cell = torch.matmul(data_rotated.cell, matrix) 68 | 69 | return ( 70 | data_rotated, 71 | matrix, 72 | torch.inverse(matrix), 73 | ) 74 | 75 | def __repr__(self): 76 | return "{}({}, axis={})".format( 77 | self.__class__.__name__, self.degrees, self.axis 78 | ) 79 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/gemnet_oc/layers/base_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | This source code is licensed under the MIT license found in the 4 | LICENSE file in the root directory of this source tree. 5 | """ 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from ..initializers import he_orthogonal_init 12 | 13 | 14 | class Dense(torch.nn.Module): 15 | """ 16 | Combines dense layer with scaling for silu activation. 17 | 18 | Arguments 19 | --------- 20 | in_features: int 21 | Input embedding size. 22 | out_features: int 23 | Output embedding size. 24 | bias: bool 25 | True if use bias. 26 | activation: str 27 | Name of the activation function to use. 28 | """ 29 | 30 | def __init__(self, in_features, out_features, bias=False, activation=None): 31 | super().__init__() 32 | 33 | self.linear = torch.nn.Linear(in_features, out_features, bias=bias) 34 | self.reset_parameters() 35 | 36 | if isinstance(activation, str): 37 | activation = activation.lower() 38 | if activation in ["silu", "swish"]: 39 | self._activation = ScaledSiLU() 40 | elif activation is None: 41 | self._activation = torch.nn.Identity() 42 | else: 43 | raise NotImplementedError( 44 | "Activation function not implemented for GemNet (yet)." 45 | ) 46 | 47 | def reset_parameters(self, initializer=he_orthogonal_init): 48 | initializer(self.linear.weight) 49 | if self.linear.bias is not None: 50 | self.linear.bias.data.fill_(0) 51 | 52 | def forward(self, x): 53 | x = self.linear(x) 54 | x = self._activation(x) 55 | return x 56 | 57 | 58 | class ScaledSiLU(torch.nn.Module): 59 | def __init__(self): 60 | super().__init__() 61 | self.scale_factor = 1 / 0.6 62 | self._activation = torch.nn.SiLU() 63 | 64 | def forward(self, x): 65 | return self._activation(x) * self.scale_factor 66 | 67 | 68 | class ResidualLayer(torch.nn.Module): 69 | """ 70 | Residual block with output scaled by 1/sqrt(2). 71 | 72 | Arguments 73 | --------- 74 | units: int 75 | Input and output embedding size. 76 | nLayers: int 77 | Number of dense layers. 78 | layer: torch.nn.Module 79 | Class for the layers inside the residual block. 80 | layer_kwargs: str 81 | Keyword arguments for initializing the layers. 82 | """ 83 | 84 | def __init__( 85 | self, units: int, nLayers: int = 2, layer=Dense, **layer_kwargs 86 | ): 87 | super().__init__() 88 | self.dense_mlp = torch.nn.Sequential( 89 | *[ 90 | layer( 91 | in_features=units, 92 | out_features=units, 93 | bias=False, 94 | **layer_kwargs 95 | ) 96 | for _ in range(nLayers) 97 | ] 98 | ) 99 | self.inv_sqrt_2 = 1 / math.sqrt(2) 100 | 101 | def forward(self, input): 102 | x = self.dense_mlp(input) 103 | x = input + x 104 | x = x * self.inv_sqrt_2 105 | return x 106 | -------------------------------------------------------------------------------- /examples/data/WX7.sdf: -------------------------------------------------------------------------------- 1 | 2 | RDKit 3D 3 | 4 | 35 38 0 0 0 0 0 0 0 0999 V2000 5 | -43.6900 -28.8990 2.2890 C 0 0 0 0 0 0 0 0 0 0 0 0 6 | -45.1650 -29.3540 4.3410 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -44.1700 -29.5510 6.4100 C 0 0 0 0 0 0 0 0 0 0 0 0 8 | -46.4750 -29.1960 6.3370 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | -47.7650 -28.9520 4.5790 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | -46.4280 -29.1750 4.9730 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | -44.3210 -33.2900 1.1720 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | -45.0900 -31.2530 0.4810 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | -45.6130 -29.7520 0.4620 C 0 0 2 0 0 0 0 0 0 0 0 0 14 | -44.5300 -28.9420 -0.0470 C 0 0 0 0 0 0 0 0 0 0 0 0 15 | -43.8670 -28.1440 0.9670 C 0 0 0 0 0 0 0 0 0 0 0 0 16 | -46.1150 -29.3030 1.8580 C 0 0 0 0 0 0 0 0 0 0 0 0 17 | -48.5420 -28.8630 5.7300 C 0 0 0 0 0 0 0 0 0 0 0 0 18 | -44.4050 -33.3150 -0.2720 N 0 0 0 0 0 0 0 0 0 0 0 0 19 | -44.8720 -32.0710 -0.6930 N 0 0 0 0 0 0 0 0 0 0 0 0 20 | -45.0070 -29.3270 2.9010 N 0 0 0 0 0 0 0 0 0 0 0 0 21 | -44.0800 -29.5310 5.0880 N 0 0 0 0 0 0 0 0 0 0 0 0 22 | -45.3230 -29.3950 7.0110 N 0 0 0 0 0 0 0 0 0 0 0 0 23 | -47.7510 -29.0100 6.7810 N 0 0 0 0 0 0 0 0 0 0 0 0 24 | -43.9620 -34.1830 1.8620 O 0 0 0 0 0 0 0 0 0 0 0 0 25 | -44.7420 -32.0170 1.6170 O 0 0 0 0 0 0 0 0 0 0 0 0 26 | -43.1330 -29.8090 2.0760 H 0 0 0 0 0 0 0 0 0 0 0 0 27 | -43.1390 -28.2930 2.9950 H 0 0 0 0 0 0 0 0 0 0 0 0 28 | -43.2860 -29.6930 7.0180 H 0 0 0 0 0 0 0 0 0 0 0 0 29 | -48.1150 -28.8690 3.5620 H 0 0 0 0 0 0 0 0 0 0 0 0 30 | -46.4370 -29.6620 -0.2330 H 0 0 0 0 0 0 0 0 0 0 0 0 31 | -43.7900 -29.5950 -0.5120 H 0 0 0 0 0 0 0 0 0 0 0 0 32 | -44.9270 -28.2730 -0.7960 H 0 0 0 0 0 0 0 0 0 0 0 0 33 | -42.8850 -27.8980 0.6020 H 0 0 0 0 0 0 0 0 0 0 0 0 34 | -44.4280 -27.2240 1.1340 H 0 0 0 0 0 0 0 0 0 0 0 0 35 | -46.9230 -29.9500 2.1790 H 0 0 0 0 0 0 0 0 0 0 0 0 36 | -46.4960 -28.2870 1.7710 H 0 0 0 0 0 0 0 0 0 0 0 0 37 | -49.6040 -28.7040 5.7610 H 0 0 0 0 0 0 0 0 0 0 0 0 38 | -44.1710 -34.0880 -0.8660 H 0 0 0 0 0 0 0 0 0 0 0 0 39 | -48.0410 -28.9910 7.7420 H 0 0 0 0 0 0 0 0 0 0 0 0 40 | 6 5 1 0 41 | 6 2 2 0 42 | 6 4 1 0 43 | 9 8 1 0 44 | 10 9 1 0 45 | 11 10 1 0 46 | 11 1 1 0 47 | 12 9 1 0 48 | 13 5 2 0 49 | 14 7 1 0 50 | 15 14 1 0 51 | 15 8 2 0 52 | 16 12 1 0 53 | 16 2 1 0 54 | 16 1 1 0 55 | 17 2 1 0 56 | 17 3 2 0 57 | 18 4 2 0 58 | 18 3 1 0 59 | 19 13 1 0 60 | 19 4 1 0 61 | 20 7 2 0 62 | 21 7 1 0 63 | 21 8 1 0 64 | 22 1 1 0 65 | 23 1 1 0 66 | 24 3 1 0 67 | 25 5 1 0 68 | 9 26 1 6 69 | 27 10 1 0 70 | 28 10 1 0 71 | 29 11 1 0 72 | 30 11 1 0 73 | 31 12 1 0 74 | 32 12 1 0 75 | 33 13 1 0 76 | 34 14 1 0 77 | 35 19 1 0 78 | M END 79 | > <_FileComments> (1) 80 | WX7 from mmCIF 81 | 82 | $$$$ 83 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/gemnet_oc/layers/force_scaler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | This source code is licensed under the MIT license found in the 4 | LICENSE file in the root directory of this source tree. 5 | """ 6 | 7 | import logging 8 | 9 | import torch 10 | 11 | 12 | class ForceScaler: 13 | """ 14 | Scales up the energy and then scales down the forces 15 | to prevent NaNs and infs in calculations using AMP. 16 | Inspired by torch.cuda.amp.GradScaler. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | init_scale=2.0**8, 22 | growth_factor=2.0, 23 | backoff_factor=0.5, 24 | growth_interval=2000, 25 | max_force_iters=50, 26 | enabled=True, 27 | ): 28 | self.scale_factor = init_scale 29 | self.growth_factor = growth_factor 30 | self.backoff_factor = backoff_factor 31 | self.growth_interval = growth_interval 32 | self.max_force_iters = max_force_iters 33 | self.enabled = enabled 34 | self.finite_force_results = 0 35 | 36 | def scale(self, energy): 37 | return energy * self.scale_factor if self.enabled else energy 38 | 39 | def unscale(self, forces): 40 | return forces / self.scale_factor if self.enabled else forces 41 | 42 | def calc_forces(self, energy, pos): 43 | energy_scaled = self.scale(energy) 44 | forces_scaled = -torch.autograd.grad( 45 | energy_scaled, 46 | pos, 47 | grad_outputs=torch.ones_like(energy_scaled), 48 | create_graph=True, 49 | )[0] 50 | # (nAtoms, 3) 51 | forces = self.unscale(forces_scaled) 52 | return forces 53 | 54 | def calc_forces_and_update(self, energy, pos): 55 | if self.enabled: 56 | found_nans_or_infs = True 57 | force_iters = 0 58 | 59 | # Re-calculate forces until everything is nice and finite. 60 | while found_nans_or_infs: 61 | forces = self.calc_forces(energy, pos) 62 | 63 | found_nans_or_infs = not torch.all(forces.isfinite()) 64 | if found_nans_or_infs: 65 | self.finite_force_results = 0 66 | 67 | # Prevent infinite loop 68 | force_iters += 1 69 | if force_iters == self.max_force_iters: 70 | logging.warning( 71 | "Too many non-finite force results in a batch. " 72 | "Breaking scaling loop." 73 | ) 74 | break 75 | else: 76 | # Delete graph to save memory 77 | del forces 78 | else: 79 | self.finite_force_results += 1 80 | self.update() 81 | else: 82 | forces = self.calc_forces(energy, pos) 83 | return forces 84 | 85 | def update(self): 86 | if self.finite_force_results == 0: 87 | self.scale_factor *= self.backoff_factor 88 | 89 | if self.finite_force_results == self.growth_interval: 90 | self.scale_factor *= self.growth_factor 91 | self.finite_force_results = 0 92 | 93 | logging.info(f"finite force step count: {self.finite_force_results}") 94 | logging.info(f"scaling factor: {self.scale_factor}") 95 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/tasks/task.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import logging 9 | import os 10 | 11 | from ocpmodels.common.registry import registry 12 | from ocpmodels.trainers.forces_trainer import ForcesTrainer 13 | 14 | 15 | class BaseTask: 16 | def __init__(self, config): 17 | self.config = config 18 | 19 | def setup(self, trainer): 20 | self.trainer = trainer 21 | if self.config["checkpoint"] is not None: 22 | self.trainer.load_checkpoint(self.config["checkpoint"]) 23 | 24 | # save checkpoint path to runner state for slurm resubmissions 25 | self.chkpt_path = os.path.join( 26 | self.trainer.config["cmd"]["checkpoint_dir"], "checkpoint.pt" 27 | ) 28 | 29 | def run(self): 30 | raise NotImplementedError 31 | 32 | 33 | @registry.register_task("train") 34 | class TrainTask(BaseTask): 35 | def _process_error(self, e: RuntimeError): 36 | e_str = str(e) 37 | if ( 38 | "find_unused_parameters" in e_str 39 | and "torch.nn.parallel.DistributedDataParallel" in e_str 40 | ): 41 | for name, parameter in self.trainer.model.named_parameters(): 42 | if parameter.requires_grad and parameter.grad is None: 43 | logging.warning( 44 | f"Parameter {name} has no gradient. Consider removing it from the model." 45 | ) 46 | 47 | def run(self): 48 | try: 49 | self.trainer.train( 50 | disable_eval_tqdm=self.config.get( 51 | "hide_eval_progressbar", False 52 | ) 53 | ) 54 | except RuntimeError as e: 55 | self._process_error(e) 56 | raise e 57 | 58 | 59 | @registry.register_task("predict") 60 | class PredictTask(BaseTask): 61 | def run(self): 62 | assert ( 63 | self.trainer.test_loader is not None 64 | ), "Test dataset is required for making predictions" 65 | assert self.config["checkpoint"] 66 | results_file = "predictions" 67 | self.trainer.predict( 68 | self.trainer.test_loader, 69 | results_file=results_file, 70 | disable_tqdm=self.config.get("hide_eval_progressbar", False), 71 | ) 72 | 73 | 74 | @registry.register_task("validate") 75 | class ValidateTask(BaseTask): 76 | def run(self): 77 | # Note that the results won't be precise on multi GPUs due to padding of extra images (although the difference should be minor) 78 | assert ( 79 | self.trainer.val_loader is not None 80 | ), "Val dataset is required for making predictions" 81 | assert self.config["checkpoint"] 82 | self.trainer.validate( 83 | split="val", 84 | disable_tqdm=self.config.get("hide_eval_progressbar", False), 85 | ) 86 | 87 | 88 | @registry.register_task("run-relaxations") 89 | class RelxationTask(BaseTask): 90 | def run(self): 91 | assert isinstance( 92 | self.trainer, ForcesTrainer 93 | ), "Relaxations are only possible for ForcesTrainer" 94 | assert ( 95 | self.trainer.relax_dataset is not None 96 | ), "Relax dataset is required for making predictions" 97 | assert self.config["checkpoint"] 98 | self.trainer.run_relaxations() 99 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/gemnet/layers/base_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import math 9 | 10 | import torch 11 | 12 | from ..initializers import he_orthogonal_init 13 | 14 | 15 | class Dense(torch.nn.Module): 16 | """ 17 | Combines dense layer with scaling for swish activation. 18 | 19 | Parameters 20 | ---------- 21 | units: int 22 | Output embedding size. 23 | activation: str 24 | Name of the activation function to use. 25 | bias: bool 26 | True if use bias. 27 | """ 28 | 29 | def __init__(self, in_features, out_features, bias=False, activation=None): 30 | super().__init__() 31 | 32 | self.linear = torch.nn.Linear(in_features, out_features, bias=bias) 33 | self.reset_parameters() 34 | 35 | if isinstance(activation, str): 36 | activation = activation.lower() 37 | if activation in ["swish", "silu"]: 38 | self._activation = ScaledSiLU() 39 | elif activation == "siqu": 40 | self._activation = SiQU() 41 | elif activation is None: 42 | self._activation = torch.nn.Identity() 43 | else: 44 | raise NotImplementedError( 45 | "Activation function not implemented for GemNet (yet)." 46 | ) 47 | 48 | def reset_parameters(self, initializer=he_orthogonal_init): 49 | initializer(self.linear.weight) 50 | if self.linear.bias is not None: 51 | self.linear.bias.data.fill_(0) 52 | 53 | def forward(self, x): 54 | x = self.linear(x) 55 | x = self._activation(x) 56 | return x 57 | 58 | 59 | class ScaledSiLU(torch.nn.Module): 60 | def __init__(self): 61 | super().__init__() 62 | self.scale_factor = 1 / 0.6 63 | self._activation = torch.nn.SiLU() 64 | 65 | def forward(self, x): 66 | return self._activation(x) * self.scale_factor 67 | 68 | 69 | class SiQU(torch.nn.Module): 70 | def __init__(self): 71 | super().__init__() 72 | self._activation = torch.nn.SiLU() 73 | 74 | def forward(self, x): 75 | return x * self._activation(x) 76 | 77 | 78 | class ResidualLayer(torch.nn.Module): 79 | """ 80 | Residual block with output scaled by 1/sqrt(2). 81 | 82 | Parameters 83 | ---------- 84 | units: int 85 | Output embedding size. 86 | nLayers: int 87 | Number of dense layers. 88 | layer_kwargs: str 89 | Keyword arguments for initializing the layers. 90 | """ 91 | 92 | def __init__( 93 | self, units: int, nLayers: int = 2, layer=Dense, **layer_kwargs 94 | ): 95 | super().__init__() 96 | self.dense_mlp = torch.nn.Sequential( 97 | *[ 98 | layer( 99 | in_features=units, 100 | out_features=units, 101 | bias=False, 102 | **layer_kwargs 103 | ) 104 | for _ in range(nLayers) 105 | ] 106 | ) 107 | self.inv_sqrt_2 = 1 / math.sqrt(2) 108 | 109 | def forward(self, input): 110 | x = self.dense_mlp(input) 111 | x = input + x 112 | x = x * self.inv_sqrt_2 113 | return x 114 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/gemnet/layers/spherical_basis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import sympy as sym 9 | import torch 10 | from torch_geometric.nn.models.schnet import GaussianSmearing 11 | 12 | from .basis_utils import real_sph_harm 13 | from .radial_basis import RadialBasis 14 | 15 | 16 | class CircularBasisLayer(torch.nn.Module): 17 | """ 18 | 2D Fourier Bessel Basis 19 | 20 | Parameters 21 | ---------- 22 | num_spherical: int 23 | Controls maximum frequency. 24 | radial_basis: RadialBasis 25 | Radial basis functions 26 | cbf: dict 27 | Name and hyperparameters of the cosine basis function 28 | efficient: bool 29 | Whether to use the "efficient" summation order 30 | """ 31 | 32 | def __init__( 33 | self, 34 | num_spherical: int, 35 | radial_basis: RadialBasis, 36 | cbf: str, 37 | efficient: bool = False, 38 | ): 39 | super().__init__() 40 | 41 | self.radial_basis = radial_basis 42 | self.efficient = efficient 43 | 44 | cbf_name = cbf["name"].lower() 45 | cbf_hparams = cbf.copy() 46 | del cbf_hparams["name"] 47 | 48 | if cbf_name == "gaussian": 49 | self.cosφ_basis = GaussianSmearing( 50 | start=-1, stop=1, num_gaussians=num_spherical, **cbf_hparams 51 | ) 52 | elif cbf_name == "spherical_harmonics": 53 | Y_lm = real_sph_harm( 54 | num_spherical, use_theta=False, zero_m_only=True 55 | ) 56 | sph_funcs = [] # (num_spherical,) 57 | 58 | # convert to tensorflow functions 59 | z = sym.symbols("z") 60 | modules = {"sin": torch.sin, "cos": torch.cos, "sqrt": torch.sqrt} 61 | m_order = 0 # only single angle 62 | for l_degree in range(len(Y_lm)): # num_spherical 63 | if ( 64 | l_degree == 0 65 | ): # Y_00 is only a constant -> function returns value and not tensor 66 | first_sph = sym.lambdify( 67 | [z], Y_lm[l_degree][m_order], modules 68 | ) 69 | sph_funcs.append( 70 | lambda z: torch.zeros_like(z) + first_sph(z) 71 | ) 72 | else: 73 | sph_funcs.append( 74 | sym.lambdify([z], Y_lm[l_degree][m_order], modules) 75 | ) 76 | self.cosφ_basis = lambda cosφ: torch.stack( 77 | [f(cosφ) for f in sph_funcs], dim=1 78 | ) 79 | else: 80 | raise ValueError(f"Unknown cosine basis function '{cbf_name}'.") 81 | 82 | def forward(self, D_ca, cosφ_cab, id3_ca): 83 | rbf = self.radial_basis(D_ca) # (num_edges, num_radial) 84 | cbf = self.cosφ_basis(cosφ_cab) # (num_triplets, num_spherical) 85 | 86 | if not self.efficient: 87 | rbf = rbf[id3_ca] # (num_triplets, num_radial) 88 | out = (rbf[:, None, :] * cbf[:, :, None]).view( 89 | -1, rbf.shape[-1] * cbf.shape[-1] 90 | ) 91 | return (out,) 92 | # (num_triplets, num_radial * num_spherical) 93 | else: 94 | return (rbf[None, :, :], cbf) 95 | # (1, num_edges, num_radial), (num_edges, num_spherical) 96 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/gemnet_gp/layers/base_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import math 9 | 10 | import torch 11 | 12 | from ..initializers import he_orthogonal_init 13 | 14 | 15 | class Dense(torch.nn.Module): 16 | """ 17 | Combines dense layer with scaling for swish activation. 18 | 19 | Parameters 20 | ---------- 21 | units: int 22 | Output embedding size. 23 | activation: str 24 | Name of the activation function to use. 25 | bias: bool 26 | True if use bias. 27 | """ 28 | 29 | def __init__(self, in_features, out_features, bias=False, activation=None): 30 | super().__init__() 31 | 32 | self.linear = torch.nn.Linear(in_features, out_features, bias=bias) 33 | self.reset_parameters() 34 | 35 | if isinstance(activation, str): 36 | activation = activation.lower() 37 | if activation in ["swish", "silu"]: 38 | self._activation = ScaledSiLU() 39 | elif activation == "siqu": 40 | self._activation = SiQU() 41 | elif activation is None: 42 | self._activation = torch.nn.Identity() 43 | else: 44 | raise NotImplementedError( 45 | "Activation function not implemented for GemNet (yet)." 46 | ) 47 | 48 | def reset_parameters(self, initializer=he_orthogonal_init): 49 | initializer(self.linear.weight) 50 | if self.linear.bias is not None: 51 | self.linear.bias.data.fill_(0) 52 | 53 | def forward(self, x): 54 | x = self.linear(x) 55 | x = self._activation(x) 56 | return x 57 | 58 | 59 | class ScaledSiLU(torch.nn.Module): 60 | def __init__(self): 61 | super().__init__() 62 | self.scale_factor = 1 / 0.6 63 | self._activation = torch.nn.SiLU() 64 | 65 | def forward(self, x): 66 | return self._activation(x) * self.scale_factor 67 | 68 | 69 | class SiQU(torch.nn.Module): 70 | def __init__(self): 71 | super().__init__() 72 | self._activation = torch.nn.SiLU() 73 | 74 | def forward(self, x): 75 | return x * self._activation(x) 76 | 77 | 78 | class ResidualLayer(torch.nn.Module): 79 | """ 80 | Residual block with output scaled by 1/sqrt(2). 81 | 82 | Parameters 83 | ---------- 84 | units: int 85 | Output embedding size. 86 | nLayers: int 87 | Number of dense layers. 88 | layer_kwargs: str 89 | Keyword arguments for initializing the layers. 90 | """ 91 | 92 | def __init__( 93 | self, units: int, nLayers: int = 2, layer=Dense, **layer_kwargs 94 | ): 95 | super().__init__() 96 | self.dense_mlp = torch.nn.Sequential( 97 | *[ 98 | layer( 99 | in_features=units, 100 | out_features=units, 101 | bias=False, 102 | **layer_kwargs 103 | ) 104 | for _ in range(nLayers) 105 | ] 106 | ) 107 | self.inv_sqrt_2 = 1 / math.sqrt(2) 108 | 109 | def forward(self, input): 110 | x = self.dense_mlp(input) 111 | x = input + x 112 | x = x * self.inv_sqrt_2 113 | return x 114 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/gemnet_gp/layers/spherical_basis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import sympy as sym 9 | import torch 10 | from torch_geometric.nn.models.schnet import GaussianSmearing 11 | 12 | from .basis_utils import real_sph_harm 13 | from .radial_basis import RadialBasis 14 | 15 | 16 | class CircularBasisLayer(torch.nn.Module): 17 | """ 18 | 2D Fourier Bessel Basis 19 | 20 | Parameters 21 | ---------- 22 | num_spherical: int 23 | Controls maximum frequency. 24 | radial_basis: RadialBasis 25 | Radial basis functions 26 | cbf: dict 27 | Name and hyperparameters of the cosine basis function 28 | efficient: bool 29 | Whether to use the "efficient" summation order 30 | """ 31 | 32 | def __init__( 33 | self, 34 | num_spherical: int, 35 | radial_basis: RadialBasis, 36 | cbf: str, 37 | efficient: bool = False, 38 | ): 39 | super().__init__() 40 | 41 | self.radial_basis = radial_basis 42 | self.efficient = efficient 43 | 44 | cbf_name = cbf["name"].lower() 45 | cbf_hparams = cbf.copy() 46 | del cbf_hparams["name"] 47 | 48 | if cbf_name == "gaussian": 49 | self.cosφ_basis = GaussianSmearing( 50 | start=-1, stop=1, num_gaussians=num_spherical, **cbf_hparams 51 | ) 52 | elif cbf_name == "spherical_harmonics": 53 | Y_lm = real_sph_harm( 54 | num_spherical, use_theta=False, zero_m_only=True 55 | ) 56 | sph_funcs = [] # (num_spherical,) 57 | 58 | # convert to tensorflow functions 59 | z = sym.symbols("z") 60 | modules = {"sin": torch.sin, "cos": torch.cos, "sqrt": torch.sqrt} 61 | m_order = 0 # only single angle 62 | for l_degree in range(len(Y_lm)): # num_spherical 63 | if ( 64 | l_degree == 0 65 | ): # Y_00 is only a constant -> function returns value and not tensor 66 | first_sph = sym.lambdify( 67 | [z], Y_lm[l_degree][m_order], modules 68 | ) 69 | sph_funcs.append( 70 | lambda z: torch.zeros_like(z) + first_sph(z) 71 | ) 72 | else: 73 | sph_funcs.append( 74 | sym.lambdify([z], Y_lm[l_degree][m_order], modules) 75 | ) 76 | self.cosφ_basis = lambda cosφ: torch.stack( 77 | [f(cosφ) for f in sph_funcs], dim=1 78 | ) 79 | else: 80 | raise ValueError(f"Unknown cosine basis function '{cbf_name}'.") 81 | 82 | def forward(self, D_ca, cosφ_cab, id3_ca): 83 | rbf = self.radial_basis(D_ca) # (num_edges, num_radial) 84 | cbf = self.cosφ_basis(cosφ_cab) # (num_triplets, num_spherical) 85 | 86 | if not self.efficient: 87 | rbf = rbf[id3_ca] # (num_triplets, num_radial) 88 | out = (rbf[:, None, :] * cbf[:, :, None]).view( 89 | -1, rbf.shape[-1] * cbf.shape[-1] 90 | ) 91 | return (out,) 92 | # (num_triplets, num_radial * num_spherical) 93 | else: 94 | return (rbf[None, :, :], cbf) 95 | # (1, num_edges, num_radial), (num_edges, num_spherical) 96 | -------------------------------------------------------------------------------- /data/conformers/fragment_merging/fragments/mol_4.mol: -------------------------------------------------------------------------------- 1 | D68EV3CPROA-x2021_0A 2 | RDKit 3D 3 | 4 | 38 39 0 0 0 0 0 0 0 0999 V2000 5 | -6.6950 -5.2370 -2.3310 N 0 0 0 0 0 0 0 0 0 0 0 0 6 | -5.8410 -0.9950 -0.3720 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -5.7270 -2.3950 -0.6220 O 0 0 0 0 0 0 0 0 0 0 0 0 8 | -6.2940 -2.8160 -1.8230 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | -3.9310 -7.0950 -2.0850 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | -2.6370 -7.6330 1.2030 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | -5.7130 -4.1650 -2.1190 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | -7.1500 -5.9540 -1.1370 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | -5.4160 -7.3860 -4.0970 C 0 0 0 0 0 0 0 0 0 0 0 0 14 | -4.9730 -7.7760 -2.7090 C 0 0 0 0 0 0 0 0 0 0 0 0 15 | -5.5990 -8.8210 -2.0310 C 0 0 0 0 0 0 0 0 0 0 0 0 16 | -5.2220 -9.1820 -0.7480 C 0 0 0 0 0 0 0 0 0 0 0 0 17 | -4.1910 -8.4790 -0.1580 C 0 0 0 0 0 0 0 0 0 0 0 0 18 | -3.5490 -7.4780 -0.8230 C 0 0 0 0 0 0 0 0 0 0 0 0 19 | -6.8480 -7.0700 -4.2150 N 0 0 0 0 0 0 0 0 0 0 0 0 20 | -6.6940 -4.5980 -4.6050 O 0 0 0 0 0 0 0 0 0 0 0 0 21 | -8.7920 -5.5550 -3.7650 O 0 0 0 0 0 0 0 0 0 0 0 0 22 | -2.4700 -7.0480 -0.0910 O 0 0 0 0 0 0 0 0 0 0 0 0 23 | -3.7000 -8.5800 1.1260 O 0 0 0 0 0 0 0 0 0 0 0 0 24 | -7.3710 -5.5830 -3.8260 S 0 0 0 0 0 0 0 0 0 0 0 0 25 | -5.8555 -0.8163 0.7133 H 0 0 0 0 0 0 0 0 0 0 0 0 26 | -4.9825 -0.4717 -0.8183 H 0 0 0 0 0 0 0 0 0 0 0 0 27 | -6.7729 -0.6175 -0.8183 H 0 0 0 0 0 0 0 0 0 0 0 0 28 | -7.3871 -2.8869 -1.7227 H 0 0 0 0 0 0 0 0 0 0 0 0 29 | -6.0803 -2.1046 -2.6343 H 0 0 0 0 0 0 0 0 0 0 0 0 30 | -3.4221 -6.2637 -2.5948 H 0 0 0 0 0 0 0 0 0 0 0 0 31 | -1.7084 -8.1385 1.5066 H 0 0 0 0 0 0 0 0 0 0 0 0 32 | -2.8712 -6.8560 1.9456 H 0 0 0 0 0 0 0 0 0 0 0 0 33 | -5.1009 -4.0795 -3.0290 H 0 0 0 0 0 0 0 0 0 0 0 0 34 | -5.1369 -4.4513 -1.2267 H 0 0 0 0 0 0 0 0 0 0 0 0 35 | -6.8110 -5.4198 -0.2371 H 0 0 0 0 0 0 0 0 0 0 0 0 36 | -8.2486 -6.0090 -1.1387 H 0 0 0 0 0 0 0 0 0 0 0 0 37 | -6.7319 -6.9715 -1.1387 H 0 0 0 0 0 0 0 0 0 0 0 0 38 | -4.8422 -6.4992 -4.4041 H 0 0 0 0 0 0 0 0 0 0 0 0 39 | -5.2349 -8.2600 -4.7399 H 0 0 0 0 0 0 0 0 0 0 0 0 40 | -6.4118 -9.3720 -2.5268 H 0 0 0 0 0 0 0 0 0 0 0 0 41 | -5.7276 -10.0019 -0.2168 H 0 0 0 0 0 0 0 0 0 0 0 0 42 | -7.1113 -7.2331 -5.1973 H 0 0 0 0 0 0 0 0 0 0 0 0 43 | 3 2 1 0 44 | 4 3 1 0 45 | 7 4 1 0 46 | 7 1 1 0 47 | 8 1 1 0 48 | 10 9 1 0 49 | 10 5 2 0 50 | 11 10 1 0 51 | 12 11 2 0 52 | 13 12 1 0 53 | 14 13 2 0 54 | 14 5 1 0 55 | 15 9 1 0 56 | 18 14 1 0 57 | 18 6 1 0 58 | 19 13 1 0 59 | 19 6 1 0 60 | 20 17 2 0 61 | 20 15 1 0 62 | 20 16 2 0 63 | 20 1 1 0 64 | 2 21 1 0 65 | 2 22 1 0 66 | 2 23 1 0 67 | 4 24 1 0 68 | 4 25 1 0 69 | 5 26 1 0 70 | 6 27 1 0 71 | 6 28 1 0 72 | 7 29 1 0 73 | 7 30 1 0 74 | 8 31 1 0 75 | 8 32 1 0 76 | 8 33 1 0 77 | 9 34 1 0 78 | 9 35 1 0 79 | 11 36 1 0 80 | 12 37 1 0 81 | 15 38 1 0 82 | M END 83 | -------------------------------------------------------------------------------- /data/conformers/fragment_merging/fragments/mol_11.mol: -------------------------------------------------------------------------------- 1 | D68EV3CPROA-x1919_0A 2 | RDKit 3D 3 | 4 | 39 40 0 0 0 0 0 0 0 0999 V2000 5 | -7.0000 -6.4740 -3.4420 N 0 0 0 0 0 0 0 0 0 0 0 0 6 | -5.3620 0.1160 -1.9960 C 0 0 0 0 0 0 0 0 0 0 0 0 7 | -6.3240 -2.1950 -0.8290 O 0 0 0 0 0 0 0 0 0 0 0 0 8 | -5.9710 -0.0750 -0.6630 C 0 0 0 0 0 0 0 0 0 0 0 0 9 | -3.5000 -7.7410 -0.5220 C 0 0 0 0 0 0 0 0 0 0 0 0 10 | -3.9630 -7.2910 -1.7400 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | -2.4170 -7.9870 1.4030 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | -6.9570 -0.9470 -0.7320 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | -7.2640 -3.2590 -0.9000 C 0 0 0 0 0 0 0 0 0 0 0 0 14 | -6.6200 -4.5170 -1.5010 C 0 0 0 0 0 0 0 0 0 0 0 0 15 | -5.7510 -7.2460 -3.5070 C 0 0 0 0 0 0 0 0 0 0 0 0 16 | -5.2420 -7.6640 -2.1500 C 0 0 0 0 0 0 0 0 0 0 0 0 17 | -6.0190 -8.4650 -1.3130 C 0 0 0 0 0 0 0 0 0 0 0 0 18 | -5.5510 -8.9040 -0.0840 C 0 0 0 0 0 0 0 0 0 0 0 0 19 | -4.2780 -8.5190 0.2940 C 0 0 0 0 0 0 0 0 0 0 0 0 20 | -8.4550 -4.4530 -3.3500 O 0 0 0 0 0 0 0 0 0 0 0 0 21 | -6.0890 -4.2120 -4.0230 O 0 0 0 0 0 0 0 0 0 0 0 0 22 | -2.2650 -7.5480 0.0520 O 0 0 0 0 0 0 0 0 0 0 0 0 23 | -3.6120 -8.7680 1.4750 O 0 0 0 0 0 0 0 0 0 0 0 0 24 | -7.0780 -4.8240 -3.1890 S 0 0 0 0 0 0 0 0 0 0 0 0 25 | -7.4755 -6.6396 -4.3406 H 0 0 0 0 0 0 0 0 0 0 0 0 26 | -4.3184 0.4433 -1.8788 H 0 0 0 0 0 0 0 0 0 0 0 0 27 | -5.3884 -0.8336 -2.5507 H 0 0 0 0 0 0 0 0 0 0 0 0 28 | -5.9259 0.8804 -2.5507 H 0 0 0 0 0 0 0 0 0 0 0 0 29 | -5.2066 -0.4459 0.0356 H 0 0 0 0 0 0 0 0 0 0 0 0 30 | -6.3685 0.8878 -0.3095 H 0 0 0 0 0 0 0 0 0 0 0 0 31 | -3.3352 -6.6498 -2.3762 H 0 0 0 0 0 0 0 0 0 0 0 0 32 | -1.5515 -8.5986 1.6977 H 0 0 0 0 0 0 0 0 0 0 0 0 33 | -2.4790 -7.1248 2.0833 H 0 0 0 0 0 0 0 0 0 0 0 0 34 | -7.5816 -0.8983 0.1722 H 0 0 0 0 0 0 0 0 0 0 0 0 35 | -7.6252 -0.7524 -1.5839 H 0 0 0 0 0 0 0 0 0 0 0 0 36 | -7.6263 -3.4883 0.1130 H 0 0 0 0 0 0 0 0 0 0 0 0 37 | -8.0996 -2.9486 -1.5446 H 0 0 0 0 0 0 0 0 0 0 0 0 38 | -5.5270 -4.4026 -1.4541 H 0 0 0 0 0 0 0 0 0 0 0 0 39 | -6.9809 -5.3729 -0.9117 H 0 0 0 0 0 0 0 0 0 0 0 0 40 | -5.9292 -8.1501 -4.1077 H 0 0 0 0 0 0 0 0 0 0 0 0 41 | -4.9868 -6.5958 -3.9578 H 0 0 0 0 0 0 0 0 0 0 0 0 42 | -7.0293 -8.7559 -1.6365 H 0 0 0 0 0 0 0 0 0 0 0 0 43 | -6.1711 -9.5368 0.5679 H 0 0 0 0 0 0 0 0 0 0 0 0 44 | 4 2 1 0 45 | 6 5 2 0 46 | 8 4 1 0 47 | 8 3 1 0 48 | 9 3 1 0 49 | 10 9 1 0 50 | 11 1 1 0 51 | 12 11 1 0 52 | 12 6 1 0 53 | 13 12 2 0 54 | 14 13 1 0 55 | 15 14 2 0 56 | 15 5 1 0 57 | 18 7 1 0 58 | 18 5 1 0 59 | 19 15 1 0 60 | 19 7 1 0 61 | 20 1 1 0 62 | 20 17 2 0 63 | 20 16 2 0 64 | 20 10 1 0 65 | 1 21 1 0 66 | 2 22 1 0 67 | 2 23 1 0 68 | 2 24 1 0 69 | 4 25 1 0 70 | 4 26 1 0 71 | 6 27 1 0 72 | 7 28 1 0 73 | 7 29 1 0 74 | 8 30 1 0 75 | 8 31 1 0 76 | 9 32 1 0 77 | 9 33 1 0 78 | 10 34 1 0 79 | 10 35 1 0 80 | 11 36 1 0 81 | 11 37 1 0 82 | 13 38 1 0 83 | 14 39 1 0 84 | M END 85 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import logging 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch_geometric.nn import radius_graph 13 | 14 | from shepherd.model.equiformer_v2.ocpmodels.common.utils import ( 15 | compute_neighbors, 16 | conditional_grad, 17 | get_pbc_distances, 18 | radius_graph_pbc, 19 | ) 20 | 21 | 22 | class BaseModel(nn.Module): 23 | def __init__(self, num_atoms=None, bond_feat_dim=None, num_targets=None): 24 | super(BaseModel, self).__init__() 25 | self.num_atoms = num_atoms 26 | self.bond_feat_dim = bond_feat_dim 27 | self.num_targets = num_targets 28 | 29 | def forward(self, data): 30 | raise NotImplementedError 31 | 32 | def generate_graph( 33 | self, 34 | data, 35 | cutoff=None, 36 | max_neighbors=None, 37 | use_pbc=None, 38 | otf_graph=None, 39 | ): 40 | cutoff = cutoff or self.cutoff 41 | max_neighbors = max_neighbors or self.max_neighbors 42 | use_pbc = use_pbc or self.use_pbc 43 | otf_graph = otf_graph or self.otf_graph 44 | 45 | if not otf_graph: 46 | try: 47 | edge_index = data.edge_index 48 | 49 | if use_pbc: 50 | cell_offsets = data.cell_offsets 51 | neighbors = data.neighbors 52 | 53 | except AttributeError: 54 | logging.warning( 55 | "Turning otf_graph=True as required attributes not present in data object" 56 | ) 57 | otf_graph = True 58 | 59 | if use_pbc: 60 | if otf_graph: 61 | edge_index, cell_offsets, neighbors = radius_graph_pbc( 62 | data, cutoff, max_neighbors 63 | ) 64 | 65 | out = get_pbc_distances( 66 | data.pos, 67 | edge_index, 68 | data.cell, 69 | cell_offsets, 70 | neighbors, 71 | return_offsets=True, 72 | return_distance_vec=True, 73 | ) 74 | 75 | edge_index = out["edge_index"] 76 | edge_dist = out["distances"] 77 | cell_offset_distances = out["offsets"] 78 | distance_vec = out["distance_vec"] 79 | else: 80 | if otf_graph: 81 | edge_index = radius_graph( 82 | data.pos, 83 | r=cutoff, 84 | batch=data.batch, 85 | max_num_neighbors=max_neighbors, 86 | ) 87 | 88 | j, i = edge_index 89 | distance_vec = data.pos[j] - data.pos[i] 90 | 91 | edge_dist = distance_vec.norm(dim=-1) 92 | cell_offsets = torch.zeros( 93 | edge_index.shape[1], 3, device=data.pos.device 94 | ) 95 | cell_offset_distances = torch.zeros_like( 96 | cell_offsets, device=data.pos.device 97 | ) 98 | neighbors = compute_neighbors(data, edge_index) 99 | 100 | return ( 101 | edge_index, 102 | edge_dist, 103 | distance_vec, 104 | cell_offsets, 105 | cell_offset_distances, 106 | neighbors, 107 | ) 108 | 109 | @property 110 | def num_params(self): 111 | return sum(p.numel() for p in self.parameters()) 112 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/common/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | import logging 8 | from abc import ABC, abstractmethod 9 | 10 | import torch 11 | import wandb 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | from ocpmodels.common.registry import registry 15 | 16 | 17 | class Logger(ABC): 18 | """Generic class to interface with various logging modules, e.g. wandb, 19 | tensorboard, etc. 20 | """ 21 | 22 | def __init__(self, config): 23 | self.config = config 24 | 25 | @abstractmethod 26 | def watch(self, model): 27 | """ 28 | Monitor parameters and gradients. 29 | """ 30 | pass 31 | 32 | def log(self, update_dict, step=None, split=""): 33 | """ 34 | Log some values. 35 | """ 36 | assert step is not None 37 | if split != "": 38 | new_dict = {} 39 | for key in update_dict: 40 | new_dict["{}/{}".format(split, key)] = update_dict[key] 41 | update_dict = new_dict 42 | return update_dict 43 | 44 | @abstractmethod 45 | def log_plots(self, plots): 46 | pass 47 | 48 | @abstractmethod 49 | def mark_preempting(self): 50 | pass 51 | 52 | 53 | @registry.register_logger("wandb") 54 | class WandBLogger(Logger): 55 | def __init__(self, config): 56 | super().__init__(config) 57 | project = ( 58 | self.config["logger"].get("project", None) 59 | if isinstance(self.config["logger"], dict) 60 | else None 61 | ) 62 | 63 | wandb.init( 64 | config=self.config, 65 | id=self.config["cmd"]["timestamp_id"], 66 | name=self.config["cmd"]["identifier"], 67 | dir=self.config["cmd"]["logs_dir"], 68 | project=project, 69 | resume="allow", 70 | ) 71 | 72 | def watch(self, model): 73 | wandb.watch(model) 74 | 75 | def log(self, update_dict, step=None, split=""): 76 | update_dict = super().log(update_dict, step, split) 77 | wandb.log(update_dict, step=int(step)) 78 | 79 | def log_plots(self, plots, caption=""): 80 | assert isinstance(plots, list) 81 | plots = [wandb.Image(x, caption=caption) for x in plots] 82 | wandb.log({"data": plots}) 83 | 84 | def mark_preempting(self): 85 | wandb.mark_preempting() 86 | 87 | 88 | @registry.register_logger("tensorboard") 89 | class TensorboardLogger(Logger): 90 | def __init__(self, config): 91 | super().__init__(config) 92 | self.writer = SummaryWriter(self.config["cmd"]["logs_dir"]) 93 | 94 | # TODO: add a model hook for watching gradients. 95 | def watch(self, model): 96 | logging.warning( 97 | "Model gradient logging to tensorboard not yet supported." 98 | ) 99 | return False 100 | 101 | def log(self, update_dict, step=None, split=""): 102 | update_dict = super().log(update_dict, step, split) 103 | for key in update_dict: 104 | if torch.is_tensor(update_dict[key]): 105 | self.writer.add_scalar(key, update_dict[key].item(), step) 106 | else: 107 | assert isinstance(update_dict[key], int) or isinstance( 108 | update_dict[key], float 109 | ) 110 | self.writer.add_scalar(key, update_dict[key], step) 111 | 112 | def mark_preempting(self): 113 | pass 114 | 115 | def log_plots(self, plots): 116 | pass 117 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/gemnet_oc/README.md: -------------------------------------------------------------------------------- 1 | # GemNet-OC: Developing Graph Neural Networks for Large and Diverse Molecular Simulation Datasets 2 | 3 | Johannes Gasteiger, Muhammed Shuaibi, Anuroop Sriram, Stephan Günnemann, Zachary Ulissi, C. Lawrence Zitnick, Abhishek Das 4 | 5 | [[`arXiv:2204.02782`](https://arxiv.org/abs/2204.02782)] 6 | 7 | When running inference with a pretrained GemNet-OC model, make sure that the 8 | `scale_file` path is correct in the config, otherwise predictions will be inaccurate. 9 | 10 | | Model | Val ID 30k Force MAE | Val ID 30k Energy MAE | Val ID 30k Force cos | Test metrics | Download | 11 | | ----- | -------------------- | --------------------- | -------------------- | ------------ | -------- | 12 | | gemnet_oc_2M | 0.0225 | 0.2299 | 0.6174 | [S2EF](https://evalai.s3.amazonaws.com/media/submission_files/submission_179229/062c037e-4f1f-49c2-9eeb-8e14681a70ee.json) \| [IS2RE](https://evalai.s3.amazonaws.com/media/submission_files/submission_179296/6688f44f-9d5a-4020-beca-8b804e0212fb.json) \| [IS2RS](https://evalai.s3.amazonaws.com/media/submission_files/submission_179257/0d02a349-0abe-44c0-a65c-29a9df75c886.json) | [config](https://github.com/Open-Catalyst-Project/ocp/blob/main/configs/s2ef/2M/gemnet/gemnet-oc.yml) \| [checkpoint](https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_07/s2ef/gemnet_oc_base_s2ef_2M.pt) \| [scale file](https://github.com/Open-Catalyst-Project/ocp/blob/481f3a5a92dc787384ddae9fe3f50f5d932712fd/configs/s2ef/all/gemnet/scaling_factors/gemnet-oc.pt) | 13 | | gemnet_oc_all | 0.0179 | 0.1668 | 0.6879 | [S2EF](https://evalai.s3.amazonaws.com/media/submission_files/submission_179008/6e731f20-17cf-417e-b0ad-97352be8cc37.json) \| [IS2RE]() \| [IS2RS](https://evalai.s3.amazonaws.com/media/submission_files/submission_160550/72a65a42-1fa9-44c5-8546-9eb691df8d2e.json) | [config](https://github.com/Open-Catalyst-Project/ocp/blob/main/configs/s2ef/all/gemnet/gemnet-oc.yml) \| [checkpoint](https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_07/s2ef/gemnet_oc_base_s2ef_all.pt) \| [scale file](https://github.com/Open-Catalyst-Project/ocp/blob/481f3a5a92dc787384ddae9fe3f50f5d932712fd/configs/s2ef/all/gemnet/scaling_factors/gemnet-oc.pt) | 14 | | gemnet_oc_large_all_md_energy | 0.0178 | 0.1504 | 0.6906 | [S2EF](https://evalai.s3.amazonaws.com/media/submission_files/submission_179143/40940149-6a4a-49a4-a2ce-38486215990f.json) | - | 15 | | gemnet_oc_large_all_md_force | 0.0164 | 0.1665 | 0.7139 | [S2EF](https://evalai.s3.amazonaws.com/media/submission_files/submission_179042/ba160459-0de3-4583-a98b-12102138c61e.json) \| [IS2RS](https://evalai.s3.amazonaws.com/media/submission_files/submission_169243/10bc7c8d-5124-4338-aaf3-04a7d015c4a0.json) | [config](https://github.com/Open-Catalyst-Project/ocp/blob/main/configs/s2ef/all/gemnet/gemnet-oc-large.yml) \| [checkpoint](https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_07/s2ef/gemnet_oc_large_s2ef_all_md.pt) \| [scale file](https://github.com/Open-Catalyst-Project/ocp/blob/481f3a5a92dc787384ddae9fe3f50f5d932712fd/configs/s2ef/all/gemnet/scaling_factors/gemnet-oc-large.pt) | 16 | | gemnet_oc_large_all_md_energy + gemnet_oc_large_all_md_force | - | - | - | [IS2RE](https://evalai.s3.amazonaws.com/media/submission_files/submission_212962/6acc7cf7-e18b-4d6a-9082-b4a114110dbf.json) | - | 17 | 18 | ## Citing 19 | 20 | If you use GemNet-OC in your work, please consider citing: 21 | 22 | ```bibtex 23 | @article{gasteiger_gemnet_oc_2022, 24 | title = {{GemNet-OC: Developing Graph Neural Networks for Large and Diverse Molecular Simulation Datasets}}, 25 | author = {Gasteiger, Johannes and Shuaibi, Muhammed and Sriram, Anuroop and G{\"u}nnemann, Stephan and Ulissi, Zachary and Zitnick, C Lawrence and Das, Abhishek}, 26 | journal = {Transactions on Machine Learning Research (TMLR)}, 27 | year = {2022}, 28 | } 29 | ``` 30 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/oc20/configs/s2ef/2M/equiformer_v2/equiformer_v2_N@12_L@6_M@2.yml: -------------------------------------------------------------------------------- 1 | trainer: forces_v2 2 | 3 | 4 | dataset: 5 | - src: datasets/oc20/s2ef/2M/train/ 6 | normalize_labels: True 7 | target_mean: -0.7554450631141663 8 | target_std: 2.887317180633545 9 | grad_target_mean: 0.0 10 | grad_target_std: 2.887317180633545 11 | - src: datasets/oc20/s2ef/all/val_id/ 12 | 13 | 14 | logger: wandb 15 | 16 | 17 | task: 18 | dataset: trajectory_lmdb_v2 19 | description: "Regressing to energies and forces for DFT trajectories from OCP" 20 | type: regression 21 | metric: force_mae 22 | labels: 23 | - potential energy 24 | grad_input: atomic forces 25 | train_on_free_atoms: True 26 | eval_on_free_atoms: True 27 | 28 | 29 | hide_eval_progressbar: False 30 | 31 | 32 | model: 33 | name: equiformer_v2 34 | 35 | use_pbc: True 36 | regress_forces: True 37 | otf_graph: True 38 | max_neighbors: 20 39 | max_radius: 12.0 40 | max_num_elements: 90 41 | 42 | num_layers: 12 43 | sphere_channels: 128 44 | attn_hidden_channels: 64 # [64, 96] This determines the hidden size of message passing. Do not necessarily use 96. 45 | num_heads: 8 46 | attn_alpha_channels: 64 # Not used when `use_s2_act_attn` is True. 47 | attn_value_channels: 16 48 | ffn_hidden_channels: 128 49 | norm_type: 'layer_norm_sh' # ['rms_norm_sh', 'layer_norm', 'layer_norm_sh'] 50 | 51 | lmax_list: [6] 52 | mmax_list: [2] 53 | grid_resolution: 18 # [18, 16, 14, None] For `None`, simply comment this line. 54 | 55 | num_sphere_samples: 128 56 | 57 | edge_channels: 128 58 | use_atom_edge_embedding: True 59 | share_atom_edge_embedding: False # If `True`, `use_atom_edge_embedding` must be `True` and the atom edge embedding will be shared across all blocks. 60 | distance_function: 'gaussian' 61 | num_distance_basis: 512 # not used 62 | 63 | attn_activation: 'silu' 64 | use_s2_act_attn: False # [False, True] Switch between attention after S2 activation or the original EquiformerV1 attention. 65 | use_attn_renorm: True # Attention re-normalization. Used for ablation study. 66 | ffn_activation: 'silu' # ['silu', 'swiglu'] 67 | use_gate_act: False # [True, False] Switch between gate activation and S2 activation 68 | use_grid_mlp: True # [False, True] If `True`, use projecting to grids and performing MLPs for FFNs. 69 | use_sep_s2_act: True # Separable S2 activation. Used for ablation study. 70 | 71 | alpha_drop: 0.1 # [0.0, 0.1] 72 | drop_path_rate: 0.05 # [0.0, 0.05] 73 | proj_drop: 0.0 74 | 75 | weight_init: 'uniform' # ['uniform', 'normal'] 76 | 77 | 78 | optim: 79 | batch_size: 4 # 6 80 | eval_batch_size: 4 # 6 81 | grad_accumulation_steps: 1 # gradient accumulation: effective batch size = `grad_accumulation_steps` * `batch_size` * (num of GPUs) 82 | load_balancing: atoms 83 | num_workers: 8 84 | lr_initial: 0.0002 # [0.0002, 0.0004], eSCN uses 0.0008 for batch size 96 85 | 86 | optimizer: AdamW 87 | optimizer_params: 88 | weight_decay: 0.001 89 | scheduler: LambdaLR 90 | scheduler_params: 91 | lambda_type: cosine 92 | warmup_factor: 0.2 93 | warmup_epochs: 0.1 94 | lr_min_factor: 0.01 95 | 96 | max_epochs: 12 97 | force_coefficient: 100 98 | energy_coefficient: 2 99 | clip_grad_norm: 100 100 | ema_decay: 0.999 101 | loss_energy: mae 102 | loss_force: l2mae 103 | 104 | eval_every: 5000 105 | 106 | 107 | #slurm: 108 | # constraint: "volta32gb" -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/oc20/configs/s2ef/2M/equiformer_v2/equiformer_v2_N@12_L@6_M@2_epochs@30.yml: -------------------------------------------------------------------------------- 1 | trainer: forces_v2 2 | 3 | 4 | dataset: 5 | - src: datasets/oc20/s2ef/2M/train/ 6 | normalize_labels: True 7 | target_mean: -0.7554450631141663 8 | target_std: 2.887317180633545 9 | grad_target_mean: 0.0 10 | grad_target_std: 2.887317180633545 11 | - src: datasets/oc20/s2ef/all/val_id/ 12 | 13 | 14 | logger: wandb 15 | 16 | 17 | task: 18 | dataset: trajectory_lmdb_v2 19 | description: "Regressing to energies and forces for DFT trajectories from OCP" 20 | type: regression 21 | metric: force_mae 22 | labels: 23 | - potential energy 24 | grad_input: atomic forces 25 | train_on_free_atoms: True 26 | eval_on_free_atoms: True 27 | 28 | 29 | hide_eval_progressbar: False 30 | 31 | 32 | model: 33 | name: equiformer_v2 34 | 35 | use_pbc: True 36 | regress_forces: True 37 | otf_graph: True 38 | max_neighbors: 20 39 | max_radius: 12.0 40 | max_num_elements: 90 41 | 42 | num_layers: 12 43 | sphere_channels: 128 44 | attn_hidden_channels: 64 # [64, 96] This determines the hidden size of message passing. Do not necessarily use 96. 45 | num_heads: 8 46 | attn_alpha_channels: 64 # Not used when `use_s2_act_attn` is True. 47 | attn_value_channels: 16 48 | ffn_hidden_channels: 128 49 | norm_type: 'layer_norm_sh' # ['rms_norm_sh', 'layer_norm', 'layer_norm_sh'] 50 | 51 | lmax_list: [6] 52 | mmax_list: [2] 53 | grid_resolution: 18 # [18, 16, 14, None] For `None`, simply comment this line. 54 | 55 | num_sphere_samples: 128 56 | 57 | edge_channels: 128 58 | use_atom_edge_embedding: True 59 | share_atom_edge_embedding: False # If `True`, `use_atom_edge_embedding` must be `True` and the atom edge embedding will be shared across all blocks. 60 | distance_function: 'gaussian' 61 | num_distance_basis: 512 # not used 62 | 63 | attn_activation: 'silu' 64 | use_s2_act_attn: False # [False, True] Switch between attention after S2 activation or the original EquiformerV1 attention. 65 | use_attn_renorm: True # Attention re-normalization. Used for ablation study. 66 | ffn_activation: 'silu' # ['silu', 'swiglu'] 67 | use_gate_act: False # [True, False] Switch between gate activation and S2 activation 68 | use_grid_mlp: True # [False, True] If `True`, use projecting to grids and performing MLPs for FFNs. 69 | use_sep_s2_act: True # Separable S2 activation. Used for ablation study. 70 | 71 | alpha_drop: 0.1 # [0.0, 0.1] 72 | drop_path_rate: 0.05 # [0.0, 0.05] 73 | proj_drop: 0.0 74 | 75 | weight_init: 'uniform' # ['uniform', 'normal'] 76 | 77 | 78 | optim: 79 | batch_size: 4 # 6 80 | eval_batch_size: 4 # 6 81 | grad_accumulation_steps: 1 # gradient accumulation: effective batch size = `grad_accumulation_steps` * `batch_size` * (num of GPUs) 82 | load_balancing: atoms 83 | num_workers: 8 84 | lr_initial: 0.0004 # [0.0002, 0.0004], eSCN uses 0.0008 for batch size 96 85 | 86 | optimizer: AdamW 87 | optimizer_params: 88 | weight_decay: 0.001 89 | scheduler: LambdaLR 90 | scheduler_params: 91 | lambda_type: cosine 92 | warmup_factor: 0.2 93 | warmup_epochs: 0.1 94 | lr_min_factor: 0.01 95 | 96 | max_epochs: 30 97 | force_coefficient: 100 98 | energy_coefficient: 2 99 | clip_grad_norm: 100 100 | ema_decay: 0.999 101 | loss_energy: mae 102 | loss_force: l2mae 103 | 104 | eval_every: 5000 105 | 106 | 107 | #slurm: 108 | # constraint: "volta32gb" -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/oc20/configs/s2ef/all_md/equiformer_v2/equiformer_v2_N@8_L@4_M@2_31M.yml: -------------------------------------------------------------------------------- 1 | trainer: forces_v2 2 | 3 | 4 | dataset: 5 | train: 6 | src: datasets/oc20/s2ef/all_md/train/ 7 | normalize_labels: True 8 | target_mean: -0.7554450631141663 9 | target_std: 2.887317180633545 10 | grad_target_mean: 0.0 11 | grad_target_std: 2.887317180633545 12 | val: 13 | src: datasets/oc20/s2ef/all/val_id/ 14 | 15 | 16 | logger: wandb 17 | 18 | 19 | task: 20 | dataset: trajectory_lmdb_v2 21 | primary_metric: forces_mae 22 | labels: 23 | - potential energy 24 | grad_input: atomic forces 25 | train_on_free_atoms: True 26 | eval_on_free_atoms: True 27 | # relaxation_steps: 300 28 | # write_pos: True 29 | # # num_relaxation_batches: 100 30 | # relax_dataset: 31 | # src: 32 | # relax_opt: 33 | # name: lbfgs 34 | # maxstep: 0.04 35 | # memory: 50 36 | # damping: 1.0 37 | # alpha: 70.0 38 | # traj_dir: path/to/traj/dir 39 | 40 | 41 | hide_eval_progressbar: False 42 | 43 | 44 | model: 45 | name: equiformer_v2 46 | 47 | use_pbc: True 48 | regress_forces: True 49 | otf_graph: True 50 | max_neighbors: 20 51 | max_radius: 12.0 52 | max_num_elements: 90 53 | 54 | num_layers: 8 55 | sphere_channels: 128 56 | attn_hidden_channels: 64 # [64, 96] This determines the hidden size of message passing. Do not necessarily use 96. 57 | num_heads: 8 58 | attn_alpha_channels: 64 # Not used when `use_s2_act_attn` is True. 59 | attn_value_channels: 16 60 | ffn_hidden_channels: 128 61 | norm_type: 'layer_norm_sh' # ['rms_norm_sh', 'layer_norm', 'layer_norm_sh'] 62 | 63 | lmax_list: [4] 64 | mmax_list: [2] 65 | grid_resolution: 18 # [18, 16, 14, None] For `None`, simply comment this line. 66 | 67 | num_sphere_samples: 128 68 | 69 | edge_channels: 128 70 | use_atom_edge_embedding: True 71 | share_atom_edge_embedding: False # If `True`, `use_atom_edge_embedding` must be `True` and the atom edge embedding will be shared across all blocks. 72 | distance_function: 'gaussian' 73 | num_distance_basis: 512 # not used 74 | 75 | attn_activation: 'silu' 76 | use_s2_act_attn: False # [False, True] Switch between attention after S2 activation or the original EquiformerV1 attention. 77 | use_attn_renorm: True # Attention re-normalization. Used for ablation study. 78 | ffn_activation: 'silu' # ['silu', 'swiglu'] 79 | use_gate_act: False # [True, False] Switch between gate activation and S2 activation 80 | use_grid_mlp: True # [False, True] If `True`, use projecting to grids and performing MLPs for FFNs. 81 | use_sep_s2_act: True # Separable S2 activation. Used for ablation study. 82 | 83 | alpha_drop: 0.1 # [0.0, 0.1] 84 | drop_path_rate: 0.1 # [0.0, 0.05] 85 | proj_drop: 0.0 86 | 87 | weight_init: 'uniform' # ['uniform', 'normal'] 88 | 89 | 90 | optim: 91 | batch_size: 8 # 6 92 | eval_batch_size: 12 # 6 93 | grad_accumulation_steps: 1 # gradient accumulation: effective batch size = `grad_accumulation_steps` * `batch_size` * (num of GPUs) 94 | load_balancing: atoms 95 | num_workers: 8 96 | lr_initial: 0.0004 # [0.0002, 0.0004], eSCN uses 0.0008 for batch size 96 97 | 98 | optimizer: AdamW 99 | optimizer_params: 100 | weight_decay: 0.001 101 | scheduler: LambdaLR 102 | scheduler_params: 103 | lambda_type: cosine 104 | warmup_factor: 0.2 105 | warmup_epochs: 0.01 106 | lr_min_factor: 0.01 107 | 108 | max_epochs: 3 109 | force_coefficient: 100 110 | energy_coefficient: 4 111 | clip_grad_norm: 100 112 | ema_decay: 0.999 113 | loss_energy: mae 114 | loss_force: l2mae 115 | 116 | eval_every: 10000 117 | 118 | 119 | #slurm: 120 | # constraint: "volta32gb" 121 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/oc20/configs/s2ef/all_md/equiformer_v2/equiformer_v2_N@20_L@6_M@3_153M.yml: -------------------------------------------------------------------------------- 1 | trainer: forces_v2 2 | 3 | 4 | dataset: 5 | train: 6 | src: datasets/oc20/s2ef/all_md/train/ 7 | normalize_labels: True 8 | target_mean: -0.7554450631141663 9 | target_std: 2.887317180633545 10 | grad_target_mean: 0.0 11 | grad_target_std: 2.887317180633545 12 | val: 13 | src: datasets/oc20/s2ef/all/val_id/ 14 | 15 | 16 | logger: wandb 17 | 18 | 19 | task: 20 | dataset: trajectory_lmdb_v2 21 | primary_metric: forces_mae 22 | labels: 23 | - potential energy 24 | grad_input: atomic forces 25 | train_on_free_atoms: True 26 | eval_on_free_atoms: True 27 | # relaxation_steps: 300 28 | # write_pos: True 29 | # # num_relaxation_batches: 100 30 | # relax_dataset: 31 | # src: 32 | # relax_opt: 33 | # name: lbfgs 34 | # maxstep: 0.04 35 | # memory: 50 36 | # damping: 1.0 37 | # alpha: 70.0 38 | # traj_dir: path/to/traj/dir 39 | 40 | 41 | hide_eval_progressbar: False 42 | 43 | 44 | model: 45 | name: equiformer_v2 46 | 47 | use_pbc: True 48 | regress_forces: True 49 | otf_graph: True 50 | max_neighbors: 20 51 | max_radius: 12.0 52 | max_num_elements: 90 53 | 54 | num_layers: 20 55 | sphere_channels: 128 56 | attn_hidden_channels: 64 # [64, 96] This determines the hidden size of message passing. Do not necessarily use 96. 57 | num_heads: 8 58 | attn_alpha_channels: 64 # Not used when `use_s2_act_attn` is True. 59 | attn_value_channels: 16 60 | ffn_hidden_channels: 128 61 | norm_type: 'layer_norm_sh' # ['rms_norm_sh', 'layer_norm', 'layer_norm_sh'] 62 | 63 | lmax_list: [6] 64 | mmax_list: [3] 65 | grid_resolution: 18 # [18, 16, 14, None] For `None`, simply comment this line. 66 | 67 | num_sphere_samples: 128 68 | 69 | edge_channels: 128 70 | use_atom_edge_embedding: True 71 | share_atom_edge_embedding: False # If `True`, `use_atom_edge_embedding` must be `True` and the atom edge embedding will be shared across all blocks. 72 | distance_function: 'gaussian' 73 | num_distance_basis: 512 # not used 74 | 75 | attn_activation: 'silu' 76 | use_s2_act_attn: False # [False, True] Switch between attention after S2 activation or the original EquiformerV1 attention. 77 | use_attn_renorm: True # Attention re-normalization. Used for ablation study. 78 | ffn_activation: 'silu' # ['silu', 'swiglu'] 79 | use_gate_act: False # [True, False] Switch between gate activation and S2 activation 80 | use_grid_mlp: True # [False, True] If `True`, use projecting to grids and performing MLPs for FFNs. 81 | use_sep_s2_act: True # Separable S2 activation. Used for ablation study. 82 | 83 | alpha_drop: 0.1 # [0.0, 0.1] 84 | drop_path_rate: 0.1 # [0.0, 0.05] 85 | proj_drop: 0.0 86 | 87 | weight_init: 'uniform' # ['uniform', 'normal'] 88 | 89 | 90 | optim: 91 | batch_size: 4 # 6 92 | eval_batch_size: 4 # 6 93 | grad_accumulation_steps: 1 # gradient accumulation: effective batch size = `grad_accumulation_steps` * `batch_size` * (num of GPUs) 94 | load_balancing: atoms 95 | num_workers: 8 96 | lr_initial: 0.0004 # [0.0002, 0.0004], eSCN uses 0.0008 for batch size 96 97 | 98 | optimizer: AdamW 99 | optimizer_params: 100 | weight_decay: 0.001 101 | scheduler: LambdaLR 102 | scheduler_params: 103 | lambda_type: cosine 104 | warmup_factor: 0.2 105 | warmup_epochs: 0.01 106 | lr_min_factor: 0.01 107 | 108 | max_epochs: 1 109 | force_coefficient: 100 110 | energy_coefficient: 4 111 | clip_grad_norm: 100 112 | ema_decay: 0.999 113 | loss_energy: mae 114 | loss_force: l2mae 115 | 116 | eval_every: 10000 117 | 118 | 119 | #slurm: 120 | # constraint: "volta32gb" 121 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/oc20/trainer/dist_setup.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 1. Copy distutils.setup from https://github.com/Open-Catalyst-Project/ocp/blob/89948582edfb8debb736406d54db9813a5f2c88d/ocpmodels/common/distutils.py#L16 3 | 2. Add OpenMPI multi-node training as Submitit is not supported. 4 | ''' 5 | 6 | import logging 7 | import os 8 | import subprocess 9 | 10 | import torch 11 | import torch.distributed as dist 12 | 13 | 14 | def setup(config): 15 | if config["submit"]: 16 | node_list = os.environ.get("SLURM_STEP_NODELIST") 17 | if node_list is None: 18 | node_list = os.environ.get("SLURM_JOB_NODELIST") 19 | if node_list is not None: 20 | try: 21 | hostnames = subprocess.check_output( 22 | ["scontrol", "show", "hostnames", node_list] 23 | ) 24 | config["init_method"] = "tcp://{host}:{port}".format( 25 | host=hostnames.split()[0].decode("utf-8"), 26 | port=config["distributed_port"], 27 | ) 28 | nnodes = int(os.environ.get("SLURM_NNODES")) 29 | ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") 30 | if ntasks_per_node is not None: 31 | ntasks_per_node = int(ntasks_per_node) 32 | else: 33 | ntasks = int(os.environ.get("SLURM_NTASKS")) 34 | nnodes = int(os.environ.get("SLURM_NNODES")) 35 | assert ntasks % nnodes == 0 36 | ntasks_per_node = int(ntasks / nnodes) 37 | if ntasks_per_node == 1: 38 | assert config["world_size"] % nnodes == 0 39 | gpus_per_node = config["world_size"] // nnodes 40 | node_id = int(os.environ.get("SLURM_NODEID")) 41 | config["rank"] = node_id * gpus_per_node 42 | config["local_rank"] = 0 43 | else: 44 | assert ntasks_per_node == config["world_size"] // nnodes 45 | config["rank"] = int(os.environ.get("SLURM_PROCID")) 46 | config["local_rank"] = int(os.environ.get("SLURM_LOCALID")) 47 | 48 | logging.info( 49 | f"Init: {config['init_method']}, {config['world_size']}, {config['rank']}" 50 | ) 51 | 52 | # ensures GPU0 does not have extra context/higher peak memory 53 | torch.cuda.set_device(config["local_rank"]) 54 | 55 | dist.init_process_group( 56 | backend=config["distributed_backend"], 57 | init_method=config["init_method"], 58 | world_size=config["world_size"], 59 | rank=config["rank"], 60 | ) 61 | except subprocess.CalledProcessError as e: # scontrol failed 62 | raise e 63 | except FileNotFoundError: # Slurm is not installed 64 | pass 65 | elif config["summit"]: 66 | world_size = int(os.getenv('OMPI_COMM_WORLD_SIZE')) 67 | world_rank = int(os.getenv('OMPI_COMM_WORLD_RANK')) 68 | 69 | # Should be set already 70 | #get_master = ( 71 | # "echo $(cat {} | sort | uniq | grep -v batch | grep -v login | head -1)" 72 | #).format(os.environ["LSB_DJOB_HOSTFILE"]) 73 | #os.environ["MASTER_ADDR"] = str( 74 | # subprocess.check_output(get_master, shell=True) 75 | #)[2:-3] 76 | #os.environ["MASTER_PORT"] = "23456" 77 | 78 | os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] 79 | os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] 80 | 81 | config["local_rank"] = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK')) 82 | 83 | # NCCL and MPI initialization 84 | dist.init_process_group( 85 | backend="nccl", 86 | rank=world_rank, 87 | world_size=world_size, 88 | init_method="env://", 89 | ) 90 | else: 91 | dist.init_process_group( 92 | backend=config["distributed_backend"], init_method="env://", 93 | rank=config['local_rank'], 94 | world_size=config['world_size'] 95 | ) 96 | torch.cuda.set_device(config["local_rank"]) 97 | # TODO: SLURM 98 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # checkpoints 2 | *.ckpt 3 | *.bak 4 | 5 | # macOS 6 | *.DS_Store 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # UV 105 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | #uv.lock 109 | 110 | # poetry 111 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 112 | # This is especially recommended for binary packages to ensure reproducibility, and is more 113 | # commonly ignored for libraries. 114 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 115 | #poetry.lock 116 | 117 | # pdm 118 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 119 | #pdm.lock 120 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 121 | # in version control. 122 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 123 | .pdm.toml 124 | .pdm-python 125 | .pdm-build/ 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .venv 140 | env/ 141 | venv/ 142 | ENV/ 143 | env.bak/ 144 | venv.bak/ 145 | 146 | # Spyder project settings 147 | .spyderproject 148 | .spyproject 149 | 150 | # Rope project settings 151 | .ropeproject 152 | 153 | # mkdocs documentation 154 | /site 155 | 156 | # mypy 157 | .mypy_cache/ 158 | .dmypy.json 159 | dmypy.json 160 | 161 | # Pyre type checker 162 | .pyre/ 163 | 164 | # pytype static type analyzer 165 | .pytype/ 166 | 167 | # Cython debug symbols 168 | cython_debug/ 169 | 170 | # PyCharm 171 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 172 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 173 | # and can be added to the global gitignore or merged into this file. For a more nuclear 174 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 175 | #.idea/ 176 | 177 | # Ruff stuff: 178 | .ruff_cache/ 179 | 180 | # PyPI configuration file 181 | .pypirc -------------------------------------------------------------------------------- /examples/paper_experiments/run_inference_gdb_unconditional_x1x2.py: -------------------------------------------------------------------------------- 1 | import open3d 2 | from shepherd.shepherd_score_utils.generate_point_cloud import ( 3 | get_atom_coords, 4 | get_atomic_vdw_radii, 5 | get_molecular_surface, 6 | get_electrostatics, 7 | get_electrostatics_given_point_charges, 8 | ) 9 | from shepherd.shepherd_score_utils.pharm_utils.pharmacophore import get_pharmacophores 10 | from shepherd.shepherd_score_utils.conformer_generation import update_mol_coordinates 11 | 12 | print('importing rdkit') 13 | import rdkit 14 | from rdkit.Chem import rdDetermineBonds 15 | 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | 19 | print('importing torch') 20 | import torch 21 | import torch_geometric 22 | from torch_geometric.nn import radius_graph 23 | import torch_scatter 24 | 25 | import pickle 26 | from copy import deepcopy 27 | import os 28 | import multiprocessing 29 | from tqdm import tqdm 30 | 31 | print('importing lightning') 32 | import pytorch_lightning as pl 33 | from pytorch_lightning.callbacks import ModelCheckpoint 34 | from pytorch_lightning.loggers import CSVLogger 35 | 36 | from shepherd.lightning_module import LightningModule 37 | from shepherd.datasets import HeteroDataset 38 | 39 | import importlib 40 | 41 | from shepherd.inference import * 42 | 43 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 44 | 45 | chkpt = '../../data/shepherd_chkpts/x1x2_diffusion_gdb17_20240824_submission.ckpt' 46 | 47 | model_pl = LightningModule.load_from_checkpoint(chkpt) 48 | params = model_pl.params 49 | model_pl.to(device) 50 | model_pl.model.device = device 51 | 52 | import argparse 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("batch_size", type=int) # should be set to 1 if using cpus 55 | parser.add_argument("num_batches", type=int) 56 | parser.add_argument("n_atoms", type=int) 57 | parser.add_argument("N_x4", type=int) 58 | parser.add_argument("file_index", type=int) 59 | args = parser.parse_args() 60 | 61 | batch_size = args.batch_size 62 | n_atoms = args.n_atoms 63 | num_batches = args.num_batches 64 | file_index = args.file_index 65 | N_x4 = args.N_x4 66 | 67 | 68 | generated_samples = [] 69 | for n in range(num_batches): 70 | 71 | # only use to break symmetry during unconditional generation 72 | T = params['noise_schedules']['x1']['T'] 73 | inject_noise_at_ts = list(np.arange(130, 80, -1)) # [150] 74 | inject_noise_scales = [1.0] * len(inject_noise_at_ts) 75 | harmonize = True 76 | harmonize_ts = [80] 77 | harmonize_jumps = [20] 78 | 79 | 80 | generated_samples_batch = inference_sample( 81 | model_pl, 82 | batch_size = batch_size, 83 | 84 | N_x1 = n_atoms, 85 | N_x4 = N_x4, # must equal len(pharm_types) if inpainting 86 | 87 | unconditional = True, 88 | 89 | prior_noise_scale = 1.0, 90 | denoising_noise_scale = 1.0, 91 | 92 | # only use to break symmetry during unconditional generation 93 | inject_noise_at_ts = inject_noise_at_ts, #[], 94 | inject_noise_scales = inject_noise_scales, #[], 95 | harmonize = harmonize, # False 96 | harmonize_ts = harmonize_ts, #[], 97 | harmonize_jumps = harmonize_jumps, #[], 98 | 99 | 100 | # all the below options are only relevant if unconditional is False 101 | 102 | inpaint_x2_pos = False, 103 | 104 | inpaint_x3_pos = False, 105 | inpaint_x3_x = False, 106 | 107 | inpaint_x4_pos = False, 108 | inpaint_x4_direction = False, 109 | inpaint_x4_type = False, 110 | 111 | stop_inpainting_at_time_x2 = 0.0, # range from 0.0 to 1.0 (fraction of T) 112 | add_noise_to_inpainted_x2_pos = 0.0, 113 | 114 | stop_inpainting_at_time_x3 = 0.0, # range from 0.0 to 1.0 (fraction of T) 115 | add_noise_to_inpainted_x3_pos = 0.0, 116 | add_noise_to_inpainted_x3_x = 0.0, 117 | 118 | stop_inpainting_at_time_x4 = 0.0, # range from 0.0 to 1.0 (fraction of T) 119 | add_noise_to_inpainted_x4_pos = 0.0, 120 | add_noise_to_inpainted_x4_direction = 0.0, 121 | add_noise_to_inpainted_x4_type = 0.0, 122 | 123 | # these are the inpainting targets 124 | center_of_mass = np.zeros(3), # center of mass of x1; already centered to zero above 125 | surface = np.zeros((75,3)), 126 | electrostatics = np.zeros(75), 127 | pharm_types = np.zeros(5, dtype = int), 128 | pharm_pos = np.zeros((5, 3)), 129 | pharm_direction = np.zeros((5, 3)), 130 | 131 | ) 132 | generated_samples = generated_samples + generated_samples_batch 133 | 134 | with open(f"samples/GDB_unconditional/x1x2/samples_{file_index}.pickle", "wb") as f: 135 | pickle.dump(generated_samples, f) 136 | -------------------------------------------------------------------------------- /examples/paper_experiments/run_inference_gdb_unconditional_x1x3.py: -------------------------------------------------------------------------------- 1 | import open3d 2 | from shepherd.shepherd_score_utils.generate_point_cloud import ( 3 | get_atom_coords, 4 | get_atomic_vdw_radii, 5 | get_molecular_surface, 6 | get_electrostatics, 7 | get_electrostatics_given_point_charges, 8 | ) 9 | from shepherd.shepherd_score_utils.pharm_utils.pharmacophore import get_pharmacophores 10 | from shepherd.shepherd_score_utils.conformer_generation import update_mol_coordinates 11 | 12 | print('importing rdkit') 13 | import rdkit 14 | from rdkit.Chem import rdDetermineBonds 15 | 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | 19 | print('importing torch') 20 | import torch 21 | import torch_geometric 22 | from torch_geometric.nn import radius_graph 23 | import torch_scatter 24 | 25 | import pickle 26 | from copy import deepcopy 27 | import os 28 | import multiprocessing 29 | from tqdm import tqdm 30 | 31 | print('importing lightning') 32 | import pytorch_lightning as pl 33 | from pytorch_lightning.callbacks import ModelCheckpoint 34 | from pytorch_lightning.loggers import CSVLogger 35 | 36 | from shepherd.lightning_module import LightningModule 37 | from shepherd.datasets import HeteroDataset 38 | 39 | import importlib 40 | 41 | from shepherd.inference import * 42 | 43 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 44 | 45 | chkpt = '../../data/shepherd_chkpts/x1x3_diffusion_gdb17_20240824_submission.ckpt' 46 | 47 | model_pl = LightningModule.load_from_checkpoint(chkpt) 48 | params = model_pl.params 49 | model_pl.to(device) 50 | model_pl.model.device = device 51 | 52 | import argparse 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("batch_size", type=int) # should be set to 1 if using cpus 55 | parser.add_argument("num_batches", type=int) 56 | parser.add_argument("n_atoms", type=int) 57 | parser.add_argument("N_x4", type=int) 58 | parser.add_argument("file_index", type=int) 59 | args = parser.parse_args() 60 | 61 | batch_size = args.batch_size 62 | n_atoms = args.n_atoms 63 | num_batches = args.num_batches 64 | file_index = args.file_index 65 | N_x4 = args.N_x4 66 | 67 | 68 | 69 | generated_samples = [] 70 | for n in range(num_batches): 71 | 72 | # only use to break symmetry during unconditional generation 73 | T = params['noise_schedules']['x1']['T'] 74 | inject_noise_at_ts = list(np.arange(130, 80, -1)) # [150] 75 | inject_noise_scales = [1.0] * len(inject_noise_at_ts) 76 | harmonize = True 77 | harmonize_ts = [80] 78 | harmonize_jumps = [20] 79 | 80 | 81 | generated_samples_batch = inference_sample( 82 | model_pl, 83 | batch_size = batch_size, 84 | 85 | N_x1 = n_atoms, 86 | N_x4 = N_x4, # must equal len(pharm_types) if inpainting 87 | 88 | unconditional = True, 89 | 90 | prior_noise_scale = 1.0, 91 | denoising_noise_scale = 1.0, 92 | 93 | # only use to break symmetry during unconditional generation 94 | inject_noise_at_ts = inject_noise_at_ts, #[], 95 | inject_noise_scales = inject_noise_scales, #[], 96 | harmonize = harmonize, # False 97 | harmonize_ts = harmonize_ts, #[], 98 | harmonize_jumps = harmonize_jumps, #[], 99 | 100 | 101 | # all the below options are only relevant if unconditional is False 102 | 103 | inpaint_x2_pos = False, 104 | 105 | inpaint_x3_pos = False, 106 | inpaint_x3_x = False, 107 | 108 | inpaint_x4_pos = False, 109 | inpaint_x4_direction = False, 110 | inpaint_x4_type = False, 111 | 112 | stop_inpainting_at_time_x2 = 0.0, # range from 0.0 to 1.0 (fraction of T) 113 | add_noise_to_inpainted_x2_pos = 0.0, 114 | 115 | stop_inpainting_at_time_x3 = 0.0, # range from 0.0 to 1.0 (fraction of T) 116 | add_noise_to_inpainted_x3_pos = 0.0, 117 | add_noise_to_inpainted_x3_x = 0.0, 118 | 119 | stop_inpainting_at_time_x4 = 0.0, # range from 0.0 to 1.0 (fraction of T) 120 | add_noise_to_inpainted_x4_pos = 0.0, 121 | add_noise_to_inpainted_x4_direction = 0.0, 122 | add_noise_to_inpainted_x4_type = 0.0, 123 | 124 | # these are the inpainting targets 125 | center_of_mass = np.zeros(3), # center of mass of x1; already centered to zero above 126 | surface = np.zeros((75,3)), 127 | electrostatics = np.zeros(75), 128 | pharm_types = np.zeros(5, dtype = int), 129 | pharm_pos = np.zeros((5, 3)), 130 | pharm_direction = np.zeros((5, 3)), 131 | 132 | ) 133 | generated_samples = generated_samples + generated_samples_batch 134 | 135 | with open(f"samples/GDB_unconditional/x1x3/samples_{file_index}.pickle", "wb") as f: 136 | pickle.dump(generated_samples, f) 137 | -------------------------------------------------------------------------------- /src/shepherd/model_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Standalone model loader for ShEPhERD that works outside of Streamlit. 3 | """ 4 | import torch 5 | from typing import Literal, Optional 6 | from pathlib import Path 7 | 8 | from shepherd.lightning_module import LightningModule 9 | from shepherd.checkpoint_manager import get_checkpoint_path 10 | 11 | 12 | def load_model( 13 | model_type: Literal['mosesaq', 'gdb_x2', 'gdb_x3', 'gdb_x4'] = 'mosesaq', 14 | device: Optional[str] = None, 15 | local_data_dir: Optional[str] = None, 16 | cache_dir: Optional[str] = None, 17 | force_download: bool = False, 18 | local_checkpoint_path: Optional[str] = None, 19 | ) -> LightningModule: 20 | """ 21 | Load a ShEPhERD model with automatic checkpoint downloading. 22 | 23 | This function provides a clean interface for loading ShEPhERD models 24 | that works both in research environments and production deployments. 25 | 26 | Arguments 27 | --------- 28 | model_type: Type of model to load 29 | - 'mosesaq': MOSES-aq trained with shape, electrostatics, and pharmacophores 30 | - 'gdb_x2': GDB17 trained with shape conditioning 31 | - 'gdb_x3': GDB17 trained with shape and electrostatics 32 | - 'gdb_x4': GDB17 trained with pharmacophores 33 | device: Device to load model on ('cuda', 'cpu', or None for auto-detection) 34 | local_data_dir: Directory containing local checkpoints (for backward compatibility) 35 | cache_dir: Directory to cache downloaded checkpoints (None uses default HF cache) 36 | force_download: Whether to force download even if local checkpoint exists 37 | local_checkpoint_path: Path to local checkpoint 38 | If this is provided, it will override the model type and download logic. 39 | 40 | Returns 41 | ------- 42 | Loaded and initialized ShEPhERD model ready for inference 43 | 44 | Example 45 | ------- 46 | >>> # Load default MOSES-aq model 47 | >>> model = load_model() 48 | 49 | >>> # Load specific model type on GPU 50 | >>> model = load_model('gdb_x3', device='cuda') 51 | 52 | >>> # Force download latest version 53 | >>> model = load_model('mosesaq', force_download=True) 54 | """ 55 | if device is None: 56 | device = "cuda" if torch.cuda.is_available() else "cpu" 57 | 58 | if local_checkpoint_path is not None: 59 | try: 60 | device_obj = torch.device(device) 61 | model_pl = LightningModule.load_from_checkpoint( 62 | local_checkpoint_path, 63 | weights_only=True, 64 | map_location=device_obj 65 | ) 66 | 67 | model_pl.eval() 68 | model_pl.model.device = device_obj 69 | 70 | print(f"Successfully loaded {model_type} model from local checkpoint.") 71 | return model_pl 72 | except Exception as e: 73 | raise RuntimeError(f"Failed to load model from local checkpoint: {str(e)}") from e 74 | 75 | try: 76 | # Get checkpoint path with automatic downloading 77 | model_path = get_checkpoint_path( 78 | model_type=model_type, 79 | local_data_dir=local_data_dir, 80 | cache_dir=cache_dir, 81 | force_download=force_download 82 | ) 83 | 84 | print(f"Loading {model_type} model from: {model_path}") 85 | print(f"Using device: {device}") 86 | 87 | device_obj = torch.device(device) 88 | model_pl = LightningModule.load_from_checkpoint( 89 | model_path, 90 | weights_only=True, 91 | map_location=device_obj 92 | ) 93 | 94 | model_pl.eval() 95 | model_pl.model.device = device_obj 96 | 97 | print(f"Successfully loaded {model_type} model") 98 | return model_pl 99 | 100 | except Exception as e: 101 | raise RuntimeError(f"Failed to load {model_type} model: {str(e)}") from e 102 | 103 | 104 | def get_model_info() -> dict: 105 | """ 106 | Get information about available ShEPhERD models. 107 | 108 | Returns 109 | ------- 110 | Dictionary mapping model types to their descriptions 111 | """ 112 | from shepherd.checkpoint_manager import CheckpointManager 113 | 114 | manager = CheckpointManager() 115 | return manager.get_available_models() 116 | 117 | 118 | def clear_model_cache(model_type: Optional[str] = None, cache_dir: Optional[str] = None): 119 | """ 120 | Clear cached model checkpoints. 121 | 122 | Arguments 123 | --------- 124 | model_type: Specific model type to clear, or None to clear all 125 | cache_dir: Cache directory to clear from (None uses default HF cache) 126 | """ 127 | from shepherd.checkpoint_manager import CheckpointManager 128 | 129 | manager = CheckpointManager(cache_dir=cache_dir) 130 | manager.clear_cache(model_type=model_type) 131 | -------------------------------------------------------------------------------- /src/shepherd/model/equiformer_v2/ocpmodels/models/gemnet_oc/layers/spherical_basis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | This source code is licensed under the MIT license found in the 4 | LICENSE file in the root directory of this source tree. 5 | """ 6 | 7 | import torch 8 | 9 | from ocpmodels.modules.scaling import ScaleFactor 10 | 11 | from .basis_utils import get_sph_harm_basis 12 | from .radial_basis import GaussianBasis, RadialBasis 13 | 14 | 15 | class CircularBasisLayer(torch.nn.Module): 16 | """ 17 | 2D Fourier Bessel Basis 18 | 19 | Arguments 20 | --------- 21 | num_spherical: int 22 | Number of basis functions. Controls the maximum frequency. 23 | radial_basis: RadialBasis 24 | Radial basis function. 25 | cbf: dict 26 | Name and hyperparameters of the circular basis function. 27 | scale_basis: bool 28 | Whether to scale the basis values for better numerical stability. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | num_spherical: int, 34 | radial_basis: RadialBasis, 35 | cbf: dict, 36 | scale_basis: bool = False, 37 | ): 38 | super().__init__() 39 | 40 | self.radial_basis = radial_basis 41 | 42 | self.scale_basis = scale_basis 43 | if self.scale_basis: 44 | self.scale_cbf = ScaleFactor() 45 | 46 | cbf_name = cbf["name"].lower() 47 | cbf_hparams = cbf.copy() 48 | del cbf_hparams["name"] 49 | 50 | if cbf_name == "gaussian": 51 | self.cosφ_basis = GaussianBasis( 52 | start=-1, stop=1, num_gaussians=num_spherical, **cbf_hparams 53 | ) 54 | elif cbf_name == "spherical_harmonics": 55 | self.cosφ_basis = get_sph_harm_basis( 56 | num_spherical, zero_m_only=True 57 | ) 58 | else: 59 | raise ValueError(f"Unknown cosine basis function '{cbf_name}'.") 60 | 61 | def forward(self, D_ca, cosφ_cab): 62 | rad_basis = self.radial_basis(D_ca) # (num_edges, num_radial) 63 | cir_basis = self.cosφ_basis(cosφ_cab) # (num_triplets, num_spherical) 64 | 65 | if self.scale_basis: 66 | cir_basis = self.scale_cbf(cir_basis) 67 | 68 | return rad_basis, cir_basis 69 | # (num_edges, num_radial), (num_triplets, num_spherical) 70 | 71 | 72 | class SphericalBasisLayer(torch.nn.Module): 73 | """ 74 | 3D Fourier Bessel Basis 75 | 76 | Arguments 77 | --------- 78 | num_spherical: int 79 | Number of basis functions. Controls the maximum frequency. 80 | radial_basis: RadialBasis 81 | Radial basis functions. 82 | sbf: dict 83 | Name and hyperparameters of the spherical basis function. 84 | scale_basis: bool 85 | Whether to scale the basis values for better numerical stability. 86 | """ 87 | 88 | def __init__( 89 | self, 90 | num_spherical: int, 91 | radial_basis: RadialBasis, 92 | sbf: dict, 93 | scale_basis: bool = False, 94 | ): 95 | super().__init__() 96 | 97 | self.num_spherical = num_spherical 98 | self.radial_basis = radial_basis 99 | 100 | self.scale_basis = scale_basis 101 | if self.scale_basis: 102 | self.scale_sbf = ScaleFactor() 103 | 104 | sbf_name = sbf["name"].lower() 105 | sbf_hparams = sbf.copy() 106 | del sbf_hparams["name"] 107 | 108 | if sbf_name == "spherical_harmonics": 109 | self.spherical_basis = get_sph_harm_basis( 110 | num_spherical, zero_m_only=False 111 | ) 112 | 113 | elif sbf_name == "legendre_outer": 114 | circular_basis = get_sph_harm_basis( 115 | num_spherical, zero_m_only=True 116 | ) 117 | self.spherical_basis = lambda cosφ, ϑ: ( 118 | circular_basis(cosφ)[:, :, None] 119 | * circular_basis(torch.cos(ϑ))[:, None, :] 120 | ).reshape(cosφ.shape[0], -1) 121 | 122 | elif sbf_name == "gaussian_outer": 123 | self.circular_basis = GaussianBasis( 124 | start=-1, stop=1, num_gaussians=num_spherical, **sbf_hparams 125 | ) 126 | self.spherical_basis = lambda cosφ, ϑ: ( 127 | self.circular_basis(cosφ)[:, :, None] 128 | * self.circular_basis(torch.cos(ϑ))[:, None, :] 129 | ).reshape(cosφ.shape[0], -1) 130 | 131 | else: 132 | raise ValueError(f"Unknown spherical basis function '{sbf_name}'.") 133 | 134 | def forward(self, D_ca, cosφ_cab, θ_cabd): 135 | rad_basis = self.radial_basis(D_ca) 136 | sph_basis = self.spherical_basis(cosφ_cab, θ_cabd) 137 | # (num_quadruplets, num_spherical**2) 138 | 139 | if self.scale_basis: 140 | sph_basis = self.scale_sbf(sph_basis) 141 | 142 | return rad_basis, sph_basis 143 | # (num_edges, num_radial), (num_quadruplets, num_spherical**2) 144 | --------------------------------------------------------------------------------