├── .gitignore ├── jaqmc ├── __init__.py ├── dmc │ ├── __init__.py │ ├── utils.py │ ├── effective_time_step_calculator.py │ ├── ckpt_handler.py │ ├── state.py │ ├── data_path.py │ ├── energy_estimator.py │ ├── branch.py │ ├── metric_manager.py │ ├── branch_fix_size.py │ ├── hamiltonian.py │ └── storage_handler.py ├── pp │ ├── special.py │ ├── pp_config.py │ ├── utils │ │ └── init_electrons.py │ ├── ph │ │ ├── data.py │ │ └── hamiltonian.py │ ├── ecp_potential.py │ └── quadrature.py └── loss │ ├── utils.py │ ├── factory.py │ └── vmc.py ├── examples ├── loss │ ├── loss_config.py │ └── lapnet │ │ ├── atom_spin_state.py │ │ └── run.py ├── pp │ └── lapnet │ │ ├── configs │ │ ├── ecp │ │ │ └── X.py │ │ └── ph │ │ │ ├── X.py │ │ │ └── XS.py │ │ └── run.py └── dmc │ ├── deeperwin │ ├── deeperwin_model.py │ └── run.py │ ├── dmc_config.py │ ├── lapnet │ └── run.py │ └── ferminet │ └── run.py ├── CONTRIBUTING.md ├── setup.py ├── tests └── dmc │ ├── utils_single_host_test.py │ ├── utils_two_hosts_test.py │ ├── metric_manager_test.py │ ├── branch_fix_size_test.py │ ├── branch_test.py │ ├── state_test.py │ ├── ckpt_handler_test.py │ ├── storage_handler_test.py │ ├── data_path_test.py │ ├── hdfs_storage_handler_test.py │ ├── energy_estimator_test.py │ ├── recovery_test.py │ ├── dmc_test.py │ └── position_update_test.py ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | *.csv 4 | *.npz 5 | __pycache__/ 6 | -------------------------------------------------------------------------------- /jaqmc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /jaqmc/dmc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dmc import run 8 | -------------------------------------------------------------------------------- /examples/loss/loss_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | Loss configurations. 9 | ''' 10 | 11 | from jaqmc.loss.spin_penalty import DEFAULT_SPIN_PENALTY_CONFIG 12 | 13 | def get_config(): 14 | return {'enforce_spin': DEFAULT_SPIN_PENALTY_CONFIG} 15 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to JaQMC 2 | 3 | Contributions are welcomed and highly appreciated! 4 | 5 | We welcome [pull requests](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests). 6 | When contributing new features or bug fixes, please 7 | 1. Conform to [Python Style Guide](https://google.github.io/styleguide/pyguide.html). 8 | 2. Add unittests if possible. 9 | 3. Remove all (accidentally included) sensitive data or information. 10 | 4. Have sensible PR title and commit messages. 11 | -------------------------------------------------------------------------------- /jaqmc/pp/special.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import jax.numpy as jnp 8 | 9 | 10 | def legendre(x, l): 11 | if l == 0: 12 | return jnp.ones_like(x) 13 | if l == 1: 14 | return x 15 | if l == 2: 16 | return (3 * x**2 - 1) / 2 17 | else: 18 | pass 19 | 20 | def legendre_list(x, l_list): 21 | result = [] 22 | for l in l_list: 23 | result.append(legendre(x, l)[None, ...]) 24 | return jnp.concatenate(result, axis=0) 25 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages 8 | from setuptools import setup 9 | 10 | REQUIRED_PACKAGES = [ 11 | 'attrs', 12 | 'chex', 13 | 'jax', 14 | 'jaxlib', 15 | 'ml-collections', 16 | 'optax', 17 | 'numpy', 18 | 'pandas', 19 | 'pyscf', 20 | 'scipy', 21 | ] 22 | 23 | setup( 24 | name='jaqmc', 25 | version='0.0.3', 26 | description='JAX accelerated Quantum Monte Carlo', 27 | author='ByteDance Research', 28 | packages=find_packages(), 29 | install_requires=REQUIRED_PACKAGES, 30 | platforms=['any'], 31 | license='Apache 2.0', 32 | ) 33 | -------------------------------------------------------------------------------- /examples/pp/lapnet/configs/ecp/X.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pyscf import gto 8 | 9 | from lapnet import base_config 10 | from jaqmc.pp.pp_config import get_config as get_ecp_config 11 | 12 | def get_config(input_str): 13 | symbol, spin = input_str.split(',') 14 | spin = int(spin) 15 | 16 | cfg = base_config.default() 17 | cfg['ecp'] = get_ecp_config() 18 | mol = gto.Mole() 19 | # Set up molecule 20 | mol.build( 21 | atom=f'{symbol} 0 0 0', 22 | basis={symbol: 'ccecpccpvdz'}, 23 | ecp={symbol: 'ccecp'}, 24 | spin=spin) 25 | 26 | cfg.system.pyscf_mol = mol 27 | return cfg -------------------------------------------------------------------------------- /examples/pp/lapnet/configs/ph/X.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pyscf import gto 8 | 9 | from lapnet import base_config 10 | from jaqmc.pp.ph.data import PH_config 11 | from jaqmc.pp.pp_config import get_config as get_ecp_config 12 | 13 | @PH_config 14 | def get_config(input_str): 15 | symbol, spin = input_str.split(',') 16 | spin = int(spin) 17 | 18 | cfg = base_config.default() 19 | cfg['ecp'] = get_ecp_config() 20 | mol = gto.Mole() 21 | # Set up molecule 22 | mol.build( 23 | atom=f'{symbol} 0 0 0', 24 | basis={symbol: 'ccecpccpvdz'}, 25 | ecp={symbol: 'ccecp'}, 26 | spin=spin) 27 | 28 | cfg.system.pyscf_mol = mol 29 | return cfg -------------------------------------------------------------------------------- /tests/dmc/utils_single_host_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from absl.testing import absltest 8 | import jax.test_util as jtu 9 | import jax.numpy as jnp 10 | 11 | from jaqmc.dmc.utils import agg_sum, agg_mean 12 | 13 | class UtilsSingleHostTest(jtu.JaxTestCase): 14 | def test_agg_mean(self): 15 | x = jnp.arange(1, 10) 16 | actual_mean = agg_mean(x) 17 | 18 | self.assertEqual(actual_mean, 5.0) 19 | 20 | def test_agg_mean_weighted(self): 21 | x = jnp.arange(1, 10) 22 | weights = jnp.array([1.0] * 5 + [0.0] * 4) 23 | actual_mean = agg_mean(x, weights=weights) 24 | 25 | self.assertEqual(actual_mean, 3) 26 | 27 | def test_agg_sum(self): 28 | x = jnp.arange(1, 101) 29 | actual_sum = agg_sum(x) 30 | 31 | self.assertEqual(actual_sum, 5050) 32 | 33 | if __name__ == '__main__': 34 | absltest.main() 35 | -------------------------------------------------------------------------------- /examples/loss/lapnet/atom_spin_state.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from lapnet import base_config 8 | from lapnet.utils import system 9 | 10 | 11 | def atom_set(cfg, spin): 12 | atom = system.Atom(cfg.system.atom) 13 | cfg.system.molecule = [atom] 14 | 15 | if (atom.charge - spin) % 2 != 0.: 16 | raise ValueError(f"Wrongly assign spin! Difference between spin-up and -down cannot be {spin} for {atom.charge} electrons!") 17 | 18 | alpha = int(spin + (atom.charge - spin) // 2) 19 | beta = int((atom.charge - spin) // 2) 20 | cfg.system.electrons = [alpha, beta] 21 | cfg.system.atom_spin_configs = [[alpha, beta],] 22 | return cfg 23 | 24 | 25 | def get_config(input_str): 26 | name, spin = input_str.split(',') 27 | spin = int(spin) 28 | cfg = base_config.default() 29 | cfg.system.atom = name 30 | cfg = atom_set(cfg, spin) 31 | return cfg 32 | -------------------------------------------------------------------------------- /examples/pp/lapnet/configs/ph/XS.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) ByteDance, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from pyscf import gto 9 | 10 | import numpy as np 11 | 12 | from lapnet import base_config 13 | from jaqmc.pp.ph.data import PH_config 14 | from jaqmc.pp.pp_config import get_config as get_ecp_config 15 | 16 | @PH_config 17 | def get_config(input_str): 18 | symbol, dist, unit, spin,Xup,Xdn,Yup,Ydn = input_str.split(',') 19 | 20 | # Get default options. 21 | cfg = base_config.default() 22 | cfg['ecp'] = get_ecp_config() 23 | 24 | mol = gto.Mole() 25 | 26 | # Set up molecule 27 | mol.build( 28 | atom=f'{symbol} 0 0 0; S 0 0 {dist}', 29 | basis={symbol: 'ccecpccpvdz', 'S': 'ccecpccpvdz'}, 30 | ecp={symbol: 'ccecp', 'S': 'ccecp'}, 31 | spin=int(spin), unit=unit) 32 | 33 | cfg.system.pyscf_mol = mol 34 | cfg.system.atom_spin_configs = [(int(Xup), int(Xdn)), (int(Yup), int(Ydn))] 35 | cfg.ecp.ph_elements = (symbol, 'S') 36 | return cfg -------------------------------------------------------------------------------- /tests/dmc/utils_two_hosts_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from absl.testing import absltest 8 | import jax.test_util as jtu 9 | import jax.numpy as jnp 10 | 11 | from jaqmc.dmc.utils import agg_sum, agg_mean 12 | from jaqmc.dmc.runner import initialize_distributed_runtime 13 | 14 | NUM_HOST = 2 15 | 16 | class UtilsTwoHostsTest(jtu.JaxTestCase): 17 | def test_agg_mean(self): 18 | x = jnp.arange(1, 10) 19 | actual_mean = agg_mean(x) 20 | 21 | self.assertEqual(actual_mean, 5.0) 22 | 23 | def test_agg_mean_weighted(self): 24 | x = jnp.arange(1, 10) 25 | weights = jnp.array([1.0] * 5 + [0.0] * 4) 26 | actual_mean = agg_mean(x, weights=weights) 27 | 28 | self.assertEqual(actual_mean, 3) 29 | 30 | def test_agg_sum(self): 31 | x = jnp.arange(1, 101) 32 | actual_sum = agg_sum(x) 33 | 34 | single_host_result = 5050 35 | self.assertEqual(actual_sum, single_host_result * NUM_HOST) 36 | 37 | if __name__ == '__main__': 38 | initialize_distributed_runtime() 39 | absltest.main() 40 | -------------------------------------------------------------------------------- /tests/dmc/metric_manager_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import shutil 9 | import tempfile 10 | 11 | from absl.testing import absltest 12 | import chex 13 | import numpy as np 14 | 15 | from jaqmc.dmc.metric_manager import MetricManager 16 | 17 | class MetricManagerTest(chex.TestCase): 18 | 19 | def setUp(self): 20 | self.tmpdir = tempfile.mkdtemp() 21 | 22 | def tearDown(self): 23 | shutil.rmtree(self.tmpdir) 24 | 25 | test_filename = 'test.csv' 26 | 27 | def test_write(self): 28 | schema = ['a', 'b'] 29 | metric_manager = MetricManager(self.test_filename, schema, self.tmpdir) 30 | 31 | test_data = np.arange(8).reshape(4, 2) 32 | with metric_manager: 33 | for i, data in enumerate(test_data): 34 | metric_manager.write(i, data) 35 | 36 | data = metric_manager.get_metric_data() 37 | 38 | target_a_col = test_data[:, 0] 39 | target_b_col = test_data[:, 1] 40 | chex.assert_trees_all_close(data.a.to_numpy(), target_a_col) 41 | chex.assert_trees_all_close(data.b.to_numpy(), target_b_col) 42 | 43 | if __name__ == '__main__': 44 | absltest.main() 45 | -------------------------------------------------------------------------------- /tests/dmc/branch_fix_size_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import unittest 3 | 4 | import chex 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from jaqmc.dmc.branch_fix_size import branch 9 | 10 | class BranchFixSizeTest(chex.TestCase): 11 | def test_branch(self): 12 | # The first walker is 0-weight and therefore it's expected to be discarded. 13 | test_weight = jnp.array([0.0, 0.1, 0.5, 0.5, 4.0]) 14 | key = jax.random.PRNGKey(int(1e6 * time.time())) 15 | before_branch_array = jnp.array([1, 2, 3, 4, 5]) 16 | 17 | expected_after_branch_array = jnp.array([5, 2, 3, 4, 5]) 18 | expected_weight = jnp.array([2.0, 0.1, 0.5, 0.5, 2.0]) 19 | 20 | for test_min_thres, test_max_thres in [(0.3, 2.0), 21 | (-1.0, 2.0), 22 | (0.2, 5.0)]: 23 | 24 | with self.subTest(msg=f'check min_thres {test_min_thres} and max_thres {test_max_thres}'): 25 | actual_weight, [actual_branch_array] = branch(test_weight, key, [before_branch_array], 26 | min_thres=test_min_thres, max_thres=test_max_thres) 27 | self.assertSequenceAlmostEqual(expected_weight, actual_weight) 28 | self.assertSequenceAlmostEqual(expected_after_branch_array, actual_branch_array) 29 | 30 | if __name__ == '__main__': 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /tests/dmc/branch_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from absl.testing import absltest 8 | import jax.test_util as jtu 9 | import jax 10 | import jax.numpy as jnp 11 | 12 | from jaqmc.dmc.branch import branch, do_branch, round_merge_pairs 13 | 14 | class BranchTest(jtu.JaxTestCase): 15 | def test_do_branch(self): 16 | key = jax.random.PRNGKey(42) 17 | min_thres = 0.5 18 | max_thres = 2 19 | weight = jnp.array([0.01, 0.02, 0.1, 0.1, 0.2, 0.3, 0.44, 1.6, 1.8, 3, 4]) 20 | 21 | updated_weight, repeat_num = do_branch( 22 | weight, 23 | key, 24 | 5, 25 | min_thres=min_thres, 26 | max_thres=max_thres) 27 | 28 | def total_weight(weight, repeat_num): 29 | return jnp.sum(weight * repeat_num) 30 | 31 | expected_weight = jnp.array([0.01, 0.03, 0.2 , 0.1 , 0.5 , 0.3 , 0.44, 1.02, 1.8 , 1.5 , 2]) 32 | expected_repeat_num = jnp.array([0, 1, 1, 0, 1, 0, 0, 2, 1, 2, 2]) 33 | 34 | self.assertAlmostEqual(total_weight(updated_weight, repeat_num), jnp.sum(weight), places=5) 35 | self.assertArraysEqual(updated_weight, expected_weight) 36 | self.assertArraysEqual(repeat_num, expected_repeat_num) 37 | 38 | def test_round_merge_pairs(self): 39 | for num in range(11): 40 | self.assertEqual(round_merge_pairs(num), num) 41 | 42 | for num, expected_round_result in [ 43 | (25, 20), 44 | (35, 30), 45 | (120, 100), 46 | (1200, 1000), 47 | (2200, 2000), 48 | (9200, 9000)]: 49 | self.assertEqual(round_merge_pairs(num), expected_round_result) 50 | 51 | if __name__ == '__main__': 52 | absltest.main() 53 | -------------------------------------------------------------------------------- /tests/dmc/state_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import shutil 8 | import tempfile 9 | 10 | from absl.testing import absltest 11 | import chex 12 | import jax.numpy as jnp 13 | 14 | from jaqmc.dmc.state import State, MixedEstimatorCalculator, EffectiveTimeStepCalculator 15 | from jaqmc.dmc.ckpt_handler import CkptHandler 16 | from jaqmc.dmc.ckpt_metric_manager import CkptMetricManager 17 | 18 | 19 | class StateTest(chex.TestCase): 20 | 21 | def test_default_state(self): 22 | init_position = jnp.arange(12).reshape((6, 2)) 23 | test_energy = -88.88 24 | def calc_energy_func(self, *args, **kwargs): 25 | return test_energy 26 | 27 | mixed_estimator_num_steps = 88 28 | energy_window_size = 22 29 | time_step = 1e-3 30 | init_step = 66 31 | 32 | state = State.default( 33 | init_position=init_position, 34 | calc_energy_func=calc_energy_func, 35 | mixed_estimator_num_steps=mixed_estimator_num_steps, 36 | energy_window_size=energy_window_size, 37 | time_step=time_step) 38 | 39 | chex.assert_tree_all_close(state.position, init_position) 40 | chex.assert_tree_all_close(state.walker_age, jnp.ones(len(init_position))) 41 | chex.assert_tree_all_close(state.weight, jnp.ones(len(init_position))) 42 | self.assertIsNone(state.local_energy) 43 | self.assertEqual(state.energy_offset, test_energy) 44 | self.assertEqual(state.target_num_walkers, len(init_position)) 45 | self.assertEqual(state.mixed_estimator, test_energy) 46 | self.assertEqual(state.mixed_estimator_calculator.mixed_estimator_num_steps, 47 | mixed_estimator_num_steps) 48 | self.assertEqual(state.mixed_estimator_calculator.all_energy.maxlen, 49 | energy_window_size) 50 | self.assertEqual(state.effective_time_step_calculator.time_step, time_step) 51 | 52 | if __name__ == '__main__': 53 | absltest.main() 54 | -------------------------------------------------------------------------------- /jaqmc/dmc/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | Misc util functions. 9 | ''' 10 | 11 | from functools import partial 12 | 13 | import jax 14 | import jax.numpy as jnp 15 | 16 | compute_mean = jax.pmap(lambda x: jax.lax.pmean(x, "i"), axis_name="i") 17 | compute_sum = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i") 18 | 19 | def agg_helper(x, p_agg_func, local_agg_func): 20 | ''' 21 | Do local aggregation with `local_agg_func` (like `jnp.sum`) then do global aggregation 22 | with `p_agg_func` (like `compute_sum`) 23 | ''' 24 | # Handle scalar case 25 | x = jnp.asarray(x) 26 | 27 | # Another way of implementation is to reshape the input `x` to a pmap'able 28 | # shape. However that requires the array length to be divisible by the number 29 | # of (local) devices, which may not be satisfied in DMC. 30 | local_agg_result = jnp.ones(jax.local_device_count()) * local_agg_func(x) 31 | return p_agg_func(local_agg_result)[0] 32 | 33 | _agg_mean = partial(agg_helper, p_agg_func=compute_mean, local_agg_func=jnp.mean) 34 | _agg_sum = partial(agg_helper, p_agg_func=compute_sum, local_agg_func=jnp.sum) 35 | 36 | # Idealy we should just do `jnp.average` to avoid overhead here for single-host case. 37 | # However we didn't find a neat way to do such logic branching. Will leave it for future work 38 | # TODO Do `jnp.average` for single-host case and do `agg_mean` for multiple-host case. 39 | def agg_mean(x, weights=None): 40 | ''' 41 | Do global and across-hosts `jnp.average` 42 | ''' 43 | if weights is None: 44 | return _agg_mean(x) 45 | total_weight = agg_sum(weights) 46 | weighted_x = x * weights / total_weight 47 | return agg_sum(weighted_x) 48 | 49 | def agg_sum(x): 50 | ''' 51 | Do global and across-hosts `jnp.sum` 52 | ''' 53 | # `_agg_sum` would sum over more elements (with multiple `jax.local_device_count()`) 54 | double_counted_result = _agg_sum(x) 55 | return double_counted_result / jax.local_device_count() 56 | -------------------------------------------------------------------------------- /jaqmc/dmc/effective_time_step_calculator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | Effective time step calculation 9 | ''' 10 | 11 | import jax.numpy as jnp 12 | 13 | from .utils import agg_mean, agg_sum 14 | 15 | class EffectiveTimeStepCalculator: 16 | 17 | def __init__(self, time_step): 18 | self.total_weight = 0.0 19 | self.denominator = 0.0 20 | self.numerator = 0.0 21 | self.time_step = time_step 22 | 23 | def update(self, 24 | diffusion_displacement, 25 | acceptance_rate, 26 | weights): 27 | old_total_weight = self.total_weight 28 | total_of_new_weights = agg_sum(weights) 29 | self.total_weight += total_of_new_weights 30 | 31 | new_avg_diffusion = agg_mean(diffusion_displacement ** 2, 32 | weights=weights) 33 | new_avg_accepted_diffusion = agg_mean(acceptance_rate * (diffusion_displacement ** 2), 34 | weights=weights) 35 | 36 | self.denominator = ( 37 | self.denominator * (old_total_weight / self.total_weight) 38 | + new_avg_diffusion * (total_of_new_weights / self.total_weight))# avg_diffusion_on_denominator 39 | self.numerator = ( 40 | self.numerator * (old_total_weight / self.total_weight) 41 | + new_avg_accepted_diffusion * (total_of_new_weights / self.total_weight))# avg_accepted_diffusion_on_numerator 42 | 43 | def update_ebye(self, 44 | effective_time_step_list, 45 | weights): 46 | total_of_new_weights = agg_sum(weights) 47 | self.total_weight += total_of_new_weights 48 | 49 | new_sum_effective_time_step = agg_sum(effective_time_step_list * weights) 50 | 51 | self.denominator = self.total_weight * self.time_step 52 | self.numerator += new_sum_effective_time_step 53 | 54 | def run(self): 55 | if self.total_weight == 0.0 or self.denominator == 0.0: 56 | return self.time_step 57 | return self.time_step * jnp.exp(jnp.log(self.numerator) - jnp.log(self.denominator)) 58 | -------------------------------------------------------------------------------- /jaqmc/pp/pp_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import ml_collections 8 | 9 | def get_config(): 10 | return ml_collections.ConfigDict({ 11 | 'use_ecp_optim':False, # if use ecp optim 12 | 'ecp_select_core':{ 13 | # When calculating ECP (Effective Core Potential), 14 | # in principle, it is necessary to integrate over 15 | # all atomic nuclei. However, due to the form of the integration, 16 | # the integral for nuclei that are far from the electrons is zero. 17 | # Utilizing this property, 18 | # one can choose to integrate only the nuclei that are close to the electrons. 19 | # The parameter max_core determines the number of nearby nuclei to be selected for integration. 20 | # you can choose the max_core by ecp cutoff range 21 | # max_core = 2 can meet the needs of most molecular systems 22 | 'max_core': 2 23 | }, 24 | 'ecp_quadrature_id': 'icosahedron_12',# quadrature rule 25 | # PH related information. 26 | # NOT supposed to be specified over command line. 27 | # When using the "PH_config" decorator in the config file, 28 | # this config will be updated automatically. 29 | 'ph_info': None, 30 | # None (the default value) means all available ph elements. 31 | # If you want to use a subset of ph elements, you can specify them as a set in the config file. 32 | # For example, if you want to use Cr and Mn, you can specify it as {'Cr', 'Mn'}. 33 | # WARNING: Customization through command line is NOT supported. 34 | 'ph_elements': None, 35 | # PH rv_type could be 'spline' or 'linear'. 36 | # Not much difference from our tests. We use 'spline' as default since 37 | # that's (what we believe) what's used when PH is constructed. 38 | # That said, we strongly disencourage mixing the results from calculations 39 | # with 'spline' and 'linear'. If you choose 'spline', stick with it in 40 | # all your calculations. Comparing 'spline' result and 'linear' one 41 | # is technically not apple-to-apple comparision. 42 | 'ph_rv_type': 'spline' 43 | }, 44 | ) 45 | -------------------------------------------------------------------------------- /tests/dmc/ckpt_handler_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import shutil 9 | import tempfile 10 | 11 | from absl.testing import absltest 12 | import chex 13 | import jax.numpy as jnp 14 | import numpy as np 15 | 16 | from jaqmc.dmc.ckpt_handler import CkptHandler 17 | from jaqmc.dmc.data_path import DataPath 18 | from jaqmc.dmc.state import State 19 | 20 | class CkptHandlerTest(chex.TestCase): 21 | 22 | def setUp(self): 23 | self.tmpdir = tempfile.mkdtemp() 24 | 25 | def tearDown(self): 26 | shutil.rmtree(self.tmpdir) 27 | 28 | def test_saving_ckpt(self): 29 | test_step = 5 30 | test_prefix = 'test_none' 31 | ckpt_handler = CkptHandler(ckpt_file_prefix=test_prefix, local_save_path=self.tmpdir) 32 | 33 | init_position = jnp.arange(12).reshape((6, 2)) 34 | test_energy = -88.88 35 | def calc_energy_func(self, *args, **kwargs): 36 | return test_energy 37 | 38 | mixed_estimator_num_steps = 88 39 | energy_window_size = 22 40 | time_step = 1e-3 41 | init_step = 66 42 | 43 | default_state = State.default( 44 | init_position=init_position, 45 | calc_energy_func=calc_energy_func, 46 | mixed_estimator_num_steps=mixed_estimator_num_steps, 47 | energy_window_size=energy_window_size, 48 | time_step=time_step) 49 | ckpt_handler.save(test_step, default_state) 50 | 51 | target_ckpt_path = os.path.join( 52 | ckpt_handler.local_save_path, 53 | ckpt_handler.ckpt_file_pattern.format(step=test_step)) 54 | step, state = ckpt_handler._load_ckpt(target_ckpt_path) 55 | 56 | self.assertEqual(step, test_step) 57 | chex.assert_trees_all_close(state.position, init_position) 58 | chex.assert_trees_all_close(state.walker_age, jnp.ones(len(init_position))) 59 | chex.assert_trees_all_close(state.weight, jnp.ones(len(init_position))) 60 | self.assertIsNone(state.local_energy) 61 | self.assertEqual(state.energy_offset, test_energy) 62 | self.assertEqual(state.target_num_walkers, len(init_position)) 63 | self.assertEqual(state.mixed_estimator, test_energy) 64 | self.assertEqual(state.mixed_estimator_calculator.mixed_estimator_num_steps, 65 | mixed_estimator_num_steps) 66 | self.assertEqual(state.mixed_estimator_calculator.all_energy.maxlen, 67 | energy_window_size) 68 | self.assertEqual(state.effective_time_step_calculator.time_step, time_step) 69 | 70 | if __name__ == '__main__': 71 | absltest.main() 72 | -------------------------------------------------------------------------------- /jaqmc/dmc/ckpt_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | Ckpt saving / loading 9 | ''' 10 | 11 | import os 12 | from typing import Optional, Tuple 13 | 14 | from absl import logging 15 | import numpy as np 16 | 17 | from .state import State 18 | 19 | class CkptHandler: 20 | ''' 21 | Handling saving and loading checkpoint (ckpt for short) files. 22 | Only support local paths. 23 | ''' 24 | def __init__(self, 25 | ckpt_file_prefix: str, 26 | local_save_path: str): 27 | ''' 28 | Args: 29 | ckpt_file_prefix: The name of is ckpt file consists of two parts: 30 | 1. this `ckpt_file_prefix` 31 | 2. the step corresponding to this ckpt. 32 | local_save_path: Where to save and load ckpts. 33 | ''' 34 | self.ckpt_file_prefix = ckpt_file_prefix 35 | self.ckpt_file_pattern = self.ckpt_file_prefix + '_{step:06d}.npz' 36 | self.local_save_path = local_save_path 37 | 38 | def load_ckpt(self, step: int) -> State: 39 | ''' 40 | Load the ckpt file corresponding to `step` 41 | ''' 42 | local_path = self._get_ckpt_path(step) 43 | _, state = self._load_ckpt(local_path) 44 | return state 45 | 46 | #TODO Have unittest to enforce the step in the name of ckpt file is 47 | # the same as the one stored inside the ckpt. 48 | def save(self, step: int, state: State) -> str: 49 | ''' 50 | Save the `state` in a ckpt corresponding to `step`. 51 | 52 | Return the local path of the saved ckpt file. 53 | ''' 54 | local_path = self._get_ckpt_path(step) 55 | with open(local_path, 'wb') as f: 56 | np.savez( 57 | f, 58 | step=step, 59 | state=state) 60 | return local_path 61 | 62 | def rm_ckpt_at(self, target_step: Optional[int]): 63 | if self.local_save_path is None or target_step is None: 64 | logging.warning(f'No ckpt to remove at step {target_step} in path {self.local_save_path}') 65 | return 66 | to_rm_filename = self.ckpt_file_pattern.format(step=target_step) 67 | to_rm_path = os.path.join(self.local_save_path, to_rm_filename) 68 | if os.path.exists(to_rm_path): 69 | os.remove(to_rm_path) 70 | 71 | def _get_ckpt_path(self, step: int) -> str: 72 | return os.path.join(self.local_save_path, self.ckpt_file_pattern.format(step=step)) 73 | 74 | @staticmethod 75 | def _load_ckpt(ckpt_path: str) -> Tuple[int, State]: 76 | ''' 77 | Return the step and state in the ckpt file at `ckp_path`. 78 | ''' 79 | ckpt_data = np.load(ckpt_path, allow_pickle=True) 80 | step = ckpt_data['step'].tolist() 81 | state = ckpt_data['state'].tolist() 82 | return step, state 83 | -------------------------------------------------------------------------------- /tests/dmc/storage_handler_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import pathlib 9 | import shutil 10 | import tempfile 11 | 12 | from absl.testing import absltest 13 | from absl.testing import parameterized 14 | import chex 15 | 16 | from jaqmc.dmc.storage_handler import local_storage_handler 17 | 18 | class LocalStorageHandlerTest(chex.TestCase, parameterized.TestCase): 19 | 20 | def setUp(self): 21 | self.tmpdir = tempfile.mkdtemp() 22 | self.tmpdir_path = pathlib.Path(self.tmpdir) 23 | 24 | def tearDown(self): 25 | shutil.rmtree(self.tmpdir) 26 | 27 | def test_rm_file(self): 28 | tmp_file = self.tmpdir_path / 'test.txt' 29 | tmp_file.touch() 30 | self.assertTrue(tmp_file.exists()) 31 | 32 | local_storage_handler.rm(str(tmp_file)) 33 | self.assertFalse(tmp_file.exists()) 34 | 35 | def test_rm_dir(self): 36 | tmp_dir = self.tmpdir_path / 'test_dir/' 37 | tmp_dir.mkdir() 38 | self.assertTrue(tmp_dir.exists() and os.path.isdir(tmp_dir)) 39 | 40 | local_storage_handler.rm(str(tmp_dir)) 41 | self.assertFalse(tmp_dir.exists()) 42 | 43 | @parameterized.parameters(True, False) 44 | def test_ls_path(self, return_fullpath): 45 | all_file_paths = [] 46 | for i in range(6): 47 | tmp_filename = f'test_{i}.txt' 48 | tmp_file = self.tmpdir_path / tmp_filename 49 | tmp_file.touch() 50 | all_file_paths.append(str(tmp_file)) 51 | expected_set = (set(all_file_paths) 52 | if return_fullpath 53 | else set(os.path.basename(f) for f in all_file_paths)) 54 | 55 | ls_results = local_storage_handler.ls(self.tmpdir, return_fullpath=return_fullpath) 56 | self.assertSetEqual(set(ls_results), expected_set) 57 | 58 | def test_exists_no_file(self): 59 | fake_path = str(self.tmpdir_path / 'fake.txt') 60 | self.assertFalse(local_storage_handler.exists(fake_path)) 61 | 62 | def test_exists(self): 63 | test_file = self.tmpdir_path / 'test.txt' 64 | test_file.touch() 65 | test_dir = self.tmpdir_path / 'test_dir/' 66 | test_dir.mkdir() 67 | 68 | self.assertTrue(local_storage_handler.exists(str(test_file))) 69 | self.assertTrue(local_storage_handler.exists(str(test_dir))) 70 | 71 | def test_exists_dir_no_dir(self): 72 | test_file = self.tmpdir_path / 'test.txt' 73 | test_file.touch() 74 | fake_path = str(self.tmpdir_path / 'fake/') 75 | 76 | self.assertFalse(local_storage_handler.exists_dir(str(test_file))) 77 | self.assertFalse(local_storage_handler.exists_dir(fake_path)) 78 | 79 | def test_exists_dir(self): 80 | test_dir = self.tmpdir_path / 'test_dir/' 81 | test_dir.mkdir() 82 | 83 | self.assertTrue(local_storage_handler.exists_dir(str(test_dir))) 84 | 85 | if __name__ == '__main__': 86 | absltest.main() 87 | -------------------------------------------------------------------------------- /jaqmc/pp/utils/init_electrons.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Sequence, Tuple 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import numpy as np 12 | from absl import logging 13 | 14 | from . import system 15 | 16 | def init_electrons( 17 | key, 18 | molecule: Sequence[system.Atom], 19 | electrons: Sequence[int], 20 | batch_size: int, 21 | init_width=1.0, 22 | given_atomic_spin_configs: Sequence[Tuple[int, int]] = None 23 | ) -> jnp.ndarray: 24 | """Initializes electron positions around each atom. 25 | 26 | Args: 27 | key: JAX RNG state. 28 | molecule: system.Atom objects making up the molecule. 29 | electrons: tuple of number of alpha and beta electrons. 30 | batch_size: total number of MCMC configurations to generate across all 31 | devices. 32 | init_width: width of (atom-centred) Gaussian used to generate initial 33 | electron configurations. 34 | 35 | Returns: 36 | array of (batch_size, (nalpha+nbeta)*ndim) of initial (random) electron 37 | positions in the initial MCMC configurations and ndim is the dimensionality 38 | of the space (i.e. typically 3). 39 | """ 40 | if given_atomic_spin_configs is None: 41 | logging.warning('no spin assignment in the system config, may lead to unexpected initialization') 42 | 43 | if (sum(atom.charge for atom in molecule) != sum(electrons) 44 | and 45 | given_atomic_spin_configs is None): 46 | if len(molecule) == 1: 47 | atomic_spin_configs = [electrons] 48 | else: 49 | raise NotImplementedError('No initialization policy yet ' 50 | 'exists for charged molecules.') 51 | else: 52 | 53 | atomic_spin_configs = [ 54 | (atom.element.nalpha - int((atom.atomic_number - atom.charge) // 2), 55 | atom.element.nbeta - int((atom.atomic_number - atom.charge) // 2)) 56 | for atom in molecule 57 | ] if given_atomic_spin_configs is None else given_atomic_spin_configs 58 | 59 | assert sum(sum(x) for x in atomic_spin_configs) == sum(electrons) 60 | while tuple(sum(x) for x in zip(*atomic_spin_configs)) != electrons: 61 | i = np.random.randint(len(atomic_spin_configs)) 62 | nalpha, nbeta = atomic_spin_configs[i] 63 | if atomic_spin_configs[i][0] > 0: 64 | atomic_spin_configs[i] = nalpha - 1, nbeta + 1 65 | 66 | # Assign each electron to an atom initially. 67 | electron_positions = [] 68 | for i in range(2): 69 | for j in range(len(molecule)): 70 | atom_position = jnp.asarray(molecule[j].coords) 71 | electron_positions.append( 72 | jnp.tile(atom_position, atomic_spin_configs[j][i])) 73 | electron_positions = jnp.concatenate(electron_positions) 74 | # Create a batch of configurations with a Gaussian distribution about each 75 | # atom. 76 | key, subkey = jax.random.split(key) 77 | return ( 78 | electron_positions + 79 | init_width * 80 | jax.random.normal(subkey, shape=(batch_size, electron_positions.size))) -------------------------------------------------------------------------------- /jaqmc/dmc/state.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | Data classes representing the internal states of DMC. 9 | ''' 10 | 11 | from typing import Optional 12 | 13 | import attr 14 | import jax 15 | import jax.numpy as jnp 16 | import numpy as np 17 | 18 | from .energy_estimator import MixedEstimatorCalculator 19 | from .effective_time_step_calculator import EffectiveTimeStepCalculator 20 | from .utils import agg_sum 21 | 22 | @attr.s(auto_attribs=True) 23 | class State: 24 | # Position, walker_age, weight, local_energy are 25 | # expected to be flattened, in the sense the size of their first dimension 26 | # is the total number of the walkers in this batch. 27 | position: Optional[jnp.ndarray] = None 28 | walker_age: Optional[jnp.ndarray] = None 29 | weight: Optional[jnp.ndarray] = None 30 | local_energy: Optional[jnp.ndarray] = None 31 | energy_offset: Optional[float] = None 32 | target_num_walkers: Optional[int] = None 33 | mixed_estimator: Optional[float] = None 34 | mixed_estimator_calculator: Optional[MixedEstimatorCalculator] = None 35 | effective_time_step_calculator: Optional[EffectiveTimeStepCalculator] = None 36 | 37 | @classmethod 38 | def default(cls, 39 | init_position: Optional[jnp.ndarray], 40 | calc_energy_func, 41 | mixed_estimator_num_steps, 42 | energy_window_size, 43 | time_step): 44 | ''' 45 | Create a default State, for instance, for the initial step of DMC. 46 | ''' 47 | default_walker_age = np.ones(len(init_position)) 48 | default_weight = np.ones(len(init_position)) 49 | default_energy = calc_energy_func(init_position) 50 | default_target_num_walkers = agg_sum(len(init_position)) 51 | default_mixed_estimator_calculator = MixedEstimatorCalculator( 52 | mixed_estimator_num_steps=mixed_estimator_num_steps, 53 | energy_window_size=energy_window_size) 54 | default_effective_time_step_calculator = EffectiveTimeStepCalculator(time_step) 55 | 56 | return cls( 57 | position=init_position, 58 | walker_age=default_walker_age, 59 | weight=default_weight, 60 | local_energy=None, 61 | energy_offset=default_energy, 62 | target_num_walkers=default_target_num_walkers, 63 | mixed_estimator=default_energy, 64 | mixed_estimator_calculator=default_mixed_estimator_calculator, 65 | effective_time_step_calculator=default_effective_time_step_calculator) 66 | 67 | @attr.s(auto_attribs=True) 68 | class IterationOutput: 69 | ''' 70 | The output of a DMC iteration. 71 | 72 | A ckpt file should be able to recover all the data listed so that the process 73 | can be continued. 74 | ''' 75 | succeeded: bool 76 | state: State 77 | key: jax.random.PRNGKey 78 | average_energy: Optional[float] = None 79 | num_old_walkers: Optional[int] = None 80 | acceptance_ratio: Optional[float] = None 81 | effective_time_step: Optional[float] = None 82 | debug_info: Optional[dict] = None 83 | -------------------------------------------------------------------------------- /examples/dmc/deeperwin/deeperwin_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | Largely directly copied from DeepErwin repo with minor modification to match 9 | DMC interface. 10 | ''' 11 | 12 | from deeperwin.model import Wavefunction, init_model_fixed_params 13 | from deeperwin.configuration import ModelConfig, PhysicalConfig 14 | import haiku as hk 15 | import jax 16 | import jax.numpy as jnp 17 | import numpy as np 18 | 19 | def evaluate_sum_of_determinants_with_sign(mo_matrix_up, mo_matrix_dn, use_full_det): 20 | LOG_EPSILON = 1e-8 21 | 22 | if use_full_det: 23 | mo_matrix = jnp.concatenate([mo_matrix_up, mo_matrix_dn], axis=-2) 24 | sign_total, log_total = jnp.linalg.slogdet(mo_matrix) 25 | else: 26 | sign_up, log_up = jnp.linalg.slogdet(mo_matrix_up) 27 | sign_dn, log_dn = jnp.linalg.slogdet(mo_matrix_dn) 28 | log_total = log_up + log_dn 29 | sign_total = sign_up * sign_dn 30 | log_shift = jnp.max(log_total, axis=-1, keepdims=True) 31 | psi = jnp.exp(log_total - log_shift) * sign_total 32 | psi = jnp.sum(psi, axis=-1) # sum over determinants 33 | sign = jnp.sign(psi) 34 | log_psi_sqr = 2 * (jnp.log(jnp.abs(psi) + LOG_EPSILON) + jnp.squeeze(log_shift, -1)) 35 | return sign, log_psi_sqr 36 | 37 | class UpdatedWavefunction(Wavefunction): 38 | def __init__(self, config: ModelConfig, phys_config: PhysicalConfig, name="wf"): 39 | super().__init__(name=name, config=config, phys_config=phys_config) 40 | 41 | def __call__(self, r, R, Z, fixed_params=None): 42 | fixed_params = fixed_params or {} 43 | diff_dist, features = self._calculate_features(r, R, Z, fixed_params.get('input')) 44 | embeddings = self._calculate_embedding(features) 45 | mo_up, mo_dn = self._calculate_orbitals(diff_dist, embeddings, fixed_params.get('orbitals')) 46 | 47 | # This is the main change 48 | psi_sign, log_psi_sqr = evaluate_sum_of_determinants_with_sign(mo_up, mo_dn, self.config.orbitals.use_full_det) 49 | 50 | # Jastrow factor to the total wavefunction 51 | if self.config.jastrow: 52 | log_psi_sqr += self._calculate_jastrow(embeddings) 53 | 54 | # Electron-electron-cusps 55 | if self.config.use_el_el_cusp_correction: 56 | log_psi_sqr += self._el_el_cusp(diff_dist.dist_el_el) 57 | return psi_sign, log_psi_sqr 58 | 59 | def build_log_psi_squared_with_sign(config: ModelConfig, phys_config: PhysicalConfig, fixed_params=None): 60 | # Initialize fixed model parameters 61 | fixed_params = fixed_params or init_model_fixed_params(config, phys_config) 62 | 63 | # Build model 64 | model = hk.multi_transform(lambda: UpdatedWavefunction(config, phys_config).init_for_multitransform()) 65 | 66 | # Initialized trainable parameters using a dummy batch 67 | n_el, _, R, Z = phys_config.get_basic_params() 68 | r = np.random.normal(size=[1, n_el, 3]) 69 | rng = jax.random.PRNGKey(np.random.randint(2**31)) 70 | params = model.init(rng, r, R, Z, fixed_params) 71 | 72 | # Remove rng-argument (replace by None) and move parameters to back of function 73 | log_psi_sqr_with_sign = lambda params, *batch: model.apply[0](params, None, *batch) 74 | orbitals = lambda params, *batch: model.apply[1](params, None, *batch) 75 | 76 | return log_psi_sqr_with_sign, orbitals, params, fixed_params 77 | -------------------------------------------------------------------------------- /jaqmc/dmc/data_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | `DataPath` represents a pair of local and remote paths for data saving and restoring. 9 | Also handle the path initialization logic. 10 | ''' 11 | 12 | import os 13 | import tempfile 14 | from typing import Optional, Union, Tuple 15 | 16 | import attr 17 | 18 | from .storage_handler import StorageHandler 19 | 20 | @attr.s(auto_attribs=True, frozen=True) 21 | class DataPath: 22 | ''' 23 | A wrapper for a local path and a remote path. 24 | ''' 25 | local_path: Optional[str] = None 26 | remote_path: Optional[str] = None 27 | 28 | def has_local(self): 29 | return self.local_path is not None and self.local_path.strip() 30 | 31 | def has_remote(self): 32 | return self.remote_path is not None and self.remote_path.strip() 33 | 34 | def initialize_data_paths(save_path: Optional[Union[DataPath, str]], 35 | restore_path: Optional[Union[DataPath, str]], 36 | remote_storage_handler: StorageHandler) -> Tuple[DataPath, DataPath]: 37 | ''' 38 | Create local paths if not exist. 39 | 40 | Download data files from remote paths to the corresponding local ones if available. 41 | 42 | Return the updated save path and restore path. 43 | Those paths need to be updated when no local path is provided. We always 44 | need a local path especially for save path to keep newly produced data files. 45 | ''' 46 | 47 | save_path = _resolve_path(save_path) 48 | restore_path = _resolve_path(restore_path) 49 | return _setup_path(save_path, restore_path, remote_storage_handler) 50 | 51 | def _resolve_path(path: Optional[Union[DataPath, str]]) -> DataPath: 52 | if path is None: 53 | return DataPath() 54 | if isinstance(path, str): 55 | return DataPath(local_path=path) 56 | # isinstance(path, DataPath) 57 | return path 58 | 59 | def _setup_path(save_path: DataPath, 60 | restore_path: DataPath, 61 | remote_storage_handler: StorageHandler) -> Tuple[DataPath, DataPath]: 62 | 63 | def setup(ckpt_path: DataPath): 64 | if ckpt_path.has_local(): 65 | local_path = ckpt_path.local_path 66 | if not os.path.exists(local_path): 67 | os.makedirs(local_path) 68 | elif not os.path.isdir(local_path): 69 | raise Exception(f'The path {local_path} already exists and it is not a directory') 70 | else: 71 | # If no local path is specified, we use a temporary directory. 72 | # This is necessary for both `save_path` (so that we can save 73 | # ckpts) and `restore_path` (so that we can download remote ckpts to 74 | # this path). 75 | local_path = tempfile.mkdtemp() 76 | 77 | if ckpt_path.has_remote(): 78 | if not remote_storage_handler.exists_dir(ckpt_path.remote_path): 79 | if remote_storage_handler.exists(ckpt_path.remote_path): 80 | raise Exception(f'The remote path {ckpt_path.remote_path} already exists and it is not a directory') 81 | else: 82 | remote_storage_handler.mkdir(ckpt_path.remote_path) 83 | return DataPath(local_path, ckpt_path.remote_path) 84 | 85 | save_path = setup(save_path) 86 | restore_path = setup(restore_path) 87 | return save_path, restore_path 88 | -------------------------------------------------------------------------------- /jaqmc/loss/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import functools 8 | from typing import Any, Iterable, Mapping, Tuple, Union 9 | from typing_extensions import Protocol 10 | 11 | import chex 12 | import jax 13 | import jax.numpy as jnp 14 | import kfac_jax 15 | 16 | PMAP_AXIS_NAME = 'qmc_pmap_axis' 17 | 18 | def wrap(func): 19 | return functools.partial(func, axis_name=PMAP_AXIS_NAME) 20 | 21 | pmap = wrap(jax.pmap) 22 | pmean = wrap(kfac_jax.utils.pmean_if_pmap) 23 | psum = wrap(kfac_jax.utils.psum_if_pmap) 24 | gather = wrap(jax.lax.all_gather) 25 | 26 | def pmean_with_mask(value, mask): 27 | ''' 28 | Only take pmean with the not-masked-out value (namely mask > 0). Here `mask` 29 | is expected to only take value between 0 and 1. 30 | ''' 31 | return psum(jnp.sum(value * mask)) / psum(jnp.sum(mask)) 32 | 33 | def pmean_with_structure_mask(value, mask): 34 | ''' 35 | Only take pmean with the not-masked-out value (namely mask > 0). Here `mask` 36 | is expected to only take value between 0 and 1. 37 | ''' 38 | def inner(x, y): 39 | return psum(jnp.sum(x * y, axis=(0))) / psum(jnp.sum(y, axis=(0))) 40 | 41 | value_masked_mean = jax.tree_util.tree_map(inner, value, mask) 42 | return value_masked_mean 43 | 44 | ParamTree = Union[jnp.ndarray, Iterable['ParamTree'], Mapping[Any, 'ParamTree']] 45 | 46 | class LocalEnergy(Protocol): 47 | 48 | def __call__(self, params: ParamTree, key: chex.PRNGKey, 49 | data: jnp.ndarray) -> jnp.ndarray: 50 | """ 51 | Returns the local energy of a Hamiltonian at a configuration. 52 | 53 | Args: 54 | params: network parameters. 55 | key: JAX PRNG state. 56 | data: walkers consisting of electronic configurations. 57 | """ 58 | 59 | class WaveFuncLike(Protocol): 60 | 61 | def __call__(self, params: ParamTree, 62 | electrons: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: 63 | """ 64 | Returns the sign and log magnitude of the wavefunction. 65 | 66 | Args: 67 | params: network parameters. 68 | electrons: electron positions. 69 | """ 70 | 71 | class LogWaveFuncLike(Protocol): 72 | 73 | def __call__(self, params: ParamTree, electrons: jnp.ndarray) -> jnp.ndarray: 74 | """ 75 | Returns the log magnitude of the wavefunction. 76 | 77 | Args: 78 | params: network parameters. 79 | electrons: electron positions 80 | """ 81 | 82 | @chex.dataclass 83 | class BaseFuncState: 84 | """ 85 | Base class for func_state passed to the loss function. 86 | """ 87 | pass 88 | 89 | @chex.dataclass 90 | class BaseAuxData: 91 | """ 92 | Base class for auxillary data returned from the loss function. 93 | """ 94 | pass 95 | 96 | class Loss(Protocol): 97 | 98 | def __call__(self, 99 | params: ParamTree, 100 | func_state: BaseFuncState, 101 | key: chex.PRNGKey, 102 | data: jnp.ndarray) -> Tuple[jnp.ndarray, Tuple[BaseFuncState, BaseAuxData]]: 103 | """ 104 | Note: kfac_jax.optimizer.Optimizer should turn on flags `value_func_has_rng=True` and 105 | `value_func_has_aux=True` when working with loss functions of this interface. 106 | 107 | Args: 108 | params: network parameters. 109 | func_state: function state passed to the loss function to control its behavior. 110 | key: JAX PRNG state. 111 | data: MCMC configuration to evaluate. 112 | Returns: 113 | (loss value, (updated func_state, auxillary data) 114 | """ 115 | -------------------------------------------------------------------------------- /tests/dmc/data_path_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | 9 | from absl.testing import absltest 10 | import chex 11 | import tempfile 12 | import shutil 13 | 14 | from jaqmc.dmc.data_path import DataPath, _resolve_path, _setup_path 15 | from jaqmc.dmc.storage_handler import dummy_storage_handler, local_storage_handler 16 | 17 | class DataPathTest(chex.TestCase): 18 | 19 | test_local_path = 'test/local/path' 20 | test_remote_path = 'test/remote/path' 21 | test_remote_path2 = 'test/remote/path2' 22 | 23 | def test_resolve_string_path(self): 24 | path = _resolve_path(self.test_local_path) 25 | self.assertIsInstance(path, DataPath) 26 | self.assertEqual(path.local_path, self.test_local_path) 27 | self.assertIsNone(path.remote_path) 28 | 29 | 30 | def test_setup_local_path(self): 31 | with tempfile.TemporaryDirectory() as tmpdir: 32 | save_path = _resolve_path(os.path.join(tmpdir, self.test_local_path)) 33 | restore_path = DataPath() 34 | updated_save_path, updated_restore_path = _setup_path( 35 | save_path=save_path, 36 | restore_path=restore_path, 37 | remote_storage_handler=dummy_storage_handler) 38 | 39 | self.assertEqual(updated_save_path.local_path, save_path.local_path) 40 | self.assertEqual(updated_save_path.remote_path, save_path.remote_path) 41 | self.assertEqual(updated_restore_path.remote_path, restore_path.remote_path) 42 | # The empty local path should be replaced by some temp path 43 | self.assertNotEqual(updated_restore_path.local_path, restore_path.local_path) 44 | 45 | self.assertTrue(local_storage_handler.exists_dir(updated_save_path.local_path)) 46 | self.assertTrue(local_storage_handler.exists_dir(updated_restore_path.local_path)) 47 | 48 | # Remove the temp path for clean-up 49 | shutil.rmtree(updated_restore_path.local_path) 50 | 51 | def test_setup_remote_path(self): 52 | with tempfile.TemporaryDirectory() as tmpdir: 53 | with tempfile.TemporaryDirectory() as tmp_remote_dir: 54 | save_path = DataPath( 55 | local_path=os.path.join(tmpdir, self.test_local_path), 56 | remote_path=os.path.join(tmp_remote_dir, self.test_remote_path)) 57 | restore_path = DataPath(remote_path=os.path.join(tmpdir, self.test_remote_path2)) 58 | updated_save_path, updated_restore_path = _setup_path( 59 | save_path=save_path, 60 | restore_path=restore_path, 61 | remote_storage_handler=local_storage_handler) 62 | 63 | self.assertEqual(updated_save_path.local_path, save_path.local_path) 64 | self.assertEqual(updated_save_path.remote_path, save_path.remote_path) 65 | self.assertEqual(updated_restore_path.remote_path, restore_path.remote_path) 66 | # The empty local path should be replaced by some temp path 67 | self.assertNotEqual(updated_restore_path.local_path, restore_path.local_path) 68 | 69 | self.assertTrue(local_storage_handler.exists_dir(updated_save_path.local_path)) 70 | self.assertTrue(local_storage_handler.exists_dir(updated_save_path.remote_path)) 71 | self.assertTrue(local_storage_handler.exists_dir(updated_restore_path.local_path)) 72 | 73 | # Remove the temp path for clean-up 74 | shutil.rmtree(updated_restore_path.local_path) 75 | 76 | if __name__ == '__main__': 77 | absltest.main() 78 | -------------------------------------------------------------------------------- /jaqmc/dmc/energy_estimator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | Energy mixed estimator calculation (stateful). 9 | ''' 10 | 11 | from collections import deque 12 | 13 | # The energy calculation is not relying on jit or pmap, so old numpy may be 14 | # performing better. 15 | import numpy as np 16 | 17 | def append_wrapper(que, elem): 18 | if len(que) < que.maxlen: 19 | to_pop = None 20 | else: 21 | to_pop = que[0] 22 | que.append(float(elem)) 23 | return to_pop 24 | 25 | def update_deque_then_get_array(que, elem): 26 | que.append(float(elem)) 27 | return np.array(que) 28 | 29 | class MixedEstimatorCalculator: 30 | def __init__(self, mixed_estimator_num_steps, energy_window_size): 31 | self.mixed_estimator_num_steps = mixed_estimator_num_steps 32 | self.all_energy = deque([], energy_window_size) 33 | self.all_Pi_log = deque([], energy_window_size) 34 | self.all_total_weight = deque([], energy_window_size) 35 | self.all_energy_offsets = deque([], mixed_estimator_num_steps) 36 | 37 | def update_then_get_normalized_Pi(self, new_Pi_log): 38 | # Pi is likely to be exploded with large `mixed_estimator_num_steps` 39 | # or large `enery_window_size`. To resolve this issue, we remove the 40 | # largest `Pi_log` from all the `Pi_log` so that taking exponential of 41 | # those `Pi_log` won't make trouble. Note that such removal won't affect 42 | # the final result because `Pi` show up in both top and bottom of the mixed 43 | # estimator ratio, in which case removing a common factor doesn't change anything. 44 | self.all_Pi_log.append(float(new_Pi_log)) 45 | all_Pi_log_array = np.array(self.all_Pi_log) 46 | max_Pi_log = np.max(all_Pi_log_array) 47 | all_Pi_log_array -= max_Pi_log 48 | return np.exp(all_Pi_log_array) 49 | 50 | @staticmethod 51 | def calculate_numerator(all_total_weight, all_Pi, all_energy): 52 | # Here `all_total_weight` is involved because we expect the energy 53 | # contained in `all_energy` are calculated with total weight divided, 54 | # namely the weight used to calculate energy is "normalized" 55 | return np.sum(all_energy * all_Pi * all_total_weight) 56 | 57 | @staticmethod 58 | def calculate_denominator(all_total_weight, all_Pi): 59 | return np.dot(all_total_weight, all_Pi) 60 | 61 | def get_Pi_log(self, time_step, energy_offset): 62 | self.all_energy_offsets.append(energy_offset) 63 | return -time_step * np.sum(self.all_energy_offsets) 64 | 65 | def run(self, energy_offset, energy, total_weight, time_step): 66 | ''' 67 | Args: 68 | energy_offset: `E_T` used to do population control. 69 | energy: The weighted-averaged local energy. Here the weights used are 70 | expected to be normalized (namely `w_i(t) / W(t)` where `W(t)` 71 | is the total weight 72 | total_weight: Total weight `W(t)` as mentioned in Umrigar paper 73 | Return: 74 | The mixed estimator of energy. 75 | ''' 76 | new_Pi_log = self.get_Pi_log(time_step, energy_offset) 77 | 78 | all_Pi = self.update_then_get_normalized_Pi(new_Pi_log) 79 | all_energy = update_deque_then_get_array(self.all_energy, energy) 80 | all_total_weight = update_deque_then_get_array(self.all_total_weight, total_weight) 81 | 82 | numerator = self.calculate_numerator(all_total_weight, all_Pi, all_energy) 83 | denominator = self.calculate_denominator(all_total_weight, all_Pi) 84 | return numerator / denominator 85 | -------------------------------------------------------------------------------- /examples/dmc/deeperwin/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | A show case for DeepErwin integration. 9 | ''' 10 | 11 | import os 12 | import pathlib 13 | import sys 14 | import time 15 | import ruamel.yaml 16 | 17 | from absl import app 18 | from absl import logging 19 | from deeperwin.utils import merge_params 20 | from deeperwin.checkpoints import load_run 21 | from deeperwin.configuration import Configuration 22 | import jax 23 | import jax.numpy as jnp 24 | 25 | from jaqmc.dmc import run 26 | from jaqmc.dmc.ckpt_metric_manager import DataPath 27 | 28 | from deeperwin_model import build_log_psi_squared_with_sign 29 | 30 | def run_wrapper(deeperwin_ckpt_file, dmc_cfg): 31 | 32 | reuse_data = load_run(deeperwin_ckpt_file) 33 | config = reuse_data.config 34 | 35 | # Build wavefunction and initialize parameters 36 | log_psi_squared_with_sign, _, params, fixed_params = \ 37 | build_log_psi_squared_with_sign(config.model, 38 | config.physical, 39 | reuse_data.fixed_params) 40 | params = merge_params(params, reuse_data.params) 41 | 42 | nuclei = jnp.array(config.physical.R) 43 | charges = jnp.array(config.physical.Z) 44 | 45 | def log_psi_with_sign(x): 46 | sign, log_squared = log_psi_squared_with_sign(params, 47 | x.reshape((-1, 3)), 48 | nuclei, 49 | charges, 50 | fixed_params) 51 | return sign, log_squared / 2 52 | 53 | raw_position_shape = reuse_data.mcmc_state.r.shape 54 | position = reuse_data.mcmc_state.r.reshape( 55 | (raw_position_shape[0], raw_position_shape[1] * raw_position_shape[2])) 56 | 57 | key = jax.random.PRNGKey(int(1e6 * time.time())) 58 | run( 59 | position, 60 | dmc_cfg.iterations, 61 | log_psi_with_sign, 62 | dmc_cfg.time_step, 63 | key, 64 | nuclei=nuclei, 65 | charges=charges, 66 | 67 | # Below are optional arguments 68 | mixed_estimator_num_steps=dmc_cfg.mixed_estimator_num_steps, 69 | energy_window_size=dmc_cfg.energy_window_size, 70 | weight_branch_threshold=dmc_cfg.weight_branch_threshold, 71 | update_energy_offset_interval=dmc_cfg.update_energy_offset_interval, 72 | energy_offset_update_amplitude=dmc_cfg.energy_offset_update_amplitude, 73 | energy_cutoff_alpha=dmc_cfg.energy_cutoff_alpha, 74 | effective_time_step_update_period=dmc_cfg.effective_time_step_update_period, 75 | energy_outlier_rel_threshold=dmc_cfg.energy_outlier_rel_threshold, 76 | fix_size=dmc_cfg.fix_size, 77 | ebye_move=dmc_cfg.ebye_move, 78 | block_size=dmc_cfg.block_size, 79 | max_restore_nums=dmc_cfg.max_restore_nums, 80 | save_path=DataPath(dmc_cfg.log.save_path, dmc_cfg.log.remote_save_path), 81 | ) 82 | 83 | def main(_): 84 | deeperwin_ckpt = FLAGS.deeperwin_ckpt 85 | dmc_cfg = FLAGS.dmc_config 86 | 87 | logging.get_absl_handler().python_handler.stream = sys.stdout 88 | logging.set_verbosity(logging.INFO) 89 | 90 | run_wrapper(deeperwin_ckpt, dmc_cfg) 91 | 92 | if __name__ == '__main__': 93 | from absl import flags 94 | from ml_collections.config_flags import config_flags 95 | FLAGS = flags.FLAGS 96 | 97 | flags.DEFINE_string('deeperwin_ckpt', '', 'NA') 98 | dmc_config_file = str(pathlib.Path(os.path.abspath(__file__)).parents[1].absolute() / 'dmc_config.py') 99 | config_flags.DEFINE_config_file('dmc_config', dmc_config_file, 'Path to DMC config file.') 100 | 101 | app.run(main) 102 | -------------------------------------------------------------------------------- /jaqmc/dmc/branch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | Variable-size branching. 9 | ''' 10 | 11 | from functools import partial 12 | 13 | from absl import logging 14 | import jax 15 | import jax.numpy as jnp 16 | import numpy as np 17 | 18 | def branch(weight, key, branch_arrays, min_thres=0.3, max_thres=2): 19 | ''' 20 | Split large weight, merge small weight. 21 | We will only be able to do branching within the same host, no balancing 22 | across hosts for now 23 | 24 | Args: 25 | weight: The weight according to which the split and merge will be done. 26 | key: The RNG key used to determine which walker should be kept when merging 27 | two walkers. 28 | branch_arrays: A list of arrays with the same length as `weight`. Those 29 | arrays will be branched accordingly. 30 | min_thres, max_thres: The threshold used to determine which walkers to 31 | split and merge. 32 | ''' 33 | merge_pairs = round_merge_pairs(int(jnp.sum(weight < min_thres)) // 2 + 1) 34 | if merge_pairs > 0.05 * len(weight): 35 | logging.warning(f'large number of pairs to merge: {merge_pairs}') 36 | weight, repeat_num = do_branch(weight, key, merge_pairs, 37 | min_thres=min_thres, 38 | max_thres=max_thres) 39 | weight = weight.repeat(repeat_num, axis=0) 40 | return weight, [l.repeat(repeat_num, axis=0) for l in branch_arrays] 41 | 42 | def round_merge_pairs(target_num): 43 | ''' 44 | In order to reduce the number of re-compilation of `do_branch` method, we 45 | round the `merge_pairs` value. (Note that `do_branch` will be re-compiled every 46 | time it meets a new value of `merge_pairs`.) 47 | ''' 48 | if target_num <= 10: 49 | return target_num 50 | num_digit = int(np.log10(target_num)) 51 | most_sig_digit = target_num // (10 ** num_digit) 52 | return most_sig_digit * (10 ** num_digit) 53 | 54 | @partial(jax.jit, static_argnums=(2,)) 55 | def do_branch(weight, key, merge_pairs, min_thres=0.1, max_thres=1): 56 | repeat_num = jnp.ones(weight.shape, dtype='int32') 57 | 58 | # We take the k-smallest value of weight and their indices by first multiplying 59 | # the weight array by -1 then take the top-k elements. 60 | neg_smallest_k, smallest_k_indices = jax.lax.top_k(-1 * weight, min(2 * merge_pairs, len(weight))) 61 | smallest_k = -neg_smallest_k.reshape((merge_pairs, 2)) 62 | smallest_k_indices = smallest_k_indices.reshape((merge_pairs, 2)) 63 | 64 | def helper(index, values): 65 | weight, repeat_num, key = values 66 | k1, k2 = smallest_k[index] 67 | k1_index, k2_index = smallest_k_indices[index] 68 | 69 | weight, repeat_num, key = jax.lax.cond( 70 | k1 < min_thres, 71 | update_weight, 72 | lambda x: x[0], 73 | ((weight, repeat_num, key), k1, k2, k1_index, k2_index)) 74 | return weight, repeat_num, key 75 | 76 | weight, repeat_num, key = jax.lax.fori_loop( 77 | 0, 78 | merge_pairs, 79 | helper, 80 | (weight, repeat_num, key)) 81 | 82 | repeat_num *= 1 + (weight > max_thres) 83 | weight *= 1 - (weight > max_thres) / 2 84 | 85 | return weight, repeat_num 86 | 87 | def update_weight(_input): 88 | (weight, repeat_num, key), k1, k2, k1_index, k2_index = _input 89 | key, sub_key = jax.random.split(key) 90 | keep_index, rm_index = jax.lax.cond(jax.random.uniform(sub_key) < k1 / (k1 + k2), 91 | lambda _: (k1_index, k2_index), 92 | lambda _: (k2_index, k1_index), 93 | operand=None) 94 | weight = weight.at[keep_index].set(k1 + k2) 95 | repeat_num = repeat_num.at[rm_index].set(0) 96 | return weight, repeat_num, key 97 | -------------------------------------------------------------------------------- /examples/dmc/dmc_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | DMC configurations. 9 | ''' 10 | 11 | import ml_collections 12 | 13 | def get_config(): 14 | # Some terminology / notation may refer to 15 | # Umrigar C J, Nightingale M P, Runge K J. A diffusion Monte Carlo algorithm with very small time‐step errors[J]. The Journal of chemical physics, 1993, 99(4): 2865-2890. 16 | 17 | cfg = ml_collections.ConfigDict({ 18 | # Number of DMC iterations to run. 19 | # Note that in practice we will run more iterations (by one more "block") 20 | # then the one specified here just to make sure that 21 | # the data in the final checkpoint is outlier-free. 22 | 'iterations': 10000, 23 | # DMC time step. Should be 24 | # 1. small enough to avoid finite-time error 25 | # 2. not too small otherwise the DMC process will last way too long. 26 | 'time_step': 0.001, 27 | # The energy offset E_T will be updated every `update_energy_offset_interval` 28 | # iterations. 29 | 'update_energy_offset_interval': 1, 30 | # The amplitude of adjustment on E_T given the weight calculated in each step. 31 | 'energy_offset_update_amplitude':1, 32 | # Metric info will be printed out every `print_info_interval`. 33 | 'print_info_interval': 20, 34 | # The lower and upper limits for branching / merging. 35 | 'weight_branch_threshold': (0.3, 2), 36 | # The T_p used in the calculation of \Pi to adjust the effect of updating E_T. 37 | 'energy_window_size': 1000, 38 | # The size of rolling window in which the energy mixed estimator is calculated. 39 | 'mixed_estimator_num_steps': 5000, 40 | # The relative threshold to determine whether some calculated local 41 | # energies are outlier, in which case we will rerun the iteration with 42 | # a different random number wishing for better luck. 43 | # Negative value means turning off such mechanism. 44 | 'energy_outlier_rel_threshold': -1.0, 45 | # If `energy_cutoff_alpha` < 0, then fallback to UNR algo. for 46 | # weight update, otherwise do energy cutoff following 47 | # Zen A, Sorella S, Gillan M J, et al. Boosting the accuracy and speed of quantum Monte Carlo: Size consistency and time step[J]. Physical Review B, 2016, 93(24): 241118. 48 | 'energy_cutoff_alpha': 0.2, 49 | # Whether do fix-size branching or not. 50 | # It can be turned on to boost efficiency due to better JAX jitting. 51 | 'fix_size': False, 52 | # If True, use elec-by-elec moves rather than walker-by-walker moves. 53 | # By default it's turned off due to efficiency concern. 54 | 'ebye_move': False, 55 | # Negative `effective_time_step_update_period` means always 56 | # update effective time step. 57 | 'effective_time_step_update_period': -1, 58 | 59 | # The size of a block of iterations. The recovery mechanism will 60 | # roll back to the previous block when error happens. 61 | 'block_size': 5000, 62 | # The max number of rolling-back which the recovery mechasim will 63 | # perform before it gives up and abort the process. 64 | 'max_restore_nums': 3, 65 | 66 | 'log': { 67 | # The local path that the checkpoint will be saved to. 68 | 'save_path': '', 69 | # The remote path that the checkpoint will be upload to. 70 | 'remote_save_path': '', 71 | # The local path that the previous checkpoint will be loaded from. 72 | 'restore_path': '', 73 | # The remote path that the previous checkpoint will be downloaded from. 74 | 'remote_restore_path': '', 75 | }, 76 | } 77 | ) 78 | return cfg 79 | -------------------------------------------------------------------------------- /jaqmc/dmc/metric_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | Handle metric csv file creation / writing / flushing / loading. 9 | ''' 10 | 11 | import csv 12 | import os 13 | import shutil 14 | from typing import List, Union 15 | 16 | from absl import logging 17 | import pandas as pd 18 | 19 | class MetricManager: 20 | ''' 21 | A context manager for the metric file. 22 | 23 | It support reset the underlying file using a given metric file, for instance, 24 | when we need to revert the state due to certain failure. 25 | ''' 26 | 27 | def __init__(self, 28 | file_name: str, 29 | schema: List[str], 30 | local_save_path: str): 31 | ''' 32 | Args: 33 | file_name: The name of the metric file 34 | schema: The column names of the metric file. We will prepand "step" to it, 35 | so no need to keep "step" in `schema`. 36 | local_save_path: The local directory storing the metric file 37 | ''' 38 | self.file_name = file_name 39 | self.full_schema = ['step'] + schema 40 | self.local_metric_path = os.path.join(local_save_path, self.file_name) 41 | 42 | self.file_handle = None 43 | self.csv_writer = None 44 | 45 | def __enter__(self): 46 | self._setup_csv_writer() 47 | 48 | def __exit__(self, exception_type, exception_value, traceback): 49 | self.flush() 50 | if self.file_handle is not None: 51 | self.file_handle.close() 52 | 53 | def _setup_csv_writer(self): 54 | ''' 55 | Create the file handle. 56 | 57 | Write schema to metric file if needed 58 | ''' 59 | csv_should_write_header = not os.path.exists(self.local_metric_path) 60 | self.file_handle = open(self.local_metric_path, 'a+') 61 | self.csv_writer = csv.writer(self.file_handle, delimiter=',') 62 | if csv_should_write_header: 63 | self.csv_writer.writerow(self.full_schema) 64 | 65 | def flush(self): 66 | if self.file_handle is not None and not self.file_handle.closed: 67 | try: 68 | self.file_handle.flush() 69 | except Exception as e: 70 | logging.warning(f'Hitting error {e} when flushing metric file') 71 | 72 | def close(self): 73 | if self.file_handle is not None and not self.file_handle.closed: 74 | try: 75 | self.file_handle.close() 76 | except Exception as e: 77 | logging.warning(f'Hitting error {e} when closing metric file handle') 78 | 79 | def get_metric_data(self) -> pd.DataFrame: 80 | ''' 81 | Return the content of the current metric file. 82 | ''' 83 | self.flush() 84 | data = pd.read_csv(self.local_metric_path, names=self.full_schema, header=0) 85 | return data 86 | 87 | def write(self, step: int, data: List[Union[int, float]]): 88 | if self.csv_writer is not None: 89 | self.csv_writer.writerow([step] + list(data)) 90 | 91 | def reset(self, target_path: str, rm_source: bool = True): 92 | ''' 93 | Reset the managed metric file by the file at `target_path`. 94 | 95 | If `rm_source` is True, remove the file at `target_path` afterwards. 96 | ''' 97 | self.close() 98 | if not self._samefile(target_path, self.local_metric_path): 99 | if rm_source: 100 | shutil.move(target_path, self.local_metric_path) 101 | else: 102 | shutil.copy(target_path, self.local_metric_path) 103 | self._setup_csv_writer() 104 | 105 | @staticmethod 106 | def _samefile(f1: str, f2: str) -> bool: 107 | if not os.path.exists(f1) or not os.path.exists(f2): 108 | return False 109 | return os.path.samefile(f1, f2) 110 | -------------------------------------------------------------------------------- /examples/dmc/lapnet/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | A show case for LapNet & LapJax (Forward-Laplacian) integration. 9 | 10 | LapNet: https://github.com/bytedance/LapNet 11 | LapJax: https://github.com/YWolfeee/lapjax 12 | ''' 13 | 14 | import pathlib 15 | import sys 16 | import time 17 | import os 18 | 19 | from absl import app 20 | from absl import logging 21 | from lapnet import base_config 22 | from lapnet import checkpoint 23 | from lapnet import networks 24 | from lapnet import hamiltonian 25 | import jax 26 | import jax.numpy as jnp 27 | 28 | from jaqmc.dmc import run 29 | from jaqmc.dmc.ckpt_metric_manager import DataPath 30 | 31 | 32 | def get_molecule_configuration(cfg): 33 | atoms = jnp.stack([jnp.array(atom.coords) for atom in cfg.system.molecule]) 34 | charges = jnp.array([atom.charge for atom in cfg.system.molecule]) 35 | spins = cfg.system.electrons 36 | return atoms, charges, spins 37 | 38 | def run_wrapper(cfg, dmc_cfg): 39 | atoms, charges, nspins = get_molecule_configuration(cfg) 40 | key = jax.random.PRNGKey(666 if cfg.debug.deterministic else int(1e6 * time.time())) 41 | 42 | _, network, *_ = networks.network_provider(cfg)(atoms, nspins, charges) 43 | 44 | vmc_ckpt_save_path = checkpoint.create_save_path(cfg.log.save_path) 45 | vmc_ckpt_restore_filename = checkpoint.find_last_checkpoint(vmc_ckpt_save_path) 46 | _, data, params, *_ = checkpoint.restore(vmc_ckpt_restore_filename, cfg.batch_size) 47 | 48 | position = data.reshape((-1, data.shape[-1])) 49 | 50 | # Get a single copy of network params from the replicated one 51 | single_params = jax.tree_map(lambda x: x[0], params) 52 | network_wrapper = lambda x: network(params=single_params, pos=x) 53 | 54 | local_energy = hamiltonian.local_energy( 55 | network, 56 | atoms, 57 | charges, 58 | nspins, 59 | forward_laplacian=cfg.optim.forward_laplacian 60 | ) 61 | local_energy_func = lambda x: local_energy(params=single_params, key=None, data=x) 62 | 63 | run( 64 | position, 65 | dmc_cfg.iterations, 66 | network_wrapper, 67 | dmc_cfg.time_step, key, 68 | nuclei=atoms, 69 | charges=charges, 70 | 71 | # Below are optional arguments 72 | mixed_estimator_num_steps=dmc_cfg.mixed_estimator_num_steps, 73 | energy_window_size=dmc_cfg.energy_window_size, 74 | weight_branch_threshold=dmc_cfg.weight_branch_threshold, 75 | update_energy_offset_interval=dmc_cfg.update_energy_offset_interval, 76 | energy_offset_update_amplitude=dmc_cfg.energy_offset_update_amplitude, 77 | energy_cutoff_alpha=dmc_cfg.energy_cutoff_alpha, 78 | effective_time_step_update_period=dmc_cfg.effective_time_step_update_period, 79 | energy_outlier_rel_threshold=dmc_cfg.energy_outlier_rel_threshold, 80 | fix_size=dmc_cfg.fix_size, 81 | ebye_move=dmc_cfg.ebye_move, 82 | block_size=dmc_cfg.block_size, 83 | max_restore_nums=dmc_cfg.max_restore_nums, 84 | save_path=DataPath(dmc_cfg.log.save_path, dmc_cfg.log.remote_save_path), 85 | local_energy_func=local_energy_func, 86 | ) 87 | 88 | def main(_): 89 | cfg = FLAGS.config 90 | cfg = base_config.resolve(cfg) 91 | dmc_cfg = FLAGS.dmc_config 92 | 93 | logging.get_absl_handler().python_handler.stream = sys.stdout 94 | logging.set_verbosity(logging.INFO) 95 | run_wrapper(cfg, dmc_cfg) 96 | 97 | if __name__ == '__main__': 98 | from absl import flags 99 | from ml_collections.config_flags import config_flags 100 | FLAGS = flags.FLAGS 101 | 102 | config_flags.DEFINE_config_file('config', None, 'Path to config file.') 103 | 104 | dmc_config_file = str(pathlib.Path(os.path.abspath(__file__)).parents[1].absolute() / 'dmc_config.py') 105 | config_flags.DEFINE_config_file('dmc_config', dmc_config_file, 'Path to DMC config file.') 106 | 107 | app.run(main) 108 | -------------------------------------------------------------------------------- /examples/dmc/ferminet/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | A show case for FermiNet integration. 9 | Tested with 10 | cuda==12.8 11 | FermiNet commit 266c0259b70e80a536836f083ae5efb6f03bd57d 12 | jax==0.5.2 13 | jax(cuda12)==0.5.1 14 | kfac-jax==0.0.6 15 | folx==0.2.15 16 | ''' 17 | 18 | import pathlib 19 | import sys 20 | import time 21 | import os 22 | 23 | from absl import app 24 | from absl import logging 25 | from ferminet import base_config 26 | from ferminet import checkpoint 27 | from ferminet import envelopes 28 | from ferminet import networks 29 | import jax 30 | import jax.numpy as jnp 31 | 32 | from jaqmc.dmc import run 33 | from jaqmc.dmc.ckpt_metric_manager import DataPath 34 | 35 | 36 | def get_molecule_configuration(cfg): 37 | atoms = jnp.stack([jnp.array(atom.coords) for atom in cfg.system.molecule]) 38 | charges = jnp.array([atom.charge for atom in cfg.system.molecule]) 39 | spins = cfg.system.electrons 40 | return atoms, charges, spins 41 | 42 | def run_wrapper(cfg, dmc_cfg): 43 | atoms, charges, spins = get_molecule_configuration(cfg) 44 | key = jax.random.PRNGKey(666 if cfg.debug.deterministic else int(1e6 * time.time())) 45 | 46 | envelope = envelopes.make_isotropic_envelope() 47 | _network = networks.make_fermi_net( 48 | spins, charges, 49 | envelope=envelope, 50 | bias_orbitals=cfg.network.bias_orbitals, 51 | full_det=cfg.network.full_det, 52 | **cfg.network.ferminet) 53 | 54 | vmc_ckpt_save_path = checkpoint.create_save_path(cfg.log.save_path) 55 | vmc_ckpt_restore_filename = checkpoint.find_last_checkpoint(vmc_ckpt_save_path) 56 | _, _data, params, *_ = checkpoint.restore(vmc_ckpt_restore_filename, cfg.batch_size) 57 | 58 | data = _data.positions 59 | position = data.reshape((-1, data.shape[-1])) 60 | 61 | # Get a single copy of network params from the replicated one 62 | single_params = jax.tree_map(lambda x: x[0], params) 63 | network = lambda params, pos: _network.apply(params, pos, spins, atoms, charges) 64 | network_wrapper = lambda x: network(params=single_params, pos=x) 65 | 66 | run( 67 | position, 68 | dmc_cfg.iterations, 69 | network_wrapper, 70 | dmc_cfg.time_step, key, 71 | nuclei=atoms, 72 | charges=charges, 73 | 74 | # Below are optional arguments 75 | mixed_estimator_num_steps=dmc_cfg.mixed_estimator_num_steps, 76 | energy_window_size=dmc_cfg.energy_window_size, 77 | weight_branch_threshold=dmc_cfg.weight_branch_threshold, 78 | update_energy_offset_interval=dmc_cfg.update_energy_offset_interval, 79 | energy_offset_update_amplitude=dmc_cfg.energy_offset_update_amplitude, 80 | energy_cutoff_alpha=dmc_cfg.energy_cutoff_alpha, 81 | effective_time_step_update_period=dmc_cfg.effective_time_step_update_period, 82 | energy_outlier_rel_threshold=dmc_cfg.energy_outlier_rel_threshold, 83 | fix_size=dmc_cfg.fix_size, 84 | ebye_move=dmc_cfg.ebye_move, 85 | block_size=dmc_cfg.block_size, 86 | max_restore_nums=dmc_cfg.max_restore_nums, 87 | save_path=DataPath(dmc_cfg.log.save_path, dmc_cfg.log.remote_save_path), 88 | ) 89 | 90 | def main(_): 91 | cfg = FLAGS.config 92 | cfg = base_config.resolve(cfg) 93 | dmc_cfg = FLAGS.dmc_config 94 | 95 | logging.get_absl_handler().python_handler.stream = sys.stdout 96 | logging.set_verbosity(logging.INFO) 97 | run_wrapper(cfg, dmc_cfg) 98 | 99 | if __name__ == '__main__': 100 | from absl import flags 101 | from ml_collections.config_flags import config_flags 102 | FLAGS = flags.FLAGS 103 | 104 | config_flags.DEFINE_config_file('config', None, 'Path to config file.') 105 | 106 | dmc_config_file = str(pathlib.Path(os.path.abspath(__file__)).parents[1].absolute() / 'dmc_config.py') 107 | config_flags.DEFINE_config_file('dmc_config', dmc_config_file, 'Path to DMC config file.') 108 | 109 | app.run(main) 110 | -------------------------------------------------------------------------------- /jaqmc/dmc/branch_fix_size.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | fix-size branching. 9 | ''' 10 | 11 | from absl import logging 12 | import jax 13 | import jax.numpy as jnp 14 | 15 | def branch(weight, key, branch_arrays, min_thres=0.3, max_thres=2): 16 | ''' 17 | Split large weight, merge small weight. 18 | We will only be able to do branching within the same host, no balancing 19 | across hosts for now 20 | 21 | We apply a simple trick to keep the population size fixed: 22 | Whenever we want to merge a pair, we also split a walker (with largest weight); 23 | similarly whenever we want to split a walker, we merge a pair(with smallest weight). 24 | 25 | This trick should be harmless with reasonable branching thresholds. That said, 26 | one potential issue is that we merge two not-so-small-weight walkers and generate 27 | one large-weight walker whose weight may exceed the `max_thres`. This may 28 | happen when `max_thres` is too small while `min_thres` is too large (for instance, 29 | if we choose the thresholds as (0.8, 1.2)). If our thresholds are reasonble, (`min_thres` 30 | 0.2~0.4, `max_thres` 1.8 ~ 2.0), then together with assumption that the majority 31 | of walkers have weight around 1, we should be fine. 32 | 33 | Args: 34 | weight: The weight according to which the split and merge will be done. 35 | key: The RNG key used to determine which walker should be kept when merging 36 | two walkers. 37 | branch_arrays: A list of arrays with the same length as `weight`. Those 38 | arrays will be branched accordingly. 39 | min_thres, max_thres: The threshold used to determine which walkers to 40 | split and merge. 41 | ''' 42 | # Round upwards. Namely whenever we have a weight below the threshold, we 43 | # merge it with the second smallest one. It should not matter much compared 44 | # to the more conservative approach doing downward rounding. 45 | num_merge_pairs = (jnp.sum(weight < min_thres) + 1) // 2 46 | num_split_walkers = jnp.sum(weight > max_thres) 47 | # We will merge and split the same number of small-weight pairs and large-weight 48 | # walkers, so that the population size is not changed. 49 | num_to_change = jnp.maximum(num_merge_pairs, num_split_walkers).tolist() 50 | if num_to_change > 0.05 * len(weight): 51 | logging.warning(f'large number to change: {num_to_change}') 52 | if num_to_change == 0: 53 | return weight, branch_arrays 54 | 55 | # We take the k-smallest value of weight and their indices by first multiplying 56 | # the weight array by -1 then take the top-k elements. 57 | _, smallest_k_indices = jax.lax.top_k(-1 * weight, 2 * num_to_change) 58 | 59 | # Group the smallest indices to pairs to be merged. 60 | smallest_k_indices = smallest_k_indices.reshape((num_to_change, 2)) 61 | 62 | _, largest_k_indices = jax.lax.top_k(weight, num_to_change) 63 | weight, *branch_arrays = do_branch(weight, key, smallest_k_indices, largest_k_indices, *branch_arrays) 64 | 65 | return weight, branch_arrays 66 | 67 | @jax.jit 68 | def do_branch(weight, key, smallest_k_indices, largest_k_indices, *branch_arrays): 69 | thresholds = weight[smallest_k_indices[:, 0]] / (weight[smallest_k_indices[:, 0]] + weight[smallest_k_indices[:, 1]]) 70 | random_num = jax.random.uniform(key, shape=thresholds.shape) 71 | kept_indices = jnp.where(random_num < thresholds, smallest_k_indices[:, 0], smallest_k_indices[:, 1]) 72 | removed_indices = jnp.where(random_num > thresholds, smallest_k_indices[:, 0], smallest_k_indices[:, 1]) 73 | 74 | # For arrays in `branch_arrays`, the spots for removed elements will be simply filled with the branched ones. 75 | branch_arrays = [arr.at[removed_indices].set(arr[largest_k_indices]) for arr in branch_arrays] 76 | 77 | # For weights, it's trickier. We need to add the removed elements' weight to the 78 | # winner in the merging process. And then halve the weight of the branched elements, 79 | # then copy it to the spots for the removed elements. 80 | weight = weight.at[kept_indices].add(weight[removed_indices]) 81 | weight = weight.at[largest_k_indices].set(weight[largest_k_indices] / 2) 82 | weight = weight.at[removed_indices].set(weight[largest_k_indices]) 83 | return [weight] + branch_arrays 84 | -------------------------------------------------------------------------------- /tests/dmc/hdfs_storage_handler_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import pathlib 9 | import shutil 10 | import subprocess 11 | import tempfile 12 | 13 | from absl.testing import absltest 14 | from absl.testing import parameterized 15 | from absl import app 16 | from absl import flags 17 | import chex 18 | 19 | from jaqmc.dmc.storage_handler import HdfsStorageHandler 20 | 21 | FLAGS = flags.FLAGS 22 | flags.DEFINE_string('command_prefix', '', 'NA') 23 | 24 | class HdfsStorageHandlerTest(chex.TestCase, parameterized.TestCase): 25 | def setUp(self): 26 | self.tmpdir = '__tmpdir' 27 | self.tmpdir_path = pathlib.Path('__tmpdir') 28 | self.command_prefix = FLAGS.command_prefix 29 | if self.command_prefix: 30 | subprocess.check_call([self.command_prefix, 'hdfs', 'dfs', '-mkdir', self.tmpdir]) 31 | else: 32 | subprocess.check_call(['hdfs', 'dfs', '-mkdir', self.tmpdir]) 33 | 34 | self.handler = HdfsStorageHandler(self.command_prefix) 35 | 36 | self.tmp_local_dir = tempfile.mkdtemp() 37 | 38 | def tearDown(self): 39 | if self.command_prefix: 40 | subprocess.check_call([self.command_prefix, 'hdfs', 'dfs', '-rm', '-r', self.tmpdir]) 41 | else: 42 | subprocess.check_call(['hdfs', 'dfs', '-rm', '-r', self.tmpdir]) 43 | 44 | shutil.rmtree(self.tmp_local_dir) 45 | 46 | def make_temp_remote_directory(self): 47 | directory_path = os.path.join(self.tmpdir, 'test_dir') 48 | if self.command_prefix: 49 | subprocess.check_call([self.command_prefix, 'hdfs', 'dfs', '-mkdir', directory_path]) 50 | else: 51 | subprocess.check_call(['hdfs', 'dfs', '-mkdir', directory_path]) 52 | return directory_path 53 | 54 | def touch_temp_remote_file(self, filename='test.txt'): 55 | tmp_file = pathlib.Path(self.tmp_local_dir) / filename 56 | tmp_file.touch() 57 | remote_path = os.path.join(self.tmpdir, filename) 58 | if self.command_prefix: 59 | subprocess.check_call([self.command_prefix, 'hdfs', 'dfs', '-put', str(tmp_file), remote_path]) 60 | else: 61 | subprocess.check_call(['hdfs', 'dfs', '-put', str(tmp_file), remote_path]) 62 | return remote_path 63 | 64 | def test_rm_file(self): 65 | tmp_file = self.touch_temp_remote_file() 66 | self.assertTrue(self.handler.exists(tmp_file)) 67 | 68 | self.handler.rm(tmp_file) 69 | self.assertFalse(self.handler.exists(tmp_file)) 70 | 71 | def test_rm_dir(self): 72 | tmp_dir = self.make_temp_remote_directory() 73 | self.assertTrue(self.handler.exists_dir(tmp_dir)) 74 | 75 | self.handler.rm(str(tmp_dir)) 76 | self.assertFalse(self.handler.exists(tmp_dir)) 77 | 78 | @parameterized.parameters(True, False) 79 | def test_ls_path(self, return_fullpath): 80 | all_file_paths = [] 81 | for i in range(3): 82 | tmp_filename = f'test_{i}.txt' 83 | remote_path = self.touch_temp_remote_file(tmp_filename) 84 | all_file_paths.append(remote_path) 85 | 86 | remote_dir = self.make_temp_remote_directory() 87 | all_file_paths.append(remote_dir) 88 | 89 | expected_set = (set(all_file_paths) 90 | if return_fullpath 91 | else set(os.path.basename(f) for f in all_file_paths)) 92 | 93 | ls_results = self.handler.ls(self.tmpdir, return_fullpath=return_fullpath) 94 | print('ls: ', ls_results) 95 | self.assertSetEqual(set(ls_results), expected_set) 96 | 97 | def test_exists_no_file(self): 98 | fake_path = str(self.tmpdir_path / 'fake.txt') 99 | self.assertFalse(self.handler.exists(fake_path)) 100 | 101 | def test_exists(self): 102 | test_file = self.touch_temp_remote_file() 103 | test_dir = self.make_temp_remote_directory() 104 | 105 | self.assertTrue(self.handler.exists(test_file)) 106 | self.assertTrue(self.handler.exists(test_dir)) 107 | 108 | def test_exists_dir_no_dir(self): 109 | test_file = self.touch_temp_remote_file() 110 | fake_path = str(self.tmpdir_path / 'fake/') 111 | 112 | self.assertFalse(self.handler.exists_dir(test_file)) 113 | self.assertFalse(self.handler.exists_dir(fake_path)) 114 | 115 | def test_exists_dir(self): 116 | test_dir = self.make_temp_remote_directory() 117 | 118 | self.assertTrue(self.handler.exists_dir(test_dir)) 119 | 120 | if __name__ == '__main__': 121 | absltest.main() 122 | -------------------------------------------------------------------------------- /jaqmc/loss/factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import inspect 8 | from typing import Optional 9 | 10 | from absl import logging 11 | import chex 12 | import jax 13 | import jax.numpy as jnp 14 | 15 | from jaqmc.loss import utils 16 | from jaqmc.loss.spin_penalty import SpinAuxData, SpinFuncState, make_spin_penalty 17 | from jaqmc.loss.overlap_penalty import OverlapAuxData, OverlapFuncState, make_overlap_penalty 18 | from jaqmc.loss.vmc import VMCAuxData, VMCFuncState, make_vmc_loss 19 | 20 | @chex.dataclass 21 | class FuncState: 22 | ''' 23 | Parent Func State containing ones for each loss component. 24 | ''' 25 | vmc: Optional[VMCFuncState] = None 26 | spin: Optional[SpinFuncState] = None 27 | overlap: Optional[OverlapFuncState] = None 28 | 29 | @chex.dataclass 30 | class AuxData: 31 | ''' 32 | Parent auxiliary data containing ones for each loss component. 33 | ''' 34 | vmc: Optional[VMCAuxData] = None 35 | spin: Optional[SpinAuxData] = None 36 | overlap: Optional[OverlapAuxData] = None 37 | 38 | def build_func_state(step=None, overlap_data=None) -> FuncState: 39 | ''' 40 | Helper function to create parent FuncState from actual data. 41 | ''' 42 | if step is None: 43 | spin = None 44 | else: 45 | spin = SpinFuncState(step=step) 46 | 47 | if overlap_data is None: 48 | overlap = None 49 | else: 50 | overlap = OverlapFuncState(overlap_data=overlap_data) 51 | 52 | return FuncState( 53 | vmc=None, 54 | spin=spin, 55 | overlap=overlap 56 | ) 57 | 58 | def make_loss( 59 | signed_network: utils.WaveFuncLike, 60 | local_energy: utils.LocalEnergy, 61 | 62 | # Flags to control loss behavior by selecting loss components. 63 | with_spin=False, 64 | with_overlap=False, 65 | 66 | # kwargs for each loss components. 67 | **kwargs 68 | ): 69 | ''' 70 | User-facing loss factory. 71 | 72 | Note: kfac_jax.optimizer.Optimizer should turn on flags `value_func_has_rng=True` and 73 | `value_func_has_aux=True` when working with the output loss function. 74 | 75 | Args: 76 | signed_network: Callable taking in params and data, returning sign 77 | and log-abs of the neural network wavefunction. 78 | local_energy: callable which evaluates the local energy. 79 | 80 | with_spin: If True, then add spin penalty in loss. 81 | with_overlap: If True, then add overlap penalty in loss. 82 | 83 | kwargs: Other flags to be passed to each loss component. 84 | 85 | Returns: 86 | Callable with signature (params, func_state, key, data) and returns (loss, (func_state, aux_data)) 87 | ''' 88 | def invoke(component_factory): 89 | relevant_kwargs = _get_relevant_kwargs(kwargs, component_factory) 90 | logging.info(f'arguments for {component_factory.__name__} are {relevant_kwargs}') 91 | component_func = component_factory( 92 | signed_network, 93 | local_energy, 94 | **relevant_kwargs) 95 | return component_func 96 | 97 | # A list of pairs (loss_func, loss_identifier) 98 | all_components = [] 99 | 100 | loss_func = invoke(make_vmc_loss) 101 | all_components.append([loss_func, 'vmc']) 102 | 103 | if with_spin: 104 | spin_penalty_func = invoke(make_spin_penalty) 105 | all_components.append([spin_penalty_func, 'spin']) 106 | 107 | if with_overlap: 108 | overlap_penalty_func = invoke(make_overlap_penalty) 109 | all_components.append([overlap_penalty_func, 'overlap']) 110 | 111 | def total_loss( 112 | params: utils.ParamTree, 113 | func_state: FuncState, 114 | key: chex.PRNGKey, 115 | data: jnp.ndarray 116 | ): 117 | _loss = 0.0 118 | all_func_state = dict() 119 | all_aux = dict() 120 | for func, component_key in all_components: 121 | key, sub_key = jax.random.split(key) 122 | primal, (new_func_state, aux) = func(params, func_state[component_key], sub_key, data) 123 | _loss += primal 124 | all_func_state[component_key] = new_func_state 125 | all_aux[component_key] = aux 126 | return _loss, (FuncState(**all_func_state), AuxData(**all_aux)) 127 | return total_loss 128 | 129 | def _get_relevant_kwargs(all_kwargs, func): 130 | ''' 131 | Pick keyword arguments from `all_kwargs` that is relevant to `func`. Namely, 132 | only pick the ones that's belong to the `func`'s argument list. 133 | ''' 134 | def __get_args(func): 135 | sig = inspect.signature(func) 136 | return set(p.name for p in sig.parameters.values()) 137 | 138 | func_args = __get_args(func) 139 | return {k: v for (k, v) in all_kwargs.items() if k in func_args} 140 | -------------------------------------------------------------------------------- /tests/dmc/energy_estimator_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from collections import deque 8 | 9 | from absl.testing import absltest 10 | import jax.test_util as jtu 11 | import jax 12 | import jax.numpy as jnp 13 | import numpy as np 14 | 15 | from jaqmc.dmc.energy_estimator import MixedEstimatorCalculator 16 | 17 | def get_mixed_estimator_Pi(energy_offsets, time_step): 18 | log = -time_step * np.sum(np.array(energy_offsets)) 19 | return np.exp(log) 20 | 21 | class EnergyEstimatorTest(jtu.JaxTestCase): 22 | 23 | @staticmethod 24 | def calc_estimator(all_energy_offsets, all_energy, all_total_weight, 25 | mixed_estimator_num_steps, energy_window_size, all_time_step): 26 | all_Pi = [] 27 | for i in range(len(all_energy_offsets)): 28 | start = max(0, i + 1 - mixed_estimator_num_steps) 29 | Pi = get_mixed_estimator_Pi(all_energy_offsets[start: (i + 1)], all_time_step[i]) 30 | all_Pi.append(Pi) 31 | numerator = 0 32 | denominator = 0 33 | start_index = len(all_energy) - energy_window_size 34 | for i, (energy, total_weight, Pi) in enumerate(zip(all_energy, all_total_weight, all_Pi)): 35 | if i < start_index: 36 | continue 37 | numerator += Pi * energy * total_weight 38 | denominator += Pi * total_weight 39 | return numerator / denominator 40 | 41 | def test_mixed_estimator_Pi(self): 42 | all_energy_offsets = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0] 43 | all_energy = [1.1, 2.2, 3.5, 1.0, 2.2, 3.0] 44 | all_total_weight = [1.0, 2, 1.0, 1.0, 1.0, 1.0] 45 | all_time_step = [1e-2, 1e-1, 1e-1, 1e-2, 1e-3] 46 | 47 | mixed_estimator_num_steps = 2 48 | energy_window_size = 6 49 | calculator = MixedEstimatorCalculator(mixed_estimator_num_steps, 50 | energy_window_size) 51 | for i, (energy_offset, energy, total_weight, time_step) in enumerate(zip(all_energy_offsets, all_energy, all_total_weight, all_time_step)): 52 | expected_result = self.calc_estimator( 53 | all_energy_offsets[:(i + 1)], 54 | all_energy[:(i + 1)], 55 | all_total_weight[:(i + 1)], 56 | mixed_estimator_num_steps=mixed_estimator_num_steps, 57 | energy_window_size=energy_window_size, 58 | all_time_step=all_time_step) 59 | result = calculator.run(energy_offset, energy, total_weight, time_step) 60 | self.assertAlmostEqual(result, expected_result, places=5) 61 | 62 | def test_mixed_estimator_Pi_2(self): 63 | all_energy_offsets = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0] 64 | all_energy = [1.1, 2.2, 3.5, 1.0, 2.2, 3.0] 65 | all_total_weight = [1.0, 2, 1.0, 1.0, 1.0, 1.0] 66 | all_time_step = [1e-2, 1e-1, 1e-1, 1e-2, 1e-3] 67 | 68 | mixed_estimator_num_steps = 2 69 | energy_window_size = 3 70 | calculator = MixedEstimatorCalculator(mixed_estimator_num_steps, 71 | energy_window_size) 72 | for i, (energy_offset, energy, total_weight, time_step) in enumerate(zip(all_energy_offsets, all_energy, all_total_weight, all_time_step)): 73 | expected_result = self.calc_estimator( 74 | all_energy_offsets[:(i + 1)], 75 | all_energy[:(i + 1)], 76 | all_total_weight[:(i + 1)], 77 | mixed_estimator_num_steps=mixed_estimator_num_steps, 78 | energy_window_size=energy_window_size, 79 | all_time_step=all_time_step) 80 | result = calculator.run(energy_offset, energy, total_weight, time_step=time_step) 81 | self.assertAlmostEqual(result, expected_result, places=5) 82 | 83 | def test_mixed_estimator_Pi_3(self): 84 | all_energy_offsets = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0] 85 | all_energy = [1.1, 2.2, 3.5, 1.0, 2.2, 3.0] 86 | all_total_weight = [1.0, 2, 1.0, 1.0, 1.0, 1.0] 87 | all_time_step = [1e-2, 1e-1, 1e-1, 1e-2, 1e-3] 88 | 89 | mixed_estimator_num_steps = 4 90 | energy_window_size = 3 91 | calculator = MixedEstimatorCalculator(mixed_estimator_num_steps, 92 | energy_window_size) 93 | for i, (energy_offset, energy, total_weight, time_step) in enumerate(zip(all_energy_offsets, all_energy, all_total_weight, all_time_step)): 94 | expected_result = self.calc_estimator( 95 | all_energy_offsets[:(i + 1)], 96 | all_energy[:(i + 1)], 97 | all_total_weight[:(i + 1)], 98 | mixed_estimator_num_steps=mixed_estimator_num_steps, 99 | energy_window_size=energy_window_size, 100 | all_time_step=all_time_step) 101 | result = calculator.run(energy_offset, energy, total_weight, time_step=time_step) 102 | self.assertAlmostEqual(result, expected_result, places=5) 103 | 104 | if __name__ == '__main__': 105 | absltest.main() 106 | -------------------------------------------------------------------------------- /tests/dmc/recovery_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | 9 | from absl.testing import absltest 10 | import chex 11 | import jax 12 | import jax.numpy as jnp 13 | import tempfile 14 | import shutil 15 | 16 | from jaqmc.dmc.ckpt_metric_manager import CkptMetricManager, NoSafeDataAvailable 17 | from jaqmc.dmc.dmc import recovery_wrapper, IterationOutput 18 | from jaqmc.dmc.state import State 19 | 20 | class TestException(Exception): 21 | pass 22 | 23 | class RecoveryWrapperTest(chex.TestCase): 24 | 25 | def setUp(self): 26 | self.tmpdir = tempfile.mkdtemp() 27 | 28 | def tearDown(self): 29 | shutil.rmtree(self.tmpdir) 30 | 31 | _test_iteration_output = IterationOutput( 32 | succeeded=True, 33 | state=State(position=jnp.arange(20).reshape((10, 2))), 34 | key=None) 35 | _test_step = 666 36 | def _test_dmc_iteration(self, should_raise): 37 | if should_raise: 38 | raise TestException() 39 | return self._test_step, self._test_iteration_output 40 | 41 | def test_no_failure(self): 42 | # When no failure happens, the raw function's behavior should not 43 | # be modified at all 44 | wrapped_func = recovery_wrapper(self._test_dmc_iteration, 0, 45 | ckpt_metric_manager=None, 46 | key=None) 47 | step, output = wrapped_func(should_raise=False) 48 | self.assertEqual(step, self._test_step) 49 | chex.assert_trees_all_close(output.state.position, 50 | self._test_iteration_output.state.position) 51 | 52 | def test_no_safe_data_available(self): 53 | ckpt_metric_manager = CkptMetricManager(metric_schema=[], block_size=3, 54 | lazy_setup=False) 55 | 56 | wrapped_func = recovery_wrapper(self._test_dmc_iteration, 57 | max_restore_nums=1, 58 | ckpt_metric_manager=ckpt_metric_manager, 59 | key=None) 60 | self.assertRaises(NoSafeDataAvailable, wrapped_func, should_raise=True) 61 | 62 | def get_manager_with_safe_data(self, block_size=10): 63 | position = jnp.arange(120).reshape((10, 12)) 64 | state = State(position=position) 65 | 66 | ckpt_metric_manager = CkptMetricManager(metric_schema=[], 67 | block_size=block_size, 68 | save_path=self.tmpdir, 69 | lazy_setup=False) 70 | with ckpt_metric_manager: 71 | for i in range(block_size + 1): 72 | ckpt_metric_manager.run(i, state, []) 73 | return ckpt_metric_manager, state 74 | 75 | def test_has_safe_data_available(self): 76 | max_restore_nums = 3 77 | 78 | counter = 0 79 | def test_func(): 80 | counter += 1 81 | if counter == 1: 82 | raise Exception() 83 | return counter, None 84 | 85 | def get_metric_latest_step(ckpt_metric_manager): 86 | data = ckpt_metric_manager.alive_metric_manager.get_metric_data() 87 | return data.step.iloc[-1] 88 | 89 | block_size = 10 90 | ckpt_metric_manager, state = self.get_manager_with_safe_data(block_size) 91 | self.assertEqual(get_metric_latest_step(ckpt_metric_manager), block_size) 92 | 93 | key = jax.random.PRNGKey(666) 94 | wrapped_func = recovery_wrapper(test_func, 95 | max_restore_nums=max_restore_nums, 96 | ckpt_metric_manager=ckpt_metric_manager, 97 | key=key) 98 | step, output = wrapped_func() 99 | self.assertEqual(step, 1) 100 | self.assertFalse(output.succeeded) 101 | chex.assert_trees_all_close(output.state.position, state.position) 102 | 103 | # metric should also be reverted 104 | self.assertEqual(get_metric_latest_step(ckpt_metric_manager), 0) 105 | 106 | def test_has_safe_data_available_but_exceeds_max_num(self): 107 | max_restore_nums = 3 108 | 109 | def test_func(): 110 | raise TestException() 111 | ckpt_metric_manager, state = self.get_manager_with_safe_data() 112 | key = jax.random.PRNGKey(666) 113 | wrapped_func = recovery_wrapper(test_func, 114 | max_restore_nums=max_restore_nums, 115 | ckpt_metric_manager=ckpt_metric_manager, 116 | key=key) 117 | # First `max_restore_nums` exception will be captured and retry be activated. 118 | for _ in range(max_restore_nums): 119 | step, output = wrapped_func() 120 | self.assertFalse(output.succeeded) 121 | 122 | self.assertRaises(TestException, wrapped_func) 123 | 124 | if __name__ == '__main__': 125 | absltest.main() 126 | -------------------------------------------------------------------------------- /jaqmc/pp/ph/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import pathlib 9 | 10 | # value: spin 11 | PH_TM_ELEMENTS = {'Cr': 6, 12 | 'Mn': 5, 13 | 'Fe': 4, 14 | 'Co': 3, 15 | 'Ni': 2, 16 | 'Cu': 1, 17 | 'Zn': 0} 18 | # value: (charge, use_hf) 19 | TM_ELEMENTS_INFO = {'Cr': (14.0, False), 20 | 'Mn': (15.0, True), 21 | 'Fe': (16.0, False), 22 | 'Co': (17.0, False), 23 | 'Ni': (18.0, True), 24 | 'Cu': (19.0, True), 25 | 'Zn': (20.0, False)} 26 | 27 | # value: spin 28 | PH_MG_ELEMENTS = {'S': 2} 29 | # value: (charge, use_hf) 30 | MG_ELEMENTS_INFO = {'S': (6.0, False)} 31 | 32 | PH_ELEMENTS = {**PH_TM_ELEMENTS, **PH_MG_ELEMENTS} 33 | 34 | def PH_config(get_config, 35 | get_ecp_cfg_ref_from_cfg=lambda x: x.ecp, 36 | get_pyscf_mol_from_cfg=lambda x: x.system.pyscf_mol): 37 | def wrapper(*args): 38 | cfg = get_config(*args) 39 | pyscf_mol = get_pyscf_mol_from_cfg(cfg) 40 | ecp_cfg_ref = get_ecp_cfg_ref_from_cfg(cfg) 41 | ecp_cfg_ref.ph_info = gen_ph_info(pyscf_mol._atom, 42 | ph_elements=ecp_cfg_ref.ph_elements) 43 | return cfg 44 | return wrapper 45 | 46 | def parse_TM_xml(xml_file): 47 | import xml.etree.ElementTree as ET 48 | tree = ET.parse(xml_file) 49 | root = tree.getroot() 50 | 51 | def get_data_arr(index): 52 | data = [float(y) for x in root[2][index][0][1].text.split('\n') for y in x.strip().split(' ') if y != ''] 53 | return np.array(data) 54 | s_arr = get_data_arr(0) 55 | p_arr = get_data_arr(1) 56 | d_arr = get_data_arr(2) 57 | 58 | local_nl = d_arr 59 | v0_nl = s_arr - local_nl 60 | v1_nl = p_arr - local_nl 61 | 62 | # This relation should hold: 2 * v0_nl == 3 * v1_nl 63 | return local_nl + v0_nl, - v0_nl / 6 64 | 65 | def gen_ph_data_TM(element): 66 | charge, use_hf = TM_ELEMENTS_INFO[element] 67 | filename = f'{element}.{"hf" if use_hf else "cc"}.xml' 68 | xml_file = pathlib.Path(__file__).parent.resolve() / 'raw_data' / 'TM' / filename 69 | loc_data, l2_data = parse_TM_xml(str(xml_file)) 70 | 71 | # The data in XML is in form: 72 | # 73 | # l2_data = r * v_{L^2} 74 | # loc_data + charge = r * \tilde{v}_loc 75 | return loc_data + charge, l2_data 76 | 77 | def gen_ph_data_MG(element): 78 | charge, use_hf = MG_ELEMENTS_INFO[element] 79 | filename = f'{element}.{"hf" if use_hf else "cc"}.xml' 80 | xml_file = pathlib.Path(__file__).parent.resolve() / 'raw_data' / 'MG' / filename 81 | loc_data, l2_data = parse_TM_xml(str(xml_file)) 82 | 83 | # The data in XML is in form: 84 | # 85 | # l2_data = r * v_{L^2} 86 | # loc_data + charge = r * \tilde{v}_loc 87 | return loc_data + charge, l2_data 88 | 89 | def gen_ph_data(element): 90 | if element in PH_TM_ELEMENTS: 91 | return gen_ph_data_TM(element) 92 | elif element in PH_MG_ELEMENTS: 93 | return gen_ph_data_MG(element) 94 | else: 95 | raise NotImplementedError(f'The PH for {element} is not supported yet') 96 | 97 | def gen_ph_info(atoms, ph_elements=None): 98 | def should_consider_ph(element): 99 | # Try all available PH elements 100 | if ph_elements is None: 101 | return element in PH_ELEMENTS 102 | 103 | return element in ph_elements 104 | 105 | ph_atom_pos = [] 106 | ph_data = dict() 107 | for symbol, pos in atoms: 108 | if should_consider_ph(symbol): 109 | ph_atom_pos.append((symbol, pos)) 110 | if symbol not in ph_data: 111 | ph_data[symbol] = gen_ph_data(symbol) 112 | return (ph_atom_pos, ph_data) 113 | 114 | def gen_ph_data_Co(): 115 | ''' 116 | This function generate Pseudo-Hamiltonian corresponding to 117 | https://pubs.acs.org/doi/10.1021/acs.jctc.1c00992 118 | , which only support element Co. 119 | 120 | This PH is deprecated due to the presence of 121 | https://pubs.aip.org/aip/jcp/article/159/16/164114/2918607/Locality-error-free-effective-core-potentials-for 122 | , which support all the first row transition metals, including Co. 123 | ''' 124 | def parse_xml(xml_file): 125 | import xml.etree.ElementTree as ET 126 | tree = ET.parse(xml_file) 127 | root = tree.getroot() 128 | loc_data = [float(y) for x in root[3][0][0][1].text.split('\n') for y in x.strip().split(' ') if y != ''] 129 | L2_data = [float(y) for x in root[2][0][1].text.split('\n') for y in x.strip().split(' ') if y != ''] 130 | return np.array(loc_data), np.array(L2_data) 131 | 132 | xml_file = pathlib.Path(__file__).parent.resolve() / 'raw_data' / 'TM' / 'Co.pure.xml' 133 | loc_data, l2_data = parse_xml(str(xml_file)) 134 | 135 | # The data in XML is in form: 136 | # 137 | # l2_data = r * v_{L^2} 138 | # loc_data + 17 = r * \tilde{v}_loc 139 | # 140 | # Here `l2_data` and `loc_data` are the data in "L2 section" and s-channel 141 | # in "semilocal section" in the XML file respectively. 142 | return loc_data + 17.0, l2_data 143 | -------------------------------------------------------------------------------- /tests/dmc/dmc_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from absl.testing import absltest 8 | import jax.test_util as jtu 9 | import jax 10 | import jax.numpy as jnp 11 | import numpy as np 12 | 13 | from jaqmc.dmc.dmc import do_run_dmc_single_walker, get_to_pad_num 14 | 15 | class DmcTest(jtu.JaxTestCase): 16 | 17 | @staticmethod 18 | def make_dummy_func(expected_return_val): 19 | def dummy_func(position): 20 | return expected_return_val 21 | return dummy_func 22 | 23 | def test_single_walker_across_node(self): 24 | position = np.array([1.0, 0.0, 0.0]) 25 | walker_age = 0.0 26 | 27 | def dummy_wave_func(position): 28 | if position[0] == 1.0: 29 | return (1, 1.0) 30 | else: 31 | return (-1, 1.0) 32 | 33 | def dummy_energy_func(position): 34 | if position[0] == 1.0: 35 | return 0.1 36 | else: 37 | return -0.1 38 | 39 | dummy_velocity = self.make_dummy_func(jnp.ones(3) * 0.1) 40 | dummy_local_energy = 0.0 41 | 42 | position_new, average_local_energy_new, walker_age_new, local_energy_new, weight_delta_log, delta_R, acceptance_rate, debug_info = do_run_dmc_single_walker( 43 | position, 44 | walker_age, 45 | dummy_local_energy, 46 | dummy_wave_func, 47 | dummy_velocity, 48 | dummy_energy_func, 49 | time_step=0.01, 50 | key=jax.random.PRNGKey(42), 51 | energy_offset=0.1, 52 | mixed_estimator=0.1, 53 | nuclei=jnp.ones((1, 3)), 54 | charges=jnp.ones(3)) 55 | is_accepted, *_ = debug_info 56 | self.assertArraysEqual(position, position_new) 57 | self.assertEqual(average_local_energy_new, 0.0) 58 | self.assertEqual(walker_age_new, walker_age + 1) 59 | self.assertEqual(is_accepted, False) 60 | self.assertEqual(acceptance_rate, 0.0) 61 | 62 | def test_single_walker_should_accept_diffusion(self): 63 | position = np.array([1.0, 0.0, 0.0]) 64 | walker_age = 0.0 65 | 66 | def dummy_wave_func(position): 67 | if position[0] == 1.0: 68 | return (1, 1.0) 69 | else: 70 | return (1, 1.0) 71 | 72 | def dummy_energy_func(position): 73 | if position[0] == 1.0: 74 | return 0.1 75 | else: 76 | return -0.1 77 | 78 | dummy_velocity = jnp.ones(3) * 0.1 79 | dummy_velocity_func = self.make_dummy_func(dummy_velocity) 80 | time_step=0.01 81 | dummy_local_energy = 0.0 82 | 83 | position_new, average_local_energy_new, walker_age_new, local_energy_new, weight_delta_log, delta_R, acceptance_rate, debug_info = do_run_dmc_single_walker( 84 | position, 85 | walker_age, 86 | dummy_local_energy, 87 | dummy_wave_func, 88 | dummy_velocity_func, 89 | dummy_energy_func, 90 | time_step=time_step, 91 | key=jax.random.PRNGKey(42), 92 | energy_offset=0.1, 93 | mixed_estimator=0.1, 94 | nuclei=jnp.ones((1, 3)), 95 | charges=jnp.ones(3)) 96 | is_accepted, *_ = debug_info 97 | expected_energy_new = acceptance_rate * (-0.1) + (1 - acceptance_rate) * 0.1 98 | delta_position_norm = jnp.linalg.norm( 99 | position_new - position - dummy_velocity * time_step) 100 | 101 | self.assertEqual(is_accepted, True) 102 | self.assertNotAlmostEqual(position[0], position_new[0]) 103 | self.assertLess(delta_position_norm, 2 * jnp.sqrt(time_step)) 104 | self.assertEqual(average_local_energy_new, expected_energy_new) 105 | 106 | def test_single_walker_should_accept_old_walker(self): 107 | position = np.array([1.0, 0.0, 0.0]) 108 | walker_age = 0.0 109 | walker_age_old = 100.0 110 | 111 | def dummy_wave_func(position): 112 | if position[0] == 1.0: 113 | return (1, 1.0) 114 | else: 115 | return (1, 0.1) 116 | 117 | def dummy_energy_func(position): 118 | if position[0] == 1.0: 119 | return 0.1 120 | else: 121 | return 0.1 122 | 123 | dummy_velocity = jnp.ones(3) * 0.1 124 | dummy_velocity_func = self.make_dummy_func(dummy_velocity) 125 | time_step=0.01 126 | dummy_local_energy = 0.0 127 | 128 | *_, debug_info = do_run_dmc_single_walker( 129 | position, 130 | walker_age, 131 | dummy_local_energy, 132 | dummy_wave_func, 133 | dummy_velocity_func, 134 | dummy_energy_func, 135 | time_step=time_step, 136 | key=jax.random.PRNGKey(42), 137 | energy_offset=0.1, 138 | mixed_estimator=0.1, 139 | nuclei=jnp.ones((1, 3)), 140 | charges=jnp.ones(3)) 141 | is_accepted, acceptance_rate, *_ = debug_info 142 | self.assertEqual(is_accepted, False) 143 | 144 | position_new, average_local_energy_new, walker_age_new, local_energy_new, weight_delta_log, delta_R, acceptance_rate, debug_info = do_run_dmc_single_walker( 145 | position, 146 | walker_age_old, 147 | dummy_local_energy, 148 | dummy_wave_func, 149 | dummy_velocity_func, 150 | dummy_energy_func, 151 | time_step=time_step, 152 | key=jax.random.PRNGKey(42), 153 | energy_offset=0.1, 154 | mixed_estimator=0.1, 155 | nuclei=jnp.ones((1, 3)), 156 | charges=jnp.ones(3)) 157 | 158 | is_accepted_old, *_ = debug_info 159 | delta_position_norm = jnp.linalg.norm( 160 | position_new - position - dummy_velocity * time_step) 161 | 162 | self.assertEqual(is_accepted_old, True) 163 | self.assertNotAlmostEqual(position[0], position_new[0]) 164 | self.assertLess(delta_position_norm, 2 * jnp.sqrt(time_step)) 165 | self.assertGreater(acceptance_rate, 0.8) 166 | 167 | def test_get_to_pad_num(self): 168 | data = [ 169 | (4096, 8, 4160), 170 | (4095, 8, 4160), 171 | (4072, 8, 4160), 172 | (4071, 8, 4080), 173 | (40960, 8, 41600), 174 | (40920, 8, 41600), 175 | (40961, 8, 41600), 176 | (40800, 8, 41600), 177 | (40799, 8, 41600), 178 | (40791, 8, 40800), 179 | ] 180 | for num_walkers, num_device, expected_target_num in data: 181 | expected_to_pad_num = expected_target_num - num_walkers 182 | actual_result = get_to_pad_num(num_walkers, num_device) 183 | self.assertEqual(actual_result, expected_to_pad_num) 184 | 185 | if __name__ == '__main__': 186 | absltest.main() 187 | 188 | -------------------------------------------------------------------------------- /jaqmc/pp/ecp_potential.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | 10 | from .quadrature import Quadrature 11 | from .special import legendre_list 12 | 13 | from typing import Tuple 14 | 15 | def numerical_integral_for_loop(psi, rea_v, rea, walkers, max_l, key, quadrature: Quadrature, ecp_range: Tuple): 16 | ''' 17 | ref: Nonlocal pseudopotentials and diffusion Monte Carlo, equation 28 18 | 19 | inputs: 20 | psi: wave function that psi(walkers) returns a complex number 21 | rea_v: vector between electron and atoms 22 | rea: distance between electron and atoms 23 | walkers: shape (n_electron, 3) 24 | max_l: shape(l_number,) values of l to evaluate 25 | key: random number 26 | quadrature: A quadrature object to do numerical integration. 27 | ecp_range: (n_up_start, n_up_end, n_down_start, n_down_end) 28 | returns: 29 | value of the integral \int (2l+1) * P_l(cos theta) psi(r1,..,ri,..) 30 | shape (n_electron, l_number) 31 | ''' 32 | 33 | n_electron = walkers.shape[0] 34 | ls = list(range(max_l-1)) 35 | res = jnp.zeros((n_electron, len(ls))) 36 | res_den = psi(walkers) 37 | normal_walkers = rea_v / rea[:, None] 38 | psi_vec = jax.vmap(psi, in_axes=0) 39 | 40 | def psi_r(i, x): 41 | coords = x.reshape(-1,3) * rea[i] + walkers[i] - rea_v[i] 42 | new_walkers = jnp.tile(walkers, (coords.shape[0],) + (1, 1)) 43 | new_walkers = new_walkers.at[:,i,:].set(coords) 44 | res_num = psi_vec(new_walkers) 45 | res = res_num[0] * res_den[0] * jnp.exp(res_num[1] - res_den[1]) 46 | return res.reshape(x.shape[:-1]) 47 | npts = quadrature.discrete_pts(key, N=n_electron) 48 | 49 | Pl_ = lambda x: legendre_list(x, ls) 50 | def Pl(i, x): 51 | tmp = Pl_(jnp.matmul(x, normal_walkers[i, :])) ## Pl(cos(\theta)) 52 | return tmp 53 | def integral(i, res): 54 | result = quadrature.integral_value(Pl(i,npts[i]),psi_r(i,npts[i])) * (2 * jnp.array(ls) + 1) 55 | res = res.at[i,:].set(result) 56 | return res 57 | n_up_start, n_up_end, n_down_start, n_down_end = ecp_range 58 | res = jax.lax.fori_loop(n_up_start, n_up_end, integral, res) 59 | res = jax.lax.fori_loop(n_down_start, n_down_end, integral, res) 60 | 61 | return res 62 | 63 | def numerical_integral_optim_for_loop(f_modify, f_memory, f_memory_update, 64 | rea_v, rea, walkers, max_l, key, quadrature: Quadrature, 65 | ecp_range: Tuple): 66 | ''' 67 | ref: Nonlocal pseudopotentials and diffusion Monte Carlo, equation 28 68 | 69 | inputs: 70 | psi: wave function that psi(walkers) returns a complex number 71 | rea_v: vector between electron and atoms 72 | rea: distance between electron and atoms 73 | walkers: shape (n_electron, 3) 74 | max_l: shape(l_number,) values of l to evaluate 75 | key: random number 76 | quadrature: A quadrature object to do numerical integration. 77 | ecp_range: (n_up_start, n_up_end, n_down_start, n_down_end) 78 | returns: 79 | value of the integral \int (2l+1) * P_l(cos theta) psi(r1,..,ri,..) 80 | shape (n_electron, l_number) 81 | ''' 82 | 83 | n_electron = walkers.shape[0] 84 | ls = list(range(max_l-1)) 85 | res = jnp.zeros((n_electron, len(ls))) 86 | memory = f_memory(walkers) 87 | res_den = f_modify(memory, walkers) 88 | normal_walkers = rea_v / rea[:, None] 89 | f_memory_update_vec = jax.vmap(f_memory_update, in_axes=(None, 0, None)) 90 | f_modify_vec = jax.vmap(f_modify, in_axes=(0, 0), out_axes=(0)) 91 | 92 | def psi_r(i, x): 93 | coords = x.reshape(-1,3) * rea[i] + walkers[i] - rea_v[i] 94 | new_walkers = jnp.tile(walkers, (coords.shape[0],) + (1, 1)) 95 | new_walkers = new_walkers.at[:,i,:].set(coords) 96 | memory_vec = f_memory_update_vec(memory, new_walkers, i) 97 | res_num = f_modify_vec(memory_vec, new_walkers) 98 | res = res_num[0] * res_den[0] * jnp.exp(res_num[1] - res_den[1]) 99 | return res.reshape(x.shape[:-1]) 100 | npts = quadrature.discrete_pts(key, N=n_electron) 101 | 102 | Pl_ = lambda x: legendre_list(x, ls) 103 | def Pl(i, x): 104 | tmp = Pl_(jnp.matmul(x, normal_walkers[i, :])) ## Pl(cos(\theta)) 105 | return tmp 106 | def integral(i, res): 107 | result = quadrature.integral_value(Pl(i,npts[i]),psi_r(i,npts[i])) * (2 * jnp.array(ls) + 1) 108 | res = res.at[i,:].set(result) 109 | return res 110 | n_up_start, n_up_end, n_down_start, n_down_end = ecp_range 111 | res = jax.lax.fori_loop(n_up_start, n_up_end, integral, res) 112 | res = jax.lax.fori_loop(n_down_start, n_down_end, integral, res) 113 | 114 | return res 115 | 116 | 117 | def numerical_integral(psi, rea_v, rea, walkers, max_l, key, 118 | quadrature: Quadrature, ecp_range: Tuple): 119 | ''' 120 | ref: Nonlocal pseudopotentials and diffusion Monte Carlo, equation 28 121 | 122 | inputs: 123 | psi: wave function that psi(walkers) returns a complex number 124 | rea_v: vector between electron and atoms 125 | rea: distance between electron and atoms 126 | walkers: shape (n_electron, 3) 127 | max_l: shape(l_number,) values of l to evaluate 128 | key: random number 129 | quadrature: A quadrature object to do numerical integration. 130 | returns: 131 | value of the integral \int (2l+1) * P_l(cos theta) psi(r1,..,ri,..) 132 | shape (n_electron, l_number) 133 | ''' 134 | 135 | numerical_integral_exact_closure = lambda rv, r : \ 136 | numerical_integral_for_loop(psi, rv, r, walkers, max_l, key, quadrature, ecp_range) 137 | numerical_integral_exact_vmap = jax.vmap(numerical_integral_exact_closure, in_axes=(1,1), out_axes=(0)) 138 | res = numerical_integral_exact_vmap(rea_v, rea) 139 | return res.transpose(1,0,2) 140 | 141 | def numerical_integral_optim(f_modify, f_memory, f_memory_update, rea_v, rea, 142 | walkers, max_l, key, quadrature: Quadrature, ecp_range: Tuple): 143 | ''' 144 | ref: Nonlocal pseudopotentials and diffusion Monte Carlo, equation 28 145 | 146 | inputs: 147 | psi: wave function that psi(walkers) returns a complex number 148 | rea_v: vector between electron and atoms 149 | rea: distance between electron and atoms 150 | walkers: shape (n_electron, 3) 151 | max_l: shape(l_number,) values of l to evaluate 152 | key: random number 153 | quadrature: A quadrature object to do numerical integration. 154 | returns: 155 | value of the integral \int (2l+1) * P_l(cos theta) psi(r1,..,ri,..) 156 | shape (n_electron, l_number) 157 | ''' 158 | 159 | numerical_integral_exact_closure = lambda rv, r : \ 160 | numerical_integral_optim_for_loop(f_modify, f_memory, f_memory_update, rv, r, walkers, max_l, key, quadrature, ecp_range) 161 | numerical_integral_exact_vmap = jax.vmap(numerical_integral_exact_closure, in_axes=(1,1), out_axes=(0)) 162 | res = numerical_integral_exact_vmap(rea_v, rea) 163 | return res.transpose(1,0,2) 164 | -------------------------------------------------------------------------------- /jaqmc/pp/ph/hamiltonian.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | from jax_cosmo.scipy import interpolate 10 | import numpy as np 11 | 12 | def get_forward_laplacian_for_kinetic_ph(f, raw_ph_data, ph_atom_pos, rv_type='spline'): 13 | if rv_type == 'spline': 14 | return get_forward_laplacian_for_kinetic_ph_spline(f, raw_ph_data, ph_atom_pos) 15 | elif rv_type == 'linear': 16 | return get_forward_laplacian_for_kinetic_ph_linear(f, raw_ph_data, ph_atom_pos) 17 | else: 18 | raise NotImplementedError(f'rv_type {rv_type} not supported yet') 19 | 20 | def get_forward_laplacian_for_kinetic_ph_linear(f, raw_ph_data, ph_atom_pos): 21 | from lapjax import LapTuple, TupType 22 | from lapjax.numpy import matmul as lapjax_matmul 23 | 24 | rx = jnp.array(np.linspace(0, 10.0, 10001)) 25 | def rV(r, arr): 26 | return jnp.interp(r, rx, jnp.asarray(arr)) 27 | 28 | def prepare_ph_for_all_atoms(): 29 | all_loc_data = [] 30 | all_l2_data = [] 31 | all_atom_pos = [] 32 | for atom, _atom_pos in ph_atom_pos: 33 | loc_data, l2_data = raw_ph_data[atom] 34 | all_loc_data.append(loc_data) 35 | all_l2_data.append(l2_data) 36 | all_atom_pos.append(_atom_pos) 37 | return ( 38 | jnp.array(all_loc_data), 39 | jnp.array(all_l2_data), 40 | jnp.array(all_atom_pos)) 41 | ph_data = prepare_ph_for_all_atoms() 42 | 43 | @jax.vmap 44 | def get_eff_mass(x): 45 | def for_scan_f(carry, input): 46 | loc_data, l2_data, atom_pos = input 47 | r = jnp.linalg.norm(x - atom_pos) 48 | rv_l2_val = rV(r, l2_data) 49 | diag = rv_l2_val * r 50 | mass = -rv_l2_val * jnp.outer(x - atom_pos, x - atom_pos) / r + jnp.identity(3) * diag 51 | return mass + carry, None 52 | 53 | mass, _ = jax.lax.scan(for_scan_f, jnp.zeros([3, 3]), ph_data) 54 | mass += 0.5 * jnp.identity(3) 55 | return mass 56 | 57 | def kinetic_ph(params, data): 58 | f_closure = lambda x: f(params, x) 59 | data_per_elec = data.reshape(-1, 3) 60 | 61 | def second_order(): 62 | mass = get_eff_mass(data_per_elec) 63 | dim = 3 64 | L = jnp.linalg.cholesky(mass).transpose(0, 2, 1) 65 | input_laptuple = LapTuple(data, is_input=True) 66 | input_laptuple: LapTuple = lapjax_matmul(input_laptuple.reshape(-1, dim), jnp.eye(dim)) 67 | input_laptuple.grad = jnp.matmul(L,input_laptuple.grad.transpose(1, 0, 2)).transpose(1, 0, 2) 68 | output = f(params,input_laptuple) 69 | return -output.get(TupType.LAP) - jnp.sum(output.get(TupType.GRAD) ** 2) 70 | 71 | def first_order(): 72 | @jax.vmap 73 | def v_dot(grad, x): 74 | def for_scan_f(carry, input): 75 | loc_data, l2_data, atom_pos = input 76 | rel_pos = x - atom_pos 77 | r = jnp.linalg.norm(rel_pos) 78 | return carry + rV(r, l2_data) * 2 / r * rel_pos, None 79 | 80 | output, _ = jax.lax.scan(for_scan_f, jnp.zeros(3), ph_data) 81 | return jnp.dot(grad, output) 82 | 83 | grad = jax.grad(f_closure)(data) 84 | grad_blocks = grad.reshape(-1, 3) 85 | return jnp.sum(v_dot(grad_blocks, data_per_elec)) 86 | 87 | def zero_order(): 88 | @jax.vmap 89 | def f(x): 90 | def for_scan_f(carry, input): 91 | loc_data, l2_data, atom_pos = input 92 | r = jnp.linalg.norm(x - atom_pos) 93 | return carry + rV(r, loc_data) / r, None 94 | output, _ = jax.lax.scan(for_scan_f, 0.0, ph_data) 95 | return output 96 | 97 | return jnp.sum(f(data_per_elec)) 98 | return ( 99 | second_order() 100 | + first_order() + zero_order()) 101 | return kinetic_ph 102 | 103 | def get_forward_laplacian_for_kinetic_ph_spline(f, raw_ph_data, ph_atom_pos): 104 | from lapjax import LapTuple, TupType 105 | from lapjax.numpy import matmul as lapjax_matmul 106 | 107 | def prepare_rV_cubic_spline(): 108 | rx = jnp.array(np.linspace(0, 10.0, 10001)) 109 | rV_dict = {} 110 | i_atom = 0 111 | for atom, (loc_data, l2_data) in raw_ph_data.items(): 112 | rV_dict[(i_atom, 0)] = interpolate.InterpolatedUnivariateSpline(rx, loc_data, k=3) 113 | rV_dict[(i_atom, 1)] = interpolate.InterpolatedUnivariateSpline(rx, l2_data, k=3) 114 | i_atom += 1 115 | 116 | def rV(r, i_atom, mode=0): 117 | # mode=0 for loc_data, mode=1 for l2_data 118 | return rV_dict[(i_atom, mode)](r) 119 | 120 | return rV 121 | 122 | rV = prepare_rV_cubic_spline() 123 | 124 | def prepare_ph_for_all_atoms(): 125 | res = [] 126 | for atom, _atom_pos in ph_atom_pos: 127 | res.append((jnp.array(_atom_pos), list(raw_ph_data).index(atom))) 128 | return res 129 | ph_data = prepare_ph_for_all_atoms() 130 | 131 | @jax.vmap 132 | def get_eff_mass(x): 133 | mass = jnp.zeros([3, 3]) 134 | for atom_pos, i_atom in ph_data: 135 | r = jnp.linalg.norm(x - atom_pos) 136 | rv_l2_val = rV(r, i_atom, mode=1) 137 | diag = rv_l2_val * r 138 | mass += -rv_l2_val * jnp.outer(x - atom_pos, x - atom_pos) / r + jnp.identity(3) * diag 139 | mass += 0.5 * jnp.identity(3) 140 | return mass 141 | 142 | def kinetic_ph(params, data): 143 | f_closure = lambda x: f(params, x) 144 | data_per_elec = data.reshape(-1, 3) 145 | 146 | def second_order(): 147 | mass = get_eff_mass(data_per_elec) 148 | dim = 3 149 | L = jnp.linalg.cholesky(mass).transpose(0, 2, 1) 150 | input_laptuple = LapTuple(data, is_input=True) 151 | input_laptuple: LapTuple = lapjax_matmul(input_laptuple.reshape(-1, dim), jnp.eye(dim)) 152 | input_laptuple.grad = jnp.matmul(L,input_laptuple.grad.transpose(1, 0, 2)).transpose(1, 0, 2) 153 | output = f(params,input_laptuple) 154 | return -output.get(TupType.LAP) - jnp.sum(output.get(TupType.GRAD) ** 2) 155 | 156 | def first_order(): 157 | @jax.vmap 158 | def v_dot(grad, x): 159 | carry = jnp.zeros(3) 160 | for atom_pos, i_atom in ph_data: 161 | rel_pos = x - atom_pos 162 | r = jnp.linalg.norm(rel_pos) 163 | carry += rV(r, i_atom, mode=1) * 2 / r * rel_pos 164 | return jnp.dot(grad, carry) 165 | 166 | grad = jax.grad(f_closure)(data) 167 | grad_blocks = grad.reshape(-1, 3) 168 | return jnp.sum(v_dot(grad_blocks, data_per_elec)) 169 | 170 | def zero_order(): 171 | @jax.vmap 172 | def f(x): 173 | carry = 0.0 174 | for atom_pos, i_atom in ph_data: 175 | r = jnp.linalg.norm(x - atom_pos) 176 | carry += rV(r, i_atom, mode=0) / r 177 | return carry 178 | 179 | return jnp.sum(f(data_per_elec)) 180 | return ( 181 | second_order() 182 | + first_order() 183 | + zero_order()) 184 | return kinetic_ph 185 | -------------------------------------------------------------------------------- /jaqmc/dmc/hamiltonian.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). 16 | # All Bytedance Modifications are Copyright 2022 Bytedance Inc. 17 | 18 | 19 | """Evaluating the Hamiltonian on a wavefunction.""" 20 | 21 | from typing import Tuple 22 | 23 | import jax 24 | from jax import lax 25 | import jax.numpy as jnp 26 | 27 | from .utils import agg_mean 28 | 29 | 30 | def local_kinetic_energy(f, partition_num=0): 31 | r"""Creates a function to for the local kinetic energy, -1/2 \nabla^2 ln|f|. 32 | 33 | Args: 34 | f: log wave function. Takes in electronic configuration and output log of wavefunction. 35 | partition_num: 0: fori_loop implementation 36 | 1: Hessian implementation 37 | other positive integer: Split the laplacian to multiple trunks and 38 | calculate accordingly. 39 | 40 | Returns: 41 | Callable with signature lapl(data), which evaluates the local 42 | kinetic energy, -1/2f \nabla^2 f = -1/2 (\nabla^2 log|f| + 43 | (\nabla log|f|)^2). 44 | """ 45 | vjvp = jax.vmap(jax.jvp, in_axes=(None, None, 0)) 46 | 47 | def _lapl_over_f(x): 48 | n = x.shape[0] 49 | eye = jnp.eye(n) 50 | grad_f = jax.grad(f) 51 | # Use Hessian 52 | if partition_num == 1: 53 | g = grad_f(x) 54 | hess = jax.hessian(f)(x) 55 | return -0.5 * (jnp.trace(hess) + jnp.sum(g ** 2)) 56 | 57 | # Original implementation 58 | if partition_num == 0: 59 | def _body_fun(i, val): 60 | primal, tangent = jax.jvp(grad_f, (x,), (eye[i],)) 61 | return val + primal[i]**2 + tangent[i] 62 | return -0.5 * lax.fori_loop(0, n, _body_fun, 0.0) 63 | 64 | # vjvp implementation 65 | assert n % partition_num == 0, f'partition_num {partition_num} does not divide the dimension {n}' 66 | eyes = jnp.asarray(jnp.array_split(eye, partition_num)) 67 | 68 | def _body_fun(val, e): 69 | primal, tangent = vjvp(grad_f, (x,), (e,)) 70 | return val, (primal, tangent) 71 | 72 | _, (primal, tangent) = lax.scan(_body_fun, None, eyes) 73 | primal = primal.reshape((-1, primal.shape[-1])) 74 | tangent = tangent.reshape((-1, tangent.shape[-1])) 75 | return -0.5 * (jnp.sum(jnp.diagonal(primal) ** 2) + jnp.trace(tangent)) 76 | 77 | return _lapl_over_f 78 | 79 | 80 | def potential_energy(r_ae, r_ee, atoms, charges): 81 | """Returns the potential energy for this electron configuration. 82 | 83 | Args: 84 | r_ae: Shape (nelectrons, natoms). r_ae[i, j] gives the distance between 85 | electron i and atom j. 86 | r_ee: Shape (neletrons, nelectrons, :). r_ee[i,j,0] gives the distance 87 | between electrons i and j. Other elements in the final axes are not 88 | required. 89 | atoms: Shape (natoms, ndim). Positions of the atoms. 90 | charges: Shape (natoms). Nuclear charges of the atoms. 91 | """ 92 | v_ee = jnp.sum(jnp.triu(1 / r_ee[..., 0], k=1)) 93 | v_ae = -jnp.sum(charges / r_ae[..., 0]) # pylint: disable=invalid-unary-operand-type 94 | r_aa = jnp.linalg.norm(atoms[None, ...] - atoms[:, None], axis=-1) 95 | v_aa = jnp.sum( 96 | jnp.triu((charges[None, ...] * charges[..., None]) / r_aa, k=1)) 97 | return v_ee + v_ae + v_aa 98 | 99 | 100 | def get_dist( 101 | x: jnp.ndarray, 102 | atoms: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: 103 | """Calculate distance between electron and atomic positions. 104 | 105 | Args: 106 | x: electron positions. Shape (nelectrons * 3,). 107 | atoms: atom positions. Shape (natoms, 3). 108 | 109 | Returns: 110 | r_ae, r_ee pair, where: 111 | r_ae: atom-electron distance. Shape (nelectron, natom, 1). 112 | r_ee: electron-electron distance. Shape (nelectron, nelectron, 1). 113 | The diagonal terms in r_ee are masked out such that the gradients of these 114 | terms are also zero. 115 | """ 116 | ndim = 3 117 | ae = jnp.reshape(x, [-1, 1, ndim]) - atoms[None, ...] 118 | ee = jnp.reshape(x, [1, -1, ndim]) - jnp.reshape(x, [-1, 1, ndim]) 119 | 120 | r_ae = jnp.linalg.norm(ae, axis=2, keepdims=True) 121 | # Avoid computing the norm of zero, as is has undefined grad 122 | n = ee.shape[0] 123 | r_ee = ( 124 | jnp.linalg.norm(ee + jnp.eye(n)[..., None], axis=-1) * (1.0 - jnp.eye(n))) 125 | 126 | return r_ae, r_ee[..., None] 127 | 128 | 129 | def local_energy(f, atoms, charges, el_partition_num=0): 130 | """Creates function to evaluate the local energy. 131 | 132 | Args: 133 | f: Callable with signature f(data) which returns the log magnitude 134 | of the wavefunction given configurations data. 135 | atoms: Shape (natoms, ndim). Positions of the atoms. 136 | charges: Shape (natoms). Nuclear charges of the atoms. 137 | partition_num: 0: fori_loop implementation 138 | 1: Hessian implementation 139 | other positive integer: Split the laplacian to multiple trunks and 140 | calculate accordingly. 141 | 142 | Returns: 143 | Callable with signature e_l(data) which evaluates the local energy 144 | of the wavefunction given single MCMC configuration in data. 145 | """ 146 | ke = local_kinetic_energy(f, el_partition_num) 147 | 148 | def _e_l(x): 149 | """Returns the total energy. 150 | 151 | Args: 152 | x: MCMC configuration. 153 | """ 154 | r_ae, r_ee = get_dist(x, atoms) 155 | potential = potential_energy(r_ae, r_ee, atoms, charges) 156 | kinetic = ke(x) 157 | return potential + kinetic 158 | 159 | return _e_l 160 | 161 | def make_calc_energy_func(el_fun, clip_pair=None): 162 | ''' 163 | A factory for averaged energy calculation using local_energy func `el_fun` on a batch of walkers. 164 | ''' 165 | # position and mask are vectorized, not params 166 | vmap_in_axes = (0, 0) 167 | pmap_in_axes = (0, 0) 168 | 169 | def local_energy_func_with_mask(position, mask): 170 | return jax.lax.cond( 171 | mask, 172 | el_fun, 173 | lambda _: 0.0, 174 | position) 175 | 176 | num_device = jax.local_device_count() 177 | pmaped_energy_func = jax.pmap( 178 | jax.vmap(local_energy_func_with_mask, in_axes=vmap_in_axes), 179 | in_axes=pmap_in_axes) 180 | 181 | def calc_energy(flatten_position): 182 | num_walkers, walker_dim = flatten_position.shape 183 | if num_walkers % num_device == 0: 184 | target_num_walkers = num_walkers 185 | else: 186 | target_num_walkers = num_walkers + num_device - (num_walkers % num_device) 187 | to_pad_num = target_num_walkers - num_walkers 188 | mask = jnp.pad( 189 | jnp.ones(num_walkers), 190 | ((0, to_pad_num),), 191 | constant_values=0).reshape((num_device, -1)) 192 | position = jnp.pad( 193 | flatten_position, 194 | ((0, to_pad_num), (0, 0)), 195 | constant_values=0).reshape((num_device, -1, walker_dim)) 196 | _local_energy = pmaped_energy_func(position, mask) 197 | return calc_masked_energy(_local_energy, mask, clip_pair=clip_pair) 198 | 199 | return calc_energy 200 | 201 | def calc_masked_energy(local_energy, mask, clip_pair=None): 202 | if clip_pair is None: 203 | return agg_mean(local_energy, mask) 204 | clip_min, clip_max = clip_pair 205 | clipped_energy = jnp.clip(local_energy, clip_min, clip_max) 206 | return agg_mean(clipped_energy, mask) 207 | -------------------------------------------------------------------------------- /jaqmc/loss/vmc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Tuple 8 | 9 | import chex 10 | import jax 11 | import jax.numpy as jnp 12 | import kfac_jax 13 | 14 | from jaqmc.loss import utils 15 | 16 | @chex.dataclass 17 | class VMCFuncState: 18 | pass 19 | 20 | @chex.dataclass 21 | class VMCAuxData(utils.BaseAuxData): 22 | """ 23 | Auxiliary data returned from energy calculation. 24 | 25 | Attributes: 26 | variance: mean variance over batch, and over all devices if inside a pmap. 27 | local_energy: local energy for each MCMC configuration. 28 | outlier_mask: boolean array representing whether certain walker is marked 29 | as outlier in VMC calculation. 30 | """ 31 | variance: jnp.ndarray 32 | local_energy: jnp.ndarray 33 | outlier_mask: jnp.ndarray 34 | 35 | def make_vmc_loss( 36 | signed_network: utils.WaveFuncLike, 37 | local_energy: utils.LocalEnergy, 38 | clip_local_energy=0.0, 39 | rm_outlier=False, 40 | el_partition=1, 41 | local_energy_outlier_width=0.0) -> utils.Loss: 42 | """ 43 | Creates the loss function corresponding to the energy calculation. 44 | 45 | Args: 46 | signed_network: network wavefunction returning both sign and log magnitude. 47 | local_energy: callable which evaluates the local energy. 48 | clip_local_energy: If greater than zero, clip local energies that are 49 | outside [E_L - n D, E_L + n D], where E_L is the mean local energy, n is 50 | this value and D the mean absolute deviation of the local energies from 51 | the mean, to the boundaries. The clipped local energies are only used to 52 | evaluate gradients. 53 | rm_outlier: If True, outliers will be removed from the computation from both 54 | loss and its gradients, otherwise outliers would be clipped when 55 | computing gradients, in which case clipping won't happen in the computation 56 | of the loss value. 57 | el_partition: Create N folds data when computing local_energy to save the memory. 58 | local_energy_outlier_width: If greater than zero, the local energy outliers 59 | will be identified as the ones that are 60 | outside [E_L - n D, E_L + n D], where E_L is the mean local energy, n is 61 | this value and D the mean absolute deviation of the local energies from 62 | the mean, to the boundaries. Those outliers will be removed from the calculation 63 | of both the energy and its gradient, if `rm_outlier` is True. 64 | Returns: 65 | Callable with signature (params, data) and returns (loss, (None, aux_data)), 66 | where loss is the mean energy, and aux_data is of VMCAuxData. 67 | The loss is averaged over the batch and over all devices inside a pmap. 68 | """ 69 | network = lambda *args, **kwargs: signed_network(*args, **kwargs)[1] 70 | batch_local_energy = jax.vmap(local_energy, in_axes=(None, 0, 0), out_axes=0) 71 | batch_network = jax.vmap(network, in_axes=(None, 0), out_axes=0) 72 | 73 | 74 | @jax.custom_jvp 75 | def total_energy( 76 | params: utils.ParamTree, 77 | func_state: None, 78 | key: chex.PRNGKey, 79 | data: jnp.ndarray, 80 | ) -> Tuple[jnp.ndarray, Tuple[None, VMCAuxData]]: 81 | """ 82 | Evaluates the total energy of the neural network wavefunction.. 83 | 84 | Args: 85 | params: parameters of the neural network wavefunction. 86 | func_state: To pass the variables to be updated in training loop. 87 | key: PRNG state. 88 | data: Batched MCMC configurations to pass to the local energy function. 89 | 90 | Returns: 91 | (loss, (None, aux_data)), where loss is the averaged energy, and `None` for 92 | func_state, we don't need func_state so we just return `None`, aux_data 93 | is an AuxiliaryLossData object containing the variance of the energy and 94 | the local energy per MCMC configuration. The loss and variance are 95 | averaged over the batch and over all devices inside a pmap. 96 | """ 97 | 98 | # we don't have any variable to be updated in training loop. 99 | del func_state 100 | 101 | keys = jax.random.split(key, num=data.shape[0]) 102 | if el_partition > 1 : 103 | # create el_partition folds to save the memory when computing local energy 104 | data = data.reshape((el_partition,-1)+data.shape[1:]) 105 | keys = keys.reshape((el_partition,-1)+keys.shape[1:]) 106 | def batch_el_scan(carry,x): 107 | return carry, batch_local_energy(params, *x) 108 | _,e_l = jax.lax.scan(batch_el_scan, None, [keys,data]) 109 | e_l = e_l.reshape(-1) 110 | else: 111 | e_l = batch_local_energy(params, keys, data) 112 | # is_finite is false for inf and nan. We should throw them away anyways. 113 | is_finite = jnp.isfinite(e_l) 114 | # Then we convert nan to 0 and inf to large numbers, otherwise we won't 115 | # be able to mask them out. It's ok to do this cast because they will be 116 | # masked away in the following computation. 117 | e_l = jnp.nan_to_num(e_l) 118 | 119 | # if not `rm_outlier`, which means we will do clipping instead, in which case 120 | # we don't clip when computing the energy but do clip in gradient computation. 121 | if rm_outlier and local_energy_outlier_width > 0.: 122 | # This loss is computed only for outlier computation 123 | loss = utils.pmean_with_mask(e_l, is_finite) 124 | tv = utils.pmean_with_mask(jnp.abs(e_l - loss), is_finite) 125 | mask = ( 126 | (loss - local_energy_outlier_width * tv < e_l) & 127 | (loss + local_energy_outlier_width * tv > e_l) & 128 | is_finite) 129 | else: 130 | mask = is_finite 131 | 132 | loss = utils.pmean_with_mask(e_l, mask) 133 | variance = utils.pmean_with_mask((e_l - loss)**2, mask) 134 | 135 | return loss, (None, VMCAuxData(variance=variance, 136 | local_energy=e_l, 137 | outlier_mask=mask)) 138 | 139 | @total_energy.defjvp 140 | def total_energy_jvp(primals, tangents): 141 | """Custom Jacobian-vector product for unbiased local energy gradients.""" 142 | 143 | # func_state is not needed and assigned as `_`. 144 | params, _, key, data = primals 145 | loss, (func_state, aux_data) = total_energy(params, None, key, data) 146 | 147 | if clip_local_energy > 0.0: 148 | # We have to gather the el from all devices and then compute the median 149 | # otherwise the median would be different on different devices 150 | median = jnp.median(utils.gather(aux_data.local_energy)) 151 | 152 | # We have to apply mask here to remove the effect of possible inf and nan. 153 | tv = utils.pmean_with_mask(jnp.abs(aux_data.local_energy - median), aux_data.outlier_mask) 154 | diff = jnp.clip(aux_data.local_energy, 155 | median - clip_local_energy * tv, 156 | median + clip_local_energy * tv) 157 | # renormalize diff 158 | diff = diff - utils.pmean_with_mask(diff, aux_data.outlier_mask) 159 | device_batch_size = jnp.sum(aux_data.outlier_mask) 160 | else: 161 | diff = aux_data.local_energy - loss 162 | device_batch_size = jnp.shape(aux_data.local_energy)[0] 163 | diff *= aux_data.outlier_mask 164 | 165 | # Due to the simultaneous requirements of KFAC (calling convention must be 166 | # (params, rng, data)) and Laplacian calculation (only want to take 167 | # Laplacian wrt electron positions) we need to change up the calling 168 | # convention between total_energy and batch_network 169 | primals = primals[0], primals[3] 170 | tangents = tangents[0], tangents[3] 171 | psi_primal, psi_tangent = jax.jvp(batch_network, primals, tangents) 172 | kfac_jax.register_normal_predictive_distribution(psi_primal[:, None]) 173 | primals_out = loss, (func_state, aux_data) 174 | 175 | tangents_out = (jnp.dot(psi_tangent, diff) / device_batch_size, (func_state, aux_data)) 176 | return primals_out, tangents_out 177 | 178 | return total_energy 179 | -------------------------------------------------------------------------------- /jaqmc/dmc/storage_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | Handlers interacting with different types of storage systems. 9 | ''' 10 | 11 | from abc import ABC, abstractmethod 12 | import os 13 | import pathlib 14 | import re 15 | import shutil 16 | import subprocess 17 | from typing import List 18 | 19 | class StorageHandler(ABC): 20 | ''' 21 | put, get, mv, cp would do overwrite when the destined path already exists. 22 | ''' 23 | 24 | @abstractmethod 25 | def put(self, src_path: str, dst_path: str): 26 | """ 27 | Put the file at local `src_path` to the `dst_path` at the file system 28 | handled by this handler. 29 | If `dst_path` already exists, then it will be overwritten. 30 | """ 31 | pass 32 | 33 | @abstractmethod 34 | def get(self, src_path: str, dst_path: str): 35 | """ 36 | Get the file from `src_path` at the file system handled by this handler 37 | to local `dst_path`. 38 | If `dst_path` already exists, then it will be overwritten. 39 | """ 40 | pass 41 | 42 | @abstractmethod 43 | def rm(self, path: str): 44 | ''' 45 | Should be able to remove both file and directory. 46 | ''' 47 | pass 48 | 49 | @abstractmethod 50 | def ls(self, path: str, return_fullpath: bool = True) -> List[str]: 51 | ''' 52 | If `return_fullpath` is True, then the return values should full path, 53 | otherwise it should only return file names. 54 | ''' 55 | pass 56 | 57 | @abstractmethod 58 | def mkdir(self, path: str): 59 | pass 60 | 61 | @abstractmethod 62 | def mv(self, src_path: str, dst_path: str): 63 | """ 64 | Both `src_path` and `dst_path` should be at the file system handled by this handler. 65 | If `dst_path` already exists, then it will be overwritten. 66 | """ 67 | pass 68 | 69 | @abstractmethod 70 | def cp(self, src_path: str, dst_path: str): 71 | """ 72 | Both `src_path` and `dst_path` should be at the file system handled by this handler. 73 | If `dst_path` already exists, then it will be overwritten. 74 | """ 75 | pass 76 | 77 | @abstractmethod 78 | def exists(self, path: str) -> bool: 79 | ''' 80 | Whether `path`, file or directory, exists. 81 | ''' 82 | pass 83 | 84 | @abstractmethod 85 | def exists_dir(self, directory: str) -> bool: 86 | ''' 87 | Whether `directory` exists. Return True only when `directory` exists 88 | and it's actually a directory. 89 | ''' 90 | pass 91 | 92 | class DummyStorageHandler(StorageHandler): 93 | ''' 94 | A Dummy implmentation of `RemoteStorageHandler` which does no-op for 95 | all the methods. 96 | ''' 97 | 98 | @staticmethod 99 | def put(src_path: str, dst_path: str): 100 | return 101 | 102 | @staticmethod 103 | def get(src_path: str, dst_path: str): 104 | return 105 | 106 | @staticmethod 107 | def rm(path: str): 108 | return 109 | 110 | @staticmethod 111 | def ls(path: str, return_fullpath: bool = True) -> List[str]: 112 | return [] 113 | 114 | @staticmethod 115 | def mkdir(path: str): 116 | return 117 | 118 | @staticmethod 119 | def mv(src_path: str, dst_path: str): 120 | return 121 | 122 | @staticmethod 123 | def cp(src_path: str, dst_path: str): 124 | return 125 | 126 | @staticmethod 127 | def exists(path: str) -> bool: 128 | return False 129 | 130 | @staticmethod 131 | def exists_dir(path: str) -> bool: 132 | return False 133 | 134 | dummy_storage_handler = DummyStorageHandler() 135 | 136 | class LocalStorageHandler(StorageHandler): 137 | 138 | @staticmethod 139 | def put(src_path: str, dst_path: str): 140 | LocalStorageHandler.cp(src_path, dst_path) 141 | 142 | @staticmethod 143 | def get(src_path: str, dst_path: str): 144 | LocalStorageHandler.cp(src_path, dst_path) 145 | 146 | @staticmethod 147 | def rm(path: str): 148 | if not os.path.exists(path): 149 | return 150 | if os.path.isdir(path): 151 | shutil.rmtree(path) 152 | else: 153 | os.remove(path) 154 | 155 | @staticmethod 156 | def ls(path: str, return_fullpath: bool = True) -> List[str]: 157 | if return_fullpath: 158 | all_paths = pathlib.Path(path).glob('*') 159 | return [str(f.absolute()) for f in all_paths] 160 | else: 161 | return os.listdir(path) 162 | 163 | @staticmethod 164 | def mkdir(path: str): 165 | os.makedirs(path) 166 | 167 | @staticmethod 168 | def mv(src_path: str, dst_path: str): 169 | shutil.move(src_path, dst_path) 170 | 171 | @staticmethod 172 | def cp(src_path: str, dst_path: str): 173 | shutil.copy(src_path, dst_path) 174 | 175 | @staticmethod 176 | def exists(path: str): 177 | return os.path.exists(path) 178 | 179 | @staticmethod 180 | def exists_dir(directory: str): 181 | exists = os.path.exists(directory) 182 | return exists and os.path.isdir(directory) 183 | 184 | local_storage_handler = LocalStorageHandler() 185 | 186 | class HdfsStorageHandler(StorageHandler): 187 | 188 | def __init__(self, command_prefix='', env_variables=None): 189 | self.command_prefix = command_prefix 190 | env_variables = {} if env_variables is None else env_variables 191 | self.env = dict(os.environ, **env_variables) 192 | 193 | def call_helper(self, *args, check_output=False): 194 | args = list(args) 195 | if not self.command_prefix: 196 | full_args = ['hdfs', 'dfs'] + args 197 | else: 198 | full_args = [self.command_prefix, 'hdfs', 'dfs'] + args 199 | 200 | if check_output: 201 | return subprocess.check_output(full_args, env=self.env) 202 | else: 203 | try: 204 | subprocess.check_call(full_args, env=self.env) 205 | return True 206 | except subprocess.CalledProcessError: 207 | return False 208 | 209 | def put(self, src_path: str, dst_path: str): 210 | self.call_helper('-put', '-f', src_path, dst_path) 211 | 212 | def get(self, src_path: str, dst_path: str): 213 | exists_file = (LocalStorageHandler.exists(dst_path) 214 | and (not LocalStorageHandler.exists_dir(dst_path))) 215 | if exists_file: 216 | LocalStorageHandler.rm(dst_path) 217 | self.call_helper('-get', src_path, dst_path) 218 | 219 | def rm(self, path: str): 220 | self.call_helper('-rm', '-r', path) 221 | 222 | def ls(self, path: str, return_fullpath: bool = True) -> List[str]: 223 | # A sample line of the return value of `hdfs dfs -ls` is like 224 | # '-rw-r--r-- 3 renweiluo supergroup 276979794 2021-12-29 00:39 hdfs:///user/renweiluo/test/qmcjax_ckpt_000000.npz' 225 | # So we basically extract the last part from it given it starts with 'hdfs:/'. 226 | # 227 | # However if `path` is not an absolute path, then the last part may not 228 | # start with 'hdfs:/'. That said, it should always contains the given `path` 229 | # and we use it as the pattern to extract the needed info. 230 | if not path: 231 | return [] 232 | 233 | hdfs_path_pattern = f'.* ([^ ]*{path}[^ ]*)$' 234 | raw_output = self.call_helper('-ls', path, check_output=True).decode().split('\n') 235 | results = [] 236 | for line in raw_output: 237 | if not line: 238 | continue 239 | matched = re.match(hdfs_path_pattern, line) 240 | if matched is None: 241 | continue 242 | full_path, = matched.groups() 243 | filename = os.path.basename(full_path) 244 | results.append(filename) 245 | 246 | if return_fullpath: 247 | return [os.path.join(path, f) for f in results] 248 | return results 249 | 250 | def mkdir(self, path: str): 251 | self.call_helper('-mkdir', '-p', path) 252 | 253 | def mv(self, src_path: str, dst_path: str): 254 | self.call_helper('-mv', src_path, dst_path) 255 | 256 | def cp(self, src_path: str, dst_path: str): 257 | self.call_helper('-cp', src_path, dst_path) 258 | 259 | def exists(self, path: str) -> bool: 260 | return self.call_helper('-test', '-e', path) 261 | 262 | def exists_dir(self, directory: str) -> bool: 263 | return self.call_helper('-test', '-d', directory) 264 | -------------------------------------------------------------------------------- /jaqmc/pp/quadrature.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import jax.numpy as jnp 8 | import jax 9 | 10 | INTEGRAL_SAMPLE_SIZE = 1 11 | DEFAULT_QUADRATURE_ID = 'icosahedron_12' 12 | 13 | def get_quadrature(quadrature_id): 14 | ''' 15 | Args: 16 | quadrature_id: Expected to be "quadrature_type" + "_" + "number of points 17 | used in quadrature". It could also be None, in which case a default one 18 | will be used. 19 | ''' 20 | quadrature_id = quadrature_id or DEFAULT_QUADRATURE_ID 21 | 22 | ALL_QUADRATURES = { 23 | 'octahedron_6': Octahedron(6), 24 | 'octahedron_26': Octahedron(26), 25 | 'icosahedron_12': Icosahedron(12) 26 | } 27 | return ALL_QUADRATURES[quadrature_id] 28 | 29 | def expand_sign(l): 30 | ''' 31 | expand the set with signs, 32 | example: [a,b,c] => [+/- a, +/- b, +/- c] 33 | 34 | ''' 35 | if (len(l) == 1): 36 | res = [[l[0]]] 37 | if l[0] != 0: 38 | res = res + [[-l[0]]] 39 | return res 40 | rests = expand_sign(l[1:]) 41 | 42 | res = [[l[0]] + s for s in rests] 43 | if l[0] != 0: 44 | res += [[-l[0]] + s for s in rests] 45 | return res 46 | 47 | 48 | 49 | class Quadrature(): 50 | def __init__(self, n_p): 51 | self.np = n_p 52 | 53 | 54 | def integrate(self, f, rotationM): 55 | 56 | pts = jnp.einsum('ijk,kl->ijl',rotationM,self.pts.T) # [ro,3,np] 57 | pts = pts.transpose(0,2,1) #[ro,np,3] 58 | evl = f(pts) #(..., pts.shape[:-1]) 59 | nums = jnp.sum(evl * self.coefs, axis = -1) 60 | return jnp.mean(nums, axis=-1) 61 | 62 | @staticmethod 63 | def sample_orientation(N, key): 64 | # sample z orientation 65 | if (N == 0): 66 | return jnp.eye(3)[None, ...] 67 | key, sub_key = jax.random.split(key) 68 | 69 | phi = jax.random.uniform(sub_key, shape=(N,)) * jnp.pi * 2 70 | key, sub_key = jax.random.split(key) 71 | costheta = 1.0 - 2 * jax.random.uniform(sub_key, shape=(N,)) 72 | sintheta = jnp.sqrt(1.0 - costheta ** 2) 73 | 74 | sinphi = jnp.sin(phi) 75 | cosphi = jnp.cos(phi) 76 | sinphi2 = sinphi ** 2 77 | cosphi2 = cosphi ** 2 78 | 79 | M11 = sinphi2 + costheta * cosphi2 80 | M12 = sinphi * cosphi * (costheta - 1) 81 | M13 = sintheta * cosphi 82 | 83 | M21 = M12 84 | M22 = cosphi2 + costheta * sinphi2 85 | M23 = sintheta * sinphi 86 | 87 | M31 = -M13 88 | M32 = - M23 89 | M33 = costheta 90 | 91 | M = jnp.vstack([M11, M12, M13, M21, M22, M23, M31, M32, M33]).T 92 | M = M.reshape(-1, 3, 3) 93 | 94 | return M 95 | 96 | def discrete_pts(self, key, N=INTEGRAL_SAMPLE_SIZE): 97 | rotationM = self.sample_orientation(N, key) 98 | pts = jnp.einsum("ijk,kl->ijl", rotationM, self.pts.T) # [ro,3,np] 99 | pts = pts.transpose(0, 2, 1) 100 | return pts 101 | 102 | def integral_value(self, Pl, psi_r): 103 | result = jnp.sum(Pl * psi_r * self.coefs, axis = -1) 104 | return result * jnp.pi * 4.0 105 | 106 | def obtain_coefs(self): 107 | "for future DMC t-move" 108 | return self.coefs 109 | 110 | def x_pts(self): 111 | "for future DMC t-move" 112 | return self.pts 113 | 114 | def __call__(self, f, key, N=INTEGRAL_SAMPLE_SIZE): 115 | ''' 116 | /int f(x_1,x_2,x_3) dOmega over the sphere 117 | 118 | :param f: cartisian coords [3,] 119 | :param N: 120 | :return:[1,] 121 | ''' 122 | Ms = self.sample_orientation(N, key) 123 | 124 | res = self.integrate(f, Ms) * jnp.pi * 4.0 125 | 126 | return res 127 | 128 | 129 | 130 | 131 | class Octahedron(Quadrature): 132 | def __init__(self, n_p): 133 | super(Octahedron, self).__init__(n_p) 134 | 135 | A_num = 6 136 | B_num = 12 137 | C_num = 8 138 | D_num = 24 139 | 140 | self.coefs = { 141 | 6: jnp.array([1. / 6.] * A_num), 142 | 18: jnp.array([1. / 30.] * A_num + [1. / 15.] * B_num), 143 | 26: jnp.array([1. / 21.] * A_num + [4. / 105.] * B_num + [27. / 840.] * C_num), 144 | 50: jnp.array( 145 | [4. / 315.] * A_num + [64. / 2835.] * B_num + [27. / 1280.] * C_num + [14641. / 725760.] * D_num) 146 | } 147 | 148 | self.pts = expand_sign([1, 0, 0]) + expand_sign([0, 1, 0]) + expand_sign([0, 0, 1]) 149 | p = 1. / jnp.sqrt(2.) 150 | self.pts += expand_sign([p, p, 0]) + expand_sign([p, 0, p]) + expand_sign([0, p, p]) 151 | q = 1. / jnp.sqrt(3.) 152 | self.pts += expand_sign([q, q, q]) 153 | r = 1. / jnp.sqrt(11.) 154 | s = 3. / jnp.sqrt(11.) 155 | self.pts += expand_sign([r, r, s]) + expand_sign([r, s, r]) + expand_sign([s, r, r]) 156 | self.pts = jnp.array(self.pts) 157 | self.coefs = self.coefs[self.np] 158 | self.pts = self.pts[:self.np, :] 159 | 160 | 161 | class Icosahedron(Quadrature): 162 | def __init__(self, n_p): 163 | super(Icosahedron, self).__init__(n_p) 164 | 165 | A_num = 2 166 | B_num = 10 167 | C_num = 20 168 | 169 | self.coefs = { 170 | 12: jnp.array([1. / 12.] * (A_num + B_num)), 171 | 32: jnp.array([5. / 168.] * (A_num + B_num) + [27. / 840.] * C_num) 172 | } 173 | 174 | polars = [[0, 0], [jnp.pi, 0]] 175 | polars += [[jnp.arctan(2), 2 * k * jnp.pi / 5] for k in range(5)] 176 | polars += [[jnp.pi - jnp.arctan(2), (2 * k + 1) / 5. * jnp.pi] for k in range(5)] 177 | down = jnp.sqrt(15 + 6 * jnp.sqrt(5)) 178 | theta1 = jnp.arccos((2 + jnp.sqrt(5)) / down) 179 | theta2 = jnp.arccos(1. / down) 180 | 181 | polars += [[theta1, (2 * k + 1) * jnp.pi / 5.] for k in range(5)] 182 | polars += [[theta2, (2 * k + 1) * jnp.pi / 5.] for k in range(5)] 183 | polars += [[jnp.pi - theta1, 2 * k * jnp.pi / 5] for k in range(5)] 184 | polars += [[jnp.pi - theta2, 2 * k * jnp.pi / 5] for k in range(5)] 185 | 186 | toCartesian = lambda p: [jnp.sin(p[0]) * jnp.cos(p[1]), jnp.sin(p[0]) * jnp.sin(p[1]), jnp.cos(p[0])] 187 | 188 | self.pts = jnp.array([toCartesian(polar) for polar in polars])[:self.np, :] 189 | self.coefs = self.coefs[self.np] 190 | 191 | 192 | if __name__ == "__main__": 193 | 194 | def psi(x): 195 | return jnp.sum(x**2) 196 | 197 | 198 | import functools 199 | from scipy.special import sph_harm 200 | # import cProfile 201 | # import matplotlib.pyplot as plt 202 | # from matplotlib import cm, colors 203 | # from mpl_toolkits.mplot3d import Axes3D 204 | 205 | 206 | def to_polar(c): 207 | r = jnp.linalg.norm(c, axis = -1) 208 | theta = jnp.arccos(c[...,2]/r) 209 | phi = jnp.arctan(c[...,1]/c[...,0]) 210 | return (theta, phi) 211 | 212 | 213 | def sph_harm_car(c, l, m): 214 | p = to_polar(c) 215 | return sph_harm(m, l, p[1], p[0]) # scipy theta phi are non-conventional 216 | 217 | 218 | Y10 = functools.partial(sph_harm_car, l=1, m=0) 219 | Y20 = functools.partial(sph_harm_car, l=2, m=0) 220 | Y21 = functools.partial(sph_harm_car, l=2, m=1) 221 | Y31 = functools.partial(sph_harm_car, l=3, m=1) 222 | Y91 = functools.partial(sph_harm_car, l=9, m=1) 223 | 224 | octahedron = Octahedron(6) 225 | octahedron_50 = Octahedron(50) 226 | icosahedron = Icosahedron(12) 227 | 228 | print("=" * 100) 229 | print("Y10 norm Octahedron {0.real:.5f} + {0.imag:.5f}i ".format(octahedron(lambda x: jnp.conj(Y10(x)) * Y10(x)))) 230 | print("Y10 norm Icosahedron {0.real:.5f} + {0.imag:.5f}i ".format(icosahedron(lambda x: jnp.conj(Y10(x)) * Y10(x)))) 231 | 232 | print("=" * 100) 233 | print("Y21 norm Octahedron {0.real:.5f} + {0.imag:.5f}i ".format(octahedron(lambda x: jnp.conj(Y21(x)) * Y21(x)))) 234 | print("Y21 norm Octahedron {0.real:.5f} + {0.imag:.5f}i ".format(octahedron_50(lambda x: jnp.conj(Y21(x)) * Y21(x)))) 235 | print("Y21 norm Icosahedron {0.real:.5f} + {0.imag:.5f}i ".format(icosahedron(lambda x: jnp.conj(Y21(x)) * Y21(x)))) 236 | 237 | print("=" * 100) 238 | print("Y91 norm Octahedron {0.real:.5f} + {0.imag:.5f}i ".format(octahedron(lambda x: jnp.conj(Y91(x)) * Y91(x)))) 239 | print("Y91 norm Octahedron {0.real:.5f} + {0.imag:.5f}i ".format(octahedron_50(lambda x: jnp.conj(Y91(x)) * Y91(x)))) 240 | print("Y91 norm Icosahedron {0.real:.5f} + {0.imag:.5f}i ".format(icosahedron(lambda x: jnp.conj(Y91(x)) * Y91(x)))) 241 | 242 | print("=" * 100) 243 | print("Y21 * Y20 Octahedron {0.real:.5f} + {0.imag:.5f}i ".format(octahedron(lambda x: jnp.conj(Y21(x)) * Y20(x)))) 244 | print("Y21 * Y20 Octahedron {0.real:.5f} + {0.imag:.5f}i ".format(octahedron_50(lambda x: jnp.conj(Y21(x)) * Y20(x)))) 245 | print("Y21 * Y20 Icosahedron {0.real:.5f} + {0.imag:.5f}i ".format(icosahedron(lambda x: jnp.conj(Y21(x)) * Y20(x)))) 246 | 247 | 248 | 249 | 250 | -------------------------------------------------------------------------------- /examples/loss/lapnet/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | A show case for LapNet & penalty integrated loss. 9 | 10 | LapNet: https://github.com/bytedance/LapNet 11 | ''' 12 | 13 | 14 | import os 15 | import sys 16 | import time 17 | 18 | from absl import app, logging 19 | import jax 20 | import jax.numpy as jnp 21 | import kfac_jax 22 | from kfac_jax import utils as kfac_utils 23 | import numpy as np 24 | 25 | from lapnet import base_config, checkpoint, curvature_tags_and_blocks, hamiltonian, mcmc, networks 26 | from lapnet.utils import writers 27 | from lapnet.train import init_electrons, make_should_save_ckpt 28 | 29 | from jaqmc.loss.factory import build_func_state, make_loss 30 | from jaqmc.loss import utils 31 | 32 | def train(cfg): 33 | 34 | atoms = jnp.stack([jnp.array(atom.coords) for atom in cfg.system.molecule]) 35 | charges = jnp.array([atom.charge for atom in cfg.system.molecule]) 36 | nspins = cfg.system.electrons 37 | local_batch_size = cfg.batch_size 38 | signed_network, params, data, sharded_key = prepare(cfg, atoms, charges, nspins, local_batch_size) 39 | 40 | local_energy = hamiltonian.local_energy( 41 | f=signed_network, 42 | atoms=atoms, 43 | charges=charges, 44 | nspins=nspins) 45 | 46 | total_loss = make_loss( 47 | signed_network, 48 | local_energy, 49 | with_spin=cfg.optim.enforce_spin.with_spin, 50 | 51 | # Energy related 52 | clip_local_energy=cfg.optim.clip_el, 53 | rm_outlier=cfg.optim.rm_outlier, 54 | el_partition=cfg.optim.el_partition_num, 55 | local_energy_outlier_width=cfg.optim.local_energy_outlier_width, 56 | # Spin related 57 | nspins=nspins, 58 | spin_cfg=cfg.optim.enforce_spin, 59 | ) 60 | func_state = build_func_state(step=kfac_utils.replicate_all_local_devices(0)) 61 | 62 | val_and_grad = jax.value_and_grad(total_loss, argnums=0, has_aux=True) 63 | def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray: 64 | fg = 1.0 * (t_ >= cfg.optim.lr.warmup) 65 | orig_lr = cfg.optim.lr.rate * jnp.power( 66 | (1.0 / (1.0 + fg * (t_ - cfg.optim.lr.warmup)/cfg.optim.lr.delay)), cfg.optim.lr.decay) 67 | linear_lr = cfg.optim.lr.rate * t_ / (cfg.optim.lr.warmup + (cfg.optim.lr.warmup == 0.0)) 68 | return fg * orig_lr + (1 - fg) * linear_lr 69 | 70 | optimizer = kfac_jax.Optimizer( 71 | val_and_grad, 72 | l2_reg=cfg.optim.kfac.l2_reg, 73 | norm_constraint=cfg.optim.kfac.norm_constraint, 74 | value_func_has_aux=True, 75 | value_func_has_state=True, 76 | value_func_has_rng=True, 77 | learning_rate_schedule=learning_rate_schedule, 78 | curvature_ema=cfg.optim.kfac.cov_ema_decay, 79 | inverse_update_period=cfg.optim.kfac.invert_every, 80 | min_damping=cfg.optim.kfac.min_damping, 81 | num_burnin_steps=0, 82 | register_only_generic=cfg.optim.kfac.register_only_generic, 83 | estimation_mode='fisher_exact', 84 | multi_device=True, 85 | pmap_axis_name=utils.PMAP_AXIS_NAME, 86 | auto_register_kwargs=dict( 87 | graph_patterns=curvature_tags_and_blocks.GRAPH_PATTERNS, 88 | ), 89 | ) 90 | 91 | sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key) 92 | opt_state = optimizer.init(params, subkeys, data, func_state) 93 | 94 | time_of_last_ckpt = time.time() 95 | ckpt_save_path = checkpoint.create_save_path(cfg.log.save_path) 96 | should_save_ckpt = make_should_save_ckpt(cfg) 97 | 98 | mcmc_step, mcmc_width, update_mcmc_width = prepare_mcmc(cfg, signed_network, local_batch_size) 99 | do_logging, writer_manager, write_to_csv = prepare_logging(cfg.optim.enforce_spin.with_spin, ckpt_save_path) 100 | 101 | with writer_manager as writer: 102 | for t in range(cfg.optim.iterations): 103 | 104 | sharded_key, mcmc_keys, loss_keys = kfac_jax.utils.p_split_num(sharded_key, 3) 105 | func_state.spin.step = kfac_utils.replicate_all_local_devices(t) 106 | data, pmove = mcmc_step(params, data, mcmc_keys, mcmc_width) 107 | 108 | # Optimization step 109 | params, opt_state, _, stats = optimizer.step( 110 | params=params, 111 | func_state=func_state, 112 | state=opt_state, 113 | rng=loss_keys, 114 | data_iterator=iter([data]), 115 | momentum=kfac_jax.utils.replicate_all_local_devices(jnp.zeros([])), 116 | damping=kfac_jax.utils.replicate_all_local_devices(jnp.asarray(cfg.optim.kfac.damping))) 117 | mcmc_width = update_mcmc_width(t, mcmc_width, pmove[0]) 118 | 119 | do_logging(t, pmove, stats) 120 | write_to_csv(writer, t, pmove, stats) 121 | 122 | # Checkpointing 123 | if should_save_ckpt(t, time_of_last_ckpt): 124 | if cfg.optim.optimizer != 'none': 125 | checkpoint.save(ckpt_save_path, t, data, params, opt_state, mcmc_width, sharded_key) 126 | time_of_last_ckpt = time.time() 127 | 128 | def prepare(cfg, atoms, charges, nspins, local_batch_size): 129 | network_init, signed_network, *_ = networks.network_provider(cfg)(atoms, nspins, charges) 130 | key = jax.random.PRNGKey(int(1e6 * time.time())) 131 | 132 | params_initialization_key, sharded_key = jax.random.split(key) 133 | params = network_init(params_initialization_key) 134 | params = kfac_utils.replicate_all_local_devices(params) 135 | 136 | 137 | subkey, sharded_key = jax.random.split(sharded_key) 138 | data = init_electrons(subkey, cfg.system.molecule, cfg.system.electrons, 139 | local_batch_size, 140 | init_width=cfg.mcmc.init_width, 141 | given_atomic_spin_configs=cfg.system.get('atom_spin_configs')) 142 | data = data.reshape([jax.local_device_count(), local_batch_size // jax.local_device_count(), *data.shape[1:]]) 143 | sharded_key = kfac_jax.utils.make_different_rng_key_on_all_devices(sharded_key) 144 | return signed_network, params, data, sharded_key 145 | 146 | def prepare_mcmc(cfg, signed_network, local_batch_size): 147 | network = lambda *args, **kwargs: signed_network(*args, **kwargs)[1] 148 | batch_network = jax.vmap(network, in_axes=(None, 0)) 149 | mcmc_width = kfac_jax.utils.replicate_all_local_devices( 150 | jnp.asarray(cfg.mcmc.move_width)) 151 | mcmc_step = mcmc.make_mcmc_step( 152 | batch_network, 153 | local_batch_size // jax.local_device_count(), 154 | steps=cfg.mcmc.steps, 155 | blocks=cfg.mcmc.blocks, 156 | ) 157 | mcmc_step = utils.pmap(mcmc_step, donate_argnums=1) 158 | 159 | pmoves = np.zeros(cfg.mcmc.adapt_frequency) 160 | def update_mcmc_width(t, mcmc_width, pmove): 161 | if t > 0 and t % cfg.mcmc.adapt_frequency == 0: 162 | if np.mean(pmoves) > 0.55: 163 | mcmc_width *= 1.1 164 | if np.mean(pmoves) < 0.5: 165 | mcmc_width /= 1.1 166 | pmoves[:] = 0 167 | pmoves[t%cfg.mcmc.adapt_frequency] = pmove 168 | return mcmc_width 169 | 170 | return mcmc_step, mcmc_width, update_mcmc_width 171 | 172 | def prepare_logging(enforce_spin, save_dir): 173 | schema = ['energy', 'var', 'pmove'] 174 | if enforce_spin: 175 | schema += ['spin', 'spin_var'] 176 | message = '{t} ' + ' '.join(f'{key}: {{{key}:.4f}}' for key in schema) 177 | writer_manager = writers.Writer( 178 | name='result', 179 | schema=schema, 180 | directory=save_dir, 181 | iteration_key=None, 182 | log=False) 183 | 184 | def _prepare(t, pmove, stats): 185 | aux = stats['aux'] 186 | vmc_aux = aux.vmc 187 | logging_dict = { 188 | 't': t, 189 | 'energy': stats['loss'][0], 190 | 'var': vmc_aux.variance[0], 191 | 'pmove': pmove[0]} 192 | 193 | if enforce_spin: 194 | spin_aux = aux.spin 195 | logging_dict['spin'] = spin_aux.estimator[0] 196 | logging_dict['spin_var'] = spin_aux.variance[0] 197 | return logging_dict 198 | 199 | def do_logging(t, pmove, stats): 200 | logging_dict = _prepare(t, pmove, stats) 201 | logging.info(message.format(**logging_dict)) 202 | 203 | def write_to_csv(writer, t, pmove, stats): 204 | logging_dict = _prepare(t, pmove, stats) 205 | writer.write(**logging_dict) 206 | 207 | return do_logging, writer_manager, write_to_csv 208 | 209 | def main(_): 210 | cfg = FLAGS.config 211 | cfg = base_config.resolve(cfg) 212 | loss_cfg = FLAGS.loss_config 213 | cfg['optim'] = {**cfg['optim'], **loss_cfg} 214 | 215 | logging.get_absl_handler().python_handler.stream = sys.stdout 216 | logging.set_verbosity(logging.INFO) 217 | train(cfg) 218 | 219 | if __name__ == '__main__': 220 | from absl import flags 221 | from ml_collections.config_flags import config_flags 222 | import pathlib 223 | FLAGS = flags.FLAGS 224 | 225 | config_flags.DEFINE_config_file('config', None, 'Path to config file.') 226 | loss_config_file = str(pathlib.Path(os.path.abspath(__file__)).parents[1].absolute() / 'loss_config.py') 227 | config_flags.DEFINE_config_file('loss_config', loss_config_file, 'Path to loss config file.') 228 | app.run(main) 229 | -------------------------------------------------------------------------------- /tests/dmc/position_update_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from absl.testing import absltest 8 | import jax.test_util as jtu 9 | import jax 10 | import jax.numpy as jnp 11 | import numpy as np 12 | import scipy as sp 13 | 14 | from jaqmc.dmc.position_update import * 15 | 16 | class PositionUpdateTest(jtu.JaxTestCase): 17 | def test_velocity_clipping_mild_v(self): 18 | velocity = jnp.array([0.3, 0.5, 0.2]) 19 | position = jnp.array([1.0, 0, 0]) 20 | nearest_nucleus = jnp.array([2.0, 0, 0]) 21 | nearest_charge = 1.0 22 | time_step = 1e-2 23 | 24 | clipped_velocity = clip_velocity_helper( 25 | velocity, position, nearest_nucleus, nearest_charge, time_step) 26 | self.assertArraysAllClose(velocity, clipped_velocity, rtol=1e-3) 27 | 28 | def test_velocity_clipping_outrageous_v(self): 29 | velocity = jnp.array([10, 0, 0]) 30 | position = jnp.array([1.0, 0, 0]) 31 | nearest_nucleus = jnp.array([0.0, 0, 0]) 32 | nearest_charge = 10.0 33 | time_step = 1e-1 34 | 35 | a = 1 + 100 / 1040 36 | v_norm = np.linalg.norm(velocity) 37 | expected_clip_factor = (-1 + np.sqrt(1 + 2 * a * v_norm**2 * time_step)) / (a * v_norm ** 2 * time_step) 38 | 39 | clipped_velocity = clip_velocity_helper( 40 | velocity, position, nearest_nucleus, nearest_charge, time_step) 41 | self.assertArraysAllClose(expected_clip_factor * velocity, clipped_velocity) 42 | 43 | 44 | def test_do_drift_mild(self): 45 | position = jnp.array([1.0, 0, 0]) 46 | nearest_nucleus = jnp.array([2.0, 0, 0]) 47 | clipped_velocity = jnp.array([0.1, 0.2, 0.3]) 48 | time_step = 1e-2 49 | drifted_position = do_drift(position, nearest_nucleus, clipped_velocity, time_step) 50 | naive_drift_position = position + clipped_velocity * time_step 51 | self.assertArraysAllClose(drifted_position, naive_drift_position, rtol=1e-3) 52 | 53 | def test_do_drift_overshoot(self): 54 | position = jnp.array([1.9, 0, 0]) 55 | nearest_nucleus = jnp.array([2.0, 0, 0]) 56 | clipped_velocity = jnp.array([2.0, 2, 3]) 57 | time_step = 1e-1 58 | drifted_position = do_drift(position, nearest_nucleus, clipped_velocity, time_step) 59 | expected_position = nearest_nucleus 60 | self.assertArraysAllClose(drifted_position, nearest_nucleus) 61 | 62 | def test_do_drift_close_to_overshoot(self): 63 | position = jnp.array([1, 0, 0]) 64 | nearest_nucleus = jnp.array([2.0, 0, 0]) 65 | clipped_velocity = jnp.array([9, 2, 3]) 66 | time_step = 1e-1 67 | drifted_position = do_drift(position, nearest_nucleus, clipped_velocity, time_step) 68 | rho_factor = (2 * 0.1 / 1.1) 69 | expected_drift_position = position + jnp.array((clipped_velocity[0], ) + tuple(clipped_velocity[1:] * rho_factor)) * time_step 70 | naive_drift_position = position + clipped_velocity * time_step 71 | self.assertArraysAllClose(drifted_position, expected_drift_position) 72 | self.assertNotEqual( 73 | jnp.linalg.norm(naive_drift_position[1:]), 74 | jnp.linalg.norm(drifted_position[1:])) 75 | 76 | 77 | def test_overshoot_prob_small(self): 78 | position = jnp.array([1, 0, 0]) 79 | nearest_nucleus = jnp.array([2.0, 0, 0]) 80 | clipped_velocity = jnp.array([1., 2, 3]) 81 | time_step = 1e-2 82 | overshoot_prob = calc_nucleus_overshoot_prob_helper( 83 | position, 84 | nearest_nucleus, 85 | clipped_velocity, 86 | time_step) 87 | self.assertLess(overshoot_prob, 0.1) 88 | 89 | 90 | def test_overshoot_prob_large(self): 91 | position = jnp.array([1.9, 0, 0]) 92 | nearest_nucleus = jnp.array([2.0, 0, 0]) 93 | clipped_velocity = jnp.array([100., 2, 3]) 94 | time_step = 1e-2 95 | overshoot_prob = calc_nucleus_overshoot_prob_helper( 96 | position, 97 | nearest_nucleus, 98 | clipped_velocity, 99 | time_step) 100 | self.assertGreater(overshoot_prob, 0.9) 101 | 102 | def test_overshoot_prob_exact(self): 103 | position = jnp.array([0.9, 0, 0]) 104 | nearest_nucleus = jnp.array([0.0, 0, 0]) 105 | clipped_velocity = jnp.array([1., 2, 3]) 106 | time_step = 0.1 107 | overshoot_prob = calc_nucleus_overshoot_prob_helper( 108 | position, 109 | nearest_nucleus, 110 | clipped_velocity, 111 | time_step) 112 | expected_prob = 0.5 * jax.lax.erfc(1 / jnp.sqrt(0.2)) 113 | self.assertEqual(overshoot_prob, expected_prob) 114 | 115 | def test_do_diffusion_gaussian(self): 116 | drifted_position = jnp.array([10.0, 0, 0]) 117 | nearest_nucleus = jnp.array([-10.0, 0, 0]) 118 | overshoot_prob = 0 119 | 120 | seed = 42 121 | key = jax.random.PRNGKey(seed) 122 | 123 | num_data = int(1e4) 124 | all_keys = jax.random.split(key, num_data) 125 | time_step = 1e-2 126 | laplace_zeta = 1 / jnp.sqrt(time_step) 127 | vmapped_do_diffusion = jax.vmap( 128 | lambda key: do_diffusion(drifted_position, nearest_nucleus, 129 | overshoot_prob, 130 | key, 131 | time_step, 132 | laplace_zeta)[0]) 133 | diffused_position = vmapped_do_diffusion(all_keys) 134 | gaussian_vars = ( 135 | (diffused_position - drifted_position) / jnp.sqrt(time_step) 136 | ).reshape((-1,)) 137 | 138 | def expected_sample_gaussian(size): 139 | return np.random.normal(size=size) 140 | test_result = sp.stats.kstest(expected_sample_gaussian, gaussian_vars) 141 | self.assertGreater(test_result.pvalue, 0.1) 142 | 143 | def test_do_diffusion_gamma(self): 144 | drifted_position = jnp.array([10.0, 0, 0]) 145 | nearest_nucleus = jnp.array([-10.0, 0, 0]) 146 | overshoot_prob = 1 147 | 148 | seed = 42 149 | key = jax.random.PRNGKey(seed) 150 | 151 | num_data = int(1e6) 152 | all_keys = jax.random.split(key, num_data) 153 | time_step = 1e-2 154 | laplace_zeta = 1 / jnp.sqrt(time_step) 155 | vmapped_do_diffusion = jax.vmap( 156 | lambda key: do_diffusion(drifted_position, nearest_nucleus, 157 | overshoot_prob, 158 | key, 159 | time_step, 160 | laplace_zeta)[0]) 161 | diffused_position = vmapped_do_diffusion(all_keys) 162 | gamma_vars = (diffused_position - nearest_nucleus) 163 | 164 | self.do_test_gamma(gamma_vars, laplace_zeta) 165 | 166 | def do_test_gamma(self, data, laplace_zeta): 167 | def expected_sample_gamma(size): 168 | return np.random.gamma(shape=3, size=size) 169 | def expected_sample_uniform(size): 170 | return np.random.uniform(size=size) 171 | data_norm = jnp.linalg.norm(data, axis=-1) 172 | x, y, z = data.T 173 | theta = jnp.arccos(z / data_norm) 174 | sin_phi_sign = y / jnp.sin(theta) 175 | 176 | phi_between_half_pi = jnp.arctan(y / x) 177 | phi =( 178 | phi_between_half_pi 179 | + (phi_between_half_pi < 0) * (sin_phi_sign > 0) * jnp.pi 180 | - (phi_between_half_pi > 0) * (sin_phi_sign < 0) * jnp.pi 181 | ) 182 | 183 | gamma_vars = data_norm * 2 * laplace_zeta 184 | uniform_vars = (phi + jnp.pi) / 2 / jnp.pi 185 | 186 | test_result_norm = sp.stats.kstest(expected_sample_gamma, gamma_vars) 187 | test_result_phi = sp.stats.kstest(expected_sample_uniform, uniform_vars) 188 | self.assertGreater(test_result_norm.pvalue, 0.1) 189 | self.assertGreater(test_result_phi.pvalue, 0.1) 190 | 191 | 192 | correct_l = 0 193 | total_l = 0 194 | for l in np.arange(0, np.pi, 0.01): 195 | n = np.sum(theta < l) 196 | if np.isclose(n / len(theta), (1 - np.cos(l)) / 2, rtol=0.05): 197 | correct_l += 1 198 | total_l += 1 199 | self.assertGreater(correct_l / total_l, 0.95) 200 | 201 | 202 | def test_sample_gamma(self): 203 | 204 | laplace_zeta = 10 205 | num_data = 1000000 206 | seed = 42 207 | key = jax.random.PRNGKey(seed) 208 | 209 | all_keys = jax.random.split(key, num_data) 210 | vmapped_sample_func = jax.vmap(sample_gamma, in_axes=(0, None)) 211 | data = vmapped_sample_func(all_keys, laplace_zeta) 212 | self.do_test_gamma(data, laplace_zeta) 213 | 214 | def test_calc_G_log(self): 215 | updated_position = jnp.zeros((3, )) 216 | nearest_nucleus = jnp.array([1.0, 0.0, 0.0]) 217 | drifted_position = jnp.array([2.0, 0.0, 0.0]) 218 | nucleus_overshoot_prob = 0.1 219 | 220 | time_step = 0.01 221 | laplace_zeta = 1 / jnp.sqrt(time_step) 222 | 223 | G_log = calc_G_log(updated_position, nearest_nucleus, drifted_position, 224 | nucleus_overshoot_prob, laplace_zeta, time_step) 225 | expected_G_log = jnp.log( 226 | (1 - nucleus_overshoot_prob) * (2 * jnp.pi * time_step) ** (-1.5) * jnp.exp(-(jnp.linalg.norm(updated_position - drifted_position) ** 2) / 2 / time_step) 227 | + nucleus_overshoot_prob * laplace_zeta ** 3 / jnp.pi * jnp.exp(-2 * laplace_zeta * jnp.linalg.norm(updated_position - nearest_nucleus)) 228 | ) 229 | self.assertAlmostEqual(G_log, expected_G_log) 230 | 231 | if __name__ == '__main__': 232 | absltest.main() 233 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JaQMC: JAX accelerated Quantum Monte Carlo 2 | 3 | A collection of GPU-friendly and neural-network-friendly scalable 4 | Quantum Monte Carlo (QMC) implementations in JAX. 5 | 6 | Currently supported functionalities: 7 | - Diffusion Monte Carlo (DMC) 8 | - Spin Symmetry Enforcement 9 | - Pseudopotential (PP) 10 | 11 | ## Installation 12 | JaQMC can be installed via the supplied setup.py file. 13 | ```shell 14 | pip3 install -e . 15 | ``` 16 | 17 | ## Introduction 18 | 19 | JaQMC is modularizely designed for easier integration with various Neural 20 | Network based Quantum Monte Carlo (NNQMC) projects. 21 | 22 | The functionalities are developed in `jaqmc` module, while we provide a number 23 | of scripts integrating with different NNQMC projects in `example` directory. 24 | 25 | ### Diffusion Monte Carlo (DMC) 26 | The fixed-node DMC implementation introduced in 27 | [Towards the ground state of molecules via diffusion Monte Carlo on neural networks](https://www.nature.com/articles/s41467-023-37609-3) 28 | 29 | See [DMC](#dmc) section for more details. 30 | 31 | ### Spin Symmetry Enforcement with $\hat{S}_+$ penalties 32 | The spin symmetry enforced solution introduced in [Symmetry enforced solution of the many-body Schrödinger equation with deep neural network](https://arxiv.org/abs/2406.01222) 33 | 34 | See [Spin Symmetry](#spin-symmetry) section for more details. 35 | 36 | ### Pseudopotential(PP) 37 | The (semi-local or local) pseudopotentials introduced in [Fermionic neural network with effective core potential](https://journals.aps.org/prresearch/abstract/10.1103/PhysRevResearch.4.013021) and [Local Pseudopotential Unlocks the True Potential of Neural Network-based Quantum Monte Carlo](https://arxiv.org/abs/2505.19909) 38 | 39 | See [PP](#pp) section for more details. 40 | 41 | ## DMC 42 | The fixed-node diffusion Monte Carlo (FNDMC) implementation here has a simple interface. 43 | In the simplest case, it requires only a (real-valued) trial wavefunction, taking in a dim-3N electron configuration and producing two outputs: 44 | the sign of the wavefunction value and the logarithm of its absolute value. 45 | In more sophisticated cases, users can also provide the implementation of local energy and quantum force, for instance, when ECP is considered. 46 | 47 | Several examples are provided integrating with neural-network-based trial wavefunctions. The DMC related config can be found in the `examples/dmc_config.py`. 48 | See [here](https://github.com/google/ml_collections/tree/master#config-flags) for instructions on how to play with those config or flags. 49 | 50 | ### Integration with FermiNet 51 | Please first install FermiNet following instructions in https://github.com/deepmind/ferminet. 52 | Then train FermiNet for your favorite atom / molecule and generate a checkpoint to be reused in DMC as the trial wavefunction. 53 | ```shell 54 | python3 examples/dmc/ferminet/run.py --config $YOUR_FERMINET_CONFIG_FILE --config.log.save_path $YOUR_FERMINET_CKPT_DIRECTORY --dmc_config.iterations 100 --dmc_config.fix_size --dmc_config.block_size 10 --dmc_config.log.save_path $YOUR_DMC_CKPT_DIRECTORY 55 | ``` 56 | 57 | ### Integration with LapNet 58 | Please first install LapNet following instructions in https://github.com/bytedance/lapnet. 59 | Then train LapNet for your favorite atom / molecule and generate a checkpoint to be reused in DMC as the trial wavefunction. 60 | ```shell 61 | python3 examples/dmc/lapnet/run.py --config $YOUR_LAPNET_CONFIG_FILE --config.log.save_path $YOUR_LAPNET_CKPT_DIRECTORY --dmc_config.iterations 100 --dmc_config.fix_size --dmc_config.block_size 10 --dmc_config.log.save_path $YOUR_DMC_CKPT_DIRECTORY 62 | ``` 63 | 64 | ### Integration with DeepErwin 65 | Please first install DeepErwin following instructions in https://mdsunivie.github.io/deeperwin/. 66 | Then train DeepErwin for your favorite atom / molecule and generate a checkpoint to be reused in DMC as the trial wavefunction. 67 | ```shell 68 | python3 examples/dmc/deeperwin/run.py --deeperwin_ckpt $YOUR_DEEPERVIN_CKPT_FILE --dmc_config.iterations 100 --dmc_config.fix_size --dmc_config.block_size 10 --dmc_config.log.save_path $YOUR_DMC_CKPT_DIRECTORY 69 | ``` 70 | 71 | ### Do Your Own Integration 72 | The entry point for DMC integration is the `run` function in `jaqmc/dmc/dmc.py`, which is quite heavily commented. 73 | Basically you only need to construct your favorite trial wavefunction in JAX, then simply pass it to this `run` function and it should work smoothly. 74 | Please don't hesitate to file an issue if you need help to integrate with your favorite (JAX-implemented) trial wavefunction. 75 | 76 | Note that our DMC implementation is "multi-node calculation ready" in the sense that if you initialize the distributed JAX runtime 77 | on a multi-node cluster, then our DMC implementation can do multi-node calculation correctly, i.e. aggregation across 78 | different computing nodes. See [here](https://jax.readthedocs.io/en/latest/multi_process.html?highlight=multi-node) for instructions on initialization of the distributed JAX runtime. 79 | 80 | 81 | ### Output 82 | The data at each checkpoint step will be stored in the specified path (namely `$YOUR_DMC_CKPT_DIRECTORY` in the examples above) with the naming pattern 83 | ``` 84 | dmc_data_{step}.tgz 85 | ``` 86 | which contains a csv file with the metric produced from each DMC step up to the checkpoint step. 87 | The columns of the metric file are 88 | 1. step: The step index in DMC 89 | 2. estimator: The mixed estimator calculated at each step, calculated and smoothed within a certain time window. 90 | 3. offset: The energy offset used to update DMC walker weights. 91 | 4. average: The local energy weighted average calculated at each DMC step. 92 | 5. num_walkers: The total number of walkers across all the computing nodes. 93 | 6. old_walkers: The number of walkers got rejected for too many times in the process. 94 | 7. total_weight: The total weight of all walkers across all the computing nodes. 95 | 8. acceptance_ratio: The acceptence ratio of the acceptence-rejection action. 96 | 9. effective_time_step: The effective time step 97 | 10. num_cutoff_updated, num_cutoff_orig: Debug related, indicating the number of outliers in terms of local energy. 98 | 99 | ## Spin Symmetry 100 | We enforce the spin symmetry with two steps: 101 | 1. Set the spin magnetic spin number to be the target spin value $s_z = s$, by setting the number of spin-up and spin-down electrons in the input of the neural network wavefunction. 102 | 2. Integrate $\hat{S}_+$ penalty into the loss function to enforce the spin symmetry. 103 | 104 | We implement `loss` module in JaQMC for that purpose. 105 | For each component of loss, such as VMC energy and spin related penalties, we build a 106 | factory method to produce losses with the same interface: 107 | ``` 108 | class Loss(Protocol): 109 | def __call__(self, 110 | params: ParamTree, 111 | func_state: BaseFuncState, 112 | key: chex.PRNGKey, 113 | data: jnp.ndarray) -> Tuple[jnp.ndarray, Tuple[BaseFuncState, BaseAuxData]]: 114 | """ 115 | Args: 116 | params: network parameters. 117 | func_state: function state passed to the loss function to control its behavior. 118 | key: JAX PRNG state. 119 | data: QMC walkers with electronic configuration to evaluate. 120 | Returns: 121 | (loss value, (updated func_state, auxillary data) 122 | """ 123 | ``` 124 | This loss interface works well with [KFAC optimizer](https://github.com/google-deepmind/kfac-jax). 125 | It is also flexible enough to work with optimizers in [optax](https://github.com/google-deepmind/optax), 126 | [SPRING](https://github.com/jeffminlin/vmcnet/blob/master/vmcnet/updates/spring.py) and etc. 127 | 128 | We also provide user-facing entry points in `jaqmc/loss/factory.py`. 129 | One for building `func_state`, one of the inputs to the loss function, and 130 | another one for building the loss function. 131 | ``` 132 | def build_func_state(step=None) -> FuncState: 133 | ''' 134 | Helper function to create parent FuncState from actual data. 135 | ''' 136 | ...... 137 | 138 | ``` 139 | 140 | ### Integration with LapNet 141 | Please first install LapNet following instructions in https://github.com/bytedance/lapnet. 142 | To simulate singlet state for Oxygen atom with LapNet and spin symmetry enforced, simply turn on `loss_config.enforce_spin.with_spin` flag 143 | as follows. 144 | ```shell 145 | python3 $JAQMC_PATH/examples/loss/lapnet/run.py --config $JAQMC_PATH/examples/loss/lapnet/atom_spin_state.py:O,0 146 | --loss_config.enforce_spin.with_spin --config.$OTHER_LAPNET_CONFIGS --loss_config.enforce_spin.$OTHER_SPIN_CONFIGS 147 | ``` 148 | Note that this example script is by no means "production-ready". It is just a 149 | show case on how to integrate the `loss` module with exisiting NNQMC projects. 150 | For instance, it's not including the pretrain phase since it has nothing to do 151 | with the `loss` module. 152 | 153 | ## PP 154 | We provide two types of PP interfaces, namely effective core potential (ECP) and Pseudo-Hamiltonian (PH). 155 | These interfaces can be easily combined with various network structures and used for the calculation of VMC and DMC. 156 | 157 | For ECP, we have provided all available ccECP atoms. The available ccECP atoms can be queried at https://pseudopotentiallibrary.org/. 158 | The ECP configs can be found at ​​examples/pp/lapnet/configs/ecp/. 159 | 160 | - X.py​​ is used for single-atom calculations. 161 | ​ 162 | 163 | For PH, we have provided all available PH atoms. The currently available PH atoms include ​​S, Cr, Mn, Fe, Co, Ni, Cu, and Zn​​. 164 | The PH configs can be found at ​​examples/pp/lapnet/configs/ph/​​. 165 | 166 | - X.py​​ is used for single-atom calculations. 167 | ​​ 168 | - XS.py​​ is used for sulfide calculations. 169 | 170 | 171 | We implement `PP` module in JaQMC and introduce different PPs with the function: 172 | ```python 173 | def pp_energy(f: WavefunctionLike, 174 | atoms: jnp.ndarray, 175 | nspins: Sequence[int], 176 | charges: jnp.ndarray, 177 | pyscf_mol: pyscf.gto.mole, 178 | pp_cfg, 179 | energy_local : EnergyPattern = None, 180 | use_scan: bool = False, 181 | el_partition_num = 0, 182 | forward_laplacian=True, 183 | ) -> EnergyPattern: 184 | """Returns the total energy funtion. 185 | Args: 186 | f: network parameters. 187 | atoms: Shape (natoms, ndim). Positions of the atoms. 188 | charges:Shape (natoms). Nuclear charges of the atoms. 189 | pyscf_mol: pyscf molecule object. 190 | pp_cfg: pp config. 191 | energy_local: local energy function. 192 | use_scan: whether to use scan. 193 | el_partition_num: number of electrons. 194 | forward_laplacian: whether to use forward mode for the laplacian. 195 | Returns: 196 | energy: total energy. 197 | """ 198 | ``` 199 | 200 | ### Integration with LapNet 201 | Please first install LapNet following instructions in https://github.com/bytedance/lapnet. 202 | Then train LapNet for your favorite atom / molecule with pseudopotential. Taking the example of a S atom. 203 | - Training LapNet with ph in the VMC framework 204 | ```shell 205 | python3 examples/pp/lapnet/run.py --config examples/pp/lapnet/configs/ph/X.py:S,2 --config.batch_size 256 --config.pretrain.iterations 10 --config.optim.iterations 10 --config.log.save_path $YOUR_VMC_CKPT_DIRECTORY 206 | ``` 207 | 208 | - Training LapNet with ecp in the VMC framework 209 | ```shell 210 | python3 examples/pp/lapnet/run.py --config examples/pp/lapnet/configs/ecp/X.py:S,2 --config.batch_size 256 --config.pretrain.iterations 10 --config.optim.iterations 10 --config.log.save_path $YOUR_VMC_CKPT_DIRECTORY 211 | ``` 212 | 213 | 214 | ## Giving Credit 215 | If you use certain functionalities of JaQMC in your work, please consider citing the corresponding papers. 216 | ### DMC paper 217 | ``` 218 | @article{ren2023towards, 219 | title={Towards the ground state of molecules via diffusion Monte Carlo on neural networks}, 220 | author={Ren, Weiluo and Fu, Weizhong and Wu, Xiaojie and Chen, Ji}, 221 | journal={Nature Communications}, 222 | volume={14}, 223 | number={1}, 224 | pages={1860}, 225 | year={2023}, 226 | publisher={Nature Publishing Group UK London} 227 | } 228 | ``` 229 | 230 | ### Spin Symmetry paper 231 | ``` 232 | @article{li2024spin, 233 | title={Spin-symmetry-enforced solution of the many-body Schr{\"o}dinger equation with a deep neural network}, 234 | author={Li, Zhe and Lu, Zixiang and Li, Ruichen and Wen, Xuelan and Li, Xiang and Wang, Liwei and Chen, Ji and Ren, Weiluo}, 235 | journal={Nature Computational Science}, 236 | volume={4}, 237 | number={12}, 238 | pages={910--919}, 239 | year={2024}, 240 | publisher={Nature Publishing Group} 241 | } 242 | ``` 243 | 244 | ### ECP paper 245 | ``` 246 | @article{li2022fermionic, 247 | title={Fermionic neural network with effective core potential}, 248 | author={Li, Xiang and Fan, Cunwei and Ren, Weiluo and Chen, Ji}, 249 | journal={Physical Review Research}, 250 | volume={4}, 251 | number={1}, 252 | pages={013021}, 253 | year={2022}, 254 | publisher={APS} 255 | } 256 | ``` 257 | 258 | ### PH paper 259 | ``` 260 | @article{fu2025local, 261 | title={Local Pseudopotential Unlocks the True Potential of Neural Network-based Quantum Monte Carlo}, 262 | author={Fu, Weizhong and Fujimaru, Ryunosuke and Li, Ruichen and Liu, Yuzhi and Wen, Xuelan and Li, Xiang and Hongo, Kenta and Wang, Liwei and Ichibha, Tom and Maezono, Ryo and others}, 263 | journal={arXiv preprint arXiv:2505.19909}, 264 | year={2025} 265 | } 266 | ``` -------------------------------------------------------------------------------- /examples/pp/lapnet/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ''' 8 | A show case for LapNet with PH and ECP. 9 | 10 | LapNet: https://github.com/bytedance/LapNet 11 | ''' 12 | 13 | 14 | import sys 15 | import time 16 | from typing import Sequence, Tuple 17 | 18 | from absl import app, logging 19 | import jax 20 | import jax.numpy as jnp 21 | import kfac_jax 22 | from kfac_jax import utils as kfac_utils 23 | import ml_collections 24 | import numpy as np 25 | import pyscf 26 | 27 | from lapnet import base_config, checkpoint, curvature_tags_and_blocks, hamiltonian, mcmc, networks 28 | from lapnet.utils import writers, system 29 | from lapnet.train import make_should_save_ckpt 30 | 31 | from jaqmc.loss.factory import build_func_state, make_loss 32 | from jaqmc.loss import utils 33 | 34 | def train(cfg): 35 | 36 | # Check if mol is a pyscf molecule and convert to internal representation 37 | if cfg.system.pyscf_mol: 38 | cfg.update( 39 | pyscf_mol_to_internal_representation(cfg.system.pyscf_mol)) 40 | 41 | atoms = jnp.stack([jnp.array(atom.coords) for atom in cfg.system.molecule]) 42 | charges = jnp.array([atom.charge for atom in cfg.system.molecule]) 43 | nspins = cfg.system.electrons 44 | local_batch_size = cfg.batch_size 45 | 46 | signed_network, params, data, sharded_key, network_options = prepare(cfg, atoms, charges, nspins, local_batch_size) 47 | 48 | if use_ecp_or_ph(cfg): 49 | from jaqmc.pp.hamiltonian import pp_energy 50 | logging.info('Applying ECP or PH from JaQMC') 51 | local_energy = pp_energy( 52 | f=signed_network, 53 | atoms=atoms, 54 | nspins=nspins, 55 | charges=charges, 56 | pyscf_mol=cfg.system.pyscf_mol, 57 | pp_cfg=cfg.ecp, 58 | energy_local=None, 59 | use_scan=False, 60 | el_partition_num=cfg.optim.el_partition_num, 61 | forward_laplacian=cfg.optim.forward_laplacian, 62 | ) 63 | else: 64 | local_energy = hamiltonian.local_energy( 65 | f=signed_network, 66 | atoms=atoms, 67 | charges=charges, 68 | nspins=nspins) 69 | 70 | total_loss = make_loss( 71 | signed_network, 72 | local_energy, 73 | 74 | # Energy related 75 | clip_local_energy=cfg.optim.clip_el, 76 | rm_outlier=cfg.optim.rm_outlier, 77 | el_partition=cfg.optim.el_partition_num, 78 | local_energy_outlier_width=cfg.optim.local_energy_outlier_width, 79 | nspins=nspins, 80 | ) 81 | func_state = build_func_state(step=kfac_utils.replicate_all_local_devices(0)) 82 | 83 | val_and_grad = jax.value_and_grad(total_loss, argnums=0, has_aux=True) 84 | def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray: 85 | fg = 1.0 * (t_ >= cfg.optim.lr.warmup) 86 | orig_lr = cfg.optim.lr.rate * jnp.power( 87 | (1.0 / (1.0 + fg * (t_ - cfg.optim.lr.warmup)/cfg.optim.lr.delay)), cfg.optim.lr.decay) 88 | linear_lr = cfg.optim.lr.rate * t_ / (cfg.optim.lr.warmup + (cfg.optim.lr.warmup == 0.0)) 89 | return fg * orig_lr + (1 - fg) * linear_lr 90 | 91 | optimizer = kfac_jax.Optimizer( 92 | val_and_grad, 93 | l2_reg=cfg.optim.kfac.l2_reg, 94 | norm_constraint=cfg.optim.kfac.norm_constraint, 95 | value_func_has_aux=True, 96 | value_func_has_state=True, 97 | value_func_has_rng=True, 98 | learning_rate_schedule=learning_rate_schedule, 99 | curvature_ema=cfg.optim.kfac.cov_ema_decay, 100 | inverse_update_period=cfg.optim.kfac.invert_every, 101 | min_damping=cfg.optim.kfac.min_damping, 102 | num_burnin_steps=0, 103 | register_only_generic=cfg.optim.kfac.register_only_generic, 104 | estimation_mode='fisher_exact', 105 | multi_device=True, 106 | pmap_axis_name=utils.PMAP_AXIS_NAME, 107 | auto_register_kwargs=dict( 108 | graph_patterns=curvature_tags_and_blocks.GRAPH_PATTERNS, 109 | ), 110 | ) 111 | 112 | sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key) 113 | opt_state = optimizer.init(params, subkeys, data, func_state) 114 | 115 | time_of_last_ckpt = time.time() 116 | ckpt_save_path = checkpoint.create_save_path(cfg.log.save_path) 117 | should_save_ckpt = make_should_save_ckpt(cfg) 118 | 119 | mcmc_step, mcmc_width, update_mcmc_width = prepare_mcmc(cfg, signed_network, local_batch_size) 120 | do_logging, writer_manager, write_to_csv = prepare_logging(ckpt_save_path) 121 | 122 | with writer_manager as writer: 123 | for t in range(cfg.optim.iterations): 124 | 125 | sharded_key, mcmc_keys, loss_keys = kfac_jax.utils.p_split_num(sharded_key, 3) 126 | func_state.spin.step = kfac_utils.replicate_all_local_devices(t) 127 | data, pmove = mcmc_step(params, data, mcmc_keys, mcmc_width) 128 | 129 | # Optimization step 130 | params, opt_state, _, stats = optimizer.step( 131 | params=params, 132 | func_state=func_state, 133 | state=opt_state, 134 | rng=loss_keys, 135 | data_iterator=iter([data]), 136 | momentum=kfac_jax.utils.replicate_all_local_devices(jnp.zeros([])), 137 | damping=kfac_jax.utils.replicate_all_local_devices(jnp.asarray(cfg.optim.kfac.damping))) 138 | mcmc_width = update_mcmc_width(t, mcmc_width, pmove[0]) 139 | 140 | do_logging(t, pmove, stats) 141 | write_to_csv(writer, t, pmove, stats) 142 | 143 | # Checkpointing 144 | if should_save_ckpt(t, time_of_last_ckpt): 145 | if cfg.optim.optimizer != 'none': 146 | checkpoint.save(ckpt_save_path, t, data, params, opt_state, mcmc_width, sharded_key) 147 | time_of_last_ckpt = time.time() 148 | 149 | def pyscf_mol_to_internal_representation( 150 | mol: pyscf.gto.Mole) -> ml_collections.ConfigDict: 151 | """Converts a PySCF Mole object to an internal representation. 152 | 153 | Args: 154 | mol: Mole object describing the system of interest. 155 | 156 | Returns: 157 | A ConfigDict with the fields required to describe the system set. 158 | """ 159 | # Ensure Mole is built so all attributes are appropriately set. 160 | mol.build() 161 | atoms = [ 162 | system.Atom(mol.atom_symbol(i), mol.atom_coord(i), charge=mol.atom_charge(i)) 163 | for i in range(mol.natm) 164 | ] 165 | return ml_collections.ConfigDict({ 166 | 'system': { 167 | 'molecule': atoms, 168 | 'electrons': mol.nelec, 169 | }, 170 | 'pretrain': { 171 | # If mol.basis isn't a string, assume that mol is passed into 172 | # pretraining as well and pretraining uses the basis already set in 173 | # mol, rather than complicating the configuration here. 174 | 'basis': mol.basis if isinstance(mol.basis, str) else None, 175 | }, 176 | }) 177 | 178 | 179 | def init_electrons( 180 | key, 181 | molecule: Sequence[system.Atom], 182 | electrons: Sequence[int], 183 | batch_size: int, 184 | init_width=1.0, 185 | given_atomic_spin_configs: Sequence[Tuple[int, int]] = None 186 | ) -> jnp.ndarray: 187 | """Initializes electron positions around each atom. 188 | 189 | Args: 190 | key: JAX RNG state. 191 | molecule: system.Atom objects making up the molecule. 192 | electrons: tuple of number of alpha and beta electrons. 193 | batch_size: total number of MCMC configurations to generate across all 194 | devices. 195 | init_width: width of (atom-centred) Gaussian used to generate initial 196 | electron configurations. 197 | 198 | Returns: 199 | array of (batch_size, (nalpha+nbeta)*ndim) of initial (random) electron 200 | positions in the initial MCMC configurations and ndim is the dimensionality 201 | of the space (i.e. typically 3). 202 | """ 203 | if given_atomic_spin_configs is None: 204 | logging.warning('no spin assignment in the system config, may lead to unexpected initialization') 205 | 206 | if (sum(atom.charge for atom in molecule) != sum(electrons) 207 | and 208 | given_atomic_spin_configs is None): 209 | if len(molecule) == 1: 210 | atomic_spin_configs = [electrons] 211 | else: 212 | raise NotImplementedError('No initialization policy yet ' 213 | 'exists for charged molecules.') 214 | else: 215 | 216 | atomic_spin_configs = [ 217 | (atom.element.nalpha - int((atom.atomic_number - atom.charge) // 2), 218 | atom.element.nbeta - int((atom.atomic_number - atom.charge) // 2)) 219 | for atom in molecule 220 | ] if given_atomic_spin_configs is None else given_atomic_spin_configs 221 | 222 | assert sum(sum(x) for x in atomic_spin_configs) == sum(electrons) 223 | while tuple(sum(x) for x in zip(*atomic_spin_configs)) != electrons: 224 | i = np.random.randint(len(atomic_spin_configs)) 225 | nalpha, nbeta = atomic_spin_configs[i] 226 | if atomic_spin_configs[i][0] > 0: 227 | atomic_spin_configs[i] = nalpha - 1, nbeta + 1 228 | 229 | # Assign each electron to an atom initially. 230 | electron_positions = [] 231 | for i in range(2): 232 | for j in range(len(molecule)): 233 | atom_position = jnp.asarray(molecule[j].coords) 234 | electron_positions.append( 235 | jnp.tile(atom_position, atomic_spin_configs[j][i])) 236 | electron_positions = jnp.concatenate(electron_positions) 237 | # Create a batch of configurations with a Gaussian distribution about each 238 | # atom. 239 | key, subkey = jax.random.split(key) 240 | return ( 241 | electron_positions + 242 | init_width * 243 | jax.random.normal(subkey, shape=(batch_size, electron_positions.size))) 244 | 245 | 246 | def prepare(cfg, atoms, charges, nspins, local_batch_size): 247 | network_init, signed_network, network_options, *_ = networks.network_provider(cfg)(atoms, nspins, charges) 248 | key = jax.random.PRNGKey(int(1e6 * time.time())) 249 | 250 | params_initialization_key, sharded_key = jax.random.split(key) 251 | params = network_init(params_initialization_key) 252 | params = kfac_utils.replicate_all_local_devices(params) 253 | 254 | 255 | subkey, sharded_key = jax.random.split(sharded_key) 256 | data = init_electrons(subkey, cfg.system.molecule, cfg.system.electrons, 257 | local_batch_size, 258 | init_width=cfg.mcmc.init_width, 259 | given_atomic_spin_configs=cfg.system.get('atom_spin_configs')) 260 | data = data.reshape([jax.local_device_count(), local_batch_size // jax.local_device_count(), *data.shape[1:]]) 261 | sharded_key = kfac_jax.utils.make_different_rng_key_on_all_devices(sharded_key) 262 | return signed_network, params, data, sharded_key, network_options 263 | 264 | def prepare_mcmc(cfg, signed_network, local_batch_size): 265 | network = lambda *args, **kwargs: signed_network(*args, **kwargs)[1] 266 | batch_network = jax.vmap(network, in_axes=(None, 0)) 267 | mcmc_width = kfac_jax.utils.replicate_all_local_devices( 268 | jnp.asarray(cfg.mcmc.move_width)) 269 | mcmc_step = mcmc.make_mcmc_step( 270 | batch_network, 271 | local_batch_size // jax.local_device_count(), 272 | steps=cfg.mcmc.steps, 273 | blocks=cfg.mcmc.blocks, 274 | ) 275 | mcmc_step = utils.pmap(mcmc_step, donate_argnums=1) 276 | 277 | pmoves = np.zeros(cfg.mcmc.adapt_frequency) 278 | def update_mcmc_width(t, mcmc_width, pmove): 279 | if t > 0 and t % cfg.mcmc.adapt_frequency == 0: 280 | if np.mean(pmoves) > 0.55: 281 | mcmc_width *= 1.1 282 | if np.mean(pmoves) < 0.5: 283 | mcmc_width /= 1.1 284 | pmoves[:] = 0 285 | pmoves[t%cfg.mcmc.adapt_frequency] = pmove 286 | return mcmc_width 287 | 288 | return mcmc_step, mcmc_width, update_mcmc_width 289 | 290 | def use_ecp_or_ph(cfg): 291 | if cfg.system.get('pyscf_mol') is None: 292 | return False 293 | # Including PH 294 | use_ecp = bool(cfg.system.pyscf_mol._ecp) 295 | if 'ecp' not in cfg or cfg.ecp.ph_info is None: 296 | use_ph = False 297 | else: 298 | use_ph = len(cfg.ecp.ph_info[0]) > 0 299 | logging.info(f'Use_ECP (including PH): {use_ecp}; Use_PH: {use_ph}') 300 | return use_ecp or use_ph 301 | 302 | def prepare_logging(save_dir): 303 | schema = ['energy', 'var', 'pmove'] 304 | message = '{t} ' + ' '.join(f'{key}: {{{key}:.4f}}' for key in schema) 305 | writer_manager = writers.Writer( 306 | name='result', 307 | schema=schema, 308 | directory=save_dir, 309 | iteration_key=None, 310 | log=False) 311 | 312 | def _prepare(t, pmove, stats): 313 | aux = stats['aux'] 314 | vmc_aux = aux.vmc 315 | logging_dict = { 316 | 't': t, 317 | 'energy': stats['loss'][0], 318 | 'var': vmc_aux.variance[0], 319 | 'pmove': pmove[0]} 320 | 321 | return logging_dict 322 | 323 | def do_logging(t, pmove, stats): 324 | logging_dict = _prepare(t, pmove, stats) 325 | logging.info(message.format(**logging_dict)) 326 | 327 | def write_to_csv(writer, t, pmove, stats): 328 | logging_dict = _prepare(t, pmove, stats) 329 | writer.write(**logging_dict) 330 | 331 | return do_logging, writer_manager, write_to_csv 332 | 333 | def main(_): 334 | cfg = FLAGS.config 335 | cfg = base_config.resolve(cfg) 336 | 337 | logging.get_absl_handler().python_handler.stream = sys.stdout 338 | logging.set_verbosity(logging.INFO) 339 | train(cfg) 340 | 341 | if __name__ == '__main__': 342 | from absl import flags 343 | from ml_collections.config_flags import config_flags 344 | import pathlib 345 | FLAGS = flags.FLAGS 346 | 347 | config_flags.DEFINE_config_file('config', None, 'Path to config file.') 348 | app.run(main) --------------------------------------------------------------------------------