├── .gitattributes ├── .gitignore ├── MANIFEST.in ├── README.md ├── SynthesisSimilarity ├── __init__.py ├── __main__.py ├── core │ ├── __init__.py │ ├── activations │ │ ├── __init__.py │ │ ├── gelu.py │ │ ├── gelu_test.py │ │ ├── swish.py │ │ └── swish_test.py │ ├── bert_modeling.py │ ├── bert_optimization.py │ ├── callbacks.py │ ├── circle_loss.py │ ├── encoders.py │ ├── exp_models.py │ ├── focal_loss.py │ ├── layers.py │ ├── losses.py │ ├── mat_featurization.py │ ├── model_framework.py │ ├── model_utils.py │ ├── task_models.py │ ├── tf_utils.py │ ├── utils.py │ └── vector_utils.py ├── examples │ ├── __init__.py │ └── synthesis_recommendation.py ├── models │ ├── SynthesisEncoding │ │ ├── model_config.json │ │ └── saved_model │ │ │ ├── saved_model.pb │ │ │ └── variables │ │ │ ├── variables.data-00000-of-00001 │ │ │ └── variables.index │ └── SynthesisRecommendation │ │ ├── cmd_parameters.json │ │ ├── model_meta.pkl │ │ └── saved_model │ │ ├── checkpoint │ │ ├── cp.ckpt.data-00000-of-00001 │ │ └── cp.ckpt.index ├── scripts │ ├── _00_download_model_and_data.py │ ├── _01_synthesis_recommendation.py │ ├── _02_target_material_similarity.py │ ├── _03_masked_precursor_completion.py │ ├── _04_reaction_relationship.py │ ├── _05_recommendation_benchmark.py │ ├── _06_computation_time_similarity.py │ └── __init__.py └── scripts_utils │ ├── FastTextSimilarity_utils.py │ ├── MatminerSimilarity_utils.py │ ├── TarMatSimilarity_utils.py │ ├── __init__.py │ ├── benchmark_utils.py │ ├── data_set_utils.py │ ├── multi_processing_utils.py │ ├── precursors_recommendation_utils.py │ ├── reaction_utils.py │ ├── recommendation_utils.py │ ├── similarity_utils.py │ └── train_utils.py ├── requirements.txt ├── requirements_optional.txt └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | SynthesisSimilarity/other_rsc/** filter=lfs diff=lfs merge=lfs -text 2 | SynthesisSimilarity/rsc/** filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | .static_storage/ 58 | .media/ 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | ### Example user template template 108 | ### Example user template 109 | 110 | # IntelliJ project files 111 | .idea 112 | *.iml 113 | out 114 | gen 115 | 116 | SynthesisSimilarity/scripts_not_used 117 | SynthesisSimilarity/scripts_debug 118 | SynthesisSimilarity/scratch 119 | SynthesisSimilarity/rsc_preparation 120 | SynthesisSimilarity/generated -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include SynthesisSimilarity/models/* 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Precursor recommendation for inorganic synthesis by machine learning materials similarity from scientific literature 2 | --- 3 | 4 | Data and codes for "Precursor recommendation for inorganic synthesis by machine learning materials similarity from scientific literature". 5 | 6 | ## Installation 7 | 8 | ```bash 9 | git clone https://github.com/CederGroupHub/SynthesisSimilarity.git 10 | cd SynthesisSimilarity 11 | pip install -e . 12 | cd .. 13 | # download necessary for synthesis recommendation 14 | python -m SynthesisSimilarity download_necessary_data 15 | # The following command is not useful for synthesis recommendation. 16 | # It's only used to download optional data for benchmarking purposes. 17 | # (optional) python -m SynthesisSimilarity download_optional_data 18 | ``` 19 | 20 | ## Description of the Data and file structure 21 | 22 | The precursor recommendation is implemented by referring the synthesis of a novel target material to the known recipe of a similar material, mimicking the human synthesis design process. 23 | The similarity of two target materials is evaluated with the cosine similarity of encoded vectors generated by the synthesis context-based encoding model (PrecursorSelector encoding) in this work. 24 | When the precursors from the reference material do not cover all the elements in the target, we use a masked precursor completion (MPC) model to predict the missing precursors. 25 | 26 | As a brief summary, the useful scripts reproducing the main results in this work are in the folder "scripts". 27 | Other auxiliary codes are in the folders "core" and "scripts_utils". 28 | The trained model is in the folder "models". 29 | If you download the data using "python -m SynthesisSimilarity download_necessary_data" and "python -m SynthesisSimilarity download_optional_data", the downloaded data will be saved in the folders "rsc" and "other_rsc". 30 | More details are displayed as follows. 31 | 32 | ``` 33 | SynthesisSimilarity 34 | ├── README.md # A simple introduction of the repo 35 | ├── setup.py # Used to install the repo as a python package 36 | ├── MANIFEST.in # MANIFEST file used by setup.py 37 | ├── requirements.txt # Python packages required for this repo 38 | ├── requirements_optional.txt # Optional packages not needed for basic use 39 | └── SynthesisSimilarity # The main directory 40 | ├── core # The directory of the core modules and framework for the PrecursorSelector model in this work 41 | │ ├── activations # Activation functions (from https://github.com/tensorflow/models) 42 | │ │ ├── gelu.py # Activation function of gelu() 43 | │ │ ├── gelu_test.py # Test for gelu() 44 | │ │ ├── __init__.py # Python init script for current directory 45 | │ │ ├── swish.py # Activation function of swish() 46 | │ │ └── swish_test.py # Test for swish() 47 | │ ├── bert_modeling.py # Attention block (from https://github.com/tensorflow/models) 48 | │ ├── bert_optimization.py # Additional optimization functions (from https://github.com/tensorflow/models) 49 | │ ├── callbacks.py # Callback functions for monitoring the training process and validation 50 | │ ├── circle_loss.py # Circle loss (adapted from https://github.com/zhen8838/Circle-Loss) 51 | │ ├── encoders.py # Encoder functions to convert the composition of a target material to an encoded vector 52 | │ ├── exp_models.py # The example of how to extend current model to other synthesis prediction tasks 53 | │ ├── focal_loss.py # Focal loss (from https://github.com/artemmavrin/focal-loss) 54 | │ ├── __init__.py # Python init script for current directory 55 | │ ├── layers.py # Low-level neural network modules to be inserted as layers in a more complex network 56 | │ ├── losses.py # The loss function used for gradient descent 57 | │ ├── mat_featurization.py # The example of how to extend input from composition to other materials features 58 | │ ├── model_framework.py # The multi-task framework of the representation model in this work 59 | │ ├── model_utils.py # Handy functions to use the model 60 | │ ├── task_models.py # Neural network modules corresponding to different prediction tasks to be used in the multi-task framework 61 | │ ├── tf_utils.py # Handy functions for tensorflow (adapted from https://github.com/tensorflow/models) 62 | │ ├── utils.py # Handy functions for data processing 63 | │ └── vector_utils.py # Handy functions for operations with vectors 64 | ├── examples # The directory of useful examples 65 | │ ├── synthesis_recommendation.py # Precursor recommendation for the given composition of a target material 66 | │ └── __init__.py # Python init script for current directory 67 | ├── __init__.py # Python init script for current directory 68 | ├── __main__.py # Used for module commands such as "python -m SynthesisSimilarity download_necessary_data" 69 | ├── models # The directory of trained models 70 | │ ├── SynthesisEncoding # The directory of minimum model files for similarity evaluation (the encoder part of the whole model) 71 | │ │ ├── model_config.json # The configuration file summarizing important attributes of the model 72 | │ │ └── saved_model # The directory of files for reloading a tensorflow model 73 | │ │ ├── assets # A directory for reloading a tensorflow model 74 | │ │ ├── saved_model.pb # A file for reloading a tensorflow model 75 | │ │ └── variables # A directory of files for reloading a tensorflow model 76 | │ │ ├── variables.data-00000-of-00001 # A file for reloading a tensorflow model 77 | │ │ └── variables.index # A file for reloading a tensorflow model 78 | │ └── SynthesisRecommendation # The directory of model files for precursor recommendation 79 | │ ├── cmd_parameters.json # The configuration file summarizing important attributes of the model 80 | │ ├── model_meta.pkl # The configuration file of all attributes of the model 81 | │ └── saved_model # The directory of files for reloading a tensorflow model 82 | │ ├── checkpoint # A checkpoint file for reloading a tensorflow model 83 | │ ├── cp.ckpt.data-00000-of-00001 # A checkpoint file for reloading a tensorflow model 84 | │ └── cp.ckpt.index # A checkpoint file for reloading a tensorflow model 85 | ├── other_rsc # The directory of model files for benchmark, but not needed for the model in this study 86 | │ ├── fasttext_pretrained_matsci # The FastText encoding model from https://figshare.com/s/70455cfcd0084a504745 (Kim, E., Jensen, Z., Grootel, A.V., Huang, K., Staib, M., Mysore, S., Chang, H.S., Strubell, E., McCallum, A., Jegelka, S. and Olivetti, E., 2020. Inorganic Materials Synthesis Planning with Literature-Trained Neural Networks. Journal of Chemical Information and Modeling) 87 | │ │ ├── fasttext_embeddings-MINIFIED.model # A file for reloading the FastText model 88 | │ │ ├── fasttext_embeddings-MINIFIED.model.vectors_ngrams.npy # A file for reloading the FastText model 89 | │ │ ├── fasttext_embeddings-MINIFIED.model.vectors.npy # A file for reloading the FastText model 90 | │ │ └── fasttext_embeddings-MINIFIED.model.vectors_vocab.npy # A file for reloading the FastText model 91 | │ └── matminer The Magpie encoding model retrieved from the matminer package (Ward, L., Dunn, A., Faghaninia, A., Zimmermann, N. E. R., Bajaj, S., Wang, Q., Montoya, J. H., Chen, J., Bystrom, K., Dylla, M., Chard, K., Asta, M., Persson, K., Snyder, G. J., Foster, I., Jain, A., Matminer: An open source toolkit for materials data mining. Comput. Mater. Sci. 152, 60-69 (2018). Ward, L., Agrawal, A., Choudhary, A., & Wolverton, C. (2016). A general-purpose machine learning framework for predicting properties of inorganic materials. npj Computational Materials, 2(1), 1-7.) 92 | │ ├── mp_imputer_preset_v1.0.2.pkl # A file for reloading the Magpie model 93 | │ └── mp_scaler_preset_v1.0.2.pkl # A file for reloading the Magpie model 94 | ├── rsc # The directory of data files for model training and evaluation in this work 95 | │ ├── data_split.npz # Data splitted based on publication year and prototype formula 96 | │ ├── ele_order_counter.json # Statistics of how often authors put one element in front of another when writing the string for a material formula 97 | │ ├── pre_count_normalized_by_rxn_ss.json # Statistics of the frequency to use each precursor in the literature-mined synthesis reactions 98 | │ └── reactions_v20_20210820_ss.jsonl # The text-mined solid-state synthesis dataset from materials science papers 99 | ├── scripts # The directory of useful scripts reproducing the main results in this work 100 | │ ├── _00_download_model_and_data.py # Download data from google drive for the PrecursorSelector model 101 | │ ├── _01_synthesis_recommendation.py # Precursor recommendation for the given composition of a target material 102 | │ ├── _02_target_material_similarity.py # Similarity evaluation for two target materials based on the PrecursorSelector encoding 103 | │ ├── _03_masked_precursor_completion.py # Prediction of the complete precursors given the target material and partial precursors 104 | │ ├── _04_reaction_relationship.py # Plot relationships between targets and their shared precursors 105 | │ ├── _05_recommendation_benchmark.py # Benchmark of precursor recommendation using various algorithms 106 | │ ├── _06_computation_time_similarity.py # Time cost for similarity evaluation 107 | │ └── __init__.py # Python init script for current directory 108 | └── scripts_utils # The directory of handy functions for the scripts 109 | ├── benchmark_utils.py # Handy functions for benchmark 110 | ├── data_set_utils.py # Handy functions for loading data 111 | ├── FastTextSimilarity_utils.py # Handy functions for using the FastText model 112 | ├── __init__.py # Python init script for current directory 113 | ├── MatminerSimilarity_utils.py # Handy functions for using the Magpie model 114 | ├── multi_processing_utils.py # Handy functions for using the parallel processing using multiple CPU cores 115 | ├── precursors_recommendation_utils.py # Handy functions for precursor recommendation 116 | ├── recommendation_utils.py # Handy functions for general recommendation 117 | ├── similarity_utils.py # Handy functions for general similarity evaluation 118 | ├── TarMatSimilarity_utils.py # Handy functions for evaluation of target similarity 119 | └── train_utils.py # Handy functions for model training 120 | ``` 121 | 122 | 123 | ## Sharing/access Information 124 | 125 | https://github.com/CederGroupHub/SynthesisSimilarity 126 | 127 | -------------------------------------------------------------------------------- /SynthesisSimilarity/__init__.py: -------------------------------------------------------------------------------- 1 | from .scripts_utils import PrecursorsRecommendation 2 | from .scripts import download_necessary_data 3 | from .scripts import download_optional_data -------------------------------------------------------------------------------- /SynthesisSimilarity/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import warnings 4 | 5 | from .scripts import download_necessary_data 6 | from .scripts import download_optional_data 7 | 8 | def main(): 9 | # Check if the script is being run from the source directory 10 | if os.path.abspath(os.getcwd()) == os.path.abspath( 11 | os.path.join(os.path.dirname(__file__),'..') 12 | ): 13 | warnings.warn( 14 | """Error because of running from source directory. 15 | To avoid confusion, please switch to another directory and run again. 16 | For example, run "cd .." first and then "python -m SynthesisSimilarity download_necessary_data". 17 | """ 18 | ) 19 | return 20 | 21 | if len(sys.argv) > 1: 22 | command = sys.argv[1] 23 | if command in {"download_data", 'download_necessary_data',} : 24 | download_necessary_data() 25 | elif command in {'download_optional_data',}: 26 | download_optional_data() 27 | else: 28 | print(f"Unknown command: {command}") 29 | else: 30 | print("Usage: python -m SynthesisSimilarity ") 31 | print( 32 | "Available commands: " 33 | "(1) download_necessary_data, " 34 | "(2) download_optional_data." 35 | ) 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | from __future__ import division 4 | from __future__ import absolute_import 5 | 6 | __author__ = 'Tanjin He' 7 | __maintainer__ = 'Tanjin He' 8 | __email__ = 'tanjin_he@berkeley.edu' 9 | 10 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/activations/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Activations package definition.""" 16 | from .gelu import gelu 17 | from .swish import swish 18 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/activations/gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Gaussian error linear unit.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | 23 | import tensorflow as tf 24 | 25 | 26 | def gelu(x): 27 | """Gaussian Error Linear Unit. 28 | 29 | This is a smoother version of the RELU. 30 | Original paper: https://arxiv.org/abs/1606.08415 31 | Args: 32 | x: float Tensor to perform activation. 33 | 34 | Returns: 35 | `x` with the GELU activation applied. 36 | """ 37 | cdf = 0.5 * (1.0 + tf.tanh( 38 | (math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3))))) 39 | return x * cdf 40 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/activations/gelu_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for the Gaussian error linear unit.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | 24 | from . import gelu 25 | 26 | 27 | def test_gelu(): 28 | expected_data = [[0.14967535, 0., -0.10032465], 29 | [-0.15880796, -0.04540223, 2.9963627]] 30 | gelu_data = gelu([[.25, 0, -.25], [-1, -2, 3]]) 31 | # assertAllClose(expected_data, gelu_data) 32 | 33 | 34 | if __name__ == '__main__': 35 | test_gelu() 36 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/activations/swish.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Customized Swish activation.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | 24 | def swish(features): 25 | """Computes the Swish activation function. 26 | 27 | The tf.nn.swish operation uses a custom gradient to reduce memory usage. 28 | Since saving custom gradients in SavedModel is currently not supported, and 29 | one would not be able to use an exported TF-Hub module for fine-tuning, we 30 | provide this wrapper that can allow to select whether to use the native 31 | TensorFlow swish operation, or whether to use a customized operation that 32 | has uses default TensorFlow gradient computation. 33 | 34 | Args: 35 | features: A `Tensor` representing preactivation values. 36 | 37 | Returns: 38 | The activation value. 39 | """ 40 | features = tf.convert_to_tensor(features) 41 | return features * tf.nn.sigmoid(features) 42 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/activations/swish_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for the customized Swish activation.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from . import swish 24 | 25 | def test_gelu(): 26 | customized_swish_data = swish([[.25, 0, -.25], [-1, -2, 3]]) 27 | swish_data = tf.nn.swish([[.25, 0, -.25], [-1, -2, 3]]) 28 | # assertAllClose(customized_swish_data, swish_data) 29 | 30 | 31 | if __name__ == '__main__': 32 | test_gelu() 33 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/bert_modeling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ 16 | The main BERT model and related functions. 17 | https://github.com/tensorflow/models/blob/ 18 | 49ba237d35d2a049be7bede596f4b29fd85cfe28/official/nlp/bert_modeling.py 19 | """ 20 | 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import copy 27 | import json 28 | import math 29 | import six 30 | import tensorflow as tf 31 | 32 | from . import tf_utils 33 | 34 | class Attention(tf.keras.layers.Layer): 35 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 36 | 37 | This is an implementation of multi-headed attention based on "Attention 38 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 39 | this is self-attention. Each timestep in `from_tensor` attends to the 40 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 41 | 42 | This function first projects `from_tensor` into a "query" tensor and 43 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 44 | of tensors of length `num_attention_heads`, where each tensor is of shape 45 | [batch_size, seq_length, size_per_head]. 46 | 47 | Then, the query and key tensors are dot-producted and scaled. These are 48 | softmaxed to obtain attention probabilities. The value tensors are then 49 | interpolated by these probabilities, then concatenated back to a single 50 | tensor and returned. 51 | 52 | In practice, the multi-headed attention are done with tf.einsum as follows: 53 | Input_tensor: [BFD] 54 | Wq, Wk, Wv: [DNH] 55 | Q:[BFNH] = einsum('BFD,DNH->BFNH', Input_tensor, Wq) 56 | K:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wk) 57 | V:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wv) 58 | attention_scores:[BNFT] = einsum('BTNH,BFNH->BNFT', K, Q) / sqrt(H) 59 | attention_probs:[BNFT] = softmax(attention_scores) 60 | context_layer:[BFNH] = einsum('BNFT,BTNH->BFNH', attention_probs, V) 61 | Wout:[DNH] 62 | Output:[BFD] = einsum('BFNH,DNH>BFD', context_layer, Wout) 63 | """ 64 | 65 | def __init__(self, 66 | num_attention_heads=12, 67 | size_per_head=64, 68 | attention_probs_dropout_prob=0.0, 69 | initializer_range=0.02, 70 | backward_compatible=False, 71 | **kwargs): 72 | super(Attention, self).__init__(**kwargs) 73 | self.num_attention_heads = num_attention_heads 74 | self.size_per_head = size_per_head 75 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 76 | self.initializer_range = initializer_range 77 | self.backward_compatible = backward_compatible 78 | 79 | def build(self, unused_input_shapes): 80 | """Implements build() for the layer.""" 81 | self.query_dense = self._projection_dense_layer("query") 82 | self.key_dense = self._projection_dense_layer("key") 83 | self.value_dense = self._projection_dense_layer("value") 84 | self.attention_probs_dropout = tf.keras.layers.Dropout( 85 | rate=self.attention_probs_dropout_prob) 86 | super(Attention, self).build(unused_input_shapes) 87 | 88 | def reshape_to_matrix(self, input_tensor): 89 | """Reshape N > 2 rank tensor to rank 2 tensor for performance.""" 90 | ndims = input_tensor.shape.ndims 91 | if ndims < 2: 92 | raise ValueError("Input tensor must have at least rank 2." 93 | "Shape = %s" % (input_tensor.shape)) 94 | if ndims == 2: 95 | return input_tensor 96 | 97 | width = input_tensor.shape[-1] 98 | output_tensor = tf.reshape(input_tensor, [-1, width]) 99 | return output_tensor 100 | 101 | def __call__(self, from_tensor, to_tensor, attention_mask=None, **kwargs): 102 | inputs = tf_utils.pack_inputs([from_tensor, to_tensor, attention_mask]) 103 | return super(Attention, self).__call__(inputs, **kwargs) 104 | 105 | def call(self, inputs): 106 | """Implements call() for the layer.""" 107 | (from_tensor, to_tensor, attention_mask) = tf_utils.unpack_inputs(inputs) 108 | 109 | # from_tensor_norm = tf.sqrt( 110 | # tf.reduce_sum( 111 | # tf.square(from_tensor), 112 | # axis=-1, 113 | # keepdims=True 114 | # ) 115 | # ) + 1e-12 116 | # from_tensor_univec = from_tensor/from_tensor_norm 117 | 118 | # Scalar dimensions referenced here: 119 | # B = batch size (number of sequences) 120 | # F = `from_tensor` sequence length 121 | # T = `to_tensor` sequence length 122 | # N = `num_attention_heads` 123 | # H = `size_per_head` 124 | 125 | # `query_tensor` = [B, F, N ,H] 126 | query_tensor = self.query_dense(from_tensor) 127 | # query_tensor = self.query_dense(from_tensor_univec) 128 | 129 | # `key_tensor` = [B, T, N, H] 130 | key_tensor = self.key_dense(to_tensor) 131 | 132 | # `value_tensor` = [B, T, N, H] 133 | value_tensor = self.value_dense(to_tensor) 134 | 135 | # Take the dot product between "query" and "key" to get the raw 136 | # attention scores. 137 | attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor) 138 | attention_scores = tf.multiply(attention_scores, 139 | 1.0 / math.sqrt(float(self.size_per_head))) 140 | 141 | if attention_mask is not None: 142 | # `attention_mask` = [B, 1, F, T] 143 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 144 | 145 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 146 | # masked positions, this operation will create a tensor which is 0.0 for 147 | # positions we want to attend and -10000.0 for masked positions. 148 | adder = (1.0 - tf.cast(attention_mask, attention_scores.dtype)) * -10000.0 149 | 150 | # Since we are adding it to the raw scores before the softmax, this is 151 | # effectively the same as removing these entirely. 152 | attention_scores += adder 153 | 154 | # Normalize the attention scores to probabilities. 155 | # `attention_probs` = [B, N, F, T] 156 | attention_probs = tf.nn.softmax(attention_scores) 157 | 158 | # This is actually dropping out entire tokens to attend to, which might 159 | # seem a bit unusual, but is taken from the original Transformer paper. 160 | attention_probs = self.attention_probs_dropout(attention_probs) 161 | 162 | # `context_layer` = [B, F, N, H] 163 | context_tensor = tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor) 164 | 165 | return context_tensor 166 | 167 | def _projection_dense_layer(self, name): 168 | """A helper to define a projection layer.""" 169 | return Dense3D( 170 | num_attention_heads=self.num_attention_heads, 171 | size_per_head=self.size_per_head, 172 | kernel_initializer=get_initializer(self.initializer_range), 173 | output_projection=False, 174 | backward_compatible=self.backward_compatible, 175 | name=name) 176 | 177 | 178 | class Dense3D(tf.keras.layers.Layer): 179 | """A Dense Layer using 3D kernel with tf.einsum implementation. 180 | 181 | Attributes: 182 | num_attention_heads: An integer, number of attention heads for each 183 | multihead attention layer. 184 | size_per_head: An integer, hidden size per attention head. 185 | hidden_size: An integer, dimension of the hidden layer. 186 | kernel_initializer: An initializer for the kernel weight. 187 | bias_initializer: An initializer for the bias. 188 | activation: An activation function to use. If nothing is specified, no 189 | activation is applied. 190 | use_bias: A bool, whether the layer uses a bias. 191 | output_projection: A bool, whether the Dense3D layer is used for output 192 | linear projection. 193 | backward_compatible: A bool, whether the variables shape are compatible 194 | with checkpoints converted from TF 1.x. 195 | """ 196 | 197 | def __init__(self, 198 | num_attention_heads=12, 199 | size_per_head=72, 200 | kernel_initializer=None, 201 | bias_initializer="zeros", 202 | activation=None, 203 | use_bias=True, 204 | output_projection=False, 205 | backward_compatible=False, 206 | **kwargs): 207 | """Inits Dense3D.""" 208 | super(Dense3D, self).__init__(**kwargs) 209 | self.num_attention_heads = num_attention_heads 210 | self.size_per_head = size_per_head 211 | self.hidden_size = num_attention_heads * size_per_head 212 | self.kernel_initializer = kernel_initializer 213 | self.bias_initializer = bias_initializer 214 | self.activation = activation 215 | self.use_bias = use_bias 216 | self.output_projection = output_projection 217 | self.backward_compatible = backward_compatible 218 | 219 | @property 220 | def compatible_kernel_shape(self): 221 | if self.output_projection: 222 | return [self.hidden_size, self.hidden_size] 223 | return [self.last_dim, self.hidden_size] 224 | 225 | @property 226 | def compatible_bias_shape(self): 227 | return [self.hidden_size] 228 | 229 | @property 230 | def kernel_shape(self): 231 | if self.output_projection: 232 | return [self.num_attention_heads, self.size_per_head, self.hidden_size] 233 | return [self.last_dim, self.num_attention_heads, self.size_per_head] 234 | 235 | @property 236 | def bias_shape(self): 237 | if self.output_projection: 238 | return [self.hidden_size] 239 | return [self.num_attention_heads, self.size_per_head] 240 | 241 | def build(self, input_shape): 242 | """Implements build() for the layer.""" 243 | dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx()) 244 | if not (dtype.is_floating or dtype.is_complex): 245 | raise TypeError("Unable to build `Dense3D` layer with non-floating " 246 | "point (and non-complex) dtype %s" % (dtype,)) 247 | input_shape = tf.TensorShape(input_shape) 248 | if tf.compat.dimension_value(input_shape[-1]) is None: 249 | raise ValueError("The last dimension of the inputs to `Dense3D` " 250 | "should be defined. Found `None`.") 251 | self.last_dim = tf.compat.dimension_value(input_shape[-1]) 252 | self.input_spec = tf.keras.layers.InputSpec( 253 | min_ndim=3, axes={-1: self.last_dim}) 254 | # Determines variable shapes. 255 | if self.backward_compatible: 256 | kernel_shape = self.compatible_kernel_shape 257 | bias_shape = self.compatible_bias_shape 258 | else: 259 | kernel_shape = self.kernel_shape 260 | bias_shape = self.bias_shape 261 | 262 | self.kernel = self.add_weight( 263 | "kernel", 264 | shape=kernel_shape, 265 | initializer=self.kernel_initializer, 266 | dtype=self.dtype, 267 | trainable=True) 268 | if self.use_bias: 269 | self.bias = self.add_weight( 270 | "bias", 271 | shape=bias_shape, 272 | initializer=self.bias_initializer, 273 | dtype=self.dtype, 274 | trainable=True) 275 | else: 276 | self.bias = None 277 | super(Dense3D, self).build(input_shape) 278 | 279 | def call(self, inputs): 280 | """Implements ``call()`` for Dense3D. 281 | 282 | Args: 283 | inputs: A float tensor of shape [batch_size, sequence_length, hidden_size] 284 | when output_projection is False, otherwise a float tensor of shape 285 | [batch_size, sequence_length, num_heads, dim_per_head]. 286 | 287 | Returns: 288 | The projected tensor with shape [batch_size, sequence_length, num_heads, 289 | dim_per_head] when output_projection is False, otherwise [batch_size, 290 | sequence_length, hidden_size]. 291 | """ 292 | if self.backward_compatible: 293 | kernel = tf.keras.backend.reshape(self.kernel, self.kernel_shape) 294 | bias = (tf.keras.backend.reshape(self.bias, self.bias_shape) 295 | if self.use_bias else None) 296 | else: 297 | kernel = self.kernel 298 | bias = self.bias 299 | 300 | if self.output_projection: 301 | ret = tf.einsum("abcd,cde->abe", inputs, kernel) 302 | else: 303 | ret = tf.einsum("abc,cde->abde", inputs, kernel) 304 | if self.use_bias: 305 | ret += bias 306 | if self.activation is not None: 307 | return self.activation(ret) 308 | return ret 309 | 310 | 311 | class Dense2DProjection(tf.keras.layers.Layer): 312 | """A 2D projection layer with tf.einsum implementation.""" 313 | 314 | def __init__(self, 315 | output_size, 316 | kernel_initializer=None, 317 | bias_initializer="zeros", 318 | activation=None, 319 | fp32_activation=False, 320 | **kwargs): 321 | super(Dense2DProjection, self).__init__(**kwargs) 322 | self.output_size = output_size 323 | self.kernel_initializer = kernel_initializer 324 | self.bias_initializer = bias_initializer 325 | self.activation = activation 326 | self.fp32_activation = fp32_activation 327 | 328 | def build(self, input_shape): 329 | """Implements build() for the layer.""" 330 | dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx()) 331 | if not (dtype.is_floating or dtype.is_complex): 332 | raise TypeError("Unable to build `Dense2DProjection` layer with " 333 | "non-floating point (and non-complex) " 334 | "dtype %s" % (dtype,)) 335 | input_shape = tf.TensorShape(input_shape) 336 | if tf.compat.dimension_value(input_shape[-1]) is None: 337 | raise ValueError("The last dimension of the inputs to " 338 | "`Dense2DProjection` should be defined. " 339 | "Found `None`.") 340 | last_dim = tf.compat.dimension_value(input_shape[-1]) 341 | self.input_spec = tf.keras.layers.InputSpec(min_ndim=3, axes={-1: last_dim}) 342 | self.kernel = self.add_weight( 343 | "kernel", 344 | shape=[last_dim, self.output_size], 345 | initializer=self.kernel_initializer, 346 | dtype=self.dtype, 347 | trainable=True) 348 | self.bias = self.add_weight( 349 | "bias", 350 | shape=[self.output_size], 351 | initializer=self.bias_initializer, 352 | dtype=self.dtype, 353 | trainable=True) 354 | super(Dense2DProjection, self).build(input_shape) 355 | 356 | def call(self, inputs): 357 | """Implements call() for Dense2DProjection. 358 | 359 | Args: 360 | inputs: float Tensor of shape [batch, from_seq_length, 361 | num_attention_heads, size_per_head]. 362 | 363 | Returns: 364 | A 3D Tensor. 365 | """ 366 | ret = tf.einsum("abc,cd->abd", inputs, self.kernel) 367 | ret += self.bias 368 | if self.activation is not None: 369 | if self.dtype == tf.float16 and self.fp32_activation: 370 | ret = tf.cast(ret, tf.float32) 371 | return self.activation(ret) 372 | return ret 373 | 374 | 375 | class TransformerBlock(tf.keras.layers.Layer): 376 | """Single transformer layer. 377 | 378 | It has two sub-layers. The first is a multi-head self-attention mechanism, and 379 | the second is a positionwise fully connected feed-forward network. 380 | """ 381 | 382 | def __init__(self, 383 | hidden_size=768, 384 | num_attention_heads=12, 385 | intermediate_size=3072, 386 | intermediate_activation="gelu", 387 | hidden_dropout_prob=0.0, 388 | attention_probs_dropout_prob=0.0, 389 | initializer_range=0.02, 390 | backward_compatible=False, 391 | float_type=tf.float32, 392 | **kwargs): 393 | super(TransformerBlock, self).__init__(**kwargs) 394 | self.hidden_size = hidden_size 395 | self.num_attention_heads = num_attention_heads 396 | self.intermediate_size = intermediate_size 397 | self.intermediate_activation = tf_utils.get_activation( 398 | intermediate_activation) 399 | self.hidden_dropout_prob = hidden_dropout_prob 400 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 401 | self.initializer_range = initializer_range 402 | self.backward_compatible = backward_compatible 403 | self.float_type = float_type 404 | 405 | if self.hidden_size % self.num_attention_heads != 0: 406 | raise ValueError( 407 | "The hidden size (%d) is not a multiple of the number of attention " 408 | "heads (%d)" % (self.hidden_size, self.num_attention_heads)) 409 | self.attention_head_size = int(self.hidden_size / self.num_attention_heads) 410 | 411 | def build(self, unused_input_shapes): 412 | """Implements build() for the layer.""" 413 | self.attention_layer = Attention( 414 | num_attention_heads=self.num_attention_heads, 415 | size_per_head=self.attention_head_size, 416 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 417 | initializer_range=self.initializer_range, 418 | backward_compatible=self.backward_compatible, 419 | name="self_attention") 420 | self.attention_output_dense = Dense3D( 421 | num_attention_heads=self.num_attention_heads, 422 | size_per_head=int(self.hidden_size / self.num_attention_heads), 423 | kernel_initializer=get_initializer(self.initializer_range), 424 | output_projection=True, 425 | backward_compatible=self.backward_compatible, 426 | name="self_attention_output") 427 | self.attention_dropout = tf.keras.layers.Dropout( 428 | rate=self.hidden_dropout_prob) 429 | self.attention_layer_norm = ( 430 | tf.keras.layers.LayerNormalization( 431 | name="self_attention_layer_norm", axis=-1, epsilon=1e-12, 432 | # We do layer norm in float32 for numeric stability. 433 | dtype=tf.float32)) 434 | self.intermediate_dense = Dense2DProjection( 435 | output_size=self.intermediate_size, 436 | kernel_initializer=get_initializer(self.initializer_range), 437 | activation=self.intermediate_activation, 438 | # Uses float32 so that gelu activation is done in float32. 439 | fp32_activation=True, 440 | name="intermediate") 441 | self.output_dense = Dense2DProjection( 442 | output_size=self.hidden_size, 443 | kernel_initializer=get_initializer(self.initializer_range), 444 | name="output") 445 | self.output_dropout = tf.keras.layers.Dropout(rate=self.hidden_dropout_prob) 446 | self.output_layer_norm = tf.keras.layers.LayerNormalization( 447 | name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32) 448 | super(TransformerBlock, self).build(unused_input_shapes) 449 | 450 | def common_layers(self): 451 | """Explicitly gets all layer objects inside a Transformer encoder block.""" 452 | return [ 453 | self.attention_layer, self.attention_output_dense, 454 | self.attention_dropout, self.attention_layer_norm, 455 | self.intermediate_dense, self.output_dense, self.output_dropout, 456 | self.output_layer_norm 457 | ] 458 | 459 | def __call__(self, input_tensor, attention_mask=None): 460 | inputs = tf_utils.pack_inputs([input_tensor, attention_mask]) 461 | return super(TransformerBlock, self).__call__(inputs) 462 | 463 | def call(self, inputs): 464 | """Implements call() for the layer.""" 465 | (input_tensor, attention_mask) = tf_utils.unpack_inputs(inputs) 466 | attention_output = self.attention_layer( 467 | from_tensor=input_tensor, 468 | to_tensor=input_tensor, 469 | attention_mask=attention_mask) 470 | attention_output = self.attention_output_dense(attention_output) 471 | attention_output = self.attention_dropout(attention_output) 472 | # Use float32 in keras layer norm and the gelu activation in the 473 | # intermediate dense layer for numeric stability 474 | 475 | attention_output = self.attention_layer_norm(input_tensor + 476 | attention_output) 477 | # attention_output = input_tensor + attention_output 478 | if self.float_type == tf.float16: 479 | attention_output = tf.cast(attention_output, tf.float16) 480 | intermediate_output = self.intermediate_dense(attention_output) 481 | if self.float_type == tf.float16: 482 | intermediate_output = tf.cast(intermediate_output, tf.float16) 483 | layer_output = self.output_dense(intermediate_output) 484 | layer_output = self.output_dropout(layer_output) 485 | # Use float32 in keras layer norm for numeric stability 486 | layer_output = self.output_layer_norm(layer_output + attention_output) 487 | # layer_output = layer_output + attention_output 488 | if self.float_type == tf.float16: 489 | layer_output = tf.cast(layer_output, tf.float16) 490 | return layer_output 491 | 492 | 493 | class Transformer(tf.keras.layers.Layer): 494 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 495 | 496 | This is almost an exact implementation of the original Transformer encoder. 497 | 498 | See the original paper: 499 | https://arxiv.org/abs/1706.03762 500 | 501 | Also see: 502 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 503 | """ 504 | 505 | def __init__(self, 506 | num_hidden_layers=12, 507 | hidden_size=768, 508 | num_attention_heads=12, 509 | intermediate_size=3072, 510 | intermediate_activation="gelu", 511 | hidden_dropout_prob=0.0, 512 | attention_probs_dropout_prob=0.0, 513 | initializer_range=0.02, 514 | backward_compatible=False, 515 | float_type=tf.float32, 516 | **kwargs): 517 | super(Transformer, self).__init__(**kwargs) 518 | self.num_hidden_layers = num_hidden_layers 519 | self.hidden_size = hidden_size 520 | self.num_attention_heads = num_attention_heads 521 | self.intermediate_size = intermediate_size 522 | self.intermediate_activation = tf_utils.get_activation( 523 | intermediate_activation) 524 | self.hidden_dropout_prob = hidden_dropout_prob 525 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 526 | self.initializer_range = initializer_range 527 | self.backward_compatible = backward_compatible 528 | self.float_type = float_type 529 | 530 | def build(self, unused_input_shapes): 531 | """Implements build() for the layer.""" 532 | self.layers = [] 533 | for i in range(self.num_hidden_layers): 534 | self.layers.append( 535 | TransformerBlock( 536 | hidden_size=self.hidden_size, 537 | num_attention_heads=self.num_attention_heads, 538 | intermediate_size=self.intermediate_size, 539 | intermediate_activation=self.intermediate_activation, 540 | hidden_dropout_prob=self.hidden_dropout_prob, 541 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 542 | initializer_range=self.initializer_range, 543 | backward_compatible=self.backward_compatible, 544 | float_type=self.float_type, 545 | name=("layer_%d" % i))) 546 | super(Transformer, self).build(unused_input_shapes) 547 | 548 | def __call__(self, input_tensor, attention_mask=None, **kwargs): 549 | inputs = tf_utils.pack_inputs([input_tensor, attention_mask]) 550 | return super(Transformer, self).__call__(inputs=inputs, **kwargs) 551 | 552 | def call(self, inputs, return_all_layers=False): 553 | """Implements call() for the layer. 554 | 555 | Args: 556 | inputs: packed inputs. 557 | return_all_layers: bool, whether to return outputs of all layers inside 558 | encoders. 559 | Returns: 560 | Output tensor of the last layer or a list of output tensors. 561 | """ 562 | unpacked_inputs = tf_utils.unpack_inputs(inputs) 563 | input_tensor = unpacked_inputs[0] 564 | attention_mask = unpacked_inputs[1] 565 | output_tensor = input_tensor 566 | 567 | all_layer_outputs = [] 568 | for layer in self.layers: 569 | output_tensor = layer(output_tensor, attention_mask) 570 | all_layer_outputs.append(output_tensor) 571 | 572 | if return_all_layers: 573 | return all_layer_outputs 574 | 575 | return all_layer_outputs[-1] 576 | 577 | 578 | def get_initializer(initializer_range=0.02): 579 | """Creates a `tf.initializers.truncated_normal` with the given range. 580 | 581 | Args: 582 | initializer_range: float, initializer range for stddev. 583 | 584 | Returns: 585 | TruncatedNormal initializer with stddev = `initializer_range`. 586 | """ 587 | return tf.keras.initializers.TruncatedNormal(stddev=initializer_range) 588 | 589 | 590 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 591 | """Create 3D attention mask from a 2D tensor mask. 592 | 593 | Args: 594 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 595 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 596 | 597 | Returns: 598 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 599 | """ 600 | from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3]) 601 | batch_size = from_shape[0] 602 | from_seq_length = from_shape[1] 603 | 604 | to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2) 605 | to_seq_length = to_shape[1] 606 | 607 | to_mask = tf.cast( 608 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), 609 | dtype=from_tensor.dtype) 610 | 611 | # We don't assume that `from_tensor` is a mask (although it could be). We 612 | # don't actually care if we attend *from* padding tokens (only *to* padding) 613 | # tokens so we create a tensor of all ones. 614 | # 615 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 616 | broadcast_ones = tf.ones( 617 | shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype) 618 | 619 | # Here we broadcast along two dimensions to create the mask. 620 | mask = broadcast_ones * to_mask 621 | 622 | return mask 623 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/bert_optimization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | 23 | import tensorflow as tf 24 | 25 | 26 | class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): 27 | """Applys a warmup schedule on a given learning rate decay schedule.""" 28 | 29 | def __init__( 30 | self, 31 | initial_learning_rate, 32 | decay_schedule_fn, 33 | warmup_steps, 34 | power=1.0, 35 | name=None): 36 | super(WarmUp, self).__init__() 37 | self.initial_learning_rate = initial_learning_rate 38 | self.warmup_steps = warmup_steps 39 | self.power = power 40 | self.decay_schedule_fn = decay_schedule_fn 41 | self.name = name 42 | 43 | def __call__(self, step): 44 | with tf.name_scope(self.name or 'WarmUp') as name: 45 | # Implements polynomial warmup. i.e., if global_step < warmup_steps, the 46 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 47 | global_step_float = tf.cast(step, tf.float32) 48 | warmup_steps_float = tf.cast(self.warmup_steps, tf.float32) 49 | warmup_percent_done = global_step_float / warmup_steps_float 50 | warmup_learning_rate = ( 51 | self.initial_learning_rate * 52 | tf.math.pow(warmup_percent_done, self.power)) 53 | return tf.cond(global_step_float < warmup_steps_float, 54 | lambda: warmup_learning_rate, 55 | lambda: self.decay_schedule_fn(step), 56 | name=name) 57 | 58 | def get_config(self): 59 | return { 60 | 'initial_learning_rate': self.initial_learning_rate, 61 | 'decay_schedule_fn': self.decay_schedule_fn, 62 | 'warmup_steps': self.warmup_steps, 63 | 'power': self.power, 64 | 'name': self.name 65 | } 66 | 67 | 68 | def create_optimizer(init_lr, num_train_steps, num_warmup_steps): 69 | """Creates an optimizer with learning rate schedule.""" 70 | # Implements linear decay of the learning rate. 71 | learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( 72 | initial_learning_rate=init_lr, 73 | decay_steps=num_train_steps, 74 | end_learning_rate=0.0) 75 | if num_warmup_steps: 76 | learning_rate_fn = WarmUp(initial_learning_rate=init_lr, 77 | decay_schedule_fn=learning_rate_fn, 78 | warmup_steps=num_warmup_steps) 79 | optimizer = AdamWeightDecay( 80 | learning_rate=learning_rate_fn, 81 | weight_decay_rate=0.01, 82 | beta_1=0.9, 83 | beta_2=0.999, 84 | epsilon=1e-6, 85 | exclude_from_weight_decay=['layer_norm', 'bias']) 86 | return optimizer 87 | 88 | 89 | class AdamWeightDecay(tf.keras.optimizers.Adam): 90 | """Adam enables L2 weight decay and clip_by_global_norm on gradients. 91 | 92 | Just adding the square of the weights to the loss function is *not* the 93 | correct way of using L2 regularization/weight decay with Adam, since that will 94 | interact with the m and v parameters in strange ways. 95 | 96 | Instead we want ot decay the weights in a manner that doesn't interact with 97 | the m/v parameters. This is equivalent to adding the square of the weights to 98 | the loss with plain (non-momentum) SGD. 99 | """ 100 | 101 | def __init__(self, 102 | learning_rate=0.001, 103 | beta_1=0.9, 104 | beta_2=0.999, 105 | epsilon=1e-7, 106 | amsgrad=False, 107 | weight_decay_rate=0.0, 108 | include_in_weight_decay=None, 109 | exclude_from_weight_decay=None, 110 | name='AdamWeightDecay', 111 | **kwargs): 112 | super(AdamWeightDecay, self).__init__( 113 | learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs) 114 | self.weight_decay_rate = weight_decay_rate 115 | self._include_in_weight_decay = include_in_weight_decay 116 | self._exclude_from_weight_decay = exclude_from_weight_decay 117 | 118 | @classmethod 119 | def from_config(cls, config): 120 | """Creates an optimizer from its config with WarmUp custom object.""" 121 | custom_objects = {'WarmUp': WarmUp} 122 | return super(AdamWeightDecay, cls).from_config( 123 | config, custom_objects=custom_objects) 124 | 125 | def _prepare_local(self, var_device, var_dtype, apply_state): 126 | super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, 127 | apply_state) 128 | apply_state['weight_decay_rate'] = tf.constant( 129 | self.weight_decay_rate, name='adam_weight_decay_rate') 130 | 131 | def _decay_weights_op(self, var, learning_rate, apply_state): 132 | do_decay = self._do_use_weight_decay(var.name) 133 | if do_decay: 134 | return var.assign_sub( 135 | learning_rate * var * 136 | apply_state['weight_decay_rate'], 137 | use_locking=self._use_locking) 138 | return tf.no_op() 139 | 140 | def apply_gradients(self, grads_and_vars, name=None): 141 | grads, tvars = list(zip(*grads_and_vars)) 142 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 143 | return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars)) 144 | 145 | def _get_lr(self, var_device, var_dtype, apply_state): 146 | """Retrieves the learning rate with the given state.""" 147 | if apply_state is None: 148 | return self._decayed_lr_t[var_dtype], {} 149 | 150 | apply_state = apply_state or {} 151 | coefficients = apply_state.get((var_device, var_dtype)) 152 | if coefficients is None: 153 | coefficients = self._fallback_apply_state(var_device, var_dtype) 154 | apply_state[(var_device, var_dtype)] = coefficients 155 | 156 | return coefficients['lr_t'], dict(apply_state=apply_state) 157 | 158 | def _resource_apply_dense(self, grad, var, apply_state=None): 159 | lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state) 160 | decay = self._decay_weights_op(var, lr_t, apply_state) 161 | with tf.control_dependencies([decay]): 162 | return super(AdamWeightDecay, self)._resource_apply_dense( 163 | grad, var, **kwargs) 164 | 165 | def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 166 | lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state) 167 | decay = self._decay_weights_op(var, lr_t, apply_state) 168 | with tf.control_dependencies([decay]): 169 | return super(AdamWeightDecay, self)._resource_apply_sparse( 170 | grad, var, indices, **kwargs) 171 | 172 | def get_config(self): 173 | config = super(AdamWeightDecay, self).get_config() 174 | config.update({ 175 | 'weight_decay_rate': self.weight_decay_rate, 176 | }) 177 | return config 178 | 179 | def _do_use_weight_decay(self, param_name): 180 | """Whether to use L2 weight decay for `param_name`.""" 181 | if self.weight_decay_rate == 0: 182 | return False 183 | 184 | if self._include_in_weight_decay: 185 | for r in self._include_in_weight_decay: 186 | if re.search(r, param_name) is not None: 187 | return True 188 | 189 | if self._exclude_from_weight_decay: 190 | for r in self._exclude_from_weight_decay: 191 | if re.search(r, param_name) is not None: 192 | return False 193 | return True 194 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/circle_loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | class CircleLoss(tf.keras.losses.Loss): 5 | 6 | def __init__(self, 7 | gamma: int = 64, 8 | margin: float = 0.25, 9 | reduction='auto', 10 | name=None): 11 | super().__init__(reduction=reduction, name=name) 12 | self.gamma = gamma 13 | self.margin = margin 14 | self.O_p = 1 + self.margin 15 | self.O_n = -self.margin 16 | self.Delta_p = 1 - self.margin 17 | self.Delta_n = self.margin 18 | 19 | def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: 20 | """ NOTE : y_pred must be cos similarity/dot similarity/logit 21 | 22 | Args: 23 | y_true (tf.Tensor): shape [None,num_total_labels] one-hot matrix (0, 1) 24 | y_pred (tf.Tensor): shape [None,num_total_labels] float logits 25 | 26 | Returns: 27 | tf.Tensor: loss 28 | """ 29 | 30 | alpha_p = tf.nn.relu(self.O_p - tf.stop_gradient(y_pred)) 31 | alpha_n = tf.nn.relu(tf.stop_gradient(y_pred) - self.O_n) 32 | # yapf: disable 33 | y_true = tf.cast(y_true, tf.float32) 34 | 35 | # (None, num_total_labels) 36 | logit_p = - y_true * (alpha_p * (y_pred - self.Delta_p)) * self.gamma 37 | # minus 10000.0 to make contribution from mask is zero 38 | logit_p = logit_p - (1.0 - y_true) * 10000.0 39 | # (None, ) 40 | loss_p = tf.reduce_logsumexp(logit_p, axis=-1) 41 | 42 | # (None, num_total_labels) 43 | logit_n = (1.0 - y_true) * (alpha_n * (y_pred - self.Delta_n)) * self.gamma 44 | logit_n = logit_n - y_true * 10000.0 45 | # (None, ) 46 | loss_n = tf.reduce_logsumexp(logit_n, axis=-1) 47 | 48 | loss = tf.math.softplus(loss_p+loss_n) 49 | 50 | return loss 51 | 52 | 53 | if __name__ == "__main__": 54 | batch_size = 2 55 | nclass = 5 56 | y_pred = tf.random.uniform((batch_size, nclass), -1, 1, dtype=tf.float32) 57 | 58 | # y_true = tf.random.uniform((batch_size,), 0, nclass, dtype=tf.int32) 59 | # y_true = tf.one_hot(y_true, nclass, dtype=tf.float32) 60 | y_true = tf.constant([[1,0,0,1,1], [0,0,1,0,1]]) 61 | 62 | mycircleloss = CircleLoss() 63 | 64 | print( 65 | 'mycircleloss:\n', 66 | mycircleloss.call(y_true, y_pred).numpy(), 67 | mycircleloss(y_true, y_pred) 68 | ) -------------------------------------------------------------------------------- /SynthesisSimilarity/core/encoders.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | import numpy as np 4 | 5 | from . import tf_utils 6 | from .utils import get_mat_mask_in_mat_seq 7 | from .layers import UnifyVector 8 | from .layers import Sampling 9 | from .layers import ZeroShift 10 | 11 | __author__ = "Tanjin He" 12 | __maintainer__ = "Tanjin He" 13 | __email__ = "tanjin_he@berkeley.edu" 14 | 15 | 16 | class MaterialEncoder(keras.Model): 17 | def __init__( 18 | self, 19 | mat_feature_len, 20 | dim_features, 21 | latent_dim, 22 | zero_shift_init_value, 23 | zero_shift_trainable, 24 | num_attention_layers, 25 | num_attention_heads, 26 | hidden_activation, 27 | hidden_dropout, 28 | attention_dropout, 29 | initializer_range, 30 | normalize_output=True, 31 | mask_zero=True, 32 | **kwargs 33 | ): 34 | super().__init__(**kwargs) 35 | self.mat_feature_len = mat_feature_len 36 | self.dim_features = dim_features 37 | self.latent_dim = latent_dim 38 | self.zero_shift_init_value = zero_shift_init_value 39 | self.zero_shift_trainable = zero_shift_trainable 40 | self.num_attention_layers = num_attention_layers 41 | self.num_attention_heads = num_attention_heads 42 | self.hidden_activation = hidden_activation 43 | self.hidden_dropout = hidden_dropout 44 | self.attention_dropout = attention_dropout 45 | self.initializer_range = initializer_range 46 | self.mask_zero = mask_zero 47 | self.normalize_output = normalize_output 48 | 49 | if self.hidden_activation in { 50 | "gelu", 51 | }: 52 | self.hidden_activation = tf_utils.get_activation(self.hidden_activation) 53 | 54 | self.zero_shift_layer = ZeroShift( 55 | shift_init_value=self.zero_shift_init_value, 56 | shift_trainable=self.zero_shift_trainable, 57 | ) 58 | self.dens_1 = keras.layers.Dense( 59 | self.dim_features, 60 | activation=self.hidden_activation, 61 | ) 62 | self.hidden_layers = [ 63 | keras.layers.Dense( 64 | self.latent_dim, 65 | activation=self.hidden_activation, 66 | ) 67 | for _ in range(self.num_attention_layers) 68 | ] 69 | 70 | self.uni_vec = UnifyVector() 71 | 72 | # # by default linear activation is used for Dense 73 | # self.dense_mean = keras.layers.Dense(self.latent_dim) 74 | # self.dense_log_var = keras.layers.Dense(self.latent_dim) 75 | # self.sampling = Sampling(initializer_range=self.initializer_range) 76 | 77 | def call(self, inputs, mask=None): 78 | """ 79 | 80 | :param inputs: (None, mat_feature_len) 81 | :param mask: (None, ) mask of materials. 82 | list of True or False 83 | True represents the materials is to be calculated 84 | False represents the return is zero by default 85 | :return z_mean: (None, dim_features) 86 | z_log_var: (None, dim_features) 87 | z: (None, dim_features) 88 | """ 89 | # inputs: (None, mat_feature_len) 90 | input_shape = tf.shape(inputs) 91 | if mask is None: 92 | mask = self.compute_mask(inputs) 93 | # next: (None, mat_feature_len) 94 | if mask is not None: 95 | processed_inputs = inputs[mask] 96 | else: 97 | processed_inputs = inputs 98 | # next: (None, mat_feature_len) 99 | _x = self.zero_shift_layer(processed_inputs) 100 | # next: (None, dim_features) 101 | _x = self.dens_1(_x) 102 | # more layers here 103 | # TODO: is dropout useful here? 104 | # https://github.com/tensorflow/tensor2tensor/blob/3aca2ab360271a4684ffa7ac8767995e264cc558/tensor2tensor/models/transformer.py#L95 105 | # next: (None, mat_feature_len+1, latent_dim) 106 | ... 107 | # next: (None, latent_dim) 108 | for tmp_layer in self.hidden_layers: 109 | _x = tmp_layer(_x) 110 | 111 | emb = _x 112 | 113 | if self.normalize_output: 114 | emb = self.uni_vec(emb) 115 | 116 | if mask is not None: 117 | effective_indices = tf.range(input_shape[0])[mask] 118 | output_shape = tf.concat((input_shape[:1], tf.shape(emb)[1:]), axis=0) 119 | emb = tf.scatter_nd( 120 | indices=tf.expand_dims(effective_indices, 1), 121 | updates=emb, 122 | shape=output_shape, 123 | ) 124 | return emb, emb, emb 125 | 126 | def compute_mask(self, inputs, mask=None): 127 | """ 128 | masked position is False. 129 | True represents a valid material 130 | 131 | :param inputs: (None, mat_feature_len) 132 | :param mask: not used 133 | :return mat_mask: (None, ) 134 | """ 135 | if not self.mask_zero: 136 | return None 137 | return get_mat_mask_in_mat_seq(inputs) 138 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/exp_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import tensorflow as tf 5 | from tensorflow import keras 6 | import numpy as np 7 | import math 8 | 9 | from .model_framework import MultiTasksOnRecipes 10 | 11 | # after experiments, we shall move ReactionTemperature 12 | # to .task_models.py if it is useful 13 | # from .task_models import ReactionTemperature 14 | 15 | __author__ = "Tanjin He" 16 | __maintainer__ = "Tanjin He" 17 | __email__ = "tanjin_he@berkeley.edu" 18 | 19 | 20 | ################################################# 21 | # ReactionTemperature is an example of task model, please design your task 22 | # based on this template 23 | ################################################# 24 | class ReactionTemperature(keras.Model): 25 | """ 26 | This is an example of task model, please design your task 27 | based on this template 28 | """ 29 | 30 | def __init__(self, num_eles, mat_encoder, **kwargs): 31 | super().__init__(**kwargs) 32 | self.num_eles = num_eles 33 | self.mat_encoder = mat_encoder 34 | 35 | def build(self, input_shape): 36 | """ 37 | 38 | :param input_shape: (None, num_eles) 39 | :return: 40 | """ 41 | self.alpha = self.add_weight( 42 | shape=(), 43 | initializer=keras.initializers.Constant(1), 44 | trainable=True, 45 | name="sim_diff_coeff", 46 | ) 47 | 48 | def call(self, inputs): 49 | """ 50 | 51 | :param inputs: (None, 2, max_mats_num, num_eles) 52 | :return loss: (None, dim_features) 53 | """ 54 | # inputs: (None, 2*num_eles+2) 55 | target_1 = inputs[:, : self.num_eles] 56 | target_2 = inputs[:, self.num_eles : 2 * self.num_eles] 57 | T_1 = inputs[:, 2 * self.num_eles] 58 | T_2 = inputs[:, 2 * self.num_eles + 1] 59 | 60 | # tar_emb_1 (None, latent_dim) 61 | tar_emb_1, _, _ = self.mat_encoder(target_1) 62 | tar_emb_2, _, _ = self.mat_encoder(target_2) 63 | 64 | # sim (None, ) 65 | sim_tar = tf.reduce_sum(tar_emb_1 * tar_emb_2, axis=-1) 66 | sim_T = 1 - tf.abs(T_1 - T_2) / (T_1 + T_2) 67 | 68 | return tf.square(sim_tar - self.alpha * sim_T) 69 | 70 | 71 | class ExpMultiTasksOnRecipes(MultiTasksOnRecipes): 72 | def __init__(self, **kwargs): 73 | super().__init__(**kwargs) 74 | ################################################# 75 | # add new tasks here 76 | ################################################# 77 | task_name = "reaction_T" 78 | if task_name in self.task_to_add: 79 | reaction_T = ReactionTemperature( 80 | num_eles=self.num_eles, 81 | mat_encoder=self.mat_encoder, 82 | ) 83 | self.add_task( 84 | task_name=task_name, 85 | task_model=reaction_T, 86 | task_loss=self.get_loss_target_T, 87 | task_weight=1.0, 88 | ) 89 | print("reaction_T here") 90 | 91 | print("tasks", self.task_names, self._weight_by_task) 92 | 93 | ################################################# 94 | # This is an example, please design your task 95 | # based on this template 96 | ################################################# 97 | def get_loss_target_T(self, inputs, task_name): 98 | """ 99 | regularize by sim(Tar) - w (|(T_1-T_2|/(T_1+T_2)) 100 | :param inputs: same as call(), which is a dict with keys 101 | reaction_1, reaction_1, 102 | temperature_1, temperature_2, ... 103 | :return: (batch_size, ) for each batch data 104 | """ 105 | # examples 106 | # to get all materials 107 | # all_mats: (batch_size*max_mats_num*2, mat_feature_len) 108 | all_mats_featurized = keras.layers.concatenate( 109 | [ 110 | tf.reshape(inputs["reaction_1_featurized"], (-1, self.num_eles)), 111 | tf.reshape(inputs["reaction_2_featurized"], (-1, self.num_eles)), 112 | ], 113 | axis=0, 114 | ) 115 | # to get all reactions 116 | # all_reactions: (batch_size*2, max_mats_num, mat_feature_len) 117 | all_reactions_featurized = keras.layers.concatenate( 118 | [ 119 | inputs["reaction_1_featurized"], 120 | inputs["reaction_2_featurized"], 121 | ], 122 | axis=0, 123 | ) 124 | # to get all reactions pairs 125 | # all_pairs: (batch_size, 2, max_mats_num, mat_feature_len) 126 | all_reaction_pairs_featurizec = tf.stack( 127 | [ 128 | inputs["reaction_1_featurized"], 129 | inputs["reaction_2_featurized"], 130 | ], 131 | axis=1, 132 | ) 133 | # to get all target T pairs 134 | # all_pairs: (batch_size, 2*mat_feature_len+2) 135 | all_target_T_pairs_featurized = tf.concat( 136 | [ 137 | inputs["reaction_1_featurized"][:, 0, :], 138 | inputs["reaction_2_featurized"][:, 0, :], 139 | tf.reshape(inputs["temperature_1"], (-1, 1)), 140 | tf.reshape(inputs["temperature_2"], (-1, 1)), 141 | ], 142 | axis=1, 143 | ) 144 | # we can use any one of all_mats, all_reactions, 145 | # all_reaction_pairs or all_target_T_pairs as the input to 146 | # a task model, and return the loss 147 | loss_reaction_T = self._model_by_task[task_name](all_target_T_pairs_featurized) 148 | 149 | ... 150 | 151 | # need to reshape the loss to be consistent with the 152 | # original shape of inputs, which is (batch_size, ) 153 | ... 154 | 155 | return loss_reaction_T 156 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | import numpy as np 4 | 5 | from .utils import repeat_in_last_dimension 6 | 7 | 8 | __author__ = 'Tanjin He' 9 | __maintainer__ = 'Tanjin He' 10 | __email__ = 'tanjin_he@berkeley.edu' 11 | 12 | 13 | class AddEMBInComposition(keras.layers.Layer): 14 | """ 15 | add a "EMB" unit in composition vector, just like BERT 16 | This "EMB" is used as the material embedding because 17 | it will be used to reconstruct the composition in encoding- 18 | decoding process 19 | 20 | """ 21 | def __init__(self, **kwargs): 22 | super().__init__(**kwargs) 23 | 24 | def call(self, inputs): 25 | """ 26 | 27 | :param inputs: (None, mat_feature_len) 28 | :return: (None, 1+mat_feature_len) 29 | """ 30 | # EMB = tf.ones(tf.concat(tf.shape(inputs)[:-1], [1]) ) 31 | EMB = tf.tile( 32 | tf.ones((1,1)), 33 | multiples=[tf.shape(inputs)[0], 1] 34 | ) 35 | return keras.layers.concatenate([EMB, inputs], axis=-1) 36 | 37 | 38 | class EmpiricalEmbedding(keras.layers.Layer): 39 | """ 40 | encode composition vector with empirical element embedding 41 | the first layer after composition input 42 | by default it is randomly initialized 43 | but it is potential to be initialized with elementary 44 | features 45 | """ 46 | def __init__(self, 47 | mat_feature_len, 48 | dim_features, 49 | emp_features=None, 50 | initializer_max=10, 51 | **kwargs): 52 | super().__init__(**kwargs) 53 | self.mat_feature_len = mat_feature_len 54 | self.dim_features = dim_features 55 | self.emp_features = emp_features 56 | self.initializer_max = initializer_max 57 | 58 | def build(self, input_shape): 59 | """ 60 | shape of emp_features: (mat_feature_len, dim_features,) 61 | 62 | :param input_shape: (None, mat_feature_len) 63 | :return: 64 | """ 65 | assert input_shape[-1] == self.mat_feature_len 66 | if self.emp_features == None: 67 | self.ele_features = self.add_weight( 68 | shape=(self.mat_feature_len, self.dim_features), 69 | initializer=keras.initializers.RandomUniform( 70 | minval=-self.initializer_max, 71 | maxval=self.initializer_max 72 | ), 73 | trainable=True, 74 | name='ele_features' 75 | ) 76 | else: 77 | # this is for future use 78 | # TODO: the features for "EMB" should be trainable, 79 | # though those for elements shouldn't 80 | assert self.emp_features.shape == ( 81 | self.mat_feature_len, self.dim_features 82 | ) 83 | self.ele_features = self.add_weight( 84 | shape=(self.mat_feature_len, self.dim_features), 85 | initializer=keras.initializers.Constant( 86 | self.emp_features 87 | ), 88 | trainable=False, 89 | name='ele_features' 90 | ) 91 | 92 | def call(self, inputs): 93 | """ 94 | 95 | :param inputs: (None, mat_feature_len) 96 | :return _x: (None, mat_feature_len, dim_features) 97 | """ 98 | # TODO: remove assert after testing 99 | assert len(inputs.shape) == 2 100 | # _x = tf.reshape( 101 | # tf.tile( 102 | # tf.reshape(inputs, (-1, 1)), 103 | # multiples=[1, self.dim_features] 104 | # ), 105 | # (-1, self.mat_feature_len, self.dim_features) 106 | # ) 107 | _x = repeat_in_last_dimension( 108 | from_tensor=inputs, 109 | from_seq_length=self.mat_feature_len, 110 | to_latent_dim=self.dim_features 111 | ) 112 | _x = _x* self.ele_features 113 | return _x 114 | 115 | 116 | class ZeroShift(keras.layers.Layer): 117 | """ 118 | 119 | """ 120 | def __init__( 121 | self, 122 | shift_init_value=-0.5, 123 | shift_trainable=False, 124 | **kwargs, 125 | ): 126 | super().__init__(**kwargs) 127 | self.shift_init_value = shift_init_value 128 | self.shift_trainable = shift_trainable 129 | 130 | def build(self, input_shape): 131 | """ 132 | shape of emp_features: (mat_feature_len, dim_features,) 133 | 134 | :param input_shape: (None, mat_feature_len) 135 | :return: 136 | """ 137 | self.shift_bias = self.add_weight( 138 | shape=(1, ), 139 | initializer=tf.keras.initializers.Constant(self.shift_init_value), 140 | trainable=self.shift_trainable, 141 | name='zero_shift_bias' 142 | ) 143 | 144 | def call(self, inputs): 145 | """ 146 | 147 | :param inputs: (None, mat_feature_len) 148 | :return _x: (None, mat_feature_len) 149 | """ 150 | _x = inputs + tf.cast( 151 | tf.equal(inputs, 0), tf.float32 152 | ) * self.shift_bias 153 | return _x 154 | 155 | 156 | class UnifyVector(keras.layers.Layer): 157 | """ 158 | make length vector to be 1 along the last dimension 159 | """ 160 | def __init__(self, **kwargs): 161 | super().__init__(**kwargs) 162 | 163 | def call(self, inputs): 164 | """ 165 | 166 | :param inputs: (None, ..., dim_features) 167 | :return _x: (None, ..., dim_features) 168 | """ 169 | # TODO: is 1e-6 ok? utils.NEAR_ZERO 170 | tensor_norm = tf.sqrt( 171 | tf.reduce_sum( 172 | tf.square(inputs), 173 | axis=-1, 174 | keepdims=True 175 | ) 176 | ) + 1e-12 177 | _x = inputs/tensor_norm 178 | return _x 179 | 180 | 181 | class PrecursorsPooling_1(keras.layers.Layer): 182 | """ 183 | pooling precursor list by reduce mean 184 | """ 185 | def __init__(self, **kwargs): 186 | super().__init__(**kwargs) 187 | 188 | def call(self, inputs): 189 | return tf.reduce_mean(input, axis=-2) 190 | 191 | 192 | class Sampling(keras.layers.Layer): 193 | """ 194 | Uses (z_mean, z_log_var) to sample z, 195 | the vector encoding a digit. 196 | """ 197 | def __init__(self, 198 | initializer_range, 199 | **kwargs): 200 | super().__init__(**kwargs) 201 | self.initializer_range = initializer_range 202 | 203 | def call(self, inputs): 204 | """ 205 | 206 | :param inputs: (None, latent_dim) 207 | :return: 208 | """ 209 | # TODO: should avoid masks because return is not always 210 | # 0 for input zero vector 211 | z_mean, z_log_var = inputs 212 | batch = tf.shape(z_mean)[0] 213 | dim = tf.shape(z_mean)[1] 214 | # TODO: which value of mean and std of random_normal to 215 | # pick for initialization? 216 | epsilon = tf.keras.backend.random_normal( 217 | shape=(batch, dim), 218 | stddev=self.initializer_range, 219 | ) 220 | return z_mean + tf.exp(0.5 * z_log_var) * epsilon 221 | 222 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | import numpy as np 4 | 5 | __author__ = 'Tanjin He' 6 | __maintainer__ = 'Tanjin He' 7 | __email__ = 'tanjin_he@berkeley.edu' 8 | 9 | 10 | # Custom loss layer 11 | class MultiLossLayer(keras.layers.Layer): 12 | def __init__(self, task_names, **kwargs): 13 | super().__init__(**kwargs) 14 | self.task_names = task_names 15 | self.num_task = len(self.task_names) 16 | 17 | def build(self, input_shape=None): 18 | # initialise log_vars 19 | self.log_vars = self.add_weight( 20 | name='log_vars', 21 | shape=(self.num_task, ), 22 | initializer=keras.initializers.Constant( 23 | np.zeros((self.num_task, ), dtype=np.float32) 24 | ), 25 | trainable=True 26 | ) 27 | 28 | def call(self, inputs): 29 | precision = tf.exp(-self.log_vars) 30 | multi_loss = tf.reduce_sum(inputs*precision, axis=-1) + \ 31 | tf.reduce_sum(self.log_vars, axis=-1) 32 | return multi_loss 33 | 34 | 35 | class CustomLoss(keras.losses.Loss): 36 | def __init__(self, **kwargs): 37 | super().__init__(**kwargs) 38 | 39 | def call(self, y_true, y_pred): 40 | """ 41 | :param y_true: 42 | :param y_pred: 43 | :return: 44 | """ 45 | loss = y_pred 46 | return loss 47 | 48 | 49 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/mat_featurization.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import numpy as np 3 | 4 | from SynthesisSimilarity.core.utils import array_to_composition 5 | from SynthesisSimilarity.core.utils import composition_to_array 6 | from SynthesisSimilarity.core.utils import get_material_valence_details 7 | from SynthesisSimilarity.core.utils import valence_to_array 8 | from SynthesisSimilarity.core.utils import array_to_formula 9 | from SynthesisSimilarity.core.utils import NEAR_ZERO 10 | 11 | 12 | __author__ = "Tanjin He" 13 | __maintainer__ = "Tanjin He" 14 | __email__ = "tanjin_he@berkeley.edu" 15 | 16 | 17 | def default_mat_featurizer(composition, **kwargs): 18 | ###################################################### 19 | # this is only an example, please add a similar one to this to 20 | # design custom features 21 | ###################################################### 22 | return composition.copy() 23 | 24 | 25 | def ion_frac_mat_featurizer( 26 | composition, 27 | ele_order, 28 | ion_order, 29 | feature_array=None, 30 | **kwargs, 31 | ): 32 | if feature_array is None: 33 | mat_composition = array_to_composition( 34 | comp_array=composition, 35 | elements=ele_order, 36 | ) 37 | oxi_details = get_material_valence_details(mat_composition) 38 | mat_features = valence_to_array( 39 | mat_ion=oxi_details, 40 | ion_order=ion_order, 41 | ) 42 | else: 43 | mat_features = feature_array.copy() 44 | return mat_features 45 | 46 | 47 | _featurizers = { 48 | "default": default_mat_featurizer, 49 | "ion_frac": ion_frac_mat_featurizer, 50 | } 51 | 52 | 53 | def mat_featurizer( 54 | composition, 55 | ele_order, 56 | featurizer_type="default", 57 | **kwargs, 58 | ): 59 | mat_features = _featurizers[featurizer_type]( 60 | composition=composition, 61 | ele_order=ele_order, 62 | **kwargs, 63 | ) 64 | 65 | return mat_features 66 | 67 | 68 | def featurize_list_of_composition( 69 | comps, 70 | ele_order, 71 | featurizer_type="default", 72 | **kwargs, 73 | ): 74 | feature_vectors = [ 75 | mat_featurizer( 76 | composition=comp, 77 | ele_order=ele_order, 78 | featurizer_type=featurizer_type, 79 | **kwargs, 80 | ) 81 | for comp in comps 82 | ] 83 | return feature_vectors 84 | 85 | 86 | def featurize_reactions( 87 | reactions, 88 | ele_order, 89 | featurizer_type="default", 90 | ion_order=None, 91 | ): 92 | # convert normalized composition array to custom features 93 | # Note: reactions is changed here because deepcopy is not used 94 | 95 | for r in reactions: 96 | r["target_comp_featurized"] = [] 97 | for i, comp in enumerate(r["target_comp"]): 98 | r["target_comp_featurized"].append( 99 | mat_featurizer( 100 | composition=comp, 101 | ele_order=ele_order, 102 | featurizer_type=featurizer_type, 103 | ion_order=ion_order, 104 | feature_array=r["target_valence"][i], 105 | ) 106 | ) 107 | r["precursors_comp_featurized"] = [] 108 | for i in range(len(r["precursors_comp"])): 109 | r["precursors_comp_featurized"].append([]) 110 | for j, comp in enumerate(r["precursors_comp"][i]): 111 | r["precursors_comp_featurized"][-1].append( 112 | mat_featurizer( 113 | composition=comp, 114 | ele_order=ele_order, 115 | featurizer_type=featurizer_type, 116 | ion_order=ion_order, 117 | feature_array=r["precursors_valence"][i][j], 118 | ) 119 | ) 120 | mat_feature_len = len(reactions[0]["target_comp_featurized"][0]) 121 | 122 | return reactions, mat_feature_len 123 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/model_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from . import utils 9 | from . import model_framework 10 | 11 | __author__ = 'Tanjin He' 12 | __maintainer__ = 'Tanjin He' 13 | __email__ = 'tanjin_he@berkeley.edu' 14 | 15 | 16 | def load_framework_model(model_dir): 17 | cp_path = os.path.join(model_dir, 'saved_model/cp.ckpt') 18 | config_path = os.path.join(model_dir, 'model_meta.pkl') 19 | 20 | with open(config_path, 'rb') as fr: 21 | model_meta = pickle.load(fr) 22 | model_config = model_meta['config'] 23 | model_config['model_path'] = model_dir 24 | 25 | # init model 26 | model = model_framework.MultiTasksOnRecipes(**model_config) 27 | batch_size = model_config['batch_size'] 28 | # run on one sample to build model 29 | zero_composition = np.zeros( 30 | shape=(len(model_config['all_eles']),), 31 | dtype=np.float32, 32 | ) 33 | zero_feature = np.zeros( 34 | shape=(model_config['mat_feature_len'],), 35 | dtype=np.float32, 36 | ) 37 | data_dicts =[ 38 | { 39 | 'reaction_1': [zero_composition]*model_config['max_mats_num'], 40 | 'reaction_2': [zero_composition]*model_config['max_mats_num'], 41 | 'reaction_1_featurized': [zero_feature] * model_config['max_mats_num'], 42 | 'reaction_2_featurized': [zero_feature] * model_config['max_mats_num'], 43 | 'precursors_1_conditional': [zero_composition] * (model_config['max_mats_num']-1), 44 | 'precursors_2_conditional': [zero_composition] * (model_config['max_mats_num']-1), 45 | 'temperature_1': 0.0, 46 | 'temperature_2': 0.0, 47 | 'synthesis_type_1': 'None', 48 | 'synthesis_type_2': 'None', 49 | } 50 | ]*batch_size 51 | data_type, data_shape, padded_data_shape = utils.get_input_format( 52 | model_type='MultiTasksOnRecipes', 53 | max_mats_num=model_config['max_mats_num'], 54 | ) 55 | data_X, data_Y = utils.dict_to_tf_dataset( 56 | data_dicts, 57 | data_type, 58 | data_shape, 59 | padded_shape=padded_data_shape, 60 | column_y=None, 61 | batch_size=batch_size, 62 | ) 63 | model.fit( 64 | x=tf.data.Dataset.zip((data_X, data_Y)), 65 | epochs=1, 66 | ) 67 | model.load_weights(cp_path) 68 | return model, model_config 69 | 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/tf_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Common TF utilities.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import six 22 | import tensorflow as tf 23 | import regex 24 | 25 | from . import activations 26 | 27 | 28 | def pack_inputs(inputs): 29 | """Pack a list of `inputs` tensors to a tuple. 30 | 31 | Args: 32 | inputs: a list of tensors. 33 | 34 | Returns: 35 | a tuple of tensors. if any input is None, replace it with a special constant 36 | tensor. 37 | """ 38 | inputs = tf.nest.flatten(inputs) 39 | outputs = [] 40 | for x in inputs: 41 | if x is None: 42 | outputs.append(tf.constant(0, shape=[], dtype=tf.int32)) 43 | else: 44 | outputs.append(x) 45 | return tuple(outputs) 46 | 47 | 48 | def unpack_inputs(inputs): 49 | """unpack a tuple of `inputs` tensors to a tuple. 50 | 51 | Args: 52 | inputs: a list of tensors. 53 | 54 | Returns: 55 | a tuple of tensors. if any input is a special constant tensor, replace it 56 | with None. 57 | """ 58 | inputs = tf.nest.flatten(inputs) 59 | outputs = [] 60 | for x in inputs: 61 | if is_special_none_tensor(x): 62 | outputs.append(None) 63 | else: 64 | outputs.append(x) 65 | x = tuple(outputs) 66 | 67 | # To trick the very pointless 'unbalanced-tuple-unpacking' pylint check 68 | # from triggering. 69 | if len(x) == 1: 70 | return x[0] 71 | return tuple(outputs) 72 | 73 | 74 | def is_special_none_tensor(tensor): 75 | """Checks if a tensor is a special None Tensor.""" 76 | return tensor.shape.ndims == 0 and tensor.dtype == tf.int32 77 | 78 | 79 | # TODO(hongkuny): consider moving custom string-map lookup to keras api. 80 | def get_activation(identifier): 81 | """Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`. 82 | 83 | It checks string first and if it is one of customized activation not in TF, 84 | the corresponding activation will be returned. For non-customized activation 85 | names and callable identifiers, always fallback to tf.keras.activations.get. 86 | 87 | Args: 88 | identifier: String name of the activation function or callable. 89 | 90 | Returns: 91 | A Python function corresponding to the activation function. 92 | """ 93 | if isinstance(identifier, six.string_types): 94 | name_to_fn = { 95 | "gelu": activations.gelu, 96 | "custom_swish": activations.swish, 97 | } 98 | identifier = str(identifier).lower() 99 | if identifier in name_to_fn: 100 | return tf.keras.activations.get(name_to_fn[identifier]) 101 | return tf.keras.activations.get(identifier) 102 | 103 | 104 | def get_shape_list(tensor, expected_rank=None, name=None): 105 | """Returns a list of the shape of tensor, preferring static dimensions. 106 | 107 | Args: 108 | tensor: A tf.Tensor object to find the shape of. 109 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 110 | specified and the `tensor` has a different rank, and exception will be 111 | thrown. 112 | name: Optional name of the tensor for the error message. 113 | 114 | Returns: 115 | A list of dimensions of the shape of tensor. All static dimensions will 116 | be returned as python integers, and dynamic dimensions will be returned 117 | as tf.Tensor scalars. 118 | """ 119 | if expected_rank is not None: 120 | assert_rank(tensor, expected_rank, name) 121 | 122 | shape = tensor.shape.as_list() 123 | 124 | non_static_indexes = [] 125 | for (index, dim) in enumerate(shape): 126 | if dim is None: 127 | non_static_indexes.append(index) 128 | 129 | if not non_static_indexes: 130 | return shape 131 | 132 | dyn_shape = tf.shape(tensor) 133 | for index in non_static_indexes: 134 | shape[index] = dyn_shape[index] 135 | return shape 136 | 137 | 138 | def assert_rank(tensor, expected_rank, name=None): 139 | """Raises an exception if the tensor rank is not of the expected rank. 140 | 141 | Args: 142 | tensor: A tf.Tensor to check the rank of. 143 | expected_rank: Python integer or list of integers, expected rank. 144 | name: Optional name of the tensor for the error message. 145 | 146 | Raises: 147 | ValueError: If the expected shape doesn't match the actual shape. 148 | """ 149 | expected_rank_dict = {} 150 | if isinstance(expected_rank, six.integer_types): 151 | expected_rank_dict[expected_rank] = True 152 | else: 153 | for x in expected_rank: 154 | expected_rank_dict[x] = True 155 | 156 | actual_rank = tensor.shape.ndims 157 | if actual_rank not in expected_rank_dict: 158 | raise ValueError( 159 | "For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not " 160 | "equal to the expected tensor rank `%s`" % 161 | (name, actual_rank, str(tensor.shape), str(expected_rank))) 162 | 163 | 164 | def get_variables_by_name(layer, name): 165 | variables = layer.variables 166 | variables = list(filter(lambda x: regex.search(name, x.name), variables)) 167 | return variables 168 | 169 | -------------------------------------------------------------------------------- /SynthesisSimilarity/core/vector_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List 3 | 4 | __author__ = 'Tanjin He' 5 | __maintainer__ = 'Tanjin He' 6 | __email__ = 'tanjin_he@berkeley.edu' 7 | 8 | 9 | def most_similar_by_vector( 10 | target_vec: np.ndarray, 11 | target_candidate_formulas: List[str], 12 | target_candidate_normal_vecs: np.ndarray, 13 | positive_vecs: List[np.ndarray], 14 | negative_vecs: List[np.ndarray], 15 | top_n=10, 16 | ): 17 | sum_vec = np.sum( 18 | [target_vec] 19 | + positive_vecs 20 | +[-v for v in negative_vecs], 21 | axis=0, 22 | ) 23 | sum_vec = sum_vec/np.linalg.norm(sum_vec) 24 | all_similarity = sum_vec @ target_candidate_normal_vecs.T 25 | 26 | mat_idx_sort = np.argsort(all_similarity)[::-1] 27 | 28 | most_similar_mats = [ 29 | target_candidate_formulas[idx] 30 | for idx in mat_idx_sort[:top_n] 31 | ] 32 | most_similar_scores = list(all_similarity[mat_idx_sort[:top_n]]) 33 | 34 | return list(zip(most_similar_mats, most_similar_scores)) 35 | -------------------------------------------------------------------------------- /SynthesisSimilarity/examples/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Tanjin He' 2 | __maintainer__ = 'Tanjin He' 3 | __email__ = 'tanjin_he@berkeley.edu' 4 | 5 | -------------------------------------------------------------------------------- /SynthesisSimilarity/examples/synthesis_recommendation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Recommend precursors for given target materials using PrecursorSelector. 3 | """ 4 | 5 | from pprint import pprint 6 | 7 | from SynthesisSimilarity import PrecursorsRecommendation 8 | 9 | 10 | __author__ = "Tanjin He" 11 | __maintainer__ = "Tanjin He" 12 | __email__ = "tanjin_he@berkeley.edu" 13 | 14 | 15 | def run_recommendations(): 16 | precursors_recommendator = PrecursorsRecommendation() 17 | 18 | test_targets_formulas = [ 19 | "LiFePO4", 20 | "LiNi0.333Mn0.333Co0.333O2", 21 | ] 22 | 23 | print("len(test_targets_formulas)", len(test_targets_formulas)) 24 | print("test_targets_formulas", test_targets_formulas) 25 | 26 | all_predicts = precursors_recommendator.recommend_precursors( 27 | target_formula=test_targets_formulas, 28 | top_n=10, 29 | validate_reaction=True, 30 | ) 31 | 32 | for i in range(len(test_targets_formulas)): 33 | pprint(all_predicts[i]) 34 | print() 35 | 36 | 37 | if __name__ == "__main__": 38 | run_recommendations() 39 | -------------------------------------------------------------------------------- /SynthesisSimilarity/models/SynthesisEncoding/model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_mats_num": 6, 3 | "all_eles": [ 4 | "Cs", 5 | "K", 6 | "Rb", 7 | "Ba", 8 | "Na", 9 | "Sr", 10 | "Li", 11 | "Ca", 12 | "La", 13 | "Tb", 14 | "Yb", 15 | "Ce", 16 | "Pr", 17 | "Nd", 18 | "Sm", 19 | "Eu", 20 | "Gd", 21 | "Dy", 22 | "Y", 23 | "Ho", 24 | "Er", 25 | "Tm", 26 | "Lu", 27 | "Pu", 28 | "Am", 29 | "Cm", 30 | "Hf", 31 | "Th", 32 | "Mg", 33 | "Zr", 34 | "Np", 35 | "Sc", 36 | "U", 37 | "Ta", 38 | "Ti", 39 | "Mn", 40 | "Be", 41 | "Nb", 42 | "Al", 43 | "Tl", 44 | "V", 45 | "Zn", 46 | "Cr", 47 | "Cd", 48 | "In", 49 | "Ga", 50 | "Fe", 51 | "Co", 52 | "Cu", 53 | "Re", 54 | "Si", 55 | "Tc", 56 | "Ni", 57 | "Ag", 58 | "Sn", 59 | "Hg", 60 | "Ge", 61 | "Bi", 62 | "B", 63 | "Sb", 64 | "Te", 65 | "Mo", 66 | "As", 67 | "P", 68 | "H", 69 | "Ir", 70 | "Os", 71 | "Pd", 72 | "Ru", 73 | "Pt", 74 | "Rh", 75 | "Pb", 76 | "W", 77 | "Au", 78 | "C", 79 | "Se", 80 | "S", 81 | "I", 82 | "Br", 83 | "N", 84 | "Cl", 85 | "O", 86 | "F" 87 | ], 88 | "all_ions": [ 89 | [ 90 | "Ag", 91 | 0 92 | ], 93 | [ 94 | "Ag", 95 | 1 96 | ], 97 | [ 98 | "Ag", 99 | 2 100 | ], 101 | [ 102 | "Ag", 103 | 3 104 | ], 105 | [ 106 | "Al", 107 | 0 108 | ], 109 | [ 110 | "Al", 111 | 1 112 | ], 113 | [ 114 | "Al", 115 | 3 116 | ], 117 | [ 118 | "Am", 119 | 3 120 | ], 121 | [ 122 | "Am", 123 | 4 124 | ], 125 | [ 126 | "As", 127 | -3 128 | ], 129 | [ 130 | "As", 131 | -2 132 | ], 133 | [ 134 | "As", 135 | -1 136 | ], 137 | [ 138 | "As", 139 | 0 140 | ], 141 | [ 142 | "As", 143 | 2 144 | ], 145 | [ 146 | "As", 147 | 3 148 | ], 149 | [ 150 | "As", 151 | 5 152 | ], 153 | [ 154 | "Au", 155 | -1 156 | ], 157 | [ 158 | "Au", 159 | 0 160 | ], 161 | [ 162 | "Au", 163 | 1 164 | ], 165 | [ 166 | "Au", 167 | 2 168 | ], 169 | [ 170 | "Au", 171 | 3 172 | ], 173 | [ 174 | "Au", 175 | 5 176 | ], 177 | [ 178 | "B", 179 | -3 180 | ], 181 | [ 182 | "B", 183 | 0 184 | ], 185 | [ 186 | "B", 187 | 3 188 | ], 189 | [ 190 | "Ba", 191 | 0 192 | ], 193 | [ 194 | "Ba", 195 | 2 196 | ], 197 | [ 198 | "Be", 199 | 0 200 | ], 201 | [ 202 | "Be", 203 | 2 204 | ], 205 | [ 206 | "Bi", 207 | 0 208 | ], 209 | [ 210 | "Bi", 211 | 1 212 | ], 213 | [ 214 | "Bi", 215 | 2 216 | ], 217 | [ 218 | "Bi", 219 | 3 220 | ], 221 | [ 222 | "Bi", 223 | 5 224 | ], 225 | [ 226 | "Br", 227 | -1 228 | ], 229 | [ 230 | "Br", 231 | 0 232 | ], 233 | [ 234 | "C", 235 | -4 236 | ], 237 | [ 238 | "C", 239 | -3 240 | ], 241 | [ 242 | "C", 243 | -2 244 | ], 245 | [ 246 | "C", 247 | 0 248 | ], 249 | [ 250 | "C", 251 | 2 252 | ], 253 | [ 254 | "C", 255 | 3 256 | ], 257 | [ 258 | "C", 259 | 4 260 | ], 261 | [ 262 | "Ca", 263 | 0 264 | ], 265 | [ 266 | "Ca", 267 | 2 268 | ], 269 | [ 270 | "Cd", 271 | 0 272 | ], 273 | [ 274 | "Cd", 275 | 1 276 | ], 277 | [ 278 | "Cd", 279 | 2 280 | ], 281 | [ 282 | "Ce", 283 | 0 284 | ], 285 | [ 286 | "Ce", 287 | 2 288 | ], 289 | [ 290 | "Ce", 291 | 3 292 | ], 293 | [ 294 | "Ce", 295 | 4 296 | ], 297 | [ 298 | "Cl", 299 | -1 300 | ], 301 | [ 302 | "Cl", 303 | 0 304 | ], 305 | [ 306 | "Cm", 307 | 4 308 | ], 309 | [ 310 | "Co", 311 | 0 312 | ], 313 | [ 314 | "Co", 315 | 1 316 | ], 317 | [ 318 | "Co", 319 | 2 320 | ], 321 | [ 322 | "Co", 323 | 3 324 | ], 325 | [ 326 | "Co", 327 | 4 328 | ], 329 | [ 330 | "Cr", 331 | 0 332 | ], 333 | [ 334 | "Cr", 335 | 1 336 | ], 337 | [ 338 | "Cr", 339 | 2 340 | ], 341 | [ 342 | "Cr", 343 | 3 344 | ], 345 | [ 346 | "Cr", 347 | 4 348 | ], 349 | [ 350 | "Cr", 351 | 5 352 | ], 353 | [ 354 | "Cr", 355 | 6 356 | ], 357 | [ 358 | "Cs", 359 | 0 360 | ], 361 | [ 362 | "Cs", 363 | 1 364 | ], 365 | [ 366 | "Cu", 367 | 0 368 | ], 369 | [ 370 | "Cu", 371 | 1 372 | ], 373 | [ 374 | "Cu", 375 | 2 376 | ], 377 | [ 378 | "Cu", 379 | 3 380 | ], 381 | [ 382 | "Cu", 383 | 4 384 | ], 385 | [ 386 | "Dy", 387 | 0 388 | ], 389 | [ 390 | "Dy", 391 | 2 392 | ], 393 | [ 394 | "Dy", 395 | 3 396 | ], 397 | [ 398 | "Er", 399 | 0 400 | ], 401 | [ 402 | "Er", 403 | 3 404 | ], 405 | [ 406 | "Eu", 407 | 0 408 | ], 409 | [ 410 | "Eu", 411 | 2 412 | ], 413 | [ 414 | "Eu", 415 | 3 416 | ], 417 | [ 418 | "F", 419 | -1 420 | ], 421 | [ 422 | "Fe", 423 | 0 424 | ], 425 | [ 426 | "Fe", 427 | 1 428 | ], 429 | [ 430 | "Fe", 431 | 2 432 | ], 433 | [ 434 | "Fe", 435 | 3 436 | ], 437 | [ 438 | "Fe", 439 | 4 440 | ], 441 | [ 442 | "Fe", 443 | 5 444 | ], 445 | [ 446 | "Fe", 447 | 6 448 | ], 449 | [ 450 | "Ga", 451 | 0 452 | ], 453 | [ 454 | "Ga", 455 | 1 456 | ], 457 | [ 458 | "Ga", 459 | 2 460 | ], 461 | [ 462 | "Ga", 463 | 3 464 | ], 465 | [ 466 | "Gd", 467 | 0 468 | ], 469 | [ 470 | "Gd", 471 | 1 472 | ], 473 | [ 474 | "Gd", 475 | 2 476 | ], 477 | [ 478 | "Gd", 479 | 3 480 | ], 481 | [ 482 | "Ge", 483 | 0 484 | ], 485 | [ 486 | "Ge", 487 | 2 488 | ], 489 | [ 490 | "Ge", 491 | 3 492 | ], 493 | [ 494 | "Ge", 495 | 4 496 | ], 497 | [ 498 | "H", 499 | -1 500 | ], 501 | [ 502 | "H", 503 | 1 504 | ], 505 | [ 506 | "Hf", 507 | 0 508 | ], 509 | [ 510 | "Hf", 511 | 2 512 | ], 513 | [ 514 | "Hf", 515 | 3 516 | ], 517 | [ 518 | "Hf", 519 | 4 520 | ], 521 | [ 522 | "Hg", 523 | 0 524 | ], 525 | [ 526 | "Hg", 527 | 1 528 | ], 529 | [ 530 | "Hg", 531 | 2 532 | ], 533 | [ 534 | "Hg", 535 | 4 536 | ], 537 | [ 538 | "Ho", 539 | 0 540 | ], 541 | [ 542 | "Ho", 543 | 3 544 | ], 545 | [ 546 | "I", 547 | -1 548 | ], 549 | [ 550 | "I", 551 | 0 552 | ], 553 | [ 554 | "In", 555 | 0 556 | ], 557 | [ 558 | "In", 559 | 1 560 | ], 561 | [ 562 | "In", 563 | 2 564 | ], 565 | [ 566 | "In", 567 | 3 568 | ], 569 | [ 570 | "Ir", 571 | 0 572 | ], 573 | [ 574 | "Ir", 575 | 1 576 | ], 577 | [ 578 | "Ir", 579 | 2 580 | ], 581 | [ 582 | "Ir", 583 | 3 584 | ], 585 | [ 586 | "Ir", 587 | 4 588 | ], 589 | [ 590 | "Ir", 591 | 5 592 | ], 593 | [ 594 | "Ir", 595 | 6 596 | ], 597 | [ 598 | "K", 599 | 0 600 | ], 601 | [ 602 | "K", 603 | 1 604 | ], 605 | [ 606 | "La", 607 | 0 608 | ], 609 | [ 610 | "La", 611 | 2 612 | ], 613 | [ 614 | "La", 615 | 3 616 | ], 617 | [ 618 | "Li", 619 | 0 620 | ], 621 | [ 622 | "Li", 623 | 1 624 | ], 625 | [ 626 | "Lu", 627 | 0 628 | ], 629 | [ 630 | "Lu", 631 | 3 632 | ], 633 | [ 634 | "Mg", 635 | 0 636 | ], 637 | [ 638 | "Mg", 639 | 1 640 | ], 641 | [ 642 | "Mg", 643 | 2 644 | ], 645 | [ 646 | "Mn", 647 | 0 648 | ], 649 | [ 650 | "Mn", 651 | 1 652 | ], 653 | [ 654 | "Mn", 655 | 2 656 | ], 657 | [ 658 | "Mn", 659 | 3 660 | ], 661 | [ 662 | "Mn", 663 | 4 664 | ], 665 | [ 666 | "Mn", 667 | 5 668 | ], 669 | [ 670 | "Mn", 671 | 6 672 | ], 673 | [ 674 | "Mn", 675 | 7 676 | ], 677 | [ 678 | "Mo", 679 | 0 680 | ], 681 | [ 682 | "Mo", 683 | 1 684 | ], 685 | [ 686 | "Mo", 687 | 2 688 | ], 689 | [ 690 | "Mo", 691 | 3 692 | ], 693 | [ 694 | "Mo", 695 | 4 696 | ], 697 | [ 698 | "Mo", 699 | 5 700 | ], 701 | [ 702 | "Mo", 703 | 6 704 | ], 705 | [ 706 | "N", 707 | -3 708 | ], 709 | [ 710 | "N", 711 | -2 712 | ], 713 | [ 714 | "N", 715 | -1 716 | ], 717 | [ 718 | "N", 719 | 1 720 | ], 721 | [ 722 | "N", 723 | 3 724 | ], 725 | [ 726 | "N", 727 | 5 728 | ], 729 | [ 730 | "Na", 731 | 0 732 | ], 733 | [ 734 | "Na", 735 | 1 736 | ], 737 | [ 738 | "Nb", 739 | 0 740 | ], 741 | [ 742 | "Nb", 743 | 2 744 | ], 745 | [ 746 | "Nb", 747 | 3 748 | ], 749 | [ 750 | "Nb", 751 | 4 752 | ], 753 | [ 754 | "Nb", 755 | 5 756 | ], 757 | [ 758 | "Nd", 759 | 0 760 | ], 761 | [ 762 | "Nd", 763 | 2 764 | ], 765 | [ 766 | "Nd", 767 | 3 768 | ], 769 | [ 770 | "Ni", 771 | 0 772 | ], 773 | [ 774 | "Ni", 775 | 1 776 | ], 777 | [ 778 | "Ni", 779 | 2 780 | ], 781 | [ 782 | "Ni", 783 | 3 784 | ], 785 | [ 786 | "Ni", 787 | 4 788 | ], 789 | [ 790 | "Np", 791 | 0 792 | ], 793 | [ 794 | "Np", 795 | 3 796 | ], 797 | [ 798 | "Np", 799 | 4 800 | ], 801 | [ 802 | "Np", 803 | 6 804 | ], 805 | [ 806 | "Np", 807 | 7 808 | ], 809 | [ 810 | "O", 811 | -2 812 | ], 813 | [ 814 | "O", 815 | -1 816 | ], 817 | [ 818 | "Os", 819 | -2 820 | ], 821 | [ 822 | "Os", 823 | -1 824 | ], 825 | [ 826 | "Os", 827 | 0 828 | ], 829 | [ 830 | "Os", 831 | 1 832 | ], 833 | [ 834 | "Os", 835 | 2 836 | ], 837 | [ 838 | "Os", 839 | 3 840 | ], 841 | [ 842 | "Os", 843 | 4 844 | ], 845 | [ 846 | "Os", 847 | 5 848 | ], 849 | [ 850 | "Os", 851 | 6 852 | ], 853 | [ 854 | "Os", 855 | 7 856 | ], 857 | [ 858 | "Os", 859 | 8 860 | ], 861 | [ 862 | "P", 863 | -3 864 | ], 865 | [ 866 | "P", 867 | -2 868 | ], 869 | [ 870 | "P", 871 | -1 872 | ], 873 | [ 874 | "P", 875 | 0 876 | ], 877 | [ 878 | "P", 879 | 3 880 | ], 881 | [ 882 | "P", 883 | 4 884 | ], 885 | [ 886 | "P", 887 | 5 888 | ], 889 | [ 890 | "Pb", 891 | 0 892 | ], 893 | [ 894 | "Pb", 895 | 2 896 | ], 897 | [ 898 | "Pb", 899 | 4 900 | ], 901 | [ 902 | "Pd", 903 | 0 904 | ], 905 | [ 906 | "Pd", 907 | 2 908 | ], 909 | [ 910 | "Pd", 911 | 4 912 | ], 913 | [ 914 | "Pr", 915 | 0 916 | ], 917 | [ 918 | "Pr", 919 | 2 920 | ], 921 | [ 922 | "Pr", 923 | 3 924 | ], 925 | [ 926 | "Pr", 927 | 4 928 | ], 929 | [ 930 | "Pt", 931 | -2 932 | ], 933 | [ 934 | "Pt", 935 | 0 936 | ], 937 | [ 938 | "Pt", 939 | 2 940 | ], 941 | [ 942 | "Pt", 943 | 4 944 | ], 945 | [ 946 | "Pt", 947 | 5 948 | ], 949 | [ 950 | "Pt", 951 | 6 952 | ], 953 | [ 954 | "Pu", 955 | 0 956 | ], 957 | [ 958 | "Pu", 959 | 3 960 | ], 961 | [ 962 | "Pu", 963 | 4 964 | ], 965 | [ 966 | "Pu", 967 | 6 968 | ], 969 | [ 970 | "Pu", 971 | 7 972 | ], 973 | [ 974 | "Rb", 975 | 0 976 | ], 977 | [ 978 | "Rb", 979 | 1 980 | ], 981 | [ 982 | "Re", 983 | 0 984 | ], 985 | [ 986 | "Re", 987 | 1 988 | ], 989 | [ 990 | "Re", 991 | 2 992 | ], 993 | [ 994 | "Re", 995 | 3 996 | ], 997 | [ 998 | "Re", 999 | 4 1000 | ], 1001 | [ 1002 | "Re", 1003 | 5 1004 | ], 1005 | [ 1006 | "Re", 1007 | 6 1008 | ], 1009 | [ 1010 | "Re", 1011 | 7 1012 | ], 1013 | [ 1014 | "Rh", 1015 | 0 1016 | ], 1017 | [ 1018 | "Rh", 1019 | 1 1020 | ], 1021 | [ 1022 | "Rh", 1023 | 2 1024 | ], 1025 | [ 1026 | "Rh", 1027 | 3 1028 | ], 1029 | [ 1030 | "Rh", 1031 | 4 1032 | ], 1033 | [ 1034 | "Rh", 1035 | 6 1036 | ], 1037 | [ 1038 | "Ru", 1039 | 0 1040 | ], 1041 | [ 1042 | "Ru", 1043 | 1 1044 | ], 1045 | [ 1046 | "Ru", 1047 | 2 1048 | ], 1049 | [ 1050 | "Ru", 1051 | 3 1052 | ], 1053 | [ 1054 | "Ru", 1055 | 4 1056 | ], 1057 | [ 1058 | "Ru", 1059 | 5 1060 | ], 1061 | [ 1062 | "Ru", 1063 | 6 1064 | ], 1065 | [ 1066 | "Ru", 1067 | 8 1068 | ], 1069 | [ 1070 | "S", 1071 | -2 1072 | ], 1073 | [ 1074 | "S", 1075 | -1 1076 | ], 1077 | [ 1078 | "S", 1079 | 0 1080 | ], 1081 | [ 1082 | "S", 1083 | 2 1084 | ], 1085 | [ 1086 | "S", 1087 | 4 1088 | ], 1089 | [ 1090 | "S", 1091 | 6 1092 | ], 1093 | [ 1094 | "Sb", 1095 | -3 1096 | ], 1097 | [ 1098 | "Sb", 1099 | -2 1100 | ], 1101 | [ 1102 | "Sb", 1103 | -1 1104 | ], 1105 | [ 1106 | "Sb", 1107 | 0 1108 | ], 1109 | [ 1110 | "Sb", 1111 | 3 1112 | ], 1113 | [ 1114 | "Sb", 1115 | 5 1116 | ], 1117 | [ 1118 | "Sc", 1119 | 0 1120 | ], 1121 | [ 1122 | "Sc", 1123 | 1 1124 | ], 1125 | [ 1126 | "Sc", 1127 | 2 1128 | ], 1129 | [ 1130 | "Sc", 1131 | 3 1132 | ], 1133 | [ 1134 | "Se", 1135 | -2 1136 | ], 1137 | [ 1138 | "Se", 1139 | -1 1140 | ], 1141 | [ 1142 | "Se", 1143 | 0 1144 | ], 1145 | [ 1146 | "Se", 1147 | 4 1148 | ], 1149 | [ 1150 | "Se", 1151 | 6 1152 | ], 1153 | [ 1154 | "Si", 1155 | -4 1156 | ], 1157 | [ 1158 | "Si", 1159 | 0 1160 | ], 1161 | [ 1162 | "Si", 1163 | 4 1164 | ], 1165 | [ 1166 | "Sm", 1167 | 0 1168 | ], 1169 | [ 1170 | "Sm", 1171 | 2 1172 | ], 1173 | [ 1174 | "Sm", 1175 | 3 1176 | ], 1177 | [ 1178 | "Sn", 1179 | 0 1180 | ], 1181 | [ 1182 | "Sn", 1183 | 2 1184 | ], 1185 | [ 1186 | "Sn", 1187 | 3 1188 | ], 1189 | [ 1190 | "Sn", 1191 | 4 1192 | ], 1193 | [ 1194 | "Sr", 1195 | 0 1196 | ], 1197 | [ 1198 | "Sr", 1199 | 2 1200 | ], 1201 | [ 1202 | "Ta", 1203 | 0 1204 | ], 1205 | [ 1206 | "Ta", 1207 | 2 1208 | ], 1209 | [ 1210 | "Ta", 1211 | 3 1212 | ], 1213 | [ 1214 | "Ta", 1215 | 4 1216 | ], 1217 | [ 1218 | "Ta", 1219 | 5 1220 | ], 1221 | [ 1222 | "Tb", 1223 | 0 1224 | ], 1225 | [ 1226 | "Tb", 1227 | 1 1228 | ], 1229 | [ 1230 | "Tb", 1231 | 3 1232 | ], 1233 | [ 1234 | "Tb", 1235 | 4 1236 | ], 1237 | [ 1238 | "Tc", 1239 | 1 1240 | ], 1241 | [ 1242 | "Tc", 1243 | 2 1244 | ], 1245 | [ 1246 | "Tc", 1247 | 4 1248 | ], 1249 | [ 1250 | "Tc", 1251 | 7 1252 | ], 1253 | [ 1254 | "Te", 1255 | -2 1256 | ], 1257 | [ 1258 | "Te", 1259 | -1 1260 | ], 1261 | [ 1262 | "Te", 1263 | 0 1264 | ], 1265 | [ 1266 | "Te", 1267 | 4 1268 | ], 1269 | [ 1270 | "Te", 1271 | 6 1272 | ], 1273 | [ 1274 | "Th", 1275 | 0 1276 | ], 1277 | [ 1278 | "Th", 1279 | 3 1280 | ], 1281 | [ 1282 | "Th", 1283 | 4 1284 | ], 1285 | [ 1286 | "Ti", 1287 | 0 1288 | ], 1289 | [ 1290 | "Ti", 1291 | 2 1292 | ], 1293 | [ 1294 | "Ti", 1295 | 3 1296 | ], 1297 | [ 1298 | "Ti", 1299 | 4 1300 | ], 1301 | [ 1302 | "Tl", 1303 | 0 1304 | ], 1305 | [ 1306 | "Tl", 1307 | 1 1308 | ], 1309 | [ 1310 | "Tl", 1311 | 3 1312 | ], 1313 | [ 1314 | "Tm", 1315 | 0 1316 | ], 1317 | [ 1318 | "Tm", 1319 | 2 1320 | ], 1321 | [ 1322 | "Tm", 1323 | 3 1324 | ], 1325 | [ 1326 | "U", 1327 | 0 1328 | ], 1329 | [ 1330 | "U", 1331 | 3 1332 | ], 1333 | [ 1334 | "U", 1335 | 4 1336 | ], 1337 | [ 1338 | "U", 1339 | 5 1340 | ], 1341 | [ 1342 | "U", 1343 | 6 1344 | ], 1345 | [ 1346 | "V", 1347 | 0 1348 | ], 1349 | [ 1350 | "V", 1351 | 1 1352 | ], 1353 | [ 1354 | "V", 1355 | 2 1356 | ], 1357 | [ 1358 | "V", 1359 | 3 1360 | ], 1361 | [ 1362 | "V", 1363 | 4 1364 | ], 1365 | [ 1366 | "V", 1367 | 5 1368 | ], 1369 | [ 1370 | "W", 1371 | 0 1372 | ], 1373 | [ 1374 | "W", 1375 | 1 1376 | ], 1377 | [ 1378 | "W", 1379 | 2 1380 | ], 1381 | [ 1382 | "W", 1383 | 3 1384 | ], 1385 | [ 1386 | "W", 1387 | 4 1388 | ], 1389 | [ 1390 | "W", 1391 | 5 1392 | ], 1393 | [ 1394 | "W", 1395 | 6 1396 | ], 1397 | [ 1398 | "X", 1399 | -1 1400 | ], 1401 | [ 1402 | "X", 1403 | 1 1404 | ], 1405 | [ 1406 | "Y", 1407 | 0 1408 | ], 1409 | [ 1410 | "Y", 1411 | 1 1412 | ], 1413 | [ 1414 | "Y", 1415 | 2 1416 | ], 1417 | [ 1418 | "Y", 1419 | 3 1420 | ], 1421 | [ 1422 | "Yb", 1423 | 0 1424 | ], 1425 | [ 1426 | "Yb", 1427 | 2 1428 | ], 1429 | [ 1430 | "Yb", 1431 | 3 1432 | ], 1433 | [ 1434 | "Zn", 1435 | 0 1436 | ], 1437 | [ 1438 | "Zn", 1439 | 1 1440 | ], 1441 | [ 1442 | "Zn", 1443 | 2 1444 | ], 1445 | [ 1446 | "Zr", 1447 | 0 1448 | ], 1449 | [ 1450 | "Zr", 1451 | 1 1452 | ], 1453 | [ 1454 | "Zr", 1455 | 2 1456 | ], 1457 | [ 1458 | "Zr", 1459 | 3 1460 | ], 1461 | [ 1462 | "Zr", 1463 | 4 1464 | ] 1465 | ], 1466 | "featurizer_type": "default", 1467 | "mat_feature_len": 83, 1468 | "ele_dim_features": 32, 1469 | "num_attention_layers": 2, 1470 | "num_attention_heads": 2, 1471 | "hidden_activation": "gelu", 1472 | "hidden_dropout": 0.1, 1473 | "attention_dropout": 0.1, 1474 | "initializer_range": 0.02, 1475 | "encoder_type": "simple_hidden", 1476 | "encoder_normalize_output": true, 1477 | "ele_emb_init_max": 10.0, 1478 | "zero_shift_init_value": -1.0, 1479 | "zero_shift_trainable": true 1480 | } -------------------------------------------------------------------------------- /SynthesisSimilarity/models/SynthesisEncoding/saved_model/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CederGroupHub/SynthesisSimilarity/21f013ff6a1fe1f5eeb8c48e32bf20d601d2fb86/SynthesisSimilarity/models/SynthesisEncoding/saved_model/saved_model.pb -------------------------------------------------------------------------------- /SynthesisSimilarity/models/SynthesisEncoding/saved_model/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CederGroupHub/SynthesisSimilarity/21f013ff6a1fe1f5eeb8c48e32bf20d601d2fb86/SynthesisSimilarity/models/SynthesisEncoding/saved_model/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /SynthesisSimilarity/models/SynthesisEncoding/saved_model/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CederGroupHub/SynthesisSimilarity/21f013ff6a1fe1f5eeb8c48e32bf20d601d2fb86/SynthesisSimilarity/models/SynthesisEncoding/saved_model/variables/variables.index -------------------------------------------------------------------------------- /SynthesisSimilarity/models/SynthesisRecommendation/cmd_parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_path": "rsc/reactions_v20_20210820_solid_state_40670.jsonl", 3 | "to_reload": true, 4 | "reload_path": "../../data_public/data_ss_v20_14.npz", 5 | "evaluation_data_path": "../../data_public/evaluation_data.npz", 6 | "cp_dir": "/global/scratch/users/tanjin_he/similarity_tasks/0001_hyper_structure_0980/TPSimilarity_003312/generated/model_1", 7 | "save_meta": false, 8 | "get_meta_only": false, 9 | "pretrain_data": "", 10 | "uncommon_reactions_only": false, 11 | "path_precursor_frequencies": "../../data_public/pre_count_normalized_by_rxn_ss_v20_12.json", 12 | "featurizer_type": "default", 13 | "ion_freq_threshold": 0, 14 | "common_ion_not_feature": false, 15 | "encoder_type": "simple_hidden", 16 | "encoder_normalize_output": true, 17 | "ele_emb_init_max": 10.0, 18 | "zero_shift_init_value": -1.0, 19 | "zero_shift_trainable": true, 20 | "task_to_add": [ 21 | "reaction_pre", 22 | "mat" 23 | ], 24 | "weight_mat_decoder": 0.1, 25 | "weight_pre_predict": 1.0, 26 | "weight_syn_type_predict": 1.0, 27 | "weight_sim_between_react": 1.0, 28 | "weight_mat_variance": 1.0, 29 | "use_adaptive_multi_loss": true, 30 | "max_mats_num": 6, 31 | "precursor_drop_n": -5, 32 | "batch_size": 8, 33 | "ele_dim_features": 32, 34 | "num_attention_layers": 2, 35 | "num_attention_heads": 2, 36 | "hidden_activation": "gelu", 37 | "hidden_dropout": 0.1, 38 | "attention_dropout": 0.1, 39 | "initializer_range": 0.02, 40 | "num_reserved_ids": 10, 41 | "decoder_final_activation": "sigmoid", 42 | "decoder_loss_fn": "mse", 43 | "ele_pred_stoi_scale": 1.0, 44 | "bias_in_element_layer": true, 45 | "constrain_element_layer": false, 46 | "norm_in_element_projection": false, 47 | "ele_pred_dot_prod_scale": 1.0, 48 | "ele_pred_balance_PN": false, 49 | "ele_pred_clip_logits": false, 50 | "ele_pred_focal_gamma": 2.0, 51 | "ele_pred_focal_alpha": 0.25, 52 | "ele_pred_focal_label_smoothing": 0.0, 53 | "ele_pred_circle_gamma": 64, 54 | "ele_pred_circle_margin": 0.25, 55 | "mat_variance_loss_fn": "abs_dot_sim", 56 | "pre_pred_under_mask": true, 57 | "pre_pred_atten_num_heads": 1, 58 | "pre_pred_atten_hidden_activation": "gelu", 59 | "pre_pred_atten_hidden_dropout": 0.1, 60 | "pre_pred_atten_dropout": 0.1, 61 | "pre_pred_atten_initializer_range": 0.02, 62 | "pre_pred_loss_fn": "circle", 63 | "pre_pred_kernel_initializer": "glorot_uniform", 64 | "pre_pred_initializer_max": 0.05, 65 | "pre_pred_lambda": 1.0, 66 | "pre_pred_dot_prod_scale": 1.0, 67 | "pre_pred_balance_PN": false, 68 | "pre_pred_clip_logits": true, 69 | "pre_pred_focal_gamma": 2.0, 70 | "pre_pred_focal_alpha": 0.25, 71 | "pre_pred_focal_label_smoothing": 0.0, 72 | "pre_pred_circle_gamma": 64, 73 | "pre_pred_circle_margin": 0.25, 74 | "constrain_precursor_layer": false, 75 | "bias_in_precursor_layer": true, 76 | "syn_type_pred_loss_fn": "cross_entropy", 77 | "syn_type_pred_balance_PN": true, 78 | "syn_type_pred_dot_prod_scale": 1.0, 79 | "constrain_syn_type_layer": false, 80 | "bias_in_syn_type_layer": false, 81 | "norm_in_syn_type_projection": true, 82 | "num_epochs": 50, 83 | "std_out": "output.txt", 84 | "steps_per_epoch": 10000, 85 | "num_warmup_epochs": 10, 86 | "lr_method_name": "adam", 87 | "init_learning_rate": 0.0005, 88 | "random_seed_str": "Similarity", 89 | "last_cp_dir": "/global/scratch/users/tanjin_he/similarity_tasks/0001_hyper_structure_0980/TPSimilarity_003312/generated/model_1/last_cp", 90 | "last_cp_path": "/global/scratch/users/tanjin_he/similarity_tasks/0001_hyper_structure_0980/TPSimilarity_003312/generated/model_1/last_cp/cp.ckpt", 91 | "opt_cp_dir": "/global/scratch/users/tanjin_he/similarity_tasks/0001_hyper_structure_0980/TPSimilarity_003312/generated/model_1/opt_cp", 92 | "opt_cp_path": "/global/scratch/users/tanjin_he/similarity_tasks/0001_hyper_structure_0980/TPSimilarity_003312/generated/model_1/opt_cp/cp.ckpt", 93 | "encoder_dir": "/global/scratch/users/tanjin_he/similarity_tasks/0001_hyper_structure_0980/TPSimilarity_003312/generated/model_1/encoder_model", 94 | "figure_dir": "/global/scratch/users/tanjin_he/similarity_tasks/0001_hyper_structure_0980/TPSimilarity_003312/generated/model_1/figures", 95 | "log_dir": "/global/scratch/users/tanjin_he/similarity_tasks/0001_hyper_structure_0980/TPSimilarity_003312/generated/model_1/logs", 96 | "num_train_steps": 500000, 97 | "num_warmup_steps": 100000 98 | } -------------------------------------------------------------------------------- /SynthesisSimilarity/models/SynthesisRecommendation/model_meta.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CederGroupHub/SynthesisSimilarity/21f013ff6a1fe1f5eeb8c48e32bf20d601d2fb86/SynthesisSimilarity/models/SynthesisRecommendation/model_meta.pkl -------------------------------------------------------------------------------- /SynthesisSimilarity/models/SynthesisRecommendation/saved_model/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp.ckpt" 2 | all_model_checkpoint_paths: "cp.ckpt" 3 | -------------------------------------------------------------------------------- /SynthesisSimilarity/models/SynthesisRecommendation/saved_model/cp.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CederGroupHub/SynthesisSimilarity/21f013ff6a1fe1f5eeb8c48e32bf20d601d2fb86/SynthesisSimilarity/models/SynthesisRecommendation/saved_model/cp.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /SynthesisSimilarity/models/SynthesisRecommendation/saved_model/cp.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CederGroupHub/SynthesisSimilarity/21f013ff6a1fe1f5eeb8c48e32bf20d601d2fb86/SynthesisSimilarity/models/SynthesisRecommendation/saved_model/cp.ckpt.index -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts/_00_download_model_and_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gdown 3 | import shutil 4 | 5 | 6 | def download_necessary_data(): 7 | root = os.path.abspath( 8 | os.path.join( 9 | os.path.dirname(__file__), 10 | "..", 11 | ) 12 | ) 13 | print("root", root) 14 | 15 | # download model and data for PrecursorSelector 16 | file_id_1 = "1ack7mcyHtUVMe99kRARvdDV8UhweElJ4" 17 | url_1 = f"https://drive.google.com/uc?id={file_id_1}" 18 | path_zip_1 = os.path.join(root, "rsc.zip") 19 | gdown.download(url_1, path_zip_1, quiet=False) 20 | shutil.unpack_archive(path_zip_1, root) 21 | os.remove(path_zip_1) 22 | 23 | 24 | def download_optional_data(): 25 | root = os.path.abspath( 26 | os.path.join( 27 | os.path.dirname(__file__), 28 | "..", 29 | ) 30 | ) 31 | print("root", root) 32 | 33 | # (optional) download model and data for baseline models 34 | file_id_2 = "1JbVNctVpspwqjaev0TDW10cn2izxDwpy" 35 | url_2 = f"https://drive.google.com/uc?id={file_id_2}" 36 | path_zip_2 = os.path.join(root, "other_rsc.zip") 37 | gdown.download(url_2, path_zip_2, quiet=False) 38 | shutil.unpack_archive(path_zip_2, root) 39 | os.remove(path_zip_2) 40 | 41 | 42 | if __name__ == "__main__": 43 | # download model and data for PrecursorSelector 44 | download_necessary_data() 45 | 46 | # # (optional) download model and data for baseline models 47 | # download_optional_data() 48 | -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts/_01_synthesis_recommendation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Recommend precursors for given target materials using PrecursorSelector. 3 | """ 4 | 5 | import os 6 | import sys 7 | 8 | parent_folder = os.path.abspath( 9 | os.path.join( 10 | os.path.dirname(__file__), 11 | "../..", 12 | ) 13 | ) 14 | print("parent_folder", parent_folder) 15 | if parent_folder not in sys.path: 16 | sys.path.append(parent_folder) 17 | 18 | from pprint import pprint 19 | 20 | from SynthesisSimilarity.scripts_utils.precursors_recommendation_utils import ( 21 | PrecursorsRecommendation, 22 | ) 23 | 24 | 25 | __author__ = "Tanjin He" 26 | __maintainer__ = "Tanjin He" 27 | __email__ = "tanjin_he@berkeley.edu" 28 | 29 | 30 | def run_recommendations(): 31 | precursors_recommendator = PrecursorsRecommendation( 32 | model_dir="../models/SynthesisRecommendation", 33 | freq_path="../rsc/pre_count_normalized_by_rxn_ss.json", 34 | data_path="../rsc/data_split.npz", 35 | all_to_knowledge_base=False, 36 | ) 37 | 38 | test_targets_formulas = [ 39 | "SrZnSO", 40 | "Na3TiV(PO4)3", 41 | "GdLu(MoO4)3", 42 | "BaYSi2O5N", 43 | "Cu3Yb(SeO3)2O2Cl", 44 | ] 45 | 46 | print("len(test_targets_formulas)", len(test_targets_formulas)) 47 | print("test_targets_formulas", test_targets_formulas) 48 | 49 | all_predicts = precursors_recommendator.recommend_precursors( 50 | target_formula=test_targets_formulas, 51 | top_n=10, 52 | ) 53 | 54 | for i in range(len(test_targets_formulas)): 55 | pprint(all_predicts[i]) 56 | print() 57 | 58 | 59 | if __name__ == "__main__": 60 | run_recommendations() 61 | -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts/_02_target_material_similarity.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calculate similarity between two target materials using the PrecursorSelector encoding. 3 | """ 4 | 5 | import os 6 | import sys 7 | 8 | parent_folder = os.path.abspath( 9 | os.path.join( 10 | os.path.dirname(__file__), 11 | "../..", 12 | ) 13 | ) 14 | print("parent_folder", parent_folder) 15 | if parent_folder not in sys.path: 16 | sys.path.append(parent_folder) 17 | 18 | from SynthesisSimilarity.scripts_utils.TarMatSimilarity_utils import TarMatSimilarity 19 | 20 | 21 | __author__ = "Tanjin He" 22 | __maintainer__ = "Tanjin He" 23 | __email__ = "tanjin_he@berkeley.edu" 24 | 25 | 26 | def calc_similarity( 27 | model_dir: str, 28 | ): 29 | formula1 = [ 30 | "NaZr2(PO4)3", 31 | ] 32 | formula2 = [ 33 | "Zr3(PO4)4", 34 | "Na3Zr2Si2PO12", 35 | "Na3Zr1.8Ge0.2Si2PO12", 36 | "Na3Ca0.1Zr1.9Si2PO11.9", 37 | "Na3Zr1.9Ti0.1Si2PO12", 38 | "LiZr2(PO4)3", 39 | "NaLa(PO3)4", 40 | "Sr0.125Ca0.375Zr2(PO4)3", 41 | "Na5Cu2(PO4)3", 42 | "LiGe2(PO4)3", 43 | "Li1.8ZrO3", 44 | "NaNbO3", 45 | "Li2Mg2(MoO4)3", 46 | "Sr2Ce2Ti5O16", 47 | "Ga0.75Al0.25FeO3", 48 | "Cu2Te", 49 | "Ni60Fe30Mn10", 50 | "AgCrSe2", 51 | "Zn0.1Cd0.9Cr2S4", 52 | "Cr2AlC", 53 | ] 54 | sim_calculator = TarMatSimilarity(model_dir) 55 | for f1 in formula1: 56 | for f2 in formula2: 57 | print("\nComparing %s to %s:" % (f1, f2)) 58 | print("Similarity = %.3f" % sim_calculator.compare(f1, f2)) 59 | 60 | 61 | if __name__ == "__main__": 62 | calc_similarity( 63 | model_dir="../models/SynthesisEncoding", 64 | ) 65 | -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts/_03_masked_precursor_completion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Predict complete the precursor list for the given target material with conditional precursors. 3 | """ 4 | 5 | import os 6 | import sys 7 | 8 | parent_folder = os.path.abspath( 9 | os.path.join( 10 | os.path.dirname(__file__), 11 | "../..", 12 | ) 13 | ) 14 | print("parent_folder", parent_folder) 15 | if parent_folder not in sys.path: 16 | sys.path.append(parent_folder) 17 | 18 | import numpy as np 19 | 20 | from SynthesisSimilarity.core import utils 21 | from SynthesisSimilarity.core import model_utils 22 | from SynthesisSimilarity.core import callbacks 23 | 24 | 25 | __author__ = "Tanjin He" 26 | __maintainer__ = "Tanjin He" 27 | __email__ = "tanjin_he@berkeley.edu" 28 | 29 | 30 | def get_decode_test_examples(): 31 | # decode target w conditional precursors 32 | target_formulas = [ 33 | "LaAlO3", 34 | "LaAlO3", 35 | "LaAlO3", 36 | "LaAlO3", 37 | "LaAlO3", 38 | ] 39 | 40 | precursor_formulas_conditional = [ 41 | [ 42 | "La(NO3)3", 43 | ], 44 | [ 45 | "Al(NO3)3", 46 | ], 47 | [ 48 | "La2O3", 49 | ], 50 | [ 51 | "Al2O3", 52 | ], 53 | [], 54 | ] 55 | return { 56 | "target_formulas": target_formulas, 57 | "precursor_formulas_conditional": precursor_formulas_conditional, 58 | } 59 | 60 | 61 | def decode_target_w_conditional_precursors( 62 | target_formulas, 63 | precursor_formulas_conditional, 64 | framework_model, 65 | all_elements, 66 | mat_feature_len, 67 | featurizer_type, 68 | max_mats_num, 69 | ): 70 | 71 | assert len(target_formulas) == len(precursor_formulas_conditional) 72 | 73 | predict_precursor_callback = callbacks.PredictPrecursorsCallback( 74 | all_elements=all_elements, 75 | mat_feature_len=mat_feature_len, 76 | test_data=None, 77 | output_thresh=0.5, 78 | featurizer_type=featurizer_type, 79 | ) 80 | 81 | target_compositions = [] 82 | precursors_conditional = [] 83 | zero_composition = np.zeros( 84 | shape=(len(all_elements),), 85 | dtype=np.float32, 86 | ) 87 | for (tar, pres) in zip(target_formulas, precursor_formulas_conditional): 88 | target_compositions.append( 89 | utils.formula_to_array(tar, all_elements), 90 | ) 91 | precursors_conditional.append([]) 92 | for i in range(max_mats_num - 1): 93 | if i < len(pres): 94 | precursors_conditional[-1].append( 95 | utils.formula_to_array(pres[i], all_elements) 96 | ) 97 | else: 98 | precursors_conditional[-1].append(zero_composition) 99 | 100 | target_compositions = np.array(target_compositions) 101 | precursors_conditional = np.array(precursors_conditional) 102 | 103 | if "reaction_pre" in framework_model.task_names: 104 | ( 105 | pre_lists_pred, 106 | pre_str_lists_pred, 107 | ) = predict_precursor_callback.predict_precursors( 108 | framework_model, 109 | target_compositions, 110 | precursors_conditional=precursors_conditional, 111 | to_print=True, 112 | ) 113 | 114 | 115 | if __name__ == "__main__": 116 | model_dir = "../models/SynthesisRecommendation" 117 | 118 | framework_model, model_config = model_utils.load_framework_model(model_dir) 119 | all_elements = model_config["all_eles"] 120 | max_mats_num = model_config["max_mats_num"] 121 | featurizer_type = model_config["featurizer_type"] 122 | mat_feature_len = model_config["mat_feature_len"] 123 | 124 | # decode target w conditional precursors 125 | decode_examples = get_decode_test_examples() 126 | target_formulas = decode_examples["target_formulas"] 127 | precursor_formulas_conditional = decode_examples["precursor_formulas_conditional"] 128 | 129 | decode_target_w_conditional_precursors( 130 | target_formulas=target_formulas, 131 | precursor_formulas_conditional=precursor_formulas_conditional, 132 | framework_model=framework_model, 133 | all_elements=all_elements, 134 | mat_feature_len=mat_feature_len, 135 | featurizer_type=featurizer_type, 136 | max_mats_num=max_mats_num, 137 | ) 138 | -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts/_04_reaction_relationship.py: -------------------------------------------------------------------------------- 1 | """ 2 | Explore the relationship of Tar_1 - Pre_1 vs Tar_2 - Pre_2, 3 | similar to king - man ~ queen - woman from word2vec. 4 | """ 5 | 6 | import os 7 | import sys 8 | 9 | parent_folder = os.path.abspath( 10 | os.path.join( 11 | os.path.dirname(__file__), 12 | "../..", 13 | ) 14 | ) 15 | print("parent_folder", parent_folder) 16 | if parent_folder not in sys.path: 17 | sys.path.append(parent_folder) 18 | 19 | 20 | import numpy as np 21 | from pprint import pprint 22 | from typing import List 23 | 24 | from SynthesisSimilarity.core import model_utils 25 | from SynthesisSimilarity.core import callbacks 26 | from sklearn import decomposition 27 | import matplotlib.pyplot as plt 28 | 29 | 30 | __author__ = "Tanjin He" 31 | __maintainer__ = "Tanjin He" 32 | __email__ = "tanjin_he@berkeley.edu" 33 | 34 | 35 | def explore_reaction_relationship( 36 | vec_cb, 37 | framework_model, 38 | ): 39 | target_formulas = [ 40 | "InCuO2", 41 | "YCuO2", 42 | "Al2CuO4", 43 | "FeCuO2", 44 | "BaIn2O4", 45 | "Ba3Y4O9", 46 | "BaAl2O4", 47 | "BaFeO3", 48 | "TiIn2O5", 49 | "Ti3Y2O9", 50 | "Ti3Al2O9", 51 | "Ti3Fe2O9", 52 | ] 53 | precursor_formulas = [ 54 | "In2O3", 55 | "Y2O3", 56 | "Al2O3", 57 | "Fe2O3", 58 | "In2O3", 59 | "Y2O3", 60 | "Al2O3", 61 | "Fe2O3", 62 | "In2O3", 63 | "Y2O3", 64 | "Al2O3", 65 | "Fe2O3", 66 | ] 67 | 68 | decoder = vec_cb.get_decoder(framework_model) 69 | project_w_attention = False 70 | if decoder.predict_precursor_under_mask: 71 | project_w_attention = True 72 | pre_vec_mapping = vec_cb.get_pre_vec_mapping( 73 | decoder=decoder, 74 | ) 75 | precursor_vecs = vec_cb.get_pre_vecs_from_mapping( 76 | precursor_formulas=precursor_formulas, 77 | pre_vec_mapping=pre_vec_mapping, 78 | ) 79 | tar_vec_mapping, _ = vec_cb.get_tar_vec_mapping( 80 | target_formulas=target_formulas, 81 | model=framework_model, 82 | project_w_attention=project_w_attention, 83 | decoder=decoder, 84 | max_mats_num=vec_cb.max_mats_num, 85 | ) 86 | diff_vecs = [] 87 | for i in range(len(target_formulas)): 88 | diff_vecs.append(tar_vec_mapping[target_formulas[i]] - precursor_vecs[i]) 89 | 90 | diff_vecs = np.array(diff_vecs) 91 | pca = decomposition.PCA( 92 | n_components=2, 93 | svd_solver="arpack", 94 | ) 95 | pca.fit(diff_vecs) 96 | diff_vecs_2d = pca.transform(diff_vecs) 97 | 98 | print("diff_vecs_2d") 99 | pprint(diff_vecs_2d) 100 | 101 | plot_relation_shift( 102 | target_formulas=target_formulas, 103 | precursor_formulas=precursor_formulas, 104 | diff_vecs=diff_vecs_2d, 105 | ) 106 | 107 | 108 | def plot_relation_shift( 109 | target_formulas: List[str], 110 | precursor_formulas: List[str], 111 | diff_vecs: np.ndarray, 112 | ): 113 | 114 | jitter_scale = 0.05 115 | target_text = { 116 | "InCuO2": "InCuO$_2$", 117 | "YCuO2": "YCuO$_2$", 118 | "Al2CuO4": "Al$_2$CuO$_4$", 119 | "FeCuO2": "FeCuO$_2$", 120 | "BaIn2O4": "BaIn$_2$O$_4$", 121 | "Ba3Y4O9": "Ba$_3$Y$_4$O$_9$", 122 | "BaAl2O4": "BaAl$_2$O$_4$", 123 | "BaFeO3": "BaFeO$_3$", 124 | "TiIn2O5": "TiIn$_2$O$_5$", 125 | "Ti3Y2O9": "Ti$_3$Y$_2$O$_9$", 126 | "Ti3Al2O9": "Ti$_3$Al$_2$O$_9$", 127 | "Ti3Fe2O9": "Ti$_3$Fe$_2$O$_9$", 128 | } 129 | target_text_jitter = { 130 | "InCuO2": np.array([-2.2, 0.0]) * jitter_scale, 131 | "YCuO2": np.array([-2.5, 0]) * jitter_scale, 132 | "Al2CuO4": np.array([-3.5, -1]) * jitter_scale, 133 | "FeCuO2": np.array([-4, -5]) * jitter_scale, 134 | "BaIn2O4": np.array([-1, -3.0]) * jitter_scale, 135 | "Ba3Y4O9": np.array([-1, -3.0]) * jitter_scale, 136 | "BaAl2O4": np.array([-3.5, 0.0]) * jitter_scale, 137 | "BaFeO3": np.array([-1.5, -2.5]) * jitter_scale, 138 | "TiIn2O5": np.array([-3.5, 0.5]) * jitter_scale, 139 | "Ti3Y2O9": np.array([-5, -3]) * jitter_scale, 140 | "Ti3Al2O9": np.array([-3.2, 0]) * jitter_scale, 141 | "Ti3Fe2O9": np.array([-3, 0]) * jitter_scale, 142 | } 143 | target_text_angles = { 144 | "InCuO2": -85, 145 | "YCuO2": -90, 146 | "Al2CuO4": -80, 147 | "FeCuO2": -70, 148 | "BaIn2O4": 32, 149 | "Ba3Y4O9": 40, 150 | "BaAl2O4": 32, 151 | "BaFeO3": 18, 152 | "TiIn2O5": -10, 153 | "Ti3Y2O9": -15, 154 | "Ti3Al2O9": -15, 155 | "Ti3Fe2O9": -10, 156 | } 157 | precursor_init_points = { 158 | "Fe2O3": np.array([-3, 3]) * jitter_scale, 159 | "Al2O3": np.array([-1.5, 1.5]) * jitter_scale, 160 | "In2O3": np.array([0, 0]) * jitter_scale, 161 | "Y2O3": np.array([1.5, -1.5]) * jitter_scale, 162 | } 163 | precursor_text = { 164 | "In2O3": "In$_2$O$_3$", 165 | "Fe2O3": "Fe$_2$O$_3$", 166 | "Al2O3": "Al$_2$O$_3$", 167 | "Y2O3": "Y$_2$O$_3$", 168 | } 169 | precursor_text_jitter = { 170 | "In2O3": np.array([1.5, -0.0]) * jitter_scale, 171 | "Fe2O3": np.array([-5.2, 1.2]) * jitter_scale, 172 | "Al2O3": np.array([-6.0, -1.7]) * jitter_scale, 173 | "Y2O3": np.array([-6.2, -1.0]) * jitter_scale, 174 | } 175 | category_text = { 176 | "TiO2": "React w/ TiO$_2$", 177 | "CuO": "React w/ CuO", 178 | "BaCO3": "React w/ BaCO$_3$", 179 | } 180 | category_text_points = { 181 | "TiO2": np.array([2.3, -9.2]) * jitter_scale, 182 | "CuO": np.array([2, 4.0]) * jitter_scale, 183 | "BaCO3": np.array([-24, -3.2]) * jitter_scale, 184 | } 185 | category_text_angles = { 186 | "TiO2": -20, 187 | "CuO": -90, 188 | "BaCO3": 18, 189 | } 190 | category_colors = { 191 | "BaCO3": "slateblue", 192 | "CuO": "orchid", 193 | "TiO2": "seagreen", 194 | } 195 | target_colors = { 196 | "InCuO2": category_colors["CuO"], 197 | "YCuO2": category_colors["CuO"], 198 | "Al2CuO4": category_colors["CuO"], 199 | "FeCuO2": category_colors["CuO"], 200 | "BaIn2O4": category_colors["BaCO3"], 201 | "Ba3Y4O9": category_colors["BaCO3"], 202 | "BaAl2O4": category_colors["BaCO3"], 203 | "BaFeO3": category_colors["BaCO3"], 204 | "TiIn2O5": category_colors["TiO2"], 205 | "Ti3Y2O9": category_colors["TiO2"], 206 | "Ti3Al2O9": category_colors["TiO2"], 207 | "Ti3Fe2O9": category_colors["TiO2"], 208 | } 209 | precursor_colors = { 210 | "Fe2O3": "tab:blue", 211 | "Al2O3": "tab:green", 212 | "In2O3": "tab:purple", 213 | "Y2O3": "tab:orange", 214 | } 215 | precursor_markers = { 216 | "Fe2O3": "^", 217 | "Al2O3": "s", 218 | "In2O3": "p", 219 | "Y2O3": "o", 220 | } 221 | 222 | target_indices_to_plot = [target_formulas.index(x) for x in target_formulas] 223 | precursors_to_plot = list( 224 | set([precursor_formulas[t_i] for t_i in target_indices_to_plot]) 225 | ) 226 | 227 | fig = plt.figure( 228 | figsize=(12, 10), 229 | # constrained_layout=True, 230 | ) 231 | ax = fig.add_subplot(111) 232 | 233 | # sns.set(style="white", palette="muted", color_codes=True) 234 | # paper_rc = {'lines.linewidth': 5} 235 | # sns.set_context("paper", rc=paper_rc) 236 | 237 | for t_i in target_indices_to_plot: 238 | pre = precursor_formulas[t_i] 239 | pre_xy = precursor_init_points[pre] 240 | tar = target_formulas[t_i] 241 | tar_dxy = diff_vecs[t_i] 242 | tar_xy = tar_dxy + pre_xy 243 | tar_text = target_text[tar] 244 | tar_text_xy = tar_xy + target_text_jitter[tar] 245 | tar_text_angle = target_text_angles[tar] 246 | tar_color = target_colors[tar] 247 | ax.arrow( 248 | pre_xy[0], 249 | pre_xy[1], 250 | tar_dxy[0], 251 | tar_dxy[1], 252 | head_width=0.04, 253 | head_length=0.05, 254 | linewidth=2, 255 | facecolor=tar_color, 256 | edgecolor=tar_color, 257 | alpha=0.75, 258 | head_starts_at_zero=False, 259 | length_includes_head=True, 260 | ) 261 | ax.text( 262 | tar_text_xy[0], 263 | tar_text_xy[1], 264 | tar_text, 265 | rotation=tar_text_angle, 266 | fontsize=28, 267 | color=tar_color, 268 | ) 269 | 270 | for pre in precursors_to_plot: 271 | pre_xy = precursor_init_points[pre] 272 | pre_text = precursor_text[pre] 273 | pre_text_xy = pre_xy + precursor_text_jitter[pre] 274 | pre_color = precursor_colors[pre] 275 | pre_marker = precursor_markers[pre] 276 | ax.plot( 277 | pre_xy[0], 278 | pre_xy[1], 279 | marker=pre_marker, 280 | markersize=20, 281 | color=pre_color, 282 | ) 283 | ax.text( 284 | pre_text_xy[0], 285 | pre_text_xy[1], 286 | pre_text, 287 | fontsize=28, 288 | color=pre_color, 289 | ) 290 | 291 | for cat in category_text_points: 292 | cat_text_xy = category_text_points[cat] 293 | cat_text_angle = category_text_angles[cat] 294 | cat_color = category_colors[cat] 295 | cat_text = category_text[cat] 296 | ax.text( 297 | cat_text_xy[0], 298 | cat_text_xy[1], 299 | cat_text, 300 | fontsize=32, 301 | color=cat_color, 302 | rotation=cat_text_angle, 303 | ) 304 | 305 | plt.xlabel("First principal component", size=36) 306 | plt.ylabel("Second principal component", size=36) 307 | ax.tick_params(axis="x", which="major", labelsize=26) 308 | ax.tick_params(axis="y", which="major", labelsize=26) 309 | plt.xlim(-1.22, 1.27) 310 | plt.ylim(-0.88, 1.2) 311 | ax.set_aspect("equal", adjustable="box") 312 | 313 | plt.tight_layout() 314 | path_save = "../generated/plots/relationship_shift.png" 315 | if not os.path.exists(os.path.dirname(path_save)): 316 | os.makedirs(os.path.dirname(path_save)) 317 | plt.savefig(path_save, dpi=300) 318 | print("Figure saved to {}".format(os.path.dirname(path_save))) 319 | # plt.show() 320 | 321 | 322 | if __name__ == "__main__": 323 | 324 | model_dir = "../models/SynthesisRecommendation" 325 | 326 | framework_model, model_config = model_utils.load_framework_model(model_dir) 327 | all_elements = model_config["all_eles"] 328 | max_mats_num = model_config["max_mats_num"] 329 | featurizer_type = model_config["featurizer_type"] 330 | 331 | # using callbacks 332 | vec_test = callbacks.VectorMathCallback( 333 | all_elements=all_elements, 334 | featurizer_type=featurizer_type, 335 | max_mats_num=max_mats_num, 336 | top_n=10, 337 | test_data=None, 338 | ) 339 | 340 | # plot relationship 341 | explore_reaction_relationship( 342 | vec_cb=vec_test, 343 | framework_model=framework_model, 344 | ) 345 | -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts/_05_recommendation_benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | Benchmark the performance of PrecursorSelector and baseline models. 3 | It takes a long time (~2h) to run everything since parallelization is not optimized here. 4 | """ 5 | 6 | import os 7 | import pdb 8 | import sys 9 | import warnings 10 | 11 | parent_folder = os.path.abspath( 12 | os.path.join( 13 | os.path.dirname(__file__), 14 | "../..", 15 | ) 16 | ) 17 | print("parent_folder", parent_folder) 18 | if parent_folder not in sys.path: 19 | sys.path.append(parent_folder) 20 | 21 | 22 | import numpy as np 23 | import pkgutil 24 | from pprint import pprint 25 | 26 | import json 27 | import itertools 28 | 29 | from SynthesisSimilarity.core.utils import ( 30 | formula_to_array, 31 | get_elements_in_formula, 32 | use_file_as_stdout, 33 | ) 34 | from SynthesisSimilarity.scripts_utils.recommendation_utils import ( 35 | collect_targets_in_reactions, 36 | add_to_sorted_list, 37 | ) 38 | from SynthesisSimilarity.scripts_utils.precursors_recommendation_utils import ( 39 | PrecursorsRecommendation, 40 | ) 41 | from SynthesisSimilarity.scripts_utils.multi_processing_utils import ( 42 | run_multiprocessing_tasks, 43 | ) 44 | from SynthesisSimilarity.scripts_utils.MatminerSimilarity_utils import ( 45 | MatMiner_features_for_formulas, 46 | ) 47 | from SynthesisSimilarity.scripts_utils.train_utils import ( 48 | load_raw_reactions, 49 | ) 50 | from SynthesisSimilarity.scripts_utils.FastTextSimilarity_utils import ( 51 | composition_to_human_formula, 52 | ) 53 | 54 | __author__ = "Tanjin He" 55 | __maintainer__ = "Tanjin He" 56 | __email__ = "tanjin_he@berkeley.edu" 57 | 58 | 59 | def evaluation_prediction_precursors( 60 | test_targets, 61 | test_targets_formulas, 62 | all_pres_predict, 63 | ): 64 | # print prediction info 65 | assert len(test_targets_formulas) == len( 66 | all_pres_predict 67 | ), "len(test_targets_formulas) != len(all_pres_predict)" 68 | all_results = [] 69 | pres_predict_result = [] 70 | pres_predict_result_top2 = [] 71 | pres_predict_result_top3 = [] 72 | pres_predict_result_topn = [] 73 | for x, pres_predict in zip(test_targets_formulas, all_pres_predict): 74 | pres_true_set = set(test_targets[x]["pres"].keys()) 75 | 76 | all_results.append([]) 77 | for i, pres in enumerate(pres_predict): 78 | all_results[-1].append(len(set(pres_predict[: i + 1]) & pres_true_set) > 0) 79 | 80 | pres_predict_result.append(len(set(pres_predict[:1]) & pres_true_set) > 0) 81 | pres_predict_result_top2.append(len(set(pres_predict[:2]) & pres_true_set) > 0) 82 | pres_predict_result_top3.append(len(set(pres_predict[:3]) & pres_true_set) > 0) 83 | pres_predict_result_topn.append(len(set(pres_predict) & pres_true_set) > 0) 84 | 85 | pres_predict_result = np.array(pres_predict_result, dtype=np.int64) 86 | print( 87 | "pres_predict all correct: True {num_true}/Total {num_total} = {accuracy}".format( 88 | num_true=sum(pres_predict_result), 89 | num_total=len(pres_predict_result), 90 | accuracy=sum(pres_predict_result) / len(pres_predict_result), 91 | ) 92 | ) 93 | 94 | pres_predict_result_top2 = np.array(pres_predict_result_top2, dtype=np.int64) 95 | print( 96 | "pres_predict top2 all correct: True {num_true}/Total {num_total} = {accuracy}".format( 97 | num_true=sum(pres_predict_result_top2), 98 | num_total=len(pres_predict_result_top2), 99 | accuracy=sum(pres_predict_result_top2) / len(pres_predict_result_top2), 100 | ) 101 | ) 102 | 103 | pres_predict_result_top3 = np.array(pres_predict_result_top3, dtype=np.int64) 104 | print( 105 | "pres_predict top3 all correct: True {num_true}/Total {num_total} = {accuracy}".format( 106 | num_true=sum(pres_predict_result_top3), 107 | num_total=len(pres_predict_result_top3), 108 | accuracy=sum(pres_predict_result_top3) / len(pres_predict_result_top3), 109 | ) 110 | ) 111 | 112 | pres_predict_result_topn = np.array(pres_predict_result_topn, dtype=np.int64) 113 | print( 114 | "pres_predict topn all correct: True {num_true}/Total {num_total} = {accuracy}".format( 115 | num_true=sum(pres_predict_result_topn), 116 | num_total=len(pres_predict_result_topn), 117 | accuracy=sum(pres_predict_result_topn) / len(pres_predict_result_topn), 118 | ) 119 | ) 120 | 121 | max_len = max(map(len, all_results)) 122 | num_short_result = 0 123 | for i in range(len(all_results)): 124 | if len(all_results[i]) < max_len: 125 | if len(all_results[i]) > 0: 126 | all_results[i].extend( 127 | [all_results[i][-1]] * (max_len - len(all_results[i])) 128 | ) 129 | else: 130 | all_results[i] = [False] * max_len 131 | num_short_result += 1 132 | print("len(all_results)", len(all_results)) 133 | print("num_short_result", num_short_result) 134 | all_results = np.array(all_results, dtype=np.int64) 135 | all_results = np.sum(all_results, axis=0) / len(all_results) 136 | print(list(all_results)) 137 | print() 138 | 139 | return all_results 140 | 141 | 142 | def run_recommendations(): 143 | precursors_recommendator = PrecursorsRecommendation( 144 | model_dir="../models/SynthesisRecommendation", 145 | freq_path="../rsc/pre_count_normalized_by_rxn_ss.json", 146 | data_path="../rsc/data_split.npz", 147 | all_to_knowledge_base=False, 148 | ) 149 | 150 | data_path = "../rsc/data_split.npz" 151 | test_data = np.load(data_path, allow_pickle=True) 152 | val_reactions = test_data["val_reactions"] 153 | test_reactions = test_data["test_reactions"] 154 | 155 | (val_targets, val_targets_formulas, _,) = collect_targets_in_reactions( 156 | val_reactions, 157 | precursors_recommendator.all_elements, 158 | precursors_recommendator.common_precursors_set, 159 | ) 160 | 161 | (test_targets, test_targets_formulas, _,) = collect_targets_in_reactions( 162 | test_reactions, 163 | precursors_recommendator.all_elements, 164 | precursors_recommendator.common_precursors_set, 165 | ) 166 | 167 | ######################## 168 | # recommendation through synthesis similarity 169 | all_pres_predict = recommend_w_SynSym( 170 | precursors_recommendator=precursors_recommendator, 171 | test_targets_formulas=val_targets_formulas, 172 | top_n=10, 173 | ) 174 | 175 | evaluation_prediction_precursors( 176 | val_targets, 177 | val_targets_formulas, 178 | all_pres_predict, 179 | ) 180 | 181 | all_pres_predict = recommend_w_SynSym( 182 | precursors_recommendator=precursors_recommendator, 183 | test_targets_formulas=test_targets_formulas, 184 | top_n=10, 185 | ) 186 | 187 | evaluation_prediction_precursors( 188 | test_targets, 189 | test_targets_formulas, 190 | all_pres_predict, 191 | ) 192 | 193 | ######################## 194 | # recommendation through product of precursor frequencies 195 | all_pres_predict = recommend_w_freq( 196 | precursors_recommendator=precursors_recommendator, 197 | test_targets_formulas=val_targets_formulas, 198 | top_n=10, 199 | ) 200 | 201 | evaluation_prediction_precursors( 202 | val_targets, 203 | val_targets_formulas, 204 | all_pres_predict, 205 | ) 206 | 207 | all_pres_predict = recommend_w_freq( 208 | precursors_recommendator=precursors_recommendator, 209 | test_targets_formulas=test_targets_formulas, 210 | top_n=10, 211 | ) 212 | 213 | evaluation_prediction_precursors( 214 | test_targets, 215 | test_targets_formulas, 216 | all_pres_predict, 217 | ) 218 | 219 | # To run the following baseline models, you need to install extra packages via 220 | # pip install matminer scikit-learn==1.0.2 gensim==3.8.3 221 | # You also need to download the model and data for Magpie and FastText encodings via 222 | # download_optional_data() in scripts/_00_download_model_and_data.py 223 | 224 | # ######################## 225 | # # recommendation through similarity based on matminer representation 226 | # all_pres_predict = recommend_w_MatMiner( 227 | # precursors_recommendator=precursors_recommendator, 228 | # test_targets_formulas=val_targets_formulas, 229 | # top_n=10, 230 | # ) 231 | # 232 | # evaluation_prediction_precursors( 233 | # val_targets, 234 | # val_targets_formulas, 235 | # all_pres_predict, 236 | # ) 237 | # 238 | # all_pres_predict = recommend_w_MatMiner( 239 | # precursors_recommendator=precursors_recommendator, 240 | # test_targets_formulas=test_targets_formulas, 241 | # top_n=10, 242 | # ) 243 | # 244 | # evaluation_prediction_precursors( 245 | # test_targets, 246 | # test_targets_formulas, 247 | # all_pres_predict, 248 | # ) 249 | 250 | # ######################## 251 | # # recommendation through similarity based on fasttext representation 252 | # (all_pres_predict, fasttext_supported_val_targets_formulas,) = recommend_w_FastText( 253 | # precursors_recommendator=precursors_recommendator, 254 | # test_targets_formulas=val_targets_formulas, 255 | # test_targets=val_targets, 256 | # top_n=10, 257 | # ) 258 | # 259 | # evaluation_prediction_precursors( 260 | # val_targets, 261 | # fasttext_supported_val_targets_formulas, 262 | # all_pres_predict, 263 | # ) 264 | # 265 | # ( 266 | # all_pres_predict, 267 | # fasttext_supported_test_targets_formulas, 268 | # ) = recommend_w_FastText( 269 | # precursors_recommendator=precursors_recommendator, 270 | # test_targets_formulas=test_targets_formulas, 271 | # test_targets=test_targets, 272 | # top_n=10, 273 | # ) 274 | # 275 | # evaluation_prediction_precursors( 276 | # test_targets, 277 | # fasttext_supported_test_targets_formulas, 278 | # all_pres_predict, 279 | # ) 280 | 281 | # ######################## 282 | # # recommendation through raw composition 283 | # all_pres_predict = recommend_w_RawComp( 284 | # precursors_recommendator=precursors_recommendator, 285 | # test_targets_formulas=val_targets_formulas, 286 | # top_n=10, 287 | # ) 288 | # 289 | # evaluation_prediction_precursors( 290 | # val_targets, 291 | # val_targets_formulas, 292 | # all_pres_predict, 293 | # ) 294 | # 295 | # all_pres_predict = recommend_w_RawComp( 296 | # precursors_recommendator=precursors_recommendator, 297 | # test_targets_formulas=test_targets_formulas, 298 | # top_n=10, 299 | # ) 300 | # 301 | # evaluation_prediction_precursors( 302 | # test_targets, 303 | # test_targets_formulas, 304 | # all_pres_predict, 305 | # ) 306 | 307 | 308 | def recommend_w_SynSym( 309 | precursors_recommendator, 310 | test_targets_formulas, 311 | top_n=10, 312 | ): 313 | all_predicts = precursors_recommendator.recommend_precursors( 314 | target_formula=test_targets_formulas, 315 | top_n=top_n, 316 | validate_first_attempt=False, 317 | recommendation_strategy="SynSim_conditional", 318 | ) 319 | 320 | all_pres_predict = [x['precursors_predicts'] for x in all_predicts] 321 | 322 | return all_pres_predict 323 | 324 | 325 | def recommend_w_RawComp( 326 | precursors_recommendator, 327 | test_targets_formulas, 328 | top_n=10, 329 | ): 330 | train_targets_compositions = [ 331 | precursors_recommendator.train_targets[formula]["comp"] 332 | for formula in precursors_recommendator.train_targets_formulas 333 | ] 334 | test_targets_compositions = [ 335 | formula_to_array(formula, precursors_recommendator.all_elements) 336 | for formula in test_targets_formulas 337 | ] 338 | train_targets_vecs = np.array(train_targets_compositions) 339 | test_targets_vecs = np.array(test_targets_compositions) 340 | 341 | train_targets_vecs = train_targets_vecs / ( 342 | np.linalg.norm( 343 | train_targets_vecs, 344 | axis=-1, 345 | keepdims=True, 346 | ) 347 | ) 348 | test_targets_vecs = test_targets_vecs / ( 349 | np.linalg.norm( 350 | test_targets_vecs, 351 | axis=-1, 352 | keepdims=True, 353 | ) 354 | ) 355 | all_distance = test_targets_vecs @ train_targets_vecs.T 356 | 357 | all_pres_predict, all_resutls = precursors_recommendator.recommend_precursors_by_similarity( 358 | test_targets_formulas=test_targets_formulas, 359 | train_targets_recipes=precursors_recommendator.train_targets_recipes, 360 | all_distance=all_distance, 361 | top_n=top_n, 362 | strategy="naive_common", 363 | ) 364 | 365 | all_pres_predict = [x['precursors_predicts'] for x in all_resutls] 366 | 367 | return all_pres_predict 368 | 369 | 370 | def recommend_w_MatMiner( 371 | precursors_recommendator, 372 | test_targets_formulas, 373 | top_n=10, 374 | ): 375 | path_to_imputer = "../other_rsc/matminer/mp_imputer_preset_v1.0.2.pkl" 376 | path_to_scaler = "../other_rsc/matminer/mp_scaler_preset_v1.0.2.pkl" 377 | train_targets_features = run_multiprocessing_tasks( 378 | tasks=precursors_recommendator.train_targets_formulas, 379 | thread_func=MatMiner_features_for_formulas, 380 | func_args=( 381 | path_to_imputer, 382 | path_to_scaler, 383 | ), 384 | num_cores=4, 385 | join_results=True, 386 | use_threading=False, 387 | mp_context=None, 388 | ) 389 | test_targets_features = run_multiprocessing_tasks( 390 | tasks=test_targets_formulas, 391 | thread_func=MatMiner_features_for_formulas, 392 | func_args=( 393 | path_to_imputer, 394 | path_to_scaler, 395 | ), 396 | num_cores=4, 397 | join_results=True, 398 | use_threading=False, 399 | mp_context=None, 400 | ) 401 | train_targets_vecs = np.array(train_targets_features) 402 | test_targets_vecs = np.array(test_targets_features) 403 | 404 | train_targets_vecs = train_targets_vecs / ( 405 | np.linalg.norm( 406 | train_targets_vecs, 407 | axis=-1, 408 | keepdims=True, 409 | ) 410 | ) 411 | test_targets_vecs = test_targets_vecs / ( 412 | np.linalg.norm( 413 | test_targets_vecs, 414 | axis=-1, 415 | keepdims=True, 416 | ) 417 | ) 418 | all_distance = test_targets_vecs @ train_targets_vecs.T 419 | 420 | all_pres_predict, all_results = precursors_recommendator.recommend_precursors_by_similarity( 421 | test_targets_formulas=test_targets_formulas, 422 | train_targets_recipes=precursors_recommendator.train_targets_recipes, 423 | all_distance=all_distance, 424 | top_n=top_n, 425 | strategy="naive_common", 426 | ) 427 | 428 | all_pres_predict = [x['precursors_predicts'] for x in all_results] 429 | 430 | return all_pres_predict 431 | 432 | 433 | def recommend_w_FastText( 434 | precursors_recommendator, 435 | test_targets_formulas, 436 | test_targets, 437 | top_n=10, 438 | ): 439 | 440 | if pkgutil.find_loader("gensim"): 441 | import gensim 442 | else: 443 | warnings.warn( 444 | "FastText encoding needs the package gensim==3.8.3. " 445 | "You may want to install it with 'pip install gensim==3.8.3'." 446 | ) 447 | 448 | path_model_fasttext = ( 449 | "../other_rsc/fasttext_pretrained_matsci/fasttext_embeddings-MINIFIED.model" 450 | ) 451 | 452 | if not os.path.exists(path_model_fasttext): 453 | warnings.warn( 454 | "You may want to download model and data for FastText encoding via " 455 | "download_optional_data() in scripts/_00_download_model_and_data.py. " 456 | ) 457 | 458 | fasttext = gensim.models.keyedvectors.KeyedVectors.load(path_model_fasttext) 459 | # Need to set this when loading from saved file 460 | fasttext.bucket = 2000000 461 | 462 | # load ele_order by statistics from text 463 | with open("../rsc/ele_order_counter.json", "r") as fr: 464 | stat_ele_order = json.load(fr) 465 | 466 | path_raw_reactions = "../rsc/reactions_v20_20210820_ss.jsonl" 467 | raw_reactions = load_raw_reactions(data_file=path_raw_reactions) 468 | print("len(raw_reactions)", len(raw_reactions)) 469 | 470 | # encode with fasttext 471 | train_targets_features = [] 472 | test_targets_features = [] 473 | fasttext_supported_train_targets_formulas = [] 474 | fasttext_supported_test_targets_formulas = [] 475 | for i, x in enumerate(precursors_recommendator.train_targets_formulas): 476 | try: 477 | human_formula = composition_to_human_formula( 478 | precursors_recommendator.train_targets[x]["comp"], 479 | raw_reactions[ 480 | list(precursors_recommendator.train_targets[x]["raw_index"])[0] 481 | ], 482 | precursors_recommendator.all_elements, 483 | stat_ele_order, 484 | ) 485 | except: 486 | print( 487 | "error in guess formula", 488 | x, 489 | list(precursors_recommendator.train_targets[x]["raw_index"]), 490 | ) 491 | 492 | try: 493 | train_targets_features.append(fasttext[human_formula.lower()]) 494 | fasttext_supported_train_targets_formulas.append(x) 495 | except: 496 | print("fasttext wrong train x skipped", x, human_formula) 497 | 498 | for i, x in enumerate(test_targets_formulas): 499 | try: 500 | human_formula = composition_to_human_formula( 501 | formula_to_array(x, precursors_recommendator.all_elements), 502 | raw_reactions[list(test_targets[x]["raw_index"])[0]], 503 | precursors_recommendator.all_elements, 504 | stat_ele_order, 505 | ) 506 | except: 507 | print( 508 | "error in guess formula", 509 | x, 510 | list(test_targets[x]["raw_index"]), 511 | ) 512 | 513 | try: 514 | test_targets_features.append(fasttext[human_formula.lower()]) 515 | fasttext_supported_test_targets_formulas.append(x) 516 | except: 517 | print("fasttext wrong test x skipped", x, human_formula) 518 | 519 | train_targets_formulas = fasttext_supported_train_targets_formulas 520 | test_targets_formulas = fasttext_supported_test_targets_formulas 521 | train_targets_vecs = np.array(train_targets_features) 522 | test_targets_vecs = np.array(test_targets_features) 523 | assert len(train_targets_formulas) == len(train_targets_vecs) 524 | assert len(test_targets_formulas) == len(test_targets_vecs) 525 | 526 | train_targets_vecs = train_targets_vecs / ( 527 | np.linalg.norm( 528 | train_targets_vecs, 529 | axis=-1, 530 | keepdims=True, 531 | ) 532 | ) 533 | test_targets_vecs = test_targets_vecs / ( 534 | np.linalg.norm( 535 | test_targets_vecs, 536 | axis=-1, 537 | keepdims=True, 538 | ) 539 | ) 540 | all_distance = test_targets_vecs @ train_targets_vecs.T 541 | 542 | train_targets_recipes = [ 543 | precursors_recommendator.train_targets[x] for x in train_targets_formulas 544 | ] 545 | 546 | all_pres_predict, all_results = precursors_recommendator.recommend_precursors_by_similarity( 547 | test_targets_formulas=test_targets_formulas, 548 | train_targets_recipes=train_targets_recipes, 549 | all_distance=all_distance, 550 | top_n=top_n, 551 | strategy="naive_common", 552 | ) 553 | 554 | all_pres_predict = [x['precursors_predicts'] for x in all_results] 555 | 556 | return all_pres_predict, test_targets_formulas 557 | 558 | 559 | def recommend_w_freq( 560 | precursors_recommendator, 561 | test_targets_formulas, 562 | top_n=10, 563 | ): 564 | all_pres_predict = [] 565 | precursor_frequencies = precursors_recommendator.precursor_frequencies 566 | common_eles = set(["C", "H", "O", "N"]) 567 | nonvolatile_nonmetal_eles = { 568 | "P", 569 | "S", 570 | "Se", 571 | } 572 | 573 | for i, x in enumerate(test_targets_formulas): 574 | # prediction for precursors 575 | eles_x = set(get_elements_in_formula(x)) 576 | effective_eles_x = eles_x & set(precursor_frequencies.keys()) 577 | 578 | pres_multi_predicts = [] 579 | pres_candidates_by_ele = [ 580 | precursor_frequencies[ele] for ele in effective_eles_x 581 | ] 582 | if len(effective_eles_x & nonvolatile_nonmetal_eles) > 0: 583 | pres_candidates_by_ele_wo_nonmetal = [ 584 | precursor_frequencies[ele] 585 | for ele in (effective_eles_x - nonvolatile_nonmetal_eles) 586 | ] 587 | else: 588 | # no need to repeat iteration based on 589 | # pres_candidates_by_ele if no extra non-metal element 590 | pres_candidates_by_ele_wo_nonmetal = [] 591 | front_p_min = 10 592 | front_p_max = front_p_min 593 | for p_by_e in pres_candidates_by_ele: 594 | if len(p_by_e) > front_p_max: 595 | front_p_max = len(p_by_e) 596 | # front_p_min = front_p_max 597 | for front_p in range(front_p_min, front_p_max + 1): 598 | # get candidates in front first to reduce computational cost 599 | pres_candidates = [] 600 | pres_probabilities = [] 601 | for comb_i, pre_comb in enumerate( 602 | itertools.chain( 603 | itertools.product( 604 | *[p_by_e[:front_p] for p_by_e in pres_candidates_by_ele] 605 | ), 606 | itertools.product( 607 | *[ 608 | p_by_e[:front_p] 609 | for p_by_e in pres_candidates_by_ele_wo_nonmetal 610 | ] 611 | ), 612 | ) 613 | ): 614 | # It is safe to presume the first term (two precursors, 615 | # one for metal, one for nonmetal, comb_i==0) 616 | # has the largest frequency because the precursor 617 | # with both metal and nonmetal has low frequency, 618 | # which is always lower than the product of two 619 | # common precursors of the metal and the nonmetal 620 | # precursors recommended or not 621 | 622 | # make sure no duplication 623 | pres_formulas = tuple(sorted(set([x["formula"] for x in pre_comb]))) 624 | if pres_formulas in {x["precursors"] for x in pres_candidates}: 625 | # is_recommended = True 626 | continue 627 | 628 | # element matched or first attempt using all common precursors 629 | pres_eles = set(sum([x["elements"] for x in pre_comb], [])) 630 | if ( 631 | pres_eles.issubset(eles_x | common_eles) 632 | and eles_x.issubset( 633 | pres_eles 634 | | { 635 | "O", 636 | "H", 637 | } 638 | ) 639 | ) or (comb_i == 0): 640 | pres_prob = np.prod([x["frequency"] for x in pre_comb]) 641 | pres_candidates, pres_probabilities = add_to_sorted_list( 642 | items=pres_candidates, 643 | values=pres_probabilities, 644 | new_item={ 645 | "precursors": pres_formulas, 646 | "probability": pres_prob, 647 | "elements": pres_eles, 648 | }, 649 | new_value=pres_prob, 650 | ) 651 | if len(pres_candidates) > top_n: 652 | # sorting in add_to_sorted_list is from low to high 653 | pres_candidates.pop(0) 654 | pres_probabilities.pop(0) 655 | 656 | # check if candidates are sufficient 657 | if len(pres_candidates) >= top_n or front_p >= front_p_max: 658 | # sort candidates by probability 659 | pres_candidates = sorted( 660 | pres_candidates, 661 | key=lambda x: x["probability"], 662 | reverse=True, 663 | ) 664 | pres_multi_predicts = [ 665 | tuple(sorted(pres_cand["precursors"])) 666 | for pres_cand in pres_candidates[:top_n] 667 | ] 668 | break 669 | 670 | all_pres_predict.append(pres_multi_predicts) 671 | 672 | return all_pres_predict 673 | 674 | 675 | if __name__ == "__main__": 676 | # use_file_as_stdout("../generated/output.txt") 677 | run_recommendations() 678 | -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts/_06_computation_time_similarity.py: -------------------------------------------------------------------------------- 1 | """ 2 | Benchmark the time efficiency for batched and non-batched similarity calculation. 3 | """ 4 | 5 | import os 6 | import sys 7 | 8 | parent_folder = os.path.abspath( 9 | os.path.join( 10 | os.path.dirname(__file__), 11 | "../..", 12 | ) 13 | ) 14 | print("parent_folder", parent_folder) 15 | if parent_folder not in sys.path: 16 | sys.path.append(parent_folder) 17 | 18 | import numpy as np 19 | from timebudget import timebudget 20 | 21 | from SynthesisSimilarity.core.utils import formula_to_array 22 | from SynthesisSimilarity.core.mat_featurization import featurize_list_of_composition 23 | from SynthesisSimilarity.scripts_utils.recommendation_utils import ( 24 | collect_targets_in_reactions, 25 | ) 26 | from SynthesisSimilarity.scripts_utils.precursors_recommendation_utils import ( 27 | PrecursorsRecommendation, 28 | ) 29 | 30 | 31 | def similarity_time(): 32 | with timebudget("Loading model:"): 33 | precursors_recommendator = PrecursorsRecommendation( 34 | model_dir="../models/SynthesisRecommendation", 35 | freq_path="../rsc/pre_count_normalized_by_rxn_ss.json", 36 | data_path="../rsc/data_split.npz", 37 | all_to_knowledge_base=False, 38 | ) 39 | 40 | data_path = "../rsc/data_split.npz" 41 | test_data = np.load(data_path, allow_pickle=True) 42 | test_reactions = test_data["test_reactions"] 43 | ( 44 | test_targets, 45 | test_targets_formulas, 46 | test_targets_features, 47 | ) = collect_targets_in_reactions( 48 | test_reactions, 49 | precursors_recommendator.all_elements, 50 | precursors_recommendator.common_precursors_set, 51 | ) 52 | 53 | with timebudget("Evaluate similarity without batching"): 54 | all_distances = [] 55 | for x in test_targets_formulas: 56 | formulas = [x] 57 | # get target_candidate_normal_vecs 58 | # TODO: should this test_targets_compositions be ndarray? 59 | test_targets_compositions = [ 60 | formula_to_array(formula, precursors_recommendator.all_elements) 61 | for formula in formulas 62 | ] 63 | test_targets_features = featurize_list_of_composition( 64 | comps=test_targets_compositions, 65 | ele_order=precursors_recommendator.all_elements, 66 | featurizer_type=precursors_recommendator.featurizer_type, 67 | ) 68 | 69 | # TP similarity 70 | # train_targets_features is pre-transformed features 71 | # TODO: convert test_targets_features to np in advance 72 | test_targets_vecs = precursors_recommendator.framework_model.get_mat_vector( 73 | np.array(test_targets_features) 74 | ).numpy() 75 | 76 | test_targets_vecs = test_targets_vecs / ( 77 | np.linalg.norm(test_targets_vecs, axis=-1, keepdims=True) 78 | ) 79 | 80 | distance = test_targets_vecs @ precursors_recommendator.train_targets_vecs.T 81 | all_distances.append(distance) 82 | all_distances = np.concatenate( 83 | all_distances, 84 | axis=0, 85 | ) 86 | print("all_distances.shape", all_distances.shape) 87 | 88 | with timebudget("Evaluate similarity with batching"): 89 | all_distances = [] 90 | 91 | # get target_candidate_normal_vecs 92 | # TODO: should this test_targets_compositions be ndarray? 93 | test_targets_compositions = [ 94 | formula_to_array(formula, precursors_recommendator.all_elements) 95 | for formula in test_targets_formulas 96 | ] 97 | test_targets_features = featurize_list_of_composition( 98 | comps=test_targets_compositions, 99 | ele_order=precursors_recommendator.all_elements, 100 | featurizer_type=precursors_recommendator.featurizer_type, 101 | ) 102 | 103 | # TP similarity 104 | # train_targets_features is pre-transformed features 105 | # TODO: convert test_targets_features to np in advance 106 | test_targets_vecs = precursors_recommendator.framework_model.get_mat_vector( 107 | np.array(test_targets_features) 108 | ).numpy() 109 | 110 | test_targets_vecs = test_targets_vecs / ( 111 | np.linalg.norm(test_targets_vecs, axis=-1, keepdims=True) 112 | ) 113 | 114 | all_distances = ( 115 | test_targets_vecs @ precursors_recommendator.train_targets_vecs.T 116 | ) 117 | print("all_distances.shape", all_distances.shape) 118 | 119 | 120 | if __name__ == "__main__": 121 | similarity_time() 122 | -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | from ._00_download_model_and_data import download_necessary_data 2 | from ._00_download_model_and_data import download_optional_data 3 | 4 | __author__ = 'Tanjin He' 5 | __maintainer__ = 'Tanjin He' 6 | __email__ = 'tanjin_he@berkeley.edu' 7 | 8 | -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts_utils/FastTextSimilarity_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | import json 4 | import itertools 5 | from functools import cmp_to_key 6 | 7 | from SynthesisSimilarity.core import utils 8 | 9 | 10 | def get_elements_of_material_in_raw_reactions(composition, element_substitution): 11 | elements = set() 12 | for comp in composition: 13 | elements.update(comp["elements"].keys()) 14 | 15 | for ele in elements & set(element_substitution.keys()): 16 | elements.add(element_substitution[ele]) 17 | elements = elements - set(element_substitution.keys()) 18 | 19 | return elements 20 | 21 | 22 | def elements_to_binary_composition(ele_in_mat, all_eles): 23 | binary_comp = np.zeros(shape=(len(all_eles)), dtype=np.float32) 24 | for ele in ele_in_mat: 25 | binary_comp[all_eles.index(ele)] = 1.0 26 | return binary_comp 27 | 28 | 29 | def get_binary_composition_of_material_in_raw_reactions( 30 | composition, 31 | element_substitution, 32 | all_eles, 33 | ): 34 | ele_in_mat = get_elements_of_material_in_raw_reactions( 35 | composition, element_substitution 36 | ) 37 | binary_comp = elements_to_binary_composition(ele_in_mat, all_eles) 38 | return binary_comp 39 | 40 | 41 | def collect_targets_in_raw_reactions( 42 | train_reactions, 43 | all_elements, 44 | common_precursors_set, 45 | exclude_common_precursors=False, 46 | ): 47 | 48 | train_targets = {} 49 | ref_precursors_comp = {} 50 | 51 | for r in train_reactions: 52 | tar_f = r["target"]["material_string"] 53 | 54 | pre_fs = set() 55 | for pre in r["precursors"]: 56 | # use material_formula for precursor but material_string for target 57 | # because we don't need to encode precursors with FastText 58 | # material_string can be name of materials sometimes 59 | non_hydrate_comp = pre["composition"] 60 | if len(non_hydrate_comp) > 1: 61 | non_hydrate_comp = list( 62 | filter( 63 | lambda x: ( 64 | set(x["elements"].keys()) != {"H", "O"} 65 | or x["formula"] == "H2O2" 66 | ), 67 | non_hydrate_comp, 68 | ) 69 | ) 70 | if len(non_hydrate_comp) == 1: 71 | try: 72 | unified_pre_formula = utils.dict_to_simple_formula( 73 | non_hydrate_comp[0]["elements"] 74 | ) 75 | except: 76 | unified_pre_formula = pre["material_formula"] 77 | else: 78 | unified_pre_formula = pre["material_formula"] 79 | pre_fs.add(unified_pre_formula) 80 | if unified_pre_formula not in ref_precursors_comp: 81 | ref_precursors_comp[ 82 | unified_pre_formula 83 | ] = get_binary_composition_of_material_in_raw_reactions( 84 | pre["composition"], 85 | r["reaction"]["element_substitution"], 86 | all_elements, 87 | ) 88 | pre_fs = tuple(sorted(pre_fs)) 89 | if exclude_common_precursors and set(pre_fs).issubset(common_precursors_set): 90 | continue 91 | if tar_f not in train_targets: 92 | train_targets[tar_f] = { 93 | "comp": get_binary_composition_of_material_in_raw_reactions( 94 | r["target"]["composition"], 95 | r["reaction"]["element_substitution"], 96 | all_elements, 97 | ), 98 | "pres": collections.Counter(), 99 | "syn_type": collections.Counter(), 100 | "raw_index": set(), 101 | "is_common": collections.Counter(), 102 | } 103 | 104 | train_targets[tar_f]["pres"][pre_fs] += 1 105 | train_targets[tar_f]["raw_index"].add(r["raw_index"]) 106 | 107 | if set(pre_fs).issubset(common_precursors_set): 108 | train_targets[tar_f]["is_common"]["common"] += 1 109 | else: 110 | train_targets[tar_f]["is_common"]["uncommon"] += 1 111 | if "synthesis_type" in r: 112 | train_targets[tar_f]["syn_type"][r["synthesis_type"]] += 1 113 | 114 | train_targets_formulas = list(train_targets.keys()) 115 | 116 | return train_targets, train_targets_formulas, ref_precursors_comp 117 | 118 | 119 | def save_ele_order( 120 | raw_reactions, 121 | save_path="generated/ele_order_counter.json", 122 | ): 123 | ele_order_counter = collections.Counter() 124 | 125 | for r in raw_reactions: 126 | for comp in r["target"]["composition"]: 127 | ele_substituion = r["reaction"]["element_substitution"] 128 | ele_pos = { 129 | ele: r["target"]["material_formula"].find(ele) 130 | for ele in comp["elements"] 131 | } 132 | for (e1, e2) in itertools.combinations(comp["elements"].keys(), 2): 133 | if ele_pos[e1] < 0: 134 | continue 135 | if ele_pos[e2] < 0: 136 | continue 137 | 138 | e1_subbed = ele_substituion.get(e1, e1) 139 | e2_subbed = ele_substituion.get(e2, e2) 140 | if ele_pos[e1] <= ele_pos[e2]: 141 | left_ele = e1_subbed 142 | right_ele = e2_subbed 143 | else: 144 | left_ele = e2_subbed 145 | right_ele = e1_subbed 146 | ele_order_counter["{} before {}".format(left_ele, right_ele)] += 1 147 | 148 | print("len(ele_order_counter)", len(ele_order_counter)) 149 | print(ele_order_counter.most_common(100)) 150 | 151 | with open(save_path, "w") as fw: 152 | json.dump(ele_order_counter, fw, indent=2) 153 | 154 | 155 | def sort_elements_by_stat_order(elements, stat_ele_order): 156 | return sorted( 157 | elements, 158 | key=cmp_to_key( 159 | lambda ele_a, ele_b: compare_elements_by_stat_order( 160 | ele_a, ele_b, stat_ele_order 161 | ) 162 | ), 163 | ) 164 | 165 | 166 | def compare_elements_by_stat_order(ele_a, ele_b, stat_ele_order): 167 | a_before_b = stat_ele_order.get("{} before {}".format(ele_a, ele_b), 0) 168 | b_before_a = stat_ele_order.get("{} before {}".format(ele_b, ele_a), 0) 169 | return b_before_a - a_before_b 170 | 171 | 172 | def composition_to_human_formula( 173 | composition, raw_reaction, all_elements, stat_ele_order 174 | ): 175 | human_formula = None 176 | 177 | all_elements_indices = {ele: i for (i, ele) in enumerate(all_elements)} 178 | 179 | comp_eles = set(np.array(all_elements)[composition > 0]) 180 | 181 | ele_substituion = raw_reaction["reaction"]["element_substitution"] 182 | 183 | for comp in raw_reaction["target"]["composition"]: 184 | for (ele, num) in comp["elements"].items(): 185 | ele_subbed = ele_substituion.get(ele, ele) 186 | if ele_subbed not in comp_eles: 187 | continue 188 | if num.isdigit(): 189 | if abs(float(num)) < utils.NEAR_ZERO: 190 | continue 191 | ele_index = all_elements_indices[ele_subbed] 192 | composition = composition / composition[ele_index] * float(num) 193 | sorted_comp_eles = sort_elements_by_stat_order(comp_eles, stat_ele_order) 194 | human_formula = "" 195 | for ele in sorted_comp_eles: 196 | ele_index = all_elements_indices[ele] 197 | if composition[ele_index] == 1.0: 198 | human_formula += ele 199 | else: 200 | human_formula += ( 201 | "{}{:.3f}".format(ele, composition[ele_index]).rstrip("0").rstrip(".") 202 | ) 203 | return human_formula 204 | -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts_utils/MatminerSimilarity_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Sep 27 18:07:28 2019 5 | 6 | @author: chrisbartel 7 | """ 8 | import pdb 9 | 10 | import pkgutil 11 | 12 | if pkgutil.find_loader("matminer"): 13 | from matminer.featurizers.composition import ElementProperty 14 | 15 | from pymatgen.core import Composition 16 | import os 17 | import warnings 18 | import numpy as np 19 | import json 20 | import multiprocessing as mp 21 | from sklearn.impute import SimpleImputer 22 | from sklearn.preprocessing import StandardScaler 23 | import joblib 24 | 25 | 26 | def _dist(m1, m2): 27 | """ 28 | Args: 29 | m1 (1d-array) - feature vector 1 30 | m2 (1d-array) - feature vector 2 31 | 32 | Returns: 33 | Euclidean distance between vectors (float) 34 | """ 35 | 36 | return np.sqrt(np.sum([(m1[i] - m2[i]) ** 2 for i in range(len(m1))])) 37 | 38 | 39 | def _similarity(m1, m2): 40 | """ 41 | Args: 42 | m1 (1d-array) - feature vector 1 43 | m2 (1d-array) - feature vector 2 44 | 45 | Returns: 46 | inverse distance (similarity) between two vectors (float) 47 | """ 48 | 49 | return 1 / _dist(m1, m2) 50 | 51 | 52 | class MatminerSimilarity(object): 53 | def __init__( 54 | self, 55 | path_to_imputer, 56 | path_to_scaler, 57 | data_source="magpie", 58 | ): 59 | """ 60 | Args: 61 | path_to_imputer (str) - path to .pkl with SimpleImputer fit to MP 62 | path_to_scaler (str) - path to .pkl with StandardScaler fit to MP 63 | features (str) - 'magpie' ('pymatgen', etc. not implemented) 64 | data_source (str) - 'magpie' ('pymatgen', etc. not implemented) 65 | stats (list) - list of statistics (str) to manipulate features with 66 | 67 | Returns: 68 | list of features (str) 69 | loaded imputer and scaler 70 | 71 | """ 72 | 73 | if data_source == "magpie": 74 | self.data_source = "magpie" 75 | else: 76 | raise NotImplementedError 77 | 78 | if os.path.exists(path_to_imputer): 79 | self.imputer = joblib.load(path_to_imputer) 80 | else: 81 | self.imputer = None 82 | if os.path.exists(path_to_scaler): 83 | self.scaler = joblib.load(path_to_scaler) 84 | else: 85 | self.scaler = None 86 | 87 | def feature_vector(self, formula): 88 | vector = ElementProperty.from_preset(self.data_source).featurize( 89 | Composition(formula) 90 | ) 91 | vector = np.array(vector) 92 | return vector 93 | 94 | def feature_vector_normalized(self, formula): 95 | """ 96 | Args: 97 | formula (str) 98 | 99 | Returns: 100 | 1d array of feature values (float or NaN) 101 | """ 102 | vector = self.feature_vector(formula) 103 | imp, sc = self.imputer, self.scaler 104 | vector = np.expand_dims(vector, axis=0) 105 | vector = imp.transform(vector) 106 | vector = sc.transform(vector) 107 | vector = vector.squeeze(axis=0) 108 | return vector 109 | 110 | def compare(self, formula1, formula2): 111 | """ 112 | Args: 113 | formula1 (str) 114 | formula2 (str) 115 | 116 | Returns: 117 | inverse distance representation of similarity (float) 118 | """ 119 | m1 = self.feature_vector_normalized(formula1) 120 | m2 = self.feature_vector_normalized(formula2) 121 | return _similarity(m1, m2) 122 | 123 | 124 | def _get_feature_vector(obj, formula): 125 | return obj.feature_vector(formula) 126 | 127 | 128 | class RegenerateScalerImputer(object): 129 | """ 130 | For re-generating Scaler and Imputer when you e.g., change the feature or stats bases 131 | """ 132 | 133 | def __init__( 134 | self, 135 | path_to_formulas, 136 | inputs={ 137 | "data_source": "magpie", 138 | "path_to_scaler": "mp_scaler.pkl", 139 | "path_to_imputer": "mp_imputer.pkl", 140 | }, 141 | ): 142 | with open(path_to_formulas) as f: 143 | d = json.load(f) 144 | self.formulas = d["formulas"] 145 | self.inputs = inputs 146 | 147 | @property 148 | def MatminerSimilarityObject(self): 149 | inputs = self.inputs 150 | obj = MatminerSimilarity( 151 | path_to_imputer=inputs["path_to_imputer"], 152 | path_to_scaler=inputs["path_to_scaler"], 153 | data_source=inputs["data_source"], 154 | ) 155 | return obj 156 | 157 | @property 158 | def X(self): 159 | # THIS TAKES SEVERAL MINUTES ON MY NICE MAC 160 | obj = self.MatminerSimilarityObject 161 | formulas = self.formulas 162 | pool = mp.Pool(processes=mp.cpu_count() - 1) 163 | return [ 164 | r 165 | for r in pool.starmap( 166 | _get_feature_vector, [(obj, formula) for formula in formulas] 167 | ) 168 | ] 169 | 170 | @property 171 | def fit_imputer_and_scaler(self): 172 | X = self.X 173 | imp = SimpleImputer() 174 | imp.fit(X) 175 | X = imp.transform(X) 176 | sc = StandardScaler() 177 | sc.fit(X) 178 | 179 | fimp = self.inputs["path_to_imputer"] 180 | fsc = self.inputs["path_to_scaler"] 181 | 182 | joblib.dump(imp, fimp) 183 | joblib.dump(sc, fsc) 184 | print("done") 185 | 186 | 187 | def MatMiner_features_for_formulas( 188 | formulas, 189 | path_to_imputer, 190 | path_to_scaler, 191 | ): 192 | if not pkgutil.find_loader("matminer"): 193 | warnings.warn( 194 | "Magpie encoding needs the package matminer and scikit-learn==1.0.2. " 195 | "You may want to install them with 'pip install matminer scikit-learn==1.0.2'. " 196 | ) 197 | 198 | if not os.path.exists(path_to_imputer): 199 | warnings.warn( 200 | "You may want to download model and data for Magpie encoding via " 201 | "download_optional_data() in scripts/_00_download_model_and_data.py. " 202 | ) 203 | 204 | all_features = [] 205 | obj = MatminerSimilarity( 206 | path_to_imputer=path_to_imputer, 207 | path_to_scaler=path_to_scaler, 208 | ) 209 | for x in formulas: 210 | all_features.append( 211 | obj.feature_vector_normalized(x), 212 | ) 213 | return all_features 214 | -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts_utils/TarMatSimilarity_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | import os 4 | from pymatgen.core import Composition 5 | from SynthesisSimilarity.core import model_utils 6 | from SynthesisSimilarity.core import mat_featurization 7 | from SynthesisSimilarity.scripts_utils import similarity_utils 8 | 9 | __author__ = 'Tanjin He' 10 | __maintainer__ = 'Tanjin He' 11 | __email__ = 'tanjin_he@berkeley.edu' 12 | 13 | 14 | class TarMatSimilarity(object): 15 | def __init__(self, model_dir): 16 | """ 17 | Args: 18 | model_dir (str) - path to dir where the model is saved 19 | """ 20 | (self.model, self.model_config) = similarity_utils.load_encoding_model(model_dir) 21 | 22 | def feature_vector(self, formula): 23 | """ 24 | Args: 25 | formula (str) 26 | 27 | Returns: 28 | 1d array of feature values (float) 29 | """ 30 | mat_vector = self.get_mat_vector_by_formula(formula) 31 | return mat_vector 32 | 33 | def compare(self, formula1, formula2): 34 | """ 35 | Args: 36 | formula1 (str) 37 | formula2 (str) 38 | 39 | Returns: 40 | similarity by cosine distance of two vectors 41 | """ 42 | m1 = self.feature_vector(formula1) 43 | m2 = self.feature_vector(formula2) 44 | return self.cos_distance(m1, m2) 45 | 46 | def get_mat_vector_by_formula(self, formula): 47 | return self.get_mat_vector_by_formulas([formula])[0] 48 | 49 | def get_mat_vector_by_formulas(self, formulas): 50 | compositions = [] 51 | for f in formulas: 52 | mat = Composition(f).as_dict() 53 | compositions.append(mat) 54 | return self.get_mat_vector_by_compositions(compositions) 55 | 56 | def get_mat_vector_by_compositions(self, compositions): 57 | comps = [] 58 | for x in compositions: 59 | comps.append(similarity_utils.composition_to_array(x, self.model_config['all_eles'])) 60 | comps = mat_featurization.featurize_list_of_composition( 61 | comps=comps, 62 | ele_order=self.model_config['all_eles'], 63 | featurizer_type=self.model_config['featurizer_type'], 64 | ion_order=self.model_config['all_ions'], 65 | ) 66 | comps = np.array(comps) 67 | return self.model(comps) 68 | 69 | def cos_distance(self, vec_1, vec_2): 70 | dot = np.dot(vec_1, vec_2) 71 | norm_1 = np.linalg.norm(vec_1) 72 | norm_2 = np.linalg.norm(vec_2) 73 | return dot / (norm_1 * norm_2) 74 | 75 | -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .precursors_recommendation_utils import PrecursorsRecommendation 2 | 3 | 4 | __author__ = 'Tanjin He' 5 | __maintainer__ = 'Tanjin He' 6 | __email__ = 'tanjin_he@berkeley.edu' -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts_utils/benchmark_utils.py: -------------------------------------------------------------------------------- 1 | # TODO: add evaluation and subfunctions from _05 2 | -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts_utils/data_set_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import random 3 | from pprint import pprint 4 | import warnings 5 | import numpy as np 6 | import tensorflow as tf 7 | from pymatgen.core import Composition 8 | 9 | from SynthesisSimilarity.scripts_utils import train_utils 10 | from SynthesisSimilarity.core import mat_featurization 11 | from SynthesisSimilarity.core import utils 12 | 13 | 14 | def load_and_generate_test_set( 15 | reload_path="generated/data.npz", 16 | featurizer_type="default", 17 | random_seed_str=None, 18 | all_elements=None, 19 | all_ions=None, 20 | ion_freq_threshold=0, 21 | common_ion_not_feature=False, 22 | ): 23 | ######################################### 24 | # load data 25 | ######################################### 26 | eles, reactions, ions, ion_counter = train_utils.load_synthesis_data( 27 | data_path=None, 28 | reload_path=reload_path, 29 | max_mats_num=6, 30 | reload=True, 31 | ) 32 | reactions, ions = train_utils.truncate_valence_array_in_reactions( 33 | reactions=reactions, 34 | all_ions=ions, 35 | ion_counter=ion_counter, 36 | ion_freq_threshold=ion_freq_threshold, 37 | common_ion_not_feature=common_ion_not_feature, 38 | ) 39 | 40 | if all_elements is not None: 41 | if all_elements != eles: 42 | warnings.warn( 43 | "all_elements (from model) != eles (from dataset)!" 44 | "Check the dataset version to make sure they are " 45 | "are consistent with each other! Temporarily, " 46 | "all_elements from model is used. " 47 | ) 48 | # map indices of eles to all_elements 49 | assert len(eles) == len(all_elements) 50 | mapping_indices = [eles.index(e) for e in all_elements] 51 | for r in reactions: 52 | for i in range(len(r["target_comp"])): 53 | r["target_comp"][i] = r["target_comp"][i][mapping_indices] 54 | for i in range(len(r["precursors_comp"])): 55 | for j in range(len(r["precursors_comp"][i])): 56 | r["precursors_comp"][i][j] = r["precursors_comp"][i][j][ 57 | mapping_indices 58 | ] 59 | eles = all_elements 60 | 61 | if all_ions is not None: 62 | assert all_ions == ions 63 | # TODO: add code to allow all_ions to be a subset of ions 64 | 65 | reactions, mat_feature_len = mat_featurization.featurize_reactions( 66 | reactions, 67 | ele_order=eles, 68 | featurizer_type=featurizer_type, 69 | ion_order=ions, 70 | ) 71 | print("mat_feature_len", mat_feature_len) 72 | 73 | # random seed 74 | if random_seed_str: 75 | print("random_seed_str", random_seed_str) 76 | random_seed = sum(map(ord, random_seed_str.strip())) 77 | else: 78 | random_seed = None 79 | 80 | # split data to train/val/test sets 81 | train_reactions, val_reactions, test_reactions = utils.split_reactions( 82 | reactions, 83 | val_frac=0.05, 84 | test_frac=0.10, 85 | # keys=('doi', 'raw_index', 'target_comp'), 86 | keys=( 87 | "doi", 88 | "raw_index", 89 | "target_comp", 90 | "prototype_path", 91 | ), 92 | # keys=('raw_index',), 93 | # keys=(), 94 | random_seed=random_seed, 95 | by_year=True, 96 | ) 97 | 98 | num_train_reactions = train_utils.get_num_reactions(train_reactions) 99 | ele_counts = train_utils.get_ele_counts(train_reactions) 100 | print("num_train_reactions", num_train_reactions) 101 | print("len(eles)", len(eles)) 102 | print("len(ions)", len(ions)) 103 | print("len(ele_counts)", len(ele_counts)) 104 | print([(e, c) for (e, c) in zip(eles, ele_counts)]) 105 | 106 | ######################################### 107 | # get train, val, test in batch format 108 | ######################################### 109 | train_X, train_Y = train_utils.train_data_generator( 110 | train_reactions, 111 | num_batch=2000, 112 | max_mats_num=6, 113 | batch_size=8, 114 | ) 115 | train_XY = tf.data.Dataset.zip((train_X, train_Y)) 116 | train_XY = train_XY.prefetch(buffer_size=10) 117 | 118 | # print(next(iter(train_XY.unbatch()))) 119 | 120 | val_X, val_Y = train_utils.prepare_dataset( 121 | val_reactions, 122 | max_mats_num=6, 123 | batch_size=8, 124 | sampling_ratio=1e-3, 125 | random_seed=random_seed, 126 | ) 127 | test_X, test_Y = train_utils.prepare_dataset( 128 | test_reactions, 129 | max_mats_num=6, 130 | batch_size=8, 131 | sampling_ratio=1e-3, 132 | random_seed=random_seed, 133 | ) 134 | 135 | data = { 136 | "train_reactions": train_reactions, 137 | "val_reactions": val_reactions, 138 | "test_reactions": test_reactions, 139 | "train_X": train_X, 140 | "val_X": val_X, 141 | "test_X": test_X, 142 | "all_eles": eles, 143 | "all_ions": ions, 144 | } 145 | return data 146 | 147 | 148 | if __name__ == "__main__": 149 | 150 | from SynthesisSimilarity.core import model_utils 151 | 152 | print("---------------------loading data------------------------------") 153 | model_dir = "../models/SynthesisRecommendation" 154 | npz_reload_path = "../rsc_preparation/data_ss.npz" 155 | framework_model, model_config = model_utils.load_framework_model(model_dir) 156 | all_elements = model_config["all_eles"] 157 | featurizer_type = model_config["featurizer_type"] 158 | 159 | test_data = load_and_generate_test_set( 160 | reload_path=npz_reload_path, 161 | featurizer_type=featurizer_type, 162 | all_elements=all_elements, 163 | ) 164 | train_reactions = test_data["train_reactions"] 165 | val_reactions = test_data["val_reactions"] 166 | test_reactions = test_data["test_reactions"] 167 | 168 | print("---------------------saving data------------------------------") 169 | print("len(train_reactions)", len(train_reactions)) 170 | print("len(val_reactions)", len(val_reactions)) 171 | print("len(test_reactions)", len(test_reactions)) 172 | npz_save_path = "../rsc/data_split.npz" 173 | np.savez( 174 | npz_save_path, 175 | train_reactions=train_reactions, 176 | val_reactions=val_reactions, 177 | test_reactions=test_reactions, 178 | ) 179 | -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts_utils/multi_processing_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import time 5 | from pprint import pprint 6 | import numpy as np 7 | import multiprocessing as mp 8 | import argparse 9 | import sys 10 | 11 | import concurrent.futures 12 | 13 | __author__ = 'Tanjin He' 14 | __maintainer__ = 'Tanjin He' 15 | __email__ = 'tanjin_he@berkeley.edu' 16 | 17 | 18 | def run_multiprocessing_tasks( 19 | tasks, 20 | thread_func, 21 | func_args=(), 22 | num_cores=4, 23 | verbose=False, 24 | join_results=False, 25 | use_threading=False, 26 | mp_context=None, 27 | ): 28 | # execute pipeline in a parallel way 29 | last_time = time.time() 30 | 31 | # get parallel_arguments 32 | if tasks: 33 | parallel_arguments = [] 34 | num_tasks_per_core = math.ceil(len(tasks)/num_cores) 35 | for i in range(num_cores): 36 | parallel_arguments.append( 37 | (tasks[i*num_tasks_per_core: (i+1)*num_tasks_per_core], ) + func_args 38 | ) 39 | else: 40 | parallel_arguments = [func_args] * num_cores 41 | 42 | if not use_threading: 43 | # running using mp 44 | # use 'spawn' for tf 45 | mp_ctx = mp.get_context(mp_context) 46 | p = mp_ctx.Pool(processes=num_cores) 47 | all_summary = p.starmap(thread_func, parallel_arguments) 48 | p.close() 49 | p.join() 50 | else: 51 | # running using threading 52 | with concurrent.futures.ThreadPoolExecutor() as executor: 53 | futures =[ 54 | executor.submit(thread_func, *(parallel_arguments[i])) 55 | for i in range(num_cores) 56 | ] 57 | all_summary = [f.result() for f in futures] 58 | 59 | if verbose: 60 | # TODO: maybe the type of tmp_summary here is not very correct 61 | # reading results 62 | print('time used:', time.time()-last_time) 63 | if isinstance(all_summary[0], dict) and 'success_tasks' in all_summary[0]: 64 | # combine all results 65 | all_success_tasks = sum([tmp_summary['success_tasks'] for tmp_summary in all_summary], []) 66 | print('len(all_success_tasks)', len(all_success_tasks)) 67 | 68 | if isinstance(all_summary[0], dict) and 'error_tasks' in all_summary[0]: 69 | # combine all error tasks 70 | all_error_tasks = sum([tmp_summary['error_tasks'] for tmp_summary in all_summary], []) 71 | print('len(all_error_tasks)', len(all_error_tasks)) 72 | 73 | if join_results and isinstance(all_summary[0], list): 74 | # when output is a single variable, the mp output is a list with length of cores, 75 | # where each element is a list of results from each processor. 76 | # Therefore, need to sum to combine results 77 | last_results = sum(all_summary, []) 78 | elif join_results and isinstance(all_summary[0], tuple): 79 | # when output is multiple variables (a tuple), the mp output is a tuple, 80 | # where each variable is a list of results from each processor. 81 | # Therefore, need to sum to combine results in each variable 82 | last_results = [] 83 | for i in range(len(all_summary[0])): 84 | last_results.append([x[i] for x in all_summary]) 85 | else: 86 | last_results = all_summary 87 | 88 | return last_results 89 | 90 | def save_results(results, dir_path='../generated/results', prefix='results'): 91 | if not os.path.exists(dir_path): 92 | os.makedirs(dir_path) 93 | file_name = '{}_{}.json'.format(prefix, str(hash(str(results)))) 94 | with open(os.path.join(dir_path, file_name), 'w') as fw: 95 | json.dump(results, fw, indent=2) 96 | return file_name 97 | -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts_utils/reaction_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import chain, combinations 3 | from pymatgen.core import Composition, Element 4 | from typing import Dict, List, Optional, Tuple, Union 5 | from copy import deepcopy 6 | 7 | from SynthesisSimilarity.core.utils import get_elements_in_formula 8 | 9 | 10 | TOLERANCE = 1e-6 # Tolerance for determining if a particular component fraction is > 0. 11 | 12 | 13 | def _balance_coeffs( 14 | reactants: List[Composition], products: List[Composition] 15 | ) -> Tuple[np.ndarray, Union[int, float], int]: 16 | """ 17 | Balances the reaction and returns the new coefficient matrix 18 | 19 | Adapted from reaction-network 20 | McDermott, M. J., Dwaraknath, S. S., and Persson, K. A. (2021). A graph-based network for predicting chemical reaction pathways in solid-state materials synthesis. Nature Communications, 12(1). https://doi.org/10.1038/s41467-021-23339-x 21 | """ 22 | compositions = reactants + products 23 | num_comp = len(compositions) 24 | 25 | all_elems = sorted({elem for c in compositions for elem in c.elements}) 26 | num_elems = len(all_elems) 27 | 28 | comp_matrix = np.array([[c[el] for el in all_elems] for c in compositions]).T 29 | 30 | rank = np.linalg.matrix_rank(comp_matrix) 31 | diff = num_comp - rank 32 | num_constraints = diff if diff >= 2 else 1 33 | 34 | # an error = a component changing sides or disappearing 35 | lowest_num_errors = np.inf 36 | 37 | first_product_idx = len(reactants) 38 | 39 | # start with simplest product constraints, work to more complex constraints 40 | product_constraints = chain.from_iterable( 41 | [ 42 | combinations(range(first_product_idx, num_comp), n_constr) 43 | for n_constr in range(num_constraints, 0, -1) 44 | ] 45 | ) 46 | reactant_constraints = chain.from_iterable( 47 | [ 48 | combinations(range(0, first_product_idx), n_constr) 49 | for n_constr in range(num_constraints, 0, -1) 50 | ] 51 | ) 52 | best_soln = np.zeros(num_comp) 53 | 54 | for constraints in chain(product_constraints, reactant_constraints): 55 | n_constr = len(constraints) 56 | 57 | comp_and_constraints = np.append( 58 | comp_matrix, np.zeros((n_constr, num_comp)), axis=0 59 | ) 60 | b = np.zeros((num_elems + n_constr, 1)) 61 | b[-n_constr:] = 1 if min(constraints) >= first_product_idx else -1 62 | 63 | for num, idx in enumerate(constraints): 64 | comp_and_constraints[num_elems + num, idx] = 1 65 | # arbitrarily fix coeff to 1 66 | 67 | coeffs = np.matmul(np.linalg.pinv(comp_and_constraints), b) 68 | 69 | num_errors = 0 70 | if np.allclose(np.matmul(comp_matrix, coeffs), np.zeros((num_elems, 1))): 71 | expected_signs = np.array([-1] * len(reactants) + [+1] * len(products)) 72 | num_errors = np.sum(np.multiply(expected_signs, coeffs.T) < TOLERANCE) 73 | if num_errors == 0: 74 | lowest_num_errors = 0 75 | best_soln = coeffs 76 | break 77 | if num_errors < lowest_num_errors: 78 | lowest_num_errors = num_errors 79 | best_soln = coeffs 80 | 81 | return np.squeeze(best_soln), lowest_num_errors, num_constraints 82 | 83 | 84 | def balance_w_rxn_network( 85 | target_formula, 86 | precursors_formulas, 87 | ref_materials_comp, 88 | ): 89 | rxn_predict = [ 90 | target_formula, 91 | { 92 | 'left': {}, 93 | 'right': {}, 94 | }, 95 | None, 96 | '', 97 | ] 98 | # blance reaction 99 | pres = [] 100 | for x in precursors_formulas: 101 | if x in ref_materials_comp: 102 | pres.append( 103 | ref_materials_comp[x]['material_formula'] 104 | ) 105 | else: 106 | pres.append(x) 107 | tars = [target_formula] + ['H2O', 'CO2', 'NH3', 'O2', 'NO2', ] 108 | all_mats = pres + tars 109 | coeffs, lowest_num_errors, num_constraints = _balance_coeffs( 110 | reactants=[Composition(x) for x in pres], 111 | products=[Composition(x) for x in tars], 112 | ) 113 | # make coeff for target as 1 114 | coeffs = coeffs / max(np.abs(coeffs[len(pres)]), TOLERANCE) 115 | 116 | coeffs = np.round(coeffs, 3) 117 | # reaction should be solvable 118 | if np.isinf(lowest_num_errors): 119 | rxn_predict = None 120 | # coefficients should be negative or zero for precursors 121 | for i, pre in enumerate(pres): 122 | if coeffs[i] > TOLERANCE: 123 | rxn_predict = None 124 | # coefficients should be positive for targets 125 | if coeffs[len(pres)] < TOLERANCE: 126 | rxn_predict = None 127 | # format rxn_predict as output 128 | if rxn_predict is not None: 129 | for i, coeff in enumerate(coeffs): 130 | if coeff < -TOLERANCE: 131 | rxn_predict[1]['left'][all_mats[i]] = -coeff 132 | elif coeff > TOLERANCE: 133 | rxn_predict[1]['right'][all_mats[i]] = coeff 134 | for mat, coeff in rxn_predict[1]['left'].items(): 135 | rxn_predict[3] += f'{coeff} {mat} + ' 136 | rxn_predict[3] = rxn_predict[3].strip('+ ') 137 | rxn_predict[3] += ' == ' 138 | for mat, coeff in rxn_predict[1]['right'].items(): 139 | rxn_predict[3] += f'{coeff} {mat} + ' 140 | rxn_predict[3] = rxn_predict[3].strip('+ ') 141 | rxn_predict = tuple(rxn_predict) 142 | 143 | return rxn_predict 144 | 145 | 146 | def are_coefficients_positive( 147 | target, 148 | precursors, 149 | reaction 150 | ): 151 | is_positive = True 152 | mat_coeff = { 153 | **reaction[1]['left'], 154 | **reaction[1]['right'], 155 | } 156 | for pre in precursors: 157 | if float(mat_coeff.get(pre, 0)) < 0 or pre in reaction[1]['right']: 158 | is_positive = False 159 | if float(mat_coeff[target]) <= 0 or target in reaction[1]['left']: 160 | is_positive = False 161 | return is_positive 162 | 163 | 164 | def clear_zero_coeff_precursors( 165 | precursors, 166 | reaction, 167 | ): 168 | pres_out = [] 169 | for pre in precursors: 170 | if pre in reaction[1]['left'] and float(reaction[1]['left'][pre]) > 0: 171 | pres_out.append(pre) 172 | return tuple(sorted(pres_out)) 173 | 174 | def reaction_coeff_to_float(reaction): 175 | reaction_out = deepcopy(reaction) 176 | for k in reaction_out[1]['left']: 177 | reaction_out[1]['left'][k] = float(reaction_out[1]['left'][k]) 178 | for k in reaction_out[1]['right']: 179 | reaction_out[1]['right'][k] = float(reaction_out[1]['right'][k]) 180 | return reaction_out -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts_utils/recommendation_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import bisect 3 | 4 | from SynthesisSimilarity.core import utils 5 | 6 | 7 | __author__ = "Tanjin He" 8 | __maintainer__ = "Tanjin He" 9 | __email__ = "tanjin_he@berkeley.edu" 10 | 11 | 12 | def collect_targets_in_reactions( 13 | train_reactions, 14 | all_elements, 15 | common_precursors_set, 16 | exclude_common_precursors=False, 17 | ): 18 | # TODO: clean the name of train_xx 19 | raw_indices_train = set() 20 | train_targets = {} 21 | for r in train_reactions: 22 | tar_f = utils.array_to_formula(r["target_comp"][0], all_elements) 23 | if len(r["target_comp"]) > 1: 24 | print("len(r['target_comp'])", len(r["target_comp"])) 25 | assert len(r["target_comp"]) == 1, "Reaction not expanded" 26 | for x in r["precursors_comp"]: 27 | assert len(x) == 1, "Reaction not expanded" 28 | pre_fs = set( 29 | [utils.array_to_formula(x[0], all_elements) for x in r["precursors_comp"]] 30 | ) 31 | assert len(pre_fs) == len( 32 | r["precursors_comp"] 33 | ), "len(pre_fs) != len(r['precursors_comp'])" 34 | pre_fs = tuple(sorted(pre_fs)) 35 | if exclude_common_precursors and set(pre_fs).issubset(common_precursors_set): 36 | continue 37 | if tar_f not in train_targets: 38 | train_targets[tar_f] = { 39 | "comp": r["target_comp"][0], 40 | "comp_fea": r["target_comp_featurized"][0], 41 | "pres": collections.Counter(), 42 | "syn_type": collections.Counter(), 43 | "syn_type_pres": collections.Counter(), 44 | "raw_index": set(), 45 | "is_common": collections.Counter(), 46 | "pres_raw_index": {}, 47 | } 48 | train_targets[tar_f]["pres"][pre_fs] += 1 49 | train_targets[tar_f]["raw_index"].add(r["raw_index"]) 50 | if pre_fs not in train_targets[tar_f]["pres_raw_index"]: 51 | train_targets[tar_f]["pres_raw_index"][pre_fs] = [] 52 | train_targets[tar_f]["pres_raw_index"][pre_fs].append(r["raw_index"]) 53 | raw_indices_train.add(r["raw_index"]) 54 | if set(pre_fs).issubset(common_precursors_set): 55 | train_targets[tar_f]["is_common"]["common"] += 1 56 | else: 57 | train_targets[tar_f]["is_common"]["uncommon"] += 1 58 | if "synthesis_type" in r: 59 | train_targets[tar_f]["syn_type"][r["synthesis_type"]] += 1 60 | train_targets[tar_f]["syn_type_pres"][(r["synthesis_type"],) + pre_fs] += 1 61 | 62 | train_targets_formulas = list(train_targets.keys()) 63 | # TODO: shall we make this np.ndarray? 64 | train_targets_features = [ 65 | train_targets[x]["comp_fea"] for x in train_targets_formulas 66 | ] 67 | print("len(train_targets)", len(train_targets)) 68 | return train_targets, train_targets_formulas, train_targets_features 69 | 70 | 71 | def add_to_sorted_list(items, values, new_item, new_value): 72 | new_idx = bisect.bisect_left(values, new_value) 73 | items.insert(new_idx, new_item) 74 | values.insert(new_idx, new_value) 75 | return items, values 76 | -------------------------------------------------------------------------------- /SynthesisSimilarity/scripts_utils/similarity_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | __author__ = 'Tanjin He' 8 | __maintainer__ = 'Tanjin He' 9 | __email__ = 'tanjin_he@berkeley.edu' 10 | 11 | NEAR_ZERO = 1e-6 12 | 13 | # TODO: do we need this gpu functions? 14 | def print_gpu_info(): 15 | gpus = tf.config.experimental.list_physical_devices('GPU') 16 | logical_gpus = tf.config.experimental.list_logical_devices('GPU') 17 | print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") 18 | 19 | def allow_gpu_growth(): 20 | gpus = tf.config.experimental.list_physical_devices('GPU') 21 | if gpus: 22 | try: 23 | # Currently, memory growth needs to be the same across GPUs 24 | for gpu in gpus: 25 | tf.config.experimental.set_memory_growth(gpu, True) 26 | logical_gpus = tf.config.experimental.list_logical_devices('GPU') 27 | print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") 28 | except RuntimeError as e: 29 | # Memory growth must be set before GPUs have been initialized 30 | print(e) 31 | 32 | if os.environ.get('tf_allow_gpu_growth', 'False') != 'True': 33 | allow_gpu_growth() 34 | os.environ['tf_allow_gpu_growth'] = 'True' 35 | 36 | def composition_to_array(composition, elements): 37 | comp_array = np.zeros((len(elements), ), dtype=np.float32) 38 | for c, v in composition.items(): 39 | comp_array[elements.index(c)] = v 40 | comp_array /= max(np.sum(comp_array), NEAR_ZERO) 41 | return comp_array 42 | 43 | def load_encoding_model(model_dir): 44 | model_path = os.path.join(model_dir, 'saved_model') 45 | model = tf.saved_model.load(model_path) 46 | with open(os.path.join(model_dir, 'model_config.json'), 'r') as fr: 47 | model_config = json.load(fr) 48 | if 'all_ions' in model_config: 49 | model_config['all_ions'] = [ 50 | tuple(x) for x in model_config['all_ions'] 51 | ] 52 | return model, model_config -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pymatgen 2 | adjustText 3 | colorcet 4 | psutil 5 | seaborn 6 | jsonlines 7 | matplotlib 8 | tensorflow==2.7.0 9 | protobuf==3.19.6 10 | regex 11 | timebudget 12 | scikit-learn 13 | gdown 14 | -------------------------------------------------------------------------------- /requirements_optional.txt: -------------------------------------------------------------------------------- 1 | # These packages are not necessary if only using PrecursorSelector for precursor recommendation. 2 | # matminer 3 | # scikit-learn==1.0.2 4 | # gensim==3.8.3 5 | # 6 | # unidecode 7 | # Synthepedia 8 | # ValenceSolver 9 | # synthesis_dataset 10 | # ReactionCompleter 11 | 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | __author__ = 'Tanjin He' 5 | __maintainer__ = 'Tanjin He' 6 | __email__ = 'tanjin_he@berkeley.edu' 7 | 8 | 9 | if __name__ == "__main__": 10 | setup(name='SynthesisSimilarity', 11 | version='1.0.0', 12 | author="Tanjin He", 13 | author_email="tanjin_he@berkeley.edu", 14 | license="MIT License", 15 | packages=find_packages(), 16 | include_package_data=True, 17 | install_requires=[ 18 | 'pymatgen', 19 | 'adjustText', 20 | 'colorcet', 21 | 'psutil', 22 | 'seaborn', 23 | 'jsonlines', 24 | 'matplotlib', 25 | 'tensorflow==2.7.0', 26 | 'protobuf==3.19.6', 27 | 'regex', 28 | 'timebudget', 29 | 'scikit-learn', 30 | 'gdown', 31 | ], 32 | zip_safe=False) 33 | 34 | --------------------------------------------------------------------------------