├── .gitignore ├── LICENSE ├── README.md ├── argparser.py ├── attn_vis.py ├── baselines ├── README.md ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── data_utils.cpython-37.pyc │ ├── data_utils.cpython-38.pyc │ ├── dimenet_pp.cpython-37.pyc │ ├── dimenet_pp.cpython-38.pyc │ ├── egnn.cpython-37.pyc │ ├── egnn.cpython-38.pyc │ ├── gin.cpython-37.pyc │ ├── painn.cpython-37.pyc │ ├── painn.cpython-38.pyc │ ├── schnet.cpython-37.pyc │ └── schnet.cpython-38.pyc ├── data_utils.py ├── dimenet_pp.py ├── egnn.py ├── painn.py ├── schnet.py └── spk_utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── acsf.cpython-37.pyc │ ├── acsf.cpython-38.pyc │ ├── activations.cpython-37.pyc │ ├── activations.cpython-38.pyc │ ├── base.cpython-37.pyc │ ├── base.cpython-38.pyc │ ├── blocks.cpython-37.pyc │ ├── blocks.cpython-38.pyc │ ├── cfconv.cpython-37.pyc │ ├── cfconv.cpython-38.pyc │ ├── cutoff.cpython-37.pyc │ ├── cutoff.cpython-38.pyc │ ├── initializers.cpython-37.pyc │ ├── initializers.cpython-38.pyc │ ├── neighbors.cpython-37.pyc │ └── neighbors.cpython-38.pyc │ ├── acsf.py │ ├── activations.py │ ├── base.py │ ├── blocks.py │ ├── cfconv.py │ ├── cutoff.py │ ├── initializers.py │ └── neighbors.py ├── featurization ├── __pycache__ │ └── data_utils.cpython-37.pyc └── data_utils.py ├── image ├── 3dstructgen-mof.png ├── Fig1.jpg └── Fig1.png ├── model_shap.py ├── models └── transformer.py ├── nist_test.py ├── pressure_adapt.py ├── process ├── README ├── create_geo_features.sh ├── prepare_mof_features.py ├── process_csd_data.py ├── process_csd_data_baselines.py ├── process_nist_data.py └── tools │ ├── get_atom_features.py │ ├── get_bond_features.py │ └── remove_waters.py ├── requirements.txt ├── train_baselines.py ├── train_ml.py ├── train_mofnet.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Matgen-project 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MOFNet 2 | MOFNet is a deep learning model that can predict adsorption isotherm for MOFs based on hierarchical representation, graph transformer and pressure adaptive mechanism. We elaborately design a hierarchical representation to describe the MOFs structure. A graph transformer is used to capture atomic level information, which can help learn chemical features required at low-pressure conditions. A pressure adaptive mechanism is used to interpolate and extrapolate the given limited data points by transfer learning, which can predict adsorption isotherms on a wider pressure range by only one model. The following is the architecture of MOFNet. 3 | 4 | 5 | 6 | ## Installation 7 | Please see dependencies in requirements.txt 8 | 9 | ## Dataset 10 | 11 | We released the training and testing data on the [Matgen website](https://matgen.nscc-gz.cn/dataset.html), which can be obtained by the following command. 12 | ``` 13 | $ wget https://matgen.nscc-gz.cn/dataset/download/CSD-MOFDB_xx.tar.gz #xx: realased data 14 | $ wget https://matgen.nscc-gz.cn/dataset/download/NIST-ISODB_xx.tar.gz 15 | ``` 16 | 17 | You can construct the data directory from the downloaded data as follows. 18 | 19 | ``` 20 | |-- data 21 | ||-- CSD-MOFDB 22 | ||-- NIST-ISODB 23 | ``` 24 | 25 | ## CSD-MOFDB 26 | We collected 7306, 6998 and 8562 MOFs for N2, CO2 and CH4 from the Cambridge Structural Database (CSD, version 5.4) dataset. 27 | GCMC simulations were carried out to calculate the adsorption data of MOFs for N2, CO2 and CH4 using RASPA software. 28 | We set 8 pressure points from the range of 0.2 kPa - 80 kPa, 5 kPa – 20,000 kPa and 100 kPa – 10,000kPa for N2, CO2 and CH4, respectively. 29 | ``` 30 | | --CSD-MOFDB 31 | ||--CIFs # CIF format files. 32 | ||--global_features 33 | ||--label_by_GCMC #calculated adsorption data by GCMC method. 34 | ||--local_features 35 | ||--mol_unit #molecule unit in mol format 36 | ||--README 37 | ``` 38 | 39 | ## NIST-ISODB 40 | We obtained 54 MOFs with 1876 pressure data points covering N2, CO2 and CH4 adsorbate molecules from the NIST/ARPA-E database. 41 | 42 | ``` 43 | |--NIST-ISODB 44 | ||--CIFs #CIF format files. 45 | ||--global_features 46 | ||--isotherm_data #experimental data. 47 | ||--local_features 48 | ||--MOFNet #MOFNet predicting results. 49 | ||--mol_unit #molecule unit in mol format 50 | ||--README 51 | ``` 52 | 53 | 54 | ## Processing 55 | 56 | ### How to generate local features? 57 | First, the CSD package need to install on your server and use CSD Python API to obtain CIF files. We create a script in process file, and run the following command to generate local features file. 58 | ``` 59 | $ python process/process_csd_data.py 60 | ``` 61 | 62 | ### How to obtain global features? 63 | The important structural properties including largest cavity diameter (LCD),pore-limiting diameter (PLD), and helium void fraction, etc., were calculated using open-source software Zeo++. 64 | 65 | 66 | ## Model training 67 | ``` 68 | $ python -u train_mofnet.py --data_dir data/CSD-MOFDB --gas_type --pressure --save_dir --use_global_feature 69 | ``` 70 | 71 | ## Transfer learning 72 | ``` 73 | $ python -u pressure_adapt.py --data_dir data/CSD-MOFDB --gas_type --pressure --save_dir --ori_dir /_ --adapter_dim 8 74 | ``` 75 | 76 | ## Prediction 77 | ``` 78 | $ python -u nist_test.py --data_dir data/NIST-ISODB --gas_type --pressure --save_dir --img_dir 79 | ``` 80 | 81 | We also welecome users to use our [3DStructGen UI interface](https://matgen.nscc-gz.cn/3dstructgen/v2/mod/3dstructgen_newUI.html) to predict crystal properties by the following steps: 82 | ``` 83 | # Upload your CIF crystal files into 3DStuctGen interface; 84 | # Click "Caculate" button and use the APP of "Artificical Intelligence - MOF" 85 | # Choose the uptake gas and pressure range you want to calculate and then submit. 86 | ``` 87 | 88 | 89 | 90 | ## Acknowledgments 91 | The implementation of the Graph Transformer module is built upon [Molecule Attention Transformer](https://github.com/ardigen/MAT). 92 | 93 | ## Reference: 94 | [1]. Maziarka, {\L}ukasz and Danel, Tomasz and Mucha, S{\l}awomir and Rataj, Krzysztof and Tabor, Jacek and Jastrz{\k{e}}bski, Stanis{\l}aw: Molecule attention transforme. arXiv preprint arXiv:2002.08264 2020 95 | 96 | [2]. Pin Chen, Yu Wang, Hui Yan, Sen Gao, Zexin Xu, Yangzhong Li, Qing Mo, Junkang Huang, Jun Tao, GeChuanqi Pan, Jiahui Li & Yunfei Du. 3DStructGen: an interactive web-based 3D structure generation for non-periodic molecule and crystal. J Cheminform 12, 7 (2020). https://doi.org/10.1186/s13321-020-0411-2 97 | -------------------------------------------------------------------------------- /argparser.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | 4 | def parse_train_args(): 5 | parser = ArgumentParser() 6 | add_data_args(parser) 7 | add_train_args(parser) 8 | args = parser.parse_args() 9 | args = vars(args) 10 | lambda_mat = [float(_) for _ in args['weight_split'].split(',')] 11 | assert len(lambda_mat) == 3 12 | lambda_sum = sum(lambda_mat) 13 | args['lambda_attention'] = lambda_mat[0] / lambda_sum 14 | args['lambda_distance'] = lambda_mat[-1] / lambda_sum 15 | if args['d_mid_list'] == 'None': 16 | args['d_mid_list'] = [] 17 | else: 18 | args['d_mid_list'] = [int(_) for _ in args['d_mid_list'].split(',')] 19 | makedirs(args['save_dir'] + f"/{args['gas_type']}_{args['pressure']}/") 20 | return args 21 | 22 | def parse_predict_args(): 23 | parser = ArgumentParser() 24 | add_data_args(parser) 25 | add_train_args(parser) 26 | args = parser.parse_args() 27 | args = vars(args) 28 | lambda_mat = [float(_) for _ in args['weight_split'].split(',')] 29 | assert len(lambda_mat) == 3 30 | lambda_sum = sum(lambda_mat) 31 | args['lambda_attention'] = lambda_mat[0] / lambda_sum 32 | args['lambda_distance'] = lambda_mat[-1] / lambda_sum 33 | if args['d_mid_list'] == 'None': 34 | args['d_mid_list'] = [] 35 | else: 36 | args['d_mid_list'] = [int(_) for _ in args['d_mid_list'].split(',')] 37 | p_cond = args['pressure'].split(',') 38 | assert len(p_cond) == 3 39 | args['pressure'] = (float(p_cond[0]), float(p_cond[1]), int(p_cond[2])) 40 | return args 41 | 42 | def parse_baseline_args(): 43 | parser = ArgumentParser() 44 | add_data_args(parser) 45 | add_baseline_args(parser) 46 | args = parser.parse_args() 47 | args = vars(args) 48 | makedirs(args['save_dir'] + f"/{args['gas_type']}_{args['pressure']}/") 49 | return args 50 | 51 | def parse_finetune_args(): 52 | parser = ArgumentParser() 53 | add_data_args(parser) 54 | add_finetune_args(parser) 55 | args = parser.parse_args() 56 | args = vars(args) 57 | makedirs(args['save_dir'] + f"/{args['gas_type']}_{args['pressure']}/") 58 | return args 59 | 60 | def parse_ml_args(): 61 | parser = ArgumentParser() 62 | add_data_args(parser) 63 | add_ml_args(parser) 64 | args = parser.parse_args() 65 | args = vars(args) 66 | makedirs(args['save_dir'] + f"/{args['ml_type']}/{args['gas_type']}_{args['pressure']}/") 67 | return args 68 | 69 | def makedirs(path: str, isfile: bool = False): 70 | if isfile: 71 | path = os.path.dirname(path) 72 | if path != '': 73 | os.makedirs(path, exist_ok=True) 74 | 75 | def add_ml_args(parser: ArgumentParser): 76 | parser.add_argument('--ml_type', type=str, default='RF', 77 | help='ML algorithm, SVR/DT/RF.') 78 | 79 | parser.add_argument('--seed', type=int, default=9999, 80 | help='Random seed to use when splitting data into train/val/test sets.' 81 | 'When `num_folds` > 1, the first fold uses this seed and all' 82 | 'subsequent folds add 1 to the seed.') 83 | parser.add_argument('--fold', type=int, default=10, 84 | help='Fold num.') 85 | 86 | def add_data_args(parser: ArgumentParser): 87 | parser.add_argument('--data_dir', type=str, 88 | help='Dataset directory, containing label/ and processed/ subdirectories.') 89 | 90 | parser.add_argument('--save_dir', type=str, 91 | help='Model directory.') 92 | 93 | 94 | parser.add_argument('--gas_type', type=str, 95 | help='Gas type for prediction.') 96 | 97 | parser.add_argument('--pressure', type=str, 98 | help='Pressure condition for prediction.') 99 | 100 | parser.add_argument('--img_dir', type=str, default='', 101 | help='Directory for visualized isotherms') 102 | 103 | 104 | parser.add_argument('--name', type=str, default='', 105 | help='Target MOF name for attention visualization.') 106 | 107 | def add_finetune_args(parser: ArgumentParser): 108 | parser.add_argument('--ori_dir', type=str, 109 | help='Pretrained model directory, containing model of different Folds.') 110 | 111 | parser.add_argument('--epoch', type=int, default=100, 112 | help='Epoch num.') 113 | 114 | parser.add_argument('--batch_size', type=int, default=32, 115 | help='Batch size.') 116 | 117 | parser.add_argument('--fold', type=int, default=10, 118 | help='Fold num.') 119 | 120 | parser.add_argument('--lr', type=float, default=0.0007, 121 | help='Learning rate.') 122 | 123 | parser.add_argument('--adapter_dim', type=int, default=8, 124 | help='Adapted vector dimension') 125 | 126 | parser.add_argument('--seed', type=int, default=9999, 127 | help='Random seed to use when splitting data into train/val/test sets.') 128 | 129 | def add_baseline_args(parser: ArgumentParser): 130 | 131 | parser.add_argument('--model_name',type=str,default='gin', 132 | help='Baseline Model, gin/egnn/schnet/painn.') 133 | 134 | parser.add_argument('--gpu', type=int, 135 | help='GPU id to allocate.') 136 | 137 | parser.add_argument('--seed', type=int, default=9999, 138 | help='Random seed to use when splitting data into train/val/test sets.') 139 | 140 | parser.add_argument('--d_model', type=int, default=1024, 141 | help='Hidden size of baseline model.') 142 | 143 | parser.add_argument('--N', type=int, default=2, 144 | help='Layer num of baseline model.') 145 | 146 | parser.add_argument('--use_global_feature', action='store_true', 147 | help='Whether to use global features(graph-level features).') 148 | 149 | parser.add_argument('--warmup_step', type=int, default=2000, 150 | help='Warmup steps.') 151 | 152 | parser.add_argument('--epoch', type=int, default=100, 153 | help='Epoch num.') 154 | 155 | parser.add_argument('--batch_size', type=int, default=32, 156 | help='Batch size.') 157 | 158 | parser.add_argument('--fold', type=int, default=10, 159 | help='Fold num.') 160 | 161 | parser.add_argument('--lr', type=float, default=0.0007, 162 | help='Maximum learning rate, (warmup_step * d_model) ** -0.5 .') 163 | 164 | def add_train_args(parser: ArgumentParser): 165 | 166 | parser.add_argument('--seed', type=int, default=9999, 167 | help='Random seed to use when splitting data into train/val/test sets.') 168 | 169 | parser.add_argument('--d_model', type=int, default=1024, 170 | help='Hidden size of transformer model.') 171 | 172 | parser.add_argument('--N', type=int, default=2, 173 | help='Layer num of transformer model.') 174 | 175 | parser.add_argument('--h', type=int, default=16, 176 | help='Attention head num of transformer model.') 177 | 178 | parser.add_argument('--n_generator_layers', type=int, default=2, 179 | help='Layer num of generator(MLP) model') 180 | 181 | parser.add_argument('--weight_split', type=str, default='1,1,1', 182 | help='Unnormalized weights of Self-Attention/Adjacency/Distance Matrix respectively in Graph Transformer.') 183 | 184 | parser.add_argument('--leaky_relu_slope', type=float, default=0.0, 185 | help='Leaky ReLU slope for activation functions.') 186 | 187 | parser.add_argument('--dense_output_nonlinearity',type=str,default='silu', 188 | help='Activation Function for predict module, silu/relu/tanh/none.') 189 | 190 | parser.add_argument('--distance_matrix_kernel',type=str,default='bessel', 191 | help='Kernel applied on Distance Matrix, bessel/softmax/exp. For example, exp means setting D(i,j) of node i,j with distance d by exp(-d)') 192 | 193 | parser.add_argument('--dropout', type=float, default=0.1, 194 | help='Dropout ratio.') 195 | 196 | parser.add_argument('--aggregation_type', type=str, default='mean', 197 | help='Type for aggregeting node feature into graph feature, mean/sum/dummy_node.') 198 | 199 | parser.add_argument('--use_global_feature', action='store_true', 200 | help='Whether to use global features(graph-level features).') 201 | 202 | parser.add_argument('--use_ffn_only', action='store_true', 203 | help='Use DNN Generator which only considers global features. ') 204 | 205 | parser.add_argument('--d_mid_list', type=str, default='128,512', 206 | help='Projection Layers to augment global feature dim to local feature dim.') 207 | 208 | parser.add_argument('--warmup_step', type=int, default=2000, 209 | help='Warmup steps.') 210 | 211 | parser.add_argument('--epoch', type=int, default=300, 212 | help='Epoch num.') 213 | 214 | parser.add_argument('--batch_size', type=int, default=64, 215 | help='Batch size.') 216 | 217 | parser.add_argument('--fold', type=int, default=10, 218 | help='Fold num.') 219 | 220 | parser.add_argument('--lr', type=float, default=0.0007, 221 | help='Maximum learning rate, (warmup_step * d_model) ** -0.5 .') 222 | 223 | 224 | 225 | 226 | -------------------------------------------------------------------------------- /attn_vis.py: -------------------------------------------------------------------------------- 1 | import shap 2 | import torch 3 | from collections import defaultdict 4 | from featurization.data_utils import load_data_from_df, construct_loader_gf_pressurever, construct_dataset_gf_pressurever, data_prefetcher 5 | from models.transformer import make_model 6 | import numpy as np 7 | import os 8 | from argparser import parse_train_args 9 | import pickle 10 | from tqdm import tqdm 11 | import matplotlib.pyplot as plt 12 | import seaborn as sns 13 | from utils import * 14 | 15 | periodic_table = ('Dummy','H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 16 | 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 17 | 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Te', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 18 | 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 19 | 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Unk') 20 | 21 | 22 | model_params = parse_train_args() 23 | img_dir = os.path.join(model_params['img_dir'],'attn') 24 | os.makedirs(img_dir,exist_ok=True) 25 | 26 | 27 | def heapmap(atoms, attn, name): 28 | plt.cla() 29 | f, ax = plt.subplots(figsize=(20, 15)) 30 | colormap = 'Reds' 31 | h = sns.heatmap(attn, vmax=attn.max(), yticklabels = atoms, xticklabels = atoms, square=True, cmap=colormap, cbar=False) 32 | fontsize = 15 33 | cb=h.figure.colorbar(h.collections[0]) 34 | cb.ax.tick_params(labelsize=fontsize) 35 | ax.tick_params(labelsize=fontsize,rotation=0) 36 | ax.set_xticklabels(ax.get_xticklabels(), rotation=90) 37 | plt.savefig(os.path.join(img_dir, name + '.pdf')) 38 | 39 | def test(model, data_loader, name_list): 40 | model.eval() 41 | batch_idx = -1 42 | ans = {} 43 | for data in tqdm(data_loader): 44 | batch_idx += 1 45 | adjacency_matrix, node_features, distance_matrix, global_features, y = (_.cpu() for _ in data) 46 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0 47 | graph_rep = model.encode(node_features, batch_mask, adjacency_matrix, distance_matrix, None) 48 | attn = model.encoder.layers[0].self_attn.self_attn.detach().cpu().numpy() 49 | atoms = node_features.numpy()[:,:,:83].argmax(axis=-1).reshape(-1) 50 | attn = attn[0].mean(axis=0) 51 | atoms = applyIndexOnList(periodic_table, atoms) 52 | ans[name_list[batch_idx]] = { 53 | 'atoms':atoms, 54 | 'attn':attn 55 | } 56 | heapmap(atoms, attn, name_list[batch_idx]) 57 | return ans 58 | 59 | if __name__ == '__main__': 60 | batch_size = 1 61 | device_ids = [0,1,2,3] 62 | X, f, y,p = load_data_from_df(model_params['data_dir'],gas_type=model_params['gas_type'], pressure="all",add_dummy_node = True,use_global_features = True, return_names=True) 63 | print("X,f,y,p") 64 | tar_idx = np.where(p==model_params['pressure'])[0][0] 65 | y = np.array(y) 66 | mean = y[...,tar_idx].mean() 67 | std = y[...,tar_idx].std() 68 | f = np.array(f) 69 | fmean = f.mean(axis=0) 70 | fstd = f.std(axis=0) 71 | test_errors_all = [] 72 | f = (f - fmean) / fstd 73 | X, names = X 74 | 75 | print(f'Loaded {len(X)} data.') 76 | 77 | fold_idx = 1 78 | save_dir = model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}/Fold-{fold_idx}" 79 | ckpt_handler = CheckpointHandler(save_dir) 80 | state = ckpt_handler.checkpoint_best(use_cuda=False) 81 | model = make_model(**state['params']) 82 | model = torch.nn.DataParallel(model) 83 | model.load_state_dict(state['model']) 84 | model = model.module 85 | if model_params['name'] == '': 86 | sample_idx = np.arange(1000) 87 | tar_name = 'all' 88 | else: 89 | if model_params['name'] in names: 90 | sample_idx = [names.index(model_params['name'])] 91 | tar_name = model_params['name'] 92 | else: 93 | sample_idx = [0] 94 | tar_name = 'random' 95 | train_sample = construct_dataset_gf_pressurever(applyIndexOnList(X,sample_idx), f[sample_idx], y[sample_idx],p, is_train=False, tar_point=model_params['pressure'],mask_point=model_params['pressure']) 96 | sample_loader = construct_loader_gf_pressurever(train_sample, 1, shuffle=False) 97 | ans = test(model, sample_loader, applyIndexOnList(names, sample_idx)) 98 | 99 | with open(os.path.join(img_dir,f"attn_{tar_name}.p"),'wb') as f: 100 | pickle.dump(ans, f) 101 | 102 | -------------------------------------------------------------------------------- /baselines/README.md: -------------------------------------------------------------------------------- 1 | ### Baselines 2 | 3 | Adapted 4 Baselines: 4 | 5 | - Schnet https://arxiv.org/abs/1706.08566 6 | - DimeNet++ https://arxiv.org/abs/2011.14115 7 | - EGNN https://arxiv.org/abs/2102.09844 8 | - PaiNN https://arxiv.org/abs/2102.03150 -------------------------------------------------------------------------------- /baselines/__init__.py: -------------------------------------------------------------------------------- 1 | from ast import mod 2 | from turtle import forward 3 | from .egnn import * 4 | from .painn import * 5 | from .schnet import * 6 | from .dimenet_pp import * 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | def make_baseline_model(d_atom, model_name, N=2, d_model=128, use_global_feature=False, d_feature=9, **kwargs): 11 | model = None 12 | if model_name == 'egnn': 13 | representation = EGNN(in_node_nf=d_atom, hidden_nf=d_model, n_layers=N, attention=True) 14 | use_adj = True 15 | elif model_name == 'dimenetpp': 16 | representation = DimeNetPlusPlus(hidden_channels=d_model, out_channels=d_model, num_input=d_atom, num_blocks=N, int_emb_size=d_model // 2, basis_emb_size=8, out_emb_channels=d_model * 2, num_spherical=7, num_radial=6) 17 | use_adj = True 18 | elif model_name == 'schnet': 19 | representation = SchNet(n_atom_basis=d_model, n_filters=d_model, n_interactions=N, max_z=d_atom) 20 | use_adj = False 21 | elif model_name == 'painn': 22 | representation = PaiNN(n_atom_basis=d_model, n_interactions=N, max_z=d_atom) 23 | use_adj = False 24 | if use_global_feature: 25 | out = Generator_with_gf(d_model=d_model, d_gf=d_feature) 26 | else: 27 | out = Generator(d_model=d_model) 28 | model = BaselineModel(representation=representation, output=out, use_adj=use_adj) 29 | return model 30 | 31 | class Generator(nn.Module): 32 | def __init__(self, d_model): 33 | super(Generator, self).__init__() 34 | self.hidden_nf = d_model 35 | self.node_dec = nn.Sequential(nn.Linear(self.hidden_nf, self.hidden_nf), 36 | nn.SiLU(), 37 | nn.Linear(self.hidden_nf, self.hidden_nf)) 38 | 39 | self.graph_dec = nn.Sequential(nn.Linear(self.hidden_nf, self.hidden_nf), 40 | nn.SiLU(), 41 | nn.Linear(self.hidden_nf, 1)) 42 | 43 | def forward(self, h, atom_mask, global_feature=None): 44 | h = self.node_dec(h) 45 | h = h * atom_mask.unsqueeze(-1) 46 | h = torch.sum(h, dim=1) 47 | pred = self.graph_dec(h) 48 | return pred.squeeze(1) 49 | 50 | class Generator_with_gf(nn.Module): 51 | def __init__(self, d_model, d_gf): 52 | super(Generator_with_gf, self).__init__() 53 | self.hidden_nf = d_model 54 | self.input_nf = d_gf 55 | self.node_dec = nn.Sequential(nn.Linear(self.hidden_nf, self.hidden_nf), 56 | nn.SiLU(), 57 | nn.Linear(self.hidden_nf, self.hidden_nf)) 58 | 59 | self.gf_enc = nn.Sequential(nn.Linear(self.input_nf, self.hidden_nf // 2), 60 | nn.SiLU(), 61 | nn.Linear(self.hidden_nf // 2, self.hidden_nf)) 62 | 63 | self.graph_dec = nn.Sequential(nn.Linear(self.hidden_nf * 2, self.hidden_nf), 64 | nn.SiLU(), 65 | nn.Linear(self.hidden_nf, 1)) 66 | 67 | def forward(self, h, atom_mask, global_feature): 68 | h = self.node_dec(h) 69 | h = h * atom_mask.unsqueeze(-1) 70 | h = torch.sum(h, dim=1) 71 | g = self.gf_enc(global_feature) 72 | h = torch.cat([h,g], dim=1) 73 | pred = self.graph_dec(h) 74 | return pred.squeeze(1) 75 | 76 | class BaselineModel(nn.Module): 77 | def __init__(self, representation, output, use_adj=True): 78 | super(BaselineModel, self).__init__() 79 | self.representation = representation 80 | self.output = output 81 | self.use_adj = use_adj 82 | def forward(self, node_features, batch_mask, pos, adj, global_feature=None): 83 | if not self.use_adj: 84 | neighbors, neighbor_mask = adj 85 | rep = self.representation(node_features, pos, neighbors, neighbor_mask, batch_mask) 86 | else: 87 | rep = self.representation(node_features, batch_mask, pos, adj) 88 | out = self.output(rep, batch_mask, global_feature) 89 | return out -------------------------------------------------------------------------------- /baselines/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/data_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/data_utils.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/data_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/data_utils.cpython-38.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/dimenet_pp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/dimenet_pp.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/dimenet_pp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/dimenet_pp.cpython-38.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/egnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/egnn.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/egnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/egnn.cpython-38.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/gin.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/gin.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/painn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/painn.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/painn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/painn.cpython-38.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/schnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/schnet.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/__pycache__/schnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/schnet.cpython-38.pyc -------------------------------------------------------------------------------- /baselines/data_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from sklearn.metrics import pairwise_distances 9 | from torch.utils.data import Dataset, dataset 10 | from scipy.sparse import coo_matrix 11 | import json 12 | import copy 13 | 14 | 15 | FloatTensor = torch.FloatTensor 16 | LongTensor = torch.LongTensor 17 | IntTensor = torch.IntTensor 18 | DoubleTensor = torch.DoubleTensor 19 | 20 | def load_data_from_df(dataset_path, gas_type, pressure, use_global_features=False): 21 | print(dataset_path + f'/label/{gas_type}/{gas_type}_ads_all.csv') 22 | data_df = pd.read_csv(dataset_path + f'/label/{gas_type}/{gas_type}_ads_all.csv',header=0) 23 | 24 | data_x = data_df['name'].values 25 | if pressure == 'all': 26 | data_y = data_df.iloc[:,1:].values 27 | else: 28 | data_y = data_df[pressure].values 29 | 30 | if data_y.dtype == np.float64: 31 | data_y = data_y.astype(np.float32) 32 | 33 | x_all, y_all, name_all = load_data_from_processed(dataset_path, data_x, data_y) 34 | 35 | if use_global_features: 36 | f_all = load_data_with_global_features(dataset_path, name_all, gas_type) 37 | if pressure == 'all': 38 | return x_all, f_all, y_all, data_df.columns.values[1:] 39 | return x_all, f_all, y_all 40 | 41 | if pressure == 'all': 42 | return x_all, y_all, data_df.columns.values[1:] 43 | return x_all, y_all 44 | 45 | def load_data_with_global_features(dataset_path, processed_files, gas_type): 46 | global_feature_path = dataset_path + f'/label/{gas_type}/{gas_type}_global_features_update.csv' 47 | data_df = pd.read_csv(global_feature_path,header=0) 48 | data_x = data_df.iloc[:, 0].values 49 | data_f = data_df.iloc[:,1:].values.astype(np.float32) 50 | data_dict = {} 51 | for i in range(data_x.shape[0]): 52 | data_dict[data_x[i]] = data_f[i] 53 | f_all = [data_dict[_] for _ in processed_files] 54 | return f_all 55 | 56 | 57 | 58 | def load_data_from_processed(dataset_path, processed_files, labels): 59 | x_all, y_all, name_all = [], [], [] 60 | 61 | for files, label in zip(processed_files, labels): 62 | 63 | data_file = dataset_path + '/processed_en/' + files + '.p' 64 | try: 65 | afm, row, col, pos = pickle.load(open(data_file, "rb")) 66 | x_all.append([afm, row, col, pos]) 67 | y_all.append([label]) 68 | name_all.append(files) 69 | except: 70 | pass 71 | 72 | return x_all, y_all, name_all 73 | 74 | class MOF: 75 | 76 | def __init__(self, x, y, index, feature = None): 77 | self.node_features = x[0] 78 | self.edges = np.array([x[1],x[2]]) 79 | self.pos = x[3] 80 | self.y = y 81 | self.index = index 82 | self.global_feature = feature 83 | self.size = x[0].shape[0] 84 | self.adj, self.nbh, self.nbh_mask = self.neighbor_matrix() 85 | 86 | def neighbor_matrix(self): 87 | csr = coo_matrix((np.ones_like(self.edges[0]), self.edges), shape=(self.size, self.size)).tocsr() 88 | rowptr, col = csr.indptr, csr.indices 89 | degree = rowptr[1:] - rowptr[:-1] 90 | max_d = degree.max() 91 | _range = np.tile(np.arange(max_d),(self.size,1)).reshape(-1) 92 | _degree = degree.repeat(max_d).reshape(-1) 93 | mask = _range < _degree 94 | ret_nbh = np.zeros(self.size * max_d) 95 | ret_nbh[mask] = col 96 | return csr.toarray(), ret_nbh.reshape(self.size, max_d), mask.reshape(self.size, max_d) 97 | 98 | 99 | class MOFDataset(Dataset): 100 | 101 | def __init__(self, data_list): 102 | 103 | self.data_list = data_list 104 | 105 | def __len__(self): 106 | return len(self.data_list) 107 | 108 | def __getitem__(self, key): 109 | if type(key) == slice: 110 | return MOFDataset(self.data_list[key]) 111 | return self.data_list[key] 112 | 113 | def construct_dataset_gf(x_all, f_all, y_all): 114 | output = [MOF(data[0], data[2], i, data[1]) 115 | for i, data in enumerate(zip(x_all, f_all, y_all))] 116 | return MOFDataset(output) 117 | 118 | def pad_array(array, shape, dtype=np.float32): 119 | padded_array = np.zeros(shape, dtype=dtype) 120 | padded_array[:array.shape[0], :array.shape[1]] = array 121 | return padded_array 122 | 123 | def mof_collate_func_adj(batch): 124 | pos_list, features_list,global_features_list = [], [], [] 125 | adjs = [] 126 | labels = [] 127 | 128 | max_size = 0 129 | for molecule in batch: 130 | if type(molecule.y[0]) == np.ndarray: 131 | labels.append(molecule.y[0]) 132 | else: 133 | labels.append(molecule.y) 134 | if molecule.node_features.shape[0] > max_size: 135 | max_size = molecule.node_features.shape[0] 136 | 137 | for molecule in batch: 138 | pos_list.append(pad_array(molecule.pos, (max_size, 3))) 139 | features_list.append(pad_array(molecule.node_features, (max_size, molecule.node_features.shape[1]))) 140 | adjs.append(pad_array(molecule.adj, (max_size, max_size))) 141 | global_features_list.append(molecule.global_feature) 142 | 143 | return [FloatTensor(features_list), FloatTensor(pos_list), LongTensor(adjs), FloatTensor(global_features_list), FloatTensor(labels)] 144 | 145 | def mof_collate_func_nbh(batch): 146 | pos_list, features_list, global_features_list = [], [], [] 147 | nbhs, nbh_masks = [],[] 148 | labels = [] 149 | 150 | max_size = 0 151 | max_degree = 0 152 | for molecule in batch: 153 | if type(molecule.y[0]) == np.ndarray: 154 | labels.append(molecule.y[0]) 155 | else: 156 | labels.append(molecule.y) 157 | if molecule.node_features.shape[0] > max_size: 158 | max_size = molecule.node_features.shape[0] 159 | if molecule.nbh.shape[1] > max_degree: 160 | max_degree = molecule.nbh.shape[1] 161 | 162 | for molecule in batch: 163 | pos_list.append(pad_array(molecule.pos, (max_size, 3))) 164 | features_list.append(pad_array(molecule.node_features, (max_size, molecule.node_features.shape[1]))) 165 | nbhs.append(pad_array(molecule.nbh, (max_size, max_degree))) 166 | nbh_masks.append(pad_array(molecule.nbh_mask, (max_size, max_degree))) 167 | global_features_list.append(molecule.global_feature) 168 | 169 | return [FloatTensor(features_list), FloatTensor(pos_list), LongTensor(nbhs), FloatTensor(nbh_masks), FloatTensor(global_features_list), FloatTensor(labels)] 170 | 171 | def construct_loader(x, f, y, batch_size, shuffle=True, use_adj=True): 172 | data_set = construct_dataset_gf(x, f, y) 173 | loader = torch.utils.data.DataLoader(dataset=data_set, 174 | batch_size=batch_size, 175 | num_workers=8, 176 | collate_fn=mof_collate_func_adj if use_adj else mof_collate_func_nbh, 177 | pin_memory=True, 178 | shuffle=shuffle) 179 | return loader 180 | 181 | class data_prefetcher(): 182 | def __init__(self, loader, device): 183 | self.loader = iter(loader) 184 | self.stream = torch.cuda.Stream(device) 185 | self.preload() 186 | 187 | def preload(self): 188 | try: 189 | self.next_data = next(self.loader) 190 | except StopIteration: 191 | self.next_data = None 192 | return 193 | with torch.cuda.stream(self.stream): 194 | self.next_data = tuple(_.cuda(non_blocking=True) for _ in self.next_data) 195 | 196 | def next(self): 197 | torch.cuda.current_stream().wait_stream(self.stream) 198 | batch = self.next_data 199 | self.preload() 200 | return batch -------------------------------------------------------------------------------- /baselines/dimenet_pp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.nn import radius_graph 4 | from torch_geometric.nn.acts import swish 5 | from torch_geometric.nn.inits import glorot_orthogonal 6 | from torch_geometric.nn.models.dimenet import ( 7 | BesselBasisLayer, 8 | Envelope, 9 | ResidualLayer, 10 | SphericalBasisLayer, 11 | ) 12 | from torch_scatter import scatter 13 | from torch_sparse import SparseTensor 14 | import sympy as sym 15 | 16 | def dense_to_sparse(adj): 17 | r"""Converts a dense adjacency matrix to a sparse adjacency matrix defined 18 | by edge indices and edge attributes. 19 | 20 | Args: 21 | adj (Tensor): The dense adjacency matrix. 22 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 23 | """ 24 | assert adj.dim() >= 2 and adj.dim() <= 3 25 | assert adj.size(-1) == adj.size(-2) 26 | 27 | index = adj.nonzero(as_tuple=True) 28 | edge_attr = adj[index] 29 | 30 | if len(index) == 3: 31 | batch = index[0] * adj.size(-1) 32 | index = (batch + index[1], batch + index[2]) 33 | 34 | return torch.stack(index, dim=0), edge_attr 35 | 36 | class MLP(torch.nn.Module): 37 | def __init__(self, input_size, output_size, hidden_sizes, activation_hidden, activation_out, biases, dropout): 38 | super(MLP, self).__init__() 39 | self.activation_hidden = activation_hidden 40 | self.activation_out = activation_out 41 | self.dropout = dropout 42 | 43 | if len(hidden_sizes) > 0: 44 | self.linear_layers = torch.nn.ModuleList([torch.nn.Linear(input_size, hidden_sizes[0], bias = biases)]) 45 | self.linear_layers.extend([torch.nn.Linear(in_size, out_size, bias = biases) 46 | for (in_size, out_size) 47 | in zip(hidden_sizes[0:-1], (hidden_sizes[1:]))]) 48 | self.linear_layers.append(torch.nn.Linear(hidden_sizes[-1], output_size, bias = biases)) 49 | 50 | else: 51 | self.linear_layers = torch.nn.ModuleList([torch.nn.Linear(input_size, output_size, bias = biases)]) 52 | 53 | def forward(self, x): 54 | if len(self.linear_layers) == 1: 55 | out = self.activation_out(self.linear_layers[0](x)) 56 | 57 | else: 58 | out = self.activation_hidden(self.linear_layers[0](x)) 59 | for i, layer in enumerate(self.linear_layers[1:-1]): 60 | out = self.activation_hidden(layer(out)) 61 | out = torch.nn.functional.dropout(out, p = self.dropout, training = self.training) 62 | out = self.activation_out(self.linear_layers[-1](out)) 63 | 64 | return out 65 | 66 | class InteractionPPBlock(torch.nn.Module): 67 | def __init__( 68 | self, 69 | hidden_channels, 70 | int_emb_size, 71 | basis_emb_size, 72 | num_spherical, 73 | num_radial, 74 | num_before_skip, 75 | num_after_skip, 76 | act=swish, 77 | ): 78 | super(InteractionPPBlock, self).__init__() 79 | self.act = act 80 | 81 | # Transformations of Bessel and spherical basis representations. 82 | self.lin_rbf1 = nn.Linear(num_radial, basis_emb_size, bias=False) 83 | self.lin_rbf2 = nn.Linear(basis_emb_size, hidden_channels, bias=False) 84 | self.lin_sbf1 = nn.Linear( 85 | num_spherical * num_radial, basis_emb_size, bias=False 86 | ) 87 | self.lin_sbf2 = nn.Linear(basis_emb_size, int_emb_size, bias=False) 88 | 89 | # Dense transformations of input messages. 90 | self.lin_kj = nn.Linear(hidden_channels, hidden_channels) 91 | self.lin_ji = nn.Linear(hidden_channels, hidden_channels) 92 | 93 | # Embedding projections for interaction triplets. 94 | self.lin_down = nn.Linear(hidden_channels, int_emb_size, bias=False) 95 | self.lin_up = nn.Linear(int_emb_size, hidden_channels, bias=False) 96 | 97 | # Residual layers before and after skip connection. 98 | self.layers_before_skip = torch.nn.ModuleList( 99 | [ 100 | ResidualLayer(hidden_channels, act) 101 | for _ in range(num_before_skip) 102 | ] 103 | ) 104 | self.lin = nn.Linear(hidden_channels, hidden_channels) 105 | self.layers_after_skip = torch.nn.ModuleList( 106 | [ 107 | ResidualLayer(hidden_channels, act) 108 | for _ in range(num_after_skip) 109 | ] 110 | ) 111 | 112 | #self.reset_parameters() 113 | 114 | def reset_parameters(self): 115 | glorot_orthogonal(self.lin_rbf1.weight, scale=2.0) 116 | glorot_orthogonal(self.lin_rbf2.weight, scale=2.0) 117 | glorot_orthogonal(self.lin_sbf1.weight, scale=2.0) 118 | glorot_orthogonal(self.lin_sbf2.weight, scale=2.0) 119 | 120 | glorot_orthogonal(self.lin_kj.weight, scale=2.0) 121 | self.lin_kj.bias.data.fill_(0) 122 | glorot_orthogonal(self.lin_ji.weight, scale=2.0) 123 | self.lin_ji.bias.data.fill_(0) 124 | 125 | glorot_orthogonal(self.lin_down.weight, scale=2.0) 126 | glorot_orthogonal(self.lin_up.weight, scale=2.0) 127 | 128 | for res_layer in self.layers_before_skip: 129 | res_layer.reset_parameters() 130 | 131 | glorot_orthogonal(self.lin.weight, scale=2.0) 132 | self.lin.bias.data.fill_(0) 133 | 134 | for res_layer in self.layers_after_skip: 135 | res_layer.reset_parameters() 136 | 137 | def forward(self, x, rbf, sbf, idx_kj, idx_ji): 138 | # Initial transformations. 139 | x_ji = self.act(self.lin_ji(x)) 140 | x_kj = self.act(self.lin_kj(x)) 141 | 142 | # Transformation via Bessel basis. 143 | rbf = self.lin_rbf1(rbf) 144 | rbf = self.lin_rbf2(rbf) 145 | x_kj = x_kj * rbf 146 | 147 | # Down-project embeddings and generate interaction triplet embeddings. 148 | x_kj = self.act(self.lin_down(x_kj)) 149 | 150 | # Transform via 2D spherical basis. 151 | sbf = self.lin_sbf1(sbf) 152 | sbf = self.lin_sbf2(sbf) 153 | x_kj = x_kj[idx_kj] * sbf 154 | 155 | # Aggregate interactions and up-project embeddings. 156 | x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0)) 157 | x_kj = self.act(self.lin_up(x_kj)) 158 | 159 | h = x_ji + x_kj 160 | for layer in self.layers_before_skip: 161 | h = layer(h) 162 | h = self.act(self.lin(h)) + x 163 | for layer in self.layers_after_skip: 164 | h = layer(h) 165 | 166 | return h 167 | 168 | 169 | class OutputPPBlock(torch.nn.Module): 170 | def __init__( 171 | self, 172 | num_radial, 173 | hidden_channels, 174 | out_emb_channels, 175 | out_channels, 176 | num_layers, 177 | act=swish, 178 | ): 179 | super(OutputPPBlock, self).__init__() 180 | self.act = act 181 | 182 | self.lin_rbf = nn.Linear(num_radial, hidden_channels, bias=False) 183 | self.lin_up = nn.Linear(hidden_channels, out_emb_channels, bias=True) 184 | self.lins = torch.nn.ModuleList() 185 | for _ in range(num_layers): 186 | self.lins.append(nn.Linear(out_emb_channels, out_emb_channels)) 187 | self.lin = nn.Linear(out_emb_channels, out_channels, bias=False) 188 | 189 | #self.reset_parameters() 190 | 191 | def reset_parameters(self): 192 | glorot_orthogonal(self.lin_rbf.weight, scale=2.0) 193 | glorot_orthogonal(self.lin_up.weight, scale=2.0) 194 | for lin in self.lins: 195 | glorot_orthogonal(lin.weight, scale=2.0) 196 | lin.bias.data.fill_(0) 197 | self.lin.weight.data.fill_(0) 198 | 199 | def forward(self, x, rbf, i, num_nodes=None): 200 | x = self.lin_rbf(rbf) * x 201 | x = scatter(x, i, dim=0, dim_size=num_nodes) 202 | x = self.lin_up(x) 203 | for lin in self.lins: 204 | x = self.act(lin(x)) 205 | return self.lin(x) 206 | 207 | class EmbeddingBlock(torch.nn.Module): 208 | def __init__(self, num_input, num_radial, hidden_channels, act=swish): 209 | super().__init__() 210 | self.act = act 211 | 212 | self.emb = nn.Linear(num_input, hidden_channels) 213 | self.lin_rbf = nn.Linear(num_radial, hidden_channels) 214 | self.lin = nn.Linear(3 * hidden_channels, hidden_channels) 215 | 216 | self.reset_parameters() 217 | 218 | def reset_parameters(self): 219 | # self.emb.weight.data.uniform_(-sqrt(3), sqrt(3)) 220 | self.emb.reset_parameters() 221 | self.lin_rbf.reset_parameters() 222 | self.lin.reset_parameters() 223 | 224 | def forward(self, x, rbf, i, j): 225 | x = self.emb(x) 226 | rbf = self.act(self.lin_rbf(rbf)) 227 | return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1))) 228 | 229 | class DimeNetPlusPlus(torch.nn.Module): 230 | r"""DimeNet++ implementation based on https://github.com/klicperajo/dimenet. 231 | Args: 232 | hidden_channels (int): Hidden embedding size. 233 | out_channels (int): Size of each output sample. 234 | num_blocks (int): Number of building blocks. 235 | int_emb_size (int): Embedding size used for interaction triplets 236 | basis_emb_size (int): Embedding size used in the basis transformation 237 | out_emb_channels(int): Embedding size used for atoms in the output block 238 | num_spherical (int): Number of spherical harmonics. 239 | num_radial (int): Number of radial basis functions. 240 | cutoff: (float, optional): Cutoff distance for interatomic 241 | interactions. (default: :obj:`5.0`) 242 | envelope_exponent (int, optional): Shape of the smooth cutoff. 243 | (default: :obj:`5`) 244 | num_before_skip: (int, optional): Number of residual layers in the 245 | interaction blocks before the skip connection. (default: :obj:`1`) 246 | num_after_skip: (int, optional): Number of residual layers in the 247 | interaction blocks after the skip connection. (default: :obj:`2`) 248 | num_output_layers: (int, optional): Number of linear layers for the 249 | output blocks. (default: :obj:`3`) 250 | act: (function, optional): The activation funtion. 251 | (default: :obj:`swish`) 252 | """ 253 | 254 | url = "https://github.com/klicperajo/dimenet/raw/master/pretrained" 255 | 256 | def __init__( 257 | self, 258 | hidden_channels, 259 | out_channels, 260 | num_blocks, 261 | int_emb_size, 262 | basis_emb_size, 263 | out_emb_channels, 264 | num_spherical, 265 | num_radial, 266 | num_input, 267 | cutoff=5.0, 268 | envelope_exponent=5, 269 | num_before_skip=1, 270 | num_after_skip=2, 271 | num_output_layers=3, 272 | act=swish, 273 | MLP_hidden_sizes = [], 274 | ): 275 | super(DimeNetPlusPlus, self).__init__() 276 | 277 | self.MLP_hidden_sizes = MLP_hidden_sizes 278 | self.hidden_channels = hidden_channels 279 | 280 | self.cutoff = cutoff 281 | 282 | if sym is None: 283 | raise ImportError("Package `sympy` could not be found.") 284 | 285 | self.num_blocks = num_blocks 286 | 287 | self.rbf = BesselBasisLayer(num_radial, cutoff, envelope_exponent) 288 | self.sbf = SphericalBasisLayer( 289 | num_spherical, num_radial, cutoff, envelope_exponent 290 | ) 291 | 292 | # self.emb = EmbeddingBlock(num_radial, hidden_channels, act) 293 | self.emb = EmbeddingBlock(num_input, num_radial, hidden_channels, act) 294 | 295 | self.output_blocks = torch.nn.ModuleList( 296 | [ 297 | OutputPPBlock( 298 | num_radial, 299 | hidden_channels, 300 | out_emb_channels, 301 | out_channels, 302 | num_output_layers, 303 | act, 304 | ) 305 | for _ in range(num_blocks + 1) 306 | ] 307 | ) 308 | 309 | self.interaction_blocks = torch.nn.ModuleList( 310 | [ 311 | InteractionPPBlock( 312 | hidden_channels, 313 | int_emb_size, 314 | basis_emb_size, 315 | num_spherical, 316 | num_radial, 317 | num_before_skip, 318 | num_after_skip, 319 | act, 320 | ) 321 | for _ in range(num_blocks) 322 | ] 323 | ) 324 | 325 | # if len(self.MLP_hidden_sizes) > 0: 326 | # self.Output_MLP = MLP(input_size = out_channels, output_size = 1, hidden_sizes = MLP_hidden_sizes, activation_hidden = torch.nn.LeakyReLU(negative_slope=0.01), activation_out = torch.nn.Identity(), biases = True, dropout = 0.0) 327 | 328 | self.reset_parameters() 329 | 330 | def reset_parameters(self): 331 | self.rbf.reset_parameters() 332 | self.emb.reset_parameters() 333 | #for out in self.output_blocks: 334 | # out.reset_parameters() 335 | for interaction in self.interaction_blocks: 336 | interaction.reset_parameters() 337 | 338 | def triplets(self, edge_index, num_nodes): 339 | row, col = edge_index # j->i 340 | 341 | value = torch.arange(row.size(0), device=row.device) 342 | adj_t = SparseTensor( 343 | row=col, col=row, value=value, sparse_sizes=(num_nodes, num_nodes) 344 | ) 345 | adj_t_row = adj_t[row] 346 | num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) 347 | 348 | # Node indices (k->j->i) for triplets. 349 | idx_i = col.repeat_interleave(num_triplets) 350 | idx_j = row.repeat_interleave(num_triplets) 351 | idx_k = adj_t_row.storage.col() 352 | mask = idx_i != idx_k # Remove i == k triplets. 353 | idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask] 354 | 355 | # Edge indices (k-j, j->i) for triplets. 356 | idx_kj = adj_t_row.storage.value()[mask] 357 | idx_ji = adj_t_row.storage.row()[mask] 358 | 359 | return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji 360 | 361 | def forward(self, node_features, batch_mask, pos, adj): 362 | """""" 363 | # edge_index = radius_graph(pos, r=self.cutoff, batch=batch) 364 | batch_size, n_nodes, in_node_nf = node_features.shape 365 | edge_index, _ = dense_to_sparse(adj) 366 | 367 | node_features = node_features.reshape(-1, in_node_nf) 368 | pos = pos.reshape(-1, 3) 369 | j, i = edge_index 370 | dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() 371 | 372 | _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( 373 | edge_index, num_nodes=node_features.size(0) 374 | ) 375 | 376 | # Calculate angles. 377 | pos_i = pos[idx_i].detach() 378 | pos_j = pos[idx_j].detach() 379 | pos_ji, pos_kj = ( 380 | pos[idx_j].detach() - pos_i, 381 | pos[idx_k].detach() - pos_j, 382 | ) 383 | 384 | a = (pos_ji * pos_kj).sum(dim=-1) 385 | b = torch.cross(pos_ji, pos_kj).norm(dim=-1) 386 | angle = torch.atan2(b, a) 387 | 388 | rbf = self.rbf(dist) 389 | sbf = self.sbf(dist, angle, idx_kj) 390 | 391 | # Embedding block. 392 | x = self.emb(node_features, rbf, i, j) 393 | P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0)) 394 | 395 | # Interaction blocks. 396 | for interaction_block, output_block in zip( 397 | self.interaction_blocks, self.output_blocks[1:] 398 | ): 399 | x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) 400 | P += output_block(x, rbf, i, num_nodes=pos.size(0)) 401 | 402 | # out = P.sum(dim=0) if batch is None else scatter(P, batch, dim=0) 403 | 404 | # #if we are using a MLP for downstream target prediction 405 | # if len(self.MLP_hidden_sizes) > 0: 406 | # target = self.Output_MLP(out) 407 | # return target, out 408 | return P.view(-1, n_nodes, self.hidden_channels) 409 | 410 | # return out -------------------------------------------------------------------------------- /baselines/egnn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | def dense_to_sparse(adj): 5 | r"""Converts a dense adjacency matrix to a sparse adjacency matrix defined 6 | by edge indices and edge attributes. 7 | 8 | Args: 9 | adj (Tensor): The dense adjacency matrix. 10 | :rtype: (:class:`LongTensor`, :class:`Tensor`) 11 | """ 12 | assert adj.dim() >= 2 and adj.dim() <= 3 13 | assert adj.size(-1) == adj.size(-2) 14 | 15 | index = adj.nonzero(as_tuple=True) 16 | edge_attr = adj[index] 17 | 18 | if len(index) == 3: 19 | batch = index[0] * adj.size(-1) 20 | index = (batch + index[1], batch + index[2]) 21 | 22 | return torch.stack(index, dim=0), edge_attr 23 | 24 | class E_GCL(nn.Module): 25 | """ 26 | E(n) Equivariant Convolutional Layer 27 | re 28 | """ 29 | 30 | def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, act_fn=nn.SiLU(), residual=True, attention=False, normalize=False, coords_agg='mean', tanh=False): 31 | super(E_GCL, self).__init__() 32 | input_edge = input_nf * 2 33 | self.residual = residual 34 | self.attention = attention 35 | self.normalize = normalize 36 | self.coords_agg = coords_agg 37 | self.tanh = tanh 38 | self.epsilon = 1e-8 39 | edge_coords_nf = 1 40 | 41 | self.edge_mlp = nn.Sequential( 42 | nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf), 43 | act_fn, 44 | nn.Linear(hidden_nf, hidden_nf), 45 | act_fn) 46 | self.node_mlp = nn.Sequential( 47 | nn.Linear(hidden_nf + input_nf, hidden_nf), 48 | act_fn, 49 | nn.Linear(hidden_nf, output_nf)) 50 | 51 | layer = nn.Linear(hidden_nf, 1, bias=False) 52 | torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) 53 | 54 | coord_mlp = [] 55 | coord_mlp.append(nn.Linear(hidden_nf, hidden_nf)) 56 | coord_mlp.append(act_fn) 57 | coord_mlp.append(layer) 58 | if self.tanh: 59 | coord_mlp.append(nn.Tanh()) 60 | self.coord_mlp = nn.Sequential(*coord_mlp) 61 | 62 | if self.attention: 63 | self.att_mlp = nn.Sequential( 64 | nn.Linear(hidden_nf, 1), 65 | nn.Sigmoid()) 66 | 67 | def edge_model(self, source, target, radial, edge_attr): 68 | if edge_attr is None: # Unused. 69 | out = torch.cat([source, target, radial], dim=1) 70 | else: 71 | out = torch.cat([source, target, radial, edge_attr], dim=1) 72 | out = self.edge_mlp(out) 73 | if self.attention: 74 | att_val = self.att_mlp(out) 75 | out = out * att_val 76 | return out 77 | 78 | def node_model(self, x, edge_index, edge_attr, node_attr): 79 | row, col = edge_index 80 | agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0)) 81 | if node_attr is not None: 82 | agg = torch.cat([x, agg, node_attr], dim=1) 83 | else: 84 | agg = torch.cat([x, agg], dim=1) 85 | out = self.node_mlp(agg) 86 | # if self.residual: 87 | # out = x + out 88 | return out, agg 89 | 90 | def coord_model(self, coord, edge_index, coord_diff, edge_feat): 91 | row, col = edge_index 92 | trans = coord_diff * self.coord_mlp(edge_feat) 93 | if self.coords_agg == 'sum': 94 | agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0)) 95 | elif self.coords_agg == 'mean': 96 | agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0)) 97 | else: 98 | raise Exception('Wrong coords_agg parameter' % self.coords_agg) 99 | coord = coord + agg 100 | return coord 101 | 102 | def coord2radial(self, edge_index, coord): 103 | row, col = edge_index 104 | coord_diff = coord[row] - coord[col] 105 | radial = torch.sum(coord_diff**2, 1).unsqueeze(1) 106 | 107 | if self.normalize: 108 | norm = torch.sqrt(radial).detach() + self.epsilon 109 | coord_diff = coord_diff / norm 110 | 111 | return radial, coord_diff 112 | 113 | def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None, edge_mask=None, update_coords=True): 114 | row, col = edge_index 115 | radial, coord_diff = self.coord2radial(edge_index, coord) 116 | h0 = h 117 | edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) 118 | if edge_mask is not None: 119 | edge_feat = edge_feat * edge_mask 120 | if update_coords: 121 | coord = self.coord_model(coord, edge_index, coord_diff, edge_feat) 122 | h, agg = self.node_model(h, edge_index, edge_feat, node_attr) 123 | if self.residual: 124 | h = h0 + h 125 | return h, coord, edge_attr 126 | 127 | 128 | def unsorted_segment_sum(data, segment_ids, num_segments): 129 | result_shape = (num_segments, data.size(1)) 130 | result = data.new_full(result_shape, 0) # Init empty result tensor. 131 | segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) 132 | result.scatter_add_(0, segment_ids, data) 133 | return result 134 | 135 | 136 | def unsorted_segment_mean(data, segment_ids, num_segments): 137 | result_shape = (num_segments, data.size(1)) 138 | segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) 139 | result = data.new_full(result_shape, 0) # Init empty result tensor. 140 | count = data.new_full(result_shape, 0) 141 | result.scatter_add_(0, segment_ids, data) 142 | count.scatter_add_(0, segment_ids, torch.ones_like(data)) 143 | return result / count.clamp(min=1) 144 | 145 | 146 | def get_edges(n_nodes): 147 | rows, cols = [], [] 148 | for i in range(n_nodes): 149 | for j in range(n_nodes): 150 | if i != j: 151 | rows.append(i) 152 | cols.append(j) 153 | 154 | edges = [rows, cols] 155 | return edges 156 | 157 | 158 | def get_edges_batch(n_nodes, batch_size): 159 | edges = get_edges(n_nodes) 160 | edge_attr = torch.ones(len(edges[0]) * batch_size, 1) 161 | edges = [torch.LongTensor(edges[0]), torch.LongTensor(edges[1])] 162 | if batch_size == 1: 163 | return edges, edge_attr 164 | elif batch_size > 1: 165 | rows, cols = [], [] 166 | for i in range(batch_size): 167 | rows.append(edges[0] + n_nodes * i) 168 | cols.append(edges[1] + n_nodes * i) 169 | edges = [torch.cat(rows), torch.cat(cols)] 170 | return edges, edge_attr 171 | 172 | class EGNN(nn.Module): 173 | def __init__(self, in_node_nf, hidden_nf, in_edge_nf=0, act_fn=nn.SiLU(), n_layers=4, residual=True, attention=False, normalize=False, tanh=False): 174 | ''' 175 | 176 | :param in_node_nf: Number of features for 'h' at the input 177 | :param hidden_nf: Number of hidden features 178 | :param out_node_nf: Number of features for 'h' at the output 179 | :param in_edge_nf: Number of features for the edge features 180 | :param device: Device (e.g. 'cpu', 'cuda:0',...) 181 | :param act_fn: Non-linearity 182 | :param n_layers: Number of layer for the EGNN 183 | :param residual: Use residual connections, we recommend not changing this one 184 | :param attention: Whether using attention or not 185 | :param normalize: Normalizes the coordinates messages such that: 186 | instead of: x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij) 187 | we get: x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij)/||x_i - x_j|| 188 | We noticed it may help in the stability or generalization in some future works. 189 | We didn't use it in our paper. 190 | :param tanh: Sets a tanh activation function at the output of phi_x(m_ij). I.e. it bounds the output of 191 | phi_x(m_ij) which definitely improves in stability but it may decrease in accuracy. 192 | We didn't use it in our paper. 193 | ''' 194 | 195 | super(EGNN, self).__init__() 196 | self.hidden_nf = hidden_nf 197 | # self.device = device 198 | self.n_layers = n_layers 199 | self.embedding_in = nn.Linear(in_node_nf, self.hidden_nf) 200 | # self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf) 201 | for i in range(0, n_layers): 202 | self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, 203 | act_fn=act_fn, residual=residual, attention=attention, 204 | normalize=normalize, tanh=tanh)) 205 | # self.node_dec = nn.Sequential(nn.Linear(self.hidden_nf, self.hidden_nf), 206 | # act_fn, 207 | # nn.Linear(self.hidden_nf, self.hidden_nf)) 208 | 209 | # self.graph_dec = nn.Sequential(nn.Linear(self.hidden_nf, self.hidden_nf), 210 | # act_fn, 211 | # nn.Linear(self.hidden_nf, 1)) 212 | # self.to(self.device) 213 | 214 | def forward(self, node_features, batch_mask, pos, adj): 215 | batch_size, n_nodes, in_node_nf = node_features.shape 216 | edge_index, _ = dense_to_sparse(adj) 217 | 218 | h = node_features.reshape(-1, in_node_nf) 219 | x = pos.reshape(-1, 3) 220 | h = self.embedding_in(h) 221 | for i in range(0, self.n_layers): 222 | h, x, _ = self._modules["gcl_%d" % i](h, edge_index, x, edge_attr=None, edge_mask=None, update_coords=False) 223 | # h = self.node_dec(h) 224 | # h = h.view(-1, n_nodes, self.hidden_nf) 225 | # h = h * batch_mask.unsqueeze(-1) 226 | # h = torch.sum(h, dim=1) 227 | # pred = self.graph_dec(h) 228 | # return pred.squeeze(1) 229 | return h.view(-1, n_nodes, self.hidden_nf) 230 | -------------------------------------------------------------------------------- /baselines/painn.py: -------------------------------------------------------------------------------- 1 | import math 2 | from . import spk_utils as snn 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .spk_utils.neighbors import atom_distances 7 | from typing import Union, Callable 8 | 9 | class BesselBasis(nn.Module): 10 | """ 11 | Sine for radial basis expansion with coulomb decay. (0th order Bessel from DimeNet) 12 | """ 13 | 14 | def __init__(self, cutoff=5.0, n_rbf=None): 15 | """ 16 | Args: 17 | cutoff: radial cutoff 18 | n_rbf: number of basis functions. 19 | """ 20 | super(BesselBasis, self).__init__() 21 | # compute offset and width of Gaussian functions 22 | freqs = torch.arange(1, n_rbf + 1) * math.pi / cutoff 23 | self.register_buffer("freqs", freqs) 24 | 25 | def forward(self, inputs): 26 | a = self.freqs[None, None, None, :] 27 | ax = inputs * a 28 | sinax = torch.sin(ax) 29 | 30 | norm = torch.where(inputs == 0, torch.tensor(1.0, device=inputs.device), inputs) 31 | y = sinax / norm 32 | 33 | return y 34 | 35 | act_class_mapping = { 36 | "ssp": snn.ShiftedSoftplus, 37 | "silu": nn.SiLU, 38 | "tanh": nn.Tanh, 39 | "sigmoid": nn.Sigmoid, 40 | } 41 | 42 | 43 | class GatedEquivariantBlock(nn.Module): 44 | """Gated Equivariant Block as defined in Schütt et al. (2021): 45 | Equivariant message passing for the prediction of tensorial properties and molecular spectra 46 | """ 47 | 48 | def __init__( 49 | self, 50 | hidden_channels, 51 | out_channels, 52 | intermediate_channels=None, 53 | activation="silu", 54 | scalar_activation=False, 55 | ): 56 | super(GatedEquivariantBlock, self).__init__() 57 | self.out_channels = out_channels 58 | 59 | if intermediate_channels is None: 60 | intermediate_channels = hidden_channels 61 | 62 | self.vec1_proj = nn.Linear(hidden_channels, hidden_channels) 63 | self.vec2_proj = nn.Linear(hidden_channels, out_channels) 64 | 65 | act_class = act_class_mapping[activation] 66 | self.update_net = nn.Sequential( 67 | nn.Linear(hidden_channels * 2, intermediate_channels), 68 | act_class(), 69 | nn.Linear(intermediate_channels, out_channels * 2), 70 | ) 71 | 72 | self.act = act_class() if scalar_activation else None 73 | 74 | def reset_parameters(self): 75 | nn.init.xavier_uniform_(self.vec1_proj.weight) 76 | nn.init.xavier_uniform_(self.vec2_proj.weight) 77 | nn.init.xavier_uniform_(self.update_net[0].weight) 78 | self.update_net[0].bias.data.fill_(0) 79 | nn.init.xavier_uniform_(self.update_net[2].weight) 80 | self.update_net[2].bias.data.fill_(0) 81 | 82 | def forward(self, x, v): 83 | vec1 = torch.norm(self.vec1_proj(v), dim=-2) 84 | vec2 = self.vec2_proj(v) 85 | 86 | x = torch.cat([x, vec1], dim=-1) 87 | x, v = torch.split(self.update_net(x), self.out_channels, dim=-1) 88 | v = v.unsqueeze(2) * vec2 89 | 90 | if self.act is not None: 91 | x = self.act(x) 92 | return x, v 93 | 94 | class PaiNN(nn.Module): 95 | """ Polarizable atom interaction neural network """ 96 | def __init__( 97 | self, 98 | n_atom_basis: int = 128, 99 | n_interactions: int = 3, 100 | n_rbf: int = 20, 101 | cutoff: float = 5., 102 | cutoff_network: Union[nn.Module, str] = 'cosine', 103 | radial_basis: Callable = BesselBasis, 104 | activation=F.silu, 105 | max_z: int = 100, 106 | store_neighbors: bool = False, 107 | store_embeddings: bool = False, 108 | n_edge_features: int = 0, 109 | ): 110 | super(PaiNN, self).__init__() 111 | 112 | self.n_atom_basis = n_atom_basis 113 | self.n_interactions = n_interactions 114 | self.cutoff = cutoff 115 | self.cutoff_network = snn.get_cutoff_by_string(cutoff_network)(cutoff) 116 | self.radial_basis = radial_basis(cutoff=cutoff, n_rbf=n_rbf) 117 | self.embedding = nn.Linear(max_z, n_atom_basis) 118 | 119 | self.store_neighbors = store_neighbors 120 | self.store_embeddings = store_embeddings 121 | self.n_edge_features = n_edge_features 122 | 123 | # if self.n_edge_features: 124 | # self.edge_embedding = nn.Embedding(n_edge_features, self.n_interactions * 3 * n_atom_basis, padding_idx=0, max_norm=1.0) 125 | 126 | if type(activation) is str: 127 | if activation == 'swish': 128 | activation = F.silu 129 | elif activation == 'softplus': 130 | activation = snn.shifted_softplus 131 | 132 | self.filter_net = snn.Dense( 133 | n_rbf + n_edge_features, self.n_interactions * 3 * n_atom_basis, activation=None 134 | ) 135 | 136 | self.interatomic_context_net = nn.ModuleList( 137 | [ 138 | nn.Sequential( 139 | snn.Dense(n_atom_basis, n_atom_basis, activation=activation), 140 | snn.Dense(n_atom_basis, 3 * n_atom_basis, activation=None), 141 | ) 142 | for _ in range(self.n_interactions) 143 | ] 144 | ) 145 | 146 | self.intraatomic_context_net = nn.ModuleList( 147 | [ 148 | nn.Sequential( 149 | snn.Dense( 150 | 2 * n_atom_basis, n_atom_basis, activation=activation 151 | ), 152 | snn.Dense(n_atom_basis, 3 * n_atom_basis, activation=None), 153 | ) 154 | for _ in range(self.n_interactions) 155 | ] 156 | ) 157 | 158 | self.mu_channel_mix = nn.ModuleList( 159 | [ 160 | nn.Sequential( 161 | snn.Dense(n_atom_basis, 2 * n_atom_basis, activation=None, bias=False) 162 | ) 163 | for _ in range(self.n_interactions) 164 | ] 165 | ) 166 | 167 | # self.node_dec = nn.Sequential(snn.Dense(self.n_atom_basis, self.n_atom_basis, activation=F.silu), 168 | # snn.Dense(self.n_atom_basis, self.n_atom_basis)) 169 | 170 | # self.graph_dec = nn.Sequential(snn.Dense(self.n_atom_basis, self.n_atom_basis, activation=F.silu), 171 | # snn.Dense(self.n_atom_basis, 1)) 172 | 173 | def forward(self, node_features, positions, neighbors, neighbor_mask, atom_mask): 174 | cell = None 175 | cell_offset = None 176 | # get interatomic vectors and distances 177 | rij, dir_ij = atom_distances( 178 | positions=positions, 179 | neighbors=neighbors, 180 | neighbor_mask=neighbor_mask, 181 | cell=cell, 182 | cell_offsets=cell_offset, 183 | return_vecs=True, 184 | normalize_vecs=True, 185 | ) 186 | 187 | phi_ij = self.radial_basis(rij[..., None]) 188 | 189 | fcut = self.cutoff_network(rij) * neighbor_mask 190 | # fcut = neighbor_mask 191 | fcut = fcut.unsqueeze(-1) 192 | 193 | filters = self.filter_net(phi_ij) 194 | 195 | # if self.n_edge_features: 196 | # edge_types = inputs['edge_types'] 197 | # filters = filters + self.edge_embedding(edge_types) 198 | 199 | filters = filters * fcut 200 | filters = torch.split(filters, 3 * self.n_atom_basis, dim=-1) 201 | 202 | # initialize scalar and vector embeddings 203 | scalars = self.embedding(node_features) 204 | 205 | sshape = scalars.shape 206 | vectors = torch.zeros((sshape[0], sshape[1], 3, sshape[2]), device=scalars.device) 207 | 208 | for i in range(self.n_interactions): 209 | # message function 210 | h_i = self.interatomic_context_net[i](scalars) 211 | h_j, vectors_j = self.collect_neighbors(h_i, vectors, neighbors) 212 | 213 | # neighborhood context 214 | h_i = filters[i] * h_j 215 | 216 | dscalars, dvR, dvv = torch.split(h_i, self.n_atom_basis, dim=-1) 217 | dvectors = torch.einsum("bijf,bijd->bidf", dvR, dir_ij) + torch.einsum( 218 | "bijf,bijdf->bidf", dvv, vectors_j 219 | ) 220 | dscalars = torch.sum(dscalars, dim=2) 221 | scalars = scalars + dscalars 222 | vectors = vectors + dvectors 223 | 224 | # update function 225 | mu_mix = self.mu_channel_mix[i](vectors) 226 | vectors_V, vectors_U = torch.split(mu_mix, self.n_atom_basis, dim=-1) 227 | mu_Vn = torch.norm(vectors_V, dim=2) 228 | 229 | ctx = torch.cat([scalars, mu_Vn], dim=-1) 230 | h_i = self.intraatomic_context_net[i](ctx) 231 | ds, dv, dsv = torch.split(h_i, self.n_atom_basis, dim=-1) 232 | dv = dv.unsqueeze(2) * vectors_U 233 | dsv = dsv * torch.einsum("bidf,bidf->bif", vectors_V, vectors_U) 234 | 235 | # calculate atomwise updates 236 | scalars = scalars + ds + dsv 237 | vectors = vectors + dv 238 | 239 | # h = self.node_dec(scalars) 240 | # h = h * atom_mask.unsqueeze(-1) 241 | # h = torch.sum(h, dim=1) 242 | # pred = self.graph_dec(h) 243 | # return pred.squeeze(1) 244 | return scalars 245 | 246 | # for layer in self.output_network: 247 | # scalars, vectors = layer(scalars, vectors) 248 | # # include v in output to make sure all parameters have a gradient 249 | # pred = scalars + vectors.sum() * 0 250 | # pred = pred.squeeze(-1) * atom_mask 251 | # return torch.sum(pred, dim = -1) 252 | # # scalars = self.scalar_LN(scalars) 253 | # # vectors = self.vector_LN(vectors) 254 | 255 | 256 | 257 | def collect_neighbors(self, scalars, vectors, neighbors): 258 | nbh_size = neighbors.size() 259 | nbh = neighbors.view(-1, nbh_size[1] * nbh_size[2], 1) 260 | 261 | scalar_nbh = nbh.expand(-1, -1, scalars.size(2)) 262 | scalars_j = torch.gather(scalars, 1, scalar_nbh) 263 | scalars_j = scalars_j.view(nbh_size[0], nbh_size[1], nbh_size[2], -1) 264 | 265 | vectors_nbh = nbh[..., None].expand(-1, -1, vectors.size(2), vectors.size(3)) 266 | vectors_j = torch.gather(vectors, 1, vectors_nbh) 267 | vectors_j = vectors_j.view(nbh_size[0], nbh_size[1], nbh_size[2], 3, -1) 268 | return scalars_j, vectors_j 269 | -------------------------------------------------------------------------------- /baselines/schnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .spk_utils.base import Dense 5 | from .spk_utils.cfconv import CFConv 6 | from .spk_utils.cutoff import CosineCutoff 7 | from .spk_utils.acsf import GaussianSmearing 8 | from .spk_utils.neighbors import AtomDistances 9 | from .spk_utils.activations import shifted_softplus 10 | 11 | 12 | __all__ = ["SchNetInteraction", "SchNet"] 13 | 14 | 15 | class SchNetInteraction(nn.Module): 16 | r"""SchNet interaction block for modeling interactions of atomistic systems. 17 | 18 | Args: 19 | n_atom_basis (int): number of features to describe atomic environments. 20 | n_spatial_basis (int): number of input features of filter-generating networks. 21 | n_filters (int): number of filters used in continuous-filter convolution. 22 | cutoff (float): cutoff radius. 23 | cutoff_network (nn.Module, optional): cutoff layer. 24 | normalize_filter (bool, optional): if True, divide aggregated filter by number 25 | of neighbors over which convolution is applied. 26 | 27 | """ 28 | 29 | def __init__( 30 | self, 31 | n_atom_basis, 32 | n_spatial_basis, 33 | n_filters, 34 | cutoff, 35 | cutoff_network=CosineCutoff, 36 | normalize_filter=False, 37 | ): 38 | super(SchNetInteraction, self).__init__() 39 | # filter block used in interaction block 40 | self.filter_network = nn.Sequential( 41 | Dense(n_spatial_basis, n_filters, activation=shifted_softplus), 42 | Dense(n_filters, n_filters), 43 | ) 44 | # cutoff layer used in interaction block 45 | self.cutoff_network = cutoff_network(cutoff) 46 | # interaction block 47 | self.cfconv = CFConv( 48 | n_atom_basis, 49 | n_filters, 50 | n_atom_basis, 51 | self.filter_network, 52 | cutoff_network=self.cutoff_network, 53 | activation=shifted_softplus, 54 | normalize_filter=normalize_filter, 55 | ) 56 | # dense layer 57 | self.dense = Dense(n_atom_basis, n_atom_basis, bias=True, activation=None) 58 | 59 | def forward(self, x, r_ij, neighbors, neighbor_mask, f_ij=None): 60 | """Compute interaction output. 61 | 62 | Args: 63 | x (torch.Tensor): input representation/embedding of atomic environments 64 | with (N_b, N_a, n_atom_basis) shape. 65 | r_ij (torch.Tensor): interatomic distances of (N_b, N_a, N_nbh) shape. 66 | neighbors (torch.Tensor): indices of neighbors of (N_b, N_a, N_nbh) shape. 67 | neighbor_mask (torch.Tensor): mask to filter out non-existing neighbors 68 | introduced via padding. 69 | f_ij (torch.Tensor, optional): expanded interatomic distances in a basis. 70 | If None, r_ij.unsqueeze(-1) is used. 71 | 72 | Returns: 73 | torch.Tensor: block output with (N_b, N_a, n_atom_basis) shape. 74 | 75 | """ 76 | # continuous-filter convolution interaction block followed by Dense layer 77 | v = self.cfconv(x, r_ij, neighbors, neighbor_mask, f_ij) 78 | v = self.dense(v) 79 | return v 80 | 81 | 82 | class SchNet(nn.Module): 83 | """SchNet architecture for learning representations of atomistic systems. 84 | 85 | Args: 86 | n_atom_basis (int, optional): number of features to describe atomic environments. 87 | This determines the size of each embedding vector; i.e. embeddings_dim. 88 | n_filters (int, optional): number of filters used in continuous-filter convolution 89 | n_interactions (int, optional): number of interaction blocks. 90 | cutoff (float, optional): cutoff radius. 91 | n_gaussians (int, optional): number of Gaussian functions used to expand 92 | atomic distances. 93 | normalize_filter (bool, optional): if True, divide aggregated filter by number 94 | of neighbors over which convolution is applied. 95 | coupled_interactions (bool, optional): if True, share the weights across 96 | interaction blocks and filter-generating networks. 97 | return_intermediate (bool, optional): if True, `forward` method also returns 98 | intermediate atomic representations after each interaction block is applied. 99 | max_z (int, optional): maximum nuclear charge allowed in database. This 100 | determines the size of the dictionary of embedding; i.e. num_embeddings. 101 | cutoff_network (nn.Module, optional): cutoff layer. 102 | trainable_gaussians (bool, optional): If True, widths and offset of Gaussian 103 | functions are adjusted during training process. 104 | distance_expansion (nn.Module, optional): layer for expanding interatomic 105 | distances in a basis. 106 | charged_systems (bool, optional): 107 | 108 | References: 109 | .. [#schnet1] Schütt, Arbabzadah, Chmiela, Müller, Tkatchenko: 110 | Quantum-chemical insights from deep tensor neural networks. 111 | Nature Communications, 8, 13890. 2017. 112 | .. [#schnet_transfer] Schütt, Kindermans, Sauceda, Chmiela, Tkatchenko, Müller: 113 | SchNet: A continuous-filter convolutional neural network for modeling quantum 114 | interactions. 115 | In Advances in Neural Information Processing Systems, pp. 992-1002. 2017. 116 | .. [#schnet3] Schütt, Sauceda, Kindermans, Tkatchenko, Müller: 117 | SchNet - a deep learning architecture for molceules and materials. 118 | The Journal of Chemical Physics 148 (24), 241722. 2018. 119 | 120 | """ 121 | 122 | def __init__( 123 | self, 124 | n_atom_basis=128, 125 | n_filters=128, 126 | n_interactions=3, 127 | cutoff=5.0, 128 | n_gaussians=25, 129 | normalize_filter=False, 130 | coupled_interactions=False, 131 | return_intermediate=False, 132 | max_z=100, 133 | cutoff_network=CosineCutoff, 134 | trainable_gaussians=False, 135 | distance_expansion=None, 136 | charged_systems=False, 137 | ): 138 | super(SchNet, self).__init__() 139 | 140 | self.n_atom_basis = n_atom_basis 141 | # make a lookup table to store embeddings for each element (up to atomic 142 | # number max_z) each of which is a vector of size n_atom_basis 143 | self.embedding = nn.Linear(max_z, n_atom_basis) 144 | 145 | # layer for computing interatomic distances 146 | self.distances = AtomDistances() 147 | 148 | # layer for expanding interatomic distances in a basis 149 | if distance_expansion is None: 150 | self.distance_expansion = GaussianSmearing( 151 | 0.0, cutoff, n_gaussians, trainable=trainable_gaussians 152 | ) 153 | else: 154 | self.distance_expansion = distance_expansion 155 | 156 | # block for computing interaction 157 | if coupled_interactions: 158 | # use the same SchNetInteraction instance (hence the same weights) 159 | self.interactions = nn.ModuleList( 160 | [ 161 | SchNetInteraction( 162 | n_atom_basis=n_atom_basis, 163 | n_spatial_basis=n_gaussians, 164 | n_filters=n_filters, 165 | cutoff_network=cutoff_network, 166 | cutoff=cutoff, 167 | normalize_filter=normalize_filter, 168 | ) 169 | ] 170 | * n_interactions 171 | ) 172 | else: 173 | # use one SchNetInteraction instance for each interaction 174 | self.interactions = nn.ModuleList( 175 | [ 176 | SchNetInteraction( 177 | n_atom_basis=n_atom_basis, 178 | n_spatial_basis=n_gaussians, 179 | n_filters=n_filters, 180 | cutoff_network=cutoff_network, 181 | cutoff=cutoff, 182 | normalize_filter=normalize_filter, 183 | ) 184 | for _ in range(n_interactions) 185 | ] 186 | ) 187 | 188 | # self.node_dec = nn.Sequential(Dense(self.n_atom_basis, self.n_atom_basis, activation=shifted_softplus), 189 | # Dense(self.n_atom_basis, self.n_atom_basis)) 190 | 191 | # self.graph_dec = nn.Sequential(Dense(self.n_atom_basis, self.n_atom_basis, activation=shifted_softplus), 192 | # Dense(self.n_atom_basis, 1)) 193 | 194 | # set attributes 195 | self.return_intermediate = return_intermediate 196 | self.charged_systems = charged_systems 197 | if charged_systems: 198 | self.charge = nn.Parameter(torch.Tensor(1, n_atom_basis)) 199 | self.charge.data.normal_(0, 1.0 / n_atom_basis ** 0.5) 200 | 201 | def forward(self, node_features, positions, neighbors, neighbor_mask, atom_mask): 202 | """Compute atomic representations/embeddings. 203 | 204 | Args: 205 | inputs (dict of torch.Tensor): SchNetPack dictionary of input tensors. 206 | 207 | Returns: 208 | torch.Tensor: atom-wise representation. 209 | list of torch.Tensor: intermediate atom-wise representations, if 210 | return_intermediate=True was used. 211 | 212 | """ 213 | # get tensors from input dictionary 214 | cell = None 215 | cell_offset = None 216 | _, n_nodes, _ = node_features.shape 217 | # get atom embeddings for the input atomic numbers 218 | x = self.embedding(node_features) 219 | 220 | # compute interatomic distance of every atom to its neighbors 221 | r_ij = self.distances( 222 | positions, neighbors, cell, cell_offset, neighbor_mask=neighbor_mask 223 | ) 224 | # expand interatomic distances (for example, Gaussian smearing) 225 | f_ij = self.distance_expansion(r_ij) 226 | # store intermediate representations 227 | if self.return_intermediate: 228 | xs = [x] 229 | # compute interaction block to update atomic embeddings 230 | for interaction in self.interactions: 231 | v = interaction(x, r_ij, neighbors, neighbor_mask, f_ij=f_ij) 232 | x = x + v 233 | if self.return_intermediate: 234 | xs.append(x) 235 | 236 | # h = self.node_dec(x) 237 | # h = h * atom_mask.unsqueeze(-1) 238 | # h = torch.sum(h, dim=1) 239 | # pred = self.graph_dec(h) 240 | # return pred.squeeze(1) 241 | return x 242 | -------------------------------------------------------------------------------- /baselines/spk_utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic building blocks of SchNetPack models. Contains various basic and specialized network layers, layers for 3 | cutoff functions, as well as several auxiliary layers and functions. 4 | """ 5 | 6 | from .acsf import * 7 | from .activations import * 8 | from .base import * 9 | from .blocks import * 10 | from .cfconv import * 11 | from .cutoff import * 12 | from .initializers import * 13 | from .neighbors import * 14 | -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/acsf.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/acsf.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/acsf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/acsf.cpython-38.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/activations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/activations.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/activations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/activations.cpython-38.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/blocks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/blocks.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/blocks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/blocks.cpython-38.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/cfconv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/cfconv.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/cfconv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/cfconv.cpython-38.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/cutoff.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/cutoff.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/cutoff.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/cutoff.cpython-38.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/initializers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/initializers.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/initializers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/initializers.cpython-38.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/neighbors.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/neighbors.cpython-37.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/__pycache__/neighbors.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/neighbors.cpython-38.pyc -------------------------------------------------------------------------------- /baselines/spk_utils/acsf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .cutoff import CosineCutoff 5 | 6 | __all__ = [ 7 | "AngularDistribution", 8 | "BehlerAngular", 9 | "GaussianSmearing", 10 | "RadialDistribution", 11 | ] 12 | 13 | 14 | class AngularDistribution(nn.Module): 15 | """ 16 | Routine used to compute angular type symmetry functions between all atoms i-j-k, where i is the central atom. 17 | 18 | Args: 19 | radial_filter (callable): Function used to expand distances (e.g. Gaussians) 20 | angular_filter (callable): Function used to expand angles between triples of atoms (e.g. BehlerAngular) 21 | cutoff_functions (callable): Cutoff function 22 | crossterms (bool): Include radial contributions of the distances r_jk 23 | pairwise_elements (bool): Recombine elemental embedding vectors via an outer product. If e.g. one-hot encoding 24 | is used for the elements, this is equivalent to standard Behler functions 25 | (default=False). 26 | 27 | """ 28 | 29 | def __init__( 30 | self, 31 | radial_filter, 32 | angular_filter, 33 | cutoff_functions=CosineCutoff, 34 | crossterms=False, 35 | pairwise_elements=False, 36 | ): 37 | super(AngularDistribution, self).__init__() 38 | self.radial_filter = radial_filter 39 | self.angular_filter = angular_filter 40 | self.cutoff_function = cutoff_functions 41 | self.crossterms = crossterms 42 | self.pairwise_elements = pairwise_elements 43 | 44 | def forward(self, r_ij, r_ik, r_jk, triple_masks=None, elemental_weights=None): 45 | """ 46 | Args: 47 | r_ij (torch.Tensor): Distances to neighbor j 48 | r_ik (torch.Tensor): Distances to neighbor k 49 | r_jk (torch.Tensor): Distances between neighbor j and k 50 | triple_masks (torch.Tensor): Tensor mask for non-counted pairs (e.g. due to cutoff) 51 | elemental_weights (tuple of two torch.Tensor): Weighting functions for neighboring elements, first is for 52 | neighbors j, second for k 53 | 54 | Returns: 55 | torch.Tensor: Angular distribution functions 56 | 57 | """ 58 | 59 | nbatch, natoms, npairs = r_ij.size() 60 | 61 | # compute gaussilizated distances and cutoffs to neighbor atoms 62 | radial_ij = self.radial_filter(r_ij) 63 | radial_ik = self.radial_filter(r_ik) 64 | angular_distribution = radial_ij * radial_ik 65 | 66 | if self.crossterms: 67 | radial_jk = self.radial_filter(r_jk) 68 | angular_distribution = angular_distribution * radial_jk 69 | 70 | # Use cosine rule to compute cos( theta_ijk ) 71 | cos_theta = (torch.pow(r_ij, 2) + torch.pow(r_ik, 2) - torch.pow(r_jk, 2)) / ( 72 | 2.0 * r_ij * r_ik 73 | ) 74 | 75 | # Required in order to catch NaNs during backprop 76 | if triple_masks is not None: 77 | cos_theta[triple_masks == 0] = 0.0 78 | 79 | angular_term = self.angular_filter(cos_theta) 80 | 81 | if self.cutoff_function is not None: 82 | cutoff_ij = self.cutoff_function(r_ij).unsqueeze(-1) 83 | cutoff_ik = self.cutoff_function(r_ik).unsqueeze(-1) 84 | angular_distribution = angular_distribution * cutoff_ij * cutoff_ik 85 | 86 | if self.crossterms: 87 | cutoff_jk = self.cutoff_function(r_jk).unsqueeze(-1) 88 | angular_distribution = angular_distribution * cutoff_jk 89 | 90 | # Compute radial part of descriptor 91 | if triple_masks is not None: 92 | # Filter out nan divisions via boolean mask, since 93 | # angular_term = angular_term * triple_masks 94 | # is not working (nan*0 = nan) 95 | angular_term[triple_masks == 0] = 0.0 96 | angular_distribution[triple_masks == 0] = 0.0 97 | 98 | # Apply weights here, since dimension is still the same 99 | if elemental_weights is not None: 100 | if not self.pairwise_elements: 101 | Z_ij, Z_ik = elemental_weights 102 | Z_ijk = Z_ij * Z_ik 103 | angular_distribution = ( 104 | torch.unsqueeze(angular_distribution, -1) 105 | * torch.unsqueeze(Z_ijk, -2).float() 106 | ) 107 | else: 108 | # Outer product to emulate vanilla SF behavior 109 | Z_ij, Z_ik = elemental_weights 110 | B, A, N, E = Z_ij.size() 111 | pair_elements = Z_ij[:, :, :, :, None] * Z_ik[:, :, :, None, :] 112 | pair_elements = pair_elements + pair_elements.permute(0, 1, 2, 4, 3) 113 | # Filter out lower triangular components 114 | pair_filter = torch.triu(torch.ones(E, E)) == 1 115 | pair_elements = pair_elements[:, :, :, pair_filter] 116 | angular_distribution = torch.unsqueeze( 117 | angular_distribution, -1 118 | ) * torch.unsqueeze(pair_elements, -2) 119 | 120 | # Dimension is (Nb x Nat x Nneighpair x Nrad) for angular_distribution and 121 | # (Nb x Nat x NNeigpair x Nang) for angular_term, where the latter dims are orthogonal 122 | # To multiply them: 123 | angular_distribution = ( 124 | angular_distribution[:, :, :, :, None, :] 125 | * angular_term[:, :, :, None, :, None] 126 | ) 127 | # For the sum over all contributions 128 | angular_distribution = torch.sum(angular_distribution, 2) 129 | # Finally, we flatten the last two dimensions 130 | angular_distribution = angular_distribution.view(nbatch, natoms, -1) 131 | 132 | return angular_distribution 133 | 134 | 135 | class BehlerAngular(nn.Module): 136 | """ 137 | Compute Behler type angular contribution of the angle spanned by three atoms: 138 | 139 | :math:`2^{(1-\zeta)} (1 + \lambda \cos( {\\theta}_{ijk} ) )^\zeta` 140 | 141 | Sets of zetas with lambdas of -1 and +1 are generated automatically. 142 | 143 | Args: 144 | zetas (set of int): Set of exponents used to compute angular Behler term (default={1}) 145 | 146 | """ 147 | 148 | def __init__(self, zetas={1}): 149 | super(BehlerAngular, self).__init__() 150 | self.zetas = zetas 151 | 152 | def forward(self, cos_theta): 153 | """ 154 | Args: 155 | cos_theta (torch.Tensor): Cosines between all pairs of neighbors of the central atom. 156 | 157 | Returns: 158 | torch.Tensor: Tensor containing values of the angular filters. 159 | """ 160 | angular_pos = [ 161 | 2 ** (1 - zeta) * ((1.0 - cos_theta) ** zeta).unsqueeze(-1) 162 | for zeta in self.zetas 163 | ] 164 | angular_neg = [ 165 | 2 ** (1 - zeta) * ((1.0 + cos_theta) ** zeta).unsqueeze(-1) 166 | for zeta in self.zetas 167 | ] 168 | angular_all = angular_pos + angular_neg 169 | return torch.cat(angular_all, -1) 170 | 171 | 172 | def gaussian_smearing(distances, offset, widths, centered=False): 173 | r"""Smear interatomic distance values using Gaussian functions. 174 | 175 | Args: 176 | distances (torch.Tensor): interatomic distances of (N_b x N_at x N_nbh) shape. 177 | offset (torch.Tensor): offsets values of Gaussian functions. 178 | widths: width values of Gaussian functions. 179 | centered (bool, optional): If True, Gaussians are centered at the origin and 180 | the offsets are used to as their widths (used e.g. for angular functions). 181 | 182 | Returns: 183 | torch.Tensor: smeared distances (N_b x N_at x N_nbh x N_g). 184 | 185 | """ 186 | if not centered: 187 | # compute width of Gaussian functions (using an overlap of 1 STDDEV) 188 | coeff = -0.5 / torch.pow(widths, 2) 189 | # Use advanced indexing to compute the individual components 190 | diff = distances[:, :, :, None] - offset[None, None, None, :] 191 | else: 192 | # if Gaussian functions are centered, use offsets to compute widths 193 | coeff = -0.5 / torch.pow(offset, 2) 194 | # if Gaussian functions are centered, no offset is subtracted 195 | diff = distances[:, :, :, None] 196 | # compute smear distance values 197 | gauss = torch.exp(coeff * torch.pow(diff, 2)) 198 | return gauss 199 | 200 | 201 | class GaussianSmearing(nn.Module): 202 | r"""Smear layer using a set of Gaussian functions. 203 | 204 | Args: 205 | start (float, optional): center of first Gaussian function, :math:`\mu_0`. 206 | stop (float, optional): center of last Gaussian function, :math:`\mu_{N_g}` 207 | n_gaussians (int, optional): total number of Gaussian functions, :math:`N_g`. 208 | centered (bool, optional): If True, Gaussians are centered at the origin and 209 | the offsets are used to as their widths (used e.g. for angular functions). 210 | trainable (bool, optional): If True, widths and offset of Gaussian functions 211 | are adjusted during training process. 212 | 213 | """ 214 | 215 | def __init__( 216 | self, start=0.0, stop=5.0, n_gaussians=50, centered=False, trainable=False 217 | ): 218 | super(GaussianSmearing, self).__init__() 219 | # compute offset and width of Gaussian functions 220 | offset = torch.linspace(start, stop, n_gaussians) 221 | widths = torch.FloatTensor((offset[1] - offset[0]) * torch.ones_like(offset)) 222 | if trainable: 223 | self.width = nn.Parameter(widths) 224 | self.offsets = nn.Parameter(offset) 225 | else: 226 | self.register_buffer("width", widths) 227 | self.register_buffer("offsets", offset) 228 | self.centered = centered 229 | 230 | def forward(self, distances): 231 | """Compute smeared-gaussian distance values. 232 | 233 | Args: 234 | distances (torch.Tensor): interatomic distance values of 235 | (N_b x N_at x N_nbh) shape. 236 | 237 | Returns: 238 | torch.Tensor: layer output of (N_b x N_at x N_nbh x N_g) shape. 239 | 240 | """ 241 | return gaussian_smearing( 242 | distances, self.offsets, self.width, centered=self.centered 243 | ) 244 | 245 | 246 | class RadialDistribution(nn.Module): 247 | """ 248 | Radial distribution function used e.g. to compute Behler type radial symmetry functions. 249 | 250 | Args: 251 | radial_filter (callable): Function used to expand distances (e.g. Gaussians) 252 | cutoff_function (callable): Cutoff function 253 | """ 254 | 255 | def __init__(self, radial_filter, cutoff_function=CosineCutoff): 256 | super(RadialDistribution, self).__init__() 257 | self.radial_filter = radial_filter 258 | self.cutoff_function = cutoff_function 259 | 260 | def forward(self, r_ij, elemental_weights=None, neighbor_mask=None): 261 | """ 262 | Args: 263 | r_ij (torch.Tensor): Interatomic distances 264 | elemental_weights (torch.Tensor): Element-specific weights for distance functions 265 | neighbor_mask (torch.Tensor): Mask to identify positions of neighboring atoms 266 | 267 | Returns: 268 | torch.Tensor: Nbatch x Natoms x Nfilter tensor containing radial distribution functions. 269 | """ 270 | 271 | nbatch, natoms, nneigh = r_ij.size() 272 | 273 | radial_distribution = self.radial_filter(r_ij) 274 | 275 | # If requested, apply cutoff function 276 | if self.cutoff_function is not None: 277 | cutoffs = self.cutoff_function(r_ij) 278 | radial_distribution = radial_distribution * cutoffs.unsqueeze(-1) 279 | 280 | # Apply neighbor mask 281 | if neighbor_mask is not None: 282 | radial_distribution = radial_distribution * torch.unsqueeze( 283 | neighbor_mask, -1 284 | ) 285 | 286 | # Weigh elements if requested 287 | if elemental_weights is not None: 288 | radial_distribution = ( 289 | radial_distribution[:, :, :, :, None] 290 | * elemental_weights[:, :, :, None, :].float() 291 | ) 292 | 293 | radial_distribution = torch.sum(radial_distribution, 2) 294 | radial_distribution = radial_distribution.view(nbatch, natoms, -1) 295 | return radial_distribution 296 | -------------------------------------------------------------------------------- /baselines/spk_utils/activations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | from torch.nn import functional 4 | 5 | 6 | def shifted_softplus(x): 7 | r"""Compute shifted soft-plus activation function. 8 | 9 | .. math:: 10 | y = \ln\left(1 + e^{-x}\right) - \ln(2) 11 | 12 | Args: 13 | x (torch.Tensor): input tensor. 14 | 15 | Returns: 16 | torch.Tensor: shifted soft-plus of input. 17 | 18 | """ 19 | return functional.softplus(x) - np.log(2.0) 20 | 21 | class ShiftedSoftplus(nn.Module): 22 | def __init__(self): 23 | super(ShiftedSoftplus, self).__init__() 24 | self.shift = torch.log(torch.tensor(2.0)).item() 25 | 26 | def forward(self, x): 27 | return functional.softplus(x) - self.shift 28 | -------------------------------------------------------------------------------- /baselines/spk_utils/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.init import xavier_uniform_ 4 | 5 | from .initializers import zeros_initializer 6 | 7 | 8 | __all__ = ["Dense", "GetItem", "ScaleShift", "Standardize", "Aggregate"] 9 | 10 | 11 | class Dense(nn.Linear): 12 | r"""Fully connected linear layer with activation function. 13 | 14 | .. math:: 15 | y = activation(xW^T + b) 16 | 17 | Args: 18 | in_features (int): number of input feature :math:`x`. 19 | out_features (int): number of output features :math:`y`. 20 | bias (bool, optional): if False, the layer will not adapt bias :math:`b`. 21 | activation (callable, optional): if None, no activation function is used. 22 | weight_init (callable, optional): weight initializer from current weight. 23 | bias_init (callable, optional): bias initializer from current bias. 24 | 25 | """ 26 | 27 | def __init__( 28 | self, 29 | in_features, 30 | out_features, 31 | bias=True, 32 | activation=None, 33 | weight_init=xavier_uniform_, 34 | bias_init=zeros_initializer, 35 | ): 36 | self.weight_init = weight_init 37 | self.bias_init = bias_init 38 | self.activation = activation 39 | # initialize linear layer y = xW^T + b 40 | super(Dense, self).__init__(in_features, out_features, bias) 41 | 42 | def reset_parameters(self): 43 | """Reinitialize model weight and bias values.""" 44 | self.weight_init(self.weight) 45 | if self.bias is not None: 46 | self.bias_init(self.bias) 47 | 48 | def forward(self, inputs): 49 | """Compute layer output. 50 | 51 | Args: 52 | inputs (dict of torch.Tensor): batch of input values. 53 | 54 | Returns: 55 | torch.Tensor: layer output. 56 | 57 | """ 58 | # compute linear layer y = xW^T + b 59 | y = super(Dense, self).forward(inputs) 60 | # add activation function 61 | if self.activation: 62 | y = self.activation(y) 63 | return y 64 | 65 | 66 | class GetItem(nn.Module): 67 | """Extraction layer to get an item from SchNetPack dictionary of input tensors. 68 | 69 | Args: 70 | key (str): Property to be extracted from SchNetPack input tensors. 71 | 72 | """ 73 | 74 | def __init__(self, key): 75 | super(GetItem, self).__init__() 76 | self.key = key 77 | 78 | def forward(self, inputs): 79 | """Compute layer output. 80 | 81 | Args: 82 | inputs (dict of torch.Tensor): SchNetPack dictionary of input tensors. 83 | 84 | Returns: 85 | torch.Tensor: layer output. 86 | 87 | """ 88 | return inputs[self.key] 89 | 90 | 91 | class ScaleShift(nn.Module): 92 | r"""Scale and shift layer for standardization. 93 | 94 | .. math:: 95 | y = x \times \sigma + \mu 96 | 97 | Args: 98 | mean (torch.Tensor): mean value :math:`\mu`. 99 | stddev (torch.Tensor): standard deviation value :math:`\sigma`. 100 | 101 | """ 102 | 103 | def __init__(self, mean, stddev): 104 | super(ScaleShift, self).__init__() 105 | self.register_buffer("mean", mean) 106 | self.register_buffer("stddev", stddev) 107 | 108 | def forward(self, input): 109 | """Compute layer output. 110 | 111 | Args: 112 | input (torch.Tensor): input data. 113 | 114 | Returns: 115 | torch.Tensor: layer output. 116 | 117 | """ 118 | y = input * self.stddev + self.mean 119 | return y 120 | 121 | 122 | class Standardize(nn.Module): 123 | r"""Standardize layer for shifting and scaling. 124 | 125 | .. math:: 126 | y = \frac{x - \mu}{\sigma} 127 | 128 | Args: 129 | mean (torch.Tensor): mean value :math:`\mu`. 130 | stddev (torch.Tensor): standard deviation value :math:`\sigma`. 131 | eps (float, optional): small offset value to avoid zero division. 132 | 133 | """ 134 | 135 | def __init__(self, mean, stddev, eps=1e-9): 136 | super(Standardize, self).__init__() 137 | self.register_buffer("mean", mean) 138 | self.register_buffer("stddev", stddev) 139 | self.register_buffer("eps", torch.ones_like(stddev) * eps) 140 | 141 | def forward(self, input): 142 | """Compute layer output. 143 | 144 | Args: 145 | input (torch.Tensor): input data. 146 | 147 | Returns: 148 | torch.Tensor: layer output. 149 | 150 | """ 151 | # Add small number to catch divide by zero 152 | y = (input - self.mean) / (self.stddev + self.eps) 153 | return y 154 | 155 | 156 | class Aggregate(nn.Module): 157 | """Pooling layer based on sum or average with optional masking. 158 | 159 | Args: 160 | axis (int): axis along which pooling is done. 161 | mean (bool, optional): if True, use average instead for sum pooling. 162 | keepdim (bool, optional): whether the output tensor has dim retained or not. 163 | 164 | """ 165 | 166 | def __init__(self, axis, mean=False, keepdim=True): 167 | super(Aggregate, self).__init__() 168 | self.average = mean 169 | self.axis = axis 170 | self.keepdim = keepdim 171 | 172 | def forward(self, input, mask=None): 173 | r"""Compute layer output. 174 | 175 | Args: 176 | input (torch.Tensor): input data. 177 | mask (torch.Tensor, optional): mask to be applied; e.g. neighbors mask. 178 | 179 | Returns: 180 | torch.Tensor: layer output. 181 | 182 | """ 183 | # mask input 184 | if mask is not None: 185 | input = input * mask[..., None] 186 | # compute sum of input along axis 187 | y = torch.sum(input, self.axis) 188 | # compute average of input along axis 189 | if self.average: 190 | # get the number of items along axis 191 | if mask is not None: 192 | N = torch.sum(mask, self.axis, keepdim=self.keepdim) 193 | N = torch.max(N, other=torch.ones_like(N)) 194 | else: 195 | N = input.size(self.axis) 196 | y = y / N 197 | return y 198 | 199 | 200 | class MaxAggregate(nn.Module): 201 | """Pooling layer that computes the maximum for each feature over all atoms 202 | 203 | Args: 204 | axis (int): axis along which pooling is done. 205 | """ 206 | 207 | def __init__(self, axis): 208 | super().__init__() 209 | self.axis = axis 210 | 211 | def forward(self, input, mask=None): 212 | r"""Compute layer output. 213 | 214 | Args: 215 | input (torch.Tensor): input data. 216 | mask (torch.Tensor, optional): mask to be applied; e.g. neighbors mask. 217 | 218 | Returns: 219 | torch.Tensor: layer output. 220 | """ 221 | # mask input 222 | if mask is not None: 223 | # If the mask is lower dimensional than the array being masked, 224 | # inject an extra dimension to the end 225 | if mask.dim() < input.dim(): 226 | mask = torch.unsqueeze(mask, -1) 227 | input = torch.where(mask > 0, input, torch.min(input)) 228 | 229 | # compute sum of input along axis 230 | return torch.max(input, self.axis)[0] 231 | 232 | 233 | class SoftmaxAggregate(nn.Module): 234 | """Pooling layer that computes the maximum for each feature over all atoms 235 | using the "softmax" function to weigh the contribution of each atom to 236 | the "maximum." 237 | 238 | Args: 239 | axis (int): axis along which pooling is done. 240 | """ 241 | 242 | def __init__(self, axis): 243 | super().__init__() 244 | self.axis = axis 245 | 246 | def forward(self, input, mask=None): 247 | r"""Compute layer output. 248 | 249 | Args: 250 | input (torch.Tensor): input data. 251 | mask (torch.Tensor, optional): mask to be applied; e.g. neighbors mask. 252 | 253 | Returns: 254 | torch.Tensor: layer output. 255 | """ 256 | 257 | # Compute the sum of exponentials for the desired axis 258 | exp_input = torch.exp(input) 259 | 260 | # Set the contributions of "masked" atoms to zero 261 | if mask is not None: 262 | # If the mask is lower dimensional than the array being masked, 263 | # inject an extra dimension to the end 264 | if mask.dim() < input.dim(): 265 | mask = torch.unsqueeze(mask, -1) 266 | exp_input = torch.where(mask > 0, exp_input, torch.zeros_like(exp_input)) 267 | 268 | # Sum exponentials along the desired axis 269 | exp_input_sum = torch.sum(exp_input, self.axis, keepdim=True) 270 | 271 | # Normalize the exponential array by the 272 | weights = exp_input / exp_input_sum 273 | 274 | # compute sum of input along axis 275 | output = torch.sum(input * weights, self.axis) 276 | return output 277 | -------------------------------------------------------------------------------- /baselines/spk_utils/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from . import shifted_softplus, Dense 5 | 6 | __all__ = ["MLP", "TiledMultiLayerNN", "ElementalGate", "GatedNetwork"] 7 | 8 | 9 | class MLP(nn.Module): 10 | """Multiple layer fully connected perceptron neural network. 11 | 12 | Args: 13 | n_in (int): number of input nodes. 14 | n_out (int): number of output nodes. 15 | n_hidden (list of int or int, optional): number hidden layer nodes. 16 | If an integer, same number of node is used for all hidden layers resulting 17 | in a rectangular network. 18 | If None, the number of neurons is divided by two after each layer starting 19 | n_in resulting in a pyramidal network. 20 | n_layers (int, optional): number of layers. 21 | activation (callable, optional): activation function. All hidden layers would 22 | the same activation function except the output layer that does not apply 23 | any activation function. 24 | 25 | """ 26 | 27 | def __init__( 28 | self, n_in, n_out, n_hidden=None, n_layers=2, activation=shifted_softplus 29 | ): 30 | super(MLP, self).__init__() 31 | # get list of number of nodes in input, hidden & output layers 32 | if n_hidden is None: 33 | c_neurons = n_in 34 | self.n_neurons = [] 35 | for i in range(n_layers): 36 | self.n_neurons.append(c_neurons) 37 | c_neurons = max(n_out, c_neurons // 2) 38 | self.n_neurons.append(n_out) 39 | else: 40 | # get list of number of nodes hidden layers 41 | if type(n_hidden) is int: 42 | n_hidden = [n_hidden] * (n_layers - 1) 43 | self.n_neurons = [n_in] + n_hidden + [n_out] 44 | 45 | # assign a Dense layer (with activation function) to each hidden layer 46 | layers = [ 47 | Dense(self.n_neurons[i], self.n_neurons[i + 1], activation=activation) 48 | for i in range(n_layers - 1) 49 | ] 50 | # assign a Dense layer (without activation function) to the output layer 51 | layers.append(Dense(self.n_neurons[-2], self.n_neurons[-1], activation=None)) 52 | # put all layers together to make the network 53 | self.out_net = nn.Sequential(*layers) 54 | 55 | def forward(self, inputs): 56 | """Compute neural network output. 57 | 58 | Args: 59 | inputs (torch.Tensor): network input. 60 | 61 | Returns: 62 | torch.Tensor: network output. 63 | 64 | """ 65 | return self.out_net(inputs) 66 | 67 | 68 | class TiledMultiLayerNN(nn.Module): 69 | """ 70 | Tiled multilayer networks which are applied to the input and produce n_tiled different outputs. 71 | These outputs are then stacked and returned. Used e.g. to construct element-dependent prediction 72 | networks of the Behler-Parrinello type. 73 | 74 | Args: 75 | n_in (int): number of input nodes 76 | n_out (int): number of output nodes 77 | n_tiles (int): number of networks to be tiled 78 | n_hidden (int): number of nodes in hidden nn (default 50) 79 | n_layers (int): number of layers (default: 3) 80 | """ 81 | 82 | def __init__( 83 | self, n_in, n_out, n_tiles, n_hidden=50, n_layers=3, activation=shifted_softplus 84 | ): 85 | super(TiledMultiLayerNN, self).__init__() 86 | self.mlps = nn.ModuleList( 87 | [ 88 | MLP( 89 | n_in, 90 | n_out, 91 | n_hidden=n_hidden, 92 | n_layers=n_layers, 93 | activation=activation, 94 | ) 95 | for _ in range(n_tiles) 96 | ] 97 | ) 98 | 99 | def forward(self, inputs): 100 | """ 101 | Args: 102 | inputs (torch.Tensor): Network inputs. 103 | 104 | Returns: 105 | torch.Tensor: Tiled network outputs. 106 | 107 | """ 108 | return torch.cat([net(inputs) for net in self.mlps], 2) 109 | 110 | 111 | class ElementalGate(nn.Module): 112 | """ 113 | Produces a Nbatch x Natoms x Nelem mask depending on the nuclear charges passed as an argument. 114 | If onehot is set, mask is one-hot mask, else a random embedding is used. 115 | If the trainable flag is set to true, the gate values can be adapted during training. 116 | 117 | Args: 118 | elements (set of int): Set of atomic number present in the data 119 | onehot (bool): Use one hit encoding for elemental gate. If set to False, random embedding is used instead. 120 | trainable (bool): If set to true, gate can be learned during training (default False) 121 | """ 122 | 123 | def __init__(self, elements, onehot=True, trainable=False): 124 | super(ElementalGate, self).__init__() 125 | self.trainable = trainable 126 | 127 | # Get the number of elements, as well as the highest nuclear charge to use in the embedding vector 128 | self.nelems = len(elements) 129 | maxelem = int(max(elements) + 1) 130 | 131 | self.gate = nn.Embedding(maxelem, self.nelems) 132 | 133 | # if requested, initialize as one hot gate for all elements 134 | if onehot: 135 | weights = torch.zeros(maxelem, self.nelems) 136 | for idx, Z in enumerate(elements): 137 | weights[Z, idx] = 1.0 138 | self.gate.weight.data = weights 139 | 140 | # Set trainable flag 141 | if not trainable: 142 | self.gate.weight.requires_grad = False 143 | 144 | def forward(self, atomic_numbers): 145 | """ 146 | Args: 147 | atomic_numbers (torch.Tensor): Tensor containing atomic numbers of each atom. 148 | 149 | Returns: 150 | torch.Tensor: One-hot vector which is one at the position of the element and zero otherwise. 151 | 152 | """ 153 | return self.gate(atomic_numbers) 154 | 155 | 156 | class GatedNetwork(nn.Module): 157 | """ 158 | Combines the TiledMultiLayerNN with the elemental gate to obtain element specific atomistic networks as in typical 159 | Behler--Parrinello networks [#behler1]_. 160 | 161 | Args: 162 | nin (int): number of input nodes 163 | nout (int): number of output nodes 164 | nnodes (int): number of nodes in hidden nn (default 50) 165 | nlayers (int): number of layers (default 3) 166 | elements (set of ints): Set of atomic number present in the data 167 | onehot (bool): Use one hit encoding for elemental gate. If set to False, random embedding is used instead. 168 | trainable (bool): If set to true, gate can be learned during training (default False) 169 | activation (callable): activation function 170 | 171 | References 172 | ---------- 173 | .. [#behler1] Behler, Parrinello: 174 | Generalized Neural-Network Representation of High-Dimensional Potential-Energy Surfaces. 175 | Phys. Rev. Lett. 98, 146401. 2007. 176 | 177 | """ 178 | 179 | def __init__( 180 | self, 181 | nin, 182 | nout, 183 | elements, 184 | n_hidden=50, 185 | n_layers=3, 186 | trainable=False, 187 | onehot=True, 188 | activation=shifted_softplus, 189 | ): 190 | super(GatedNetwork, self).__init__() 191 | self.nelem = len(elements) 192 | self.gate = ElementalGate(elements, trainable=trainable, onehot=onehot) 193 | self.network = TiledMultiLayerNN( 194 | nin, 195 | nout, 196 | self.nelem, 197 | n_hidden=n_hidden, 198 | n_layers=n_layers, 199 | activation=activation, 200 | ) 201 | 202 | def forward(self, atomic_numbers, representation): 203 | """ 204 | Args: 205 | inputs (dict of torch.Tensor): SchNetPack format dictionary of input tensors. 206 | 207 | Returns: 208 | torch.Tensor: Output of the gated network. 209 | """ 210 | # At this point, inputs should be the general schnetpack container 211 | gated_network = self.gate(atomic_numbers) * self.network(representation) 212 | return torch.sum(gated_network, -1, keepdim=True) 213 | -------------------------------------------------------------------------------- /baselines/spk_utils/cfconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from . import Dense 5 | from .base import Aggregate 6 | 7 | 8 | __all__ = ["CFConv"] 9 | 10 | 11 | class CFConv(nn.Module): 12 | r"""Continuous-filter convolution block used in SchNet module. 13 | 14 | Args: 15 | n_in (int): number of input (i.e. atomic embedding) dimensions. 16 | n_filters (int): number of filter dimensions. 17 | n_out (int): number of output dimensions. 18 | filter_network (nn.Module): filter block. 19 | cutoff_network (nn.Module, optional): if None, no cut off function is used. 20 | activation (callable, optional): if None, no activation function is used. 21 | normalize_filter (bool, optional): If True, normalize filter to the number 22 | of neighbors when aggregating. 23 | axis (int, optional): axis over which convolution should be applied. 24 | 25 | """ 26 | 27 | def __init__( 28 | self, 29 | n_in, 30 | n_filters, 31 | n_out, 32 | filter_network, 33 | cutoff_network=None, 34 | activation=None, 35 | normalize_filter=False, 36 | axis=2, 37 | ): 38 | super(CFConv, self).__init__() 39 | self.in2f = Dense(n_in, n_filters, bias=False, activation=None) 40 | self.f2out = Dense(n_filters, n_out, bias=True, activation=activation) 41 | self.filter_network = filter_network 42 | self.cutoff_network = cutoff_network 43 | self.agg = Aggregate(axis=axis, mean=normalize_filter) 44 | 45 | def forward(self, x, r_ij, neighbors, pairwise_mask, f_ij=None): 46 | """Compute convolution block. 47 | 48 | Args: 49 | x (torch.Tensor): input representation/embedding of atomic environments 50 | with (N_b, N_a, n_in) shape. 51 | r_ij (torch.Tensor): interatomic distances of (N_b, N_a, N_nbh) shape. 52 | neighbors (torch.Tensor): indices of neighbors of (N_b, N_a, N_nbh) shape. 53 | pairwise_mask (torch.Tensor): mask to filter out non-existing neighbors 54 | introduced via padding. 55 | f_ij (torch.Tensor, optional): expanded interatomic distances in a basis. 56 | If None, r_ij.unsqueeze(-1) is used. 57 | 58 | Returns: 59 | torch.Tensor: block output with (N_b, N_a, n_out) shape. 60 | 61 | """ 62 | if f_ij is None: 63 | f_ij = r_ij.unsqueeze(-1) 64 | 65 | # pass expanded interactomic distances through filter block 66 | W = self.filter_network(f_ij) 67 | # apply cutoff 68 | if self.cutoff_network is not None: 69 | C = self.cutoff_network(r_ij) 70 | W = W * C.unsqueeze(-1) 71 | 72 | # pass initial embeddings through Dense layer 73 | y = self.in2f(x) 74 | # reshape y for element-wise multiplication by W 75 | nbh_size = neighbors.size() 76 | nbh = neighbors.view(-1, nbh_size[1] * nbh_size[2], 1) 77 | nbh = nbh.expand(-1, -1, y.size(2)) 78 | y = torch.gather(y, 1, nbh) 79 | y = y.view(nbh_size[0], nbh_size[1], nbh_size[2], -1) 80 | 81 | # element-wise multiplication, aggregating and Dense layer 82 | y = y * W 83 | y = self.agg(y, pairwise_mask) 84 | y = self.f2out(y) 85 | return y 86 | -------------------------------------------------------------------------------- /baselines/spk_utils/cutoff.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | 6 | __all__ = ["CosineCutoff", "MollifierCutoff", "HardCutoff", "get_cutoff_by_string"] 7 | 8 | 9 | def get_cutoff_by_string(key): 10 | # build cutoff module 11 | if key == "hard": 12 | cutoff_network = HardCutoff 13 | elif key == "cosine": 14 | cutoff_network = CosineCutoff 15 | elif key == "mollifier": 16 | cutoff_network = MollifierCutoff 17 | else: 18 | raise NotImplementedError("cutoff_function {} is unknown".format(key)) 19 | return cutoff_network 20 | 21 | 22 | class CosineCutoff(nn.Module): 23 | r"""Class of Behler cosine cutoff. 24 | 25 | .. math:: 26 | f(r) = \begin{cases} 27 | 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right] 28 | & r < r_\text{cutoff} \\ 29 | 0 & r \geqslant r_\text{cutoff} \\ 30 | \end{cases} 31 | 32 | Args: 33 | cutoff (float, optional): cutoff radius. 34 | 35 | """ 36 | 37 | def __init__(self, cutoff=5.0): 38 | super(CosineCutoff, self).__init__() 39 | self.register_buffer("cutoff", torch.FloatTensor([cutoff])) 40 | 41 | def forward(self, distances): 42 | """Compute cutoff. 43 | 44 | Args: 45 | distances (torch.Tensor): values of interatomic distances. 46 | 47 | Returns: 48 | torch.Tensor: values of cutoff function. 49 | 50 | """ 51 | # Compute values of cutoff function 52 | cutoffs = 0.5 * (torch.cos(distances * np.pi / self.cutoff) + 1.0) 53 | # Remove contributions beyond the cutoff radius 54 | cutoffs *= (distances < self.cutoff).float() 55 | return cutoffs 56 | 57 | 58 | class MollifierCutoff(nn.Module): 59 | r"""Class for mollifier cutoff scaled to have a value of 1 at :math:`r=0`. 60 | 61 | .. math:: 62 | f(r) = \begin{cases} 63 | \exp\left(1 - \frac{1}{1 - \left(\frac{r}{r_\text{cutoff}}\right)^2}\right) 64 | & r < r_\text{cutoff} \\ 65 | 0 & r \geqslant r_\text{cutoff} \\ 66 | \end{cases} 67 | 68 | Args: 69 | cutoff (float, optional): Cutoff radius. 70 | eps (float, optional): offset added to distances for numerical stability. 71 | 72 | """ 73 | 74 | def __init__(self, cutoff=5.0, eps=1.0e-7): 75 | super(MollifierCutoff, self).__init__() 76 | self.register_buffer("cutoff", torch.FloatTensor([cutoff])) 77 | self.register_buffer("eps", torch.FloatTensor([eps])) 78 | 79 | def forward(self, distances): 80 | """Compute cutoff. 81 | 82 | Args: 83 | distances (torch.Tensor): values of interatomic distances. 84 | 85 | Returns: 86 | torch.Tensor: values of cutoff function. 87 | 88 | """ 89 | mask = (distances + self.eps < self.cutoff).float() 90 | exponent = 1.0 - 1.0 / (1.0 - torch.pow(distances * mask / self.cutoff, 2)) 91 | cutoffs = torch.exp(exponent) 92 | cutoffs = cutoffs * mask 93 | return cutoffs 94 | 95 | 96 | class HardCutoff(nn.Module): 97 | r"""Class of hard cutoff. 98 | 99 | .. math:: 100 | f(r) = \begin{cases} 101 | 1 & r \leqslant r_\text{cutoff} \\ 102 | 0 & r > r_\text{cutoff} \\ 103 | \end{cases} 104 | 105 | Args: 106 | cutoff (float): cutoff radius. 107 | 108 | """ 109 | 110 | def __init__(self, cutoff=5.0): 111 | super(HardCutoff, self).__init__() 112 | self.register_buffer("cutoff", torch.FloatTensor([cutoff])) 113 | 114 | def forward(self, distances): 115 | """Compute cutoff. 116 | 117 | Args: 118 | distances (torch.Tensor): values of interatomic distances. 119 | 120 | Returns: 121 | torch.Tensor: values of cutoff function. 122 | 123 | """ 124 | mask = (distances <= self.cutoff).float() 125 | return mask 126 | -------------------------------------------------------------------------------- /baselines/spk_utils/initializers.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from torch.nn.init import constant_ 4 | 5 | zeros_initializer = partial(constant_, val=0.0) 6 | -------------------------------------------------------------------------------- /baselines/spk_utils/neighbors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def atom_distances( 6 | positions, 7 | neighbors, 8 | cell=None, 9 | cell_offsets=None, 10 | return_vecs=False, 11 | normalize_vecs=False, 12 | neighbor_mask=None, 13 | ): 14 | r"""Compute distance of every atom to its neighbors. 15 | 16 | This function uses advanced torch indexing to compute differentiable distances 17 | of every central atom to its relevant neighbors. 18 | 19 | Args: 20 | positions (torch.Tensor): 21 | atomic Cartesian coordinates with (N_b x N_at x 3) shape 22 | neighbors (torch.Tensor): 23 | indices of neighboring atoms to consider with (N_b x N_at x N_nbh) shape 24 | cell (torch.tensor, optional): 25 | periodic cell of (N_b x 3 x 3) shape 26 | cell_offsets (torch.Tensor, optional) : 27 | offset of atom in cell coordinates with (N_b x N_at x N_nbh x 3) shape 28 | return_vecs (bool, optional): if True, also returns direction vectors. 29 | normalize_vecs (bool, optional): if True, normalize direction vectors. 30 | neighbor_mask (torch.Tensor, optional): boolean mask for neighbor positions. 31 | 32 | Returns: 33 | (torch.Tensor, torch.Tensor): 34 | distances: 35 | distance of every atom to its neighbors with 36 | (N_b x N_at x N_nbh) shape. 37 | 38 | dist_vec: 39 | direction cosines of every atom to its 40 | neighbors with (N_b x N_at x N_nbh x 3) shape (optional). 41 | 42 | """ 43 | 44 | # Construct auxiliary index vector 45 | n_batch = positions.size()[0] 46 | idx_m = torch.arange(n_batch, device=positions.device, dtype=torch.long)[ 47 | :, None, None 48 | ] 49 | # Get atomic positions of all neighboring indices 50 | pos_xyz = positions[idx_m, neighbors[:, :, :], :] 51 | 52 | # Subtract positions of central atoms to get distance vectors 53 | dist_vec = pos_xyz - positions[:, :, None, :] 54 | 55 | # add cell offset 56 | if cell is not None: 57 | B, A, N, D = cell_offsets.size() 58 | cell_offsets = cell_offsets.view(B, A * N, D) 59 | offsets = cell_offsets.bmm(cell) 60 | offsets = offsets.view(B, A, N, D) 61 | dist_vec += offsets 62 | 63 | # Compute vector lengths 64 | distances = torch.norm(dist_vec, 2, 3) 65 | 66 | if neighbor_mask is not None: 67 | # Avoid problems with zero distances in forces (instability of square 68 | # root derivative at 0) This way is neccessary, as gradients do not 69 | # work with inplace operations, such as e.g. 70 | # -> distances[mask==0] = 0.0 71 | tmp_distances = torch.zeros_like(distances) 72 | tmp_distances[neighbor_mask != 0] = distances[neighbor_mask != 0] 73 | distances = tmp_distances 74 | 75 | if return_vecs: 76 | tmp_distances = torch.ones_like(distances) 77 | tmp_distances[neighbor_mask != 0] = distances[neighbor_mask != 0] 78 | 79 | if normalize_vecs: 80 | dist_vec = dist_vec / tmp_distances[:, :, :, None] 81 | return distances, dist_vec 82 | return distances 83 | 84 | 85 | class AtomDistances(nn.Module): 86 | r"""Layer for computing distance of every atom to its neighbors. 87 | 88 | Args: 89 | return_directions (bool, optional): if True, the `forward` method also returns 90 | normalized direction vectors. 91 | 92 | """ 93 | 94 | def __init__(self, return_directions=False): 95 | super(AtomDistances, self).__init__() 96 | self.return_directions = return_directions 97 | 98 | def forward( 99 | self, positions, neighbors, cell=None, cell_offsets=None, neighbor_mask=None 100 | ): 101 | r"""Compute distance of every atom to its neighbors. 102 | 103 | Args: 104 | positions (torch.Tensor): atomic Cartesian coordinates with 105 | (N_b x N_at x 3) shape. 106 | neighbors (torch.Tensor): indices of neighboring atoms to consider 107 | with (N_b x N_at x N_nbh) shape. 108 | cell (torch.tensor, optional): periodic cell of (N_b x 3 x 3) shape. 109 | cell_offsets (torch.Tensor, optional): offset of atom in cell coordinates 110 | with (N_b x N_at x N_nbh x 3) shape. 111 | neighbor_mask (torch.Tensor, optional): boolean mask for neighbor 112 | positions. Required for the stable computation of forces in 113 | molecules with different sizes. 114 | 115 | Returns: 116 | torch.Tensor: layer output of (N_b x N_at x N_nbh) shape. 117 | 118 | """ 119 | return atom_distances( 120 | positions, 121 | neighbors, 122 | cell, 123 | cell_offsets, 124 | return_vecs=self.return_directions, 125 | normalize_vecs=True, 126 | neighbor_mask=neighbor_mask, 127 | ) 128 | 129 | 130 | def triple_distances( 131 | positions, 132 | neighbors_j, 133 | neighbors_k, 134 | offset_idx_j=None, 135 | offset_idx_k=None, 136 | cell=None, 137 | cell_offsets=None, 138 | ): 139 | """ 140 | Get all distances between atoms forming a triangle with the central atoms. 141 | Required e.g. for angular symmetry functions. 142 | 143 | Args: 144 | positions (torch.Tensor): Atomic positions 145 | neighbors_j (torch.Tensor): Indices of first neighbor in triangle 146 | neighbors_k (torch.Tensor): Indices of second neighbor in triangle 147 | offset_idx_j (torch.Tensor): Indices for offets of neighbors j (for PBC) 148 | offset_idx_k (torch.Tensor): Indices for offets of neighbors k (for PBC) 149 | cell (torch.tensor, optional): periodic cell of (N_b x 3 x 3) shape. 150 | cell_offsets (torch.Tensor, optional): offset of atom in cell coordinates 151 | with (N_b x N_at x N_nbh x 3) shape. 152 | 153 | Returns: 154 | torch.Tensor: Distance between central atom and neighbor j 155 | torch.Tensor: Distance between central atom and neighbor k 156 | torch.Tensor: Distance between neighbors 157 | 158 | """ 159 | nbatch, _, _ = neighbors_k.size() 160 | idx_m = torch.arange(nbatch, device=positions.device, dtype=torch.long)[ 161 | :, None, None 162 | ] 163 | 164 | pos_j = positions[idx_m, neighbors_j[:], :] 165 | pos_k = positions[idx_m, neighbors_k[:], :] 166 | 167 | if cell is not None: 168 | # Get the offsets into true cartesian values 169 | B, A, N, D = cell_offsets.size() 170 | 171 | cell_offsets = cell_offsets.view(B, A * N, D) 172 | offsets = cell_offsets.bmm(cell) 173 | offsets = offsets.view(B, A, N, D) 174 | 175 | # Get the offset values for j and k atoms 176 | B, A, T = offset_idx_j.size() 177 | 178 | # Collapse batch and atoms position for easier indexing 179 | offset_idx_j = offset_idx_j.view(B * A, T) 180 | offset_idx_k = offset_idx_k.view(B * A, T) 181 | offsets = offsets.view(B * A, -1, D) 182 | 183 | # Construct auxiliary aray for advanced indexing 184 | idx_offset_m = torch.arange(B * A, device=positions.device, dtype=torch.long)[ 185 | :, None 186 | ] 187 | 188 | # Restore proper dmensions 189 | offset_j = offsets[idx_offset_m, offset_idx_j[:]].view(B, A, T, D) 190 | offset_k = offsets[idx_offset_m, offset_idx_k[:]].view(B, A, T, D) 191 | 192 | # Add offsets 193 | pos_j = pos_j + offset_j 194 | pos_k = pos_k + offset_k 195 | 196 | # if positions.is_cuda: 197 | # idx_m = idx_m.pin_memory().cuda(async=True) 198 | 199 | # Get the real positions of j and k 200 | R_ij = pos_j - positions[:, :, None, :] 201 | R_ik = pos_k - positions[:, :, None, :] 202 | R_jk = pos_j - pos_k 203 | 204 | # + 1e-9 to avoid division by zero 205 | r_ij = torch.norm(R_ij, 2, 3) + 1e-9 206 | r_ik = torch.norm(R_ik, 2, 3) + 1e-9 207 | r_jk = torch.norm(R_jk, 2, 3) + 1e-9 208 | 209 | return r_ij, r_ik, r_jk 210 | 211 | 212 | class TriplesDistances(nn.Module): 213 | """ 214 | Layer that gets all distances between atoms forming a triangle with the 215 | central atoms. Required e.g. for angular symmetry functions. 216 | """ 217 | 218 | def __init__(self): 219 | super(TriplesDistances, self).__init__() 220 | 221 | def forward(self, positions, neighbors_j, neighbors_k): 222 | """ 223 | Args: 224 | positions (torch.Tensor): Atomic positions 225 | neighbors_j (torch.Tensor): Indices of first neighbor in triangle 226 | neighbors_k (torch.Tensor): Indices of second neighbor in triangle 227 | 228 | Returns: 229 | torch.Tensor: Distance between central atom and neighbor j 230 | torch.Tensor: Distance between central atom and neighbor k 231 | torch.Tensor: Distance between neighbors 232 | 233 | """ 234 | return triple_distances(positions, neighbors_j, neighbors_k) 235 | 236 | 237 | def neighbor_elements(atomic_numbers, neighbors): 238 | """ 239 | Return the atomic numbers associated with the neighboring atoms. Can also 240 | be used to gather other properties by neighbors if different atom-wise 241 | Tensor is passed instead of atomic_numbers. 242 | 243 | Args: 244 | atomic_numbers (torch.Tensor): Atomic numbers (Nbatch x Nat x 1) 245 | neighbors (torch.Tensor): Neighbor indices (Nbatch x Nat x Nneigh) 246 | 247 | Returns: 248 | torch.Tensor: Atomic numbers of neighbors (Nbatch x Nat x Nneigh) 249 | 250 | """ 251 | # Get molecules in batch 252 | n_batch = atomic_numbers.size()[0] 253 | # Construct auxiliary index 254 | idx_m = torch.arange(n_batch, device=atomic_numbers.device, dtype=torch.long)[ 255 | :, None, None 256 | ] 257 | # Get neighbors via advanced indexing 258 | neighbor_numbers = atomic_numbers[idx_m, neighbors[:, :, :]] 259 | return neighbor_numbers 260 | 261 | 262 | class NeighborElements(nn.Module): 263 | """ 264 | Layer to obtain the atomic numbers associated with the neighboring atoms. 265 | """ 266 | 267 | def __init__(self): 268 | super(NeighborElements, self).__init__() 269 | 270 | def forward(self, atomic_numbers, neighbors): 271 | """ 272 | Args: 273 | atomic_numbers (torch.Tensor): Atomic numbers (Nbatch x Nat x 1) 274 | neighbors (torch.Tensor): Neighbor indices (Nbatch x Nat x Nneigh) 275 | 276 | Returns: 277 | torch.Tensor: Atomic numbers of neighbors (Nbatch x Nat x Nneigh) 278 | """ 279 | return neighbor_elements(atomic_numbers, neighbors) 280 | -------------------------------------------------------------------------------- /featurization/__pycache__/data_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/featurization/__pycache__/data_utils.cpython-37.pyc -------------------------------------------------------------------------------- /featurization/data_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from torch.utils.data import Dataset, dataset 9 | import json 10 | import copy 11 | 12 | 13 | FloatTensor = torch.FloatTensor 14 | LongTensor = torch.LongTensor 15 | IntTensor = torch.IntTensor 16 | DoubleTensor = torch.DoubleTensor 17 | 18 | def cutname(ori_name): 19 | ori_name = ori_name[:-3] 20 | if ori_name.endswith('out.'): 21 | ori_name = ori_name[:-4] 22 | elif ori_name.endswith('faps.'): 23 | ori_name = ori_name[:-5] 24 | return ori_name + 'p' 25 | 26 | 27 | def load_data_from_df(dataset_path, gas_type, pressure, add_dummy_node=True, use_global_features=False, return_names=False): 28 | 29 | data_df = pd.read_csv(dataset_path + f'/label_by_GCMC/{gas_type}_ads_all.csv',header=0) 30 | data_x = data_df['name'].values 31 | if pressure == 'all': 32 | data_y = data_df.iloc[:,1:].values 33 | else: 34 | data_y = data_df[pressure].values 35 | 36 | if data_y.dtype == np.float64: 37 | data_y = data_y.astype(np.float32) 38 | 39 | x_all, y_all, name_all = load_data_from_processed(dataset_path, data_x, data_y, add_dummy_node=add_dummy_node) 40 | 41 | if return_names: 42 | x_all = (x_all, name_all) 43 | 44 | if use_global_features: 45 | f_all = load_data_with_global_features(dataset_path, name_all, gas_type) 46 | if pressure == 'all': 47 | return x_all, f_all, y_all, data_df.columns.values[1:] 48 | return x_all, f_all, y_all 49 | 50 | if pressure == 'all': 51 | return x_all, y_all, data_df.columns.values[1:] 52 | return x_all, y_all 53 | 54 | def norm_str(ori): 55 | ori = ori.split('.')[0].split('-') 56 | if ori[-1] == 'clean': 57 | ori = ori[:-1] 58 | elif ori[-2] == 'clean': 59 | ori = ori[:-2] 60 | return '-'.join(ori[1:]) 61 | 62 | 63 | def load_real_data(dataset_path, gas_type): 64 | 65 | data_df = pd.read_csv(dataset_path + f'/global_features/exp_geo_all.csv', header=0) 66 | data_x = data_df['name'].values 67 | data_y = data_df.iloc[:,1:].values 68 | global_dic = {} 69 | for x,y in zip(data_x, data_y): 70 | global_dic[x] = y 71 | with open(dataset_path + '/isotherm_data/all.json') as f: 72 | labels = json.load(f)[gas_type]['data'] 73 | label_dict = {_['name']:_["isotherm_data"] for _ in labels} 74 | 75 | with open(dataset_path + f'/isotherm_data/{gas_type}.txt','r') as f: 76 | ls = f.readlines() 77 | ls = [_.strip().split() for _ in ls] 78 | X_all, y_all, f_all, p_all, n_all = [],[],[],[],[] 79 | for l in ls: 80 | if l[0] not in global_dic: 81 | continue 82 | gf = global_dic[l[0]] 83 | afm, adj, dist = pickle.load(open(dataset_path + f'/local_features/{l[0]}.cif.p', "rb")) 84 | afm, adj, dist = add_dummy_node_func(afm, adj, dist) 85 | iso = label_dict[norm_str(l[0])] 86 | p,y = [],[] 87 | for _ in iso: 88 | if _['pressure'] > 0: 89 | p.append(_['pressure']) 90 | y.append(_['adsorption']) 91 | if len(p) == 0: 92 | continue 93 | X_all.append([afm,adj,dist]) 94 | f_all.append(gf) 95 | p_all.append(p) 96 | y_all.append(y) 97 | n_all.append(norm_str(l[0])) 98 | return X_all, f_all, y_all, p_all, n_all 99 | 100 | 101 | 102 | 103 | 104 | def load_data_with_global_features(dataset_path, processed_files, gas_type): 105 | global_feature_path = dataset_path + f'/global_features/{gas_type}_global_features_update.csv' 106 | data_df = pd.read_csv(global_feature_path,header=0) 107 | data_x = data_df.iloc[:, 0].values 108 | data_f = data_df.iloc[:,1:].values.astype(np.float32) 109 | data_dict = {} 110 | for i in range(data_x.shape[0]): 111 | data_dict[data_x[i]] = data_f[i] 112 | f_all = [data_dict[_] for _ in processed_files] 113 | return f_all 114 | 115 | 116 | 117 | def load_data_from_processed(dataset_path, processed_files, labels, add_dummy_node=True): 118 | x_all, y_all, name_all = [], [], [] 119 | 120 | for files, label in zip(processed_files, labels): 121 | 122 | data_file = dataset_path + '/local_features/' + files + '.p' 123 | try: 124 | afm, adj, dist = pickle.load(open(data_file, "rb")) 125 | if add_dummy_node: 126 | afm, adj, dist = add_dummy_node_func(afm, adj, dist) 127 | x_all.append([afm, adj, dist]) 128 | y_all.append([label]) 129 | name_all.append(files) 130 | except: 131 | pass 132 | 133 | return x_all, y_all, name_all 134 | 135 | def add_dummy_node_func(node_features, adj_matrix, dist_matrix): 136 | m = np.zeros((node_features.shape[0] + 1, node_features.shape[1] + 1)) 137 | m[1:, 1:] = node_features 138 | m[0, 0] = 1. 139 | node_features = m 140 | 141 | m = np.ones((adj_matrix.shape[0] + 1, adj_matrix.shape[1] + 1)) 142 | m[1:, 1:] = adj_matrix 143 | adj_matrix = m 144 | 145 | m = np.full((dist_matrix.shape[0] + 1, dist_matrix.shape[1] + 1), 1e6) 146 | m[1:, 1:] = dist_matrix 147 | dist_matrix = m 148 | 149 | return node_features, adj_matrix, dist_matrix 150 | 151 | 152 | class MOF: 153 | def __init__(self, x, y, index, feature = None): 154 | self.node_features = x[0] 155 | self.adjacency_matrix = x[1] 156 | self.distance_matrix = x[2] 157 | self.y = y 158 | self.index = index 159 | self.global_feature = feature 160 | 161 | 162 | class MOFDataset(Dataset): 163 | 164 | def __init__(self, data_list): 165 | self.data_list = data_list 166 | 167 | def __len__(self): 168 | return len(self.data_list) 169 | 170 | def __getitem__(self, key): 171 | if type(key) == slice: 172 | return MOFDataset(self.data_list[key]) 173 | return self.data_list[key] 174 | 175 | 176 | class RealMOFDataset(Dataset): 177 | def __init__(self, data_list, pressure_list, ori_point): 178 | self.data_list = data_list 179 | self.pressure_list = pressure_list 180 | self.ori_point = np.log(np.float32(ori_point)) 181 | def __len__(self): 182 | return len(self.data_list) 183 | def __getitem__(self,key): 184 | if type(key) == slice: 185 | return RealMOFDataset(self.data_list[key], self.pressure_list[key], self.ori_point) 186 | tar_mol = copy.deepcopy(self.data_list[key]) 187 | tar_p = np.log(self.pressure_list[key]) - self.ori_point 188 | tar_mol.global_feature = np.append(tar_mol.global_feature, tar_p) 189 | tar_mol.y = tar_mol.y 190 | return tar_mol 191 | 192 | class MOFDatasetPressureVer(Dataset): 193 | 194 | def __init__(self, data_list, pressure_list, mask_point=None, is_train=True, tar_point=None): 195 | self.data_list = data_list 196 | self.pressure_list = pressure_list 197 | self.mask_point = mask_point 198 | self.is_train = is_train 199 | self.tar_point = tar_point 200 | if is_train: 201 | self.use_idx = np.where(pressure_list != mask_point)[0] 202 | else: 203 | self.use_idx = np.where(pressure_list == tar_point)[0] 204 | self.calcMid() 205 | 206 | def __len__(self): 207 | return len(self.data_list) 208 | 209 | def toStr(self): 210 | return {"data_list":self.data_list,"pressure_list":self.pressure_list,"mask_point":self.mask_point,"is_train":self.is_train, "tar_point":self.tar_point} 211 | def __getitem__(self, key): 212 | if type(key) == slice: 213 | return MOFDataset(self.data_list[key], self.pressure_list, self.mask_point, self.is_train) 214 | tar_mol = copy.deepcopy(self.data_list[key]) 215 | if self.is_train: 216 | tar_p = self.float_pressure - self.mid 217 | tar_mol.global_feature = np.append(tar_mol.global_feature, tar_p) 218 | tar_mol.y = tar_mol.y[0] 219 | else: 220 | tar_idx = self.use_idx 221 | tar_p = self.float_pressure[tar_idx] - self.mid 222 | tar_mol.global_feature = np.append(tar_mol.global_feature, tar_p) 223 | tar_mol.y = [tar_mol.y[0][tar_idx]] 224 | return tar_mol 225 | 226 | def changeTarPoint(self,tar_point): 227 | self.tar_point = tar_point 228 | if not tar_point: 229 | self.is_train = True 230 | else: 231 | self.is_train = False 232 | if not self.is_train: 233 | self.use_idx = np.where(self.pressure_list == tar_point)[0] 234 | 235 | def calcMid(self): 236 | self.float_pressure = np.log(self.pressure_list.astype(np.float)) 237 | self.mid = np.log(np.float(self.mask_point)) 238 | 239 | 240 | def pad_array(array, shape, dtype=np.float32): 241 | padded_array = np.zeros(shape, dtype=dtype) 242 | padded_array[:array.shape[0], :array.shape[1]] = array 243 | return padded_array 244 | 245 | 246 | def mof_collate_func_gf(batch): 247 | adjacency_list, distance_list, features_list, global_features_list = [], [], [], [] 248 | labels = [] 249 | 250 | max_size = 0 251 | for molecule in batch: 252 | if type(molecule.y[0]) == np.ndarray: 253 | labels.append(molecule.y[0]) 254 | else: 255 | labels.append(molecule.y) 256 | if molecule.adjacency_matrix.shape[0] > max_size: 257 | max_size = molecule.adjacency_matrix.shape[0] 258 | 259 | for molecule in batch: 260 | adjacency_list.append(pad_array(molecule.adjacency_matrix, (max_size, max_size))) 261 | distance_list.append(pad_array(molecule.distance_matrix, (max_size, max_size))) 262 | features_list.append(pad_array(molecule.node_features, (max_size, molecule.node_features.shape[1]))) 263 | global_features_list.append(molecule.global_feature) 264 | 265 | return [FloatTensor(features) for features in (adjacency_list, features_list, distance_list, global_features_list, labels)] 266 | 267 | 268 | def construct_dataset(x_all, y_all): 269 | output = [MOF(data[0], data[1], i) 270 | for i, data in enumerate(zip(x_all, y_all))] 271 | return MOFDataset(output) 272 | 273 | def construct_dataset_gf(x_all, f_all, y_all): 274 | output = [MOF(data[0], data[2], i, data[1]) 275 | for i, data in enumerate(zip(x_all, f_all, y_all))] 276 | return MOFDataset(output) 277 | 278 | def construct_dataset_gf_pressurever(x_all, f_all, y_all, pressure_list, is_train=True, mask_point=None, tar_point=None): 279 | output = [MOF(data[0], data[2], i, data[1]) 280 | for i, data in enumerate(zip(x_all, f_all, y_all))] 281 | return MOFDatasetPressureVer(output, pressure_list, is_train=is_train, mask_point=mask_point,tar_point=tar_point) 282 | 283 | def construct_dataset_real(x_all, f_all, y_all, pressure_list, tar_point=None): 284 | output = [MOF(data[0], data[2], i, data[1]) 285 | for i, data in enumerate(zip(x_all, f_all, y_all))] 286 | return RealMOFDataset(output, pressure_list, ori_point=tar_point) 287 | 288 | def construct_loader_gf(x,f,y, batch_size, shuffle=True): 289 | data_set = construct_dataset_gf(x, f, y) 290 | loader = torch.utils.data.DataLoader(dataset=data_set, 291 | batch_size=batch_size, 292 | num_workers=0, 293 | collate_fn=mof_collate_func_gf, 294 | pin_memory=True, 295 | shuffle=shuffle) 296 | return loader 297 | 298 | def construct_loader_gf_pressurever(data_set, batch_size, shuffle=True): 299 | loader = torch.utils.data.DataLoader(dataset=data_set, 300 | batch_size=batch_size, 301 | num_workers=0, 302 | collate_fn=mof_collate_func_gf, 303 | pin_memory=True, 304 | shuffle=shuffle) 305 | return loader 306 | 307 | class data_prefetcher(): 308 | def __init__(self, loader): 309 | self.loader = iter(loader) 310 | self.stream = torch.cuda.Stream() 311 | self.preload() 312 | 313 | def preload(self): 314 | try: 315 | self.next_data = next(self.loader) 316 | except StopIteration: 317 | self.next_data = None 318 | return 319 | with torch.cuda.stream(self.stream): 320 | self.next_data = tuple(_.cuda(non_blocking=True) for _ in self.next_data) 321 | 322 | def next(self): 323 | torch.cuda.current_stream().wait_stream(self.stream) 324 | batch = self.next_data 325 | self.preload() 326 | return batch 327 | -------------------------------------------------------------------------------- /image/3dstructgen-mof.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/image/3dstructgen-mof.png -------------------------------------------------------------------------------- /image/Fig1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/image/Fig1.jpg -------------------------------------------------------------------------------- /image/Fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/image/Fig1.png -------------------------------------------------------------------------------- /model_shap.py: -------------------------------------------------------------------------------- 1 | import shap 2 | import torch 3 | from collections import defaultdict 4 | from featurization.data_utils import load_data_from_df, construct_loader_gf_pressurever, construct_dataset_gf_pressurever, data_prefetcher 5 | from models.transformer import make_model 6 | import numpy as np 7 | import os 8 | from argparser import parse_train_args 9 | import pickle 10 | from tqdm import tqdm 11 | from utils import * 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | def gradient_shap(model, sample_loader, test_loader, batch_size): 16 | model.eval() 17 | model.set_adapter_dim(1) 18 | graph_reps, global_feas = [],[] 19 | for data in tqdm(sample_loader): 20 | adjacency_matrix, node_features, distance_matrix, global_features, y = (_.cpu() for _ in data) 21 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0 22 | batch_mask = batch_mask.float() 23 | graph_rep = model.encode(node_features, batch_mask, adjacency_matrix, distance_matrix, None) 24 | graph_reps.append(graph_rep) 25 | global_feas.append(global_features) 26 | graph_reps = torch.cat(graph_reps) 27 | global_feas = torch.cat(global_feas) 28 | e = shap.GradientExplainer(model.generator, [graph_reps, global_feas]) 29 | shap_all = [] 30 | for data in tqdm(test_loader): 31 | adjacency_matrix, node_features, distance_matrix, global_features, y = (_.cpu() for _ in data) 32 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0 33 | batch_mask = batch_mask.float() 34 | graph_rep = model.encode(node_features, batch_mask, adjacency_matrix, distance_matrix, None) 35 | ans = e.shap_values([graph_rep, global_features],nsamples=10) 36 | local_shap = np.abs(ans[0].sum(axis=1)).reshape(-1,1) 37 | global_shap = np.abs(ans[-1])[:,:9] 38 | shap_values = np.concatenate([local_shap, global_shap],axis=1) 39 | shap_all.append(shap_values) 40 | shap_all = np.concatenate(shap_all, axis=0) 41 | return shap_all 42 | 43 | if __name__ == '__main__': 44 | model_params = parse_train_args() 45 | device_ids = [0,1,2,3] 46 | X, f, y, p = load_data_from_df(model_params['data_dir'],gas_type=model_params['gas_type'], pressure='all',add_dummy_node = True,use_global_features = True) 47 | tar_idx = np.where(p==model_params['pressure'])[0][0] 48 | print(f'Loaded {len(X)} data.') 49 | y = np.array(y) 50 | mean = y[...,tar_idx].mean() 51 | std = y[...,tar_idx].std() 52 | y = (y - mean) / std 53 | f = np.array(f) 54 | fmean = f.mean(axis=0) 55 | fstd = f.std(axis=0) 56 | f = (f - fmean) / fstd 57 | batch_size = model_params['batch_size'] 58 | fold_num = model_params['fold'] 59 | idx_list = np.arange(len(X)) 60 | set_seed(model_params['seed']) 61 | np.random.shuffle(idx_list) 62 | X = applyIndexOnList(X,idx_list) 63 | f = f[idx_list] 64 | y = y[idx_list] 65 | 66 | 67 | 68 | for fold_idx in range(1,2): 69 | set_seed(model_params['seed']) 70 | save_dir = model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}/Fold-{fold_idx}" 71 | ckpt_handler = CheckpointHandler(save_dir) 72 | state = ckpt_handler.checkpoint_best() 73 | model = make_model(**state['params']) 74 | model = torch.nn.DataParallel(model) 75 | model.load_state_dict(state['model']) 76 | model = model.module 77 | train_idx, val_idx, test_idx = splitdata(len(X),fold_num,fold_idx) 78 | train_sample = construct_dataset_gf_pressurever(applyIndexOnList(X,train_idx), f[train_idx], y[train_idx],p, is_train=False, tar_point=model_params['pressure'],mask_point=model_params['pressure']) 79 | test_set = construct_dataset_gf_pressurever(applyIndexOnList(X,test_idx), f[test_idx], y[test_idx],p, is_train=False, tar_point=model_params['pressure'],mask_point=model_params['pressure']) 80 | shaps = {pres:[] for pres in [p[3]]} 81 | for pres in [p[3]]: 82 | train_sample.changeTarPoint(pres) 83 | test_set.changeTarPoint(pres) 84 | sample_loader = construct_loader_gf_pressurever(train_sample, batch_size, shuffle=False) 85 | test_loader = construct_loader_gf_pressurever(test_set, batch_size, shuffle=False) 86 | shap_values = gradient_shap(model, sample_loader, test_loader, batch_size) 87 | shaps[pres].append(shap_values) 88 | 89 | for pres in [p[3]]: 90 | shaps[pres] = np.concatenate(shaps[pres],axis=0) 91 | 92 | with open(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}/shap_result_{p[3]}.p",'wb') as f: 93 | pickle.dump(shaps, f) 94 | 95 | 96 | -------------------------------------------------------------------------------- /nist_test.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import os 3 | import pandas as pd 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import time 8 | from featurization.data_utils import load_data_from_df, construct_loader_gf_pressurever, construct_dataset_gf_pressurever, data_prefetcher, load_real_data, construct_dataset_real 9 | from models.transformer import make_model 10 | from argparser import parse_train_args 11 | from utils import * 12 | import matplotlib.pyplot as plt 13 | from tqdm import tqdm 14 | import pickle 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | def ensemble_test(models,data_loader, mean, std, img_dir, names, p_ori): 19 | os.makedirs(img_dir,exist_ok=True) 20 | for model in models: 21 | model.eval() 22 | batch_idx = 0 23 | p_ori = np.log(float(p_ori)) 24 | ans = {} 25 | for data in tqdm(data_loader): 26 | adjacency_matrix, node_features, distance_matrix, global_features, y = data 27 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0 28 | adapter_dim = global_features.shape[-1] - 9 29 | pressure = global_features[...,-adapter_dim:] 30 | outputs = [] 31 | for model in models: 32 | model.module.set_adapter_dim(adapter_dim) 33 | output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, global_features) 34 | outputs.append(output.cpu().detach().numpy().reshape(-1) * std + mean) 35 | y_tmp = y.cpu().detach().numpy().reshape(-1) 36 | futures_tmp = np.mean(np.array(outputs),axis=0) 37 | pres = pressure.cpu().detach().numpy().reshape(-1) + p_ori 38 | 39 | plt.xlabel('log pressure(Pa)') 40 | plt.ylabel('adsorption(mol/kg)') 41 | l1 = plt.scatter(pres, y_tmp, c ='r', marker = 'o') 42 | l2 = plt.scatter(pres, futures_tmp, c = 'g', marker = 'x') 43 | plt.legend(handles=[l1,l2],labels=['label','prediction'],loc='best') 44 | plt.savefig(f'{img_dir}/{names[batch_idx]}.png') 45 | plt.cla() 46 | ans[names[batch_idx]] = { 47 | 'pressure':np.exp(pres), 48 | 'label':y_tmp, 49 | 'pred':futures_tmp 50 | } 51 | batch_idx += 1 52 | return ans 53 | 54 | if __name__ == '__main__': 55 | 56 | model_params = parse_train_args() 57 | batch_size = 1 58 | device_ids = [0,1,2,3] 59 | 60 | save_dir = f"{model_params['save_dir']}/{model_params['gas_type']}_{model_params['pressure']}" 61 | 62 | with open(os.path.join(save_dir,f'offset.p'),'rb') as f: 63 | p_ori, mean, std, fmean, fstd = pickle.load(f) 64 | 65 | test_errors_all = [] 66 | 67 | X, f, y, p, names = load_real_data(model_params['data_dir'], model_params['gas_type']) 68 | f = np.array(f) 69 | f = (f - fmean) / fstd 70 | test_errors = [] 71 | models = [] 72 | img_dir = os.path.join(model_params['img_dir'],model_params['gas_type']) 73 | predict_res = [] 74 | for fold_idx in range(1,11): 75 | save_dir_fold = f"{save_dir}/Fold-{fold_idx}" 76 | state = CheckpointHandler(save_dir_fold).checkpoint_best() 77 | model = make_model(**state['params']) 78 | model = torch.nn.DataParallel(model) 79 | model.load_state_dict(state['model']) 80 | model = model.to(device) 81 | models.append(model) 82 | test_set = construct_dataset_real(X, f, y, p, p_ori) 83 | test_loader = construct_loader_gf_pressurever(test_set,1,shuffle=False) 84 | test_res = ensemble_test(models, test_loader, mean, std, img_dir, names, p_ori) 85 | with open(os.path.join(img_dir,f"results.p"),'wb') as f: 86 | pickle.dump(test_res,f) 87 | -------------------------------------------------------------------------------- /pressure_adapt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import time 7 | from featurization.data_utils import load_data_from_df, construct_loader_gf_pressurever, construct_dataset_gf_pressurever, data_prefetcher 8 | from models.transformer import make_model 9 | from argparser import parse_finetune_args 10 | import pickle 11 | from utils import * 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | 16 | def train(model, epoch, train_loader, optimizer, scheduler, adapter_dim): 17 | model.train() 18 | loss = 0 19 | loss_all = 0 20 | prefetcher = data_prefetcher(train_loader) 21 | batch_idx = 0 22 | data = prefetcher.next() 23 | while data is not None: 24 | lr = scheduler.optimizer.param_groups[0]['lr'] 25 | adjacency_matrix, node_features, distance_matrix, global_features, y = data 26 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0 27 | 28 | optimizer.zero_grad() 29 | output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, global_features) 30 | loss = F.mse_loss(output.reshape(-1), y.reshape(-1)) 31 | loss.backward() 32 | step_loss = loss.cpu().detach().numpy() 33 | loss_all += step_loss 34 | optimizer.step() 35 | scheduler.step() 36 | print(f'After Step {batch_idx} of Epoch {epoch}, Loss = {step_loss}, Lr = {lr}') 37 | batch_idx += 1 38 | data = prefetcher.next() 39 | return loss_all / len(train_loader.dataset) 40 | 41 | 42 | 43 | def test(model, data_loader, mean, std, adapter_dim): 44 | model.eval() 45 | error = 0 46 | prefetcher = data_prefetcher(data_loader) 47 | batch_idx = 0 48 | data = prefetcher.next() 49 | futures, ys = None, None 50 | while data is not None: 51 | adjacency_matrix, node_features, distance_matrix, global_features, y = data 52 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0 53 | output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, global_features) 54 | output = output.reshape(y.shape).cpu().detach().numpy() 55 | y = y.cpu().detach().numpy() 56 | ys = y if ys is None else np.concatenate([ys,y], axis=0) 57 | futures = output if futures is None else np.concatenate([futures,output], axis=0) 58 | batch_idx += 1 59 | data = prefetcher.next() 60 | 61 | futures = np.array(futures) * std + mean 62 | ys = np.array(ys) * std + mean 63 | mae = np.mean(np.abs(futures - ys), axis=0) 64 | rmse = np.sqrt(np.mean((futures - ys)**2, axis=0)) 65 | # pcc = np.corrcoef(futures,ys)[0][1] 66 | pcc = np.array([np.corrcoef(futures[:,i],ys[:,i])[0][1] for i in range(adapter_dim)]) 67 | smape = 2 * np.mean(np.abs(futures-ys)/(np.abs(futures)+np.abs(ys)), axis=0) 68 | 69 | return {'MAE':mae, 'RMSE':rmse, 'PCC':pcc, 'sMAPE':smape} 70 | 71 | 72 | 73 | def get_RdecayFactor(warmup_step): 74 | 75 | def warmupRdecayFactor(step): 76 | if step < warmup_step: 77 | return step / warmup_step 78 | else: 79 | return (warmup_step / step) ** 0.5 80 | 81 | return warmupRdecayFactor 82 | 83 | if __name__ == '__main__': 84 | 85 | model_params = parse_finetune_args() 86 | batch_size = model_params['batch_size'] 87 | device_ids = [0,1,2,3] 88 | logger = get_logger(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}") 89 | X, f, y, p = load_data_from_df(model_params['data_dir'],gas_type=model_params['gas_type'], pressure='all',add_dummy_node = True,use_global_features = True) 90 | tar_idx = np.where(p==model_params['pressure'])[0][0] 91 | print(f'Loaded {len(X)} data.') 92 | logger.info(f'Loaded {len(X)} data.') 93 | y = np.array(y) 94 | mean = y[...,tar_idx].mean() 95 | std = y[...,tar_idx].std() 96 | y = (y - mean) / std 97 | f = np.array(f) 98 | fmean = f.mean(axis=0) 99 | fstd = f.std(axis=0) 100 | f = (f - fmean) / fstd 101 | 102 | with open(os.path.join(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}",f'offset.p'),'wb') as file: 103 | pickle.dump((model_params['pressure'], mean, std, fmean, fstd), file) 104 | 105 | printParams(model_params,logger) 106 | fold_num = model_params['fold'] 107 | epoch_num = model_params['epoch'] 108 | test_errors = [] 109 | idx_list = np.arange(len(X)) 110 | set_seed(model_params['seed']) 111 | np.random.shuffle(idx_list) 112 | X = applyIndexOnList(X,idx_list) 113 | f = f[idx_list] 114 | y = y[idx_list] 115 | test_errors = [] 116 | 117 | for fold_idx in range(1, fold_num + 1): 118 | 119 | set_seed(model_params['seed']) 120 | ori_state = CheckpointHandler(model_params['ori_dir']+f'/Fold-{fold_idx}').checkpoint_avg() 121 | ori_params = ori_state['params'] 122 | ori_params['adapter_finetune'] = True 123 | model = make_model(**ori_params) 124 | model.set_adapter_dim(model_params['adapter_dim']) 125 | model = torch.nn.DataParallel(model, device_ids=device_ids) 126 | model.load_state_dict(ori_state['model'],strict=False) 127 | model = model.to(device) 128 | lr = model_params['lr'] 129 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 130 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda = get_RdecayFactor(ori_params['warmup_step'])) 131 | best_val_error = 0 132 | best_val_error_s = 0 133 | test_error = 0 134 | best_epoch = -1 135 | 136 | train_idx, val_idx, test_idx = splitdata(len(X),fold_num, fold_idx) 137 | 138 | train_set = construct_dataset_gf_pressurever(applyIndexOnList(X,train_idx), f[train_idx], y[train_idx],p, is_train=True, mask_point=model_params['pressure']) 139 | 140 | 141 | val_set = construct_dataset_gf_pressurever(applyIndexOnList(X,val_idx), f[val_idx], y[val_idx],p, is_train=True, mask_point=model_params['pressure']) 142 | 143 | 144 | test_set = construct_dataset_gf_pressurever(applyIndexOnList(X,test_idx), f[test_idx], y[test_idx],p, is_train=True, mask_point=model_params['pressure']) 145 | 146 | ckpt_handler = CheckpointHandler(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}/Fold-{fold_idx}") 147 | 148 | for epoch in range(1,epoch_num + 1): 149 | train_adapter_dim = model_params['adapter_dim'] 150 | train_loader = construct_loader_gf_pressurever(train_set,batch_size) 151 | loss = train(model, epoch, train_loader,optimizer,scheduler, train_adapter_dim) 152 | val_loader = construct_loader_gf_pressurever(val_set, batch_size, shuffle=False) 153 | val_error = test(model, val_loader, mean, std, train_adapter_dim)['MAE'] 154 | val_error_ = np.mean(val_error) 155 | ckpt_handler.save_model(model,ori_params,epoch,val_error_) 156 | 157 | if best_val_error == 0 or val_error_ <= best_val_error: 158 | print("Enter test step.\n") 159 | best_epoch = epoch 160 | best_val_error = val_error_ 161 | test_loader = construct_loader_gf_pressurever(test_set, batch_size, shuffle=False) 162 | test_error = test(model, test_loader, mean, std, train_adapter_dim) 163 | for idx, pres in enumerate(p): 164 | for _ in test_error.keys(): 165 | print('Fold: {:02d}, Epoch: {:03d}, Pressure: {}, Test {}: {:.7f}'.format(fold_idx, epoch, pres, _, test_error[_][idx])) 166 | logger.info('Fold: {:02d}, Epoch: {:03d}, Pressure: {}, Test {}: {:.7f}'.format(fold_idx, epoch, pres, _, test_error[_][idx])) 167 | lr = scheduler.optimizer.param_groups[0]['lr'] 168 | p_str = 'Fold: {:02d}, Epoch: {:03d}, Val MAE: {:.7f}, Best Val MAE: {:.7f}'.format(fold_idx, epoch, val_error_, best_val_error) 169 | print(p_str) 170 | logger.info(p_str) 171 | 172 | for idx, pres in enumerate(p): 173 | for _ in test_error.keys(): 174 | print('Fold: {:02d}, Epoch: {:03d}, Pressure: {}, Test {}: {:.7f}'.format(fold_idx, epoch, pres, _, test_error[_][idx])) 175 | logger.info('Fold: {:02d}, Epoch: {:03d}, Pressure: {}, Test {}: {:.7f}'.format(fold_idx, epoch, pres, _, test_error[_][idx])) 176 | 177 | test_errors.append(test_error) 178 | 179 | for idx, pres in enumerate(p): 180 | for _ in test_errors[0].keys(): 181 | mt_list = [__[_][idx] for __ in test_errors] 182 | p_str = 'Pressure {}, Test {} of {:02d}-Folds: {:.7f}({:.7f})'.format(pres, _, fold_num, np.mean(mt_list), np.std(mt_list)) 183 | print(p_str) 184 | logger.info(p_str) -------------------------------------------------------------------------------- /process/README: -------------------------------------------------------------------------------- 1 | python process_data.py ABIJUS 2 | -------------------------------------------------------------------------------- /process/create_geo_features.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | data_dir=${root_dir_for_cifs}/ 4 | cell_num=$2 5 | i=$1 6 | name=`echo $i |cut -d '.' -f 1` 7 | #argument 1 and 2 8 | #../network -ha -res ~/wsl/work/clean/$i >/dev/null 9 | ~/bin/network -ha -res ${data_dir}/${i} >/dev/null 10 | LCD=`head -n 1 ${data_dir}/${name}.res | awk '{print $4}'` 11 | PLD=`head -n 1 ${data_dir}/${name}.res | awk '{print $3}'` 12 | #exit 13 | rm ${data_dir}/${name}.res 14 | #argument 3 and 4 and 5 15 | #../network -ha -sa 1.86 1.86 2000 ~/wsl/work/clean/$i >/dev/null 16 | ~/bin/network -ha -sa 1.86 1.86 2000 ${data_dir}/${i} >/dev/null 17 | VSA=`head -n 1 ${data_dir}/${name}.sa | awk '{print $10}'` 18 | GSA=`head -n 1 ${data_dir}/${name}.sa | awk '{print $12}'` 19 | Density=`head -n 1 ${data_dir}/${name}.sa | awk '{print $6}'` 20 | chan_num_sa=`sed -n '2p' ${data_dir}/${name}.sa | awk '{print $2}'` 21 | rm ${data_dir}/${name}.sa 22 | #argument 6 and 7 23 | # ../network -ha -vol 0 0 50000 ~/wsl/work/clean/$i >/dev/null 24 | ~/bin/network -ha -vol 0 0 50000 ${data_dir}/${i} >/dev/null 25 | voidfract=`head -n 1 ${data_dir}/${name}.vol | awk '{print $10}'` 26 | porevolume=`head -n 1 ${data_dir}/${name}.vol | awk '{print $12}'` 27 | rm ${data_dir}/${name}.vol 28 | 29 | ~/bin/network -oms /tmp/${i}.cif >/dev/null 30 | oms=`tail -n 1 /tmp/${i}.oms | awk '{print $3}'` 31 | rm /tmp/${i}.oms 32 | 33 | chan_num_sa=`echo "scale=6;${chan_num_sa}/${cell_num}" | bc` 34 | porevolume=`echo "scale=6;${porevolume}/${cell_num}" | bc` 35 | oms=`echo "scale=6;${oms}/${cell_num}" | bc` 36 | #printf "%-20s%-10s%-10s%-10s%-10s%-10s%-10s%-15s%-15s\n" $name $Density $PLD $LCD $VSA $GSA $voidfract $porevolume $chan_num_sa $oms 37 | echo "$i,$LCD,$PLD,$VSA,$GSA,$Density,$voidfract,$porevolume,$chan_num_sa,$oms" 38 | -------------------------------------------------------------------------------- /process/prepare_mof_features.py: -------------------------------------------------------------------------------- 1 | from ccdc.descriptors import MolecularDescriptors as MD, GeometricDescriptors as GD 2 | from ccdc.io import EntryReader 3 | csd = EntryReader('CSD') 4 | import ccdc.molecule 5 | import sys 6 | import os 7 | import numpy as np 8 | import math 9 | 10 | import pickle 11 | from script.get_atom_features import get_atom_features 12 | from script.get_bond_features import get_bond_features 13 | from script.remove_waters import remove_waters, remove_single_oxygen, get_largest_components 14 | 15 | mol_name = sys.argv[1] 16 | mol = csd.molecule(mol_name) 17 | 18 | mol = remove_waters(mol) 19 | mol = remove_single_oxygen(mol) 20 | if len(mol.components) > 1: 21 | lg_id = get_largest_components(mol) 22 | mol = mol.components[lg_id] 23 | 24 | mol.remove_hydrogens() 25 | 26 | atom_features = get_atom_features(mol) 27 | bond_features = get_bond_features(mol) 28 | 29 | mol_features = [atom_features, bond_features] 30 | 31 | save_path = './processed/' + mol_name + '.p' 32 | 33 | if not os.path.exists(save_path): 34 | pickle.dump(mol_features,open(save_path, "wb")) 35 | 36 | 37 | -------------------------------------------------------------------------------- /process/process_csd_data.py: -------------------------------------------------------------------------------- 1 | from ccdc.descriptors import MolecularDescriptors as MD, GeometricDescriptors as GD 2 | from ccdc.io import EntryReader 3 | csd = EntryReader('CSD') 4 | import ccdc.molecule 5 | import sys 6 | import os 7 | import numpy as np 8 | import math 9 | 10 | import pickle 11 | from tools.get_atom_features import get_atom_features 12 | from tools.get_bond_features import get_bond_features 13 | from tools.remove_waters import remove_waters, remove_single_oxygen, get_largest_components 14 | import numpy as np 15 | from sklearn.metrics import pairwise_distances 16 | 17 | mol_name = sys.argv[1] 18 | mol = csd.molecule(mol_name) 19 | 20 | 21 | # remove waters 22 | mol = remove_waters(mol) 23 | mol = remove_single_oxygen(mol) 24 | 25 | # remove other solvates, here we remove all small components. 26 | 27 | if len(mol.components) > 1: 28 | lg_id = get_largest_components(mol) 29 | mol = mol.components[lg_id] 30 | 31 | mol.remove_hydrogens() 32 | 33 | atom_features = np.array([get_atom_features(atom) for atom in mol.atoms]) 34 | bond_matrix = get_bond_features(mol) 35 | 36 | pos_matrix = np.array([[atom.coordinates.x, atom.coordinates.y, atom.coordinates.z] for atom in mol.atoms]) 37 | dist_matrix = pairwise_distances(pos_matrix) 38 | 39 | mol_features = [atom_features, bond_matrix, dist_matrix] 40 | 41 | save_path = '../data/processed/' + mol_name + '.p' 42 | 43 | if not os.path.exists(save_path): 44 | pickle.dump(mol_features,open(save_path, "wb")) 45 | 46 | 47 | -------------------------------------------------------------------------------- /process/process_csd_data_baselines.py: -------------------------------------------------------------------------------- 1 | from ccdc.descriptors import MolecularDescriptors as MD, GeometricDescriptors as GD 2 | from ccdc.io import EntryReader 3 | csd = EntryReader('CSD') 4 | import ccdc.molecule 5 | import sys 6 | import os 7 | import numpy as np 8 | import math 9 | 10 | import pickle 11 | from tools.get_atom_features import get_atom_features 12 | from tools.get_bond_features import get_bond_features_en 13 | from tools.remove_waters import remove_waters, remove_single_oxygen, get_largest_components 14 | import numpy as np 15 | 16 | mol_name = sys.argv[1] 17 | mol = csd.molecule(mol_name) 18 | 19 | 20 | # remove waters 21 | mol = remove_waters(mol) 22 | mol = remove_single_oxygen(mol) 23 | 24 | # remove other solvates, here we remove all small components. 25 | 26 | if len(mol.components) > 1: 27 | lg_id = get_largest_components(mol) 28 | mol = mol.components[lg_id] 29 | 30 | mol.remove_hydrogens() 31 | 32 | atom_features = np.array([get_atom_features(atom) for atom in mol.atoms]) 33 | row, col = get_bond_features_en(mol) 34 | 35 | pos_matrix = np.array([[atom.coordinates.x, atom.coordinates.y, atom.coordinates.z] for atom in mol.atoms]) 36 | 37 | mol_features = [atom_features, row, col, pos_matrix] 38 | 39 | save_path = '../data/processed_en/' + mol_name + '.p' 40 | 41 | os.makedirs('../data/processed_en/', exist_ok=True) 42 | 43 | if not os.path.exists(save_path): 44 | pickle.dump(mol_features,open(save_path, "wb")) 45 | 46 | 47 | -------------------------------------------------------------------------------- /process/process_nist_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pandas as pd 4 | from tqdm import tqdm 5 | from argparse import ArgumentParser 6 | 7 | t_dict = {77:"Nitrogen", 273:"Carbon Dioxide", 298:"Methane"} 8 | unit_dic = {"mmol/g":1, "mol/g":0.001, 'mmol/kg':1000} 9 | m_dic = {"Nitrogen":28.0134, "Methane":16.0424, "Carbon Dioxide":44.0094} 10 | def get_unit_factor(unit,ads): 11 | if unit in unit_dic: 12 | return 1 / unit_dic[unit] 13 | elif unit == "cm3(STP)/g": 14 | return 1 / 22.4139 15 | elif unit == 'mg/g': 16 | return 1 / m_dic[ads] 17 | else: 18 | return None 19 | 20 | def norm_str(ori): 21 | ori = ori.split('.')[0].split('-') 22 | if ori[-1] == 'clean': 23 | ori = ori[:-1] 24 | elif ori[-2] == 'clean': 25 | ori = ori[:-2] 26 | return '-'.join(ori[1:]) 27 | 28 | if __name__ == "__main__": 29 | parser = ArgumentParser() 30 | parser.add_argument('--data_dir', type=str, 31 | help='NIST data directory.') 32 | args = parser.parse_args() 33 | prefix = os.path.join(args.data_dir,'isotherm_data') 34 | pres_all = {"CH4":{"num":0, "data":[]}, "CO2":{"num":0, "data":[]}, "N2":{"num":0, "data":[]}} 35 | for gas_type in ['CH4','CO2','N2']: 36 | gas_pref = os.path.join(prefix, gas_type) 37 | files = os.listdir(gas_pref) 38 | for js in tqdm(files): 39 | with open(os.path.join(gas_pref, js), "r") as f: 40 | dic = json.load(f) 41 | name = dic['adsorbent']['name'] 42 | t = dic['temperature'] 43 | if t not in t_dict: 44 | continue 45 | tar_obj = t_dict[t] 46 | unit_factor = get_unit_factor(dic['adsorptionUnits'], tar_obj) 47 | if not unit_factor: 48 | continue 49 | tar_key = None 50 | for ads in dic['adsorbates']: 51 | if ads['name'] == tar_obj: 52 | tar_key = ads['InChIKey'] 53 | break 54 | if not tar_key: 55 | continue 56 | pres_ret = [] 57 | for d in dic['isotherm_data']: 58 | pres = d['pressure'] * 1e5 59 | for sd in d['species_data']: 60 | if sd['InChIKey'] == tar_key: 61 | tar_abs = sd['adsorption'] * unit_factor 62 | pres_ret.append({'pressure':pres, 'adsorption':tar_abs}) 63 | pres_all[gas_type]['num'] += 1 64 | pres_all[gas_type]['data'].append({"name":name, "filename":js, "isotherm_data":pres_ret}) 65 | with open(os.path.join(prefix,'all.json'),'w') as f: 66 | json.dump(pres_all, f) 67 | 68 | -------------------------------------------------------------------------------- /process/tools/get_atom_features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def get_atom_features(atom): 4 | attributes = [] 5 | attributes += one_hot_vector( 6 | atom.atomic_number, 7 | [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, \ 8 | 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 19, 30, \ 9 | 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, \ 10 | 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, \ 11 | 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, \ 12 | 73, 74, 75, 76, 77, 78, 79, 80, 81, 999] 13 | ) 14 | # Connected numbers 15 | attributes += one_hot_vector( 16 | len(atom.neighbours), 17 | [0, 1, 2, 3, 4, 5, 6, 999] 18 | ) 19 | 20 | # Test whether or not the atom is a hydrogen bond acceptor 21 | attributes.append(atom.is_acceptor) 22 | attributes.append(atom.is_chiral) 23 | 24 | # Test whether the atom is part of a ring system. 25 | attributes.append(atom.is_cyclic) 26 | attributes.append(atom.is_metal) 27 | 28 | # Test Whether this is a spiro atom. 29 | attributes.append(atom.is_spiro) 30 | 31 | return np.array(list(attributes), dtype=np.float32) 32 | 33 | def one_hot_vector(val, lst): 34 | """Converts a value to a one-hot vector based on options in lst""" 35 | if val not in lst: 36 | val = lst[-1] 37 | return map(lambda x: x == val, lst) 38 | -------------------------------------------------------------------------------- /process/tools/get_bond_features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | def get_bond_features(mol): 5 | """Calculate bond features. 6 | 7 | Args: 8 | mol (ccdc.molecule.bond): An CSD mol object. 9 | 10 | Returns: 11 | bond matriax. 12 | bond distance. 13 | """ 14 | adj_matrix = np.eye(len(mol.atoms)) 15 | dis_matrix = [] 16 | 17 | for bond in mol.bonds: 18 | atom1,atom2 = bond.atoms 19 | # construct atom matrix. 20 | adj_matrix[atom1.index, atom2.index] = adj_matrix[atom2.index, atom1.index] = 1 21 | 22 | # calculate bond distance. 23 | #print(atom1,atom2) 24 | #a_array = [atom1.coordinates.x, atom1.coordinates.y, atom1.coordinates.z] 25 | #b_array = [atom2.coordinates.x, atom2.coordinates.y, atom2.coordinates.z] 26 | #bond_length = calc_distance(a_array, b_array) 27 | #dis_matrix.append(bond_length) 28 | 29 | return adj_matrix 30 | 31 | def get_bond_features_en(mol): 32 | """Calculate bond features. 33 | 34 | Args: 35 | mol (ccdc.molecule.bond): An CSD mol object. 36 | 37 | Returns: 38 | bond matriax (coo). 39 | """ 40 | row, col = [], [] 41 | 42 | for bond in mol.bonds: 43 | atom1,atom2 = bond.atoms 44 | # construct atom matrix. 45 | row.append(atom1.index) 46 | col.append(atom2.index) 47 | row.append(atom2.index) 48 | col.append(atom1.index) 49 | 50 | return row, col 51 | 52 | # function to obtain bond distance 53 | def calc_distance(a_array, b_array): 54 | delt_d = np.array(a_array) - np.array(b_array) 55 | distance = math.sqrt(delt_d[0]**2 + delt_d[1]**2 + delt_d[2]**2) 56 | return round(distance,3) 57 | 58 | 59 | -------------------------------------------------------------------------------- /process/tools/remove_waters.py: -------------------------------------------------------------------------------- 1 | import ccdc.molecule 2 | 3 | def get_largest_components(m): 4 | s = [] 5 | for c in m.components: 6 | n = len(c.atoms) 7 | id_n = int(str(c.identifier)) 8 | l = [(n, id_n)] 9 | s.append(l) 10 | t = sorted(s, key=lambda k: k[0]) 11 | largest_id = t[-1][0][1] - 1 12 | 13 | return largest_id 14 | 15 | def remove_waters(m): 16 | keep = [] 17 | waters = 0 18 | for s in m.components: 19 | ats = [at.atomic_symbol for at in s.atoms] 20 | if len(ats) == 3: 21 | ats.sort() 22 | if ats[0] == 'H' and ats[1] == 'H' and ats[2] == 'O': 23 | waters += 1 24 | else: 25 | keep.append(s) 26 | else: 27 | keep.append(s) 28 | new = ccdc.molecule.Molecule(m.identifier) 29 | for k in keep: 30 | new.add_molecule(k) 31 | return new 32 | 33 | def remove_single_oxygen(m): 34 | keep = [] 35 | waters = 0 36 | for s in m.components: 37 | ats = [at.atomic_symbol for at in s.atoms] 38 | if len(ats) == 1: 39 | ats.sort() 40 | if ats[0] == 'O': 41 | waters += 1 42 | else: 43 | keep.append(s) 44 | else: 45 | keep.append(s) 46 | new = ccdc.molecule.Molecule(m.identifier) 47 | for k in keep: 48 | new.add_molecule(k) 49 | return new 50 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ase==3.21.1 2 | backcall==0.2.0 3 | certifi==2020.12.5 4 | charset-normalizer==2.0.6 5 | cloudpickle==2.0.0 6 | cycler==0.10.0 7 | decorator==5.1.0 8 | future==0.18.2 9 | googledrivedownloader==0.4 10 | greenlet==1.1.0 11 | idna==3.2 12 | importlib-metadata==4.0.1 13 | ipython==7.27.0 14 | isodate==0.6.0 15 | jedi==0.18.0 16 | Jinja2==3.0.3 17 | joblib==1.0.1 18 | kiwisolver==1.3.1 19 | llvmlite==0.37.0 20 | MarkupSafe==2.0.1 21 | matplotlib==3.4.2 22 | matplotlib-inline==0.1.3 23 | monty==2021.8.17 24 | mpmath==1.2.1 25 | networkx==2.6.3 26 | numba==0.54.1 27 | numpy==1.20.3 28 | olefile==0.46 29 | packaging==21.3 30 | palettable==3.3.0 31 | pandas==1.1.5 32 | parso==0.8.2 33 | pexpect==4.8.0 34 | pickleshare==0.7.5 35 | Pillow==8.1.2 36 | plotly==5.3.1 37 | prompt-toolkit==3.0.20 38 | ptyprocess==0.7.0 39 | pycairo==1.20.0 40 | Pygments==2.10.0 41 | pymatgen==2022.0.0 42 | pyparsing==2.4.7 43 | python-dateutil==2.8.1 44 | python-louvain==0.15 45 | pytz==2021.1 46 | PyYAML==6.0 47 | rdflib==5.0.0 48 | reportlab==3.5.67 49 | requests==2.26.0 50 | ruamel.yaml==0.17.16 51 | ruamel.yaml.clib==0.2.6 52 | scikit-learn==0.24.2 53 | scipy==1.6.3 54 | seaborn==0.11.2 55 | shap==0.40.0 56 | six==1.16.0 57 | sklearn==0.0 58 | slicer==0.0.7 59 | spglib==1.16.2 60 | SQLAlchemy==1.4.15 61 | sympy==1.9 62 | tabulate==0.8.9 63 | tenacity==8.0.1 64 | threadpoolctl==2.1.0 65 | torch==1.8.1+cu102 66 | torch-cluster==1.5.9 67 | torch-geometric==2.0.3 68 | torch-scatter==2.0.8 69 | torch-sparse==0.6.12 70 | torch-spline-conv==1.2.1 71 | torchaudio==0.8.1 72 | torchvision==0.6.0+cu101 73 | tornado==6.1 74 | tqdm==4.62.2 75 | traitlets==5.1.0 76 | typing-extensions==3.10.0.0 77 | uncertainties==3.1.6 78 | urllib3==1.26.7 79 | wcwidth==0.2.5 80 | yacs==0.1.8 81 | zipp==3.4.1 82 | -------------------------------------------------------------------------------- /train_baselines.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import time 7 | from baselines.data_utils import load_data_from_df, construct_loader, data_prefetcher 8 | from baselines import make_baseline_model 9 | from argparser import parse_baseline_args 10 | from utils import * 11 | 12 | model_params = parse_baseline_args() 13 | 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | def warmupRdecayFactor(step): 19 | warmup_step = model_params['warmup_step'] 20 | if step < warmup_step: 21 | return step / warmup_step 22 | else: 23 | return (warmup_step / step) ** 0.5 24 | 25 | 26 | def train(epoch, train_loader, optimizer, scheduler, use_adj=True): 27 | model.train() 28 | loss = 0 29 | loss_all = 0 30 | prefetcher = data_prefetcher(train_loader, device) 31 | batch_idx = 0 32 | data = prefetcher.next() 33 | while data is not None: 34 | lr = scheduler.optimizer.param_groups[0]['lr'] 35 | if use_adj: 36 | node_features, pos, adj, global_feature, y = data 37 | else: 38 | node_features, pos, nbh, nbh_mask, global_feature, y = data 39 | adj = (nbh, nbh_mask) 40 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0 41 | 42 | optimizer.zero_grad() 43 | output = model(node_features, batch_mask, pos, adj, global_feature) 44 | y = y.squeeze(-1) 45 | loss = F.mse_loss(output, y) 46 | loss.backward() 47 | step_loss = loss.cpu().detach().numpy() 48 | loss_all += step_loss 49 | optimizer.step() 50 | scheduler.step() 51 | print(f'After Step {batch_idx} of Epoch {epoch}, Loss = {step_loss}, Lr = {lr}') 52 | batch_idx += 1 53 | data = prefetcher.next() 54 | return loss_all / len(train_loader.dataset) 55 | 56 | 57 | def test(data_loader, mean, std, use_adj=True): 58 | model.eval() 59 | error = 0 60 | prefetcher = data_prefetcher(data_loader, device) 61 | batch_idx = 0 62 | data = prefetcher.next() 63 | futures, ys = [], [] 64 | while data is not None: 65 | 66 | if use_adj: 67 | node_features, pos, adj, global_feature, y = data 68 | else: 69 | node_features, pos, nbh, nbh_mask, global_feature, y = data 70 | adj = (nbh, nbh_mask) 71 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0 72 | 73 | optimizer.zero_grad() 74 | output = model(node_features, batch_mask, pos, adj, global_feature) 75 | ys += list(y.cpu().detach().numpy().reshape(-1)) 76 | futures += list(output.cpu().detach().numpy().reshape(-1)) 77 | batch_idx += 1 78 | data = prefetcher.next() 79 | 80 | futures = np.array(futures) * std + mean 81 | ys = np.array(ys) * std + mean 82 | mae = np.mean(np.abs(futures - ys)) 83 | rmse = np.sqrt(np.mean((futures - ys)**2)) 84 | pcc = np.corrcoef(futures,ys)[0][1] 85 | smape = 2 * np.mean(np.abs(futures-ys)/(np.abs(futures)+np.abs(ys))) 86 | 87 | return {'MAE':mae, 'RMSE':rmse, 'PCC':pcc, 'sMAPE':smape} 88 | 89 | if __name__ == '__main__': 90 | 91 | model_name = model_params['model_name'] 92 | if model_name == 'egnn' or 'dimenetpp': 93 | use_adj = True 94 | else: 95 | use_adj = False 96 | batch_size = model_params['batch_size'] 97 | device_ids = [0,1,2,3] 98 | logger = get_logger(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}") 99 | X, f, y = load_data_from_df(model_params['data_dir'],gas_type=model_params['gas_type'], pressure=model_params['pressure'],use_global_features = True) 100 | print(f'Loaded {len(X)} data.') 101 | logger.info(f'Loaded {len(X)} data.') 102 | y = np.array(y) 103 | mean = y.mean() 104 | std = y.std() 105 | y = (y - mean) / std 106 | f = np.array(f) 107 | fmean = f.mean(axis=0) 108 | fstd = f.std(axis=0) 109 | f = (f - fmean) / fstd 110 | 111 | model_params['d_atom'] = X[0][0].shape[1] 112 | model_params['d_feature'] = f.shape[-1] 113 | 114 | printParams(model_params,logger) 115 | fold_num = model_params['fold'] 116 | epoch_num = model_params['epoch'] 117 | test_errors = [] 118 | idx_list = np.arange(len(X)) 119 | set_seed(model_params['seed']) 120 | np.random.shuffle(idx_list) 121 | X = applyIndexOnList(X,idx_list) 122 | f = f[idx_list] 123 | y = y[idx_list] 124 | 125 | for fold_idx in range(1,fold_num + 1): 126 | set_seed(model_params['seed']) 127 | model = make_baseline_model(**model_params) 128 | model = torch.nn.DataParallel(model, device_ids=device_ids) 129 | model = model.to(device) 130 | lr = model_params['lr'] 131 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 132 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda = warmupRdecayFactor) 133 | best_val_error = 0 134 | test_error = 0 135 | best_epoch = -1 136 | train_idx, val_idx, test_idx = splitdata(len(X),fold_num,fold_idx) 137 | 138 | train_loader = construct_loader(applyIndexOnList(X,train_idx), f[train_idx], y[train_idx],batch_size, shuffle=True, use_adj=use_adj) 139 | val_loader = construct_loader(applyIndexOnList(X,val_idx), f[val_idx], y[val_idx],batch_size, shuffle=False, use_adj=use_adj) 140 | test_loader = construct_loader(applyIndexOnList(X,test_idx),f[test_idx], y[test_idx],batch_size, shuffle=False, use_adj=use_adj) 141 | 142 | ckpt_handler = CheckpointHandler(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}/Fold-{fold_idx}") 143 | 144 | for epoch in range(1,epoch_num + 1): 145 | loss = train(epoch,train_loader,optimizer,scheduler, use_adj=use_adj) 146 | val_error = test(val_loader, mean, std, use_adj=use_adj)['MAE'] 147 | ckpt_handler.save_model(model,model_params,epoch,val_error) 148 | if best_val_error == 0 or val_error <= best_val_error: 149 | print("Enter test step.\n") 150 | best_epoch = epoch 151 | test_error = test(test_loader, mean, std, use_adj=use_adj) 152 | best_val_error = val_error 153 | state = {"params":model_params, "epoch":epoch, "model":model.state_dict()} 154 | lr = scheduler.optimizer.param_groups[0]['lr'] 155 | 156 | epoch_op_str = 'Fold: {:02d}, Epoch: {:03d}, LR: {:.7f}, Loss: {:.7f}, Validation MAE: {:.7f}, \ 157 | Test MAE: {:.7f}, Test RMSE: {:.7f}, Test PCC: {:.7f}, Test sMAPE: {:.7f}, Best Val MAE {:.7f}(epoch {:03d})'.format(fold_idx, epoch, lr, loss, val_error, test_error['MAE'], test_error['RMSE'], test_error['PCC'], test_error['sMAPE'], best_val_error, best_epoch) 158 | 159 | print(epoch_op_str) 160 | 161 | logger.info(epoch_op_str) 162 | 163 | test_errors.append(test_error) 164 | print('Fold: {:02d}, Test MAE: {:.7f}, Test RMSE: {:.7f}, Test PCC: {:.7f}, Test sMAPE: {:.7f}'.format(fold_idx, test_error['MAE'], test_error['RMSE'], test_error['PCC'], test_error['sMAPE'])) 165 | logger.info('Fold: {:02d}, Test MAE: {:.7f}, Test RMSE: {:.7f}, Test PCC: {:.7f}, Test sMAPE: {:.7f}'.format(fold_idx, test_error['MAE'], test_error['RMSE'], test_error['PCC'], test_error['sMAPE'])) 166 | for _ in test_errors[0].keys(): 167 | err_mean = np.mean([__[_] for __ in test_errors]) 168 | err_std = np.std([__[_] for __ in test_errors]) 169 | print('Test {} of {:02d}-Folds : {:.7f}({:.7f})'.format(_,fold_num,err_mean,err_std)) 170 | logger.info('Test {} of {:02d}-Folds : {:.7f}({:.7f})'.format(_,fold_num,err_mean,err_std)) 171 | -------------------------------------------------------------------------------- /train_ml.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import tree, svm, ensemble 3 | from featurization.data_utils import load_data_from_df, construct_loader_gf, data_prefetcher 4 | from argparser import parse_train_args,parse_ml_args 5 | from utils import * 6 | 7 | def get_metric_dict(predicted, ground_truth): 8 | mae = np.mean(np.abs(predicted - ground_truth)) 9 | smape = np.mean(np.abs(predicted - ground_truth) / ((np.abs(ground_truth) + np.abs(predicted)) / 2)) 10 | pcc = np.corrcoef(predicted, ground_truth)[0][1] 11 | rmse = np.sqrt(np.mean((predicted - ground_truth) ** 2)) 12 | return {'MAE':mae, 'sMAPE':smape, 'PCC': pcc, 'RMSE':rmse} 13 | 14 | if __name__ == '__main__': 15 | 16 | model_params = parse_ml_args() 17 | device_ids = [0,1,2,3] 18 | logger = get_logger(model_params['save_dir'] + f"/{model_params['ml_type']}/{model_params['gas_type']}_{model_params['pressure']}/") 19 | X, f, y = load_data_from_df(model_params['data_dir'],gas_type=model_params['gas_type'], pressure=model_params['pressure'],add_dummy_node = True,use_global_features = True) 20 | print(f'Loaded {len(X)} data.') 21 | logger.info(f'Loaded {len(X)} data.') 22 | y = np.array(y).reshape(-1) 23 | mean = y.mean() 24 | std = y.std() 25 | y = (y - mean) / std 26 | f = np.array(f) 27 | fmean = f.mean(axis=0) 28 | fstd = f.std(axis=0) 29 | f = (f - fmean) / fstd 30 | 31 | Xs = [np.mean(_[0][:,1:],axis=0) for _ in X] 32 | f = np.concatenate((Xs,f),axis=1) 33 | 34 | printParams(model_params,logger) 35 | fold_num = model_params['fold'] 36 | test_errors = [] 37 | idx_list = np.arange(len(X)) 38 | set_seed(model_params['seed']) 39 | np.random.shuffle(idx_list) 40 | X = applyIndexOnList(X,idx_list) 41 | f = f[idx_list] 42 | y = y[idx_list] 43 | 44 | for fold_idx in range(1,fold_num + 1): 45 | set_seed(model_params['seed']) 46 | 47 | train_idx, val_idx, test_idx = splitdata(len(X),fold_num,fold_idx) 48 | 49 | train_f,train_y = f[train_idx], y[train_idx] 50 | test_f,test_y = f[test_idx], y[test_idx] 51 | 52 | if model_params['ml_type'] == 'RF': 53 | 54 | model = ensemble.RandomForestRegressor(n_estimators=100,criterion='mse',min_samples_split=2,min_samples_leaf=1,max_features='auto') 55 | 56 | elif model_params['ml_type'] == 'SVR': 57 | 58 | model = svm.SVR() 59 | 60 | elif model_params['ml_type'] == 'DT': 61 | 62 | model = tree.DecisionTreeRegressor() 63 | 64 | elif model_params['ml_type'] == 'GBRT': 65 | 66 | model = ensemble.GradientBoostingRegressor() 67 | 68 | model.fit(train_f,train_y) 69 | 70 | future = model.predict(test_f) * std + mean 71 | 72 | test_y = test_y * std + mean 73 | test_error = get_metric_dict(future, test_y) 74 | for _ in test_error.keys(): 75 | print('Fold: {:02d}, Test {}: {:.7f}'.format(fold_idx, _, test_error[_])) 76 | logger.info('Fold: {:02d}, Test {}: {:.7f}'.format(fold_idx, _, test_error[_])) 77 | test_errors.append(test_error) 78 | for _ in test_errors[0].keys(): 79 | err_mean = np.mean([__[_] for __ in test_errors]) 80 | err_std = np.std([__[_] for __ in test_errors]) 81 | print('Test {} of {:02d}-Folds : {:.7f}({:.7f})'.format(_,fold_num,err_mean,err_std)) 82 | logger.info('Test {} of {:02d}-Folds : {:.7f}({:.7f})'.format(_,fold_num,err_mean,err_std)) -------------------------------------------------------------------------------- /train_mofnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import time 7 | from featurization.data_utils import load_data_from_df, construct_loader_gf, data_prefetcher 8 | from models.transformer import make_model 9 | from argparser import parse_train_args 10 | from utils import * 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | def warmupRdecayFactor(step): 15 | warmup_step = model_params['warmup_step'] 16 | if step < warmup_step: 17 | return step / warmup_step 18 | else: 19 | return (warmup_step / step) ** 0.5 20 | 21 | 22 | def train(epoch, train_loader, optimizer, scheduler): 23 | model.train() 24 | loss = 0 25 | loss_all = 0 26 | prefetcher = data_prefetcher(train_loader) 27 | batch_idx = 0 28 | data = prefetcher.next() 29 | while data is not None: 30 | lr = scheduler.optimizer.param_groups[0]['lr'] 31 | adjacency_matrix, node_features, distance_matrix, global_features, y = data 32 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0 33 | 34 | optimizer.zero_grad() 35 | output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, global_features) 36 | loss = F.mse_loss(output, y) 37 | loss.backward() 38 | step_loss = loss.cpu().detach().numpy() 39 | loss_all += step_loss 40 | optimizer.step() 41 | scheduler.step() 42 | print(f'After Step {batch_idx} of Epoch {epoch}, Loss = {step_loss}, Lr = {lr}') 43 | batch_idx += 1 44 | data = prefetcher.next() 45 | return loss_all / len(train_loader.dataset) 46 | 47 | 48 | def test(data_loader, mean, std): 49 | model.eval() 50 | error = 0 51 | prefetcher = data_prefetcher(data_loader) 52 | batch_idx = 0 53 | data = prefetcher.next() 54 | futures, ys = [], [] 55 | while data is not None: 56 | adjacency_matrix, node_features, distance_matrix, global_features, y = data 57 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0 58 | output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, global_features) 59 | ys += list(y.cpu().detach().numpy().reshape(-1)) 60 | futures += list(output.cpu().detach().numpy().reshape(-1)) 61 | batch_idx += 1 62 | data = prefetcher.next() 63 | 64 | futures = np.array(futures) * std + mean 65 | ys = np.array(ys) * std + mean 66 | mae = np.mean(np.abs(futures - ys)) 67 | rmse = np.sqrt(np.mean((futures - ys)**2)) 68 | pcc = np.corrcoef(futures,ys)[0][1] 69 | smape = 2 * np.mean(np.abs(futures-ys)/(np.abs(futures)+np.abs(ys))) 70 | 71 | return {'MAE':mae, 'RMSE':rmse, 'PCC':pcc, 'sMAPE':smape} 72 | 73 | if __name__ == '__main__': 74 | 75 | model_params = parse_train_args() 76 | batch_size = model_params['batch_size'] 77 | device_ids = [0,1,2,3] 78 | logger = get_logger(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}") 79 | X, f, y = load_data_from_df(model_params['data_dir'],gas_type=model_params['gas_type'], pressure=model_params['pressure'],add_dummy_node = True,use_global_features = True) 80 | print(f'Loaded {len(X)} data.') 81 | logger.info(f'Loaded {len(X)} data.') 82 | y = np.array(y) 83 | mean = y.mean() 84 | std = y.std() 85 | y = (y - mean) / std 86 | f = np.array(f) 87 | fmean = f.mean(axis=0) 88 | fstd = f.std(axis=0) 89 | f = (f - fmean) / fstd 90 | 91 | model_params['d_atom'] = X[0][0].shape[1] 92 | model_params['d_feature'] = f.shape[-1] 93 | 94 | printParams(model_params,logger) 95 | fold_num = model_params['fold'] 96 | epoch_num = model_params['epoch'] 97 | test_errors = [] 98 | idx_list = np.arange(len(X)) 99 | set_seed(model_params['seed']) 100 | np.random.shuffle(idx_list) 101 | X = applyIndexOnList(X,idx_list) 102 | f = f[idx_list] 103 | y = y[idx_list] 104 | 105 | for fold_idx in range(1,fold_num + 1): 106 | set_seed(model_params['seed']) 107 | model = make_model(**model_params) 108 | model = torch.nn.DataParallel(model, device_ids=device_ids) 109 | model = model.to(device) 110 | lr = model_params['lr'] 111 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 112 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda = warmupRdecayFactor) 113 | best_val_error = 0 114 | test_error = 0 115 | best_epoch = -1 116 | train_idx, val_idx, test_idx = splitdata(len(X),fold_num,fold_idx) 117 | 118 | train_loader = construct_loader_gf(applyIndexOnList(X,train_idx), f[train_idx], y[train_idx],batch_size) 119 | val_loader = construct_loader_gf(applyIndexOnList(X,val_idx), f[val_idx], y[val_idx],batch_size) 120 | test_loader = construct_loader_gf(applyIndexOnList(X,test_idx),f[test_idx], y[test_idx],batch_size) 121 | 122 | ckpt_handler = CheckpointHandler(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}/Fold-{fold_idx}") 123 | 124 | for epoch in range(1,epoch_num + 1): 125 | loss = train(epoch,train_loader,optimizer,scheduler) 126 | val_error = test(val_loader, mean, std)['MAE'] 127 | ckpt_handler.save_model(model,model_params,epoch,val_error) 128 | if best_val_error == 0 or val_error <= best_val_error: 129 | print("Enter test step.\n") 130 | best_epoch = epoch 131 | test_error = test(test_loader, mean, std) 132 | best_val_error = val_error 133 | state = {"params":model_params, "epoch":epoch, "model":model.state_dict()} 134 | # torch.save(state, model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}/Fold-{fold_idx}.pt") 135 | lr = scheduler.optimizer.param_groups[0]['lr'] 136 | 137 | epoch_op_str = 'Fold: {:02d}, Epoch: {:03d}, LR: {:.7f}, Loss: {:.7f}, Validation MAE: {:.7f}, \ 138 | Test MAE: {:.7f}, Test RMSE: {:.7f}, Test PCC: {:.7f}, Test sMAPE: {:.7f}, Best Val MAE {:.7f}(epoch {:03d})'.format(fold_idx, epoch, lr, loss, val_error, test_error['MAE'], test_error['RMSE'], test_error['PCC'], test_error['sMAPE'], best_val_error, best_epoch) 139 | 140 | print(epoch_op_str) 141 | 142 | logger.info(epoch_op_str) 143 | 144 | test_errors.append(test_error) 145 | print('Fold: {:02d}, Test MAE: {:.7f}, Test RMSE: {:.7f}, Test PCC: {:.7f}, Test sMAPE: {:.7f}'.format(fold_idx, test_error['MAE'], test_error['RMSE'], test_error['PCC'], test_error['sMAPE'])) 146 | logger.info('Fold: {:02d}, Test MAE: {:.7f}, Test RMSE: {:.7f}, Test PCC: {:.7f}, Test sMAPE: {:.7f}'.format(fold_idx, test_error['MAE'], test_error['RMSE'], test_error['PCC'], test_error['sMAPE'])) 147 | for _ in test_errors[0].keys(): 148 | err_mean = np.mean([__[_] for __ in test_errors]) 149 | err_std = np.std([__[_] for __ in test_errors]) 150 | print('Test {} of {:02d}-Folds : {:.7f}({:.7f})'.format(_,fold_num,err_mean,err_std)) 151 | logger.info('Test {} of {:02d}-Folds : {:.7f}({:.7f})'.format(_,fold_num,err_mean,err_std)) 152 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import logging 4 | import os 5 | 6 | def splitdata(length,fold,index): 7 | fold_length = length // fold 8 | index_list = np.arange(length) 9 | if index == 1: 10 | val = index_list[:fold_length] 11 | test = index_list[fold_length * (fold - 1):] 12 | train = index_list[fold_length : fold_length * (fold - 1)] 13 | elif index == fold: 14 | val = index_list[fold_length * (fold - 1):] 15 | test = index_list[fold_length * (fold - 2) : fold_length * (fold - 1)] 16 | train = index_list[:fold_length * (fold - 2)] 17 | else: 18 | val = index_list[fold_length * (index - 1) : fold_length * index] 19 | test = index_list[fold_length * (index - 2) : fold_length * (index - 1)] 20 | train = np.concatenate([index_list[:fold_length * (index - 2)],index_list[fold_length * index:]]) 21 | return train,val,test 22 | 23 | 24 | def printParams(model_params, logger=None): 25 | print("=========== Parameters ==========") 26 | for k,v in model_params.items(): 27 | print(f'{k} : {v}') 28 | print("=================================") 29 | print() 30 | if logger: 31 | for k,v in model_params.items(): 32 | logger.info(f'{k} : {v}') 33 | 34 | def applyIndexOnList(lis,idx): 35 | ans = [] 36 | for _ in idx: 37 | ans.append(lis[_]) 38 | return ans 39 | 40 | def set_seed(seed): 41 | torch.manual_seed(seed) # set seed for cpu 42 | torch.cuda.manual_seed(seed) # set seed for gpu 43 | torch.backends.cudnn.deterministic = True # cudnn 44 | torch.backends.cudnn.benchmark = False 45 | np.random.seed(seed) # numpy 46 | 47 | def get_logger(save_dir): 48 | logger = logging.getLogger(__name__) 49 | logger.setLevel(level = logging.INFO) 50 | handler = logging.FileHandler(save_dir + "/log.txt") 51 | handler.setLevel(logging.INFO) 52 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 53 | handler.setFormatter(formatter) 54 | logger.addHandler(handler) 55 | return logger 56 | 57 | class CheckpointHandler(object): 58 | def __init__(self, save_dir, max_save=5): 59 | self.save_dir = save_dir 60 | self.max_save = max_save 61 | self.init_info() 62 | 63 | def init_info(self): 64 | os.makedirs(self.save_dir, exist_ok=True) 65 | self.metric_dic = {} 66 | if os.path.exists(self.save_dir+'/eval_log.txt'): 67 | with open(self.save_dir+'/eval_log.txt','r') as f: 68 | ls = f.readlines() 69 | for l in ls: 70 | l = l.strip().split(':') 71 | assert len(l) == 2 72 | self.metric_dic[l[0]] = float(l[1]) 73 | 74 | 75 | def save_model(self, model, model_params, epoch, eval_metric): 76 | max_in_dic = max(self.metric_dic.values()) if len(self.metric_dic) else 1e9 77 | if eval_metric > max_in_dic: 78 | return 79 | if len(self.metric_dic) == self.max_save: 80 | self.remove_last() 81 | self.metric_dic['model-'+str(epoch)+'.pt'] = eval_metric 82 | state = {"params":model_params, "epoch":epoch, "model":model.state_dict()} 83 | torch.save(state, self.save_dir + '/' + 'model-'+str(epoch)+'.pt') 84 | log_str = '\n'.join(['{}:{:.7f}'.format(k,v) for k,v in self.metric_dic.items()]) 85 | with open(self.save_dir+'/eval_log.txt','w') as f: 86 | f.write(log_str) 87 | 88 | 89 | def remove_last(self): 90 | last_model = sorted(list(self.metric_dic.keys()),key = lambda x:self.metric_dic[x])[-1] 91 | if os.path.exists(self.save_dir+'/'+last_model): 92 | os.remove(self.save_dir+'/'+last_model) 93 | self.metric_dic.pop(last_model) 94 | 95 | def checkpoint_best(self, use_cuda=True): 96 | best_model = sorted(list(self.metric_dic.keys()),key = lambda x:self.metric_dic[x])[0] 97 | if use_cuda: 98 | state = torch.load(self.save_dir + '/' + best_model) 99 | else: 100 | state = torch.load(self.save_dir + '/' + best_model,map_location='cpu') 101 | return state 102 | 103 | def checkpoint_avg(self, use_cuda=True): 104 | return_dic = None 105 | model_num = 0 106 | tmp_model_params = None 107 | for ckpt in os.listdir(self.save_dir): 108 | if not ckpt.endswith('.pt'): 109 | continue 110 | model_num += 1 111 | if use_cuda: 112 | state = torch.load(self.save_dir + '/' + ckpt) 113 | else: 114 | state = torch.load(self.save_dir + '/' + ckpt,map_location='cpu') 115 | model,tmp_model_params = state['model'], state['params'] 116 | if not return_dic: 117 | return_dic = model 118 | else: 119 | for k in return_dic: 120 | return_dic[k] += model[k] 121 | for k in return_dic: 122 | return_dic[k] = return_dic[k]/model_num 123 | return {'params':tmp_model_params, 'model':return_dic} --------------------------------------------------------------------------------