├── .gitignore ├── DeepSolid ├── __init__.py ├── config │ ├── poscar │ │ └── bcc_li.vasp │ ├── read_poscar.py │ ├── diamond.py │ ├── rock_salt.py │ ├── hydrogen_chain.py │ ├── two_hydrogen_cell.py │ └── graphene.py ├── utils │ ├── kfac_ferminet_alpha │ │ ├── __init__.py │ │ ├── distributions.py │ │ ├── vjp_rc.py │ │ ├── example.py │ │ ├── layers_and_loss_tags.py │ │ └── tracer.py │ ├── units.py │ ├── poscar_to_cell.py │ ├── system.py │ ├── writers.py │ └── elements.py ├── constants.py ├── distributed.py ├── estimator.py ├── init_guess.py ├── curvature_tags_and_blocks.py ├── supercell.py ├── checkpoint.py ├── base_config.py ├── distance.py ├── ewaldsum.py ├── hf.py ├── hamiltonian.py ├── train.py ├── pretrain.py └── qmc.py ├── test ├── test_cell.py └── test_network.py ├── bin └── deepsolid ├── setup.py ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | *.csv 4 | *.npz 5 | .idea 6 | -------------------------------------------------------------------------------- /DeepSolid/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /DeepSolid/config/poscar/bcc_li.vasp: -------------------------------------------------------------------------------- 1 | Li2 2 | 1.0 3 | 3.4268178940 0.0000000000 0.0000000000 4 | 0.0000000000 3.4268178940 0.0000000000 5 | 0.0000000000 0.0000000000 3.4268178940 6 | Li 7 | 2 8 | Cartesian 9 | 0.000000000 0.000000000 0.000000000 10 | 1.713408947 1.713408947 1.713408947 11 | -------------------------------------------------------------------------------- /test/test_cell.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pyscf.pbc import gto 8 | import numpy as np 9 | 10 | # Define your test system 11 | nk = 1 12 | cell = gto.Cell() 13 | L = 2 / 0.529177 14 | cell.atom = f""" 15 | Li 0 0 0 16 | H {L/2} {L/2} {L/2} 17 | """ 18 | cell.basis = "sto-3g" 19 | cell.a = (1 - np.eye(3)) * L / 2 20 | cell.unit = "B" 21 | cell.verbose = 5 22 | cell.spin = 0 23 | cell.exp_to_discard = 0.1 24 | cell.build() 25 | -------------------------------------------------------------------------------- /DeepSolid/config/read_poscar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from DeepSolid import base_config 8 | from DeepSolid import supercell 9 | from DeepSolid.utils import poscar_to_cell 10 | import numpy as np 11 | 12 | 13 | def get_config(input_str): 14 | poscar_path, S, basis = input_str.split(',') 15 | cell = poscar_to_cell.read_poscar(poscar_path) 16 | S = int(S) 17 | S = np.diag([S, S, S]) 18 | cell.verbose = 5 19 | cell.basis = basis 20 | cell.exp_to_discard = 0.1 21 | cell.build() 22 | cfg = base_config.default() 23 | 24 | # Set up cell 25 | 26 | simulation_cell = supercell.get_supercell(cell, S) 27 | if cell.spin != 0: 28 | simulation_cell.hf_type = 'uhf' 29 | cfg.system.pyscf_cell = simulation_cell 30 | 31 | return cfg -------------------------------------------------------------------------------- /bin/deepsolid: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # Copyright (c) ByteDance, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # modified from FermiNet:https://github.com/deepmind/ferminet 9 | 10 | import sys 11 | 12 | from absl import app 13 | from absl import flags 14 | from absl import logging 15 | from jax.config import config as jax_config 16 | from DeepSolid import process 17 | from ml_collections.config_flags import config_flags 18 | 19 | logging.get_absl_handler().python_handler.stream = sys.stdout 20 | logging.set_verbosity(logging.INFO) 21 | 22 | # internal imports 23 | 24 | FLAGS = flags.FLAGS 25 | 26 | config_flags.DEFINE_config_file('config', None, 'Path to config file.') 27 | 28 | 29 | def main(_): 30 | cfg = FLAGS.config 31 | if cfg.use_x64: 32 | jax_config.update("jax_enable_x64", True) 33 | process.process(cfg) 34 | 35 | 36 | if __name__ == '__main__': 37 | app.run(main) 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup, find_packages 8 | 9 | REQUIRED_PACKAGES = ( 10 | "absl-py", 11 | 'attrs', 12 | "dataclasses", 13 | "networkx", 14 | "scipy==1.9.3", 15 | "numpy", 16 | "ordered-set", 17 | "typing", 18 | "chex==0.1.5", 19 | "jax==0.2.26", 20 | "jaxlib==0.1.75", 21 | "pandas", 22 | "ml_collections", 23 | "pyscf", 24 | "tables", 25 | "h5py==3.2.1", 26 | "optax==0.0.9", 27 | "opt_einsum==3.3.0", 28 | ) 29 | 30 | setup( 31 | name="DeepSolid", 32 | version="1.0", 33 | description="A library combining solid quantum Monte Carlo and neural network.", 34 | author='ByteDance', 35 | author_email='lixiang.62770689@bytedance.com', 36 | install_requires=REQUIRED_PACKAGES, 37 | packages=find_packages(), 38 | scripts=['bin/deepsolid'], 39 | license='Apache 2.0', 40 | ) 41 | -------------------------------------------------------------------------------- /DeepSolid/config/diamond.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | from pyscf.pbc import gto 9 | 10 | from DeepSolid import base_config 11 | from DeepSolid import supercell 12 | from DeepSolid.utils import units 13 | 14 | 15 | def get_config(input_str): 16 | X, Y, L_Ang, S, basis= input_str.split(',') 17 | S = np.eye(3) * int(S) 18 | cfg = base_config.default() 19 | L_Ang = float(L_Ang) 20 | L_Bohr = units.angstrom2bohr(L_Ang) 21 | 22 | # Set up cell 23 | cell = gto.Cell() 24 | cell.atom = [[X, [0.0, 0.0, 0.0]], 25 | [Y, [0.25 * L_Bohr, 0.25 * L_Bohr, 0.25 * L_Bohr]]] 26 | 27 | cell.basis = basis 28 | cell.a = (np.ones((3, 3)) - np.eye(3)) * L_Bohr / 2 29 | cell.unit = "B" 30 | cell.verbose = 5 31 | cell.exp_to_discard = 0.1 32 | cell.build() 33 | simulation_cell = supercell.get_supercell(cell, S) 34 | cfg.system.pyscf_cell = simulation_cell 35 | 36 | return cfg -------------------------------------------------------------------------------- /DeepSolid/config/rock_salt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | from pyscf.pbc import gto 9 | 10 | from DeepSolid import base_config 11 | from DeepSolid import supercell 12 | from DeepSolid.utils import units 13 | 14 | 15 | def get_config(input_str): 16 | X, Y, L_Ang, S, basis= input_str.split(',') 17 | S = np.eye(3) * int(S) 18 | cfg = base_config.default() 19 | L_Ang = float(L_Ang) 20 | L_Bohr = units.angstrom2bohr(L_Ang) 21 | 22 | # Set up cell 23 | cell = gto.Cell() 24 | cell.atom = [[X, [0.0, 0.0, 0.0]], 25 | [Y, [0.5 * L_Bohr, 0.5 * L_Bohr, 0.5 * L_Bohr]]] 26 | 27 | 28 | cell.basis = basis 29 | cell.a = (np.ones((3, 3)) - np.eye(3)) * L_Bohr / 2 30 | cell.unit = "B" 31 | cell.verbose = 5 32 | cell.exp_to_discard = 0.1 33 | cell.build() 34 | simulation_cell = supercell.get_supercell(cell, S) 35 | cfg.system.pyscf_cell = simulation_cell 36 | 37 | return cfg -------------------------------------------------------------------------------- /DeepSolid/utils/kfac_ferminet_alpha/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 17 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 18 | """Module for anything that an end user would use.""" 19 | 20 | from DeepSolid.utils.kfac_ferminet_alpha.loss_functions import register_normal_predictive_distribution 21 | from DeepSolid.utils.kfac_ferminet_alpha.loss_functions import register_squared_error_loss 22 | from DeepSolid.utils.kfac_ferminet_alpha.optimizer import Optimizer 23 | -------------------------------------------------------------------------------- /DeepSolid/config/hydrogen_chain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import numpy as np 9 | from pyscf.pbc import gto 10 | 11 | from DeepSolid import base_config 12 | from DeepSolid import supercell 13 | 14 | 15 | def get_config(input_str): 16 | symbol, Sx, Sy, Sz, L, spin, basis= input_str.split(',') 17 | Sx = int(Sx) 18 | Sy = int(Sy) 19 | Sz = int(Sz) 20 | S = np.diag([Sx, Sy, Sz]) 21 | L = float(L) 22 | spin = int(spin) 23 | cfg = base_config.default() 24 | 25 | # Set up cell 26 | cell = gto.Cell() 27 | cell.atom = f""" 28 | {symbol} {L/2} {0} {0} 29 | """ 30 | cell.basis = basis 31 | cell.a = np.array([[L, 0, 0], 32 | [0, 100, 0], 33 | [0, 0, 100]]) 34 | cell.unit = "B" 35 | cell.spin = spin 36 | cell.verbose = 5 37 | cell.exp_to_discard = 0.1 38 | cell.build() 39 | simulation_cell = supercell.get_supercell(cell, S) 40 | cfg.system.pyscf_cell = simulation_cell 41 | 42 | return cfg -------------------------------------------------------------------------------- /DeepSolid/config/two_hydrogen_cell.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import numpy as np 9 | from pyscf.pbc import gto 10 | 11 | from DeepSolid import base_config 12 | from DeepSolid import supercell 13 | 14 | 15 | def get_config(input_str): 16 | symbol, Sx, Sy, Sz, L, spin, basis= input_str.split(',') 17 | Sx = int(Sx) 18 | Sy = int(Sy) 19 | Sz = int(Sz) 20 | S = np.diag([Sx, Sy, Sz]) 21 | L = float(L) 22 | spin = int(spin) 23 | cfg = base_config.default() 24 | 25 | # Set up cell 26 | cell = gto.Cell() 27 | cell.atom = f""" 28 | {symbol} {L} {0} {0} 29 | {symbol} 0 0 0 30 | """ 31 | cell.basis = basis 32 | cell.a = np.array([[2*L, 0, 0], 33 | [0, 100, 0], 34 | [0, 0, 100]]) 35 | cell.unit = "B" 36 | cell.spin = spin 37 | cell.verbose = 5 38 | cell.exp_to_discard = 0.1 39 | cell.build() 40 | simulation_cell = supercell.get_supercell(cell, S) 41 | simulation_cell.hf_type = 'uhf' 42 | cfg.system.pyscf_cell = simulation_cell 43 | 44 | return cfg -------------------------------------------------------------------------------- /DeepSolid/config/graphene.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | from pyscf.pbc import gto 9 | 10 | from DeepSolid import base_config 11 | from DeepSolid import supercell 12 | from DeepSolid.utils import units 13 | 14 | 15 | def get_config(input_str): 16 | X, Y, L_Ang, S, z, basis = input_str.split(',') 17 | S = np.diag([int(S), int(S), 1]) 18 | cfg = base_config.default() 19 | L_Ang = float(L_Ang) 20 | z = float(z) 21 | L_Bohr = units.angstrom2bohr(L_Ang) 22 | 23 | # Set up cell 24 | cell = gto.Cell() 25 | cell.atom = [[X, [3**(-0.5) * L_Bohr, 0.0, 0.0]], 26 | [Y, [2*3**(-0.5) * L_Bohr, 0.0, 0.0]]] 27 | 28 | cell.basis = basis 29 | cell.a = np.array([[L_Bohr * np.cos(np.pi/6), -L_Bohr * 0.5, 0], 30 | [L_Bohr * np.cos(np.pi/6), L_Bohr * 0.5, 0], 31 | [0, 0, z], 32 | ]) 33 | cell.unit = "B" 34 | cell.verbose = 5 35 | cell.exp_to_discard = 0.1 36 | cell.build() 37 | simulation_cell = supercell.get_supercell(cell, S) 38 | cfg.system.pyscf_cell = simulation_cell 39 | 40 | return cfg -------------------------------------------------------------------------------- /DeepSolid/utils/units.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 17 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 18 | 19 | from typing import TypeVar 20 | import numpy as np 21 | 22 | # 1 Bohr = 0.52917721067 (12) x 10^{-10} m 23 | # https://physics.nist.gov/cgi-bin/cuu/Value?bohrrada0 24 | # Note: pyscf uses a slightly older definition of 0.52917721092 angstrom. 25 | ANGSTROM_BOHR = 0.52917721067 26 | BOHR_ANGSTROM = 1. / ANGSTROM_BOHR 27 | 28 | # 1 Hartree = 627.509474 kcal/mol 29 | # https://en.wikipedia.org/wiki/Hartree 30 | KCAL_HARTREE = 627.509474 31 | HARTREE_KCAL = 1. / KCAL_HARTREE 32 | 33 | NumericalLike = TypeVar('NumericalLike', float, np.ndarray) 34 | 35 | 36 | def bohr2angstrom(x_b: NumericalLike) -> NumericalLike: 37 | return x_b * ANGSTROM_BOHR 38 | 39 | 40 | def angstrom2bohr(x_a: NumericalLike) -> NumericalLike: 41 | return x_a * BOHR_ANGSTROM 42 | 43 | 44 | def hartree2kcal(x_b: NumericalLike) -> NumericalLike: 45 | return x_b * KCAL_HARTREE 46 | 47 | 48 | def kcal2hartree(x_a: NumericalLike) -> NumericalLike: 49 | return x_a * HARTREE_KCAL 50 | -------------------------------------------------------------------------------- /DeepSolid/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 16 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 17 | 18 | import functools 19 | import jax 20 | import jax.numpy as jnp 21 | from jax import core 22 | from typing import TypeVar 23 | 24 | T = TypeVar("T") 25 | 26 | PMAP_AXIS_NAME = 'qmc_pmap_axis' 27 | 28 | pmap = functools.partial(jax.pmap, axis_name=PMAP_AXIS_NAME) 29 | broadcast_all_local_devices = jax.pmap(lambda x: x) 30 | p_split = jax.pmap(lambda key: tuple(jax.random.split(key))) 31 | 32 | 33 | def wrap_if_pmap(p_func): 34 | def p_func_if_pmap(obj, axis_name): 35 | try: 36 | core.axis_frame(axis_name) 37 | return p_func(obj, axis_name) 38 | except NameError: 39 | return obj 40 | 41 | return p_func_if_pmap 42 | 43 | 44 | pmean_if_pmap = wrap_if_pmap(jax.lax.pmean) 45 | psum_if_pmap = wrap_if_pmap(jax.lax.psum) 46 | 47 | 48 | def replicate_all_local_devices(obj: T) -> T: 49 | n = jax.local_device_count() 50 | obj_stacked = jax.tree_map(lambda x: jnp.stack([x] * n, axis=0), obj) 51 | return broadcast_all_local_devices(obj_stacked) 52 | 53 | 54 | def make_different_rng_key_on_all_devices(rng: jnp.ndarray) -> jnp.ndarray: 55 | rng = jax.random.fold_in(rng, jax.host_id()) 56 | rng = jax.random.split(rng, jax.local_device_count()) 57 | return broadcast_all_local_devices(rng) 58 | -------------------------------------------------------------------------------- /DeepSolid/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 16 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 17 | 18 | import functools 19 | 20 | from absl import logging 21 | from jax._src.lib import xla_bridge 22 | from jax._src.lib import xla_client 23 | from jax._src.lib import xla_extension 24 | 25 | _service = None 26 | 27 | 28 | def initialize(coordinator_address: str, num_processes: int, process_id: int, 29 | xla_client_config): 30 | """Initialize distributed system for topology discovery. 31 | 32 | Currently, calling ``initialize`` sets up the multi-host GPU backend, and 33 | is not required for CPU or TPU backends. 34 | 35 | Args: 36 | coordinator_address: IP address of the coordinator. 37 | num_processes: Number of processes. 38 | process_id: Id of the current processe. 39 | 40 | """ 41 | if process_id == 0: 42 | global _service 43 | assert _service is None, 'initialize should be called once only' 44 | logging.info('Starting JAX distributed service on %s', coordinator_address) 45 | _service = xla_extension.get_distributed_runtime_service(coordinator_address, 46 | num_processes) 47 | 48 | client = xla_extension.get_distributed_runtime_client(coordinator_address, 49 | process_id, 50 | **xla_client_config) 51 | logging.info('Connecting to JAX distributed service on %s', coordinator_address) 52 | client.connect() 53 | 54 | factory = functools.partial(xla_client.make_gpu_client, client, process_id) 55 | xla_bridge.register_backend_factory('gpu', factory, priority=300) 56 | -------------------------------------------------------------------------------- /DeepSolid/utils/kfac_ferminet_alpha/distributions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 17 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 18 | """Module for all distribution implementations needed for the loss functions.""" 19 | import math 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | 24 | class MultivariateNormalDiag: 25 | """Multivariate normal distribution on `R^k`.""" 26 | 27 | def __init__( 28 | self, 29 | loc: jnp.ndarray, 30 | scale_diag: jnp.ndarray): 31 | """Initializes a MultivariateNormalDiag distribution. 32 | 33 | Args: 34 | loc: Mean vector of the distribution. Can also be a batch of vectors. 35 | scale_diag: Vector of standard deviations. 36 | """ 37 | super().__init__() 38 | self._loc = loc 39 | self._scale_diag = scale_diag 40 | 41 | @property 42 | def loc(self) -> jnp.ndarray: 43 | """Mean of the distribution.""" 44 | return self._loc 45 | 46 | @property 47 | def scale_diag(self) -> jnp.ndarray: 48 | """Scale of the distribution.""" 49 | return self._scale_diag 50 | 51 | def _num_dims(self) -> int: 52 | """Dimensionality of the events.""" 53 | return self._scale_diag.shape[-1] 54 | 55 | def _standardize(self, value: jnp.ndarray) -> jnp.ndarray: 56 | return (value - self._loc) / self._scale_diag 57 | 58 | def log_prob(self, value: jnp.ndarray) -> jnp.ndarray: 59 | """See `Distribution.log_prob`.""" 60 | log_unnormalized = -0.5 * jnp.square(self._standardize(value)) 61 | log_normalization = 0.5 * math.log(2 * math.pi) + jnp.log(self._scale_diag) 62 | return jnp.sum(log_unnormalized - log_normalization, axis=-1) 63 | 64 | def mean(self) -> jnp.ndarray: 65 | """Calculates the mean.""" 66 | return self.loc 67 | 68 | def sample(self, seed: jnp.ndarray) -> jnp.ndarray: 69 | """Samples an event. 70 | 71 | Args: 72 | seed: PRNG key or integer seed. 73 | 74 | Returns: 75 | A sample. 76 | """ 77 | eps = jax.random.normal(seed, self.loc.shape) 78 | return self.loc + eps * self.scale_diag 79 | -------------------------------------------------------------------------------- /DeepSolid/utils/poscar_to_cell.py: -------------------------------------------------------------------------------- 1 | """ 2 | I/O routines for crystal structure. 3 | Author: 4 | Zhi-Hao Cui 5 | Bo-Xiao Zheng 6 | """ 7 | # modified from libdmet_preview: 8 | # https://github.com/gkclab/libdmet_preview/blob/faee119f18755314d945393595301f66baf40ae5/libdmet/utils/iotools.py 9 | 10 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 11 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 12 | 13 | 14 | import numpy as np 15 | from pyscf.data.nist import BOHR 16 | from collections import OrderedDict 17 | import scipy.linalg as la 18 | import sys 19 | import os 20 | 21 | 22 | def Frac2Real(cellsize, coord): 23 | assert cellsize.ndim == 2 and cellsize.shape[0] == cellsize.shape[1] 24 | return np.dot(coord, cellsize) 25 | 26 | def Real2Frac(cellsize, coord): 27 | assert cellsize.ndim == 2 and cellsize.shape[0] == cellsize.shape[1] 28 | return np.dot(coord, la.inv(cellsize)) 29 | 30 | 31 | def read_poscar(fname="POSCAR"): 32 | """ 33 | Read cell structure from a VASP POSCAR file. 34 | 35 | Args: 36 | fname: file name. 37 | Returns: 38 | cell: cell, without build, unit in A. 39 | """ 40 | from pyscf.pbc import gto 41 | with open(fname, 'r') as f: 42 | lines = f.readlines() 43 | 44 | # 1 line scale factor 45 | line = lines[1].split() 46 | assert len(line) == 1 47 | factor = float(line[0]) 48 | 49 | # 2-4 line, lattice vector 50 | a = np.array([np.fromstring(lines[i], dtype=np.double, sep=' ') \ 51 | for i in range(2, 5)]) * factor 52 | a = a / BOHR 53 | 54 | # 5, 6 line, species names and numbers 55 | sp_names = lines[5].split() 56 | if all([name.isdigit() for name in sp_names]): 57 | # 5th line can be number of atoms not names. 58 | sp_nums = np.fromstring(lines[5], dtype=int, sep=' ') 59 | sp_names = ["X" for i in range(len(sp_nums))] 60 | line_no = 6 61 | else: 62 | sp_nums = np.fromstring(lines[6], dtype=int, sep=' ') 63 | line_no = 7 64 | 65 | # 7, cartisian or fraction or direct 66 | line = lines[line_no].split() 67 | assert len(line) == 1 68 | use_cart = line[0].startswith(('C', 'K', 'c', 'k')) 69 | line_no += 1 70 | 71 | # 8-end, coords 72 | atom_col = [] 73 | for sp_name, sp_num in zip(sp_names, sp_nums): 74 | for i in range(sp_num): 75 | # there may be 4th element for comments or fixation. 76 | coord = np.array(list(map(float, 77 | \ 78 | lines[line_no].split()[:3]))) 79 | if use_cart: 80 | coord *= factor 81 | coord = coord / BOHR 82 | else: 83 | coord = Frac2Real(a, coord) 84 | atom_col.append((sp_name, coord)) 85 | line_no += 1 86 | 87 | cell = gto.Cell() 88 | cell.a = a 89 | cell.atom = atom_col 90 | cell.unit = 'Bohr' 91 | return cell 92 | 93 | -------------------------------------------------------------------------------- /DeepSolid/estimator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import functools 10 | 11 | import pyscf.pbc.gto 12 | from DeepSolid import constants 13 | 14 | 15 | def make_complex_polarization(simulation_cell: pyscf.pbc.gto.Cell, 16 | direction: int = 0, 17 | ndim=3): 18 | ''' 19 | generates the order parameter function of hydrogen chain. 20 | :param simulation_cell: pyscf object of simulation cell. 21 | :param direction: projection direction of electrons 22 | :param ndim: 23 | :return:the order parameter 24 | ''' 25 | 26 | rec_vec = simulation_cell.reciprocal_vectors()[direction] 27 | 28 | def complex_polarization(data): 29 | """ 30 | 31 | :param data: electron walkers with shape [batch, ne * ndim] 32 | :return: complex polarization with shape [] 33 | """ 34 | leading_shape = list(data.shape[:-1]) 35 | data = data.reshape(leading_shape + [-1, ndim]) 36 | dots = jnp.einsum('i,...i->...', rec_vec, data) 37 | dots = jnp.sum(dots, axis=-1) 38 | polarization = jnp.exp(1j * dots) 39 | polarization = jnp.mean(polarization, axis=-1) 40 | polarization = constants.pmean_if_pmap(polarization, axis_name=constants.PMAP_AXIS_NAME) 41 | return polarization 42 | 43 | return complex_polarization 44 | 45 | def make_structure_factor(simulation_cell: pyscf.pbc.gto.Cell, 46 | nq=4, 47 | ndim=3): 48 | ''' 49 | generates the structure factor function which is used for finite size error reduction. 50 | see PHYSICAL REVIEW B 94, 035126 (2016) for details. 51 | :param simulation_cell: pyscf object of simulation cell. 52 | :param nq: number of sampled crystal momentum in each direction. 53 | :return:the structure factor. 54 | ''' 55 | mesh_grid = jnp.meshgrid(*[jnp.array(range(0, nq)) for _ in range(3)]) 56 | point_list = jnp.stack([m.ravel() for m in mesh_grid], axis=0).T 57 | rec_vec = simulation_cell.reciprocal_vectors() 58 | 59 | qvecs = point_list @ rec_vec 60 | rec_vec = qvecs 61 | nelec = simulation_cell.nelectron 62 | 63 | def structure_factor(data): 64 | """ 65 | 66 | :param data: electron walkers with shape [batch, ne * ndim] 67 | :return: structure factor with shape [] 68 | """ 69 | leading_shape = list(data.shape[:-1]) 70 | data = data.reshape(leading_shape + [-1, ndim]) 71 | dots = jnp.einsum('kj,...j->...k', rec_vec, data) 72 | # batch ne npoint 73 | rho_k = jnp.exp(1j * dots) 74 | rho_k = jnp.sum(rho_k, axis=1) 75 | rho_k_one = jnp.mean(rho_k, axis=0) 76 | rho_k_one_mean = constants.pmean_if_pmap(rho_k_one, axis_name=constants.PMAP_AXIS_NAME) 77 | rho_k_two = jnp.mean(jnp.abs(rho_k)**2, axis=0) 78 | rho_k_two_mean = constants.pmean_if_pmap(rho_k_two, axis_name=constants.PMAP_AXIS_NAME) 79 | 80 | sk = rho_k_two_mean - jnp.abs(rho_k_one_mean)**2 81 | sk = sk / nelec 82 | 83 | return sk 84 | 85 | return structure_factor -------------------------------------------------------------------------------- /DeepSolid/utils/system.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 17 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 18 | 19 | from typing import Sequence 20 | import attr 21 | import numpy as np 22 | 23 | from DeepSolid.utils import elements 24 | from DeepSolid.utils import units as unit_conversion 25 | 26 | 27 | @attr.s 28 | class Atom: 29 | """Atom information for Hamiltonians. 30 | 31 | The nuclear charge is inferred from the symbol if not given, in which case the 32 | symbol must be the IUPAC symbol of the desired element. 33 | 34 | Attributes: 35 | symbol: Element symbol. 36 | coords: An iterable of atomic coordinates. Always a list of floats and in 37 | bohr after initialisation. Default: place atom at origin. 38 | charge: Nuclear charge. Default: nuclear charge (atomic number) of atom of 39 | the given name. 40 | atomic_number: Atomic number associated with element. Default: atomic number 41 | of element of the given symbol. Should match charge unless fractional 42 | nuclear charges are being used. 43 | units: String giving units of coords. Either bohr or angstrom. Default: 44 | bohr. If angstrom, coords are converted to be in bohr and units to the 45 | string 'bohr'. 46 | coords_angstrom: list of atomic coordinates in angstrom. 47 | coords_array: Numpy array of atomic coordinates in bohr. 48 | element: elements.Element corresponding to the symbol. 49 | """ 50 | symbol = attr.ib(type=str) 51 | coords = attr.ib( 52 | type=Sequence[float], 53 | converter=lambda xs: tuple(float(x) for x in xs), 54 | default=(0.0, 0.0, 0.0)) 55 | charge = attr.ib(type=float, converter=float) 56 | atomic_number = attr.ib(type=int, converter=int) 57 | units = attr.ib( 58 | type=str, 59 | default='bohr', 60 | validator=attr.validators.in_(['bohr', 'angstrom'])) 61 | 62 | @charge.default 63 | def _set_default_charge(self): 64 | return self.element.atomic_number 65 | 66 | @atomic_number.default 67 | def _set_default_atomic_number(self): 68 | return self.element.atomic_number 69 | 70 | def __attrs_post_init__(self): 71 | if self.units == 'angstrom': 72 | self.coords = [unit_conversion.angstrom2bohr(x) for x in self.coords] 73 | self.units = 'bohr' 74 | 75 | @property 76 | def coords_angstrom(self): 77 | return [unit_conversion.bohr2angstrom(x) for x in self.coords] 78 | 79 | @property 80 | def coords_array(self): 81 | if not hasattr(self, '_coords_arr'): 82 | self._coords_arr = np.array(self.coords) 83 | return self._coords_arr 84 | 85 | @property 86 | def element(self): 87 | return elements.SYMBOLS[self.symbol] 88 | -------------------------------------------------------------------------------- /DeepSolid/utils/kfac_ferminet_alpha/vjp_rc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020, 2021 The NetKet Authors - All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # enable x64 on jax 16 | # must be done at startup. 17 | 18 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 19 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 20 | 21 | import jax 22 | import jax.numpy as jnp 23 | from jax.tree_util import ( 24 | tree_map, 25 | tree_multimap, 26 | ) 27 | 28 | def vjp_rc( 29 | fun, *primals, has_aux: bool = False, conjugate: bool = False): 30 | ''' 31 | realize the vjp of R->C function 32 | :param fun: 33 | :param primals: 34 | :param has_aux: 35 | :param conjugate: 36 | :return: 37 | ''' 38 | if has_aux: 39 | 40 | def real_fun(*primals): 41 | val, aux = fun(*primals) 42 | real_val = jax.tree_map(lambda x:x.real, val) 43 | return real_val, aux 44 | 45 | def imag_fun(*primals): 46 | val, aux = fun(*primals) 47 | imag_val = jax.tree_map(lambda x: x.imag, val) 48 | return imag_val, aux 49 | 50 | vals_r, vjp_r_fun, aux = jax.vjp(real_fun, *primals, has_aux=True) 51 | vals_j, vjp_j_fun, _ = jax.vjp(imag_fun, *primals, has_aux=True) 52 | 53 | else: 54 | real_fun = lambda *primals: fun(*primals).real 55 | imag_fun = lambda *primals: fun(*primals).imag 56 | 57 | vals_r, vjp_r_fun = jax.vjp(real_fun, *primals, has_aux=False) 58 | vals_j, vjp_j_fun = jax.vjp(imag_fun, *primals, has_aux=False) 59 | 60 | primals_out = jax.tree_multimap(lambda x,y:x + 1j*y, vals_r, vals_j) 61 | 62 | def vjp_fun(ȳ): 63 | """ 64 | function computing the vjp product for a R->C function. 65 | """ 66 | ȳ_r = jax.tree_map(lambda x:x.real, ȳ) 67 | # ȳ_r = jax.tree_map(lambda x:jnp.asarray(x, dtype=vals_r.dtype), ȳ_r) 68 | ȳ_j = jax.tree_map(lambda x:x.imag, ȳ) 69 | # ȳ_j = jax.tree_map(lambda x:jnp.asarray(x, dtype=vals_j.dtype), ȳ_j) 70 | 71 | # val = vals_r + vals_j 72 | vr_jr = vjp_r_fun(jax.tree_map(lambda x,v:jnp.asarray(x, dtype=v.dtype), ȳ_r, vals_r)) 73 | vj_jr = vjp_r_fun(jax.tree_map(lambda x,v:jnp.asarray(x, dtype=v.dtype), ȳ_j, vals_r)) 74 | vr_jj = vjp_j_fun(jax.tree_map(lambda x,v:jnp.asarray(x, dtype=v.dtype), ȳ_r, vals_j)) 75 | vj_jj = vjp_j_fun(jax.tree_map(lambda x,v:jnp.asarray(x, dtype=v.dtype), ȳ_j, vals_j)) 76 | 77 | r = tree_multimap( 78 | lambda re, im: re + 1j * im, 79 | vr_jr, 80 | vj_jr, 81 | ) 82 | i = tree_multimap(lambda re, im: re + 1j * im, vr_jj, vj_jj) 83 | out = tree_multimap(lambda re, im: re + 1j * im, r, i) 84 | 85 | if conjugate: 86 | out = tree_map(jnp.conjugate, out) 87 | 88 | return out 89 | 90 | if has_aux: 91 | return primals_out, vjp_fun, aux 92 | else: 93 | return primals_out, vjp_fun -------------------------------------------------------------------------------- /DeepSolid/init_guess.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 16 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as np 21 | import pyscf.pbc.gto 22 | from typing import Sequence 23 | from DeepSolid.utils import system 24 | from DeepSolid import distance 25 | 26 | 27 | def init_electrons( 28 | key, 29 | cell: Sequence[system.Atom], 30 | latvec, 31 | electrons: Sequence[int], 32 | batch_size: int, 33 | init_width=0.5, 34 | ) -> jnp.ndarray: 35 | """ 36 | Initializes electron positions around each atom. 37 | 38 | :param key: jax key for random 39 | :param cell: internal representation of simulation cell 40 | :param latvec: lattice vector of cell 41 | :param electrons: list of up, down electrons 42 | :param batch_size: batch_size for simulation 43 | :param init_width: std of gaussian used for initialization 44 | :return: jnp.array with shape [Batch_size, N_ele * ndim] 45 | """ 46 | if sum(atom.charge for atom in cell) != sum(electrons): 47 | if len(cell) == 1: 48 | atomic_spin_configs = [electrons] 49 | else: 50 | raise NotImplementedError('No initialization policy yet ' 51 | 'exists for charged molecules.') 52 | else: 53 | 54 | atomic_spin_configs = [ 55 | (atom.element.nalpha - int((atom.atomic_number - atom.charge) // 2), 56 | atom.element.nbeta - int((atom.atomic_number - atom.charge) // 2)) 57 | for atom in cell 58 | ] 59 | # element.nalpha return the up spin number of the single element, if ecp is used, [nalpha,nbeta] should be reduce 60 | # with the the core charge which equals atomic_number - atom.charge 61 | assert sum(sum(x) for x in atomic_spin_configs) == sum(electrons) 62 | while tuple(sum(x) for x in zip(*atomic_spin_configs)) != electrons: 63 | i = np.random.randint(len(atomic_spin_configs)) 64 | nalpha, nbeta = atomic_spin_configs[i] 65 | if atomic_spin_configs[i][0] > 0: 66 | atomic_spin_configs[i] = nalpha - 1, nbeta + 1 67 | 68 | # Assign each electron to an atom initially. 69 | electron_positions = [] 70 | for i in range(2): 71 | for j in range(len(cell)): 72 | atom_position = jnp.asarray(cell[j].coords) 73 | electron_positions.append(jnp.tile(atom_position, atomic_spin_configs[j][i])) 74 | electron_positions = jnp.concatenate(electron_positions) 75 | # Create a batch of configurations with a Gaussian distribution about each 76 | # atom. 77 | key, subkey = jax.random.split(key) 78 | guess = electron_positions + init_width * jax.random.normal(subkey, shape=(batch_size, electron_positions.size)) 79 | replaced_guess, _ = distance.enforce_pbc(latvec, guess) 80 | return replaced_guess 81 | 82 | 83 | 84 | def pyscf_to_cell(cell: pyscf.pbc.gto.Cell): 85 | """ 86 | Converts the pyscf cell to the internal representation. 87 | 88 | :param cell: pyscf.cell object 89 | :return: internal cell representation 90 | """ 91 | internal_cell = [system.Atom(cell.atom_symbol(i), 92 | cell.atom_coords()[i], 93 | charge=cell.atom_charges()[i], ) 94 | for i in range(cell.natm)] 95 | ## cfg.system.pyscf_mol.atom_charges()[i] return the screen charge of i atom if ecp is used 96 | return internal_cell -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepSolid 2 | 3 | An implementation of the algorithm given in 4 | ["Ab initio calculation of real solids via neural network ansatz"](https://rdcu.be/c4rNI). 5 | A periodic neural network is proposed as wavefunction ansatz for solid quantum Monte Carlo and achieves 6 | unprecedented accuracy compared with other state-of-the-art methods. 7 | This repository is developed upon [FermiNet](https://github.com/deepmind/ferminet/tree/jax) 8 | and [PyQMC](https://github.com/WagnerGroup/pyqmc). 9 | 10 | ## Installation 11 | 12 | DeepSolid can be installed via the supplied setup.py file. 13 | ```shell 14 | # Install with CPU only 15 | pip3 install -e . -f https://storage.googleapis.com/jax-releases/jax_releases.html 16 | # or with GPU 17 | pip3 install -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 18 | ``` 19 | 20 | Python 3.9 is recommended. 21 | If GPU is available, we recommend you to install jax and jaxlib with cuda 11.4+. 22 | Our experiments were carried out with `jax==0.2.26` and `jaxlib==0.1.75.` 23 | 24 | ## Usage 25 | 26 | [Ml_collection](https://github.com/google/ml_collections) package is used for system definition. Below is a simple example of H10 in PBC: 27 | ``` 28 | deepsolid --config=PATH/TO/DeepSolid/config/two_hydrogen_cell.py:H,5,1,1,2.0,0,ccpvdz --config.batch_size 4096 29 | ``` 30 | 31 | ### Customize your system 32 | Simulation system can be customized in config.py file, such as 33 | 34 | ```python 35 | import numpy as np 36 | from pyscf.pbc import gto 37 | from DeepSolid import base_config 38 | from DeepSolid import supercell 39 | 40 | 41 | def get_config(input_str): 42 | symbol, S = input_str.split(',') 43 | cfg = base_config.default() 44 | 45 | # Set up cell. 46 | cell = gto.Cell() 47 | 48 | # Define the atoms in the primitive cell. 49 | cell.atom = f""" 50 | {symbol} 0.000000000000 0.000000000000 0.000000000000 51 | """ 52 | 53 | # Define the pretrain basis. 54 | cell.basis = "ccpvdz" 55 | 56 | # Define the lattice vectors of the primitive cell. 57 | # In this example it's a simple cubic. 58 | cell.a = np.array([[3.0, 0.0, 0.0], 59 | [0.0, 3.0, 0.0], 60 | [0.0, 0.0, 3.0]]) 61 | 62 | # Define the unit used in cell definition, only support Bohr now. 63 | cell.unit = "B" 64 | cell.verbose = 5 65 | 66 | # Define the threshold to discard gaussian basis used in pretrain. 67 | cell.exp_to_discard = 0.1 68 | cell.build() 69 | 70 | # Define the supercell for QMC, S specifies how to tile the primitive cell. 71 | S = np.eye(3) * int(S) 72 | simulation_cell = supercell.get_supercell(cell, S) 73 | 74 | # Assign the defined supercell to cfg. 75 | cfg.system.pyscf_cell = simulation_cell 76 | 77 | return cfg 78 | ``` 79 | After defining the config file, simply use the following command to launch the simulation: 80 | 81 | ```shell 82 | deepsolid --config=PATH/TO/config.py:He,1 --config.batch_size 4096 83 | ``` 84 | 85 | 86 | ### Read structure from poscar file 87 | 88 | We also support reading structure from poscar file, which is commonly used. Simply use the following command 89 | ```shell 90 | deepsolid --config=DeepSolid/config/read_poscar.py:PATH/TO/POSCAR/bcc_li.vasp,1,ccpvdz 91 | ``` 92 | ## Distributed training 93 | Present released code doesn't support multi-node training. See [this link](https://github.com/google/jax/pull/8364) 94 | for help. 95 | 96 | ## Tricks to accelerate 97 | The bottleneck of DeepSolid is the laplacian evaluation of the neural network. We recommend 98 | the users to use partition mode instead, simply adding two more flags: 99 | ```shell 100 | deepsolid --config=PATH/TO/config.py --config.optim.laplacian_mode=partition --config.optim.partition_number=3 101 | ``` 102 | Partition mode will try to parallelize the calculation of laplacian and partition number must be a factor of 103 | (electron number * 3). Note that partition mode will require a lot of GPU memory. 104 | 105 | ## Precision 106 | DeepSolid supports both FP32 and FP64. However, we recommend the users turn off the TF32 mode which 107 | is automatically adopted in A100 if FP32 is chosen. TF32 can be turned off using the following command: 108 | 109 | ```shell 110 | NVIDIA_TF32_OVERRIDE=0 deepsolid --config.use_x64=False 111 | ``` 112 | 113 | ## Giving Credit 114 | 115 | If you use this code in your work, please cite the associated paper. 116 | 117 | ``` 118 | @article{li2022ab, 119 | title={Ab initio calculation of real solids via neural network ansatz}, 120 | author={Li, Xiang and Li, Zhe and Chen, Ji}, 121 | journal={Nature Communications}, 122 | volume={13}, 123 | number={1}, 124 | pages={7895}, 125 | year={2022}, 126 | publisher={Nature Publishing Group UK London} 127 | } 128 | ``` 129 | -------------------------------------------------------------------------------- /test/test_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | import logging 11 | from time import time 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | jax.config.update("jax_enable_x64", True) 15 | 16 | import test_cell 17 | from DeepSolid import network 18 | from DeepSolid import supercell 19 | from DeepSolid import init_guess 20 | from DeepSolid import hf 21 | from DeepSolid import base_config 22 | 23 | # Define your test system 24 | S = np.eye(3) # Define how to tile primitive cell 25 | cell = test_cell.cell 26 | simulation_cell = supercell.get_supercell(cell, S=S) 27 | 28 | # Define the scaled twist momentum k_s 29 | scaled_twist = 0.0 30 | twist = scaled_twist * jnp.ones(3) 31 | 32 | # Do HF calculation to get k-points 33 | scf_approx = hf.SCF(simulation_cell, twist=twist) 34 | scf_approx.init_scf() 35 | 36 | # Define your neural network settings 37 | cfg = base_config.default() 38 | cfg.network.detnet.determinants = 8 39 | system_dict = {'klist': scf_approx.klist, # occupied k points from HF 40 | 'simulation_cell': simulation_cell, 41 | } 42 | system_dict.update(cfg.network.detnet) 43 | system_dict['envelope_type'] = 'isotropic' 44 | system_dict['full_det'] = False 45 | 46 | # quantum number of periodic boundary condition 47 | kp = sum([jnp.sum(k, axis=0) for k in system_dict['klist']]) 48 | 49 | # make callable neural network functions 50 | slater_forward = network.make_solid_fermi_net(**system_dict, method_name='eval_logdet') 51 | slater_phase_and_slogdet = network.make_solid_fermi_net(**system_dict, method_name='eval_phase_and_slogdet') 52 | slater_mat_forward = network.make_solid_fermi_net(**system_dict, method_name='eval_mats') 53 | 54 | # initialize parameters and electron positions 55 | key = jax.random.PRNGKey(int(time())) 56 | internal_sim_cell = init_guess.pyscf_to_cell(simulation_cell) 57 | coord = init_guess.init_electrons(key, 58 | internal_sim_cell, 59 | simulation_cell.a, 60 | simulation_cell.nelec, 61 | batch_size=1)[0] 62 | p = slater_forward.init(key, data=coord) 63 | 64 | 65 | def test_periodic_bc(params, x): 66 | """ 67 | test periodic boundary condition of wf 68 | :param params: 69 | :param x: 70 | :return: 71 | """ 72 | trans = cell.lattice_vectors()[2] 73 | x1 = x 74 | x2 = x + jnp.tile(trans, simulation_cell.nelectron) 75 | # simultaneous translation of all electron over a primitive lattice vector 76 | p1, s1 = slater_phase_and_slogdet.apply(params=params, x=x1) 77 | p2, s2 = slater_phase_and_slogdet.apply(params=params, x=x2) 78 | logging.info(f'original:{p1, s1}') 79 | logging.info(f'translated:{p2, s2}') 80 | logging.info(f'kp angle:{jnp.angle(p2 / p1) / np.pi} pi') 81 | assert jnp.allclose(s1, s2) 82 | assert jnp.allclose(p1 * jnp.exp(1j * jnp.dot(kp, trans)), p2) 83 | logging.info('Periodic BC checked') 84 | 85 | 86 | def test_twisted_bc(params, x): 87 | """ 88 | test twist boundary condition of wf 89 | :param params: 90 | :param x: 91 | :return: 92 | """ 93 | x1 = x 94 | x2 = x + jnp.concatenate([simulation_cell.lattice_vectors()[1][None, ...], 95 | jnp.zeros(shape=[simulation_cell.nelectron - 1, 3]) 96 | ], axis=0).ravel() 97 | # translation of a single electron over a supercell lattice vector 98 | p1, s1 = slater_phase_and_slogdet.apply(params=params, x=x1) 99 | p2, s2 = slater_phase_and_slogdet.apply(params=params, x=x2) 100 | logging.info(f'original:{p1, s1}') 101 | logging.info(f'translated:{p2, s2}') 102 | logging.info(f'ks angle:{jnp.angle(p2 / p1) / jnp.pi} pi') 103 | assert jnp.allclose(s1, s2) 104 | assert jnp.allclose(p2 / p1, 105 | jnp.exp(1j * scaled_twist * 2 * jnp.pi)) 106 | logging.info('Twisted BC checked') 107 | 108 | 109 | def test_anti_symmetry(params, x): 110 | """ 111 | test anti-symmetry condition of wf 112 | :param params: 113 | :param x: 114 | :return: 115 | """ 116 | x1 = x 117 | x2 = jnp.concatenate([x1[3:6], x1[:3], x1[6:]]) 118 | p1, s1 = slater_phase_and_slogdet.apply(params=params, x=x1) 119 | p2, s2 = slater_phase_and_slogdet.apply(params=params, x=x2) 120 | assert jnp.allclose(p1, -p2) 121 | assert jnp.allclose(s1, s2) 122 | logging.info('Anti symmetry checked') 123 | 124 | 125 | if __name__ == '__main__': 126 | test_periodic_bc(x=coord, params=p) 127 | test_twisted_bc(x=coord, params=p) 128 | test_anti_symmetry(x=coord, params=p) 129 | -------------------------------------------------------------------------------- /DeepSolid/utils/writers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 17 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 18 | 19 | import contextlib 20 | import os 21 | from typing import Mapping, Optional, Sequence 22 | 23 | from absl import logging 24 | import tables 25 | 26 | 27 | class Writer(contextlib.AbstractContextManager): 28 | """Write data to CSV, as well as logging data to stdout if desired.""" 29 | 30 | def __init__(self, 31 | name: str, 32 | schema: Sequence[str], 33 | directory: str = 'logs/', 34 | iteration_key: Optional[str] = 't', 35 | log: bool = True): 36 | """Initialise Writer. 37 | 38 | Args: 39 | name: file name for CSV. 40 | schema: sequence of keys, corresponding to each data item. 41 | directory: directory path to write file to. 42 | iteration_key: if not None or a null string, also include the iteration 43 | index as the first column in the CSV output with the given key. 44 | log: Also log each entry to stdout. 45 | """ 46 | self._schema = schema 47 | if not os.path.isdir(directory): 48 | os.mkdir(directory) 49 | self._filename = os.path.join(directory, name + '.csv') 50 | self._iteration_key = iteration_key 51 | self._log = log 52 | 53 | def __enter__(self): 54 | should_add_header = not os.path.exists(self._filename) 55 | 56 | self._file = open(self._filename, 'a+') 57 | 58 | if should_add_header: 59 | # write top row of csv 60 | if self._iteration_key: 61 | self._file.write(f'{self._iteration_key},') 62 | self._file.write(','.join(self._schema) + '\n') 63 | return self 64 | 65 | def write(self, t: int, **data): 66 | """Writes to file and stdout. 67 | 68 | Args: 69 | t: iteration index. 70 | **data: data items with keys as given in schema. 71 | """ 72 | row = [str(data.get(key, '')) for key in self._schema] 73 | if self._iteration_key: 74 | row.insert(0, str(t)) 75 | for key in data: 76 | if key not in self._schema: 77 | raise ValueError('Not a recognized key for writer: %s' % key) 78 | 79 | # write the data to csv 80 | self._file.write(','.join(row) + '\n') 81 | 82 | # write the data to abseil logs 83 | if self._log: 84 | logging.info('Iteration %s: %s', t, data) 85 | 86 | def flush(self): 87 | self._file.flush() 88 | 89 | def __exit__(self, exc_type, exc_val, exc_tb): 90 | self.flush() 91 | self._file.close() 92 | 93 | 94 | class H5Writer(contextlib.AbstractContextManager): 95 | """Write data to HDF5 files.""" 96 | 97 | def __init__(self, 98 | name: str, 99 | schema: Mapping[str, Sequence[int]], 100 | directory: str = '', 101 | index_key: str = 't', 102 | compression_level: int = 5): 103 | """Initialise H5Writer. 104 | 105 | Args: 106 | name: file name for CSV. 107 | schema: dict of keys, corresponding to each data item . All data is 108 | assumed ot be 32-bit floats. 109 | directory: directory path to write file to. 110 | index_key: name of (integer) key used to index each entry. 111 | compression_level: compression level (0-9) used to compress HDF5 file. 112 | """ 113 | self._path = os.path.join(directory, name) 114 | self._schema = schema 115 | self._index_key = index_key 116 | self._description = {} 117 | self._file = None 118 | self._complevel = compression_level 119 | 120 | def __enter__(self): 121 | if not self._schema: 122 | return self 123 | pos = 1 124 | self._description[self._index_key] = tables.Int32Col(pos=pos) 125 | for key, shape in self._schema.items(): 126 | pos += 1 127 | self._description[key] = tables.Float32Col(pos=pos, shape=shape) 128 | if not os.path.isdir(os.path.dirname(self._path)): 129 | os.mkdir(os.path.dirname(self._path)) 130 | self._file = tables.open_file( 131 | self._path, 132 | mode='w', 133 | title='Fermi Net Data', 134 | filters=tables.Filters(complevel=self._complevel)) 135 | self._table = self._file.create_table( 136 | where=self._file.root, name='data', description=self._description) 137 | return self 138 | 139 | def write(self, index: int, data): 140 | """Write data to HDF5 file. 141 | 142 | Args: 143 | index: iteration index. 144 | data: dict of arrays to write to file. Only elements with keys in the 145 | schema are written. 146 | """ 147 | if self._file: 148 | h5_data = (index,) 149 | for key in self._description: 150 | if key != self._index_key: 151 | h5_data += (data[key],) 152 | self._table.append([h5_data]) 153 | self._table.flush() 154 | 155 | def __exit__(self, exc_type, exc_val, exc_tb): 156 | if self._file: 157 | self._file.close() 158 | self._file = None 159 | -------------------------------------------------------------------------------- /DeepSolid/curvature_tags_and_blocks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 16 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 17 | 18 | """Curvature blocks for FermiNet.""" 19 | import functools 20 | from typing import Optional, Mapping, Union 21 | import jax 22 | import jax.numpy as jnp 23 | 24 | from DeepSolid.utils.kfac_ferminet_alpha import curvature_blocks as blocks 25 | from DeepSolid.utils.kfac_ferminet_alpha import layers_and_loss_tags as tags 26 | from DeepSolid.utils.kfac_ferminet_alpha import utils 27 | 28 | 29 | vmap_psd_inv_cholesky = jax.vmap(utils.psd_inv_cholesky, (0, None), 0) 30 | vmap_matmul = jax.vmap(jnp.matmul, in_axes=(0, 0), out_axes=0) 31 | 32 | 33 | qmc1_tag = tags.LayerTag("qmc1_tag", 1, 1) 34 | 35 | 36 | def register_qmc1(y, x, w, **kwargs): 37 | return qmc1_tag.bind(y, x, w, **kwargs) 38 | 39 | 40 | qmc2_tag = tags.LayerTag("qmc2_tag", 1, 1) 41 | 42 | 43 | def register_qmc2(y, x, w, **kwargs): 44 | return qmc2_tag.bind(y, x, w, **kwargs) 45 | 46 | 47 | repeated_dense_tag = tags.LayerTag("repeated_dense_tag", 1, 1) 48 | 49 | 50 | def register_repeated_dense(y, x, w, b): 51 | if b is None: 52 | return repeated_dense_tag.bind(y, x, w) 53 | return repeated_dense_tag.bind(y, x, w, b) 54 | 55 | 56 | class QmcBlockedDense(blocks.TwoKroneckerFactored): 57 | """A factor that is the Kronecker product of two matrices.""" 58 | 59 | def update_curvature_inverse_estimate(self, diagonal_weight, pmap_axis_name): 60 | self.inputs_factor.sync(pmap_axis_name) 61 | 62 | self.outputs_factor.sync(pmap_axis_name) 63 | vmap_pi_adjusted_inverse = jax.vmap( 64 | functools.partial(utils.pi_adjusted_inverse, 65 | pmap_axis_name=pmap_axis_name), 66 | (0, 0, None), (0, 0) 67 | ) 68 | self.inputs_factor_inverse, self.outputs_factor_inverse = ( 69 | vmap_pi_adjusted_inverse(self.inputs_factor.value, 70 | self.outputs_factor.value, 71 | diagonal_weight / self.extra_scale)) 72 | 73 | def multiply_matpower(self, vec, exp, diagonal_weight): 74 | w, = vec 75 | # kmjn 76 | v = w 77 | k, m, j, n = v.shape 78 | if exp == 1: 79 | inputs_factor = self.inputs_factor.value 80 | outputs_factor = self.outputs_factor.value 81 | scale = self.extra_scale 82 | elif exp == -1: 83 | inputs_factor = self.inputs_factor_inverse 84 | outputs_factor = self.outputs_factor_inverse 85 | scale = 1.0 / self.extra_scale 86 | diagonal_weight = 0.0 87 | else: 88 | raise NotImplementedError() 89 | # jk(mn) 90 | v = jnp.transpose(v, [2, 0, 1, 3]).reshape([j, k, m * n]) 91 | v = vmap_matmul(inputs_factor, v) 92 | v = vmap_matmul(v, outputs_factor) 93 | # kmjn 94 | v = jnp.transpose(v.reshape([j, k, m, n]), [1, 2, 0, 3]) 95 | v = v * scale + diagonal_weight * w 96 | return (v,) 97 | 98 | def update_curvature_matrix_estimate( 99 | self, 100 | info: blocks._BlockInfo, # pylint: disable=protected-access 101 | batch_size: int, 102 | ema_old: Union[float, jnp.ndarray], 103 | ema_new: Union[float, jnp.ndarray], 104 | pmap_axis_name: str 105 | ) -> None: 106 | (x,), (dy,) = info["inputs"], info["outputs_tangent"] 107 | assert batch_size == x.shape[0] 108 | normalizer = x.shape[0] * x.shape[1] 109 | # The forward computation is 110 | # einsum(x,w): bijk,bkmjn -> bijmn 111 | inputs_cov = jnp.einsum("bijk,bijl->jkl", x, x) / normalizer 112 | dy = jnp.reshape(dy, dy.shape[:-2] + (-1,)) 113 | outputs_cov = jnp.einsum("bijk,bijl->jkl", dy, dy) / normalizer 114 | self.inputs_factor.update(inputs_cov, ema_old, ema_new) 115 | self.outputs_factor.update(outputs_cov, ema_old, ema_new) 116 | 117 | def init(self, rng): 118 | del rng 119 | k, m, j, n = self.params_shapes[0] 120 | return dict( 121 | inputs_factor=utils.WeightedMovingAverage.zero([j, k, k]), 122 | inputs_factor_inverse=jnp.zeros([j, k, k]), 123 | outputs_factor=utils.WeightedMovingAverage.zero([j, m * n, m * n]), 124 | outputs_factor_inverse=jnp.zeros([j, m * n, m * n]), 125 | extra_scale=jnp.asarray(m) 126 | ) 127 | 128 | def input_size(self) -> int: 129 | raise NotImplementedError() 130 | 131 | def output_size(self) -> int: 132 | raise NotImplementedError() 133 | 134 | 135 | class RepeatedDenseBlock(blocks.DenseTwoKroneckerFactored): 136 | """Dense block that is repeated.""" 137 | 138 | def compute_extra_scale(self) -> Optional[jnp.ndarray]: 139 | (x_shape,) = self.inputs_shapes 140 | return utils.product(x_shape) // (x_shape[0] * x_shape[-1]) 141 | 142 | def update_curvature_matrix_estimate( 143 | self, 144 | info: Mapping[str, blocks._Arrays], # pylint: disable=protected-access 145 | batch_size: int, 146 | ema_old: Union[float, jnp.ndarray], 147 | ema_new: Union[float, jnp.ndarray], 148 | pmap_axis_name: str 149 | ) -> None: 150 | info = dict(**info) 151 | (x,), (dy,) = info["inputs"], info["outputs_tangent"] 152 | assert x.shape[0] == batch_size 153 | info["inputs"] = (x.reshape([-1, x.shape[-1]]),) 154 | info["outputs_tangent"] = (dy.reshape([-1, dy.shape[-1]]),) 155 | super().update_curvature_matrix_estimate(info, x.size // x.shape[-1], 156 | ema_old, ema_new, pmap_axis_name) 157 | 158 | 159 | blocks.set_default_tag_to_block("qmc1_tag", QmcBlockedDense) 160 | blocks.set_default_tag_to_block("repeated_dense_tag", RepeatedDenseBlock) 161 | -------------------------------------------------------------------------------- /DeepSolid/supercell.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2019 Lucas K Wagner 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 24 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 25 | 26 | import logging 27 | 28 | import numpy as np 29 | import pyscf.pbc.gto 30 | 31 | 32 | def get_supercell_kpts(supercell): 33 | """ 34 | 35 | :param supercell: pyscf object of simulation cell. 36 | :return:supercell k points which belong to the unit box primitive cell k point space 37 | """ 38 | Sinv = np.linalg.inv(supercell.S).T 39 | u = [0, 1] 40 | unit_box = np.stack([x.ravel() for x in np.meshgrid(*[u] * 3, indexing="ij")]).T 41 | unit_box_ = np.dot(unit_box, supercell.S.T) 42 | xyz_range = np.stack([f(unit_box_, axis=0) for f in (np.amin, np.amax)]).T 43 | kptmesh = np.meshgrid(*[np.arange(*r) for r in xyz_range], indexing="ij") 44 | possible_kpts = np.dot(np.stack([x.ravel() for x in kptmesh]).T, Sinv) 45 | in_unit_box = (possible_kpts >= 0) * (possible_kpts < 1 - 1e-12) 46 | select = np.where(np.all(in_unit_box, axis=1))[0] 47 | reclatvec = np.linalg.inv(supercell.original_cell.lattice_vectors()).T * 2 * np.pi 48 | return np.dot(possible_kpts[select], reclatvec) 49 | 50 | 51 | def get_supercell_copies(latvec, S): 52 | Sinv = np.linalg.inv(S).T 53 | u = [0, 1] 54 | unit_box = np.stack([x.ravel() for x in np.meshgrid(*[u] * 3, indexing="ij")]).T 55 | unit_box_ = np.dot(unit_box, S) 56 | xyz_range = np.stack([f(unit_box_, axis=0) for f in (np.amin, np.amax)]).T 57 | mesh = np.meshgrid(*[np.arange(*r) for r in xyz_range], indexing="ij") 58 | possible_pts = np.dot(np.stack([x.ravel() for x in mesh]).T, Sinv.T) 59 | in_unit_box = (possible_pts >= 0) * (possible_pts < 1 - 1e-12) 60 | select = np.where(np.all(in_unit_box, axis=1))[0] 61 | return np.linalg.multi_dot((possible_pts[select], S, latvec)) 62 | 63 | 64 | def get_supercell(cell, S, sym_type='minimal') -> pyscf.pbc.gto.Cell: 65 | """ 66 | generate supercell from primitive cell with S specified 67 | 68 | :param cell: pyscf Cell object 69 | :param S: (3, 3) supercell matrix for QMC from cell defined by cell.a. 70 | :return: QMC simulation cell 71 | """ 72 | import pyscf.pbc 73 | scale = np.abs(int(np.round(np.linalg.det(S)))) 74 | superlattice = np.dot(S, cell.lattice_vectors()) 75 | Rpts = get_supercell_copies(cell.lattice_vectors(), S) 76 | atom = [] 77 | for (name, xyz) in cell._atom: 78 | atom.extend([(name, xyz + R) for R in Rpts]) 79 | supercell = pyscf.pbc.gto.Cell() 80 | supercell.a = superlattice 81 | supercell.atom = atom 82 | supercell.ecp = cell.ecp 83 | supercell.basis = cell.basis 84 | supercell.exp_to_discard = cell.exp_to_discard 85 | supercell.unit = "Bohr" 86 | supercell.spin = cell.spin * scale 87 | supercell.build() 88 | supercell.original_cell = cell 89 | supercell.S = S 90 | supercell.scale = scale 91 | supercell.output = None 92 | supercell.stdout = None 93 | supercell = set_symmetry_lat(supercell, sym_type) 94 | logging.info(f'Use {sym_type} type feature.') 95 | return supercell 96 | 97 | 98 | def set_symmetry_lat(supercell, sym_type='minimal'): 99 | ''' 100 | Attach corresponding lattice vectors to the simulation cell. 101 | 102 | :param supercell: 103 | :param sym_type:specify the symmetry of constructed distance feature, 104 | Minimal is used as default, and other type hasn't been tested. 105 | :return: simulation cell with symmetry specified. 106 | ''' 107 | prim_bv = supercell.original_cell.reciprocal_vectors() 108 | sim_bv = supercell.reciprocal_vectors() 109 | if sym_type == 'minimal': 110 | mat = np.eye(3) 111 | elif sym_type == 'fcc': 112 | mat = np.array([[1, 0, 0], 113 | [0, 1, 0], 114 | [0, 0, 1], 115 | [1, 1, 1]]) 116 | elif sym_type == 'bcc': 117 | mat = np.array([[1, 0, 0], 118 | [0, 1, 0], 119 | [0, 0, 1], 120 | [1, -1, 0], 121 | [1, 0, -1], 122 | [0, 1, -1]]) 123 | elif sym_type == 'hexagonal': 124 | mat = np.array([[1, 0, 0], 125 | [0, 1, 0], 126 | [0, 0, 1], 127 | [-1, -1, 0]]) 128 | else: 129 | mat = np.eye(3) 130 | 131 | prim_bv = mat @ prim_bv 132 | sim_bv = mat @ sim_bv 133 | 134 | prim_av = np.linalg.pinv(prim_bv).T 135 | sim_av = np.linalg.pinv(sim_bv).T 136 | supercell.BV = sim_bv 137 | supercell.AV = sim_av 138 | supercell.original_cell.BV = prim_bv 139 | supercell.original_cell.AV = prim_av 140 | return supercell 141 | 142 | 143 | def get_k_indices(cell, mf, kpts, tol=1e-6): 144 | """Given a list of kpts, return inds such that mf.kpts[inds] is a list of kpts equivalent to the input list""" 145 | kdiffs = mf.kpts[None] - kpts[:, None] 146 | frac_kdiffs = np.dot(kdiffs, cell.lattice_vectors().T) / (2 * np.pi) 147 | kdiffs = np.mod(frac_kdiffs + 0.5, 1) - 0.5 148 | return np.nonzero(np.linalg.norm(kdiffs, axis=-1) < tol)[1] 149 | -------------------------------------------------------------------------------- /DeepSolid/utils/kfac_ferminet_alpha/example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Example of running KFAC.""" 16 | from absl import app 17 | from absl import flags 18 | import jax 19 | import jax.numpy as jnp 20 | 21 | import numpy as np 22 | import DeepSolid.utils.kfac_ferminet_alpha as kfac_ferminet_alpha 23 | from DeepSolid.utils.kfac_ferminet_alpha import utils 24 | 25 | 26 | TRAINING_STEPS = flags.DEFINE_integer( 27 | name="training_steps", 28 | default=100, 29 | help="Number of training steps to perform") 30 | BATCH_SIZE = flags.DEFINE_integer( 31 | name="batch_size", default=128, help="Batch size") 32 | LEARNING_RATE = flags.DEFINE_float( 33 | name="learning_rate", default=1e-3, help="Learning rate") 34 | L2_REG = flags.DEFINE_float( 35 | name="l2_reg", default=1e-3, help="L2 regularization coefficient") 36 | MOMENTUM = flags.DEFINE_float( 37 | name="momentum", default=0.8, help="Momentum coefficient") 38 | DAMPING = flags.DEFINE_float( 39 | name="damping", default=1e-2, help="Damping coefficient") 40 | MULTI_DEVICE = flags.DEFINE_bool( 41 | name="multi_device", 42 | default=False, 43 | help="Whether the computation should be replicated across multiple devices") 44 | SEED = flags.DEFINE_integer(name="seed", default=12412321, help="JAX RNG seed") 45 | 46 | 47 | def glorot_uniform(shape, key): 48 | dim_in = np.prod(shape[:-1]) 49 | dim_out = shape[-1] 50 | c = jnp.sqrt(6 / (dim_in + dim_out)) 51 | return jax.random.uniform(key, shape=shape, minval=-c, maxval=c) 52 | 53 | 54 | def fully_connected_layer(params, x): 55 | w, b = params 56 | return jnp.matmul(x, w) + b[None] 57 | 58 | 59 | def model_init(rng_key, batch, encoder_sizes=(1000, 500, 250, 30)): 60 | """Initialize the standard autoencoder.""" 61 | x_size = batch.shape[-1] 62 | decoder_sizes = encoder_sizes[len(encoder_sizes) - 2::-1] 63 | sizes = (x_size,) + encoder_sizes + decoder_sizes + (x_size,) 64 | keys = jax.random.split(rng_key, len(sizes) - 1) 65 | params = [] 66 | for rng_key, dim_in, dim_out in zip(keys, sizes, sizes[1:]): 67 | # Glorot uniform initialization 68 | w = glorot_uniform((dim_in, dim_out), rng_key) 69 | b = jnp.zeros([dim_out]) 70 | params.append((w, b)) 71 | return params, None 72 | 73 | 74 | def model_loss(params, inputs, l2_reg): 75 | """Evaluate the standard autoencoder.""" 76 | h = inputs.reshape([inputs.shape[0], -1]) 77 | for i, layer_params in enumerate(params): 78 | h = fully_connected_layer(layer_params, h) 79 | # Last layer does not have a nonlinearity 80 | if i % 4 != 3: 81 | h = jnp.tanh(h) 82 | l2_value = 0.5 * sum(jnp.square(p).sum() for p in jax.tree_leaves(params)) 83 | error = jax.nn.sigmoid(h) - inputs.reshape([inputs.shape[0], -1]) 84 | mean_squared_error = jnp.mean(jnp.sum(error * error, axis=1), axis=0) 85 | regularized_loss = mean_squared_error + l2_reg * l2_value 86 | 87 | return regularized_loss, dict(mean_squared_error=mean_squared_error) 88 | 89 | 90 | def random_data(multi_device, batch_shape, rng): 91 | if multi_device: 92 | shape = (multi_device,) + tuple(batch_shape) 93 | else: 94 | shape = tuple(batch_shape) 95 | while True: 96 | rng, key = jax.random.split(rng) 97 | yield jax.random.normal(key, shape) 98 | 99 | 100 | def main(argv): 101 | del argv # Unused. 102 | 103 | learning_rate = jnp.asarray([LEARNING_RATE.value]) 104 | momentum = jnp.asarray([MOMENTUM.value]) 105 | damping = jnp.asarray([DAMPING.value]) 106 | 107 | # RNG keys 108 | global_step = jnp.zeros([]) 109 | rng = jax.random.PRNGKey(SEED.value) 110 | params_key, opt_key, step_key, data_key = jax.random.split(rng, 4) 111 | dataset = random_data(MULTI_DEVICE.value, (BATCH_SIZE.value, 20), data_key) 112 | example_batch = next(dataset) 113 | 114 | if MULTI_DEVICE.value: 115 | global_step = utils.replicate_all_local_devices(global_step) 116 | learning_rate = utils.replicate_all_local_devices(learning_rate) 117 | momentum = utils.replicate_all_local_devices(momentum) 118 | damping = utils.replicate_all_local_devices(damping) 119 | params_key, opt_key = utils.replicate_all_local_devices( 120 | (params_key, opt_key)) 121 | step_key = utils.make_different_rng_key_on_all_devices(step_key) 122 | split_key = jax.pmap(lambda x: tuple(jax.random.split(x))) 123 | jit_init_parameters_func = jax.pmap(model_init) 124 | else: 125 | split_key = jax.random.split 126 | jit_init_parameters_func = jax.jit(model_init) 127 | 128 | # Initialize or load parameters 129 | params, func_state = jit_init_parameters_func(params_key, example_batch) 130 | 131 | # Make optimizer 132 | optim = kfac_ferminet_alpha.Optimizer( 133 | value_and_grad_func=jax.value_and_grad( 134 | lambda p, x: model_loss(p, x, L2_REG.value), has_aux=True), 135 | l2_reg=L2_REG.value, 136 | value_func_has_aux=True, 137 | value_func_has_state=False, 138 | value_func_has_rng=False, 139 | learning_rate_schedule=None, 140 | momentum_schedule=None, 141 | damping_schedule=None, 142 | norm_constraint=1.0, 143 | num_burnin_steps=10, 144 | ) 145 | 146 | # Initialize optimizer 147 | opt_state = optim.init(params, opt_key, example_batch, func_state) 148 | 149 | for t in range(TRAINING_STEPS.value): 150 | step_key, key_t = split_key(step_key) 151 | params, opt_state, stats = optim.step( 152 | params, 153 | opt_state, 154 | key_t, 155 | dataset, 156 | learning_rate=learning_rate, 157 | momentum=momentum, 158 | damping=damping) 159 | global_step = global_step + 1 160 | 161 | # Log any of the statistics 162 | print(f"iteration: {t}") 163 | print(f"mini-batch loss = {stats['loss']}") 164 | if "aux" in stats: 165 | for k, v in stats["aux"].items(): 166 | print(f"{k} = {v}") 167 | print("----") 168 | 169 | 170 | if __name__ == "__main__": 171 | app.run(main) 172 | -------------------------------------------------------------------------------- /DeepSolid/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 16 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 17 | 18 | import datetime 19 | import os 20 | from typing import Optional 21 | import zipfile 22 | 23 | from absl import logging 24 | import jax 25 | import numpy as np 26 | 27 | 28 | def get_restore_path(restore_path: Optional[str] = None) -> Optional[str]: 29 | """Gets the path containing checkpoints from a previous calculation. 30 | 31 | Args: 32 | restore_path: path to checkpoints. 33 | 34 | Returns: 35 | The path or None if restore_path is falsy. 36 | """ 37 | if restore_path: 38 | ckpt_restore_path = restore_path 39 | else: 40 | ckpt_restore_path = None 41 | return ckpt_restore_path 42 | 43 | 44 | def find_last_checkpoint(ckpt_path: Optional[str] = None) -> Optional[str]: 45 | """Finds most recent valid checkpoint in a directory. 46 | 47 | Args: 48 | ckpt_path: Directory containing checkpoints. 49 | 50 | Returns: 51 | Last QMC checkpoint (ordered by sorting all checkpoints by name in reverse) 52 | or None if no valid checkpoint is found or ckpt_path is not given or doesn't 53 | exist. A checkpoint is regarded as not valid if it cannot be read 54 | successfully using np.load. 55 | """ 56 | if ckpt_path and os.path.exists(ckpt_path): 57 | files = [f for f in os.listdir(ckpt_path) if 'qmcjax_ckpt_' in f] 58 | # Handle case where last checkpoint is corrupt/empty. 59 | for file in sorted(files, reverse=True): 60 | fname = os.path.join(ckpt_path, file) 61 | with open(fname, 'rb') as f: 62 | try: 63 | np.load(f, allow_pickle=True) 64 | return fname 65 | except (OSError, EOFError, zipfile.BadZipFile): 66 | logging.info('Error loading checkpoint %s. Trying next checkpoint...', 67 | fname) 68 | return None 69 | 70 | 71 | def create_save_path(save_path: Optional[str],) -> str: 72 | """Creates the directory for saving checkpoints, if it doesn't exist. 73 | 74 | Args: 75 | save_path: directory to use. If false, create a directory in the working 76 | directory based upon the current time. 77 | 78 | Returns: 79 | Path to save checkpoints to. 80 | """ 81 | timestamp = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 82 | default_save_path = os.path.join(os.getcwd(), f'DeepSolid_{timestamp}') 83 | ckpt_save_path = save_path or default_save_path 84 | 85 | 86 | if ckpt_save_path and not os.path.isdir(ckpt_save_path): 87 | os.makedirs(ckpt_save_path) 88 | 89 | return ckpt_save_path 90 | 91 | 92 | def save(save_path: str, t: int, data, params, opt_state, mcmc_width, 93 | remote_save_path: Optional[int] = None) -> str: 94 | """Saves checkpoint information to a npz file. 95 | 96 | Args: 97 | save_path: path to directory to save checkpoint to. The checkpoint file is 98 | save_path/qmcjax_ckpt_$t.npz, where $t is the number of completed 99 | iterations. 100 | t: number of completed iterations. 101 | data: MCMC walker configurations. 102 | params: pytree of network parameters. 103 | opt_state: optimization state. 104 | mcmc_width: width to use in the MCMC proposal distribution. 105 | 106 | Returns: 107 | path to checkpoint file. 108 | """ 109 | ckpt_filename = os.path.join(save_path, f'qmcjax_ckpt_{t:06d}.npz') 110 | logging.info('Saving checkpoint %s', ckpt_filename) 111 | with open(ckpt_filename, 'wb') as f: 112 | np.savez( 113 | f, 114 | t=t, 115 | data=data, 116 | params=params, 117 | opt_state=opt_state, 118 | mcmc_width=mcmc_width) 119 | 120 | return ckpt_filename 121 | 122 | 123 | def restore(restore_filename: str, batch_size: Optional[int] = None, shape_check=True): 124 | """Restores data saved in a checkpoint. 125 | 126 | Args: 127 | restore_filename: filename containing checkpoint. 128 | batch_size: total batch size to be used. If present, check the data saved in 129 | the checkpoint is consistent with the batch size requested for the 130 | calculation. 131 | 132 | Returns: 133 | (t, data, params, opt_state, mcmc_width) tuple, where 134 | t: number of completed iterations. 135 | data: MCMC walker configurations. 136 | params: pytree of network parameters. 137 | opt_state: optimization state. 138 | mcmc_width: width to use in the MCMC proposal distribution. 139 | 140 | Raises: 141 | ValueError: if the leading dimension of data does not match the number of 142 | devices (i.e. the number of devices being parallelised over has changed) or 143 | if the total batch size is not equal to the number of MCMC configurations in 144 | data. 145 | """ 146 | logging.info('Loading checkpoint %s', restore_filename) 147 | with open(restore_filename, 'rb') as f: 148 | ckpt_data = np.load(f, allow_pickle=True) 149 | # Retrieve data from npz file. Non-array variables need to be converted back 150 | # to natives types using .tolist(). 151 | t = ckpt_data['t'].tolist() + 1 # Return the iterations completed. 152 | data = ckpt_data['data'] 153 | params = ckpt_data['params'].tolist() 154 | opt_state = ckpt_data['opt_state'].tolist() 155 | mcmc_width = ckpt_data['mcmc_width'].tolist() 156 | if shape_check: 157 | if data.shape[0] != jax.local_device_count(): 158 | raise ValueError( 159 | 'Incorrect number of devices found. Expected {}, found {}.'.format( 160 | data.shape[0], jax.local_device_count())) 161 | if batch_size and data.shape[0] * data.shape[1] != batch_size: 162 | raise ValueError( 163 | 'Wrong batch size in loaded data. Expected {}, found {}.'.format( 164 | batch_size, data.shape[0] * data.shape[1])) 165 | return t, data, params, opt_state, mcmc_width -------------------------------------------------------------------------------- /DeepSolid/base_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 16 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 17 | 18 | import ml_collections 19 | from ml_collections import config_dict 20 | 21 | 22 | def default() -> ml_collections.ConfigDict: 23 | """Create set of default parameters for running qmc.py. 24 | 25 | Note: placeholders (cfg.system.molecule and cfg.system.electrons) must be 26 | replaced with appropriate values. 27 | 28 | Returns: 29 | ml_collections.ConfigDict containing default settings. 30 | """ 31 | # wavefunction output. 32 | cfg = ml_collections.ConfigDict({ 33 | 'batch_size': 100, # batch size 34 | # Config module used. Should be set in get_config function as either the 35 | # absolute module or relative to the configs subdirectory. Relative 36 | # imports must start with a '.' (e.g. .atom). Do *not* override on 37 | # command-line. Do *not* set using __name__ from inside a get_config 38 | # function, as config_flags overrides this when importing the module using 39 | # importlib.import_module. 40 | 'config_module': __name__, 41 | 'use_x64': True, # use float64 or 32 42 | 'optim': { 43 | 'iterations': 1000000, # number of iterations 44 | 'optimizer': 'kfac', 45 | 'local_energy_outlier_width': 5.0, 46 | 'lr': { 47 | 'rate': 5.e-2, # learning rate, different from the reported lr in FermiNet 48 | # since DeepSolid energy gradient is not batch-size dependent 49 | 'decay': 1.0, # exponent of learning rate decay 50 | 'delay': 10000.0, # term that sets the scale of the rate decay 51 | }, 52 | 'clip_el': 5.0, # If not none, scale at which to clip local energy 53 | 'clip_type': 'real', # Clip real and imag part of gradient. 54 | 'gradient_clip': 5.0, 55 | # ADAM hyperparameters. See optax documentation for details. 56 | 'adam': { 57 | 'b1': 0.9, 58 | 'b2': 0.999, 59 | 'eps': 1.e-8, 60 | 'eps_root': 0.0, 61 | }, 62 | 'kfac': { 63 | 'invert_every': 1, 64 | 'cov_update_every': 1, 65 | 'damping': 0.001, 66 | 'cov_ema_decay': 0.95, 67 | 'momentum': 0.0, 68 | 'momentum_type': 'regular', 69 | # Warning: adaptive damping is not currently available. 70 | 'min_damping': 1.e-4, 71 | 'norm_constraint': 0.001, 72 | 'mean_center': True, 73 | 'l2_reg': 0.0, 74 | 'register_only_generic': False, 75 | }, 76 | 'ministeps': 1, 77 | 'laplacian_mode': 'for', # specify the laplacian evaluation mode, mode is one of 'for', 'partition' or 'hessian' 78 | # 'for' mode calculates the laplacian of each electron one by one, which is slow but save GPU memory 79 | # 'hessian' mode calculates the laplacian in a highly parallized mode, which is fast but require GPU memory 80 | # 'partition' mode calculate the laplacian in a moderate way. 81 | 'partition_number': 3, 82 | # Only used for 'partition' mode. 83 | # partition_number must be divisivle by (dim * number of electrons). The smaller the faster, but requires more memory. 84 | }, 85 | 'log': { 86 | 'stats_frequency': 1, # iterations between logging of stats 87 | 'save_frequency': 10.0, # minutes between saving network params 88 | 'save_frequency_in_step': -1, 89 | 'save_path': '', 90 | # specify the local save path 91 | 'restore_path': '', 92 | # specify the restore path which contained saved Model parameters. 93 | 'local_energies': False, 94 | 'complex_polarization': False, # log polarization order parameter which is useful for hydrogen chain. 95 | 'structure_factor': False, 96 | # return the strture factor S(k) at reciprocal lattices of supercell 97 | # log S(k) requires a lot of storage space, be careful. 98 | 'stats_file_name': 'train_stats' 99 | }, 100 | 'system': { 101 | 'pyscf_cell': None, # simulation cell obj 102 | 'ndim': 3, #dimension of the system 103 | 'internal_cell': None, 104 | }, 105 | 'mcmc': { 106 | # Note: HMC options are not currently used. 107 | # Number of burn in steps after pretraining. If zero do not burn in 108 | # or reinitialize walkers. 109 | 'burn_in': 100, 110 | 'steps': 20, # Number of MCMC steps to make between network updates. 111 | # Width of (atom-centred) Gaussian used to generate initial electron 112 | # configurations. 113 | 'init_width': 0.8, 114 | # Width of Gaussian used for random moves for RMW or step size for 115 | # HMC. 116 | 'move_width': 0.02, 117 | # Number of steps after which to update the adaptive MCMC step size 118 | 'adapt_frequency': 100, 119 | 'init_means': (), # Not implemented in JAX. 120 | # If true, scale the proposal width for each electron by the harmonic 121 | # mean of the distance to the nuclei. 122 | 'importance_sampling': False, 123 | # whether to use importance sampling in MCMC step, untested yet 124 | # Metropolis sampling will be used if false 125 | 'one_electron': False 126 | # If true, use one-electron moves, untested yet 127 | }, 128 | 'network': { 129 | 'detnet': { 130 | 'envelope_type': 'isotropic', 131 | # only isotropic mode has been tested 132 | 'bias_orbitals': False, 133 | 'use_last_layer': False, 134 | 'full_det': False, 135 | 'hidden_dims': ((256, 32), (256, 32), (256, 32)), 136 | 'determinants': 8, 137 | 'after_determinants': 1, 138 | 'distance_type': 'nu', 139 | }, 140 | 'twist': (0.0, 0.0, 0.0), # Difine the twist of wavefunction, 141 | # twists are given in terms of fractions of supercell reciprocal vectors 142 | }, 143 | 'debug': { 144 | # Check optimizer state, parameters and loss and raise an exception if 145 | # NaN is found. 146 | 'check_nan': False, # check whether the gradient contain nans before optimize, if True, retry. 147 | 'deterministic': False, # Use a deterministic seed. 148 | }, 149 | 'pretrain': { 150 | 'method': 'net', # Method is one of 'hf', 'net'. 151 | 'iterations': 1000, 152 | 'lr': 3e-4, 153 | 'steps': 1, #mcmc steps between each pretrain iterations 154 | }, 155 | }) 156 | 157 | return cfg 158 | 159 | 160 | def resolve(cfg): 161 | cfg = cfg.copy_and_resolve_references() 162 | return cfg 163 | -------------------------------------------------------------------------------- /DeepSolid/distance.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2019 Lucas K Wagner 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 24 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 25 | 26 | from functools import partial 27 | import jax 28 | import jax.numpy as jnp 29 | import logging 30 | 31 | 32 | class MinimalImageDistance: 33 | """Computer minimal image distance between particles and its images""" 34 | 35 | def __init__(self, latvec, verbose=0): 36 | """ 37 | 38 | :param latvec: array with shape [3,3], each row with a lattice vector 39 | """ 40 | 41 | latvec = jnp.asarray(latvec) 42 | ortho_tol = 1e-10 43 | diagonal = jnp.all(jnp.abs(latvec - jnp.diag(jnp.diagonal(latvec))) < ortho_tol) 44 | if diagonal: 45 | self.dist_i = self.diagonal_dist_i 46 | if verbose == 0: 47 | logging.info("Diagonal lattice vectors") 48 | else: 49 | orthogonal = ( 50 | jnp.dot(latvec[0], latvec[1]) < ortho_tol 51 | and jnp.dot(latvec[1], latvec[2]) < ortho_tol 52 | and jnp.dot(latvec[2], latvec[0]) < ortho_tol 53 | ) 54 | if orthogonal: 55 | self.dist_i = self.orthogonal_dist_i 56 | if verbose == 0: 57 | logging.info("Orthogonal lattice vectors") 58 | else: 59 | self.dist_i = self.general_dist_i 60 | if verbose == 0: 61 | logging.info("Non-orthogonal lattice vectors") 62 | self._latvec = latvec 63 | self._invvec = jnp.linalg.inv(latvec) 64 | self.dim = self._latvec.shape[-1] 65 | # list of all 26 neighboring cells 66 | mesh_grid = jnp.meshgrid(*[jnp.array([0, 1, 2]) for _ in range(3)]) 67 | self.point_list = jnp.stack([m.ravel() for m in mesh_grid], axis=0).T - 1 68 | self.shifts = self.point_list @ self._latvec 69 | 70 | def general_dist_i(self, configs, vec, return_wrap=False): 71 | """ 72 | calculate minimal distance between electron and ion in the most general lattice vector 73 | 74 | :param configs: ion coordinate with shape [N_atom * 3] 75 | :param vec: electron coordinate with shape [N_ele * 3] 76 | :return: minimal image distance between electron and atom with shape [N_ele, N_atom, 3] 77 | """ 78 | configs = configs.reshape([1, -1, self.dim]) 79 | v = vec.reshape([-1, 1, self.dim]) 80 | d1 = v - configs 81 | shifts = self.shifts.reshape((-1, *[1] * (len(d1.shape) - 1), 3)) 82 | d1all = d1[None] + shifts 83 | dists = jnp.linalg.norm(d1all, axis=-1) 84 | mininds = jnp.argmin(dists, axis=0) 85 | inds = jnp.meshgrid(*[jnp.arange(n) for n in mininds.shape], indexing='ij') 86 | if return_wrap: 87 | return d1all[(mininds, *inds)], -self.point_list[mininds] 88 | else: 89 | return d1all[(mininds, *inds)] 90 | 91 | def orthogonal_dist_i(self, configs, vec, return_wrap=False): 92 | """ 93 | calculate minimal distance between electron and ion in the orthogonal lattice vector 94 | 95 | :param configs: ion coordinate with shape [N_atom * 3] 96 | :param vec: electron coordinate with shape [N_ele * 3] 97 | :return: minimal image distance between electron and atom with shape [N_ele, N_atom, 3] 98 | """ 99 | configs = configs.reshape([1, -1, self.dim]).real 100 | v = vec.reshape([-1, 1, self.dim]).real 101 | d1 = v - configs 102 | frac_disps = jnp.einsum("...ij,jk->...ik", d1, self._invvec) 103 | replace_frac_disps = (frac_disps + 0.5) % 1 - 0.5 104 | if return_wrap == False: 105 | return jnp.einsum("...ij,jk->...ik", replace_frac_disps, self._latvec) 106 | else: 107 | wrap = -((frac_disps + 0.5) // 1) 108 | return jnp.einsum("...ij,jk->...ik", replace_frac_disps, self._latvec), wrap 109 | 110 | def diagonal_dist_i(self, configs, vec, return_wrap=False): 111 | """ 112 | calculate minimal distance between electron and ion in the diagonal lattice vector 113 | 114 | :param configs: ion coordinate with shape [N_atom * 3] 115 | :param vec: electron coordinate with shape [N_ele * 3] 116 | :return: minimal image distance between electron and atom with shape [N_ele, N_atom, 3] 117 | """ 118 | configs = configs.reshape([1, -1, self.dim]).real 119 | v = vec.reshape([-1, 1, self.dim]).real 120 | d1 = v - configs 121 | latvec_diag = jnp.diagonal(self._latvec) 122 | replace_d1 = (d1 + latvec_diag / 2) % latvec_diag - latvec_diag / 2 123 | if return_wrap == False: 124 | return replace_d1 125 | else: 126 | ## minus applies after //, order of // and - sign matters 127 | wrap = -((d1 + latvec_diag / 2) // latvec_diag) 128 | return replace_d1, wrap 129 | 130 | def dist_matrix(self, configs): 131 | """ 132 | calculate minimal distance between electrons 133 | 134 | :param configs: electron coordinate with shape [N_ele * 3] 135 | :return: vs: electron coordinate diffs with shape [N_ele, N_ele, 3] 136 | """ 137 | 138 | vs = self.dist_i(configs, configs) 139 | vs = vs * (1 - jnp.eye(vs.shape[0]))[..., None] 140 | 141 | return vs 142 | 143 | 144 | @partial(jax.vmap, in_axes=(None, 0), out_axes=0) 145 | def enforce_pbc(latvec, epos): 146 | """ 147 | Enforces periodic boundary conditions on a set of configs. 148 | 149 | :param lattvecs: orthogonal lattice vectors defining 3D torus: (3,3) 150 | :param epos: attempted new electron coordinates: (N_ele * 3) 151 | :return: final electron coordinates with PBCs imposed: (N_ele * 3) 152 | """ 153 | 154 | # Writes epos in terms of (lattice vecs) fractional coordinates 155 | dim = latvec.shape[-1] 156 | epos = epos.reshape(-1, dim) 157 | recpvecs = jnp.linalg.inv(latvec) 158 | epos_lvecs_coord = jnp.einsum("ij,jk->ik", epos, recpvecs) 159 | 160 | tmp = jnp.divmod(epos_lvecs_coord, 1) 161 | wrap = tmp[0] 162 | final_epos = jnp.matmul(tmp[1], latvec).ravel() 163 | return final_epos, wrap 164 | 165 | import numpy as np 166 | 167 | def np_enforce_pbc(latvec, epos): 168 | """ 169 | Enforces periodic boundary conditions on a set of configs. Used in float 32 mode. 170 | 171 | :param lattvecs: orthogonal lattice vectors defining 3D torus: (3,3) 172 | :param epos: attempted new electron coordinates: (N_ele * 3) 173 | :return: final electron coordinates with PBCs imposed: (N_ele * 3) 174 | """ 175 | 176 | # Writes epos in terms of (lattice vecs) fractional coordinates 177 | dim = latvec.shape[-1] 178 | epos = epos.reshape(-1, dim) 179 | recpvecs = np.linalg.inv(latvec) 180 | epos_lvecs_coord = np.einsum("ij,jk->ik", epos, recpvecs) 181 | 182 | tmp = np.divmod(epos_lvecs_coord, 1) 183 | wrap = tmp[0] 184 | final_epos = np.matmul(tmp[1], latvec).ravel() 185 | return final_epos, wrap 186 | -------------------------------------------------------------------------------- /DeepSolid/ewaldsum.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2019 Lucas K Wagner 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 24 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 25 | 26 | import logging 27 | 28 | import jax 29 | import jax.numpy as jnp 30 | from DeepSolid import distance 31 | 32 | 33 | class EwaldSum: 34 | def __init__(self, cell, ewald_gmax=200, nlatvec=1): 35 | """ 36 | :parameter cell: pyscf Cell object (simulation cell) 37 | :parameter int ewald_gmax: how far to take reciprocal sum; probably never needs to be changed. 38 | :parameter int nlatvec: how far to take real-space sum; probably never needs to be changed. 39 | """ 40 | self.nelec = cell.nelec 41 | self.atom_coords = jnp.asarray(cell.atom_coords()) 42 | self.atom_charges = jnp.asarray(cell.atom_charges()) 43 | self.latvec = jnp.asarray(cell.lattice_vectors()) 44 | self.dist = distance.MinimalImageDistance(self.latvec) 45 | self.set_lattice_displacements(nlatvec) 46 | self.set_up_reciprocal_ewald_sum(ewald_gmax) 47 | 48 | def set_lattice_displacements(self, nlatvec): 49 | """ 50 | Generates list of lattice-vector displacements to add together for real-space sum 51 | 52 | :parameter int nlatvec: sum goes from `-nlatvec` to `nlatvec` in each lattice direction. 53 | """ 54 | XYZ = jnp.meshgrid(*[jnp.arange(-nlatvec, nlatvec + 1)] * 3, indexing="ij") 55 | xyz = jnp.stack(XYZ, axis=-1).reshape((-1, 3)) 56 | self.lattice_displacements = jnp.asarray(jnp.dot(xyz, self.latvec)) 57 | 58 | def set_up_reciprocal_ewald_sum(self, ewald_gmax): 59 | cellvolume = jnp.linalg.det(self.latvec) 60 | recvec = jnp.linalg.inv(self.latvec).T 61 | 62 | # Determine alpha 63 | smallestheight = jnp.amin(1 / jnp.linalg.norm(recvec, axis=1)) 64 | self.alpha = 5.0 / smallestheight 65 | logging.info(f"Setting Ewald alpha to {self.alpha.item()}") 66 | 67 | # Determine G points to include in reciprocal Ewald sum 68 | gptsXpos = jnp.meshgrid( 69 | jnp.arange(1, ewald_gmax + 1), 70 | *[jnp.arange(-ewald_gmax, ewald_gmax + 1)] * 2, 71 | indexing="ij" 72 | ) 73 | zero = jnp.asarray([0]) 74 | gptsX0Ypos = jnp.meshgrid( 75 | zero, 76 | jnp.arange(1, ewald_gmax + 1), 77 | jnp.arange(-ewald_gmax, ewald_gmax + 1), 78 | indexing="ij", 79 | ) 80 | gptsX0Y0Zpos = jnp.meshgrid( 81 | zero, zero, jnp.arange(1, ewald_gmax + 1), indexing="ij" 82 | ) 83 | gs = zip( 84 | *[ 85 | select_big(x, cellvolume, recvec, self.alpha) 86 | for x in (gptsXpos, gptsX0Ypos, gptsX0Y0Zpos) 87 | ] 88 | ) 89 | self.gpoints, self.gweight = [jnp.concatenate(x, axis=0) for x in gs] 90 | self.set_ewald_constants(cellvolume) 91 | 92 | def set_ewald_constants(self, cellvolume): 93 | self.i_sum = jnp.sum(self.atom_charges) 94 | ii_sum2 = jnp.sum(self.atom_charges ** 2) 95 | ii_sum = (self.i_sum ** 2 - ii_sum2) / 2 96 | 97 | self.ijconst = -jnp.pi / (cellvolume * self.alpha ** 2) 98 | self.squareconst = -self.alpha / jnp.sqrt(jnp.pi) + self.ijconst / 2 99 | 100 | self.ii_const = ii_sum * self.ijconst + ii_sum2 * self.squareconst 101 | self.e_single_test = -self.i_sum * self.ijconst + self.squareconst 102 | self.ion_ion = self.ewald_ion() 103 | 104 | # XC correction not used, so we can compare to other codes 105 | # rs = lambda ne: (3 / (4 * np.pi) / (ne * cellvolume)) ** (1 / 3) 106 | # cexc = 0.36 107 | # xc_correction = lambda ne: cexc / rs(ne) 108 | 109 | def ee_const(self, ne): 110 | return ne * (ne - 1) / 2 * self.ijconst + ne * self.squareconst 111 | 112 | def ei_const(self, ne): 113 | return -ne * self.i_sum * self.ijconst 114 | 115 | def e_single(self, ne): 116 | return ( 117 | 0.5 * (ne - 1) * self.ijconst - self.i_sum * self.ijconst + self.squareconst 118 | ) 119 | 120 | def ewald_ion(self): 121 | # Real space part 122 | if len(self.atom_charges) == 1: 123 | ion_ion_real = 0 124 | else: 125 | ion_distances = self.dist.dist_matrix(self.atom_coords.ravel()) 126 | rvec = ion_distances[None, :, :, :] + self.lattice_displacements[:, None, None, :] 127 | r = jnp.linalg.norm(rvec, axis=-1) 128 | charge_ij = self.atom_charges[..., None] * self.atom_charges[None, ...] 129 | ion_ion_real = jnp.sum(jnp.triu(charge_ij * jax.lax.erfc(self.alpha * r) / r, k=1)) 130 | # Reciprocal space part 131 | GdotR = jnp.dot(self.gpoints, jnp.asarray(self.atom_coords.T)) 132 | self.ion_exp = jnp.dot(jnp.exp(1j * GdotR), self.atom_charges) 133 | ion_ion_rec = jnp.dot(self.gweight, jnp.abs(self.ion_exp) ** 2) 134 | 135 | ion_ion = ion_ion_real + ion_ion_rec 136 | return ion_ion 137 | 138 | def _real_cij(self, dists): 139 | r = dists[:, :, None, :] + self.lattice_displacements 140 | r = jnp.linalg.norm(r, axis=-1) 141 | cij = jnp.sum(jax.lax.erfc(self.alpha * r) / r, axis=-1) 142 | return cij 143 | 144 | def ewald_electron(self, configs): 145 | nelec = sum(self.nelec) 146 | 147 | # Real space electron-ion part 148 | # ei_distances shape (elec, atom, dim) 149 | ei_distances = self.dist.dist_i(self.atom_coords.ravel(), configs) 150 | ei_cij = self._real_cij(ei_distances) 151 | ei_real_separated = jnp.sum(-self.atom_charges[None, :] * ei_cij) 152 | 153 | # Real space electron-electron part 154 | ee_real_separated = jnp.array(0.) 155 | if nelec > 1: 156 | ee_distances = self.dist.dist_matrix(configs) 157 | rvec = ee_distances[None, :, :, :] + self.lattice_displacements[:, None, None, :] 158 | r = jnp.linalg.norm(rvec, axis=-1) 159 | ee_real_separated = jnp.sum(jnp.triu(jax.lax.erfc(self.alpha * r) / r, k=1)) 160 | 161 | # ee_distances = self.dist.dist_matrix(configs) 162 | # ee_cij = self._real_cij(ee_distances) 163 | # 164 | # for ((i, j), val) in zip(ee_inds, ee_cij.T): 165 | # ee_real_separated[:, i] += val 166 | # ee_real_separated[:, j] += val 167 | # ee_real_separated /= 2 168 | 169 | ee_recip, ei_recip = self.reciprocal_space_electron(configs) 170 | ee = ee_real_separated + ee_recip 171 | ei = ei_real_separated + ei_recip 172 | return ee, ei 173 | 174 | def reciprocal_space_electron(self, configs): 175 | # Reciprocal space electron-electron part 176 | e_GdotR = jnp.einsum("ik,jk->ij", configs.reshape(sum(self.nelec), -1), self.gpoints) 177 | sum_e_sin = jnp.sin(e_GdotR).sum(axis=0) 178 | sum_e_cos = jnp.cos(e_GdotR).sum(axis=0) 179 | ee_recip = jnp.dot(sum_e_sin ** 2 + sum_e_cos ** 2, self.gweight) 180 | ## Reciprocal space electron-ion part 181 | coscos_sinsin = -self.ion_exp.real * sum_e_cos - self.ion_exp.imag * sum_e_sin 182 | ei_recip = 2 * jnp.dot(coscos_sinsin, self.gweight) 183 | return ee_recip, ei_recip 184 | 185 | def energy(self, configs): 186 | nelec = sum(self.nelec) 187 | ee, ei = self.ewald_electron(configs) 188 | ee += self.ee_const(nelec) 189 | ei += self.ei_const(nelec) 190 | ii = self.ion_ion + self.ii_const 191 | return jnp.asarray(ee), jnp.asarray(ei), jnp.asarray(ii) 192 | 193 | 194 | def select_big(gpts, cellvolume, recvec, alpha): 195 | gpoints = jnp.einsum("j...,jk->...k", gpts, recvec) * 2 * jnp.pi 196 | gsquared = jnp.einsum("...k,...k->...", gpoints, gpoints) 197 | gweight = 4 * jnp.pi * jnp.exp(-gsquared / (4 * alpha ** 2)) 198 | gweight /= cellvolume * gsquared 199 | bigweight = gweight > 1e-12 200 | return gpoints[bigweight], gweight[bigweight] 201 | -------------------------------------------------------------------------------- /DeepSolid/hf.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2019 Lucas K Wagner 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 24 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 25 | 26 | from pyscf.pbc import gto, scf 27 | from DeepSolid import supercell 28 | from DeepSolid import distance 29 | import numpy as np 30 | 31 | _gldict = {"laplacian": np.s_[:1], "gradient_laplacian": np.s_[0:4]} 32 | 33 | 34 | def _aostack_mol(ao, gl): 35 | return np.concatenate( 36 | [ao[_gldict[gl]], ao[[4, 7, 9]].sum(axis=0, keepdims=True)], axis=0 37 | ) 38 | 39 | 40 | def _aostack_pbc(ao, gl): 41 | return [_aostack_mol(ak, gl) for ak in ao] 42 | 43 | 44 | class SCF: 45 | def __init__(self, cell, twist=np.ones(3)*0.5): 46 | """ 47 | Hartree Fock wave function class for QMC simulation 48 | 49 | :param cell: pyscf.pbc.gto.Cell, simulation object 50 | :param twist:np.array with shape [3] 51 | """ 52 | self._aostack = _aostack_pbc 53 | self.coeff_key = ("mo_coeff_alpha", "mo_coeff_beta") 54 | self.param_split = {} 55 | self.parameters = {} 56 | self.k_split = {} 57 | self.ns_tol = cell.scale 58 | self.simulation_cell = cell 59 | self.primitive_cell = cell.original_cell 60 | self.sim_nelec = self.simulation_cell.nelec 61 | self.kpts = supercell.get_supercell_kpts(self.simulation_cell) 62 | self.kpts = self.kpts + np.dot(np.linalg.inv(cell.a), np.mod(twist, 1.0)) * 2 * np.pi 63 | if hasattr(self.simulation_cell, 'hf_type'): 64 | hf_type = self.simulation_cell.hf_type 65 | else: 66 | hf_type = 'rhf' 67 | 68 | if hf_type == 'uhf': 69 | self.kmf = scf.KUHF(self.primitive_cell, exxdiv='ewald', kpts=self.kpts).density_fit() 70 | 71 | # break initial guess symmetry for UHF 72 | dm_up, dm_down = self.kmf.get_init_guess() 73 | dm_down[:, :2, :2] = 0 74 | dm = (dm_up, dm_down) 75 | elif hf_type == 'rhf': 76 | self.kmf = scf.KHF(self.primitive_cell, exxdiv='ewald', kpts=self.kpts).density_fit() 77 | dm = self.kmf.get_init_guess() 78 | else: 79 | raise ValueError('Unrecognized Hartree Fock type.') 80 | 81 | self.kmf.kernel(dm) 82 | # self.init_scf() 83 | 84 | def init_scf(self): 85 | """ 86 | initialization function to set up HF ansatz. 87 | """ 88 | self.klist = [] 89 | for s, key in enumerate(self.coeff_key): 90 | mclist = [] 91 | for k in range(self.kmf.kpts.shape[0]): 92 | # restrict or not 93 | if len(self.kmf.mo_coeff[0][0].shape) == 2: 94 | mca = self.kmf.mo_coeff[s][k][:, np.asarray(self.kmf.mo_occ[s][k] > 0.9)] 95 | else: 96 | minocc = (0.9, 1.1)[s] 97 | mca = self.kmf.mo_coeff[k][:, np.asarray(self.kmf.mo_occ[k] > minocc)] 98 | mclist.append(mca) 99 | self.param_split[key] = np.cumsum([m.shape[1] for m in mclist]) 100 | self.parameters[key] = np.concatenate(mclist, axis=-1) 101 | self.k_split[key] = np.array([m.shape[1] for m in mclist]) 102 | self.klist.append(np.concatenate([np.tile(kpt[None, :], (split, 1)) 103 | for kpt, split in 104 | zip(self.kmf.kpts, self.k_split[self.coeff_key[s]])])) 105 | 106 | def eval_orbitals_pbc(self, coord, eval_str="GTOval_sph"): 107 | """ 108 | eval the atomic orbital valus of HF. 109 | :param coord: electron walkers with shape [batch, ne * ndim]. 110 | :param eval_str: 111 | :return: atomic orbital valus of HF. 112 | """ 113 | prim_coord, wrap = distance.np_enforce_pbc(self.primitive_cell.a, coord.reshape([coord.shape[0], -1])) 114 | prim_coord = prim_coord.reshape([-1, 3]) 115 | wrap = wrap.reshape([-1, 3]) 116 | ao = self.primitive_cell.eval_gto("PBC" + eval_str, prim_coord, kpts=self.kmf.kpts) 117 | 118 | kdotR = np.einsum('ij,kj,nk->in', self.kmf.kpts, self.primitive_cell.a, wrap) 119 | wrap_phase = np.exp(1j*kdotR) 120 | ao = [ao[k] * wrap_phase[k][:, None] for k in range(len(self.kmf.kpts))] 121 | 122 | return ao 123 | 124 | def eval_mos_pbc(self, aos, s): 125 | """ 126 | eval the molecular orbital values. 127 | :param aos: atomic orbital values. 128 | :param s: spin index. 129 | :return: molecular orbital values. 130 | """ 131 | c = self.coeff_key[s] 132 | p = np.split(self.parameters[c], self.param_split[c], axis=-1) 133 | mo = [ao.dot(p[k]) for k, ao in enumerate(aos)] 134 | return np.concatenate(mo, axis=-1) 135 | 136 | def eval_orb_mat(self, coord): 137 | """ 138 | eval the orbital matrix of HF. 139 | :param coord: electron walkers with shape [batch, ne * ndim]. 140 | :return: orbital matrix of HF. 141 | """ 142 | batch, nelec, ndim = coord.shape 143 | aos = self.eval_orbitals_pbc(coord) 144 | aos_shape = (self.ns_tol, batch, nelec, -1) 145 | 146 | aos = np.reshape(aos, aos_shape) 147 | mos = [] 148 | for s in [0, 1]: 149 | i0, i1 = s * self.sim_nelec[0], self.sim_nelec[0] + s * self.sim_nelec[1] 150 | ne = self.sim_nelec[s] 151 | mo = self.eval_mos_pbc(aos[:, :, i0:i1], s).reshape([batch, ne, ne]) 152 | mos.append(mo) 153 | return mos 154 | 155 | def eval_slogdet(self, coord): 156 | """ 157 | eval the slogdet of HF 158 | :param coord: electron walkers with shape [batch, ne * ndim]. 159 | :return: slogdet of HF. 160 | """ 161 | mos = self.eval_orb_mat(coord) 162 | slogdets = [np.linalg.slogdet(mo) for mo in mos] 163 | phase, slogdet = list(map(lambda x, y: [x[0] * y[0], x[1] + y[1]], *zip(slogdets)))[0] 164 | 165 | return phase, slogdet 166 | 167 | def eval_phase(self, coord): 168 | """ 169 | 170 | :param coord: 171 | :return: a list of phase with shape [B, nk * nao] 172 | """ 173 | coords = np.split(coord, (self.sim_nelec[0], sum(self.sim_nelec)), axis=1) 174 | kdots = [np.einsum('ijl, kl->ijk', cor, kpt) for cor, kpt in zip(coords, self.klist)] 175 | phase = [np.exp(1j * kdot) for kdot in kdots] 176 | return phase 177 | 178 | def pure_periodic(self, coord): 179 | orbitals = self.eval_orb_mat(coord) 180 | ## minus symbol makes mos to be periodical 181 | phases = self.eval_phase(-coord) 182 | return [orbital * phase for orbital, phase in zip(orbitals, phases)] 183 | 184 | def eval_inverse(self, coord): 185 | mats = self.eval_orb_mat(coord) 186 | inverse = [np.linalg.inv(mat) for mat in mats] 187 | 188 | return inverse 189 | 190 | def _testrow(self, e, vec, inverse, mask=None, spin=None): 191 | """vec is a nconfig,nmo vector which replaces row e""" 192 | s = int(e >= self.sim_nelec[0]) if spin is None else spin 193 | elec = e - s * self.sim_nelec[0] 194 | if mask is None: 195 | return np.einsum("i...j,ij...->i...", vec, inverse[s][:, :, elec]) 196 | 197 | return np.einsum("i...j,ij...->i...", vec, inverse[s][mask][:, :, elec]) 198 | 199 | def laplacian(self, e, coord, inverse): 200 | s = int(e >= self.sim_nelec[0]) 201 | ao = self.eval_orbitals_pbc(coord, eval_str="GTOval_sph_deriv2") 202 | mo = self.eval_mos_pbc(self._aostack(ao, "laplacian"), s) 203 | ratios = np.asarray([self._testrow(e, x, inverse=inverse) for x in mo]) 204 | return ratios[1] / ratios[0] 205 | 206 | def kinetic(self, coord): 207 | ke = np.zeros(coord.shape[0]) 208 | inverse = self.eval_inverse(coord) 209 | for e in range(self.simulation_cell.nelectron): 210 | ke = ke - 0.5 * np.real(self.laplacian(e, 211 | coord[:, e, :], 212 | inverse=inverse)) 213 | return ke 214 | 215 | def __call__(self, coord): 216 | phase, slogdet = self.eval_slogdet(coord) 217 | psi = np.exp(slogdet) * phase 218 | return psi 219 | -------------------------------------------------------------------------------- /DeepSolid/hamiltonian.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 16 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | from DeepSolid import ewaldsum 21 | from DeepSolid import network 22 | 23 | 24 | def local_kinetic_energy(f): 25 | ''' 26 | holomorphic mode, which seems dangerous since many op don't support complex number now. 27 | :param f: function return the logdet of wavefunction 28 | :return: local kinetic energy 29 | ''' 30 | def _lapl_over_f(params, x): 31 | ne = x.shape[-1] 32 | eye = jnp.eye(ne) 33 | grad_f = jax.grad(f, argnums=1, holomorphic=True) 34 | grad_f_closure = lambda y: grad_f(params, y) 35 | 36 | def _body_fun(i, val): 37 | primal, tangent = jax.jvp(grad_f_closure, (x + 0j,), (eye[i] + 0j,)) 38 | return val + tangent[i] + primal[i] ** 2 39 | 40 | return -0.5 * jax.lax.fori_loop(0, ne, _body_fun, 0.0) 41 | 42 | return _lapl_over_f 43 | 44 | 45 | def local_kinetic_energy_real_imag(f): 46 | ''' 47 | evaluate real and imaginary part of laplacian. 48 | :param f: function return the logdet of wavefunction 49 | :return: local kinetic energy 50 | ''' 51 | def _lapl_over_f(params, x): 52 | ne = x.shape[-1] 53 | eye = jnp.eye(ne) 54 | grad_f_real = jax.grad(lambda p, y: f(p, y).real, argnums=1) 55 | grad_f_imag = jax.grad(lambda p, y: f(p, y).imag, argnums=1) 56 | grad_f_real_closure = lambda y: grad_f_real(params, y) 57 | grad_f_imag_closure = lambda y: grad_f_imag(params, y) 58 | 59 | def _body_fun(i, val): 60 | primal_real, tangent_real = jax.jvp(grad_f_real_closure, (x,), (eye[i],)) 61 | primal_imag, tangent_imag = jax.jvp(grad_f_imag_closure, (x,), (eye[i],)) 62 | kine_real = val[0] + tangent_real[i] + primal_real[i] ** 2 - primal_imag[i] ** 2 63 | kine_imag = val[1] + tangent_imag[i] + 2 * primal_real[i] * primal_imag[i] 64 | return [kine_real, kine_imag] 65 | 66 | result = jax.lax.fori_loop(0, ne, _body_fun, [0.0, 0.0]) 67 | complex = [1., 1j] 68 | return [-0.5 * re * com for re, com in zip(result, complex)] 69 | 70 | return lambda p, y: _lapl_over_f(p, y) 71 | 72 | 73 | def local_kinetic_energy_real_imag_dim_batch(f): 74 | ''' 75 | evaluate real and imaginary part of laplacian, in which vamp is used to accelerate. 76 | :param f: function return the logdet of wavefunction 77 | :return: local kinetic energy 78 | ''' 79 | 80 | def _lapl_over_f(params, x): 81 | ne = x.shape[-1] 82 | eye = jnp.eye(ne) 83 | grad_f_real = jax.grad(lambda p, y: f(p, y).real, argnums=1) 84 | grad_f_imag = jax.grad(lambda p, y: f(p, y).imag, argnums=1) 85 | grad_f_real_closure = lambda y: grad_f_real(params, y) 86 | grad_f_imag_closure = lambda y: grad_f_imag(params, y) 87 | 88 | def _body_fun(dummy_eye): 89 | primal_real, tangent_real = jax.jvp(grad_f_real_closure, (x,), (dummy_eye,)) 90 | primal_imag, tangent_imag = jax.jvp(grad_f_imag_closure, (x,), (dummy_eye,)) 91 | kine_real = ((tangent_real + primal_real ** 2 - primal_imag ** 2) * dummy_eye).sum() 92 | kine_imag = ((tangent_imag + 2 * primal_real * primal_imag) * dummy_eye).sum() 93 | return [kine_real, kine_imag] 94 | 95 | # result = jax.lax.fori_loop(0, ne, _body_fun, [0.0, 0.0]) 96 | result = jax.vmap(_body_fun, in_axes=0)(eye) 97 | result = [re.sum() for re in result] 98 | complex = [1., 1j] 99 | return [-0.5 * re * com for re, com in zip(result, complex)] 100 | 101 | return lambda p, y: _lapl_over_f(p, y) 102 | 103 | 104 | def local_kinetic_energy_real_imag_hessian(f): 105 | ''' 106 | Use jax.hessian to evaluate laplacian, which requires huge amount of memory. 107 | :param f: function return the logdet of wavefunction 108 | :return: local kinetic energy 109 | ''' 110 | def _lapl_over_f(params, x): 111 | ne = x.shape[-1] 112 | grad_f_real = jax.grad(lambda p, y: f(p, y).real, argnums=1) 113 | grad_f_imag = jax.grad(lambda p, y: f(p, y).imag, argnums=1) 114 | hessian_f_real = jax.hessian(lambda p, y: f(p, y).real, argnums=1) 115 | hessian_f_imag = jax.hessian(lambda p, y: f(p, y).imag, argnums=1) 116 | v_grad_f_real = grad_f_real(params, x) 117 | v_grad_f_imag = grad_f_imag(params, x) 118 | real_kinetic = jnp.trace(hessian_f_real(params, x),) + jnp.sum(v_grad_f_real**2) - jnp.sum(v_grad_f_imag**2) 119 | imag_kinetic = jnp.trace(hessian_f_imag(params, x),) + jnp.sum(2 * v_grad_f_real * v_grad_f_imag) 120 | 121 | complex = [1., 1j] 122 | return [-0.5 * re * com for re, com in zip([real_kinetic, imag_kinetic], complex)] 123 | 124 | return lambda p, y: _lapl_over_f(p, y) 125 | 126 | 127 | def local_kinetic_energy_partition(f, partition_number=3): 128 | ''' 129 | Try to parallelize the evaluation of laplacian 130 | :param f: bfunction return the logdet of wavefunction 131 | :param partition_number: partition_number must be divisivle by (dim * number of electrons). 132 | The smaller the faster, but requires more memory. 133 | :return: local kinetic energy 134 | ''' 135 | vjvp = jax.vmap(jax.jvp, in_axes=(None, None, 0)) 136 | 137 | def _lapl_over_f(params, x): 138 | n = x.shape[0] 139 | eye = jnp.eye(n) 140 | grad_f_real = jax.grad(lambda p, y: f(p, y).real, argnums=1) 141 | grad_f_imag = jax.grad(lambda p, y: f(p, y).imag, argnums=1) 142 | grad_f_closure_real = lambda y: grad_f_real(params, y) 143 | grad_f_closure_imag = lambda y: grad_f_imag(params, y) 144 | 145 | eyes = jnp.asarray(jnp.array_split(eye, partition_number)) 146 | def _body_fun(val, e): 147 | primal_real, tangent_real = vjvp(grad_f_closure_real, (x,), (e,)) 148 | primal_imag, tangent_imag = vjvp(grad_f_closure_imag, (x,), (e,)) 149 | return val, ([primal_real, primal_imag], [tangent_real, tangent_imag]) 150 | _, (plist, tlist) = \ 151 | jax.lax.scan(_body_fun, None, eyes) 152 | primal = [primal.reshape((-1, primal.shape[-1])) for primal in plist] 153 | tangent = [tangent.reshape((-1, tangent.shape[-1])) for tangent in tlist] 154 | 155 | real_kinetic = jnp.trace(tangent[0]) + jnp.trace(primal[0]**2).sum() - jnp.trace(primal[1]**2).sum() 156 | imag_kinetic = jnp.trace(tangent[1]) + jnp.trace(2 * primal[0] * primal[1]).sum() 157 | return [-0.5 * real_kinetic, -0.5 * 1j * imag_kinetic] 158 | 159 | return _lapl_over_f 160 | 161 | 162 | 163 | def local_ewald_energy(simulation_cell): 164 | """ 165 | generate local energy of ewald part. 166 | :param simulation_cell: 167 | :return: 168 | """ 169 | ewald = ewaldsum.EwaldSum(simulation_cell) 170 | assert jnp.allclose(simulation_cell.energy_nuc(), 171 | (ewald.ion_ion + ewald.ii_const), 172 | rtol=1e-8, atol=1e-5) 173 | ## check pyscf madelung constant agrees with DeepSolid 174 | 175 | def _local_ewald_energy(x): 176 | energy = ewald.energy(x) 177 | return sum(energy) 178 | 179 | return _local_ewald_energy 180 | 181 | 182 | def local_energy(f, simulation_cell): 183 | ke = local_kinetic_energy(f) 184 | ew = local_ewald_energy(simulation_cell) 185 | 186 | def _local_energy(params, x): 187 | kinetic = ke(params, x) 188 | ewald = ew(x) 189 | return kinetic + ewald 190 | 191 | return _local_energy 192 | 193 | 194 | def local_energy_seperate(f, simulation_cell, mode='for', partition_number=3): 195 | """ 196 | genetate the local energy function. 197 | :param f: function return the logdet of wavefunction. 198 | :param simulation_cell: pyscf object of simulation cell. 199 | :param mode: specify the evaluation style of local energy. 200 | 'for' mode calculates the laplacian of each electron one by one, which is slow but save GPU memory 201 | 'hessian' mode calculates the laplacian in a highly parallized mode, which is fast but require GPU memory 202 | 'partition' mode calculate the laplacian in a moderate way. 203 | :param partition_number: Only used if 'partition' mode is employed. 204 | partition_number must be divisivle by (dim * number of electrons). 205 | The smaller the faster, but requires more memory. 206 | :return: the local energy function. 207 | """ 208 | 209 | if mode == 'for': 210 | ke_ri = local_kinetic_energy_real_imag(f) 211 | elif mode == 'hessian': 212 | ke_ri = local_kinetic_energy_real_imag_hessian(f) 213 | elif mode == 'dim_batch': 214 | ke_ri = local_kinetic_energy_real_imag_dim_batch(f) 215 | elif mode == 'partition': 216 | ke_ri = local_kinetic_energy_partition(f, partition_number=partition_number) 217 | else: 218 | raise ValueError('Unrecognized laplacian evaluation mode.') 219 | ke = lambda p, y: sum(ke_ri(p, y)) 220 | # ke = local_kinetic_energy(f) 221 | ew = local_ewald_energy(simulation_cell) 222 | 223 | def _local_energy(params, x): 224 | kinetic = ke(params, x) 225 | ewald = ew(x) 226 | return kinetic, ewald 227 | 228 | return _local_energy 229 | -------------------------------------------------------------------------------- /DeepSolid/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 16 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 17 | 18 | import chex 19 | import jax 20 | import jax.numpy as jnp 21 | import functools 22 | 23 | from DeepSolid import hamiltonian 24 | from DeepSolid import constants 25 | from DeepSolid.utils.kfac_ferminet_alpha import loss_functions 26 | 27 | 28 | @chex.dataclass 29 | class AuxiliaryLossData: 30 | variance: jnp.DeviceArray 31 | local_energy: jnp.DeviceArray 32 | imaginary: jnp.DeviceArray 33 | kinetic: jnp.DeviceArray 34 | ewald: jnp.DeviceArray 35 | 36 | 37 | def make_loss(network, batch_network, 38 | simulation_cell, 39 | clip_local_energy=5.0, 40 | clip_type='real', 41 | mode='for', 42 | partition_number=3): 43 | """ 44 | generates loss function used for wavefunction trains. 45 | :param network: unbatched logdet function of wavefunction 46 | :param batch_network: batched logdet function of wavefunction 47 | :param simulation_cell: pyscf object of simulation cell. 48 | :param clip_local_energy: clip window width of local energy. 49 | :param clip_type: specify the clip style. real mode clips the local energy in Cartesion style, 50 | and complex mode in polar style 51 | :param mode: specify the evaluation style of local energy. 52 | 'for' mode calculates the laplacian of each electron one by one, which is slow but save GPU memory 53 | 'hessian' mode calculates the laplacian in a highly parallized mode, which is fast but require GPU memory 54 | 'partition' mode calculate the laplacian in a moderate way. 55 | :param partition_number: Only used if 'partition' mode is employed. 56 | partition_number must be divisivle by (dim * number of electrons). 57 | The smaller the faster, but requires more memory. 58 | :return: the loss function 59 | """ 60 | el_fun = hamiltonian.local_energy_seperate(network, 61 | simulation_cell=simulation_cell, 62 | mode=mode, 63 | partition_number=partition_number) 64 | batch_local_energy = jax.vmap(el_fun, in_axes=(None, 0), out_axes=0) 65 | 66 | @jax.custom_jvp 67 | def total_energy(params, data): 68 | """ 69 | 70 | :param params: a dictionary of parameters 71 | :param data: batch electron coord with shape [Batch, Nelec * Ndim] 72 | :return: energy expectation of corresponding walkers (only take real part) with shape [Batch] 73 | """ 74 | ke, ew = batch_local_energy(params, data) 75 | e_l = ke + ew 76 | mean_e_l = jnp.mean(e_l) 77 | 78 | pmean_loss = constants.pmean_if_pmap(mean_e_l, axis_name=constants.PMAP_AXIS_NAME) 79 | variance = constants.pmean_if_pmap(jnp.mean(jnp.abs(e_l)**2) - jnp.abs(mean_e_l.real) ** 2, 80 | axis_name=constants.PMAP_AXIS_NAME) 81 | loss = pmean_loss.real 82 | imaginary = pmean_loss.imag 83 | 84 | return loss, AuxiliaryLossData(variance=variance, 85 | local_energy=e_l, 86 | imaginary=imaginary, 87 | kinetic=ke, 88 | ewald=ew, 89 | ) 90 | 91 | @total_energy.defjvp 92 | def total_energy_jvp(primals, tangents): 93 | """ 94 | customised jvp function of loss function. 95 | :param primals: inputs of total_energy function (params, data) 96 | :param tangents: tangent vectors corresponding to the primal (params, data) 97 | :return: Jacobian-vector product of total energy. 98 | """ 99 | params, data = primals 100 | loss, aux_data = total_energy(params, data) 101 | diff = (aux_data.local_energy - loss) 102 | if clip_local_energy > 0.0: 103 | if clip_type == 'complex': 104 | radius, phase = jnp.abs(diff), jnp.angle(diff) 105 | radius_tv = constants.pmean_if_pmap(radius.std(), axis_name=constants.PMAP_AXIS_NAME) 106 | radius_mean = jnp.median(radius) 107 | radius_mean = constants.pmean_if_pmap(radius_mean, axis_name=constants.PMAP_AXIS_NAME) 108 | clip_radius = jnp.clip(radius, 109 | radius_mean - radius_tv * clip_local_energy, 110 | radius_mean + radius_tv * clip_local_energy) 111 | clip_diff = clip_radius * jnp.exp(1j * phase) 112 | elif clip_type == 'real': 113 | tv_re = jnp.mean(jnp.abs(diff.real)) 114 | tv_re = constants.pmean_if_pmap(tv_re, axis_name=constants.PMAP_AXIS_NAME) 115 | tv_im = jnp.mean(jnp.abs(diff.imag)) 116 | tv_im = constants.pmean_if_pmap(tv_im, axis_name=constants.PMAP_AXIS_NAME) 117 | clip_diff_re = jnp.clip(diff.real, 118 | -clip_local_energy * tv_re, 119 | clip_local_energy * tv_re) 120 | clip_diff_im = jnp.clip(diff.imag, 121 | -clip_local_energy * tv_im, 122 | clip_local_energy * tv_im) 123 | clip_diff = clip_diff_re + clip_diff_im * 1j 124 | else: 125 | raise ValueError('Unrecognized clip type.') 126 | else: 127 | clip_diff = diff 128 | 129 | psi_primal, psi_tangent = jax.jvp(batch_network, primals, tangents) 130 | conj_psi_tangent = jnp.conjugate(psi_tangent) 131 | conj_psi_primal = jnp.conjugate(psi_primal) 132 | 133 | loss_functions.register_normal_predictive_distribution(conj_psi_primal[:, None]) 134 | 135 | primals_out = loss, aux_data 136 | # tangents_dot = jnp.dot(clip_diff, conj_psi_tangent).real 137 | # dot causes the gradient to be extensive with batch size, which does matter for KFAC. 138 | tangents_dot = jnp.mean((clip_diff * conj_psi_tangent).real) 139 | 140 | tangents_out = (tangents_dot, aux_data) 141 | 142 | return primals_out, tangents_out 143 | 144 | return total_energy 145 | 146 | 147 | def make_training_step(mcmc_step, val_and_grad, opt_update): 148 | """ 149 | generates the function used for wavefunction train. 150 | :param mcmc_step: MCMC sample function 151 | :param val_and_grad: value and grad of loss function 152 | :param opt_update: optimizer state 153 | :return: one iteration train function 154 | """ 155 | 156 | @functools.partial(constants.pmap, donate_argnums=(1, 2, 3, 4)) 157 | def step(t, data, params, state, key, mcmc_width): 158 | """ 159 | one iteration train of energy 160 | :param t: current step of energy train. 161 | :param data: batched electron walkes with shape [batch, ne * ndim] 162 | :param params: a dictionary of parameters. 163 | :param state: optimizer state. 164 | :param key: PRNG key. 165 | :param mcmc_width: mcmc move width 166 | :return: 167 | data: moved electron walkers 168 | params: updated paramters 169 | state: updated optimizer state 170 | loss: value of loss function 171 | aux_data: auxiliary value of loss function 172 | pmove: accept rate of MCMC move 173 | search_direction: calculated gradient. 174 | """ 175 | data, pmove = mcmc_step(params, data, key, mcmc_width) 176 | 177 | # Optimization step 178 | (loss, aux_data), search_direction = val_and_grad(params, data) 179 | search_direction = constants.pmean_if_pmap(search_direction, 180 | axis_name=constants.PMAP_AXIS_NAME) 181 | state, params = opt_update(t, search_direction, params, state) 182 | return data, params, state, loss, aux_data, pmove, search_direction 183 | 184 | return step 185 | 186 | 187 | @functools.partial(jax.vmap, in_axes=(0, 0), out_axes=0) 188 | def direct_product(x, y): 189 | return x.ravel()[:, None] * y.ravel()[None, :] 190 | 191 | 192 | def make_sr_matrix(network): 193 | ''' 194 | which is used to calculate the fisher matrix, abandoned now. 195 | :param network: 196 | :return: 197 | ''' 198 | network_grad = jax.grad(network.apply, argnums=0, holomorphic=True) 199 | batch_network_grad = jax.vmap(network_grad, in_axes=(None, 0)) 200 | 201 | def sr_matrix(params, data): 202 | complex_params = jax.tree_map(lambda x: x+0j, params) 203 | batch_diffs = batch_network_grad(complex_params, data) 204 | 205 | s1 = jax.tree_map(lambda x: jnp.mean(direct_product(jnp.conjugate(x), x), 206 | axis=0), 207 | batch_diffs) 208 | s2 = jax.tree_map(lambda x: (jnp.mean(jnp.conjugate(x), axis=0).ravel()[:, None] * 209 | jnp.mean(x, axis=0).ravel()[None, :] 210 | ), 211 | batch_diffs) 212 | s1 = constants.pmean_if_pmap(s1, axis_name=constants.PMAP_AXIS_NAME) 213 | s2 = constants.pmean_if_pmap(s2, axis_name=constants.PMAP_AXIS_NAME) 214 | matrix = jax.tree_multimap(lambda x, y: x - y, s1, s2) 215 | return matrix 216 | 217 | return sr_matrix 218 | 219 | 220 | 221 | 222 | 223 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /DeepSolid/utils/elements.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 17 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 18 | 19 | import collections 20 | from typing import Optional 21 | import attr 22 | 23 | 24 | @attr.s 25 | class Element(object): 26 | """Chemical element. 27 | 28 | Attributes: 29 | symbol: official symbol of element. 30 | atomic_number: atomic number of element. 31 | period: period to which the element belongs. 32 | spin: overrides default ground-state spin-configuration based on the 33 | element's group (main groups only). 34 | """ 35 | symbol: str = attr.ib() 36 | atomic_number: int = attr.ib() 37 | period: int = attr.ib() 38 | _spin: Optional[int] = attr.ib(default=None, repr=False) 39 | 40 | @property 41 | def group(self) -> int: 42 | """Group to which element belongs. Set to -1 for actines and lanthanides.""" 43 | is_lanthanide = (58 <= self.atomic_number <= 71) 44 | is_actinide = (90 <= self.atomic_number <= 103) 45 | if is_lanthanide or is_actinide: 46 | return -1 47 | if self.symbol == 'He': 48 | # n=1 shell only has s orbital -> He is a noble gas. 49 | return 18 50 | period_starts = (1, 3, 11, 19, 37, 55, 87) 51 | period_start = period_starts[self.period - 1] 52 | group_ = self.atomic_number - period_start + 1 53 | # Adjust for absence of d block in periods 2 and 3. 54 | if self.period < 4 and group_ > 2: 55 | group_ += 10 56 | # Adjust for Lanthanides and Actinides in periods 6 and 7. 57 | if self.period >= 6 and group_ > 3: 58 | group_ -= 14 59 | return group_ 60 | 61 | @property 62 | def spin_config(self) -> int: 63 | """Canonical spin configuration (via Hund's rules) of neutral atom. 64 | 65 | Returns: 66 | Number of unpaired electrons (as required by PySCF) in the neutral atom's 67 | ground state. 68 | 69 | Raises: 70 | NotImplementedError: if element is a transition metal and the spin 71 | configuration is not set at initialization. 72 | """ 73 | if self._spin is not None: 74 | return self._spin 75 | unpaired = {1: 1, 2: 0, 3: 1, 13: 1, 14: 2, 15: 3, 16: 2, 17: 1, 18: 0} 76 | if self.group in unpaired: 77 | return unpaired[self.group] 78 | else: 79 | raise NotImplementedError( 80 | 'Spin configuration for transition metals not set.') 81 | 82 | @property 83 | def nalpha(self) -> int: 84 | """Returns the number of alpha electrons of the ground state neutral atom. 85 | 86 | Without loss of generality, the number of alpha electrons is taken to be 87 | equal to or greater than the number of beta electrons. 88 | """ 89 | electrons = self.atomic_number 90 | unpaired = self.spin_config 91 | return (electrons + unpaired) // 2 92 | 93 | @property 94 | def nbeta(self) -> int: 95 | """Returns the number of beta electrons of the ground state neutral atom. 96 | 97 | Without loss of generality, the number of alpha electrons is taken to be 98 | equal to or greater than the number of beta electrons. 99 | """ 100 | electrons = self.atomic_number 101 | unpaired = self.spin_config 102 | return (electrons - unpaired) // 2 103 | 104 | 105 | # Atomic symbols for all known elements 106 | # Generated using 107 | # def _element(symbol, atomic_number): 108 | # # period_start[n] = atomic number of group 1 element in (n+1)-th period. 109 | # period_start = (1, 3, 11, 19, 37, 55, 87) 110 | # for p, group1_no in enumerate(period_start): 111 | # if atomic_number < group1_no: 112 | # # In previous period but n is 0-based. 113 | # period = p 114 | # break 115 | # else: 116 | # period = p + 1 117 | # return Element(symbol=symbol, atomic_number=atomic_number, period=period) 118 | # [_element(s, n+1) for n, s in enumerate(symbols)] 119 | # where symbols is the list of chemical symbols of all elements. 120 | _ELEMENTS = ( 121 | Element(symbol='H', atomic_number=1, period=1), 122 | Element(symbol='He', atomic_number=2, period=1), 123 | Element(symbol='Li', atomic_number=3, period=2), 124 | Element(symbol='Be', atomic_number=4, period=2), 125 | Element(symbol='B', atomic_number=5, period=2), 126 | Element(symbol='C', atomic_number=6, period=2), 127 | Element(symbol='N', atomic_number=7, period=2), 128 | Element(symbol='O', atomic_number=8, period=2), 129 | Element(symbol='F', atomic_number=9, period=2), 130 | Element(symbol='Ne', atomic_number=10, period=2), 131 | Element(symbol='Na', atomic_number=11, period=3), 132 | Element(symbol='Mg', atomic_number=12, period=3), 133 | Element(symbol='Al', atomic_number=13, period=3), 134 | Element(symbol='Si', atomic_number=14, period=3), 135 | Element(symbol='P', atomic_number=15, period=3), 136 | Element(symbol='S', atomic_number=16, period=3), 137 | Element(symbol='Cl', atomic_number=17, period=3), 138 | Element(symbol='Ar', atomic_number=18, period=3), 139 | Element(symbol='K', atomic_number=19, period=4), 140 | Element(symbol='Ca', atomic_number=20, period=4), 141 | Element(symbol='Sc', atomic_number=21, period=4, spin=1), 142 | Element(symbol='Ti', atomic_number=22, period=4, spin=2), 143 | Element(symbol='V', atomic_number=23, period=4, spin=3), 144 | Element(symbol='Cr', atomic_number=24, period=4, spin=6), 145 | Element(symbol='Mn', atomic_number=25, period=4, spin=5), 146 | Element(symbol='Fe', atomic_number=26, period=4, spin=4), 147 | Element(symbol='Co', atomic_number=27, period=4, spin=3), 148 | Element(symbol='Ni', atomic_number=28, period=4, spin=2), 149 | Element(symbol='Cu', atomic_number=29, period=4, spin=1), 150 | Element(symbol='Zn', atomic_number=30, period=4, spin=0), 151 | Element(symbol='Ga', atomic_number=31, period=4), 152 | Element(symbol='Ge', atomic_number=32, period=4), 153 | Element(symbol='As', atomic_number=33, period=4), 154 | Element(symbol='Se', atomic_number=34, period=4), 155 | Element(symbol='Br', atomic_number=35, period=4), 156 | Element(symbol='Kr', atomic_number=36, period=4), 157 | Element(symbol='Rb', atomic_number=37, period=5), 158 | Element(symbol='Sr', atomic_number=38, period=5), 159 | Element(symbol='Y', atomic_number=39, period=5, spin=1), 160 | Element(symbol='Zr', atomic_number=40, period=5, spin=2), 161 | Element(symbol='Nb', atomic_number=41, period=5, spin=5), 162 | Element(symbol='Mo', atomic_number=42, period=5, spin=6), 163 | Element(symbol='Tc', atomic_number=43, period=5, spin=5), 164 | Element(symbol='Ru', atomic_number=44, period=5, spin=4), 165 | Element(symbol='Rh', atomic_number=45, period=5, spin=3), 166 | Element(symbol='Pd', atomic_number=46, period=5, spin=0), 167 | Element(symbol='Ag', atomic_number=47, period=5, spin=1), 168 | Element(symbol='Cd', atomic_number=48, period=5, spin=0), 169 | Element(symbol='In', atomic_number=49, period=5), 170 | Element(symbol='Sn', atomic_number=50, period=5), 171 | Element(symbol='Sb', atomic_number=51, period=5), 172 | Element(symbol='Te', atomic_number=52, period=5), 173 | Element(symbol='I', atomic_number=53, period=5), 174 | Element(symbol='Xe', atomic_number=54, period=5), 175 | Element(symbol='Cs', atomic_number=55, period=6), 176 | Element(symbol='Ba', atomic_number=56, period=6), 177 | Element(symbol='La', atomic_number=57, period=6), 178 | Element(symbol='Ce', atomic_number=58, period=6), 179 | Element(symbol='Pr', atomic_number=59, period=6), 180 | Element(symbol='Nd', atomic_number=60, period=6), 181 | Element(symbol='Pm', atomic_number=61, period=6), 182 | Element(symbol='Sm', atomic_number=62, period=6), 183 | Element(symbol='Eu', atomic_number=63, period=6), 184 | Element(symbol='Gd', atomic_number=64, period=6), 185 | Element(symbol='Tb', atomic_number=65, period=6), 186 | Element(symbol='Dy', atomic_number=66, period=6), 187 | Element(symbol='Ho', atomic_number=67, period=6), 188 | Element(symbol='Er', atomic_number=68, period=6), 189 | Element(symbol='Tm', atomic_number=69, period=6), 190 | Element(symbol='Yb', atomic_number=70, period=6), 191 | Element(symbol='Lu', atomic_number=71, period=6), 192 | Element(symbol='Hf', atomic_number=72, period=6), 193 | Element(symbol='Ta', atomic_number=73, period=6), 194 | Element(symbol='W', atomic_number=74, period=6), 195 | Element(symbol='Re', atomic_number=75, period=6), 196 | Element(symbol='Os', atomic_number=76, period=6), 197 | Element(symbol='Ir', atomic_number=77, period=6), 198 | Element(symbol='Pt', atomic_number=78, period=6), 199 | Element(symbol='Au', atomic_number=79, period=6), 200 | Element(symbol='Hg', atomic_number=80, period=6), 201 | Element(symbol='Tl', atomic_number=81, period=6), 202 | Element(symbol='Pb', atomic_number=82, period=6), 203 | Element(symbol='Bi', atomic_number=83, period=6), 204 | Element(symbol='Po', atomic_number=84, period=6), 205 | Element(symbol='At', atomic_number=85, period=6), 206 | Element(symbol='Rn', atomic_number=86, period=6), 207 | Element(symbol='Fr', atomic_number=87, period=7), 208 | Element(symbol='Ra', atomic_number=88, period=7), 209 | Element(symbol='Ac', atomic_number=89, period=7), 210 | Element(symbol='Th', atomic_number=90, period=7), 211 | Element(symbol='Pa', atomic_number=91, period=7), 212 | Element(symbol='U', atomic_number=92, period=7), 213 | Element(symbol='Np', atomic_number=93, period=7), 214 | Element(symbol='Pu', atomic_number=94, period=7), 215 | Element(symbol='Am', atomic_number=95, period=7), 216 | Element(symbol='Cm', atomic_number=96, period=7), 217 | Element(symbol='Bk', atomic_number=97, period=7), 218 | Element(symbol='Cf', atomic_number=98, period=7), 219 | Element(symbol='Es', atomic_number=99, period=7), 220 | Element(symbol='Fm', atomic_number=100, period=7), 221 | Element(symbol='Md', atomic_number=101, period=7), 222 | Element(symbol='No', atomic_number=102, period=7), 223 | Element(symbol='Lr', atomic_number=103, period=7), 224 | Element(symbol='Rf', atomic_number=104, period=7), 225 | Element(symbol='Db', atomic_number=105, period=7), 226 | Element(symbol='Sg', atomic_number=106, period=7), 227 | Element(symbol='Bh', atomic_number=107, period=7), 228 | Element(symbol='Hs', atomic_number=108, period=7), 229 | Element(symbol='Mt', atomic_number=109, period=7), 230 | Element(symbol='Ds', atomic_number=110, period=7), 231 | Element(symbol='Rg', atomic_number=111, period=7), 232 | Element(symbol='Cn', atomic_number=112, period=7), 233 | Element(symbol='Nh', atomic_number=113, period=7), 234 | Element(symbol='Fl', atomic_number=114, period=7), 235 | Element(symbol='Mc', atomic_number=115, period=7), 236 | Element(symbol='Lv', atomic_number=116, period=7), 237 | Element(symbol='Ts', atomic_number=117, period=7), 238 | Element(symbol='Og', atomic_number=118, period=7), 239 | ) 240 | 241 | ATOMIC_NUMS = {element.atomic_number: element for element in _ELEMENTS} 242 | 243 | # Lookup by symbol instead of atomic number. 244 | SYMBOLS = {element.symbol: element for element in _ELEMENTS} 245 | 246 | # Lookup by period. 247 | PERIODS = collections.defaultdict(list) 248 | for element in _ELEMENTS: 249 | PERIODS[element.period].append(element) 250 | PERIODS = {period: tuple(elements) for period, elements in PERIODS.items()} 251 | -------------------------------------------------------------------------------- /DeepSolid/utils/kfac_ferminet_alpha/layers_and_loss_tags.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 17 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 18 | """A module for registering already known functions for tagging patterns.""" 19 | import functools 20 | 21 | from typing import Sequence, Tuple, TypeVar 22 | 23 | import jax 24 | from jax import core as jax_core 25 | from jax import lax 26 | from jax import lib as jax_lib 27 | from jax.interpreters import batching as jax_batching 28 | import jax.numpy as jnp 29 | 30 | _T = TypeVar("_T") 31 | 32 | 33 | class LossTag(jax_core.Primitive): 34 | """A tagging primitive specifically for losses.""" 35 | multiple_results = True 36 | 37 | def __init__(self, cls, num_inputs: int, num_targets: int = 1): 38 | super().__init__(cls.__name__ + "_tag") 39 | self._cls = cls 40 | self._num_inputs = num_inputs 41 | self._num_targets = num_targets 42 | jax.xla.translations[self] = self.xla_translation 43 | jax.ad.primitive_jvps[self] = self.jvp 44 | # This line defines how does the tag behave under vmap. It is required for 45 | # any primitive that can be used inside a vmap. The reason why we want to 46 | # allow this is two fold - one to not break user code when the tags are not 47 | # used at all, and two - to be able to define a network with code for a 48 | # single example which is the vmap-ed for a batch. 49 | jax_batching.primitive_batchers[self] = self.batching 50 | 51 | @property 52 | def num_inputs(self) -> int: 53 | return self._num_inputs 54 | 55 | @property 56 | def num_targets(self) -> int: 57 | return self._num_targets 58 | 59 | def loss(self, *args, weight: float = 1.0, **kwargs): 60 | return self._cls(*args, weight=weight, **kwargs) 61 | 62 | def loss_evaluate(self, *args, weight: float = 1.0, **kwargs): 63 | return self.loss(*args, weight=weight, **kwargs).evaluate() 64 | 65 | def get_outputs(self, *args, weight: float, return_loss: bool, **kwargs): 66 | if len(args) < self.num_inputs: 67 | raise ValueError("Inputs to the tag are not enough.") 68 | if len(args) < self.num_inputs + self.num_targets: 69 | if len(args) != self.num_inputs: 70 | raise ValueError("Inputs to the tag are not quite enough.") 71 | if return_loss: 72 | raise ValueError("Can not have return_loss=True when there are no " 73 | "targets.") 74 | return args 75 | if len(args) > self.num_inputs + self.num_targets: 76 | raise ValueError("Inputs to the tag are too many.") 77 | if return_loss: 78 | return self.loss(*args, weight=weight, **kwargs).evaluate() 79 | else: 80 | return args 81 | 82 | def impl(self, *args, weight: float, return_loss: bool, **kwargs): 83 | return self.get_outputs(*args, weight=weight, return_loss=return_loss) 84 | 85 | def abstract_eval(self, *args, weight: float, return_loss: bool, **kwargs): 86 | return self.get_outputs(*args, weight=weight, return_loss=return_loss) 87 | 88 | def xla_translation( 89 | self, 90 | c, 91 | *args, 92 | weight: float = 1.0, 93 | return_loss: bool = False, 94 | **kwargs, 95 | ): 96 | outputs = self.get_outputs( 97 | *args, weight=weight, return_loss=return_loss, **kwargs) 98 | if isinstance(outputs, tuple): 99 | return jax_lib.xla_client.ops.Tuple(c, outputs) 100 | return outputs 101 | 102 | def jvp( 103 | self, 104 | arg_values, 105 | arg_tangents, 106 | weight: float, 107 | return_loss: bool, 108 | **kwargs, 109 | ): 110 | if len(arg_values) != len(arg_tangents): 111 | raise ValueError("Values and tangents are not the same length.") 112 | primal_output = self.bind( 113 | *arg_values, weight=weight, return_loss=return_loss, **kwargs) 114 | if len(arg_values) == self.num_inputs: 115 | tangents_out = self.get_outputs( 116 | *arg_tangents, weight=weight, return_loss=return_loss, **kwargs) 117 | elif return_loss: 118 | tangents_out = jax.jvp( 119 | functools.partial(self.loss_evaluate, weight=weight, **kwargs), 120 | arg_tangents, arg_tangents)[1] 121 | else: 122 | tangents_out = arg_tangents 123 | return primal_output, tangents_out 124 | 125 | def batching(self, batched_args, batched_dims, **kwargs): 126 | return self.bind(*batched_args, **kwargs), batched_dims[0] 127 | 128 | 129 | class LayerTag(jax_core.Primitive): 130 | """A tagging primitive that is used to mark/tag computation.""" 131 | 132 | def __init__(self, name: str, num_inputs: int, num_outputs: int): 133 | super().__init__(name) 134 | if num_outputs > 1: 135 | raise NotImplementedError( 136 | f"Only single outputs are supported, got: num_outputs={num_outputs}") 137 | self._num_outputs = num_outputs 138 | self._num_inputs = num_inputs 139 | jax.xla.translations[self] = self.xla_translation 140 | jax.ad.deflinear(self, self.transpose) 141 | jax.ad.primitive_transposes[self] = self.transpose 142 | # This line defines how does the tag behave under vmap. It is required for 143 | # any primitive that can be used inside a vmap. The reason why we want to 144 | # allow this is two fold - one to not break user code when the tags are not 145 | # used at all, and two - to be able to define a network with code for a 146 | # single example which is the vmap-ed for a batch. 147 | jax_batching.primitive_batchers[self] = self.batching 148 | 149 | @property 150 | def num_outputs(self) -> int: 151 | return self._num_outputs 152 | 153 | @property 154 | def num_inputs(self) -> int: 155 | return self._num_inputs 156 | 157 | def split_all_inputs( 158 | self, 159 | all_inputs: Sequence[_T], 160 | ) -> Tuple[Sequence[_T], Sequence[_T], Sequence[_T]]: 161 | outputs = tuple(all_inputs[:self.num_outputs]) 162 | inputs = tuple(all_inputs[self.num_outputs:self.num_outputs + 163 | self.num_inputs]) 164 | params = tuple(all_inputs[self.num_outputs + self.num_inputs:]) 165 | return outputs, inputs, params 166 | 167 | def get_outputs(self, *operands: _T, **kwargs) -> _T: 168 | assert self.num_outputs == 1 169 | return operands[0] 170 | 171 | def xla_translation(self, c, *operands: _T, **kwargs) -> _T: 172 | return self.get_outputs(*operands, **kwargs) 173 | 174 | @staticmethod 175 | def transpose(cotangent, *operands, **kwargs): 176 | return (cotangent,) + (None,) * (len(operands) - 1) 177 | 178 | def impl(self, *operands, **kwargs): 179 | return self.get_outputs(*operands, **kwargs) 180 | 181 | def abstract_eval(self, *abstract_operands, **kwargs): 182 | return self.get_outputs(*abstract_operands, **kwargs) 183 | 184 | def batching(self, batched_operands, batched_dims, **kwargs): 185 | return self.bind(*batched_operands, **kwargs), batched_dims[0] 186 | 187 | 188 | # _____ _ 189 | # / ____| (_) 190 | # | | __ ___ _ __ ___ _ __ _ ___ 191 | # | | |_ |/ _ \ '_ \ / _ \ '__| |/ __| 192 | # | |__| | __/ | | | __/ | | | (__ 193 | # \_____|\___|_| |_|\___|_| |_|\___| 194 | # 195 | # 196 | 197 | generic_tag = LayerTag(name="generic_tag", num_inputs=0, num_outputs=1) 198 | 199 | 200 | def register_generic(parameter: _T) -> _T: 201 | return generic_tag.bind(parameter) 202 | 203 | 204 | # _____ 205 | # | __ \ 206 | # | | | | ___ _ __ ___ ___ 207 | # | | | |/ _ \ '_ \/ __|/ _ \ 208 | # | |__| | __/ | | \__ \ __/ 209 | # |_____/ \___|_| |_|___/\___| 210 | # 211 | 212 | dense_tag = LayerTag(name="dense_tag", num_inputs=1, num_outputs=1) 213 | 214 | 215 | def register_dense(y, x, w, b=None): 216 | if b is None: 217 | return dense_tag.bind(y, x, w) 218 | return dense_tag.bind(y, x, w, b) 219 | 220 | 221 | def dense_func(x, params): 222 | """Example of a dense layer function.""" 223 | w = params[0] 224 | y = jnp.matmul(x, w) 225 | if len(params) == 1: 226 | # No bias 227 | return y 228 | # Add bias 229 | return y + params[1] 230 | 231 | 232 | def dense_tagging(jaxpr, inverse_map, values_map): 233 | """Correctly registers a dense layer pattern.""" 234 | del inverse_map 235 | in_values = [values_map[v] for v in jaxpr.invars] 236 | out_values = [values_map[v] for v in jaxpr.outvars] 237 | return register_dense(out_values[0], *in_values) 238 | 239 | 240 | # ___ _____ _____ _ _ _ 241 | # |__ \| __ \ / ____| | | | | (_) 242 | # ) | | | | | | ___ _ ____ _____ | |_ _| |_ _ ___ _ __ 243 | # / /| | | | | | / _ \| '_ \ \ / / _ \| | | | | __| |/ _ \| "_ \ 244 | # / /_| |__| | | |___| (_) | | | \ V / (_) | | |_| | |_| | (_) | | | | 245 | # |____|_____/ \_____\___/|_| |_|\_/ \___/|_|\__,_|\__|_|\___/|_| |_| 246 | # 247 | 248 | conv2d_tag = LayerTag(name="conv2d_tag", num_inputs=1, num_outputs=1) 249 | 250 | 251 | def register_conv2d(y, x, w, b=None, **kwargs): 252 | if b is None: 253 | return conv2d_tag.bind(y, x, w, **kwargs) 254 | return conv2d_tag.bind(y, x, w, b, **kwargs) 255 | 256 | 257 | def conv2d_func(x, params): 258 | """Example of a conv2d layer function.""" 259 | w = params[0] 260 | y = lax.conv_general_dilated( 261 | x, 262 | w, 263 | window_strides=(2, 2), 264 | padding="SAME", 265 | dimension_numbers=("NHWC", "HWIO", "NHWC")) 266 | if len(params) == 1: 267 | # No bias 268 | return y 269 | # Add bias 270 | return y + params[1][None, None, None] 271 | 272 | 273 | def conv2d_tagging(jaxpr, inverse_map, values_map): 274 | """Correctly registers a conv2d layer pattern.""" 275 | in_values = [values_map[v] for v in jaxpr.invars] 276 | out_values = [values_map[v] for v in jaxpr.outvars] 277 | keys = [k for k in inverse_map.keys() if isinstance(k, str)] 278 | keys = [k for k in keys if k.startswith("conv_general_dilated")] 279 | if len(keys) != 1: 280 | raise ValueError("Did not find any conv_general_dilated!") 281 | kwargs = inverse_map[keys[0]].params 282 | return register_conv2d(out_values[0], *in_values, **kwargs) 283 | 284 | 285 | # _____ _ _ _____ _ _ __ _ 286 | # / ____| | | | | / ____| | (_)/ _| | 287 | # | (___ ___ __ _| | ___ __ _ _ __ __| | | (___ | |__ _| |_| |_ 288 | # \___ \ / __/ _` | |/ _ \ / _` | '_ \ / _` | \___ \| '_ \| | _| __| 289 | # ____) | (_| (_| | | __/ | (_| | | | | (_| | ____) | | | | | | | |_ 290 | # |_____/ \___\__,_|_|\___| \__,_|_| |_|\__,_| |_____/|_| |_|_|_| \__| 291 | # 292 | 293 | scale_and_shift_tag = LayerTag( 294 | name="scale_and_shift_tag", num_inputs=1, num_outputs=1) 295 | 296 | 297 | def register_scale_and_shift(y, args, has_scale: bool, has_shift: bool): 298 | assert has_scale or has_shift 299 | x, args = args[0], args[1:] 300 | return scale_and_shift_tag.bind( 301 | y, x, *args, has_scale=has_scale, has_shift=has_shift) 302 | 303 | 304 | def scale_and_shift_func(x, params, has_scale: bool, has_shift: bool): 305 | """Example of a scale and shift function.""" 306 | if has_scale and has_shift: 307 | scale, shift = params 308 | return x * scale + shift 309 | elif has_scale: 310 | return x * params[0] 311 | elif has_shift: 312 | return x + params[0] 313 | else: 314 | raise ValueError() 315 | 316 | 317 | def scale_and_shift_tagging( 318 | jaxpr, 319 | inverse_map, 320 | values_map, 321 | has_scale: bool, 322 | has_shift: bool, 323 | ): 324 | """Correctly registers a scale and shift layer pattern.""" 325 | del inverse_map 326 | in_values = [values_map[v] for v in jaxpr.invars] 327 | out_values = [values_map[v] for v in jaxpr.outvars] 328 | return register_scale_and_shift(out_values[0], in_values, has_scale, 329 | has_shift) 330 | 331 | 332 | def batch_norm_func( 333 | inputs: Tuple[jnp.ndarray, jnp.ndarray], 334 | params: Tuple[jnp.ndarray, jnp.ndarray], 335 | ) -> jnp.ndarray: 336 | """Example of batch norm as is defined in Haiku.""" 337 | x, y = inputs 338 | scale, shift = params 339 | inv = scale * y 340 | return x * inv + shift 341 | 342 | 343 | def batch_norm_tagging_func( 344 | jaxpr, 345 | inverse_map, 346 | values_map, 347 | has_scale: bool, 348 | has_shift: bool, 349 | ): 350 | """Correctly registers a batch norm layer pattern as is defined in Haiku.""" 351 | del inverse_map 352 | in_values = [values_map[v] for v in jaxpr.invars] 353 | out_values = [values_map[v] for v in jaxpr.outvars] 354 | # The first two are both multipliers with the scale so we merge them 355 | in_values = [in_values[0] * in_values[1]] + in_values[2:] 356 | return register_scale_and_shift(out_values[0], in_values, has_scale, 357 | has_shift) 358 | -------------------------------------------------------------------------------- /DeepSolid/pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 16 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 17 | 18 | import functools 19 | 20 | import numpy as np 21 | from absl import logging 22 | import jax 23 | import jax.numpy as jnp 24 | import optax 25 | 26 | from DeepSolid import hf 27 | from DeepSolid import qmc 28 | from DeepSolid import constants 29 | 30 | 31 | def _batch_slater_slogdet(scf: hf.SCF, dim=3): 32 | 33 | def batch_slater_slogdet(params, x): 34 | del params 35 | batch = x.shape[0] 36 | x = x.reshape([batch, -1, dim]) 37 | result = scf.eval_slogdet(x)[1] 38 | return result 39 | 40 | return batch_slater_slogdet 41 | 42 | 43 | def make_pretrain_step(batch_orbitals, 44 | batch_network, 45 | latvec, 46 | optimizer, 47 | full_det=False, 48 | ): 49 | """ 50 | generate the low-level pretrain function 51 | :param batch_orbitals: batched function return the orbital matrix of wavefunction 52 | :param batch_network: batched function return the slogdet of wavefunction 53 | :param latvec: lattice vector of primitive cell 54 | :param optimizer: optimizer function 55 | :return: the low-level pretrain function 56 | """ 57 | 58 | def pretrain_step(data, target, params, state, key): 59 | """ 60 | One iteration of pretraining to match HF. 61 | :param data: batched input data, a [batch, 3N] dimensional vector. 62 | :param target: corresponding HF matrix values. 63 | :param params: A dictionary of parameters. 64 | :param state: optimizer state. 65 | :param key: PRNG key. 66 | :return: pretrained params, data, state, loss value, slogdet of neural network, 67 | and number of accepted MCMC moves. 68 | """ 69 | 70 | def loss_fn(x, p, target): 71 | """ 72 | loss function 73 | :param x: batched input data, a [batch, 3N] dimensional vector. 74 | :param p: A dictionary of parameters. 75 | :param target: corresponding HF matrix values. 76 | :return: value of loss function 77 | """ 78 | predict = batch_orbitals(p, x) 79 | if full_det: 80 | batch_size = predict[0].shape[0] 81 | na = target[0].shape[1] 82 | nb = target[1].shape[1] 83 | target = [jnp.concatenate( 84 | (jnp.concatenate((target[0], jnp.zeros((batch_size, na, nb))), axis=-1), 85 | jnp.concatenate((jnp.zeros((batch_size, nb, na)), target[1]), axis=-1)), 86 | axis=-2)] 87 | result = jnp.array([jnp.mean(jnp.abs(tar[:, None, ...] - pre)**2) 88 | for tar, pre in zip(target, predict)]).mean() 89 | return constants.pmean_if_pmap(result, axis_name=constants.PMAP_AXIS_NAME) 90 | 91 | val_and_grad = jax.value_and_grad(loss_fn, argnums=1) 92 | loss_val, search_direction = val_and_grad(data, params, target) 93 | search_direction = constants.pmean_if_pmap( 94 | search_direction, axis_name=constants.PMAP_AXIS_NAME) 95 | updates, state = optimizer.update(search_direction, state, params) 96 | params = optax.apply_updates(params, updates) 97 | logprob = 2 * batch_network(params, data) 98 | data, key, logprob, num_accepts = qmc.mh_update(params=params, 99 | f=batch_network, 100 | x1=data, 101 | key=key, 102 | lp_1=logprob, 103 | num_accepts=0, 104 | latvec=latvec) 105 | return data, params, state, loss_val, logprob, num_accepts 106 | 107 | return pretrain_step 108 | 109 | 110 | def pretrain_hartree_fock(params, 111 | data, 112 | batch_network, 113 | batch_orbitals, 114 | sharded_key, 115 | cell, 116 | scf_approx: hf.SCF, 117 | full_det=False, 118 | iterations=1000, 119 | learning_rate=5e-3, 120 | ): 121 | """ 122 | generates a function used for pretrain, and neural network is used as the target sample. 123 | :param params: A dictionary of parameters. 124 | :param data: The input data, a 3N dimensional vector. 125 | :param batch_network: batched function return the slogdet of wavefunction 126 | :param batch_orbitals: batched function return the orbital matrix of wavefunction 127 | :param sharded_key: PRNG key 128 | :param cell: pyscf object of simulation cell 129 | :param scf_approx: hf.SCF object in DeepSolid. Used to eval the orbital value of Hartree Fock ansatz. 130 | :param full_det: If true, the determinants are dense, rather than block-sparse. 131 | True by default, false is still available for backward compatibility. 132 | Thus, the output shape of the orbitals will be (ndet, nalpha+nbeta, 133 | nalpha+nbeta) if True, and (ndet, nalpha, nalpha) and (ndet, nbeta, nbeta) 134 | if False. 135 | :param iterations: pretrain iterations 136 | :param learning_rate: learning rate of pretrain 137 | :return: pretrained parameters and electron positions. 138 | """ 139 | 140 | optimizer = optax.adam(learning_rate) 141 | opt_state_pt = constants.pmap(optimizer.init)(params) 142 | leading_shape = data.shape[:-1] 143 | 144 | pretrain_step = make_pretrain_step(batch_orbitals=batch_orbitals, 145 | batch_network=batch_network, 146 | latvec=cell.lattice_vectors(), 147 | optimizer=optimizer, 148 | full_det=full_det,) 149 | pretrain_step = constants.pmap(pretrain_step) 150 | 151 | for t in range(iterations): 152 | target = scf_approx.eval_orb_mat(np.array(data.reshape([-1, cell.nelectron, 3]), dtype=np.float64)) 153 | # PYSCF PBC eval_gto seems only accept float64 array, float32 array will easily cause nan or underflow. 154 | target = [jnp.array(tar) for tar in target] 155 | target = [tar.reshape([*leading_shape, ne, ne]) for tar, ne in zip(target, cell.nelec) if ne > 0] 156 | 157 | slogprob_target = [2 * jnp.linalg.slogdet(tar)[1] for tar in target] 158 | slogprob_target = functools.reduce(lambda x, y: x+y, slogprob_target) 159 | sharded_key, subkeys = constants.p_split(sharded_key) 160 | data, params, opt_state_pt, loss, logprob, num_accepts = pretrain_step( 161 | data, target, params, opt_state_pt, subkeys) 162 | logging.info('Pretrain iter %05d: Loss=%03.6f, pmove=%0.2f, ' 163 | 'Norm of Net prob=%03.4f, Norm of HF prob=%03.4f', 164 | t, loss[0], 165 | jnp.mean(num_accepts) / leading_shape[-1], 166 | jnp.mean(logprob), 167 | jnp.mean(slogprob_target)) 168 | 169 | return params, data 170 | 171 | 172 | def pretrain_hartree_fock_usingHF(params, 173 | data, 174 | batch_orbitals, 175 | sharded_key, 176 | cell, 177 | scf_approx: hf.SCF, 178 | iterations=1000, 179 | learning_rate=5e-3, 180 | nsteps=1, 181 | full_det=False, 182 | ): 183 | """ 184 | generates a function used for pretrain, and HF ansatz is used as the target sample. 185 | :param params: A dictionary of parameters. 186 | :param data: The input data, a 3N dimensional vector. 187 | :param batch_network: batched function return the slogdet of wavefunction 188 | :param batch_orbitals: batched function return the orbital matrix of wavefunction 189 | :param sharded_key: PRNG key 190 | :param cell: pyscf object of simulation cell 191 | :param scf_approx: hf.SCF object in DeepSolid. Used to eval the orbital value of Hartree Fock ansatz. 192 | :param full_det: If true, the determinants are dense, rather than block-sparse. 193 | True by default, false is still available for backward compatibility. 194 | Thus, the output shape of the orbitals will be (ndet, nalpha+nbeta, 195 | nalpha+nbeta) if True, and (ndet, nalpha, nalpha) and (ndet, nbeta, nbeta) 196 | if False. 197 | :param iterations: pretrain iterations 198 | :param learning_rate: learning rate of pretrain 199 | :return: pretrained parameters and electron positions. 200 | """ 201 | 202 | optimizer = optax.adam(learning_rate) 203 | opt_state_pt = constants.pmap(optimizer.init)(params) 204 | leading_shape = data.shape[:-1] 205 | 206 | def make_pretrain_step(batch_orbitals, 207 | latvec, 208 | optimizer, 209 | ): 210 | """ 211 | generate the low-level pretrain function 212 | :param batch_orbitals: batched function return the orbital matrix of wavefunction 213 | :param latvec: lattice vector of primitive cell 214 | :param optimizer: optimizer function 215 | :return: the low-level pretrain function 216 | """ 217 | 218 | def pretrain_step(data, target, params, state): 219 | """ 220 | One iteration of pretraining to match HF. 221 | :param data: batched input data, a [batch, 3N] dimensional vector. 222 | :param target: corresponding HF matrix values. 223 | :param params: A dictionary of parameters. 224 | :param state: optimizer state. 225 | :return: pretrained params, data, state, loss value. 226 | """ 227 | 228 | def loss_fn(x, p, target): 229 | """ 230 | loss function 231 | :param x: batched input data, a [batch, 3N] dimensional vector. 232 | :param p: A dictionary of parameters. 233 | :param target: corresponding HF matrix values. 234 | :return: value of loss function 235 | """ 236 | predict = batch_orbitals(p, x) 237 | if full_det: 238 | batch_size = predict[0].shape[0] 239 | na = target[0].shape[1] 240 | nb = target[1].shape[1] 241 | target = [jnp.concatenate( 242 | (jnp.concatenate((target[0], jnp.zeros((batch_size, na, nb))), axis=-1), 243 | jnp.concatenate((jnp.zeros((batch_size, nb, na)), target[1]), axis=-1)), 244 | axis=-2)] 245 | result = jnp.array([jnp.mean(jnp.abs(tar[:, None, ...] - pre) ** 2) 246 | for tar, pre in zip(target, predict)]).mean() 247 | return constants.pmean_if_pmap(result, axis_name=constants.PMAP_AXIS_NAME) 248 | 249 | val_and_grad = jax.value_and_grad(loss_fn, argnums=1) 250 | loss_val, search_direction = val_and_grad(data, params, target) 251 | search_direction = constants.pmean_if_pmap( 252 | search_direction, axis_name=constants.PMAP_AXIS_NAME) 253 | updates, state = optimizer.update(search_direction, state, params) 254 | params = optax.apply_updates(params, updates) 255 | 256 | return params, state, loss_val 257 | 258 | return pretrain_step 259 | 260 | 261 | pretrain_step = make_pretrain_step(batch_orbitals=batch_orbitals, 262 | latvec=cell.lattice_vectors(), 263 | optimizer=optimizer,) 264 | pretrain_step = constants.pmap(pretrain_step) 265 | batch_network = _batch_slater_slogdet(scf_approx) 266 | logprob = 2 * batch_network(None, data.reshape([-1, cell.nelectron * 3])) 267 | 268 | def step_fn(inputs): 269 | return qmc.mh_update(params, 270 | batch_network, 271 | *inputs, 272 | latvec=cell.lattice_vectors(), 273 | ) 274 | 275 | for t in range(iterations): 276 | 277 | for _ in range(nsteps): 278 | sharded_key, subkeys = constants.p_split(sharded_key) 279 | inputs = (data.reshape([-1, cell.nelectron * 3]), 280 | sharded_key[0], 281 | logprob, 282 | 0.) 283 | data, _, logprob, num_accepts = step_fn(inputs) 284 | 285 | data = data.reshape([*leading_shape, -1]) 286 | target = scf_approx.eval_orb_mat(data.reshape([-1, cell.nelectron, 3])) 287 | target = [tar.reshape([*leading_shape, ne, ne]) for tar, ne in zip(target, cell.nelec) if ne > 0] 288 | 289 | slogprob_net = [2 * jnp.linalg.slogdet(net_mat)[1] for net_mat in constants.pmap(batch_orbitals)(params, data)] 290 | slogprob_net = functools.reduce(lambda x, y: x+y, slogprob_net) 291 | 292 | sharded_key, subkeys = constants.p_split(sharded_key) 293 | params, opt_state_pt, loss = pretrain_step(data, target, params, opt_state_pt) 294 | 295 | logging.info('Pretrain iter %05d: Loss=%03.6f, pmove=%0.2f, ' 296 | 'Norm of Net prob=%03.4f, Norm of HF prob=%03.4f', 297 | t, loss[0], 298 | jnp.mean(num_accepts) / functools.reduce(lambda x, y: x*y, leading_shape), 299 | jnp.mean(slogprob_net), 300 | jnp.mean(logprob)) 301 | 302 | return params, data 303 | -------------------------------------------------------------------------------- /DeepSolid/utils/kfac_ferminet_alpha/tracer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 17 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 18 | """Module for the Jax tracer functionality for tags.""" 19 | import functools 20 | from typing import Any, Callable, Sequence, Tuple 21 | 22 | import jax 23 | from jax import core 24 | from jax import util as jax_util 25 | import jax.numpy as jnp 26 | 27 | from DeepSolid.utils.kfac_ferminet_alpha import layers_and_loss_tags as tags 28 | from DeepSolid.utils.kfac_ferminet_alpha import tag_graph_matcher as tgm 29 | from DeepSolid.utils.kfac_ferminet_alpha import utils 30 | from DeepSolid.utils.kfac_ferminet_alpha import vjp_rc 31 | 32 | _Function = Callable[[Any], Any] 33 | _Loss = tags.LossTag 34 | 35 | 36 | def extract_tags( 37 | jaxpr: core.Jaxpr 38 | ) -> Tuple[Sequence[core.JaxprEqn], Sequence[core.JaxprEqn]]: 39 | """Extracts all of the tag equations.""" 40 | # Loop through equations and evaluate primitives using `bind` 41 | layer_tags = [] 42 | loss_tags = [] 43 | for eqn in jaxpr.eqns: 44 | if isinstance(eqn.primitive, tags.LossTag): 45 | loss_tags.append(eqn) 46 | elif isinstance(eqn.primitive, tags.LayerTag): 47 | layer_tags.append(eqn) 48 | return tuple(layer_tags), tuple(loss_tags) 49 | 50 | 51 | def construct_compute_losses_inputs( 52 | jaxpr: core.Jaxpr, 53 | consts: Tuple[Any], 54 | num_losses: int, 55 | primals: Any, 56 | params_index: int) -> Callable[[Any], Sequence[Sequence[jnp.ndarray]]]: 57 | """Constructs a function that computes all of the inputs to all losses.""" 58 | primals_ = list(primals) 59 | 60 | def forward_compute_losses( 61 | params_primals: Any, 62 | ) -> Sequence[Sequence[jnp.ndarray]]: 63 | primals_[params_index] = params_primals 64 | flat_args = jax.tree_flatten(primals_)[0] 65 | # Mapping from variable -> value 66 | env = dict() 67 | read = functools.partial(tgm.read_env, env) 68 | write = functools.partial(tgm.write_env, env) 69 | 70 | # Bind args and consts to environment 71 | write(jax.core.unitvar, jax.core.unit) 72 | jax_util.safe_map(write, jaxpr.invars, flat_args) 73 | jax_util.safe_map(write, jaxpr.constvars, consts) 74 | 75 | # Loop through equations and evaluate primitives using `bind` 76 | losses_so_far = 0 77 | loss_tags = [] 78 | for eqn in jaxpr.eqns: 79 | tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) 80 | if isinstance(eqn.primitive, tags.LossTag): 81 | loss_tags.append(eqn) 82 | losses_so_far += 1 83 | if num_losses is not None and losses_so_far == num_losses: 84 | break 85 | return tuple(tuple(read(v) for v in tag.invars) for tag in loss_tags) 86 | # return tuple(jax_util.safe_map(read, tag.invars) for tag in loss_tags) 87 | return forward_compute_losses 88 | 89 | 90 | # We know when `.primitive` will be either a `LossTag` or a `LayerTag`, however 91 | # pytype cannot infer its subclass, so we need to unbox it. 92 | 93 | 94 | def _unbox_loss_tag(jaxpr_eqn: core.JaxprEqn) -> tags.LossTag: 95 | assert isinstance(jaxpr_eqn.primitive, tags.LossTag) 96 | return jaxpr_eqn.primitive 97 | 98 | 99 | def _unbox_layer_tag(jaxpr_eqn: core.JaxprEqn) -> tags.LayerTag: 100 | assert isinstance(jaxpr_eqn.primitive, tags.LayerTag) 101 | return jaxpr_eqn.primitive 102 | 103 | 104 | def trace_losses_matrix_vector_vjp(tagged_func: _Function, 105 | params_index: int = 0): 106 | """Returns the Jacobian-transposed vector product (backward mode) function in equivalent form to jax.vjp.""" 107 | def vjp(*primals): 108 | typed_jaxpr = jax.make_jaxpr(tagged_func)(*primals) 109 | jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals 110 | _, loss_jaxpr_eqns = extract_tags(jaxpr) 111 | n = len(loss_jaxpr_eqns) 112 | losses_func = construct_compute_losses_inputs( 113 | jaxpr, consts, n, primals, params_index) 114 | losses_inputs, full_vjp_func = jax.vjp(losses_func, primals[params_index]) 115 | losses = [] 116 | for jaxpr_eqn, inputs in zip(loss_jaxpr_eqns, losses_inputs): 117 | loss_tag = _unbox_loss_tag(jaxpr_eqn) 118 | losses.append(loss_tag.loss(*inputs, weight=jaxpr_eqn.params["weight"])) 119 | losses = tuple(losses) 120 | 121 | def vjp_func(tangents): 122 | flat_tangents = jax.tree_flatten(tangents)[0] 123 | loss_invars = [] 124 | loss_targets = [] 125 | for jaxpr_eqn, inputs in zip(loss_jaxpr_eqns, losses_inputs): 126 | num_inputs = _unbox_loss_tag(jaxpr_eqn).num_inputs 127 | loss_invars.append(tuple(jaxpr_eqn.invars[:num_inputs])) 128 | loss_targets.append(inputs[num_inputs:]) 129 | treedef = jax.tree_structure(loss_invars) 130 | tangents = jax.tree_unflatten(treedef, flat_tangents) 131 | # Since the losses could also take and targets as inputs and we don't want 132 | # this function to computes vjp w.r.t to those (e.g. the user should not 133 | # be providing tangent vectors for the targets, only for inputs) we have 134 | # to manually fill in these "extra" tangents with zeros. 135 | targets_tangents = jax.tree_map(jnp.zeros_like, loss_targets) 136 | tangents = tuple(ti + tti for ti, tti in zip(tangents, targets_tangents)) 137 | input_tangents = full_vjp_func(tangents)[0] 138 | return input_tangents, 139 | return losses, vjp_func 140 | return vjp 141 | 142 | 143 | def trace_losses_matrix_vector_jvp( 144 | tagged_func: _Function, 145 | params_index: int = 0): 146 | """Returns the Jacobian vector product (forward mode) function in equivalent form to jax.jvp.""" 147 | def jvp(primals, params_tangents): 148 | typed_jaxpr = jax.make_jaxpr(tagged_func)(*primals) 149 | jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals 150 | _, loss_tags = extract_tags(jaxpr) 151 | n = len(loss_tags) 152 | losses_func = construct_compute_losses_inputs(jaxpr, consts, n, 153 | primals, params_index) 154 | primals = (primals[params_index],) 155 | tangents = (params_tangents,) 156 | (primals_out, tangents_out) = jax.jvp(losses_func, primals, tangents) 157 | tangents_out = tuple(tuple(t[:tag.primitive.num_inputs]) 158 | for t, tag in zip(tangents_out, loss_tags)) 159 | losses = tuple(tag.primitive.loss(*inputs, weight=tag.params["weight"]) 160 | for tag, inputs in zip(loss_tags, primals_out)) 161 | return losses, tangents_out 162 | return jvp 163 | 164 | 165 | def trace_losses_matrix_vector_hvp(tagged_func, params_index=0): 166 | """Returns the Hessian vector product function of **the tagged losses**, rather than the output value of `tagged_func`.""" 167 | # The function uses backward-over-forward mode. 168 | 169 | def hvp(primals, params_tangents): 170 | typed_jaxpr = jax.make_jaxpr(tagged_func)(*primals) 171 | jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals 172 | _, loss_tags = extract_tags(jaxpr) 173 | n = len(loss_tags) 174 | losses_func = construct_compute_losses_inputs( 175 | jaxpr, consts, n, primals, params_index) 176 | 177 | def losses_sum(param_primals): 178 | loss_inputs = losses_func(param_primals) 179 | losses = [ 180 | _unbox_loss_tag(jaxpr_eqn).loss( 181 | *inputs, weight=jaxpr_eqn.params["weight"]) 182 | for jaxpr_eqn, inputs in zip(loss_tags, loss_inputs) 183 | ] 184 | # This computes the sum of losses evaluated. Makes it easier as we can 185 | # now use jax.grad rather than jax.vjp for taking derivatives. 186 | return sum(jnp.sum(loss.evaluate(None)) for loss in losses) 187 | 188 | def grads_times_tangents(params_primals): 189 | grads = jax.grad(losses_sum)(params_primals) 190 | return utils.inner_product(grads, params_tangents) 191 | 192 | return jax.grad(grads_times_tangents)(primals[params_index]) 193 | return hvp 194 | 195 | 196 | def trace_estimator_vjp(tagged_func: _Function) -> _Function: 197 | """Creates the function needed for an estimator of curvature matrices. 198 | 199 | Args: 200 | tagged_func: An function that has been annotated with tags both for layers 201 | and losses. 202 | 203 | Returns: 204 | A function with the same signatures as `tagged_func`, which when provided 205 | with inputs returns two things: 206 | 1. The instances of all losses objected that are tagged. 207 | 2. A second function, which when provide with tangent vectors for each 208 | of the loss instances' parameters, returns for every tagged layer a 209 | dictionary containing the following elements: 210 | inputs - The primal values of the inputs to the layer. 211 | outputs - The primal values of the outputs to the layer. 212 | params - The primal values of the layer. 213 | inputs_tangent - The tangent value of layer, given the provided 214 | tangents of the losses. 215 | inputs_tangent - The tangent value of layer, given the provided 216 | tangents of the losses. 217 | inputs_tangent - The tangent value of layer, given the provided 218 | tangents of the losses. 219 | """ 220 | def full_vjp_func(func_args): 221 | # Trace the tagged function 222 | typed_jaxpr = jax.make_jaxpr(tagged_func)(*func_args) 223 | jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals 224 | layer_tags, loss_tags = extract_tags(jaxpr) 225 | 226 | layer_vars_flat = jax.tree_flatten([tag.invars for tag in layer_tags])[0] 227 | layer_input_vars = tuple(set(layer_vars_flat)) 228 | 229 | def forward(): 230 | own_func_args = func_args 231 | # Mapping from variable -> value 232 | env = dict() 233 | read = functools.partial(tgm.read_env, env) 234 | write = functools.partial(tgm.write_env, env) 235 | 236 | # Bind args and consts to environment 237 | write(jax.core.unitvar, jax.core.unit) 238 | jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0]) 239 | jax_util.safe_map(write, jaxpr.constvars, consts) 240 | 241 | # Loop through equations and evaluate primitives using `bind` 242 | num_losses_passed = 0 243 | for eqn in jaxpr.eqns: 244 | tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) 245 | if isinstance(eqn.primitive, tags.LossTag): 246 | num_losses_passed += 1 247 | if num_losses_passed == len(loss_tags): 248 | break 249 | if num_losses_passed != len(loss_tags): 250 | raise ValueError("This should be unreachable.") 251 | 252 | return jax_util.safe_map(read, layer_input_vars) 253 | 254 | def forward_aux(aux): 255 | own_func_args = func_args 256 | # Mapping from variable -> value 257 | env = dict() 258 | read = functools.partial(tgm.read_env, env) 259 | def write(var, val): 260 | if not isinstance(var, (jax.core.Literal, jax.core.UnitVar)): 261 | val = val + aux[var] if var in aux else val 262 | env[var] = val 263 | 264 | # Bind args and consts to environment 265 | write(jax.core.unitvar, jax.core.unit) 266 | jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0]) 267 | jax_util.safe_map(write, jaxpr.constvars, consts) 268 | 269 | # Loop through equations and evaluate primitives using `bind` 270 | num_losses_passed = 0 271 | losses_inputs_values = [] 272 | losses_kwargs_values = [] 273 | for eqn in jaxpr.eqns: 274 | input_values = jax_util.safe_map(read, eqn.invars) 275 | tgm.evaluate_eqn(eqn, input_values, write) 276 | if isinstance(eqn.primitive, tags.LossTag): 277 | loss = eqn.primitive.loss(*input_values, weight=eqn.params["weight"]) 278 | losses_inputs_values.append(loss.inputs) 279 | losses_kwargs_values.append(dict( 280 | targets=loss.targets, 281 | weight=eqn.params["weight"] 282 | )) 283 | num_losses_passed += 1 284 | if num_losses_passed == len(loss_tags): 285 | break 286 | if num_losses_passed != len(loss_tags): 287 | raise ValueError("This should be unreachable.") 288 | # Read the inputs to the loss functions, but also return the target values 289 | return tuple(losses_inputs_values), tuple(losses_kwargs_values) 290 | 291 | layer_input_values = forward() 292 | primals_dict = dict(zip(layer_input_vars, layer_input_values)) 293 | primals_dict.update(zip(jaxpr.invars, jax.tree_flatten(func_args)[0])) 294 | aux_values = jax.tree_map(lambda x:jnp.zeros_like(x), layer_input_values) 295 | aux_dict = dict(zip(layer_input_vars, aux_values)) 296 | 297 | losses_args, aux_vjp, losses_kwargs = vjp_rc.vjp_rc(forward_aux, aux_dict, 298 | has_aux=True) 299 | losses = tuple(tag.primitive.loss(*inputs, **kwargs) 300 | for tag, inputs, kwargs in 301 | zip(loss_tags, losses_args, losses_kwargs)) 302 | 303 | def vjp_func(tangents): 304 | tangents = jax.tree_map(lambda x:x+0j, tangents) 305 | all_tangents = aux_vjp(tangents) 306 | tangents_dict, inputs_tangents = all_tangents[0], all_tangents[1:] 307 | inputs_tangents = jax.tree_flatten(inputs_tangents)[0] 308 | tangents_dict.update(zip(jaxpr.invars, inputs_tangents)) 309 | 310 | read_primals = functools.partial(tgm.read_env, primals_dict) 311 | read_tangents = functools.partial(tgm.read_env, tangents_dict) 312 | layers_info = [] 313 | for jaxpr_eqn in layer_tags: 314 | layer_tag = _unbox_layer_tag(jaxpr_eqn) 315 | info = dict() 316 | primals = jax_util.safe_map(read_primals, tuple(jaxpr_eqn.invars)) 317 | ( 318 | info["outputs"], 319 | info["inputs"], 320 | info["params"], 321 | ) = layer_tag.split_all_inputs(primals) 322 | tangents = jax_util.safe_map(read_tangents, tuple(jaxpr_eqn.invars)) 323 | ( 324 | info["outputs_tangent"], 325 | info["inputs_tangent"], 326 | info["params_tangent"], 327 | ) = layer_tag.split_all_inputs(tangents) 328 | layers_info.append(info) 329 | return tuple(layers_info) 330 | 331 | return losses, vjp_func 332 | return full_vjp_func 333 | -------------------------------------------------------------------------------- /DeepSolid/qmc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 16 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 17 | 18 | import logging 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | from DeepSolid import constants 23 | from DeepSolid import distance 24 | 25 | 26 | def _log_prob_gaussian(x, mu, sigma): 27 | """Calculates the log probability of Gaussian with diagonal covariance. 28 | 29 | Args: 30 | x: Positions. Shape (batch, nelectron, 1, ndim) - as used in mh_update. 31 | mu: means of Gaussian distribution. Same shape as or broadcastable to x. 32 | sigma: standard deviation of the distribution. Same shape as or 33 | broadcastable to x. 34 | 35 | Returns: 36 | Log probability of Gaussian distribution with shape as required for 37 | mh_update - (batch, nelectron, 1, 1). 38 | """ 39 | numer = jnp.sum(-0.5 * ((x - mu) ** 2) / (sigma ** 2), axis=[1, 2, 3]) 40 | denom = x.shape[-1] * jnp.sum(jnp.log(sigma), axis=[1, 2, 3]) 41 | return numer - denom 42 | 43 | 44 | def _harmonic_mean(x, atoms): 45 | """Calculates the harmonic mean of each electron distance to the nuclei. 46 | 47 | Args: 48 | x: electron positions. Shape (batch, nelectrons, 1, ndim). Note the third 49 | dimension is already expanded, which allows for avoiding additional 50 | reshapes in the MH algorithm. 51 | atoms: atom positions. Shape (natoms, ndim) 52 | 53 | Returns: 54 | Array of shape (batch, nelectrons, 1, 1), where the (i, j, 0, 0) element is 55 | the harmonic mean of the distance of the j-th electron of the i-th MCMC 56 | configuration to all atoms. 57 | """ 58 | ae = x - atoms[None, ...] 59 | r_ae = jnp.linalg.norm(ae, axis=-1, keepdims=True) 60 | return 1.0 / jnp.mean(1.0 / r_ae, axis=-2, keepdims=True) 61 | 62 | 63 | def limdrift(g:jnp.array, cutoff=1): 64 | """ 65 | Limit a vector to have a maximum magnitude of cutoff while maintaining direction 66 | 67 | Args: 68 | g: a [nconf,ndim] vector 69 | 70 | cutoff: the maximum magnitude 71 | 72 | Returns: 73 | The vector with the cut off applied. 74 | """ 75 | g_shape = g.shape 76 | g = g.reshape([-1, 3]) 77 | tot = jnp.linalg.norm(g, axis=-1) 78 | normalize = jnp.clip(tot, a_min=cutoff, a_max=jnp.max(tot)) 79 | g = cutoff * g / normalize[:, None] 80 | g = g.reshape(g_shape) 81 | return g 82 | 83 | def importance_update(params, 84 | f, 85 | x1, 86 | key, 87 | lp_1, 88 | num_accepts, 89 | latvec, 90 | stddev=0.02, 91 | atoms=None, 92 | i=0, 93 | ): 94 | """ 95 | Performs one importance sampling step using an all-electron move. 96 | :param params: a dictionary of parameters. 97 | :param f: val_and_grad of batch_slogdet 98 | :param x1: Initial MCMC configurations. Shape (batch, nelectrons*ndim). 99 | :param key: PRNG key. 100 | :param lp_1: slogdet of wavefunction at original position x1. 101 | :param num_accepts: number of accepted MCMC moves. 102 | :param latvec: lattice vector of primitive cell. 103 | :param stddev: MCMC move width. 104 | :param atoms:atoms positions in the primitive cell 105 | :param i: 106 | :return: moved electron position x_new, key, slogdet value of x_new, and number of accepted MCMC moves. 107 | """ 108 | del i 109 | key, subkey = jax.random.split(key) 110 | if atoms is None: # symmetric proposal, same stddev everywhere 111 | _, grad = f(params, x1) 112 | grad = limdrift(grad) 113 | gauss = stddev * jax.random.normal(subkey, shape=x1.shape) 114 | x2 = x1 + gauss + stddev**2 * grad # proposal 115 | x2, _ = distance.enforce_pbc(latvec, x2) 116 | 117 | # Compute reverse move 118 | lpsi_2, new_grad = f(params, x2) 119 | lp_2 = 2 * lpsi_2 120 | new_grad = limdrift(new_grad) 121 | forward = jnp.sum(gauss ** 2, axis=-1) 122 | backward = jnp.sum((gauss + stddev**2 * (grad + new_grad)) ** 2, 123 | axis=-1) 124 | lp_2 = lp_2 + 1 / (2 * stddev**2) * (forward - backward) 125 | 126 | ratio = lp_2 - lp_1 127 | else: # asymmetric proposal, stddev propto harmonic mean of nuclear distances 128 | n = x1.shape[0] 129 | x1 = jnp.reshape(x1, [n, -1, 1, 3]) 130 | hmean1 = _harmonic_mean(x1, atoms) # harmonic mean of distances to nuclei 131 | 132 | x2 = x1 + stddev * hmean1 * jax.random.normal(subkey, shape=x1.shape) 133 | lp_2 = 2. * f(params, x2) # log prob of proposal 134 | hmean2 = _harmonic_mean(x2, atoms) # needed for probability of reverse jump 135 | 136 | lq_1 = _log_prob_gaussian(x1, x2, stddev * hmean1) # forward probability 137 | lq_2 = _log_prob_gaussian(x2, x1, stddev * hmean2) # reverse probability 138 | ratio = lp_2 + lq_2 - lp_1 - lq_1 139 | 140 | x1 = jnp.reshape(x1, [n, -1]) 141 | x2 = jnp.reshape(x2, [n, -1]) 142 | 143 | key, subkey = jax.random.split(key) 144 | rnd = jnp.log(jax.random.uniform(subkey, shape=lp_1.shape)) 145 | cond = ratio > rnd 146 | x_new = jnp.where(cond[..., None], x2, x1) 147 | lp_new = jnp.where(cond, lp_2, lp_1) 148 | num_accepts += jnp.sum(cond) 149 | 150 | return x_new, key, lp_new, num_accepts 151 | 152 | 153 | def mh_update(params, 154 | f, 155 | x1, 156 | key, 157 | lp_1, 158 | num_accepts, 159 | latvec, 160 | stddev=0.02, 161 | atoms=None, 162 | i=0, 163 | ): 164 | """Performs one Metropolis-Hastings step using an all-electron move. 165 | 166 | Args: 167 | params: Wavefuncttion parameters. 168 | f: Callable with signature f(params, x) which returns the log of the 169 | wavefunction (i.e. the sqaure root of the log probability of x). 170 | x1: Initial MCMC configurations. Shape (batch, nelectrons*ndim). 171 | key: RNG state. 172 | lp_1: log probability of f evaluated at x1 given parameters params. 173 | num_accepts: Number of MH move proposals accepted. 174 | latvec: lattice vector of primitive cell. 175 | stddev: width of Gaussian move proposal. 176 | atoms: If not None, atom positions. Shape (natoms, 3). If present, then the 177 | Metropolis-Hastings move proposals are drawn from a Gaussian distribution, 178 | N(0, (h_i stddev)^2), where h_i is the harmonic mean of distances between 179 | the i-th electron and the atoms, otherwise the move proposal drawn from 180 | N(0, stddev^2). 181 | 182 | Returns: 183 | (x, key, lp, num_accepts), where: 184 | x: Updated MCMC configurations. 185 | key: RNG state. 186 | lp: log probability of f evaluated at x. 187 | num_accepts: update running total of number of accepted MH moves. 188 | """ 189 | del i 190 | key, subkey = jax.random.split(key) 191 | if atoms is None: # symmetric proposal, same stddev everywhere 192 | x2 = x1 + stddev * jax.random.normal(subkey, shape=x1.shape) # proposal 193 | x2, _ = distance.enforce_pbc(latvec, x2) 194 | # reduce the electrons into the simulation cell. 195 | lp_2 = 2. * f(params, x2) # log prob of proposal 196 | ratio = lp_2 - lp_1 197 | else: # asymmetric proposal, stddev propto harmonic mean of nuclear distances 198 | n = x1.shape[0] 199 | x1 = jnp.reshape(x1, [n, -1, 1, 3]) 200 | hmean1 = _harmonic_mean(x1, atoms) # harmonic mean of distances to nuclei 201 | 202 | x2 = x1 + stddev * hmean1 * jax.random.normal(subkey, shape=x1.shape) 203 | x2 = jnp.reshape(x2, [n, -1]) 204 | x2, _ = distance.enforce_pbc(latvec, x2) 205 | lp_2 = 2. * f(params, x2) 206 | 207 | x2 = jnp.reshape(x2, [n, -1, 1, 3]) 208 | hmean2 = _harmonic_mean(x2, atoms) # needed for probability of reverse jump 209 | 210 | lq_1 = _log_prob_gaussian(x1, x2, stddev * hmean1) # forward probability 211 | lq_2 = _log_prob_gaussian(x2, x1, stddev * hmean2) # reverse probability 212 | ratio = lp_2 + lq_2 - lp_1 - lq_1 213 | 214 | x1 = jnp.reshape(x1, [n, -1]) 215 | x2 = jnp.reshape(x2, [n, -1]) 216 | 217 | key, subkey = jax.random.split(key) 218 | rnd = jnp.log(jax.random.uniform(subkey, shape=lp_1.shape)) 219 | cond = ratio > rnd 220 | x_new = jnp.where(cond[..., None], x2, x1) 221 | lp_new = jnp.where(cond, lp_2, lp_1) 222 | num_accepts += jnp.sum(cond) 223 | 224 | return x_new, key, lp_new, num_accepts 225 | 226 | 227 | def mh_one_electron_update(params, 228 | f, 229 | x1, 230 | key, 231 | lp_1, 232 | num_accepts, 233 | latvec, 234 | stddev=0.02, 235 | atoms=None, 236 | i=0): 237 | """Performs one Metropolis-Hastings step for a single electron. 238 | 239 | Args: 240 | params: Wavefuncttion parameters. 241 | f: Callable with signature f(params, x) which returns the log of the 242 | wavefunction (i.e. the sqaure root of the log probability of x). 243 | x1: Initial MCMC configurations. Shape (batch, nelectrons*ndim). 244 | key: RNG state. 245 | lp_1: log probability of f evaluated at x1 given parameters params. 246 | num_accepts: Number of MH move proposals accepted. 247 | latvec: lattice vector of primitive cell. 248 | stddev: width of Gaussian move proposal. 249 | atoms: Ignored. Asymmetric move proposals are not implemented for 250 | single-electron moves. 251 | i: index of electron to move. 252 | 253 | Returns: 254 | (x, key, lp, num_accepts), where: 255 | x: Updated MCMC configurations. 256 | key: RNG state. 257 | lp: log probability of f evaluated at x. 258 | num_accepts: update running total of number of accepted MH moves. 259 | 260 | Raises: 261 | NotImplementedError: if atoms is supplied. 262 | """ 263 | key, subkey = jax.random.split(key) 264 | n = x1.shape[0] 265 | x1 = jnp.reshape(x1, [n, -1, 1, 3]) 266 | nelec = x1.shape[1] 267 | ii = i % nelec 268 | if atoms is None: # symmetric proposal, same stddev everywhere 269 | x2 = x1.at[:, ii].add(stddev * 270 | jax.random.normal(subkey, shape=x1[:, ii].shape)) 271 | x2, _ = distance.enforce_pbc(latvec, x2) 272 | lp_2 = 2. * f(params, x2) # log prob of proposal 273 | ratio = lp_2 - lp_1 274 | else: # asymmetric proposal, stddev propto harmonic mean of nuclear distances 275 | raise NotImplementedError('Still need to work out reverse probabilities ' 276 | 'for asymmetric moves.') 277 | 278 | x1 = jnp.reshape(x1, [n, -1]) 279 | x2 = jnp.reshape(x2, [n, -1]) 280 | key, subkey = jax.random.split(key) 281 | rnd = jnp.log(jax.random.uniform(subkey, shape=lp_1.shape)) 282 | cond = ratio > rnd 283 | x_new = jnp.where(cond[..., None], x2, x1) 284 | lp_new = jnp.where(cond, lp_2, lp_1) 285 | num_accepts += jnp.sum(cond) 286 | 287 | return x_new, key, lp_new, num_accepts 288 | 289 | 290 | def make_mcmc_step(batch_slog_network, 291 | batch_per_device, 292 | latvec, 293 | steps=10, 294 | atoms=None, 295 | importance_sampling=None, 296 | one_electron_moves=False, 297 | ): 298 | """Creates the MCMC step function. 299 | 300 | Args: 301 | batch_slog_network: function, signature (params, x), which evaluates the log of 302 | the wavefunction (square root of the log probability distribution) at x 303 | given params. Inputs and outputs are batched. 304 | batch_per_device: Batch size per device. 305 | latvec: lattice vector of primitive cell. 306 | steps: Number of MCMC moves to attempt in a single call to the MCMC step 307 | function. 308 | atoms: atom positions. If given, an asymmetric move proposal is used based 309 | on the harmonic mean of electron-atom distances for each electron. 310 | Otherwise the (conventional) normal distribution is used. 311 | importance_sampling: if true, importance sampling is used for MCMC. 312 | Otherwise, Metropolis method is used. 313 | one_electron_moves: If true, attempt to move one electron at a time. 314 | Otherwise, attempt one all-electron move per MCMC step. 315 | 316 | Returns: 317 | Callable which performs the set of MCMC steps. 318 | """ 319 | if importance_sampling is not None: 320 | if one_electron_moves: 321 | raise ValueError('Importance sampling for one elec move is not implemented yet') 322 | else: 323 | logging.info('Using importance sampling') 324 | func = jax.vmap(jax.value_and_grad(importance_sampling, argnums=1), in_axes=(None, 0)) 325 | inner_fun = importance_update 326 | else: 327 | func = batch_slog_network 328 | if one_electron_moves: 329 | logging.info('Using one electron Metropolis sampling') 330 | inner_fun = mh_one_electron_update 331 | else: 332 | logging.info('Using Metropolis sampling') 333 | inner_fun = mh_update 334 | 335 | @jax.jit 336 | def mcmc_step(params, data, key, width): 337 | """Performs a set of MCMC steps. 338 | 339 | Args: 340 | params: parameters to pass to the network. 341 | data: (batched) MCMC configurations to pass to the network. 342 | key: RNG state. 343 | width: standard deviation to use in the move proposal. 344 | 345 | Returns: 346 | (data, pmove), where data is the updated MCMC configurations, key the 347 | updated RNG state and pmove the average probability a move was accepted. 348 | """ 349 | 350 | def step_fn(i, x): 351 | return inner_fun(params, func, *x, 352 | latvec=latvec, stddev=width, 353 | atoms=atoms, i=i) 354 | 355 | nelec = data.shape[-1] // 3 356 | nsteps = nelec * steps if one_electron_moves else steps 357 | logprob = 2. * batch_slog_network(params, data) 358 | data, key, _, num_accepts = jax.lax.fori_loop(0, nsteps, step_fn, 359 | (data, key, logprob, 0.)) 360 | pmove = jnp.sum(num_accepts) / (nsteps * batch_per_device) 361 | pmove = constants.pmean_if_pmap(pmove, axis_name=constants.PMAP_AXIS_NAME) 362 | return data, pmove 363 | 364 | return mcmc_step 365 | --------------------------------------------------------------------------------