├── .github └── workflows │ └── travis.yml ├── .gitignore ├── README.md ├── pyproject.toml ├── requirements.txt ├── schnax ├── __init__.py ├── energy.py ├── model │ ├── __init__.py │ ├── gaussian_smearing.py │ ├── interaction │ │ ├── __init__.py │ │ ├── aggregate.py │ │ ├── cfconv.py │ │ ├── cosine_cutoff.py │ │ ├── filter_network.py │ │ └── interaction.py │ ├── misc.py │ └── schnet.py └── utils │ ├── __init__.py │ ├── ase.py │ ├── distances.py │ ├── layer_hooks.py │ ├── schnetkit.py │ ├── stateless.py │ └── train.py ├── tests ├── assets │ ├── model_n1.torch │ ├── model_n5.torch │ ├── zro2_n_1500.in │ ├── zro2_n_324.in │ ├── zro2_n_768.in │ └── zro2_n_96.in ├── interaction_test_case.py ├── test_cfconv.py ├── test_distance_expansion.py ├── test_distances.py ├── test_embeddings.py ├── test_interaction.py ├── test_schnetkit_end_to_end.py ├── test_schnetkit_end_to_end_no_activations.py └── test_utils │ ├── __init__.py │ ├── activation.py │ ├── initialize.py │ └── mock_environment_provider.py └── train ├── __init__.py ├── iso17.ipynb └── train.py /.github/workflows/travis.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.8", "3.9"] 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install flake8 pytest poetry setuptools-scm 23 | poetry install 24 | - name: Lint with flake8 25 | run: | 26 | # stop the build if there are Python syntax errors or undefined names 27 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 28 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 29 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 30 | - name: Test with nose2 31 | run: | 32 | poetry run nose2 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | poetry.lock 2 | 3 | # Created by https://www.toptal.com/developers/gitignore/api/macos,python,pycharm+all 4 | # Edit at https://www.toptal.com/developers/gitignore?templates=macos,python,pycharm+all 5 | 6 | ### macOS ### 7 | # General 8 | .DS_Store 9 | .AppleDouble 10 | .LSOverride 11 | 12 | # Icon must end with two \r 13 | Icon 14 | 15 | 16 | # Thumbnails 17 | ._* 18 | 19 | # Files that might appear in the root of a volume 20 | .DocumentRevisions-V100 21 | .fseventsd 22 | .Spotlight-V100 23 | .TemporaryItems 24 | .Trashes 25 | .VolumeIcon.icns 26 | .com.apple.timemachine.donotpresent 27 | 28 | # Directories potentially created on remote AFP share 29 | .AppleDB 30 | .AppleDesktop 31 | Network Trash Folder 32 | Temporary Items 33 | .apdisk 34 | 35 | ### PyCharm+all ### 36 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 37 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 38 | 39 | # User-specific stuff 40 | .idea/**/workspace.xml 41 | .idea/**/tasks.xml 42 | .idea/**/usage.statistics.xml 43 | .idea/**/dictionaries 44 | .idea/**/shelf 45 | 46 | # AWS User-specific 47 | .idea/**/aws.xml 48 | 49 | # Generated files 50 | .idea/**/contentModel.xml 51 | 52 | # Sensitive or high-churn files 53 | .idea/**/dataSources/ 54 | .idea/**/dataSources.ids 55 | .idea/**/dataSources.local.xml 56 | .idea/**/sqlDataSources.xml 57 | .idea/**/dynamic.xml 58 | .idea/**/uiDesigner.xml 59 | .idea/**/dbnavigator.xml 60 | 61 | # Gradle 62 | .idea/**/gradle.xml 63 | .idea/**/libraries 64 | 65 | # Gradle and Maven with auto-import 66 | # When using Gradle or Maven with auto-import, you should exclude module files, 67 | # since they will be recreated, and may cause churn. Uncomment if using 68 | # auto-import. 69 | # .idea/artifacts 70 | # .idea/compiler.xml 71 | # .idea/jarRepositories.xml 72 | # .idea/modules.xml 73 | # .idea/*.iml 74 | # .idea/modules 75 | # *.iml 76 | # *.ipr 77 | 78 | # CMake 79 | cmake-build-*/ 80 | 81 | # Mongo Explorer plugin 82 | .idea/**/mongoSettings.xml 83 | 84 | # File-based project format 85 | *.iws 86 | 87 | # IntelliJ 88 | out/ 89 | 90 | # mpeltonen/sbt-idea plugin 91 | .idea_modules/ 92 | 93 | # JIRA plugin 94 | atlassian-ide-plugin.xml 95 | 96 | # Cursive Clojure plugin 97 | .idea/replstate.xml 98 | 99 | # Crashlytics plugin (for Android Studio and IntelliJ) 100 | com_crashlytics_export_strings.xml 101 | crashlytics.properties 102 | crashlytics-build.properties 103 | fabric.properties 104 | 105 | # Editor-based Rest Client 106 | .idea/httpRequests 107 | 108 | # Android studio 3.1+ serialized cache file 109 | .idea/caches/build_file_checksums.ser 110 | 111 | ### PyCharm+all Patch ### 112 | # Ignores the whole .idea folder and all .iml files 113 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 114 | 115 | .idea/ 116 | 117 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 118 | 119 | *.iml 120 | modules.xml 121 | .idea/misc.xml 122 | *.ipr 123 | 124 | # Sonarlint plugin 125 | .idea/sonarlint 126 | 127 | ### Python ### 128 | # Byte-compiled / optimized / DLL files 129 | __pycache__/ 130 | *.py[cod] 131 | *$py.class 132 | 133 | # C extensions 134 | *.so 135 | 136 | # Distribution / packaging 137 | .Python 138 | build/ 139 | develop-eggs/ 140 | dist/ 141 | downloads/ 142 | eggs/ 143 | .eggs/ 144 | lib/ 145 | lib64/ 146 | parts/ 147 | sdist/ 148 | var/ 149 | wheels/ 150 | share/python-wheels/ 151 | *.egg-info/ 152 | .installed.cfg 153 | *.egg 154 | MANIFEST 155 | 156 | # PyInstaller 157 | # Usually these files are written by a python script from a template 158 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 159 | *.manifest 160 | *.spec 161 | 162 | # Installer logs 163 | pip-log.txt 164 | pip-delete-this-directory.txt 165 | 166 | # Unit test / coverage reports 167 | htmlcov/ 168 | .tox/ 169 | .nox/ 170 | .coverage 171 | .coverage.* 172 | .cache 173 | nosetests.xml 174 | coverage.xml 175 | *.cover 176 | *.py,cover 177 | .hypothesis/ 178 | .pytest_cache/ 179 | cover/ 180 | 181 | # Translations 182 | *.mo 183 | *.pot 184 | 185 | # Django stuff: 186 | *.log 187 | local_settings.py 188 | db.sqlite3 189 | db.sqlite3-journal 190 | 191 | # Flask stuff: 192 | instance/ 193 | .webassets-cache 194 | 195 | # Scrapy stuff: 196 | .scrapy 197 | 198 | # Sphinx documentation 199 | docs/_build/ 200 | 201 | # PyBuilder 202 | .pybuilder/ 203 | target/ 204 | 205 | # Jupyter Notebook 206 | .ipynb_checkpoints 207 | 208 | # IPython 209 | profile_default/ 210 | ipython_config.py 211 | 212 | # pyenv 213 | # For a library or package, you might want to ignore these files since the code is 214 | # intended to run in multiple environments; otherwise, check them in: 215 | # .python-version 216 | 217 | # pipenv 218 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 219 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 220 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 221 | # install all needed dependencies. 222 | #Pipfile.lock 223 | 224 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 225 | __pypackages__/ 226 | 227 | # Celery stuff 228 | celerybeat-schedule 229 | celerybeat.pid 230 | 231 | # SageMath parsed files 232 | *.sage.py 233 | 234 | # Environments 235 | .env 236 | .venv 237 | env/ 238 | venv/ 239 | ENV/ 240 | env.bak/ 241 | venv.bak/ 242 | 243 | # Spyder project settings 244 | .spyderproject 245 | .spyproject 246 | 247 | # Rope project settings 248 | .ropeproject 249 | 250 | # mkdocs documentation 251 | /site 252 | 253 | # mypy 254 | .mypy_cache/ 255 | .dmypy.json 256 | dmypy.json 257 | 258 | # Pyre type checker 259 | .pyre/ 260 | 261 | # pytype static type analyzer 262 | .pytype/ 263 | 264 | # Cython debug symbols 265 | cython_debug/ 266 | 267 | # End of https://www.toptal.com/developers/gitignore/api/macos,python,pycharm+all 268 | n -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![build](https://github.com/fabiannagel/schnax/actions/workflows/travis.yml/badge.svg) 2 | 3 | # `schnax`: SchNet in JAX and JAX-MD 4 | This is a re-implementation of the `SchNet` neural network architecture in [JAX](https://github.com/google/jax), [haiku](https://github.com/deepmind/dm-haiku), and [JAX-MD](https://github.com/google/jax-md). 5 | `schnax` is intended as a drop-in replacement for the original `pytorch` implementation, allowing the use of trained weights obtained with [SchNetPack](https://github.com/atomistic-machine-learning/schnetpack). 6 | 7 | 8 | ## References 9 | * [1] K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, A. Tkatchenko, K.-R. Müller. 10 | *SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.* 11 | Advances in Neural Information Processing Systems 30, pp. 992-1002 (2017) [link](http://papers.nips.cc/paper/6700-schnet-a-continuous-filter-convolutional-neural-network-for-modeling-quantum-interactions) 12 | 13 | * [2] K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, A. Tkatchenko, K.-R. Müller. 14 | *SchNet - a deep learning architecture for molecules and materials.* 15 | The Journal of Chemical Physics 148(24), 241722 (2018) [10.1063/1.5019779](https://doi.org/10.1063/1.5019779) 16 | 17 | * [3] K.T. Schütt, P. Kessel, M. Gastegger, K. Nicoli, A. Tkatchenko, K.-R. Müller. *SchNetPack: A Deep Learning Toolbox For Atomistic Systems.* J. Chem. Theory Comput. [10.1021/acs.jctc.8b00908](https://doi.org/10.1021/acs.jctc.8b00908) [arXiv:1809.01072.](https://arxiv.org/abs/1809.01072v1) (2018) 18 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "schnax" 3 | version = "0.1.0" 4 | description = "The SchNet neural network architecture in JAX and JAX-MD" 5 | authors = ["Fabian Nagel ", "Marcel Langer "] 6 | 7 | [tool.poetry.dependencies] 8 | python = "^3.7" 9 | jax = {extras = ["cpu"], version = "^0.3.4"} 10 | jax_md = "^0.1.25" 11 | dm-haiku = "^0.0.6" 12 | numpy = "^1.21" 13 | 14 | [tool.poetry.dev-dependencies] 15 | ase = "^3.22" 16 | schnetpack = { git = "https://github.com/atomistic-machine-learning/schnetpack.git", branch = "master" } 17 | schnetkit = { git = "https://github.com/sirmarcel/schnetkit.git", branch = "main", rev = "d7cba6a2d5dafdeb72de3c3f37e5b211a30e3e5a" } 18 | nose2 = "^0.11.0" 19 | torch = "^1.11.0" 20 | 21 | [build-system] 22 | requires = ["poetry>=0.12"] 23 | build-backend = "poetry.masonry.api" 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ase==3.22.1 2 | dm-haiku==0.0.6 3 | jax-md==0.1.25 4 | nose2==0.11.0 5 | schnax @ git+https://github.com/fabiannagel/schnax.git@b901791112066e49fc4c322427bfd09805882a1c 6 | schnetkit @ git+https://github.com/sirmarcel/schnetkit.git@d7cba6a2d5dafdeb72de3c3f37e5b211a30e3e5a 7 | schnetpack==1.0.0 8 | torch==1.11.0 9 | -------------------------------------------------------------------------------- /schnax/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import SchNet 2 | -------------------------------------------------------------------------------- /schnax/energy.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | import jax 3 | import jax_md 4 | from jax import numpy as jnp 5 | from jax_md.energy import DisplacementFn, Box 6 | 7 | from .model import SchNet 8 | from .utils import compute_distances 9 | from .utils.stateless import transform_stateless 10 | 11 | 12 | def _get_model(displacement_fn: DisplacementFn, n_atom_basis: int, max_z: int, n_gaussians: int, n_filters: int, 13 | mean: float, stddev: float, r_cutoff: jnp.ndarray, n_interactions: int, normalize_filter: bool, 14 | per_atom: bool, return_activations: bool): 15 | """Moved to dedicated method for better testing access.""" 16 | 17 | def model(R: jnp.ndarray, Z: jnp.int32, neighbor: jnp.ndarray): 18 | dR = compute_distances(R, neighbor, displacement_fn) 19 | net = SchNet(n_atom_basis=n_atom_basis, 20 | max_z=max_z, 21 | n_gaussians=n_gaussians, 22 | n_filters=n_filters, 23 | mean=mean, 24 | stddev=stddev, 25 | r_cutoff=r_cutoff, 26 | n_interactions=n_interactions, 27 | normalize_filter=normalize_filter, 28 | per_atom=per_atom) 29 | return net(dR, Z, neighbor) 30 | 31 | fun = hk.transform_with_state(model) 32 | fun = hk.without_apply_rng(fun) 33 | if return_activations: 34 | return fun 35 | 36 | # how to get rid of states? see https://github.com/deepmind/dm-haiku/issues/267 37 | rng = jax.random.PRNGKey(0) 38 | return transform_stateless(rng, fun.init, fun.apply) 39 | 40 | 41 | def schnet_neighbor_list( 42 | displacement_fn: DisplacementFn, 43 | box_size: Box, 44 | r_cutoff: float, 45 | dr_threshold: float, 46 | per_atom=False, 47 | n_interactions=1, 48 | n_atom_basis=128, 49 | max_z=100, 50 | n_gaussians=25, 51 | n_filters=128, 52 | mean=0.0, 53 | stddev=20.0, 54 | normalize_filter=False, 55 | return_activations=False 56 | ): 57 | """Convenience wrapper around SchNet""" 58 | model = _get_model(displacement_fn=displacement_fn, 59 | n_atom_basis=n_atom_basis, 60 | max_z=max_z, 61 | n_gaussians=n_gaussians, 62 | n_filters=n_filters, 63 | mean=mean, 64 | stddev=stddev, 65 | r_cutoff=r_cutoff, 66 | n_interactions=n_interactions, 67 | normalize_filter=normalize_filter, 68 | per_atom=per_atom, 69 | return_activations=return_activations) 70 | 71 | neighbor_fn = jax_md.partition.neighbor_list( 72 | displacement_fn, 73 | box_size, 74 | r_cutoff, 75 | dr_threshold=dr_threshold, 76 | mask_self=True, 77 | fractional_coordinates=False, 78 | ) 79 | 80 | return neighbor_fn, model.init, model.apply 81 | -------------------------------------------------------------------------------- /schnax/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .schnet import SchNet 2 | -------------------------------------------------------------------------------- /schnax/model/gaussian_smearing.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | import jax.numpy as jnp 3 | 4 | 5 | class GaussianSmearing(hk.Module): 6 | def __init__( 7 | self, start=0.0, stop=5.0, n_gaussians=50, centered=False, trainable=False 8 | ): 9 | super().__init__(name="GaussianSmearing") 10 | self.offset = jnp.linspace(start, stop, n_gaussians) 11 | self.widths = (self.offset[1] - self.offset[0]) * jnp.ones_like(self.offset) 12 | # TODO: trainable 13 | self.centered = centered 14 | 15 | def _smearing(self, distances: jnp.ndarray) -> jnp.ndarray: 16 | """Smear interatomic distance values using Gaussian functions.""" 17 | 18 | if not self.centered: 19 | # compute width of Gaussian functions (using an overlap of 1 STDDEV) 20 | coeff = -0.5 / jnp.power(self.widths, 2) 21 | # Use advanced indexing to compute the individual components 22 | # diff = distances[:, :, :, None] - self.offset[None, None, None, :] 23 | diff = ( 24 | distances[:, :, None] - self.offset[None, None, :] 25 | ) # skip batches for now 26 | else: 27 | # if Gaussian functions are centered, use offsets to compute widths 28 | coeff = -0.5 / jnp.power(self.offset, 2) 29 | # if Gaussian functions are centered, no offset is subtracted 30 | # diff = distances[:, :, :, None] 31 | diff = distances[:, :, None] # skip batches for now 32 | 33 | # compute smear distance values 34 | gauss = jnp.exp(coeff * jnp.power(diff, 2)) 35 | return gauss 36 | 37 | def __call__(self, distances: jnp.ndarray, *args, **kwargs): 38 | smearing = self._smearing(distances) 39 | hk.set_state(self.name, smearing) 40 | return smearing 41 | -------------------------------------------------------------------------------- /schnax/model/interaction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fabiannagel/schnax/e153581b6b44976e255955716a7aada07ee0838d/schnax/model/interaction/__init__.py -------------------------------------------------------------------------------- /schnax/model/interaction/aggregate.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | from jax import numpy as jnp 3 | 4 | 5 | class Aggregate(hk.Module): 6 | """Pooling layer based on sum or average with optional masking. 7 | 8 | Args: 9 | axis (int): axis along which pooling is done. 10 | mean_pooling (bool, optional): if True, use average instead for sum pooling. 11 | keepdim (bool, optional): whether the output tensor has dim retained or not. 12 | 13 | """ 14 | 15 | def __init__(self, axis: int, mean_pooling=False, keepdim=True): 16 | super().__init__(name="Aggregate") 17 | self.axis = axis 18 | self.mean_pooling = mean_pooling 19 | self.keepdim = keepdim 20 | 21 | def __call__(self, input: jnp.ndarray, mask=None): 22 | # mask input 23 | if mask is not None: 24 | input = input * mask[..., None] 25 | 26 | # compute sum of input along axis 27 | y = jnp.sum(input, self.axis) 28 | 29 | # compute average of input along axis 30 | if self.mean_pooling: 31 | 32 | # get the number of items along axis 33 | if mask is not None: 34 | N = jnp.sum(mask, self.axis, keepdim=self.keepdim) 35 | N = jnp.max(N, other=jnp.ones_like(N)) 36 | 37 | else: 38 | # N = input.size(self.axis) 39 | N = input.shape(self.axis) 40 | 41 | y = y / N 42 | 43 | return y 44 | -------------------------------------------------------------------------------- /schnax/model/interaction/cfconv.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | import jax.numpy as jnp 3 | from jax_md.partition import NeighborList 4 | 5 | from schnax.model.interaction.aggregate import Aggregate 6 | from schnax.model.interaction.filter_network import FilterNetwork 7 | from schnax.model.interaction.cosine_cutoff import CosineCutoff 8 | 9 | 10 | class CFConv(hk.Module): 11 | 12 | def __init__(self, n_filters: int, n_out: int, r_cutoff: float, activation, normalize_filter=False, axis=1): 13 | super().__init__(name="CFConv") 14 | 15 | self.filter_network = FilterNetwork(n_filters) 16 | self.cutoff_network = CosineCutoff(r_cutoff) 17 | 18 | self.in2f = hk.Linear(n_filters, with_bias=False, name="in2f") 19 | self.f2out = hk.Sequential([ 20 | hk.Linear(n_out, with_bias=True, name="f2out"), activation 21 | ]) 22 | 23 | self.aggregate = Aggregate(axis=axis, mean_pooling=normalize_filter) 24 | 25 | @staticmethod 26 | def _reshape_y(y: jnp.ndarray, neighbors: NeighborList) -> jnp.ndarray: 27 | nbh_size = neighbors.idx.shape 28 | 29 | # use SchNetPack's 0-padding approach to obtain *exactly* the same results after this reshaping operation. 30 | nbh_indices = neighbors.idx 31 | # padding_mask = nbh_indices == nbh_indices.shape[0] 32 | # nbh_indices[padding_mask] = 0 33 | 34 | # (n_atoms, max_occupancy) -> (n_atoms * max_occupancy, 1) 35 | # for batches, use shape (-1, nbh_size[1] * nbh_size[2], 1) 36 | nbh = nbh_indices.reshape((nbh_size[0] * nbh_size[1], 1)) 37 | 38 | # (n_atoms * max_occupancy, 1) -> (n_atoms * max_occupancy, n_filters) 39 | nbh = jnp.tile(nbh, (1, y.shape[1])) 40 | 41 | # (n_atoms, n_filters) -> (n_atoms * max_occupancy, n_filters) 42 | y = jnp.take_along_axis(y, indices=nbh, axis=0) 43 | 44 | # (n_atoms * max_occupancy, n_filters) -> (n_atoms, max_occupancy, n_filters) 45 | y = jnp.reshape(y, (nbh_size[0], nbh_size[1], -1)) 46 | return y 47 | 48 | def __call__(self, x: jnp.ndarray, dR: jnp.ndarray, neighbors: NeighborList, dR_expanded: jnp.ndarray): 49 | if dR_expanded is None: 50 | # Insert a new dimension (size 1) at the last position 51 | # (n_atoms, max_occupancy) -> (n_atoms, max_occupancy, 1) 52 | dR_expanded = jnp.expand_dims(dR, axis=-1) 53 | 54 | # pass expanded interactomic distances through filter block 55 | W = self.filter_network(dR_expanded) 56 | 57 | # apply cutoff 58 | if self.cutoff_network is not None: 59 | C = self.cutoff_network(dR) 60 | W = W * jnp.expand_dims(C, axis=-1) 61 | 62 | # pass initial embeddings through dense layer. reshape y for element-wise multiplication by W. 63 | y = self.in2f(x) 64 | hk.set_state(self.in2f.name, y) 65 | y = self._reshape_y(y, neighbors) 66 | 67 | # element-wise multiplication, aggregation and dense output layer. 68 | y = y * W 69 | 70 | # aggregate over neighborhoods, skip padded indices. 71 | actual_indices_mask = neighbors.idx != neighbors.idx.shape[0] 72 | y = self.aggregate(y, actual_indices_mask) 73 | hk.set_state(self.aggregate.name, y) 74 | 75 | y = self.f2out(y) 76 | hk.set_state(self.f2out.layers[0].name, y) 77 | 78 | return y 79 | -------------------------------------------------------------------------------- /schnax/model/interaction/cosine_cutoff.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | import jax.numpy as jnp 3 | 4 | 5 | class CosineCutoff(hk.Module): 6 | 7 | def __init__(self, r_cutoff: float): 8 | super().__init__(name="CosineCutoff") 9 | self.r_cutoff = jnp.float32(r_cutoff) 10 | 11 | def __call__(self, dR: jnp.ndarray) -> jnp.ndarray: 12 | cutoffs = 0.5 * (jnp.cos(dR * jnp.pi / self.r_cutoff) + 1.0) 13 | cutoffs *= (dR < self.r_cutoff) 14 | hk.set_state(self.name, cutoffs) 15 | return cutoffs 16 | -------------------------------------------------------------------------------- /schnax/model/interaction/filter_network.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | from jax import numpy as jnp 3 | 4 | from schnax.model.misc import shifted_softplus 5 | 6 | 7 | class FilterNetwork(hk.Module): 8 | def __init__(self, n_filters: int): 9 | super().__init__(name="FilterNetwork") 10 | 11 | self.linear_0 = hk.Sequential( 12 | [hk.Linear(n_filters, name="linear_0"), shifted_softplus] 13 | ) # n_spatial_basis -> n_filters 14 | self.linear_1 = hk.Linear(n_filters, name="linear_1") # n_filters -> n_filters 15 | 16 | def __call__(self, x: jnp.ndarray): 17 | x = self.linear_0(x) 18 | hk.set_state( 19 | self.linear_0.layers[0].name, x 20 | ) # w/o referencing layer 0, the key would be "Sequential". 21 | 22 | x = self.linear_1(x) 23 | hk.set_state(self.linear_1.name, x) 24 | 25 | return x 26 | -------------------------------------------------------------------------------- /schnax/model/interaction/interaction.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | import jax.numpy as jnp 3 | from jax_md.partition import NeighborList 4 | 5 | from schnax.model.misc import shifted_softplus 6 | 7 | from .cfconv import CFConv 8 | 9 | 10 | class Interaction(hk.Module): 11 | def __init__( 12 | self, 13 | idx: int, 14 | n_atom_basis: int, 15 | n_filters: int, 16 | n_spatial_basis: int, 17 | r_cutoff: float, 18 | normalize_filter: bool 19 | ): 20 | super().__init__(name="Interaction_{}".format(idx)) 21 | self.cfconv = CFConv( 22 | n_filters, n_atom_basis, r_cutoff, activation=shifted_softplus, normalize_filter=normalize_filter 23 | ) 24 | self.dense = hk.Linear(n_atom_basis, name="Output") 25 | 26 | def __call__( 27 | self, x: jnp.ndarray, dR: jnp.ndarray, neighbors: NeighborList, dR_expanded=None 28 | ): 29 | """Compute convolution block. 30 | 31 | Args: 32 | x: input representation/embedding of atomic environments with (N_a, n_in) shape. 33 | dR: interatomic distances of (N_a, N_nbh) shape. 34 | neighbors: neighbor list with neighbor indices in (N_a, N_nbh) shape. 35 | pairwise_mask: mask to filter out non-existing neighbors introduced via padding. 36 | dR_expanded (optional): expanded interatomic distances in a basis. 37 | If None, dR.unsqueeze(-1) is used. 38 | 39 | Returns: 40 | jnp.ndarray: block output with (N_a, n_out) shape. 41 | 42 | """ 43 | x = self.cfconv(x, dR, neighbors, dR_expanded) 44 | x = self.dense(x) 45 | hk.set_state(self.dense.name, x) 46 | return x 47 | -------------------------------------------------------------------------------- /schnax/model/misc.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | def shifted_softplus(x: jnp.ndarray) -> jnp.ndarray: 6 | return jax.nn.softplus(x) - jnp.log(2.0) 7 | -------------------------------------------------------------------------------- /schnax/model/schnet.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | import jax.numpy as jnp 3 | from jax_md.partition import NeighborList 4 | 5 | from .misc import shifted_softplus 6 | from .gaussian_smearing import GaussianSmearing 7 | from .interaction.aggregate import Aggregate 8 | from .interaction.interaction import Interaction 9 | 10 | 11 | class SchNet(hk.Module): 12 | # n_atom_basis = 128 13 | # max_z = 100 14 | # n_gaussians = 25 15 | # 16 | # n_filters = 128 17 | # 18 | # mean = 0.0 19 | # stddev = 20.0 20 | 21 | # config_atomwise = {'n_in': 128, 'mean': 0.0, 'stddev': 20.0, 'n_layers': 2, 'n_neurons': None} 22 | 23 | def __init__(self, n_atom_basis: int, max_z: int, n_gaussians: int, n_filters: int, mean: float, stddev: float, 24 | r_cutoff: float, n_interactions: int, normalize_filter: bool, per_atom: bool): 25 | 26 | self.n_atom_basis = n_atom_basis 27 | self.max_z = max_z 28 | self.n_gaussians = n_gaussians 29 | self.n_filters = n_filters 30 | self.mean = mean 31 | self.stddev = stddev 32 | 33 | super().__init__(name="SchNet") 34 | self.n_interactions = n_interactions 35 | self.per_atom = per_atom 36 | 37 | self.embedding = hk.Embed( 38 | self.max_z, self.n_atom_basis, name="embeddings" 39 | ) 40 | self.distance_expansion = GaussianSmearing(0.0, r_cutoff, self.n_gaussians) 41 | 42 | self.interactions = hk.Sequential( 43 | [ 44 | Interaction( 45 | idx=i, 46 | n_atom_basis=self.n_atom_basis, 47 | n_filters=self.n_filters, 48 | n_spatial_basis=self.n_gaussians, 49 | r_cutoff=r_cutoff, 50 | normalize_filter=normalize_filter 51 | ) 52 | for i in range(self.n_interactions) 53 | ] 54 | ) 55 | 56 | self.atomwise = hk.nets.MLP( 57 | output_sizes=[64, 1], activation=shifted_softplus, name="atomwise" 58 | ) 59 | self.aggregate = Aggregate(axis=0, mean_pooling=False) 60 | 61 | @staticmethod 62 | def standardize(yi: jnp.ndarray, mean: float, stddev: float): 63 | return yi * stddev + mean 64 | 65 | def __call__( 66 | self, dR: jnp.ndarray, Z: jnp.ndarray, neighbors: NeighborList, *args, **kwargs 67 | ) -> jnp.ndarray: 68 | # TODO: Move hk.set_state() calls into layer modules. Use self.name as key. 69 | 70 | # get embedding for Z 71 | x = self.embedding(Z) 72 | hk.set_state("embedding", x) 73 | 74 | # expand interatomic distances 75 | dR_expanded = self.distance_expansion(dR) 76 | # hk.set_state("distance_expansion", dR_expanded) 77 | 78 | # compute interactions 79 | for i, interaction in enumerate(self.interactions.layers): 80 | v = interaction(x, dR, neighbors, dR_expanded) 81 | x = x + v 82 | 83 | # energy contributions 84 | yi = self.atomwise(x) 85 | yi = self.standardize(yi, self.mean, self.stddev) 86 | 87 | if self.per_atom: 88 | return jnp.squeeze(yi) 89 | 90 | y = self.aggregate(yi) 91 | return jnp.squeeze(y) 92 | -------------------------------------------------------------------------------- /schnax/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .distances import compute_distances 2 | from .ase import atoms_to_input 3 | from .schnetkit import get_params 4 | -------------------------------------------------------------------------------- /schnax/utils/ase.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax.numpy as jnp 3 | from ase import Atoms 4 | 5 | 6 | def atoms_to_input(atoms: Atoms): 7 | R = atoms.positions.astype(np.float32) 8 | Z = atoms.numbers.astype(np.int) 9 | box = np.array(atoms.cell.array, dtype=np.float32) 10 | return jnp.float32(R), jnp.int32(Z), jnp.float32(box) 11 | -------------------------------------------------------------------------------- /schnax/utils/distances.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from jax_md import space 4 | from jax_md.energy import DisplacementFn 5 | from jax_md.partition import NeighborList 6 | 7 | 8 | def compute_distances( 9 | R: jnp.ndarray, neighbors: NeighborList, displacement_fn: DisplacementFn 10 | ) -> jnp.ndarray: 11 | R_neighbors = R[neighbors.idx] 12 | 13 | nl_displacement_fn = space.map_neighbor(displacement_fn) 14 | displacements = nl_displacement_fn(R, R_neighbors) 15 | distances_with_padding = space.distance(displacements) 16 | 17 | padding_mask = neighbors.idx < R.shape[0] 18 | distances_without_padding = distances_with_padding * padding_mask 19 | return distances_without_padding 20 | -------------------------------------------------------------------------------- /schnax/utils/layer_hooks.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, OrderedDict 2 | 3 | 4 | def register_layer_hook(layer_outputs: Dict, layer, layer_name: str): 5 | hook = get_layer_hook(layer_outputs, layer_name) 6 | layer.register_forward_hook(hook) 7 | 8 | 9 | def get_layer_hook(layer_outputs: Dict, layer_name: str) -> Callable: 10 | def hook(model, input, output): 11 | layer_outputs[layer_name] = output.detach() 12 | 13 | return hook 14 | 15 | 16 | def register_representation_layer_hooks(layer_outputs: Dict, model: OrderedDict): 17 | # layer_member_names = list(dict(model.named_modules()).keys())[1:] 18 | register_layer_hook( 19 | layer_outputs, model.representation.embedding, "representation.embedding" 20 | ) 21 | register_layer_hook( 22 | layer_outputs, model.representation.distances, "representation.distances" 23 | ) 24 | register_layer_hook( 25 | layer_outputs, 26 | model.representation.distance_expansion, 27 | "representation.distance_expansion", 28 | ) 29 | 30 | for interaction_network_idx, interaction_network in enumerate( 31 | model.representation.interactions 32 | ): 33 | base_name = "representation.interactions.{}.".format(interaction_network_idx) 34 | 35 | # cfconv layer 36 | register_layer_hook( 37 | layer_outputs, 38 | interaction_network.filter_network, 39 | base_name + "filter_network.1", 40 | ) 41 | register_layer_hook( 42 | layer_outputs, 43 | interaction_network.cutoff_network, 44 | base_name + "cutoff_network", 45 | ) 46 | register_layer_hook( 47 | layer_outputs, interaction_network.cfconv.in2f, base_name + "cfconv.in2f" 48 | ) 49 | register_layer_hook( 50 | layer_outputs, interaction_network.cfconv.f2out, base_name + "cfconv.f2out" 51 | ) 52 | register_layer_hook( 53 | layer_outputs, interaction_network.cfconv.agg, base_name + "cfconv.agg" 54 | ) 55 | 56 | # dense output layer 57 | register_layer_hook(layer_outputs, interaction_network.dense, base_name + "dense") 58 | 59 | 60 | def register_output_layer_hooks(layer_outputs: Dict, model: OrderedDict): 61 | for output_module_idx, output_module in enumerate(model.output_modules): 62 | base_name = "output_modules.{}.".format(output_module_idx) 63 | 64 | # GetItem() 65 | register_layer_hook(layer_outputs, output_module.out_net[0], base_name + "0") 66 | 67 | # 2-layer MLP 68 | register_layer_hook( 69 | layer_outputs, output_module.out_net[1].out_net[0], base_name + "1.out_net.0" 70 | ) 71 | register_layer_hook( 72 | layer_outputs, output_module.out_net[1].out_net[1], base_name + "1.out_net.1" 73 | ) 74 | 75 | # ScaleShift() 76 | register_layer_hook( 77 | layer_outputs, output_module.standardize, base_name + "standardize" 78 | ) 79 | 80 | # Aggregate() 81 | register_layer_hook( 82 | layer_outputs, output_module.atom_pool, base_name + "atom_pool" 83 | ) 84 | -------------------------------------------------------------------------------- /schnax/utils/schnetkit.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import numpy as np 4 | from jax import random 5 | from jax_md import space, partition 6 | from schnax import energy, utils 7 | from schnetkit.engine import load_file 8 | import jax.numpy as jnp 9 | 10 | 11 | def get_interaction_count(file: str) -> int: 12 | spec, state = load_file(file) 13 | return spec['schnet']['representation']['n_interactions'] 14 | 15 | 16 | def get_params(file: str) -> Dict: 17 | """Read weights from an existing schnetkit torch model.""" 18 | spec, state = load_file(file) 19 | n_interactions = get_interaction_count(file) 20 | params = {} 21 | 22 | def get_param(key): 23 | return state[key].cpu().numpy() 24 | 25 | def set_params(layer_key: str, weight_key: str, bias_key=None, interaction_idx=None): 26 | if interaction_idx is not None: 27 | layer_key = layer_key.format(interaction_idx) 28 | weight_key = weight_key.format(interaction_idx) 29 | if bias_key: 30 | bias_key = bias_key.format(interaction_idx) 31 | 32 | params[layer_key] = {} 33 | params[layer_key]['w'] = get_param(weight_key).T 34 | if bias_key: 35 | params[layer_key]['b'] = get_param(bias_key) 36 | 37 | # embeddings layer (special case, no transpose) 38 | params['SchNet/~/embeddings'] = { 39 | 'embeddings': get_param('representation.embedding.weight') 40 | } 41 | 42 | for i in range(n_interactions): 43 | # interaction block // cfconv block // filter network 44 | set_params( 45 | layer_key='SchNet/~/Interaction_{}/~/CFConv/~/FilterNetwork/~/linear_0', 46 | weight_key='representation.interactions.{}.filter_network.0.weight', 47 | bias_key='representation.interactions.{}.filter_network.0.bias', 48 | interaction_idx=i, 49 | ) 50 | 51 | set_params( 52 | layer_key='SchNet/~/Interaction_{}/~/CFConv/~/FilterNetwork/~/linear_1', 53 | weight_key='representation.interactions.{}.filter_network.1.weight', 54 | bias_key='representation.interactions.{}.filter_network.1.bias', 55 | interaction_idx=i, 56 | ) 57 | 58 | # interaction block // cfconv block // in2f 59 | set_params( 60 | layer_key='SchNet/~/Interaction_{}/~/CFConv/~/in2f', 61 | weight_key='representation.interactions.{}.cfconv.in2f.weight', 62 | interaction_idx=i, 63 | ) 64 | 65 | # interaction block // cfconv block // f2out 66 | set_params( 67 | layer_key='SchNet/~/Interaction_{}/~/CFConv/~/f2out', 68 | weight_key='representation.interactions.{}.cfconv.f2out.weight', 69 | bias_key='representation.interactions.{}.cfconv.f2out.bias', 70 | interaction_idx=i, 71 | ) 72 | 73 | # interaction block // output layer 74 | set_params( 75 | layer_key='SchNet/~/Interaction_{}/~/Output', 76 | weight_key='representation.interactions.{}.dense.weight', 77 | bias_key='representation.interactions.{}.dense.bias', 78 | interaction_idx=i, 79 | ) 80 | 81 | set_params( 82 | layer_key='SchNet/~/atomwise/~/linear_0', 83 | weight_key='output_modules.0.out_net.1.out_net.0.weight', 84 | bias_key='output_modules.0.out_net.1.out_net.0.bias', 85 | ) 86 | set_params( 87 | layer_key='SchNet/~/atomwise/~/linear_1', 88 | weight_key='output_modules.0.out_net.1.out_net.1.weight', 89 | bias_key='output_modules.0.out_net.1.out_net.1.bias', 90 | ) 91 | 92 | return params 93 | 94 | 95 | def normalize_representation_config(repr_config: Dict) -> Dict: 96 | """Normalize existing representation config to contain all required arguments. Also changed some naming conventions 97 | to passthrough kwargs to schnax with less hassle.""" 98 | 99 | normalized_repr = {} 100 | 101 | def set_value_or_default(key: str, default: Any, new_key=None): 102 | if new_key is None: 103 | new_key = key 104 | 105 | try: 106 | normalized_repr.update({new_key: repr_config[key]}) 107 | except KeyError: 108 | normalized_repr.update({new_key: default}) 109 | 110 | # skipping 'trainable_gaussians' (training not implemented) 111 | set_value_or_default('cutoff', 5.0, new_key='r_cutoff') 112 | set_value_or_default('n_interactions', 1) 113 | set_value_or_default('n_atom_basis', 128) 114 | set_value_or_default('max_z', 100) 115 | set_value_or_default('n_gaussians', 25) 116 | set_value_or_default('n_filters', 128) 117 | set_value_or_default('mean', 0.0) 118 | set_value_or_default('stddev', 1.0) 119 | set_value_or_default('normalize_filter', False) 120 | return normalized_repr 121 | 122 | 123 | def initialize_from_schnetkit_model( 124 | file: str, box: np.ndarray, dr_threshold=0.0, per_atom=False, return_activations=True 125 | ): 126 | spec, weights = load_file(file) 127 | model_config = normalize_representation_config(spec['schnet']['representation']) 128 | atomwise_config = spec['schnet']['atomwise'] 129 | 130 | # these keys are in atomwise 131 | model_config['mean'] = atomwise_config['mean'] 132 | model_config['stddev'] = atomwise_config['stddev'] 133 | 134 | box = jnp.float32(box) 135 | displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False) 136 | 137 | r_cutoff = jnp.float32(model_config['r_cutoff']) 138 | neighbor_fn = partition.neighbor_list( 139 | displacement_fn, 140 | box, 141 | r_cutoff, 142 | dr_threshold=dr_threshold, # as the effective cutoff = r_cutoff + dr_threshold 143 | mask_self=True, # an atom is not a neighbor of itself 144 | fractional_coordinates=False, 145 | ) 146 | 147 | # 'max_z', 'n_gaussians', 'mean', and 'stddev' missing in repr 148 | init_fn, apply_fn = energy._get_model( 149 | displacement_fn=displacement_fn, per_atom=per_atom, return_activations=return_activations, **model_config 150 | ) 151 | 152 | params = utils.get_params(file) 153 | return neighbor_fn, displacement_fn, shift_fn, params, init_fn, apply_fn 154 | -------------------------------------------------------------------------------- /schnax/utils/stateless.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax.numpy as jnp 4 | from haiku import Transformed 5 | from haiku._src.data_structures import FlatMapping 6 | from jax._src.random import KeyArray 7 | from jax_md.partition import NeighborList 8 | 9 | 10 | def transform_stateless(rng: KeyArray, init_fn: Callable, apply_fn: Callable) -> Transformed: 11 | def stateless_init_fn(rng: KeyArray, R: jnp.ndarray, Z: jnp.ndarray, neighbor: NeighborList): 12 | params, state = init_fn(rng, R, Z, neighbor) 13 | return params 14 | 15 | def stateless_apply_fn(params: FlatMapping, R: jnp.ndarray, Z: jnp.ndarray, neighbor: NeighborList, **kwargs): 16 | _, state = init_fn(rng, R, Z, neighbor) 17 | 18 | pred, state = apply_fn(params, state, R, Z, neighbor, **kwargs) 19 | return pred 20 | 21 | return Transformed(stateless_init_fn, stateless_apply_fn) 22 | -------------------------------------------------------------------------------- /schnax/utils/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ase.db import connect 3 | from jax import numpy as jnp 4 | 5 | 6 | def build_dataset(): 7 | 8 | def extract_properties(): 9 | positions = [] 10 | charges = [] 11 | energies = [] 12 | forces = [] 13 | 14 | db = connect("train/data/iso17/reference.db") 15 | for row in db.select(limit=350): 16 | atoms = row.toatoms() 17 | 18 | positions += [atoms.get_positions()] 19 | charges += [atoms.get_atomic_numbers()] 20 | energies += [row['total_energy']] 21 | forces += [row.data['atomic_forces']] 22 | 23 | return jnp.float32(positions), jnp.int32(charges), jnp.float32(energies), jnp.float32(forces) 24 | 25 | def shuffle_dataset(positions, charges, energies, forces): 26 | dataset_size = positions.shape[0] 27 | lookup = np.arange(dataset_size) 28 | np.random.shuffle(lookup) 29 | permute = lambda a, lookup: jnp.take_along_axis(a, lookup, axis=0) 30 | return permute(positions, lookup[:, None, None]), permute(charges, lookup[:, None]), permute(energies, lookup), permute(forces, lookup[:, None, None]) 31 | 32 | def split_dataset(positions, charges, energies, forces, train_size=0.8, val_size=0.1, test_size=0.1): 33 | assert train_size + val_size + test_size == 1.0 34 | dataset_size = positions.shape[0] 35 | idx_train_end = int(dataset_size * train_size) 36 | idx_val_start = int(dataset_size * train_size + dataset_size * val_size) 37 | 38 | train_positions, val_positions, test_positions = jnp.split(positions, [idx_train_end, idx_val_start]) 39 | train_charges, val_charges, test_charges = jnp.split(charges, [idx_train_end, idx_val_start]) 40 | train_energies, val_energies, test_energies = jnp.split(energies, [idx_train_end, idx_val_start]) 41 | train_forces, val_forces, test_forces = jnp.split(forces, [idx_train_end, idx_val_start]) 42 | 43 | train_set = train_positions, train_charges, train_energies, train_forces 44 | val_set = val_positions, val_charges, val_energies, val_forces 45 | test_set = test_positions, test_charges, test_energies, test_forces 46 | return train_set, val_set, test_set 47 | 48 | properties = extract_properties() 49 | properties = shuffle_dataset(*properties) 50 | train_set, val_set, test_set = split_dataset(*properties) 51 | # train_set = make_batches(*train_set, batch_size) 52 | return train_set, val_set, test_set 53 | 54 | 55 | def make_batches(train_positions, train_charges, train_energies, train_forces, batch_size): 56 | batch_positions = [] 57 | batch_charges = [] 58 | batch_energies = [] 59 | batch_forces = [] 60 | 61 | dataset_size = train_positions.shape[0] 62 | lookup = jnp.arange(0, dataset_size) 63 | for i in range(0, dataset_size, batch_size): 64 | if i + batch_size > len(lookup): 65 | break 66 | 67 | idx = lookup[i:i + batch_size] 68 | 69 | batch_positions += [train_positions[idx]] 70 | batch_charges += [train_charges[idx]] 71 | batch_energies += [train_energies[idx]] 72 | batch_forces += [train_forces[idx]] 73 | 74 | return jnp.stack(batch_positions), jnp.stack(batch_charges), jnp.stack(batch_energies), jnp.stack(batch_forces) -------------------------------------------------------------------------------- /tests/assets/model_n1.torch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fabiannagel/schnax/e153581b6b44976e255955716a7aada07ee0838d/tests/assets/model_n1.torch -------------------------------------------------------------------------------- /tests/assets/model_n5.torch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fabiannagel/schnax/e153581b6b44976e255955716a7aada07ee0838d/tests/assets/model_n5.torch -------------------------------------------------------------------------------- /tests/assets/zro2_n_324.in: -------------------------------------------------------------------------------- 1 | #======================================================= 2 | # FHI-aims file: geometry.in.supercell.1000K 3 | # Created using the Atomic Simulation Environment (ASE) 4 | # Sat Nov 13 17:45:37 2021 5 | 6 | # Additional information: 7 | # created from MB distrubtion 8 | # T = 1000.0 K 9 | # quantum: False 10 | # deterministic: False 11 | # plus_minus: False 12 | # gauge_eigenvectors: False 13 | # Random seed: 616700959 14 | # Sample number: 1 15 | #======================================================= 16 | lattice_vector 15.2273168400000003 0.0000000000000000 0.0000000000000000 17 | lattice_vector 0.0000000000000000 15.2273168400000003 0.0000000000000000 18 | lattice_vector 0.0000000000000000 0.0000000000000000 15.5603641799999988 19 | atom 0.0000000000000000 0.0000000000000000 0.0000000000000000 Zr 20 | velocity 0.2943979846273991 -2.2803895797531433 -1.6261915070374302 21 | atom 2.5378861400000008 2.5378861400000008 0.0000000000000000 Zr 22 | velocity 3.6504635765269255 -2.8499856954186193 -0.7226410952267897 23 | atom 5.0757722800000016 5.0757722800000016 0.0000000000000000 Zr 24 | velocity -3.1030659808639802 5.2893256348754374 -1.8841991705185073 25 | atom 7.6136584200000001 7.6136584200000001 0.0000000000000000 Zr 26 | velocity 1.0199255295791863 -0.0108330399902401 4.5756500389920713 27 | atom 10.1515445600000032 10.1515445600000032 0.0000000000000000 Zr 28 | velocity 2.6752960278115734 -1.2025211755773180 -5.3589451269615216 29 | atom 12.6894307000000026 12.6894307000000026 0.0000000000000000 Zr 30 | velocity 1.5316738984678691 1.2879652337507410 3.5893070921816141 31 | atom 12.6894307000000044 2.5378861400000008 0.0000000000000000 Zr 32 | velocity 4.9339324445658921 -1.5343361855012503 3.3557664272777790 33 | atom 0.0000000000000000 5.0757722800000016 0.0000000000000000 Zr 34 | velocity 2.2737985690993741 0.1730290651819895 3.2769549097309629 35 | atom 2.5378861400000008 7.6136584200000001 0.0000000000000000 Zr 36 | velocity 1.7221566719778805 1.0945604723113838 -0.6012149395103167 37 | atom 5.0757722800000025 10.1515445600000032 0.0000000000000000 Zr 38 | velocity -0.3638443984185772 4.5329148842601779 3.1916675833325554 39 | atom 7.6136584200000001 12.6894307000000026 0.0000000000000000 Zr 40 | velocity -2.2476928090757822 -3.6861575489507454 3.2027864712736802 41 | atom 10.1515445600000032 0.0000000000000000 0.0000000000000000 Zr 42 | velocity -0.1523586238010664 3.7838147733807972 -1.0398409564409636 43 | atom 10.1515445600000049 5.0757722800000016 0.0000000000000000 Zr 44 | velocity 1.9632053265158220 3.4271260692153489 -4.3370232558850335 45 | atom 12.6894307000000044 7.6136584200000001 0.0000000000000000 Zr 46 | velocity -2.3903290089835196 0.0121873293437342 -0.7319851632943658 47 | atom 0.0000000000000000 10.1515445600000032 0.0000000000000000 Zr 48 | velocity -2.8785025945388467 0.1656668582391811 -1.8698356838676760 49 | atom 2.5378861400000008 12.6894307000000026 0.0000000000000000 Zr 50 | velocity 0.1207379433846996 -0.3025739781599428 -0.4441632691407430 51 | atom 5.0757722800000016 0.0000000000000000 0.0000000000000000 Zr 52 | velocity -4.1145620980861057 0.7708386041791669 -0.1066909122200843 53 | atom 7.6136584200000001 2.5378861399999972 0.0000000000000000 Zr 54 | velocity -6.0304276631619906 -1.2429333776137523 -0.0008858252422341 55 | atom 0.0000000000000000 0.0000000000000000 5.1867880599999996 Zr 56 | velocity -1.7565911628085868 -2.7704440421599839 -0.9513073149413184 57 | atom 2.5378861400000008 2.5378861400000008 5.1867880599999996 Zr 58 | velocity -8.5118948285982547 4.8244859151913380 -0.3185057722266884 59 | atom 5.0757722800000016 5.0757722800000016 5.1867880599999996 Zr 60 | velocity -0.3767504947427329 -2.7981092107887959 -1.3196547855314262 61 | atom 7.6136584200000001 7.6136584200000001 5.1867880599999996 Zr 62 | velocity -1.5686271740668443 0.4199366767718719 -1.8853731727128717 63 | atom 10.1515445600000032 10.1515445600000032 5.1867880599999996 Zr 64 | velocity 2.8850829965483009 -7.1381232525154967 -0.8320095512152426 65 | atom 12.6894307000000026 12.6894307000000026 5.1867880599999996 Zr 66 | velocity 4.4813336857036097 3.6491077562758685 -2.4054991597245614 67 | atom 12.6894307000000044 2.5378861400000008 5.1867880599999996 Zr 68 | velocity 0.0095020359266324 0.3874452026844213 0.9779369554148606 69 | atom 0.0000000000000000 5.0757722800000016 5.1867880599999996 Zr 70 | velocity 6.4176840239344424 2.9484730934256405 -0.1381367623374735 71 | atom 2.5378861400000008 7.6136584200000001 5.1867880599999996 Zr 72 | velocity 4.4153572851483656 5.2573650346455727 -0.6213918027213922 73 | atom 5.0757722800000025 10.1515445600000032 5.1867880599999996 Zr 74 | velocity 3.4359293479869635 -1.0397026749201017 -2.0088479706882585 75 | atom 7.6136584200000001 12.6894307000000026 5.1867880599999996 Zr 76 | velocity 0.0047821594432962 0.5630983455376529 -3.7160633267272929 77 | atom 10.1515445600000032 0.0000000000000000 5.1867880599999996 Zr 78 | velocity 2.5198311832499587 2.3333550598997213 2.9283691458950218 79 | atom 10.1515445600000049 5.0757722800000016 5.1867880599999996 Zr 80 | velocity -4.7255617084149160 2.5313923028053451 1.4343357930497846 81 | atom 12.6894307000000044 7.6136584200000001 5.1867880599999996 Zr 82 | velocity 3.8036868035780684 -0.7880405861752369 3.5583298785588040 83 | atom 0.0000000000000000 10.1515445600000032 5.1867880599999996 Zr 84 | velocity -2.2950924152567658 2.2263877466853832 6.0898581306040684 85 | atom 2.5378861400000008 12.6894307000000026 5.1867880599999996 Zr 86 | velocity 5.2206343814991536 1.6115161747769757 0.5676999094675882 87 | atom 5.0757722800000016 0.0000000000000000 5.1867880599999996 Zr 88 | velocity 0.6010045619998635 3.9179838592472773 -0.2072811658195699 89 | atom 7.6136584200000001 2.5378861399999972 5.1867880599999996 Zr 90 | velocity -2.3043776247094967 0.8369925367203924 -4.8947036005536164 91 | atom 0.0000000000000000 0.0000000000000000 10.3735761199999992 Zr 92 | velocity 5.9401784910282389 -7.5554608496759998 -0.2429931042654560 93 | atom 2.5378861400000008 2.5378861400000008 10.3735761199999992 Zr 94 | velocity -2.0485482799307562 0.6051649075174594 -0.1635865568813885 95 | atom 5.0757722800000016 5.0757722800000016 10.3735761199999992 Zr 96 | velocity -1.6510369410443546 1.8371678980931934 -4.1453663958377485 97 | atom 7.6136584200000001 7.6136584200000001 10.3735761199999992 Zr 98 | velocity -0.4824113365568052 0.5883255163099623 -0.6985466753800362 99 | atom 10.1515445600000032 10.1515445600000032 10.3735761199999992 Zr 100 | velocity 1.4195391747759305 1.0035136763758219 1.3106908065226748 101 | atom 12.6894307000000026 12.6894307000000026 10.3735761199999992 Zr 102 | velocity -2.7813363534796527 -0.5335043116116627 4.5541455529722725 103 | atom 12.6894307000000044 2.5378861400000008 10.3735761199999992 Zr 104 | velocity 1.9377367548557298 -0.4685501657763533 1.1541801592289902 105 | atom 0.0000000000000000 5.0757722800000016 10.3735761199999992 Zr 106 | velocity -0.8264320523889126 2.7616565771369390 2.3853256573370354 107 | atom 2.5378861400000008 7.6136584200000001 10.3735761199999992 Zr 108 | velocity 0.7677011579813771 0.4317643223840620 -1.5418584697190343 109 | atom 5.0757722800000025 10.1515445600000032 10.3735761199999992 Zr 110 | velocity -6.9026548287374716 1.9593459604985950 1.1178874135431673 111 | atom 7.6136584200000001 12.6894307000000026 10.3735761199999992 Zr 112 | velocity -1.5693307098236764 -1.7093339603303812 -2.2737447969359001 113 | atom 10.1515445600000032 0.0000000000000000 10.3735761199999992 Zr 114 | velocity -0.4263959728648616 -2.9479468619921763 3.1815516775686503 115 | atom 10.1515445600000049 5.0757722800000016 10.3735761199999992 Zr 116 | velocity 2.0599470388679162 5.5606087655360401 -0.6230483969491378 117 | atom 12.6894307000000044 7.6136584200000001 10.3735761199999992 Zr 118 | velocity -0.4357464551336434 6.6291177692308167 1.4520401313074882 119 | atom 0.0000000000000000 10.1515445600000032 10.3735761199999992 Zr 120 | velocity -0.8726744986404369 0.2984330689305262 2.3981257596551702 121 | atom 2.5378861400000008 12.6894307000000026 10.3735761199999992 Zr 122 | velocity 3.4847380925953950 -0.3285738285110676 6.8548717198924987 123 | atom 5.0757722800000016 0.0000000000000000 10.3735761199999992 Zr 124 | velocity -0.9895504691811475 -1.1220447618538976 -0.1583520180313331 125 | atom 7.6136584200000001 2.5378861399999972 10.3735761199999992 Zr 126 | velocity -4.9511202765930458 -1.8685099717715665 -7.2253534855408530 127 | atom 0.0000000000000000 2.5378861400000008 2.5933940300000007 Zr 128 | velocity -2.2471790141681116 -5.8769600495633059 1.9889519292915052 129 | atom 2.5378861400000008 5.0757722800000016 2.5933940300000007 Zr 130 | velocity -2.2763739729980386 -0.1466547373170067 4.4420535595961059 131 | atom 5.0757722800000016 7.6136584200000001 2.5933940300000007 Zr 132 | velocity -4.0946946131652906 0.3554953930218806 2.7321678141109338 133 | atom 7.6136584200000001 10.1515445600000032 2.5933940300000007 Zr 134 | velocity -1.7251462883363968 4.8743843661462440 3.6112560757479151 135 | atom 10.1515445600000032 12.6894307000000044 2.5933940300000007 Zr 136 | velocity -2.0432291323022005 7.5869080763466217 -3.2477005457139470 137 | atom 12.6894307000000026 0.0000000000000000 2.5933940300000007 Zr 138 | velocity -0.8900983378195415 3.9429202320039529 -5.9218123291054408 139 | atom 12.6894307000000026 5.0757722800000016 2.5933940300000007 Zr 140 | velocity 1.1631785768940672 0.3111701562497449 5.3506357255343930 141 | atom 0.0000000000000000 7.6136584200000001 2.5933940300000007 Zr 142 | velocity 0.9867878329509551 -1.6517644109309941 -1.0970905309164298 143 | atom 2.5378861399999990 10.1515445600000032 2.5933940300000007 Zr 144 | velocity -0.3790543890279178 1.6282365528823799 0.9959107664004554 145 | atom 5.0757722799999998 12.6894307000000026 2.5933940300000007 Zr 146 | velocity 1.2722468112243683 1.6300239181875347 -4.7699265154236317 147 | atom 7.6136584200000001 0.0000000000000000 2.5933940300000007 Zr 148 | velocity -7.8066358246133092 -0.5503827676810736 0.1243256500377454 149 | atom 10.1515445600000032 2.5378861399999972 2.5933940300000007 Zr 150 | velocity -0.1706604753117558 -1.8463737691920390 -3.3160319422219264 151 | atom 10.1515445600000049 7.6136584200000001 2.5933940300000007 Zr 152 | velocity 1.3760448068263407 4.2171084541502202 1.7381107140931809 153 | atom 12.6894307000000044 10.1515445600000032 2.5933940300000007 Zr 154 | velocity -0.1322719739457492 -0.0443744547185619 -1.1465811376339208 155 | atom 0.0000000000000000 12.6894307000000026 2.5933940300000007 Zr 156 | velocity -1.6278848536029082 3.6050858936714709 -0.0631283634835853 157 | atom 2.5378861399999990 0.0000000000000000 2.5933940300000007 Zr 158 | velocity -0.4724206487909376 -1.5630568440815771 0.8783675048140031 159 | atom 5.0757722800000025 2.5378861399999972 2.5933940300000007 Zr 160 | velocity 1.9686175209128998 -2.1777905095938599 -3.3548137725375051 161 | atom 7.6136584200000001 5.0757722799999998 2.5933940300000007 Zr 162 | velocity 0.5936706340344706 0.9354232820918819 0.1645118491953596 163 | atom 0.0000000000000000 2.5378861400000008 7.7801820899999994 Zr 164 | velocity -0.0927572212851131 0.8884081678329232 -3.3279972110275517 165 | atom 2.5378861400000008 5.0757722800000016 7.7801820899999994 Zr 166 | velocity 1.1292997398686129 -3.3264718303119540 -3.0712181873749187 167 | atom 5.0757722800000016 7.6136584200000001 7.7801820899999994 Zr 168 | velocity 1.8282705547485856 -1.9066542110706637 -4.3903142422874124 169 | atom 7.6136584200000001 10.1515445600000032 7.7801820899999994 Zr 170 | velocity 0.2669963776324322 1.0987135203058971 4.3418258428962089 171 | atom 10.1515445600000032 12.6894307000000044 7.7801820899999994 Zr 172 | velocity 1.7280163127556263 1.5399834543414666 -1.2824271741004147 173 | atom 12.6894307000000026 0.0000000000000000 7.7801820899999994 Zr 174 | velocity 0.5188230652005044 1.0047998252903405 6.5089383469046274 175 | atom 12.6894307000000026 5.0757722800000016 7.7801820899999994 Zr 176 | velocity -4.0480520945270007 -5.2216298822395029 -1.0994666585602195 177 | atom 0.0000000000000000 7.6136584200000001 7.7801820899999994 Zr 178 | velocity 1.8756262552744498 -2.3602581879416440 -1.8840692966749499 179 | atom 2.5378861399999990 10.1515445600000032 7.7801820899999994 Zr 180 | velocity 3.3289211857590630 -1.2302799588196627 1.3157990447644206 181 | atom 5.0757722799999998 12.6894307000000026 7.7801820899999994 Zr 182 | velocity -1.9517265754027255 5.4348315966891398 -2.6057194734174596 183 | atom 7.6136584200000001 0.0000000000000000 7.7801820899999994 Zr 184 | velocity 0.2837378107741934 2.8600360170359989 -1.2862967731234607 185 | atom 10.1515445600000032 2.5378861399999972 7.7801820899999994 Zr 186 | velocity -1.2564436305099660 1.4314693517159247 5.6846205337567008 187 | atom 10.1515445600000049 7.6136584200000001 7.7801820899999994 Zr 188 | velocity -1.6451178230477184 -8.0126782301354673 -1.2173786210489856 189 | atom 12.6894307000000044 10.1515445600000032 7.7801820899999994 Zr 190 | velocity 3.4133832124338177 -2.4172916256456105 -1.7063491653831833 191 | atom 0.0000000000000000 12.6894307000000026 7.7801820899999994 Zr 192 | velocity 0.6901201650675416 0.5904204562167733 -2.0463071962058144 193 | atom 2.5378861399999990 0.0000000000000000 7.7801820899999994 Zr 194 | velocity 0.0726421477144409 -4.1508565832125717 -2.6961926661403330 195 | atom 5.0757722800000025 2.5378861399999972 7.7801820899999994 Zr 196 | velocity 1.2511814398695418 -4.9144550356598522 3.2794373382363160 197 | atom 7.6136584200000001 5.0757722799999998 7.7801820899999994 Zr 198 | velocity 1.6270618727065775 -0.6690478701449339 4.8421756795706203 199 | atom 0.0000000000000000 2.5378861400000008 12.9669701499999981 Zr 200 | velocity -1.1062480602548239 4.3883853076514798 0.0020812126835117 201 | atom 2.5378861400000008 5.0757722800000016 12.9669701499999981 Zr 202 | velocity 0.4592875125776427 0.5809673943198469 0.4259798918984733 203 | atom 5.0757722800000016 7.6136584200000001 12.9669701499999981 Zr 204 | velocity -4.8050290668998157 -2.4398573277991953 2.3808131839099671 205 | atom 7.6136584200000001 10.1515445600000032 12.9669701499999981 Zr 206 | velocity -2.0263823349276100 -0.5084484809209403 1.7094225902963021 207 | atom 10.1515445600000032 12.6894307000000044 12.9669701499999981 Zr 208 | velocity 1.2984269529940753 -1.6857359742694851 -4.8081590120892805 209 | atom 12.6894307000000026 0.0000000000000000 12.9669701499999981 Zr 210 | velocity -0.6099312803031103 -2.0859595221523888 1.2598265386981007 211 | atom 12.6894307000000026 5.0757722800000016 12.9669701499999981 Zr 212 | velocity -4.0755738500112990 0.4857131226938052 1.8982520335391626 213 | atom 0.0000000000000000 7.6136584200000001 12.9669701499999981 Zr 214 | velocity -0.1169275378061145 -0.4236983061967228 0.4284806163432516 215 | atom 2.5378861399999990 10.1515445600000032 12.9669701499999981 Zr 216 | velocity -1.9925037797247755 -0.3811559505989436 3.9093945271014814 217 | atom 5.0757722799999998 12.6894307000000026 12.9669701499999981 Zr 218 | velocity -3.4866029432618899 -2.0213879110176491 0.4362079751564558 219 | atom 7.6136584200000001 0.0000000000000000 12.9669701499999981 Zr 220 | velocity 6.9074706593540425 -2.3668095386940529 -2.4572462174445504 221 | atom 10.1515445600000032 2.5378861399999972 12.9669701499999981 Zr 222 | velocity -1.9400125337383947 2.9504612092739593 -4.6652597254130264 223 | atom 10.1515445600000049 7.6136584200000001 12.9669701499999981 Zr 224 | velocity -0.1479071319622944 -3.9782044061418631 6.0574511910006352 225 | atom 12.6894307000000044 10.1515445600000032 12.9669701499999981 Zr 226 | velocity 1.5047184177075579 -2.4513763171185490 4.1790798419066029 227 | atom 0.0000000000000000 12.6894307000000026 12.9669701499999981 Zr 228 | velocity 1.3506629892419568 2.9892125340884155 0.7160416837532774 229 | atom 2.5378861399999990 0.0000000000000000 12.9669701499999981 Zr 230 | velocity 1.0783737407588476 -1.8861087519646595 0.9858944361480352 231 | atom 5.0757722800000025 2.5378861399999972 12.9669701499999981 Zr 232 | velocity -1.0180225672768433 2.2181451850490919 1.1024179680670192 233 | atom 7.6136584200000001 5.0757722799999998 12.9669701499999981 Zr 234 | velocity 0.7144643046849545 -1.6626277516877273 -2.3972360561911179 235 | atom 1.2689430700000011 1.2689430700000011 1.0494925176094136 O 236 | velocity -9.1660885594522270 -7.5537430899539091 11.8940714423486131 237 | atom 3.8068292100000001 3.8068292100000001 1.0494925176094136 O 238 | velocity 10.4511377810829380 4.8608004447725763 1.2854396927072482 239 | atom 6.3447153500000004 6.3447153500000004 1.0494925176094136 O 240 | velocity -5.7586335284564560 -2.5983407062861477 3.8770736367077769 241 | atom 8.8826014899999990 8.8826014899999990 1.0494925176094136 O 242 | velocity -4.1062572046705599 -12.0812290468738173 3.2853286240770583 243 | atom 11.4204876300000038 11.4204876300000038 1.0494925176094136 O 244 | velocity 0.6431145686500094 1.7554769652874107 4.7702655305156494 245 | atom 13.9583737700000032 13.9583737700000032 1.0494925176094136 O 246 | velocity -5.4452414050169144 -3.1595847293613222 7.0782721896243279 247 | atom 13.9583737700000032 3.8068292100000001 1.0494925176094136 O 248 | velocity -1.2457966962579130 -6.6807259411816364 0.5431781745365515 249 | atom 1.2689430700000011 6.3447153500000004 1.0494925176094136 O 250 | velocity 6.8974540067516870 7.6999975838780585 3.6623557316566804 251 | atom 3.8068292100000001 8.8826014899999990 1.0494925176094136 O 252 | velocity 2.0088517807159803 5.4881700143578769 15.3439306814779179 253 | atom 6.3447153500000004 11.4204876300000020 1.0494925176094136 O 254 | velocity -0.9004339107992031 -3.3719609518370088 -5.1191131677262165 255 | atom 8.8826014900000043 13.9583737700000032 1.0494925176094136 O 256 | velocity 3.2072935658583241 -3.6592354458155580 3.8738541777973863 257 | atom 11.4204876300000038 1.2689430699999995 1.0494925176094136 O 258 | velocity 13.2976786276343280 0.0454408642193988 12.6891413019910022 259 | atom 11.4204876300000038 6.3447153500000004 1.0494925176094136 O 260 | velocity -1.0650178839606050 2.0578634838126248 -2.5546191580258033 261 | atom 13.9583737700000050 8.8826014899999990 1.0494925176094136 O 262 | velocity -9.6541863011583686 19.0568691696580288 -12.3149300755311799 263 | atom 1.2689430699999995 11.4204876300000038 1.0494925176094136 O 264 | velocity -3.4612992004287504 -4.6844507622590674 -17.5577416256374406 265 | atom 3.8068292100000001 13.9583737700000015 1.0494925176094136 O 266 | velocity 5.7295753075712872 -4.2135283574004392 -27.1385985518327431 267 | atom 6.3447153500000022 1.2689430699999995 1.0494925176094136 O 268 | velocity 3.8832396844793857 3.4997770453003767 -1.1850282998798853 269 | atom 8.8826014899999990 3.8068292100000001 1.0494925176094136 O 270 | velocity 7.3624321383036140 -7.4102206688385328 2.1298402788113076 271 | atom 1.2689430700000011 1.2689430700000011 6.2362805776094143 O 272 | velocity -3.5272039576225147 3.4303303834921119 15.1387810826521942 273 | atom 3.8068292100000001 3.8068292100000001 6.2362805776094143 O 274 | velocity -4.5266272944560804 0.2719676248998419 1.2187834140689096 275 | atom 6.3447153500000004 6.3447153500000004 6.2362805776094143 O 276 | velocity 6.8119949571755365 4.3339961591729388 -0.9627361307979836 277 | atom 8.8826014899999990 8.8826014899999990 6.2362805776094143 O 278 | velocity 0.6611453169968187 1.3414128397064189 2.4899184277979884 279 | atom 11.4204876300000038 11.4204876300000038 6.2362805776094143 O 280 | velocity -2.8922683197521835 3.1230220876843737 11.4957675980423879 281 | atom 13.9583737700000032 13.9583737700000032 6.2362805776094143 O 282 | velocity 7.0586335385139307 -7.2778371594009732 -2.3758649577791142 283 | atom 13.9583737700000032 3.8068292100000001 6.2362805776094143 O 284 | velocity -8.1580207319402227 -7.5333235444053894 -2.3962800398143429 285 | atom 1.2689430700000011 6.3447153500000004 6.2362805776094143 O 286 | velocity -3.4991157652178302 -3.8213635305961380 -4.2628201035099504 287 | atom 3.8068292100000001 8.8826014899999990 6.2362805776094143 O 288 | velocity -12.9247279505666821 7.0329710529688239 11.2558534367839851 289 | atom 6.3447153500000004 11.4204876300000020 6.2362805776094143 O 290 | velocity -1.3815052871301303 -6.2902597376737077 7.1539027624872817 291 | atom 8.8826014900000043 13.9583737700000032 6.2362805776094143 O 292 | velocity -1.8827546382718423 -0.4356944141048841 -0.8778886468622976 293 | atom 11.4204876300000038 1.2689430699999995 6.2362805776094143 O 294 | velocity -7.0214835292654056 -11.7367791369879519 -4.9621041619367050 295 | atom 11.4204876300000038 6.3447153500000004 6.2362805776094143 O 296 | velocity -13.9263315226871569 -3.3871513960817996 -2.4910489532130975 297 | atom 13.9583737700000050 8.8826014899999990 6.2362805776094143 O 298 | velocity 3.2075319573054406 -7.0003167796099577 -17.6928637961333557 299 | atom 1.2689430699999995 11.4204876300000038 6.2362805776094143 O 300 | velocity -11.1539101382172401 11.2400011826016879 3.4846379952336450 301 | atom 3.8068292100000001 13.9583737700000015 6.2362805776094143 O 302 | velocity 6.1551572599687665 -14.0695090119834312 14.9282628779358095 303 | atom 6.3447153500000022 1.2689430699999995 6.2362805776094143 O 304 | velocity 13.8344600488326588 4.1302640908447117 -9.5241583823609997 305 | atom 8.8826014899999990 3.8068292100000001 6.2362805776094143 O 306 | velocity 4.3588092796503792 16.4518353001358832 0.1719724526543752 307 | atom 1.2689430700000011 1.2689430700000011 11.4230686376094130 O 308 | velocity -1.5731092080031239 14.4617075895086877 5.6464191219764004 309 | atom 3.8068292100000001 3.8068292100000001 11.4230686376094130 O 310 | velocity 8.1829905470145814 4.7930497697333898 0.5501298289811878 311 | atom 6.3447153500000004 6.3447153500000004 11.4230686376094130 O 312 | velocity 3.9252469115082014 -1.8774004799650756 -5.1432625441124324 313 | atom 8.8826014899999990 8.8826014899999990 11.4230686376094130 O 314 | velocity 0.1079693245447899 -7.2188782837274639 5.5378080994603858 315 | atom 11.4204876300000038 11.4204876300000038 11.4230686376094130 O 316 | velocity -2.7015979876658758 2.6215305908816426 -12.5291355532003141 317 | atom 13.9583737700000032 13.9583737700000032 11.4230686376094130 O 318 | velocity -0.0473208478307459 3.0770000613750779 -5.2454599118439376 319 | atom 13.9583737700000032 3.8068292100000001 11.4230686376094130 O 320 | velocity 10.9299385369905568 -2.9217611794129796 3.3207569563481072 321 | atom 1.2689430700000011 6.3447153500000004 11.4230686376094130 O 322 | velocity -12.1395885263285095 2.0275741721349427 -9.2708422153284111 323 | atom 3.8068292100000001 8.8826014899999990 11.4230686376094130 O 324 | velocity -5.5099974264548477 12.0219257108277748 -3.4795702201124916 325 | atom 6.3447153500000004 11.4204876300000020 11.4230686376094130 O 326 | velocity 6.0228566991049988 -0.8051698779171211 8.4769198754577744 327 | atom 8.8826014900000043 13.9583737700000032 11.4230686376094130 O 328 | velocity -10.8075186144300197 9.3284900944528975 -7.4443443880240681 329 | atom 11.4204876300000038 1.2689430699999995 11.4230686376094130 O 330 | velocity -11.9436904146195655 2.2091936409771464 -1.3609621550841347 331 | atom 11.4204876300000038 6.3447153500000004 11.4230686376094130 O 332 | velocity -15.8333746854382795 8.9078876805392646 -2.5823536623635781 333 | atom 13.9583737700000050 8.8826014899999990 11.4230686376094130 O 334 | velocity 0.5706270037040955 9.9022938572355290 -9.3852690387591213 335 | atom 1.2689430699999995 11.4204876300000038 11.4230686376094130 O 336 | velocity -8.4148596809087266 -8.9161162215002907 3.3962059376619265 337 | atom 3.8068292100000001 13.9583737700000015 11.4230686376094130 O 338 | velocity 3.2383535873546672 -14.7932357521853390 -5.2409003484824925 339 | atom 6.3447153500000022 1.2689430699999995 11.4230686376094130 O 340 | velocity -1.5146996462566045 8.4325475309404982 -5.4845692594672739 341 | atom 8.8826014899999990 3.8068292100000001 11.4230686376094130 O 342 | velocity 2.9193266119061825 0.1898376175764067 -5.5713129273176785 343 | atom 1.2689430700000011 1.2689430700000011 3.6428865476094141 O 344 | velocity -7.1342515768217947 10.1194854634224729 -14.4084055238780504 345 | atom 3.8068292100000001 3.8068292100000001 3.6428865476094141 O 346 | velocity -7.0647326453656571 9.6007651961895721 -0.1560157109318254 347 | atom 6.3447153500000004 6.3447153500000004 3.6428865476094141 O 348 | velocity 1.6157213840645930 10.1200640874617704 -3.6569770622372220 349 | atom 8.8826014899999990 8.8826014899999990 3.6428865476094141 O 350 | velocity -4.8231217735800742 2.9240458260865525 -11.5650416631713693 351 | atom 11.4204876300000038 11.4204876300000038 3.6428865476094141 O 352 | velocity 2.1255590076615607 -4.7033470307096419 2.8370963518779284 353 | atom 13.9583737700000032 13.9583737700000032 3.6428865476094141 O 354 | velocity -13.7222980048816847 3.1284184371691874 -0.4910222524666513 355 | atom 13.9583737700000032 3.8068292100000001 3.6428865476094141 O 356 | velocity 4.4295712485792134 2.4668492982502661 -2.9019416497713517 357 | atom 1.2689430700000011 6.3447153500000004 3.6428865476094141 O 358 | velocity 12.7854660372388231 11.5132228952107258 -11.0492543126302216 359 | atom 3.8068292100000001 8.8826014899999990 3.6428865476094141 O 360 | velocity -1.0189251066170935 -7.6397511290338089 -11.2532489114257430 361 | atom 6.3447153500000004 11.4204876300000020 3.6428865476094141 O 362 | velocity 3.2992706257479751 -0.0766765147498132 11.8572126777842950 363 | atom 8.8826014900000043 13.9583737700000032 3.6428865476094141 O 364 | velocity -1.3670486791449232 2.7009185456365357 -12.0528036381672301 365 | atom 11.4204876300000038 1.2689430699999995 3.6428865476094141 O 366 | velocity 7.5348457560310784 4.2371653472187694 -4.7772801216414296 367 | atom 11.4204876300000038 6.3447153500000004 3.6428865476094141 O 368 | velocity -1.6292575759901264 -13.2403583899695150 -7.1732515428230545 369 | atom 13.9583737700000050 8.8826014899999990 3.6428865476094141 O 370 | velocity 0.3638136904601378 11.2053756965406386 9.6870215468679532 371 | atom 1.2689430699999995 11.4204876300000038 3.6428865476094141 O 372 | velocity 5.7217254016211854 -6.7525517474054695 3.9222495394484387 373 | atom 3.8068292100000001 13.9583737700000015 3.6428865476094141 O 374 | velocity 6.8539292566000656 -0.0765779420487294 -5.9489681213686074 375 | atom 6.3447153500000022 1.2689430699999995 3.6428865476094141 O 376 | velocity -1.4489922514656675 5.9160801065356488 5.3744978810981809 377 | atom 8.8826014899999990 3.8068292100000001 3.6428865476094141 O 378 | velocity 2.9443020028612077 -4.9940133553659356 2.6123496945678233 379 | atom 1.2689430700000011 1.2689430700000011 8.8296746076094141 O 380 | velocity 8.0015626093223791 9.6327060673050209 -9.7136945257768197 381 | atom 3.8068292100000001 3.8068292100000001 8.8296746076094141 O 382 | velocity 6.4183327897736682 -4.0053821646168632 -4.9392056183080904 383 | atom 6.3447153500000004 6.3447153500000004 8.8296746076094141 O 384 | velocity -4.7664081405460026 -5.7679670786215640 -15.0562627632010528 385 | atom 8.8826014899999990 8.8826014899999990 8.8296746076094141 O 386 | velocity -10.8277282668554040 -10.5309692668149104 6.2161991594976573 387 | atom 11.4204876300000038 11.4204876300000038 8.8296746076094141 O 388 | velocity 13.4204964190180380 -8.0858950428541796 11.0070571744734611 389 | atom 13.9583737700000032 13.9583737700000032 8.8296746076094141 O 390 | velocity -4.3136330117399631 -9.7179660654521367 1.2743975880378622 391 | atom 13.9583737700000032 3.8068292100000001 8.8296746076094141 O 392 | velocity -9.2070057982670264 -8.8295043712467436 -7.5396246270762814 393 | atom 1.2689430700000011 6.3447153500000004 8.8296746076094141 O 394 | velocity 8.7003818199889409 7.8206254913746704 -3.3672952766684419 395 | atom 3.8068292100000001 8.8826014899999990 8.8296746076094141 O 396 | velocity -2.1567101458613180 -1.6905002555222424 -0.3593007342069916 397 | atom 6.3447153500000004 11.4204876300000020 8.8296746076094141 O 398 | velocity 8.5986027721976939 -8.3924672993731093 3.6460849882936501 399 | atom 8.8826014900000043 13.9583737700000032 8.8296746076094141 O 400 | velocity -3.7369562211877669 5.8662385941623416 -5.3892617719903040 401 | atom 11.4204876300000038 1.2689430699999995 8.8296746076094141 O 402 | velocity 5.9683088170349219 13.4910406974500621 0.5471623469052094 403 | atom 11.4204876300000038 6.3447153500000004 8.8296746076094141 O 404 | velocity -11.4312934769396328 10.9485278091962961 -15.7917723823015894 405 | atom 13.9583737700000050 8.8826014899999990 8.8296746076094141 O 406 | velocity 1.6604995751670362 6.6653705630181097 2.7641975980482867 407 | atom 1.2689430699999995 11.4204876300000038 8.8296746076094141 O 408 | velocity -2.5023148765427039 -9.9279727295686815 -2.0938018011534987 409 | atom 3.8068292100000001 13.9583737700000015 8.8296746076094141 O 410 | velocity 0.4837496494789859 0.3804085090974915 -11.2911257311814097 411 | atom 6.3447153500000022 1.2689430699999995 8.8296746076094141 O 412 | velocity 0.6711639709629819 1.9377233664397249 0.1651181121503962 413 | atom 8.8826014899999990 3.8068292100000001 8.8296746076094141 O 414 | velocity -10.3405681837594337 4.6130327993262155 -5.6060789858474820 415 | atom 1.2689430700000011 1.2689430700000011 14.0164626676094120 O 416 | velocity 2.4295534847399014 -7.3491737964535444 -6.0687297044060529 417 | atom 3.8068292100000001 3.8068292100000001 14.0164626676094120 O 418 | velocity 2.0299872892987310 -8.3303647516551749 3.6889558970278995 419 | atom 6.3447153500000004 6.3447153500000004 14.0164626676094120 O 420 | velocity 18.4190300017067408 -9.1786965953480504 3.6604843380579744 421 | atom 8.8826014899999990 8.8826014899999990 14.0164626676094120 O 422 | velocity -0.3015538376647183 6.3216755813994459 7.0822791776771146 423 | atom 11.4204876300000038 11.4204876300000038 14.0164626676094120 O 424 | velocity 1.9205469652416627 3.6215256028622820 -6.5989495628796959 425 | atom 13.9583737700000032 13.9583737700000032 14.0164626676094120 O 426 | velocity -2.2029367928013972 -7.9841593931199624 -1.7072507421000904 427 | atom 13.9583737700000032 3.8068292100000001 14.0164626676094120 O 428 | velocity 1.2125050520904952 0.8554944780815864 -5.0085519302368091 429 | atom 1.2689430700000011 6.3447153500000004 14.0164626676094120 O 430 | velocity 4.2222666159165732 2.1912299024121205 -2.0927357075040613 431 | atom 3.8068292100000001 8.8826014899999990 14.0164626676094120 O 432 | velocity -5.1905280800070992 -6.5776990299849247 -11.4925452024722787 433 | atom 6.3447153500000004 11.4204876300000020 14.0164626676094120 O 434 | velocity -7.1592426959864328 3.3558319426903460 -5.4882345288300511 435 | atom 8.8826014900000043 13.9583737700000032 14.0164626676094120 O 436 | velocity -0.9866328964691482 9.0671824549891920 3.7969331242679938 437 | atom 11.4204876300000038 1.2689430699999995 14.0164626676094120 O 438 | velocity 9.4790082238803333 -17.4200177972141219 -3.3067558985530856 439 | atom 11.4204876300000038 6.3447153500000004 14.0164626676094120 O 440 | velocity -2.2367947595994608 1.5081599880679029 -2.0600648639830883 441 | atom 13.9583737700000050 8.8826014899999990 14.0164626676094120 O 442 | velocity -2.6250176308835749 -5.8760182405450969 2.9435615001467763 443 | atom 1.2689430699999995 11.4204876300000038 14.0164626676094120 O 444 | velocity 5.6564947332560056 -9.6038844899670472 9.0036247353061878 445 | atom 3.8068292100000001 13.9583737700000015 14.0164626676094120 O 446 | velocity 11.0263492895902058 -1.6861419303172720 12.8313235172568234 447 | atom 6.3447153500000022 1.2689430699999995 14.0164626676094120 O 448 | velocity 3.7422673878331589 -7.4707130926875456 0.6581663989664271 449 | atom 8.8826014899999990 3.8068292100000001 14.0164626676094120 O 450 | velocity -1.7053320831577932 8.3230080345371391 6.1086765790427648 451 | atom 13.9583737700000032 1.2689430700000011 1.5439015123905850 O 452 | velocity -0.3521538608847399 -0.2263149093549525 12.8570026965465978 453 | atom 1.2689430700000011 3.8068292100000001 1.5439015123905850 O 454 | velocity -2.6418847346860934 -14.6318514999135818 -8.3635909419767245 455 | atom 3.8068292100000001 6.3447153500000004 1.5439015123905850 O 456 | velocity 6.1292355831158964 0.9508825426889151 -0.8319465104090895 457 | atom 6.3447153500000022 8.8826014900000043 1.5439015123905850 O 458 | velocity 7.2581426357902323 12.5583359425693892 7.6520261531371849 459 | atom 8.8826014899999990 11.4204876300000038 1.5439015123905850 O 460 | velocity 1.1763132829067089 4.0833517245776259 -19.2394641506927293 461 | atom 11.4204876300000020 13.9583737700000032 1.5439015123905850 O 462 | velocity 3.2299147716716550 -13.4321401021369446 3.3432973389939176 463 | atom 11.4204876300000038 3.8068292100000001 1.5439015123905850 O 464 | velocity 2.6567118741037690 -2.5377510122142835 -6.1473181368457457 465 | atom 13.9583737700000032 6.3447153500000004 1.5439015123905850 O 466 | velocity 2.6918091238517210 -2.3320147681185595 3.2105752409445523 467 | atom 1.2689430699999995 8.8826014899999990 1.5439015123905850 O 468 | velocity 6.6846174955870037 -2.1060193667593845 -7.8733752409308160 469 | atom 3.8068292100000001 11.4204876300000038 1.5439015123905850 O 470 | velocity -4.8924149950218494 -3.1402803597918130 2.7094819344897347 471 | atom 6.3447153500000004 13.9583737700000032 1.5439015123905850 O 472 | velocity 7.6514568389422299 4.6941003483260832 -1.1742921703318225 473 | atom 8.8826014899999990 1.2689430699999995 1.5439015123905850 O 474 | velocity -4.7474571009028486 -7.2796762987773462 -3.2078293823848929 475 | atom 8.8826014900000043 6.3447153500000004 1.5439015123905850 O 476 | velocity 11.0453470053501732 -15.3842979025251267 0.7581446466147325 477 | atom 11.4204876300000038 8.8826014899999990 1.5439015123905850 O 478 | velocity -10.3625513407277534 6.5851936671153348 -0.4892306496458373 479 | atom 13.9583737700000050 11.4204876300000038 1.5439015123905850 O 480 | velocity 8.6469466240965964 -2.5571356799310490 -9.2566189287506440 481 | atom 1.2689430700000011 13.9583737700000032 1.5439015123905850 O 482 | velocity 6.5175006284102874 7.5823855082369533 9.1366589480262572 483 | atom 3.8068292100000001 1.2689430699999995 1.5439015123905850 O 484 | velocity -13.4744599939509513 -1.5158528825682021 -3.8381737665176971 485 | atom 6.3447153500000004 3.8068292100000001 1.5439015123905850 O 486 | velocity -3.4713657526391457 -0.5326404567614987 -5.5359831339287329 487 | atom 13.9583737700000032 1.2689430700000011 6.7306895723905846 O 488 | velocity -1.0002400757191054 -2.4199744439001192 1.0883614758733631 489 | atom 1.2689430700000011 3.8068292100000001 6.7306895723905846 O 490 | velocity 9.2013750999198525 2.7988447679789470 -10.6010441724685371 491 | atom 3.8068292100000001 6.3447153500000004 6.7306895723905846 O 492 | velocity -7.0596849352604503 3.2777242302728258 0.3407610166675124 493 | atom 6.3447153500000022 8.8826014900000043 6.7306895723905846 O 494 | velocity -7.5023266436961729 13.1555073758821628 -7.4420221183135835 495 | atom 8.8826014899999990 11.4204876300000038 6.7306895723905846 O 496 | velocity 5.3081184174547626 7.1543777657070384 3.7880363685692142 497 | atom 11.4204876300000020 13.9583737700000032 6.7306895723905846 O 498 | velocity 12.5072859845846853 3.2637228178101254 1.4229933653089237 499 | atom 11.4204876300000038 3.8068292100000001 6.7306895723905846 O 500 | velocity 7.7798212114528642 5.7679011117902581 4.9190397854668007 501 | atom 13.9583737700000032 6.3447153500000004 6.7306895723905846 O 502 | velocity -7.9412004366841273 -6.4017714735973357 3.5338779909581097 503 | atom 1.2689430699999995 8.8826014899999990 6.7306895723905846 O 504 | velocity -4.1177356799483436 5.6819127130626192 -5.4391968359563965 505 | atom 3.8068292100000001 11.4204876300000038 6.7306895723905846 O 506 | velocity -6.2227911864330290 -6.4424802414295224 -2.3107746408518937 507 | atom 6.3447153500000004 13.9583737700000032 6.7306895723905846 O 508 | velocity 3.1854754214418510 0.5156499995504454 0.7208791368180817 509 | atom 8.8826014899999990 1.2689430699999995 6.7306895723905846 O 510 | velocity 4.8004807372907186 -4.3914501416926397 3.2065651090892988 511 | atom 8.8826014900000043 6.3447153500000004 6.7306895723905846 O 512 | velocity 1.8886259787991011 -9.6289130756482244 -2.5068627482657364 513 | atom 11.4204876300000038 8.8826014899999990 6.7306895723905846 O 514 | velocity 3.5066783732840920 -1.3540170779380891 -7.8132750754184705 515 | atom 13.9583737700000050 11.4204876300000038 6.7306895723905846 O 516 | velocity 2.1236097222699084 2.2689662140626008 0.5814691592435226 517 | atom 1.2689430700000011 13.9583737700000032 6.7306895723905846 O 518 | velocity -2.8130796196209573 -8.4928034077466279 -2.8238143878303208 519 | atom 3.8068292100000001 1.2689430699999995 6.7306895723905846 O 520 | velocity 7.6512306708986157 0.0187160288006447 -8.1397300679789666 521 | atom 6.3447153500000004 3.8068292100000001 6.7306895723905846 O 522 | velocity 2.3488856239977771 7.0392410486622463 -7.8497082399073621 523 | atom 13.9583737700000032 1.2689430700000011 11.9174776323905842 O 524 | velocity 4.7258448023009940 -2.4054342093981385 2.0500241804466022 525 | atom 1.2689430700000011 3.8068292100000001 11.9174776323905842 O 526 | velocity -0.3280360230992813 1.5418192638813102 13.1668816565579370 527 | atom 3.8068292100000001 6.3447153500000004 11.9174776323905842 O 528 | velocity -7.7739952494976547 3.0852794806787491 -1.4661362059003333 529 | atom 6.3447153500000022 8.8826014900000043 11.9174776323905842 O 530 | velocity -0.6523779426851724 -4.7059239826259773 -6.4723261414377076 531 | atom 8.8826014899999990 11.4204876300000038 11.9174776323905842 O 532 | velocity -1.2668877804400061 -3.5102037661275234 3.0354564671993640 533 | atom 11.4204876300000020 13.9583737700000032 11.9174776323905842 O 534 | velocity -15.5979197040320106 4.7737982788635289 11.8996235504803689 535 | atom 11.4204876300000038 3.8068292100000001 11.9174776323905842 O 536 | velocity 12.3412532868100424 -9.3741247530507099 8.0805425778494904 537 | atom 13.9583737700000032 6.3447153500000004 11.9174776323905842 O 538 | velocity 11.1836864108869047 3.2399766074610885 2.4869487358951519 539 | atom 1.2689430699999995 8.8826014899999990 11.9174776323905842 O 540 | velocity -7.8797175532312504 6.2962507946239441 3.3790316811290793 541 | atom 3.8068292100000001 11.4204876300000038 11.9174776323905842 O 542 | velocity -7.5103736216329775 1.0921474835571783 -6.2620980077269150 543 | atom 6.3447153500000004 13.9583737700000032 11.9174776323905842 O 544 | velocity -0.0622435386642375 -6.1067437502584454 -2.7800642859593108 545 | atom 8.8826014899999990 1.2689430699999995 11.9174776323905842 O 546 | velocity 2.6929514904954206 8.2237675729433288 -17.8010607644783470 547 | atom 8.8826014900000043 6.3447153500000004 11.9174776323905842 O 548 | velocity 2.1406649094360639 -5.8546644338673630 -5.3781954900404774 549 | atom 11.4204876300000038 8.8826014899999990 11.9174776323905842 O 550 | velocity 0.5348601241533025 -15.2917669394792632 2.1549276229281036 551 | atom 13.9583737700000050 11.4204876300000038 11.9174776323905842 O 552 | velocity -4.5992621012257793 16.4602482349719459 -12.3114602734893595 553 | atom 1.2689430700000011 13.9583737700000032 11.9174776323905842 O 554 | velocity 5.1719230624182195 -11.0468619746612031 3.7107065567743902 555 | atom 3.8068292100000001 1.2689430699999995 11.9174776323905842 O 556 | velocity 10.7403430300879315 -18.1305143335417362 6.4692762747414330 557 | atom 6.3447153500000004 3.8068292100000001 11.9174776323905842 O 558 | velocity -0.3937904655760134 4.4820858829862482 -1.4331021451315660 559 | atom 13.9583737700000032 1.2689430700000011 4.1372955423905839 O 560 | velocity 2.3668985433670899 0.1476942167919024 9.8095972408522485 561 | atom 1.2689430700000011 3.8068292100000001 4.1372955423905839 O 562 | velocity -7.2228183976445424 8.8669536439268555 4.1009913682227772 563 | atom 3.8068292100000001 6.3447153500000004 4.1372955423905839 O 564 | velocity -5.4032588532899641 -14.9174792902691760 7.3986143911899012 565 | atom 6.3447153500000022 8.8826014900000043 4.1372955423905839 O 566 | velocity -2.8543310629590604 -3.8621556322530108 -3.7200571844595047 567 | atom 8.8826014899999990 11.4204876300000038 4.1372955423905839 O 568 | velocity -6.0994387718011414 7.8523390023679083 1.1657509919271229 569 | atom 11.4204876300000020 13.9583737700000032 4.1372955423905839 O 570 | velocity 1.8758068007087425 10.2533854389311632 -12.5512666657439702 571 | atom 11.4204876300000038 3.8068292100000001 4.1372955423905839 O 572 | velocity -11.9442181718117677 -3.9847887076726054 5.3110891643501681 573 | atom 13.9583737700000032 6.3447153500000004 4.1372955423905839 O 574 | velocity 0.9916447291151665 1.8135576493250838 0.0351077959046652 575 | atom 1.2689430699999995 8.8826014899999990 4.1372955423905839 O 576 | velocity -2.1070525915974803 5.7471430154552650 2.0417891219654871 577 | atom 3.8068292100000001 11.4204876300000038 4.1372955423905839 O 578 | velocity -8.1032867928969488 -3.4006475990600609 15.6475171685948915 579 | atom 6.3447153500000004 13.9583737700000032 4.1372955423905839 O 580 | velocity 5.6286730819708861 -0.1158625218584759 0.5568166012940949 581 | atom 8.8826014899999990 1.2689430699999995 4.1372955423905839 O 582 | velocity 3.1492833269485669 -1.3874368087444751 -7.6380182207200651 583 | atom 8.8826014900000043 6.3447153500000004 4.1372955423905839 O 584 | velocity -5.1667617950076083 -2.3079170212989144 1.2837320257519156 585 | atom 11.4204876300000038 8.8826014899999990 4.1372955423905839 O 586 | velocity -14.7226102410234425 -5.5093185891691352 10.0105522118223789 587 | atom 13.9583737700000050 11.4204876300000038 4.1372955423905839 O 588 | velocity 7.3974212441911114 2.5404470549909548 4.3808202535021712 589 | atom 1.2689430700000011 13.9583737700000032 4.1372955423905839 O 590 | velocity -1.8192849685215620 -9.6142049271284975 2.9569601335503499 591 | atom 3.8068292100000001 1.2689430699999995 4.1372955423905839 O 592 | velocity 2.2616909755904442 -8.6349231592391735 -1.5329392987051687 593 | atom 6.3447153500000004 3.8068292100000001 4.1372955423905839 O 594 | velocity 9.0011509416699855 -3.6473797915457982 -0.1695839059494839 595 | atom 13.9583737700000032 1.2689430700000011 9.3240836023905835 O 596 | velocity 13.7874294928194630 -6.0117939601791521 4.3796069080421933 597 | atom 1.2689430700000011 3.8068292100000001 9.3240836023905835 O 598 | velocity 4.9505660160462543 -0.0587824345455037 8.6811893223225702 599 | atom 3.8068292100000001 6.3447153500000004 9.3240836023905835 O 600 | velocity 16.7587266974382025 -3.5647184712770259 -3.7411235858553451 601 | atom 6.3447153500000022 8.8826014900000043 9.3240836023905835 O 602 | velocity 0.0890246909381031 -3.4673125250648678 10.5079756993917321 603 | atom 8.8826014899999990 11.4204876300000038 9.3240836023905835 O 604 | velocity -5.0209844671691481 -8.3610671931180534 -9.3529449222510532 605 | atom 11.4204876300000020 13.9583737700000032 9.3240836023905835 O 606 | velocity 11.1566134411352955 -1.4278722406744526 2.9231402721251669 607 | atom 11.4204876300000038 3.8068292100000001 9.3240836023905835 O 608 | velocity 5.3432973873381844 -19.5255850630742707 8.4944692760867930 609 | atom 13.9583737700000032 6.3447153500000004 9.3240836023905835 O 610 | velocity 3.2448655325661289 7.4814840653375363 5.9624703388634321 611 | atom 1.2689430699999995 8.8826014899999990 9.3240836023905835 O 612 | velocity -13.8977343002437532 3.7411532883783756 -12.1236190597745708 613 | atom 3.8068292100000001 11.4204876300000038 9.3240836023905835 O 614 | velocity -6.2019723890710345 0.7239259445393957 2.2754565929201482 615 | atom 6.3447153500000004 13.9583737700000032 9.3240836023905835 O 616 | velocity 1.5129830300980267 2.2309662335478251 -1.3269250404554858 617 | atom 8.8826014899999990 1.2689430699999995 9.3240836023905835 O 618 | velocity -10.0624710145050038 3.4915361824312141 5.8506800046626681 619 | atom 8.8826014900000043 6.3447153500000004 9.3240836023905835 O 620 | velocity 14.6461107951736995 -4.6447745614011016 14.2455033816090317 621 | atom 11.4204876300000038 8.8826014899999990 9.3240836023905835 O 622 | velocity -2.7570136708272632 0.1837710715135169 -1.2384404146057040 623 | atom 13.9583737700000050 11.4204876300000038 9.3240836023905835 O 624 | velocity 2.3068244990937354 -15.4847812781141609 -1.4282135069569046 625 | atom 1.2689430700000011 13.9583737700000032 9.3240836023905835 O 626 | velocity 0.8113841045178914 -3.5180161280940099 4.9215015238592947 627 | atom 3.8068292100000001 1.2689430699999995 9.3240836023905835 O 628 | velocity 0.1866060691663847 2.2759306721023465 6.1627456114294548 629 | atom 6.3447153500000004 3.8068292100000001 9.3240836023905835 O 630 | velocity -16.6310607223196172 -4.3225289756534080 3.8040125341521014 631 | atom 13.9583737700000032 1.2689430700000011 14.5108716623905813 O 632 | velocity -13.5775830222978762 -3.9715319333223662 -0.9101338355292713 633 | atom 1.2689430700000011 3.8068292100000001 14.5108716623905813 O 634 | velocity 11.0124137941101061 10.5870928525617458 -7.7390631023284202 635 | atom 3.8068292100000001 6.3447153500000004 14.5108716623905813 O 636 | velocity -6.6072435239492391 5.2896193114737127 -4.1931947640527127 637 | atom 6.3447153500000022 8.8826014900000043 14.5108716623905813 O 638 | velocity 1.0648638280664589 -2.4221923422757654 9.0094857419542773 639 | atom 8.8826014899999990 11.4204876300000038 14.5108716623905813 O 640 | velocity 0.9482837730385327 -0.8025289226781915 4.6693746417182576 641 | atom 11.4204876300000020 13.9583737700000032 14.5108716623905813 O 642 | velocity 8.7897444463594336 2.0304625262230092 -1.2727225848388475 643 | atom 11.4204876300000038 3.8068292100000001 14.5108716623905813 O 644 | velocity 11.8530997215694107 4.4842068803560506 -6.2383852329364844 645 | atom 13.9583737700000032 6.3447153500000004 14.5108716623905813 O 646 | velocity 3.7058430677658993 2.7462643508654572 -5.5946423447767284 647 | atom 1.2689430699999995 8.8826014899999990 14.5108716623905813 O 648 | velocity -2.6346271551432068 2.5278710170661189 -3.3181956730146376 649 | atom 3.8068292100000001 11.4204876300000038 14.5108716623905813 O 650 | velocity 14.1072769474807753 -5.0766325482500836 11.9260067299135191 651 | atom 6.3447153500000004 13.9583737700000032 14.5108716623905813 O 652 | velocity 9.5430544639099253 -9.2976735033215121 0.8852098394123070 653 | atom 8.8826014899999990 1.2689430699999995 14.5108716623905813 O 654 | velocity 0.0357185727243592 -2.4696620449659266 -0.0623126898690627 655 | atom 8.8826014900000043 6.3447153500000004 14.5108716623905813 O 656 | velocity 3.3215708343680492 4.6256774320556193 9.9537366410762491 657 | atom 11.4204876300000038 8.8826014899999990 14.5108716623905813 O 658 | velocity -1.8195858409885133 -22.1070138275505101 4.9145534438228786 659 | atom 13.9583737700000050 11.4204876300000038 14.5108716623905813 O 660 | velocity 1.5713643389587826 5.7586643721998430 8.9519132945090014 661 | atom 1.2689430700000011 13.9583737700000032 14.5108716623905813 O 662 | velocity 4.1749122107987731 12.2931983822067359 -4.6310461158817322 663 | atom 3.8068292100000001 1.2689430699999995 14.5108716623905813 O 664 | velocity -3.4468321835801112 -11.4511310686508754 -1.4683258159071526 665 | atom 6.3447153500000004 3.8068292100000001 14.5108716623905813 O 666 | velocity 0.7342302340189361 19.7130969321527303 -15.1900460570633129 667 | -------------------------------------------------------------------------------- /tests/assets/zro2_n_96.in: -------------------------------------------------------------------------------- 1 | #======================================================= 2 | # FHI-aims file: starting_geometries/geometry_2400_0.in 3 | # Created using the Atomic Simulation Environment (ASE) 4 | # Wed Jan 13 17:56:38 2021 5 | 6 | # Additional information: 7 | # starting geometry from CC, T=2400, ens=0 8 | #======================================================= 9 | lattice_vector 10.1515389999999996 0.0000000000000000 0.0000000000000000 10 | lattice_vector 0.0000000000000000 10.1515500000000003 0.0000000000000000 11 | lattice_vector 0.0000000000000000 0.0000000000000000 10.3735759999999999 12 | atom_frac 0.3234503891478918 0.5928700799385315 0.7423298185697970 Zr 13 | velocity 0.5713616863993055 0.6945967180438689 -6.9826345736665640 14 | atom_frac 0.8335255147027461 0.5535007580123232 0.7746102751837938 Zr 15 | velocity -1.3607901101886146 0.8050175163512294 -2.1648154913140978 16 | atom_frac 0.3349683895220223 0.0537455521570598 0.7418256115345373 Zr 17 | velocity 2.7270978701608874 -2.9334589506254316 3.8087937754478385 18 | atom_frac 0.7792406481421192 0.0746522698504169 0.7209692665287264 Zr 19 | velocity 4.7484388251808358 2.3931682133953212 -2.1262144585453226 20 | atom_frac 0.3168908083789069 0.5518454048889087 0.2621888122283000 Zr 21 | velocity 0.2869084107445167 4.3560551539599341 3.7044368081627908 22 | atom_frac 0.8180810269260652 0.5584295127345085 0.2680236101803274 Zr 23 | velocity 7.4509275257964180 -4.4911954274459287 1.8584698331640266 24 | atom_frac 0.3268082031699824 0.0635500105895159 0.2538195960582927 Zr 25 | velocity -5.2594287223610143 -6.8999253479553584 -2.9591886089246477 26 | atom_frac 0.8103185231323055 0.0602214666725771 0.2469619271117308 Zr 27 | velocity 2.1841307816437023 2.6339607517787718 -0.6140877576634283 28 | atom_frac 0.3079483701929334 0.7961881131452833 0.0128262288722809 Zr 29 | velocity -2.6056875783807758 3.1872306517775137 -3.3169113777857828 30 | atom_frac 0.8087460886472485 0.8054861385699721 0.9735624079873710 Zr 31 | velocity -0.0613724434862874 -0.3336224978003149 -3.5791277759477556 32 | atom_frac 0.2832146662688289 0.3029634252897341 0.0015553113024862 Zr 33 | velocity -2.7208190639946843 7.2579085237542973 -1.8866189495531311 34 | atom_frac 0.8213556161287467 0.2791528042515675 0.9974229426766622 Zr 35 | velocity 4.5702013119187850 -0.5149573743058480 5.8049915577627393 36 | atom_frac 0.3183746444750890 0.8028702976392766 0.4839715263087676 Zr 37 | velocity 8.7290949188013176 -4.5338290211459364 -6.7875338285808455 38 | atom_frac 0.8323534904412032 0.7974899389748363 0.4907984440466817 Zr 39 | velocity 0.5181432458843832 0.0453741836167826 -4.9451936905600515 40 | atom_frac 0.3201507722129620 0.3041760213957475 0.5216971061859478 Zr 41 | velocity -12.7773985111097037 -2.9030406325919449 1.2765403706717096 42 | atom_frac 0.8296376539557204 0.3142970502041560 0.5111824273519565 Zr 43 | velocity 1.8504068058179413 0.3577455322399059 0.6349468981283710 44 | atom_frac 0.5808748594671213 0.5486954415828125 0.9896472923126991 Zr 45 | velocity 4.4082519711592880 -1.9703219771194609 4.2798691624951264 46 | atom_frac 0.0594704674828122 0.5562166959725362 0.0227198017347152 Zr 47 | velocity -8.8323464629696833 4.5248582448560057 -6.0264820928584655 48 | atom_frac 0.5713110209200792 0.0616910274785624 0.0186708999866584 Zr 49 | velocity -4.5775049208885417 -1.4166370056757569 2.0423700813699837 50 | atom_frac 0.0848147507486303 0.0487826026567372 0.9763443368034320 Zr 51 | velocity 1.4112150911946761 5.8810408536289929 5.3256636301151170 52 | atom_frac 0.5890285512374036 0.5566270904443164 0.5062917898321658 Zr 53 | velocity 0.1130456561201160 -3.9907536974411753 5.3423264260287082 54 | atom_frac 0.0846452798930290 0.5671453137698185 0.5306978393950167 Zr 55 | velocity -2.7183749038797993 1.9734006273246707 6.6884960465310170 56 | atom_frac 0.5734680682406875 0.0819701464308406 0.4674270617962407 Zr 57 | velocity 0.5716415810857182 -0.2907028219248353 1.6042013003790425 58 | atom_frac 0.0361458681289605 0.0396957252833311 0.4861559899884090 Zr 59 | velocity 2.6514882930560146 4.4944846153761393 -0.7497333280991189 60 | atom_frac 0.5483013728263272 0.8146331210504800 0.7427069749139544 Zr 61 | velocity -2.8911481609450886 7.2180651002658767 4.8087585717189469 62 | atom_frac 0.0594988355952728 0.8321377710792933 0.7456877772910711 Zr 63 | velocity 0.1060819990526292 10.0406925145668051 -1.9363797763288200 64 | atom_frac 0.5530500577301629 0.3205775167338978 0.7443758372233451 Zr 65 | velocity -5.2944216233288053 -7.6661494983393625 4.3372817318187469 66 | atom_frac 0.0674866776357752 0.3268211642557048 0.7471192200259582 Zr 67 | velocity -3.1964861285712680 5.6497240419042258 -1.8889174335457974 68 | atom_frac 0.5552862654618183 0.7904836443695790 0.2308418697660286 Zr 69 | velocity 0.1499474127537301 2.6575213849627097 -7.7057251975440852 70 | atom_frac 0.0657522972625136 0.8108545916633421 0.2286983938807601 Zr 71 | velocity -3.0323446735857464 -0.9661928987743672 7.6388425729202307 72 | atom_frac 0.5572237864623285 0.3241477065078732 0.2621454433842293 Zr 73 | velocity 0.6176741857661247 -1.3489515633473188 -4.1610198586667870 74 | atom_frac 0.0634479126760977 0.3089673833059976 0.2477965284102608 Zr 75 | velocity -0.1567235051519292 -5.7116583178500582 3.6004030887618157 76 | atom_frac 0.4271954488871097 0.7120571784604321 0.8932203301927898 O 77 | velocity 11.5126094716460372 -2.9379140039110796 9.5538698054652542 78 | atom_frac 0.8958982918747592 0.6335581699346405 0.9154732745969181 O 79 | velocity 8.2857752723281237 7.3281802651555505 -11.7830737127942111 80 | atom_frac 0.4924084722523354 0.1809722741847304 0.9092329771334398 O 81 | velocity 2.9046503152353509 11.0333644834298212 -7.9739792190934926 82 | atom_frac 0.9574674342481472 0.1814120562869710 0.8855749868704871 O 83 | velocity 13.7403316982469850 -4.1130728843105215 13.2951063920476624 84 | atom_frac 0.4609495703065319 0.6866003733420020 0.3618528065924422 O 85 | velocity -14.5800056728314864 10.3261601918467090 -0.1679295929380796 86 | atom_frac 0.9438594965748544 0.6654330767222739 0.4056178968563974 O 87 | velocity 1.0327990551678965 0.0175595548642565 -0.7023336145299418 88 | atom_frac 0.4298804269973253 0.2059347872984914 0.3840154918612443 O 89 | velocity 5.2459530777620342 -16.0393706814456820 -12.8124104278983371 90 | atom_frac 0.9455493211423411 0.1905994404795327 0.3741144384540104 O 91 | velocity 20.5355056607847999 2.3948546048745207 2.7535386472463075 92 | atom_frac 0.4444653357485994 0.6706019799932030 0.1163907952281836 O 93 | velocity -13.5118274385632624 7.2559146945978306 8.1489450936098802 94 | atom_frac 0.9515191666997488 0.6643712369047091 0.1405775452939275 O 95 | velocity 0.7121827805314864 -2.3693325047451519 1.3004471974363767 96 | atom_frac 0.4524207974771116 0.1627770704966237 0.1517602348505472 O 97 | velocity -11.3684565388520724 4.9590830673126325 -10.7583762587430236 98 | atom_frac 0.9610971587657793 0.2000251597046756 0.1068794897728614 O 99 | velocity 18.8467883075062979 -30.5388429798784280 -19.6354409147077966 100 | atom_frac 0.4687372949067132 0.6863881289064232 0.6286452347772841 O 101 | velocity -26.9614836970991654 21.6737069794785349 -7.1777478976015443 102 | atom_frac 0.0059719516420121 0.6683810068413198 0.6739393310464973 O 103 | velocity -1.9795755500503349 -2.7798221905166227 3.4929613065018859 104 | atom_frac 0.4740572951549514 0.1461070447370106 0.6335871930759460 O 105 | velocity 5.9662661918256736 0.0708842734038097 12.7970245724097111 106 | atom_frac 0.9337810759531142 0.1777779363742483 0.6345032012104601 O 107 | velocity -3.1181860946287729 4.1823809597517716 -13.7619087484147933 108 | atom_frac 0.6511103873018663 0.6582613728937946 0.8457262404015742 O 109 | velocity 7.6164487602619841 -0.0679608099094410 1.8738549788445318 110 | atom_frac 0.1872673887181047 0.6723658633410662 0.8713402070799885 O 111 | velocity -12.9813065403579220 22.0518613672857882 15.6656828075185484 112 | atom_frac 0.7343179433187422 0.1496830385507632 0.8993010298473737 O 113 | velocity 11.6424078445315047 -2.0091651986075383 -6.3157731490877103 114 | atom_frac 0.1935689662424584 0.1756135309386251 0.8766046086711081 O 115 | velocity 8.0193568555443733 15.7836109865161909 -9.6234239880721617 116 | atom_frac 0.6976276119315505 0.6536494141288769 0.3860640814700735 O 117 | velocity 4.4919489736292002 -14.7383112914306924 -3.4204048488625833 118 | atom_frac 0.1980342507672974 0.6799383946293917 0.3893132724915690 O 119 | velocity 42.6825771014840996 1.4145185728115526 -4.2292385259586638 120 | atom_frac 0.7055327285843064 0.1834304268806241 0.3545185305433728 O 121 | velocity -23.4708353105369802 3.6913107986409144 -8.7858468962831289 122 | atom_frac 0.1899861085102466 0.1726435490146825 0.3671679476778307 O 123 | velocity -8.9514447202983387 -24.2635598806931121 -8.0328235856182246 124 | atom_frac 0.7147223696820748 0.6624527515502558 0.1134792254859848 O 125 | velocity -18.6284530892564000 2.0880369405231867 -6.8726899877586884 126 | atom_frac 0.2107551978079383 0.6789413724997660 0.1119815278742836 O 127 | velocity -6.7974094130603646 -18.1197506183793884 -1.1351024810555155 128 | atom_frac 0.6908874634673620 0.1812809373937970 0.1132758867337551 O 129 | velocity 4.9003212843341153 -23.4252773437961892 -6.6844464891995807 130 | atom_frac 0.2106132094847885 0.1532061606355680 0.1120646361486145 O 131 | velocity 11.5959464580720066 -12.0646635374607563 -2.9100512506719816 132 | atom_frac 0.7390910688517278 0.7107236500829922 0.6318223436161262 O 133 | velocity 8.9577609877862603 6.9354529237039335 -2.7393361950752197 134 | atom_frac 0.2044507616037332 0.7212750387871802 0.6303907090476804 O 135 | velocity 16.6198136865647399 -13.6646979545384486 10.0800940888404416 136 | atom_frac 0.6978965041655262 0.2152198432751649 0.6245685518667815 O 137 | velocity 10.6407191852658602 5.2719181921957761 2.7572888058761258 138 | atom_frac 0.2085892119411648 0.2136788746546094 0.6456768861576760 O 139 | velocity -12.8660483182505931 3.9911681290243259 11.9196426805194378 140 | atom_frac 0.4632009688383210 0.9539985844526206 0.8624587162613933 O 141 | velocity -0.1970943432763341 -8.4188760439811183 23.2464690147692288 142 | atom_frac 0.9248510772603051 0.9107849658426546 0.8438587069685515 O 143 | velocity 6.1787793058408917 2.9960102880114809 -9.6581418668737644 144 | atom_frac 0.4258722288315103 0.4327677655136408 0.8513009207239624 O 145 | velocity 1.6499487043670586 -17.9545284526808224 2.6640156702116795 146 | atom_frac 0.9381538158893938 0.4181898882436672 0.8511057401999078 O 147 | velocity -6.2508351305494232 -4.9081894180285746 -2.0487005721010121 148 | atom_frac 0.4715441914767801 0.9498360033689437 0.3834556598418906 O 149 | velocity 10.1663336036997514 -7.1644989653850768 -11.2703215604623974 150 | atom_frac 0.9518162418525902 0.9362535701444608 0.3386724973143302 O 151 | velocity 5.9714157423390866 2.4742109971225377 -3.5553205300408228 152 | atom_frac 0.4734646992933781 0.4696524550438110 0.3598325697907838 O 153 | velocity -6.4422682032680756 -5.6801995128792075 12.4368141373326537 154 | atom_frac 0.0069012373394812 0.4650788451024720 0.3534951852668742 O 155 | velocity 8.6098742776280694 -32.3098665878795686 17.3065251828291800 156 | atom_frac 0.4330538847360977 0.9017546187528014 0.1332885564245155 O 157 | velocity -10.7760718712019727 7.9305908363129634 -4.5037970484097007 158 | atom_frac 0.9405227512793874 0.9194520383586744 0.0875583308976577 O 159 | velocity 15.5725163522646906 -1.2387385810852587 -20.6908633599528962 160 | atom_frac 0.4298099293121960 0.4208812496613817 0.1203142175851413 O 161 | velocity -22.0500954581887996 0.9290171722431881 -4.9958748935916981 162 | atom_frac 0.9505291207569612 0.4224778137328783 0.0979028832487466 O 163 | velocity -9.6056364710230397 3.0263949061457329 18.7005079369894389 164 | atom_frac 0.3828565038266612 0.9123754963527737 0.6339854106240702 O 165 | velocity 7.0884520008192569 14.9981073439590222 -6.7909784659221053 166 | atom_frac 0.9205779409407776 0.9275362540695756 0.6305808517718480 O 167 | velocity 20.8972880748579151 8.7791640237780584 -8.9474403086150112 168 | atom_frac 0.4376490776423161 0.4213227053996680 0.5995198521705533 O 169 | velocity -0.5579611007103814 0.4477758833775834 -0.7252374768403282 170 | atom_frac 0.9327426688702080 0.4403553683920189 0.6211520453506101 O 171 | velocity -11.9885402097383302 10.6256089033067482 -3.8719114811714301 172 | atom_frac 0.7070193780470134 0.9086888780531053 0.8519929540208698 O 173 | velocity 14.3738791924278342 -10.2911924666323333 19.7870438193027880 174 | atom_frac 0.2140161348934383 0.9235337381976152 0.8807014774847168 O 175 | velocity -7.6598732382425476 8.7297986458502255 -1.8789851986331869 176 | atom_frac 0.7057862635409272 0.3976496505459757 0.8716853368597289 O 177 | velocity -19.5830750992321647 -11.7357202073228475 -6.2365125762749152 178 | atom_frac 0.1615792886182086 0.4129942787062075 0.8920597506587892 O 179 | velocity -13.7353777690631969 6.8198345734235453 -19.8429001957091593 180 | atom_frac 0.7410336491836361 0.9427599903463019 0.4092117163840127 O 181 | velocity 10.0911916866000819 -35.6306281453462859 16.0188243982140008 182 | atom_frac 0.1961722158581078 0.9204481975658890 0.3663380525674078 O 183 | velocity -1.1243337274426830 -2.9128038277498480 -27.5669171520166394 184 | atom_frac 0.6992730383048325 0.4095198654392679 0.4042878955145265 O 185 | velocity 4.9474412561202499 -14.2536895948260227 -7.7312833907263130 186 | atom_frac 0.2393471827276633 0.4549846102319350 0.4006424602277941 O 187 | velocity -8.7935962474951150 -10.8036246432117782 0.5354247450927119 188 | atom_frac 0.6854889342394292 0.9342950514945993 0.1290158051572572 O 189 | velocity -7.6980033796485676 0.7301569414230618 11.3512297725917346 190 | atom_frac 0.1778497378574815 0.9294099846821422 0.1358533759235966 O 191 | velocity 5.6403301546621041 14.1450667078430250 -1.6733660279495710 192 | atom_frac 0.6691993174630960 0.4269615763110067 0.1336563399159557 O 193 | velocity 5.2937414308830517 -4.7039893331805596 1.6753612063963361 194 | atom_frac 0.1750481281705168 0.4225274032044367 0.1436459086047087 O 195 | velocity 0.9900954521188362 10.1220974651784594 6.4911813220396475 196 | atom_frac 0.6418049755805501 0.9571137304155524 0.6265642503607242 O 197 | velocity 1.3751897561995075 1.7476417811656049 10.2958801195367240 198 | atom_frac 0.1575063761268119 0.9539916909240460 0.5880714712072288 O 199 | velocity 2.0877493450491840 -4.2518340236823873 5.6316839238620737 200 | atom_frac 0.6746895963262319 0.4583917795804582 0.6464185031275618 O 201 | velocity 10.9664790369913394 -14.0070789614973954 2.3977741422214778 202 | atom_frac 0.2028737593383624 0.4642065231417862 0.6261392734771500 O 203 | velocity 1.2142801480541554 24.6577154233525384 17.2481525639260447 204 | -------------------------------------------------------------------------------- /tests/interaction_test_case.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from schnax.utils.schnetkit import get_interaction_count 4 | 5 | 6 | class InteractionTestCase(TestCase): 7 | 8 | def __init__(self, method_name: str, geometry_file: str, weights_file: str): 9 | super().__init__(method_name) 10 | self.geometry_file = geometry_file 11 | self.weights_file = weights_file 12 | self.n_interactions = get_interaction_count(weights_file) -------------------------------------------------------------------------------- /tests/test_cfconv.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | import torch 5 | from jax_md.partition import NeighborList 6 | import jax.numpy as jnp 7 | 8 | from schnax.model.interaction.cfconv import CFConv 9 | 10 | import tests.test_utils.initialize as init 11 | import tests.test_utils.activation as activation 12 | from tests.interaction_test_case import InteractionTestCase 13 | 14 | 15 | class CFConvTest(InteractionTestCase): 16 | r_cutoff = 5.0 17 | 18 | def __init__(self, method_name: str): 19 | super().__init__(method_name, geometry_file="tests/assets/zro2_n_96.in", 20 | weights_file="tests/assets/model_n1.torch") 21 | 22 | def setUp(self): 23 | _, self.schnet_activations, __ = init.initialize_and_predict_schnet( 24 | self.geometry_file, self.weights_file, self.r_cutoff, sort_nl_indices=True 25 | ) 26 | 27 | self.schnax_activations, self.schnax_neighbors = self.initialize_schnax() 28 | 29 | def initialize_schnax(self): 30 | R, Z, box, neighbors, displacement_fn = init.initialize_schnax( 31 | self.geometry_file, self.r_cutoff, sort_nl_indices=True 32 | ) 33 | 34 | energy, energies, forces, schnax_activations = init.predict_schnax( 35 | R, Z, box, displacement_fn, neighbors, self.r_cutoff, self.weights_file, return_actiations=True 36 | ) 37 | 38 | return schnax_activations, neighbors 39 | 40 | def test_filter_network(self): 41 | for i in range(self.n_interactions): 42 | schnet_interaction, schnax_interaction = activation.get_cfconv_filters( 43 | self.schnet_activations, self.schnax_activations, interaction_block_idx=i 44 | ) 45 | 46 | relevant_schnax_interactions = schnax_interaction[:, 0:48, :] 47 | np.testing.assert_allclose( 48 | schnet_interaction, relevant_schnax_interactions, rtol=1e-6, atol=3 * 1e-6 49 | ) 50 | 51 | def test_cutoff_network(self): 52 | for i in range(self.n_interactions): 53 | schnet_cutoff, schnax_cutoff = activation.get_cutoff_network( 54 | self.schnet_activations, self.schnax_activations, interaction_block_idx=i 55 | ) 56 | self.assertEqual((96, 48), schnet_cutoff.shape) 57 | self.assertEqual((96, 60), schnax_cutoff.shape) 58 | 59 | relevant_schnax_cutoff = schnax_cutoff[:, 0:48] 60 | np.testing.assert_allclose( 61 | schnet_cutoff, relevant_schnax_cutoff, rtol=1e-6, atol=1e-6 62 | ) 63 | 64 | def test_in2f(self): 65 | for i in range(self.n_interactions): 66 | schnet_in2f, schnax_in2f = activation.get_in2f( 67 | self.schnet_activations, self.schnax_activations, interaction_block_idx=i 68 | ) 69 | 70 | np.testing.assert_allclose(schnet_in2f, schnax_in2f, rtol=1e-6, atol=3 * 1e-6) 71 | 72 | def test_reshaping_and_elementwise_product_equality(self): 73 | """Assert equality of reshaping logic and element-wise product (after in2f, before aggregation). 74 | As this operation is sensitive to the NL's padding strategy, we use schnax's neighbor list as input for both. 75 | 76 | To make it work with SchNetPack, we have to: 77 | (1) add a dimension for batches 78 | (2) change the padding strategy to something that works with how PyTorch indexes tensors 79 | """ 80 | 81 | # TODO: Maybe this test could be improved. 82 | # Instead of copy&pasting SchNet reshaping logic from schnetpack/nn/cfconv.py, we could mock the entire class 83 | # and all unnecessary forward pass calls. That way, we'd be able to test reshaping & element-wise multiplication 84 | # while keeping the unit test more robust.""" 85 | 86 | def schnet_reshape(schnet_in2f: jnp.ndarray, neighbors: NeighborList): 87 | def do_reshape(y: torch.Tensor, neighbors: torch.Tensor): 88 | """Reshape y for element-wise multiplication by W (filter block output). 89 | (n_batches, 96, 128) -> (n_batches, 96, 48, 128) 90 | 91 | Taken from schnetpack/nn/cfconv.py 92 | """ 93 | import torch 94 | 95 | nbh_size = neighbors.size() 96 | 97 | # (n_batches, n_atoms, max_occupancy) -> (n_batches, n_atoms * max_occupancy, 1) 98 | nbh = neighbors.view(-1, nbh_size[1] * nbh_size[2], 1) 99 | # (n_batches, n_atoms * max_occupancy, 1) -> (n_batches, n_atoms * max_occupancy, n_filters) 100 | nbh = nbh.expand(-1, -1, y.size(2)) 101 | # (n_batches, n_atoms, n_filters) -> (n_batches, n_atoms * max_occupancy, n_filters) 102 | y = torch.gather(y, 1, nbh) 103 | 104 | # (n_batches, n_atoms * max_occupancy, n_filters) -> (n_batches, n_atoms, max_occupancy, n_filters) 105 | y = y.view(nbh_size[0], nbh_size[1], nbh_size[2], -1) 106 | return y 107 | 108 | # convert to torch tensors and add a batch dimension for compatibility 109 | schnet_in2f = torch.tensor(schnet_in2f)[None, ...] 110 | 111 | self.assertEqual((96, 60), neighbors.idx.shape) 112 | neighbor_indices = neighbors.idx[:, 0:48] 113 | 114 | neighbors = torch.tensor(neighbor_indices, dtype=torch.int64)[None, ...] 115 | 116 | # replace padding by n_atoms through (n_atoms - 1) to prevent out-of-bound errors 117 | # background: JAX falls back to returning the last element (n_atoms - 1) when passed an index that is out-of-bounds. 118 | # to reproduce the same behavior in PyTorch, we have to explicitly provide the (n_atoms - 1) as padding. 119 | neighbors[neighbors == neighbors.shape[1]] = neighbors.shape[1] - 1 120 | 121 | return do_reshape(schnet_in2f, neighbors) 122 | 123 | for i in range(self.n_interactions): 124 | schnet_in2f, schnax_in2f = activation.get_in2f( 125 | self.schnet_activations, self.schnax_activations, interaction_block_idx=i 126 | ) 127 | np.testing.assert_allclose( 128 | schnet_in2f, schnax_in2f, rtol=1e-6, atol=3 * 1e-6 129 | ) 130 | 131 | schnet_in2f = schnet_reshape(schnet_in2f, self.schnax_neighbors) 132 | schnax_in2f = CFConv._reshape_y(schnax_in2f, self.schnax_neighbors) 133 | np.testing.assert_allclose( 134 | schnet_in2f[0], schnax_in2f[:, 0:48], rtol=1e-6, atol=3 * 1e-6 135 | ) 136 | 137 | schnet_W, schnax_W = activation.get_cfconv_filters( 138 | self.schnet_activations, self.schnax_activations, interaction_block_idx=i 139 | ) 140 | schnet_W = torch.tensor(schnet_W)[None, ...] # add batches dimension 141 | 142 | schnet_y = schnet_in2f * schnet_W 143 | schnax_y = schnax_in2f * schnax_W 144 | 145 | # TODO: Test seems volatile. See above for improvements. 146 | np.testing.assert_allclose( 147 | schnet_y[0], schnax_y[:, 0:48], rtol=1e-5, atol=8 * 1e-6 148 | ) 149 | 150 | def test_aggregate(self): 151 | for i in range(self.n_interactions): 152 | schnet_agg, schnax_agg = activation.get_aggregate( 153 | self.schnet_activations, self.schnax_activations, interaction_block_idx=i 154 | ) 155 | np.testing.assert_allclose( 156 | schnet_agg, schnax_agg, rtol=5 * 1e-6, atol=5 * 1e-6 157 | ) 158 | 159 | def test_f2out(self): 160 | for i in range(self.n_interactions): 161 | schnet_f2out, schnax_f2out = activation.get_f2out( 162 | self.schnet_activations, self.schnax_activations, interaction_block_idx=i 163 | ) 164 | np.testing.assert_allclose( 165 | schnet_f2out, schnax_f2out, rtol=1e-6, atol=2 * 1e-6 166 | ) 167 | -------------------------------------------------------------------------------- /tests/test_distance_expansion.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | 5 | import tests.test_utils.initialize as init 6 | import tests.test_utils.activation as activation 7 | 8 | 9 | class DistanceExpansionTest(TestCase): 10 | r_cutoff = 5.0 11 | atol = 1e-6 12 | rtol = 2 * 1e-5 # close to the edge - 1 * 1e-5 already fails. 13 | 14 | def setUp(self): 15 | _, schnet_activations, __ = init.initialize_and_predict_schnet( 16 | sort_nl_indices=True 17 | ) 18 | 19 | energy, energies, forces, schnax_activations = init.initialize_and_predict_schnax( 20 | r_cutoff=self.r_cutoff, sort_nl_indices=True, return_activations=True 21 | ) 22 | ( 23 | self.schnet_expansions, 24 | self.schnax_expansions, 25 | ) = activation.get_distance_expansion(schnet_activations, schnax_activations) 26 | 27 | def test_distance_expansions(self): 28 | self.assertEqual((96, 48, 25), self.schnet_expansions.shape) 29 | self.assertEqual((96, 60, 25), self.schnax_expansions.shape) 30 | 31 | # SchNetPack's neighborhood size is 48, so we only compare up to that index. 32 | # anything past idx 48 is just padding anyways. 33 | relevant_schnax_expansions = self.schnax_expansions[:, 0:48, :] 34 | np.testing.assert_allclose( 35 | self.schnet_expansions, 36 | relevant_schnax_expansions, 37 | rtol=self.rtol, 38 | atol=self.atol, 39 | ) 40 | -------------------------------------------------------------------------------- /tests/test_distances.py: -------------------------------------------------------------------------------- 1 | from itertools import zip_longest 2 | from unittest import TestCase 3 | 4 | import numpy as np 5 | 6 | from schnax import utils 7 | 8 | import tests.test_utils.initialize as init 9 | 10 | 11 | class DistancesTest(TestCase): 12 | r_cutoff = 5.0 13 | atol = 1e-6 14 | rtol = 1e-6 15 | 16 | def __init__(self, method_name: str): 17 | super().__init__(method_name) 18 | 19 | def setUp(self): 20 | self.schnet_R, self.schnet_nl, self.schnet_dR = self.initialize_schnet() 21 | self.schnax_R, self.schnax_nl, self.schnax_dR = self.initialize_schnax() 22 | 23 | def initialize_schnet(self): 24 | inputs, schnet_activations, preds = init.initialize_and_predict_schnet( 25 | sort_nl_indices=True 26 | ) 27 | # skip batches for now 28 | R = inputs['_positions'][0].detach().numpy() 29 | nl = inputs['_neighbors'][0].detach().numpy() 30 | dR = schnet_activations['representation.distances'][0].numpy() 31 | return R, nl, dR 32 | 33 | def initialize_schnax(self): 34 | R, Z, box, neighbors, displacement_fn = init.initialize_schnax( 35 | r_cutoff=self.r_cutoff, sort_nl_indices=True 36 | ) 37 | dR = utils.compute_distances(R, neighbors, displacement_fn) 38 | return R, neighbors.idx, dR 39 | 40 | def test_position_equality(self): 41 | self.assertEqual(self.schnet_R.shape, self.schnax_R.shape) 42 | np.testing.assert_allclose( 43 | self.schnet_R, self.schnax_R, atol=self.atol, rtol=self.rtol 44 | ) 45 | 46 | def test_nl_shapes(self): 47 | self.assertEqual((96, 48), self.schnet_nl.shape) 48 | self.assertEqual((96, 60), self.schnax_nl.shape) 49 | 50 | def test_neighborhood_equality(self): 51 | """Asserts that every atom indices the same neighboring atoms in both neighbor list implementations.""" 52 | 53 | # for the provided input data, we expect SchNetPack to default to a smaller neighborhood size than schnax. 54 | assert self.schnet_nl.shape[1] < self.schnax_nl.shape[1] 55 | 56 | # loop over neighborhoods 57 | for reference_atom_idx, (schnet_ngbhhood, schnax_ngbhhood) in enumerate( 58 | zip(self.schnet_nl, self.schnax_nl) 59 | ): 60 | 61 | # loop over indices within a neighborhood. use fill values to account for different neighborhood sizes. 62 | for position_in_neighborhood, (schnet_idx, schnax_idx) in enumerate( 63 | zip_longest(schnet_ngbhhood, schnax_ngbhhood, fillvalue=-1) 64 | ): 65 | 66 | # fill values should not affect schnax_nl; 67 | # schnet_nl should have a smaller neighborhood size and require filling. 68 | assert schnax_idx != -1 69 | 70 | # mismatching indices within a neighborhood are acceptable if caused by 71 | # (1) different numerical values used for padding 72 | # (2) different neighborhood sizes, requiring fill values as we iterate over both neighborhoods. 73 | if not schnet_idx == schnax_idx: 74 | 75 | # (1) In SchNetPack, a 0 is padding, if it does not occur at the neighborhoods 0-th position. 76 | if schnet_idx == 0 and position_in_neighborhood > 0: 77 | # as long as schnax pads at the same position (using n_atoms), this is fine. 78 | assert schnax_idx == self.schnax_nl.shape[0] 79 | 80 | # (2) SchNetPack fill values should only pairwise match with padding in schnax 81 | if schnet_idx != -1: 82 | assert schnax_idx == self.schnax_nl.shape[0] 83 | 84 | def test_distances_equality(self): 85 | # loop over neighborhoods 86 | for reference_atom_idx, (schnet_ngbhhood, schnax_ngbhhood) in enumerate( 87 | zip(self.schnet_dR, self.schnax_dR) 88 | ): 89 | 90 | # loop over distances within a neighborhood 91 | for position_in_neighborhood, (schnet_dr, schnax_dr) in enumerate( 92 | zip_longest(schnet_ngbhhood, schnax_ngbhhood, fillvalue=-1) 93 | ): 94 | 95 | # fill values should not affect schnax distances 96 | assert schnax_dr != -1 97 | 98 | try: 99 | np.testing.assert_allclose( 100 | schnet_dr, schnax_dr, rtol=self.rtol, atol=self.atol 101 | ) 102 | 103 | except AssertionError as ex: 104 | # this is fine, if we have 105 | # (1) encountered a fill value for schnet_dr 106 | # (2) schnax_dr is 0 since it comes from a padded index in the neighbor list. 107 | if ( 108 | schnet_dr == -1 109 | and schnax_dr == 0 110 | and self.schnax_nl[reference_atom_idx][position_in_neighborhood] 111 | == self.schnax_nl.shape[0] 112 | ): 113 | continue 114 | 115 | raise ex 116 | -------------------------------------------------------------------------------- /tests/test_embeddings.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import numpy as np 3 | 4 | import tests.test_utils.activation as activation 5 | import tests.test_utils.initialize as init 6 | 7 | 8 | class EmbeddingsTest(TestCase): 9 | """Asserts output equality of both embedding layers, i.e. equal shape and approximate and exact equality of output tensors.""" 10 | 11 | atol = 1e-6 12 | rtol = 1e-6 13 | 14 | def __init__(self, method_name: str): 15 | super().__init__(method_name) 16 | 17 | def setUp(self): 18 | _, schnet_activations, __ = init.initialize_and_predict_schnet() 19 | energy, energies, forces, schnax_activations = init.initialize_and_predict_schnax(return_activations=True) 20 | self.schnet_embeddings, self.schnax_embeddings = activation.get_embeddings( 21 | schnet_activations, schnax_activations 22 | ) 23 | 24 | def test_embeddings_shape_equality(self): 25 | self.assertEqual(self.schnet_embeddings.shape, (96, 128)) 26 | self.assertEqual(self.schnax_embeddings.shape, (96, 128)) 27 | 28 | def test_embeddings_approx_equality(self): 29 | """Simply added as a safe guard to notice numerical instabilities (e.g. when enabling double precision in JAX).""" 30 | np.testing.assert_allclose( 31 | self.schnet_embeddings, self.schnax_embeddings, self.rtol, self.atol 32 | ) 33 | 34 | def test_embeddings_exact_equality(self): 35 | """As the embedding layer simply performs a dictionary lookup from the same trained representations, the return values should be exactly equal.""" 36 | np.testing.assert_equal(self.schnet_embeddings, self.schnax_embeddings) 37 | -------------------------------------------------------------------------------- /tests/test_interaction.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | 5 | import tests.test_utils.initialize as init 6 | import tests.test_utils.activation as activation 7 | from tests.interaction_test_case import InteractionTestCase 8 | 9 | 10 | class InteractionTest(InteractionTestCase): 11 | """Asserts equal output of interactions blocks as wholes as well as individual layers. 12 | To bypass the fact that input neighbor lists are still not 100% equal, we temporarily use SchNetPack's representation (adapted to JAX-MD's conventions). 13 | That way, we can test an interaction block without having to deal with errors cascading down from the distance layer. 14 | """ 15 | r_cutoff = 5.0 16 | rtol = 1e-6 17 | atol = 1e-6 18 | 19 | def __init__(self, method_name: str): 20 | super().__init__(method_name, geometry_file="tests/assets/zro2_n_96.in", weights_file="tests/assets/model_n1.torch") 21 | 22 | def setUp(self): 23 | _, self.schnet_activations, _ = init.initialize_and_predict_schnet( 24 | geometry_file=self.geometry_file, 25 | weights_file=self.weights_file, 26 | r_cutoff=self.r_cutoff, 27 | sort_nl_indices=True, 28 | ) 29 | energy, energies, forces, self.schnax_activations = init.initialize_and_predict_schnax( 30 | self.geometry_file, self.weights_file, self.r_cutoff, sort_nl_indices=True, return_activations=True 31 | ) 32 | 33 | def test_interaction_block(self): 34 | for i in range(self.n_interactions): 35 | schnet_interaction, schnax_interaction = activation.get_interaction_output( 36 | self.schnet_activations, self.schnax_activations, interaction_block_idx=i 37 | ) 38 | np.testing.assert_allclose( 39 | schnet_interaction, schnax_interaction, rtol=self.rtol, atol=2 * self.atol 40 | ) 41 | -------------------------------------------------------------------------------- /tests/test_schnetkit_end_to_end.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | from ase.io import read 5 | from schnetkit import Calculator 6 | 7 | import tests.test_utils.initialize as init 8 | 9 | 10 | class EndToEndTest(TestCase): 11 | geometry_file = "tests/assets/zro2_n_96.in" 12 | weights_file = "tests/assets/model_n1.torch" 13 | 14 | r_cutoff = 5.0 15 | 16 | def __init__(self, method_name: str): 17 | super().__init__(method_name) 18 | 19 | def setUp(self): 20 | self.schnet_preds = self._initalize_and_predict_schnet() 21 | self.schnax_preds = self._initialize_and_predict_schnax() 22 | 23 | def _initalize_and_predict_schnet(self): 24 | schnet = Calculator(self.weights_file, skin=0.0, energies=True, stress=False) 25 | atoms = read(self.geometry_file) 26 | preds = schnet.calculate(atoms) 27 | 28 | return { 29 | 'energy': preds["energy"], 30 | 'energies': preds["energies"], 31 | 'forces': preds["forces"] 32 | } 33 | 34 | def _initialize_and_predict_schnax(self): 35 | energy, energies, forces, state = init.initialize_and_predict_schnax( 36 | self.geometry_file, self.weights_file, self.r_cutoff, sort_nl_indices=False, return_activations=True 37 | ) 38 | 39 | return { 40 | 'energy': energy, 41 | 'energies': energies, 42 | 'forces': forces 43 | } 44 | 45 | def test_energy_equality(self): 46 | np.testing.assert_allclose( 47 | self.schnet_preds['energy'], self.schnax_preds['energy'], rtol=1e-6, atol=4 * 1e-5 48 | ) 49 | 50 | def test_energies_equality(self): 51 | np.testing.assert_allclose( 52 | self.schnet_preds['energies'], self.schnax_preds['energies'], rtol=1e-6, atol=1e-5 53 | ) 54 | 55 | def test_forces_equality(self): 56 | np.testing.assert_allclose( 57 | self.schnet_preds['forces'], self.schnax_preds['forces'], rtol=1e-6, atol=3 * 1e-5 58 | ) 59 | -------------------------------------------------------------------------------- /tests/test_schnetkit_end_to_end_no_activations.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | from ase.io import read 5 | from schnetkit import Calculator 6 | 7 | import tests.test_utils.initialize as init 8 | 9 | 10 | class EndToEndTest(TestCase): 11 | geometry_file = "tests/assets/zro2_n_96.in" 12 | weights_file = "tests/assets/model_n1.torch" 13 | 14 | r_cutoff = 5.0 15 | 16 | def __init__(self, method_name: str): 17 | super().__init__(method_name) 18 | 19 | def setUp(self): 20 | self.schnet_preds = self._initalize_and_predict_schnet() 21 | self.schnax_preds = self._initialize_and_predict_schnax() 22 | 23 | def _initalize_and_predict_schnet(self): 24 | schnet = Calculator(self.weights_file, skin=0.0, energies=True, stress=False) 25 | atoms = read(self.geometry_file) 26 | preds = schnet.calculate(atoms) 27 | 28 | return { 29 | 'energy': preds["energy"], 30 | 'energies': preds["energies"], 31 | 'forces': preds["forces"] 32 | } 33 | 34 | def _initialize_and_predict_schnax(self): 35 | energy, energies, forces = init.initialize_and_predict_schnax( 36 | self.geometry_file, self.weights_file, self.r_cutoff, sort_nl_indices=False, return_activations=False 37 | ) 38 | 39 | return { 40 | 'energy': energy, 41 | 'energies': energies, 42 | 'forces': forces 43 | } 44 | 45 | def test_energy_equality(self): 46 | np.testing.assert_allclose( 47 | self.schnet_preds['energy'], self.schnax_preds['energy'], rtol=1e-6, atol=4 * 1e-5 48 | ) 49 | 50 | def test_energies_equality(self): 51 | np.testing.assert_allclose( 52 | self.schnet_preds['energies'], self.schnax_preds['energies'], rtol=1e-6, atol=1e-5 53 | ) 54 | 55 | def test_forces_equality(self): 56 | np.testing.assert_allclose( 57 | self.schnet_preds['forces'], self.schnax_preds['forces'], rtol=1e-6, atol=3 * 1e-5 58 | ) -------------------------------------------------------------------------------- /tests/test_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fabiannagel/schnax/e153581b6b44976e255955716a7aada07ee0838d/tests/test_utils/__init__.py -------------------------------------------------------------------------------- /tests/test_utils/activation.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def _dispatch_to_numpy(tensor: torch.tensor): 8 | return tensor.cpu().numpy()[0] 9 | 10 | 11 | def get_embeddings(schnet_activations: Dict, schnax_activations: Dict): 12 | return ( 13 | schnet_activations['representation.embedding'][0].numpy(), 14 | schnax_activations['SchNet']['embedding'], 15 | ) 16 | 17 | 18 | def get_distance_expansion(schnet_activations: Dict, schnax_activations: Dict): 19 | return ( 20 | schnet_activations['representation.distance_expansion'][0].numpy(), 21 | schnax_activations['SchNet/~/GaussianSmearing']['GaussianSmearing'], 22 | ) 23 | 24 | 25 | def get_interaction_output( 26 | schnet_activations: Dict, schnax_activations: Dict, interaction_block_idx: int 27 | ): 28 | k = 'representation.interactions.{}.dense'.format(interaction_block_idx) 29 | a_schnet = schnet_activations[k][0] 30 | 31 | k = 'SchNet/~/Interaction_{}'.format(interaction_block_idx) 32 | a_schnax = schnax_activations[k]['Output'] 33 | return a_schnax, a_schnet 34 | 35 | 36 | def get_cfconv_filters( 37 | schnet_activations: Dict, schnax_activations: Dict, interaction_block_idx: int 38 | ): 39 | k = 'representation.interactions.{}.filter_network.1'.format(interaction_block_idx) 40 | a_schnet_1 = _dispatch_to_numpy(schnet_activations[k]) 41 | 42 | k = 'SchNet/~/Interaction_{}/~/CFConv/~/FilterNetwork'.format(interaction_block_idx) 43 | a_schnax_1 = schnax_activations[k]['linear_1'] 44 | 45 | return a_schnet_1, a_schnax_1 46 | 47 | 48 | def get_cutoff_network( 49 | schnet_activations: Dict, schnax_activations: Dict, interaction_block_idx: int 50 | ): 51 | k = 'representation.interactions.{}.cutoff_network'.format(interaction_block_idx) 52 | schnet_cutoff = _dispatch_to_numpy(schnet_activations[k]) 53 | 54 | k = 'SchNet/~/Interaction_{}/~/CFConv/~/CosineCutoff'.format(interaction_block_idx) 55 | schnax_cutoff = schnax_activations[k]['CosineCutoff'] 56 | 57 | return schnet_cutoff, schnax_cutoff 58 | 59 | 60 | def get_in2f(schnet_activations: Dict, schnax_activations: Dict, interaction_block_idx: int): 61 | k = 'representation.interactions.{}.cfconv.in2f'.format(interaction_block_idx) 62 | schnet_in2f = _dispatch_to_numpy(schnet_activations[k]) 63 | 64 | k = 'SchNet/~/Interaction_{}/~/CFConv'.format(interaction_block_idx) 65 | schnax_in2f = schnax_activations[k]['in2f'] 66 | 67 | return schnet_in2f, schnax_in2f 68 | 69 | 70 | def get_aggregate( 71 | schnet_activations: Dict, schnax_activations: Dict, interaction_block_idx: int 72 | ): 73 | k = 'representation.interactions.{}.cfconv.agg'.format(interaction_block_idx) 74 | schnet_agg = _dispatch_to_numpy(schnet_activations[k]) 75 | 76 | k = 'SchNet/~/Interaction_{}/~/CFConv'.format(interaction_block_idx) 77 | schnax_agg = schnax_activations[k]['Aggregate'] 78 | 79 | return schnet_agg, schnax_agg 80 | 81 | 82 | def get_f2out( 83 | schnet_activations: Dict, schnax_activations: Dict, interaction_block_idx: int 84 | ): 85 | k = 'representation.interactions.{}.cfconv.f2out'.format(interaction_block_idx) 86 | schnet_f2out = _dispatch_to_numpy(schnet_activations[k]) 87 | 88 | k = 'SchNet/~/Interaction_{}/~/CFConv'.format(interaction_block_idx) 89 | schnax_f2out = schnax_activations[k]['f2out'] 90 | 91 | return schnet_f2out, schnax_f2out 92 | -------------------------------------------------------------------------------- /tests/test_utils/initialize.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax_md 3 | import numpy as np 4 | from ase.io import read 5 | from jax import numpy as jnp 6 | from jax_md import quantity 7 | from jax_md.partition import NeighborList 8 | from jax_md.space import DisplacementFn, Box 9 | from schnetpack import AtomsConverter 10 | from schnetpack.environment import AseEnvironmentProvider 11 | 12 | from schnetkit import load 13 | 14 | from schnax import energy, utils 15 | from schnax.energy import schnet_neighbor_list 16 | from schnax.utils.layer_hooks import ( 17 | register_representation_layer_hooks, 18 | register_output_layer_hooks, 19 | ) 20 | from schnax.utils.schnetkit import get_interaction_count 21 | from .mock_environment_provider import MockEnvironmentProvider 22 | 23 | 24 | def initialize_schnax( 25 | geometry_file="tests/assets/zro2_n_96.in", r_cutoff=5.0, sort_nl_indices=False 26 | ): 27 | atoms = read(geometry_file) 28 | R, Z, box = utils.atoms_to_input(atoms) 29 | displacement_fn, shift_fn = jax_md.space.periodic_general( 30 | box, fractional_coordinates=False 31 | ) 32 | 33 | neighbor_fn = jax_md.partition.neighbor_list( 34 | displacement_fn, 35 | box, 36 | r_cutoff, 37 | dr_threshold=0.0, # as the effective cutoff = r_cutoff + dr_threshold 38 | mask_self=True, # an atom is not a neighbor of itself 39 | fractional_coordinates=False, 40 | ) 41 | 42 | neighbors = neighbor_fn.allocate(R) 43 | if sort_nl_indices: 44 | neighbors = sort_schnax_nl(neighbors) 45 | 46 | return R, Z, box, neighbors, displacement_fn 47 | 48 | 49 | def sort_schnax_nl(neighbors: NeighborList) -> NeighborList: 50 | # constructing the NL with mask_self=True pads an *already existing* self-reference, 51 | # causing a padding index at position 0. sort in ascending order to move it to the end. 52 | new_indices = np.sort(neighbors.idx, axis=1) 53 | object.__setattr__(neighbors, 'idx', new_indices) 54 | return neighbors 55 | 56 | 57 | def predict_schnax( 58 | R: jnp.ndarray, Z: jnp.ndarray, box: Box, displacement_fn: DisplacementFn, neighbors: NeighborList, 59 | r_cutoff: float, weights_file="tests/assets/model_n1.torch", return_actiations=False 60 | ): 61 | neighbor_fn, init_fn, apply_fn = schnet_neighbor_list(displacement_fn, box, r_cutoff, dr_threshold=0.0, 62 | n_interactions=get_interaction_count(weights_file), 63 | per_atom=True, 64 | return_activations=return_actiations) 65 | params = utils.get_params(weights_file) 66 | rng = jax.random.PRNGKey(0) 67 | 68 | def pred_fn_stateless(energies_fn): 69 | # energies_fn = lambda R: apply_fn(R) 70 | energy_fn = lambda R: jnp.sum(energies_fn(R)) 71 | force_fn = quantity.force(energy_fn) 72 | return energy_fn(R), energies_fn(R), force_fn(R) 73 | # energy_fn = lambda R: jnp.sum(apply_fn) 74 | # return energy_fn(R), apply_fn(R), force_fn(R) 75 | 76 | def pred_fn_stateful(energies_fn): 77 | energy_fn = lambda R: jnp.sum(energies_fn(R)[0]) 78 | force_fn = quantity.force(energy_fn) 79 | energies, state = energies_fn(R) 80 | return energy_fn(R), energies, force_fn(R), state 81 | 82 | if return_actiations: 83 | _, state = init_fn(rng, R, Z, neighbors) 84 | apply_fn_stateful = lambda R: apply_fn(params, state, R, Z, neighbors) 85 | energy, energies, forces, state = pred_fn_stateful(apply_fn_stateful) 86 | return energy, energies, forces, state 87 | 88 | apply_fn_stateless = lambda R: apply_fn(params, R, Z, neighbors) 89 | energy, energies, forces = pred_fn_stateless(apply_fn_stateless) 90 | return energy, energies, forces 91 | 92 | 93 | def initialize_and_predict_schnax( 94 | geometry_file="tests/assets/zro2_n_96.in", 95 | weights_file="tests/assets/model_n1.torch", 96 | r_cutoff=5.0, 97 | sort_nl_indices=False, 98 | return_activations=False 99 | ): 100 | R, Z, box, neighbors, displacement_fn = initialize_schnax( 101 | geometry_file, r_cutoff, sort_nl_indices 102 | ) 103 | return predict_schnax(R, Z, box, displacement_fn, neighbors, r_cutoff, weights_file, return_activations) 104 | 105 | 106 | def get_schnet_inputs( 107 | geometry_file="tests/assets/zro2_n_96.in", r_cutoff=5.0, mock_environment_provider=None 108 | ): 109 | atoms = read(geometry_file, format="aims") 110 | if not mock_environment_provider: 111 | converter = AtomsConverter( 112 | environment_provider=AseEnvironmentProvider(cutoff=r_cutoff), device="cpu" 113 | ) 114 | else: 115 | converter = AtomsConverter( 116 | environment_provider=mock_environment_provider, device="cpu" 117 | ) 118 | 119 | return converter(atoms) 120 | 121 | 122 | def initialize_and_predict_schnet( 123 | geometry_file="tests/assets/zro2_n_96.in", 124 | weights_file="tests/assets/model_n1.torch", 125 | r_cutoff=5.0, 126 | sort_nl_indices=False, 127 | ): 128 | layer_outputs = {} 129 | 130 | mock_provider = None 131 | if sort_nl_indices: 132 | mock_provider = MockEnvironmentProvider(AseEnvironmentProvider(cutoff=r_cutoff)) 133 | 134 | inputs = get_schnet_inputs( 135 | geometry_file, r_cutoff, mock_environment_provider=mock_provider 136 | ) 137 | # inputs["_neighbor_mask"] = None 138 | 139 | model = load(weights_file) 140 | assert model.cutoff == r_cutoff 141 | model = model.model # get raw schnetpack model without schnetkit wrapper 142 | register_representation_layer_hooks(layer_outputs, model) 143 | register_output_layer_hooks(layer_outputs, model) 144 | 145 | preds = model(inputs) 146 | return inputs, layer_outputs, preds 147 | -------------------------------------------------------------------------------- /tests/test_utils/mock_environment_provider.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from schnetpack.environment import AseEnvironmentProvider 3 | 4 | 5 | class MockEnvironmentProvider: 6 | """Wraps around the default AseEnvironmentProvider to equalize NL conventions with JAX-MD. 7 | If we apply a consistent ordering to both the neighborhoods and offsets here, AtomsConverter() will implicitly apply it to all other inputs as well, making our life easier down the line.""" 8 | 9 | def __init__(self, environment_provider: AseEnvironmentProvider): 10 | self.environment_provider = environment_provider 11 | 12 | def get_environment(self, atoms, **kwargs): 13 | neighborhood_idx, offset = self.environment_provider.get_environment( 14 | atoms, **kwargs 15 | ) 16 | 17 | # replace -1 padding w/ atom count 18 | neighborhood_idx[neighborhood_idx == -1] = neighborhood_idx.shape[0] 19 | 20 | # that way, we can sort in ascending order and the padded indices stay "at the end" 21 | # as the same permutation has to be applied to offsets as well, we do this in two steps: 22 | # (1) sort and obtain indices of the new permutation. (2) apply the permutation to the neighborhoods. 23 | sorted_indices = np.argsort(neighborhood_idx, axis=1) 24 | neighborhood_idx = np.take_along_axis(neighborhood_idx, sorted_indices, axis=1) 25 | 26 | # reverse padding to -1 to stay compatible to the original AtomsConverter() 27 | # this gives us a SchNetPack-compatible NL with nice, ascending ordering and -1 padding (only!) at the end. 28 | # makes our life easier for comparing individual neighborhoods from both SchNet and schnax. 29 | neighborhood_idx[neighborhood_idx == neighborhood_idx.shape[0]] = -1 30 | 31 | # apply the same ordering to offsets. 32 | sorted_offset = np.empty_like(offset) 33 | for i, idx_row in enumerate(sorted_indices): 34 | for j, idx in enumerate(idx_row): 35 | 36 | matching_offset = offset[i][idx] 37 | sorted_offset[i][j] = matching_offset 38 | 39 | return neighborhood_idx, sorted_offset 40 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fabiannagel/schnax/e153581b6b44976e255955716a7aada07ee0838d/train/__init__.py -------------------------------------------------------------------------------- /train/iso17.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "iso17.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": { 23 | "colab": { 24 | "base_uri": "https://localhost:8080/" 25 | }, 26 | "id": "gN4EFP-t-ODL", 27 | "outputId": "a8e93321-0779-4b5b-896b-c9d19b7804c6" 28 | }, 29 | "outputs": [ 30 | { 31 | "output_type": "stream", 32 | "name": "stdout", 33 | "text": [ 34 | "--2022-04-05 13:46:45-- http://quantum-machine.org/datasets/iso17.tar.gz\n", 35 | "Resolving quantum-machine.org (quantum-machine.org)... 130.149.80.145\n", 36 | "Connecting to quantum-machine.org (quantum-machine.org)|130.149.80.145|:80... connected.\n", 37 | "HTTP request sent, awaiting response... 200 OK\n", 38 | "Length: 799712882 (763M) [application/x-gzip]\n", 39 | "Saving to: ‘train/data/iso17/iso17.tar.gz’\n", 40 | "\n", 41 | "iso17.tar.gz 100%[===================>] 762.67M 30.4MB/s in 26s \n", 42 | "\n", 43 | "2022-04-05 13:47:12 (29.3 MB/s) - ‘train/data/iso17/iso17.tar.gz’ saved [799712882/799712882]\n", 44 | "\n", 45 | "tar (child): iso17.tar.gz: Cannot open: No such file or directory\n", 46 | "tar (child): Error is not recoverable: exiting now\n", 47 | "tar: Child returned status 2\n", 48 | "tar: Error is not recoverable: exiting now\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "!mkdir -p train/data/iso17\n", 54 | "!wget http://quantum-machine.org/datasets/iso17.tar.gz -P train/data/iso17\n", 55 | "!tar xzvf train/data/iso17/iso17.tar.gz" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "source": [ 61 | "!tar xzvf train/data/iso17/iso17.tar.gz" 62 | ], 63 | "metadata": { 64 | "colab": { 65 | "base_uri": "https://localhost:8080/" 66 | }, 67 | "id": "l5saGKlA-3X_", 68 | "outputId": "6b647591-cc23-4409-eb9c-d1edd417bb5a" 69 | }, 70 | "execution_count": 12, 71 | "outputs": [ 72 | { 73 | "output_type": "stream", 74 | "name": "stdout", 75 | "text": [ 76 | "iso17/reference.db\n", 77 | "iso17/train_ids.txt\n", 78 | "iso17/validation_ids.txt\n", 79 | "iso17/\n", 80 | "iso17/test_other.db\n", 81 | "iso17/test_within.db\n", 82 | "iso17/README\n", 83 | "iso17/reference_eq.db\n", 84 | "iso17/test_eq.db\n" 85 | ] 86 | } 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "source": [ 92 | "" 93 | ], 94 | "metadata": { 95 | "id": "29mXldA6-_Wg" 96 | }, 97 | "execution_count": null, 98 | "outputs": [] 99 | } 100 | ] 101 | } -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import optax 3 | from jax import vmap, grad, jit, random, lax 4 | from jax_md import space 5 | from schnax.energy import schnet_neighbor_list 6 | from schnax.utils.train import build_dataset, make_batches 7 | 8 | # mkdir -p train/data/iso17 9 | # !wget http://quantum-machine.org/datasets/iso17.tar.gz -P train/data/iso17 10 | # !tar xzvf iso17.tar.gz 11 | train_set, val_set, test_set = build_dataset() 12 | train_positions, train_charges, train_energies, train_forces = train_set 13 | val_positions, val_charges, val_energies, val_forces = val_set 14 | test_positions, test_charges, test_energies, test_forces = test_set 15 | 16 | # TODO: What box size? Doesn't seem to be included in ISO17 17 | box_size = 10.862 18 | r_cutoff = 3.0 19 | dr_threshold = 0.1 20 | 21 | displacement_fn, shift_fn = space.periodic_general(box_size, fractional_coordinates=False) 22 | neighbor_fn, init_fn, energy_fn = schnet_neighbor_list(displacement_fn, box_size, r_cutoff, dr_threshold) 23 | neighbor = neighbor_fn.allocate(train_positions[0]) 24 | 25 | 26 | @jit 27 | def train_energy_fn(params, R, Z): 28 | _neighbor = neighbor.update(R) 29 | return energy_fn(params, R, Z, _neighbor) 30 | 31 | 32 | # Vectorize over states, not parameters 33 | vectorized_energy_fn = vmap(train_energy_fn, (None, 0, 0)) 34 | grad_fn = grad(train_energy_fn, argnums=1) 35 | force_fn = lambda params, R, Z, **kwargs: -grad_fn(params, R, Z) 36 | vectorized_force_fn = vmap(force_fn, (None, 0, 0)) 37 | 38 | # Initialize random parameters 39 | key = random.PRNGKey(0) 40 | params = init_fn(key, train_positions[0], train_charges[0], neighbor=neighbor) 41 | 42 | n_predictions = 500 43 | example_positions = train_positions[:n_predictions] 44 | example_charges = train_charges[:n_predictions] 45 | example_energies = train_energies[:n_predictions] 46 | example_forces = train_forces[:n_predictions] 47 | 48 | predicted = vmap(train_energy_fn, (None, 0, 0))(params, example_positions, example_charges) 49 | 50 | # seemingly no correlation from model priors - surprising? 51 | # import matplotlib.pyplot as plt 52 | # plt.plot(example_energies, predicted, 'o') 53 | # plt.show() 54 | 55 | @jit 56 | def energy_loss(params, R, Z, energies): 57 | return np.mean((vectorized_energy_fn(params, R, Z) - energies) ** 2) 58 | 59 | @jit 60 | def force_loss(params, R, Z, forces): 61 | dforces = vectorized_force_fn(params, R, Z) - forces 62 | return np.mean(np.sum(dforces ** 2, axis=(1, 2))) 63 | 64 | @jit 65 | def loss(params, R, Z, energies, forces): 66 | return energy_loss(params, R, Z, energies) + force_loss(params, R, Z, forces) 67 | 68 | 69 | opt = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3)) 70 | 71 | @jit 72 | def update_step(params, opt_state, R, Z, energies, forces): 73 | updates, opt_state = opt.update(grad(loss)(params, R, Z, energies, forces), opt_state) 74 | return optax.apply_updates(params, updates), opt_state 75 | 76 | @jit 77 | def update_epoch(params_and_opt_state, batches): 78 | def inner_update(params_and_opt_state, batch): 79 | params, opt_state = params_and_opt_state 80 | # b_xs, b_labels = batch 81 | # return update_step(params, opt_state, b_xs, b_labels), 0 82 | 83 | training_points, labels = batch 84 | positions, charges = training_points 85 | energies, forces = labels 86 | return update_step(params, opt_state, positions, charges, energies, forces), 0 87 | 88 | return lax.scan(inner_update, params_and_opt_state, batches)[0] 89 | 90 | 91 | batch_positions, batch_charges, batch_energies, batch_forces = make_batches(train_positions, train_charges, train_energies, train_forces, batch_size=128) 92 | train_epochs = 20 93 | opt_state = opt.init(params) 94 | 95 | train_energy_error = [] 96 | test_energy_error = [] 97 | 98 | for iteration in range(train_epochs): 99 | train_energy_error += [float(np.sqrt(energy_loss(params, batch_positions[0], batch_charges[0], batch_energies[0])))] 100 | # test_energy_error += [float(np.sqrt(energy_loss(params, , test_energies)))] 101 | # draw_training(params) 102 | 103 | params, opt_state = update_epoch((params, opt_state), 104 | ((batch_positions, batch_charges), (batch_energies, batch_forces))) 105 | 106 | # why shuffle here? 107 | # np.random.shuffle(lookup) 108 | # batch_positions, batch_charges, batch_energies, batch_forces = make_batches(lookup) 109 | 110 | print("Epoch {}/{}".format(iteration, train_epochs)) 111 | print("Training error: {}".format(train_energy_error[-1])) 112 | --------------------------------------------------------------------------------