├── .gitignore ├── pytype.cfg ├── .flake8 ├── ferminet ├── __init__.py ├── pbc │ ├── __init__.py │ ├── tests │ │ ├── __init__.py │ │ ├── hamiltonian_test.py │ │ └── features_test.py │ ├── feature_layer.py │ ├── envelopes.py │ └── hamiltonian.py ├── utils │ ├── __init__.py │ ├── utils.py │ ├── units.py │ ├── multi_host.py │ ├── tests │ │ ├── statistics_test.py │ │ ├── gto_test.py │ │ ├── units_test.py │ │ ├── scf_test.py │ │ ├── system_test.py │ │ └── elements_test.py │ ├── statistics.py │ ├── writers.py │ ├── system.py │ ├── analysis_tools.py │ └── pseudopotential.py ├── configs │ ├── __init__.py │ ├── excited │ │ ├── __init__.py │ │ ├── atoms.py │ │ ├── presets.py │ │ ├── carbon_dimer.py │ │ ├── benzene.py │ │ ├── twisted_ethylene.py │ │ ├── double_excitation.py │ │ └── oscillator.py │ ├── nh3.py │ ├── ch4.py │ ├── c2h4.py │ ├── hn.py │ ├── h4.py │ ├── heg.py │ ├── atom.py │ ├── hcl.py │ ├── li_excited.py │ ├── li_wqmc.py │ ├── diatomic.py │ └── organic.py ├── constants.py ├── main.py ├── tests │ ├── envelopes_test.py │ ├── network_blocks_test.py │ ├── psiformer_test.py │ ├── hamiltonian_test.py │ ├── excited_test.py │ └── networks_test.py ├── jastrows.py ├── sto.py ├── network_blocks.py ├── checkpoint.py └── curvature_tags_and_blocks.py ├── CONTRIBUTING.md ├── .github └── workflows │ └── ci-build.yaml └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | *.csv 4 | *.npz 5 | -------------------------------------------------------------------------------- /pytype.cfg: -------------------------------------------------------------------------------- 1 | [pytype] 2 | inputs = ferminet 3 | disable = wrong-arg-types 4 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 80 3 | ignore = 4 | # object names too complex 5 | C901, 6 | # four-space indents 7 | E111, E114, 8 | # line continuations 9 | E121, 10 | # line breaks around binary operators 11 | W503, W504 12 | max-complexity = 18 13 | select = B,C,F,W,T4,B9,E261 14 | exclude = 15 | .git, 16 | __pycache__ 17 | -------------------------------------------------------------------------------- /ferminet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /ferminet/pbc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | -------------------------------------------------------------------------------- /ferminet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /ferminet/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /ferminet/pbc/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | -------------------------------------------------------------------------------- /ferminet/configs/excited/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /ferminet/configs/nh3.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Ammonia example config.""" 16 | 17 | from ferminet import base_config 18 | from ferminet.utils import system 19 | import ml_collections 20 | 21 | 22 | def get_config() -> ml_collections.ConfigDict: 23 | """Returns config for running NH3 with FermiNet.""" 24 | cfg = base_config.default() 25 | # geometry in bohr. 26 | cfg.system.molecule = [ 27 | system.Atom(symbol='N', coords=(0.0, 0.0, 0.22013)), 28 | system.Atom(symbol='H', coords=(0.0, 1.77583, -0.51364)), 29 | system.Atom(symbol='H', coords=(1.53791, -0.88791, -0.51364)), 30 | system.Atom(symbol='H', coords=(-1.53791, -0.88791, -0.51364)), 31 | ] 32 | cfg.system.electrons = (5, 5) 33 | return cfg 34 | -------------------------------------------------------------------------------- /ferminet/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Constants for FermiNet.""" 16 | 17 | import functools 18 | import jax 19 | import kfac_jax 20 | 21 | 22 | # Axis name we pmap over. 23 | PMAP_AXIS_NAME = 'qmc_pmap_axis' 24 | 25 | # Shortcut for jax.pmap over PMAP_AXIS_NAME. Prefer this if pmapping any 26 | # function which does communications or reductions. 27 | pmap = functools.partial(jax.pmap, axis_name=PMAP_AXIS_NAME) 28 | 29 | # Shortcut for kfac utils 30 | psum = functools.partial(kfac_jax.utils.psum_if_pmap, axis_name=PMAP_AXIS_NAME) 31 | pmean = functools.partial( 32 | kfac_jax.utils.pmean_if_pmap, axis_name=PMAP_AXIS_NAME) 33 | all_gather = functools.partial(kfac_jax.utils.wrap_if_pmap(jax.lax.all_gather), 34 | axis_name=PMAP_AXIS_NAME) 35 | -------------------------------------------------------------------------------- /ferminet/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Main wrapper for FermiNet in JAX.""" 16 | 17 | from absl import app 18 | from absl import flags 19 | from absl import logging 20 | from ferminet import base_config 21 | from ferminet import train 22 | from ml_collections.config_flags import config_flags 23 | 24 | # internal imports 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | config_flags.DEFINE_config_file('config', None, 'Path to config file.') 29 | 30 | 31 | def main(_): 32 | cfg = FLAGS.config 33 | cfg = base_config.resolve(cfg) 34 | logging.info('System config:\n\n%s', cfg) 35 | train.train(cfg) 36 | 37 | 38 | def main_wrapper(): 39 | # For calling from setuptools' console_script entry-point. 40 | app.run(main) 41 | 42 | 43 | if __name__ == '__main__': 44 | app.run(main) 45 | -------------------------------------------------------------------------------- /ferminet/configs/ch4.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Methane example config.""" 16 | 17 | from ferminet import base_config 18 | from ferminet.utils import system 19 | import ml_collections 20 | 21 | 22 | def get_config() -> ml_collections.ConfigDict: 23 | """Returns config for running CH4 with FermiNet.""" 24 | cfg = base_config.default() 25 | # geometry in bohr. 26 | cfg.system.molecule = [ 27 | system.Atom(symbol='C', coords=(0.0, 0.0, 0.0)), 28 | system.Atom(symbol='H', coords=(1.18886, 1.18886, 1.18886)), 29 | system.Atom(symbol='H', coords=(-1.18886, -1.18886, 1.18886)), 30 | system.Atom(symbol='H', coords=(1.18886, -1.18886, -1.18886)), 31 | system.Atom(symbol='H', coords=(-1.18886, 1.18886, -1.18886)), 32 | ] 33 | cfg.system.electrons = (5, 5) 34 | return cfg 35 | -------------------------------------------------------------------------------- /ferminet/configs/c2h4.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Ethene example config.""" 16 | 17 | from ferminet import base_config 18 | from ferminet.utils import system 19 | import ml_collections 20 | 21 | 22 | def get_config() -> ml_collections.ConfigDict: 23 | """Returns config for running C2H4 with FermiNet.""" 24 | cfg = base_config.default() 25 | # geometry in bohr. 26 | cfg.system.molecule = [ 27 | system.Atom(symbol='C', coords=(0.0, 0.0, 1.26135)), 28 | system.Atom(symbol='C', coords=(0.0, 0.0, -1.26135)), 29 | system.Atom(symbol='H', coords=(0.0, 1.74390, 2.33889)), 30 | system.Atom(symbol='H', coords=(0.0, -1.74390, 2.33889)), 31 | system.Atom(symbol='H', coords=(0.0, 1.74390, -2.33889)), 32 | system.Atom(symbol='H', coords=(0.0, -1.74390, -2.33889)), 33 | ] 34 | cfg.system.electrons = (8, 8) 35 | return cfg 36 | -------------------------------------------------------------------------------- /ferminet/configs/excited/atoms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config to reproduce Fig. 1 from Pfau et al. (2024).""" 16 | 17 | from ferminet import base_config 18 | from ferminet.configs import atom 19 | from ferminet.configs.excited import presets 20 | import ml_collections 21 | 22 | 23 | def get_config() -> ml_collections.ConfigDict: 24 | """Returns config for running generic atoms with qmc.""" 25 | cfg = base_config.default() 26 | cfg.system.atom = '' 27 | cfg.system.charge = 0 28 | cfg.system.delta_charge = 0.0 29 | cfg.system.states = 5 30 | cfg.system.spin_polarisation = ml_collections.FieldReference( 31 | None, field_type=int) 32 | cfg.pretrain.iterations = 10_000 33 | cfg.update_from_flattened_dict(presets.excited_states) 34 | with cfg.ignore_type(): 35 | cfg.system.set_molecule = atom.adjust_nuclear_charge 36 | return cfg 37 | -------------------------------------------------------------------------------- /ferminet/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Generic utils for all QMC calculations.""" 16 | 17 | from typing import Any, Callable, Mapping, Sequence 18 | 19 | 20 | def select_output(f: Callable[..., Sequence[Any]], 21 | argnum: int) -> Callable[..., Any]: 22 | """Return the argnum-th result from callable f.""" 23 | 24 | def f_selected(*args, **kwargs): 25 | return f(*args, **kwargs)[argnum] 26 | 27 | return f_selected 28 | 29 | 30 | def flatten_dict_keys(input_dict: Mapping[str, Any], 31 | prefix: str = '') -> dict[str, Any]: 32 | """Flattens the keys of the given, potentially nested dictionary.""" 33 | output_dict = {} 34 | for key, value in input_dict.items(): 35 | nested_key = '{}.{}'.format(prefix, key) if prefix else key 36 | if isinstance(value, dict): 37 | output_dict.update(flatten_dict_keys(value, prefix=nested_key)) 38 | else: 39 | output_dict[nested_key] = value 40 | return output_dict 41 | -------------------------------------------------------------------------------- /ferminet/configs/hn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """1D hydrogen chain.""" 16 | 17 | from ferminet import base_config 18 | from ferminet.utils import system 19 | import ml_collections 20 | 21 | 22 | def _set_geometry(cfg: ml_collections.ConfigDict) -> ml_collections.ConfigDict: 23 | """Returns the config with the Hn molecule set.""" 24 | start = -(cfg.system.bond_length * (cfg.system.natoms - 1)) / 2 25 | atom_position = lambda i: (start + i * cfg.system.bond_length, 0, 0) 26 | cfg.system.molecule = [ 27 | system.Atom(symbol='H', coords=atom_position(i), units=cfg.system.units) 28 | for i in range(cfg.system.natoms) 29 | ] 30 | nalpha = cfg.system.natoms // 2 31 | cfg.system.electrons = (nalpha, cfg.system.natoms - nalpha) 32 | return cfg 33 | 34 | 35 | def get_config(): 36 | """Returns config for running Hn with FermiNet.""" 37 | cfg = base_config.default() 38 | cfg.system.update({ 39 | 'bond_length': 1.4, 40 | 'natoms': 2, 41 | }) 42 | with cfg.ignore_type(): 43 | cfg.system.set_molecule = _set_geometry 44 | cfg.config_module = '.h4' 45 | return cfg 46 | -------------------------------------------------------------------------------- /ferminet/utils/units.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Basic definition of units and converters useful for chemistry.""" 16 | 17 | from typing import TypeVar 18 | import numpy as np 19 | 20 | # 1 Bohr = 0.52917721067 (12) x 10^{-10} m 21 | # https://physics.nist.gov/cgi-bin/cuu/Value?bohrrada0 22 | # Note: pyscf uses a slightly older definition of 0.52917721092 angstrom. 23 | ANGSTROM_BOHR = 0.52917721067 24 | BOHR_ANGSTROM = 1. / ANGSTROM_BOHR 25 | 26 | # 1 Hartree = 627.509474 kcal/mol 27 | # https://en.wikipedia.org/wiki/Hartree 28 | KCAL_HARTREE = 627.509474 29 | HARTREE_KCAL = 1. / KCAL_HARTREE 30 | 31 | NumericalLike = TypeVar('NumericalLike', float, np.ndarray) 32 | 33 | 34 | def bohr2angstrom(x_b: NumericalLike) -> NumericalLike: 35 | return x_b * ANGSTROM_BOHR 36 | 37 | 38 | def angstrom2bohr(x_a: NumericalLike) -> NumericalLike: 39 | return x_a * BOHR_ANGSTROM 40 | 41 | 42 | def hartree2kcal(x_b: NumericalLike) -> NumericalLike: 43 | return x_b * KCAL_HARTREE 44 | 45 | 46 | def kcal2hartree(x_a: NumericalLike) -> NumericalLike: 47 | return x_a * HARTREE_KCAL 48 | -------------------------------------------------------------------------------- /ferminet/utils/multi_host.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Generic utilities.""" 16 | 17 | from absl import logging 18 | import jax 19 | import jax.numpy as jnp 20 | 21 | 22 | def check_synced(obj, name): 23 | """Checks whether the object is synced across local devices. 24 | 25 | Args: 26 | obj: PyTree with leaf nodes mapped over local devices. 27 | name: the name of the object (for logging only). 28 | 29 | Returns: 30 | True if object is in sync across all devices and False otherwise. 31 | """ 32 | for i in range(1, jax.local_device_count()): 33 | norms = jax.tree.map(lambda x: jnp.linalg.norm(x[0] - x[i]), obj) # pylint: disable=cell-var-from-loop 34 | total_norms = sum(jax.tree.leaves(norms)) 35 | if total_norms != 0.0: 36 | logging.info( 37 | '%s object is not synced across device 0 and %d. The total norm' 38 | ' of the difference is %.5e. For specific detail inspect ' 39 | 'the individual differences norms:\n %s.', 40 | name, i, total_norms, str(norms) 41 | ) 42 | return False 43 | logging.info('%s objects are synced.', name) 44 | return True 45 | -------------------------------------------------------------------------------- /ferminet/configs/h4.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Four hydrogen atoms on a circle.""" 16 | 17 | import itertools 18 | from ferminet import base_config 19 | from ferminet.utils import system 20 | import ml_collections 21 | import numpy as np 22 | 23 | 24 | def _set_geometry(cfg: ml_collections.ConfigDict) -> ml_collections.ConfigDict: 25 | """Returns the config with the H4 molecule set.""" 26 | t = np.radians(cfg.system.angle / 2) 27 | x = cfg.system.radius * np.cos(t) 28 | y = cfg.system.radius * np.sin(t) 29 | quadrants = itertools.product((1, -1), (1, -1)) 30 | cfg.system.molecule = [ 31 | system.Atom( 32 | symbol='H', coords=(i * x, j * y, 0.0), units=cfg.system.units) 33 | for i, j in quadrants 34 | ] 35 | 36 | return cfg 37 | 38 | 39 | def get_config(): 40 | """Returns config for running H4 with FermiNet.""" 41 | cfg = base_config.default() 42 | cfg.system.update({ 43 | 'angle': 90, 44 | 'radius': 1.738, 45 | 'units': 'angstrom', 46 | 'electrons': (2, 2), 47 | }) 48 | with cfg.ignore_type(): 49 | cfg.system.set_molecule = _set_geometry 50 | cfg.config_module = '.h4' 51 | return cfg 52 | -------------------------------------------------------------------------------- /ferminet/configs/excited/presets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Commonly used sets of parameters for running experiments. 16 | 17 | These are common optimization settings or network architectures that are 18 | designed to be applied on top of the base configuration from 19 | base_config.get_base_config. 20 | 21 | Usage: 22 | 23 | A flattened dict can be used directly to update a config, e.g.: 24 | 25 | config.update_from_flattened_dict(psiformer) 26 | """ 27 | 28 | from ferminet.utils import utils 29 | 30 | psiformer = utils.flatten_dict_keys({ 31 | 'network': { 32 | 'network_type': 'psiformer', 33 | 'determinants': 16, 34 | 'jastrow': 'simple_ee', 35 | 'rescale_inputs': True, 36 | 'psiformer': { 37 | 'num_heads': 4, 38 | 'mlp_hidden_dims': (256,), 39 | 'num_layers': 4, 40 | 'use_layer_norm': True, 41 | } 42 | } 43 | }) 44 | 45 | 46 | ferminet = utils.flatten_dict_keys({ 47 | 'network': { 48 | 'network_type': 'ferminet', 49 | 'determinants': 16, 50 | 'ferminet': { 51 | 'hidden_dims': 4 * ((256, 32),) 52 | } 53 | } 54 | }) 55 | 56 | 57 | excited_states = utils.flatten_dict_keys({ 58 | 'optim': { 59 | 'clip_median': True, 60 | 'reset_if_nan': True, 61 | 'laplacian': 'folx', 62 | }, 63 | 'pretrain': { 64 | 'basis': 'ccpvdz', 65 | 'scf_fraction': 1.0 66 | } 67 | }) 68 | -------------------------------------------------------------------------------- /.github/workflows/ci-build.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | # Trigger the workflow on push or pull request, 5 | # only on the main branch 6 | push: 7 | branches: 8 | - main 9 | pull_request: 10 | branches: 11 | - main 12 | workflow_dispatch: 13 | 14 | jobs: 15 | lint_and_typecheck: 16 | runs-on: ubuntu-latest 17 | strategy: 18 | matrix: 19 | python-version: ['3.11'] 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | pip install -e '.[testing]' 29 | - name: Lint with flake8 30 | run: | 31 | flake8 . 32 | - name: Lint with pylint 33 | run: | 34 | pylint --fail-under 9.5 ferminet 35 | - name: Type check with pytype 36 | run: | 37 | pytype ferminet 38 | build: 39 | name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ${{ matrix.os }})" 40 | runs-on: ${{ matrix.os }} 41 | strategy: 42 | matrix: 43 | include: 44 | - name-prefix: "all tests" 45 | python-version: '3.11' 46 | os: ubuntu-latest 47 | package-overrides: "none" 48 | steps: 49 | - uses: actions/checkout@v4 50 | - name: Set up Python ${{ matrix.python-version }} 51 | uses: actions/setup-python@v5 52 | with: 53 | python-version: ${{ matrix.python-version }} 54 | - name: Install dependencies 55 | run: | 56 | pip install --upgrade pip 57 | pip install -e '.[testing]' 58 | if [ ${{ matrix.package-overrides }} != none ]; then 59 | pip install ${{ matrix.package-overrides }} 60 | fi 61 | - name: Run tests 62 | run: | 63 | python -m pytest 64 | - name: Run multi-device tests 65 | if: matrix.python-version == '3.11' 66 | run: | 67 | FERMINET_CHEX_N_CPU_DEVICES=2 python -m pytest ferminet/tests/train_test.py 68 | -------------------------------------------------------------------------------- /ferminet/configs/excited/carbon_dimer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config to reproduce Fig. 3 from Pfau et al. (2024).""" 16 | 17 | from ferminet import base_config 18 | from ferminet.configs.excited import presets 19 | from ferminet.utils import system 20 | import ml_collections 21 | 22 | 23 | def finalise( 24 | experiment_config: ml_collections.ConfigDict) -> ml_collections.ConfigDict: 25 | """Returns the experiment config with the molecule set.""" 26 | # Equilibrium bond length is 1.244 Angstrom 27 | bond_length = experiment_config.system.equilibrium_multiple * 1.244 * 1.88973 28 | experiment_config.system.molecule = [ 29 | system.Atom('C', coords=(0, 0, bond_length / 2)), 30 | system.Atom('C', coords=(0, 0, -bond_length / 2))] 31 | return experiment_config 32 | 33 | 34 | def get_config() -> ml_collections.ConfigDict: 35 | """Returns config for running generic atoms with qmc.""" 36 | cfg = base_config.default() 37 | cfg.system.charge = 0 38 | cfg.system.delta_charge = 0.0 39 | cfg.system.molecule_name = 'C2' 40 | cfg.system.states = 8 41 | cfg.system.spin_polarisation = ml_collections.FieldReference( 42 | None, field_type=int) 43 | cfg.system.units = 'bohr' 44 | cfg.system.electrons = (6, 6) 45 | cfg.pretrain.iterations = 100_000 46 | cfg.optim.iterations = 100_000 47 | cfg.update_from_flattened_dict(presets.psiformer) 48 | cfg.update_from_flattened_dict(presets.excited_states) 49 | with cfg.ignore_type(): 50 | cfg.system.set_molecule = finalise 51 | return cfg 52 | -------------------------------------------------------------------------------- /ferminet/utils/tests/statistics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for ferminet.utils.statistics.""" 15 | 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | from ferminet.utils import statistics 19 | import numpy as np 20 | import pandas as pd 21 | 22 | 23 | class StatisticsTest(parameterized.TestCase): 24 | 25 | @parameterized.parameters(0.1, 0.2, 0.5) 26 | def test_exponentially_weighted_stats(self, alpha): 27 | # Generate some data which (loosely) mimics converging simulation. 28 | n = 10_000 29 | data = (0.1 * np.random.uniform(size=n) + np.exp(-np.arange(n) / 100) + 30 | (1 + 0.05 * np.random.normal(size=n))) 31 | 32 | stats = [] 33 | for x in data: 34 | stat = statistics.exponentialy_weighted_stats( 35 | alpha=alpha, 36 | observation=x, 37 | previous_stats=stats[-1] if stats else None) 38 | stats.append(stat) 39 | 40 | # Exponentially weighted algorithm is equivalent to that in pandas without 41 | # bias correction or adjusting for the initial iterations: 42 | ewm = pd.Series(data).ewm(adjust=False, alpha=alpha) 43 | expected_mean = ewm.mean() 44 | expected_variance = ewm.var(bias=True) 45 | with self.subTest('Check mean'): 46 | np.testing.assert_allclose([s.mean for s in stats], expected_mean) 47 | with self.subTest('Check variance'): 48 | np.testing.assert_allclose([s.variance for s in stats], expected_variance) 49 | 50 | 51 | if __name__ == '__main__': 52 | absltest.main() 53 | -------------------------------------------------------------------------------- /ferminet/utils/statistics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Simple statistics utilities.""" 15 | 16 | from typing import Generic, Optional, TypeVar, Union 17 | 18 | import attr 19 | from jax import numpy as jnp 20 | import numpy as np 21 | 22 | T = TypeVar('T', float, np.ndarray, jnp.ndarray) 23 | 24 | 25 | @attr.s(auto_attribs=True) 26 | class WeightedStats(Generic[T]): 27 | mean: T 28 | variance: T 29 | 30 | 31 | def exponentialy_weighted_stats( 32 | alpha: Union[float, T], 33 | observation: T, 34 | previous_stats: Optional[WeightedStats[T]] = None, 35 | ) -> WeightedStats[T]: 36 | """Returns the exponentially-weighted mean and variance. 37 | 38 | mu_t = alpha mu_{t-1} + (1-alpha) x_t 39 | 40 | and similarly for the variance. 41 | 42 | Args: 43 | alpha: weighting factor for previous observations. 44 | observation: new (t-th) value to include in the mean and variance. 45 | previous_stats: previous value of the mean and variance after (t-1) 46 | observations. Pass in None to indicate no prior observations have been 47 | made. 48 | """ 49 | if previous_stats is None: 50 | return WeightedStats[T](mean=observation, variance=0.0 * observation) 51 | else: 52 | # See Incremental calculation of weighted mean and variance, Tony Finch, 53 | # https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf 54 | diff = observation - previous_stats.mean 55 | incr = alpha * diff 56 | mean = previous_stats.mean + incr 57 | variance = (1 - alpha) * (previous_stats.variance + diff * incr) 58 | return WeightedStats[T](mean=mean, variance=variance) 59 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited and Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Setup for pip package.""" 16 | 17 | import unittest 18 | 19 | from setuptools import find_packages 20 | from setuptools import setup 21 | 22 | REQUIRED_PACKAGES = [ 23 | 'absl-py', 24 | 'attrs', 25 | 'chex', 26 | 'h5py', 27 | 'folx @ git+https://github.com/microsoft/folx', 28 | 'jax', 29 | 'jaxlib', 30 | # TODO(b/230487443) - use released version of kfac. 31 | 'kfac_jax @ git+https://github.com/deepmind/kfac-jax', 32 | 'ml-collections', 33 | 'optax', 34 | 'numpy', 35 | 'pandas', 36 | 'pyscf', 37 | 'pyblock', 38 | 'scipy', 39 | 'typing_extensions', 40 | ] 41 | 42 | 43 | def ferminet_test_suite(): 44 | test_loader = unittest.TestLoader() 45 | test_suite = test_loader.discover('ferminet/tests', pattern='*_test.py') 46 | return test_suite 47 | 48 | 49 | setup( 50 | name='ferminet', 51 | version='0.2', 52 | description=( 53 | 'A library to train networks to represent ground ' 54 | 'state wavefunctions of fermionic systems' 55 | ), 56 | url='https://github.com/deepmind/ferminet', 57 | author='DeepMind', 58 | author_email='no-reply@google.com', 59 | # Contained modules and scripts. 60 | entry_points={ 61 | 'console_scripts': [ 62 | 'ferminet = ferminet.main:main_wrapper', 63 | ], 64 | }, 65 | packages=find_packages(), 66 | install_requires=REQUIRED_PACKAGES, 67 | extras_require={'testing': ['flake8', 'pylint', 'pytest', 'pytype']}, 68 | platforms=['any'], 69 | license='Apache 2.0', 70 | test_suite='setup.ferminet_test_suite', 71 | ) 72 | -------------------------------------------------------------------------------- /ferminet/configs/heg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """Unpolarised 14 electron simple cubic homogeneous electron gas.""" 16 | 17 | from ferminet import base_config 18 | from ferminet.pbc import envelopes 19 | from ferminet.utils import system 20 | 21 | import numpy as np 22 | 23 | 24 | def _sc_lattice_vecs(rs: float, nelec: int) -> np.ndarray: 25 | """Returns simple cubic lattice vectors with Wigner-Seitz radius rs.""" 26 | volume = (4 / 3) * np.pi * (rs**3) * nelec 27 | length = volume**(1 / 3) 28 | return length * np.eye(3) 29 | 30 | 31 | def get_config(): 32 | """Returns config for running unpolarised 14 electron gas with FermiNet.""" 33 | # Get default options. 34 | cfg = base_config.default() 35 | cfg.system.electrons = (7, 7) 36 | # A ghost atom at the origin defines one-electron coordinate system. 37 | # Element 'X' is a dummy nucleus with zero charge 38 | cfg.system.molecule = [system.Atom("X", (0., 0., 0.))] 39 | # Pretraining is not currently implemented for systems in PBC 40 | cfg.pretrain.method = None 41 | 42 | lattice = _sc_lattice_vecs(1.0, sum(cfg.system.electrons)) 43 | kpoints = envelopes.make_kpoints(lattice, cfg.system.electrons) 44 | 45 | cfg.system.make_local_energy_fn = "ferminet.pbc.hamiltonian.local_energy" 46 | cfg.system.make_local_energy_kwargs = {"lattice": lattice, "heg": True} 47 | cfg.network.make_feature_layer_fn = ( 48 | "ferminet.pbc.feature_layer.make_pbc_feature_layer") 49 | cfg.network.make_feature_layer_kwargs = { 50 | "lattice": lattice, 51 | "include_r_ae": False 52 | } 53 | cfg.network.make_envelope_fn = ( 54 | "ferminet.pbc.envelopes.make_multiwave_envelope") 55 | cfg.network.make_envelope_kwargs = {"kpoints": kpoints} 56 | cfg.network.full_det = True 57 | return cfg 58 | -------------------------------------------------------------------------------- /ferminet/utils/tests/gto_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for gto.py.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from ferminet.utils import gto 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | import pyscf.dft 24 | import pyscf.gto 25 | import pyscf.lib 26 | 27 | 28 | class GtoTest(parameterized.TestCase): 29 | 30 | def setUp(self): 31 | super(GtoTest, self).setUp() 32 | pyscf.lib.param.TMPDIR = None 33 | pyscf.lib.num_threads(1) 34 | 35 | @parameterized.parameters([['a', True], ['a', False], ['b', False]]) 36 | def test_eval_gto(self, unit, jit): 37 | mol = pyscf.gto.M(atom='Na 0 0 -1; F 0 0 1', basis='def2-qzvp', unit=unit) 38 | solver = pyscf.dft.RKS(mol) 39 | solver.grids.build() 40 | coords = solver.grids.coords 41 | 42 | aos_pyscf = mol.eval_gto('GTOval_sph', coords) 43 | 44 | mol_jax = gto.Mol.from_pyscf_mol(mol) 45 | jax_eval_gto = jax.jit(mol_jax.eval_gto) if jit else mol_jax.eval_gto 46 | aos_jax = jax_eval_gto(coords) 47 | 48 | # Loose tolerances due to float32. With float64, these agree to better than 49 | # 1e-10. 50 | np.testing.assert_allclose(aos_pyscf, aos_jax, atol=4.e-4, rtol=2.e-4) 51 | 52 | def test_grad_solid_harmonic(self): 53 | np.random.seed(0) 54 | r = np.random.randn(100, 3) 55 | l_max = 5 56 | 57 | jax_grad = jax.jacfwd(lambda x: gto.solid_harmonic_from_cart(x, l_max)) 58 | expected = jnp.transpose(jnp.squeeze(jax.vmap(jax_grad)(r[:, None, :])), 59 | [1, 2, 0, 3]) 60 | 61 | with self.subTest('by hand'): 62 | observed = gto.grad_solid_harmonic(r, l_max) 63 | np.testing.assert_allclose(observed, expected, atol=1.e-4) 64 | with self.subTest('by jax'): 65 | observed_jacfwd = gto.grad_solid_harmonic_by_jacfwd(r, l_max) 66 | np.testing.assert_allclose(observed_jacfwd, expected, atol=1.e-4) 67 | 68 | if __name__ == '__main__': 69 | absltest.main() 70 | -------------------------------------------------------------------------------- /ferminet/tests/envelopes_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ferminet.envelopes.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from ferminet import envelopes 20 | import jax.numpy as jnp 21 | import numpy as np 22 | 23 | 24 | def _shape_options(dim2=None): 25 | shapes = [[(3, 4, 5), (5, 6, 4, 2)], [(3, 4, 1), (1, 6, 4, 2)], 26 | [(3, 1, 5), (5, 6, 1, 2)], [(3, 1, 1), (1, 6, 1, 2)]] 27 | for shape1, shape2 in shapes: 28 | if dim2: 29 | shape2 = shape2[:dim2] 30 | yield { 31 | 'testcase_name': f'_shape1={shape1}_shape2={shape2}', 32 | 'shapes': (shape1, shape2), 33 | } 34 | 35 | 36 | class ApplyCovarianceTest(parameterized.TestCase): 37 | 38 | @parameterized.named_parameters(_shape_options()) 39 | def test_apply_covariance(self, shapes): 40 | rng = np.random.RandomState(0).standard_normal 41 | if jnp.ones(1).dtype == jnp.float64: 42 | dtype = np.float64 43 | atol = 0 44 | else: 45 | dtype = np.float32 46 | atol = 1.e-6 47 | x = rng(shapes[0]).astype(dtype) 48 | y = rng(shapes[1]).astype(dtype) 49 | np.testing.assert_allclose( 50 | envelopes._apply_covariance(x, y), 51 | jnp.einsum('ijk,kmjn->ijmn', x, y), 52 | atol=atol, 53 | ) 54 | 55 | @parameterized.named_parameters(_shape_options(dim2=3)) 56 | def test_reduced_apply_covariance(self, shapes): 57 | rng = np.random.RandomState(0).standard_normal 58 | if jnp.ones(1).dtype == jnp.float64: 59 | dtype = np.float64 60 | atol = 0 61 | else: 62 | dtype = np.float32 63 | atol = 1.e-6 64 | x = rng(shapes[0]).astype(dtype) 65 | y = rng(shapes[1]).astype(dtype) 66 | np.testing.assert_allclose( 67 | jnp.squeeze( 68 | envelopes._apply_covariance(x, jnp.expand_dims(y, -1)), axis=-1), 69 | jnp.einsum('ijk,klj->ijl', x, y), 70 | atol=atol, 71 | ) 72 | 73 | 74 | if __name__ == '__main__': 75 | absltest.main() 76 | -------------------------------------------------------------------------------- /ferminet/configs/atom.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Generic single-atom configuration for FermiNet.""" 16 | 17 | from ferminet import base_config 18 | from ferminet.utils import elements 19 | from ferminet.utils import system 20 | import ml_collections 21 | 22 | 23 | def adjust_nuclear_charge(cfg): 24 | """Sets the molecule, nuclear charge electrons for the atom. 25 | 26 | Note: function name predates this logic but is kept for compatibility with 27 | xm_expt.py. 28 | 29 | Args: 30 | cfg: ml_collections.ConfigDict after all argument parsing. 31 | 32 | Returns: 33 | ml_collections.ConfictDict with the nuclear charge for the atom in 34 | cfg.system.molecule and cfg.system.charge appropriately set. 35 | """ 36 | if cfg.system.molecule: 37 | atom = cfg.system.molecule[0] 38 | else: 39 | atom = system.Atom(symbol=cfg.system.atom, coords=(0, 0, 0)) 40 | 41 | if abs(cfg.system.delta_charge) > 1.e-8: 42 | nuclear_charge = atom.charge + cfg.system.delta_charge 43 | cfg.system.molecule = [ 44 | system.Atom(atom.symbol, atom.coords, nuclear_charge) 45 | ] 46 | else: 47 | cfg.system.molecule = [atom] 48 | 49 | if not cfg.system.electrons: 50 | atomic_number = elements.SYMBOLS[atom.symbol].atomic_number 51 | if 'charge' in cfg.system: 52 | atomic_number -= cfg.system.charge 53 | if ('spin_polarisation' in cfg.system 54 | and cfg.system.spin_polarisation is not None): 55 | spin_polarisation = cfg.system.spin_polarisation 56 | else: 57 | spin_polarisation = elements.ATOMIC_NUMS[atomic_number].spin_config 58 | nalpha = (atomic_number + spin_polarisation) // 2 59 | cfg.system.electrons = (nalpha, atomic_number - nalpha) 60 | 61 | return cfg 62 | 63 | 64 | def get_config(): 65 | """Returns config for running generic atoms with qmc.""" 66 | cfg = base_config.default() 67 | cfg.system.atom = '' 68 | cfg.system.charge = 0 69 | cfg.system.delta_charge = 0.0 70 | cfg.system.spin_polarisation = ml_collections.FieldReference( 71 | None, field_type=int) 72 | with cfg.ignore_type(): 73 | cfg.system.set_molecule = adjust_nuclear_charge 74 | cfg.config_module = '.atom' 75 | return cfg 76 | -------------------------------------------------------------------------------- /ferminet/utils/writers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Writer utility classes.""" 15 | 16 | import contextlib 17 | import os 18 | from typing import Optional, Sequence 19 | 20 | from absl import logging 21 | 22 | 23 | class Writer(contextlib.AbstractContextManager): 24 | """Write data to CSV, as well as logging data to stdout if desired.""" 25 | 26 | def __init__(self, 27 | name: str, 28 | schema: Sequence[str], 29 | directory: str = 'logs/', 30 | iteration_key: Optional[str] = 't', 31 | log: bool = True): 32 | """Initialise Writer. 33 | 34 | Args: 35 | name: file name for CSV. 36 | schema: sequence of keys, corresponding to each data item. 37 | directory: directory path to write file to. 38 | iteration_key: if not None or a null string, also include the iteration 39 | index as the first column in the CSV output with the given key. 40 | log: Also log each entry to stdout. 41 | """ 42 | self._schema = schema 43 | if not os.path.isdir(directory): 44 | os.mkdir(directory) 45 | self._filename = os.path.join(directory, name + '.csv') 46 | self._iteration_key = iteration_key 47 | self._log = log 48 | 49 | def __enter__(self): 50 | self._file = open(self._filename, 'w', encoding='UTF-8') 51 | # write top row of csv 52 | if self._iteration_key: 53 | self._file.write(f'{self._iteration_key},') 54 | self._file.write(','.join(self._schema) + '\n') 55 | return self 56 | 57 | def write(self, t: int, **data): 58 | """Writes to file and stdout. 59 | 60 | Args: 61 | t: iteration index. 62 | **data: data items with keys as given in schema. 63 | """ 64 | row = [str(data.get(key, '')) for key in self._schema] 65 | if self._iteration_key: 66 | row.insert(0, str(t)) 67 | for key in data: 68 | if key not in self._schema: 69 | raise ValueError(f'Not a recognized key for writer: {key}') 70 | 71 | # write the data to csv 72 | self._file.write(','.join(row) + '\n') 73 | 74 | # write the data to abseil logs 75 | if self._log: 76 | logging.info('Iteration %s: %s', t, data) 77 | 78 | def __exit__(self, exc_type, exc_val, exc_tb): 79 | self._file.close() 80 | -------------------------------------------------------------------------------- /ferminet/configs/hcl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example excited states config for HCl with FermiNet and pseudopotentials.""" 16 | 17 | from ferminet import base_config 18 | from ferminet.utils import system 19 | import ml_collections 20 | import pyscf 21 | 22 | 23 | def finalize(cfg): 24 | """Sets the molecule, nuclear charge electrons for the atoms. 25 | 26 | Args: 27 | cfg: ml_collections.ConfigDict after all argument parsing. 28 | 29 | Returns: 30 | ml_collections.ConfictDict with the nuclear charge for the atom in 31 | cfg.system.molecule and cfg.system.charge appropriately set. 32 | """ 33 | 34 | # Create a pyscf Mole object with pseudopotentials to be used for 35 | # pretraining and updating the config for consistency 36 | mol = pyscf.gto.Mole() 37 | mol.atom = [[atom.symbol, atom.coords] for atom in cfg.system.molecule] 38 | 39 | atoms = list(set([atom.symbol for atom in cfg.system.molecule])) 40 | pseudo_atoms = cfg.system.pp.symbols if cfg.system.use_pp else [] 41 | mol.basis = { 42 | atom: 43 | cfg.system.pp.basis if atom in pseudo_atoms else 'cc-pvdz' 44 | for atom in atoms 45 | } 46 | mol.ecp = { 47 | atom: cfg.system.pp.type 48 | for atom in atoms if atom in pseudo_atoms 49 | } 50 | 51 | mol.charge = 0 52 | mol.spin = 0 53 | mol.unit = 'angstrom' 54 | mol.build() 55 | 56 | cfg.system.pyscf_mol = mol 57 | 58 | return cfg 59 | 60 | 61 | def get_config(): 62 | """Returns config for running generic atoms with qmc.""" 63 | cfg = base_config.default() 64 | cfg.system.molecule = [ 65 | system.Atom(symbol='H', coords=(0.0, 0.0, 0.0), units='angstrom'), 66 | system.Atom(symbol='Cl', coords=(0.0, 0.0, 1.2799799), units='angstrom'), 67 | ] 68 | cfg.system.electrons = (9, 9) # Core electrons are removed automatically 69 | cfg.system.use_pp = True # Enable pseudopotentials 70 | cfg.system.pp.symbols = ['Cl'] # Indicate which atoms to apply PP to 71 | cfg.system.charge = 0 72 | cfg.system.delta_charge = 0.0 73 | cfg.system.states = 3 74 | cfg.pretrain.iterations = 10_000 75 | cfg.optim.reset_if_nan = True 76 | cfg.system.spin_polarisation = ml_collections.FieldReference( 77 | None, field_type=int) 78 | with cfg.ignore_type(): 79 | cfg.system.set_molecule = finalize 80 | cfg.config_module = '.diatomic' 81 | return cfg 82 | -------------------------------------------------------------------------------- /ferminet/configs/li_excited.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example excited states config for lithium atom with FermiNet.""" 16 | 17 | from ferminet import base_config 18 | from ferminet.utils import elements 19 | from ferminet.utils import system 20 | import ml_collections 21 | 22 | 23 | def _adjust_nuclear_charge(cfg): 24 | """Sets the molecule, nuclear charge electrons for the atom. 25 | 26 | Note: function name predates this logic but is kept for compatibility with 27 | xm_expt.py. 28 | 29 | Args: 30 | cfg: ml_collections.ConfigDict after all argument parsing. 31 | 32 | Returns: 33 | ml_collections.ConfictDict with the nuclear charge for the atom in 34 | cfg.system.molecule and cfg.system.charge appropriately set. 35 | """ 36 | if cfg.system.molecule: 37 | atom = cfg.system.molecule[0] 38 | else: 39 | atom = system.Atom(symbol=cfg.system.atom, coords=(0, 0, 0)) 40 | 41 | if abs(cfg.system.delta_charge) > 1.e-8: 42 | nuclear_charge = atom.charge + cfg.system.delta_charge 43 | cfg.system.molecule = [ 44 | system.Atom(atom.symbol, atom.coords, nuclear_charge) 45 | ] 46 | else: 47 | cfg.system.molecule = [atom] 48 | 49 | if not cfg.system.electrons: 50 | atomic_number = elements.SYMBOLS[atom.symbol].atomic_number 51 | if 'charge' in cfg.system: 52 | atomic_number -= cfg.system.charge 53 | if ('spin_polarisation' in cfg.system 54 | and cfg.system.spin_polarisation is not None): 55 | spin_polarisation = cfg.system.spin_polarisation 56 | else: 57 | spin_polarisation = elements.ATOMIC_NUMS[atomic_number].spin_config 58 | nalpha = (atomic_number + spin_polarisation) // 2 59 | cfg.system.electrons = (nalpha, atomic_number - nalpha) 60 | 61 | return cfg 62 | 63 | 64 | def get_config(): 65 | """Returns config for running generic atoms with qmc.""" 66 | cfg = base_config.default() 67 | cfg.system.atom = 'Li' 68 | cfg.system.charge = 0 69 | cfg.system.delta_charge = 0.0 70 | cfg.system.states = 3 71 | cfg.pretrain.iterations = 1000 72 | cfg.optim.reset_if_nan = True 73 | cfg.observables.s2 = True 74 | cfg.observables.dipole = True 75 | cfg.observables.density = True 76 | cfg.system.spin_polarisation = ml_collections.FieldReference( 77 | None, field_type=int) 78 | with cfg.ignore_type(): 79 | cfg.system.set_molecule = _adjust_nuclear_charge 80 | cfg.config_module = '.atom' 81 | return cfg 82 | -------------------------------------------------------------------------------- /ferminet/utils/tests/units_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ferminet.utils.units.""" 16 | 17 | from absl.testing import absltest 18 | from ferminet.utils import units 19 | import numpy as np 20 | 21 | 22 | class UnitsTest(absltest.TestCase): 23 | 24 | def test_angstrom2bohr(self): 25 | self.assertAlmostEqual(units.angstrom2bohr(2), 3.77945225091, places=10) 26 | 27 | def test_angstrom2bohr_numpy(self): 28 | x = np.random.uniform(size=(3,)) 29 | x1 = units.angstrom2bohr(x) 30 | x2 = np.array([units.angstrom2bohr(v) for v in x]) 31 | np.testing.assert_allclose(x1, x2) 32 | 33 | def test_bohr2angstrom(self): 34 | self.assertAlmostEqual(units.bohr2angstrom(2), 1.05835442134, places=10) 35 | 36 | def test_bohr2angstrom_numpy(self): 37 | x = np.random.uniform(size=(3,)) 38 | x1 = units.bohr2angstrom(x) 39 | x2 = np.array([units.bohr2angstrom(v) for v in x]) 40 | np.testing.assert_allclose(x1, x2) 41 | 42 | def test_angstrom_bohr_idempotent(self): 43 | x = np.random.uniform() 44 | x1 = units.bohr2angstrom(units.angstrom2bohr(x)) 45 | self.assertAlmostEqual(x, x1, places=10) 46 | 47 | def test_bohr_angstrom_idempotent(self): 48 | x = np.random.uniform() 49 | x1 = units.angstrom2bohr(units.bohr2angstrom(x)) 50 | self.assertAlmostEqual(x, x1, places=10) 51 | 52 | def test_hartree2kcal(self): 53 | self.assertAlmostEqual(units.hartree2kcal(2), 1255.018948, places=10) 54 | 55 | def test_hartree2kcal_numpy(self): 56 | x = np.random.uniform(size=(3,)) 57 | x1 = units.hartree2kcal(x) 58 | x2 = np.array([units.hartree2kcal(v) for v in x]) 59 | np.testing.assert_allclose(x1, x2) 60 | 61 | def test_kcal2hartree(self): 62 | self.assertAlmostEqual(units.kcal2hartree(2), 0.00318720287, places=10) 63 | 64 | def test_kcal2hartree_numpy(self): 65 | x = np.random.uniform(size=(3,)) 66 | x1 = units.kcal2hartree(x) 67 | x2 = np.array([units.kcal2hartree(v) for v in x]) 68 | np.testing.assert_allclose(x1, x2) 69 | 70 | def test_hartree_kcal_idempotent(self): 71 | x = np.random.uniform() 72 | x1 = units.kcal2hartree(units.hartree2kcal(x)) 73 | self.assertAlmostEqual(x, x1, places=10) 74 | 75 | def test_kcal_hartree_idempotent(self): 76 | x = np.random.uniform() 77 | x1 = units.hartree2kcal(units.kcal2hartree(x)) 78 | self.assertAlmostEqual(x, x1, places=10) 79 | 80 | 81 | if __name__ == '__main__': 82 | absltest.main() 83 | -------------------------------------------------------------------------------- /ferminet/configs/li_wqmc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Generic single-atom configuration for FermiNet.""" 16 | 17 | from ferminet import base_config 18 | from ferminet.utils import elements 19 | from ferminet.utils import system 20 | import ml_collections 21 | 22 | 23 | def _adjust_nuclear_charge(cfg): 24 | """Sets the molecule, nuclear charge electrons for the atom. 25 | 26 | Note: function name predates this logic but is kept for compatibility with 27 | xm_expt.py. 28 | 29 | Args: 30 | cfg: ml_collections.ConfigDict after all argument parsing. 31 | 32 | Returns: 33 | ml_collections.ConfictDict with the nuclear charge for the atom in 34 | cfg.system.molecule and cfg.system.charge appropriately set. 35 | """ 36 | if cfg.system.molecule: 37 | atom = cfg.system.molecule[0] 38 | else: 39 | atom = system.Atom(symbol=cfg.system.atom, coords=(0, 0, 0)) 40 | 41 | if abs(cfg.system.delta_charge) > 1.0e-8: 42 | nuclear_charge = atom.charge + cfg.system.delta_charge 43 | cfg.system.molecule = [ 44 | system.Atom(atom.symbol, atom.coords, nuclear_charge) 45 | ] 46 | else: 47 | cfg.system.molecule = [atom] 48 | 49 | if not cfg.system.electrons: 50 | atomic_number = elements.SYMBOLS[atom.symbol].atomic_number 51 | if 'charge' in cfg.system: 52 | atomic_number -= cfg.system.charge 53 | if ( 54 | 'spin_polarisation' in cfg.system 55 | and cfg.system.spin_polarisation is not None 56 | ): 57 | spin_polarisation = cfg.system.spin_polarisation 58 | else: 59 | spin_polarisation = elements.ATOMIC_NUMS[atomic_number].spin_config 60 | nalpha = (atomic_number + spin_polarisation) // 2 61 | cfg.system.electrons = (nalpha, atomic_number - nalpha) 62 | 63 | return cfg 64 | 65 | 66 | def get_config(): 67 | """Returns config for running generic atoms with qmc.""" 68 | cfg = base_config.default() 69 | cfg.system.atom = 'Li' 70 | cfg.system.charge = 0 71 | cfg.system.delta_charge = 0.0 72 | cfg.system.spin_polarisation = ml_collections.FieldReference( 73 | None, field_type=int 74 | ) 75 | with cfg.ignore_type(): 76 | cfg.system.set_molecule = _adjust_nuclear_charge 77 | cfg.config_module = '.atom' 78 | cfg.network.network_type = 'psiformer' 79 | cfg.optim.iterations = 10_000 80 | cfg.optim.lr.delay = 5_000 81 | cfg.optim.clip_median = True 82 | cfg.debug.deterministic = True 83 | cfg.optim.kfac.norm_constraint = 1e-3 84 | cfg.optim.objective = 'wqmc' 85 | return cfg 86 | -------------------------------------------------------------------------------- /ferminet/tests/network_blocks_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ferminet.network_blocks.""" 16 | 17 | import itertools 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from ferminet import network_blocks 22 | import numpy as np 23 | 24 | 25 | class NetworkBlocksTest(parameterized.TestCase): 26 | 27 | @parameterized.parameters([ 28 | {'sizes': [], 'expected_indices': []}, 29 | {'sizes': [3], 'expected_indices': []}, 30 | {'sizes': [3, 0], 'expected_indices': [3]}, 31 | {'sizes': [3, 6], 'expected_indices': [3]}, 32 | {'sizes': [3, 6, 0], 'expected_indices': [3, 9]}, 33 | {'sizes': [2, 0, 6], 'expected_indices': [2, 2]}, 34 | ]) 35 | def test_array_partitions(self, sizes, expected_indices): 36 | self.assertEqual(network_blocks.array_partitions(sizes), expected_indices) 37 | 38 | @parameterized.parameters( 39 | {'shape': shape} for shape in [(1, 1, 1), (10, 2, 2), (10, 3, 3)]) 40 | def test_slogdet(self, shape, dtype=np.float32): 41 | a = np.random.normal(size=shape).astype(dtype) 42 | s1, ld1 = network_blocks.slogdet(a) 43 | s2, ld2 = np.linalg.slogdet(a) 44 | np.testing.assert_allclose(s1, s2, atol=1E-5, rtol=1E-5) 45 | np.testing.assert_allclose(ld1, ld2, atol=1E-5, rtol=1E-5) 46 | 47 | @parameterized.parameters([ 48 | { 49 | 'dims': [4] 50 | }, 51 | { 52 | 'dims': [4, 3] 53 | }, 54 | { 55 | 'dims': [4, 3, 1] 56 | }, 57 | ]) 58 | def test_split_into_blocks(self, dims): 59 | 60 | trailing_dims = [5] 61 | shape = [sum(dims), sum(dims)] + trailing_dims 62 | xs = np.random.uniform(size=shape).astype(np.float32) 63 | blocks = network_blocks.split_into_blocks(xs, dims) 64 | 65 | expected_shapes = [ 66 | list(dims) + trailing_dims for dims in itertools.product(dims, dims) 67 | ] 68 | shapes = [list(block.shape) for block in blocks] 69 | with self.subTest('check shapes'): 70 | self.assertEqual(shapes, expected_shapes) 71 | 72 | # each group of N blocks was split along axis=1. 73 | blocks1 = [ 74 | np.concatenate(blocks[i:i + len(dims)], axis=1) 75 | for i in range(0, len(blocks), len(dims)) 76 | ] 77 | # groups were split along axis=0. 78 | reconstructed_xs = np.concatenate(blocks1, axis=0) 79 | with self.subTest('check can reconstruct original array'): 80 | np.testing.assert_allclose(xs, reconstructed_xs) 81 | 82 | 83 | if __name__ == '__main__': 84 | absltest.main() 85 | -------------------------------------------------------------------------------- /ferminet/configs/excited/benzene.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config to reproduce Fig. 6 from Pfau et al. (2024).""" 16 | 17 | from ferminet import base_config 18 | from ferminet.configs.excited import presets 19 | from ferminet.utils import system 20 | import ml_collections 21 | 22 | 23 | _GEOM = """C 0.00000000 2.63144965 0.00000000 24 | C -2.27890225 1.31572483 0.00000000 25 | C -2.27890225 -1.31572483 0.00000000 26 | C 0.00000000 -2.63144965 0.00000000 27 | C 2.27890225 -1.31572483 0.00000000 28 | C 2.27890225 1.31572483 0.00000000 29 | H -4.04725813 2.33668557 0.00000000 30 | H -4.04725813 -2.33668557 0.00000000 31 | H -0.00000000 -4.67337115 0.00000000 32 | H 4.04725813 -2.33668557 0.00000000 33 | H 4.04725813 2.33668557 0.00000000 34 | H 0.00000000 4.67337115 0.00000000""".split('\n') 35 | 36 | 37 | def finalise( 38 | experiment_config: ml_collections.ConfigDict) -> ml_collections.ConfigDict: 39 | """Returns the experiment config with the molecule commpletely set.""" 40 | molecule = [] 41 | for atom in _GEOM: 42 | element, x, y, z = atom.split() 43 | coords = [float(xx) for xx in (x, y, z)] 44 | molecule.append(system.Atom(symbol=element, coords=coords, units='bohr')) 45 | 46 | if not experiment_config.system.electrons: # Don't override if already set 47 | nelectrons = int(sum(atom.charge for atom in molecule)) 48 | na = nelectrons // 2 49 | experiment_config.system.electrons = (na, nelectrons - na) 50 | experiment_config.system.molecule = molecule 51 | return experiment_config 52 | 53 | 54 | def get_config() -> ml_collections.ConfigDict: 55 | """Returns config for running generic atoms with qmc.""" 56 | cfg = base_config.default() 57 | cfg.system.charge = 0 58 | cfg.system.delta_charge = 0.0 59 | cfg.system.states = 6 60 | cfg.system.spin_polarisation = ml_collections.FieldReference( 61 | None, field_type=int) 62 | # While this value was used in the paper, it can be lowered. 63 | cfg.pretrain.iterations = 100_000 64 | cfg.mcmc.blocks = 4 65 | # While this envelope was used in the paper, it can be replaced with the 66 | # default 'isotropic' envelope without any noticeable change in the results. 67 | cfg.network.envelope_type = 'bottleneck' 68 | cfg.network.num_envelopes = 32 69 | cfg.update_from_flattened_dict(presets.excited_states) 70 | with cfg.ignore_type(): 71 | cfg.system.set_molecule = finalise 72 | 73 | return cfg 74 | -------------------------------------------------------------------------------- /ferminet/jastrows.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Multiplicative Jastrow factors.""" 16 | 17 | import enum 18 | from typing import Any, Callable, Iterable, Mapping, Union 19 | 20 | import jax.numpy as jnp 21 | 22 | ParamTree = Union[jnp.ndarray, Iterable['ParamTree'], Mapping[Any, 'ParamTree']] 23 | 24 | 25 | class JastrowType(enum.Enum): 26 | """Available multiplicative Jastrow factors.""" 27 | 28 | NONE = enum.auto() 29 | SIMPLE_EE = enum.auto() 30 | 31 | 32 | def _jastrow_ee( 33 | r_ee: jnp.ndarray, 34 | params: ParamTree, 35 | nspins: tuple[int, int], 36 | jastrow_fun: Callable[[jnp.ndarray, float, jnp.ndarray], jnp.ndarray], 37 | ) -> jnp.ndarray: 38 | """Jastrow factor for electron-electron cusps.""" 39 | r_ees = [ 40 | jnp.split(r, nspins[0:1], axis=1) 41 | for r in jnp.split(r_ee, nspins[0:1], axis=0) 42 | ] 43 | r_ees_parallel = jnp.concatenate([ 44 | r_ees[0][0][jnp.triu_indices(nspins[0], k=1)], 45 | r_ees[1][1][jnp.triu_indices(nspins[1], k=1)], 46 | ]) 47 | 48 | if r_ees_parallel.shape[0] > 0: 49 | jastrow_ee_par = jnp.sum( 50 | jastrow_fun(r_ees_parallel, 0.25, params['ee_par']) 51 | ) 52 | else: 53 | jastrow_ee_par = jnp.asarray(0.0) 54 | 55 | if r_ees[0][1].shape[0] > 0: 56 | jastrow_ee_anti = jnp.sum(jastrow_fun(r_ees[0][1], 0.5, params['ee_anti'])) 57 | else: 58 | jastrow_ee_anti = jnp.asarray(0.0) 59 | 60 | return jastrow_ee_anti + jastrow_ee_par 61 | 62 | 63 | def make_simple_ee_jastrow() -> ...: 64 | """Creates a Jastrow factor for electron-electron cusps.""" 65 | 66 | def simple_ee_cusp_fun( 67 | r: jnp.ndarray, cusp: float, alpha: jnp.ndarray 68 | ) -> jnp.ndarray: 69 | """Jastrow function satisfying electron cusp condition.""" 70 | return -(cusp * alpha**2) / (alpha + r) 71 | 72 | def init() -> Mapping[str, jnp.ndarray]: 73 | params = {} 74 | params['ee_par'] = jnp.ones( 75 | shape=1, 76 | ) 77 | params['ee_anti'] = jnp.ones( 78 | shape=1, 79 | ) 80 | return params 81 | 82 | def apply( 83 | r_ee: jnp.ndarray, 84 | params: ParamTree, 85 | nspins: tuple[int, int], 86 | ) -> jnp.ndarray: 87 | """Jastrow factor for electron-electron cusps.""" 88 | return _jastrow_ee(r_ee, params, nspins, jastrow_fun=simple_ee_cusp_fun) 89 | 90 | return init, apply 91 | 92 | 93 | def get_jastrow(jastrow: JastrowType) -> ...: 94 | jastrow_init, jastrow_apply = None, None 95 | if jastrow == JastrowType.SIMPLE_EE: 96 | jastrow_init, jastrow_apply = make_simple_ee_jastrow() 97 | elif jastrow != JastrowType.NONE: 98 | raise ValueError(f'Unknown Jastrow Factor type: {jastrow}') 99 | 100 | return jastrow_init, jastrow_apply 101 | -------------------------------------------------------------------------------- /ferminet/pbc/tests/hamiltonian_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """Tests for ferminet.pbc.hamiltonian.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from ferminet import base_config 20 | from ferminet import networks 21 | from ferminet.pbc import envelopes 22 | from ferminet.pbc import feature_layer as pbc_feature_layer 23 | from ferminet.pbc import hamiltonian 24 | import jax 25 | import jax.numpy as jnp 26 | import numpy as np 27 | 28 | 29 | class PbcHamiltonianTest(parameterized.TestCase): 30 | 31 | def test_periodicity(self): 32 | cfg = base_config.default() 33 | 34 | nspins = (6, 5) 35 | atoms = jnp.asarray([[0., 0., 0.2], [1.2, 1., -0.2], [2.5, -0.8, 0.6]]) 36 | natom = atoms.shape[0] 37 | charges = jnp.asarray([2, 5, 7]) 38 | spins = np.ones(shape=(1,)) 39 | key = jax.random.PRNGKey(42) 40 | key, subkey = jax.random.split(key) 41 | xs = jax.random.uniform(subkey, shape=(sum(nspins), 3)) 42 | 43 | feature_layer = pbc_feature_layer.make_pbc_feature_layer( 44 | natom, nspins, ndim=3, lattice=jnp.eye(3), include_r_ae=False 45 | ) 46 | 47 | kpoints = envelopes.make_kpoints(jnp.eye(3), nspins) 48 | 49 | network = networks.make_fermi_net( 50 | nspins, 51 | charges, 52 | envelope=envelopes.make_multiwave_envelope(kpoints), 53 | feature_layer=feature_layer, 54 | bias_orbitals=cfg.network.bias_orbitals, 55 | full_det=cfg.network.full_det, 56 | **cfg.network.ferminet, 57 | ) 58 | 59 | key, subkey = jax.random.split(key) 60 | params = network.init(subkey) 61 | 62 | local_energy = hamiltonian.local_energy( 63 | f=network.apply, 64 | charges=charges, 65 | nspins=nspins, 66 | use_scan=False, 67 | lattice=jnp.eye(3), 68 | heg=False, 69 | ) 70 | 71 | data = networks.FermiNetData( 72 | positions=xs.flatten(), spins=spins, atoms=atoms, charges=charges 73 | ) 74 | 75 | key, subkey = jax.random.split(key) 76 | e1, _ = local_energy(params, subkey, data) 77 | 78 | # Select random electron coordinate to displace by a random lattice vec 79 | key, subkey = jax.random.split(key) 80 | e_idx = jax.random.randint(subkey, (1,), 0, xs.shape[0]) 81 | key, subkey = jax.random.split(key) 82 | randvec = jax.random.randint(subkey, (3,), 0, 100).astype(jnp.float32) 83 | xs = xs.at[e_idx].add(randvec) 84 | 85 | data2 = networks.FermiNetData( 86 | positions=xs.flatten(), spins=spins, atoms=atoms, charges=charges 87 | ) 88 | 89 | key, subkey = jax.random.split(key) 90 | e2, _ = local_energy(params, subkey, data2) 91 | 92 | atol, rtol = 4.e-3, 4.e-3 93 | np.testing.assert_allclose(e1, e2, atol=atol, rtol=rtol) 94 | 95 | 96 | if __name__ == '__main__': 97 | absltest.main() 98 | -------------------------------------------------------------------------------- /ferminet/pbc/tests/features_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """Tests for ferminet.pbc.feature_layer.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from ferminet import networks 20 | from ferminet.pbc import feature_layer as pbc_feature_layer 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | 25 | 26 | class FeatureLayerTest(parameterized.TestCase): 27 | 28 | @parameterized.parameters([True, False]) 29 | def test_shape(self, heg): 30 | """Asserts that output shape of apply matches what is expected by init.""" 31 | nspins = (6, 5) 32 | atoms = jnp.asarray([[0., 0., 0.2], [1.2, 1., -0.2], [2.5, -0.8, 0.6]]) 33 | natom = atoms.shape[0] 34 | key = jax.random.PRNGKey(42) 35 | key, subkey = jax.random.split(key) 36 | xs = jax.random.uniform(subkey, shape=(sum(nspins), 3)) 37 | 38 | feature_layer = pbc_feature_layer.make_pbc_feature_layer( 39 | natom, nspins, 3, lattice=jnp.eye(3), include_r_ae=heg 40 | ) 41 | 42 | dims, params = feature_layer.init() 43 | ae, ee, r_ae, r_ee = networks.construct_input_features(xs, atoms) 44 | 45 | ae_features, ee_features = feature_layer.apply( 46 | ae=ae, r_ae=r_ae, ee=ee, r_ee=r_ee, **params) 47 | 48 | assert dims[0] == ae_features.shape[-1] 49 | assert dims[1] == ee_features.shape[-1] 50 | 51 | def test_periodicity(self): 52 | nspins = (6, 5) 53 | atoms = jnp.asarray([[0., 0., 0.2], [1.2, 1., -0.2], [2.5, -0.8, 0.6]]) 54 | natom = atoms.shape[0] 55 | key = jax.random.PRNGKey(42) 56 | key, subkey = jax.random.split(key) 57 | xs = jax.random.uniform(subkey, shape=(sum(nspins), 3)) 58 | 59 | feature_layer = pbc_feature_layer.make_pbc_feature_layer( 60 | natom, nspins, 3, lattice=jnp.eye(3), include_r_ae=False 61 | ) 62 | 63 | _, params = feature_layer.init() 64 | ae, ee, r_ae, r_ee = networks.construct_input_features(xs, atoms) 65 | 66 | ae_features_1, ee_features_1 = feature_layer.apply( 67 | ae=ae, r_ae=r_ae, ee=ee, r_ee=r_ee, **params) 68 | 69 | # Select random electron coordinate to displace by a random lattice vector 70 | key, subkey = jax.random.split(key) 71 | e_idx = jax.random.randint(subkey, (1,), 0, xs.shape[0]) 72 | key, subkey = jax.random.split(key) 73 | randvec = jax.random.randint(subkey, (3,), 0, 100).astype(jnp.float32) 74 | xs = xs.at[e_idx].add(randvec) 75 | 76 | ae, ee, r_ae, r_ee = networks.construct_input_features(xs, atoms) 77 | 78 | ae_features_2, ee_features_2 = feature_layer.apply( 79 | ae=ae, r_ae=r_ae, ee=ee, r_ee=r_ee, **params) 80 | 81 | atol, rtol = 4.e-3, 4.e-3 82 | np.testing.assert_allclose( 83 | ae_features_1, ae_features_2, atol=atol, rtol=rtol) 84 | np.testing.assert_allclose( 85 | ee_features_1, ee_features_2, atol=atol, rtol=rtol) 86 | 87 | 88 | if __name__ == '__main__': 89 | absltest.main() 90 | -------------------------------------------------------------------------------- /ferminet/utils/tests/scf_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ferminet.utils.scf.""" 16 | 17 | from typing import List, Tuple 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from ferminet.utils import scf 21 | from ferminet.utils import system 22 | import numpy as np 23 | import pyscf 24 | 25 | 26 | class ScfTest(parameterized.TestCase): 27 | 28 | def setUp(self): 29 | super().setUp() 30 | # disable use of temp directory in pyscf. 31 | # Test calculations are small enough to fit in RAM and we don't need 32 | # checkpoint files. 33 | pyscf.lib.param.TMPDIR = None 34 | 35 | @parameterized.parameters( 36 | { 37 | 'molecule': [system.Atom('He', (0, 0, 0))], 38 | 'nelectrons': (1, 1) 39 | }, 40 | { 41 | 'molecule': [system.Atom('N', (0, 0, 0))], 42 | 'nelectrons': (5, 2) 43 | }, 44 | { 45 | 'molecule': [system.Atom('N', (0, 0, 0))], 46 | 'nelectrons': (5, 3) 47 | }, 48 | { 49 | 'molecule': [system.Atom('N', (0, 0, 0))], 50 | 'nelectrons': (4, 2) 51 | }, 52 | { 53 | 'molecule': [system.Atom('O', (0, 0, 0))], 54 | 'nelectrons': (5, 3), 55 | 'restricted': False, 56 | }, 57 | { 58 | 'molecule': [ 59 | system.Atom('N', (0, 0, 0)), 60 | system.Atom('N', (0, 0, 1.4)) 61 | ], 62 | 'nelectrons': (7, 7) 63 | }, 64 | { 65 | 'molecule': [ 66 | system.Atom('O', (0, 0, 0)), 67 | system.Atom('O', (0, 0, 1.4)) 68 | ], 69 | 'nelectrons': (9, 7), 70 | 'restricted': False, 71 | }, 72 | ) 73 | def test_scf_interface(self, 74 | molecule: List[system.Atom], 75 | nelectrons: Tuple[int, int], 76 | restricted: bool = True): 77 | """Tests SCF interface to a pyscf calculation. 78 | 79 | pyscf has its own tests so only check that we can run calculations over 80 | atoms and simple diatomics using the interface in ferminet.scf. 81 | 82 | Args: 83 | molecule: List of system.Atom objects giving atoms in the molecule. 84 | nelectrons: Tuple containing number of alpha and beta electrons. 85 | restricted: If true, run a restricted Hartree-Fock calculation, otherwise 86 | run an unrestricted Hartree-Fock calculation. 87 | """ 88 | npts = 100 89 | xs = np.random.randn(npts, 3) 90 | hf = scf.Scf(molecule=molecule, 91 | nelectrons=nelectrons, 92 | restricted=restricted) 93 | hf.run() 94 | mo_vals = hf.eval_mos(xs) 95 | self.assertLen(mo_vals, 2) # alpha-spin orbitals and beta-spin orbitals. 96 | for spin_mo_vals in mo_vals: 97 | # Evaluate npts points on M orbitals/functions - (npts, M) array. 98 | self.assertEqual(spin_mo_vals.shape, (npts, hf._mol.nao_nr())) 99 | 100 | 101 | if __name__ == '__main__': 102 | absltest.main() 103 | -------------------------------------------------------------------------------- /ferminet/configs/diatomic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Diatomic molecule config for FermiNet.""" 16 | 17 | import re 18 | 19 | from ferminet import base_config 20 | from ferminet.utils import system 21 | 22 | 23 | # Default bond lengths in angstrom for a few diatomics of interest. 24 | # N2 bond length from the G3 dataset. Ref: 25 | # 1. http://www.cse.anl.gov/OldCHMwebsiteContent/compmat/comptherm.htm 26 | # 2. L. A. Curtiss, P. C. Redfern, K. Raghavachari, and J. A. Pople, 27 | # J. Chem. Phys, 109, 42 (1998). 28 | BOND_LENGTHS = { 29 | 'BeH': 1.348263, 30 | 'CN': 1.134797, 31 | 'ClF': 1.659091, 32 | 'F2': 1.420604, 33 | 'H2': 0.737164, 34 | 'HCl': 1.2799799, 35 | 'Li2': 2.77306, 36 | 'LiH': 1.639999, 37 | 'N2': 1.129978, 38 | 'NH': 1.039428, 39 | 'CO': 1.150338, 40 | 'BH': 1.2324, # Not in G3 41 | 'PN': 1.491, # Not in G3 42 | 'AlH': 1.648, # Not in G3 43 | 'AlN': 1.786, # Not in G3 44 | } 45 | 46 | # Default spin polarisation for a few diatomics of interest. 47 | # Otherwise default to either singlet (doublet) for even (odd) numbers of 48 | # electrons. Units: number of unpaired electrons. 49 | SPIN_POLARISATION = { 50 | 'B2': 2, 51 | 'O2': 2, 52 | 'NH': 2, 53 | 'AlN': 2, 54 | } 55 | 56 | 57 | def molecule(cfg): 58 | """Creates molecule in cfg.""" 59 | if cfg.system.molecule_name.endswith('2'): 60 | atom1 = atom2 = cfg.system.molecule_name.strip('2') 61 | else: 62 | atom1, atom2 = re.findall('[A-Z][a-z]*', cfg.system.molecule_name) 63 | if cfg.system.bond_length < 0: 64 | if cfg.system.molecule_name in BOND_LENGTHS: 65 | cfg.system.bond_length = BOND_LENGTHS[cfg.system.molecule_name] 66 | else: 67 | raise ValueError('bond length not set.') 68 | pos = (cfg.system.bond_length * cfg.system.bond_length_multiple) / 2 69 | atom_coords = ((-pos, 0., 0.), (pos, 0., 0.)) 70 | cfg.system.molecule = [ 71 | system.Atom(symbol=atom, coords=coord, units=cfg.system.units) 72 | for atom, coord in zip((atom1, atom2), atom_coords) 73 | ] 74 | 75 | if any( 76 | abs(atom.charge - round(atom.charge)) > 1e-6 77 | for atom in cfg.system.molecule): 78 | raise RuntimeError( 79 | 'Cannot set the number of electrons for a fractional charge atom.') 80 | electrons = sum(int(round(atom.charge)) for atom in cfg.system.molecule) 81 | 82 | if not cfg.system.electrons: 83 | if cfg.system.molecule_name in SPIN_POLARISATION: 84 | spin = SPIN_POLARISATION[cfg.system.molecule_name] 85 | else: 86 | spin = electrons % 2 87 | nalpha = (electrons + spin) // 2 88 | cfg.system.electrons = (nalpha, electrons - nalpha) 89 | 90 | return cfg 91 | 92 | 93 | def get_config(): 94 | """Returns the config for running a diatomic molecule with qmc.""" 95 | cfg = base_config.default() 96 | # Can specify homonuclear diatomics using X2 or heteronuclear diaomics using 97 | # XY. 98 | cfg.system.molecule_name = 'N2' 99 | 100 | cfg.system.bond_length = -1.0 101 | cfg.system.units = 'angstrom' 102 | cfg.system.bond_length_multiple = 1.0 103 | with cfg.ignore_type(): 104 | cfg.system.set_molecule = molecule 105 | cfg.config_module = '.diatomic' 106 | 107 | return cfg 108 | -------------------------------------------------------------------------------- /ferminet/utils/tests/system_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ferminet.utils.system.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from ferminet.utils import system 20 | from ferminet.utils import units 21 | import numpy as np 22 | import pyscf 23 | 24 | 25 | class SystemAtomCoordsTest(absltest.TestCase): 26 | 27 | def test_atom_coords(self): 28 | xs = np.random.uniform(size=3) 29 | atom = system.Atom(symbol='H', coords=xs, units='angstrom') 30 | np.testing.assert_allclose(atom.coords / xs, [units.BOHR_ANGSTROM]*3) 31 | np.testing.assert_allclose(atom.coords_angstrom, xs) 32 | 33 | def test_atom_units(self): 34 | system.Atom(symbol='H', coords=[1, 2, 3], units='bohr') 35 | system.Atom(symbol='H', coords=[1, 2, 3], units='angstrom') 36 | with self.assertRaises(ValueError): 37 | system.Atom(symbol='H', coords=[1, 2, 3], units='dummy') 38 | 39 | 40 | class PyscfConversionTest(parameterized.TestCase): 41 | 42 | @parameterized.parameters([ 43 | { 44 | 'mol_string': 'H 0 0 -1; H 0 0 1' 45 | }, 46 | { 47 | 'mol_string': 'O 0 0 0; H 0 1 0; H 0 0 1' 48 | }, 49 | { 50 | 'mol_string': 'H 0 0 0; Cl 0 0 1.1' 51 | }, 52 | ]) 53 | def test_conversion_pyscf(self, mol_string): 54 | mol = pyscf.gto.Mole() 55 | mol.build( 56 | atom=mol_string, 57 | basis='sto-3g', unit='bohr') 58 | cfg = system.pyscf_mol_to_internal_representation(mol) 59 | # Assert that the alpha and beta electrons are the same 60 | self.assertEqual(mol.nelec, cfg.system.electrons) 61 | # Assert that the basis are the same 62 | self.assertEqual(mol.basis, cfg.pretrain.basis) 63 | # Assert that atom symbols are the same 64 | self.assertEqual([mol.atom_symbol(i) for i in range(mol.natm)], 65 | [atom.symbol for atom in cfg.system.molecule]) 66 | # Assert that atom coordinates are the same 67 | pyscf_coords = [mol.atom_coords()[i] for i in range(mol.natm)] 68 | internal_coords = [np.array(atom.coords) for atom in cfg.system.molecule] 69 | np.testing.assert_allclose(pyscf_coords, internal_coords) 70 | 71 | def test_conversion_pyscf_ang(self): 72 | mol = pyscf.gto.Mole() 73 | mol.build( 74 | atom='H 0 0 -1; H 0 0 1', 75 | basis='sto-3g', unit='ang') 76 | cfg = system.pyscf_mol_to_internal_representation(mol) 77 | # Assert that the coordinates are now in bohr internally 78 | bohr_coords = [[0, 0, -units.BOHR_ANGSTROM], [0, 0, units.BOHR_ANGSTROM]] 79 | np.testing.assert_allclose([atom.coords for atom in cfg.system.molecule], 80 | bohr_coords) 81 | # Assert that the alpha and beta electrons are the same 82 | self.assertEqual(mol.nelec, cfg.system.electrons) 83 | # Assert that the basis are the same 84 | self.assertEqual(mol.basis, cfg.pretrain.basis) 85 | # Assert that atom symbols are the same 86 | self.assertEqual([mol.atom_symbol(i) for i in range(mol.natm)], 87 | [atom.symbol for atom in cfg.system.molecule]) 88 | # Assert that atom coordinates are the same 89 | pyscf_coords = [mol.atom_coords()[i] for i in range(mol.natm)] 90 | internal_coords = [np.array(atom.coords) for atom in cfg.system.molecule] 91 | np.testing.assert_allclose(pyscf_coords, internal_coords) 92 | 93 | 94 | if __name__ == '__main__': 95 | absltest.main() 96 | -------------------------------------------------------------------------------- /ferminet/utils/system.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions to create different kinds of systems.""" 16 | 17 | from typing import Sequence 18 | import attr 19 | from ferminet.utils import elements 20 | from ferminet.utils import units as unit_conversion 21 | import ml_collections 22 | import numpy as np 23 | import pyscf 24 | 25 | 26 | @attr.s 27 | class Atom: 28 | """Atom information for Hamiltonians. 29 | 30 | The nuclear charge is inferred from the symbol if not given, in which case the 31 | symbol must be the IUPAC symbol of the desired element. 32 | 33 | Attributes: 34 | symbol: Element symbol. 35 | coords: An iterable of atomic coordinates. Always a list of floats and in 36 | bohr after initialisation. Default: place atom at origin. 37 | charge: Nuclear charge. Default: nuclear charge (atomic number) of atom of 38 | the given name. 39 | atomic_number: Atomic number associated with element. Default: atomic number 40 | of element of the given symbol. Should match charge unless fractional 41 | nuclear charges are being used. 42 | units: String giving units of coords. Either bohr or angstrom. Default: 43 | bohr. If angstrom, coords are converted to be in bohr and units to the 44 | string 'bohr'. 45 | coords_angstrom: list of atomic coordinates in angstrom. 46 | coords_array: Numpy array of atomic coordinates in bohr. 47 | element: elements.Element corresponding to the symbol. 48 | """ 49 | symbol = attr.ib(type=str) 50 | coords = attr.ib( 51 | type=Sequence[float], 52 | converter=lambda xs: tuple(float(x) for x in xs), 53 | default=(0.0, 0.0, 0.0)) 54 | charge = attr.ib(type=float, converter=float) 55 | atomic_number = attr.ib(type=int, converter=int) 56 | units = attr.ib( 57 | type=str, 58 | default='bohr', 59 | validator=attr.validators.in_(['bohr', 'angstrom'])) 60 | 61 | @charge.default 62 | def _set_default_charge(self): 63 | return self.element.atomic_number 64 | 65 | @atomic_number.default 66 | def _set_default_atomic_number(self): 67 | return self.element.atomic_number 68 | 69 | def __attrs_post_init__(self): 70 | if self.units == 'angstrom': 71 | self.coords = [unit_conversion.angstrom2bohr(x) for x in self.coords] 72 | self.units = 'bohr' 73 | 74 | @property 75 | def coords_angstrom(self): 76 | return [unit_conversion.bohr2angstrom(x) for x in self.coords] 77 | 78 | @property 79 | def coords_array(self): 80 | if not hasattr(self, '_coords_arr'): 81 | self._coords_arr = np.array(self.coords) 82 | return self._coords_arr 83 | 84 | @property 85 | def element(self): 86 | return elements.SYMBOLS[self.symbol] 87 | 88 | 89 | def pyscf_mol_to_internal_representation( 90 | mol: pyscf.gto.Mole) -> ml_collections.ConfigDict: 91 | """Converts a PySCF Mole object to an internal representation. 92 | 93 | Args: 94 | mol: Mole object describing the system of interest. 95 | 96 | Returns: 97 | A ConfigDict with the fields required to describe the system set. 98 | """ 99 | # Ensure Mole is built so all attributes are appropriately set. 100 | mol.build() 101 | atoms = [ 102 | Atom(mol.atom_symbol(i), mol.atom_coord(i)) 103 | for i in range(mol.natm) 104 | ] 105 | return ml_collections.ConfigDict({ 106 | 'system': { 107 | 'molecule': atoms, 108 | 'electrons': mol.nelec, 109 | }, 110 | 'pretrain': { 111 | # If mol.basis isn't a string, assume that mol is passed into 112 | # pretraining as well and pretraining uses the basis already set in 113 | # mol, rather than complicating the configuration here. 114 | 'basis': mol.basis if isinstance(mol.basis, str) else None, 115 | }, 116 | }) 117 | -------------------------------------------------------------------------------- /ferminet/configs/excited/twisted_ethylene.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config to reproduce Fig. 4 from Pfau et al. (2024).""" 16 | 17 | from ferminet import base_config 18 | from ferminet.configs.excited import presets 19 | from ferminet.utils import system 20 | import ml_collections 21 | import numpy as np 22 | 23 | 24 | # Geometries from Entwistle, Schätzle, Erdman, Hermann and Noé (2023), taken 25 | # from Barbatti, Paier and Lischka (2004). All geometries are in Angstroms. 26 | _SYSTEMS = { 27 | 'planar': ['C -0.675 0.0 0.0', 28 | 'C 0.675 0.0 0.0', 29 | 'H -1.2429 0.0 -0.93037', 30 | 'H -1.2429 0.0 0.93037', 31 | 'H 1.2429 0.0 -0.93037', 32 | 'H 1.2429 0.0 0.93037'], 33 | 'twisted': ['C -0.6885 0.0 0.0', 34 | 'C 0.6885 0.0 0.0', 35 | 'H -1.307207 0.0 -0.915547', 36 | 'H -1.307207 0.0 0.915547', 37 | 'H 1.307207 -0.915547 0.0', 38 | 'H 1.307207 0.915547 0.0'], 39 | } 40 | 41 | 42 | def finalise( 43 | experiment_config: ml_collections.ConfigDict) -> ml_collections.ConfigDict: 44 | """Returns the experiment config with the molecule set.""" 45 | geom = _SYSTEMS[experiment_config.system.molecule_name] 46 | 47 | molecule = [] 48 | for i, atom in enumerate(geom): 49 | element, x, y, z = atom.split() 50 | coords = np.array([float(xx) * 1.88973 for xx in (x, y, z)]) # ang to bohr 51 | if i == 4 or i == 5: 52 | if experiment_config.system.twist.tau != 0: 53 | # rotate hydrogens around x axis 54 | tau = experiment_config.system.twist.tau * np.pi / 180.0 # deg to rad 55 | rot = np.array([[1.0, 0.0, 0.0], 56 | [0.0, np.cos(tau), -np.sin(tau)], 57 | [0.0, np.sin(tau), np.cos(tau)]]) 58 | coords = rot @ coords 59 | if experiment_config.system.twist.phi != 0: 60 | # rotate hydrogens around y axis 61 | phi = experiment_config.system.twist.phi * np.pi / 180.0 # deg to rad 62 | rot = np.array([[np.cos(phi), 0.0, -np.sin(phi)], 63 | [0.0, 1.0, 0.0], 64 | [np.sin(phi), 0.0, np.cos(phi)]]) 65 | # position of carbon atom 66 | coord0 = np.array([float(xx) for xx in geom[1].split()[1:]]) * 1.88973 67 | coords = (rot @ (coords - coord0)) + coord0 68 | molecule.append(system.Atom(symbol=element, 69 | coords=list(coords), 70 | units='bohr')) 71 | 72 | if not experiment_config.system.electrons: # Don't override if already set 73 | nelectrons = int(sum(atom.charge for atom in molecule)) 74 | na = nelectrons // 2 75 | experiment_config.system.electrons = (na, nelectrons - na) 76 | experiment_config.system.molecule = molecule 77 | return experiment_config 78 | 79 | 80 | def get_config() -> ml_collections.ConfigDict: 81 | """Returns config for running generic atoms with qmc.""" 82 | cfg = base_config.default() 83 | cfg.system.charge = 0 84 | cfg.system.delta_charge = 0.0 85 | cfg.system.molecule_name = 'planar' 86 | cfg.system.twist = { 87 | 'tau': 0.0, # torsion angle, in degrees 88 | 'phi': 0.0, # pyramidalization angle, in degrees 89 | } 90 | cfg.system.states = 3 # Note that for equilibrium only, we computed 5 states. 91 | cfg.system.spin_polarisation = ml_collections.FieldReference( 92 | None, field_type=int) 93 | cfg.system.units = 'bohr' 94 | cfg.system.electrons = (8, 8) 95 | cfg.pretrain.iterations = 10_000 96 | cfg.optim.iterations = 100_000 97 | cfg.update_from_flattened_dict(presets.psiformer) 98 | cfg.update_from_flattened_dict(presets.excited_states) 99 | with cfg.ignore_type(): 100 | cfg.system.set_molecule = finalise 101 | return cfg 102 | -------------------------------------------------------------------------------- /ferminet/configs/excited/double_excitation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config to reproduce Fig. 5 from Pfau et al. (2024).""" 16 | 17 | from ferminet import base_config 18 | from ferminet.configs.excited import presets 19 | from ferminet.utils import system 20 | import ml_collections 21 | 22 | 23 | # Geometries are taken from Loos, Boggio-Pasqua, Scemama, Caffarel and 24 | # Jacquemin, JCTC (2019). The units listed in the supplemental material are 25 | # inconsistent: nitrosomethane and glyoxal are in Bohr, the rest are Angstrom. 26 | _SYSTEMS = { 27 | 'nitrosomethane': """C -1.78426612 0.00000000 -1.07224050 28 | N -0.00541753 0.00000000 1.08060391 29 | O 2.18814985 0.00000000 0.43452135 30 | H -0.77343975 0.00000000 -2.86415606 31 | H -2.97471478 1.66801808 -0.86424584 32 | H -2.97471478 -1.66801808 -0.86424584""".split('\n'), 33 | 'butadiene': """C 1.740343 0.616556 0.00000000 34 | C -1.740343 -0.616556 0.00000000 35 | C 0.397343 0.616556 0.00000000 36 | C -0.397343 -0.616556 0.00000000 37 | H 0.126346 -1.577069 0.00000000 38 | H -0.126346 1.577069 0.00000000 39 | H 2.279054 1.568725 0.00000000 40 | H -2.279054 -1.568725 0.00000000 41 | H 2.279054 -0.335614 0.00000000 42 | H -2.279054 0.335614 0.00000000""".split('\n'), 43 | 'glyoxal': """C 1.21360282 0.75840215 0.00000000 44 | C -1.21360282 -0.75840215 0.00000000 45 | O 3.25581408 -0.26453186 0.00000000 46 | O -3.25581408 0.26453186 0.00000000 47 | H 0.96135276 2.81883243 0.00000000 48 | H -0.96135276 -2.81883243 0.00000000""".split('\n'), 49 | 'tetrazine': """C 0.00000000 0.00000000 1.26054332 50 | C 0.00000000 0.00000000 -1.26054332 51 | N 0.00000000 1.19421138 0.66133002 52 | N 0.00000000 -1.19421138 0.66133002 53 | N 0.00000000 1.19421138 -0.66133002 54 | N 0.00000000 -1.19421138 -0.66133002 55 | H 0.00000000 0.00000000 2.33817427 56 | H 0.00000000 0.00000000 -2.33817427""".split('\n'), 57 | 'cyclopentadienone': """C 0.00000000 0.00000000 0.76853878 58 | C 0.00000000 1.19974276 -0.13448057 59 | C 0.00000000 -1.19974276 -0.13448057 60 | C 0.00000000 0.74909075 -1.39624830 61 | C 0.00000000 -0.74909075 -1.39624830 62 | O 0.00000000 0.00000000 1.98144505 63 | H 0.00000000 2.21416694 0.22305399 64 | H 0.00000000 -2.21416694 0.22305399 65 | H 0.00000000 1.34284493 -2.29584273 66 | H 0.00000000 -1.34284493 -2.29584273""".split('\n') 67 | } 68 | 69 | 70 | def finalise( 71 | experiment_config: ml_collections.ConfigDict) -> ml_collections.ConfigDict: 72 | """Returns the experiment config with the molecule set.""" 73 | system_name = experiment_config.system.molecule_name 74 | geom = _SYSTEMS[system_name] 75 | units = 'bohr' if system_name in ['gloyxal', 'nitrosomethane'] else 'angstrom' 76 | 77 | molecule = [] 78 | for atom in geom: 79 | element, x, y, z = atom.split() 80 | coords = [float(xx) for xx in (x, y, z)] 81 | molecule.append(system.Atom(symbol=element, 82 | coords=coords, 83 | units=units)) 84 | 85 | if not experiment_config.system.electrons: # Don't override if already set 86 | nelectrons = int(sum(atom.charge for atom in molecule)) 87 | na = nelectrons // 2 88 | experiment_config.system.electrons = (na, nelectrons - na) 89 | experiment_config.system.molecule = molecule 90 | return experiment_config 91 | 92 | 93 | def get_config() -> ml_collections.ConfigDict: 94 | """Returns config for running generic atoms with qmc.""" 95 | cfg = base_config.default() 96 | cfg.system.charge = 0 97 | cfg.system.delta_charge = 0.0 98 | cfg.system.molecule_name = 'nitrosomethane' 99 | cfg.system.states = 6 100 | cfg.system.spin_polarisation = ml_collections.FieldReference( 101 | None, field_type=int) 102 | cfg.pretrain.iterations = 50_000 103 | cfg.mcmc.blocks = 2 104 | cfg.update_from_flattened_dict(presets.excited_states) 105 | with cfg.ignore_type(): 106 | cfg.system.set_molecule = finalise 107 | return cfg 108 | -------------------------------------------------------------------------------- /ferminet/pbc/feature_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """Feature layer for periodic boundary conditions. 16 | 17 | See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., 18 | Spencer, J.S. and Foulkes, W.M.C., 2022. Discovering Quantum Phase Transitions 19 | with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183. 20 | """ 21 | 22 | from typing import Optional, Tuple 23 | 24 | import chex 25 | from ferminet import networks 26 | import jax.numpy as jnp 27 | 28 | 29 | def periodic_norm(metric: jnp.ndarray, scaled_r: jnp.ndarray) -> jnp.ndarray: 30 | """Returns the periodic norm of a set of vectors. 31 | 32 | Args: 33 | metric: metric tensor in fractional coordinate system, A.T A, where A is the 34 | lattice vectors. 35 | scaled_r: vectors in fractional coordinates of the lattice cell, with 36 | trailing dimension ndim, to compute the periodic norm of. 37 | """ 38 | chex.assert_rank(metric, expected_ranks=2) 39 | a = (1 - jnp.cos(2 * jnp.pi * scaled_r)) 40 | b = jnp.sin(2 * jnp.pi * scaled_r) 41 | cos_term = jnp.einsum('...m,mn,...n->...', a, metric, a) 42 | sin_term = jnp.einsum('...m,mn,...n->...', b, metric, b) 43 | return (1 / (2 * jnp.pi)) * jnp.sqrt(cos_term + sin_term) 44 | 45 | 46 | def make_pbc_feature_layer( 47 | natoms: Optional[int] = None, 48 | nspins: Optional[Tuple[int, ...]] = None, 49 | ndim: int = 3, 50 | rescale_inputs: bool = False, 51 | lattice: Optional[jnp.ndarray] = None, 52 | include_r_ae: bool = True, 53 | ) -> networks.FeatureLayer: 54 | """Returns the init and apply functions for periodic features. 55 | 56 | Args: 57 | natoms: number of atoms. 58 | nspins: tuple of the number of spin-up and spin-down electrons. 59 | ndim: dimension of the system. 60 | rescale_inputs: If true, rescales r_ae for stability. Note that unlike in 61 | the OBC case, we do not rescale r_ee as well. 62 | lattice: Matrix whose columns are the primitive lattice vectors of the 63 | system, shape (ndim, ndim). 64 | include_r_ae: Flag to enable electron-atom distance features. Set to False 65 | to avoid cusps with ghost atoms in, e.g., homogeneous electron gas. 66 | """ 67 | 68 | del nspins 69 | 70 | if lattice is None: 71 | lattice = jnp.eye(ndim) 72 | 73 | # Calculate reciprocal vectors, factor 2pi omitted 74 | reciprocal_vecs = jnp.linalg.inv(lattice) 75 | lattice_metric = lattice.T @ lattice 76 | 77 | def init() -> Tuple[Tuple[int, int], networks.Param]: 78 | if include_r_ae: 79 | return (natoms * (2 * ndim + 1), 2 * ndim + 1), {} 80 | else: 81 | return (natoms * (2 * ndim), 2 * ndim + 1), {} 82 | 83 | def apply(ae, r_ae, ee, r_ee) -> Tuple[jnp.ndarray, jnp.ndarray]: 84 | # One e features in phase coordinates, (s_ae)_i = k_i . ae 85 | s_ae = jnp.einsum('il,jkl->jki', reciprocal_vecs, ae) 86 | # Two e features in phase coordinates 87 | s_ee = jnp.einsum('il,jkl->jki', reciprocal_vecs, ee) 88 | # Periodized features 89 | ae = jnp.concatenate( 90 | (jnp.sin(2 * jnp.pi * s_ae), jnp.cos(2 * jnp.pi * s_ae)), axis=-1) 91 | ee = jnp.concatenate( 92 | (jnp.sin(2 * jnp.pi * s_ee), jnp.cos(2 * jnp.pi * s_ee)), axis=-1) 93 | # Distance features defined on orthonormal projections 94 | r_ae = periodic_norm(lattice_metric, s_ae) 95 | if rescale_inputs: 96 | r_ae = jnp.log(1 + r_ae) 97 | # Don't take gradients through |0| 98 | n = ee.shape[0] 99 | s_ee += jnp.eye(n)[..., None] 100 | r_ee = periodic_norm(lattice_metric, s_ee) * (1.0 - jnp.eye(n)) 101 | 102 | if include_r_ae: 103 | ae_features = jnp.concatenate((r_ae[..., None], ae), axis=2) 104 | else: 105 | ae_features = ae 106 | ae_features = jnp.reshape(ae_features, [jnp.shape(ae_features)[0], -1]) 107 | ee_features = jnp.concatenate((r_ee[..., None], ee), axis=2) 108 | return ae_features, ee_features 109 | 110 | return networks.FeatureLayer(init=init, apply=apply) 111 | -------------------------------------------------------------------------------- /ferminet/utils/analysis_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tools for reading and analysing QMC data.""" 16 | 17 | from typing import Iterable, Optional, Union 18 | 19 | from absl import logging 20 | import numpy as np 21 | import pandas as pd 22 | 23 | from pyblock import pd_utils as blocking 24 | 25 | 26 | def _format_network(network_option: Union[int, Iterable[int]]) -> str: 27 | """Formats a network configuration to a (short) string. 28 | 29 | Args: 30 | network_option: a integer or iterable of integers. 31 | 32 | Returns: 33 | String representation of the network option. If the network option is an 34 | iterable of the form [V, V, ...], return NxV, where N is the length of the 35 | iterable. 36 | """ 37 | try: 38 | # pytype doesn't handle try...except TypeError gracefully. 39 | if all(xi == network_option[0] for xi in network_option[1:]): # pytype: disable=unsupported-operands 40 | return f'{len(network_option)}x{network_option[0]}' # pytype: disable=unsupported-operands,wrong-arg-types 41 | else: 42 | return str(network_option) 43 | except TypeError: 44 | return str(network_option) 45 | 46 | 47 | def estimate_stats(df: pd.DataFrame, 48 | burn_in: int, 49 | groups: Optional[Iterable[str]] = None, 50 | group_by_work_unit: bool = True) -> pd.DataFrame: 51 | """Estimates statistics for the (local) energy. 52 | 53 | Args: 54 | df: pd.DataFrame containing local energy data in the 'eigenvalues' column. 55 | burn_in: number of data points to discard before accumulating statistics to 56 | allow for learning and equilibration time. 57 | groups: list of column names in df to group by. The statistics for each 58 | group are returned, along with the corresponding data for the group. The 59 | group columns should be sufficient to distinguish between separate work 60 | units/calculations in df. 61 | group_by_work_unit: add 'work_unit_id' to the list of groups if not already 62 | present and 'work_unit_id' is a column in df. This is usually helpful 63 | for safety, when each work unit is a separate calculation and should be 64 | treated separately statistically. 65 | 66 | Returns: 67 | pandas DataFrame containing estimates of the mean, standard error and error 68 | in the standard error from a blocking analysis of the local energy for each 69 | group in df. 70 | 71 | Raises: 72 | RuntimeError: If groups is empty or None and group_by_work_unit is False. If 73 | df does not contain a key to group over, insert a dummy column with 74 | identical values or use pyblock directly. 75 | """ 76 | wid = 'work_unit_id' 77 | if groups is None: 78 | groups = [] 79 | else: 80 | groups = list(groups) 81 | if group_by_work_unit and wid not in groups and wid in df.columns: 82 | groups.append(wid) 83 | if not groups: 84 | raise RuntimeError( 85 | 'Must group by at least one key or set group_by_work_unit to True.') 86 | if len(groups) == 1: 87 | index_dict = {'index': groups[0]} 88 | else: 89 | index_dict = {f'level_{i}': group for i, group in enumerate(groups)} 90 | stats_dict = { 91 | 'mean': 'energy', 92 | 'standard error': 'stderr', 93 | 'standard error error': 'stderrerr' 94 | } 95 | def block(key, values): 96 | blocked = blocking.reblock_summary(blocking.reblock(values)[1]) 97 | if not blocked.empty: 98 | return blocked.iloc[0] 99 | else: 100 | logging.warning('Reblocking failed to estimate statistics for %s.', key) 101 | return pd.Series({statistic: np.nan for statistic in stats_dict}) 102 | stats = ( 103 | pd.DataFrame.from_dict({ 104 | n: block(n, d.eigenvalues[burn_in:]) 105 | for n, d in df.groupby(groups) if not d[burn_in:].eigenvalues.empty 106 | }, orient='index') 107 | .reset_index() 108 | .rename(index_dict, axis=1) 109 | .rename(stats_dict, axis=1) 110 | ) 111 | stats = stats.sort_values(by=groups).reset_index(drop=True) 112 | stats['burn_in'] = burn_in 113 | return stats 114 | -------------------------------------------------------------------------------- /ferminet/utils/pseudopotential.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tools for downloading and processing pseudopotentials.""" 16 | 17 | from ferminet.utils import elements 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as np 21 | import pyscf 22 | 23 | 24 | def gaussian(r, a, b, n): 25 | return a * r**n * jnp.exp(-b * r**2) 26 | 27 | 28 | def eval_ecp(r, coeffs): 29 | r"""Evaluates r^2 U_l = \\sum A_{lk} r^{n_{lk}} e^{- B_{lk} r^2}.""" 30 | val = 0. 31 | for r_exponent, v in enumerate(coeffs): 32 | for (exp_coeff, linear_coeff) in v: 33 | val += gaussian(r, linear_coeff, exp_coeff, r_exponent) 34 | return val 35 | 36 | 37 | def calc_gaussian_cutoff( 38 | coeffs, 39 | v_tol, 40 | rmin=0.005, 41 | rmax=10, 42 | nr=10001, 43 | ): 44 | """Calculates the Gaussian cutoff.""" 45 | r = np.linspace(rmin, rmax, nr) 46 | r_c_max = 0 47 | for r_exponent in range(len(coeffs)): 48 | for (exp_coeff, linear_coeff) in coeffs[r_exponent]: 49 | v = np.abs(gaussian(r, linear_coeff, exp_coeff, r_exponent)) / r**2 50 | try: 51 | ind = np.where(v > v_tol)[0][-1] + 1 52 | except IndexError: 53 | # in case all values are zero 54 | ind = 0 55 | r_c_new = r[ind] 56 | r_c_max = max(r_c_new, r_c_max) 57 | return r_c_max 58 | 59 | 60 | def calc_r_c(ecp, v_tol, **kwargs): 61 | """Calculates r_c.""" 62 | out = {} 63 | for l, coeffs in ecp: 64 | out[l] = calc_gaussian_cutoff(coeffs, v_tol, **kwargs) 65 | return out 66 | 67 | 68 | def eval_ecp_on_grid( 69 | ecp_all, 70 | r_grid=None, 71 | log_r0=-5, 72 | log_rn=10, 73 | n_grid=10001, 74 | ): 75 | """r_grid overrules log_r0, log_rn and n_grid if set.""" 76 | 77 | if r_grid is None: 78 | r_grid = jnp.logspace(log_r0, log_rn, n_grid) 79 | else: 80 | n_grid = r_grid.size 81 | 82 | n_channels = max(len(val[1]) for val in ecp_all.values()) 83 | n_cores = {z: val[0] for z, val in ecp_all.items()} 84 | v_grid_dict = {} 85 | 86 | for z, (_, ecp_val) in ecp_all.items(): 87 | v_grid = jnp.zeros((n_channels, n_grid)) 88 | 89 | # zeff = z - n_cores[z] 90 | 91 | for l, coeffs in ecp_val: 92 | v_grid = v_grid.at[l].set(eval_ecp(r_grid, coeffs) / r_grid**2) 93 | 94 | v_grid_dict[z] = jnp.asarray(v_grid) 95 | 96 | return n_cores, v_grid_dict, r_grid, n_channels 97 | 98 | 99 | def make_pp_args(symbols, v_tol=1e-5, quad_degree=4, zeta="d", return_ecp=True): 100 | """Creates arguments to use in pp_hamiltonian.""" 101 | ecp_all = {} 102 | ecp_basis = {} 103 | for symb in symbols: 104 | z = elements.SYMBOLS[symb].atomic_number 105 | aug = "" 106 | ecp_symb = pyscf.gto.basis.load_ecp("ccecp", symb) 107 | ecp_basis_symb = pyscf.gto.basis.load(f"ccecp{aug}ccpv{zeta}z", symb) 108 | 109 | ecp_all[z] = ecp_symb 110 | ecp_basis[z] = ecp_basis_symb 111 | 112 | r_c = { 113 | z: max(calc_r_c(ecp_symb, v_tol).values()) 114 | for z, (nelec, ecp_symb) in ecp_all.items() 115 | } 116 | 117 | n_cores, v_grid_dict, r_grid, n_channels = eval_ecp_on_grid(ecp_all) 118 | 119 | pp_args = n_cores, r_c, n_channels, v_grid_dict, r_grid, quad_degree 120 | if return_ecp: 121 | return (ecp_all, ecp_basis), pp_args 122 | else: 123 | return pp_args 124 | 125 | 126 | def check_pp_args(pp_args): 127 | """Checks that pseudopotential arguments are consistent.""" 128 | # TODO(hsutterud): switch to a better error type 129 | n_cores, r_c, n_channels, v_grid_dict, r_grid, quad_degree = pp_args 130 | 131 | assert n_cores.keys() == r_c.keys() == v_grid_dict.keys() 132 | assert isinstance(quad_degree, int) 133 | assert quad_degree > 0 134 | 135 | for _, v_grid in v_grid_dict.items(): 136 | assert v_grid.shape[1] == r_grid.shape[0] 137 | assert v_grid.shape[0] <= n_channels 138 | 139 | 140 | def leg_l0(x): 141 | return jnp.ones_like(x) 142 | 143 | 144 | def leg_l1(x): 145 | return x 146 | 147 | 148 | def leg_l2(x): 149 | return 0.5 * (3 * x**2 - 1) 150 | 151 | 152 | def leg_l3(x): 153 | return 0.5 * (5 * x**3 - 3 * x) 154 | 155 | 156 | def eval_leg(x, l): 157 | return jax.lax.switch(l, [leg_l0, leg_l1, leg_l2, leg_l3], x) 158 | -------------------------------------------------------------------------------- /ferminet/pbc/envelopes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """Multiplicative envelopes appropriate for periodic boundary conditions. 16 | 17 | See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., 18 | Spencer, J.S. and Foulkes, W.M.C., 2022. Discovering Quantum Phase Transitions 19 | with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183. 20 | """ 21 | 22 | import itertools 23 | from typing import Mapping, Optional, Sequence, Tuple, Union 24 | 25 | from ferminet import envelopes 26 | 27 | import jax.numpy as jnp 28 | import numpy as np 29 | 30 | 31 | def make_multiwave_envelope(kpoints: jnp.ndarray) -> envelopes.Envelope: 32 | """Returns an oscillatory envelope. 33 | 34 | Envelope consists of a sum of truncated 3D Fourier series, one centered on 35 | each atom, with Fourier frequencies given by kpoints: 36 | 37 | sigma_{2i}*cos(kpoints_i.r_{ae}) + sigma_{2i+1}*sin(kpoints_i.r_{ae}) 38 | 39 | Initialization sets the coefficient of the first term in each 40 | series to 1, and all other coefficients to 0. This corresponds to the 41 | cosine of the first entry in kpoints. If this is [0, 0, 0], the envelope 42 | will evaluate to unity at the beginning of training. 43 | 44 | Args: 45 | kpoints: Reciprocal lattice vectors of terms included in the Fourier 46 | series. Shape (nkpoints, ndim) (Note that ndim=3 is currently 47 | a hard-coded default). 48 | 49 | Returns: 50 | An instance of ferminet.envelopes.Envelope with apply_type 51 | envelopes.EnvelopeType.PRE_DETERMINANT 52 | """ 53 | 54 | def init( 55 | natom: int, output_dims: Sequence[int], ndim: int = 3 56 | ) -> Sequence[Mapping[str, jnp.ndarray]]: 57 | """See ferminet.envelopes.EnvelopeInit.""" 58 | del natom, ndim # unused 59 | params = [] 60 | nk = kpoints.shape[0] 61 | for output_dim in output_dims: 62 | params.append({'sigma': jnp.zeros((2 * nk, output_dim))}) 63 | params[-1]['sigma'] = params[-1]['sigma'].at[0, :].set(1.0) 64 | return params 65 | 66 | def apply(*, ae: jnp.ndarray, r_ae: jnp.ndarray, r_ee: jnp.ndarray, 67 | sigma: jnp.ndarray) -> jnp.ndarray: 68 | """See ferminet.envelopes.EnvelopeApply.""" 69 | del r_ae, r_ee # unused 70 | phase_coords = ae @ kpoints.T 71 | waves = jnp.concatenate((jnp.cos(phase_coords), jnp.sin(phase_coords)), 72 | axis=2) 73 | env = waves @ (sigma**2.0) 74 | return jnp.sum(env, axis=1) 75 | 76 | return envelopes.Envelope(envelopes.EnvelopeType.PRE_DETERMINANT, init, apply) 77 | 78 | 79 | def make_kpoints( 80 | lattice: Union[np.ndarray, jnp.ndarray], 81 | spins: Tuple[int, int], 82 | min_kpoints: Optional[int] = None, 83 | ) -> jnp.ndarray: 84 | """Generates an array of reciprocal lattice vectors. 85 | 86 | Args: 87 | lattice: Matrix whose columns are the primitive lattice vectors of the 88 | system, shape (ndim, ndim). (Note that ndim=3 is currently 89 | a hard-coded default). 90 | spins: Tuple of the number of spin-up and spin-down electrons. 91 | min_kpoints: If specified, the number of kpoints which must be included in 92 | the output. The number of kpoints returned will be the 93 | first filled shell which is larger than this value. Defaults to None, 94 | which results in min_kpoints == sum(spins). 95 | 96 | Raises: 97 | ValueError: Fewer kpoints requested by min_kpoints than number of 98 | electrons in the system. 99 | 100 | Returns: 101 | jnp.ndarray, shape (nkpoints, ndim), an array of reciprocal lattice 102 | vectors sorted in ascending order according to length. 103 | """ 104 | rec_lattice = 2 * jnp.pi * jnp.linalg.inv(lattice) 105 | # Calculate required no. of k points 106 | if min_kpoints is None: 107 | min_kpoints = sum(spins) 108 | elif min_kpoints < sum(spins): 109 | raise ValueError( 110 | 'Number of kpoints must be equal or greater than number of electrons') 111 | 112 | dk = 1 + 1e-5 113 | # Generate ordinals of the lowest min_kpoints kpoints 114 | max_k = int(jnp.ceil(min_kpoints * dk)**(1 / 3.)) 115 | ordinals = sorted(range(-max_k, max_k+1), key=abs) 116 | ordinals = jnp.asarray(list(itertools.product(ordinals, repeat=3))) 117 | 118 | kpoints = ordinals @ rec_lattice.T 119 | kpoints = jnp.asarray(sorted(kpoints, key=jnp.linalg.norm)) 120 | k_norms = jnp.linalg.norm(kpoints, axis=1) 121 | 122 | return kpoints[k_norms <= k_norms[min_kpoints - 1] * dk] 123 | -------------------------------------------------------------------------------- /ferminet/configs/excited/oscillator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Config to reproduce Figs. 2 and S3 from Pfau et al. (2024).""" 16 | 17 | 18 | from ferminet import base_config 19 | from ferminet.configs.excited import presets 20 | from ferminet.utils import system 21 | import ml_collections 22 | import pyscf 23 | 24 | 25 | # Geometries from Chrayteh, Blondel, Loos and Jacquemin, JCTC (2021) 26 | # All geometries in atomic units (Bohr) 27 | _SYSTEMS = { 28 | 'BH': ['B 0.00000000 0.00000000 0.00000000', 29 | 'H 0.00000000 0.00000000 2.31089693'], 30 | 'HCl': ['H 0.00000000 0.00000000 2.38483140', 31 | 'Cl 0.00000000 0.00000000 -0.02489783'], 32 | 'H2O': ['O 0.00000000 0.00000000 -0.13209669', 33 | 'H 0.00000000 1.43152878 0.97970006', 34 | 'H 0.00000000 -1.43152878 0.97970006'], 35 | 'H2S': ['S 0.00000000 0.00000000 -0.50365086', 36 | 'H 0.00000000 1.81828105 1.25212288', 37 | 'H 0.00000000 -1.81828105 1.25212288'], 38 | 'BF': ['B 0.00000000 0.00000000 0.00000000', 39 | 'F 0.00000000 0.00000000 2.39729626'], 40 | 'CO': ['C 0.00000000 0.00000000 -1.24942055', 41 | 'O 0.00000000 0.00000000 0.89266692'], 42 | 'N2': ['N 0.00000000 0.00000000 1.04008632', 43 | 'N 0.00000000 0.00000000 -1.04008632'], 44 | 'C2H4': ['C 0.00000000 1.26026583 0.00000000', 45 | 'C 0.00000000 -1.26026583 0.00000000', 46 | 'H 0.00000000 2.32345976 1.74287672', 47 | 'H 0.00000000 -2.32345976 1.74287672', 48 | 'H 0.00000000 2.32345976 -1.74287672', 49 | 'H 0.00000000 -2.32345976 -1.74287672'], 50 | 'CH2O': ['C 0.00000000 0.00000000 1.13947666', 51 | 'O 0.00000000 0.00000000 -1.14402883', 52 | 'H 0.00000000 1.76627623 2.23398653', 53 | 'H 0.00000000 -1.76627623 2.23398653'], 54 | 'CH2S': ['C 0.00000000 0.00000000 2.08677304', 55 | 'S 0.00000000 0.00000000 -0.97251194', 56 | 'H 0.00000000 1.73657773 3.17013507', 57 | 'H 0.00000000 -1.73657773 3.17013507'], 58 | 'HNO': ['O 0.21099695 0.00000000 2.15462460', 59 | 'N -0.44776863 0.00000000 -0.03589263', 60 | 'H 1.18163475 0.00000000 -1.17386890'], 61 | 'HCF': ['C -0.13561085 0.00000000 1.20394474', 62 | 'F 1.85493976 0.00000000 -0.27610752', 63 | 'H -1.71932891 0.00000000 -0.18206846'], 64 | 'H2CSi': ['C 0.00000000 0.00000000 -2.09539928', 65 | 'Si 0.00000000 0.00000000 1.14992930', 66 | 'H 0.00000000 1.70929524 -3.22894481', 67 | 'H 0.00000000 -1.70929524 -3.22894481'], 68 | } 69 | 70 | 71 | def finalise( 72 | experiment_config: ml_collections.ConfigDict) -> ml_collections.ConfigDict: 73 | """Returns the experiment config with the molecule commpletely set.""" 74 | geom = _SYSTEMS[experiment_config.system.molecule_name] 75 | 76 | molecule = [] 77 | for atom in geom: 78 | element, x, y, z = atom.split() 79 | coords = [float(xx) for xx in (x, y, z)] 80 | molecule.append(system.Atom(symbol=element, coords=coords, units='bohr')) 81 | 82 | if not experiment_config.system.electrons: # Don't override if already set 83 | nelectrons = int(sum(atom.charge for atom in molecule)) 84 | na = nelectrons // 2 85 | experiment_config.system.electrons = (na, nelectrons - na) 86 | experiment_config.system.molecule = molecule 87 | 88 | pseudo_atoms = ['S', 'Si', 'Cl'] # use ECP for second-row atoms 89 | 90 | mol = pyscf.gto.Mole() 91 | mol.atom = [[atom.symbol, atom.coords] for atom in molecule] 92 | 93 | atoms = list(set([atom.symbol for atom in molecule])) 94 | mol.basis = { 95 | atom: 96 | experiment_config.system.ecp_basis if atom in pseudo_atoms else 'cc-pvdz' 97 | for atom in atoms 98 | } 99 | mol.ecp = { 100 | atom: experiment_config.system.ecp 101 | for atom in atoms if atom in pseudo_atoms 102 | } 103 | 104 | mol.charge = 0 105 | mol.spin = 0 106 | mol.unit = 'bohr' 107 | mol.build() 108 | 109 | experiment_config.system.pyscf_mol = mol 110 | experiment_config.system.make_local_energy_kwargs = { 111 | 'ecp_symbols': list(mol.ecp.keys()), 112 | 'ecp_type': experiment_config.system.ecp, 113 | } 114 | 115 | return experiment_config 116 | 117 | 118 | def get_config() -> ml_collections.ConfigDict: 119 | """Returns the config for running FermiNet on a molecule from the G3 set.""" 120 | cfg = base_config.default() 121 | 122 | cfg.system.molecule_name = '' 123 | cfg.system.ecp = 'ccecp' 124 | cfg.system.ecp_basis = 'ccecp-cc-pVDZ' 125 | cfg.system.states = 5 126 | cfg.pretrain.iterations = 10_000 127 | cfg.update_from_flattened_dict(presets.excited_states) 128 | with cfg.ignore_type(): 129 | cfg.system.set_molecule = finalise 130 | 131 | return cfg 132 | -------------------------------------------------------------------------------- /ferminet/sto.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Constants for Slater-type orbitals.""" 16 | 17 | # Zeta and c coefficient for Slater-type orbitals reverse-fit from STO-6G 18 | STO_6G_COEFFS = { 19 | 'Al': { 20 | '1s': (12.557246008517494, 25.106698511179705), 21 | '2s': (4.359619258385573, 12.927028645707773), 22 | '3s': (1.6997831803065009, 0.7615604776424957), 23 | '2p': (6.554000435218195, 94.73883433477683), 24 | '3p': (2.2351712617093247, 1.3827832310783092), 25 | }, 26 | 'Ar': { 27 | '1s': (17.39539980503338, 40.93071746434361), 28 | '2s': (6.7395274109023875, 38.411458472102076), 29 | '3s': (2.329717983859512, 2.295580734042915), 30 | '2p': (10.084117748829376, 424.5775400483707), 31 | '3p': (3.060265563420203, 5.664664351842427), 32 | }, 33 | 'B': { 34 | '1s': (4.67868616774112, 5.70938862708008), 35 | '2s': (1.4999288235962795, 0.8976437475218287), 36 | '2p': (2.2716468438938575, 2.3554214877163155), 37 | }, 38 | 'Be': { 39 | '1s': (3.6790735046348035, 3.9812194646017587), 40 | '2s': (1.1499342729534248, 0.4619650556409617), 41 | '2p': (1.737594734756709, 0.9201071798798099), 42 | }, 43 | 'C': { 44 | '1s': (5.668891371758942, 7.615256589029782), 45 | '2s': (1.719877299055101, 1.2637414721616658), 46 | '2p': (2.603224975467389, 3.791736047605733), 47 | }, 48 | 'Ca': { 49 | '1s': (19.5709944613605, 48.83583953845113), 50 | '2s': (7.738732968586929, 54.260484887280676), 51 | '3s': (3.0096391625807035, 5.625243917454751), 52 | '4s': (1.3601380588639749, 0.12698321714720479), 53 | '2p': (11.44277220950772, 656.1541713931504), 54 | '3p': (3.951008639686456, 17.85199403770938), 55 | '4p': (1.6981898284489827, 0.14655638210969082), 56 | }, 57 | 'Cl': { 58 | '1s': (16.42714984359141, 37.564449406190995), 59 | '2s': (6.258720160253116, 31.915499495709224), 60 | '3s': (2.0997377724242736, 1.5955272516248895), 61 | '2p': (9.379020216891835, 331.94837422065166), 62 | '3p': (2.7582728076424314, 3.550609030735584), 63 | }, 64 | 'F': { 65 | '1s': (8.647725610389443, 14.3478532498991), 66 | '2s': (2.5498686974182796, 3.3822880320590616), 67 | '2p': (3.8381983263962307, 14.689479350772942), 68 | }, 69 | 'H': { 70 | '1s': (1.2396569932253754, 0.7786482598438814), 71 | }, 72 | 'He': { 73 | '1s': (1.6895399166477436, 1.2389169650053704), 74 | }, 75 | 'K': { 76 | '1s': (18.603597538983905, 45.2640252413769), 77 | '2s': (7.259222119835751, 46.24588987232522), 78 | '3s': (2.7496572800200547, 4.100159535853956), 79 | '4s': (1.4301411982280272, 0.1591563672673338), 80 | '2p': (10.773926382192002, 534.0140415145601), 81 | '3p': (3.62088857101528, 12.145661267587803), 82 | '4p': (1.7834260775485395, 0.19104940003516785), 83 | }, 84 | 'Li': { 85 | '1s': (2.689237663677221, 2.487921264218123), 86 | '2s': (0.7999592565921534, 0.18646529682488083), 87 | '2p': (1.2131148771860853, 0.2626425031316319), 88 | }, 89 | 'Mg': { 90 | '1s': (10.606505744603167, 19.48727485747987), 91 | '2s': (3.4797896476024155, 7.358433661799946), 92 | '3s': (1.7497617130189205, 0.8428386648793419), 93 | '2p': (5.214100079417301, 42.5456609714464), 94 | '3p': (2.304293894314714, 1.5934349371819814), 95 | }, 96 | 'N': { 97 | '1s': (6.668046567851146, 9.714119765863773), 98 | '2s': (1.9498640428999283, 1.7295555428587879), 99 | '2p': (2.953104555611183, 5.903584171692387), 100 | }, 101 | 'Na': { 102 | '1s': (10.606505744603167, 19.48727485747987), 103 | '2s': (3.4797896476024155, 7.358433661799946), 104 | '3s': (1.7497617130189205, 0.8428386648793419), 105 | '2p': (5.214100079417301, 42.5456609714464), 106 | '3p': (2.304293894314714, 1.5934349371819814), 107 | }, 108 | 'Ne': { 109 | '1s': (9.638169483462994, 16.882135191702933), 110 | '2s': (2.879839514936209, 4.5849782106230474), 111 | '2p': (4.313575648976462, 21.89010672369113), 112 | }, 113 | 'O': { 114 | '1s': (7.658312143014643, 11.957311431836787), 115 | '2s': (2.2498858861805417, 2.4735485371920998), 116 | '2p': (3.382554479984454, 9.404304291880152), 117 | }, 118 | 'P': { 119 | '1s': (14.49840434649223, 31.149470127449554), 120 | '2s': (5.3092768063608204, 21.157258029235972), 121 | '3s': (1.8997434641776487, 1.1239662327153015), 122 | '2p': (7.868086018599397, 177.5301649474899), 123 | '3p': (2.504478963796671, 2.3183467992744076), 124 | }, 125 | 'S': { 126 | '1s': (15.468177117665267, 34.325592817926776), 127 | '2s': (5.788995339687754, 26.26240044766998), 128 | '3s': (2.0497271436527353, 1.4664153516476472), 129 | '2p': (8.713206244747633, 257.8924790569185), 130 | '3p': (2.69317030051437, 3.1966886670849703), 131 | }, 132 | 'Si': { 133 | '1s': (13.528008654654725, 28.074801445575115), 134 | '2s': (4.8295674907671104, 16.69869883644322), 135 | '3s': (1.7497617129910332, 0.8428386647959631), 136 | '2p': (7.211841900405304, 131.56735728268595), 137 | '3p': (2.304293894345575, 1.5934349373181378), 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /ferminet/utils/tests/elements_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ferminet.utils.elements.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | from ferminet.utils import elements 21 | 22 | 23 | class ElementsTest(parameterized.TestCase): 24 | 25 | def test_elements(self): 26 | for n, element in elements.ATOMIC_NUMS.items(): 27 | self.assertEqual(n, element.atomic_number) 28 | self.assertEqual(elements.SYMBOLS[element.symbol], element) 29 | if element.symbol == 'X': 30 | continue 31 | elif element.symbol in ['Li', 'Na', 'K', 'Rb', 'Cs', 'Fr']: 32 | self.assertEqual(element.period, elements.ATOMIC_NUMS[n - 1].period + 1) 33 | elif element.symbol != 'H': 34 | self.assertEqual(element.period, elements.ATOMIC_NUMS[n - 1].period) 35 | self.assertCountEqual( 36 | (element.symbol for element in elements.ATOMIC_NUMS.values()), 37 | elements.SYMBOLS.keys()) 38 | self.assertCountEqual( 39 | (element.atomic_number for element in elements.SYMBOLS.values()), 40 | elements.ATOMIC_NUMS.keys()) 41 | 42 | @parameterized.parameters( 43 | (elements.SYMBOLS['H'], 1, 1, 1, 1, 0), 44 | (elements.SYMBOLS['He'], 1, 18, 0, 1, 1), 45 | (elements.SYMBOLS['Li'], 2, 1, 1, 2, 1), 46 | (elements.SYMBOLS['Be'], 2, 2, 0, 2, 2), 47 | (elements.SYMBOLS['C'], 2, 14, 2, 4, 2), 48 | (elements.SYMBOLS['N'], 2, 15, 3, 5, 2), 49 | (elements.SYMBOLS['Al'], 3, 13, 1, 7, 6), 50 | (elements.SYMBOLS['Zn'], 4, 12, 0, 15, 15), 51 | (elements.SYMBOLS['Ga'], 4, 13, 1, 16, 15), 52 | (elements.SYMBOLS['Kr'], 4, 18, 0, 18, 18), 53 | (elements.SYMBOLS['Ce'], 6, -1, -1, None, None), 54 | (elements.SYMBOLS['Ac'], 7, 3, -1, None, None), 55 | ) 56 | def test_element_group_period(self, element, period, group, spin_config, 57 | nalpha, nbeta): 58 | # Validate subset of elements. See below for more thorough tests using 59 | # properties of the periodic table. 60 | with self.subTest('Verify period'): 61 | self.assertEqual(element.period, period) 62 | with self.subTest('Verify group'): 63 | self.assertEqual(element.group, group) 64 | with self.subTest('Verify spin configuration'): 65 | if (element.period > 5 and 66 | (element.group == -1 or 3 <= element.group <= 12)): 67 | with self.assertRaises(NotImplementedError): 68 | _ = element.spin_config 69 | else: 70 | self.assertEqual(element.spin_config, spin_config) 71 | with self.subTest('Verify electrons per spin'): 72 | if (element.period > 5 and 73 | (element.group == -1 or 3 <= element.group <= 12)): 74 | with self.assertRaises(NotImplementedError): 75 | _ = element.nalpha 76 | with self.assertRaises(NotImplementedError): 77 | _ = element.nbeta 78 | else: 79 | self.assertEqual(element.nalpha, nalpha) 80 | self.assertEqual(element.nbeta, nbeta) 81 | 82 | def test_periods(self): 83 | self.assertLen(elements.ATOMIC_NUMS, 84 | sum(len(period) for period in elements.PERIODS.values())) 85 | period_length = {0: 1, 1: 2, 2: 8, 3: 8, 4: 18, 5: 18, 6: 32, 7: 32} 86 | for p, es in elements.PERIODS.items(): 87 | self.assertLen(es, period_length[p]) 88 | 89 | def test_groups(self): 90 | # Atomic numbers of first element in each period. 91 | period_starts = sorted([ 92 | period_elements[0].atomic_number 93 | for period_elements in elements.PERIODS.values() 94 | ]) 95 | # Iterate over all elements in order of atomic number. Group should 96 | # increment monotonically (except for accommodating absence of d block and 97 | # presence of f block) and reset to 1 on the first element in each period. 98 | for i in range(1, len(elements.ATOMIC_NUMS)): 99 | element = elements.ATOMIC_NUMS[i] 100 | if element.atomic_number in period_starts: 101 | prev_group = 0 102 | fblock = 0 103 | if element.symbol == 'He': 104 | # Full shell, not just full s subshell. 105 | self.assertEqual(element.group, 18) 106 | elif element.group == -1: 107 | # Found a lanthanide (period 6) or actinide (period 7). 108 | self.assertIn(element.period, [6, 7]) 109 | fblock += 1 110 | elif element.atomic_number == 5 or element.atomic_number == 13: 111 | # No d block (10 elements, groups 3-12) in periods 2 and 3. 112 | self.assertEqual(element.group, prev_group + 11) 113 | else: 114 | # Group should increment monotonically. 115 | self.assertEqual(element.group, prev_group + 1) 116 | if element.group != -1: 117 | prev_group = element.group 118 | self.assertGreaterEqual(prev_group, 1) 119 | self.assertLessEqual(prev_group, 18) 120 | if element.group == 4 and element.period > 6: 121 | # Should have seen 14 lanthanides (period 6) or 14 actinides (period 7). 122 | self.assertEqual(fblock, 14) 123 | 124 | # The periodic table (up to element 118) contains 7 periods. 125 | # Hydrogen and Helium are placed in groups 1 and 18 respectively. 126 | # Groups 1-2 (s-block) and 13-18 (p-block) are present in the second 127 | # period onwards, groups 3-12 (d-block) the fourth period onwards. 128 | # Check each group contains the expected number of elements. 129 | nelements_in_group = [0]*18 130 | for element in elements.ATOMIC_NUMS.values(): 131 | if element.group != -1 and element.period != 0: 132 | nelements_in_group[element.group-1] += 1 133 | self.assertListEqual(nelements_in_group, [7, 6] + [4]*10 + [6]*5 + [7]) 134 | 135 | 136 | if __name__ == '__main__': 137 | absltest.main() 138 | -------------------------------------------------------------------------------- /ferminet/network_blocks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Neural network building blocks.""" 16 | 17 | import functools 18 | import itertools 19 | from typing import MutableMapping, Optional, Sequence, Tuple 20 | 21 | import chex 22 | import jax 23 | import jax.numpy as jnp 24 | 25 | 26 | def array_partitions(sizes: Sequence[int]) -> Sequence[int]: 27 | """Returns the indices for splitting an array into separate partitions. 28 | 29 | Args: 30 | sizes: size of each of N partitions. The dimension of the array along 31 | the relevant axis is assumed to be sum(sizes). 32 | 33 | Returns: 34 | sequence of indices (length len(sizes)-1) at which an array should be split 35 | to give the desired partitions. 36 | """ 37 | return list(itertools.accumulate(sizes))[:-1] 38 | 39 | 40 | def split_into_blocks(block_arr: jnp.ndarray, 41 | block_dims: Tuple[int, ...]) -> Sequence[jnp.ndarray]: 42 | """Split a square array into blocks along the leading two axes. 43 | 44 | Consider the (N,N) array 45 | A B 46 | C D 47 | where A=(N1,N1), B=(N1,N2), C=(N2,N1), D=(N2,N2), and N=N1+N2. Split the array 48 | into the given blocks. 49 | 50 | Args: 51 | block_arr: block array to split. 52 | block_dims: the size of each block along each axis (i.e. N1, N2, ...). 53 | 54 | Returns: 55 | blocks of the array split along the two leading axes into chunks with 56 | dimensions given by block_dims. 57 | """ 58 | partitions = array_partitions(block_dims) 59 | block1 = jnp.split(block_arr, partitions, axis=0) 60 | block12 = [jnp.split(arr, partitions, axis=1) for arr in block1] 61 | return tuple(itertools.chain.from_iterable(block12)) 62 | 63 | 64 | def init_linear_layer( 65 | key: chex.PRNGKey, in_dim: int, out_dim: int, include_bias: bool = True 66 | ) -> MutableMapping[str, jnp.ndarray]: 67 | """Initialises parameters for a linear layer, x w + b. 68 | 69 | Args: 70 | key: JAX PRNG state. 71 | in_dim: input dimension to linear layer. 72 | out_dim: output dimension (number of hidden units) of linear layer. 73 | include_bias: if true, include a bias in the linear layer. 74 | 75 | Returns: 76 | A mapping containing the weight matrix (key 'w') and, if required, bias 77 | unit (key 'b'). 78 | """ 79 | key1, key2 = jax.random.split(key) 80 | weight = ( 81 | jax.random.normal(key1, shape=(in_dim, out_dim)) / 82 | jnp.sqrt(float(in_dim))) 83 | if include_bias: 84 | bias = jax.random.normal(key2, shape=(out_dim,)) 85 | return {'w': weight, 'b': bias} 86 | else: 87 | return {'w': weight} 88 | 89 | 90 | def linear_layer(x: jnp.ndarray, 91 | w: jnp.ndarray, 92 | b: Optional[jnp.ndarray] = None) -> jnp.ndarray: 93 | """Evaluates a linear layer, x w + b. 94 | 95 | Args: 96 | x: inputs. 97 | w: weights. 98 | b: optional bias. 99 | 100 | Returns: 101 | x w + b if b is given, x w otherwise. 102 | """ 103 | y = jnp.dot(x, w) 104 | return y + b if b is not None else y 105 | 106 | vmap_linear_layer = jax.vmap(linear_layer, in_axes=(0, None, None), out_axes=0) 107 | 108 | 109 | def slogdet(x): 110 | """Computes sign and log of determinants of matrices. 111 | 112 | This is a jnp.linalg.slogdet with a special (fast) path for small matrices. 113 | 114 | Args: 115 | x: square matrix. 116 | 117 | Returns: 118 | sign, (natural) logarithm of the determinant of x. 119 | """ 120 | if x.shape[-1] == 1: 121 | if x.dtype == jnp.complex64 or x.dtype == jnp.complex128: 122 | sign = x[..., 0, 0] / jnp.abs(x[..., 0, 0]) 123 | else: 124 | sign = jnp.sign(x[..., 0, 0]) 125 | logdet = jnp.log(jnp.abs(x[..., 0, 0])) 126 | else: 127 | sign, logdet = jnp.linalg.slogdet(x) 128 | 129 | return sign, logdet 130 | 131 | 132 | def logdet_matmul( 133 | xs: Sequence[jnp.ndarray], w: Optional[jnp.ndarray] = None 134 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 135 | """Combines determinants and takes dot product with weights in log-domain. 136 | 137 | We use the log-sum-exp trick to reduce numerical instabilities. 138 | 139 | Args: 140 | xs: FermiNet orbitals in each determinant. Either of length 1 with shape 141 | (ndet, nelectron, nelectron) (full_det=True) or length 2 with shapes 142 | (ndet, nalpha, nalpha) and (ndet, nbeta, nbeta) (full_det=False, 143 | determinants are factorised into block-diagonals for each spin channel). 144 | w: weight of each determinant. If none, a uniform weight is assumed. 145 | 146 | Returns: 147 | sum_i w_i D_i in the log domain, where w_i is the weight of D_i, the i-th 148 | determinant (or product of the i-th determinant in each spin channel, if 149 | full_det is not used). 150 | """ 151 | # 1x1 determinants appear to be numerically sensitive and can become 0 152 | # (especially when multiple determinants are used with the spin-factored 153 | # wavefunction). Avoid this by not going into the log domain for 1x1 matrices. 154 | # Pass initial value to functools so det1d = 1 if all matrices are larger than 155 | # 1x1. 156 | det1d = functools.reduce(lambda a, b: a * b, 157 | [x.reshape(-1) for x in xs if x.shape[-1] == 1], 1) 158 | # Pass initial value to functools so sign_in = 1, logdet = 0 if all matrices 159 | # are 1x1. 160 | phase_in, logdet = functools.reduce( 161 | lambda a, b: (a[0] * b[0], a[1] + b[1]), 162 | [slogdet(x) for x in xs if x.shape[-1] > 1], (1, 0)) 163 | 164 | # log-sum-exp trick 165 | maxlogdet = jnp.max(logdet) 166 | det = phase_in * det1d * jnp.exp(logdet - maxlogdet) 167 | if w is None: 168 | result = jnp.sum(det) 169 | else: 170 | result = jnp.matmul(det, w)[0] 171 | # return phase as a unit-norm complex number, rather than as an angle 172 | if result.dtype == jnp.complex64 or result.dtype == jnp.complex128: 173 | phase_out = jnp.angle(result) # result / jnp.abs(result) 174 | else: 175 | phase_out = jnp.sign(result) 176 | log_out = jnp.log(jnp.abs(result)) + maxlogdet 177 | return phase_out, log_out 178 | -------------------------------------------------------------------------------- /ferminet/tests/psiformer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import itertools 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from ferminet import psiformer 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | 25 | 26 | def _network_options(): 27 | """Yields the set of all combinations of options to pass into test_fermi_net. 28 | 29 | Example output: 30 | { 31 | 'vmap': True, 32 | 'ndim': 3, 33 | 'determinants': 1, 34 | 'nspins': (1, 0), 35 | 'jastrow': 'none', 36 | 'rescale_inputs': True, 37 | } 38 | """ 39 | # Key for each option and corresponding values to test. 40 | all_options = { 41 | 'vmap': [True, False], 42 | 'ndim': [2, 3], 43 | 'determinants': [1, 4], 44 | 'nspins': [(1, 1), (1, 0), (2, 1)], 45 | 'jastrow': ['none', 'simple_ee'], 46 | 'rescale_inputs': [True, False], 47 | } 48 | # Create the product of all options. 49 | for options in itertools.product(*all_options.values()): 50 | options_dict = dict(zip(all_options.keys(), options)) 51 | yield options_dict 52 | 53 | 54 | class PsiformerTest(parameterized.TestCase): 55 | 56 | @parameterized.parameters( 57 | {'jastrow': 'simple_ee'}, 58 | {'jastrow': 'none'}, 59 | ) 60 | def test_antisymmetry(self, jastrow): 61 | """Check that the Psiformer is antisymmetric.""" 62 | 63 | key = jax.random.PRNGKey(42) 64 | natom = 4 65 | key, *subkeys = jax.random.split(key, num=3) 66 | atoms = jax.random.normal(subkeys[0], shape=(natom, 3)) 67 | charges = jax.random.normal(subkeys[1], shape=(natom,)) 68 | nspins = (3, 4) 69 | determinants = 6 70 | 71 | network = psiformer.make_fermi_net( 72 | nspins, 73 | charges, 74 | determinants=determinants, 75 | ndim=3, 76 | jastrow=jastrow, 77 | num_layers=4, 78 | num_heads=8, 79 | heads_dim=32, 80 | mlp_hidden_dims=(64, 128), 81 | use_layer_norm=True, 82 | tf32=False, 83 | ) 84 | 85 | key, subkey = jax.random.split(key) 86 | params = network.init(subkey) 87 | 88 | # Randomize parameters of envelope 89 | for i in range(len(params['envelope'])): 90 | if params['envelope'][i]: 91 | key, *subkeys = jax.random.split(key, num=3) 92 | params['envelope'][i]['sigma'] = jax.random.normal( 93 | subkeys[0], params['envelope'][i]['sigma'].shape 94 | ) 95 | params['envelope'][i]['pi'] = jax.random.normal( 96 | subkeys[1], params['envelope'][i]['pi'].shape 97 | ) 98 | 99 | key, subkey = jax.random.split(key) 100 | pos1 = jax.random.normal(subkey, shape=(sum(nspins) * 3,)) 101 | # Switch position and spin of first and second electrons. 102 | pos2 = jnp.concatenate((pos1[3:6], pos1[:3], pos1[6:])) 103 | # Switch position and spin of fourth and fifth electrons. 104 | pos3 = jnp.concatenate((pos1[:9], pos1[12:15], pos1[9:12], pos1[15:])) 105 | 106 | key, subkey = jax.random.split(key) 107 | spins1 = jax.random.uniform(subkey, shape=(sum(nspins),)) 108 | spins2 = jnp.concatenate((spins1[1:2], spins1[:1], spins1[2:])) 109 | spins3 = jnp.concatenate((spins1[:3], spins1[4:5], spins1[3:4], spins1[5:])) 110 | 111 | out1 = network.apply(params, pos1, spins1, atoms, charges) 112 | 113 | if out1[0].dtype == jnp.float32: 114 | rtol = 1.5e-4 115 | atol = 1.0e-4 116 | else: 117 | rtol = 1.0e-7 118 | atol = 0 119 | 120 | # Output should have the same magnitude but different sign. 121 | out2 = network.apply(params, pos2, spins2, atoms, charges) 122 | with self.subTest('swap up electrons'): 123 | np.testing.assert_allclose(out1[1], out2[1], rtol=rtol, atol=atol) 124 | np.testing.assert_allclose(out1[0], -1 * out2[0]) 125 | 126 | # Output should have the same magnitude but different sign. 127 | out3 = network.apply(params, pos3, spins3, atoms, charges) 128 | with self.subTest('swap down electrons'): 129 | np.testing.assert_allclose(out1[1], out3[1], rtol, atol=atol) 130 | np.testing.assert_allclose(out1[0], -1 * out3[0]) 131 | 132 | @parameterized.parameters(_network_options()) 133 | def test_psiformer(self, **network_options): 134 | nspins = network_options['nspins'] 135 | ndim = network_options['ndim'] 136 | atoms_3d = jnp.asarray( 137 | [[0.0, 0.0, 0.2], [1.2, 1.0, -0.2], [2.5, -0.8, 0.6]] 138 | ) 139 | atoms = atoms_3d[:, :ndim] 140 | charges = jnp.asarray([2, 5, 7]) 141 | key = jax.random.PRNGKey(42) 142 | 143 | psiformer_config = { 144 | 'num_layers': 4, 145 | 'num_heads': 4, 146 | 'heads_dim': 128, 147 | 'mlp_hidden_dims': (64, 32), 148 | 'use_layer_norm': True, 149 | 'tf32': False, 150 | } 151 | 152 | network = psiformer.make_fermi_net( 153 | nspins, 154 | charges, 155 | determinants=network_options['determinants'], 156 | ndim=ndim, 157 | rescale_inputs=network_options['rescale_inputs'], 158 | jastrow=network_options['jastrow'], 159 | **psiformer_config 160 | ) 161 | 162 | key, subkey = jax.random.split(key) 163 | if network_options['vmap']: 164 | batch = 10 165 | xs = jax.random.uniform(subkey, shape=(batch, sum(nspins), ndim)) 166 | 167 | network_apply = jax.vmap(network.apply, in_axes=(None, 0, 0, None, None)) 168 | expected_shape = (batch,) 169 | else: 170 | batch = 1 171 | xs = jax.random.uniform(subkey, shape=(sum(nspins), ndim)) 172 | network_apply = network.apply 173 | expected_shape = () 174 | 175 | key, subkey = jax.random.split(key) 176 | spins = jax.random.uniform(subkey, shape=(batch, sum(nspins))) 177 | if not network_options['vmap']: 178 | spins = jnp.squeeze(spins, axis=0) 179 | 180 | key, subkey = jax.random.split(key) 181 | params = network.init(subkey) 182 | 183 | sign_out, log_out = network_apply(params, xs, spins, atoms, charges) 184 | self.assertSequenceEqual(sign_out.shape, expected_shape) 185 | self.assertSequenceEqual(log_out.shape, expected_shape) 186 | 187 | 188 | if __name__ == '__main__': 189 | absltest.main() 190 | -------------------------------------------------------------------------------- /ferminet/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Super simple checkpoints using numpy.""" 16 | 17 | import dataclasses 18 | import datetime 19 | import os 20 | from typing import Optional 21 | import zipfile 22 | 23 | from absl import logging 24 | from ferminet import networks 25 | from ferminet import observables 26 | import jax 27 | import jax.numpy as jnp 28 | import numpy as np 29 | 30 | 31 | def find_last_checkpoint(ckpt_path: Optional[str] = None) -> Optional[str]: 32 | """Finds most recent valid checkpoint in a directory. 33 | 34 | Args: 35 | ckpt_path: Directory containing checkpoints. 36 | 37 | Returns: 38 | Last QMC checkpoint (ordered by sorting all checkpoints by name in reverse) 39 | or None if no valid checkpoint is found or ckpt_path is not given or doesn't 40 | exist. A checkpoint is regarded as not valid if it cannot be read 41 | successfully using np.load. 42 | """ 43 | if ckpt_path and os.path.exists(ckpt_path): 44 | files = [f for f in os.listdir(ckpt_path) if 'qmcjax_ckpt_' in f] 45 | # Handle case where last checkpoint is corrupt/empty. 46 | for file in sorted(files, reverse=True): 47 | fname = os.path.join(ckpt_path, file) 48 | with open(fname, 'rb') as f: 49 | try: 50 | np.load(f, allow_pickle=True) 51 | return fname 52 | except (OSError, EOFError, zipfile.BadZipFile): 53 | logging.info('Error loading checkpoint %s. Trying next checkpoint...', 54 | fname) 55 | return None 56 | 57 | 58 | def create_save_path(save_path: Optional[str]) -> str: 59 | """Creates the directory for saving checkpoints, if it doesn't exist. 60 | 61 | Args: 62 | save_path: directory to use. If false, create a directory in the working 63 | directory based upon the current time. 64 | 65 | Returns: 66 | Path to save checkpoints to. 67 | """ 68 | timestamp = datetime.datetime.now().strftime('%Y_%m_%d_%H:%M:%S') 69 | default_save_path = os.path.join(os.getcwd(), f'ferminet_{timestamp}') 70 | ckpt_save_path = save_path or default_save_path 71 | if ckpt_save_path and not os.path.isdir(ckpt_save_path): 72 | os.makedirs(ckpt_save_path) 73 | return ckpt_save_path 74 | 75 | 76 | def get_restore_path(restore_path: Optional[str] = None) -> Optional[str]: 77 | """Gets the path containing checkpoints from a previous calculation. 78 | 79 | Args: 80 | restore_path: path to checkpoints. 81 | 82 | Returns: 83 | The path or None if restore_path is falsy. 84 | """ 85 | if restore_path: 86 | ckpt_restore_path = restore_path 87 | else: 88 | ckpt_restore_path = None 89 | return ckpt_restore_path 90 | 91 | 92 | def save(save_path: str, 93 | t: int, 94 | data: networks.FermiNetData, 95 | params, 96 | opt_state, 97 | mcmc_width, 98 | density_state: Optional[observables.DensityState] = None) -> str: 99 | """Saves checkpoint information to a npz file. 100 | 101 | Args: 102 | save_path: path to directory to save checkpoint to. The checkpoint file is 103 | save_path/qmcjax_ckpt_$t.npz, where $t is the number of completed 104 | iterations. 105 | t: number of completed iterations. 106 | data: MCMC walker configurations. 107 | params: pytree of network parameters. 108 | opt_state: optimization state. 109 | mcmc_width: width to use in the MCMC proposal distribution. 110 | density_state: optional state of the density matrix calculation 111 | 112 | Returns: 113 | path to checkpoint file. 114 | """ 115 | ckpt_filename = os.path.join(save_path, f'qmcjax_ckpt_{t:06d}.npz') 116 | logging.info('Saving checkpoint %s', ckpt_filename) 117 | with open(ckpt_filename, 'wb') as f: 118 | np.savez( 119 | f, 120 | t=t, 121 | data=dataclasses.asdict(data), 122 | params=params, 123 | opt_state=np.asarray(opt_state, dtype=object), 124 | mcmc_width=mcmc_width, 125 | density_state=(dataclasses.asdict(density_state) 126 | if density_state else None)) 127 | return ckpt_filename 128 | 129 | 130 | def restore(restore_filename: str, batch_size: Optional[int] = None): 131 | """Restores data saved in a checkpoint. 132 | 133 | Args: 134 | restore_filename: filename containing checkpoint. 135 | batch_size: total batch size to be used. If present, check the data saved in 136 | the checkpoint is consistent with the batch size requested for the 137 | calculation. 138 | 139 | Returns: 140 | (t, data, params, opt_state, mcmc_width) tuple, where 141 | t: number of completed iterations. 142 | data: MCMC walker configurations. 143 | params: pytree of network parameters. 144 | opt_state: optimization state. 145 | mcmc_width: width to use in the MCMC proposal distribution. 146 | density_state: optional state of the density matrix calculation 147 | 148 | Raises: 149 | ValueError: if the leading dimension of data does not match the number of 150 | devices (i.e. the number of devices being parallelised over has changed) or 151 | if the total batch size is not equal to the number of MCMC configurations in 152 | data. 153 | """ 154 | logging.info('Loading checkpoint %s', restore_filename) 155 | with open(restore_filename, 'rb') as f: 156 | ckpt_data = np.load(f, allow_pickle=True) 157 | # Retrieve data from npz file. Non-array variables need to be converted back 158 | # to natives types using .tolist(). 159 | t = ckpt_data['t'].tolist() + 1 # Return the iterations completed. 160 | data = networks.FermiNetData(**ckpt_data['data'].item()) 161 | params = ckpt_data['params'].tolist() 162 | opt_state = ckpt_data['opt_state'].tolist() 163 | mcmc_width = jnp.array(ckpt_data['mcmc_width'].tolist()) 164 | if ckpt_data['density_state']: 165 | density_state = observables.DensityState( 166 | **ckpt_data['density_state'].item()) 167 | else: 168 | density_state = None 169 | if data.positions.shape[0] != jax.device_count(): 170 | raise ValueError( 171 | 'Incorrect number of devices found. Expected' 172 | f' {data.positions.shape[0]}, found {jax.device_count()}.' 173 | ) 174 | if ( 175 | batch_size 176 | and data.positions.shape[0] * data.positions.shape[1] != batch_size 177 | ): 178 | raise ValueError( 179 | f'Wrong batch size in loaded data. Expected {batch_size}, found ' 180 | f'{data.positions.shape[0] * data.positions.shape[1]}.') 181 | return t, data, params, opt_state, mcmc_width, density_state 182 | -------------------------------------------------------------------------------- /ferminet/configs/organic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Organic molecule config. Bicyclobutane, butadiene and cyclobutadiene.""" 16 | # Geometries for bicyclobutane to butadiene transition taken from 17 | # CASSCF(10,10)/cc-pVDZ calculations used in 18 | # A. Kinal and P. Piecuch, J. Phys. Chem. A 2007, 111, 734-742 19 | # Geometries for cyclobutadiene automerization taken from MR-BWCCSD(T) 20 | # a.c./cc-pVTZ calculations from 21 | # Bhaskaran-Nair et al., J. Chem. Phys. 2008, 129, 184104, and are the same as 22 | # those in J. Hermann, Z. Schätzle and F Noé, Nat. Chem. 2020, 12, 891–897 23 | 24 | from ferminet import base_config 25 | from ferminet.utils import system 26 | 27 | systems = { 28 | 'bicbut': [['C', (1.0487346562, 0.5208579773, 0.2375867187)], 29 | ['C', (0.2497284256, -0.7666691493, 0.0936474818)], 30 | ['C', (-0.1817326465, 0.4922777820, -0.6579637266)], 31 | ['C', (-1.1430708301, -0.1901383337, 0.3048494250)], 32 | ['H', (2.0107137141, 0.5520589541, -0.2623459977)], 33 | ['H', (1.0071921280, 1.0672669240, 1.1766131856)], 34 | ['H', (0.5438033167, -1.7129829738, -0.3260782874)], 35 | ['H', (-0.2580605320, 0.6268443026, -1.7229636111)], 36 | ['H', (-1.3778676954, 0.2935640723, 1.2498189977)], 37 | ['H', (-1.9664163102, -0.7380906148, -0.1402911727)]], 38 | 'con_TS': [['C', (1.0422528085, 0.5189448459, 0.2893513723)], 39 | ['C', (0.6334392052, -0.8563584473, -0.1382423606)], 40 | ['C', (-0.2492035181, 0.3134656784, -0.5658962512)], 41 | ['C', (-1.3903646889, 0.0535204487, 0.2987506023)], 42 | ['H', (1.8587636947, 0.9382817031, -0.2871146890)], 43 | ['H', (0.9494853889, 0.8960565051, 1.3038563129)], 44 | ['H', (0.3506375894, -1.7147937260, 0.4585707483)], 45 | ['H', (-0.3391417369, 0.6603641863, -1.5850373819)], 46 | ['H', (-1.2605467656, 0.0656225945, 1.3701508857)], 47 | ['H', (-2.3153892612, -0.3457478660, -0.0991685880)]], 48 | 'dis_TS': [['C', (1.5864390444, -0.1568990400, -0.1998155990)], 49 | ['C', (-0.8207390911, 0.8031532550, -0.2771554962)], 50 | ['C', (0.2514913592, 0.0515423448, 0.4758741643)], 51 | ['C', (-1.0037104567, -0.6789877402, -0.0965401189)], 52 | ['H', (2.4861305372, 0.1949133826, 0.2874101433)], 53 | ['H', (1.6111805503, -0.2769458302, -1.2753251100)], 54 | ['H', (-1.4350764228, 1.6366792379, 0.0289087336)], 55 | ['H', (0.2833919284, 0.1769734467, 1.5525271253)], 56 | ['H', (-1.7484283536, -1.0231589431, 0.6120702030)], 57 | ['H', (-0.8524391649, -1.3241689195, -0.9544331346)]], 58 | 'g-but': [['C', (1.4852019019, 0.4107781008, 0.5915178362)], 59 | ['C', (0.7841417614, -0.4218449588, -0.2276848579)], 60 | ['C', (-0.6577970182, -0.2577617373, -0.6080850660)], 61 | ['C', (-1.6247236649, 0.2933006709, 0.1775352473)], 62 | ['H', (1.0376813593, 1.2956518484, 1.0267024109)], 63 | ['H', (2.5232360753, 0.2129135014, 0.8248568552)], 64 | ['H', (1.2972328960, -1.2700686671, -0.6686116041)], 65 | ['H', (-0.9356614935, -0.6338686329, -1.5871170536)], 66 | ['H', (-1.4152018269, 0.6472889925, 1.1792563311)], 67 | ['H', (-2.6423222755, 0.3847635835, -0.1791755263)]], 68 | 'gt-TS': [['C', (1.7836595975, 0.4683155866, -0.4860478101)], 69 | ['C', (0.7828892933, -0.4014025715, -0.1873880949)], 70 | ['C', (-0.6557274850, -0.2156646805, -0.6243545354)], 71 | ['C', (-1.6396999531, 0.2526943506, 0.1877948644)], 72 | ['H', (1.6003117673, 1.3693309737, -1.0595471944)], 73 | ['H', (2.7986234673, 0.2854595500, -0.1564989895)], 74 | ['H', (1.0128486304, -1.2934621995, 0.3872559845)], 75 | ['H', (-0.9003245968, -0.4891235826, -1.6462438855)], 76 | ['H', (-1.4414954784, 0.5345813494, 1.2152198579)], 77 | ['H', (-2.6556262424, 0.3594422237, -0.1709361970)]], 78 | 't-but': [['C', (0.6109149108, 1.7798412991, -0.0000000370)], 79 | ['C', (0.6162339625, 0.4163908910, -0.0000000070)], 80 | ['C', (-0.6162376752, -0.4163867945, -0.0000000601)], 81 | ['C', (-0.6109129465, -1.7798435851, 0.0000000007)], 82 | ['H', (1.5340442204, 2.3439205382, 0.0000000490)], 83 | ['H', (-0.3156117962, 2.3419017314, 0.0000000338)], 84 | ['H', (1.5642720455, -0.1114324578, -0.0000000088)], 85 | ['H', (-1.5642719469, 0.1114307897, -0.0000000331)], 86 | ['H', (-1.5340441021, -2.3439203971, 0.0000000714)], 87 | ['H', (0.3156133277, -2.3419020150, -0.0000000088)]], 88 | 'cycbut-ground': [['C', (0.0000000e+00, 0.0000000e+00, 0.0000000e+00)], 89 | ['C', (2.9555318e+00, 0.0000000e+00, 0.0000000e+00)], 90 | ['C', (2.9555318e+00, 2.5586891e+00, 0.0000000e+00)], 91 | ['C', (0.0000000e+00, 2.5586891e+00, 0.0000000e+00)], 92 | ['H', (-1.4402903e+00, -1.4433100e+00, 1.7675451e-16)], 93 | ['H', (4.3958220e+00, -1.4433100e+00, -1.7675451e-16)], 94 | ['H', (4.3958220e+00, 4.0019994e+00, 1.7675451e-16)], 95 | ['H', (-1.4402903e+00, 4.0019994e+00, -1.7675451e-16)]], 96 | 'cycbut-trans': [['C', (0.0000000e+00, 0.0000000e+00, 0.0000000e+00)], 97 | ['C', (2.7419927e+00, 0.0000000e+00, 0.0000000e+00)], 98 | ['C', (2.7419927e+00, 2.7419927e+00, 0.0000000e+00)], 99 | ['C', (0.0000000e+00, 2.7419927e+00, 0.0000000e+00)], 100 | ['H', (-1.4404647e+00, -1.4404647e+00, 1.7640606e-16)], 101 | ['H', (4.1824574e+00, -1.4404647e+00, -1.7640606e-16)], 102 | ['H', (4.1824574e+00, 4.1824574e+00, 1.7640606e-16)], 103 | ['H', (-1.4404647e+00, 4.1824574e+00, -1.7640606e-16)]] 104 | } 105 | 106 | 107 | def organic_molecule(cfg): 108 | """Sets molecule and electrons fields in ConfigDict.""" 109 | cfg.system.electrons = (15, 15) 110 | units = 'angstrom' 111 | if 'cycbut' in cfg.system.molecule_name: 112 | cfg.system.electrons = (14, 14) 113 | units = 'bohr' 114 | if cfg.system.molecule_name not in systems: 115 | raise ValueError(f'Unrecognized molecule: {cfg.system.molecule_name}') 116 | molecule = [] 117 | for element, coords in systems[cfg.system.molecule_name]: 118 | molecule.append(system.Atom(symbol=element, coords=coords, units=units)) 119 | cfg.system.molecule = molecule 120 | return cfg 121 | 122 | 123 | def get_config(): 124 | """Returns config for running bibut.""" 125 | name = 'organic' 126 | cfg = base_config.default() 127 | cfg.system.molecule_name = name 128 | with cfg.ignore_type(): 129 | cfg.system.set_molecule = organic_molecule 130 | cfg.config_module = '.organic' 131 | return cfg 132 | -------------------------------------------------------------------------------- /ferminet/tests/hamiltonian_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ferminet.hamiltonian.""" 16 | 17 | import itertools 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from ferminet import base_config 22 | from ferminet import hamiltonian 23 | from ferminet import networks 24 | import jax 25 | import jax.numpy as jnp 26 | import numpy as np 27 | 28 | 29 | def h_atom_log_psi(param, xs, spins, atoms=None, charges=None): 30 | del param, spins, atoms, charges 31 | # log of exact hydrogen wavefunction. 32 | return -jnp.abs(jnp.linalg.norm(xs)) 33 | 34 | 35 | def h_atom_log_psi_signed(param, xs, spins, atoms=None, charges=None): 36 | log_psi = h_atom_log_psi(param, xs, spins, atoms, charges) 37 | return jnp.ones_like(log_psi), log_psi 38 | 39 | 40 | def kinetic_from_hessian(log_f): 41 | 42 | def kinetic_operator(params, pos, spins, atoms, charges): 43 | f = lambda x: jnp.exp(log_f(params, x, spins, atoms, charges)) 44 | ys = f(pos) 45 | hess = jax.hessian(f)(pos) 46 | return -0.5 * jnp.trace(hess) / ys 47 | 48 | return kinetic_operator 49 | 50 | 51 | def kinetic_from_hessian_log(log_f): 52 | 53 | def kinetic_operator(params, pos, spins, atoms, charges): 54 | f = lambda x: log_f(params, x, spins, atoms, charges) 55 | grad_f = jax.grad(f)(pos) 56 | hess = jax.hessian(f)(pos) 57 | return -0.5 * (jnp.trace(hess) + jnp.sum(grad_f**2)) 58 | 59 | return kinetic_operator 60 | 61 | 62 | class HamiltonianTest(parameterized.TestCase): 63 | 64 | @parameterized.parameters(['default', 'folx']) 65 | def test_local_kinetic_energy(self, laplacian): 66 | 67 | dummy_params = {} 68 | xs = np.random.normal(size=(3,)) 69 | spins = np.ones(shape=(1,)) 70 | atoms = np.random.normal(size=(1, 3)) 71 | charges = 2 * np.ones(shape=(1,)) 72 | expected_kinetic_energy = -(1 - 2 / np.abs(np.linalg.norm(xs))) / 2 73 | 74 | kinetic = hamiltonian.local_kinetic_energy(h_atom_log_psi_signed, 75 | laplacian_method=laplacian) 76 | kinetic_energy = kinetic( 77 | dummy_params, 78 | networks.FermiNetData( 79 | positions=xs, spins=spins, atoms=atoms, charges=charges 80 | ), 81 | ) 82 | np.testing.assert_allclose( 83 | kinetic_energy, expected_kinetic_energy, rtol=1.e-5) 84 | 85 | def test_potential_energy_null(self): 86 | 87 | # with one electron and a nuclear charge of zero, the potential energy is 88 | # zero. 89 | xs = np.random.normal(size=(1, 3)) 90 | r_ae = jnp.linalg.norm(xs, axis=-1) 91 | r_ee = jnp.zeros(shape=(1, 1, 1)) 92 | atoms = jnp.zeros(shape=(1, 3)) 93 | charges = jnp.zeros(shape=(1,)) 94 | v = hamiltonian.potential_energy(r_ae, r_ee, atoms, charges) 95 | np.testing.assert_allclose(v, 0.0, rtol=1E-5) 96 | 97 | def test_potential_energy_ee(self): 98 | 99 | xs = np.random.normal(size=(5, 3)) 100 | r_ae = jnp.linalg.norm(xs, axis=-1) 101 | r_ee = jnp.linalg.norm(xs[None, ...] - xs[:, None, :], axis=-1) 102 | atoms = jnp.zeros(shape=(1, 3)) 103 | charges = jnp.zeros(shape=(1,)) 104 | mask = ~jnp.eye(r_ee.shape[0], dtype=bool) 105 | expected_v_ee = 0.5 * np.sum(1.0 / r_ee[mask]) 106 | v = hamiltonian.potential_energy(r_ae, r_ee[..., None], atoms, charges) 107 | np.testing.assert_allclose(v, expected_v_ee, rtol=1E-5) 108 | 109 | def test_potential_energy_he2_ion(self): 110 | 111 | xs = np.random.normal(size=(1, 3)) 112 | atoms = jnp.array([[0, 0, -1], [0, 0, 1]]) 113 | r_ae = jnp.linalg.norm(xs - atoms, axis=-1) 114 | r_ee = jnp.zeros(shape=(1, 1, 1)) 115 | charges = jnp.array([2, 2]) 116 | v_ee = -jnp.sum(charges / r_ae) 117 | v_ae = jnp.prod(charges) / jnp.linalg.norm(jnp.diff(atoms, axis=0)) 118 | expected_v = v_ee + v_ae 119 | v = hamiltonian.potential_energy(r_ae[..., None], r_ee, atoms, charges) 120 | np.testing.assert_allclose(v, expected_v, rtol=1E-5) 121 | 122 | def test_local_energy(self): 123 | 124 | spins = np.ones(shape=(1,)) 125 | atoms = np.zeros(shape=(1, 3)) 126 | charges = np.ones(shape=(1,)) 127 | dummy_params = {} 128 | local_energy = hamiltonian.local_energy( 129 | h_atom_log_psi_signed, charges, nspins=(1, 0), use_scan=False 130 | ) 131 | 132 | xs = np.random.normal(size=(100, 3)) 133 | key = jax.random.PRNGKey(4) 134 | keys = jax.random.split(key, num=xs.shape[0]) 135 | batch_local_energy = jax.vmap( 136 | local_energy, 137 | in_axes=( 138 | None, 139 | 0, 140 | networks.FermiNetData( 141 | positions=0, spins=None, atoms=None, charges=None 142 | ), 143 | ), 144 | ) 145 | energies, _ = batch_local_energy( 146 | dummy_params, 147 | keys, 148 | networks.FermiNetData( 149 | positions=xs, spins=spins, atoms=atoms, charges=charges 150 | ), 151 | ) 152 | 153 | np.testing.assert_allclose( 154 | energies, -0.5 * np.ones_like(energies), rtol=1E-5) 155 | 156 | 157 | class LaplacianTest(parameterized.TestCase): 158 | 159 | @parameterized.parameters(['default', 'folx']) 160 | def test_laplacian(self, laplacian): 161 | 162 | xs = np.random.uniform(size=(100, 3)) 163 | spins = np.ones(shape=(1,)) 164 | atoms = np.random.normal(size=(1, 3)) 165 | charges = 3 * np.ones(shape=(1,)) 166 | data = networks.FermiNetData( 167 | positions=xs, spins=spins, atoms=atoms, charges=charges 168 | ) 169 | dummy_params = {} 170 | t_l_fn = jax.vmap( 171 | hamiltonian.local_kinetic_energy(h_atom_log_psi_signed, 172 | laplacian_method=laplacian), 173 | in_axes=( 174 | None, 175 | networks.FermiNetData( 176 | positions=0, spins=None, atoms=None, charges=None 177 | ), 178 | ), 179 | ) 180 | t_l = t_l_fn(dummy_params, data) 181 | hess_t = jax.vmap( 182 | kinetic_from_hessian(h_atom_log_psi), 183 | in_axes=(None, 0, None, None, None), 184 | )(dummy_params, xs, spins, atoms, charges) 185 | np.testing.assert_allclose(t_l, hess_t, rtol=1E-5) 186 | 187 | @parameterized.parameters( 188 | itertools.product([True, False], ['default', 'folx']) 189 | ) 190 | def test_fermi_net_laplacian(self, full_det, laplacian): 191 | natoms = 2 192 | np.random.seed(12) 193 | atoms = np.random.uniform(low=-5.0, high=5.0, size=(natoms, 3)) 194 | nspins = (2, 3) 195 | charges = 2 * np.ones(shape=(natoms,)) 196 | batch = 4 197 | cfg = base_config.default() 198 | cfg.network.full_det = full_det 199 | cfg.network.ferminet.hidden_dims = ((8, 4),) * 2 200 | cfg.network.determinants = 2 201 | feature_layer = networks.make_ferminet_features( 202 | natoms, 203 | cfg.system.electrons, 204 | cfg.system.ndim, 205 | ) 206 | network = networks.make_fermi_net( 207 | nspins, 208 | charges, 209 | full_det=full_det, 210 | feature_layer=feature_layer, 211 | **cfg.network.ferminet 212 | ) 213 | log_network = lambda *args, **kwargs: network.apply(*args, **kwargs)[1] 214 | key = jax.random.PRNGKey(47) 215 | params = network.init(key) 216 | xs = np.random.normal(scale=5, size=(batch, sum(nspins) * 3)) 217 | spins = np.sign(np.random.normal(scale=1, size=(batch, sum(nspins)))) 218 | t_l_fn = jax.jit( 219 | jax.vmap( 220 | hamiltonian.local_kinetic_energy(network.apply, 221 | laplacian_method=laplacian), 222 | in_axes=( 223 | None, 224 | networks.FermiNetData( 225 | positions=0, spins=0, atoms=None, charges=None 226 | ), 227 | ), 228 | ) 229 | ) 230 | t_l = t_l_fn( 231 | params, 232 | networks.FermiNetData( 233 | positions=xs, spins=spins, atoms=atoms, charges=charges 234 | ), 235 | ) 236 | hess_t_fn = jax.jit( 237 | jax.vmap( 238 | kinetic_from_hessian_log(log_network), 239 | in_axes=(None, 0, 0, None, None), 240 | ) 241 | ) 242 | hess_t = hess_t_fn(params, xs, spins, atoms, charges) 243 | if hess_t.dtype == jnp.float64: 244 | atol, rtol = 1.e-10, 1.e-10 245 | else: 246 | # This needs a low tolerance because on fast math optimization in CPU can 247 | # substantially affect floating point expressions. See 248 | # https://github.com/jax-ml/jax/issues/6566. 249 | atol, rtol = 4.e-3, 4.e-3 250 | np.testing.assert_allclose(t_l, hess_t, atol=atol, rtol=rtol) 251 | 252 | 253 | if __name__ == '__main__': 254 | absltest.main() 255 | -------------------------------------------------------------------------------- /ferminet/pbc/hamiltonian.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """Ewald summation of Coulomb Hamiltonian in periodic boundary conditions. 16 | 17 | See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., 18 | Spencer, J.S. and Foulkes, W.M.C., 2022. Discovering Quantum Phase Transitions 19 | with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183. 20 | """ 21 | 22 | import itertools 23 | from typing import Callable, Optional, Sequence, Tuple 24 | 25 | import chex 26 | from ferminet import hamiltonian 27 | from ferminet import networks 28 | import jax 29 | import jax.numpy as jnp 30 | 31 | 32 | def make_ewald_potential( 33 | lattice: jnp.ndarray, 34 | atoms: jnp.ndarray, 35 | charges: jnp.ndarray, 36 | truncation_limit: int = 5, 37 | include_heg_background: bool = True 38 | ) -> Callable[[jnp.ndarray, jnp.ndarray], float]: 39 | """Creates a function to evaluate infinite Coulomb sum for periodic lattice. 40 | 41 | Args: 42 | lattice: Shape (3, 3). Matrix whose columns are the primitive lattice 43 | vectors. 44 | atoms: Shape (natoms, ndim). Positions of the atoms. 45 | charges: Shape (natoms). Nuclear charges of the atoms. 46 | truncation_limit: Integer. Half side length of cube of nearest neighbours 47 | to primitive cell which are summed over in evaluation of Ewald sum. 48 | Must be large enough to achieve convergence for the real and reciprocal 49 | space sums. 50 | include_heg_background: bool. When True, includes cell-neutralizing 51 | background term for homogeneous electron gas. 52 | 53 | Returns: 54 | Callable with signature f(ae, ee), where (ae, ee) are atom-electon and 55 | electron-electron displacement vectors respectively, which evaluates the 56 | Coulomb sum for the periodic lattice via the Ewald method. 57 | """ 58 | rec = 2 * jnp.pi * jnp.linalg.inv(lattice) 59 | volume = jnp.abs(jnp.linalg.det(lattice)) 60 | # the factor gamma tunes the width of the summands in real / reciprocal space 61 | # and this value is chosen to optimize the convergence trade-off between the 62 | # two sums. See CASINO QMC manual. 63 | gamma = (2.8 / volume**(1 / 3))**2 64 | ordinals = sorted(range(-truncation_limit, truncation_limit + 1), key=abs) 65 | ordinals = jnp.array(list(itertools.product(ordinals, repeat=3))) 66 | lat_vectors = jnp.einsum('kj,ij->ik', lattice, ordinals) 67 | rec_vectors = jnp.einsum('jk,ij->ik', rec, ordinals[1:]) 68 | rec_vec_square = jnp.einsum('ij,ij->i', rec_vectors, rec_vectors) 69 | lat_vec_norm = jnp.linalg.norm(lat_vectors[1:], axis=-1) 70 | 71 | def real_space_ewald(separation: jnp.ndarray): 72 | """Real-space Ewald potential between charges seperated by separation.""" 73 | displacements = jnp.linalg.norm( 74 | separation - lat_vectors, axis=-1) # |r - R| 75 | return jnp.sum( 76 | jax.scipy.special.erfc(gamma**0.5 * displacements) / displacements) 77 | 78 | def recp_space_ewald(separation: jnp.ndarray): 79 | """Returns reciprocal-space Ewald potential between charges.""" 80 | return (4 * jnp.pi / volume) * jnp.sum( 81 | jnp.exp(1.0j * jnp.dot(rec_vectors, separation)) * 82 | jnp.exp(-rec_vec_square / (4 * gamma)) / rec_vec_square) 83 | 84 | def ewald_sum(separation: jnp.ndarray): 85 | """Evaluates combined real and reciprocal space Ewald potential.""" 86 | return (real_space_ewald(separation) + recp_space_ewald(separation) - 87 | jnp.pi / (volume * gamma)) 88 | 89 | madelung_const = ( 90 | jnp.sum(jax.scipy.special.erfc(gamma**0.5 * lat_vec_norm) / lat_vec_norm) 91 | - 2 * gamma**0.5 / jnp.pi**0.5) 92 | madelung_const += ( 93 | (4 * jnp.pi / volume) * 94 | jnp.sum(jnp.exp(-rec_vec_square / (4 * gamma)) / rec_vec_square) - 95 | jnp.pi / (volume * gamma)) 96 | 97 | batch_ewald_sum = jax.vmap(ewald_sum, in_axes=(0,)) 98 | 99 | def atom_electron_potential(ae: jnp.ndarray): 100 | """Evaluates periodic atom-electron potential.""" 101 | nelec = ae.shape[0] 102 | ae = jnp.reshape(ae, [-1, 3]) # flatten electronxatom axis 103 | # calculate potential for each ae pair 104 | ewald = batch_ewald_sum(ae) - madelung_const 105 | return jnp.sum(-jnp.tile(charges, nelec) * ewald) 106 | 107 | def electron_electron_potential(ee: jnp.ndarray): 108 | """Evaluates periodic electron-electron potential.""" 109 | nelec = ee.shape[0] 110 | ee = jnp.reshape(ee, [-1, 3]) 111 | if include_heg_background: 112 | ewald = batch_ewald_sum(ee) 113 | else: 114 | ewald = batch_ewald_sum(ee) - madelung_const 115 | ewald = jnp.reshape(ewald, [nelec, nelec]) 116 | ewald = ewald.at[jnp.diag_indices(nelec)].set(0.0) 117 | if include_heg_background: 118 | return 0.5 * jnp.sum(ewald) + 0.5 * nelec * madelung_const 119 | else: 120 | return 0.5 * jnp.sum(ewald) 121 | 122 | # Atom-atom potential 123 | natom = atoms.shape[0] 124 | if natom > 1: 125 | aa = jnp.reshape(atoms, [1, -1, 3]) - jnp.reshape(atoms, [-1, 1, 3]) 126 | aa = jnp.reshape(aa, [-1, 3]) 127 | chargeprods = (charges[..., None] @ charges[..., None].T).flatten() 128 | ewald = batch_ewald_sum(aa) - madelung_const 129 | ewald = jnp.reshape(ewald, [natom, natom]) 130 | ewald = ewald.at[jnp.diag_indices(natom)].set(0.0) 131 | ewald = ewald.flatten() 132 | atom_atom_potential = 0.5 * jnp.sum(chargeprods * ewald) 133 | else: 134 | atom_atom_potential = 0.0 135 | 136 | def potential(ae: jnp.ndarray, ee: jnp.ndarray): 137 | """Accumulates atom-electron, atom-atom, and electron-electron potential.""" 138 | # Reduce vectors into first unit cell - Ewald summation 139 | # is only guaranteed to converge close to the origin 140 | phase_ae = jnp.einsum('il,jkl->jki', rec / (2 * jnp.pi), ae) 141 | phase_ee = jnp.einsum('il,jkl->jki', rec / (2 * jnp.pi), ee) 142 | phase_prim_ae = phase_ae % 1 143 | phase_prim_ee = phase_ee % 1 144 | prim_ae = jnp.einsum('il,jkl->jki', lattice, phase_prim_ae) 145 | prim_ee = jnp.einsum('il,jkl->jki', lattice, phase_prim_ee) 146 | return jnp.real( 147 | atom_electron_potential(prim_ae) + 148 | electron_electron_potential(prim_ee) + atom_atom_potential) 149 | 150 | return potential 151 | 152 | 153 | def local_energy( 154 | f: networks.FermiNetLike, 155 | charges: jnp.ndarray, 156 | nspins: Sequence[int], 157 | use_scan: bool = False, 158 | complex_output: bool = False, 159 | laplacian_method: str = 'default', 160 | states: int = 0, 161 | lattice: Optional[jnp.ndarray] = None, 162 | heg: bool = True, 163 | convergence_radius: int = 5, 164 | ) -> hamiltonian.LocalEnergy: 165 | """Creates the local energy function in periodic boundary conditions. 166 | 167 | Args: 168 | f: Callable which returns the sign and log of the magnitude of the 169 | wavefunction given the network parameters and configurations data. 170 | charges: Shape (natoms). Nuclear charges of the atoms. 171 | nspins: Number of particles of each spin. 172 | use_scan: Whether to use a `lax.scan` for computing the laplacian. 173 | complex_output: If true, the output of f is complex-valued. 174 | laplacian_method: Laplacian calculation method. One of: 175 | 'default': take jvp(grad), looping over inputs 176 | 'folx': use Microsoft's implementation of forward laplacian 177 | states: Number of excited states to compute. Not implemented, only present 178 | for consistency of calling convention. 179 | lattice: Shape (ndim, ndim). Matrix of lattice vectors. Default: identity 180 | matrix. 181 | heg: bool. Flag to enable features specific to the electron gas. 182 | convergence_radius: int. Radius of cluster summed over by Ewald sums. 183 | 184 | Returns: 185 | Callable with signature e_l(params, key, data) which evaluates the local 186 | energy of the wavefunction given the parameters params, RNG state key, 187 | and a single MCMC configuration in data. 188 | """ 189 | if states: 190 | raise NotImplementedError('Excited states not implemented with PBC.') 191 | del nspins 192 | if lattice is None: 193 | lattice = jnp.eye(3) 194 | 195 | ke = hamiltonian.local_kinetic_energy(f, 196 | use_scan=use_scan, 197 | complex_output=complex_output, 198 | laplacian_method=laplacian_method) 199 | 200 | def _e_l( 201 | params: networks.ParamTree, key: chex.PRNGKey, data: networks.FermiNetData 202 | ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: 203 | """Returns the total energy. 204 | 205 | Args: 206 | params: network parameters. 207 | key: RNG state. 208 | data: MCMC configuration. 209 | """ 210 | del key # unused 211 | potential_energy = make_ewald_potential( 212 | lattice, data.atoms, charges, convergence_radius, heg 213 | ) 214 | ae, ee, _, _ = networks.construct_input_features( 215 | data.positions, data.atoms) 216 | potential = potential_energy(ae, ee) 217 | kinetic = ke(params, data) 218 | return potential + kinetic, None 219 | 220 | return _e_l 221 | -------------------------------------------------------------------------------- /ferminet/tests/excited_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Test experiments from Pfau, Axelrod, Sutterud, von Glehn, Spencer (2024).""" 16 | import itertools 17 | import os 18 | 19 | from absl import flags 20 | from absl import logging 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | import chex 24 | from ferminet import base_config 25 | from ferminet import train 26 | from ferminet.configs.excited import atoms 27 | from ferminet.configs.excited import benzene 28 | from ferminet.configs.excited import carbon_dimer 29 | from ferminet.configs.excited import double_excitation 30 | from ferminet.configs.excited import oscillator 31 | from ferminet.configs.excited import presets 32 | from ferminet.configs.excited import twisted_ethylene 33 | import jax 34 | 35 | import pyscf 36 | 37 | FLAGS = flags.FLAGS 38 | # Default flags are sufficient so mark FLAGS as parsed so we can run the tests 39 | # with py.test, which imports this file rather than runs it. 40 | FLAGS.mark_as_parsed() 41 | 42 | 43 | def setUpModule(): 44 | # Allow chex_n_cpu_devices to be set via an environment variable as well as 45 | # --chex_n_cpu_devices to play nicely with pytest. 46 | fake_devices = os.environ.get('FERMINET_CHEX_N_CPU_DEVICES') 47 | if fake_devices is not None: 48 | fake_devices = int(fake_devices) 49 | try: 50 | chex.set_n_cpu_devices(n=fake_devices) 51 | except RuntimeError: 52 | # jax has already been initialised (e.g. because this is being run with 53 | # other tests via a test runner such as pytest) 54 | logging.info('JAX already initialised so cannot set number of CPU devices. ' 55 | 'Using a single device in train_test.') 56 | jax.clear_caches() 57 | 58 | 59 | def minimize_config(cfg): 60 | """Change fields in config to minimal values appropriate for fast testing.""" 61 | cfg.network.ferminet.hidden_dims = ((16, 4),) * 2 62 | cfg.network.psiformer.heads_dim = 4 63 | cfg.network.psiformer.mlp_hidden_dims = (16,) 64 | cfg.network.determinants = 2 65 | cfg.batch_size = 32 66 | cfg.pretrain.iterations = 10 67 | cfg.mcmc.burn_in = 10 68 | cfg.optim.iterations = 3 69 | return cfg 70 | 71 | 72 | class ExcitedStateTest(parameterized.TestCase): 73 | 74 | def setUp(self): 75 | super(ExcitedStateTest, self).setUp() 76 | # disable use of temp directory in pyscf. 77 | # Test calculations are small enough to fit in RAM and we don't need 78 | # checkpoint files. 79 | pyscf.lib.param.TMPDIR = None 80 | 81 | @parameterized.parameters(itertools.product( 82 | ['Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne'], 83 | [(5, 'ferminet'), (5, 'psiformer'), (10, 'psiformer')])) 84 | def test_atoms(self, atom, state_and_network): 85 | """Test experiments from Fig. 1 of Pfau et al. (2024).""" 86 | states, network = state_and_network 87 | cfg = atoms.get_config() 88 | cfg.system.atom = atom 89 | cfg.system.states = states 90 | match network: 91 | case 'ferminet': 92 | cfg.update_from_flattened_dict(presets.ferminet) 93 | case 'psiformer': 94 | cfg.update_from_flattened_dict(presets.psiformer) 95 | case _: 96 | raise ValueError(f'Unknown network type: {network}') 97 | cfg = minimize_config(cfg) 98 | cfg.log.save_path = self.create_tempdir().full_path 99 | cfg = base_config.resolve(cfg) 100 | # Calculation is too small to test the results for accuracy. Test just to 101 | # ensure they actually run without a top-level error. 102 | train.train(cfg) 103 | 104 | @parameterized.parameters(itertools.product( 105 | ['BH', 'HCl', 'H2O', 'H2S', 'BF', 'CO', 'C2H4', 106 | 'CH2O', 'CH2S', 'HNO', 'HCF', 'H2CSi'], 107 | ['ferminet', 'psiformer'])) 108 | def test_oscillator(self, system, network): 109 | """Test experiments from Fig. 2 of Pfau et al. (2024).""" 110 | cfg = oscillator.get_config() 111 | cfg.system.molecule_name = system 112 | if system in ['HNO', 'HCF']: 113 | cfg.pretrain.excitation_type = 'random' 114 | match network: 115 | case 'ferminet': 116 | cfg.update_from_flattened_dict(presets.ferminet) 117 | case 'psiformer': 118 | cfg.update_from_flattened_dict(presets.psiformer) 119 | case _: 120 | raise ValueError(f'Unknown network type: {network}') 121 | cfg = minimize_config(cfg) 122 | cfg.log.save_path = self.create_tempdir().full_path 123 | cfg = base_config.resolve(cfg) 124 | # Calculation is too small to test the results for accuracy. Test just to 125 | # ensure they actually run without a top-level error. 126 | train.train(cfg) 127 | 128 | @parameterized.parameters( 129 | [0.8, 0.9, 0.95, 1.0, 1.05, 1.1, 1.2, 1.3, 1.4, 1.5]) 130 | def test_carbon_dimer(self, equilibrium_multiple): 131 | """Test experiments from Fig. 3 of Pfau et al. (2024).""" 132 | cfg = carbon_dimer.get_config() 133 | cfg.system.equilibrium_multiple = equilibrium_multiple 134 | cfg = minimize_config(cfg) 135 | cfg.log.save_path = self.create_tempdir().full_path 136 | cfg = base_config.resolve(cfg) 137 | # Calculation is too small to test the results for accuracy. Test just to 138 | # ensure they actually run without a top-level error. 139 | train.train(cfg) 140 | 141 | @parameterized.parameters(itertools.chain( 142 | itertools.product(['planar'], [0, 15, 30, 45, 60, 70, 80, 85, 90]), 143 | itertools.product(['twisted'], [0, 20, 40, 60, 70, 80, 90, 95, 97.5, 144 | 100, 102.5, 105, 110, 120]), 145 | )) 146 | def test_twisted_ethylene(self, system, angle): 147 | """Test experiments from Fig. 4 of Pfau et al. (2024).""" 148 | cfg = twisted_ethylene.get_config() 149 | cfg.system.molecule_name = system 150 | match system: 151 | case 'planar': 152 | cfg.system.twist.tau = angle 153 | case 'twisted': 154 | cfg.system.twist.phi = angle 155 | case _: 156 | raise ValueError(f'Unknown system type: {system}') 157 | cfg = minimize_config(cfg) 158 | cfg.log.save_path = self.create_tempdir().full_path 159 | cfg = base_config.resolve(cfg) 160 | # Calculation is too small to test the results for accuracy. Test just to 161 | # ensure they actually run without a top-level error. 162 | train.train(cfg) 163 | 164 | @parameterized.parameters( 165 | itertools.product(['nitrosomethane', 166 | 'butadiene', 167 | 'glyoxal', 168 | 'tetrazine', 169 | 'cyclopentadienone'], 170 | ['ferminet'])) 171 | def test_double_excitation(self, system, network): 172 | """Test experiments from Fig. 5 of Pfau et al. (2024).""" 173 | cfg = double_excitation.get_config() 174 | cfg.system.molecule_name = system 175 | 176 | match system: 177 | case 'nitrosomethane': 178 | cfg.system.states = 6 179 | cfg.pretrain.excitation_type = 'random' 180 | case 'butadiene': 181 | cfg.system.states = 7 182 | cfg.pretrain.excitation_type = 'random' 183 | case 'glyoxal': 184 | cfg.system.states = 7 185 | cfg.pretrain.excitation_type = 'random' 186 | case 'tetrazine': 187 | cfg.system.states = 5 188 | cfg.optim.spin_energy = 0.5 189 | cfg.mcmc.blocks = 4 190 | case 'cyclopentadienone': 191 | cfg.system.states = 6 192 | cfg.optim.spin_energy = 1.0 193 | cfg.mcmc.blocks = 4 194 | case _: 195 | raise ValueError(f'Unknown system type: {system}') 196 | 197 | match network: 198 | case 'ferminet': 199 | cfg.update_from_flattened_dict(presets.ferminet) 200 | case'psiformer': 201 | cfg.update_from_flattened_dict(presets.psiformer) 202 | case _: 203 | raise ValueError(f'Unknown network type: {network}') 204 | 205 | cfg = minimize_config(cfg) 206 | cfg.log.save_path = self.create_tempdir().full_path 207 | cfg = base_config.resolve(cfg) 208 | # Calculation is too small to test the results for accuracy. Test just to 209 | # ensure they actually run without a top-level error. 210 | train.train(cfg) 211 | 212 | @parameterized.parameters(['ferminet', 'psiformer']) 213 | def test_benzene(self, network): 214 | """Test experiments from Fig. 6 of Pfau et al. (2024).""" 215 | cfg = benzene.get_config() 216 | match network: 217 | case 'ferminet': 218 | cfg.update_from_flattened_dict(presets.ferminet) 219 | case'psiformer': 220 | cfg.update_from_flattened_dict(presets.psiformer) 221 | case _: 222 | raise ValueError(f'Unknown network type: {network}') 223 | cfg = minimize_config(cfg) 224 | cfg.log.save_path = self.create_tempdir().full_path 225 | cfg = base_config.resolve(cfg) 226 | # Calculation is too small to test the results for accuracy. Test just to 227 | # ensure they actually run without a top-level error. 228 | train.train(cfg) 229 | 230 | @parameterized.parameters(itertools.product( 231 | ['HNO', 'HCF'], 232 | ['ferminet', 'psiformer'], 233 | ['vmc', 'vmc_overlap'], 234 | ['ordered', 'random'])) 235 | def test_pretrain_and_penalty(self, system, network, objective, excitation): 236 | """Test experiments from Fig. S3 of Pfau et al. (2024).""" 237 | cfg = oscillator.get_config() 238 | cfg.system.molecule_name = system 239 | cfg.pretrain.excitation_type = excitation 240 | cfg.optim.objective = objective 241 | match network: 242 | case 'ferminet': 243 | cfg.update_from_flattened_dict(presets.ferminet) 244 | case 'psiformer': 245 | cfg.update_from_flattened_dict(presets.psiformer) 246 | case _: 247 | raise ValueError(f'Unknown network type: {network}') 248 | cfg = minimize_config(cfg) 249 | cfg.log.save_path = self.create_tempdir().full_path 250 | cfg = base_config.resolve(cfg) 251 | # Calculation is too small to test the results for accuracy. Test just to 252 | # ensure they actually run without a top-level error. 253 | train.train(cfg) 254 | 255 | 256 | if __name__ == '__main__': 257 | absltest.main() 258 | -------------------------------------------------------------------------------- /ferminet/curvature_tags_and_blocks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Curvature blocks for FermiNet.""" 16 | import dataclasses 17 | import functools 18 | from typing import Sequence, Set, Tuple 19 | import jax 20 | import jax.numpy as jnp 21 | import kfac_jax 22 | import numpy as np 23 | 24 | 25 | Array = kfac_jax.utils.Array 26 | Scalar = kfac_jax.utils.Scalar 27 | Numeric = kfac_jax.utils.Numeric 28 | 29 | vmap_psd_inv = jax.vmap(kfac_jax.utils.psd_inv, (0, None), 0) 30 | vmap_matmul = jax.vmap(jnp.matmul, in_axes=(0, 0), out_axes=0) 31 | 32 | 33 | def register_repeated_dense(y, x, w, b, **kwargs): 34 | return kfac_jax.register_dense(y, x, w, b, variant="repeated_dense", **kwargs) 35 | 36 | 37 | def register_qmc(y, x, w, **kwargs): 38 | return kfac_jax.register_dense(y, x, w, variant="qmc", **kwargs) 39 | 40 | 41 | _dense = functools.partial( 42 | kfac_jax.tag_graph_matcher._dense, # pylint: disable=protected-access 43 | axes=1, 44 | with_reshape=False, 45 | ) 46 | 47 | 48 | _repeated_dense_parameter_extractor = functools.partial( 49 | kfac_jax.tag_graph_matcher._dense_parameter_extractor, # pylint: disable=protected-access 50 | variant="repeated_dense", 51 | ) 52 | 53 | 54 | class RepeatedDenseBlock(kfac_jax.DenseTwoKroneckerFactored): 55 | """Dense block that is repeatedly applied to multiple inputs (e.g. vmap).""" 56 | 57 | def fixed_scale(self) -> Numeric: 58 | (x_shape,) = self.inputs_shapes 59 | return float(kfac_jax.utils.product(x_shape) // (x_shape[0] * x_shape[-1])) 60 | 61 | def update_curvature_matrix_estimate( 62 | self, 63 | state: kfac_jax.KroneckerFactored.State, 64 | estimation_data: kfac_jax.LayerVjpData[Array], 65 | ema_old: Numeric, 66 | ema_new: Numeric, 67 | identity_weight: Numeric, 68 | batch_size: int, 69 | ) -> kfac_jax.KroneckerFactored.State: 70 | [x] = estimation_data.primals.inputs 71 | [dy] = estimation_data.tangents.outputs 72 | assert x.shape[0] == batch_size 73 | 74 | estimation_data = dataclasses.replace( 75 | estimation_data, 76 | primals=dataclasses.replace( 77 | estimation_data.primals, 78 | inputs=(x.reshape([-1, x.shape[-1]]),), 79 | ), 80 | tangents=dataclasses.replace( 81 | estimation_data.tangents, 82 | outputs=(dy.reshape([-1, dy.shape[-1]]),), 83 | ), 84 | ) 85 | 86 | batch_size = x.size // x.shape[-1] 87 | return super().update_curvature_matrix_estimate( 88 | state=state, 89 | estimation_data=estimation_data, 90 | ema_old=ema_old, 91 | ema_new=ema_new, 92 | identity_weight=identity_weight, 93 | batch_size=batch_size, 94 | ) 95 | 96 | 97 | class QmcBlockedDense(kfac_jax.KroneckerFactored): 98 | """A factor that is the Kronecker product of two matrices.""" 99 | 100 | def input_size(self) -> int: 101 | raise NotImplementedError() 102 | 103 | def output_size(self) -> int: 104 | raise NotImplementedError() 105 | 106 | def fixed_scale(self) -> Numeric: 107 | return float(self.parameters_shapes[0][1]) 108 | 109 | def update_curvature_matrix_estimate( 110 | self, 111 | state: kfac_jax.KroneckerFactored.State, 112 | estimation_data: kfac_jax.LayerVjpData[Array], 113 | ema_old: Numeric, 114 | ema_new: Numeric, 115 | identity_weight: Numeric, 116 | batch_size: int, 117 | ) -> kfac_jax.KroneckerFactored.State: 118 | del identity_weight 119 | 120 | [x] = estimation_data.primals.inputs 121 | [dy] = estimation_data.tangents.outputs 122 | assert batch_size == x.shape[0] 123 | normalizer = x.shape[0] * x.shape[1] 124 | # The forward computation is 125 | # einsum(x,w): bijk,bkmjn -> bijmn 126 | inputs_cov = jnp.einsum("bijk,bijl->jkl", x, x) / normalizer 127 | dy = jnp.reshape(dy, dy.shape[:-2] + (-1,)) 128 | outputs_cov = jnp.einsum("bijk,bijl->jkl", dy, dy) / normalizer 129 | 130 | state.inputs_factor.update(inputs_cov, ema_old, ema_new) 131 | state.outputs_factor.update(outputs_cov, ema_old, ema_new) 132 | 133 | return state 134 | 135 | def _init( 136 | self, 137 | rng: kfac_jax.utils.PRNGKey, 138 | exact_powers_to_cache: Set[Scalar], 139 | approx_powers_to_cache: Set[Scalar], 140 | cache_eigenvalues: bool, 141 | ) -> kfac_jax.KroneckerFactored.State: 142 | del rng, cache_eigenvalues 143 | k, m, j, n = self.parameters_shapes[0] 144 | cache = dict() 145 | if exact_powers_to_cache: 146 | raise NotImplementedError( 147 | "Caching of exact powers is not yet implemented for QmcBlockedDense.") 148 | for power in approx_powers_to_cache: 149 | if power != -1: 150 | raise NotImplementedError(f"Approximations for power {power} is not " 151 | f"yet implemented.") 152 | cache[str(power)] = dict( 153 | inputs_factor=jnp.zeros([j, k, k]), 154 | outputs_factor=jnp.zeros([j, m * n, m * n]), 155 | ) 156 | return kfac_jax.KroneckerFactored.State( 157 | cache=cache, 158 | inputs_factor= 159 | kfac_jax.utils.WeightedMovingAverage.zeros_array((j, k, k)), 160 | outputs_factor=kfac_jax.utils.WeightedMovingAverage.zeros_array( 161 | (j, m * n, m * n)), 162 | ) 163 | 164 | def _update_cache( 165 | self, 166 | state: kfac_jax.KroneckerFactored.State, 167 | identity_weight: kfac_jax.utils.Numeric, 168 | exact_powers: set[kfac_jax.utils.Scalar], 169 | approx_powers: set[kfac_jax.utils.Scalar], 170 | eigenvalues: bool, 171 | ) -> kfac_jax.KroneckerFactored.State: 172 | del eigenvalues 173 | 174 | if exact_powers: 175 | raise NotImplementedError( 176 | "Caching of exact powers is not yet implemented for QmcBlockedDense.") 177 | for power in approx_powers: 178 | if power != -1: 179 | raise NotImplementedError(f"Approximations for power {power} is not " 180 | f"yet implemented.") 181 | cache = state.cache[str(power)] 182 | pi_adjusted_inverse = jax.vmap( 183 | kfac_jax.utils.pi_adjusted_kronecker_inverse, 184 | (0, None), (0, 0) 185 | ) 186 | cache["inputs_factor"], cache["outputs_factor"] = pi_adjusted_inverse( 187 | state.inputs_factor.value, 188 | state.outputs_factor.value, 189 | damping=identity_weight, 190 | ) 191 | return state 192 | 193 | def multiply_matpower( 194 | self, 195 | state: kfac_jax.KroneckerFactored.State, 196 | vector: Sequence[Array], 197 | identity_weight: Numeric, 198 | power: Scalar, 199 | exact_power: bool, 200 | use_cached: bool, 201 | ) -> Tuple[Array, ...]: 202 | w, = vector 203 | # kmjn 204 | v = w 205 | k, m, j, n = v.shape 206 | if power == 1: 207 | # jk(mn) 208 | v = jnp.transpose(v, [2, 0, 1, 3]).reshape([j, k, m * n]) 209 | v = vmap_matmul(state.inputs_factor.value, v) 210 | v = vmap_matmul(v, state.outputs_factor.value) 211 | # kmjn 212 | v = jnp.transpose(v.reshape([j, k, m, n]), [1, 2, 0, 3]) 213 | v = v + identity_weight * w 214 | elif exact_power: 215 | raise NotImplementedError( 216 | "Exact powers is not yet implemented for QmcBlockedDense.") 217 | else: 218 | if not use_cached: 219 | raise NotImplementedError( 220 | "Caching of exact powers is not yet implemented for " 221 | "QmcBlockedDense.") 222 | else: 223 | # jk(mn) 224 | v = jnp.transpose(v, [2, 0, 1, 3]).reshape([j, k, m * n]) 225 | v = vmap_matmul(state.cache[str(power)]["inputs_factor"], v) 226 | v = vmap_matmul(v, state.cache[str(power)]["outputs_factor"]) 227 | # kmjn 228 | v = jnp.transpose(v.reshape([j, k, m, n]), [1, 2, 0, 3]) 229 | # kmjn 230 | return (v,) 231 | 232 | 233 | # repeating a dense layer once 234 | _repeated_dense1 = jax.vmap(_dense, in_axes=[0, [None, None]]) 235 | _repeated_dense2 = jax.vmap(_repeated_dense1, in_axes=[0, [None, None]]) 236 | _repeated_dense1_no_b = jax.vmap(_dense, in_axes=[0, [None]]) 237 | _repeated_dense2_no_b = jax.vmap(_repeated_dense1_no_b, in_axes=[0, [None]]) 238 | 239 | # Computation for repeated dense layer 240 | repeated_dense1_with_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern( 241 | name="repeated_dense1_with_bias", 242 | tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag, 243 | compute_func=_repeated_dense1, 244 | parameters_extractor_func=_repeated_dense_parameter_extractor, 245 | example_args=[np.zeros([9, 11, 13]), [np.zeros([13, 7]), np.zeros([7])]], 246 | ) 247 | 248 | repeated_dense1_no_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern( 249 | name="repeated_dense1_no_bias", 250 | tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag, 251 | compute_func=_repeated_dense1_no_b, 252 | parameters_extractor_func=_repeated_dense_parameter_extractor, 253 | example_args=[np.zeros([9, 11, 13]), [np.zeros([13, 7])]], 254 | ) 255 | 256 | repeated_dense2_with_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern( 257 | name="repeated_dense2_with_bias", 258 | tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag, 259 | compute_func=_repeated_dense2, 260 | parameters_extractor_func=_repeated_dense_parameter_extractor, 261 | example_args=[np.zeros([8, 9, 11, 13]), [np.zeros([13, 7]), np.zeros([7])]], 262 | ) 263 | 264 | repeated_dense2_no_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern( 265 | name="repeated_dense2_no_bias", 266 | tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag, 267 | compute_func=_repeated_dense2_no_b, 268 | parameters_extractor_func=_repeated_dense_parameter_extractor, 269 | example_args=[np.zeros([8, 9, 11, 13]), [np.zeros([13, 7])]], 270 | ) 271 | 272 | GRAPH_PATTERNS = ( 273 | repeated_dense1_with_bias_pattern, 274 | repeated_dense2_with_bias_pattern, 275 | repeated_dense1_no_bias_pattern, 276 | repeated_dense2_no_bias_pattern, 277 | ) + kfac_jax.tag_graph_matcher.DEFAULT_GRAPH_PATTERNS 278 | 279 | 280 | kfac_jax.set_default_tag_to_block_ctor( 281 | "repeated_dense", RepeatedDenseBlock 282 | ) 283 | kfac_jax.set_default_tag_to_block_ctor("qmc", QmcBlockedDense) 284 | -------------------------------------------------------------------------------- /ferminet/tests/networks_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ferminet.networks.""" 16 | 17 | import itertools 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import chex 22 | from ferminet import envelopes 23 | from ferminet import networks 24 | import jax 25 | from jax import random 26 | import jax.numpy as jnp 27 | import numpy as np 28 | 29 | 30 | def rand_default(): 31 | randn = np.random.RandomState(0).randn 32 | def generator(shape, dtype): 33 | return randn(*shape).astype(dtype) 34 | return generator 35 | 36 | 37 | def _antisymmtry_options(): 38 | for envelope in envelopes.EnvelopeLabel: 39 | yield { 40 | 'testcase_name': f'_envelope={envelope}', 41 | 'envelope_label': envelope, 42 | 'dtype': np.float32, 43 | } 44 | 45 | 46 | def _network_options(): 47 | """Yields the set of all combinations of options to pass into test_fermi_net. 48 | 49 | Example output: 50 | { 51 | 'vmap': True, 52 | 'envelope': envelopes.EnvelopeLabel.ISOTROPIC, 53 | 'bias_orbitals': False, 54 | 'full_det': True, 55 | 'use_last_layer': False, 56 | 'hidden_dims': ((32, 8), (32, 8)), 57 | } 58 | """ 59 | # Key for each option and corresponding values to test. 60 | all_options = { 61 | 'vmap': [True, False], 62 | 'envelope_label': list(envelopes.EnvelopeLabel), 63 | 'bias_orbitals': [True, False], 64 | 'full_det': [True, False], 65 | 'use_last_layer': [True, False], 66 | 'hidden_dims': [((32, 8), (32, 8))], 67 | } 68 | # Create the product of all options. 69 | for options in itertools.product(*all_options.values()): 70 | # Yield dict of the current combination of options. 71 | yield dict(zip(all_options.keys(), options)) 72 | 73 | 74 | class NetworksTest(parameterized.TestCase): 75 | 76 | @parameterized.named_parameters(_antisymmtry_options()) 77 | def test_antisymmetry(self, envelope_label, dtype): 78 | """Check that the Fermi Net is symmetric.""" 79 | del dtype # unused 80 | 81 | key = random.PRNGKey(42) 82 | 83 | key, *subkeys = random.split(key, num=3) 84 | natoms = 4 85 | atoms = random.normal(subkeys[0], shape=(natoms, 3)) 86 | charges = random.normal(subkeys[1], shape=(natoms,)) 87 | nspins = (3, 4) 88 | 89 | key, subkey = random.split(key) 90 | pos1 = random.normal(subkey, shape=(sum(nspins) * 3,)) 91 | pos2 = jnp.concatenate((pos1[3:6], pos1[:3], pos1[6:])) 92 | pos3 = jnp.concatenate((pos1[:9], pos1[12:15], pos1[9:12], pos1[15:])) 93 | 94 | key, subkey = random.split(key) 95 | spins1 = jax.random.uniform(subkey, shape=(sum(nspins),)) 96 | spins2 = jnp.concatenate((spins1[1:2], spins1[:1], spins1[2:])) 97 | spins3 = jnp.concatenate((spins1[:3], spins1[4:5], spins1[3:4], spins1[5:])) 98 | 99 | feature_layer = networks.make_ferminet_features(natoms, nspins, ndim=3) 100 | 101 | kwargs = {} 102 | network = networks.make_fermi_net( 103 | nspins=nspins, 104 | charges=charges, 105 | hidden_dims=((16, 16), (16, 16)), 106 | envelope=envelopes.get_envelope(envelope_label, **kwargs), 107 | feature_layer=feature_layer, 108 | ) 109 | 110 | key, subkey = random.split(key) 111 | params = network.init(subkey) 112 | 113 | # Randomize parameters of envelope 114 | if isinstance(params['envelope'], list): 115 | for i in range(len(params['envelope'])): 116 | if params['envelope'][i]: 117 | key, *subkeys = random.split(key, num=3) 118 | params['envelope'][i]['sigma'] = random.normal( 119 | subkeys[0], params['envelope'][i]['sigma'].shape) 120 | params['envelope'][i]['pi'] = random.normal( 121 | subkeys[1], params['envelope'][i]['pi'].shape) 122 | else: 123 | assert isinstance(params['envelope'], dict) 124 | key, *subkeys = random.split(key, num=3) 125 | params['envelope']['sigma'] = random.normal( 126 | subkeys[0], params['envelope']['sigma'].shape) 127 | params['envelope']['pi'] = random.normal( 128 | subkeys[1], params['envelope']['pi'].shape 129 | ) 130 | 131 | out1 = network.apply(params, pos1, spins1, atoms, charges) 132 | 133 | out2 = network.apply(params, pos2, spins2, atoms, charges) 134 | np.testing.assert_allclose(out1[1], out2[1], atol=1E-5, rtol=1E-5) 135 | np.testing.assert_allclose(out1[0], -1*out2[0], atol=1E-5, rtol=1E-5) 136 | 137 | out3 = network.apply(params, pos3, spins3, atoms, charges) 138 | np.testing.assert_allclose(out1[1], out3[1], atol=1E-5, rtol=1E-5) 139 | np.testing.assert_allclose(out1[0], -1*out3[0], atol=1E-5, rtol=1E-5) 140 | 141 | def test_create_input_features(self): 142 | dtype = np.float32 143 | ndim = 3 144 | nelec = 6 145 | xs = np.random.normal(scale=3, size=(nelec, ndim)).astype(dtype) 146 | atoms = jnp.array([[0.2, 0.5, 0.3], [1.2, 0.3, 0.7]]) 147 | input_features = networks.construct_input_features(xs, atoms) 148 | d_input_features = jax.jacfwd(networks.construct_input_features)( 149 | xs, atoms, ndim=3) 150 | r_ee = input_features[-1][:, :, 0] 151 | d_r_ee = d_input_features[-1][:, :, 0] 152 | # The gradient of |r_i - r_j| wrt r_k should only be non-zero for k = i or 153 | # k = j and the i = j term should be explicitly masked out. 154 | mask = np.fromfunction( 155 | lambda i, j, k: np.logical_and(np.logical_or(i == k, j == k), i != j), 156 | d_r_ee.shape[:-1], 157 | ) 158 | d_r_ee_non_zeros = d_r_ee[mask] 159 | d_r_ee_zeros = d_r_ee[~mask] 160 | with self.subTest('check forward pass'): 161 | chex.assert_tree_all_finite(input_features) 162 | # |r_i - r_j| should be zero. 163 | np.testing.assert_allclose(np.diag(r_ee), np.zeros(6), atol=1E-5) 164 | with self.subTest('check backwards pass'): 165 | # Most importantly, check the gradient of the electron-electron distances, 166 | # |x_i - x_j|, is masked out for i==j. 167 | chex.assert_tree_all_finite(d_input_features) 168 | # We mask out the |r_i-r_j| terms for i == j. Check these are zero. 169 | np.testing.assert_allclose( 170 | d_r_ee_zeros, np.zeros_like(d_r_ee_zeros), atol=1E-5, rtol=1E-5) 171 | self.assertTrue(np.all(np.abs(d_r_ee_non_zeros) > 0.0)) 172 | 173 | @parameterized.parameters(None, 4) 174 | def test_construct_symmetric_features(self, naux_features): 175 | dtype = np.float32 176 | hidden_units_one = 8 # 128 177 | hidden_units_two = 4 # 32 178 | nspins = (6, 5) 179 | h_one = np.random.uniform( 180 | low=-5, high=5, size=(sum(nspins), hidden_units_one)).astype(dtype) 181 | h_two = np.random.uniform( 182 | low=-5, 183 | high=5, 184 | size=(sum(nspins), sum(nspins), hidden_units_two)).astype(dtype) 185 | if naux_features: 186 | h_aux = np.random.uniform(size=(sum(nspins), naux_features)).astype(dtype) 187 | else: 188 | h_aux = None 189 | h_two = h_two + np.transpose(h_two, axes=(1, 0, 2)) 190 | features = networks.construct_symmetric_features( 191 | h_one, h_two, nspins, h_aux=h_aux 192 | ) 193 | # Swap electrons 194 | swaps = np.arange(sum(nspins)) 195 | np.random.shuffle(swaps[:nspins[0]]) 196 | np.random.shuffle(swaps[nspins[0]:]) 197 | inverse_swaps = [0] * len(swaps) 198 | for i, j in enumerate(swaps): 199 | inverse_swaps[j] = i 200 | inverse_swaps = np.asarray(inverse_swaps) 201 | h_aux_swap = h_aux if h_aux is None else h_aux[swaps] 202 | features_swap = networks.construct_symmetric_features( 203 | h_one[swaps], h_two[swaps][:, swaps], nspins, h_aux=h_aux_swap 204 | ) 205 | np.testing.assert_allclose( 206 | features, features_swap[inverse_swaps], atol=1E-5, rtol=1E-5) 207 | 208 | @parameterized.parameters(_network_options()) 209 | def test_fermi_net(self, vmap, **network_options): 210 | # Warning: this only tests we can build and run the network. It does not 211 | # test correctness of output nor test changing network width or depth. 212 | nspins = (6, 5) 213 | natoms = 3 214 | atoms = jnp.asarray([[0., 0., 0.2], [1.2, 1., -0.2], [2.5, -0.8, 0.6]]) 215 | charges = jnp.asarray([2, 5, 7]) 216 | key = jax.random.PRNGKey(42) 217 | feature_layer = networks.make_ferminet_features(natoms, nspins, ndim=3) 218 | kwargs = {} 219 | network_options['envelope'] = envelopes.get_envelope( 220 | network_options['envelope_label'], **kwargs) 221 | del network_options['envelope_label'] 222 | 223 | envelope = network_options['envelope'] 224 | if ( 225 | envelope.apply_type == envelopes.EnvelopeType.PRE_ORBITAL 226 | and network_options['bias_orbitals'] 227 | ): 228 | with self.assertRaises(ValueError): 229 | networks.make_fermi_net( 230 | nspins, charges, feature_layer=feature_layer, **network_options 231 | ) 232 | else: 233 | network = networks.make_fermi_net( 234 | nspins, charges, feature_layer=feature_layer, **network_options 235 | ) 236 | 237 | key, *subkeys = jax.random.split(key, num=3) 238 | if vmap: 239 | batch = 10 240 | xs = jax.random.uniform(subkeys[0], shape=(batch, sum(nspins), 3)) 241 | spins = jax.random.uniform(subkeys[1], shape=(batch, sum(nspins))) 242 | fermi_net = jax.vmap(network.apply, in_axes=(None, 0, 0, None, None)) 243 | expected_shape = (batch,) 244 | else: 245 | xs = jax.random.uniform(subkeys[0], shape=(sum(nspins), 3)) 246 | spins = jax.random.uniform(subkeys[1], shape=(sum(nspins),)) 247 | fermi_net = network.apply 248 | expected_shape = () 249 | 250 | key, subkey = jax.random.split(key) 251 | params = network.init(subkey) 252 | sign_psi, log_psi = fermi_net(params, xs, spins, atoms, charges) 253 | self.assertSequenceEqual(sign_psi.shape, expected_shape) 254 | self.assertSequenceEqual(log_psi.shape, expected_shape) 255 | 256 | @parameterized.parameters( 257 | *(itertools.product([(1, 0), (2, 0), (0, 1)], [True, False]))) 258 | def test_spin_polarised_fermi_net(self, nspins, full_det): 259 | natoms = 1 260 | atoms = jnp.zeros(shape=(1, 3)) 261 | charges = jnp.ones(shape=1) 262 | key = jax.random.PRNGKey(42) 263 | feature_layer = networks.make_ferminet_features(natoms, nspins, ndim=3) 264 | network = networks.make_fermi_net( 265 | nspins, charges, feature_layer=feature_layer, full_det=full_det 266 | ) 267 | key, *subkeys = jax.random.split(key, num=4) 268 | params = network.init(subkeys[0]) 269 | xs = jax.random.uniform(subkeys[1], shape=(sum(nspins) * 3,)) 270 | spins = jax.random.uniform(subkeys[2], shape=(sum(nspins),)) 271 | # Test fermi_net runs without raising exceptions for spin-polarised systems. 272 | sign_out, log_out = network.apply(params, xs, spins, atoms, charges) 273 | self.assertEqual(sign_out.size, 1) 274 | self.assertEqual(log_out.size, 1) 275 | 276 | 277 | if __name__ == '__main__': 278 | absltest.main() 279 | --------------------------------------------------------------------------------