├── .gitignore ├── LICENSE ├── README.md ├── data ├── isomerization │ ├── isomerization_dataset.hd5f │ ├── p_test.xyz │ ├── p_train.xyz │ ├── parser.py │ ├── r_test.xyz │ ├── r_train.xyz │ └── ts_train.xyz └── ts.hdf5 ├── dataset_statistics.py ├── docker-compose.yml ├── runs └── final │ ├── from_scratch │ ├── tsnet-distance │ │ └── experiment.py │ ├── tsnet-isom │ │ └── experiment.py │ ├── tsnet-shared-isom │ │ └── experiment.py │ ├── tsnet-shared │ │ └── experiment.py │ └── tsnet │ │ └── experiment.py │ └── pre_trained │ ├── tsnet-energy │ └── experiment.py │ ├── tsnet-isom │ └── experiment.py │ ├── tsnet-shared-isom │ └── experiment.py │ ├── tsnet-shared │ └── experiment.py │ └── tsnet │ └── experiment.py ├── setup.py ├── submit_experiments.py ├── tests ├── __init__.py ├── layer_tests │ ├── conftest.py │ ├── test_layers.py │ └── test_subclass_models.py └── tool_tests │ ├── __init__.py │ ├── conftest.py │ ├── test_jobs │ ├── __init__.py │ ├── test_classifiers.py │ ├── test_cross_validation.py │ ├── test_job.py │ ├── test_pipeline.py │ ├── test_regression.py │ └── test_search.py │ └── test_loaders.py ├── tfn ├── __init__.py ├── layers │ ├── __init__.py │ ├── atomic_images.py │ ├── layers.py │ ├── molecular_layers.py │ ├── radial_factories.py │ ├── utility_layers.py │ └── utils.py └── tools │ ├── __init__.py │ ├── builders │ ├── __init__.py │ ├── builder.py │ ├── cartesian_builder.py │ ├── classifier_builder.py │ ├── energy_builder.py │ ├── force_builder.py │ ├── missing_point_builder.py │ ├── multi_trunk_builder.py │ └── siamese_builder.py │ ├── callbacks.py │ ├── converters.py │ ├── ingredients.py │ ├── jobs │ ├── __init__.py │ ├── classification.py │ ├── config_defaults.py │ ├── cross_validate.py │ ├── job.py │ ├── keras_job.py │ ├── load_model.py │ ├── pipeline.py │ ├── regression.py │ └── search.py │ ├── loaders │ ├── __init__.py │ ├── data_loader.py │ ├── iso17_loader.py │ ├── isom_loader.py │ ├── qm9_loader.py │ ├── sn2_loader.py │ └── ts_loader.py │ ├── loggers.py │ └── radials.py └── tutorials ├── cat_pic.png ├── cat_pic_rotated.png └── tutorial.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Editor temporary/working/backup files # 2 | ######################################### 3 | .#* 4 | [#]*# 5 | *~ 6 | *$ 7 | *.bak 8 | *.diff 9 | .idea/ 10 | *.iml 11 | *.ipr 12 | *.iws 13 | *.org 14 | .project 15 | *.rej 16 | .settings/ 17 | .*.sw[nop] 18 | .sw[nop] 19 | *.tmp 20 | *.vim 21 | tags 22 | cscope.out 23 | # gnu global 24 | GPATH 25 | GRTAGS 26 | GSYMS 27 | GTAGS 28 | .vscode 29 | 30 | # Compiled source # 31 | ################### 32 | *.a 33 | *.com 34 | *.class 35 | *.dll 36 | *.exe 37 | *.o 38 | *.py[ocd] 39 | *.so 40 | 41 | # Packages # 42 | ############ 43 | # it's better to unpack these files and commit the raw source 44 | # git has its own built in compression methods 45 | *.7z 46 | *.bz2 47 | *.bzip2 48 | *.dmg 49 | *.gz 50 | *.iso 51 | *.jar 52 | *.rar 53 | *.tar 54 | *.tbz2 55 | *.tgz 56 | *.zip 57 | 58 | # Python files # 59 | ################ 60 | # setup.py working directory 61 | build 62 | # sphinx build directory 63 | _build 64 | # setup.py dist directory 65 | dist 66 | doc/build 67 | doc/sphinx/source/*.rst 68 | doc/cdoc/build 69 | # Egg metadata 70 | *.egg-info 71 | # The shelf plugin uses this dir 72 | ./.shelf 73 | MANIFEST 74 | 75 | # Paver generated files # 76 | ######################### 77 | /release 78 | 79 | # Logs and databases # 80 | ###################### 81 | *.log 82 | *.sql 83 | *.sqlite 84 | *.db 85 | 86 | # Patches # 87 | ########### 88 | *.patch 89 | *.diff 90 | 91 | # OS generated files # 92 | ###################### 93 | .DS_Store* 94 | .VolumeIcon.icns 95 | .fseventsd 96 | Icon? 97 | .gdb_history 98 | ehthumbs.db 99 | Thumbs.db 100 | 101 | # Docs 102 | ###################### 103 | docs/html 104 | 105 | # Custom Patterns 106 | ###################### 107 | examples/* 108 | *.h5 109 | *archived/ 110 | *storage/ 111 | logs/ 112 | mongo_db/ 113 | *.sh 114 | *.out 115 | *.pickle 116 | *.csv 117 | *.png 118 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Riley Jackson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## How to Install 2 | 3 | requirements: `atomic-images`, `tensorflow 2.0` 4 | 5 | This installation guide expects the user to understand pip and have it installed. This package depends on 6 | `atomic-images` available [here](https://github.com/UPEIChemistry/atomic-images). 7 | Clone `atomic-images` using `git clone git@github.com:UPEIChemistry/atomic-images.git`, then checkout the `tf_2.0` branch 8 | by using 9 | `git checkout tf_2.0`. Once this branch is checked out, use `pip install -e ./atomic-images` to install the package. 10 | To install tensor-field-networks, start by cloning the repo with 11 | `git clone git@github.com:UPEIChemistry/tensor-field-networks.git`, 12 | followed by using pip: `pip install -e ./tensor-field-networks`. The setup.py script contained in this package 13 | should install tensorflow 2, numpy, and any other 'official' dependencies. Be sure to install `tensorflow-gpu==2.0.0` 14 | and CUDA/cudNN if you intend to use this code on a GPU (which is recommended for the performance boost). 15 | 16 | # Tensor Field Networks 17 | 18 | Tensor Field Networks (TFN) are **Rotationally Equivariant Continuous Graph Convolution Neural Networks** which are 19 | capable of inputing continuous 3D point-clouds (e.g. molecules) and making scalar, vector, and higher order tensor 20 | predictions which rotate with the original input point-cloud ([Thomas et. al., 2018](https://arxiv.org/abs/1802.08219)). 21 | 22 | Ignoring the **continuous convolution** part, this means that TFNs are capable of knowing when an image has been 23 | rotated, something vanilla convolution nets are not capable of. For example, a traditional conv. net trained to 24 | recognize cats on **non-rotated images** would not identify a cat in the second picture: 25 | 26 | ![cat](tutorials/cat_pic.png) ![cat_rotated](tutorials/cat_pic_rotated.png) 27 | 28 | While TFNs will still identify a cat in the rotated image, trained only on images in a single orientation. To see a 29 | demonstration of this equivariance, and a further explanation of TFNs, checkout the Jupyter notebook located in the 30 | `tutorials` directory. If the user is not familiar with using Jupyter notebooks, they can read up on them 31 | [here](https://jupyter.readthedocs.io/en/latest/content-quickstart.html). 32 | -------------------------------------------------------------------------------- /data/isomerization/isomerization_dataset.hd5f: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UPEIChemistry/tensor-field-networks/5c25583ee4108a13af8e73eabd3c448f42cb70a0/data/isomerization/isomerization_dataset.hd5f -------------------------------------------------------------------------------- /data/isomerization/parser.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import h5py 5 | import numpy as np 6 | 7 | 8 | def split_xyz_file(path): 9 | with open(path) as file: 10 | lines = file.readlines() 11 | separated_cartesian_lines = consume_lines(lines) 12 | cartesians_collection = [] 13 | atomic_nums_collection = [] 14 | for cartesian_lines in separated_cartesian_lines: 15 | coordinates = convert_xyz_lines_to_list_of_numbers(cartesian_lines[2:]) 16 | cartesians, atomic_nums = coordinate_to_array(coordinates) 17 | cartesians_collection.append(cartesians) 18 | atomic_nums_collection.append(atomic_nums) 19 | return cartesians_collection, atomic_nums_collection 20 | 21 | 22 | def consume_lines(lines): 23 | def c(l: list): 24 | if len(l) > 0: 25 | i = int(l[0]) 26 | separated_cartesians.append(l[0 : i + 2]) 27 | l[0 : i + 2] = [] 28 | c(l) 29 | 30 | separated_cartesians = [] 31 | sys.setrecursionlimit(10000) 32 | c(lines) 33 | sys.setrecursionlimit(1500) 34 | return separated_cartesians 35 | 36 | 37 | def convert_xyz_lines_to_list_of_numbers(lines): 38 | coords = [] 39 | for l in lines: 40 | element, x, y, z = l.split() 41 | if not element.isdigit(): 42 | element = element_mapping()[element] 43 | else: 44 | element = int(element) 45 | coords.append((element, float(x), float(y), float(z))) 46 | return coords 47 | 48 | 49 | def coordinate_to_array(coordinates): 50 | coordinate_array = np.array(coordinates) 51 | cartesians = coordinate_array[:, 1:] 52 | atomic_nums = coordinate_array[:, :1].astype("int").reshape((-1, 1)) 53 | return cartesians, atomic_nums 54 | 55 | 56 | def pad_array(arr, atom_padding, value=np.nan): 57 | return np.pad( 58 | arr, 59 | ((0, atom_padding - arr.shape[0]), (0, 0)), 60 | mode="constant", 61 | constant_values=value, 62 | ) 63 | 64 | 65 | def element_mapping(): 66 | mapping = { 67 | "C": 6, 68 | "H": 1, 69 | "B": 5, 70 | "Br": 35, 71 | "Cl": 17, 72 | "D": 0, 73 | "F": 9, 74 | "I": 53, 75 | "N": 7, 76 | "O": 8, 77 | "P": 15, 78 | "S": 16, 79 | "Se": 34, 80 | "Si": 14, 81 | } 82 | reverse_mapping = dict([reversed(pair) for pair in mapping.items()]) 83 | mapping.update(reverse_mapping) 84 | return mapping 85 | 86 | 87 | def parse(): 88 | names = ["p_train", "r_train", "ts_train"] 89 | paths = [Path(f"./{n}.xyz") for n in names] 90 | 91 | arrays = {} 92 | for path in paths: 93 | cartesians_list, atomic_nums_list = split_xyz_file(path) 94 | for i, (c, a) in enumerate(zip(cartesians_list, atomic_nums_list)): 95 | cartesians_list[i] = pad_array(c, 21) 96 | atomic_nums_list[i] = pad_array(a, 21, value=0) 97 | arrays[path.stem] = [ 98 | np.array(arr) for arr in (cartesians_list, atomic_nums_list) 99 | ] 100 | 101 | with h5py.File("./isomerization_dataset.hd5f", "w") as file: 102 | for name, (cartesians, atomic_nums) in arrays.items(): 103 | file.create_dataset(f"{name}/cartesians", data=cartesians) 104 | file.create_dataset( 105 | f"{name}/atomic_nums", 106 | data=np.nan_to_num(np.squeeze(atomic_nums, axis=-1), nan=0), 107 | ) 108 | 109 | 110 | if __name__ == "__main__": 111 | parse() 112 | -------------------------------------------------------------------------------- /data/ts.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UPEIChemistry/tensor-field-networks/5c25583ee4108a13af8e73eabd3c448f42cb70a0/data/ts.hdf5 -------------------------------------------------------------------------------- /dataset_statistics.py: -------------------------------------------------------------------------------- 1 | from tfn.tools.loaders import TSLoader 2 | import numpy as np 3 | 4 | 5 | def get_atomic_histogram_data(z): 6 | z = np.where(z == 0, np.nan, z) 7 | z = np.where(z == 1, np.nan, z) 8 | return [np.count_nonzero(z == i) for i in range(36)] 9 | 10 | 11 | loader = TSLoader( 12 | path="/home/riley/dev/python/data/ts.hdf5", splitting=None, map_points=False 13 | ) 14 | x, _ = loader.load_data(remove_noise=True, shuffle=False)[0] 15 | z, *_ = x 16 | 17 | print(f"Size data: {np.count_nonzero(np.where(z == 1, 0, z), axis=1)}") 18 | print(f"Type data: {get_atomic_histogram_data(z)}") 19 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | 4 | mongo: 5 | image: mongo 6 | ports: 7 | - 27017:27017 8 | environment: 9 | MONGO_INITDB_ROOT_USERNAME: sample 10 | MONGO_INITDB_ROOT_PASSWORD: password 11 | MONGO_INITDB_DATABASE: db 12 | expose: 13 | - 27017 14 | networks: 15 | - omniboard 16 | 17 | omniboard: 18 | image: vivekratnavel/omniboard:latest 19 | command: ["--mu", "mongodb://sample:password@mongo:27017/db?authSource=admin"] 20 | ports: 21 | - 9000:9000 22 | networks: 23 | - omniboard 24 | depends_on: 25 | - mongo 26 | 27 | networks: 28 | omniboard: -------------------------------------------------------------------------------- /runs/final/from_scratch/tsnet-distance/experiment.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tfn.tools.jobs import CrossValidate 3 | 4 | 5 | job = CrossValidate( 6 | exp_config={ 7 | "name": f"{Path(__file__).parent}", 8 | "notes": "", 9 | "seed": 1, 10 | "run_config": {"epochs": 1000, "test": False, "batch_size": 48}, 11 | "loader_config": { 12 | "loader_type": "ts_loader", 13 | "splitting": 5, 14 | "load_kwargs": { 15 | "remove_noise": True, 16 | "shuffle": False, 17 | "output_distance_matrix": True, 18 | }, 19 | }, 20 | "builder_config": { 21 | "builder_type": "cartesian_builder", 22 | "radial_factory": "multi_dense", 23 | "prediction_type": "cartesians", 24 | "output_type": "distance_matrix", 25 | }, 26 | "lr_config": {"min_delta": 0.01, "patience": 30, "cooldown": 20}, 27 | } 28 | ) 29 | job.run() 30 | -------------------------------------------------------------------------------- /runs/final/from_scratch/tsnet-isom/experiment.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tfn.tools.jobs import StructurePrediction 3 | 4 | 5 | job = StructurePrediction( 6 | exp_config={ 7 | "name": f"{Path(__file__).parent}", 8 | "notes": "", 9 | "seed": 1, 10 | "run_config": { 11 | "epochs": 100, 12 | "test": True, 13 | "batch_size": 32, 14 | }, 15 | "loader_config": { 16 | "loader_type": "isom_loader", 17 | "path": "/home/rjackson/dev/tensor-field-networks/data/isomerization/isomerization_dataset.hd5f", 18 | "splitting": "75:20:5", 19 | }, 20 | "builder_config": { 21 | "builder_type": "cartesian_builder", 22 | "radial_factory": "multi_dense", 23 | "prediction_type": "cartesians", 24 | "output_type": "cartesians", 25 | }, 26 | "lr_config": {"min_delta": 0.01, "patience": 30, "cooldown": 20}, 27 | } 28 | ) 29 | job.run() 30 | -------------------------------------------------------------------------------- /runs/final/from_scratch/tsnet-shared-isom/experiment.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tfn.tools.jobs import StructurePrediction 3 | 4 | 5 | job = StructurePrediction( 6 | exp_config={ 7 | "name": f"{Path(__file__).parent}", 8 | "notes": "", 9 | "seed": 1, 10 | "run_config": { 11 | "epochs": 100, 12 | "test": True, 13 | "batch_size": 32, 14 | }, 15 | "loader_config": { 16 | "loader_type": "isom_loader", 17 | "path": "/home/rjackson/dev/tensor-field-networks/data/isomerization/isomerization_dataset.hd5f", 18 | "splitting": "75:20:5", 19 | }, 20 | "builder_config": { 21 | "builder_type": "cartesian_builder", 22 | "radial_factory": "single_dense", 23 | "prediction_type": "cartesians", 24 | "output_type": "cartesians", 25 | }, 26 | "lr_config": {"min_delta": 0.01, "patience": 30, "cooldown": 20}, 27 | } 28 | ) 29 | job.run() 30 | -------------------------------------------------------------------------------- /runs/final/from_scratch/tsnet-shared/experiment.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tfn.tools.jobs import CrossValidate 3 | 4 | 5 | job = CrossValidate( 6 | exp_config={ 7 | "name": f"{Path(__file__).parent}", 8 | "notes": "", 9 | "seed": 1, 10 | "run_config": {"epochs": 1000, "test": False, "batch_size": 48}, 11 | "loader_config": { 12 | "loader_type": "ts_loader", 13 | "splitting": 5, 14 | "load_kwargs": {"remove_noise": True, "shuffle": False}, 15 | }, 16 | "builder_config": { 17 | "builder_type": "cartesian_builder", 18 | "radial_factory": "single_dense", 19 | "prediction_type": "cartesians", 20 | "output_type": "cartesians", 21 | }, 22 | "lr_config": {"min_delta": 0.01, "patience": 30, "cooldown": 20}, 23 | } 24 | ) 25 | job.run() 26 | -------------------------------------------------------------------------------- /runs/final/from_scratch/tsnet/experiment.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tfn.tools.jobs import CrossValidate 3 | 4 | 5 | job = CrossValidate( 6 | exp_config={ 7 | "name": f"{Path(__file__).parent}", 8 | "notes": "", 9 | "seed": 1, 10 | "run_config": {"epochs": 1000, "test": False, "batch_size": 48}, 11 | "loader_config": { 12 | "path": "../../../../data/ts.hdf5" 13 | "loader_type": "ts_loader", 14 | "splitting": 5, 15 | "load_kwargs": {"remove_noise": True, "shuffle": False}, 16 | }, 17 | "builder_config": { 18 | "builder_type": "cartesian_builder", 19 | "radial_factory": "multi_dense", 20 | "prediction_type": "cartesians", 21 | "output_type": "cartesians", 22 | }, 23 | "lr_config": {"min_delta": 0.01, "patience": 30, "cooldown": 20}, 24 | } 25 | ) 26 | job.run() 27 | -------------------------------------------------------------------------------- /runs/final/pre_trained/tsnet-energy/experiment.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tfn.tools.jobs import Pipeline, CrossValidate, Regression 3 | 4 | 5 | job = Pipeline( 6 | exp_config={"name": f"{Path(__file__).parent}", "seed": 1}, 7 | jobs=[ 8 | Regression( 9 | exp_config={ 10 | "name": f"{Path(__file__).parent} QM9", 11 | "seed": 1, 12 | "run_config": {"epochs": 50, "test": False,}, 13 | "loader_config": { 14 | "loader_type": "qm9_loader", 15 | "splitting": "90:10:0", 16 | "map_points": False, 17 | "load_kwargs": {"custom_maxz": 36}, 18 | }, 19 | "builder_config": {"builder_type": "energy_builder"}, 20 | } 21 | ), 22 | CrossValidate( 23 | exp_config={ 24 | "name": f"{Path(__file__).parent} TS", 25 | "seed": 1, 26 | "run_config": {"epochs": 1000, "test": False, "batch_size": 48,}, 27 | "loader_config": { 28 | "loader_type": "ts_loader", 29 | "splitting": 5, 30 | "map_points": False, 31 | "load_kwargs": {"remove_noise": True, "shuffle": False}, 32 | }, 33 | "builder_config": { 34 | "builder_type": "cartesian_builder", 35 | "radial_factory": "multi_dense", 36 | "prediction_type": "cartesians", 37 | "output_type": "cartesians", 38 | }, 39 | "lr_config": {"min_delta": 0.01, "patience": 30, "cooldown": 20}, 40 | } 41 | ), 42 | ], 43 | ) 44 | job.run() 45 | -------------------------------------------------------------------------------- /runs/final/pre_trained/tsnet-isom/experiment.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tfn.tools.jobs import Pipeline, CrossValidate, StructurePrediction 3 | 4 | 5 | job = Pipeline( 6 | exp_config={"name": f"{Path(__file__).parent}", "seed": 1}, 7 | jobs=[ 8 | StructurePrediction( 9 | exp_config={ 10 | "name": f"{Path(__file__).parent} QM9", 11 | "seed": 1, 12 | "run_config": { 13 | "epochs": 100, 14 | "test": True, 15 | "batch_size": 32 16 | }, 17 | "loader_config": { 18 | "loader_type": "isom_loader", 19 | "path": "/home/rjackson/dev/tensor-field-networks/data/isomerization/isomerization_dataset" 20 | ".hd5f", 21 | "splitting": "75:20:5", 22 | }, 23 | "builder_config": { 24 | "builder_type": "cartesian_builder", 25 | "radial_factory": "multi_dense", 26 | "prediction_type": "cartesians", 27 | "output_type": "cartesians", 28 | }, 29 | } 30 | ), 31 | CrossValidate( 32 | exp_config={ 33 | "name": f"{Path(__file__).parent} TS", 34 | "seed": 1, 35 | "run_config": {"epochs": 1000, "test": False, "batch_size": 48,}, 36 | "loader_config": { 37 | "loader_type": "ts_loader", 38 | "splitting": 5, 39 | "map_points": False, 40 | "load_kwargs": {"remove_noise": True, "shuffle": False}, 41 | }, 42 | "builder_config": { 43 | "builder_type": "cartesian_builder", 44 | "radial_factory": "multi_dense", 45 | "prediction_type": "cartesians", 46 | "output_type": "cartesians", 47 | }, 48 | "lr_config": {"min_delta": 0.01, "patience": 30, "cooldown": 20}, 49 | } 50 | ), 51 | ], 52 | ) 53 | job.run() 54 | -------------------------------------------------------------------------------- /runs/final/pre_trained/tsnet-shared-isom/experiment.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tfn.tools.jobs import Pipeline, CrossValidate, StructurePrediction 3 | 4 | 5 | job = Pipeline( 6 | exp_config={"name": f"{Path(__file__).parent}", "seed": 1}, 7 | jobs=[ 8 | StructurePrediction( 9 | exp_config={ 10 | "name": f"{Path(__file__).parent} QM9", 11 | "seed": 1, 12 | "run_config": { 13 | "epochs": 100, 14 | "test": True, 15 | "batch_size": 32 16 | }, 17 | "loader_config": { 18 | "loader_type": "isom_loader", 19 | "path": "/home/rjackson/dev/tensor-field-networks/data/isomerization/isomerization_dataset" 20 | ".hd5f", 21 | "splitting": "75:20:5", 22 | }, 23 | "builder_config": { 24 | "builder_type": "cartesian_builder", 25 | "radial_factory": "single_dense", 26 | "prediction_type": "cartesians", 27 | "output_type": "cartesians", 28 | }, 29 | } 30 | ), 31 | CrossValidate( 32 | exp_config={ 33 | "name": f"{Path(__file__).parent} TS", 34 | "seed": 1, 35 | "run_config": {"epochs": 1000, "test": False, "batch_size": 48,}, 36 | "loader_config": { 37 | "loader_type": "ts_loader", 38 | "splitting": 5, 39 | "map_points": False, 40 | "load_kwargs": {"remove_noise": True, "shuffle": False}, 41 | }, 42 | "builder_config": { 43 | "builder_type": "cartesian_builder", 44 | "radial_factory": "single_dense", 45 | "prediction_type": "cartesians", 46 | "output_type": "cartesians", 47 | }, 48 | "lr_config": {"min_delta": 0.01, "patience": 30, "cooldown": 20}, 49 | } 50 | ), 51 | ], 52 | ) 53 | job.run() 54 | -------------------------------------------------------------------------------- /runs/final/pre_trained/tsnet-shared/experiment.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tfn.tools.jobs import Pipeline, CrossValidate, StructurePrediction 3 | 4 | 5 | job = Pipeline( 6 | exp_config={"name": f"{Path(__file__).parent}", "seed": 1}, 7 | jobs=[ 8 | StructurePrediction( 9 | exp_config={ 10 | "name": f"{Path(__file__).parent} QM9", 11 | "seed": 1, 12 | "run_config": {"epochs": 50, "test": False,}, 13 | "loader_config": { 14 | "loader_type": "qm9_loader", 15 | "splitting": "90:10:0", 16 | "load_kwargs": {"modify_structures": True, "custom_maxz": 36}, 17 | }, 18 | "builder_config": { 19 | "builder_type": "cartesian_builder", 20 | "radial_factory": "single_dense", 21 | "prediction_type": "cartesians", 22 | "output_type": "cartesians", 23 | }, 24 | } 25 | ), 26 | CrossValidate( 27 | exp_config={ 28 | "name": f"{Path(__file__).parent} TS", 29 | "seed": 1, 30 | "run_config": {"epochs": 1000, "test": False, "batch_size": 48,}, 31 | "loader_config": { 32 | "loader_type": "ts_loader", 33 | "splitting": 5, 34 | "map_points": False, 35 | "load_kwargs": {"remove_noise": True, "shuffle": False}, 36 | }, 37 | "builder_config": { 38 | "builder_type": "cartesian_builder", 39 | "radial_factory": "single_dense", 40 | "prediction_type": "cartesians", 41 | "output_type": "cartesians", 42 | }, 43 | "lr_config": {"min_delta": 0.01, "patience": 30, "cooldown": 20}, 44 | } 45 | ), 46 | ], 47 | ) 48 | job.run() 49 | -------------------------------------------------------------------------------- /runs/final/pre_trained/tsnet/experiment.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tfn.tools.jobs import Pipeline, CrossValidate, StructurePrediction 3 | 4 | 5 | job = Pipeline( 6 | exp_config={"name": f"{Path(__file__).parent}", "seed": 1}, 7 | jobs=[ 8 | StructurePrediction( 9 | exp_config={ 10 | "name": f"{Path(__file__).parent} QM9", 11 | "seed": 1, 12 | "run_config": {"epochs": 50, "test": False,}, 13 | "loader_config": { 14 | "loader_type": "qm9_loader", 15 | "splitting": "90:10:0", 16 | "load_kwargs": {"modify_structures": True, "custom_maxz": 36}, 17 | }, 18 | "builder_config": { 19 | "builder_type": "cartesian_builder", 20 | "radial_factory": "multi_dense", 21 | "prediction_type": "cartesians", 22 | "output_type": "cartesians", 23 | }, 24 | } 25 | ), 26 | CrossValidate( 27 | exp_config={ 28 | "name": f"{Path(__file__).parent} TS", 29 | "seed": 1, 30 | "run_config": {"epochs": 1000, "test": False, "batch_size": 48,}, 31 | "loader_config": { 32 | "loader_type": "ts_loader", 33 | "splitting": 5, 34 | "map_points": False, 35 | "load_kwargs": {"remove_noise": True, "shuffle": False}, 36 | }, 37 | "builder_config": { 38 | "builder_type": "cartesian_builder", 39 | "radial_factory": "multi_dense", 40 | "prediction_type": "cartesians", 41 | "output_type": "cartesians", 42 | }, 43 | "lr_config": {"min_delta": 0.01, "patience": 30, "cooldown": 20}, 44 | } 45 | ), 46 | ], 47 | ) 48 | job.run() 49 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | from tfn import __author__, __description__, __email__, __version__ 4 | 5 | setup( 6 | name="tensor-field-networks", 7 | author=__author__, 8 | author_email=__email__, 9 | version=__version__, 10 | description=__description__, 11 | packages=find_packages(), 12 | install_requires=[ 13 | "tensorflow>=2.0", 14 | "sacred>=0.8.0", 15 | "numpy", 16 | "scikit-learn", 17 | "h5py", 18 | ], 19 | extras_require={"dev": ["pytest", "black"]}, 20 | ) 21 | -------------------------------------------------------------------------------- /submit_experiments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import getpass 3 | import socket 4 | import time 5 | from warnings import warn 6 | from pathlib import Path 7 | from typing import List 8 | import os 9 | 10 | TIME = '24:00:00' 11 | MEM = '16G' 12 | NUM_GPUS = 1 13 | 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser(description='Script for running experiments on tater.') 17 | parser.add_argument('experiments', nargs='+', 18 | help='Path to one or more experiment file(s) or the directories ' 19 | 'containing them') 20 | parser.add_argument('--local', default=False, action='store_true', 21 | help='Flag to tell script NOT to add SLURM commands. NOT IMPLEMENTED.') 22 | parser.add_argument('--time', default=TIME, 23 | help='Wall-time for each SLURM job. Defaults to {}'.format(TIME)) 24 | parser.add_argument('--mem', default=MEM, 25 | help='Memory required for each SLURM job. Defaults to {}'.format(MEM)) 26 | parser.add_argument('--num_gpus', default=NUM_GPUS, 27 | help=('Number of gpus for each SLURM job. Defaults to {}'.format(NUM_GPUS))) 28 | return parser.parse_args() 29 | 30 | 31 | def main(experiments: List[Path], **kwargs): 32 | 33 | paths = sanitize_paths([e.resolve() for e in experiments]) 34 | 35 | if not kwargs.pop('local', False): 36 | run_slurm_job(paths, **kwargs) 37 | else: 38 | run_slurm_job(paths, **kwargs) 39 | 40 | 41 | def sanitize_paths(experiments): 42 | paths = [] 43 | for exp in experiments: 44 | if exp.is_dir(): 45 | paths.extend([ 46 | Path(root) / Path(file) for root, dirs, files in os.walk(exp.__str__()) 47 | for file in files if 'exp' in file and '.py' in file 48 | ]) 49 | elif 'exp' in exp.stem.lower() and exp.suffix == '.py': 50 | paths.append(exp) 51 | else: 52 | warn('incompatible experiment path: `{}` skipping...') 53 | continue 54 | if not paths: 55 | raise ValueError('No compatible experiment paths passed, please ensure script name ' 56 | 'contains "experiment", or is a path to a directory which contains a ' 57 | 'script with "experiment" in it\'s name.') 58 | return paths 59 | 60 | 61 | def run_slurm_job(paths, **kwargs): 62 | import subprocess 63 | slurm_params_lines = [ 64 | '#!/bin/bash\n', 65 | '#SBATCH --export=ALL\n', 66 | '#SBATCH --time={}\n'.format(kwargs.pop('time', TIME)), 67 | '#SBATCH --mem={}\n'.format(kwargs.pop('mem', MEM)), 68 | '#SBATCH --gres=gpu:{}\n'.format(str(kwargs.pop('num_gpus', NUM_GPUS))) 69 | ] 70 | for i, path in enumerate(paths): 71 | if path.is_dir(): 72 | d = path 73 | else: 74 | d = path.parent 75 | submission_path = d/'{}.sh'.format(path.stem) 76 | with open(submission_path, 'w') as file: 77 | file.write(''.join(line for line in slurm_params_lines)) 78 | file.write('#SBATCH --chdir={}\n'.format(path.parent.__str__())) 79 | file.write('#SBATCH --output=output.out\n\n') 80 | if socket.gethostname() == 'tater': 81 | file.write('source /mnt/fast-data/common/miniconda3/bin/activate mlenv\n') 82 | else: 83 | file.write('module load python/3.7\n') 84 | file.write('. ~/dev/environments/mlenv/bin/activate\n') 85 | file.write('python {}\n'.format(path)) 86 | subprocess.run( 87 | ['sbatch', str(submission_path)], 88 | check=True 89 | ) 90 | time.sleep(kwargs.pop('sleep_time', 3)) 91 | if i == len(paths): 92 | break 93 | subprocess.run(['squeue', '-u', getpass.getuser()]) 94 | 95 | 96 | if __name__ == '__main__': 97 | args = get_args() 98 | main( 99 | [Path(p) for p in args.experiments], 100 | local=args.local, 101 | time=args.time, 102 | mem=args.mem, 103 | num_gpus=args.num_gpus 104 | ) 105 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UPEIChemistry/tensor-field-networks/5c25583ee4108a13af8e73eabd3c448f42cb70a0/tests/__init__.py -------------------------------------------------------------------------------- /tests/layer_tests/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | 5 | # ===== Parser ===== # 6 | 7 | 8 | def pytest_addoption(parser): 9 | parser.addoption("--not-eager", action="store_false", default=True) 10 | parser.addoption("--not-dynamic", action="store_false", default=True) 11 | 12 | 13 | @pytest.fixture(scope="session") 14 | def dynamic(request): 15 | return request.config.getoption("--not-dynamic") 16 | 17 | 18 | @pytest.fixture(scope="session") 19 | def eager(request): 20 | return request.config.getoption("--not-eager") 21 | 22 | 23 | # ===== Data Fixtures ===== # 24 | 25 | 26 | @pytest.fixture(scope="session") 27 | def random_onehot_rbf_vectors(): 28 | one_hot = np.random.randint(0, 2, size=[2, 1, 6]) 29 | rbf = np.random.rand(2, 1, 1, 80).astype("float32") 30 | vectors = np.random.rand(2, 1, 1, 3).astype("float32") 31 | return one_hot, rbf, vectors 32 | 33 | 34 | @pytest.fixture(scope="session") 35 | def random_z_and_cartesians(): 36 | r = np.random.rand(2, 1, 3).astype("float32") 37 | z = np.random.randint(6, size=(2, 1)) 38 | return z, r 39 | 40 | 41 | @pytest.fixture(scope="session") 42 | def random_features_and_targets(): 43 | features = [ 44 | np.random.rand(2, 1, 16, 1).astype("float32"), 45 | np.random.rand(2, 1, 8, 3).astype("float32"), 46 | ] 47 | targets = [ 48 | np.random.rand(2, 1, 16, 1).astype("float32"), 49 | np.random.rand(2, 1, 16, 3).astype("float32"), 50 | ] 51 | return features, targets 52 | -------------------------------------------------------------------------------- /tests/layer_tests/test_layers.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | from tensorflow.python.keras.layers import Dense 5 | from tensorflow.keras.utils import get_custom_objects 6 | from tfn import layers 7 | 8 | 9 | class TestRadialFactory: 10 | def test_get_radial(self): 11 | _ = layers.DenseRadialFactory().get_radial(32) 12 | 13 | def test_export_and_creation_json(self): 14 | config = layers.DenseRadialFactory().to_json() 15 | factory = layers.DenseRadialFactory.from_json(config) 16 | assert factory.num_layers == 2 17 | assert factory.units == 32 18 | 19 | def test_radial_sum_atoms(self): 20 | radial = layers.Radial(units=16, sum_points=True) 21 | output = radial(np.random.randn(2, 1, 160)) 22 | assert output.shape == (2, 1, 16) 23 | 24 | 25 | class TestConvolution: 26 | def test_defaults(self, random_onehot_rbf_vectors, random_features_and_targets): 27 | _, *point_cloud = random_onehot_rbf_vectors 28 | features, targets = random_features_and_targets 29 | output = layers.Convolution()(list(point_cloud) + features) 30 | assert len(output) == 2 31 | 32 | def test_provided_radial_json( 33 | self, random_onehot_rbf_vectors, random_features_and_targets 34 | ): 35 | _, *point_cloud = random_onehot_rbf_vectors 36 | features, targets = random_features_and_targets 37 | d = { 38 | "type": "DenseRadialFactory", 39 | "num_layers": 3, 40 | "units": 4, 41 | "kernel_lambda": 0.01, 42 | } 43 | _ = layers.Convolution(radial_factory=json.dumps(d))( 44 | list(point_cloud) + list(features) 45 | ) 46 | 47 | def test_provided_radial_string( 48 | self, random_onehot_rbf_vectors, random_features_and_targets 49 | ): 50 | _, *point_cloud = random_onehot_rbf_vectors 51 | features, targets = random_features_and_targets 52 | conv = layers.Convolution(radial_factory="DenseRadialFactory") 53 | _ = conv(list(point_cloud) + list(features)) 54 | config = json.loads(conv.radial_factory.to_json()) 55 | assert config["type"] == "DenseRadialFactory" 56 | assert config["units"] == 32 57 | 58 | def test_provided_radial_string_and_kwargs( 59 | self, random_onehot_rbf_vectors, random_features_and_targets 60 | ): 61 | _, *point_cloud = random_onehot_rbf_vectors 62 | features, targets = random_features_and_targets 63 | conv = layers.Convolution( 64 | radial_factory="DenseRadialFactory", 65 | factory_kwargs={"num_layers": 3, "units": 4, "kernel_lambda": 0.01}, 66 | ) 67 | _ = conv(list(point_cloud) + list(features)) 68 | config = json.loads(conv.radial_factory.to_json()) 69 | assert config["units"] == 4 70 | 71 | def test_custom_radial( 72 | self, random_onehot_rbf_vectors, random_features_and_targets 73 | ): 74 | class MyRadial(layers.DenseRadialFactory): 75 | def __init__(self, num_units, **kwargs): 76 | super().__init__() 77 | self.num_units = num_units 78 | 79 | def get_radial(self, feature_dim, input_order=None, filter_order=None): 80 | return Dense(feature_dim) 81 | 82 | @classmethod 83 | def from_json(cls, json_str: str): 84 | return cls(**json.loads(json_str)) 85 | 86 | get_custom_objects().update({MyRadial.__name__: MyRadial}) 87 | _, *point_cloud = random_onehot_rbf_vectors 88 | features, targets = random_features_and_targets 89 | conv = layers.Convolution( 90 | radial_factory="MyRadial", factory_kwargs={"num_units": 6} 91 | ) 92 | _ = conv(list(point_cloud) + list(features)) 93 | config = json.loads(conv.radial_factory.to_json()) 94 | assert config["type"] == "MyRadial" 95 | assert config["num_units"] == 6 96 | 97 | def test_get_config(self, random_onehot_rbf_vectors, random_features_and_targets): 98 | _, *point_cloud = random_onehot_rbf_vectors 99 | features, targets = random_features_and_targets 100 | conv = layers.Convolution() 101 | _ = conv(list(point_cloud) + list(features)) 102 | config = conv.get_config() 103 | assert config["trainable"] is True and config["si_units"] == 16 104 | 105 | 106 | class TestMolecularConvolution: 107 | def test_defaults(self, random_onehot_rbf_vectors, random_features_and_targets): 108 | features, targets = random_features_and_targets 109 | output = layers.MolecularConvolution()( 110 | list(random_onehot_rbf_vectors) + list(features) 111 | ) 112 | assert len(output) == 2 113 | 114 | def test_sum_points(self, random_z_and_cartesians, random_features_and_targets): 115 | point_cloud = layers.Preprocessing(max_z=5, sum_points=True)( 116 | random_z_and_cartesians 117 | ) 118 | assert len(point_cloud[1].shape) == 3 119 | features, targets = random_features_and_targets 120 | output = layers.MolecularConvolution(sum_points=True)(point_cloud + features) 121 | assert len(output) == 2 122 | 123 | def test_one_in_one_out(self, random_onehot_rbf_vectors): 124 | point_cloud = random_onehot_rbf_vectors 125 | features = np.random.rand(2, 1, 1, 1).astype("float32") 126 | output = layers.MolecularConvolution(si_units=1)(list(point_cloud) + [features]) 127 | assert len(output) == 2 128 | 129 | 130 | class TestHarmonicFilter: 131 | def test_ro0_filter(self, random_onehot_rbf_vectors): 132 | _, rbf, vectors = random_onehot_rbf_vectors 133 | output = layers.HarmonicFilter(radial=Dense(16), filter_order=0)([rbf, vectors]) 134 | assert output.shape[-1] == 1 135 | 136 | def test_ro1_filter(self, random_onehot_rbf_vectors): 137 | _, rbf, vectors = random_onehot_rbf_vectors 138 | output = layers.HarmonicFilter(radial=Dense(16), filter_order=1)([rbf, vectors]) 139 | assert output.shape[-1] == 3 140 | 141 | def test_ro2_filter(self, random_onehot_rbf_vectors): 142 | _, rbf, vectors = random_onehot_rbf_vectors 143 | output = layers.HarmonicFilter(radial=Dense(16), filter_order=2)([rbf, vectors]) 144 | assert output.shape[-1] == 5 145 | 146 | 147 | class TestSelfInteraction: 148 | def test_correct_output_shapes(self, random_features_and_targets): 149 | inputs, targets = random_features_and_targets 150 | si = layers.SelfInteraction(32) 151 | outputs = si(inputs) 152 | assert outputs[0].shape == (2, 1, 32, 1) and outputs[1].shape == (2, 1, 32, 3) 153 | 154 | def test_molecular_si(self, random_features_and_targets, random_onehot_rbf_vectors): 155 | one_hot, *_ = random_onehot_rbf_vectors 156 | inputs, targets = random_features_and_targets 157 | si = layers.MolecularSelfInteraction(32) 158 | outputs = si([one_hot] + inputs) 159 | assert outputs[0].shape == (2, 1, 32, 1) and outputs[1].shape == (2, 1, 32, 3) 160 | 161 | def test_one_to_one_si(self): 162 | inputs = [np.random.rand(2, 10, 1, 1).astype("float32")] 163 | output = layers.SelfInteraction(1)(inputs)[0] 164 | assert output.shape == (2, 10, 1, 1) 165 | 166 | 167 | class TestEquivariantActivation: 168 | def test_correct_output_shapes(self, random_features_and_targets): 169 | inputs, targets = random_features_and_targets 170 | outputs = layers.EquivariantActivation()(inputs) 171 | assert all([i.shape == o.shape for i, o in zip(inputs, outputs)]) 172 | 173 | def test_molecular_activation( 174 | self, random_features_and_targets, random_onehot_rbf_vectors 175 | ): 176 | one_hot, *_ = random_onehot_rbf_vectors 177 | inputs, targets = random_features_and_targets 178 | outputs = layers.MolecularActivation()([one_hot] + inputs) 179 | assert all([i.shape == o.shape for i, o in zip(inputs, outputs)]) 180 | 181 | 182 | class TestPreprocessing: 183 | def test_3_output_tensors(self, random_z_and_cartesians): 184 | pre_block = layers.Preprocessing( 185 | max_z=5, 186 | basis_config={ 187 | "width": 0.2, 188 | "spacing": 0.2, 189 | "min_value": -1.0, 190 | "max_value": 15.0, 191 | }, 192 | dynamic=True, 193 | ) 194 | outputs = pre_block(random_z_and_cartesians) 195 | assert len(outputs) == 3 196 | 197 | def test_cosine_basis(self, random_z_and_cartesians): 198 | pre_block = layers.Preprocessing(max_z=5, basis_type="cosine") 199 | outputs = pre_block(random_z_and_cartesians) 200 | assert len(outputs) == 3 201 | assert pre_block.basis_type == "cosine" 202 | assert outputs[1].shape == (2, 1, 1, 80) 203 | 204 | def test_shifted_cosine_basis(self, random_z_and_cartesians): 205 | pre_block = layers.Preprocessing(max_z=5, basis_type="shifted_cosine") 206 | outputs = pre_block(random_z_and_cartesians) 207 | assert len(outputs) == 3 208 | assert pre_block.basis_type == "shifted_cosine" 209 | assert outputs[1].shape == (2, 1, 1, 80) 210 | -------------------------------------------------------------------------------- /tests/layer_tests/test_subclass_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.keras import Model 4 | from tfn.layers import ( 5 | Preprocessing, 6 | MolecularConvolution, 7 | MolecularSelfInteraction, 8 | ) 9 | from tfn.layers.utils import rotation_matrix 10 | 11 | 12 | # ===== Model Subclasses ===== # 13 | 14 | 15 | class Scalar(Model): 16 | def __init__(self, max_z: int = 6, **kwargs): 17 | super().__init__(**kwargs) 18 | self.max_z = max_z 19 | self.preprocessing = Preprocessing(max_z) 20 | self.embedding = MolecularSelfInteraction(16) 21 | self.convolution = MolecularConvolution() 22 | self.output_layer = MolecularConvolution(output_orders=[0], si_units=1) 23 | 24 | def call(self, inputs, training=None, mask=None): 25 | point_cloud = self.preprocessing(inputs) 26 | embedding = self.embedding( 27 | [point_cloud[0], tf.expand_dims(point_cloud[0], axis=-1)] 28 | ) 29 | output = self.convolution(point_cloud + embedding) 30 | output = self.output_layer(point_cloud + output) 31 | return tf.reduce_sum(tf.reduce_sum(output[0], axis=-2), axis=-2) 32 | 33 | def compute_output_shape(self, input_shape): 34 | batch, points, _ = input_shape[0] 35 | return tf.TensorShape([batch, 1]) 36 | 37 | def get_config(self): 38 | return dict(max_z=self.max_z,) 39 | 40 | 41 | class Vector(Scalar): 42 | def __init__(self, *args, **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.output_layer = MolecularConvolution(output_orders=[1]) 45 | 46 | def call(self, inputs, training=None, mask=None): 47 | point_cloud = self.preprocessing(inputs) 48 | embedding = self.embedding( 49 | [point_cloud[0], tf.expand_dims(point_cloud[0], axis=-1)] 50 | ) 51 | output = self.convolution(point_cloud + embedding) 52 | output = self.output_layer(point_cloud + output) 53 | return tf.reduce_sum(output[0], axis=-2) 54 | 55 | 56 | # ===== Tests ===== # 57 | 58 | 59 | class TestEquivariance: 60 | def test_dummy_atom_masked_predicted_vectors_rotate_correctly( 61 | self, random_z_and_cartesians, dynamic, eager 62 | ): 63 | z, start = random_z_and_cartesians 64 | end = np.random.rand(2, 10, 3) 65 | model = Vector(dynamic=dynamic) 66 | model.compile(optimizer="adam", loss="mae", run_eagerly=eager) 67 | model.fit([z, start], end, epochs=5) 68 | predicted_end = model.predict([z, start]) 69 | rot_mat = rotation_matrix([1, 0, 0], theta=np.radians(45)) 70 | rotated_start = np.dot(start, rot_mat) 71 | predicted_rotated_end = model.predict([z, rotated_start]) 72 | assert np.all( 73 | np.isclose( 74 | np.dot(predicted_end, rot_mat), 75 | predicted_rotated_end, 76 | rtol=0, 77 | atol=1.0e-5, 78 | ) 79 | ) 80 | 81 | 82 | class TestScalars: 83 | def test_default_model_predict_molecular_energies( 84 | self, random_z_and_cartesians, dynamic, eager 85 | ): 86 | e = np.random.rand(2, 1).astype("float32") 87 | model = Scalar(dynamic=dynamic) 88 | model.compile(optimizer="adam", loss="mae", run_eagerly=eager) 89 | model.fit(x=random_z_and_cartesians, y=e, epochs=2) 90 | -------------------------------------------------------------------------------- /tests/tool_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UPEIChemistry/tensor-field-networks/5c25583ee4108a13af8e73eabd3c448f42cb70a0/tests/tool_tests/__init__.py -------------------------------------------------------------------------------- /tests/tool_tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import socket 4 | from pathlib import Path 5 | from shutil import rmtree 6 | 7 | import pytest 8 | from pytest import fixture 9 | 10 | 11 | test_dir = Path(os.path.abspath(os.path.dirname(__file__))) / "test_jobs" 12 | test_model_path = test_dir / "test_model.h5" 13 | 14 | if socket.gethostname() == "tater": 15 | os.environ["DATADIR"] = "/mnt/fast-data/riley/data" 16 | else: 17 | os.environ["DATADIR"] = "/home/riley/dev/python/data" 18 | 19 | 20 | def clear_directories(): 21 | rmtree(test_dir / "logs", ignore_errors=True) 22 | rmtree(test_dir / "sacred_storage", ignore_errors=True) 23 | rmtree(test_dir / "tuner_storage", ignore_errors=True) 24 | try: 25 | for f in os.listdir(test_dir): 26 | if re.search(r".*\.h5", f): 27 | os.remove(test_dir / f) 28 | os.remove(test_model_path) 29 | except IsADirectoryError: 30 | rmtree(test_model_path, ignore_errors=True) 31 | except FileNotFoundError: 32 | pass 33 | 34 | 35 | @fixture(scope="session", autouse=True) 36 | def clear_logdirs(request): 37 | request.addfinalizer(clear_directories) 38 | 39 | 40 | @pytest.fixture(scope="session") 41 | def model(): 42 | return str(test_model_path) 43 | 44 | 45 | @pytest.fixture 46 | def run_config(): 47 | return { 48 | "epochs": 2, 49 | "test": False, 50 | "save_model": True, 51 | "use_strategy": False, 52 | "select_few": 50, 53 | "run_eagerly": True, 54 | "model_path": str(test_model_path), 55 | } 56 | 57 | 58 | @pytest.fixture 59 | def builder_config(): 60 | return {"dynamic": True} 61 | -------------------------------------------------------------------------------- /tests/tool_tests/test_jobs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UPEIChemistry/tensor-field-networks/5c25583ee4108a13af8e73eabd3c448f42cb70a0/tests/tool_tests/test_jobs/__init__.py -------------------------------------------------------------------------------- /tests/tool_tests/test_jobs/test_classifiers.py: -------------------------------------------------------------------------------- 1 | from tfn.tools.jobs import Regression 2 | 3 | 4 | class TestClassifiers: 5 | def test_siamese_network(self, run_config, builder_config): 6 | loader_config = { 7 | "loader_type": "ts_loader", 8 | "load_kwargs": {"output_type": "siamese"}, 9 | } 10 | job = Regression( 11 | { 12 | "name": "test", 13 | "run_config": run_config, 14 | "loader_config": loader_config, 15 | "builder_config": dict( 16 | **builder_config, builder_type="siamese_builder" 17 | ), 18 | } 19 | ) 20 | job.run() 21 | 22 | def test_basic_classifier(self, run_config, builder_config): 23 | loader_config = { 24 | "loader_type": "ts_loader", 25 | "load_kwargs": {"output_type": "classifier"}, 26 | } 27 | job = Regression( 28 | { 29 | "name": "test", 30 | "run_config": run_config, 31 | "loader_config": loader_config, 32 | "builder_config": dict( 33 | **builder_config, builder_type="classifier_builder" 34 | ), 35 | } 36 | ) 37 | job.run() 38 | 39 | def test_qm9_basic_classification(self, run_config, builder_config): 40 | loader_config = { 41 | "loader_type": "qm9_loader", 42 | "load_kwargs": {"modify_structures": True, "classifier_output": True}, 43 | } 44 | job = Regression( 45 | { 46 | "name": "test", 47 | "run_config": run_config, 48 | "loader_config": loader_config, 49 | "builder_config": dict( 50 | **builder_config, builder_type="classifier_builder" 51 | ), 52 | } 53 | ) 54 | job.run() 55 | -------------------------------------------------------------------------------- /tests/tool_tests/test_jobs/test_cross_validation.py: -------------------------------------------------------------------------------- 1 | from tfn.tools.jobs import CrossValidate 2 | 3 | 4 | class TestCrossValidation: 5 | def test_cartesian_ts_cross_validated(self, run_config, builder_config): 6 | loader_config = { 7 | "loader_type": "ts_loader", 8 | "splitting": 5, 9 | "load_kwargs": {"output_distance_matrix": True}, 10 | } 11 | job = CrossValidate( 12 | { 13 | "name": "test", 14 | "run_config": dict(**run_config, metrics=["cumulative_loss"]), 15 | "loader_config": loader_config, 16 | "builder_config": dict( 17 | **builder_config, 18 | builder_type="cartesian_builder", 19 | prediction_type="vectors", 20 | output_type="distance_matrix" 21 | ), 22 | } 23 | ) 24 | job.run() 25 | -------------------------------------------------------------------------------- /tests/tool_tests/test_jobs/test_job.py: -------------------------------------------------------------------------------- 1 | from tfn.tools.jobs import KerasJob 2 | from tfn.tools.jobs.config_defaults import loader_config, run_config 3 | from tfn.tools.loaders import DataLoader 4 | 5 | 6 | class TestJob: 7 | def test_load_data(self, clear_logdirs): 8 | job = KerasJob() 9 | loader, data = job._load_data() 10 | assert isinstance(loader, DataLoader) 11 | assert len(data) == 3 and len(data[0]) == 2 12 | 13 | def test_default_config(self, clear_logdirs): 14 | job = KerasJob() 15 | assert job.exp_config["loader_config"] == loader_config 16 | assert job.exp_config["run_config"] == run_config 17 | -------------------------------------------------------------------------------- /tests/tool_tests/test_jobs/test_pipeline.py: -------------------------------------------------------------------------------- 1 | from tfn.tools.jobs import Pipeline, Regression, StructurePrediction, CrossValidate 2 | 3 | 4 | class TestPipeline: 5 | def test_regression_to_structure_prediction_to_cross_validation( 6 | self, builder_config, run_config 7 | ): 8 | job = Pipeline( 9 | jobs=[ 10 | Regression( 11 | exp_config={ 12 | "run_config": run_config, 13 | "loader_config": {"loader_type": "iso17_loader"}, 14 | "builder_config": dict( 15 | **builder_config, builder_type="force_builder" 16 | ), 17 | } 18 | ), 19 | StructurePrediction( 20 | exp_config={ 21 | "run_config": run_config, 22 | "loader_config": { 23 | "loader_type": "qm9_loader", 24 | "load_kwargs": {"modify_structures": True}, 25 | }, 26 | "builder_config": dict( 27 | **builder_config, builder_type="cartesian_builder" 28 | ), 29 | } 30 | ), 31 | CrossValidate( 32 | exp_config={ 33 | "run_config": run_config, 34 | "loader_config": {"loader_type": "ts_loader", "splitting": 5}, 35 | "builder_config": dict( 36 | **builder_config, builder_type="cartesian_builder" 37 | ), 38 | } 39 | ), 40 | ] 41 | ) 42 | job.run() 43 | -------------------------------------------------------------------------------- /tests/tool_tests/test_jobs/test_regression.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.models import load_model 2 | 3 | from tfn.tools.jobs import Regression, StructurePrediction 4 | 5 | 6 | class TestScalarModels: 7 | def test_qm9(self, run_config, builder_config): 8 | job = Regression( 9 | {"name": "test", "run_config": run_config, "builder_config": builder_config} 10 | ) 11 | job.run() 12 | 13 | def test_non_residual(self, run_config, builder_config): 14 | job = Regression( 15 | { 16 | "name": "test", 17 | "run_config": run_config, 18 | "builder_config": dict(**builder_config, residual=False), 19 | } 20 | ) 21 | job.run() 22 | 23 | def test_sum_points(self, run_config, builder_config): 24 | job = Regression( 25 | { 26 | "name": "test", 27 | "run_config": run_config, 28 | "builder_config": dict(**builder_config, sum_points=True), 29 | } 30 | ) 31 | job.run() 32 | 33 | def test_cosine_basis(self, run_config, builder_config): 34 | job = Regression( 35 | { 36 | "name": "test", 37 | "run_config": run_config, 38 | "builder_config": dict(**builder_config, basis_type="cosine"), 39 | } 40 | ) 41 | job.run() 42 | 43 | def test_single_dense_radial(self, run_config, builder_config): 44 | job = Regression( 45 | { 46 | "name": "test", 47 | "run_config": run_config, 48 | "builder_config": dict( 49 | **builder_config, 50 | **{ 51 | "embedding_units": 32, 52 | "model_num_layers": (3, 3, 3), 53 | "si_units": 32, 54 | "radial_factory": "single_dense", 55 | "radial_kwargs": { 56 | "num_layers": 1, 57 | "units": 64, 58 | "activation": "ssp", 59 | "kernel_lambda": 0.01, 60 | "bias_lambda": 0.01, 61 | }, 62 | } 63 | ), 64 | } 65 | ) 66 | job.run() 67 | 68 | def test_default_loads_graphly(self, run_config, builder_config, model): 69 | run_config["run_eagerly"] = False 70 | builder_config["dynamic"] = False 71 | job = Regression( 72 | {"name": "test", "run_config": run_config, "builder_config": builder_config} 73 | ) 74 | job.run() 75 | model = load_model(model) 76 | assert True 77 | 78 | def test_default_loads_eagerly(self, run_config, builder_config, model): 79 | run_config["run_eagerly"] = True 80 | builder_config["dynamic"] = True 81 | job = Regression( 82 | {"name": "test", "run_config": run_config, "builder_config": builder_config} 83 | ) 84 | job.run() 85 | model = load_model(model) 86 | assert True 87 | 88 | 89 | class TestDualModels: 90 | def test_iso17(self, run_config, builder_config): 91 | loader_config = {"loader_type": "iso17_loader"} 92 | job = Regression( 93 | { 94 | "name": "test", 95 | "run_config": run_config, 96 | "loader_config": loader_config, 97 | "builder_config": dict(**builder_config, builder_type="force_builder"), 98 | } 99 | ) 100 | job.run() 101 | 102 | def test_sn2(self, run_config, builder_config): 103 | loader_config = {"loader_type": "sn2_loader"} 104 | job = Regression( 105 | { 106 | "name": "test", 107 | "run_config": run_config, 108 | "loader_config": loader_config, 109 | "builder_config": dict(**builder_config, builder_type="force_builder"), 110 | } 111 | ) 112 | job.run() 113 | 114 | 115 | class TestCartesianModels: 116 | def test_vector_prediction_cartesian_output(self, run_config, builder_config): 117 | loader_config = { 118 | "loader_type": "ts_loader", 119 | "load_kwargs": {"output_distance_matrix": False}, 120 | } 121 | job = StructurePrediction( 122 | { 123 | "name": "test", 124 | "run_config": run_config, 125 | "loader_config": loader_config, 126 | "builder_config": dict( 127 | **builder_config, 128 | builder_type="cartesian_builder", 129 | prediction_type="vectors", 130 | output_type="cartesians" 131 | ), 132 | } 133 | ) 134 | job.run() 135 | 136 | def test_vector_prediction_distance_matrix_output(self, run_config, builder_config): 137 | loader_config = { 138 | "loader_type": "ts_loader", 139 | "load_kwargs": {"output_distance_matrix": True}, 140 | } 141 | job = StructurePrediction( 142 | { 143 | "name": "test", 144 | "run_config": run_config, 145 | "loader_config": loader_config, 146 | "builder_config": dict( 147 | **builder_config, 148 | builder_type="cartesian_builder", 149 | prediction_type="vectors", 150 | output_type="distance_matrix" 151 | ), 152 | } 153 | ) 154 | job.run() 155 | 156 | def test_cartesian_prediction_cartesian_output(self, run_config, builder_config): 157 | loader_config = { 158 | "loader_type": "ts_loader", 159 | "load_kwargs": {"output_distance_matrix": False}, 160 | } 161 | job = StructurePrediction( 162 | { 163 | "name": "test", 164 | "run_config": run_config, 165 | "loader_config": loader_config, 166 | "builder_config": dict( 167 | **builder_config, 168 | builder_type="cartesian_builder", 169 | prediction_type="cartesians", 170 | output_type="cartesians" 171 | ), 172 | } 173 | ) 174 | job.run() 175 | 176 | def test_cartesian_prediction_distance_matrix_output( 177 | self, run_config, builder_config 178 | ): 179 | loader_config = { 180 | "loader_type": "ts_loader", 181 | "load_kwargs": {"output_distance_matrix": True}, 182 | } 183 | job = StructurePrediction( 184 | { 185 | "name": "test", 186 | "run_config": run_config, 187 | "loader_config": loader_config, 188 | "builder_config": dict( 189 | **builder_config, 190 | builder_type="cartesian_builder", 191 | prediction_type="cartesians", 192 | output_type="distance_matrix" 193 | ), 194 | } 195 | ) 196 | job.run() 197 | 198 | def test_modified_qm9_vector_prediction_cartesian_output( 199 | self, run_config, builder_config 200 | ): 201 | loader_config = { 202 | "loader_type": "qm9_loader", 203 | "load_kwargs": {"modify_structures": True}, 204 | } 205 | job = Regression( 206 | { 207 | "name": "test", 208 | "run_config": run_config, 209 | "loader_config": loader_config, 210 | "builder_config": dict( 211 | **builder_config, builder_type="cartesian_builder" 212 | ), 213 | } 214 | ) 215 | job.run() 216 | -------------------------------------------------------------------------------- /tests/tool_tests/test_jobs/test_search.py: -------------------------------------------------------------------------------- 1 | from tfn.tools.jobs import CrossValidate, GridSearch, Regression 2 | 3 | 4 | class TestGridSearch: 5 | GRID_CONFIG = { 6 | "sum_atoms": [True, False], 7 | "model_num_layers": [(1, 1), (1, 1, 1)], 8 | "radial_kwargs": [ 9 | { 10 | "num_layers": 2, 11 | "units": 18, 12 | "activation": "ssp", 13 | "kernel_lambda": 0.01, 14 | "bias_lambda": 0.01, 15 | }, 16 | { 17 | "num_layers": 1, 18 | "units": 32, 19 | "activation": "ssp", 20 | "kernel_lambda": 0.01, 21 | "bias_lambda": 0.01, 22 | }, 23 | ], 24 | } 25 | 26 | def test_basic_grid_search(self, run_config): 27 | job = GridSearch( 28 | job=Regression(exp_config={"name": "test", "run_config": run_config}), 29 | grid=self.GRID_CONFIG, 30 | total_models=3, 31 | ) 32 | job.run() 33 | 34 | def test_cross_validate_grid_search(self, run_config, builder_config): 35 | loader_config = { 36 | "loader_type": "ts_loader", 37 | "splitting": 3, 38 | "load_kwargs": {"output_distance_matrix": True}, 39 | } 40 | job = GridSearch( 41 | job=CrossValidate( 42 | { 43 | "name": "test", 44 | "run_config": dict(**run_config), 45 | "loader_config": loader_config, 46 | "builder_config": dict( 47 | **builder_config, 48 | builder_type="cartesian_builder", 49 | prediction_type="vectors", 50 | output_type="distance_matrix" 51 | ), 52 | } 53 | ), 54 | grid=self.GRID_CONFIG, 55 | total_models=3, 56 | ) 57 | job.run() 58 | -------------------------------------------------------------------------------- /tests/tool_tests/test_loaders.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | 4 | from math import ceil, isclose 5 | import numpy as np 6 | 7 | from tfn.tools.loaders import ( 8 | ISO17DataLoader, 9 | QM9DataDataLoader, 10 | TSLoader, 11 | SN2Loader, 12 | IsomLoader, 13 | ) 14 | 15 | 16 | class TestQM9Loader: 17 | def test_load_data(self): 18 | loader = QM9DataDataLoader(os.environ["DATADIR"] + "/QM9_data_original.hdf5") 19 | data = loader.load_data() 20 | assert len(data) == 3 21 | assert len(data[0]) == 2 22 | assert len(data[0][0]) == 2 23 | 24 | def test_train_val_test_splitting(self): 25 | loader = QM9DataDataLoader( 26 | os.environ["DATADIR"] + "/QM9_data_original.hdf5", splitting="70:20:10" 27 | ) 28 | data = loader.load_data() 29 | assert len(data) == 3 30 | assert isclose(len(data[0][0][0]), ceil(0.70 * 133885), abs_tol=1) 31 | assert isclose(len(data[1][0][0]), ceil(0.20 * 133885), abs_tol=1) 32 | assert isclose(len(data[2][0][0]), ceil(0.10 * 133885), abs_tol=1) 33 | 34 | def test_train_val_splitting(self): 35 | loader = QM9DataDataLoader( 36 | os.environ["DATADIR"] + "/QM9_data_original.hdf5", splitting="70:30:0" 37 | ) 38 | data = loader.load_data() 39 | assert len(data) == 3 40 | assert isclose(len(data[0][0][0]), ceil(0.70 * 133885), abs_tol=1) 41 | assert isclose(len(data[1][0][0]), ceil(0.30 * 133885), abs_tol=1) 42 | assert data[2] is None 43 | 44 | def test_train_test_splitting(self): 45 | loader = QM9DataDataLoader( 46 | os.environ["DATADIR"] + "/QM9_data_original.hdf5", splitting="70:0:30" 47 | ) 48 | data = loader.load_data() 49 | assert len(data) == 3 50 | assert isclose(len(data[0][0][0]), ceil(0.70 * 133885), abs_tol=1) 51 | assert data[1] is None 52 | assert isclose(len(data[2][0][0]), ceil(0.30 * 133885), abs_tol=1) 53 | 54 | def test_cross_validation_splitting(self): 55 | loader = QM9DataDataLoader( 56 | os.environ["DATADIR"] + "/QM9_data_original.hdf5", splitting=7 57 | ) 58 | data = loader.load_data() 59 | assert len(data) == 7 60 | assert len(np.concatenate([d[0][0] for d in data], axis=0)) == 133885 61 | 62 | def test_modified(self): 63 | loader = QM9DataDataLoader( 64 | os.environ["DATADIR"] + "/QM9_data_original.hdf5", splitting="70:20:10" 65 | ) 66 | data = loader.load_data(modify_structures=True, modify_distance=1) 67 | assert len(data[0][0]) == 3 68 | assert data[0][1][0].shape[1:] == (loader.num_points, 3) 69 | 70 | def test_distance_matrix(self): 71 | loader = QM9DataDataLoader( 72 | os.environ["DATADIR"] + "/QM9_data_original.hdf5", splitting="70:20:10" 73 | ) 74 | data = loader.load_data( 75 | modify_structures=True, modify_distance=1, output_distance_matrix=True 76 | ) 77 | assert len(data[0][0]) == 3 78 | assert data[0][1][0].shape[1:] == (loader.num_points, loader.num_points) 79 | 80 | def test_classifier(self): 81 | loader = QM9DataDataLoader( 82 | os.environ["DATADIR"] + "/QM9_data_original.hdf5", splitting="70:20:10" 83 | ) 84 | data = loader.load_data(modify_structures=True, classifier_output=True) 85 | assert len(data[0][0]) == 2 # tiled atomic_nums, tiled cartesians 86 | assert len(data[0][1][0].shape) == 1 87 | 88 | 89 | class TestISO17Loader: 90 | def test_load_dual_data(self): 91 | loader = ISO17DataLoader(os.environ["DATADIR"] + "/iso17.hdf5") 92 | data = loader.load_data() 93 | assert len(data) == 3 # train, val, test 94 | assert len(data[0]) == 2 # x_train, y_train 95 | assert len(data[0][0]) == 2 # r, z 96 | assert ( 97 | ceil(len(data[0][0][0]) / 0.70) == 461715 98 | ) # ensure train split is 95% of total reference examples 99 | assert data[0][0][0].shape[1] == 29 100 | 101 | def test_load_force_data(self): 102 | loader = ISO17DataLoader( 103 | os.environ["DATADIR"] + "/iso17.hdf5", use_energies=False 104 | ) 105 | data = loader.load_data() 106 | assert len(data[0][1]) == 1 # Only 1 y values 107 | 108 | 109 | class TestTSLoader: 110 | def test_load_ts_data(self): 111 | loader = TSLoader(os.environ["DATADIR"] + "/ts.hdf5") 112 | data = loader.load_data() 113 | assert len(data) == 3 114 | assert len(data[0]) == 2 115 | assert len(data[0][0]) == 3 # Z, R, P 116 | assert len(data[0][1]) == 1 # TS 117 | 118 | def test_complexes(self): 119 | loader = TSLoader(os.environ["DATADIR"] + "/ts.hdf5") 120 | data = loader.load_data("use_complexes") 121 | assert len(data) == 3 122 | assert len(data[0]) == 2 123 | assert len(data[0][0]) == 3 # Z, RC, PC 124 | assert len(data[0][1]) == 1 # TS 125 | 126 | def test_remove_noise(self): 127 | loader = TSLoader(os.environ["DATADIR"] + "/ts.hdf5", splitting=None) 128 | data = loader.load_data(remove_noise=True) 129 | assert len(data[0][0][0]) == 55 130 | 131 | def test_energy_serving(self): 132 | loader = TSLoader(os.environ["DATADIR"] + "/ts.hdf5", pre_load=False) 133 | data = loader.load_data(output_type="energies") 134 | assert len(data) == 3 135 | assert len(data[0][0]) == 3 136 | assert len(data[0][1][0].shape) == 1 137 | 138 | def test_serving_both(self): 139 | loader = TSLoader(os.environ["DATADIR"] + "/ts.hdf5", pre_load=False) 140 | data = loader.load_data(output_type="both") 141 | assert len(data) == 3 142 | assert len(data[0][0]) == 3 143 | assert len(data[0][1]) == 2 144 | assert data[0][1][0].shape[1:] == (loader.num_points, 3) 145 | assert len(data[0][1][1].shape) == 1 146 | 147 | def test_distance_matrix(self): 148 | loader = TSLoader(os.environ["DATADIR"] + "/ts.hdf5", pre_load=False) 149 | data = loader.load_data(output_type="both", output_distance_matrix=True) 150 | assert len(data) == 3 151 | assert len(data[0][0]) == 3 152 | assert len(data[0][1]) == 2 153 | assert data[0][1][0].shape[1:] == (loader.num_points, loader.num_points) 154 | assert len(data[0][1][1].shape) == 1 155 | 156 | def test_classification_data(self): 157 | loader = TSLoader(os.environ["DATADIR"] + "/ts.hdf5", pre_load=False) 158 | data = loader.load_data(output_type="classifier") 159 | assert len(data) == 3 160 | assert len(data[0][0]) == 2 161 | assert len(data[0][1]) == 1 162 | assert len(data[0][1][0].shape) == 1 163 | 164 | def test_siamese_data(self): 165 | loader = TSLoader(os.environ["DATADIR"] + "/ts.hdf5", pre_load=False) 166 | data = loader.load_data(output_type="siamese") 167 | assert len(data) == 3 168 | assert data[0][0][0].shape[1:] == (2, loader.num_points,) 169 | assert data[0][0][1].shape[1:] == (2, loader.num_points, 3) 170 | assert len(data[0][1]) == 1 171 | 172 | def test_custom_split(self): 173 | loader = TSLoader( 174 | os.environ["DATADIR"] + "/ts.hdf5", pre_load=False, splitting="custom" 175 | ) 176 | train, val, test = loader.load_data(shuffle=False) 177 | assert test is None 178 | assert len(val[0][0]) == 7 179 | 180 | 181 | class TestSN2Loader: 182 | def test_load_sn2_data(self): 183 | loader = SN2Loader(os.environ["DATADIR"] + "/sn2_reactions.npz") 184 | data = loader.load_data() 185 | assert len(data) == 3 186 | assert len(data[0]) == 2 187 | assert len(data[0][0]) == 2 # R, Z 188 | assert len(data[0][1]) == 2 # E, F 189 | assert len(loader.mu) == 2 190 | assert isclose(loader.sigma[1], 0.71, abs_tol=0.3) 191 | 192 | 193 | class TestIsomLoader: 194 | def test_load_isomerization_data(self): 195 | loader = IsomLoader( 196 | "/home/riley/Documents/tensor-field-networks/data/isomerization/isomerization_dataset.hd5f" 197 | ) 198 | data = loader.load_data() 199 | assert data 200 | -------------------------------------------------------------------------------- /tfn/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Package containing Tensorfield Network tf.keras layers built using TF 2.0 3 | """ 4 | import logging 5 | import os 6 | 7 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # FATAL 8 | logging.getLogger("tensorflow").setLevel(logging.FATAL) 9 | 10 | __version__ = "2.5.0" 11 | __author__ = "Riley Jackson" 12 | __email__ = "rjjackson@upei.ca" 13 | __description__ = ( 14 | "Package containing Tensor Field Network tf.keras layers built using Tensorflow 2" 15 | ) 16 | -------------------------------------------------------------------------------- /tfn/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.utils import get_custom_objects 2 | 3 | from tfn.layers.utils import shifted_softplus, tfn_mae 4 | 5 | from .atomic_images import ( 6 | OneHot, 7 | DistanceMatrix, 8 | KernelBasis, 9 | GaussianBasis, 10 | AtomicNumberBasis, 11 | Unstandardization, 12 | DummyAtomMasking, 13 | CutoffLayer, 14 | CosineCutoff, 15 | TanhCutoff, 16 | LongTanhCutoff, 17 | ) 18 | 19 | from .utility_layers import Preprocessing, UnitVectors, MaskedDistanceMatrix 20 | from .radial_factories import RadialFactory, DenseRadialFactory, Radial 21 | from .layers import ( 22 | EquivariantLayer, 23 | Convolution, 24 | HarmonicFilter, 25 | SelfInteraction, 26 | EquivariantActivation, 27 | ) 28 | from .molecular_layers import ( 29 | MolecularConvolution, 30 | MolecularSelfInteraction, 31 | MolecularActivation, 32 | ) 33 | 34 | 35 | get_custom_objects().update( 36 | { 37 | "ssp": shifted_softplus, 38 | "shifted_softplus": shifted_softplus, 39 | "tfn_mae": tfn_mae, 40 | RadialFactory.__name__: RadialFactory, 41 | DenseRadialFactory.__name__: DenseRadialFactory, 42 | Radial.__name__: Radial, 43 | EquivariantLayer.__name__: EquivariantLayer, 44 | Convolution.__name__: Convolution, 45 | MolecularConvolution.__name__: MolecularConvolution, 46 | HarmonicFilter.__name__: HarmonicFilter, 47 | SelfInteraction.__name__: SelfInteraction, 48 | MolecularSelfInteraction.__name__: MolecularSelfInteraction, 49 | EquivariantActivation.__name__: EquivariantActivation, 50 | MolecularActivation.__name__: MolecularActivation, 51 | Preprocessing.__name__: Preprocessing, 52 | UnitVectors.__name__: UnitVectors, 53 | MaskedDistanceMatrix.__name__: MaskedDistanceMatrix, 54 | OneHot.__name__: OneHot, 55 | DistanceMatrix.__name__: DistanceMatrix, 56 | KernelBasis.__name__: KernelBasis, 57 | GaussianBasis.__name__: GaussianBasis, 58 | AtomicNumberBasis.__name__: AtomicNumberBasis, 59 | Unstandardization.__name__: Unstandardization, 60 | DummyAtomMasking.__name__: DummyAtomMasking, 61 | CutoffLayer.__name__: CutoffLayer, 62 | CosineCutoff.__name__: CosineCutoff, 63 | TanhCutoff.__name__: TanhCutoff, 64 | LongTanhCutoff.__name__: LongTanhCutoff, 65 | } 66 | ) 67 | -------------------------------------------------------------------------------- /tfn/layers/atomic_images.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Layer 3 | 4 | from tensorflow.keras import backend as K 5 | 6 | import numpy as np 7 | 8 | 9 | def linspace(*args, **kwargs): 10 | """ 11 | Keras backend equivalent to numpy and TF's linspace 12 | 13 | Arguments: 14 | start (float or int): the starting point. If only two values 15 | are provided, the stop value 16 | stop (float or int): the stopping point. If only two values 17 | are provided, the number of points. 18 | n (int): the number of points to return 19 | """ 20 | endpoint = kwargs.get("endpoint", True) 21 | if len(args) == 1: 22 | raise ValueError("must provide the number of points") 23 | elif len(args) == 2: 24 | stop, n = args 25 | start = 0 26 | elif len(args) == 3: 27 | start, stop, n = args 28 | else: 29 | raise ValueError("invalid call to linspace") 30 | 31 | range_ = stop - start 32 | if endpoint: 33 | step_ = range_ / (n - 1) 34 | else: 35 | step_ = range_ / n 36 | 37 | points = tf.range(0, n, dtype=tf.float32) 38 | points *= step_ 39 | points += start 40 | 41 | return points 42 | 43 | 44 | class OneHot(Layer): 45 | """One-hot atomic number layer 46 | 47 | Converts a list of atomic numbers to one-hot vectors 48 | 49 | Input: atomic numbers (batch, atoms) 50 | Output: one-hot atomic number (batch, atoms, atomic_number) 51 | """ 52 | 53 | def __init__(self, max_atomic_number, **kwargs): 54 | # Parameters 55 | self.max_atomic_number = max_atomic_number 56 | 57 | super(OneHot, self).__init__(**kwargs) 58 | 59 | def call(self, inputs, **kwargs): 60 | atomic_numbers = inputs 61 | return tf.one_hot(atomic_numbers, self.max_atomic_number) 62 | 63 | def compute_output_shape(self, input_shapes): 64 | atomic_numbers = input_shapes 65 | return tf.TensorShape(list(atomic_numbers) + [self.max_atomic_number]) 66 | 67 | def get_config(self): 68 | base_config = super(OneHot, self).get_config() 69 | config = {"max_atomic_number": self.max_atomic_number} 70 | return {**base_config, **config} 71 | 72 | 73 | class DistanceMatrix(Layer): 74 | """ 75 | Distance matrix layer 76 | 77 | Expands Cartesian coordinates into a distance matrix. 78 | 79 | Input: coordinates (..., atoms, 3) 80 | Output: distance matrix (..., atoms, atoms) 81 | """ 82 | 83 | def call(self, inputs, **kwargs): 84 | positions = inputs 85 | # `positions` should be Cartesian coordinates of shape 86 | # (..., atoms, 3) 87 | v1 = tf.expand_dims(positions, axis=-2) 88 | v2 = tf.expand_dims(positions, axis=-3) 89 | 90 | sum_squares = tf.reduce_sum(tf.square(v2 - v1), axis=-1) 91 | sqrt = tf.sqrt(sum_squares + 1e-7) 92 | tf.where(sqrt >= 1e-7, sqrt, tf.zeros_like(sqrt)) 93 | return sqrt 94 | 95 | def compute_output_shape(self, input_shape): 96 | return tf.TensorShape( 97 | list(input_shape[:-2]) + [input_shape[-2], input_shape[-2]] 98 | ) 99 | 100 | 101 | # 102 | # Kernel functions 103 | # 104 | class KernelBasis(Layer): 105 | """Expand tensor using kernel of width=width, spacing=spacing, 106 | starting at min_value ending at max_value (inclusive if endpoint=True). 107 | 108 | Input: tensor (batch, atoms, [atoms, [atoms...]) 109 | Output: tensor expanded into kernel basis set (batch, atoms, [atoms, [atoms...]], n_gaussians) 110 | 111 | Args: 112 | min_value (float, optional): minimum value 113 | max_value (float, optional): maximum value (non-inclusive) 114 | width (float, optional): width of kernel functions 115 | spacing (float, optional): spacing between kernel functions 116 | self_thresh (float, optional): value below which a distance is 117 | considered to be a self interaction (i.e. zero) 118 | include_self_interactions (bool, optional): whether or not to include 119 | self-interactions (i.e. distance is zero) 120 | """ 121 | 122 | def __init__( 123 | self, 124 | min_value=-1, 125 | max_value=9, 126 | width=0.2, 127 | spacing=0.2, 128 | self_thresh=1e-5, 129 | include_self_interactions=True, 130 | endpoint=False, 131 | **kwargs 132 | ): 133 | super(KernelBasis, self).__init__(**kwargs) 134 | self._n_centers = int(np.ceil((max_value - min_value) / spacing)) 135 | self.min_value = min_value 136 | self.max_value = max_value 137 | self.spacing = spacing 138 | self.width = width 139 | self.self_thresh = self_thresh 140 | self.include_self_interactions = include_self_interactions 141 | self.endpoint = endpoint 142 | 143 | def call(self, inputs, **kwargs): 144 | in_tensor = tf.expand_dims(inputs, -1) 145 | mu = linspace( 146 | self.min_value, self.max_value, self._n_centers, endpoint=self.endpoint 147 | ) 148 | 149 | mu_prefix_shape = tuple([1 for _ in range(len(K.int_shape(in_tensor)) - 1)]) 150 | mu = tf.reshape(mu, mu_prefix_shape + (-1,)) 151 | values = self.kernel_func(in_tensor, mu) 152 | 153 | if not self.include_self_interactions: 154 | mask = tf.cast(in_tensor >= self.self_thresh, tf.float32) 155 | values *= mask 156 | 157 | return values 158 | 159 | def kernel_func(self, inputs, centres): 160 | raise NotImplementedError 161 | 162 | def compute_output_shape(self, input_shape): 163 | return tf.TensorShape(list(input_shape) + [self._n_centers]) 164 | 165 | def get_config(self): 166 | config = { 167 | "width": self.width, 168 | "spacing": self.spacing, 169 | "min_value": self.min_value, 170 | "max_value": self.max_value, 171 | "self_thresh": self.self_thresh, 172 | "include_self_interactions": self.include_self_interactions, 173 | "endpoint": self.endpoint, 174 | } 175 | base_config = super().get_config() 176 | return {**base_config, **config} 177 | 178 | 179 | class GaussianBasis(KernelBasis): 180 | """Expand distance matrix into Gaussians of width=width, spacing=spacing, 181 | starting at min_value ending at max_value (inclusive if endpoint=True). 182 | 183 | -(x - u)^2 184 | exp(----------) 185 | 2 * w^2 186 | 187 | where: u is linspace(min_value, max_value, ceil((max_value - min_value) / width)) 188 | w is width 189 | 190 | Input: distance_matrix (batch, atoms, atoms) 191 | Output: distance_matrix expanded into Gaussian basis set (batch, atoms, atoms, n_centres) 192 | 193 | Args: 194 | min_value (float, optional): minimum value 195 | max_value (float, optional): maximum value (non-inclusive) 196 | width (float, optional): width of Gaussians 197 | spacing (float, optional): spacing between Gaussians 198 | self_thresh (float, optional): value below which a distance is 199 | considered to be a self interaction (i.e. zero) 200 | include_self_interactions (bool, optional): whether or not to include 201 | self-interactions (i.e. distance is zero) 202 | (batch, atoms, atoms, n_gaussians) 203 | """ 204 | 205 | def kernel_func(self, inputs, centres): 206 | gamma = -0.5 / (self.width ** 2) 207 | return tf.exp(gamma * tf.square(inputs - centres)) 208 | 209 | 210 | class CosineBasis(KernelBasis): 211 | """ 212 | Expand distance value into a vector of dampened cosine activations, each element representing 213 | the activation of a cosine function parameterized by a grid of kappa values, where kappa 214 | refers to the period size of the cosine function.\n 215 | 216 | f(kappa, x) = cos(kappa * x) * e^(-w * x) 217 | 218 | Where: 219 | x is our distance value;\n 220 | w is the width parameter of the dampening; 221 | 222 | Input: distance_matrix (batch, atoms, atoms);\n 223 | Output: distance_matrix expanded into Cosine basis set (batch, atoms, atoms, n_centres) 224 | 225 | Args: 226 | min_value (float, optional): minimum value of kappa 227 | max_value (float, optional): maximum value (non-inclusive) of kappa 228 | width (float, optional): Parameter for dampening, lower width means the cosine function 229 | dampens earlier, and only shorter distances are probed. Keep around 0.2 230 | spacing (float, optional): spacing on the grid of kappa values being used to generate 231 | cosine basis functions 232 | self_thresh (float, optional): value below which a distance is 233 | considered to be a self interaction (i.e. zero) 234 | include_self_interactions (bool, optional): whether or not to include 235 | self-interactions (i.e. distance is zero) 236 | (batch, atoms, atoms, n_gaussians) 237 | """ 238 | 239 | def kernel_func(self, inputs, centres): 240 | return tf.cos(centres * inputs) * self.cutoff(inputs) 241 | 242 | def cutoff(self, inputs): 243 | return tf.exp(-self.width * inputs) 244 | 245 | 246 | class ShiftedCosineBasis(CosineBasis): 247 | def kernel_func(self, inputs, centres): 248 | return (0.5 * (tf.cos(centres * inputs) + 1)) * self.cutoff(inputs) 249 | 250 | 251 | # 252 | # Atom-related functions 253 | # 254 | class AtomicNumberBasis(Layer): 255 | """Expands Gaussian matrix into the one-hot atomic numbers basis 256 | 257 | Inputs: 258 | one_hot_numbers (batch, atoms, max_atomic_number + 1) 259 | gaussians_matrix (batch, atoms, atoms, n_gaussians) 260 | Output: 261 | gaussians_atom_matrix (batch, atoms, atoms, n_gaussians, max_atomic_number + 1) 262 | """ 263 | 264 | def __init__(self, zero_dummy_atoms=False, **kwargs): 265 | kwargs.pop("max_atomic_number", None) # Backward compatibility 266 | super(AtomicNumberBasis, self).__init__(**kwargs) 267 | self.zero_dummy_atoms = zero_dummy_atoms 268 | 269 | def call(self, inputs, **kwargs): 270 | one_hot_numbers, gaussian_mat = inputs 271 | 272 | gaussian_mat = tf.expand_dims(gaussian_mat, axis=-1) 273 | if self.zero_dummy_atoms: 274 | mask = tf.eye(one_hot_numbers.shape[-1], dtype=tf.float32) 275 | mask[0] = 0 276 | one_hot_numbers = K.dot(one_hot_numbers, mask) 277 | one_hot_numbers = tf.expand_dims(one_hot_numbers, axis=1) 278 | one_hot_numbers = tf.expand_dims(one_hot_numbers, axis=3) 279 | return gaussian_mat * one_hot_numbers 280 | 281 | def compute_output_shape(self, input_shapes): 282 | one_hot_numbers_shape, gaussian_mat_shape = input_shapes 283 | return tf.TensorShape(list(gaussian_mat_shape) + [one_hot_numbers_shape[-1]]) 284 | 285 | def get_config(self): 286 | config = {"zero_dummy_atoms": self.zero_dummy_atoms} 287 | base_config = super(AtomicNumberBasis, self).get_config() 288 | return {**base_config, **config} 289 | 290 | 291 | # 292 | # Normalization-related layers 293 | # 294 | class Unstandardization(Layer): 295 | """ 296 | Offsets energies by mean and standard deviation (optionally, per-atom) 297 | 298 | `mu` and `sigma` both follow the following: 299 | If the value is a scalar, apply it equally to all properties 300 | and all types of atoms 301 | 302 | If the value is a vector, each component corresponds to an 303 | output property. It is expanded into a matrix where the 304 | first axis shape is 1. It then follows the matrix rules. 305 | 306 | If the value is a matrix, rows correspond to types of atoms and 307 | columns correspond to properties. 308 | 309 | If there is only one row, then the row vector applies to every 310 | type of atom equally. 311 | 312 | If there is one column, then the scalars are applied to every 313 | property equally. 314 | 315 | If there is a single scalar, then it is treated as a scalar. 316 | 317 | Inputs: the inputs to this layer depend on whether or not mu and sigma 318 | are given as a single scalar or per atom type. 319 | 320 | If scalar: 321 | atomic_props (batch, atoms, energies) 322 | If per type: 323 | one_hot_atomic_numbers (batch, atoms, atomic_number) 324 | atomic_props (batch, atoms, energies) 325 | Output: atomic_props (batch, atoms, energies) 326 | 327 | Attributes: 328 | mu (float, list, or np.ndarray): the mean values by which 329 | to offset the inputs to this layer 330 | sigma (float, list, or np.ndarray): the standard deviation 331 | values by which to scale the inputs to this layer 332 | """ 333 | 334 | def __init__( 335 | self, mu, sigma, trainable=False, per_type=False, use_float64=False, **kwargs 336 | ): 337 | super(Unstandardization, self).__init__(trainable=trainable, **kwargs) 338 | self.init_mu = mu 339 | self.init_sigma = sigma 340 | self.use_float64 = use_float64 341 | 342 | self.mu = np.asanyarray(self.init_mu) 343 | self.sigma = np.asanyarray(self.init_sigma) 344 | 345 | self.per_type = len(self.mu.shape) > 0 or per_type 346 | 347 | @staticmethod 348 | def expand_ones_to_shape(arr, shape): 349 | if len(arr.shape) == 0: 350 | arr = arr.reshape((1, 1)) 351 | if 1 in arr.shape: 352 | tile_shape = tuple( 353 | shape[i] if arr.shape[i] == 1 else 1 for i in range(len(shape)) 354 | ) 355 | arr = np.tile(arr, tile_shape) 356 | if arr.shape != shape: 357 | raise ValueError( 358 | "the arrays were not of the right shape: " 359 | "expected %s but was %s" % (shape, arr.shape) 360 | ) 361 | return arr 362 | 363 | def build(self, input_shapes): 364 | # If mu is given as a vector, assume it applies to all atoms 365 | if len(self.mu.shape) == 1: 366 | self.mu = np.expand_dims(self.mu, axis=0) 367 | if len(self.sigma.shape) == 1: 368 | self.sigma = np.expand_dims(self.sigma, axis=0) 369 | 370 | if self.per_type: 371 | one_hot_atomic_numbers, atomic_props = input_shapes 372 | w_shape = (one_hot_atomic_numbers[-1], atomic_props[-1]) 373 | 374 | self.mu = self.expand_ones_to_shape(self.mu, w_shape) 375 | self.sigma = self.expand_ones_to_shape(self.sigma, w_shape) 376 | else: 377 | w_shape = self.mu.shape 378 | 379 | self.mu = self.add_weight( 380 | name="mu", shape=w_shape, initializer=lambda x, dtype=self.dtype: self.mu 381 | ) 382 | self.sigma = self.add_weight( 383 | name="sigma", 384 | shape=w_shape, 385 | initializer=lambda x, dtype=self.dtype: self.sigma, 386 | ) 387 | super(Unstandardization, self).build(input_shapes) 388 | 389 | def call(self, inputs, **kwargs): 390 | # `atomic_props` should be of shape (batch, atoms, energies) 391 | 392 | # If mu and sigma are given per atom type, need atomic numbers 393 | # to know how to apply them. Otherwise, just energies is enough. 394 | if self.per_type or isinstance(inputs, (list, tuple)): 395 | self.per_type = True 396 | one_hot_atomic_numbers, atomic_props = inputs 397 | atomic_props *= K.dot(one_hot_atomic_numbers, self.sigma) 398 | atomic_props += K.dot(one_hot_atomic_numbers, self.mu) 399 | else: 400 | atomic_props = inputs 401 | atomic_props *= self.sigma 402 | atomic_props += self.mu 403 | 404 | return atomic_props 405 | 406 | def compute_output_shape(self, input_shapes): 407 | if self.per_type or isinstance(input_shapes, list): 408 | atomic_props = input_shapes[-1] 409 | else: 410 | atomic_props = input_shapes 411 | return atomic_props 412 | 413 | def get_config(self): 414 | mu = self.init_mu 415 | if isinstance(mu, (np.ndarray, np.generic)): 416 | if len(mu.shape) > 0: 417 | mu = mu.tolist() 418 | else: 419 | mu = float(mu) 420 | 421 | sigma = self.init_sigma 422 | if isinstance(sigma, (np.ndarray, np.generic)): 423 | if len(sigma.shape) > 0: 424 | sigma = sigma.tolist() 425 | else: 426 | sigma = float(sigma) 427 | 428 | config = { 429 | "mu": mu, 430 | "sigma": sigma, 431 | "per_type": self.per_type, 432 | "use_float64": self.use_float64, 433 | } 434 | base_config = super(Unstandardization, self).get_config() 435 | return {**base_config, **config} 436 | 437 | 438 | # 439 | # Dummy atom-related layers 440 | # 441 | class DummyAtomMasking(Layer): 442 | """ 443 | Masks dummy atoms (atomic number = 0 by default) with zeros 444 | 445 | Inputs: atomic_numbers 446 | Either or both in this order: 447 | atomic_numbers (batch, atoms) 448 | or 449 | one_hot_atomic_numbers (batch, atoms, atomic_number) 450 | value (batch, atoms, ...) 451 | Output: value with zeroes for dummy atoms (batch, atoms, ...) 452 | 453 | Args: 454 | atom_axes (int or iterable of int): axes to which to apply 455 | the masking 456 | 457 | Keyword Args: 458 | dummy_index (int): the index to mask (default: 0) 459 | invert_mask (bool): if True, zeroes all but the desired index rather 460 | than zeroeing the desired index 461 | """ 462 | 463 | def __init__(self, atom_axes=1, **kwargs): 464 | self.invert_mask = kwargs.pop("invert_mask", False) 465 | self.dummy_index = kwargs.pop("dummy_index", 0) 466 | super(DummyAtomMasking, self).__init__(**kwargs) 467 | if isinstance(atom_axes, int): 468 | atom_axes = [atom_axes] 469 | elif isinstance(atom_axes, tuple): 470 | atom_axes = list(atom_axes) 471 | self.atom_axes = atom_axes 472 | 473 | def call(self, inputs, **kwargs): 474 | # `value` should be of shape (batch, atoms, ...) 475 | one_hot_atomic_numbers, value = inputs 476 | atomic_numbers = tf.argmax(one_hot_atomic_numbers, axis=-1) 477 | 478 | # Form the mask that removes dummy atoms (atomic number = dummy_index) 479 | if self.invert_mask: 480 | selection_mask = tf.equal(atomic_numbers, self.dummy_index) 481 | else: 482 | selection_mask = tf.not_equal(atomic_numbers, self.dummy_index) 483 | selection_mask = tf.cast(selection_mask, value.dtype) 484 | 485 | for axis in self.atom_axes: 486 | mask = selection_mask 487 | for _ in range(axis - 1): 488 | mask = tf.expand_dims(mask, axis=1) 489 | # Add one since tf.int_shape does not return batch dim 490 | while len(K.int_shape(value)) != len(K.int_shape(mask)): 491 | mask = tf.expand_dims(mask, axis=-1) 492 | 493 | # Zeros the energies of dummy atoms 494 | value *= mask 495 | return value 496 | 497 | def compute_output_shape(self, input_shapes): 498 | value = input_shapes[-1] 499 | return value 500 | 501 | def get_config(self): 502 | config = { 503 | "atom_axes": self.atom_axes, 504 | "invert_mask": self.invert_mask, 505 | "dummy_index": self.dummy_index, 506 | } 507 | base_config = super(DummyAtomMasking, self).get_config() 508 | return {**base_config, **config} 509 | 510 | 511 | # 512 | # Cutoff functions 513 | # 514 | class CutoffLayer(Layer): 515 | """Base layer for cutoff functions. 516 | 517 | Applies a cutoff function to the expanded distance matrix 518 | 519 | Inputs: 520 | distance_matrix (batch, atoms, atoms) 521 | basis_functions (batch, atoms, atoms, n_centres) 522 | Output: basis_functions with cutoff function multiplied (batch, atoms, atoms, n_centres) 523 | """ 524 | 525 | def __init__(self, cutoff, **kwargs): 526 | super(CutoffLayer, self).__init__(**kwargs) 527 | self.cutoff = cutoff 528 | 529 | def call(self, inputs, **kwargs): 530 | distance_matrix, basis_functions = inputs 531 | 532 | cutoffs = self.cutoff_function(distance_matrix) 533 | cutoffs = K.expand_dims(cutoffs, axis=-1) 534 | 535 | return basis_functions * cutoffs 536 | 537 | def cutoff_function(self, distance_matrix): 538 | """Function responsible for the cutoff. It should also return zeros 539 | for anything greater than the cutoff. 540 | 541 | Args: 542 | distance_matrix (Tensor): the distance matrix tensor 543 | """ 544 | raise NotImplementedError 545 | 546 | def compute_output_shape(self, input_shapes): 547 | _, basis_functions_shape = input_shapes 548 | return basis_functions_shape 549 | 550 | def get_config(self): 551 | config = {"cutoff": self.cutoff} 552 | base_config = super(CutoffLayer, self).get_config() 553 | return dict(list(base_config.items()) + list(config.items())) 554 | 555 | 556 | class CosineCutoff(CutoffLayer): 557 | """The cosine cutoff originally proposed by Behler et al. for ACSFs. 558 | """ 559 | 560 | def cutoff_function(self, distance_matrix): 561 | cos_component = 0.5 * (1 + K.cos(np.pi * distance_matrix / self.cutoff)) 562 | return K.switch( 563 | distance_matrix <= self.cutoff, cos_component, K.zeros_like(distance_matrix) 564 | ) 565 | 566 | 567 | class TanhCutoff(CutoffLayer): 568 | """Alternate tanh^3 cutoff function mentioned in some of the ACSF papers. 569 | """ 570 | 571 | def cutoff_function(self, distance_matrix): 572 | normalization_factor = 1.0 / (K.tanh(1.0) ** 3) 573 | tanh_component = (K.tanh(1.0 - (distance_matrix / self.cutoff))) ** 3 574 | return K.switch( 575 | distance_matrix <= self.cutoff, 576 | normalization_factor * tanh_component, 577 | K.zeros_like(distance_matrix), 578 | ) 579 | 580 | 581 | class LongTanhCutoff(CutoffLayer): 582 | """Custom tanh cutoff function that keeps symmetry functions relatively unscaled 583 | longer than the previously proposed tanh function 584 | """ 585 | 586 | def cutoff_function(self, distance_matrix): 587 | normalization_factor = 1.0 / (K.tanh(float(self.cutoff)) ** 3) 588 | tanh_component = (K.tanh(self.cutoff - distance_matrix)) ** 3 589 | return K.switch( 590 | distance_matrix <= self.cutoff, 591 | normalization_factor * tanh_component, 592 | K.zeros_like(distance_matrix), 593 | ) 594 | 595 | 596 | tf.keras.utils.get_custom_objects().update( 597 | { 598 | OneHot.__name__: OneHot, 599 | DistanceMatrix.__name__: DistanceMatrix, 600 | KernelBasis.__name__: KernelBasis, 601 | GaussianBasis.__name__: GaussianBasis, 602 | AtomicNumberBasis.__name__: AtomicNumberBasis, 603 | Unstandardization.__name__: Unstandardization, 604 | DummyAtomMasking.__name__: DummyAtomMasking, 605 | CutoffLayer.__name__: CutoffLayer, 606 | CosineCutoff.__name__: CosineCutoff, 607 | TanhCutoff.__name__: TanhCutoff, 608 | LongTanhCutoff.__name__: LongTanhCutoff, 609 | } 610 | ) 611 | -------------------------------------------------------------------------------- /tfn/layers/molecular_layers.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Layer 2 | 3 | from .atomic_images import DummyAtomMasking 4 | 5 | from . import Convolution, SelfInteraction, EquivariantActivation 6 | 7 | 8 | class MolecularLayer(Layer): 9 | def build(self, input_shape): 10 | if len(input_shape) < self.total_required_inputs: 11 | raise ValueError( 12 | "Ensure one_hot tensor is passed before other relevant tensors" 13 | ) 14 | _, *input_shape = input_shape 15 | super().build(input_shape) 16 | 17 | def call(self, inputs, **kwargs): 18 | """ 19 | :param inputs: List of tensors, with 'one_hot' as the first tensor 20 | :return: Output tensors of shape (batch, points, si_units, representation_index) 21 | with dummy atom values zeroed. 22 | """ 23 | one_hot, *inputs = inputs 24 | activated_output = super().call(inputs, **kwargs) 25 | if not isinstance(activated_output, list): 26 | activated_output = [activated_output] 27 | return [DummyAtomMasking()([one_hot, tensor]) for tensor in activated_output] 28 | 29 | 30 | class MolecularConvolution(MolecularLayer, Convolution): 31 | """ 32 | Input: 33 | one_hot (batch, points, depth) 34 | image (batch, points, points, basis_functions) 35 | vectors (batch, points, points, 3) 36 | feature_tensors [(batch, points, features_dim, representation_index), ...] 37 | Output: 38 | [(batch, points, si_units, representation_index), ...] 39 | """ 40 | 41 | def __init__(self, *args, **kwargs): 42 | self.total_required_inputs = 4 43 | super().__init__(*args, **kwargs) 44 | 45 | 46 | class MolecularSelfInteraction(MolecularLayer, SelfInteraction): 47 | """ 48 | Input: 49 | one_hot (batch, points, depth) 50 | feature_tensors [(batch, points, features_dim, representation_index), ...] 51 | Output: 52 | [(batch, points, si_units, representation_index), ...] 53 | """ 54 | 55 | def __init__(self, *args, **kwargs): 56 | self.total_required_inputs = 2 57 | super().__init__(*args, **kwargs) 58 | 59 | 60 | class MolecularActivation(MolecularLayer, EquivariantActivation): 61 | """ 62 | Input: 63 | one_hot (batch, points, depth) 64 | feature_tensors [(batch, points, features_dim, representation_index), ...] 65 | Output: 66 | [(batch, points, si_units, representation_index), ...] 67 | """ 68 | 69 | def __init__(self, *args, **kwargs): 70 | self.total_required_inputs = 2 71 | super().__init__(*args, **kwargs) 72 | -------------------------------------------------------------------------------- /tfn/layers/radial_factories.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import tensorflow as tf 4 | from tensorflow.keras import activations, regularizers, Sequential 5 | from tensorflow.keras.layers import Layer 6 | 7 | 8 | class RadialFactory(object): 9 | """ 10 | Abstract class for RadialFactory objects, defines the interface. Subclass 11 | """ 12 | 13 | def __init__( 14 | self, 15 | num_layers: int = 2, 16 | units: int = 32, 17 | activation: str = "ssp", 18 | l2_lambda: float = 0.0, 19 | **kwargs, 20 | ): 21 | self.num_layers = num_layers 22 | self.units = units 23 | if activation is None: 24 | activation = "ssp" 25 | if isinstance(activation, str): 26 | self.activation = activation 27 | else: 28 | raise ValueError( 29 | "Expected `str` for param `activation`, but got `{}` instead. " 30 | "Ensure `activation` is a string mapping to a valid keras " 31 | "activation function" 32 | ) 33 | self.l2_lambda = l2_lambda 34 | self.sum_points = kwargs.pop("sum_points", False) 35 | self.dispensed_radials = 0 36 | 37 | def get_radial(self, feature_dim, input_order=None, filter_order=None): 38 | raise NotImplementedError 39 | 40 | def to_json(self): 41 | self.__dict__["type"] = type(self).__name__ 42 | return json.dumps(self.__dict__) 43 | 44 | @classmethod 45 | def from_json(cls, config: str): 46 | raise NotImplementedError 47 | 48 | 49 | class DenseRadialFactory(RadialFactory): 50 | """ 51 | Default factory class for supplying radial functions to a Convolution layer. Subclass this 52 | factory and override its `get_radial` method to return custom radial instances/templates. 53 | You must also override the `to_json` and `from_json` and register any custom `RadialFactory` 54 | classes to a unique string in the keras global custom objects dict. 55 | """ 56 | 57 | def get_radial(self, feature_dim, input_order=None, filter_order=None): 58 | """ 59 | Factory method for obtaining radial functions of a specified architecture, or an instance 60 | of a radial function (i.e. object which inherits from Layer). 61 | 62 | :param feature_dim: Dimension of the feature tensor being point convolved with the filter 63 | produced by this radial function. Use to ensure radial function outputs a filter of 64 | shape (points, feature_dim, filter_order) 65 | :param input_order: Optional. Rotation order of the of the feature tensor point convolved 66 | with the filter produced by this radial function 67 | :param filter_order: Optional. Rotation order of the filter being produced by this radial 68 | function. 69 | :return: Keras Layer object, or subclass of Layer. Must have attr dynamic == True and 70 | trainable == True. 71 | """ 72 | layers = [ 73 | Radial( 74 | self.units, 75 | self.activation, 76 | self.l2_lambda, 77 | sum_points=self.sum_points, 78 | name=f"radial_{self.dispensed_radials}/layer_{i}", 79 | ) 80 | for i in range(self.num_layers) 81 | ] 82 | layers.append( 83 | Radial( 84 | feature_dim, 85 | self.activation, 86 | self.l2_lambda, 87 | sum_points=self.sum_points, 88 | name=f"radial_{self.dispensed_radials}/layer_{self.num_layers}", 89 | ) 90 | ) 91 | 92 | self.dispensed_radials += 1 93 | return Sequential(layers) 94 | 95 | @classmethod 96 | def from_json(cls, config: str): 97 | return cls(**json.loads(config)) 98 | 99 | 100 | class Radial(Layer): 101 | def __init__( 102 | self, units: int = 32, activation: str = "ssp", l2_lambda: float = 0.0, **kwargs 103 | ): 104 | self.sum_points = kwargs.pop("sum_points", False) 105 | super().__init__(**kwargs) 106 | self.units = units 107 | self.activation = activations.get(activation) 108 | self.l2_lambda = l2_lambda 109 | self.kernel = None 110 | self.bias = None 111 | 112 | def build(self, input_shape): 113 | self.kernel = self.add_weight( 114 | name="kernel", 115 | shape=(input_shape[-1], self.units), 116 | regularizer=regularizers.l2(self.l2_lambda), 117 | ) 118 | self.bias = self.add_weight( 119 | name="bias", 120 | shape=(self.units,), 121 | regularizer=regularizers.l2(self.l2_lambda), 122 | ) 123 | self.built = True 124 | 125 | def compute_output_shape(self, input_shape): 126 | return tf.TensorShape(list(input_shape)[:-1] + [self.units]) 127 | 128 | def get_config(self): 129 | base = super().get_config() 130 | updates = dict(units=self.units, activation=self.activation,) 131 | return {**base, **updates} 132 | 133 | def call(self, inputs, training=None, mask=None): 134 | equation = "bpf,fu->bpu" if self.sum_points else "bpqf,fu->bpqu" 135 | return self.activation(tf.einsum(equation, inputs, self.kernel) + self.bias) 136 | -------------------------------------------------------------------------------- /tfn/layers/utility_layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Layer 3 | 4 | from .atomic_images import ( 5 | OneHot, 6 | DistanceMatrix, 7 | CosineBasis, 8 | ShiftedCosineBasis, 9 | CosineCutoff, 10 | GaussianBasis, 11 | DummyAtomMasking, 12 | ) 13 | 14 | from . import utils 15 | 16 | 17 | class Preprocessing(Layer): 18 | """ 19 | Convenience layer for obtaining required tensors from cartesian and point type tensors. 20 | Defaults to gaussians for image basis functions. 21 | 22 | Input: 23 | cartesian positions (batch, points, 3) 24 | point types (batch, point_type) 25 | Output: 26 | one_hot (batch, points, depth) 27 | image (batch, points, points, basis_functions) 28 | vectors (batch, points, points, 3) 29 | 30 | :param max_z: int. Total number of point types + 1 (for 0 type points) 31 | :param basis_config: dict. Contains: 'width' which specifies the size of the gaussian basis 32 | functions, 'spacing' which defines the size of the grid, 'min_value' which specifies the 33 | beginning point probed by the grid, and 'max_value' which defines the end point of the grid. 34 | """ 35 | 36 | def __init__(self, max_z, basis_config=None, basis_type="gaussian", **kwargs): 37 | self.sum_points = kwargs.pop("sum_points", False) 38 | super().__init__(**kwargs) 39 | self.max_z = max_z 40 | if basis_config is None: 41 | basis_config = { 42 | "width": 0.2, 43 | "spacing": 0.2, 44 | "min_value": -1.0, 45 | "max_value": 15.0, 46 | } 47 | self.basis_config = basis_config 48 | self.basis_type = basis_type 49 | self.one_hot = OneHot(self.max_z) 50 | if self.basis_type == "cosine": 51 | basis_function = CosineBasis(**self.basis_config) 52 | elif self.basis_type == "shifted_cosine": 53 | basis_function = ShiftedCosineBasis(**self.basis_config) 54 | else: 55 | basis_function = GaussianBasis(**self.basis_config) 56 | self.basis_function = basis_function 57 | self.cutoff = CosineCutoff(cutoff=kwargs.pop("cutoff", 15.0)) 58 | self.distance_matrix = MaskedDistanceMatrix() 59 | self.unit_vectors = UnitVectors(self.sum_points) 60 | 61 | def call(self, inputs, **kwargs): 62 | z, r = inputs 63 | one_hot = self.one_hot(z) 64 | dist_matrix = self.distance_matrix([one_hot, r]) 65 | # (batch, points, points, basis_functions) 66 | rbf = self.cutoff([dist_matrix, self.basis_function(dist_matrix)]) 67 | # (batch, points, points, 3) 68 | vectors = self.unit_vectors(r) 69 | if self.sum_points: 70 | rbf = tf.reduce_sum(rbf, axis=-2) 71 | return [one_hot, rbf, vectors] 72 | 73 | def get_config(self): 74 | base = super().get_config() 75 | updates = dict(max_z=self.max_z, basis_config=self.basis_config) 76 | return {**base, **updates} 77 | 78 | def compute_output_shape(self, input_shape): 79 | _, r = input_shape 80 | batch, points, _ = r 81 | return [ 82 | tf.TensorShape([batch, points, self.max_z]), 83 | tf.TensorShape([batch, points, points, self.basis_function._n_centers]), 84 | tf.TensorShape([batch, points, points, 3]), 85 | ] 86 | 87 | 88 | class UnitVectors(Layer): 89 | """ 90 | Input: 91 | cartesian positions (..., batch, points, 3) 92 | Output: 93 | unit vectors between every point in every batch (..., batch, point, point, 3) 94 | """ 95 | 96 | def __init__(self, sum_points: bool = False, **kwargs): 97 | self.sum_points = sum_points 98 | super().__init__(**kwargs) 99 | 100 | def call(self, inputs, **kwargs): 101 | if self.sum_points: 102 | v = inputs 103 | else: 104 | i = tf.expand_dims(inputs, axis=-2) 105 | j = tf.expand_dims(inputs, axis=-3) 106 | v = i - j 107 | den = utils.norm_with_epsilon(v, axis=-1, keepdims=True) 108 | return v / den 109 | 110 | 111 | class MaskedDistanceMatrix(DistanceMatrix): 112 | def call(self, inputs, **kwargs): 113 | one_hot, inputs = inputs 114 | d = super().call(inputs, **kwargs) 115 | return DummyAtomMasking(atom_axes=1)( 116 | [one_hot, DummyAtomMasking(atom_axes=2)([one_hot, d])] 117 | ) 118 | -------------------------------------------------------------------------------- /tfn/layers/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module containing basic utility functions for the TFN layers 3 | """ 4 | 5 | import math 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | 10 | def norm_with_epsilon(x, axis=None, keepdims=False): 11 | """ 12 | Normalizes tensor `x` along `axis`. 13 | 14 | :param x: Tensor being normalized 15 | :param axis: int. Defaults to None, which normalizes the entire tensor. Axis to normalize along. 16 | :param keepdims: bool. Defaults to False. 17 | :return: Normalized tensor. 18 | """ 19 | return tf.sqrt( 20 | tf.maximum(tf.reduce_sum(tf.square(x), axis=axis, keepdims=keepdims), 1e-7) 21 | ) 22 | 23 | 24 | def rotation_matrix(axis_matrix=None, theta=math.pi / 2): 25 | """ 26 | Return the 3D rotation matrix associated with counterclockwise rotation about 27 | the given `axis` by `theta` radians. 28 | 29 | :param axis_matrix: np.ndarray. Defaults to [1, 0, 0], the x-axis. 30 | :param theta: float. Defaults to pi/2. Rotation in radians. 31 | """ 32 | axis_matrix = axis_matrix or [1, 0, 0] 33 | axis_matrix = np.asarray(axis_matrix) 34 | axis_matrix = axis_matrix / math.sqrt(np.dot(axis_matrix, axis_matrix)) 35 | a = math.cos(theta / 2.0) 36 | b, c, d = -axis_matrix * math.sin(theta / 2.0) 37 | aa, bb, cc, dd = a * a, b * b, c * c, d * d 38 | bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d 39 | return np.array( 40 | [ 41 | [aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], 42 | [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], 43 | [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc], 44 | ] 45 | ) 46 | 47 | 48 | def shifted_softplus(x): 49 | y = tf.where(x < 14.0, tf.math.softplus(tf.where(x < 14.0, x, tf.zeros_like(x))), x) 50 | return y - tf.math.log(2.0) 51 | 52 | 53 | def tfn_mae(y_pred, y_true): 54 | loss = tf.abs(y_pred - y_true) 55 | return tf.reduce_mean(loss[loss != 0]) 56 | -------------------------------------------------------------------------------- /tfn/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UPEIChemistry/tensor-field-networks/5c25583ee4108a13af8e73eabd3c448f42cb70a0/tfn/tools/__init__.py -------------------------------------------------------------------------------- /tfn/tools/builders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sub-package containing all builder classes 3 | """ 4 | from .builder import Builder 5 | from .energy_builder import EnergyBuilder 6 | from .force_builder import ForceBuilder 7 | from .missing_point_builder import MissingPointBuilder 8 | from .cartesian_builder import CartesianBuilder 9 | from .classifier_builder import ClassifierBuilder 10 | from .siamese_builder import SiameseBuilder 11 | -------------------------------------------------------------------------------- /tfn/tools/builders/builder.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple 2 | 3 | import tensorflow as tf 4 | from tensorflow.keras import Model 5 | from tensorflow.keras.layers import Add, Input, Lambda 6 | from tfn.layers import ( 7 | DenseRadialFactory, 8 | MolecularConvolution, 9 | Preprocessing, 10 | RadialFactory, 11 | MolecularSelfInteraction, 12 | MolecularActivation, 13 | ) 14 | 15 | 16 | class Builder(object): 17 | def __init__( 18 | self, 19 | max_z: int, 20 | num_points: int, 21 | name: str = "model", 22 | mu: Union[int, list] = None, 23 | sigma: Union[int, list] = None, 24 | standardize: bool = True, 25 | trainable_offsets: bool = True, 26 | embedding_units: int = 32, 27 | radial_factory: Union[RadialFactory, str] = DenseRadialFactory(), 28 | num_layers: Union[int, Tuple[int]] = (3,), 29 | si_units: Union[int, Tuple[int]] = (64, 32, 16), 30 | output_orders: list = None, 31 | residual: bool = True, 32 | activation: str = "ssp", 33 | dynamic: bool = True, 34 | **kwargs, 35 | ): 36 | self.max_z = max_z 37 | self.num_points = num_points 38 | self.name = name 39 | self.mu = mu 40 | self.sigma = sigma 41 | self.standardize = standardize 42 | self.trainable_offsets = trainable_offsets 43 | self.embedding_units = embedding_units 44 | self.radial_factory = radial_factory 45 | 46 | if not isinstance(num_layers, tuple): 47 | num_layers = (num_layers,) 48 | self.num_layers = num_layers 49 | if not isinstance(si_units, tuple): 50 | si_units = self.tuplize_si_units(si_units, self.num_layers) 51 | self.si_units = si_units 52 | 53 | self.output_orders = output_orders 54 | self.residual = residual 55 | self.activation = activation 56 | self.dynamic = dynamic 57 | self.sum_points = kwargs.pop("sum_points", False) 58 | 59 | self.use_scalars = kwargs.pop("use_scalars", True) 60 | if not self.use_scalars: 61 | kwargs.setdefault("final_output_orders", [1]) 62 | self.normalize_max = kwargs.pop("normalize_max", None) 63 | self.normalize_min = kwargs.pop("normalize_min", None) 64 | self.num_final_si_layers = kwargs.pop("num_final_si_layers", 0) 65 | self.final_si_units = kwargs.pop("final_si_units", 32) 66 | 67 | self.point_cloud_layer = Preprocessing( 68 | self.max_z, 69 | kwargs.pop("basis_config", None), 70 | kwargs.pop("basis_type", "gaussian"), 71 | sum_points=self.sum_points, 72 | ) 73 | self.model = None 74 | 75 | @staticmethod 76 | def tuplize_si_units(si_units, num_layers): 77 | return tuple(si_units for _ in range(len(num_layers))) 78 | 79 | def normalize_array(self, array): 80 | if self.normalize_max and self.normalize_min: 81 | return (array - self.normalize_max) / ( 82 | self.normalize_min - self.normalize_max 83 | ) 84 | else: 85 | return array 86 | 87 | def get_model(self, use_cache=True): 88 | if self.model is not None and use_cache: 89 | return self.model 90 | inputs = self.get_inputs() 91 | point_cloud, learned_tensors = self.get_learned_output(inputs) 92 | output = self.get_model_output(point_cloud, learned_tensors) 93 | self.model = Model(inputs=inputs, outputs=output, name=self.name) 94 | return self.model 95 | 96 | def get_inputs(self): 97 | return [ 98 | Input([self.num_points,], dtype="int32", name="atomic_nums"), 99 | Input([self.num_points, 3], dtype="float32", name="cartesians"), 100 | ] 101 | 102 | def get_layers(self, **kwargs): 103 | num_layers = kwargs.pop("num_layers", self.num_layers) 104 | si_units = kwargs.pop("si_units", self.si_units) 105 | clusters, skips = [], [] 106 | for cluster_num, num_layers_in_cluster in enumerate(num_layers): 107 | skips.append( 108 | MolecularConvolution( 109 | name=f"cluster{cluster_num}_skip", 110 | radial_factory=self.radial_factory, 111 | si_units=si_units[cluster_num], 112 | output_orders=self.output_orders, 113 | activation=self.activation, 114 | dynamic=self.dynamic, 115 | sum_points=self.sum_points, 116 | ) 117 | ) 118 | layers = [] 119 | for layer_num in range(num_layers_in_cluster): 120 | layers.append( 121 | MolecularConvolution( 122 | name=f"cluster_{cluster_num}/layer_{layer_num}", 123 | radial_factory=self.radial_factory, 124 | si_units=si_units[cluster_num], 125 | output_orders=self.output_orders, 126 | activation=self.activation, 127 | dynamic=self.dynamic, 128 | sum_points=self.sum_points, 129 | ) 130 | ) 131 | clusters.append(layers) 132 | if self.residual: 133 | return clusters, skips 134 | else: 135 | return [layer for cluster in clusters for layer in cluster] 136 | 137 | def get_learned_tensors(self, tensors, point_cloud, clusters=None): 138 | clusters = clusters or self.get_layers() 139 | output = tensors 140 | if self.residual: 141 | clusters, skips = clusters 142 | for cluster, skip in zip(clusters, skips): 143 | shortcut = output 144 | for layer in cluster: 145 | output = layer(point_cloud + output) 146 | shortcut = skip(point_cloud + shortcut) 147 | output = [Add()([o, s]) for o, s in zip(output, shortcut)] 148 | else: 149 | for layer_num, layer in enumerate(clusters): 150 | output = layer(point_cloud + output) 151 | 152 | return output 153 | 154 | def make_embedding(self, one_hot, layer: MolecularSelfInteraction = None): 155 | scalar = Lambda(lambda x: tf.expand_dims(x, axis=-1))(one_hot) 156 | vector = Lambda(lambda x: tf.tile(x, (1, 1, 1, 3)))(scalar) 157 | if self.residual: 158 | pre_embedding = [one_hot, scalar, vector] 159 | else: 160 | pre_embedding = [one_hot, scalar] 161 | layer = layer or MolecularSelfInteraction(self.embedding_units) 162 | return layer(pre_embedding) 163 | 164 | def get_learned_output(self, inputs: list): 165 | inputs = [ 166 | inputs[0], 167 | inputs[-1], 168 | ] # General case for a single molecule as input (z, r) 169 | point_cloud = self.point_cloud_layer(inputs) # one_hot, rbf, vectors 170 | embedding = self.make_embedding(one_hot=point_cloud[0]) 171 | output = self.get_learned_tensors(embedding, point_cloud) 172 | return point_cloud, output 173 | 174 | def get_final_output(self, one_hot: tf.Tensor, inputs: list, output_dim: int = 1): 175 | output = inputs 176 | for i in range(self.num_final_si_layers): 177 | output = MolecularSelfInteraction(self.final_si_units, name=f"si_{i}")( 178 | [one_hot] + output 179 | ) 180 | output = MolecularActivation(name=f"ea_{i}")([one_hot] + output) 181 | output = MolecularSelfInteraction( 182 | self.final_si_units, name=f"si_{self.num_final_si_layers}" 183 | )([one_hot] + output) 184 | output = MolecularActivation(name=f"ea_{self.num_final_si_layers}")( 185 | [one_hot] + output 186 | ) 187 | return output 188 | 189 | def get_model_output(self, point_cloud: list, inputs: list): 190 | raise NotImplementedError 191 | 192 | @property 193 | def model_config(self): 194 | if self.model: 195 | return self.model.to_json() 196 | -------------------------------------------------------------------------------- /tfn/tools/builders/cartesian_builder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Input, Lambda, Add 4 | 5 | from .multi_trunk_builder import DualTrunkBuilder 6 | from ...layers.utility_layers import MaskedDistanceMatrix 7 | 8 | 9 | class CartesianBuilder(DualTrunkBuilder): 10 | def __init__(self, *args, **kwargs): 11 | self.prediction_type = kwargs.pop( 12 | "prediction_type", "cartesians" 13 | ) # or 'vectors' 14 | self.output_type = kwargs.pop( 15 | "output_type", "cartesians" 16 | ) # or 'distance_matrix' 17 | super().__init__(*args, **kwargs) 18 | 19 | def get_inputs(self): 20 | return [ 21 | Input([self.num_points,], name="atomic_nums", dtype="int32"), 22 | Input([self.num_points, 3], name="reactant_cartesians", dtype="float32"), 23 | Input([self.num_points, 3], name="product_cartesians", dtype="float32"), 24 | ] 25 | 26 | def get_model(self, use_cache=True): 27 | if self.model is not None and use_cache: 28 | return self.model 29 | inputs = self.get_inputs() 30 | point_cloud, learned_tensors = self.get_learned_output(inputs) 31 | output = self.get_model_output(point_cloud, learned_tensors) 32 | if self.prediction_type == "vectors": 33 | # mix reactant and product cartesians commutatively 34 | midpoint = Lambda(lambda x: (x[0] + x[1]) / 2, name="midpoint")( 35 | [inputs[1], inputs[2]] 36 | ) 37 | output = Add(name="cartesians")([midpoint, output]) # (batch, points, 3) 38 | if self.output_type == "distance_matrix": 39 | output = MaskedDistanceMatrix(name="distance_matrix")( 40 | [point_cloud[0][0], output] 41 | ) # (batch, points, points) 42 | output = Lambda( 43 | lambda x: tf.linalg.band_part(x, 0, -1), name="upper_triangle" 44 | )(output) 45 | self.model = Model(inputs=inputs, outputs=output, name=self.name) 46 | return self.model 47 | 48 | def get_learned_output(self, inputs: list): 49 | z, r, p = inputs 50 | point_clouds = [self.point_cloud_layer([z, x]) for x in (r, p)] 51 | inputs = self.get_dual_trunks(point_clouds) 52 | return point_clouds, inputs 53 | 54 | def get_model_output(self, point_cloud: list, inputs: list): 55 | one_hot, output = self.mix_dual_trunks( 56 | point_cloud, inputs, output_order=1, output_type=self.output_type 57 | ) 58 | output = Lambda(lambda x: tf.squeeze(x, axis=-2), name=self.prediction_type)( 59 | output[0] 60 | ) 61 | return output # (batch, points, 3) 62 | -------------------------------------------------------------------------------- /tfn/tools/builders/classifier_builder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.keras.layers import Lambda 3 | 4 | from tfn.layers import EquivariantActivation, MolecularConvolution 5 | from tfn.tools.builders import Builder 6 | 7 | 8 | class ClassifierMixIn(object): 9 | @staticmethod 10 | def average_votes(inputs: list): 11 | output = EquivariantActivation(activation="sigmoid")(inputs) 12 | output = Lambda( 13 | lambda x: tf.squeeze(tf.squeeze(x, axis=-1), axis=-1), name="squeeze" 14 | )( 15 | output[0] 16 | ) # (batch, points) 17 | output = Lambda( 18 | lambda x: tf.reduce_mean(x, axis=-1, keepdims=True), 19 | name="molecular_average", 20 | )(output) 21 | return output 22 | 23 | 24 | class ClassifierBuilder(Builder, ClassifierMixIn): 25 | def get_model_output(self, point_cloud: list, inputs: list): 26 | one_hot = point_cloud[0] 27 | output = MolecularConvolution( 28 | name="energy_layer", 29 | radial_factory=self.radial_factory, 30 | si_units=1, # For molecular energy output 31 | activation=self.activation, 32 | output_orders=[0], 33 | dynamic=self.dynamic, 34 | sum_points=self.sum_points, 35 | )(point_cloud + inputs) 36 | output = self.get_final_output(point_cloud[0], output) 37 | return self.average_votes(output) 38 | -------------------------------------------------------------------------------- /tfn/tools/builders/energy_builder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Lambda 3 | 4 | from tfn.layers.atomic_images import Unstandardization 5 | from tfn.layers import MolecularConvolution 6 | 7 | from . import Builder 8 | 9 | 10 | class EnergyBuilder(Builder): 11 | def get_model_output(self, point_cloud: list, inputs: list): 12 | """ 13 | :return: tf.keras `Model` object. Outputs molecular energy tensor of shape (batch, 1) 14 | """ 15 | output = MolecularConvolution( 16 | name="energy_layer", 17 | radial_factory=self.radial_factory, 18 | si_units=1, # For molecular energy output 19 | activation=self.activation, 20 | output_orders=[0], 21 | dynamic=self.dynamic, 22 | sum_points=self.sum_points, 23 | )(point_cloud + inputs) 24 | output = self.get_final_output(point_cloud[0], output) 25 | atomic_energies = Lambda(lambda x: tf.squeeze(x, axis=-1), name="squeeze")( 26 | output[0] 27 | ) 28 | atomic_energies = Unstandardization( 29 | self.mu, self.sigma, trainable=self.trainable_offsets, name="atomic_energy" 30 | )([point_cloud[0], atomic_energies]) 31 | return Lambda(lambda x: tf.reduce_sum(x, axis=-2), name="molecular_energy")( 32 | atomic_energies 33 | ) 34 | -------------------------------------------------------------------------------- /tfn/tools/builders/force_builder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Lambda 3 | 4 | from tfn.layers.atomic_images import Unstandardization 5 | from . import Builder 6 | 7 | 8 | class ForceBuilder(Builder): 9 | def get_model_output(self, point_cloud: list, inputs: list): 10 | tensors = self.get_final_output(point_cloud[0], inputs) 11 | outputs = [] 12 | if self.use_scalars: 13 | outputs.append(self.sanitize_molecular_energy(tensors[0], point_cloud[0])) 14 | outputs.append(self.sanitize_atomic_forces(tensors[-1])) 15 | return outputs 16 | 17 | def sanitize_molecular_energy(self, energy, one_hot): 18 | atomic_energies = Lambda( 19 | lambda x: tf.squeeze(x, axis=-1), name="energy_squeeze" 20 | )(energy) 21 | atomic_energies = Unstandardization( 22 | self.mu[0], 23 | self.sigma[0], 24 | trainable=self.trainable_offsets, 25 | name="atomic_energy", 26 | )([one_hot, atomic_energies]) 27 | return Lambda(lambda x: tf.reduce_sum(x, axis=-2), name="molecular_energy")( 28 | atomic_energies 29 | ) 30 | 31 | def sanitize_atomic_forces(self, forces): 32 | atomic_forces = Lambda(lambda x: tf.squeeze(x, axis=-2), name="force_squeeze")( 33 | forces 34 | ) 35 | if self.standardize: 36 | atomic_forces = Unstandardization( 37 | self.mu[1], 38 | self.sigma[1], 39 | trainable=self.trainable_offsets, 40 | name="atomic_forces", 41 | )(atomic_forces) 42 | return atomic_forces 43 | -------------------------------------------------------------------------------- /tfn/tools/builders/missing_point_builder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Lambda 3 | from tfn.layers import MolecularConvolution 4 | 5 | from tfn.tools.builders import Builder 6 | 7 | 8 | class MissingPointBuilder(Builder): 9 | def get_model_output(self, point_cloud: list, inputs: list): 10 | output = MolecularConvolution( 11 | name="dual_layer", 12 | radial_factory=self.radial_factory, 13 | si_units=1, # For molecular energy output 14 | activation=self.activation, 15 | dynamic=self.dynamic, 16 | )(point_cloud + inputs) 17 | 18 | # Get Energies (batch, points, 1, 1) -> (batch, 1) 19 | atomic_energies = Lambda( 20 | lambda x: tf.squeeze(x, axis=-1), name="energy_squeeze" 21 | )(output[0]) 22 | molecular_energy = Lambda( 23 | lambda x: tf.reduce_sum(x, axis=-2), name="molecular_energy" 24 | )(atomic_energies) 25 | 26 | # Get Forces (batch, points, 1, 3) -> (batch, points, 3) 27 | atomic_forces = Lambda(lambda x: tf.squeeze(x, axis=-2), name="force")( 28 | output[1] 29 | ) 30 | missing_atom = Lambda(lambda x: tf.reduce_mean(x, axis=1), name="missing_atom")( 31 | atomic_forces 32 | ) 33 | return molecular_energy, missing_atom 34 | -------------------------------------------------------------------------------- /tfn/tools/builders/multi_trunk_builder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.keras.layers import Lambda 3 | 4 | from tfn.layers import MolecularConvolution, MolecularSelfInteraction, SelfInteraction 5 | from tfn.tools.builders import Builder 6 | 7 | 8 | class DualTrunkBuilder(Builder): 9 | def get_dual_trunks(self, point_clouds: list): 10 | embedding_layer = MolecularSelfInteraction( 11 | self.embedding_units, name="embedding" 12 | ) 13 | embeddings = [ 14 | self.make_embedding(pc[0], embedding_layer) for pc in point_clouds 15 | ] 16 | layers = self.get_layers() 17 | inputs = [ 18 | self.get_learned_tensors(e, pc, layers) 19 | for e, pc in zip(embeddings, point_clouds) 20 | ] 21 | return inputs 22 | 23 | def mix_dual_trunks( 24 | self, 25 | point_cloud: list, 26 | inputs: list, 27 | output_order: int = 0, 28 | output_type: str = None, 29 | ): 30 | # Select smaller molecule 31 | one_hots = [p[0] for p in point_cloud] # [(batch, points, max_z), ...] 32 | one_hot = Lambda( 33 | lambda x: tf.where(tf.reduce_sum(x[0]) > tf.reduce_sum(x[1]), x[1], x[0],), 34 | name="one_hot_select", 35 | )(one_hots) 36 | # Truncate to RO0 outputs 37 | layer = MolecularConvolution( 38 | name="truncate_layer", 39 | radial_factory=self.radial_factory, 40 | si_units=self.final_si_units, 41 | activation=self.activation, 42 | output_orders=[output_order], 43 | dynamic=self.dynamic, 44 | sum_points=self.sum_points, 45 | ) 46 | outputs = [layer(z + x)[0] for x, z in zip(inputs, point_cloud)] 47 | output_type = output_type or "vectors" 48 | if output_type == "cartesians": 49 | output = Lambda(lambda x: (x[0] + x[1]), name="learned_midpoint")(outputs) 50 | else: 51 | output = Lambda(lambda x: tf.abs(x[1] - x[0]), name="absolute_difference")( 52 | outputs 53 | ) 54 | output = self.get_final_output(one_hot, output) 55 | return one_hot, output 56 | 57 | def get_final_output(self, one_hot: tf.Tensor, inputs: list, output_dim: int = 1): 58 | output = inputs 59 | for i in range(self.num_final_si_layers): 60 | output = SelfInteraction(self.final_si_units, name=f"si_{i}")(output) 61 | return SelfInteraction(output_dim, name=f"si_{self.num_final_si_layers}")( 62 | output 63 | ) 64 | 65 | def get_model_output(self, point_cloud: list, inputs: list): 66 | raise NotImplementedError 67 | -------------------------------------------------------------------------------- /tfn/tools/builders/siamese_builder.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.keras import Input 2 | 3 | from tfn.tools.builders.classifier_builder import ClassifierMixIn 4 | from tfn.tools.builders.multi_trunk_builder import DualTrunkBuilder 5 | 6 | 7 | class SiameseBuilder(DualTrunkBuilder, ClassifierMixIn): 8 | def get_inputs(self): 9 | return [ 10 | Input([2, self.num_points,], name="atomic_nums", dtype="int32"), 11 | Input([2, self.num_points, 3], name="cartesians", dtype="float32"), 12 | ] 13 | 14 | def get_learned_output(self, inputs: list): 15 | z, c = inputs 16 | point_clouds = [ 17 | self.point_cloud_layer([a, b]) 18 | for a, b in zip( 19 | [z[:, 0], z[:, 1]], [c[:, 0], c[:, 1]] # Split z, c into 4 arrays 20 | ) 21 | ] 22 | inputs = self.get_dual_trunks(point_clouds) 23 | return point_clouds, inputs 24 | 25 | def get_model_output(self, point_cloud: list, inputs: list): 26 | one_hot, output = self.mix_dual_trunks(point_cloud, inputs) 27 | return self.average_votes(output) 28 | -------------------------------------------------------------------------------- /tfn/tools/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.keras.callbacks import Callback 8 | from tensorflow.keras.models import Model 9 | from tensorflow.keras.metrics import categorical_accuracy 10 | from tensorflow.keras.losses import get as get_loss 11 | 12 | from .converters import ndarrays_to_xyz 13 | from ..layers import MaskedDistanceMatrix, OneHot 14 | 15 | 16 | class TestModel(Callback): 17 | def __init__(self, x_test, y_test): 18 | super().__init__() 19 | self.x_test = x_test 20 | self.y_test = y_test 21 | 22 | def on_train_end(self, logs=None): 23 | predictions = self.model.evaluate(x=self.x_test, y=self.y_test) 24 | if not isinstance(predictions, list): 25 | predictions = [predictions] 26 | for pred, name in zip(predictions, self.model.metric_names): 27 | logs["test_{}".format(name)] = pred 28 | 29 | 30 | class CartesianMetrics(Callback): 31 | def __init__( 32 | self, 33 | path: Union[str, Path], 34 | train: list = None, 35 | validation: list = None, 36 | test: list = None, 37 | max_structures: int = 64, 38 | write_rate: int = 50, 39 | tensorboard_logdir: str = None, 40 | ): 41 | """ 42 | :param path: Base path to which .xyz files will be written to. Will create subdirectories 43 | at this path as needed. 44 | :param validation: Defaults to None. List of val data to write. Of shape [[z, r, p], [ts]]. 45 | If None, will not write validation .xyz files. 46 | :param test: list. Defaults to None. Same as validation, but for test data. 47 | :param max_structures: int. Defaults to 10. Max number of structures to write .xyz files 48 | for. 49 | :param write_rate: int. Defaults to 50. Number of epochs to take before writing 50 | validation Cartesians (if validation data provided). 51 | """ 52 | super().__init__() 53 | self.path = Path(path) 54 | self.train = train 55 | self.validation = validation 56 | self.test = test 57 | self.max_structures = max_structures 58 | self.write_rate = write_rate 59 | logdir = tensorboard_logdir or self.path.parent / "logs" 60 | self.file_writers = [ 61 | tf.summary.create_file_writer(str(logdir / "train")), 62 | tf.summary.create_file_writer(str(logdir / "train")), 63 | tf.summary.create_file_writer(str(logdir / "validation")), 64 | tf.summary.create_file_writer(str(logdir / "validation")), 65 | ] 66 | 67 | self.total_epochs = -1 68 | self._prediction_type = "vectors" 69 | self._output_type = "cartesians" 70 | 71 | def get_vectors(self, inputs): 72 | model = Model( 73 | inputs=self.model.input, outputs=self.model.get_layer("vectors").output, 74 | ) 75 | vectors = model.predict([np.expand_dims(a, axis=0) for a in inputs]) 76 | return np.squeeze(vectors, axis=0) 77 | 78 | @staticmethod 79 | def write_vectors(vectors, path, i): 80 | vector_path = path / "vectors" 81 | os.makedirs(vector_path, exist_ok=True) 82 | np.savetxt(vector_path / f"{i}_vectors.txt", vectors) # Write vectors .txt file 83 | 84 | def _get_prediction(self, x): 85 | if self._output_type == "distance_matrix": 86 | model = Model( 87 | inputs=self.model.input, 88 | outputs=self.model.get_layer("cartesians").output, 89 | ) 90 | else: 91 | model = self.model 92 | return model.predict(x) 93 | 94 | def _unwrap_data_lazily(self, data: list): 95 | """ 96 | :param data: data in format (x, y) where x is [atomic_nums, reactants, products] and y is 97 | [ts_cartesians]. 98 | :return: i, z, r, p, ts_true, ts_pred 99 | """ 100 | predicted_transition_states = self._get_prediction(data[0]) 101 | ((atomic_nums, reactants, products), (true_transition_states,),) = data 102 | for i, structures in enumerate( 103 | zip( 104 | atomic_nums[: self.max_structures], 105 | reactants[: self.max_structures], 106 | products[: self.max_structures], 107 | true_transition_states[: self.max_structures], 108 | predicted_transition_states[: self.max_structures], 109 | ) 110 | ): 111 | output = [np.expand_dims(a, 0) for a in structures] 112 | output.insert(0, i) 113 | yield output 114 | 115 | def loss(self, a, b): 116 | if a.shape != b.shape: 117 | return 0 118 | else: 119 | return np.mean(get_loss(self.model.loss)(a, b)) 120 | 121 | @staticmethod 122 | def structure_loss(z, y_pred, y_true): 123 | d = MaskedDistanceMatrix() 124 | one_hot = OneHot(np.max(z) + 1)(z) 125 | dist_matrix = np.abs(d([one_hot, y_pred]) - d([one_hot, y_true])) 126 | dist_matrix = np.triu(dist_matrix) 127 | return ( 128 | float(np.mean(dist_matrix[dist_matrix != 0])), 129 | float(np.mean(np.sum(np.sum(dist_matrix, axis=-1), axis=-1), axis=0)), 130 | ) 131 | 132 | def write_cartesians( 133 | self, data: list, path: Path, write_static_structures: bool = False 134 | ): 135 | for i, z, r, p, true, pred in self._unwrap_data_lazily(data): 136 | # Make .xyz message lines 137 | arrays = ( 138 | {"reactant": r, "product": p, "true": true, "predicted": pred} 139 | if write_static_structures 140 | else {"predicted": pred} 141 | ) 142 | for name, array in arrays.items(): 143 | loss = self.loss(array, true) 144 | error = self.structure_loss(z, array, true)[0] if name != "true" else 0 145 | message = f"loss: {loss} " f"-- distance_error: {error} " 146 | ndarrays_to_xyz(array[0], z[0], path / f"{i}_{name}.xyz", message) 147 | 148 | # Add vector information if relevant 149 | if self._prediction_type == "vectors": 150 | # Write vectors 151 | vectors = self.get_vectors([z, r, p]) 152 | self.write_vectors(vectors, path, i) 153 | 154 | def compute_metrics(self, epoch, split: str = "train"): 155 | if split == "train": 156 | data = self.train 157 | file_writers = self.file_writers[:2] 158 | else: 159 | data = self.validation 160 | file_writers = self.file_writers[2:] 161 | self.write_metrics( 162 | {"distance_error": self._compute_metrics(data)[0]}, 163 | epoch, 164 | file_writers, 165 | split, 166 | ) 167 | 168 | def _compute_metrics(self, data): 169 | z = data[0][0] 170 | y_true = data[1][0] 171 | y_pred = self.model.predict(data[0]) 172 | return self.structure_loss(z, y_pred, y_true) 173 | 174 | def _writing(self, epoch): 175 | return ( 176 | False 177 | if self.write_rate <= 0 178 | else (epoch in range(25) or (epoch + 1) % self.write_rate == 0) 179 | ) 180 | 181 | def write_metrics( 182 | self, 183 | metrics: dict, 184 | epoch: int, 185 | file_writers: list = None, 186 | prefix: str = "scalar", 187 | ): 188 | file_writers = file_writers or self.file_writers 189 | print( 190 | " -- ".join( 191 | [ 192 | f"{prefix}_{name}: {round(metric, 8)}" 193 | for name, metric in metrics.items() 194 | ] 195 | ) 196 | ) 197 | 198 | if self.write_rate > 0: 199 | for writer, (name, metric) in zip(file_writers, metrics.items()): 200 | with writer.as_default(): 201 | tf.summary.scalar(name, metric, epoch) 202 | 203 | def on_train_begin(self, logs=None): 204 | data = self.validation or self.test 205 | if data is None: 206 | return 207 | else: 208 | if data[1][-1].shape[-1] != 3: 209 | self._output_type = "distance_matrix" 210 | if "vectors" not in [layer.name for layer in self.model.layers]: 211 | self._prediction_type = "cartesians" 212 | 213 | self.write_cartesians( 214 | self.train, 215 | self.path / "pre_training" / "train", 216 | write_static_structures=True, 217 | ) 218 | self.write_cartesians( 219 | self.validation, 220 | self.path / "pre_training" / "validation", 221 | write_static_structures=True, 222 | ) 223 | 224 | def on_epoch_end(self, epoch, logs=None): 225 | if self.write_rate <= 0: 226 | return 227 | 228 | if self._output_type == "cartesians": 229 | self.compute_metrics(epoch, "train") 230 | self.compute_metrics(epoch, "validation") 231 | 232 | if self._writing(epoch): 233 | self.total_epochs = epoch 234 | if self.validation is not None and self.train is not None: 235 | 236 | self.write_cartesians( 237 | self.train, self.path / "epochs" / f"epoch_{epoch + 1}" / "train", 238 | ) 239 | self.write_cartesians( 240 | self.validation, 241 | self.path / "epochs" / f"epoch_{epoch + 1}" / "validation", 242 | ) 243 | 244 | def on_train_end(self, logs=None): 245 | for name, data in zip( 246 | ["train", "val", "test"], [self.train, self.validation, self.test] 247 | ): 248 | if data is not None: 249 | print( 250 | f"final {name} loss: {round(self.model.evaluate(*data, verbose=0), 8)}" 251 | ) 252 | if self._output_type == "cartesians": 253 | self.compute_metrics(self.total_epochs + 1, name) 254 | if self.write_rate > 0: 255 | self.write_cartesians(data, self.path / "post_training" / name) 256 | 257 | 258 | class ClassificationMetrics(Callback): 259 | def __init__(self, validation, log_dir): 260 | super().__init__() 261 | self.validation = validation 262 | self.val_f1s = None 263 | self.val_recalls = None 264 | self.val_precisions = None 265 | self.file_writer = tf.summary.create_file_writer(log_dir + "/metrics") 266 | self.file_writer.set_as_default() 267 | 268 | def on_epoch_end(self, epoch, logs=None): 269 | target = self.validation[1] 270 | prediction = np.asarray(self.model.predict(self.validation[0])) 271 | f1score = self.f1_score(target, prediction) 272 | precision = self.precision(target, prediction) 273 | recall = self.recall(target, prediction) 274 | accuracy = np.mean(categorical_accuracy(target, prediction)) 275 | tf.summary.scalar("f1score", f1score, epoch) 276 | tf.summary.scalar("precision", precision, epoch) 277 | tf.summary.scalar("recall", recall, epoch) 278 | tf.summary.scalar("accuracy", accuracy, epoch) 279 | print( 280 | f"Metrics for epoch {epoch}:" 281 | f" -- val_f1score: {f1score} -- val_precision: {precision} -- val_recall: {recall} " 282 | f" -- val_accuracy: {accuracy}" 283 | ) 284 | 285 | def f1_score(self, y_true, y_pred): 286 | recall = self.recall(y_true, y_pred) 287 | precision = self.precision(y_true, y_pred) 288 | return 2 * ((precision * recall) / (precision + recall + 1e-7)) 289 | 290 | def recall(self, y_true, y_pred): 291 | true_positives = np.sum(np.round(np.clip(y_true * y_pred, 0, 1))) 292 | possible_positives = np.sum(np.round(np.clip(y_true, 0, 1))) 293 | return true_positives / (possible_positives + 1e-7) 294 | 295 | def precision(self, y_true, y_pred): 296 | true_positives = np.sum(np.round(np.clip(y_true * y_pred, 0, 1))) 297 | predicted_positives = np.sum(np.round(np.clip(y_pred, 0, 1))) 298 | return true_positives / (predicted_positives + 1e-7) 299 | -------------------------------------------------------------------------------- /tfn/tools/converters.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | 6 | 7 | def element_mapping(): 8 | mapping = { 9 | "C": 6, 10 | "H": 1, 11 | "B": 5, 12 | "Br": 35, 13 | "Cl": 17, 14 | "D": 0, 15 | "F": 9, 16 | "I": 53, 17 | "N": 7, 18 | "O": 8, 19 | "P": 15, 20 | "S": 16, 21 | "Se": 34, 22 | "Si": 14, 23 | } 24 | reverse_mapping = dict([reversed(pair) for pair in mapping.items()]) 25 | mapping.update(reverse_mapping) 26 | return mapping 27 | 28 | 29 | def parse_xyz(f: str): 30 | """ 31 | :param f: str. Path to .xyz file. 32 | :return: List[tuple]. List of atomic coordinates. Tuple indices: (atomic_num, x, y, z) 33 | """ 34 | with open(f, "r") as file: 35 | lines = file.readlines()[2:] 36 | coords = [] 37 | for l in lines: 38 | element, x, y, z = l.split() 39 | if not element.isdigit(): 40 | element = element_mapping()[element] 41 | else: 42 | element = int(element) 43 | coords.append((element, float(x), float(y), float(z))) 44 | return coords 45 | 46 | 47 | def xyz_to_ndarray(path): 48 | """Single .xyz file to cartesian, atomic_nums arrays""" 49 | coordinates = parse_xyz(path) 50 | coordinate_array = np.array(coordinates) 51 | cartesians = coordinate_array[:, 1:] 52 | atomic_nums = coordinate_array[:, :1].astype("int").reshape((-1,)) 53 | return cartesians, atomic_nums 54 | 55 | 56 | def ndarrays_to_xyz(c, z, path, message: "str" = None): 57 | """ 58 | 59 | :param c: cartesian array of shape (points, 3) 60 | :param z: atomic_nums arrray of shape (points,) 61 | :param path: path to .xyz file 62 | :param message: str to add to message portion of .xyz. Defaults to None 63 | :return: 64 | """ 65 | os.makedirs(Path(path).parent, exist_ok=True) 66 | message = message or "" 67 | first_dummy_atom = np.where(z == 0)[0][0] 68 | coordinates = np.concatenate([z.reshape((-1, 1)), c], axis=-1) # (atoms, 4) 69 | text = [ 70 | " ".join([element_mapping()[c[0]], str(c[1]), str(c[2]), str(c[3]),]) 71 | for c in coordinates[:first_dummy_atom] 72 | ] 73 | with open(path, "w") as file: 74 | file.write(str(len(z[:first_dummy_atom])) + "\n") 75 | file.write(f"{message}\n") 76 | file.write("\n".join(text)) 77 | file.write("\n") 78 | -------------------------------------------------------------------------------- /tfn/tools/ingredients.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from sacred import Ingredient 4 | 5 | from .builders import ( 6 | EnergyBuilder, 7 | ForceBuilder, 8 | CartesianBuilder, 9 | SiameseBuilder, 10 | ClassifierBuilder, 11 | ) 12 | from .loaders import ISO17DataLoader, QM9DataDataLoader, TSLoader, SN2Loader, IsomLoader 13 | from .loggers import SacredMetricLogger 14 | from .radials import get_radial_factory 15 | 16 | 17 | # ===== Dataset Ingredient(s) ===== # 18 | data_ingredient = Ingredient("data_loader") 19 | 20 | 21 | @data_ingredient.capture 22 | def get_data_loader( 23 | loader_type: str = "qm9_loader", **kwargs, 24 | ): 25 | """ 26 | :param loader_type: str. Defaults to 'qm9_loader'. Used to specify which loader type is 27 | being used. Supported identifiers: 'qm9_loader', 'iso17_loader', 'ts_loader', 28 | 'isom_loader', 'sn2_loader' 29 | :param kwargs: kwargs passed directly to Loader classes 30 | :return: DataLoader object specified by `loader_type` 31 | """ 32 | if loader_type == "qm9_loader": 33 | return QM9DataDataLoader(**kwargs) 34 | elif loader_type == "iso17_loader": 35 | return ISO17DataLoader(**kwargs) 36 | elif loader_type == "ts_loader": 37 | return TSLoader(**kwargs) 38 | elif loader_type == "isom_loader": 39 | return IsomLoader(**kwargs) 40 | elif loader_type == "sn2_loader": 41 | return SN2Loader(**kwargs) 42 | else: 43 | raise ValueError( 44 | "arg `loader_type` had value: {} which is not supported. " 45 | "Check ingredient docs for supported strings " 46 | "identifiers".format(loader_type) 47 | ) 48 | 49 | 50 | # ===== Builder Ingredient(s) ===== # 51 | builder_ingredient = Ingredient("model_builder") 52 | 53 | 54 | @builder_ingredient.capture 55 | def get_builder( 56 | builder_type: str = "energy_builder", **kwargs, 57 | ): 58 | """ 59 | 60 | :param builder_type: str. Defaults to 'energy_builder'. Possible values include: 61 | 'energy_builder', 'force_builder', 'ts_builder'. 62 | :param kwargs: kwargs passed directly to Builder classes 63 | :return: Builder object specified by 'builder_type' 64 | """ 65 | kwargs["radial_factory"] = get_radial_factory( 66 | kwargs.get("radial_factory", "multi_dense"), kwargs.get("radial_kwargs", None) 67 | ) 68 | if builder_type == "energy_builder": 69 | return EnergyBuilder(**kwargs) 70 | elif builder_type == "force_builder": 71 | return ForceBuilder(**kwargs) 72 | elif builder_type == "cartesian_builder": 73 | return CartesianBuilder(**kwargs) 74 | elif builder_type == "siamese_builder": 75 | return SiameseBuilder(**kwargs) 76 | elif builder_type == "classifier_builder": 77 | return ClassifierBuilder(**kwargs) 78 | else: 79 | raise ValueError( 80 | "arg `builder_type` had value: {} which is not supported. Check " 81 | "ingredient docs for supported string identifiers".format(builder_type) 82 | ) 83 | 84 | 85 | # ===== Logger Ingredient(s) ===== # 86 | logger_ingredient = Ingredient("metric_logger") 87 | get_logger = logger_ingredient.capture(SacredMetricLogger) 88 | -------------------------------------------------------------------------------- /tfn/tools/jobs/__init__.py: -------------------------------------------------------------------------------- 1 | from .job import Job 2 | from .keras_job import KerasJob 3 | from .regression import Regression, StructurePrediction 4 | from .classification import Classification 5 | from .pipeline import Pipeline 6 | from .search import GridSearch 7 | from .cross_validate import CrossValidate 8 | from .load_model import LoadModel 9 | -------------------------------------------------------------------------------- /tfn/tools/jobs/classification.py: -------------------------------------------------------------------------------- 1 | from sacred.run import Run 2 | from tensorflow.keras.models import Model 3 | 4 | from . import KerasJob 5 | from ..callbacks import ClassificationMetrics 6 | 7 | 8 | class Classification(KerasJob): 9 | def _fit( 10 | self, run: Run, fitable: Model, data: tuple, callbacks: list = None, 11 | ) -> Model: 12 | return super()._fit( 13 | run, 14 | fitable, 15 | data, 16 | callbacks=[ClassificationMetrics(data[-1], run.observers[0].dir + "/logs")], 17 | ) 18 | -------------------------------------------------------------------------------- /tfn/tools/jobs/config_defaults.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | 4 | # General to all Jobs 5 | run_config = { # Defines kwargs used when running a job 6 | "storage_dir": "./sacred_storage", 7 | "model_path": "./model.hdf5", 8 | "epochs": 100, 9 | "batch_size": 32, 10 | "test": True, 11 | "write_test_results": False, 12 | "save_model": True, 13 | "use_strategy": False, 14 | "loss": "tfn_mae", 15 | "optimizer": "adam", 16 | "loss_weights": None, 17 | "class_weight": None, 18 | "metrics": None, 19 | "run_eagerly": False, 20 | "select_few": False, 21 | "capture_output": True, 22 | "num_models_to_test": 1, 23 | "fit_verbosity": 2, 24 | "save_verbosity": 0, 25 | "freeze_layers": False, 26 | "use_default_callbacks": True, 27 | "root_dir": None, 28 | } 29 | loader_config = { # Passed directly to loader classes 30 | "loader_type": "qm9_loader", 31 | "map_points": True, 32 | "splitting": "75:15:10", 33 | "pre_load": False, 34 | "load_kwargs": {"cache": True}, 35 | } 36 | 37 | # SingleModel Only 38 | builder_config = { # Passed directly to builder classes 39 | "builder_type": "energy_builder", 40 | "standardize": True, 41 | "trainable_offsets": True, 42 | "name": "model", 43 | "embedding_units": 64, 44 | "num_layers": (2, 2, 2), 45 | "si_units": 64, 46 | "max_filter_order": 1, 47 | "residual": True, 48 | "activation": "ssp", 49 | "dynamic": False, 50 | "sum_points": False, 51 | "basis_type": "gaussian", 52 | "basis_config": { # 800 functions 53 | "width": 0.2, 54 | "spacing": 0.02, 55 | "min_value": -1.0, 56 | "max_value": 15.0, 57 | }, 58 | "num_final_si_layers": 1, 59 | "final_si_units": 32, 60 | "radial_factory": "multi_dense", 61 | "radial_kwargs": { 62 | "num_layers": 2, 63 | "units": 64, 64 | "activation": "ssp", 65 | "kernel_lambda": 0.01, 66 | "bias_lambda": 0.01, 67 | }, 68 | } 69 | 70 | # Search Only 71 | factory_config = dict( # Passed directly to factory class 72 | **builder_config, **run_config, **{"factory_type": "energy_factory"} 73 | ) 74 | 75 | # Pipeline only 76 | pipeline_config = { 77 | "configs": [ 78 | {"builder_config": builder_config, "loader_config": loader_config}, 79 | {"builder_config": builder_config, "loader_config": loader_config}, 80 | ], 81 | "freeze_layers": False, 82 | } 83 | 84 | tuner_config = { # Passed directly to tuner classes 85 | "project_name": datetime.utcnow().strftime("%Y-%m-%d-%H%MZ"), 86 | "directory": "./tuner_storage", 87 | "objective": "val_loss", 88 | } 89 | 90 | # Callbacks 91 | tb_config = { # Passed directly to tensorboard callback 92 | "histogram_freq": 10, 93 | "update_freq": "epoch", 94 | "write_images": False, 95 | } 96 | lr_config = { # Passed directly to ReduceLROnPlateau callback 97 | "monitor": "val_loss", 98 | "factor": 0.5, 99 | "patience": 20, 100 | "verbose": 1, 101 | "min_delta": 0.0001, 102 | "cooldown": 20, 103 | "min_lr": 0.000001, 104 | } 105 | 106 | cm_config = {"max_structures": 64, "write_rate": 50} 107 | 108 | 109 | # Precanned search-spaces 110 | default_architecture_search = { # 1994 possible models 111 | "si_units": {"type": "choice", "kwargs": {"values": [16, 32, 64, 128]}}, 112 | "model_num_layers": { 113 | "type": "choice", 114 | "kwargs": {"values": [4, 8, 16, 32, 64, 128]}, 115 | }, 116 | "num_final_si_layers": {"type": "choice", "kwargs": {"values": [0, 1, 2]}}, 117 | "final_si_units": {"type": "choice", "kwargs": {"values": [8, 16, 32]}}, 118 | "radial_num_layers": {"type": "choice", "kwargs": {"values": [1, 2, 3]}}, 119 | "radial_units": {"type": "choice", "kwargs": {"values": [16, 32, 64]}}, 120 | } 121 | 122 | 123 | default_grid_search = { # 432 models 124 | "residual": [True, False], 125 | "model_num_layers": [[2 for _ in range(i + 1)] for i in [0, 2, 8, 16]], # 4 126 | "num_final_si_layers": [1, 2, 3], # 3 127 | "final_si_units": [16, 32, 64], # 3 128 | "radial_factory": ["single_dense", "multi_dense"], # 2 129 | "radial_kwargs": [ # 3 130 | { 131 | "num_layers": 1, 132 | "units": 64, 133 | "activation": "ssp", 134 | "kernel_lambda": 0.01, 135 | "bias_lambda": 0.01, 136 | }, 137 | { 138 | "num_layers": 2, 139 | "units": 64, 140 | "activation": "ssp", 141 | "kernel_lambda": 0.01, 142 | "bias_lambda": 0.01, 143 | }, 144 | { 145 | "num_layers": 3, 146 | "units": 64, 147 | "activation": "ssp", 148 | "kernel_lambda": 0.01, 149 | "bias_lambda": 0.01, 150 | }, 151 | ], 152 | } 153 | -------------------------------------------------------------------------------- /tfn/tools/jobs/cross_validate.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | 3 | import numpy as np 4 | from sacred.run import Run 5 | from tensorflow.keras.models import Model 6 | 7 | from . import KerasJob, config_defaults 8 | from ..callbacks import CartesianMetrics 9 | 10 | 11 | class CrossValidate(KerasJob): 12 | @property 13 | def config_defaults(self): 14 | base = super().config_defaults 15 | base["loader_config"][ 16 | "map_points" 17 | ] = False # Ensure reconstruction works properly 18 | base["cm_config"] = copy(config_defaults.cm_config) 19 | return base 20 | 21 | def _main( 22 | self, 23 | run: Run, 24 | seed: int, 25 | fitable: Model = None, 26 | fitable_config: dict = None, 27 | loader_config: dict = None, 28 | ): 29 | # folds: (((x1, x2, ...), (y1, y2, ...)), ...) 30 | model = None 31 | train_loss = [] 32 | val_loss = [] 33 | loader, folds = self._load_data(loader_config) 34 | print(f"**CROSS VALIDATING WITH {len(folds)} FOLDS**") 35 | root = self.exp_config["run_config"]["root_dir"] 36 | 37 | # Loop over folds 38 | for i in range(len(folds)): 39 | print(f"CROSS VALIDATING USING FOLD {i} AS VAL FOLD...") 40 | val = folds[i] 41 | train = self._combine_folds(folds[:i] + folds[i + 1 :]) 42 | data = (train, val, None) # No testing data 43 | model = self._load_fitable(loader, fitable_config) 44 | 45 | # Preload weights if necessary 46 | if fitable is not None: 47 | fitable.save_weights("./temp_weights.hdf5") 48 | model.load_weights("./temp_weights.hdf5") 49 | 50 | # fit the new model 51 | self.exp_config["run_config"]["root_dir"] = root / f"cv_model_{i}" 52 | model = self._fit( 53 | run, 54 | model, 55 | data, 56 | callbacks=[ 57 | CartesianMetrics( 58 | self.exp_config["run_config"]["root_dir"] / "cartesians", 59 | *data, 60 | **self.exp_config["cm_config"], 61 | ) 62 | ], 63 | ) 64 | 65 | # [(loss, metric1, metric2, ...), ...] 66 | train_loss.append(self._evaluate_fold(model, train)) 67 | val_loss.append(self._evaluate_fold(model, val)) 68 | 69 | loss = np.array([train_loss, val_loss]) # (2, num_folds, ?) 70 | print(f"AVERAGE TRAIN LOSS ACROSS MODELS {np.mean(loss[0], axis=0).tolist()}") 71 | print(f"STANDARD DEVIATION: {np.std(loss[0], axis=0).tolist()}") 72 | print("Final train losses: {}".format("\n".join(map(str, train_loss)))) 73 | 74 | print(f"AVERAGE VAL LOSS ACROSS MODELS {np.mean(loss[1], axis=0).tolist()}") 75 | print(f"STANDARD DEVIATION: {np.std(loss[1], axis=0).tolist()}") 76 | print("Final val losses: {}".format("\n".join(map(str, val_loss)))) 77 | return model 78 | 79 | def _evaluate_fold(self, fitable: Model, data: list): 80 | loss = fitable.evaluate(*data, verbose=0) 81 | if not isinstance(loss, list): 82 | loss = [loss] 83 | loss.append( 84 | CartesianMetrics.structure_loss( 85 | data[0][0], fitable.predict(data[0]), data[1][0] 86 | )[0] 87 | ) 88 | return loss 89 | 90 | @staticmethod 91 | def _combine_folds(folds): 92 | """ 93 | :param folds: list. Folds to be combined. Of shape ((x, y), ...) where x and y are lists 94 | of ndarrays 95 | :return: list. Folds concatenated to the shape (x, y), where x and y are lists of 96 | ndarrays concatenated along axis 0 across all folds. 97 | """ 98 | x_arrays = [[] for _ in folds[0][0]] 99 | y_arrays = [[] for _ in folds[0][1]] 100 | for (x, y) in folds: 101 | for j, array in enumerate(x): 102 | x_arrays[j].append(array) 103 | for j, array in enumerate(y): 104 | y_arrays[j].append(array) 105 | combined_folds = [ 106 | [np.concatenate(x, axis=0) for x in x_arrays], 107 | [np.concatenate(y, axis=0) for y in y_arrays], 108 | ] 109 | return combined_folds 110 | -------------------------------------------------------------------------------- /tfn/tools/jobs/job.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from abc import ABCMeta, abstractmethod 3 | from copy import copy 4 | from pathlib import Path 5 | from typing import List 6 | 7 | from sacred import Experiment 8 | from sacred.observers import FileStorageObserver, MongoObserver, RunObserver 9 | from tensorflow.keras.models import Model 10 | 11 | from tfn.tools.ingredients import builder_ingredient, data_ingredient 12 | from tfn.tools.jobs import config_defaults as cd 13 | 14 | 15 | class Job(metaclass=ABCMeta): 16 | def __init__( 17 | self, 18 | exp_config: dict = None, 19 | add_defaults: bool = True, 20 | mongo_hostnames: list = None, 21 | ): 22 | exp_config = exp_config or dict() 23 | if add_defaults: 24 | self.exp_config = self.add_config_defaults(exp_config) 25 | else: 26 | self.exp_config = exp_config 27 | if mongo_hostnames is None: 28 | mongo_hostnames = ["tater"] 29 | self.mongo_hostnames = mongo_hostnames 30 | 31 | self._experiment = None 32 | self._observers = [] 33 | 34 | @property 35 | def default_observers(self): 36 | observers = [] 37 | if socket.gethostname() in self.mongo_hostnames: 38 | observers.append( 39 | MongoObserver( 40 | url=f"mongodb://sample:password@localhost:27017/?authMechanism=SCRAM-SHA-1", 41 | db_name="db", 42 | ) 43 | ) 44 | observers.append( 45 | FileStorageObserver(self.exp_config.get("storage_dir", "./sacred_storage")) 46 | ) 47 | return observers 48 | 49 | @property 50 | def experiment(self): 51 | """ 52 | Experiment object required for Sacred. 53 | 54 | :return: sacred.Experiment object. 55 | """ 56 | if self._experiment is None: 57 | self._experiment = Experiment( 58 | name=self.exp_config.get("name"), 59 | ingredients=self.exp_config.get("ingredients"), 60 | ) 61 | observers = self._observers or self.default_observers 62 | self._experiment.observers.extend(observers) 63 | self._experiment.add_config(self.exp_config) 64 | if not self.exp_config["run_config"]["capture_output"]: 65 | self._experiment.captured_out_filter = ( 66 | lambda *args, **kwargs: "Output capturing turned off." 67 | ) 68 | return self._experiment 69 | 70 | @staticmethod 71 | def set_config_defaults(d: dict, values: dict): 72 | for k, v in values.items(): 73 | d.setdefault(k, v) 74 | 75 | def add_config_defaults(self, ec: dict): 76 | for name, conf in self.config_defaults.items(): 77 | if name in ec: 78 | self.set_config_defaults(ec[name], conf) 79 | else: 80 | ec.setdefault(name, conf) 81 | return ec 82 | 83 | def update_observers(self, o: List[RunObserver]): 84 | """ 85 | ONLY USE BEFORE CALLING `self.experiment` AS OBSERVERS CANNOT BE SET AFTER THE EXPERIMENT 86 | IS CREATED. 87 | 88 | :param o: List of sacred RunObservers to update Job observers. 89 | """ 90 | self._observers.extend(o) 91 | 92 | def override_observers(self, o: List[RunObserver]): 93 | """ 94 | ONLY USE BEFORE CALLING `self.experiment`. Replace defaults with new list of 95 | RunObserver objects. 96 | :param o: List of new sacred RunObservers 97 | """ 98 | self._observers = o 99 | 100 | @abstractmethod 101 | def _main(self, run, seed, fitable, fitable_config, loader_config): 102 | """ 103 | Private method containing the actual work completed by the job. Implemented is a default 104 | workflow for a basic keras/kerastuner type job. 105 | 106 | :param run: sacred.Run object. See sacred documentation for more details on utility. 107 | :param fitable: Optional tensorflow.keras.Model or kerastuner.Tuner object. 108 | Model-like which contains a fit method. 109 | :param fitable_config: Optional dict. Contains data which can be used to create a new 110 | fitable instance. 111 | :param loader_config: Optional dict. Contains data which can be used to create a new 112 | DataLoader instance. 113 | """ 114 | pass 115 | 116 | def run( 117 | self, 118 | fitable: Model = None, 119 | fitable_config: dict = None, 120 | loader_config: dict = None, 121 | ): 122 | """ 123 | Exposed method of the particular job. Runs whatever work is entailed by the job based on 124 | the content provided in `self.exp_config`. 125 | """ 126 | 127 | @self.experiment.main 128 | def main(_run, _seed): 129 | self.exp_config["run_config"]["root_dir"] = Path( 130 | _run.observers[0].dir 131 | ).absolute() 132 | self._main(_run, _seed, fitable, fitable_config, loader_config) 133 | 134 | self.experiment.run() 135 | 136 | @abstractmethod 137 | def _load_data(self, config): 138 | """ 139 | Obtains a loader using ingredients.get_loader and self.exp_config['loader_config'] 140 | 141 | :return: Loader object and the data returned by that Loader's get_data method. 142 | """ 143 | pass 144 | 145 | @abstractmethod 146 | def _load_fitable(self, loader, fitable_config): 147 | """ 148 | Defines and compiles a fitable (keras.model or keras_tuner.tuner) which implements 149 | a 'fit' method. This method calls either get_builder, or get_hyper_factory, depending on 150 | which type of fitable is beind loaded. 151 | 152 | :return: Model or Tuner object. 153 | """ 154 | pass 155 | 156 | @abstractmethod 157 | def _fit(self, run, fitable, data, callbacks): 158 | """ 159 | 160 | :param run: sacred.Run object. See sacred documentation for details on utility. 161 | :param fitable: tensorflow.keras.Model object. 162 | :param data: tuple. train and validation data in the form (train, val), where train is 163 | the tuple (x_train, y_train). 164 | :param callbacks: Optional list. List of tensorflow.keras.Callback objects to pass to 165 | fitable.fit method. 166 | :return: tensorflow.keras.Model object. 167 | """ 168 | pass 169 | 170 | @abstractmethod 171 | def _test_fitable(self, run, fitable, test_data): 172 | """ 173 | :param fitable: tensorflow.keras.Model object. 174 | :param test_data: tuple. contains (x_test, y_test). 175 | :return: float. Scalar test_loss value. 176 | """ 177 | pass 178 | 179 | @abstractmethod 180 | def _save_fitable(self, run, fitable): 181 | """ 182 | :param run: sacred.Run object. see sacred documentation for more details on utility. 183 | :param fitable: tensorflow.keras.Model object. 184 | """ 185 | pass 186 | 187 | @abstractmethod 188 | def _new_model_path(self, i): 189 | pass 190 | 191 | @property 192 | def config_defaults(self): 193 | """ 194 | Defines default values for the various config dictionaries required for the Job. 195 | 196 | :return: dict. Experiment dictionary containing necessary config(s) for the Job. 197 | """ 198 | return { 199 | "ingredients": [data_ingredient, builder_ingredient], 200 | "run_config": copy(cd.run_config), 201 | "loader_config": copy(cd.loader_config), 202 | "builder_config": copy(cd.builder_config), 203 | "tb_config": copy(cd.tb_config), 204 | "lr_config": copy(cd.lr_config), 205 | } 206 | -------------------------------------------------------------------------------- /tfn/tools/jobs/keras_job.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Tuple 3 | 4 | from sacred.run import Run 5 | import tensorflow as tf 6 | from tensorflow.keras.callbacks import ReduceLROnPlateau, TensorBoard 7 | from tensorflow.keras.models import Model 8 | 9 | from .job import Job 10 | from ..ingredients import ( 11 | get_data_loader, 12 | get_builder, 13 | ) 14 | from ..loaders import DataLoader 15 | 16 | 17 | class KerasJob(Job): 18 | def _main( 19 | self, 20 | run: Run, 21 | seed: int, 22 | fitable: Model = None, 23 | fitable_config: dict = None, 24 | loader_config: dict = None, 25 | ): 26 | """ 27 | Private method containing the actual work completed by the job. Implemented is a default 28 | workflow for a basic keras/kerastuner type job. 29 | 30 | :param run: sacred.Run object. See sacred documentation for more details on utility. 31 | :param fitable: Optional tensorflow.keras.Model or kerastuner.Tuner object. 32 | Model-like which contains a fit method. 33 | :param fitable_config: Optional dict. Contains data which can be used to create a new 34 | fitable instance. 35 | :param loader_config: Optional dict. Contains data which can be used to create a new 36 | DataLoader instance. 37 | """ 38 | loader, data = self._load_data(loader_config) 39 | fitable = fitable or self._load_fitable(loader, fitable_config) 40 | fitable = self._fit(run, fitable, data) 41 | if self.exp_config["run_config"]["test"]: 42 | self._test_fitable(run, fitable, data[-1]) 43 | if self.exp_config["run_config"]["save_model"]: 44 | self._save_fitable(run, fitable) 45 | return fitable 46 | 47 | def _load_data(self, config: dict = None) -> Tuple[DataLoader, Tuple]: 48 | """ 49 | Obtains a loader using ingredients.get_loader and self.exp_config['loader_config'] 50 | 51 | :param config: Optional dict. config passed to get_data_loader to obtain specific 52 | data_loader class. 53 | :return: Loader object and the data returned by that Loader's get_data method. 54 | """ 55 | config = config or self.exp_config["loader_config"] 56 | loader = get_data_loader(**config) 57 | if self.exp_config["run_config"]["select_few"]: 58 | data = loader.few_examples(**config["load_kwargs"]) 59 | else: 60 | data = loader.load_data(**config["load_kwargs"]) 61 | return loader, data 62 | 63 | def _load_fitable(self, loader: DataLoader, fitable_config: dict = None) -> Model: 64 | """ 65 | Defines and compiles a fitable (keras.model or keras_tuner.tuner) which implements 66 | a 'fit' method. This method calls either get_builder, or get_hyper_factory, depending on 67 | which type of fitable is beind loaded. 68 | 69 | :return: Model or Tuner object. 70 | """ 71 | fitable_config = fitable_config or self.exp_config["builder_config"] 72 | conf = dict( 73 | **fitable_config, 74 | max_z=loader.max_z, 75 | num_points=loader.num_points, 76 | mu=loader.mu, 77 | sigma=loader.sigma, 78 | ) 79 | builder = get_builder(**conf) 80 | run_config = self.exp_config["run_config"] 81 | compile_kwargs = dict( 82 | loss=run_config["loss"], 83 | loss_weights=run_config["loss_weights"], 84 | optimizer=run_config["optimizer"], 85 | metrics=run_config["metrics"], 86 | run_eagerly=run_config["run_eagerly"], 87 | ) 88 | if run_config["use_strategy"]: 89 | strategy = tf.distribute.MirroredStrategy() 90 | with strategy.scope(): 91 | model = builder.get_model() 92 | model.compile(**compile_kwargs) 93 | else: 94 | model = builder.get_model() 95 | model.compile(**compile_kwargs) 96 | return model 97 | 98 | def _fit( 99 | self, run: Run, fitable: Model, data: tuple, callbacks: list = None, 100 | ) -> Model: 101 | """ 102 | 103 | :param run: sacred.Run object. See sacred documentation for details on utility. 104 | :param fitable: tensorflow.keras.Model object. 105 | :param data: tuple. train, validation, and test data in the form (train, val, test), 106 | where train is 107 | the tuple (x_train, y_train). 108 | :param callbacks: Optional list. List of tensorflow.keras.Callback objects to pass to 109 | fitable.fit method. 110 | :return: tensorflow.keras.Model object. 111 | """ 112 | tensorboard_directory = self.exp_config["run_config"]["root_dir"] / "logs" 113 | (x_train, y_train), val, _ = data 114 | callbacks = callbacks or [] 115 | if self.exp_config["run_config"]["use_default_callbacks"]: 116 | callbacks.extend( 117 | [ 118 | TensorBoard( 119 | **dict( 120 | **self.exp_config["tb_config"], 121 | log_dir=tensorboard_directory, 122 | ) 123 | ), 124 | ReduceLROnPlateau(**self.exp_config["lr_config"]), 125 | ] 126 | ) 127 | kwargs = dict( 128 | x=x_train, 129 | y=y_train, 130 | epochs=self.exp_config["run_config"]["epochs"], 131 | batch_size=self.exp_config["run_config"]["batch_size"], 132 | validation_data=val, 133 | class_weight=self.exp_config["run_config"]["class_weight"], 134 | callbacks=callbacks, 135 | verbose=self.exp_config["run_config"]["fit_verbosity"], 136 | ) 137 | fitable.fit(**kwargs) 138 | return fitable 139 | 140 | def _test_fitable(self, run: Run, fitable: Model, test_data: tuple) -> float: 141 | """ 142 | :param fitable: tensorflow.keras.Model object. 143 | :param test_data: tuple. contains (x_test, y_test). 144 | :return: float. Scalar test_loss value. 145 | """ 146 | if test_data is None: 147 | return 0.0 148 | x_test, y_test = test_data 149 | loss = fitable.evaluate(x=x_test, y=y_test, verbose=0) 150 | print(f"Test split results: {loss}") 151 | return loss 152 | 153 | def _save_fitable(self, run: Run, fitable: Model): 154 | """ 155 | :param run: sacred.Run object. see sacred documentation for more details on utility. 156 | :param fitable: tensorflow.keras.Model object. 157 | """ 158 | path = self.exp_config["run_config"]["model_path"] 159 | if self.exp_config["run_config"]["save_verbosity"] > 0: 160 | fitable.summary() 161 | fitable.save(self.exp_config["run_config"]["model_path"]) 162 | run.add_artifact(path) 163 | 164 | def _new_model_path(self, name: str): 165 | model_path = Path(self.exp_config["run_config"]["model_path"]).parent / name 166 | self.exp_config["run_config"]["model_path"] = model_path 167 | return model_path 168 | -------------------------------------------------------------------------------- /tfn/tools/jobs/load_model.py: -------------------------------------------------------------------------------- 1 | from sacred.run import Run 2 | from tensorflow.keras.models import load_model, Model 3 | 4 | from . import KerasJob 5 | from ..loaders import DataLoader 6 | 7 | 8 | class LoadModel(KerasJob): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | path = self.exp_config["run_config"]["model_path"] 12 | print(f"Loading pre-trained model from file {path}") 13 | self.model = load_model(path) 14 | 15 | def _main( 16 | self, 17 | run: Run, 18 | seed: int, 19 | fitable: Model = None, 20 | fitable_config: dict = None, 21 | loader_config: dict = None, 22 | ): 23 | self._save_fitable(run, self.model) 24 | return self.model 25 | 26 | def _load_fitable(self, loader: DataLoader, fitable_config: dict = None) -> Model: 27 | return self.model 28 | -------------------------------------------------------------------------------- /tfn/tools/jobs/pipeline.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List 3 | 4 | from tensorflow.keras.models import Model, load_model 5 | from tensorflow.keras.layers import InputLayer 6 | from tfn.layers.atomic_images import Unstandardization 7 | 8 | from . import KerasJob 9 | 10 | 11 | class Pipeline(KerasJob): 12 | NONTRANSFERABLE_LAYERS = (Unstandardization, InputLayer) 13 | BLACKLISTED_LAYER_NAMES = ["embedding"] 14 | 15 | def __init__(self, jobs: List[KerasJob], *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.jobs = jobs 18 | 19 | def _main(self, run, seed, fitable=None, loader_config=None, fitable_config=None): 20 | model_path = None 21 | for i, job in enumerate(self.jobs): 22 | loader, _ = job._load_data() 23 | fitable = job._load_fitable(loader) 24 | try: 25 | fitable = self.initialize_fitable_weights(fitable, model_path) 26 | except (FileNotFoundError, TypeError) as _: 27 | pass 28 | 29 | model_path = self._new_model_path(f"model_{i}.h5") 30 | job.exp_config["run_config"]["model_path"] = model_path 31 | job.exp_config["run_config"]["root_dir"] = ( 32 | self.exp_config["run_config"]["root_dir"] / f"pipeline_model_{i}" 33 | ) 34 | fitable = job._main(run, seed, fitable) 35 | 36 | return fitable 37 | 38 | def layer_is_valid(self, layer): 39 | if layer is None: 40 | return False 41 | elif any( 42 | [ 43 | isinstance(layer, self.NONTRANSFERABLE_LAYERS), 44 | any([layer.name in name for name in self.BLACKLISTED_LAYER_NAMES]), 45 | ] 46 | ): 47 | return False 48 | else: 49 | return True 50 | 51 | def initialize_fitable_weights(self, fitable: Model, path) -> Model: 52 | if not Path(path).exists(): 53 | raise FileNotFoundError( 54 | f"hdf5 file {path} does not exist - cannot read weights." 55 | ) 56 | temp_path = self.exp_config["run_config"]["root_dir"] / "temp_model.h5" 57 | fitable.save(temp_path) 58 | fitable = load_model(temp_path) 59 | fitable.load_weights(path, by_name=True, skip_mismatch=True) 60 | return fitable 61 | -------------------------------------------------------------------------------- /tfn/tools/jobs/regression.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | from pathlib import Path 3 | 4 | from sacred.run import Run 5 | from tensorflow.keras.models import Model 6 | 7 | from . import KerasJob 8 | from .config_defaults import cm_config 9 | from ..callbacks import CartesianMetrics 10 | 11 | 12 | class Regression(KerasJob): 13 | pass 14 | 15 | 16 | class StructurePrediction(Regression): 17 | @property 18 | def config_defaults(self): 19 | base = super().config_defaults 20 | base["loader_config"][ 21 | "map_points" 22 | ] = False # Ensure reconstruction works properly 23 | base["cm_config"] = copy(cm_config) 24 | return base 25 | 26 | def _fit( 27 | self, run: Run, fitable: Model, data: tuple, callbacks: list = None, 28 | ) -> Model: 29 | path = self.exp_config["run_config"]["root_dir"] / "cartesians" 30 | metric_data = data if self.exp_config["run_config"]["test"] else data[:2] 31 | return super()._fit( 32 | run, 33 | fitable, 34 | data, 35 | callbacks=[ 36 | CartesianMetrics(path, *metric_data, **self.exp_config["cm_config"]) 37 | ], 38 | ) 39 | -------------------------------------------------------------------------------- /tfn/tools/jobs/search.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.model_selection import ParameterGrid 3 | 4 | from . import KerasJob 5 | 6 | 7 | class GridSearch(KerasJob): 8 | def __init__( 9 | self, job: KerasJob, grid: dict, total_models: int = None, *args, **kwargs 10 | ): 11 | super().__init__(*args, **kwargs) 12 | self.job = job 13 | self.grid = ParameterGrid(grid) 14 | self.total_models = total_models or np.inf 15 | 16 | def _main( 17 | self, run, seed, fitable=None, dataloader_config=None, fitable_config=None 18 | ): 19 | for i, config in enumerate(self.grid): 20 | if i >= self.total_models: # Stop when hitting max models 21 | print( 22 | f"## Max model count of {self.total_models} exceeded, ending search ##" 23 | ) 24 | break 25 | print(f"### Performing Grid searh on model {i} ###") 26 | print(f"Config set (not showing defaults): {config}") 27 | [ 28 | config.setdefault(k, v) 29 | for k, v in self.job.exp_config["builder_config"].items() 30 | ] 31 | self.job._new_model_path(f"model_{i}.h5") 32 | self.job._main(run, seed, fitable_config=config) 33 | print(f"# Completed search on model {i} #\n") 34 | try: 35 | pass 36 | except Exception as e: 37 | print( 38 | f"Error message: {e}\n" 39 | f"Encountered {type(e)} in search, skipping configuration..." 40 | ) 41 | pass 42 | -------------------------------------------------------------------------------- /tfn/tools/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sub-package containg all loader classes 3 | """ 4 | from .data_loader import DataLoader 5 | from .qm9_loader import QM9DataDataLoader 6 | from .iso17_loader import ISO17DataLoader 7 | from .ts_loader import TSLoader 8 | from .sn2_loader import SN2Loader 9 | from .isom_loader import IsomLoader 10 | -------------------------------------------------------------------------------- /tfn/tools/loaders/data_loader.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Union 3 | 4 | import numpy as np 5 | 6 | 7 | class DataLoader(object): 8 | EV_PER_HARTREE = 27.2116 9 | KCAL_PER_HARTREE = 627.509 10 | KCAL_PER_EV = 23.06035 11 | 12 | def __init__( 13 | self, 14 | path: str, 15 | map_points: bool = True, 16 | splitting: Union[str, int, None] = "70:20:10", 17 | pre_load: bool = False, 18 | num_points: int = 29, 19 | **kwargs 20 | ): 21 | """ 22 | 23 | :param path: str. Path to dataset, typically a .npz or .hdf5 file. 24 | :param map_points: bool. Defaults to True. Whether or not to map point integers 25 | (e.g. atomic numbers) into a smaller set, such that no integer is left unassigned. 26 | pass False if reconstructing .xyz files after model training. 27 | :param splitting: Union[str, int, None]. Defaults to '70:20:10' 28 | :param pre_load: 29 | :param num_points: 30 | :param kwargs: 31 | """ 32 | self.path = path 33 | self.map_points = map_points 34 | self.splitting = splitting 35 | self.num_points = num_points 36 | self.data = None 37 | self.dataset_length = None 38 | 39 | self._max_z = None 40 | self._mu = None 41 | self._sigma = None 42 | if pre_load: 43 | self.load_data() 44 | 45 | @property 46 | def max_z(self): 47 | if self._max_z is None: 48 | self.load_data(return_maxz=True) 49 | return self._max_z 50 | 51 | @property 52 | def mu(self): 53 | raise NotImplementedError 54 | 55 | @property 56 | def sigma(self): 57 | raise NotImplementedError 58 | 59 | def load_data(self, *args, **kwargs): 60 | """ 61 | :return: data in the format: List[Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray], ...]. e.g. 62 | [ 63 | (x_train, y_train), 64 | (x_val, y_val), 65 | (x_test, y_test) 66 | ] 67 | """ 68 | if self.data is None or self.dataset_length is None: 69 | raise NotImplementedError( 70 | "Data and Length must be specified. Make sure a DataLoader " 71 | "subclass is being used." 72 | ) 73 | if isinstance(self.splitting, int): 74 | return self.cross_validate() 75 | else: 76 | return self.three_way_split() 77 | 78 | def few_examples(self, num_examples: int = 5, **kwargs): 79 | data = self.load_data(**kwargs) 80 | truncated_data = [] 81 | for split in data: # [x, y] 82 | truncated_split = [] 83 | for l in split: 84 | truncated_array = [] 85 | for array in l: 86 | truncated_array.append(array[:num_examples]) 87 | truncated_split.append(truncated_array) 88 | truncated_data.append(truncated_split) 89 | return truncated_data 90 | 91 | def cross_validate(self, data: list = None, length: int = None): 92 | """ 93 | :return: data in the format: [(x_0, y_0), (x_1, y_1), ...] for number of folds specified 94 | in splitting param. 95 | """ 96 | data = data or self.data 97 | length = length or self.dataset_length 98 | if self.splitting < 2: 99 | raise ValueError( 100 | "Must provide a splitting param of 2 or more folds for cross " 101 | "validation" 102 | ) 103 | fold_length, remainder = divmod(length, self.splitting) 104 | folds = [fold_length for _ in range(self.splitting)] 105 | if remainder > 0.25 * length: 106 | folds.append(remainder) 107 | else: 108 | folds[-1] += remainder 109 | return self.split_data(data, folds) 110 | 111 | def three_way_split(self, data: list = None, length: int = None): 112 | """ 113 | :return: data in the format: [ 114 | [ 115 | (x_train, y_train), 116 | (x_val, y_val), 117 | (x_test, y_test) 118 | ] 119 | ] 120 | """ 121 | data = data or self.data 122 | length = length or self.dataset_length 123 | if self.splitting is None: 124 | splits = [length] # Use 100 percent of dataset as train data 125 | else: 126 | splits = [ 127 | int(int(x) / 100 * length) 128 | for x in re.findall(r"(\d{1,2})", self.splitting) 129 | ] 130 | splits[int(np.argmax(splits))] += length - sum( 131 | splits 132 | ) # Add remainder to largest split 133 | return self.split_data(data, splits) 134 | 135 | @staticmethod 136 | def split_data(data: list, splits: list): 137 | """ 138 | :param data: dataset in the form [x, y], where x and y are lists of ndarrays. 139 | :param splits: List[int]. 140 | :return: data split according to splits, with None for splits with length == 0. 141 | """ 142 | x_data, y_data = data 143 | output_data = [] 144 | for i, split in enumerate(splits): 145 | cursor = sum(splits[:i]) 146 | boundary = cursor + split 147 | output_data.append( 148 | ( 149 | [x[cursor:boundary] for x in x_data], 150 | [y[cursor:boundary] for y in y_data], 151 | ) 152 | ) 153 | output_data = [o if len(o[0][0]) != 0 else None for o in output_data] 154 | return output_data 155 | 156 | @staticmethod 157 | def remap_points(atomic_nums): 158 | atom_mapping = np.unique(atomic_nums) 159 | for remapped_z, original_z in enumerate(atom_mapping): 160 | atomic_nums[atomic_nums == original_z] = remapped_z 161 | 162 | @staticmethod 163 | def pad_along_axis(array: np.ndarray, target_length, axis=1): 164 | pad_size = target_length - array.shape[axis] 165 | axis_nb = len(array.shape) 166 | if pad_size < 0: 167 | return array 168 | npad = [(0, 0) for _ in range(axis_nb)] 169 | npad[axis] = (0, pad_size) 170 | b = np.pad(array, pad_width=npad, mode="constant", constant_values=0) 171 | return b 172 | 173 | @staticmethod 174 | def shuffle_arrays(x, y, length): 175 | """ 176 | :param x: list. input data to be shuffled. 177 | :param y: list. output data to be shuffled. 178 | :param length: int. number of examples in dataset. 179 | :return: List[list]. Input and output shuffled. 180 | """ 181 | s = np.arange(length) 182 | np.random.shuffle(s) 183 | inp = [a[s] for a in x] 184 | out = [a[s] for a in y] 185 | return inp, out 186 | -------------------------------------------------------------------------------- /tfn/tools/loaders/iso17_loader.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import h5py 4 | import numpy as np 5 | 6 | from .data_loader import DataLoader 7 | 8 | 9 | class ISO17DataLoader(DataLoader): 10 | def __init__(self, *args, **kwargs): 11 | self.use_energies = kwargs.pop("use_energies", True) 12 | self._force_mu = None 13 | self._force_sigma = None 14 | super().__init__(*args, **kwargs) 15 | 16 | @property 17 | def mu(self): 18 | atomic_means = np.array( 19 | [ 20 | 0.0, # Dummy points 21 | -13.61312172, # Hydrogens 22 | -1029.86312267, # Carbons 23 | -2042.61123593, # Oxygens 24 | ] 25 | ).reshape((-1, 1)) 26 | if self._force_mu is None: 27 | self.load_data() 28 | return atomic_means, self._force_mu 29 | 30 | @property 31 | def sigma(self): 32 | if self._force_mu is None: 33 | self.load_data() 34 | return np.ones_like(self.mu[0]), self._force_sigma 35 | 36 | def load_data(self, *args, **kwargs): 37 | """ 38 | The ISO17 has 640982 structures split across 5 different datasets. The ISO17 is composed of 39 | 129 isomers of C7O2H10 each with 5000 structures (frames with 1 femtosecond resolution) 40 | along various MD trajectories. Each structure has an associated molecular energy, 41 | and set of atomic forces. 42 | 43 | :param kwargs: possible kwargs: 44 | return_stats: bool. Defaults to False. exit early by returning mu, sigma 45 | return_maxz: bool. Defaults to False. exit early by returning max_z 46 | dataset_type: str. Defaults to 'reference'. Indicator for which dataset to select for 47 | train/val split test_type: str. Defaults to 'test_other'. Indictator for which datset 48 | to select for testing 49 | :return: dict. The structure of the returned data is as such: 50 | { 51 | dataset: (x, y) 52 | } 53 | Possible values for dataset: 54 | 'reference' - 80% of steps of 80% of MD trajectories, (404000 examples). 55 | 'reference_eq' - equilibrium conformations of those molecules, (101 examples). 56 | 'test_within' - remaining 20% unseen steps of reference trajectories, 57 | (101000 examples). 58 | 'test_other' - remaining 20% unseen MD trajectories, (130000 examples). 59 | 'test_eq' - equilibrium conformations of test trajectories, (5881 examples). 60 | 61 | """ 62 | if ( 63 | self.data is not None 64 | and "dataset_type" not in kwargs 65 | and "test_type" not in kwargs 66 | ): 67 | return self.data 68 | dataset_name = kwargs.get("dataset_type", "reference") 69 | test_name = kwargs.get("test_type", "test_other") 70 | data = { 71 | dataset_name: [], # Populated to [positions, atomic_nums, energies, forces] 72 | test_name: [], 73 | } 74 | 75 | # Load from hdf5 file 76 | with h5py.File(self.path, "r") as file: 77 | for name, l in data.items(): 78 | positions = self.pad_along_axis( 79 | np.array(file["{}/positions".format(name)]), self.num_points 80 | ) 81 | atomic_nums = self.pad_along_axis( 82 | np.tile( 83 | np.expand_dims(file["{}/atomic_numbers".format(name)], axis=0), 84 | (len(positions), 1), 85 | ), 86 | self.num_points, 87 | ) 88 | 89 | energies = np.array(file["{}/energies".format(name)]) 90 | forces = ( 91 | self.pad_along_axis( 92 | np.array(file["{}/forces".format(name)]), self.num_points 93 | ) 94 | * self.KCAL_PER_EV 95 | ) 96 | 97 | if self.use_energies: 98 | l.extend([atomic_nums, positions, energies, forces]) 99 | else: 100 | l.extend([atomic_nums, positions, forces]) 101 | 102 | # Remapping 103 | if self.map_points: 104 | [self.remap_points(d[0]) for d in data.values()] 105 | self._max_z = np.max(data[dataset_name][0]) + 1 106 | if kwargs.get("return_maxz", False): 107 | return 108 | 109 | # Get Force mu/sigma 110 | self._force_mu = np.mean(data[dataset_name][-1]) 111 | self._force_sigma = np.std(data[dataset_name][-1]) 112 | 113 | # Split data 114 | self.splitting = re.search(r"\d{1,2}:\d{1,2}", self.splitting).group(0) 115 | self.data = self.three_way_split( 116 | data=[data[dataset_name][:2], data[dataset_name][2:]], 117 | length=len(data[dataset_name][0]), 118 | ) 119 | self.splitting = None 120 | self.data.extend( 121 | self.three_way_split( 122 | data=[data[test_name][:2], data[test_name][2:]], 123 | length=len(data[dataset_name][0]), 124 | ) 125 | ) 126 | return self.data 127 | -------------------------------------------------------------------------------- /tfn/tools/loaders/isom_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | from . import DataLoader 4 | 5 | 6 | class IsomLoader(DataLoader): 7 | @property 8 | def mu(self): 9 | return 0 10 | 11 | @property 12 | def sigma(self): 13 | return 1 14 | 15 | def load_data(self, *args, **kwargs): 16 | with h5py.File(self.path, "r") as dataset: 17 | atomic_nums = self.pad_along_axis( 18 | np.asarray(dataset["ts_train/atomic_nums"], dtype="int"), 19 | self.num_points, 20 | ) 21 | cartesians = { 22 | structure_type: self.pad_along_axis( 23 | np.nan_to_num(dataset["{}/cartesians".format(structure_type)]), 24 | self.num_points, 25 | ) 26 | for structure_type in ("ts_train", "r_train", "p_train",) 27 | } 28 | 29 | # Remap points 30 | if self.map_points: 31 | self.remap_points(atomic_nums) 32 | self._max_z = kwargs.get("custom_maxz", None) or np.max(atomic_nums) + 1 33 | if kwargs.get("return_maxz", False): 34 | return 35 | 36 | x = [atomic_nums, cartesians["r_train"], cartesians["p_train"]] 37 | y = [cartesians["ts_train"]] 38 | self.data = [x, y] 39 | self.dataset_length = len(atomic_nums) 40 | return super().load_data() 41 | -------------------------------------------------------------------------------- /tfn/tools/loaders/qm9_loader.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | 4 | from tfn.layers.atomic_images import OneHot 5 | 6 | from ...layers.utility_layers import MaskedDistanceMatrix 7 | from .data_loader import DataLoader 8 | 9 | 10 | class QM9DataDataLoader(DataLoader): 11 | @property 12 | def mu(self): 13 | a = np.zeros((self.max_z, 1), dtype="float32") 14 | a[1] = -13.61312172 15 | a[6] = -1029.86312267 16 | a[7] = -1485.30251237 17 | a[8] = -2042.61123593 18 | a[9] = -2713.48485589 19 | return a 20 | 21 | @property 22 | def sigma(self): 23 | return np.ones_like(self.mu) 24 | 25 | def load_data(self, *args, **kwargs): 26 | """ 27 | The QM9 is a dataset of 133,885 small molecules with up to 9 heavy points (C, N, O, or F). The dataset has 13 28 | different chemical properties associated with each structure, the most valuable being energy, of which there 29 | are several forms. This `DataLoader` is responsible for returning U0 energies; the internal energy of the 30 | structures at 298 Kelvin. 31 | 32 | :param kwargs: possible kwargs: 33 | return_stats: exit early by returning mu, sigma 34 | return_maxz: exit early by returning max_z 35 | :return: List[Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray], ...]. 36 | QM9 data in the format: 37 | [ 38 | (x_train, y_train), 39 | (x_val, y_val), 40 | (x_test, y_test) 41 | ], where x = (cartesians, atomic_nums) and y = energies 42 | """ 43 | if self.data is not None: 44 | return self.data 45 | with h5py.File(self.path, "r") as dataset: 46 | cartesians = self.pad_along_axis( 47 | np.nan_to_num(dataset["QM9/R"]), self.num_points 48 | ) 49 | atomic_nums = self.pad_along_axis( 50 | np.array(dataset["QM9/Z"]), self.num_points 51 | ) 52 | energies = np.array(dataset["QM9/U_naught"]) * self.EV_PER_HARTREE 53 | 54 | if self.map_points: 55 | self.remap_points(atomic_nums) 56 | self._max_z = kwargs.get("custom_maxz", None) or np.max(atomic_nums) + 1 57 | if kwargs.get("return_maxz", False): 58 | return 59 | 60 | if kwargs.get("modify_structures", False): 61 | forward_cartesians, reverse_cartesians = self.modify_structures( 62 | cartesians, 63 | kwargs.get("modify_distance", 0.5), 64 | kwargs.get("modify_seed", 0), 65 | ) 66 | if kwargs.get("classifier_output", False): 67 | tiled_cartesians = np.concatenate( 68 | [cartesians, forward_cartesians, reverse_cartesians], axis=0 69 | ) 70 | tiled_atomic_nums = np.tile(atomic_nums, (3, 1)) 71 | labels = np.zeros((len(tiled_cartesians),), dtype="int32") 72 | labels[: len(cartesians)] = 1 73 | x, y = self.shuffle_arrays( 74 | [tiled_atomic_nums, tiled_cartesians], [labels], len(labels) 75 | ) 76 | length = len(labels) 77 | else: 78 | x = [atomic_nums, forward_cartesians, reverse_cartesians] 79 | y = [ 80 | np.triu( 81 | MaskedDistanceMatrix()( 82 | [OneHot(self.max_z)(atomic_nums), cartesians] 83 | ) 84 | ) 85 | if kwargs.get("output_distance_matrix", False) 86 | else cartesians 87 | ] 88 | length = len(atomic_nums) 89 | 90 | else: 91 | x = [atomic_nums, cartesians] 92 | y = [energies] 93 | length = len(atomic_nums) 94 | 95 | self.data = [x, y] 96 | self.dataset_length = length 97 | return super().load_data(*args, **kwargs) 98 | 99 | def modify_structures(self, c, distance=0.75, seed=0): 100 | np.random.seed(seed) 101 | indices = np.random.randint(3, size=(len(c))) 102 | forward, reverse = np.copy(c), np.copy(c) 103 | forward += 0.2 * distance 104 | reverse -= 0.2 * distance 105 | for i, j in enumerate(indices): 106 | forward[i, j] += 0.8 * distance 107 | reverse[i, j] -= 0.8 * distance 108 | return forward, reverse 109 | -------------------------------------------------------------------------------- /tfn/tools/loaders/sn2_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from . import DataLoader 4 | 5 | 6 | class SN2Loader(DataLoader): 7 | def __init__(self, *args, **kwargs): 8 | self.use_energies = kwargs.pop("use_energies", True) 9 | self._force_mu = None 10 | self._force_sigma = None 11 | super().__init__(*args, **kwargs) 12 | 13 | @property 14 | def mu(self): 15 | atomic_means = np.array( 16 | [ 17 | 0.0, # Dummy points 18 | -13.579407869766147, # H 19 | -1028.9362774711024, # C 20 | -2715.578463075019, # F 21 | -12518.663203367176, # Cl 22 | -70031.09203874589, # Br 23 | -8096.587166328217, # I 24 | ] 25 | ).reshape([-1, 1]) 26 | if self._force_mu is None: 27 | self.load_data() 28 | return atomic_means, self._force_mu 29 | 30 | @property 31 | def sigma(self): 32 | if self._force_mu is None: 33 | self.load_data() 34 | return np.ones_like(self.mu[0]), self._force_sigma 35 | 36 | def load_data(self, *args, **kwargs): 37 | if self.data is not None: 38 | return self.data 39 | 40 | # Load from .npz 41 | data = np.load(self.path) 42 | cartesians = self.pad_along_axis( 43 | data["R"], self.num_points 44 | ) # (batch, points, 3) 45 | atomic_nums = self.pad_along_axis(data["Z"], self.num_points) # (batch, points) 46 | energies = data["E"] # (batch, ) 47 | forces = self.pad_along_axis(data["F"], self.num_points) 48 | dipoles = data["D"] # (batch, 3) 49 | 50 | # Remap points 51 | if self.map_points: 52 | self.remap_points(atomic_nums) 53 | self._max_z = np.max(atomic_nums) + 1 54 | if kwargs.get("return_maxz", False): 55 | return 56 | 57 | # Stats for forces/dipoles 58 | self._force_mu = np.mean(forces) 59 | self._force_sigma = np.std(forces) 60 | 61 | self.data = [[atomic_nums, cartesians], [energies, forces]] 62 | self.dataset_length = len(atomic_nums) 63 | return super().load_data() 64 | -------------------------------------------------------------------------------- /tfn/tools/loaders/ts_loader.py: -------------------------------------------------------------------------------- 1 | from h5py import File 2 | import numpy as np 3 | 4 | from tfn.layers.atomic_images import OneHot 5 | 6 | from ...layers.utility_layers import MaskedDistanceMatrix 7 | from . import DataLoader 8 | 9 | 10 | class TSLoader(DataLoader): 11 | def __init__(self, *args, **kwargs): 12 | self.use_energies = kwargs.pop("use_energies", False) 13 | super().__init__(*args, **kwargs) 14 | 15 | @property 16 | def mu(self): 17 | mu = np.array( 18 | [ 19 | 0.0, # Dummy points 20 | -13.61312172, # Hydrogens 21 | -1029.86312267, # Carbons 22 | -1485.30251237, # Nitrogens 23 | -2042.61123593, # Oxygens 24 | -2715.57846308, # Fluorines 25 | -17497.9266683, # Silicon 26 | -19674.5108670, # Phosphorus 27 | -10831.2647155, # Sulfur 28 | -12518.6632034, # Chlorine 29 | -61029.6106422, # Selenium 30 | -70031.0920387, # Bromine 31 | ] 32 | ).reshape((-1, 1)) 33 | if self.use_energies: 34 | return mu 35 | else: 36 | return np.zeros_like(mu) 37 | 38 | @property 39 | def sigma(self): 40 | return np.ones_like(self.mu) 41 | 42 | def load_data(self, *args, **kwargs): 43 | """ 44 | The MP2 TS Dataset is a dataset of 74 gas-phase SN2 structures, comprised of 45 | reactant/ts/product cartesians & atomic numbers obtained using the MP2 method with the 46 | cc-PVDZ basis set. 47 | 48 | x: [atomic_numbers, reactant_cartesians, reactant_complex_cartesians, 49 | product_cartesians, product_complex_cartesians] if kwarg `input_type` == 50 | 'cartesians' (Default). If kwarg `input_type` == 'energies', then x is: [ 51 | atomic_numbers, reactant_energies, product_energies]. Very rarely will this be 52 | useful, but the functionality exists. 53 | y: [ts_cartesians] if kwarg `output_type` == 'cartesians' (Default). If kwarg 54 | 'output_type' == 'energies' then y is: [ts_energies], and if `output_type` == 'both' 55 | then y is: [ts_cartesians, ts_energies] 56 | 57 | :param kwargs: Possible kwargs: 58 | 'cache': bool. Defaults to True. 59 | 'input_type': str. Defaults to 'cartesians'. Possible values include ['cartesians', 60 | 'classifier', 'siamese'], 61 | 'output_type': str. Defaults to 'cartesians'. Possible values include ['cartesians', 62 | 'energies', 'both', 'classifier', 'siamese'] 63 | :return: data in the format: [(x_train, y_train), (x_val, y_val), (x_test, y_test)] 64 | """ 65 | if ( 66 | self.data is not None 67 | and self.dataset_length is not None 68 | and kwargs.get("cache", True) 69 | ): 70 | return super().load_data() 71 | 72 | # Load Data 73 | with File(self.path, "r") as dataset: 74 | atomic_nums = self.pad_along_axis( 75 | np.asarray(dataset["ts/atomic_numbers"], dtype="int"), self.num_points 76 | ) 77 | cartesians = { 78 | structure_type: self.pad_along_axis( 79 | np.nan_to_num(dataset["{}/cartesians".format(structure_type)]), 80 | self.num_points, 81 | ) 82 | for structure_type in ( 83 | "ts", 84 | "reactant", 85 | "reactant_complex", 86 | "product_complex", 87 | "product", 88 | ) 89 | } 90 | energies = { 91 | structure_type: np.asarray( 92 | dataset["{}/energies".format(structure_type)] 93 | ) 94 | * self.EV_PER_HARTREE 95 | for structure_type in ("ts", "reactant", "product") 96 | } 97 | noisy_indices = np.asarray( 98 | dataset["noisy_reactions"], dtype="int" 99 | ) # (16, ) 100 | 101 | # Pull out noise 102 | if kwargs.get("remove_noise", False): 103 | atomic_nums = np.delete(atomic_nums, noisy_indices, axis=0) 104 | cartesians = { 105 | k: np.delete(c, noisy_indices, axis=0) for k, c in cartesians.items() 106 | } 107 | energies = { 108 | k: np.delete(e, noisy_indices, axis=0) for k, e in energies.items() 109 | } 110 | 111 | # Remap 112 | if self.map_points: 113 | self.remap_points(atomic_nums) 114 | self._max_z = np.max(atomic_nums) + 1 115 | if kwargs.get("return_maxz", False): 116 | return 117 | 118 | # Determine I/O data 119 | input_type = kwargs.get("input_type", "cartesians").lower() 120 | output_type = kwargs.get("output_type", "cartesians").lower() 121 | 122 | if input_type == "classifier" or output_type == "classifier": 123 | tiled_atomic_nums, tiled_cartesians, labels = self.tile_arrays( 124 | atomic_nums, cartesians 125 | ) 126 | x, y = self.shuffle_arrays( 127 | [tiled_atomic_nums, tiled_cartesians], [labels], len(labels) 128 | ) 129 | length = len(labels) 130 | 131 | elif input_type == "siamese" or output_type == "siamese": 132 | x, y = self.make_siamese_dataset( 133 | *self.tile_arrays( 134 | atomic_nums, cartesians, blacklist=kwargs.pop("blacklist", None) 135 | ) 136 | ) 137 | if kwargs.get("shuffle", True): 138 | x, y = self.shuffle_arrays(x, y, len(y[0])) 139 | length = len(y[0]) 140 | 141 | else: # Regression dataset 142 | length = len(atomic_nums) 143 | x = [ 144 | atomic_nums, 145 | cartesians["reactant_complex"] 146 | if kwargs.get("use_complexes", False) 147 | else cartesians["reactant"], 148 | cartesians["product_complex"] 149 | if kwargs.get("use_complexes", False) 150 | else cartesians["product"], 151 | ] 152 | y = [ 153 | np.triu( 154 | MaskedDistanceMatrix()( 155 | [OneHot(self.max_z)(atomic_nums), cartesians["ts"]] 156 | ) 157 | ) 158 | if kwargs.get("output_distance_matrix", False) 159 | else cartesians["ts"], 160 | energies["ts"], 161 | ] 162 | if output_type == "energies": 163 | y.pop(0) 164 | elif output_type == "both": 165 | pass 166 | else: 167 | y.pop(1) 168 | 169 | # shuffle dataset 170 | if kwargs.get("shuffle", True): 171 | x, y = self.shuffle_arrays(x, y, length) 172 | 173 | # Split and serve data 174 | self.data = [x, y] 175 | self.dataset_length = length 176 | if self.splitting == "custom": 177 | split = [ 178 | 0, # hetero-ring structure, complex 179 | 3, # 3 member double bond ring, simple reaction 180 | 7, # methyl-chloride, super simple 181 | 11, # ? 182 | 16, # ispropyl-chloride, little more complex 183 | 22, # ? 184 | 24, # Triple bond, perfect midpoint 185 | ] 186 | val = [[a[split] for a in x], [a[split] for a in y]] 187 | train = [ 188 | [np.delete(a, split, 0) for a in x], 189 | [np.delete(a, split, 0) for a in y], 190 | ] 191 | return train, val, None 192 | else: 193 | return super().load_data() 194 | 195 | def make_siamese_dataset(self, tiled_atomic_nums, tiled_cartesians, labels): 196 | # Make x shape: (batch, batch, 2, points, 3) Convert for output -> (batch, 2, points, 3) 197 | c = np.zeros((len(labels), len(labels), 2, self.num_points, 3)) 198 | a = np.zeros(c.shape[:-1]) 199 | diff = np.where( 200 | (np.expand_dims(labels, -1) - np.expand_dims(labels, -2)) != 0, 1, 0 201 | ) 202 | indices = np.triu_indices(diff.shape[0], 1) 203 | for i, (i_atomic_nums, i_cartesians) in enumerate( 204 | zip(tiled_atomic_nums, tiled_cartesians) 205 | ): 206 | for j, (j_atomic_nums, j_cartesians) in enumerate( 207 | zip(tiled_atomic_nums, tiled_cartesians) 208 | ): 209 | a[i, j, 1], c[i, j, 1] = i_atomic_nums, i_cartesians 210 | a[i, j, 0], c[i, j, 0] = j_atomic_nums, j_cartesians 211 | 212 | # assign data 213 | labels = [diff[indices]] 214 | x = [a[indices], c[indices]] 215 | return x, labels 216 | 217 | @staticmethod 218 | def tile_arrays(atomic_nums, cartesians, blacklist: list = None): 219 | """:return: tiled/concatenated arrays: [atomic_nums, cartesians <- (concat), labels]""" 220 | blacklist = blacklist or [] 221 | tiled_atomic_nums = np.tile(atomic_nums, (5 - len(blacklist), 1)) 222 | tiled_cartesians = np.concatenate( 223 | [a for key, a in cartesians.items() if key not in blacklist], axis=0 224 | ) 225 | labels = np.zeros((len(tiled_atomic_nums),), dtype="int32") 226 | labels[: len(atomic_nums)] = 1 227 | return tiled_atomic_nums, tiled_cartesians, labels 228 | -------------------------------------------------------------------------------- /tfn/tools/loggers.py: -------------------------------------------------------------------------------- 1 | from sacred.run import Run 2 | from tensorflow.keras.callbacks import Callback 3 | 4 | 5 | class SacredMetricLogger(Callback): 6 | def __init__(self, _run: Run): 7 | super().__init__() 8 | self.run = _run 9 | self.batch_log_rate = 500 10 | self.epoch_log_rate = 1 11 | 12 | def on_epoch_end(self, epoch, logs=None): 13 | if epoch == 0 or epoch % self.epoch_log_rate == 0: 14 | for key, value in logs.items(): 15 | if key != "epoch" and key != "size": 16 | self.run.log_scalar("{}".format(key), value=value, step=epoch) 17 | if epoch % (self.epoch_log_rate * 10) == 0: 18 | self.run.result = value 19 | -------------------------------------------------------------------------------- /tfn/tools/radials.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import tensorflow as tf 4 | from tensorflow.keras.models import model_from_json 5 | from tfn.layers import DenseRadialFactory 6 | 7 | 8 | def get_radial_factory(identifier="multi_dense", radial_kwargs: dict = None): 9 | radial = None 10 | if identifier == "single_dense": 11 | radial = SingleModelDenseRadialFactory 12 | else: 13 | radial = DenseRadialFactory 14 | 15 | if radial_kwargs is not None: 16 | return radial(**radial_kwargs) 17 | else: 18 | return radial() 19 | 20 | 21 | class SingleModelDenseRadialFactory(DenseRadialFactory): 22 | def __init__(self, *args, **kwargs): 23 | self.radial = kwargs.pop("radial", None) 24 | super().__init__(*args, **kwargs) 25 | 26 | def to_json(self): 27 | self.__dict__["radial"] = None 28 | return super().to_json() 29 | 30 | @classmethod 31 | def from_json(cls, config: str): 32 | config = json.loads(config) 33 | if config["radial"]: 34 | config["radial"] = model_from_json(config["radial"]) 35 | else: 36 | config["radial"] = None 37 | return cls(**config) 38 | 39 | def get_radial(self, feature_dim, input_order=None, filter_order=None): 40 | if self.radial is None: 41 | self.radial = super().get_radial(feature_dim, input_order, filter_order) 42 | return self.radial 43 | 44 | 45 | tf.keras.utils.get_custom_objects().update( 46 | {SingleModelDenseRadialFactory.__name__: SingleModelDenseRadialFactory,} 47 | ) 48 | -------------------------------------------------------------------------------- /tutorials/cat_pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UPEIChemistry/tensor-field-networks/5c25583ee4108a13af8e73eabd3c448f42cb70a0/tutorials/cat_pic.png -------------------------------------------------------------------------------- /tutorials/cat_pic_rotated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UPEIChemistry/tensor-field-networks/5c25583ee4108a13af8e73eabd3c448f42cb70a0/tutorials/cat_pic_rotated.png --------------------------------------------------------------------------------