├── .gitignore ├── LICENSE ├── README.md ├── ase_calculator.py ├── ase_example.ipynb ├── config.yaml ├── config_seml.yaml ├── data ├── coll_v1.2_test.npz ├── coll_v1.2_train.npz └── coll_v1.2_val.npz ├── env.yml ├── fit_scaling.py ├── gemnet ├── model │ ├── gemnet.py │ ├── initializers.py │ ├── layers │ │ ├── atom_update_block.py │ │ ├── base_layers.py │ │ ├── basis_layers.py │ │ ├── basis_utils.py │ │ ├── efficient.py │ │ ├── embedding_block.py │ │ ├── envelope.py │ │ ├── interaction_block.py │ │ └── scaling.py │ └── utils.py └── training │ ├── data_container.py │ ├── data_provider.py │ ├── ema_decay.py │ ├── metrics.py │ ├── schedules.py │ └── trainer.py ├── predict.ipynb ├── pretrained ├── GemNet-Q │ ├── model.pth │ └── model_kwargs.json ├── GemNet-T │ ├── model.pth │ └── model_kwargs.json └── scaling_factors.json ├── requirements.txt ├── scaling_factors.json ├── setup.py ├── train.ipynb └── train_seml.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Jupyter 2 | .ipynb_checkpoints 3 | 4 | # VS Code 5 | .vscode 6 | 7 | # Python 8 | __pycache__ 9 | *.pyc 10 | 11 | # pytest 12 | .pytest_cache 13 | 14 | *.traj -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Johannes Gasteiger, Florian Becker 2 | 3 | Hippocratic License Version 2.0 4 | 5 | Licensor hereby grants permission by this license ("License"), free 6 | of charge, to any person or entity (the "Licensee") obtaining a copy 7 | of this software and associated documentation files (the "Software"), 8 | to deal in the Software without restriction, including without 9 | limitation the rights to use, copy, modify, merge, publish, 10 | distribute, sublicense, and/or sell copies of the Software, and to 11 | permit persons to whom the Software is furnished to do so, subject to 12 | the following conditions: 13 | 14 | The above copyright notice and this License or a subsequent version 15 | published on the Hippocratic License Website 16 | (https://firstdonoharm.dev/) shall be included in all copies or 17 | substantial portions of the Software. Licensee has the option of 18 | following the terms and conditions either of the above numbered 19 | version of this License or of any subsequent version published on the 20 | Hippocratic License Website. 21 | 22 | Compliance with Human Rights Laws and Human Rights Principles: 23 | 24 | 1. Human Rights Laws. The Software shall not be used by any person or 25 | entity for any systems, activities, or other uses that violate any 26 | applicable laws, regulations, or rules that protect human, civil, 27 | labor, privacy, political, environmental, security, economic, due 28 | process, or similar rights (the "Human Rights Laws"). Where the Human 29 | Rights Laws of more than one jurisdiction are applicable to the use 30 | of the Software, the Human Rights Laws that are most protective of 31 | the individuals or groups harmed shall apply. 32 | 33 | 2. Human Rights Principles. Licensee is advised to consult the 34 | articles of the United Nations Universal Declaration of Human Rights 35 | (https://www.un.org/en/universal-declaration-human-rights/) and the 36 | United Nations Global Compact 37 | (https://www.unglobalcompact.org/what-is-gc/mission/principles) that 38 | define recognized principles of international human rights (the 39 | "Human Rights Principles"). It is Licensor's express intent that all 40 | use of the Software be consistent with Human Rights Principles. If 41 | Licensor receives notification or otherwise learns of an alleged 42 | violation of any Human Rights Principles relating to Licensee's use 43 | of the Software, Licensor may in its discretion and without 44 | obligation (i) (a) notify Licensee of such allegation and (b) allow 45 | Licensee 90 days from notification under (i)(a) to investigate and 46 | respond to Licensor regarding the allegation and (ii) (a) after the 47 | earlier of 90 days from notification under (i)(a), or Licensee's 48 | response under (i)(b), notify Licensee of License termination and (b) 49 | allow Licensee an additional 90 days from notification under (ii)(a) 50 | to cease use of the Software. 51 | 52 | 3. Indemnity. Licensee shall hold harmless and indemnify Licensor 53 | against all losses, damages, liabilities, deficiencies, claims, 54 | actions, judgments, settlements, interest, awards, penalties, fines, 55 | costs, or expenses of whatever kind, including Licensor's reasonable 56 | attorneys' fees, arising out of or relating to Licensee's 57 | non-compliance with this License or use of the Software in violation 58 | of Human Rights Laws or Human Rights Principles. 59 | 60 | Enforceability: If any portion or provision of this License is 61 | determined to be invalid, illegal, or unenforceable by a court of 62 | competent jurisdiction, then such invalidity, illegality, or 63 | unenforceability shall not affect any other term or provision of this 64 | License or invalidate or render unenforceable such term or provision 65 | in any other jurisdiction. Upon a determination that any term or 66 | provision is invalid, illegal, or unenforceable, to the extent 67 | permitted by applicable law, the court may modify this License to 68 | affect the original intent of the parties as closely as possible. The 69 | section headings are for convenience only and are not intended to 70 | affect the construction or interpretation of this License. Any rule 71 | of construction to the effect that ambiguities are to be resolved 72 | against the drafting party shall not apply in interpreting this 73 | License. The language in this License shall be interpreted as to its 74 | fair meaning and not strictly for or against any party. 75 | 76 | 77 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 78 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 79 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 80 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 81 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 82 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 83 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 84 | 85 | This Hippocratic License is an Ethical Source license 86 | (https://ethicalsource.dev). 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GemNet: Universal Directional Graph Neural Networks for Molecules 2 | 3 | Reference implementation in PyTorch of the geometric message passing neural network (GemNet). You can find its original [TensorFlow 2 implementation in another repository](https://github.com/TUM-DAML/gemnet_tf). GemNet is a model for predicting the overall energy and the forces acting on the atoms of a molecule. It was proposed in the paper: 4 | 5 | **[GemNet: Universal Directional Graph Neural Networks for Molecules](https://www.cs.cit.tum.de/daml/gemnet/)** 6 | by Johannes Gasteiger, Florian Becker, Stephan Günnemann 7 | Published at NeurIPS 2021 8 | 9 | and further analyzed in 10 | 11 | **[How robust are modern graph neural network potentials in long and hot molecular dynamics simulations?](https://www.cs.cit.tum.de/daml/gemnet/)** 12 | by Sina Stocker\*, Johannes Gasteiger\*, Florian Becker, Stephan Günnemann and Johannes T. Margraf 13 | Published in Machine Learning: Science and Technology, 2022 14 | 15 | \*Both authors contributed equally to this research. Note that the author's name has changed from Johannes Klicpera to Johannes Gasteiger. 16 | 17 | ## Run the code 18 | Adjust config.yaml (or config_seml.yaml) to your needs. 19 | This repository contains notebooks for training the model (`train.ipynb`) and for generating predictions on a molecule loaded from [ASE](https://wiki.fysik.dtu.dk/ase/) (`predict.ipynb`). It also contains a script for training the model on a cluster with Sacred and [SEML](https://github.com/TUM-DAML/seml) (`train_seml.py`). Further, a notebook is provided to show how GemNet can be used for MD simulations (`ase_example.ipynb`). 20 | 21 | ## Compute scaling factors 22 | You can either use the precomputed scaling_factors (in scaling_factors.json) or compute them yourself by running fit_scaling.py. Scaling factors are used to ensure a consistent scale of activations at initialization. They are the same for all GemNet variants. 23 | 24 | ## Contact 25 | Please contact j.gasteiger@in.tum.de if you have any questions. 26 | 27 | ## Cite 28 | Please cite our papers if you use the model or this code in your own work: 29 | 30 | ``` 31 | @inproceedings{gasteiger_gemnet_2021, 32 | title = {GemNet: Universal Directional Graph Neural Networks for Molecules}, 33 | author = {Gasteiger, Johannes and Becker, Florian and G{\"u}nnemann, Stephan}, 34 | booktitle={Conference on Neural Information Processing Systems (NeurIPS)}, 35 | year = {2021} 36 | } 37 | ``` 38 | 39 | ``` 40 | @article{stocker_robust_2022, 41 | title = {How robust are modern graph neural network potentials in long and hot molecular dynamics simulations?}, 42 | author = {Stocker, Sina and Gasteiger, Johannes and Becker, Florian and G{\"u}nnemann, Stephan and Margraf, Johannes T.}, 43 | volume = {3}, 44 | doi = {10.1088/2632-2153/ac9955}, 45 | number = {4}, 46 | journal = {Machine Learning: Science and Technology}, 47 | year = {2022}, 48 | pages = {045010}, 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /ase_calculator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | 4 | # GemNet imports 5 | from gemnet.model.gemnet import GemNet 6 | from gemnet.training.data_container import DataContainer 7 | 8 | # ASE imports 9 | from ase.md import MDLogger 10 | from ase.md.velocitydistribution import ( 11 | MaxwellBoltzmannDistribution, 12 | Stationary, 13 | ) 14 | from ase.md.verlet import VelocityVerlet 15 | from ase.md.langevin import Langevin 16 | from ase import units, Atoms 17 | 18 | from ase.io.trajectory import Trajectory 19 | 20 | from ase.calculators.calculator import Calculator, all_changes 21 | 22 | 23 | class Molecule(DataContainer): 24 | """ 25 | Implements the DataContainer but for a single molecule. Requires custom init method. 26 | """ 27 | def __init__(self, R, Z, cutoff, int_cutoff, triplets_only=False): 28 | self.index_keys = [ 29 | "batch_seg", 30 | "id_undir", 31 | "id_swap", 32 | "id_c", 33 | "id_a", 34 | "id3_expand_ba", 35 | "id3_reduce_ca", 36 | "Kidx3", 37 | ] 38 | if not triplets_only: 39 | self.index_keys += [ 40 | "id4_int_b", 41 | "id4_int_a", 42 | "id4_reduce_ca", 43 | "id4_expand_db", 44 | "id4_reduce_cab", 45 | "id4_expand_abd", 46 | "Kidx4", 47 | "id4_reduce_intm_ca", 48 | "id4_expand_intm_db", 49 | "id4_reduce_intm_ab", 50 | "id4_expand_intm_ab", 51 | ] 52 | self.triplets_only = triplets_only 53 | self.cutoff = cutoff 54 | self.int_cutoff = int_cutoff 55 | self.keys = ["N", "Z", "R", "F", "E"] 56 | 57 | assert R.shape == (len(Z), 3) 58 | self.R = R 59 | self.Z = Z 60 | self.N = np.array([len(Z)], dtype=np.int32) 61 | self.E = np.zeros(1, dtype=np.float32).reshape(1, 1) 62 | self.F = np.zeros((len(Z), 3), dtype=np.float32) 63 | 64 | self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)]) 65 | self.addID = False 66 | self.dtypes, dtypes2 = self.get_dtypes() 67 | self.dtypes.update(dtypes2) # merge all dtypes in single dict 68 | self.device = "cpu" 69 | 70 | def get(self): 71 | """ 72 | Get the molecule representation in the expected format for the GemNet model. 73 | """ 74 | data = self.__getitem__(0) 75 | for var in ["E", "F"]: 76 | data.pop(var) # not needed i.e.e not kown -> want to calculate this 77 | # push to the selected device 78 | for key in data: 79 | data[key] = data[key].to(self.device) 80 | return data 81 | 82 | def update(self, R): 83 | """ 84 | Update the position of the atoms. 85 | Graph representation of the molecule might change if the atom positions are updated. 86 | 87 | Parameters 88 | ---------- 89 | R: torch.Tensor (nAtoms, 3) 90 | Positions of the atoms in A°. 91 | """ 92 | assert self.R.shape == R.shape 93 | self.R = R 94 | 95 | def to(self, device): 96 | """ 97 | Changes the device of the returned tensors in the .get() method. 98 | """ 99 | self.device = device 100 | 101 | 102 | class GNNCalculator(Calculator): 103 | """ 104 | A custom ase calculator that computes energy and forces acting on atoms of a molecule using GNNs, 105 | e.g. GemNet. 106 | 107 | Parameters 108 | ---------- 109 | molecule 110 | Captures data of all atoms. Contains indices etc. 111 | model 112 | The trained GemNet model. 113 | atoms: ase.Atoms 114 | ASE atoms instance. 115 | 116 | restart: str 117 | Prefix for restart file. May contain a directory. Default is None: don't restart. 118 | label: str 119 | Name used for all files. 120 | """ 121 | 122 | implemented_properties = ["energy", "forces"] 123 | 124 | def __init__( 125 | self, 126 | molecule, 127 | model, 128 | atoms=None, 129 | restart=None, 130 | add_atom_energies=False, 131 | label="gemnet_calc", # ase settings 132 | **kwargs, 133 | ): 134 | super().__init__(restart=restart, label=label, atoms=atoms, **kwargs) 135 | self.molecule = molecule 136 | self.model = model 137 | # atom energies: EPBE0_atom (in eV) from QM7-X 138 | self.add_atom_energies = add_atom_energies 139 | self.atom_energies = { 140 | 1: -13.641404161, 141 | 6: -1027.592489146, 142 | 7: -1484.274819088, 143 | 8: -2039.734879322, 144 | 16: -10828.707468187, 145 | 17: -12516.444619523, 146 | } 147 | 148 | def calculate( 149 | self, atoms=None, properties=["energy", "forces"], system_changes=all_changes 150 | ): 151 | super().calculate(atoms, properties, system_changes) 152 | 153 | # atoms.positions changes in each time step 154 | # -> need to recompute indices 155 | self.molecule.update(R=atoms.positions) 156 | 157 | # get new indices etc. 158 | inputs = self.molecule.get() 159 | 160 | # predict the energy and forces 161 | energy, forces = self.model.predict(inputs) 162 | 163 | # uncomment to add atomic reference energies 164 | energy = float(energy) # to scalar 165 | if self.add_atom_energies: 166 | energy += np.sum([self.atom_energies[z] for z in atoms.numbers]) 167 | 168 | # store energy and forces in the calculator dictionary 169 | self.results["energy"] = energy 170 | self.results["forces"] = forces.numpy() 171 | 172 | 173 | class MDSimulator: 174 | """ 175 | Runs a MD simulation on the Atoms object created from data and perform MD simulation for max_steps 176 | 177 | Parameters 178 | ---------- 179 | molecule 180 | Captures data of all atoms. 181 | model 182 | The trained GemNet model. 183 | dynamics: str 184 | Name of the MD integrator. Implemented: 'langevin' or 'verlet'. 185 | max_steps: int 186 | Maximum number of simulation steps. 187 | time: float 188 | Integration time step for Newton's law in femtoseconds. 189 | temperature: float 190 | The temperature in Kelvin. 191 | langevin_friction: float 192 | Only used when dynamics are 'langevin'. A friction coefficient, typically 1e-4 to 1e-2. 193 | interval: int 194 | Write only every time step to trajectory file. 195 | traj_path: str 196 | Path of the file where to save the calculated trajectory. 197 | vel: N-array, default=None 198 | If set, then atoms have been initialized with these velocties. 199 | logfile: str 200 | File name or open file, where to log md simulation. “-” refers to standard output. 201 | """ 202 | 203 | def __init__( 204 | self, 205 | molecule, 206 | model, 207 | dynamics: str = "langevin", 208 | max_steps: int = 100, # max_steps * time is total time length of trajectory 209 | time: float = 0.5, # in fs 210 | temperature: float = 300, # in K 211 | langevin_friction: float = 0.002, 212 | interval: int = 10, 213 | traj_path="md_sim.traj", 214 | vel=None, 215 | logfile="-", 216 | ): 217 | 218 | self.max_steps = max_steps 219 | 220 | atoms = Atoms( 221 | positions=molecule.R, numbers=molecule.Z 222 | ) # positions in A, numbers in integers (1=H, etc.) 223 | 224 | atoms.calc = GNNCalculator(molecule, model=model, atoms=atoms) 225 | 226 | # Initializes velocities 227 | #TODO: Implement a check for that switch 228 | if vel is not None: 229 | atoms.set_velocities(vel) 230 | else: 231 | # Set the momenta to a Maxwell-Boltzmann distribution 232 | MaxwellBoltzmannDistribution( 233 | atoms, 234 | temp=temperature * units.kB, # kB: Boltzmann constant, eV/K 235 | # temperature_K = temperature # only works in newer ase versions 236 | ) 237 | # Set the center-of-mass momentum to zero 238 | Stationary(atoms) 239 | 240 | self.dyn = None 241 | # Select MD simulation 242 | if dynamics.lower() == "verlet": 243 | logging.info("Selected MD integrator: Verlet") 244 | # total energy will always be constant 245 | self.dyn = VelocityVerlet(atoms, timestep=time * units.fs) 246 | elif dynamics.lower() == "langevin": 247 | logging.info("Selected MD integrator: Langevin") 248 | # each atom is coupled to a heat bath through a fluctuating force and a friction term 249 | self.dyn = Langevin( 250 | atoms, 251 | timestep=time * units.fs, 252 | temperature=temperature * units.kB, # kB: Boltzmann constant, eV/K 253 | # temperature_K = temperature, # only works in newer ase versions 254 | friction=langevin_friction, 255 | ) 256 | else: 257 | raise UserWarning( 258 | f"Unkown MD integrator. I only know 'verlet' and 'langevin' but {dynamics} was given." 259 | ) 260 | 261 | logging.info(f"Save trajectory to {traj_path}") 262 | self.traj = Trajectory(traj_path, "w", atoms) 263 | self.dyn.attach(self.traj.write, interval=interval) 264 | self.dyn.attach( 265 | MDLogger(self.dyn, atoms, logfile, peratom=False, mode="a"), 266 | interval=interval, 267 | ) 268 | 269 | def run(self): 270 | self.dyn.run(self.max_steps) 271 | self.traj.close() 272 | 273 | -------------------------------------------------------------------------------- /ase_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import logging\n", 10 | "# Set up logger\n", 11 | "logger = logging.getLogger()\n", 12 | "logger.handlers = []\n", 13 | "ch = logging.StreamHandler()\n", 14 | "formatter = logging.Formatter(\n", 15 | " fmt=\"%(asctime)s (%(levelname)s): %(message)s\", datefmt=\"%Y-%m-%d %H:%M:%S\"\n", 16 | ")\n", 17 | "ch.setFormatter(formatter)\n", 18 | "logger.addHandler(ch)\n", 19 | "logger.setLevel(\"INFO\")" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "data": { 29 | "application/vnd.jupyter.widget-view+json": { 30 | "model_id": "23182efc371d4de1865c0b18e3979e7f", 31 | "version_major": 2, 32 | "version_minor": 0 33 | }, 34 | "text/plain": [] 35 | }, 36 | "metadata": {}, 37 | "output_type": "display_data" 38 | } 39 | ], 40 | "source": [ 41 | "from gemnet.model.gemnet import GemNet\n", 42 | "from gemnet.model.utils import read_json\n", 43 | "\n", 44 | "from ase_calculator import Molecule, MDSimulator\n", 45 | "from ase.build import molecule as ase_molecule_db\n", 46 | "\n", 47 | "# for visualization\n", 48 | "from ase.io.trajectory import Trajectory\n", 49 | "import nglview" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "# Model settings" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "model_name = \"GemNet-Q\"\n", 66 | "# model_name = \"GemNet-T\"\n", 67 | "\n", 68 | "pretrained_models_path = \"./pretrained\"\n", 69 | "weights_file = f\"{pretrained_models_path}/{model_name}/model.pth" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "# Load the model" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 4, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "num_spherical: 7\n", 89 | "num_radial: 6\n", 90 | "num_blocks: 4\n", 91 | "emb_size_atom: 128\n", 92 | "emb_size_edge: 128\n", 93 | "emb_size_trip: 64\n", 94 | "emb_size_quad: 32\n", 95 | "emb_size_rbf: 16\n", 96 | "emb_size_cbf: 16\n", 97 | "emb_size_sbf: 32\n", 98 | "emb_size_bil_trip: 64\n", 99 | "emb_size_bil_quad: 32\n", 100 | "num_before_skip: 1\n", 101 | "num_after_skip: 1\n", 102 | "num_concat: 1\n", 103 | "num_atom: 2\n", 104 | "triplets_only: False\n", 105 | "num_targets: 1\n", 106 | "direct_forces: False\n", 107 | "cutoff: 5.0\n", 108 | "int_cutoff: 10.0\n", 109 | "envelope_exponent: 5\n", 110 | "extensive: True\n", 111 | "forces_coupled: False\n", 112 | "output_init: HeOrthogonal\n", 113 | "activation: swish\n", 114 | "scale_file: ./pretrained/scaling_factors.json\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "model_kwargs = read_json(f\"{pretrained_models_path}/{model_name}/model_kwargs.json\")\n", 120 | "model_kwargs[\"scale_file\"] = f\"{pretrained_models_path}/\" + model_kwargs[\"scale_file\"]\n", 121 | "\n", 122 | "for key, value in model_kwargs.items():\n", 123 | " print(f\"{key}: {value}\")" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 5, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "model = GemNet(**model_kwargs)\n", 133 | "model.load_weights(weights_file)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "# Molecule setup\n", 141 | "Load from database or build your own by specifying R and Z" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 6, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "mol = ase_molecule_db('C7NH5')\n", 151 | "R = mol.get_positions()\n", 152 | "Z = mol.get_atomic_numbers()" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "# MD simulation settings" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 7, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "traj_path = \"./md_sim.traj\"\n", 169 | "logfile = \"-\" # “-” refers to standard output.\n", 170 | "dynamics = \"langevin\" # Name of the MD integrator. Implemented: 'langevin' or 'verlet'.\n", 171 | "max_steps = 10 # Maximum number of simulation steps.\n", 172 | "time = 0.5 # Integration time step for Newton's law in femtoseconds.\n", 173 | "interval = 2 # Write only every time step to trajectory file.\n", 174 | "temperature = 1500 # The temperature in Kelvin.\n", 175 | "langevin_friction = 0.002 # Friction coefficient (only used when dynamics is langevin)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "# Setup and run the simulation" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 8, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "cutoff = model_kwargs[\"cutoff\"]\n", 192 | "int_cutoff = model_kwargs[\"int_cutoff\"]\n", 193 | "triplets_only = model_kwargs[\"triplets_only\"]\n", 194 | "molecule = Molecule(\n", 195 | " R, Z, cutoff=cutoff, int_cutoff=int_cutoff, triplets_only=triplets_only\n", 196 | ")" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 9, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "name": "stderr", 206 | "output_type": "stream", 207 | "text": [ 208 | "/nfs/homedirs/beckerf/anaconda3/envs/torch/lib/python3.8/site-packages/ase/md/md.py:48: FutureWarning: Specify the temperature in K using the 'temperature_K' argument\n", 209 | " warnings.warn(FutureWarning(w))\n", 210 | "2022-03-15 20:28:47 (INFO): Selected MD integrator: Langevin\n", 211 | "2022-03-15 20:28:47 (INFO): Save trajectory to ./md_sim.traj\n" 212 | ] 213 | }, 214 | { 215 | "name": "stdout", 216 | "output_type": "stream", 217 | "text": [ 218 | "Time[ps] Etot[eV] Epot[eV] Ekin[eV] T[K]\n", 219 | "0.0000 -75.7794 -77.4343 1.6548 984.8\n", 220 | "0.0010 -75.7737 -77.3192 1.5455 919.7\n", 221 | "0.0020 -75.7691 -77.0242 1.2551 746.9\n", 222 | "0.0030 -75.7583 -76.8102 1.0519 626.0\n", 223 | "0.0040 -75.7583 -76.7260 0.9677 575.9\n", 224 | "0.0050 -75.7566 -76.5981 0.8415 500.8\n" 225 | ] 226 | } 227 | ], 228 | "source": [ 229 | "simulation = MDSimulator(\n", 230 | " molecule, model, \n", 231 | " dynamics=dynamics, max_steps=max_steps, time=time, temperature=temperature, langevin_friction=langevin_friction,\n", 232 | " interval=interval, traj_path=traj_path, logfile=logfile\n", 233 | ")\n", 234 | "simulation.run()" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": {}, 240 | "source": [ 241 | "# Visualize simulated trajectory" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 10, 247 | "metadata": {}, 248 | "outputs": [ 249 | { 250 | "data": { 251 | "application/vnd.jupyter.widget-view+json": { 252 | "model_id": "b621ccebd84648228596cfdf0c679c62", 253 | "version_major": 2, 254 | "version_minor": 0 255 | }, 256 | "text/plain": [ 257 | "NGLWidget(max_frame=5)" 258 | ] 259 | }, 260 | "metadata": {}, 261 | "output_type": "display_data" 262 | } 263 | ], 264 | "source": [ 265 | "traj = Trajectory(traj_path)\n", 266 | "nglview.show_asetraj(traj)" 267 | ] 268 | } 269 | ], 270 | "metadata": { 271 | "interpreter": { 272 | "hash": "73d4dc6ffc134dc5e05ee963c4039b14792ec4f63c8d27e3dd67b524fa7b1d65" 273 | }, 274 | "kernelspec": { 275 | "display_name": "Python 3.8.0 ('tf')", 276 | "language": "python", 277 | "name": "python3" 278 | }, 279 | "language_info": { 280 | "codemirror_mode": { 281 | "name": "ipython", 282 | "version": 3 283 | }, 284 | "file_extension": ".py", 285 | "mimetype": "text/x-python", 286 | "name": "python", 287 | "nbconvert_exporter": "python", 288 | "pygments_lexer": "ipython3", 289 | "version": "3.8.0" 290 | }, 291 | "orig_nbformat": 4 292 | }, 293 | "nbformat": 4, 294 | "nbformat_minor": 2 295 | } 296 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | 2 | num_spherical: 7 3 | num_radial: 6 4 | num_blocks: 4 5 | 6 | emb_size_atom: 128 7 | emb_size_edge: 128 8 | emb_size_trip: 64 9 | emb_size_quad: 32 10 | emb_size_rbf: 16 11 | emb_size_cbf: 16 12 | emb_size_sbf: 32 13 | emb_size_bil_trip: 64 14 | emb_size_bil_quad: 32 15 | 16 | num_before_skip: 1 17 | num_after_skip: 1 18 | num_concat: 1 19 | num_atom: 2 20 | 21 | cutoff: 5.0 22 | int_cutoff: 10.0 23 | triplets_only: False 24 | direct_forces: False 25 | 26 | mve: False 27 | loss: "rmse" 28 | forces_coupled: False 29 | envelope_exponent: 5 30 | extensive: True 31 | 32 | rho_force: 0.999 33 | ema_decay: 0.999 34 | weight_decay: 0.000002 35 | 36 | learning_rate: 0.001 37 | decay_steps: 4500000 38 | decay_rate: 0.01 39 | staircase: False 40 | decay_patience: 5 41 | decay_factor: 0.5 42 | decay_cooldown: 5 43 | agc: False 44 | grad_clip_max: 10.0 45 | 46 | restart: null 47 | tfseed: 1234 48 | data_seed: 42 49 | scale_file: "scaling_factors.json" 50 | comment: "GemNet" 51 | output_init: "HeOrthogonal" 52 | 53 | logdir: "logs" 54 | dataset: "data/coll_v1.2_train.npz" 55 | val_dataset: "data/coll_v1.2_val.npz" 56 | num_train: 0 # derived from dataset 57 | num_val: 0 # derived from dataset 58 | 59 | patience: 5 60 | evaluation_interval: 7500 61 | save_interval: 7500 62 | warmup_steps: 3750 63 | batch_size: 32 64 | num_steps: 1500000 65 | -------------------------------------------------------------------------------- /config_seml.yaml: -------------------------------------------------------------------------------- 1 | 2 | seml: 3 | executable: 'train_seml.py' 4 | name: "gemnet" 5 | output_dir: "slurm_logs" 6 | project_root_dir: "." 7 | 8 | slurm: 9 | experiments_per_job: 1 10 | sbatch_options: 11 | gres: 'gpu:1' 12 | mem: 40G 13 | cpus-per-task: 2 14 | time: 07-00:00 15 | partition: gpu_all 16 | 17 | fixed: 18 | num_spherical: 7 19 | num_radial: 6 20 | num_blocks: 4 21 | 22 | emb_size_atom: 128 23 | emb_size_edge: 128 24 | emb_size_trip: 64 25 | emb_size_quad: 32 26 | emb_size_rbf: 16 27 | emb_size_cbf: 16 28 | emb_size_sbf: 32 29 | emb_size_bil_trip: 64 30 | emb_size_bil_quad: 32 31 | 32 | num_before_skip: 1 33 | num_after_skip: 1 34 | num_concat: 1 35 | num_atom: 2 36 | 37 | cutoff: 5.0 38 | int_cutoff: 10.0 39 | 40 | mve: False 41 | loss: "rmse" 42 | forces_coupled: False 43 | envelope_exponent: 5 44 | extensive: True 45 | 46 | rho_force: 0.999 47 | ema_decay: 0.999 48 | weight_decay: 0.000002 49 | 50 | learning_rate: 0.001 51 | decay_steps: 4500000 52 | decay_rate: 0.01 53 | staircase: False 54 | decay_patience: 5 55 | decay_factor: 0.5 56 | decay_cooldown: 5 57 | agc: False 58 | grad_clip_max: 10.0 59 | 60 | restart: null 61 | tfseed: 1234 62 | data_seed: 42 63 | scale_file: "scaling_factors.json" 64 | comment: "GemNet" 65 | output_init: "HeOrthogonal" 66 | 67 | logdir: "logs" 68 | dataset: "data/coll_v1.2_train.npz" 69 | val_dataset: "data/coll_v1.2_val.npz" 70 | num_train: 0 # derived from dataset 71 | num_val: 0 # derived from dataset 72 | 73 | patience: 5 74 | evaluation_interval: 7500 75 | save_interval: 7500 76 | warmup_steps: 3750 77 | batch_size: 32 78 | num_steps: 1500000 79 | 80 | grid: 81 | triplets_only: 82 | type: choice 83 | options: 84 | - True 85 | - False 86 | 87 | direct_forces: 88 | type: choice 89 | options: 90 | - True 91 | - False 92 | -------------------------------------------------------------------------------- /data/coll_v1.2_test.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TUM-DAML/gemnet_pytorch/a0164f74217155232d39c35f0bb2c016bd3f44da/data/coll_v1.2_test.npz -------------------------------------------------------------------------------- /data/coll_v1.2_train.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TUM-DAML/gemnet_pytorch/a0164f74217155232d39c35f0bb2c016bd3f44da/data/coll_v1.2_train.npz -------------------------------------------------------------------------------- /data/coll_v1.2_val.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TUM-DAML/gemnet_pytorch/a0164f74217155232d39c35f0bb2c016bd3f44da/data/coll_v1.2_val.npz -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: torch 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - pytorch 6 | - pyg 7 | dependencies: 8 | - python==3.8 9 | - cudatoolkit=11.3 10 | - pytorch==1.10 11 | - pytorch-scatter 12 | - jupyterlab 13 | - numpy 14 | - numba 15 | - scipy>=1.3 16 | - sympy>=1.5 17 | - tqdm 18 | - ase 19 | - nglview -------------------------------------------------------------------------------- /fit_scaling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" 5 | os.environ["AUTOGRAPH_VERBOSITY"] = "1" 6 | import logging 7 | 8 | # Set up logger 9 | logger = logging.getLogger() 10 | logger.handlers = [] 11 | ch = logging.StreamHandler() 12 | formatter = logging.Formatter( 13 | fmt="%(asctime)s (%(levelname)s): %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 14 | ) 15 | ch.setFormatter(formatter) 16 | logger.addHandler(ch) 17 | logger.setLevel("INFO") 18 | 19 | import torch 20 | 21 | from gemnet.model.gemnet import GemNet 22 | from gemnet.training.trainer import Trainer 23 | from gemnet.training.metrics import Metrics 24 | from gemnet.training.data_container import DataContainer 25 | from gemnet.training.data_provider import DataProvider 26 | 27 | import yaml 28 | import ast 29 | from tqdm import trange 30 | 31 | from gemnet.model.utils import write_json 32 | from gemnet.model.layers.scaling import AutomaticFit 33 | 34 | def run( 35 | nBatches, 36 | num_spherical, 37 | num_radial, 38 | num_blocks, 39 | emb_size_atom, 40 | emb_size_edge, 41 | emb_size_trip, 42 | emb_size_quad, 43 | emb_size_rbf, 44 | emb_size_cbf, 45 | emb_size_sbf, 46 | num_before_skip, 47 | num_after_skip, 48 | num_concat, 49 | num_atom, 50 | emb_size_bil_quad, 51 | emb_size_bil_trip, 52 | triplets_only, 53 | forces_coupled, 54 | direct_forces, 55 | mve, 56 | cutoff, 57 | int_cutoff, 58 | envelope_exponent, 59 | extensive, 60 | output_init, 61 | scale_file, 62 | data_seed, 63 | val_dataset, 64 | tfseed, 65 | batch_size, 66 | comment, 67 | overwrite_mode=1, 68 | **kwargs, 69 | ): 70 | """ 71 | Run this function to automatically fit all scaling factors in the network. 72 | """ 73 | torch.manual_seed(tfseed) 74 | 75 | def init(scale_file): 76 | # initialize file 77 | # same for all models 78 | preset = {"comment": comment} 79 | write_json(scale_file, preset) 80 | 81 | if os.path.exists(scale_file): 82 | print(f"Already found existing file: {scale_file}") 83 | if str(overwrite_mode) == "1": 84 | print("Selected: Overwrite the current file.") 85 | init(scale_file) 86 | elif str(overwrite_mode) == "2": 87 | print("Selected: Only fit unfitted variables.") 88 | else: 89 | print("Selected: Exit script") 90 | return 91 | else: 92 | init(scale_file) 93 | 94 | AutomaticFit.set2fitmode() 95 | 96 | logging.info("Initialize model") 97 | model = GemNet( 98 | num_spherical=num_spherical, 99 | num_radial=num_radial, 100 | num_blocks=num_blocks, 101 | emb_size_atom=emb_size_atom, 102 | emb_size_edge=emb_size_edge, 103 | emb_size_trip=emb_size_trip, 104 | emb_size_quad=emb_size_quad, 105 | emb_size_rbf=emb_size_rbf, 106 | emb_size_cbf=emb_size_cbf, 107 | emb_size_sbf=emb_size_sbf, 108 | num_before_skip=num_before_skip, 109 | num_after_skip=num_after_skip, 110 | num_concat=num_concat, 111 | num_atom=num_atom, 112 | emb_size_bil_quad=emb_size_bil_quad, 113 | emb_size_bil_trip=emb_size_bil_trip, 114 | num_targets=2 if mve else 1, 115 | cutoff=cutoff, 116 | int_cutoff=int_cutoff, 117 | envelope_exponent=envelope_exponent, 118 | forces_coupled=forces_coupled, 119 | direct_forces=True, # evaluates faster 120 | triplets_only=triplets_only, 121 | activation="swish", 122 | extensive=extensive, 123 | output_init=output_init, 124 | scale_file=scale_file, 125 | ) 126 | 127 | logging.info("Load dataset") 128 | # Initialize validation datasets 129 | val_data_container = DataContainer( 130 | val_dataset, cutoff=cutoff, int_cutoff=int_cutoff, triplets_only=triplets_only 131 | ) 132 | val_data_provider = DataProvider( 133 | val_data_container, 134 | 0, 135 | nBatches * batch_size, 136 | batch_size, 137 | seed=data_seed, 138 | shuffle=True, 139 | random_split=True, 140 | ) 141 | 142 | # Initialize datasets 143 | dataset_iter = val_data_provider.get_dataset("val") 144 | logging.info("Prepare training") 145 | 146 | # Initialize trainer 147 | trainer = Trainer(model, mve=mve) 148 | metrics = Metrics("train", trainer.tracked_metrics, None) 149 | 150 | # Training loop 151 | logging.info("Start training") 152 | 153 | while not AutomaticFit.fitting_completed(): 154 | for step in trange(0, nBatches, desc="Training..."): 155 | trainer.test_on_batch(dataset_iter, metrics) 156 | 157 | current_var = AutomaticFit.activeVar 158 | if current_var is not None: 159 | current_var.fit() # fit current variable 160 | else: 161 | print("Found no variable to fit. Something went wrong!") 162 | 163 | logging.info(f"\n Fitting done. Results saved to: {scale_file}") 164 | 165 | 166 | if __name__ == "__main__": 167 | 168 | config_path = "config.yaml" 169 | 170 | with open('config.yaml', 'r') as c: 171 | config = yaml.safe_load(c) 172 | 173 | # For strings that yaml doesn't parse (e.g. None) 174 | for key, val in config.items(): 175 | if type(val) is str: 176 | try: 177 | config[key] = ast.literal_eval(val) 178 | except (ValueError, SyntaxError): 179 | pass 180 | 181 | nBatches = 25 ## number of batches to use to fit a single variable 182 | 183 | config["scale_file"] = "scaling_factors.json" 184 | config["batch_size"] = 32 185 | config["direct_forces"] = True 186 | config["triplets_only"] = False 187 | run(nBatches, **config) 188 | -------------------------------------------------------------------------------- /gemnet/model/initializers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _standardize(kernel): 5 | """ 6 | Makes sure that Var(W) = 1 and E[W] = 0 7 | """ 8 | eps = 1e-6 9 | 10 | if len(kernel.shape) == 3: 11 | axis = [0, 1] # last dimension is output dimension 12 | else: 13 | axis = 1 14 | 15 | var, mean = torch.var_mean(kernel, dim=axis, unbiased=True, keepdim=True) 16 | kernel = (kernel - mean) / (var + eps) ** 0.5 17 | return kernel 18 | 19 | 20 | def he_orthogonal_init(tensor): 21 | """ 22 | Generate a weight matrix with variance according to He initialization. 23 | Based on a random (semi-)orthogonal matrix neural networks 24 | are expected to learn better when features are decorrelated 25 | (stated by eg. "Reducing overfitting in deep networks by decorrelating representations", 26 | "Dropout: a simple way to prevent neural networks from overfitting", 27 | "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks") 28 | """ 29 | tensor = torch.nn.init.orthogonal_(tensor) 30 | 31 | if len(tensor.shape) == 3: 32 | fan_in = tensor.shape[:-1].numel() 33 | else: 34 | fan_in = tensor.shape[1] 35 | 36 | with torch.no_grad(): 37 | tensor.data = _standardize(tensor.data) 38 | tensor.data *= (1 / fan_in) ** 0.5 39 | 40 | return tensor 41 | -------------------------------------------------------------------------------- /gemnet/model/layers/atom_update_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter 3 | from .base_layers import ResidualLayer, Dense 4 | from ..initializers import he_orthogonal_init 5 | from .scaling import ScalingFactor 6 | from .embedding_block import EdgeEmbedding 7 | 8 | 9 | class AtomUpdateBlock(torch.nn.Module): 10 | """ 11 | Aggregate the message embeddings of the atoms 12 | 13 | Parameters 14 | ---------- 15 | emb_size_atom: int 16 | Embedding size of the atoms. 17 | emb_size_edge: int 18 | Embedding size of the edge embeddings. 19 | nHidden: int 20 | Number of residual blocks. 21 | activation: callable/str 22 | Activation function to use in the dense layers. 23 | scale_file: str 24 | Path to the json file containing the scaling factors. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | emb_size_atom: int, 30 | emb_size_edge: int, 31 | emb_size_rbf: int, 32 | nHidden: int, 33 | activation=None, 34 | scale_file=None, 35 | name: str = "atom_update", 36 | ): 37 | super().__init__() 38 | self.name = name 39 | self.emb_size_edge = emb_size_edge 40 | 41 | self.dense_rbf = Dense(emb_size_rbf, emb_size_edge, activation=None, bias=False) 42 | self.scale_sum = ScalingFactor(scale_file=scale_file, name=name + "_sum") 43 | 44 | self.layers = self.get_mlp(emb_size_atom, nHidden, activation) 45 | 46 | def get_mlp(self, units, nHidden, activation): 47 | dense1 = Dense(self.emb_size_edge, units, activation=activation, bias=False) 48 | res = [ 49 | ResidualLayer(units, nLayers=2, activation=activation) 50 | for i in range(nHidden) 51 | ] 52 | mlp = [dense1] + res 53 | return torch.nn.ModuleList(mlp) 54 | 55 | def forward(self, h, m, rbf, id_j): 56 | """ 57 | Returns 58 | ------- 59 | h: Tensor, shape=(nAtoms, emb_size_atom) 60 | Atom embedding. 61 | """ 62 | nAtoms = h.shape[0] 63 | 64 | mlp_rbf = self.dense_rbf(rbf) # (nEdges, emb_size_edge) 65 | x = m * mlp_rbf 66 | 67 | x2 = scatter(x, id_j, dim=0, dim_size=nAtoms, reduce="add") 68 | x = self.scale_sum(m, x2) # (nAtoms, emb_size_edge) 69 | 70 | for i, layer in enumerate(self.layers): 71 | x = layer(x) # (nAtoms, emb_size_atom) 72 | return x 73 | 74 | 75 | class OutputBlock(AtomUpdateBlock): 76 | """ 77 | Combines the atom update block and subsequent final dense layer. 78 | 79 | Parameters 80 | ---------- 81 | emb_size_atom: int 82 | Embedding size of the atoms. 83 | emb_size_edge: int 84 | Embedding size of the edge embeddings. 85 | nHidden: int 86 | Number of residual blocks. 87 | num_targets: int 88 | Number of targets. 89 | activation: str 90 | Activation function to use in the dense layers (except for the final dense layer). 91 | direct_forces: bool 92 | If true directly predict forces without taking the gradient of the energy potential. 93 | output_init: str 94 | Kernel initializer of the final dense layer. 95 | scale_file: str 96 | Path to the json file containing the scaling factors. 97 | """ 98 | 99 | def __init__( 100 | self, 101 | emb_size_atom: int, 102 | emb_size_edge: int, 103 | emb_size_rbf: int, 104 | nHidden: int, 105 | num_targets: int, 106 | activation=None, 107 | direct_forces=True, 108 | output_init="HeOrthogonal", 109 | scale_file=None, 110 | name: str = "output", 111 | **kwargs, 112 | ): 113 | 114 | super().__init__( 115 | name=name, 116 | emb_size_atom=emb_size_atom, 117 | emb_size_edge=emb_size_edge, 118 | emb_size_rbf=emb_size_rbf, 119 | nHidden=nHidden, 120 | activation=activation, 121 | scale_file=scale_file, 122 | **kwargs, 123 | ) 124 | 125 | assert isinstance(output_init, str) 126 | self.output_init = output_init 127 | self.direct_forces = direct_forces 128 | self.dense_rbf = Dense(emb_size_rbf, emb_size_edge, activation=None, bias=False) 129 | 130 | self.seq_energy = self.layers # inherited from parent class 131 | # do not add bias to final layer to enforce that prediction for an atom 132 | # without any edge embeddings is zero 133 | self.out_energy = Dense(emb_size_atom, num_targets, bias=False, activation=None) 134 | 135 | if self.direct_forces: 136 | self.scale_rbf = ScalingFactor(scale_file=scale_file, name=name + "_had") 137 | self.seq_forces = self.get_mlp(emb_size_edge, nHidden, activation) 138 | # no bias in final layer to ensure continuity 139 | self.out_forces = Dense( 140 | emb_size_edge, num_targets, bias=False, activation=None 141 | ) 142 | 143 | self.reset_parameters() 144 | 145 | def reset_parameters(self): 146 | if self.output_init.lower() == "heorthogonal": 147 | he_orthogonal_init(self.out_energy.weight) 148 | if self.direct_forces: 149 | he_orthogonal_init(self.out_forces.weight) 150 | elif self.output_init.lower() == "zeros": 151 | torch.nn.init.zeros_(self.out_energy.weight) 152 | if self.direct_forces: 153 | torch.nn.init.zeros_(self.out_forces.weight) 154 | else: 155 | raise UserWarning(f"Unknown output_init: {self.output_init}") 156 | 157 | def forward(self, h, m, rbf, id_j): 158 | """ 159 | Returns 160 | ------- 161 | (E, F): tuple 162 | - E: Tensor, shape=(nAtoms, num_targets) 163 | - F: Tensor, shape=(nEdges, num_targets) 164 | Energy and force prediction 165 | """ 166 | nAtoms = h.shape[0] 167 | 168 | rbf_mlp = self.dense_rbf(rbf) # (nEdges, emb_size_edge) 169 | x = m * rbf_mlp 170 | 171 | # -------------------------------------- Energy Prediction -------------------------------------- # 172 | x_E = scatter(x, id_j, dim=0, dim_size=nAtoms, reduce="add") # (nAtoms, emb_size_edge) 173 | x_E = self.scale_sum(m, x_E) 174 | 175 | for i, layer in enumerate(self.seq_energy): 176 | x_E = layer(x_E) # (nAtoms, emb_size_atom) 177 | 178 | x_E = self.out_energy(x_E) # (nAtoms, num_targets) 179 | 180 | # --------------------------------------- Force Prediction -------------------------------------- # 181 | if self.direct_forces: 182 | 183 | x_F = self.scale_rbf(m, x) 184 | 185 | for i, layer in enumerate(self.seq_forces): 186 | x_F = layer(x_F) # (nEdges, emb_size_edge) 187 | 188 | x_F = self.out_forces(x_F) # (nEdges, num_targets) 189 | else: 190 | x_F = 0 191 | # ----------------------------------------------------------------------------------------------- # 192 | 193 | return x_E, x_F 194 | -------------------------------------------------------------------------------- /gemnet/model/layers/base_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..initializers import he_orthogonal_init 3 | 4 | 5 | class Dense(torch.nn.Module): 6 | """ 7 | Combines dense layer and scaling for swish activation. 8 | 9 | Parameters 10 | ---------- 11 | units: int 12 | Output embedding size. 13 | activation: str 14 | Name of the activation function to use. 15 | bias: bool 16 | True if use bias. 17 | """ 18 | 19 | def __init__( 20 | self, in_features, out_features, bias=False, activation=None, name=None 21 | ): 22 | super().__init__() 23 | 24 | self.linear = torch.nn.Linear(in_features, out_features, bias=bias) 25 | self.reset_parameters() 26 | self.weight = self.linear.weight 27 | self.bias = self.linear.bias 28 | 29 | if isinstance(activation, str): 30 | activation = activation.lower() 31 | if activation in ["swish", "silu"]: 32 | self._activation = ScaledSiLU() 33 | elif activation is None: 34 | self._activation = torch.nn.Identity() 35 | else: 36 | raise NotImplementedError( 37 | "Activation function not implemented for GemNet (yet)." 38 | ) 39 | 40 | def reset_parameters(self): 41 | he_orthogonal_init(self.linear.weight) 42 | if self.linear.bias is not None: 43 | self.linear.bias.data.fill_(0) 44 | 45 | def forward(self, x): 46 | x = self.linear(x) 47 | x = self._activation(x) 48 | return x 49 | 50 | 51 | class ScaledSiLU(torch.nn.Module): 52 | def __init__(self): 53 | super().__init__() 54 | self.scale_factor = 1 / 0.6 55 | self._activation = torch.nn.SiLU() 56 | 57 | def forward(self, x): 58 | return self._activation(x) * self.scale_factor 59 | 60 | 61 | class ResidualLayer(torch.nn.Module): 62 | """ 63 | Residual block with output scaled by 1/sqrt(2). 64 | 65 | Parameters 66 | ---------- 67 | units: int 68 | Output embedding size. 69 | nLayers: int 70 | Number of dense layers. 71 | activation: str 72 | Name of the activation function to use. 73 | """ 74 | 75 | def __init__(self, units: int, nLayers: int = 2, activation=None, name=None): 76 | super().__init__() 77 | self.dense_mlp = torch.nn.Sequential( 78 | *[ 79 | Dense(units, units, activation=activation, bias=False) 80 | for i in range(nLayers) 81 | ] 82 | ) 83 | self.inv_sqrt_2 = 1 / (2.0 ** 0.5) 84 | 85 | def forward(self, inputs): 86 | x = self.dense_mlp(inputs) 87 | x = inputs + x 88 | x = x * self.inv_sqrt_2 89 | return x 90 | -------------------------------------------------------------------------------- /gemnet/model/layers/basis_layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import sympy as sym 4 | 5 | from .envelope import Envelope 6 | from .basis_utils import bessel_basis, real_sph_harm 7 | 8 | 9 | class BesselBasisLayer(torch.nn.Module): 10 | """ 11 | 1D Bessel Basis 12 | 13 | Parameters 14 | ---------- 15 | num_radial: int 16 | Controls maximum frequency. 17 | cutoff: float 18 | Cutoff distance in Angstrom. 19 | envelope_exponent: int = 5 20 | Exponent of the envelope function. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | num_radial: int, 26 | cutoff: float, 27 | envelope_exponent: int = 5, 28 | name="bessel_basis", 29 | ): 30 | super().__init__() 31 | self.num_radial = num_radial 32 | self.inv_cutoff = 1 / cutoff 33 | self.norm_const = (2 * self.inv_cutoff) ** 0.5 34 | 35 | self.envelope = Envelope(envelope_exponent) 36 | 37 | # Initialize frequencies at canonical positions 38 | self.frequencies = torch.nn.Parameter( 39 | data=torch.Tensor( 40 | np.pi * np.arange(1, self.num_radial + 1, dtype=np.float32) 41 | ), 42 | requires_grad=True, 43 | ) 44 | 45 | def forward(self, d): 46 | d = d[:, None] # (nEdges,1) 47 | d_scaled = d * self.inv_cutoff 48 | env = self.envelope(d_scaled) 49 | return env * self.norm_const * torch.sin(self.frequencies * d_scaled) / d 50 | 51 | 52 | class SphericalBasisLayer(torch.nn.Module): 53 | """ 54 | 2D Fourier Bessel Basis 55 | 56 | Parameters 57 | ---------- 58 | num_spherical: int 59 | Controls maximum frequency. 60 | num_radial: int 61 | Controls maximum frequency. 62 | cutoff: float 63 | Cutoff distance in Angstrom. 64 | envelope_exponent: int = 5 65 | Exponent of the envelope function. 66 | efficient: bool 67 | Whether to use the (memory) efficient implementation or not. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | num_spherical: int, 73 | num_radial: int, 74 | cutoff: float, 75 | envelope_exponent: int = 5, 76 | efficient: bool = False, 77 | name: str = "spherical_basis", 78 | ): 79 | super().__init__() 80 | 81 | assert num_radial <= 64 82 | self.efficient = efficient 83 | self.num_radial = num_radial 84 | self.num_spherical = num_spherical 85 | self.envelope = Envelope(envelope_exponent) 86 | self.inv_cutoff = 1 / cutoff 87 | 88 | # retrieve formulas 89 | bessel_formulas = bessel_basis(num_spherical, num_radial) 90 | Y_lm = real_sph_harm( 91 | num_spherical, spherical_coordinates=True, zero_m_only=True 92 | ) 93 | self.sph_funcs = [] # (num_spherical,) 94 | self.bessel_funcs = [] # (num_spherical * num_radial,) 95 | self.norm_const = self.inv_cutoff ** 1.5 96 | self.register_buffer( 97 | "device_buffer", torch.zeros(0), persistent=False 98 | ) # dummy buffer to get device of layer 99 | 100 | # convert to torch functions 101 | x = sym.symbols("x") 102 | theta = sym.symbols("theta") 103 | modules = {"sin": torch.sin, "cos": torch.cos, "sqrt": torch.sqrt} 104 | m = 0 # only single angle 105 | for l in range(len(Y_lm)): # num_spherical 106 | if l == 0: 107 | # Y_00 is only a constant -> function returns value and not tensor 108 | first_sph = sym.lambdify([theta], Y_lm[l][m], modules) 109 | self.sph_funcs.append( 110 | lambda theta: torch.zeros_like(theta) + first_sph(theta) 111 | ) 112 | else: 113 | self.sph_funcs.append(sym.lambdify([theta], Y_lm[l][m], modules)) 114 | for n in range(num_radial): 115 | self.bessel_funcs.append( 116 | sym.lambdify([x], bessel_formulas[l][n], modules) 117 | ) 118 | 119 | def forward(self, D_ca, Angle_cab, id3_reduce_ca, Kidx): 120 | 121 | d_scaled = D_ca * self.inv_cutoff # (nEdges,) 122 | u_d = self.envelope(d_scaled) 123 | rbf = [f(d_scaled) for f in self.bessel_funcs] 124 | # s: 0 0 0 0 1 1 1 1 ... 125 | # r: 0 1 2 3 0 1 2 3 ... 126 | rbf = torch.stack(rbf, dim=1) # (nEdges, num_spherical * num_radial) 127 | rbf = rbf * self.norm_const 128 | rbf_env = u_d[:, None] * rbf # (nEdges, num_spherical * num_radial) 129 | 130 | sph = [f(Angle_cab) for f in self.sph_funcs] 131 | sph = torch.stack(sph, dim=1) # (nTriplets, num_spherical) 132 | 133 | if not self.efficient: 134 | rbf_env = rbf_env[id3_reduce_ca] # (nTriplets, num_spherical * num_radial) 135 | rbf_env = rbf_env.view(-1, self.num_spherical, self.num_radial) 136 | # e.g. num_spherical = 3, num_radial = 2 137 | # z_ln: l: 0 0 1 1 2 2 138 | # n: 0 1 0 1 0 1 139 | sph = sph.view(-1, self.num_spherical, 1) # (nTriplets, num_spherical, 1) 140 | # e.g. num_spherical = 3, num_radial = 2 141 | # Y_lm: l: 0 0 1 1 2 2 142 | # m: 0 0 0 0 0 0 143 | out = (rbf_env * sph).view(-1, self.num_spherical * self.num_radial) 144 | return out # (nTriplets, num_spherical * num_radial) 145 | else: 146 | rbf_env = rbf_env.view(-1, self.num_spherical, self.num_radial) 147 | rbf_env = torch.transpose( 148 | rbf_env, 0, 1 149 | ) # (num_spherical, nEdges, num_radial) 150 | 151 | # Zero padded dense matrix 152 | # maximum number of neighbors, catch empty id_reduce_ji with maximum 153 | Kmax = 0 if sph.shape[0]==0 else torch.max(torch.max(Kidx + 1), torch.tensor(0)) 154 | nEdges = d_scaled.shape[0] 155 | 156 | sph2 = torch.zeros( 157 | nEdges, Kmax, self.num_spherical, device=self.device_buffer.device, dtype=sph.dtype 158 | ) 159 | sph2[id3_reduce_ca, Kidx] = sph 160 | 161 | # (num_spherical, nEdges, num_radial), (nEdges, Kmax, num_spherical) 162 | return rbf_env, sph2 163 | 164 | 165 | class TensorBasisLayer(torch.nn.Module): 166 | """ 167 | 3D Fourier Bessel Basis 168 | 169 | Parameters 170 | ---------- 171 | num_spherical: int 172 | Controls maximum frequency. 173 | num_radial: int 174 | Controls maximum frequency. 175 | cutoff: float 176 | Cutoff distance in Angstrom. 177 | envelope_exponent: int = 5 178 | Exponent of the envelope function. 179 | efficient: bool 180 | Whether to use the (memory) efficient implementation or not. 181 | """ 182 | 183 | def __init__( 184 | self, 185 | num_spherical: int, 186 | num_radial: int, 187 | cutoff: float, 188 | envelope_exponent: int = 5, 189 | efficient=False, 190 | name: str = "tensor_basis", 191 | ): 192 | super().__init__() 193 | 194 | assert num_radial <= 64 195 | self.num_radial = num_radial 196 | self.num_spherical = num_spherical 197 | self.efficient = efficient 198 | 199 | self.inv_cutoff = 1 / cutoff 200 | self.envelope = Envelope(envelope_exponent) 201 | 202 | # retrieve formulas 203 | bessel_formulas = bessel_basis(num_spherical, num_radial) 204 | Y_lm = real_sph_harm( 205 | num_spherical, spherical_coordinates=True, zero_m_only=False 206 | ) 207 | self.sph_funcs = [] # (num_spherical**2,) 208 | self.bessel_funcs = [] # (num_spherical * num_radial,) 209 | self.norm_const = self.inv_cutoff ** 1.5 210 | 211 | # convert to tensorflow functions 212 | x = sym.symbols("x") 213 | theta = sym.symbols("theta") 214 | phi = sym.symbols("phi") 215 | modules = {"sin": torch.sin, "cos": torch.cos, "sqrt": torch.sqrt} 216 | for l in range(len(Y_lm)): # num_spherical 217 | for m in range(len(Y_lm[l])): 218 | if ( 219 | l == 0 220 | ): # Y_00 is only a constant -> function returns value and not tensor 221 | first_sph = sym.lambdify([theta, phi], Y_lm[l][m], modules) 222 | self.sph_funcs.append( 223 | lambda theta, phi: torch.zeros_like(theta) 224 | + first_sph(theta, phi) 225 | ) 226 | else: 227 | self.sph_funcs.append( 228 | sym.lambdify([theta, phi], Y_lm[l][m], modules) 229 | ) 230 | for j in range(num_radial): 231 | self.bessel_funcs.append( 232 | sym.lambdify([x], bessel_formulas[l][j], modules) 233 | ) 234 | 235 | self.register_buffer( 236 | "degreeInOrder", torch.arange(num_spherical) * 2 + 1, persistent=False 237 | ) 238 | 239 | def forward(self, D_ca, Alpha_cab, Theta_cabd, id4_reduce_ca, Kidx): 240 | 241 | d_scaled = D_ca * self.inv_cutoff 242 | u_d = self.envelope(d_scaled) 243 | 244 | rbf = [f(d_scaled) for f in self.bessel_funcs] 245 | # s: 0 0 0 0 1 1 1 1 ... 246 | # r: 0 1 2 3 0 1 2 3 ... 247 | rbf = torch.stack(rbf, dim=1) # (nEdges, num_spherical * num_radial) 248 | rbf = rbf * self.norm_const 249 | 250 | rbf_env = u_d[:, None] * rbf # (nEdges, num_spherical * num_radial) 251 | rbf_env = rbf_env.view( 252 | (-1, self.num_spherical, self.num_radial) 253 | ) # (nEdges, num_spherical, num_radial) 254 | rbf_env = torch.repeat_interleave( 255 | rbf_env, self.degreeInOrder, dim=1 256 | ) # (nEdges, num_spherical**2, num_radial) 257 | 258 | if not self.efficient: 259 | rbf_env = rbf_env.view( 260 | (-1, self.num_spherical ** 2 * self.num_radial) 261 | ) # (nEdges, num_spherical**2 * num_radial) 262 | rbf_env = rbf_env[ 263 | id4_reduce_ca 264 | ] # (nQuadruplets, num_spherical**2 * num_radial) 265 | # e.g. num_spherical = 3, num_radial = 2 266 | # j_ln: l: 0 0 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 267 | # n: 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 268 | 269 | sph = [f(Alpha_cab, Theta_cabd) for f in self.sph_funcs] 270 | sph = torch.stack(sph, dim=1) # (nQuadruplets, num_spherical**2) 271 | 272 | if not self.efficient: 273 | sph = torch.repeat_interleave( 274 | sph, self.num_radial, axis=1 275 | ) # (nQuadruplets, num_spherical**2 * num_radial) 276 | # e.g. num_spherical = 3, num_radial = 2 277 | # Y_lm: l: 0 0 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 278 | # m: 0 0 -1 -1 0 0 1 1 -2 -2 -1 -1 0 0 1 1 2 2 279 | return rbf_env * sph # (nQuadruplets, num_spherical**2 * num_radial) 280 | 281 | else: 282 | rbf_env = torch.transpose(rbf_env, 0, 1) # (num_spherical**2, nEdges, num_radial) 283 | 284 | # Zero padded dense matrix 285 | # maximum number of neighbors, catch empty id_reduce_ji with maximum 286 | Kmax = 0 if sph.shape[0]==0 else torch.max(torch.max(Kidx + 1), torch.tensor(0)) 287 | nEdges = d_scaled.shape[0] 288 | 289 | sph2 = torch.zeros( 290 | nEdges, Kmax, self.num_spherical ** 2, device=self.degreeInOrder.device, dtype=sph.dtype 291 | ) 292 | sph2[id4_reduce_ca, Kidx] = sph 293 | 294 | # (num_spherical**2, nEdges, num_radial), (nEdges, Kmax, num_spherical**2) 295 | return rbf_env, sph2 296 | -------------------------------------------------------------------------------- /gemnet/model/layers/basis_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import brentq 3 | from scipy import special as sp 4 | import sympy as sym 5 | 6 | 7 | def Jn(r, n): 8 | """ 9 | numerical spherical bessel functions of order n 10 | """ 11 | return sp.spherical_jn(n, r) 12 | 13 | 14 | def Jn_zeros(n, k): 15 | """ 16 | Compute the first k zeros of the spherical bessel functions up to order n (excluded) 17 | """ 18 | zerosj = np.zeros((n, k), dtype="float32") 19 | zerosj[0] = np.arange(1, k + 1) * np.pi 20 | points = np.arange(1, k + n) * np.pi 21 | racines = np.zeros(k + n - 1, dtype="float32") 22 | for i in range(1, n): 23 | for j in range(k + n - 1 - i): 24 | foo = brentq(Jn, points[j], points[j + 1], (i,)) 25 | racines[j] = foo 26 | points = racines 27 | zerosj[i][:k] = racines[:k] 28 | 29 | return zerosj 30 | 31 | 32 | def spherical_bessel_formulas(n): 33 | """ 34 | Computes the sympy formulas for the spherical bessel functions up to order n (excluded) 35 | """ 36 | x = sym.symbols("x") 37 | # j_i = (-x)^i * (1/x * d/dx)^î * sin(x)/x 38 | j = [sym.sin(x) / x] # j_0 39 | a = sym.sin(x) / x 40 | for i in range(1, n): 41 | b = sym.diff(a, x) / x 42 | j += [sym.simplify(b * (-x) ** i)] 43 | a = sym.simplify(b) 44 | return j 45 | 46 | 47 | def bessel_basis(n, k): 48 | """ 49 | Compute the sympy formulas for the normalized and rescaled spherical bessel functions up to 50 | order n (excluded) and maximum frequency k (excluded). 51 | 52 | Returns: 53 | bess_basis: list 54 | Bessel basis formulas taking in a single argument x. 55 | Has length n where each element has length k. -> In total n*k many. 56 | """ 57 | zeros = Jn_zeros(n, k) 58 | normalizer = [] 59 | for order in range(n): 60 | normalizer_tmp = [] 61 | for i in range(k): 62 | normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1) ** 2] 63 | normalizer_tmp = ( 64 | 1 / np.array(normalizer_tmp) ** 0.5 65 | ) # sqrt(2/(j_l+1)**2) , sqrt(1/c**3) not taken into account yet 66 | normalizer += [normalizer_tmp] 67 | 68 | f = spherical_bessel_formulas(n) 69 | x = sym.symbols("x") 70 | bess_basis = [] 71 | for order in range(n): 72 | bess_basis_tmp = [] 73 | for i in range(k): 74 | bess_basis_tmp += [ 75 | sym.simplify( 76 | normalizer[order][i] * f[order].subs(x, zeros[order, i] * x) 77 | ) 78 | ] 79 | bess_basis += [bess_basis_tmp] 80 | return bess_basis 81 | 82 | 83 | def sph_harm_prefactor(l, m): 84 | """Computes the constant pre-factor for the spherical harmonic of degree l and order m. 85 | 86 | Parameters 87 | ---------- 88 | l: int 89 | Degree of the spherical harmonic. l >= 0 90 | m: int 91 | Order of the spherical harmonic. -l <= m <= l 92 | 93 | Returns 94 | ------- 95 | factor: float 96 | 97 | """ 98 | # sqrt((2*l+1)/4*pi * (l-m)!/(l+m)! ) 99 | return ( 100 | (2 * l + 1) 101 | / (4 * np.pi) 102 | * np.math.factorial(l - abs(m)) 103 | / np.math.factorial(l + abs(m)) 104 | ) ** 0.5 105 | 106 | 107 | def associated_legendre_polynomials(L, zero_m_only=True, pos_m_only=True): 108 | """Computes string formulas of the associated legendre polynomials up to degree L (excluded). 109 | 110 | Parameters 111 | ---------- 112 | L: int 113 | Degree up to which to calculate the associated legendre polynomials (degree L is excluded). 114 | zero_m_only: bool 115 | If True only calculate the polynomials for the polynomials where m=0. 116 | pos_m_only: bool 117 | If True only calculate the polynomials for the polynomials where m>=0. Overwritten by zero_m_only. 118 | 119 | Returns 120 | ------- 121 | polynomials: list 122 | Contains the sympy functions of the polynomials (in total L many if zero_m_only is True else L^2 many). 123 | """ 124 | # calculations from http://web.cmb.usc.edu/people/alber/Software/tomominer/docs/cpp/group__legendre__polynomials.html 125 | z = sym.symbols("z") 126 | P_l_m = [[0] * (2 * l + 1) for l in range(L)] # for order l: -l <= m <= l 127 | 128 | P_l_m[0][0] = 1 129 | if L > 0: 130 | if zero_m_only: 131 | # m = 0 132 | P_l_m[1][0] = z 133 | for l in range(2, L): 134 | P_l_m[l][0] = sym.simplify( 135 | ((2 * l - 1) * z * P_l_m[l - 1][0] - (l - 1) * P_l_m[l - 2][0]) / l 136 | ) 137 | return P_l_m 138 | else: 139 | # for m >= 0 140 | for l in range(1, L): 141 | P_l_m[l][l] = sym.simplify( 142 | (1 - 2 * l) * (1 - z ** 2) ** 0.5 * P_l_m[l - 1][l - 1] 143 | ) # P_00, P_11, P_22, P_33 144 | 145 | for m in range(0, L - 1): 146 | P_l_m[m + 1][m] = sym.simplify( 147 | (2 * m + 1) * z * P_l_m[m][m] 148 | ) # P_10, P_21, P_32, P_43 149 | 150 | for l in range(2, L): 151 | for m in range(l - 1): # P_20, P_30, P_31 152 | P_l_m[l][m] = sym.simplify( 153 | ( 154 | (2 * l - 1) * z * P_l_m[l - 1][m] 155 | - (l + m - 1) * P_l_m[l - 2][m] 156 | ) 157 | / (l - m) 158 | ) 159 | 160 | if not pos_m_only: 161 | # for m < 0: P_l(-m) = (-1)^m * (l-m)!/(l+m)! * P_lm 162 | for l in range(1, L): 163 | for m in range(1, l + 1): # P_1(-1), P_2(-1) P_2(-2) 164 | P_l_m[l][-m] = sym.simplify( 165 | (-1) ** m 166 | * np.math.factorial(l - m) 167 | / np.math.factorial(l + m) 168 | * P_l_m[l][m] 169 | ) 170 | 171 | return P_l_m 172 | 173 | 174 | def real_sph_harm(L, spherical_coordinates, zero_m_only=True): 175 | """ 176 | Computes formula strings of the the real part of the spherical harmonics up to degree L (excluded). 177 | Variables are either spherical coordinates phi and theta (or cartesian coordinates x,y,z) on the UNIT SPHERE. 178 | 179 | Parameters 180 | ---------- 181 | L: int 182 | Degree up to which to calculate the spherical harmonics (degree L is excluded). 183 | spherical_coordinates: bool 184 | - True: Expects the input of the formula strings to be phi and theta. 185 | - False: Expects the input of the formula strings to be x, y and z. 186 | zero_m_only: bool 187 | If True only calculate the harmonics where m=0. 188 | 189 | Returns 190 | ------- 191 | Y_lm_real: list 192 | Computes formula strings of the the real part of the spherical harmonics up 193 | to degree L (where degree L is not excluded). 194 | In total L^2 many sph harm exist up to degree L (excluded). However, if zero_m_only only is True then 195 | the total count is reduced to be only L many. 196 | """ 197 | z = sym.symbols("z") 198 | P_l_m = associated_legendre_polynomials(L, zero_m_only) 199 | if zero_m_only: 200 | # for all m != 0: Y_lm = 0 201 | Y_l_m = [[0] for l in range(L)] 202 | else: 203 | Y_l_m = [[0] * (2 * l + 1) for l in range(L)] # for order l: -l <= m <= l 204 | 205 | # convert expressions to spherical coordiantes 206 | if spherical_coordinates: 207 | # replace z by cos(theta) 208 | theta = sym.symbols("theta") 209 | for l in range(L): 210 | for m in range(len(P_l_m[l])): 211 | if not isinstance(P_l_m[l][m], int): 212 | P_l_m[l][m] = P_l_m[l][m].subs(z, sym.cos(theta)) 213 | 214 | ## calculate Y_lm 215 | # Y_lm = N * P_lm(cos(theta)) * exp(i*m*phi) 216 | # { sqrt(2) * (-1)^m * N * P_l|m| * sin(|m|*phi) if m < 0 217 | # Y_lm_real = { Y_lm if m = 0 218 | # { sqrt(2) * (-1)^m * N * P_lm * cos(m*phi) if m > 0 219 | 220 | for l in range(L): 221 | Y_l_m[l][0] = sym.simplify(sph_harm_prefactor(l, 0) * P_l_m[l][0]) # Y_l0 222 | 223 | if not zero_m_only: 224 | phi = sym.symbols("phi") 225 | for l in range(1, L): 226 | # m > 0 227 | for m in range(1, l + 1): 228 | Y_l_m[l][m] = sym.simplify( 229 | 2 ** 0.5 230 | * (-1) ** m 231 | * sph_harm_prefactor(l, m) 232 | * P_l_m[l][m] 233 | * sym.cos(m * phi) 234 | ) 235 | # m < 0 236 | for m in range(1, l + 1): 237 | Y_l_m[l][-m] = sym.simplify( 238 | 2 ** 0.5 239 | * (-1) ** m 240 | * sph_harm_prefactor(l, -m) 241 | * P_l_m[l][m] 242 | * sym.sin(m * phi) 243 | ) 244 | 245 | # convert expressions to cartesian coordinates 246 | if not spherical_coordinates: 247 | # replace phi by atan2(y,x) 248 | x = sym.symbols("x") 249 | y = sym.symbols("y") 250 | for l in range(L): 251 | for m in range(len(Y_l_m[l])): 252 | Y_l_m[l][m] = sym.simplify(Y_l_m[l][m].subs(phi, sym.atan2(y, x))) 253 | return Y_l_m 254 | -------------------------------------------------------------------------------- /gemnet/model/layers/efficient.py: -------------------------------------------------------------------------------- 1 | from ..initializers import he_orthogonal_init 2 | import torch 3 | 4 | 5 | class EfficientInteractionDownProjection(torch.nn.Module): 6 | """ 7 | Down projection in the efficient reformulation. 8 | 9 | Parameters 10 | ---------- 11 | num_spherical: int 12 | Same as the setting in the basis layers. 13 | num_radial: int 14 | Same as the setting in the basis layers. 15 | emb_size_interm: int 16 | Intermediate embedding size (down-projection size). 17 | """ 18 | 19 | def __init__( 20 | self, 21 | num_spherical: int, 22 | num_radial: int, 23 | emb_size_interm: int, 24 | name="EfficientDownProj", 25 | ): 26 | super().__init__() 27 | 28 | self.num_spherical = num_spherical 29 | self.num_radial = num_radial 30 | self.emb_size_interm = emb_size_interm 31 | 32 | self.reset_parameters() 33 | 34 | def reset_parameters(self): 35 | self.weight = torch.nn.Parameter( 36 | torch.empty((self.num_spherical, self.num_radial, self.emb_size_interm)), 37 | requires_grad=True, 38 | ) 39 | he_orthogonal_init(self.weight) 40 | 41 | def forward(self, tbf): 42 | """ 43 | Returns 44 | ------- 45 | (rbf_W1, sph): tuple 46 | - rbf_W1: Tensor, shape=(nEdges, emb_size_interm, num_spherical) 47 | - sph: Tensor, shape=(nEdges, Kmax, num_spherical) 48 | """ 49 | rbf_env, sph = tbf 50 | # (num_spherical, nEdges, num_radial), (nEdges, Kmax, num_spherical) ; Kmax = maximum number of neighbors of the edges 51 | 52 | # MatMul: mul + sum over num_radial 53 | rbf_W1 = torch.matmul(rbf_env, self.weight) # (num_spherical, nEdges , emb_size_interm) 54 | rbf_W1 = rbf_W1.permute(1, 2, 0) # (nEdges, emb_size_interm, num_spherical) 55 | 56 | sph = torch.transpose(sph, 1, 2) # (nEdges, num_spherical, Kmax) 57 | return rbf_W1, sph 58 | 59 | 60 | class EfficientInteractionHadamard(torch.nn.Module): 61 | """ 62 | Efficient reformulation of the hadamard product and subsequent summation. 63 | 64 | Parameters 65 | ---------- 66 | emb_size_interm: int 67 | Intermediate embedding size (down-projection size). 68 | emb_size: int 69 | Embedding size. 70 | """ 71 | 72 | def __init__(self, emb_size_interm: int, emb_size: int, name="EfficientHadamard"): 73 | super().__init__() 74 | self.emb_size_interm = emb_size_interm 75 | self.emb_size = emb_size 76 | 77 | self.reset_parameters() 78 | 79 | def reset_parameters(self): 80 | self.weight = torch.nn.Parameter( 81 | torch.empty((self.emb_size, 1, self.emb_size_interm), requires_grad=True) 82 | ) 83 | he_orthogonal_init(self.weight) 84 | 85 | def forward(self, basis, m, id_reduce, Kidx): 86 | """ 87 | Returns 88 | ------- 89 | m_ca: Tensor, shape=(nEdges, emb_size) 90 | Edge embeddings. 91 | """ 92 | # quadruplets: m = m_db , triplets: m = m_ba 93 | # num_spherical is actually num_spherical**2 for quadruplets 94 | rbf_W1, sph = basis # (nEdges, emb_size_interm, num_spherical) , (nEdges, num_spherical, Kmax) 95 | nEdges = rbf_W1.shape[0] 96 | 97 | # Create (zero-padded) dense matrix of the neighboring edge embeddings. 98 | # maximum number of neighbors, catch empty id_reduce_ji with maximum 99 | if sph.shape[2]==0: 100 | Kmax = 0 101 | else: 102 | Kmax = torch.max(torch.max(Kidx + 1), torch.tensor(0)) 103 | m2 = torch.zeros(nEdges, Kmax, self.emb_size, device=self.weight.device, dtype=m.dtype) 104 | m2[id_reduce, Kidx] = m # (nQuadruplets or nTriplets, emb_size) -> (nEdges, Kmax, emb_size) 105 | 106 | sum_k = torch.matmul(sph, m2) # (nEdges, num_spherical, emb_size) 107 | 108 | # MatMul: mul + sum over num_spherical 109 | rbf_W1_sum_k = torch.matmul( 110 | rbf_W1, sum_k 111 | ) # (nEdges, emb_size_interm, emb_size) 112 | 113 | # MatMul: mul + sum over emb_size_interm 114 | m_ca = torch.matmul(self.weight, rbf_W1_sum_k.permute(2, 1, 0))[:, 0] # (emb_size, nEdges) 115 | m_ca = torch.transpose(m_ca, 0, 1) # (nEdges, emb_size) 116 | 117 | return m_ca 118 | 119 | 120 | class EfficientInteractionBilinear(torch.nn.Module): 121 | """ 122 | Efficient reformulation of the bilinear layer and subsequent summation. 123 | 124 | Parameters 125 | ---------- 126 | emb_size: int 127 | Edge embedding size. 128 | emb_size_interm: int 129 | Intermediate embedding size (down-projection size). 130 | units_out: int 131 | Embedding output size of the bilinear layer. 132 | kernel_initializer: callable 133 | Initializer of the weight matrix. 134 | """ 135 | 136 | def __init__( 137 | self, 138 | emb_size: int, 139 | emb_size_interm: int, 140 | units_out: int, 141 | name="EfficientBilinear", 142 | ): 143 | super().__init__() 144 | self.emb_size = emb_size 145 | self.emb_size_interm = emb_size_interm 146 | self.units_out = units_out 147 | 148 | self.reset_parameters() 149 | 150 | def reset_parameters(self): 151 | self.weight = torch.nn.Parameter( 152 | torch.empty( 153 | (self.emb_size, self.emb_size_interm, self.units_out), 154 | requires_grad=True, 155 | ) 156 | ) 157 | he_orthogonal_init(self.weight) 158 | 159 | def forward(self, basis, m, id_reduce, Kidx): 160 | """ 161 | Returns 162 | ------- 163 | m_ca: Tensor, shape=(nEdges, units_out) 164 | Edge embeddings. 165 | """ 166 | # quadruplets: m = m_db , triplets: m = m_ba 167 | # num_spherical is actually num_spherical**2 for quadruplets 168 | rbf_W1, sph = basis # (nEdges, emb_size_interm, num_spherical) , (nEdges, num_spherical, Kmax) 169 | nEdges = rbf_W1.shape[0] 170 | 171 | # Create (zero-padded) dense matrix of the neighboring edge embeddings. 172 | # maximum number of neighbors, catch empty id_reduce_ji with maximum 173 | Kmax = 0 if sph.shape[2]==0 else torch.max(torch.max(Kidx + 1), torch.tensor(0)) 174 | m2 = torch.zeros(nEdges, Kmax, self.emb_size, device=self.weight.device, dtype=m.dtype) 175 | m2[id_reduce, Kidx] = m # (nQuadruplets or nTriplets, emb_size) -> (nEdges, Kmax, emb_size) 176 | 177 | sum_k = torch.matmul(sph, m2) # (nEdges, num_spherical, emb_size) 178 | 179 | # MatMul: mul + sum over num_spherical 180 | rbf_W1_sum_k = torch.matmul( 181 | rbf_W1, sum_k 182 | ) # (nEdges, emb_size_interm, emb_size) 183 | 184 | # Bilinear: Sum over emb_size_interm and emb_size 185 | m_ca = torch.matmul( 186 | rbf_W1_sum_k.permute(2, 0, 1), self.weight 187 | ) # (emb_size, nEdges, units_out) 188 | m_ca = torch.sum(m_ca, dim=0) # (nEdges, units_out) 189 | return m_ca 190 | -------------------------------------------------------------------------------- /gemnet/model/layers/embedding_block.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .base_layers import Dense 5 | 6 | 7 | class AtomEmbedding(torch.nn.Module): 8 | """ 9 | Initial atom embeddings based on the atom type 10 | 11 | Parameters 12 | ---------- 13 | emb_size: int 14 | Atom embeddings size 15 | """ 16 | 17 | def __init__(self, emb_size, name=None): 18 | super().__init__() 19 | self.emb_size = emb_size 20 | 21 | # Atom embeddings: We go up to Pu (94). Use 93 dimensions because of 0-based indexing 22 | self.embeddings = torch.nn.Embedding(93, emb_size) 23 | # init by uniform distribution 24 | torch.nn.init.uniform_(self.embeddings.weight, a=-np.sqrt(3), b=np.sqrt(3)) 25 | 26 | def forward(self, Z): 27 | """ 28 | Returns 29 | ------- 30 | h: Tensor, shape=(nAtoms, emb_size) 31 | Atom embeddings. 32 | """ 33 | h = self.embeddings(Z - 1) # -1 because Z.min()=1 (==Hydrogen) 34 | return h 35 | 36 | 37 | class EdgeEmbedding(torch.nn.Module): 38 | """ 39 | Edge embedding based on the concatenation of atom embeddings and subsequent dense layer. 40 | 41 | Parameters 42 | ---------- 43 | atom_features: int 44 | Embedding size of the atom embeddings. 45 | edge_features: int 46 | Embedding size of the edge embeddings. 47 | out_features: int 48 | Embedding size after the dense layer. 49 | activation: str 50 | Activation function used in the dense layer. 51 | """ 52 | 53 | def __init__( 54 | self, atom_features, edge_features, out_features, activation=None, name=None 55 | ): 56 | super().__init__() 57 | in_features = 2 * atom_features + edge_features 58 | self.dense = Dense(in_features, out_features, activation=activation, bias=False) 59 | 60 | def forward(self, h, m_rbf, idnb_a, idnb_c,): 61 | """ 62 | Returns 63 | ------- 64 | m_ca: Tensor, shape=(nEdges, emb_size) 65 | Edge embeddings. 66 | """ 67 | # m_rbf: shape (nEdges, nFeatures) 68 | # in embedding block: m_rbf = rbf ; In interaction block: m_rbf = m_ca 69 | 70 | h_a = h[idnb_a] # shape=(nEdges, emb_size) 71 | h_c = h[idnb_c] # shape=(nEdges, emb_size) 72 | 73 | m_ca = torch.cat([h_a, h_c, m_rbf], dim=-1) # (nEdges, 2*emb_size+nFeatures) 74 | m_ca = self.dense(m_ca) # (nEdges, emb_size) 75 | return m_ca 76 | -------------------------------------------------------------------------------- /gemnet/model/layers/envelope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Envelope(torch.nn.Module): 5 | """ 6 | Envelope function that ensures a smooth cutoff. 7 | 8 | Parameters 9 | ---------- 10 | p: int 11 | Exponent of the envelope function. 12 | """ 13 | 14 | def __init__(self, p, name="envelope"): 15 | super().__init__() 16 | assert p > 0 17 | self.p = p 18 | self.a = -(self.p + 1) * (self.p + 2) / 2 19 | self.b = self.p * (self.p + 2) 20 | self.c = -self.p * (self.p + 1) / 2 21 | 22 | def forward(self, d_scaled): 23 | env_val = ( 24 | 1 25 | + self.a * d_scaled ** self.p 26 | + self.b * d_scaled ** (self.p + 1) 27 | + self.c * d_scaled ** (self.p + 2) 28 | ) 29 | return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled)) 30 | -------------------------------------------------------------------------------- /gemnet/model/layers/interaction_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base_layers import ResidualLayer, Dense 4 | from .embedding_block import EdgeEmbedding 5 | import numpy as np 6 | from .scaling import ScalingFactor 7 | from .atom_update_block import AtomUpdateBlock 8 | from .efficient import EfficientInteractionHadamard, EfficientInteractionBilinear 9 | 10 | 11 | class InteractionBlock(torch.nn.Module): 12 | """ 13 | Interaction block for GemNet-Q/dQ. 14 | 15 | Parameters 16 | ---------- 17 | emb_size_atom: int 18 | Embedding size of the atoms. 19 | emb_size_edge: int 20 | Embedding size of the edges. 21 | emb_size_trip: int 22 | (Down-projected) Embedding size in the triplet message passing block. 23 | emb_size_quad: int 24 | (Down-projected) Embedding size in the quadruplet message passing block. 25 | emb_size_rbf: int 26 | Embedding size of the radial basis transformation. 27 | emb_size_cbf: int 28 | Embedding size of the circular basis transformation (one angle). 29 | emb_size_sbf: int 30 | Embedding size of the spherical basis transformation (two angles). 31 | emb_size_bil_trip: int 32 | Embedding size of the edge embeddings in the triplet-based message passing block after the bilinear layer. 33 | emb_size_bil_quad: int 34 | Embedding size of the edge embeddings in the quadruplet-based message passing block after the bilinear layer. 35 | num_before_skip: int 36 | Number of residual blocks before the first skip connection. 37 | num_after_skip: int 38 | Number of residual blocks after the first skip connection. 39 | num_concat: int 40 | Number of residual blocks after the concatenation. 41 | num_atom: int 42 | Number of residual blocks in the atom embedding blocks. 43 | activation: str 44 | Name of the activation function to use in the dense layers (except for the final dense layer). 45 | scale_file: str 46 | Path to the json file containing the scaling factors. 47 | """ 48 | 49 | def __init__( 50 | self, 51 | emb_size_atom, 52 | emb_size_edge, 53 | emb_size_trip, 54 | emb_size_quad, 55 | emb_size_rbf, 56 | emb_size_cbf, 57 | emb_size_sbf, 58 | emb_size_bil_trip, 59 | emb_size_bil_quad, 60 | num_before_skip, 61 | num_after_skip, 62 | num_concat, 63 | num_atom, 64 | activation=None, 65 | scale_file=None, 66 | name="Interaction", 67 | ): 68 | super().__init__() 69 | self.name = name 70 | 71 | block_nr = name.split("_")[-1] 72 | 73 | ## -------------------------------------------- Message Passing ------------------------------------------- ## 74 | # Dense transformation of skip connection 75 | self.dense_ca = Dense( 76 | emb_size_edge, 77 | emb_size_edge, 78 | activation=activation, 79 | bias=False, 80 | name="dense_ca", 81 | ) 82 | 83 | # Quadruplet Interaction 84 | self.quad_interaction = QuadrupletInteraction( 85 | emb_size_edge=emb_size_edge, 86 | emb_size_quad=emb_size_quad, 87 | emb_size_bilinear=emb_size_bil_quad, 88 | emb_size_rbf=emb_size_rbf, 89 | emb_size_cbf=emb_size_cbf, 90 | emb_size_sbf=emb_size_sbf, 91 | activation=activation, 92 | scale_file=scale_file, 93 | name=f"QuadInteraction_{block_nr}", 94 | ) 95 | 96 | # Triplet Interaction 97 | self.trip_interaction = TripletInteraction( 98 | emb_size_edge=emb_size_edge, 99 | emb_size_trip=emb_size_trip, 100 | emb_size_bilinear=emb_size_bil_trip, 101 | emb_size_rbf=emb_size_rbf, 102 | emb_size_cbf=emb_size_cbf, 103 | activation=activation, 104 | scale_file=scale_file, 105 | name=f"TripInteraction_{block_nr}", 106 | ) 107 | 108 | ## ---------------------------------------- Update Edge Embeddings ---------------------------------------- ## 109 | # Residual layers before skip connection 110 | self.layers_before_skip = torch.nn.ModuleList( 111 | [ 112 | ResidualLayer( 113 | emb_size_edge, activation=activation, name=f"res_bef_skip_{i}" 114 | ) 115 | for i in range(num_before_skip) 116 | ] 117 | ) 118 | 119 | # Residual layers after skip connection 120 | self.layers_after_skip = torch.nn.ModuleList( 121 | [ 122 | ResidualLayer( 123 | emb_size_edge, activation=activation, name=f"res_aft_skip_{i}" 124 | ) 125 | for i in range(num_after_skip) 126 | ] 127 | ) 128 | 129 | ## ---------------------------------------- Update Atom Embeddings ---------------------------------------- ## 130 | self.atom_update = AtomUpdateBlock( 131 | emb_size_atom=emb_size_atom, 132 | emb_size_edge=emb_size_edge, 133 | emb_size_rbf=emb_size_rbf, 134 | nHidden=num_atom, 135 | activation=activation, 136 | scale_file=scale_file, 137 | name=f"AtomUpdate_{block_nr}", 138 | ) 139 | 140 | ## ------------------------------ Update Edge Embeddings with Atom Embeddings ----------------------------- ## 141 | self.concat_layer = EdgeEmbedding( 142 | emb_size_atom, 143 | emb_size_edge, 144 | emb_size_edge, 145 | activation=activation, 146 | name="concat", 147 | ) 148 | self.residual_m = torch.nn.ModuleList( 149 | [ 150 | ResidualLayer(emb_size_edge, activation=activation, name=f"res_m_{i}") 151 | for i in range(num_concat) 152 | ] 153 | ) 154 | 155 | self.inv_sqrt_2 = 1 / (2.0 ** 0.5) 156 | self.inv_sqrt_3 = 1 / (3.0 ** 0.5) 157 | 158 | def forward(self, 159 | h, 160 | m, 161 | rbf4, 162 | cbf4, 163 | sbf4, 164 | Kidx4, 165 | rbf3, 166 | cbf3, 167 | Kidx3, 168 | id_swap, 169 | id3_expand_ba, 170 | id3_reduce_ca, 171 | id4_reduce_ca, 172 | id4_expand_intm_db, 173 | id4_expand_abd, 174 | rbf_h, 175 | id_c, 176 | id_a): 177 | """ 178 | Returns 179 | ------- 180 | h: Tensor, shape=(nEdges, emb_size_atom) 181 | Atom embeddings. 182 | m: Tensor, shape=(nEdges, emb_size_edge) 183 | Edge embeddings (c->a). 184 | """ 185 | # Initial transformation 186 | x_ca_skip = self.dense_ca(m) # (nEdges, emb_size_edge) 187 | 188 | x4 = self.quad_interaction( 189 | m, 190 | rbf4, 191 | cbf4, 192 | sbf4, 193 | Kidx4, 194 | id_swap, 195 | id4_reduce_ca, 196 | id4_expand_intm_db, 197 | id4_expand_abd, 198 | ) 199 | x3 = self.trip_interaction(m, rbf3, cbf3, Kidx3, id_swap, id3_expand_ba, id3_reduce_ca) 200 | 201 | ## ---------------------- Merge Embeddings after Quadruplet and Triplet Interaction ---------------------- ## 202 | x = x_ca_skip + x3 + x4 # (nEdges, emb_size_edge) 203 | x = x * self.inv_sqrt_3 204 | 205 | ## --------------------------------------- Update Edge Embeddings ---------------------------------------- ## 206 | # Transformations before skip connection 207 | for i, layer in enumerate(self.layers_before_skip): 208 | x = layer(x) # (nEdges, emb_size_edge) 209 | 210 | # Skip connection 211 | m = m + x # (nEdges, emb_size_edge) 212 | m = m * self.inv_sqrt_2 213 | 214 | # Transformations after skip connection 215 | for i, layer in enumerate(self.layers_after_skip): 216 | m = layer(m) # (nEdges, emb_size_edge) 217 | 218 | ## --------------------------------------- Update Atom Embeddings ---------------------------------------- ## 219 | h2 = self.atom_update(h, m, rbf_h, id_a) 220 | 221 | # Skip connection 222 | h = h + h2 # (nAtoms, emb_size_atom) 223 | h = h * self.inv_sqrt_2 224 | 225 | ## ----------------------------- Update Edge Embeddings with Atom Embeddings ----------------------------- ## 226 | m2 = self.concat_layer(h, m, id_c, id_a) # (nEdges, emb_size_edge) 227 | 228 | for i, layer in enumerate(self.residual_m): 229 | m2 = layer(m2) # (nEdges, emb_size_edge) 230 | 231 | # Skip connection 232 | m = m + m2 # (nEdges, emb_size_edge) 233 | m = m * self.inv_sqrt_2 234 | return h, m 235 | 236 | 237 | class InteractionBlockTripletsOnly(torch.nn.Module): 238 | """ 239 | Interaction block for GemNet-T/dT. 240 | 241 | Parameters 242 | ---------- 243 | emb_size_atom: int 244 | Embedding size of the atoms. 245 | emb_size_edge: int 246 | Embedding size of the edges. 247 | emb_size_trip: int 248 | (Down-projected) Embedding size in the triplet message passing block. 249 | emb_size_rbf: int 250 | Embedding size of the radial basis transformation. 251 | emb_size_cbf: int 252 | Embedding size of the circular basis transformation (one angle). 253 | emb_size_bil_trip: int 254 | Embedding size of the edge embeddings in the triplet-based message passing block after the bilinear layer. 255 | num_before_skip: int 256 | Number of residual blocks before the first skip connection. 257 | num_after_skip: int 258 | Number of residual blocks after the first skip connection. 259 | num_concat: int 260 | Number of residual blocks after the concatenation. 261 | num_atom: int 262 | Number of residual blocks in the atom embedding blocks. 263 | activation: str 264 | Name of the activation function to use in the dense layers (except for the final dense layer). 265 | scale_file: str 266 | Path to the json file containing the scaling factors. 267 | """ 268 | 269 | def __init__( 270 | self, 271 | emb_size_atom, 272 | emb_size_edge, 273 | emb_size_trip, 274 | emb_size_quad, 275 | emb_size_rbf, 276 | emb_size_cbf, 277 | emb_size_bil_trip, 278 | num_before_skip, 279 | num_after_skip, 280 | num_concat, 281 | num_atom, 282 | activation=None, 283 | scale_file=None, 284 | name="Interaction", 285 | **kwargs, 286 | ): 287 | super().__init__() 288 | self.name = name 289 | 290 | block_nr = name.split("_")[-1] 291 | 292 | ## -------------------------------------------- Message Passing ------------------------------------------- ## 293 | # Dense transformation of skip connection 294 | self.dense_ca = Dense( 295 | emb_size_edge, 296 | emb_size_edge, 297 | activation=activation, 298 | bias=False, 299 | name="dense_ca", 300 | ) 301 | 302 | # Triplet Interaction 303 | self.trip_interaction = TripletInteraction( 304 | emb_size_edge=emb_size_edge, 305 | emb_size_trip=emb_size_trip, 306 | emb_size_bilinear=emb_size_bil_trip, 307 | emb_size_rbf=emb_size_rbf, 308 | emb_size_cbf=emb_size_cbf, 309 | activation=activation, 310 | scale_file=scale_file, 311 | name=f"TripInteraction_{block_nr}", 312 | ) 313 | 314 | ## ---------------------------------------- Update Edge Embeddings ---------------------------------------- ## 315 | # Residual layers before skip connection 316 | self.layers_before_skip = torch.nn.ModuleList( 317 | [ 318 | ResidualLayer( 319 | emb_size_edge, activation=activation, name=f"res_bef_skip_{i}" 320 | ) 321 | for i in range(num_before_skip) 322 | ] 323 | ) 324 | 325 | # Residual layers after skip connection 326 | self.layers_after_skip = torch.nn.ModuleList( 327 | [ 328 | ResidualLayer( 329 | emb_size_edge, activation=activation, name=f"res_aft_skip_{i}" 330 | ) 331 | for i in range(num_after_skip) 332 | ] 333 | ) 334 | 335 | ## ---------------------------------------- Update Atom Embeddings ---------------------------------------- ## 336 | self.atom_update = AtomUpdateBlock( 337 | emb_size_atom=emb_size_atom, 338 | emb_size_edge=emb_size_edge, 339 | emb_size_rbf=emb_size_rbf, 340 | nHidden=num_atom, 341 | activation=activation, 342 | scale_file=scale_file, 343 | name=f"AtomUpdate_{block_nr}", 344 | ) 345 | 346 | ## ------------------------------ Update Edge Embeddings with Atom Embeddings ----------------------------- ## 347 | self.concat_layer = EdgeEmbedding( 348 | emb_size_atom, 349 | emb_size_edge, 350 | emb_size_edge, 351 | activation=activation, 352 | name="concat", 353 | ) 354 | self.residual_m = torch.nn.ModuleList( 355 | [ 356 | ResidualLayer(emb_size_edge, activation=activation, name=f"res_m_{i}") 357 | for i in range(num_concat) 358 | ] 359 | ) 360 | 361 | self.inv_sqrt_2 = 1 / (2.0 ** 0.5) 362 | 363 | def forward(self, 364 | h, 365 | m, 366 | rbf3, 367 | cbf3, 368 | Kidx3, 369 | id_swap, 370 | id3_expand_ba, 371 | id3_reduce_ca, 372 | rbf_h, 373 | id_c, 374 | id_a, 375 | **kwargs): 376 | """ 377 | Returns 378 | ------- 379 | h: Tensor, shape=(nEdges, emb_size_atom) 380 | Atom embeddings. 381 | m: Tensor, shape=(nEdges, emb_size_edge) 382 | Edge embeddings (c->a). 383 | """ 384 | # Initial transformation 385 | x_ca_skip = self.dense_ca(m) # (nEdges, emb_size_edge) 386 | 387 | x3 = self.trip_interaction(m, rbf3, cbf3, Kidx3, id_swap, id3_expand_ba, id3_reduce_ca) 388 | 389 | ## ----------------------------- Merge Embeddings after Triplet Interaction ------------------------------ ## 390 | x = x_ca_skip + x3 # (nEdges, emb_size_edge) 391 | x = x * self.inv_sqrt_2 392 | 393 | ## ---------------------------------------- Update Edge Embeddings --------------------------------------- ## 394 | # Transformations before skip connection 395 | for i, layer in enumerate(self.layers_before_skip): 396 | x = layer(x) # (nEdges, emb_size_edge) 397 | 398 | # Skip connection 399 | m = m + x # (nEdges, emb_size_edge) 400 | m = m * self.inv_sqrt_2 401 | 402 | # Transformations after skip connection 403 | for i, layer in enumerate(self.layers_after_skip): 404 | m = layer(m) # (nEdges, emb_size_edge) 405 | 406 | ## ---------------------------------------- Update Atom Embeddings --------------------------------------- ## 407 | h2 = self.atom_update(h, m, rbf_h, id_a) # (nAtoms, emb_size_atom) 408 | 409 | # Skip connection 410 | h = h + h2 # (nAtoms, emb_size_atom) 411 | h = h * self.inv_sqrt_2 412 | 413 | ## ----------------------------- Update Edge Embeddings with Atom Embeddings ----------------------------- ## 414 | m2 = self.concat_layer(h, m, id_c, id_a) # (nEdges, emb_size_edge) 415 | 416 | for i, layer in enumerate(self.residual_m): 417 | m2 = layer(m2) # (nEdges, emb_size_edge) 418 | 419 | # Skip connection 420 | m = m + m2 # (nEdges, emb_size_edge) 421 | m = m * self.inv_sqrt_2 422 | return h, m 423 | 424 | 425 | class QuadrupletInteraction(torch.nn.Module): 426 | """ 427 | Quadruplet-based message passing block. 428 | 429 | Parameters 430 | ---------- 431 | emb_size_edge: int 432 | Embedding size of the edges. 433 | emb_size_quad: int 434 | (Down-projected) Embedding size of the edge embeddings after the hadamard product with rbf. 435 | emb_size_bilinear: int 436 | Embedding size of the edge embeddings after the bilinear layer. 437 | emb_size_rbf: int 438 | Embedding size of the radial basis transformation. 439 | emb_size_cbf: int 440 | Embedding size of the circular basis transformation (one angle). 441 | emb_size_sbf: int 442 | Embedding size of the spherical basis transformation (two angles). 443 | activation: str 444 | Name of the activation function to use in the dense layers (except for the final dense layer). 445 | scale_file: str 446 | Path to the json file containing the scaling factors. 447 | """ 448 | 449 | def __init__( 450 | self, 451 | emb_size_edge, 452 | emb_size_quad, 453 | emb_size_bilinear, 454 | emb_size_rbf, 455 | emb_size_cbf, 456 | emb_size_sbf, 457 | activation=None, 458 | scale_file=None, 459 | name="QuadrupletInteraction", 460 | **kwargs, 461 | ): 462 | super().__init__() 463 | self.name = name 464 | 465 | # Dense transformation 466 | self.dense_db = Dense( 467 | emb_size_edge, 468 | emb_size_edge, 469 | activation=activation, 470 | bias=False, 471 | name="dense_db", 472 | ) 473 | 474 | # Up projections of basis representations, bilinear layer and scaling factors 475 | self.mlp_rbf = Dense( 476 | emb_size_rbf, emb_size_edge, activation=None, name="MLP_rbf4_2", bias=False 477 | ) 478 | self.scale_rbf = ScalingFactor(scale_file=scale_file, name=name + "_had_rbf") 479 | 480 | self.mlp_cbf = Dense( 481 | emb_size_cbf, emb_size_quad, activation=None, name="MLP_cbf4_2", bias=False 482 | ) 483 | self.scale_cbf = ScalingFactor(scale_file=scale_file, name=name + "_had_cbf") 484 | 485 | self.mlp_sbf = EfficientInteractionBilinear( 486 | emb_size_quad, emb_size_sbf, emb_size_bilinear, name="MLP_sbf4_2" 487 | ) 488 | self.scale_sbf_sum = ScalingFactor( 489 | scale_file=scale_file, name=name + "_sum_sbf" 490 | ) # combines scaling for bilinear layer and summation 491 | 492 | # Down and up projections 493 | self.down_projection = Dense( 494 | emb_size_edge, 495 | emb_size_quad, 496 | activation=activation, 497 | bias=False, 498 | name="dense_down", 499 | ) 500 | self.up_projection_ca = Dense( 501 | emb_size_bilinear, 502 | emb_size_edge, 503 | activation=activation, 504 | bias=False, 505 | name="dense_up_ca", 506 | ) 507 | self.up_projection_ac = Dense( 508 | emb_size_bilinear, 509 | emb_size_edge, 510 | activation=activation, 511 | bias=False, 512 | name="dense_up_ac", 513 | ) 514 | 515 | self.inv_sqrt_2 = 1 / (2.0 ** 0.5) 516 | 517 | def forward(self, 518 | m, 519 | rbf, 520 | cbf, 521 | sbf, 522 | Kidx4, 523 | id_swap, 524 | id4_reduce_ca, 525 | id4_expand_intm_db, 526 | id4_expand_abd): 527 | """ 528 | Returns 529 | ------- 530 | m: Tensor, shape=(nEdges, emb_size_edge) 531 | Edge embeddings (c->a). 532 | """ 533 | x_db = self.dense_db(m) # (nEdges, emb_size_edge) 534 | 535 | # Transform via radial bessel basis 536 | x_db2 = x_db * self.mlp_rbf(rbf) # (nEdges, emb_size_edge) 537 | x_db = self.scale_rbf(x_db, x_db2) 538 | 539 | # Down project embeddings 540 | x_db = self.down_projection(x_db) # (nEdges, emb_size_quad) 541 | 542 | # Transform via circular spherical bessel basis 543 | x_db = x_db[id4_expand_intm_db] # (intmTriplets, emb_size_quad) 544 | x_db2 = x_db * self.mlp_cbf(cbf) # (intmTriplets, emb_size_quad) 545 | x_db = self.scale_cbf(x_db, x_db2) 546 | 547 | # Transform via spherical bessel basis 548 | x_db = x_db[id4_expand_abd] # (nQuadruplets, emb_size_quad) 549 | x = self.mlp_sbf(sbf, x_db, id4_reduce_ca, Kidx4) # (nEdges, emb_size_bilinear) 550 | x = self.scale_sbf_sum(x_db, x) 551 | 552 | # Basis representation: 553 | # rbf(d_db) 554 | # cbf(d_ba, angle_abd) 555 | # sbf(d_ca, angle_cab, angle_cabd) 556 | 557 | # Upproject embeddings 558 | x_ca = self.up_projection_ca(x) # (nEdges, emb_size_edge) 559 | x_ac = self.up_projection_ac(x) # (nEdges, emb_size_edge) 560 | 561 | # Merge interaction of c->a and a->c 562 | x_ac = x_ac[id_swap] # swap to add to edge a->c and not c->a 563 | x4 = x_ca + x_ac 564 | x4 = x4 * self.inv_sqrt_2 565 | 566 | return x4 567 | 568 | 569 | class TripletInteraction(torch.nn.Module): 570 | """ 571 | Triplet-based message passing block. 572 | 573 | Parameters 574 | ---------- 575 | emb_size_edge: int 576 | Embedding size of the edges. 577 | emb_size_trip: int 578 | (Down-projected) Embedding size of the edge embeddings after the hadamard product with rbf. 579 | emb_size_bilinear: int 580 | Embedding size of the edge embeddings after the bilinear layer. 581 | emb_size_rbf: int 582 | Embedding size of the radial basis transformation. 583 | emb_size_cbf: int 584 | Embedding size of the circular basis transformation (one angle). 585 | activation: str 586 | Name of the activation function to use in the dense layers (except for the final dense layer). 587 | scale_file: str 588 | Path to the json file containing the scaling factors. 589 | """ 590 | 591 | def __init__( 592 | self, 593 | emb_size_edge, 594 | emb_size_trip, 595 | emb_size_bilinear, 596 | emb_size_rbf, 597 | emb_size_cbf, 598 | activation=None, 599 | scale_file=None, 600 | name="TripletInteraction", 601 | **kwargs, 602 | ): 603 | super().__init__() 604 | self.name = name 605 | 606 | # Dense transformation 607 | self.dense_ba = Dense( 608 | emb_size_edge, 609 | emb_size_edge, 610 | activation=activation, 611 | bias=False, 612 | name="dense_ba", 613 | ) 614 | 615 | # Down projections of basis representations, bilinear layer and scaling factors 616 | self.mlp_rbf = Dense( 617 | emb_size_rbf, emb_size_edge, activation=None, name="MLP_rbf3_2", bias=False 618 | ) 619 | self.scale_rbf = ScalingFactor(scale_file=scale_file, name=name + "_had_rbf") 620 | 621 | self.mlp_cbf = EfficientInteractionBilinear( 622 | emb_size_trip, emb_size_cbf, emb_size_bilinear, name="MLP_cbf3_2" 623 | ) 624 | self.scale_cbf_sum = ScalingFactor( 625 | scale_file=scale_file, name=name + "_sum_cbf" 626 | ) # combines scaling for bilinear layer and summation 627 | 628 | # Down and up projections 629 | self.down_projection = Dense( 630 | emb_size_edge, 631 | emb_size_trip, 632 | activation=activation, 633 | bias=False, 634 | name="dense_down", 635 | ) 636 | self.up_projection_ca = Dense( 637 | emb_size_bilinear, 638 | emb_size_edge, 639 | activation=activation, 640 | bias=False, 641 | name="dense_up_ca", 642 | ) 643 | self.up_projection_ac = Dense( 644 | emb_size_bilinear, 645 | emb_size_edge, 646 | activation=activation, 647 | bias=False, 648 | name="dense_up_ac", 649 | ) 650 | 651 | self.inv_sqrt_2 = 1 / (2.0) ** 0.5 652 | 653 | def forward(self, 654 | m, 655 | rbf3, 656 | cbf3, 657 | Kidx3, 658 | id_swap, 659 | id3_expand_ba, 660 | id3_reduce_ca): 661 | """ 662 | Returns 663 | ------- 664 | m: Tensor, shape=(nEdges, emb_size_edge) 665 | Edge embeddings (c->a). 666 | """ 667 | # Dense transformation 668 | x_ba = self.dense_ba(m) # (nEdges, emb_size_edge) 669 | 670 | # Transform via radial bessel basis 671 | mlp_rbf = self.mlp_rbf(rbf3) # (nEdges, emb_size_edge) 672 | x_ba2 = x_ba * mlp_rbf 673 | x_ba = self.scale_rbf(x_ba, x_ba2) 674 | 675 | x_ba = self.down_projection(x_ba) # (nEdges, emb_size_trip) 676 | 677 | # Transform via circular spherical basis 678 | x_ba = x_ba[id3_expand_ba] # (nTriplets, emb_size_trip) 679 | 680 | # Efficient bilinear layer 681 | x = self.mlp_cbf(cbf3, x_ba, id3_reduce_ca, Kidx3) # (nEdges, emb_size_bilinear) 682 | x = self.scale_cbf_sum(x_ba, x) 683 | 684 | # Basis representation: 685 | # rbf(d_ba) 686 | # cbf(d_ca, angle_cab) 687 | 688 | # Up project embeddings 689 | x_ca = self.up_projection_ca(x) # (nEdges, emb_size_edge) 690 | x_ac = self.up_projection_ac(x) # (nEdges, emb_size_edge) 691 | 692 | # Merge interaction of c->a and a->c 693 | x_ac = x_ac[id_swap] # swap to add to edge a->c and not c->a 694 | x3 = x_ca + x_ac 695 | x3 = x3 * self.inv_sqrt_2 696 | return x3 697 | -------------------------------------------------------------------------------- /gemnet/model/layers/scaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from ..utils import read_value_json, update_json 4 | import logging 5 | 6 | 7 | class AutomaticFit: 8 | """ 9 | All added variables are processed in the order of creation. 10 | """ 11 | 12 | activeVar = None 13 | queue = None 14 | fitting_mode = False 15 | 16 | def __init__(self, variable, scale_file, name): 17 | self.variable = variable # variable to find value for 18 | self.scale_file = scale_file 19 | self._name = name 20 | 21 | self._fitted = False 22 | self.load_maybe() 23 | 24 | # first instance created 25 | if AutomaticFit.fitting_mode and not self._fitted: 26 | 27 | # if first layer set to active 28 | if AutomaticFit.activeVar is None: 29 | AutomaticFit.activeVar = self 30 | AutomaticFit.queue = [] # initialize 31 | # else add to queue 32 | else: 33 | self._add2queue() 34 | 35 | def reset(): 36 | AutomaticFit.activeVar = None 37 | AutomaticFit.all_processed = False 38 | 39 | def fitting_completed(): 40 | return AutomaticFit.queue is None 41 | 42 | def set2fitmode(): 43 | AutomaticFit.reset() 44 | AutomaticFit.fitting_mode = True 45 | 46 | def _add2queue(self): 47 | logging.debug(f"Add {self._name} to queue.") 48 | # check that same variable is not added twice 49 | for var in AutomaticFit.queue: 50 | if self._name == var._name: 51 | raise ValueError( 52 | f"Variable with the same name ({self._name}) was already added to queue!" 53 | ) 54 | AutomaticFit.queue += [self] 55 | 56 | def set_next_active(self): 57 | """ 58 | Set the next variable in the queue that should be fitted. 59 | """ 60 | queue = AutomaticFit.queue 61 | if len(queue) == 0: 62 | logging.debug("Processed all variables.") 63 | AutomaticFit.queue = None 64 | AutomaticFit.activeVar = None 65 | return 66 | AutomaticFit.activeVar = queue.pop(0) 67 | 68 | def load_maybe(self): 69 | """ 70 | Load variable from file or set to initial value of the variable. 71 | """ 72 | value = read_value_json(self.scale_file, self._name) 73 | if value is None: 74 | logging.info( 75 | f"Initialize variable {self._name}' to {self.variable.numpy():.3f}" 76 | ) 77 | else: 78 | self._fitted = True 79 | logging.debug(f"Set scale factor {self._name} : {value}") 80 | with torch.no_grad(): 81 | self.variable.copy_(torch.tensor(value)) 82 | 83 | 84 | class AutoScaleFit(AutomaticFit): 85 | """ 86 | Class to automatically fit the scaling factors depending on the observed variances. 87 | 88 | Parameters 89 | ---------- 90 | variable: tf.Variable 91 | Variable to fit. 92 | scale_file: str 93 | Path to the json file where to store/load from the scaling factors. 94 | """ 95 | 96 | def __init__(self, variable, scale_file, name): 97 | super().__init__(variable, scale_file, name) 98 | 99 | if not self._fitted: 100 | self._init_stats() 101 | 102 | def _init_stats(self): 103 | self.variance_in = 0 104 | self.variance_out = 0 105 | self.nSamples = 0 106 | 107 | def observe(self, x, y): 108 | """ 109 | Observe variances for inut x and output y. 110 | The scaling factor alpha is calculated s.t. Var(alpha * y) ~ Var(x) 111 | """ 112 | if self._fitted: 113 | return 114 | 115 | # only track stats for current variable 116 | if AutomaticFit.activeVar == self: 117 | nSamples = y.shape[0] 118 | self.variance_in += torch.mean(torch.var(x, dim=0)) * nSamples 119 | self.variance_out += torch.mean(torch.var(y, dim=0)) * nSamples 120 | self.nSamples += nSamples 121 | 122 | def fit(self): 123 | """ 124 | Fit the scaling factor based on the observed variances. 125 | """ 126 | if AutomaticFit.activeVar == self: 127 | if self.variance_in == 0: 128 | raise ValueError( 129 | f"Did not track the variable {self._name}. Add observe calls to track the variance before and after." 130 | ) 131 | 132 | # calculate variance preserving scaling factor 133 | self.variance_in = self.variance_in / self.nSamples 134 | self.variance_out = self.variance_out / self.nSamples 135 | 136 | ratio = self.variance_out / self.variance_in 137 | value = np.sqrt(1 / ratio, dtype="float32") 138 | logging.info( 139 | f"Variable: {self._name}, Var_in: {self.variance_in.numpy():.3f}, Var_out: {self.variance_out.numpy():.3f}, " 140 | + f"Ratio: {ratio:.3f} => Scaling factor: {value:.3f}" 141 | ) 142 | 143 | # set variable to calculated value 144 | with torch.no_grad(): 145 | self.variable.copy_(self.variable * value) 146 | update_json(self.scale_file, {self._name: float(self.variable.numpy())}) 147 | self.set_next_active() # set next variable in queue to active 148 | 149 | 150 | class ScalingFactor(torch.nn.Module): 151 | """ 152 | Scale the output y of the layer s.t. the (mean) variance wrt. to the reference input x_ref is preserved. 153 | 154 | Parameters 155 | ---------- 156 | scale_file: str 157 | Path to the json file where to store/load from the scaling factors. 158 | name: str 159 | Name of the scaling factor 160 | """ 161 | 162 | def __init__(self, scale_file, name, device=None): 163 | super().__init__() 164 | 165 | self.scale_factor = torch.nn.Parameter( 166 | torch.tensor(1.0, device=device), requires_grad=False 167 | ) 168 | self.autofit = AutoScaleFit(self.scale_factor, scale_file, name) 169 | 170 | def forward(self, x_ref, y): 171 | y = y * self.scale_factor 172 | self.autofit.observe(x_ref, y) 173 | 174 | return y 175 | -------------------------------------------------------------------------------- /gemnet/model/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def read_json(path): 5 | """ """ 6 | if not path.endswith(".json"): 7 | raise UserWarning(f"Path {path} is not a json-path.") 8 | 9 | with open(path, "r") as f: 10 | content = json.load(f) 11 | return content 12 | 13 | 14 | def update_json(path, data): 15 | """ """ 16 | if not path.endswith(".json"): 17 | raise UserWarning(f"Path {path} is not a json-path.") 18 | 19 | content = read_json(path) 20 | content.update(data) 21 | write_json(path, content) 22 | 23 | 24 | def write_json(path, data): 25 | """ """ 26 | if not path.endswith(".json"): 27 | raise UserWarning(f"Path {path} is not a json-path.") 28 | 29 | with open(path, "w", encoding="utf-8") as f: 30 | json.dump(data, f, ensure_ascii=False, indent=4) 31 | 32 | 33 | def read_value_json(path, key): 34 | """ """ 35 | content = read_json(path) 36 | 37 | if key in content.keys(): 38 | return content[key] 39 | else: 40 | return None 41 | -------------------------------------------------------------------------------- /gemnet/training/data_container.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import numba 4 | import torch 5 | 6 | 7 | class DataContainer: 8 | """ 9 | Parameters 10 | ---------- 11 | path: str 12 | Absolute path of the dataset (in npz-format). 13 | cutoff: float 14 | Insert edge between atoms if distance is less than cutoff. 15 | int_cutoff: float 16 | Cutoff of edge embeddings involved in quadruplet-based message passing. 17 | triplets_only: bool 18 | Flag whether to load quadruplet indices as well. 19 | transforms: list 20 | Transforms that should be applied on the whole dataset. 21 | addID: bool 22 | Whether to add the molecule id to the output. 23 | """ 24 | def __init__( 25 | self, 26 | path, 27 | cutoff, 28 | int_cutoff, 29 | triplets_only=False, 30 | transforms=None, 31 | addID=False, 32 | ): 33 | self.index_keys = [ 34 | "batch_seg", 35 | "id_undir", 36 | "id_swap", 37 | "id_c", 38 | "id_a", 39 | "id3_expand_ba", 40 | "id3_reduce_ca", 41 | "Kidx3", 42 | ] 43 | if not triplets_only: 44 | self.index_keys += [ 45 | "id4_int_b", 46 | "id4_int_a", 47 | "id4_reduce_ca", 48 | "id4_expand_db", 49 | "id4_reduce_cab", 50 | "id4_expand_abd", 51 | "Kidx4", 52 | "id4_reduce_intm_ca", 53 | "id4_expand_intm_db", 54 | "id4_reduce_intm_ab", 55 | "id4_expand_intm_ab", 56 | ] 57 | self.triplets_only = triplets_only 58 | self.cutoff = cutoff 59 | self.int_cutoff = int_cutoff 60 | self.addID = addID 61 | self.keys = ["N", "Z", "R", "F", "E"] 62 | if addID: 63 | self.keys += ["id"] 64 | 65 | self._load_npz(path, self.keys) # set keys as attributes 66 | 67 | if transforms is None: 68 | self.transforms = [] 69 | else: 70 | assert isinstance(transforms, (list, tuple)) 71 | self.transforms = transforms 72 | 73 | # modify dataset 74 | for transform in self.transforms: 75 | transform(self) 76 | 77 | assert self.R is not None 78 | assert self.N is not None 79 | assert self.Z is not None 80 | assert self.E is not None 81 | assert self.F is not None 82 | 83 | assert len(self.E) > 0 84 | assert len(self.F) > 0 85 | 86 | self.E = self.E[:, None] # shape=(nMolecules,1) 87 | self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)]) 88 | 89 | self.dtypes, dtypes2 = self.get_dtypes() 90 | self.dtypes.update(dtypes2) # merge all dtypes in single dict 91 | self.targets = ["E", "F"] 92 | 93 | def _load_npz(self, path, keys): 94 | """Load the keys from the file and set as attributes. 95 | 96 | Parameters 97 | ---------- 98 | path: str 99 | Absolute path of the dataset (in npz-format). 100 | keys: list 101 | Contains keys in the dataset to load and set as attributes. 102 | 103 | Returns 104 | ------- 105 | None 106 | """ 107 | with np.load(path, allow_pickle=True) as data: 108 | for key in keys: 109 | if key not in data.keys(): 110 | if key != "F": 111 | raise UserWarning(f"Can not find key {key} in the dataset.") 112 | else: 113 | setattr(self, key, data[key]) 114 | 115 | @staticmethod 116 | def _bmat_fast(mats): 117 | """Combines multiple adjacency matrices into single sparse block matrix. 118 | 119 | Parameters 120 | ---------- 121 | mats: list 122 | Has adjacency matrices as elements. 123 | 124 | Returns 125 | ------- 126 | adj_matrix: sp.csr_matrix 127 | Combined adjacency matrix (sparse block matrix) 128 | """ 129 | assert len(mats) > 0 130 | new_data = np.concatenate([mat.data for mat in mats]) 131 | 132 | ind_offset = np.zeros(1 + len(mats), dtype="int32") 133 | ind_offset[1:] = np.cumsum([mat.shape[0] for mat in mats]) 134 | new_indices = np.concatenate( 135 | [mats[i].indices + ind_offset[i] for i in range(len(mats))] 136 | ) 137 | 138 | indptr_offset = np.zeros(1 + len(mats)) 139 | indptr_offset[1:] = np.cumsum([mat.nnz for mat in mats]) 140 | new_indptr = np.concatenate( 141 | [mats[i].indptr[i >= 1 :] + indptr_offset[i] for i in range(len(mats))] 142 | ) 143 | 144 | # Resulting matrix shape: sum of matrices 145 | shape = (ind_offset[-1], ind_offset[-1]) 146 | 147 | # catch case with no edges 148 | if len(new_data) == 0: 149 | return sp.csr_matrix(shape) 150 | 151 | return sp.csr_matrix((new_data, new_indices, new_indptr), shape=shape) 152 | 153 | def __len__(self): 154 | return len(self.N) 155 | 156 | def __getitem__(self, idx): 157 | """ 158 | Parameters 159 | ---------- 160 | idx: array-like 161 | Ids of the molecules to get. 162 | 163 | Returns 164 | ------- 165 | data: dict 166 | nMolecules = len(idx) 167 | nAtoms = total sum of atoms in the selected molecules 168 | Contains following keys and values: 169 | 170 | - id: np.ndarray, shape (nMolecules,) 171 | Ids of the molecules in the dataset. 172 | - N: np.ndarray, shape (nMolecules,) 173 | Number of atoms in the molecules. 174 | - Z: np.ndarray, shape (nAtoms,) 175 | Atomic numbers (dt. Ordnungszahl). 176 | - R: np.ndarray, shape (nAtoms,3) 177 | Atom positions in °A. 178 | - F: np.ndarray, shape (nAtoms,3) 179 | Forces at the atoms in eV/°A. 180 | - E: np.ndarray, shape (nMolecules,1) 181 | Energy of the molecule in eV. 182 | - batch_seg: np.ndarray, shape (nAtoms,) 183 | Contains the index of the sample the atom belongs to. 184 | E.g. [0,0,0, 1,1,1,1, 2,...] where first molecule has 3 atoms, 185 | second molecule has 4 atoms etc. 186 | - id_c: np.ndarray, shape (nEdges,) 187 | Indices of edges' source atom. 188 | - id_a: np.ndarray, shape (nEdges,) 189 | Indices of edges' target atom. 190 | - id_undir: np.ndarray, shape (nEdges,) 191 | Indices where the same index denotes opposite edges, c-> and a->c. 192 | - id_swap: np.ndarray, shape (nEdges,) 193 | Indices to map c->a to a->c. 194 | - id3_expand_ba: np.ndarray, shape (nTriplets,) 195 | Indices to map the edges from c->a to b->a in the triplet-based massage passing. 196 | - id3_reduce_ca: np.ndarray, shape (nTriplets,) 197 | Indices to map the edges from c->a to c->a in the triplet-based massage passing. 198 | - Kidx3: np.ndarray, shape (nTriplets,) 199 | Indices to reshape the neighbor indices b->a into a dense matrix. 200 | - id4_int_a: np.ndarray, shape (nInterEdges,) 201 | Indices of the atom a of the interaction edge. 202 | - id4_int_b: np.ndarray, shape (nInterEdges,) 203 | Indices of the atom b of the interaction edge. 204 | - id4_reduce_ca: np.ndarray, shape (nQuadruplets,) 205 | Indices to map c->a to c->a in quadruplet-message passing. 206 | - id4_expand_db: np.ndarray, shape (nQuadruplets,) 207 | Indices to map c->a to d->b in quadruplet-message passing. 208 | - id4_reduce_intm_ca: np.ndarray, shape (intmTriplets,) 209 | Indices to map c->a to intermediate c->a. 210 | - id4_expand_intm_db: np.ndarray, shape (intmTriplets,) 211 | Indices to map d->b to intermediate d->b. 212 | - id4_reduce_intm_ab: np.ndarray, shape (intmTriplets,) 213 | Indices to map b-a to intermediate b-a of the quadruplet's part c->a-b. 214 | - id4_expand_intm_ab: np.ndarray, shape (intmTriplets,) 215 | Indices to map b-a to intermediate b-a of the quadruplet's part a-b<-d. 216 | - id4_reduce_cab: np.ndarray, shape (nQuadruplets,) 217 | Indices to map from intermediate c->a to quadruplet c->a. 218 | - id4_expand_abd: np.ndarray, shape (nQuadruplets,) 219 | Indices to map from intermediate d->b to quadruplet d->b. 220 | - Kidx4: np.ndarray, shape (nTriplets,) 221 | Indices to reshape the neighbor indices d->b into a dense matrix. 222 | """ 223 | if isinstance(idx, (int, np.int64, np.int32)): 224 | idx = [idx] 225 | if isinstance(idx, tuple): 226 | idx = list(idx) 227 | if isinstance(idx, slice): 228 | idx = np.arange(idx.start, min(idx.stop, len(self)), idx.step) 229 | 230 | data = {} 231 | if self.addID: 232 | data["id"] = self.id[idx] 233 | data["E"] = self.E[idx] 234 | data["N"] = self.N[idx] 235 | data["batch_seg"] = np.repeat(np.arange(len(idx), dtype=np.int32), data["N"]) 236 | 237 | data["Z"] = np.zeros(np.sum(data["N"]), dtype=np.int32) 238 | data["R"] = np.zeros([np.sum(data["N"]), 3], dtype=np.float32) 239 | data["F"] = np.zeros([np.sum(data["N"]), 3], dtype=np.float32) 240 | 241 | nend = 0 242 | adj_matrices = [] 243 | adj_matrices_int = [] 244 | for k, i in enumerate(idx): 245 | n = data["N"][k] 246 | nstart = nend 247 | nend = nstart + n 248 | s, e = ( 249 | self.N_cumsum[i], 250 | self.N_cumsum[i + 1], 251 | ) # start and end idx of atoms belonging to molecule 252 | 253 | data["F"][nstart:nend] = self.F[s:e] 254 | data["Z"][nstart:nend] = self.Z[s:e] 255 | R = self.R[s:e] 256 | data["R"][nstart:nend] = R 257 | 258 | D_ij = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1) 259 | # get adjacency matrix for embeddings 260 | adj_mat = sp.csr_matrix(D_ij <= self.cutoff) 261 | adj_mat -= sp.eye(n, dtype=np.bool) 262 | adj_matrices.append(adj_mat) 263 | 264 | if not self.triplets_only: 265 | # get adjacency matrix for interaction 266 | adj_mat = sp.csr_matrix(D_ij <= self.int_cutoff) 267 | adj_mat -= sp.eye(n, dtype=np.bool) 268 | adj_matrices_int.append(adj_mat) 269 | 270 | #### Indices of the moleule structure 271 | idx_data = {key: None for key in self.index_keys if key != "batch_seg"} 272 | # Entry A_ij is edge j -> i (!) 273 | adj_matrix = self._bmat_fast(adj_matrices) 274 | idx_t, idx_s = adj_matrix.nonzero() # target and source nodes 275 | 276 | if not self.triplets_only: 277 | # Entry A_ij is edge j -> i (!) 278 | adj_matrix_int = self._bmat_fast(adj_matrices_int) 279 | idx_int_t, idx_int_s = adj_matrix_int.nonzero() # target and source nodes 280 | 281 | # catch no edge case 282 | if len(idx_t) == 0: 283 | for key in idx_data.keys(): 284 | data[key] = np.array([], dtype="int32") 285 | return self.convert_to_tensor(data) 286 | 287 | # Get mask for undirected edges 0 1 nEdges/2 nEdges/2+1 288 | # change order of indices such that edge = [[0,1],[0,2], ..., [1,0], [2,0], ...] 289 | edges = np.stack([idx_t, idx_s], axis=0) 290 | mask = edges[0] < edges[1] 291 | edges = edges[:, mask] 292 | edges = np.concatenate([edges, edges[::-1]], axis=-1).astype("int32") 293 | idx_t, idx_s = edges[0], edges[1] 294 | indices = np.arange(len(mask) / 2, dtype="int32") 295 | idx_data["id_undir"] = np.concatenate(2 * [indices], axis=-1).astype("int32") 296 | 297 | idx_data["id_c"] = idx_s # node c is source 298 | idx_data["id_a"] = idx_t # node a is target 299 | 300 | if not self.triplets_only: 301 | idx_data["id4_int_a"] = idx_int_t 302 | idx_data["id4_int_b"] = idx_int_s 303 | # 0 1 ... nEdges/2 nEdges/2+1 304 | ## swap indices a->c to c->a: [nEdges/2 nEdges/2+1 ... 0 1 ... ] 305 | N_undir_edges = int(len(idx_s) / 2) 306 | ind = np.arange(N_undir_edges, dtype="int32") 307 | id_swap = np.concatenate([ind + N_undir_edges, ind]) 308 | idx_data["id_swap"] = id_swap 309 | 310 | # assign an edge_id to each edge 311 | edge_ids = sp.csr_matrix( 312 | (np.arange(len(idx_s)), (idx_t, idx_s)), 313 | shape=adj_matrix.shape, 314 | dtype="int32", 315 | ) 316 | 317 | #### ------------------------------------ Triplets ------------------------------------ #### 318 | id3_expand_ba, id3_reduce_ca = self.get_triplets(idx_s, idx_t, edge_ids) 319 | # embed msg from c -> a with all quadruplets for k and l: c -> a <- k <- l 320 | # id3_reduce_ca is for k -> a -> c but we want c -> a <- k 321 | id3_reduce_ca = id_swap[id3_reduce_ca] 322 | 323 | # --------------------- Needed for efficient implementation --------------------- # 324 | if len(id3_reduce_ca) > 0: 325 | # id_reduce_ca must be sorted (i.e. grouped would suffice) for ragged_range ! 326 | idx_sorted = np.argsort(id3_reduce_ca) 327 | id3_reduce_ca = id3_reduce_ca[idx_sorted] 328 | id3_expand_ba = id3_expand_ba[idx_sorted] 329 | _, K = np.unique(id3_reduce_ca, return_counts=True) 330 | idx_data["Kidx3"] = DataContainer.ragged_range( 331 | K 332 | ) # K = [1 4 2 3] -> Kidx3 = [0 0 1 2 3 0 1 0 1 2] , (nTriplets,) 333 | else: 334 | idx_data["Kidx3"] = np.array([], dtype="int32") 335 | # ------------------------------------------------------------------------------- # 336 | 337 | idx_data["id3_expand_ba"] = id3_expand_ba # (nTriplets,) 338 | idx_data["id3_reduce_ca"] = id3_reduce_ca # (nTriplets,) 339 | 340 | # node indices in triplet 341 | # id3_c = idx_s[id3_reduce_ca] 342 | # id3_a = idx_t[id3_reduce_ca] 343 | # assert np.all(id3_j == idx_t[id3_expand_ba]) 344 | # id3_b = idx_s[id3_expand_ba] 345 | #### ---------------------------------------------------------------------------------- #### 346 | 347 | if self.triplets_only: 348 | data.update(idx_data) 349 | return self.convert_to_tensor(data) 350 | 351 | #### ----------------------------------- Quadruplets ---------------------------------- #### 352 | 353 | # get quadruplets 354 | output = self.get_quadruplets( 355 | idx_s, idx_t, adj_matrix, edge_ids, idx_int_s, idx_int_t 356 | ) 357 | ( 358 | id4_reduce_ca, 359 | id4_expand_db, 360 | id4_reduce_cab, 361 | id4_expand_abd, 362 | id4_reduce_intm_ca, 363 | id4_expand_intm_db, 364 | id4_reduce_intm_ab, 365 | id4_expand_intm_ab, 366 | ) = output 367 | 368 | # --------------------- Needed for efficient implementation --------------------- # 369 | if len(id4_reduce_ca) > 0: 370 | # id4_reduce_ca has to be sorted (i.e. grouped would suffice) for ragged range ! 371 | sorted_idx = np.argsort(id4_reduce_ca) 372 | id4_reduce_ca = id4_reduce_ca[sorted_idx] 373 | id4_expand_db = id4_expand_db[sorted_idx] 374 | id4_reduce_cab = id4_reduce_cab[sorted_idx] 375 | id4_expand_abd = id4_expand_abd[sorted_idx] 376 | 377 | _, K = np.unique(id4_reduce_ca, return_counts=True) 378 | # K = [1 4 2 3] -> Kidx4 = [0 0 1 2 3 0 1 0 1 2] 379 | idx_data["Kidx4"] = DataContainer.ragged_range(K) # (nQuadruplets,) 380 | else: 381 | idx_data["Kidx4"] = np.array([], dtype="int32") 382 | # ------------------------------------------------------------------------------- # 383 | 384 | idx_data["id4_reduce_ca"] = id4_reduce_ca # (nQuadruplets,) 385 | idx_data["id4_expand_db"] = id4_expand_db # (nQuadruplets,) 386 | idx_data["id4_reduce_cab"] = id4_reduce_cab # (nQuadruplets,) 387 | idx_data["id4_expand_abd"] = id4_expand_abd # (nQuadruplets,) 388 | idx_data["id4_reduce_intm_ca"] = id4_reduce_intm_ca # (intmTriplets,) 389 | idx_data["id4_expand_intm_db"] = id4_expand_intm_db # (intmTriplets,) 390 | idx_data["id4_reduce_intm_ab"] = id4_reduce_intm_ab # (intmTriplets,) 391 | idx_data["id4_expand_intm_ab"] = id4_expand_intm_ab # (intmTriplets,) 392 | 393 | # # node indices in quadruplet 394 | # idx_c = idx_s[id4_reduce_ca] 395 | # idx_a = idx_t[id4_reduce_ca] 396 | # idx_b = idx_t[id4_expand_db] 397 | # idx_d = idx_s[id4_expand_db] 398 | # assert np.all(idx_c == idx_s[id4_reduce_intm_ca][id4_reduce_cab]) 399 | # assert np.all(idx_a == idx_t[id4_reduce_intm_ca][id4_reduce_cab]) 400 | # assert np.all(idx_a == idx_int_t[id4_reduce_intm_ab][id4_reduce_cab]) 401 | # assert np.all(idx_a == idx_int_t[id4_expand_intm_ab][id4_expand_abd]) 402 | # assert np.all(idx_b == idx_int_s[id4_reduce_intm_ab][id4_reduce_cab]) 403 | # assert np.all(idx_b == idx_int_s[id4_expand_intm_ab][id4_expand_abd]) 404 | # assert np.all(idx_b == idx_t[id4_expand_intm_db][id4_expand_abd]) 405 | # assert np.all(idx_d == idx_s[id4_expand_intm_db][id4_expand_abd]) 406 | 407 | data.update(idx_data) 408 | return self.convert_to_tensor(data) 409 | 410 | @staticmethod 411 | def get_triplets(idx_s, idx_t, edge_ids): 412 | """ 413 | Get triplets c -> a <- b 414 | """ 415 | # Edge indices of triplets k -> a -> i 416 | id3_expand_ba = edge_ids[idx_s].data.astype("int32").flatten() 417 | id3_reduce_ca = edge_ids[idx_s].tocoo().row.astype("int32").flatten() 418 | 419 | id3_i = idx_t[id3_reduce_ca] 420 | id3_k = idx_s[id3_expand_ba] 421 | mask = id3_i != id3_k 422 | id3_expand_ba = id3_expand_ba[mask] 423 | id3_reduce_ca = id3_reduce_ca[mask] 424 | 425 | return id3_expand_ba, id3_reduce_ca 426 | 427 | @staticmethod 428 | def get_quadruplets(idx_s, idx_t, adj_matrix, edge_ids, idx_int_s, idx_int_t): 429 | """ 430 | c -> a - b <- d where D_ab <= int_cutoff; D_ca & D_db <= cutoff 431 | """ 432 | # Number of incoming edges to target and source node of interaction edges 433 | nNeighbors_t = adj_matrix[idx_int_t].sum(axis=1).A1.astype("int32") 434 | nNeighbors_s = adj_matrix[idx_int_s].sum(axis=1).A1.astype("int32") 435 | id4_reduce_intm_ca = ( 436 | edge_ids[idx_int_t].data.astype("int32").flatten() 437 | ) # (intmTriplets,) 438 | id4_expand_intm_db = ( 439 | edge_ids[idx_int_s].data.astype("int32").flatten() 440 | ) # (intmTriplets,) 441 | # note that id4_reduce_intm_ca and id4_expand_intm_db have the same shape but 442 | # id4_reduce_intm_ca[i] and id4_expand_intm_db[i] may not belong to the same interacting quadruplet ! 443 | 444 | # each reduce edge (c->a) has to be repeated as often as there are neighbors for node b 445 | # vice verca for the edges of the source node (d->b) and node a 446 | id4_reduce_cab = DataContainer.repeat_blocks( 447 | nNeighbors_t, nNeighbors_s 448 | ) # (nQuadruplets,) 449 | id4_reduce_ca = id4_reduce_intm_ca[id4_reduce_cab] # intmTriplets -> nQuadruplets 450 | 451 | N = np.repeat(nNeighbors_t, nNeighbors_s) 452 | id4_expand_abd = np.repeat( 453 | np.arange(len(id4_expand_intm_db)), N 454 | ) # (nQuadruplets,) 455 | id4_expand_db = id4_expand_intm_db[id4_expand_abd] # intmTriplets -> nQuadruplets 456 | 457 | id4_reduce_intm_ab = np.repeat( 458 | np.arange(len(idx_int_t)), nNeighbors_t 459 | ) # (intmTriplets,) 460 | id4_expand_intm_ab = np.repeat( 461 | np.arange(len(idx_int_t)), nNeighbors_s 462 | ) # (intmTriplets,) 463 | 464 | # Mask out all quadruplets where nodes appear more than once 465 | idx_c = idx_s[id4_reduce_ca] 466 | idx_a = idx_t[id4_reduce_ca] 467 | idx_b = idx_t[id4_expand_db] 468 | idx_d = idx_s[id4_expand_db] 469 | 470 | mask1 = idx_c != idx_b 471 | mask2 = idx_a != idx_d 472 | mask3 = idx_c != idx_d 473 | mask = mask1 * mask2 * mask3 # logical and 474 | 475 | id4_reduce_ca = id4_reduce_ca[mask] 476 | id4_expand_db = id4_expand_db[mask] 477 | id4_reduce_cab = id4_reduce_cab[mask] 478 | id4_expand_abd = id4_expand_abd[mask] 479 | 480 | return ( 481 | id4_reduce_ca, 482 | id4_expand_db, 483 | id4_reduce_cab, 484 | id4_expand_abd, 485 | id4_reduce_intm_ca, 486 | id4_expand_intm_db, 487 | id4_reduce_intm_ab, 488 | id4_expand_intm_ab, 489 | ) 490 | 491 | def convert_to_tensor(self, data): 492 | for key in data: 493 | data[key] = torch.tensor(data[key], dtype=self.dtypes[key]) 494 | return data 495 | 496 | def get_dtypes(self): 497 | """ 498 | Returns 499 | ------- 500 | dtypes: tuple 501 | (dtypes_input, dtypes_target) TF input types for the inputs and targets 502 | stored in dicts. 503 | """ 504 | # dtypes of dataset values 505 | dtypes_input = {} 506 | if self.addID: 507 | dtypes_input["id"] = torch.int64 508 | dtypes_input["Z"] = torch.int64 509 | dtypes_input["N"] = torch.int64 510 | dtypes_input["R"] = torch.float32 511 | for key in self.index_keys: 512 | dtypes_input[key] = torch.int64 513 | 514 | dtypes_target = {} 515 | dtypes_target["E"] = torch.float32 516 | dtypes_target["F"] = torch.float32 517 | 518 | return dtypes_input, dtypes_target 519 | 520 | @staticmethod 521 | @numba.njit(nogil=True) 522 | def repeat_blocks(sizes, repeats): 523 | """Repeat blocks of indices. 524 | From https://stackoverflow.com/questions/51154989/numpy-vectorized-function-to-repeat-blocks-of-consecutive-elements 525 | 526 | Examples 527 | -------- 528 | sizes = [1,3,2] ; repeats = [3,2,3] 529 | Return: [0 0 0 1 2 3 1 2 3 4 5 4 5 4 5] 530 | sizes = [0,3,2] ; repeats = [3,2,3] 531 | Return: [0 1 2 0 1 2 3 4 3 4 3 4] 532 | sizes = [2,3,2] ; repeats = [2,0,2] 533 | Return: [0 1 0 1 5 6 5 6] 534 | """ 535 | a = np.arange(np.sum(sizes)) 536 | indices = np.empty((sizes * repeats).sum(), dtype=np.int32) 537 | start = 0 538 | oi = 0 539 | for i, size in enumerate(sizes): 540 | end = start + size 541 | for _ in range(repeats[i]): 542 | oe = oi + size 543 | indices[oi:oe] = a[start:end] 544 | oi = oe 545 | start = end 546 | return indices 547 | 548 | @staticmethod 549 | @numba.njit(nogil=True) 550 | def ragged_range(sizes): 551 | """ 552 | ------- 553 | Example 554 | ------- 555 | sizes = [1,3,2] ; 556 | Return: [0 0 1 2 0 1] 557 | """ 558 | a = np.arange(sizes.max()) 559 | indices = np.empty(sizes.sum(), dtype=np.int32) 560 | start = 0 561 | for size in sizes: 562 | end = start + size 563 | indices[start:end] = a[:size] 564 | start = end 565 | return indices 566 | -------------------------------------------------------------------------------- /gemnet/training/data_provider.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import DataLoader, Subset 5 | from torch.utils.data.sampler import ( 6 | BatchSampler, 7 | SubsetRandomSampler, 8 | SequentialSampler, 9 | ) 10 | 11 | def collate(batch, target_keys): 12 | """ 13 | custom batching function because batches have variable shape 14 | """ 15 | batch = batch[0] # already batched: Batching happens in DataContainer 16 | inputs = {} 17 | targets = {} 18 | for key in batch: 19 | if key in target_keys: 20 | targets[key] = batch[key] 21 | else: 22 | inputs[key] = batch[key] 23 | return inputs, targets 24 | 25 | class DataProvider: 26 | """ 27 | Parameters 28 | ---------- 29 | data_container: DataContainer 30 | Contains the dataset. 31 | ntrain: int 32 | Number of samples in the training set. 33 | nval: int 34 | Number of samples in the validation set. 35 | batch_size: int 36 | Number of samples to process at once. 37 | seed: int 38 | Seed for drawing samples into train and val set (and shuffle). 39 | random_split: bool 40 | If True put the samples randomly into the subsets else in order. 41 | shuffle: bool 42 | If True shuffle the samples after each epoch. 43 | sample_with_replacement: bool 44 | Sample data from the dataset with replacement. 45 | split: str/dict 46 | Overwrites settings of 'ntrain', 'nval', 'random_split' and 'sample_with_replacement'. 47 | If of type dict the dictionary is assumed to contain the index split of the subsets. 48 | If split is of type str then load the index split from the .npz-file. 49 | Dict and split file are assumed to have keys 'train', 'val', 'test'. 50 | """ 51 | 52 | def __init__( 53 | self, 54 | data_container, 55 | ntrain: int, 56 | nval: int, 57 | batch_size: int = 1, 58 | seed: int = None, 59 | random_split: bool = False, 60 | shuffle: bool = True, 61 | sample_with_replacement: bool = False, 62 | split = None, 63 | **kwargs 64 | ): 65 | self.kwargs = kwargs 66 | self.data_container = data_container 67 | self._ndata = len(data_container) 68 | self.batch_size = batch_size 69 | self.seed = seed 70 | self.random_split = random_split 71 | self.shuffle = shuffle 72 | self.sample_with_replacement = sample_with_replacement 73 | 74 | # Random state parameter, such that random operations are reproducible if wanted 75 | self._random_state = np.random.RandomState(seed=seed) 76 | 77 | if split is None: 78 | self.nsamples, self.idx = self._random_split_data(ntrain, nval) 79 | else: 80 | self.nsamples, self.idx = self._manual_split_data(split) 81 | 82 | def _manual_split_data(self, split): 83 | 84 | if isinstance(split, (dict,str)): 85 | if isinstance(split, str): 86 | # split is the path to the file containing the indices 87 | assert split.endswith(".npz") , "'split' has to be a .npz file if 'split' is of type str" 88 | split = np.load(split) 89 | 90 | keys = ["train", "val", "test"] 91 | for key in keys: 92 | assert key in split.keys(), f"{key} is not in {[k for k in split.keys()]}" 93 | 94 | idx = {key: np.array(split[key]) for key in keys} 95 | nsamples = {key: len(idx[key]) for key in keys} 96 | 97 | return nsamples, idx 98 | 99 | else: 100 | raise TypeError("'split' has to be either of type str or dict if not None.") 101 | 102 | def _random_split_data(self, ntrain, nval): 103 | 104 | nsamples = { 105 | "train": ntrain, 106 | "val": nval, 107 | "test": self._ndata - ntrain - nval, 108 | } 109 | 110 | all_idx = np.arange(self._ndata) 111 | if self.random_split: 112 | # Shuffle indices 113 | all_idx = self._random_state.permutation(all_idx) 114 | 115 | if self.sample_with_replacement: 116 | # Sample with replacement so as to train an ensemble of Dimenets 117 | all_idx = self._random_state.choice(all_idx, self._ndata, replace=True) 118 | 119 | # Store indices of training, validation and test data 120 | idx = { 121 | "train": all_idx[0:ntrain], 122 | "val": all_idx[ntrain : ntrain + nval], 123 | "test": all_idx[ntrain + nval :], 124 | } 125 | 126 | return nsamples, idx 127 | 128 | def save_split(self, path): 129 | """ 130 | Save the split of the samples to path. 131 | Data has keys 'train', 'val', 'test'. 132 | """ 133 | assert isinstance(path, str) 134 | assert path.endswith(".npz"), "'path' has to end with .npz" 135 | np.savez(path, **self.idx) 136 | 137 | def get_dataset(self, split, batch_size=None): 138 | assert split in self.idx 139 | if batch_size is None: 140 | batch_size = self.batch_size 141 | shuffle = self.shuffle if split == "train" else False 142 | 143 | indices = self.idx[split] 144 | if shuffle: 145 | torch_generator = torch.Generator() 146 | if self.seed is not None: 147 | torch_generator.manual_seed(self.seed) 148 | idx_sampler = SubsetRandomSampler(indices, torch_generator) 149 | dataset = self.data_container 150 | else: 151 | subset = Subset(self.data_container, indices) 152 | idx_sampler = SequentialSampler(subset) 153 | dataset = subset 154 | 155 | batch_sampler = BatchSampler( 156 | idx_sampler, batch_size=batch_size, drop_last=False 157 | ) 158 | 159 | dataloader = DataLoader( 160 | dataset, 161 | sampler=batch_sampler, 162 | collate_fn=functools.partial(collate, target_keys=self.data_container.targets), 163 | pin_memory=True, # load on CPU push to GPU 164 | **self.kwargs 165 | ) 166 | 167 | # loop infinitely 168 | # we use the generator as the rest of the code is based on steps and not epochs 169 | def generator(): 170 | while True: 171 | for inputs, targets in dataloader: 172 | yield inputs, targets 173 | 174 | return generator() 175 | -------------------------------------------------------------------------------- /gemnet/training/ema_decay.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from: 3 | https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py 4 | """ 5 | 6 | from __future__ import division 7 | from __future__ import unicode_literals 8 | 9 | from typing import Iterable, Optional 10 | import weakref 11 | import copy 12 | 13 | import torch 14 | 15 | 16 | # Partially based on: 17 | # https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py 18 | class ExponentialMovingAverage: 19 | """ 20 | Maintains (exponential) moving average of a set of parameters. 21 | 22 | Args: 23 | parameters: Iterable of `torch.nn.Parameter` (typically from 24 | `model.parameters()`). 25 | decay: The exponential decay. 26 | use_num_updates: Whether to use number of updates when computing 27 | averages. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | parameters: Iterable[torch.nn.Parameter], 33 | decay: float, 34 | use_num_updates: bool = False, 35 | ): 36 | if decay < 0.0 or decay > 1.0: 37 | raise ValueError("Decay must be between 0 and 1") 38 | self.decay = decay 39 | self.num_updates = 0 if use_num_updates else None 40 | parameters = list(parameters) 41 | self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] 42 | self.collected_params = [] 43 | # By maintaining only a weakref to each parameter, 44 | # we maintain the old GC behaviour of ExponentialMovingAverage: 45 | # if the model goes out of scope but the ExponentialMovingAverage 46 | # is kept, no references to the model or its parameters will be 47 | # maintained, and the model will be cleaned up. 48 | self._params_refs = [weakref.ref(p) for p in parameters] 49 | 50 | def _get_parameters( 51 | self, parameters: Optional[Iterable[torch.nn.Parameter]] 52 | ) -> Iterable[torch.nn.Parameter]: 53 | if parameters is None: 54 | parameters = [p() for p in self._params_refs] 55 | if any(p is None for p in parameters): 56 | raise ValueError( 57 | "(One of) the parameters with which this " 58 | "ExponentialMovingAverage " 59 | "was initialized no longer exists (was garbage collected);" 60 | " please either provide `parameters` explicitly or keep " 61 | "the model to which they belong from being garbage " 62 | "collected." 63 | ) 64 | return parameters 65 | else: 66 | return parameters 67 | 68 | def update(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None: 69 | """ 70 | Update currently maintained parameters. 71 | 72 | Call this every time the parameters are updated, such as the result of 73 | the `optimizer.step()` call. 74 | 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 77 | parameters used to initialize this object. If `None`, the 78 | parameters with which this `ExponentialMovingAverage` was 79 | initialized will be used. 80 | """ 81 | parameters = self._get_parameters(parameters) 82 | decay = self.decay 83 | if self.num_updates is not None: 84 | self.num_updates += 1 85 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) 86 | one_minus_decay = 1.0 - decay 87 | with torch.no_grad(): 88 | parameters = [p for p in parameters if p.requires_grad] 89 | for s_param, param in zip(self.shadow_params, parameters): 90 | tmp = s_param - param 91 | # tmp will be a new tensor so we can do in-place 92 | tmp.mul_(one_minus_decay) 93 | s_param.sub_(tmp) 94 | 95 | def copy_to( 96 | self, parameters: Optional[Iterable[torch.nn.Parameter]] = None 97 | ) -> None: 98 | """ 99 | Copy current parameters into given collection of parameters. 100 | 101 | Args: 102 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 103 | updated with the stored moving averages. If `None`, the 104 | parameters with which this `ExponentialMovingAverage` was 105 | initialized will be used. 106 | """ 107 | parameters = self._get_parameters(parameters) 108 | for s_param, param in zip(self.shadow_params, parameters): 109 | if param.requires_grad: 110 | param.data.copy_(s_param.data) 111 | 112 | def store(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None: 113 | """ 114 | Save the current parameters for restoring later. 115 | 116 | Args: 117 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 118 | temporarily stored. If `None`, the parameters of with which this 119 | `ExponentialMovingAverage` was initialized will be used. 120 | """ 121 | parameters = self._get_parameters(parameters) 122 | self.collected_params = [ 123 | param.clone() for param in parameters if param.requires_grad 124 | ] 125 | 126 | def restore( 127 | self, parameters: Optional[Iterable[torch.nn.Parameter]] = None 128 | ) -> None: 129 | """ 130 | Restore the parameters stored with the `store` method. 131 | Useful to validate the model with EMA parameters without affecting the 132 | original optimization process. Store the parameters before the 133 | `copy_to` method. After validation (or model saving), use this to 134 | restore the former parameters. 135 | 136 | Args: 137 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 138 | updated with the stored parameters. If `None`, the 139 | parameters with which this `ExponentialMovingAverage` was 140 | initialized will be used. 141 | """ 142 | parameters = self._get_parameters(parameters) 143 | for c_param, param in zip(self.collected_params, parameters): 144 | if param.requires_grad: 145 | param.data.copy_(c_param.data) 146 | 147 | def state_dict(self) -> dict: 148 | r"""Returns the state of the ExponentialMovingAverage as a dict.""" 149 | # Following PyTorch conventions, references to tensors are returned: 150 | # "returns a reference to the state and not its copy!" - 151 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict 152 | return { 153 | "decay": self.decay, 154 | "num_updates": self.num_updates, 155 | "shadow_params": self.shadow_params, 156 | "collected_params": self.collected_params, 157 | } 158 | 159 | def load_state_dict(self, state_dict: dict) -> None: 160 | r"""Loads the ExponentialMovingAverage state. 161 | 162 | Args: 163 | state_dict (dict): EMA state. Should be an object returned 164 | from a call to :meth:`state_dict`. 165 | """ 166 | # deepcopy, to be consistent with module API 167 | state_dict = copy.deepcopy(state_dict) 168 | self.decay = state_dict["decay"] 169 | if self.decay < 0.0 or self.decay > 1.0: 170 | raise ValueError("Decay must be between 0 and 1") 171 | self.num_updates = state_dict["num_updates"] 172 | assert self.num_updates is None or isinstance( 173 | self.num_updates, int 174 | ), "Invalid num_updates" 175 | self.shadow_params = state_dict["shadow_params"] 176 | assert isinstance(self.shadow_params, list), "shadow_params must be a list" 177 | assert all( 178 | isinstance(p, torch.Tensor) for p in self.shadow_params 179 | ), "shadow_params must all be Tensors" 180 | self.collected_params = state_dict["collected_params"] 181 | assert isinstance( 182 | self.collected_params, list 183 | ), "collected_params must be a list" 184 | assert all( 185 | isinstance(p, torch.Tensor) for p in self.collected_params 186 | ), "collected_params must all be Tensors" 187 | -------------------------------------------------------------------------------- /gemnet/training/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import logging 4 | 5 | 6 | class BestMetrics: 7 | """Class for saving the metrics. 8 | 9 | Parameters 10 | ---------- 11 | path: str 12 | Directory where to save the results in. 13 | metic: Metrics 14 | instance to save the best state of. 15 | assert_exist: bool 16 | If True raise UserWarning if the metrics should be restored but None are found. 17 | If False log Warning and frehsly initilaize the metrics. 18 | """ 19 | 20 | def __init__(self, path, metrics, assert_exist=True): 21 | self.path = os.path.join(path, "best_metrics.npz") 22 | self.metrics = metrics 23 | self.assert_exist = assert_exist 24 | self.state = {} 25 | 26 | def inititalize(self): 27 | self.state = {f"{k}_{self.metrics.tag}": np.inf for k in self.metrics.keys} 28 | self.state["step"] = 0 29 | np.savez(self.path, **self.state) 30 | 31 | def restore(self): 32 | if not os.path.isfile(self.path): 33 | string = f"Best metrics can not be restored as the file does not exist in the given path: {self.path}" 34 | if self.assert_exist: 35 | raise UserWarning(string) 36 | 37 | string += "\n Will initialize the best metrics." 38 | logging.warning(string) 39 | self.inititalize() 40 | else: 41 | loss_file = np.load(self.path) 42 | self.state = {k: v.item() for k, v in loss_file.items()} 43 | 44 | def items(self): 45 | return self.state.items() 46 | 47 | def update(self, step, metrics): 48 | self.state["step"] = step 49 | self.state.update(metrics.result()) 50 | np.savez(self.path, **self.state) 51 | 52 | def write(self, summary_writer, step): 53 | for key, val in self.state.items(): 54 | if key != "step": 55 | summary_writer.add_scalar(key + "_best", val, step) 56 | 57 | @property 58 | def loss(self): 59 | return self.state["loss_val"] 60 | 61 | @property 62 | def step(self): 63 | return self.state["step"] 64 | 65 | 66 | class MeanMetric: 67 | def __init__(self): 68 | self.reset_states() 69 | 70 | def update_state(self, values, sample_weight): 71 | self.values += sample_weight * values 72 | self.sample_weights += sample_weight 73 | 74 | def result(self): 75 | return self.values / self.sample_weights 76 | 77 | def reset_states(self): 78 | self.sample_weights = 0 79 | self.values = 0 80 | 81 | 82 | class Metrics: 83 | """Class for saving the metrics. 84 | 85 | Parameters 86 | ---------- 87 | tag: str 88 | Tag to add to the metric (e.g 'train' or 'val'). 89 | keys: list 90 | Name of the different metrics to watch (e.g. 'loss', 'mae' etc) 91 | ex: sacred.Eperiment 92 | Sacred experiment that keeps track of the metrics. 93 | """ 94 | 95 | def __init__(self, tag, keys, ex=None): 96 | self.tag = tag 97 | self.keys = keys 98 | self.ex = ex 99 | 100 | assert "loss" in self.keys 101 | self.mean_metrics = {} 102 | for key in self.keys: 103 | self.mean_metrics[key] = MeanMetric() 104 | 105 | def update_state(self, nsamples, **updates): 106 | """Update the metrics. 107 | 108 | Parameters 109 | ---------- 110 | nsamples: int 111 | Number of samples for which the updates where calculated on. 112 | updates: dict 113 | Contains metric updates. 114 | """ 115 | assert set(updates.keys()).issubset(set(self.keys)) 116 | for key in updates: 117 | self.mean_metrics[key].update_state( 118 | updates[key].cpu(), sample_weight=nsamples 119 | ) 120 | 121 | def write(self, summary_writer, step): 122 | """Write metrics to summary_writer (and the Sacred experiment).""" 123 | for key, val in self.result().items(): 124 | summary_writer.add_scalar(key, val, global_step=step) 125 | if self.ex is not None: 126 | if key not in self.ex.current_run.info: 127 | self.ex.current_run.info[key] = [] 128 | self.ex.current_run.info[key].append(val) 129 | 130 | if self.ex is not None: 131 | if f"step_{self.tag}" not in self.ex.current_run.info: 132 | self.ex.current_run.info[f"step_{self.tag}"] = [] 133 | self.ex.current_run.info[f"step_{self.tag}"].append(step) 134 | 135 | def reset_states(self): 136 | for key in self.keys: 137 | self.mean_metrics[key].reset_states() 138 | 139 | def result(self, append_tag=True): 140 | """ 141 | Parameters 142 | ---------- 143 | append_tag: bool 144 | If True append the tag to the key of the returned dict 145 | 146 | Returns 147 | ------- 148 | result_dict: dict 149 | Contains the numpy values of the metrics. 150 | """ 151 | result_dict = {} 152 | for key in self.keys: 153 | result_key = f"{key}_{self.tag}" if append_tag else key 154 | result_dict[result_key] = self.mean_metrics[key].result().numpy().item() 155 | return result_dict 156 | 157 | @property 158 | def loss(self): 159 | return self.mean_metrics["loss"].result().numpy().item() 160 | -------------------------------------------------------------------------------- /gemnet/training/schedules.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import LambdaLR 2 | 3 | class LinearWarmupExponentialDecay(LambdaLR): 4 | """This schedule combines a linear warmup with an exponential decay. 5 | 6 | Parameters 7 | ---------- 8 | optimizer: Optimizer 9 | Optimizer instance. 10 | decay_steps: float 11 | Number of steps until learning rate reaches learning_rate*decay_rate 12 | decay_rate: float 13 | Decay rate. 14 | warmup_steps: int 15 | Total number of warmup steps of the learning rate schedule. 16 | staircase: bool 17 | If True use staircase decay and not (continous) exponential decay. 18 | last_step: int 19 | Only needed when resuming training to resume learning rate schedule at this step. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | optimizer, 25 | warmup_steps, 26 | decay_steps, 27 | decay_rate, 28 | staircase=False, 29 | last_step=-1, 30 | verbose=False, 31 | ): 32 | assert decay_rate <= 1 33 | 34 | if warmup_steps == 0: 35 | warmup_steps = 1 36 | 37 | def lr_lambda(step): 38 | # step starts at 0 39 | warmup = min(1 / warmup_steps + 1 / warmup_steps * step, 1) 40 | exponent = step / decay_steps 41 | if staircase: 42 | exponent = int(exponent) 43 | decay = decay_rate ** exponent 44 | return warmup * decay 45 | 46 | super().__init__(optimizer, lr_lambda, last_epoch=last_step, verbose=verbose) 47 | -------------------------------------------------------------------------------- /predict.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\r\n", 10 | "import os\r\n", 11 | "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"1\"\r\n", 12 | "os.environ[\"AUTOGRAPH_VERBOSITY\"] = \"1\"\r\n", 13 | "\r\n", 14 | "# Set up logger\r\n", 15 | "import logging\r\n", 16 | "logger = logging.getLogger()\r\n", 17 | "logger.handlers = []\r\n", 18 | "ch = logging.StreamHandler()\r\n", 19 | "formatter = logging.Formatter(\r\n", 20 | " fmt=\"%(asctime)s (%(levelname)s): %(message)s\", datefmt=\"%Y-%m-%d %H:%M:%S\"\r\n", 21 | ")\r\n", 22 | "ch.setFormatter(formatter)\r\n", 23 | "logger.addHandler(ch)\r\n", 24 | "logger.setLevel(\"INFO\")\r\n", 25 | "\r\n", 26 | "import tensorflow as tf\r\n", 27 | "# TensorFlow logging verbosity\r\n", 28 | "tf.get_logger().setLevel(\"WARN\")\r\n", 29 | "tf.autograph.set_verbosity(1)\r\n", 30 | "\r\n", 31 | "# GemNet imports\r\n", 32 | "from gemnet.model.gemnet import GemNet\r\n", 33 | "from gemnet.training.data_container import DataContainer" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "# Custom molecule class to use molecules from ase" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "class Molecule(DataContainer):\r\n", 50 | " \"\"\"\r\n", 51 | " Implements the DataContainer but for a single molecule. Requires custom init method.\r\n", 52 | " \"\"\"\r\n", 53 | " def __init__(self, R, Z, cutoff, int_cutoff, triplets_only=False):\r\n", 54 | " self.index_keys = [\r\n", 55 | " \"batch_seg\",\r\n", 56 | " \"id_undir\",\r\n", 57 | " \"id_swap\",\r\n", 58 | " \"id_c\",\r\n", 59 | " \"id_a\",\r\n", 60 | " \"id3_expand_ba\",\r\n", 61 | " \"id3_reduce_ca\",\r\n", 62 | " \"Kidx3\",\r\n", 63 | " ]\r\n", 64 | " if not triplets_only:\r\n", 65 | " self.index_keys += [\r\n", 66 | " \"id4_int_b\",\r\n", 67 | " \"id4_int_a\",\r\n", 68 | " \"id4_reduce_ca\",\r\n", 69 | " \"id4_expand_db\",\r\n", 70 | " \"id4_reduce_cab\",\r\n", 71 | " \"id4_expand_abd\",\r\n", 72 | " \"Kidx4\",\r\n", 73 | " \"id4_reduce_intm_ca\",\r\n", 74 | " \"id4_expand_intm_db\",\r\n", 75 | " \"id4_reduce_intm_ab\",\r\n", 76 | " \"id4_expand_intm_ab\",\r\n", 77 | " ]\r\n", 78 | " self.triplets_only = triplets_only\r\n", 79 | " self.cutoff = cutoff\r\n", 80 | " self.int_cutoff = int_cutoff\r\n", 81 | " self.keys = [\"N\", \"Z\", \"R\", \"F\", \"E\"]\r\n", 82 | "\r\n", 83 | " assert R.shape == (len(Z), 3)\r\n", 84 | " self.R = R\r\n", 85 | " self.Z = Z\r\n", 86 | " self.N = np.array([len(Z)], dtype=np.int32)\r\n", 87 | " self.E = np.zeros(1, dtype=np.float32).reshape(1, 1)\r\n", 88 | " self.F = np.zeros((len(Z), 3), dtype=np.float32)\r\n", 89 | "\r\n", 90 | " self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)])\r\n", 91 | " self.addID = False\r\n", 92 | " self.dtypes, dtypes2 = self.get_dtypes()\r\n", 93 | " self.dtypes.update(dtypes2) # merge all dtypes in single dict\r\n", 94 | "\r\n", 95 | " def get(self):\r\n", 96 | " \"\"\"\r\n", 97 | " Get the molecule representation in the expected format for the GemNet model.\r\n", 98 | " \"\"\"\r\n", 99 | " data = self.__getitem__(0)\r\n", 100 | " for var in [\"E\", \"F\"]:\r\n", 101 | " data.pop(var) # not needed i.e.e not kown -> want to calculate this\r\n", 102 | " return data" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "# Setup the model and the data" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "# Model setup\r\n", 119 | "scale_file = \"./scaling_factors.json\"\r\n", 120 | "pytorch_weights_file = \"./pretrained/best/model.pth\"\r\n", 121 | "# depends on GemNet model that is loaded\r\n", 122 | "triplets_only = False\r\n", 123 | "direct_forces = False\r\n", 124 | "cutoff = 5.0\r\n", 125 | "int_cutoff = 10.0\r\n", 126 | "\r\n", 127 | "# Data setup\r\n", 128 | "from ase.build import molecule as ase_molecule_db\r\n", 129 | "\r\n", 130 | "mol = ase_molecule_db('C7NH5')\r\n", 131 | "R = mol.get_positions()\r\n", 132 | "Z = mol.get_atomic_numbers()\r\n", 133 | "\r\n", 134 | "molecule = Molecule(\r\n", 135 | " R, Z, cutoff=cutoff, int_cutoff=int_cutoff, triplets_only=triplets_only\r\n", 136 | ")" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "model = GemNet(\r\n", 146 | " num_spherical=7,\r\n", 147 | " num_radial=6,\r\n", 148 | " num_blocks=4,\r\n", 149 | " emb_size_atom=128,\r\n", 150 | " emb_size_edge=128,\r\n", 151 | " emb_size_trip=64,\r\n", 152 | " emb_size_quad=32,\r\n", 153 | " emb_size_rbf=16,\r\n", 154 | " emb_size_cbf=16,\r\n", 155 | " emb_size_sbf=32,\r\n", 156 | " emb_size_bil_trip=64,\r\n", 157 | " emb_size_bil_quad=32,\r\n", 158 | " num_before_skip=1,\r\n", 159 | " num_after_skip=1,\r\n", 160 | " num_concat=1,\r\n", 161 | " num_atom=2,\r\n", 162 | " num_targets=1,\r\n", 163 | " cutoff=cutoff,\r\n", 164 | " int_cutoff=int_cutoff, # no effect for GemNet-(d)T\r\n", 165 | " scale_file=scale_file,\r\n", 166 | " triplets_only=triplets_only,\r\n", 167 | " direct_forces=direct_forces,\r\n", 168 | ")\r\n", 169 | "# model.load_weights(pytorch_weights_file)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": {}, 175 | "source": [ 176 | "# Run the model" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "energy, forces = model.predict(molecule.get())\r\n", 186 | "\r\n", 187 | "print(\"Energy [eV]\", energy)\r\n", 188 | "print(\"Forces [eV/°A]\", forces)" 189 | ] 190 | } 191 | ], 192 | "metadata": { 193 | "interpreter": { 194 | "hash": "6d9d58ddb04bb635eba824a3c64b6d0110bcc4c6cff8b192a6f7cbbb2bf10de4" 195 | }, 196 | "kernelspec": { 197 | "display_name": "Python 3.5.4 64-bit", 198 | "name": "python3" 199 | }, 200 | "language_info": { 201 | "name": "python", 202 | "version": "" 203 | }, 204 | "orig_nbformat": 4 205 | }, 206 | "nbformat": 4, 207 | "nbformat_minor": 2 208 | } -------------------------------------------------------------------------------- /pretrained/GemNet-Q/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TUM-DAML/gemnet_pytorch/a0164f74217155232d39c35f0bb2c016bd3f44da/pretrained/GemNet-Q/model.pth -------------------------------------------------------------------------------- /pretrained/GemNet-Q/model_kwargs.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_spherical": 7, 3 | "num_radial": 6, 4 | "num_blocks": 4, 5 | "emb_size_atom": 128, 6 | "emb_size_edge": 128, 7 | "emb_size_trip": 64, 8 | "emb_size_quad": 32, 9 | "emb_size_rbf": 16, 10 | "emb_size_cbf": 16, 11 | "emb_size_sbf": 32, 12 | "emb_size_bil_trip": 64, 13 | "emb_size_bil_quad": 32, 14 | "num_before_skip": 1, 15 | "num_after_skip": 1, 16 | "num_concat": 1, 17 | "num_atom": 2, 18 | "triplets_only": false, 19 | "num_targets": 1, 20 | "direct_forces": false, 21 | "cutoff": 5.0, 22 | "int_cutoff": 10.0, 23 | "envelope_exponent": 5, 24 | "extensive": true, 25 | "forces_coupled": false, 26 | "output_init": "HeOrthogonal", 27 | "activation": "swish", 28 | "scale_file": "scaling_factors.json" 29 | } -------------------------------------------------------------------------------- /pretrained/GemNet-T/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TUM-DAML/gemnet_pytorch/a0164f74217155232d39c35f0bb2c016bd3f44da/pretrained/GemNet-T/model.pth -------------------------------------------------------------------------------- /pretrained/GemNet-T/model_kwargs.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_spherical": 7, 3 | "num_radial": 6, 4 | "num_blocks": 4, 5 | "emb_size_atom": 128, 6 | "emb_size_edge": 128, 7 | "emb_size_trip": 64, 8 | "emb_size_quad": 32, 9 | "emb_size_rbf": 16, 10 | "emb_size_cbf": 16, 11 | "emb_size_sbf": 32, 12 | "emb_size_bil_trip": 64, 13 | "emb_size_bil_quad": 32, 14 | "num_before_skip": 1, 15 | "num_after_skip": 1, 16 | "num_concat": 1, 17 | "num_atom": 2, 18 | "triplets_only": true, 19 | "num_targets": 1, 20 | "direct_forces": false, 21 | "cutoff": 5.0, 22 | "int_cutoff": 10.0, 23 | "envelope_exponent": 5, 24 | "extensive": true, 25 | "forces_coupled": false, 26 | "output_init": "HeOrthogonal", 27 | "activation": "swish", 28 | "scale_file": "scaling_factors.json" 29 | } -------------------------------------------------------------------------------- /pretrained/scaling_factors.json: -------------------------------------------------------------------------------- 1 | { 2 | "QuadInteraction_1_had_rbf": 3.838575780391693, 3 | "QuadInteraction_1_had_cbf": 30.91912269592285, 4 | "QuadInteraction_1_sum_sbf": 2.015521287918091, 5 | "TripInteraction_1_had_rbf": 2.9607054591178894, 6 | "TripInteraction_1_sum_cbf": 5.57607889175415, 7 | "AtomUpdate_1_sum": 1.0634181648492813, 8 | "QuadInteraction_2_had_rbf": 3.4999656677246094, 9 | "QuadInteraction_2_had_cbf": 29.30245018005371, 10 | "QuadInteraction_2_sum_sbf": 1.9548791646957397, 11 | "TripInteraction_2_had_rbf": 3.0770468711853027, 12 | "TripInteraction_2_sum_cbf": 6.400703430175781, 13 | "AtomUpdate_2_sum": 1.023792326450348, 14 | "QuadInteraction_3_had_rbf": 3.506244122982025, 15 | "QuadInteraction_3_had_cbf": 30.250303268432617, 16 | "QuadInteraction_3_sum_sbf": 1.9496761560440063, 17 | "TripInteraction_3_had_rbf": 3.4999406337738037, 18 | "TripInteraction_3_sum_cbf": 5.825993537902832, 19 | "AtomUpdate_3_sum": 0.8776205033063889, 20 | "QuadInteraction_4_had_rbf": 3.420105278491974, 21 | "QuadInteraction_4_had_cbf": 30.560321807861328, 22 | "QuadInteraction_4_sum_sbf": 2.013889789581299, 23 | "TripInteraction_4_had_rbf": 3.34897518157959, 24 | "TripInteraction_4_sum_cbf": 5.816178321838379, 25 | "AtomUpdate_4_sum": 0.8766722679138184, 26 | "OutBlock_0_sum": 1.1001640558242798, 27 | "OutBlock_0_had": 3.786764442920685, 28 | "OutBlock_1_sum": 0.989106222987175, 29 | "OutBlock_1_had": 2.9567965865135193, 30 | "OutBlock_2_sum": 0.9261481463909149, 31 | "OutBlock_2_had": 2.9033637046813965, 32 | "OutBlock_3_sum": 0.8048739284276962, 33 | "OutBlock_3_had": 2.95436292886734, 34 | "OutBlock_4_sum": 0.8166412264108658, 35 | "OutBlock_4_had": 3.0642566084861755 36 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy>=1.3 3 | sympy>=1.5 4 | tqdm 5 | ase 6 | nglview 7 | numba 8 | jupyterlab 9 | torch==1.10 10 | torch-scatter 11 | -------------------------------------------------------------------------------- /scaling_factors.json: -------------------------------------------------------------------------------- 1 | { 2 | "comment": "GemNet", 3 | "QuadInteraction_1_had_rbf": 3.838575780391693, 4 | "QuadInteraction_1_had_cbf": 30.91912269592285, 5 | "QuadInteraction_1_sum_sbf": 2.015521287918091, 6 | "TripInteraction_1_had_rbf": 2.9607054591178894, 7 | "TripInteraction_1_sum_cbf": 5.57607889175415, 8 | "AtomUpdate_1_sum": 1.0634181648492813, 9 | "QuadInteraction_2_had_rbf": 3.4999656677246094, 10 | "QuadInteraction_2_had_cbf": 29.30245018005371, 11 | "QuadInteraction_2_sum_sbf": 1.9548791646957397, 12 | "TripInteraction_2_had_rbf": 3.0770468711853027, 13 | "TripInteraction_2_sum_cbf": 6.400703430175781, 14 | "AtomUpdate_2_sum": 1.023792326450348, 15 | "QuadInteraction_3_had_rbf": 3.506244122982025, 16 | "QuadInteraction_3_had_cbf": 30.250303268432617, 17 | "QuadInteraction_3_sum_sbf": 1.9496761560440063, 18 | "TripInteraction_3_had_rbf": 3.4999406337738037, 19 | "TripInteraction_3_sum_cbf": 5.825993537902832, 20 | "AtomUpdate_3_sum": 0.8776205033063889, 21 | "QuadInteraction_4_had_rbf": 3.420105278491974, 22 | "QuadInteraction_4_had_cbf": 30.560321807861328, 23 | "QuadInteraction_4_sum_sbf": 2.013889789581299, 24 | "TripInteraction_4_had_rbf": 3.34897518157959, 25 | "TripInteraction_4_sum_cbf": 5.816178321838379, 26 | "AtomUpdate_4_sum": 0.8766722679138184, 27 | "OutBlock_0_sum": 1.1001640558242798, 28 | "OutBlock_0_had": 3.786764442920685, 29 | "OutBlock_1_sum": 0.989106222987175, 30 | "OutBlock_1_had": 2.9567965865135193, 31 | "OutBlock_2_sum": 0.9261481463909149, 32 | "OutBlock_2_had": 2.9033637046813965, 33 | "OutBlock_3_sum": 0.8048739284276962, 34 | "OutBlock_3_had": 2.95436292886734, 35 | "OutBlock_4_sum": 0.8166412264108658, 36 | "OutBlock_4_had": 3.0642566084861755 37 | } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open("requirements.txt", "r") as f: 4 | install_requires = f.read().splitlines() 5 | 6 | setup( 7 | name="gemnet_pytorch", 8 | version="1.0", 9 | description="GemNet: Universal Directional Graph Neural Networks for Molecules", 10 | author="Johannes Gasteiger, Florian Becker, Stephan Günnemann", 11 | author_email="j.gasteiger@in.tum.de", 12 | packages=["gemnet"], 13 | install_requires=install_requires, 14 | zip_safe=False, 15 | python_requires='>=3.8', 16 | ) 17 | -------------------------------------------------------------------------------- /train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Set up logger\r\n", 10 | "import os\r\n", 11 | "import logging\r\n", 12 | "\r\n", 13 | "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"1\"\r\n", 14 | "os.environ[\"AUTOGRAPH_VERBOSITY\"] = \"1\"\r\n", 15 | "\r\n", 16 | "logger = logging.getLogger()\r\n", 17 | "logger.handlers = []\r\n", 18 | "ch = logging.StreamHandler()\r\n", 19 | "formatter = logging.Formatter(\r\n", 20 | " fmt=\"%(asctime)s (%(levelname)s): %(message)s\", datefmt=\"%Y-%m-%d %H:%M:%S\"\r\n", 21 | ")\r\n", 22 | "ch.setFormatter(formatter)\r\n", 23 | "logger.addHandler(ch)\r\n", 24 | "logger.setLevel(\"INFO\")" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "import numpy as np\r\n", 34 | "import yaml\r\n", 35 | "import string\r\n", 36 | "import ast\r\n", 37 | "import random\r\n", 38 | "import time\r\n", 39 | "from datetime import datetime\r\n", 40 | "\r\n", 41 | "from gemnet.model.gemnet import GemNet\r\n", 42 | "from gemnet.training.trainer import Trainer\r\n", 43 | "from gemnet.training.metrics import Metrics, BestMetrics\r\n", 44 | "from gemnet.training.data_container import DataContainer\r\n", 45 | "from gemnet.training.data_provider import DataProvider\r\n", 46 | "\r\n", 47 | "import torch\r\n", 48 | "from torch.utils.tensorboard import SummaryWriter" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "# Load config file" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "with open('config.yaml', 'r') as c:\r\n", 65 | " config = yaml.safe_load(c)\r\n", 66 | " \r\n", 67 | "# For strings that yaml doesn't parse (e.g. None)\r\n", 68 | "for key, val in config.items():\r\n", 69 | " if type(val) is str:\r\n", 70 | " try:\r\n", 71 | " config[key] = ast.literal_eval(val)\r\n", 72 | " except (ValueError, SyntaxError):\r\n", 73 | " pass" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "num_spherical = config[\"num_spherical\"]\r\n", 83 | "num_radial = config[\"num_radial\"]\r\n", 84 | "num_blocks = config[\"num_blocks\"]\r\n", 85 | "emb_size_atom = config[\"emb_size_atom\"]\r\n", 86 | "emb_size_edge = config[\"emb_size_edge\"]\r\n", 87 | "emb_size_trip = config[\"emb_size_trip\"]\r\n", 88 | "emb_size_quad = config[\"emb_size_quad\"]\r\n", 89 | "emb_size_rbf = config[\"emb_size_rbf\"]\r\n", 90 | "emb_size_cbf = config[\"emb_size_cbf\"]\r\n", 91 | "emb_size_sbf = config[\"emb_size_sbf\"]\r\n", 92 | "num_before_skip = config[\"num_before_skip\"]\r\n", 93 | "num_after_skip = config[\"num_after_skip\"]\r\n", 94 | "num_concat = config[\"num_concat\"]\r\n", 95 | "num_atom = config[\"num_atom\"]\r\n", 96 | "emb_size_bil_quad = config[\"emb_size_bil_quad\"]\r\n", 97 | "emb_size_bil_trip = config[\"emb_size_bil_trip\"]\r\n", 98 | "triplets_only = config[\"triplets_only\"]\r\n", 99 | "forces_coupled = config[\"forces_coupled\"]\r\n", 100 | "direct_forces = config[\"direct_forces\"]\r\n", 101 | "mve = config[\"mve\"]\r\n", 102 | "cutoff = config[\"cutoff\"]\r\n", 103 | "int_cutoff = config[\"int_cutoff\"]\r\n", 104 | "envelope_exponent = config[\"envelope_exponent\"]\r\n", 105 | "extensive = config[\"extensive\"]\r\n", 106 | "output_init = config[\"output_init\"]\r\n", 107 | "scale_file = config[\"scale_file\"]\r\n", 108 | "data_seed = config[\"data_seed\"]\r\n", 109 | "dataset = config[\"dataset\"]\r\n", 110 | "val_dataset = config[\"val_dataset\"]\r\n", 111 | "num_train = config[\"num_train\"]\r\n", 112 | "num_val = config[\"num_val\"]\r\n", 113 | "logdir = config[\"logdir\"]\r\n", 114 | "loss = config[\"loss\"]\r\n", 115 | "tfseed = config[\"tfseed\"]\r\n", 116 | "num_steps = config[\"num_steps\"]\r\n", 117 | "rho_force = config[\"rho_force\"]\r\n", 118 | "ema_decay = config[\"ema_decay\"]\r\n", 119 | "weight_decay = config[\"weight_decay\"]\r\n", 120 | "grad_clip_max = config[\"grad_clip_max\"]\r\n", 121 | "agc = config[\"agc\"]\r\n", 122 | "decay_patience = config[\"decay_patience\"]\r\n", 123 | "decay_factor = config[\"decay_factor\"]\r\n", 124 | "decay_cooldown = config[\"decay_cooldown\"]\r\n", 125 | "batch_size = config[\"batch_size\"]\r\n", 126 | "evaluation_interval = config[\"evaluation_interval\"]\r\n", 127 | "patience = config[\"patience\"]\r\n", 128 | "save_interval = config[\"save_interval\"]\r\n", 129 | "learning_rate = config[\"learning_rate\"]\r\n", 130 | "warmup_steps = config[\"warmup_steps\"]\r\n", 131 | "decay_steps = config[\"decay_steps\"]\r\n", 132 | "decay_rate = config[\"decay_rate\"]\r\n", 133 | "staircase = config[\"staircase\"]\r\n", 134 | "restart = config[\"restart\"]\r\n", 135 | "comment = config[\"comment\"]" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "# Set paths and create directories" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "torch.manual_seed(tfseed)\r\n", 152 | "\r\n", 153 | "logging.info(\"Start training\")\r\n", 154 | "num_gpus = torch.cuda.device_count()\r\n", 155 | "cuda_available = torch.cuda.is_available()\r\n", 156 | "logging.info(f\"Available GPUs: {num_gpus}\")\r\n", 157 | "logging.info(f\"CUDA Available: {cuda_available}\")\r\n", 158 | "if num_gpus == 0:\r\n", 159 | " logging.warning(\"No GPUs were found. Training is run on CPU!\")\r\n", 160 | "if not cuda_available:\r\n", 161 | " logging.warning(\"CUDA unavailable. Training is run on CPU!\")\r\n", 162 | "\r\n", 163 | "# Used for creating a \"unique\" id for a run (almost impossible to generate the same twice)\r\n", 164 | "def id_generator(\r\n", 165 | " size=6, chars=string.ascii_uppercase + string.ascii_lowercase + string.digits\r\n", 166 | "):\r\n", 167 | " return \"\".join(random.SystemRandom().choice(chars) for _ in range(size))\r\n", 168 | "\r\n", 169 | "# A unique directory name is created for this run based on the input\r\n", 170 | "if (restart is None) or (restart == \"None\"):\r\n", 171 | " directory = (\r\n", 172 | " logdir\r\n", 173 | " + \"/\"\r\n", 174 | " + datetime.now().strftime(\"%Y%m%d_%H%M%S\")\r\n", 175 | " + \"_\"\r\n", 176 | " + id_generator()\r\n", 177 | " + \"_\"\r\n", 178 | " + os.path.basename(dataset)\r\n", 179 | " + \"_\"\r\n", 180 | " + str(comment)\r\n", 181 | " )\r\n", 182 | "else:\r\n", 183 | " directory = restart\r\n", 184 | "\r\n", 185 | "logging.info(f\"Directory: {directory}\")\r\n", 186 | "logging.info(\"Create directories\")\r\n", 187 | "\r\n", 188 | "if not os.path.exists(directory):\r\n", 189 | " os.makedirs(directory, exist_ok=True)\r\n", 190 | "\r\n", 191 | "best_dir = os.path.join(directory, \"best\")\r\n", 192 | "if not os.path.exists(best_dir):\r\n", 193 | " os.makedirs(best_dir)\r\n", 194 | "log_dir = os.path.join(directory, \"logs\")\r\n", 195 | "if not os.path.exists(log_dir):\r\n", 196 | " os.makedirs(log_dir)\r\n", 197 | "\r\n", 198 | "extension = \".pth\"\r\n", 199 | "log_path_model = f\"{log_dir}/model{extension}\"\r\n", 200 | "log_path_training = f\"{log_dir}/training{extension}\"\r\n", 201 | "best_path_model = f\"{best_dir}/model{extension}\"" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "# Initialize model" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "logging.info(\"Initialize model\")\r\n", 218 | "model = GemNet(\r\n", 219 | " num_spherical=num_spherical,\r\n", 220 | " num_radial=num_radial,\r\n", 221 | " num_blocks=num_blocks,\r\n", 222 | " emb_size_atom=emb_size_atom,\r\n", 223 | " emb_size_edge=emb_size_edge,\r\n", 224 | " emb_size_trip=emb_size_trip,\r\n", 225 | " emb_size_quad=emb_size_quad,\r\n", 226 | " emb_size_rbf=emb_size_rbf,\r\n", 227 | " emb_size_cbf=emb_size_cbf,\r\n", 228 | " emb_size_sbf=emb_size_sbf,\r\n", 229 | " num_before_skip=num_before_skip,\r\n", 230 | " num_after_skip=num_after_skip,\r\n", 231 | " num_concat=num_concat,\r\n", 232 | " num_atom=num_atom,\r\n", 233 | " emb_size_bil_quad=emb_size_bil_quad,\r\n", 234 | " emb_size_bil_trip=emb_size_bil_trip,\r\n", 235 | " num_targets=2 if mve else 1,\r\n", 236 | " triplets_only=triplets_only,\r\n", 237 | " direct_forces=direct_forces,\r\n", 238 | " forces_coupled=forces_coupled,\r\n", 239 | " cutoff=cutoff,\r\n", 240 | " int_cutoff=int_cutoff,\r\n", 241 | " envelope_exponent=envelope_exponent,\r\n", 242 | " activation=\"swish\",\r\n", 243 | " extensive=extensive,\r\n", 244 | " output_init=output_init,\r\n", 245 | " scale_file=scale_file,\r\n", 246 | ")\r\n", 247 | "# push to GPU if available\r\n", 248 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\n", 249 | "model.to(device)" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "# Load dataset" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "train = {}\r\n", 266 | "validation = {}\r\n", 267 | "\r\n", 268 | "logging.info(\"Load dataset\")\r\n", 269 | "data_container = DataContainer(\r\n", 270 | " dataset, cutoff=cutoff, int_cutoff=int_cutoff, triplets_only=triplets_only\r\n", 271 | ")\r\n", 272 | "\r\n", 273 | "if val_dataset is not None:\r\n", 274 | " # Initialize DataProvider\r\n", 275 | " if num_train == 0:\r\n", 276 | " num_train = len(data_container)\r\n", 277 | " logging.info(f\"Training data size: {num_train}\")\r\n", 278 | " data_provider = DataProvider(\r\n", 279 | " data_container,\r\n", 280 | " num_train,\r\n", 281 | " 0,\r\n", 282 | " batch_size,\r\n", 283 | " seed=data_seed,\r\n", 284 | " shuffle=True,\r\n", 285 | " random_split=True,\r\n", 286 | " )\r\n", 287 | "\r\n", 288 | " # Initialize validation datasets\r\n", 289 | " val_data_container = DataContainer(\r\n", 290 | " val_dataset,\r\n", 291 | " cutoff=cutoff,\r\n", 292 | " int_cutoff=int_cutoff,\r\n", 293 | " triplets_only=triplets_only,\r\n", 294 | " )\r\n", 295 | " if num_val == 0:\r\n", 296 | " num_val = len(val_data_container)\r\n", 297 | " logging.info(f\"Validation data size: {num_val}\")\r\n", 298 | " val_data_provider = DataProvider(\r\n", 299 | " val_data_container,\r\n", 300 | " 0,\r\n", 301 | " num_val,\r\n", 302 | " batch_size,\r\n", 303 | " seed=data_seed,\r\n", 304 | " shuffle=True,\r\n", 305 | " random_split=True,\r\n", 306 | " )\r\n", 307 | "else:\r\n", 308 | " # Initialize DataProvider (splits dataset into 3 sets based on data_seed and provides tf.datasets)\r\n", 309 | " logging.info(f\"Training data size: {num_train}\")\r\n", 310 | " logging.info(f\"Validation data size: {num_val}\")\r\n", 311 | " assert num_train > 0\r\n", 312 | " assert num_val > 0\r\n", 313 | " data_provider = DataProvider(\r\n", 314 | " data_container,\r\n", 315 | " num_train,\r\n", 316 | " num_val,\r\n", 317 | " batch_size,\r\n", 318 | " seed=data_seed,\r\n", 319 | " shuffle=True,\r\n", 320 | " random_split=True,\r\n", 321 | " )\r\n", 322 | " val_data_provider = data_provider\r\n", 323 | "\r\n", 324 | "# Initialize datasets\r\n", 325 | "train[\"dataset_iter\"] = data_provider.get_dataset(\"train\")\r\n", 326 | "validation[\"dataset_iter\"] = val_data_provider.get_dataset(\"val\")" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "# Prepare training" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "logging.info(\"Prepare training\")\r\n", 343 | "# Initialize trainer\r\n", 344 | "trainer = Trainer(\r\n", 345 | " model,\r\n", 346 | " learning_rate=learning_rate,\r\n", 347 | " decay_steps=decay_steps,\r\n", 348 | " decay_rate=decay_rate,\r\n", 349 | " warmup_steps=warmup_steps,\r\n", 350 | " weight_decay=weight_decay,\r\n", 351 | " ema_decay=ema_decay,\r\n", 352 | " decay_patience=decay_patience,\r\n", 353 | " decay_factor=decay_factor,\r\n", 354 | " decay_cooldown=decay_cooldown,\r\n", 355 | " grad_clip_max=grad_clip_max,\r\n", 356 | " rho_force=rho_force,\r\n", 357 | " mve=mve,\r\n", 358 | " loss=loss,\r\n", 359 | " staircase=staircase,\r\n", 360 | " agc=agc,\r\n", 361 | ")\r\n", 362 | "\r\n", 363 | "# Initialize metrics\r\n", 364 | "train[\"metrics\"] = Metrics(\"train\", trainer.tracked_metrics)\r\n", 365 | "validation[\"metrics\"] = Metrics(\"val\", trainer.tracked_metrics)\r\n", 366 | "\r\n", 367 | "# Save/load best recorded loss (only the best model is saved)\r\n", 368 | "metrics_best = BestMetrics(best_dir, validation[\"metrics\"])\r\n", 369 | "\r\n", 370 | "# Set up checkpointing\r\n", 371 | "# Restore latest checkpoint\r\n", 372 | "if os.path.exists(log_path_model):\r\n", 373 | " logging.info(\"Restoring model and trainer\")\r\n", 374 | " model_checkpoint = torch.load(log_path_model)\r\n", 375 | " model.load_state_dict(model_checkpoint[\"model\"])\r\n", 376 | "\r\n", 377 | " train_checkpoint = torch.load(log_path_training)\r\n", 378 | " trainer.load_state_dict(train_checkpoint[\"trainer\"])\r\n", 379 | " # restore the best saved results\r\n", 380 | " metrics_best.restore()\r\n", 381 | " logging.info(f\"Restored best metrics: {metrics_best.loss}\")\r\n", 382 | " step_init = int(train_checkpoint[\"step\"])\r\n", 383 | "else:\r\n", 384 | " logging.info(\"Freshly initialize model\")\r\n", 385 | " metrics_best.inititalize()\r\n", 386 | " step_init = 0" 387 | ] 388 | }, 389 | { 390 | "cell_type": "markdown", 391 | "metadata": {}, 392 | "source": [ 393 | "# Training loop" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": null, 399 | "metadata": {}, 400 | "outputs": [], 401 | "source": [ 402 | "summary_writer = SummaryWriter(log_dir)\r\n", 403 | "steps_per_epoch = int(np.ceil(num_train / batch_size))\r\n", 404 | "\r\n", 405 | "for step in range(step_init + 1, num_steps + 1):\r\n", 406 | "\r\n", 407 | " # keep track of the learning rate\r\n", 408 | " if step % 10 == 0:\r\n", 409 | " lr = trainer.schedulers[0].get_last_lr()[0]\r\n", 410 | " summary_writer.add_scalar(\"lr\", lr, global_step=step)\r\n", 411 | "\r\n", 412 | " # Perform training step\r\n", 413 | " trainer.train_on_batch(train[\"dataset_iter\"], train[\"metrics\"])\r\n", 414 | "\r\n", 415 | " # Save progress\r\n", 416 | " if step % save_interval == 0:\r\n", 417 | " torch.save({\"model\": model.state_dict()}, log_path_model)\r\n", 418 | " torch.save(\r\n", 419 | " {\"trainer\": trainer.state_dict(), \"step\": step}, log_path_training\r\n", 420 | " )\r\n", 421 | "\r\n", 422 | " # Check performance on the validation set\r\n", 423 | " if step % evaluation_interval == 0:\r\n", 424 | "\r\n", 425 | " # Save backup variables and load averaged variables\r\n", 426 | " trainer.save_variable_backups()\r\n", 427 | " trainer.load_averaged_variables()\r\n", 428 | "\r\n", 429 | " # Compute averages\r\n", 430 | " for i in range(int(np.ceil(num_val / batch_size))):\r\n", 431 | " trainer.test_on_batch(validation[\"dataset_iter\"], validation[\"metrics\"])\r\n", 432 | "\r\n", 433 | " # Update and save best result\r\n", 434 | " if validation[\"metrics\"].loss < metrics_best.loss:\r\n", 435 | " metrics_best.update(step, validation[\"metrics\"])\r\n", 436 | " torch.save(model.state_dict(), best_path_model)\r\n", 437 | "\r\n", 438 | " # write to summary writer\r\n", 439 | " metrics_best.write(summary_writer, step)\r\n", 440 | "\r\n", 441 | " epoch = step // steps_per_epoch\r\n", 442 | " train_metrics_res = train[\"metrics\"].result(append_tag=False)\r\n", 443 | " val_metrics_res = validation[\"metrics\"].result(append_tag=False)\r\n", 444 | " metrics_strings = [\r\n", 445 | " f\"{key}: train={train_metrics_res[key]:.6f}, val={val_metrics_res[key]:.6f}\"\r\n", 446 | " for key in validation[\"metrics\"].keys\r\n", 447 | " ]\r\n", 448 | " logging.info(\r\n", 449 | " f\"{step}/{num_steps} (epoch {epoch}): \" + \"; \".join(metrics_strings)\r\n", 450 | " )\r\n", 451 | "\r\n", 452 | " # decay learning rate on plateau\r\n", 453 | " trainer.decay_maybe(validation[\"metrics\"].loss)\r\n", 454 | "\r\n", 455 | " train[\"metrics\"].write(summary_writer, step)\r\n", 456 | " validation[\"metrics\"].write(summary_writer, step)\r\n", 457 | " train[\"metrics\"].reset_states()\r\n", 458 | " validation[\"metrics\"].reset_states()\r\n", 459 | "\r\n", 460 | " # Restore backup variables\r\n", 461 | " trainer.restore_variable_backups()\r\n", 462 | "\r\n", 463 | " # early stopping\r\n", 464 | " if step - metrics_best.step > patience * evaluation_interval:\r\n", 465 | " break\r\n", 466 | "\r\n", 467 | "result = {key + \"_best\": val for key, val in metrics_best.items()}" 468 | ] 469 | }, 470 | { 471 | "cell_type": "markdown", 472 | "metadata": {}, 473 | "source": [ 474 | "# Print results" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": null, 480 | "metadata": {}, 481 | "outputs": [], 482 | "source": [ 483 | "for key, val in metrics_best.items():\r\n", 484 | " print(f\"{key}: {val}\")\r\n" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": null, 490 | "metadata": {}, 491 | "outputs": [], 492 | "source": [] 493 | } 494 | ], 495 | "metadata": { 496 | "interpreter": { 497 | "hash": "6d9d58ddb04bb635eba824a3c64b6d0110bcc4c6cff8b192a6f7cbbb2bf10de4" 498 | }, 499 | "kernelspec": { 500 | "display_name": "Python 3.5.4 64-bit", 501 | "name": "python3" 502 | }, 503 | "language_info": { 504 | "name": "python", 505 | "version": "" 506 | }, 507 | "orig_nbformat": 4 508 | }, 509 | "nbformat": 4, 510 | "nbformat_minor": 2 511 | } -------------------------------------------------------------------------------- /train_seml.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" 5 | os.environ["AUTOGRAPH_VERBOSITY"] = "1" 6 | import logging 7 | import string 8 | import random 9 | import time 10 | from datetime import datetime 11 | 12 | from gemnet.model.gemnet import GemNet 13 | from gemnet.training.trainer import Trainer 14 | from gemnet.training.metrics import Metrics, BestMetrics 15 | from gemnet.training.data_container import DataContainer 16 | from gemnet.training.data_provider import DataProvider 17 | 18 | from sacred import Experiment 19 | import torch 20 | from torch.utils.tensorboard import SummaryWriter 21 | import seml 22 | 23 | ex = Experiment() 24 | seml.setup_logger(ex) 25 | 26 | 27 | @ex.post_run_hook 28 | def collect_stats(_run): 29 | seml.collect_exp_stats(_run) 30 | 31 | 32 | @ex.config 33 | def config(): 34 | overwrite = None 35 | db_collection = None 36 | if db_collection is not None: 37 | ex.observers.append( 38 | seml.create_mongodb_observer(db_collection, overwrite=overwrite) 39 | ) 40 | 41 | 42 | @ex.automain 43 | def run( 44 | num_spherical, 45 | num_radial, 46 | num_blocks, 47 | emb_size_atom, 48 | emb_size_edge, 49 | emb_size_trip, 50 | emb_size_quad, 51 | emb_size_rbf, 52 | emb_size_cbf, 53 | emb_size_sbf, 54 | num_before_skip, 55 | num_after_skip, 56 | num_concat, 57 | num_atom, 58 | emb_size_bil_quad, 59 | emb_size_bil_trip, 60 | triplets_only, 61 | forces_coupled, 62 | direct_forces, 63 | mve, 64 | cutoff, 65 | int_cutoff, 66 | envelope_exponent, 67 | extensive, 68 | output_init, 69 | scale_file, 70 | data_seed, 71 | dataset, 72 | val_dataset, 73 | num_train, 74 | num_val, 75 | logdir, 76 | loss, 77 | tfseed, 78 | num_steps, 79 | rho_force, 80 | ema_decay, 81 | weight_decay, 82 | grad_clip_max, 83 | agc, 84 | decay_patience, 85 | decay_factor, 86 | decay_cooldown, 87 | batch_size, 88 | evaluation_interval, 89 | patience, 90 | save_interval, 91 | learning_rate, 92 | warmup_steps, 93 | decay_steps, 94 | decay_rate, 95 | staircase, 96 | restart, 97 | comment, 98 | ): 99 | 100 | torch.manual_seed(tfseed) 101 | 102 | logging.info("Start training") 103 | # log hyperparameters 104 | logging.info( 105 | "Hyperparams: \n" + "\n".join(f"{key}: {val}" for key, val in locals().items()) 106 | ) 107 | num_gpus = torch.cuda.device_count() 108 | cuda_available = torch.cuda.is_available() 109 | logging.info(f"Available GPUs: {num_gpus}") 110 | logging.info(f"CUDA Available: {cuda_available}") 111 | if num_gpus == 0: 112 | logging.warning("No GPUs were found. Training is run on CPU!") 113 | if not cuda_available: 114 | logging.warning("CUDA unavailable. Training is run on CPU!") 115 | 116 | # Used for creating a "unique" id for a run (almost impossible to generate the same twice) 117 | def id_generator( 118 | size=6, chars=string.ascii_uppercase + string.ascii_lowercase + string.digits 119 | ): 120 | return "".join(random.SystemRandom().choice(chars) for _ in range(size)) 121 | 122 | # A unique directory name is created for this run based on the input 123 | if (restart is None) or (restart == "None"): 124 | directory = ( 125 | logdir 126 | + "/" 127 | + datetime.now().strftime("%Y%m%d_%H%M%S") 128 | + "_" 129 | + id_generator() 130 | + "_" 131 | + os.path.basename(dataset) 132 | + "_" 133 | + str(comment) 134 | ) 135 | else: 136 | directory = restart 137 | 138 | logging.info(f"Directory: {directory}") 139 | logging.info("Create directories") 140 | 141 | if not os.path.exists(directory): 142 | os.makedirs(directory, exist_ok=True) 143 | 144 | best_dir = os.path.join(directory, "best") 145 | if not os.path.exists(best_dir): 146 | os.makedirs(best_dir) 147 | log_dir = os.path.join(directory, "logs") 148 | if not os.path.exists(log_dir): 149 | os.makedirs(log_dir) 150 | 151 | extension = ".pth" 152 | log_path_model = f"{log_dir}/model{extension}" 153 | log_path_training = f"{log_dir}/training{extension}" 154 | best_path_model = f"{best_dir}/model{extension}" 155 | 156 | logging.info("Initialize model") 157 | model = GemNet( 158 | num_spherical=num_spherical, 159 | num_radial=num_radial, 160 | num_blocks=num_blocks, 161 | emb_size_atom=emb_size_atom, 162 | emb_size_edge=emb_size_edge, 163 | emb_size_trip=emb_size_trip, 164 | emb_size_quad=emb_size_quad, 165 | emb_size_rbf=emb_size_rbf, 166 | emb_size_cbf=emb_size_cbf, 167 | emb_size_sbf=emb_size_sbf, 168 | num_before_skip=num_before_skip, 169 | num_after_skip=num_after_skip, 170 | num_concat=num_concat, 171 | num_atom=num_atom, 172 | emb_size_bil_quad=emb_size_bil_quad, 173 | emb_size_bil_trip=emb_size_bil_trip, 174 | num_targets=2 if mve else 1, 175 | triplets_only=triplets_only, 176 | direct_forces=direct_forces, 177 | forces_coupled=forces_coupled, 178 | cutoff=cutoff, 179 | int_cutoff=int_cutoff, 180 | envelope_exponent=envelope_exponent, 181 | activation="swish", 182 | extensive=extensive, 183 | output_init=output_init, 184 | scale_file=scale_file, 185 | ) 186 | # push to GPU if available 187 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 188 | model.to(device) 189 | 190 | # Initialize summary writer 191 | summary_writer = SummaryWriter(log_dir) 192 | train = {} 193 | validation = {} 194 | 195 | logging.info("Load dataset") 196 | data_container = DataContainer( 197 | dataset, cutoff=cutoff, int_cutoff=int_cutoff, triplets_only=triplets_only 198 | ) 199 | 200 | if val_dataset is not None: 201 | # Initialize DataProvider 202 | if num_train == 0: 203 | num_train = len(data_container) 204 | logging.info(f"Training data size: {num_train}") 205 | data_provider = DataProvider( 206 | data_container, 207 | num_train, 208 | 0, 209 | batch_size, 210 | seed=data_seed, 211 | shuffle=True, 212 | random_split=True, 213 | ) 214 | 215 | # Initialize validation datasets 216 | val_data_container = DataContainer( 217 | val_dataset, 218 | cutoff=cutoff, 219 | int_cutoff=int_cutoff, 220 | triplets_only=triplets_only, 221 | ) 222 | if num_val == 0: 223 | num_val = len(val_data_container) 224 | logging.info(f"Validation data size: {num_val}") 225 | val_data_provider = DataProvider( 226 | val_data_container, 227 | 0, 228 | num_val, 229 | batch_size, 230 | seed=data_seed, 231 | shuffle=True, 232 | random_split=True, 233 | ) 234 | else: 235 | # Initialize DataProvider (splits dataset into 3 sets based on data_seed and provides tf.datasets) 236 | logging.info(f"Training data size: {num_train}") 237 | logging.info(f"Validation data size: {num_val}") 238 | assert num_train > 0 239 | assert num_val > 0 240 | data_provider = DataProvider( 241 | data_container, 242 | num_train, 243 | num_val, 244 | batch_size, 245 | seed=data_seed, 246 | shuffle=True, 247 | random_split=True, 248 | ) 249 | val_data_provider = data_provider 250 | 251 | # Initialize datasets 252 | train["dataset_iter"] = data_provider.get_dataset("train") 253 | validation["dataset_iter"] = val_data_provider.get_dataset("val") 254 | 255 | 256 | logging.info("Prepare training") 257 | # Initialize trainer 258 | trainer = Trainer( 259 | model, 260 | learning_rate=learning_rate, 261 | decay_steps=decay_steps, 262 | decay_rate=decay_rate, 263 | warmup_steps=warmup_steps, 264 | weight_decay=weight_decay, 265 | ema_decay=ema_decay, 266 | decay_patience=decay_patience, 267 | decay_factor=decay_factor, 268 | decay_cooldown=decay_cooldown, 269 | grad_clip_max=grad_clip_max, 270 | rho_force=rho_force, 271 | mve=mve, 272 | loss=loss, 273 | staircase=staircase, 274 | agc=agc, 275 | ) 276 | 277 | # Initialize metrics 278 | train["metrics"] = Metrics("train", trainer.tracked_metrics, ex) 279 | validation["metrics"] = Metrics("val", trainer.tracked_metrics, ex) 280 | 281 | # Save/load best recorded loss (only the best model is saved) 282 | metrics_best = BestMetrics(best_dir, validation["metrics"]) 283 | 284 | # Set up checkpointing 285 | # Restore latest checkpoint 286 | if os.path.exists(log_path_model): 287 | logging.info("Restoring model and trainer") 288 | model_checkpoint = torch.load(log_path_model) 289 | model.load_state_dict(model_checkpoint["model"]) 290 | 291 | train_checkpoint = torch.load(log_path_training) 292 | trainer.load_state_dict(train_checkpoint["trainer"]) 293 | # restore the best saved results 294 | metrics_best.restore() 295 | logging.info(f"Restored best metrics: {metrics_best.loss}") 296 | step_init = int(train_checkpoint["step"]) 297 | else: 298 | logging.info("Freshly initialize model") 299 | metrics_best.inititalize() 300 | step_init = 0 301 | 302 | if ex is not None: 303 | ex.current_run.info = {"directory": directory} 304 | # save the number of parameters 305 | nparams = sum(p.numel() for p in model.parameters() if p.requires_grad) 306 | ex.current_run.info.update({"nParams": nparams}) 307 | 308 | # Training loop 309 | logging.info("Start training") 310 | 311 | steps_per_epoch = int(np.ceil(num_train / batch_size)) 312 | for step in range(step_init + 1, num_steps + 1): 313 | # start after evaluation to not include time on validation set 314 | if ex is not None: 315 | if step == evaluation_interval + 1: 316 | start = time.perf_counter() 317 | if step == 2 * evaluation_interval - 1: 318 | end = time.perf_counter() 319 | time_delta = end - start 320 | nsteps = evaluation_interval - 2 321 | ex.current_run.info.update( 322 | {"seconds_per_step": time_delta / nsteps, 323 | "min_per_epoch": int(time_delta / nsteps * steps_per_epoch * 100 / 60) / 100 # two digits only 324 | } 325 | ) 326 | 327 | # keep track of the learning rate 328 | if step % 10 == 0: 329 | lr = trainer.schedulers[0].get_last_lr()[0] 330 | summary_writer.add_scalar("lr", lr, global_step=step) 331 | 332 | # Perform training step 333 | trainer.train_on_batch(train["dataset_iter"], train["metrics"]) 334 | 335 | # Save progress 336 | if step % save_interval == 0: 337 | torch.save({"model": model.state_dict()}, log_path_model) 338 | torch.save( 339 | {"trainer": trainer.state_dict(), "step": step}, log_path_training 340 | ) 341 | 342 | # Check performance on the validation set 343 | if step % evaluation_interval == 0: 344 | 345 | # Save backup variables and load averaged variables 346 | trainer.save_variable_backups() 347 | trainer.load_averaged_variables() 348 | 349 | # Compute averages 350 | for i in range(int(np.ceil(num_val / batch_size))): 351 | trainer.test_on_batch(validation["dataset_iter"], validation["metrics"]) 352 | 353 | # Update and save best result 354 | if validation["metrics"].loss < metrics_best.loss: 355 | metrics_best.update(step, validation["metrics"]) 356 | torch.save(model.state_dict(), best_path_model) 357 | 358 | # write to summary writer 359 | metrics_best.write(summary_writer, step) 360 | 361 | epoch = step // steps_per_epoch 362 | train_metrics_res = train["metrics"].result(append_tag=False) 363 | val_metrics_res = validation["metrics"].result(append_tag=False) 364 | metrics_strings = [ 365 | f"{key}: train={train_metrics_res[key]:.6f}, val={val_metrics_res[key]:.6f}" 366 | for key in validation["metrics"].keys 367 | ] 368 | logging.info( 369 | f"{step}/{num_steps} (epoch {epoch}): " + "; ".join(metrics_strings) 370 | ) 371 | 372 | # decay learning rate on plateau 373 | trainer.decay_maybe(validation["metrics"].loss) 374 | 375 | train["metrics"].write(summary_writer, step) 376 | validation["metrics"].write(summary_writer, step) 377 | train["metrics"].reset_states() 378 | validation["metrics"].reset_states() 379 | 380 | # Restore backup variables 381 | trainer.restore_variable_backups() 382 | 383 | # early stopping 384 | if step - metrics_best.step > patience * evaluation_interval: 385 | break 386 | 387 | return {key + "_best": val for key, val in metrics_best.items()} 388 | 389 | --------------------------------------------------------------------------------