├── .gitattributes ├── CSPML (2022).yml ├── CSPML.py ├── CSPML_latest_codes ├── CSPML.yml ├── CSPML_Creating_MLdata.ipynb ├── CSPML_Structure_Prediction.ipynb ├── CSPML_training.ipynb ├── Create_strcmp_fgp.ipynb ├── KmdPlus.py ├── cif_files_for_90crystals.zip ├── data_set │ ├── CSML_90_opt_results.pd.xz │ ├── CSPML_models.xz │ ├── MP_stable_20211107.pd.xz │ ├── all_searching_targets_20211107_with_predictions.pd.xz │ ├── cmpfgp_stable_meanstd_20211107.npy │ ├── element_dissimilarity.npy │ ├── element_features.csv │ ├── paper_used_mp_data_20211107.pd.xz │ └── strfgp_test_20211107.npy ├── readme.txt └── tools.ipynb ├── LICENSE ├── README.md ├── data_set ├── MP_candidates.pkl ├── MP_structures.pkl ├── candidates_paper.pkl ├── descriptor_standardization.pkl ├── element_dissimilarity.pkl ├── model1_tau=0.3 ├── model2_tau=0.3 ├── model3_tau=0.3 ├── model4_tau=0.3 └── model5_tau=0.3 ├── read_me.txt └── tutorial.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | data_set/* filter=lfs diff=lfs merge=lfs -text 2 | CSPML_latest_codes/data_set/* filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /CSPML (2022).yml: -------------------------------------------------------------------------------- 1 | name: CSPML 2 | channels: 3 | - apple 4 | - conda-forge 5 | dependencies: 6 | - ca-certificates=2024.3.11=haa95532_0 7 | - lz4-c=1.9.4=h2bbff1b_1 8 | - openssl=1.1.1w=h2bbff1b_0 9 | - pip=24.0=py38haa95532_0 10 | - python=3.8.8=hdbf39b2_5 11 | - setuptools=69.5.1=py38haa95532_0 12 | - sqlite=3.45.3=h2bbff1b_0 13 | - vc=14.2=h2eaa2aa_1 14 | - vs2015_runtime=14.29.30133=h43f2093_3 15 | - wheel=0.43.0=py38haa95532_0 16 | - xz=5.4.6=h8cc25b3_1 17 | - zlib=1.2.13=h8cc25b3_1 18 | - zstd=1.5.5=hd43e919_2 19 | - pip: 20 | - anyio==4.4.0 21 | - argon2-cffi==23.1.0 22 | - argon2-cffi-bindings==21.2.0 23 | - arrow==1.3.0 24 | - asttokens==2.4.1 25 | - async-lru==2.0.4 26 | - attrs==23.2.0 27 | - babel==2.15.0 28 | - backcall==0.2.0 29 | - beautifulsoup4==4.12.3 30 | - bleach==6.1.0 31 | - certifi==2024.6.2 32 | - cffi==1.16.0 33 | - charset-normalizer==3.3.2 34 | - colorama==0.4.6 35 | - comm==0.2.2 36 | - contourpy==1.1.1 37 | - cycler==0.12.1 38 | - debugpy==1.8.1 39 | - decorator==5.1.1 40 | - defusedxml==0.7.1 41 | - exceptiongroup==1.2.1 42 | - executing==2.0.1 43 | - fastjsonschema==2.20.0 44 | - fonttools==4.53.0 45 | - fqdn==1.5.1 46 | - h11==0.14.0 47 | - httpcore==1.0.5 48 | - httpx==0.27.0 49 | - idna==3.7 50 | - importlib-metadata==7.2.0 51 | - importlib-resources==6.4.0 52 | - ipykernel==6.29.4 53 | - ipython==8.12.3 54 | - ipywidgets==8.1.3 55 | - isoduration==20.11.0 56 | - jedi==0.19.1 57 | - jinja2==3.1.4 58 | - joblib==1.4.2 59 | - json5==0.9.25 60 | - jsonpointer==3.0.0 61 | - jsonschema==4.22.0 62 | - jsonschema-specifications==2023.12.1 63 | - jupyter==1.0.0 64 | - jupyter-client==8.6.2 65 | - jupyter-console==6.6.3 66 | - jupyter-core==5.7.2 67 | - jupyter-events==0.10.0 68 | - jupyter-lsp==2.2.5 69 | - jupyter-server==2.14.1 70 | - jupyter-server-terminals==0.5.3 71 | - jupyterlab==4.2.2 72 | - jupyterlab-pygments==0.3.0 73 | - jupyterlab-server==2.27.2 74 | - jupyterlab-widgets==3.0.11 75 | - keras==2.6.0 76 | - kiwisolver==1.4.5 77 | - markupsafe==2.1.5 78 | - matplotlib==3.7.5 79 | - matplotlib-inline==0.1.7 80 | - mistune==3.0.2 81 | - monty==2023.9.25 82 | - mordred==1.2.0 83 | - mpmath==1.3.0 84 | - nbclient==0.10.0 85 | - nbconvert==7.16.4 86 | - nbformat==5.10.4 87 | - nest-asyncio==1.6.0 88 | - networkx==2.8.8 89 | - notebook==7.2.1 90 | - notebook-shim==0.2.4 91 | - numpy==1.19.2 92 | - overrides==7.7.0 93 | - packaging==24.1 94 | - palettable==3.3.3 95 | - pandocfilters==1.5.1 96 | - parso==0.8.4 97 | - pickleshare==0.7.5 98 | - pillow==10.3.0 99 | - pkgutil-resolve-name==1.3.10 100 | - platformdirs==4.2.2 101 | - plotly==5.22.0 102 | - prometheus-client==0.20.0 103 | - prompt-toolkit==3.0.47 104 | - protobuf==3.20.3 105 | - psutil==6.0.0 106 | - pure-eval==0.2.2 107 | - pycparser==2.22 108 | - pygments==2.18.0 109 | - pyparsing==3.1.2 110 | - python-dateutil==2.9.0.post0 111 | - python-json-logger==2.0.7 112 | - pytz==2024.1 113 | - pywin32==306 114 | - pywinpty==2.0.13 115 | - pyyaml==5.4.1 116 | - pyzmq==26.0.3 117 | - qtconsole==5.5.2 118 | - qtpy==2.4.1 119 | - rdkit==2024.3.1 120 | - referencing==0.35.1 121 | - requests==2.32.3 122 | - rfc3339-validator==0.1.4 123 | - rfc3986-validator==0.1.1 124 | - rpds-py==0.18.1 125 | - ruamel-yaml==0.18.6 126 | - ruamel-yaml-clib==0.2.8 127 | - scikit-learn==1.3.2 128 | - scipy==1.10.1 129 | - seaborn==0.13.2 130 | - send2trash==1.8.3 131 | - sniffio==1.3.1 132 | - soupsieve==2.5 133 | - spglib==2.4.0 134 | - stack-data==0.6.3 135 | - sympy==1.12.1 136 | - tabulate==0.9.0 137 | - tenacity==8.4.1 138 | - tensorflow==2.6.0 139 | - terminado==0.18.1 140 | - threadpoolctl==3.5.0 141 | - tinycss2==1.3.0 142 | - tomli==2.0.1 143 | - tornado==6.4.1 144 | - tqdm==4.66.4 145 | - traitlets==5.14.3 146 | - types-python-dateutil==2.9.0.20240316 147 | - typing-extensions==4.12.2 148 | - uri-template==1.3.0 149 | - urllib3==2.2.2 150 | - wcwidth==0.2.13 151 | - webcolors==24.6.0 152 | - webencodings==0.5.1 153 | - websocket-client==1.8.0 154 | - werkzeug==3.0.3 155 | - widgetsnbextension==4.0.11 156 | - zipp==3.19.2 157 | prefix: /Users/minorukusaba/miniforge3/envs/CSPML 158 | -------------------------------------------------------------------------------- /CSPML.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # author: Minoru Kusaba (SOKENDAI, kusaba@ism.ac.jp) 3 | # last update: 2022/01/12 4 | 5 | """ 6 | CSPML is a unique methodology for the crystal structure prediction (CSP) that relies on a 7 | machine learning algorithm (Binary classification neural network model). CSPML predict stable structure 8 | for any given query composition, by automatically selecting from a crystal structure database a set of 9 | template crystals with nearly identical stable structures to which atomic substitution is to be applied. 10 | The pre-trained model is used for the selection of the template crystals. 11 | 33,153 candidate compounds (all candidate templates; obtained from Materials Project) and pre-trained models 12 | are embedded in CSPML. 13 | """ 14 | 15 | # Import libraries. 16 | import pandas as pd 17 | import numpy as np 18 | from pymatgen.core.composition import Composition 19 | from xenonpy.descriptor import Compositions 20 | import pickle 21 | import itertools 22 | import copy 23 | import os 24 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 25 | import tensorflow as tf 26 | from tensorflow.keras.utils import to_categorical 27 | 28 | # Load preset data. 29 | # Elements handled in CSPML. 30 | elements = ["H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", 31 | "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", 32 | "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", 33 | "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", 34 | "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr"] 35 | 36 | # Candidate templates for CSPML. 37 | with open("./data_set/MP_candidates.pkl", "rb") as f: # preset 33,153 candidate compounds. 38 | MP_candidates = pickle.load(f) 39 | 40 | with open("./data_set/MP_structures.pkl", 'rb') as f: # preset 33,153 candidate structures. 41 | MP_structures = pickle.load(f) 42 | 43 | # Pre-calculated velues for standardizing the XenonPy-calculated descroptor. 44 | with open("./data_set/descriptor_standardization.pkl", 'rb') as f: 45 | descriptor_standardization = pickle.load(f) 46 | 47 | xenonpy_mean = descriptor_standardization["mean"] # equal to the mean of the 33,153 XenonPy-calculated descriptor. 48 | xenonpy_std = descriptor_standardization["std"] # equal to the std of the 33,153 XenonPy-calculated descriptor. 49 | 50 | # Dissimilarity of any element pairs for the above-defined elements. 51 | with open("./data_set/element_dissimilarity.pkl", 'rb') as f: 52 | element_dissimilarity = pickle.load(f) 53 | 54 | # Load pre-trained models (Ensemble of NN-binary classifieres). 55 | model1 = tf.keras.models.load_model("./data_set/model1_tau=0.3") 56 | model2 = tf.keras.models.load_model("./data_set/model2_tau=0.3") 57 | model3 = tf.keras.models.load_model("./data_set/model3_tau=0.3") 58 | model4 = tf.keras.models.load_model("./data_set/model4_tau=0.3") 59 | model5 = tf.keras.models.load_model("./data_set/model5_tau=0.3") 60 | 61 | models = list([model1, model2, model3, model4, model5]) 62 | 63 | # Define functions. 64 | def formula_to_composition(formula, elements = elements): 65 | """ 66 | Transform a pretty formulas (single str object) to a vector of the composition ratio (np.array). 67 | Args: 68 | formula (str): single pretty formula (like "SiO2"). 69 | elements (list): a list consists of the element names for creating the vector of the composition ratio. 70 | 71 | Returns: a vector of the composition ratio (np.array). 72 | """ 73 | comp = Composition(formula) 74 | vec = np.zeros(len(elements)) 75 | for i in range(0, len(elements)): 76 | vec[i] = comp.get_atomic_fraction(elements[i]) 77 | return vec 78 | 79 | def formula_to_Composition(formula): 80 | """ 81 | Transform a list of pretty formulas to Composition class objects (pymatgen.core.composition). 82 | Args: 83 | formula (list): a list of pretty formulas (like ["SiO2","Li4Ti5O12"]). 84 | 85 | Returns: a list of Composition class objects. 86 | """ 87 | comp = [] 88 | for i in range(len(formula)): 89 | comp.append(Composition(formula[i])) 90 | return comp 91 | 92 | def Composition_to_descriptor(comp, mean = xenonpy_mean, std = xenonpy_std): 93 | """ 94 | Transform a list of Composition class objects (pymatgen.core.composition) to the descriptors 95 | calculated by xenonpy.descriptor. 96 | Args: 97 | comp (list): a list of Composition class objects. 98 | mean = xenonpy_mean (pandas.Series): pre-calculated mean for nomalizing the descriptors. 99 | std = xenonpy_std (pandas.Series): pre-calculated standard deviation for nomalizing the descriptors. 100 | Returns: a pd.Dataframe containing the XenonPy-calculated descriptors (d=290). 101 | """ 102 | descp = Compositions().transform(comp) 103 | descp_scaled = (descp - xenonpy_mean)/xenonpy_std 104 | return descp_scaled 105 | 106 | def formula_to_sortedcomposition(formula, elements = elements): 107 | """ 108 | Transform a list of pretty formulas to the sorted composition ratios. 109 | Args: 110 | formula (list): a list of pretty formulas (like ["SiO2","Li4Ti5O12"]). 111 | elements (list): a list of element's names (str). 112 | 113 | Returns: a pd.Dataframe containing the sorted composition ratios of given formulas. 114 | """ 115 | N_data = len(formula) 116 | sorted_composition = np.zeros((N_data, len(elements))) 117 | for i in range(0, N_data): 118 | sorted_composition[i,] = np.sort(formula_to_composition(formula[i], elements))[::-1] 119 | sorted_composition_pd = pd.DataFrame(sorted_composition) 120 | return sorted_composition_pd 121 | 122 | def ensemble_models(X, models = models): 123 | """ 124 | Calculate an ensemble of the estimated class probabilities of being classified into similar pairs. 125 | Args: 126 | X (np.array): the descriptors for paired-formulas (an absolute value of the difference of xenonpy-descriptors). 127 | models (list): a list of pre-trained models (keras.engine.functional.Functional). 128 | 129 | Returns: a np.array showing an ensemble of the estimated class probabilities of being classified into similar pairs. 130 | """ 131 | preds = list() 132 | for i in range(0, len(models)): 133 | pred = models[i](X) 134 | preds.append(pred[:,1]) 135 | return np.sum(np.array(preds), axis = 0)/len(models) 136 | 137 | def Narrowingdown_candidates(query_formula, candidates = MP_candidates, elements = elements): 138 | """ 139 | Narrowing down the candidate compounds by the composition ratios of the given query formulas. 140 | Args: 141 | query_formula (list): a list of (query) pretty formulas (like ["SiO2","Li4Ti5O12"]). 142 | candidates (dictionary): a dictionary consists of three keys,'property', 'composition', 'descriptor'. 143 | Each of their keys contains pandas.DataFrame object which lists properties, composition ratios, and 144 | chemical composition descriptors of the candidate compounds, respectively. 145 | elements (list): a list of element's names (str). 146 | 147 | Returns: a list of the dictionaries consists of three keys,'query_formula', 'candidates_num', 'candidates_id'. 148 | The 'query_formula' shows a query formula (str) which was used for narrowing down candidates. The 'candidates_num' 149 | shows the number of narrowed-down candidates for a given query formula. The 'candidates_id' shows the material-ids 150 | of the narrowed-down candidates for a given query formula. 151 | """ 152 | all_comp = candidates["composition"] 153 | query_comp = formula_to_sortedcomposition(query_formula, elements) 154 | 155 | survived = [] 156 | for i in range(len(query_formula)): 157 | ix = np.sum(all_comp == query_comp.iloc[i,:], axis = 1) == len(all_comp.columns) 158 | candidates_id = candidates["property"][ix]["material_id"].reset_index(drop=True) 159 | 160 | if len(candidates_id) == 0: 161 | print(f"None of the candidates had the same composition ratio as {query_formula[i]}.") 162 | candidates_id = list() 163 | candidates_num = 0 164 | 165 | else: 166 | candidates_id = list(candidates_id) 167 | candidates_num = len(candidates_id) 168 | 169 | result = {"query_formula":query_formula[i], "candidates_num":candidates_num, 170 | "candidates_id":candidates_id} 171 | survived.append(result) 172 | 173 | return survived 174 | 175 | def Screening_candidates(query_formula, top_K, candidates=MP_candidates, prediction_models=models, 176 | mean=xenonpy_mean,std=xenonpy_std,cut_off=0.5, elements = elements): 177 | """ 178 | Screening the candidate compounds by the pre-trained models into top-K candidates for the given query formulas. 179 | Args: 180 | query_formula (list): a list of (query) pretty formulas (like ["SiO2","Li4Ti5O12"]). 181 | top_K (int): Candidates are screened up to top-K candidates. 182 | candidates (dictionary): a dictionary consists of three keys,'property', 'composition', 'descriptor'. 183 | Each of their keys contains pandas.DataFrame object which lists properties, composition ratios, and 184 | chemical composition descriptors of the candidate compounds, respectively. 185 | models (list): a list of pre-trained models (keras.engine.functional.Functional). 186 | mean = xenonpy_mean (pandas.Series): pre-calculated mean for nomalizing the descriptors. 187 | std = xenonpy_std (pandas.Series): pre-calculated standard deviation for nomalizing the descriptors. 188 | cut_off (float; default = 0.5): The probability used for cutting-off any candidates of which 189 | the estimated class-probabilities (of being classified into similar pairs) are not greater than the value. 190 | elements (list): a list of element's names (str). 191 | 192 | Returns: a list of the dictionaries consists of four keys,"query_formula","topK_formula","topK_id" 193 | , and "topK_pred". The "query_formula" shows a query formula (str) which was used for screening candidates. 194 | The "topK_formula" shows the formulas of the screened top-K candidates for a given query formula. 195 | The "topK_id" shows the material-ids of the screened top-K candidates for a given query formula. 196 | The "topK_pred" shows the estimated class-probabilities (of being classified into similar pairs) 197 | of the screened top-K candidates for a given query formula. 198 | """ 199 | all_comp = candidates["composition"] 200 | query_comp = formula_to_sortedcomposition(query_formula,elements) 201 | x = formula_to_Composition(query_formula) 202 | query_descp = Composition_to_descriptor(x, mean, std) 203 | 204 | predictions = [] 205 | for i in range(len(query_formula)): 206 | ix = np.sum(all_comp == query_comp.iloc[i,:], axis = 1) == len(all_comp.columns) 207 | candidates_descp = candidates["descriptor"][ix] 208 | candidates_id = candidates["property"][ix]["material_id"].reset_index(drop=True) 209 | candidates_formula = candidates["property"][ix]["pretty_formula"].reset_index(drop=True) 210 | 211 | if len(candidates_id) == 0: 212 | print(f"None of the candidates had the same composition ratio as {query_formula[i]}.") 213 | topK_id = list() 214 | topK_pred = 0 215 | topK_formula = list() 216 | 217 | else: 218 | pred = ensemble_models(abs(candidates_descp - query_descp.iloc[i,:]).values, 219 | prediction_models) 220 | topK_id = list(candidates_id[np.argsort(pred)[::-1]][:top_K]) 221 | topK_formula = list(candidates_formula[np.argsort(pred)[::-1]][:top_K]) 222 | topK_pred = np.sort(pred)[::-1][:top_K] 223 | 224 | # Cutting-off candidates. 225 | surviving = topK_pred>cut_off 226 | 227 | if sum(surviving) == 0: 228 | print(f"None of the candidates had the class probabilities greater than {cut_off} at {query_formula[i]}.") 229 | topK_id = list() 230 | topK_pred = 0 231 | topK_formula = list() 232 | else: 233 | topK_id = [topK_id[j] for j in range(len(topK_id)) if surviving[j]] 234 | topK_formula = [topK_formula[j] for j in range(len(topK_formula)) if surviving[j]] 235 | topK_pred = topK_pred[surviving] 236 | 237 | prediction_result = {"query_formula":query_formula[i],"topK_formula":topK_formula, 238 | "topK_id":topK_id,"topK_pred":topK_pred} 239 | predictions.append(prediction_result) 240 | 241 | return predictions 242 | 243 | def Structure_prediction(query_formula, top_K, candidates=MP_candidates, structures=MP_structures,elements = elements, 244 | prediction_models=models, mean=xenonpy_mean, std=xenonpy_std, element_dissimilarity = element_dissimilarity, 245 | cut_off=0.5, SI = False, save_cif = False, save_cif_filename = ""): 246 | """ 247 | Predicting stable structures for the given query fomulas by element-substitution of the screened top-K candidate 248 | structures. The screening is performed using pre-trained models with pre-defined candidate set. 249 | The predicted structures are automatically saved as .cif files into the directory (save_cif_filename), if save_cif = True. 250 | Args: 251 | query_formula (list): a list of (query) pretty formulas (like ["SiO2","Li4Ti5O12"]). 252 | top_K (int): Candidates are screened up to top-K candidates. 253 | candidates (dictionary): a dictionary consists of three keys,'property', 'composition', 'descriptor'. 254 | Each of their keys contains pandas.DataFrame object which lists properties, composition ratios, and 255 | chemical composition descriptors of the candidate compounds, respectively. 256 | structures (dictionary): a dictionary consists of (at least) two keys,'material_id', 'structure'. 257 | The 'material_id' should be a np.array containing material-ids for the candidate compounds. 258 | The 'structure' should be a list containing Structure objects (pymatgen.Structure) for the candidate compounds. 259 | elements (list): a list of element's names (str). 260 | models (list): a list of pre-trained models (keras.engine.functional.Functional). 261 | mean = xenonpy_mean (pandas.Series): pre-calculated mean for nomalizing the descriptors. 262 | std = xenonpy_std (pandas.Series): pre-calculated standard deviation for nomalizing the descriptors. 263 | element_dissimilarity (np.arrray): a np.array containing dissimilarities for all pairs of the elements. 264 | cut_off (float; default = 0.5): The probability used for cutting-off any candidates of which 265 | the estimated class-probabilities (of being classified into similar pairs) are not greater than the value. 266 | SI (bool; default = False): If true, supplementary information of the predicted structures are also returned. 267 | save_cif (bool; default = False): If true, .cif files of the predicted structures are saved as .cif files. 268 | The top-jth predicted structure of the ith query formula (query_formula[i]) is saved as a "query_formula[i]_j.cif". 269 | save_cif_filename (str): Name of the directory of which .cif files are saved. 270 | 271 | Returns: (predictions) a list of lists containing pymatgen.Structure objects. predictions[i][j] shows 272 | the top-(j+1)th predicted structure for the query_formula[i]. 273 | (screened; optionally returned if SI=True) a list of the dictionaries consists of four keys,"query_formula","topK_formula","topK_id" 274 | , and "topK_pred". The "query_formula" shows a query formula (str) which was used for screening candidates. 275 | The "topK_formula" shows the formulas of the screened top-K candidates for a given query formula. 276 | The "topK_id" shows the material-ids of the screened top-K candidates for a given query formula. 277 | The "topK_pred" shows the estimated class-probabilities (of being classified into similar pairs) 278 | of the screened top-K candidates for a given query formula. These screened top-K candidates are template structures 279 | which are used for generating the predicted structures by element-substitution. 280 | """ 281 | # Screening top_K candidates using pre-trained model for each query formula. 282 | screened = Screening_candidates(query_formula, top_K, candidates, prediction_models, 283 | mean,std,cut_off) 284 | element_symbol = np.array(elements) 285 | predictions = [] 286 | 287 | for i in range(len(query_formula)): 288 | 289 | predicted_structures = [] 290 | scr_num = len(screened[i]["topK_id"]) 291 | 292 | if scr_num == 0: 293 | pass 294 | 295 | else: 296 | for j in range(scr_num): 297 | 298 | # The ith query formula. 299 | vec = formula_to_composition(query_formula[i],elements) 300 | N_ele = sum(vec != 0) 301 | comp_index = np.argsort(vec)[::-1][:N_ele] 302 | 303 | # Top-jth suggested formula for ith query formula. 304 | sug_formula = screened[i]['topK_formula'][j] 305 | vec_sug = formula_to_composition(sug_formula,elements) 306 | comp_sug_index = np.argsort(vec_sug)[::-1][:N_ele] 307 | 308 | # Composition of ith fomula (quary & suggested) and it's unique composition ratio. 309 | comp = np.sort(vec)[::-1][:N_ele] 310 | keys = np.sort(list(set(comp)))[::-1] 311 | 312 | # Grouping composition-index(=element species) according to unique composition ratio. 313 | group_index = [] 314 | group_sug_index = [] 315 | for k in range(0, len(keys)): 316 | x = (comp == keys[k]) 317 | group_index.append(comp_index[x]) 318 | group_sug_index.append(comp_sug_index[x]) 319 | 320 | # Find out elements-replacement that minimize element-dissimilarity and make dict showing replacement. 321 | replacement = [] 322 | for l in range(0, len(keys)): 323 | # Replacement is unique. 324 | if len(group_index[l]) == 1: 325 | replacement.append(group_sug_index[l]) 326 | # Replacement is not unique. 327 | else : 328 | seq = group_sug_index[l] 329 | pmt = list(itertools.permutations(seq)) 330 | K = len(pmt) 331 | dis_sum = np.zeros(K) 332 | for m in range(0, K): 333 | dis_sum[m] = sum(element_dissimilarity[group_index[l], pmt[m]]) 334 | replacement.append(np.array(pmt[np.argmin(dis_sum)])) 335 | rep_index = np.concatenate(replacement) 336 | q_ele = element_symbol[comp_index] 337 | rep_ele = element_symbol[rep_index] 338 | rep_dict = dict(zip(rep_ele,q_ele)) 339 | 340 | # Generating top-jth candidate structure for ith query formula. 341 | str_index = np.where(structures["material_id"] == screened[i]["topK_id"][j])[0][0] # id to index 342 | query_str = copy.deepcopy(structures["structure"][str_index]) 343 | query_str.replace_species(rep_dict) 344 | predicted_structures.append(query_str) 345 | 346 | # Save the structure object as a .cif file into dir = filename (if save_cif=True). 347 | if save_cif: 348 | text = f"{save_cif_filename}/{query_formula[i]}_{j+1}.cif" 349 | query_str.to(filename=text) 350 | else: 351 | pass 352 | 353 | predictions.append(predicted_structures) 354 | 355 | # Return the predicted structures (+ optionally the supplementary information of the predicted structures). 356 | if SI: 357 | return predictions, screened 358 | else: 359 | return predictions 360 | -------------------------------------------------------------------------------- /CSPML_latest_codes/CSPML.yml: -------------------------------------------------------------------------------- 1 | name: CSPML 2 | channels: 3 | - apple 4 | - conda-forge 5 | dependencies: 6 | - bzip2=1.0.8=h3422bc3_4 7 | - c-ares=1.18.1=h3422bc3_0 8 | - ca-certificates=2022.12.7=h4653dfc_0 9 | - cached-property=1.5.2=hd8ed1ab_1 10 | - cached_property=1.5.2=pyha770c72_1 11 | - grpcio=1.52.1=py39h23fbdae_1 12 | - h5py=3.6.0=nompi_py39hd982b79_100 13 | - hdf5=1.12.1=nompi_hd9dbc9e_104 14 | - krb5=1.20.1=h69eda48_0 15 | - libabseil=20230125.0=cxx17_hb7217d7_1 16 | - libblas=3.9.0=16_osxarm64_openblas 17 | - libcblas=3.9.0=16_osxarm64_openblas 18 | - libcurl=7.88.1=h9049daf_0 19 | - libcxx=15.0.7=h75e25f2_0 20 | - libedit=3.1.20191231=hc8eb9b7_2 21 | - libev=4.33=h642e427_1 22 | - libffi=3.4.2=h3422bc3_5 23 | - libgfortran=5.0.0=12_2_0_hd922786_31 24 | - libgfortran5=12.2.0=h0eea778_31 25 | - libgrpc=1.52.1=he98ff75_1 26 | - liblapack=3.9.0=16_osxarm64_openblas 27 | - libnghttp2=1.52.0=hae82a92_0 28 | - libopenblas=0.3.21=openmp_hc731615_3 29 | - libprotobuf=3.21.12=hb5ab8b9_0 30 | - libsqlite=3.40.0=h76d750c_0 31 | - libssh2=1.10.0=h7a5bd25_3 32 | - libzlib=1.2.13=h03a7124_4 33 | - llvm-openmp=15.0.7=h7cfbb63_0 34 | - ncurses=6.3=h07bb92c_1 35 | - numpy=1.22.4=py39h7df2422_0 36 | - openssl=3.1.0=h03a7124_0 37 | - pip=23.0.1=pyhd8ed1ab_0 38 | - python=3.9.16=hea58f1e_0_cpython 39 | - python_abi=3.9=3_cp39 40 | - re2=2023.02.02=hb7217d7_0 41 | - readline=8.1.2=h46ed386_0 42 | - setuptools=67.6.0=pyhd8ed1ab_0 43 | - tensorflow-deps=2.9.0=0 44 | - tk=8.6.12=he1e0b03_0 45 | - tzdata=2022g=h191b570_0 46 | - wheel=0.40.0=pyhd8ed1ab_0 47 | - xz=5.2.6=h57fd34a_0 48 | - zlib=1.2.13=h03a7124_4 49 | - pip: 50 | - absl-py==1.4.0 51 | - aioitertools==0.11.0 52 | - alembic==1.10.2 53 | - anyio==3.6.2 54 | - appnope==0.1.3 55 | - argon2-cffi==21.3.0 56 | - argon2-cffi-bindings==21.2.0 57 | - arrow==1.2.3 58 | - asttokens==2.2.1 59 | - astunparse==1.6.3 60 | - attrs==22.2.0 61 | - autopage==0.5.1 62 | - backcall==0.2.0 63 | - bcrypt==4.0.1 64 | - beautifulsoup4==4.11.2 65 | - bleach==6.0.0 66 | - blinker==1.6.2 67 | - boltons==23.0.0 68 | - boto3==1.26.129 69 | - botocore==1.29.129 70 | - bravado==11.0.3 71 | - bravado-core==5.17.1 72 | - cachetools==5.3.0 73 | - certifi==2022.12.7 74 | - cffi==1.15.1 75 | - charset-normalizer==3.1.0 76 | - click==8.1.3 77 | - cliff==4.2.0 78 | - cmaes==0.9.1 79 | - cmd2==2.4.3 80 | - colorlog==6.7.0 81 | - comm==0.1.2 82 | - contourpy==1.0.7 83 | - cryptography==40.0.2 84 | - cycler==0.11.0 85 | - debugpy==1.6.6 86 | - decorator==5.1.1 87 | - defusedxml==0.7.1 88 | - deprecated==1.2.14 89 | - dnspython==2.3.0 90 | - emmet-core==0.56.1 91 | - executing==1.2.0 92 | - fastapi==0.95.1 93 | - fastjsonschema==2.16.3 94 | - filelock==3.12.2 95 | - filetype==1.2.0 96 | - flask==2.3.2 97 | - flatbuffers==1.12 98 | - flatten-dict==0.4.2 99 | - fonttools==4.39.0 100 | - fqdn==1.5.1 101 | - future==0.18.3 102 | - gast==0.4.0 103 | - google-auth==2.16.2 104 | - google-auth-oauthlib==0.4.6 105 | - google-pasta==0.2.0 106 | - idna==3.4 107 | - importlib-metadata==4.13.0 108 | - importlib-resources==5.12.0 109 | - ipykernel==6.21.3 110 | - ipython==8.11.0 111 | - ipython-genutils==0.2.0 112 | - ipywidgets==8.0.4 113 | - isoduration==20.11.0 114 | - itsdangerous==2.1.2 115 | - jedi==0.18.2 116 | - jinja2==3.1.2 117 | - jmespath==1.0.1 118 | - joblib==1.2.0 119 | - json2html==1.3.0 120 | - jsonpointer==2.3 121 | - jsonref==1.1.0 122 | - jsonschema==4.17.3 123 | - jupyter==1.0.0 124 | - jupyter-client==8.0.3 125 | - jupyter-console==6.6.3 126 | - jupyter-core==5.2.0 127 | - jupyter-events==0.6.3 128 | - jupyter-server==2.4.0 129 | - jupyter-server-terminals==0.4.4 130 | - jupyterlab-pygments==0.2.2 131 | - jupyterlab-widgets==3.0.5 132 | - keras==2.9.0 133 | - keras-preprocessing==1.1.2 134 | - kiwisolver==1.4.4 135 | - latexcodec==2.0.1 136 | - libclang==15.0.6.1 137 | - maggma==0.50.4 138 | - mako==1.2.4 139 | - markdown==3.4.1 140 | - markupsafe==2.1.2 141 | - matminer==0.7.8 142 | - matplotlib==3.7.1 143 | - matplotlib-inline==0.1.6 144 | - mistune==2.0.5 145 | - mongogrant==0.3.3 146 | - mongomock==4.1.2 147 | - monotonic==1.6 148 | - monty==2022.9.9 149 | - mordred==1.2.0 150 | - mp-api==0.33.3 151 | - mpcontribs-client==5.3.5 152 | - mpmath==1.3.0 153 | - msgpack==1.0.5 154 | - nbclassic==0.5.3 155 | - nbclient==0.7.2 156 | - nbconvert==7.2.10 157 | - nbformat==5.7.3 158 | - nest-asyncio==1.5.6 159 | - networkx==2.8.8 160 | - notebook==6.5.3 161 | - notebook-shim==0.2.2 162 | - oauthlib==3.2.2 163 | - opt-einsum==3.3.0 164 | - optuna==3.0.3 165 | - orjson==3.8.12 166 | - packaging==23.0 167 | - palettable==3.3.0 168 | - pandas==1.5.1 169 | - pandocfilters==1.5.0 170 | - paramiko==3.1.0 171 | - parso==0.8.3 172 | - pbr==5.11.1 173 | - pexpect==4.8.0 174 | - pickleshare==0.7.5 175 | - pillow==9.4.0 176 | - pint==0.19.2 177 | - platformdirs==3.1.1 178 | - plotly==5.13.1 179 | - prettytable==3.6.0 180 | - prometheus-client==0.16.0 181 | - prompt-toolkit==3.0.38 182 | - protobuf==3.19.6 183 | - psutil==5.9.4 184 | - ptyprocess==0.7.0 185 | - pure-eval==0.2.2 186 | - pyasn1==0.4.8 187 | - pyasn1-modules==0.2.8 188 | - pybtex==0.24.0 189 | - pycparser==2.21 190 | - pydantic==1.10.6 191 | - pydash==7.0.3 192 | - pygments==2.14.0 193 | - pyisemail==2.0.1 194 | - pymatgen==2022.5.26 195 | - pymongo==4.3.3 196 | - pynacl==1.5.0 197 | - pyparsing==3.0.9 198 | - pyperclip==1.8.2 199 | - pyrsistent==0.19.3 200 | - python-dateutil==2.8.2 201 | - python-json-logger==2.0.7 202 | - pytz==2022.7.1 203 | - pyyaml==6.0 204 | - pyzmq==24.0.1 205 | - qpsolvers==2.6.0 206 | - qtconsole==5.4.1 207 | - qtpy==2.3.0 208 | - quadprog==0.1.11 209 | - rdkit==2023.3.2 210 | - requests==2.28.2 211 | - requests-futures==1.0.0 212 | - requests-oauthlib==1.3.1 213 | - rfc3339-validator==0.1.4 214 | - rfc3986-validator==0.1.1 215 | - rfc3987==1.3.8 216 | - rsa==4.9 217 | - ruamel-yaml==0.17.21 218 | - ruamel-yaml-clib==0.2.7 219 | - s3transfer==0.6.1 220 | - scikit-learn==1.1.3 221 | - scipy==1.8.1 222 | - seaborn==0.12.2 223 | - semantic-version==2.10.0 224 | - send2trash==1.8.0 225 | - sentinels==1.0.0 226 | - simplejson==3.19.1 227 | - six==1.15.0 228 | - sniffio==1.3.0 229 | - soupsieve==2.4 230 | - spglib==2.0.2 231 | - sqlalchemy==2.0.6 232 | - sshtunnel==0.4.0 233 | - stack-data==0.6.2 234 | - starlette==0.26.1 235 | - stevedore==5.0.0 236 | - swagger-spec-validator==3.0.3 237 | - sympy==1.11.1 238 | - tabulate==0.9.0 239 | - tenacity==8.2.2 240 | - tensorboard==2.9.1 241 | - tensorboard-data-server==0.6.1 242 | - tensorboard-plugin-wit==1.8.1 243 | - tensorflow-estimator==2.9.0 244 | - tensorflow-macos==2.9.0 245 | - tensorflow-metal==0.5.1 246 | - termcolor==2.2.0 247 | - terminado==0.17.1 248 | - threadpoolctl==3.1.0 249 | - tinycss2==1.2.1 250 | - torch==2.0.1 251 | - tornado==6.2 252 | - tqdm==4.65.0 253 | - traitlets==5.9.0 254 | - typing-extensions==4.5.0 255 | - ujson==5.7.0 256 | - uncertainties==3.1.7 257 | - uri-template==1.2.0 258 | - urllib3==1.26.15 259 | - wcwidth==0.2.6 260 | - webcolors==1.12 261 | - webencodings==0.5.1 262 | - websocket-client==1.5.1 263 | - werkzeug==2.3.3 264 | - widgetsnbextension==4.0.5 265 | - wrapt==1.15.0 266 | - xenonpy==0.6.7 267 | - zipp==3.15.0 268 | prefix: /Users/minorukusaba/miniforge3/envs/CSPML 269 | -------------------------------------------------------------------------------- /CSPML_latest_codes/CSPML_Structure_Prediction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "a021746a", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "Metal device set to: Apple M1 Max\n" 14 | ] 15 | }, 16 | { 17 | "name": "stderr", 18 | "output_type": "stream", 19 | "text": [ 20 | "2024-07-03 12:47:33.698045: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.\n", 21 | "2024-07-03 12:47:33.698188: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: )\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "# Import libraries.\n", 27 | "import pandas as pd\n", 28 | "import numpy as np\n", 29 | "from pymatgen.core.composition import Composition\n", 30 | "from KmdPlus import StatsDescriptor, formula_to_composition \n", 31 | "import pickle\n", 32 | "import itertools\n", 33 | "import copy\n", 34 | "import os\n", 35 | "import tensorflow as tf\n", 36 | "from tensorflow.keras.utils import to_categorical\n", 37 | "import math\n", 38 | "from scipy.spatial import distance_matrix\n", 39 | "import shutil\n", 40 | "\n", 41 | "# Read all templates.\n", 42 | "MP_stable = pd.read_pickle(\"data_set/MP_stable_20211107.pd.xz\")\n", 43 | "\n", 44 | "# Element-level descriptors of shape (94, 58).\n", 45 | "element_features = pd.read_csv(\"data_set/element_features.csv\", index_col= 0)\n", 46 | "\n", 47 | "# Load test data (90 crystals).\n", 48 | "test_data = pd.read_pickle(\"data_set/all_searching_targets_20211107_with_predictions.pd.xz\")\n", 49 | "\n", 50 | "# Load the pre-trained models.\n", 51 | "with open(\"data_set/CSPML_models.xz\", \"rb\") as f:\n", 52 | " models = pickle.load(f)\n", 53 | " \n", 54 | "# Load stats.\n", 55 | "cmpfgp_stable_meanstd = np.load(\"data_set/cmpfgp_stable_meanstd_20211107.npy\") \n", 56 | "\n", 57 | "# Load element dissimilarity.\n", 58 | "element_dissimilarity = np.load('data_set/element_dissimilarity.npy')" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 2, 64 | "id": "7084939f", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "# ensemble models.\n", 69 | "def ensemble_models(X, models):\n", 70 | " y_pred = np.array([models[i].predict(X, verbose=0)[:,1] for i in range(len(models))])\n", 71 | " \n", 72 | " return y_pred.mean(0)\n", 73 | "\n", 74 | "# Formula to ratio label.\n", 75 | "def formula_to_ratiolabel(formula):\n", 76 | " # Convert chemical formulas to compositions.\n", 77 | " weight = np.array([formula_to_composition(formula[i]) for i in range(len(formula))])\n", 78 | " \n", 79 | " ratio_label = []\n", 80 | " for i in range(len(formula)):\n", 81 | " sorted_weight = np.sort(weight[i])[::-1]\n", 82 | " comp = Composition(formula[i])\n", 83 | " comp_ratio = comp.num_atoms * sorted_weight\n", 84 | " x = [int(round(comp_ratio[j])) for j in range(len(comp))]\n", 85 | " gcd_x = math.gcd(*x) \n", 86 | " # For collection in the case like \"O2\", \"Na2O2\".\n", 87 | " if gcd_x != 1:\n", 88 | " x = [int(round(x[k]/gcd_x)) for k in range(len(x))]\n", 89 | " else:\n", 90 | " pass\n", 91 | " # Get ratio label for collected x.\n", 92 | " label = \"\"\n", 93 | " for j in range(len(x)):\n", 94 | " label += f\"{x[j]}:\"\n", 95 | " # Save results.\n", 96 | " ratio_label.append(label[:-1])\n", 97 | " \n", 98 | " return np.array(ratio_label, dtype = \"object\")\n", 99 | "\n", 100 | "# Screening for CSPML.\n", 101 | "def Screening_candidates(query_formula, top_K, templates, cutoff = 0.5, element_features = element_features,\n", 102 | " meanstd = cmpfgp_stable_meanstd, models = models):\n", 103 | " \n", 104 | " # Calculate cmpfgp for quary formula.\n", 105 | " query_cmpfgp = StatsDescriptor(query_formula, element_features)\n", 106 | " query_cmpfgp = (query_cmpfgp - meanstd[0])/meanstd[1] # scaling.\n", 107 | " \n", 108 | " # Calculate ratio label.\n", 109 | " query_ratiolabel = formula_to_ratiolabel(query_formula)\n", 110 | " \n", 111 | " # Make predictions.\n", 112 | " predictions = []\n", 113 | "\n", 114 | " for i in range(len(query_formula)):\n", 115 | " ix = np.where(templates.comp_ratio_label.values == query_ratiolabel[i])[0]\n", 116 | "\n", 117 | " if len(ix) < 1:\n", 118 | " print(f\"None of the candidates had the same composition ratio as {query_formula[i]}.\")\n", 119 | " topK_id, topK_pred, topK_formula = [], 0, []\n", 120 | " else:\n", 121 | " x = templates.iloc[ix]\n", 122 | " # Composition fingerprint for x.\n", 123 | " x_cmpfgp = x.cmpfgp.values\n", 124 | " x_cmpfgp = np.array([x_cmpfgp[i] for i in range(len(x_cmpfgp))])\n", 125 | " X = np.abs(x_cmpfgp - query_cmpfgp[i,:])\n", 126 | " y_pred = ensemble_models(X, models)\n", 127 | " topK_ix = np.argsort(y_pred)[::-1][:top_K]\n", 128 | "\n", 129 | " topK_id = x.materials_id.values[topK_ix]\n", 130 | " topK_pred = y_pred[topK_ix]\n", 131 | " topK_formula = x.pretty_formula.values[topK_ix]\n", 132 | "\n", 133 | " survived = (topK_pred > cutoff)\n", 134 | "\n", 135 | " if sum(survived) < 1:\n", 136 | " print(f\"None of the candidates had the class probabilities greater than {cutoff} at {query_formula[i]}.\")\n", 137 | " topK_id, topK_pred, topK_formula = [], 0, []\n", 138 | " else:\n", 139 | " topK_id, topK_pred, topK_formula = topK_id[survived], topK_pred[survived], topK_formula[survived]\n", 140 | "\n", 141 | "\n", 142 | " prediction_result = {\"query_formula\":query_formula[i],\"topK_formula\":topK_formula,\n", 143 | " \"topK_id\":topK_id,\"topK_pred\":topK_pred}\n", 144 | " predictions.append(prediction_result)\n", 145 | " \n", 146 | " return predictions\n", 147 | "\n", 148 | "# CSPML.\n", 149 | "def Structure_prediction(query_formula, top_K, templates, cutoff = 0.5, element_features = element_features,\n", 150 | " meanstd = cmpfgp_stable_meanstd, models = models, element_dissimilarity = element_dissimilarity,\n", 151 | " SI = False, save_cif = False, save_cif_filename = \"\"):\n", 152 | " \n", 153 | " # Screening top_K candidates using pre-trained model for each query formula.\n", 154 | " screened = Screening_candidates(query_formula, top_K, templates, cutoff, element_features,\n", 155 | " meanstd, models)\n", 156 | " \n", 157 | " element_symbol = element_features.index.values\n", 158 | " predictions = []\n", 159 | "\n", 160 | " for i in range(len(query_formula)):\n", 161 | "\n", 162 | " predicted_structures = []\n", 163 | " scr_num = len(screened[i][\"topK_id\"])\n", 164 | "\n", 165 | " if scr_num == 0:\n", 166 | " pass\n", 167 | "\n", 168 | " else:\n", 169 | " for j in range(scr_num):\n", 170 | "\n", 171 | " # The ith query formula.\n", 172 | " vec = formula_to_composition(query_formula[i])\n", 173 | " N_ele = sum(vec != 0)\n", 174 | " comp_index = np.argsort(vec)[::-1][:N_ele]\n", 175 | "\n", 176 | " # Top-jth suggested formula for ith query formula.\n", 177 | " sug_formula = screened[i]['topK_formula'][j]\n", 178 | " vec_sug = formula_to_composition(sug_formula)\n", 179 | " comp_sug_index = np.argsort(vec_sug)[::-1][:N_ele]\n", 180 | "\n", 181 | " # Composition of ith fomula (quary & suggested) and it's unique composition ratio.\n", 182 | " comp = np.sort(vec)[::-1][:N_ele]\n", 183 | " keys = np.sort(list(set(comp)))[::-1]\n", 184 | "\n", 185 | " # Grouping composition-index(=element species) according to unique composition ratio.\n", 186 | " group_index = []\n", 187 | " group_sug_index = []\n", 188 | " for k in range(0, len(keys)):\n", 189 | " x = (comp == keys[k])\n", 190 | " group_index.append(comp_index[x])\n", 191 | " group_sug_index.append(comp_sug_index[x])\n", 192 | "\n", 193 | " # Find out elements-replacement that minimize element-dissimilarity and make dict showing replacement.\n", 194 | " replacement = []\n", 195 | " for l in range(0, len(keys)):\n", 196 | " # Replacement is unique.\n", 197 | " if len(group_index[l]) == 1:\n", 198 | " replacement.append(group_sug_index[l])\n", 199 | " # Replacement is not unique.\n", 200 | " else :\n", 201 | " seq = group_sug_index[l]\n", 202 | " pmt = list(itertools.permutations(seq))\n", 203 | " K = len(pmt)\n", 204 | " dis_sum = np.zeros(K)\n", 205 | " for m in range(0, K):\n", 206 | " dis_sum[m] = sum(element_dissimilarity[group_index[l], pmt[m]]) # element_dissimilarity.\n", 207 | " replacement.append(np.array(pmt[np.argmin(dis_sum)]))\n", 208 | " rep_index = np.concatenate(replacement)\n", 209 | " q_ele = element_symbol[comp_index]\n", 210 | " rep_ele = element_symbol[rep_index]\n", 211 | " rep_dict = dict(zip(rep_ele,q_ele))\n", 212 | "\n", 213 | " # Generating top-jth candidate structure for ith query formula.\n", 214 | " query_str = copy.deepcopy(templates[templates.materials_id.values == screened[i][\"topK_id\"][j]].structure[0])\n", 215 | " query_str.replace_species(rep_dict)\n", 216 | " predicted_structures.append(query_str)\n", 217 | "\n", 218 | " # Save the structure object as a .cif file into dir = filename (if save_cif=True).\n", 219 | " if save_cif:\n", 220 | " text = f\"{save_cif_filename}/{query_formula[i]}_{j+1}.cif\"\n", 221 | " query_str.to(filename=text)\n", 222 | " else:\n", 223 | " pass\n", 224 | "\n", 225 | " predictions.append(predicted_structures)\n", 226 | "\n", 227 | " # Return the predicted structures (+ optionally the supplementary information of the predicted structures).\n", 228 | " if SI:\n", 229 | " return predictions, screened\n", 230 | " else:\n", 231 | " return predictions" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 3, 237 | "id": "a8829ecb", 238 | "metadata": {}, 239 | "outputs": [ 240 | { 241 | "name": "stderr", 242 | "output_type": "stream", 243 | "text": [ 244 | "2024-07-03 12:48:14.818388: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n", 245 | "2024-07-03 12:48:14.877639: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n", 246 | "2024-07-03 12:48:15.034916: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n", 247 | "2024-07-03 12:48:15.160049: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n", 248 | "2024-07-03 12:48:15.291992: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n", 249 | "2024-07-03 12:48:15.432128: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.\n" 250 | ] 251 | }, 252 | { 253 | "name": "stdout", 254 | "output_type": "stream", 255 | "text": [ 256 | "WARNING:tensorflow:5 out of the last 13 calls to .predict_function at 0x2e1df83a0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", 257 | "None of the candidates had the same composition ratio as NaCaAlPHO5F2.\n", 258 | "None of the candidates had the class probabilities greater than 0.5 at MgB7.\n", 259 | "None of the candidates had the class probabilities greater than 0.5 at Ba2CaSi4(BO7)2.\n", 260 | "None of the candidates had the same composition ratio as K5Ag2(AsSe3)3.\n", 261 | "None of the candidates had the same composition ratio as Na(WO3)9.\n", 262 | "None of the candidates had the same composition ratio as Li6V3P8O29.\n", 263 | "None of the candidates had the same composition ratio as Mg3Si2H4O9.\n", 264 | "None of the candidates had the class probabilities greater than 0.5 at Y4Si5Ir9.\n" 265 | ] 266 | } 267 | ], 268 | "source": [ 269 | "# Create a directory for saving results (results should be same as cif_files_for_90crystals/predicted_structures (pre DFT)).\n", 270 | "new_dir = \"CSPML_test90\"\n", 271 | "if os.path.exists(new_dir):\n", 272 | " shutil.rmtree(new_dir)\n", 273 | "os.mkdir(new_dir)\n", 274 | "\n", 275 | "# Make CSPML prediction.\n", 276 | "predictions, screened = Structure_prediction(test_data.pretty_formula.values, 10, MP_stable,\n", 277 | " SI = True, save_cif=True, save_cif_filename=new_dir)" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 4, 283 | "id": "c7486fc8", 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "dissim<=0.2: 51/90\n" 291 | ] 292 | }, 293 | { 294 | "data": { 295 | "text/html": [ 296 | "
\n", 297 | "\n", 310 | "\n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | "
Min. dissim (top10)<=0.2
C2.534663e+000.0
Si3.500210e-171.0
GaAs3.877854e-161.0
ZnO4.216784e-010.0
BN1.726447e+000.0
.........
VPt34.285599e-021.0
SmVO45.189530e-021.0
VCl51.235371e-011.0
LaSi2Ni92.569991e-021.0
Ca3Ni7B21.037236e-011.0
\n", 376 | "

90 rows × 2 columns

\n", 377 | "
" 378 | ], 379 | "text/plain": [ 380 | " Min. dissim (top10) <=0.2\n", 381 | "C 2.534663e+00 0.0\n", 382 | "Si 3.500210e-17 1.0\n", 383 | "GaAs 3.877854e-16 1.0\n", 384 | "ZnO 4.216784e-01 0.0\n", 385 | "BN 1.726447e+00 0.0\n", 386 | "... ... ...\n", 387 | "VPt3 4.285599e-02 1.0\n", 388 | "SmVO4 5.189530e-02 1.0\n", 389 | "VCl5 1.235371e-01 1.0\n", 390 | "LaSi2Ni9 2.569991e-02 1.0\n", 391 | "Ca3Ni7B2 1.037236e-01 1.0\n", 392 | "\n", 393 | "[90 rows x 2 columns]" 394 | ] 395 | }, 396 | "execution_count": 4, 397 | "metadata": {}, 398 | "output_type": "execute_result" 399 | } 400 | ], 401 | "source": [ 402 | "# Load test (ground truth) structure fingerprints (see Create_strcmp_fgp.ipynb for details).\n", 403 | "test_strfgp = np.load('data_set/strfgp_test_20211107.npy')\n", 404 | "\n", 405 | "# Get structure fingerprints for the predicted structures (pre DFT).\n", 406 | "predicted_strfgp = []\n", 407 | "\n", 408 | "for i in range(len(screened)):\n", 409 | " strfgp = []\n", 410 | " len_j = len(predictions[i])\n", 411 | " if len_j==0:\n", 412 | " pass\n", 413 | " else:\n", 414 | " for j in range(len_j):\n", 415 | " strfgp.append(MP_stable[MP_stable.index.values == screened[i][\"topK_id\"][j]].strfgp[0] )\n", 416 | " \n", 417 | " predicted_strfgp.append(strfgp)\n", 418 | " \n", 419 | "# Get dissimilarity for predicted (pre DFT) vs true.\n", 420 | "predicted_dissim = []\n", 421 | "\n", 422 | "for i in range(len(predicted_strfgp)):\n", 423 | " dissim = []\n", 424 | " len_j = len(predicted_strfgp[i])\n", 425 | " if len_j == 0:\n", 426 | " dissim.append(1000)\n", 427 | " else:\n", 428 | " for j in range(len_j):\n", 429 | " dissim.append(np.sum((test_strfgp[i, :] - predicted_strfgp[i][j])**2)**(1/2) )\n", 430 | " \n", 431 | " predicted_dissim.append(dissim)\n", 432 | " \n", 433 | "# Get min dissim for each i.\n", 434 | "predicted_dissim_min = np.array([np.min(np.array(predicted_dissim[i])) for i in range(len(predicted_dissim))])\n", 435 | "predicted_formula = np.array([screened[i][\"query_formula\"] for i in range(len(screened))])\n", 436 | "\n", 437 | "# Sumarize and save the results.\n", 438 | "tau = 0.2\n", 439 | "\n", 440 | "template_dissim = pd.DataFrame(np.array([predicted_dissim_min, (predicted_dissim_min <= tau)]).T,\n", 441 | " columns=[\"Min. dissim (top10)\", \"<=0.2\"], index=predicted_formula)\n", 442 | "\n", 443 | "template_dissim.to_csv(\"template_dissim.csv\")\n", 444 | "\n", 445 | "print(f\"dissim<=0.2: {int(sum(template_dissim.iloc[:,1]))}/{template_dissim.shape[0]}\")\n", 446 | "\n", 447 | "template_dissim" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 5, 453 | "id": "3511b0d8", 454 | "metadata": {}, 455 | "outputs": [ 456 | { 457 | "name": "stdout", 458 | "output_type": "stream", 459 | "text": [ 460 | "dissim<=0.2: 59/90\n" 461 | ] 462 | }, 463 | { 464 | "data": { 465 | "text/html": [ 466 | "
\n", 467 | "\n", 480 | "\n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | "
Min. dissim (top10)<=0.2
Ag8GeS60.1183031.0
Al2CoO40.3840120.0
Al2O30.0015001.0
AlH12(ClO2)30.0577711.0
BN0.0429201.0
.........
ZnO0.4234290.0
ZnSb1.4435400.0
Zr4O2.0961720.0
ZrO20.1187921.0
ZrTe50.1657891.0
\n", 546 | "

77 rows × 2 columns

\n", 547 | "
" 548 | ], 549 | "text/plain": [ 550 | " Min. dissim (top10) <=0.2\n", 551 | "Ag8GeS6 0.118303 1.0\n", 552 | "Al2CoO4 0.384012 0.0\n", 553 | "Al2O3 0.001500 1.0\n", 554 | "AlH12(ClO2)3 0.057771 1.0\n", 555 | "BN 0.042920 1.0\n", 556 | "... ... ...\n", 557 | "ZnO 0.423429 0.0\n", 558 | "ZnSb 1.443540 0.0\n", 559 | "Zr4O 2.096172 0.0\n", 560 | "ZrO2 0.118792 1.0\n", 561 | "ZrTe5 0.165789 1.0\n", 562 | "\n", 563 | "[77 rows x 2 columns]" 564 | ] 565 | }, 566 | "execution_count": 5, 567 | "metadata": {}, 568 | "output_type": "execute_result" 569 | } 570 | ], 571 | "source": [ 572 | "# Load relaxed predictions (after DFT, structural optimizations by DFT were performed on a separate machine).\n", 573 | "relaxed = pd.read_pickle('data_set/CSML_90_opt_results.pd.xz')\n", 574 | "relaxed_converged = relaxed[relaxed.converged == True]\n", 575 | "relaxed_formula = np.unique(relaxed_converged.target_formula.values)\n", 576 | "\n", 577 | "# Save predictred structures (after DFT, results should be same as cif_files_for_90crystals/predicted_structures (after DFT)).\n", 578 | "new_dir = \"predicted_structures (after DFT)\"\n", 579 | "if os.path.exists(new_dir):\n", 580 | " shutil.rmtree(new_dir)\n", 581 | "os.mkdir(new_dir)\n", 582 | " \n", 583 | "for i in range(len(relaxed_formula)):\n", 584 | " str_formula = relaxed_converged[relaxed_converged.target_formula == relaxed_formula[i]]\n", 585 | " \n", 586 | " for j in range(str_formula.shape[0]):\n", 587 | " str_x = str_formula.final_structure[j]\n", 588 | " text = f\"{new_dir}/{relaxed_formula[i]}_{str_formula.index[j]}.cif\"\n", 589 | " str_x.to(filename=text)\n", 590 | " \n", 591 | "# Load libraries for calculating strfgp.\n", 592 | "from matminer.featurizers.site import CrystalNNFingerprint # matminer version = 0.6.2 (later version will give same calculation results).\n", 593 | "from matminer.featurizers.structure import SiteStatsFingerprint\n", 594 | "# Parallel calculation.\n", 595 | "import joblib\n", 596 | "\n", 597 | "# Site featurizer.\n", 598 | "cnnf = CrystalNNFingerprint.from_preset('ops', distance_cutoffs=None, x_diff_weight=0)\n", 599 | "\n", 600 | "def parallel_cnnf(featurizer, str_x):\n", 601 | " return np.array(joblib.Parallel(n_jobs=-1)(joblib.delayed(featurizer)(str_x, i) for i in range(len(str_x.sites))))\n", 602 | "\n", 603 | "# SiteStats.\n", 604 | "def SiteStats(site_fgps):\n", 605 | " return np.array([site_fgps.mean(0), site_fgps.std(0), site_fgps.min(0), site_fgps.max(0)]).T.flatten()\n", 606 | " \n", 607 | "# Calculate structure finger prints for the relaxed structures.\n", 608 | "dft_fgp = []\n", 609 | "for i in range(len(relaxed_formula)):\n", 610 | " x = relaxed_converged[relaxed_converged.target_formula == relaxed_formula[i]]\n", 611 | " str_x = np.array([SiteStats(parallel_cnnf(cnnf.featurize, x.final_structure[j])) for j in range(x.shape[0])])\n", 612 | " dft_fgp.append(str_x)\n", 613 | " \n", 614 | "true_fgp = []\n", 615 | "for i in range(len(relaxed_formula)):\n", 616 | " x = test_data[test_data.pretty_formula == relaxed_formula[i]].structure[0]\n", 617 | " true_fgp.append(SiteStats(parallel_cnnf(cnnf.featurize, x)))\n", 618 | "\n", 619 | "# Get dissimilarity for predicted (pre DFT) vs true.\n", 620 | "dissim = np.array([np.min(np.sum((true_fgp[i] - dft_fgp[i])**2,1)**(1/2)) for i in range(len(relaxed_formula))])\n", 621 | "\n", 622 | "# Sumarize and save the results.\n", 623 | "template_dissim_dft = pd.DataFrame(np.array([dissim, (dissim <= tau)]).T,\n", 624 | " columns=[\"Min. dissim (top10)\", \"<=0.2\"], index=relaxed_formula)\n", 625 | "\n", 626 | "template_dissim_dft.to_csv(\"template_dissim_dft.csv\")\n", 627 | "\n", 628 | "print(f\"dissim<=0.2: {int(sum(template_dissim_dft.iloc[:,1]))}/{template_dissim.shape[0]}\")\n", 629 | "\n", 630 | "template_dissim_dft" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": 6, 636 | "id": "b84111fc", 637 | "metadata": {}, 638 | "outputs": [], 639 | "source": [ 640 | "# Save test (ground truth) structures as cif files (results should be same as cif_files_for_90crystals/ground_truth_structures).\n", 641 | "new_dir = \"ground_truth_structures\"\n", 642 | "if os.path.exists(new_dir):\n", 643 | " shutil.rmtree(new_dir)\n", 644 | "os.mkdir(new_dir)\n", 645 | "\n", 646 | "for i in range(test_data.shape[0]):\n", 647 | " str_x = test_data.structure[i]\n", 648 | " text = f\"{new_dir}/{test_data.pretty_formula[i]}.cif\"\n", 649 | " str_x.to(filename=text)" 650 | ] 651 | }, 652 | { 653 | "cell_type": "code", 654 | "execution_count": null, 655 | "id": "93a39c8b", 656 | "metadata": {}, 657 | "outputs": [], 658 | "source": [] 659 | } 660 | ], 661 | "metadata": { 662 | "kernelspec": { 663 | "display_name": "Python 3 (ipykernel)", 664 | "language": "python", 665 | "name": "python3" 666 | }, 667 | "language_info": { 668 | "codemirror_mode": { 669 | "name": "ipython", 670 | "version": 3 671 | }, 672 | "file_extension": ".py", 673 | "mimetype": "text/x-python", 674 | "name": "python", 675 | "nbconvert_exporter": "python", 676 | "pygments_lexer": "ipython3", 677 | "version": "3.9.16" 678 | } 679 | }, 680 | "nbformat": 4, 681 | "nbformat_minor": 5 682 | } 683 | -------------------------------------------------------------------------------- /CSPML_latest_codes/Create_strcmp_fgp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "23f9ff8c", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# Load libraries.\n", 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "import time\n", 14 | "import pickle\n", 15 | "from matminer.featurizers.site import CrystalNNFingerprint \n", 16 | "from matminer.featurizers.structure import SiteStatsFingerprint\n", 17 | "from KmdPlus import StatsDescriptor, formula_to_composition \n", 18 | "from pymatgen.core.composition import Composition\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "from collections import Counter\n", 21 | "# For parallel calculation.\n", 22 | "import joblib\n", 23 | "\n", 24 | "MP_data = pd.read_pickle(\"data_set/paper_used_mp_data_20211107.pd.xz\") # All crystal data from Materials Project.\n", 25 | "test_data = pd.read_pickle(\"data_set/all_searching_targets_20211107_with_predictions.pd.xz\") # Preselected crystal data for testing." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "id": "ea380c44", 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "data": { 36 | "text/html": [ 37 | "
\n", 38 | "\n", 51 | "\n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | "
full_formulacompositioncomposition_ratiototal_atomselementsn_elementsspace_group_numspace_groupwy_cfgwy_reformat...efermifinal_energy_per_atomformation_energy_per_atomhas_bandstructureis_orderedoxide_typepoint_grouppretty_formulatotal_magnetizationvolume
id
mp-1006278Ac1Eu1Au2{'Ac': 1.0, 'Eu': 1.0, 'Au': 2.0}(1.0, 1.0, 2.0)4.0(Ac, Au, Eu)3225Fm-3m{'Ac': {'b': 4}, 'Eu': {'a': 4}, 'Au': {'c': 8}}{'Ac': ('b',), 'Eu': ('a',), 'Au': ('c',)}...4.883417-6.019130-0.776843TrueTrueNonem-3mAcEuAu21.627705117.080578
mp-1017985Ti2Ag2{'Ti': 2.0, 'Ag': 2.0}(2.0, 2.0)4.0(Ag, Ti)2129P4/nmm{'Ti': {'c': 2}, 'Ag': {'c': 2}}{'Ti': ('c',), 'Ag': ('c',)}...4.709549-5.429487-0.065696TrueTrueNone4/mmmTiAg0.00043270.460966
mp-1018128Sc1Ag2{'Sc': 1.0, 'Ag': 2.0}(1.0, 2.0)3.0(Ag, Sc)2139I4/mmm{'Sc': {'a': 2}, 'Ag': {'e': 4}}{'Sc': ('a',), 'Ag': ('e',)}...3.934398-4.301338-0.302162TrueTrueNone4/mmmScAg20.00358457.497334
mp-1018131Lu1Ag2{'Lu': 1.0, 'Ag': 2.0}(1.0, 2.0)3.0(Ag, Lu)2139I4/mmm{'Lu': {'a': 2}, 'Ag': {'e': 4}}{'Lu': ('a',), 'Ag': ('e',)}...3.456485-3.736455-0.341119TrueTrueNone4/mmmLuAg20.00429262.417938
mp-1025059La2Ag4{'La': 2.0, 'Ag': 4.0}(2.0, 4.0)6.0(Ag, La)274Imma{'La': {'e': 4}, 'Ag': {'h': 8}}{'La': ('e',), 'Ag': ('h',)}...5.563882-3.832468-0.298780TrueTrueNonemmmLaAg20.000054150.182757
..................................................................
mp-945077Y2Te6{'Y': 2.0, 'Te': 6.0}(2.0, 6.0)8.0(Te, Y)263Cmcm{'Y': {'c': 4}, 'Te': {'c': 12}}{'Y': ('c',), 'Te': ('c', 'c', 'c')}...5.933911-4.907222-0.933227TrueTrueNonemmmYTe30.000283245.997038
mp-972256Xe1{'Xe': 1.0}(1.0,)1.0(Xe,)1166R-3m{'Xe': {'a': 1}}{'Xe': ('a',)}...-6.965177-0.0361740.000000TrueTrueNone-3mXe0.00000085.786507
mp-972364Yb3{'Yb': 3.0}(3.0,)3.0(Yb,)1166R-3m{'Yb': {'a': 1, 'c': 2}}{'Yb': ('a', 'c')}...1.419946-1.5395950.000000TrueTrueNone-3mYb0.000007123.042457
mp-977585Zr3Tl1{'Zr': 3.0, 'Tl': 1.0}(1.0, 3.0)4.0(Tl, Zr)2221Pm-3m{'Zr': {'c': 3}, 'Tl': {'a': 1}}{'Zr': ('c',), 'Tl': ('a',)}...5.632566-7.113048-0.111859TrueTrueNonem-3mZr3Tl0.00202590.588661
mp-9948V5Te4{'V': 5.0, 'Te': 4.0}(4.0, 5.0)9.0(Te, V)287I4/m{'V': {'a': 2, 'h': 8}, 'Te': {'h': 8}}{'V': ('a', 'h'), 'Te': ('h',)}...7.122283-6.749386-0.306610TrueTrueNone4/mV5Te40.837220186.898902
\n", 369 | "

33064 rows × 33 columns

\n", 370 | "
" 371 | ], 372 | "text/plain": [ 373 | " full_formula composition composition_ratio \\\n", 374 | "id \n", 375 | "mp-1006278 Ac1Eu1Au2 {'Ac': 1.0, 'Eu': 1.0, 'Au': 2.0} (1.0, 1.0, 2.0) \n", 376 | "mp-1017985 Ti2Ag2 {'Ti': 2.0, 'Ag': 2.0} (2.0, 2.0) \n", 377 | "mp-1018128 Sc1Ag2 {'Sc': 1.0, 'Ag': 2.0} (1.0, 2.0) \n", 378 | "mp-1018131 Lu1Ag2 {'Lu': 1.0, 'Ag': 2.0} (1.0, 2.0) \n", 379 | "mp-1025059 La2Ag4 {'La': 2.0, 'Ag': 4.0} (2.0, 4.0) \n", 380 | "... ... ... ... \n", 381 | "mp-945077 Y2Te6 {'Y': 2.0, 'Te': 6.0} (2.0, 6.0) \n", 382 | "mp-972256 Xe1 {'Xe': 1.0} (1.0,) \n", 383 | "mp-972364 Yb3 {'Yb': 3.0} (3.0,) \n", 384 | "mp-977585 Zr3Tl1 {'Zr': 3.0, 'Tl': 1.0} (1.0, 3.0) \n", 385 | "mp-9948 V5Te4 {'V': 5.0, 'Te': 4.0} (4.0, 5.0) \n", 386 | "\n", 387 | " total_atoms elements n_elements space_group_num \\\n", 388 | "id \n", 389 | "mp-1006278 4.0 (Ac, Au, Eu) 3 225 \n", 390 | "mp-1017985 4.0 (Ag, Ti) 2 129 \n", 391 | "mp-1018128 3.0 (Ag, Sc) 2 139 \n", 392 | "mp-1018131 3.0 (Ag, Lu) 2 139 \n", 393 | "mp-1025059 6.0 (Ag, La) 2 74 \n", 394 | "... ... ... ... ... \n", 395 | "mp-945077 8.0 (Te, Y) 2 63 \n", 396 | "mp-972256 1.0 (Xe,) 1 166 \n", 397 | "mp-972364 3.0 (Yb,) 1 166 \n", 398 | "mp-977585 4.0 (Tl, Zr) 2 221 \n", 399 | "mp-9948 9.0 (Te, V) 2 87 \n", 400 | "\n", 401 | " space_group wy_cfg \\\n", 402 | "id \n", 403 | "mp-1006278 Fm-3m {'Ac': {'b': 4}, 'Eu': {'a': 4}, 'Au': {'c': 8}} \n", 404 | "mp-1017985 P4/nmm {'Ti': {'c': 2}, 'Ag': {'c': 2}} \n", 405 | "mp-1018128 I4/mmm {'Sc': {'a': 2}, 'Ag': {'e': 4}} \n", 406 | "mp-1018131 I4/mmm {'Lu': {'a': 2}, 'Ag': {'e': 4}} \n", 407 | "mp-1025059 Imma {'La': {'e': 4}, 'Ag': {'h': 8}} \n", 408 | "... ... ... \n", 409 | "mp-945077 Cmcm {'Y': {'c': 4}, 'Te': {'c': 12}} \n", 410 | "mp-972256 R-3m {'Xe': {'a': 1}} \n", 411 | "mp-972364 R-3m {'Yb': {'a': 1, 'c': 2}} \n", 412 | "mp-977585 Pm-3m {'Zr': {'c': 3}, 'Tl': {'a': 1}} \n", 413 | "mp-9948 I4/m {'V': {'a': 2, 'h': 8}, 'Te': {'h': 8}} \n", 414 | "\n", 415 | " wy_reformat ... efermi \\\n", 416 | "id ... \n", 417 | "mp-1006278 {'Ac': ('b',), 'Eu': ('a',), 'Au': ('c',)} ... 4.883417 \n", 418 | "mp-1017985 {'Ti': ('c',), 'Ag': ('c',)} ... 4.709549 \n", 419 | "mp-1018128 {'Sc': ('a',), 'Ag': ('e',)} ... 3.934398 \n", 420 | "mp-1018131 {'Lu': ('a',), 'Ag': ('e',)} ... 3.456485 \n", 421 | "mp-1025059 {'La': ('e',), 'Ag': ('h',)} ... 5.563882 \n", 422 | "... ... ... ... \n", 423 | "mp-945077 {'Y': ('c',), 'Te': ('c', 'c', 'c')} ... 5.933911 \n", 424 | "mp-972256 {'Xe': ('a',)} ... -6.965177 \n", 425 | "mp-972364 {'Yb': ('a', 'c')} ... 1.419946 \n", 426 | "mp-977585 {'Zr': ('c',), 'Tl': ('a',)} ... 5.632566 \n", 427 | "mp-9948 {'V': ('a', 'h'), 'Te': ('h',)} ... 7.122283 \n", 428 | "\n", 429 | " final_energy_per_atom formation_energy_per_atom has_bandstructure \\\n", 430 | "id \n", 431 | "mp-1006278 -6.019130 -0.776843 True \n", 432 | "mp-1017985 -5.429487 -0.065696 True \n", 433 | "mp-1018128 -4.301338 -0.302162 True \n", 434 | "mp-1018131 -3.736455 -0.341119 True \n", 435 | "mp-1025059 -3.832468 -0.298780 True \n", 436 | "... ... ... ... \n", 437 | "mp-945077 -4.907222 -0.933227 True \n", 438 | "mp-972256 -0.036174 0.000000 True \n", 439 | "mp-972364 -1.539595 0.000000 True \n", 440 | "mp-977585 -7.113048 -0.111859 True \n", 441 | "mp-9948 -6.749386 -0.306610 True \n", 442 | "\n", 443 | " is_ordered oxide_type point_group pretty_formula \\\n", 444 | "id \n", 445 | "mp-1006278 True None m-3m AcEuAu2 \n", 446 | "mp-1017985 True None 4/mmm TiAg \n", 447 | "mp-1018128 True None 4/mmm ScAg2 \n", 448 | "mp-1018131 True None 4/mmm LuAg2 \n", 449 | "mp-1025059 True None mmm LaAg2 \n", 450 | "... ... ... ... ... \n", 451 | "mp-945077 True None mmm YTe3 \n", 452 | "mp-972256 True None -3m Xe \n", 453 | "mp-972364 True None -3m Yb \n", 454 | "mp-977585 True None m-3m Zr3Tl \n", 455 | "mp-9948 True None 4/m V5Te4 \n", 456 | "\n", 457 | " total_magnetization volume \n", 458 | "id \n", 459 | "mp-1006278 1.627705 117.080578 \n", 460 | "mp-1017985 0.000432 70.460966 \n", 461 | "mp-1018128 0.003584 57.497334 \n", 462 | "mp-1018131 0.004292 62.417938 \n", 463 | "mp-1025059 0.000054 150.182757 \n", 464 | "... ... ... \n", 465 | "mp-945077 0.000283 245.997038 \n", 466 | "mp-972256 0.000000 85.786507 \n", 467 | "mp-972364 0.000007 123.042457 \n", 468 | "mp-977585 0.002025 90.588661 \n", 469 | "mp-9948 0.837220 186.898902 \n", 470 | "\n", 471 | "[33064 rows x 33 columns]" 472 | ] 473 | }, 474 | "execution_count": 2, 475 | "metadata": {}, 476 | "output_type": "execute_result" 477 | } 478 | ], 479 | "source": [ 480 | "# Exclude all formula in test data from MP data.\n", 481 | "MP_data_left = MP_data[np.invert(MP_data.pretty_formula.isin(test_data.pretty_formula))]\n", 482 | "# Get stable data.\n", 483 | "MP_stable = MP_data_left[MP_data_left.e_above_hull.values == 0]\n", 484 | "# Delete overlapping formula in stable data.\n", 485 | "count = Counter(MP_stable.pretty_formula).most_common()\n", 486 | "keys = np.array([count[i][0] for i in range(len(count))])\n", 487 | "freqs = np.array([count[i][1] for i in range(len(count))])\n", 488 | "overlapping_formulas = keys[freqs>1]\n", 489 | "\n", 490 | "excl_ids = []\n", 491 | "for i in range(len(overlapping_formulas)):\n", 492 | " x = MP_stable[MP_stable.pretty_formula.values == overlapping_formulas[i]]\n", 493 | " x_sorted = x.sort_values(\"final_energy_per_atom\")\n", 494 | " excl_ids.append(np.asarray(x_sorted.index[1:]))\n", 495 | " \n", 496 | "MP_stable = MP_stable[np.invert(MP_stable.index.isin(np.concatenate(excl_ids)))]\n", 497 | "MP_stable" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": 3, 503 | "id": "e53575b1", 504 | "metadata": {}, 505 | "outputs": [], 506 | "source": [ 507 | "# Calculate the local order parameter fingerprints for all stable structures (DOI: 10.3389/fmats.2017.00034.).\n", 508 | "structures = MP_stable.structure.values\n", 509 | "\n", 510 | "# Site featurizer.\n", 511 | "cnnf = CrystalNNFingerprint.from_preset('ops', distance_cutoffs=None, x_diff_weight=0)\n", 512 | "\n", 513 | "def parallel_cnnf(featurizer, str_x):\n", 514 | " return np.array(joblib.Parallel(n_jobs=-1)(joblib.delayed(featurizer)(str_x, i) for i in range(len(str_x.sites))))\n", 515 | "\n", 516 | "# SiteStats.\n", 517 | "def SiteStats(site_fgps):\n", 518 | " return np.array([site_fgps.mean(0), site_fgps.std(0), site_fgps.min(0), site_fgps.max(0)]).T.flatten()" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": 4, 524 | "id": "0764c5f3", 525 | "metadata": {}, 526 | "outputs": [ 527 | { 528 | "name": "stdout", 529 | "output_type": "stream", 530 | "text": [ 531 | "time: 4228.889889001846\n", 532 | "time per iteration: 0.12790012971817827\n", 533 | "(33064, 244)\n" 534 | ] 535 | } 536 | ], 537 | "source": [ 538 | "# Calculate structure fingerprints for all stable data.\n", 539 | "n_iter = len(structures)\n", 540 | "\n", 541 | "strfgp_stable = []\n", 542 | "errors_i = []\n", 543 | "\n", 544 | "s = time.time()\n", 545 | "\n", 546 | "for i in range(n_iter):\n", 547 | " str_x = structures[i] # ith str.\n", 548 | " \n", 549 | " try:\n", 550 | " strfgp_stable.append(SiteStats(parallel_cnnf(cnnf.featurize, str_x))) # site fgps for the ith str.\n", 551 | " \n", 552 | " except:\n", 553 | " strfgp_stable.append(\"NA\")\n", 554 | " errors_i.append(i)\n", 555 | " print(f\"error at {i}\")\n", 556 | " \n", 557 | "e = time.time()\n", 558 | "print(f\"time: {e-s}\")\n", 559 | "print(f\"time per iteration: {(e-s)/n_iter}\")\n", 560 | "\n", 561 | "# Save results.\n", 562 | "strfgp_stable_array = np.array(strfgp_stable)\n", 563 | "\n", 564 | "print(strfgp_stable_array.shape)\n", 565 | "\n", 566 | "np.save('data_set/strfgp_stable_20211107', strfgp_stable_array)" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": 5, 572 | "id": "3ec28d65", 573 | "metadata": {}, 574 | "outputs": [ 575 | { 576 | "name": "stdout", 577 | "output_type": "stream", 578 | "text": [ 579 | "(33064, 290)\n" 580 | ] 581 | } 582 | ], 583 | "source": [ 584 | "# Calculate fingerprints for chemical compositions (five statistics of element_features).\n", 585 | "\n", 586 | "# Element-level descriptors of shape (94, 58).\n", 587 | "element_features = pd.read_csv(\"data_set/element_features.csv\", index_col= 0)\n", 588 | "\n", 589 | "cmpfgp_stable_array = StatsDescriptor(MP_stable.pretty_formula.values, element_features)\n", 590 | "\n", 591 | "# Save results.\n", 592 | "print(cmpfgp_stable_array.shape)\n", 593 | "\n", 594 | "np.save('data_set/cmpfgp_stable_20211107', cmpfgp_stable_array)" 595 | ] 596 | }, 597 | { 598 | "cell_type": "code", 599 | "execution_count": 6, 600 | "id": "04b2ca14", 601 | "metadata": {}, 602 | "outputs": [ 603 | { 604 | "name": "stdout", 605 | "output_type": "stream", 606 | "text": [ 607 | "time: 20.505083799362183\n", 608 | "time per iteration: 0.22783426443735758\n", 609 | "(90, 244)\n" 610 | ] 611 | } 612 | ], 613 | "source": [ 614 | "# Calculate structure fingerprints for test data.\n", 615 | "structures = test_data.structure.values\n", 616 | "\n", 617 | "n_iter = len(structures)\n", 618 | "\n", 619 | "strfgp_stable = []\n", 620 | "errors_i = []\n", 621 | "\n", 622 | "s = time.time()\n", 623 | "\n", 624 | "for i in range(n_iter):\n", 625 | " str_x = structures[i] # ith str.\n", 626 | " \n", 627 | " try:\n", 628 | " strfgp_stable.append(SiteStats(parallel_cnnf(cnnf.featurize, str_x))) # site fgps for the ith str.\n", 629 | " \n", 630 | " except:\n", 631 | " strfgp_stable.append(\"NA\")\n", 632 | " errors_i.append(i)\n", 633 | " print(f\"error at {i}\")\n", 634 | " \n", 635 | "e = time.time()\n", 636 | "print(f\"time: {e-s}\")\n", 637 | "print(f\"time per iteration: {(e-s)/n_iter}\")\n", 638 | "\n", 639 | "# Save results.\n", 640 | "strfgp_stable_array = np.array(strfgp_stable)\n", 641 | "\n", 642 | "print(strfgp_stable_array.shape)\n", 643 | "\n", 644 | "np.save('data_set/strfgp_test_20211107', strfgp_stable_array)" 645 | ] 646 | }, 647 | { 648 | "cell_type": "code", 649 | "execution_count": null, 650 | "id": "03d767b5", 651 | "metadata": {}, 652 | "outputs": [], 653 | "source": [] 654 | } 655 | ], 656 | "metadata": { 657 | "kernelspec": { 658 | "display_name": "Python 3 (ipykernel)", 659 | "language": "python", 660 | "name": "python3" 661 | }, 662 | "language_info": { 663 | "codemirror_mode": { 664 | "name": "ipython", 665 | "version": 3 666 | }, 667 | "file_extension": ".py", 668 | "mimetype": "text/x-python", 669 | "name": "python", 670 | "nbconvert_exporter": "python", 671 | "pygments_lexer": "ipython3", 672 | "version": "3.9.16" 673 | } 674 | }, 675 | "nbformat": 4, 676 | "nbformat_minor": 5 677 | } 678 | -------------------------------------------------------------------------------- /CSPML_latest_codes/KmdPlus.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # author: Minoru Kusaba (ISM, kusaba@ism.ac.jp) 3 | # last update: 2022/07/31 4 | # minor update: 2023/07/21 5 | 6 | """ 7 | This module contains a class for treating Kernel mean descriptor (KMD), 8 | and a function for generating descriptors with summary statistics. 9 | """ 10 | 11 | # Import libraries. 12 | import pandas as pd 13 | import numpy as np 14 | from statistics import median 15 | from scipy.spatial import distance_matrix 16 | from pymatgen.core.composition import Composition 17 | from qpsolvers import solve_qp 18 | 19 | # Load preset data. 20 | element_features = pd.read_csv("data_set/element_features.csv", index_col= 0) # element-level descriptors of shape (94, 58). 21 | elements = list(element_features.index) # 94 elements, "H" ~ "Pu". 22 | 23 | # Define functions. 24 | def formula_to_composition(formula, elements = elements): 25 | """ 26 | Convert a chemical formula to a composition vector for the predifened elements. 27 | 28 | Args 29 | ---- 30 | formula: str 31 | Chemical formula (e.g. "SiO2"). 32 | elements: a list of str 33 | Chemical elements (e.g. ["H", "He", ...]). 34 | Returns 35 | ---- 36 | vec: numpy.ndarray of shape (len(elements),). 37 | """ 38 | comp = Composition(formula) 39 | vec = np.array([comp.get_atomic_fraction(elements[i]) for i in range(len(elements))]) 40 | return vec 41 | 42 | class KMD(): 43 | """ 44 | Kernel mean descriptor (KMD). 45 | """ 46 | def __init__(self, method = "1d"): 47 | """ 48 | Parameters 49 | ---- 50 | method: str, default = "1d" 51 | method must be "md" or "1d". 52 | For "md", KMD is generated on a multidimensional feature space. 53 | For "1d", KMD is generated for each feature, then combined. 54 | ---- 55 | """ 56 | self.method = method 57 | 58 | def transform(self, weight, component_features, n_grids = None, sigma = "auto", scale = True): 59 | """ 60 | Generate kernel mean descriptor (KMD) with the Gaussian kernel (materials → descriptors). 61 | 62 | Args 63 | ---- 64 | weight: array-like of shape (n_samples, n_components) 65 | Mixing ratio of constituent elements that make up each sample. 66 | component_features: array-like of shape (n_components, n_features) 67 | Features for each constituent element. 68 | n_grids: int, default = None 69 | The number of grids for discretizing the kernel mean. 70 | The kernel mean is discretized at the n_grids equally spaced grids 71 | between a maximum and minimum values for each feature. 72 | This argument is only necessary for "1d". 73 | sigma: str or float, default = "auto" 74 | A hyper parameter defines the kernel width. 75 | If sima = "auto", the kernel width is given as the inverse median of the nearest distances 76 | for "md", and as the inverse of the grid width for "1d". 77 | scale: bool, default = True 78 | IF scale = True, component_features is scaled. 79 | Returns 80 | ---- 81 | KMD: numpy array of shape (n_samples, n_components) for "md", and (n_samples, n_features*n_grids) for "1d". 82 | """ 83 | self.component_features = component_features 84 | self.sigma = sigma 85 | self.scale = scale 86 | # Generate KMD on a multidimensional feature space. 87 | if self.method == "md": 88 | 89 | # Standardize each feature to have mean 0 and variance 1 (for "md"). 90 | if scale == True: 91 | component_features = (component_features - component_features.mean(axis=0))/component_features.std(axis=0, ddof=1) 92 | else: 93 | pass 94 | 95 | # Set the kernel width as the inverse median of the nearest distances. 96 | if sigma == "auto": 97 | d = distance_matrix(component_features, component_features)**2 98 | min_dist = [np.sort(d[i,:])[1] for i in range(component_features.shape[0])] # the nearest distances 99 | gamma = 1/median(min_dist) 100 | kernelized_component_features = np.exp(-d * gamma) 101 | KMD = np.dot(weight, kernelized_component_features) 102 | return KMD 103 | 104 | # Manually set the kernel width. 105 | else: 106 | d = distance_matrix(component_features, component_features)**2 107 | kernelized_component_features = np.exp(-d/(2*sigma**2)) 108 | KMD = np.dot(weight, kernelized_component_features) 109 | return KMD 110 | 111 | # Generate KMD for each feature, then combine them. 112 | elif self.method == "1d": 113 | 114 | if n_grids == None: 115 | print('For self.method = "1d", please set n_grids') 116 | return 117 | else: 118 | pass 119 | 120 | # Min-Max Normalization (for "1d"). 121 | if scale == True: 122 | component_features = (component_features - component_features.min(axis=0))/(component_features.max(axis=0) - component_features.min(axis=0)) 123 | else: 124 | pass 125 | 126 | # Set the kernel width as the inverse of the grid width. 127 | if sigma == "auto": 128 | max_cf = component_features.max(axis=0) 129 | min_cf = component_features.min(axis=0) 130 | x = np.asarray(component_features) 131 | k = [] 132 | for i in range(component_features.shape[1]): 133 | grid_points = np.linspace(min_cf[i], max_cf[i], n_grids) 134 | gamma = 1/(grid_points[1] - grid_points[0])**2 135 | d = np.array([(x[j,i] - grid_points)**2 for j in range(x.shape[0])]) 136 | k.append(np.exp(-d*gamma)) 137 | kernelized_component_features = np.concatenate(k, axis=1) 138 | KMD = np.dot(weight, kernelized_component_features) 139 | return KMD 140 | 141 | # Manually set the kernel width. 142 | else: 143 | max_cf = component_features.max(axis=0) 144 | min_cf = component_features.min(axis=0) 145 | x = np.asarray(component_features) 146 | k = [] 147 | for i in range(component_features.shape[1]): 148 | grid_points = np.linspace(min_cf[i], max_cf[i], n_grids) 149 | d = np.array([(x[j,i] - grid_points)**2 for j in range(x.shape[0])]) 150 | k.append(np.exp(-d/(2*sigma**2))) 151 | kernelized_component_features = np.concatenate(k, axis=1) 152 | KMD = np.dot(weight, kernelized_component_features) 153 | return KMD 154 | else: 155 | print('self.method must be "md" or "1d"') 156 | 157 | def inverse_transform(self, KMD): 158 | """ 159 | Derive the weights of the constituent elements for a given kernel mean descriptors 160 | by solving a quadratic programming (descriptors → materials). 161 | 162 | Args 163 | ---- 164 | KMD: array-like of shape (n_samples, n_components) for "md", (n_samples, n_features*n_grids) for "1d". 165 | Kernel mean descriptor (KMD). 166 | Returns 167 | ---- 168 | weight: numpy array of shape (n_samples, n_components). 169 | """ 170 | component_features = self.component_features 171 | sigma = self.sigma 172 | scale = self.scale 173 | if self.method == "md": 174 | 175 | # Standardize each feature to have mean 0 and variance 1 (for "md"). 176 | if scale == True: 177 | component_features = (component_features - component_features.mean(axis=0))/component_features.std(axis=0, ddof=1) 178 | else: 179 | pass 180 | 181 | KMD = np.asarray(KMD) 182 | n_components = KMD.shape[1] 183 | 184 | # Set the kernel width as the inverse median of the nearest distances. 185 | if sigma == "auto": 186 | d = distance_matrix(component_features, component_features)**2 187 | min_dist = [np.sort(d[i,:])[1] for i in range(component_features.shape[0])] # the nearest distances 188 | gamma = 1/median(min_dist) 189 | kernelized_component_features = np.exp(-d * gamma) 190 | P = np.dot(kernelized_component_features, kernelized_component_features.T) 191 | if min(np.linalg.eigvals(P)) <= 0: 192 | print("Given KMD is not inversible: smaller sigma may solve the problem") 193 | return 194 | else: 195 | pass 196 | # Equality constraints. 197 | A = np.ones(P.shape[0]) 198 | b = np.array([1.]) 199 | # Inequality constraints. 200 | G = np.diag(-A) 201 | h = np.zeros(P.shape[0]) 202 | # Solve quadratic programming. 203 | w_raw = np.array([solve_qp(P, -np.dot(kernelized_component_features, KMD[i]) 204 | , G, h, A, b, solver="quadprog") for i in range(KMD.shape[0])]) 205 | w = np.round(abs(w_raw), 12) 206 | weight = w/w.sum(axis=1)[:, None] 207 | return weight 208 | 209 | # Manually set the kernel width. 210 | else: 211 | d = distance_matrix(component_features, component_features)**2 212 | kernelized_component_features = np.exp(-d/(2*sigma**2)) 213 | P = np.dot(kernelized_component_features, kernelized_component_features.T) 214 | if min(np.linalg.eigvals(P)) <= 0: 215 | print("Given KMD is not inversible: smaller sigma may solve the problem") 216 | return 217 | else: 218 | pass 219 | # Equality constraints. 220 | A = np.ones(P.shape[0]) 221 | b = np.array([1.]) 222 | # Inequality constraints. 223 | G = np.diag(-A) 224 | h = np.zeros(P.shape[0]) 225 | # Solve quadratic programming. 226 | w_raw = np.array([solve_qp(P, -np.dot(kernelized_component_features, KMD[i]) 227 | , G, h, A, b, solver="quadprog") for i in range(KMD.shape[0])]) 228 | w = np.round(abs(w_raw), 12) 229 | weight = w/w.sum(axis=1)[:, None] 230 | return weight 231 | 232 | elif self.method == "1d": 233 | 234 | KMD = np.asarray(KMD) 235 | n_grids = int(KMD.shape[1]/component_features.shape[1]) 236 | 237 | # Min-Max Normalization (for "1d"). 238 | if scale == True: 239 | component_features = (component_features - component_features.min(axis=0))/(component_features.max(axis=0) - component_features.min(axis=0)) 240 | else: 241 | pass 242 | 243 | # Set the kernel width as the inverse of the grid width. 244 | if sigma == "auto": 245 | max_cf = component_features.max(axis=0) 246 | min_cf = component_features.min(axis=0) 247 | x = np.asarray(component_features) 248 | k = [] 249 | for i in range(component_features.shape[1]): 250 | grid_points = np.linspace(min_cf[i], max_cf[i], n_grids) 251 | gamma = 1/(grid_points[1] - grid_points[0])**2 252 | d = np.array([(x[j,i] - grid_points)**2 for j in range(x.shape[0])]) 253 | k.append(np.exp(-d*gamma)) 254 | kernelized_component_features = np.concatenate(k, axis=1) 255 | P = np.dot(kernelized_component_features, kernelized_component_features.T) 256 | if min(np.linalg.eigvals(P)) <= 0: 257 | print("Given KMD is not inversible: consider increasing the number of grids (n_grids)") 258 | return 259 | else: 260 | pass 261 | # Equality constraints. 262 | A = np.ones(P.shape[0]) 263 | b = np.array([1.]) 264 | # Inequality constraints. 265 | G = np.diag(-A) 266 | h = np.zeros(P.shape[0]) 267 | # Solve quadratic programming. 268 | w_raw = np.array([solve_qp(P, -np.dot(kernelized_component_features, KMD[i]) 269 | , G, h, A, b, solver="quadprog") for i in range(KMD.shape[0])]) 270 | w = np.round(abs(w_raw), 12) 271 | weight = w/w.sum(axis=1)[:, None] 272 | return weight 273 | 274 | # Manually set the kernel width. 275 | else: 276 | max_cf = component_features.max(axis=0) 277 | min_cf = component_features.min(axis=0) 278 | x = np.asarray(component_features) 279 | k = [] 280 | for i in range(component_features.shape[1]): 281 | grid_points = np.linspace(min_cf[i], max_cf[i], n_grids) 282 | d = np.array([(x[j,i] - grid_points)**2 for j in range(x.shape[0])]) 283 | k.append(np.exp(-d/(2*sigma**2))) 284 | kernelized_component_features = np.concatenate(k, axis=1) 285 | P = np.dot(kernelized_component_features, kernelized_component_features.T) 286 | if min(np.linalg.eigvals(P)) <= 0: 287 | print("Given KMD is not inversible: consider increasing the number of grids (n_grids)") 288 | return 289 | else: 290 | pass 291 | # Equality constraints. 292 | A = np.ones(P.shape[0]) 293 | b = np.array([1.]) 294 | # Inequality constraints. 295 | G = np.diag(-A) 296 | h = np.zeros(P.shape[0]) 297 | # Solve quadratic programming. 298 | w_raw = np.array([solve_qp(P, -np.dot(kernelized_component_features, KMD[i]) 299 | , G, h, A, b, solver="quadprog") for i in range(KMD.shape[0])]) 300 | w = np.round(abs(w_raw), 12) 301 | weight = w/w.sum(axis=1)[:, None] 302 | return weight 303 | 304 | else: 305 | print('self.method must be "md" or "1d"') 306 | 307 | def StatsDescriptor(formula, component_features, stats = ["mean", "sum", "var", "max", "min"]): 308 | """ 309 | Generate descriptors for mixture systems using summary statistics. 310 | 311 | Args 312 | ---- 313 | weight: array-like of shape (n_samples, n_components) 314 | Mixing ratio of constituent elements that make up each sample. 315 | component_features: array-like of shape (n_components, n_features) 316 | Features for each constituent element. 317 | stats: a list of str, default = ["mean", "sum", "var", "max", "min"] 318 | Type of summary statistics for generating descriptors. 319 | Only "mean", "sum", "var", "max" and "min" are supported. 320 | Returns 321 | ---- 322 | SD: numpy array of shape (n_samples, n_features*len(stats)). 323 | """ 324 | n_samples = len(formula) 325 | # Get comp weight. 326 | w = np.array([formula_to_composition(formula[i]) for i in range(n_samples)]) 327 | # as array. 328 | cf = np.asarray(component_features) 329 | 330 | s = [] 331 | for x in stats: 332 | # Weighted mean. 333 | if x == "mean": 334 | wm = np.dot(w, cf) 335 | s.append(wm) 336 | # Weighted mean. 337 | elif x == "sum": 338 | wm = np.dot(w, cf) 339 | n_atoms = np.array([Composition(formula[i]).num_atoms for i in range(n_samples)]) # only for sum. 340 | n_atoms_array = np.array([n_atoms for i in range(wm.shape[1])]).T 341 | s.append(wm * n_atoms_array) 342 | # Weighted variance. 343 | elif x == "var": 344 | wm = np.dot(w, cf) 345 | wv = np.array([np.dot(w[i], (cf - wm[i])**2) for i in range(n_samples)]) 346 | s.append(wv) 347 | # Maximum pooling. 348 | elif x == "max": 349 | nonzero = (w != 0) 350 | maxp = np.array([cf[nonzero[i]].max(axis = 0) for i in range(n_samples)]) 351 | s.append(maxp) 352 | # Minimum pooling. 353 | elif x == "min": 354 | nonzero = (w != 0) 355 | minp = np.array([cf[nonzero[i]].min(axis = 0) for i in range(n_samples)]) 356 | s.append(minp) 357 | else: 358 | print(f'"{x}" is not supported: only "mean", "var", "max" and "min" are supported as stats') 359 | 360 | SD = np.concatenate(s, axis = 1) 361 | return SD -------------------------------------------------------------------------------- /CSPML_latest_codes/cif_files_for_90crystals.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minoru938/CSPML/4410c98b55c82b1f5d00bd581676281889f87cd3/CSPML_latest_codes/cif_files_for_90crystals.zip -------------------------------------------------------------------------------- /CSPML_latest_codes/data_set/CSML_90_opt_results.pd.xz: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d125d085a152133e93edde640030a7d2564ef3be87302b79dfca93a08913bb5c 3 | size 234768 4 | -------------------------------------------------------------------------------- /CSPML_latest_codes/data_set/CSPML_models.xz: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:84c6ba963f32a0af2ed5f07731e8440c42dae177b23eaa9e4f5696dc6cf60ec0 3 | size 28262838 4 | -------------------------------------------------------------------------------- /CSPML_latest_codes/data_set/MP_stable_20211107.pd.xz: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6b3a8cdf2db483c867ab9e99cec9188ac79192d85503fae8ec98e812c8a1aa87 3 | size 78013796 4 | -------------------------------------------------------------------------------- /CSPML_latest_codes/data_set/all_searching_targets_20211107_with_predictions.pd.xz: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ad4edf58a907dfa496ca8e39c9439789d4331ceec3e89feee6a809007637f324 3 | size 89148 4 | -------------------------------------------------------------------------------- /CSPML_latest_codes/data_set/cmpfgp_stable_meanstd_20211107.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:57d65d65f91a54b27dd6c248bebc241d3ee1ae56f9863cd3f18c0961639a7636 3 | size 4768 4 | -------------------------------------------------------------------------------- /CSPML_latest_codes/data_set/element_dissimilarity.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:da2af89e31e94d6df9cd88d856e035a2a699271c998f193d7a0fab23eb3912c4 3 | size 70816 4 | -------------------------------------------------------------------------------- /CSPML_latest_codes/data_set/element_features.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:06a2c633a3b569bb5ec7a4a47294c59b7858742abbde74d5571c80cc09ed8f31 3 | size 37527 4 | -------------------------------------------------------------------------------- /CSPML_latest_codes/data_set/paper_used_mp_data_20211107.pd.xz: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7cf23bfea7a3ac1dea3078d81ef0e932735146bda54f2e79f8f9ff91c575cbfc 3 | size 72937780 4 | -------------------------------------------------------------------------------- /CSPML_latest_codes/data_set/strfgp_test_20211107.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8ff5d1e0c7fbba15b293e8b2b0c1baf76591d4b957354350258dfab4820d4f8d 3 | size 175808 4 | -------------------------------------------------------------------------------- /CSPML_latest_codes/readme.txt: -------------------------------------------------------------------------------- 1 |  2 | About this file: 3 | 4 | This file contains the latest version of CSPML including training codes. This code corresponds to the result of the paper "Shotgun crystal structure prediction using machine-learned formation energies" (https://arxiv.org/abs/2305.02158). See "Details of the CSPML model" section in supplementary information of the paper for details. 5 | 6 | 7 | How to build a conda environment for CSPML: 8 | 9 | 1. cd into this directory. 10 | 2. Build a conda environment from CSPML.yml by conda env create -n CSPML -f CSPML.yml 11 | 12 | 13 | Usage: 14 | 15 | ・ To immediately reproduce the crystal structure prediction results reported in the paper, run "CSPML_Structure_Prediction.ipynb" in Jupyter Notebook. You should get the same prediction results as those contained in "cif_files_for_90crystals.zip". 16 | 17 | ・ If you want to start with training the model, run "Create_strcmp_fgp.ipynb" → "CSPML_Creating_MLdata.ipynb" → "CSPML_training.ipynb" → "CSPML_Structure_Prediction.ipynb " in that order. 18 | 19 | ################################################################ 20 | # If the yml file does not work properly, please refer to the following to build the environment manually 21 | 22 | Dependencies: 23 | 24 | pandas version = 1.5.1 25 | numpy version = 1.22.4 26 | tensorflow-macos version = 2.9.0 27 | tensorflow-metal = 0.5.1 28 | pymatgen version = 2022.5.26 29 | matminer version = 0.7.8 30 | scipy version == 1.8.1 31 | joblib version == 1.2.0 32 | matplotlib version == 3.7.1 33 | scikit-learn version == 1.1.3 34 | keras version == 2.9.0 35 | optuna version == 3.0.3 36 | qpsolvers version == 2.6.0 # peer dependency for KmdPlus.py. 37 | 38 | Environment of author: 39 | Python 3.9.16 40 | macOS Ventura 13.4.1 41 | -------------------------------------------------------------------------------- /CSPML_latest_codes/tools.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Divide a dataset into training, validation, test set by random sampling (randomness included).\n", 10 | "def divide_dataset(N, test_prop = 0.1, val_prop = 0.1):\n", 11 | " \n", 12 | " # Get sample numbers for test, validation, and training.\n", 13 | " test_N = round(N * test_prop)\n", 14 | " val_N = round(N * val_prop)\n", 15 | " train_N = N - test_N - val_N\n", 16 | "\n", 17 | " # Sampling test index.\n", 18 | " test_index = np.sort(np.random.choice(N, test_N, replace=False))\n", 19 | "\n", 20 | " # Get trainval index by removing test index from all index.\n", 21 | " trainval_index = np.setdiff1d(np.array(range(N)), test_index)\n", 22 | "\n", 23 | " # Sampling validation index.\n", 24 | " val_index = np.sort(np.random.choice(trainval_index, val_N, replace = False))\n", 25 | "\n", 26 | " # Get train index by removing val index from trainval index.\n", 27 | " train_index = np.setdiff1d(trainval_index, val_index)\n", 28 | "\n", 29 | " # Summarize as dictionary.\n", 30 | " result = {\"train_ids\":train_index,\"val_ids\":val_index,\"test_ids\":test_index}\n", 31 | "\n", 32 | " return result\n", 33 | "\n", 34 | "# Divide a dataset into training and test set by random sampling (randomness included).\n", 35 | "def train_test_split(N, test_prop = 0.2):\n", 36 | " \n", 37 | " # Get sample sizes for training and test.\n", 38 | " test_N = round(N * test_prop)\n", 39 | " train_N = N - test_N\n", 40 | " \n", 41 | " # Sampling test index.\n", 42 | " test_index = np.sort(np.random.choice(N, test_N, replace=False))\n", 43 | " \n", 44 | " # Get train index by removing test index.\n", 45 | " train_index = np.setdiff1d(np.array(range(N)), test_index)\n", 46 | " \n", 47 | " result = {\"train_ids\":train_index, \"test_ids\":test_index}\n", 48 | " return result\n", 49 | "\n", 50 | "# Get a sub-dataset of the given X, and Y by bagging.\n", 51 | "def bagging(X, Y):\n", 52 | " \n", 53 | " N = X.shape[0]\n", 54 | " ids = np.random.choice(N, N, replace=True) # bagging.\n", 55 | " \n", 56 | " return X[ids], Y[ids]\n", 57 | "\n", 58 | "# Get a sub-dataset of the given X, and Y by pasting.\n", 59 | "def pasting(X, Y, prop = 0.6):\n", 60 | " \n", 61 | " N = X.shape[0]\n", 62 | " sample_N = round(N * prop)\n", 63 | " \n", 64 | " ids = np.random.choice(N, sample_N, replace=False) # pasting.\n", 65 | " \n", 66 | " return X[ids], Y[ids]" 67 | ] 68 | } 69 | ], 70 | "metadata": { 71 | "kernelspec": { 72 | "display_name": "Python 3 (ipykernel)", 73 | "language": "python", 74 | "name": "python3" 75 | }, 76 | "language_info": { 77 | "codemirror_mode": { 78 | "name": "ipython", 79 | "version": 3 80 | }, 81 | "file_extension": ".py", 82 | "mimetype": "text/x-python", 83 | "name": "python", 84 | "nbconvert_exporter": "python", 85 | "pygments_lexer": "ipython3", 86 | "version": "3.9.16" 87 | } 88 | }, 89 | "nbformat": 4, 90 | "nbformat_minor": 4 91 | } 92 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Minoru Kusaba 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSPML (crystal structure prediction with machine learning-based element substitution) 2 | 3 | CSPML is a unique methodology for the crystal structure prediction (CSP) that relies on a machine learning algorithm (binary classification neural network model). CSPML predicts a stable structure 4 | for any given query composition, by automatically selecting from a crystal structure database a set of template crystals with nearly identical stable structures to which atomic substitution is to 5 | be applied. Pre-trained models are used to select the template crystals. The 33,153 stable compounds (all candidate crystals; obtained from the [Materials Project](https://materialsproject.org) database) and the pre-trained models are embedded in CSPML. 6 | 7 | For more details, please see our paper: 8 | [Crystal structure prediction with machine learning-based element substitution](https://doi.org/10.1016/j.commatsci.2022.111496) (Accepted 3 May 2022). 9 | 10 | # Dependencies 11 | 12 | * pandas version = 1.3.3 13 | * numpy version = 1.19.2 # tensorflow is compatible with numpy=<1.19.2 (01/14/2022). 14 | * tensorflow version = 2.6.0 15 | * pymatgen version = 2020.1.28 16 | * xenonpy version = 0.4.2 (see [this page](https://xenonpy.readthedocs.io/en/latest/installation.html) for installation) 17 | * torch version = 1.10.0 # peer dependency for xenonpy. 18 | * matminer version = 0.6.2 (optional; for calculating the structure fingerprint with [local structure order parameters](https://pubs.rsc.org/en/content/articlelanding/2020/ra/c9ra07755c)) 19 | 20 | # Usage 21 | 22 | 1. First install the dependencies listed above. 23 | 24 | 2. Clone the `CSPML` github repository: 25 | ```bash 26 | git clone https://github.com/Minoru938/CSPML.git 27 | ``` 28 | 29 | Note: Due to the size of this repository (about 500MB), this operation can take tens of minutes. 30 | 31 | 3. `cd` into `CSPML` directory. 32 | 33 | 4. Run `jupyter notebook` and open `tutorial.ipynb` to demonstrate `CSPML`. 34 | 35 | 36 | # Environment of author 37 | * Python 3.8.8 38 | * macOS Big Sur version 11.6 39 | 40 | # Addition of the latest version of CSPML (2024/07/09) 41 | 42 | The latest version of CSPML has been added to this repository as the file "CSPML_latest_codes." This file contains the CSPML training codes, which addressed bias in training data with an updated TensorFlow environment. Please refer to read_me.txt in this file for details on usage. This file corresponds to the result of the paper "[Shotgun crystal structure prediction using machine-learned formation energies](https://doi.org/10.1038/s41524-024-01471-8)". See the "Details of the CSPML model" section in the paper's supplementary information for details. If you want to use CSPML for actual crystal structure prediction or as a comparison method, I recommend using this version of CSPML. 43 | 44 | The article titled “Shotgun crystal structure prediction using machine-learned formation energies” has been officially published in *npj Computational Materials* (20 December 2024). 45 | 46 | # Reference 47 | 48 | 1. [Materials Project]: A. Jain, S. P. Ong, G. Hautier, W. Chen, W. D. Richards, S. Dacek, S. Cholia, D. Gunter, D. Skinner, G. Ceder, et al., Commentary: The materials project: 49 | A materials genome approach to accelerating materi- als innovation, APL materials 1 (1) (2013) 011002. 50 | 51 | 2. [XenonPy]: C. Liu, E. Fujita, Y. Katsura, Y. Inada, A. Ishikawa, R. Tamura, K. Kimura, R. Yoshida, Machine learning to predict quasicrystals from chemical compositions, 52 | Advanced Materials 33 (36) (2021) 2170284. 53 | 54 | 3. [Local structure order parameters]: N. E. Zimmermann, A. Jain, Local structure order parameters and site fingerprints for quantification of coordination environment and 55 | crystal structure similarity, RSC Advances 10 (10) (2020) 6063–6081. 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /data_set/MP_candidates.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7f84e8f1e530b480fd27dcfe53daeb1c6e4e9c1d99760342f7dc77f36c9467ba 3 | size 111272232 4 | -------------------------------------------------------------------------------- /data_set/MP_structures.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:19debac0a6c7c1fc646caebbf34f80050f980ad120546e07c1c3cbbc643743f8 3 | size 253995327 4 | -------------------------------------------------------------------------------- /data_set/candidates_paper.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fd1a58ce3258c58be0f7a05bacb39db571b40526e3041ac5fe0ee0a01e6a20c5 3 | size 110615715 4 | -------------------------------------------------------------------------------- /data_set/descriptor_standardization.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3f1ce6fa3eda6d93ebd2935d4d7097f85c73a0515c9f71919572e7c0f172469c 3 | size 14138 4 | -------------------------------------------------------------------------------- /data_set/element_dissimilarity.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:40999e03de2eed9e8704ce2e65f743e183a11bfb8856a54d435dc5b61615f37c 3 | size 85033 4 | -------------------------------------------------------------------------------- /data_set/model1_tau=0.3: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1ad3510ed550ccc1da12ee787e5cd0226e3ef9a49149f521651f7ed7997fd3e1 3 | size 271080 4 | -------------------------------------------------------------------------------- /data_set/model2_tau=0.3: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7b11d42491bc9f3ccca1c3705864db94ffa51a0cc4d9a6be10c230c9ee0fdc4a 3 | size 271080 4 | -------------------------------------------------------------------------------- /data_set/model3_tau=0.3: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2ea490d667ae28e2da9559cc4e556d2e8847f6b320861fc4a178f93b1b1c8bd7 3 | size 271096 4 | -------------------------------------------------------------------------------- /data_set/model4_tau=0.3: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:13b0e68719af8a612e47ae1e69db1f3df85c16f07be2fc5fb45b370d34b9d8c4 3 | size 271104 4 | -------------------------------------------------------------------------------- /data_set/model5_tau=0.3: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3aad52bceea0794f67a49d9d5d889da5f66a6d94407c008321f28cf289888912 3 | size 271104 4 | -------------------------------------------------------------------------------- /read_me.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minoru938/CSPML/4410c98b55c82b1f5d00bd581676281889f87cd3/read_me.txt -------------------------------------------------------------------------------- /tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4c98a308", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# Import CSPML module.\n", 11 | "import CSPML\n", 12 | "\n", 13 | "# Import libraries.\n", 14 | "import pandas as pd\n", 15 | "import numpy as np\n", 16 | "import pickle\n", 17 | "import os" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "id": "f117940e", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# Set the 38 benchmark sets (chemical formulas for crystal structure prediction).\n", 28 | "query_formula = ['Ag8GeS6','Al2O3','BN','Ba(FeAs)2','Ba2CaSi4(BO7)2','Bi2Te3','C','Ca14MnSb11','CaCO3','Cd3As2',\n", 29 | " 'CoSb3','CsPbI3','Cu12Sb4S13','Fe3O4','GaAs','GeH4','La2CuO4','Li3PS4','Li4Ti5O12','LiBF4','LiCoO2','LiFePO4',\n", 30 | " 'LiPF6','MgB7','Mn(FeO2)2','NaCaAlPHO5F2','Si','Si3N4','SiO2','SrTiO3','TiO2','V2O5','VO2','Y3Al5O12','ZnO',\n", 31 | " 'ZnSb','ZrO2','ZrTe5'] # (N=38)\n", 32 | "\n", 33 | "# Load candidate compounds (N=33,115) used in the paper (https://doi.org/10.1016/j.commatsci.2022.111496).\n", 34 | "with open(\"./data_set/candidates_paper.pkl\", \"rb\") as f:\n", 35 | " candidates_paper = pickle.load(f)" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "id": "ade713bf", 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "name": "stdout", 46 | "output_type": "stream", 47 | "text": [ 48 | "None of the candidates had the class probabilities greater than 0.5 at Ba2CaSi4(BO7)2.\n", 49 | "None of the candidates had the class probabilities greater than 0.5 at MgB7.\n", 50 | "None of the candidates had the same composition ratio as NaCaAlPHO5F2.\n", 51 | "The top-1th predicted structure for CoSb3 is shown below;\n", 52 | "Full Formula (Co4 Sb12)\n", 53 | "Reduced Formula: CoSb3\n", 54 | "abc : 7.948651 7.948651 7.948651\n", 55 | "angles: 109.471221 109.471221 109.471221\n", 56 | "Sites (16)\n", 57 | " # SP a b c magmom\n", 58 | "--- ---- -------- -------- -------- --------\n", 59 | " 0 Co 0.5 0.5 0.5 1.047\n", 60 | " 1 Co 0.5 0 0 0.95\n", 61 | " 2 Co 0 0 0.5 0.979\n", 62 | " 3 Co 0 0.5 0 0.994\n", 63 | " 4 Sb 0.669087 0.840293 0.50938 -0.001\n", 64 | " 5 Sb 0.840293 0.50938 0.669087 -0.007\n", 65 | " 6 Sb 0.50938 0.669087 0.840293 -0.007\n", 66 | " 7 Sb 0.669087 0.159707 0.828795 -0.002\n", 67 | " 8 Sb 0.840293 0.171205 0.330913 -0.01\n", 68 | " 9 Sb 0.159707 0.828795 0.669087 -0.01\n", 69 | " 10 Sb 0.171205 0.330913 0.840293 -0.011\n", 70 | " 11 Sb 0.159707 0.49062 0.330913 -0.007\n", 71 | " 12 Sb 0.330913 0.159707 0.49062 -0.001\n", 72 | " 13 Sb 0.49062 0.330913 0.159707 -0.007\n", 73 | " 14 Sb 0.330913 0.840293 0.171205 -0.002\n", 74 | " 15 Sb 0.828795 0.669087 0.159707 -0.011\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "# Predict stable structures for the given query formulas using the candidate compounds used in the paper.\n", 80 | "# In the example below, up to 12 structures are suggested for each query fomula.\n", 81 | "predicted = CSPML.Structure_prediction(query_formula, 12, candidates_paper)\n", 82 | "\n", 83 | "i = 10\n", 84 | "j = 0\n", 85 | "print(f\"The top-{j+1}th predicted structure for {query_formula[i]} is shown below;\\n{predicted[i][j]}\")" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "id": "e86e1656", 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | "None of the candidates had the class probabilities greater than 0.5 at Ba2CaSi4(BO7)2.\n", 99 | "None of the candidates had the class probabilities greater than 0.5 at MgB7.\n", 100 | "None of the candidates had the same composition ratio as NaCaAlPHO5F2.\n", 101 | "The top-1th predicted structure for CoSb3 is shown below; \n", 102 | "Full Formula (Co4 Sb12)\n", 103 | "Reduced Formula: CoSb3\n", 104 | "abc : 7.948651 7.948651 7.948651\n", 105 | "angles: 109.471221 109.471221 109.471221\n", 106 | "Sites (16)\n", 107 | " # SP a b c magmom\n", 108 | "--- ---- -------- -------- -------- --------\n", 109 | " 0 Co 0.5 0.5 0.5 1.047\n", 110 | " 1 Co 0.5 0 0 0.95\n", 111 | " 2 Co 0 0 0.5 0.979\n", 112 | " 3 Co 0 0.5 0 0.994\n", 113 | " 4 Sb 0.669087 0.840293 0.50938 -0.001\n", 114 | " 5 Sb 0.840293 0.50938 0.669087 -0.007\n", 115 | " 6 Sb 0.50938 0.669087 0.840293 -0.007\n", 116 | " 7 Sb 0.669087 0.159707 0.828795 -0.002\n", 117 | " 8 Sb 0.840293 0.171205 0.330913 -0.01\n", 118 | " 9 Sb 0.159707 0.828795 0.669087 -0.01\n", 119 | " 10 Sb 0.171205 0.330913 0.840293 -0.011\n", 120 | " 11 Sb 0.159707 0.49062 0.330913 -0.007\n", 121 | " 12 Sb 0.330913 0.159707 0.49062 -0.001\n", 122 | " 13 Sb 0.49062 0.330913 0.159707 -0.007\n", 123 | " 14 Sb 0.330913 0.840293 0.171205 -0.002\n", 124 | " 15 Sb 0.828795 0.669087 0.159707 -0.011\n", 125 | "This predicted structure was generated by element-substitution of the template structure; \n", 126 | "formula = FeSb3, material id = mp-971669\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "# If SI = True, the supplementary information of the predicted structures are also returned.\n", 132 | "# In the example below, up to 5 structures are suggested for each query fomula.\n", 133 | "predicted, SI = CSPML.Structure_prediction(query_formula, 5, candidates_paper, SI=True)\n", 134 | "\n", 135 | "print(f\"The top-{j+1}th predicted structure for {query_formula[i]} is shown below; \\n{predicted[i][j]}\")\n", 136 | "print(f\"This predicted structure was generated by element-substitution of the template structure; \\nformula = {SI[i]['topK_formula'][j]}, material id = {SI[i]['topK_id'][j]}\")" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 5, 142 | "id": "b83e2a34", 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stdout", 147 | "output_type": "stream", 148 | "text": [ 149 | "None of the candidates had the class probabilities greater than 0.5 at Ba2CaSi4(BO7)2.\n", 150 | "None of the candidates had the class probabilities greater than 0.5 at MgB7.\n", 151 | "None of the candidates had the same composition ratio as NaCaAlPHO5F2.\n", 152 | "The top-1th predicted structure for CoSb3 is saved as a CoSb3_1.cif.\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "# Make new directory for saving .cif files of the predicted structures.\n", 158 | "os.mkdir(\"predicted_structures_paper\")\n", 159 | "\n", 160 | "# If save_cif = True, the .cif files of the predicted structures are automatically saved into save_cif_filename.\n", 161 | "# In the example below, up to 5 structures are suggested for each query fomula.\n", 162 | "predicted = CSPML.Structure_prediction(query_formula, 5, candidates_paper, save_cif = True, \n", 163 | " save_cif_filename = \"predicted_structures_paper\")\n", 164 | "\n", 165 | "print(f\"The top-{j+1}th predicted structure for {query_formula[i]} is saved as a {query_formula[i]}_{j+1}.cif.\")" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 6, 171 | "id": "b2ee335b", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "# For general use.\n", 176 | "\n", 177 | "# Make new directory for saving .cif files of the predicted structures.\n", 178 | "os.mkdir(\"predicted_structures\")\n", 179 | "\n", 180 | "# Perform structure prediction with embedded candidate compounds in the CSPML module.\n", 181 | "# Since the embedded candidate compounds (=CSPML.MP_candidates; N=33,153) contains true structures for \n", 182 | "# all query formulas defined in this program, their true structures are suggested as top-1th predicted structures. \n", 183 | "# In the example below, up to 6 structures are suggested for each query fomula.\n", 184 | "\n", 185 | "predicted, SI = CSPML.Structure_prediction(query_formula, 6, SI = True, save_cif = True, \n", 186 | " save_cif_filename = \"predicted_structures\")\n", 187 | "\n", 188 | "# The simplest form is \"predicted= CSPML.Structure_prediction(query_formula, 6)\".\n", 189 | "# Since the candidate set is embedded in the module, the user only needs to set \"query_formula\" and \"top_K\".\n", 190 | "# Since candidates_paper is a subset of the embedded candidate compounds (candidates_paper is for reproducing\n", 191 | "# the result of the paper), if you use this module for general use, you should use this form." 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 7, 197 | "id": "9d592c9b", 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "name": "stdout", 202 | "output_type": "stream", 203 | "text": [ 204 | "None of the candidates had the same composition ratio as Ca2As4Xe5F34.\n", 205 | "None of the candidates had the same composition ratio as CsAg2(B5O8)3.\n", 206 | "None of the candidates had the same composition ratio as K4U5Te2O21.\n", 207 | "None of the candidates had the same composition ratio as Mn3Sb5(IO3)3.\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "# For reproducing the results of the crystal structure prediction of the 50 randomly selected benchmark sets\n", 213 | "# presented in the paper.\n", 214 | "\n", 215 | "# The 50 randomly selected query compositions.\n", 216 | "random50 = ['Ac2SnHg','Al6Ru','Ba2MgTl','Ba3Rh(CN)3','BaMg','BaNaP','Be17Ru3','Ca2Al2Sn2O9','Ca2As4Xe5F34','Cd3TeO6',\n", 217 | " 'CdPtF6','Cr2GaC','CsAg2(B5O8)3','Er4MgNi','Eu2Mo(WO6)2','EuCoO3','Ho(Al5Ru)2','Ho6Al7Cu16','HoFeSi','K3Y3(BO3)4',\n", 218 | " 'K4U5Te2O21','Li2BPt3','Li2DyIn','Li2SO4','Li7PN4','LiCeHg2','LiGdO2','LiIr','LuBiPd','Mg2Ga5Cu6','MgSn4Ru',\n", 219 | " 'Mn3Sb5(IO3)3','Na2Ga2As3','NaCaVO4','NdEuCuS3','NpAs2','Pr9Ga5S21','Rb12Sn2As6O','Rb2ScInCl6','Rb3H(SO4)2','SmH3',\n", 220 | " 'SrHI','SrNdVO4','Tb3Ni','TiMnSi2','TmRhO3','Y2GeRh3','Y2Te5O13','ZnH2SO5','Zr5CuSb3'] # (N=50)\n", 221 | "\n", 222 | "# Load preset 33,153 candidate compounds.\n", 223 | "with open(\"./data_set/MP_candidates.pkl\", \"rb\") as f:\n", 224 | " MP_candidates = pickle.load(f)\n", 225 | "\n", 226 | "# Prepare the candidate templates for the 50 query compositions.\n", 227 | "random50_property = MP_candidates[\"property\"][np.logical_not(MP_candidates[\"property\"][\"pretty_formula\"].isin(random50))].reset_index(drop=True)\n", 228 | "random50_composition = MP_candidates[\"composition\"][np.logical_not(MP_candidates[\"property\"][\"pretty_formula\"].isin(random50))].reset_index(drop=True)\n", 229 | "random50_descriptor = MP_candidates[\"descriptor\"][np.logical_not(MP_candidates[\"property\"][\"pretty_formula\"].isin(random50))].reset_index(drop=True)\n", 230 | "\n", 231 | "candidate_random50 = {'property':random50_property, 'composition':random50_composition, 'descriptor':random50_descriptor}\n", 232 | "\n", 233 | "# Crystal structure prediction for the 50 query compositions.\n", 234 | "# Make new directory for saving .cif files of the predicted structures.\n", 235 | "os.mkdir(\"predicted_random50\")\n", 236 | "\n", 237 | "# If save_cif = True, the .cif files of the predicted structures are automatically saved into save_cif_filename.\n", 238 | "# In the example below, up to 5 structures are suggested for each query fomula.\n", 239 | "predicted_50, SI_50 = CSPML.Structure_prediction(random50, 5, candidate_random50, save_cif = True, SI = True,\n", 240 | " save_cif_filename = \"predicted_random50\")\n", 241 | "\n", 242 | "# Note that the crystal structure data (CIF files) provided in the Supplementary Data (https://doi.org/10.1016/j.commatsci.2022.111496)\n", 243 | "# are locally optimized suructures of the above prediction results using DFT calculations." 244 | ] 245 | } 246 | ], 247 | "metadata": { 248 | "kernelspec": { 249 | "display_name": "Python 3 (ipykernel)", 250 | "language": "python", 251 | "name": "python3" 252 | }, 253 | "language_info": { 254 | "codemirror_mode": { 255 | "name": "ipython", 256 | "version": 3 257 | }, 258 | "file_extension": ".py", 259 | "mimetype": "text/x-python", 260 | "name": "python", 261 | "nbconvert_exporter": "python", 262 | "pygments_lexer": "ipython3", 263 | "version": "3.8.8" 264 | } 265 | }, 266 | "nbformat": 4, 267 | "nbformat_minor": 5 268 | } 269 | --------------------------------------------------------------------------------