├── .gitignore ├── LICENSE ├── README.md ├── alanine_dipeptide_files ├── ala2.pdb └── fes.dat ├── assets ├── ala2.gif └── systems.png ├── configs ├── ala │ ├── base.yml │ ├── dpp.yml │ ├── forcenet.yml │ ├── gemnet-T-scale.json │ ├── gemnet-T.yml │ ├── gemnet-dT-scale.json │ ├── gemnet-dT.yml │ ├── nequip │ │ └── 40k.yml │ └── schnet.yml ├── lips │ ├── base.yml │ ├── dpp.yml │ ├── forcenet.yml │ ├── gemnet-T-scale.json │ ├── gemnet-T.yml │ ├── gemnet-dT-scale.json │ ├── gemnet-dT.yml │ ├── nequip │ │ └── lips20k.yaml │ └── schnet.yml ├── md17 │ ├── base.yml │ ├── dimenet.yml │ ├── forcenet.yml │ ├── gemnet-T-scale.json │ ├── gemnet-T.yml │ ├── gemnet-dT-scale.json │ ├── gemnet-dT.yml │ ├── nequip │ │ ├── aspirin.yml │ │ ├── ethanol.yml │ │ ├── naphthalene.yml │ │ └── salicylic_acid.yml │ └── schnet.yml ├── simulate │ ├── ala.yml │ ├── aspirin.yml │ ├── ethanol.yml │ ├── lips.yml │ ├── naphthalene.yml │ ├── nequip_data_cfg │ │ ├── ala.yml │ │ ├── aspirin.yml │ │ ├── ethanol.yml │ │ ├── lips.yml │ │ ├── naphthalene.yml │ │ ├── salicylic_acid.yml │ │ └── water.yml │ ├── salicylic_acid.yml │ └── water.yml └── water │ ├── base.yml │ ├── dpp.yml │ ├── forcenet.yml │ ├── gemnet-T-scale.json │ ├── gemnet-T.yml │ ├── gemnet-dT-scale.json │ ├── gemnet-dT.yml │ ├── nequip │ ├── nequip_10k.yml │ ├── nequip_1k.yml │ └── nequip_90k.yml │ └── schnet.yml ├── deeppot_se ├── ala40k │ └── input.json ├── lips20k │ └── input.json ├── md17_aspirin_10k │ └── input.json ├── md17_ehtanol_10k │ └── input.json ├── md17_naphthalene_10k │ └── input.json ├── md17_salicylic_acid_10k │ └── input.json ├── water10k │ └── input.json ├── water1k │ └── input.json └── water90k │ └── input.json ├── example_model └── water_1k_schnet │ └── checkpoints │ ├── best_checkpoint.pt │ ├── checkpoint.pt │ └── config.yml ├── example_sim ├── ala_nequip │ ├── COLVAR │ ├── HILLS │ ├── atoms.traj │ ├── fes.dat │ ├── test_metric.json │ ├── test_metric.log │ └── thermo.log ├── aspirin_dimenet │ ├── atoms.traj │ ├── test_metric.json │ └── thermo.log ├── lips_gemnet-t │ ├── atoms.traj │ ├── test_metric.json │ └── thermo.log ├── water-10k_gemnet-t │ ├── atoms.traj │ ├── test_metric.json │ └── thermo.log └── water-1k_schnet │ ├── atoms.traj │ ├── test_metric.json │ └── thermo.log ├── fit_scaling.py ├── main.py ├── mdsim ├── __init__.py ├── common │ ├── __init__.py │ ├── const.py │ ├── data_parallel.py │ ├── deepmd_utils.py │ ├── distutils.py │ ├── flags.py │ ├── hpo_utils.py │ ├── logger.py │ ├── registry.py │ ├── transforms.py │ └── utils.py ├── datasets │ ├── __init__.py │ ├── embeddings │ │ ├── __init__.py │ │ ├── atomic_radii.py │ │ ├── continuous_embeddings.py │ │ ├── khot_embeddings.py │ │ └── qmof_khot_embeddings.py │ └── lmdb_dataset.py ├── md │ ├── __init__.py │ ├── ase_utils.py │ └── integrator.py ├── models │ ├── __init__.py │ ├── base.py │ ├── cgcnn.py │ ├── dimenet.py │ ├── dimenet_plus_plus.py │ ├── forcenet.py │ ├── gemnet │ │ ├── fit_scaling.py │ │ ├── gemnet.py │ │ ├── initializers.py │ │ ├── layers │ │ │ ├── atom_update_block.py │ │ │ ├── base_layers.py │ │ │ ├── basis_utils.py │ │ │ ├── efficient.py │ │ │ ├── embedding_block.py │ │ │ ├── interaction_block.py │ │ │ ├── radial_basis.py │ │ │ ├── scaling.py │ │ │ └── spherical_basis.py │ │ └── utils.py │ ├── schnet.py │ ├── spinconv.py │ └── utils │ │ ├── __init__.py │ │ ├── activations.py │ │ └── basis.py ├── modules │ ├── __init__.py │ ├── evaluator.py │ ├── exponential_moving_average.py │ ├── loss.py │ ├── normalizer.py │ └── scheduler.py ├── tasks │ ├── __init__.py │ └── task.py └── trainers │ ├── __init__.py │ └── trainer.py ├── observable.ipynb ├── preprocessing ├── __init__.py ├── alanine_dipeptide.py ├── arrays_to_graphs.py ├── lips.py ├── md17.py └── water.py ├── setup.py └── simulate.py /.gitignore: -------------------------------------------------------------------------------- 1 | wandb 2 | results 3 | logs 4 | experimental 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *,cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | docs/source/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # IPython Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | 93 | # Rope project settings 94 | .ropeproject 95 | 96 | # User directories 97 | Local 98 | 99 | # .DS_Store 100 | .DS_Store 101 | 102 | # VIM swap files 103 | *.swp 104 | 105 | # PyCharm 106 | .idea/ 107 | 108 | # VS Code 109 | .vscode/ 110 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Xiang Fu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /alanine_dipeptide_files/ala2.pdb: -------------------------------------------------------------------------------- 1 | TITLE Gromacs Runs One Microsecond At Cannonball Speeds 2 | MODEL 1 3 | ATOM 1 1HH3 ACE A 1 0.844 0.586 3.765 1.00 0.00 H 4 | ATOM 2 CH3 ACE A 1 -0.219 0.365 3.675 1.00 0.00 C 5 | ATOM 3 2HH3 ACE A 1 -0.540 -0.233 4.528 1.00 0.00 H 6 | ATOM 4 3HH3 ACE A 1 -0.785 1.295 3.650 1.00 0.00 H 7 | ATOM 5 C ACE A 1 -0.468 -0.410 2.396 1.00 0.00 C 8 | ATOM 6 O ACE A 1 -0.957 -1.533 2.440 1.00 0.00 O 9 | ATOM 7 N ALA A 2 -0.119 0.186 1.255 1.00 0.00 N 10 | ATOM 8 H ALA A 2 0.311 1.099 1.292 1.00 0.00 H 11 | ATOM 9 CA ALA A 2 -0.298 -0.382 -0.082 1.00 0.00 CA 12 | ATOM 10 HA ALA A 2 -0.029 -1.440 -0.049 1.00 0.00 H 13 | ATOM 11 CB ALA A 2 -1.777 -0.261 -0.486 1.00 0.00 C 14 | ATOM 12 HB1 ALA A 2 -2.404 -0.790 0.234 1.00 0.00 H 15 | ATOM 13 HB2 ALA A 2 -2.070 0.790 -0.509 1.00 0.00 H 16 | ATOM 14 HB3 ALA A 2 -1.935 -0.697 -1.472 1.00 0.00 H 17 | ATOM 15 C ALA A 2 0.628 0.315 -1.095 1.00 0.00 C 18 | ATOM 16 O ALA A 2 1.124 1.410 -0.831 1.00 0.00 O 19 | ATOM 17 N NME A 3 0.851 -0.319 -2.250 1.00 0.00 N 20 | ATOM 18 H NME A 3 0.402 -1.208 -2.397 1.00 0.00 H 21 | ATOM 19 CH3 NME A 3 1.698 0.205 -3.315 1.00 0.00 C 22 | ATOM 20 1HH3 NME A 3 1.307 1.165 -3.657 1.00 0.00 H 23 | ATOM 21 2HH3 NME A 3 1.721 -0.493 -4.152 1.00 0.00 H 24 | ATOM 22 3HH3 NME A 3 2.713 0.351 -2.940 1.00 0.00 H 25 | TER 26 | ENDMDL 27 | -------------------------------------------------------------------------------- /assets/ala2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyonofx/MDsim/61ca2cfba373fd374343592556789fa30fbe3b56/assets/ala2.gif -------------------------------------------------------------------------------- /assets/systems.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyonofx/MDsim/61ca2cfba373fd374343592556789fa30fbe3b56/assets/systems.png -------------------------------------------------------------------------------- /configs/ala/base.yml: -------------------------------------------------------------------------------- 1 | trainer: trainer 2 | no_energy: True 3 | 4 | dataset: 5 | src: DATAPATH/ala 6 | name: alanine-dipeptide 7 | size: 40k 8 | normalize_labels: True 9 | 10 | logger: 11 | name: wandb 12 | project: mdbench 13 | 14 | task: 15 | dataset: lmdb 16 | description: "Regressing to energies and forces" 17 | type: regression 18 | metric: mae 19 | labels: 20 | - potential energy 21 | grad_input: atomic forces 22 | train_on_free_atoms: True 23 | eval_on_free_atoms: True 24 | 25 | optim: 26 | batch_size: 5 27 | eval_batch_size: 5 28 | num_workers: 4 29 | lr_initial: 0.001 30 | optimizer: Adam 31 | optimizer_params: {"amsgrad": True} 32 | 33 | scheduler: ReduceLROnPlateau 34 | patience: 5 35 | factor: 0.8 36 | min_lr: 0.000001 37 | 38 | max_epochs: 2000 39 | force_coefficient: 100 40 | energy_coefficient: 1 41 | ema_decay: 0.999 42 | clip_grad_norm: 10 43 | 44 | early_stopping_time: 604800 45 | early_stopping_lr: 0.000001 -------------------------------------------------------------------------------- /configs/ala/dpp.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/ala/base.yml 3 | 4 | model: 5 | name: dimenetplusplus 6 | hidden_channels: 128 7 | out_emb_channels: 256 8 | int_emb_size: 64 9 | basis_emb_size: 8 10 | num_blocks: 4 11 | cutoff: 5.0 12 | envelope_exponent: 5 13 | num_radial: 6 14 | num_spherical: 7 15 | num_before_skip: 1 16 | num_after_skip: 2 17 | num_output_layers: 3 18 | regress_forces: True 19 | use_pbc: True 20 | otf_graph: True 21 | 22 | optim: 23 | loss_energy: mae 24 | loss_force: mae -------------------------------------------------------------------------------- /configs/ala/forcenet.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/ala/base.yml 3 | 4 | model: 5 | name: forcenet 6 | num_interactions: 5 7 | cutoff: 6 8 | basis: "sphallmul" 9 | ablation: "none" 10 | depth_mlp_edge: 2 11 | depth_mlp_node: 1 12 | activation_str: "swish" 13 | decoder_activation_str: "swish" 14 | feat: "full" 15 | hidden_channels: 512 16 | decoder_hidden_channels: 512 17 | max_n: 3 18 | otf_graph: True -------------------------------------------------------------------------------- /configs/ala/gemnet-T-scale.json: -------------------------------------------------------------------------------- 1 | { 2 | "comment": null, 3 | "TripInteraction_1_had_rbf": 3.9750466346740723, 4 | "TripInteraction_1_sum_cbf": 2.1334712505340576, 5 | "AtomUpdate_1_sum": 0.6158230304718018, 6 | "TripInteraction_2_had_rbf": 3.596966028213501, 7 | "TripInteraction_2_sum_cbf": 2.0770750045776367, 8 | "AtomUpdate_2_sum": 0.49761834740638733, 9 | "TripInteraction_3_had_rbf": 3.83621883392334, 10 | "TripInteraction_3_sum_cbf": 2.0392439365386963, 11 | "AtomUpdate_3_sum": 0.39251139760017395, 12 | "TripInteraction_4_had_rbf": 3.904399871826172, 13 | "TripInteraction_4_sum_cbf": 2.099656343460083, 14 | "AtomUpdate_4_sum": 0.4656703472137451, 15 | "OutBlock_0_sum": 0.6115648150444031, 16 | "OutBlock_1_sum": 0.5632280111312866, 17 | "OutBlock_2_sum": 0.45566001534461975, 18 | "OutBlock_3_sum": 0.4641615152359009, 19 | "OutBlock_4_sum": 0.4313839077949524 20 | } -------------------------------------------------------------------------------- /configs/ala/gemnet-T.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/ala/base.yml 3 | 4 | model: 5 | name: gemnet_t 6 | num_spherical: 7 7 | num_radial: 6 8 | num_blocks: 4 9 | emb_size_atom: 128 10 | emb_size_edge: 128 11 | emb_size_trip: 64 12 | emb_size_rbf: 16 13 | emb_size_cbf: 16 14 | emb_size_bil_trip: 64 15 | num_before_skip: 1 16 | num_after_skip: 1 17 | num_concat: 1 18 | num_atom: 2 19 | cutoff: 5.0 20 | max_neighbors: 50 21 | rbf: 22 | name: gaussian 23 | envelope: 24 | name: polynomial 25 | exponent: 5 26 | cbf: 27 | name: spherical_harmonics 28 | output_init: HeOrthogonal 29 | activation: silu 30 | scale_file: configs/ala/gemnet-T-scale.json 31 | extensive: True 32 | otf_graph: True 33 | regress_forces: True 34 | direct_forces: False 35 | 36 | optim: 37 | batch_size: 5 38 | eval_batch_size: 5 39 | num_workers: 4 40 | lr_initial: 0.001 41 | optimizer: AdamW 42 | optimizer_params: {"eps": 1.e-7, "weight_decay": 0.000002, "amsgrad": True} 43 | 44 | scheduler: ReduceLROnPlateau 45 | patience: 5 46 | factor: 0.8 47 | min_lr: 0.000001 48 | 49 | max_epochs: 2000 50 | force_coefficient: 0.999 51 | energy_coefficient: 0.001 52 | ema_decay: 0.999 53 | clip_grad_norm: 10 54 | loss_energy: mae 55 | loss_force: l2mae -------------------------------------------------------------------------------- /configs/ala/gemnet-dT-scale.json: -------------------------------------------------------------------------------- 1 | { 2 | "comment": null, 3 | "TripInteraction_1_had_rbf": 3.9720029830932617, 4 | "TripInteraction_1_sum_cbf": 2.133086919784546, 5 | "AtomUpdate_1_sum": 0.6158733367919922, 6 | "TripInteraction_2_had_rbf": 3.597200632095337, 7 | "TripInteraction_2_sum_cbf": 2.0771214962005615, 8 | "AtomUpdate_2_sum": 0.4976757764816284, 9 | "TripInteraction_3_had_rbf": 3.8365020751953125, 10 | "TripInteraction_3_sum_cbf": 2.0394160747528076, 11 | "AtomUpdate_3_sum": 0.39257684350013733, 12 | "TripInteraction_4_had_rbf": 3.9046010971069336, 13 | "TripInteraction_4_sum_cbf": 2.099813222885132, 14 | "AtomUpdate_4_sum": 0.46566471457481384, 15 | "OutBlock_0_sum": 0.6115648150444031, 16 | "OutBlock_0_had": 3.732174873352051, 17 | "OutBlock_1_sum": 0.5139416456222534, 18 | "OutBlock_1_had": 3.6697442531585693, 19 | "OutBlock_2_sum": 0.4596453011035919, 20 | "OutBlock_2_had": 3.6231987476348877, 21 | "OutBlock_3_sum": 0.49535492062568665, 22 | "OutBlock_3_had": 3.5756115913391113, 23 | "OutBlock_4_sum": 0.4333173930644989, 24 | "OutBlock_4_had": 3.768071413040161 25 | } -------------------------------------------------------------------------------- /configs/ala/gemnet-dT.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/ala/base.yml 3 | 4 | identifier: dT 5 | 6 | model: 7 | name: gemnet_t 8 | num_spherical: 7 9 | num_radial: 6 10 | num_blocks: 4 11 | emb_size_atom: 128 12 | emb_size_edge: 128 13 | emb_size_trip: 64 14 | emb_size_rbf: 16 15 | emb_size_cbf: 16 16 | emb_size_bil_trip: 64 17 | num_before_skip: 1 18 | num_after_skip: 1 19 | num_concat: 1 20 | num_atom: 2 21 | cutoff: 5.0 22 | max_neighbors: 50 23 | rbf: 24 | name: gaussian 25 | envelope: 26 | name: polynomial 27 | exponent: 5 28 | cbf: 29 | name: spherical_harmonics 30 | output_init: HeOrthogonal 31 | activation: silu 32 | scale_file: configs/ala/gemnet-dT-scale.json 33 | extensive: True 34 | otf_graph: True 35 | regress_forces: True 36 | direct_forces: True 37 | 38 | optim: 39 | batch_size: 5 40 | eval_batch_size: 5 41 | num_workers: 4 42 | lr_initial: 0.001 43 | optimizer: AdamW 44 | optimizer_params: {"eps": 1.e-7, "weight_decay": 0.000002, "amsgrad": True} 45 | 46 | scheduler: ReduceLROnPlateau 47 | patience: 5 48 | factor: 0.8 49 | min_lr: 0.000001 50 | 51 | max_epochs: 2000 52 | force_coefficient: 0.999 53 | energy_coefficient: 0.001 54 | ema_decay: 0.999 55 | clip_grad_norm: 10 56 | loss_energy: mae 57 | loss_force: l2mae -------------------------------------------------------------------------------- /configs/ala/schnet.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/ala/base.yml 3 | 4 | model: 5 | name: schnet 6 | hidden_channels: 64 7 | num_filters: 64 8 | num_interactions: 6 9 | num_gaussians: 25 10 | cutoff: 6.0 11 | use_pbc: True 12 | otf_graph: True -------------------------------------------------------------------------------- /configs/lips/base.yml: -------------------------------------------------------------------------------- 1 | trainer: trainer 2 | 3 | dataset: 4 | src: DATAPATH/lips 5 | name: lips 6 | size: 20k 7 | unit_kcal: False 8 | normalize_labels: True 9 | 10 | logger: 11 | name: wandb 12 | project: mdbench 13 | 14 | task: 15 | dataset: lmdb 16 | description: "Regressing to energies and forces" 17 | type: regression 18 | metric: mae 19 | labels: 20 | - potential energy 21 | grad_input: atomic forces 22 | train_on_free_atoms: True 23 | eval_on_free_atoms: True 24 | 25 | optim: 26 | batch_size: 1 27 | eval_batch_size: 1 28 | num_workers: 4 29 | lr_initial: 0.001 30 | optimizer: Adam 31 | optimizer_params: {"amsgrad": True} 32 | 33 | scheduler: ReduceLROnPlateau 34 | patience: 5 35 | factor: 0.8 36 | min_lr: 0.000001 37 | 38 | max_epochs: 2000 39 | force_coefficient: 1000 40 | energy_coefficient: 1 41 | ema_decay: 0.999 42 | clip_grad_norm: 10 43 | 44 | early_stopping_time: 604800 45 | early_stopping_lr: 0.000001 -------------------------------------------------------------------------------- /configs/lips/dpp.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/lips/base.yml 3 | 4 | model: 5 | name: dimenetplusplus 6 | hidden_channels: 128 7 | out_emb_channels: 256 8 | int_emb_size: 64 9 | basis_emb_size: 8 10 | num_blocks: 4 11 | cutoff: 5.0 12 | envelope_exponent: 5 13 | num_radial: 6 14 | num_spherical: 7 15 | num_before_skip: 1 16 | num_after_skip: 2 17 | num_output_layers: 3 18 | regress_forces: True 19 | use_pbc: True 20 | otf_graph: True 21 | 22 | optim: 23 | loss_energy: mae 24 | loss_force: mae 25 | -------------------------------------------------------------------------------- /configs/lips/forcenet.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/lips/base.yml 3 | 4 | identifier: size512 5 | 6 | model: 7 | name: forcenet 8 | num_interactions: 5 9 | cutoff: 6.0 10 | basis: "sphallmul" 11 | ablation: "none" 12 | depth_mlp_edge: 2 13 | depth_mlp_node: 1 14 | activation_str: "swish" 15 | decoder_activation_str: "swish" 16 | feat: "full" 17 | hidden_channels: 512 18 | decoder_hidden_channels: 512 19 | max_n: 3 20 | otf_graph: True 21 | 22 | -------------------------------------------------------------------------------- /configs/lips/gemnet-T-scale.json: -------------------------------------------------------------------------------- 1 | { 2 | "comment": null, 3 | "TripInteraction_1_had_rbf": 6.074735164642334, 4 | "TripInteraction_1_sum_cbf": 2.643386125564575, 5 | "AtomUpdate_1_sum": 0.6294099688529968, 6 | "TripInteraction_2_had_rbf": 5.459280014038086, 7 | "TripInteraction_2_sum_cbf": 2.8160929679870605, 8 | "AtomUpdate_2_sum": 0.6374367475509644, 9 | "TripInteraction_3_had_rbf": 5.262460231781006, 10 | "TripInteraction_3_sum_cbf": 2.5424845218658447, 11 | "AtomUpdate_3_sum": 0.561196506023407, 12 | "TripInteraction_4_had_rbf": 5.5935187339782715, 13 | "TripInteraction_4_sum_cbf": 2.9127252101898193, 14 | "AtomUpdate_4_sum": 0.6214905977249146, 15 | "OutBlock_0_sum": 0.8549189567565918, 16 | "OutBlock_1_sum": 0.7073656916618347, 17 | "OutBlock_2_sum": 0.6550492644309998, 18 | "OutBlock_3_sum": 0.6138262152671814, 19 | "OutBlock_4_sum": 0.6996177434921265 20 | } -------------------------------------------------------------------------------- /configs/lips/gemnet-T.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/lips/base.yml 3 | 4 | model: 5 | name: gemnet_t 6 | num_spherical: 7 7 | num_radial: 6 8 | num_blocks: 4 9 | emb_size_atom: 128 10 | emb_size_edge: 128 11 | emb_size_trip: 64 12 | emb_size_rbf: 16 13 | emb_size_cbf: 16 14 | emb_size_bil_trip: 64 15 | num_before_skip: 1 16 | num_after_skip: 1 17 | num_concat: 1 18 | num_atom: 2 19 | cutoff: 5.0 20 | max_neighbors: 50 21 | rbf: 22 | name: gaussian 23 | envelope: 24 | name: polynomial 25 | exponent: 5 26 | cbf: 27 | name: spherical_harmonics 28 | output_init: HeOrthogonal 29 | activation: silu 30 | scale_file: configs/lips/gemnet-T-scale.json 31 | extensive: True 32 | otf_graph: True 33 | regress_forces: True 34 | direct_forces: False 35 | use_pbc: True 36 | 37 | optim: 38 | batch_size: 1 39 | eval_batch_size: 1 40 | num_workers: 4 41 | lr_initial: 0.001 42 | optimizer: AdamW 43 | optimizer_params: {"eps": 1.e-7, "weight_decay": 0.000002, "amsgrad": True} 44 | 45 | scheduler: ReduceLROnPlateau 46 | patience: 5 47 | factor: 0.8 48 | min_lr: 0.000001 49 | 50 | max_epochs: 2000 51 | force_coefficient: 0.999 52 | energy_coefficient: 0.001 53 | ema_decay: 0.999 54 | clip_grad_norm: 10 55 | loss_energy: mae 56 | loss_force: l2mae 57 | 58 | early_stopping_time: 604800 59 | early_stopping_lr: 0.000001 -------------------------------------------------------------------------------- /configs/lips/gemnet-dT-scale.json: -------------------------------------------------------------------------------- 1 | { 2 | "comment": null, 3 | "TripInteraction_1_had_rbf": 6.074735164642334, 4 | "TripInteraction_1_sum_cbf": 2.643386125564575, 5 | "AtomUpdate_1_sum": 0.6294099688529968, 6 | "TripInteraction_2_had_rbf": 5.459280014038086, 7 | "TripInteraction_2_sum_cbf": 2.8160929679870605, 8 | "AtomUpdate_2_sum": 0.6374367475509644, 9 | "TripInteraction_3_had_rbf": 5.262460231781006, 10 | "TripInteraction_3_sum_cbf": 2.5424845218658447, 11 | "AtomUpdate_3_sum": 0.561196506023407, 12 | "TripInteraction_4_had_rbf": 5.5935187339782715, 13 | "TripInteraction_4_sum_cbf": 2.9127252101898193, 14 | "AtomUpdate_4_sum": 0.6214906573295593, 15 | "OutBlock_0_sum": 0.8549189567565918, 16 | "OutBlock_0_had": 6.044666767120361, 17 | "OutBlock_1_sum": 0.6930992007255554, 18 | "OutBlock_1_had": 5.088397979736328, 19 | "OutBlock_2_sum": 0.7165416479110718, 20 | "OutBlock_2_had": 4.458869934082031, 21 | "OutBlock_3_sum": 0.7415679097175598, 22 | "OutBlock_3_had": 4.5923991203308105, 23 | "OutBlock_4_sum": 0.684806764125824, 24 | "OutBlock_4_had": 5.620660781860352 25 | } -------------------------------------------------------------------------------- /configs/lips/gemnet-dT.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/lips/base.yml 3 | 4 | identifier: dT 5 | 6 | model: 7 | name: gemnet_t 8 | num_spherical: 7 9 | num_radial: 6 10 | num_blocks: 4 11 | emb_size_atom: 128 12 | emb_size_edge: 128 13 | emb_size_trip: 64 14 | emb_size_rbf: 16 15 | emb_size_cbf: 16 16 | emb_size_bil_trip: 64 17 | num_before_skip: 1 18 | num_after_skip: 1 19 | num_concat: 1 20 | num_atom: 2 21 | cutoff: 5.0 22 | max_neighbors: 50 23 | rbf: 24 | name: gaussian 25 | envelope: 26 | name: polynomial 27 | exponent: 5 28 | cbf: 29 | name: spherical_harmonics 30 | output_init: HeOrthogonal 31 | activation: silu 32 | scale_file: configs/lips/gemnet-dT-scale.json 33 | extensive: True 34 | otf_graph: True 35 | regress_forces: True 36 | direct_forces: True 37 | use_pbc: True 38 | 39 | optim: 40 | batch_size: 1 41 | eval_batch_size: 1 42 | num_workers: 4 43 | lr_initial: 0.001 44 | optimizer: AdamW 45 | optimizer_params: {"eps": 1.e-7, "weight_decay": 0.000002, "amsgrad": True} 46 | 47 | scheduler: ReduceLROnPlateau 48 | patience: 5 49 | factor: 0.8 50 | min_lr: 0.000001 51 | 52 | max_epochs: 2000 53 | force_coefficient: 0.999 54 | energy_coefficient: 0.001 55 | ema_decay: 0.999 56 | clip_grad_norm: 10 57 | loss_energy: mae 58 | loss_force: l2mae 59 | 60 | early_stopping_time: 604800 61 | early_stopping_lr: 0.000001 -------------------------------------------------------------------------------- /configs/lips/schnet.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/lips/base.yml 3 | 4 | model: 5 | name: schnet 6 | hidden_channels: 64 7 | num_filters: 64 8 | num_interactions: 6 9 | num_gaussians: 25 10 | cutoff: 6.0 11 | use_pbc: True 12 | otf_graph: True -------------------------------------------------------------------------------- /configs/md17/base.yml: -------------------------------------------------------------------------------- 1 | trainer: trainer 2 | 3 | dataset: 4 | src: DATAPATH/md17 5 | name: md17 6 | size: 10k 7 | molecule: aspirin 8 | normalize_labels: False 9 | 10 | logger: 11 | name: wandb 12 | project: mdbench 13 | 14 | task: 15 | dataset: lmdb 16 | description: "Regressing to energies and forces" 17 | type: regression 18 | metric: mae 19 | labels: 20 | - potential energy 21 | grad_input: atomic forces 22 | train_on_free_atoms: True 23 | eval_on_free_atoms: True 24 | 25 | optim: 26 | early_stopping_time: 604800 27 | early_stopping_lr: 0.000001 -------------------------------------------------------------------------------- /configs/md17/dimenet.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/md17/base.yml 3 | 4 | model: 5 | name: dimenet 6 | hidden_channels: 128 7 | num_blocks: 6 8 | cutoff: 5.0 9 | envelope_exponent: 5 10 | num_bilinear: 8 11 | num_spherical: 7 12 | num_radial: 6 13 | num_before_skip: 1 14 | num_after_skip: 2 15 | num_output_layers: 3 16 | regress_forces: True 17 | use_pbc: False 18 | otf_graph: True 19 | 20 | optim: 21 | batch_size: 32 22 | eval_batch_size: 32 23 | num_workers: 4 24 | lr_initial: 0.001 25 | optimizer: AdamW 26 | optimizer_params: {"eps": 1.e-7, "weight_decay": 0.000002, "amsgrad": True} 27 | 28 | scheduler: ReduceLROnPlateau 29 | patience: 5 30 | factor: 0.8 31 | min_lr: 0.000001 32 | 33 | max_epochs: 10000 34 | force_coefficient: 1000 35 | energy_coefficient: 1 36 | ema_decay: 0.999 37 | clip_grad_norm: 10 38 | loss_energy: mae 39 | loss_force: mae 40 | 41 | -------------------------------------------------------------------------------- /configs/md17/forcenet.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/md17/base.yml 3 | 4 | identifier: size512 5 | 6 | model: 7 | name: forcenet 8 | num_interactions: 5 9 | cutoff: 6 10 | basis: "sphallmul" 11 | ablation: "none" 12 | depth_mlp_edge: 2 13 | depth_mlp_node: 1 14 | activation_str: "swish" 15 | decoder_activation_str: "swish" 16 | feat: "full" 17 | hidden_channels: 512 18 | decoder_hidden_channels: 512 19 | max_n: 3 20 | otf_graph: True 21 | 22 | optim: 23 | batch_size: 8 24 | eval_batch_size: 8 25 | num_workers: 8 26 | lr_initial: 0.001 27 | max_epochs: 10000 28 | energy_coefficient: 0 29 | scheduler: ReduceLROnPlateau 30 | patience: 5 31 | factor: 0.8 32 | min_lr: 0.000001 -------------------------------------------------------------------------------- /configs/md17/gemnet-T-scale.json: -------------------------------------------------------------------------------- 1 | { 2 | "comment": null, 3 | "TripInteraction_1_had_rbf": 3.8454270362854004, 4 | "TripInteraction_1_sum_cbf": 2.3950836658477783, 5 | "AtomUpdate_1_sum": 0.692237138748169, 6 | "TripInteraction_2_had_rbf": 3.8384854793548584, 7 | "TripInteraction_2_sum_cbf": 2.339738368988037, 8 | "AtomUpdate_2_sum": 0.5826907157897949, 9 | "TripInteraction_3_had_rbf": 3.925180673599243, 10 | "TripInteraction_3_sum_cbf": 2.176547050476074, 11 | "AtomUpdate_3_sum": 0.4893060326576233, 12 | "TripInteraction_4_had_rbf": 3.838670492172241, 13 | "TripInteraction_4_sum_cbf": 2.2910943031311035, 14 | "AtomUpdate_4_sum": 0.5455359220504761, 15 | "OutBlock_0_sum": 0.6474325656890869, 16 | "OutBlock_1_sum": 0.6421718001365662, 17 | "OutBlock_2_sum": 0.5285606980323792, 18 | "OutBlock_3_sum": 0.561263382434845, 19 | "OutBlock_4_sum": 0.5111677050590515 20 | } -------------------------------------------------------------------------------- /configs/md17/gemnet-T.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/md17/base.yml 3 | 4 | model: 5 | name: gemnet_t 6 | num_spherical: 7 7 | num_radial: 6 8 | num_blocks: 4 9 | emb_size_atom: 128 10 | emb_size_edge: 128 11 | emb_size_trip: 64 12 | emb_size_rbf: 16 13 | emb_size_cbf: 16 14 | emb_size_bil_trip: 64 15 | num_before_skip: 1 16 | num_after_skip: 1 17 | num_concat: 1 18 | num_atom: 2 19 | cutoff: 5.0 20 | max_neighbors: 50 21 | rbf: 22 | name: gaussian 23 | envelope: 24 | name: polynomial 25 | exponent: 5 26 | cbf: 27 | name: spherical_harmonics 28 | output_init: HeOrthogonal 29 | activation: silu 30 | scale_file: configs/md17/gemnet-T-scale.json 31 | extensive: True 32 | otf_graph: True 33 | regress_forces: True 34 | direct_forces: False 35 | 36 | optim: 37 | batch_size: 1 38 | eval_batch_size: 1 39 | num_workers: 4 40 | lr_initial: 0.001 41 | optimizer: AdamW 42 | optimizer_params: {"eps": 1.e-7, "weight_decay": 0.000002, "amsgrad": True} 43 | 44 | scheduler: ReduceLROnPlateau 45 | patience: 5 46 | factor: 0.8 47 | min_lr: 0.000001 48 | 49 | max_epochs: 10000 50 | force_coefficient: 1000 51 | energy_coefficient: 1 52 | ema_decay: 0.999 53 | clip_grad_norm: 10 54 | loss_energy: mae 55 | loss_force: l2mae 56 | -------------------------------------------------------------------------------- /configs/md17/gemnet-dT-scale.json: -------------------------------------------------------------------------------- 1 | { 2 | "comment": null, 3 | "TripInteraction_1_had_rbf": 3.8454270362854004, 4 | "TripInteraction_1_sum_cbf": 2.3950836658477783, 5 | "AtomUpdate_1_sum": 0.692237138748169, 6 | "TripInteraction_2_had_rbf": 3.8384854793548584, 7 | "TripInteraction_2_sum_cbf": 2.339738368988037, 8 | "AtomUpdate_2_sum": 0.5826907157897949, 9 | "TripInteraction_3_had_rbf": 3.9251813888549805, 10 | "TripInteraction_3_sum_cbf": 2.1765472888946533, 11 | "AtomUpdate_3_sum": 0.4893060028553009, 12 | "TripInteraction_4_had_rbf": 3.8386707305908203, 13 | "TripInteraction_4_sum_cbf": 2.2910943031311035, 14 | "AtomUpdate_4_sum": 0.5455359220504761, 15 | "OutBlock_0_sum": 0.6474325656890869, 16 | "OutBlock_0_had": 3.910574197769165, 17 | "OutBlock_1_sum": 0.5782961845397949, 18 | "OutBlock_1_had": 3.810222625732422, 19 | "OutBlock_2_sum": 0.5464082956314087, 20 | "OutBlock_2_had": 3.691774368286133, 21 | "OutBlock_3_sum": 0.6221889853477478, 22 | "OutBlock_3_had": 3.926927089691162, 23 | "OutBlock_4_sum": 0.49617061018943787, 24 | "OutBlock_4_had": 3.650491714477539 25 | } -------------------------------------------------------------------------------- /configs/md17/gemnet-dT.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/md17/base.yml 3 | 4 | identifier: dT 5 | 6 | model: 7 | name: gemnet_t 8 | num_spherical: 7 9 | num_radial: 6 10 | num_blocks: 4 11 | emb_size_atom: 128 12 | emb_size_edge: 128 13 | emb_size_trip: 64 14 | emb_size_rbf: 16 15 | emb_size_cbf: 16 16 | emb_size_bil_trip: 64 17 | num_before_skip: 1 18 | num_after_skip: 1 19 | num_concat: 1 20 | num_atom: 2 21 | cutoff: 5.0 22 | max_neighbors: 50 23 | rbf: 24 | name: gaussian 25 | envelope: 26 | name: polynomial 27 | exponent: 5 28 | cbf: 29 | name: spherical_harmonics 30 | output_init: HeOrthogonal 31 | activation: silu 32 | scale_file: configs/md17/gemnet-dT-scale.json 33 | extensive: True 34 | otf_graph: True 35 | regress_forces: True 36 | direct_forces: True 37 | 38 | optim: 39 | batch_size: 1 40 | eval_batch_size: 1 41 | num_workers: 4 42 | lr_initial: 0.001 43 | optimizer: AdamW 44 | optimizer_params: {"eps": 1.e-7, "weight_decay": 0.000002, "amsgrad": True} 45 | 46 | scheduler: ReduceLROnPlateau 47 | patience: 5 48 | factor: 0.8 49 | min_lr: 0.000001 50 | 51 | max_epochs: 10000 52 | force_coefficient: 1000 53 | energy_coefficient: 1 54 | ema_decay: 0.999 55 | clip_grad_norm: 10 56 | loss_energy: mae 57 | loss_force: l2mae 58 | -------------------------------------------------------------------------------- /configs/md17/schnet.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/md17/base.yml 3 | 4 | model: 5 | name: schnet 6 | hidden_channels: 64 7 | num_filters: 64 8 | num_interactions: 6 9 | num_gaussians: 25 10 | cutoff: 5.0 11 | use_pbc: False 12 | otf_graph: True 13 | 14 | optim: 15 | batch_size: 100 16 | eval_batch_size: 100 17 | num_workers: 8 18 | lr_initial: 0.001 19 | lr_gamma: 0.1 20 | optimizer: Adam 21 | warmup_steps: 5000 22 | warmup_factor: 0.2 23 | max_epochs: 10000 24 | energy_coefficient: 0.05 25 | force_coefficient: 0.95 26 | scheduler: ReduceLROnPlateau 27 | patience: 5 28 | factor: 0.8 29 | min_lr: 0.000001 -------------------------------------------------------------------------------- /configs/simulate/ala.yml: -------------------------------------------------------------------------------- 1 | model_dir: MODELPATH 2 | test_dataset_src: DATAPATH/ala/40k/test 3 | dataset_src: DATAPATH/ala/local_minimas 4 | nequip_data_config: configs/simulate/nequip_data_cfg/ala.yml 5 | dp_data_path: DATAPATH/ala_dp/40k/test 6 | sim_type: ocp 7 | seed: 0 8 | init_idx: 0 9 | identifier: 5ns 10 | save_freq: 500 11 | steps: 2500000 12 | max_test_points: 10000 13 | plumed: True 14 | 15 | kcal: False # legacy. want to remove soon. 16 | T_init: 300. 17 | integrator: Langevin 18 | integrator_config: {"timestep": 2., "temperature_K": 300., "friction": 0.5} 19 | -------------------------------------------------------------------------------- /configs/simulate/aspirin.yml: -------------------------------------------------------------------------------- 1 | model_dir: MODELPATH 2 | dataset_src: DATAPATH/md17/aspirin/10k/test 3 | dp_data_path: DATAPATH/md17/aspirin/10k/DP/test 4 | nequip_data_config: configs/simulate/nequip_data_cfg/aspirin.yml 5 | identifier: 300ps 6 | sim_type: ocp 7 | seed: 123 8 | save_freq: 100 9 | steps: 600000 10 | max_test_points: 10000 11 | 12 | kcal: False 13 | T_init: 500. 14 | integrator: NoseHoover 15 | integrator_config: {"timestep": 0.5, "temperature": 500., "ttime": 20.} 16 | -------------------------------------------------------------------------------- /configs/simulate/ethanol.yml: -------------------------------------------------------------------------------- 1 | model_dir: MODELPATH 2 | dataset_src: DATAPATH/md17/ethanol/10k/test 3 | dp_data_path: DATAPATH/md17/ethanol/10k/DP/test 4 | nequip_data_config: configs/simulate/nequip_data_cfg/ethanol.yml 5 | identifier: 300ps 6 | sim_type: ocp 7 | seed: 123 8 | save_freq: 100 9 | steps: 600000 10 | max_test_points: 10000 11 | 12 | kcal: False 13 | T_init: 500. 14 | integrator: NoseHoover 15 | integrator_config: {"timestep": 0.5, "temperature": 500., "ttime": 20.} 16 | -------------------------------------------------------------------------------- /configs/simulate/lips.yml: -------------------------------------------------------------------------------- 1 | model_dir: MODELPATH 2 | dataset_src: DATAPATH/lips/20k/test 3 | dp_data_path: DATAPATH/lips/20k/DP/test 4 | nequip_data_config: configs/simulate/nequip_data_cfg/lips.yml 5 | sim_type: ocp 6 | seed: 123 7 | identifier: 50ps 8 | save_freq: 50 9 | steps: 200000 10 | max_test_points: 5000 11 | 12 | kcal: False 13 | T_init: 520. 14 | integrator: NoseHoover 15 | integrator_config: {"timestep": 0.25, "temperature": 520., "ttime": 20.} 16 | -------------------------------------------------------------------------------- /configs/simulate/naphthalene.yml: -------------------------------------------------------------------------------- 1 | model_dir: MODELPATH 2 | dataset_src: DATAPATH/md17/naphthalene/10k/test 3 | dp_data_path: DATAPATH/md17/naphthalene/10k/DP/test 4 | nequip_data_config: configs/simulate/nequip_data_cfg/naphthalene.yml 5 | identifier: 300ps 6 | sim_type: ocp 7 | seed: 123 8 | save_freq: 100 9 | steps: 600000 10 | max_test_points: 10000 11 | 12 | kcal: False 13 | T_init: 500. 14 | integrator: NoseHoover 15 | integrator_config: {"timestep": 0.5, "temperature": 500., "ttime": 20.} -------------------------------------------------------------------------------- /configs/simulate/nequip_data_cfg/ala.yml: -------------------------------------------------------------------------------- 1 | root: MODELPATH/nequip 2 | dataset: npz # type of data set, can be npz or ase 3 | dataset_file_name: DATAPATH/ala/1k/test/nequip_npz.npz 4 | n_train: 10000 5 | 6 | key_mapping: 7 | atomic_number: atomic_numbers # atomic species, integers 8 | energy: total_energy # total potential eneriges to train to 9 | force: forces # atomic forces to train to 10 | pos: pos 11 | lattices: cell 12 | pbc: pbc # raw atomic positions 13 | npz_fixed_field_keys: # fields that are repeated across different examples 14 | - atomic_numbers 15 | - pbc 16 | 17 | chemical_symbols: 18 | - H 19 | - C 20 | - N 21 | - O 22 | 23 | global_rescale_scale_trainable: false 24 | per_species_rescale_trainable: true 25 | per_species_rescale_shifts: dataset_per_atom_total_energy_mean 26 | per_species_rescale_scales: dataset_forces_rms 27 | -------------------------------------------------------------------------------- /configs/simulate/nequip_data_cfg/aspirin.yml: -------------------------------------------------------------------------------- 1 | root: MODELPATH/nequip 2 | dataset: npz 3 | dataset_file_name: DATAPATH/md17/aspirin/10k/test/nequip_npz.npz 4 | n_train: 10000 5 | 6 | key_mapping: 7 | z: atomic_numbers 8 | E: total_energy 9 | F: forces 10 | R: pos 11 | 12 | npz_fixed_field_keys: 13 | - atomic_numbers 14 | 15 | chemical_symbols: 16 | - H 17 | - C 18 | - O 19 | 20 | per_species_rescale_shifts_trainable: false 21 | per_species_rescale_scales_trainable: false 22 | per_species_rescale_shifts: dataset_per_atom_total_energy_mean 23 | per_species_rescale_scales: dataset_forces_rms -------------------------------------------------------------------------------- /configs/simulate/nequip_data_cfg/ethanol.yml: -------------------------------------------------------------------------------- 1 | root: MODELPATH/nequip 2 | dataset: npz 3 | dataset_file_name: DATAPATH/md17/ethanol/10k/test/nequip_npz.npz 4 | n_train: 10000 5 | 6 | key_mapping: 7 | z: atomic_numbers 8 | E: total_energy 9 | F: forces 10 | R: pos 11 | 12 | npz_fixed_field_keys: 13 | - atomic_numbers 14 | 15 | chemical_symbols: 16 | - H 17 | - C 18 | - O 19 | 20 | per_species_rescale_shifts_trainable: false 21 | per_species_rescale_scales_trainable: false 22 | per_species_rescale_shifts: dataset_per_atom_total_energy_mean 23 | per_species_rescale_scales: dataset_forces_rms -------------------------------------------------------------------------------- /configs/simulate/nequip_data_cfg/lips.yml: -------------------------------------------------------------------------------- 1 | root: MODELPATH/nequip 2 | dataset: npz # type of data set, can be npz or ase 3 | dataset_file_name: DATAPATH/lips/20k/val/nequip_npz.npz 4 | n_train: 1000 5 | 6 | key_mapping: 7 | atomic_numbers: atomic_numbers # atomic species, integers 8 | energy: total_energy # total potential eneriges to train to 9 | forces: forces # atomic forces to train to 10 | pos: pos 11 | lattices: cell 12 | pbc: pbc # raw atomic positions 13 | 14 | npz_fixed_field_keys: # fields that are repeated across different examples 15 | - atomic_numbers 16 | - pbc 17 | - cell 18 | 19 | chemical_symbols: 20 | - Li 21 | - P 22 | - S 23 | 24 | global_rescale_scale_trainable: false 25 | per_species_rescale_trainable: true 26 | per_species_rescale_shifts: dataset_per_atom_total_energy_mean 27 | per_species_rescale_scales: dataset_forces_rms -------------------------------------------------------------------------------- /configs/simulate/nequip_data_cfg/naphthalene.yml: -------------------------------------------------------------------------------- 1 | root: MODELPATH/nequip 2 | dataset: npz 3 | dataset_file_name: DATAPATH/md17/naphthalene/10k/test/nequip_npz.npz 4 | n_train: 10000 5 | 6 | key_mapping: 7 | z: atomic_numbers 8 | E: total_energy 9 | F: forces 10 | R: pos 11 | 12 | npz_fixed_field_keys: 13 | - atomic_numbers 14 | 15 | chemical_symbols: 16 | - H 17 | - C 18 | 19 | per_species_rescale_shifts_trainable: false 20 | per_species_rescale_scales_trainable: false 21 | per_species_rescale_shifts: dataset_per_atom_total_energy_mean 22 | per_species_rescale_scales: dataset_forces_rms -------------------------------------------------------------------------------- /configs/simulate/nequip_data_cfg/salicylic_acid.yml: -------------------------------------------------------------------------------- 1 | root: MODELPATH/nequip 2 | dataset: npz 3 | dataset_file_name: DATAPATH/md17/salicylic_acid/10k/test/nequip_npz.npz 4 | n_train: 10000 5 | 6 | key_mapping: 7 | z: atomic_numbers 8 | E: total_energy 9 | F: forces 10 | R: pos 11 | 12 | npz_fixed_field_keys: 13 | - atomic_numbers 14 | 15 | chemical_symbols: 16 | - H 17 | - C 18 | - O 19 | 20 | per_species_rescale_shifts_trainable: false 21 | per_species_rescale_scales_trainable: false 22 | per_species_rescale_shifts: dataset_per_atom_total_energy_mean 23 | per_species_rescale_scales: dataset_forces_rms -------------------------------------------------------------------------------- /configs/simulate/nequip_data_cfg/water.yml: -------------------------------------------------------------------------------- 1 | root: MODELPATH/nequip 2 | dataset: npz # type of data set, can be npz or ase 3 | dataset_file_name: DATAPATH/water/90k/test/nequip_npz.npz 4 | n_train: 10000 5 | 6 | key_mapping: 7 | atom_types: atomic_numbers # atomic species, integers 8 | energy: total_energy # total potential eneriges to train to 9 | forces: forces # atomic forces to train to 10 | wrapped_coords: pos 11 | lattices: cell 12 | pbc: pbc # raw atomic positions 13 | npz_fixed_field_keys: # fields that are repeated across different examples 14 | - atomic_numbers 15 | - pbc 16 | 17 | chemical_symbols: 18 | - H 19 | - O 20 | 21 | global_rescale_scale_trainable: false 22 | per_species_rescale_trainable: true 23 | per_species_rescale_shifts: dataset_per_atom_total_energy_mean 24 | per_species_rescale_scales: dataset_forces_rms 25 | -------------------------------------------------------------------------------- /configs/simulate/salicylic_acid.yml: -------------------------------------------------------------------------------- 1 | model_dir: MODELPATH 2 | dataset_src: DATAPATH/md17/salicylic_acid/10k/test 3 | dp_data_path: DATAPATH/md17/salicylic_acid/10k/DP/test 4 | nequip_data_config: configs/simulate/nequip_data_cfg/salicylic_acid.yml 5 | identifier: 300ps 6 | sim_type: ocp 7 | seed: 123 8 | save_freq: 100 9 | steps: 600000 10 | max_test_points: 10000 11 | 12 | kcal: False 13 | T_init: 500. 14 | integrator: NoseHoover 15 | integrator_config: {"timestep": 0.5, "temperature": 500., "ttime": 20.} -------------------------------------------------------------------------------- /configs/simulate/water.yml: -------------------------------------------------------------------------------- 1 | model_dir: MODELPATH 2 | dataset_src: DATAPATH/water/10k/test 3 | nequip_data_config: configs/simulate/nequip_data_cfg/water.yml 4 | dp_data_path: DATAPATH/water_dp/test 5 | identifier: 500ps 6 | sim_type: ocp 7 | seed: 0 8 | save_freq: 100 9 | steps: 500000 10 | max_test_points: 10000 11 | 12 | kcal: False 13 | T_init: 300. 14 | integrator: NoseHoover 15 | integrator_config: {"timestep": 1., "temperature": 300., "ttime": 20.} 16 | -------------------------------------------------------------------------------- /configs/water/base.yml: -------------------------------------------------------------------------------- 1 | trainer: trainer 2 | 3 | dataset: 4 | src: DATAPATH/water 5 | name: water 6 | size: 10k 7 | normalize_labels: True 8 | 9 | logger: 10 | name: wandb 11 | project: mdbench 12 | 13 | task: 14 | dataset: lmdb 15 | description: "Regressing to energies and forces" 16 | type: regression 17 | metric: mae 18 | labels: 19 | - potential energy 20 | grad_input: atomic forces 21 | train_on_free_atoms: True 22 | eval_on_free_atoms: True 23 | 24 | optim: 25 | batch_size: 1 26 | eval_batch_size: 1 27 | num_workers: 4 28 | lr_initial: 0.001 29 | optimizer: Adam 30 | optimizer_params: {"amsgrad": True} 31 | 32 | scheduler: ReduceLROnPlateau 33 | patience: 5 34 | factor: 0.8 35 | min_lr: 0.000001 36 | 37 | max_epochs: 2000 38 | force_coefficient: 100 39 | energy_coefficient: 1 40 | ema_decay: 0.999 41 | clip_grad_norm: 10 42 | 43 | early_stopping_time: 604800 44 | early_stopping_lr: 0.000001 -------------------------------------------------------------------------------- /configs/water/dpp.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/water/base.yml 3 | 4 | model: 5 | name: dimenetplusplus 6 | hidden_channels: 128 7 | out_emb_channels: 256 8 | int_emb_size: 64 9 | basis_emb_size: 8 10 | num_blocks: 4 11 | cutoff: 5.0 12 | envelope_exponent: 5 13 | num_radial: 6 14 | num_spherical: 7 15 | num_before_skip: 1 16 | num_after_skip: 2 17 | num_output_layers: 3 18 | regress_forces: True 19 | use_pbc: True 20 | otf_graph: True 21 | 22 | optim: 23 | loss_force: l2mae 24 | -------------------------------------------------------------------------------- /configs/water/forcenet.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/water/base.yml 3 | 4 | model: 5 | name: forcenet 6 | num_interactions: 5 7 | cutoff: 6 8 | basis: "sphallmul" 9 | ablation: "none" 10 | depth_mlp_edge: 2 11 | depth_mlp_node: 1 12 | activation_str: "swish" 13 | decoder_activation_str: "swish" 14 | feat: "full" 15 | hidden_channels: 128 16 | decoder_hidden_channels: 128 17 | max_n: 3 18 | otf_graph: True 19 | 20 | optim: 21 | energy_coefficient: 0 22 | lr_initial: 0.001 -------------------------------------------------------------------------------- /configs/water/gemnet-T-scale.json: -------------------------------------------------------------------------------- 1 | { 2 | "comment": null, 3 | "TripInteraction_1_had_rbf": 5.5737762451171875, 4 | "TripInteraction_1_sum_cbf": 1.7056342363357544, 5 | "AtomUpdate_1_sum": 0.4618416130542755, 6 | "TripInteraction_2_had_rbf": 4.484927177429199, 7 | "TripInteraction_2_sum_cbf": 1.6062310934066772, 8 | "AtomUpdate_2_sum": 0.4112415015697479, 9 | "TripInteraction_3_had_rbf": 4.911035060882568, 10 | "TripInteraction_3_sum_cbf": 1.6301891803741455, 11 | "AtomUpdate_3_sum": 0.3454769253730774, 12 | "TripInteraction_4_had_rbf": 5.337745666503906, 13 | "TripInteraction_4_sum_cbf": 1.6658426523208618, 14 | "AtomUpdate_4_sum": 0.32520246505737305, 15 | "OutBlock_0_sum": 0.4573551118373871, 16 | "OutBlock_1_sum": 0.514782190322876, 17 | "OutBlock_2_sum": 0.4109618067741394, 18 | "OutBlock_3_sum": 0.37208542227745056, 19 | "OutBlock_4_sum": 0.3400648534297943 20 | } -------------------------------------------------------------------------------- /configs/water/gemnet-T.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/water/base.yml 3 | 4 | model: 5 | name: gemnet_t 6 | num_spherical: 7 7 | num_radial: 6 8 | num_blocks: 4 9 | emb_size_atom: 128 10 | emb_size_edge: 128 11 | emb_size_trip: 64 12 | emb_size_rbf: 16 13 | emb_size_cbf: 16 14 | emb_size_bil_trip: 64 15 | num_before_skip: 1 16 | num_after_skip: 1 17 | num_concat: 1 18 | num_atom: 2 19 | cutoff: 5.0 20 | max_neighbors: 50 21 | rbf: 22 | name: gaussian 23 | envelope: 24 | name: polynomial 25 | exponent: 5 26 | cbf: 27 | name: spherical_harmonics 28 | output_init: HeOrthogonal 29 | activation: silu 30 | scale_file: configs/water/gemnet-T-scale.json 31 | extensive: True 32 | otf_graph: True 33 | regress_forces: True 34 | direct_forces: False 35 | 36 | optim: 37 | batch_size: 1 38 | eval_batch_size: 1 39 | num_workers: 4 40 | lr_initial: 0.001 41 | optimizer: AdamW 42 | optimizer_params: {"eps": 1.e-7, "weight_decay": 0.000002, "amsgrad": True} 43 | 44 | scheduler: ReduceLROnPlateau 45 | patience: 5 46 | factor: 0.8 47 | min_lr: 0.000001 48 | 49 | max_epochs: 2000 50 | force_coefficient: 0.999 51 | energy_coefficient: 0.001 52 | ema_decay: 0.999 53 | clip_grad_norm: 10 54 | loss_force: l2mae 55 | 56 | early_stopping_time: 604800 57 | early_stopping_lr: 0.000001 -------------------------------------------------------------------------------- /configs/water/gemnet-dT-scale.json: -------------------------------------------------------------------------------- 1 | { 2 | "comment": null, 3 | "TripInteraction_1_had_rbf": 5.5737762451171875, 4 | "TripInteraction_1_sum_cbf": 1.7056342363357544, 5 | "AtomUpdate_1_sum": 0.4618416130542755, 6 | "TripInteraction_2_had_rbf": 4.484927177429199, 7 | "TripInteraction_2_sum_cbf": 1.6062310934066772, 8 | "AtomUpdate_2_sum": 0.4112415015697479, 9 | "TripInteraction_3_had_rbf": 4.911035060882568, 10 | "TripInteraction_3_sum_cbf": 1.630189299583435, 11 | "AtomUpdate_3_sum": 0.3454769253730774, 12 | "TripInteraction_4_had_rbf": 5.3377461433410645, 13 | "TripInteraction_4_sum_cbf": 1.6658427715301514, 14 | "AtomUpdate_4_sum": 0.32520249485969543, 15 | "OutBlock_0_sum": 0.4573551118373871, 16 | "OutBlock_0_had": 4.61002254486084, 17 | "OutBlock_1_sum": 0.45900964736938477, 18 | "OutBlock_1_had": 4.356435298919678, 19 | "OutBlock_2_sum": 0.3811725676059723, 20 | "OutBlock_2_had": 4.245356559753418, 21 | "OutBlock_3_sum": 0.37914782762527466, 22 | "OutBlock_3_had": 4.587994575500488, 23 | "OutBlock_4_sum": 0.3253774642944336, 24 | "OutBlock_4_had": 5.189902305603027 25 | } -------------------------------------------------------------------------------- /configs/water/gemnet-dT.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/water/base.yml 3 | 4 | identifier: dT 5 | 6 | model: 7 | name: gemnet_t 8 | num_spherical: 7 9 | num_radial: 6 10 | num_blocks: 4 11 | emb_size_atom: 128 12 | emb_size_edge: 128 13 | emb_size_trip: 64 14 | emb_size_rbf: 16 15 | emb_size_cbf: 16 16 | emb_size_bil_trip: 64 17 | num_before_skip: 1 18 | num_after_skip: 1 19 | num_concat: 1 20 | num_atom: 2 21 | cutoff: 5.0 22 | max_neighbors: 50 23 | rbf: 24 | name: gaussian 25 | envelope: 26 | name: polynomial 27 | exponent: 5 28 | cbf: 29 | name: spherical_harmonics 30 | output_init: HeOrthogonal 31 | activation: silu 32 | scale_file: configs/water/gemnet-dT-scale.json 33 | extensive: True 34 | otf_graph: True 35 | regress_forces: True 36 | direct_forces: True 37 | 38 | optim: 39 | batch_size: 1 40 | eval_batch_size: 1 41 | num_workers: 4 42 | lr_initial: 0.001 43 | optimizer: AdamW 44 | optimizer_params: {"eps": 1.e-7, "weight_decay": 0.000002, "amsgrad": True} 45 | 46 | scheduler: ReduceLROnPlateau 47 | patience: 5 48 | factor: 0.8 49 | min_lr: 0.000001 50 | 51 | max_epochs: 2000 52 | force_coefficient: 0.999 53 | energy_coefficient: 0.001 54 | ema_decay: 0.999 55 | clip_grad_norm: 10 56 | loss_force: l2mae 57 | 58 | early_stopping_time: 604800 59 | early_stopping_lr: 0.000001 -------------------------------------------------------------------------------- /configs/water/schnet.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/water/base.yml 3 | 4 | model: 5 | name: schnet 6 | hidden_channels: 64 7 | num_filters: 64 8 | num_interactions: 6 9 | num_gaussians: 25 10 | cutoff: 6.0 11 | use_pbc: True 12 | otf_graph: True -------------------------------------------------------------------------------- /deeppot_se/ala40k/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment": " model parameters", 3 | "model": { 4 | "type_map": [ 5 | "H", 6 | "C", 7 | "N", 8 | "O" 9 | ], 10 | "descriptor": { 11 | "type": "se_e2_a", 12 | "sel": [ 13 | 40, 14 | 40, 15 | 40, 16 | 40 17 | ], 18 | "rcut_smth": 1.00, 19 | "rcut": 6.00, 20 | "neuron": [ 21 | 25, 22 | 50, 23 | 100 24 | ], 25 | "resnet_dt": false, 26 | "axis_neuron": 12, 27 | "seed": 1, 28 | "_comment": " that's all" 29 | }, 30 | "fitting_net": { 31 | "neuron": [ 32 | 240, 33 | 120, 34 | 60, 35 | 30, 36 | 10 37 | ], 38 | "resnet_dt": true, 39 | "seed": 1, 40 | "_comment": " that's all" 41 | }, 42 | "_comment": " that's all" 43 | }, 44 | "learning_rate": { 45 | "type": "exp", 46 | "decay_steps": 20000, 47 | "start_lr": 0.001, 48 | "stop_lr": 3.51e-8, 49 | "_comment": "that's all" 50 | }, 51 | "loss": { 52 | "type": "ener", 53 | "start_pref_e": 0, 54 | "limit_pref_e": 0, 55 | "start_pref_f": 1000, 56 | "limit_pref_f": 1, 57 | "start_pref_v": 0, 58 | "limit_pref_v": 0, 59 | "_comment": " that's all" 60 | }, 61 | "training": { 62 | "training_data": { 63 | "systems": "../../DATAPATH/ala/40k/DP/train", 64 | "batch_size": 4, 65 | "_comment": "that's all" 66 | }, 67 | "validation_data": { 68 | "systems": [ 69 | "../../DATAPATH/ala/40k/DP/val" 70 | ], 71 | "batch_size": 4, 72 | "numb_btch": 5, 73 | "_comment": "that's all" 74 | }, 75 | "numb_steps": 4000000, 76 | "seed": 10, 77 | "disp_file": "lcurve.out", 78 | "disp_freq": 2000, 79 | "save_freq": 1000000, 80 | "_comment": "that's all" 81 | }, 82 | "_comment": "that's all" 83 | } -------------------------------------------------------------------------------- /deeppot_se/lips20k/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment": " model parameters", 3 | "model": { 4 | "type_map": [ 5 | "Li", 6 | "P", 7 | "S" 8 | ], 9 | "descriptor": { 10 | "type": "se_e2_a", 11 | "sel": [ 12 | 40, 13 | 40, 14 | 40 15 | ], 16 | "rcut_smth": 1.00, 17 | "rcut": 6.00, 18 | "neuron": [ 19 | 25, 20 | 50, 21 | 100 22 | ], 23 | "resnet_dt": false, 24 | "axis_neuron": 12, 25 | "seed": 1, 26 | "_comment": " that's all" 27 | }, 28 | "fitting_net": { 29 | "neuron": [ 30 | 240, 31 | 120, 32 | 60, 33 | 30, 34 | 10 35 | ], 36 | "resnet_dt": true, 37 | "seed": 1, 38 | "_comment": " that's all" 39 | }, 40 | "_comment": " that's all" 41 | }, 42 | "learning_rate": { 43 | "type": "exp", 44 | "decay_steps": 20000, 45 | "start_lr": 0.001, 46 | "stop_lr": 3.51e-8, 47 | "_comment": "that's all" 48 | }, 49 | "loss": { 50 | "type": "ener", 51 | "start_pref_e": 0, 52 | "limit_pref_e": 0, 53 | "start_pref_f": 1000, 54 | "limit_pref_f": 1, 55 | "start_pref_v": 0, 56 | "limit_pref_v": 0, 57 | "_comment": " that's all" 58 | }, 59 | "training": { 60 | "training_data": { 61 | "systems": "../../DATAPATH/lips/2k/DP/train", 62 | "batch_size": 4, 63 | "_comment": "that's all" 64 | }, 65 | "validation_data": { 66 | "systems": [ 67 | "../../DATAPATH/lips/2k/DP/test" 68 | ], 69 | "batch_size": 4, 70 | "numb_btch": 5, 71 | "_comment": "that's all" 72 | }, 73 | "numb_steps": 4000000, 74 | "seed": 10, 75 | "disp_file": "lcurve.out", 76 | "disp_freq": 2000, 77 | "save_freq": 1000000, 78 | "_comment": "that's all" 79 | }, 80 | "_comment": "that's all" 81 | } -------------------------------------------------------------------------------- /deeppot_se/md17_aspirin_10k/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment": " model parameters", 3 | "model": { 4 | "type_map": [ 5 | "H", 6 | "C", 7 | "O" 8 | ], 9 | "descriptor": { 10 | "type": "se_e2_a", 11 | "sel": [ 12 | 40, 13 | 40, 14 | 40 15 | ], 16 | "rcut_smth": 1.00, 17 | "rcut": 6.00, 18 | "neuron": [ 19 | 25, 20 | 50, 21 | 100 22 | ], 23 | "resnet_dt": false, 24 | "axis_neuron": 12, 25 | "seed": 1, 26 | "_comment": " that's all" 27 | }, 28 | "fitting_net": { 29 | "neuron": [ 30 | 240, 31 | 120, 32 | 60, 33 | 30, 34 | 10 35 | ], 36 | "resnet_dt": true, 37 | "seed": 1, 38 | "_comment": " that's all" 39 | }, 40 | "_comment": " that's all" 41 | }, 42 | "learning_rate": { 43 | "type": "exp", 44 | "decay_steps": 20000, 45 | "start_lr": 0.001, 46 | "stop_lr": 3.51e-8, 47 | "_comment": "that's all" 48 | }, 49 | "loss": { 50 | "type": "ener", 51 | "start_pref_e": 1, 52 | "limit_pref_e": 400, 53 | "start_pref_f": 1000, 54 | "limit_pref_f": 1, 55 | "start_pref_v": 0, 56 | "limit_pref_v": 0, 57 | "_comment": " that's all" 58 | }, 59 | "training": { 60 | "training_data": { 61 | "systems": "../../DATAPATH/md17/aspirin/10k/DP/train", 62 | "batch_size": 4, 63 | "_comment": "that's all" 64 | }, 65 | "validation_data": { 66 | "systems": [ 67 | "../../DATAPATH/md17/aspirin/10k/DP/val" 68 | ], 69 | "batch_size": 4, 70 | "numb_btch": 12, 71 | "_comment": "that's all" 72 | }, 73 | "numb_steps": 4000000, 74 | "seed": 10, 75 | "disp_file": "lcurve.out", 76 | "disp_freq": 2000, 77 | "save_freq": 1000000, 78 | "_comment": "that's all" 79 | }, 80 | "_comment": "that's all" 81 | } 82 | -------------------------------------------------------------------------------- /deeppot_se/md17_ehtanol_10k/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment": " model parameters", 3 | "model": { 4 | "type_map": [ 5 | "H", 6 | "C", 7 | "O" 8 | ], 9 | "descriptor": { 10 | "type": "se_e2_a", 11 | "sel": [ 12 | 40, 13 | 40, 14 | 40 15 | ], 16 | "rcut_smth": 1.00, 17 | "rcut": 6.00, 18 | "neuron": [ 19 | 25, 20 | 50, 21 | 100 22 | ], 23 | "resnet_dt": false, 24 | "axis_neuron": 12, 25 | "seed": 1, 26 | "_comment": " that's all" 27 | }, 28 | "fitting_net": { 29 | "neuron": [ 30 | 240, 31 | 120, 32 | 60, 33 | 30, 34 | 10 35 | ], 36 | "resnet_dt": true, 37 | "seed": 1, 38 | "_comment": " that's all" 39 | }, 40 | "_comment": " that's all" 41 | }, 42 | "learning_rate": { 43 | "type": "exp", 44 | "decay_steps": 20000, 45 | "start_lr": 0.001, 46 | "stop_lr": 3.51e-8, 47 | "_comment": "that's all" 48 | }, 49 | "loss": { 50 | "type": "ener", 51 | "start_pref_e": 1, 52 | "limit_pref_e": 400, 53 | "start_pref_f": 1000, 54 | "limit_pref_f": 1, 55 | "start_pref_v": 0, 56 | "limit_pref_v": 0, 57 | "_comment": " that's all" 58 | }, 59 | "training": { 60 | "training_data": { 61 | "systems": "../../DATAPATH/md17/ethanol/10k/DP/train", 62 | "batch_size": 4, 63 | "_comment": "that's all" 64 | }, 65 | "validation_data": { 66 | "systems": [ 67 | "../../DATAPATH/md17/ethanol/10k/DP/val" 68 | ], 69 | "batch_size": 4, 70 | "numb_btch": 12, 71 | "_comment": "that's all" 72 | }, 73 | "numb_steps": 4000000, 74 | "seed": 10, 75 | "disp_file": "lcurve.out", 76 | "disp_freq": 2000, 77 | "save_freq": 1000000, 78 | "_comment": "that's all" 79 | }, 80 | "_comment": "that's all" 81 | } 82 | -------------------------------------------------------------------------------- /deeppot_se/md17_naphthalene_10k/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment": " model parameters", 3 | "model": { 4 | "type_map": [ 5 | "H", 6 | "C" 7 | ], 8 | "descriptor": { 9 | "type": "se_e2_a", 10 | "sel": [ 11 | 40, 12 | 40 13 | ], 14 | "rcut_smth": 1.00, 15 | "rcut": 6.00, 16 | "neuron": [ 17 | 25, 18 | 50, 19 | 100 20 | ], 21 | "resnet_dt": false, 22 | "axis_neuron": 12, 23 | "seed": 1, 24 | "_comment": " that's all" 25 | }, 26 | "fitting_net": { 27 | "neuron": [ 28 | 240, 29 | 120, 30 | 60, 31 | 30, 32 | 10 33 | ], 34 | "resnet_dt": true, 35 | "seed": 1, 36 | "_comment": " that's all" 37 | }, 38 | "_comment": " that's all" 39 | }, 40 | "learning_rate": { 41 | "type": "exp", 42 | "decay_steps": 20000, 43 | "start_lr": 0.001, 44 | "stop_lr": 3.51e-8, 45 | "_comment": "that's all" 46 | }, 47 | "loss": { 48 | "type": "ener", 49 | "start_pref_e": 1, 50 | "limit_pref_e": 400, 51 | "start_pref_f": 1000, 52 | "limit_pref_f": 1, 53 | "start_pref_v": 0, 54 | "limit_pref_v": 0, 55 | "_comment": " that's all" 56 | }, 57 | "training": { 58 | "training_data": { 59 | "systems": "../../DATAPATH/md17/naphthalene/10k/DP/train", 60 | "batch_size": 4, 61 | "_comment": "that's all" 62 | }, 63 | "validation_data": { 64 | "systems": [ 65 | "../../DATAPATH/md17/naphthalene/10k/DP/val" 66 | ], 67 | "batch_size": 4, 68 | "numb_btch": 12, 69 | "_comment": "that's all" 70 | }, 71 | "numb_steps": 4000000, 72 | "seed": 10, 73 | "disp_file": "lcurve.out", 74 | "disp_freq": 2000, 75 | "save_freq": 1000000, 76 | "_comment": "that's all" 77 | }, 78 | "_comment": "that's all" 79 | } 80 | -------------------------------------------------------------------------------- /deeppot_se/md17_salicylic_acid_10k/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment": " model parameters", 3 | "model": { 4 | "type_map": [ 5 | "H", 6 | "C", 7 | "O" 8 | ], 9 | "descriptor": { 10 | "type": "se_e2_a", 11 | "sel": [ 12 | 40, 13 | 40, 14 | 40 15 | ], 16 | "rcut_smth": 1.00, 17 | "rcut": 6.00, 18 | "neuron": [ 19 | 25, 20 | 50, 21 | 100 22 | ], 23 | "resnet_dt": false, 24 | "axis_neuron": 12, 25 | "seed": 1, 26 | "_comment": " that's all" 27 | }, 28 | "fitting_net": { 29 | "neuron": [ 30 | 240, 31 | 120, 32 | 60, 33 | 30, 34 | 10 35 | ], 36 | "resnet_dt": true, 37 | "seed": 1, 38 | "_comment": " that's all" 39 | }, 40 | "_comment": " that's all" 41 | }, 42 | "learning_rate": { 43 | "type": "exp", 44 | "decay_steps": 20000, 45 | "start_lr": 0.001, 46 | "stop_lr": 3.51e-8, 47 | "_comment": "that's all" 48 | }, 49 | "loss": { 50 | "type": "ener", 51 | "start_pref_e": 1, 52 | "limit_pref_e": 400, 53 | "start_pref_f": 1000, 54 | "limit_pref_f": 1, 55 | "start_pref_v": 0, 56 | "limit_pref_v": 0, 57 | "_comment": " that's all" 58 | }, 59 | "training": { 60 | "training_data": { 61 | "systems": "../../DATAPATH/md17/salicylic_acid/10k/DP/train", 62 | "batch_size": 4, 63 | "_comment": "that's all" 64 | }, 65 | "validation_data": { 66 | "systems": [ 67 | "../../DATAPATH/md17/salicylic_acid/10k/DP/val" 68 | ], 69 | "batch_size": 4, 70 | "numb_btch": 12, 71 | "_comment": "that's all" 72 | }, 73 | "numb_steps": 4000000, 74 | "seed": 10, 75 | "disp_file": "lcurve.out", 76 | "disp_freq": 2000, 77 | "save_freq": 1000000, 78 | "_comment": "that's all" 79 | }, 80 | "_comment": "that's all" 81 | } 82 | -------------------------------------------------------------------------------- /deeppot_se/water10k/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment": " model parameters", 3 | "model": { 4 | "type_map": [ 5 | "O", 6 | "H" 7 | ], 8 | "descriptor": { 9 | "type": "se_e2_a", 10 | "sel": [ 11 | 46, 12 | 92 13 | ], 14 | "rcut_smth": 0.50, 15 | "rcut": 6.00, 16 | "neuron": [ 17 | 25, 18 | 50, 19 | 100 20 | ], 21 | "resnet_dt": false, 22 | "axis_neuron": 16, 23 | "seed": 1, 24 | "_comment": " that's all" 25 | }, 26 | "fitting_net": { 27 | "neuron": [ 28 | 240, 29 | 120, 30 | 60, 31 | 30, 32 | 10 33 | ], 34 | "resnet_dt": true, 35 | "seed": 1, 36 | "_comment": " that's all" 37 | }, 38 | "_comment": " that's all" 39 | }, 40 | "learning_rate": { 41 | "type": "exp", 42 | "decay_steps": 20000, 43 | "start_lr": 0.001, 44 | "stop_lr": 3.51e-8, 45 | "_comment": "that's all" 46 | }, 47 | "loss": { 48 | "type": "ener", 49 | "start_pref_e": 1, 50 | "limit_pref_e": 400, 51 | "start_pref_f": 1000, 52 | "limit_pref_f": 1, 53 | "start_pref_v": 0, 54 | "limit_pref_v": 0, 55 | "_comment": " that's all" 56 | }, 57 | "training": { 58 | "training_data": { 59 | "systems": [ 60 | "../../DATAPATH/water/10k/DP/train" 61 | ], 62 | "batch_size": 4, 63 | "_comment": "that's all" 64 | }, 65 | "validation_data": { 66 | "systems": [ 67 | "../../DATAPATH/water/10k/DP/val" 68 | ], 69 | "batch_size": 4, 70 | "numb_btch": 5, 71 | "_comment": "that's all" 72 | }, 73 | "numb_steps": 4000000, 74 | "seed": 10, 75 | "disp_file": "lcurve.out", 76 | "disp_freq": 2000, 77 | "save_freq": 1000000, 78 | "_comment": "that's all" 79 | }, 80 | "_comment": "that's all" 81 | } -------------------------------------------------------------------------------- /deeppot_se/water1k/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment": " model parameters", 3 | "model": { 4 | "type_map": [ 5 | "O", 6 | "H" 7 | ], 8 | "descriptor": { 9 | "type": "se_e2_a", 10 | "sel": [ 11 | 46, 12 | 92 13 | ], 14 | "rcut_smth": 0.50, 15 | "rcut": 6.00, 16 | "neuron": [ 17 | 25, 18 | 50, 19 | 100 20 | ], 21 | "resnet_dt": false, 22 | "axis_neuron": 16, 23 | "seed": 1, 24 | "_comment": " that's all" 25 | }, 26 | "fitting_net": { 27 | "neuron": [ 28 | 240, 29 | 120, 30 | 60, 31 | 30, 32 | 10 33 | ], 34 | "resnet_dt": true, 35 | "seed": 1, 36 | "_comment": " that's all" 37 | }, 38 | "_comment": " that's all" 39 | }, 40 | "learning_rate": { 41 | "type": "exp", 42 | "decay_steps": 20000, 43 | "start_lr": 0.001, 44 | "stop_lr": 3.51e-8, 45 | "_comment": "that's all" 46 | }, 47 | "loss": { 48 | "type": "ener", 49 | "start_pref_e": 1, 50 | "limit_pref_e": 400, 51 | "start_pref_f": 1000, 52 | "limit_pref_f": 1, 53 | "start_pref_v": 0, 54 | "limit_pref_v": 0, 55 | "_comment": " that's all" 56 | }, 57 | "training": { 58 | "training_data": { 59 | "systems": [ 60 | "../../DATAPATH/water/1k/DP/train" 61 | ], 62 | "batch_size": 4, 63 | "_comment": "that's all" 64 | }, 65 | "validation_data": { 66 | "systems": [ 67 | "../../DATAPATH/water/1k/DP/val" 68 | ], 69 | "batch_size": 4, 70 | "numb_btch": 5, 71 | "_comment": "that's all" 72 | }, 73 | "numb_steps": 4000000, 74 | "seed": 10, 75 | "disp_file": "lcurve.out", 76 | "disp_freq": 2000, 77 | "save_freq": 1000000, 78 | "_comment": "that's all" 79 | }, 80 | "_comment": "that's all" 81 | } -------------------------------------------------------------------------------- /deeppot_se/water90k/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment": " model parameters", 3 | "model": { 4 | "type_map": [ 5 | "O", 6 | "H" 7 | ], 8 | "descriptor": { 9 | "type": "se_e2_a", 10 | "sel": [ 11 | 46, 12 | 92 13 | ], 14 | "rcut_smth": 0.50, 15 | "rcut": 6.00, 16 | "neuron": [ 17 | 25, 18 | 50, 19 | 100 20 | ], 21 | "resnet_dt": false, 22 | "axis_neuron": 16, 23 | "seed": 1, 24 | "_comment": " that's all" 25 | }, 26 | "fitting_net": { 27 | "neuron": [ 28 | 240, 29 | 120, 30 | 60, 31 | 30, 32 | 10 33 | ], 34 | "resnet_dt": true, 35 | "seed": 1, 36 | "_comment": " that's all" 37 | }, 38 | "_comment": " that's all" 39 | }, 40 | "learning_rate": { 41 | "type": "exp", 42 | "decay_steps": 20000, 43 | "start_lr": 0.001, 44 | "stop_lr": 3.51e-8, 45 | "_comment": "that's all" 46 | }, 47 | "loss": { 48 | "type": "ener", 49 | "start_pref_e": 1, 50 | "limit_pref_e": 400, 51 | "start_pref_f": 1000, 52 | "limit_pref_f": 1, 53 | "start_pref_v": 0, 54 | "limit_pref_v": 0, 55 | "_comment": " that's all" 56 | }, 57 | "training": { 58 | "training_data": { 59 | "systems": [ 60 | "../../DATAPATH/water/90k/DP/train" 61 | ], 62 | "batch_size": 4, 63 | "_comment": "that's all" 64 | }, 65 | "validation_data": { 66 | "systems": [ 67 | "../../DATAPATH/water/90k/DP/val" 68 | ], 69 | "batch_size": 4, 70 | "numb_btch": 5, 71 | "_comment": "that's all" 72 | }, 73 | "numb_steps": 4000000, 74 | "seed": 10, 75 | "disp_file": "lcurve.out", 76 | "disp_freq": 2000, 77 | "save_freq": 1000000, 78 | "_comment": "that's all" 79 | }, 80 | "_comment": "that's all" 81 | } -------------------------------------------------------------------------------- /example_model/water_1k_schnet/checkpoints/best_checkpoint.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyonofx/MDsim/61ca2cfba373fd374343592556789fa30fbe3b56/example_model/water_1k_schnet/checkpoints/best_checkpoint.pt -------------------------------------------------------------------------------- /example_model/water_1k_schnet/checkpoints/checkpoint.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyonofx/MDsim/61ca2cfba373fd374343592556789fa30fbe3b56/example_model/water_1k_schnet/checkpoints/checkpoint.pt -------------------------------------------------------------------------------- /example_model/water_1k_schnet/checkpoints/config.yml: -------------------------------------------------------------------------------- 1 | amp: false 2 | checkpoint: null 3 | cpu: false 4 | dataset: 5 | - grad_target_mean: -3.5600233069761502e-09 6 | grad_target_std: 1.25220260734245 7 | name: water 8 | normalize_labels: true 9 | size: 1k 10 | src: DATAPATH/water/1k/train 11 | target_mean: -26.147620446227215 12 | target_std: 0.6255959083458403 13 | - src: DATAPATH/water/1k/val 14 | distributed: false 15 | distributed_backend: nccl 16 | distributed_port: 13356 17 | identifier: '' 18 | is_debug: false 19 | local_rank: 0 20 | logger: 21 | name: wandb 22 | project: mdbench 23 | mode: train 24 | model: 25 | cutoff: 6.0 26 | hidden_channels: 64 27 | name: schnet 28 | num_filters: 64 29 | num_gaussians: 25 30 | num_interactions: 6 31 | otf_graph: true 32 | use_pbc: true 33 | noddp: false 34 | optim: 35 | batch_size: 1 36 | clip_grad_norm: 10 37 | early_stopping_lr: 1.0e-06 38 | early_stopping_time: 604800 39 | ema_decay: 0.999 40 | energy_coefficient: 1 41 | eval_batch_size: 1 42 | factor: 0.8 43 | force_coefficient: 100 44 | lr_initial: 0.001 45 | max_epochs: 10000 46 | min_lr: 1.0e-06 47 | num_workers: 4 48 | optimizer: Adam 49 | optimizer_params: 50 | amsgrad: true 51 | patience: 50 52 | scheduler: ReduceLROnPlateau 53 | print_every: 200 54 | run_dir: example_model/water_1k_schnet 55 | seed: 0 56 | submit: false 57 | summit: false 58 | task: 59 | dataset: lmdb 60 | description: Regressing to energies and forces 61 | eval_on_free_atoms: true 62 | grad_input: atomic forces 63 | labels: 64 | - potential energy 65 | metric: mae 66 | train_on_free_atoms: true 67 | type: regression 68 | timestamp_id: null 69 | trainer: trainer 70 | world_size: 1 71 | -------------------------------------------------------------------------------- /example_sim/ala_nequip/atoms.traj: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyonofx/MDsim/61ca2cfba373fd374343592556789fa30fbe3b56/example_sim/ala_nequip/atoms.traj -------------------------------------------------------------------------------- /example_sim/ala_nequip/test_metric.json: -------------------------------------------------------------------------------- 1 | {"f_mae": 0.215261, "f_rmse": 0.344438, "e_mae": 13.130361, "e/N_mae": 0.596834, "num_params": 1110168, "running_time": 269723.19397234917} -------------------------------------------------------------------------------- /example_sim/ala_nequip/test_metric.log: -------------------------------------------------------------------------------- 1 | Using device: cuda 2 | WARNING: please note that models running on CUDA are usually nondeterministc and that this manifests in the final test errors; for a _more_ deterministic result, please use `--device cpu` 3 | Loading model... 4 | loaded model from training session 5 | Loading dataset... 6 | Loaded dataset specified in ala.yml. 7 | Using all frames from the specified test dataset, yielding a test set size of 2000 frames. 8 | Starting... 9 | 10 | --- Final result: --- 11 | f_mae = 0.215261 12 | f_rmse = 0.344438 13 | e_mae = 13.130361 14 | e/N_mae = 0.596834 15 | -------------------------------------------------------------------------------- /example_sim/aspirin_dimenet/atoms.traj: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyonofx/MDsim/61ca2cfba373fd374343592556789fa30fbe3b56/example_sim/aspirin_dimenet/atoms.traj -------------------------------------------------------------------------------- /example_sim/aspirin_dimenet/test_metric.json: -------------------------------------------------------------------------------- 1 | {"forces_mae": {"total": 6269.489235877991, "numel": 629937, "metric": 0.009952565472226574}, "forces_rmse": {"total": 6269.489235877991, "numel": 629937, "metric": 0.009952565472226574}, "forces_cos": {"total": 209944.96466064453, "numel": 209979, "metric": 0.9998379107465248}, "forces_magnitude": {"total": 2272.355224132538, "numel": 209979, "metric": 0.010821821344670362}, "energy_mae": {"total": 495.93359375, "numel": 9999, "metric": 0.04959831920692069}, "energy_rmse": {"total": 495.93359375, "numel": 9999, "metric": 0.04959831920692069}, "loss": {"total": 6231.116358757019, "numel": 313, "metric": 19.907719996028813}, "num_params": 2100070, "running_time": 5378.366251945496, "early_stop": true} -------------------------------------------------------------------------------- /example_sim/lips_gemnet-t/atoms.traj: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyonofx/MDsim/61ca2cfba373fd374343592556789fa30fbe3b56/example_sim/lips_gemnet-t/atoms.traj -------------------------------------------------------------------------------- /example_sim/lips_gemnet-t/test_metric.json: -------------------------------------------------------------------------------- 1 | {"forces_mae": {"total": 1647.8261918127537, "numel": 1245000, "metric": 0.00132355517414679}, "forces_rmse": {"total": 1647.8261918127537, "numel": 1245000, "metric": 0.00132355517414679}, "forces_cos": {"total": 414992.70877075195, "numel": 415000, "metric": 0.9999824307728963}, "forces_magnitude": {"total": 559.1305168792605, "numel": 415000, "metric": 0.0013473024503114711}, "energy_mae": {"total": 212.92987060546875, "numel": 5000, "metric": 0.04258597412109375}, "energy_rmse": {"total": 212.92987060546875, "numel": 5000, "metric": 0.04258597412109375}, "loss": {"total": 21.893823865801096, "numel": 5000, "metric": 0.004378764773160219}, "num_params": 1891025, "running_time": 8920.006316900253, "early_stop": false, "simulated_frames": 200000} -------------------------------------------------------------------------------- /example_sim/water-10k_gemnet-t/atoms.traj: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyonofx/MDsim/61ca2cfba373fd374343592556789fa30fbe3b56/example_sim/water-10k_gemnet-t/atoms.traj -------------------------------------------------------------------------------- /example_sim/water-10k_gemnet-t/test_metric.json: -------------------------------------------------------------------------------- 1 | {"forces_mae": {"total": 4198.861646324396, "numel": 5760000, "metric": 0.0007289690358202077}, "forces_rmse": {"total": 4198.861646324396, "numel": 5760000, "metric": 0.0007289690358202077}, "forces_cos": {"total": 1919997.7127685547, "numel": 1920000, "metric": 0.9999988087336222}, "forces_magnitude": {"total": 1666.4475756287575, "numel": 1920000, "metric": 0.0008679414456399778}, "energy_mae": {"total": 2499.787738800049, "numel": 10000, "metric": 0.2499787738800049}, "energy_rmse": {"total": 2499.787738800049, "numel": 10000, "metric": 0.2499787738800049}, "loss": {"total": 14.167603051348124, "numel": 10000, "metric": 0.0014167603051348124}, "num_params": 1891025, "running_time": 2074.7184019088745, "early_stop": false} -------------------------------------------------------------------------------- /example_sim/water-1k_schnet/atoms.traj: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyonofx/MDsim/61ca2cfba373fd374343592556789fa30fbe3b56/example_sim/water-1k_schnet/atoms.traj -------------------------------------------------------------------------------- /example_sim/water-1k_schnet/test_metric.json: -------------------------------------------------------------------------------- 1 | {"forces_mae": {"total": 75394.92027902603, "numel": 5760000, "metric": 0.013089395881775353}, "forces_rmse": {"total": 75394.92027902603, "numel": 5760000, "metric": 0.013089395881775353}, "forces_cos": {"total": 1919179.0262298584, "numel": 1920000, "metric": 0.9995724094947179}, "forces_magnitude": {"total": 28081.919793128967, "numel": 1920000, "metric": 0.01462599989225467}, "energy_mae": {"total": 2221.2859592437744, "numel": 10000, "metric": 0.22212859592437745}, "energy_rmse": {"total": 2221.2859592437744, "numel": 10000, "metric": 0.22212859592437745}, "loss": {"total": 2191.8285777997226, "numel": 10000, "metric": 0.21918285777997226}, "num_params": 117953, "running_time": 1451.8271114826202} -------------------------------------------------------------------------------- /fit_scaling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for calculating the scaling factors used to even out GemNet activation 3 | scales. This generates the `scale_file` specified in the config, which is then 4 | read in at model initialization. 5 | This only needs to be run if the hyperparameters or model change 6 | in places were it would affect the activation scales. 7 | """ 8 | 9 | import logging 10 | import os 11 | import sys 12 | import numpy as np 13 | import torch 14 | from tqdm import trange 15 | 16 | from mdsim.common.flags import flags 17 | from mdsim.common.registry import registry 18 | from mdsim.common.utils import build_config, setup_imports, setup_logging 19 | from mdsim.models.gemnet.layers.scaling import AutomaticFit 20 | from mdsim.models.gemnet.utils import write_json 21 | 22 | if __name__ == "__main__": 23 | setup_logging() 24 | 25 | num_batches = 25 # number of batches to use to fit a single variable 26 | 27 | parser = flags.get_parser() 28 | args, override_args = parser.parse_known_args() 29 | config = build_config(args, override_args) 30 | assert config["model"]["name"].startswith("gemnet") 31 | config["logger"] = "tensorboard" 32 | 33 | if args.distributed: 34 | raise ValueError( 35 | "I don't think this works with DDP (race conditions)." 36 | ) 37 | 38 | setup_imports() 39 | 40 | scale_file = config["model"]["scale_file"] 41 | 42 | logging.info(f"Run fitting for model: {args.identifier}") 43 | logging.info(f"Target scale file: {scale_file}") 44 | 45 | def initialize_scale_file(scale_file): 46 | # initialize file 47 | preset = {"comment": args.identifier} 48 | write_json(scale_file, preset) 49 | 50 | if os.path.exists(scale_file): 51 | logging.warning(f"Already found existing file: {scale_file}") 52 | flag = input( 53 | "Do you want to continue and overwrite the file (1), " 54 | "only fit the variables not fitted yet (2), or exit (3)? " 55 | ) 56 | if str(flag) == "1": 57 | logging.info("Overwriting the current file.") 58 | initialize_scale_file(scale_file) 59 | elif str(flag) == "2": 60 | logging.info("Only fitting unfitted variables.") 61 | else: 62 | print(flag) 63 | logging.info("Exiting script") 64 | sys.exit() 65 | else: 66 | initialize_scale_file(scale_file) 67 | 68 | AutomaticFit.set2fitmode() 69 | 70 | # compose dataset configs. 71 | train_data_cfg = config['dataset'] 72 | dataset_name = train_data_cfg['name'] 73 | if dataset_name == 'md17': 74 | train_data_cfg['src'] = os.path.join(train_data_cfg['src'], train_data_cfg['molecule']) 75 | train_data_cfg['name'] = 'md17-' + train_data_cfg['molecule'] 76 | src = os.path.join(train_data_cfg['src'], train_data_cfg['size']) 77 | train_data_cfg['src'] = os.path.join(src, 'train') 78 | 79 | norm_stats = np.load(os.path.join(src, 'metadata.npy'), allow_pickle=True).item() 80 | if not train_data_cfg['normalize_labels']: 81 | # mean of energy is arbitrary. should always substract. 82 | # this is done in . 83 | train_data_cfg['target_mean'] = norm_stats['e_mean'] 84 | train_data_cfg['target_std'] = 1. 85 | train_data_cfg['grad_target_mean'] = 0. 86 | train_data_cfg['grad_target_std'] = 1. 87 | train_data_cfg['normalize_labels'] = True 88 | else: 89 | train_data_cfg['target_mean'] = float(norm_stats['e_mean']) 90 | train_data_cfg['target_std'] = float(norm_stats['e_std']) 91 | train_data_cfg['grad_target_mean'] = float(norm_stats['f_mean']) 92 | train_data_cfg['grad_target_std'] = float(norm_stats['f_std']) 93 | # train, val, test 94 | config['dataset'] = [train_data_cfg, 95 | {'src': os.path.join(src, 'val')}, ] 96 | 97 | # initialize trainer. 98 | trainer = registry.get_trainer_class( 99 | config.get("trainer", "energy") 100 | )( 101 | task=config["task"], 102 | model=config["model"], 103 | dataset=config["dataset"], 104 | optimizer=config["optim"], 105 | identifier=config["identifier"], 106 | timestamp_id=config.get("timestamp_id", None), 107 | run_dir=config.get("run_dir", None), 108 | is_debug=config.get("is_debug", False), 109 | print_every=config.get("print_every", 100), 110 | seed=config.get("seed", 0), 111 | logger=config.get("logger", "wandb"), 112 | local_rank=config["local_rank"], 113 | amp=config.get("amp", False), 114 | cpu=config.get("cpu", False), 115 | slurm=config.get("slurm", {}), 116 | no_energy=config.get("no_energy", False) 117 | ) 118 | 119 | # Fitting loop 120 | logging.info("Start fitting") 121 | 122 | if not AutomaticFit.fitting_completed(): 123 | with torch.no_grad(): 124 | trainer.model.eval() 125 | for _ in trange(len(AutomaticFit.queue) + 1): 126 | assert ( 127 | trainer.val_loader is not None 128 | ), "Val dataset is required for making predictions" 129 | 130 | for i, batch in enumerate(trainer.val_loader): 131 | with torch.cuda.amp.autocast( 132 | enabled=trainer.scaler is not None 133 | ): 134 | out = trainer._forward(batch) 135 | loss = trainer._compute_loss(out, batch) 136 | del out, loss 137 | if i == num_batches: 138 | break 139 | 140 | current_var = AutomaticFit.activeVar 141 | if current_var is not None: 142 | current_var.fit() # fit current variable 143 | else: 144 | print("Found no variable to fit. Something went wrong!") 145 | 146 | assert AutomaticFit.fitting_completed() 147 | logging.info(f"Fitting done. Results saved to: {scale_file}") 148 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | import yaml 5 | import time 6 | 7 | import submitit 8 | 9 | from mdsim.common import distutils 10 | from mdsim.common.flags import flags 11 | from mdsim.common.registry import registry 12 | from mdsim.common.utils import ( 13 | build_config, 14 | create_grid, 15 | save_experiment_log, 16 | setup_imports, 17 | setup_logging, 18 | compose_data_cfg 19 | ) 20 | 21 | 22 | class Runner(submitit.helpers.Checkpointable): 23 | def __init__(self): 24 | self.config = None 25 | 26 | def __call__(self, config): 27 | setup_logging() 28 | self.config = copy.deepcopy(config) 29 | 30 | if config['distributed']: 31 | distutils.setup(config) 32 | 33 | try: 34 | setup_imports() 35 | 36 | # compose dataset configs. 37 | train_data_cfg = config['dataset'] 38 | train_data_cfg = compose_data_cfg(train_data_cfg) 39 | config['dataset'] = [ 40 | train_data_cfg, 41 | {'src': os.path.join(os.path.dirname(train_data_cfg['src']), 'val')} 42 | ] 43 | 44 | self.config = copy.deepcopy(config) 45 | 46 | # initialize trainer. 47 | self.trainer = registry.get_trainer_class( 48 | config.get("trainer", "energy") 49 | )( 50 | task=config["task"], 51 | model=config["model"], 52 | dataset=config["dataset"], 53 | optimizer=config["optim"], 54 | identifier=config["identifier"], 55 | timestamp_id=config.get("timestamp_id", None), 56 | run_dir=config.get("run_dir", None), 57 | is_debug=config.get("is_debug", False), 58 | print_every=config.get("print_every", 100), 59 | seed=config.get("seed", 0), 60 | logger=config.get("logger", "wandb"), 61 | local_rank=config["local_rank"], 62 | amp=config.get("amp", False), 63 | cpu=config.get("cpu", False), 64 | slurm=config.get("slurm", {}), 65 | no_energy=config.get("no_energy", False) 66 | ) 67 | 68 | # save config. 69 | with open(os.path.join(self.trainer.config["cmd"]["checkpoint_dir"], 'config.yml'), 'w') as yf: 70 | yaml.dump(self.config, yf, default_flow_style=False) 71 | 72 | self.task = registry.get_task_class(config["mode"])(self.config) 73 | self.task.setup(self.trainer) 74 | start_time = time.time() 75 | self.task.run() 76 | distutils.synchronize() 77 | if distutils.is_master(): 78 | logging.info(f"Total time taken: {time.time() - start_time}") 79 | finally: 80 | if config['distributed']: 81 | distutils.cleanup() 82 | 83 | def checkpoint(self, *args, **kwargs): 84 | new_runner = Runner() 85 | self.trainer.save(checkpoint_file="checkpoint.pt", training_state=True) 86 | self.config["checkpoint"] = self.task.chkpt_path 87 | self.config["timestamp_id"] = self.trainer.timestamp_id 88 | if self.trainer.logger is not None: 89 | self.trainer.logger.mark_preempting() 90 | return submitit.helpers.DelayedSubmission(new_runner, self.config) 91 | 92 | 93 | if __name__ == "__main__": 94 | setup_logging() 95 | parser = flags.get_parser() 96 | args, override_args = parser.parse_known_args() 97 | if args.nequip: 98 | os.system(f'nequip-train {args.config_yml}') 99 | else: 100 | config = build_config(args, override_args) 101 | if args.submit: # Run on cluster 102 | slurm_add_params = config.get( 103 | "slurm", None 104 | ) # additional slurm arguments 105 | if args.sweep_yml: # Run grid search 106 | configs = create_grid(config, args.sweep_yml) 107 | else: 108 | configs = [config] 109 | 110 | logging.info(f"Submitting {len(configs)} jobs") 111 | executor = submitit.AutoExecutor( 112 | folder=args.logdir / "%j", slurm_max_num_timeout=3 113 | ) 114 | executor.update_parameters( 115 | name=args.identifier, 116 | mem_gb=args.slurm_mem, 117 | timeout_min=args.slurm_timeout * 60, 118 | slurm_partition=args.slurm_partition, 119 | gpus_per_node=args.num_gpus, 120 | cpus_per_task=(config["optim"]["num_workers"] + 1), 121 | tasks_per_node=(args.num_gpus if args.distributed else 1), 122 | nodes=args.num_nodes, 123 | slurm_additional_parameters=slurm_add_params, 124 | ) 125 | for config in configs: 126 | config["slurm"] = copy.deepcopy(executor.parameters) 127 | config["slurm"]["folder"] = str(executor.folder) 128 | jobs = executor.map_array(Runner(), configs) 129 | logging.info( 130 | f"Submitted jobs: {', '.join([job.job_id for job in jobs])}" 131 | ) 132 | log_file = save_experiment_log(args, jobs, configs) 133 | logging.info(f"Experiment log saved to: {log_file}") 134 | 135 | else: # Run locally 136 | Runner()(config) 137 | -------------------------------------------------------------------------------- /mdsim/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | -------------------------------------------------------------------------------- /mdsim/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyonofx/MDsim/61ca2cfba373fd374343592556789fa30fbe3b56/mdsim/common/__init__.py -------------------------------------------------------------------------------- /mdsim/common/const.py: -------------------------------------------------------------------------------- 1 | """ 2 | adopted from 3 | https://github.com/learningmatter-mit/NeuralForceField/blob/master/nff/utils/constants.py 4 | """ 5 | import torch 6 | import math 7 | from rdkit import Chem 8 | import copy 9 | 10 | PERIODICTABLE = Chem.GetPeriodicTable() 11 | 12 | HARTREE_TO_KCAL_MOL = 627.509 13 | EV_TO_KCAL_MOL = 23.06052 14 | 15 | # Distances 16 | BOHR_RADIUS = 0.529177 17 | 18 | # Masses 19 | ATOMIC_MASS = { 20 | 1: 1.008, 21 | 3: 6.941, 22 | 6: 12.01, 23 | 7: 14.0067, 24 | 8: 15.999, 25 | 9: 18.998403, 26 | 14: 28.0855, 27 | 16: 32.06, 28 | } 29 | 30 | AU_TO_KCAL = { 31 | 'energy': HARTREE_TO_KCAL_MOL, 32 | '_grad': 1.0 / BOHR_RADIUS, 33 | } 34 | 35 | KCAL_TO_AU = { 36 | 'energy': 1.0 / HARTREE_TO_KCAL_MOL, 37 | '_grad': BOHR_RADIUS, 38 | } 39 | 40 | KB_EV = 0.0000861731 41 | KB_AU = 3.166815e-6 42 | EV_TO_AU = 1 / 27.2114 43 | 44 | # Coulomb's constant, in (kcal/mol) * (A / e^2), 45 | # where A is Angstroms and e is the electron charge 46 | KE_KCAL = 332.07 47 | 48 | # Hardness used in xtb, in eV. Source: Ghosh, D.C. and Islam, N., 2010. 49 | # Semiempirical evaluation of the global hardness of the atoms 50 | # of 103 elements of the periodic table using the most probable 51 | # radii as their size descriptors. International Journal of 52 | # Quantum Chemistry, 110(6), pp.1206-1213. 53 | 54 | 55 | HARDNESS_EV = {"H": 6.4299, 56 | "He": 12.5449, 57 | "Li": 2.3746, 58 | "Be": 3.4968, 59 | "B": 4.6190, 60 | "C": 5.7410, 61 | "N": 6.6824, 62 | "O": 7.9854, 63 | "F": 9.1065, 64 | "Ne": 10.2303, 65 | "Na": 2.4441, 66 | "Mg": 3.0146, 67 | "Al": 3.5849, 68 | "Si": 4.1551, 69 | "P": 4.7258, 70 | "S": 5.2960, 71 | "Cl": 5.8662, 72 | "Ar": 6.4366, 73 | "K": 2.3273, 74 | "Ca": 2.7587, 75 | "Br": 5.9111, 76 | "I": 5.5839} 77 | 78 | # Hardness in AU 79 | HARDNESS_AU = {key: val * EV_TO_AU for key, val in 80 | HARDNESS_EV.items()} 81 | 82 | # Hardness in AU as a matrix 83 | HARDNESS_AU_MAT = torch.zeros(200) 84 | for key, val in HARDNESS_AU.items(): 85 | at_num = int(PERIODICTABLE.GetAtomicNumber(key)) 86 | HARDNESS_AU_MAT[at_num] = val 87 | 88 | # Times 89 | 90 | FS_TO_AU = 41.341374575751 91 | FS_TO_ASE = 0.098 92 | ASE_TO_FS = 1 / FS_TO_ASE 93 | 94 | # Masses 95 | AMU_TO_AU = 1.67262192369e-27 / (9.1093837015e-31) 96 | 97 | # Weird units used by Gaussian 98 | CM_TO_J = 1.98630e-23 99 | DYN_TO_J_PER_M = 0.00001 100 | ANGS_TO_M = 1e-10 101 | MDYN_PER_A_TO_J_PER_M = DYN_TO_J_PER_M / 1000 / ANGS_TO_M 102 | KG_TO_AMU = 1 / (1.66e-27) 103 | HBAR_SI = 6.626e-34 / (2 * math.pi) 104 | 105 | 106 | # Times 107 | 108 | FS_TO_AU = 41.341374575751 109 | FS_TO_ASE = 0.098 110 | ASE_TO_FS = 1/FS_TO_ASE 111 | 112 | # Masses 113 | AMU_TO_AU = 1.66e-27/(9.1093837015e-31) 114 | 115 | # Weird units used by Gaussian 116 | CM_TO_J = 1.98630e-23 117 | DYN_TO_J_PER_M = 0.00001 118 | ANGS_TO_M = 1e-10 119 | MDYN_PER_A_TO_J_PER_M = DYN_TO_J_PER_M / 1000 / ANGS_TO_M 120 | KG_TO_AMU = 1 / (1.66e-27) 121 | HBAR_SI = 6.626e-34 / (2 * math.pi) 122 | 123 | AU_TO_KCAL = { 124 | 'energy': HARTREE_TO_KCAL_MOL, 125 | '_grad': 1.0 / BOHR_RADIUS, 126 | } 127 | 128 | KCAL_TO_AU = { 129 | 'energy': 1.0 / HARTREE_TO_KCAL_MOL, 130 | '_grad': BOHR_RADIUS, 131 | } 132 | 133 | 134 | 135 | ELEC_CONFIG = {"1": [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 136 | "6": [2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 137 | "7": [2, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 138 | "8": [2, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 139 | "9": [2, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 140 | "11": [2, 2, 6, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 141 | "14": [2, 2, 6, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 142 | "16": [2, 2, 6, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 143 | "17": [2, 2, 5, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 144 | "86": [2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 6, 2, 6, 10, 4] 145 | } 146 | 147 | ELEC_CONFIG = {int(key): val for key, val in ELEC_CONFIG.items()} 148 | 149 | 150 | # with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 151 | # "elec_config.json"), "r") as f: 152 | # ELEC_CONFIG = json.load(f) 153 | # ELEC_CONFIG = {int(key): val for key, val in ELEC_CONFIG.items()} 154 | 155 | 156 | 157 | def convert_units(props, conversion_dict): 158 | """Converts dictionary of properties to the desired units. 159 | Args: 160 | props (dict): dictionary containing the properties of interest. 161 | conversion_dict (dict): constants to convert. 162 | Returns: 163 | props (dict): dictionary with properties converted. 164 | """ 165 | 166 | props = props.copy() 167 | for prop_key in props.keys(): 168 | for conv_key, conv_const in conversion_dict.items(): 169 | if conv_key in prop_key: 170 | props[prop_key] = [ 171 | x * conv_const 172 | for x in props[prop_key] 173 | ] 174 | 175 | return props 176 | 177 | 178 | def exc_ev_to_hartree(props, 179 | add_ground_energy=False): 180 | """ Note: only converts excited state energies from ev to hartree, 181 | not gradients. 182 | """ 183 | 184 | assert "energy_0" in props.keys() 185 | exc_keys = [key for key in props.keys() if 186 | key.startswith('energy') and 'grad' not in key 187 | and key != 'energy_0'] 188 | energy_0 = props['energy_0'] 189 | new_props = copy.deepcopy(props) 190 | 191 | for key in exc_keys: 192 | new_props[key] *= EV_TO_AU 193 | if add_ground_energy: 194 | new_props[key] += energy_0 195 | 196 | return new_props -------------------------------------------------------------------------------- /mdsim/common/distutils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import subprocess 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | def setup(config): 10 | if config["submit"]: 11 | node_list = os.environ.get("SLURM_STEP_NODELIST") 12 | if node_list is None: 13 | node_list = os.environ.get("SLURM_JOB_NODELIST") 14 | if node_list is not None: 15 | try: 16 | hostnames = subprocess.check_output( 17 | ["scontrol", "show", "hostnames", node_list] 18 | ) 19 | config["init_method"] = "tcp://{host}:{port}".format( 20 | host=hostnames.split()[0].decode("utf-8"), 21 | port=config["distributed_port"], 22 | ) 23 | nnodes = int(os.environ.get("SLURM_NNODES")) 24 | ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") 25 | if ntasks_per_node is not None: 26 | ntasks_per_node = int(ntasks_per_node) 27 | else: 28 | ntasks = int(os.environ.get("SLURM_NTASKS")) 29 | nnodes = int(os.environ.get("SLURM_NNODES")) 30 | assert ntasks % nnodes == 0 31 | ntasks_per_node = int(ntasks / nnodes) 32 | if ntasks_per_node == 1: 33 | assert config["world_size"] % nnodes == 0 34 | gpus_per_node = config["world_size"] // nnodes 35 | node_id = int(os.environ.get("SLURM_NODEID")) 36 | config["rank"] = node_id * gpus_per_node 37 | config["local_rank"] = 0 38 | else: 39 | assert ntasks_per_node == config["world_size"] // nnodes 40 | config["rank"] = int(os.environ.get("SLURM_PROCID")) 41 | config["local_rank"] = int(os.environ.get("SLURM_LOCALID")) 42 | 43 | logging.info( 44 | f"Init: {config['init_method']}, {config['world_size']}, {config['rank']}" 45 | ) 46 | dist.init_process_group( 47 | backend=config["distributed_backend"], 48 | init_method=config["init_method"], 49 | world_size=config["world_size"], 50 | rank=config["rank"], 51 | ) 52 | except subprocess.CalledProcessError as e: # scontrol failed 53 | raise e 54 | except FileNotFoundError: # Slurm is not installed 55 | pass 56 | elif config["summit"]: 57 | world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 58 | world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 59 | get_master = ( 60 | "echo $(cat {} | sort | uniq | grep -v batch | grep -v login | head -1)" 61 | ).format(os.environ["LSB_DJOB_HOSTFILE"]) 62 | os.environ["MASTER_ADDR"] = str( 63 | subprocess.check_output(get_master, shell=True) 64 | )[2:-3] 65 | os.environ["MASTER_PORT"] = "23456" 66 | os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] 67 | os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] 68 | # NCCL and MPI initialization 69 | dist.init_process_group( 70 | backend="nccl", 71 | rank=world_rank, 72 | world_size=world_size, 73 | init_method="env://", 74 | ) 75 | else: 76 | dist.init_process_group( 77 | backend=config["distributed_backend"], init_method="env://" 78 | ) 79 | 80 | 81 | def cleanup(): 82 | dist.destroy_process_group() 83 | 84 | 85 | def initialized(): 86 | return dist.is_available() and dist.is_initialized() 87 | 88 | 89 | def get_rank(): 90 | return dist.get_rank() if initialized() else 0 91 | 92 | 93 | def get_world_size(): 94 | return dist.get_world_size() if initialized() else 1 95 | 96 | 97 | def is_master(): 98 | return get_rank() == 0 99 | 100 | 101 | def synchronize(): 102 | if get_world_size() == 1: 103 | return 104 | dist.barrier() 105 | 106 | 107 | def broadcast(tensor, src, group=dist.group.WORLD, async_op=False): 108 | if get_world_size() == 1: 109 | return 110 | dist.broadcast(tensor, src, group, async_op) 111 | 112 | 113 | def all_reduce(data, group=dist.group.WORLD, average=False, device=None): 114 | if get_world_size() == 1: 115 | return data 116 | tensor = data 117 | if not isinstance(data, torch.Tensor): 118 | tensor = torch.tensor(data) 119 | if device is not None: 120 | tensor = tensor.cuda(device) 121 | dist.all_reduce(tensor, group=group) 122 | if average: 123 | tensor /= get_world_size() 124 | if not isinstance(data, torch.Tensor): 125 | result = tensor.cpu().numpy() if tensor.numel() > 1 else tensor.item() 126 | else: 127 | result = tensor 128 | return result 129 | 130 | 131 | def all_gather(data, group=dist.group.WORLD, device=None): 132 | if get_world_size() == 1: 133 | return data 134 | tensor = data 135 | if not isinstance(data, torch.Tensor): 136 | tensor = torch.tensor(data) 137 | if device is not None: 138 | tensor = tensor.cuda(device) 139 | tensor_list = [ 140 | tensor.new_zeros(tensor.shape) for _ in range(get_world_size()) 141 | ] 142 | dist.all_gather(tensor_list, tensor, group=group) 143 | if not isinstance(data, torch.Tensor): 144 | result = [tensor.cpu().numpy() for tensor in tensor_list] 145 | else: 146 | result = tensor_list 147 | return result 148 | -------------------------------------------------------------------------------- /mdsim/common/flags.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | 5 | class Flags: 6 | def __init__(self): 7 | self.parser = argparse.ArgumentParser( 8 | description="Graph Networks for Electrocatalyst Design" 9 | ) 10 | self.add_core_args() 11 | 12 | def get_parser(self): 13 | return self.parser 14 | 15 | def add_core_args(self): 16 | self.parser.add_argument_group("Core Arguments") 17 | self.parser.add_argument( 18 | "--mode", 19 | default='train', 20 | choices=["train", "predict", "run-relaxations", "validate"], 21 | help="Whether to train the model, make predictions, or to run relaxations", 22 | ) 23 | self.parser.add_argument( 24 | "--config-yml", 25 | type=Path, 26 | help="Path to a config file listing data, model, optim parameters.", 27 | ) 28 | self.parser.add_argument( 29 | "--identifier", 30 | default=None, 31 | type=str, 32 | help="Experiment identifier to append to checkpoint/log/result directory", 33 | ) 34 | self.parser.add_argument( 35 | "--debug", 36 | action="store_true", 37 | help="Whether this is a debugging run or not", 38 | ) 39 | self.parser.add_argument( 40 | "--run-dir", 41 | default="MODELPATH/", 42 | type=str, 43 | help="Directory to store checkpoint/log/result directory", 44 | ) 45 | self.parser.add_argument( 46 | "--print-every", 47 | default=200, 48 | type=int, 49 | help="Log every N iterations (default: 200)", 50 | ) 51 | self.parser.add_argument( 52 | "--seed", default=0, type=int, help="Seed for torch, cuda, numpy" 53 | ) 54 | self.parser.add_argument( 55 | "--amp", action="store_true", help="Use mixed-precision training" 56 | ) 57 | self.parser.add_argument( 58 | "--checkpoint", type=str, help="Model checkpoint to load" 59 | ) 60 | self.parser.add_argument( 61 | "--timestamp-id", 62 | default=None, 63 | type=str, 64 | help="Override time stamp ID. " 65 | "Useful for seamlessly continuing model training in logger.", 66 | ) 67 | # Cluster args 68 | self.parser.add_argument( 69 | "--sweep-yml", 70 | default=None, 71 | type=Path, 72 | help="Path to a config file with parameter sweeps", 73 | ) 74 | self.parser.add_argument( 75 | "--submit", action="store_true", help="Submit job to cluster" 76 | ) 77 | self.parser.add_argument( 78 | "--summit", action="store_true", help="Running on Summit cluster" 79 | ) 80 | self.parser.add_argument( 81 | "--logdir", default="logs", type=Path, help="Where to store logs" 82 | ) 83 | self.parser.add_argument( 84 | "--slurm-partition", 85 | default="ocp", 86 | type=str, 87 | help="Name of partition", 88 | ) 89 | self.parser.add_argument( 90 | "--slurm-mem", default=80, type=int, help="Memory (in gigabytes)" 91 | ) 92 | self.parser.add_argument( 93 | "--slurm-timeout", default=72, type=int, help="Time (in hours)" 94 | ) 95 | self.parser.add_argument( 96 | "--num-gpus", default=1, type=int, help="Number of GPUs to request" 97 | ) 98 | self.parser.add_argument( 99 | "--distributed", action="store_true", help="Run with DDP" 100 | ) 101 | self.parser.add_argument( 102 | "--cpu", action="store_true", help="Run CPU only training" 103 | ) 104 | self.parser.add_argument( 105 | "--num-nodes", 106 | default=1, 107 | type=int, 108 | help="Number of Nodes to request", 109 | ) 110 | self.parser.add_argument( 111 | "--distributed-port", 112 | type=int, 113 | default=13356, 114 | help="Port on master for DDP", 115 | ) 116 | self.parser.add_argument( 117 | "--distributed-backend", 118 | type=str, 119 | default="nccl", 120 | help="Backend for DDP", 121 | ) 122 | self.parser.add_argument( 123 | "--local_rank", default=0, type=int, help="Local rank" 124 | ) 125 | self.parser.add_argument( 126 | "--no-ddp", action="store_true", help="Do not use DDP" 127 | ) 128 | 129 | # added args from mdsim. 130 | self.parser.add_argument( 131 | "--no_energy", 132 | action="store_true" 133 | ) 134 | self.parser.add_argument( 135 | "--molecule", type=str, help="md17 molecule" 136 | ) 137 | self.parser.add_argument( 138 | "--size", type=str, help="dataset size" 139 | ) 140 | self.parser.add_argument( 141 | "--cutoff", type=float, help="reset radius cutoff" 142 | ) 143 | self.parser.add_argument( 144 | "--lr_patience", type=int, help="patience for lr scheduler" 145 | ) 146 | self.parser.add_argument( 147 | "--max_epochs", type=int, help="maximum number of training epochs" 148 | ) 149 | self.parser.add_argument( 150 | "--nequip", action="store_true", help="train with nequip" 151 | ) 152 | 153 | 154 | flags = Flags() 155 | -------------------------------------------------------------------------------- /mdsim/common/hpo_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from ray import tune 4 | 5 | 6 | def tune_reporter( 7 | iters, 8 | train_metrics, 9 | val_metrics, 10 | test_metrics=None, 11 | metric_to_opt="val_loss", 12 | min_max="min", 13 | ): 14 | """ 15 | Wrapper function for tune.report() 16 | 17 | Args: 18 | iters(dict): dict with training iteration info (e.g. steps, epochs) 19 | train_metrics(dict): train metrics dict 20 | val_metrics(dict): val metrics dict 21 | test_metrics(dict, optional): test metrics dict, default is None 22 | metric_to_opt(str, optional): str for val metric to optimize, default is val_loss 23 | min_max(str, optional): either "min" or "max", determines whether metric_to_opt is to be minimized or maximized, default is min 24 | 25 | """ 26 | # labels metric dicts 27 | train = label_metric_dict(train_metrics, "train") 28 | val = label_metric_dict(val_metrics, "val") 29 | # this enables tolerance for NaNs assumes val set is used for optimization 30 | if math.isnan(val[metric_to_opt]): 31 | if min_max == "min": 32 | val[metric_to_opt] = 100000.0 33 | if min_max == "max": 34 | val[metric_to_opt] = 0.0 35 | if test_metrics: 36 | test = label_metric_dict(test_metrics, "test") 37 | else: 38 | test = {} 39 | # report results to Ray Tune 40 | tune.report(**iters, **train, **val, **test) 41 | 42 | 43 | def label_metric_dict(metric_dict, split): 44 | new_dict = {} 45 | for key in metric_dict: 46 | new_dict["{}_{}".format(split, key)] = metric_dict[key] 47 | metric_dict = new_dict 48 | return metric_dict 49 | -------------------------------------------------------------------------------- /mdsim/common/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC, abstractmethod 3 | 4 | import torch 5 | import wandb 6 | try: 7 | from torch.utils.tensorboard import SummaryWriter 8 | except: 9 | pass 10 | 11 | from mdsim.common.registry import registry 12 | 13 | 14 | class Logger(ABC): 15 | """Generic class to interface with various logging modules, e.g. wandb, 16 | tensorboard, etc. 17 | """ 18 | 19 | def __init__(self, config): 20 | self.config = config 21 | 22 | @abstractmethod 23 | def watch(self, model): 24 | """ 25 | Monitor parameters and gradients. 26 | """ 27 | pass 28 | 29 | def log(self, update_dict, step=None, split=""): 30 | """ 31 | Log some values. 32 | """ 33 | assert step is not None 34 | if split != "": 35 | new_dict = {} 36 | for key in update_dict: 37 | new_dict["{}/{}".format(split, key)] = update_dict[key] 38 | update_dict = new_dict 39 | return update_dict 40 | 41 | @abstractmethod 42 | def log_plots(self, plots): 43 | pass 44 | 45 | @abstractmethod 46 | def mark_preempting(self): 47 | pass 48 | 49 | 50 | @registry.register_logger("wandb") 51 | class WandBLogger(Logger): 52 | def __init__(self, config): 53 | super().__init__(config) 54 | project = ( 55 | self.config["logger"].get("project", None) 56 | if isinstance(self.config["logger"], dict) 57 | else None 58 | ) 59 | 60 | wandb.init( 61 | config=self.config, 62 | name=self.config["cmd"]["expname"], 63 | dir=self.config["cmd"]["logs_dir"], 64 | project=project, 65 | resume="allow", 66 | ) 67 | 68 | def watch(self, model): 69 | wandb.watch(model) 70 | 71 | def log(self, update_dict, step=None, split=""): 72 | update_dict = super().log(update_dict, step, split) 73 | wandb.log(update_dict, step=int(step)) 74 | 75 | def log_plots(self, plots, caption=""): 76 | assert isinstance(plots, list) 77 | plots = [wandb.Image(x, caption=caption) for x in plots] 78 | wandb.log({"data": plots}) 79 | 80 | def mark_preempting(self): 81 | wandb.mark_preempting() 82 | 83 | 84 | @registry.register_logger("tensorboard") 85 | class TensorboardLogger(Logger): 86 | def __init__(self, config): 87 | super().__init__(config) 88 | self.writer = SummaryWriter(self.config["cmd"]["logs_dir"]) 89 | 90 | def watch(self, model): 91 | logging.warning( 92 | "Model gradient logging to tensorboard not yet supported." 93 | ) 94 | return False 95 | 96 | def log(self, update_dict, step=None, split=""): 97 | update_dict = super().log(update_dict, step, split) 98 | for key in update_dict: 99 | if torch.is_tensor(update_dict[key]): 100 | self.writer.add_scalar(key, update_dict[key].item(), step) 101 | else: 102 | assert isinstance(update_dict[key], int) or isinstance( 103 | update_dict[key], float 104 | ) 105 | self.writer.add_scalar(key, update_dict[key], step) 106 | 107 | def mark_preempting(self): 108 | pass 109 | 110 | def log_plots(self, plots): 111 | pass 112 | -------------------------------------------------------------------------------- /mdsim/common/registry.py: -------------------------------------------------------------------------------- 1 | # Borrowed from https://github.com/facebookresearch/pythia/blob/master/pythia/common/registry.py. 2 | """ 3 | Registry is central source of truth. Inspired from Redux's concept of 4 | global store, Registry maintains mappings of various information to unique 5 | keys. Special functions in registry can be used as decorators to register 6 | different kind of classes. 7 | 8 | Import the global registry object using 9 | 10 | ``from mdsim.common.registry import registry`` 11 | 12 | Various decorators for registry different kind of classes with unique keys 13 | 14 | - Register a model: ``@registry.register_model`` 15 | """ 16 | 17 | 18 | class Registry: 19 | r"""Class for registry object which acts as central source of truth.""" 20 | mapping = { 21 | # Mappings to respective classes. 22 | "task_name_mapping": {}, 23 | "dataset_name_mapping": {}, 24 | "model_name_mapping": {}, 25 | "logger_name_mapping": {}, 26 | "trainer_name_mapping": {}, 27 | "integrator_name_mapping": {}, 28 | "state": {}, 29 | } 30 | 31 | @classmethod 32 | def register_task(cls, name): 33 | r"""Register a new task to registry with key 'name' 34 | Args: 35 | name: Key with which the task will be registered. 36 | Usage:: 37 | from mdsim.common.registry import registry 38 | from mdsim.tasks import BaseTask 39 | @registry.register_task("train") 40 | class TrainTask(BaseTask): 41 | ... 42 | """ 43 | 44 | def wrap(func): 45 | cls.mapping["task_name_mapping"][name] = func 46 | return func 47 | 48 | return wrap 49 | 50 | @classmethod 51 | def register_dataset(cls, name): 52 | r"""Register a dataset to registry with key 'name' 53 | 54 | Args: 55 | name: Key with which the dataset will be registered. 56 | 57 | Usage:: 58 | 59 | from mdsim.common.registry import registry 60 | from mdsim.datasets import BaseDataset 61 | 62 | @registry.register_dataset("qm9") 63 | class QM9(BaseDataset): 64 | ... 65 | """ 66 | 67 | def wrap(func): 68 | cls.mapping["dataset_name_mapping"][name] = func 69 | return func 70 | 71 | return wrap 72 | 73 | @classmethod 74 | def register_model(cls, name): 75 | r"""Register a model to registry with key 'name' 76 | 77 | Args: 78 | name: Key with which the model will be registered. 79 | 80 | Usage:: 81 | 82 | from mdsim.common.registry import registry 83 | from mdsim.modules.layers import CGCNNConv 84 | 85 | @registry.register_model("cgcnn") 86 | class CGCNN(): 87 | ... 88 | """ 89 | 90 | def wrap(func): 91 | cls.mapping["model_name_mapping"][name] = func 92 | return func 93 | 94 | return wrap 95 | 96 | @classmethod 97 | def register_logger(cls, name): 98 | r"""Register a logger to registry with key 'name' 99 | 100 | Args: 101 | name: Key with which the logger will be registered. 102 | 103 | Usage:: 104 | 105 | from mdsim.common.registry import registry 106 | 107 | @registry.register_logger("tensorboard") 108 | class WandB(): 109 | ... 110 | """ 111 | 112 | def wrap(func): 113 | from mdsim.common.logger import Logger 114 | 115 | assert issubclass( 116 | func, Logger 117 | ), "All loggers must inherit Logger class" 118 | cls.mapping["logger_name_mapping"][name] = func 119 | return func 120 | 121 | return wrap 122 | 123 | @classmethod 124 | def register_trainer(cls, name): 125 | r"""Register a trainer to registry with key 'name' 126 | 127 | Args: 128 | name: Key with which the trainer will be registered. 129 | 130 | Usage:: 131 | 132 | from mdsim.common.registry import registry 133 | 134 | @registry.register_trainer("active_discovery") 135 | class ActiveDiscoveryTrainer(): 136 | ... 137 | """ 138 | 139 | def wrap(func): 140 | cls.mapping["trainer_name_mapping"][name] = func 141 | return func 142 | 143 | return wrap 144 | 145 | @classmethod 146 | def register_integrator(cls, name): 147 | def wrap(func): 148 | cls.mapping["integrator_name_mapping"][name] = func 149 | return func 150 | 151 | return wrap 152 | 153 | @classmethod 154 | def register(cls, name, obj): 155 | r"""Register an item to registry with key 'name' 156 | 157 | Args: 158 | name: Key with which the item will be registered. 159 | 160 | Usage:: 161 | 162 | from mdsim.common.registry import registry 163 | 164 | registry.register("config", {}) 165 | """ 166 | path = name.split(".") 167 | current = cls.mapping["state"] 168 | 169 | for part in path[:-1]: 170 | if part not in current: 171 | current[part] = {} 172 | current = current[part] 173 | 174 | current[path[-1]] = obj 175 | 176 | @classmethod 177 | def get_task_class(cls, name): 178 | return cls.mapping["task_name_mapping"].get(name, None) 179 | 180 | @classmethod 181 | def get_dataset_class(cls, name): 182 | return cls.mapping["dataset_name_mapping"].get(name, None) 183 | 184 | @classmethod 185 | def get_model_class(cls, name): 186 | return cls.mapping["model_name_mapping"].get(name, None) 187 | 188 | @classmethod 189 | def get_logger_class(cls, name): 190 | return cls.mapping["logger_name_mapping"].get(name, None) 191 | 192 | @classmethod 193 | def get_trainer_class(cls, name): 194 | return cls.mapping["trainer_name_mapping"].get(name, None) 195 | 196 | @classmethod 197 | def get_integrator_class(cls, name): 198 | return cls.mapping["integrator_name_mapping"].get(name, None) 199 | 200 | @classmethod 201 | def get(cls, name, default=None, no_warning=False): 202 | r"""Get an item from registry with key 'name' 203 | 204 | Args: 205 | name (string): Key whose value needs to be retreived. 206 | default: If passed and key is not in registry, default value will 207 | be returned with a warning. Default: None 208 | no_warning (bool): If passed as True, warning when key doesn't exist 209 | will not be generated. Useful for cgcnn's 210 | internal operations. Default: False 211 | Usage:: 212 | 213 | from mdsim.common.registry import registry 214 | 215 | config = registry.get("config") 216 | """ 217 | original_name = name 218 | name = name.split(".") 219 | value = cls.mapping["state"] 220 | for subname in name: 221 | value = value.get(subname, default) 222 | if value is default: 223 | break 224 | 225 | if ( 226 | "writer" in cls.mapping["state"] 227 | and value == default 228 | and no_warning is False 229 | ): 230 | cls.mapping["state"]["writer"].write( 231 | "Key {} is not present in registry, returning default value " 232 | "of {}".format(original_name, default) 233 | ) 234 | return value 235 | 236 | @classmethod 237 | def unregister(cls, name): 238 | r"""Remove an item from registry with key 'name' 239 | 240 | Args: 241 | name: Key which needs to be removed. 242 | Usage:: 243 | 244 | from mdsim.common.registry import registry 245 | 246 | config = registry.unregister("config") 247 | """ 248 | return cls.mapping["state"].pop(name, None) 249 | 250 | 251 | registry = Registry() 252 | -------------------------------------------------------------------------------- /mdsim/common/transforms.py: -------------------------------------------------------------------------------- 1 | # Borrowed from https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/transforms/random_rotate.py 2 | # with changes to keep track of the rotation / inverse rotation matrices. 3 | 4 | import math 5 | import numbers 6 | import random 7 | 8 | import torch 9 | import torch_geometric 10 | from torch_geometric.transforms import LinearTransformation 11 | 12 | 13 | class RandomRotate(object): 14 | r"""Rotates node positions around a specific axis by a randomly sampled 15 | factor within a given interval. 16 | 17 | Args: 18 | degrees (tuple or float): Rotation interval from which the rotation 19 | angle is sampled. If `degrees` is a number instead of a 20 | tuple, the interval is given by :math:`[-\mathrm{degrees}, 21 | \mathrm{degrees}]`. 22 | axes (int, optional): The rotation axes. (default: `[0, 1, 2]`) 23 | """ 24 | 25 | def __init__(self, degrees, axes=[0, 1, 2]): 26 | if isinstance(degrees, numbers.Number): 27 | degrees = (-abs(degrees), abs(degrees)) 28 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2 29 | self.degrees = degrees 30 | self.axes = axes 31 | 32 | def __call__(self, data): 33 | if data.pos.size(-1) == 2: 34 | degree = math.pi * random.uniform(*self.degrees) / 180.0 35 | sin, cos = math.sin(degree), math.cos(degree) 36 | matrix = [[cos, sin], [-sin, cos]] 37 | else: 38 | m1, m2, m3 = torch.eye(3), torch.eye(3), torch.eye(3) 39 | if 0 in self.axes: 40 | degree = math.pi * random.uniform(*self.degrees) / 180.0 41 | sin, cos = math.sin(degree), math.cos(degree) 42 | m1 = torch.tensor([[1, 0, 0], [0, cos, sin], [0, -sin, cos]]) 43 | if 1 in self.axes: 44 | degree = math.pi * random.uniform(*self.degrees) / 180.0 45 | sin, cos = math.sin(degree), math.cos(degree) 46 | m2 = torch.tensor([[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]]) 47 | if 2 in self.axes: 48 | degree = math.pi * random.uniform(*self.degrees) / 180.0 49 | sin, cos = math.sin(degree), math.cos(degree) 50 | m3 = torch.tensor([[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]]) 51 | 52 | matrix = torch.mm(torch.mm(m1, m2), m3) 53 | 54 | data_rotated = LinearTransformation(matrix)(data) 55 | if torch_geometric.__version__.startswith("2."): 56 | matrix = matrix.T 57 | 58 | # LinearTransformation only rotates `.pos`; need to rotate `.cell` too. 59 | if hasattr(data_rotated, "cell"): 60 | data_rotated.cell = torch.matmul(data_rotated.cell, matrix) 61 | 62 | return ( 63 | data_rotated, 64 | matrix, 65 | torch.inverse(matrix), 66 | ) 67 | 68 | def __repr__(self): 69 | return "{}({}, axis={})".format( 70 | self.__class__.__name__, self.degrees, self.axis 71 | ) 72 | -------------------------------------------------------------------------------- /mdsim/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .lmdb_dataset import ( 7 | LmdbDataset, 8 | data_list_collater, 9 | ) 10 | -------------------------------------------------------------------------------- /mdsim/datasets/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "ATOMIC_RADII", 3 | "KHOT_EMBEDDINGS", 4 | "CONTINUOUS_EMBEDDINGS", 5 | "QMOF_KHOT_EMBEDDINGS", 6 | ] 7 | 8 | from .atomic_radii import ATOMIC_RADII 9 | from .continuous_embeddings import CONTINUOUS_EMBEDDINGS 10 | from .khot_embeddings import KHOT_EMBEDDINGS 11 | from .qmof_khot_embeddings import QMOF_KHOT_EMBEDDINGS 12 | -------------------------------------------------------------------------------- /mdsim/datasets/embeddings/atomic_radii.py: -------------------------------------------------------------------------------- 1 | """ 2 | Atomic radii in picometers 3 | 4 | NaN stored for unavailable parameters. 5 | """ 6 | ATOMIC_RADII = { 7 | 0: float("NaN"), 8 | 1: 25.0, 9 | 2: 120.0, 10 | 3: 145.0, 11 | 4: 105.0, 12 | 5: 85.0, 13 | 6: 70.0, 14 | 7: 65.0, 15 | 8: 60.0, 16 | 9: 50.0, 17 | 10: 160.0, 18 | 11: 180.0, 19 | 12: 150.0, 20 | 13: 125.0, 21 | 14: 110.0, 22 | 15: 100.0, 23 | 16: 100.0, 24 | 17: 100.0, 25 | 18: 71.0, 26 | 19: 220.0, 27 | 20: 180.0, 28 | 21: 160.0, 29 | 22: 140.0, 30 | 23: 135.0, 31 | 24: 140.0, 32 | 25: 140.0, 33 | 26: 140.0, 34 | 27: 135.0, 35 | 28: 135.0, 36 | 29: 135.0, 37 | 30: 135.0, 38 | 31: 130.0, 39 | 32: 125.0, 40 | 33: 115.0, 41 | 34: 115.0, 42 | 35: 115.0, 43 | 36: float("NaN"), 44 | 37: 235.0, 45 | 38: 200.0, 46 | 39: 180.0, 47 | 40: 155.0, 48 | 41: 145.0, 49 | 42: 145.0, 50 | 43: 135.0, 51 | 44: 130.0, 52 | 45: 135.0, 53 | 46: 140.0, 54 | 47: 160.0, 55 | 48: 155.0, 56 | 49: 155.0, 57 | 50: 145.0, 58 | 51: 145.0, 59 | 52: 140.0, 60 | 53: 140.0, 61 | 54: float("NaN"), 62 | 55: 260.0, 63 | 56: 215.0, 64 | 57: 195.0, 65 | 58: 185.0, 66 | 59: 185.0, 67 | 60: 185.0, 68 | 61: 185.0, 69 | 62: 185.0, 70 | 63: 185.0, 71 | 64: 180.0, 72 | 65: 175.0, 73 | 66: 175.0, 74 | 67: 175.0, 75 | 68: 175.0, 76 | 69: 175.0, 77 | 70: 175.0, 78 | 71: 175.0, 79 | 72: 155.0, 80 | 73: 145.0, 81 | 74: 135.0, 82 | 75: 135.0, 83 | 76: 130.0, 84 | 77: 135.0, 85 | 78: 135.0, 86 | 79: 135.0, 87 | 80: 150.0, 88 | 81: 190.0, 89 | 82: 180.0, 90 | 83: 160.0, 91 | 84: 190.0, 92 | 85: float("NaN"), 93 | 86: float("NaN"), 94 | 87: float("NaN"), 95 | 88: 215.0, 96 | 89: 195.0, 97 | 90: 180.0, 98 | 91: 180.0, 99 | 92: 175.0, 100 | 93: 175.0, 101 | 94: 175.0, 102 | 95: 175.0, 103 | 96: float("NaN"), 104 | 97: float("NaN"), 105 | 98: float("NaN"), 106 | 99: float("NaN"), 107 | 100: float("NaN"), 108 | } 109 | -------------------------------------------------------------------------------- /mdsim/datasets/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import logging 3 | import pickle 4 | import warnings 5 | from pathlib import Path 6 | 7 | import lmdb 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import Dataset 11 | from torch_geometric.data import Batch 12 | 13 | from mdsim.common.registry import registry 14 | 15 | @registry.register_dataset("lmdb") 16 | @registry.register_dataset("single_point_lmdb") 17 | @registry.register_dataset("trajectory_lmdb") 18 | class LmdbDataset(Dataset): 19 | r"""Dataset class to load from LMDB files containing relaxation 20 | trajectories or single point computations. 21 | 22 | Args: 23 | config (dict): Dataset configuration 24 | transform (callable, optional): Data transform function. 25 | (default: :obj:`None`) 26 | """ 27 | 28 | def __init__(self, config, transform=None): 29 | super(LmdbDataset, self).__init__() 30 | self.config = config 31 | 32 | self.path = Path(self.config["src"]) 33 | if not self.path.is_file(): 34 | db_paths = sorted(self.path.glob("*.lmdb")) 35 | assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'" 36 | 37 | self.metadata_path = self.path / "metadata.npz" 38 | 39 | self._keys, self.envs = [], [] 40 | for db_path in db_paths: 41 | self.envs.append(self.connect_db(db_path)) 42 | length = pickle.loads( 43 | self.envs[-1].begin().get("length".encode("ascii")) 44 | ) 45 | self._keys.append(list(range(length))) 46 | 47 | keylens = [len(k) for k in self._keys] 48 | self._keylen_cumulative = np.cumsum(keylens).tolist() 49 | self.num_samples = sum(keylens) 50 | else: 51 | self.metadata_path = self.path.parent / "metadata.npz" 52 | self.env = self.connect_db(self.path) 53 | self._keys = [ 54 | f"{j}".encode("ascii") 55 | for j in range(self.env.stat()["entries"]) 56 | ] 57 | self.num_samples = len(self._keys) 58 | 59 | self.transform = transform 60 | 61 | def __len__(self): 62 | return self.num_samples 63 | 64 | def __getitem__(self, idx): 65 | if not self.path.is_file(): 66 | # Figure out which db this should be indexed from. 67 | db_idx = bisect.bisect(self._keylen_cumulative, idx) 68 | # Extract index of element within that db. 69 | el_idx = idx 70 | if db_idx != 0: 71 | el_idx = idx - self._keylen_cumulative[db_idx - 1] 72 | assert el_idx >= 0 73 | 74 | # Return features. 75 | datapoint_pickled = ( 76 | self.envs[db_idx] 77 | .begin() 78 | .get(f"{self._keys[db_idx][el_idx]}".encode("ascii")) 79 | ) 80 | data_object = pickle.loads(datapoint_pickled) 81 | data_object.id = f"{db_idx}_{el_idx}" 82 | else: 83 | datapoint_pickled = self.env.begin().get(self._keys[idx]) 84 | data_object = pickle.loads(datapoint_pickled) 85 | 86 | if self.transform is not None: 87 | data_object = self.transform(data_object) 88 | 89 | return data_object 90 | 91 | def connect_db(self, lmdb_path=None): 92 | env = lmdb.open( 93 | str(lmdb_path), 94 | subdir=False, 95 | readonly=True, 96 | lock=False, 97 | readahead=False, 98 | meminit=False, 99 | max_readers=1, 100 | ) 101 | return env 102 | 103 | def close_db(self): 104 | if not self.path.is_file(): 105 | for env in self.envs: 106 | env.close() 107 | else: 108 | self.env.close() 109 | 110 | def data_list_collater(data_list, otf_graph=False): 111 | batch = Batch.from_data_list(data_list) 112 | 113 | if not otf_graph: 114 | try: 115 | n_neighbors = [] 116 | for i, data in enumerate(data_list): 117 | n_index = data.edge_index[1, :] 118 | n_neighbors.append(n_index.shape[0]) 119 | batch.neighbors = torch.tensor(n_neighbors) 120 | except NotImplementedError: 121 | logging.warning( 122 | "LMDB does not contain edge index information, set otf_graph=True" 123 | ) 124 | 125 | return batch 126 | -------------------------------------------------------------------------------- /mdsim/md/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyonofx/MDsim/61ca2cfba373fd374343592556789fa30fbe3b56/mdsim/md/__init__.py -------------------------------------------------------------------------------- /mdsim/md/integrator.py: -------------------------------------------------------------------------------- 1 | """ 2 | adapted from 3 | https://github.com/torchmd/mdgrad/tree/master/nff/md 4 | """ 5 | import numpy as np 6 | from ase.md.md import MolecularDynamics 7 | from ase.md.verlet import VelocityVerlet 8 | from ase.md.langevin import Langevin 9 | 10 | class NoseHoover(MolecularDynamics): 11 | def __init__(self, 12 | atoms, 13 | timestep, 14 | temperature, 15 | ttime, 16 | trajectory=None, 17 | logfile=None, 18 | loginterval=1, 19 | **kwargs): 20 | 21 | super().__init__( 22 | atoms, 23 | timestep, 24 | trajectory, 25 | logfile, 26 | loginterval) 27 | 28 | # Initialize simulation parameters 29 | 30 | # Q is chosen to be 6 N kT 31 | self.dt = timestep 32 | self.Natom = atoms.get_number_of_atoms() 33 | self.T = temperature 34 | self.targeEkin = 0.5 * (3.0 * self.Natom) * self.T 35 | self.ttime = ttime # * units.fs 36 | self.Q = 3.0 * self.Natom * self.T * (self.ttime * self.dt)**2 37 | self.zeta = 0.0 38 | 39 | def step(self): 40 | 41 | # get current acceleration and velocity: 42 | accel = self.atoms.get_forces() / self.atoms.get_masses().reshape(-1, 1) 43 | vel = self.atoms.get_velocities() 44 | 45 | # make full step in position 46 | x = self.atoms.get_positions() + vel * self.dt + \ 47 | (accel - self.zeta * vel) * (0.5 * self.dt ** 2) 48 | self.atoms.set_positions(x) 49 | 50 | # record current velocities 51 | KE_0 = self.atoms.get_kinetic_energy() 52 | 53 | # make half a step in velocity 54 | vel_half = vel + 0.5 * self.dt * (accel - self.zeta * vel) 55 | self.atoms.set_velocities(vel_half) 56 | 57 | # make a full step in accelerations 58 | f = self.atoms.get_forces() 59 | accel = f / self.atoms.get_masses().reshape(-1, 1) 60 | 61 | # make a half step in self.zeta 62 | self.zeta = self.zeta + 0.5 * self.dt * \ 63 | (1/self.Q) * (KE_0 - self.targeEkin) 64 | 65 | # make another halfstep in self.zeta 66 | self.zeta = self.zeta + 0.5 * self.dt * \ 67 | (1/self.Q) * (self.atoms.get_kinetic_energy() - self.targeEkin) 68 | 69 | # make another half step in velocity 70 | vel = (self.atoms.get_velocities() + 0.5 * self.dt * accel) / \ 71 | (1 + 0.5 * self.dt * self.zeta) 72 | self.atoms.set_velocities(vel) 73 | 74 | return f 75 | 76 | 77 | class NoseHooverChain(MolecularDynamics): 78 | def __init__(self, 79 | atoms, 80 | timestep, 81 | temperature, 82 | ttime, 83 | num_chains, 84 | trajectory=None, 85 | logfile=None, 86 | loginterval=1, 87 | **kwargs): 88 | 89 | super().__init__( 90 | atoms, 91 | timestep, 92 | trajectory, 93 | logfile, 94 | loginterval) 95 | 96 | # Initialize simulation parameters 97 | 98 | self.dt = timestep 99 | 100 | self.N_dof = 3*atoms.get_number_of_atoms() 101 | self.T = temperature 102 | 103 | # in units of fs: 104 | self.ttime = ttime 105 | self.Q = 2 * np.array([self.N_dof * self.T * (self.ttime * self.dt)**2, 106 | *[self.T * (self.ttime * self.dt)**2]*(num_chains-1)]) 107 | self.targeEkin = 1/2 * self.N_dof * self.T 108 | 109 | # self.zeta = np.array([0.0]*num_chains) 110 | self.p_zeta = np.array([0.0]*num_chains) 111 | 112 | def get_zeta_accel(self): 113 | 114 | p0_dot = 2 * (self.atoms.get_kinetic_energy() - self.targeEkin)- \ 115 | self.p_zeta[0]*self.p_zeta[1] / self.Q[1] 116 | p_middle_dot = self.p_zeta[:-2]**2 / self.Q[:-2] - \ 117 | self.T - self.p_zeta[1:-1] * self.p_zeta[2:]/self.Q[2:] 118 | p_last_dot = self.p_zeta[-2]**2 / self.Q[-2] - self.T 119 | p_dot = np.array([p0_dot, *p_middle_dot, p_last_dot]) 120 | 121 | return p_dot / self.Q 122 | 123 | def half_step_v_zeta(self): 124 | 125 | v = self.p_zeta / self.Q 126 | accel = self.get_zeta_accel() 127 | v_half = v + 1/2 * accel * self.dt 128 | return v_half 129 | 130 | def half_step_v_system(self): 131 | 132 | v = self.atoms.get_velocities() 133 | accel = self.atoms.get_forces() / self.atoms.get_masses().reshape(-1, 1) 134 | accel -= v * self.p_zeta[0] / self.Q[0] 135 | v_half = v + 1/2 * accel * self.dt 136 | return v_half 137 | 138 | def full_step_positions(self): 139 | 140 | accel = self.atoms.get_forces() / self.atoms.get_masses().reshape(-1, 1) 141 | new_positions = self.atoms.get_positions() + self.atoms.get_velocities() * self.dt + \ 142 | (accel - self.p_zeta[0] / self.Q[0])*(self.dt)**2 143 | return new_positions 144 | 145 | def step(self): 146 | 147 | new_positions = self.full_step_positions() 148 | self.atoms.set_positions(new_positions) 149 | 150 | v_half_system = self.half_step_v_system() 151 | v_half_zeta = self.half_step_v_zeta() 152 | 153 | self.atoms.set_velocities(v_half_system) 154 | self.p_zeta = v_half_zeta * self.Q 155 | 156 | v_full_zeta = self.half_step_v_zeta() 157 | accel = self.atoms.get_forces() / self.atoms.get_masses().reshape(-1, 1) 158 | v_full_system = (v_half_system + 1/2 * accel * self.dt) / \ 159 | (1 + 0.5 * self.dt * v_full_zeta[0]) 160 | 161 | self.atoms.set_velocities(v_full_system) 162 | self.p_zeta = v_full_zeta * self.Q -------------------------------------------------------------------------------- /mdsim/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .base import BaseModel 7 | from .cgcnn import CGCNN 8 | from .dimenet import DimeNetWrap as DimeNet 9 | from .dimenet_plus_plus import DimeNetPlusPlusWrap as DimeNetPlusPlus 10 | from .forcenet import ForceNet 11 | from .gemnet.gemnet import GemNetT 12 | from .schnet import SchNetWrap as SchNet 13 | from .spinconv import spinconv 14 | -------------------------------------------------------------------------------- /mdsim/models/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class BaseModel(nn.Module): 12 | def __init__(self, num_atoms=None, bond_feat_dim=None, num_targets=None): 13 | super(BaseModel, self).__init__() 14 | self.num_atoms = num_atoms 15 | self.bond_feat_dim = bond_feat_dim 16 | self.num_targets = num_targets 17 | 18 | def forward(self, data): 19 | raise NotImplementedError 20 | 21 | @property 22 | def num_params(self): 23 | return sum(p.numel() for p in self.parameters()) 24 | -------------------------------------------------------------------------------- /mdsim/models/gemnet/fit_scaling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for calculating the scaling factors used to even out GemNet activation 3 | scales. This generates the `scale_file` specified in the config, which is then 4 | read in at model initialization. 5 | This only needs to be run if the hyperparameters or model change 6 | in places were it would affect the activation scales. 7 | """ 8 | 9 | import logging 10 | import os 11 | import sys 12 | import numpy as np 13 | import torch 14 | from tqdm import trange 15 | 16 | from mdsim.common.flags import flags 17 | from mdsim.common.registry import registry 18 | from mdsim.common.utils import build_config, setup_imports, setup_logging 19 | from mdsim.models.gemnet.layers.scaling import AutomaticFit 20 | from mdsim.models.gemnet.utils import write_json 21 | 22 | if __name__ == "__main__": 23 | setup_logging() 24 | 25 | num_batches = 25 # number of batches to use to fit a single variable 26 | 27 | parser = flags.get_parser() 28 | args, override_args = parser.parse_known_args() 29 | config = build_config(args, override_args) 30 | assert config["model"]["name"].startswith("gemnet") 31 | config["logger"] = "tensorboard" 32 | 33 | if args.distributed: 34 | raise ValueError( 35 | "I don't think this works with DDP (race conditions)." 36 | ) 37 | 38 | setup_imports() 39 | 40 | scale_file = config["model"]["scale_file"] 41 | 42 | logging.info(f"Run fitting for model: {args.identifier}") 43 | logging.info(f"Target scale file: {scale_file}") 44 | 45 | def initialize_scale_file(scale_file): 46 | # initialize file 47 | preset = {"comment": args.identifier} 48 | write_json(scale_file, preset) 49 | 50 | if os.path.exists(scale_file): 51 | logging.warning(f"Already found existing file: {scale_file}") 52 | flag = input( 53 | "Do you want to continue and overwrite the file (1), " 54 | "only fit the variables not fitted yet (2), or exit (3)? " 55 | ) 56 | if str(flag) == "1": 57 | logging.info("Overwriting the current file.") 58 | initialize_scale_file(scale_file) 59 | elif str(flag) == "2": 60 | logging.info("Only fitting unfitted variables.") 61 | else: 62 | print(flag) 63 | logging.info("Exiting script") 64 | sys.exit() 65 | else: 66 | initialize_scale_file(scale_file) 67 | 68 | AutomaticFit.set2fitmode() 69 | 70 | # compose dataset configs. 71 | train_data_cfg = config['dataset'] 72 | dataset_name = train_data_cfg['name'] 73 | if dataset_name == 'md17': 74 | train_data_cfg['src'] = os.path.join(train_data_cfg['src'], train_data_cfg['molecule']) 75 | train_data_cfg['name'] = 'md17-' + train_data_cfg['molecule'] 76 | src = os.path.join(train_data_cfg['src'], train_data_cfg['size']) 77 | train_data_cfg['src'] = os.path.join(src, 'train') 78 | 79 | norm_stats = np.load(os.path.join(src, 'metadata.npy'), allow_pickle=True).item() 80 | if not train_data_cfg['normalize_labels']: 81 | # mean of energy is arbitrary. should always substract. 82 | # this is done in . 83 | train_data_cfg['target_mean'] = norm_stats['e_mean'] 84 | train_data_cfg['target_std'] = 1. 85 | train_data_cfg['grad_target_mean'] = 0. 86 | train_data_cfg['grad_target_std'] = 1. 87 | train_data_cfg['normalize_labels'] = True 88 | else: 89 | train_data_cfg['target_mean'] = float(norm_stats['e_mean']) 90 | train_data_cfg['target_std'] = float(norm_stats['e_std']) 91 | train_data_cfg['grad_target_mean'] = float(norm_stats['f_mean']) 92 | train_data_cfg['grad_target_std'] = float(norm_stats['f_std']) 93 | # train, val, test 94 | config['dataset'] = [train_data_cfg, 95 | {'src': os.path.join(src, 'val')}, ] 96 | 97 | # initialize trainer. 98 | trainer = registry.get_trainer_class( 99 | config.get("trainer", "energy") 100 | )( 101 | task=config["task"], 102 | model=config["model"], 103 | dataset=config["dataset"], 104 | optimizer=config["optim"], 105 | identifier=config["identifier"], 106 | timestamp_id=config.get("timestamp_id", None), 107 | run_dir=config.get("run_dir", None), 108 | is_debug=config.get("is_debug", False), 109 | print_every=config.get("print_every", 100), 110 | seed=config.get("seed", 0), 111 | logger=config.get("logger", "wandb"), 112 | local_rank=config["local_rank"], 113 | amp=config.get("amp", False), 114 | cpu=config.get("cpu", False), 115 | slurm=config.get("slurm", {}), 116 | no_energy=config.get("no_energy", False) 117 | ) 118 | 119 | # Fitting loop 120 | logging.info("Start fitting") 121 | 122 | if not AutomaticFit.fitting_completed(): 123 | with torch.no_grad(): 124 | trainer.model.eval() 125 | for _ in trange(len(AutomaticFit.queue) + 1): 126 | assert ( 127 | trainer.val_loader is not None 128 | ), "Val dataset is required for making predictions" 129 | 130 | for i, batch in enumerate(trainer.val_loader): 131 | with torch.cuda.amp.autocast( 132 | enabled=trainer.scaler is not None 133 | ): 134 | out = trainer._forward(batch) 135 | loss = trainer._compute_loss(out, batch) 136 | del out, loss 137 | if i == num_batches: 138 | break 139 | 140 | current_var = AutomaticFit.activeVar 141 | if current_var is not None: 142 | current_var.fit() # fit current variable 143 | else: 144 | print("Found no variable to fit. Something went wrong!") 145 | 146 | assert AutomaticFit.fitting_completed() 147 | logging.info(f"Fitting done. Results saved to: {scale_file}") 148 | -------------------------------------------------------------------------------- /mdsim/models/gemnet/initializers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | 10 | 11 | def _standardize(kernel): 12 | """ 13 | Makes sure that N*Var(W) = 1 and E[W] = 0 14 | """ 15 | eps = 1e-6 16 | 17 | if len(kernel.shape) == 3: 18 | axis = [0, 1] # last dimension is output dimension 19 | else: 20 | axis = 1 21 | 22 | var, mean = torch.var_mean(kernel, dim=axis, unbiased=True, keepdim=True) 23 | kernel = (kernel - mean) / (var + eps) ** 0.5 24 | return kernel 25 | 26 | 27 | def he_orthogonal_init(tensor): 28 | """ 29 | Generate a weight matrix with variance according to He (Kaiming) initialization. 30 | Based on a random (semi-)orthogonal matrix neural networks 31 | are expected to learn better when features are decorrelated 32 | (stated by eg. "Reducing overfitting in deep networks by decorrelating representations", 33 | "Dropout: a simple way to prevent neural networks from overfitting", 34 | "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks") 35 | """ 36 | tensor = torch.nn.init.orthogonal_(tensor) 37 | 38 | if len(tensor.shape) == 3: 39 | fan_in = tensor.shape[:-1].numel() 40 | else: 41 | fan_in = tensor.shape[1] 42 | 43 | with torch.no_grad(): 44 | tensor.data = _standardize(tensor.data) 45 | tensor.data *= (1 / fan_in) ** 0.5 46 | 47 | return tensor 48 | -------------------------------------------------------------------------------- /mdsim/models/gemnet/layers/atom_update_block.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | from torch_scatter import scatter 10 | 11 | from ..initializers import he_orthogonal_init 12 | from .base_layers import Dense, ResidualLayer 13 | from .scaling import ScalingFactor 14 | 15 | 16 | class AtomUpdateBlock(torch.nn.Module): 17 | """ 18 | Aggregate the message embeddings of the atoms 19 | 20 | Parameters 21 | ---------- 22 | emb_size_atom: int 23 | Embedding size of the atoms. 24 | emb_size_atom: int 25 | Embedding size of the edges. 26 | nHidden: int 27 | Number of residual blocks. 28 | activation: callable/str 29 | Name of the activation function to use in the dense layers. 30 | scale_file: str 31 | Path to the json file containing the scaling factors. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | emb_size_atom: int, 37 | emb_size_edge: int, 38 | emb_size_rbf: int, 39 | nHidden: int, 40 | activation=None, 41 | scale_file=None, 42 | name: str = "atom_update", 43 | ): 44 | super().__init__() 45 | self.name = name 46 | 47 | self.dense_rbf = Dense( 48 | emb_size_rbf, emb_size_edge, activation=None, bias=False 49 | ) 50 | self.scale_sum = ScalingFactor( 51 | scale_file=scale_file, name=name + "_sum" 52 | ) 53 | 54 | self.layers = self.get_mlp( 55 | emb_size_edge, emb_size_atom, nHidden, activation 56 | ) 57 | 58 | def get_mlp(self, units_in, units, nHidden, activation): 59 | dense1 = Dense(units_in, units, activation=activation, bias=False) 60 | mlp = [dense1] 61 | res = [ 62 | ResidualLayer(units, nLayers=2, activation=activation) 63 | for i in range(nHidden) 64 | ] 65 | mlp += res 66 | return torch.nn.ModuleList(mlp) 67 | 68 | def forward(self, h, m, rbf, id_j): 69 | """ 70 | Returns 71 | ------- 72 | h: torch.Tensor, shape=(nAtoms, emb_size_atom) 73 | Atom embedding. 74 | """ 75 | nAtoms = h.shape[0] 76 | 77 | mlp_rbf = self.dense_rbf(rbf) # (nEdges, emb_size_edge) 78 | x = m * mlp_rbf 79 | 80 | x2 = scatter(x, id_j, dim=0, dim_size=nAtoms, reduce="sum") 81 | # (nAtoms, emb_size_edge) 82 | x = self.scale_sum(m, x2) 83 | 84 | for layer in self.layers: 85 | x = layer(x) # (nAtoms, emb_size_atom) 86 | 87 | return x 88 | 89 | 90 | class OutputBlock(AtomUpdateBlock): 91 | """ 92 | Combines the atom update block and subsequent final dense layer. 93 | 94 | Parameters 95 | ---------- 96 | emb_size_atom: int 97 | Embedding size of the atoms. 98 | emb_size_atom: int 99 | Embedding size of the edges. 100 | nHidden: int 101 | Number of residual blocks. 102 | num_targets: int 103 | Number of targets. 104 | activation: str 105 | Name of the activation function to use in the dense layers except for the final dense layer. 106 | direct_forces: bool 107 | If true directly predict forces without taking the gradient of the energy potential. 108 | output_init: int 109 | Kernel initializer of the final dense layer. 110 | scale_file: str 111 | Path to the json file containing the scaling factors. 112 | """ 113 | 114 | def __init__( 115 | self, 116 | emb_size_atom: int, 117 | emb_size_edge: int, 118 | emb_size_rbf: int, 119 | nHidden: int, 120 | num_targets: int, 121 | activation=None, 122 | direct_forces=True, 123 | output_init="HeOrthogonal", 124 | scale_file=None, 125 | name: str = "output", 126 | **kwargs, 127 | ): 128 | 129 | super().__init__( 130 | name=name, 131 | emb_size_atom=emb_size_atom, 132 | emb_size_edge=emb_size_edge, 133 | emb_size_rbf=emb_size_rbf, 134 | nHidden=nHidden, 135 | activation=activation, 136 | scale_file=scale_file, 137 | **kwargs, 138 | ) 139 | 140 | assert isinstance(output_init, str) 141 | self.output_init = output_init.lower() 142 | self.direct_forces = direct_forces 143 | 144 | self.seq_energy = self.layers # inherited from parent class 145 | self.out_energy = Dense( 146 | emb_size_atom, num_targets, bias=False, activation=None 147 | ) 148 | 149 | if self.direct_forces: 150 | self.scale_rbf_F = ScalingFactor( 151 | scale_file=scale_file, name=name + "_had" 152 | ) 153 | self.seq_forces = self.get_mlp( 154 | emb_size_edge, emb_size_edge, nHidden, activation 155 | ) 156 | self.out_forces = Dense( 157 | emb_size_edge, num_targets, bias=False, activation=None 158 | ) 159 | self.dense_rbf_F = Dense( 160 | emb_size_rbf, emb_size_edge, activation=None, bias=False 161 | ) 162 | 163 | self.reset_parameters() 164 | 165 | def reset_parameters(self): 166 | if self.output_init == "heorthogonal": 167 | self.out_energy.reset_parameters(he_orthogonal_init) 168 | if self.direct_forces: 169 | self.out_forces.reset_parameters(he_orthogonal_init) 170 | elif self.output_init == "zeros": 171 | self.out_energy.reset_parameters(torch.nn.init.zeros_) 172 | if self.direct_forces: 173 | self.out_forces.reset_parameters(torch.nn.init.zeros_) 174 | else: 175 | raise UserWarning(f"Unknown output_init: {self.output_init}") 176 | 177 | def forward(self, h, m, rbf, id_j): 178 | """ 179 | Returns 180 | ------- 181 | (E, F): tuple 182 | - E: torch.Tensor, shape=(nAtoms, num_targets) 183 | - F: torch.Tensor, shape=(nEdges, num_targets) 184 | Energy and force prediction 185 | """ 186 | nAtoms = h.shape[0] 187 | 188 | # -------------------------------------- Energy Prediction -------------------------------------- # 189 | rbf_emb_E = self.dense_rbf(rbf) # (nEdges, emb_size_edge) 190 | x = m * rbf_emb_E 191 | 192 | x_E = scatter(x, id_j, dim=0, dim_size=nAtoms, reduce="sum") 193 | # (nAtoms, emb_size_edge) 194 | x_E = self.scale_sum(m, x_E) 195 | 196 | for layer in self.seq_energy: 197 | x_E = layer(x_E) # (nAtoms, emb_size_atom) 198 | 199 | x_E = self.out_energy(x_E) # (nAtoms, num_targets) 200 | 201 | # --------------------------------------- Force Prediction -------------------------------------- # 202 | if self.direct_forces: 203 | x_F = m 204 | for i, layer in enumerate(self.seq_forces): 205 | x_F = layer(x_F) # (nEdges, emb_size_edge) 206 | 207 | rbf_emb_F = self.dense_rbf_F(rbf) # (nEdges, emb_size_edge) 208 | x_F_rbf = x_F * rbf_emb_F 209 | x_F = self.scale_rbf_F(x_F, x_F_rbf) 210 | 211 | x_F = self.out_forces(x_F) # (nEdges, num_targets) 212 | else: 213 | x_F = 0 214 | # ----------------------------------------------------------------------------------------------- # 215 | 216 | return x_E, x_F 217 | -------------------------------------------------------------------------------- /mdsim/models/gemnet/layers/base_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import math 9 | 10 | import torch 11 | 12 | from ..initializers import he_orthogonal_init 13 | 14 | 15 | class Dense(torch.nn.Module): 16 | """ 17 | Combines dense layer with scaling for swish activation. 18 | 19 | Parameters 20 | ---------- 21 | units: int 22 | Output embedding size. 23 | activation: str 24 | Name of the activation function to use. 25 | bias: bool 26 | True if use bias. 27 | """ 28 | 29 | def __init__(self, in_features, out_features, bias=False, activation=None): 30 | super().__init__() 31 | 32 | self.linear = torch.nn.Linear(in_features, out_features, bias=bias) 33 | self.reset_parameters() 34 | 35 | if isinstance(activation, str): 36 | activation = activation.lower() 37 | if activation in ["swish", "silu"]: 38 | self._activation = ScaledSiLU() 39 | elif activation == "siqu": 40 | self._activation = SiQU() 41 | elif activation is None: 42 | self._activation = torch.nn.Identity() 43 | else: 44 | raise NotImplementedError( 45 | "Activation function not implemented for GemNet (yet)." 46 | ) 47 | 48 | def reset_parameters(self, initializer=he_orthogonal_init): 49 | initializer(self.linear.weight) 50 | if self.linear.bias is not None: 51 | self.linear.bias.data.fill_(0) 52 | 53 | def forward(self, x): 54 | x = self.linear(x) 55 | x = self._activation(x) 56 | return x 57 | 58 | 59 | class ScaledSiLU(torch.nn.Module): 60 | def __init__(self): 61 | super().__init__() 62 | self.scale_factor = 1 / 0.6 63 | self._activation = torch.nn.SiLU() 64 | 65 | def forward(self, x): 66 | return self._activation(x) * self.scale_factor 67 | 68 | 69 | class SiQU(torch.nn.Module): 70 | def __init__(self): 71 | super().__init__() 72 | self._activation = torch.nn.SiLU() 73 | 74 | def forward(self, x): 75 | return x * self._activation(x) 76 | 77 | 78 | class ResidualLayer(torch.nn.Module): 79 | """ 80 | Residual block with output scaled by 1/sqrt(2). 81 | 82 | Parameters 83 | ---------- 84 | units: int 85 | Output embedding size. 86 | nLayers: int 87 | Number of dense layers. 88 | layer_kwargs: str 89 | Keyword arguments for initializing the layers. 90 | """ 91 | 92 | def __init__( 93 | self, units: int, nLayers: int = 2, layer=Dense, **layer_kwargs 94 | ): 95 | super().__init__() 96 | self.dense_mlp = torch.nn.Sequential( 97 | *[ 98 | layer( 99 | in_features=units, 100 | out_features=units, 101 | bias=False, 102 | **layer_kwargs 103 | ) 104 | for _ in range(nLayers) 105 | ] 106 | ) 107 | self.inv_sqrt_2 = 1 / math.sqrt(2) 108 | 109 | def forward(self, input): 110 | x = self.dense_mlp(input) 111 | x = input + x 112 | x = x * self.inv_sqrt_2 113 | return x 114 | -------------------------------------------------------------------------------- /mdsim/models/gemnet/layers/efficient.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | 10 | from ..initializers import he_orthogonal_init 11 | 12 | 13 | class EfficientInteractionDownProjection(torch.nn.Module): 14 | """ 15 | Down projection in the efficient reformulation. 16 | 17 | Parameters 18 | ---------- 19 | emb_size_interm: int 20 | Intermediate embedding size (down-projection size). 21 | kernel_initializer: callable 22 | Initializer of the weight matrix. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | num_spherical: int, 28 | num_radial: int, 29 | emb_size_interm: int, 30 | ): 31 | super().__init__() 32 | 33 | self.num_spherical = num_spherical 34 | self.num_radial = num_radial 35 | self.emb_size_interm = emb_size_interm 36 | 37 | self.reset_parameters() 38 | 39 | def reset_parameters(self): 40 | self.weight = torch.nn.Parameter( 41 | torch.empty( 42 | (self.num_spherical, self.num_radial, self.emb_size_interm) 43 | ), 44 | requires_grad=True, 45 | ) 46 | he_orthogonal_init(self.weight) 47 | 48 | def forward(self, rbf, sph, id_ca, id_ragged_idx): 49 | """ 50 | 51 | Arguments 52 | --------- 53 | rbf: torch.Tensor, shape=(1, nEdges, num_radial) 54 | sph: torch.Tensor, shape=(nEdges, Kmax, num_spherical) 55 | id_ca 56 | id_ragged_idx 57 | 58 | Returns 59 | ------- 60 | rbf_W1: torch.Tensor, shape=(nEdges, emb_size_interm, num_spherical) 61 | sph: torch.Tensor, shape=(nEdges, Kmax, num_spherical) 62 | Kmax = maximum number of neighbors of the edges 63 | """ 64 | num_edges = rbf.shape[1] 65 | 66 | # MatMul: mul + sum over num_radial 67 | rbf_W1 = torch.matmul(rbf, self.weight) 68 | # (num_spherical, nEdges , emb_size_interm) 69 | rbf_W1 = rbf_W1.permute(1, 2, 0) 70 | # (nEdges, emb_size_interm, num_spherical) 71 | 72 | # Zero padded dense matrix 73 | # maximum number of neighbors, catch empty id_ca with maximum 74 | if sph.shape[0] == 0: 75 | Kmax = 0 76 | else: 77 | Kmax = torch.max( 78 | torch.max(id_ragged_idx + 1), 79 | torch.tensor(0).to(id_ragged_idx.device), 80 | ) 81 | 82 | sph2 = sph.new_zeros(num_edges, Kmax, self.num_spherical) 83 | sph2[id_ca, id_ragged_idx] = sph 84 | 85 | sph2 = torch.transpose(sph2, 1, 2) 86 | # (nEdges, num_spherical/emb_size_interm, Kmax) 87 | 88 | return rbf_W1, sph2 89 | 90 | 91 | class EfficientInteractionBilinear(torch.nn.Module): 92 | """ 93 | Efficient reformulation of the bilinear layer and subsequent summation. 94 | 95 | Parameters 96 | ---------- 97 | units_out: int 98 | Embedding output size of the bilinear layer. 99 | kernel_initializer: callable 100 | Initializer of the weight matrix. 101 | """ 102 | 103 | def __init__( 104 | self, 105 | emb_size: int, 106 | emb_size_interm: int, 107 | units_out: int, 108 | ): 109 | super().__init__() 110 | self.emb_size = emb_size 111 | self.emb_size_interm = emb_size_interm 112 | self.units_out = units_out 113 | 114 | self.reset_parameters() 115 | 116 | def reset_parameters(self): 117 | self.weight = torch.nn.Parameter( 118 | torch.empty( 119 | (self.emb_size, self.emb_size_interm, self.units_out), 120 | requires_grad=True, 121 | ) 122 | ) 123 | he_orthogonal_init(self.weight) 124 | 125 | def forward( 126 | self, 127 | basis, 128 | m, 129 | id_reduce, 130 | id_ragged_idx, 131 | ): 132 | """ 133 | 134 | Arguments 135 | --------- 136 | basis 137 | m: quadruplets: m = m_db , triplets: m = m_ba 138 | id_reduce 139 | id_ragged_idx 140 | 141 | Returns 142 | ------- 143 | m_ca: torch.Tensor, shape=(nEdges, units_out) 144 | Edge embeddings. 145 | """ 146 | # num_spherical is actually num_spherical**2 for quadruplets 147 | (rbf_W1, sph) = basis 148 | # (nEdges, emb_size_interm, num_spherical), (nEdges, num_spherical, Kmax) 149 | nEdges = rbf_W1.shape[0] 150 | 151 | # Create (zero-padded) dense matrix of the neighboring edge embeddings. 152 | Kmax = torch.max( 153 | torch.max(id_ragged_idx) + 1, 154 | torch.tensor(0).to(id_ragged_idx.device), 155 | ) 156 | # maximum number of neighbors, catch empty id_reduce_ji with maximum 157 | m2 = m.new_zeros(nEdges, Kmax, self.emb_size) 158 | m2[id_reduce, id_ragged_idx] = m 159 | # (num_quadruplets or num_triplets, emb_size) -> (nEdges, Kmax, emb_size) 160 | 161 | sum_k = torch.matmul(sph, m2) # (nEdges, num_spherical, emb_size) 162 | 163 | # MatMul: mul + sum over num_spherical 164 | rbf_W1_sum_k = torch.matmul(rbf_W1, sum_k) 165 | # (nEdges, emb_size_interm, emb_size) 166 | 167 | # Bilinear: Sum over emb_size_interm and emb_size 168 | m_ca = torch.matmul(rbf_W1_sum_k.permute(2, 0, 1), self.weight) 169 | # (emb_size, nEdges, units_out) 170 | m_ca = torch.sum(m_ca, dim=0) 171 | # (nEdges, units_out) 172 | 173 | return m_ca 174 | -------------------------------------------------------------------------------- /mdsim/models/gemnet/layers/embedding_block.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from .base_layers import Dense 12 | 13 | 14 | class AtomEmbedding(torch.nn.Module): 15 | """ 16 | Initial atom embeddings based on the atom type 17 | 18 | Parameters 19 | ---------- 20 | emb_size: int 21 | Atom embeddings size 22 | """ 23 | 24 | def __init__(self, emb_size): 25 | super().__init__() 26 | self.emb_size = emb_size 27 | 28 | # Atom embeddings: We go up to Bi (83). 29 | self.embeddings = torch.nn.Embedding(83, emb_size) 30 | # init by uniform distribution 31 | torch.nn.init.uniform_( 32 | self.embeddings.weight, a=-np.sqrt(3), b=np.sqrt(3) 33 | ) 34 | 35 | def forward(self, Z): 36 | """ 37 | Returns 38 | ------- 39 | h: torch.Tensor, shape=(nAtoms, emb_size) 40 | Atom embeddings. 41 | """ 42 | h = self.embeddings(Z - 1) # -1 because Z.min()=1 (==Hydrogen) 43 | return h 44 | 45 | 46 | class EdgeEmbedding(torch.nn.Module): 47 | """ 48 | Edge embedding based on the concatenation of atom embeddings and subsequent dense layer. 49 | 50 | Parameters 51 | ---------- 52 | emb_size: int 53 | Embedding size after the dense layer. 54 | activation: str 55 | Activation function used in the dense layer. 56 | """ 57 | 58 | def __init__( 59 | self, 60 | atom_features, 61 | edge_features, 62 | out_features, 63 | activation=None, 64 | ): 65 | super().__init__() 66 | in_features = 2 * atom_features + edge_features 67 | self.dense = Dense( 68 | in_features, out_features, activation=activation, bias=False 69 | ) 70 | 71 | def forward( 72 | self, 73 | h, 74 | m_rbf, 75 | idx_s, 76 | idx_t, 77 | ): 78 | """ 79 | 80 | Arguments 81 | --------- 82 | h 83 | m_rbf: shape (nEdges, nFeatures) 84 | in embedding block: m_rbf = rbf ; In interaction block: m_rbf = m_st 85 | idx_s 86 | idx_t 87 | 88 | Returns 89 | ------- 90 | m_st: torch.Tensor, shape=(nEdges, emb_size) 91 | Edge embeddings. 92 | """ 93 | h_s = h[idx_s] # shape=(nEdges, emb_size) 94 | h_t = h[idx_t] # shape=(nEdges, emb_size) 95 | 96 | m_st = torch.cat( 97 | [h_s, h_t, m_rbf], dim=-1 98 | ) # (nEdges, 2*emb_size+nFeatures) 99 | m_st = self.dense(m_st) # (nEdges, emb_size) 100 | return m_st 101 | -------------------------------------------------------------------------------- /mdsim/models/gemnet/layers/radial_basis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import math 9 | 10 | import numpy as np 11 | import torch 12 | from scipy.special import binom 13 | from torch_geometric.nn.models.schnet import GaussianSmearing 14 | 15 | 16 | class PolynomialEnvelope(torch.nn.Module): 17 | """ 18 | Polynomial envelope function that ensures a smooth cutoff. 19 | 20 | Parameters 21 | ---------- 22 | exponent: int 23 | Exponent of the envelope function. 24 | """ 25 | 26 | def __init__(self, exponent): 27 | super().__init__() 28 | assert exponent > 0 29 | self.p = exponent 30 | self.a = -(self.p + 1) * (self.p + 2) / 2 31 | self.b = self.p * (self.p + 2) 32 | self.c = -self.p * (self.p + 1) / 2 33 | 34 | def forward(self, d_scaled): 35 | env_val = ( 36 | 1 37 | + self.a * d_scaled ** self.p 38 | + self.b * d_scaled ** (self.p + 1) 39 | + self.c * d_scaled ** (self.p + 2) 40 | ) 41 | return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled)) 42 | 43 | 44 | class ExponentialEnvelope(torch.nn.Module): 45 | """ 46 | Exponential envelope function that ensures a smooth cutoff, 47 | as proposed in Unke, Chmiela, Gastegger, Schütt, Sauceda, Müller 2021. 48 | SpookyNet: Learning Force Fields with Electronic Degrees of Freedom 49 | and Nonlocal Effects 50 | """ 51 | 52 | def __init__(self): 53 | super().__init__() 54 | 55 | def forward(self, d_scaled): 56 | env_val = torch.exp( 57 | -(d_scaled ** 2) / ((1 - d_scaled) * (1 + d_scaled)) 58 | ) 59 | return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled)) 60 | 61 | 62 | class SphericalBesselBasis(torch.nn.Module): 63 | """ 64 | 1D spherical Bessel basis 65 | 66 | Parameters 67 | ---------- 68 | num_radial: int 69 | Controls maximum frequency. 70 | cutoff: float 71 | Cutoff distance in Angstrom. 72 | """ 73 | 74 | def __init__( 75 | self, 76 | num_radial: int, 77 | cutoff: float, 78 | ): 79 | super().__init__() 80 | self.norm_const = math.sqrt(2 / (cutoff ** 3)) 81 | # cutoff ** 3 to counteract dividing by d_scaled = d / cutoff 82 | 83 | # Initialize frequencies at canonical positions 84 | self.frequencies = torch.nn.Parameter( 85 | data=torch.tensor( 86 | np.pi * np.arange(1, num_radial + 1, dtype=np.float32) 87 | ), 88 | requires_grad=True, 89 | ) 90 | 91 | def forward(self, d_scaled): 92 | return ( 93 | self.norm_const 94 | / d_scaled[:, None] 95 | * torch.sin(self.frequencies * d_scaled[:, None]) 96 | ) # (num_edges, num_radial) 97 | 98 | 99 | class BernsteinBasis(torch.nn.Module): 100 | """ 101 | Bernstein polynomial basis, 102 | as proposed in Unke, Chmiela, Gastegger, Schütt, Sauceda, Müller 2021. 103 | SpookyNet: Learning Force Fields with Electronic Degrees of Freedom 104 | and Nonlocal Effects 105 | 106 | Parameters 107 | ---------- 108 | num_radial: int 109 | Controls maximum frequency. 110 | pregamma_initial: float 111 | Initial value of exponential coefficient gamma. 112 | Default: gamma = 0.5 * a_0**-1 = 0.94486, 113 | inverse softplus -> pregamma = log e**gamma - 1 = 0.45264 114 | """ 115 | 116 | def __init__( 117 | self, 118 | num_radial: int, 119 | pregamma_initial: float = 0.45264, 120 | ): 121 | super().__init__() 122 | prefactor = binom(num_radial - 1, np.arange(num_radial)) 123 | self.register_buffer( 124 | "prefactor", 125 | torch.tensor(prefactor, dtype=torch.float), 126 | persistent=False, 127 | ) 128 | 129 | self.pregamma = torch.nn.Parameter( 130 | data=torch.tensor(pregamma_initial, dtype=torch.float), 131 | requires_grad=True, 132 | ) 133 | self.softplus = torch.nn.Softplus() 134 | 135 | exp1 = torch.arange(num_radial) 136 | self.register_buffer("exp1", exp1[None, :], persistent=False) 137 | exp2 = num_radial - 1 - exp1 138 | self.register_buffer("exp2", exp2[None, :], persistent=False) 139 | 140 | def forward(self, d_scaled): 141 | gamma = self.softplus(self.pregamma) # constrain to positive 142 | exp_d = torch.exp(-gamma * d_scaled)[:, None] 143 | return ( 144 | self.prefactor * (exp_d ** self.exp1) * ((1 - exp_d) ** self.exp2) 145 | ) 146 | 147 | 148 | class RadialBasis(torch.nn.Module): 149 | """ 150 | 151 | Parameters 152 | ---------- 153 | num_radial: int 154 | Controls maximum frequency. 155 | cutoff: float 156 | Cutoff distance in Angstrom. 157 | rbf: dict = {"name": "gaussian"} 158 | Basis function and its hyperparameters. 159 | envelope: dict = {"name": "polynomial", "exponent": 5} 160 | Envelope function and its hyperparameters. 161 | """ 162 | 163 | def __init__( 164 | self, 165 | num_radial: int, 166 | cutoff: float, 167 | rbf: dict = {"name": "gaussian"}, 168 | envelope: dict = {"name": "polynomial", "exponent": 5}, 169 | ): 170 | super().__init__() 171 | self.inv_cutoff = 1 / cutoff 172 | 173 | env_name = envelope["name"].lower() 174 | env_hparams = envelope.copy() 175 | del env_hparams["name"] 176 | 177 | if env_name == "polynomial": 178 | self.envelope = PolynomialEnvelope(**env_hparams) 179 | elif env_name == "exponential": 180 | self.envelope = ExponentialEnvelope(**env_hparams) 181 | else: 182 | raise ValueError(f"Unknown envelope function '{env_name}'.") 183 | 184 | rbf_name = rbf["name"].lower() 185 | rbf_hparams = rbf.copy() 186 | del rbf_hparams["name"] 187 | 188 | # RBFs get distances scaled to be in [0, 1] 189 | if rbf_name == "gaussian": 190 | self.rbf = GaussianSmearing( 191 | start=0, stop=1, num_gaussians=num_radial, **rbf_hparams 192 | ) 193 | elif rbf_name == "spherical_bessel": 194 | self.rbf = SphericalBesselBasis( 195 | num_radial=num_radial, cutoff=cutoff, **rbf_hparams 196 | ) 197 | elif rbf_name == "bernstein": 198 | self.rbf = BernsteinBasis(num_radial=num_radial, **rbf_hparams) 199 | else: 200 | raise ValueError(f"Unknown radial basis function '{rbf_name}'.") 201 | 202 | def forward(self, d): 203 | d_scaled = d * self.inv_cutoff 204 | 205 | env = self.envelope(d_scaled) 206 | return env[:, None] * self.rbf(d_scaled) # (nEdges, num_radial) 207 | -------------------------------------------------------------------------------- /mdsim/models/gemnet/layers/scaling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import logging 9 | 10 | import torch 11 | 12 | from ..utils import read_value_json, update_json 13 | 14 | 15 | class AutomaticFit: 16 | """ 17 | All added variables are processed in the order of creation. 18 | """ 19 | 20 | activeVar = None 21 | queue = None 22 | fitting_mode = False 23 | 24 | def __init__(self, variable, scale_file, name): 25 | self.variable = variable # variable to find value for 26 | self.scale_file = scale_file 27 | self._name = name 28 | 29 | self._fitted = False 30 | self.load_maybe() 31 | 32 | # first instance created 33 | if AutomaticFit.fitting_mode and not self._fitted: 34 | 35 | # if first layer set to active 36 | if AutomaticFit.activeVar is None: 37 | AutomaticFit.activeVar = self 38 | AutomaticFit.queue = [] # initialize 39 | # else add to queue 40 | else: 41 | self._add2queue() # adding variables to list fill fail in graph mode 42 | 43 | def reset(): 44 | AutomaticFit.activeVar = None 45 | AutomaticFit.all_processed = False 46 | 47 | def fitting_completed(): 48 | return AutomaticFit.queue is None 49 | 50 | def set2fitmode(): 51 | AutomaticFit.reset() 52 | AutomaticFit.fitting_mode = True 53 | 54 | def _add2queue(self): 55 | logging.debug(f"Add {self._name} to queue.") 56 | # check that same variable is not added twice 57 | for var in AutomaticFit.queue: 58 | if self._name == var._name: 59 | raise ValueError( 60 | f"Variable with the same name ({self._name}) was already added to queue!" 61 | ) 62 | AutomaticFit.queue += [self] 63 | 64 | def set_next_active(self): 65 | """ 66 | Set the next variable in the queue that should be fitted. 67 | """ 68 | queue = AutomaticFit.queue 69 | if len(queue) == 0: 70 | logging.debug("Processed all variables.") 71 | AutomaticFit.queue = None 72 | AutomaticFit.activeVar = None # reset to None 73 | return 74 | AutomaticFit.activeVar = queue.pop(0) 75 | 76 | def load_maybe(self): 77 | """ 78 | Load variable from file or set to initial value of the variable. 79 | """ 80 | value = read_value_json(self.scale_file, self._name) 81 | if value is None: 82 | logging.debug( 83 | f"Initialize variable {self._name}' to {self.variable.numpy():.3f}" 84 | ) 85 | else: 86 | self._fitted = True 87 | logging.debug(f"Set scale factor {self._name} : {value}") 88 | with torch.no_grad(): 89 | self.variable.copy_(torch.tensor(value)) 90 | 91 | 92 | class AutoScaleFit(AutomaticFit): 93 | """ 94 | Class to automatically fit the scaling factors depending on the observed variances. 95 | 96 | Parameters 97 | ---------- 98 | variable: torch.Tensor 99 | Variable to fit. 100 | scale_file: str 101 | Path to the json file where to store/load from the scaling factors. 102 | """ 103 | 104 | def __init__(self, variable, scale_file, name): 105 | super().__init__(variable, scale_file, name) 106 | 107 | if not self._fitted: 108 | self._init_stats() 109 | 110 | def _init_stats(self): 111 | self.variance_in = 0 112 | self.variance_out = 0 113 | self.nSamples = 0 114 | 115 | @torch.no_grad() 116 | def observe(self, x, y): 117 | """ 118 | Observe variances for input x and output y. 119 | The scaling factor alpha is calculated s.t. Var(alpha * y) ~ Var(x) 120 | """ 121 | if self._fitted: 122 | return 123 | 124 | # only track stats for current variable 125 | if AutomaticFit.activeVar == self: 126 | nSamples = y.shape[0] 127 | self.variance_in += ( 128 | torch.mean(torch.var(x, dim=0)).to(dtype=torch.float32) 129 | * nSamples 130 | ) 131 | self.variance_out += ( 132 | torch.mean(torch.var(y, dim=0)).to(dtype=torch.float32) 133 | * nSamples 134 | ) 135 | self.nSamples += nSamples 136 | 137 | @torch.no_grad() 138 | def fit(self): 139 | """ 140 | Fit the scaling factor based on the observed variances. 141 | """ 142 | if AutomaticFit.activeVar == self: 143 | if self.variance_in == 0: 144 | raise ValueError( 145 | f"Did not track the variable {self._name}. Add observe calls to track the variance before and after." 146 | ) 147 | 148 | # calculate variance preserving scaling factor 149 | self.variance_in = self.variance_in / self.nSamples 150 | self.variance_out = self.variance_out / self.nSamples 151 | 152 | ratio = self.variance_out / self.variance_in 153 | value = torch.sqrt(1 / ratio) 154 | logging.info( 155 | f"Variable: {self._name}, " 156 | f"Var_in: {self.variance_in.item():.3f}, " 157 | f"Var_out: {self.variance_out.item():.3f}, " 158 | f"Ratio: {ratio:.3f} => Scaling factor: {value:.3f}" 159 | ) 160 | 161 | # set variable to calculated value 162 | self.variable.copy_(self.variable * value) 163 | update_json( 164 | self.scale_file, {self._name: float(self.variable.item())} 165 | ) 166 | self.set_next_active() # set next variable in queue to active 167 | 168 | 169 | class ScalingFactor(torch.nn.Module): 170 | """ 171 | Scale the output y of the layer s.t. the (mean) variance wrt. to the reference input x_ref is preserved. 172 | 173 | Parameters 174 | ---------- 175 | scale_file: str 176 | Path to the json file where to store/load from the scaling factors. 177 | name: str 178 | Name of the scaling factor 179 | """ 180 | 181 | def __init__(self, scale_file, name, device=None): 182 | super().__init__() 183 | 184 | self.scale_factor = torch.nn.Parameter( 185 | torch.tensor(1.0, device=device), requires_grad=False 186 | ) 187 | self.autofit = AutoScaleFit(self.scale_factor, scale_file, name) 188 | 189 | def forward(self, x_ref, y): 190 | y = y * self.scale_factor 191 | self.autofit.observe(x_ref, y) 192 | 193 | return y 194 | -------------------------------------------------------------------------------- /mdsim/models/gemnet/layers/spherical_basis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import sympy as sym 9 | import torch 10 | from torch_geometric.nn.models.schnet import GaussianSmearing 11 | 12 | from .basis_utils import real_sph_harm 13 | from .radial_basis import RadialBasis 14 | 15 | 16 | class CircularBasisLayer(torch.nn.Module): 17 | """ 18 | 2D Fourier Bessel Basis 19 | 20 | Parameters 21 | ---------- 22 | num_spherical: int 23 | Controls maximum frequency. 24 | radial_basis: RadialBasis 25 | Radial basis functions 26 | cbf: dict 27 | Name and hyperparameters of the cosine basis function 28 | efficient: bool 29 | Whether to use the "efficient" summation order 30 | """ 31 | 32 | def __init__( 33 | self, 34 | num_spherical: int, 35 | radial_basis: RadialBasis, 36 | cbf: str, 37 | efficient: bool = False, 38 | ): 39 | super().__init__() 40 | 41 | self.radial_basis = radial_basis 42 | self.efficient = efficient 43 | 44 | cbf_name = cbf["name"].lower() 45 | cbf_hparams = cbf.copy() 46 | del cbf_hparams["name"] 47 | 48 | if cbf_name == "gaussian": 49 | self.cosφ_basis = GaussianSmearing( 50 | start=-1, stop=1, num_gaussians=num_spherical, **cbf_hparams 51 | ) 52 | elif cbf_name == "spherical_harmonics": 53 | Y_lm = real_sph_harm( 54 | num_spherical, use_theta=False, zero_m_only=True 55 | ) 56 | sph_funcs = [] # (num_spherical,) 57 | 58 | # convert to tensorflow functions 59 | z = sym.symbols("z") 60 | modules = {"sin": torch.sin, "cos": torch.cos, "sqrt": torch.sqrt} 61 | m_order = 0 # only single angle 62 | for l_degree in range(len(Y_lm)): # num_spherical 63 | if ( 64 | l_degree == 0 65 | ): # Y_00 is only a constant -> function returns value and not tensor 66 | first_sph = sym.lambdify( 67 | [z], Y_lm[l_degree][m_order], modules 68 | ) 69 | sph_funcs.append( 70 | lambda z: torch.zeros_like(z) + first_sph(z) 71 | ) 72 | else: 73 | sph_funcs.append( 74 | sym.lambdify([z], Y_lm[l_degree][m_order], modules) 75 | ) 76 | self.cosφ_basis = lambda cosφ: torch.stack( 77 | [f(cosφ) for f in sph_funcs], dim=1 78 | ) 79 | else: 80 | raise ValueError(f"Unknown cosine basis function '{cbf_name}'.") 81 | 82 | def forward(self, D_ca, cosφ_cab, id3_ca): 83 | rbf = self.radial_basis(D_ca) # (num_edges, num_radial) 84 | cbf = self.cosφ_basis(cosφ_cab) # (num_triplets, num_spherical) 85 | 86 | if not self.efficient: 87 | rbf = rbf[id3_ca] # (num_triplets, num_radial) 88 | out = (rbf[:, None, :] * cbf[:, :, None]).view( 89 | -1, rbf.shape[-1] * cbf.shape[-1] 90 | ) 91 | return (out,) 92 | # (num_triplets, num_radial * num_spherical) 93 | else: 94 | return (rbf[None, :, :], cbf) 95 | # (1, num_edges, num_radial), (num_edges, num_spherical) 96 | -------------------------------------------------------------------------------- /mdsim/models/schnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | from torch_geometric.nn import SchNet, radius_graph 10 | from torch_scatter import scatter 11 | 12 | from mdsim.common.registry import registry 13 | from mdsim.common.utils import ( 14 | conditional_grad, 15 | get_pbc_distances, 16 | radius_graph_pbc, 17 | ) 18 | 19 | 20 | @registry.register_model("schnet") 21 | class SchNetWrap(SchNet): 22 | r"""Wrapper around the continuous-filter convolutional neural network SchNet from the 23 | `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling 24 | Quantum Interactions" `_. Each layer uses interaction 25 | block of the form: 26 | 27 | .. math:: 28 | \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot 29 | h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))), 30 | 31 | Args: 32 | num_atoms (int): Unused argument 33 | bond_feat_dim (int): Unused argument 34 | num_targets (int): Number of targets to predict. 35 | use_pbc (bool, optional): If set to :obj:`True`, account for periodic boundary conditions. 36 | (default: :obj:`True`) 37 | regress_forces (bool, optional): If set to :obj:`True`, predict forces by differentiating 38 | energy with respect to positions. 39 | (default: :obj:`True`) 40 | otf_graph (bool, optional): If set to :obj:`True`, compute graph edges on the fly. 41 | (default: :obj:`False`) 42 | hidden_channels (int, optional): Number of hidden channels. 43 | (default: :obj:`128`) 44 | num_filters (int, optional): Number of filters to use. 45 | (default: :obj:`128`) 46 | num_interactions (int, optional): Number of interaction blocks 47 | (default: :obj:`6`) 48 | num_gaussians (int, optional): The number of gaussians :math:`\mu`. 49 | (default: :obj:`50`) 50 | cutoff (float, optional): Cutoff distance for interatomic interactions. 51 | (default: :obj:`10.0`) 52 | readout (string, optional): Whether to apply :obj:`"add"` or 53 | :obj:`"mean"` global aggregation. (default: :obj:`"add"`) 54 | """ 55 | 56 | def __init__( 57 | self, 58 | num_atoms, # not used 59 | bond_feat_dim, # not used 60 | num_targets, 61 | use_pbc=True, 62 | regress_forces=True, 63 | otf_graph=False, 64 | hidden_channels=128, 65 | num_filters=128, 66 | num_interactions=6, 67 | num_gaussians=50, 68 | cutoff=10.0, 69 | readout="add", 70 | ): 71 | self.num_targets = num_targets 72 | self.regress_forces = regress_forces 73 | self.use_pbc = use_pbc 74 | self.cutoff = cutoff 75 | self.otf_graph = otf_graph 76 | 77 | super(SchNetWrap, self).__init__( 78 | hidden_channels=hidden_channels, 79 | num_filters=num_filters, 80 | num_interactions=num_interactions, 81 | num_gaussians=num_gaussians, 82 | cutoff=cutoff, 83 | readout=readout, 84 | ) 85 | 86 | @conditional_grad(torch.enable_grad()) 87 | def _forward(self, data): 88 | z = data.atomic_numbers.long() 89 | pos = data.pos 90 | batch = data.batch 91 | 92 | if self.otf_graph: 93 | edge_index, cell_offsets, _, neighbors = radius_graph_pbc( 94 | data, self.cutoff, 500 95 | ) 96 | data.edge_index = edge_index 97 | data.cell_offsets = cell_offsets 98 | data.neighbors = neighbors 99 | 100 | if self.use_pbc: 101 | assert z.dim() == 1 and z.dtype == torch.long 102 | 103 | out = get_pbc_distances( 104 | data.pos, 105 | data.edge_index, 106 | data.cell, 107 | data.cell_offsets, 108 | data.natoms, 109 | ) 110 | 111 | edge_index = out["edge_index"] 112 | edge_weight = out["distances"] 113 | edge_attr = self.distance_expansion(edge_weight) 114 | 115 | h = self.embedding(z) 116 | for interaction in self.interactions: 117 | h = h + interaction(h, edge_index, edge_weight, edge_attr) 118 | 119 | h = self.lin1(h) 120 | h = self.act(h) 121 | h = self.lin2(h) 122 | 123 | batch = torch.zeros_like(z) if batch is None else batch 124 | energy = scatter(h, batch, dim=0, reduce=self.readout) 125 | else: 126 | energy = super(SchNetWrap, self).forward(z, pos, batch) 127 | return energy 128 | 129 | def forward(self, data): 130 | if self.regress_forces: 131 | data.pos.requires_grad_(True) 132 | energy = self._forward(data) 133 | 134 | if self.regress_forces: 135 | forces = -1 * ( 136 | torch.autograd.grad( 137 | energy, 138 | data.pos, 139 | grad_outputs=torch.ones_like(energy), 140 | create_graph=True, 141 | )[0] 142 | ) 143 | return energy, forces 144 | else: 145 | return energy 146 | 147 | @property 148 | def num_params(self): 149 | return sum(p.numel() for p in self.parameters()) 150 | -------------------------------------------------------------------------------- /mdsim/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /mdsim/models/utils/activations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | class Act(torch.nn.Module): 13 | def __init__(self, act, slope=0.05): 14 | super(Act, self).__init__() 15 | self.act = act 16 | self.slope = slope 17 | self.shift = torch.log(torch.tensor(2.0)).item() 18 | 19 | def forward(self, input): 20 | if self.act == "relu": 21 | return F.relu(input) 22 | elif self.act == "leaky_relu": 23 | return F.leaky_relu(input) 24 | elif self.act == "sp": 25 | return F.softplus(input, beta=1) 26 | elif self.act == "leaky_sp": 27 | return F.softplus(input, beta=1) - self.slope * F.relu(-input) 28 | elif self.act == "elu": 29 | return F.elu(input, alpha=1) 30 | elif self.act == "leaky_elu": 31 | return F.elu(input, alpha=1) - self.slope * F.relu(-input) 32 | elif self.act == "ssp": 33 | return F.softplus(input, beta=1) - self.shift 34 | elif self.act == "leaky_ssp": 35 | return ( 36 | F.softplus(input, beta=1) 37 | - self.slope * F.relu(-input) 38 | - self.shift 39 | ) 40 | elif self.act == "tanh": 41 | return torch.tanh(input) 42 | elif self.act == "leaky_tanh": 43 | return torch.tanh(input) + self.slope * input 44 | elif self.act == "swish": 45 | return torch.sigmoid(input) * input 46 | else: 47 | raise RuntimeError(f"Undefined activation called {self.act}") 48 | -------------------------------------------------------------------------------- /mdsim/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | -------------------------------------------------------------------------------- /mdsim/modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from mdsim.common import distutils 5 | 6 | 7 | class L2MAELoss(nn.Module): 8 | def __init__(self, reduction="mean"): 9 | super().__init__() 10 | self.reduction = reduction 11 | assert reduction in ["mean", "sum"] 12 | 13 | def forward(self, input: torch.Tensor, target: torch.Tensor): 14 | dists = torch.norm(input - target, p=2, dim=-1) 15 | if self.reduction == "mean": 16 | return torch.mean(dists) 17 | elif self.reduction == "sum": 18 | return torch.sum(dists) 19 | 20 | 21 | class DDPLoss(nn.Module): 22 | def __init__(self, loss_fn, reduction="mean"): 23 | super().__init__() 24 | self.loss_fn = loss_fn 25 | self.loss_fn.reduction = "sum" 26 | self.reduction = reduction 27 | assert reduction in ["mean", "sum"] 28 | 29 | def forward(self, input: torch.Tensor, target: torch.Tensor): 30 | loss = self.loss_fn(input, target) 31 | if self.reduction == "mean": 32 | num_samples = input.shape[0] 33 | num_samples = distutils.all_reduce( 34 | num_samples, device=input.device 35 | ) 36 | # Multiply by world size since gradients are averaged 37 | # across DDP replicas 38 | return loss * distutils.get_world_size() / num_samples 39 | else: 40 | return loss 41 | -------------------------------------------------------------------------------- /mdsim/modules/normalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Normalizer(object): 5 | """Normalize a Tensor and restore it later.""" 6 | 7 | def __init__(self, tensor=None, mean=None, std=None, device=None): 8 | """tensor is taken as a sample to calculate the mean and std""" 9 | if tensor is None and mean is None: 10 | return 11 | 12 | if device is None: 13 | device = "cpu" 14 | 15 | if tensor is not None: 16 | self.mean = torch.mean(tensor, dim=0).to(device) 17 | self.std = torch.std(tensor, dim=0).to(device) 18 | return 19 | 20 | if mean is not None and std is not None: 21 | self.mean = torch.tensor(mean).to(device) 22 | self.std = torch.tensor(std).to(device) 23 | 24 | def to(self, device): 25 | self.mean = self.mean.to(device) 26 | self.std = self.std.to(device) 27 | 28 | def norm(self, tensor): 29 | return (tensor - self.mean) / self.std 30 | 31 | def denorm(self, normed_tensor): 32 | return normed_tensor * self.std + self.mean 33 | 34 | def state_dict(self): 35 | return {"mean": self.mean, "std": self.std} 36 | 37 | def load_state_dict(self, state_dict): 38 | self.mean = state_dict["mean"].to(self.mean.device) 39 | self.std = state_dict["std"].to(self.mean.device) 40 | -------------------------------------------------------------------------------- /mdsim/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["TrainTask", "PredictTask", "ValidateTask"] 2 | 3 | from .task import PredictTask, TrainTask, ValidateTask 4 | -------------------------------------------------------------------------------- /mdsim/tasks/task.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import logging 3 | import os 4 | 5 | from mdsim.common.registry import registry 6 | from mdsim.trainers.trainer import Trainer 7 | 8 | 9 | class BaseTask: 10 | def __init__(self, config): 11 | self.config = config 12 | 13 | def setup(self, trainer): 14 | self.trainer = trainer 15 | if self.config["checkpoint"] is not None: 16 | self.trainer.load_checkpoint(self.config["checkpoint"]) 17 | else: 18 | ckpt_dir = (Path(self.trainer.config["cmd"]["checkpoint_dir"]) / 'checkpoint.pt') 19 | if ckpt_dir.exists(): 20 | self.trainer.load_checkpoint(ckpt_dir) 21 | 22 | # save checkpoint path to runner state for slurm resubmissions 23 | self.chkpt_path = os.path.join( 24 | self.trainer.config["cmd"]["checkpoint_dir"], "checkpoint.pt" 25 | ) 26 | 27 | def run(self): 28 | raise NotImplementedError 29 | 30 | 31 | @registry.register_task("train") 32 | class TrainTask(BaseTask): 33 | def _process_error(self, e: RuntimeError): 34 | e_str = str(e) 35 | if ( 36 | "find_unused_parameters" in e_str 37 | and "torch.nn.parallel.DistributedDataParallel" in e_str 38 | ): 39 | for name, parameter in self.trainer.model.named_parameters(): 40 | if parameter.requires_grad and parameter.grad is None: 41 | logging.warning( 42 | f"Parameter {name} has no gradient. Consider removing it from the model." 43 | ) 44 | 45 | def run(self): 46 | try: 47 | self.trainer.train( 48 | disable_eval_tqdm=self.config.get( 49 | "hide_eval_progressbar", False 50 | ) 51 | ) 52 | except RuntimeError as e: 53 | self._process_error(e) 54 | raise e 55 | 56 | 57 | @registry.register_task("predict") 58 | class PredictTask(BaseTask): 59 | def run(self): 60 | assert ( 61 | self.trainer.test_loader is not None 62 | ), "Test dataset is required for making predictions" 63 | assert self.config["checkpoint"] 64 | results_file = "predictions" 65 | self.trainer.predict( 66 | self.trainer.test_loader, 67 | results_file=results_file, 68 | disable_tqdm=self.config.get("hide_eval_progressbar", False), 69 | ) 70 | 71 | 72 | @registry.register_task("validate") 73 | class ValidateTask(BaseTask): 74 | def run(self): 75 | # Note that the results won't be precise on multi GPUs due to padding of extra images (although the difference should be minor) 76 | assert ( 77 | self.trainer.val_loader is not None 78 | ), "Val dataset is required for making predictions" 79 | assert self.config["checkpoint"] 80 | self.trainer.validate( 81 | split="val", 82 | disable_tqdm=self.config.get("hide_eval_progressbar", False), 83 | ) -------------------------------------------------------------------------------- /mdsim/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "Trainer" 3 | ] 4 | from .trainer import Trainer -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyonofx/MDsim/61ca2cfba373fd374343592556789fa30fbe3b56/preprocessing/__init__.py -------------------------------------------------------------------------------- /preprocessing/alanine_dipeptide.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | import pickle 5 | 6 | import lmdb 7 | import numpy as np 8 | from tqdm import tqdm 9 | from urllib import request as request 10 | 11 | from arrays_to_graphs import AtomsToGraphs 12 | from sklearn.model_selection import train_test_split 13 | 14 | from mdsim.common.utils import EV_TO_KCAL_MOL 15 | 16 | def download(data_path): 17 | url = 'https://zenodo.org/record/7196767/files/alanine_dipeptide.npy?download=1' 18 | request.urlretrieve(url, os.path.join(data_path, 'alanine_dipeptide.npy')) 19 | 20 | def write_to_lmdb(data_path, db_path, time_split): 21 | a2g = AtomsToGraphs( 22 | max_neigh=1000, 23 | radius=6, 24 | r_energy=False, 25 | r_forces=True, 26 | r_distances=False, 27 | r_edges=False, 28 | device='cpu' 29 | ) 30 | 31 | data_file = (Path(data_path) / 'alanine_dipeptide.npy') 32 | Path(data_path).mkdir(parents=True, exist_ok=True) 33 | if not data_file.is_file(): 34 | download(data_path) 35 | 36 | n_points = 50000 37 | all_data = np.load(data_file, allow_pickle=True).item() 38 | all_data['force'] = all_data['force'] / EV_TO_KCAL_MOL 39 | force = all_data['force'] 40 | 41 | if time_split: 42 | test = np.arange(n_points-10000, n_points) 43 | else: 44 | train_val_pool, test = train_test_split(np.arange(n_points), train_size=n_points-10000, 45 | test_size=10000, random_state=123) 46 | for dataset_size, train_size, val_size in zip(['40k'], [38000], [2000]): 47 | print(f'processing dataset with size {dataset_size}.') 48 | if time_split: 49 | train = np.arange(train_size) 50 | val = np.arange(train_size, train_size+val_size) 51 | dataset_size = dataset_size + '_time_split' 52 | else: 53 | size = train_size + val_size 54 | train_val = train_val_pool[:size] 55 | train, val = train_test_split(train_val, train_size=train_size, test_size=val_size, random_state=123) 56 | ranges = [train, val, test] 57 | 58 | norm_stats = { 59 | 'e_mean': 0, 60 | 'e_std': 1, 61 | 'f_mean': force[train].mean(), 62 | 'f_std': force[train].std(), 63 | } 64 | save_path = Path(db_path) / dataset_size 65 | save_path.mkdir(parents=True, exist_ok=True) 66 | np.save(save_path / 'metadata', norm_stats) 67 | 68 | for spidx, split in enumerate(['train', 'val', 'test']): 69 | print(f'processing split {split}.') 70 | # for OCP 71 | save_path = Path(db_path) / dataset_size / split 72 | save_path.mkdir(parents=True, exist_ok=True) 73 | db = lmdb.open( 74 | str(save_path / 'data.lmdb'), 75 | map_size=1099511627776 * 2, 76 | subdir=False, 77 | meminit=False, 78 | map_async=True, 79 | ) 80 | 81 | for i, idx in enumerate(tqdm((ranges[spidx]))): 82 | data = {k: v[idx] if (v.shape[0] == 50001) else v for k, v in all_data.items()} 83 | natoms = np.array([data['pos'].shape[0]] * 1, dtype=np.int64) 84 | data = a2g.convert(natoms, data['pos'], data['atomic_number'], 85 | data['lengths'][None, :], data['angles'][None, :], forces=data['force']) 86 | txn = db.begin(write=True) 87 | txn.put(f"{i}".encode("ascii"), pickle.dumps(data, protocol=-1)) 88 | txn.commit() 89 | 90 | # Save count of objects in lmdb. 91 | txn = db.begin(write=True) 92 | txn.put("length".encode("ascii"), pickle.dumps(i, protocol=-1)) 93 | txn.commit() 94 | 95 | db.sync() 96 | db.close() 97 | 98 | # for nequip. turn energy loss == 0. 99 | data = all_data 100 | data['pbc'] = np.array([True]*3) 101 | data = {k: v[ranges[spidx]] if v.shape[0] == 50001 else v for k, v in data.items()} 102 | data['energy'] = np.zeros(len(ranges[spidx]))[:, None] 103 | data['force'] = data['force'] 104 | data['lattices'] = data['lengths'][:, None] * np.eye(3) 105 | np.savez(save_path / 'nequip_npz', **data) 106 | 107 | if __name__ == '__main__': 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument("--data_path", type=str, default='./DATAPATH/ala') 110 | parser.add_argument("--db_path", type=str, default='./DATAPATH/ala') 111 | parser.add_argument("--time_split", action="store_true", help='split data by time order') 112 | args = parser.parse_args() 113 | write_to_lmdb(args.data_path, args.db_path, args.time_split) 114 | -------------------------------------------------------------------------------- /preprocessing/arrays_to_graphs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data 3 | import mdsim.common.utils as utils 4 | 5 | 6 | class AtomsToGraphs: 7 | """A class to help convert periodic atomic structures to graphs. 8 | 9 | The AtomsToGraphs class takes in periodic atomic structures in form of ASE atoms objects and converts 10 | them into graph representations for use in PyTorch. The primary purpose of this class is to determine the 11 | nearest neighbors within some radius around each individual atom, taking into account PBC, and set the 12 | pair index and distance between atom pairs appropriately. Lastly, atomic properties and the graph information 13 | are put into a PyTorch geometric data object for use with PyTorch. 14 | 15 | Args: 16 | max_neigh (int): Maximum number of neighbors to consider. 17 | radius (int or float): Cutoff radius in Angstroms to search for neighbors. 18 | r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned. 19 | r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned. 20 | r_distances (bool): Return the distances with other properties. 21 | Default is False, so the distances will not be returned. 22 | 23 | Attributes: 24 | max_neigh (int): Maximum number of neighbors to consider. 25 | radius (int or float): Cutoff radius in Angstoms to search for neighbors. 26 | r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned. 27 | r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned. 28 | r_distances (bool): Return the distances with other properties. 29 | Default is False, so the distances will not be returned. 30 | 31 | """ 32 | 33 | def __init__( 34 | self, 35 | max_neigh=200, 36 | radius=6, 37 | r_energy=False, 38 | r_forces=False, 39 | r_distances=False, 40 | r_edges=True, 41 | device='cpu' 42 | ): 43 | self.max_neigh = max_neigh 44 | self.radius = radius 45 | self.r_energy = r_energy 46 | self.r_forces = r_forces 47 | self.r_distances = r_distances 48 | self.r_edges = r_edges 49 | self.device = device 50 | 51 | def convert( 52 | self, 53 | natoms, 54 | positions, 55 | atomic_numbers, 56 | lengths=None, 57 | angles=None, 58 | energy=None, 59 | forces=None, 60 | cell=None, 61 | ): 62 | """Convert a batch of atomic stucture to a batch of graphs. 63 | 64 | Args: 65 | natoms: (B), sum(natoms) == N 66 | positions: (B*N, 3) 67 | atomic_numbers: (B*N) 68 | lengths: (B, 3) lattice lengths [lx, ly, lz] 69 | angles: (B, 3) lattice angles [ax, ay, az] 70 | forces: (B*N, 3) 71 | energy: (B) 72 | 73 | Returns: 74 | data (torch_geometric.data.Data): A torch geometic data object with edge_index, positions, atomic_numbers, 75 | and optionally, energy, forces, and distances. 76 | Optional properties can included by setting r_property=True when constructing the class. 77 | """ 78 | 79 | natoms = torch.from_numpy(natoms).to(self.device).long() 80 | positions = torch.from_numpy(positions).to(self.device).float() 81 | atomic_numbers = torch.from_numpy(atomic_numbers).to(self.device).long() 82 | if cell is None: 83 | lengths = torch.from_numpy(lengths).to(self.device).float() 84 | angles = torch.from_numpy(angles).to(self.device).float() 85 | cells = utils.lattice_params_to_matrix_torch(lengths, angles).float() 86 | else: 87 | cells = torch.from_numpy(cell).to(self.device).float() 88 | 89 | data = Data( 90 | cell=cells, 91 | pos=positions, 92 | atomic_numbers=atomic_numbers, 93 | natoms=natoms, 94 | ) 95 | 96 | # optionally include other properties 97 | if self.r_edges: 98 | edge_index, cell_offsets, edge_distances, _ = utils.radius_graph_pbc( 99 | data, self.radius, self.max_neigh) 100 | data.edge_index = edge_index 101 | data.cell_offsets = cell_offsets 102 | if energy is not None: 103 | energy = torch.from_numpy(energy).to(self.device).float() 104 | data.y = energy 105 | if forces is not None: 106 | forces = torch.from_numpy(forces).to(self.device).float() 107 | data.force = forces 108 | if self.r_distances and self.r_edges: 109 | data.distances = edge_distances 110 | 111 | fixed_idx = torch.zeros(natoms).float() 112 | data.fixed = fixed_idx 113 | 114 | return data.cpu() -------------------------------------------------------------------------------- /preprocessing/lips.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | import pickle 5 | 6 | import lmdb 7 | import numpy as np 8 | from tqdm import tqdm 9 | from urllib import request as request 10 | 11 | from arrays_to_graphs import AtomsToGraphs 12 | from sklearn.model_selection import train_test_split 13 | from ase.io import read 14 | 15 | def download(data_path): 16 | url = 'https://archive.materialscloud.org/record/file?filename=lips.xyz&record_id=1302' 17 | request.urlretrieve(url, os.path.join(data_path, 'lips.xyz')) 18 | 19 | def write_to_lmdb(data_path, db_path): 20 | a2g = AtomsToGraphs( 21 | max_neigh=1000, 22 | radius=4., 23 | r_energy=True, 24 | r_forces=True, 25 | r_distances=False, 26 | r_edges=False, 27 | device='cpu' 28 | ) 29 | 30 | data_file = (Path(data_path) / 'lips.xyz') 31 | Path(data_path).mkdir(parents=True, exist_ok=True) 32 | if not data_file.is_file(): 33 | download(data_path) 34 | 35 | atoms = read(data_file, index=':', format='extxyz') 36 | n_points = len(atoms) 37 | positions, cell, atomic_numbers, energy, forces = [], [], [], [], [] 38 | for i in range(n_points): 39 | positions.append(atoms[i].get_positions()) 40 | cell.append(atoms[i].get_cell()) 41 | atomic_numbers.append(atoms[i].get_atomic_numbers()) 42 | energy.append(atoms[i].get_potential_energy()) 43 | forces.append(atoms[i].get_forces()) 44 | positions = np.array(positions) 45 | cell = np.array(cell)[0] 46 | atomic_numbers = np.array(atomic_numbers)[0] 47 | energy = np.array(energy)[:, None] 48 | forces = np.array(forces) 49 | 50 | for dataset_size, train_size, val_size in zip(['20k'], [19000], [1000]): 51 | print(f'processing dataset with size {dataset_size}.') 52 | size = train_size + val_size 53 | train, test = train_test_split(np.arange(n_points), train_size=size, test_size=n_points-size, random_state=123) 54 | train, val = train_test_split(train, train_size=train_size, test_size=val_size, random_state=123) 55 | ranges = [train, val, test] 56 | 57 | norm_stats = { 58 | 'e_mean': energy[train].mean(), 59 | 'e_std': energy[train].std(), 60 | 'f_mean': forces[train].mean(), 61 | 'f_std': forces[train].std(), 62 | } 63 | save_path = Path(db_path) / dataset_size 64 | save_path.mkdir(parents=True, exist_ok=True) 65 | np.save(save_path / 'metadata', norm_stats) 66 | 67 | for spidx, split in enumerate(['train', 'val', 'test']): 68 | print(f'processing split {split}.') 69 | # for OCP 70 | save_path = Path(db_path) / dataset_size / split 71 | save_path.mkdir(parents=True, exist_ok=True) 72 | db = lmdb.open( 73 | str(save_path / 'data.lmdb'), 74 | map_size=1099511627776 * 2, 75 | subdir=False, 76 | meminit=False, 77 | map_async=True, 78 | ) 79 | 80 | for i, idx in enumerate(tqdm(ranges[spidx])): 81 | 82 | natoms = np.array([atomic_numbers.shape[0]] * 1, dtype=np.int64) 83 | data = a2g.convert(natoms, positions[idx], atomic_numbers, 84 | energy=energy[idx], forces=forces[idx], cell=cell[None, :]) 85 | data.sid = 0 86 | data.fid = idx 87 | txn = db.begin(write=True) 88 | txn.put(f"{i}".encode("ascii"), pickle.dumps(data, protocol=-1)) 89 | txn.commit() 90 | 91 | # Save count of objects in lmdb. 92 | txn = db.begin(write=True) 93 | txn.put("length".encode("ascii"), pickle.dumps(i, protocol=-1)) 94 | txn.commit() 95 | 96 | db.sync() 97 | db.close() 98 | 99 | # for nequip. 100 | data = {} 101 | data['pbc'] = np.array([True]*3) 102 | data['pos'] = positions[ranges[spidx]] 103 | data['energy'] = energy[ranges[spidx]] 104 | data['forces'] = forces[ranges[spidx]] 105 | data['cell'] = cell 106 | data['atomic_numbers'] = atomic_numbers 107 | np.savez(save_path / 'nequip_npz', **data) 108 | 109 | if __name__ == '__main__': 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument("--data_path", type=str, default='./DATAPATH/lips') 112 | parser.add_argument("--db_path", type=str, default='./DATAPATH/lips') 113 | args = parser.parse_args() 114 | write_to_lmdb(args.data_path, args.db_path) -------------------------------------------------------------------------------- /preprocessing/md17.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | import pickle 5 | 6 | import lmdb 7 | import numpy as np 8 | from tqdm import tqdm 9 | from urllib import request as request 10 | from sklearn.model_selection import train_test_split 11 | 12 | from arrays_to_graphs import AtomsToGraphs 13 | from mdsim.common.utils import EV_TO_KCAL_MOL 14 | 15 | MD17_mols = ['aspirin', 'benzene', 'ethanol', 'malonaldehyde', 16 | 'naphthalene', 'salicylic_acid', 'toluene', 'uracil'] 17 | 18 | datasets_dict = dict( 19 | aspirin="aspirin_dft.npz", 20 | azobenzene="azobenzene_dft.npz", 21 | benzene="benzene2017_dft.npz", 22 | ethanol="ethanol_dft.npz", 23 | malonaldehyde="malonaldehyde_dft.npz", 24 | naphthalene="naphthalene_dft.npz", 25 | paracetamol="paracetamol_dft.npz", 26 | salicylic_acid="salicylic_dft.npz", 27 | toluene="toluene_dft.npz", 28 | uracil="uracil_dft.npz") 29 | 30 | def download(molecule, data_path): 31 | url = ( 32 | "http://www.quantum-machine.org/gdml/data/npz/" 33 | + datasets_dict[molecule] 34 | ) 35 | request.urlretrieve(url, os.path.join(data_path, datasets_dict[molecule])) 36 | print(f'{molecule} downloaded.') 37 | 38 | def write_to_lmdb(molecule, data_path, db_path): 39 | print(f'process MD17 molecule: {molecule}.') 40 | a2g = AtomsToGraphs( 41 | max_neigh=1000, 42 | radius=6, 43 | r_energy=True, 44 | r_forces=True, 45 | r_distances=False, 46 | r_edges=False, 47 | device='cpu' 48 | ) 49 | 50 | npzname = datasets_dict[molecule] 51 | data_file = Path(data_path) / npzname 52 | Path(data_path).mkdir(parents=True, exist_ok=True) 53 | if not data_file.is_file(): 54 | download(molecule, data_path) 55 | all_data = np.load(data_file) 56 | 57 | n_points = all_data.f.R.shape[0] 58 | atomic_numbers = all_data.f.z 59 | atomic_numbers = atomic_numbers.astype(np.int64) 60 | positions = all_data.f.R 61 | force = all_data.f.F / EV_TO_KCAL_MOL 62 | energy = all_data.f.E / EV_TO_KCAL_MOL 63 | lengths = np.ones(3)[None, :] * 30. 64 | angles = np.ones(3)[None, :] * 90. 65 | 66 | train_val_pool, test = train_test_split(np.arange(n_points), train_size=n_points-10000, 67 | test_size=10000, random_state=123) 68 | 69 | for dataset_size, train_size, val_size in zip(['10k'], [9500], [500]): 70 | print(f'processing dataset with size {dataset_size}.') 71 | size = train_size + val_size 72 | train_val = train_val_pool[:size] 73 | train, val = train_test_split(train_val, train_size=train_size, test_size=val_size, random_state=123) 74 | ranges = [train, val, test] 75 | 76 | norm_stats = { 77 | 'e_mean': energy[train].mean(), 78 | 'e_std': energy[train].std(), 79 | 'f_mean': force[train].mean(), 80 | 'f_std': force[train].std(), 81 | } 82 | save_path = Path(db_path) / molecule / dataset_size 83 | save_path.mkdir(parents=True, exist_ok=True) 84 | np.save(save_path / 'metadata', norm_stats) 85 | 86 | for spidx, split in enumerate(['train', 'val', 'test']): 87 | print(f'processing split {split}.') 88 | save_path = Path(db_path) / molecule / dataset_size / split 89 | save_path.mkdir(parents=True, exist_ok=True) 90 | db = lmdb.open( 91 | str(save_path / 'data.lmdb'), 92 | map_size=1099511627776 * 2, 93 | subdir=False, 94 | meminit=False, 95 | map_async=True, 96 | ) 97 | for i, idx in enumerate(tqdm(ranges[spidx])): 98 | natoms = np.array([positions.shape[1]] * 1, dtype=np.int64) 99 | data = a2g.convert(natoms, positions[idx], atomic_numbers, 100 | lengths, angles, energy[idx], force[idx]) 101 | data.sid = 0 102 | data.fid = idx 103 | txn = db.begin(write=True) 104 | txn.put(f"{i}".encode("ascii"), pickle.dumps(data, protocol=-1)) 105 | txn.commit() 106 | 107 | # Save count of objects in lmdb. 108 | txn = db.begin(write=True) 109 | txn.put("length".encode("ascii"), pickle.dumps(i, protocol=-1)) 110 | txn.commit() 111 | 112 | db.sync() 113 | db.close() 114 | 115 | # nequip 116 | data = { 117 | 'z': atomic_numbers, 118 | 'E': energy[ranges[spidx]], 119 | 'F': force[ranges[spidx]], 120 | 'R': all_data.f.R[ranges[spidx]] 121 | } 122 | np.savez(save_path / 'nequip_npz', **data) 123 | 124 | if __name__ == '__main__': 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument("--molecule", type=str, default='ethanol') 127 | parser.add_argument("--data_path", type=str, default='./DATAPATH/md17') 128 | parser.add_argument("--db_path", type=str, default='./DATAPATH/md17') 129 | args = parser.parse_args() 130 | assert args.molecule in MD17_mols, ' must be one of the 8 molecules in MD17.' 131 | write_to_lmdb(args.molecule, args.data_path, args.db_path) 132 | -------------------------------------------------------------------------------- /preprocessing/water.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | import pickle 5 | 6 | import lmdb 7 | import numpy as np 8 | from tqdm import tqdm 9 | from urllib import request as request 10 | 11 | from arrays_to_graphs import AtomsToGraphs 12 | from sklearn.model_selection import train_test_split 13 | 14 | from mdsim.common.utils import EV_TO_KCAL_MOL 15 | 16 | def download(data_path): 17 | url = 'https://zenodo.org/record/7196767/files/water.npy?download=1' 18 | request.urlretrieve(url, os.path.join(data_path, 'water.npy')) 19 | 20 | def write_to_lmdb(data_path, db_path, time_split): 21 | a2g = AtomsToGraphs( 22 | max_neigh=1000, 23 | radius=4., 24 | r_energy=True, 25 | r_forces=True, 26 | r_distances=False, 27 | r_edges=False, 28 | device='cpu' 29 | ) 30 | 31 | data_file = (Path(data_path) / 'water.npy') 32 | Path(data_path).mkdir(parents=True, exist_ok=True) 33 | if not data_file.is_file(): 34 | download(data_path) 35 | 36 | n_points = 100001 37 | all_data = np.load(data_file, allow_pickle=True).item() 38 | all_data['energy'] = all_data['energy'] / EV_TO_KCAL_MOL 39 | all_data['forces'] = all_data['forces'] / EV_TO_KCAL_MOL 40 | energy = all_data['energy'] 41 | force = all_data['forces'] 42 | 43 | if time_split: 44 | test = np.arange(90000, n_points) 45 | else: 46 | train_val_pool, test = train_test_split(np.arange(n_points), train_size=n_points-10000, 47 | test_size=10000, random_state=123) 48 | 49 | for dataset_size, train_size, val_size in zip(['1k', '10k', '90k'], [950, 9500, 85500], [50, 500, 4500]): 50 | print(f'processing dataset with size {dataset_size}.') 51 | # time split 52 | if time_split: 53 | train = np.arange(train_size) 54 | val = np.arange(train_size, train_size+val_size) 55 | dataset_size = dataset_size + '_time_split' 56 | else: 57 | size = train_size + val_size 58 | train_val = train_val_pool[:size] 59 | train, val = train_test_split(train_val, train_size=train_size, test_size=val_size, random_state=123) 60 | ranges = [train, val, test] 61 | 62 | norm_stats = { 63 | 'e_mean': energy[train].mean(), 64 | 'e_std': energy[train].std(), 65 | 'f_mean': force[train].mean(), 66 | 'f_std': force[train].std(), 67 | } 68 | save_path = Path(db_path) / dataset_size 69 | save_path.mkdir(parents=True, exist_ok=True) 70 | np.save(save_path / 'metadata', norm_stats) 71 | 72 | for spidx, split in enumerate(['train', 'val', 'test']): 73 | print(f'processing split {split}.') 74 | # for OCP 75 | save_path = Path(db_path) / dataset_size / split 76 | save_path.mkdir(parents=True, exist_ok=True) 77 | db = lmdb.open( 78 | str(save_path / 'data.lmdb'), 79 | map_size=1099511627776 * 2, 80 | subdir=False, 81 | meminit=False, 82 | map_async=True, 83 | ) 84 | 85 | for i, idx in enumerate(tqdm(ranges[spidx])): 86 | 87 | data = {k: v[idx] if (v.shape[0] == 100001) else v for k, v in all_data.items()} 88 | natoms = np.array([data['wrapped_coords'].shape[0]] * 1, dtype=np.int64) 89 | data = a2g.convert(natoms, data['wrapped_coords'], data['atom_types'], 90 | data['lengths'][None, :], data['angles'][None, :], 91 | np.array([data['energy']]), data['forces']) 92 | data.sid = 0 93 | data.fid = idx 94 | txn = db.begin(write=True) 95 | txn.put(f"{i}".encode("ascii"), pickle.dumps(data, protocol=-1)) 96 | txn.commit() 97 | 98 | # Save count of objects in lmdb. 99 | txn = db.begin(write=True) 100 | txn.put("length".encode("ascii"), pickle.dumps(i, protocol=-1)) 101 | txn.commit() 102 | 103 | db.sync() 104 | db.close() 105 | 106 | # for nequip. 107 | data = all_data 108 | data['pbc'] = np.array([True]*3) 109 | data = {k: v[ranges[spidx]] if v.shape[0] == 100001 else v for k, v in data.items()} 110 | data['energy'] = energy[ranges[spidx]][:, None] / EV_TO_KCAL_MOL 111 | data['force'] = force[ranges[spidx]] / EV_TO_KCAL_MOL 112 | data['lattices'] = data['lengths'][:, None] * np.eye(3) 113 | np.savez(save_path / 'nequip_npz', **data) 114 | 115 | if __name__ == '__main__': 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument("--data_path", type=str, default='./DATAPATH/lips') 118 | parser.add_argument("--db_path", type=str, default='./DATAPATH/lips') 119 | parser.add_argument("--time_split", action="store_true", help='split data by time order') 120 | args = parser.parse_args() 121 | write_to_lmdb(args.data_path, args.db_path, args.time_split) 122 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="mdsim", 5 | version="0.0.1", 6 | description="Machine learning models for MD simulation", 7 | packages=find_packages(), 8 | include_package_data=True, 9 | ) 10 | --------------------------------------------------------------------------------