├── images ├── example_molecules_1.png └── example_molecules_2.png ├── LICENSE ├── .gitignore ├── template_data.py ├── qm9_invalid.txt ├── qm9_data.py ├── nn_classes.py ├── template_preprocess_dataset.py ├── template_filter_generated.py ├── display_molecules.py ├── qm9_preprocess_dataset.py ├── README.md ├── gschnet_script.py └── utility_classes.py /images/example_molecules_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/G-SchNet/HEAD/images/example_molecules_1.png -------------------------------------------------------------------------------- /images/example_molecules_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/G-SchNet/HEAD/images/example_molecules_2.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Niklas Gebauer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea 6 | *.DS_Store 7 | 8 | # test data 9 | src/sacred_scripts/data/* 10 | src/sacred_scripts/models/* 11 | src/sacred_scripts/experiments/* 12 | src/scripts/data/ 13 | src/scripts/training/ 14 | 15 | docs/tutorials/*.db 16 | docs/tutorials/*.xyz 17 | docs/tutorials/qm9tut 18 | 19 | # C extensions 20 | *.so 21 | 22 | 23 | # Distribution / packaging 24 | .Python 25 | env/ 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | .hypothesis/ 62 | .pytest_cache 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # dotenv 98 | .env 99 | 100 | # virtualenv 101 | .venv 102 | venv/ 103 | ENV/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | -------------------------------------------------------------------------------- /template_data.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | import numpy as np 4 | import torch 5 | from ase.db import connect 6 | 7 | from schnetpack import Properties 8 | from schnetpack.datasets import AtomsData 9 | from utility_classes import ConnectivityCompressor 10 | from template_preprocess_dataset import preprocess_dataset 11 | 12 | 13 | class TemplateData(AtomsData): 14 | """ Simple template dataset class. We assume molecules made of C, N, O, F, 15 | and H atoms as illustration here. 16 | 17 | The class basically serves as interface to a database. It initiates 18 | pre-processing of the data in order to prepare it for usage with G-SchNet. 19 | To this end, it calls the template_preprocess_dataset script which provides 20 | very basic pre-processing (e.g. calculation of connectivity matrices) and can 21 | also be adapted to the data at hand. 22 | Single (pre-processed) data points are read from the database in the 23 | get_properties method (which is called in __getitem__). The class builds upon 24 | the AtomsData class from SchNetPack. 25 | 26 | Args: 27 | path (str): path to directory containing database 28 | subset (list, optional): indices of subset, set to None for entire dataset 29 | (default: None). 30 | precompute_distances (bool, optional): if True and the pre-processed 31 | database does not yet exist, the pairwise distances of atoms in the 32 | dataset's molecules will be computed during pre-processing and stored in 33 | the database (increases storage demand of the dataset but decreases 34 | computational cost during training as otherwise the distances will be 35 | computed once in every epoch, default: True) 36 | remove_invalid (bool, optional): if True, molecules that do not pass the 37 | implemented validity checks will be removed from the training data ( 38 | in the simple template_preprocess_dataset script this is only a check 39 | for disconnectedness, i.e. if all atoms are connected by some path as 40 | otherwise no proper generation trace can be sampled, 41 | note: only works if the pre-processed database does not yet exist, 42 | default: True) 43 | """ 44 | 45 | ##### Adjust the following settings to fit your data: ##### 46 | # name of the database 47 | db_name = 'template_data.db' 48 | # name of the database after pre-processing (if the same as db_name, the original 49 | # database will be renamed to .bak.db) 50 | preprocessed_db_name = 'template_data_gschnet.db' 51 | # all atom types found in molecules of the dataset 52 | available_atom_types = [1, 6, 7, 8, 9] # for example H, C, N, O, and F 53 | # valence constraints of the atom types (does not need to be provided unless a 54 | # valence check is implemented, but this is not the case in the template script) 55 | atom_types_valence = [1, 4, 3, 2, 1] 56 | # minimum and maximum distance between neighboring atoms in angstrom (this is 57 | # used to determine which atoms are considered as connected in the connectivity 58 | # matrix, i.e. for sampling generation traces during training, and also to restrict 59 | # the grid around the focused atom during generation, as the next atom will always 60 | # be a neighbor of the focused atom) 61 | radial_limits = [0.9, 1.7] 62 | 63 | # used to decompress connectivity matrices 64 | connectivity_compressor = ConnectivityCompressor() 65 | 66 | def __init__(self, path, subset=None, precompute_distances=True, 67 | remove_invalid=True): 68 | self.path_to_dir = Path(path) 69 | self.db_path = self.path_to_dir / self.preprocessed_db_name 70 | self.source_db_path = self.path_to_dir / self.db_name 71 | self.precompute_distances = precompute_distances 72 | self.remove_invalid = remove_invalid 73 | 74 | # do pre-processing (if database is not already pre-processed) 75 | found_connectivity = False 76 | if self.db_path.is_file(): 77 | with connect(self.db_path) as conn: 78 | n_mols = conn.count() 79 | if n_mols > 0: 80 | first_row = conn.get(1) 81 | found_connectivity = 'con_mat' in first_row.data 82 | if not found_connectivity: 83 | self._preprocess_data() 84 | 85 | super().__init__(str(self.db_path), subset=subset) 86 | 87 | def create_subset(self, idx): 88 | """ 89 | Returns a new dataset that only consists of provided indices. 90 | 91 | Args: 92 | idx (numpy.ndarray): subset indices 93 | 94 | Returns: 95 | schnetpack.data.AtomsData: dataset with subset of original data 96 | """ 97 | idx = np.array(idx) 98 | subidx = idx if self.subset is None or len(idx) == 0 \ 99 | else np.array(self.subset)[idx] 100 | return type(self)(self.path_to_dir, subidx) 101 | 102 | def get_properties(self, idx): 103 | _idx = self._subset_index(idx) 104 | with connect(self.db_path) as conn: 105 | row = conn.get(_idx + 1) 106 | at = row.toatoms() 107 | 108 | # extract/calculate structure (atom positions, types and cell) 109 | properties = {} 110 | properties[Properties.Z] = torch.LongTensor(at.numbers.astype(np.int)) 111 | positions = at.positions.astype(np.float32) 112 | positions -= at.get_center_of_mass() # center positions 113 | properties[Properties.R] = torch.FloatTensor(positions) 114 | properties[Properties.cell] = torch.FloatTensor(at.cell.astype(np.float32)) 115 | 116 | # recover connectivity matrix from compressed format 117 | con_mat = self.connectivity_compressor.decompress(row.data['con_mat']) 118 | # save in dictionary 119 | properties['_con_mat'] = torch.FloatTensor(con_mat.astype(np.float32)) 120 | 121 | # extract pre-computed distances (if they exist) 122 | if 'dists' in row.data: 123 | properties['dists'] = row.data['dists'][:, None] 124 | 125 | # get atom environment 126 | nbh_idx, offsets = self.environment_provider.get_environment(at) 127 | # store neighbors, cell, and index 128 | properties[Properties.neighbors] = torch.LongTensor(nbh_idx.astype(np.int)) 129 | properties[Properties.cell_offset] = torch.FloatTensor( 130 | offsets.astype(np.float32)) 131 | properties["_idx"] = torch.LongTensor(np.array([idx], dtype=np.int)) 132 | 133 | return at, properties 134 | 135 | def _preprocess_data(self): 136 | # check if pre-processing source db has different name than target db (if 137 | # not, rename it) 138 | source_db = self.path_to_dir / self.db_name 139 | if self.db_name == self.preprocessed_db_name: 140 | new_name = self.path_to_dir / (self.db_name + '.bak.db') 141 | source_db.rename(new_name) 142 | source_db = new_name 143 | # look for pre-computed list of invalid molecules 144 | invalid_list_path = self.source_db_path.parent / \ 145 | (self.source_db_path.stem + f'_invalid.txt') 146 | if invalid_list_path.is_file(): 147 | invalid_list = np.loadtxt(invalid_list_path) 148 | else: 149 | invalid_list = None 150 | # initialize pre-processing (calculation and validation of connectivity 151 | # matrices as well as computation of pairwise distances between atoms) 152 | valence_list = \ 153 | np.array([self.available_atom_types, self.atom_types_valence]).flatten('F') 154 | preprocess_dataset(datapath=source_db, 155 | cutoff=self.radial_limits[-1], 156 | valence_list=list(valence_list), 157 | logging_print=True, 158 | new_db_path=self.db_path, 159 | precompute_distances=self.precompute_distances, 160 | remove_invalid=self.remove_invalid, 161 | invalid_list=invalid_list) 162 | return True 163 | -------------------------------------------------------------------------------- /qm9_invalid.txt: -------------------------------------------------------------------------------- 1 | 270 2 | 281 3 | 1116 4 | 1641 5 | 1647 6 | 1669 7 | 3837 8 | 3894 9 | 3898 10 | 4009 11 | 4014 12 | 4806 13 | 4870 14 | 4871 15 | 5033 16 | 5556 17 | 5598 18 | 5803 19 | 6032 20 | 6058 21 | 6074 22 | 6169 23 | 6329 24 | 6345 25 | 6682 26 | 6725 27 | 6732 28 | 7325 29 | 7334 30 | 7374 31 | 7375 32 | 7409 33 | 7416 34 | 7418 35 | 7428 36 | 7601 37 | 7653 38 | 7656 39 | 8397 40 | 9985 41 | 10015 42 | 10275 43 | 10294 44 | 10348 45 | 10784 46 | 11387 47 | 13480 48 | 13804 49 | 13832 50 | 14120 51 | 14317 52 | 15511 53 | 17807 54 | 17808 55 | 18307 56 | 20389 57 | 20397 58 | 20481 59 | 20489 60 | 20641 61 | 20678 62 | 20682 63 | 20691 64 | 21118 65 | 21126 66 | 21164 67 | 21363 68 | 21463 69 | 21520 70 | 21533 71 | 21534 72 | 21554 73 | 21566 74 | 21610 75 | 21619 76 | 21702 77 | 21705 78 | 21716 79 | 21717 80 | 21724 81 | 21740 82 | 21747 83 | 21757 84 | 21763 85 | 21837 86 | 21856 87 | 21858 88 | 21869 89 | 21874 90 | 21967 91 | 21968 92 | 21969 93 | 21970 94 | 21971 95 | 21972 96 | 21973 97 | 21974 98 | 21975 99 | 21976 100 | 21977 101 | 21978 102 | 21979 103 | 21980 104 | 21981 105 | 21982 106 | 21983 107 | 21984 108 | 21985 109 | 21986 110 | 21987 111 | 22008 112 | 22053 113 | 22089 114 | 22092 115 | 22096 116 | 22110 117 | 22115 118 | 22194 119 | 22202 120 | 22231 121 | 22237 122 | 22248 123 | 22264 124 | 22451 125 | 22459 126 | 22464 127 | 22468 128 | 22470 129 | 22471 130 | 22481 131 | 22494 132 | 22497 133 | 22498 134 | 22503 135 | 22542 136 | 22680 137 | 22685 138 | 22971 139 | 22973 140 | 22979 141 | 23148 142 | 23149 143 | 23371 144 | 23792 145 | 23798 146 | 23818 147 | 23821 148 | 23827 149 | 25530 150 | 25764 151 | 25811 152 | 25828 153 | 25829 154 | 25859 155 | 25868 156 | 25892 157 | 25914 158 | 26121 159 | 26152 160 | 26153 161 | 26186 162 | 26228 163 | 26229 164 | 26538 165 | 27271 166 | 27293 167 | 27322 168 | 27335 169 | 27388 170 | 27860 171 | 28082 172 | 28250 173 | 28383 174 | 28401 175 | 29149 176 | 29167 177 | 29539 178 | 29557 179 | 29563 180 | 30525 181 | 30526 182 | 30528 183 | 30529 184 | 30537 185 | 30539 186 | 30545 187 | 30546 188 | 30548 189 | 30550 190 | 30551 191 | 30705 192 | 30712 193 | 30760 194 | 30761 195 | 30762 196 | 30786 197 | 30787 198 | 30797 199 | 30901 200 | 30902 201 | 30903 202 | 30993 203 | 30994 204 | 30995 205 | 30999 206 | 31012 207 | 31106 208 | 31108 209 | 31109 210 | 31110 211 | 31111 212 | 31170 213 | 31502 214 | 31598 215 | 32413 216 | 32464 217 | 32759 218 | 32813 219 | 32865 220 | 32884 221 | 32941 222 | 32942 223 | 33399 224 | 36994 225 | 36995 226 | 37991 227 | 38082 228 | 42423 229 | 43212 230 | 43242 231 | 43474 232 | 43519 233 | 45540 234 | 45544 235 | 45545 236 | 45926 237 | 46610 238 | 49722 239 | 50308 240 | 50449 241 | 50619 242 | 50735 243 | 51245 244 | 51246 245 | 52007 246 | 53819 247 | 53820 248 | 53844 249 | 53891 250 | 53895 251 | 53938 252 | 53940 253 | 53942 254 | 53943 255 | 53953 256 | 54077 257 | 54078 258 | 54101 259 | 54118 260 | 54123 261 | 54125 262 | 54228 263 | 54243 264 | 54295 265 | 54383 266 | 54386 267 | 54399 268 | 54408 269 | 54409 270 | 54411 271 | 54421 272 | 54447 273 | 54448 274 | 54486 275 | 54537 276 | 54568 277 | 54581 278 | 54610 279 | 54614 280 | 54617 281 | 54618 282 | 54623 283 | 54628 284 | 54656 285 | 54690 286 | 54691 287 | 54765 288 | 54793 289 | 54794 290 | 54795 291 | 54810 292 | 54873 293 | 54895 294 | 54899 295 | 54903 296 | 54993 297 | 55144 298 | 55145 299 | 55186 300 | 55189 301 | 55266 302 | 55399 303 | 55407 304 | 55409 305 | 55437 306 | 55449 307 | 55475 308 | 55476 309 | 55478 310 | 55483 311 | 55498 312 | 55517 313 | 55557 314 | 55609 315 | 55610 316 | 55618 317 | 55619 318 | 55620 319 | 55700 320 | 55702 321 | 55790 322 | 55909 323 | 55943 324 | 56015 325 | 56054 326 | 56071 327 | 56240 328 | 56342 329 | 56343 330 | 57735 331 | 57736 332 | 57944 333 | 58280 334 | 58612 335 | 58613 336 | 58981 337 | 59826 338 | 59848 339 | 59965 340 | 59976 341 | 60659 342 | 60717 343 | 60779 344 | 61434 345 | 61439 346 | 61450 347 | 62028 348 | 62083 349 | 66510 350 | 66602 351 | 66603 352 | 71535 353 | 72316 354 | 72318 355 | 74136 356 | 74175 357 | 74199 358 | 74201 359 | 74240 360 | 74241 361 | 74242 362 | 74312 363 | 75052 364 | 75169 365 | 76134 366 | 76135 367 | 76142 368 | 76371 369 | 76372 370 | 76379 371 | 76393 372 | 76394 373 | 76396 374 | 77141 375 | 77459 376 | 80207 377 | 80594 378 | 80596 379 | 81048 380 | 81053 381 | 81056 382 | 81566 383 | 81567 384 | 81572 385 | 81577 386 | 81578 387 | 81579 388 | 81580 389 | 82081 390 | 83400 391 | 83410 392 | 83413 393 | 83414 394 | 83416 395 | 84309 396 | 84799 397 | 85156 398 | 85354 399 | 85487 400 | 85779 401 | 85951 402 | 85961 403 | 86562 404 | 86587 405 | 86635 406 | 86738 407 | 86741 408 | 87034 409 | 87036 410 | 89621 411 | 89625 412 | 89627 413 | 90286 414 | 90692 415 | 90693 416 | 90695 417 | 91152 418 | 91257 419 | 91258 420 | 91518 421 | 92759 422 | 93323 423 | 93346 424 | 93566 425 | 93571 426 | 93940 427 | 93941 428 | 93985 429 | 93987 430 | 93996 431 | 94181 432 | 94603 433 | 94605 434 | 95437 435 | 96611 436 | 96612 437 | 96636 438 | 96637 439 | 96639 440 | 96678 441 | 97115 442 | 97259 443 | 97324 444 | 97357 445 | 97362 446 | 97454 447 | 97457 448 | 97475 449 | 97528 450 | 97529 451 | 98010 452 | 98232 453 | 98233 454 | 98234 455 | 99224 456 | 99716 457 | 99725 458 | 99727 459 | 99730 460 | 99732 461 | 99744 462 | 99808 463 | 100075 464 | 100091 465 | 100442 466 | 100456 467 | 100514 468 | 100518 469 | 100625 470 | 100626 471 | 100709 472 | 100733 473 | 101806 474 | 101940 475 | 102014 476 | 102130 477 | 102224 478 | 102627 479 | 102633 480 | 102793 481 | 102795 482 | 102796 483 | 103797 484 | 103798 485 | 103812 486 | 103820 487 | 104600 488 | 104601 489 | 105193 490 | 105210 491 | 105214 492 | 105578 493 | 108409 494 | 108890 495 | 110173 496 | 112229 497 | 112337 498 | 112354 499 | 112496 500 | 112945 501 | 112946 502 | 112954 503 | 112989 504 | 113156 505 | 113160 506 | 113173 507 | 113174 508 | 113175 509 | 113183 510 | 115697 511 | 115698 512 | 116536 513 | 116638 514 | 116798 515 | 116943 516 | 117294 517 | 117522 518 | 117629 519 | 117642 520 | 118440 521 | 118447 522 | 119757 523 | 120430 524 | 120722 525 | 121012 526 | 121588 527 | 121595 528 | 121599 529 | 121610 530 | 121612 531 | 121779 532 | 121863 533 | 121881 534 | 122766 535 | 123125 536 | 123128 537 | 123544 538 | 123567 539 | 123588 540 | 123592 541 | 123615 542 | 123619 543 | 123629 544 | 123641 545 | 123654 546 | 123673 547 | 123685 548 | 123698 549 | 123901 550 | 123907 551 | 123964 552 | 123997 553 | 124017 554 | 124033 555 | 124121 556 | 124204 557 | 124221 558 | 124249 559 | 124709 560 | 124711 561 | 124713 562 | 124721 563 | 124722 564 | 124723 565 | 124730 566 | 124731 567 | 124736 568 | 124934 569 | 125054 570 | 125099 571 | 125275 572 | 125283 573 | 125360 574 | 125388 575 | 125470 576 | 125618 577 | 125629 578 | 125758 579 | 125792 580 | 125904 581 | 125916 582 | 126007 583 | 126024 584 | 126080 585 | 126088 586 | 126092 587 | 126291 588 | 126346 589 | 126350 590 | 126359 591 | 126864 592 | 126872 593 | 127082 594 | 127323 595 | 127355 596 | 127394 597 | 127406 598 | 127542 599 | 127605 600 | 127633 601 | 127777 602 | 127838 603 | 127892 604 | 127893 605 | 127894 606 | 128141 607 | 128142 608 | 128146 609 | 128170 610 | 128182 611 | 128194 612 | 128251 613 | 128259 614 | 128391 615 | 128393 616 | 128396 617 | 128406 618 | 128417 619 | 128421 620 | 128498 621 | 128527 622 | 128528 623 | 128557 624 | 128567 625 | 128618 626 | 128626 627 | 128932 628 | 128947 629 | 129099 630 | 129105 631 | 129113 632 | 129135 633 | 129136 634 | 129144 635 | 129145 636 | 129146 637 | 129148 638 | 129149 639 | 129150 640 | 129155 641 | 129156 642 | 129158 643 | 129169 644 | 129174 645 | 129176 646 | 129181 647 | 129242 648 | 129249 649 | 129316 650 | 129335 651 | 129336 652 | 129339 653 | 129392 654 | 129400 655 | 129405 656 | 129409 657 | 129410 658 | 129411 659 | 129577 660 | 129578 661 | 129580 662 | 129653 663 | 129735 664 | 129859 665 | 129867 666 | 129914 667 | 129939 668 | 129993 669 | 129995 670 | 129996 671 | 129998 672 | 130006 673 | 130008 674 | 130035 675 | 130037 676 | 130120 677 | 130181 678 | 130296 679 | 130335 680 | 130336 681 | 130337 682 | 130338 683 | 130344 684 | 130345 685 | 130354 686 | 130355 687 | 130356 688 | 130357 689 | 130365 690 | 130369 691 | 130373 692 | 130376 693 | 130381 694 | 130382 695 | 130384 696 | 130385 697 | 130386 698 | 130387 699 | 130392 700 | 130403 701 | 130405 702 | 130406 703 | 130415 704 | 130423 705 | 130434 706 | 130437 707 | 130439 708 | 130440 709 | 130449 710 | 130452 711 | 130453 712 | 130462 713 | 130466 714 | 130469 715 | 130475 716 | 130479 717 | 130530 718 | 130536 719 | 130537 720 | 130582 721 | 130583 722 | 130587 723 | 130591 724 | 130602 725 | 130619 726 | 130629 727 | 130634 728 | 130661 729 | 130663 730 | 130664 731 | 130665 732 | 130666 733 | 130668 734 | 130669 735 | 130679 736 | 130683 737 | 130685 738 | 130691 739 | 130740 740 | 130746 741 | 130793 742 | 130860 743 | 130878 744 | 130882 745 | 130918 746 | 131091 747 | 131164 748 | 131199 749 | 131224 750 | 131513 751 | 131541 752 | 131554 753 | 131658 754 | 131693 755 | 131695 756 | 131704 757 | 131881 758 | 131882 759 | 131883 760 | 131884 761 | 131885 762 | 131886 763 | 131887 764 | 131888 765 | 131889 766 | 131890 767 | 131891 768 | 131892 769 | 131893 770 | 131894 771 | 131895 772 | 131896 773 | 131897 774 | 131898 775 | 131899 776 | 131900 777 | 131901 778 | 131902 779 | 131903 780 | 131904 781 | 131905 782 | 131906 783 | 131907 784 | 131908 785 | 131909 786 | 131910 787 | 131911 788 | 131912 789 | 131913 790 | 131914 791 | 131915 792 | 131916 793 | 131917 794 | 131918 795 | 131919 796 | 131920 797 | 131921 798 | 131922 799 | 131923 800 | 131924 801 | 131925 802 | 131926 803 | 131927 804 | 131928 805 | 131929 806 | 131930 807 | 131931 808 | 131932 809 | 131933 810 | 131934 811 | 131935 812 | 131936 813 | 131937 814 | 131938 815 | 131939 816 | 131940 817 | 131941 818 | 131942 819 | 131943 820 | 131944 821 | 131945 822 | 131946 823 | 131947 824 | 131948 825 | 131949 826 | 131950 827 | 131951 828 | 131952 829 | 131953 830 | 131954 831 | 131955 832 | 131956 833 | 131957 834 | 131958 835 | 131959 836 | 131960 837 | 131961 838 | 131962 839 | 131963 840 | 131964 841 | 131965 842 | 131966 843 | 131967 844 | 131968 845 | 131969 846 | 131970 847 | 131971 848 | 131972 849 | 131973 850 | 131974 851 | 131975 852 | 131976 853 | 131977 854 | 131978 855 | 131979 856 | 131980 857 | 131981 858 | 131982 859 | 131983 860 | 131984 861 | 131985 862 | 131986 863 | 131987 864 | 131988 865 | 131989 866 | 131990 867 | 131991 868 | 131992 869 | 131993 870 | 131994 871 | 131995 872 | 131996 873 | 131997 874 | 131998 875 | 131999 876 | 132000 877 | 132001 878 | 132071 879 | 132883 880 | 133142 881 | 133166 882 | 133167 883 | 133262 884 | 133273 885 | 133310 886 | 133336 887 | 133337 888 | 133395 889 | 133396 890 | 133402 891 | 133403 892 | 133812 893 | 133815 894 | 133819 895 | 133821 896 | 133825 897 | 133827 898 | 133828 899 | 133831 900 | 133832 901 | 133833 902 | 133839 903 | 133842 904 | 133843 905 | 133844 906 | 133845 907 | 133846 908 | 133848 909 | 133850 910 | 133851 911 | 133853 912 | 133857 913 | 133863 914 | 133864 915 | 133865 916 | -------------------------------------------------------------------------------- /qm9_data.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | import shutil 5 | import tarfile 6 | import tempfile 7 | from urllib import request as request 8 | from urllib.error import HTTPError, URLError 9 | from base64 import b64encode, b64decode 10 | 11 | import numpy as np 12 | import torch 13 | from ase.db import connect 14 | from ase.io.extxyz import read_xyz 15 | from ase.units import Debye, Bohr, Hartree, eV 16 | 17 | from schnetpack import Properties 18 | from schnetpack.datasets import DownloadableAtomsData 19 | from utility_classes import ConnectivityCompressor 20 | from qm9_preprocess_dataset import preprocess_dataset 21 | 22 | 23 | class QM9gen(DownloadableAtomsData): 24 | """ QM9 benchmark dataset for organic molecules with up to nine non-hydrogen atoms 25 | from {C, O, N, F}. 26 | 27 | This class adds convenience functions to download QM9 from figshare, 28 | pre-process the data such that it can be used for moleculec generation with the 29 | G-SchNet model, and load the data into pytorch. 30 | 31 | Args: 32 | path (str): path to directory containing qm9 database 33 | subset (list, optional): indices of subset, set to None for entire dataset 34 | (default: None). 35 | download (bool, optional): enable downloading if qm9 database does not 36 | exists (default: True) 37 | precompute_distances (bool, optional): if True and the pre-processed 38 | database does not yet exist, the pairwise distances of atoms in the 39 | dataset's molecules will be computed during pre-processing and stored in 40 | the database (increases storage demand of the dataset but decreases 41 | computational cost during training as otherwise the distances will be 42 | computed once in every epoch, default: True) 43 | remove_invalid (bool, optional): if True QM9 molecules that do not pass the 44 | valence check will be removed from the training data (note 1: the 45 | validity is per default inferred from a pre-computed list in our 46 | repository but will be assessed locally if the download fails, 47 | note2: only works if the pre-processed database does not yet exist, 48 | default: True) 49 | 50 | References: 51 | .. [#qm9_1] https://ndownloader.figshare.com/files/3195404 52 | """ 53 | 54 | # general settings for the dataset 55 | available_atom_types = [1, 6, 7, 8, 9] # all atom types found in the dataset 56 | atom_types_valence = [1, 4, 3, 2, 1] # valence constraints of the atom types 57 | radial_limits = [0.9, 1.7] # minimum and maximum distance between neighboring atoms 58 | 59 | # properties 60 | A = 'rotational_constant_A' 61 | B = 'rotational_constant_B' 62 | C = 'rotational_constant_C' 63 | mu = 'dipole_moment' 64 | alpha = 'isotropic_polarizability' 65 | homo = 'homo' 66 | lumo = 'lumo' 67 | gap = 'gap' 68 | r2 = 'electronic_spatial_extent' 69 | zpve = 'zpve' 70 | U0 = 'energy_U0' 71 | U = 'energy_U' 72 | H = 'enthalpy_H' 73 | G = 'free_energy' 74 | Cv = 'heat_capacity' 75 | 76 | properties = [ 77 | A, B, C, mu, alpha, 78 | homo, lumo, gap, r2, zpve, 79 | U0, U, H, G, Cv 80 | ] 81 | 82 | units = [1., 1., 1., Debye, Bohr ** 3, 83 | Hartree, Hartree, Hartree, 84 | Bohr ** 2, Hartree, 85 | Hartree, Hartree, Hartree, 86 | Hartree, 1., 87 | ] 88 | 89 | units_dict = dict(zip(properties, units)) 90 | 91 | connectivity_compressor = ConnectivityCompressor() 92 | 93 | def __init__(self, path, subset=None, download=True, precompute_distances=True, 94 | remove_invalid=True): 95 | self.path = path 96 | self.dbpath = os.path.join(self.path, f'qm9gen.db') 97 | self.precompute_distances = precompute_distances 98 | self.remove_invalid = remove_invalid 99 | 100 | super().__init__(self.dbpath, subset=subset, 101 | available_properties=self.properties, 102 | units=self.units, download=download) 103 | 104 | def create_subset(self, idx): 105 | """ 106 | Returns a new dataset that only consists of provided indices. 107 | 108 | Args: 109 | idx (numpy.ndarray): subset indices 110 | 111 | Returns: 112 | schnetpack.data.AtomsData: dataset with subset of original data 113 | """ 114 | idx = np.array(idx) 115 | subidx = idx if self.subset is None or len(idx) == 0 \ 116 | else np.array(self.subset)[idx] 117 | return type(self)(self.path, subidx, download=False) 118 | 119 | def get_properties(self, idx): 120 | _idx = self._subset_index(idx) 121 | with connect(self.dbpath) as conn: 122 | row = conn.get(_idx + 1) 123 | at = row.toatoms() 124 | 125 | # extract/calculate structure 126 | properties = {} 127 | properties[Properties.Z] = torch.LongTensor(at.numbers.astype(np.int)) 128 | positions = at.positions.astype(np.float32) 129 | positions -= at.get_center_of_mass() # center positions 130 | properties[Properties.R] = torch.FloatTensor(positions) 131 | properties[Properties.cell] = torch.FloatTensor(at.cell.astype(np.float32)) 132 | 133 | # recover connectivity matrix from compressed format 134 | con_mat = self.connectivity_compressor.decompress(row.data['con_mat']) 135 | # save in dictionary 136 | properties['_con_mat'] = torch.FloatTensor(con_mat.astype(np.float32)) 137 | 138 | # extract pre-computed distances (if they exist) 139 | if 'dists' in row.data: 140 | properties['dists'] = row.data['dists'] 141 | 142 | # get atom environment 143 | nbh_idx, offsets = self.environment_provider.get_environment(at) 144 | # store neighbors, cell, and index 145 | properties[Properties.neighbors] = torch.LongTensor(nbh_idx.astype(np.int)) 146 | properties[Properties.cell_offset] = torch.FloatTensor( 147 | offsets.astype(np.float32)) 148 | properties["_idx"] = torch.LongTensor(np.array([idx], dtype=np.int)) 149 | 150 | return at, properties 151 | 152 | def _download(self): 153 | works = True 154 | if not os.path.exists(self.dbpath): 155 | qm9_path = os.path.join(self.path, f'qm9.db') 156 | if not os.path.exists(qm9_path): 157 | works = works and self._load_data() 158 | works = works and self._preprocess_qm9() 159 | return works 160 | 161 | def _load_data(self): 162 | logging.info('Downloading GDB-9 data...') 163 | tmpdir = tempfile.mkdtemp('gdb9') 164 | tar_path = os.path.join(tmpdir, 'gdb9.tar.gz') 165 | raw_path = os.path.join(tmpdir, 'gdb9_xyz') 166 | url = 'https://ndownloader.figshare.com/files/3195389' 167 | 168 | try: 169 | request.urlretrieve(url, tar_path) 170 | logging.info('Done.') 171 | except HTTPError as e: 172 | logging.error('HTTP Error:', e.code, url) 173 | return False 174 | except URLError as e: 175 | logging.error('URL Error:', e.reason, url) 176 | return False 177 | 178 | logging.info('Extracting data from tar file...') 179 | tar = tarfile.open(tar_path) 180 | tar.extractall(raw_path) 181 | tar.close() 182 | logging.info('Done.') 183 | 184 | logging.info('Parsing xyz files...') 185 | with connect(os.path.join(self.path, 'qm9.db')) as con: 186 | ordered_files = sorted(os.listdir(raw_path), 187 | key=lambda x: (int(re.sub('\D', '', x)), x)) 188 | for i, xyzfile in enumerate(ordered_files): 189 | xyzfile = os.path.join(raw_path, xyzfile) 190 | 191 | if (i + 1) % 10000 == 0: 192 | logging.info('Parsed: {:6d} / 133885'.format(i + 1)) 193 | properties = {} 194 | tmp = os.path.join(tmpdir, 'tmp.xyz') 195 | 196 | with open(xyzfile, 'r') as f: 197 | lines = f.readlines() 198 | l = lines[1].split()[2:] 199 | for pn, p in zip(self.properties, l): 200 | properties[pn] = float(p) * self.units[pn] 201 | with open(tmp, "wt") as fout: 202 | for line in lines: 203 | fout.write(line.replace('*^', 'e')) 204 | 205 | with open(tmp, 'r') as f: 206 | ats = list(read_xyz(f, 0))[0] 207 | con.write(ats, data=properties) 208 | logging.info('Done.') 209 | 210 | shutil.rmtree(tmpdir) 211 | 212 | return True 213 | 214 | def _preprocess_qm9(self): 215 | # try to download pre-computed list of invalid molecules 216 | logging.info('Downloading pre-computed list of invalid QM9 molecules...') 217 | raw_path = os.path.join(self.path, 'qm9_invalid.txt') 218 | url = 'https://github.com/atomistic-machine-learning/G-SchNet/blob/master/' \ 219 | 'qm9_invalid.txt?raw=true' 220 | 221 | try: 222 | request.urlretrieve(url, raw_path) 223 | logging.info('Done.') 224 | invalid_list = np.loadtxt(raw_path) 225 | except HTTPError as e: 226 | logging.error('HTTP Error:', e.code, url) 227 | logging.info('CAUTION: Could not download pre-computed list, will assess ' 228 | 'validity during pre-processing.') 229 | invalid_list = None 230 | except URLError as e: 231 | logging.error('URL Error:', e.reason, url) 232 | logging.info('CAUTION: Could not download pre-computed list, will assess ' 233 | 'validity during pre-processing.') 234 | invalid_list = None 235 | # check validity of molecules and store connectivity matrices and inter-atomic 236 | # distances in database as a pre-processing step 237 | qm9_db = os.path.join(self.path, f'qm9.db') 238 | valence_list = \ 239 | np.array([self.available_atom_types, self.atom_types_valence]).flatten('F') 240 | preprocess_dataset(datapath=qm9_db, valence_list=valence_list, 241 | n_threads=8, n_mols_per_thread=125, logging_print=True, 242 | new_db_path=self.dbpath, 243 | precompute_distances=self.precompute_distances, 244 | remove_invalid=self.remove_invalid, 245 | invalid_list=invalid_list) 246 | return True 247 | 248 | def get_available_properties(self, available_properties): 249 | # we don't use properties other than stored connectivity matrices (and 250 | # distances, if they were precomputed) so we skip this part 251 | return available_properties 252 | -------------------------------------------------------------------------------- /nn_classes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from collections import Iterable 6 | 7 | import schnetpack as spk 8 | from schnetpack.nn import MLP 9 | from schnetpack.metrics import Metric 10 | 11 | 12 | ### OUTPUT MODULE ### 13 | class AtomwiseWithProcessing(nn.Module): 14 | r""" 15 | Atom-wise dense layers that allow to use additional pre- and post-processing layers. 16 | 17 | Args: 18 | n_in (int): input dimension of representation (default: 128) 19 | n_out (int): output dimension (default: 1) 20 | n_layers (int): number of atom-wise dense layers in output network (default: 5) 21 | n_neurons (list of int or None): number of neurons in each layer of the output 22 | network. If `None`, interpolate linearly between n_in and n_out. 23 | activation (function): activation function for hidden layers 24 | (default: spk.nn.activations.shifted_softplus). 25 | preprocess_layers (nn.Module): a torch.nn.Module or list of Modules for 26 | preprocessing the representation given by the first part of the network 27 | (default: None). 28 | postprocess_layers (nn.Module): a torch.nn.Module or list of Modules for 29 | postprocessing the output given by the second part of the network 30 | (default: None). 31 | in_key (str): keyword to access the representation in the inputs dictionary, 32 | it is automatically inferred from the preprocessing layers, if at least one 33 | is given (default: 'representation'). 34 | out_key (str): a string as key to the output dictionary (if set to 'None', the 35 | output will not be wrapped into a dictionary, default: 'y') 36 | 37 | Returns: 38 | result: dictionary with predictions stored in result[out_key] 39 | """ 40 | 41 | def __init__(self, n_in=128, n_out=1, n_layers=5, n_neurons=None, 42 | activation=spk.nn.activations.shifted_softplus, 43 | preprocess_layers=None, postprocess_layers=None, 44 | in_key='representation', out_key='y'): 45 | 46 | super(AtomwiseWithProcessing, self).__init__() 47 | 48 | self.n_in = n_in 49 | self.n_out = n_out 50 | self.n_layers = n_layers 51 | self.in_key = in_key 52 | self.out_key = out_key 53 | 54 | if isinstance(preprocess_layers, Iterable): 55 | self.preprocess_layers = nn.ModuleList(preprocess_layers) 56 | self.in_key = self.preprocess_layers[-1].out_key 57 | elif preprocess_layers is not None: 58 | self.preprocess_layers = preprocess_layers 59 | self.in_key = self.preprocess_layers.out_key 60 | else: 61 | self.preprocess_layers = None 62 | 63 | if isinstance(postprocess_layers, Iterable): 64 | self.postprocess_layers = nn.ModuleList(postprocess_layers) 65 | else: 66 | self.postprocess_layers = postprocess_layers 67 | 68 | if n_neurons is None: 69 | # linearly interpolate between n_in and n_out 70 | n_neurons = list(np.linspace(n_in, n_out, n_layers + 1).astype(int)[1:-1]) 71 | self.out_net = MLP(n_in, n_out, n_neurons, n_layers, activation) 72 | 73 | self.derivative = None # don't compute derivative w.r.t. inputs 74 | 75 | def forward(self, inputs): 76 | """ 77 | Compute layer output and apply pre-/postprocessing if specified. 78 | 79 | Args: 80 | inputs (dict of torch.Tensor): batch of input values. 81 | Returns: 82 | torch.Tensor: layer output. 83 | """ 84 | # apply pre-processing layers 85 | if self.preprocess_layers is not None: 86 | if isinstance(self.preprocess_layers, Iterable): 87 | for pre_layer in self.preprocess_layers: 88 | inputs = pre_layer(inputs) 89 | else: 90 | inputs = self.preprocess_layers(inputs) 91 | 92 | # get (pre-processed) representation 93 | if isinstance(inputs[self.in_key], tuple): 94 | repr = inputs[self.in_key][0] 95 | else: 96 | repr = inputs[self.in_key] 97 | 98 | # apply output network 99 | result = self.out_net(repr) 100 | 101 | # apply post-processing layers 102 | if self.postprocess_layers is not None: 103 | if isinstance(self.postprocess_layers, Iterable): 104 | for post_layer in self.postprocess_layers: 105 | result = post_layer(inputs, result) 106 | else: 107 | result = self.postprocess_layers(inputs, result) 108 | 109 | # use provided key to store result 110 | if self.out_key is not None: 111 | result = {self.out_key: result} 112 | 113 | return result 114 | 115 | 116 | ### METRICS ### 117 | class KLDivergence(Metric): 118 | r""" 119 | Metric for mean KL-Divergence. 120 | 121 | Args: 122 | target (str): name of target property 123 | model_output ([int], [str]): indices or keys to unpack the desired output 124 | from the model in case of multiple outputs, e.g. ['x', 'y'] to get 125 | output['x']['y'] (default: 'y'). 126 | name (str): name used in logging for this metric. If set to `None`, 127 | `KLD_[target]` will be used (default: None). 128 | mask (str): key for a mask in the examined batch which hides irrelevant output 129 | values. If 'None' is provided, no mask will be applied (default: None). 130 | inverse_mask (bool): whether the mask needs to be inverted prior to application 131 | (default: False). 132 | """ 133 | 134 | def __init__(self, target='_labels', model_output='y', name=None, 135 | mask=None, inverse_mask=False): 136 | name = 'KLD_' + target if name is None else name 137 | super(KLDivergence, self).__init__(name) 138 | self.target = target 139 | self.model_output = model_output 140 | self.loss = 0. 141 | self.n_entries = 0. 142 | self.mask_str = mask 143 | self.inverse_mask = inverse_mask 144 | 145 | def reset(self): 146 | self.loss = 0. 147 | self.n_entries = 0. 148 | 149 | def add_batch(self, batch, result): 150 | # extract true labels 151 | y = batch[self.target] 152 | 153 | # extract predictions 154 | yp = result 155 | if self.model_output is not None: 156 | if isinstance(self.model_output, list): 157 | for key in self.model_output: 158 | yp = yp[key] 159 | else: 160 | yp = yp[self.model_output] 161 | 162 | # normalize output 163 | log_yp = F.log_softmax(yp, -1) 164 | 165 | # apply KL divergence formula entry-wise 166 | loss = F.kl_div(log_yp, y, reduction='none') 167 | 168 | # sum over last dimension to get KL divergence per distribution 169 | loss = torch.sum(loss, -1) 170 | 171 | # apply mask to filter padded dimensions 172 | if self.mask_str is not None: 173 | atom_mask = batch[self.mask_str] 174 | if self.inverse_mask: 175 | atom_mask = 1.-atom_mask 176 | loss = torch.where(atom_mask > 0, loss, torch.zeros_like(loss)) 177 | n_entries = torch.sum(atom_mask > 0) 178 | else: 179 | n_entries = torch.prod(torch.tensor(loss.size())) 180 | 181 | # calculate loss and n_entries 182 | self.n_entries += n_entries.detach().cpu().data.numpy() 183 | self.loss += torch.sum(loss).detach().cpu().data.numpy() 184 | 185 | def aggregate(self): 186 | return self.loss / max(self.n_entries, 1.) 187 | 188 | 189 | ### PRE- AND POST-PROCESSING LAYERS ### 190 | class EmbeddingMultiplication(nn.Module): 191 | r""" 192 | Layer that multiplies embeddings of given types with the representation. 193 | 194 | Args: 195 | embedding (torch.nn.Embedding instance): the embedding layer used to embed atom 196 | types. 197 | in_key_types (str): the keyword to obtain types for embedding from inputs. 198 | in_key_representation (str): the keyword to obtain the representation from 199 | inputs. 200 | out_key (str): the keyword used to store the calculated product in the inputs 201 | dictionary. 202 | """ 203 | 204 | def __init__(self, embedding, in_key_types='_next_types', 205 | in_key_representation='representation', 206 | out_key='preprocessed_representation'): 207 | super(EmbeddingMultiplication, self).__init__() 208 | self.embedding = embedding 209 | self.in_key_types = in_key_types 210 | self.in_key_representation = in_key_representation 211 | self.out_key = out_key 212 | 213 | def forward(self, inputs): 214 | """ 215 | Compute layer output. 216 | 217 | Args: 218 | inputs (dict of torch.Tensor): batch of input values containing the atomic 219 | numbers for embedding as well as the representation. 220 | Returns: 221 | torch.Tensor: layer output. 222 | """ 223 | # get types to embed from inputs 224 | types = inputs[self.in_key_types] 225 | st = types.size() 226 | 227 | # embed types 228 | if len(st) == 1: 229 | emb = self.embedding(types.view(st[0], 1)) 230 | elif len(st) == 2: 231 | emb = self.embedding(types.view(*st[:-1], 1, st[-1])) 232 | 233 | # get representation 234 | if isinstance(inputs[self.in_key_representation], tuple): 235 | repr = inputs[self.in_key_representation][0] 236 | else: 237 | repr = inputs[self.in_key_representation] 238 | if len(st) == 2: 239 | # if multiple types are provided per molecule, expand 240 | # dimensionality of representation 241 | repr = repr.view(*repr.size()[:-1], 1, repr.size()[-1]) 242 | 243 | # multiply embedded types with representation 244 | features = repr * emb 245 | 246 | # store result in input dictionary 247 | inputs.update({self.out_key: features}) 248 | 249 | return inputs 250 | 251 | 252 | class NormalizeAndAggregate(nn.Module): 253 | r""" 254 | Layer that normalizes and aggregates given input along specifiable axes. 255 | 256 | Args: 257 | normalize (bool): set True to normalize the input (default: True). 258 | normalization_axis (int): axis along which normalization is applied 259 | (default: -1). 260 | normalization_mode (str): which normalization to apply (currently only 261 | 'logsoftmax' is supported, default: 'logsoftmax'). 262 | aggregate (bool): set True to aggregate the input (default: True). 263 | aggregation_axis (int): axis along which aggregation is applied 264 | (default: -1). 265 | aggregation_mode (str): which aggregation to apply (currently 'sum' and 266 | 'mean' are supported, default: 'sum'). 267 | keepdim (bool): set True to keep the number of dimensions after aggregation 268 | (default: True). 269 | in_key_mask (str): key to extract a mask from the inputs dictionary, 270 | which hides values during aggregation (default: None). 271 | squeeze (bool): whether to squeeze the input before applying normalization 272 | (default: False). 273 | 274 | Returns: 275 | torch.Tensor: input after normalization and aggregation along specified axes. 276 | """ 277 | 278 | def __init__(self, normalize=True, normalization_axis=-1, 279 | normalization_mode='logsoftmax', aggregate=True, 280 | aggregation_axis=-1, aggregation_mode='sum', keepdim=True, 281 | mask=None, squeeze=False): 282 | 283 | super(NormalizeAndAggregate, self).__init__() 284 | 285 | if normalize: 286 | if normalization_mode.lower() == 'logsoftmax': 287 | self.normalization = nn.LogSoftmax(normalization_axis) 288 | else: 289 | self.normalization = None 290 | 291 | if aggregate: 292 | if aggregation_mode.lower() == 'sum': 293 | self.aggregation =\ 294 | spk.nn.base.Aggregate(aggregation_axis, mean=False, 295 | keepdim=keepdim) 296 | elif aggregation_mode.lower() == 'mean': 297 | self.aggregation =\ 298 | spk.nn.base.Aggregate(aggregation_axis, mean=True, 299 | keepdim=keepdim) 300 | else: 301 | self.aggregation = None 302 | 303 | self.mask = mask 304 | self.squeeze = squeeze 305 | 306 | def forward(self, inputs, result): 307 | """ 308 | Compute layer output. 309 | 310 | Args: 311 | inputs (dict of torch.Tensor): batch of input values containing the mask 312 | result (torch.Tensor): batch of result values to which normalization and 313 | aggregation is applied 314 | Returns: 315 | torch.Tensor: normalized and aggregated result. 316 | """ 317 | 318 | res = result 319 | 320 | if self.squeeze: 321 | res = torch.squeeze(res) 322 | 323 | if self.normalization is not None: 324 | res = self.normalization(res) 325 | 326 | if self.aggregation is not None: 327 | if self.mask is not None: 328 | mask = inputs[self.mask] 329 | else: 330 | mask = None 331 | res = self.aggregation(res, mask) 332 | 333 | return res 334 | -------------------------------------------------------------------------------- /template_preprocess_dataset.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import argparse 3 | import sys 4 | import time 5 | import numpy as np 6 | import logging 7 | from ase.db import connect 8 | from scipy.spatial.distance import pdist, squareform 9 | from utility_classes import ConnectivityCompressor, Molecule 10 | from multiprocessing import Process, Queue 11 | from pathlib import Path 12 | 13 | # list names of collected statistics here (e.g. the number of atoms of each type) 14 | stat_heads = ['n_atoms', 'C', 'N', 'O', 'F', 'H'] 15 | atom_types = [6, 7, 8, 9, 1] # atom type charges in the same order as in stat_heads 16 | 17 | 18 | def preprocess_dataset(datapath, new_db_path=None, cutoff=2.0, 19 | precompute_distances=True, remove_invalid=True, 20 | invalid_list=None, valence_list=None, logging_print=True): 21 | ''' 22 | Pre-processes all molecules of a dataset. 23 | Along with a new database containing the pre-processed molecules, a 24 | "input_db_invalid.txt" file holding the indices of removed molecules and a 25 | "new_db_statistics.npz" file (containing atom count statistics for all molecules in 26 | the new database) are stored. 27 | 28 | Args: 29 | datapath (str): full path to dataset (ase.db database) 30 | new_db_path (str, optional): full path to new database where pre-processed 31 | molecules shall be stored (None to simply append "gen" to the name in 32 | datapath, default: None) 33 | cutoff (float, optional): cutoff value in angstrom used to determine which 34 | atoms in a molecule are considered as neighbors (i.e. connected, default: 35 | 2.0) 36 | precompute_distances (bool, optional): if True, the pairwise distances between 37 | atoms in each molecule are computed and stored in the database (default: 38 | True) 39 | remove_invalid (bool, optional): if True, molecules that do not pass the 40 | validity or connectivity checks are removed from the new database (default: 41 | True) 42 | invalid_list (list of int, optional): precomputed list containing indices of 43 | molecules that are marked as invalid (default: None) 44 | valence_list (list, optional): the valence of atom types in the form 45 | [type1 valence type2 valence ...] which could be used for valence checks 46 | (not implemented, default: None) 47 | logging_print (bool, optional): set True to show output with logging.info 48 | instead of standard printing (default: True) 49 | ''' 50 | # convert paths 51 | datapath = Path(datapath) 52 | if new_db_path is None: 53 | new_db_path = datapath.parent / (datapath.stem + 'gen.db') 54 | else: 55 | new_db_path = Path(new_db_path) 56 | 57 | def _print(x, end='\n', flush=False): 58 | if logging_print: 59 | logging.info(x) 60 | else: 61 | print(x, end=end, flush=flush) 62 | 63 | with connect(datapath) as db: 64 | n_all = db.count() 65 | if n_all == 0: 66 | _print('No molecules found in data base!') 67 | sys.exit(0) 68 | _print('\nPre-processing data...') 69 | if logging_print: 70 | _print(f'Processed: 0 / {n_all}...') 71 | else: 72 | _print(f'0.00%', end='', flush=True) 73 | 74 | # setup counter etc. 75 | count = 0 # count number of discarded (invalid etc.) molecules 76 | disc = [] # indices of disconnected structures 77 | inval = [] # indices of invalid structures 78 | stats = np.empty((len(stat_heads), 0)) # scaffold for statistics 79 | start_time = time.time() 80 | compressor = ConnectivityCompressor() # used to compress connectivity matrices 81 | # check if list of invalid molecules was provided and cast it into a set (allows 82 | # for faster lookup) 83 | if invalid_list is not None and remove_invalid: 84 | invalid_list = {*invalid_list} 85 | n_inval = len(invalid_list) 86 | else: 87 | n_inval = 0 88 | 89 | # preprocess each structure in the source db and write results into target db 90 | with connect(datapath) as source_db: 91 | with connect(new_db_path) as target_db: 92 | for i in range(source_db.count()): 93 | 94 | # skip molecule if index is present in precomputed list of invalid 95 | # molecules and if remove_invalid is True 96 | if remove_invalid and invalid_list is not None: 97 | if i in invalid_list: 98 | continue 99 | 100 | # get molecule from database 101 | row = source_db.get(i + 1) 102 | # extract additional data stored with molecule 103 | data = row.data 104 | # get ase.Atoms object 105 | at = row.toatoms() 106 | # get positions and atomic numbers 107 | pos = at.positions 108 | numbers = at.numbers 109 | 110 | # the algorithm to sample generation traces (atom placement steps) 111 | # assumes that the atoms in our structures are ordered by their 112 | # distance to the center of mass, thus we order them in that way here: 113 | 114 | # center positions (using center of mass) 115 | pos = pos - at.get_center_of_mass() 116 | # order atoms by distance to center of mass 117 | center_dists = np.sqrt(np.maximum(np.sum(pos ** 2, axis=1), 0)) 118 | idcs_sorted = np.argsort(center_dists) 119 | pos = pos[idcs_sorted] 120 | numbers = numbers[idcs_sorted] 121 | # update positions and atomic numbers accordingly in ase.Atoms object 122 | at.positions = pos 123 | at.numbers = numbers 124 | 125 | # retrieve connectivity matrix (and pairwise distances) 126 | connectivity, pairwise_distances = get_connectivity(at, cutoff) 127 | 128 | # check if the connectivity matrix represents a proper structure (i.e. 129 | # if all atoms are connected to each other via some path) as 130 | # disconnected structures cannot be used for training (there must be 131 | # an atom placement trajectory for G-SchNet) 132 | if is_disconnected(connectivity): 133 | count += 1 134 | disc += [i] 135 | continue 136 | 137 | # you could potentially implement some valency constraint checking here 138 | # and remove or mark molecules that do not pass the test 139 | # val = [check validity e.g. with connectivity and valence list] 140 | # if remove_invalid: 141 | # if not val: 142 | # count += 1 143 | # inval += [i] 144 | # continue 145 | 146 | # update data stored in db with a compressed version of the 147 | # connectivity matrix (we store only indices of entries >= 1 148 | data.update({'con_mat': compressor.compress(connectivity)}) 149 | 150 | # if desired, also store precomputed distances (in condensed format) 151 | if precompute_distances: 152 | data.update({'dists': pairwise_distances}) 153 | 154 | # write preprocessed molecule and data to target database 155 | target_db.write(at, data=data) 156 | 157 | # you can additionally gather some statistics about the training data 158 | # (these statistics can for example be used to filter molecules when 159 | # displaying them with the display_molecules.py script) 160 | # e.g. for QM9 we collected the atom, bond, and ring count statistics 161 | # when doing valency checks 162 | # here we simply count the number of atoms of each type 163 | atom_type_counts = np.bincount(numbers, minlength=10) 164 | # store counts [n_atoms, C, N, O, F, H] as listed in stat_heads 165 | statistics = np.array([len(numbers), *atom_type_counts[atom_types]]) 166 | # update stats array with statistics of current molecule 167 | stats = np.hstack((stats, statistics[:, None])) 168 | 169 | # print progress every 1000 molecules 170 | if (i+1) % 1000 == 0: 171 | _print(f'Processed: {i+1:6d} / {n_all}...') 172 | 173 | if not logging_print: 174 | _print('\033[K', end='\n', flush=True) 175 | _print(f'... successfully validated {n_all - count - n_inval} data ' 176 | f'points!', flush=True) 177 | if invalid_list is not None: 178 | _print(f'{n_inval} structures were removed because they are on the ' 179 | f'pre-computed list of invalid molecules!', flush=True) 180 | if len(disc)+len(inval) > 0: 181 | _print(f'CAUTION: Could not validate {len(disc)+len(inval)} additional ' 182 | f'molecules. You might want to increase the cutoff (currently ' 183 | f'{cutoff} angstrom) in order to have less disconnected structures. ' 184 | f'The molecules were removed and their indices are ' 185 | f'appended to the list of invalid molecules stored at ' 186 | f'{datapath.parent / (datapath.stem + f"_invalid.txt")}', 187 | flush=True) 188 | np.savetxt(datapath.parent / (datapath.stem + f'_invalid.txt'), 189 | np.append(np.sort(list(invalid_list)), np.sort(inval + disc)), 190 | fmt='%d') 191 | elif remove_invalid: 192 | _print(f'Identified {len(disc)} disconnected structures, and {len(inval)} ' 193 | f'invalid structures! You might want to increase the cutoff (currently ' 194 | f'{cutoff} angstrom) in order to have less disconnected structures.', 195 | flush=True) 196 | np.savetxt(datapath.parent / (datapath.stem + f'_invalid.txt'), 197 | np.sort(inval + disc), fmt='%d') 198 | _print('\nCompressing and storing statistics with numpy...') 199 | np.savez_compressed(new_db_path.parent/(new_db_path.stem+f'_statistics.npz'), 200 | stats=stats, 201 | stat_heads=stat_heads) 202 | 203 | end_time = time.time() - start_time 204 | m, s = divmod(end_time, 60) 205 | h, m = divmod(m, 60) 206 | h, m, s = int(h), int(m), int(s) 207 | _print(f'Done! Pre-processing needed {h:d}h{m:02d}m{s:02d}s.') 208 | 209 | 210 | def get_connectivity(mol, cutoff=2.0): 211 | ''' 212 | Write code to obtain a connectivity matrix given a molecule from your database 213 | here. The simple default implementation calculates pairwise distances and then 214 | uses a radial cutoff (e.g. 2 angstrom) to determine which atoms are labeled as 215 | connected. The matrix only needs to be binary as it is only used to sample 216 | generation traces, i.e. an order of atom placement steps for training. 217 | However, one could for example also use chemoinformatics tools in order to obtain 218 | bond order information and check the valence of provided structures on the run if 219 | the structures allow this (we did this for our experiments with QM9 in order to 220 | allow for comparison to related work, but we think that using a radial cutoff is 221 | actually more robust and more general as it does not depend on usually unreliable 222 | bond order assignment algorithms and can be used for all kinds of materials or 223 | molecules). 224 | Args: 225 | mol (ase.Atoms): one molecule from the database 226 | cutoff (float, optional): cutoff value in angstrom used to determine which 227 | atoms are connected 228 | 229 | Returns: 230 | numpy.ndarray: the computed connectivity matrix (n_atoms x n_atoms, float) 231 | numpy.ndarray: the computed pairwise distances in a condensed format 232 | (length is n_atoms*(n_atoms-1)/2), see scipy.spatial.distance.pdist for 233 | more information 234 | ''' 235 | # retrieve positions 236 | atom_positions = mol.get_positions() 237 | # get pairwise distances (condensed) 238 | pairwise_distances = pdist(atom_positions) 239 | # use cutoff to obtain connectivity matrix (condensed format) 240 | connectivity = np.array(pairwise_distances <= cutoff, dtype=float) 241 | # cast to redundant square matrix format 242 | connectivity = squareform(connectivity) 243 | # set diagonal entries to zero (as we do not assume atoms to be their own neighbors) 244 | connectivity[np.diag_indices_from(connectivity)] = 0 245 | return connectivity, pairwise_distances 246 | 247 | 248 | def is_disconnected(connectivity): 249 | ''' 250 | Assess whether all atoms of a molecule are connected using a connectivity matrix 251 | 252 | Args: 253 | connectivity (numpy.ndarray): matrix (n_atoms x n_atoms) indicating bonds 254 | between atoms 255 | 256 | Returns 257 | bool: True if the molecule consists of at least two disconnected graphs, 258 | False if all atoms are connected by some path 259 | ''' 260 | con_mat = connectivity 261 | seen, queue = {0}, collections.deque([0]) # start at node (atom) 0 262 | while queue: 263 | vertex = queue.popleft() 264 | # iterate over (bonded) neighbors of current node 265 | for node in np.argwhere(con_mat[vertex] > 0).flatten(): 266 | # add node to queue and list of seen nodes if it has not been seen before 267 | if node not in seen: 268 | seen.add(node) 269 | queue.append(node) 270 | # if the seen nodes do not include all nodes, there are disconnected parts 271 | return seen != {*range(len(con_mat))} 272 | -------------------------------------------------------------------------------- /template_filter_generated.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import os 4 | import argparse 5 | import time 6 | 7 | from scipy.spatial.distance import pdist 8 | from schnetpack import Properties 9 | from utility_classes import Molecule, ConnectivityCompressor 10 | from utility_functions import update_dict 11 | from ase import Atoms 12 | from ase.db import connect 13 | 14 | 15 | def get_parser(): 16 | """ Setup parser for command line arguments """ 17 | main_parser = argparse.ArgumentParser() 18 | main_parser.add_argument('data_path', 19 | help='Path to generated molecules in .mol_dict format, ' 20 | 'a database called "generated_molecules.db" with the ' 21 | 'filtered molecules along with computed statistics ' 22 | '("generated_molecules_statistics.npz") will be ' 23 | 'stored in the same directory as the input file/s ' 24 | '(if the path points to a directory, all .mol_dict ' 25 | 'files in the directory will be merged and filtered ' 26 | 'in one pass)') 27 | main_parser.add_argument('--valence', 28 | default=[1, 1, 6, 4, 7, 3, 8, 2, 9, 1], type=int, 29 | nargs='+', 30 | help='the valence of atom types in the form ' 31 | '[type1 valence type2 valence ...] ' 32 | '(default: %(default)s)') 33 | main_parser.add_argument('--filters', type=str, nargs='*', 34 | default=['valence', 'disconnected', 'unique'], 35 | choices=['valence', 'disconnected', 'unique'], 36 | help='Select the filters applied to identify ' 37 | 'invalid molecules (default: %(default)s)') 38 | main_parser.add_argument('--store', type=str, default='valid', 39 | choices=['all', 'valid'], 40 | help='How much information shall be stored ' 41 | 'after filtering: \n"all" keeps all ' 42 | 'generated molecules and statistics, ' 43 | '\n"valid" keeps only valid molecules' 44 | '(default: %(default)s)') 45 | main_parser.add_argument('--print_file', 46 | help='Use to limit the printing if results are ' 47 | 'written to a file instead of the console (' 48 | 'e.g. if running on a cluster)', 49 | action='store_true') 50 | return main_parser 51 | 52 | 53 | if __name__ == '__main__': 54 | parser = get_parser() 55 | args = parser.parse_args() 56 | print_file = args.print_file 57 | printed_todos = False 58 | 59 | # read input file or fuse dictionaries if data_path is a folder 60 | if not os.path.isdir(args.data_path): 61 | if not os.path.isfile(args.data_path): 62 | print(f'\n\nThe specified data path ({args.data_path}) is neither a file ' 63 | f'nor a directory! Please specify a different data path.') 64 | raise FileNotFoundError 65 | else: 66 | with open(args.data_path, 'rb') as f: 67 | res = pickle.load(f) # read input file 68 | target_db = os.path.join(os.path.dirname(args.data_path), 69 | 'generated_molecules.db') 70 | else: 71 | print(f'\n\nFusing .mol_dict files in folder {args.data_path}...') 72 | mol_files = [f for f in os.listdir(args.data_path) 73 | if f.endswith(".mol_dict")] 74 | if len(mol_files) == 0: 75 | print(f'Could not find any .mol_dict files at {args.data_path}! Please ' 76 | f'specify a different data path!') 77 | raise FileNotFoundError 78 | res = {} 79 | for file in mol_files: 80 | with open(os.path.join(args.data_path, file), 'rb') as f: 81 | cur_res = pickle.load(f) 82 | update_dict(res, cur_res) 83 | res = dict(sorted(res.items())) # sort dictionary keys 84 | print(f'...done!') 85 | target_db = os.path.join(args.data_path, 'generated_molecules.db') 86 | 87 | # compute array with valence of provided atom types 88 | max_type = max(args.valence[::2]) 89 | valence = np.zeros(max_type+1, dtype=int) 90 | valence[args.valence[::2]] = args.valence[1::2] 91 | 92 | # print the chosen settings 93 | valence_str = '' 94 | for i in range(max_type+1): 95 | if valence[i] > 0: 96 | valence_str += f'type {i}: {valence[i]}, ' 97 | filters = [] 98 | if 'valence' in args.filters: 99 | filters += ['valency'] 100 | if 'disconnected' in args.filters: 101 | filters += ['connectedness'] 102 | if 'unique' in args.filters: 103 | filters += ['uniqueness'] 104 | if len(filters) >= 3: 105 | edit = ', ' 106 | else: 107 | edit = ' ' 108 | for i in range(len(filters) - 1): 109 | filters[i] = filters[i] + edit 110 | if len(filters) >= 2: 111 | filters = filters[:-1] + ['and '] + filters[-1:] 112 | string = ''.join(filters) 113 | print(f'\n\n1. Filtering molecules according to {string}...') 114 | print(f'\nTarget valence:\n{valence_str[:-2]}\n') 115 | 116 | # initial setup of array for statistics and some counters 117 | n_generated = 0 118 | n_valid = 0 119 | n_non_unique = 0 120 | stat_heads = ['n_atoms', 'id', 'valid', 'duplicating', 'n_duplicates', 121 | 'known', 'equals', 'C', 'N', 'O', 'F', 'H'] 122 | stats = np.empty((len(stat_heads), 0)) 123 | all_mols = [] 124 | connectivity_compressor = ConnectivityCompressor() 125 | 126 | # iterate over generated molecules by length (all generated molecules with n 127 | # atoms are stored in one batch, so we loop over all available lengths n) 128 | # this is useful e.g. for finding duplicates, since we only need to compare 129 | # molecules of the same length (and can actually further narrow down the 130 | # candidates by looking at the exact atom type composition of each molecule) 131 | start_time = time.time() 132 | for n_atoms in res: 133 | if not isinstance(n_atoms, int) or n_atoms == 0: 134 | continue 135 | 136 | prog_str = lambda x: f'Checking {x} for molecules of length {n_atoms}' 137 | work_str = 'valence' if 'valence' in args.filters else 'dictionary' 138 | if not print_file: 139 | print('\033[K', end='\r', flush=True) 140 | print(prog_str(work_str) + ' (0.00%)', end='\r', flush=True) 141 | else: 142 | print(prog_str(work_str), flush=True) 143 | 144 | d = res[n_atoms] # dictionary containing molecules of length n_atoms 145 | all_pos = d[Properties.R] # n_mols x n_atoms x 3 matrix with atom positions 146 | all_numbers = d[Properties.Z] # n_mols x n_atoms matrix with atom types 147 | n_mols = len(all_pos) 148 | valid = np.ones(n_mols, dtype=int) # all molecules are valid in the beginning 149 | 150 | # check valency of molecules with length n 151 | if 'valence' in args.filters: 152 | if not printed_todos: 153 | print('Please implement a procedure to check the valence in generated ' 154 | 'molecules! Skipping valence check...') 155 | # TODO 156 | # Implement a procedure to assess the valence of generated molecules here! 157 | # You can adapt and use the Molecule class in utility_classes.py, 158 | # but the current code is tailored towards the QM9 dataset. In fact, 159 | # the OpenBabel algorithm to kekulize bond orders is not very reliable 160 | # and we implemented some heuristics in the Molecule class to fix these 161 | # flaws for structures made of C, N, O, and F atoms. However, when using 162 | # more complex structures with a more diverse set of atom types, we think 163 | # that the reliability of bond assignment in OpenBabel might further 164 | # degrade and therefore do no recommend to use valence checks for 165 | # analysis unless it is very important for your use case. 166 | 167 | # detect molecules with disconnected parts if desired 168 | if 'disconnected' in args.filters: 169 | if not print_file: 170 | print('\033[K', end='\r', flush=True) 171 | print(prog_str("connectedness")+'...', end='\r', flush=True) 172 | if not printed_todos: 173 | print('Please implement a procedure to check the connectedness of ' 174 | 'generated molecules! In this template script we will now remove ' 175 | 'molecules where two atoms are closer than 0.3 angstrom as an ' 176 | 'example processing step...') 177 | # TODO 178 | # Implement a procedure to assess the connectedness of generated 179 | # molecules here! You can for example use a connectivity matrix obtained 180 | # from kekulized bond orders (as we do in our QM9 experiments) or 181 | # calculate the connectivity with a simple cutoff (e.g. all atoms less 182 | # then 2.0 angstrom apart are connected, see get_connectivity function in 183 | # template_preprocess_dataset script). 184 | # We will remove all molecules where two atoms are closer than 0.3 185 | # angstrom in the following as an example filtering step 186 | 187 | # loop over all molecules of length n_atoms 188 | for i in range(len(all_pos)): 189 | positions = all_pos[i] # extract atom positions 190 | dists = pdist(positions) # compute pair-wise distances 191 | if np.any(dists) < 0.3: # check if any two atoms are closer than 0.3 A 192 | valid[i] = 0 # mark current molecule as invalid 193 | 194 | 195 | # identify identical molecules (e.g. using fingerprints) 196 | if not print_file: 197 | print('\033[K', end='\r', flush=True) 198 | print(prog_str('uniqueness')+'...', end='\r', flush=True) 199 | if not printed_todos: 200 | print('Please implement a procedure to check the uniqueness of ' 201 | 'generated molecules! Skipping check for uniqueness...') 202 | printed_todos = True 203 | # TODO 204 | # Implement procedure to identify duplicate structures here. 205 | # This can (heuristically) be achieved in many ways but perfectly identifying 206 | # all duplicate structures without false positives or false negatives is 207 | # probably impossible (or computationally prohibitive). 208 | # For our QM9 experiments, we compared fingerprints and canonical smiles 209 | # strings of generated molecules using the Molecule class in utility_classes.py 210 | # that provides functions to obtain these. It would also be possible to compare 211 | # learned embeddings, e.g. from SchNet or G-SchNet, either as an average over 212 | # all atoms, over all atoms of the same type, or combined with an algorithm 213 | # to find the best match between atoms of two molecules considering the 214 | # distances between embeddings. A similar procedure could be implemented 215 | # using the root-mean-square deviation (RMSD) of atomic positions. Then it 216 | # would be required to find the best match between atoms of two structures if 217 | # they are rotated such that the RMSD given the match is minimal. Again, 218 | # the best procedure really depends on the experimental setup, e.g. the 219 | # goals of the experiment, used data and size of molecules in the dataset etc. 220 | 221 | # duplicate_count contains the number of duplicates found for each structure 222 | duplicate_count = np.zeros(n_mols, dtype=int) 223 | # duplicating contains -1 for original structures and the id of the duplicated 224 | # original structure for duplicates 225 | duplicating = -np.ones(n_mols, dtype=int) 226 | # remove duplicate structures from list of valid molecules if desired 227 | if 'unique' in args.filters: 228 | valid[duplicating != -1] = 0 229 | # count number of non-unique structures 230 | n_non_unique += np.sum(duplicate_count) 231 | 232 | # store list of valid molecules in dictionary 233 | d.update({'valid': valid}) 234 | 235 | # collect statistics of generated data 236 | n_generated += len(valid) 237 | n_valid += np.sum(valid) 238 | # count number of atoms per type (here for C, N, O, F, and H as example) 239 | n_of_types = [np.sum(all_numbers == i, axis=1) for i in [6, 7, 8, 9, 1]] 240 | stats_new = np.stack( 241 | (np.ones(len(valid)) * n_atoms, # n_atoms 242 | np.arange(0, len(valid)), # id 243 | valid, # valid 244 | duplicating, # id of duplicated molecule 245 | duplicate_count, # number of duplicates 246 | -np.ones(len(valid)), # known 247 | -np.ones(len(valid)), # equals 248 | *n_of_types, # n_atoms per type 249 | ), 250 | axis=0) 251 | stats = np.hstack((stats, stats_new)) 252 | 253 | if not print_file: 254 | print('\033[K', end='\r', flush=True) 255 | end_time = time.time() - start_time 256 | m, s = divmod(end_time, 60) 257 | h, m = divmod(m, 60) 258 | h, m, s = int(h), int(m), int(s) 259 | print(f'Needed {h:d}h{m:02d}m{s:02d}s.') 260 | 261 | # Update and print results 262 | res.update({'n_generated': n_generated, 263 | 'n_valid': n_valid, 264 | 'stats': stats, 265 | 'stat_heads': stat_heads}) 266 | 267 | print(f'Number of generated molecules: {n_generated}\n' 268 | f'Number of duplicate molecules: {n_non_unique}') 269 | if 'unique' in args.filters: 270 | print(f'Number of unique and valid molecules: {n_valid}') 271 | else: 272 | print(f'Number of valid molecules (including duplicates): {n_valid}') 273 | 274 | # Remove invalid molecules from results if desired 275 | if args.store != 'all': 276 | shrunk_res = {} 277 | shrunk_stats = np.empty((len(stats), 0)) 278 | i = 0 279 | for key in res: 280 | if isinstance(key, str): 281 | shrunk_res[key] = res[key] 282 | continue 283 | if key == 0: 284 | continue 285 | d = res[key] 286 | start = i 287 | end = i + len(d['valid']) 288 | idcs = np.where(d['valid'])[0] 289 | if len(idcs) < 1: 290 | i = end 291 | continue 292 | # shrink stats 293 | idx_id = stat_heads.index('id') 294 | idx_known = stat_heads.index('known') 295 | new_stats = stats[:, start:end] 296 | new_stats = new_stats[:, idcs] 297 | new_stats[idx_id] = np.arange(len(new_stats[idx_id])) # adjust ids 298 | shrunk_stats = np.hstack((shrunk_stats, new_stats)) 299 | # shrink positions and atomic numbers 300 | shrunk_res[key] = {Properties.R: d[Properties.R][idcs], 301 | Properties.Z: d[Properties.Z][idcs]} 302 | i = end 303 | 304 | shrunk_res['stats'] = shrunk_stats 305 | res = shrunk_res 306 | 307 | # transfer results to ASE db 308 | # get filename that is not yet taken for db 309 | if os.path.isfile(target_db): 310 | file_name, _ = os.path.splitext(target_db) 311 | expand = 0 312 | while True: 313 | expand += 1 314 | new_file_name = file_name + '_' + str(expand) 315 | if os.path.isfile(new_file_name + '.db'): 316 | continue 317 | else: 318 | target_db = new_file_name + '.db' 319 | break 320 | print(f'Transferring generated molecules to database at {target_db}...') 321 | # open db 322 | with connect(target_db) as conn: 323 | # store metadata 324 | conn.metadata = {'n_generated': int(n_generated), 325 | 'n_non_unique': int(n_non_unique), 326 | 'n_valid': int(n_valid), 327 | 'non_unique_removed_from_valid': 'unique' in args.filters} 328 | # store molecules 329 | for n_atoms in res: 330 | if isinstance(n_atoms, str) or n_atoms == 0: 331 | continue 332 | d = res[n_atoms] 333 | all_pos = d[Properties.R] 334 | all_numbers = d[Properties.Z] 335 | for pos, num in zip(all_pos, all_numbers): 336 | at = Atoms(num, positions=pos) 337 | conn.write(at) 338 | 339 | # store gathered statistics in separate file 340 | np.savez_compressed(os.path.splitext(target_db)[0] + f'_statistics.npz', 341 | stats=res['stats'], stat_heads=res['stat_heads']) 342 | -------------------------------------------------------------------------------- /display_molecules.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import subprocess 5 | import numpy as np 6 | import tempfile 7 | 8 | from ase.db import connect 9 | from ase.io import write 10 | from utility_classes import IndexProvider 11 | 12 | 13 | def get_parser(): 14 | """ Setup parser for command line arguments """ 15 | main_parser = argparse.ArgumentParser() 16 | main_parser.add_argument('--data_path', type=str, default=None, 17 | help='Path to database with filtered, generated molecules ' 18 | '(.db format, needs to be provided if generated ' 19 | 'molecules shall be displayed, default: %(default)s)') 20 | main_parser.add_argument('--train_data_path', type=str, 21 | help='Path to training data base (.db format, needs to be ' 22 | 'provided if molecules from the training data set ' 23 | 'shall be displayed, e.g. when using --train or ' 24 | '--test, default: %(default)s)', 25 | default=None) 26 | main_parser.add_argument('--select', type=str, nargs='*', 27 | help='Selection strings that specify which molecules ' 28 | 'shall be shown, if None all molecules from ' 29 | 'data_path and/or train_data_path are shown, ' 30 | 'providing multiple strings' 31 | ' will open multiple windows (one per string), ' 32 | '(default: %(default)s). The selection string has ' 33 | 'the general format "Property,OperatorTarget" (e.g. ' 34 | '"C,>8"to filter for all molecules with more than ' 35 | 'eight carbon atoms where "C" is the statistic ' 36 | 'counting the number of carbon atoms in a molecule, ' 37 | '">" is the operator, and "8" is the target value). ' 38 | 'Multiple conditions can be combined to form one ' 39 | 'selection string using "&" (e.g "C,>8&R5,>0" to ' 40 | 'get all molecules with more than 8 carbon atoms ' 41 | 'and at least 1 ring of size 5). Prepending ' 42 | '"training" to the selection string will filter and ' 43 | 'display molecules from the training data base ' 44 | 'instead of generated molecules (e.g. "training C,>8"' 45 | '). An overview of the available properties for ' 46 | 'molecuels generated with G-SchNet trained on QM9 can' 47 | ' be found in the README.md.', 48 | default=None) 49 | main_parser.add_argument('--print_indices', 50 | help='For each provided selection print out the indices ' 51 | 'of molecules that match the respective selection ' 52 | 'string', 53 | action='store_true') 54 | main_parser.add_argument('--export_to_dir', type=str, 55 | help='Optionally, provide a path to an directory to which ' 56 | 'indices of molecules matching the corresponding ' 57 | 'query shall be written (one .npy-file (numpy) per ' 58 | 'selection string, if None is provided, the ' 59 | 'indices will not be exported, default: %(default)s)', 60 | default=None) 61 | main_parser.add_argument('--train', 62 | help='Display all generated molecules that match ' 63 | 'structures used during training and the ' 64 | 'corresponding molecules from the training data set.', 65 | action='store_true') 66 | main_parser.add_argument('--test', 67 | help='Display all generated molecules that match ' 68 | 'held out test data structures and the ' 69 | 'corresponding molecules from the training data set.', 70 | action='store_true') 71 | main_parser.add_argument('--novel', 72 | help='Display all generated molecules that match neither ' 73 | 'structures used during training nor those held out ' 74 | 'as test data.', 75 | action='store_true') 76 | main_parser.add_argument('--block', 77 | help='Make the call to ASE GUI blocking (such that the ' 78 | 'script stops until the GUI window is closed).', 79 | action='store_true') 80 | 81 | return main_parser 82 | 83 | 84 | def view_ase(mols, name, block=False): 85 | ''' 86 | Display a list of molecules using the ASE GUI. 87 | 88 | Args: 89 | mols (list of ase.Atoms): molecules as ase.Atoms objects 90 | name (str): the name that shall be displayed in the windows top bar 91 | block (bool, optional): whether the call to ase gui shall block or not block 92 | the script (default: False) 93 | ''' 94 | dir = tempfile.mkdtemp('', 'generated_molecules_') # make temporary directory 95 | filename = os.path.join(dir, name) # path of temporary file 96 | format = 'traj' # use trajectory format for temporary file 97 | command = sys.executable + ' -m ase gui -b' # command to execute ase gui viewer 98 | write(filename, mols, format=format) # write molecules to temporary file 99 | # show molecules in ase gui and remove temporary file and directory afterwards 100 | if block: 101 | subprocess.call(command.split() + [filename]) 102 | os.remove(filename) 103 | os.rmdir(dir) 104 | else: 105 | subprocess.Popen(command.split() + [filename]) 106 | subprocess.Popen(['sleep 60; rm "{0}"'.format(filename)], shell=True) 107 | subprocess.Popen(['sleep 65; rmdir "{0}"'.format(dir)], shell=True) 108 | 109 | 110 | def print_indices(idcs, name='', per_line=10): 111 | ''' 112 | Prints provided indices in a clean formatting. 113 | 114 | Args: 115 | idcs (list of int): indices that shall be printed 116 | name (str): the selection string that was used to obtain the indices 117 | per_line (int, optional): the number of indices that are printed per line ( 118 | default: 10) 119 | ''' 120 | biggest = len(str(max(idcs))) 121 | new_line = '\n' 122 | format = f'>{biggest}d' 123 | str_idcs = [f'{j:{format}} ' + (new_line if (i+1) % per_line == 0 else '') 124 | for i, j in enumerate(idcs)] 125 | print(f'\nAll {len(idcs)} indices for selection {name}:') 126 | print(''.join(str_idcs)) 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = get_parser() 131 | args = parser.parse_args() 132 | 133 | # make sure that at least one path was provided 134 | if args.data_path is None and args.train_data_path is None: 135 | print(f'\nPlease specify --data_path to display generated molecules or ' 136 | f'--train_data_path to display training molecules (or both)!') 137 | sys.exit(0) 138 | 139 | # sort queries into those concerning generated structures and those concerning 140 | # training data molecules 141 | train_selections = [] 142 | gen_selections = [] 143 | if args.select is not None: 144 | for selection in args.select: 145 | if selection.startswith('training'): 146 | # put queries concerning training structures aside for later 147 | train_selections += [selection] 148 | else: 149 | gen_selections += [selection] 150 | 151 | # make sure that the required paths were provided 152 | if args.train or args.test: 153 | if args.data_path is None: 154 | print('\nYou need to specify --data_path (and optionally ' 155 | '--train_data_path) if using --train or --test!') 156 | sys.exit(0) 157 | if args.novel: 158 | if args.data_path is None: 159 | print('\nYou need to specify --data_path if you want to display novel ' 160 | 'molecules!') 161 | sys.exit(0) 162 | if len(gen_selections) > 0: 163 | if args.data_path is None: 164 | print(f'\nYou need to specify --data_path to process the selections ' 165 | f'{gen_selections}!') 166 | sys.exit(0) 167 | if len(train_selections) > 0: 168 | if args.train_data_path is None: 169 | print(f'\nYou need to specify --train_data_path to process the selections ' 170 | f'{train_selections}!') 171 | sys.exit(0) 172 | 173 | # check if statistics files are needed 174 | need_gen_stats = (len(gen_selections) > 0) or args.train or args.test or args.novel 175 | need_train_stats = (len(train_selections) > 0) or args.train or args.test 176 | 177 | # check if there is a database with generated molecules at the provided path 178 | # and load accompanying statistics file 179 | if args.data_path is not None: 180 | if not os.path.isfile(args.data_path): 181 | print(f'\nThe specified data path ({args.data_path}) is not a file! Please ' 182 | f'specify a different data path.') 183 | raise FileNotFoundError 184 | elif need_gen_stats: 185 | stats_path = os.path.splitext(args.data_path)[0] + f'_statistics.npz' 186 | if not os.path.isfile(stats_path): 187 | print(f'\nCannot find statistics file belonging to {args.data_path} (' 188 | f'expected it at {stats_path}. Please make sure that the file ' 189 | f'exists.') 190 | raise FileNotFoundError 191 | else: 192 | stats_dict = np.load(stats_path) 193 | index_provider = IndexProvider(stats_dict['stats'], 194 | stats_dict['stat_heads']) 195 | 196 | # check if there is a database with training molecules at the provided path 197 | # and load accompanying statistics file 198 | if args.train_data_path is not None: 199 | if not os.path.isfile(args.train_data_path): 200 | print(f'\nThe specified training data path ({args.train_data_path}) is ' 201 | f'not a file! Please specify --train_data_path correctly.') 202 | raise FileNotFoundError 203 | elif need_train_stats: 204 | stats_path = os.path.splitext(args.train_data_path)[0] + f'_statistics.npz' 205 | if not os.path.isfile(stats_path) and len(train_selections) > 0: 206 | print(f'\nCannot find statistics file belonging to ' 207 | f'{args.train_data_path} (expected it at {stats_path}. Please ' 208 | f'make sure that the file exists.') 209 | raise FileNotFoundError 210 | else: 211 | train_stats_dict = np.load(stats_path) 212 | train_index_provider = IndexProvider(train_stats_dict['stats'], 213 | train_stats_dict['stat_heads']) 214 | 215 | # create folder(s) for export of indices if necessary 216 | if args.export_to_dir is not None: 217 | if not os.path.isdir(args.export_to_dir): 218 | print(f'\nDirectory {args.export_to_dir} does not exist, creating ' 219 | f'it to store indices of molecules matching the queries!') 220 | os.makedirs(args.export_to_dir) 221 | else: 222 | print(f'\nWill store indices of molecules matching the queries at ' 223 | f'{args.export_to_dir}!') 224 | 225 | # display all generated molecules if desired 226 | if (len(gen_selections) == 0) and not (args.train or args.test or args.novel) and\ 227 | args.data_path is not None: 228 | with connect(args.data_path) as con: 229 | _ats = [con.get(int(idx) + 1).toatoms() for idx in range(con.count())] 230 | view_ase(_ats, 'all generated molecules', args.block) 231 | 232 | # display generated molecules matching selection strings 233 | if len(gen_selections) > 0: 234 | for selection in gen_selections: 235 | # display queries concerning generated molecules 236 | idcs = index_provider.get_selected(selection) 237 | if len(idcs) == 0: 238 | print(f'\nNo molecules match selection {selection}!') 239 | continue 240 | with connect(args.data_path) as con: 241 | _ats = [con.get(int(idx) + 1).toatoms() for idx in idcs] 242 | if args.print_indices: 243 | print_indices(idcs, selection) 244 | view_ase(_ats, f'generated molecules ({selection})', args.block) 245 | if args.export_to_dir is not None: 246 | np.save(os.path.join(args.export_to_dir, selection), idcs) 247 | 248 | # display all training molecules if desired 249 | if (len(train_selections) == 0) and not (args.train or args.test) and \ 250 | args.train_data_path is not None: 251 | with connect(args.train_data_path) as con: 252 | _ats = [con.get(int(idx) + 1).toatoms() for idx in range(con.count())] 253 | view_ase(_ats, 'all molecules in the training data set', args.block) 254 | 255 | # display training molecules matching selection strings 256 | if len(train_selections) > 0: 257 | # display training molecules that match the selection strings 258 | for selection in train_selections: 259 | _selection = selection.split()[1] 260 | stats_queries = [] 261 | db_queries = [] 262 | # sort into queries handled by looking into the statistics or the db 263 | for _sel_str in _selection.split('&'): 264 | prop = _sel_str.split(',')[0] 265 | if prop in train_stats_dict['stat_heads']: 266 | stats_queries += [_sel_str] 267 | elif len(prop.split('+')) > 0: 268 | found = True 269 | for p in prop.split('+'): 270 | if p not in train_stats_dict['stat_heads']: 271 | found = False 272 | break 273 | if found: 274 | stats_queries += [_sel_str] 275 | else: 276 | db_queries += [_sel_str] 277 | else: 278 | db_queries += [_sel_str] 279 | # process queries concerning the statistics 280 | if len(stats_queries) > 0: 281 | idcs = train_index_provider.get_selected('&'.join(stats_queries)) 282 | else: 283 | idcs = range(connect(args.train_data_path).count()) 284 | # process queries concerning the db entries 285 | if len(db_queries) > 0: 286 | with connect(args.train_data_path) as con: 287 | for query in db_queries: 288 | head, condition = query.split(',') 289 | if head not in con.get(1).data: 290 | print(f'Entry {head} not found for molecules in the ' 291 | f'database, skipping query {query}.') 292 | continue 293 | else: 294 | op = train_index_provider.rel_re.search(condition).group(0) 295 | op = train_index_provider.op_dict[op] # extract operator 296 | num = float(train_index_provider.num_re.search( 297 | condition).group(0)) # extract numerical value 298 | remaining_idcs = [] 299 | for idx in idcs: 300 | if op(con.get(int(idx)+1).data[head], num): 301 | remaining_idcs += [idx] 302 | idcs = remaining_idcs 303 | # extract molecules matching the query from db and display them 304 | if len(idcs) == 0: 305 | print(f'\nNo training molecules match selection {_selection}!') 306 | continue 307 | with connect(args.train_data_path) as con: 308 | _ats = [con.get(int(idx)+1).toatoms() for idx in idcs] 309 | if args.print_indices: 310 | print_indices(idcs, selection) 311 | view_ase(_ats, f'training data set molecules ({_selection})', args.block) 312 | if args.export_to_dir is not None: 313 | np.save(os.path.join(args.export_to_dir, selection), idcs) 314 | 315 | # display generated molecules that match structures used for training 316 | if args.train: 317 | idcs = index_provider.get_selected('known,>=1&known,<=2') 318 | if len(idcs) == 0: 319 | print(f'\nNo generated molecules found that match structures used ' 320 | f'during training!') 321 | else: 322 | with connect(args.data_path) as con: 323 | _ats = [con.get(int(idx) + 1).toatoms() for idx in idcs] 324 | if args.print_indices: 325 | print_indices(idcs, 'generated train') 326 | view_ase(_ats, f'generated molecules (matching train structures)', 327 | args.block) 328 | if args.export_to_dir is not None: 329 | np.save(os.path.join(args.export_to_dir, 'generated train'), idcs) 330 | # display corresponding training structures 331 | if args.train_data_path is not None: 332 | _row_idx = list(stats_dict['stat_heads']).index('equals') 333 | t_idcs = stats_dict['stats'][_row_idx, idcs].astype(int) 334 | with connect(args.train_data_path) as con: 335 | _ats = [con.get(int(idx) + 1).toatoms() for idx in t_idcs] 336 | if args.print_indices: 337 | print_indices(t_idcs, 'reference train') 338 | view_ase(_ats, f'training molecules (train structures)', args.block) 339 | if args.export_to_dir is not None: 340 | np.save(os.path.join(args.export_to_dir, 'reference train'), t_idcs) 341 | 342 | # display generated molecules that match held out test structures 343 | if args.test: 344 | idcs = index_provider.get_selected('known,==3') 345 | if len(idcs) == 0: 346 | print(f'\nNo generated molecules found that match held out test ' 347 | f'structures!') 348 | else: 349 | with connect(args.data_path) as con: 350 | _ats = [con.get(int(idx) + 1).toatoms() for idx in idcs] 351 | if args.print_indices: 352 | print_indices(idcs, 'generated test') 353 | view_ase(_ats, f'generated molecules (matching test structures)', 354 | args.block) 355 | if args.export_to_dir is not None: 356 | np.save(os.path.join(args.export_to_dir, 'generated test'), idcs) 357 | # display corresponding training structures 358 | if args.train_data_path is not None: 359 | _row_idx = list(stats_dict['stat_heads']).index('equals') 360 | t_idcs = stats_dict['stats'][_row_idx, idcs].astype(int) 361 | with connect(args.train_data_path) as con: 362 | _ats = [con.get(int(idx) + 1).toatoms() for idx in t_idcs] 363 | if args.print_indices: 364 | print_indices(t_idcs, 'reference test') 365 | view_ase(_ats, f'training molecules (test structures)', args.block) 366 | if args.export_to_dir is not None: 367 | np.save(os.path.join(args.export_to_dir, 'reference test'), t_idcs) 368 | 369 | # display generated molecules that are novel (i.e. that do not match held out 370 | # test structures or structures used during training) 371 | if args.novel: 372 | idcs = index_provider.get_selected('known,==0') 373 | if len(idcs) == 0: 374 | print(f'\nNo novel molecules found!') 375 | else: 376 | with connect(args.data_path) as con: 377 | _ats = [con.get(int(idx) + 1).toatoms() for idx in idcs] 378 | if args.print_indices: 379 | print_indices(idcs, 'novel') 380 | view_ase(_ats, f'generated molecules (novel)', args.block) 381 | if args.export_to_dir is not None: 382 | np.save(os.path.join(args.export_to_dir, 'generated novel'), idcs) 383 | -------------------------------------------------------------------------------- /qm9_preprocess_dataset.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import argparse 3 | import sys 4 | import time 5 | import numpy as np 6 | import logging 7 | from ase.db import connect 8 | from scipy.spatial.distance import pdist 9 | from utility_classes import ConnectivityCompressor, Molecule 10 | from multiprocessing import Process, Queue 11 | from pathlib import Path 12 | 13 | 14 | def get_parser(): 15 | """ Setup parser for command line arguments """ 16 | main_parser = argparse.ArgumentParser() 17 | main_parser.add_argument('datapath', help='Full path to dataset (e.g. ' 18 | '/home/qm9.db)') 19 | main_parser.add_argument('--valence_list', 20 | default=[1, 1, 6, 4, 7, 3, 8, 2, 9, 1], type=int, 21 | nargs='+', 22 | help='The valence of atom types in the form ' 23 | '[type1 valence type2 valence ...] ' 24 | '(default: %(default)s)') 25 | main_parser.add_argument('--n_threads', type=int, default=16, 26 | help='Number of extra threads used while ' 27 | 'processing the data') 28 | main_parser.add_argument('--n_mols_per_thread', type=int, default=100, 29 | help='Number of molecules processed by each ' 30 | 'thread in one iteration') 31 | return main_parser 32 | 33 | 34 | def is_disconnected(connectivity): 35 | ''' 36 | Assess whether all atoms of a molecule are connected using a connectivity matrix 37 | 38 | Args: 39 | connectivity (numpy.ndarray): matrix (n_atoms x n_atoms) indicating bonds 40 | between atoms 41 | 42 | Returns 43 | bool: True if the molecule consists of at least two disconnected graphs, 44 | False if all atoms are connected by some path 45 | ''' 46 | con_mat = connectivity 47 | seen, queue = {0}, collections.deque([0]) # start at node (atom) 0 48 | while queue: 49 | vertex = queue.popleft() 50 | # iterate over (bonded) neighbors of current node 51 | for node in np.argwhere(con_mat[vertex] > 0).flatten(): 52 | # add node to queue and list of seen nodes if it has not been seen before 53 | if node not in seen: 54 | seen.add(node) 55 | queue.append(node) 56 | # if the seen nodes do not include all nodes, there are disconnected parts 57 | return seen != {*range(len(con_mat))} 58 | 59 | 60 | def get_count_statistics(mol=None, get_stat_heads=False): 61 | ''' 62 | Collects atom, bond, and ring count statistics of a provided molecule 63 | 64 | Args: 65 | mol (utility_classes.Molecule): Molecule to be examined 66 | get_stat_heads (bool, optional): set True to only return the headers of 67 | gathered statistics (default: False) 68 | 69 | Returns: 70 | numpy.ndarray: (n_statistics x 1) array containing the gathered statistics. Use 71 | get_stat_heads parameter to obtain the corresponding row headers (where RX 72 | describes number of X-membered rings and CXC indicates the number of 73 | carbon-carbon bonds of order X etc.). 74 | ''' 75 | stat_heads = ['n_atoms', 'C', 'N', 'O', 'F', 'H', 'H1C', 'H1N', 76 | 'H1O', 'C1C', 'C2C', 'C3C', 'C1N', 'C2N', 'C3N', 'C1O', 77 | 'C2O', 'C1F', 'N1N', 'N2N', 'N1O', 'N2O', 'N1F', 'O1O', 78 | 'O1F', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R>8'] 79 | if get_stat_heads: 80 | return stat_heads 81 | if mol is None: 82 | return None 83 | key_idx_dict = dict(zip(stat_heads, range(len(stat_heads)))) 84 | stats = np.zeros((len(stat_heads), 1)) 85 | # process all bonds and store statistics about bond and ring counts 86 | bond_stats = mol.get_bond_stats() 87 | for key, value in bond_stats.items(): 88 | if key in key_idx_dict: 89 | idx = key_idx_dict[key] 90 | stats[idx, 0] = value 91 | # store simple statistics about number of atoms 92 | stats[key_idx_dict['n_atoms'], 0] = mol.n_atoms 93 | for key in ['C', 'N', 'O', 'F', 'H']: 94 | idx = key_idx_dict[key] 95 | charge = mol.type_charges[key] 96 | if charge in mol._unique_numbers: 97 | stats[idx, 0] = np.sum(mol.numbers == charge) 98 | return stats 99 | 100 | 101 | def preprocess_molecules(mol_idcs, source_db, valence, 102 | precompute_distances=True, remove_invalid=True, 103 | invalid_list=None, print_progress=False): 104 | ''' 105 | Checks the validity of selected molecules and collects atom, bond, 106 | and ring count statistics for the valid structures. Molecules are classified as 107 | invalid if they consist of disconnected parts or fail a valence check, where the 108 | valency constraints of all atoms in a molecule have to be satisfied (e.g. carbon 109 | has four bonds, nitrogen has three bonds etc.) 110 | 111 | Args: 112 | mol_idcs (array): the indices of molecules from the source database that 113 | shall be examined 114 | source_db (str): full path to the source database (in ase.db sqlite format) 115 | valence (array): an array where the i-th entry contains the valency 116 | constraint of atoms with atomic charge i (e.g. a valency of 4 at array 117 | position 6 representing carbon) 118 | precompute_distances (bool, optional): if True, the pairwise distances between 119 | atoms in each molecule are computed and stored in the database (default: 120 | True) 121 | remove_invalid (bool, optional): if True, molecules that do not pass the 122 | valency or connectivity checks (or are on the invalid_list) are removed from 123 | the new database (default: True) 124 | invalid_list (list of int, optional): precomputed list containing indices of 125 | molecules that are marked as invalid (because they did not pass the 126 | valency or connectivity checks in earlier runs, default: None) 127 | print_progress (bool, optional): set True to print the progress in percent 128 | (default: False) 129 | 130 | Returns 131 | list of ase.Atoms: list of all valid molecules 132 | list of dict: list of corresponding dictionaries with data of each molecule 133 | numpy.ndarray: (n_statistics x n_valid_molecules) matrix with atom, bond, 134 | and ring count statistics 135 | list of int: list with indices of molecules that failed the valency check 136 | list of int: list with indices of molecules that consist of disconnected parts 137 | int: number of molecules processed 138 | ''' 139 | # initial setup 140 | count = 0 # count the number of invalid molecules 141 | disc = [] # store indices of disconnected molecules 142 | inval = [] # store indices of invalid molecules 143 | data_list = [] # store data fields of molecules for new db 144 | mols = [] # store molecules (as ase.Atoms objects) 145 | compressor = ConnectivityCompressor() # (de)compress sparse connectivity matrices 146 | stats = np.empty((len(get_count_statistics(get_stat_heads=True)), 0)) 147 | n_all = len(mol_idcs) 148 | 149 | with connect(source_db) as source_db: 150 | # iterate over provided indices 151 | for i in mol_idcs: 152 | i = int(i) 153 | # skip molecule if present in invalid_list and remove_invalid is True 154 | if remove_invalid and invalid_list is not None: 155 | if i in invalid_list: 156 | continue 157 | # get molecule from database 158 | row = source_db.get(i + 1) 159 | data = row.data 160 | at = row.toatoms() 161 | # get positions and atomic numbers 162 | pos = at.positions 163 | numbers = at.numbers 164 | # center positions (using center of mass) 165 | pos = pos - at.get_center_of_mass() 166 | # order atoms by distance to center of mass 167 | center_dists = np.sqrt(np.maximum(np.sum(pos ** 2, axis=1), 0)) 168 | idcs_sorted = np.argsort(center_dists) 169 | pos = pos[idcs_sorted] 170 | numbers = numbers[idcs_sorted] 171 | # update positions and atomic numbers accordingly in Atoms object 172 | at.positions = pos 173 | at.numbers = numbers 174 | # instantiate utility_classes.Molecule object 175 | mol = Molecule(pos, numbers) 176 | # get connectivity matrix (detecting bond orders with Open Babel) 177 | con_mat = mol.get_connectivity() 178 | # stop if molecule is disconnected (and therefore invalid) 179 | if remove_invalid: 180 | if is_disconnected(con_mat): 181 | count += 1 182 | disc += [i] 183 | continue 184 | 185 | # check if valency constraints of all atoms in molecule are satisfied: 186 | # since the detection of bond orders for the connectivity matrix with Open 187 | # Babel is unreliable for certain cases (e.g. some aromatic rings) we 188 | # try to fix it manually (with heuristics) or by reshuffling the atom 189 | # order (as the bond order detection of Open Babel is sensitive to the 190 | # order of atoms) 191 | nums = numbers 192 | random_ord = np.arange(len(numbers)) 193 | for _ in range(10): # try 10 times before dismissing as invalid 194 | if np.all(np.sum(con_mat, axis=0) == valence[nums]): 195 | # valency is correct -> mark as valid and stop check 196 | val = True 197 | break 198 | else: 199 | # try to fix bond orders using heuristics 200 | val = False 201 | con_mat = mol.get_fixed_connectivity() 202 | if np.all(np.sum(con_mat, axis=0) == valence[nums]): 203 | # valency is now correct -> mark as valid and stop check 204 | val = True 205 | break 206 | # shuffle atom order before checking valency again 207 | random_ord = np.random.permutation(range(len(pos))) 208 | mol = Molecule(pos[random_ord], numbers[random_ord]) 209 | con_mat = mol.get_connectivity() 210 | nums = numbers[random_ord] 211 | if remove_invalid: 212 | if not val: 213 | # stop if molecule is invalid (it failed the repeated valence checks) 214 | count += 1 215 | inval += [i] 216 | continue 217 | 218 | if precompute_distances: 219 | # calculate pairwise distances of atoms and store them in data 220 | dists = pdist(pos)[:, None] 221 | data.update({'dists': dists}) 222 | 223 | # store compressed connectivity matrix in data 224 | rand_ord_rev = np.argsort(random_ord) 225 | con_mat = con_mat[rand_ord_rev][:, rand_ord_rev] 226 | data.update( 227 | {'con_mat': compressor.compress(con_mat)}) 228 | 229 | # update atom, bond, and ring count statistics 230 | stats = np.hstack((stats, get_count_statistics(mol=mol))) 231 | 232 | # add results to the lists 233 | mols += [at] 234 | data_list += [data] 235 | 236 | # print progress if desired 237 | if print_progress: 238 | if i % 100 == 0: 239 | print('\033[K', end='\r', flush=True) 240 | print(f'{100 * (i + 1) / n_all:.2f}%', end='\r', flush=True) 241 | 242 | return mols, data_list, stats, inval, disc, count 243 | 244 | 245 | def _processing_worker(q_in, q_out, task): 246 | ''' 247 | Simple worker function that repeatedly fulfills a task using transmitted input and 248 | sends back the results until a stop signal is received. Can be used as target in 249 | a multiprocessing.Process object. 250 | 251 | Args: 252 | q_in (multiprocessing.Queue): queue to receive a list with data. The first 253 | entry signals whether worker can stop and the remaining entries are used as 254 | input arguments to the task function 255 | q_out (multiprocessing.Queue): queue to send results from task back 256 | task (callable function): function that is called using the received data 257 | ''' 258 | while True: 259 | data = q_in.get(True) # receive data 260 | if data[0]: # stop if stop signal is received 261 | break 262 | results = task(*data[1:]) # fulfill task with received data 263 | q_out.put(results) # send back results 264 | 265 | 266 | def _submit_jobs(qs_out, count, chunk_size, n_all, working_flag, 267 | n_per_thread): 268 | ''' 269 | Function that submits a job to preprocess molecules to every provided worker. 270 | 271 | Args: 272 | qs_out (list of multiprocessing.Queue): queues used to send data to workers (one 273 | queue per worker) 274 | count (int): index of the earliest, not yet preprocessed molecule in the db 275 | chunk_size (int): number of molecules to be divided amongst workers 276 | n_all (int): total number of molecules in the db 277 | working_flag (array): flags indicating whether workers are running 278 | n_per_thread (int): number of molecules to be given to each thread 279 | 280 | Returns: 281 | numpy.ndarray: array with flags indicating whether workers got 282 | a job 283 | int: index of the new earliest, not yet preprocessed molecule in 284 | the db (after the submitted preprocessing jobs have been done) 285 | ''' 286 | # calculate indices of molecules that shall be preprocessed by workers 287 | idcs = np.arange(count, min(n_all, count + chunk_size)) 288 | start = 0 289 | for i, q in enumerate(qs_out): 290 | if start >= len(idcs): 291 | # stop if no more indices are left to submit 292 | break 293 | end = start + n_per_thread 294 | q.put((False, idcs[start:end])) # submit indices (and signal to not stop) 295 | working_flag[i] = 1 # set flag that current worker got a job 296 | start = end 297 | new_count = count + len(idcs) 298 | return working_flag, new_count 299 | 300 | 301 | def preprocess_dataset(datapath, valence_list, n_threads, n_mols_per_thread=100, 302 | logging_print=True, new_db_path=None, precompute_distances=True, 303 | remove_invalid=True, invalid_list=None): 304 | ''' 305 | Pre-processes all molecules of a dataset using the provided valency information. 306 | Multi-threading is used to speed up the process. 307 | Along with a new database containing the pre-processed molecules, a 308 | "input_db_invalid.txt" file holding the indices of removed molecules (which 309 | do not pass the valence or connectivity checks, omitted if remove_invalid is False) 310 | and a "new_db_statistics.npz" file (containing atom, bond, and ring count statistics 311 | for all molecules in the new database) are stored. 312 | 313 | Args: 314 | datapath (str): full path to dataset (ase.db database) 315 | valence_list (list): the valence of atom types in the form 316 | [type1 valence type2 valence ...] 317 | n_threads (int): number of threads used (0 for no extra threads) 318 | n_mols_per_thread (int, optional): number of molecules processed by each 319 | thread at each iteration (default: 100) 320 | logging_print (bool, optional): set True to show output with logging.info 321 | instead of standard printing (default: True) 322 | new_db_path (str, optional): full path to new database where pre-processed 323 | molecules shall be stored (None to simply append "gen" to the name in 324 | datapath, default: None) 325 | precompute_distances (bool, optional): if True, the pairwise distances between 326 | atoms in each molecule are computed and stored in the database (default: 327 | True) 328 | remove_invalid (bool, optional): if True, molecules that do not pass the 329 | valency or connectivity check are removed from the new database (default: 330 | True) 331 | invalid_list (list of int, optional): precomputed list containing indices of 332 | molecules that are marked as invalid (because they did not pass the 333 | valency or connectivity checks in earlier runs, default: None) 334 | ''' 335 | # convert paths 336 | datapath = Path(datapath) 337 | if new_db_path is None: 338 | new_db_path = datapath.parent / (datapath.stem + 'gen.db') 339 | else: 340 | new_db_path = Path(new_db_path) 341 | 342 | # compute array where the valency constraint of atom type i is stored at entry i 343 | max_type = max(valence_list[::2]) 344 | valence = np.zeros(max_type + 1, dtype=int) 345 | valence[valence_list[::2]] = valence_list[1::2] 346 | 347 | def _print(x, end='\n', flush=False): 348 | if logging_print: 349 | logging.info(x) 350 | else: 351 | print(x, end=end, flush=flush) 352 | 353 | with connect(datapath) as db: 354 | n_all = db.count() 355 | if n_all == 0: 356 | _print('No molecules found in data base!') 357 | sys.exit(0) 358 | _print('\nPre-processing data...') 359 | if logging_print: 360 | _print(f'Processed: 0 / {n_all}...') 361 | else: 362 | _print(f'0.00%', end='', flush=True) 363 | 364 | # initial setup 365 | n_iterations = 0 366 | chunk_size = n_threads * n_mols_per_thread 367 | current = 0 368 | count = 0 # count number of discarded (invalid etc.) molecules 369 | disc = [] 370 | inval = [] 371 | stats = np.empty((len(get_count_statistics(get_stat_heads=True)), 0)) 372 | working_flag = np.zeros(n_threads, dtype=bool) 373 | start_time = time.time() 374 | if invalid_list is not None and remove_invalid: 375 | invalid_list = {*invalid_list} 376 | n_inval = len(invalid_list) 377 | else: 378 | n_inval = 0 379 | 380 | with connect(new_db_path) as new_db: 381 | 382 | if n_threads >= 1: 383 | # set up threads and queues 384 | threads = [] 385 | qs_in = [] 386 | qs_out = [] 387 | for i in range(n_threads): 388 | qs_in += [Queue(1)] 389 | qs_out += [Queue(1)] 390 | threads += \ 391 | [Process(target=_processing_worker, 392 | name=str(i), 393 | args=(qs_out[-1], 394 | qs_in[-1], 395 | lambda x: 396 | preprocess_molecules(x, 397 | datapath, 398 | valence, 399 | precompute_distances, 400 | remove_invalid, 401 | invalid_list)))] 402 | threads[-1].start() 403 | 404 | # submit first round of jobs 405 | working_flag, current = \ 406 | _submit_jobs(qs_out, current, chunk_size, n_all, 407 | working_flag, n_mols_per_thread) 408 | 409 | while np.any(working_flag == 1): 410 | n_iterations += 1 411 | 412 | # initialize new iteration 413 | results = [] 414 | 415 | # gather results 416 | for i, q in enumerate(qs_in): 417 | if working_flag[i]: 418 | results += [q.get()] 419 | working_flag[i] = 0 420 | 421 | # submit new jobs 422 | working_flag, current_new = \ 423 | _submit_jobs(qs_out, current, chunk_size, n_all, working_flag, 424 | n_mols_per_thread) 425 | 426 | # store gathered results 427 | for res in results: 428 | mols, data_list, _stats, _inval, _disc, _c = res 429 | for (at, data) in zip(mols, data_list): 430 | new_db.write(at, data=data) 431 | stats = np.hstack((stats, _stats)) 432 | inval += _inval 433 | disc += _disc 434 | count += _c 435 | 436 | # print progress 437 | if logging_print and n_iterations % 10 == 0: 438 | _print(f'Processed: {current:6d} / {n_all}...') 439 | elif not logging_print: 440 | _print('\033[K', end='\r', flush=True) 441 | _print(f'{100 * current / n_all:.2f}%', end='\r', 442 | flush=True) 443 | current = current_new # update current position in database 444 | 445 | # stop worker threads and join 446 | for i, q_out in enumerate(qs_out): 447 | q_out.put((True,)) 448 | threads[i].join() 449 | threads[i].terminate() 450 | if logging_print: 451 | _print(f'Processed: {n_all} / {n_all}...') 452 | 453 | else: 454 | results = preprocess_molecules(range(n_all), datapath, valence, 455 | precompute_distances, remove_invalid, 456 | invalid_list, print_progress=True) 457 | mols, data_list, stats, inval, disc, count = results 458 | for (at, data) in zip(mols, data_list): 459 | new_db.write(at, data=data) 460 | 461 | if not logging_print: 462 | _print('\033[K', end='\n', flush=True) 463 | _print(f'... successfully validated {n_all - count - n_inval} data ' 464 | f'points!', flush=True) 465 | if invalid_list is not None: 466 | _print(f'{n_inval} structures were removed because they are on the ' 467 | f'pre-computed list of invalid molecules!', flush=True) 468 | if len(disc)+len(inval) > 0: 469 | _print(f'CAUTION: Could not validate {len(disc)+len(inval)} additional ' 470 | f'molecules. These were also removed and their indices are ' 471 | f'appended to the list of invalid molecules stored at ' 472 | f'{datapath.parent / (datapath.stem + f"_invalid.txt")}', 473 | flush=True) 474 | np.savetxt(datapath.parent / (datapath.stem + f'_invalid.txt'), 475 | np.append(np.sort(list(invalid_list)), np.sort(inval + disc)), 476 | fmt='%d') 477 | elif remove_invalid: 478 | _print(f'Identified {len(disc)} disconnected structures, and {len(inval)} ' 479 | f'structures with invalid valence!', flush=True) 480 | np.savetxt(datapath.parent / (datapath.stem + f'_invalid.txt'), 481 | np.sort(inval + disc), fmt='%d') 482 | _print('\nCompressing and storing statistics with numpy...') 483 | np.savez_compressed(new_db_path.parent/(new_db_path.stem+f'_statistics.npz'), 484 | stats=stats, 485 | stat_heads=get_count_statistics(get_stat_heads=True)) 486 | 487 | end_time = time.time() - start_time 488 | m, s = divmod(end_time, 60) 489 | h, m = divmod(m, 60) 490 | h, m, s = int(h), int(m), int(s) 491 | _print(f'Done! Pre-processing needed {h:d}h{m:02d}m{s:02d}s.') 492 | 493 | 494 | if __name__ == '__main__': 495 | parser = get_parser() 496 | args = parser.parse_args() 497 | preprocess_dataset(**vars(args)) 498 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ⚠️ **_Disclaimer: This repository is deprecated and only meant for reproduction of the published results of G-SchNet on QM9. If you want to use custom data sets or build on top of our model, please refer to the [up-to-date implementation](https://github.com/atomistic-machine-learning/schnetpack-gschnet)._** 2 | 3 | # G-SchNet 4 | 5 | ![generated molecules](./images/example_molecules_1.png) 6 | Implementation of G-SchNet - a generative model for 3d molecular structures - accompanying the paper [_"Symmetry-adapted generation of 3d point sets for the targeted discovery of molecules"_](http://papers.nips.cc/paper/8974-symmetry-adapted-generation-of-3d-point-sets-for-the-targeted-discovery-of-molecules) published at NeurIPS 2019. 7 | 8 | G-SchNet generates molecules in an autoregressive fashion, placing one atom after another in 3d euclidean space. The model can be trained on data sets with molecules of variable size and composition. It only uses the positions and types of atoms in a molecule, needing no bond-based information such as molecular graphs. 9 | 10 | The code provided in this repository allows to train G-SchNet on the QM9 data set which consists of approximately 130k small molecules with up to nine heavy atoms from fluorine, oxygen, nitrogen, and carbon. We provide the pre- and post-processing routines used in our paper's experiments with QM9 molecules in order to make our results reproducible. Although the code and the following guide is very much tailored to QM9, we also provide a few template scripts with basic functionality that can serve as a solid starting point for the application of G-SchNet to other data sets with differently composed molecular structures. The corresponding description can be found at the bottom of this readme after the introduction to the QM9 scripts. 11 | 12 | ### Requirements 13 | - schnetpack 0.3 14 | - pytorch >= 1.2 15 | - python >= 3.7 16 | - ASE >= 3.17.0 17 | - Open Babel 2.41 18 | - rdkit >= 2019.03.4.0 19 | 20 | The following commands will create a new conda environment called _"gschnet"_ and install all dependencies (tested on Ubuntu 18.04): 21 | 22 | conda create -n gschnet python=3.7 pytorch=1.5.1 torchvision cudatoolkit=10.2 ase=3.19.0 openbabel=2.4.1 rdkit=2019.09.2.0 -c pytorch -c openbabel -c defaults -c conda-forge 23 | conda activate gschnet 24 | pip install 'schnetpack==0.3' 25 | 26 | Replace _"cudatoolkit=10.2"_ with _"cpuonly"_ if you do not want to utilize a GPU for training/generation. However, we strongly recommend to use a GPU if available. 27 | 28 | 29 | # Getting started with G-SchNet and QM9 30 | Clone the repository into your folder of choice: 31 | 32 | git clone https://github.com/atomistic-machine-learning/G-SchNet.git 33 | 34 | ### Training a model 35 | A model with the same settings as described in the paper can be trained by running gschnet_script.py with standard parameters: 36 | 37 | python ./G-SchNet/gschnet_script.py train gschnet ./data/ ./models/gschnet/ --split 50000 5000 --cuda 38 | 39 | The training data (QM9) is automatically downloaded and preprocessed if not present in ./data/ and the model will be stored in ./models/gschnet/. 40 | With _--split 50000 5000_, 50k molecules are used as the training set, 5k are used for validation, and the remaining structures are left out as a test set. 41 | We recommend to train on a GPU but you can remove _--cuda_ from the call to use the CPU instead. If your GPU has less than 16GB VRAM, you need to decrease the number of features (e.g. _--features 64_) or the depth of the network (e.g. _--interactions 6_). 42 | 43 | At the bottom of this page, we provide a model trained exactly as described above for download. Feel free to use it instead of training your own model. 44 | 45 | ### Generating molecules 46 | Running the script with the following arguments will generate 1000 molecules using the trained model at ./model/geschnet/ and store them in ./model/gschnet/generated/generated.mol_dict: 47 | 48 | python ./G-SchNet/gschnet_script.py generate gschnet ./models/gschnet/ 1000 --cuda 49 | 50 | Remove _--cuda_ from the call if you want to run on the CPU. Add _--show_gen_ to display the molecules with ASE after generation. If you are running into problems due to small VRAM, decrease the size of mini-batches during generation (e.g. _--chunk_size 500_, default is 1000). 51 | 52 | ### Filtering and analysis of generated molecules 53 | After generation, the generated molecules can be filtered for invalid and duplicate structures by running qm9_filter_generated.py: 54 | 55 | python ./G-SchNet/qm9_filter_generated.py ./models/gschnet/generated/generated.mol_dict --train_data_path ./data/qm9gen.db --model_path ./models/gschnet 56 | 57 | The script will print its progress and the gathered results. To store them in a file, please redirect the console output to a file (e.g. ./results.txt) and use the _--print_file_ argument when calling the script: 58 | 59 | python ./G-SchNet/qm9_filter_generated.py ./models/gschnet/generated/generated.mol_dict --train_data_path ./data/qm9gen.db --model_path ./models/gschnet --print_file >> ./results.txt 60 | 61 | The script checks the valency constraints (e.g. every hydrogen atom should have exactly one bond), the connectedness (i.e. all atoms in a molecule should be connected to each other via a path over bonds), and removes duplicates*. The remaining valid structures are stored in an sqlite database with ASE (at ./models/gschnet/generated/generated_molecules.db) along with an .npz-file that records certain statistics (e.g. the number of rings of certain sizes, the number of single, double, and triple bonds, the index of the matching training/test data molecule etc. for each molecule, see tables below for an overview showing all stored statistics). 62 | 63 | *_Please note that, as described in the paper, we use molecular fingerprints and canonical smiles representations to identify duplicates which means that different spatial conformers corresponding to the same canonical smiles string are tagged as duplicates and removed in the process. Add '--filters valence disconnected' to the call in order to not remove but keep identified duplicates in the created database._ 64 | 65 | ### Displaying generated and QM9 training data molecules 66 | After filtering, all generated molecules stored in the sqlite database can be displayed with ASE as follows: 67 | 68 | python ./G-SchNet/display_molecules.py --data_path ./models/gschnet/generated/generated_molecules.db 69 | 70 | The script allows to query the generated molecules for structures with certain properties using _--select "selection string"_. The selection string has the general format _"Property,OperatorTarget"_ (e.g. _"C,>8"_ to filter for all molecules with more than eight carbon atoms where _"C"_ is the statistic counting the number of carbon atoms in a molecule, _">"_ is the operator, and _"8"_ is the target value). Multiple conditions can be combined to form one selection string using _"&"_ (e.g _"C,>8&R5,>0"_ to get all molecules with more than 8 carbon atoms and at least 1 ring of size 5). Furthermore, multiple selection strings may be provided such that multiple windows with molecule plots are opened (one per selection string). The available operators are _"<", "<=", "=", "!=", ">=",_ and _">"_. Properties may be summed together using _"+"_ (e.g. _"R5+R6,=1"_ to get molecules with exactly one ring of size 5 or 6). For a list of all available properties, see the tables below. 71 | 72 | An example call to display all generated molecules that consist of at least 7 carbon atoms and two rings of size 6 or 5 and to display all generated molecules that have at least 1 Fluorine atom: 73 | 74 | python ./G-SchNet/display_molecules.py --data_path ./models/gschnet/generated/generated_molecules.db --select "C,>=7&R5+R6,=2" "F,>=1" 75 | 76 | The same script can also be used to display molecules from the QM9 training database using _--train_data_path_: 77 | 78 | python ./G-SchNet/display_molecules.py --train_data_path ./data/qm9gen.db 79 | 80 | Note that displaying all ~130k molecules from the database is quite slow. However, the training database can also be queried in a similar manner by prepending _"training"_ to the selection string. For example, the following call will display all molecules from the QM9 database that have at least one Fluorine atom and not more than 5 other heavy atoms: 81 | 82 | python ./G-SchNet/display_molecules.py --train_data_path ./data/qm9gen.db --select "training F,>=1&C+N+O,<=5" 83 | 84 | Using _--train_ or _--test_ with _--data_path_ will display all generated molecules that match structures used for training or held out test data, respectively, and the corresponding reference molecules from the QM9 database if _--train_data_path_ is also provided. _--novel_ will display all generated molecules that match neither structures used for training nor held out test data. 85 | 86 | The indices of molecules matching the queries can be exported using _--export_to_dir_. For example, the indices of all molecules from QM9 with HOMO-LUMO gap < 4.5 eV will be stored as numpy-readable file _training gap,<=4.5.npy_ at ./data/subsets/ by calling: 87 | 88 | python ./G-SchNet/display_molecules.py --train_data_path ./data/qm9gen.db --select "training gap,<=4.5" --export_to_dir ./data/subsets 89 | 90 | The following properties are available for __both generated molecules as well as structures in the QM9 training database__: 91 | 92 | | property | description | 93 | |---|---| 94 | | n_atoms | total number of atoms | 95 | | C, N, O, F, H | number of atoms of the respective type | 96 | | H1C, C2C, N1O, ... | number of covalent bonds of a certain kind (single, double, triple) between two specific atom types (the types are ordere by increasing nuclear charge, i.e. write C3N not N3C) | 97 | | R3, ..., R8, R>8 | number of rings of a certain size (3-8, >8) | 98 | 99 | Additionally, __generated molecules__ allow to use the following properties in selection strings: 100 | 101 | | property | description | 102 | |---|---| 103 | | known | whether the molecule is novel (0) or matches a structure used for training (1), used for validation (2), or from the held out test data (3) | 104 | | equals | the index of the matching molecule in the training database (if known is 1, 2, or 3) or -1 (if known is 0) | 105 | | n_duplicates | the number of times the particular molecule was generated (0 if duplicating is not -1) | 106 | | duplicating | this is -1 for all "original" structures (i.e. the first occurence of a generated molecule) and the index of the original structure if the generated molecule is a duplicate (in the default settings only original, i.e. unique, structures are stored in the database) | 107 | | valid | whether the molecule passed the validity check during filtering (i.e. the valency, connectedness and uniquess checks, in the default settings only valid molecules are stored in the databbase) | 108 | 109 | Finally, molecules from the __QM9 training database__ can also be queried for properties available in the QM9 data set: 110 | 111 | | property | unit | description | 112 | |---|---|---| 113 | | dipole_moment | e*Ångström | length of the dipole moment | 114 | | isotropic_polarizability | Ångström³ | isotropic polarizability | 115 | | homo | eV | energy of highest occupied molecular orbital (HOMO) | 116 | | lumo | eV | energy of lowest unoccupied molecular orbital (LUMO) | 117 | | gap | eV | energy difference between the HOMO and LUMO (HOMO-LUMO gap) | 118 | | electronic_spatial_extent | Ångström² | electronic spatial extent | 119 | | zpve | eV | zero point vibrational energy | 120 | | energy_U0 | eV | internal energy at 0 K | 121 | | energy_U | eV | internal energy at 298.15 K | 122 | | enthalpy_H | eV | enthalpy at 298.15 K | 123 | | free_energy | eV | free energy at 298.15 K | 124 | | heat_capacity | cal/(molK) | heat capacity at 298.15 K | 125 | 126 | All properties use the ASE-internal units and therefore can easily be converted with ASE. For example, you can get the dipole moment in Debye by multiplying it with 1/ase.units.Debye. Similarly, the isotropic polarizability can be converted to Bohr³ using 1/ase.units.Bohr³ and the electronic spatial extent may be obtained in Bohr² with 1/ase.units.Bohr². 127 | 128 | ### Training a biased model 129 | The generation of molecules with G-SchNet can be biased towards desired target properties of QM9 molecules. To this end, the weights of an already trained model are fine-tuned in a second training run where only a small number of molecules that exhibit the desired target property is used as training data. For example, we biased the generation towards molecules with a small HOMO-LUMO gap in our paper. We found that the pre-training with a large and diverse set of molecules would increase the robustness of the learned model (e.g. increase the number of generated molecules that are valid structures) compared to training on the small subset directly. 130 | 131 | The second training run for biasing is started with the same script as the usual training but requires two additional parameters. This is the path to an already trained model that is used to initialize the weights and the path to a file holding the indices of molecules that exhibit the desired target property. Such a file can be obtained using the display_molecules script (see description in the section above, where we extract the indices of all molecules with HOMO-LUMO gap smaller than 4.5 eV). 132 | 133 | Assume there is a model already trained on 50k examples from QM9 at ./models/gschnet and a file with the indices of 3000 molecules that exhibit the desired target property at ./data/subsets/indices.npy, then a biased model can be trained with the following call: 134 | 135 | python ./G-SchNet/gschnet_script.py train gschnet ./data/ ./models/biased_gschnet/ --split 2000 500 --cuda --pretrained_path ./models/gschnet --subset_path ./data/subsets/indices.npy 136 | 137 | The argument _--split_ needs to be adjusted according to the number of molecules available in the subset. Note that the learning rate parameters can also be adjusted with _--lr, --lr_decay, --lr_patience,_ and _--lr_min_, which are 1e-4, 0.5, 10, and 1e-6 per default, respectively. In our paper, we used these standard parameters for the pre-training as well as for the fine-tuning with respect to small HOMO-LUMO gaps, where we had 3.3k molecules for training and 0.5k for validation. If there are significanly less molecules exhibiting the target property, it could be better to decrease the learning rate for the fine-tuning step such that overfitting is prevented as more information from the pre-trained weights is retained. On the contrary, if there is a larger subset of molecules with the target property, training G-SchNet directly on that subset might lead to similarly good results as starting from the pre-trained weights. 138 | 139 | After the training has converged, molecules can be sampled from the biased distribution and filtererd afterwards just as before and described in the previous sections (but of course the path to the model directory ./models/geschnet needs to be replaced with ./models/biased_gschnet in the arguments when calling the scripts). 140 | 141 | # Applying G-SchNet to other data sets 142 | 143 | _Disclaimer: Since the code was mainly written to run experiments on QM9, experiments with other data sets will most likely need code adaptations. Although G-SchNet generalizes well to larger structures in theory, there are a few key points missing in the current implementation that might hinder its applicability to structures with many atoms (e.g. significantly more than 100). For example, all already placed atoms are predicting a distance to the new position at each step right now. Here it would be better to introduce a cutoff that limits the number of atoms used to predict distances to a smaller region around the focus token. Furthermore, we currently also do not use the provided cutoff to limit the number of neighbors that are examined in the SchNet feature extraction (i.e. we always use the full distance matrix in the continuous filter convolutions instead of removing atoms from the computations that are masked by the cutoff anyways). Therefore, we recommend to instead use our [re-implementation of G-SchNet](https://github.com/atomistic-machine-learning/schnetpack-gschnet) which uses such cutoffs and is designed to work with larger molecules and custom data._ 144 | 145 | In the following we will describe the provided template scripts and how they could be adjusted in order to use our implementation of G-SchNet on data sets other than QM9. There are three relevant files, template_data.py, template_preprocess_dataset.py, and template_filter_generated.py. They take care of loading the data, pre-processing the data, and filtering molecules after generation, respectively. 146 | 147 | ### Adjusting the template data class 148 | 149 | The file template_data.py contains a template data class that loads molecules from an sq-lite database assembled with ASE. It already provides all the necessary functionality to run with gschnet_script.py (e.g. using only subsets of data, splitting of the data set, initialisation of pre-processing etc.). However, a few basic properties of the used data set must be coded as static class variables that can be found at the top of the class definition. These are the name of the original data base file (db_name), the desired name of the data base file after pre-processing (preprocessed_db_name), a list of all the atom types that occur in the data base (available_atom_types), the valence constraints of these types (atom_types_valence, currently not used, can safely be set to None), and the minimum and maximum distance between two atoms that are considered as neighbors for the data set (radial_limits). Note that the setting for the radial_limits here will determine which atoms are considered to be connected when calculating connectivity matrices during pre-processing as well as the extent of the radial grid around the focus token used during generation. 150 | 151 | ### Pre-processing of data 152 | 153 | The file template_preprocess_dataset.py contains functions for very basic pre-processing of the data set. In contrast to the procedure that we implemented for QM9, we do not check for valency constraints here as this can get very complicated when allowing other atom types than those in QM9. Furthermore, the detection of bonds with Open Babel is not very reliable, especially for kekulization of aromatic structures containing nitrogen and carbon atoms. We provided heuristics to compensate the shortcomings of the Open Babel implementation for QM9 molecules but cannot guarantee that they lead to any meaningful results when using substantially different molecular structures. Thus, one would need to implement an own strategy for valency checks if desired when using data sets other than QM9. 154 | 155 | The provided script only uses the limits specified under radial_limits in template_data.py to determine which atoms in a molecule are connected and stores the calculated connectivity matrices in the target data base in a compressed format. As our generation procedures places new atoms only in the proximity of the focused atom, it is required that every pair of atoms in a molecule is connected by some path. Otherwise, the algorithm would not be able to generate the structure since there are disconnected parts that cannot be reached. Thus we remove such disconnected structures. If you encounter a lot of removals due to disconnectedness, you should consider increasing the maximum value in radial_limits in the template data class. 156 | 157 | ### Filtering generated molecules 158 | 159 | Finally, the script template_filter_generated.py can be used to filter molecules after generation. As a simple filter example, we remove generated molecules where two atoms are closer than 0.3 Ångström since we consider them to be invalid.The script then only stores generated molecules in an ase sq-lite data base such that they can be visualized with the display_molecules.py script and also generates a statistics file that includes very basic statistics (i.e. the number of atoms of a certain type). All other, more sophisticated filtering routines need to be implemented according to the specific data set and experimental setup. This could for example be valency checks or identification of duplicate molecules. Again, we removed our implementation of those routines for QM9 as they cannot be easily generalized to arbitrarily composed molecules. 160 | 161 | ### Initialize training on custom data set 162 | 163 | After adding the required arguments to the data class in template_data.py and, optionally, implementing further routines for pre-processing of the data set or filtering of generated molecules, you can start training G-SchNet on your data set by adding _--dataset_name template_data_ to the training call: 164 | 165 | python ./G-SchNet/gschnet_script.py train gschnet ./data/ ./models/gschnet_my_data/ --dataset_name template_data --split 2000 200 --cuda 166 | 167 | Do not forget to adjust the number of training and validation molecules according to your data set with _--split #train_samples #val_samples_ (2000 and 200 in the example call above). In order to generate 1000 molecules and filter them afterwards, simply use the following calls: 168 | 169 | python ./G-SchNet/gschnet_script.py generate gschnet ./models/gschnet_my_data/ 1000 --cuda 170 | python ./G-SchNet/template_filter_generated.py ./models/gschnet_my_data/generated/generated.mol_dict 171 | 172 | The filtered molecules can then be viewed with ASE: 173 | 174 | ase gui ./models/gschnet_my_data/generated/generated_molecules.db 175 | 176 | ### Tweaking model and batch sizes 177 | 178 | For training on data sets with large molecules or when using a GPU with less than 16GB VRAM there are a few parameters that can be tweaked to manage the model and batch sizes. We recommend to lower the number of interaction layers in the features extraction part using e.g. _--interactions 6_ (default is 9) or to lower the number of features to e.g. _--features 64_ (default is 128). Furthermore, the number of molecules considered per batch can be lowered using e.g. _--batch_size 2_ (default is 5). However, the size of the batches does also depend on the number of atoms in molecules of the batch as G-SchNet is an autoregressive model that predicts each atom placement step individually. In the default setup, we always sample a complete atom placement trajectory for each molecule of the batch, which means that batch sizes will become very large for molecules with many atoms. To mitigate this problem, you can set e.g. _--draw_random_samples 10_ to draw ten atom placement steps for each molecule in the batch randomly instead of sampling the whole trajectory. In this way, it should be possible to train G-SchNet on larger molecules than those in QM9 even though the cutoffs are not properly implemented yet (as explained in the disclaimer above). 179 | 180 | If you need help to train G-SchNet on your own data, don't hesitate to open an issue or drop us an e-mail. 181 | 182 | 183 | # Citation 184 | If you are using G-SchNet in your research, please cite the corresponding paper: 185 | 186 | N. Gebauer, M. Gastegger, and K. Schütt. Symmetry-adapted generation of 3d point sets for the targeted discovery of molecules. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, _Advances in Neural Information Processing Systems 32_, pages 7566–7578. Curran Associates, Inc., 2019. 187 | 188 | @incollection{NIPS2019_8974, 189 | title = {Symmetry-adapted generation of 3d point sets for the targeted discovery of molecules}, 190 | author = {Gebauer, Niklas and Gastegger, Michael and Sch\"{u}tt, Kristof}, 191 | booktitle = {Advances in Neural Information Processing Systems 32}, 192 | editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett}, 193 | pages = {7566--7578}, 194 | year = {2019}, 195 | publisher = {Curran Associates, Inc.}, 196 | url = {http://papers.nips.cc/paper/8974-symmetry-adapted-generation-of-3d-point-sets-for-the-targeted-discovery-of-molecules.pdf} 197 | } 198 | 199 | # Trained G-SchNet model 200 | Here we provide an already trained G-SchNet model ready to be used for molecule generation or further fine-tuning and biasing. The model was trained as described in the paper, using the standard settings of the gschnet_script and 50k structures from QM9 (as explained in "Training a model" above). Simply extract the folder "gschnet" from the provided zip-file into ./models and continue with the steps described in "Generating molecules" or "Training a biased model" from the guide above. 201 | We used an environment with pytorch 1.5.0, cudatoolkit 10.2, and schnetpack 0.3 for training. 202 | 203 | [Download here.](http://www.quantum-machine.org/data/trained_gschnet_model.zip) 204 | 205 | The QM9 training data is usually downloaded and pre-processed as a first step of the training script. If you use our trained model from here instead of training your own model, you might still need the training data (e.g. for visualization or filtering of generated molecules). In this case, you can simply start a dummy training with zero epochs to initialize the data download and remove the dummy model afterwards: 206 | 207 | python ./G-SchNet/gschnet_script.py train gschnet ./data/ ./models/_dummy/ --split 1 1 --max_epochs 0 208 | rm -r ./models/_dummy 209 | 210 | ![more generated molecules](./images/example_molecules_2.png) 211 | -------------------------------------------------------------------------------- /gschnet_script.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pickle 5 | import time 6 | from shutil import copyfile, rmtree 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from torch.optim import Adam 12 | from torch.utils.data.sampler import RandomSampler 13 | from ase import Atoms 14 | import ase.visualize as asv 15 | 16 | import schnetpack as spk 17 | from schnetpack.utils import count_params, to_json, read_from_json 18 | from schnetpack import Properties 19 | from schnetpack.datasets import DownloadableAtomsData 20 | 21 | from nn_classes import AtomwiseWithProcessing, EmbeddingMultiplication,\ 22 | NormalizeAndAggregate, KLDivergence 23 | from utility_functions import boolean_string, collate_atoms, generate_molecules, \ 24 | update_dict, get_dict_count 25 | 26 | # add your own dataset classes here: 27 | from qm9_data import QM9gen 28 | from template_data import TemplateData 29 | dataset_name_to_class_mapping = {'qm9': QM9gen, 30 | 'template_data': TemplateData} 31 | 32 | logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) 33 | 34 | 35 | def get_parser(): 36 | """ Setup parser for command line arguments """ 37 | main_parser = argparse.ArgumentParser() 38 | 39 | ## command-specific 40 | cmd_parser = argparse.ArgumentParser(add_help=False) 41 | cmd_parser.add_argument('--cuda', help='Set flag to use GPU(s)', 42 | action='store_true') 43 | cmd_parser.add_argument('--parallel', 44 | help='Run data-parallel on all available GPUs ' 45 | '(specify with environment variable' 46 | + ' CUDA_VISIBLE_DEVICES)', 47 | action='store_true') 48 | cmd_parser.add_argument('--batch_size', type=int, 49 | help='Mini-batch size for training and prediction ' 50 | '(default: %(default)s)', 51 | default=5) 52 | cmd_parser.add_argument('--draw_random_samples', type=int, default=0, 53 | help='Only draw x generation steps per molecule ' 54 | 'in each batch (if x=0, all generation ' 55 | 'steps are included for each molecule,' 56 | 'default: %(default)s)') 57 | cmd_parser.add_argument('--checkpoint', type=int, default=-1, 58 | help='The checkpoint of the model that is going ' 59 | 'to be loaded for evaluation or generation ' 60 | '(set to -1 to load the best model ' 61 | 'according to validation error, ' 62 | 'default: %(default)s)') 63 | cmd_parser.add_argument('--precompute_distances', type=boolean_string, 64 | default='true', 65 | help='Store precomputed distances in the database ' 66 | 'during pre-processing (caution, has no effect if ' 67 | 'the dataset has already been downloaded, ' 68 | 'pre-processed, and stored before, ' 69 | 'default: %(default)s)') 70 | 71 | ## training 72 | train_parser = argparse.ArgumentParser(add_help=False, 73 | parents=[cmd_parser]) 74 | train_parser.add_argument('datapath', 75 | help='Path / destination of dataset '\ 76 | 'directory') 77 | train_parser.add_argument('modelpath', 78 | help='Destination for models and logs') 79 | train_parser.add_argument('--dataset_name', type=str, default='qm9', 80 | help=f'Name of the dataset used (choose from ' 81 | f'{list(dataset_name_to_class_mapping.keys())}, ' 82 | f'default: %(default)s)'), 83 | train_parser.add_argument('--subset_path', type=str, 84 | help='A path to a npy file containing indices ' 85 | 'of a subset of the data set at datapath ' 86 | '(default: %(default)s)', 87 | default=None) 88 | train_parser.add_argument('--seed', type=int, default=None, 89 | help='Set random seed for torch and numpy.') 90 | train_parser.add_argument('--overwrite', 91 | help='Remove previous model directory.', 92 | action='store_true') 93 | train_parser.add_argument('--pretrained_path', 94 | help='Start training from the pre-trained model at the ' 95 | 'provided path (reset optimizer parameters such as ' 96 | 'best loss and learning rate and create new split)', 97 | default=None) 98 | train_parser.add_argument('--split_path', 99 | help='Path/destination of npz with data splits', 100 | default=None) 101 | train_parser.add_argument('--split', 102 | help='Split into [train] [validation] and use ' 103 | 'remaining for testing', 104 | type=int, nargs=2, default=[None, None]) 105 | train_parser.add_argument('--max_epochs', type=int, 106 | help='Maximum number of training epochs ' 107 | '(default: %(default)s)', 108 | default=500) 109 | train_parser.add_argument('--lr', type=float, 110 | help='Initial learning rate ' 111 | '(default: %(default)s)', 112 | default=1e-4) 113 | train_parser.add_argument('--lr_patience', type=int, 114 | help='Epochs without improvement before reducing' 115 | ' the learning rate (default: %(default)s)', 116 | default=10) 117 | train_parser.add_argument('--lr_decay', type=float, 118 | help='Learning rate decay ' 119 | '(default: %(default)s)', 120 | default=0.5) 121 | train_parser.add_argument('--lr_min', type=float, 122 | help='Minimal learning rate ' 123 | '(default: %(default)s)', 124 | default=1e-6) 125 | train_parser.add_argument('--logger', 126 | help='Choose logger for training process ' 127 | '(default: %(default)s)', 128 | choices=['csv', 'tensorboard'], 129 | default='tensorboard') 130 | train_parser.add_argument('--log_every_n_epochs', type=int, 131 | help='Log metrics every given number of epochs ' 132 | '(default: %(default)s)', 133 | default=1) 134 | train_parser.add_argument('--checkpoint_every_n_epochs', type=int, 135 | help='Create checkpoint every given number of ' 136 | 'epochs' 137 | '(default: %(default)s)', 138 | default=25) 139 | train_parser.add_argument('--label_width_factor', type=float, 140 | help='A factor that is multiplied with the ' 141 | 'range between two distance bins in order ' 142 | 'to determine the width of the Gaussians ' 143 | 'used to obtain labels from distances ' 144 | '(set to 0. to use one-hot ' 145 | 'encodings of distances as labels, ' 146 | 'default: %(default)s)', 147 | default=0.1) 148 | 149 | ## evaluation 150 | eval_parser = argparse.ArgumentParser(add_help=False, parents=[cmd_parser]) 151 | eval_parser.add_argument('datapath', help='Path of dataset directory') 152 | eval_parser.add_argument('modelpath', help='Path of stored model') 153 | eval_parser.add_argument('--split', 154 | help='Evaluate trained model on given split', 155 | choices=['train', 'validation', 'test'], 156 | default=['test'], nargs='+') 157 | 158 | ## molecule generation 159 | gen_parser = argparse.ArgumentParser(add_help=False, parents=[cmd_parser]) 160 | gen_parser.add_argument('modelpath', help='Path of stored model') 161 | gen_parser.add_argument('amount_gen', type=int, 162 | help='The amount of generated molecules') 163 | gen_parser.add_argument('--show_gen', 164 | help='Whether to open plots of generated ' 165 | 'molecules for visual evaluation', 166 | action='store_true') 167 | gen_parser.add_argument('--chunk_size', type=int, 168 | help='The size of mini batches during generation ' 169 | '(default: %(default)s)', 170 | default=1000) 171 | gen_parser.add_argument('--max_length', type=int, 172 | help='The maximum number of atoms per molecule ' 173 | '(default: %(default)s)', 174 | default=35) 175 | gen_parser.add_argument('--file_name', type=str, 176 | help='The name of the file in which generated ' 177 | 'molecules are stored (please note that ' 178 | 'increasing numbers are appended to the file name ' 179 | 'if it already exists and that the extension ' 180 | '.mol_dict is automatically added to the chosen ' 181 | 'file name, default: %(default)s)', 182 | default='generated') 183 | gen_parser.add_argument('--store_unfinished', 184 | help='Store molecules which have not been ' 185 | 'finished after sampling max_length atoms', 186 | action='store_true') 187 | gen_parser.add_argument('--print_file', 188 | help='Use to limit the printing if results are ' 189 | 'written to a file instead of the console (' 190 | 'e.g. if running on a cluster)', 191 | action='store_true') 192 | gen_parser.add_argument('--temperature', type=float, 193 | help='The temperature T to use for sampling ' 194 | '(default: %(default)s)', 195 | default=0.1) 196 | 197 | # model-specific parsers 198 | model_parser = argparse.ArgumentParser(add_help=False) 199 | model_parser.add_argument('--aggregation_mode', type=str, default='sum', 200 | choices=['sum', 'avg'], 201 | help=' (default: %(default)s)') 202 | 203 | ####### G-SchNet ####### 204 | gschnet_parser = argparse.ArgumentParser(add_help=False, 205 | parents=[model_parser]) 206 | gschnet_parser.add_argument('--features', type=int, 207 | help='Size of atom-wise representation ' 208 | '(default: %(default)s)', 209 | default=128) 210 | gschnet_parser.add_argument('--interactions', type=int, 211 | help='Number of regular SchNet interaction ' 212 | 'blocks (default: %(default)s)', 213 | default=9) 214 | gschnet_parser.add_argument('--cutoff', type=float, default=10., 215 | help='Cutoff radius of local environment ' 216 | '(default: %(default)s)') 217 | gschnet_parser.add_argument('--num_gaussians', type=int, default=25, 218 | help='Number of Gaussians to expand distances ' 219 | '(default: %(default)s)') 220 | gschnet_parser.add_argument('--max_distance', type=float, default=15., 221 | help='Maximum distance covered by the discrete ' 222 | 'distributions over distances learned by ' 223 | 'the model ' 224 | '(default: %(default)s)') 225 | gschnet_parser.add_argument('--num_distance_bins', type=int, default=300, 226 | help='Number of bins used in the discrete ' 227 | 'distributions over distances learned by ' 228 | 'the model(default: %(default)s)') 229 | 230 | ## setup subparser structure 231 | cmd_subparsers = main_parser.add_subparsers(dest='mode', 232 | help='Command-specific ' 233 | 'arguments') 234 | cmd_subparsers.required = True 235 | subparser_train = cmd_subparsers.add_parser('train', help='Training help') 236 | subparser_eval = cmd_subparsers.add_parser('eval', help='Eval help') 237 | subparser_gen = cmd_subparsers.add_parser('generate', help='Generate help') 238 | 239 | train_subparsers = subparser_train.add_subparsers(dest='model', 240 | help='Model-specific ' 241 | 'arguments') 242 | train_subparsers.required = True 243 | train_subparsers.add_parser('gschnet', help='G-SchNet help', 244 | parents=[train_parser, gschnet_parser]) 245 | 246 | eval_subparsers = subparser_eval.add_subparsers(dest='model', 247 | help='Model-specific ' 248 | 'arguments') 249 | eval_subparsers.required = True 250 | eval_subparsers.add_parser('gschnet', help='G-SchNet help', 251 | parents=[eval_parser, gschnet_parser]) 252 | 253 | gen_subparsers = subparser_gen.add_subparsers(dest='model', 254 | help='Model-specific ' 255 | 'arguments') 256 | gen_subparsers.required = True 257 | gen_subparsers.add_parser('gschnet', help='G-SchNet help', 258 | parents=[gen_parser, gschnet_parser]) 259 | 260 | return main_parser 261 | 262 | 263 | def get_model(args, parallelize=False): 264 | # get SchNet layers for feature extraction 265 | representation =\ 266 | spk.representation.SchNet(n_atom_basis=args.features, 267 | n_filters=args.features, 268 | n_interactions=args.interactions, 269 | cutoff=args.cutoff, 270 | n_gaussians=args.num_gaussians, 271 | max_z=100) 272 | 273 | # get output layers for prediction of next atom type 274 | preprocess_type = \ 275 | EmbeddingMultiplication(representation.embedding, 276 | in_key_types='_all_types', 277 | in_key_representation='representation', 278 | out_key='preprocessed_representation') 279 | postprocess_type = NormalizeAndAggregate(normalize=True, 280 | normalization_axis=-1, 281 | normalization_mode='logsoftmax', 282 | aggregate=True, 283 | aggregation_axis=-2, 284 | aggregation_mode='sum', 285 | keepdim=False, 286 | mask='_type_mask', 287 | squeeze=True) 288 | out_module_type = \ 289 | AtomwiseWithProcessing(n_in=args.features, 290 | n_out=1, 291 | n_layers=5, 292 | preprocess_layers=preprocess_type, 293 | postprocess_layers=postprocess_type, 294 | out_key='type_predictions') 295 | 296 | # get output layers for predictions of distances 297 | preprocess_dist = \ 298 | EmbeddingMultiplication(representation.embedding, 299 | in_key_types='_next_types', 300 | in_key_representation='representation', 301 | out_key='preprocessed_representation') 302 | out_module_dist = \ 303 | AtomwiseWithProcessing(n_in=args.features, 304 | n_out=args.num_distance_bins, 305 | n_layers=5, 306 | preprocess_layers=preprocess_dist, 307 | out_key='distance_predictions') 308 | 309 | # combine layers into an atomistic model 310 | model = spk.atomistic.AtomisticModel(representation, 311 | [out_module_type, out_module_dist]) 312 | 313 | if parallelize: 314 | model = nn.DataParallel(model) 315 | 316 | logging.info("The model you built has: %d parameters" % 317 | count_params(model)) 318 | 319 | return model 320 | 321 | 322 | def train(args, model, train_loader, val_loader, device): 323 | 324 | # setup hooks and logging 325 | hooks = [ 326 | spk.hooks.MaxEpochHook(args.max_epochs) 327 | ] 328 | 329 | # filter for trainable parameters 330 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 331 | # setup optimizer 332 | optimizer = Adam(trainable_params, lr=args.lr) 333 | schedule = spk.hooks.ReduceLROnPlateauHook(optimizer, 334 | patience=args.lr_patience, 335 | factor=args.lr_decay, 336 | min_lr=args.lr_min, 337 | window_length=1, 338 | stop_after_min=True) 339 | hooks.append(schedule) 340 | 341 | # set up metrics to log KL divergence on distributions of types and distances 342 | metrics = [KLDivergence(target='_type_labels', 343 | model_output='type_predictions', 344 | name='KLD_types'), 345 | KLDivergence(target='_labels', 346 | model_output='distance_predictions', 347 | mask='_dist_mask', 348 | name='KLD_dists')] 349 | 350 | if args.logger == 'csv': 351 | logger =\ 352 | spk.hooks.CSVHook(os.path.join(args.modelpath, 'log'), 353 | metrics, 354 | every_n_epochs=args.log_every_n_epochs) 355 | hooks.append(logger) 356 | elif args.logger == 'tensorboard': 357 | logger =\ 358 | spk.hooks.TensorboardHook(os.path.join(args.modelpath, 'log'), 359 | metrics, 360 | every_n_epochs=args.log_every_n_epochs) 361 | hooks.append(logger) 362 | 363 | norm_layer = nn.LogSoftmax(-1).to(device) 364 | loss_layer = nn.KLDivLoss(reduction='none').to(device) 365 | 366 | # setup loss function 367 | def loss(batch, result): 368 | # loss for type predictions (KLD) 369 | out_type = norm_layer(result['type_predictions']) 370 | loss_type = loss_layer(out_type, batch['_type_labels']) 371 | loss_type = torch.sum(loss_type, -1) 372 | loss_type = torch.mean(loss_type) 373 | 374 | # loss for distance predictions (KLD) 375 | mask_dist = batch['_dist_mask'] 376 | N = torch.sum(mask_dist) 377 | out_dist = norm_layer(result['distance_predictions']) 378 | loss_dist = loss_layer(out_dist, batch['_labels']) 379 | loss_dist = torch.sum(loss_dist, -1) 380 | loss_dist = torch.sum(loss_dist * mask_dist) / torch.max(N, torch.ones_like(N)) 381 | 382 | return loss_type + loss_dist 383 | 384 | # initialize trainer 385 | trainer = spk.train.Trainer(args.modelpath, 386 | model, 387 | loss, 388 | optimizer, 389 | train_loader, 390 | val_loader, 391 | hooks=hooks, 392 | checkpoint_interval=args.checkpoint_every_n_epochs, 393 | keep_n_checkpoints=10) 394 | 395 | # reset optimizer and hooks if starting from pre-trained model (e.g. for 396 | # fine-tuning) 397 | if args.pretrained_path is not None: 398 | logging.info('starting from pre-trained model...') 399 | # reset epoch and step 400 | trainer.epoch = 0 401 | trainer.step = 0 402 | trainer.best_loss = float('inf') 403 | # reset optimizer 404 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 405 | optimizer = Adam(trainable_params, lr=args.lr) 406 | trainer.optimizer = optimizer 407 | # reset scheduler 408 | schedule =\ 409 | spk.hooks.ReduceLROnPlateauHook(optimizer, 410 | patience=args.lr_patience, 411 | factor=args.lr_decay, 412 | min_lr=args.lr_min, 413 | window_length=1, 414 | stop_after_min=True) 415 | trainer.hooks[1] = schedule 416 | # remove checkpoints of pre-trained model 417 | rmtree(os.path.join(args.modelpath, 'checkpoints')) 418 | os.makedirs(os.path.join(args.modelpath, 'checkpoints')) 419 | # store first checkpoint 420 | trainer.store_checkpoint() 421 | 422 | # start training 423 | trainer.train(device) 424 | 425 | 426 | def evaluate(args, model, train_loader, val_loader, test_loader, device): 427 | header = ['Subset', 'distances KLD', 'types KLD'] 428 | 429 | metrics = [KLDivergence(target='_labels', 430 | model_output='distance_predictions', 431 | mask='_dist_mask'), 432 | KLDivergence(target='_type_labels', 433 | model_output='type_predictions')] 434 | 435 | results = [] 436 | if 'train' in args.split: 437 | results.append(['training'] + 438 | ['%.5f' % i for i in 439 | evaluate_dataset(metrics, model, 440 | train_loader, device)]) 441 | 442 | if 'validation' in args.split: 443 | results.append(['validation'] + 444 | ['%.5f' % i for i in 445 | evaluate_dataset(metrics, model, 446 | val_loader, device)]) 447 | 448 | if 'test' in args.split: 449 | results.append(['test'] + ['%.5f' % i for i in evaluate_dataset( 450 | metrics, model, test_loader, device)]) 451 | 452 | header = ','.join(header) 453 | results = np.array(results) 454 | 455 | np.savetxt(os.path.join(args.modelpath, 'evaluation.csv'), results, 456 | header=header, fmt='%s', delimiter=',') 457 | 458 | 459 | def evaluate_dataset(metrics, model, loader, device): 460 | for metric in metrics: 461 | metric.reset() 462 | 463 | for batch in loader: 464 | batch = { 465 | k: v.to(device) 466 | for k, v in batch.items() 467 | } 468 | result = model(batch) 469 | 470 | for metric in metrics: 471 | metric.add_batch(batch, result) 472 | 473 | results = [ 474 | metric.aggregate() for metric in metrics 475 | ] 476 | return results 477 | 478 | 479 | def generate(args, train_args, model, device): 480 | # generate molecules (in chunks) and print progress 481 | 482 | dataclass = dataset_name_to_class_mapping[train_args.dataset_name] 483 | types = sorted(dataclass.available_atom_types) # retrieve available atom types 484 | all_types = types + [types[-1] + 1] # add stop token to list (largest type + 1) 485 | start_token = types[-1] + 2 # define start token (largest type + 2) 486 | amount = args.amount_gen 487 | chunk_size = args.chunk_size 488 | if chunk_size >= amount: 489 | chunk_size = amount 490 | 491 | # set parameters for printing progress 492 | if int(amount / 10.) < chunk_size: 493 | step = chunk_size 494 | else: 495 | step = int(amount / 10.) 496 | increase = lambda x, y: y + step if x >= y else y 497 | thresh = step 498 | if args.print_file: 499 | progress = lambda x, y: print(f'Generated {x}.', flush=True) \ 500 | if x >= y else print('', end='', flush=True) 501 | else: 502 | progress = lambda x, y: print(f'\x1b[2K\rSuccessfully generated' 503 | f' {x}', end='', flush=True) 504 | 505 | # generate 506 | generated = {} 507 | left = args.amount_gen 508 | done = 0 509 | start_time = time.time() 510 | with torch.no_grad(): 511 | while left > 0: 512 | if left - chunk_size < 0: 513 | batch = left 514 | else: 515 | batch = chunk_size 516 | update_dict(generated, 517 | generate_molecules( 518 | batch, 519 | model, 520 | all_types=all_types, 521 | start_token=start_token, 522 | max_length=args.max_length, 523 | save_unfinished=args.store_unfinished, 524 | device=device, 525 | max_dist=train_args.max_distance, 526 | n_bins=train_args.num_distance_bins, 527 | radial_limits=dataclass.radial_limits, 528 | t=args.temperature) 529 | ) 530 | left -= batch 531 | done += batch 532 | n = np.sum(get_dict_count(generated, args.max_length)) 533 | progress(n, thresh) 534 | thresh = increase(n, thresh) 535 | print('') 536 | end_time = time.time() - start_time 537 | m, s = divmod(end_time, 60) 538 | h, m = divmod(m, 60) 539 | h, m, s = int(h), int(m), int(s) 540 | print(f'Time consumed: {h:d}:{m:02d}:{s:02d}') 541 | 542 | # sort keys in resulting dictionary 543 | generated = dict(sorted(generated.items())) 544 | 545 | # show generated molecules and print some statistics if desired 546 | if args.show_gen: 547 | ats = [] 548 | n_total_atoms = 0 549 | n_molecules = 0 550 | for key in generated: 551 | n = 0 552 | for i in range(len(generated[key][Properties.Z])): 553 | at = Atoms(generated[key][Properties.Z][i], 554 | positions=generated[key][Properties.R][i]) 555 | ats += [at] 556 | n += 1 557 | n_molecules += 1 558 | n_total_atoms += n * key 559 | asv.view(ats) 560 | print(f'Total number of atoms placed: {n_total_atoms} ' 561 | f'(avg {n_total_atoms / n_molecules:.2f})', flush=True) 562 | 563 | return generated 564 | 565 | 566 | def main(args): 567 | # set device (cpu or gpu) 568 | device = torch.device('cuda' if args.cuda else 'cpu') 569 | 570 | # store (or load) arguments 571 | argparse_dict = vars(args) 572 | jsonpath = os.path.join(args.modelpath, 'args.json') 573 | 574 | if args.mode == 'train': 575 | # overwrite existing model if desired 576 | if args.overwrite and os.path.exists(args.modelpath): 577 | rmtree(args.modelpath) 578 | logging.info('existing model will be overwritten...') 579 | 580 | # create model directory if it does not exist 581 | if not os.path.exists(args.modelpath): 582 | os.makedirs(args.modelpath) 583 | 584 | # get latest checkpoint of pre-trained model if a path was provided 585 | if args.pretrained_path is not None: 586 | model_chkpt_path = os.path.join(args.modelpath, 'checkpoints') 587 | pretrained_chkpt_path = os.path.join(args.pretrained_path, 'checkpoints') 588 | if os.path.exists(model_chkpt_path) \ 589 | and len(os.listdir(model_chkpt_path)) > 0: 590 | logging.info(f'found existing checkpoints in model directory ' 591 | f'({model_chkpt_path}), please use --overwrite or choose ' 592 | f'empty model directory to start from a pre-trained ' 593 | f'model...') 594 | logging.warning(f'will ignore pre-trained model and start from latest ' 595 | f'checkpoint at {model_chkpt_path}...') 596 | args.pretrained_path = None 597 | else: 598 | logging.info(f'fetching latest checkpoint from pre-trained model at ' 599 | f'{pretrained_chkpt_path}...') 600 | if not os.path.exists(pretrained_chkpt_path): 601 | logging.warning(f'did not find checkpoints of pre-trained model, ' 602 | f'will train from scratch...') 603 | args.pretrained_path = None 604 | else: 605 | chkpt_files = [f for f in os.listdir(pretrained_chkpt_path) 606 | if f.startswith("checkpoint")] 607 | if len(chkpt_files) == 0: 608 | logging.warning(f'did not find checkpoints of pre-trained ' 609 | f'model, will train from scratch...') 610 | args.pretrained_path = None 611 | else: 612 | epoch = max([int(f.split(".")[0].split("-")[-1]) 613 | for f in chkpt_files]) 614 | chkpt = os.path.join(pretrained_chkpt_path, 615 | "checkpoint-" + str(epoch) + ".pth.tar") 616 | if not os.path.exists(model_chkpt_path): 617 | os.makedirs(model_chkpt_path) 618 | copyfile(chkpt, os.path.join(model_chkpt_path, 619 | f'checkpoint-{epoch}.pth.tar')) 620 | 621 | # store arguments for training in model directory 622 | to_json(jsonpath, argparse_dict) 623 | train_args = args 624 | 625 | # set seed 626 | spk.utils.set_random_seed(args.seed) 627 | else: 628 | # load arguments used for training from model directory 629 | train_args = read_from_json(jsonpath) 630 | 631 | # load data for training/evaluation 632 | if args.mode in ['train', 'eval']: 633 | # find correct data class 634 | assert train_args.dataset_name in dataset_name_to_class_mapping, \ 635 | f'Could not find data class for dataset {train_args.dataset}. Please ' \ 636 | f'specify a correct dataset name!' 637 | dataclass = dataset_name_to_class_mapping[train_args.dataset_name] 638 | 639 | # load the dataset 640 | logging.info(f'{train_args.dataset_name} will be loaded...') 641 | subset = None 642 | if train_args.subset_path is not None: 643 | logging.info(f'Using subset from {train_args.subset_path}') 644 | subset = np.load(train_args.subset_path) 645 | subset = [int(i) for i in subset] 646 | if issubclass(dataclass, DownloadableAtomsData): 647 | data = dataclass(args.datapath, subset=subset, 648 | precompute_distances=args.precompute_distances, 649 | download=True if args.mode == 'train' else False) 650 | else: 651 | data = dataclass(args.datapath, subset=subset, 652 | precompute_distances=args.precompute_distances) 653 | 654 | # splits the dataset in test, val, train sets 655 | split_path = os.path.join(args.modelpath, 'split.npz') 656 | if args.mode == 'train': 657 | if args.split_path is not None: 658 | copyfile(args.split_path, split_path) 659 | 660 | logging.info('create splits...') 661 | data_train, data_val, data_test = data.create_splits(*train_args.split, 662 | split_file=split_path) 663 | 664 | logging.info('load data...') 665 | types = sorted(dataclass.available_atom_types) 666 | max_type = types[-1] 667 | # set up collate function according to args 668 | collate = lambda x: \ 669 | collate_atoms(x, 670 | all_types=types + [max_type+1], 671 | start_token=max_type+2, 672 | draw_samples=args.draw_random_samples, 673 | label_width_scaling=train_args.label_width_factor, 674 | max_dist=train_args.max_distance, 675 | n_bins=train_args.num_distance_bins) 676 | 677 | train_loader = spk.data.AtomsLoader(data_train, batch_size=args.batch_size, 678 | sampler=RandomSampler(data_train), 679 | num_workers=4, pin_memory=True, 680 | collate_fn=collate) 681 | val_loader = spk.data.AtomsLoader(data_val, batch_size=args.batch_size, 682 | num_workers=2, pin_memory=True, 683 | collate_fn=collate) 684 | 685 | # construct the model 686 | if args.mode == 'train' or args.checkpoint >= 0: 687 | model = get_model(train_args, parallelize=args.parallel) 688 | logging.info(f'running on {device}') 689 | 690 | # load model or checkpoint for evaluation or generation 691 | if args.mode in ['eval', 'generate']: 692 | if args.checkpoint < 0: # load best model 693 | logging.info(f'restoring best model') 694 | model = torch.load(os.path.join(args.modelpath, 'best_model')).to(device) 695 | else: 696 | logging.info(f'restoring checkpoint {args.checkpoint}') 697 | chkpt = os.path.join(args.modelpath, 'checkpoints', 698 | 'checkpoint-' + str(args.checkpoint) + '.pth.tar') 699 | state_dict = torch.load(chkpt) 700 | model.load_state_dict(state_dict['model'], strict=True) 701 | 702 | # execute training, evaluation, or generation 703 | if args.mode == 'train': 704 | logging.info("training...") 705 | train(args, model, train_loader, val_loader, device) 706 | logging.info("...training done!") 707 | 708 | elif args.mode == 'eval': 709 | logging.info("evaluating...") 710 | test_loader = spk.data.AtomsLoader(data_test, 711 | batch_size=args.batch_size, 712 | num_workers=2, 713 | pin_memory=True, 714 | collate_fn=collate) 715 | with torch.no_grad(): 716 | evaluate(args, model, train_loader, val_loader, test_loader, device) 717 | logging.info("... done!") 718 | 719 | elif args.mode == 'generate': 720 | logging.info(f'generating {args.amount_gen} molecules...') 721 | generated = generate(args, train_args, model, device) 722 | gen_path = os.path.join(args.modelpath, 'generated/') 723 | if not os.path.exists(gen_path): 724 | os.makedirs(gen_path) 725 | # get untaken filename and store results 726 | file_name = os.path.join(gen_path, args.file_name) 727 | if os.path.isfile(file_name + '.mol_dict'): 728 | expand = 0 729 | while True: 730 | expand += 1 731 | new_file_name = file_name + '_' + str(expand) 732 | if os.path.isfile(new_file_name + '.mol_dict'): 733 | continue 734 | else: 735 | file_name = new_file_name 736 | break 737 | with open(file_name + '.mol_dict', 'wb') as f: 738 | pickle.dump(generated, f) 739 | logging.info('...done!') 740 | else: 741 | logging.info(f'Unknown mode: {args.mode}') 742 | 743 | 744 | if __name__ == '__main__': 745 | parser = get_parser() 746 | args = parser.parse_args() 747 | main(args) 748 | -------------------------------------------------------------------------------- /utility_classes.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import re 3 | import numpy as np 4 | import openbabel as ob 5 | import pybel 6 | from multiprocessing import Process 7 | from rdkit import Chem 8 | from scipy.spatial.distance import squareform 9 | 10 | 11 | class Molecule: 12 | ''' 13 | Molecule class that allows to get statistics such as the connectivity matrix, 14 | molecular fingerprint, canonical smiles representation, or ring count given 15 | positions of atoms and their atomic numbers. Currently supports molecules made of 16 | carbon, nitrogen, oxygen, fluorine, and hydrogen (such as in the QM9 benchmark 17 | dataset). Mainly relies on routines from Open Babel and RdKit. 18 | 19 | Args: 20 | pos (numpy.ndarray): positions of atoms in euclidean space (n_atoms x 3) 21 | atomic_numbers (numpy.ndarray): list with nuclear charge/type of each atom 22 | (e.g. 1 for hydrogens, 6 for carbons etc.). 23 | connectivity_matrix (numpy.ndarray, optional): optionally, a pre-calculated 24 | connectivity matrix (n_atoms x n_atoms) containing the bond order between 25 | atom pairs can be provided (default: None). 26 | store_positions (bool, optional): set True to store the positions of atoms in 27 | self.positions (only for convenience, not needed for computations, default: 28 | False). 29 | ''' 30 | 31 | type_infos = {1: {'name': 'H', 32 | 'n_bonds': 1}, 33 | 6: {'name': 'C', 34 | 'n_bonds': 4}, 35 | 7: {'name': 'N', 36 | 'n_bonds': 3}, 37 | 8: {'name': 'O', 38 | 'n_bonds': 2}, 39 | 9: {'name': 'F', 40 | 'n_bonds': 1}, 41 | } 42 | type_charges = {'H': 1, 'C': 6, 'N': 7, 'O': 8, 'F': 9} 43 | 44 | def __init__(self, pos, atomic_numbers, connectivity_matrix=None, 45 | store_positions=False): 46 | # set comparison metrics to None (will be computed just in time) 47 | self._fp = None 48 | self._fp_bits = None 49 | self._can = None 50 | self._mirror_can = None 51 | self._inchi_key = None 52 | self._bond_stats = None 53 | self._fixed_connectivity = False 54 | self._row_indices = {} 55 | self._obmol = None 56 | self._rings = None 57 | self._n_atoms_per_type = None 58 | self._connectivity = connectivity_matrix 59 | 60 | # set statistics 61 | self.n_atoms = len(pos) 62 | self.numbers = atomic_numbers 63 | self._unique_numbers = {*self.numbers} # set for fast query 64 | self.positions = pos 65 | if not store_positions: 66 | self._obmol = self.get_obmol() # create obmol before removing pos 67 | self.positions = None 68 | 69 | def sanity_check(self): 70 | ''' 71 | Check whether the sum of valence of all atoms can be divided by 2. 72 | 73 | Returns: 74 | bool: True if the test is passed, False otherwise 75 | ''' 76 | count = 0 77 | for atom in self.numbers: 78 | count += self.type_infos[atom]['n_bonds'] 79 | if count % 2 == 0: 80 | return True 81 | else: 82 | return False 83 | 84 | def get_obmol(self): 85 | ''' 86 | Retrieve the underlying Open Babel OBMol object. 87 | 88 | Returns: 89 | OBMol object: Open Babel OBMol representation 90 | ''' 91 | if self._obmol is None: 92 | if self.positions is None: 93 | print('Error, cannot create obmol without positions!') 94 | return 95 | if self.numbers is None: 96 | print('Error, cannot create obmol without atomic numbers!') 97 | return 98 | # use openbabel to infer bonds and bond order: 99 | obmol = ob.OBMol() 100 | obmol.BeginModify() 101 | 102 | # set positions and atomic numbers of all atoms in the molecule 103 | for p, n in zip(self.positions, self.numbers): 104 | obatom = obmol.NewAtom() 105 | obatom.SetAtomicNum(int(n)) 106 | obatom.SetVector(*p.tolist()) 107 | 108 | # infer bonds and bond order 109 | obmol.ConnectTheDots() 110 | obmol.PerceiveBondOrders() 111 | 112 | obmol.EndModify() 113 | self._obmol = obmol 114 | return self._obmol 115 | 116 | def get_fp(self): 117 | ''' 118 | Retrieve the molecular fingerprint (the path-based FP2 from Open Babel is used, 119 | which means that paths of length up to 7 are considered). 120 | 121 | Returns: 122 | pybel.Fingerprint object: moleculer fingerprint (use "fp1 | fp2" to 123 | calculate the Tanimoto coefficient of two fingerprints) 124 | ''' 125 | if self._fp is None: 126 | # calculate fingerprint 127 | self._fp = pybel.Molecule(self.get_obmol()).calcfp() 128 | return self._fp 129 | 130 | def get_fp_bits(self): 131 | ''' 132 | Retrieve the bits set in the molecular fingerprint. 133 | 134 | Returns: 135 | Set of int: object containing the bits set in the molecular fingerprint 136 | ''' 137 | if self._fp_bits is None: 138 | self._fp_bits = {*self.get_fp().bits} 139 | return self._fp_bits 140 | 141 | def get_can(self): 142 | ''' 143 | Retrieve the canonical SMILES representation of the molecule. 144 | 145 | Returns: 146 | String: canonical SMILES string 147 | ''' 148 | if self._can is None: 149 | # calculate canonical SMILES 150 | self._can = pybel.Molecule(self.get_obmol()).write('can') 151 | return self._can 152 | 153 | def get_mirror_can(self): 154 | ''' 155 | Retrieve the canonical SMILES representation of the mirrored molecule (the 156 | z-coordinates are flipped). 157 | 158 | Returns: 159 | String: canonical SMILES string of the mirrored molecule 160 | ''' 161 | if self._mirror_can is None: 162 | # calculate canonical SMILES of mirrored molecule 163 | self._flip_z() # flip z to mirror molecule using x-y plane 164 | self._mirror_can = pybel.Molecule(self.get_obmol()).write('can') 165 | self._flip_z() # undo mirroring 166 | return self._mirror_can 167 | 168 | def get_inchi_key(self): 169 | ''' 170 | Retrieve the InChI-key of the molecule. 171 | 172 | Returns: 173 | String: InChI-key 174 | ''' 175 | if self._inchi_key is None: 176 | # calculate inchi key 177 | self._inchi_key = pybel.Molecule(self.get_obmol()).\ 178 | write('inchikey') 179 | return self._inchi_key 180 | 181 | def _flip_z(self): 182 | ''' 183 | Flips the z-coordinates of atom positions (to get a mirrored version of the 184 | molecule). 185 | ''' 186 | if self._obmol is None: 187 | self.get_obmol() 188 | for atom in ob.OBMolAtomIter(self._obmol): 189 | x, y, z = atom.x(), atom.y(), atom.z() 190 | atom.SetVector(x, y, -z) 191 | self._obmol.ConnectTheDots() 192 | self._obmol.PerceiveBondOrders() 193 | 194 | def get_connectivity(self): 195 | ''' 196 | Retrieve the connectivity matrix of the molecule. 197 | 198 | Returns: 199 | numpy.ndarray: (n_atoms x n_atoms) array containing the pairwise bond orders 200 | between atoms (0 for no bond). 201 | ''' 202 | if self._connectivity is None: 203 | # get connectivity matrix 204 | connectivity = np.zeros((self.n_atoms, len(self.numbers))) 205 | for atom in ob.OBMolAtomIter(self.get_obmol()): 206 | index = atom.GetIdx() - 1 207 | # loop over all neighbors of atom 208 | for neighbor in ob.OBAtomAtomIter(atom): 209 | idx = neighbor.GetIdx() - 1 210 | bond_order = neighbor.GetBond(atom).GetBO() 211 | #print(f'{index}-{idx}: {bond_order}') 212 | # do not count bonds between two hydrogen atoms 213 | if (self.numbers[index] == 1 and self.numbers[idx] == 1 214 | and bond_order > 0): 215 | bond_order = 0 216 | connectivity[index, idx] = bond_order 217 | self._connectivity = connectivity 218 | return self._connectivity 219 | 220 | def get_ring_counts(self): 221 | ''' 222 | Retrieve a list containing the sizes of rings in the symmetric smallest set 223 | of smallest rings (S-SSSR from RdKit) in the molecule (e.g. [5, 6, 5] for two 224 | rings of size 5 and one ring of size 6). 225 | 226 | Returns: 227 | List of int: list with ring sizes 228 | ''' 229 | if self._rings is None: 230 | # calculate symmetric SSSR with RdKit using the canonical smiles 231 | # representation as input 232 | can = self.get_can() 233 | mol = Chem.MolFromSmiles(can) 234 | if mol is not None: 235 | ssr = Chem.GetSymmSSSR(mol) 236 | self._rings = [len(ssr[i]) for i in range(len(ssr))] 237 | else: 238 | self._rings = [] # cannot count rings 239 | return self._rings 240 | 241 | def get_n_atoms_per_type(self): 242 | ''' 243 | Retrieve the number of atoms in the molecule per type. 244 | 245 | Returns: 246 | numpy.ndarray: number of atoms in the molecule per type, where the order 247 | corresponds to the order specified in Molecule.type_infos 248 | ''' 249 | if self._n_atoms_per_type is None: 250 | _types = np.array(list(self.type_infos.keys()), dtype=int) 251 | self._n_atoms_per_type =\ 252 | np.bincount(self.numbers, minlength=np.max(_types)+1)[_types] 253 | return self._n_atoms_per_type 254 | 255 | def remove_unpicklable_attributes(self, restorable=True): 256 | ''' 257 | Some attributes of the class cannot be processed by pickle. This method 258 | allows to remove these attributes prior to pickling. 259 | 260 | Args: 261 | restorable (bool, optional): Set True to allow restoring the deleted 262 | attributes later on (default: True) 263 | ''' 264 | # set attributes which are not picklable (SwigPyObjects) to None 265 | if restorable and self.positions is None and self._obmol is not None: 266 | # store positions to allow restoring obmol object later on 267 | pos = [atom.coords for atom in pybel.Molecule(self._obmol).atoms] 268 | self.positions = np.array(pos) 269 | self._obmol = None 270 | self._fp = None 271 | 272 | def tanimoto_similarity(self, other_mol, use_bits=True): 273 | ''' 274 | Get the Tanimoto (fingerprint) similarity to another molecule. 275 | 276 | Args: 277 | other_mol (Molecule or pybel.Fingerprint/list of bits set): 278 | representation of the second molecule (if it is not a Molecule object, 279 | it needs to be a pybel.Fingerprint if use_bits is False and a list of bits 280 | set in the fingerprint if use_bits is True). 281 | use_bits (bool, optional): set True to calculate Tanimoto similarity 282 | from bits set in the fingerprint (default: True) 283 | 284 | Returns: 285 | float: Tanimoto similarity to the other molecule 286 | ''' 287 | if use_bits: 288 | a = self.get_fp_bits() 289 | b = other_mol.get_fp_bits() if isinstance(other_mol, Molecule) \ 290 | else other_mol 291 | n_equal = len(a.intersection(b)) 292 | if len(a) + len(b) == 0: # edge case with no set bits 293 | return 1. 294 | return n_equal / (len(a)+len(b)-n_equal) 295 | else: 296 | fp_other = other_mol.get_fp() if isinstance(other_mol, Molecule)\ 297 | else other_mol 298 | return self.get_fp() | fp_other 299 | 300 | def _update_bond_orders(self, idc_lists): 301 | ''' 302 | Updates the bond orders in the underlying OBMol object. 303 | 304 | Args: 305 | idc_lists (list of list of int): nested list containing bonds, i.e. pairs 306 | of row indices (list1) and column indices (list2) which shall be updated 307 | ''' 308 | con_mat = self.get_connectivity() 309 | self._obmol.BeginModify() 310 | for i in range(len(idc_lists[0])): 311 | idx1 = idc_lists[0][i] 312 | idx2 = idc_lists[1][i] 313 | obbond = self._obmol.GetBond(int(idx1+1), int(idx2+1)) 314 | obbond.SetBO(int(con_mat[idx1, idx2])) 315 | self._obmol.EndModify() 316 | 317 | # reset fingerprints etc 318 | self._fp = None 319 | self._can = None 320 | self._mirror_can = None 321 | self._inchi_key = None 322 | 323 | def get_fixed_connectivity(self, recursive_call=False): 324 | ''' 325 | Attempts to fix the connectivity matrix using some heuristics (as some valid 326 | QM9 molecules do not pass the valency check using the connectivity matrix 327 | obtained with Open Babel, which seems to have problems with assigning correct 328 | bond orders to aromatic rings containing Nitrogen). 329 | 330 | Args: 331 | recursive_call (bool, do not set True): flag that indicates a recursive 332 | call (used internally, do not set to True) 333 | 334 | Returns: 335 | numpy.ndarray: (n_atoms x n_atoms) array containing the pairwise bond orders 336 | between atoms (0 for no bond) after the attempted fix. 337 | ''' 338 | 339 | # if fix has already been attempted, return the connectivity matrix 340 | if self._fixed_connectivity: 341 | return self._connectivity 342 | 343 | # define helpers: 344 | # increases bond order between two atoms in connectivity matrix 345 | def increase_bond(con_mat, idx1, idx2): 346 | con_mat[idx1, idx2] += 1 347 | con_mat[idx2, idx1] += 1 348 | return con_mat 349 | 350 | # decreases bond order between two atoms in connectivity matrix 351 | def decrease_bond(con_mat, idx1, idx2): 352 | con_mat[idx1, idx2] -= 1 353 | con_mat[idx2, idx1] -= 1 354 | return con_mat 355 | 356 | # returns only the rows of the connectivity matrix corresponding to atoms of 357 | # certain types (and the indices of these atoms) 358 | def get_typewise_connectivity(con_mat, types): 359 | idcs = [] 360 | for type in types: 361 | idcs += list(self._get_row_idcs(type)) 362 | return con_mat[idcs], np.array(idcs).astype(int) 363 | 364 | # store old connectivity matrix for later comparison 365 | old_mat = self.get_connectivity().copy() 366 | 367 | # get connectivity matrix and find indices of N and C atoms 368 | con_mat = self.get_connectivity() 369 | if 6 not in self._unique_numbers and 7 not in self._unique_numbers: 370 | # do not attempt fixing if there is no carbon and no nitrogen 371 | return con_mat 372 | N_mat, N_idcs = get_typewise_connectivity(con_mat, [7]) 373 | C_mat, C_idcs = get_typewise_connectivity(con_mat, [6]) 374 | NC_idcs = np.hstack((N_idcs, C_idcs)) # indices of all N and C atoms 375 | NC_valences = self._get_valences()[NC_idcs] # array with valency constraints 376 | 377 | # return connectivity if valency constraints of N and C atoms are already met 378 | if np.all(np.sum(con_mat[NC_idcs], axis=1) == NC_valences): 379 | return con_mat 380 | 381 | # if a C or N atom is "overcharged" (total bond order too high) we decrease 382 | # double to single bonds between N-N or N-C until it is not overcharged anymore 383 | # (e.g. C=N=C -> C=N-C) 384 | if 7 in self._unique_numbers: # only necessary if molecule contains N 385 | for cur in NC_idcs: 386 | type = self.numbers[cur] 387 | if np.sum(con_mat[cur]) <= self.type_infos[type]['n_bonds']: 388 | continue 389 | if type == 6: # for carbon look only at nitrogen neighbors 390 | neighbors = self._get_neighbors(cur, types=[7], strength=2) 391 | else: 392 | neighbors = self._get_neighbors(cur, types=[6, 7], 393 | strength=2) 394 | for neighbor in neighbors: 395 | con_mat = decrease_bond(con_mat, cur, neighbor) 396 | self._connectivity = con_mat 397 | if np.sum(con_mat[cur]) == \ 398 | self.type_infos[type]['n_bonds']: 399 | break 400 | 401 | # get updated partial connectivity matrices for N and C 402 | N_mat, _ = get_typewise_connectivity(con_mat, [7]) 403 | C_mat, _ = get_typewise_connectivity(con_mat, [6]) 404 | 405 | # increase total number of bonds by transferring the strength of a 406 | # double C-N bond to two neighboring bonds, if the involved atoms 407 | # are not yet saturated (e.g. H2C-H2C=N-H2C -> H2C=H2C-N=H2C) 408 | if (np.sum(N_mat) < len(N_idcs) * 3 or np.sum(C_mat) < len(C_idcs) * 4) \ 409 | and 7 in self._unique_numbers: 410 | for cur in NC_idcs: 411 | type = self.numbers[cur] 412 | if sum(con_mat[cur]) >= self.type_infos[type]['n_bonds']: 413 | continue 414 | CN_nbors = self._get_CN_neighbors(cur) 415 | for nbor_1, nbor_2 in CN_nbors: 416 | if con_mat[nbor_1, nbor_2] <= 1: 417 | continue 418 | else: 419 | nbor_2_nbors = np.where(con_mat[nbor_2] == 1)[0] 420 | for nbor_2_nbor in nbor_2_nbors: 421 | nbor_2_nbor_type = self.numbers[nbor_2_nbor] 422 | if (np.sum(con_mat[nbor_2_nbor]) < 423 | self.type_infos[nbor_2_nbor_type]['n_bonds']): 424 | con_mat = increase_bond(con_mat, cur, nbor_1) 425 | con_mat = increase_bond(con_mat, nbor_2, nbor_2_nbor) 426 | con_mat = decrease_bond(con_mat, nbor_1, nbor_2) 427 | self._connectivity = con_mat 428 | 429 | # increase bond strength between two undercharged neighbors C-N, 430 | # C-C or N-N (e.g HN-CH2 -> HN=CH2, starting from those atoms with least 431 | # available neighbors if there are multiple undercharged neighbors) 432 | undercharged_pairs = True 433 | while (undercharged_pairs): 434 | NC_charges = np.sum(con_mat[NC_idcs], axis=1) 435 | undercharged = NC_idcs[np.where(NC_charges < NC_valences)[0]] 436 | partial_con_mat = con_mat[undercharged][:, undercharged] 437 | # if non of the undercharged atoms are neighbors, stop 438 | if np.sum(partial_con_mat) == 0: 439 | break 440 | # sort by number of undercharged neighbors 441 | n_nbors = np.sum(partial_con_mat > 0, axis=0) 442 | # mask indices with zero undercharged neighbors to ignore them when sorting 443 | n_nbors[np.where(n_nbors == 0)[0]] = 1000 444 | cur = np.argmin(n_nbors) 445 | cur_nbor = np.where(partial_con_mat[cur] > 0)[0][0] 446 | con_mat = increase_bond(con_mat, undercharged[cur], undercharged[cur_nbor]) 447 | self._connectivity = con_mat 448 | 449 | # if the molecule still is not valid, try to flip double bonds if an atom 450 | # forms a double bond and has at least one other neighbor that has too few bonds 451 | # (e.g. C-N=C -> C=N-C) and repeat above heuristics with a recursive call of 452 | # this function 453 | if not recursive_call and \ 454 | not np.all(np.sum(con_mat[NC_idcs], axis=1) == NC_valences): 455 | changed = False 456 | candidates = np.where(np.any(con_mat[NC_idcs][:, NC_idcs] == 2, axis=0))[0] 457 | for cand in NC_idcs[candidates]: 458 | if np.sum(con_mat[cand, NC_idcs] == 2) == 0: 459 | continue 460 | NC_charges = np.sum(con_mat[NC_idcs], axis=1) 461 | undercharged = NC_charges < NC_valences 462 | uc_neighbors = np.logical_and(con_mat[cand, NC_idcs] == 1, undercharged) 463 | if np.any(uc_neighbors): 464 | uc_neighbor = NC_idcs[np.where(uc_neighbors)[0][0]] 465 | oc_neighbor = NC_idcs[ 466 | np.where(con_mat[cand, NC_idcs] == 2)[0][0]] 467 | con_mat = increase_bond(con_mat, cand, uc_neighbor) 468 | con_mat = decrease_bond(con_mat, cand, oc_neighbor) 469 | self._connectivity = con_mat 470 | changed = True 471 | if changed: 472 | self._connectivity = self.get_fixed_connectivity( 473 | recursive_call=True) 474 | 475 | # store that fixing the connectivity matrix has already been attempted 476 | if not recursive_call: 477 | self._fixed_connectivity = True 478 | if np.any(old_mat != self._connectivity): 479 | # update bond orders in underlying OBMol object (where they changed) 480 | self._update_bond_orders(np.where(old_mat != self._connectivity)) 481 | 482 | return self._connectivity 483 | 484 | def _get_valences(self): 485 | ''' 486 | Retrieve the valency constraints of all atoms in the molecule. 487 | 488 | Returns: 489 | numpy.ndarray: valency constraints (one per atom) 490 | ''' 491 | valence = [] 492 | for atom in self.numbers: 493 | valence += [self.type_infos[atom]['n_bonds']] 494 | return np.array(valence) 495 | 496 | def _get_CN_neighbors(self, idx): 497 | ''' 498 | For a focus atom of type K returns indices of atoms C (carbon) and N (nitrogen) 499 | on two-step paths of the form K-C-N (and K-C-C only for K=N since one atom 500 | needs to be nitrogen). 501 | 502 | Args: 503 | idx (int): the index of the focus atom from which paths are examined 504 | 505 | Returns: 506 | list of lists: list1[i] contains an index of a direct neighbor of the 507 | focus atom and list2[i] contains the index of a second neighbor on the 508 | i-th identified two-step path 509 | ''' 510 | con_mat = self.get_connectivity() 511 | nbors = con_mat[idx] > 0 512 | C_nbors = np.where(np.logical_and(self.numbers == 6, nbors))[0] 513 | type = self.numbers[idx] 514 | # mask types to exclude idx from neighborhood 515 | _numbers = self.numbers.copy() 516 | _numbers[idx] = 0 517 | CN_nbors = np.where(np.logical_and(_numbers == 7, con_mat[C_nbors] > 0)) 518 | CN_nbors = [(C_nbors[CN_nbors[0][i]], CN_nbors[1][i]) 519 | for i in range(len(CN_nbors[0]))] 520 | if type == 7: # for N atoms, also add C-C neighbors 521 | CC_nbors = np.where(np.logical_and( 522 | _numbers == 6, con_mat[C_nbors] > 0)) 523 | CC_nbors = [ 524 | (C_nbors[CC_nbors[0][i]], CC_nbors[1][i]) 525 | for i in range(len(CC_nbors[0]))] 526 | CN_nbors += CC_nbors 527 | return CN_nbors 528 | 529 | def _get_neighbors(self, idx, types=None, strength=1): 530 | ''' 531 | Retrieve the indices of neighbors of an atom. 532 | 533 | Args: 534 | idx (int): index of the atom 535 | types (list of int, optional): restrict the returned neighbors to 536 | contain only atoms of the specified types (set None to apply no type 537 | filter, default: None) 538 | strength (int, optional): restrict the returned neighbors to contain 539 | only atoms with a certain minimal bond order to the atom at idx 540 | (default: 1) 541 | 542 | Returns: 543 | list of int: indices of all neighbors that meet the requirements 544 | ''' 545 | con_mat = self.get_connectivity() 546 | neighbors = con_mat[idx] >= strength 547 | if types is not None: 548 | type_arr = np.zeros(len(neighbors)).astype(bool) 549 | for type in types: 550 | type_arr = np.logical_or(type_arr, self.numbers == type) 551 | return np.where(np.logical_and(neighbors, type_arr))[0] 552 | 553 | def get_bond_stats(self): 554 | ''' 555 | Retrieve the bond and ring count of the molecule. The bond count is 556 | calculated for every pair of types (e.g. C1N are all single bonds between 557 | carbon and nitrogen atoms in the molecule, C2N are all double bonds between 558 | such atoms etc.). The ring count is provided for rings from size 3 to 8 (R3, 559 | R4, ..., R8) and for rings greater than size eight (R>8). 560 | 561 | Returns: 562 | dict (str->int): bond and ring counts 563 | ''' 564 | if self._bond_stats is None: 565 | 566 | # 1st analyze bonds 567 | unique_types = np.sort(list(self._unique_numbers)) 568 | # get connectivity and read bonds from matrix 569 | con_mat = self.get_connectivity() 570 | d = {} 571 | for i, type1 in enumerate(unique_types): 572 | row_idcs = self._get_row_idcs(type1) 573 | n_bonds1 = self.type_infos[type1]['n_bonds'] 574 | for type2 in unique_types[i:]: 575 | col_idcs = self._get_row_idcs(type2) 576 | n_bonds2 = self.type_infos[type2]['n_bonds'] 577 | max_bond_strength = min(n_bonds1, n_bonds2) 578 | if n_bonds1 == n_bonds2: # exclude small trivial molecules 579 | max_bond_strength -= 1 580 | for n in range(1, max_bond_strength + 1): 581 | id = self.type_infos[type1]['name'] + str(n) + \ 582 | self.type_infos[type2]['name'] 583 | d[id] = np.sum(con_mat[row_idcs][:, col_idcs] == n) 584 | if type1 == type2: 585 | d[id] = int(d[id]/2) # remove twice counted bonds 586 | 587 | # 2nd analyze rings 588 | ring_counts = self.get_ring_counts() 589 | if len(ring_counts) > 0: 590 | ring_counts = np.bincount(np.array(ring_counts)) 591 | n_bigger_8 = 0 592 | for i in np.nonzero(ring_counts)[0]: 593 | if i < 9: 594 | d[f'R{i}'] = ring_counts[i] 595 | else: 596 | n_bigger_8 += ring_counts[i] 597 | if n_bigger_8 > 0: 598 | d[f'R>8'] = n_bigger_8 599 | self._bond_stats = d 600 | 601 | return self._bond_stats 602 | 603 | def _get_row_idcs(self, type): 604 | ''' 605 | Retrieve the indices of all atoms in the molecule corresponding to a selected 606 | type. 607 | 608 | Args: 609 | type (int): the atom type (atomic number, e.g. 6 for carbon) 610 | 611 | Returns: 612 | list of int: indices of all atoms with the selected type 613 | ''' 614 | if type not in self._row_indices: 615 | self._row_indices[type] = np.where(self.numbers == type)[0] 616 | return self._row_indices[type] 617 | 618 | 619 | class ConnectivityCompressor(): 620 | ''' 621 | Utility class that provides methods to compress and decompress connectivity 622 | matrices. 623 | ''' 624 | 625 | def __init__(self): 626 | pass 627 | 628 | def compress(self, connectivity_matrix): 629 | ''' 630 | Compresses a single connectivity matrix. 631 | 632 | Args: 633 | connectivity_matrix (numpy.ndarray): array (n_atoms x n_atoms) 634 | containing the bond orders of bonds between atoms of a molecule 635 | 636 | Returns: 637 | dict (str/int->int): the length of the non-redundant connectivity 638 | matrix (list with upper triangular part) and the indices of that list for 639 | bond orders > 0 640 | ''' 641 | smaller = squareform(connectivity_matrix) # get list of upper triangular part 642 | d = {'n_entries': len(smaller)} # store length of list 643 | for i in np.unique(smaller).astype(int): # store indices per bond order > 0 644 | if i > 0: 645 | d[int(i)] = np.where(smaller == i)[0] 646 | return d 647 | 648 | def decompress(self, idcs_dict): 649 | ''' 650 | Retrieve the full (n_atoms x n_atoms) connectivity matrix from compressed 651 | format. 652 | 653 | Args: 654 | idcs_dict (dict str/int->int): compressed connectivity matrix 655 | (obtained with the compress method) 656 | 657 | Returns: 658 | numpy.ndarray: full connectivity matrix as an array of shape (n_atoms x 659 | n_atoms) 660 | ''' 661 | n_entries = idcs_dict['n_entries'] 662 | con_mat = np.zeros(n_entries) 663 | for i in idcs_dict: 664 | if isinstance(i, int) or i.isdigit(): 665 | con_mat[idcs_dict[i]] = int(i) 666 | return squareform(con_mat) 667 | 668 | def compress_batch(self, connectivity_batch): 669 | ''' 670 | Compress a batch of connectivity matrices. 671 | 672 | Args: 673 | connectivity_batch (list of numpy.ndarray): list of connectivity matrices 674 | 675 | Returns: 676 | list of dict: batch of compressed connectivity matrices (see compress) 677 | ''' 678 | dict_list = [] 679 | for matrix in connectivity_batch: 680 | dict_list += [self.compress(matrix)] 681 | return dict_list 682 | 683 | def decompress_batch(self, idcs_dict_batch): 684 | ''' 685 | Retrieve a list of full connectivity matrices from a batch of compressed 686 | connectivity matrices. 687 | 688 | Args: 689 | idcs_dict_batch (list of dict): list with compressed connectivity 690 | matrices 691 | 692 | Return: 693 | list numpy.ndarray: batch of full connectivity matrices (see decompress) 694 | ''' 695 | matrix_list = [] 696 | for idcs_dict in idcs_dict_batch: 697 | matrix_list += [self.decompress(idcs_dict)] 698 | return matrix_list 699 | 700 | 701 | class IndexProvider(): 702 | ''' 703 | Class which allows to filter a large set of molecules for desired structures 704 | according to provided statistics. The filtering is done using a selection string 705 | of the general format 'Statistics_nameDelimiterOperatorTarget_value' 706 | (e.g. 'C,>8' to filter for all molecules with more than eight carbon atoms where 707 | 'C' is the statistic counting the number of carbon atoms in a molecule, ',' is the 708 | delimiter, '>' is the operator, and '8' is the target value). 709 | 710 | Args: 711 | statistics (numpy.ndarray): 712 | statistics of all molecules where columns correspond to molecules and rows 713 | correspond to available statistics (n_statistics x n_molecules) 714 | row_headlines (numpy.ndarray): 715 | the names of the statistics stored in each row (e.g. 'F' for the number of 716 | fluorine atoms or 'R5' for the number of rings of size 5) 717 | default_filter (str, optional): 718 | the default behaviour of the filter if no operator and target value are 719 | given (e.g. filtering for 'F' will give all molecules with at least 1 720 | fluorine atom if default_filter='>0' or all molecules with exactly 2 721 | fluorine atoms if default_filter='==2', default: '>0') 722 | delimiter (str, optional): 723 | the delimiter used to separate names of statistics from the operator and 724 | target value in the selection strings (default: ',') 725 | ''' 726 | 727 | # dictionary mapping strings of available operators to corresponding function: 728 | op_dict = {'<': operator.lt, 729 | '<=': operator.le, 730 | '==': operator.eq, 731 | '=': operator.eq, 732 | '!=': operator.ne, 733 | '>': operator.gt, 734 | '>=': operator.ge} 735 | 736 | rel_re = re.compile('<=|<|={1,2}|!=|>=|>') # regular expression for operators 737 | num_re = re.compile('[\-]*[0-9]+[.]*[0-9]*') # regular expression for target values 738 | 739 | def __init__(self, statistics, row_headlines, default_filter='>0', delimiter=','): 740 | self.statistics = np.array(statistics) 741 | self.headlines = list(row_headlines) 742 | self.default_relation = self.rel_re.search(default_filter).group(0) 743 | self.default_number = float(self.num_re.search(default_filter).group(0)) 744 | self.delimiter = delimiter 745 | 746 | def get_selected(self, selection_str, idcs=None): 747 | ''' 748 | Retrieve the indices of all molecules which fulfill the selection criteria. 749 | The selection string is of the general format 750 | 'Statistics_nameDelimiterOperatorTarget_value' (e.g. 'C,>8' to filter for all 751 | molecules with more than eight carbon atoms where 'C' is the statistic counting 752 | the number of carbon atoms in a molecule, ',' is the delimiter, '>' is the 753 | operator, and '8' is the target value). 754 | 755 | The following operators are available: 756 | '<' 757 | '<=' 758 | '==' 759 | '!=' 760 | '>=' 761 | '>' 762 | 763 | The target value can be any positive or negative integer or float value. 764 | 765 | Multiple statistics can be summed using '+' (e.g. 'F+N,=0' gives all 766 | molecules that have no fluorine and no nitrogen atoms). 767 | 768 | Multiple filters can be concatenated using '&' (e.g. 'H,>8&C,=5' gives all 769 | molecules that have more than 8 hydrogen atoms and exactly 5 carbon atoms). 770 | 771 | Args: 772 | selection_str (str): string describing the criterion(s) for filtering (build 773 | as described above) 774 | idcs (numpy.ndarray, optional): if provided, only this subset of all 775 | molecules is filtered for structures fulfilling the selection criteria 776 | 777 | Returns: 778 | list of int: indices of all the molecules in the dataset that fulfill the 779 | selection criterion(s) 780 | ''' 781 | 782 | delimiter = self.delimiter 783 | if idcs is None: 784 | idcs = np.arange(len(self.statistics[0])) # take all to begin with 785 | criterions = selection_str.split('&') # split criteria 786 | for criterion in criterions: 787 | rel_strs = criterion.split(delimiter) 788 | 789 | # add multiple statistics if specified 790 | heads = rel_strs[0].split('+') 791 | statistics = self.statistics[self.headlines.index(heads[0])][idcs] 792 | for head in heads[1:]: 793 | statistics += self.statistics[self.headlines.index(head)][idcs] 794 | 795 | if len(rel_strs) == 1: 796 | relation = self.op_dict[self.default_relation]( 797 | statistics, self.default_number) 798 | elif len(rel_strs) == 2: 799 | rel = self.rel_re.search(rel_strs[1]).group(0) 800 | num = float(self.num_re.search(rel_strs[1]).group(0)) 801 | relation = self.op_dict[rel](statistics, num) 802 | new_idcs = np.where(relation)[0] 803 | idcs = idcs[new_idcs] 804 | 805 | return idcs 806 | 807 | 808 | class ProcessQ(Process): 809 | ''' 810 | Multiprocessing.Process class that runs a provided function using provided 811 | (keyword) arguments and puts the result into a provided Multiprocessing.Queue 812 | object (such that the result of the function can easily be obtained by the host 813 | process). 814 | 815 | Args: 816 | queue (Multiprocessing.Queue): the queue into which the results of running 817 | the target function will be put (the object in the queue will be a tuple 818 | containing the provided name as first entry and the function return as 819 | second entry). 820 | name (str): name of the object (is returned as first value in the tuple put 821 | into the queue. 822 | target (callable object): the function that is executed in the process's run 823 | method 824 | args (list of any): sequential arguments target is called with 825 | kwargs (dict (str->any)): keyword arguments target is called with 826 | ''' 827 | 828 | def __init__(self, queue, name=None, target=None, args=(), kwargs={}): 829 | super(ProcessQ, self).__init__(None, target, name, args, kwargs) 830 | self._name = name 831 | self._q = queue 832 | self._target = target 833 | self._args = args 834 | self._kwargs = kwargs 835 | 836 | def run(self): 837 | ''' 838 | Method representing the process's activity. 839 | 840 | Invokes the callable object passed as the target argument, if any, with 841 | sequential and keyword arguments taken from the args and kwargs arguments, 842 | respectively. Puts the string passed as name argument and the returned result 843 | of the callable object into the queue as (name, result). 844 | ''' 845 | if self._target is not None: 846 | res = (self.name, self._target(*self._args, **self._kwargs)) 847 | self._q.put(res) 848 | --------------------------------------------------------------------------------