├── sgdml ├── intf │ ├── __init__.py │ └── ase_calc.py ├── utils │ ├── __init__.py │ ├── ui.py │ ├── desc.py │ ├── io.py │ └── perm.py ├── solvers │ ├── __init__.py │ ├── analytic.py │ └── iterative.py ├── __init__.py └── get.py ├── pyproject.toml ├── .gitignore ├── setup.cfg ├── LICENSE.txt ├── setup.py ├── scripts ├── sgdml_dataset_to_extxyz.py ├── sgdml_datasets_from_model.py ├── sgdml_dataset_from_aims.py ├── sgdml_dataset_from_ipi.py ├── sgdml_dataset_via_ase.py └── sgdml_dataset_from_extxyz.py └── README.md /sgdml/intf/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sgdml/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sgdml/solvers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | skip-string-normalization = true 3 | skip-numeric-underscore-normalization = true 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .DS_Store 3 | 4 | # Compiled python modules. 5 | *.pyc 6 | 7 | # Setuptools distribution folder. 8 | /dist/ 9 | 10 | # Python egg metadata, regenerated from source files by setuptools. 11 | /*.egg-info 12 | /*.egg 13 | sgdml/_bmark_cache.npz 14 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-complexity = 12 3 | ignore = E501,W503,E741 4 | select = C,E,F,W 5 | 6 | [isort] 7 | multi_line_output = 3 8 | include_trailing_comma = 1 9 | line_length = 85 10 | sections = FUTURE,STDLIB,TYPING,THIRDPARTY,FIRSTPARTY,LOCALFOLDER 11 | known_typing = typing, typing_extensions 12 | no_lines_before = TYPING 13 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018-2022 Stefan Chmiela 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from io import open 4 | from setuptools import setup, find_packages 5 | 6 | 7 | def get_property(property, package): 8 | result = re.search( 9 | r'{}\s*=\s*[\'"]([^\'"]*)[\'"]'.format(property), 10 | open(package + '/__init__.py').read(), 11 | ) 12 | return result.group(1) 13 | 14 | 15 | from os import path 16 | 17 | this_dir = path.abspath(path.dirname(__file__)) 18 | with open(path.join(this_dir, 'README.md'), encoding='utf8') as f: 19 | long_description = f.read() 20 | 21 | # Scripts 22 | scripts = [] 23 | for dirname, dirnames, filenames in os.walk('scripts'): 24 | for filename in filenames: 25 | if filename.endswith('.py'): 26 | scripts.append(os.path.join(dirname, filename)) 27 | 28 | setup( 29 | name='sgdml', 30 | version=get_property('__version__', 'sgdml'), 31 | description='Reference implementation of the GDML and sGDML force field models.', 32 | long_description=long_description, 33 | long_description_content_type='text/markdown', 34 | classifiers=[ 35 | 'Development Status :: 4 - Beta', 36 | 'Environment :: Console', 37 | 'Intended Audience :: Science/Research', 38 | 'Intended Audience :: Education', 39 | 'Intended Audience :: Developers', 40 | 'License :: OSI Approved :: MIT License', 41 | 'Operating System :: MacOS :: MacOS X', 42 | 'Operating System :: POSIX :: Linux', 43 | 'Programming Language :: Python :: 3.7', 44 | 'Topic :: Scientific/Engineering :: Chemistry', 45 | 'Topic :: Scientific/Engineering :: Physics', 46 | 'Topic :: Software Development :: Libraries :: Python Modules', 47 | ], 48 | url='http://www.sgdml.org', 49 | author='Stefan Chmiela', 50 | author_email='sgdml@chmiela.com', 51 | license='LICENSE.txt', 52 | packages=find_packages(), 53 | install_requires=['torch >= 1.8', 'numpy >= 1.19.0', 'scipy >= 1.1.0', 'psutil', 'future'], 54 | entry_points={ 55 | 'console_scripts': ['sgdml=sgdml.cli:main', 'sgdml-get=sgdml.get:main'] 56 | }, 57 | extras_require={'ase': ['ase >= 3.16.2']}, 58 | scripts=scripts, 59 | include_package_data=True, 60 | zip_safe=False, 61 | ) 62 | -------------------------------------------------------------------------------- /scripts/sgdml_dataset_to_extxyz.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # MIT License 4 | # 5 | # Copyright (c) 2018-2019 Stefan Chmiela 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | from __future__ import print_function 26 | 27 | import argparse 28 | import os 29 | import sys 30 | 31 | import numpy as np 32 | 33 | from sgdml.utils import io, ui 34 | 35 | 36 | parser = argparse.ArgumentParser( 37 | description='Converts a native dataset file to extended XYZ format.' 38 | ) 39 | parser.add_argument( 40 | 'dataset', 41 | metavar='', 42 | type=lambda x: io.is_file_type(x, 'dataset'), 43 | help='path to dataset file', 44 | ) 45 | parser.add_argument( 46 | '-o', 47 | '--overwrite', 48 | dest='overwrite', 49 | action='store_true', 50 | help='overwrite existing xyz dataset file', 51 | ) 52 | 53 | args = parser.parse_args() 54 | dataset_path, dataset = args.dataset 55 | 56 | name = os.path.splitext(os.path.basename(dataset_path))[0] 57 | dataset_file_name = name + '.xyz' 58 | 59 | xyz_exists = os.path.isfile(dataset_file_name) 60 | if xyz_exists and args.overwrite: 61 | print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing xyz dataset file.') 62 | if not xyz_exists or args.overwrite: 63 | print(ui.color_str('[INFO]', bold=True) + ' Writing dataset to \'{}\'...'.format(dataset_file_name)) 64 | else: 65 | sys.exit( 66 | ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' Dataset \'{}\' already exists.'.format(dataset_file_name) 67 | ) 68 | 69 | R = dataset['R'] 70 | z = dataset['z'] 71 | F = dataset['F'] 72 | 73 | lattice = dataset['lattice'] if 'lattice' in dataset else None 74 | 75 | try: 76 | with open(dataset_file_name, 'w') as file: 77 | 78 | n = R.shape[0] 79 | for i, r in enumerate(R): 80 | 81 | e = np.squeeze(dataset['E'][i]) if 'E' in dataset else None 82 | f = dataset['F'][i,:,:] 83 | ext_xyz_str = io.generate_xyz_str(r, z, e=e, f=f, lattice=lattice) + '\n' 84 | 85 | file.write(ext_xyz_str) 86 | 87 | progr = float(i) / (n - 1) 88 | ui.callback(i, n - 1, disp_str='Exporting %d data points...' % n) 89 | 90 | except IOError: 91 | sys.exit("ERROR: Writing xyz file failed.") 92 | 93 | print() 94 | -------------------------------------------------------------------------------- /scripts/sgdml_datasets_from_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # MIT License 4 | # 5 | # Copyright (c) 2018 Stefan Chmiela 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | from __future__ import print_function 26 | 27 | import argparse 28 | import os 29 | import sys 30 | 31 | import numpy as np 32 | 33 | from sgdml.utils import io, ui 34 | 35 | parser = argparse.ArgumentParser( 36 | description='Extracts the training and test data subsets from a dataset that were used to construct a model.' 37 | ) 38 | parser.add_argument( 39 | 'model', 40 | metavar='', 41 | type=lambda x: io.is_file_type(x, 'model'), 42 | help='path to model file', 43 | ) 44 | parser.add_argument( 45 | 'dataset', 46 | metavar='', 47 | type=lambda x: io.is_file_type(x, 'dataset'), 48 | help='path to dataset file referenced in model', 49 | ) 50 | parser.add_argument( 51 | '-o', 52 | '--overwrite', 53 | dest='overwrite', 54 | action='store_true', 55 | help='overwrite existing files', 56 | ) 57 | args = parser.parse_args() 58 | 59 | model_path, model = args.model 60 | dataset_path, dataset = args.dataset 61 | 62 | 63 | for s in ['train', 'valid']: 64 | 65 | if dataset['md5'] != model['md5_' + s]: 66 | sys.exit( 67 | ui.fail_str('[FAIL]') 68 | + ' Dataset fingerprint does not match the one referenced in model for \'%s\'.' 69 | % s 70 | ) 71 | 72 | idxs = model['idxs_' + s] 73 | R = dataset['R'][idxs, :, :] 74 | E = dataset['E'][idxs] 75 | F = dataset['F'][idxs, :, :] 76 | 77 | base_vars = { 78 | 'type': 'd', 79 | 'name': dataset['name'].astype(str), 80 | 'theory': dataset['theory'].astype(str), 81 | 'z': dataset['z'], 82 | 'R': R, 83 | 'E': E, 84 | 'F': F, 85 | } 86 | base_vars['md5'] = io.dataset_md5(base_vars) 87 | 88 | subset_file_name = '%s_%s.npz' % ( 89 | os.path.splitext(os.path.basename(dataset_path))[0], 90 | s, 91 | ) 92 | file_exists = os.path.isfile(subset_file_name) 93 | if file_exists and args.overwrite: 94 | print(ui.info_str('[INFO]') + ' Overwriting existing model file.') 95 | if not file_exists or args.overwrite: 96 | np.savez_compressed(subset_file_name, **base_vars) 97 | ui.callback(1, disp_str='Extracted %s dataset saved to \'%s\'' % (s, subset_file_name)) # DONE 98 | else: 99 | print( 100 | ui.warn_str('[WARN]') 101 | + ' %s dataset \'%s\' already exists.' % (s.capitalize(), subset_file_name) 102 | + '\n Run \'python %s -o %s %s\' to overwrite.\n' 103 | % (os.path.basename(__file__), model_path, dataset_path) 104 | ) 105 | sys.exit() 106 | -------------------------------------------------------------------------------- /sgdml/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # MIT License 4 | # 5 | # Copyright (c) 2019-2025 Stefan Chmiela 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | __version__ = '1.0.3' 26 | 27 | MAX_PRINT_WIDTH = 100 28 | LOG_LEVELNAME_WIDTH = 7 # do not modify 29 | 30 | # more descriptive callback status 31 | DONE = 1 32 | NOT_DONE = 0 33 | 34 | 35 | # Logging 36 | 37 | import copy 38 | import logging 39 | import re 40 | import textwrap 41 | 42 | from .utils import ui 43 | 44 | 45 | class ColoredFormatter(logging.Formatter): 46 | 47 | LEVEL_COLORS = { 48 | 'DEBUG': (ui.CYAN, ui.BLACK), 49 | 'INFO': (ui.WHITE, ui.BLACK), 50 | 'DONE': (ui.GREEN, ui.BLACK), 51 | 'WARNING': (ui.YELLOW, ui.BLACK), 52 | 'ERROR': (ui.RED, ui.BLACK), 53 | 'CRITICAL': (ui.BLACK, ui.RED), 54 | } 55 | 56 | LEVEL_NAMES = { 57 | 'DEBUG': '[DEBG]', 58 | 'INFO': '[INFO]', 59 | 'DONE': '[DONE]', 60 | 'WARNING': '[WARN]', 61 | 'ERROR': '[FAIL]', 62 | 'CRITICAL': '[CRIT]', 63 | } 64 | 65 | def __init__(self, msg, use_color=True): 66 | 67 | logging.Formatter.__init__(self, msg) 68 | self.use_color = use_color 69 | 70 | def format(self, record): 71 | 72 | _record = copy.copy(record) 73 | levelname = _record.levelname 74 | msg = _record.msg 75 | 76 | levelname = ui.color_str( 77 | self.LEVEL_NAMES[levelname], 78 | self.LEVEL_COLORS[levelname][0], 79 | self.LEVEL_COLORS[levelname][1], 80 | bold=True, 81 | ) 82 | 83 | if _record.levelname != 'CRITICAL': 84 | # wrap long messages (except for critical [i.e. exceptions, since they print a formatted traceback string]) 85 | msg = ui.wrap_str(msg) 86 | 87 | # indent multiline strings after the first line 88 | msg = ui.indent_str(msg, LOG_LEVELNAME_WIDTH)[LOG_LEVELNAME_WIDTH:] 89 | 90 | _record.levelname = levelname 91 | _record.msg = msg 92 | return logging.Formatter.format(self, _record) 93 | 94 | 95 | class ColoredLogger(logging.Logger): 96 | def __init__(self, name): 97 | 98 | logging.Logger.__init__(self, name, logging.DEBUG) 99 | 100 | # add 'DONE' logging level 101 | logging.DONE = logging.INFO + 1 102 | logging.addLevelName(logging.DONE, 'DONE') 103 | 104 | # only display levelname and message 105 | formatter = ColoredFormatter('%(levelname)s %(message)s') 106 | 107 | # this handler will write to sys.stderr by default 108 | hd = logging.StreamHandler() 109 | hd.setFormatter(formatter) 110 | hd.setLevel( 111 | logging.INFO 112 | ) # control logging level here 113 | 114 | self.addHandler(hd) 115 | return 116 | 117 | def done(self, msg, *args, **kwargs): 118 | 119 | if self.isEnabledFor(logging.DONE): 120 | self._log(logging.DONE, msg, args, **kwargs) 121 | 122 | 123 | logging.setLoggerClass(ColoredLogger) 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Symmetric Gradient Domain Machine Learning (sGDML) 2 | 3 | For more details visit: [sgdml.org](http://sgdml.org/) 4 | Documentation can be found here: [docs.sgdml.org](http://docs.sgdml.org/) 5 | 6 | #### Requirements: 7 | - Python 3.7+ 8 | - PyTorch (>=1.8) 9 | - NumPy (>=1.19) 10 | - SciPy (>=1.1) 11 | 12 | #### Optional: 13 | - ASE (>=3.16.2) (to run atomistic simulations) 14 | 15 | ## Getting started 16 | 17 | ### Stable release 18 | 19 | Most systems come with the default package manager for Python ``pip`` already preinstalled. Install ``sgdml`` by simply calling: 20 | 21 | ``` 22 | $ pip install sgdml 23 | ``` 24 | 25 | The ``sgdml`` command-line interface and the corresponding Python API can now be used from anywhere on the system. 26 | 27 | ### Development version 28 | 29 | #### (1) Clone the repository 30 | 31 | ``` 32 | $ git clone https://github.com/stefanch/sGDML.git 33 | $ cd sGDML 34 | ``` 35 | 36 | ...or update your existing local copy with 37 | 38 | ``` 39 | $ git pull origin master 40 | ``` 41 | 42 | #### (2) Install 43 | 44 | ``` 45 | $ pip install -e . 46 | ``` 47 | 48 | Using the flag ``--user``, you can tell ``pip`` to install the package to the current users's home directory, instead of system-wide. This option might require you to update your system's ``PATH`` variable accordingly. 49 | 50 | 51 | ### Optional dependencies 52 | 53 | Some functionality of this package relies on third-party libraries that are not installed by default. These optional dependencies (or "package extras") are specified during installation using the "square bracket syntax": 54 | 55 | ``` 56 | $ pip install sgdml[] 57 | ``` 58 | 59 | #### Atomic Simulation Environment (ASE) 60 | 61 | If you are interested in interfacing with [ASE](https://wiki.fysik.dtu.dk/ase/) to perform atomistic simulations (see [here](http://docs.sgdml.org/applications.html) for examples), use the ``ase`` keyword: 62 | 63 | ``` 64 | $ pip install sgdml[ase] 65 | ``` 66 | 67 | ## Reconstruct your first force field 68 | 69 | Download one of the example datasets: 70 | 71 | ``` 72 | $ sgdml-get dataset ethanol_dft 73 | ``` 74 | 75 | Train a force field model: 76 | 77 | ``` 78 | $ sgdml all ethanol_dft.npz 200 1000 5000 79 | ``` 80 | 81 | ## Query a force field 82 | 83 | ```python 84 | import numpy as np 85 | from sgdml.predict import GDMLPredict 86 | from sgdml.utils import io 87 | 88 | r,_ = io.read_xyz('geometries/ethanol.xyz') # 9 atoms 89 | print(r.shape) # (1,27) 90 | 91 | model = np.load('models/ethanol.npz') 92 | gdml = GDMLPredict(model) 93 | e,f = gdml.predict(r) 94 | print(e.shape) # (1,) 95 | print(f.shape) # (1,27) 96 | ``` 97 | 98 | ## Authors 99 | 100 | * Stefan Chmiela 101 | * Jan Hermann 102 | 103 | We appreciate and welcome contributions and would like to thank the following people for participating in this project: 104 | 105 | * Huziel Sauceda 106 | * Igor Poltavsky 107 | * Luis Gálvez 108 | * Danny Panknin 109 | * Grégory Fonseca 110 | * Anton Charkin-Gorbulin 111 | 112 | ## References 113 | 114 | * [1] Chmiela, S., Tkatchenko, A., Sauceda, H. E., Poltavsky, I., Schütt, K. T., Müller, K.-R., 115 | *Machine Learning of Accurate Energy-conserving Molecular Force Fields.* 116 | Science Advances, 3(5), e1603015 (2017) 117 | [10.1126/sciadv.1603015](http://dx.doi.org/10.1126/sciadv.1603015) 118 | 119 | * [2] Chmiela, S., Sauceda, H. E., Müller, K.-R., Tkatchenko, A., 120 | *Towards Exact Molecular Dynamics Simulations with Machine-Learned Force Fields.* 121 | Nature Communications, 9(1), 3887 (2018) 122 | [10.1038/s41467-018-06169-2](https://doi.org/10.1038/s41467-018-06169-2) 123 | 124 | * [3] Chmiela, S., Sauceda, H. E., Poltavsky, I., Müller, K.-R., Tkatchenko, A., 125 | *sGDML: Constructing Accurate and Data Efficient Molecular Force Fields Using Machine Learning.* 126 | Computer Physics Communications, 240, 38-45 (2019) 127 | [10.1016/j.cpc.2019.02.007](https://doi.org/10.1016/j.cpc.2019.02.007) 128 | 129 | * [4] Chmiela, S., Vassilev-Galindo, V., Unke, O. T., Kabylda, A., Sauceda, H. E., Tkatchenko, A., Müller, K.-R., 130 | *Accurate Global Machine Learning Force Fields for Molecules With Hundreds of Atoms.* 131 | Science Advances, 9(2), e1603015 (2023) 132 | [10.1126/sciadv.adf0873](https://doi.org/10.1126/sciadv.adf0873) -------------------------------------------------------------------------------- /sgdml/intf/ase_calc.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018-2020 Stefan Chmiela 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import logging 24 | import numpy as np 25 | 26 | try: 27 | from ase.calculators.calculator import Calculator 28 | from ase.units import kcal, mol 29 | except ImportError: 30 | raise ImportError( 31 | 'Optional ASE dependency not found! Please run \'pip install sgdml[ase]\' to install it.' 32 | ) 33 | 34 | from ..predict import GDMLPredict 35 | 36 | 37 | class SGDMLCalculator(Calculator): 38 | 39 | implemented_properties = ['energy', 'forces'] 40 | 41 | def __init__( 42 | self, 43 | model_path, 44 | E_to_eV=kcal / mol, 45 | F_to_eV_Ang=kcal / mol, 46 | use_torch=False, 47 | *args, 48 | **kwargs 49 | ): 50 | """ 51 | ASE calculator for the sGDML force field. 52 | 53 | A calculator takes atomic numbers and atomic positions from an Atoms object and calculates the energy and forces. 54 | 55 | Note 56 | ---- 57 | ASE uses eV and Angstrom as energy and length unit, respectively. Unless the paramerters `E_to_eV` and `F_to_eV_Ang` are specified, the sGDML model is assumed to use kcal/mol and Angstorm and the appropriate conversion factors are set accordingly. 58 | Here is how to find them: `ASE units `_. 59 | 60 | Parameters 61 | ---------- 62 | model_path : :obj:`str` 63 | Path to a sGDML model file 64 | E_to_eV : float, optional 65 | Conversion factor from whatever energy unit is used by the model to eV. By default this parameter is set to convert from kcal/mol. 66 | F_to_eV_Ang : float, optional 67 | Conversion factor from whatever length unit is used by the model to Angstrom. By default, the length unit is not converted (assumed to be in Angstrom) 68 | use_torch : boolean, optional 69 | Use PyTorch to calculate predictions 70 | """ 71 | 72 | super(SGDMLCalculator, self).__init__(*args, **kwargs) 73 | 74 | self.log = logging.getLogger(__name__) 75 | 76 | model = np.load(model_path, allow_pickle=True) 77 | self.gdml_predict = GDMLPredict(model, use_torch=use_torch) 78 | self.gdml_predict.prepare_parallel(n_bulk=1) 79 | 80 | self.log.warning( 81 | 'Please remember to specify the proper conversion factors, if your model does not use \'kcal/mol\' and \'Ang\' as units.' 82 | ) 83 | 84 | # Converts energy from the unit used by the sGDML model to eV. 85 | self.E_to_eV = E_to_eV 86 | 87 | # Converts length from eV to unit used in sGDML model. 88 | self.Ang_to_R = F_to_eV_Ang / E_to_eV 89 | 90 | # Converts force from the unit used by the sGDML model to eV/Ang. 91 | self.F_to_eV_Ang = F_to_eV_Ang 92 | 93 | def calculate(self, atoms=None, *args, **kwargs): 94 | 95 | super(SGDMLCalculator, self).calculate(atoms, *args, **kwargs) 96 | 97 | # convert model units to ASE default units 98 | r = np.array(atoms.get_positions()) * self.Ang_to_R 99 | 100 | e, f = self.gdml_predict.predict(r.ravel()) 101 | 102 | # convert model units to ASE default units (eV and Ang) 103 | e *= self.E_to_eV 104 | f *= self.F_to_eV_Ang 105 | 106 | self.results = {'energy': e, 'forces': f.reshape(-1, 3)} 107 | -------------------------------------------------------------------------------- /sgdml/get.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # MIT License 4 | # 5 | # Copyright (c) 2018-2023 Stefan Chmiela 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | from __future__ import print_function 26 | 27 | import argparse 28 | import os 29 | import re 30 | import sys 31 | 32 | from . import __version__ 33 | from .utils import ui 34 | 35 | if sys.version[0] == '3': 36 | raw_input = input 37 | 38 | try: 39 | from urllib.request import urlopen 40 | except ImportError: 41 | from urllib2 import urlopen 42 | 43 | 44 | def download(command, file_name): 45 | 46 | base_url = 'http://www.quantum-machine.org/gdml/' + ( 47 | 'data/npz/' if command == 'dataset' else 'models/' 48 | ) 49 | request = urlopen(base_url + file_name) 50 | file = open(file_name, 'wb') 51 | filesize = int(request.headers['Content-Length']) 52 | 53 | size = 0 54 | block_sz = 1024 55 | while True: 56 | buffer = request.read(block_sz) 57 | if not buffer: 58 | break 59 | size += len(buffer) 60 | file.write(buffer) 61 | 62 | ui.callback( 63 | size, 64 | filesize, 65 | disp_str='Downloading: {}'.format(file_name), 66 | sec_disp_str='{:,} bytes'.format(filesize), 67 | ) 68 | file.close() 69 | 70 | 71 | def main(): 72 | 73 | base_url = 'http://www.quantum-machine.org/gdml/' 74 | 75 | parser = argparse.ArgumentParser() 76 | 77 | parent_parser = argparse.ArgumentParser(add_help=False) 78 | parent_parser.add_argument( 79 | '-o', 80 | '--overwrite', 81 | dest='overwrite', 82 | action='store_true', 83 | help='overwrite existing files', 84 | ) 85 | 86 | subparsers = parser.add_subparsers(title='commands', dest='command') 87 | subparsers.required = True 88 | parser_dataset = subparsers.add_parser( 89 | 'dataset', help='download benchmark dataset', parents=[parent_parser] 90 | ) 91 | parser_model = subparsers.add_parser( 92 | 'model', help='download pre-trained model', parents=[parent_parser] 93 | ) 94 | 95 | for subparser in [parser_dataset, parser_model]: 96 | subparser.add_argument( 97 | 'name', 98 | metavar='', 99 | type=str, 100 | help='item name', 101 | nargs='?', 102 | default=None, 103 | ) 104 | 105 | args = parser.parse_args() 106 | 107 | print("Contacting server (%s)..." % base_url) 108 | 109 | if args.name is not None: 110 | 111 | url = '%sget.php?version=%s&%s=%s' % ( 112 | base_url, 113 | __version__, 114 | args.command, 115 | args.name, 116 | ) 117 | response = urlopen(url) 118 | match, score = response.read().decode().split(',') 119 | response.close() 120 | 121 | if int(score) == 0 or ui.yes_or_no('Do you mean \'%s\'?' % match): 122 | download(args.command, match + '.npz') 123 | return 124 | 125 | response = urlopen( 126 | '%sget.php?version=%s&%s' % (base_url, __version__, args.command) 127 | ) 128 | line = response.readlines() 129 | response.close() 130 | 131 | print() 132 | print('Available %ss:' % args.command) 133 | 134 | print('{:<2} {:<31} {:>4}'.format('ID', 'Name', 'Size')) 135 | print('-' * 42) 136 | 137 | items = line[0].split(b';') 138 | for i, item in enumerate(items): 139 | name, size = item.split(b',') 140 | size = int(size) / 1024**2 # Bytes to MBytes 141 | 142 | print('{:>2d} {:<30} {:>5.1f} MB'.format(i, name.decode("utf-8"), size)) 143 | print() 144 | 145 | down_list = raw_input( 146 | 'Please list which %ss to download (e.g. 0 1 2 6) or type \'all\': ' 147 | % args.command 148 | ) 149 | down_idxs = [] 150 | if 'all' in down_list.lower(): 151 | down_idxs = list(range(len(items))) 152 | elif re.match( 153 | "^ *[0-9][0-9 ]*$", down_list 154 | ): # only digits and spaces, at least one digit 155 | down_idxs = [int(idx) for idx in re.split(r'\s+', down_list.strip())] 156 | down_idxs = list(set(down_idxs)) 157 | else: 158 | print(ui.color_str('ABORTED', fore_color=ui.RED, bold=True)) 159 | 160 | for idx in down_idxs: 161 | if idx not in range(len(items)): 162 | print( 163 | ui.color_str('[WARN]', fore_color=ui.YELLOW, bold=True) 164 | + ' Index ' 165 | + str(idx) 166 | + ' out of range, skipping.' 167 | ) 168 | else: 169 | name = items[idx].split(b',')[0].decode("utf-8") 170 | if os.path.exists(name): 171 | print("'%s' exists, skipping." % (name)) 172 | continue 173 | 174 | download(args.command, name + '.npz') 175 | 176 | 177 | if __name__ == "__main__": 178 | main() 179 | -------------------------------------------------------------------------------- /sgdml/solvers/analytic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # MIT License 4 | # 5 | # Copyright (c) 2020-2022 Stefan Chmiela 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import sys 26 | import logging 27 | import warnings 28 | from functools import partial 29 | 30 | import numpy as np 31 | import scipy as sp 32 | import timeit 33 | 34 | from .. import DONE, NOT_DONE 35 | 36 | 37 | class Analytic(object): 38 | def __init__(self, gdml_train, desc, callback=None): 39 | 40 | self.log = logging.getLogger(__name__) 41 | 42 | self.gdml_train = gdml_train 43 | self.desc = desc 44 | 45 | self.callback = callback 46 | 47 | # from memory_profiler import profile 48 | # @profile 49 | def solve(self, task, R_desc, R_d_desc, tril_perms_lin, y): 50 | 51 | sig = task['sig'] 52 | lam = task['lam'] 53 | use_E_cstr = task['use_E_cstr'] 54 | 55 | n_train, dim_d = R_d_desc.shape[:2] 56 | n_atoms = int((1 + np.sqrt(8 * dim_d + 1)) / 2) 57 | dim_i = 3 * n_atoms 58 | 59 | if self.callback is not None: 60 | self.callback = partial( 61 | self.callback, 62 | disp_str='Assembling kernel matrix', 63 | ) 64 | 65 | K = -self.gdml_train._assemble_kernel_mat( 66 | R_desc, 67 | R_d_desc, 68 | tril_perms_lin, 69 | sig, 70 | self.desc, 71 | use_E_cstr=use_E_cstr, 72 | callback=self.callback, 73 | ) # Flip sign to make convex 74 | 75 | start = timeit.default_timer() 76 | 77 | with warnings.catch_warnings(): 78 | warnings.simplefilter('ignore') 79 | 80 | if K.shape[0] == K.shape[1]: 81 | 82 | K[np.diag_indices_from(K)] += lam # Regularize 83 | 84 | if self.callback is not None: 85 | self.callback = partial( 86 | self.callback, 87 | disp_str='Solving linear system (Cholesky factorization)', 88 | ) 89 | self.callback(NOT_DONE) 90 | 91 | try: 92 | 93 | # Cholesky (do not overwrite K in case we need to retry) 94 | L, lower = sp.linalg.cho_factor( 95 | K, overwrite_a=False, check_finite=False 96 | ) 97 | alphas = -sp.linalg.cho_solve( 98 | (L, lower), y, overwrite_b=False, check_finite=False 99 | ) 100 | 101 | except np.linalg.LinAlgError: # Try a solver that makes less assumptions 102 | 103 | if self.callback is not None: 104 | self.callback = partial( 105 | self.callback, 106 | disp_str='Solving linear system (LU factorization) ', # Keep whitespaces! 107 | ) 108 | self.callback(NOT_DONE) 109 | 110 | try: 111 | # LU 112 | alphas = -sp.linalg.solve( 113 | K, y, overwrite_a=True, overwrite_b=True, check_finite=False 114 | ) 115 | except MemoryError: 116 | self.log.critical( 117 | 'Not enough memory to train this system using a closed form solver.' 118 | ) 119 | print() 120 | os._exit(1) 121 | 122 | except MemoryError: 123 | self.log.critical( 124 | 'Not enough memory to train this system using a closed form solver.' 125 | ) 126 | print() 127 | os._exit(1) 128 | else: 129 | 130 | if self.callback is not None: 131 | self.callback = partial( 132 | self.callback, 133 | disp_str='Solving over-determined linear system (least squares approximation)', 134 | ) 135 | self.callback(NOT_DONE) 136 | 137 | # Least squares for non-square K 138 | alphas = -np.linalg.lstsq(K, y, rcond=-1)[0] 139 | 140 | stop = timeit.default_timer() 141 | 142 | if self.callback is not None: 143 | dur_s = stop - start 144 | sec_disp_str = 'took {:.1f} s'.format(dur_s) if dur_s >= 0.1 else '' 145 | self.callback( 146 | DONE, 147 | disp_str='Training on {:,} points'.format(n_train), 148 | sec_disp_str=sec_disp_str, 149 | ) 150 | 151 | return alphas 152 | 153 | @staticmethod 154 | def est_memory_requirement(n_train, n_atoms): 155 | 156 | est_bytes = 3 * (n_train * 3 * n_atoms) ** 2 * 8 # K + factor(s) of K 157 | est_bytes += (n_train * 3 * n_atoms) * 8 # alpha 158 | 159 | return est_bytes 160 | -------------------------------------------------------------------------------- /scripts/sgdml_dataset_from_aims.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # MIT License 4 | # 5 | # Copyright (c) 2018-2022 Stefan Chmiela 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | from __future__ import print_function 26 | 27 | import argparse 28 | import os 29 | import sys 30 | 31 | import numpy as np 32 | 33 | from sgdml.utils import io, ui 34 | 35 | 36 | def read_reference_data(f): # noqa C901 37 | eV_to_kcalmol = 0.036749326 / 0.0015946679 38 | 39 | e_next, f_next, geo_next = False, False, False 40 | n_atoms = None 41 | R, z, E, F = [], [], [], [] 42 | 43 | geo_idx = 0 44 | for line in f: 45 | if n_atoms: 46 | cols = line.split() 47 | if e_next: 48 | E.append(float(cols[5])) 49 | e_next = False 50 | elif f_next: 51 | a = int(cols[1]) - 1 52 | F.append(list(map(float, cols[2:5]))) 53 | if a == n_atoms - 1: 54 | f_next = False 55 | elif geo_next: 56 | if 'atom' in cols: 57 | a_count += 1 # noqa: F821 58 | R.append(list(map(float, cols[1:4]))) 59 | 60 | if geo_idx == 0: 61 | z.append(io._z_str_to_z_dict[cols[4]]) 62 | 63 | if a_count == n_atoms: 64 | geo_next = False 65 | geo_idx += 1 66 | elif 'Energy and forces in a compact form:' in line: 67 | e_next = True 68 | elif 'Total atomic forces (unitary forces cleaned) [eV/Ang]:' in line: 69 | f_next = True 70 | elif ( 71 | 'Atomic structure (and velocities) as used in the preceding time step:' 72 | in line 73 | ): 74 | geo_next = True 75 | a_count = 0 76 | elif 'The structure contains' in line and 'atoms, and a total of' in line: 77 | n_atoms = int(line.split()[3]) 78 | print('Number atoms per geometry: {:>7d}'.format(n_atoms)) 79 | continue 80 | 81 | if geo_idx > 0 and geo_idx % 1000 == 0: 82 | sys.stdout.write("\rNumber geometries found so far: {:>7d}".format(geo_idx)) 83 | sys.stdout.flush() 84 | sys.stdout.write("\rNumber geometries found so far: {:>7d}".format(geo_idx)) 85 | sys.stdout.flush() 86 | print( 87 | '\n' 88 | + ui.color_str('[INFO]', bold=True) 89 | + ' Energies and forces have been converted from eV to kcal/mol(/Ang)' 90 | ) 91 | 92 | R = np.array(R).reshape(-1, n_atoms, 3) 93 | z = np.array(z) 94 | E = np.array(E) * eV_to_kcalmol 95 | F = np.array(F).reshape(-1, n_atoms, 3) * eV_to_kcalmol 96 | 97 | f.close() 98 | return (R, z, E, F) 99 | 100 | 101 | parser = argparse.ArgumentParser(description='Creates a dataset from FHI-aims format.') 102 | parser.add_argument( 103 | 'dataset', 104 | metavar='', 105 | type=argparse.FileType('r'), 106 | help='path to xyz dataset file', 107 | ) 108 | parser.add_argument( 109 | '-o', 110 | '--overwrite', 111 | dest='overwrite', 112 | action='store_true', 113 | help='overwrite existing dataset file', 114 | ) 115 | args = parser.parse_args() 116 | dataset = args.dataset 117 | 118 | name = os.path.splitext(os.path.basename(dataset.name))[0] 119 | dataset_file_name = name + '.npz' 120 | 121 | dataset_exists = os.path.isfile(dataset_file_name) 122 | if dataset_exists and args.overwrite: 123 | print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.') 124 | if not dataset_exists or args.overwrite: 125 | print('Writing dataset to \'%s\'...' % dataset_file_name) 126 | else: 127 | sys.exit( 128 | ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' Dataset \'%s\' already exists.' % dataset_file_name 129 | ) 130 | 131 | R, z, E, F = read_reference_data(dataset) 132 | 133 | # Prune all arrays to same length. 134 | n_mols = min(min(R.shape[0], F.shape[0]), E.shape[0]) 135 | if n_mols != R.shape[0] or n_mols != F.shape[0] or n_mols != E.shape[0]: 136 | print( 137 | ui.color_str('[WARN]', fore_color=ui.YELLOW, bold=True) 138 | + ' Incomplete output detected: Final dataset was pruned to %d points.' % n_mols 139 | ) 140 | R = R[:n_mols, :, :] 141 | F = F[:n_mols, :, :] 142 | E = E[:n_mols] 143 | 144 | # Base variables contained in every model file. 145 | base_vars = { 146 | 'type': 'd', 147 | 'R': R, 148 | 'z': z, 149 | 'E': E[:, None], 150 | 'F': F, 151 | 'e_unit': 'kcal/mol', 152 | 'r_unit': 'Ang', 153 | 'name': name, 154 | 'theory': 'unknown', 155 | } 156 | 157 | base_vars['F_min'], base_vars['F_max'] = np.min(F.ravel()), np.max(F.ravel()) 158 | base_vars['F_mean'], base_vars['F_var'] = np.mean(F.ravel()), np.var(F.ravel()) 159 | 160 | base_vars['E_min'], base_vars['E_max'] = np.min(E), np.max(E) 161 | base_vars['E_mean'], base_vars['E_var'] = np.mean(E), np.var(E) 162 | 163 | base_vars['md5'] = io.dataset_md5(base_vars) 164 | 165 | np.savez_compressed(dataset_file_name, **base_vars) 166 | print(ui.color_str('DONE', fore_color=ui.GREEN, bold=True)) 167 | -------------------------------------------------------------------------------- /scripts/sgdml_dataset_from_ipi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # MIT License 4 | # 5 | # Copyright (c) 2018 Stefan Chmiela 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | from __future__ import print_function 26 | 27 | import argparse 28 | import os 29 | import sys 30 | 31 | import numpy as np 32 | 33 | from sgdml.utils import io, ui 34 | 35 | 36 | def raw_input_float(prompt): 37 | while True: 38 | try: 39 | return float(input(prompt)) 40 | except ValueError: 41 | print(ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' That is not a valid float.') 42 | 43 | 44 | # Assumes that the atoms in each molecule are in the same order. 45 | def read_concat_xyz(f): 46 | n_atoms = None 47 | 48 | R, z = [], [] 49 | for i, line in enumerate(f): 50 | line = line.strip() 51 | if not n_atoms: 52 | n_atoms = int(line) 53 | print('Number atoms per geometry: {:>7d}'.format(n_atoms)) 54 | 55 | file_i, line_i = divmod(i, n_atoms + 2) 56 | 57 | cols = line.split() 58 | if line_i >= 2: 59 | if file_i == 0: # first molecule 60 | z.append(io._z_str_to_z_dict[cols[0]]) 61 | R.append(list(map(float, cols[1:4]))) 62 | 63 | if file_i % 1000 == 0: 64 | sys.stdout.write("\rNumber geometries found so far: {:>7d}".format(file_i)) 65 | sys.stdout.flush() 66 | sys.stdout.write("\rNumber geometries found so far: {:>7d}\n".format(file_i)) 67 | sys.stdout.flush() 68 | 69 | # Only keep complete entries. 70 | R = R[: int(n_atoms * np.floor(len(R) / float(n_atoms)))] 71 | 72 | R = np.array(R).reshape(-1, n_atoms, 3) 73 | z = np.array(z) 74 | 75 | f.close() 76 | return (R, z) 77 | 78 | 79 | def read_out_file(f, col): 80 | 81 | E = [] 82 | for i, line in enumerate(f): 83 | line = line.strip() 84 | if line[0] != '#': # Ignore comments. 85 | E.append(float(line.split()[col])) 86 | if i % 1000 == 0: 87 | sys.stdout.write("\rNumber lines processed so far: {:>7d}".format(len(E))) 88 | sys.stdout.flush() 89 | sys.stdout.write("\rNumber lines processed so far: {:>7d}\n".format(len(E))) 90 | sys.stdout.flush() 91 | 92 | return np.array(E) 93 | 94 | 95 | parser = argparse.ArgumentParser( 96 | description='Creates a dataset from extended [TODO] format.' 97 | ) 98 | parser.add_argument( 99 | 'geometries', 100 | metavar='', 101 | type=argparse.FileType('r'), 102 | help='path to XYZ geometry file', 103 | ) 104 | parser.add_argument( 105 | 'forces', 106 | metavar='', 107 | type=argparse.FileType('r'), 108 | help='path to XYZ force file', 109 | ) 110 | parser.add_argument( 111 | 'energies', 112 | metavar='', 113 | type=argparse.FileType('r'), 114 | help='path to CSV force file', 115 | ) 116 | parser.add_argument( 117 | 'energy_col', 118 | metavar='', 119 | type=lambda x: io.is_strict_pos_int(x), 120 | help='which column to parse from energy file (zero based)', 121 | nargs='?', 122 | default=0, 123 | ) 124 | parser.add_argument( 125 | '-o', 126 | '--overwrite', 127 | dest='overwrite', 128 | action='store_true', 129 | help='overwrite existing dataset file', 130 | ) 131 | args = parser.parse_args() 132 | geometries = args.geometries 133 | forces = args.forces 134 | energies = args.energies 135 | energy_col = args.energy_col 136 | 137 | name = os.path.splitext(os.path.basename(geometries.name))[0] 138 | dataset_file_name = name + '.npz' 139 | 140 | dataset_exists = os.path.isfile(dataset_file_name) 141 | if dataset_exists and args.overwrite: 142 | print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.') 143 | if not dataset_exists or args.overwrite: 144 | print('Writing dataset to \'%s\'...' % dataset_file_name) 145 | else: 146 | sys.exit( 147 | ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' Dataset \'%s\' already exists.' % dataset_file_name 148 | ) 149 | 150 | 151 | print('Reading geometries...') 152 | R, z = read_concat_xyz(geometries) 153 | 154 | print('Reading forces...') 155 | F, _ = read_concat_xyz(forces) 156 | 157 | print('Reading energies from column %d...' % energy_col) 158 | E = read_out_file(energies, energy_col) 159 | 160 | # Prune all arrays to same length. 161 | n_mols = min(min(R.shape[0], F.shape[0]), E.shape[0]) 162 | if n_mols != R.shape[0] or n_mols != F.shape[0] or n_mols != E.shape[0]: 163 | print( 164 | ui.color_str('[WARN]', fore_color=ui.YELLOW, bold=True) 165 | + ' Incomplete output detected: Final dataset was pruned to %d points.' % n_mols 166 | ) 167 | R = R[:n_mols, :, :] 168 | F = F[:n_mols, :, :] 169 | E = E[:n_mols] 170 | 171 | print( 172 | ui.color_str('[INFO]', bold=True) 173 | + ' Geometries, forces and energies must have consistent units.' 174 | ) 175 | R_conv_fact = raw_input_float('Unit conversion factor for geometries: ') 176 | R = R * R_conv_fact 177 | F_conv_fact = raw_input_float('Unit conversion factor for forces: ') 178 | F = F * F_conv_fact 179 | E_conv_fact = raw_input_float('Unit conversion factor for energies: ') 180 | E = E * E_conv_fact 181 | 182 | # Base variables contained in every model file. 183 | base_vars = { 184 | 'type': 'd', 185 | 'R': R, 186 | 'z': z, 187 | 'E': E[:, None], 188 | 'F': F, 189 | 'name': name, 190 | 'theory': 'unknown', 191 | } 192 | base_vars['md5'] = io.dataset_md5(base_vars) 193 | 194 | np.savez_compressed(dataset_file_name, **base_vars) 195 | ui.color_str('[DONE]', fore_color=ui.GREEN, bold=True) 196 | -------------------------------------------------------------------------------- /scripts/sgdml_dataset_via_ase.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # MIT License 4 | # 5 | # Copyright (c) 2018-2022 Stefan Chmiela 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | from __future__ import print_function 26 | 27 | import argparse 28 | import os 29 | import sys 30 | 31 | try: 32 | from ase.io import read 33 | except ImportError: 34 | raise ImportError('Optional ASE dependency not found! Please run \'pip install sgdml[ase]\' to install it.') 35 | 36 | import numpy as np 37 | 38 | from sgdml import __version__ 39 | from sgdml.utils import io, ui 40 | 41 | if sys.version[0] == '3': 42 | raw_input = input 43 | 44 | 45 | parser = argparse.ArgumentParser( 46 | description='Creates a dataset from any input format supported by ASE.' 47 | ) 48 | parser.add_argument( 49 | 'dataset', 50 | metavar='', 51 | type=argparse.FileType('r'), 52 | help='path to input dataset file', 53 | ) 54 | parser.add_argument( 55 | '-o', 56 | '--overwrite', 57 | dest='overwrite', 58 | action='store_true', 59 | help='overwrite existing dataset file', 60 | ) 61 | args = parser.parse_args() 62 | dataset = args.dataset 63 | 64 | 65 | name = os.path.splitext(os.path.basename(dataset.name))[0] 66 | dataset_file_name = name + '.npz' 67 | 68 | dataset_exists = os.path.isfile(dataset_file_name) 69 | if dataset_exists and args.overwrite: 70 | print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.') 71 | if not dataset_exists or args.overwrite: 72 | print('Writing dataset to \'{}\'...'.format(dataset_file_name)) 73 | else: 74 | sys.exit( 75 | ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) 76 | + ' Dataset \'{}\' already exists.'.format(dataset_file_name) 77 | ) 78 | 79 | mols = read(dataset.name, index=':') 80 | 81 | # filter incomplete outputs from trajectory 82 | mols = [mol for mol in mols if mol.get_calculator() is not None] 83 | 84 | lattice, R, z, E, F = None, None, None, None, None 85 | 86 | calc = mols[0].get_calculator() 87 | 88 | print("\rNumber geometries: {:,}".format(len(mols))) 89 | #print("\rAvailable properties: " + ', '.join(calc.results)) 90 | print() 91 | 92 | if 'forces' not in calc.results: 93 | sys.exit( 94 | ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) 95 | + ' Forces are missing in the input file!' 96 | ) 97 | 98 | lattice = np.array(mols[0].get_cell().T) 99 | if not np.any(lattice): 100 | print( 101 | ui.color_str('[INFO]', bold=True) 102 | + ' No lattice vectors specified.' 103 | ) 104 | lattice = None 105 | 106 | Z = np.array([mol.get_atomic_numbers() for mol in mols]) 107 | all_z_the_same = (Z == Z[0]).all() 108 | if not all_z_the_same: 109 | sys.exit( 110 | ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) 111 | + ' Order of atoms changes accross dataset.' 112 | ) 113 | 114 | R = np.array([mol.get_positions() for mol in mols]) 115 | z = Z[0] 116 | 117 | if 'Energy' in mols[0].info: 118 | E = np.array([float(mol.info['Energy']) for mol in mols]) 119 | else: 120 | E = np.array([mol.get_potential_energy() for mol in mols]) 121 | F = np.array([mol.get_forces() for mol in mols]) 122 | 123 | print('Please provide a name for this dataset. Otherwise the original filename will be reused.') 124 | custom_name = raw_input('> ').strip() 125 | if custom_name != '': 126 | name = custom_name 127 | 128 | print('Please provide a descriptor for the level of theory used to create this dataset.') 129 | theory = raw_input('> ').strip() 130 | if theory == '': 131 | theory = 'unknown' 132 | 133 | # Base variables contained in every model file. 134 | base_vars = { 135 | 'type': 'd', 136 | 'code_version': __version__, 137 | 'name': name, 138 | 'theory': theory, 139 | 'R': R, 140 | 'z': z, 141 | 'F': F, 142 | } 143 | 144 | base_vars['F_min'], base_vars['F_max'] = np.min(F.ravel()), np.max(F.ravel()) 145 | base_vars['F_mean'], base_vars['F_var'] = np.mean(F.ravel()), np.var(F.ravel()) 146 | 147 | print('If you want to convert your original length unit, please provide a conversion factor (default: 1.0): ') 148 | R_to_new_unit = raw_input('> ').strip() 149 | if R_to_new_unit != '': 150 | R_to_new_unit = float(R_to_new_unit) 151 | else: 152 | R_to_new_unit = 1.0 153 | 154 | print('If you want to convert your original energy unit, please provide a conversion factor (default: 1.0): ') 155 | E_to_new_unit = raw_input('> ').strip() 156 | if E_to_new_unit != '': 157 | E_to_new_unit = float(E_to_new_unit) 158 | else: 159 | E_to_new_unit = 1.0 160 | 161 | print('Please provide a description of the length unit, e.g. \'Ang\' or \'au\': ') 162 | print('Note: This string will be stored in the dataset file and passed on to models files for later reference.') 163 | r_unit = raw_input('> ').strip() 164 | if r_unit != '': 165 | base_vars['r_unit'] = r_unit 166 | 167 | print('Please provide a description of the energy unit, e.g. \'kcal/mol\' or \'eV\': ') 168 | print('Note: This string will be stored in the dataset file and passed on to models files for later reference.') 169 | e_unit = raw_input('> ').strip() 170 | if e_unit != '': 171 | base_vars['e_unit'] = e_unit 172 | 173 | if E is not None: 174 | base_vars['E'] = E * E_to_new_unit 175 | base_vars['E_min'], base_vars['E_max'] = np.min(E), np.max(E) 176 | base_vars['E_mean'], base_vars['E_var'] = np.mean(E), np.var(E) 177 | else: 178 | print(ui.color_str('[INFO]', bold=True) + ' No energy labels found in dataset.') 179 | 180 | base_vars['R'] *= R_to_new_unit 181 | base_vars['F'] *= E_to_new_unit / R_to_new_unit 182 | 183 | if lattice is not None: 184 | base_vars['lattice'] = lattice 185 | 186 | base_vars['md5'] = io.dataset_md5(base_vars) 187 | np.savez_compressed(dataset_file_name, **base_vars) 188 | print(ui.color_str('[DONE]', fore_color=ui.GREEN, bold=True)) 189 | -------------------------------------------------------------------------------- /scripts/sgdml_dataset_from_extxyz.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # MIT License 4 | # 5 | # Copyright (c) 2018-2022 Stefan Chmiela 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | from __future__ import print_function 26 | 27 | import argparse 28 | import os 29 | import sys 30 | 31 | try: 32 | from ase.io import read 33 | except ImportError: 34 | raise ImportError('Optional ASE dependency not found! Please run \'pip install sgdml[ase]\' to install it.') 35 | 36 | import numpy as np 37 | 38 | from sgdml import __version__ 39 | from sgdml.utils import io, ui 40 | 41 | if sys.version[0] == '3': 42 | raw_input = input 43 | 44 | 45 | # Note: assumes that the atoms in each molecule are in the same order. 46 | def read_nonstd_ext_xyz(f): 47 | n_atoms = None 48 | 49 | R, z, E, F = [], [], [], [] 50 | for i, line in enumerate(f): 51 | line = line.strip() 52 | if not n_atoms: 53 | n_atoms = int(line) 54 | print('Number atoms per geometry: {:,}'.format(n_atoms)) 55 | 56 | file_i, line_i = divmod(i, n_atoms + 2) 57 | 58 | if line_i == 1: 59 | try: 60 | e = float(line) 61 | except ValueError: 62 | pass 63 | else: 64 | E.append(e) 65 | 66 | cols = line.split() 67 | if line_i >= 2: 68 | R.append(list(map(float, cols[1:4]))) 69 | if file_i == 0: # first molecule 70 | z.append(io._z_str_to_z_dict[cols[0]]) 71 | F.append(list(map(float, cols[4:7]))) 72 | 73 | if file_i % 1000 == 0: 74 | sys.stdout.write('\rNumber geometries found so far: {:,}'.format(file_i)) 75 | sys.stdout.flush() 76 | sys.stdout.write('\rNumber geometries found so far: {:,}'.format(file_i)) 77 | sys.stdout.flush() 78 | print() 79 | 80 | R = np.array(R).reshape(-1, n_atoms, 3) 81 | z = np.array(z) 82 | E = None if not E else np.array(E) 83 | F = np.array(F).reshape(-1, n_atoms, 3) 84 | 85 | if F.shape[0] != R.shape[0]: 86 | sys.exit( 87 | ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) 88 | + ' Force labels are missing from dataset or are incomplete!' 89 | ) 90 | 91 | f.close() 92 | return (R, z, E, F) 93 | 94 | # Extracts info string for each frame. 95 | def extract_info_from_extxyz(file_path): 96 | infos = [] 97 | 98 | with open(file_path) as f: 99 | lines = f.readlines() 100 | 101 | i = 0 102 | while i < len(lines): 103 | try: 104 | n_atoms = int(lines[i]) 105 | except ValueError: 106 | raise ValueError(f"Invalid atom count at line {i + 1}") 107 | 108 | if i + 1 >= len(lines): 109 | break 110 | 111 | comment_line = lines[i + 1].strip() 112 | info = {} 113 | for token in comment_line.split(): 114 | if "=" in token: 115 | key, val = token.split("=", 1) 116 | val = val.strip('"') 117 | try: 118 | val = float(val) 119 | except ValueError: 120 | pass 121 | info[key] = val 122 | infos.append(info) 123 | 124 | i += 2 + n_atoms 125 | 126 | return infos 127 | 128 | 129 | parser = argparse.ArgumentParser( 130 | description='Creates a dataset from extended XYZ format.' 131 | ) 132 | parser.add_argument( 133 | 'dataset', 134 | metavar='', 135 | type=argparse.FileType('r'), 136 | help='path to extended xyz dataset file', 137 | ) 138 | parser.add_argument( 139 | '-o', 140 | '--overwrite', 141 | dest='overwrite', 142 | action='store_true', 143 | help='overwrite existing dataset file', 144 | ) 145 | args = parser.parse_args() 146 | dataset = args.dataset 147 | 148 | 149 | name = os.path.splitext(os.path.basename(dataset.name))[0] 150 | dataset_file_name = name + '.npz' 151 | 152 | dataset_exists = os.path.isfile(dataset_file_name) 153 | if dataset_exists and args.overwrite: 154 | print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.') 155 | if not dataset_exists or args.overwrite: 156 | print('Writing dataset to \'{}\'...'.format(dataset_file_name)) 157 | else: 158 | sys.exit( 159 | ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) 160 | + ' Dataset \'{}\' already exists.'.format(dataset_file_name) 161 | ) 162 | 163 | lattice, R, z, E, F = None, None, None, None, None 164 | 165 | mols = read(dataset.name, format='extxyz', index=':') 166 | #calc = mols[0].get_calculator() # depreciated 167 | calc = mols[0].calc 168 | is_extxyz = calc is not None 169 | if is_extxyz: 170 | 171 | print("\rNumber geometries found: {:,}\n".format(len(mols))) 172 | 173 | if 'forces' not in calc.results: 174 | sys.exit( 175 | ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) 176 | + ' Forces are missing in the input file!' 177 | ) 178 | 179 | lattice = np.array(mols[0].get_cell().T) 180 | if not np.any(lattice): # all zeros 181 | print( 182 | ui.color_str('[INFO]', bold=True) 183 | + ' No lattice vectors specified in extended XYZ file.' 184 | ) 185 | lattice = None 186 | 187 | Z = np.array([mol.get_atomic_numbers() for mol in mols]) 188 | all_z_the_same = (Z == Z[0]).all() 189 | if not all_z_the_same: 190 | sys.exit( 191 | ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) 192 | + ' Order of atoms changes accross dataset.' 193 | ) 194 | 195 | R = np.array([mol.get_positions() for mol in mols]) 196 | z = Z[0] 197 | 198 | # ASE did not parse info string. Try doing it manually. 199 | if not mols[0].info: 200 | 201 | print( 202 | ui.color_str('[INFO]', bold=True) 203 | + ' ASE did not parse info string completely. Try doing it manually.' 204 | ) 205 | 206 | infos = extract_info_from_extxyz(dataset.name) 207 | for mol, info in zip(mols, infos): 208 | mol.info.update(info) 209 | 210 | if 'Energy' in mols[0].info: 211 | E = np.array([mol.info['Energy'] for mol in mols]) 212 | if 'energy' in mols[0].info: 213 | E = np.array([mol.info['energy'] for mol in mols]) 214 | F = np.array([mol.get_forces() for mol in mols]) 215 | 216 | else: # legacy non-standard XYZ format 217 | 218 | with open(dataset.name) as f: 219 | R, z, E, F = read_nonstd_ext_xyz(f) 220 | 221 | # Base variables contained in every model file. 222 | base_vars = { 223 | 'type': 'd', 224 | 'code_version': __version__, 225 | 'name': name, 226 | 'theory': 'unknown', 227 | 'R': R, 228 | 'z': z, 229 | 'F': F, 230 | } 231 | 232 | base_vars['F_min'], base_vars['F_max'] = np.min(F.ravel()), np.max(F.ravel()) 233 | base_vars['F_mean'], base_vars['F_var'] = np.mean(F.ravel()), np.var(F.ravel()) 234 | 235 | print('Please provide a description of the length unit used in your input file, e.g. \'Ang\' or \'au\': ') 236 | print('Note: This string will be stored in the dataset file and passed on to models files for later reference.') 237 | r_unit = raw_input('> ').strip() 238 | if r_unit != '': 239 | base_vars['r_unit'] = r_unit 240 | 241 | print('Please provide a description of the energy unit used in your input file, e.g. \'kcal/mol\' or \'eV\': ') 242 | print('Note: This string will be stored in the dataset file and passed on to models files for later reference.') 243 | e_unit = raw_input('> ').strip() 244 | if e_unit != '': 245 | base_vars['e_unit'] = e_unit 246 | 247 | if E is not None: 248 | base_vars['E'] = E 249 | base_vars['E_min'], base_vars['E_max'] = np.min(E), np.max(E) 250 | base_vars['E_mean'], base_vars['E_var'] = np.mean(E), np.var(E) 251 | else: 252 | print(ui.color_str('[INFO]', bold=True) + ' No energy labels found in dataset.') 253 | 254 | if lattice is not None: 255 | base_vars['lattice'] = lattice 256 | 257 | base_vars['md5'] = io.dataset_md5(base_vars) 258 | np.savez_compressed(dataset_file_name, **base_vars) 259 | print(ui.color_str('[DONE]', fore_color=ui.GREEN, bold=True)) 260 | -------------------------------------------------------------------------------- /sgdml/utils/ui.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # MIT License 4 | # 5 | # Copyright (c) 2018-2021 Stefan Chmiela 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | from __future__ import print_function 26 | from functools import partial 27 | 28 | from .. import __version__, MAX_PRINT_WIDTH, LOG_LEVELNAME_WIDTH 29 | import textwrap 30 | import re 31 | import sys 32 | 33 | if sys.version[0] == '3': 34 | raw_input = input 35 | 36 | import numpy as np 37 | 38 | 39 | def yes_or_no(question): 40 | """ 41 | Ask for yes/no user input on a question. 42 | 43 | Any response besides ``y`` yields a negative answer. 44 | 45 | Parameters 46 | ---------- 47 | question : :obj:`str` 48 | User question. 49 | """ 50 | 51 | reply = raw_input(question + ' (y/n): ').lower().strip() 52 | if not reply or reply[0] != 'y': 53 | return False 54 | else: 55 | return True 56 | 57 | 58 | last_callback_pct = 0 59 | 60 | 61 | def callback( 62 | current, 63 | total=1, 64 | disp_str='', 65 | sec_disp_str=None, 66 | done_with_warning=False, 67 | newline_when_done=True, 68 | ): 69 | """ 70 | Print progress or toggle bar. 71 | 72 | Example (progress): 73 | ``[ 45%] Task description (secondary string)`` 74 | 75 | Example (toggle, not done): 76 | ``[ .. ] Task description (secondary string)`` 77 | 78 | Example (toggle, done): 79 | ``[DONE] Task description (secondary string)`` 80 | 81 | Parameters 82 | ---------- 83 | current : int 84 | How many items already processed? 85 | total : int, optional 86 | Total number of items? If there is only 87 | one item, the toggle style is used. 88 | disp_str : :obj:`str`, optional 89 | Task description. 90 | sec_disp_str : :obj:`str`, optional 91 | Additional string shown in gray. 92 | done_with_warning : bool, optional 93 | Indicate that the process did not 94 | finish successfully. 95 | newline_when_done : bool, optional 96 | Finish with a newline character once 97 | current=total (default: True)? 98 | """ 99 | 100 | global last_callback_pct 101 | 102 | is_toggle = total == 1 103 | is_done = np.isclose(current - total, 0.0) 104 | 105 | bold_color_str = partial(color_str, bold=True) 106 | 107 | if is_toggle: 108 | 109 | if is_done: 110 | if done_with_warning: 111 | flag_str = bold_color_str('[WARN]', fore_color=YELLOW) 112 | else: 113 | flag_str = bold_color_str('[DONE]', fore_color=GREEN) 114 | 115 | else: 116 | flag_str = bold_color_str('[' + blink_str(' .. ') + ']') 117 | else: 118 | 119 | # Only show progress in 10 percent steps when not printing to terminal. 120 | pct = int(float(current) * 100 / total) 121 | pct = int(np.ceil(pct / 10.0)) * 10 if not sys.stdout.isatty() else pct 122 | 123 | # Do not print, if there is no need to. 124 | if not is_done and pct == last_callback_pct: 125 | return 126 | else: 127 | last_callback_pct = pct 128 | 129 | flag_str = bold_color_str( 130 | '[{:3d}%]'.format(pct), fore_color=GREEN if is_done else WHITE 131 | ) 132 | 133 | sys.stdout.write('\r{} {}'.format(flag_str, disp_str)) 134 | 135 | if sec_disp_str is not None: 136 | w = MAX_PRINT_WIDTH - LOG_LEVELNAME_WIDTH - len(disp_str) - 1 137 | # sys.stdout.write(' \x1b[90m{0: >{width}}\x1b[0m'.format(sec_disp_str, width=w)) 138 | sys.stdout.write( 139 | color_str(' {:>{width}}'.format(sec_disp_str, width=w), fore_color=GRAY) 140 | ) 141 | 142 | if is_done and newline_when_done: 143 | sys.stdout.write('\n') 144 | 145 | sys.stdout.flush() 146 | 147 | 148 | # use this to integrate a callback for a subtask with an existing callback function 149 | # 'subtask_callback = partial(ui.sec_callback, main_callback=self.callback)' 150 | def sec_callback( 151 | current, total=1, disp_str=None, sec_disp_str=None, main_callback=None, **kwargs 152 | ): 153 | global last_callback_pct 154 | 155 | assert main_callback is not None 156 | 157 | is_toggle = total == 1 158 | is_done = np.isclose(current - total, 0.0) 159 | 160 | sec_disp_str = disp_str 161 | if is_toggle: 162 | sec_disp_str = '{} | {}'.format(disp_str, 'DONE' if is_done else ' .. ') 163 | else: 164 | 165 | # Only show progress in 10 percent steps when not printing to terminal. 166 | pct = int(float(current) * 100 / total) 167 | pct = int(np.ceil(pct / 10.0)) * 10 if not sys.stdout.isatty() else pct 168 | 169 | # Do not print, if there is no need to. 170 | if pct == last_callback_pct: 171 | return 172 | 173 | last_callback_pct = pct 174 | sec_disp_str = '{} | {:3d}%'.format(disp_str, pct) 175 | 176 | main_callback(0, sec_disp_str=sec_disp_str, **kwargs) 177 | 178 | 179 | # COLORS 180 | 181 | BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, GRAY = list(range(8)) + [60] 182 | COLOR_SEQ, RESET_SEQ = '\033[{:d};{:d};{:d}m', '\033[0m' 183 | 184 | ENABLE_COLORED_OUTPUT = ( 185 | sys.stdout.isatty() 186 | ) # Running in a real terminal or piped/redirected? 187 | 188 | 189 | def color_str(str, fore_color=WHITE, back_color=BLACK, bold=False): 190 | 191 | if ENABLE_COLORED_OUTPUT: 192 | 193 | # foreground is set with 30 plus the number of the color, background with 40 194 | return ( 195 | COLOR_SEQ.format(1 if bold else 0, 30 + fore_color, 40 + back_color) 196 | + str 197 | + RESET_SEQ 198 | ) 199 | else: 200 | return str 201 | 202 | 203 | def blink_str(str): 204 | 205 | return '\x1b[5m' + str + '\x1b[0m' if ENABLE_COLORED_OUTPUT else str 206 | 207 | 208 | def unicode_str(s): 209 | 210 | if sys.version[0] == '3': 211 | s = str(s, 'utf-8', 'ignore') 212 | else: 213 | s = str(s) 214 | 215 | return s.rstrip('\x00') # remove null-characters 216 | 217 | 218 | def gen_memory_str(bytes): 219 | 220 | pwr = 1024 221 | n = 0 222 | pwr_strs = {0: '', 1: 'K', 2: 'M', 3: 'G', 4: 'T'} 223 | while bytes > pwr and n < 4: 224 | bytes /= pwr 225 | n += 1 226 | 227 | return '{:.{num_dec_pts}f} {}B'.format( 228 | bytes, pwr_strs[n], num_dec_pts=max(0, n - 2) 229 | ) # 1 decimal point for GB, 2 for TB 230 | 231 | 232 | def gen_lattice_str(lat): 233 | 234 | lat_str, col_widths = gen_mat_str(lat) 235 | desc_str = (' '.join([('{:' + str(w) + '}') for w in col_widths])).format( 236 | 'a', 'b', 'c' 237 | ) + '\n' 238 | 239 | lat_str = indent_str(lat_str, 21) 240 | 241 | return desc_str + lat_str 242 | 243 | 244 | def str_plen(str): 245 | """ 246 | Returns printable length of string. This function can only account for invisible characters due to string styling with ``color_str``. 247 | 248 | Parameters 249 | ---------- 250 | str : :obj:`str` 251 | String. 252 | 253 | Returns 254 | ------- 255 | :obj:`str` 256 | 257 | """ 258 | 259 | num_colored_subs = str.count(RESET_SEQ) 260 | return len(str) - ( 261 | 14 * num_colored_subs 262 | ) # 14: length of invisible characters per colored segment 263 | 264 | 265 | def wrap_str(str, width=MAX_PRINT_WIDTH - LOG_LEVELNAME_WIDTH): 266 | """ 267 | Wrap multiline string after a given number of characters. The default maximum line already accounts for the indentation due to the logging level label. 268 | 269 | Parameters 270 | ---------- 271 | str : :obj:`str` 272 | Multiline string. 273 | width : int, optional 274 | Max number of characters in a line. 275 | 276 | Returns 277 | ------- 278 | :obj:`str` 279 | 280 | """ 281 | 282 | return '\n'.join( 283 | [ 284 | '\n'.join( 285 | textwrap.wrap( 286 | line, 287 | width + (len(line) - str_plen(line)), 288 | break_long_words=False, 289 | replace_whitespace=False, 290 | ) 291 | ) 292 | for line in str.splitlines() 293 | ] 294 | ) 295 | 296 | 297 | def indent_str(str, indent): 298 | """ 299 | Indents all lines of a multiline string right by a given number of 300 | characters. 301 | 302 | Parameters 303 | ---------- 304 | str : :obj:`str` 305 | Multiline string. 306 | indent : int 307 | Number of characters added in front of each line. 308 | 309 | Returns 310 | ------- 311 | :obj:`str` 312 | 313 | """ 314 | 315 | return re.sub('^', ' ' * indent, str, flags=re.MULTILINE) 316 | 317 | 318 | def wrap_indent_str(label, str, width=MAX_PRINT_WIDTH - LOG_LEVELNAME_WIDTH): 319 | """ 320 | Wraps and indents a multiline string to arrange it with the provided label in two columns. The default maximum line already accounts for the indentation due to the logging level label. 321 | 322 | Example: 323 | ``