├── .gitignore
├── LICENSE
├── README.md
├── argparser.py
├── attn_vis.py
├── baselines
├── README.md
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── data_utils.cpython-37.pyc
│ ├── data_utils.cpython-38.pyc
│ ├── dimenet_pp.cpython-37.pyc
│ ├── dimenet_pp.cpython-38.pyc
│ ├── egnn.cpython-37.pyc
│ ├── egnn.cpython-38.pyc
│ ├── gin.cpython-37.pyc
│ ├── painn.cpython-37.pyc
│ ├── painn.cpython-38.pyc
│ ├── schnet.cpython-37.pyc
│ └── schnet.cpython-38.pyc
├── data_utils.py
├── dimenet_pp.py
├── egnn.py
├── painn.py
├── schnet.py
└── spk_utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── acsf.cpython-37.pyc
│ ├── acsf.cpython-38.pyc
│ ├── activations.cpython-37.pyc
│ ├── activations.cpython-38.pyc
│ ├── base.cpython-37.pyc
│ ├── base.cpython-38.pyc
│ ├── blocks.cpython-37.pyc
│ ├── blocks.cpython-38.pyc
│ ├── cfconv.cpython-37.pyc
│ ├── cfconv.cpython-38.pyc
│ ├── cutoff.cpython-37.pyc
│ ├── cutoff.cpython-38.pyc
│ ├── initializers.cpython-37.pyc
│ ├── initializers.cpython-38.pyc
│ ├── neighbors.cpython-37.pyc
│ └── neighbors.cpython-38.pyc
│ ├── acsf.py
│ ├── activations.py
│ ├── base.py
│ ├── blocks.py
│ ├── cfconv.py
│ ├── cutoff.py
│ ├── initializers.py
│ └── neighbors.py
├── featurization
├── __pycache__
│ └── data_utils.cpython-37.pyc
└── data_utils.py
├── image
├── 3dstructgen-mof.png
├── Fig1.jpg
└── Fig1.png
├── model_shap.py
├── models
└── transformer.py
├── nist_test.py
├── pressure_adapt.py
├── process
├── README
├── create_geo_features.sh
├── prepare_mof_features.py
├── process_csd_data.py
├── process_csd_data_baselines.py
├── process_nist_data.py
└── tools
│ ├── get_atom_features.py
│ ├── get_bond_features.py
│ └── remove_waters.py
├── requirements.txt
├── train_baselines.py
├── train_ml.py
├── train_mofnet.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Matgen-project
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MOFNet
2 | MOFNet is a deep learning model that can predict adsorption isotherm for MOFs based on hierarchical representation, graph transformer and pressure adaptive mechanism. We elaborately design a hierarchical representation to describe the MOFs structure. A graph transformer is used to capture atomic level information, which can help learn chemical features required at low-pressure conditions. A pressure adaptive mechanism is used to interpolate and extrapolate the given limited data points by transfer learning, which can predict adsorption isotherms on a wider pressure range by only one model. The following is the architecture of MOFNet.
3 |
4 |
5 |
6 | ## Installation
7 | Please see dependencies in requirements.txt
8 |
9 | ## Dataset
10 |
11 | We released the training and testing data on the [Matgen website](https://matgen.nscc-gz.cn/dataset.html), which can be obtained by the following command.
12 | ```
13 | $ wget https://matgen.nscc-gz.cn/dataset/download/CSD-MOFDB_xx.tar.gz #xx: realased data
14 | $ wget https://matgen.nscc-gz.cn/dataset/download/NIST-ISODB_xx.tar.gz
15 | ```
16 |
17 | You can construct the data directory from the downloaded data as follows.
18 |
19 | ```
20 | |-- data
21 | ||-- CSD-MOFDB
22 | ||-- NIST-ISODB
23 | ```
24 |
25 | ## CSD-MOFDB
26 | We collected 7306, 6998 and 8562 MOFs for N2, CO2 and CH4 from the Cambridge Structural Database (CSD, version 5.4) dataset.
27 | GCMC simulations were carried out to calculate the adsorption data of MOFs for N2, CO2 and CH4 using RASPA software.
28 | We set 8 pressure points from the range of 0.2 kPa - 80 kPa, 5 kPa – 20,000 kPa and 100 kPa – 10,000kPa for N2, CO2 and CH4, respectively.
29 | ```
30 | | --CSD-MOFDB
31 | ||--CIFs # CIF format files.
32 | ||--global_features
33 | ||--label_by_GCMC #calculated adsorption data by GCMC method.
34 | ||--local_features
35 | ||--mol_unit #molecule unit in mol format
36 | ||--README
37 | ```
38 |
39 | ## NIST-ISODB
40 | We obtained 54 MOFs with 1876 pressure data points covering N2, CO2 and CH4 adsorbate molecules from the NIST/ARPA-E database.
41 |
42 | ```
43 | |--NIST-ISODB
44 | ||--CIFs #CIF format files.
45 | ||--global_features
46 | ||--isotherm_data #experimental data.
47 | ||--local_features
48 | ||--MOFNet #MOFNet predicting results.
49 | ||--mol_unit #molecule unit in mol format
50 | ||--README
51 | ```
52 |
53 |
54 | ## Processing
55 |
56 | ### How to generate local features?
57 | First, the CSD package need to install on your server and use CSD Python API to obtain CIF files. We create a script in process file, and run the following command to generate local features file.
58 | ```
59 | $ python process/process_csd_data.py
60 | ```
61 |
62 | ### How to obtain global features?
63 | The important structural properties including largest cavity diameter (LCD),pore-limiting diameter (PLD), and helium void fraction, etc., were calculated using open-source software Zeo++.
64 |
65 |
66 | ## Model training
67 | ```
68 | $ python -u train_mofnet.py --data_dir data/CSD-MOFDB --gas_type --pressure --save_dir --use_global_feature
69 | ```
70 |
71 | ## Transfer learning
72 | ```
73 | $ python -u pressure_adapt.py --data_dir data/CSD-MOFDB --gas_type --pressure --save_dir --ori_dir /_ --adapter_dim 8
74 | ```
75 |
76 | ## Prediction
77 | ```
78 | $ python -u nist_test.py --data_dir data/NIST-ISODB --gas_type --pressure --save_dir --img_dir
79 | ```
80 |
81 | We also welecome users to use our [3DStructGen UI interface](https://matgen.nscc-gz.cn/3dstructgen/v2/mod/3dstructgen_newUI.html) to predict crystal properties by the following steps:
82 | ```
83 | # Upload your CIF crystal files into 3DStuctGen interface;
84 | # Click "Caculate" button and use the APP of "Artificical Intelligence - MOF"
85 | # Choose the uptake gas and pressure range you want to calculate and then submit.
86 | ```
87 |
88 |
89 |
90 | ## Acknowledgments
91 | The implementation of the Graph Transformer module is built upon [Molecule Attention Transformer](https://github.com/ardigen/MAT).
92 |
93 | ## Reference:
94 | [1]. Maziarka, {\L}ukasz and Danel, Tomasz and Mucha, S{\l}awomir and Rataj, Krzysztof and Tabor, Jacek and Jastrz{\k{e}}bski, Stanis{\l}aw: Molecule attention transforme. arXiv preprint arXiv:2002.08264 2020
95 |
96 | [2]. Pin Chen, Yu Wang, Hui Yan, Sen Gao, Zexin Xu, Yangzhong Li, Qing Mo, Junkang Huang, Jun Tao, GeChuanqi Pan, Jiahui Li & Yunfei Du. 3DStructGen: an interactive web-based 3D structure generation for non-periodic molecule and crystal. J Cheminform 12, 7 (2020). https://doi.org/10.1186/s13321-020-0411-2
97 |
--------------------------------------------------------------------------------
/argparser.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import os
3 |
4 | def parse_train_args():
5 | parser = ArgumentParser()
6 | add_data_args(parser)
7 | add_train_args(parser)
8 | args = parser.parse_args()
9 | args = vars(args)
10 | lambda_mat = [float(_) for _ in args['weight_split'].split(',')]
11 | assert len(lambda_mat) == 3
12 | lambda_sum = sum(lambda_mat)
13 | args['lambda_attention'] = lambda_mat[0] / lambda_sum
14 | args['lambda_distance'] = lambda_mat[-1] / lambda_sum
15 | if args['d_mid_list'] == 'None':
16 | args['d_mid_list'] = []
17 | else:
18 | args['d_mid_list'] = [int(_) for _ in args['d_mid_list'].split(',')]
19 | makedirs(args['save_dir'] + f"/{args['gas_type']}_{args['pressure']}/")
20 | return args
21 |
22 | def parse_predict_args():
23 | parser = ArgumentParser()
24 | add_data_args(parser)
25 | add_train_args(parser)
26 | args = parser.parse_args()
27 | args = vars(args)
28 | lambda_mat = [float(_) for _ in args['weight_split'].split(',')]
29 | assert len(lambda_mat) == 3
30 | lambda_sum = sum(lambda_mat)
31 | args['lambda_attention'] = lambda_mat[0] / lambda_sum
32 | args['lambda_distance'] = lambda_mat[-1] / lambda_sum
33 | if args['d_mid_list'] == 'None':
34 | args['d_mid_list'] = []
35 | else:
36 | args['d_mid_list'] = [int(_) for _ in args['d_mid_list'].split(',')]
37 | p_cond = args['pressure'].split(',')
38 | assert len(p_cond) == 3
39 | args['pressure'] = (float(p_cond[0]), float(p_cond[1]), int(p_cond[2]))
40 | return args
41 |
42 | def parse_baseline_args():
43 | parser = ArgumentParser()
44 | add_data_args(parser)
45 | add_baseline_args(parser)
46 | args = parser.parse_args()
47 | args = vars(args)
48 | makedirs(args['save_dir'] + f"/{args['gas_type']}_{args['pressure']}/")
49 | return args
50 |
51 | def parse_finetune_args():
52 | parser = ArgumentParser()
53 | add_data_args(parser)
54 | add_finetune_args(parser)
55 | args = parser.parse_args()
56 | args = vars(args)
57 | makedirs(args['save_dir'] + f"/{args['gas_type']}_{args['pressure']}/")
58 | return args
59 |
60 | def parse_ml_args():
61 | parser = ArgumentParser()
62 | add_data_args(parser)
63 | add_ml_args(parser)
64 | args = parser.parse_args()
65 | args = vars(args)
66 | makedirs(args['save_dir'] + f"/{args['ml_type']}/{args['gas_type']}_{args['pressure']}/")
67 | return args
68 |
69 | def makedirs(path: str, isfile: bool = False):
70 | if isfile:
71 | path = os.path.dirname(path)
72 | if path != '':
73 | os.makedirs(path, exist_ok=True)
74 |
75 | def add_ml_args(parser: ArgumentParser):
76 | parser.add_argument('--ml_type', type=str, default='RF',
77 | help='ML algorithm, SVR/DT/RF.')
78 |
79 | parser.add_argument('--seed', type=int, default=9999,
80 | help='Random seed to use when splitting data into train/val/test sets.'
81 | 'When `num_folds` > 1, the first fold uses this seed and all'
82 | 'subsequent folds add 1 to the seed.')
83 | parser.add_argument('--fold', type=int, default=10,
84 | help='Fold num.')
85 |
86 | def add_data_args(parser: ArgumentParser):
87 | parser.add_argument('--data_dir', type=str,
88 | help='Dataset directory, containing label/ and processed/ subdirectories.')
89 |
90 | parser.add_argument('--save_dir', type=str,
91 | help='Model directory.')
92 |
93 |
94 | parser.add_argument('--gas_type', type=str,
95 | help='Gas type for prediction.')
96 |
97 | parser.add_argument('--pressure', type=str,
98 | help='Pressure condition for prediction.')
99 |
100 | parser.add_argument('--img_dir', type=str, default='',
101 | help='Directory for visualized isotherms')
102 |
103 |
104 | parser.add_argument('--name', type=str, default='',
105 | help='Target MOF name for attention visualization.')
106 |
107 | def add_finetune_args(parser: ArgumentParser):
108 | parser.add_argument('--ori_dir', type=str,
109 | help='Pretrained model directory, containing model of different Folds.')
110 |
111 | parser.add_argument('--epoch', type=int, default=100,
112 | help='Epoch num.')
113 |
114 | parser.add_argument('--batch_size', type=int, default=32,
115 | help='Batch size.')
116 |
117 | parser.add_argument('--fold', type=int, default=10,
118 | help='Fold num.')
119 |
120 | parser.add_argument('--lr', type=float, default=0.0007,
121 | help='Learning rate.')
122 |
123 | parser.add_argument('--adapter_dim', type=int, default=8,
124 | help='Adapted vector dimension')
125 |
126 | parser.add_argument('--seed', type=int, default=9999,
127 | help='Random seed to use when splitting data into train/val/test sets.')
128 |
129 | def add_baseline_args(parser: ArgumentParser):
130 |
131 | parser.add_argument('--model_name',type=str,default='gin',
132 | help='Baseline Model, gin/egnn/schnet/painn.')
133 |
134 | parser.add_argument('--gpu', type=int,
135 | help='GPU id to allocate.')
136 |
137 | parser.add_argument('--seed', type=int, default=9999,
138 | help='Random seed to use when splitting data into train/val/test sets.')
139 |
140 | parser.add_argument('--d_model', type=int, default=1024,
141 | help='Hidden size of baseline model.')
142 |
143 | parser.add_argument('--N', type=int, default=2,
144 | help='Layer num of baseline model.')
145 |
146 | parser.add_argument('--use_global_feature', action='store_true',
147 | help='Whether to use global features(graph-level features).')
148 |
149 | parser.add_argument('--warmup_step', type=int, default=2000,
150 | help='Warmup steps.')
151 |
152 | parser.add_argument('--epoch', type=int, default=100,
153 | help='Epoch num.')
154 |
155 | parser.add_argument('--batch_size', type=int, default=32,
156 | help='Batch size.')
157 |
158 | parser.add_argument('--fold', type=int, default=10,
159 | help='Fold num.')
160 |
161 | parser.add_argument('--lr', type=float, default=0.0007,
162 | help='Maximum learning rate, (warmup_step * d_model) ** -0.5 .')
163 |
164 | def add_train_args(parser: ArgumentParser):
165 |
166 | parser.add_argument('--seed', type=int, default=9999,
167 | help='Random seed to use when splitting data into train/val/test sets.')
168 |
169 | parser.add_argument('--d_model', type=int, default=1024,
170 | help='Hidden size of transformer model.')
171 |
172 | parser.add_argument('--N', type=int, default=2,
173 | help='Layer num of transformer model.')
174 |
175 | parser.add_argument('--h', type=int, default=16,
176 | help='Attention head num of transformer model.')
177 |
178 | parser.add_argument('--n_generator_layers', type=int, default=2,
179 | help='Layer num of generator(MLP) model')
180 |
181 | parser.add_argument('--weight_split', type=str, default='1,1,1',
182 | help='Unnormalized weights of Self-Attention/Adjacency/Distance Matrix respectively in Graph Transformer.')
183 |
184 | parser.add_argument('--leaky_relu_slope', type=float, default=0.0,
185 | help='Leaky ReLU slope for activation functions.')
186 |
187 | parser.add_argument('--dense_output_nonlinearity',type=str,default='silu',
188 | help='Activation Function for predict module, silu/relu/tanh/none.')
189 |
190 | parser.add_argument('--distance_matrix_kernel',type=str,default='bessel',
191 | help='Kernel applied on Distance Matrix, bessel/softmax/exp. For example, exp means setting D(i,j) of node i,j with distance d by exp(-d)')
192 |
193 | parser.add_argument('--dropout', type=float, default=0.1,
194 | help='Dropout ratio.')
195 |
196 | parser.add_argument('--aggregation_type', type=str, default='mean',
197 | help='Type for aggregeting node feature into graph feature, mean/sum/dummy_node.')
198 |
199 | parser.add_argument('--use_global_feature', action='store_true',
200 | help='Whether to use global features(graph-level features).')
201 |
202 | parser.add_argument('--use_ffn_only', action='store_true',
203 | help='Use DNN Generator which only considers global features. ')
204 |
205 | parser.add_argument('--d_mid_list', type=str, default='128,512',
206 | help='Projection Layers to augment global feature dim to local feature dim.')
207 |
208 | parser.add_argument('--warmup_step', type=int, default=2000,
209 | help='Warmup steps.')
210 |
211 | parser.add_argument('--epoch', type=int, default=300,
212 | help='Epoch num.')
213 |
214 | parser.add_argument('--batch_size', type=int, default=64,
215 | help='Batch size.')
216 |
217 | parser.add_argument('--fold', type=int, default=10,
218 | help='Fold num.')
219 |
220 | parser.add_argument('--lr', type=float, default=0.0007,
221 | help='Maximum learning rate, (warmup_step * d_model) ** -0.5 .')
222 |
223 |
224 |
225 |
226 |
--------------------------------------------------------------------------------
/attn_vis.py:
--------------------------------------------------------------------------------
1 | import shap
2 | import torch
3 | from collections import defaultdict
4 | from featurization.data_utils import load_data_from_df, construct_loader_gf_pressurever, construct_dataset_gf_pressurever, data_prefetcher
5 | from models.transformer import make_model
6 | import numpy as np
7 | import os
8 | from argparser import parse_train_args
9 | import pickle
10 | from tqdm import tqdm
11 | import matplotlib.pyplot as plt
12 | import seaborn as sns
13 | from utils import *
14 |
15 | periodic_table = ('Dummy','H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar',
16 | 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br',
17 | 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Te', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te',
18 | 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm',
19 | 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Unk')
20 |
21 |
22 | model_params = parse_train_args()
23 | img_dir = os.path.join(model_params['img_dir'],'attn')
24 | os.makedirs(img_dir,exist_ok=True)
25 |
26 |
27 | def heapmap(atoms, attn, name):
28 | plt.cla()
29 | f, ax = plt.subplots(figsize=(20, 15))
30 | colormap = 'Reds'
31 | h = sns.heatmap(attn, vmax=attn.max(), yticklabels = atoms, xticklabels = atoms, square=True, cmap=colormap, cbar=False)
32 | fontsize = 15
33 | cb=h.figure.colorbar(h.collections[0])
34 | cb.ax.tick_params(labelsize=fontsize)
35 | ax.tick_params(labelsize=fontsize,rotation=0)
36 | ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
37 | plt.savefig(os.path.join(img_dir, name + '.pdf'))
38 |
39 | def test(model, data_loader, name_list):
40 | model.eval()
41 | batch_idx = -1
42 | ans = {}
43 | for data in tqdm(data_loader):
44 | batch_idx += 1
45 | adjacency_matrix, node_features, distance_matrix, global_features, y = (_.cpu() for _ in data)
46 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
47 | graph_rep = model.encode(node_features, batch_mask, adjacency_matrix, distance_matrix, None)
48 | attn = model.encoder.layers[0].self_attn.self_attn.detach().cpu().numpy()
49 | atoms = node_features.numpy()[:,:,:83].argmax(axis=-1).reshape(-1)
50 | attn = attn[0].mean(axis=0)
51 | atoms = applyIndexOnList(periodic_table, atoms)
52 | ans[name_list[batch_idx]] = {
53 | 'atoms':atoms,
54 | 'attn':attn
55 | }
56 | heapmap(atoms, attn, name_list[batch_idx])
57 | return ans
58 |
59 | if __name__ == '__main__':
60 | batch_size = 1
61 | device_ids = [0,1,2,3]
62 | X, f, y,p = load_data_from_df(model_params['data_dir'],gas_type=model_params['gas_type'], pressure="all",add_dummy_node = True,use_global_features = True, return_names=True)
63 | print("X,f,y,p")
64 | tar_idx = np.where(p==model_params['pressure'])[0][0]
65 | y = np.array(y)
66 | mean = y[...,tar_idx].mean()
67 | std = y[...,tar_idx].std()
68 | f = np.array(f)
69 | fmean = f.mean(axis=0)
70 | fstd = f.std(axis=0)
71 | test_errors_all = []
72 | f = (f - fmean) / fstd
73 | X, names = X
74 |
75 | print(f'Loaded {len(X)} data.')
76 |
77 | fold_idx = 1
78 | save_dir = model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}/Fold-{fold_idx}"
79 | ckpt_handler = CheckpointHandler(save_dir)
80 | state = ckpt_handler.checkpoint_best(use_cuda=False)
81 | model = make_model(**state['params'])
82 | model = torch.nn.DataParallel(model)
83 | model.load_state_dict(state['model'])
84 | model = model.module
85 | if model_params['name'] == '':
86 | sample_idx = np.arange(1000)
87 | tar_name = 'all'
88 | else:
89 | if model_params['name'] in names:
90 | sample_idx = [names.index(model_params['name'])]
91 | tar_name = model_params['name']
92 | else:
93 | sample_idx = [0]
94 | tar_name = 'random'
95 | train_sample = construct_dataset_gf_pressurever(applyIndexOnList(X,sample_idx), f[sample_idx], y[sample_idx],p, is_train=False, tar_point=model_params['pressure'],mask_point=model_params['pressure'])
96 | sample_loader = construct_loader_gf_pressurever(train_sample, 1, shuffle=False)
97 | ans = test(model, sample_loader, applyIndexOnList(names, sample_idx))
98 |
99 | with open(os.path.join(img_dir,f"attn_{tar_name}.p"),'wb') as f:
100 | pickle.dump(ans, f)
101 |
102 |
--------------------------------------------------------------------------------
/baselines/README.md:
--------------------------------------------------------------------------------
1 | ### Baselines
2 |
3 | Adapted 4 Baselines:
4 |
5 | - Schnet https://arxiv.org/abs/1706.08566
6 | - DimeNet++ https://arxiv.org/abs/2011.14115
7 | - EGNN https://arxiv.org/abs/2102.09844
8 | - PaiNN https://arxiv.org/abs/2102.03150
--------------------------------------------------------------------------------
/baselines/__init__.py:
--------------------------------------------------------------------------------
1 | from ast import mod
2 | from turtle import forward
3 | from .egnn import *
4 | from .painn import *
5 | from .schnet import *
6 | from .dimenet_pp import *
7 | from torch import nn
8 | from torch.nn import functional as F
9 |
10 | def make_baseline_model(d_atom, model_name, N=2, d_model=128, use_global_feature=False, d_feature=9, **kwargs):
11 | model = None
12 | if model_name == 'egnn':
13 | representation = EGNN(in_node_nf=d_atom, hidden_nf=d_model, n_layers=N, attention=True)
14 | use_adj = True
15 | elif model_name == 'dimenetpp':
16 | representation = DimeNetPlusPlus(hidden_channels=d_model, out_channels=d_model, num_input=d_atom, num_blocks=N, int_emb_size=d_model // 2, basis_emb_size=8, out_emb_channels=d_model * 2, num_spherical=7, num_radial=6)
17 | use_adj = True
18 | elif model_name == 'schnet':
19 | representation = SchNet(n_atom_basis=d_model, n_filters=d_model, n_interactions=N, max_z=d_atom)
20 | use_adj = False
21 | elif model_name == 'painn':
22 | representation = PaiNN(n_atom_basis=d_model, n_interactions=N, max_z=d_atom)
23 | use_adj = False
24 | if use_global_feature:
25 | out = Generator_with_gf(d_model=d_model, d_gf=d_feature)
26 | else:
27 | out = Generator(d_model=d_model)
28 | model = BaselineModel(representation=representation, output=out, use_adj=use_adj)
29 | return model
30 |
31 | class Generator(nn.Module):
32 | def __init__(self, d_model):
33 | super(Generator, self).__init__()
34 | self.hidden_nf = d_model
35 | self.node_dec = nn.Sequential(nn.Linear(self.hidden_nf, self.hidden_nf),
36 | nn.SiLU(),
37 | nn.Linear(self.hidden_nf, self.hidden_nf))
38 |
39 | self.graph_dec = nn.Sequential(nn.Linear(self.hidden_nf, self.hidden_nf),
40 | nn.SiLU(),
41 | nn.Linear(self.hidden_nf, 1))
42 |
43 | def forward(self, h, atom_mask, global_feature=None):
44 | h = self.node_dec(h)
45 | h = h * atom_mask.unsqueeze(-1)
46 | h = torch.sum(h, dim=1)
47 | pred = self.graph_dec(h)
48 | return pred.squeeze(1)
49 |
50 | class Generator_with_gf(nn.Module):
51 | def __init__(self, d_model, d_gf):
52 | super(Generator_with_gf, self).__init__()
53 | self.hidden_nf = d_model
54 | self.input_nf = d_gf
55 | self.node_dec = nn.Sequential(nn.Linear(self.hidden_nf, self.hidden_nf),
56 | nn.SiLU(),
57 | nn.Linear(self.hidden_nf, self.hidden_nf))
58 |
59 | self.gf_enc = nn.Sequential(nn.Linear(self.input_nf, self.hidden_nf // 2),
60 | nn.SiLU(),
61 | nn.Linear(self.hidden_nf // 2, self.hidden_nf))
62 |
63 | self.graph_dec = nn.Sequential(nn.Linear(self.hidden_nf * 2, self.hidden_nf),
64 | nn.SiLU(),
65 | nn.Linear(self.hidden_nf, 1))
66 |
67 | def forward(self, h, atom_mask, global_feature):
68 | h = self.node_dec(h)
69 | h = h * atom_mask.unsqueeze(-1)
70 | h = torch.sum(h, dim=1)
71 | g = self.gf_enc(global_feature)
72 | h = torch.cat([h,g], dim=1)
73 | pred = self.graph_dec(h)
74 | return pred.squeeze(1)
75 |
76 | class BaselineModel(nn.Module):
77 | def __init__(self, representation, output, use_adj=True):
78 | super(BaselineModel, self).__init__()
79 | self.representation = representation
80 | self.output = output
81 | self.use_adj = use_adj
82 | def forward(self, node_features, batch_mask, pos, adj, global_feature=None):
83 | if not self.use_adj:
84 | neighbors, neighbor_mask = adj
85 | rep = self.representation(node_features, pos, neighbors, neighbor_mask, batch_mask)
86 | else:
87 | rep = self.representation(node_features, batch_mask, pos, adj)
88 | out = self.output(rep, batch_mask, global_feature)
89 | return out
--------------------------------------------------------------------------------
/baselines/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/baselines/__pycache__/data_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/data_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/__pycache__/data_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/data_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/baselines/__pycache__/dimenet_pp.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/dimenet_pp.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/__pycache__/dimenet_pp.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/dimenet_pp.cpython-38.pyc
--------------------------------------------------------------------------------
/baselines/__pycache__/egnn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/egnn.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/__pycache__/egnn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/egnn.cpython-38.pyc
--------------------------------------------------------------------------------
/baselines/__pycache__/gin.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/gin.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/__pycache__/painn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/painn.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/__pycache__/painn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/painn.cpython-38.pyc
--------------------------------------------------------------------------------
/baselines/__pycache__/schnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/schnet.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/__pycache__/schnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/__pycache__/schnet.cpython-38.pyc
--------------------------------------------------------------------------------
/baselines/data_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | from sklearn.metrics import pairwise_distances
9 | from torch.utils.data import Dataset, dataset
10 | from scipy.sparse import coo_matrix
11 | import json
12 | import copy
13 |
14 |
15 | FloatTensor = torch.FloatTensor
16 | LongTensor = torch.LongTensor
17 | IntTensor = torch.IntTensor
18 | DoubleTensor = torch.DoubleTensor
19 |
20 | def load_data_from_df(dataset_path, gas_type, pressure, use_global_features=False):
21 | print(dataset_path + f'/label/{gas_type}/{gas_type}_ads_all.csv')
22 | data_df = pd.read_csv(dataset_path + f'/label/{gas_type}/{gas_type}_ads_all.csv',header=0)
23 |
24 | data_x = data_df['name'].values
25 | if pressure == 'all':
26 | data_y = data_df.iloc[:,1:].values
27 | else:
28 | data_y = data_df[pressure].values
29 |
30 | if data_y.dtype == np.float64:
31 | data_y = data_y.astype(np.float32)
32 |
33 | x_all, y_all, name_all = load_data_from_processed(dataset_path, data_x, data_y)
34 |
35 | if use_global_features:
36 | f_all = load_data_with_global_features(dataset_path, name_all, gas_type)
37 | if pressure == 'all':
38 | return x_all, f_all, y_all, data_df.columns.values[1:]
39 | return x_all, f_all, y_all
40 |
41 | if pressure == 'all':
42 | return x_all, y_all, data_df.columns.values[1:]
43 | return x_all, y_all
44 |
45 | def load_data_with_global_features(dataset_path, processed_files, gas_type):
46 | global_feature_path = dataset_path + f'/label/{gas_type}/{gas_type}_global_features_update.csv'
47 | data_df = pd.read_csv(global_feature_path,header=0)
48 | data_x = data_df.iloc[:, 0].values
49 | data_f = data_df.iloc[:,1:].values.astype(np.float32)
50 | data_dict = {}
51 | for i in range(data_x.shape[0]):
52 | data_dict[data_x[i]] = data_f[i]
53 | f_all = [data_dict[_] for _ in processed_files]
54 | return f_all
55 |
56 |
57 |
58 | def load_data_from_processed(dataset_path, processed_files, labels):
59 | x_all, y_all, name_all = [], [], []
60 |
61 | for files, label in zip(processed_files, labels):
62 |
63 | data_file = dataset_path + '/processed_en/' + files + '.p'
64 | try:
65 | afm, row, col, pos = pickle.load(open(data_file, "rb"))
66 | x_all.append([afm, row, col, pos])
67 | y_all.append([label])
68 | name_all.append(files)
69 | except:
70 | pass
71 |
72 | return x_all, y_all, name_all
73 |
74 | class MOF:
75 |
76 | def __init__(self, x, y, index, feature = None):
77 | self.node_features = x[0]
78 | self.edges = np.array([x[1],x[2]])
79 | self.pos = x[3]
80 | self.y = y
81 | self.index = index
82 | self.global_feature = feature
83 | self.size = x[0].shape[0]
84 | self.adj, self.nbh, self.nbh_mask = self.neighbor_matrix()
85 |
86 | def neighbor_matrix(self):
87 | csr = coo_matrix((np.ones_like(self.edges[0]), self.edges), shape=(self.size, self.size)).tocsr()
88 | rowptr, col = csr.indptr, csr.indices
89 | degree = rowptr[1:] - rowptr[:-1]
90 | max_d = degree.max()
91 | _range = np.tile(np.arange(max_d),(self.size,1)).reshape(-1)
92 | _degree = degree.repeat(max_d).reshape(-1)
93 | mask = _range < _degree
94 | ret_nbh = np.zeros(self.size * max_d)
95 | ret_nbh[mask] = col
96 | return csr.toarray(), ret_nbh.reshape(self.size, max_d), mask.reshape(self.size, max_d)
97 |
98 |
99 | class MOFDataset(Dataset):
100 |
101 | def __init__(self, data_list):
102 |
103 | self.data_list = data_list
104 |
105 | def __len__(self):
106 | return len(self.data_list)
107 |
108 | def __getitem__(self, key):
109 | if type(key) == slice:
110 | return MOFDataset(self.data_list[key])
111 | return self.data_list[key]
112 |
113 | def construct_dataset_gf(x_all, f_all, y_all):
114 | output = [MOF(data[0], data[2], i, data[1])
115 | for i, data in enumerate(zip(x_all, f_all, y_all))]
116 | return MOFDataset(output)
117 |
118 | def pad_array(array, shape, dtype=np.float32):
119 | padded_array = np.zeros(shape, dtype=dtype)
120 | padded_array[:array.shape[0], :array.shape[1]] = array
121 | return padded_array
122 |
123 | def mof_collate_func_adj(batch):
124 | pos_list, features_list,global_features_list = [], [], []
125 | adjs = []
126 | labels = []
127 |
128 | max_size = 0
129 | for molecule in batch:
130 | if type(molecule.y[0]) == np.ndarray:
131 | labels.append(molecule.y[0])
132 | else:
133 | labels.append(molecule.y)
134 | if molecule.node_features.shape[0] > max_size:
135 | max_size = molecule.node_features.shape[0]
136 |
137 | for molecule in batch:
138 | pos_list.append(pad_array(molecule.pos, (max_size, 3)))
139 | features_list.append(pad_array(molecule.node_features, (max_size, molecule.node_features.shape[1])))
140 | adjs.append(pad_array(molecule.adj, (max_size, max_size)))
141 | global_features_list.append(molecule.global_feature)
142 |
143 | return [FloatTensor(features_list), FloatTensor(pos_list), LongTensor(adjs), FloatTensor(global_features_list), FloatTensor(labels)]
144 |
145 | def mof_collate_func_nbh(batch):
146 | pos_list, features_list, global_features_list = [], [], []
147 | nbhs, nbh_masks = [],[]
148 | labels = []
149 |
150 | max_size = 0
151 | max_degree = 0
152 | for molecule in batch:
153 | if type(molecule.y[0]) == np.ndarray:
154 | labels.append(molecule.y[0])
155 | else:
156 | labels.append(molecule.y)
157 | if molecule.node_features.shape[0] > max_size:
158 | max_size = molecule.node_features.shape[0]
159 | if molecule.nbh.shape[1] > max_degree:
160 | max_degree = molecule.nbh.shape[1]
161 |
162 | for molecule in batch:
163 | pos_list.append(pad_array(molecule.pos, (max_size, 3)))
164 | features_list.append(pad_array(molecule.node_features, (max_size, molecule.node_features.shape[1])))
165 | nbhs.append(pad_array(molecule.nbh, (max_size, max_degree)))
166 | nbh_masks.append(pad_array(molecule.nbh_mask, (max_size, max_degree)))
167 | global_features_list.append(molecule.global_feature)
168 |
169 | return [FloatTensor(features_list), FloatTensor(pos_list), LongTensor(nbhs), FloatTensor(nbh_masks), FloatTensor(global_features_list), FloatTensor(labels)]
170 |
171 | def construct_loader(x, f, y, batch_size, shuffle=True, use_adj=True):
172 | data_set = construct_dataset_gf(x, f, y)
173 | loader = torch.utils.data.DataLoader(dataset=data_set,
174 | batch_size=batch_size,
175 | num_workers=8,
176 | collate_fn=mof_collate_func_adj if use_adj else mof_collate_func_nbh,
177 | pin_memory=True,
178 | shuffle=shuffle)
179 | return loader
180 |
181 | class data_prefetcher():
182 | def __init__(self, loader, device):
183 | self.loader = iter(loader)
184 | self.stream = torch.cuda.Stream(device)
185 | self.preload()
186 |
187 | def preload(self):
188 | try:
189 | self.next_data = next(self.loader)
190 | except StopIteration:
191 | self.next_data = None
192 | return
193 | with torch.cuda.stream(self.stream):
194 | self.next_data = tuple(_.cuda(non_blocking=True) for _ in self.next_data)
195 |
196 | def next(self):
197 | torch.cuda.current_stream().wait_stream(self.stream)
198 | batch = self.next_data
199 | self.preload()
200 | return batch
--------------------------------------------------------------------------------
/baselines/dimenet_pp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch_geometric.nn import radius_graph
4 | from torch_geometric.nn.acts import swish
5 | from torch_geometric.nn.inits import glorot_orthogonal
6 | from torch_geometric.nn.models.dimenet import (
7 | BesselBasisLayer,
8 | Envelope,
9 | ResidualLayer,
10 | SphericalBasisLayer,
11 | )
12 | from torch_scatter import scatter
13 | from torch_sparse import SparseTensor
14 | import sympy as sym
15 |
16 | def dense_to_sparse(adj):
17 | r"""Converts a dense adjacency matrix to a sparse adjacency matrix defined
18 | by edge indices and edge attributes.
19 |
20 | Args:
21 | adj (Tensor): The dense adjacency matrix.
22 | :rtype: (:class:`LongTensor`, :class:`Tensor`)
23 | """
24 | assert adj.dim() >= 2 and adj.dim() <= 3
25 | assert adj.size(-1) == adj.size(-2)
26 |
27 | index = adj.nonzero(as_tuple=True)
28 | edge_attr = adj[index]
29 |
30 | if len(index) == 3:
31 | batch = index[0] * adj.size(-1)
32 | index = (batch + index[1], batch + index[2])
33 |
34 | return torch.stack(index, dim=0), edge_attr
35 |
36 | class MLP(torch.nn.Module):
37 | def __init__(self, input_size, output_size, hidden_sizes, activation_hidden, activation_out, biases, dropout):
38 | super(MLP, self).__init__()
39 | self.activation_hidden = activation_hidden
40 | self.activation_out = activation_out
41 | self.dropout = dropout
42 |
43 | if len(hidden_sizes) > 0:
44 | self.linear_layers = torch.nn.ModuleList([torch.nn.Linear(input_size, hidden_sizes[0], bias = biases)])
45 | self.linear_layers.extend([torch.nn.Linear(in_size, out_size, bias = biases)
46 | for (in_size, out_size)
47 | in zip(hidden_sizes[0:-1], (hidden_sizes[1:]))])
48 | self.linear_layers.append(torch.nn.Linear(hidden_sizes[-1], output_size, bias = biases))
49 |
50 | else:
51 | self.linear_layers = torch.nn.ModuleList([torch.nn.Linear(input_size, output_size, bias = biases)])
52 |
53 | def forward(self, x):
54 | if len(self.linear_layers) == 1:
55 | out = self.activation_out(self.linear_layers[0](x))
56 |
57 | else:
58 | out = self.activation_hidden(self.linear_layers[0](x))
59 | for i, layer in enumerate(self.linear_layers[1:-1]):
60 | out = self.activation_hidden(layer(out))
61 | out = torch.nn.functional.dropout(out, p = self.dropout, training = self.training)
62 | out = self.activation_out(self.linear_layers[-1](out))
63 |
64 | return out
65 |
66 | class InteractionPPBlock(torch.nn.Module):
67 | def __init__(
68 | self,
69 | hidden_channels,
70 | int_emb_size,
71 | basis_emb_size,
72 | num_spherical,
73 | num_radial,
74 | num_before_skip,
75 | num_after_skip,
76 | act=swish,
77 | ):
78 | super(InteractionPPBlock, self).__init__()
79 | self.act = act
80 |
81 | # Transformations of Bessel and spherical basis representations.
82 | self.lin_rbf1 = nn.Linear(num_radial, basis_emb_size, bias=False)
83 | self.lin_rbf2 = nn.Linear(basis_emb_size, hidden_channels, bias=False)
84 | self.lin_sbf1 = nn.Linear(
85 | num_spherical * num_radial, basis_emb_size, bias=False
86 | )
87 | self.lin_sbf2 = nn.Linear(basis_emb_size, int_emb_size, bias=False)
88 |
89 | # Dense transformations of input messages.
90 | self.lin_kj = nn.Linear(hidden_channels, hidden_channels)
91 | self.lin_ji = nn.Linear(hidden_channels, hidden_channels)
92 |
93 | # Embedding projections for interaction triplets.
94 | self.lin_down = nn.Linear(hidden_channels, int_emb_size, bias=False)
95 | self.lin_up = nn.Linear(int_emb_size, hidden_channels, bias=False)
96 |
97 | # Residual layers before and after skip connection.
98 | self.layers_before_skip = torch.nn.ModuleList(
99 | [
100 | ResidualLayer(hidden_channels, act)
101 | for _ in range(num_before_skip)
102 | ]
103 | )
104 | self.lin = nn.Linear(hidden_channels, hidden_channels)
105 | self.layers_after_skip = torch.nn.ModuleList(
106 | [
107 | ResidualLayer(hidden_channels, act)
108 | for _ in range(num_after_skip)
109 | ]
110 | )
111 |
112 | #self.reset_parameters()
113 |
114 | def reset_parameters(self):
115 | glorot_orthogonal(self.lin_rbf1.weight, scale=2.0)
116 | glorot_orthogonal(self.lin_rbf2.weight, scale=2.0)
117 | glorot_orthogonal(self.lin_sbf1.weight, scale=2.0)
118 | glorot_orthogonal(self.lin_sbf2.weight, scale=2.0)
119 |
120 | glorot_orthogonal(self.lin_kj.weight, scale=2.0)
121 | self.lin_kj.bias.data.fill_(0)
122 | glorot_orthogonal(self.lin_ji.weight, scale=2.0)
123 | self.lin_ji.bias.data.fill_(0)
124 |
125 | glorot_orthogonal(self.lin_down.weight, scale=2.0)
126 | glorot_orthogonal(self.lin_up.weight, scale=2.0)
127 |
128 | for res_layer in self.layers_before_skip:
129 | res_layer.reset_parameters()
130 |
131 | glorot_orthogonal(self.lin.weight, scale=2.0)
132 | self.lin.bias.data.fill_(0)
133 |
134 | for res_layer in self.layers_after_skip:
135 | res_layer.reset_parameters()
136 |
137 | def forward(self, x, rbf, sbf, idx_kj, idx_ji):
138 | # Initial transformations.
139 | x_ji = self.act(self.lin_ji(x))
140 | x_kj = self.act(self.lin_kj(x))
141 |
142 | # Transformation via Bessel basis.
143 | rbf = self.lin_rbf1(rbf)
144 | rbf = self.lin_rbf2(rbf)
145 | x_kj = x_kj * rbf
146 |
147 | # Down-project embeddings and generate interaction triplet embeddings.
148 | x_kj = self.act(self.lin_down(x_kj))
149 |
150 | # Transform via 2D spherical basis.
151 | sbf = self.lin_sbf1(sbf)
152 | sbf = self.lin_sbf2(sbf)
153 | x_kj = x_kj[idx_kj] * sbf
154 |
155 | # Aggregate interactions and up-project embeddings.
156 | x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0))
157 | x_kj = self.act(self.lin_up(x_kj))
158 |
159 | h = x_ji + x_kj
160 | for layer in self.layers_before_skip:
161 | h = layer(h)
162 | h = self.act(self.lin(h)) + x
163 | for layer in self.layers_after_skip:
164 | h = layer(h)
165 |
166 | return h
167 |
168 |
169 | class OutputPPBlock(torch.nn.Module):
170 | def __init__(
171 | self,
172 | num_radial,
173 | hidden_channels,
174 | out_emb_channels,
175 | out_channels,
176 | num_layers,
177 | act=swish,
178 | ):
179 | super(OutputPPBlock, self).__init__()
180 | self.act = act
181 |
182 | self.lin_rbf = nn.Linear(num_radial, hidden_channels, bias=False)
183 | self.lin_up = nn.Linear(hidden_channels, out_emb_channels, bias=True)
184 | self.lins = torch.nn.ModuleList()
185 | for _ in range(num_layers):
186 | self.lins.append(nn.Linear(out_emb_channels, out_emb_channels))
187 | self.lin = nn.Linear(out_emb_channels, out_channels, bias=False)
188 |
189 | #self.reset_parameters()
190 |
191 | def reset_parameters(self):
192 | glorot_orthogonal(self.lin_rbf.weight, scale=2.0)
193 | glorot_orthogonal(self.lin_up.weight, scale=2.0)
194 | for lin in self.lins:
195 | glorot_orthogonal(lin.weight, scale=2.0)
196 | lin.bias.data.fill_(0)
197 | self.lin.weight.data.fill_(0)
198 |
199 | def forward(self, x, rbf, i, num_nodes=None):
200 | x = self.lin_rbf(rbf) * x
201 | x = scatter(x, i, dim=0, dim_size=num_nodes)
202 | x = self.lin_up(x)
203 | for lin in self.lins:
204 | x = self.act(lin(x))
205 | return self.lin(x)
206 |
207 | class EmbeddingBlock(torch.nn.Module):
208 | def __init__(self, num_input, num_radial, hidden_channels, act=swish):
209 | super().__init__()
210 | self.act = act
211 |
212 | self.emb = nn.Linear(num_input, hidden_channels)
213 | self.lin_rbf = nn.Linear(num_radial, hidden_channels)
214 | self.lin = nn.Linear(3 * hidden_channels, hidden_channels)
215 |
216 | self.reset_parameters()
217 |
218 | def reset_parameters(self):
219 | # self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))
220 | self.emb.reset_parameters()
221 | self.lin_rbf.reset_parameters()
222 | self.lin.reset_parameters()
223 |
224 | def forward(self, x, rbf, i, j):
225 | x = self.emb(x)
226 | rbf = self.act(self.lin_rbf(rbf))
227 | return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1)))
228 |
229 | class DimeNetPlusPlus(torch.nn.Module):
230 | r"""DimeNet++ implementation based on https://github.com/klicperajo/dimenet.
231 | Args:
232 | hidden_channels (int): Hidden embedding size.
233 | out_channels (int): Size of each output sample.
234 | num_blocks (int): Number of building blocks.
235 | int_emb_size (int): Embedding size used for interaction triplets
236 | basis_emb_size (int): Embedding size used in the basis transformation
237 | out_emb_channels(int): Embedding size used for atoms in the output block
238 | num_spherical (int): Number of spherical harmonics.
239 | num_radial (int): Number of radial basis functions.
240 | cutoff: (float, optional): Cutoff distance for interatomic
241 | interactions. (default: :obj:`5.0`)
242 | envelope_exponent (int, optional): Shape of the smooth cutoff.
243 | (default: :obj:`5`)
244 | num_before_skip: (int, optional): Number of residual layers in the
245 | interaction blocks before the skip connection. (default: :obj:`1`)
246 | num_after_skip: (int, optional): Number of residual layers in the
247 | interaction blocks after the skip connection. (default: :obj:`2`)
248 | num_output_layers: (int, optional): Number of linear layers for the
249 | output blocks. (default: :obj:`3`)
250 | act: (function, optional): The activation funtion.
251 | (default: :obj:`swish`)
252 | """
253 |
254 | url = "https://github.com/klicperajo/dimenet/raw/master/pretrained"
255 |
256 | def __init__(
257 | self,
258 | hidden_channels,
259 | out_channels,
260 | num_blocks,
261 | int_emb_size,
262 | basis_emb_size,
263 | out_emb_channels,
264 | num_spherical,
265 | num_radial,
266 | num_input,
267 | cutoff=5.0,
268 | envelope_exponent=5,
269 | num_before_skip=1,
270 | num_after_skip=2,
271 | num_output_layers=3,
272 | act=swish,
273 | MLP_hidden_sizes = [],
274 | ):
275 | super(DimeNetPlusPlus, self).__init__()
276 |
277 | self.MLP_hidden_sizes = MLP_hidden_sizes
278 | self.hidden_channels = hidden_channels
279 |
280 | self.cutoff = cutoff
281 |
282 | if sym is None:
283 | raise ImportError("Package `sympy` could not be found.")
284 |
285 | self.num_blocks = num_blocks
286 |
287 | self.rbf = BesselBasisLayer(num_radial, cutoff, envelope_exponent)
288 | self.sbf = SphericalBasisLayer(
289 | num_spherical, num_radial, cutoff, envelope_exponent
290 | )
291 |
292 | # self.emb = EmbeddingBlock(num_radial, hidden_channels, act)
293 | self.emb = EmbeddingBlock(num_input, num_radial, hidden_channels, act)
294 |
295 | self.output_blocks = torch.nn.ModuleList(
296 | [
297 | OutputPPBlock(
298 | num_radial,
299 | hidden_channels,
300 | out_emb_channels,
301 | out_channels,
302 | num_output_layers,
303 | act,
304 | )
305 | for _ in range(num_blocks + 1)
306 | ]
307 | )
308 |
309 | self.interaction_blocks = torch.nn.ModuleList(
310 | [
311 | InteractionPPBlock(
312 | hidden_channels,
313 | int_emb_size,
314 | basis_emb_size,
315 | num_spherical,
316 | num_radial,
317 | num_before_skip,
318 | num_after_skip,
319 | act,
320 | )
321 | for _ in range(num_blocks)
322 | ]
323 | )
324 |
325 | # if len(self.MLP_hidden_sizes) > 0:
326 | # self.Output_MLP = MLP(input_size = out_channels, output_size = 1, hidden_sizes = MLP_hidden_sizes, activation_hidden = torch.nn.LeakyReLU(negative_slope=0.01), activation_out = torch.nn.Identity(), biases = True, dropout = 0.0)
327 |
328 | self.reset_parameters()
329 |
330 | def reset_parameters(self):
331 | self.rbf.reset_parameters()
332 | self.emb.reset_parameters()
333 | #for out in self.output_blocks:
334 | # out.reset_parameters()
335 | for interaction in self.interaction_blocks:
336 | interaction.reset_parameters()
337 |
338 | def triplets(self, edge_index, num_nodes):
339 | row, col = edge_index # j->i
340 |
341 | value = torch.arange(row.size(0), device=row.device)
342 | adj_t = SparseTensor(
343 | row=col, col=row, value=value, sparse_sizes=(num_nodes, num_nodes)
344 | )
345 | adj_t_row = adj_t[row]
346 | num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)
347 |
348 | # Node indices (k->j->i) for triplets.
349 | idx_i = col.repeat_interleave(num_triplets)
350 | idx_j = row.repeat_interleave(num_triplets)
351 | idx_k = adj_t_row.storage.col()
352 | mask = idx_i != idx_k # Remove i == k triplets.
353 | idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]
354 |
355 | # Edge indices (k-j, j->i) for triplets.
356 | idx_kj = adj_t_row.storage.value()[mask]
357 | idx_ji = adj_t_row.storage.row()[mask]
358 |
359 | return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji
360 |
361 | def forward(self, node_features, batch_mask, pos, adj):
362 | """"""
363 | # edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
364 | batch_size, n_nodes, in_node_nf = node_features.shape
365 | edge_index, _ = dense_to_sparse(adj)
366 |
367 | node_features = node_features.reshape(-1, in_node_nf)
368 | pos = pos.reshape(-1, 3)
369 | j, i = edge_index
370 | dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()
371 |
372 | _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets(
373 | edge_index, num_nodes=node_features.size(0)
374 | )
375 |
376 | # Calculate angles.
377 | pos_i = pos[idx_i].detach()
378 | pos_j = pos[idx_j].detach()
379 | pos_ji, pos_kj = (
380 | pos[idx_j].detach() - pos_i,
381 | pos[idx_k].detach() - pos_j,
382 | )
383 |
384 | a = (pos_ji * pos_kj).sum(dim=-1)
385 | b = torch.cross(pos_ji, pos_kj).norm(dim=-1)
386 | angle = torch.atan2(b, a)
387 |
388 | rbf = self.rbf(dist)
389 | sbf = self.sbf(dist, angle, idx_kj)
390 |
391 | # Embedding block.
392 | x = self.emb(node_features, rbf, i, j)
393 | P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0))
394 |
395 | # Interaction blocks.
396 | for interaction_block, output_block in zip(
397 | self.interaction_blocks, self.output_blocks[1:]
398 | ):
399 | x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
400 | P += output_block(x, rbf, i, num_nodes=pos.size(0))
401 |
402 | # out = P.sum(dim=0) if batch is None else scatter(P, batch, dim=0)
403 |
404 | # #if we are using a MLP for downstream target prediction
405 | # if len(self.MLP_hidden_sizes) > 0:
406 | # target = self.Output_MLP(out)
407 | # return target, out
408 | return P.view(-1, n_nodes, self.hidden_channels)
409 |
410 | # return out
--------------------------------------------------------------------------------
/baselines/egnn.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 |
4 | def dense_to_sparse(adj):
5 | r"""Converts a dense adjacency matrix to a sparse adjacency matrix defined
6 | by edge indices and edge attributes.
7 |
8 | Args:
9 | adj (Tensor): The dense adjacency matrix.
10 | :rtype: (:class:`LongTensor`, :class:`Tensor`)
11 | """
12 | assert adj.dim() >= 2 and adj.dim() <= 3
13 | assert adj.size(-1) == adj.size(-2)
14 |
15 | index = adj.nonzero(as_tuple=True)
16 | edge_attr = adj[index]
17 |
18 | if len(index) == 3:
19 | batch = index[0] * adj.size(-1)
20 | index = (batch + index[1], batch + index[2])
21 |
22 | return torch.stack(index, dim=0), edge_attr
23 |
24 | class E_GCL(nn.Module):
25 | """
26 | E(n) Equivariant Convolutional Layer
27 | re
28 | """
29 |
30 | def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, act_fn=nn.SiLU(), residual=True, attention=False, normalize=False, coords_agg='mean', tanh=False):
31 | super(E_GCL, self).__init__()
32 | input_edge = input_nf * 2
33 | self.residual = residual
34 | self.attention = attention
35 | self.normalize = normalize
36 | self.coords_agg = coords_agg
37 | self.tanh = tanh
38 | self.epsilon = 1e-8
39 | edge_coords_nf = 1
40 |
41 | self.edge_mlp = nn.Sequential(
42 | nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf),
43 | act_fn,
44 | nn.Linear(hidden_nf, hidden_nf),
45 | act_fn)
46 | self.node_mlp = nn.Sequential(
47 | nn.Linear(hidden_nf + input_nf, hidden_nf),
48 | act_fn,
49 | nn.Linear(hidden_nf, output_nf))
50 |
51 | layer = nn.Linear(hidden_nf, 1, bias=False)
52 | torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
53 |
54 | coord_mlp = []
55 | coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
56 | coord_mlp.append(act_fn)
57 | coord_mlp.append(layer)
58 | if self.tanh:
59 | coord_mlp.append(nn.Tanh())
60 | self.coord_mlp = nn.Sequential(*coord_mlp)
61 |
62 | if self.attention:
63 | self.att_mlp = nn.Sequential(
64 | nn.Linear(hidden_nf, 1),
65 | nn.Sigmoid())
66 |
67 | def edge_model(self, source, target, radial, edge_attr):
68 | if edge_attr is None: # Unused.
69 | out = torch.cat([source, target, radial], dim=1)
70 | else:
71 | out = torch.cat([source, target, radial, edge_attr], dim=1)
72 | out = self.edge_mlp(out)
73 | if self.attention:
74 | att_val = self.att_mlp(out)
75 | out = out * att_val
76 | return out
77 |
78 | def node_model(self, x, edge_index, edge_attr, node_attr):
79 | row, col = edge_index
80 | agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0))
81 | if node_attr is not None:
82 | agg = torch.cat([x, agg, node_attr], dim=1)
83 | else:
84 | agg = torch.cat([x, agg], dim=1)
85 | out = self.node_mlp(agg)
86 | # if self.residual:
87 | # out = x + out
88 | return out, agg
89 |
90 | def coord_model(self, coord, edge_index, coord_diff, edge_feat):
91 | row, col = edge_index
92 | trans = coord_diff * self.coord_mlp(edge_feat)
93 | if self.coords_agg == 'sum':
94 | agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0))
95 | elif self.coords_agg == 'mean':
96 | agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0))
97 | else:
98 | raise Exception('Wrong coords_agg parameter' % self.coords_agg)
99 | coord = coord + agg
100 | return coord
101 |
102 | def coord2radial(self, edge_index, coord):
103 | row, col = edge_index
104 | coord_diff = coord[row] - coord[col]
105 | radial = torch.sum(coord_diff**2, 1).unsqueeze(1)
106 |
107 | if self.normalize:
108 | norm = torch.sqrt(radial).detach() + self.epsilon
109 | coord_diff = coord_diff / norm
110 |
111 | return radial, coord_diff
112 |
113 | def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None, edge_mask=None, update_coords=True):
114 | row, col = edge_index
115 | radial, coord_diff = self.coord2radial(edge_index, coord)
116 | h0 = h
117 | edge_feat = self.edge_model(h[row], h[col], radial, edge_attr)
118 | if edge_mask is not None:
119 | edge_feat = edge_feat * edge_mask
120 | if update_coords:
121 | coord = self.coord_model(coord, edge_index, coord_diff, edge_feat)
122 | h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
123 | if self.residual:
124 | h = h0 + h
125 | return h, coord, edge_attr
126 |
127 |
128 | def unsorted_segment_sum(data, segment_ids, num_segments):
129 | result_shape = (num_segments, data.size(1))
130 | result = data.new_full(result_shape, 0) # Init empty result tensor.
131 | segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
132 | result.scatter_add_(0, segment_ids, data)
133 | return result
134 |
135 |
136 | def unsorted_segment_mean(data, segment_ids, num_segments):
137 | result_shape = (num_segments, data.size(1))
138 | segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
139 | result = data.new_full(result_shape, 0) # Init empty result tensor.
140 | count = data.new_full(result_shape, 0)
141 | result.scatter_add_(0, segment_ids, data)
142 | count.scatter_add_(0, segment_ids, torch.ones_like(data))
143 | return result / count.clamp(min=1)
144 |
145 |
146 | def get_edges(n_nodes):
147 | rows, cols = [], []
148 | for i in range(n_nodes):
149 | for j in range(n_nodes):
150 | if i != j:
151 | rows.append(i)
152 | cols.append(j)
153 |
154 | edges = [rows, cols]
155 | return edges
156 |
157 |
158 | def get_edges_batch(n_nodes, batch_size):
159 | edges = get_edges(n_nodes)
160 | edge_attr = torch.ones(len(edges[0]) * batch_size, 1)
161 | edges = [torch.LongTensor(edges[0]), torch.LongTensor(edges[1])]
162 | if batch_size == 1:
163 | return edges, edge_attr
164 | elif batch_size > 1:
165 | rows, cols = [], []
166 | for i in range(batch_size):
167 | rows.append(edges[0] + n_nodes * i)
168 | cols.append(edges[1] + n_nodes * i)
169 | edges = [torch.cat(rows), torch.cat(cols)]
170 | return edges, edge_attr
171 |
172 | class EGNN(nn.Module):
173 | def __init__(self, in_node_nf, hidden_nf, in_edge_nf=0, act_fn=nn.SiLU(), n_layers=4, residual=True, attention=False, normalize=False, tanh=False):
174 | '''
175 |
176 | :param in_node_nf: Number of features for 'h' at the input
177 | :param hidden_nf: Number of hidden features
178 | :param out_node_nf: Number of features for 'h' at the output
179 | :param in_edge_nf: Number of features for the edge features
180 | :param device: Device (e.g. 'cpu', 'cuda:0',...)
181 | :param act_fn: Non-linearity
182 | :param n_layers: Number of layer for the EGNN
183 | :param residual: Use residual connections, we recommend not changing this one
184 | :param attention: Whether using attention or not
185 | :param normalize: Normalizes the coordinates messages such that:
186 | instead of: x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij)
187 | we get: x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij)/||x_i - x_j||
188 | We noticed it may help in the stability or generalization in some future works.
189 | We didn't use it in our paper.
190 | :param tanh: Sets a tanh activation function at the output of phi_x(m_ij). I.e. it bounds the output of
191 | phi_x(m_ij) which definitely improves in stability but it may decrease in accuracy.
192 | We didn't use it in our paper.
193 | '''
194 |
195 | super(EGNN, self).__init__()
196 | self.hidden_nf = hidden_nf
197 | # self.device = device
198 | self.n_layers = n_layers
199 | self.embedding_in = nn.Linear(in_node_nf, self.hidden_nf)
200 | # self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
201 | for i in range(0, n_layers):
202 | self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf,
203 | act_fn=act_fn, residual=residual, attention=attention,
204 | normalize=normalize, tanh=tanh))
205 | # self.node_dec = nn.Sequential(nn.Linear(self.hidden_nf, self.hidden_nf),
206 | # act_fn,
207 | # nn.Linear(self.hidden_nf, self.hidden_nf))
208 |
209 | # self.graph_dec = nn.Sequential(nn.Linear(self.hidden_nf, self.hidden_nf),
210 | # act_fn,
211 | # nn.Linear(self.hidden_nf, 1))
212 | # self.to(self.device)
213 |
214 | def forward(self, node_features, batch_mask, pos, adj):
215 | batch_size, n_nodes, in_node_nf = node_features.shape
216 | edge_index, _ = dense_to_sparse(adj)
217 |
218 | h = node_features.reshape(-1, in_node_nf)
219 | x = pos.reshape(-1, 3)
220 | h = self.embedding_in(h)
221 | for i in range(0, self.n_layers):
222 | h, x, _ = self._modules["gcl_%d" % i](h, edge_index, x, edge_attr=None, edge_mask=None, update_coords=False)
223 | # h = self.node_dec(h)
224 | # h = h.view(-1, n_nodes, self.hidden_nf)
225 | # h = h * batch_mask.unsqueeze(-1)
226 | # h = torch.sum(h, dim=1)
227 | # pred = self.graph_dec(h)
228 | # return pred.squeeze(1)
229 | return h.view(-1, n_nodes, self.hidden_nf)
230 |
--------------------------------------------------------------------------------
/baselines/painn.py:
--------------------------------------------------------------------------------
1 | import math
2 | from . import spk_utils as snn
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from .spk_utils.neighbors import atom_distances
7 | from typing import Union, Callable
8 |
9 | class BesselBasis(nn.Module):
10 | """
11 | Sine for radial basis expansion with coulomb decay. (0th order Bessel from DimeNet)
12 | """
13 |
14 | def __init__(self, cutoff=5.0, n_rbf=None):
15 | """
16 | Args:
17 | cutoff: radial cutoff
18 | n_rbf: number of basis functions.
19 | """
20 | super(BesselBasis, self).__init__()
21 | # compute offset and width of Gaussian functions
22 | freqs = torch.arange(1, n_rbf + 1) * math.pi / cutoff
23 | self.register_buffer("freqs", freqs)
24 |
25 | def forward(self, inputs):
26 | a = self.freqs[None, None, None, :]
27 | ax = inputs * a
28 | sinax = torch.sin(ax)
29 |
30 | norm = torch.where(inputs == 0, torch.tensor(1.0, device=inputs.device), inputs)
31 | y = sinax / norm
32 |
33 | return y
34 |
35 | act_class_mapping = {
36 | "ssp": snn.ShiftedSoftplus,
37 | "silu": nn.SiLU,
38 | "tanh": nn.Tanh,
39 | "sigmoid": nn.Sigmoid,
40 | }
41 |
42 |
43 | class GatedEquivariantBlock(nn.Module):
44 | """Gated Equivariant Block as defined in Schütt et al. (2021):
45 | Equivariant message passing for the prediction of tensorial properties and molecular spectra
46 | """
47 |
48 | def __init__(
49 | self,
50 | hidden_channels,
51 | out_channels,
52 | intermediate_channels=None,
53 | activation="silu",
54 | scalar_activation=False,
55 | ):
56 | super(GatedEquivariantBlock, self).__init__()
57 | self.out_channels = out_channels
58 |
59 | if intermediate_channels is None:
60 | intermediate_channels = hidden_channels
61 |
62 | self.vec1_proj = nn.Linear(hidden_channels, hidden_channels)
63 | self.vec2_proj = nn.Linear(hidden_channels, out_channels)
64 |
65 | act_class = act_class_mapping[activation]
66 | self.update_net = nn.Sequential(
67 | nn.Linear(hidden_channels * 2, intermediate_channels),
68 | act_class(),
69 | nn.Linear(intermediate_channels, out_channels * 2),
70 | )
71 |
72 | self.act = act_class() if scalar_activation else None
73 |
74 | def reset_parameters(self):
75 | nn.init.xavier_uniform_(self.vec1_proj.weight)
76 | nn.init.xavier_uniform_(self.vec2_proj.weight)
77 | nn.init.xavier_uniform_(self.update_net[0].weight)
78 | self.update_net[0].bias.data.fill_(0)
79 | nn.init.xavier_uniform_(self.update_net[2].weight)
80 | self.update_net[2].bias.data.fill_(0)
81 |
82 | def forward(self, x, v):
83 | vec1 = torch.norm(self.vec1_proj(v), dim=-2)
84 | vec2 = self.vec2_proj(v)
85 |
86 | x = torch.cat([x, vec1], dim=-1)
87 | x, v = torch.split(self.update_net(x), self.out_channels, dim=-1)
88 | v = v.unsqueeze(2) * vec2
89 |
90 | if self.act is not None:
91 | x = self.act(x)
92 | return x, v
93 |
94 | class PaiNN(nn.Module):
95 | """ Polarizable atom interaction neural network """
96 | def __init__(
97 | self,
98 | n_atom_basis: int = 128,
99 | n_interactions: int = 3,
100 | n_rbf: int = 20,
101 | cutoff: float = 5.,
102 | cutoff_network: Union[nn.Module, str] = 'cosine',
103 | radial_basis: Callable = BesselBasis,
104 | activation=F.silu,
105 | max_z: int = 100,
106 | store_neighbors: bool = False,
107 | store_embeddings: bool = False,
108 | n_edge_features: int = 0,
109 | ):
110 | super(PaiNN, self).__init__()
111 |
112 | self.n_atom_basis = n_atom_basis
113 | self.n_interactions = n_interactions
114 | self.cutoff = cutoff
115 | self.cutoff_network = snn.get_cutoff_by_string(cutoff_network)(cutoff)
116 | self.radial_basis = radial_basis(cutoff=cutoff, n_rbf=n_rbf)
117 | self.embedding = nn.Linear(max_z, n_atom_basis)
118 |
119 | self.store_neighbors = store_neighbors
120 | self.store_embeddings = store_embeddings
121 | self.n_edge_features = n_edge_features
122 |
123 | # if self.n_edge_features:
124 | # self.edge_embedding = nn.Embedding(n_edge_features, self.n_interactions * 3 * n_atom_basis, padding_idx=0, max_norm=1.0)
125 |
126 | if type(activation) is str:
127 | if activation == 'swish':
128 | activation = F.silu
129 | elif activation == 'softplus':
130 | activation = snn.shifted_softplus
131 |
132 | self.filter_net = snn.Dense(
133 | n_rbf + n_edge_features, self.n_interactions * 3 * n_atom_basis, activation=None
134 | )
135 |
136 | self.interatomic_context_net = nn.ModuleList(
137 | [
138 | nn.Sequential(
139 | snn.Dense(n_atom_basis, n_atom_basis, activation=activation),
140 | snn.Dense(n_atom_basis, 3 * n_atom_basis, activation=None),
141 | )
142 | for _ in range(self.n_interactions)
143 | ]
144 | )
145 |
146 | self.intraatomic_context_net = nn.ModuleList(
147 | [
148 | nn.Sequential(
149 | snn.Dense(
150 | 2 * n_atom_basis, n_atom_basis, activation=activation
151 | ),
152 | snn.Dense(n_atom_basis, 3 * n_atom_basis, activation=None),
153 | )
154 | for _ in range(self.n_interactions)
155 | ]
156 | )
157 |
158 | self.mu_channel_mix = nn.ModuleList(
159 | [
160 | nn.Sequential(
161 | snn.Dense(n_atom_basis, 2 * n_atom_basis, activation=None, bias=False)
162 | )
163 | for _ in range(self.n_interactions)
164 | ]
165 | )
166 |
167 | # self.node_dec = nn.Sequential(snn.Dense(self.n_atom_basis, self.n_atom_basis, activation=F.silu),
168 | # snn.Dense(self.n_atom_basis, self.n_atom_basis))
169 |
170 | # self.graph_dec = nn.Sequential(snn.Dense(self.n_atom_basis, self.n_atom_basis, activation=F.silu),
171 | # snn.Dense(self.n_atom_basis, 1))
172 |
173 | def forward(self, node_features, positions, neighbors, neighbor_mask, atom_mask):
174 | cell = None
175 | cell_offset = None
176 | # get interatomic vectors and distances
177 | rij, dir_ij = atom_distances(
178 | positions=positions,
179 | neighbors=neighbors,
180 | neighbor_mask=neighbor_mask,
181 | cell=cell,
182 | cell_offsets=cell_offset,
183 | return_vecs=True,
184 | normalize_vecs=True,
185 | )
186 |
187 | phi_ij = self.radial_basis(rij[..., None])
188 |
189 | fcut = self.cutoff_network(rij) * neighbor_mask
190 | # fcut = neighbor_mask
191 | fcut = fcut.unsqueeze(-1)
192 |
193 | filters = self.filter_net(phi_ij)
194 |
195 | # if self.n_edge_features:
196 | # edge_types = inputs['edge_types']
197 | # filters = filters + self.edge_embedding(edge_types)
198 |
199 | filters = filters * fcut
200 | filters = torch.split(filters, 3 * self.n_atom_basis, dim=-1)
201 |
202 | # initialize scalar and vector embeddings
203 | scalars = self.embedding(node_features)
204 |
205 | sshape = scalars.shape
206 | vectors = torch.zeros((sshape[0], sshape[1], 3, sshape[2]), device=scalars.device)
207 |
208 | for i in range(self.n_interactions):
209 | # message function
210 | h_i = self.interatomic_context_net[i](scalars)
211 | h_j, vectors_j = self.collect_neighbors(h_i, vectors, neighbors)
212 |
213 | # neighborhood context
214 | h_i = filters[i] * h_j
215 |
216 | dscalars, dvR, dvv = torch.split(h_i, self.n_atom_basis, dim=-1)
217 | dvectors = torch.einsum("bijf,bijd->bidf", dvR, dir_ij) + torch.einsum(
218 | "bijf,bijdf->bidf", dvv, vectors_j
219 | )
220 | dscalars = torch.sum(dscalars, dim=2)
221 | scalars = scalars + dscalars
222 | vectors = vectors + dvectors
223 |
224 | # update function
225 | mu_mix = self.mu_channel_mix[i](vectors)
226 | vectors_V, vectors_U = torch.split(mu_mix, self.n_atom_basis, dim=-1)
227 | mu_Vn = torch.norm(vectors_V, dim=2)
228 |
229 | ctx = torch.cat([scalars, mu_Vn], dim=-1)
230 | h_i = self.intraatomic_context_net[i](ctx)
231 | ds, dv, dsv = torch.split(h_i, self.n_atom_basis, dim=-1)
232 | dv = dv.unsqueeze(2) * vectors_U
233 | dsv = dsv * torch.einsum("bidf,bidf->bif", vectors_V, vectors_U)
234 |
235 | # calculate atomwise updates
236 | scalars = scalars + ds + dsv
237 | vectors = vectors + dv
238 |
239 | # h = self.node_dec(scalars)
240 | # h = h * atom_mask.unsqueeze(-1)
241 | # h = torch.sum(h, dim=1)
242 | # pred = self.graph_dec(h)
243 | # return pred.squeeze(1)
244 | return scalars
245 |
246 | # for layer in self.output_network:
247 | # scalars, vectors = layer(scalars, vectors)
248 | # # include v in output to make sure all parameters have a gradient
249 | # pred = scalars + vectors.sum() * 0
250 | # pred = pred.squeeze(-1) * atom_mask
251 | # return torch.sum(pred, dim = -1)
252 | # # scalars = self.scalar_LN(scalars)
253 | # # vectors = self.vector_LN(vectors)
254 |
255 |
256 |
257 | def collect_neighbors(self, scalars, vectors, neighbors):
258 | nbh_size = neighbors.size()
259 | nbh = neighbors.view(-1, nbh_size[1] * nbh_size[2], 1)
260 |
261 | scalar_nbh = nbh.expand(-1, -1, scalars.size(2))
262 | scalars_j = torch.gather(scalars, 1, scalar_nbh)
263 | scalars_j = scalars_j.view(nbh_size[0], nbh_size[1], nbh_size[2], -1)
264 |
265 | vectors_nbh = nbh[..., None].expand(-1, -1, vectors.size(2), vectors.size(3))
266 | vectors_j = torch.gather(vectors, 1, vectors_nbh)
267 | vectors_j = vectors_j.view(nbh_size[0], nbh_size[1], nbh_size[2], 3, -1)
268 | return scalars_j, vectors_j
269 |
--------------------------------------------------------------------------------
/baselines/schnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .spk_utils.base import Dense
5 | from .spk_utils.cfconv import CFConv
6 | from .spk_utils.cutoff import CosineCutoff
7 | from .spk_utils.acsf import GaussianSmearing
8 | from .spk_utils.neighbors import AtomDistances
9 | from .spk_utils.activations import shifted_softplus
10 |
11 |
12 | __all__ = ["SchNetInteraction", "SchNet"]
13 |
14 |
15 | class SchNetInteraction(nn.Module):
16 | r"""SchNet interaction block for modeling interactions of atomistic systems.
17 |
18 | Args:
19 | n_atom_basis (int): number of features to describe atomic environments.
20 | n_spatial_basis (int): number of input features of filter-generating networks.
21 | n_filters (int): number of filters used in continuous-filter convolution.
22 | cutoff (float): cutoff radius.
23 | cutoff_network (nn.Module, optional): cutoff layer.
24 | normalize_filter (bool, optional): if True, divide aggregated filter by number
25 | of neighbors over which convolution is applied.
26 |
27 | """
28 |
29 | def __init__(
30 | self,
31 | n_atom_basis,
32 | n_spatial_basis,
33 | n_filters,
34 | cutoff,
35 | cutoff_network=CosineCutoff,
36 | normalize_filter=False,
37 | ):
38 | super(SchNetInteraction, self).__init__()
39 | # filter block used in interaction block
40 | self.filter_network = nn.Sequential(
41 | Dense(n_spatial_basis, n_filters, activation=shifted_softplus),
42 | Dense(n_filters, n_filters),
43 | )
44 | # cutoff layer used in interaction block
45 | self.cutoff_network = cutoff_network(cutoff)
46 | # interaction block
47 | self.cfconv = CFConv(
48 | n_atom_basis,
49 | n_filters,
50 | n_atom_basis,
51 | self.filter_network,
52 | cutoff_network=self.cutoff_network,
53 | activation=shifted_softplus,
54 | normalize_filter=normalize_filter,
55 | )
56 | # dense layer
57 | self.dense = Dense(n_atom_basis, n_atom_basis, bias=True, activation=None)
58 |
59 | def forward(self, x, r_ij, neighbors, neighbor_mask, f_ij=None):
60 | """Compute interaction output.
61 |
62 | Args:
63 | x (torch.Tensor): input representation/embedding of atomic environments
64 | with (N_b, N_a, n_atom_basis) shape.
65 | r_ij (torch.Tensor): interatomic distances of (N_b, N_a, N_nbh) shape.
66 | neighbors (torch.Tensor): indices of neighbors of (N_b, N_a, N_nbh) shape.
67 | neighbor_mask (torch.Tensor): mask to filter out non-existing neighbors
68 | introduced via padding.
69 | f_ij (torch.Tensor, optional): expanded interatomic distances in a basis.
70 | If None, r_ij.unsqueeze(-1) is used.
71 |
72 | Returns:
73 | torch.Tensor: block output with (N_b, N_a, n_atom_basis) shape.
74 |
75 | """
76 | # continuous-filter convolution interaction block followed by Dense layer
77 | v = self.cfconv(x, r_ij, neighbors, neighbor_mask, f_ij)
78 | v = self.dense(v)
79 | return v
80 |
81 |
82 | class SchNet(nn.Module):
83 | """SchNet architecture for learning representations of atomistic systems.
84 |
85 | Args:
86 | n_atom_basis (int, optional): number of features to describe atomic environments.
87 | This determines the size of each embedding vector; i.e. embeddings_dim.
88 | n_filters (int, optional): number of filters used in continuous-filter convolution
89 | n_interactions (int, optional): number of interaction blocks.
90 | cutoff (float, optional): cutoff radius.
91 | n_gaussians (int, optional): number of Gaussian functions used to expand
92 | atomic distances.
93 | normalize_filter (bool, optional): if True, divide aggregated filter by number
94 | of neighbors over which convolution is applied.
95 | coupled_interactions (bool, optional): if True, share the weights across
96 | interaction blocks and filter-generating networks.
97 | return_intermediate (bool, optional): if True, `forward` method also returns
98 | intermediate atomic representations after each interaction block is applied.
99 | max_z (int, optional): maximum nuclear charge allowed in database. This
100 | determines the size of the dictionary of embedding; i.e. num_embeddings.
101 | cutoff_network (nn.Module, optional): cutoff layer.
102 | trainable_gaussians (bool, optional): If True, widths and offset of Gaussian
103 | functions are adjusted during training process.
104 | distance_expansion (nn.Module, optional): layer for expanding interatomic
105 | distances in a basis.
106 | charged_systems (bool, optional):
107 |
108 | References:
109 | .. [#schnet1] Schütt, Arbabzadah, Chmiela, Müller, Tkatchenko:
110 | Quantum-chemical insights from deep tensor neural networks.
111 | Nature Communications, 8, 13890. 2017.
112 | .. [#schnet_transfer] Schütt, Kindermans, Sauceda, Chmiela, Tkatchenko, Müller:
113 | SchNet: A continuous-filter convolutional neural network for modeling quantum
114 | interactions.
115 | In Advances in Neural Information Processing Systems, pp. 992-1002. 2017.
116 | .. [#schnet3] Schütt, Sauceda, Kindermans, Tkatchenko, Müller:
117 | SchNet - a deep learning architecture for molceules and materials.
118 | The Journal of Chemical Physics 148 (24), 241722. 2018.
119 |
120 | """
121 |
122 | def __init__(
123 | self,
124 | n_atom_basis=128,
125 | n_filters=128,
126 | n_interactions=3,
127 | cutoff=5.0,
128 | n_gaussians=25,
129 | normalize_filter=False,
130 | coupled_interactions=False,
131 | return_intermediate=False,
132 | max_z=100,
133 | cutoff_network=CosineCutoff,
134 | trainable_gaussians=False,
135 | distance_expansion=None,
136 | charged_systems=False,
137 | ):
138 | super(SchNet, self).__init__()
139 |
140 | self.n_atom_basis = n_atom_basis
141 | # make a lookup table to store embeddings for each element (up to atomic
142 | # number max_z) each of which is a vector of size n_atom_basis
143 | self.embedding = nn.Linear(max_z, n_atom_basis)
144 |
145 | # layer for computing interatomic distances
146 | self.distances = AtomDistances()
147 |
148 | # layer for expanding interatomic distances in a basis
149 | if distance_expansion is None:
150 | self.distance_expansion = GaussianSmearing(
151 | 0.0, cutoff, n_gaussians, trainable=trainable_gaussians
152 | )
153 | else:
154 | self.distance_expansion = distance_expansion
155 |
156 | # block for computing interaction
157 | if coupled_interactions:
158 | # use the same SchNetInteraction instance (hence the same weights)
159 | self.interactions = nn.ModuleList(
160 | [
161 | SchNetInteraction(
162 | n_atom_basis=n_atom_basis,
163 | n_spatial_basis=n_gaussians,
164 | n_filters=n_filters,
165 | cutoff_network=cutoff_network,
166 | cutoff=cutoff,
167 | normalize_filter=normalize_filter,
168 | )
169 | ]
170 | * n_interactions
171 | )
172 | else:
173 | # use one SchNetInteraction instance for each interaction
174 | self.interactions = nn.ModuleList(
175 | [
176 | SchNetInteraction(
177 | n_atom_basis=n_atom_basis,
178 | n_spatial_basis=n_gaussians,
179 | n_filters=n_filters,
180 | cutoff_network=cutoff_network,
181 | cutoff=cutoff,
182 | normalize_filter=normalize_filter,
183 | )
184 | for _ in range(n_interactions)
185 | ]
186 | )
187 |
188 | # self.node_dec = nn.Sequential(Dense(self.n_atom_basis, self.n_atom_basis, activation=shifted_softplus),
189 | # Dense(self.n_atom_basis, self.n_atom_basis))
190 |
191 | # self.graph_dec = nn.Sequential(Dense(self.n_atom_basis, self.n_atom_basis, activation=shifted_softplus),
192 | # Dense(self.n_atom_basis, 1))
193 |
194 | # set attributes
195 | self.return_intermediate = return_intermediate
196 | self.charged_systems = charged_systems
197 | if charged_systems:
198 | self.charge = nn.Parameter(torch.Tensor(1, n_atom_basis))
199 | self.charge.data.normal_(0, 1.0 / n_atom_basis ** 0.5)
200 |
201 | def forward(self, node_features, positions, neighbors, neighbor_mask, atom_mask):
202 | """Compute atomic representations/embeddings.
203 |
204 | Args:
205 | inputs (dict of torch.Tensor): SchNetPack dictionary of input tensors.
206 |
207 | Returns:
208 | torch.Tensor: atom-wise representation.
209 | list of torch.Tensor: intermediate atom-wise representations, if
210 | return_intermediate=True was used.
211 |
212 | """
213 | # get tensors from input dictionary
214 | cell = None
215 | cell_offset = None
216 | _, n_nodes, _ = node_features.shape
217 | # get atom embeddings for the input atomic numbers
218 | x = self.embedding(node_features)
219 |
220 | # compute interatomic distance of every atom to its neighbors
221 | r_ij = self.distances(
222 | positions, neighbors, cell, cell_offset, neighbor_mask=neighbor_mask
223 | )
224 | # expand interatomic distances (for example, Gaussian smearing)
225 | f_ij = self.distance_expansion(r_ij)
226 | # store intermediate representations
227 | if self.return_intermediate:
228 | xs = [x]
229 | # compute interaction block to update atomic embeddings
230 | for interaction in self.interactions:
231 | v = interaction(x, r_ij, neighbors, neighbor_mask, f_ij=f_ij)
232 | x = x + v
233 | if self.return_intermediate:
234 | xs.append(x)
235 |
236 | # h = self.node_dec(x)
237 | # h = h * atom_mask.unsqueeze(-1)
238 | # h = torch.sum(h, dim=1)
239 | # pred = self.graph_dec(h)
240 | # return pred.squeeze(1)
241 | return x
242 |
--------------------------------------------------------------------------------
/baselines/spk_utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Basic building blocks of SchNetPack models. Contains various basic and specialized network layers, layers for
3 | cutoff functions, as well as several auxiliary layers and functions.
4 | """
5 |
6 | from .acsf import *
7 | from .activations import *
8 | from .base import *
9 | from .blocks import *
10 | from .cfconv import *
11 | from .cutoff import *
12 | from .initializers import *
13 | from .neighbors import *
14 |
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/acsf.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/acsf.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/acsf.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/acsf.cpython-38.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/activations.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/activations.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/activations.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/activations.cpython-38.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/base.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/base.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/base.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/base.cpython-38.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/blocks.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/blocks.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/blocks.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/blocks.cpython-38.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/cfconv.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/cfconv.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/cfconv.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/cfconv.cpython-38.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/cutoff.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/cutoff.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/cutoff.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/cutoff.cpython-38.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/initializers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/initializers.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/initializers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/initializers.cpython-38.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/neighbors.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/neighbors.cpython-37.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/__pycache__/neighbors.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/baselines/spk_utils/__pycache__/neighbors.cpython-38.pyc
--------------------------------------------------------------------------------
/baselines/spk_utils/acsf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from .cutoff import CosineCutoff
5 |
6 | __all__ = [
7 | "AngularDistribution",
8 | "BehlerAngular",
9 | "GaussianSmearing",
10 | "RadialDistribution",
11 | ]
12 |
13 |
14 | class AngularDistribution(nn.Module):
15 | """
16 | Routine used to compute angular type symmetry functions between all atoms i-j-k, where i is the central atom.
17 |
18 | Args:
19 | radial_filter (callable): Function used to expand distances (e.g. Gaussians)
20 | angular_filter (callable): Function used to expand angles between triples of atoms (e.g. BehlerAngular)
21 | cutoff_functions (callable): Cutoff function
22 | crossterms (bool): Include radial contributions of the distances r_jk
23 | pairwise_elements (bool): Recombine elemental embedding vectors via an outer product. If e.g. one-hot encoding
24 | is used for the elements, this is equivalent to standard Behler functions
25 | (default=False).
26 |
27 | """
28 |
29 | def __init__(
30 | self,
31 | radial_filter,
32 | angular_filter,
33 | cutoff_functions=CosineCutoff,
34 | crossterms=False,
35 | pairwise_elements=False,
36 | ):
37 | super(AngularDistribution, self).__init__()
38 | self.radial_filter = radial_filter
39 | self.angular_filter = angular_filter
40 | self.cutoff_function = cutoff_functions
41 | self.crossterms = crossterms
42 | self.pairwise_elements = pairwise_elements
43 |
44 | def forward(self, r_ij, r_ik, r_jk, triple_masks=None, elemental_weights=None):
45 | """
46 | Args:
47 | r_ij (torch.Tensor): Distances to neighbor j
48 | r_ik (torch.Tensor): Distances to neighbor k
49 | r_jk (torch.Tensor): Distances between neighbor j and k
50 | triple_masks (torch.Tensor): Tensor mask for non-counted pairs (e.g. due to cutoff)
51 | elemental_weights (tuple of two torch.Tensor): Weighting functions for neighboring elements, first is for
52 | neighbors j, second for k
53 |
54 | Returns:
55 | torch.Tensor: Angular distribution functions
56 |
57 | """
58 |
59 | nbatch, natoms, npairs = r_ij.size()
60 |
61 | # compute gaussilizated distances and cutoffs to neighbor atoms
62 | radial_ij = self.radial_filter(r_ij)
63 | radial_ik = self.radial_filter(r_ik)
64 | angular_distribution = radial_ij * radial_ik
65 |
66 | if self.crossterms:
67 | radial_jk = self.radial_filter(r_jk)
68 | angular_distribution = angular_distribution * radial_jk
69 |
70 | # Use cosine rule to compute cos( theta_ijk )
71 | cos_theta = (torch.pow(r_ij, 2) + torch.pow(r_ik, 2) - torch.pow(r_jk, 2)) / (
72 | 2.0 * r_ij * r_ik
73 | )
74 |
75 | # Required in order to catch NaNs during backprop
76 | if triple_masks is not None:
77 | cos_theta[triple_masks == 0] = 0.0
78 |
79 | angular_term = self.angular_filter(cos_theta)
80 |
81 | if self.cutoff_function is not None:
82 | cutoff_ij = self.cutoff_function(r_ij).unsqueeze(-1)
83 | cutoff_ik = self.cutoff_function(r_ik).unsqueeze(-1)
84 | angular_distribution = angular_distribution * cutoff_ij * cutoff_ik
85 |
86 | if self.crossterms:
87 | cutoff_jk = self.cutoff_function(r_jk).unsqueeze(-1)
88 | angular_distribution = angular_distribution * cutoff_jk
89 |
90 | # Compute radial part of descriptor
91 | if triple_masks is not None:
92 | # Filter out nan divisions via boolean mask, since
93 | # angular_term = angular_term * triple_masks
94 | # is not working (nan*0 = nan)
95 | angular_term[triple_masks == 0] = 0.0
96 | angular_distribution[triple_masks == 0] = 0.0
97 |
98 | # Apply weights here, since dimension is still the same
99 | if elemental_weights is not None:
100 | if not self.pairwise_elements:
101 | Z_ij, Z_ik = elemental_weights
102 | Z_ijk = Z_ij * Z_ik
103 | angular_distribution = (
104 | torch.unsqueeze(angular_distribution, -1)
105 | * torch.unsqueeze(Z_ijk, -2).float()
106 | )
107 | else:
108 | # Outer product to emulate vanilla SF behavior
109 | Z_ij, Z_ik = elemental_weights
110 | B, A, N, E = Z_ij.size()
111 | pair_elements = Z_ij[:, :, :, :, None] * Z_ik[:, :, :, None, :]
112 | pair_elements = pair_elements + pair_elements.permute(0, 1, 2, 4, 3)
113 | # Filter out lower triangular components
114 | pair_filter = torch.triu(torch.ones(E, E)) == 1
115 | pair_elements = pair_elements[:, :, :, pair_filter]
116 | angular_distribution = torch.unsqueeze(
117 | angular_distribution, -1
118 | ) * torch.unsqueeze(pair_elements, -2)
119 |
120 | # Dimension is (Nb x Nat x Nneighpair x Nrad) for angular_distribution and
121 | # (Nb x Nat x NNeigpair x Nang) for angular_term, where the latter dims are orthogonal
122 | # To multiply them:
123 | angular_distribution = (
124 | angular_distribution[:, :, :, :, None, :]
125 | * angular_term[:, :, :, None, :, None]
126 | )
127 | # For the sum over all contributions
128 | angular_distribution = torch.sum(angular_distribution, 2)
129 | # Finally, we flatten the last two dimensions
130 | angular_distribution = angular_distribution.view(nbatch, natoms, -1)
131 |
132 | return angular_distribution
133 |
134 |
135 | class BehlerAngular(nn.Module):
136 | """
137 | Compute Behler type angular contribution of the angle spanned by three atoms:
138 |
139 | :math:`2^{(1-\zeta)} (1 + \lambda \cos( {\\theta}_{ijk} ) )^\zeta`
140 |
141 | Sets of zetas with lambdas of -1 and +1 are generated automatically.
142 |
143 | Args:
144 | zetas (set of int): Set of exponents used to compute angular Behler term (default={1})
145 |
146 | """
147 |
148 | def __init__(self, zetas={1}):
149 | super(BehlerAngular, self).__init__()
150 | self.zetas = zetas
151 |
152 | def forward(self, cos_theta):
153 | """
154 | Args:
155 | cos_theta (torch.Tensor): Cosines between all pairs of neighbors of the central atom.
156 |
157 | Returns:
158 | torch.Tensor: Tensor containing values of the angular filters.
159 | """
160 | angular_pos = [
161 | 2 ** (1 - zeta) * ((1.0 - cos_theta) ** zeta).unsqueeze(-1)
162 | for zeta in self.zetas
163 | ]
164 | angular_neg = [
165 | 2 ** (1 - zeta) * ((1.0 + cos_theta) ** zeta).unsqueeze(-1)
166 | for zeta in self.zetas
167 | ]
168 | angular_all = angular_pos + angular_neg
169 | return torch.cat(angular_all, -1)
170 |
171 |
172 | def gaussian_smearing(distances, offset, widths, centered=False):
173 | r"""Smear interatomic distance values using Gaussian functions.
174 |
175 | Args:
176 | distances (torch.Tensor): interatomic distances of (N_b x N_at x N_nbh) shape.
177 | offset (torch.Tensor): offsets values of Gaussian functions.
178 | widths: width values of Gaussian functions.
179 | centered (bool, optional): If True, Gaussians are centered at the origin and
180 | the offsets are used to as their widths (used e.g. for angular functions).
181 |
182 | Returns:
183 | torch.Tensor: smeared distances (N_b x N_at x N_nbh x N_g).
184 |
185 | """
186 | if not centered:
187 | # compute width of Gaussian functions (using an overlap of 1 STDDEV)
188 | coeff = -0.5 / torch.pow(widths, 2)
189 | # Use advanced indexing to compute the individual components
190 | diff = distances[:, :, :, None] - offset[None, None, None, :]
191 | else:
192 | # if Gaussian functions are centered, use offsets to compute widths
193 | coeff = -0.5 / torch.pow(offset, 2)
194 | # if Gaussian functions are centered, no offset is subtracted
195 | diff = distances[:, :, :, None]
196 | # compute smear distance values
197 | gauss = torch.exp(coeff * torch.pow(diff, 2))
198 | return gauss
199 |
200 |
201 | class GaussianSmearing(nn.Module):
202 | r"""Smear layer using a set of Gaussian functions.
203 |
204 | Args:
205 | start (float, optional): center of first Gaussian function, :math:`\mu_0`.
206 | stop (float, optional): center of last Gaussian function, :math:`\mu_{N_g}`
207 | n_gaussians (int, optional): total number of Gaussian functions, :math:`N_g`.
208 | centered (bool, optional): If True, Gaussians are centered at the origin and
209 | the offsets are used to as their widths (used e.g. for angular functions).
210 | trainable (bool, optional): If True, widths and offset of Gaussian functions
211 | are adjusted during training process.
212 |
213 | """
214 |
215 | def __init__(
216 | self, start=0.0, stop=5.0, n_gaussians=50, centered=False, trainable=False
217 | ):
218 | super(GaussianSmearing, self).__init__()
219 | # compute offset and width of Gaussian functions
220 | offset = torch.linspace(start, stop, n_gaussians)
221 | widths = torch.FloatTensor((offset[1] - offset[0]) * torch.ones_like(offset))
222 | if trainable:
223 | self.width = nn.Parameter(widths)
224 | self.offsets = nn.Parameter(offset)
225 | else:
226 | self.register_buffer("width", widths)
227 | self.register_buffer("offsets", offset)
228 | self.centered = centered
229 |
230 | def forward(self, distances):
231 | """Compute smeared-gaussian distance values.
232 |
233 | Args:
234 | distances (torch.Tensor): interatomic distance values of
235 | (N_b x N_at x N_nbh) shape.
236 |
237 | Returns:
238 | torch.Tensor: layer output of (N_b x N_at x N_nbh x N_g) shape.
239 |
240 | """
241 | return gaussian_smearing(
242 | distances, self.offsets, self.width, centered=self.centered
243 | )
244 |
245 |
246 | class RadialDistribution(nn.Module):
247 | """
248 | Radial distribution function used e.g. to compute Behler type radial symmetry functions.
249 |
250 | Args:
251 | radial_filter (callable): Function used to expand distances (e.g. Gaussians)
252 | cutoff_function (callable): Cutoff function
253 | """
254 |
255 | def __init__(self, radial_filter, cutoff_function=CosineCutoff):
256 | super(RadialDistribution, self).__init__()
257 | self.radial_filter = radial_filter
258 | self.cutoff_function = cutoff_function
259 |
260 | def forward(self, r_ij, elemental_weights=None, neighbor_mask=None):
261 | """
262 | Args:
263 | r_ij (torch.Tensor): Interatomic distances
264 | elemental_weights (torch.Tensor): Element-specific weights for distance functions
265 | neighbor_mask (torch.Tensor): Mask to identify positions of neighboring atoms
266 |
267 | Returns:
268 | torch.Tensor: Nbatch x Natoms x Nfilter tensor containing radial distribution functions.
269 | """
270 |
271 | nbatch, natoms, nneigh = r_ij.size()
272 |
273 | radial_distribution = self.radial_filter(r_ij)
274 |
275 | # If requested, apply cutoff function
276 | if self.cutoff_function is not None:
277 | cutoffs = self.cutoff_function(r_ij)
278 | radial_distribution = radial_distribution * cutoffs.unsqueeze(-1)
279 |
280 | # Apply neighbor mask
281 | if neighbor_mask is not None:
282 | radial_distribution = radial_distribution * torch.unsqueeze(
283 | neighbor_mask, -1
284 | )
285 |
286 | # Weigh elements if requested
287 | if elemental_weights is not None:
288 | radial_distribution = (
289 | radial_distribution[:, :, :, :, None]
290 | * elemental_weights[:, :, :, None, :].float()
291 | )
292 |
293 | radial_distribution = torch.sum(radial_distribution, 2)
294 | radial_distribution = radial_distribution.view(nbatch, natoms, -1)
295 | return radial_distribution
296 |
--------------------------------------------------------------------------------
/baselines/spk_utils/activations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn as nn
3 | from torch.nn import functional
4 |
5 |
6 | def shifted_softplus(x):
7 | r"""Compute shifted soft-plus activation function.
8 |
9 | .. math::
10 | y = \ln\left(1 + e^{-x}\right) - \ln(2)
11 |
12 | Args:
13 | x (torch.Tensor): input tensor.
14 |
15 | Returns:
16 | torch.Tensor: shifted soft-plus of input.
17 |
18 | """
19 | return functional.softplus(x) - np.log(2.0)
20 |
21 | class ShiftedSoftplus(nn.Module):
22 | def __init__(self):
23 | super(ShiftedSoftplus, self).__init__()
24 | self.shift = torch.log(torch.tensor(2.0)).item()
25 |
26 | def forward(self, x):
27 | return functional.softplus(x) - self.shift
28 |
--------------------------------------------------------------------------------
/baselines/spk_utils/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn.init import xavier_uniform_
4 |
5 | from .initializers import zeros_initializer
6 |
7 |
8 | __all__ = ["Dense", "GetItem", "ScaleShift", "Standardize", "Aggregate"]
9 |
10 |
11 | class Dense(nn.Linear):
12 | r"""Fully connected linear layer with activation function.
13 |
14 | .. math::
15 | y = activation(xW^T + b)
16 |
17 | Args:
18 | in_features (int): number of input feature :math:`x`.
19 | out_features (int): number of output features :math:`y`.
20 | bias (bool, optional): if False, the layer will not adapt bias :math:`b`.
21 | activation (callable, optional): if None, no activation function is used.
22 | weight_init (callable, optional): weight initializer from current weight.
23 | bias_init (callable, optional): bias initializer from current bias.
24 |
25 | """
26 |
27 | def __init__(
28 | self,
29 | in_features,
30 | out_features,
31 | bias=True,
32 | activation=None,
33 | weight_init=xavier_uniform_,
34 | bias_init=zeros_initializer,
35 | ):
36 | self.weight_init = weight_init
37 | self.bias_init = bias_init
38 | self.activation = activation
39 | # initialize linear layer y = xW^T + b
40 | super(Dense, self).__init__(in_features, out_features, bias)
41 |
42 | def reset_parameters(self):
43 | """Reinitialize model weight and bias values."""
44 | self.weight_init(self.weight)
45 | if self.bias is not None:
46 | self.bias_init(self.bias)
47 |
48 | def forward(self, inputs):
49 | """Compute layer output.
50 |
51 | Args:
52 | inputs (dict of torch.Tensor): batch of input values.
53 |
54 | Returns:
55 | torch.Tensor: layer output.
56 |
57 | """
58 | # compute linear layer y = xW^T + b
59 | y = super(Dense, self).forward(inputs)
60 | # add activation function
61 | if self.activation:
62 | y = self.activation(y)
63 | return y
64 |
65 |
66 | class GetItem(nn.Module):
67 | """Extraction layer to get an item from SchNetPack dictionary of input tensors.
68 |
69 | Args:
70 | key (str): Property to be extracted from SchNetPack input tensors.
71 |
72 | """
73 |
74 | def __init__(self, key):
75 | super(GetItem, self).__init__()
76 | self.key = key
77 |
78 | def forward(self, inputs):
79 | """Compute layer output.
80 |
81 | Args:
82 | inputs (dict of torch.Tensor): SchNetPack dictionary of input tensors.
83 |
84 | Returns:
85 | torch.Tensor: layer output.
86 |
87 | """
88 | return inputs[self.key]
89 |
90 |
91 | class ScaleShift(nn.Module):
92 | r"""Scale and shift layer for standardization.
93 |
94 | .. math::
95 | y = x \times \sigma + \mu
96 |
97 | Args:
98 | mean (torch.Tensor): mean value :math:`\mu`.
99 | stddev (torch.Tensor): standard deviation value :math:`\sigma`.
100 |
101 | """
102 |
103 | def __init__(self, mean, stddev):
104 | super(ScaleShift, self).__init__()
105 | self.register_buffer("mean", mean)
106 | self.register_buffer("stddev", stddev)
107 |
108 | def forward(self, input):
109 | """Compute layer output.
110 |
111 | Args:
112 | input (torch.Tensor): input data.
113 |
114 | Returns:
115 | torch.Tensor: layer output.
116 |
117 | """
118 | y = input * self.stddev + self.mean
119 | return y
120 |
121 |
122 | class Standardize(nn.Module):
123 | r"""Standardize layer for shifting and scaling.
124 |
125 | .. math::
126 | y = \frac{x - \mu}{\sigma}
127 |
128 | Args:
129 | mean (torch.Tensor): mean value :math:`\mu`.
130 | stddev (torch.Tensor): standard deviation value :math:`\sigma`.
131 | eps (float, optional): small offset value to avoid zero division.
132 |
133 | """
134 |
135 | def __init__(self, mean, stddev, eps=1e-9):
136 | super(Standardize, self).__init__()
137 | self.register_buffer("mean", mean)
138 | self.register_buffer("stddev", stddev)
139 | self.register_buffer("eps", torch.ones_like(stddev) * eps)
140 |
141 | def forward(self, input):
142 | """Compute layer output.
143 |
144 | Args:
145 | input (torch.Tensor): input data.
146 |
147 | Returns:
148 | torch.Tensor: layer output.
149 |
150 | """
151 | # Add small number to catch divide by zero
152 | y = (input - self.mean) / (self.stddev + self.eps)
153 | return y
154 |
155 |
156 | class Aggregate(nn.Module):
157 | """Pooling layer based on sum or average with optional masking.
158 |
159 | Args:
160 | axis (int): axis along which pooling is done.
161 | mean (bool, optional): if True, use average instead for sum pooling.
162 | keepdim (bool, optional): whether the output tensor has dim retained or not.
163 |
164 | """
165 |
166 | def __init__(self, axis, mean=False, keepdim=True):
167 | super(Aggregate, self).__init__()
168 | self.average = mean
169 | self.axis = axis
170 | self.keepdim = keepdim
171 |
172 | def forward(self, input, mask=None):
173 | r"""Compute layer output.
174 |
175 | Args:
176 | input (torch.Tensor): input data.
177 | mask (torch.Tensor, optional): mask to be applied; e.g. neighbors mask.
178 |
179 | Returns:
180 | torch.Tensor: layer output.
181 |
182 | """
183 | # mask input
184 | if mask is not None:
185 | input = input * mask[..., None]
186 | # compute sum of input along axis
187 | y = torch.sum(input, self.axis)
188 | # compute average of input along axis
189 | if self.average:
190 | # get the number of items along axis
191 | if mask is not None:
192 | N = torch.sum(mask, self.axis, keepdim=self.keepdim)
193 | N = torch.max(N, other=torch.ones_like(N))
194 | else:
195 | N = input.size(self.axis)
196 | y = y / N
197 | return y
198 |
199 |
200 | class MaxAggregate(nn.Module):
201 | """Pooling layer that computes the maximum for each feature over all atoms
202 |
203 | Args:
204 | axis (int): axis along which pooling is done.
205 | """
206 |
207 | def __init__(self, axis):
208 | super().__init__()
209 | self.axis = axis
210 |
211 | def forward(self, input, mask=None):
212 | r"""Compute layer output.
213 |
214 | Args:
215 | input (torch.Tensor): input data.
216 | mask (torch.Tensor, optional): mask to be applied; e.g. neighbors mask.
217 |
218 | Returns:
219 | torch.Tensor: layer output.
220 | """
221 | # mask input
222 | if mask is not None:
223 | # If the mask is lower dimensional than the array being masked,
224 | # inject an extra dimension to the end
225 | if mask.dim() < input.dim():
226 | mask = torch.unsqueeze(mask, -1)
227 | input = torch.where(mask > 0, input, torch.min(input))
228 |
229 | # compute sum of input along axis
230 | return torch.max(input, self.axis)[0]
231 |
232 |
233 | class SoftmaxAggregate(nn.Module):
234 | """Pooling layer that computes the maximum for each feature over all atoms
235 | using the "softmax" function to weigh the contribution of each atom to
236 | the "maximum."
237 |
238 | Args:
239 | axis (int): axis along which pooling is done.
240 | """
241 |
242 | def __init__(self, axis):
243 | super().__init__()
244 | self.axis = axis
245 |
246 | def forward(self, input, mask=None):
247 | r"""Compute layer output.
248 |
249 | Args:
250 | input (torch.Tensor): input data.
251 | mask (torch.Tensor, optional): mask to be applied; e.g. neighbors mask.
252 |
253 | Returns:
254 | torch.Tensor: layer output.
255 | """
256 |
257 | # Compute the sum of exponentials for the desired axis
258 | exp_input = torch.exp(input)
259 |
260 | # Set the contributions of "masked" atoms to zero
261 | if mask is not None:
262 | # If the mask is lower dimensional than the array being masked,
263 | # inject an extra dimension to the end
264 | if mask.dim() < input.dim():
265 | mask = torch.unsqueeze(mask, -1)
266 | exp_input = torch.where(mask > 0, exp_input, torch.zeros_like(exp_input))
267 |
268 | # Sum exponentials along the desired axis
269 | exp_input_sum = torch.sum(exp_input, self.axis, keepdim=True)
270 |
271 | # Normalize the exponential array by the
272 | weights = exp_input / exp_input_sum
273 |
274 | # compute sum of input along axis
275 | output = torch.sum(input * weights, self.axis)
276 | return output
277 |
--------------------------------------------------------------------------------
/baselines/spk_utils/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from . import shifted_softplus, Dense
5 |
6 | __all__ = ["MLP", "TiledMultiLayerNN", "ElementalGate", "GatedNetwork"]
7 |
8 |
9 | class MLP(nn.Module):
10 | """Multiple layer fully connected perceptron neural network.
11 |
12 | Args:
13 | n_in (int): number of input nodes.
14 | n_out (int): number of output nodes.
15 | n_hidden (list of int or int, optional): number hidden layer nodes.
16 | If an integer, same number of node is used for all hidden layers resulting
17 | in a rectangular network.
18 | If None, the number of neurons is divided by two after each layer starting
19 | n_in resulting in a pyramidal network.
20 | n_layers (int, optional): number of layers.
21 | activation (callable, optional): activation function. All hidden layers would
22 | the same activation function except the output layer that does not apply
23 | any activation function.
24 |
25 | """
26 |
27 | def __init__(
28 | self, n_in, n_out, n_hidden=None, n_layers=2, activation=shifted_softplus
29 | ):
30 | super(MLP, self).__init__()
31 | # get list of number of nodes in input, hidden & output layers
32 | if n_hidden is None:
33 | c_neurons = n_in
34 | self.n_neurons = []
35 | for i in range(n_layers):
36 | self.n_neurons.append(c_neurons)
37 | c_neurons = max(n_out, c_neurons // 2)
38 | self.n_neurons.append(n_out)
39 | else:
40 | # get list of number of nodes hidden layers
41 | if type(n_hidden) is int:
42 | n_hidden = [n_hidden] * (n_layers - 1)
43 | self.n_neurons = [n_in] + n_hidden + [n_out]
44 |
45 | # assign a Dense layer (with activation function) to each hidden layer
46 | layers = [
47 | Dense(self.n_neurons[i], self.n_neurons[i + 1], activation=activation)
48 | for i in range(n_layers - 1)
49 | ]
50 | # assign a Dense layer (without activation function) to the output layer
51 | layers.append(Dense(self.n_neurons[-2], self.n_neurons[-1], activation=None))
52 | # put all layers together to make the network
53 | self.out_net = nn.Sequential(*layers)
54 |
55 | def forward(self, inputs):
56 | """Compute neural network output.
57 |
58 | Args:
59 | inputs (torch.Tensor): network input.
60 |
61 | Returns:
62 | torch.Tensor: network output.
63 |
64 | """
65 | return self.out_net(inputs)
66 |
67 |
68 | class TiledMultiLayerNN(nn.Module):
69 | """
70 | Tiled multilayer networks which are applied to the input and produce n_tiled different outputs.
71 | These outputs are then stacked and returned. Used e.g. to construct element-dependent prediction
72 | networks of the Behler-Parrinello type.
73 |
74 | Args:
75 | n_in (int): number of input nodes
76 | n_out (int): number of output nodes
77 | n_tiles (int): number of networks to be tiled
78 | n_hidden (int): number of nodes in hidden nn (default 50)
79 | n_layers (int): number of layers (default: 3)
80 | """
81 |
82 | def __init__(
83 | self, n_in, n_out, n_tiles, n_hidden=50, n_layers=3, activation=shifted_softplus
84 | ):
85 | super(TiledMultiLayerNN, self).__init__()
86 | self.mlps = nn.ModuleList(
87 | [
88 | MLP(
89 | n_in,
90 | n_out,
91 | n_hidden=n_hidden,
92 | n_layers=n_layers,
93 | activation=activation,
94 | )
95 | for _ in range(n_tiles)
96 | ]
97 | )
98 |
99 | def forward(self, inputs):
100 | """
101 | Args:
102 | inputs (torch.Tensor): Network inputs.
103 |
104 | Returns:
105 | torch.Tensor: Tiled network outputs.
106 |
107 | """
108 | return torch.cat([net(inputs) for net in self.mlps], 2)
109 |
110 |
111 | class ElementalGate(nn.Module):
112 | """
113 | Produces a Nbatch x Natoms x Nelem mask depending on the nuclear charges passed as an argument.
114 | If onehot is set, mask is one-hot mask, else a random embedding is used.
115 | If the trainable flag is set to true, the gate values can be adapted during training.
116 |
117 | Args:
118 | elements (set of int): Set of atomic number present in the data
119 | onehot (bool): Use one hit encoding for elemental gate. If set to False, random embedding is used instead.
120 | trainable (bool): If set to true, gate can be learned during training (default False)
121 | """
122 |
123 | def __init__(self, elements, onehot=True, trainable=False):
124 | super(ElementalGate, self).__init__()
125 | self.trainable = trainable
126 |
127 | # Get the number of elements, as well as the highest nuclear charge to use in the embedding vector
128 | self.nelems = len(elements)
129 | maxelem = int(max(elements) + 1)
130 |
131 | self.gate = nn.Embedding(maxelem, self.nelems)
132 |
133 | # if requested, initialize as one hot gate for all elements
134 | if onehot:
135 | weights = torch.zeros(maxelem, self.nelems)
136 | for idx, Z in enumerate(elements):
137 | weights[Z, idx] = 1.0
138 | self.gate.weight.data = weights
139 |
140 | # Set trainable flag
141 | if not trainable:
142 | self.gate.weight.requires_grad = False
143 |
144 | def forward(self, atomic_numbers):
145 | """
146 | Args:
147 | atomic_numbers (torch.Tensor): Tensor containing atomic numbers of each atom.
148 |
149 | Returns:
150 | torch.Tensor: One-hot vector which is one at the position of the element and zero otherwise.
151 |
152 | """
153 | return self.gate(atomic_numbers)
154 |
155 |
156 | class GatedNetwork(nn.Module):
157 | """
158 | Combines the TiledMultiLayerNN with the elemental gate to obtain element specific atomistic networks as in typical
159 | Behler--Parrinello networks [#behler1]_.
160 |
161 | Args:
162 | nin (int): number of input nodes
163 | nout (int): number of output nodes
164 | nnodes (int): number of nodes in hidden nn (default 50)
165 | nlayers (int): number of layers (default 3)
166 | elements (set of ints): Set of atomic number present in the data
167 | onehot (bool): Use one hit encoding for elemental gate. If set to False, random embedding is used instead.
168 | trainable (bool): If set to true, gate can be learned during training (default False)
169 | activation (callable): activation function
170 |
171 | References
172 | ----------
173 | .. [#behler1] Behler, Parrinello:
174 | Generalized Neural-Network Representation of High-Dimensional Potential-Energy Surfaces.
175 | Phys. Rev. Lett. 98, 146401. 2007.
176 |
177 | """
178 |
179 | def __init__(
180 | self,
181 | nin,
182 | nout,
183 | elements,
184 | n_hidden=50,
185 | n_layers=3,
186 | trainable=False,
187 | onehot=True,
188 | activation=shifted_softplus,
189 | ):
190 | super(GatedNetwork, self).__init__()
191 | self.nelem = len(elements)
192 | self.gate = ElementalGate(elements, trainable=trainable, onehot=onehot)
193 | self.network = TiledMultiLayerNN(
194 | nin,
195 | nout,
196 | self.nelem,
197 | n_hidden=n_hidden,
198 | n_layers=n_layers,
199 | activation=activation,
200 | )
201 |
202 | def forward(self, atomic_numbers, representation):
203 | """
204 | Args:
205 | inputs (dict of torch.Tensor): SchNetPack format dictionary of input tensors.
206 |
207 | Returns:
208 | torch.Tensor: Output of the gated network.
209 | """
210 | # At this point, inputs should be the general schnetpack container
211 | gated_network = self.gate(atomic_numbers) * self.network(representation)
212 | return torch.sum(gated_network, -1, keepdim=True)
213 |
--------------------------------------------------------------------------------
/baselines/spk_utils/cfconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from . import Dense
5 | from .base import Aggregate
6 |
7 |
8 | __all__ = ["CFConv"]
9 |
10 |
11 | class CFConv(nn.Module):
12 | r"""Continuous-filter convolution block used in SchNet module.
13 |
14 | Args:
15 | n_in (int): number of input (i.e. atomic embedding) dimensions.
16 | n_filters (int): number of filter dimensions.
17 | n_out (int): number of output dimensions.
18 | filter_network (nn.Module): filter block.
19 | cutoff_network (nn.Module, optional): if None, no cut off function is used.
20 | activation (callable, optional): if None, no activation function is used.
21 | normalize_filter (bool, optional): If True, normalize filter to the number
22 | of neighbors when aggregating.
23 | axis (int, optional): axis over which convolution should be applied.
24 |
25 | """
26 |
27 | def __init__(
28 | self,
29 | n_in,
30 | n_filters,
31 | n_out,
32 | filter_network,
33 | cutoff_network=None,
34 | activation=None,
35 | normalize_filter=False,
36 | axis=2,
37 | ):
38 | super(CFConv, self).__init__()
39 | self.in2f = Dense(n_in, n_filters, bias=False, activation=None)
40 | self.f2out = Dense(n_filters, n_out, bias=True, activation=activation)
41 | self.filter_network = filter_network
42 | self.cutoff_network = cutoff_network
43 | self.agg = Aggregate(axis=axis, mean=normalize_filter)
44 |
45 | def forward(self, x, r_ij, neighbors, pairwise_mask, f_ij=None):
46 | """Compute convolution block.
47 |
48 | Args:
49 | x (torch.Tensor): input representation/embedding of atomic environments
50 | with (N_b, N_a, n_in) shape.
51 | r_ij (torch.Tensor): interatomic distances of (N_b, N_a, N_nbh) shape.
52 | neighbors (torch.Tensor): indices of neighbors of (N_b, N_a, N_nbh) shape.
53 | pairwise_mask (torch.Tensor): mask to filter out non-existing neighbors
54 | introduced via padding.
55 | f_ij (torch.Tensor, optional): expanded interatomic distances in a basis.
56 | If None, r_ij.unsqueeze(-1) is used.
57 |
58 | Returns:
59 | torch.Tensor: block output with (N_b, N_a, n_out) shape.
60 |
61 | """
62 | if f_ij is None:
63 | f_ij = r_ij.unsqueeze(-1)
64 |
65 | # pass expanded interactomic distances through filter block
66 | W = self.filter_network(f_ij)
67 | # apply cutoff
68 | if self.cutoff_network is not None:
69 | C = self.cutoff_network(r_ij)
70 | W = W * C.unsqueeze(-1)
71 |
72 | # pass initial embeddings through Dense layer
73 | y = self.in2f(x)
74 | # reshape y for element-wise multiplication by W
75 | nbh_size = neighbors.size()
76 | nbh = neighbors.view(-1, nbh_size[1] * nbh_size[2], 1)
77 | nbh = nbh.expand(-1, -1, y.size(2))
78 | y = torch.gather(y, 1, nbh)
79 | y = y.view(nbh_size[0], nbh_size[1], nbh_size[2], -1)
80 |
81 | # element-wise multiplication, aggregating and Dense layer
82 | y = y * W
83 | y = self.agg(y, pairwise_mask)
84 | y = self.f2out(y)
85 | return y
86 |
--------------------------------------------------------------------------------
/baselines/spk_utils/cutoff.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 |
5 |
6 | __all__ = ["CosineCutoff", "MollifierCutoff", "HardCutoff", "get_cutoff_by_string"]
7 |
8 |
9 | def get_cutoff_by_string(key):
10 | # build cutoff module
11 | if key == "hard":
12 | cutoff_network = HardCutoff
13 | elif key == "cosine":
14 | cutoff_network = CosineCutoff
15 | elif key == "mollifier":
16 | cutoff_network = MollifierCutoff
17 | else:
18 | raise NotImplementedError("cutoff_function {} is unknown".format(key))
19 | return cutoff_network
20 |
21 |
22 | class CosineCutoff(nn.Module):
23 | r"""Class of Behler cosine cutoff.
24 |
25 | .. math::
26 | f(r) = \begin{cases}
27 | 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right]
28 | & r < r_\text{cutoff} \\
29 | 0 & r \geqslant r_\text{cutoff} \\
30 | \end{cases}
31 |
32 | Args:
33 | cutoff (float, optional): cutoff radius.
34 |
35 | """
36 |
37 | def __init__(self, cutoff=5.0):
38 | super(CosineCutoff, self).__init__()
39 | self.register_buffer("cutoff", torch.FloatTensor([cutoff]))
40 |
41 | def forward(self, distances):
42 | """Compute cutoff.
43 |
44 | Args:
45 | distances (torch.Tensor): values of interatomic distances.
46 |
47 | Returns:
48 | torch.Tensor: values of cutoff function.
49 |
50 | """
51 | # Compute values of cutoff function
52 | cutoffs = 0.5 * (torch.cos(distances * np.pi / self.cutoff) + 1.0)
53 | # Remove contributions beyond the cutoff radius
54 | cutoffs *= (distances < self.cutoff).float()
55 | return cutoffs
56 |
57 |
58 | class MollifierCutoff(nn.Module):
59 | r"""Class for mollifier cutoff scaled to have a value of 1 at :math:`r=0`.
60 |
61 | .. math::
62 | f(r) = \begin{cases}
63 | \exp\left(1 - \frac{1}{1 - \left(\frac{r}{r_\text{cutoff}}\right)^2}\right)
64 | & r < r_\text{cutoff} \\
65 | 0 & r \geqslant r_\text{cutoff} \\
66 | \end{cases}
67 |
68 | Args:
69 | cutoff (float, optional): Cutoff radius.
70 | eps (float, optional): offset added to distances for numerical stability.
71 |
72 | """
73 |
74 | def __init__(self, cutoff=5.0, eps=1.0e-7):
75 | super(MollifierCutoff, self).__init__()
76 | self.register_buffer("cutoff", torch.FloatTensor([cutoff]))
77 | self.register_buffer("eps", torch.FloatTensor([eps]))
78 |
79 | def forward(self, distances):
80 | """Compute cutoff.
81 |
82 | Args:
83 | distances (torch.Tensor): values of interatomic distances.
84 |
85 | Returns:
86 | torch.Tensor: values of cutoff function.
87 |
88 | """
89 | mask = (distances + self.eps < self.cutoff).float()
90 | exponent = 1.0 - 1.0 / (1.0 - torch.pow(distances * mask / self.cutoff, 2))
91 | cutoffs = torch.exp(exponent)
92 | cutoffs = cutoffs * mask
93 | return cutoffs
94 |
95 |
96 | class HardCutoff(nn.Module):
97 | r"""Class of hard cutoff.
98 |
99 | .. math::
100 | f(r) = \begin{cases}
101 | 1 & r \leqslant r_\text{cutoff} \\
102 | 0 & r > r_\text{cutoff} \\
103 | \end{cases}
104 |
105 | Args:
106 | cutoff (float): cutoff radius.
107 |
108 | """
109 |
110 | def __init__(self, cutoff=5.0):
111 | super(HardCutoff, self).__init__()
112 | self.register_buffer("cutoff", torch.FloatTensor([cutoff]))
113 |
114 | def forward(self, distances):
115 | """Compute cutoff.
116 |
117 | Args:
118 | distances (torch.Tensor): values of interatomic distances.
119 |
120 | Returns:
121 | torch.Tensor: values of cutoff function.
122 |
123 | """
124 | mask = (distances <= self.cutoff).float()
125 | return mask
126 |
--------------------------------------------------------------------------------
/baselines/spk_utils/initializers.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | from torch.nn.init import constant_
4 |
5 | zeros_initializer = partial(constant_, val=0.0)
6 |
--------------------------------------------------------------------------------
/baselines/spk_utils/neighbors.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | def atom_distances(
6 | positions,
7 | neighbors,
8 | cell=None,
9 | cell_offsets=None,
10 | return_vecs=False,
11 | normalize_vecs=False,
12 | neighbor_mask=None,
13 | ):
14 | r"""Compute distance of every atom to its neighbors.
15 |
16 | This function uses advanced torch indexing to compute differentiable distances
17 | of every central atom to its relevant neighbors.
18 |
19 | Args:
20 | positions (torch.Tensor):
21 | atomic Cartesian coordinates with (N_b x N_at x 3) shape
22 | neighbors (torch.Tensor):
23 | indices of neighboring atoms to consider with (N_b x N_at x N_nbh) shape
24 | cell (torch.tensor, optional):
25 | periodic cell of (N_b x 3 x 3) shape
26 | cell_offsets (torch.Tensor, optional) :
27 | offset of atom in cell coordinates with (N_b x N_at x N_nbh x 3) shape
28 | return_vecs (bool, optional): if True, also returns direction vectors.
29 | normalize_vecs (bool, optional): if True, normalize direction vectors.
30 | neighbor_mask (torch.Tensor, optional): boolean mask for neighbor positions.
31 |
32 | Returns:
33 | (torch.Tensor, torch.Tensor):
34 | distances:
35 | distance of every atom to its neighbors with
36 | (N_b x N_at x N_nbh) shape.
37 |
38 | dist_vec:
39 | direction cosines of every atom to its
40 | neighbors with (N_b x N_at x N_nbh x 3) shape (optional).
41 |
42 | """
43 |
44 | # Construct auxiliary index vector
45 | n_batch = positions.size()[0]
46 | idx_m = torch.arange(n_batch, device=positions.device, dtype=torch.long)[
47 | :, None, None
48 | ]
49 | # Get atomic positions of all neighboring indices
50 | pos_xyz = positions[idx_m, neighbors[:, :, :], :]
51 |
52 | # Subtract positions of central atoms to get distance vectors
53 | dist_vec = pos_xyz - positions[:, :, None, :]
54 |
55 | # add cell offset
56 | if cell is not None:
57 | B, A, N, D = cell_offsets.size()
58 | cell_offsets = cell_offsets.view(B, A * N, D)
59 | offsets = cell_offsets.bmm(cell)
60 | offsets = offsets.view(B, A, N, D)
61 | dist_vec += offsets
62 |
63 | # Compute vector lengths
64 | distances = torch.norm(dist_vec, 2, 3)
65 |
66 | if neighbor_mask is not None:
67 | # Avoid problems with zero distances in forces (instability of square
68 | # root derivative at 0) This way is neccessary, as gradients do not
69 | # work with inplace operations, such as e.g.
70 | # -> distances[mask==0] = 0.0
71 | tmp_distances = torch.zeros_like(distances)
72 | tmp_distances[neighbor_mask != 0] = distances[neighbor_mask != 0]
73 | distances = tmp_distances
74 |
75 | if return_vecs:
76 | tmp_distances = torch.ones_like(distances)
77 | tmp_distances[neighbor_mask != 0] = distances[neighbor_mask != 0]
78 |
79 | if normalize_vecs:
80 | dist_vec = dist_vec / tmp_distances[:, :, :, None]
81 | return distances, dist_vec
82 | return distances
83 |
84 |
85 | class AtomDistances(nn.Module):
86 | r"""Layer for computing distance of every atom to its neighbors.
87 |
88 | Args:
89 | return_directions (bool, optional): if True, the `forward` method also returns
90 | normalized direction vectors.
91 |
92 | """
93 |
94 | def __init__(self, return_directions=False):
95 | super(AtomDistances, self).__init__()
96 | self.return_directions = return_directions
97 |
98 | def forward(
99 | self, positions, neighbors, cell=None, cell_offsets=None, neighbor_mask=None
100 | ):
101 | r"""Compute distance of every atom to its neighbors.
102 |
103 | Args:
104 | positions (torch.Tensor): atomic Cartesian coordinates with
105 | (N_b x N_at x 3) shape.
106 | neighbors (torch.Tensor): indices of neighboring atoms to consider
107 | with (N_b x N_at x N_nbh) shape.
108 | cell (torch.tensor, optional): periodic cell of (N_b x 3 x 3) shape.
109 | cell_offsets (torch.Tensor, optional): offset of atom in cell coordinates
110 | with (N_b x N_at x N_nbh x 3) shape.
111 | neighbor_mask (torch.Tensor, optional): boolean mask for neighbor
112 | positions. Required for the stable computation of forces in
113 | molecules with different sizes.
114 |
115 | Returns:
116 | torch.Tensor: layer output of (N_b x N_at x N_nbh) shape.
117 |
118 | """
119 | return atom_distances(
120 | positions,
121 | neighbors,
122 | cell,
123 | cell_offsets,
124 | return_vecs=self.return_directions,
125 | normalize_vecs=True,
126 | neighbor_mask=neighbor_mask,
127 | )
128 |
129 |
130 | def triple_distances(
131 | positions,
132 | neighbors_j,
133 | neighbors_k,
134 | offset_idx_j=None,
135 | offset_idx_k=None,
136 | cell=None,
137 | cell_offsets=None,
138 | ):
139 | """
140 | Get all distances between atoms forming a triangle with the central atoms.
141 | Required e.g. for angular symmetry functions.
142 |
143 | Args:
144 | positions (torch.Tensor): Atomic positions
145 | neighbors_j (torch.Tensor): Indices of first neighbor in triangle
146 | neighbors_k (torch.Tensor): Indices of second neighbor in triangle
147 | offset_idx_j (torch.Tensor): Indices for offets of neighbors j (for PBC)
148 | offset_idx_k (torch.Tensor): Indices for offets of neighbors k (for PBC)
149 | cell (torch.tensor, optional): periodic cell of (N_b x 3 x 3) shape.
150 | cell_offsets (torch.Tensor, optional): offset of atom in cell coordinates
151 | with (N_b x N_at x N_nbh x 3) shape.
152 |
153 | Returns:
154 | torch.Tensor: Distance between central atom and neighbor j
155 | torch.Tensor: Distance between central atom and neighbor k
156 | torch.Tensor: Distance between neighbors
157 |
158 | """
159 | nbatch, _, _ = neighbors_k.size()
160 | idx_m = torch.arange(nbatch, device=positions.device, dtype=torch.long)[
161 | :, None, None
162 | ]
163 |
164 | pos_j = positions[idx_m, neighbors_j[:], :]
165 | pos_k = positions[idx_m, neighbors_k[:], :]
166 |
167 | if cell is not None:
168 | # Get the offsets into true cartesian values
169 | B, A, N, D = cell_offsets.size()
170 |
171 | cell_offsets = cell_offsets.view(B, A * N, D)
172 | offsets = cell_offsets.bmm(cell)
173 | offsets = offsets.view(B, A, N, D)
174 |
175 | # Get the offset values for j and k atoms
176 | B, A, T = offset_idx_j.size()
177 |
178 | # Collapse batch and atoms position for easier indexing
179 | offset_idx_j = offset_idx_j.view(B * A, T)
180 | offset_idx_k = offset_idx_k.view(B * A, T)
181 | offsets = offsets.view(B * A, -1, D)
182 |
183 | # Construct auxiliary aray for advanced indexing
184 | idx_offset_m = torch.arange(B * A, device=positions.device, dtype=torch.long)[
185 | :, None
186 | ]
187 |
188 | # Restore proper dmensions
189 | offset_j = offsets[idx_offset_m, offset_idx_j[:]].view(B, A, T, D)
190 | offset_k = offsets[idx_offset_m, offset_idx_k[:]].view(B, A, T, D)
191 |
192 | # Add offsets
193 | pos_j = pos_j + offset_j
194 | pos_k = pos_k + offset_k
195 |
196 | # if positions.is_cuda:
197 | # idx_m = idx_m.pin_memory().cuda(async=True)
198 |
199 | # Get the real positions of j and k
200 | R_ij = pos_j - positions[:, :, None, :]
201 | R_ik = pos_k - positions[:, :, None, :]
202 | R_jk = pos_j - pos_k
203 |
204 | # + 1e-9 to avoid division by zero
205 | r_ij = torch.norm(R_ij, 2, 3) + 1e-9
206 | r_ik = torch.norm(R_ik, 2, 3) + 1e-9
207 | r_jk = torch.norm(R_jk, 2, 3) + 1e-9
208 |
209 | return r_ij, r_ik, r_jk
210 |
211 |
212 | class TriplesDistances(nn.Module):
213 | """
214 | Layer that gets all distances between atoms forming a triangle with the
215 | central atoms. Required e.g. for angular symmetry functions.
216 | """
217 |
218 | def __init__(self):
219 | super(TriplesDistances, self).__init__()
220 |
221 | def forward(self, positions, neighbors_j, neighbors_k):
222 | """
223 | Args:
224 | positions (torch.Tensor): Atomic positions
225 | neighbors_j (torch.Tensor): Indices of first neighbor in triangle
226 | neighbors_k (torch.Tensor): Indices of second neighbor in triangle
227 |
228 | Returns:
229 | torch.Tensor: Distance between central atom and neighbor j
230 | torch.Tensor: Distance between central atom and neighbor k
231 | torch.Tensor: Distance between neighbors
232 |
233 | """
234 | return triple_distances(positions, neighbors_j, neighbors_k)
235 |
236 |
237 | def neighbor_elements(atomic_numbers, neighbors):
238 | """
239 | Return the atomic numbers associated with the neighboring atoms. Can also
240 | be used to gather other properties by neighbors if different atom-wise
241 | Tensor is passed instead of atomic_numbers.
242 |
243 | Args:
244 | atomic_numbers (torch.Tensor): Atomic numbers (Nbatch x Nat x 1)
245 | neighbors (torch.Tensor): Neighbor indices (Nbatch x Nat x Nneigh)
246 |
247 | Returns:
248 | torch.Tensor: Atomic numbers of neighbors (Nbatch x Nat x Nneigh)
249 |
250 | """
251 | # Get molecules in batch
252 | n_batch = atomic_numbers.size()[0]
253 | # Construct auxiliary index
254 | idx_m = torch.arange(n_batch, device=atomic_numbers.device, dtype=torch.long)[
255 | :, None, None
256 | ]
257 | # Get neighbors via advanced indexing
258 | neighbor_numbers = atomic_numbers[idx_m, neighbors[:, :, :]]
259 | return neighbor_numbers
260 |
261 |
262 | class NeighborElements(nn.Module):
263 | """
264 | Layer to obtain the atomic numbers associated with the neighboring atoms.
265 | """
266 |
267 | def __init__(self):
268 | super(NeighborElements, self).__init__()
269 |
270 | def forward(self, atomic_numbers, neighbors):
271 | """
272 | Args:
273 | atomic_numbers (torch.Tensor): Atomic numbers (Nbatch x Nat x 1)
274 | neighbors (torch.Tensor): Neighbor indices (Nbatch x Nat x Nneigh)
275 |
276 | Returns:
277 | torch.Tensor: Atomic numbers of neighbors (Nbatch x Nat x Nneigh)
278 | """
279 | return neighbor_elements(atomic_numbers, neighbors)
280 |
--------------------------------------------------------------------------------
/featurization/__pycache__/data_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/featurization/__pycache__/data_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/featurization/data_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | from torch.utils.data import Dataset, dataset
9 | import json
10 | import copy
11 |
12 |
13 | FloatTensor = torch.FloatTensor
14 | LongTensor = torch.LongTensor
15 | IntTensor = torch.IntTensor
16 | DoubleTensor = torch.DoubleTensor
17 |
18 | def cutname(ori_name):
19 | ori_name = ori_name[:-3]
20 | if ori_name.endswith('out.'):
21 | ori_name = ori_name[:-4]
22 | elif ori_name.endswith('faps.'):
23 | ori_name = ori_name[:-5]
24 | return ori_name + 'p'
25 |
26 |
27 | def load_data_from_df(dataset_path, gas_type, pressure, add_dummy_node=True, use_global_features=False, return_names=False):
28 |
29 | data_df = pd.read_csv(dataset_path + f'/label_by_GCMC/{gas_type}_ads_all.csv',header=0)
30 | data_x = data_df['name'].values
31 | if pressure == 'all':
32 | data_y = data_df.iloc[:,1:].values
33 | else:
34 | data_y = data_df[pressure].values
35 |
36 | if data_y.dtype == np.float64:
37 | data_y = data_y.astype(np.float32)
38 |
39 | x_all, y_all, name_all = load_data_from_processed(dataset_path, data_x, data_y, add_dummy_node=add_dummy_node)
40 |
41 | if return_names:
42 | x_all = (x_all, name_all)
43 |
44 | if use_global_features:
45 | f_all = load_data_with_global_features(dataset_path, name_all, gas_type)
46 | if pressure == 'all':
47 | return x_all, f_all, y_all, data_df.columns.values[1:]
48 | return x_all, f_all, y_all
49 |
50 | if pressure == 'all':
51 | return x_all, y_all, data_df.columns.values[1:]
52 | return x_all, y_all
53 |
54 | def norm_str(ori):
55 | ori = ori.split('.')[0].split('-')
56 | if ori[-1] == 'clean':
57 | ori = ori[:-1]
58 | elif ori[-2] == 'clean':
59 | ori = ori[:-2]
60 | return '-'.join(ori[1:])
61 |
62 |
63 | def load_real_data(dataset_path, gas_type):
64 |
65 | data_df = pd.read_csv(dataset_path + f'/global_features/exp_geo_all.csv', header=0)
66 | data_x = data_df['name'].values
67 | data_y = data_df.iloc[:,1:].values
68 | global_dic = {}
69 | for x,y in zip(data_x, data_y):
70 | global_dic[x] = y
71 | with open(dataset_path + '/isotherm_data/all.json') as f:
72 | labels = json.load(f)[gas_type]['data']
73 | label_dict = {_['name']:_["isotherm_data"] for _ in labels}
74 |
75 | with open(dataset_path + f'/isotherm_data/{gas_type}.txt','r') as f:
76 | ls = f.readlines()
77 | ls = [_.strip().split() for _ in ls]
78 | X_all, y_all, f_all, p_all, n_all = [],[],[],[],[]
79 | for l in ls:
80 | if l[0] not in global_dic:
81 | continue
82 | gf = global_dic[l[0]]
83 | afm, adj, dist = pickle.load(open(dataset_path + f'/local_features/{l[0]}.cif.p', "rb"))
84 | afm, adj, dist = add_dummy_node_func(afm, adj, dist)
85 | iso = label_dict[norm_str(l[0])]
86 | p,y = [],[]
87 | for _ in iso:
88 | if _['pressure'] > 0:
89 | p.append(_['pressure'])
90 | y.append(_['adsorption'])
91 | if len(p) == 0:
92 | continue
93 | X_all.append([afm,adj,dist])
94 | f_all.append(gf)
95 | p_all.append(p)
96 | y_all.append(y)
97 | n_all.append(norm_str(l[0]))
98 | return X_all, f_all, y_all, p_all, n_all
99 |
100 |
101 |
102 |
103 |
104 | def load_data_with_global_features(dataset_path, processed_files, gas_type):
105 | global_feature_path = dataset_path + f'/global_features/{gas_type}_global_features_update.csv'
106 | data_df = pd.read_csv(global_feature_path,header=0)
107 | data_x = data_df.iloc[:, 0].values
108 | data_f = data_df.iloc[:,1:].values.astype(np.float32)
109 | data_dict = {}
110 | for i in range(data_x.shape[0]):
111 | data_dict[data_x[i]] = data_f[i]
112 | f_all = [data_dict[_] for _ in processed_files]
113 | return f_all
114 |
115 |
116 |
117 | def load_data_from_processed(dataset_path, processed_files, labels, add_dummy_node=True):
118 | x_all, y_all, name_all = [], [], []
119 |
120 | for files, label in zip(processed_files, labels):
121 |
122 | data_file = dataset_path + '/local_features/' + files + '.p'
123 | try:
124 | afm, adj, dist = pickle.load(open(data_file, "rb"))
125 | if add_dummy_node:
126 | afm, adj, dist = add_dummy_node_func(afm, adj, dist)
127 | x_all.append([afm, adj, dist])
128 | y_all.append([label])
129 | name_all.append(files)
130 | except:
131 | pass
132 |
133 | return x_all, y_all, name_all
134 |
135 | def add_dummy_node_func(node_features, adj_matrix, dist_matrix):
136 | m = np.zeros((node_features.shape[0] + 1, node_features.shape[1] + 1))
137 | m[1:, 1:] = node_features
138 | m[0, 0] = 1.
139 | node_features = m
140 |
141 | m = np.ones((adj_matrix.shape[0] + 1, adj_matrix.shape[1] + 1))
142 | m[1:, 1:] = adj_matrix
143 | adj_matrix = m
144 |
145 | m = np.full((dist_matrix.shape[0] + 1, dist_matrix.shape[1] + 1), 1e6)
146 | m[1:, 1:] = dist_matrix
147 | dist_matrix = m
148 |
149 | return node_features, adj_matrix, dist_matrix
150 |
151 |
152 | class MOF:
153 | def __init__(self, x, y, index, feature = None):
154 | self.node_features = x[0]
155 | self.adjacency_matrix = x[1]
156 | self.distance_matrix = x[2]
157 | self.y = y
158 | self.index = index
159 | self.global_feature = feature
160 |
161 |
162 | class MOFDataset(Dataset):
163 |
164 | def __init__(self, data_list):
165 | self.data_list = data_list
166 |
167 | def __len__(self):
168 | return len(self.data_list)
169 |
170 | def __getitem__(self, key):
171 | if type(key) == slice:
172 | return MOFDataset(self.data_list[key])
173 | return self.data_list[key]
174 |
175 |
176 | class RealMOFDataset(Dataset):
177 | def __init__(self, data_list, pressure_list, ori_point):
178 | self.data_list = data_list
179 | self.pressure_list = pressure_list
180 | self.ori_point = np.log(np.float32(ori_point))
181 | def __len__(self):
182 | return len(self.data_list)
183 | def __getitem__(self,key):
184 | if type(key) == slice:
185 | return RealMOFDataset(self.data_list[key], self.pressure_list[key], self.ori_point)
186 | tar_mol = copy.deepcopy(self.data_list[key])
187 | tar_p = np.log(self.pressure_list[key]) - self.ori_point
188 | tar_mol.global_feature = np.append(tar_mol.global_feature, tar_p)
189 | tar_mol.y = tar_mol.y
190 | return tar_mol
191 |
192 | class MOFDatasetPressureVer(Dataset):
193 |
194 | def __init__(self, data_list, pressure_list, mask_point=None, is_train=True, tar_point=None):
195 | self.data_list = data_list
196 | self.pressure_list = pressure_list
197 | self.mask_point = mask_point
198 | self.is_train = is_train
199 | self.tar_point = tar_point
200 | if is_train:
201 | self.use_idx = np.where(pressure_list != mask_point)[0]
202 | else:
203 | self.use_idx = np.where(pressure_list == tar_point)[0]
204 | self.calcMid()
205 |
206 | def __len__(self):
207 | return len(self.data_list)
208 |
209 | def toStr(self):
210 | return {"data_list":self.data_list,"pressure_list":self.pressure_list,"mask_point":self.mask_point,"is_train":self.is_train, "tar_point":self.tar_point}
211 | def __getitem__(self, key):
212 | if type(key) == slice:
213 | return MOFDataset(self.data_list[key], self.pressure_list, self.mask_point, self.is_train)
214 | tar_mol = copy.deepcopy(self.data_list[key])
215 | if self.is_train:
216 | tar_p = self.float_pressure - self.mid
217 | tar_mol.global_feature = np.append(tar_mol.global_feature, tar_p)
218 | tar_mol.y = tar_mol.y[0]
219 | else:
220 | tar_idx = self.use_idx
221 | tar_p = self.float_pressure[tar_idx] - self.mid
222 | tar_mol.global_feature = np.append(tar_mol.global_feature, tar_p)
223 | tar_mol.y = [tar_mol.y[0][tar_idx]]
224 | return tar_mol
225 |
226 | def changeTarPoint(self,tar_point):
227 | self.tar_point = tar_point
228 | if not tar_point:
229 | self.is_train = True
230 | else:
231 | self.is_train = False
232 | if not self.is_train:
233 | self.use_idx = np.where(self.pressure_list == tar_point)[0]
234 |
235 | def calcMid(self):
236 | self.float_pressure = np.log(self.pressure_list.astype(np.float))
237 | self.mid = np.log(np.float(self.mask_point))
238 |
239 |
240 | def pad_array(array, shape, dtype=np.float32):
241 | padded_array = np.zeros(shape, dtype=dtype)
242 | padded_array[:array.shape[0], :array.shape[1]] = array
243 | return padded_array
244 |
245 |
246 | def mof_collate_func_gf(batch):
247 | adjacency_list, distance_list, features_list, global_features_list = [], [], [], []
248 | labels = []
249 |
250 | max_size = 0
251 | for molecule in batch:
252 | if type(molecule.y[0]) == np.ndarray:
253 | labels.append(molecule.y[0])
254 | else:
255 | labels.append(molecule.y)
256 | if molecule.adjacency_matrix.shape[0] > max_size:
257 | max_size = molecule.adjacency_matrix.shape[0]
258 |
259 | for molecule in batch:
260 | adjacency_list.append(pad_array(molecule.adjacency_matrix, (max_size, max_size)))
261 | distance_list.append(pad_array(molecule.distance_matrix, (max_size, max_size)))
262 | features_list.append(pad_array(molecule.node_features, (max_size, molecule.node_features.shape[1])))
263 | global_features_list.append(molecule.global_feature)
264 |
265 | return [FloatTensor(features) for features in (adjacency_list, features_list, distance_list, global_features_list, labels)]
266 |
267 |
268 | def construct_dataset(x_all, y_all):
269 | output = [MOF(data[0], data[1], i)
270 | for i, data in enumerate(zip(x_all, y_all))]
271 | return MOFDataset(output)
272 |
273 | def construct_dataset_gf(x_all, f_all, y_all):
274 | output = [MOF(data[0], data[2], i, data[1])
275 | for i, data in enumerate(zip(x_all, f_all, y_all))]
276 | return MOFDataset(output)
277 |
278 | def construct_dataset_gf_pressurever(x_all, f_all, y_all, pressure_list, is_train=True, mask_point=None, tar_point=None):
279 | output = [MOF(data[0], data[2], i, data[1])
280 | for i, data in enumerate(zip(x_all, f_all, y_all))]
281 | return MOFDatasetPressureVer(output, pressure_list, is_train=is_train, mask_point=mask_point,tar_point=tar_point)
282 |
283 | def construct_dataset_real(x_all, f_all, y_all, pressure_list, tar_point=None):
284 | output = [MOF(data[0], data[2], i, data[1])
285 | for i, data in enumerate(zip(x_all, f_all, y_all))]
286 | return RealMOFDataset(output, pressure_list, ori_point=tar_point)
287 |
288 | def construct_loader_gf(x,f,y, batch_size, shuffle=True):
289 | data_set = construct_dataset_gf(x, f, y)
290 | loader = torch.utils.data.DataLoader(dataset=data_set,
291 | batch_size=batch_size,
292 | num_workers=0,
293 | collate_fn=mof_collate_func_gf,
294 | pin_memory=True,
295 | shuffle=shuffle)
296 | return loader
297 |
298 | def construct_loader_gf_pressurever(data_set, batch_size, shuffle=True):
299 | loader = torch.utils.data.DataLoader(dataset=data_set,
300 | batch_size=batch_size,
301 | num_workers=0,
302 | collate_fn=mof_collate_func_gf,
303 | pin_memory=True,
304 | shuffle=shuffle)
305 | return loader
306 |
307 | class data_prefetcher():
308 | def __init__(self, loader):
309 | self.loader = iter(loader)
310 | self.stream = torch.cuda.Stream()
311 | self.preload()
312 |
313 | def preload(self):
314 | try:
315 | self.next_data = next(self.loader)
316 | except StopIteration:
317 | self.next_data = None
318 | return
319 | with torch.cuda.stream(self.stream):
320 | self.next_data = tuple(_.cuda(non_blocking=True) for _ in self.next_data)
321 |
322 | def next(self):
323 | torch.cuda.current_stream().wait_stream(self.stream)
324 | batch = self.next_data
325 | self.preload()
326 | return batch
327 |
--------------------------------------------------------------------------------
/image/3dstructgen-mof.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/image/3dstructgen-mof.png
--------------------------------------------------------------------------------
/image/Fig1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/image/Fig1.jpg
--------------------------------------------------------------------------------
/image/Fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Matgen-project/MOFNet/47132ace072a8446ef43379a9ab019743ff746e7/image/Fig1.png
--------------------------------------------------------------------------------
/model_shap.py:
--------------------------------------------------------------------------------
1 | import shap
2 | import torch
3 | from collections import defaultdict
4 | from featurization.data_utils import load_data_from_df, construct_loader_gf_pressurever, construct_dataset_gf_pressurever, data_prefetcher
5 | from models.transformer import make_model
6 | import numpy as np
7 | import os
8 | from argparser import parse_train_args
9 | import pickle
10 | from tqdm import tqdm
11 | from utils import *
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 |
15 | def gradient_shap(model, sample_loader, test_loader, batch_size):
16 | model.eval()
17 | model.set_adapter_dim(1)
18 | graph_reps, global_feas = [],[]
19 | for data in tqdm(sample_loader):
20 | adjacency_matrix, node_features, distance_matrix, global_features, y = (_.cpu() for _ in data)
21 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
22 | batch_mask = batch_mask.float()
23 | graph_rep = model.encode(node_features, batch_mask, adjacency_matrix, distance_matrix, None)
24 | graph_reps.append(graph_rep)
25 | global_feas.append(global_features)
26 | graph_reps = torch.cat(graph_reps)
27 | global_feas = torch.cat(global_feas)
28 | e = shap.GradientExplainer(model.generator, [graph_reps, global_feas])
29 | shap_all = []
30 | for data in tqdm(test_loader):
31 | adjacency_matrix, node_features, distance_matrix, global_features, y = (_.cpu() for _ in data)
32 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
33 | batch_mask = batch_mask.float()
34 | graph_rep = model.encode(node_features, batch_mask, adjacency_matrix, distance_matrix, None)
35 | ans = e.shap_values([graph_rep, global_features],nsamples=10)
36 | local_shap = np.abs(ans[0].sum(axis=1)).reshape(-1,1)
37 | global_shap = np.abs(ans[-1])[:,:9]
38 | shap_values = np.concatenate([local_shap, global_shap],axis=1)
39 | shap_all.append(shap_values)
40 | shap_all = np.concatenate(shap_all, axis=0)
41 | return shap_all
42 |
43 | if __name__ == '__main__':
44 | model_params = parse_train_args()
45 | device_ids = [0,1,2,3]
46 | X, f, y, p = load_data_from_df(model_params['data_dir'],gas_type=model_params['gas_type'], pressure='all',add_dummy_node = True,use_global_features = True)
47 | tar_idx = np.where(p==model_params['pressure'])[0][0]
48 | print(f'Loaded {len(X)} data.')
49 | y = np.array(y)
50 | mean = y[...,tar_idx].mean()
51 | std = y[...,tar_idx].std()
52 | y = (y - mean) / std
53 | f = np.array(f)
54 | fmean = f.mean(axis=0)
55 | fstd = f.std(axis=0)
56 | f = (f - fmean) / fstd
57 | batch_size = model_params['batch_size']
58 | fold_num = model_params['fold']
59 | idx_list = np.arange(len(X))
60 | set_seed(model_params['seed'])
61 | np.random.shuffle(idx_list)
62 | X = applyIndexOnList(X,idx_list)
63 | f = f[idx_list]
64 | y = y[idx_list]
65 |
66 |
67 |
68 | for fold_idx in range(1,2):
69 | set_seed(model_params['seed'])
70 | save_dir = model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}/Fold-{fold_idx}"
71 | ckpt_handler = CheckpointHandler(save_dir)
72 | state = ckpt_handler.checkpoint_best()
73 | model = make_model(**state['params'])
74 | model = torch.nn.DataParallel(model)
75 | model.load_state_dict(state['model'])
76 | model = model.module
77 | train_idx, val_idx, test_idx = splitdata(len(X),fold_num,fold_idx)
78 | train_sample = construct_dataset_gf_pressurever(applyIndexOnList(X,train_idx), f[train_idx], y[train_idx],p, is_train=False, tar_point=model_params['pressure'],mask_point=model_params['pressure'])
79 | test_set = construct_dataset_gf_pressurever(applyIndexOnList(X,test_idx), f[test_idx], y[test_idx],p, is_train=False, tar_point=model_params['pressure'],mask_point=model_params['pressure'])
80 | shaps = {pres:[] for pres in [p[3]]}
81 | for pres in [p[3]]:
82 | train_sample.changeTarPoint(pres)
83 | test_set.changeTarPoint(pres)
84 | sample_loader = construct_loader_gf_pressurever(train_sample, batch_size, shuffle=False)
85 | test_loader = construct_loader_gf_pressurever(test_set, batch_size, shuffle=False)
86 | shap_values = gradient_shap(model, sample_loader, test_loader, batch_size)
87 | shaps[pres].append(shap_values)
88 |
89 | for pres in [p[3]]:
90 | shaps[pres] = np.concatenate(shaps[pres],axis=0)
91 |
92 | with open(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}/shap_result_{p[3]}.p",'wb') as f:
93 | pickle.dump(shaps, f)
94 |
95 |
96 |
--------------------------------------------------------------------------------
/nist_test.py:
--------------------------------------------------------------------------------
1 | from cProfile import label
2 | import os
3 | import pandas as pd
4 | import torch
5 | import torch.nn.functional as F
6 | import numpy as np
7 | import time
8 | from featurization.data_utils import load_data_from_df, construct_loader_gf_pressurever, construct_dataset_gf_pressurever, data_prefetcher, load_real_data, construct_dataset_real
9 | from models.transformer import make_model
10 | from argparser import parse_train_args
11 | from utils import *
12 | import matplotlib.pyplot as plt
13 | from tqdm import tqdm
14 | import pickle
15 |
16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17 |
18 | def ensemble_test(models,data_loader, mean, std, img_dir, names, p_ori):
19 | os.makedirs(img_dir,exist_ok=True)
20 | for model in models:
21 | model.eval()
22 | batch_idx = 0
23 | p_ori = np.log(float(p_ori))
24 | ans = {}
25 | for data in tqdm(data_loader):
26 | adjacency_matrix, node_features, distance_matrix, global_features, y = data
27 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
28 | adapter_dim = global_features.shape[-1] - 9
29 | pressure = global_features[...,-adapter_dim:]
30 | outputs = []
31 | for model in models:
32 | model.module.set_adapter_dim(adapter_dim)
33 | output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, global_features)
34 | outputs.append(output.cpu().detach().numpy().reshape(-1) * std + mean)
35 | y_tmp = y.cpu().detach().numpy().reshape(-1)
36 | futures_tmp = np.mean(np.array(outputs),axis=0)
37 | pres = pressure.cpu().detach().numpy().reshape(-1) + p_ori
38 |
39 | plt.xlabel('log pressure(Pa)')
40 | plt.ylabel('adsorption(mol/kg)')
41 | l1 = plt.scatter(pres, y_tmp, c ='r', marker = 'o')
42 | l2 = plt.scatter(pres, futures_tmp, c = 'g', marker = 'x')
43 | plt.legend(handles=[l1,l2],labels=['label','prediction'],loc='best')
44 | plt.savefig(f'{img_dir}/{names[batch_idx]}.png')
45 | plt.cla()
46 | ans[names[batch_idx]] = {
47 | 'pressure':np.exp(pres),
48 | 'label':y_tmp,
49 | 'pred':futures_tmp
50 | }
51 | batch_idx += 1
52 | return ans
53 |
54 | if __name__ == '__main__':
55 |
56 | model_params = parse_train_args()
57 | batch_size = 1
58 | device_ids = [0,1,2,3]
59 |
60 | save_dir = f"{model_params['save_dir']}/{model_params['gas_type']}_{model_params['pressure']}"
61 |
62 | with open(os.path.join(save_dir,f'offset.p'),'rb') as f:
63 | p_ori, mean, std, fmean, fstd = pickle.load(f)
64 |
65 | test_errors_all = []
66 |
67 | X, f, y, p, names = load_real_data(model_params['data_dir'], model_params['gas_type'])
68 | f = np.array(f)
69 | f = (f - fmean) / fstd
70 | test_errors = []
71 | models = []
72 | img_dir = os.path.join(model_params['img_dir'],model_params['gas_type'])
73 | predict_res = []
74 | for fold_idx in range(1,11):
75 | save_dir_fold = f"{save_dir}/Fold-{fold_idx}"
76 | state = CheckpointHandler(save_dir_fold).checkpoint_best()
77 | model = make_model(**state['params'])
78 | model = torch.nn.DataParallel(model)
79 | model.load_state_dict(state['model'])
80 | model = model.to(device)
81 | models.append(model)
82 | test_set = construct_dataset_real(X, f, y, p, p_ori)
83 | test_loader = construct_loader_gf_pressurever(test_set,1,shuffle=False)
84 | test_res = ensemble_test(models, test_loader, mean, std, img_dir, names, p_ori)
85 | with open(os.path.join(img_dir,f"results.p"),'wb') as f:
86 | pickle.dump(test_res,f)
87 |
--------------------------------------------------------------------------------
/pressure_adapt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import time
7 | from featurization.data_utils import load_data_from_df, construct_loader_gf_pressurever, construct_dataset_gf_pressurever, data_prefetcher
8 | from models.transformer import make_model
9 | from argparser import parse_finetune_args
10 | import pickle
11 | from utils import *
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 |
15 |
16 | def train(model, epoch, train_loader, optimizer, scheduler, adapter_dim):
17 | model.train()
18 | loss = 0
19 | loss_all = 0
20 | prefetcher = data_prefetcher(train_loader)
21 | batch_idx = 0
22 | data = prefetcher.next()
23 | while data is not None:
24 | lr = scheduler.optimizer.param_groups[0]['lr']
25 | adjacency_matrix, node_features, distance_matrix, global_features, y = data
26 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
27 |
28 | optimizer.zero_grad()
29 | output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, global_features)
30 | loss = F.mse_loss(output.reshape(-1), y.reshape(-1))
31 | loss.backward()
32 | step_loss = loss.cpu().detach().numpy()
33 | loss_all += step_loss
34 | optimizer.step()
35 | scheduler.step()
36 | print(f'After Step {batch_idx} of Epoch {epoch}, Loss = {step_loss}, Lr = {lr}')
37 | batch_idx += 1
38 | data = prefetcher.next()
39 | return loss_all / len(train_loader.dataset)
40 |
41 |
42 |
43 | def test(model, data_loader, mean, std, adapter_dim):
44 | model.eval()
45 | error = 0
46 | prefetcher = data_prefetcher(data_loader)
47 | batch_idx = 0
48 | data = prefetcher.next()
49 | futures, ys = None, None
50 | while data is not None:
51 | adjacency_matrix, node_features, distance_matrix, global_features, y = data
52 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
53 | output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, global_features)
54 | output = output.reshape(y.shape).cpu().detach().numpy()
55 | y = y.cpu().detach().numpy()
56 | ys = y if ys is None else np.concatenate([ys,y], axis=0)
57 | futures = output if futures is None else np.concatenate([futures,output], axis=0)
58 | batch_idx += 1
59 | data = prefetcher.next()
60 |
61 | futures = np.array(futures) * std + mean
62 | ys = np.array(ys) * std + mean
63 | mae = np.mean(np.abs(futures - ys), axis=0)
64 | rmse = np.sqrt(np.mean((futures - ys)**2, axis=0))
65 | # pcc = np.corrcoef(futures,ys)[0][1]
66 | pcc = np.array([np.corrcoef(futures[:,i],ys[:,i])[0][1] for i in range(adapter_dim)])
67 | smape = 2 * np.mean(np.abs(futures-ys)/(np.abs(futures)+np.abs(ys)), axis=0)
68 |
69 | return {'MAE':mae, 'RMSE':rmse, 'PCC':pcc, 'sMAPE':smape}
70 |
71 |
72 |
73 | def get_RdecayFactor(warmup_step):
74 |
75 | def warmupRdecayFactor(step):
76 | if step < warmup_step:
77 | return step / warmup_step
78 | else:
79 | return (warmup_step / step) ** 0.5
80 |
81 | return warmupRdecayFactor
82 |
83 | if __name__ == '__main__':
84 |
85 | model_params = parse_finetune_args()
86 | batch_size = model_params['batch_size']
87 | device_ids = [0,1,2,3]
88 | logger = get_logger(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}")
89 | X, f, y, p = load_data_from_df(model_params['data_dir'],gas_type=model_params['gas_type'], pressure='all',add_dummy_node = True,use_global_features = True)
90 | tar_idx = np.where(p==model_params['pressure'])[0][0]
91 | print(f'Loaded {len(X)} data.')
92 | logger.info(f'Loaded {len(X)} data.')
93 | y = np.array(y)
94 | mean = y[...,tar_idx].mean()
95 | std = y[...,tar_idx].std()
96 | y = (y - mean) / std
97 | f = np.array(f)
98 | fmean = f.mean(axis=0)
99 | fstd = f.std(axis=0)
100 | f = (f - fmean) / fstd
101 |
102 | with open(os.path.join(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}",f'offset.p'),'wb') as file:
103 | pickle.dump((model_params['pressure'], mean, std, fmean, fstd), file)
104 |
105 | printParams(model_params,logger)
106 | fold_num = model_params['fold']
107 | epoch_num = model_params['epoch']
108 | test_errors = []
109 | idx_list = np.arange(len(X))
110 | set_seed(model_params['seed'])
111 | np.random.shuffle(idx_list)
112 | X = applyIndexOnList(X,idx_list)
113 | f = f[idx_list]
114 | y = y[idx_list]
115 | test_errors = []
116 |
117 | for fold_idx in range(1, fold_num + 1):
118 |
119 | set_seed(model_params['seed'])
120 | ori_state = CheckpointHandler(model_params['ori_dir']+f'/Fold-{fold_idx}').checkpoint_avg()
121 | ori_params = ori_state['params']
122 | ori_params['adapter_finetune'] = True
123 | model = make_model(**ori_params)
124 | model.set_adapter_dim(model_params['adapter_dim'])
125 | model = torch.nn.DataParallel(model, device_ids=device_ids)
126 | model.load_state_dict(ori_state['model'],strict=False)
127 | model = model.to(device)
128 | lr = model_params['lr']
129 | optimizer = torch.optim.Adam(model.parameters(), lr=lr)
130 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda = get_RdecayFactor(ori_params['warmup_step']))
131 | best_val_error = 0
132 | best_val_error_s = 0
133 | test_error = 0
134 | best_epoch = -1
135 |
136 | train_idx, val_idx, test_idx = splitdata(len(X),fold_num, fold_idx)
137 |
138 | train_set = construct_dataset_gf_pressurever(applyIndexOnList(X,train_idx), f[train_idx], y[train_idx],p, is_train=True, mask_point=model_params['pressure'])
139 |
140 |
141 | val_set = construct_dataset_gf_pressurever(applyIndexOnList(X,val_idx), f[val_idx], y[val_idx],p, is_train=True, mask_point=model_params['pressure'])
142 |
143 |
144 | test_set = construct_dataset_gf_pressurever(applyIndexOnList(X,test_idx), f[test_idx], y[test_idx],p, is_train=True, mask_point=model_params['pressure'])
145 |
146 | ckpt_handler = CheckpointHandler(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}/Fold-{fold_idx}")
147 |
148 | for epoch in range(1,epoch_num + 1):
149 | train_adapter_dim = model_params['adapter_dim']
150 | train_loader = construct_loader_gf_pressurever(train_set,batch_size)
151 | loss = train(model, epoch, train_loader,optimizer,scheduler, train_adapter_dim)
152 | val_loader = construct_loader_gf_pressurever(val_set, batch_size, shuffle=False)
153 | val_error = test(model, val_loader, mean, std, train_adapter_dim)['MAE']
154 | val_error_ = np.mean(val_error)
155 | ckpt_handler.save_model(model,ori_params,epoch,val_error_)
156 |
157 | if best_val_error == 0 or val_error_ <= best_val_error:
158 | print("Enter test step.\n")
159 | best_epoch = epoch
160 | best_val_error = val_error_
161 | test_loader = construct_loader_gf_pressurever(test_set, batch_size, shuffle=False)
162 | test_error = test(model, test_loader, mean, std, train_adapter_dim)
163 | for idx, pres in enumerate(p):
164 | for _ in test_error.keys():
165 | print('Fold: {:02d}, Epoch: {:03d}, Pressure: {}, Test {}: {:.7f}'.format(fold_idx, epoch, pres, _, test_error[_][idx]))
166 | logger.info('Fold: {:02d}, Epoch: {:03d}, Pressure: {}, Test {}: {:.7f}'.format(fold_idx, epoch, pres, _, test_error[_][idx]))
167 | lr = scheduler.optimizer.param_groups[0]['lr']
168 | p_str = 'Fold: {:02d}, Epoch: {:03d}, Val MAE: {:.7f}, Best Val MAE: {:.7f}'.format(fold_idx, epoch, val_error_, best_val_error)
169 | print(p_str)
170 | logger.info(p_str)
171 |
172 | for idx, pres in enumerate(p):
173 | for _ in test_error.keys():
174 | print('Fold: {:02d}, Epoch: {:03d}, Pressure: {}, Test {}: {:.7f}'.format(fold_idx, epoch, pres, _, test_error[_][idx]))
175 | logger.info('Fold: {:02d}, Epoch: {:03d}, Pressure: {}, Test {}: {:.7f}'.format(fold_idx, epoch, pres, _, test_error[_][idx]))
176 |
177 | test_errors.append(test_error)
178 |
179 | for idx, pres in enumerate(p):
180 | for _ in test_errors[0].keys():
181 | mt_list = [__[_][idx] for __ in test_errors]
182 | p_str = 'Pressure {}, Test {} of {:02d}-Folds: {:.7f}({:.7f})'.format(pres, _, fold_num, np.mean(mt_list), np.std(mt_list))
183 | print(p_str)
184 | logger.info(p_str)
--------------------------------------------------------------------------------
/process/README:
--------------------------------------------------------------------------------
1 | python process_data.py ABIJUS
2 |
--------------------------------------------------------------------------------
/process/create_geo_features.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | data_dir=${root_dir_for_cifs}/
4 | cell_num=$2
5 | i=$1
6 | name=`echo $i |cut -d '.' -f 1`
7 | #argument 1 and 2
8 | #../network -ha -res ~/wsl/work/clean/$i >/dev/null
9 | ~/bin/network -ha -res ${data_dir}/${i} >/dev/null
10 | LCD=`head -n 1 ${data_dir}/${name}.res | awk '{print $4}'`
11 | PLD=`head -n 1 ${data_dir}/${name}.res | awk '{print $3}'`
12 | #exit
13 | rm ${data_dir}/${name}.res
14 | #argument 3 and 4 and 5
15 | #../network -ha -sa 1.86 1.86 2000 ~/wsl/work/clean/$i >/dev/null
16 | ~/bin/network -ha -sa 1.86 1.86 2000 ${data_dir}/${i} >/dev/null
17 | VSA=`head -n 1 ${data_dir}/${name}.sa | awk '{print $10}'`
18 | GSA=`head -n 1 ${data_dir}/${name}.sa | awk '{print $12}'`
19 | Density=`head -n 1 ${data_dir}/${name}.sa | awk '{print $6}'`
20 | chan_num_sa=`sed -n '2p' ${data_dir}/${name}.sa | awk '{print $2}'`
21 | rm ${data_dir}/${name}.sa
22 | #argument 6 and 7
23 | # ../network -ha -vol 0 0 50000 ~/wsl/work/clean/$i >/dev/null
24 | ~/bin/network -ha -vol 0 0 50000 ${data_dir}/${i} >/dev/null
25 | voidfract=`head -n 1 ${data_dir}/${name}.vol | awk '{print $10}'`
26 | porevolume=`head -n 1 ${data_dir}/${name}.vol | awk '{print $12}'`
27 | rm ${data_dir}/${name}.vol
28 |
29 | ~/bin/network -oms /tmp/${i}.cif >/dev/null
30 | oms=`tail -n 1 /tmp/${i}.oms | awk '{print $3}'`
31 | rm /tmp/${i}.oms
32 |
33 | chan_num_sa=`echo "scale=6;${chan_num_sa}/${cell_num}" | bc`
34 | porevolume=`echo "scale=6;${porevolume}/${cell_num}" | bc`
35 | oms=`echo "scale=6;${oms}/${cell_num}" | bc`
36 | #printf "%-20s%-10s%-10s%-10s%-10s%-10s%-10s%-15s%-15s\n" $name $Density $PLD $LCD $VSA $GSA $voidfract $porevolume $chan_num_sa $oms
37 | echo "$i,$LCD,$PLD,$VSA,$GSA,$Density,$voidfract,$porevolume,$chan_num_sa,$oms"
38 |
--------------------------------------------------------------------------------
/process/prepare_mof_features.py:
--------------------------------------------------------------------------------
1 | from ccdc.descriptors import MolecularDescriptors as MD, GeometricDescriptors as GD
2 | from ccdc.io import EntryReader
3 | csd = EntryReader('CSD')
4 | import ccdc.molecule
5 | import sys
6 | import os
7 | import numpy as np
8 | import math
9 |
10 | import pickle
11 | from script.get_atom_features import get_atom_features
12 | from script.get_bond_features import get_bond_features
13 | from script.remove_waters import remove_waters, remove_single_oxygen, get_largest_components
14 |
15 | mol_name = sys.argv[1]
16 | mol = csd.molecule(mol_name)
17 |
18 | mol = remove_waters(mol)
19 | mol = remove_single_oxygen(mol)
20 | if len(mol.components) > 1:
21 | lg_id = get_largest_components(mol)
22 | mol = mol.components[lg_id]
23 |
24 | mol.remove_hydrogens()
25 |
26 | atom_features = get_atom_features(mol)
27 | bond_features = get_bond_features(mol)
28 |
29 | mol_features = [atom_features, bond_features]
30 |
31 | save_path = './processed/' + mol_name + '.p'
32 |
33 | if not os.path.exists(save_path):
34 | pickle.dump(mol_features,open(save_path, "wb"))
35 |
36 |
37 |
--------------------------------------------------------------------------------
/process/process_csd_data.py:
--------------------------------------------------------------------------------
1 | from ccdc.descriptors import MolecularDescriptors as MD, GeometricDescriptors as GD
2 | from ccdc.io import EntryReader
3 | csd = EntryReader('CSD')
4 | import ccdc.molecule
5 | import sys
6 | import os
7 | import numpy as np
8 | import math
9 |
10 | import pickle
11 | from tools.get_atom_features import get_atom_features
12 | from tools.get_bond_features import get_bond_features
13 | from tools.remove_waters import remove_waters, remove_single_oxygen, get_largest_components
14 | import numpy as np
15 | from sklearn.metrics import pairwise_distances
16 |
17 | mol_name = sys.argv[1]
18 | mol = csd.molecule(mol_name)
19 |
20 |
21 | # remove waters
22 | mol = remove_waters(mol)
23 | mol = remove_single_oxygen(mol)
24 |
25 | # remove other solvates, here we remove all small components.
26 |
27 | if len(mol.components) > 1:
28 | lg_id = get_largest_components(mol)
29 | mol = mol.components[lg_id]
30 |
31 | mol.remove_hydrogens()
32 |
33 | atom_features = np.array([get_atom_features(atom) for atom in mol.atoms])
34 | bond_matrix = get_bond_features(mol)
35 |
36 | pos_matrix = np.array([[atom.coordinates.x, atom.coordinates.y, atom.coordinates.z] for atom in mol.atoms])
37 | dist_matrix = pairwise_distances(pos_matrix)
38 |
39 | mol_features = [atom_features, bond_matrix, dist_matrix]
40 |
41 | save_path = '../data/processed/' + mol_name + '.p'
42 |
43 | if not os.path.exists(save_path):
44 | pickle.dump(mol_features,open(save_path, "wb"))
45 |
46 |
47 |
--------------------------------------------------------------------------------
/process/process_csd_data_baselines.py:
--------------------------------------------------------------------------------
1 | from ccdc.descriptors import MolecularDescriptors as MD, GeometricDescriptors as GD
2 | from ccdc.io import EntryReader
3 | csd = EntryReader('CSD')
4 | import ccdc.molecule
5 | import sys
6 | import os
7 | import numpy as np
8 | import math
9 |
10 | import pickle
11 | from tools.get_atom_features import get_atom_features
12 | from tools.get_bond_features import get_bond_features_en
13 | from tools.remove_waters import remove_waters, remove_single_oxygen, get_largest_components
14 | import numpy as np
15 |
16 | mol_name = sys.argv[1]
17 | mol = csd.molecule(mol_name)
18 |
19 |
20 | # remove waters
21 | mol = remove_waters(mol)
22 | mol = remove_single_oxygen(mol)
23 |
24 | # remove other solvates, here we remove all small components.
25 |
26 | if len(mol.components) > 1:
27 | lg_id = get_largest_components(mol)
28 | mol = mol.components[lg_id]
29 |
30 | mol.remove_hydrogens()
31 |
32 | atom_features = np.array([get_atom_features(atom) for atom in mol.atoms])
33 | row, col = get_bond_features_en(mol)
34 |
35 | pos_matrix = np.array([[atom.coordinates.x, atom.coordinates.y, atom.coordinates.z] for atom in mol.atoms])
36 |
37 | mol_features = [atom_features, row, col, pos_matrix]
38 |
39 | save_path = '../data/processed_en/' + mol_name + '.p'
40 |
41 | os.makedirs('../data/processed_en/', exist_ok=True)
42 |
43 | if not os.path.exists(save_path):
44 | pickle.dump(mol_features,open(save_path, "wb"))
45 |
46 |
47 |
--------------------------------------------------------------------------------
/process/process_nist_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pandas as pd
4 | from tqdm import tqdm
5 | from argparse import ArgumentParser
6 |
7 | t_dict = {77:"Nitrogen", 273:"Carbon Dioxide", 298:"Methane"}
8 | unit_dic = {"mmol/g":1, "mol/g":0.001, 'mmol/kg':1000}
9 | m_dic = {"Nitrogen":28.0134, "Methane":16.0424, "Carbon Dioxide":44.0094}
10 | def get_unit_factor(unit,ads):
11 | if unit in unit_dic:
12 | return 1 / unit_dic[unit]
13 | elif unit == "cm3(STP)/g":
14 | return 1 / 22.4139
15 | elif unit == 'mg/g':
16 | return 1 / m_dic[ads]
17 | else:
18 | return None
19 |
20 | def norm_str(ori):
21 | ori = ori.split('.')[0].split('-')
22 | if ori[-1] == 'clean':
23 | ori = ori[:-1]
24 | elif ori[-2] == 'clean':
25 | ori = ori[:-2]
26 | return '-'.join(ori[1:])
27 |
28 | if __name__ == "__main__":
29 | parser = ArgumentParser()
30 | parser.add_argument('--data_dir', type=str,
31 | help='NIST data directory.')
32 | args = parser.parse_args()
33 | prefix = os.path.join(args.data_dir,'isotherm_data')
34 | pres_all = {"CH4":{"num":0, "data":[]}, "CO2":{"num":0, "data":[]}, "N2":{"num":0, "data":[]}}
35 | for gas_type in ['CH4','CO2','N2']:
36 | gas_pref = os.path.join(prefix, gas_type)
37 | files = os.listdir(gas_pref)
38 | for js in tqdm(files):
39 | with open(os.path.join(gas_pref, js), "r") as f:
40 | dic = json.load(f)
41 | name = dic['adsorbent']['name']
42 | t = dic['temperature']
43 | if t not in t_dict:
44 | continue
45 | tar_obj = t_dict[t]
46 | unit_factor = get_unit_factor(dic['adsorptionUnits'], tar_obj)
47 | if not unit_factor:
48 | continue
49 | tar_key = None
50 | for ads in dic['adsorbates']:
51 | if ads['name'] == tar_obj:
52 | tar_key = ads['InChIKey']
53 | break
54 | if not tar_key:
55 | continue
56 | pres_ret = []
57 | for d in dic['isotherm_data']:
58 | pres = d['pressure'] * 1e5
59 | for sd in d['species_data']:
60 | if sd['InChIKey'] == tar_key:
61 | tar_abs = sd['adsorption'] * unit_factor
62 | pres_ret.append({'pressure':pres, 'adsorption':tar_abs})
63 | pres_all[gas_type]['num'] += 1
64 | pres_all[gas_type]['data'].append({"name":name, "filename":js, "isotherm_data":pres_ret})
65 | with open(os.path.join(prefix,'all.json'),'w') as f:
66 | json.dump(pres_all, f)
67 |
68 |
--------------------------------------------------------------------------------
/process/tools/get_atom_features.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def get_atom_features(atom):
4 | attributes = []
5 | attributes += one_hot_vector(
6 | atom.atomic_number,
7 | [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, \
8 | 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 19, 30, \
9 | 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, \
10 | 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, \
11 | 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, \
12 | 73, 74, 75, 76, 77, 78, 79, 80, 81, 999]
13 | )
14 | # Connected numbers
15 | attributes += one_hot_vector(
16 | len(atom.neighbours),
17 | [0, 1, 2, 3, 4, 5, 6, 999]
18 | )
19 |
20 | # Test whether or not the atom is a hydrogen bond acceptor
21 | attributes.append(atom.is_acceptor)
22 | attributes.append(atom.is_chiral)
23 |
24 | # Test whether the atom is part of a ring system.
25 | attributes.append(atom.is_cyclic)
26 | attributes.append(atom.is_metal)
27 |
28 | # Test Whether this is a spiro atom.
29 | attributes.append(atom.is_spiro)
30 |
31 | return np.array(list(attributes), dtype=np.float32)
32 |
33 | def one_hot_vector(val, lst):
34 | """Converts a value to a one-hot vector based on options in lst"""
35 | if val not in lst:
36 | val = lst[-1]
37 | return map(lambda x: x == val, lst)
38 |
--------------------------------------------------------------------------------
/process/tools/get_bond_features.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 |
4 | def get_bond_features(mol):
5 | """Calculate bond features.
6 |
7 | Args:
8 | mol (ccdc.molecule.bond): An CSD mol object.
9 |
10 | Returns:
11 | bond matriax.
12 | bond distance.
13 | """
14 | adj_matrix = np.eye(len(mol.atoms))
15 | dis_matrix = []
16 |
17 | for bond in mol.bonds:
18 | atom1,atom2 = bond.atoms
19 | # construct atom matrix.
20 | adj_matrix[atom1.index, atom2.index] = adj_matrix[atom2.index, atom1.index] = 1
21 |
22 | # calculate bond distance.
23 | #print(atom1,atom2)
24 | #a_array = [atom1.coordinates.x, atom1.coordinates.y, atom1.coordinates.z]
25 | #b_array = [atom2.coordinates.x, atom2.coordinates.y, atom2.coordinates.z]
26 | #bond_length = calc_distance(a_array, b_array)
27 | #dis_matrix.append(bond_length)
28 |
29 | return adj_matrix
30 |
31 | def get_bond_features_en(mol):
32 | """Calculate bond features.
33 |
34 | Args:
35 | mol (ccdc.molecule.bond): An CSD mol object.
36 |
37 | Returns:
38 | bond matriax (coo).
39 | """
40 | row, col = [], []
41 |
42 | for bond in mol.bonds:
43 | atom1,atom2 = bond.atoms
44 | # construct atom matrix.
45 | row.append(atom1.index)
46 | col.append(atom2.index)
47 | row.append(atom2.index)
48 | col.append(atom1.index)
49 |
50 | return row, col
51 |
52 | # function to obtain bond distance
53 | def calc_distance(a_array, b_array):
54 | delt_d = np.array(a_array) - np.array(b_array)
55 | distance = math.sqrt(delt_d[0]**2 + delt_d[1]**2 + delt_d[2]**2)
56 | return round(distance,3)
57 |
58 |
59 |
--------------------------------------------------------------------------------
/process/tools/remove_waters.py:
--------------------------------------------------------------------------------
1 | import ccdc.molecule
2 |
3 | def get_largest_components(m):
4 | s = []
5 | for c in m.components:
6 | n = len(c.atoms)
7 | id_n = int(str(c.identifier))
8 | l = [(n, id_n)]
9 | s.append(l)
10 | t = sorted(s, key=lambda k: k[0])
11 | largest_id = t[-1][0][1] - 1
12 |
13 | return largest_id
14 |
15 | def remove_waters(m):
16 | keep = []
17 | waters = 0
18 | for s in m.components:
19 | ats = [at.atomic_symbol for at in s.atoms]
20 | if len(ats) == 3:
21 | ats.sort()
22 | if ats[0] == 'H' and ats[1] == 'H' and ats[2] == 'O':
23 | waters += 1
24 | else:
25 | keep.append(s)
26 | else:
27 | keep.append(s)
28 | new = ccdc.molecule.Molecule(m.identifier)
29 | for k in keep:
30 | new.add_molecule(k)
31 | return new
32 |
33 | def remove_single_oxygen(m):
34 | keep = []
35 | waters = 0
36 | for s in m.components:
37 | ats = [at.atomic_symbol for at in s.atoms]
38 | if len(ats) == 1:
39 | ats.sort()
40 | if ats[0] == 'O':
41 | waters += 1
42 | else:
43 | keep.append(s)
44 | else:
45 | keep.append(s)
46 | new = ccdc.molecule.Molecule(m.identifier)
47 | for k in keep:
48 | new.add_molecule(k)
49 | return new
50 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ase==3.21.1
2 | backcall==0.2.0
3 | certifi==2020.12.5
4 | charset-normalizer==2.0.6
5 | cloudpickle==2.0.0
6 | cycler==0.10.0
7 | decorator==5.1.0
8 | future==0.18.2
9 | googledrivedownloader==0.4
10 | greenlet==1.1.0
11 | idna==3.2
12 | importlib-metadata==4.0.1
13 | ipython==7.27.0
14 | isodate==0.6.0
15 | jedi==0.18.0
16 | Jinja2==3.0.3
17 | joblib==1.0.1
18 | kiwisolver==1.3.1
19 | llvmlite==0.37.0
20 | MarkupSafe==2.0.1
21 | matplotlib==3.4.2
22 | matplotlib-inline==0.1.3
23 | monty==2021.8.17
24 | mpmath==1.2.1
25 | networkx==2.6.3
26 | numba==0.54.1
27 | numpy==1.20.3
28 | olefile==0.46
29 | packaging==21.3
30 | palettable==3.3.0
31 | pandas==1.1.5
32 | parso==0.8.2
33 | pexpect==4.8.0
34 | pickleshare==0.7.5
35 | Pillow==8.1.2
36 | plotly==5.3.1
37 | prompt-toolkit==3.0.20
38 | ptyprocess==0.7.0
39 | pycairo==1.20.0
40 | Pygments==2.10.0
41 | pymatgen==2022.0.0
42 | pyparsing==2.4.7
43 | python-dateutil==2.8.1
44 | python-louvain==0.15
45 | pytz==2021.1
46 | PyYAML==6.0
47 | rdflib==5.0.0
48 | reportlab==3.5.67
49 | requests==2.26.0
50 | ruamel.yaml==0.17.16
51 | ruamel.yaml.clib==0.2.6
52 | scikit-learn==0.24.2
53 | scipy==1.6.3
54 | seaborn==0.11.2
55 | shap==0.40.0
56 | six==1.16.0
57 | sklearn==0.0
58 | slicer==0.0.7
59 | spglib==1.16.2
60 | SQLAlchemy==1.4.15
61 | sympy==1.9
62 | tabulate==0.8.9
63 | tenacity==8.0.1
64 | threadpoolctl==2.1.0
65 | torch==1.8.1+cu102
66 | torch-cluster==1.5.9
67 | torch-geometric==2.0.3
68 | torch-scatter==2.0.8
69 | torch-sparse==0.6.12
70 | torch-spline-conv==1.2.1
71 | torchaudio==0.8.1
72 | torchvision==0.6.0+cu101
73 | tornado==6.1
74 | tqdm==4.62.2
75 | traitlets==5.1.0
76 | typing-extensions==3.10.0.0
77 | uncertainties==3.1.6
78 | urllib3==1.26.7
79 | wcwidth==0.2.5
80 | yacs==0.1.8
81 | zipp==3.4.1
82 |
--------------------------------------------------------------------------------
/train_baselines.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import time
7 | from baselines.data_utils import load_data_from_df, construct_loader, data_prefetcher
8 | from baselines import make_baseline_model
9 | from argparser import parse_baseline_args
10 | from utils import *
11 |
12 | model_params = parse_baseline_args()
13 |
14 |
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16 |
17 |
18 | def warmupRdecayFactor(step):
19 | warmup_step = model_params['warmup_step']
20 | if step < warmup_step:
21 | return step / warmup_step
22 | else:
23 | return (warmup_step / step) ** 0.5
24 |
25 |
26 | def train(epoch, train_loader, optimizer, scheduler, use_adj=True):
27 | model.train()
28 | loss = 0
29 | loss_all = 0
30 | prefetcher = data_prefetcher(train_loader, device)
31 | batch_idx = 0
32 | data = prefetcher.next()
33 | while data is not None:
34 | lr = scheduler.optimizer.param_groups[0]['lr']
35 | if use_adj:
36 | node_features, pos, adj, global_feature, y = data
37 | else:
38 | node_features, pos, nbh, nbh_mask, global_feature, y = data
39 | adj = (nbh, nbh_mask)
40 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
41 |
42 | optimizer.zero_grad()
43 | output = model(node_features, batch_mask, pos, adj, global_feature)
44 | y = y.squeeze(-1)
45 | loss = F.mse_loss(output, y)
46 | loss.backward()
47 | step_loss = loss.cpu().detach().numpy()
48 | loss_all += step_loss
49 | optimizer.step()
50 | scheduler.step()
51 | print(f'After Step {batch_idx} of Epoch {epoch}, Loss = {step_loss}, Lr = {lr}')
52 | batch_idx += 1
53 | data = prefetcher.next()
54 | return loss_all / len(train_loader.dataset)
55 |
56 |
57 | def test(data_loader, mean, std, use_adj=True):
58 | model.eval()
59 | error = 0
60 | prefetcher = data_prefetcher(data_loader, device)
61 | batch_idx = 0
62 | data = prefetcher.next()
63 | futures, ys = [], []
64 | while data is not None:
65 |
66 | if use_adj:
67 | node_features, pos, adj, global_feature, y = data
68 | else:
69 | node_features, pos, nbh, nbh_mask, global_feature, y = data
70 | adj = (nbh, nbh_mask)
71 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
72 |
73 | optimizer.zero_grad()
74 | output = model(node_features, batch_mask, pos, adj, global_feature)
75 | ys += list(y.cpu().detach().numpy().reshape(-1))
76 | futures += list(output.cpu().detach().numpy().reshape(-1))
77 | batch_idx += 1
78 | data = prefetcher.next()
79 |
80 | futures = np.array(futures) * std + mean
81 | ys = np.array(ys) * std + mean
82 | mae = np.mean(np.abs(futures - ys))
83 | rmse = np.sqrt(np.mean((futures - ys)**2))
84 | pcc = np.corrcoef(futures,ys)[0][1]
85 | smape = 2 * np.mean(np.abs(futures-ys)/(np.abs(futures)+np.abs(ys)))
86 |
87 | return {'MAE':mae, 'RMSE':rmse, 'PCC':pcc, 'sMAPE':smape}
88 |
89 | if __name__ == '__main__':
90 |
91 | model_name = model_params['model_name']
92 | if model_name == 'egnn' or 'dimenetpp':
93 | use_adj = True
94 | else:
95 | use_adj = False
96 | batch_size = model_params['batch_size']
97 | device_ids = [0,1,2,3]
98 | logger = get_logger(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}")
99 | X, f, y = load_data_from_df(model_params['data_dir'],gas_type=model_params['gas_type'], pressure=model_params['pressure'],use_global_features = True)
100 | print(f'Loaded {len(X)} data.')
101 | logger.info(f'Loaded {len(X)} data.')
102 | y = np.array(y)
103 | mean = y.mean()
104 | std = y.std()
105 | y = (y - mean) / std
106 | f = np.array(f)
107 | fmean = f.mean(axis=0)
108 | fstd = f.std(axis=0)
109 | f = (f - fmean) / fstd
110 |
111 | model_params['d_atom'] = X[0][0].shape[1]
112 | model_params['d_feature'] = f.shape[-1]
113 |
114 | printParams(model_params,logger)
115 | fold_num = model_params['fold']
116 | epoch_num = model_params['epoch']
117 | test_errors = []
118 | idx_list = np.arange(len(X))
119 | set_seed(model_params['seed'])
120 | np.random.shuffle(idx_list)
121 | X = applyIndexOnList(X,idx_list)
122 | f = f[idx_list]
123 | y = y[idx_list]
124 |
125 | for fold_idx in range(1,fold_num + 1):
126 | set_seed(model_params['seed'])
127 | model = make_baseline_model(**model_params)
128 | model = torch.nn.DataParallel(model, device_ids=device_ids)
129 | model = model.to(device)
130 | lr = model_params['lr']
131 | optimizer = torch.optim.Adam(model.parameters(), lr=lr)
132 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda = warmupRdecayFactor)
133 | best_val_error = 0
134 | test_error = 0
135 | best_epoch = -1
136 | train_idx, val_idx, test_idx = splitdata(len(X),fold_num,fold_idx)
137 |
138 | train_loader = construct_loader(applyIndexOnList(X,train_idx), f[train_idx], y[train_idx],batch_size, shuffle=True, use_adj=use_adj)
139 | val_loader = construct_loader(applyIndexOnList(X,val_idx), f[val_idx], y[val_idx],batch_size, shuffle=False, use_adj=use_adj)
140 | test_loader = construct_loader(applyIndexOnList(X,test_idx),f[test_idx], y[test_idx],batch_size, shuffle=False, use_adj=use_adj)
141 |
142 | ckpt_handler = CheckpointHandler(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}/Fold-{fold_idx}")
143 |
144 | for epoch in range(1,epoch_num + 1):
145 | loss = train(epoch,train_loader,optimizer,scheduler, use_adj=use_adj)
146 | val_error = test(val_loader, mean, std, use_adj=use_adj)['MAE']
147 | ckpt_handler.save_model(model,model_params,epoch,val_error)
148 | if best_val_error == 0 or val_error <= best_val_error:
149 | print("Enter test step.\n")
150 | best_epoch = epoch
151 | test_error = test(test_loader, mean, std, use_adj=use_adj)
152 | best_val_error = val_error
153 | state = {"params":model_params, "epoch":epoch, "model":model.state_dict()}
154 | lr = scheduler.optimizer.param_groups[0]['lr']
155 |
156 | epoch_op_str = 'Fold: {:02d}, Epoch: {:03d}, LR: {:.7f}, Loss: {:.7f}, Validation MAE: {:.7f}, \
157 | Test MAE: {:.7f}, Test RMSE: {:.7f}, Test PCC: {:.7f}, Test sMAPE: {:.7f}, Best Val MAE {:.7f}(epoch {:03d})'.format(fold_idx, epoch, lr, loss, val_error, test_error['MAE'], test_error['RMSE'], test_error['PCC'], test_error['sMAPE'], best_val_error, best_epoch)
158 |
159 | print(epoch_op_str)
160 |
161 | logger.info(epoch_op_str)
162 |
163 | test_errors.append(test_error)
164 | print('Fold: {:02d}, Test MAE: {:.7f}, Test RMSE: {:.7f}, Test PCC: {:.7f}, Test sMAPE: {:.7f}'.format(fold_idx, test_error['MAE'], test_error['RMSE'], test_error['PCC'], test_error['sMAPE']))
165 | logger.info('Fold: {:02d}, Test MAE: {:.7f}, Test RMSE: {:.7f}, Test PCC: {:.7f}, Test sMAPE: {:.7f}'.format(fold_idx, test_error['MAE'], test_error['RMSE'], test_error['PCC'], test_error['sMAPE']))
166 | for _ in test_errors[0].keys():
167 | err_mean = np.mean([__[_] for __ in test_errors])
168 | err_std = np.std([__[_] for __ in test_errors])
169 | print('Test {} of {:02d}-Folds : {:.7f}({:.7f})'.format(_,fold_num,err_mean,err_std))
170 | logger.info('Test {} of {:02d}-Folds : {:.7f}({:.7f})'.format(_,fold_num,err_mean,err_std))
171 |
--------------------------------------------------------------------------------
/train_ml.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn import tree, svm, ensemble
3 | from featurization.data_utils import load_data_from_df, construct_loader_gf, data_prefetcher
4 | from argparser import parse_train_args,parse_ml_args
5 | from utils import *
6 |
7 | def get_metric_dict(predicted, ground_truth):
8 | mae = np.mean(np.abs(predicted - ground_truth))
9 | smape = np.mean(np.abs(predicted - ground_truth) / ((np.abs(ground_truth) + np.abs(predicted)) / 2))
10 | pcc = np.corrcoef(predicted, ground_truth)[0][1]
11 | rmse = np.sqrt(np.mean((predicted - ground_truth) ** 2))
12 | return {'MAE':mae, 'sMAPE':smape, 'PCC': pcc, 'RMSE':rmse}
13 |
14 | if __name__ == '__main__':
15 |
16 | model_params = parse_ml_args()
17 | device_ids = [0,1,2,3]
18 | logger = get_logger(model_params['save_dir'] + f"/{model_params['ml_type']}/{model_params['gas_type']}_{model_params['pressure']}/")
19 | X, f, y = load_data_from_df(model_params['data_dir'],gas_type=model_params['gas_type'], pressure=model_params['pressure'],add_dummy_node = True,use_global_features = True)
20 | print(f'Loaded {len(X)} data.')
21 | logger.info(f'Loaded {len(X)} data.')
22 | y = np.array(y).reshape(-1)
23 | mean = y.mean()
24 | std = y.std()
25 | y = (y - mean) / std
26 | f = np.array(f)
27 | fmean = f.mean(axis=0)
28 | fstd = f.std(axis=0)
29 | f = (f - fmean) / fstd
30 |
31 | Xs = [np.mean(_[0][:,1:],axis=0) for _ in X]
32 | f = np.concatenate((Xs,f),axis=1)
33 |
34 | printParams(model_params,logger)
35 | fold_num = model_params['fold']
36 | test_errors = []
37 | idx_list = np.arange(len(X))
38 | set_seed(model_params['seed'])
39 | np.random.shuffle(idx_list)
40 | X = applyIndexOnList(X,idx_list)
41 | f = f[idx_list]
42 | y = y[idx_list]
43 |
44 | for fold_idx in range(1,fold_num + 1):
45 | set_seed(model_params['seed'])
46 |
47 | train_idx, val_idx, test_idx = splitdata(len(X),fold_num,fold_idx)
48 |
49 | train_f,train_y = f[train_idx], y[train_idx]
50 | test_f,test_y = f[test_idx], y[test_idx]
51 |
52 | if model_params['ml_type'] == 'RF':
53 |
54 | model = ensemble.RandomForestRegressor(n_estimators=100,criterion='mse',min_samples_split=2,min_samples_leaf=1,max_features='auto')
55 |
56 | elif model_params['ml_type'] == 'SVR':
57 |
58 | model = svm.SVR()
59 |
60 | elif model_params['ml_type'] == 'DT':
61 |
62 | model = tree.DecisionTreeRegressor()
63 |
64 | elif model_params['ml_type'] == 'GBRT':
65 |
66 | model = ensemble.GradientBoostingRegressor()
67 |
68 | model.fit(train_f,train_y)
69 |
70 | future = model.predict(test_f) * std + mean
71 |
72 | test_y = test_y * std + mean
73 | test_error = get_metric_dict(future, test_y)
74 | for _ in test_error.keys():
75 | print('Fold: {:02d}, Test {}: {:.7f}'.format(fold_idx, _, test_error[_]))
76 | logger.info('Fold: {:02d}, Test {}: {:.7f}'.format(fold_idx, _, test_error[_]))
77 | test_errors.append(test_error)
78 | for _ in test_errors[0].keys():
79 | err_mean = np.mean([__[_] for __ in test_errors])
80 | err_std = np.std([__[_] for __ in test_errors])
81 | print('Test {} of {:02d}-Folds : {:.7f}({:.7f})'.format(_,fold_num,err_mean,err_std))
82 | logger.info('Test {} of {:02d}-Folds : {:.7f}({:.7f})'.format(_,fold_num,err_mean,err_std))
--------------------------------------------------------------------------------
/train_mofnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import time
7 | from featurization.data_utils import load_data_from_df, construct_loader_gf, data_prefetcher
8 | from models.transformer import make_model
9 | from argparser import parse_train_args
10 | from utils import *
11 |
12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13 |
14 | def warmupRdecayFactor(step):
15 | warmup_step = model_params['warmup_step']
16 | if step < warmup_step:
17 | return step / warmup_step
18 | else:
19 | return (warmup_step / step) ** 0.5
20 |
21 |
22 | def train(epoch, train_loader, optimizer, scheduler):
23 | model.train()
24 | loss = 0
25 | loss_all = 0
26 | prefetcher = data_prefetcher(train_loader)
27 | batch_idx = 0
28 | data = prefetcher.next()
29 | while data is not None:
30 | lr = scheduler.optimizer.param_groups[0]['lr']
31 | adjacency_matrix, node_features, distance_matrix, global_features, y = data
32 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
33 |
34 | optimizer.zero_grad()
35 | output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, global_features)
36 | loss = F.mse_loss(output, y)
37 | loss.backward()
38 | step_loss = loss.cpu().detach().numpy()
39 | loss_all += step_loss
40 | optimizer.step()
41 | scheduler.step()
42 | print(f'After Step {batch_idx} of Epoch {epoch}, Loss = {step_loss}, Lr = {lr}')
43 | batch_idx += 1
44 | data = prefetcher.next()
45 | return loss_all / len(train_loader.dataset)
46 |
47 |
48 | def test(data_loader, mean, std):
49 | model.eval()
50 | error = 0
51 | prefetcher = data_prefetcher(data_loader)
52 | batch_idx = 0
53 | data = prefetcher.next()
54 | futures, ys = [], []
55 | while data is not None:
56 | adjacency_matrix, node_features, distance_matrix, global_features, y = data
57 | batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
58 | output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, global_features)
59 | ys += list(y.cpu().detach().numpy().reshape(-1))
60 | futures += list(output.cpu().detach().numpy().reshape(-1))
61 | batch_idx += 1
62 | data = prefetcher.next()
63 |
64 | futures = np.array(futures) * std + mean
65 | ys = np.array(ys) * std + mean
66 | mae = np.mean(np.abs(futures - ys))
67 | rmse = np.sqrt(np.mean((futures - ys)**2))
68 | pcc = np.corrcoef(futures,ys)[0][1]
69 | smape = 2 * np.mean(np.abs(futures-ys)/(np.abs(futures)+np.abs(ys)))
70 |
71 | return {'MAE':mae, 'RMSE':rmse, 'PCC':pcc, 'sMAPE':smape}
72 |
73 | if __name__ == '__main__':
74 |
75 | model_params = parse_train_args()
76 | batch_size = model_params['batch_size']
77 | device_ids = [0,1,2,3]
78 | logger = get_logger(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}")
79 | X, f, y = load_data_from_df(model_params['data_dir'],gas_type=model_params['gas_type'], pressure=model_params['pressure'],add_dummy_node = True,use_global_features = True)
80 | print(f'Loaded {len(X)} data.')
81 | logger.info(f'Loaded {len(X)} data.')
82 | y = np.array(y)
83 | mean = y.mean()
84 | std = y.std()
85 | y = (y - mean) / std
86 | f = np.array(f)
87 | fmean = f.mean(axis=0)
88 | fstd = f.std(axis=0)
89 | f = (f - fmean) / fstd
90 |
91 | model_params['d_atom'] = X[0][0].shape[1]
92 | model_params['d_feature'] = f.shape[-1]
93 |
94 | printParams(model_params,logger)
95 | fold_num = model_params['fold']
96 | epoch_num = model_params['epoch']
97 | test_errors = []
98 | idx_list = np.arange(len(X))
99 | set_seed(model_params['seed'])
100 | np.random.shuffle(idx_list)
101 | X = applyIndexOnList(X,idx_list)
102 | f = f[idx_list]
103 | y = y[idx_list]
104 |
105 | for fold_idx in range(1,fold_num + 1):
106 | set_seed(model_params['seed'])
107 | model = make_model(**model_params)
108 | model = torch.nn.DataParallel(model, device_ids=device_ids)
109 | model = model.to(device)
110 | lr = model_params['lr']
111 | optimizer = torch.optim.Adam(model.parameters(), lr=lr)
112 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda = warmupRdecayFactor)
113 | best_val_error = 0
114 | test_error = 0
115 | best_epoch = -1
116 | train_idx, val_idx, test_idx = splitdata(len(X),fold_num,fold_idx)
117 |
118 | train_loader = construct_loader_gf(applyIndexOnList(X,train_idx), f[train_idx], y[train_idx],batch_size)
119 | val_loader = construct_loader_gf(applyIndexOnList(X,val_idx), f[val_idx], y[val_idx],batch_size)
120 | test_loader = construct_loader_gf(applyIndexOnList(X,test_idx),f[test_idx], y[test_idx],batch_size)
121 |
122 | ckpt_handler = CheckpointHandler(model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}/Fold-{fold_idx}")
123 |
124 | for epoch in range(1,epoch_num + 1):
125 | loss = train(epoch,train_loader,optimizer,scheduler)
126 | val_error = test(val_loader, mean, std)['MAE']
127 | ckpt_handler.save_model(model,model_params,epoch,val_error)
128 | if best_val_error == 0 or val_error <= best_val_error:
129 | print("Enter test step.\n")
130 | best_epoch = epoch
131 | test_error = test(test_loader, mean, std)
132 | best_val_error = val_error
133 | state = {"params":model_params, "epoch":epoch, "model":model.state_dict()}
134 | # torch.save(state, model_params['save_dir'] + f"/{model_params['gas_type']}_{model_params['pressure']}/Fold-{fold_idx}.pt")
135 | lr = scheduler.optimizer.param_groups[0]['lr']
136 |
137 | epoch_op_str = 'Fold: {:02d}, Epoch: {:03d}, LR: {:.7f}, Loss: {:.7f}, Validation MAE: {:.7f}, \
138 | Test MAE: {:.7f}, Test RMSE: {:.7f}, Test PCC: {:.7f}, Test sMAPE: {:.7f}, Best Val MAE {:.7f}(epoch {:03d})'.format(fold_idx, epoch, lr, loss, val_error, test_error['MAE'], test_error['RMSE'], test_error['PCC'], test_error['sMAPE'], best_val_error, best_epoch)
139 |
140 | print(epoch_op_str)
141 |
142 | logger.info(epoch_op_str)
143 |
144 | test_errors.append(test_error)
145 | print('Fold: {:02d}, Test MAE: {:.7f}, Test RMSE: {:.7f}, Test PCC: {:.7f}, Test sMAPE: {:.7f}'.format(fold_idx, test_error['MAE'], test_error['RMSE'], test_error['PCC'], test_error['sMAPE']))
146 | logger.info('Fold: {:02d}, Test MAE: {:.7f}, Test RMSE: {:.7f}, Test PCC: {:.7f}, Test sMAPE: {:.7f}'.format(fold_idx, test_error['MAE'], test_error['RMSE'], test_error['PCC'], test_error['sMAPE']))
147 | for _ in test_errors[0].keys():
148 | err_mean = np.mean([__[_] for __ in test_errors])
149 | err_std = np.std([__[_] for __ in test_errors])
150 | print('Test {} of {:02d}-Folds : {:.7f}({:.7f})'.format(_,fold_num,err_mean,err_std))
151 | logger.info('Test {} of {:02d}-Folds : {:.7f}({:.7f})'.format(_,fold_num,err_mean,err_std))
152 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import logging
4 | import os
5 |
6 | def splitdata(length,fold,index):
7 | fold_length = length // fold
8 | index_list = np.arange(length)
9 | if index == 1:
10 | val = index_list[:fold_length]
11 | test = index_list[fold_length * (fold - 1):]
12 | train = index_list[fold_length : fold_length * (fold - 1)]
13 | elif index == fold:
14 | val = index_list[fold_length * (fold - 1):]
15 | test = index_list[fold_length * (fold - 2) : fold_length * (fold - 1)]
16 | train = index_list[:fold_length * (fold - 2)]
17 | else:
18 | val = index_list[fold_length * (index - 1) : fold_length * index]
19 | test = index_list[fold_length * (index - 2) : fold_length * (index - 1)]
20 | train = np.concatenate([index_list[:fold_length * (index - 2)],index_list[fold_length * index:]])
21 | return train,val,test
22 |
23 |
24 | def printParams(model_params, logger=None):
25 | print("=========== Parameters ==========")
26 | for k,v in model_params.items():
27 | print(f'{k} : {v}')
28 | print("=================================")
29 | print()
30 | if logger:
31 | for k,v in model_params.items():
32 | logger.info(f'{k} : {v}')
33 |
34 | def applyIndexOnList(lis,idx):
35 | ans = []
36 | for _ in idx:
37 | ans.append(lis[_])
38 | return ans
39 |
40 | def set_seed(seed):
41 | torch.manual_seed(seed) # set seed for cpu
42 | torch.cuda.manual_seed(seed) # set seed for gpu
43 | torch.backends.cudnn.deterministic = True # cudnn
44 | torch.backends.cudnn.benchmark = False
45 | np.random.seed(seed) # numpy
46 |
47 | def get_logger(save_dir):
48 | logger = logging.getLogger(__name__)
49 | logger.setLevel(level = logging.INFO)
50 | handler = logging.FileHandler(save_dir + "/log.txt")
51 | handler.setLevel(logging.INFO)
52 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
53 | handler.setFormatter(formatter)
54 | logger.addHandler(handler)
55 | return logger
56 |
57 | class CheckpointHandler(object):
58 | def __init__(self, save_dir, max_save=5):
59 | self.save_dir = save_dir
60 | self.max_save = max_save
61 | self.init_info()
62 |
63 | def init_info(self):
64 | os.makedirs(self.save_dir, exist_ok=True)
65 | self.metric_dic = {}
66 | if os.path.exists(self.save_dir+'/eval_log.txt'):
67 | with open(self.save_dir+'/eval_log.txt','r') as f:
68 | ls = f.readlines()
69 | for l in ls:
70 | l = l.strip().split(':')
71 | assert len(l) == 2
72 | self.metric_dic[l[0]] = float(l[1])
73 |
74 |
75 | def save_model(self, model, model_params, epoch, eval_metric):
76 | max_in_dic = max(self.metric_dic.values()) if len(self.metric_dic) else 1e9
77 | if eval_metric > max_in_dic:
78 | return
79 | if len(self.metric_dic) == self.max_save:
80 | self.remove_last()
81 | self.metric_dic['model-'+str(epoch)+'.pt'] = eval_metric
82 | state = {"params":model_params, "epoch":epoch, "model":model.state_dict()}
83 | torch.save(state, self.save_dir + '/' + 'model-'+str(epoch)+'.pt')
84 | log_str = '\n'.join(['{}:{:.7f}'.format(k,v) for k,v in self.metric_dic.items()])
85 | with open(self.save_dir+'/eval_log.txt','w') as f:
86 | f.write(log_str)
87 |
88 |
89 | def remove_last(self):
90 | last_model = sorted(list(self.metric_dic.keys()),key = lambda x:self.metric_dic[x])[-1]
91 | if os.path.exists(self.save_dir+'/'+last_model):
92 | os.remove(self.save_dir+'/'+last_model)
93 | self.metric_dic.pop(last_model)
94 |
95 | def checkpoint_best(self, use_cuda=True):
96 | best_model = sorted(list(self.metric_dic.keys()),key = lambda x:self.metric_dic[x])[0]
97 | if use_cuda:
98 | state = torch.load(self.save_dir + '/' + best_model)
99 | else:
100 | state = torch.load(self.save_dir + '/' + best_model,map_location='cpu')
101 | return state
102 |
103 | def checkpoint_avg(self, use_cuda=True):
104 | return_dic = None
105 | model_num = 0
106 | tmp_model_params = None
107 | for ckpt in os.listdir(self.save_dir):
108 | if not ckpt.endswith('.pt'):
109 | continue
110 | model_num += 1
111 | if use_cuda:
112 | state = torch.load(self.save_dir + '/' + ckpt)
113 | else:
114 | state = torch.load(self.save_dir + '/' + ckpt,map_location='cpu')
115 | model,tmp_model_params = state['model'], state['params']
116 | if not return_dic:
117 | return_dic = model
118 | else:
119 | for k in return_dic:
120 | return_dic[k] += model[k]
121 | for k in return_dic:
122 | return_dic[k] = return_dic[k]/model_num
123 | return {'params':tmp_model_params, 'model':return_dic}
--------------------------------------------------------------------------------