├── LICENSE ├── LICENSE_SHORT ├── README.md ├── chemtrain ├── __init__.py ├── data_processing.py ├── dimenet_basis_util.py ├── dropout.py ├── force_matching.py ├── jax_md_mod │ ├── __init__.py │ ├── custom_energy.py │ ├── custom_interpolate.py │ ├── custom_quantity.py │ ├── custom_space.py │ └── io.py ├── layers.py ├── max_likelihood.py ├── neural_networks.py ├── reweighting.py ├── sparse_graph.py ├── trainers.py ├── traj_quantity.py ├── traj_util.py └── util.py ├── examples ├── alanine_dipeptide │ ├── alanine_force_matching.py │ ├── alanine_relative_entropy.py │ ├── alanine_simulation.py │ ├── data │ │ ├── confs │ │ │ └── heavy_2_7nm.gro │ │ ├── dataset │ │ │ ├── phi_angles_r100ns.npy │ │ │ └── psi_angles_r100ns.npy │ │ └── prior │ │ │ ├── Alanine_dipeptide_heavy_dihedral_constant.npy │ │ │ ├── Alanine_dipeptide_heavy_dihedral_multiplicity.npy │ │ │ ├── Alanine_dipeptide_heavy_dihedral_phase.npy │ │ │ ├── Alanine_dipeptide_heavy_epsilon.npy │ │ │ ├── Alanine_dipeptide_heavy_eq_angle.npy │ │ │ ├── Alanine_dipeptide_heavy_eq_angle_variance.npy │ │ │ ├── Alanine_dipeptide_heavy_eq_bond_length.npy │ │ │ ├── Alanine_dipeptide_heavy_eq_bond_variance.npy │ │ │ └── Alanine_dipeptide_heavy_sigma.npy │ └── visualization.py └── water │ ├── CG_water_force_matching.py │ ├── CG_water_relative_entropy.py │ ├── CG_water_simulation.py │ └── data │ ├── dataset │ └── box.npy │ └── water_models │ ├── TIP4P-2005_150_COM_ADF.csv │ ├── TIP4P-2005_1k_50b_TCF_cut05.npy │ ├── TIP4P-2005_1k_50b_TCF_cut06.npy │ ├── TIP4P-2005_1k_50b_TCF_cut08.npy │ └── TIP4P-2005_300_COM_RDF.csv ├── setup.py └── util ├── Initialization.py └── Postprocessing.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE_SHORT: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Multiscale Modeling of Fluid Materials, TU Munich 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Relative Entropy Minimization 2 | 3 | Implementation of the Relative Entropy Minimization and Force Matching methods 4 | as employed in the paper 5 | [Deep coarse-grained potentials via relative entropy minimization](https://aip.scitation.org/doi/10.1063/5.0124538). 6 | 7 | ## Getting started 8 | This repository provides code to train and simulate two systems, coarse-grained 9 | water and alanine dipeptide, with force matching and relative 10 | entropy minimization. The training examples can be found in 11 | [CG_water_force_matching.py](examples/water/CG_water_force_matching.py) 12 | and [CG_water_relative_entropy.py](examples/water/CG_water_relative_entropy.py) 13 | for water and in [alanine_force_matching.py](examples/alanine_dipeptide/alanine_force_matching.py) 14 | and [alanine_relative_entropy.py](examples/alanine_dipeptide/alanine_relative_entropy.py) 15 | for alanine dipeptide. Training the model with force matching will take a few 16 | hours and more than a day with relative entropy. 17 | 18 | MD simulation employing the trained DimeNet++ models can be found 19 | in [CG_water_simulation.py](examples/water/CG_water_simulation.py) 20 | and [alanine_simulation.py](examples/alanine_dipeptide/alanine_simulation.py) 21 | respectively. 22 | 23 | ## Data sets 24 | The data sets for alanine dipeptide and water can be downloaded from Google 25 | Drive via the following link:
26 | [https://drive.google.com/drive/folders/1IBZbuSBIBhvFbVhuo9s-ENE2IyWG-YI_?usp=sharing](https://drive.google.com/drive/folders/1IBZbuSBIBhvFbVhuo9s-ENE2IyWG-YI_?usp=sharing)
27 | Once downloaded, you can move the conf and force files into the dataset folder 28 | of [water](examples/water/data/dataset) and 29 | [alanine dipeptide](examples/alanine_dipeptide/data/dataset). 30 | 31 | ## Installation 32 | All dependencies can be installed locally with pip: 33 | ``` 34 | pip install -e .[all] 35 | ``` 36 | 37 | However, this only installs a CPU version of Jax. If you want to enable GPU 38 | support, please overwrite the jaxlib version: 39 | ``` 40 | pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html 41 | ``` 42 | 43 | ## Requirements 44 | The repository uses with the following packages: 45 | ``` 46 | 'jax>=0.4.3', 47 | 'jax-md>=0.2.5', 48 | 'optax>=0.0.9', 49 | 'dm-haiku>=0.0.9', 50 | 'sympy', 51 | 'cloudpickle', 52 | 'chex', 53 | 'jax-sgmc', 54 | ``` 55 | The code was run with Python 3.8. The packages used in the paper 56 | are listed in [setup.py](setup.py). 57 | 58 | ## Citation 59 | Please cite our paper if you use this code in your own work: 60 | ``` 61 | @article{thaler_entropy_2022, 62 | title = {Deep coarse-grained potentials via relative entropy minimization}, 63 | author = {Thaler, Stephan and Stupp, Maximilian and Zavadlav, Julija}, 64 | journal={The Journal of Chemical Physics}, 65 | volume = {157}, 66 | number = {24}, 67 | pages = {244103}, 68 | year = {2022}, 69 | doi = {10.1063/5.0124538} 70 | } 71 | ``` 72 | -------------------------------------------------------------------------------- /chemtrain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/chemtrain/__init__.py -------------------------------------------------------------------------------- /chemtrain/data_processing.py: -------------------------------------------------------------------------------- 1 | """Functions that facilitate common data processing operations for machine 2 | learning. 3 | """ 4 | import numpy as onp 5 | from jax import lax 6 | from jax.tree_util import tree_flatten 7 | from jax_sgmc.data import numpy_loader 8 | 9 | from chemtrain.jax_md_mod import custom_space 10 | from chemtrain import util 11 | 12 | 13 | def get_dataset(data_location_str, retain=None, subsampling=1): 14 | """Loads .pyy numpy dataset. 15 | 16 | Args: 17 | data_location_str: String of .npy data location 18 | retain: Number of samples to keep in the dataset. All by default. 19 | subsampling: Only keep every subsampled sample of the data, e.g. 2. 20 | 21 | Returns: 22 | Subsampled data array 23 | """ 24 | loaded_data = onp.load(data_location_str) 25 | loaded_data = loaded_data[:retain:subsampling] 26 | return loaded_data 27 | 28 | 29 | def train_val_test_split(dataset, train_ratio=0.7, val_ratio=0.1, shuffle=False, 30 | shuffle_seed=0): 31 | """Train-validation-test split for datasets. Works on arbitrary pytrees, 32 | including chex.dataclasses, dictionaries and single arrays. 33 | 34 | If a subset ratio ratios is 0, returns None for the respective subset. 35 | 36 | Args: 37 | dataset: Dataset pytree. Samples are assumed to be stacked along 38 | axis 0. 39 | train_ratio: Percantage of dataset to use for training. 40 | val_ratio: Percantage of dataset to use for validation. 41 | shuffle: If True, shuffles data before splitting into train-val-test. 42 | Shuffling copies the dataset. 43 | shuffle_seed: PRNG Seed for data shuffling 44 | 45 | Returns: 46 | Tuple (train_data, val_data, test_data) with the same shape as the input 47 | pytree, but split along axis 0. 48 | """ 49 | assert train_ratio + val_ratio <= 1., 'Distribution of data exceeds 100%.' 50 | leaves, _ = tree_flatten(dataset) 51 | dataset_size = leaves[0].shape[0] 52 | train_size = int(dataset_size * train_ratio) 53 | val_size = int(dataset_size * val_ratio) 54 | 55 | if shuffle: 56 | dataset_idxs = onp.arange(dataset_size) 57 | numpy_rng = onp.random.default_rng(shuffle_seed) 58 | numpy_rng.shuffle(dataset_idxs) 59 | 60 | def retreive_datasubset(idxs): 61 | data_subset = util.tree_take(dataset, idxs, axis=0) 62 | subset_leaves, _ = tree_flatten(data_subset) 63 | subset_size = subset_leaves[0].shape[0] 64 | if subset_size == 0: 65 | data_subset = None 66 | return data_subset 67 | 68 | train_data = retreive_datasubset(dataset_idxs[:train_size]) 69 | val_data = retreive_datasubset(dataset_idxs[train_size: 70 | val_size + train_size]) 71 | test_data = retreive_datasubset(dataset_idxs[val_size + train_size:]) 72 | 73 | else: 74 | def retreive_datasubset(start, end): 75 | data_subset = util.tree_get_slice(dataset, start, end, 76 | to_device=False) 77 | subset_leaves, _ = tree_flatten(data_subset) 78 | subset_size = subset_leaves[0].shape[0] 79 | if subset_size == 0: 80 | data_subset = None 81 | return data_subset 82 | 83 | train_data = retreive_datasubset(0, train_size) 84 | val_data = retreive_datasubset(train_size, train_size + val_size) 85 | test_data = retreive_datasubset(train_size + val_size, None) 86 | return train_data, val_data, test_data 87 | 88 | 89 | def init_dataloaders(dataset, train_ratio=0.7, val_ratio=0.1, shuffle=False): 90 | """Splits dataset and initializes dataloaders. 91 | 92 | If the validation or test ratios are 0, returns None for the respective 93 | dataloader. 94 | 95 | Args: 96 | dataset: Dictionary containing the whole dataset. The NumpyDataLoader 97 | returns batches with the same kwargs as provided in dataset. 98 | train_ratio: Percantage of dataset to use for training. 99 | val_ratio: Percantage of dataset to use for validation. 100 | shuffle: Whether to shuffle data before splitting into train-val-test. 101 | 102 | Returns: 103 | A tuple (train_loader, val_loader, test_loader) of NumpyDataLoaders. 104 | """ 105 | def init_subloader(data_subset): 106 | if data_subset is None: 107 | loader = None 108 | else: 109 | loader = numpy_loader.NumpyDataLoader(**data_subset, copy=False) 110 | return loader 111 | 112 | train_set, val_set, test_set = train_val_test_split( 113 | dataset, train_ratio, val_ratio, shuffle=shuffle) 114 | train_loader = init_subloader(train_set) 115 | val_loader = init_subloader(val_set) 116 | test_loader = init_subloader(test_set) 117 | return train_loader, val_loader, test_loader 118 | 119 | 120 | def scale_dataset_fractional(traj, box): 121 | """Scales a dataset of positions from real space to fractional coordinates. 122 | 123 | Args: 124 | traj: A (N_snapshots, N_particles, 3) array of particle positions 125 | box: A 1 or 2-dimensional jax_md box 126 | 127 | Returns: 128 | A (N_snapshots, N_particles, 3) array of particle positions in 129 | fractional coordinates. 130 | """ 131 | _, scale_fn = custom_space.init_fractional_coordinates(box) 132 | scaled_traj = lax.map(scale_fn, traj) 133 | return scaled_traj 134 | -------------------------------------------------------------------------------- /chemtrain/dimenet_basis_util.py: -------------------------------------------------------------------------------- 1 | """The spherical Bessel Function utility file from the original DimeNet 2 | implementation (https://github.com/klicperajo/dimenet). 3 | Licenced under Hippocratic License 2.1. 4 | """ 5 | 6 | import numpy as np 7 | from scipy.optimize import brentq 8 | from scipy import special as sp 9 | import sympy as sym 10 | 11 | 12 | def Jn(r, n): 13 | """ 14 | numerical spherical bessel functions of order n 15 | """ 16 | return np.sqrt(np.pi/(2*r)) * sp.jv(n+0.5, r) 17 | 18 | 19 | def Jn_zeros(n, k): 20 | """ 21 | Compute the first k zeros of the spherical bessel functions up to order n (excluded) 22 | """ 23 | zerosj = np.zeros((n, k), dtype="float32") 24 | zerosj[0] = np.arange(1, k + 1) * np.pi 25 | points = np.arange(1, k + n) * np.pi 26 | racines = np.zeros(k + n - 1, dtype="float32") 27 | for i in range(1, n): 28 | for j in range(k + n - 1 - i): 29 | foo = brentq(Jn, points[j], points[j + 1], (i,)) 30 | racines[j] = foo 31 | points = racines 32 | zerosj[i][:k] = racines[:k] 33 | 34 | return zerosj 35 | 36 | 37 | def spherical_bessel_formulas(n): 38 | """ 39 | Computes the sympy formulas for the spherical bessel functions up to order n (excluded) 40 | """ 41 | x = sym.symbols('x') 42 | 43 | f = [sym.sin(x)/x] 44 | a = sym.sin(x)/x 45 | for i in range(1, n): 46 | b = sym.diff(a, x)/x 47 | f += [sym.simplify(b*(-x)**i)] 48 | a = sym.simplify(b) 49 | return f 50 | 51 | 52 | def bessel_basis(n, k): 53 | """ 54 | Compute the sympy formulas for the normalized and rescaled spherical bessel functions up to 55 | order n (excluded) and maximum frequency k (excluded). 56 | """ 57 | 58 | zeros = Jn_zeros(n, k) 59 | normalizer = [] 60 | for order in range(n): 61 | normalizer_tmp = [] 62 | for i in range(k): 63 | normalizer_tmp += [0.5*Jn(zeros[order, i], order+1)**2] 64 | normalizer_tmp = 1/np.array(normalizer_tmp)**0.5 65 | normalizer += [normalizer_tmp] 66 | 67 | f = spherical_bessel_formulas(n) 68 | x = sym.symbols('x') 69 | bess_basis = [] 70 | for order in range(n): 71 | bess_basis_tmp = [] 72 | for i in range(k): 73 | bess_basis_tmp += [sym.simplify(normalizer[order] 74 | [i]*f[order].subs(x, zeros[order, i]*x))] 75 | bess_basis += [bess_basis_tmp] 76 | return bess_basis 77 | 78 | 79 | def sph_harm_prefactor(l, m): 80 | """ 81 | Computes the constant pre-factor for the spherical harmonic of degree l and order m 82 | input: 83 | l: int, l>=0 84 | m: int, -l<=m<=l 85 | """ 86 | return ((2*l+1) * np.math.factorial(l-abs(m)) / (4*np.pi*np.math.factorial(l+abs(m))))**0.5 87 | 88 | 89 | def associated_legendre_polynomials(l, zero_m_only=True): 90 | """ 91 | Computes sympy formulas of the associated legendre polynomials up to order l (excluded). 92 | """ 93 | z = sym.symbols('z') 94 | P_l_m = [[0]*(j+1) for j in range(l)] 95 | 96 | P_l_m[0][0] = 1 97 | if l > 0: 98 | P_l_m[1][0] = z 99 | 100 | for j in range(2, l): 101 | P_l_m[j][0] = sym.simplify( 102 | ((2*j-1)*z*P_l_m[j-1][0] - (j-1)*P_l_m[j-2][0])/j) 103 | if not zero_m_only: 104 | for i in range(1, l): 105 | P_l_m[i][i] = sym.simplify((1-2*i)*P_l_m[i-1][i-1]) 106 | if i + 1 < l: 107 | P_l_m[i+1][i] = sym.simplify((2*i+1)*z*P_l_m[i][i]) 108 | for j in range(i + 2, l): 109 | P_l_m[j][i] = sym.simplify( 110 | ((2*j-1) * z * P_l_m[j-1][i] - (i+j-1) * P_l_m[j-2][i]) / (j - i)) 111 | 112 | return P_l_m 113 | 114 | 115 | def real_sph_harm(l, zero_m_only=True, spherical_coordinates=True): 116 | """ 117 | Computes formula strings of the the real part of the spherical harmonics up to order l (excluded). 118 | Variables are either cartesian coordinates x,y,z on the unit sphere or spherical coordinates phi and theta. 119 | """ 120 | if not zero_m_only: 121 | S_m = [0] 122 | C_m = [1] 123 | for i in range(1, l): 124 | x = sym.symbols('x') 125 | y = sym.symbols('y') 126 | S_m += [x*S_m[i-1] + y*C_m[i-1]] 127 | C_m += [x*C_m[i-1] - y*S_m[i-1]] 128 | 129 | P_l_m = associated_legendre_polynomials(l, zero_m_only) 130 | if spherical_coordinates: 131 | theta = sym.symbols('theta') 132 | z = sym.symbols('z') 133 | for i in range(len(P_l_m)): 134 | for j in range(len(P_l_m[i])): 135 | if type(P_l_m[i][j]) != int: 136 | P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta)) 137 | if not zero_m_only: 138 | phi = sym.symbols('phi') 139 | for i in range(len(S_m)): 140 | S_m[i] = S_m[i].subs(x, sym.sin( 141 | theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi)) 142 | for i in range(len(C_m)): 143 | C_m[i] = C_m[i].subs(x, sym.sin( 144 | theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi)) 145 | 146 | Y_func_l_m = [['0']*(2*j + 1) for j in range(l)] 147 | for i in range(l): 148 | Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0]) 149 | 150 | if not zero_m_only: 151 | for i in range(1, l): 152 | for j in range(1, i + 1): 153 | Y_func_l_m[i][j] = sym.simplify( 154 | 2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]) 155 | for i in range(1, l): 156 | for j in range(1, i + 1): 157 | Y_func_l_m[i][-j] = sym.simplify( 158 | 2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]) 159 | 160 | return Y_func_l_m -------------------------------------------------------------------------------- /chemtrain/dropout.py: -------------------------------------------------------------------------------- 1 | """Customn function for enabling dropout applications in haiku.""" 2 | from functools import wraps 3 | 4 | import haiku as hk 5 | from jax import random, vmap, numpy as jnp 6 | 7 | 8 | # Note: This implementation currently stores the RNG key as float32 9 | # rather than uint32. This way, both energy_params and the RNG key 10 | # can be treated analogously in grad and optimizer of trainers. 11 | # Therefore, the RNG key is updated twice per parameter update: 12 | # By optimizer via a (arbitrary) gradient and by the standard RNG update. 13 | # This allows a more unified treatment in trainers, without frequent 14 | # branching if dropout is used or not. 15 | 16 | 17 | def _dropout_wrapper(fun): 18 | """Wraps Haiku apply fuctions such that combined dropout and models params 19 | are split and supplied seperately as expected by model. 20 | """ 21 | # first argument is params for Haiku models 22 | @wraps(fun) 23 | def dropout_fn(params, *args, **kwargs): 24 | if dropout_is_used(params): 25 | params, drop_key = split_dropout_params(params) 26 | kwargs['dropout_key'] = drop_key 27 | return fun(params, *args, **kwargs) 28 | return dropout_fn 29 | 30 | 31 | def model_init_apply(model, model_kwargs): 32 | """Returns Haiku model.init and model.apply that adapivly use dropout if 33 | 'dropout_key' is provided alongside parameters. If not, no dropout is 34 | applied. 35 | """ 36 | dropout_mode = model_kwargs.get('dropout_mode', None) 37 | if dropout_mode is None: 38 | return model.init, model.apply 39 | else: 40 | return model.init, _dropout_wrapper(model.apply) 41 | 42 | 43 | def split_dropout_params(meta_params): 44 | """Splits up meta params, built up by energy_params and 45 | dropout_key. 46 | """ 47 | return (meta_params['haiku_params'], 48 | jnp.uint32(meta_params['Dropout_RNG_key'])) 49 | 50 | 51 | def next_dropout_params(meta_params): 52 | """Steps dropout_key and re-packes it in meta_params.""" 53 | _, old_dropout_key = split_dropout_params(meta_params) 54 | new_dropout_key, _ = random.split(old_dropout_key, 2) 55 | return build_dropout_params(meta_params['haiku_params'], new_dropout_key) 56 | 57 | 58 | def dropout_is_used(meta_params): 59 | """A function that returns whether dropout is used by 60 | checking if the 'Dropout_RNG_key' is set or exists at all. 61 | """ 62 | try: 63 | return 'Dropout_RNG_key' in meta_params.keys() 64 | except AttributeError: 65 | return False 66 | 67 | 68 | def build_dropout_params(energy_params, dropout_key): 69 | """Combines meta_params, built up by energy_params and 70 | dropout_key. 71 | """ 72 | 73 | return {'haiku_params': energy_params, 74 | 'Dropout_RNG_key': jnp.float32(dropout_key)} 75 | 76 | 77 | def construct_dropout_params(dropout_key, dropout_setup): 78 | """Splits and distributes the random key for all dropout.Linear 79 | layers. 80 | """ 81 | dropout_key_dict = {} 82 | if dropout_key is not None and len(dropout_setup) != 0: # dropout 83 | n_keys = len(dropout_setup) 84 | split = random.split(dropout_key, n_keys) 85 | for i, (layer_name, do_rate) in enumerate(dropout_setup.items()): 86 | dropout_key_dict[layer_name] = {'key': split[i], 'do_rate': do_rate} 87 | 88 | return dropout_key_dict 89 | 90 | 91 | class Linear(hk.Module): 92 | """Wrapper function for hk.Linear that applies dropout if 93 | name exists as key in dropout_dict during call. 94 | """ 95 | def __init__(self, embed_size, name, with_bias=True, **init_kwargs): 96 | super().__init__(name=name) 97 | self.linear = hk.Linear(embed_size, with_bias=with_bias, **init_kwargs) 98 | # False: independent dropout of output features 99 | # True: same parts of feature vector are dropped for all edge embeddings 100 | self.shared = False 101 | 102 | def __call__(self, inputs, dropout_dict=None): 103 | linear_output = self.linear(inputs) 104 | if dropout_dict is None or self.module_name not in dropout_dict.keys(): 105 | return linear_output 106 | else: # apply dropout 107 | dropout_key = dropout_dict[self.module_name]['key'] 108 | if self.shared: 109 | vectorized_dropout = vmap(hk.dropout, in_axes=(None, None, 0)) 110 | dropped_array = vectorized_dropout(dropout_key, dropout_dict[ 111 | self.module_name]['do_rate'], linear_output) 112 | else: 113 | dropped_array = hk.dropout(dropout_key, dropout_dict[ 114 | self.module_name]['do_rate'], linear_output) 115 | return dropped_array 116 | 117 | 118 | def dimenetpp_setup(setup_dict, 119 | num_dense_out, 120 | n_interaction_blocks, 121 | num_res_before_skip, 122 | num_res_after_skip, 123 | overall_dropout_rate=None): 124 | """Builds the Dropout hyperparameters for DimeNet++. 125 | 126 | Args: 127 | setup_dict: Dict containing the block name to be dropouted 128 | alongside the target dropout rate for each block. 129 | Available blocks are 'output', 'interaction' and 130 | 'embedding'. If None, no layers will be dropped. 131 | num_dense_out: Number of fully-connected layers in output block 132 | n_interaction_blocks: Number of interaction blocks 133 | num_res_before_skip: Number of residual blocks before the skip 134 | num_res_after_skip: Number of residual blocks after the skip connection 135 | overall_dropout_rate: If given, will override all droput rates 136 | with this global rate. 137 | 138 | Returns: dropout_setup - a dict encoding dropout structure of DimeNet++ 139 | """ 140 | def index_to_layer_name(i): 141 | if i == 0: 142 | return '' 143 | else: 144 | return '_' + str(i) 145 | 146 | # For maximum flexibility, each layer can be addressed seperately. 147 | # These names are hardcoded to naming of layers in DimeNetPP 148 | # If you change layer names, make sure to adjust these 149 | delimiter = '/~/' 150 | net_prefix = 'DimeNetPP' 151 | output_prefix = 'Output' 152 | interaction_prefix = 'Interaction' 153 | embedding_prefix = 'Embedding' 154 | residual_prefix = 'ResLayer' 155 | output_layers = ['Dense_Series' + index_to_layer_name(i) 156 | for i in range(num_dense_out)] 157 | output_layers.extend(['RBF_Dense', 'Upprojection']) 158 | interaction_layers = ['Dense_ji', 'Dense_kj', 'Downprojection', 159 | 'Upprojection', 'FinalBeforeSkip', 'rbf1', 'rbf2', 160 | 'sbf1', 'sbf2'] 161 | residual_layers = ['ResidualSubLayer' + index_to_layer_name(i) 162 | for i in range(2)] 163 | embedding_layers = ['Concat_Dense', 'RBF_Dense'] 164 | 165 | if setup_dict is None: # no dropout 166 | setup_dict = {} 167 | 168 | dropout_setup = {} 169 | if 'output' in setup_dict.keys(): 170 | layer_prefix = net_prefix + delimiter + output_prefix 171 | droprate = setup_dict['output'] 172 | 173 | for i_outblock in range(n_interaction_blocks + 1): 174 | i_interaction_prefix = layer_prefix + \ 175 | index_to_layer_name(i_outblock) + delimiter 176 | for layer_name in output_layers: 177 | name = i_interaction_prefix + layer_name 178 | dropout_setup[name] = droprate 179 | 180 | if 'interaction' in setup_dict.keys(): 181 | layer_prefix = net_prefix + delimiter + interaction_prefix 182 | droprate = setup_dict['interaction'] 183 | 184 | for i_interaction in range(n_interaction_blocks): 185 | i_interaction_prefix = layer_prefix + \ 186 | index_to_layer_name(i_interaction) \ 187 | + delimiter 188 | for layer_name in interaction_layers: 189 | name = i_interaction_prefix + layer_name 190 | dropout_setup[name] = droprate 191 | for i_res_block in range(num_res_before_skip + num_res_after_skip): 192 | res_block_prefix = i_interaction_prefix + residual_prefix + \ 193 | index_to_layer_name(i_res_block) + delimiter 194 | for res_layer_name in residual_layers: 195 | name = res_block_prefix + res_layer_name 196 | dropout_setup[name] = droprate 197 | 198 | if 'embedding' in setup_dict.keys(): 199 | layer_prefix = net_prefix + delimiter + embedding_prefix 200 | droprate = setup_dict['embedding'] 201 | for layer_name in embedding_layers: 202 | name = layer_prefix + delimiter + layer_name 203 | dropout_setup[name] = droprate 204 | 205 | if overall_dropout_rate is not None: # override dropout rate 206 | 207 | for layer in dropout_setup: 208 | dropout_setup[layer] = overall_dropout_rate 209 | 210 | return dropout_setup 211 | -------------------------------------------------------------------------------- /chemtrain/force_matching.py: -------------------------------------------------------------------------------- 1 | """Functions for learning via direct matching of per-snapshot quantities, 2 | such as energy, forces and virial pressure. 3 | """ 4 | from collections import namedtuple 5 | 6 | from jax import vmap, value_and_grad, numpy as jnp 7 | 8 | from chemtrain import max_likelihood 9 | from chemtrain.jax_md_mod import custom_quantity 10 | 11 | # Note: 12 | # Computing the neighborlist in each snapshot is not efficient for DimeNet++, 13 | # which constructs a sparse graph representation afterwards. However, other 14 | # models such as the tabulated potential are inefficient if used without 15 | # neighbor list as many cut-off interactions are otherwise computed. 16 | # For the sake of a simpler implementation, the slight inefficiency 17 | # in the case of DimeNet++ is accepted for now. 18 | 19 | State = namedtuple( 20 | 'State', 21 | ['position'] 22 | ) 23 | State.__doc__ = """Emulates structure of simulation state for 24 | compatibility with quantity functions. 25 | 26 | position: atomic positions 27 | """ 28 | 29 | 30 | def build_dataset(position_data, energy_data=None, force_data=None, 31 | virial_data=None): 32 | """Builds the force-matching dataset depending on available data. 33 | 34 | Interface of force-loss functions depends on dict keys set here. 35 | """ 36 | dataset = {'R': position_data} 37 | if energy_data is not None: 38 | dataset['U'] = energy_data 39 | if force_data is not None: 40 | dataset['F'] = force_data 41 | if virial_data is not None: 42 | dataset['p'] = virial_data 43 | return dataset, _dataset_target_keys(dataset) 44 | 45 | 46 | def _dataset_target_keys(dataset): 47 | """Dataset keys excluding particle positions for validation loss with 48 | possibly masked atoms. 49 | """ 50 | target_key_list = list(dataset) 51 | target_key_list.remove('R') 52 | assert target_key_list, 'At least one target quantity needs to be supplied.' 53 | return target_key_list 54 | 55 | 56 | def init_virial_fn(virial_data, energy_fn_template, box_tensor): 57 | """Initializes the correct virial function depending on the target data 58 | type. 59 | """ 60 | if virial_data is not None: 61 | assert box_tensor is not None, ('If the virial is to be matched, ' 62 | 'box_tensor is a mandatory input.') 63 | if virial_data.ndim == 3: 64 | virial_fn = custom_quantity.init_virial_stress_tensor( 65 | energy_fn_template, box_tensor, include_kinetic=False, 66 | pressure_tensor=True 67 | ) 68 | elif virial_data.ndim in [1, 2]: 69 | virial_fn = custom_quantity.init_pressure( 70 | energy_fn_template, box_tensor, include_kinetic=False) 71 | else: 72 | raise ValueError('Format of virial dataset incompatible.') 73 | else: 74 | virial_fn = None 75 | 76 | return virial_fn 77 | 78 | 79 | def init_model(nbrs_init, energy_fn_template, virial_fn=None): 80 | """Initialize predictions of energy, forces and virial (if applicable) 81 | for a single snapshot. 82 | 83 | Beware, currently overflow of neighbor list is not checked. 84 | 85 | Args: 86 | nbrs_init: Initial neighbor list. 87 | energy_fn_template: Energy_fn_template to get energy_fn from params. 88 | virial_fn: Function to compute virial pressure. If None, no virial 89 | pressure is predicted. 90 | 91 | Returns: 92 | A function(params, single_observation) returning a dict of predictions 93 | containing energy ('U'), forces('F') and if applicable virial ('p'). 94 | The single_observation is assumed to be a dict contain particle 95 | positions under 'R'. 96 | """ 97 | def fm_model(params, single_observation): 98 | positions = single_observation['R'] 99 | energy_fn = energy_fn_template(params) 100 | # TODO check for neighborlist overflow and hand through 101 | nbrs = nbrs_init.update(positions) 102 | energy, negative_forces = value_and_grad(energy_fn)(positions, 103 | neighbor=nbrs) 104 | predictions = {'U': energy, 'F': -negative_forces} 105 | if virial_fn is not None: 106 | predictions['p'] = - virial_fn(State(positions), nbrs, params) 107 | return predictions 108 | return fm_model 109 | 110 | 111 | def init_loss_fn(gamma_u=1., gamma_f=1., gamma_p=1.e-6, 112 | error_fn=max_likelihood.mse_loss, individual=False): 113 | """Initializes loss function for energy/force matching. 114 | 115 | Args: 116 | gamma_u: Weight for potential energy loss component 117 | gamma_f: Weight for force loss component 118 | gamma_p: Weight for virial loss component 119 | error_fn: Function quantifying the deviation of the model and the 120 | targets. By default, a mean-squared error. 121 | individual: Default False initializes a loss function that returns 122 | scalar loss weighted by gammas. If True, returns all 123 | individual components, e.g. for testing purposes. In this 124 | case, gamma values are unused. 125 | 126 | Returns: 127 | loss_fn(predictions, targets), which returns a scalar loss value for a 128 | batch of predictions and targets. 129 | """ 130 | def loss_fn(predictions, targets, mask=None): 131 | errors = {} 132 | loss_val = 0. 133 | if 'U' in targets.keys(): # energy loss component 134 | errors['energy'] = error_fn(predictions['U'], targets['U']) 135 | loss_val += gamma_u * errors['energy'] 136 | if 'F' in targets.keys(): # forces loss component 137 | if mask is None: # only forces need mask, U and p are unchanged 138 | mask = jnp.ones_like(predictions['F']) 139 | errors['forces'] = error_fn(predictions['F'], targets['F'], mask) 140 | loss_val += gamma_f * errors['forces'] 141 | if 'p' in targets.keys(): # virial loss component 142 | errors['pressure'] = error_fn(predictions['p'], targets['p']) 143 | loss_val += gamma_p * errors['pressure'] 144 | 145 | if individual: 146 | return errors 147 | else: 148 | return loss_val 149 | 150 | return loss_fn 151 | 152 | 153 | def init_mae_fn(val_loader, nbrs_init, energy_fn_template, batch_size=1, 154 | batch_cache=1, virial_fn=None): 155 | """Returns a function that computes for each observable - energy, forces and 156 | virial (if applicable) - the individual mean absolute error on the 157 | validation set. These metrics are usually better interpretable than a 158 | (combined) MSE loss value. 159 | """ 160 | model = init_model(nbrs_init, energy_fn_template, virial_fn) 161 | batched_model = vmap(model, in_axes=(None, 0)) 162 | 163 | abs_error = init_loss_fn(error_fn=max_likelihood.mae_loss, individual=True) 164 | 165 | target_keys = _dataset_target_keys(val_loader._reference_data) 166 | mean_abs_error, data_release_fn = max_likelihood.init_val_loss_fn( 167 | batched_model, abs_error, val_loader, target_keys, batch_size, 168 | batch_cache) 169 | 170 | return mean_abs_error, data_release_fn -------------------------------------------------------------------------------- /chemtrain/jax_md_mod/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/chemtrain/jax_md_mod/__init__.py -------------------------------------------------------------------------------- /chemtrain/jax_md_mod/custom_energy.py: -------------------------------------------------------------------------------- 1 | """Custom definition of some potential energy functions.""" 2 | from functools import partial 3 | from typing import Callable, Any 4 | 5 | from jax import vmap 6 | import jax.numpy as jnp 7 | from jax_md import space, partition, util, energy, smap 8 | 9 | from chemtrain.jax_md_mod import custom_interpolate, custom_quantity 10 | from chemtrain import sparse_graph 11 | 12 | # Types 13 | f32 = util.f32 14 | f64 = util.f64 15 | Array = util.Array 16 | 17 | PyTree = Any 18 | Box = space.Box 19 | DisplacementFn = space.DisplacementFn 20 | DisplacementOrMetricFn = space.DisplacementOrMetricFn 21 | 22 | NeighborFn = partition.NeighborFn 23 | NeighborList = partition.NeighborList 24 | 25 | 26 | def harmonic_angle(displacement_or_metric: DisplacementOrMetricFn, 27 | angle_idxs: Array, 28 | eq_mean: Array, 29 | eq_variance: Array, 30 | kbt: [float, Array]): 31 | """Harmonic Angle interaction. 32 | 33 | The variance of the angle is used to determine the force constant. 34 | https://manual.gromacs.org/documentation/2019/reference-manual/functions/bonded-interactions.html 35 | 36 | Args: 37 | displacement_or_metric: Displacement function 38 | angle_idxs: Indices of particles (i, j, k) 39 | eq_mean: Equilibrium angle in degrees 40 | eq_variance: Angle Variance 41 | kbt: kbT 42 | 43 | Returns: 44 | Harmonic angle potential energy function. 45 | """ 46 | 47 | kbt = jnp.array(kbt, dtype=f32) 48 | angle_mask = jnp.ones([angle_idxs.shape[0], 1]) 49 | harmonic_fn = partial(energy.simple_spring, length=eq_mean, 50 | epsilon=kbt / eq_variance) 51 | 52 | def energy_fn(pos, **unused_kwargs): 53 | angles = sparse_graph.angle_triplets(pos, displacement_or_metric, 54 | angle_idxs, angle_mask) 55 | return jnp.sum(harmonic_fn(jnp.rad2deg(angles))) 56 | 57 | return energy_fn 58 | 59 | 60 | def dihedral_energy(angle, 61 | phase_angle: Array, 62 | force_constant: Array, 63 | n: [int, Array]): 64 | """Energy of dihedral angles. 65 | 66 | https://manual.gromacs.org/documentation/2019/reference-manual/functions/bonded-interactions.html 67 | """ 68 | cos_angle = jnp.cos(n * angle - phase_angle) 69 | energies = force_constant * (1 + cos_angle) 70 | return jnp.sum(energies) 71 | 72 | 73 | def periodic_dihedral(displacement_or_metric: DisplacementOrMetricFn, 74 | dihedral_idxs: Array, 75 | phase_angle: Array, 76 | force_constant: Array, 77 | multiplicity: [float, Array]): 78 | """Peridoc dihedral angle interaction. 79 | 80 | https://manual.gromacs.org/documentation/2019/reference-manual/functions/bonded-interactions.html 81 | 82 | Args: 83 | displacement_or_metric: Displacement function 84 | dihedral_idxs: Indices of particles (i, j, k, l) building the dihedrals 85 | phase_angle: Dihedral phase angle in degrees. 86 | force_constant: Force constant 87 | multiplicity: Dihedral multiplicity 88 | 89 | Returns: 90 | Peridoc dihedral potential energy function. 91 | """ 92 | 93 | multiplicity = jnp.array(multiplicity, dtype=f32) 94 | phase_angle = jnp.deg2rad(phase_angle) 95 | 96 | def energy_fn(pos, **unused_kwargs): 97 | dihedral_angles = custom_quantity.dihedral_displacement( 98 | pos, displacement_or_metric, dihedral_idxs, degrees=False) 99 | per_angle_u = vmap(dihedral_energy)(dihedral_angles, phase_angle, 100 | force_constant, multiplicity) 101 | return jnp.sum(per_angle_u) 102 | 103 | return energy_fn 104 | 105 | 106 | def generic_repulsion(dr: Array, 107 | sigma: Array = 1., 108 | epsilon: Array = 1., 109 | exp: Array = 12., 110 | **unused_dynamic_kwargs) -> Array: 111 | """ 112 | Repulsive interaction between soft sphere particles: 113 | U = epsilon * (sigma / r)**exp. 114 | 115 | Args: 116 | dr: An ndarray of pairwise distances between particles. 117 | sigma: Repulsion length scale 118 | epsilon: Interaction energy scale 119 | exp: Exponent specifying interaction stiffness 120 | 121 | Returns: 122 | Array of energies 123 | """ 124 | del unused_dynamic_kwargs 125 | dr = jnp.where(dr > 1.e-7, dr, 1.e7) # save masks dividing by 0 126 | idr = (sigma / dr) 127 | pot_energy = epsilon * idr ** exp 128 | return pot_energy 129 | 130 | 131 | def generic_repulsion_neighborlist( 132 | displacement_or_metric: DisplacementOrMetricFn, 133 | box_size: Box = None, 134 | species: Array = None, 135 | sigma: Array = 1.0, 136 | epsilon: Array = 1.0, 137 | exp: [int, Array] = 12., 138 | r_onset: Array = 0.9, 139 | r_cutoff: Array = 1., 140 | dr_threshold: float = 0.2, 141 | per_particle: bool = False, 142 | capacity_multiplier: float = 1.25, 143 | initialize_neighbor_list: bool = True): 144 | """Convenience wrapper to compute generic repulsion energy over a system 145 | with neighborlist. 146 | 147 | Provides option not to initialize neighborlist. This is useful if energy 148 | function needs to be initialized within a jitted function. 149 | """ 150 | sigma = jnp.array(sigma, dtype=f32) 151 | epsilon = jnp.array(epsilon, dtype=f32) 152 | exp = jnp.array(exp, dtype=f32) 153 | r_onset = jnp.array(r_onset, dtype=f32) 154 | r_cutoff = jnp.array(r_cutoff, dtype=f32) 155 | 156 | energy_fn = smap.pair_neighbor_list( 157 | energy.multiplicative_isotropic_cutoff(generic_repulsion, r_onset, 158 | r_cutoff), 159 | space.canonicalize_displacement_or_metric(displacement_or_metric), 160 | species=species, 161 | sigma=sigma, 162 | epsilon=epsilon, 163 | exp=exp, 164 | reduce_axis=(1,) if per_particle else None) 165 | 166 | if initialize_neighbor_list: 167 | assert box_size is not None 168 | neighbor_fn = partition.neighbor_list( 169 | displacement_or_metric, box_size, r_cutoff, dr_threshold, 170 | capacity_multiplier=capacity_multiplier 171 | ) 172 | return neighbor_fn, energy_fn 173 | 174 | return energy_fn 175 | 176 | 177 | def generic_repulsion_nonbond(displacement_or_metric: DisplacementOrMetricFn, 178 | pair_idxs: Array, 179 | sigma: Array = 1., 180 | epsilon: Array = 1., 181 | exp: Array = 12.) -> Callable[[Array], Array]: 182 | """Convenience wrapper to compute repulsive part of Lennard Jones energy of 183 | particles via connection idxs. 184 | 185 | Args: 186 | displacement_or_metric: Displacement_fn 187 | pair_idxs: Set of pair indices (i, j) defining repulsion pairs 188 | sigma: sigma 189 | epsilon: epsilon 190 | exp: LJ exponent 191 | 192 | Returns: 193 | Pairwise nonbonded repulsion potential energy function. 194 | """ 195 | sigma = jnp.array(sigma, f32) 196 | epsilon = jnp.array(epsilon, f32) 197 | exp = jnp.array(exp, dtype=f32) 198 | 199 | return smap.bond( 200 | generic_repulsion, 201 | space.canonicalize_displacement_or_metric(displacement_or_metric), 202 | pair_idxs, 203 | ignore_unused_parameters=True, 204 | sigma=sigma, 205 | epsilon=epsilon, 206 | exp=exp) 207 | 208 | 209 | def lennard_jones_nonbond(displacement_or_metric: DisplacementOrMetricFn, 210 | pair_idxs: Array, 211 | sigma: Array = 1., 212 | epsilon: Array = 1.) -> Callable[[Array], Array]: 213 | """Convenience wrapper to compute lennard jones energy of nonbonded 214 | particles. 215 | 216 | Args: 217 | displacement_or_metric: Displacement_fn 218 | pair_idxs: Set of pair indices (i, j) defining repulsion pairs 219 | sigma: sigma 220 | epsilon: epsilon 221 | 222 | Returns: 223 | Pairwise nonbonded repulsion potential energy function. 224 | """ 225 | sigma = jnp.array(sigma, f32) 226 | epsilon = jnp.array(epsilon, f32) 227 | return smap.bond( 228 | energy.lennard_jones, 229 | space.canonicalize_displacement_or_metric(displacement_or_metric), 230 | pair_idxs, 231 | ignore_unused_parameters=True, 232 | sigma=sigma, 233 | epsilon=epsilon) 234 | 235 | 236 | def tabulated(dr: Array, spline: Callable[[Array], Array], **unused_kwargs 237 | ) -> Array: 238 | """ 239 | Tabulated radial potential between particles given a spline function. 240 | 241 | Args: 242 | dr: An ndarray of pairwise distances between particles 243 | spline: A function computing the spline values at a given pairwise 244 | distance. 245 | 246 | Returns: 247 | Array of energies 248 | """ 249 | 250 | return spline(dr) 251 | 252 | 253 | def tabulated_neighbor_list(displacement_or_metric: DisplacementOrMetricFn, 254 | x_vals: Array, 255 | y_vals: Array, 256 | box_size: Box, 257 | degree: int = 3, 258 | monotonic: bool = True, 259 | r_onset: Array = 0.9, 260 | r_cutoff: Array = 1., 261 | dr_threshold: Array = 0.2, 262 | species: Array = None, 263 | capacity_multiplier: float = 1.25, 264 | initialize_neighbor_list: bool = True, 265 | per_particle: bool = False, 266 | fractional=True): 267 | """ 268 | Convenience wrapper to compute tabulated energy using a neighbor list. 269 | 270 | Provides option not to initialize neighborlist. This is useful if energy 271 | function needs to be initialized within a jitted function. 272 | """ 273 | 274 | x_vals = jnp.array(x_vals, f32) 275 | y_vals = jnp.array(y_vals, f32) 276 | box_size = jnp.array(box_size, f32) 277 | r_onset = jnp.array(r_onset, f32) 278 | r_cutoff = jnp.array(r_cutoff, f32) 279 | dr_threshold = jnp.array(dr_threshold, f32) 280 | 281 | # Note: cannot provide the spline parameters via kwargs because only 282 | # per-particle parameters are supported 283 | if monotonic: 284 | spline = custom_interpolate.MonotonicInterpolate(x_vals, y_vals) 285 | else: 286 | spline = custom_interpolate.InterpolatedUnivariateSpline(x_vals, y_vals, 287 | k=degree) 288 | tabulated_partial = partial(tabulated, spline=spline) 289 | 290 | energy_fn = smap.pair_neighbor_list( 291 | energy.multiplicative_isotropic_cutoff(tabulated_partial, r_onset, 292 | r_cutoff), 293 | space.canonicalize_displacement_or_metric(displacement_or_metric), 294 | species=species, 295 | reduce_axis=(1,) if per_particle else None) 296 | 297 | if initialize_neighbor_list: 298 | neighbor_fn = partition.neighbor_list( 299 | displacement_or_metric, box_size, r_cutoff, dr_threshold, 300 | capacity_multiplier=capacity_multiplier, 301 | fractional_coordinates=fractional) 302 | return neighbor_fn, energy_fn 303 | return energy_fn 304 | -------------------------------------------------------------------------------- /chemtrain/jax_md_mod/custom_space.py: -------------------------------------------------------------------------------- 1 | """Custom functions simplifying the handling of fractional coordinates.""" 2 | from typing import Union, Tuple, Callable 3 | 4 | from jax_md import space, util 5 | import jax.numpy as jnp 6 | 7 | Box = Union[float, util.Array] 8 | 9 | 10 | def _rectangular_boxtensor(box: Box) -> Box: 11 | """Transforms a 1-dimensional box to a 2D box tensor.""" 12 | spatial_dim = box.shape[0] 13 | return jnp.eye(spatial_dim).at[jnp.diag_indices(spatial_dim)].set(box) 14 | 15 | 16 | def init_fractional_coordinates(box: Box) -> Tuple[Box, Callable]: 17 | """Returns a 2D box tensor and a scale function that projects positions 18 | within a box in real space to the unit-hypercube as required by fractional 19 | coordinates. 20 | 21 | Args: 22 | box: A 1 or 2-dimensional box 23 | 24 | Returns: 25 | A tuple (box, scale_fn) of a 2D box tensor and a scale_fn that scales 26 | positions in real-space to the unit hypercube. 27 | """ 28 | if box.ndim != 2: # we need to transform to box tensor 29 | box_tensor = _rectangular_boxtensor(box) 30 | else: 31 | box_tensor = box 32 | inv_box_tensor = space.inverse(box_tensor) 33 | scale_fn = lambda positions: jnp.dot(positions, inv_box_tensor) 34 | return box_tensor, scale_fn 35 | -------------------------------------------------------------------------------- /chemtrain/jax_md_mod/io.py: -------------------------------------------------------------------------------- 1 | """Functions for io: Loading data to and from Jax M.D.""" 2 | import jax.numpy as jnp 3 | import mdtraj 4 | import numpy as onp 5 | 6 | 7 | def load_box(filename): 8 | """Loads initial configuration using the file loader from MDTraj. 9 | 10 | Args: 11 | filename: String providing the location of the file to load. 12 | 13 | Returns: 14 | Tuple of jnp arrays of box, coordinates, mass, and species. 15 | """ 16 | traj = mdtraj.load(filename) 17 | coordinates = traj.xyz[0] 18 | box = traj.unitcell_lengths[0] 19 | 20 | species = onp.zeros(coordinates.shape[0]) 21 | masses = onp.zeros_like(species) 22 | for atom in traj.topology.atoms: 23 | species[atom.index] = atom.element.number 24 | masses[atom.index] = atom.element.mass 25 | 26 | # _, bonds = traj.topology.to_dataframe() 27 | 28 | return (jnp.array(box), jnp.array(coordinates), jnp.array(masses), 29 | jnp.array(species, dtype=jnp.int32)) 30 | -------------------------------------------------------------------------------- /chemtrain/neural_networks.py: -------------------------------------------------------------------------------- 1 | """Neural network models for potential energy and molecular property 2 | prediction. 3 | """ 4 | from functools import partial 5 | from typing import Callable, Dict, Any, Tuple 6 | 7 | import haiku as hk 8 | from jax import numpy as jnp, nn as jax_nn 9 | from jax_md import smap, space, partition, nn, util 10 | 11 | from chemtrain import layers, sparse_graph, dropout 12 | 13 | 14 | class DimeNetPP(hk.Module): 15 | """DimeNet++ for molecular property prediction. 16 | 17 | This model takes as input a sparse representation of a molecular graph 18 | - consisting of pairwise distances and angular triplets - and predicts 19 | per-atom properties. Global properties can be obtained by summing over 20 | per-atom predictions. 21 | 22 | The default values correspond to the orinal values of DimeNet++. 23 | 24 | This custom implementation follows the original DimeNet / DimeNet++ 25 | (https://arxiv.org/abs/2011.14115), while correcting for known issues 26 | (see https://github.com/klicperajo/dimenet). 27 | """ 28 | def __init__(self, 29 | r_cutoff: float, 30 | n_species: int, 31 | num_targets: int, 32 | kbt_dependent: bool = False, 33 | embed_size: int = 128, 34 | n_interaction_blocks: int = 4, 35 | num_residual_before_skip: int = 1, 36 | num_residual_after_skip: int = 2, 37 | out_embed_size: int = None, 38 | type_embed_size: int = None, 39 | angle_int_embed_size: int = None, 40 | basis_int_embed_size: int = 8, 41 | num_dense_out: int = 3, 42 | num_rbf: int = 6, 43 | num_sbf: int = 7, 44 | activation: Callable = jax_nn.swish, 45 | envelope_p: int = 6, 46 | init_kwargs: Dict[str, Any] = None, 47 | dropout_mode: Dict[str, Any] = None, 48 | name: str = 'DimeNetPP'): 49 | """Initializes the DimeNet++ model 50 | 51 | The default values correspond to the orinal values of DimeNet++. 52 | 53 | Args: 54 | r_cutoff: Radial cut-off distance of edges 55 | n_species: Number of different atom species the network is supposed 56 | to process. 57 | num_targets: Number of different atomic properties to predict 58 | kbt_dependent: True, if DimeNet explicitly depends on temperature. 59 | In this case 'kT' needs to be provided as a kwarg 60 | during the model call to the energy_fn. Default False 61 | results in a model independent of temperature. 62 | embed_size: Size of message embeddings. Scale interaction and output 63 | embedding sizes accordingly, if not specified 64 | explicitly. 65 | n_interaction_blocks: Number of interaction blocks 66 | num_residual_before_skip: Number of residual blocks before the skip 67 | connection in the Interaction block. 68 | num_residual_after_skip: Number of residual blocks after the skip 69 | connection in the Interaction block. 70 | out_embed_size: Embedding size of output block. 71 | If None is set to 2 * embed_size. 72 | type_embed_size: Embedding size of atom type embeddings. 73 | If None is set to 0.5 * embed_size. 74 | angle_int_embed_size: Embedding size of Linear layers for 75 | down-projected triplet interation. 76 | If None is 0.5 * embed_size. 77 | basis_int_embed_size: Embedding size of Linear layers for interation 78 | of RBS/ SBF basis in interaction block 79 | num_dense_out: Number of final Linear layers in output block 80 | num_rbf: Number of radial Bessel embedding functions 81 | num_sbf: Number of spherical Bessel embedding functions 82 | activation: Activation function 83 | envelope_p: Power of envelope polynomial 84 | init_kwargs: Kwargs for initializaion of Linear layers 85 | dropout_mode: A dict defining which fully connected layers to apply 86 | dropout and at which rate 87 | (see dropout.dimenetpp_setup). If None, no Dropout is 88 | applied. 89 | name: Name of DimeNet++ model 90 | """ 91 | super().__init__(name=name) 92 | self.dropout_setup = dropout.dimenetpp_setup(dropout_mode, 93 | num_dense_out, 94 | n_interaction_blocks, 95 | num_residual_before_skip, 96 | num_residual_after_skip) 97 | if init_kwargs is None: 98 | init_kwargs = { 99 | 'w_init': layers.OrthogonalVarianceScalingInit(scale=1.), 100 | 'b_init': hk.initializers.Constant(0.), 101 | } 102 | 103 | # input representation: 104 | self.r_cutoff = r_cutoff 105 | self._rbf_layer = layers.RadialBesselLayer(r_cutoff, num_rbf, 106 | envelope_p) 107 | self._sbf_layer = layers.SphericalBesselLayer(r_cutoff, num_sbf, 108 | num_rbf, envelope_p) 109 | 110 | # build GNN structure 111 | self._n_interactions = n_interaction_blocks 112 | self._output_blocks = [] 113 | self._int_blocks = [] 114 | self._embedding_layer = layers.EmbeddingBlock( 115 | embed_size, n_species, type_embed_size, activation, init_kwargs, 116 | kbt_dependent) 117 | self._output_blocks.append(layers.OutputBlock( 118 | embed_size, out_embed_size, num_dense_out, num_targets, activation, 119 | init_kwargs) 120 | ) 121 | 122 | for _ in range(n_interaction_blocks): 123 | self._int_blocks.append(layers.InteractionBlock( 124 | embed_size, num_residual_before_skip, num_residual_after_skip, 125 | activation, init_kwargs, angle_int_embed_size, 126 | basis_int_embed_size) 127 | ) 128 | self._output_blocks.append(layers.OutputBlock( 129 | embed_size, out_embed_size, num_dense_out, num_targets, 130 | activation, init_kwargs) 131 | ) 132 | 133 | def __call__(self, 134 | graph: sparse_graph.SparseDirectionalGraph, 135 | **dyn_kwargs) -> jnp.ndarray: 136 | """Predicts per-atom quantities for a given molecular graph. 137 | 138 | Args: 139 | graph: An instance of sparse_graph.SparseDirectionalGraph defining 140 | the molecular graph connectivity. 141 | **dyn_kwargs: Kwargs supplied on-the-fly, such as 'kT' for 142 | temperature-dependent models or 'dropout_key' for 143 | Dropout. 144 | 145 | Returns: 146 | An (n_partciles, num_targets) array of predicted per-atom quantities 147 | """ 148 | dropout_key = dyn_kwargs.get('dropout_key', None) 149 | dropout_params = dropout.construct_dropout_params(dropout_key, 150 | self.dropout_setup) 151 | n_particles = graph.species.size 152 | # cutoff all non-existing edges: are encoded as 0 by rbf envelope 153 | # non-existing triplets will be masked explicitly in DimeNet++ 154 | pair_distances = jnp.where(graph.edge_mask, graph.distance_ij, 155 | 2. * self.r_cutoff) 156 | 157 | rbf = self._rbf_layer(pair_distances) 158 | # explicitly masked via mask array in angular_connections 159 | sbf = self._sbf_layer(pair_distances, graph.angles, graph.triplet_mask, 160 | graph.expand_to_kj) 161 | 162 | messages = self._embedding_layer(rbf, graph.species, graph.idx_i, 163 | graph.idx_j, dropout_params, 164 | **dyn_kwargs) 165 | per_atom_quantities = self._output_blocks[0]( 166 | messages, rbf, graph.idx_i, n_particles, dropout_params) 167 | 168 | for i in range(self._n_interactions): 169 | messages = self._int_blocks[i]( 170 | messages, rbf, sbf, graph.reduce_to_ji, graph.expand_to_kj, 171 | dropout_params) 172 | per_atom_quantities += self._output_blocks[i + 1]( 173 | messages, rbf, graph.idx_i, n_particles, dropout_params) 174 | return per_atom_quantities 175 | 176 | 177 | def dimenetpp_neighborlist(displacement: space.DisplacementFn, 178 | r_cutoff: float, 179 | n_species: int = 100, 180 | positions_test: jnp.ndarray = None, 181 | neighbor_test: partition.NeighborList = None, 182 | max_triplet_multiplier: float = 1.25, 183 | max_edge_multiplier: float = 1.25, 184 | **dimenetpp_kwargs 185 | ) -> Tuple[nn.InitFn, Callable[[Any, util.Array], 186 | util.Array]]: 187 | """DimeNet++ energy function for Jax, M.D. 188 | 189 | This function provides an interface for the DimeNet++ haiku model to be used 190 | as a jax_md energy_fn. Analogous to jax_md energy_fns, the initialized 191 | DimeNet++ energy_fn requires particle positions and a dense neighbor list as 192 | input - plus an array for species or other dynamic kwargs, if applicable. 193 | 194 | From particle positions and neighbor list, the sparse graph representation 195 | with edges and angle triplets is computed. Due to the constant shape 196 | requirement of jit of the neighborlist in jax_md, the neighbor list contains 197 | many masked edges, i.e. pairwise interactions that only "fill" the neighbor 198 | list, but are set to 0 during computation. This translates to masked edges 199 | and triplets in the sparse graph representation. 200 | 201 | For improved computational efficiency during jax_md simulations, the 202 | maximum number of edges and triplets can be estimated during model 203 | initialization. Edges and triplets beyond this maximum estimate can be 204 | capped to reduce computational and memory requirements. Capping is enabled 205 | by providing sample inputs (positions_test and neighbor_test) at 206 | initialization time. However, beware that currently, an overflow of 207 | max_edges and max_angles is not caught, as this requires passing an error 208 | code throgh jax_md simulators - analogous to the overflow detection in 209 | jax_md neighbor lists. If in doubt, increase the max edges/angles 210 | multipliers or disable capping. 211 | 212 | Args: 213 | displacement: Jax_md displacement function 214 | r_cutoff: Radial cut-off distance of DimeNetPP and the neighbor list 215 | n_species: Number of different atom species the network is supposed 216 | to process. 217 | positions_test: Sample positions to estimate max_edges / max_angles. 218 | Needs to be provided to enable capping. 219 | neighbor_test: Sample neighborlist to estimate max_edges / max_angles. 220 | Needs to be provided to enable capping. 221 | max_edge_multiplier: Multiplier for initial estimate of maximum edges. 222 | max_triplet_multiplier: Multiplier for initial estimate of maximum 223 | triplets. 224 | dimenetpp_kwargs: Kwargs to change the default structure of DimeNet++. 225 | For definition of the kwargs, see DimeNetPP. 226 | 227 | Returns: 228 | A tuple of 2 functions: A init_fn that initializes the model parameters 229 | and an energy function that computes the energy for a particular state 230 | given model parameters. The energy function requires the same input as 231 | other energy functions with neighbor lists in jax_md.energy. 232 | """ 233 | r_cutoff = jnp.array(r_cutoff, dtype=util.f32) 234 | 235 | if positions_test is not None and neighbor_test is not None: 236 | print('Capping edges and triplets. Beware of overflow, which is' 237 | ' currently not being detected.') 238 | 239 | testgraph, _ = sparse_graph.sparse_graph_from_neighborlist( 240 | displacement, positions_test, neighbor_test, r_cutoff) 241 | max_triplets = jnp.int32(jnp.ceil(testgraph.n_triplets 242 | * max_triplet_multiplier)) 243 | max_edges = jnp.int32(jnp.ceil(testgraph.n_edges * max_edge_multiplier)) 244 | 245 | # cap maximum edges and angles to avoid overflow from multiplier 246 | n_particles, n_neighbors = neighbor_test.idx.shape 247 | max_edges = min(max_edges, n_particles * n_neighbors) 248 | max_triplets = min(max_triplets, n_particles * n_neighbors**2) 249 | else: 250 | max_triplets = None 251 | max_edges = None 252 | 253 | @hk.without_apply_rng 254 | @hk.transform 255 | def model(positions: jnp.ndarray, 256 | neighbor: partition.NeighborList, 257 | species: jnp.ndarray = None, 258 | **dynamic_kwargs) -> jnp.ndarray: 259 | """Evalues the DimeNet++ model and predicts the potential energy. 260 | 261 | Args: 262 | positions: Jax_md state-position. (N_particles x dim) array of 263 | particle positions 264 | neighbor: Jax_md dense neighbor list corresponding to positions 265 | species: (N_particles,) Array encoding atom types. If None, assumes 266 | all particles to belong to the same species 267 | **dynamic_kwargs: Dynamic kwargs, such as 'box' or 'kT'. 268 | 269 | Returns: 270 | Potential energy value of state 271 | """ 272 | # dynamic box necessary for pressure computation 273 | dynamic_displacement = partial(displacement, **dynamic_kwargs) 274 | 275 | graph_rep, overflow = sparse_graph.sparse_graph_from_neighborlist( 276 | dynamic_displacement, positions, neighbor, r_cutoff, species, 277 | max_edges, max_triplets 278 | ) 279 | # TODO: return overflow to detect possible overflow 280 | del overflow 281 | 282 | net = DimeNetPP(r_cutoff, n_species, num_targets=1, **dimenetpp_kwargs) 283 | per_atom_energies = net(graph_rep, **dynamic_kwargs) 284 | gnn_energy = util.high_precision_sum(per_atom_energies) 285 | return gnn_energy 286 | 287 | return dropout.model_init_apply(model, dimenetpp_kwargs) -------------------------------------------------------------------------------- /chemtrain/reweighting.py: -------------------------------------------------------------------------------- 1 | """Implementation of the reweighting formalism. 2 | 3 | Allows re-using existing trajectories. 4 | p(S) = ... 5 | """ 6 | from abc import abstractmethod 7 | import time 8 | import warnings 9 | 10 | from jax import (checkpoint, lax, jit, random, grad, tree_util, vmap, 11 | numpy as jnp) 12 | from jax_md import util as jax_md_util 13 | 14 | from chemtrain import util, traj_util, max_likelihood 15 | from chemtrain.jax_md_mod import custom_quantity 16 | 17 | 18 | def checkpoint_quantities(compute_fns): 19 | """Applies checkpoint to all compute_fns to save memory on backward pass.""" 20 | for quantity_key in compute_fns: 21 | compute_fns[quantity_key] = checkpoint(compute_fns[quantity_key]) 22 | 23 | 24 | def _estimate_effective_samples(weights): 25 | """Returns the effective sample size after reweighting to 26 | judge reweighting quality. 27 | """ 28 | # mask to avoid NaN from log(0) if a few weights are 0. 29 | weights = jnp.where(weights > 1.e-10, weights, 1.e-10) 30 | exponent = -jnp.sum(weights * jnp.log(weights)) 31 | return jnp.exp(exponent) 32 | 33 | 34 | def _build_weights(exponents): 35 | """Returns weights and the effective sample size from exponents 36 | of the reweighting formulas in a numerically stable way. 37 | """ 38 | 39 | # The reweighting scheme is a softmax, where the exponent above 40 | # represents the logits. To improve numerical stability and 41 | # guard against overflow it is good practice to subtract the 42 | # max of the exponent using the identity softmax(x + c) = 43 | # softmax(x). With all values in the exponent <=0, this 44 | # rules out overflow and the 0 value guarantees a denominator >=1. 45 | exponents -= jnp.max(exponents) 46 | prob_ratios = jnp.exp(exponents) 47 | weights = prob_ratios / jax_md_util.high_precision_sum(prob_ratios) 48 | n_eff = _estimate_effective_samples(weights) 49 | return weights, n_eff 50 | 51 | 52 | def init_pot_reweight_propagation_fns(energy_fn_template, simulator_template, 53 | neighbor_fn, timings, ref_kbt, 54 | ref_press=None, reweight_ratio=0.9, 55 | npt_ensemble=False, energy_batch_size=10): 56 | """ 57 | Initializes all functions necessary for trajectory reweighting for 58 | a single state point. 59 | 60 | Initialized functions include a function that computes weights for a 61 | given trajectory and a function that propagates the trajectory forward 62 | if the statistical error does not allow a re-use of the trajectory. 63 | The propagation function also ensures that generated trajectories 64 | did not encounter any neighbor list overflow. 65 | """ 66 | traj_energy_fn = custom_quantity.energy_wrapper(energy_fn_template) 67 | reweighting_quantities = {'energy': traj_energy_fn} 68 | 69 | if npt_ensemble: 70 | # pressure currently only used to print pressure of generated trajectory 71 | # such that user can ensure correct statepoint of reference trajectory 72 | pressure_fn = custom_quantity.init_pressure(energy_fn_template) 73 | reweighting_quantities['pressure'] = pressure_fn 74 | 75 | trajectory_generator = traj_util.trajectory_generator_init( 76 | simulator_template, energy_fn_template, timings, reweighting_quantities) 77 | 78 | beta = 1. / ref_kbt 79 | checkpoint_quantities(reweighting_quantities) 80 | 81 | def compute_weights(params, traj_state): 82 | """Computes weights for the reweighting approach.""" 83 | 84 | # reweighting properties (U and pressure) under perturbed potential 85 | reweight_properties = traj_util.quantity_traj( 86 | traj_state, reweighting_quantities, params, energy_batch_size) 87 | 88 | # Note: Difference in pot. Energy is difference in total energy 89 | # as kinetic energy is the same and cancels 90 | exponent = -beta * (reweight_properties['energy'] 91 | - traj_state.aux['energy']) 92 | return _build_weights(exponent) 93 | 94 | def trajectory_identity_mapping(inputs): 95 | """Re-uses trajectory if no recomputation needed.""" 96 | traj_state = inputs[1] 97 | return traj_state 98 | 99 | def recompute_trajectory(inputs): 100 | """Recomputes the reference trajectory, starting from the last 101 | state of the previous trajectory to save equilibration time. 102 | """ 103 | params, traj_state = inputs 104 | # give kT here as additional input to be handed through to energy_fn 105 | # for kbt-dependent potentials 106 | updated_traj = trajectory_generator(params, traj_state.sim_state, 107 | kT=ref_kbt, pressure=ref_press) 108 | return updated_traj 109 | 110 | @jit 111 | def propagation_fn(params, traj_state): 112 | """Checks if a trajectory can be re-used. If not, a new trajectory 113 | is generated ensuring trajectories are always valid. 114 | Takes params and the traj_state as input and returns a 115 | trajectory valid for reweighting as well as an error code 116 | indicating if the neighborlist buffer overflowed during trajectory 117 | generation. 118 | """ 119 | _, n_eff = compute_weights(params, traj_state) 120 | n_snapshots = traj_state.aux['energy'].size 121 | recompute = n_eff < reweight_ratio * n_snapshots 122 | propagated_state = lax.cond(recompute, 123 | recompute_trajectory, 124 | trajectory_identity_mapping, 125 | (params, traj_state)) 126 | return propagated_state 127 | 128 | def safe_propagate_traj(params, traj_state): 129 | """Recomputes trajectory and neighbor list until a trajectory without 130 | overflow was obtained. 131 | """ 132 | reset_counter = 0 133 | while traj_state.overflow: 134 | warnings.warn('Neighborlist buffer overflowed. ' 135 | 'Initializing larger neighborlist.') 136 | if reset_counter == 3: # still overflow after multiple resets 137 | raise RuntimeError('Multiple neighbor list re-computations did ' 138 | 'not yield a trajectory without overflow. ' 139 | 'Consider increasing the neighbor list ' 140 | 'capacity multiplier.') 141 | 142 | # Note: We restart the simulation from the last trajectory, where 143 | # we know that no overflow has occured. Re-starting from a state 144 | # that was generated with overflow is dangerous as overflown 145 | # particles could cause exploding forces once re-considered. 146 | last_state, _ = traj_state.sim_state 147 | if jnp.any(jnp.isnan(last_state.position)): 148 | raise RuntimeError('Last state is NaN. Currently there is no ' 149 | 'recovering from this. Restart from the last' 150 | ' non-overflown state might help, but comes' 151 | ' at the cost that the reference state is ' 152 | 'likely not representative.') 153 | if last_state.position.ndim > 2: 154 | single_enlarged_nbrs = util.neighbor_allocate( 155 | neighbor_fn, util.tree_get_single(last_state)) 156 | enlarged_nbrs = vmap(util.neighbor_update, (None, 0))( 157 | single_enlarged_nbrs, last_state) 158 | else: 159 | enlarged_nbrs = util.neighbor_allocate(neighbor_fn, last_state) 160 | reset_traj_state = traj_state.replace( 161 | sim_state=(last_state, enlarged_nbrs)) 162 | traj_state = recompute_trajectory((params, reset_traj_state)) 163 | reset_counter += 1 164 | return traj_state 165 | 166 | def propagate(params, old_traj_state): 167 | """Wrapper around jitted propagation function that ensures that 168 | if neighbor list buffer overflowed, the trajectory is recomputed and 169 | the neighbor list size is increased until valid trajectory was obtained. 170 | Due to the recomputation of the neighbor list, this function cannot be 171 | jit. 172 | """ 173 | new_traj_state = propagation_fn(params, old_traj_state) 174 | new_traj_state = safe_propagate_traj(params, new_traj_state) 175 | return new_traj_state 176 | 177 | def init_first_traj(params, reference_state): 178 | """Initializes initial trajectory to start optimization from. 179 | 180 | We dump the initial trajectory for equilibration, as initial 181 | equilibration usually takes much longer than equilibration time 182 | of each trajectory. If this is still not sufficient, the simulation 183 | should equilibrate over the course of subsequent updates. 184 | """ 185 | dump_traj = trajectory_generator(params, reference_state, 186 | kT=ref_kbt, pressure=ref_press) 187 | 188 | t_start = time.time() 189 | init_traj = trajectory_generator(params, dump_traj.sim_state, 190 | kT=ref_kbt, pressure=ref_press) 191 | runtime = (time.time() - t_start) / 60. # in mins 192 | init_traj = safe_propagate_traj(params, init_traj) 193 | return init_traj, runtime 194 | 195 | return init_first_traj, compute_weights, propagate 196 | 197 | 198 | def init_rel_entropy_gradient(energy_fn_template, compute_weights, kbt, 199 | vmap_batch_size=10): 200 | """Initializes a function that computes the relative entropy gradient. 201 | 202 | The computation of the gradient is batched to increase computational 203 | efficiency. 204 | 205 | Args: 206 | energy_fn_template: Energy function template 207 | compute_weights: compute_weights function as initialized from 208 | init_pot_reweight_propagation_fns. 209 | kbt: KbT 210 | vmap_batch_size: Batch size for 211 | 212 | Returns: 213 | A function rel_entropy_gradient(params, traj_state, reference_batch), 214 | which returns the relative entropy gradient of 'params' given a 215 | generated trajectory saved in 'traj_state' and a reference trajectory 216 | 'reference_batch'. 217 | """ 218 | beta = 1 / kbt 219 | 220 | @jit 221 | def rel_entropy_gradient(params, traj_state, reference_batch): 222 | if traj_state.sim_state[0].position.ndim > 2: 223 | nbrs_init = util.tree_get_single(traj_state.sim_state[1]) 224 | else: 225 | nbrs_init = traj_state.sim_state[1] 226 | 227 | def energy(params, position): 228 | energy_fn = energy_fn_template(params) 229 | # Note: nbrs update requires constant box, i.e. not yet 230 | # applicable to npt ensemble 231 | nbrs = nbrs_init.update(position) 232 | return energy_fn(position, neighbor=nbrs) 233 | 234 | def weighted_gradient(map_input): 235 | position, weight = map_input 236 | snapshot_grad = grad(energy)(params, position) # dudtheta 237 | weight_gradient = lambda new_grad: weight * new_grad 238 | weighted_grad_snapshot = tree_util.tree_map(weight_gradient, 239 | snapshot_grad) 240 | return weighted_grad_snapshot 241 | 242 | def add_gradient(map_input): 243 | batch_gradient = vmap(weighted_gradient)(map_input) 244 | return util.tree_sum(batch_gradient, axis=0) 245 | 246 | weights, _ = compute_weights(params, traj_state) 247 | 248 | # reshape for batched computations 249 | batch_weights = weights.reshape((-1, vmap_batch_size)) 250 | traj_shape = traj_state.trajectory.position.shape 251 | batchwise_gen_traj = traj_state.trajectory.position.reshape( 252 | (-1, vmap_batch_size, traj_shape[-2], traj_shape[-1])) 253 | ref_shape = reference_batch.shape 254 | reference_batches = reference_batch.reshape( 255 | (-1, vmap_batch_size, ref_shape[-2], ref_shape[-1])) 256 | 257 | # no reweighting for reference data: weights = 1 / N 258 | ref_weights = jnp.ones(reference_batches.shape[:2]) / (ref_shape[0]) 259 | 260 | ref_grad = lax.map(add_gradient, (reference_batches, ref_weights)) 261 | mean_ref_grad = util.tree_sum(ref_grad, axis=0) 262 | gen_traj_grad = lax.map(add_gradient, (batchwise_gen_traj, 263 | batch_weights)) 264 | mean_gen_grad = util.tree_sum(gen_traj_grad, axis=0) 265 | 266 | combine_grads = lambda x, y: beta * (x - y) 267 | dtheta = tree_util.tree_map(combine_grads, mean_ref_grad, mean_gen_grad) 268 | return dtheta 269 | return rel_entropy_gradient 270 | 271 | 272 | class PropagationBase(max_likelihood.MLETrainerTemplate): 273 | """Trainer base class for shared functionality whenever (multiple) 274 | simulations are run during training. Can be used as a template to 275 | build other trainers. Currently used for DiffTRe and relative entropy. 276 | 277 | We only save the latest generated trajectory for each state point. 278 | While accumulating trajectories would enable more frequent reweighting, 279 | this effect is likely minor as past trajectories become exponentially 280 | less useful with changing potential. Additionally, saving long trajectories 281 | for each statepoint would increase memory requirements over the course of 282 | the optimization. 283 | """ 284 | def __init__(self, init_trainer_state, optimizer, checkpoint_path, 285 | reweight_ratio=0.9, sim_batch_size=1, energy_fn_template=None): 286 | super().__init__(optimizer, init_trainer_state, checkpoint_path, 287 | energy_fn_template) 288 | self.sim_batch_size = sim_batch_size 289 | self.reweight_ratio = reweight_ratio 290 | 291 | # store for each state point corresponding traj_state and grad_fn 292 | # save in distinct dicts as grad_fns need to be deleted for checkpoint 293 | self.grad_fns, self.trajectory_states, self.statepoints = {}, {}, {} 294 | self.n_statepoints = 0 295 | self.shuffle_key = random.PRNGKey(0) 296 | 297 | def _init_statepoint(self, reference_state, energy_fn_template, 298 | simulator_template, neighbor_fn, timings, kbt, 299 | set_key=None, energy_batch_size=10, 300 | initialize_traj=True, ref_press=None): 301 | """Initializes the simulation and reweighting functions as well 302 | as the initial trajectory for a statepoint.""" 303 | if set_key is not None: 304 | key = set_key 305 | if set_key not in self.statepoints.keys(): 306 | self.n_statepoints += 1 307 | else: 308 | key = self.n_statepoints 309 | self.n_statepoints += 1 310 | self.statepoints[key] = {'kbT': kbt} 311 | npt_ensemble = util.is_npt_ensemble(reference_state[0]) 312 | if npt_ensemble: self.statepoints[key]['pressure'] = ref_press 313 | 314 | gen_init_traj, compute_weights, propagate = \ 315 | init_pot_reweight_propagation_fns(energy_fn_template, 316 | simulator_template, 317 | neighbor_fn, 318 | timings, 319 | kbt, 320 | ref_press, 321 | self.reweight_ratio, 322 | npt_ensemble, 323 | energy_batch_size) 324 | if initialize_traj: 325 | init_traj, runtime = gen_init_traj(self.params, reference_state) 326 | print(f'Time for trajectory initialization {key}: {runtime} mins') 327 | self.trajectory_states[key] = init_traj 328 | else: 329 | print('Not initializing the initial trajectory is only valid if ' 330 | 'a checkpoint is loaded. In this case, please be use to add ' 331 | 'state points in the same sequence, otherwise loaded ' 332 | 'trajectories will not match its respective simulations.') 333 | 334 | return key, compute_weights, propagate 335 | 336 | @abstractmethod 337 | def add_statepoint(self, *args, **kwargs): 338 | """User interface to add additional state point to train model on.""" 339 | raise NotImplementedError() 340 | 341 | @property 342 | def params(self): 343 | return self.state.params 344 | 345 | @params.setter 346 | def params(self, loaded_params): 347 | self.state = self.state.replace(params=loaded_params) 348 | 349 | def get_sim_state(self, key): 350 | return self.trajectory_states[key].sim_state 351 | 352 | def _get_batch(self): 353 | """Helper function to re-shuffle simulations and split into batches.""" 354 | self.shuffle_key, used_key = random.split(self.shuffle_key, 2) 355 | shuffled_indices = random.permutation(used_key, self.n_statepoints) 356 | if self.sim_batch_size == 1: 357 | batch_list = jnp.split(shuffled_indices, shuffled_indices.size) 358 | elif self.sim_batch_size == -1: 359 | batch_list = jnp.split(shuffled_indices, 1) 360 | else: 361 | raise NotImplementedError('Only batch_size = 1 or -1 implemented.') 362 | 363 | return (batch.tolist() for batch in batch_list) 364 | 365 | def _print_measured_statepoint(self): 366 | """Print meausured kbT (and pressure for npt ensemble) for all 367 | statepoints to ensure the simulation is indeed carried out at the 368 | prescribed state point. 369 | """ 370 | for sim_key, traj in self.trajectory_states.items(): 371 | statepoint = self.statepoints[sim_key] 372 | measured_kbt = jnp.mean(traj.aux['kbT']) 373 | if 'pressure' in statepoint: # NPT 374 | measured_press = jnp.mean(traj.aux['pressure']) 375 | press_print = (f' press = {measured_press:.2f} ref_press = ' 376 | f'{statepoint["pressure"]:.2f}') 377 | else: 378 | press_print = '' 379 | print(f'Statepoint {sim_key}: kbT = {measured_kbt:.3f} ref_kbt = ' 380 | f'{statepoint["kbT"]:.3f}' + press_print) 381 | 382 | def train(self, max_epochs, thresh=None, checkpoint_freq=None): 383 | assert self.n_statepoints > 0, ('Add at least 1 state point via ' 384 | '"add_statepoint" to start training.') 385 | super().train(max_epochs, thresh=thresh, 386 | checkpoint_freq=checkpoint_freq) 387 | 388 | @abstractmethod 389 | def _update(self, batch): 390 | """Implementation of gradient computation, stepping of the optimizer 391 | and logging of auxiliary results. Takes batch of simulation indices 392 | as input. 393 | """ 394 | -------------------------------------------------------------------------------- /chemtrain/sparse_graph.py: -------------------------------------------------------------------------------- 1 | """Functions to extract the sparse directional graph representation of a 2 | molecular state. 3 | 4 | The :class:`SparseDirectionalGraph` is the input to 5 | :class:`~chemtrain.neural_networks.DimeNetPP`. 6 | """ 7 | import inspect 8 | from typing import Optional, Callable, Tuple 9 | 10 | import chex 11 | from jax import numpy as jnp, vmap, lax 12 | from jax_md import space, partition, smap 13 | 14 | 15 | @chex.dataclass 16 | class SparseDirectionalGraph: 17 | """Sparse directial graph representation of a molecular state. 18 | 19 | Required arguments are necessary inputs for DimeNet++. 20 | If masks are not provided, all entities are assumed to be present. 21 | 22 | Attributes: 23 | distance_ij: A (N_edges,) array storing for each the radial distances 24 | between particle i and j 25 | idx_i: A (N_edges,) array storing for each edge particle index i 26 | idx_j: A (N_edges,) array storing for each edge particle index j 27 | angles: A (N_triplets,) array storing for each triplet the angle formed 28 | by the 3 particles 29 | reduce_to_ji: A (N_triplets,) array storing for each triplet kji edge 30 | index j->i to aggregate messages via a segment_sum: each 31 | m_ji is a distinct segment containing all incoming m_kj. 32 | expand_to_kj: A (N_triplets,) array storing for each triplet kji edge 33 | index k->j to gather all incoming edges for message 34 | passing. 35 | edge_mask: A (N_edges,) boolean array storing for each edge whether the 36 | edge exists. By default, all edges are considered. 37 | triplet_mask: A (N_triplets,) boolean array storing for each triplet 38 | whether the triplet exists. By default, all triplets are 39 | considered. 40 | n_edges: Number of non-masked edges in the graph. None assumes all 41 | edges are real. 42 | n_triplets: Number of non-masked triplets in the graph. None assumes 43 | all triplets are real. 44 | n_particles: Number of non-masked species in the graph. 45 | """ 46 | species: jnp.ndarray 47 | distance_ij: jnp.ndarray 48 | idx_i: jnp.ndarray 49 | idx_j: jnp.ndarray 50 | angles: jnp.ndarray 51 | reduce_to_ji: jnp.ndarray 52 | expand_to_kj: jnp.ndarray 53 | species_mask: Optional[jnp.ndarray] = None 54 | edge_mask: Optional[jnp.ndarray] = None 55 | triplet_mask: Optional[jnp.ndarray] = None 56 | n_edges: Optional[int] = None 57 | n_triplets: Optional[int] = None 58 | 59 | def __post_init__(self): 60 | if self.species_mask is None: 61 | self.species_mask = jnp.ones_like(self.species, dtype=bool) 62 | if self.edge_mask is None: 63 | self.edge_mask = jnp.ones_like(self.distance_ij, dtype=bool) 64 | if self.triplet_mask is None: 65 | self.triplet_mask = jnp.ones_like(self.angles, dtype=bool) 66 | 67 | @property 68 | def n_particles(self): 69 | return jnp.sum(self.species_mask) 70 | 71 | def to_dict(self): 72 | """Returns the stored graph data as a dictionary of arrays. 73 | This format is often beneficial for dataloaders. 74 | """ 75 | return { 76 | 'species': self.species, 77 | 'distance_ij': self.distance_ij, 78 | 'idx_i': self.idx_i, 79 | 'idx_j': self.idx_j, 80 | 'angles': self.angles, 81 | 'reduce_to_ji': self.reduce_to_ji, 82 | 'expand_to_kj': self.expand_to_kj, 83 | 'species_mask': self.species_mask, 84 | 'edge_mask': self.edge_mask, 85 | 'triplet_mask': self.triplet_mask 86 | } 87 | 88 | @classmethod 89 | def from_dict(cls, graph_dict): 90 | """Initializes instance from dictionary containing all necessary keys 91 | for initialization. 92 | """ 93 | return cls(**{ 94 | key: value for key, value in graph_dict.items() 95 | if key in inspect.signature(cls).parameters 96 | }) 97 | 98 | def cap_exactly(self): 99 | """Deletes all non-existing edges and triplets from the stored graph. 100 | 101 | This is a non-pure function and hence not available in a jit-context. 102 | Returning the capped graph does not solve the problem when n_edges 103 | and n_triplets are computed within the jit-compiled function. 104 | """ 105 | # edges are sorted, hence all non-existing edges are at the end 106 | self.species = self.species[:self.n_particles] 107 | self.species_mask = self.species_mask[:self.n_particles] 108 | 109 | self.distance_ij = self.distance_ij[:self.n_edges] 110 | self.idx_i = self.idx_i[:self.n_edges] 111 | self.idx_j = self.idx_j[:self.n_edges] 112 | self.edge_mask = self.edge_mask[:self.n_edges] 113 | 114 | self.angles = self.angles[:self.n_triplets] 115 | self.reduce_to_ji = self.reduce_to_ji[:self.n_triplets] 116 | self.expand_to_kj = self.expand_to_kj[:self.n_triplets] 117 | self.triplet_mask = self.triplet_mask[:self.n_triplets] 118 | 119 | 120 | def angle(r_ij, r_kj): 121 | """Computes the angle (kj, ij) from vectors r_kj and r_ij, 122 | correctly selecting the quadrant. 123 | 124 | Based on 125 | :math:`\\tan(\\theta) = |(r_{ji} \\times r_{kj})| / (r_{ji} \\cdot r_{kj})`. 126 | Beware the non-differentability of arctan2(0,0). 127 | 128 | Args: 129 | r_ij: Vector pointing to position of particle i from particle j 130 | r_kj: Vector pointing to position of particle k from particle j 131 | 132 | Returns: 133 | Angle between vectors 134 | """ 135 | cross = jnp.linalg.norm(jnp.cross(r_ij, r_kj)) 136 | dot = jnp.dot(r_ij, r_kj) 137 | theta = jnp.arctan2(cross, dot) 138 | return theta 139 | 140 | 141 | def safe_angle_mask(r_ji, r_kj, angle_mask): 142 | """Sets masked angles to pi/2 to ensure differentiablility. 143 | 144 | Args: 145 | r_ji: Array (N_triplets, dim) of vectors pointing to position of 146 | particle i from particle j 147 | r_kj: Array (N_triplets, dim) of vectors pointing to position of 148 | particle k from particle j 149 | angle_mask: (N_triplets, ) or (N_triplets, 1) Boolean mask for each 150 | triplet, which is False for triplets that need to be masked. 151 | 152 | Returns: 153 | A tuple (r_ji_safe, r_kj_safe) of vectors r_ji and r_kj, where masked 154 | triplets are replaced such that the angle between them is pi/2. 155 | """ 156 | if angle_mask.ndim == 1: # expand for broadcasing, if necessary 157 | angle_mask = jnp.expand_dims(angle_mask, -1) 158 | safe_ji = jnp.array([1., 0., 0.], dtype=jnp.float32) 159 | safe_kj = jnp.array([0., 1., 0.], dtype=jnp.float32) 160 | r_ji_safe = jnp.where(angle_mask, r_ji, safe_ji) 161 | r_kj_safe = jnp.where(angle_mask, r_kj, safe_kj) 162 | return r_ji_safe, r_kj_safe 163 | 164 | 165 | def angle_triplets(positions, displacement_fn, angle_idxs, angle_mask): 166 | """Computes the angle for all triplets between 0 and pi. 167 | 168 | Masked angles are set to pi/2. 169 | 170 | Args: 171 | positions: Array pf particle positions (N_particles x 3) 172 | displacement_fn: Jax_md displacement function 173 | angle_idxs: Array of particle indeces that form a triplet 174 | (N_triples x 3) 175 | angle_mask: Boolean mask for each triplet, which is False for triplets 176 | that need to be masked. 177 | 178 | Returns: 179 | A (N_triples,) array with the angle for each triplet. 180 | """ 181 | r_i = positions[angle_idxs[:, 0]] 182 | r_j = positions[angle_idxs[:, 1]] 183 | r_k = positions[angle_idxs[:, 2]] 184 | 185 | # Note: The original DimeNet implementation uses R_ji, however r_ij is the 186 | # correct vector to get the angle between both vectors. This is a 187 | # known issue in DimeNet. We apply the correct angle definition. 188 | r_ij = vmap(displacement_fn)(r_i, r_j) # r_i - r_j respecting periodic BCs 189 | r_kj = vmap(displacement_fn)(r_k, r_j) 190 | # we need to mask as the case where r_ij is co-linear with r_kj. 191 | # Otherwise, this generates NaNs on the backward pass 192 | r_ij_safe, r_kj_safe = safe_angle_mask(r_ij, r_kj, angle_mask) 193 | angles = vmap(angle)(r_ij_safe, r_kj_safe) 194 | return angles 195 | 196 | 197 | def _flatten_sort_and_capp(matrix, sorting_args, cap_size): 198 | """Helper function that takes a 2D array, flattens it, sorts it using the 199 | args (usually provided via argsort) and capps the end of the resulting 200 | vector. Used to delete non-existing edges and returns the capped vector. 201 | """ 202 | vect = jnp.ravel(matrix) 203 | sorted_vect = vect[sorting_args] 204 | capped_vect = sorted_vect[0:cap_size] 205 | return capped_vect 206 | 207 | 208 | def sparse_graph_from_neighborlist(displacement_fn: Callable, 209 | positions: jnp.ndarray, 210 | neighbor: partition.NeighborList, 211 | r_cutoff: jnp.array, 212 | species: jnp.array = None, 213 | max_edges: Optional[int] = None, 214 | max_triplets: Optional[int] = None, 215 | species_mask: jnp.array = None, 216 | ) -> Tuple[SparseDirectionalGraph, bool]: 217 | """Constructs a sparse representation of graph edges and angles to save 218 | memory and computations over neighbor list. 219 | 220 | The speed-up over simply using the dense jax_md neighbor list is 221 | significant, particularly regarding triplets. To allow for a representation 222 | of constant size required by jit, we pad the resulting vectors. 223 | 224 | Args: 225 | displacement_fn: Jax_MD displacement function encoding box dimensions 226 | positions: (N_particles, dim) array of particle positions 227 | neighbor: Jax_MD neighbor list that is in sync with positions 228 | r_cutoff: Radial cutoff distance, below which 2 particles are considered 229 | to be connected by an edge. 230 | species: (N_particles,) array encoding atom types. If None, assumes type 231 | 0 for all atoms. 232 | max_edges: Maximum number of edges storable in the graph. Can be used to 233 | reduce the number of padded edges, but should be used 234 | carefully, such that no existing edges are capped. Default 235 | None uses the maximum possible number of edges as given by 236 | the dense neighbor list. 237 | max_triplets: Maximum number of triplets storable in the graph. Can be 238 | used to reduce the number of padded triplets, but should be 239 | used carefully, such that no existing triplets are capped. 240 | Default None uses the maximum possible number of triplets as 241 | given by the dense neighbor list. 242 | species_mask: (N_particles,) array encoding atom types. Default None, 243 | assumes no masking necessary. 244 | 245 | Returns: 246 | Tuple (sparse_graph, too_many_edges_error_code) containing the 247 | SparseDirectionalGraph and whether max_edges or max_triplets overflowed. 248 | """ 249 | assert neighbor.format.name == 'Dense', ('Currently only dense neighbor' 250 | ' lists supported.') 251 | n_particles, max_neighbors = neighbor.idx.shape 252 | species = _canonicalize_species(species, n_particles) 253 | 254 | neighbor_displacement_fn = space.map_neighbor(displacement_fn) 255 | 256 | # compute pairwise distances 257 | pos_neigh = positions[neighbor.idx] 258 | pair_displacement = neighbor_displacement_fn(positions, pos_neigh) 259 | pair_distances = space.distance(pair_displacement) 260 | 261 | # compute adjacency matrix via neighbor_list, then build sparse graph 262 | # representation to avoid part of padding overhead in dense neighborlist 263 | # adds all edges > cut-off to masked edges 264 | edge_idx_ji = jnp.where(pair_distances < r_cutoff, neighbor.idx, 265 | n_particles) 266 | # neighbor.idx: an index j in row i encodes a directed edge from 267 | # particle j to particle i. 268 | # edge_idx[i, j]: j->i. if j == N: encodes masked edge. 269 | # Index N would index out-of-bounds, but in jax the last element is 270 | # returned instead 271 | 272 | # conservative estimates for initialization run 273 | # use guess from initialization for tighter bound to save memory and 274 | # computations during production runs 275 | if max_edges is None: 276 | max_edges = n_particles * max_neighbors 277 | if max_triplets is None: 278 | max_triplets = max_edges * max_neighbors 279 | 280 | # sparse edge representation: 281 | # construct vectors from adjacency matrix and only keep existing edges 282 | # Target node (i) and source (j) of edges 283 | pair_mask = edge_idx_ji != n_particles # non-existing neighbor encoded as N 284 | # due to undirectedness, each edge is included twice 285 | n_edges = jnp.count_nonzero(pair_mask) 286 | pair_mask_flat = jnp.ravel(pair_mask) 287 | # non-existing edges are sorted to the end for capping 288 | sorting_idxs = jnp.argsort(~pair_mask_flat) 289 | _, yy = jnp.meshgrid(jnp.arange(max_neighbors), jnp.arange(n_particles)) # pylint: disable=unbalanced-tuple-unpacking 290 | idx_i = _flatten_sort_and_capp(yy, sorting_idxs, max_edges) 291 | idx_j = _flatten_sort_and_capp(edge_idx_ji, sorting_idxs, max_edges) 292 | d_ij = _flatten_sort_and_capp(pair_distances, sorting_idxs, max_edges) 293 | sparse_pair_mask = _flatten_sort_and_capp(pair_mask_flat, sorting_idxs, 294 | max_edges) 295 | 296 | # build sparse angle combinations from adjacency matrix: 297 | # angle defined for 3 particles with connections k->j and j->i 298 | # directional message passing accumulates all k->j to update each m_ji 299 | idx3_i = jnp.repeat(idx_i, max_neighbors) 300 | idx3_j = jnp.repeat(idx_j, max_neighbors) 301 | # retrieves for each j in idx_j its neighbors k: stored in 2nd axis 302 | idx3_k_mat = edge_idx_ji[idx_j] 303 | idx3_k = idx3_k_mat.ravel() 304 | angle_idxs = jnp.column_stack([idx3_i, idx3_j, idx3_k]) 305 | 306 | # masking: 307 | # k and j are different particles, by edge_idx_ji construction. 308 | # The same applies to j - i, except for masked ones 309 | mask_i_eq_k = idx3_i != idx3_k 310 | # mask for ij known a priori 311 | mask_ij = jnp.repeat(sparse_pair_mask, max_neighbors) 312 | mask_k = idx3_k != n_particles 313 | angle_mask = mask_ij * mask_k * mask_i_eq_k # union of masks 314 | angle_mask, sorting_idx3 = lax.top_k(angle_mask, max_triplets) 315 | angle_idxs = angle_idxs[sorting_idx3] 316 | n_triplets = jnp.count_nonzero(angle_mask) 317 | angles = angle_triplets(positions, displacement_fn, angle_idxs, angle_mask) 318 | 319 | # retrieving edge_id m_ji from nodes i and j: 320 | # idx_i < N by construction, but idx_j can be N: will override 321 | # lookup[i, N-1], which is problematic if [i, N-1] is an existing edge. 322 | # Hence, the lookup table is extended by 1. 323 | edge_id_lookup = jnp.zeros([n_particles, n_particles + 1], dtype=jnp.int32) 324 | edge_id_lookup_direct = edge_id_lookup.at[(idx_i, idx_j)].set( 325 | jnp.arange(max_edges)) 326 | 327 | # stores for each angle kji edge index j->i to aggregate messages via a 328 | # segment_sum: each m_ji is a distinct segment containing all incoming m_kj 329 | reduce_to_ji = edge_id_lookup_direct[(angle_idxs[:, 0], angle_idxs[:, 1])] 330 | # stores for each angle kji edge index k->j to gather all incoming edges 331 | # for message passing 332 | expand_to_kj = edge_id_lookup_direct[(angle_idxs[:, 1], angle_idxs[:, 2])] 333 | 334 | too_many_edges_error_code = lax.cond( 335 | jnp.bitwise_or(n_edges > max_edges, n_triplets > max_triplets), 336 | lambda _: True, lambda _: False, n_edges 337 | ) 338 | 339 | sparse_graph = SparseDirectionalGraph( 340 | species=species, distance_ij=d_ij, idx_i=idx_i, idx_j=idx_j, 341 | angles=angles, reduce_to_ji=reduce_to_ji, expand_to_kj=expand_to_kj, 342 | edge_mask=sparse_pair_mask, triplet_mask=angle_mask, n_edges=n_edges, 343 | n_triplets=n_triplets, species_mask=species_mask 344 | ) 345 | return sparse_graph, too_many_edges_error_code 346 | 347 | 348 | def _canonicalize_species(species, n_particles): 349 | """Ensures species are integer and initializes species to 0 if species=None. 350 | 351 | Args: 352 | species: (N_particles,) array of atom types or None 353 | n_particles: Number of particles 354 | 355 | Returns: 356 | Integer species array. 357 | """ 358 | if species is None: 359 | species = jnp.zeros(n_particles, dtype=jnp.int32) 360 | else: 361 | smap._check_species_dtype(species) # assert species are int 362 | return species -------------------------------------------------------------------------------- /chemtrain/trainers.py: -------------------------------------------------------------------------------- 1 | """This file contains several Trainer classes as a quickstart for users.""" 2 | from jax import numpy as jnp 3 | from jax_sgmc import data 4 | from jax_sgmc.data import numpy_loader 5 | 6 | from chemtrain import (util, force_matching, reweighting, 7 | max_likelihood) 8 | 9 | 10 | class ForceMatching(max_likelihood.DataParallelTrainer): 11 | """Force-matching trainer. 12 | 13 | This implementation assumes a constant number of particles per box and 14 | constant box sizes for each snapshot. 15 | If this is not the case, please use the ForceMatchingPrecomputed trainer 16 | based on padded sparse neighborlists. 17 | Caution: Currently neighborlist overflow is not checked. 18 | Make sure to build nbrs_init large enough. 19 | 20 | Virial data is pressure tensor, i.e. negative stress tensor 21 | """ 22 | def __init__(self, init_params, energy_fn_template, nbrs_init, 23 | optimizer, position_data, energy_data=None, force_data=None, 24 | virial_data=None, box_tensor=None, gamma_f=1., gamma_p=1.e-6, 25 | batch_per_device=1, batch_cache=10, train_ratio=0.7, 26 | val_ratio=0.1, shuffle=False, 27 | convergence_criterion='window_median', 28 | checkpoint_folder='Checkpoints'): 29 | 30 | checkpoint_path = 'output/force_matching/' + str(checkpoint_folder) 31 | dataset_dict = {'position_data': position_data, 32 | 'energy_data': energy_data, 33 | 'force_data': force_data, 34 | 'virial_data': virial_data 35 | } 36 | 37 | virial_fn = force_matching.init_virial_fn( 38 | virial_data, energy_fn_template, box_tensor) 39 | model = force_matching.init_model(nbrs_init, energy_fn_template, 40 | virial_fn) 41 | loss_fn = force_matching.init_loss_fn(gamma_f=gamma_f, gamma_p=gamma_p) 42 | 43 | super().__init__(dataset_dict, loss_fn, model, init_params, optimizer, 44 | checkpoint_path, batch_per_device, batch_cache, 45 | train_ratio, val_ratio, shuffle=shuffle, 46 | convergence_criterion=convergence_criterion, 47 | energy_fn_template=energy_fn_template) 48 | self._virial_fn = virial_fn 49 | self._nbrs_init = nbrs_init 50 | self._init_test_fn() 51 | 52 | @staticmethod 53 | def _build_dataset(position_data, energy_data=None, force_data=None, 54 | virial_data=None): 55 | return force_matching.build_dataset(position_data, energy_data, 56 | force_data, virial_data) 57 | 58 | def evaluate_mae_testset(self): 59 | assert self.test_loader is not None, ('No test set available. Check' 60 | ' train and val ratios or add a' 61 | ' test_loader manually.') 62 | maes = self.mae_fn(self.best_inference_params_replicated) 63 | for key, mae_value in maes.items(): 64 | print(f'{key}: MAE = {mae_value:.4f}') 65 | 66 | def _init_test_fn(self): 67 | if self.test_loader is not None: 68 | self.mae_fn, data_release_fn = force_matching.init_mae_fn( 69 | self.test_loader, self._nbrs_init, 70 | self.reference_energy_fn_template, self.batch_size, 71 | self.batch_cache, self._virial_fn 72 | ) 73 | self.release_fns.append(data_release_fn) 74 | else: 75 | self.mae_fn = None 76 | 77 | 78 | class RelativeEntropy(reweighting.PropagationBase): 79 | """Trainer for relative entropy minimization.""" 80 | def __init__(self, init_params, optimizer, 81 | reweight_ratio=0.9, sim_batch_size=1, energy_fn_template=None, 82 | convergence_criterion='window_median', 83 | checkpoint_folder='Checkpoints'): 84 | """ 85 | Initializes a relative entropy trainer instance. 86 | 87 | Uses first order method optimizer as Hessian is very expensive 88 | for neural networks. Both reweighting and the gradient formula 89 | currently assume a NVT ensemble. 90 | 91 | Args: 92 | init_params: Initial energy parameters 93 | optimizer: Optimizer from optax 94 | reweight_ratio: Ratio of reference samples required for n_eff to 95 | surpass to allow re-use of previous reference 96 | trajectory state. If trajectories should not be 97 | re-used, a value > 1 can be specified. 98 | sim_batch_size: Number of state-points to be processed as a single 99 | batch. Gradients will be averaged over the batch 100 | before stepping the optimizer. 101 | energy_fn_template: Function that takes energy parameters and 102 | initializes an new energy function. Here, the 103 | energy_fn_template is only a reference that 104 | will be saved alongside the trainer. Each 105 | state point requires its own due to the 106 | dependence on the box size via the displacement 107 | function, which can vary between state points. 108 | convergence_criterion: Either 'max_loss' or 'ave_loss'. 109 | If 'max_loss', stops if the gradient norm 110 | across all batches in the epoch is smaller 111 | than convergence_thresh. 'ave_loss' evaluates 112 | the average gradient norm across the batch. 113 | For a single state point, both are 114 | equivalent. A criterion based on the rolling 115 | standard deviation 'std' might be implemented 116 | in the future. 117 | checkpoint_folder: Name of folders to store ckeckpoints in. 118 | """ 119 | 120 | checkpoint_path = 'output/rel_entropy/' + str(checkpoint_folder) 121 | init_trainer_state = util.TrainerState( 122 | params=init_params, opt_state=optimizer.init(init_params)) 123 | super().__init__(init_trainer_state, optimizer, checkpoint_path, 124 | reweight_ratio, sim_batch_size, energy_fn_template) 125 | 126 | # in addition to the standard trajectory state, we also need to keep 127 | # track of dataloader states for reference snapshots 128 | self.data_states = {} 129 | 130 | self.early_stop = max_likelihood.EarlyStopping(self.params, 131 | convergence_criterion) 132 | 133 | def _set_dataset(self, key, reference_data, reference_batch_size, 134 | batch_cache=1): 135 | """Set dataset and loader corresponding to current state point.""" 136 | reference_loader = numpy_loader.NumpyDataLoader(R=reference_data, 137 | copy=False) 138 | init_reference_batch, get_ref_batch, _ = data.random_reference_data( 139 | reference_loader, batch_cache, reference_batch_size) 140 | init_reference_batch_state = init_reference_batch(shuffle=True) 141 | self.data_states[key] = init_reference_batch_state 142 | return get_ref_batch 143 | 144 | def add_statepoint(self, reference_data, energy_fn_template, 145 | simulator_template, neighbor_fn, timings, kbt, 146 | reference_state, reference_batch_size=None, 147 | batch_cache=1, initialize_traj=True, set_key=None, 148 | vmap_batch=10): 149 | """ 150 | Adds a state point to the pool of simulations. 151 | 152 | As each reference dataset / trajectory corresponds to a single 153 | state point, we initialize the dataloader together with the 154 | simulation. 155 | 156 | Currently only supports NVT simulations. 157 | 158 | Args: 159 | reference_data: De-correlated reference trajectory 160 | energy_fn_template: Function that takes energy parameters and 161 | initializes an new energy function. 162 | simulator_template: Function that takes an energy function and 163 | returns a simulator function. 164 | neighbor_fn: Neighbor function 165 | timings: Instance of TimingClass containing information 166 | about the trajectory length and which states to retain 167 | kbt: Temperature in kbT 168 | reference_state: Tuple of initial simulation state and neighbor list 169 | reference_batch_size: Batch size of dataloader for reference 170 | trajectory. If None, will use the same number 171 | of snapshots as generated via the optimizer. 172 | batch_cache: Number of reference batches to cache in order to 173 | minimize host-device communication. Make sure the 174 | cached data size does not exceed the full dataset size. 175 | initialize_traj: True, if an initial trajectory should be generated. 176 | Should only be set to False if a checkpoint is 177 | loaded before starting any training. 178 | set_key: Specify a key in order to restart from same statepoint. 179 | By default, uses the index of the sequance statepoints are 180 | added, i.e. self.trajectory_states[0] for the first added 181 | statepoint. Can be used for changing the timings of the 182 | simulation during training. 183 | vmap_batch: Batch size of vmapping of per-snapshot energy and 184 | gradient calculation. 185 | """ 186 | if reference_batch_size is None: 187 | print('No reference batch size provided. Using number of generated' 188 | ' CG snapshots by default.') 189 | states_per_traj = jnp.size(timings.t_production_start) 190 | if reference_state[0].position.ndim > 2: 191 | n_trajctories = reference_state[0].position.shape[0] 192 | reference_batch_size = n_trajctories * states_per_traj 193 | else: 194 | reference_batch_size = states_per_traj 195 | 196 | key, weights_fn, propagate = self._init_statepoint(reference_state, 197 | energy_fn_template, 198 | simulator_template, 199 | neighbor_fn, 200 | timings, 201 | kbt, 202 | set_key, 203 | vmap_batch, 204 | initialize_traj) 205 | 206 | reference_dataloader = self._set_dataset(key, 207 | reference_data, 208 | reference_batch_size, 209 | batch_cache) 210 | 211 | grad_fn = reweighting.init_rel_entropy_gradient( 212 | energy_fn_template, weights_fn, kbt, vmap_batch) 213 | 214 | def propagation_and_grad(params, traj_state, batch_state): 215 | """Propagates the trajectory, if necessary, and computes the 216 | gradient via the relative entropy formalism. 217 | """ 218 | traj_state = propagate(params, traj_state) 219 | new_batch_state, reference_batch = reference_dataloader(batch_state) 220 | reference_positions = reference_batch['R'] 221 | grad = grad_fn(params, traj_state, reference_positions) 222 | return traj_state, grad, new_batch_state 223 | 224 | self.grad_fns[key] = propagation_and_grad 225 | 226 | def _update(self, batch): 227 | """Updates the potential using the gradient from relative entropy.""" 228 | grads = [] 229 | for sim_key in batch: 230 | grad_fn = self.grad_fns[sim_key] 231 | 232 | self.trajectory_states[sim_key], curr_grad, \ 233 | self.data_states[sim_key] = grad_fn(self.params, 234 | self.trajectory_states[sim_key], 235 | self.data_states[sim_key]) 236 | grads.append(curr_grad) 237 | 238 | batch_grad = util.tree_mean(grads) 239 | self._step_optimizer(batch_grad) 240 | self.gradient_norm_history.append(util.tree_norm(batch_grad)) 241 | 242 | def _evaluate_convergence(self, duration, thresh): 243 | curr_grad_norm = self.gradient_norm_history[-1] 244 | print(f'\nEpoch {self._epoch}: Elapsed time = {duration:.3f} min') 245 | self._print_measured_statepoint() 246 | 247 | self._converged = self.early_stop.early_stopping(curr_grad_norm, thresh, 248 | save_best_params=False) 249 | -------------------------------------------------------------------------------- /chemtrain/traj_quantity.py: -------------------------------------------------------------------------------- 1 | """Molecular dynamics observable functions acting on trajectories rather than 2 | single snapshots. 3 | 4 | Builds on the TrajectoryState object defined in traj_util.py. 5 | """ 6 | from jax import numpy as jnp 7 | 8 | 9 | def init_traj_mean_fn(quantity_key): 10 | """Initializes the 'traj_fn' for the DiffTRe 'target' dict for simple 11 | trajectory-averaged observables. 12 | 13 | This function builds the 'traj_fn' of the DiffTRe 'target' dict for the 14 | common case of target observables that simply consist of a 15 | trajectory-average of instantaneous quantities, such as RDF, ADF, pressure 16 | or density. 17 | 18 | This function also serves as a template on how to build the 'traj_fn' for 19 | observables that are a general function of one or many instantaneous 20 | quantities, such as stiffness via the stress-fluctuation method or 21 | fluctuation formulas in this module. The 'traj_fn' receives a dict of all 22 | quantity trajectories as input under the same keys as instantaneous 23 | quantities are defined in 'quantities'. The 'traj_fn' then returns the 24 | ensemble-averaged quantity, possibly taking advantage of fluctuation 25 | formulas defined in the traj_quantity module. 26 | 27 | Args: 28 | quantity_key: Quantity key used in 'quantities' to generate the 29 | quantity trajectory at hand, to be averaged over. 30 | 31 | Returns: 32 | The 'traj_fn' to be used in building the 'targets' dict for DiffTRe. 33 | """ 34 | def traj_mean(quantity_trajs): 35 | quantity_traj = quantity_trajs[quantity_key] 36 | return jnp.mean(quantity_traj, axis=0) 37 | return traj_mean 38 | -------------------------------------------------------------------------------- /chemtrain/traj_util.py: -------------------------------------------------------------------------------- 1 | """Utility functions to process whole MD trajectories rather than 2 | single snapshots. 3 | """ 4 | from functools import partial 5 | from typing import Any, Dict 6 | 7 | import chex 8 | from jax import lax, jit, vmap, numpy as jnp 9 | from jax_md import simulate, util as jax_md_util 10 | 11 | from chemtrain import util 12 | from chemtrain.jax_md_mod import custom_quantity 13 | 14 | Array = jax_md_util.Array 15 | 16 | 17 | @partial(chex.dataclass, frozen=True) 18 | class TimingClass: 19 | """A dataclass containing run-times for the simulation. 20 | 21 | Attributes: 22 | t_equilib_start: Starting time of all printouts that will be dumped 23 | for equilibration 24 | t_production_start: Starting time of all runs that result in a printout 25 | t_production_end: Generation time of all printouts 26 | timesteps_per_printout: Number of simulation timesteps to run forward 27 | from each starting time 28 | time_step: Simulation time step 29 | """ 30 | t_equilib_start: Array 31 | t_production_start: Array 32 | t_production_end: Array 33 | timesteps_per_printout: int 34 | time_step: float 35 | 36 | 37 | @partial(chex.dataclass, frozen=True) 38 | class TrajectoryState: 39 | """A dataclass storing information of a generated trajectory. 40 | 41 | Attributes: 42 | sim_state: Last simulation state, a tuple of last state and nbrs 43 | trajectory: Generated trajectory 44 | overflow: True if neighbor list overflowed during trajectory generation 45 | thermostat_kbT: Target thermostat kbT at time of respective snapshots 46 | barostat_press: Target barostat pressure at time of respective snapshots 47 | aux: Dict of auxilary per-snapshot quantities as defined by quantities 48 | in trajectory generator. 49 | """ 50 | sim_state: Any 51 | trajectory: Any 52 | overflow: Array = False 53 | thermostat_kbt: Array = None 54 | barostat_press: Array = None 55 | aux: Dict = None 56 | 57 | 58 | def process_printouts(time_step, total_time, t_equilib, print_every, t_start=0): 59 | """Initializes a dataclass containing information for the simulator 60 | on simulation time and saving states. 61 | 62 | This function is not jitable as array sizes depend on input values. 63 | 64 | Args: 65 | time_step: Time step size 66 | total_time: Total simulation run length 67 | t_equilib: Equilibration run length 68 | print_every: Time after which a state is saved 69 | t_start: Starting time. Only relevant for time-dependent 70 | thermostat/barostat. 71 | 72 | Returns: 73 | A class containing information for the simulator 74 | on which states to save. 75 | """ 76 | assert total_time > 0. and t_equilib > 0., 'Times need to be positive.' 77 | assert total_time > t_equilib, ('Total time needs to exceed equilibration ' 78 | 'time, otherwise no trajectory will be ' 79 | 'sampled.') 80 | timesteps_per_printout = int(print_every / time_step) 81 | n_production = int((total_time - t_equilib) / print_every) 82 | n_dumped = int(t_equilib / print_every) 83 | equilibration_t_start = jnp.arange(n_dumped) * print_every + t_start 84 | production_t_start = (jnp.arange(n_production) * print_every 85 | + t_equilib + t_start) 86 | production_t_end = production_t_start + print_every 87 | timings = TimingClass(t_equilib_start=equilibration_t_start, 88 | t_production_start=production_t_start, 89 | t_production_end=production_t_end, 90 | timesteps_per_printout=timesteps_per_printout, 91 | time_step=time_step) 92 | return timings 93 | 94 | 95 | def _run_to_next_printout_neighbors(apply_fn, timings, **kwargs): 96 | """Initializes a function that runs simulation to next printout 97 | state and returns that state. 98 | 99 | Run simulation forward to each printout point and return state. 100 | Used to sample a specified number of states 101 | 102 | Args: 103 | apply_fn: Apply function from initialization of simulator 104 | neighbor_fn: Neighbor function 105 | timings: Instance of TimingClass containing information 106 | about which states to retain and simulation time 107 | kwargs: Kwargs to supply 'kT' and/or 'pressure' time-dependent 108 | functions to allow for non-equilibrium MD 109 | 110 | Returns: 111 | A function that takes the current simulation state, runs the 112 | simulation forward to the next printout state and returns it. 113 | """ 114 | def do_step(cur_state, t): 115 | apply_kwargs = {} 116 | if 'kT' in kwargs: 117 | apply_kwargs['kT'] = kwargs['kT'](t) 118 | if 'pressure' in kwargs: 119 | apply_kwargs['pressure'] = kwargs['pressure'](t) 120 | 121 | state, nbrs = cur_state 122 | new_state = apply_fn(state, neighbor=nbrs, **apply_kwargs) 123 | nbrs = util.neighbor_update(nbrs, new_state) 124 | new_sim_state = (new_state, nbrs) 125 | return new_sim_state, t 126 | 127 | @jit 128 | def run_small_simulation(start_state, t_start=0.): 129 | simulation_time_points = jnp.arange(timings.timesteps_per_printout) \ 130 | * timings.time_step + t_start 131 | printout_state, _ = lax.scan(do_step, 132 | start_state, 133 | xs=simulation_time_points) 134 | cur_state, _ = printout_state 135 | return printout_state, cur_state 136 | return run_small_simulation 137 | 138 | 139 | def _canonicalize_dynamic_state_kwargs(state_kwargs, t_snapshots, *keys): 140 | """Converts constant state_kwargs, such as 'kT' and 'pressure' to constant 141 | functions over time and deletes None kwargs. Additionally, return the 142 | values of state_kwargs at production printout times. 143 | """ 144 | def constant_fn(_, c): 145 | return c 146 | 147 | state_point_vals = [] 148 | for key in keys: 149 | if key in state_kwargs: 150 | if state_kwargs[key] is None: 151 | state_kwargs.pop(key) # ignore kwarg if None is provided 152 | state_points = None 153 | else: 154 | if jnp.isscalar(state_kwargs[key]): 155 | state_kwargs[key] = partial(constant_fn, 156 | c=state_kwargs[key]) 157 | state_points = vmap(state_kwargs[key])(t_snapshots) 158 | else: 159 | state_points = None 160 | state_point_vals.append(state_points) 161 | return state_kwargs, tuple(state_point_vals) 162 | 163 | 164 | def _traj_replicate_if_not_none(thermostat_values, n_traj): 165 | """Replicates thermostat targets to multiple trajectories, if not None.""" 166 | if thermostat_values is not None: 167 | thermostat_values = jnp.tile(thermostat_values, n_traj) 168 | return thermostat_values 169 | 170 | 171 | def trajectory_generator_init(simulator_template, energy_fn_template, 172 | ref_timings=None, quantities=None, vmap_batch=10): 173 | """Initializes a trajectory_generator function that computes a new 174 | trajectory stating at the last traj_state. 175 | 176 | Args: 177 | simulator_template: Function returning new simulator given 178 | current energy function 179 | energy_fn_template: Energy function template 180 | ref_timings: Instance of TimingClass containing information about the 181 | times states need to be retained 182 | quantities: Quantities dict to compute and store auxilary quantities 183 | alongside trajectory. This is particularly helpful for 184 | storing energy and pressure in a reweighting context. 185 | vmap_batch: Batch size for computation of auxillary quantities. 186 | 187 | Returns: 188 | A function taking energy params and the current traj_state (including 189 | neighbor list) that runs the simulation forward generating the 190 | next TrajectoryState. 191 | """ 192 | if quantities is None: 193 | quantities = {} 194 | 195 | # temperature is inexpensive and generally useful: compute it by default 196 | quantities['kbT'] = custom_quantity.temperature 197 | 198 | def generate_reference_trajectory(params, sim_state, **kwargs): 199 | """ 200 | Returns a new TrajectoryState with auxilary variables. 201 | 202 | Args: 203 | params: Energy function parameters 204 | sim_state: Initial simulation state(s). Mulriple states can be 205 | provided to run multiple trajectories in parallel. 206 | **kwargs: Kwargs to supply 'kT' and/or 'pressure' to change these 207 | thermostat/barostat values on the fly. Can be constant 208 | or function of t. 209 | 210 | Returns: 211 | TrajectoryState object containing the newly generated trajectory 212 | """ 213 | timings = kwargs.pop('timings', ref_timings) 214 | assert timings is not None 215 | 216 | kwargs, (kbt, barostat_press) = _canonicalize_dynamic_state_kwargs( 217 | kwargs, timings.t_production_end, 'kT', 'pressure') 218 | 219 | energy_fn = energy_fn_template(params) 220 | _, apply_fn = simulator_template(energy_fn) 221 | run_to_printout = _run_to_next_printout_neighbors(apply_fn, timings, 222 | **kwargs) 223 | 224 | if sim_state[0].position.ndim > 2: 225 | def run_trajectory(state, starting_time): 226 | state, trajectory = lax.scan( 227 | run_to_printout, state, xs=starting_time) 228 | return state, trajectory 229 | 230 | if timings.t_equilib_start.size > 0: 231 | sim_state, _ = vmap(run_trajectory, (0, None))( # equilibration 232 | sim_state, timings.t_equilib_start) 233 | 234 | new_sim_state, traj = vmap(run_trajectory, (0, None))( # production 235 | sim_state, timings.t_production_start) 236 | 237 | # combine parallel trajectories to single large one for streamlined 238 | # postprocessing via traj_quantity, DiffTRe, relative entropy, etc. 239 | traj = util.tree_combine(traj) 240 | overflow = jnp.any(new_sim_state[1].did_buffer_overflow) 241 | n_traj = sim_state[0].position.shape[0] 242 | kbt = _traj_replicate_if_not_none(kbt, n_traj) 243 | barostat_press = _traj_replicate_if_not_none(barostat_press, n_traj) 244 | 245 | else: 246 | if timings.t_equilib_start.size > 0: 247 | sim_state, _ = lax.scan( # equilibration 248 | run_to_printout, sim_state, xs=timings.t_equilib_start) 249 | 250 | new_sim_state, traj = lax.scan( # production 251 | run_to_printout, sim_state, xs=timings.t_production_start) 252 | overflow = new_sim_state[1].did_buffer_overflow 253 | 254 | traj_state = TrajectoryState(sim_state=new_sim_state, 255 | trajectory=traj, 256 | overflow=overflow, 257 | thermostat_kbt=kbt, 258 | barostat_press=barostat_press) 259 | 260 | aux_trajectory = quantity_traj(traj_state, quantities, params, 261 | vmap_batch) 262 | return traj_state.replace(aux=aux_trajectory) 263 | 264 | return generate_reference_trajectory 265 | 266 | 267 | def quantity_traj(traj_state, quantities, energy_params=None, batch_size=1): 268 | """Computes quantities of interest for all states in a trajectory. 269 | 270 | Arbitrary quantity functions can be provided via the quantities dict. 271 | The quantities dict should provide each quantity function via its own 272 | key that contains another dict containing the function under the 273 | 'compute_fn' key. The resulting quantity trajectory will be saved in 274 | a dict under the same key as the input quantity function. 275 | 276 | Args: 277 | traj_state: TrajectoryState as output from trajectory generator 278 | quantities: The quantity dict containing for each target quantity 279 | the snapshot compute function 280 | energy_params: Energy params for energy_fn_template to initialize 281 | the current energy_fn 282 | batch_size: Number of batches for vmap 283 | 284 | Returns: 285 | A dict of quantity trajectories saved under the same key as the 286 | input quantity function. 287 | """ 288 | if traj_state.sim_state[0].position.ndim > 2: 289 | last_state, fixed_reference_nbrs = util.tree_get_single( 290 | traj_state.sim_state) 291 | else: 292 | last_state, fixed_reference_nbrs = traj_state.sim_state 293 | npt_ensemble = util.is_npt_ensemble(last_state) 294 | 295 | @jit 296 | def single_state_quantities(single_snapshot): 297 | state, kbt = single_snapshot 298 | nbrs = util.neighbor_update(fixed_reference_nbrs, state) 299 | kwargs = {'neighbor': nbrs, 'energy_params': energy_params, 'kT': kbt} 300 | if npt_ensemble: 301 | box = simulate.npt_box(state) 302 | kwargs['box'] = box 303 | 304 | computed_quantities = { 305 | quantity_fn_key: quantities[quantity_fn_key](state, **kwargs) 306 | for quantity_fn_key in quantities 307 | } 308 | return computed_quantities 309 | 310 | batched_traj = util.tree_vmap_split(traj_state.trajectory, batch_size) 311 | if traj_state.thermostat_kbt is not None: 312 | thermo_kbt = traj_state.thermostat_kbt.reshape((-1, batch_size)) 313 | else: 314 | thermo_kbt = traj_state.thermostat_kbt 315 | 316 | bachted_quantity_trajs = lax.map( 317 | vmap(single_state_quantities), (batched_traj, thermo_kbt) 318 | ) 319 | quantity_trajs = util.tree_combine(bachted_quantity_trajs) 320 | return quantity_trajs -------------------------------------------------------------------------------- /chemtrain/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions helpful in designing new trainers.""" 2 | import abc 3 | from functools import partial 4 | import pathlib 5 | from typing import Any 6 | 7 | import chex 8 | import cloudpickle as pickle 9 | from jax import tree_map, device_count, numpy as jnp 10 | from jax.tree_util import tree_flatten, tree_leaves 11 | from jax_md import simulate 12 | import numpy as onp 13 | 14 | 15 | # freezing seems to give slight performance improvement 16 | @partial(chex.dataclass, frozen=True) 17 | class TrainerState: 18 | """Each trainer at least contains the state of parameter and 19 | optimizer. 20 | """ 21 | params: Any 22 | opt_state: Any 23 | 24 | 25 | def _get_box_kwargs_if_npt(state): 26 | kwargs = {} 27 | if is_npt_ensemble(state): 28 | box = simulate.npt_box(state) 29 | kwargs['box'] = box 30 | return kwargs 31 | 32 | 33 | def neighbor_update(neighbors, state): 34 | """Update neighbor lists irrespective of the ensemble. 35 | 36 | Fetches the box to the neighbor list update function in case of the 37 | NPT ensemble. 38 | 39 | Args: 40 | neighbors: Neighbor list to be updated 41 | state: Simulation state 42 | 43 | Returns: 44 | Updated neighbor list 45 | """ 46 | kwargs = _get_box_kwargs_if_npt(state) 47 | nbrs = neighbors.update(state.position, **kwargs) 48 | return nbrs 49 | 50 | 51 | def neighbor_allocate(neighbor_fn, state, extra_capacity=0): 52 | """Re-allocates neighbor lost irrespective of ensemble. Not jitable. 53 | 54 | Args: 55 | neighbor_fn: Neighbor function to re-allocate neighbor list 56 | state: Simulation state 57 | extra_capacity: Additional capacity of new neighbor list 58 | 59 | Returns: 60 | Updated neighbor list 61 | """ 62 | kwargs = _get_box_kwargs_if_npt(state) 63 | nbrs = neighbor_fn.allocate(state.position, extra_capacity, **kwargs) 64 | return nbrs 65 | 66 | 67 | def is_npt_ensemble(state): 68 | """Whether a state belongs to the NPT ensemble.""" 69 | return hasattr(state, 'box_position') 70 | 71 | 72 | def tree_combine(tree): 73 | """Combines the first two axes of `tree`, e.g. after batching.""" 74 | return tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), tree) 75 | 76 | 77 | def tree_norm(tree): 78 | """Returns the Euclidean norm of a PyTree.""" 79 | leaves, _ = tree_flatten(tree) 80 | return sum(jnp.vdot(x, x) for x in leaves) 81 | 82 | 83 | def tree_get_single(tree, n=0): 84 | """Returns the n-th tree of a tree-replica, e.g. from pmap. 85 | By default, the first tree is returned. 86 | """ 87 | single_tree = tree_map(lambda x: jnp.array(x[n]), tree) 88 | return single_tree 89 | 90 | 91 | def tree_set(tree, new_data, end, start=0): 92 | """Overrides entries of a tree from index start:end along axis 0 93 | with new_data. 94 | """ 95 | return tree_map(lambda leaf, new_data_leaf: 96 | leaf.at[start:end, ...].set(new_data_leaf), tree, new_data) 97 | 98 | 99 | def tree_get_slice(tree, idx_start, idx_stop, take_every=1, to_device=True): 100 | """Returns a slice of trees taken from a tree-replica along axis 0.""" 101 | if to_device: 102 | return tree_map(lambda x: jnp.array(x[idx_start:idx_stop:take_every]), 103 | tree) 104 | else: 105 | return tree_map(lambda x: x[idx_start:idx_stop:take_every], tree) 106 | 107 | 108 | def tree_take(tree, indicies, axis=0, on_cpu=True): 109 | """Tree-wise application of numpy.take.""" 110 | numpy = onp if on_cpu else jnp 111 | return tree_map(lambda x: numpy.take(x, indicies, axis), tree) 112 | 113 | 114 | def tree_replicate(tree): 115 | """Replicates a pytree along the first axis for pmap.""" 116 | return tree_map(lambda x: jnp.array([x] * device_count()), tree) 117 | 118 | 119 | def tree_concat(tree): 120 | """For output computed in parallel via pmap, restacks all leaves such that 121 | the parallel dimension is again along axis 0 and the leading pmap dimension 122 | vanishes. 123 | """ 124 | return tree_map(partial(jnp.concatenate, axis=0), tree) 125 | 126 | 127 | def tree_pmap_split(tree, n_devices): 128 | """Splits the first axis of `tree` evenly across the number of devices for 129 | pmap batching (size of first axis is n_devices). 130 | """ 131 | assert tree_leaves(tree)[0].shape[0] % n_devices == 0, \ 132 | 'First dimension needs to be multiple of number of devices.' 133 | return tree_map(lambda x: jnp.reshape(x, (n_devices, x.shape[0]//n_devices, 134 | *x.shape[1:])), tree) 135 | 136 | 137 | def tree_vmap_split(tree, batch_size): 138 | """Splits the first axis of a 'tree' with leaf sizes (N, X)`into 139 | (n_batches, batch_size, X) to allow straightforward vmapping over axis0. 140 | """ 141 | assert tree_leaves(tree)[0].shape[0] % batch_size == 0, \ 142 | 'First dimension of tree needs to be splittable by batch_size' \ 143 | ' without remainder.' 144 | return tree_map(lambda x: jnp.reshape(x, (x.shape[0] // batch_size, 145 | batch_size, *x.shape[1:])), 146 | tree) 147 | 148 | 149 | def tree_sum(tree_list, axis=None): 150 | """Computes the sum of equal-shaped leafs of a pytree.""" 151 | @partial(partial, tree_map) 152 | def leaf_add(leafs): 153 | return jnp.sum(leafs, axis=axis) 154 | return leaf_add(tree_list) 155 | 156 | 157 | def tree_mean(tree_list): 158 | """Computes the mean a list of equal-shaped pytrees.""" 159 | @partial(partial, tree_map) 160 | def tree_add_imp(*leafs): 161 | return jnp.mean(jnp.stack(leafs), axis=0) 162 | 163 | return tree_add_imp(*tree_list) 164 | 165 | 166 | def tree_multiplicity(tree): 167 | """Returns the number of stacked trees along axis 0.""" 168 | leaves, _ = tree_flatten(tree) 169 | return leaves[0].shape[0] 170 | 171 | 172 | def load_trainer(file_path): 173 | """Returns the trainer saved via 'trainer.save_trainer'. 174 | 175 | Args: 176 | file_path: Path of pickle file containing trainer. 177 | 178 | """ 179 | with open(file_path, 'rb') as pickle_file: 180 | trainer = pickle.load(pickle_file) 181 | trainer.move_to_device() 182 | return trainer 183 | 184 | 185 | def format_not_recognized_error(file_format): 186 | raise ValueError(f'File format {file_format} not recognized. ' 187 | f'Expected ".hdf5" or ".pkl".') 188 | 189 | 190 | class TrainerInterface(abc.ABC): 191 | """Abstract class defining the user interface of trainers as well as 192 | checkpointing functionality. 193 | """ 194 | def __init__(self, checkpoint_path, reference_energy_fn_template=None): 195 | """A reference energy_fn_template can be provided, but is not mandatory 196 | due to the dependence of the template on the box via the displacement 197 | function. 198 | """ 199 | self.checkpoint_path = checkpoint_path 200 | self._epoch = 0 201 | self.reference_energy_fn_template = reference_energy_fn_template 202 | 203 | @property 204 | def energy_fn(self): 205 | if self.reference_energy_fn_template is None: 206 | raise ValueError('Cannot construct energy_fn as no reference ' 207 | 'energy_fn_template was provided during ' 208 | 'initialization.') 209 | return self.reference_energy_fn_template(self.params) 210 | 211 | def _dump_checkpoint_occasionally(self, frequency=None): 212 | """Dumps a checkpoint during training, from which training can 213 | be resumed. 214 | """ 215 | assert self.checkpoint_path is not None 216 | if frequency is not None: 217 | pathlib.Path(self.checkpoint_path).mkdir(parents=True, 218 | exist_ok=True) 219 | if self._epoch % frequency == 0: # checkpoint model 220 | file_path = (self.checkpoint_path + 221 | f'/epoch{self._epoch - 1}.pkl') 222 | self.save_trainer(file_path) 223 | 224 | def save_trainer(self, save_path): 225 | """Saves whole trainer, e.g. for production after training.""" 226 | with open(save_path, 'wb') as pickle_file: 227 | pickle.dump(self, pickle_file) 228 | 229 | def save_energy_params(self, file_path, save_format='.hdf5'): 230 | if save_format == '.hdf5': 231 | raise NotImplementedError 232 | # from jax_sgmc.io import pytree_dict_keys, dict_to_pytree 233 | # leaf_names = pytree_dict_keys(self.state) 234 | # leafes = tree_leaves(self.state) 235 | # with h5py.File(file_path, "w") as file: 236 | # for leaf_name, value in zip(leaf_names, leafes): 237 | # file[leaf_name] = value 238 | elif save_format == '.pkl': 239 | with open(file_path, 'wb') as pickle_file: 240 | pickle.dump(self.params, pickle_file) 241 | else: 242 | format_not_recognized_error(save_format) 243 | 244 | def load_energy_params(self, file_path): 245 | if file_path.endswith('.hdf5'): 246 | raise NotImplementedError 247 | elif file_path.endswith('.pkl'): 248 | with open(file_path, 'rb') as pickle_file: 249 | params = pickle.load(pickle_file) 250 | else: 251 | format_not_recognized_error(file_path[-4:]) 252 | self.params = tree_map(jnp.array, params) # move state on device 253 | 254 | @property 255 | @abc.abstractmethod 256 | def params(self): 257 | """Short-cut for parameters. Depends on specific trainer.""" 258 | 259 | @params.setter 260 | @abc.abstractmethod 261 | def params(self, loaded_params): 262 | raise NotImplementedError() 263 | 264 | @abc.abstractmethod 265 | def train(self, *args, **kwargs): 266 | """Training of any trainer should start by calling train.""" 267 | 268 | @abc.abstractmethod 269 | def move_to_device(self): 270 | """Move all attributes that are expected to be on device to device to 271 | avoid TracerExceptions after loading trainers from disk, i.e. 272 | loading numpy rather than device arrays. 273 | """ -------------------------------------------------------------------------------- /examples/alanine_dipeptide/alanine_force_matching.py: -------------------------------------------------------------------------------- 1 | """Training a CG model for alanine dipeptide via force matching.""" 2 | import os 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | 5 | import warnings 6 | warnings.filterwarnings('ignore') # disable warnings about float64 usage 7 | 8 | import cloudpickle as pickle 9 | from pathlib import Path 10 | 11 | from jax import random 12 | from jax_md import space 13 | import matplotlib.pyplot as plt 14 | import optax 15 | 16 | from chemtrain import trainers, data_processing 17 | from chemtrain.jax_md_mod import custom_space, io 18 | from util import Initialization 19 | 20 | Path('output/figures').mkdir(parents=True, exist_ok=True) 21 | Path('output/force_matching').mkdir(parents=True, exist_ok=True) 22 | 23 | # user input 24 | configuration_str = 'data/dataset/confs_heavy_100ns.npy' 25 | force_str = 'data/dataset/forces_heavy_100ns.npy' 26 | file_topology = 'data/confs/heavy_2_7nm.gro' 27 | 28 | training_name = 'FM_model_alanine' 29 | save_plot = f'output/figures/FM_losses_alanine_{training_name}.png' 30 | save_params_path = ('output/force_matching/' 31 | f'trained_params_alanine_{training_name}.pkl') 32 | 33 | used_dataset_size = 500000 34 | train_ratio = 0.8 35 | val_ratio = 0.08 36 | batch_per_device = 500 37 | batch_cache = 50 38 | 39 | initial_lr = 0.001 40 | epochs = 100 41 | check_freq = 10 42 | 43 | system_temperature = 300 # Kelvin 44 | boltzmann_constant = 0.0083145107 # in kJ / mol K 45 | kbt = system_temperature * boltzmann_constant 46 | 47 | model = 'CGDimeNet' 48 | 49 | key = random.PRNGKey(0) 50 | model_init_key, shuffle_key = random.split(key, 2) 51 | 52 | # build datasets 53 | position_data = data_processing.get_dataset(configuration_str, 54 | retain=used_dataset_size) 55 | force_data = data_processing.get_dataset(force_str, retain=used_dataset_size) 56 | 57 | box, _, masses, _ = io.load_box(file_topology) 58 | box_tensor, _ = custom_space.init_fractional_coordinates(box) 59 | displacement, _ = space.periodic_general(box_tensor, 60 | fractional_coordinates=True) 61 | 62 | position_data = data_processing.scale_dataset_fractional(position_data, 63 | box_tensor) 64 | r_init = position_data[0] 65 | 66 | lrd = int(used_dataset_size / batch_per_device * epochs) 67 | lr_schedule = optax.exponential_decay(-initial_lr, lrd, 0.01) 68 | optimizer = optax.chain( 69 | optax.scale_by_adam(), 70 | optax.scale_by_schedule(lr_schedule) 71 | ) 72 | 73 | priors = ['bond', 'angle', 'dihedral'] 74 | species, prior_idxs, prior_constants = Initialization.select_protein( 75 | 'heavy_alanine_dipeptide', priors) 76 | 77 | energy_fn_template, _, init_params, nbrs_init = \ 78 | Initialization.select_model( 79 | model, r_init, displacement, box, model_init_key, kbt, fractional=True, 80 | species=species, prior_constants=prior_constants, prior_idxs=prior_idxs) 81 | 82 | trainer = trainers.ForceMatching(init_params, energy_fn_template, nbrs_init, 83 | optimizer, position_data, 84 | force_data=force_data, 85 | batch_per_device=batch_per_device, 86 | box_tensor=box_tensor, 87 | batch_cache=batch_cache, 88 | train_ratio=train_ratio, 89 | val_ratio=val_ratio) 90 | 91 | 92 | trainer.train(epochs) 93 | 94 | best_params = trainer.best_params 95 | with open(save_params_path, 'wb') as pickle_file: 96 | pickle.dump(best_params, pickle_file) 97 | 98 | plt.figure() 99 | plt.plot(trainer.train_losses, label='Train', color='#3C5488FF') 100 | plt.plot(trainer.val_losses, label='Val', color='#00A087FF') 101 | plt.legend() 102 | plt.ylabel('MSE Loss') 103 | plt.xlabel('Updates') 104 | plt.savefig(save_plot) 105 | -------------------------------------------------------------------------------- /examples/alanine_dipeptide/alanine_relative_entropy.py: -------------------------------------------------------------------------------- 1 | """Training a CG model of alanine dipeptide via relative entropy minimization. 2 | """ 3 | import os 4 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 5 | 6 | import warnings 7 | warnings.filterwarnings('ignore') # disable warnings about float64 usage 8 | 9 | from pathlib import Path 10 | 11 | from jax import random 12 | import optax 13 | 14 | from chemtrain import trainers, traj_util, data_processing 15 | from chemtrain.jax_md_mod import io 16 | from util import Initialization 17 | 18 | Path('output/figures').mkdir(parents=True, exist_ok=True) 19 | Path('output/rel_entropy').mkdir(parents=True, exist_ok=True) 20 | 21 | # user input 22 | configuration_str = 'data/dataset/confs_heavy_100ns.npy' 23 | file_topology = 'data/confs/heavy_2_7nm.gro' 24 | 25 | save_params_path = ('output/rel_entropy/' 26 | 'trained_params_alanine_RE_model_alanine.pkl') 27 | 28 | used_dataset_size = 400000 29 | n_trajectory = 50 30 | num_updates = 300 31 | 32 | # simulation parameters 33 | system_temperature = 300 # Kelvin 34 | boltzmann_constant = 0.0083145107 # in kJ / mol K 35 | kbt = system_temperature * boltzmann_constant 36 | 37 | time_step = 0.002 38 | total_time = 1005 39 | t_equilib = 5. 40 | print_every = 0.2 41 | 42 | model = 'CGDimeNet' 43 | 44 | initial_lr = 0.003 45 | lr_schedule = optax.exponential_decay(-initial_lr, num_updates, 0.01) 46 | optimizer = optax.chain( 47 | optax.scale_by_adam(0.1, 0.4), 48 | optax.scale_by_schedule(lr_schedule) 49 | ) 50 | 51 | timings = traj_util.process_printouts(time_step, total_time, t_equilib, 52 | print_every) 53 | 54 | # initial configuration 55 | box, _, masses, _ = io.load_box(file_topology) 56 | 57 | priors = ['bond', 'angle', 'dihedral'] 58 | species, prior_idxs, prior_constants = Initialization.select_protein( 59 | 'heavy_alanine_dipeptide', priors) 60 | 61 | position_data = data_processing.get_dataset(configuration_str, 62 | retain=used_dataset_size) 63 | 64 | # Random starting configurations 65 | key = random.PRNGKey(0) 66 | r_init = random.choice(key, position_data, (n_trajectory,), replace=False) 67 | 68 | simulation_data = Initialization.InitializationClass( 69 | r_init=r_init, box=box, kbt=kbt, masses=masses, dt=time_step, 70 | species=species) 71 | 72 | init_sim_states, init_params, simulation_fns, _, _ = \ 73 | Initialization.initialize_simulation(simulation_data, 74 | model, 75 | integrator='Langevin', 76 | prior_constants=prior_constants, 77 | prior_idxs=prior_idxs) 78 | 79 | simulator_template, energy_fn_template, neighbor_fn = simulation_fns 80 | 81 | reference_data = data_processing.scale_dataset_fractional(position_data, box) 82 | 83 | # a reweight_ratio > 1 disables reweighting 84 | trainer = trainers.RelativeEntropy(init_params, optimizer, reweight_ratio=1.1, 85 | energy_fn_template=energy_fn_template) 86 | 87 | trainer.add_statepoint( 88 | reference_data, energy_fn_template, simulator_template, neighbor_fn, 89 | timings, kbt, init_sim_states, reference_batch_size=used_dataset_size, 90 | vmap_batch=n_trajectory) 91 | 92 | trainer.train(num_updates) 93 | trainer.save_energy_params(save_params_path, '.pkl') 94 | -------------------------------------------------------------------------------- /examples/alanine_dipeptide/alanine_simulation.py: -------------------------------------------------------------------------------- 1 | """Forward simulation of alanine dipeptide.""" 2 | import os 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | 5 | import warnings 6 | warnings.filterwarnings('ignore') # disable warnings about float64 usage 7 | 8 | import cloudpickle as pickle 9 | from pathlib import Path 10 | import time 11 | 12 | from jax import vmap, random, tree_util, numpy as jnp 13 | from jax_md import space 14 | import numpy as onp 15 | 16 | from chemtrain.jax_md_mod import io, custom_space, custom_quantity 17 | from chemtrain import data_processing, traj_util 18 | from util import Initialization 19 | import visualization 20 | 21 | Path('output/trajectories').mkdir(parents=True, exist_ok=True) 22 | 23 | # user input 24 | save_name = 'FM_2fs_100ns' 25 | folder_name = 'FM_2fs_100ns/' 26 | labels = ['Reference', 'Predicted'] 27 | 28 | file_topology = 'data/confs/heavy_2_7nm.gro' 29 | configuration_str = 'data/dataset/confs_heavy_100ns.npy' 30 | used_dataset_size = 500000 31 | n_trajectory = 50 32 | 33 | model = 'CGDimeNet' 34 | 35 | 36 | saved_params_path = ('output/force_matching/' 37 | 'trained_params_alanine_FM_model_alanine.pkl') 38 | # saved_params_path = ('output/rel_entropy/' 39 | # 'trained_params_alanine_RE_model_alanine.pkl') 40 | 41 | # simulation parameters 42 | system_temperature = 300 # Kelvin 43 | boltzmann_constant = 0.0083145107 # in kJ / mol K 44 | kbt = system_temperature * boltzmann_constant 45 | time_step = 0.002 46 | 47 | total_time = 110000 48 | t_equilib = 10000. 49 | print_every = 0.2 50 | ############################### 51 | 52 | box, _, masses, _ = io.load_box(file_topology) 53 | priors = ['bond', 'angle', 'dihedral'] 54 | species, prior_idxs, prior_constants = Initialization.select_protein( 55 | 'heavy_alanine_dipeptide', priors) 56 | 57 | # random starting points 58 | position_data = data_processing.get_dataset(configuration_str)[1:] 59 | key = random.PRNGKey(0) 60 | r_init = random.choice(key, position_data, (n_trajectory,), replace=False) 61 | 62 | simulation_data = Initialization.InitializationClass( 63 | r_init=r_init, box=box, kbt=kbt, masses=masses, dt=time_step, 64 | species=species) 65 | timings = traj_util.process_printouts(time_step, total_time, t_equilib, 66 | print_every) 67 | 68 | reference_state, energy_params, simulation_fns, compute_fns, targets = \ 69 | Initialization.initialize_simulation(simulation_data, 70 | model, 71 | integrator='Langevin', 72 | prior_constants=prior_constants, 73 | prior_idxs=prior_idxs) 74 | 75 | simulator_template, energy_fn_template, neighbor_fn = simulation_fns 76 | 77 | if saved_params_path is not None: 78 | print('Using saved params') 79 | with open(saved_params_path, 'rb') as pickle_file: 80 | params = pickle.load(pickle_file) 81 | energy_params = tree_util.tree_map(jnp.array, params) 82 | 83 | trajectory_generator = traj_util.trajectory_generator_init(simulator_template, 84 | energy_fn_template, 85 | timings) 86 | 87 | t_start = time.time() 88 | traj_state = trajectory_generator(energy_params, reference_state) 89 | t_end = time.time() - t_start 90 | print('total runtime in min:', t_end / 60.) 91 | 92 | assert not traj_state.overflow, ('Neighborlist overflow during trajectory ' 93 | 'generation. Increase capacity and re-run.') 94 | 95 | # postprocessing 96 | traj_positions = traj_state.trajectory.position 97 | jnp.save(f'output/trajectories/confs_alanine_{save_name}', traj_positions) 98 | 99 | box_tensor, _ = custom_space.init_fractional_coordinates(box) 100 | displacement, _ = space.periodic_general(box_tensor, 101 | fractional_coordinates=True) 102 | position_data = data_processing.scale_dataset_fractional(position_data, box) 103 | 104 | # dihedrals 105 | dihedral_idxs = jnp.array([[1, 3, 4, 6], [3, 4, 6, 8]]) # 0: phi 1: psi 106 | batched_dihedrals = vmap(custom_quantity.dihedral_displacement, (0, None, None)) 107 | 108 | dihedrals_ref = batched_dihedrals(position_data, displacement, dihedral_idxs) 109 | dihedral_angles = batched_dihedrals(traj_positions, displacement, dihedral_idxs) 110 | 111 | phi = dihedral_angles[:, 0].reshape((n_trajectory, -1)) 112 | psi = dihedral_angles[:, 1].reshape((n_trajectory, -1)) 113 | 114 | # mean squared error (in rad) 115 | nbins = 60 116 | dihedrals_ref_rad = jnp.deg2rad(dihedrals_ref) 117 | dihedral_angles_rad = jnp.deg2rad(dihedral_angles) 118 | h_ref, y1, y2 = onp.histogram2d( 119 | dihedrals_ref_rad[:, 0], dihedrals_ref_rad[:, 1], bins=nbins, density=True) 120 | h_pred, _, _ = onp.histogram2d( 121 | dihedral_angles_rad[:, 0], dihedral_angles_rad[:, 1], bins=nbins, 122 | density=True) 123 | 124 | mse = onp.mean((h_ref - h_pred)**2) 125 | print('MSE of the phi-psi dihedral density histrogram: ', mse) 126 | 127 | # unstack parallel trajectories 128 | dihedral_angles_split = dihedral_angles.reshape((n_trajectory, -1, 2)) 129 | 130 | # Plots 131 | phi_angles_ref = onp.load('data/dataset/phi_angles_r100ns.npy') 132 | psi_angles_ref = onp.load('data/dataset/psi_angles_r100ns.npy') 133 | 134 | # dihedral histograms 135 | Path(f'output/postprocessing/{folder_name}').mkdir(parents=True, exist_ok=True) 136 | 137 | 138 | visualization.plot_histogram_density(dihedral_angles_split[0, :], 139 | save_name + '_first_predicted_', 140 | folder=folder_name) 141 | visualization.plot_histogram_density(dihedrals_ref, save_name + '_REF', 142 | folder=folder_name) 143 | 144 | visualization.plot_1d_dihedral( 145 | [phi_angles_ref, phi], 'phi_' + save_name, labels=labels[0:2], 146 | folder=folder_name) 147 | visualization.plot_1d_dihedral( 148 | [psi_angles_ref, psi], 'psi_' + save_name, location='upper left', 149 | labels=labels[0:2], xlabel='$\psi$ in deg', folder=folder_name) 150 | 151 | visualization.plot_histogram_free_energy(dihedral_angles_split[0, :], 152 | save_name + '_pred', 153 | kbt, folder=folder_name) 154 | visualization.plot_histogram_free_energy(dihedrals_ref, save_name + '_REF', 155 | kbt, folder=folder_name) 156 | 157 | visualization.plot_compare_histogram_free_energy( 158 | [dihedrals_ref, dihedral_angles_split[0, :]], save_name, kbt, titles=labels, 159 | folder=folder_name) 160 | 161 | visualization.plot_compare_histogram_density( 162 | [dihedrals_ref, dihedral_angles_split[0, :]], save_name, titles=labels, 163 | folder=folder_name) 164 | -------------------------------------------------------------------------------- /examples/alanine_dipeptide/data/confs/heavy_2_7nm.gro: -------------------------------------------------------------------------------- 1 | Protein in water 2 | 10 3 | 1ACE CH3 1 1.186 1.717 1.471 -0.9155 0.3007 0.6508 4 | 1ACE C 2 1.282 1.595 1.497 0.4413 0.1841 0.3569 5 | 1ACE O 3 1.321 1.582 1.611 0.0613 0.0256 -0.5025 6 | 2ALA N 4 1.314 1.512 1.396 -0.1254 0.1223 -0.4071 7 | 2ALA CA 5 1.394 1.389 1.414 0.8537 -0.1251 -1.0147 8 | 2ALA CB 6 1.300 1.278 1.473 0.5590 -0.2504 0.8860 9 | 2ALA C 7 1.459 1.339 1.288 0.0259 -0.6055 -0.2998 10 | 2ALA O 8 1.404 1.360 1.178 0.9320 -0.4149 -0.3561 11 | 3NME N 9 1.573 1.280 1.306 -0.1968 -0.0872 0.2543 12 | 3NME CH3 10 1.650 1.232 1.190 -0.3199 0.2708 -0.1091 13 | 2.71381 2.71381 2.71381 14 | -------------------------------------------------------------------------------- /examples/alanine_dipeptide/data/dataset/phi_angles_r100ns.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/examples/alanine_dipeptide/data/dataset/phi_angles_r100ns.npy -------------------------------------------------------------------------------- /examples/alanine_dipeptide/data/dataset/psi_angles_r100ns.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/examples/alanine_dipeptide/data/dataset/psi_angles_r100ns.npy -------------------------------------------------------------------------------- /examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_dihedral_constant.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_dihedral_constant.npy -------------------------------------------------------------------------------- /examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_dihedral_multiplicity.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_dihedral_multiplicity.npy -------------------------------------------------------------------------------- /examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_dihedral_phase.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_dihedral_phase.npy -------------------------------------------------------------------------------- /examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_epsilon.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_epsilon.npy -------------------------------------------------------------------------------- /examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_eq_angle.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_eq_angle.npy -------------------------------------------------------------------------------- /examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_eq_angle_variance.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_eq_angle_variance.npy -------------------------------------------------------------------------------- /examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_eq_bond_length.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_eq_bond_length.npy -------------------------------------------------------------------------------- /examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_eq_bond_variance.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_eq_bond_variance.npy -------------------------------------------------------------------------------- /examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_sigma.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/examples/alanine_dipeptide/data/prior/Alanine_dipeptide_heavy_sigma.npy -------------------------------------------------------------------------------- /examples/alanine_dipeptide/visualization.py: -------------------------------------------------------------------------------- 1 | """Plot functions to visualize free energy surface of alanine dipeptide.""" 2 | from jax import numpy as jnp 3 | from matplotlib.animation import FuncAnimation 4 | import matplotlib.colors as clr 5 | import matplotlib.pyplot as plt 6 | import numpy as onp 7 | from scipy.interpolate import interp1d 8 | 9 | 10 | def plot_scatter_forces(predicted, reference, save_as, model, line=2000): 11 | """Scatter plot of predicted and reference forces.""" 12 | plt.figure() 13 | hex_bin = plt.hexbin(predicted, reference, gridsize=50, mincnt=1, 14 | vmax=200000) 15 | plt.plot([-line, line], [-line, line], 'r--') 16 | plt.ylabel('Reference Force Components [kJ $\mathrm{mol^{-1} \ nm^{-1}}$]') 17 | plt.xlabel('Predicted Force Components [kJ $\mathrm{mol^{-1} \ nm^{-1}}$]') 18 | plt.title(model) 19 | # plt.xlim([-line,line]) 20 | # plt.ylim([-line,line]) 21 | plt.tight_layout() 22 | cbar = plt.colorbar(hex_bin) 23 | cbar.ax.set_yticklabels(['{:,.0f}K'.format(i/1000) 24 | for i in cbar.get_ticks()]) 25 | cbar.set_label('Number of data points') 26 | plt.savefig('output/postprocessing/force_scatter_' + save_as + '.png') 27 | 28 | 29 | def dihedral_map(): 30 | mymap = onp.array([[0.9, 0.9, 0.9], 31 | [0.85, 0.85, 0.85], 32 | [0.8, 0.8, 0.8], 33 | [0.75, 0.75, 0.75], 34 | [0.7, 0.7, 0.7], 35 | [0.65, 0.65, 0.65], 36 | [0.6, 0.6, 0.6], 37 | [0.55, 0.55, 0.55], 38 | [0.5, 0.5, 0.5], 39 | [0.45, 0.45, 0.45], 40 | [0.4, 0.4, 0.4], 41 | [0.35, 0.35, 0.35], 42 | [0.3, 0.3, 0.3], 43 | [0.25, 0.25, 0.25], 44 | [0.2, 0.2, 0.2], 45 | [0.15, 0.15, 0.15], 46 | [0.1, 0.1, 0.1], 47 | [0.05, 0.05, 0.05], 48 | [0, 0, 0]]) 49 | newcmp = clr.ListedColormap(mymap) 50 | return newcmp 51 | 52 | 53 | def annotate_alanine_histrogram(axis=None): 54 | if axis is None: 55 | target = plt 56 | target.xlabel('$\phi$ in $\mathrm{deg}$') 57 | target.ylabel('$\psi$ in $\mathrm{deg}$') 58 | target.xlim([-180, 180]) 59 | target.ylim([-180, 180]) 60 | else: 61 | target = axis 62 | target.set_xlabel('$\phi$ in $\mathrm{deg}$') 63 | target.set_ylabel('$\psi$ in $\mathrm{deg}$') 64 | target.set_xlim([-180, 180]) 65 | target.set_ylim([-180, 180]) 66 | 67 | target.text(-155, 90, '$C5$', fontsize=18) 68 | target.text(-70, 90, '$C7eq$', fontsize=18) 69 | target.text(145, 90, '$C5$', fontsize=18) 70 | target.text(-155, -150, '$C5$', fontsize=18) 71 | target.text(-70, -150, '$C7eq$', fontsize=18) 72 | target.text(145, -150, '$C5$', fontsize=18) 73 | target.text(-170, -90, r'$\alpha_R$"', fontsize=18) 74 | target.text(140, -90, r'$\alpha_R$"', fontsize=18) 75 | target.text(-70, -90, r'$\alpha_R$', fontsize=18) 76 | target.text(70, 0, r'$\alpha_L$', fontsize=18) 77 | target.plot([-180, 13], [74, 74], 'k', linewidth=0.5) 78 | target.plot([128, 180], [74, 74], 'k', linewidth=0.5) 79 | target.plot([13, 13], [-180, 180], 'k', linewidth=0.5) 80 | target.plot([128, 128], [-180, 180], 'k', linewidth=0.5) 81 | target.plot([-180, 13], [-125, -125], 'k', linewidth=0.5) 82 | target.plot([128, 180], [-125, -125], 'k', linewidth=0.5) 83 | target.plot([-134, -134], [-125, 74], 'k', linewidth=0.5) 84 | target.plot([-110, -110], [-180, -125], 'k', linewidth=0.5) 85 | target.plot([-110, -110], [74, 180], 'k', linewidth=0.5) 86 | 87 | 88 | def plot_histogram_density(angles, saveas, folder=''): 89 | """Plot 2D histogram for alanine from the dihedral angles.""" 90 | newcmp = dihedral_map() 91 | h, x_edges, y_edges = jnp.histogram2d(angles[:, 0], angles[:, 1], 92 | bins=60, density=True) 93 | 94 | h_masked = onp.where(h == 0, onp.nan, h) 95 | x, y = onp.meshgrid(x_edges, y_edges) 96 | 97 | plt.figure() 98 | # vmin=1, vmax=5.25 99 | plt.pcolormesh(x, y, h_masked.T, cmap=newcmp, vmax=0.000225) 100 | # axs.xaxis.set_major_formatter(tck.FormatStrFormatter('%g $\pi$')) 101 | # axs.xaxis.set_major_locator(tck.MultipleLocator(base=1.0)) 102 | cbar = plt.colorbar() 103 | cbar.formatter.set_powerlimits((0, 0)) 104 | cbar.formatter.set_useMathText(True) 105 | cbar.set_label('Density') 106 | annotate_alanine_histrogram() 107 | plt.savefig(f'output/postprocessing/{folder}histogram_density_{saveas}.png') 108 | plt.close('all') 109 | 110 | 111 | def plot_compare_histogram_density(list_angles, saveas, titles=None, folder=''): 112 | """Plot 2D density histogram for alanine from the dihedral angles.""" 113 | newcmp = dihedral_map() 114 | 115 | n_plots = len(list_angles) 116 | fig, axs = plt.subplots(ncols=n_plots, figsize=(6.4 * n_plots, 4.8), 117 | constrained_layout=True) 118 | 119 | images = [] 120 | for i in range(n_plots): 121 | h, x_edges, y_edges = jnp.histogram2d( 122 | list_angles[i][:, 0], list_angles[i][:, 1], bins=60, density=True) 123 | h_masked = onp.where(h == 0, onp.nan, h) 124 | x, y = onp.meshgrid(x_edges, y_edges) 125 | images.append(axs[i].pcolormesh(x,y,h_masked.T, cmap=newcmp)) 126 | if titles: 127 | axs[i].set_title(titles[i]) 128 | annotate_alanine_histrogram(axs[i]) 129 | 130 | # Find the min and max of all colors for use in setting the color scale. 131 | vmin = min(image.get_array().min() for image in images) 132 | vmax = max(image.get_array().max() for image in images) 133 | norm = clr.Normalize(vmin=vmin, vmax=vmax) 134 | for im in images: 135 | im.set_norm(norm) 136 | 137 | cbar = fig.colorbar(images[0], ax=axs) 138 | cbar.formatter.set_powerlimits((0, 0)) 139 | cbar.formatter.set_useMathText(True) 140 | cbar.set_label('Density') 141 | 142 | plt.savefig(f'output/postprocessing/{folder}histogram_compare_' 143 | f'density_{saveas}.png') 144 | plt.close('all') 145 | 146 | 147 | def visualize_dihedral_series(angles, saveas, dihedral_ref=None): 148 | """Visualize the initial points of parallel trajectories 149 | during training. Option to add a reference histogram as a background.""" 150 | fig, ax = plt.subplots() 151 | if dihedral_ref is not None: 152 | newcmp = dihedral_map() 153 | plt.hist2d(dihedral_ref[:, 0], dihedral_ref[:, 1], bins=60, cmap=newcmp, 154 | cmin=1) 155 | plt.xlim(-180, 180) 156 | plt.ylim(-180, 180) 157 | plt.xlabel('$\phi$ in $\mathrm{deg}$') 158 | plt.ylabel('$\psi$ in $\mathrm{deg}$') 159 | #Plot initial points 160 | plt.scatter(angles[0, :, 0], angles[0, :, 1], color='tab:orange') 161 | graph = plt.scatter([], [], color='tab:blue') 162 | 163 | def update(i): 164 | if i % 10 == 0: 165 | label = f'Epoch {i}' 166 | print(label) 167 | # Plot next points dynamically 168 | graph.set_offsets(angles[1:i+1].reshape((i * angles.shape[1], 2))) 169 | ax.set_title('Epoche ' + str(i+1)) 170 | return graph 171 | 172 | anim = FuncAnimation(fig, update, frames=angles.shape[0] - 1, interval=200, 173 | repeat=False) 174 | anim.save(f'output/postprocessing/dihedral_series_{saveas}.gif', dpi=80, 175 | writer='imagemagick', extra_args=['-loop', '1']) 176 | 177 | 178 | def plot_dihedral_diff(angles, saveas, dihedral_ref): 179 | """Plots the difference between the reference dihedrals and the predicted 180 | angles.""" 181 | h_ref, x_ref, y_ref = jnp.histogram2d( 182 | dihedral_ref[:, 0], dihedral_ref[:, 1], bins=60, density=True) 183 | h_predicted, _, _ = jnp.histogram2d(angles[:, 0], angles[:, 1], 184 | bins=60, density=True) 185 | 186 | x, y = onp.meshgrid(x_ref, y_ref) 187 | h_diff = h_ref - h_predicted 188 | h_max = jnp.max(h_diff) 189 | vmin, vmax = -h_max, h_max 190 | norm = clr.Normalize(vmin, vmax) 191 | cmap = plt.get_cmap('PiYG') 192 | 193 | plt.figure() 194 | plt.pcolormesh(x, y, h_diff.T, cmap=cmap, norm=norm) 195 | plt.colorbar() 196 | plt.savefig(f'output/postprocessing/diff_dihedral_{saveas}.png') 197 | 198 | 199 | def plot_histogram_free_energy(angles, saveas, kbt, degrees=True, folder=''): 200 | """Plot 2D free energy histogram for alanine from the dihedral angles.""" 201 | cmap = plt.get_cmap('magma') 202 | 203 | if degrees: 204 | angles = jnp.deg2rad(angles) 205 | 206 | h, x_edges, y_edges = jnp.histogram2d(angles[:, 0], angles[:, 1], 207 | bins=60, density=True) 208 | 209 | h = jnp.log(h) * -(kbt / 4.184) 210 | x, y = onp.meshgrid(x_edges, y_edges) 211 | 212 | plt.figure() 213 | plt.pcolormesh(x, y, h.T, cmap=cmap, vmax=5.25) 214 | cbar = plt.colorbar() 215 | cbar.set_label('Free Energy (kcal/mol)') 216 | plt.xlabel('$\phi$ in rad') 217 | plt.ylabel('$\psi$ in rad') 218 | plt.savefig(f'output/postprocessing/{folder}histogram_free' 219 | f'_energy_{saveas}.png') 220 | plt.close('all') 221 | 222 | 223 | def plot_compare_histogram_free_energy(list_angles, saveas, kbt, degrees=True, 224 | titles=None, folder=''): 225 | """Plot 2D free energy histogram for alanine from the dihedral angles. 226 | Comparison of multiple curves.""" 227 | cmap = plt.get_cmap('magma') 228 | 229 | n_plots = len(list_angles) 230 | fig, axs = plt.subplots(ncols=n_plots, figsize=(6.4 * n_plots, 4.8), 231 | constrained_layout=True) 232 | 233 | images = [] 234 | for i in range(n_plots): 235 | if degrees: 236 | angles = jnp.deg2rad(list_angles[i]) 237 | else: 238 | angles = list_angles[i] 239 | h, x_edges, y_edges = jnp.histogram2d(angles[:, 0], angles[:, 1], 240 | bins=60, density=True) 241 | h_masked = jnp.log(h) * -kbt / 4.184 242 | x, y = onp.meshgrid(x_edges, y_edges) 243 | images.append(axs[i].pcolormesh(x, y, h_masked.T, cmap=cmap)) 244 | axs[i].set_xlabel('$\phi$ in rad') 245 | axs[i].set_ylabel('$\psi$ in rad') 246 | if titles: 247 | axs[i].set_title(titles[i]) 248 | 249 | # Find the min and max of all colors for use in setting the color scale. 250 | vmin = min(image.get_array().min() for image in images) 251 | vmax = max(image.get_array().max() for image in images) 252 | norm = clr.Normalize(vmin=vmin, vmax=vmax) 253 | for im in images: 254 | im.set_norm(norm) 255 | 256 | cbar = fig.colorbar(images[0], ax=axs) 257 | cbar.set_label('Free Energy (kcal/mol)') 258 | 259 | plt.savefig(f'output/postprocessing/{folder}histogram_compare' 260 | f'_free_energy_{saveas}.png') 261 | plt.close('all') 262 | 263 | 264 | def _spline_free_energy(angles, bins, kbt): 265 | h, x_bins = jnp.histogram(angles, bins=bins, density=True) 266 | h = jnp.log(h) * -kbt / 4.184 267 | h_spline = onp.where(h == jnp.inf, 15, h) 268 | width = x_bins[1] - x_bins[0] 269 | bin_center = x_bins + width / 2 270 | cubic_interploation_model = interp1d(bin_center[:-1], h_spline, 271 | kind='cubic') 272 | x_ = onp.linspace(bin_center[:-1].min(), bin_center[:-1].max(), 40) 273 | y_ = cubic_interploation_model(x_) 274 | return x_, y_ 275 | 276 | 277 | def plot_compare_1d_free_energy(angles, reference, saveas, labels, kbt, bins=60, 278 | degrees=True, xlabel='$\phi$ in rad', 279 | folder=''): 280 | """Plot and save spline interpolation of the 1D histogram for 281 | alanine dipeptide free energies from the dihedral angles. 282 | angles: angles in form of list of numpy arrays of [Nmodels, Nangles]. 283 | Default for phi angles. 284 | Use xlabel='$\psi$' for psi angles. 285 | """ 286 | if degrees: 287 | reference = jnp.deg2rad(reference) 288 | x_ref, y_ref = _spline_free_energy(reference, bins, kbt) 289 | plt.figure() 290 | plt.plot(x_ref, y_ref, color='k', linestyle='--', label='Reference', 291 | linewidth=2.5) 292 | n_models = len(angles) 293 | for i in range(n_models): 294 | angle = jnp.deg2rad(angles[i]) if degrees else angles[i] 295 | x, y = _spline_free_energy(angle, bins, kbt) 296 | plt.plot(x, y, linewidth=2.5, label=labels[i]) 297 | 298 | plt.xlabel(xlabel) 299 | plt.ylabel('Free Energy (kcal/mol)') 300 | plt.xlim([-jnp.pi, jnp.pi]) 301 | plt.ylim([0, 7]) 302 | plt.title(saveas) 303 | plt.legend() 304 | plt.savefig(f'output/postprocessing/{folder}free_energy_1D_{saveas}.png') 305 | plt.close('all') 306 | 307 | 308 | def plot_1d_dihedral(angles, saveas, labels, bins=60, degrees=True, 309 | location=None, xlabel='$\phi$ in deg', folder='', 310 | color=None, line=None): 311 | """Plot 1D histogram splines for alanine dipeptide dihedral angles with 312 | mean and standard deviation for different models. 313 | 314 | angles: angles in form of list of [Ntrajectory x Nangles] or 315 | numpy arrays of [Nmodels, Ntrajectory, Nangles]. 316 | Default for phi angles. Use xlabel='$\psi$' for psi angles. 317 | """ 318 | plt.figure() 319 | if color is None: 320 | color = ['k', '#00A087FF', '#3C5488FF'] 321 | if line is None: 322 | line = ['--', '-', '-'] 323 | n_models = len(angles) 324 | for i in range(n_models): 325 | if degrees: 326 | angles_conv = angles[i] 327 | else: 328 | angles_conv = onp.rad2deg(angles[i]) 329 | n_traj = angles_conv.shape[0] 330 | h_temp = onp.zeros((n_traj, bins)) 331 | for j in range(n_traj): 332 | h, x_bins = jnp.histogram(angles_conv[j, :], bins=bins, 333 | density=True) 334 | width = x_bins[1] - x_bins[0] 335 | bin_center = x_bins + width/2 336 | h_temp[j] = h 337 | h_mean = jnp.mean(h_temp, axis=0) 338 | h_std = jnp.std(h_temp, axis=0) 339 | plt.plot(bin_center[:-1], h_mean, label=labels[i], color=color[i], 340 | linestyle=line[i], linewidth=2.0, zorder=n_models-i) 341 | plt.fill_between(bin_center[:-1], h_mean-h_std, h_mean+h_std, 342 | color=color[i], alpha=0.4, zorder=n_models-i) 343 | plt.xlabel(xlabel) 344 | plt.ylabel('Density') 345 | if location is not None: 346 | plt.legend(loc=location) 347 | else: 348 | plt.legend() 349 | plt.savefig(f'output/postprocessing/{folder}dihedral_1D_{saveas}.png') 350 | plt.close('all') 351 | -------------------------------------------------------------------------------- /examples/water/CG_water_force_matching.py: -------------------------------------------------------------------------------- 1 | """Train a CG water model via force matching.""" 2 | import os 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | 5 | import warnings 6 | warnings.filterwarnings('ignore') # disable warnings about float64 usage 7 | 8 | import cloudpickle as pickle 9 | from pathlib import Path 10 | 11 | from jax import random, numpy as jnp 12 | from jax_md import space 13 | import matplotlib.pyplot as plt 14 | import numpy as onp 15 | import optax 16 | 17 | from chemtrain import trainers, data_processing 18 | from chemtrain.jax_md_mod import custom_space 19 | from util import Initialization 20 | 21 | Path('output/figures').mkdir(parents=True, exist_ok=True) 22 | Path('output/force_matching').mkdir(parents=True, exist_ok=True) 23 | 24 | # user input 25 | configuration_str = 'data/dataset/conf_COM_10k.npy' 26 | force_str = 'data/dataset/forces_COM_10k.npy' 27 | box_str = 'data/dataset/box.npy' 28 | 29 | training_name = 'FM_model_water' 30 | save_plot = f'output/figures/force_matching_losses_{training_name}.png' 31 | save_params_path = ('output/force_matching/' 32 | f'trained_params_water_{training_name}.pkl') 33 | 34 | used_dataset_size = 10000 35 | train_ratio = 0.8 36 | val_ratio = 0.08 37 | batch_per_device = 10 38 | batch_cache = 10 39 | 40 | epochs = 100 41 | 42 | box_length = onp.load(box_str) 43 | box = jnp.ones(3) * box_length 44 | 45 | model = 'CGDimeNet' 46 | # model = 'Tabulated' 47 | 48 | # build datasets 49 | position_data = data_processing.get_dataset(configuration_str, 50 | retain=used_dataset_size) 51 | force_data = data_processing.get_dataset(force_str, retain=used_dataset_size) 52 | 53 | dataset_size = position_data.shape[0] 54 | print('Dataset size:', dataset_size) 55 | 56 | if model == 'Tabulated': 57 | initial_lr = 0.1 58 | elif model == 'CGDimeNet': 59 | initial_lr = 0.001 60 | else: 61 | raise NotImplementedError 62 | 63 | decay_length = int(dataset_size / batch_per_device * epochs) 64 | lr_schedule = optax.exponential_decay(-initial_lr, decay_length, 0.01) 65 | optimizer = optax.chain( 66 | optax.scale_by_adam(), 67 | optax.scale_by_schedule(lr_schedule) 68 | ) 69 | 70 | box_tensor, _ = custom_space.init_fractional_coordinates(box) 71 | displacement, _ = space.periodic_general(box_tensor, 72 | fractional_coordinates=True) 73 | position_data = data_processing.scale_dataset_fractional(position_data, 74 | box_tensor) 75 | R_init = position_data[0] 76 | 77 | model_init_key = random.PRNGKey(0) 78 | constants = {'repulsive': (0.3165, 1., 0.5, 12)} 79 | idxs = {} 80 | energy_fn_template, _, init_params, nbrs_init = \ 81 | Initialization.select_model(model, R_init, displacement, box, 82 | model_init_key, fractional=True, 83 | prior_constants=constants, prior_idxs=idxs) 84 | 85 | trainer = trainers.ForceMatching(init_params, energy_fn_template, nbrs_init, 86 | optimizer, position_data, 87 | force_data=force_data, 88 | batch_per_device=batch_per_device, 89 | box_tensor=box_tensor, 90 | batch_cache=batch_cache, 91 | train_ratio=train_ratio, 92 | val_ratio=val_ratio) 93 | 94 | trainer.train(epochs) 95 | 96 | with open(save_params_path, 'wb') as pickle_file: 97 | pickle.dump(trainer.best_params, pickle_file) 98 | 99 | plt.figure() 100 | plt.plot(trainer.train_losses, label='Train', color='#3C5488FF') 101 | plt.plot(trainer.val_losses, label='Val', color='#00A087FF') 102 | plt.legend() 103 | plt.ylabel('MSE Loss') 104 | plt.xlabel('Update step') 105 | plt.savefig(save_plot) 106 | -------------------------------------------------------------------------------- /examples/water/CG_water_relative_entropy.py: -------------------------------------------------------------------------------- 1 | """Training a CG water model via relative entropy minimization.""" 2 | import os 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | 5 | import warnings 6 | warnings.filterwarnings('ignore') # disable warnings about float64 usage 7 | 8 | from pathlib import Path 9 | 10 | from jax import numpy as jnp 11 | import numpy as onp 12 | import optax 13 | 14 | from chemtrain import trainers, traj_util, data_processing 15 | from util import Initialization 16 | 17 | Path('output/figures').mkdir(parents=True, exist_ok=True) 18 | Path('output/rel_entropy').mkdir(parents=True, exist_ok=True) 19 | 20 | # dataset 21 | configuration_str = 'data/dataset/conf_COM_10k.npy' 22 | box_str = 'data/dataset/box.npy' 23 | 24 | save_param_path = 'output/rel_entropy/trained_params_water_RE_model_water.pkl' 25 | used_dataset_size = 8000 26 | num_updates = 300 27 | 28 | # simulation parameters 29 | system_temperature = 298 # Kelvin 30 | boltzmann_constant = 0.0083145107 # in kJ / mol K 31 | kbt = system_temperature * boltzmann_constant 32 | mass = 18.0154 33 | rand_seed = 0 34 | 35 | time_step = 0.002 36 | total_time = 75. 37 | t_equilib = 5. 38 | print_every = 0.1 39 | 40 | model = 'CGDimeNet' 41 | # model = 'Tabulated' 42 | 43 | if model == 'Tabulated': 44 | initial_lr = 0.1 45 | elif model == 'CGDimeNet': 46 | initial_lr = 0.003 47 | else: 48 | raise NotImplementedError 49 | 50 | lr_schedule = optax.exponential_decay(-initial_lr, num_updates, 0.01) 51 | optimizer = optax.chain( 52 | optax.scale_by_adam(0.1, 0.4), 53 | optax.scale_by_schedule(lr_schedule) 54 | ) 55 | 56 | timings = traj_util.process_printouts(time_step, total_time, t_equilib, 57 | print_every) 58 | 59 | box_length = onp.load(box_str) 60 | box = jnp.ones(3) * box_length 61 | 62 | position_data = data_processing.get_dataset(configuration_str, 63 | retain=used_dataset_size) 64 | r_init = position_data[0] 65 | 66 | constants = {'repulsive': (0.3165, 1., 0.5, 12)} 67 | idxs = {} 68 | 69 | simulation_data = Initialization.InitializationClass( 70 | r_init=r_init, box=box, kbt=kbt, masses=mass, dt=time_step) 71 | reference_state, init_params, simulation_fns, _, _ = \ 72 | Initialization.initialize_simulation( 73 | simulation_data, model, key_init=rand_seed, prior_constants=constants, 74 | prior_idxs=idxs) 75 | simulator_template, energy_fn_template, neighbor_fn = simulation_fns 76 | 77 | reference_data = data_processing.scale_dataset_fractional(position_data, box) 78 | trainer = trainers.RelativeEntropy( 79 | init_params, optimizer, energy_fn_template=energy_fn_template, 80 | reweight_ratio=1.1) 81 | 82 | trainer.add_statepoint(reference_data, energy_fn_template, simulator_template, 83 | neighbor_fn, timings, kbt, reference_state) 84 | 85 | trainer.train(num_updates) 86 | 87 | # save parameters 88 | trainer.save_energy_params(save_param_path, '.pkl') 89 | -------------------------------------------------------------------------------- /examples/water/CG_water_simulation.py: -------------------------------------------------------------------------------- 1 | """Runs a CG water simulation in Jax M.D with loaded parameters. 2 | Trajectory generation for postprocessing and analysis of simulations. 3 | """ 4 | import os 5 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 6 | 7 | import warnings 8 | warnings.filterwarnings('ignore') # disable warnings about float64 usage 9 | 10 | from pathlib import Path 11 | import time 12 | 13 | import cloudpickle as pickle 14 | from jax import tree_util, numpy as jnp 15 | import numpy as onp 16 | 17 | from chemtrain import traj_util, data_processing 18 | from util import Postprocessing, Initialization 19 | 20 | Path('output/figures').mkdir(parents=True, exist_ok=True) 21 | Path('output/trajectories').mkdir(parents=True, exist_ok=True) 22 | Path('output/properties').mkdir(parents=True, exist_ok=True) 23 | 24 | configuration_str = 'data/dataset/conf_COM_10k.npy' 25 | box_str = 'data/dataset/box.npy' 26 | 27 | # model = 'Tabulated' 28 | model = 'CGDimeNet' 29 | 30 | plotname = 'FM_2fs_1ns' 31 | 32 | 33 | saved_params_path = 'output/rel_entropy/trained_params_water_RE_model_water.pkl' 34 | # saved_params_path = ('output/force_matching/' 35 | # 'trained_params_water_FM_model_water.pkl') 36 | 37 | system_temperature = 298 # Kelvin 38 | boltzmann_constant = 0.0083145107 # in kJ / mol K 39 | kbt = system_temperature * boltzmann_constant 40 | mass = 18.0154 41 | time_step = 0.002 42 | 43 | total_time = 1100. 44 | t_equilib = 100. 45 | print_every = 0.1 46 | 47 | target_name = 'TIP4P/2005' 48 | rdf_struct = Initialization.select_target_rdf(target_name) 49 | tcf_struct = Initialization.select_target_tcf(target_name, 0.5, nbins=50) 50 | adf_struct = Initialization.select_target_adf(target_name, 0.318) 51 | 52 | # targets are only dummy 53 | target_dict = {'rdf': rdf_struct, 'adf': adf_struct, 'tcf': tcf_struct, 54 | 'pressure': 1.} 55 | 56 | ############################### 57 | used_dataset_size = 1000 58 | box_length = onp.load(box_str) 59 | box = jnp.ones(3) * box_length 60 | 61 | position_data = data_processing.get_dataset(configuration_str, 62 | retain=used_dataset_size) 63 | r_init = onp.array(position_data[0]) 64 | 65 | constants = {'repulsive': (0.3165, 1., 0.5, 12)} 66 | idxs = {} 67 | 68 | simulation_data = Initialization.InitializationClass( 69 | r_init=r_init, box=box, kbt=kbt, masses=mass, dt=time_step) 70 | timings = traj_util.process_printouts(time_step, total_time, t_equilib, 71 | print_every) 72 | 73 | reference_state, energy_params, simulation_fns, compute_fns, _ = \ 74 | Initialization.initialize_simulation( 75 | simulation_data, model, target_dict, 76 | prior_idxs=idxs, prior_constants=constants 77 | ) 78 | simulator_template, energy_fn_template, neighbor_fn = simulation_fns 79 | 80 | if saved_params_path is not None: 81 | with open(saved_params_path, 'rb') as pickle_file: 82 | params = pickle.load(pickle_file) 83 | energy_params = tree_util.tree_map(jnp.array, params) 84 | 85 | trajectory_generator = traj_util.trajectory_generator_init(simulator_template, 86 | energy_fn_template, 87 | timings) 88 | 89 | # compute trajectory and quantities 90 | t_start = time.time() 91 | traj_state = trajectory_generator(energy_params, reference_state) 92 | print('trajectory ps/min: ', total_time / ((time.time() - t_start) / 60.)) 93 | 94 | jnp.save(f'output/trajectories/CG_water_trajectory_{plotname}', 95 | traj_state.trajectory.position) 96 | assert not traj_state.overflow, ('Neighborlist overflow during trajectory ' 97 | 'generation. Increase capacity and re-run.') 98 | 99 | print('average kbT:', jnp.mean(traj_state.aux['kbT']), 'vs reference:', kbt) 100 | 101 | t_post_start = time.time() 102 | quantity_trajectory = traj_util.quantity_traj(traj_state, compute_fns, 103 | energy_params, batch_size=2) 104 | print('quantity runtime in min: ', (time.time() - t_post_start) / 60.) 105 | 106 | if 'rdf' in quantity_trajectory: 107 | computed_RDF = jnp.mean(quantity_trajectory['rdf'], axis=0) 108 | jnp.save(f'output/properties/{plotname}_RDF', 109 | jnp.array([rdf_struct.rdf_bin_centers, computed_RDF]).T) 110 | Postprocessing.plot_initial_and_predicted_rdf(rdf_struct.rdf_bin_centers, 111 | computed_RDF, model, 112 | plotname, 113 | rdf_struct.reference) 114 | 115 | if ('pressure_tensor' in quantity_trajectory 116 | or 'pressure' in quantity_trajectory): 117 | if 'pressure_tensor' in quantity_trajectory: 118 | pressure_traj = quantity_trajectory['pressure_tensor'] 119 | else: 120 | pressure_traj = quantity_trajectory['pressure'] 121 | mean_pressure = jnp.mean(pressure_traj, axis=0) 122 | std_pressure = jnp.std(pressure_traj, axis=0) 123 | print('Pressure scalar mean:', mean_pressure, 'and standard deviation:', 124 | std_pressure) 125 | 126 | if 'adf' in quantity_trajectory: 127 | computed_ADF = jnp.mean(quantity_trajectory['adf'], axis=0) 128 | jnp.save(f'output/properties/{plotname}_ADF', 129 | jnp.array([adf_struct.adf_bin_centers, computed_ADF]).T) 130 | Postprocessing.plot_initial_and_predicted_adf(adf_struct.adf_bin_centers, 131 | computed_ADF, model, 132 | plotname, 133 | adf_struct.reference) 134 | 135 | if 'tcf' in quantity_trajectory: 136 | computed_TCF = jnp.mean(quantity_trajectory['tcf'], axis=0) 137 | equilateral = jnp.diagonal(jnp.diagonal(computed_TCF)) 138 | jnp.save(f'output/properties/{plotname}_TCF', 139 | jnp.array([tcf_struct.tcf_x_bin_centers[0, :, 0], equilateral]).T) 140 | Postprocessing.plot_initial_and_predicted_tcf( 141 | tcf_struct.tcf_x_bin_centers[0, :, 0], equilateral, model, plotname, 142 | tcf_struct.reference) 143 | -------------------------------------------------------------------------------- /examples/water/data/dataset/box.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/examples/water/data/dataset/box.npy -------------------------------------------------------------------------------- /examples/water/data/water_models/TIP4P-2005_150_COM_ADF.csv: -------------------------------------------------------------------------------- 1 | 1.047197543084621429e-02 0.000000000000000000e+00 2 | 3.141592442989349365e-02 0.000000000000000000e+00 3 | 5.235987529158592224e-02 0.000000000000000000e+00 4 | 7.330383360385894775e-02 0.000000000000000000e+00 5 | 9.424777328968048096e-02 0.000000000000000000e+00 6 | 1.151917278766632080e-01 0.000000000000000000e+00 7 | 1.361356824636459351e-01 0.000000000000000000e+00 8 | 1.570796370506286621e-01 0.000000000000000000e+00 9 | 1.780235767364501953e-01 0.000000000000000000e+00 10 | 1.989675313234329224e-01 0.000000000000000000e+00 11 | 2.199114710092544556e-01 0.000000000000000000e+00 12 | 2.408554404973983765e-01 0.000000000000000000e+00 13 | 2.617993950843811035e-01 0.000000000000000000e+00 14 | 2.827433347702026367e-01 0.000000000000000000e+00 15 | 3.036872744560241699e-01 0.000000000000000000e+00 16 | 3.246312141418457031e-01 0.000000000000000000e+00 17 | 3.455751836299896240e-01 0.000000000000000000e+00 18 | 3.665191531181335449e-01 0.000000000000000000e+00 19 | 3.874630630016326904e-01 0.000000000000000000e+00 20 | 4.084070324897766113e-01 0.000000000000000000e+00 21 | 4.293509721755981445e-01 0.000000000000000000e+00 22 | 4.502949416637420654e-01 0.000000000000000000e+00 23 | 4.712388813495635986e-01 0.000000000000000000e+00 24 | 4.921828210353851318e-01 0.000000000000000000e+00 25 | 5.131267905235290527e-01 0.000000000000000000e+00 26 | 5.340707302093505859e-01 2.696798894593110453e-41 27 | 5.550146698951721191e-01 4.748788753793651716e-36 28 | 5.759586095809936523e-01 3.099701255590088052e-31 29 | 5.969026088714599609e-01 7.518099419880879376e-27 30 | 6.178465485572814941e-01 6.799678711506241071e-23 31 | 6.387904882431030273e-01 2.306602185392067470e-19 32 | 6.597344875335693359e-01 2.963818310326244094e-16 33 | 6.806783676147460938e-01 1.468772206723173435e-13 34 | 7.016223073005676270e-01 2.907225675474744975e-11 35 | 7.225663065910339355e-01 2.458349079859090125e-09 36 | 7.435102462768554688e-01 9.963414271396686672e-08 37 | 7.644541859626770020e-01 2.233220357084064744e-06 38 | 7.853981852531433105e-01 3.063649637624621391e-05 39 | 8.063421249389648438e-01 2.677272714208811522e-04 40 | 8.272860050201416016e-01 1.543953432701528072e-03 41 | 8.482299447059631348e-01 6.202130112797021866e-03 42 | 8.691739439964294434e-01 1.831641234457492828e-02 43 | 8.901178836822509766e-01 4.156225547194480896e-02 44 | 9.110618233680725098e-01 7.547598332166671753e-02 45 | 9.320058226585388184e-01 1.138126924633979797e-01 46 | 9.529497623443603516e-01 1.474943608045578003e-01 47 | 9.738937020301818848e-01 1.702717691659927368e-01 48 | 9.948376417160034180e-01 1.814334839582443237e-01 49 | 1.015781641006469727e+00 1.847356110811233521e-01 50 | 1.036725640296936035e+00 1.853542476892471313e-01 51 | 1.057669520378112793e+00 1.874053627252578735e-01 52 | 1.078613400459289551e+00 1.934888809919357300e-01 53 | 1.099557399749755859e+00 2.046962380409240723e-01 54 | 1.120501279830932617e+00 2.206497937440872192e-01 55 | 1.141445279121398926e+00 2.406460344791412354e-01 56 | 1.162389278411865234e+00 2.637881338596343994e-01 57 | 1.183333277702331543e+00 2.892001271247863770e-01 58 | 1.204277157783508301e+00 3.162643611431121826e-01 59 | 1.225221037864685059e+00 3.440644443035125732e-01 60 | 1.246165037155151367e+00 3.722534477710723877e-01 61 | 1.267108917236328125e+00 4.014416038990020752e-01 62 | 1.288053035736083984e+00 4.314111173152923584e-01 63 | 1.308996915817260742e+00 4.612493515014648438e-01 64 | 1.329940915107727051e+00 4.909719526767730713e-01 65 | 1.350884795188903809e+00 5.203309655189514160e-01 66 | 1.371828794479370117e+00 5.490950942039489746e-01 67 | 1.392772674560546875e+00 5.779589414596557617e-01 68 | 1.413716554641723633e+00 6.066927313804626465e-01 69 | 1.434660673141479492e+00 6.347096562385559082e-01 70 | 1.455604553222656250e+00 6.623813509941101074e-01 71 | 1.476548552513122559e+00 6.895257234573364258e-01 72 | 1.497492432594299316e+00 7.158519625663757324e-01 73 | 1.518436431884765625e+00 7.416029572486877441e-01 74 | 1.539380311965942383e+00 7.667118906974792480e-01 75 | 1.560324311256408691e+00 7.906714677810668945e-01 76 | 1.581268310546875000e+00 8.132739067077636719e-01 77 | 1.602212190628051758e+00 8.347730040550231934e-01 78 | 1.623156189918518066e+00 8.547309637069702148e-01 79 | 1.644100069999694824e+00 8.726347684860229492e-01 80 | 1.665044069290161133e+00 8.887214660644531250e-01 81 | 1.685987949371337891e+00 9.029079079627990723e-01 82 | 1.706931948661804199e+00 9.147949218750000000e-01 83 | 1.727875947952270508e+00 9.240854978561401367e-01 84 | 1.748819828033447266e+00 9.308882355690002441e-01 85 | 1.769763827323913574e+00 9.355124831199645996e-01 86 | 1.790707707405090332e+00 9.380925297737121582e-01 87 | 1.811651706695556641e+00 9.378729462623596191e-01 88 | 1.832595705986022949e+00 9.343109130859375000e-01 89 | 1.853539705276489258e+00 9.280441999435424805e-01 90 | 1.874483585357666016e+00 9.193806648254394531e-01 91 | 1.895427465438842773e+00 9.083833694458007812e-01 92 | 1.916371464729309082e+00 8.952869176864624023e-01 93 | 1.937315344810485840e+00 8.797358274459838867e-01 94 | 1.958259463310241699e+00 8.616499900817871094e-01 95 | 1.979203343391418457e+00 8.419786691665649414e-01 96 | 2.000147342681884766e+00 8.209761381149291992e-01 97 | 2.021091222763061523e+00 7.986261844635009766e-01 98 | 2.042035341262817383e+00 7.753762006759643555e-01 99 | 2.062979221343994141e+00 7.515185475349426270e-01 100 | 2.083923101425170898e+00 7.267671823501586914e-01 101 | 2.104866981506347656e+00 7.011711597442626953e-01 102 | 2.125811100006103516e+00 6.754121184349060059e-01 103 | 2.146754980087280273e+00 6.494687199592590332e-01 104 | 2.167698860168457031e+00 6.236197948455810547e-01 105 | 2.188642740249633789e+00 5.985466837882995605e-01 106 | 2.209586620330810547e+00 5.739730596542358398e-01 107 | 2.230530738830566406e+00 5.494124889373779297e-01 108 | 2.251474618911743164e+00 5.253456830978393555e-01 109 | 2.272418498992919922e+00 5.020490884780883789e-01 110 | 2.293362617492675781e+00 4.792396426200866699e-01 111 | 2.314306735992431641e+00 4.567476511001586914e-01 112 | 2.335250616073608398e+00 4.350789785385131836e-01 113 | 2.356194496154785156e+00 4.145214855670928955e-01 114 | 2.377138376235961914e+00 3.947218060493469238e-01 115 | 2.398082256317138672e+00 3.756109774112701416e-01 116 | 2.419026374816894531e+00 3.569985628128051758e-01 117 | 2.439970254898071289e+00 3.387740254402160645e-01 118 | 2.460914134979248047e+00 3.214335739612579346e-01 119 | 2.481858015060424805e+00 3.051315546035766602e-01 120 | 2.502801895141601562e+00 2.895469069480895996e-01 121 | 2.523746013641357422e+00 2.745556533336639404e-01 122 | 2.544689893722534180e+00 2.599911689758300781e-01 123 | 2.565634012222290039e+00 2.457677125930786133e-01 124 | 2.586577892303466797e+00 2.321528047323226929e-01 125 | 2.607522010803222656e+00 2.192973047494888306e-01 126 | 2.628465890884399414e+00 2.069451361894607544e-01 127 | 2.649409770965576172e+00 1.949867457151412964e-01 128 | 2.670353651046752930e+00 1.834066063165664673e-01 129 | 2.691297531127929688e+00 1.722830235958099365e-01 130 | 2.712241649627685547e+00 1.617439389228820801e-01 131 | 2.733185529708862305e+00 1.515560448169708252e-01 132 | 2.754129409790039062e+00 1.415209621191024780e-01 133 | 2.775073289871215820e+00 1.317665427923202515e-01 134 | 2.796017408370971680e+00 1.224943697452545166e-01 135 | 2.816961288452148438e+00 1.135734096169471741e-01 136 | 2.837905406951904297e+00 1.047595441341400146e-01 137 | 2.858849287033081055e+00 9.626057744026184082e-02 138 | 2.879793167114257812e+00 8.825404942035675049e-02 139 | 2.900737285614013672e+00 8.043953776359558105e-02 140 | 2.921681165695190430e+00 7.269820570945739746e-02 141 | 2.942625045776367188e+00 6.510291248559951782e-02 142 | 2.963568925857543945e+00 5.773638933897018433e-02 143 | 2.984512805938720703e+00 5.062547326087951660e-02 144 | 3.005456924438476562e+00 4.359810799360275269e-02 145 | 3.026400804519653320e+00 3.665751963853836060e-02 146 | 3.047344684600830078e+00 2.980031445622444153e-02 147 | 3.068288564682006836e+00 2.304768748581409454e-02 148 | 3.089232683181762695e+00 1.642253622412681580e-02 149 | 3.110176801681518555e+00 1.004248671233654022e-02 150 | 3.131120681762695312e+00 4.602672532200813293e-03 151 | -------------------------------------------------------------------------------- /examples/water/data/water_models/TIP4P-2005_1k_50b_TCF_cut05.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/examples/water/data/water_models/TIP4P-2005_1k_50b_TCF_cut05.npy -------------------------------------------------------------------------------- /examples/water/data/water_models/TIP4P-2005_1k_50b_TCF_cut06.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/examples/water/data/water_models/TIP4P-2005_1k_50b_TCF_cut06.npy -------------------------------------------------------------------------------- /examples/water/data/water_models/TIP4P-2005_1k_50b_TCF_cut08.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tummfm/relative-entropy/cb7067bf396474b475c36e0b89eb55bb383360fd/examples/water/data/water_models/TIP4P-2005_1k_50b_TCF_cut08.npy -------------------------------------------------------------------------------- /examples/water/data/water_models/TIP4P-2005_300_COM_RDF.csv: -------------------------------------------------------------------------------- 1 | 1.416666666045784950e-03 0.000000000000000000e+00 2 | 4.250000230967998505e-03 0.000000000000000000e+00 3 | 7.083333563059568405e-03 0.000000000000000000e+00 4 | 9.916666895151138306e-03 0.000000000000000000e+00 5 | 1.275000069290399551e-02 0.000000000000000000e+00 6 | 1.558333355933427811e-02 0.000000000000000000e+00 7 | 1.841666735708713531e-02 0.000000000000000000e+00 8 | 2.125000022351741791e-02 0.000000000000000000e+00 9 | 2.408333495259284973e-02 0.000000000000000000e+00 10 | 2.691666781902313232e-02 0.000000000000000000e+00 11 | 2.975000068545341492e-02 0.000000000000000000e+00 12 | 3.258333355188369751e-02 0.000000000000000000e+00 13 | 3.541666641831398010e-02 0.000000000000000000e+00 14 | 3.824999928474426270e-02 0.000000000000000000e+00 15 | 4.108333587646484375e-02 0.000000000000000000e+00 16 | 4.391666874289512634e-02 0.000000000000000000e+00 17 | 4.675000160932540894e-02 0.000000000000000000e+00 18 | 4.958333447575569153e-02 0.000000000000000000e+00 19 | 5.241666734218597412e-02 0.000000000000000000e+00 20 | 5.525000020861625671e-02 0.000000000000000000e+00 21 | 5.808333680033683777e-02 0.000000000000000000e+00 22 | 6.091666966676712036e-02 0.000000000000000000e+00 23 | 6.374999880790710449e-02 0.000000000000000000e+00 24 | 6.658332794904708862e-02 0.000000000000000000e+00 25 | 6.941667199134826660e-02 0.000000000000000000e+00 26 | 7.225000858306884766e-02 0.000000000000000000e+00 27 | 7.508333772420883179e-02 0.000000000000000000e+00 28 | 7.791666686534881592e-02 0.000000000000000000e+00 29 | 8.075000345706939697e-02 0.000000000000000000e+00 30 | 8.358334004878997803e-02 0.000000000000000000e+00 31 | 8.641666918992996216e-02 0.000000000000000000e+00 32 | 8.924999833106994629e-02 0.000000000000000000e+00 33 | 9.208333492279052734e-02 0.000000000000000000e+00 34 | 9.491667151451110840e-02 0.000000000000000000e+00 35 | 9.775000065565109253e-02 0.000000000000000000e+00 36 | 1.005833297967910767e-01 0.000000000000000000e+00 37 | 1.034166663885116577e-01 0.000000000000000000e+00 38 | 1.062500104308128357e-01 0.000000000000000000e+00 39 | 1.090833395719528198e-01 0.000000000000000000e+00 40 | 1.119166687130928040e-01 0.000000000000000000e+00 41 | 1.147500053048133850e-01 0.000000000000000000e+00 42 | 1.175833418965339661e-01 0.000000000000000000e+00 43 | 1.204166710376739502e-01 0.000000000000000000e+00 44 | 1.232500001788139343e-01 0.000000000000000000e+00 45 | 1.260833442211151123e-01 0.000000000000000000e+00 46 | 1.289166808128356934e-01 0.000000000000000000e+00 47 | 1.317500025033950806e-01 0.000000000000000000e+00 48 | 1.345833390951156616e-01 0.000000000000000000e+00 49 | 1.374166756868362427e-01 0.000000000000000000e+00 50 | 1.402500122785568237e-01 0.000000000000000000e+00 51 | 1.430833488702774048e-01 0.000000000000000000e+00 52 | 1.459166705608367920e-01 0.000000000000000000e+00 53 | 1.487500071525573730e-01 0.000000000000000000e+00 54 | 1.515833437442779541e-01 0.000000000000000000e+00 55 | 1.544166654348373413e-01 0.000000000000000000e+00 56 | 1.572500020265579224e-01 0.000000000000000000e+00 57 | 1.600833386182785034e-01 0.000000000000000000e+00 58 | 1.629166752099990845e-01 0.000000000000000000e+00 59 | 1.657500118017196655e-01 0.000000000000000000e+00 60 | 1.685833334922790527e-01 0.000000000000000000e+00 61 | 1.714166700839996338e-01 0.000000000000000000e+00 62 | 1.742500066757202148e-01 0.000000000000000000e+00 63 | 1.770833283662796021e-01 0.000000000000000000e+00 64 | 1.799166649580001831e-01 0.000000000000000000e+00 65 | 1.827500015497207642e-01 0.000000000000000000e+00 66 | 1.855833381414413452e-01 0.000000000000000000e+00 67 | 1.884166747331619263e-01 0.000000000000000000e+00 68 | 1.912499964237213135e-01 0.000000000000000000e+00 69 | 1.940833330154418945e-01 0.000000000000000000e+00 70 | 1.969166696071624756e-01 0.000000000000000000e+00 71 | 1.997499912977218628e-01 0.000000000000000000e+00 72 | 2.025833427906036377e-01 3.469614997668247068e-42 73 | 2.054166793823242188e-01 8.026747913867663832e-37 74 | 2.082500159740447998e-01 6.851288890151846825e-32 75 | 2.110833525657653809e-01 2.160860849386418483e-27 76 | 2.139166742563247681e-01 2.523800327490867406e-23 77 | 2.167500108480453491e-01 1.095818114731221748e-19 78 | 2.195833474397659302e-01 1.780484960347812704e-16 79 | 2.224166691303253174e-01 1.095904779684769847e-13 80 | 2.252500057220458984e-01 2.617174185703952105e-11 81 | 2.280833423137664795e-01 2.542059895915826928e-09 82 | 2.309166789054870605e-01 1.096716033543998492e-07 83 | 2.337500154972076416e-01 2.402863628958584741e-06 84 | 2.365833371877670288e-01 3.068153819185681641e-05 85 | 2.394166737794876099e-01 2.520468551665544510e-04 86 | 2.422500103712081909e-01 1.462224754504859447e-03 87 | 2.450833320617675781e-01 6.552447564899921417e-03 88 | 2.479166686534881592e-01 2.369282394647598267e-02 89 | 2.507500052452087402e-01 7.054819166660308838e-02 90 | 2.535833418369293213e-01 1.764161586761474609e-01 91 | 2.564166784286499023e-01 3.764983713626861572e-01 92 | 2.592500150203704834e-01 6.968575716018676758e-01 93 | 2.620833218097686768e-01 1.136678099632263184e+00 94 | 2.649166584014892578e-01 1.656738996505737305e+00 95 | 2.677499949932098389e-01 2.187455892562866211e+00 96 | 2.705833315849304199e-01 2.651867151260375977e+00 97 | 2.734166681766510010e-01 2.988822221755981445e+00 98 | 2.762500047683715820e-01 3.166519403457641602e+00 99 | 2.790833413600921631e-01 3.184405326843261719e+00 100 | 2.819166779518127441e-01 3.069432973861694336e+00 101 | 2.847500145435333252e-01 2.862239360809326172e+00 102 | 2.875833213329315186e-01 2.601718187332153320e+00 103 | 2.904166579246520996e-01 2.319228887557983398e+00 104 | 2.932499945163726807e-01 2.042148351669311523e+00 105 | 2.960833311080932617e-01 1.788080453872680664e+00 106 | 2.989166676998138428e-01 1.564203143119812012e+00 107 | 3.017500042915344238e-01 1.372956037521362305e+00 108 | 3.045833408832550049e-01 1.214646100997924805e+00 109 | 3.074166476726531982e-01 1.087616205215454102e+00 110 | 3.102499842643737793e-01 9.873774051666259766e-01 111 | 3.130833208560943604e-01 9.090203046798706055e-01 112 | 3.159166872501373291e-01 8.486949205398559570e-01 113 | 3.187500238418579102e-01 8.031448721885681152e-01 114 | 3.215833604335784912e-01 7.705770730972290039e-01 115 | 3.244166970252990723e-01 7.485921978950500488e-01 116 | 3.272500336170196533e-01 7.343463897705078125e-01 117 | 3.300833702087402344e-01 7.263634800910949707e-01 118 | 3.329166769981384277e-01 7.230204939842224121e-01 119 | 3.357500135898590088e-01 7.229611277580261230e-01 120 | 3.385833501815795898e-01 7.259953022003173828e-01 121 | 3.414166867733001709e-01 7.318863272666931152e-01 122 | 3.442500233650207520e-01 7.398442029953002930e-01 123 | 3.470833599567413330e-01 7.493991255760192871e-01 124 | 3.499166965484619141e-01 7.608982324600219727e-01 125 | 3.527500033378601074e-01 7.730560302734375000e-01 126 | 3.555833399295806885e-01 7.843612432479858398e-01 127 | 3.584166765213012695e-01 7.962299585342407227e-01 128 | 3.612500131130218506e-01 8.099718689918518066e-01 129 | 3.640833497047424316e-01 8.242074251174926758e-01 130 | 3.669166862964630127e-01 8.382154107093811035e-01 131 | 3.697500228881835938e-01 8.525349497795104980e-01 132 | 3.725833594799041748e-01 8.668604493141174316e-01 133 | 3.754166960716247559e-01 8.810650110244750977e-01 134 | 3.782500028610229492e-01 8.958851099014282227e-01 135 | 3.810833394527435303e-01 9.111163020133972168e-01 136 | 3.839166760444641113e-01 9.262808561325073242e-01 137 | 3.867500126361846924e-01 9.414162039756774902e-01 138 | 3.895833492279052734e-01 9.562956690788269043e-01 139 | 3.924166858196258545e-01 9.712432622909545898e-01 140 | 3.952500224113464355e-01 9.860259294509887695e-01 141 | 3.980833292007446289e-01 9.998610019683837891e-01 142 | 4.009166657924652100e-01 1.013574123382568359e+00 143 | 4.037500023841857910e-01 1.027941346168518066e+00 144 | 4.065833389759063721e-01 1.041846990585327148e+00 145 | 4.094166755676269531e-01 1.055045485496520996e+00 146 | 4.122500121593475342e-01 1.067820072174072266e+00 147 | 4.150833487510681152e-01 1.079607367515563965e+00 148 | 4.179166853427886963e-01 1.090632557868957520e+00 149 | 4.207500219345092773e-01 1.101311922073364258e+00 150 | 4.235833287239074707e-01 1.110913991928100586e+00 151 | 4.264166653156280518e-01 1.119199633598327637e+00 152 | 4.292500019073486328e-01 1.126748561859130859e+00 153 | 4.320833384990692139e-01 1.133677959442138672e+00 154 | 4.349166750907897949e-01 1.139968752861022949e+00 155 | 4.377500116825103760e-01 1.144941806793212891e+00 156 | 4.405833482742309570e-01 1.148324966430664062e+00 157 | 4.434166550636291504e-01 1.150982975959777832e+00 158 | 4.462499916553497314e-01 1.152776002883911133e+00 159 | 4.490833282470703125e-01 1.153429031372070312e+00 160 | 4.519166648387908936e-01 1.152806401252746582e+00 161 | 4.547500014305114746e-01 1.151133656501770020e+00 162 | 4.575833380222320557e-01 1.148886442184448242e+00 163 | 4.604166746139526367e-01 1.145385146141052246e+00 164 | 4.632500112056732178e-01 1.140898108482360840e+00 165 | 4.660833477973937988e-01 1.136256217956542969e+00 166 | 4.689166843891143799e-01 1.131097078323364258e+00 167 | 4.717500209808349609e-01 1.124847769737243652e+00 168 | 4.745833575725555420e-01 1.117028236389160156e+00 169 | 4.774166941642761230e-01 1.107944250106811523e+00 170 | 4.802500307559967041e-01 1.098916411399841309e+00 171 | 4.830833673477172852e-01 1.090118288993835449e+00 172 | 4.859167039394378662e-01 1.081097126007080078e+00 173 | 4.887500107288360596e-01 1.071637511253356934e+00 174 | 4.915833473205566406e-01 1.061913013458251953e+00 175 | 4.944166839122772217e-01 1.052093029022216797e+00 176 | 4.972500205039978027e-01 1.041418552398681641e+00 177 | 5.000833272933959961e-01 1.030466437339782715e+00 178 | 5.029166936874389648e-01 1.020531773567199707e+00 179 | 5.057500004768371582e-01 1.011157989501953125e+00 180 | 5.085833668708801270e-01 1.001176238059997559e+00 181 | 5.114166736602783203e-01 9.911072254180908203e-01 182 | 5.142500400543212891e-01 9.818866848945617676e-01 183 | 5.170833468437194824e-01 9.731761217117309570e-01 184 | 5.199167132377624512e-01 9.646103978157043457e-01 185 | 5.227499604225158691e-01 9.561593532562255859e-01 186 | 5.255833268165588379e-01 9.482639431953430176e-01 187 | 5.284166336059570312e-01 9.409123063087463379e-01 188 | 5.312500000000000000e-01 9.333912134170532227e-01 189 | 5.340833067893981934e-01 9.258452653884887695e-01 190 | 5.369166731834411621e-01 9.189936518669128418e-01 191 | 5.397499799728393555e-01 9.127939343452453613e-01 192 | 5.425833463668823242e-01 9.066441059112548828e-01 193 | 5.454167127609252930e-01 9.005212783813476562e-01 194 | 5.482500195503234863e-01 8.951601386070251465e-01 195 | 5.510833859443664551e-01 8.906797766685485840e-01 196 | 5.539166927337646484e-01 8.870040774345397949e-01 197 | 5.567500591278076172e-01 8.840076923370361328e-01 198 | 5.595833659172058105e-01 8.815358877182006836e-01 199 | 5.624167323112487793e-01 8.799689412117004395e-01 200 | 5.652500391006469727e-01 8.793771266937255859e-01 201 | 5.680834054946899414e-01 8.792802691459655762e-01 202 | 5.709166526794433594e-01 8.796753883361816406e-01 203 | 5.737500190734863281e-01 8.810513615608215332e-01 204 | 5.765833258628845215e-01 8.836045861244201660e-01 205 | 5.794166922569274902e-01 8.867932558059692383e-01 206 | 5.822499990463256836e-01 8.904874920845031738e-01 207 | 5.850833654403686523e-01 8.950030803680419922e-01 208 | 5.879166722297668457e-01 9.001734852790832520e-01 209 | 5.907500386238098145e-01 9.057089686393737793e-01 210 | 5.935833454132080078e-01 9.117854237556457520e-01 211 | 5.964167118072509766e-01 9.183089733123779297e-01 212 | 5.992500185966491699e-01 9.250890612602233887e-01 213 | 6.020833849906921387e-01 9.320988059043884277e-01 214 | 6.049166917800903320e-01 9.392644166946411133e-01 215 | 6.077500581741333008e-01 9.466698765754699707e-01 216 | 6.105833649635314941e-01 9.541507363319396973e-01 217 | 6.134166717529296875e-01 9.612823128700256348e-01 218 | 6.162499785423278809e-01 9.681364893913269043e-01 219 | 6.190833449363708496e-01 9.753326177597045898e-01 220 | 6.219166517257690430e-01 9.824585914611816406e-01 221 | 6.247500181198120117e-01 9.890159964561462402e-01 222 | 6.275833249092102051e-01 9.955735206604003906e-01 223 | 6.304166913032531738e-01 1.002019882202148438e+00 224 | 6.332499980926513672e-01 1.007885932922363281e+00 225 | 6.360833644866943359e-01 1.013558626174926758e+00 226 | 6.389166712760925293e-01 1.019194841384887695e+00 227 | 6.417500376701354980e-01 1.024528384208679199e+00 228 | 6.445833444595336914e-01 1.029541850090026855e+00 229 | 6.474167108535766602e-01 1.034276008605957031e+00 230 | 6.502500176429748535e-01 1.038602948188781738e+00 231 | 6.530833840370178223e-01 1.042269110679626465e+00 232 | 6.559166908264160156e-01 1.045462012290954590e+00 233 | 6.587500572204589844e-01 1.048409581184387207e+00 234 | 6.615833044052124023e-01 1.051094055175781250e+00 235 | 6.644166707992553711e-01 1.053415894508361816e+00 236 | 6.672499775886535645e-01 1.055243253707885742e+00 237 | 6.700833439826965332e-01 1.056801199913024902e+00 238 | 6.729166507720947266e-01 1.058075904846191406e+00 239 | 6.757500171661376953e-01 1.058975815773010254e+00 240 | 6.785833239555358887e-01 1.059502482414245605e+00 241 | 6.814166903495788574e-01 1.059800505638122559e+00 242 | 6.842499971389770508e-01 1.059618473052978516e+00 243 | 6.870833635330200195e-01 1.058826327323913574e+00 244 | 6.899166703224182129e-01 1.057646989822387695e+00 245 | 6.927500367164611816e-01 1.056127429008483887e+00 246 | 6.955833435058593750e-01 1.054541826248168945e+00 247 | 6.984167098999023438e-01 1.052804708480834961e+00 248 | 7.012500762939453125e-01 1.050705671310424805e+00 249 | 7.040833234786987305e-01 1.048386335372924805e+00 250 | 7.069166898727416992e-01 1.045899152755737305e+00 251 | 7.097499966621398926e-01 1.043227910995483398e+00 252 | 7.125833630561828613e-01 1.040268778800964355e+00 253 | 7.154166698455810547e-01 1.037242531776428223e+00 254 | 7.182500362396240234e-01 1.034327268600463867e+00 255 | 7.210833430290222168e-01 1.031489253044128418e+00 256 | 7.239167094230651855e-01 1.028607010841369629e+00 257 | 7.267500162124633789e-01 1.025450587272644043e+00 258 | 7.295833826065063477e-01 1.022058129310607910e+00 259 | 7.324166893959045410e-01 1.018610358238220215e+00 260 | 7.352500557899475098e-01 1.015341401100158691e+00 261 | 7.380833625793457031e-01 1.012277364730834961e+00 262 | 7.409167289733886719e-01 1.009602069854736328e+00 263 | 7.437500357627868652e-01 1.007086634635925293e+00 264 | 7.465834021568298340e-01 1.004308342933654785e+00 265 | 7.494167089462280273e-01 1.001388549804687500e+00 266 | 7.522500157356262207e-01 9.988106489181518555e-01 267 | 7.550833225250244141e-01 9.964964389801025391e-01 268 | 7.579166889190673828e-01 9.943118691444396973e-01 269 | 7.607499957084655762e-01 9.923703670501708984e-01 270 | 7.635833621025085449e-01 9.907506704330444336e-01 271 | 7.664166688919067383e-01 9.893094301223754883e-01 272 | 7.692500352859497070e-01 9.878864884376525879e-01 273 | 7.720833420753479004e-01 9.864303469657897949e-01 274 | 7.749167084693908691e-01 9.850019216537475586e-01 275 | 7.777500152587890625e-01 9.839083552360534668e-01 276 | 7.805833816528320312e-01 9.832831621170043945e-01 277 | 7.834166884422302246e-01 9.828054308891296387e-01 278 | 7.862500548362731934e-01 9.822917580604553223e-01 279 | 7.890833616256713867e-01 9.817476272583007812e-01 280 | 7.919167280197143555e-01 9.814655780792236328e-01 281 | 7.947499752044677734e-01 9.815892577171325684e-01 282 | 7.975833415985107422e-01 9.819374680519104004e-01 283 | 8.004166483879089355e-01 9.823878407478332520e-01 284 | 8.032500147819519043e-01 9.826974272727966309e-01 285 | 8.060833215713500977e-01 9.829730987548828125e-01 286 | 8.089166879653930664e-01 9.833093285560607910e-01 287 | 8.117499947547912598e-01 9.837819933891296387e-01 288 | 8.145833611488342285e-01 9.845438003540039062e-01 289 | 8.174166679382324219e-01 9.852983355522155762e-01 290 | 8.202500343322753906e-01 9.860043525695800781e-01 291 | 8.230833411216735840e-01 9.869421720504760742e-01 292 | 8.259167075157165527e-01 9.880825877189636230e-01 293 | 8.287500143051147461e-01 9.891009330749511719e-01 294 | 8.315833806991577148e-01 9.898670911788940430e-01 295 | 8.344166874885559082e-01 9.907552599906921387e-01 296 | 8.372500538825988770e-01 9.918808341026306152e-01 297 | 8.400833606719970703e-01 9.928250908851623535e-01 298 | 8.429166674613952637e-01 9.934182167053222656e-01 299 | 8.457499742507934570e-01 9.942603111267089844e-01 300 | 8.485833406448364258e-01 9.954217672348022461e-01 301 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | install_requires = [ 4 | 'jax>=0.4.3', 5 | 'jax-md>=0.2.5', 6 | 'optax>=0.0.9', 7 | 'dm-haiku>=0.0.9', 8 | 'sympy', 9 | 'cloudpickle', 10 | 'chex', 11 | 'jax-sgmc', 12 | ] 13 | 14 | extras_requires = { 15 | 'all': ['mdtraj<=1.9.6', 'matplotlib'], 16 | } 17 | 18 | with open('README.md', 'rt') as f: 19 | long_description = f.read() 20 | 21 | setup( 22 | name='rel-entropy', 23 | version='0.0.1', 24 | license='Apache 2.0', 25 | description=('Train neural network potentials via relative entropy' 26 | ' and force matching.'), 27 | author='Stephan Thaler', 28 | author_email='stephan.thaler@tum.de', 29 | packages=find_packages(exclude='examples'), 30 | python_requires='>=3.8', 31 | install_requires=install_requires, 32 | extras_require=extras_requires, 33 | long_description=long_description, 34 | long_description_content_type='text/markdown', 35 | url='https://github.com/tummfm/relative-entropy', 36 | classifiers=[ 37 | 'Programming Language :: Python :: 3', 38 | 'Programming Language :: Python :: 3.8', 39 | 'Programming Language :: Python :: 3.9', 40 | 'License :: OSI Approved :: Apache Software License', 41 | 'Operating System :: POSIX :: Linux', 42 | 'Topic :: Scientific/Engineering', 43 | 'Intended Audience :: Science/Research', 44 | 'Intended Audience :: Developers', 45 | ], 46 | zip_safe=False, 47 | ) 48 | -------------------------------------------------------------------------------- /util/Postprocessing.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import lax 3 | import pickle 4 | 5 | from matplotlib import pyplot as plt 6 | from functools import partial 7 | 8 | 9 | def box_density(R_snapshot, bin_edges, axis=0): 10 | # assumes all particles are wrapped into the same box 11 | profile, _ = jnp.histogram(R_snapshot[:, axis], bins=bin_edges) 12 | # norm via n_bins and n_particles 13 | profile *= (profile.shape[0] / R_snapshot.shape[0]) 14 | return profile 15 | 16 | 17 | def get_bin_centers_from_edges(bin_edges): 18 | """To get centers from bin edges as generated from jnp.histogram""" 19 | bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2. 20 | return bin_centers 21 | 22 | 23 | def plot_density(file_name, n_bins=50): 24 | with open(file_name, 'rb') as f: 25 | R_traj_list, box = pickle.load(f) 26 | 27 | R_traj = R_traj_list[10] 28 | bin_edges = jnp.linspace(0., box[0], n_bins + 1) 29 | bin_centers = get_bin_centers_from_edges(bin_edges) 30 | compute_box_density = partial(box_density, bin_edges=bin_edges) 31 | density_snapshots = lax.map(compute_box_density, R_traj) 32 | density = jnp.mean(density_snapshots, axis=0) 33 | 34 | file_name = file_name[:-4] 35 | plt.figure() 36 | plt.plot(bin_centers, density) 37 | plt.ylabel('Normalizes Density') 38 | plt.xlabel('x') 39 | plt.savefig(file_name + '.png') 40 | 41 | 42 | def plot_initial_and_predicted_rdf(rdf_bin_centers, g_average_final, model, 43 | visible_device, reference_rdf=None, 44 | g_average_init=None, after_pretraining=False, 45 | std=None, T=None, color=None): 46 | if color is None: 47 | color = ['k', '#00A087FF', '#3C5488FF'] 48 | if after_pretraining: 49 | pretrain_str = '_after_pretrain' 50 | else: 51 | pretrain_str = '' 52 | 53 | plt.figure() 54 | plt.plot(rdf_bin_centers, g_average_final, label='predicted', 55 | color=color[1]) 56 | if g_average_init is not None: 57 | plt.plot(rdf_bin_centers, g_average_init, label='initial guess', 58 | color=color[2]) 59 | if reference_rdf is not None: 60 | plt.plot(rdf_bin_centers, reference_rdf, label='reference', 61 | dashes=(4, 3), color=color[0], linestyle='--') 62 | if std is not None: 63 | plt.fill_between(rdf_bin_centers, g_average_final - std, 64 | g_average_final + std, alpha=0.3, 65 | facecolor='#00A087FF', label='Uncertainty') 66 | plt.legend() 67 | plt.xlabel('r in $\mathrm{nm}$') 68 | plt.savefig(f'output/figures/predicted_RDF_' 69 | f'{model}_{T or ""}_{visible_device}{pretrain_str}.png') 70 | 71 | 72 | def plot_initial_and_predicted_adf(adf_bin_centers, predicted_adf_final, model, 73 | visible_device, reference_adf=None, 74 | adf_init=None, after_pretraining=False, 75 | std=None, T=None, color=None): 76 | if color is None: 77 | color = ['k', '#00A087FF', '#3C5488FF'] 78 | if after_pretraining: 79 | pretrain_str = '_after_pretrain' 80 | else: 81 | pretrain_str = '' 82 | 83 | plt.figure() 84 | plt.plot(adf_bin_centers, predicted_adf_final, label='predicted', 85 | color=color[1]) 86 | if adf_init is not None: 87 | plt.plot(adf_bin_centers, adf_init, label='initial guess', 88 | color=color[2]) 89 | if reference_adf is not None: 90 | plt.plot(adf_bin_centers, reference_adf, label='reference', 91 | dashes=(4, 3), color=color[0], linestyle='--') 92 | if std is not None: 93 | plt.fill_between(adf_bin_centers, predicted_adf_final - std, 94 | predicted_adf_final + std, alpha=0.3, 95 | facecolor='#00A087FF', label='Uncertainty') 96 | plt.legend() 97 | plt.xlabel(r'$\alpha$ in $\mathrm{rad}$') 98 | plt.savefig(f'output/figures/predicted_ADF_' 99 | f'{model}_{T or ""}_{visible_device}{pretrain_str}.png') 100 | 101 | 102 | def plot_initial_and_predicted_tcf(bin_centers, g_average_final, model, 103 | visible_device, reference_tcf=None, 104 | tcf_init=None, labels=None, 105 | axis_label=None, color=None): 106 | if color is None: 107 | color = ['k', '#00A087FF', '#3C5488FF'] 108 | if labels is None: 109 | labels = ['reference', 'predicted', 'initial guess'] 110 | plt.figure() 111 | plt.plot(bin_centers, g_average_final, label=labels[1], color=color[1]) 112 | if tcf_init is not None: 113 | plt.plot(bin_centers, tcf_init, label=labels[2], color=color[2]) 114 | if reference_tcf is not None: 115 | plt.plot(bin_centers, reference_tcf, label=labels[0], dashes=(4, 3), 116 | color=color[0], linestyle='--') 117 | plt.legend() 118 | if axis_label is not None: 119 | plt.ylabel(axis_label[0]) 120 | plt.xlabel(axis_label[1]) 121 | else: 122 | plt.xlabel('r in $\mathrm{nm}$') 123 | 124 | plt.savefig(f'output/figures/predicted_TCF_' 125 | f'{model}_{visible_device}.png') 126 | --------------------------------------------------------------------------------