├── molgym ├── __init__.py ├── agents │ ├── __init__.py │ ├── covariant │ │ ├── __init__.py │ │ ├── gmm.py │ │ ├── tools.py │ │ ├── so3_tools.py │ │ ├── modules.py │ │ ├── spherical_dists.py │ │ └── agent.py │ ├── internal │ │ ├── __init__.py │ │ └── zmat.py │ └── base.py ├── tools │ ├── __init__.py │ ├── analysis.py │ ├── qm9_parser.py │ ├── model_util.py │ ├── arg_parser.py │ └── util.py ├── version.py ├── minimizer.py ├── modules.py ├── buffer_container.py ├── calculator.py ├── reward.py ├── spaces.py ├── env_container.py ├── buffer.py ├── environment.py └── ppo.py ├── tests ├── __init__.py ├── agents │ ├── __init__.py │ ├── internal │ │ ├── __init__.py │ │ └── test_zmat.py │ └── covariant │ │ ├── __init__.py │ │ ├── resources │ │ ├── h2o.xyz │ │ ├── ch3.xyz │ │ └── ch4.xyz │ │ ├── test_gmm.py │ │ ├── test_tools.py │ │ ├── test_sphs.py │ │ ├── test_so3_tools.py │ │ ├── test_agent.py │ │ └── test_spherical_distr.py ├── resources │ ├── h2o.xyz │ ├── ethanol.xyz │ ├── energy.dat │ └── gradients.dat ├── test_tools.py ├── test_reward.py ├── test_modules.py ├── test_sparrow.py ├── test_environment.py ├── test_minimizer.py └── test_spaces.py ├── .flake8 ├── resources └── intro.png ├── .gitignore ├── requirements.txt ├── .mypy.ini ├── setup.py ├── LICENSE ├── scripts ├── structures.py ├── plot.py ├── run.py ├── run_stochastic.py └── run_solvation.py ├── README.md └── .style.yapf /molgym/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /molgym/agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /molgym/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/agents/internal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /molgym/agents/covariant/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /molgym/agents/internal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/agents/covariant/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | extend-ignore = E501 -------------------------------------------------------------------------------- /molgym/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.0.0' 2 | -------------------------------------------------------------------------------- /resources/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gncs/molgym/HEAD/resources/intro.png -------------------------------------------------------------------------------- /tests/resources/h2o.xyz: -------------------------------------------------------------------------------- 1 | 3 2 | 3 | O -0.27939703 0.83823215 0.00973345 4 | H -0.52040310 1.77677325 0.21391146 5 | H 0.54473632 0.90669722 -0.53501306 6 | -------------------------------------------------------------------------------- /tests/agents/covariant/resources/h2o.xyz: -------------------------------------------------------------------------------- 1 | 3 2 | 3 | O -0.27939703 0.83823215 0.00973345 4 | H -0.52040310 1.77677325 0.21391146 5 | H 0.54473632 0.90669722 -0.53501306 6 | -------------------------------------------------------------------------------- /tests/agents/covariant/resources/ch3.xyz: -------------------------------------------------------------------------------- 1 | 4 2 | 3 | C -0.64199507 0.00300222 0.00102728 4 | H -1.00569497 -0.76355062 0.67642535 5 | H -1.00565717 -0.19575086 -1.00111577 6 | H -0.99863651 0.97308735 0.32936706 7 | -------------------------------------------------------------------------------- /tests/agents/covariant/resources/ch4.xyz: -------------------------------------------------------------------------------- 1 | 5 2 | 3 | C -0.64199507 0.00300222 0.00102728 4 | H 0.44248372 -0.00188810 -0.00060391 5 | H -1.00569497 -0.76355062 0.67642535 6 | H -1.00565717 -0.19575086 -1.00111577 7 | H -0.99863651 0.97308735 0.32936706 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE 2 | .idea/ 3 | 4 | # mypy 5 | .mypy_cache/ 6 | .pytest_cache/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | 12 | # Distribution and packaging 13 | *.egg-info/ 14 | .eggs/ 15 | 16 | # Jupyter 17 | .ipynb_checkpoints/ 18 | 19 | data/ 20 | log/ 21 | tf_log/ 22 | src/ 23 | output/ 24 | *.ipynb 25 | *.out 26 | -------------------------------------------------------------------------------- /tests/resources/ethanol.xyz: -------------------------------------------------------------------------------- 1 | 9 2 | 3 | O -1.207454 0.251814 0.015195 4 | C -0.030747 -0.571012 0.019581 5 | C 1.237302 0.268221 0.009126 6 | H -0.076875 -1.243951 0.894697 7 | H -0.152197 -1.158922 -0.910633 8 | H 2.133158 -0.361916 -0.023873 9 | H 1.265320 0.929390 -0.867509 10 | H 1.316863 0.906659 0.895476 11 | H -1.196271 0.874218 0.765639 12 | -------------------------------------------------------------------------------- /molgym/agents/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import List, Optional 3 | 4 | import numpy as np 5 | import torch.distributions 6 | 7 | from molgym.spaces import ObservationType, ObservationSpace, ActionSpace 8 | 9 | 10 | class AbstractActorCritic(torch.nn.Module, abc.ABC): 11 | def __init__(self, observation_space: ObservationSpace, action_space: ActionSpace): 12 | super().__init__() 13 | 14 | self.observation_space = observation_space 15 | self.action_space = action_space 16 | 17 | @abc.abstractmethod 18 | def step(self, observations: List[ObservationType], actions: Optional[np.ndarray] = None) -> dict: 19 | raise NotImplementedError 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ase==3.20.0 2 | cloudpickle==1.3.0 3 | -e git+git://github.com/risilab/cormorant.git@6a4b6370e8a7cfd7bf253ecfb5783b6d7787ba3f#egg=cormorant 4 | cycler==0.10.0 5 | future==0.18.2 6 | gym==0.17.2 7 | h5py==2.10.0 8 | importlib-metadata==1.7.0 9 | kiwisolver==1.2.0 10 | matplotlib==3.3.0 11 | mpmath==1.1.0 12 | ndim==0.1.4 13 | numpy==1.19.1 14 | orthopy==0.8.4 15 | pandas==1.1.0 16 | Pillow==7.2.0 17 | protobuf==3.12.4 18 | pyglet==1.5.0 19 | pyparsing==2.4.7 20 | python-dateutil==2.8.1 21 | pytz==2020.1 22 | PyYAML==5.3.1 23 | quadpy==0.16.2 24 | schnetpack==0.3 25 | scipy==1.5.2 26 | six==1.15.0 27 | sympy==1.6.2 28 | tensorboardX==2.1 29 | torch==1.5.1 30 | torch-scatter==2.0.5 31 | zipp==3.1.0 32 | -------------------------------------------------------------------------------- /tests/resources/energy.dat: -------------------------------------------------------------------------------- 1 | # SPARROW, command-line program 2 | # Host: gabriel 3 | # Start: Fri Oct 18 18:26:07 2019 4 | # 5 | # Program call: sparrow --molecular_charge 0 --spin_multiplicity 1 --method PM6 --unrestricted_calculation --structure h2o.xyz --gradients --output_to_file 6 | # Method: PM6 7 | # Calculation Modus: Spin-unrestricted formalism 8 | # Molecular charge: 0 9 | # Spin multiplicity: 1 10 | # Convergence threshold: 1e-05 hartree 11 | # Max Iterations: 100 12 | # Convergence accelerator: diis 13 | # Parameters: /home/gregor/local/sparrow-1.0.0-gcc-9.2.1/resources/Parameters/Pm6/parameters.xml 14 | # 15 | # Energy [hartree]: 16 | -11.72459668 17 | -------------------------------------------------------------------------------- /tests/resources/gradients.dat: -------------------------------------------------------------------------------- 1 | # SPARROW, command-line program 2 | # Host: gabriel 3 | # Start: Fri Oct 18 18:26:07 2019 4 | # 5 | # Program call: sparrow --molecular_charge 0 --spin_multiplicity 1 --method PM6 --unrestricted_calculation --structure h2o.xyz --gradients --output_to_file 6 | # Method: PM6 7 | # Calculation Modus: Spin-unrestricted formalism 8 | # Molecular charge: 0 9 | # Spin multiplicity: 1 10 | # Convergence threshold: 1.000000e-05 hartree 11 | # Max Iterations: 100 12 | # Convergence accelerator: diis 13 | # Parameters: /home/gregor/local/sparrow-1.0.0-gcc-9.2.1/resources/Parameters/Pm6/parameters.xml 14 | # 15 | # Gradients [hartree/bohr]: 16 | -8.700857e-03 -1.502556e-02 5.081632e-03 17 | -4.048210e-03 1.437334e-02 3.364464e-03 18 | 1.274907e-02 6.522202e-04 -8.446095e-03 19 | -------------------------------------------------------------------------------- /tests/test_tools.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | 5 | from molgym.tools.util import discount_cumsum, split_formula_strings, zs_to_formula 6 | 7 | 8 | class TestTools(TestCase): 9 | def test_parse_formula(self): 10 | s = 'H2O, CH4, O2' 11 | formulas = split_formula_strings(s) 12 | 13 | self.assertEqual(len(formulas), 3) 14 | 15 | def test_zs_to_formula(self): 16 | formula = zs_to_formula([1, 1, 2, 4]) 17 | self.assertEqual(len(formula), 3) 18 | 19 | def test_cumsum(self): 20 | discount = 0.5 21 | x = np.ones(3, dtype=np.float32) 22 | y = discount_cumsum(x, discount=discount) 23 | 24 | self.assertAlmostEqual(y[0], x[0] + discount * x[1] + discount**2 * x[2]) 25 | self.assertAlmostEqual(y[1], x[1] + discount * x[2]) 26 | self.assertAlmostEqual(y[2], x[2]) 27 | -------------------------------------------------------------------------------- /.mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | # Platform configuration 3 | python_version = 3.7 4 | 5 | # Untyped definitions and calls 6 | check_untyped_defs = True 7 | 8 | [mypy-ase.*] 9 | ignore_missing_imports = True 10 | 11 | [mypy-numpy.*] 12 | ignore_missing_imports = True 13 | 14 | [mypy-scipy.*] 15 | ignore_missing_imports = True 16 | 17 | [mypy-gym.*] 18 | ignore_missing_imports = True 19 | 20 | [mypy-matplotlib.*] 21 | ignore_missing_imports = True 22 | 23 | [mypy-cormorant.*] 24 | ignore_missing_imports = True 25 | 26 | [mypy-schnetpack.*] 27 | ignore_missing_imports = True 28 | 29 | [mypy-scine_sparrow.*] 30 | ignore_missing_imports = True 31 | 32 | [mypy-setuptools.*] 33 | ignore_missing_imports = True 34 | 35 | [mypy-pandas.*] 36 | ignore_missing_imports = True 37 | 38 | [mypy-torch_scatter.*] 39 | ignore_missing_imports = True 40 | 41 | [mypy-networkx.*] 42 | ignore_missing_imports = True 43 | 44 | [mypy-rdkit.*] 45 | ignore_missing_imports = True 46 | 47 | [mypy-quadpy.*] 48 | ignore_missing_imports = True 49 | 50 | -------------------------------------------------------------------------------- /tests/agents/covariant/test_gmm.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from molgym.agents.covariant.gmm import GaussianMixtureModel 7 | from molgym.tools.util import to_numpy 8 | 9 | 10 | class GaussianMixtureModelTest(TestCase): 11 | def setUp(self): 12 | self.log_probs = torch.log(torch.tensor([[0.7, 0.3], [0.5, 0.5]])) 13 | self.means = torch.tensor([[-0.5, 0.3], [0.0, 0.2]]) 14 | self.log_stds = torch.log(torch.tensor([[0.2, 0.5], [0.3, 0.2]])) 15 | self.distr = GaussianMixtureModel(log_probs=self.log_probs, means=self.means, stds=torch.exp(self.log_stds)) 16 | 17 | def test_samples(self): 18 | s = self.distr.sample(torch.Size((3, ))) 19 | self.assertEqual(s.shape, (3, 2)) 20 | 21 | def test_argmax(self): 22 | torch.manual_seed(1) 23 | argmax = self.distr.argmax(128) 24 | self.assertEqual(argmax.shape, (2, )) 25 | self.assertTrue(np.allclose(to_numpy(argmax), np.array([-0.495, 0.156]), atol=1.e-2)) 26 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | 4 | from setuptools import setup, find_packages 5 | 6 | 7 | def readme() -> str: 8 | with open('README.md') as f: 9 | return f.read() 10 | 11 | 12 | version_dict = {} # type: Dict[str, str] 13 | with open(os.path.join('molgym', 'version.py')) as file: 14 | exec(file.read(), version_dict) 15 | 16 | setup( 17 | name='molgym', 18 | version=version_dict['__version__'], 19 | description='', 20 | long_description=readme(), 21 | classifiers=['Programming Language :: Python :: 3.6'], 22 | author='Gregor Simm and Robert Pinsler', 23 | author_email='gncs2@cam.ac.uk, rp586@cam.ac.uk', 24 | python_requires='>=3.7', 25 | packages=find_packages(), 26 | include_package_data=True, 27 | install_requires=[ 28 | 'gym', 29 | 'numpy', 30 | 'scipy', 31 | 'pandas', 32 | 'matplotlib', 33 | 'ase', 34 | 'schnetpack', 35 | ], 36 | zip_safe=False, 37 | test_suite='pytest', 38 | tests_require=['pytest'], 39 | ) 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Gregor Simm 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/agents/covariant/test_tools.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import torch 4 | 5 | from molgym.agents.covariant.tools import pad_sequence 6 | 7 | 8 | class ToolsTest(TestCase): 9 | def test_pad_sequence(self): 10 | a = torch.rand(size=(3, 2)) 11 | b = torch.rand(size=(4, 2)) 12 | max_length = 5 13 | c = pad_sequence(sequences=[a, b], max_length=max_length, padding_value=0.0) 14 | self.assertEqual(c.shape, (2, max_length, 2)) 15 | self.assertTrue(torch.all(c[0, 3:] == 0.0)) 16 | self.assertTrue(torch.all(c[0, 4:] == 0.0)) 17 | 18 | def test_pad_sequence_too_small(self): 19 | a = torch.rand(size=(3, 2)) 20 | b = torch.rand(size=(4, 3)) 21 | max_length = 3 22 | with self.assertRaises(RuntimeError): 23 | pad_sequence(sequences=[a, b], max_length=max_length, padding_value=0.0) 24 | 25 | def test_pad_sequence_mismatch(self): 26 | a = torch.rand(size=(3, 2)) 27 | b = torch.rand(size=(4, 3)) 28 | max_length = 5 29 | with self.assertRaises(RuntimeError): 30 | pad_sequence(sequences=[a, b], max_length=max_length, padding_value=0.0) 31 | -------------------------------------------------------------------------------- /molgym/agents/covariant/gmm.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.distributions as D 5 | from torch.distributions import MixtureSameFamily 6 | 7 | 8 | class GaussianMixtureModel(MixtureSameFamily, ABC): 9 | def __init__( 10 | self, 11 | log_probs: torch.Tensor, 12 | means: torch.Tensor, 13 | stds: torch.Tensor, 14 | validate_args=None, 15 | ) -> None: 16 | categoricals = D.Categorical(logits=log_probs, validate_args=validate_args) 17 | normals = D.Normal(loc=means, scale=stds, validate_args=validate_args) 18 | super().__init__(mixture_distribution=categoricals, component_distribution=normals, validate_args=validate_args) 19 | 20 | def argmax(self, count=128) -> torch.Tensor: 21 | # This can also be implemented using the EM algorithm 22 | # http://www.cs.columbia.edu/~jebara/htmlpapers/ARL/node61.html 23 | samples = self.sample(torch.Size((count, ))) # (samples, batches) 24 | log_probs = self.log_prob(samples) # (samples, batches) 25 | indices = torch.argmax(log_probs, dim=0).unsqueeze(0) # (1, batches) 26 | result = torch.gather(samples, dim=0, index=indices) # (1, batches) 27 | return result.squeeze(0) # (batches, ) 28 | -------------------------------------------------------------------------------- /tests/test_reward.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import pkg_resources 4 | from ase import Atoms, Atom 5 | 6 | from molgym.reward import InteractionReward 7 | 8 | RESOURCES_FOLDER = 'resources' 9 | 10 | 11 | class TestReward(TestCase): 12 | RESOURCES = pkg_resources.resource_filename(__package__, RESOURCES_FOLDER) 13 | 14 | def setUp(self): 15 | self.reward = InteractionReward() 16 | 17 | def test_calculation(self): 18 | reward, info = self.reward.calculate(Atoms(), Atom('H')) 19 | self.assertEqual(reward, 0) 20 | 21 | def test_h2(self): 22 | atom1 = Atom('H', position=(0, 0, 0)) 23 | atom2 = Atom('H', position=(1, 0, 0)) 24 | 25 | atoms = Atoms() 26 | atoms.append(atom1) 27 | 28 | reward, info = self.reward.calculate(atoms, atom2) 29 | 30 | self.assertAlmostEqual(reward, 0.1696435) 31 | 32 | def test_addition(self): 33 | atom1 = Atom('H', position=(0, 0, 0)) 34 | atom2 = Atom('H', position=(1, 0, 0)) 35 | atom3 = Atom('H', position=(2, 0, 0)) 36 | 37 | atoms = Atoms() 38 | atoms.append(atom1) 39 | 40 | reward1, _ = self.reward.calculate(atoms, atom2) 41 | atoms.append(atom2) 42 | 43 | reward2, _ = self.reward.calculate(atoms, atom3) 44 | atoms.append(atom3) 45 | 46 | self.assertAlmostEqual(reward1 + reward2, 0.2141968) 47 | -------------------------------------------------------------------------------- /tests/test_modules.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from molgym.modules import to_one_hot, masked_softmax 7 | 8 | 9 | class TestModules(TestCase): 10 | def test_one_hot(self): 11 | positions = np.array([[1], [3], [2]]) 12 | indices = torch.from_numpy(positions) 13 | 14 | result = to_one_hot(indices=indices, num_classes=4).detach() 15 | expected = [ 16 | [0, 1, 0, 0], 17 | [0, 0, 0, 1], 18 | [0, 0, 1, 0], 19 | ] 20 | self.assertTrue(np.allclose(expected, result)) 21 | 22 | def test_one_hot_wrong_index(self): 23 | positions = np.array([ 24 | [5], 25 | ]) 26 | indices = torch.from_numpy(positions) 27 | 28 | with self.assertRaises(RuntimeError): 29 | to_one_hot(indices=indices, num_classes=3).detach() 30 | 31 | def test_softmax(self): 32 | logits = torch.from_numpy(np.array([ 33 | [0.5, 0.5], 34 | [1.0, 0.5], 35 | ], dtype=np.float)) 36 | 37 | mask_1 = torch.ones(size=logits.shape, dtype=torch.bool) 38 | 39 | y1 = masked_softmax(logits=logits, mask=mask_1) 40 | self.assertEqual(y1.shape, (2, 2)) 41 | self.assertAlmostEqual(y1.sum().item(), 2.0) 42 | 43 | mask_2 = torch.from_numpy(np.array([[1, 0], [1, 0]], dtype=np.bool)) 44 | y2 = masked_softmax(logits=logits, mask=mask_2) 45 | 46 | total = y2.sum(dim=0, keepdim=False) 47 | self.assertTrue(np.allclose(total, np.array([2, 0]))) 48 | -------------------------------------------------------------------------------- /molgym/minimizer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import scipy.optimize 5 | from ase import Atoms 6 | 7 | 8 | def minimize( 9 | calculator, 10 | atoms: Atoms, 11 | charge: int, 12 | spin_multiplicity: int, 13 | max_iter=120, 14 | fixed_indices=None, 15 | verbose=False, 16 | ) -> Tuple[Atoms, bool]: 17 | atoms = atoms.copy() 18 | calculator.set_elements(list(atoms.symbols)) 19 | calculator.set_settings({'molecular_charge': charge, 'spin_multiplicity': spin_multiplicity}) 20 | 21 | mask = np.ones((len(atoms) * 3, ), dtype=np.float) 22 | if fixed_indices: 23 | for index in fixed_indices: 24 | mask[index * 3:(index + 1) * 3] = 0 25 | 26 | def function(coords: np.ndarray) -> Tuple[float, np.ndarray]: 27 | calculator.set_positions(coords.reshape(-1, 3)) 28 | energy = calculator.calculate_energy() 29 | gradients = calculator.calculate_gradients() 30 | return energy, gradients.flatten() * mask 31 | 32 | initial_coords = atoms.positions.flatten() 33 | 34 | minimize_result = scipy.optimize.minimize( 35 | function, 36 | x0=initial_coords, 37 | jac=True, 38 | method='BFGS', 39 | options={ 40 | 'maxiter': max_iter, 41 | 'disp': verbose, 42 | 'norm': np.inf, # equivalent to taking numpy.amax(numpy.abs(gradient)) 43 | 'gtol': 3e-4, # TolMaxG=3e-4 (ORCA) 44 | }, 45 | ) 46 | 47 | atoms.positions = minimize_result.x.reshape(-1, 3) 48 | 49 | return atoms, minimize_result.success 50 | -------------------------------------------------------------------------------- /molgym/tools/analysis.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import re 5 | from typing import List 6 | 7 | 8 | def parse_json_lines_file(path: str) -> List[dict]: 9 | dicts = [] 10 | with open(path, mode='r') as f: 11 | for line in f: 12 | dicts.append(json.loads(line)) 13 | return dicts 14 | 15 | 16 | def parse_buffer_filename(filename: str) -> dict: 17 | regex = re.compile(r'(?P.*?)_run-(?P\d+)_steps-(?P\d+)(_rank-(?P\d+))?_(?P.*?)\.pkl') 18 | match = regex.match(filename) 19 | if not match: 20 | raise RuntimeError(f'Cannot parse filename: {filename}') 21 | return { 22 | 'name': match.group('name'), 23 | 'seed': int(match.group('seed')), 24 | 'steps': int(match.group('steps')), 25 | 'rank': int(match.group('rank')) if match.group('rank') else 0, 26 | 'mode': match.group('mode'), 27 | } 28 | 29 | 30 | def parse_results_filename(filename: str) -> dict: 31 | regex = re.compile(r'(?P.*?)_run-(?P\d+)_(?P.*?)\.txt') 32 | match = regex.match(filename) 33 | if not match: 34 | raise RuntimeError(f'Cannot parse filename: {filename}') 35 | return { 36 | 'name': match.group('name'), 37 | 'seed': int(match.group('seed')), 38 | 'mode': match.group('mode'), 39 | } 40 | 41 | 42 | def collect_results_paths(directory: str, mode: str) -> List[str]: 43 | return glob.glob(os.path.join(directory, f'*_run-*_{mode}.txt')) 44 | 45 | 46 | def collect_buffer_paths(directory: str, mode: str) -> List[str]: 47 | return glob.glob(os.path.join(directory, f'*_{mode}.pkl')) 48 | -------------------------------------------------------------------------------- /molgym/modules.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.distributions 5 | import torch_scatter 6 | 7 | 8 | def to_one_hot(indices: torch.Tensor, num_classes: int, device=None) -> torch.Tensor: 9 | """ 10 | Generates one-hot encoding with classes from 11 | 12 | :param indices: (N x 1) tensor 13 | :param num_classes: number of classes 14 | :param device: torch device 15 | :return: (N x num_classes) tensor 16 | """ 17 | shape = (*indices.shape[:-1], num_classes) 18 | oh = torch.zeros(shape, device=device).view(-1, num_classes) 19 | 20 | # scatter_ is the in-place version of scatter 21 | oh.scatter_(1, indices.view(-1, 1), 1) 22 | 23 | return oh.view(*shape) 24 | 25 | 26 | def masked_softmax(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 27 | return torch_scatter.composite.scatter_softmax(src=logits, index=mask.to(torch.long), dim=-1) * mask 28 | 29 | 30 | def init_layer(layer: torch.nn.Linear, w_scale=1.0) -> torch.nn.Linear: 31 | torch.nn.init.orthogonal_(layer.weight.data) 32 | layer.weight.data.mul_(w_scale) # type: ignore 33 | torch.nn.init.constant_(layer.bias.data, 0) 34 | return layer 35 | 36 | 37 | class MLP(torch.nn.Module): 38 | def __init__(self, input_dim: int, output_dims: Tuple[int, ...] = (64, 64), gate=torch.nn.functional.relu): 39 | super().__init__() 40 | dims = (input_dim, ) + output_dims 41 | self.layers = torch.nn.ModuleList( 42 | [init_layer(torch.nn.Linear(dim_in, dim_out)) for dim_in, dim_out in zip(dims[:-1], dims[1:])]) 43 | self.gate = gate 44 | self.output_dim = dims[-1] 45 | 46 | def forward(self, x): 47 | for layer in self.layers[:-1]: 48 | x = self.gate(layer(x)) 49 | x = self.layers[-1](x) 50 | return x 51 | -------------------------------------------------------------------------------- /tests/agents/covariant/test_sphs.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | import torch 5 | from cormorant.cg_lib import SphericalHarmonics 6 | 7 | from molgym.agents.covariant.so3_tools import spherical_to_cartesian 8 | 9 | 10 | class TestSphericalHarmonics(TestCase): 11 | def test_conversion(self): 12 | theta_phi = np.array([np.pi / 3, np.pi / 4]) 13 | pos = spherical_to_cartesian(theta_phi) 14 | 15 | expected = np.array([0.612372, 0.612372, 0.5]) 16 | self.assertTrue(np.allclose(pos, expected)) 17 | 18 | def test_l_1(self): 19 | theta_phi = np.array([np.pi / 2, 0.0]) 20 | pos = spherical_to_cartesian(theta_phi) 21 | pos_tensor = torch.tensor(pos, dtype=torch.float32) 22 | 23 | # To match the definition Mathematica uses sh_norm='qm' is required 24 | sph = SphericalHarmonics(maxl=1, normalize=True, sh_norm='qm') 25 | output = sph.forward(pos_tensor) 26 | 27 | # Mathematica output: 28 | expected = np.array([ 29 | [0.345494, 0], 30 | [0, 0], 31 | [-0.345494, 0], 32 | ], dtype=np.float32) 33 | 34 | self.assertTrue(np.allclose(output[1].cpu().detach().numpy(), expected)) 35 | 36 | def test_l_2(self): 37 | theta_phi = np.array([np.pi / 3, np.pi / 4]) 38 | pos = spherical_to_cartesian(theta_phi) 39 | pos_tensor = torch.tensor(pos, dtype=torch.float32) 40 | 41 | # To match the definition Mathematica uses sh_norm='qm' is required 42 | sph = SphericalHarmonics(maxl=2, normalize=False, sh_norm='qm') 43 | output = sph.forward(pos_tensor) 44 | 45 | # Mathematica output: 46 | expected = np.array([ 47 | [0, -0.289706], 48 | [0.236544, -0.236544], 49 | [-0.0788479, 0], 50 | [-0.236544, -0.236544], 51 | [0, 0.289706], 52 | ], 53 | dtype=np.float32) 54 | 55 | self.assertTrue(np.allclose(output[2].cpu().detach().numpy(), expected)) 56 | -------------------------------------------------------------------------------- /molgym/agents/covariant/tools.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | import ase.data 4 | import torch 5 | from ase import Atoms 6 | 7 | 8 | def atoms_to_feats(atoms: Atoms, dtype: torch.dtype, device: torch.device) -> Dict[str, torch.Tensor]: 9 | return { 10 | 'num_atoms': torch.tensor(len(atoms), dtype=torch.int, device=device), 11 | 'charges': torch.tensor([ase.data.atomic_numbers[atom.symbol] for atom in atoms], 12 | dtype=torch.int, 13 | device=device), 14 | 'positions': torch.tensor(atoms.positions, dtype=dtype, device=device), 15 | } 16 | 17 | 18 | def pad_sequence(sequences: List[torch.Tensor], max_length: int, padding_value=0) -> torch.Tensor: 19 | # assuming trailing dimensions and type of all the Tensors 20 | # in sequences are same and fetching those from sequences[0] 21 | max_size = sequences[0].size() 22 | trailing_dims = max_size[1:] 23 | out_dims = (len(sequences), max_length) + trailing_dims 24 | 25 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) # type: ignore 26 | for i, tensor in enumerate(sequences): 27 | length = tensor.size(0) 28 | # use index notation to prevent duplicate references to the tensor 29 | out_tensor[i, :length, ...] = tensor 30 | 31 | return out_tensor 32 | 33 | 34 | def process_atoms_list(atoms_list: List[Atoms], max_num_atoms: int, dtype: torch.dtype, 35 | device: torch.device) -> Dict[str, torch.Tensor]: 36 | # Gather features 37 | feats_list = [atoms_to_feats(atoms, dtype=dtype, device=device) for atoms in atoms_list] 38 | 39 | # Convert list-of-dicts to dict-of-lists 40 | props = feats_list[0].keys() 41 | prop_dict = {prop: [feats[prop] for feats in feats_list] for prop in props} 42 | 43 | # Pad and stack 44 | molecules = { 45 | key: pad_sequence(val, max_length=max_num_atoms) if val[0].dim() > 0 else torch.stack(val) 46 | for key, val in prop_dict.items() 47 | } 48 | 49 | return molecules 50 | -------------------------------------------------------------------------------- /scripts/structures.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | from typing import Tuple 5 | 6 | import ase.data 7 | import ase.io 8 | 9 | from molgym.spaces import CanvasSpace 10 | from molgym.tools.analysis import parse_buffer_filename, collect_buffer_paths 11 | 12 | 13 | def parse_args() -> argparse.Namespace: 14 | parser = argparse.ArgumentParser(description='Analyse MolGym output') 15 | 16 | parser.add_argument('--dir', help='path to data directory of experiment(s)', required=True) 17 | parser.add_argument('--symbols', 18 | help='symbols representing elements on canvas (comma separated)', 19 | type=str, 20 | required=True) 21 | parser.add_argument('--canvas_size', 22 | help='maximum number of atoms that can be placed on the canvas', 23 | type=int, 24 | default=128) 25 | parser.add_argument('--mode', help='select from train or eval mode', default='eval', choices=['train', 'eval']) 26 | parser.add_argument('--name', help='name of experiment') 27 | 28 | return parser.parse_args() 29 | 30 | 31 | def path_to_sort_key(path: str) -> Tuple: 32 | parsed_path = parse_buffer_filename(os.path.basename(path)) 33 | return parsed_path['steps'], parsed_path['mode'], parsed_path['name'], parsed_path['seed'], parsed_path['rank'] 34 | 35 | 36 | def main() -> None: 37 | args = parse_args() 38 | 39 | paths = collect_buffer_paths(args.dir, mode=args.mode) 40 | print(f'Parsed paths: {len(paths)}') 41 | 42 | canvas_space = CanvasSpace(size=args.canvas_size, zs=[ase.data.atomic_numbers[s] for s in args.symbols.split(',')]) 43 | 44 | # Sort paths 45 | paths = sorted(paths, key=path_to_sort_key) 46 | 47 | atoms_list = [] 48 | for path in paths: 49 | info = parse_buffer_filename(os.path.basename(path)) 50 | 51 | if args.name and info['name'] != args.name: 52 | continue 53 | 54 | with open(path, mode='rb') as f: 55 | buffer = pickle.load(f) 56 | 57 | atoms = [ 58 | canvas_space.to_atoms(obs[0]) for terminal, obs in zip(buffer.term_buf, buffer.next_obs_buf) 59 | if obs and terminal 60 | ] 61 | for atom in atoms: 62 | atom.info = info 63 | 64 | atoms_list += atoms 65 | 66 | if args.name: 67 | filename = f'structures_{args.name}_{args.mode}.xyz' 68 | else: 69 | filename = f'structures_{args.mode}.xyz' 70 | 71 | ase.io.write(filename, images=atoms_list) 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /molgym/tools/qm9_parser.py: -------------------------------------------------------------------------------- 1 | import re 2 | import tarfile 3 | from typing import Tuple, Iterator 4 | 5 | from ase import Atoms 6 | 7 | 8 | class ParserError(Exception): 9 | """Error raised when an occurs while parsing GDB9 dataset""" 10 | 11 | 12 | _coord_line = (br'(?P\D+)\s+(?P-?\d+\.\d*(E-?\d+)?)\s+(?P-?\d+\.\d*(E-?\d+)?)\s+' 13 | br'(?P-?\d+\.\d*(E-?\d+)?)\s+(?P-?\d+\.\d*(E-?\d+)?)\s*') 14 | _coord_re = re.compile(_coord_line) 15 | _data_re = re.compile( 16 | br'^(?P\d+)\n' 17 | br'gdb (?P\d+)\s+(?P-?\d+(\.\d*)?)\s+(?P-?\d+\.\d*)\s+(?P-?\d+\.\d*)\s+(?P-?\d+\.\d*)' 18 | br'\s+(?P-?\d+\.\d*)\s+(?P-?\d+\.\d*)\s+(?P-?\d+\.\d*)\s+(?P-?\d+\.\d*)' 19 | br'\s+(?P-?\d+\.\d*)\s+(?P-?\d+\.\d*)\s+(?P-?\d+\.\d*)\s+(?P-?\d+\.\d*)' 20 | br'\s+(?P-?\d+\.\d*)\s+(?P-?\d+\.\d*)\s+(?P-?\d+\.\d*)\s+' 21 | br'(?P(' + _coord_line + br')+)' 22 | br'(?P(\s*-?\d+\.\d*)+)' 23 | br'(?P(\s*\S+))' 24 | br'(?P(\s*\S+))' 25 | br'(?P(\s*\S+))' 26 | br'(?P(\s*\S+){2})\s*$') 27 | 28 | 29 | def parse_entry(string: bytes) -> Tuple[str, Atoms, dict]: 30 | elements = [] 31 | positions = [] 32 | 33 | match = _data_re.match(string) 34 | try: 35 | if not match: 36 | raise ParserError('String does not match pattern') 37 | 38 | for coord in _coord_re.finditer(match.group('coordinates')): 39 | elements.append(coord.group('element').decode('ascii').strip()) 40 | positions.append((float(coord.group('x')), float(coord.group('y')), float(coord.group('z')))) 41 | 42 | info = {'smiles': match.group('smiles_opt').decode('ascii').strip()} 43 | 44 | return match.group('id').decode('ascii'), Atoms(symbols=elements, positions=positions), info 45 | 46 | except (ValueError, AttributeError) as e: 47 | raise ParserError(e) 48 | 49 | 50 | def parse_dataset(file_path: str, strict=False) -> Iterator[Tuple[str, Atoms, dict]]: 51 | with tarfile.open(file_path, mode='r') as archive: 52 | for i, entry in enumerate(archive): 53 | f = archive.extractfile(entry) 54 | 55 | if not f: 56 | raise RuntimeError('File cannot be read') 57 | 58 | string = f.read().replace(b'*^', b'E') 59 | 60 | try: 61 | yield parse_entry(string) 62 | except ParserError as e: 63 | if not strict: 64 | print('Could not parse: ' + entry.name + ': ' + str(e)) 65 | else: 66 | raise 67 | -------------------------------------------------------------------------------- /tests/test_sparrow.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest import TestCase 3 | 4 | import ase.io 5 | import numpy as np 6 | import pkg_resources 7 | from ase import Atoms 8 | 9 | from molgym.calculator import Sparrow 10 | 11 | RESOURCES_FOLDER = 'resources' 12 | 13 | 14 | class TestSparrow(TestCase): 15 | RESOURCES = pkg_resources.resource_filename(__package__, RESOURCES_FOLDER) 16 | 17 | def setUp(self): 18 | self.atoms = Atoms(symbols='HH', positions=[(0, 0, 0), (1.2, 0, 0)]) 19 | self.charge = 0 20 | self.spin_multiplicity = 1 21 | 22 | def test_calculator(self): 23 | calculator = Sparrow('PM6') 24 | calculator.set_elements(list(self.atoms.symbols)) 25 | calculator.set_positions(self.atoms.positions) 26 | calculator.set_settings({'molecular_charge': 0, 'spin_multiplicity': 1}) 27 | 28 | gradients = calculator.calculate_gradients() 29 | energy = calculator.calculate_energy() 30 | 31 | self.assertAlmostEqual(energy, -0.9379853016) 32 | self.assertEqual(gradients.shape, (2, 3)) 33 | 34 | def test_atomic_energies(self): 35 | calculator = Sparrow('PM6') 36 | calculator.set_positions([(0, 0, 0)]) 37 | 38 | calculator.set_elements(['H']) 39 | calculator.set_settings({'molecular_charge': 0, 'spin_multiplicity': 2}) 40 | self.assertAlmostEqual(calculator.calculate_energy(), -0.4133180865) 41 | 42 | calculator.set_elements(['C']) 43 | calculator.set_settings({'molecular_charge': 0, 'spin_multiplicity': 1}) 44 | self.assertAlmostEqual(calculator.calculate_energy(), -4.162353543) 45 | 46 | calculator.set_elements(['O']) 47 | calculator.set_settings({'molecular_charge': 0, 'spin_multiplicity': 1}) 48 | self.assertAlmostEqual(calculator.calculate_energy(), -10.37062419) 49 | 50 | def test_energy_gradients(self): 51 | calculator = Sparrow('PM6') 52 | atoms = ase.io.read(filename=os.path.join(self.RESOURCES, 'h2o.xyz'), format='xyz', index=0) 53 | calculator.set_positions(atoms.positions) 54 | calculator.set_elements(list(atoms.symbols)) 55 | calculator.set_settings({'molecular_charge': 0, 'spin_multiplicity': 1}) 56 | 57 | energy = calculator.calculate_energy() 58 | gradients = calculator.calculate_gradients() 59 | 60 | energy_file = os.path.join(self.RESOURCES, 'energy.dat') 61 | expected_energy = float(np.genfromtxt(energy_file)) 62 | self.assertAlmostEqual(energy, expected_energy) 63 | 64 | gradients_file = os.path.join(self.RESOURCES, 'gradients.dat') 65 | expected_gradients = np.genfromtxt(gradients_file) 66 | self.assertTrue(np.allclose(gradients, expected_gradients)) 67 | -------------------------------------------------------------------------------- /molgym/buffer_container.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import List 3 | 4 | import numpy as np 5 | 6 | from molgym.buffer import DynamicPPOBuffer 7 | from molgym.spaces import ObservationType 8 | 9 | 10 | class PPOBufferContainer: 11 | def __init__(self, size: int, gamma: float, lam: float) -> None: 12 | super().__init__() 13 | 14 | self.gamma = gamma 15 | self.lam = lam 16 | self.size = size 17 | 18 | self.buffers = [DynamicPPOBuffer(gamma=self.gamma, lam=self.lam) for _ in range(self.size)] 19 | 20 | self.episodic_returns: List[float] = [] 21 | self.episode_lengths: List[int] = [] 22 | 23 | def get_num_episodes(self) -> int: 24 | num_returns = len(self.episodic_returns) 25 | assert num_returns == len(self.episode_lengths) 26 | return num_returns 27 | 28 | def store( 29 | self, 30 | observations: List[ObservationType], 31 | actions: np.ndarray, 32 | rewards: np.ndarray, 33 | next_observations: List[ObservationType], 34 | terminals: np.ndarray, 35 | values: np.ndarray, 36 | logps: np.ndarray, 37 | ) -> None: 38 | assert len(observations) == actions.shape[0] == rewards.shape[0] == len( 39 | next_observations) == terminals.shape[0] == values.shape[0] == logps.shape[0] == len(self.buffers) 40 | 41 | for i, buffer in enumerate(self.buffers): 42 | buffer.store( 43 | obs=observations[i], 44 | act=actions[i], 45 | reward=rewards[i], 46 | next_obs=next_observations[i], 47 | terminal=terminals[i], 48 | value=values[i], 49 | logp=logps[i], 50 | ) 51 | 52 | if terminals[i]: 53 | episodic_ret, episode_length = buffer.finish_path(0.0) 54 | assert episodic_ret is not None and episode_length > 0 55 | self.episodic_returns.append(episodic_ret) 56 | self.episode_lengths.append(episode_length) 57 | 58 | def finish_paths(self, values: np.ndarray): 59 | assert values.shape[0] == self.size 60 | 61 | for buffer, value in zip(self.buffers, values): 62 | # the buffer could be already finished so we have to check 63 | if not buffer.is_finished(): 64 | # Don't record unfinished paths 65 | buffer.finish_path(value) 66 | 67 | def merge(self) -> DynamicPPOBuffer: 68 | new = DynamicPPOBuffer(gamma=self.gamma, lam=self.lam) 69 | 70 | assert all(buffer.is_finished() for buffer in self.buffers) 71 | 72 | for field in DynamicPPOBuffer.BUFFER_FIELDS: 73 | setattr(new, field, list(itertools.chain.from_iterable(getattr(buffer, field) for buffer in self.buffers))) 74 | 75 | return new 76 | -------------------------------------------------------------------------------- /tests/test_environment.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from ase import Atom 4 | 5 | from molgym.environment import MolecularEnvironment 6 | from molgym.reward import InteractionReward 7 | from molgym.spaces import ObservationSpace, ActionSpace 8 | from molgym.tools.util import string_to_formula 9 | 10 | 11 | class TestEnvironment(TestCase): 12 | def setUp(self): 13 | self.reward = InteractionReward() 14 | self.zs = [0, 1, 6, 7, 8] 15 | self.observation_space = ObservationSpace(canvas_size=5, zs=self.zs) 16 | self.action_space = ActionSpace(zs=self.zs) 17 | 18 | def test_addition(self): 19 | formula = string_to_formula('H2CO') 20 | env = MolecularEnvironment(reward=self.reward, 21 | observation_space=self.observation_space, 22 | action_space=self.action_space, 23 | formulas=[formula]) 24 | action = self.action_space.from_atom(Atom(symbol='H', position=(0.0, 1.0, 0.0))) 25 | obs, reward, done, info = env.step(action=action) 26 | 27 | atoms1, formula = self.observation_space.parse(obs) 28 | 29 | self.assertEqual(atoms1[0].symbol, 'H') 30 | self.assertEqual(formula, ((0, 0), (1, 1), (6, 1), (7, 0), (8, 1))) 31 | self.assertEqual(reward, 0.0) 32 | self.assertFalse(done) 33 | 34 | def test_invalid_action(self): 35 | formula = string_to_formula('H2CO') 36 | env = MolecularEnvironment(reward=self.reward, 37 | observation_space=self.observation_space, 38 | action_space=self.action_space, 39 | formulas=[formula]) 40 | action = self.action_space.from_atom(Atom(symbol='N', position=(0, 1, 0))) 41 | with self.assertRaises(RuntimeError): 42 | env.step(action) 43 | 44 | def test_invalid_formula(self): 45 | formula = string_to_formula('He2') 46 | with self.assertRaises(AssertionError): 47 | self.observation_space.bag_space.from_formula(formula) 48 | 49 | def test_solo_distance(self): 50 | formula = string_to_formula('H2CO') 51 | env = MolecularEnvironment( 52 | reward=self.reward, 53 | observation_space=self.observation_space, 54 | action_space=self.action_space, 55 | formulas=[formula], 56 | max_solo_distance=1.0, 57 | ) 58 | 59 | # First H can be on its own 60 | action = self.action_space.from_atom(atom=Atom(symbol='H', position=(0, 0, 0))) 61 | obs, reward, done, info = env.step(action=action) 62 | self.assertFalse(done) 63 | 64 | # Second H cannot 65 | action = self.action_space.from_atom(atom=Atom(symbol='H', position=(0, 1.5, 0))) 66 | obs, reward, done, info = env.step(action=action) 67 | self.assertTrue(done) 68 | -------------------------------------------------------------------------------- /tests/test_minimizer.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | from ase import Atoms 5 | 6 | from molgym.calculator import Sparrow 7 | from molgym.minimizer import minimize 8 | 9 | 10 | class TestMinimizer(TestCase): 11 | def setUp(self): 12 | self.atoms = Atoms(symbols='OHH', 13 | positions=[ 14 | (-0.27939703, 0.83823215, 0.00973345), 15 | (-0.52040310, 1.77677325, 0.21391146), 16 | (0.54473632, 0.90669722, -0.53501306), 17 | ]) 18 | 19 | self.charge = 0 20 | self.spin_multiplicity = 1 21 | 22 | def test_minimize(self): 23 | calculator = Sparrow('PM6') 24 | 25 | calculator.set_elements(list(self.atoms.symbols)) 26 | calculator.set_positions(self.atoms.positions) 27 | calculator.set_settings({'molecular_charge': self.charge, 'spin_multiplicity': self.spin_multiplicity}) 28 | energy1 = calculator.calculate_energy() 29 | gradients1 = calculator.calculate_gradients() 30 | 31 | opt_atoms, success = minimize(calculator=calculator, 32 | atoms=self.atoms, 33 | charge=self.charge, 34 | spin_multiplicity=self.spin_multiplicity) 35 | 36 | calculator.set_positions(opt_atoms.positions) 37 | energy2 = calculator.calculate_energy() 38 | gradients2 = calculator.calculate_gradients() 39 | 40 | self.assertTrue(energy1 > energy2) 41 | self.assertTrue(np.sum(np.square(gradients1)) > np.sum(np.square(gradients2))) 42 | self.assertTrue(np.all(gradients2 < 1E-3)) 43 | 44 | def test_minimize_fail(self): 45 | calculator = Sparrow('PM6') 46 | calculator.set_elements(list(self.atoms.symbols)) 47 | calculator.set_positions(self.atoms.positions) 48 | calculator.set_settings({'molecular_charge': self.charge, 'spin_multiplicity': self.spin_multiplicity}) 49 | 50 | opt_atoms, success = minimize( 51 | calculator=calculator, 52 | atoms=self.atoms, 53 | charge=self.charge, 54 | spin_multiplicity=self.spin_multiplicity, 55 | max_iter=1, 56 | ) 57 | 58 | self.assertFalse(success) 59 | 60 | def test_minimize_fixed(self): 61 | calculator = Sparrow('PM6') 62 | 63 | calculator.set_elements(list(self.atoms.symbols)) 64 | calculator.set_positions(self.atoms.positions) 65 | calculator.set_settings({'molecular_charge': self.charge, 'spin_multiplicity': self.spin_multiplicity}) 66 | 67 | fixed_index = 2 68 | opt_atoms, success = minimize( 69 | calculator=calculator, 70 | atoms=self.atoms, 71 | charge=self.charge, 72 | spin_multiplicity=self.spin_multiplicity, 73 | fixed_indices=[fixed_index], 74 | ) 75 | 76 | self.assertTrue(np.all((self.atoms.positions - opt_atoms.positions)[fixed_index] < 1E-6)) 77 | -------------------------------------------------------------------------------- /molgym/calculator.py: -------------------------------------------------------------------------------- 1 | # Try loading energy computation backends. 2 | # For each one that is successful, set its entry in `calculators`. 3 | 4 | import numpy as np 5 | 6 | calculators = {"sparrow_v2": None, "sparrow_v3": None} 7 | 8 | 9 | class SparrowCalc: 10 | """ 11 | Calculation object for sparrow v3. 12 | """ 13 | 14 | def __init__(self, method): 15 | self.calc = manager.get("calculator", method) 16 | self.calc.set_required_properties([su.Property.Energy, su.Property.Gradients]) 17 | self.elements = None 18 | self.positions = None 19 | 20 | def set_elements(self, codes): 21 | elems = [] 22 | for code in codes: 23 | if isinstance(code, str): 24 | code = getattr(su.ElementType, code) 25 | elems.append(code) 26 | 27 | self.elements = elems 28 | 29 | def set_positions(self, crd): 30 | self.positions = np.array(crd) * su.BOHR_PER_ANGSTROM 31 | 32 | def set_settings(self, attr): 33 | """ 34 | This routine will be called with `attr`: 35 | 36 | { 'unrestricted_calculation' : int 37 | 'spin_multiplicity' : int 38 | } 39 | 40 | Available attributes in self.calc.settings: 41 | 42 | molecular_charge 0 43 | spin_multiplicity 1 44 | spin_mode any 45 | temperature 298.15 46 | electronic_temperature 0.0 47 | symmetry_number 1 48 | self_consistence_criterion 1e-07 49 | density_rmsd_criterion 1e-05 50 | max_scf_iterations 100 51 | scf_mixer diis 52 | method_parameters 53 | nddo_dipole True 54 | """ 55 | # for k,v in self.calc.settings.items(): 56 | # print(k, v) 57 | for k, v in attr.items(): 58 | if k == "unrestricted_calculation": 59 | if v: 60 | self.calc.settings["spin_mode"] = "unrestricted" 61 | else: 62 | self.calc.settings["spin_mode"] = "restricted" 63 | continue 64 | try: 65 | self.calc.settings[k] = v 66 | except RuntimeError as e: 67 | print(f"Unable to set {k} = {v}: {e}") 68 | 69 | def _structure(self): 70 | structure = su.AtomCollection(len(self.elements)) 71 | structure.elements = self.elements 72 | structure.positions = self.positions 73 | return structure 74 | 75 | def calculate_energy(self): 76 | self.calc.structure = self._structure() 77 | return self.calc.calculate().energy 78 | 79 | def calculate_gradients(self): 80 | self.calc.structure = self._structure() 81 | return self.calc.calculate().gradients 82 | 83 | 84 | try: # try sparrow v2 85 | from scine_sparrow import Calculation 86 | 87 | calculators["sparrow_v2"] = Calculation 88 | except: # try sparrow v3 89 | import scine_utilities as su 90 | import scine_sparrow 91 | 92 | manager = su.core.ModuleManager() 93 | calculators["sparrow_v3"] = SparrowCalc 94 | 95 | # Use the first loaded backend. 96 | for k, v in calculators.items(): 97 | if v is not None: 98 | calculator = k 99 | Sparrow = v 100 | break 101 | -------------------------------------------------------------------------------- /molgym/reward.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import time 3 | from typing import Tuple, Dict 4 | 5 | import ase.data 6 | import numpy as np 7 | from ase import Atoms, Atom 8 | 9 | from molgym.calculator import Sparrow 10 | 11 | 12 | class MolecularReward(abc.ABC): 13 | @abc.abstractmethod 14 | def calculate(self, atoms: Atoms, new_atom: Atom) -> Tuple[float, dict]: 15 | raise NotImplementedError 16 | 17 | @staticmethod 18 | def get_minimum_spin_multiplicity(atoms) -> int: 19 | return sum(ase.data.atomic_numbers[atom.symbol] for atom in atoms) % 2 + 1 20 | 21 | 22 | class InteractionReward(MolecularReward): 23 | def __init__(self) -> None: 24 | # Due to some mysterious bug in Sparrow, calculations get slower and slower over time. 25 | # Therefore, we generate a new Sparrow object every time. 26 | self.calculator = Sparrow('PM6') 27 | 28 | self.settings = { 29 | 'molecular_charge': 0, 30 | 'max_scf_iterations': 128, 31 | 'unrestricted_calculation': 1, 32 | } 33 | 34 | self.atom_energies: Dict[str, float] = {} 35 | 36 | def calculate(self, atoms: Atoms, new_atom: Atom) -> Tuple[float, dict]: 37 | start = time.time() 38 | self.calculator = Sparrow('PM6') 39 | 40 | all_atoms = atoms.copy() 41 | all_atoms.append(new_atom) 42 | 43 | e_tot = self._calculate_energy(all_atoms) 44 | e_parts = self._calculate_energy(atoms) + self._calculate_atomic_energy(new_atom) 45 | delta_e = e_tot - e_parts 46 | 47 | elapsed = time.time() - start 48 | 49 | reward = -1 * delta_e 50 | 51 | info = { 52 | 'elapsed_time': elapsed, 53 | } 54 | 55 | return reward, info 56 | 57 | def _calculate_atomic_energy(self, atom: Atom) -> float: 58 | if atom.symbol not in self.atom_energies: 59 | atoms = Atoms() 60 | atoms.append(atom) 61 | self.atom_energies[atom.symbol] = self._calculate_energy(atoms) 62 | return self.atom_energies[atom.symbol] 63 | 64 | def _calculate_energy(self, atoms: Atoms) -> float: 65 | if len(atoms) == 0: 66 | return 0.0 67 | 68 | self.calculator.set_elements(list(atoms.symbols)) 69 | self.calculator.set_positions(atoms.positions) 70 | self.settings['spin_multiplicity'] = self.get_minimum_spin_multiplicity(atoms) 71 | self.calculator.set_settings(self.settings) 72 | return self.calculator.calculate_energy() 73 | 74 | 75 | class SolvationReward(InteractionReward): 76 | def __init__(self, distance_penalty=0.01) -> None: 77 | super().__init__() 78 | 79 | self.distance_penalty = distance_penalty 80 | 81 | def calculate(self, atoms: Atoms, new_atom: Atom) -> Tuple[float, dict]: 82 | start_time = time.time() 83 | self.calculator = Sparrow('PM6') 84 | 85 | all_atoms = atoms.copy() 86 | all_atoms.append(new_atom) 87 | 88 | e_tot = self._calculate_energy(all_atoms) 89 | e_parts = self._calculate_energy(atoms) + self._calculate_atomic_energy(new_atom) 90 | delta_e = e_tot - e_parts 91 | 92 | distance = np.linalg.norm(new_atom.position) 93 | 94 | reward = -1 * (delta_e + self.distance_penalty * distance) 95 | 96 | info = { 97 | 'elapsed_time': time.time() - start_time, 98 | } 99 | 100 | return reward, info 101 | -------------------------------------------------------------------------------- /tests/test_spaces.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | from ase import Atoms, Atom 5 | 6 | from molgym.spaces import ObservationSpace, CanvasItemSpace, ActionSpace, CanvasSpace, BagSpace 7 | 8 | 9 | class TestAtomicSpace(TestCase): 10 | def test_atom(self): 11 | space = CanvasItemSpace(zs=[0, 1, 6]) 12 | sample = space.sample() 13 | element, position = sample 14 | 15 | self.assertTrue(isinstance(element, int)) 16 | self.assertEqual(len(position), 3) 17 | 18 | atom = space.to_atom(sample) 19 | tup = space.from_atom(atom) 20 | 21 | self.assertEqual(element, tup[0]) 22 | self.assertTrue(np.isclose(position, tup[1]).all()) 23 | 24 | def test_invalid_atom(self): 25 | space = CanvasItemSpace(zs=[1, 6]) 26 | 27 | with self.assertRaises(IndexError): 28 | space.to_atom((2, (0, 0, 0))) 29 | 30 | with self.assertRaises(RuntimeError): 31 | space.to_atom((-1, (0, 0, 0))) 32 | 33 | with self.assertRaises(ValueError): 34 | space.to_atom((1, ('H', 0, 0))) # type: ignore 35 | 36 | 37 | class TestMolecularSpace(TestCase): 38 | def test_atoms(self): 39 | space = CanvasSpace(size=5, zs=[0, 1]) 40 | sample = space.sample() 41 | 42 | self.assertEqual(len(sample), 5) 43 | 44 | atoms = space.to_atoms(sample) 45 | self.assertLessEqual(len(atoms), 5) 46 | 47 | atoms = Atoms() 48 | tup = space.from_atoms(atoms) 49 | 50 | self.assertEqual(len(tup), 5) 51 | for element, position in tup: 52 | self.assertEqual(element, 0) 53 | 54 | parsed = space.to_atoms(tup) 55 | self.assertEqual(len(parsed), 0) 56 | 57 | def test_invalid_atoms(self): 58 | space = CanvasSpace(size=2, zs=[0, 1]) 59 | atoms = Atoms(symbols='HHH') 60 | with self.assertRaises(RuntimeError): 61 | space.from_atoms(atoms) 62 | 63 | 64 | class TestBagSpace(TestCase): 65 | def setUp(self): 66 | self.atomic_numbers = [1, 6, 7, 8] 67 | self.bag_space = BagSpace(self.atomic_numbers) 68 | 69 | def test_bag(self): 70 | for item in self.bag_space.sample(): 71 | self.assertIsInstance(item, int) 72 | self.assertEqual(len(self.bag_space.sample()), len(self.atomic_numbers)) 73 | 74 | def test_invalid_bag(self): 75 | with self.assertRaises(AssertionError): 76 | self.bag_space.to_formula((1, 0, 0, 0, 0)) 77 | 78 | 79 | class TestActionSpace(TestCase): 80 | def setUp(self) -> None: 81 | self.space = ActionSpace(zs=[0, 1, 6]) 82 | 83 | def test_action(self): 84 | self.assertIsNone(self.space.shape) 85 | 86 | def test_shape(self): 87 | action = self.space.sample() 88 | self.assertEqual(len(action), 2) 89 | self.assertEqual(len(action[1]), 3) 90 | 91 | def test_build(self): 92 | symbol = 'C' 93 | action = self.space.from_atom(Atom(symbol=symbol)) 94 | self.assertEqual(action[0], 2) 95 | self.assertEqual(self.space.to_atom(action).symbol, symbol) 96 | 97 | 98 | class TestObservationSpace(TestCase): 99 | def setUp(self): 100 | self.atomic_numbers = [0, 1, 6, 7] 101 | 102 | def test_molecular_observation(self): 103 | space = ObservationSpace(canvas_size=5, zs=self.atomic_numbers) 104 | canvas, bag = space.sample() 105 | 106 | self.assertEqual(len(canvas), 5) 107 | self.assertEqual(len(bag), len(self.atomic_numbers)) 108 | -------------------------------------------------------------------------------- /molgym/agents/internal/zmat.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | 5 | 6 | def get_distance(p_i: np.ndarray, p_j: np.ndarray) -> float: 7 | """ 8 | Compute distance between points i and j 9 | 10 | :param p_i: point i 11 | :param p_j: point j 12 | :return: distance 13 | """ 14 | return np.sqrt(np.sum(np.square(p_i - p_j))) 15 | 16 | 17 | def get_angle(p_i: np.ndarray, p_j: np.ndarray, p_k: np.ndarray) -> float: 18 | """ 19 | Compute angle between points i, j, and k 20 | 21 | :param p_i: point i 22 | :param p_j: point j 23 | :param p_k: point k 24 | :return: angle in radians 25 | """ 26 | rij = p_i - p_j 27 | rkj = p_k - p_j 28 | 29 | sin_theta = np.linalg.norm(np.cross(rij, rkj)) 30 | cos_theta = np.dot(rij, rkj) 31 | return np.arctan2(sin_theta, cos_theta) 32 | 33 | 34 | def get_dihedral(p_i: np.ndarray, p_j: np.ndarray, p_k: np.ndarray, p_l: np.ndarray) -> float: 35 | """ 36 | Return dihedral between points i, j, k, and l. 37 | 38 | :param p_i: point i 39 | :param p_j: point j 40 | :param p_k: point k 41 | :param p_l: point l 42 | :return: dihedral angle in radians 43 | """ 44 | r_ji = p_j - p_i 45 | r_kj = p_k - p_j 46 | r_lk = p_l - p_k 47 | 48 | v1 = np.cross(r_ji, r_kj) 49 | v1 = v1 / np.linalg.norm(v1) 50 | 51 | v2 = np.cross(r_lk, r_kj) 52 | v2 = v2 / np.linalg.norm(v2) 53 | 54 | m1 = np.cross(v1, r_kj) / np.linalg.norm(r_kj) 55 | 56 | x = np.dot(v1, v2) 57 | y = np.dot(m1, v2) 58 | 59 | psi = np.arctan2(y, x) 60 | if psi < 0: 61 | return -psi - np.pi 62 | else: 63 | return np.pi - psi 64 | 65 | 66 | def position_point(p0: np.ndarray, p1: np.ndarray, p2: np.ndarray, distance: float, angle: float, 67 | dihedral: float) -> np.ndarray: 68 | """ 69 | Determine point p in space that is: 70 | - far from p2 71 | - between p2 and p1 72 | - between p2, p1, and p0 73 | 74 | :param p0: position for dihedral 75 | :param p1: position for angle 76 | :param p2: position for distance 77 | :param distance: distance between p and v2 78 | :param angle: angle between p, p2 and p1 79 | :param dihedral: dihedral angle between p, p2, p1, and p0 80 | :return: coordinates of p 81 | """ 82 | x = distance * np.cos(angle) 83 | y = distance * np.cos(dihedral) * np.sin(angle) 84 | z = distance * np.sin(dihedral) * np.sin(angle) 85 | 86 | v_a = p1 - p0 87 | 88 | v_b = p2 - p1 89 | v_b = v_b / np.linalg.norm(v_b) 90 | 91 | c_ab = np.cross(v_a, v_b) 92 | c_ab = c_ab / np.linalg.norm(c_ab) 93 | 94 | c_ab_b = np.cross(c_ab, v_b) 95 | 96 | return p2 - v_b * x + c_ab_b * y + c_ab * z 97 | 98 | 99 | def position_atom_helper( 100 | positions: List[np.ndarray], 101 | focus: int, 102 | distance: float, 103 | angle: float, 104 | dihedral: float, 105 | ) -> np.ndarray: 106 | if focus > len(positions): 107 | raise RuntimeError('Focus greater than number of atoms') 108 | 109 | if len(positions) == 0: 110 | return np.array([0, 0, 0], dtype=np.float) 111 | 112 | focus = positions[focus] 113 | sorted_positions = sorted(positions, key=lambda p: get_distance(p, focus)) 114 | 115 | p_aux_1 = np.array([1, 0, 0], dtype=np.float) 116 | p_aux_0 = np.array([0, 1, 0], dtype=np.float) 117 | 118 | if len(positions) == 1: 119 | p2 = sorted_positions[0] 120 | p1 = p2 + p_aux_1 121 | p0 = p2 + p_aux_0 122 | 123 | elif len(positions) == 2: 124 | p2 = sorted_positions[0] 125 | p1 = sorted_positions[1] 126 | p0 = p2 + p1 + p_aux_0 + p_aux_1 127 | 128 | else: 129 | p2 = sorted_positions[0] 130 | p1 = sorted_positions[1] 131 | p0 = sorted_positions[2] 132 | 133 | return position_point(p0, p1, p2, distance=distance, angle=angle, dihedral=dihedral) 134 | -------------------------------------------------------------------------------- /molgym/spaces.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import defaultdict 3 | from typing import Tuple, List, Dict 4 | 5 | import ase.data 6 | import gym 7 | import numpy as np 8 | from ase import Atom, Atoms 9 | 10 | CanvasItemType = Tuple[int, Tuple[float, float, float]] 11 | ActionType = CanvasItemType 12 | CanvasType = Tuple[CanvasItemType, ...] 13 | BagType = Tuple[int, ...] 14 | ObservationType = Tuple[CanvasType, BagType] 15 | 16 | FormulaType = Tuple[Tuple[int, int], ...] 17 | 18 | NULL_SYMBOL = 'X' 19 | 20 | 21 | class CanvasItemSpace(gym.spaces.Tuple): 22 | def __init__(self, zs: List[int]) -> None: 23 | self.zs = zs 24 | 25 | label = gym.spaces.Discrete(n=len(zs)) 26 | 27 | low = np.array([-np.inf, -np.inf, -np.inf], dtype=np.float) 28 | high = np.array([np.inf, np.inf, np.inf], dtype=np.float) 29 | position = gym.spaces.Box(low=low, high=high, dtype=np.float) 30 | 31 | super().__init__((label, position)) 32 | 33 | def to_atom(self, canvas_item: CanvasItemType) -> Atom: 34 | label, position = canvas_item 35 | if label < 0: 36 | raise RuntimeError(f'Invalid atomic number: {label}') 37 | 38 | return Atom(symbol=self.zs[label], position=position) 39 | 40 | def from_atom(self, atom: Atom) -> CanvasItemType: 41 | return self.zs.index(ase.data.atomic_numbers[atom.symbol]), tuple(atom.position) # type: ignore 42 | 43 | 44 | ActionSpace = CanvasItemSpace 45 | 46 | 47 | class CanvasSpace(gym.spaces.Tuple): 48 | def __init__(self, size: int, zs: List[int]) -> None: 49 | assert 0 in zs, '0 has to be in the list of atomic numbers' 50 | self.size = size 51 | self.zs = zs 52 | self.canvas_item_space = CanvasItemSpace(zs) 53 | super().__init__((self.canvas_item_space, ) * self.size) 54 | 55 | def to_atoms(self, canvas: CanvasType) -> Atoms: 56 | atoms = Atoms() 57 | for canvas_item in canvas: 58 | atom = self.canvas_item_space.to_atom(canvas_item) 59 | if atom.symbol != NULL_SYMBOL: 60 | atoms.append(atom) 61 | return atoms 62 | 63 | def from_atoms(self, atoms: Atoms) -> CanvasType: 64 | if len(atoms) > self.size: 65 | raise RuntimeError(f'Too many atoms: {len(atoms)} > {self.size}') 66 | 67 | elif len(atoms) < self.size: 68 | atoms = atoms.copy() 69 | 70 | dummy = Atom(symbol=NULL_SYMBOL, position=(0, 0, 0)) 71 | while len(atoms) < self.size: 72 | atoms.append(dummy) 73 | 74 | return tuple(self.canvas_item_space.from_atom(atom) for atom in atoms) 75 | 76 | 77 | class BagSpace(gym.spaces.Tuple): 78 | def __init__(self, zs: List[int]): 79 | self.zs = zs 80 | self.size = len(zs) 81 | self.bag_item_space = gym.spaces.Discrete(n=sys.maxsize) 82 | 83 | super().__init__((self.bag_item_space, ) * self.size) 84 | 85 | def to_formula(self, bag: BagType) -> FormulaType: 86 | assert len(bag) == self.size 87 | return tuple(zip(self.zs, bag)) 88 | 89 | def from_formula(self, formula: FormulaType) -> BagType: 90 | assert all(z in self.zs for z, count in formula) 91 | formula_dict: Dict[int, int] = defaultdict(int) 92 | formula_dict.update(formula) 93 | return tuple(formula_dict[z] for z in self.zs) 94 | 95 | 96 | class ObservationSpace(gym.spaces.Tuple): 97 | def __init__(self, canvas_size: int, zs: List[int]): 98 | self.zs = zs 99 | self.canvas_space = CanvasSpace(size=canvas_size, zs=zs) 100 | self.bag_space = BagSpace(zs=zs) 101 | super().__init__((self.canvas_space, self.bag_space)) 102 | 103 | def build(self, atoms: Atoms, formula: FormulaType) -> ObservationType: 104 | return self.canvas_space.from_atoms(atoms), self.bag_space.from_formula(formula) 105 | 106 | def parse(self, observation: ObservationType) -> Tuple[Atoms, FormulaType]: 107 | return self.canvas_space.to_atoms(observation[0]), self.bag_space.to_formula(observation[1]) 108 | -------------------------------------------------------------------------------- /scripts/plot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import List 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from molgym.tools.analysis import parse_json_lines_file, parse_results_filename, collect_results_paths 10 | 11 | # Styling 12 | fig_width = 0.45 * 5.50107 13 | fig_height = 2.1 14 | 15 | plt.style.use('ggplot') 16 | plt.rcParams.update({'font.size': 6}) 17 | 18 | colors = [ 19 | '#1f77b4', # muted blue 20 | '#d62728', # brick red 21 | '#ff7f0e', # safety orange 22 | '#2ca02c', # cooked asparagus green 23 | '#9467bd', # muted purple 24 | '#8c564b', # chestnut brown 25 | '#e377c2', # raspberry yogurt pink 26 | '#7f7f7f', # middle gray 27 | '#bcbd22', # curry yellow-green 28 | '#17becf', # blue-teal 29 | ] 30 | 31 | 32 | def parse_args() -> argparse.Namespace: 33 | parser = argparse.ArgumentParser(description='Plot MolGym output') 34 | 35 | parser.add_argument('--dir', help='path to results directory (repeatable)', required=True, action='append') 36 | parser.add_argument('--baseline', help='baseline (repeatable)', required=False, action='append') 37 | parser.add_argument('--max_num_steps', help='analyse up to maximum number of steps', required=False, type=int) 38 | parser.add_argument('--min_num_steps', help='analyse after minimum number of steps', required=False, type=int) 39 | parser.add_argument('--mode', 40 | help='train or eval mode', 41 | required=False, 42 | type=str, 43 | choices=['train', 'eval'], 44 | default='eval') 45 | 46 | return parser.parse_args() 47 | 48 | 49 | def get_data(directories: List[str], mode: str) -> pd.DataFrame: 50 | paths = [] 51 | for directory in directories: 52 | paths += collect_results_paths(directory=directory, mode=mode) 53 | 54 | assert len(paths) > 0 55 | 56 | frames = [] 57 | for path in paths: 58 | df = pd.DataFrame(parse_json_lines_file(path)) 59 | 60 | info = parse_results_filename(os.path.basename(path)) 61 | df['seed'] = info['seed'] 62 | df['name'] = info['name'] 63 | df['mode'] = info['mode'] 64 | 65 | frames.append(df) 66 | 67 | data = pd.concat(frames) 68 | 69 | # Compute average and std over seeds 70 | data = data.groupby(['name', 'mode', 'total_num_steps']).agg([np.mean, np.std]).reset_index() 71 | 72 | return data 73 | 74 | 75 | def main() -> None: 76 | args = parse_args() 77 | data = get_data(directories=args.dir, mode=args.mode) 78 | 79 | if args.max_num_steps: 80 | data = data[data['total_num_steps'] <= args.max_num_steps] 81 | 82 | if args.min_num_steps: 83 | data = data[data['total_num_steps'] >= args.min_num_steps] 84 | 85 | fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(fig_width, fig_height), constrained_layout=True) 86 | color_iter = iter(colors) 87 | 88 | prop = 'return_mean' 89 | for j, (name, group) in enumerate(data.groupby('name')): 90 | color = next(color_iter) 91 | 92 | if group[prop]['mean'].isna().all(): 93 | continue 94 | ax.plot( 95 | group['total_num_steps'] / 1000, 96 | group[prop]['mean'], 97 | zorder=2 * j + 3, 98 | label=name, 99 | color=color, 100 | ) 101 | ax.fill_between( 102 | x=group['total_num_steps'] / 1000, 103 | y1=group[prop]['mean'] - group[prop]['std'], 104 | y2=group[prop]['mean'] + group[prop]['std'], 105 | alpha=0.5, 106 | zorder=2 * j + 2, 107 | color=color, 108 | ) 109 | 110 | color_iter = iter(colors) 111 | if args.baseline: 112 | for baseline in args.baseline: 113 | color = next(color_iter) 114 | ax.axhline(float(baseline), color=color, linestyle='dashed', zorder=1) 115 | 116 | ax.set_ylabel('Average Return') 117 | ax.set_xlabel('Steps x 1000') 118 | 119 | ax.legend(loc='lower right') 120 | 121 | fig.savefig('average_return.pdf') 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | -------------------------------------------------------------------------------- /molgym/env_container.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Tuple 3 | 4 | import gym 5 | import numpy as np 6 | 7 | # The class is based on: Baselines https://github.com/openai/baselines. 8 | from molgym.spaces import ObservationType 9 | 10 | 11 | class VecEnv(ABC): 12 | """ 13 | An abstract asynchronous, vectorized environment. 14 | Used to batch data from multiple copies of an environment, so that 15 | each observation becomes an batch of observations, and expected action is a batch of actions to 16 | be applied per-environment. 17 | """ 18 | @abstractmethod 19 | def reset(self) -> List[ObservationType]: 20 | """ 21 | Reset all the environments and return an array of 22 | observations, or a dict of observation arrays. 23 | 24 | If step_async is still doing work, that work will 25 | be cancelled and step_wait() should not be called 26 | until step_async() is invoked again. 27 | """ 28 | raise NotImplementedError 29 | 30 | @abstractmethod 31 | def step_async(self, actions) -> None: 32 | """ 33 | Tell all the environments to start taking a step 34 | with the given actions. 35 | Call step_wait() to get the results of the step. 36 | 37 | You should not call this if a step_async run is 38 | already pending. 39 | """ 40 | raise NotImplementedError 41 | 42 | @abstractmethod 43 | def step_wait(self) -> Tuple[List[ObservationType], np.ndarray, np.ndarray, List[dict]]: 44 | """ 45 | Wait for the step taken with step_async(). 46 | 47 | Returns (obs, rews, dones, infos): 48 | - obs: an array of observations, or a dict of 49 | arrays of observations. 50 | - rews: an array of rewards 51 | - dones: an array of "episode done" booleans 52 | - infos: a sequence of info objects 53 | """ 54 | raise NotImplementedError 55 | 56 | def step(self, actions) -> Tuple[List[ObservationType], np.ndarray, np.ndarray, List[dict]]: 57 | """ 58 | Step the environments synchronously. 59 | 60 | This is available for backwards compatibility. 61 | """ 62 | self.step_async(actions) 63 | return self.step_wait() 64 | 65 | def render(self, mode='human'): 66 | raise NotImplementedError 67 | 68 | @abstractmethod 69 | def get_size(self) -> int: 70 | raise NotImplementedError 71 | 72 | @abstractmethod 73 | def reset_if_terminal(self, observations: List[ObservationType], terminals: List[bool]): 74 | raise NotImplementedError 75 | 76 | 77 | # This class is based on: DeepRL https://github.com/ShangtongZhang/DeepRL. 78 | class SimpleEnvContainer(VecEnv): 79 | def __init__(self, environments: List[gym.Env]): 80 | super().__init__() 81 | self.environments = environments 82 | 83 | self.actions = None 84 | 85 | def step_async(self, actions: np.ndarray) -> None: 86 | self.actions = actions 87 | 88 | def step_wait(self): 89 | assert self.actions and len(self.environments) == len(self.actions) 90 | 91 | data = [] 92 | for env, action in zip(self.environments, self.actions): 93 | obs, reward, done, info = env.step(action) 94 | data.append([obs, reward, done, info]) 95 | 96 | obs_list, rewards, done_list, infos = zip(*data) 97 | return obs_list, np.array(rewards), np.array(done_list), infos 98 | 99 | def reset(self): 100 | return [env.reset() for env in self.environments] 101 | 102 | def reset_if_terminal(self, observations: List[ObservationType], terminals: List[bool]) -> List[ObservationType]: 103 | assert len(self.environments) == len(observations) == len(terminals) 104 | 105 | new_observations = [] 106 | for env, observation, terminal in zip(self.environments, observations, terminals): 107 | if terminal: 108 | new_observations.append(env.reset()) 109 | else: 110 | new_observations.append(observation) 111 | 112 | return new_observations 113 | 114 | def get_size(self) -> int: 115 | return len(self.environments) 116 | 117 | def close(self): 118 | pass 119 | 120 | def render(self, mode='human'): 121 | raise NotImplementedError 122 | -------------------------------------------------------------------------------- /molgym/tools/model_util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | from dataclasses import dataclass 5 | from typing import Tuple, Optional, Sequence 6 | 7 | import torch 8 | 9 | from molgym.agents.base import AbstractActorCritic 10 | from molgym.agents.covariant.agent import CovariantAC 11 | from molgym.agents.internal.agent import SchNetAC 12 | from molgym.spaces import ObservationSpace, ActionSpace 13 | 14 | 15 | def build_model(config: dict, observation_space: ObservationSpace, action_space: ActionSpace, 16 | device: torch.device) -> AbstractActorCritic: 17 | if config['model'] == 'internal': 18 | return SchNetAC( 19 | observation_space=observation_space, 20 | action_space=action_space, 21 | min_max_distance=(config['min_mean_distance'], config['max_mean_distance']), 22 | network_width=config['network_width'], 23 | device=device, 24 | ) 25 | elif config['model'] == 'covariant': 26 | return CovariantAC( 27 | observation_space=observation_space, 28 | action_space=action_space, 29 | min_max_distance=(config['min_mean_distance'], config['max_mean_distance']), 30 | network_width=config['network_width'], 31 | maxl=config['maxl'], 32 | num_cg_levels=config['num_cg_levels'], 33 | num_channels_hidden=config['num_channels_hidden'], 34 | num_channels_per_element=config['num_channels_per_element'], 35 | num_gaussians=config['num_gaussians'], 36 | bag_scale=config['bag_scale'], 37 | beta=float(config['beta']) if config['beta'] is not None else config['beta'], 38 | device=device, 39 | ) 40 | else: 41 | raise RuntimeError(f'Model \'{config["model"]}\' is not available.') 42 | 43 | 44 | @dataclass 45 | class ModelPathInfo: 46 | path: str 47 | tag: str 48 | num_steps: int 49 | 50 | 51 | class ModelIO: 52 | def __init__(self, directory: str, tag: str, keep: bool = False) -> None: 53 | self.directory = directory 54 | self.tag = tag 55 | self.keep = keep 56 | self.old_path: Optional[str] = None 57 | 58 | self._steps_string = '_steps-' 59 | self._suffix = '.model' 60 | self._iter_suffix = '.txt' 61 | 62 | def _get_model_filename(self, num_steps: int) -> str: 63 | return self.tag + self._steps_string + str(num_steps) + self._suffix 64 | 65 | def _list_file_paths(self) -> Sequence[str]: 66 | all_paths = [os.path.join(self.directory, f) for f in os.listdir(self.directory)] 67 | return [path for path in all_paths if os.path.isfile(path)] 68 | 69 | def _parse_model_path(self, path: str) -> Optional[ModelPathInfo]: 70 | filename = os.path.basename(path) 71 | regex = re.compile(rf'(?P.+){self._steps_string}(?P\d+){self._suffix}') 72 | match = regex.match(filename) 73 | if not match: 74 | return None 75 | 76 | return ModelPathInfo( 77 | path=path, 78 | tag=match.group('tag'), 79 | num_steps=int(match.group('num_steps')), 80 | ) 81 | 82 | def save(self, module: AbstractActorCritic, num_steps: int) -> None: 83 | if not self.keep and self.old_path: 84 | logging.debug(f'Deleting old model: {self.old_path}') 85 | os.remove(self.old_path) 86 | 87 | filename = self._get_model_filename(num_steps) 88 | path = os.path.join(self.directory, filename) 89 | logging.debug(f'Saving model: {path}') 90 | torch.save(obj=module, f=path) 91 | self.old_path = path 92 | 93 | def load(self, device: torch.device, path: str) -> Tuple[AbstractActorCritic, int]: 94 | model_info = self._parse_model_path(path) 95 | 96 | if model_info is None: 97 | raise RuntimeError(f"Cannot find model '{path}'") 98 | 99 | logging.info(f'Loading model: {model_info.path}') 100 | model = torch.load(f=model_info.path, map_location=device) 101 | 102 | return model, model_info.num_steps 103 | 104 | def load_latest(self, device: torch.device) -> Tuple[AbstractActorCritic, int]: 105 | all_file_paths = self._list_file_paths() 106 | model_infos = [self._parse_model_path(path) for path in all_file_paths] 107 | selected_model_infos = [info for info in model_infos if info and info.tag == self.tag] 108 | 109 | if len(selected_model_infos) == 0: 110 | raise RuntimeError(f"Cannot find model to load in '{self.directory}'") 111 | 112 | latest_model_info = max(selected_model_infos, key=lambda info: info.num_steps) 113 | 114 | logging.info(f'Loading model: {latest_model_info.path}') 115 | model = torch.load(f=latest_model_info.path, map_location=device) 116 | 117 | return model, latest_model_info.num_steps 118 | -------------------------------------------------------------------------------- /tests/agents/internal/test_zmat.py: -------------------------------------------------------------------------------- 1 | import io 2 | from unittest import TestCase 3 | 4 | import ase.io 5 | import numpy as np 6 | 7 | from molgym.agents.internal.zmat import get_distance, get_angle, get_dihedral, position_point 8 | 9 | 10 | class TestZMat(TestCase): 11 | def test_distance(self): 12 | p1 = np.array([0, 0, 0], dtype=np.float) 13 | p2 = np.array([0, 1, 0], dtype=np.float) 14 | p3 = np.array([1, 0, 0], dtype=np.float) 15 | 16 | self.assertAlmostEqual(get_distance(p1, p1), 0) 17 | self.assertAlmostEqual(get_distance(p1, p2), 1) 18 | self.assertAlmostEqual(get_distance(p1, p3), 1) 19 | self.assertAlmostEqual(get_distance(p2, p3), np.sqrt(2)) 20 | 21 | def test_angle(self): 22 | p1 = np.array([1, 0, 0], dtype=np.float) 23 | p2 = np.array([0, 0, 0], dtype=np.float) 24 | p3 = np.array([0, 1, 0], dtype=np.float) 25 | p4 = np.array([-1, 0, 0], dtype=np.float) 26 | 27 | self.assertAlmostEqual(get_angle(p1, p2, p1), 0) 28 | self.assertAlmostEqual(get_angle(p1, p2, p3), np.pi / 2) 29 | self.assertAlmostEqual(get_angle(p1, p2, p4), np.pi) 30 | 31 | def test_dihedral(self): 32 | p1 = np.array([0, 0, 1.5], dtype=np.float) 33 | p2 = np.array([0, 0, 0], dtype=np.float) 34 | p3 = np.array([0, 0.5, 0], dtype=np.float) 35 | 36 | for psi in np.arange(start=-np.pi, stop=np.pi, step=np.pi / 17): 37 | p4 = np.array([np.sin(psi), 0.5, np.cos(psi)], dtype=np.float) 38 | dihedral = get_dihedral(p1, p2, p3, p4) 39 | self.assertAlmostEqual(psi, dihedral) 40 | 41 | def test_dihedral_2(self): 42 | p1 = np.array([0, 0, 1.5], dtype=np.float) 43 | p2 = np.array([0, 0, 0], dtype=np.float) 44 | p3 = np.array([0, 0.5, 0], dtype=np.float) 45 | 46 | # Add delta so that the corner case of -180, 180 goes away 47 | delta = 1E-4 48 | for psi in np.arange(start=-np.pi + delta, stop=np.pi - delta, step=np.pi / 17): 49 | p4 = np.array([np.sin(2 * np.pi + psi), 0.5, np.cos(2 * np.pi + psi)], dtype=np.float) 50 | dihedral = get_dihedral(p1, p2, p3, p4) 51 | self.assertAlmostEqual(psi, dihedral) 52 | 53 | def test_dihedral_sign(self): 54 | p0 = np.array([0, 0, 1], dtype=np.float) 55 | p1 = np.array([0, 0, 0], dtype=np.float) 56 | p2 = np.array([0, 1, 0], dtype=np.float) 57 | 58 | p3_1 = np.array([1, 0, 0], dtype=np.float) 59 | dihedral = get_dihedral(p0, p1, p2, p3_1) 60 | self.assertEqual(dihedral, np.pi / 2) 61 | 62 | p3_2 = np.array([-1, 0, 0], dtype=np.float) 63 | dihedral = get_dihedral(p0, p1, p2, p3_2) 64 | self.assertEqual(dihedral, -np.pi / 2) 65 | 66 | def test_dihedral_nan(self): 67 | string = '4\n\nC 0.5995394918 0.0 1.0\nC -0.5995394918 0.0 1.0\nH -1.6616385861 0.0 1.0\nH 1.6616385861 0.0 1.0' 68 | atoms = ase.io.read(io.StringIO(string), format='xyz') 69 | dihedral = get_dihedral(*(a.position for a in atoms)) 70 | self.assertTrue(np.isnan(dihedral)) 71 | 72 | def test_positioning(self): 73 | p0 = np.array([0, 0, 1], dtype=np.float) 74 | p1 = np.array([0, 0, 0], dtype=np.float) 75 | p2 = np.array([0, 1, 0], dtype=np.float) 76 | 77 | distance = 2.5 78 | angle = 2 * np.pi / 3 79 | 80 | # Add delta so that the corner case of -180, 180 goes away 81 | delta = 1E-4 82 | for psi in np.arange(start=-np.pi + delta, stop=np.pi - delta, step=np.pi / 17): 83 | p_new = position_point(p0=p0, p1=p1, p2=p2, distance=distance, angle=angle, dihedral=psi) 84 | 85 | self.assertAlmostEqual(get_distance(p2, p_new), distance) 86 | self.assertAlmostEqual(get_angle(p1, p2, p_new), angle) 87 | self.assertAlmostEqual(get_dihedral(p0, p1, p2, p_new), psi) 88 | 89 | def test_neg_angles(self): 90 | p0 = np.array([0, 0, 1], dtype=np.float) 91 | p1 = np.array([0, 0, 0], dtype=np.float) 92 | p2 = np.array([0, 1, 0], dtype=np.float) 93 | 94 | angle = 1 * np.pi / 3 95 | p_neg = position_point(p0=p0, p1=p1, p2=p2, distance=2.5, angle=-1 * angle, dihedral=np.pi) 96 | self.assertAlmostEqual(get_angle(p1, p2, p_neg), angle) 97 | 98 | def test_neg_distance(self): 99 | p0 = np.array([0, 0, 1], dtype=np.float) 100 | p1 = np.array([0, 0, 0], dtype=np.float) 101 | p2 = np.array([0, 1, 0], dtype=np.float) 102 | 103 | distance = 2.5 104 | angle = 3 * np.pi / 2 105 | dihedral = 3 * np.pi / 2 106 | 107 | p_new = position_point(p0=p0, p1=p1, p2=p2, distance=-1 * distance, angle=angle, dihedral=dihedral) 108 | 109 | self.assertAlmostEqual(get_distance(p2, p_new), distance) 110 | 111 | # If the distance is negative, the angle is messed up! 112 | self.assertNotAlmostEqual(get_angle(p1, p2, p_new), angle) 113 | -------------------------------------------------------------------------------- /molgym/buffer.py: -------------------------------------------------------------------------------- 1 | # The content of this file is based on: OpenAI Spinning Up https://spinningup.openai.com/. 2 | from typing import Optional, List, Tuple 3 | 4 | import numpy as np 5 | 6 | from molgym.spaces import ObservationType 7 | from molgym.tools import util 8 | 9 | 10 | class DynamicPPOBuffer: 11 | """ 12 | A buffer for storing trajectories experienced by a PPO agent interacting 13 | with the environment, and using Generalized Advantage Estimation (GAE-Lambda) 14 | for calculating the advantages of state-action pairs. 15 | """ 16 | BUFFER_FIELDS = [ 17 | 'obs_buf', 'act_buf', 'rew_buf', 'next_obs_buf', 'term_buf', 'val_buf', 'logp_buf', 'adv_buf', 'ret_buf' 18 | ] 19 | 20 | def __init__(self, gamma=0.99, lam=0.95) -> None: 21 | self.obs_buf: List[ObservationType] = [] 22 | self.act_buf: List[np.ndarray] = [] 23 | self.rew_buf: List[float] = [] 24 | self.next_obs_buf: List[ObservationType] = [] 25 | self.term_buf: List[bool] = [] 26 | 27 | self.val_buf: List[float] = [] 28 | self.logp_buf: List[float] = [] 29 | 30 | # Filled when path is finished 31 | self.adv_buf: List[float] = [] 32 | self.ret_buf: List[float] = [] 33 | 34 | self.gamma = gamma 35 | self.lam = lam 36 | 37 | self.current_index = 0 38 | self.start_index = 0 39 | 40 | def store(self, obs: ObservationType, act: np.ndarray, reward: float, next_obs: ObservationType, terminal: bool, 41 | value: float, logp: float) -> None: 42 | """Append one time step of agent-environment interaction to the buffer.""" 43 | self.obs_buf.append(obs) 44 | self.act_buf.append(act) 45 | self.rew_buf.append(reward) 46 | self.next_obs_buf.append(next_obs) 47 | self.term_buf.append(terminal) 48 | 49 | self.val_buf.append(value) 50 | self.logp_buf.append(logp) 51 | 52 | self.current_index += 1 53 | 54 | def finish_path(self, last_val: float) -> Tuple[Optional[float], int]: 55 | """ 56 | Call this at the end of a trajectory, or when one gets cut off 57 | by an epoch ending. This looks back in the buffer to where the 58 | trajectory started, and uses rewards and value estimates from 59 | the whole trajectory to compute advantage estimates with GAE-Lambda, 60 | as well as compute the rewards-to-go for each state, to use as 61 | the targets for the value function. 62 | 63 | The "last_val" argument should be 0 if the trajectory ended 64 | because the agent reached a terminal state (died), and otherwise 65 | should be V(s_T), the value function estimated for the last state. 66 | This allows us to bootstrap the reward-to-go calculation to account 67 | for timesteps beyond the arbitrary episode horizon (or epoch cutoff). 68 | """ 69 | 70 | if self.is_finished(): 71 | return None, 0 72 | 73 | path_slice = slice(self.start_index, self.current_index) 74 | rews = np.array(self.rew_buf[path_slice] + [last_val]) 75 | vals = np.array(self.val_buf[path_slice] + [last_val]) 76 | 77 | # the next two lines implement GAE-Lambda advantage calculation 78 | deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1] 79 | self.adv_buf += util.discount_cumsum(deltas, self.gamma * self.lam).tolist() 80 | 81 | # the next line computes rewards-to-go, to be targets for the value function 82 | self.ret_buf += util.discount_cumsum(rews, self.gamma).tolist()[:-1] 83 | 84 | episodic_return = self.ret_buf[self.start_index] 85 | episode_length = self.current_index - self.start_index 86 | 87 | self.start_index = self.current_index 88 | 89 | # Ensure that all buffer fields have the same length 90 | assert all(len(getattr(self, field)) == self.current_index for field in DynamicPPOBuffer.BUFFER_FIELDS) 91 | 92 | return episodic_return, episode_length 93 | 94 | def is_finished(self) -> bool: 95 | return self.start_index == self.current_index 96 | 97 | def get_data(self) -> dict: 98 | """ 99 | Call this at the end of an epoch to get all of the data from 100 | the buffer, with advantages appropriately normalized (shifted to have 101 | mean zero and std one). Also, resets some pointers in the buffer. 102 | """ 103 | assert self.is_finished() 104 | 105 | # advantage normalization trick 106 | adv_buf = np.array(self.adv_buf) 107 | adv_mean = np.mean(adv_buf) 108 | adv_std = np.std(adv_buf) 109 | 110 | adv_buf_standard = (adv_buf - adv_mean) / adv_std 111 | 112 | return dict(obs=self.obs_buf, 113 | act=np.array(self.act_buf), 114 | ret=np.array(self.ret_buf), 115 | adv=adv_buf_standard, 116 | logp=np.array(self.logp_buf)) 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MolGym: Reinforcement Learning for 3D Molecular Design 2 | 3 | This repository allows to train reinforcement learning policies for designing molecules directly in Cartesian coordinates. The agent builds molecules by repeatedly taking atoms from a given _bag_ and placing them onto a 3D _canvas_. 4 | 5 | 6 | 7 | Check out our [blog post](https://mlg-blog.com/2021/04/30/reinforcement-learning-for-3d-molecular-design.html) for a gentle introduction. For more details, see our papers: 8 | 9 | **Reinforcement Learning for Molecular Design Guided by Quantum Mechanics**
10 | Gregor N. C. Simm*, Robert Pinsler* and José Miguel Hernández-Lobato
11 | *Proceedings of the 37th International Conference on Machine Learning*, Vienna, Austria, PMLR 108, 2020.
12 | http://proceedings.mlr.press/v119/simm20b.html 13 | 14 | **Symmetry-Aware Actor-Critic for 3D Molecular Design**
15 | Gregor N. C. Simm, Robert Pinsler, Gábor Csányi and José Miguel Hernández-Lobato
16 | *International Conference on Learning Representations*, 2021.
17 | https://openreview.net/forum?id=jEYKjPE1xYN 18 | 19 | ## Setup 20 | 21 | Dependencies: 22 | * Python >= 3.7 23 | * [ase](https://wiki.fysik.dtu.dk/ase/) 24 | * [cormorant](https://github.com/risilab/cormorant) 25 | * [gym](https://www.gymlibrary.ml/) 26 | * [matplotlib](https://matplotlib.org/) 27 | * [pandas](https://pandas.pydata.org/) 28 | * [quadpy](https://github.com/sigma-py/quadpy) 29 | * [schnetpack](https://schnetpack.readthedocs.io) 30 | * [sparrow](https://github.com/qcscine/sparrow) >= 2.0.1 31 | * torch >= 1.5.1 32 | * [torch-scatter](https://github.com/rusty1s/pytorch_scatter) >= 2.0.5 33 | 34 | Install required packages and library itself: 35 | ``` 36 | pip install -r requirements.txt 37 | pip install -e . 38 | ``` 39 | 40 | **Note:** Make sure that the CUDA versions associated with `torch` and `torch-scatter` match. Check the [documentation](https://github.com/rusty1s/pytorch_scatter) if you run into any errors when installing `torch-scatter`. 41 | 42 | ### Sparrow Setup 43 | 44 | Sparrow can be installed using the *conda* package manager and is available on the *conda-forge* channel. 45 | To install the *conda* package manager we recommend the [miniforge](https://github.com/conda-forge/miniforge/releases) installer. 46 | If the *conda-forge* channel is not yet enabled, add it to your channels with 47 | 48 | ``` 49 | conda config --add channels conda-forge 50 | conda config --set channel_priority strict 51 | ``` 52 | 53 | Once the `conda-forge` channel has been enabled, `scine-sparrow-python` can be installed with `conda`: 54 | 55 | ``` 56 | conda install scine-sparrow-python 57 | ``` 58 | 59 | 60 | ## Usage 61 | 62 | You can use this code to train and evaluate reinforcement learning agents for 3D molecular design. We currently support running experiments given a specific bag (single-bag), a stochastic bag, or multiple bags (multi-bag). 63 | 64 | ### Training 65 | To perform the single-bag experiment with SF6, run 66 | ```shell 67 | python3 scripts/run.py \ 68 | --name=SF6 \ 69 | --symbols=X,F,S \ 70 | --formulas=SF6 \ 71 | --min_mean_distance=1.10 \ 72 | --max_mean_distance=2.10 \ 73 | --bag_scale=5 \ 74 | --beta=-10 \ 75 | --model=covariant \ 76 | --canvas_size=7 \ 77 | --num_envs=10 \ 78 | --num_steps=15000 \ 79 | --num_steps_per_iter=140 \ 80 | --mini_batch_size=140 \ 81 | --save_rollouts=eval \ 82 | --device=cuda \ 83 | --seed=1 84 | ``` 85 | Hyper-parameters for the other experiments can be found in the papers. 86 | 87 | ### Evaluation 88 | 89 | To generate learning curves, run the following command: 90 | ```shell 91 | python3 scripts/plot.py --dir=results 92 | ``` 93 | Running this script will automatically generate a figure of the learning curve. 94 | 95 | To write out the generated structures, run the following command: 96 | ```shell 97 | python3 scripts/structures.py --dir=data --symbols=X,F,S 98 | ``` 99 | You can visualize the structures in the generated XYZ file using, for example, [PyMOL](https://pymol.org/2/). 100 | 101 | ## Citation 102 | 103 | If you use this code, please cite our papers: 104 | ```txt 105 | @inproceedings{Simm2020Reinforcement, 106 | title = {Reinforcement Learning for Molecular Design Guided by Quantum Mechanics}, 107 | booktitle = {Proceedings of the 37th International Conference on Machine Learning}, 108 | author = {Simm, Gregor N. C. and Pinsler, Robert and {Hern{\'a}ndez-Lobato}, Jos{\'e} Miguel}, 109 | editor = {III, Hal Daum{\'e} and Singh, Aarti}, 110 | year = {2020}, 111 | volume = {119}, 112 | pages = {8959--8969}, 113 | publisher = {{PMLR}}, 114 | series = {Proceedings of Machine Learning Research} 115 | url = {http://proceedings.mlr.press/v119/simm20b.html} 116 | } 117 | 118 | @inproceedings{Simm2021SymmetryAware, 119 | title = {Symmetry-Aware Actor-Critic for 3D Molecular Design}, 120 | author = {Gregor N. C. Simm and Robert Pinsler and G{\'a}bor Cs{\'a}nyi and Jos{\'e} Miguel Hern{\'a}ndez-Lobato}, 121 | booktitle = {International Conference on Learning Representations}, 122 | year = {2021}, 123 | url = {https://openreview.net/forum?id=jEYKjPE1xYN} 124 | } 125 | ``` 126 | -------------------------------------------------------------------------------- /scripts/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import ase.data 4 | import ase.io 5 | 6 | from molgym.env_container import SimpleEnvContainer 7 | from molgym.environment import MolecularEnvironment 8 | from molgym.ppo import batch_ppo 9 | from molgym.reward import InteractionReward 10 | from molgym.spaces import ActionSpace, ObservationSpace 11 | from molgym.tools import util 12 | from molgym.tools.arg_parser import build_default_argparser 13 | from molgym.tools.model_util import ModelIO, build_model 14 | 15 | 16 | def get_config() -> dict: 17 | parser = build_default_argparser() 18 | args = parser.parse_args() 19 | config = vars(args) 20 | return config 21 | 22 | 23 | def main() -> None: 24 | config = get_config() 25 | 26 | util.create_directories([config['log_dir'], config['model_dir'], config['data_dir'], config['results_dir']]) 27 | 28 | tag = util.get_tag(config) 29 | util.setup_logger(config, directory=config['log_dir'], tag=tag) 30 | util.save_config(config, directory=config['log_dir'], tag=tag) 31 | 32 | util.set_seeds(seed=config['seed']) 33 | device = util.init_device(config['device']) 34 | 35 | zs = [ase.data.atomic_numbers[s] for s in config['symbols'].split(',')] 36 | action_space = ActionSpace(zs=zs) 37 | observation_space = ObservationSpace(canvas_size=config['canvas_size'], zs=zs) 38 | 39 | # Evaluation formulas 40 | if not config['eval_formulas']: 41 | config['eval_formulas'] = config['formulas'] 42 | 43 | train_formulas = util.split_formula_strings(config['formulas']) 44 | eval_formulas = util.split_formula_strings(config['eval_formulas']) 45 | 46 | logging.info(f'Training bags: {train_formulas}') 47 | logging.info(f'Evaluation bags: {eval_formulas}') 48 | 49 | model_handler = ModelIO(directory=config['model_dir'], tag=tag, keep=config['keep_models']) 50 | 51 | if config['load_latest']: 52 | model, start_num_steps = model_handler.load_latest(device=device) 53 | model.action_space = action_space 54 | model.observation_space = observation_space 55 | elif config['load_model'] is not None: 56 | model, start_num_steps = model_handler.load(device=device, path=config['load_model']) 57 | model.action_space = action_space 58 | model.observation_space = observation_space 59 | else: 60 | model = build_model(config, observation_space=observation_space, action_space=action_space, device=device) 61 | start_num_steps = 0 62 | 63 | var_counts = util.count_vars(model) 64 | logging.info(f'Number of parameters: {var_counts}') 65 | 66 | reward = InteractionReward() 67 | 68 | # Number of episodes during evaluation 69 | if not config['num_eval_episodes']: 70 | config['num_eval_episodes'] = len(eval_formulas) 71 | 72 | training_envs = SimpleEnvContainer([ 73 | MolecularEnvironment( 74 | reward=reward, 75 | observation_space=observation_space, 76 | action_space=action_space, 77 | formulas=[util.string_to_formula(f) for f in train_formulas], 78 | min_atomic_distance=config['min_atomic_distance'], 79 | max_solo_distance=config['max_solo_distance'], 80 | min_reward=config['min_reward'], 81 | ) for _ in range(config['num_envs']) 82 | ]) 83 | 84 | eval_envs = SimpleEnvContainer([ 85 | MolecularEnvironment( 86 | reward=reward, 87 | observation_space=observation_space, 88 | action_space=action_space, 89 | formulas=[util.string_to_formula(f) for f in eval_formulas], 90 | min_atomic_distance=config['min_atomic_distance'], 91 | max_solo_distance=config['max_solo_distance'], 92 | min_reward=config['min_reward'], 93 | ) 94 | ]) 95 | 96 | batch_ppo( 97 | envs=training_envs, 98 | eval_envs=eval_envs, 99 | ac=model, 100 | optimizer=util.get_optimizer(name=config['optimizer'], 101 | learning_rate=config['learning_rate'], 102 | parameters=model.parameters()), 103 | gamma=config['discount'], 104 | start_num_steps=start_num_steps, 105 | max_num_steps=config['max_num_steps'], 106 | num_steps_per_iter=config['num_steps_per_iter'], 107 | mini_batch_size=config['mini_batch_size'], 108 | clip_ratio=config['clip_ratio'], 109 | vf_coef=config['vf_coef'], 110 | entropy_coef=config['entropy_coef'], 111 | max_num_train_iters=config['max_num_train_iters'], 112 | lam=config['lam'], 113 | target_kl=config['target_kl'], 114 | gradient_clip=config['gradient_clip'], 115 | eval_freq=config['eval_freq'], 116 | model_handler=model_handler, 117 | save_freq=config['save_freq'], 118 | num_eval_episodes=config['num_eval_episodes'], 119 | rollout_saver=util.RolloutSaver(directory=config['data_dir'], tag=tag), 120 | save_train_rollout=config['save_rollouts'] == 'train' or config['save_rollouts'] == 'all', 121 | save_eval_rollout=config['save_rollouts'] == 'eval' or config['save_rollouts'] == 'all', 122 | info_saver=util.InfoSaver(directory=config['results_dir'], tag=tag), 123 | device=device, 124 | ) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /scripts/run_stochastic.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import ase.data 4 | import ase.io 5 | 6 | from molgym.env_container import SimpleEnvContainer 7 | from molgym.environment import StochasticEnvironment, MolecularEnvironment 8 | from molgym.ppo import batch_ppo 9 | from molgym.reward import InteractionReward 10 | from molgym.spaces import ActionSpace, ObservationSpace 11 | from molgym.tools import util 12 | from molgym.tools.arg_parser import build_default_argparser 13 | from molgym.tools.model_util import ModelIO, build_model 14 | 15 | 16 | def get_config() -> dict: 17 | parser = build_default_argparser() 18 | parser.add_argument('--size_range', help='minimum and maximum bag size (comma-separated)', type=str, required=True) 19 | args = parser.parse_args() 20 | config = vars(args) 21 | return config 22 | 23 | 24 | def main() -> None: 25 | config = get_config() 26 | 27 | util.create_directories([config['log_dir'], config['model_dir'], config['data_dir'], config['results_dir']]) 28 | 29 | tag = util.get_tag(config) 30 | util.setup_logger(config, directory=config['log_dir'], tag=tag) 31 | util.save_config(config, directory=config['log_dir'], tag=tag) 32 | 33 | util.set_seeds(seed=config['seed']) 34 | device = util.init_device(config['device']) 35 | 36 | zs = [ase.data.atomic_numbers[s] for s in config['symbols'].split(',')] 37 | action_space = ActionSpace(zs=zs) 38 | observation_space = ObservationSpace(canvas_size=config['canvas_size'], zs=zs) 39 | 40 | # Evaluation formulas 41 | if not config['eval_formulas']: 42 | config['eval_formulas'] = config['formulas'] 43 | 44 | train_formula = util.split_formula_strings(config['formulas'])[0] 45 | eval_formulas = util.split_formula_strings(config['eval_formulas']) 46 | size_range = util.parse_size_range(config['size_range']) 47 | 48 | logging.info(f'Statistical training bag: {train_formula}, size range: {size_range}') 49 | logging.info(f'Evaluation bag(s): {eval_formulas}') 50 | 51 | model_handler = ModelIO(directory=config['model_dir'], tag=tag, keep=config['keep_models']) 52 | 53 | start_num_steps = 0 54 | if not config['load_latest']: 55 | model = build_model(config, observation_space=observation_space, action_space=action_space, device=device) 56 | else: 57 | model, start_num_steps = model_handler.load_latest(device=device) 58 | model.action_space = action_space 59 | model.observation_space = observation_space 60 | 61 | var_counts = util.count_vars(model) 62 | logging.info(f'Number of parameters: {var_counts}') 63 | 64 | reward = InteractionReward() 65 | 66 | # Number of episodes during evaluation 67 | if not config['num_eval_episodes']: 68 | config['num_eval_episodes'] = len(eval_formulas) 69 | 70 | training_envs = SimpleEnvContainer([ 71 | StochasticEnvironment( 72 | reward=reward, 73 | observation_space=observation_space, 74 | action_space=action_space, 75 | formula=util.string_to_formula(train_formula), 76 | size_range=size_range, 77 | min_atomic_distance=config['min_atomic_distance'], 78 | max_solo_distance=config['max_solo_distance'], 79 | min_reward=config['min_reward'], 80 | ) for _ in range(config['num_envs']) 81 | ]) 82 | 83 | eval_envs = SimpleEnvContainer([ 84 | MolecularEnvironment( 85 | reward=reward, 86 | observation_space=observation_space, 87 | action_space=action_space, 88 | formulas=[util.string_to_formula(formula) for formula in eval_formulas], 89 | min_atomic_distance=config['min_atomic_distance'], 90 | max_solo_distance=config['max_solo_distance'], 91 | min_reward=config['min_reward'], 92 | ) 93 | ]) 94 | 95 | batch_ppo( 96 | envs=training_envs, 97 | eval_envs=eval_envs, 98 | ac=model, 99 | optimizer=util.get_optimizer(name=config['optimizer'], 100 | learning_rate=config['learning_rate'], 101 | parameters=model.parameters()), 102 | gamma=config['discount'], 103 | start_num_steps=start_num_steps, 104 | max_num_steps=config['max_num_steps'], 105 | num_steps_per_iter=config['num_steps_per_iter'], 106 | mini_batch_size=config['mini_batch_size'], 107 | clip_ratio=config['clip_ratio'], 108 | vf_coef=config['vf_coef'], 109 | entropy_coef=config['entropy_coef'], 110 | max_num_train_iters=config['max_num_train_iters'], 111 | lam=config['lam'], 112 | target_kl=config['target_kl'], 113 | gradient_clip=config['gradient_clip'], 114 | eval_freq=config['eval_freq'], 115 | model_handler=model_handler, 116 | save_freq=config['save_freq'], 117 | num_eval_episodes=config['num_eval_episodes'], 118 | rollout_saver=util.RolloutSaver(directory=config['data_dir'], tag=tag), 119 | save_train_rollout=config['save_rollouts'] == 'train' or config['save_rollouts'] == 'all', 120 | save_eval_rollout=config['save_rollouts'] == 'eval' or config['save_rollouts'] == 'all', 121 | info_saver=util.InfoSaver(directory=config['results_dir'], tag=tag), 122 | device=device, 123 | ) 124 | 125 | 126 | if __name__ == '__main__': 127 | main() 128 | -------------------------------------------------------------------------------- /tests/agents/covariant/test_so3_tools.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | import torch 5 | from cormorant.cg_lib import SphericalHarmonics 6 | from cormorant.so3_lib import SO3WignerD 7 | from cormorant.so3_lib.rotations import rotate_rep 8 | 9 | from molgym.agents.covariant.so3_tools import (spherical_to_cartesian, estimate_alms, concat_so3vecs, AtomicScalars, 10 | generate_fibonacci_grid, cartesian_to_spherical, complex_product, 11 | get_normalization_constant, normalize_alms) 12 | from molgym.tools.util import to_numpy 13 | 14 | 15 | class FibonacciGridTest(TestCase): 16 | def test_generation(self): 17 | count = 10 18 | grid = generate_fibonacci_grid(n=count) 19 | self.assertEqual(grid.shape, (count, 3)) 20 | 21 | def test_empty(self): 22 | count = 0 23 | grid = generate_fibonacci_grid(n=count) 24 | self.assertEqual(grid.shape, (count, 3)) 25 | 26 | 27 | class SphericalCartesianTransformationTest(TestCase): 28 | def test_spherical_to_cartesian(self): 29 | theta_phi = np.array([[np.pi / 2, np.pi]]) 30 | xyz = spherical_to_cartesian(theta_phi) 31 | self.assertTrue(np.all(np.isclose(xyz, np.array([[-1.0, 0.0, 0.0]])))) 32 | 33 | def test_spherical_to_cartesian_2(self): 34 | theta_phi = np.array([[np.pi / 2, 3 / 2 * np.pi]]) 35 | xyz = spherical_to_cartesian(theta_phi) 36 | self.assertTrue(np.all(np.isclose(xyz, np.array([[0.0, -1.0, 0.0]])))) 37 | 38 | def test_cartesian_to_spherical(self): 39 | xyz = np.array([[0.0, -1.0, 0.0]]) 40 | theta_phi = cartesian_to_spherical(xyz) 41 | self.assertTrue(np.all(np.isclose(theta_phi, np.array([[np.pi / 2, -np.pi / 2]])))) 42 | 43 | def test_cycle(self): 44 | xyz = np.array([[0.0, -1.0, 0.0]]) 45 | theta_phi = cartesian_to_spherical(xyz) 46 | xyz_new = spherical_to_cartesian(theta_phi) 47 | self.assertTrue(np.all(np.isclose(xyz, xyz_new))) 48 | 49 | def test_cycle_2(self): 50 | theta_phi = np.array([[0.3, -1.2]]) 51 | xyz = spherical_to_cartesian(theta_phi) 52 | theta_phi_2 = cartesian_to_spherical(xyz) 53 | self.assertTrue(np.all(np.isclose(theta_phi, theta_phi_2))) 54 | 55 | 56 | class ComplexNumbersTest(TestCase): 57 | def test_multiplication(self): 58 | a = torch.tensor([2.0, -1.0], dtype=torch.float) 59 | b = torch.tensor([3.0, -2.0], dtype=torch.float) 60 | c = to_numpy(complex_product(a, b)) 61 | expected = np.array([4.0, -7.0]) 62 | self.assertTrue(np.allclose(c, expected)) 63 | 64 | def test_multiplication_2(self): 65 | a = torch.tensor([2.0, 0.0], dtype=torch.float) 66 | b = torch.tensor([3.0, 0.0], dtype=torch.float) 67 | c = to_numpy(complex_product(a, b)) 68 | expected = np.array([6.0, 0.0]) 69 | self.assertTrue(np.allclose(c, expected)) 70 | 71 | 72 | class NormalizationTest(TestCase): 73 | def setUp(self): 74 | self.maxl = 4 75 | self.sphs = SphericalHarmonics(maxl=self.maxl) 76 | self.sphs_conj = SphericalHarmonics(maxl=self.maxl, conj=True, sh_norm='unit') 77 | 78 | def test_concat(self): 79 | theta_phi = np.array([[np.pi / 2, np.pi / 2]]) 80 | xyz_refs = spherical_to_cartesian(theta_phi) 81 | y_lms_conj = self.sphs_conj.forward(torch.tensor(xyz_refs, dtype=torch.float)) 82 | 83 | a_lms = estimate_alms(y_lms_conj) 84 | a_lms = concat_so3vecs([a_lms] * 3) 85 | 86 | self.assertTrue(all(a_lm.shape[0] == 3 for a_lm in a_lms)) 87 | 88 | def test_normalization(self): 89 | theta_phi = np.array([[np.pi / 2, np.pi / 2]]) 90 | xyz_refs = spherical_to_cartesian(theta_phi) 91 | y_lms_conj = self.sphs_conj.forward(torch.tensor(xyz_refs, dtype=torch.float)) 92 | 93 | a_lms = estimate_alms(y_lms_conj) 94 | k1 = get_normalization_constant(a_lms) 95 | 96 | self.assertTrue(k1.shape, (1, )) 97 | 98 | # If sh_norm='unit', sum over m = 1. 99 | self.assertTrue(k1.item(), self.maxl + 1) 100 | 101 | normalized_a_lms = normalize_alms(a_lms) 102 | 103 | k2 = get_normalization_constant(normalized_a_lms) 104 | self.assertAlmostEqual(k2.item(), 1.0) 105 | 106 | 107 | class AtomicScalarsTest(TestCase): 108 | def test_invariant(self): 109 | max_ell = 4 110 | sphs_conj = SphericalHarmonics(maxl=max_ell, conj=True, sh_norm='unit') 111 | atomic_scalars = AtomicScalars(maxl=max_ell) 112 | 113 | theta_phi = np.array([[np.pi / 3, np.pi / 4], [2 * np.pi / 3, np.pi / 2]]) 114 | xyz_refs = spherical_to_cartesian(theta_phi) 115 | y_lms_conj = sphs_conj.forward(torch.tensor(xyz_refs, dtype=torch.float)) 116 | 117 | a_lms = estimate_alms(y_lms_conj) 118 | 119 | invariant = atomic_scalars(a_lms) 120 | 121 | self.assertTrue(invariant.shape[-1], atomic_scalars.get_output_dim(channels=1)) 122 | 123 | random_rotation = SO3WignerD.euler(maxl=max_ell, dtype=torch.float) 124 | a_lms_rotated = rotate_rep(random_rotation, a_lms) 125 | 126 | self.assertFalse(np.allclose(to_numpy(a_lms[1]), to_numpy(a_lms_rotated[1]))) 127 | 128 | invariant_rotated = atomic_scalars(a_lms_rotated) 129 | 130 | self.assertTrue(np.allclose(invariant, invariant_rotated)) 131 | -------------------------------------------------------------------------------- /tests/agents/covariant/test_agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest import TestCase 3 | 4 | import ase.data 5 | import ase.io 6 | import numpy as np 7 | import pkg_resources 8 | import torch 9 | from cormorant.so3_lib import rotations 10 | 11 | from molgym.agents.covariant.agent import CovariantAC 12 | from molgym.agents.covariant.so3_tools import generate_fibonacci_grid, AtomicScalars 13 | from molgym.spaces import ActionSpace, ObservationSpace 14 | from molgym.tools import util 15 | 16 | RESOURCES_FOLDER = 'resources' 17 | 18 | 19 | class CovariantAgentTest(TestCase): 20 | RESOURCES = pkg_resources.resource_filename(__package__, RESOURCES_FOLDER) 21 | 22 | def setUp(self) -> None: 23 | util.set_seeds(0) 24 | self.device = torch.device('cpu') 25 | self.action_space = ActionSpace(zs=[1]) 26 | self.observation_space = ObservationSpace(canvas_size=5, zs=[0, 1, 6, 8]) 27 | self.agent = CovariantAC( 28 | observation_space=self.observation_space, 29 | action_space=self.action_space, 30 | min_max_distance=(0.9, 1.8), 31 | network_width=64, 32 | bag_scale=1, 33 | device=self.device, 34 | beta=100, 35 | maxl=4, 36 | num_cg_levels=3, 37 | num_channels_hidden=10, 38 | num_channels_per_element=4, 39 | num_gaussians=3, 40 | ) 41 | self.formula = ((1, 1), ) 42 | 43 | def verify_alms(self, atoms): 44 | observation = self.observation_space.build(atoms, formula=self.formula) 45 | util.set_seeds(0) 46 | action = self.agent.step([observation]) 47 | so3_dist = action['dists'][-1] 48 | 49 | # Rotate 50 | wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh, dtype=self.agent.dtype) 51 | atoms.positions = np.einsum('ij,...j->...i', rot_mat, atoms.positions) 52 | 53 | observation = self.observation_space.build(atoms, formula=self.formula) 54 | util.set_seeds(0) 55 | action = self.agent.step([observation]) 56 | so3_dist_rot = action['dists'][-1] 57 | 58 | rotated_b_lms = so3_dist.coefficients.apply_wigner(wigner_d) 59 | for part1, part2 in zip(so3_dist_rot.coefficients, rotated_b_lms): 60 | max_delta = torch.max(torch.abs(part1 - part2)) 61 | self.assertTrue(max_delta < 1e-5) 62 | 63 | def test_rotations(self): 64 | for file in ['h2o.xyz', 'ch3.xyz', 'ch4.xyz']: 65 | self.verify_alms(atoms=ase.io.read(filename=os.path.join(self.RESOURCES, file), format='xyz', index=0)) 66 | 67 | def verify_probs(self, atoms): 68 | grid_points = torch.tensor(generate_fibonacci_grid(n=100_000), dtype=torch.float, device=self.device) 69 | grid_points = grid_points.unsqueeze(-2) 70 | 71 | observation = self.observation_space.build(atoms, formula=self.formula) 72 | util.set_seeds(0) 73 | action = self.agent.step([observation]) 74 | so3_dist = action['dists'][-1] 75 | 76 | # Rotate atoms 77 | wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh, dtype=self.agent.dtype) 78 | atoms_rotated = atoms.copy() 79 | atoms_rotated.positions = np.einsum('ij,...j->...i', rot_mat, atoms.positions) 80 | 81 | observation = self.observation_space.build(atoms_rotated, formula=self.formula) 82 | util.set_seeds(0) 83 | action = self.agent.step([observation]) 84 | so3_dist_rot = action['dists'][-1] 85 | 86 | log_probs = so3_dist.log_prob(grid_points) # (samples, batches) 87 | log_probs_rot = so3_dist_rot.log_prob(grid_points) # (samples, batches) 88 | 89 | # Maximum over grid points 90 | maximum, max_indices = torch.max(log_probs, dim=0) 91 | minimum, min_indices = torch.min(log_probs, dim=0) 92 | 93 | maximum_rot, max_indices_rot = torch.max(log_probs_rot, dim=0) 94 | minimum_rot, min_indices_rot = torch.min(log_probs_rot, dim=0) 95 | 96 | self.assertTrue(torch.allclose(maximum, maximum_rot, atol=5e-3)) 97 | self.assertTrue(torch.allclose(minimum, minimum_rot, atol=5e-3)) 98 | 99 | def test_distribution(self): 100 | for file in ['h2o.xyz', 'ch3.xyz', 'ch4.xyz']: 101 | self.verify_probs(atoms=ase.io.read(filename=os.path.join(self.RESOURCES, file), format='xyz', index=0)) 102 | 103 | def verify_invariance(self, atoms): 104 | atomic_scalars = AtomicScalars(maxl=self.agent.max_sh) 105 | 106 | observation = self.observation_space.build(atoms, formula=self.formula) 107 | util.set_seeds(0) 108 | action = self.agent.step([observation]) 109 | so3_dist = action['dists'][-1] 110 | scalars = atomic_scalars(so3_dist.coefficients) 111 | 112 | # Rotate atoms 113 | wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh, dtype=self.agent.dtype) 114 | atoms_rotated = atoms.copy() 115 | atoms_rotated.positions = np.einsum('ij,...j->...i', rot_mat, atoms.positions) 116 | 117 | observation = self.observation_space.build(atoms_rotated, formula=self.formula) 118 | util.set_seeds(0) 119 | action = self.agent.step([observation]) 120 | so3_dist_rot = action['dists'][-1] 121 | scalars_rot = atomic_scalars(so3_dist_rot.coefficients) 122 | 123 | self.assertTrue(torch.allclose(scalars, scalars_rot, atol=1e-05)) 124 | 125 | def test_invariance(self): 126 | for file in ['h2o.xyz', 'ch3.xyz', 'ch4.xyz']: 127 | self.verify_invariance( 128 | atoms=ase.io.read(filename=os.path.join(self.RESOURCES, file), format='xyz', index=0)) 129 | -------------------------------------------------------------------------------- /scripts/run_solvation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import ase.data 4 | import ase.io 5 | 6 | from molgym.env_container import SimpleEnvContainer 7 | from molgym.environment import RefillableMolecularEnvironment 8 | from molgym.ppo import batch_ppo 9 | from molgym.reward import SolvationReward 10 | from molgym.spaces import ActionSpace, ObservationSpace 11 | from molgym.tools import util 12 | from molgym.tools.arg_parser import build_default_argparser 13 | from molgym.tools.model_util import ModelIO, build_model 14 | 15 | 16 | def get_config() -> dict: 17 | parser = build_default_argparser() 18 | parser.add_argument('--num_refills', 19 | help='number of times the bag gets refilled by the environment', 20 | type=int, 21 | required=False, 22 | default=0) 23 | parser.add_argument('--initial_structure', help='path to initial structure', type=str, required=False) 24 | args = parser.parse_args() 25 | config = vars(args) 26 | return config 27 | 28 | 29 | def main() -> None: 30 | config = get_config() 31 | 32 | util.create_directories([config['log_dir'], config['model_dir'], config['data_dir'], config['results_dir']]) 33 | 34 | tag = util.get_tag(config) 35 | util.setup_logger(config, directory=config['log_dir'], tag=tag) 36 | util.save_config(config, directory=config['log_dir'], tag=tag) 37 | 38 | util.set_seeds(seed=config['seed']) 39 | device = util.init_device(config['device']) 40 | 41 | zs = [ase.data.atomic_numbers[s] for s in config['symbols'].split(',')] 42 | action_space = ActionSpace(zs=zs) 43 | observation_space = ObservationSpace(canvas_size=config['canvas_size'], zs=zs) 44 | 45 | # Evaluation formulas 46 | if not config['eval_formulas']: 47 | config['eval_formulas'] = config['formulas'] 48 | 49 | train_formulas = util.split_formula_strings(config['formulas']) 50 | eval_formulas = util.split_formula_strings(config['eval_formulas']) 51 | 52 | logging.info(f'Training bags: {train_formulas}') 53 | logging.info(f'Evaluation bags: {eval_formulas}') 54 | 55 | model_handler = ModelIO(directory=config['model_dir'], tag=tag, keep=config['keep_models']) 56 | 57 | if config['load_latest']: 58 | model, start_num_steps = model_handler.load_latest(device=device) 59 | model.action_space = action_space 60 | model.observation_space = observation_space 61 | elif config['load_model'] is not None: 62 | model, start_num_steps = model_handler.load(device=device, path=config['load_model']) 63 | model.action_space = action_space 64 | model.observation_space = observation_space 65 | else: 66 | model = build_model(config, observation_space=observation_space, action_space=action_space, device=device) 67 | start_num_steps = 0 68 | 69 | var_counts = util.count_vars(model) 70 | logging.info(f'Number of parameters: {var_counts}') 71 | 72 | reward = SolvationReward() 73 | 74 | # Number of episodes during evaluation 75 | if not config['num_eval_episodes']: 76 | config['num_eval_episodes'] = len(eval_formulas) 77 | 78 | if config['initial_structure']: 79 | initial_structure = ase.io.read(config['initial_structure'], index=0, format='xyz') 80 | else: 81 | initial_structure = ase.Atoms() 82 | 83 | training_envs = SimpleEnvContainer([ 84 | RefillableMolecularEnvironment( 85 | reward=reward, 86 | observation_space=observation_space, 87 | action_space=action_space, 88 | formulas=[util.string_to_formula(f) for f in train_formulas], 89 | initial_structure=initial_structure, 90 | num_refills=config['num_refills'], 91 | min_atomic_distance=config['min_atomic_distance'], 92 | max_solo_distance=config['max_solo_distance'], 93 | min_reward=config['min_reward'], 94 | ) for _ in range(config['num_envs']) 95 | ]) 96 | 97 | eval_envs = SimpleEnvContainer([ 98 | RefillableMolecularEnvironment( 99 | reward=reward, 100 | observation_space=observation_space, 101 | action_space=action_space, 102 | formulas=[util.string_to_formula(f) for f in eval_formulas], 103 | initial_structure=initial_structure, 104 | num_refills=config['num_refills'], 105 | min_atomic_distance=config['min_atomic_distance'], 106 | max_solo_distance=config['max_solo_distance'], 107 | min_reward=config['min_reward'], 108 | ) 109 | ]) 110 | 111 | batch_ppo( 112 | envs=training_envs, 113 | eval_envs=eval_envs, 114 | ac=model, 115 | optimizer=util.get_optimizer(name=config['optimizer'], 116 | learning_rate=config['learning_rate'], 117 | parameters=model.parameters()), 118 | gamma=config['discount'], 119 | start_num_steps=start_num_steps, 120 | max_num_steps=config['max_num_steps'], 121 | num_steps_per_iter=config['num_steps_per_iter'], 122 | mini_batch_size=config['mini_batch_size'], 123 | clip_ratio=config['clip_ratio'], 124 | vf_coef=config['vf_coef'], 125 | entropy_coef=config['entropy_coef'], 126 | max_num_train_iters=config['max_num_train_iters'], 127 | lam=config['lam'], 128 | target_kl=config['target_kl'], 129 | gradient_clip=config['gradient_clip'], 130 | eval_freq=config['eval_freq'], 131 | model_handler=model_handler, 132 | save_freq=config['save_freq'], 133 | num_eval_episodes=config['num_eval_episodes'], 134 | rollout_saver=util.RolloutSaver(directory=config['data_dir'], tag=tag), 135 | save_train_rollout=config['save_rollouts'] == 'train' or config['save_rollouts'] == 'all', 136 | save_eval_rollout=config['save_rollouts'] == 'eval' or config['save_rollouts'] == 'all', 137 | info_saver=util.InfoSaver(directory=config['results_dir'], tag=tag), 138 | device=device, 139 | ) 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /molgym/tools/arg_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def build_default_argparser() -> argparse.ArgumentParser: 5 | parser = argparse.ArgumentParser(description='Command line tool of MolGym') 6 | 7 | # Name and seed 8 | parser.add_argument('--name', help='experiment name', required=True) 9 | parser.add_argument('--seed', help='run ID', type=int, default=0) 10 | 11 | # Directories 12 | parser.add_argument('--log_dir', help='directory for log files', type=str, default='logs') 13 | parser.add_argument('--model_dir', help='directory for model files', type=str, default='models') 14 | parser.add_argument('--data_dir', help='directory for saved rollouts', type=str, default='data') 15 | parser.add_argument('--results_dir', help='directory for results', type=str, default='results') 16 | 17 | # Device 18 | parser.add_argument('--device', help='select device', type=str, choices=['cpu', 'cuda'], default='cpu') 19 | 20 | # Spaces 21 | parser.add_argument('--canvas_size', 22 | help='maximum number of atoms that can be placed on the canvas', 23 | type=int, 24 | default=25) 25 | parser.add_argument('--symbols', 26 | help='chemical symbols available on canvas and in bag (comma separated)', 27 | type=str, 28 | default='X,H,C,N,O,F') 29 | 30 | # Environment 31 | parser.add_argument('--formulas', 32 | help='list of formulas for environment (comma separated)', 33 | type=str, 34 | required=True) 35 | parser.add_argument('--eval_formulas', 36 | help='list of formulas for environment (comma separated) used for evaluation', 37 | type=str, 38 | required=False) 39 | parser.add_argument('--bag_scale', help='maximum bag size', type=int, required=True) 40 | parser.add_argument('--min_atomic_distance', help='minimum allowed atomic distance', type=float, default=0.6) 41 | parser.add_argument('--max_solo_distance', 42 | help='maximum distance hydrogen or halogens can be away from the nearest heavy atom', 43 | type=float, 44 | default=2.0) 45 | parser.add_argument('--min_reward', help='minimum reward given by environment', type=float, default=-0.6) 46 | 47 | # Model 48 | parser.add_argument('--model', 49 | help='model representation', 50 | type=str, 51 | default='internal', 52 | choices=['internal', 'covariant']) 53 | parser.add_argument('--min_mean_distance', help='minimum mean distance', type=float, default=0.8) 54 | parser.add_argument('--max_mean_distance', help='maximum mean distance', type=float, default=1.8) 55 | parser.add_argument('--network_width', help='width of FC layers', type=int, default=128) 56 | parser.add_argument('--maxl', help='maximum L in spherical harmonics expansion', type=int, default=4) 57 | parser.add_argument('--num_cg_levels', help='number of CG layers', type=int, default=3) 58 | parser.add_argument('--num_channels_hidden', help='number of channels in hidden layers', type=int, default=10) 59 | parser.add_argument('--num_channels_per_element', help='number of channels per element', type=int, default=4) 60 | parser.add_argument('--num_gaussians', help='number of Gaussians in GMM', type=int, default=3) 61 | parser.add_argument('--beta', help='set beta parameter of spherical distribution', required=False, default=None) 62 | 63 | parser.add_argument('--load_latest', help='load latest checkpoint file', action='store_true', default=False) 64 | parser.add_argument('--load_model', help='load checkpoint file', type=str, default=None) 65 | parser.add_argument('--save_freq', help='save model every iterations', type=int, default=10) 66 | parser.add_argument('--eval_freq', help='evaluate model every iterations', type=int, default=10) 67 | parser.add_argument('--num_eval_episodes', help='number of episodes per evaluation', type=int, default=None) 68 | 69 | # Training algorithm 70 | parser.add_argument('--optimizer', 71 | help='Optimizer for parameter optimization', 72 | type=str, 73 | default='adam', 74 | choices=['adam', 'amsgrad']) 75 | parser.add_argument('--discount', help='discount factor', type=float, default=1.0) 76 | parser.add_argument('--num_steps', dest='max_num_steps', help='maximum number of steps', type=int, default=50000) 77 | parser.add_argument('--num_steps_per_iter', 78 | help='number of optimization steps per iteration', 79 | type=int, 80 | default=128) 81 | parser.add_argument('--mini_batch_size', help='mini batch size for training', type=int, default=64) 82 | parser.add_argument('--num_envs', help='number of environment copies', type=int, default=8) 83 | parser.add_argument('--clip_ratio', help='PPO clip ratio', type=float, default=0.2) 84 | parser.add_argument('--learning_rate', help='Learning rate of Adam optimizer', type=float, default=3e-4) 85 | parser.add_argument('--vf_coef', help='Coefficient for value function loss', type=float, default=0.5) 86 | parser.add_argument('--entropy_coef', help='Coefficient for entropy loss', type=float, default=0.01) 87 | parser.add_argument('--max_num_train_iters', help='Maximum number of training iterations', type=int, default=7) 88 | parser.add_argument('--gradient_clip', help='maximum norm of gradients', type=float, default=0.5) 89 | parser.add_argument('--lam', help='Lambda for GAE-Lambda', type=float, default=0.97) 90 | parser.add_argument('--target_kl', 91 | help='KL divergence between new and old policies after an update for early stopping', 92 | type=float, 93 | default=0.01) 94 | 95 | # Logging 96 | parser.add_argument('--log_level', help='log level', type=str, default='INFO') 97 | parser.add_argument('--keep_models', help='keep all models', action='store_true', default=False) 98 | parser.add_argument('--save_rollouts', 99 | help='which rollouts to save', 100 | type=str, 101 | default='none', 102 | choices=['none', 'train', 'eval', 'all']) 103 | 104 | return parser 105 | -------------------------------------------------------------------------------- /molgym/tools/util.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import logging 4 | import os 5 | import pickle 6 | import sys 7 | from typing import Optional, List, Iterable, Tuple, Dict 8 | 9 | import ase.data 10 | import ase.formula 11 | import numpy as np 12 | import scipy.signal 13 | import torch 14 | from ase.formula import Formula 15 | from torch.optim import Adam 16 | from torch.optim.optimizer import Optimizer 17 | 18 | from molgym.spaces import FormulaType 19 | 20 | 21 | def string_to_formula(string: str) -> FormulaType: 22 | d = Formula(string).count().items() 23 | return tuple((ase.data.atomic_numbers[symbol], count) for symbol, count in d) 24 | 25 | 26 | def zs_to_formula(zs: List[int]) -> FormulaType: 27 | counter: Dict[int, int] = collections.Counter() 28 | for z in zs: 29 | counter[z] += 1 30 | return tuple(counter.items()) 31 | 32 | 33 | def remove_atom_from_formula(formula: FormulaType, atomic_number: int) -> FormulaType: 34 | copy = list(formula) 35 | for i, (z, count) in enumerate(formula): 36 | if z == atomic_number and count >= 1: 37 | copy[i] = (z, count - 1) 38 | return tuple(copy) 39 | 40 | raise RuntimeError(f"Could not remove atomic number {atomic_number} from bag {formula}") 41 | 42 | 43 | def get_formula_size(formula: FormulaType) -> int: 44 | return sum(count for z, count in formula) 45 | 46 | 47 | def to_numpy(t: torch.Tensor) -> np.ndarray: 48 | return t.cpu().detach().numpy() 49 | 50 | 51 | def combined_shape(length: int, shape: Optional[tuple] = None) -> tuple: 52 | if shape is None: 53 | return length, 54 | return (length, shape) if np.isscalar(shape) else (length, *shape) 55 | 56 | 57 | def count_vars(module: torch.nn.Module) -> int: 58 | return sum(np.prod(p.shape) for p in module.parameters()) 59 | 60 | 61 | def compute_gradient_norm(parameters: Iterable[torch.nn.Parameter], norm_type: int = 2) -> float: 62 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 63 | if len(parameters) == 0: 64 | return 0.0 65 | device = parameters[0].grad.device # type: ignore 66 | total_norm = torch.norm( 67 | torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), # type: ignore 68 | norm_type) 69 | return total_norm.item() 70 | 71 | 72 | def discount_cumsum(x: np.ndarray, discount: float) -> np.ndarray: 73 | """ 74 | magic from rllab for computing discounted cumulative sums of vectors. 75 | 76 | input: 77 | vector x, 78 | [x0, 79 | x1, 80 | x2] 81 | 82 | output: 83 | [x0 + discount * x1 + discount^2 * x2, 84 | x1 + discount * x2, 85 | x2] 86 | """ 87 | return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1] 88 | 89 | 90 | def set_seeds(seed: int) -> None: 91 | np.random.seed(seed) 92 | torch.manual_seed(seed) 93 | 94 | 95 | def split_formula_strings(formulas: str) -> List[str]: 96 | return formulas.split(',') 97 | 98 | 99 | def parse_size_range(size_range: str) -> Tuple[int, int]: 100 | parsed_range = [int(i) for i in size_range.split(',')] 101 | assert len(parsed_range) == 2 102 | return parsed_range[0], parsed_range[1] 103 | 104 | 105 | def get_tag(config: dict) -> str: 106 | return '{exp}_run-{seed}'.format(exp=config['name'], seed=config['seed']) 107 | 108 | 109 | def save_config(config: dict, directory: str, tag: str, verbose=True): 110 | formatted = json.dumps(config, indent=4, sort_keys=True) 111 | 112 | if verbose: 113 | logging.info(formatted) 114 | 115 | path = os.path.join(directory, tag + '.json') 116 | with open(file=path, mode='w') as f: 117 | f.write(formatted) 118 | 119 | 120 | def create_directories(directories: List[str]): 121 | for directory in directories: 122 | os.makedirs(directory, exist_ok=True) 123 | 124 | 125 | def setup_logger(config: dict, directory, tag: str): 126 | logger = logging.getLogger() 127 | logger.setLevel(config['log_level']) 128 | 129 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 130 | 131 | ch = logging.StreamHandler(stream=sys.stdout) 132 | ch.setFormatter(formatter) 133 | logger.addHandler(ch) 134 | 135 | path = os.path.join(directory, tag + '.log') 136 | fh = logging.FileHandler(path) 137 | fh.setFormatter(formatter) 138 | 139 | logger.addHandler(fh) 140 | 141 | 142 | def setup_simple_logger(path: str = None, log_level=logging.INFO): 143 | logger = logging.getLogger() 144 | logger.setLevel(log_level) 145 | formatter = logging.Formatter('%(message)s') 146 | 147 | ch = logging.StreamHandler(stream=sys.stdout) 148 | ch.setFormatter(formatter) 149 | logger.addHandler(ch) 150 | 151 | if path: 152 | fh = logging.FileHandler(path, mode='w') 153 | fh.setFormatter(formatter) 154 | logger.addHandler(fh) 155 | 156 | 157 | class RolloutSaver: 158 | def __init__(self, directory: str, tag: str): 159 | self.directory = directory 160 | self.tag = tag 161 | self._suffix = '.pkl' 162 | 163 | def save(self, obj: object, num_steps: int, info: str): 164 | added = f'steps-{num_steps}' 165 | 166 | path = os.path.join(self.directory, self.tag + '_' + added + '_' + info + self._suffix) 167 | logging.debug(f'Saving rollout: {path}') 168 | with open(path, mode='wb') as f: 169 | pickle.dump(obj, f) 170 | 171 | 172 | class InfoSaver: 173 | def __init__(self, directory: str, tag: str): 174 | self.directory = directory 175 | self.tag = tag 176 | self._suffix = '.txt' 177 | 178 | def save(self, obj: object, name: str): 179 | path = os.path.join(self.directory, self.tag + '_' + name + self._suffix) 180 | logging.debug(f'Saving info: {path}') 181 | with open(path, mode='a') as f: 182 | f.write(json.dumps(obj)) 183 | f.write('\n') 184 | 185 | 186 | def init_device(device_str: str) -> torch.device: 187 | if device_str == 'cuda': 188 | assert (torch.cuda.is_available()), 'No CUDA device available!' 189 | logging.info('CUDA Device: {}'.format(torch.cuda.current_device())) 190 | torch.cuda.init() 191 | return torch.device('cuda') 192 | else: 193 | logging.info('Using CPU') 194 | return torch.device('cpu') 195 | 196 | 197 | def get_optimizer(name: str, learning_rate: float, parameters: Iterable[torch.Tensor]) -> Optimizer: 198 | if name == 'adam': 199 | amsgrad = False 200 | elif name == 'amsgrad': 201 | amsgrad = True 202 | else: 203 | raise RuntimeError(f"Unknown optimizer '{name}'") 204 | 205 | return Adam(parameters, lr=learning_rate, amsgrad=amsgrad) 206 | -------------------------------------------------------------------------------- /molgym/agents/covariant/so3_tools.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import torch 5 | from cormorant.so3_lib import SO3Vec 6 | 7 | 8 | def generate_fibonacci_grid(n: int) -> np.ndarray: 9 | # Based on: http://extremelearning.com.au/how-to-evenly-distribute-points-on-a-sphere-more-effectively-than-the-canonical-fibonacci-lattice/ 10 | golden_ratio = (1 + 5**0.5) / 2 11 | offset = 0.5 12 | 13 | index = np.arange(0, n) 14 | theta = np.arccos(1 - 2 * (index + offset) / n) 15 | phi = 2 * np.pi * index / golden_ratio 16 | 17 | theta_phi = np.stack([theta, phi], axis=-1) 18 | 19 | return spherical_to_cartesian(theta_phi) 20 | 21 | 22 | def spherical_to_cartesian(theta_phi: np.ndarray) -> np.ndarray: 23 | theta, phi = theta_phi[..., 0], theta_phi[..., 1] 24 | x = np.sin(theta) * np.cos(phi) 25 | y = np.sin(theta) * np.sin(phi) 26 | z = np.cos(theta) 27 | return np.stack([x, y, z], axis=-1) 28 | 29 | 30 | def cartesian_to_spherical(pos: np.ndarray) -> np.ndarray: 31 | theta_phi = np.empty(shape=pos.shape[:-1] + (2, )) 32 | 33 | x, y, z = pos[..., 0], pos[..., 1], pos[..., 2] 34 | r = np.linalg.norm(pos, axis=-1) 35 | theta_phi[..., 0] = np.arccos(z / r) # theta 36 | theta_phi[..., 1] = np.arctan2(y, x) # phi 37 | 38 | return theta_phi 39 | 40 | 41 | def complex_product(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 42 | a_r, a_i = a.unbind(-1) 43 | b_r, b_i = b.unbind(-1) 44 | return torch.stack([a_r * b_r - a_i * b_i, a_i * b_r + a_r * b_i], dim=-1) 45 | 46 | 47 | def sum_product_alms_ylms(a_lms: SO3Vec, y_lms: SO3Vec) -> torch.Tensor: 48 | # Dimensions of SO3Vec's for each ell: (batches, taus, ms, 2) 49 | assert a_lms.ells == y_lms.ells 50 | 51 | summands = [] 52 | for ell in a_lms.ells: 53 | product = complex_product(a_lms[ell], y_lms[ell]) 54 | summand = torch.einsum('...tmx->...x', product) # sum over tau and m 55 | summands.append(summand) 56 | 57 | # sum over ell's 58 | return torch.sum(torch.stack(summands, dim=0), dim=0) # (..., batches, 2) 59 | 60 | 61 | def get_normalization_constant(a_lms: SO3Vec) -> torch.Tensor: 62 | # Dimensions of SO3Vec's for each ell: (batches, taus, ms, 2) 63 | summands = [] 64 | for ell in a_lms.ells: 65 | a_lm = torch.einsum('...btmx->...bmx', a_lms[ell]) # sum over tau's 66 | squared = torch.square(a_lm) 67 | item = torch.einsum('...bmx->...b', squared) # sum over m's and real and imaginary components 68 | summands.append(item) 69 | 70 | return torch.sum(torch.stack(summands, dim=0), dim=0) # sum over ell's 71 | 72 | 73 | def normalize_alms(a_lms: SO3Vec) -> SO3Vec: 74 | # Normalize a_lms such that: 75 | # \sum_\ell \sum_m | a_lm |^2 = 1 76 | k = get_normalization_constant(a_lms) # [batches] 77 | clamped_k = k.clamp(min=1e-10) 78 | sqrt_k = torch.sqrt(clamped_k).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # [batches, 1, 1, 1] 79 | return SO3Vec([part / sqrt_k for part in a_lms]) 80 | 81 | 82 | def estimate_alms(y_lms_conj: SO3Vec) -> SO3Vec: 83 | # Dimensions of SO3Vec's for each ell: (batches, taus, ms, 2) 84 | 85 | # Compute mean over samples 86 | means = [] 87 | for ell in y_lms_conj.ells: 88 | # select all batch dimensions 89 | dim = list(range(len(y_lms_conj[ell].shape) - 3)) 90 | means.append(torch.mean(y_lms_conj[ell], dim=dim, keepdim=True)) 91 | return SO3Vec(means) 92 | 93 | 94 | def concat_so3vecs(so3vecs: List[SO3Vec]) -> SO3Vec: 95 | # Concat SO3Vecs along batch dimension 96 | # Dimensions of SO3Vec's for each ell: (batches, taus, ms, 2) 97 | 98 | # Ensure that all SO3 vectors are of the same kind 99 | assert all(so3vec.ells == so3vecs[0].ells for so3vec in so3vecs) 100 | 101 | return SO3Vec(list(map(lambda tensors: torch.cat(tensors, dim=0), zip(*so3vecs)))) 102 | 103 | 104 | def unsqueeze_so3vec(vec: SO3Vec, dim: int) -> SO3Vec: 105 | return SO3Vec([t.unsqueeze(dim) for t in vec]) 106 | 107 | 108 | def select_atomic_covariats(vec: SO3Vec, focus: torch.Tensor) -> SO3Vec: 109 | # vec (per ell): [batches, atoms, taus, ms, 2] 110 | # focus: [batches, atoms] 111 | vectors = [] 112 | for ell in vec.ells: 113 | vectors.append(torch.einsum('ba,batmx->btmx', focus, vec[ell])) # type: ignore 114 | 115 | return SO3Vec(vectors) # (batches, taus, ms, 2) 116 | 117 | 118 | def select_taus(vec: SO3Vec, indices: torch.Tensor) -> SO3Vec: 119 | vectors = [] 120 | # vec: (..., taus, ms, 2) 121 | for ell in vec.ells: 122 | gather_indices = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, (2 * ell + 1), 2) 123 | vectors.append(torch.gather(vec[ell], dim=1, index=gather_indices)) 124 | 125 | return SO3Vec(vectors) # (..., sliced_taus, ms, 2) 126 | 127 | 128 | def select_atomic_invariats(invariats: torch.Tensor, focus: torch.Tensor) -> torch.Tensor: 129 | # invariats: [batches, atoms, feats] 130 | # focus: [batches, atoms] 131 | # return: [batches, feats] 132 | return torch.einsum('ba,baf->bf', focus, invariats) # type: ignore 133 | 134 | 135 | def select_element(vec: SO3Vec, element_oh: torch.Tensor) -> SO3Vec: 136 | # vec (per ell): [batches, taus, ms, 2] 137 | # element_oh: [batches, taus] 138 | tensors = [] 139 | for ell in vec.ells: 140 | t = torch.einsum('bt,btmx->bmx', element_oh, vec[ell]) # type: ignore # [batches, ms, 2] 141 | t = t.unsqueeze(dim=-3) # [batches, 1, ms, 2] 142 | tensors.append(t) 143 | 144 | return SO3Vec(tensors) # [batches, 1, ms, 2] 145 | 146 | 147 | class AtomicScalars(torch.nn.Module): 148 | """ 149 | Based on Cormorant's GetScalarsAtom class. 150 | Construct a set of scalar feature vectors for each atom by using the 151 | covariant atom :class:`SO3Vec` representations. 152 | """ 153 | def __init__(self, maxl: int, full_scalars=True, device=None, dtype=torch.float): 154 | super().__init__() 155 | 156 | self.device = device 157 | self.dtype = dtype 158 | 159 | self.maxl = maxl 160 | 161 | signs = [torch.pow(-1, torch.arange(-m, m + 1)) for m in range(self.maxl + 1)] 162 | signs = [torch.stack([s, -s], dim=-1) for s in signs] 163 | self.signs = [s.to(device=self.device, dtype=self.dtype) for s in signs] # (ms, 2) 164 | 165 | self.full_scalars = full_scalars 166 | 167 | def get_output_dim(self, channels: int) -> int: 168 | if self.full_scalars: 169 | return (self.maxl + 2) * channels * 2 170 | else: 171 | return channels * 2 172 | 173 | def forward(self, vec: SO3Vec) -> torch.Tensor: 174 | # Selection of invariant part 175 | scalars = [vec[0]] # (..., taus, 1, 2) 176 | 177 | if self.full_scalars: 178 | # Scalar product with itself 179 | scalars_prod = [(sign * part * part.flip(-2)).sum(dim=(-1, -2), keepdim=True) 180 | for part, sign in zip(vec, self.signs)] # (..., taus, 1, 1) 181 | 182 | # SO3 invariant norm 183 | scalars_norm = [(part * part).sum(dim=(-1, -2), keepdim=True) for part in vec] # (..., taus, 1, 1) 184 | 185 | # Put invariant components together 186 | # (..., taus, 1, 2) 187 | scalars += [torch.cat([s_prod, s_norm], dim=-1) for s_prod, s_norm in zip(scalars_prod, scalars_norm)] 188 | 189 | # Concat parts together along tau dimension 190 | scalars_cat = torch.cat(scalars, dim=-3) # (..., x * taus, 1, 2) 191 | 192 | return scalars_cat.flatten(start_dim=-3) # (..., output_dim) 193 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | # Align closing bracket with visual indentation. 3 | align_closing_bracket_with_visual_indent=True 4 | 5 | # Allow dictionary keys to exist on multiple lines. For example: 6 | # 7 | # x = { 8 | # ('this is the first element of a tuple', 9 | # 'this is the second element of a tuple'): 10 | # value, 11 | # } 12 | allow_multiline_dictionary_keys=False 13 | 14 | # Allow lambdas to be formatted on more than one line. 15 | allow_multiline_lambdas=False 16 | 17 | # Allow splits before the dictionary value. 18 | allow_split_before_dict_value=True 19 | 20 | # Number of blank lines surrounding top-level function and class 21 | # definitions. 22 | blank_lines_around_top_level_definition=2 23 | 24 | # Insert a blank line before a class-level docstring. 25 | blank_line_before_class_docstring=False 26 | 27 | # Insert a blank line before a module docstring. 28 | blank_line_before_module_docstring=False 29 | 30 | # Insert a blank line before a 'def' or 'class' immediately nested 31 | # within another 'def' or 'class'. For example: 32 | # 33 | # class Foo: 34 | # # <------ this blank line 35 | # def method(): 36 | # ... 37 | blank_line_before_nested_class_or_def=False 38 | 39 | # Do not split consecutive brackets. Only relevant when 40 | # dedent_closing_brackets is set. For example: 41 | # 42 | # call_func_that_takes_a_dict( 43 | # { 44 | # 'key1': 'value1', 45 | # 'key2': 'value2', 46 | # } 47 | # ) 48 | # 49 | # would reformat to: 50 | # 51 | # call_func_that_takes_a_dict({ 52 | # 'key1': 'value1', 53 | # 'key2': 'value2', 54 | # }) 55 | coalesce_brackets=False 56 | 57 | # The column limit. 58 | column_limit=120 59 | 60 | # The style for continuation alignment. Possible values are: 61 | # 62 | # - SPACE: Use spaces for continuation alignment. This is default behavior. 63 | # - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns 64 | # (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs) for continuation 65 | # alignment. 66 | # - LESS: Slightly left if cannot vertically align continuation lines with 67 | # indent characters. 68 | # - VALIGN-RIGHT: Vertically align continuation lines with indent 69 | # characters. Slightly right (one more indent character) if cannot 70 | # vertically align continuation lines with indent characters. 71 | # 72 | # For options FIXED, and VALIGN-RIGHT are only available when USE_TABS is 73 | # enabled. 74 | continuation_align_style=SPACE 75 | 76 | # Indent width used for line continuations. 77 | continuation_indent_width=4 78 | 79 | # Put closing brackets on a separate line, dedented, if the bracketed 80 | # expression can't fit in a single line. Applies to all kinds of brackets, 81 | # including function definitions and calls. For example: 82 | # 83 | # config = { 84 | # 'key1': 'value1', 85 | # 'key2': 'value2', 86 | # } # <--- this bracket is dedented and on a separate line 87 | # 88 | # time_series = self.remote_client.query_entity_counters( 89 | # entity='dev3246.region1', 90 | # key='dns.query_latency_tcp', 91 | # transform=Transformation.AVERAGE(window=timedelta(seconds=60)), 92 | # start_ts=now()-timedelta(days=3), 93 | # end_ts=now(), 94 | # ) # <--- this bracket is dedented and on a separate line 95 | dedent_closing_brackets=False 96 | 97 | # Place each dictionary entry onto its own line. 98 | each_dict_entry_on_separate_line=True 99 | 100 | # The regex for an i18n comment. The presence of this comment stops 101 | # reformatting of that line, because the comments are required to be 102 | # next to the string they translate. 103 | i18n_comment= 104 | 105 | # The i18n function call names. The presence of this function stops 106 | # reformattting on that line, because the string it has cannot be moved 107 | # away from the i18n comment. 108 | i18n_function_call= 109 | 110 | # Indent the dictionary value if it cannot fit on the same line as the 111 | # dictionary key. For example: 112 | # 113 | # config = { 114 | # 'key1': 115 | # 'value1', 116 | # 'key2': value1 + 117 | # value2, 118 | # } 119 | indent_dictionary_value=False 120 | 121 | # The number of columns to use for indentation. 122 | indent_width=4 123 | 124 | # Join short lines into one line. E.g., single line 'if' statements. 125 | join_multiple_lines=True 126 | 127 | # Do not include spaces around selected binary operators. For example: 128 | # 129 | # 1 + 2 * 3 - 4 / 5 130 | # 131 | # will be formatted as follows when configured with a value "*,/": 132 | # 133 | # 1 + 2*3 - 4/5 134 | # 135 | no_spaces_around_selected_binary_operators=set() 136 | 137 | # Use spaces around default or named assigns. 138 | spaces_around_default_or_named_assign=False 139 | 140 | # Use spaces around the power operator. 141 | spaces_around_power_operator=False 142 | 143 | # The number of spaces required before a trailing comment. 144 | spaces_before_comment=2 145 | 146 | # Insert a space between the ending comma and closing bracket of a list, 147 | # etc. 148 | space_between_ending_comma_and_closing_bracket=True 149 | 150 | # Split before arguments 151 | split_all_comma_separated_values=False 152 | 153 | # Split before arguments if the argument list is terminated by a 154 | # comma. 155 | split_arguments_when_comma_terminated=False 156 | 157 | # Set to True to prefer splitting before '&', '|' or '^' rather than 158 | # after. 159 | split_before_bitwise_operator=True 160 | 161 | # Split before the closing bracket if a list or dict literal doesn't fit on 162 | # a single line. 163 | split_before_closing_bracket=True 164 | 165 | # Split before a dictionary or set generator (comp_for). For example, note 166 | # the split before the 'for': 167 | # 168 | # foo = { 169 | # variable: 'Hello world, have a nice day!' 170 | # for variable in bar if variable != 42 171 | # } 172 | split_before_dict_set_generator=True 173 | 174 | # Split after the opening paren which surrounds an expression if it doesn't 175 | # fit on a single line. 176 | split_before_expression_after_opening_paren=False 177 | 178 | # If an argument / parameter list is going to be split, then split before 179 | # the first argument. 180 | split_before_first_argument=False 181 | 182 | # Set to True to prefer splitting before 'and' or 'or' rather than 183 | # after. 184 | split_before_logical_operator=True 185 | 186 | # Split named assignments onto individual lines. 187 | split_before_named_assigns=True 188 | 189 | # Set to True to split list comprehensions and generators that have 190 | # non-trivial expressions and multiple clauses before each of these 191 | # clauses. For example: 192 | # 193 | # result = [ 194 | # a_long_var + 100 for a_long_var in xrange(1000) 195 | # if a_long_var % 10] 196 | # 197 | # would reformat to something like: 198 | # 199 | # result = [ 200 | # a_long_var + 100 201 | # for a_long_var in xrange(1000) 202 | # if a_long_var % 10] 203 | split_complex_comprehension=False 204 | 205 | # The penalty for splitting right after the opening bracket. 206 | split_penalty_after_opening_bracket=30 207 | 208 | # The penalty for splitting the line after a unary operator. 209 | split_penalty_after_unary_operator=10000 210 | 211 | # The penalty for splitting right before an if expression. 212 | split_penalty_before_if_expr=0 213 | 214 | # The penalty of splitting the line around the '&', '|', and '^' 215 | # operators. 216 | split_penalty_bitwise_operator=300 217 | 218 | # The penalty for splitting a list comprehension or generator 219 | # expression. 220 | split_penalty_comprehension=120 221 | 222 | # The penalty for characters over the column limit. 223 | split_penalty_excess_character=4500 224 | 225 | # The penalty incurred by adding a line split to the unwrapped line. The 226 | # more line splits added the higher the penalty. 227 | split_penalty_for_added_line_split=30 228 | 229 | # The penalty of splitting a list of "import as" names. For example: 230 | # 231 | # from a_very_long_or_indented_module_name_yada_yad import (long_argument_1, 232 | # long_argument_2, 233 | # long_argument_3) 234 | # 235 | # would reformat to something like: 236 | # 237 | # from a_very_long_or_indented_module_name_yada_yad import ( 238 | # long_argument_1, long_argument_2, long_argument_3) 239 | split_penalty_import_names=0 240 | 241 | # The penalty of splitting the line around the 'and' and 'or' 242 | # operators. 243 | split_penalty_logical_operator=300 244 | 245 | # Use the Tab character for indentation. 246 | use_tabs=False 247 | 248 | -------------------------------------------------------------------------------- /molgym/agents/covariant/modules.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from cormorant.cg_lib import CGModule, SphericalHarmonicsRel, CGProduct 5 | from cormorant.models.cormorant_cg import CormorantCG 6 | from cormorant.models.cormorant_qm9 import expand_var_list 7 | from cormorant.nn import NoLayer, RadialFilters, CatMixReps, InputLinear 8 | from cormorant.so3_lib import SO3Vec 9 | 10 | 11 | class Cormorant(CGModule): 12 | def __init__( 13 | self, 14 | maxl, 15 | max_sh, 16 | num_cg_levels, 17 | num_channels, 18 | num_species, 19 | cutoff_type, 20 | hard_cut_rad, 21 | soft_cut_rad, 22 | soft_cut_width, 23 | weight_init, 24 | level_gain, 25 | charge_power, 26 | basis_set, 27 | charge_scale, 28 | bag_scale, 29 | device=None, 30 | dtype=None, 31 | cg_dict=None, 32 | ) -> None: 33 | # Parameters 34 | level_gain = expand_var_list(level_gain, num_cg_levels) 35 | hard_cut_rad = expand_var_list(hard_cut_rad, num_cg_levels) 36 | soft_cut_rad = expand_var_list(soft_cut_rad, num_cg_levels) 37 | soft_cut_width = expand_var_list(soft_cut_width, num_cg_levels) 38 | maxl = expand_var_list(maxl, num_cg_levels) 39 | max_sh = expand_var_list(max_sh, num_cg_levels) 40 | num_channels = expand_var_list(num_channels, num_cg_levels + 1) 41 | 42 | super().__init__(maxl=max(maxl + max_sh), device=device, dtype=dtype, cg_dict=cg_dict) 43 | 44 | self.num_cg_levels = num_cg_levels 45 | self.num_channels = num_channels 46 | self.charge_power = charge_power 47 | self.charge_scale = charge_scale 48 | self.bag_scale = bag_scale 49 | self.num_species = num_species 50 | 51 | # Set up spherical harmonics 52 | self.sph_harms = SphericalHarmonicsRel(maxl=max(max_sh), 53 | conj=True, 54 | device=self.device, 55 | dtype=self.dtype, 56 | cg_dict=self.cg_dict) 57 | 58 | # Set up position functions, now independent of spherical harmonics 59 | self.rad_funcs = RadialFilters( 60 | max_sh=max_sh, 61 | basis_set=basis_set, 62 | num_channels_out=num_channels, 63 | num_levels=num_cg_levels, 64 | device=self.device, 65 | dtype=self.dtype, 66 | ) 67 | tau_pos = self.rad_funcs.tau 68 | 69 | num_scalars_in = self.num_species * (self.charge_power + 1) + self.num_species 70 | num_scalars_out = num_channels[0] 71 | 72 | self.input_func_atom = InputLinear(num_scalars_in, num_scalars_out, device=self.device, dtype=self.dtype) 73 | self.input_func_edge = NoLayer() 74 | 75 | tau_in_atom = self.input_func_atom.tau 76 | tau_in_edge = self.input_func_edge.tau 77 | 78 | self.cormorant_cg = CormorantCG(maxl=maxl, 79 | max_sh=max_sh, 80 | tau_in_atom=tau_in_atom, 81 | tau_in_edge=tau_in_edge, 82 | tau_pos=tau_pos, 83 | num_cg_levels=num_cg_levels, 84 | num_channels=num_channels, 85 | level_gain=level_gain, 86 | weight_init=weight_init, 87 | cutoff_type=cutoff_type, 88 | hard_cut_rad=hard_cut_rad, 89 | soft_cut_rad=soft_cut_rad, 90 | soft_cut_width=soft_cut_width, 91 | cat=True, 92 | gaussian_mask=False, 93 | device=self.device, 94 | dtype=self.dtype, 95 | cg_dict=self.cg_dict) 96 | 97 | def forward(self, data) -> SO3Vec: 98 | # Get and prepare the data 99 | atom_scalars, atom_mask, edge_scalars, edge_mask, atom_positions = self.prepare_input(data) 100 | 101 | # Calculate spherical harmonics and radial functions 102 | spherical_harmonics, norms = self.sph_harms(atom_positions, atom_positions) 103 | rad_func_levels = self.rad_funcs(norms, edge_mask * (norms > 0)) 104 | 105 | # Prepare the input reps for both the atom and edge network 106 | atom_reps_in = self.input_func_atom(atom_scalars, atom_mask, edge_scalars, edge_mask, norms) 107 | edge_net_in = self.input_func_edge(atom_scalars, atom_mask, edge_scalars, edge_mask, norms) 108 | 109 | # Clebsch-Gordan layers central to the network 110 | atoms_all, edges_all = self.cormorant_cg(atom_reps_in, atom_mask, edge_net_in, edge_mask, rad_func_levels, 111 | norms, spherical_harmonics) 112 | 113 | # Return last atomic layer 114 | return atoms_all[-1] 115 | 116 | def prepare_input(self, data) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 117 | atom_positions = data['positions'].to(self.device, self.dtype) 118 | one_hot = data['one_hot'].to(self.device, self.dtype) 119 | charges = data['charges'].to(self.device, self.dtype) 120 | 121 | atom_mask = data['atom_mask'].to(self.device) 122 | edge_mask = data['edge_mask'].to(self.device) 123 | 124 | charge_tensor = (charges.unsqueeze(-1) / self.charge_scale).pow( 125 | torch.arange(self.charge_power + 1, device=self.device, dtype=self.dtype)) 126 | charge_tensor = charge_tensor.view(charges.shape + (1, self.charge_power + 1)) 127 | charge_tensor = (one_hot.unsqueeze(-1) * charge_tensor).view(charges.shape[:2] + (-1, )) 128 | 129 | bag_tiled = (data['bags'] / self.bag_scale).unsqueeze(1) # (batches, 1, feats) 130 | bag_tiled = bag_tiled.expand(charge_tensor.shape[:-1] + (-1, )) # (batches, atoms, feats) 131 | atom_scalars = torch.cat([charge_tensor, bag_tiled], dim=-1) 132 | 133 | edge_scalars = torch.tensor([]) 134 | 135 | return atom_scalars, atom_mask, edge_scalars, edge_mask, atom_positions 136 | 137 | 138 | class CormorantMixer(CGModule): 139 | def __init__(self, 140 | tau_in, 141 | tau_other, 142 | maxl, 143 | num_channels, 144 | level_gain, 145 | weight_init, 146 | device=None, 147 | dtype=None, 148 | cg_dict=None) -> None: 149 | super().__init__(maxl=maxl, device=device, dtype=dtype, cg_dict=cg_dict) 150 | 151 | self.tau_in = tau_in 152 | self.tau_other = tau_other 153 | 154 | # Operations linear in input reps 155 | self.cg_aggregate = CGProduct(self.tau_other, 156 | self.tau_in, 157 | maxl=self.maxl, 158 | device=self.device, 159 | dtype=self.dtype, 160 | cg_dict=self.cg_dict) 161 | tau_ag = list(self.cg_aggregate.tau) 162 | 163 | self.cg_power = CGProduct(tau_ag, 164 | tau_ag, 165 | maxl=self.maxl, 166 | device=self.device, 167 | dtype=self.dtype, 168 | cg_dict=self.cg_dict) 169 | tau_sq = list(self.cg_power.tau) 170 | 171 | self.cat_mix = CatMixReps([tau_ag, tau_sq, self.tau_in], 172 | num_channels, 173 | maxl=self.maxl, 174 | weight_init=weight_init, 175 | gain=level_gain, 176 | device=self.device, 177 | dtype=self.dtype) 178 | self.tau = self.cat_mix.tau 179 | 180 | def forward(self, atom_reps, other_reps): 181 | # Aggregate information based upon other reps 182 | reps_ag = self.cg_aggregate(other_reps, atom_reps) 183 | 184 | # CG non-linearity 185 | reps_sq = self.cg_power(reps_ag, reps_ag) 186 | 187 | # Concatenate and mix results 188 | reps_out = self.cat_mix([reps_ag, reps_sq, atom_reps]) 189 | 190 | return reps_out 191 | -------------------------------------------------------------------------------- /tests/agents/covariant/test_spherical_distr.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | import torch 5 | from cormorant.cg_lib import SphericalHarmonics 6 | 7 | from molgym.agents.covariant.so3_tools import (spherical_to_cartesian, estimate_alms, concat_so3vecs, 8 | cartesian_to_spherical, generate_fibonacci_grid) 9 | from molgym.agents.covariant.spherical_dists import SphericalUniform, SO3Distribution, ExpSO3Distribution 10 | from molgym.tools.util import to_numpy 11 | 12 | 13 | class SphericalUniformTest(TestCase): 14 | def setUp(self): 15 | self.dist = SphericalUniform() 16 | 17 | def test_shape(self): 18 | num_samples = 1000 19 | samples = self.dist.sample(torch.Size((num_samples, ))) 20 | self.assertTrue(samples.shape == (num_samples, 3)) 21 | 22 | def test_min_max(self): 23 | self.dist = SphericalUniform(batch_shape=(3, )) 24 | self.assertTrue(self.dist.get_max_prob().shape == (3, )) 25 | 26 | def test_distance(self): 27 | num_samples = 1000 28 | samples = self.dist.sample(torch.Size((num_samples, ))) 29 | self.assertTrue(np.allclose(samples.norm(dim=-1), 1.)) 30 | 31 | def test_mean(self): 32 | torch.manual_seed(1) 33 | num_samples = 200_000 34 | self.dist = SphericalUniform() 35 | samples = self.dist.sample(torch.Size((num_samples, ))) 36 | self.assertAlmostEqual(samples.mean(0).norm().item(), 0, places=2) 37 | 38 | def test_argmax(self): 39 | dist = SphericalUniform(batch_shape=torch.Size((3, ))) 40 | arg_maxes = dist.argmax() 41 | self.assertEqual(arg_maxes.shape, dist.batch_shape + dist.event_shape) 42 | 43 | 44 | class SphericalDistributionTest(TestCase): 45 | def setUp(self): 46 | self.maxl = 3 47 | self.sphs = SphericalHarmonics(maxl=self.maxl, sh_norm='qm') 48 | sphs_conj = SphericalHarmonics(maxl=self.maxl, conj=True, sh_norm='qm') 49 | 50 | # Generate some reference point(s) 51 | phi_refs = np.array([ 52 | np.pi / 2, 53 | -np.pi / 2, 54 | ]) 55 | theta_refs = np.pi / 2 * np.ones_like(phi_refs) 56 | theta_phi_refs = np.stack([theta_refs, phi_refs], axis=-1) 57 | xyz_refs = spherical_to_cartesian(theta_phi_refs) 58 | y_lms_conj = sphs_conj.forward(torch.tensor(xyz_refs, dtype=torch.float)) 59 | self.a_lms_1 = estimate_alms(y_lms_conj) 60 | 61 | # Another set of a_lms 62 | phi_refs = np.array([np.pi / 3]) 63 | theta_refs = np.pi / 3 * np.ones_like(phi_refs) 64 | theta_phi_refs = np.stack([theta_refs, phi_refs], axis=-1) 65 | xyz_refs = spherical_to_cartesian(theta_phi_refs) 66 | y_lms_conj = sphs_conj.forward(torch.tensor(xyz_refs, dtype=torch.float)) 67 | self.a_lms_2 = estimate_alms(y_lms_conj) 68 | 69 | def test_max(self): 70 | a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2, self.a_lms_1]) 71 | so3_distr = SO3Distribution(a_lms=a_lms, sphs=self.sphs) 72 | self.assertEqual(so3_distr.get_max_prob().shape, (3, )) 73 | 74 | def test_sample(self): 75 | torch.manual_seed(1) 76 | samples_shape = (2048, ) 77 | 78 | a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2]) 79 | so3_distr = SO3Distribution(a_lms=a_lms, sphs=self.sphs) 80 | samples = so3_distr.sample(samples_shape) 81 | 82 | self.assertEqual(samples.shape, samples_shape + so3_distr.batch_shape + so3_distr.event_shape) 83 | 84 | angles = cartesian_to_spherical(to_numpy(samples)) # [S, B, 2] 85 | mean_angles = np.mean(angles, axis=0) # [B, 2] 86 | 87 | self.assertEqual(mean_angles.shape, (2, 2)) 88 | 89 | so3_distr_1 = SO3Distribution(a_lms=self.a_lms_1, sphs=self.sphs) 90 | samples_1 = so3_distr_1.sample(samples_shape) 91 | angles_1 = cartesian_to_spherical(to_numpy(samples_1)) # [S, 1, 2] 92 | mean_angles_1 = np.mean(angles_1, axis=0) # [1, 2] 93 | 94 | so3_distr_2 = SO3Distribution(a_lms=self.a_lms_2, sphs=self.sphs) 95 | samples_2 = so3_distr_2.sample(samples_shape) 96 | angles_2 = cartesian_to_spherical(to_numpy(samples_2)) # [S, 1, 2] 97 | mean_angles_2 = np.mean(angles_2, axis=0) # [1, 2] 98 | 99 | # Assert that batching does not affect the result 100 | self.assertTrue(np.allclose(mean_angles[0], mean_angles_1, atol=0.1)) 101 | self.assertTrue(np.allclose(mean_angles[1], mean_angles_2, atol=0.1)) 102 | 103 | def test_prob(self): 104 | a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2, self.a_lms_1]) 105 | so3_distr = SO3Distribution(a_lms=a_lms, sphs=self.sphs) 106 | samples = torch.tensor([ 107 | [1.0, 0.0, 0.0], 108 | [0.0, 0.0, 1.0], 109 | [0.0, 1.0, 0.0], 110 | ]) 111 | 112 | self.assertEqual(so3_distr.log_prob(samples).shape, (3, )) 113 | self.assertEqual(so3_distr.log_prob(samples[[0]]).shape, (3, )) 114 | 115 | with self.assertRaises(RuntimeError): 116 | so3_distr.log_prob(samples[:2]) 117 | 118 | def test_max_sample(self): 119 | a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2]) 120 | so3_distr = SO3Distribution(a_lms=a_lms, sphs=self.sphs, dtype=torch.float) 121 | samples = so3_distr.argmax(count=17) 122 | self.assertEqual(samples.shape, (2, 3)) 123 | 124 | def test_normalization(self): 125 | a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2]) 126 | so3_distr = SO3Distribution(a_lms=a_lms, sphs=self.sphs, dtype=torch.float) 127 | grid = generate_fibonacci_grid(n=1024) 128 | grid_t = torch.tensor(grid, dtype=torch.float).unsqueeze(1) 129 | probs = so3_distr.prob(grid_t) 130 | integral = 4 * np.pi * torch.mean(probs, dim=0) 131 | self.assertTrue(np.allclose(to_numpy(integral), 1.0)) 132 | 133 | 134 | class ExpSphericalDistributionTest(TestCase): 135 | def setUp(self) -> None: 136 | self.maxl = 3 137 | self.sphs = SphericalHarmonics(maxl=self.maxl, sh_norm='qm') 138 | sphs_conj = SphericalHarmonics(maxl=self.maxl, conj=True, sh_norm='qm') 139 | 140 | # Generate some reference point(s) 141 | phi_refs = np.array([ 142 | np.pi / 2, 143 | -np.pi / 2, 144 | ]) 145 | theta_refs = np.pi / 2 * np.ones_like(phi_refs) 146 | theta_phi_refs = np.stack([theta_refs, phi_refs], axis=-1) 147 | xyz_refs = spherical_to_cartesian(theta_phi_refs) 148 | y_lms_conj = sphs_conj.forward(torch.tensor(xyz_refs, dtype=torch.float)) 149 | self.a_lms_1 = estimate_alms(y_lms_conj) 150 | 151 | # Another set of a_lms 152 | phi_refs = np.array([np.pi / 3]) 153 | theta_refs = np.pi / 3 * np.ones_like(phi_refs) 154 | theta_phi_refs = np.stack([theta_refs, phi_refs], axis=-1) 155 | xyz_refs = spherical_to_cartesian(theta_phi_refs) 156 | y_lms_conj = sphs_conj.forward(torch.tensor(xyz_refs, dtype=torch.float)) 157 | self.a_lms_2 = estimate_alms(y_lms_conj) 158 | 159 | def test_max(self): 160 | a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2, self.a_lms_1]) 161 | distr = ExpSO3Distribution(a_lms=a_lms, sphs=self.sphs, beta=100) 162 | self.assertEqual(distr.get_max_log_prob().shape, (3, )) 163 | 164 | def test_sample(self): 165 | torch.manual_seed(1) 166 | samples_shape = (2048, ) 167 | 168 | a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2]) 169 | distr = ExpSO3Distribution(a_lms=a_lms, sphs=self.sphs, beta=100) 170 | samples = distr.sample(samples_shape) 171 | 172 | self.assertEqual(samples.shape, samples_shape + distr.batch_shape + distr.event_shape) 173 | 174 | angles = cartesian_to_spherical(to_numpy(samples)) # [S, B, 2] 175 | mean_angles = np.mean(angles, axis=0) # [B, 2] 176 | 177 | self.assertEqual(mean_angles.shape, (2, 2)) 178 | 179 | distr_1 = ExpSO3Distribution(a_lms=self.a_lms_1, sphs=self.sphs, beta=100) 180 | samples_1 = distr_1.sample(samples_shape) 181 | angles_1 = cartesian_to_spherical(to_numpy(samples_1)) # [S, 1, 2] 182 | mean_angles_1 = np.mean(angles_1, axis=0) # [1, 2] 183 | 184 | distr_2 = ExpSO3Distribution(a_lms=self.a_lms_2, sphs=self.sphs, beta=100) 185 | samples_2 = distr_2.sample(samples_shape) 186 | angles_2 = cartesian_to_spherical(to_numpy(samples_2)) # [S, 1, 2] 187 | mean_angles_2 = np.mean(angles_2, axis=0) # [1, 2] 188 | 189 | # Assert that batching does not affect the result 190 | self.assertTrue(np.allclose(mean_angles[0], mean_angles_1, atol=0.1)) 191 | self.assertTrue(np.allclose(mean_angles[1], mean_angles_2, atol=0.1)) 192 | 193 | def test_prob(self): 194 | a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2, self.a_lms_1]) 195 | distr = ExpSO3Distribution(a_lms=a_lms, sphs=self.sphs, beta=100) 196 | samples = torch.tensor([ 197 | [1.0, 0.0, 0.0], 198 | [0.0, 0.0, 1.0], 199 | [0.0, 1.0, 0.0], 200 | ]) 201 | 202 | self.assertEqual(distr.log_prob(samples).shape, (3, )) 203 | self.assertEqual(distr.log_prob(samples[[0]]).shape, (3, )) 204 | 205 | with self.assertRaises(RuntimeError): 206 | distr.log_prob(samples[:2]) 207 | 208 | def test_max_sample(self): 209 | a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2]) 210 | distr = ExpSO3Distribution(a_lms=a_lms, sphs=self.sphs, dtype=torch.float, beta=100) 211 | samples = distr.argmax(count=17) 212 | self.assertEqual(samples.shape, (2, 3)) 213 | 214 | def test_normalization(self): 215 | a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2]) 216 | distr = ExpSO3Distribution(a_lms=a_lms, sphs=self.sphs, dtype=torch.float, beta=100) 217 | grid = generate_fibonacci_grid(n=1024) 218 | grid_t = torch.tensor(grid, dtype=torch.float).unsqueeze(1) 219 | probs = torch.exp(distr.log_prob(grid_t)) 220 | integral = 4 * np.pi * torch.mean(probs, dim=0) 221 | self.assertTrue(np.allclose(to_numpy(integral), 1.0, atol=5e-3)) 222 | -------------------------------------------------------------------------------- /molgym/environment.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import itertools 3 | import logging 4 | from typing import Tuple, List 5 | 6 | import ase.data 7 | import gym 8 | import numpy as np 9 | from ase import Atoms, Atom 10 | from scipy.spatial.qhull import ConvexHull, Delaunay 11 | 12 | from molgym.reward import InteractionReward 13 | from molgym.spaces import ActionSpace, ObservationSpace, ActionType, ObservationType, FormulaType 14 | from molgym.tools.util import remove_atom_from_formula, get_formula_size, zs_to_formula 15 | 16 | 17 | class AbstractMolecularEnvironment(gym.Env, abc.ABC): 18 | # Negative reward should be on the same order of magnitude as the positive ones. 19 | # Memory agent on QM9: mean 0.26, std 0.13, min -0.54, max 1.23 (negative reward indeed possible 20 | # but avoidable and probably due to PM6) 21 | 22 | def __init__( 23 | self, 24 | reward: InteractionReward, 25 | observation_space: ObservationSpace, 26 | action_space: ActionSpace, 27 | min_atomic_distance=0.6, # Angstrom 28 | max_solo_distance=2.0, # Angstrom 29 | min_reward=-0.6, # Hartree 30 | seed=0, 31 | ): 32 | self.reward = reward 33 | self.observation_space = observation_space 34 | self.action_space = action_space 35 | 36 | self.random_state = np.random.RandomState(seed=seed) 37 | 38 | self.min_atomic_distance = min_atomic_distance 39 | self.max_solo_distance = max_solo_distance 40 | self.min_reward = min_reward 41 | 42 | self.current_atoms = Atoms() 43 | self.current_formula: FormulaType = tuple() 44 | 45 | @abc.abstractmethod 46 | def reset(self) -> ObservationType: 47 | raise NotImplementedError 48 | 49 | def step(self, action: ActionType) -> Tuple[ObservationType, float, bool, dict]: 50 | atomic_number_index, position = action 51 | atomic_number = self.action_space.zs[atomic_number_index] 52 | done = atomic_number == 0 53 | 54 | if done: 55 | return self.observation_space.build(self.current_atoms, self.current_formula), 0.0, done, {} 56 | 57 | new_atom = self.action_space.to_atom(action) 58 | if not self._is_valid_action(current_atoms=self.current_atoms, new_atom=new_atom): 59 | return ( 60 | self.observation_space.build(self.current_atoms, self.current_formula), 61 | self.min_reward, 62 | True, 63 | {}, 64 | ) 65 | 66 | reward, info = self._calculate_reward(new_atom) 67 | 68 | if reward < self.min_reward: 69 | done = True 70 | reward = self.min_reward 71 | 72 | self.current_atoms.append(new_atom) 73 | self.current_formula = remove_atom_from_formula(self.current_formula, atomic_number) 74 | 75 | # Check if state is terminal 76 | if self._is_terminal(): 77 | done = True 78 | 79 | return self.observation_space.build(self.current_atoms, self.current_formula), reward, done, info 80 | 81 | def _is_terminal(self) -> bool: 82 | return len(self.current_atoms) == self.observation_space.canvas_space.size or get_formula_size( 83 | self.current_formula) == 0 84 | 85 | def _is_valid_action(self, current_atoms: Atoms, new_atom: Atom) -> bool: 86 | if self._is_too_close(current_atoms, new_atom): 87 | return False 88 | 89 | return self._all_covered(current_atoms, new_atom) 90 | 91 | def _is_too_close(self, existing_atoms: Atoms, new_atom: Atom) -> bool: 92 | # Check distances between new and old atoms 93 | for existing_atom in existing_atoms: 94 | if np.linalg.norm(existing_atom.position - new_atom.position) < self.min_atomic_distance: 95 | logging.debug('Atoms are too close') 96 | return True 97 | 98 | return False 99 | 100 | def _calculate_reward(self, new_atom: Atom) -> Tuple[float, dict]: 101 | return self.reward.calculate(self.current_atoms, new_atom) 102 | 103 | def _all_covered(self, existing_atoms: Atoms, new_atom: Atom) -> bool: 104 | # Ensure that certain atoms are not too far away from the nearest heavy atom to avoid H2, F2,... formation 105 | candidates = ['H', 'F', 'Cl', 'Br'] 106 | if len(existing_atoms) == 0 or new_atom.symbol not in candidates: 107 | return True 108 | 109 | for existing_atom in existing_atoms: 110 | if existing_atom.symbol in candidates: 111 | continue 112 | 113 | distance = np.linalg.norm(existing_atom.position - new_atom.position) 114 | if distance < self.max_solo_distance: 115 | return True 116 | 117 | logging.debug('There is a single atom floating around') 118 | return False 119 | 120 | def render(self, mode='human'): 121 | pass 122 | 123 | def seed(self, seed=None) -> int: 124 | seed = seed or np.random.randint(int(1e5)) 125 | self.random_state = np.random.RandomState(seed) 126 | return seed 127 | 128 | 129 | class MolecularEnvironment(AbstractMolecularEnvironment): 130 | def __init__(self, formulas: List[FormulaType], *args, **kwargs): 131 | super().__init__(*args, **kwargs) 132 | 133 | self.formulas = formulas 134 | self.formula_cycle = itertools.cycle(self.formulas) 135 | self.reset() 136 | 137 | def reset(self) -> ObservationType: 138 | self.current_atoms = Atoms() 139 | self.current_formula = next(self.formula_cycle) 140 | return self.observation_space.build(self.current_atoms, self.current_formula) 141 | 142 | 143 | class ConstrainedMolecularEnvironment(MolecularEnvironment): 144 | def __init__(self, scaffold: Atoms, scaffold_z: int, *args, **kwargs): 145 | self.scaffold = scaffold 146 | self.scaffold_z = scaffold_z 147 | 148 | super().__init__(*args, **kwargs) 149 | 150 | def reset(self) -> ObservationType: 151 | self.current_atoms = self.scaffold.copy() 152 | self.current_formula = next(self.formula_cycle) 153 | return self.observation_space.build(self.current_atoms, self.current_formula) 154 | 155 | def _is_valid_action(self, current_atoms: Atoms, new_atom: Atom) -> bool: 156 | is_scaffold = list(ase.data.atomic_numbers[symbol] == self.scaffold_z for symbol in current_atoms.symbols) 157 | scaffold_atoms = current_atoms[is_scaffold] 158 | 159 | if not self._is_inside_scaffold(scaffold_positions=scaffold_atoms.positions, new_position=new_atom.position): 160 | logging.debug(f'Atom {new_atom} is not inside scaffold') 161 | return False 162 | 163 | # Make sure atom is not too close to _any_ other atom (also scaffold atoms) 164 | return super()._is_valid_action(current_atoms=current_atoms, new_atom=new_atom) 165 | 166 | @staticmethod 167 | def _is_inside_scaffold(scaffold_positions: np.ndarray, new_position: np.ndarray): 168 | hull = ConvexHull(scaffold_positions, incremental=False) 169 | vertices = scaffold_positions[hull.vertices] 170 | delaunay = Delaunay(vertices) 171 | return delaunay.find_simplex(new_position) >= 0 172 | 173 | def _calculate_reward(self, new_atom: Atom) -> Tuple[float, dict]: 174 | is_scaffold = list(ase.data.atomic_numbers[symbol] == self.scaffold_z for symbol in self.current_atoms.symbols) 175 | return self.reward.calculate(self.current_atoms[np.logical_not(is_scaffold)], new_atom) 176 | 177 | 178 | class RefillableMolecularEnvironment(AbstractMolecularEnvironment): 179 | def __init__(self, formulas: List[FormulaType], initial_structure: Atoms, num_refills: int, *args, **kwargs): 180 | super().__init__(*args, **kwargs) 181 | 182 | self.formulas = formulas 183 | self.atoms = initial_structure.copy() 184 | self.num_refills = num_refills 185 | self.formulas_cycle = itertools.cycle(self.formulas) 186 | 187 | self.current_refill_counter = 0 188 | self.reset() 189 | 190 | def _is_terminal(self) -> bool: 191 | if len(self.current_atoms) == self.observation_space.canvas_space.size: 192 | return True 193 | 194 | if get_formula_size(self.current_formula) == 0: 195 | if self.current_refill_counter < self.num_refills: 196 | self.current_formula = next(self.formulas_cycle) 197 | self.current_refill_counter += 1 198 | else: 199 | return True 200 | 201 | return False 202 | 203 | def reset(self) -> ObservationType: 204 | self.current_refill_counter = 0 205 | self.current_atoms = self.atoms.copy() 206 | self.current_formula = next(self.formulas_cycle) 207 | return self.observation_space.build(self.current_atoms, self.current_formula) 208 | 209 | 210 | class StochasticEnvironment(AbstractMolecularEnvironment): 211 | def __init__(self, formula: FormulaType, size_range: Tuple[int, int], *args, **kwargs): 212 | super().__init__(*args, **kwargs) 213 | 214 | self.formula = formula 215 | self.min_size, self.max_size = size_range 216 | 217 | formula_size = get_formula_size(formula) 218 | self.zs = [z for z, count in formula] 219 | self.z_probs = [count / formula_size for z, count in formula] 220 | 221 | self.z_to_bond_count = { 222 | 1: 1, 223 | 5: 3, 224 | 6: 4, 225 | 7: 3, 226 | 8: 2, 227 | 9: 1, 228 | } 229 | 230 | self.reset() 231 | 232 | def reset(self) -> ObservationType: 233 | self.current_atoms = Atoms() 234 | self.current_formula = self.sample_formula() 235 | while not self.is_valid_formula(self.current_formula): 236 | self.current_formula = self.sample_formula() 237 | 238 | return self.observation_space.build(self.current_atoms, self.current_formula) 239 | 240 | def sample_formula(self) -> FormulaType: 241 | if self.min_size < self.max_size: 242 | size = self.random_state.randint(low=self.min_size, high=self.max_size, size=1) 243 | else: 244 | size = self.max_size 245 | zs = np.random.choice(self.zs, size=size, replace=True, p=self.z_probs) 246 | return zs_to_formula(zs) 247 | 248 | def is_valid_formula(self, formula: FormulaType) -> bool: 249 | return sum(count * self.z_to_bond_count[z] for z, count in formula) % 2 == 0 250 | -------------------------------------------------------------------------------- /molgym/agents/covariant/spherical_dists.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC 3 | 4 | import numpy as np 5 | import quadpy 6 | import torch 7 | from cormorant.cg_lib import SphericalHarmonics 8 | from cormorant.so3_lib import SO3Vec 9 | from torch.distributions import Uniform 10 | from torch.distributions.distribution import Distribution 11 | 12 | from .so3_tools import sum_product_alms_ylms, generate_fibonacci_grid, normalize_alms 13 | 14 | 15 | class SphericalDistribution(Distribution, ABC): 16 | arg_constraints = {} # type: ignore 17 | has_rsample = False 18 | 19 | def __init__(self, batch_shape=torch.Size(), validate_args=None, device=None, dtype=torch.float) -> None: 20 | super().__init__(batch_shape, event_shape=torch.Size((3, )), validate_args=validate_args) 21 | self.device = device 22 | self.dtype = dtype 23 | 24 | def expand(self, batch_shape, _instance=None): 25 | new = self._get_checked_instance(SphericalDistribution, _instance) 26 | batch_shape = torch.Size(batch_shape) 27 | new.device = self.device 28 | new.dtype = self.dtype 29 | super(SphericalDistribution, new).__init__(batch_shape, event_shape=self.event_shape, validate_args=False) 30 | new._validate_args = self._validate_args 31 | return new 32 | 33 | @staticmethod 34 | def _spherical_to_cartesian(theta: torch.Tensor, phi: torch.Tensor) -> torch.Tensor: 35 | x = torch.sin(theta) * torch.cos(phi) 36 | y = torch.sin(theta) * torch.sin(phi) 37 | z = torch.cos(theta) 38 | return torch.stack([x, y, z], dim=-1) 39 | 40 | def argmax(self) -> torch.Tensor: 41 | raise NotImplementedError 42 | 43 | 44 | class SphericalUniform(SphericalDistribution): 45 | def __init__(self, batch_shape=torch.Size(), validate_args=None, device=None, dtype=torch.float) -> None: 46 | super().__init__(batch_shape, validate_args=validate_args, device=device, dtype=dtype) 47 | self.uniform_dist = Uniform(0.0, 1.0) 48 | 49 | def sample(self, sample_shape=torch.Size()) -> torch.Tensor: 50 | # Based on: http://corysimon.github.io/articles/uniformdistn-on-sphere/ 51 | # Get shape 52 | if not isinstance(sample_shape, torch.Size): 53 | sample_shape = torch.Size(sample_shape) 54 | shape = sample_shape + self._batch_shape 55 | 56 | # Sample from transformed uniform 57 | theta = torch.acos(1 - 2 * self.uniform_dist.sample(shape).to(self.device)) 58 | phi = 2 * np.pi * self.uniform_dist.sample(shape).to(self.device) 59 | 60 | # Convert to Cartesian coordinates 61 | return self._spherical_to_cartesian(theta=theta, phi=phi) 62 | 63 | def prob(self, value: torch.Tensor) -> torch.Tensor: 64 | if self._validate_args: 65 | self._validate_sample(value) 66 | 67 | return torch.ones(size=value.shape[:-1], device=self.device) / (4 * np.pi) 68 | 69 | def log_prob(self, value: torch.Tensor) -> torch.Tensor: 70 | return torch.log(self.prob(value).clamp(min=1e-10)) 71 | 72 | def get_max_prob(self) -> torch.Tensor: 73 | return torch.ones(size=self.batch_shape, device=self.device) / (4 * np.pi) 74 | 75 | def argmax(self) -> torch.Tensor: 76 | return self.sample() 77 | 78 | 79 | class SO3Distribution(SphericalDistribution): 80 | def __init__(self, 81 | a_lms: SO3Vec, 82 | sphs: SphericalHarmonics, 83 | empty: torch.Tensor = None, 84 | validate_args=None, 85 | device=None, 86 | dtype=torch.float) -> None: 87 | # SO3Vec: -ell, ..., ell: (batch size, tau's, m's, 2) 88 | assert all(a_lm.shape[:-3] == a_lms[0].shape[:-3] for a_lm in a_lms) 89 | super().__init__(batch_shape=a_lms[0].shape[:-3], validate_args=validate_args, device=device, dtype=dtype) 90 | 91 | assert sphs.sh_norm == 'qm' 92 | self.sphs = sphs 93 | 94 | assert empty is None or empty.shape == self.batch_shape 95 | self.empty = empty 96 | 97 | self.coefficients = normalize_alms(a_lms) # (batches, taus, ms, 2) 98 | 99 | self.spherical_uniform = SphericalUniform(batch_shape=self.batch_shape, 100 | device=device, 101 | dtype=dtype, 102 | validate_args=validate_args) 103 | self.uniform_dist = Uniform(low=0.0, high=1.0, validate_args=validate_args) 104 | 105 | def get_max_prob(self) -> torch.Tensor: 106 | # grid_points: (samples, 1, 3) 107 | grid_points = torch.tensor(generate_fibonacci_grid(n=1024), dtype=self.dtype, device=self.device).unsqueeze(-2) 108 | 109 | probs = self.prob(grid_points) # (samples, batches) 110 | 111 | # Maximum over grid points 112 | maximum, _ = torch.max(probs, dim=0) 113 | 114 | return maximum # (batches, ) 115 | 116 | def sample(self, sample_shape=torch.Size()) -> torch.Tensor: 117 | assert len(self.batch_shape) == 1 118 | num_batches = self.batch_shape[0] 119 | 120 | accepted_t = torch.empty(size=(0, num_batches), dtype=torch.bool, device=self.device) 121 | candidates_t = torch.empty(size=(0, num_batches) + self.event_shape, dtype=self.dtype, device=self.device) 122 | 123 | max_prob = self.get_max_prob() 124 | max_prob_proposal = self.spherical_uniform.get_max_prob() 125 | 126 | m_value = max_prob / max_prob_proposal # (batches, ) 127 | logging.debug(f'Mean M value: {torch.mean(m_value).item():.3f}') 128 | count = min(max(1, int(2 * torch.max(m_value).item())), 1024) 129 | 130 | # number of samples per batch item 131 | num_samples = int(np.product(sample_shape)) 132 | 133 | while torch.any(accepted_t.sum(dim=0) < num_samples): 134 | candidates = self.spherical_uniform.sample(torch.Size((count, ))) # (count, batches, event) 135 | threshold = self.prob(candidates) / (m_value * self.spherical_uniform.prob(candidates)) # (count, batches) 136 | u = self.uniform_dist.sample(torch.Size((count, ))).unsqueeze(1).to(self.device) # (count, 1) 137 | accepted = u < threshold # (count, batches) 138 | 139 | accepted_t = torch.cat([accepted_t, accepted], dim=0) 140 | candidates_t = torch.cat([candidates_t, candidates], dim=0) 141 | 142 | # Collect accepted samples 143 | samples = [] 144 | for i in range(num_batches): 145 | cs = candidates_t[:, i] # (count, event) 146 | acs = accepted_t[:, i] # (count, ) 147 | samples.append(cs[acs][:num_samples]) 148 | 149 | samples_t = torch.stack(samples, dim=0) # (batches, samples, event) 150 | return samples_t.transpose(0, 1).reshape(sample_shape + self.batch_shape + self.event_shape).contiguous() 151 | 152 | def argmax(self, count=256) -> torch.Tensor: 153 | samples = self.sample(sample_shape=torch.Size((count, ))) # (samples, batches, 3) 154 | probs = self.prob(samples) # (samples, batches) 155 | indices = torch.argmax(probs, dim=0) # (batches, ) 156 | gather_indices = indices.unsqueeze(0).unsqueeze(-1).expand((-1, -1) + self.event_shape) # (1, batches, 3) 157 | result = torch.gather(samples, dim=0, index=gather_indices) # (1, batches, 3) 158 | return result.squeeze(0) # squeeze out samples dimension 159 | 160 | def prob(self, value: torch.Tensor) -> torch.Tensor: 161 | # value: (..., batches, 3) 162 | y_lms = self.sphs.forward(value) # (..., batches, taus, ms, 2) 163 | 164 | # Compute sum of products over ells, taus, and ms 165 | s = sum_product_alms_ylms(a_lms=self.coefficients, y_lms=y_lms) # (...., batches, 2) 166 | 167 | # Compute sum of squares 168 | p = torch.sum(torch.square(s), dim=-1, keepdim=False) # (..., batches) 169 | 170 | # Apply mask where probability is not defined 171 | if self.empty is not None: 172 | empty = self.empty.reshape((1, ) * (len(p.shape) - 1) + self.batch_shape) 173 | constant = self.spherical_uniform.prob(value) 174 | p = torch.where(empty, constant, p) 175 | 176 | return p 177 | 178 | def log_prob(self, value: torch.Tensor): 179 | return torch.log(self.prob(value).clamp(min=1e-10)) 180 | 181 | 182 | class ExpSO3Distribution(SphericalDistribution): 183 | def __init__(self, 184 | a_lms: SO3Vec, 185 | sphs: SphericalHarmonics, 186 | beta: float, 187 | validate_args=None, 188 | device=None, 189 | dtype=torch.float) -> None: 190 | # SO3Vec: -ell, ..., ell: (batch size, tau's, m's, 2) 191 | assert all(a_lm.shape[:-3] == a_lms[0].shape[:-3] for a_lm in a_lms) 192 | super().__init__(batch_shape=a_lms[0].shape[:-3], validate_args=validate_args, device=device, dtype=dtype) 193 | 194 | assert sphs.sh_norm == 'qm' 195 | self.sphs = sphs 196 | 197 | self.coefficients = normalize_alms(a_lms) # (batches, taus, ms, 2) 198 | self.sphs = sphs 199 | self.beta = beta 200 | 201 | self.spherical_uniform = SphericalUniform(batch_shape=self.batch_shape, 202 | device=device, 203 | dtype=dtype, 204 | validate_args=validate_args) 205 | self.uniform_dist = Uniform(low=0.0, high=1.0, validate_args=validate_args) 206 | self.log_z = self.compute_log_z() 207 | 208 | def compute_log_z(self) -> torch.Tensor: 209 | grid = quadpy.u3._lebedev.lebedev_071() 210 | # grid_points: (samples, 1, 3) 211 | grid_points = torch.tensor(grid.points.transpose(), dtype=self.dtype, device=self.device).unsqueeze(-2) 212 | weights = torch.tensor(grid.weights, dtype=self.dtype, device=self.device).unsqueeze(-1) # (samples, 1) 213 | log_probs_unnormalized = self.log_prob_unnormalized(grid_points) # (samples, batches) 214 | result = np.log(4 * np.pi) + torch.logsumexp(log_probs_unnormalized + torch.log(weights), dim=0) 215 | return result 216 | 217 | def get_max_log_prob(self) -> torch.Tensor: 218 | # grid_points: (samples, 1, 3) 219 | grid_points = torch.tensor(generate_fibonacci_grid(n=4096), dtype=self.dtype, device=self.device).unsqueeze(1) 220 | log_probs = self.log_prob(grid_points) # (samples, batches) 221 | 222 | # Maximum over grid points 223 | maximum, _ = torch.max(log_probs, dim=0) 224 | 225 | return maximum # (batches, ) 226 | 227 | def sample(self, sample_shape=torch.Size()) -> torch.Tensor: 228 | assert len(self.batch_shape) == 1 229 | num_batches = self.batch_shape[0] 230 | 231 | accepted_t = torch.empty(size=(0, num_batches), dtype=torch.bool, device=self.device) 232 | candidates_t = torch.empty(size=(0, num_batches) + self.event_shape, dtype=self.dtype, device=self.device) 233 | 234 | max_log_prob = self.get_max_log_prob() 235 | max_log_prob_proposal = torch.log(self.spherical_uniform.get_max_prob()) 236 | 237 | log_m_value = max_log_prob - max_log_prob_proposal # (batches, ) 238 | m_value = torch.exp(log_m_value.clamp(-8, 8)) 239 | 240 | logging.debug(f'Mean M value: {torch.mean(m_value):.3f}') 241 | count = min(max(1, int(2 * torch.max(m_value).item())), 1024) 242 | 243 | # number of samples per batch item 244 | num_samples = int(np.product(sample_shape)) 245 | 246 | while torch.any(accepted_t.sum(dim=0) < num_samples): 247 | candidates = self.spherical_uniform.sample(torch.Size((count, ))) # (count, batches, event) 248 | log_threshold = self.log_prob(candidates) - log_m_value - self.spherical_uniform.log_prob(candidates) 249 | u = self.uniform_dist.sample(torch.Size((count, ))).unsqueeze(1).to(self.device) # (count, 1) 250 | accepted = u < torch.exp(log_threshold) # (count, batches) 251 | 252 | accepted_t = torch.cat([accepted_t, accepted], dim=0) 253 | candidates_t = torch.cat([candidates_t, candidates], dim=0) 254 | 255 | # Collect accepted samples 256 | samples = [] 257 | for i in range(num_batches): 258 | cs = candidates_t[:, i] # (count, event) 259 | acs = accepted_t[:, i] # (count, ) 260 | samples.append(cs[acs][:num_samples]) 261 | 262 | samples_t = torch.stack(samples, dim=0) # (batches, samples, event) 263 | return samples_t.transpose(0, 1).reshape(sample_shape + self.batch_shape + self.event_shape).contiguous() 264 | 265 | def argmax(self, count=128) -> torch.Tensor: 266 | samples = self.sample(sample_shape=torch.Size((count, ))) # (samples, batches, 3) 267 | log_probs_unnormalized = self.log_prob_unnormalized(samples) # (samples, batches) 268 | indices = torch.argmax(log_probs_unnormalized, dim=0) # (batches, ) 269 | gather_indices = indices.unsqueeze(0).unsqueeze(-1).expand((-1, -1) + self.event_shape) # (1, batches, 3) 270 | result = torch.gather(samples, dim=0, index=gather_indices) # (1, batches, 3) 271 | return result.squeeze(0) # squeeze out samples dimension 272 | 273 | def log_prob_unnormalized(self, value: torch.Tensor) -> torch.Tensor: 274 | # value: (..., batches, 3) 275 | y_lms = self.sphs.forward(value) # (..., batches, taus, ms, 2) 276 | 277 | # Compute sum of products over ells, taus, and ms 278 | s = sum_product_alms_ylms(a_lms=self.coefficients, y_lms=y_lms) # (...., batches, 2) 279 | 280 | # Compute sum of squares 281 | log_p_unnormalized = -self.beta * torch.sum(torch.square(s), dim=-1, keepdim=False) # (..., batches) 282 | 283 | return log_p_unnormalized 284 | 285 | def log_prob(self, value: torch.Tensor): 286 | return self.log_prob_unnormalized(value) - self.log_z 287 | -------------------------------------------------------------------------------- /molgym/agents/covariant/agent.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Dict, Any 2 | 3 | import numpy as np 4 | import torch 5 | import torch.distributions 6 | from cormorant.cg_lib import SphericalHarmonics, CGDict 7 | from cormorant.so3_lib import SO3Tau, SO3Vec 8 | 9 | from molgym.agents.base import AbstractActorCritic 10 | from molgym.agents.covariant import tools, so3_tools 11 | from molgym.agents.covariant.gmm import GaussianMixtureModel 12 | from molgym.agents.covariant.modules import Cormorant, CormorantMixer 13 | from molgym.agents.covariant.so3_tools import AtomicScalars 14 | from molgym.agents.covariant.spherical_dists import SphericalDistribution, ExpSO3Distribution, SO3Distribution 15 | from molgym.modules import MLP, masked_softmax, to_one_hot 16 | from molgym.spaces import ObservationType, ActionType, ObservationSpace, ActionSpace 17 | from molgym.tools.util import to_numpy 18 | 19 | 20 | class CovariantAC(AbstractActorCritic): 21 | def __init__( 22 | self, 23 | observation_space: ObservationSpace, 24 | action_space: ActionSpace, 25 | min_max_distance: Tuple[float, float], 26 | network_width: int, 27 | maxl: int, 28 | num_cg_levels: int, 29 | num_channels_hidden: int, 30 | num_channels_per_element: int, 31 | num_gaussians: int, 32 | bag_scale: int, 33 | beta: Optional[float] = None, 34 | device=None, 35 | ): 36 | super().__init__(observation_space, action_space) 37 | self.device = device 38 | self.dtype = torch.float 39 | 40 | self.zs = self.observation_space.zs 41 | self.zs_tensor = torch.tensor(self.zs, dtype=self.dtype, device=self.device) 42 | 43 | self.min_distance, self.max_distance = min_max_distance 44 | assert self.min_distance < self.max_distance 45 | self.beta = beta 46 | 47 | self.max_sh = maxl 48 | self.num_cg_levels = num_cg_levels 49 | self.num_channels_hidden = num_channels_hidden 50 | self.num_channels_per_element = num_channels_per_element 51 | self.num_gaussians = num_gaussians 52 | 53 | self.num_channels_out = len(self.zs) * self.num_channels_per_element 54 | self.channel_offsets = torch.arange(start=0, 55 | end=self.num_channels_per_element, 56 | dtype=torch.long, 57 | device=self.device).unsqueeze(0) 58 | 59 | self.cg_dict = CGDict(maxl=self.max_sh, device=self.device, dtype=self.dtype) 60 | self.cg_model = Cormorant( 61 | maxl=self.max_sh, # Cutoff in CG operations (default: [3]) 62 | max_sh=self.max_sh, # Number of spherical harmonic powers to use (default: [3]) 63 | num_cg_levels=self.num_cg_levels, # Number of CG levels (default: 4) 64 | num_channels=[self.num_channels_hidden] * self.num_cg_levels + [self.num_channels_out], 65 | num_species=len(self.zs), 66 | cutoff_type=['soft'], # Types of cutoffs to include 67 | hard_cut_rad=min(self.max_distance, 2.1), # Radius of hard cutoff (in AA) 68 | soft_cut_rad=min(self.max_distance, 2.1), # Radius of soft cutoff (in AA) 69 | soft_cut_width=0.2, # Width of SOFT cutoff in Angstroms (default: 0.2) 70 | weight_init='rand', # Weight initialization function to use (default: rand) 71 | level_gain=[10.0], # Gain at each level (default: [10.]) 72 | charge_power=2, # Maximum power to take in one-hot (default: 2) 73 | basis_set=[3, 3], # Use gaussian mask instead of sigmoid mask. 74 | charge_scale=max(self.zs), 75 | bag_scale=bag_scale, 76 | device=self.device, 77 | dtype=self.dtype, 78 | cg_dict=self.cg_dict, 79 | ) 80 | 81 | self.cg_mix = CormorantMixer( 82 | tau_in=SO3Tau([self.num_channels_per_element] * (self.max_sh + 1)), 83 | tau_other=SO3Tau([self.num_channels_per_element]), 84 | maxl=self.max_sh, 85 | num_channels=self.num_channels_per_element, 86 | level_gain=10.0, 87 | weight_init='rand', 88 | device=self.device, 89 | dtype=self.dtype, 90 | cg_dict=self.cg_dict, 91 | ) 92 | 93 | self.sph_harms = SphericalHarmonics(maxl=self.max_sh, 94 | conj=False, 95 | sh_norm='qm', 96 | device=self.device, 97 | dtype=self.dtype, 98 | cg_dict=self.cg_dict) 99 | 100 | self.atomic_scalars = AtomicScalars(maxl=self.max_sh, full_scalars=True, device=self.device, dtype=self.dtype) 101 | 102 | self.num_latent = self.atomic_scalars.get_output_dim(self.num_channels_out) 103 | self.num_latent_element = self.atomic_scalars.get_output_dim(self.num_channels_per_element) 104 | 105 | # Focus 106 | self.phi_focus = MLP( 107 | input_dim=self.num_latent, 108 | output_dims=(network_width, 1), 109 | ) 110 | 111 | # Element 112 | self.phi_element = MLP( 113 | input_dim=self.num_latent, 114 | output_dims=(network_width, len(self.zs)), 115 | ) 116 | 117 | # Distance: Gaussian Mixture Model 118 | self.phi_d = MLP( 119 | input_dim=self.num_latent_element, 120 | output_dims=(network_width, 2 * self.num_gaussians), 121 | ) 122 | self.pad_zeros = torch.nn.ConstantPad1d(padding=(0, 1), value=0.0) # Pad with one 0.0 to the right 123 | 124 | self.distance_half_width = torch.tensor((self.max_distance - self.min_distance) / 2, 125 | dtype=self.dtype, 126 | device=self.device) 127 | self.distance_center = torch.tensor((self.min_distance + self.max_distance) / 2, 128 | dtype=self.dtype, 129 | device=self.device) 130 | 131 | self.distance_log_stds = torch.nn.Parameter(torch.log( 132 | torch.tensor([0.1] * self.num_gaussians, dtype=self.dtype, device=self.device)), 133 | requires_grad=True) # (gaussians, ) 134 | 135 | # Value function 136 | self.phi_trans = MLP( 137 | input_dim=self.num_latent, 138 | output_dims=(network_width, network_width), 139 | ) 140 | self.phi_v = MLP( 141 | input_dim=network_width, 142 | output_dims=(network_width, 1), 143 | ) 144 | 145 | self.to(self.device) 146 | 147 | def to_action_space(self, action: torch.Tensor, observation: ObservationType) -> ActionType: 148 | assert action.shape == (6, ) 149 | action = to_numpy(action) 150 | 151 | focus = int(round(action[0].item())) 152 | element_index = int(round(action[1].item())) 153 | d = action[2] 154 | so3 = action[-3:] 155 | 156 | atoms, bag = self.observation_space.parse(observation) 157 | 158 | if len(atoms): 159 | position = atoms[focus].position + d * so3 160 | else: 161 | position = (0.0, 0.0, 0.0) 162 | 163 | return element_index, position 164 | 165 | def parse_observations(self, observations: List[ObservationType]) -> Dict[str, torch.Tensor]: 166 | parsed_observations = [self.observation_space.parse(observation) for observation in observations] 167 | atoms_list = [tup[0] for tup in parsed_observations] 168 | bags = [observation[1] for observation in observations] 169 | 170 | # Canvas 171 | data = tools.process_atoms_list(atoms_list, 172 | max_num_atoms=self.observation_space.canvas_space.size, 173 | dtype=self.dtype, 174 | device=self.device) 175 | 176 | data['one_hot'] = data['charges'].unsqueeze(-1) == self.zs_tensor.unsqueeze(0).unsqueeze(0) 177 | data['atom_mask'] = data['charges'] > 0 178 | data['edge_mask'] = data['atom_mask'].unsqueeze(1) * data['atom_mask'].unsqueeze(2) 179 | 180 | # At least one atom needs to be selectable 181 | default = torch.zeros_like(data['atom_mask']) 182 | default[..., 0] = 1 183 | 184 | # If the canvas is empty, focus 0th index 185 | data['focus_mask'] = torch.logical_or(data['atom_mask'], default) 186 | 187 | # Is canvas empty? 188 | data['empty'] = torch.tensor([len(atoms) == 0 for atoms in atoms_list], dtype=torch.bool, device=self.device) 189 | 190 | # Bag 191 | data['bags'] = torch.tensor([list(bag) for bag in bags], dtype=self.dtype, device=self.device) # (batches, zs) 192 | data['element_mask'] = data['bags'] > 0 # (batches, zs) 193 | 194 | # Value mask 195 | data['value_mask'] = data['atom_mask'] 196 | 197 | return data 198 | 199 | def get_so3_distribution(self, a_lms: SO3Vec, empty: torch.Tensor) -> SphericalDistribution: 200 | if self.beta is not None: 201 | return ExpSO3Distribution(a_lms=a_lms, 202 | sphs=self.sph_harms, 203 | beta=self.beta, 204 | dtype=self.dtype, 205 | device=self.device) 206 | else: 207 | return SO3Distribution(a_lms=a_lms, sphs=self.sph_harms, empty=empty, dtype=self.dtype, device=self.device) 208 | 209 | def step(self, observations: List[ObservationType], actions: Optional[np.ndarray] = None) -> dict: 210 | data = self.parse_observations(observations) 211 | 212 | # Cast action to tensor 213 | if actions is not None: 214 | actions = torch.as_tensor(actions, dtype=torch.float, device=self.device) 215 | 216 | # SO3Vec (batches, atoms, taus, ms, 2) 217 | covariats = self.cg_model(data) 218 | 219 | # Compute invariants 220 | invariats = self.atomic_scalars(covariats) # (batches, atoms, inv_feats) 221 | 222 | # Focus 223 | focus_logits = self.phi_focus(invariats) # (batches, atoms, 1) 224 | focus_logits = focus_logits.squeeze(-1) # (batches, atoms) 225 | focus_probs = masked_softmax(focus_logits, mask=data['focus_mask']) # (batches, atoms) 226 | focus_dist = torch.distributions.Categorical(probs=focus_probs) 227 | 228 | # focus: (batches, 1) 229 | if actions is not None: 230 | focus = torch.round(actions[:, :1]).long() 231 | elif self.training: 232 | focus = focus_dist.sample().unsqueeze(-1) 233 | else: 234 | focus = torch.argmax(focus_probs, dim=-1).unsqueeze(-1) 235 | 236 | focus_oh = to_one_hot(focus, num_classes=self.observation_space.canvas_space.size, 237 | device=self.device) # (batches, atoms) 238 | 239 | focused_cov = so3_tools.select_atomic_covariats(covariats, focus_oh) # (batches, taus, ms, 2) 240 | focused_inv = so3_tools.select_atomic_invariats(invariats, focus_oh) # (batches, feats) 241 | 242 | # Element 243 | element_logits = self.phi_element(focused_inv) # (batches, zs) 244 | element_probs = masked_softmax(element_logits, mask=data['element_mask']) # (batches, zs) 245 | element_dist = torch.distributions.Categorical(probs=element_probs) 246 | 247 | # element: (batches, 1) 248 | if actions is not None: 249 | element = torch.round(actions[:, 1:2]).long() 250 | elif self.training: 251 | element = element_dist.sample().unsqueeze(-1) 252 | else: 253 | element = torch.argmax(element_probs, dim=-1).unsqueeze(-1) 254 | 255 | # Crop element 256 | offsets = self.channel_offsets.expand(len(observations), -1) # (batches, channels_per_element) 257 | indices = offsets + element * self.num_channels_per_element 258 | element_cov = so3_tools.select_taus(focused_cov, indices=indices) 259 | element_inv = self.atomic_scalars(element_cov) # (batches, inv_feats) 260 | 261 | # Distance: Gaussian mixture model 262 | # gmm_log_probs, d_mean_trans: (batches, gaussians) 263 | gmm_log_probs, d_mean_trans = self.phi_d(element_inv).split(self.num_gaussians, dim=-1) 264 | distance_mean = torch.tanh(d_mean_trans) * self.distance_half_width + self.distance_center 265 | distance_dist = GaussianMixtureModel(log_probs=gmm_log_probs, 266 | means=distance_mean, 267 | stds=torch.exp(self.distance_log_stds).clamp(1e-6)) 268 | 269 | # distance: (batches, 1) 270 | if actions is not None: 271 | distance = actions[:, 2:3] 272 | elif self.training: 273 | # Ensure that the sampled distance is > 0 274 | distance = distance_dist.sample().clamp(0.001).unsqueeze(-1) 275 | else: 276 | distance = distance_dist.argmax().unsqueeze(-1) 277 | 278 | # Condition on distance 279 | transformed_d = distance.unsqueeze(1).unsqueeze(1).expand(-1, self.num_channels_per_element, 1, -1) 280 | transformed_d = self.pad_zeros(transformed_d) 281 | distance_so3 = SO3Vec([transformed_d]) 282 | cond_cov = self.cg_mix(element_cov, distance_so3) 283 | 284 | so3_dist = self.get_so3_distribution(a_lms=cond_cov, empty=data['empty']) 285 | 286 | # so3: (batches, 3) 287 | if actions is not None: 288 | orientation = actions[..., 3:6] 289 | elif self.training: 290 | orientation = so3_dist.sample() 291 | else: 292 | orientation = so3_dist.argmax() 293 | 294 | # Log prob 295 | log_prob_list = [ 296 | focus_dist.log_prob(focus.squeeze(-1)), 297 | element_dist.log_prob(element.squeeze(-1)), 298 | distance_dist.log_prob(distance.squeeze(-1)), 299 | so3_dist.log_prob(orientation), 300 | ] 301 | log_prob = torch.stack(log_prob_list, dim=-1).sum(dim=-1) # (batches, ) 302 | 303 | # Entropy 304 | entropy_list = [ 305 | focus_dist.entropy(), 306 | element_dist.entropy(), 307 | ] 308 | entropy = torch.stack(entropy_list, dim=-1).sum(dim=-1) # (batches, ) 309 | 310 | # Value function 311 | # atom_mask: (batches, atoms) 312 | # invariants: (batches, atoms, feats) 313 | trans_invariats = self.phi_trans(invariats) 314 | value_feats = torch.einsum( # type: ignore 315 | 'ba,baf->bf', data['value_mask'].to(self.dtype), trans_invariats) # (batches, inv_feats) 316 | value = self.phi_v(value_feats).squeeze(-1) # (batches, ) 317 | 318 | # Action 319 | response: Dict[str, Any] = {} 320 | if actions is None: 321 | actions = torch.cat([focus.float(), element.float(), distance, orientation], dim=-1) 322 | 323 | # Build correspond action in action space 324 | response['actions'] = [self.to_action_space(a, o) for a, o in zip(actions, observations)] 325 | 326 | response.update({ 327 | 'a': actions, # (batches, subactions) 328 | 'logp': log_prob, # (batches, ) 329 | 'ent': entropy, # (batches, ) 330 | 'v': value, # (batches, ) 331 | 'dists': [focus_dist, element_dist, distance_dist, so3_dist], 332 | }) 333 | 334 | return response 335 | -------------------------------------------------------------------------------- /molgym/ppo.py: -------------------------------------------------------------------------------- 1 | # The content of this file is based on: DeepRL https://github.com/ShangtongZhang/DeepRL. 2 | import logging 3 | import time 4 | from typing import Dict, Optional, Tuple, Sequence, List, Iterator 5 | 6 | import numpy as np 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | 10 | from molgym.agents.base import AbstractActorCritic 11 | from molgym.buffer import DynamicPPOBuffer 12 | from molgym.buffer_container import PPOBufferContainer 13 | from molgym.env_container import VecEnv 14 | from molgym.tools.model_util import ModelIO 15 | from molgym.tools.util import RolloutSaver, to_numpy, InfoSaver, compute_gradient_norm 16 | 17 | 18 | def compute_loss( 19 | ac: AbstractActorCritic, 20 | data: dict, 21 | clip_ratio: float, 22 | vf_coef: float, 23 | entropy_coef: float, 24 | device=None, 25 | ) -> Tuple[torch.Tensor, Dict[str, float]]: 26 | pred = ac.step(data['obs'], data['act']) 27 | 28 | old_logp = torch.as_tensor(data['logp'], device=device) 29 | adv = torch.as_tensor(data['adv'], device=device) 30 | ret = torch.as_tensor(data['ret'], device=device) 31 | 32 | # Policy loss 33 | ratio = torch.exp(pred['logp'] - old_logp) 34 | obj = ratio * adv 35 | clipped_obj = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv 36 | policy_loss = -torch.min(obj, clipped_obj).mean() 37 | 38 | # Entropy loss 39 | entropy_loss = -entropy_coef * pred['ent'].mean() 40 | 41 | # Value loss 42 | vf_loss = vf_coef * (pred['v'] - ret).pow(2).mean() 43 | 44 | # Total loss 45 | loss = policy_loss + entropy_loss + vf_loss 46 | 47 | # Approximate KL for early stopping 48 | approx_kl = (old_logp - pred['logp']).mean() 49 | 50 | # Extra info 51 | clipped = ratio.lt(1 - clip_ratio) | ratio.gt(1 + clip_ratio) 52 | clip_fraction = torch.as_tensor(clipped, dtype=torch.float32).mean() 53 | 54 | info = dict( 55 | policy_loss=to_numpy(policy_loss).item(), 56 | entropy_loss=to_numpy(entropy_loss).item(), 57 | vf_loss=to_numpy(vf_loss).item(), 58 | total_loss=to_numpy(loss).item(), 59 | approx_kl=to_numpy(approx_kl).item(), 60 | clip_fraction=to_numpy(clip_fraction).item(), 61 | ) 62 | 63 | return loss, info 64 | 65 | 66 | def get_batch_generator(indices: np.ndarray, batch_size: int) -> Iterator[np.ndarray]: 67 | assert len(indices.shape) == 1 68 | indices = np.random.permutation(indices) 69 | batches = indices[:len(indices) // batch_size * batch_size].reshape(-1, batch_size) 70 | for batch in batches: 71 | yield batch 72 | remainder = len(indices) % batch_size 73 | if remainder: 74 | yield indices[-remainder:] 75 | 76 | 77 | def collect_data_batch(data: Dict[str, Sequence], indices: np.ndarray) -> Dict[str, Sequence]: 78 | batch: Dict[str, Sequence] = {} 79 | for key, value in data.items(): 80 | if isinstance(value, np.ndarray): 81 | batch[key] = value[indices] 82 | elif isinstance(value, list): 83 | items = [] 84 | for index in indices: 85 | items.append(value[index]) 86 | batch[key] = items 87 | else: 88 | ValueError(value) 89 | return batch 90 | 91 | 92 | def compute_mean_dict(dicts: List[Dict[str, float]]) -> Dict[str, float]: 93 | # Assert all dicts have the same keys 94 | assert (d.keys() == dicts[0].keys() for d in dicts) 95 | return {key: np.mean([d[key] for d in dicts]) for key in dicts[0].keys()} 96 | 97 | 98 | # Train policy with multiple steps of gradient descent 99 | def train( 100 | ac: AbstractActorCritic, 101 | optimizer: Optimizer, 102 | data: Dict[str, Sequence], 103 | mini_batch_size: int, 104 | clip_ratio: float, 105 | target_kl: float, 106 | vf_coef: float, 107 | entropy_coef: float, 108 | gradient_clip: float, 109 | max_num_steps: int, 110 | device=None, 111 | ) -> dict: 112 | infos = {} 113 | 114 | start_time = time.time() 115 | 116 | num_epochs = 0 117 | for i in range(max_num_steps): 118 | optimizer.zero_grad() 119 | 120 | batch_infos = [] 121 | batch_generator = get_batch_generator(indices=np.arange(len(data['obs'])), batch_size=mini_batch_size) 122 | for batch_indices in batch_generator: 123 | data_batch = collect_data_batch(data, indices=batch_indices) 124 | batch_loss, batch_info = compute_loss(ac, 125 | data=data_batch, 126 | clip_ratio=clip_ratio, 127 | vf_coef=vf_coef, 128 | entropy_coef=entropy_coef, 129 | device=device) 130 | 131 | batch_loss.backward(retain_graph=False) # type: ignore 132 | batch_infos.append(batch_info) 133 | 134 | loss_info = compute_mean_dict(batch_infos) 135 | loss_info['grad_norm'] = compute_gradient_norm(ac.parameters()) 136 | 137 | # Check KL 138 | if loss_info['approx_kl'] > 1.5 * target_kl: 139 | logging.debug(f'Early stopping at step {i} for reaching max KL.') 140 | break 141 | 142 | # Take gradient step 143 | logging.debug('Taking gradient step') 144 | torch.nn.utils.clip_grad_norm_(ac.parameters(), max_norm=gradient_clip) 145 | optimizer.step() 146 | optimizer.zero_grad() 147 | 148 | num_epochs += 1 149 | 150 | # Logging 151 | logging.debug(f'Loss {i}: {loss_info}') 152 | infos.update(loss_info) 153 | 154 | infos['num_opt_steps'] = num_epochs 155 | infos['time'] = time.time() - start_time 156 | 157 | if num_epochs > 0: 158 | logging.info(f'Optimization: policy loss={infos["policy_loss"]:.3f}, vf loss={infos["vf_loss"]:.3f}, ' 159 | f'entropy loss={infos["entropy_loss"]:.3f}, total loss={infos["total_loss"]:.3f}, ' 160 | f'num steps={num_epochs}') 161 | return infos 162 | 163 | 164 | def batch_rollout(ac: AbstractActorCritic, 165 | envs: VecEnv, 166 | buffer_container: PPOBufferContainer, 167 | num_steps: int = None, 168 | num_episodes: int = None) -> dict: 169 | assert num_steps is not None or num_episodes is not None 170 | 171 | if num_steps is not None: 172 | assert num_steps % envs.get_size() == 0 173 | num_iters = num_steps // envs.get_size() 174 | else: 175 | num_iters = np.inf 176 | 177 | if num_episodes is not None: 178 | assert envs.get_size() == 1 179 | else: 180 | num_episodes = np.inf 181 | 182 | start_time = time.time() 183 | 184 | counter = 0 185 | observations = envs.reset() 186 | 187 | while counter < num_iters and buffer_container.get_num_episodes() < num_episodes: 188 | predictions = ac.step(observations) 189 | 190 | next_observations, rewards, terminals, _ = envs.step(predictions['actions']) 191 | 192 | buffer_container.store(observations=observations, 193 | actions=to_numpy(predictions['a']), 194 | rewards=rewards, 195 | next_observations=next_observations, 196 | terminals=terminals, 197 | values=to_numpy(predictions['v']), 198 | logps=to_numpy(predictions['logp'])) 199 | 200 | # Reset environment if state is terminal to get valid next observation 201 | observations = envs.reset_if_terminal(next_observations, terminals) 202 | 203 | if counter == num_iters - 1: 204 | # Note: finished trajectories will not be affected by this 205 | predictions = ac.step(observations) 206 | buffer_container.finish_paths(to_numpy(predictions['v'])) 207 | 208 | counter += 1 209 | 210 | info = { 211 | 'time': time.time() - start_time, 212 | 'return_mean': np.mean(buffer_container.episodic_returns).item(), 213 | 'return_std': np.std(buffer_container.episodic_returns).item(), 214 | 'episode_length_mean': np.mean(buffer_container.episode_lengths).item(), 215 | 'episode_length_std': np.std(buffer_container.episode_lengths).item(), 216 | } 217 | 218 | return info 219 | 220 | 221 | def compute_buffer_stats(buffer: DynamicPPOBuffer) -> Dict[str, float]: 222 | return { 223 | 'value_mean': np.mean(buffer.val_buf).item(), 224 | 'value_std': np.std(buffer.val_buf).item(), 225 | 'logp_mean': np.mean(buffer.logp_buf).item(), 226 | 'logp_std': np.std(buffer.logp_buf).item(), 227 | } 228 | 229 | 230 | def batch_ppo( 231 | envs: VecEnv, 232 | eval_envs: VecEnv, 233 | ac: AbstractActorCritic, 234 | optimizer: Optimizer, 235 | gamma=0.99, 236 | start_num_steps=0, 237 | max_num_steps=4096, 238 | num_steps_per_iter=200, 239 | mini_batch_size=64, 240 | clip_ratio=0.2, 241 | vf_coef=0.5, 242 | entropy_coef=0.0, 243 | max_num_train_iters=80, 244 | lam=0.97, 245 | target_kl=0.01, 246 | gradient_clip=0.5, 247 | save_freq=5, 248 | model_handler: Optional[ModelIO] = None, 249 | eval_freq=10, 250 | num_eval_episodes=1, 251 | rollout_saver: Optional[RolloutSaver] = None, 252 | save_train_rollout=False, 253 | save_eval_rollout=True, 254 | info_saver: Optional[InfoSaver] = None, 255 | device=None, 256 | ): 257 | """ 258 | Proximal Policy Optimization (by clipping), with early stopping based on approximate KL 259 | 260 | Args: 261 | :param envs: VecEnv for training. 262 | :param eval_envs: VecEnv for evaluation. 263 | :param ac: Instance of an AbstractActorCritic 264 | :param optimizer: Optimizer to optimize agent's parameters 265 | :param num_steps_per_iter: Number of agent-environment interaction steps per iteration. 266 | :param start_num_steps: Initial number of steps 267 | :param max_num_steps: Maximum number of steps 268 | :param mini_batch_size: mini batch size for loss calculation 269 | :param gamma: Discount factor. (Always between 0 and 1.) 270 | :param clip_ratio: Hyperparameter for clipping in the policy objective. 271 | Roughly: how far can the new policy go from the old policy while 272 | still profiting (improving the objective function)? The new policy 273 | can still go farther than the clip_ratio says, but it doesn't help 274 | on the objective anymore. (Usually small, 0.1 to 0.3.) 275 | :param vf_coef: coefficient for value function loss term 276 | :param entropy_coef: coefficient for entropy loss term 277 | :param gradient_clip: clip norm of gradients before optimization step is taken 278 | :param max_num_train_iters: Maximum number of gradient descent steps to take 279 | on policy loss per epoch. (Early stopping may cause optimizer to take fewer than this.) 280 | :param lam: Lambda for GAE-Lambda. (Always between 0 and 1, close to 1.) 281 | :param target_kl: Roughly what KL divergence we think is appropriate 282 | between new and old policies after an update. This will get used 283 | for early stopping. (Usually small, 0.01 or 0.05.) 284 | :param eval_freq: How often to evaluate the policy 285 | :param num_eval_episodes: Number of evaluation episodes 286 | :param model_handler: Save model to file 287 | :param save_freq: How often the model is saved 288 | :param rollout_saver: Saves rollout buffers 289 | :param save_train_rollout: Save training rollout 290 | :param save_eval_rollout: Save evaluation rollout 291 | :param info_saver: Save statistics 292 | :param device: device on which to run the calculations 293 | """ 294 | 295 | # Total number of steps 296 | total_num_steps = start_num_steps 297 | num_iterations = (max_num_steps - total_num_steps) // num_steps_per_iter 298 | 299 | logging.info('Starting PPO') 300 | 301 | # Main loop 302 | for iteration in range(num_iterations): 303 | logging.info(f'Iteration: {iteration}/{num_iterations - 1}, steps: {total_num_steps}') 304 | 305 | # Training rollout 306 | train_container = PPOBufferContainer(size=envs.get_size(), gamma=gamma, lam=lam) 307 | train_rollout = batch_rollout(ac=ac, envs=envs, buffer_container=train_container, num_steps=num_steps_per_iter) 308 | logging.info( 309 | f'Training rollout: return={train_rollout["return_mean"]:.3f} ({train_rollout["return_std"]:.1f}), ' 310 | f'episode length={train_rollout["episode_length_mean"]:.1f}') 311 | 312 | train_buffer = train_container.merge() 313 | 314 | if info_saver: 315 | train_rollout['total_num_steps'] = total_num_steps 316 | train_rollout.update(compute_buffer_stats(train_buffer)) 317 | info_saver.save(train_rollout, name='train') 318 | 319 | # Save training buffer 320 | if rollout_saver and save_train_rollout: 321 | rollout_saver.save(train_buffer, num_steps=total_num_steps, info='train') 322 | 323 | # Obtain (standardized) data for training 324 | data = train_buffer.get_data() 325 | 326 | # Train policy 327 | opt_info = train( 328 | ac=ac, 329 | optimizer=optimizer, 330 | data=data, 331 | mini_batch_size=mini_batch_size, 332 | clip_ratio=clip_ratio, 333 | vf_coef=vf_coef, 334 | entropy_coef=entropy_coef, 335 | target_kl=target_kl, 336 | gradient_clip=gradient_clip, 337 | max_num_steps=max_num_train_iters, 338 | device=device, 339 | ) 340 | 341 | if info_saver: 342 | opt_info['total_num_steps'] = total_num_steps 343 | info_saver.save(opt_info, name='opt') 344 | 345 | # Update number of steps taken / trained 346 | total_num_steps += num_steps_per_iter 347 | 348 | # Evaluate policy 349 | if (iteration % eval_freq == 0) or (iteration == num_iterations - 1): 350 | eval_container = PPOBufferContainer(size=eval_envs.get_size(), gamma=gamma, lam=lam) 351 | 352 | with torch.no_grad(): 353 | ac.training = False 354 | eval_rollout = batch_rollout(ac, 355 | eval_envs, 356 | buffer_container=eval_container, 357 | num_episodes=num_eval_episodes) 358 | logging.info( 359 | f'Evaluation rollout: return={eval_rollout["return_mean"]:.3f} ({eval_rollout["return_std"]:.1f}), ' 360 | f'episode length={eval_rollout["episode_length_mean"]:.1f}') 361 | ac.training = True 362 | 363 | eval_buffer = eval_container.merge() 364 | 365 | # Log information 366 | if info_saver: 367 | eval_rollout['total_num_steps'] = total_num_steps 368 | eval_rollout.update(compute_buffer_stats(eval_buffer)) 369 | info_saver.save(eval_rollout, name='eval') 370 | 371 | # Safe evaluation buffer 372 | if rollout_saver and save_eval_rollout: 373 | rollout_saver.save(eval_buffer, num_steps=total_num_steps, info='eval') 374 | 375 | # Save model 376 | if model_handler and ((iteration % save_freq == 0) or (iteration == num_iterations - 1)): 377 | model_handler.save(ac, num_steps=total_num_steps) 378 | 379 | logging.info('Finished PPO') 380 | --------------------------------------------------------------------------------