├── MIT LICENSE.txt ├── Mat2Spec_Codes ├── Mat2Spec │ ├── Mat2Spec.py │ ├── SinkhornDistance.py │ ├── __init__.py │ ├── data.py │ ├── file_setter.py │ ├── pytorch_stats_loss.py │ └── utils.py ├── SCRIPTS │ ├── test_dos128_norm_sum_kl.sh │ ├── test_dos128_norm_sum_wd.sh │ ├── test_dos128_std_mae.sh │ ├── test_nolabel128_norm_sum_kl.sh │ ├── test_nolabel128_norm_sum_wd.sh │ ├── test_nolabel128_std_mae.sh │ ├── test_phdos51_norm_max_mae.sh │ ├── test_phdos51_norm_max_mse.sh │ ├── test_phdos51_norm_sum_kl.sh │ ├── test_phdos51_norm_sum_wd.sh │ ├── train_dos128_norm_sum_kl.sh │ ├── train_dos128_norm_sum_wd.sh │ ├── train_dos128_std_mae.sh │ ├── train_phdos51_norm_max_mae.sh │ ├── train_phdos51_norm_max_mse.sh │ ├── train_phdos51_norm_sum_kl.sh │ └── train_phdos51_norm_sum_wd.sh ├── test_Mat2Spec.py └── train_Mat2Spec.py └── README.md /MIT LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021-2022 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/Mat2Spec/Mat2Spec.py: -------------------------------------------------------------------------------- 1 | import torch, numpy as np 2 | import torch.optim as optim 3 | from torch.optim import lr_scheduler 4 | from torch.nn import Linear, Dropout, Parameter 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | 8 | from torch_geometric.nn.conv import MessagePassing 9 | from torch_geometric.utils import softmax 10 | from torch_geometric.nn import global_add_pool, global_mean_pool 11 | from torch_geometric.nn import GATConv 12 | from torch_scatter import scatter_add 13 | from torch_geometric.nn.inits import glorot, zeros 14 | 15 | from random import sample 16 | from copy import copy, deepcopy 17 | from Mat2Spec.utils import * 18 | from Mat2Spec.SinkhornDistance import SinkhornDistance 19 | from Mat2Spec.pytorch_stats_loss import torch_wasserstein_loss 20 | 21 | device = set_device() 22 | torch.cuda.empty_cache() 23 | kl_loss_fn = torch.nn.KLDivLoss() 24 | sinkhorn = SinkhornDistance(eps=0.1, max_iter=50, reduction='mean').to(device) 25 | 26 | 27 | # Note: the part of GNN implementation is modified from https://github.com/superlouis/GATGNN/ 28 | 29 | class COMPOSITION_Attention(torch.nn.Module): 30 | def __init__(self,neurons): 31 | super(COMPOSITION_Attention, self).__init__() 32 | self.node_layer1 = Linear(neurons+103,32) 33 | self.atten_layer = Linear(32,1) 34 | 35 | def forward(self,x,batch,global_feat): 36 | #torch.set_printoptions(threshold=10_000) 37 | # global_feat, [bs*103], rach row is an atom composition vector 38 | # x: [num_atom * atom_emb_len] 39 | 40 | counts = torch.unique(batch,return_counts=True)[-1] # return the number of atoms per crystal 41 | # batch includes all of the atoms from the Batch of crystals, each atom indexed by its Batch index. 42 | 43 | graph_embed = global_feat 44 | graph_embed = torch.repeat_interleave(graph_embed, counts, dim=0) # repeat rows according to counts 45 | chunk = torch.cat([x,graph_embed],dim=-1) 46 | x = F.softplus(self.node_layer1(chunk)) # [num_atom * 32] 47 | x = self.atten_layer(x) # [num_atom * 1] 48 | weights = softmax(x,batch) # [num_atom * 1] 49 | return weights 50 | 51 | 52 | class GAT_Crystal(MessagePassing): 53 | def __init__(self, in_features, out_features, edge_dim, heads, concat=False, 54 | dropout=0.0, bias=True, has_edge_attr=True, **kwargs): 55 | super(GAT_Crystal, self).__init__(aggr='add',flow='target_to_source', **kwargs) 56 | self.in_features = in_features 57 | self.out_features = out_features 58 | self.heads = heads 59 | self.concat = concat 60 | #self.dropout = dropout 61 | self.dropout = nn.Dropout(p=dropout) 62 | self.neg_slope = 0.2 63 | self.prelu = nn.PReLU() 64 | self.bn1 = nn.BatchNorm1d(heads) 65 | if has_edge_attr: 66 | self.W = Parameter(torch.Tensor(in_features+edge_dim,heads*out_features)) 67 | else: 68 | self.W = Parameter(torch.Tensor(in_features, heads * out_features)) 69 | self.att = Parameter(torch.Tensor(1,heads,2*out_features)) 70 | 71 | if bias and concat : self.bias = Parameter(torch.Tensor(heads * out_features)) 72 | elif bias and not concat : self.bias = Parameter(torch.Tensor(out_features)) 73 | else : self.register_parameter('bias', None) 74 | self.reset_parameters() 75 | 76 | def reset_parameters(self): 77 | glorot(self.W) 78 | glorot(self.att) 79 | zeros(self.bias) 80 | 81 | def forward(self, x, edge_index, edge_attr=None): 82 | # x: [num_node, emb_len] 83 | # edge_index: [2, num_edge] 84 | # edge_attr: [num_edge, emb_len] 85 | return self.propagate(edge_index, x=x, edge_attr=edge_attr) 86 | 87 | def message(self, edge_index_i, x_i, x_j, size_i, edge_attr): 88 | # edge_index_i: [num_edge] 89 | # x_i: [num_edge, emb_len] 90 | # x_j: [num_edge, emb_len] 91 | # size_i: num_node 92 | # edge_attr: [num_edge, emb_len] 93 | if edge_attr is not None: 94 | x_i = torch.cat([x_i,edge_attr],dim=-1) 95 | x_j = torch.cat([x_j,edge_attr],dim=-1) 96 | 97 | x_i = F.softplus(torch.matmul(x_i,self.W)) 98 | x_j = F.softplus(torch.matmul(x_j,self.W)) 99 | 100 | x_i = x_i.view(-1, self.heads, self.out_features) # [num_edge, num_head, emb_len] 101 | x_j = x_j.view(-1, self.heads, self.out_features) # [num_edge, num_head, emb_len] 102 | 103 | alpha = F.softplus((torch.cat([x_i, x_j], dim=-1)*self.att).sum(dim=-1)) # [num_edge, num_head] 104 | 105 | # self.att: (1,heads,2*out_features) 106 | 107 | alpha = F.softplus(self.bn1(alpha)) 108 | alpha = softmax(alpha, edge_index_i, size_i) # [num_edge, num_head] 109 | #alpha = softmax(alpha, edge_index_i) # [num_edge, num_head] 110 | alpha = self.dropout(alpha) 111 | 112 | return x_j * alpha.view(-1, self.heads, 1) # [num_edge, num_head, emb_len] 113 | 114 | def update(self, aggr_out): 115 | # aggr_out: [num_node, num_head, emb_len] 116 | if self.concat is True: aggr_out = aggr_out.view(-1, self.heads * self.out_features) 117 | else: aggr_out = aggr_out.mean(dim=1) 118 | if self.bias is not None: aggr_out = aggr_out + self.bias 119 | return aggr_out # [num_node, emb_len] 120 | 121 | class FractionalEncoder(nn.Module): 122 | """ 123 | Encoding element fractional amount using a "fractional encoding" inspired 124 | by the positional encoder discussed by Vaswani. 125 | https://arxiv.org/abs/1706.03762 126 | """ 127 | def __init__(self, 128 | d_model, 129 | resolution=100, 130 | log10=False, 131 | compute_device=None): 132 | super().__init__() 133 | self.d_model = d_model//2 134 | self.resolution = resolution 135 | self.log10 = log10 136 | self.compute_device = compute_device 137 | 138 | x = torch.linspace(0, self.resolution - 1, 139 | self.resolution, 140 | requires_grad=False) \ 141 | .view(self.resolution, 1) # (resolution, 1) 142 | fraction = torch.linspace(0, self.d_model - 1, 143 | self.d_model, 144 | requires_grad=False) \ 145 | .view(1, self.d_model).repeat(self.resolution, 1) # (resolution, d_model) 146 | 147 | pe = torch.zeros(self.resolution, self.d_model) # (resolution, d_model) 148 | pe[:, 0::2] = torch.sin(x /torch.pow(50, 2 * fraction[:, 0::2] / self.d_model)) 149 | pe[:, 1::2] = torch.cos(x / torch.pow(50, 2 * fraction[:, 1::2] / self.d_model)) 150 | pe = self.register_buffer('pe', pe) # (resolution, d_model) 151 | 152 | def forward(self, x): 153 | x = x.clone() 154 | if self.log10: 155 | x = 0.0025 * (torch.log2(x))**2 156 | x[x > 1] = 1 157 | # x = 1 - x # for sinusoidal encoding at x=0 158 | x[x < 1/self.resolution] = 1/self.resolution 159 | frac_idx = torch.round(x * (self.resolution)).to(dtype=torch.long) - 1 # (bs, n_elem) 160 | out = self.pe[frac_idx] # (bs, n_elem, d_model) 161 | return out 162 | 163 | class GNN(torch.nn.Module): 164 | def __init__(self,heads,neurons=64,nl=3,concat_comp=False): 165 | super(GNN, self).__init__() 166 | 167 | self.n_heads = heads 168 | self.number_layers = nl 169 | self.concat_comp = concat_comp 170 | 171 | n_h, n_hX2 = neurons, neurons*2 172 | self.neurons = neurons 173 | self.neg_slope = 0.2 174 | 175 | self.embed_n = Linear(92,n_h) 176 | self.embed_e = Linear(41,n_h) 177 | self.embed_comp = Linear(103,n_h) 178 | 179 | self.node_att = nn.ModuleList([GAT_Crystal(n_h,n_h,n_h,self.n_heads) for i in range(nl)]) 180 | self.batch_norm = nn.ModuleList([nn.BatchNorm1d(n_h) for i in range(nl)]) 181 | 182 | self.comp_atten = COMPOSITION_Attention(n_h) 183 | 184 | self.emb_scaler = nn.parameter.Parameter(torch.tensor([1.])) 185 | self.pos_scaler = nn.parameter.Parameter(torch.tensor([1.])) 186 | self.pos_scaler_log = nn.parameter.Parameter(torch.tensor([1.])) 187 | self.pe = FractionalEncoder(n_h, resolution=5000, log10=False) 188 | self.ple = FractionalEncoder(n_h, resolution=5000, log10=True) 189 | self.pe_linear = nn.Linear(103, 1) 190 | self.ple_linear = nn.Linear(103, 1) 191 | 192 | if self.concat_comp : reg_h = n_hX2 193 | else : reg_h = n_h 194 | 195 | self.linear1 = nn.Linear(reg_h,reg_h) 196 | self.linear2 = nn.Linear(reg_h,reg_h) 197 | 198 | def forward(self,data): 199 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 200 | 201 | batch, global_feat, cluster = data.batch, data.global_feature, data.cluster 202 | 203 | x = self.embed_n(x) # [num_atom, emb_len] 204 | 205 | edge_attr = F.leaky_relu(self.embed_e(edge_attr),self.neg_slope) # [num_edges, emb_len] 206 | 207 | for a_idx in range(len(self.node_att)): 208 | x = self.node_att[a_idx](x,edge_index,edge_attr) # [num_atom, emb_len] 209 | x = self.batch_norm[a_idx](x) 210 | x = F.softplus(x) 211 | 212 | ag = self.comp_atten(x,batch,global_feat) # [num_atom * 1] 213 | x = (x)*ag # [num_atom, emb_len] 214 | 215 | # CRYSTAL FEATURE-AGGREGATION 216 | y = global_mean_pool(x,batch)#*2**self.emb_scaler#.unsqueeze(1).squeeze() # [bs, emb_len] 217 | #y = F.relu(self.linear1(y)) # [bs, emb_len] 218 | #y = F.relu(self.linear2(y)) # [bs, emb_len] 219 | 220 | if self.concat_comp: 221 | pe = torch.zeros([global_feat.shape[0], global_feat.shape[1], y.shape[1]]).to(device) 222 | ple = torch.zeros([global_feat.shape[0], global_feat.shape[1], y.shape[1]]).to(device) 223 | pe_scaler = 2 ** (1 - self.pos_scaler) ** 2 224 | ple_scaler = 2 ** (1 - self.pos_scaler_log) ** 2 225 | pe[:, :, :y.shape[1] // 2] = self.pe(global_feat)# * pe_scaler 226 | ple[:, :, y.shape[1] // 2:] = self.ple(global_feat)# * ple_scaler 227 | pe = self.pe_linear(torch.transpose(pe, 1,2)).squeeze()* pe_scaler 228 | ple = self.ple_linear(torch.transpose(ple, 1,2)).squeeze()* ple_scaler 229 | y = y + pe + ple 230 | #y = torch.cat([y, pe+ple], dim=-1) 231 | #y = torch.cat([y, F.leaky_relu(self.embed_comp(global_feat), self.neg_slope)], dim=-1) 232 | 233 | return y 234 | 235 | class Mat2Spec(nn.Module): 236 | def __init__(self, args, NORMALIZER): 237 | super(Mat2Spec, self).__init__() 238 | n_heads = args.num_heads 239 | number_neurons = args.num_neurons 240 | number_layers = args.num_layers 241 | concat_comp = args.concat_comp 242 | self.graph_encoder = GNN(n_heads, neurons=number_neurons, nl=number_layers, concat_comp=concat_comp) 243 | 244 | self.loss_type = args.Mat2Spec_loss_type 245 | self.NORMALIZER = NORMALIZER 246 | self.input_dim = args.Mat2Spec_input_dim 247 | self.latent_dim = args.Mat2Spec_latent_dim 248 | self.emb_size = args.Mat2Spec_emb_size 249 | self.label_dim = args.Mat2Spec_label_dim 250 | self.scale_coeff = args.Mat2Spec_scale_coeff 251 | self.keep_prob = args.Mat2Spec_keep_prob 252 | self.K = args.Mat2Spec_K 253 | self.args = args 254 | 255 | self.fx1 = nn.Linear(self.input_dim, 256) 256 | self.fx2 = nn.Linear(256, 512) 257 | self.fx3 = nn.Linear(512, 256) 258 | self.fx_mu = nn.Linear(256, self.latent_dim*self.K) 259 | self.fx_logvar = nn.Linear(256, self.latent_dim*self.K) 260 | self.fx_mix_coeff = nn.Linear(256, self.K) 261 | 262 | self.fe_mix_coeff = nn.Sequential( 263 | nn.Linear(self.label_dim, 128), 264 | nn.ReLU(), 265 | nn.Linear(128, self.label_dim) 266 | ) 267 | 268 | self.fd_x1 = nn.Linear(self.input_dim + self.latent_dim, 512) 269 | self.fd_x2 = torch.nn.Sequential( 270 | nn.Linear(512, self.emb_size) 271 | ) 272 | self.feat_mp_mu = nn.Linear(self.emb_size, self.label_dim) 273 | 274 | # label layers 275 | self.fe0 = nn.Linear(self.label_dim, self.emb_size) 276 | self.fe1 = nn.Linear(self.label_dim, 512) 277 | self.fe2 = nn.Linear(512, 256) 278 | self.fe_mu = nn.Linear(256, self.latent_dim) 279 | self.fe_logvar = nn.Linear(256, self.latent_dim) 280 | 281 | self.fd1 = self.fd_x1 282 | self.fd2 = self.fd_x2 283 | #self.fd = self.fd_x 284 | self.label_mp_mu = self.feat_mp_mu 285 | 286 | self.bias = nn.Parameter(torch.zeros(self.label_dim)) 287 | 288 | assert id(self.fd_x1) == id(self.fd1) 289 | assert id(self.fd_x2) == id(self.fd2) 290 | 291 | self.dropout = nn.Dropout(p=self.keep_prob) 292 | self.emb_proj = nn.Linear(args.Mat2Spec_emb_size, 1024) 293 | self.W = nn.Linear(args.Mat2Spec_label_dim, args.Mat2Spec_emb_size) # linear transformation for label 294 | 295 | def label_encode(self, x): 296 | #h0 = self.dropout(F.relu(self.fe0(x))) # [label_dim, emb_size] 297 | h1 = self.dropout(F.relu(self.fe1(x))) # [label_dim, 512] 298 | h2 = self.dropout(F.relu(self.fe2(h1))) # [label_dim, 256] 299 | mu = self.fe_mu(h2) * self.scale_coeff # [label_dim, latent_dim] 300 | logvar = self.fe_logvar(h2) * self.scale_coeff # [label_dim, latent_dim] 301 | 302 | fe_output = { 303 | 'fe_mu': mu, 304 | 'fe_logvar': logvar 305 | } 306 | return fe_output 307 | 308 | def feat_encode(self, x): 309 | h1 = self.dropout(F.relu(self.fx1(x))) 310 | h2 = self.dropout(F.relu(self.fx2(h1))) 311 | h3 = self.dropout(F.relu(self.fx3(h2))) 312 | mu = self.fx_mu(h3) * self.scale_coeff # [bs, latent_dim] 313 | logvar = self.fx_logvar(h3) * self.scale_coeff 314 | mix_coeff = self.fx_mix_coeff(h3) # [bs, K] 315 | 316 | if self.K > 1: 317 | mu = mu.view(x.shape[0], self.K, self.args.Mat2Spec_latent_dim) # [bs, K, latent_dim] 318 | logvar = logvar.view(x.shape[0], self.K, self.args.Mat2Spec_latent_dim) # [bs, K, latent_dim] 319 | 320 | fx_output = { 321 | 'fx_mu': mu, 322 | 'fx_logvar': logvar, 323 | 'fx_mix_coeff': mix_coeff 324 | } 325 | return fx_output 326 | 327 | def label_reparameterize(self, mu, logvar): 328 | std = torch.exp(0.5 * logvar) 329 | eps = torch.randn_like(std) 330 | return mu + eps * std 331 | 332 | def feat_reparameterize(self, mu, logvar, coeff=1.0): 333 | std = torch.exp(0.5 * logvar) 334 | eps = torch.randn_like(std) 335 | return mu + eps * std 336 | 337 | def label_decode(self, z): 338 | d1 = F.relu(self.fd1(z)) 339 | d2 = F.leaky_relu(self.fd2(d1)) 340 | return d2 341 | 342 | def feat_decode(self, z): 343 | d1 = F.relu(self.fd_x1(z)) 344 | d2 = F.leaky_relu(self.fd_x2(d1)) 345 | return d2 346 | 347 | def label_forward(self, x, feat): # x is label 348 | n_label = x.shape[1] # label_dim 349 | all_labels = torch.eye(n_label).to(x.device) # [label_dim, label_dim] 350 | fe_output = self.label_encode(all_labels) # map each label to a Gaussian mixture. 351 | mu = fe_output['fe_mu'] 352 | logvar = fe_output['fe_logvar'] 353 | fe_output['fe_mix_coeff'] = self.fe_mix_coeff(x) 354 | mix_coeff = F.softmax(fe_output['fe_mix_coeff'], dim=-1) 355 | 356 | if self.args.train: 357 | z = self.label_reparameterize(mu, logvar) # [label_dim, latent_dim] 358 | else: 359 | z = mu 360 | z = torch.matmul(mix_coeff, z) 361 | 362 | label_emb = self.label_decode(torch.cat((feat, z), 1)) 363 | fe_output['label_emb'] = label_emb 364 | return fe_output 365 | 366 | def feat_forward(self, x): 367 | fx_output = self.feat_encode(x) 368 | mu = fx_output['fx_mu'] # [bs, latent_dim] 369 | logvar = fx_output['fx_logvar'] # [bs, latent_dim] 370 | 371 | if self.args.train: 372 | z = self.feat_reparameterize(mu, logvar) 373 | else: 374 | z = mu 375 | if self.K > 1: 376 | mix_coeff = fx_output['fx_mix_coeff'] # [bs, K] 377 | mix_coeff = F.softmax(mix_coeff, dim=-1) 378 | mix_coeff = mix_coeff.unsqueeze(-1).expand_as(z) 379 | z = z * mix_coeff 380 | z = torch.sum(z, dim=1) # [bs, latent_dim] 381 | 382 | feat_emb = self.feat_decode(torch.cat((x, z), 1)) # [bs, emb_size] 383 | fx_output['feat_emb'] = feat_emb 384 | return fx_output 385 | 386 | def forward(self, data): 387 | label = data.y 388 | feature = self.graph_encoder(data) 389 | 390 | fe_output = self.label_forward(label, feature) 391 | label_emb = fe_output['label_emb'] # [bs, emb_size] 392 | fx_output = self.feat_forward(feature) 393 | feat_emb = fx_output['feat_emb'] # [bs, emb_size] 394 | W = self.W.weight # [emb_size, label_dim] 395 | label_out = torch.matmul(label_emb, W) # [bs, emb_size] * [emb_size, label_dim] = [bs, label_dim] 396 | feat_out = torch.matmul(feat_emb, W) # [bs, label_dim] 397 | 398 | label_proj = self.emb_proj(label_emb) 399 | feat_proj = self.emb_proj(feat_emb) 400 | fe_output.update(fx_output) 401 | output = fe_output 402 | 403 | if self.args.label_scaling == 'normalized_max': 404 | label_out = F.relu(label_out) 405 | feat_out = F.relu(feat_out) 406 | maxima, _ = torch.max(label_out, dim=1) 407 | label_out = label_out.div(maxima.unsqueeze(1)+1e-8) 408 | maxima, _ = torch.max(feat_out, dim=1) 409 | feat_out = feat_out.div(maxima.unsqueeze(1)+1e-8) 410 | 411 | output['label_out'] = label_out 412 | output['feat_out'] = feat_out 413 | output['label_proj'] = label_proj 414 | output['feat_proj'] = feat_proj 415 | return output 416 | 417 | def kl(fx_mu, fe_mu, fx_logvar, fe_logvar): 418 | kl_loss = 0.5 * torch.sum( 419 | (fx_logvar - fe_logvar) - 1 + torch.exp(fe_logvar - fx_logvar) + (fx_mu - fe_mu)**2 / ( 420 | torch.exp(fx_logvar) + 1e-8), dim=-1) 421 | return kl_loss 422 | 423 | def compute_c_loss(BX, BY, tau=1): 424 | BX = F.normalize(BX, dim=1) 425 | BY = F.normalize(BY, dim=1) 426 | b = torch.matmul(BX, torch.transpose(BY, 0, 1)) # [bs, bs] 427 | b = torch.exp(b/tau) 428 | b_diag = torch.diagonal(b, 0).unsqueeze(1) # [bs, 1] 429 | b_sum = torch.sum(b, dim=-1, keepdim=True) # [bs, 1] 430 | c = b_diag/(b_sum-b_diag) 431 | c_loss = -torch.mean(torch.log(c)) 432 | return c_loss 433 | 434 | def compute_loss(input_label, output, NORMALIZER, args): 435 | fe_out, fe_mu, fe_logvar, label_emb, label_proj = output['label_out'], output['fe_mu'], output['fe_logvar'], output['label_emb'], output['label_proj'] 436 | fx_out, fx_mu, fx_logvar, feat_emb, feat_proj = output['feat_out'], output['fx_mu'], output['fx_logvar'], output['feat_emb'], output['feat_proj'] 437 | 438 | fx_mix_coeff = output['fx_mix_coeff'] # [bs, K] 439 | fe_mix_coeff = output['fe_mix_coeff'] 440 | fx_mix_coeff = F.softmax(fx_mix_coeff, dim=-1) 441 | fe_mix_coeff = F.softmax(fe_mix_coeff, dim=-1) 442 | fe_mix_coeff = fe_mix_coeff.repeat(1, args.Mat2Spec_K) 443 | fx_mix_coeff = fx_mix_coeff.repeat(1, args.Mat2Spec_label_dim) 444 | mix_coeff = fe_mix_coeff * fx_mix_coeff 445 | fx_mu = fx_mu.repeat(1, args.Mat2Spec_label_dim, 1) 446 | fx_logvar = fx_logvar.repeat(1, args.Mat2Spec_label_dim, 1) 447 | fe_mu = fe_mu.squeeze(0).expand(fx_mu.shape[0], fe_mu.shape[0], fe_mu.shape[1]) 448 | fe_logvar = fe_logvar.squeeze(0).expand(fx_mu.shape[0], fe_logvar.shape[0], fe_logvar.shape[1]) 449 | fe_mu = fe_mu.repeat(1, args.Mat2Spec_K, 1) 450 | fe_logvar = fe_logvar.repeat(1, args.Mat2Spec_K, 1) 451 | kl_all = kl(fx_mu, fe_mu, fx_logvar, fe_logvar) 452 | kl_all_inv = kl(fe_mu, fx_mu, fe_logvar, fx_logvar) 453 | kl_loss = torch.mean(torch.sum(mix_coeff * (0.5*kl_all + 0.5*kl_all_inv), dim=-1)) 454 | #c_loss = torch.mean(-1 * F.cosine_similarity(label_proj, feat_proj)) 455 | c_loss = compute_c_loss(label_proj, feat_proj) 456 | 457 | if args.label_scaling == 'normalized_sum': 458 | assert args.Mat2Spec_loss_type == 'KL' or args.Mat2Spec_loss_type == 'WD' 459 | #input_label_normalize = F.softmax(torch.log(input_label+1e-6), dim=1) 460 | input_label_normalize = input_label / (torch.sum(input_label, dim=1, keepdim=True)+1e-8) 461 | pred_e = F.softmax(fe_out, dim=1) 462 | pred_x = F.softmax(fx_out, dim=1) 463 | #nll_loss = kl_loss_fn(torch.log(pred_e+1e-8), input_label_normalize) 464 | #nll_loss_x = kl_loss_fn(torch.log(pred_x+1e-8), input_label_normalize) 465 | P = input_label_normalize 466 | Q_e = pred_e 467 | Q_x = pred_x 468 | c1, c2, c3 = 1, 1.1, 0.1 469 | if args.ablation_LE: 470 | c2 = 0.0 471 | if args.ablation_CL: 472 | c3 = 0.0 473 | 474 | if args.Mat2Spec_loss_type == 'KL': 475 | nll_loss = torch.mean(torch.sum(P*(torch.log(P+1e-8)-torch.log(Q_e+1e-8)),dim=1)) \ 476 | #+ torch.mean(torch.sum(Q_e*(torch.log(Q_e+1e-8)-torch.log(P+1e-8)),dim=1)) 477 | nll_loss_x = torch.mean(torch.sum(P*(torch.log(P+1e-8)-torch.log(Q_x+1e-8)),dim=1)) \ 478 | #+ torch.mean(torch.sum(Q_x*(torch.log(Q_x+1e-8)-torch.log(P+1e-8)),dim=1)) 479 | elif args.Mat2Spec_loss_type == 'WD': 480 | #nll_loss, _, _ = sinkhorn(Q_e, P) 481 | #nll_loss_x, _, _ = sinkhorn(Q_x, P) 482 | nll_loss = torch_wasserstein_loss(Q_e, P) 483 | nll_loss_x = torch_wasserstein_loss(Q_x, P) 484 | total_loss = (nll_loss + nll_loss_x) * c1 + kl_loss * c2 + c_loss * c3 485 | 486 | return total_loss, nll_loss, nll_loss_x, kl_loss, c_loss, pred_e, pred_x 487 | 488 | else: # standardized or normalized_max 489 | assert args.Mat2Spec_loss_type == 'MAE' or args.Mat2Spec_loss_type == 'MSE' 490 | pred_e = fe_out 491 | pred_x = fx_out 492 | c1, c2, c3 = 1, 1.1, 0.1 493 | if args.ablation_LE: 494 | c2 = 0.0 495 | if args.ablation_CL: 496 | c3 = 0.0 497 | 498 | if args.Mat2Spec_loss_type == 'MAE': 499 | nll_loss = torch.mean(torch.abs(pred_e-input_label)) 500 | nll_loss_x = torch.mean(torch.abs(pred_x-input_label)) 501 | elif args.Mat2Spec_loss_type == 'MSE': 502 | nll_loss = torch.mean((pred_e-input_label)**2) 503 | nll_loss_x = torch.mean((pred_x-input_label)**2) 504 | total_loss = (nll_loss + nll_loss_x) * c1 + kl_loss * c2 + c_loss * c3 505 | 506 | return total_loss, nll_loss, nll_loss_x, kl_loss, c_loss, pred_e, pred_x 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/Mat2Spec/SinkhornDistance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 5 | 6 | # Adapted from https://github.com/gpeyre/SinkhornAutoDiff 7 | class SinkhornDistance(nn.Module): 8 | r""" 9 | Given two empirical measures each with :math:`P_1` locations 10 | :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`, 11 | outputs an approximation of the regularized OT cost for point clouds. 12 | Args: 13 | eps (float): regularization coefficient 14 | max_iter (int): maximum number of Sinkhorn iterations 15 | reduction (string, optional): Specifies the reduction to apply to the output: 16 | 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 17 | 'mean': the sum of the output will be divided by the number of 18 | elements in the output, 'sum': the output will be summed. Default: 'none' 19 | Shape: 20 | - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)` 21 | - Output: :math:`(N)` or :math:`()`, depending on `reduction` 22 | """ 23 | def __init__(self, eps, max_iter, reduction='none'): 24 | super(SinkhornDistance, self).__init__() 25 | self.eps = eps 26 | self.max_iter = max_iter 27 | self.reduction = reduction 28 | 29 | def forward(self, mu, nu): 30 | # The Sinkhorn algorithm takes as input three variables : 31 | C = self._cost_matrix(mu.shape[0], mu.shape[1]).to(device) # Wasserstein cost function 32 | #x_points = x.shape[-2] 33 | #y_points = y.shape[-2] 34 | #if x.dim() == 2: 35 | # batch_size = 1 36 | #else: 37 | # batch_size = x.shape[0] 38 | 39 | # both marginals are fixed with equal weights 40 | #mu = torch.empty(batch_size, x_points, dtype=torch.float, 41 | # requires_grad=False).fill_(1.0 / x_points).squeeze() 42 | #nu = torch.empty(batch_size, y_points, dtype=torch.float, 43 | # requires_grad=False).fill_(1.0 / y_points).squeeze() 44 | 45 | u = torch.zeros_like(mu) 46 | v = torch.zeros_like(nu) 47 | # To check if algorithm terminates because of threshold 48 | # or max iterations reached 49 | actual_nits = 0 50 | # Stopping criterion 51 | thresh = 1e-2 52 | 53 | # Sinkhorn iterations 54 | for i in range(self.max_iter): 55 | u1 = u # useful to check the update 56 | u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u 57 | v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v 58 | err = (u - u1).abs().sum(-1).mean() 59 | 60 | actual_nits += 1 61 | if err.item() < thresh: 62 | break 63 | 64 | U, V = u, v 65 | # Transport plan pi = diag(a)*K*diag(b) 66 | pi = torch.exp(self.M(C, U, V)) 67 | # Sinkhorn distance 68 | cost = torch.sum(pi * C, dim=(-2, -1)) 69 | 70 | if self.reduction == 'mean': 71 | cost = cost.mean() 72 | elif self.reduction == 'sum': 73 | cost = cost.sum() 74 | 75 | return cost, pi, C 76 | 77 | def M(self, C, u, v): 78 | "Modified cost for logarithmic updates" 79 | "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$" 80 | return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps # [bs,N,N] 81 | 82 | @staticmethod 83 | def _cost_matrix(batch_size, n, p=2): 84 | "Returns the matrix of $|x_i-y_j|^p$." 85 | 86 | a = np.array([[[i, 0] for i in range(n)] for b in range(batch_size)]) 87 | b = np.array([[[i, 1] for i in range(n)] for b in range(batch_size)]) 88 | 89 | # Wrap with torch tensors 90 | x = torch.tensor(a, dtype=torch.float, requires_grad=False) 91 | y = torch.tensor(b, dtype=torch.float, requires_grad=False) 92 | 93 | x_col = x.unsqueeze(-2) 94 | y_lin = y.unsqueeze(-3) 95 | C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1) # [bs, N, N] 96 | return C 97 | 98 | @staticmethod 99 | def ave(u, u1, tau): 100 | "Barycenter subroutine, used by kinetic acceleration through extrapolation." 101 | return tau * u + (1 - tau) * u1 -------------------------------------------------------------------------------- /Mat2Spec_Codes/Mat2Spec/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/Mat2Spec/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import functools 4 | import torch 5 | import pickle 6 | from torch.utils.data import Dataset 7 | from torch_geometric.data import Dataset as torch_Dataset 8 | from torch_geometric.data import Data, DataLoader as torch_DataLoader 9 | import sys, json, os 10 | from pymatgen.core.structure import Structure 11 | from sklearn.cluster import KMeans 12 | from sklearn.cluster import SpectralClustering as SPCL 13 | import warnings 14 | from Mat2Spec.utils import * 15 | from os import path 16 | 17 | # Note: this file for data loading is modified from https://github.com/superlouis/GATGNN/blob/master/gatgnn/data.py 18 | 19 | # gpu_id = 0 20 | # device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') 21 | device = set_device() 22 | 23 | if not sys.warnoptions: 24 | warnings.simplefilter("ignore") 25 | 26 | def mkdirs(path): 27 | if not os.path.exists(path): 28 | os.makedirs(path) 29 | 30 | class ELEM_Encoder: 31 | def __init__(self): 32 | self.elements = ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 33 | 'Ar', 'K', 34 | 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 35 | 'Kr', 'Rb', 36 | 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 37 | 'Xe', 'Cs', 38 | 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 39 | 'Hf', 'Ta', 40 | 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 41 | 'Th', 'Pa', 42 | 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr'] # 103 43 | self.e_arr = np.array(self.elements) 44 | 45 | def encode(self, composition_dict): # from formula to composition, which is a vector of length 103 46 | answer = [0] * len(self.elements) 47 | 48 | elements = [str(i) for i in composition_dict.keys()] 49 | counts = [j for j in composition_dict.values()] 50 | total = sum(counts) 51 | 52 | for idx in range(len(elements)): 53 | elem = elements[idx] 54 | ratio = counts[idx] / total 55 | idx_e = self.elements.index(elem) 56 | answer[idx_e] = ratio 57 | return torch.tensor(answer).float().view(1, -1) 58 | 59 | def decode_pymatgen_num(tensor_idx): # from ele_num to ele_name 60 | idx = (tensor_idx - 1).cpu().tolist() 61 | return self.e_arr[idx] 62 | 63 | 64 | class DATA_normalizer: 65 | def __init__(self, array): 66 | tensor = torch.tensor(array) 67 | self.mean = torch.mean(tensor, dim=0).float() 68 | self.std = torch.std(tensor, dim=0).float() 69 | 70 | def reg(self, x): 71 | return x.float() 72 | 73 | def log10(self, x): 74 | return torch.log10(x) 75 | 76 | def delog10(self, x): 77 | return 10 * x 78 | 79 | def norm(self, x): 80 | return (x - self.mean) / self.std 81 | 82 | def denorm(self, x): 83 | return x * self.std + self.mean 84 | 85 | 86 | class METRICS: 87 | def __init__(self, c_property, epoch, torch_criterion, torch_func, device): 88 | self.c_property = c_property 89 | self.criterion = torch_criterion 90 | self.eval_func = torch_func 91 | self.dv = device 92 | self.training_measure1 = torch.tensor(0.0).to(device) 93 | self.training_measure2 = torch.tensor(0.0).to(device) 94 | self.valid_measure1 = torch.tensor(0.0).to(device) 95 | self.valid_measure2 = torch.tensor(0.0).to(device) 96 | 97 | self.training_counter = 0 98 | self.valid_counter = 0 99 | 100 | self.training_loss1 = [] 101 | self.training_loss2 = [] 102 | self.valid_loss1 = [] 103 | self.valid_loss2 = [] 104 | self.duration = [] 105 | self.dataframe = self.to_frame() 106 | 107 | def __str__(self): 108 | x = self.to_frame() 109 | return x.to_string() 110 | 111 | def to_frame(self): 112 | metrics_df = pd.DataFrame(list(zip(self.training_loss1, self.training_loss2, 113 | self.valid_loss1, self.valid_loss2, self.duration)), 114 | columns=['training_1', 'training_2', 'valid_1', 'valid_2', 'time']) 115 | return metrics_df 116 | 117 | def set_label(self, which_phase, graph_data): 118 | use_label = graph_data.y 119 | return use_label 120 | 121 | def save_time(self, e_duration): 122 | self.duration.append(e_duration) 123 | 124 | def __call__(self, which_phase, tensor_pred, tensor_true, measure=1): 125 | if measure == 1: 126 | if which_phase == 'training': 127 | loss = self.criterion(tensor_pred, tensor_true) 128 | self.training_measure1 += loss 129 | elif which_phase == 'validation': 130 | loss = self.criterion(tensor_pred, tensor_true) 131 | self.valid_measure1 += loss 132 | else: 133 | if which_phase == 'training': 134 | loss = self.eval_func(tensor_pred, tensor_true) 135 | self.training_measure2 += loss 136 | elif which_phase == 'validation': 137 | loss = self.eval_func(tensor_pred, tensor_true) 138 | self.valid_measure2 += loss 139 | return loss 140 | 141 | def reset_parameters(self, which_phase, epoch): 142 | if which_phase == 'training': 143 | # AVERAGES 144 | t1 = self.training_measure1 / (self.training_counter) 145 | t2 = self.training_measure2 / (self.training_counter) 146 | 147 | self.training_loss1.append(t1.item()) 148 | self.training_loss2.append(t2.item()) 149 | self.training_measure1 = torch.tensor(0.0).to(self.dv) 150 | self.training_measure2 = torch.tensor(0.0).to(self.dv) 151 | self.training_counter = 0 152 | else: 153 | # AVERAGES 154 | v1 = self.valid_measure1 / (self.valid_counter) 155 | v2 = self.valid_measure2 / (self.valid_counter) 156 | 157 | self.valid_loss1.append(v1.item()) 158 | self.valid_loss2.append(v2.item()) 159 | self.valid_measure1 = torch.tensor(0.0).to(self.dv) 160 | self.valid_measure2 = torch.tensor(0.0).to(self.dv) 161 | self.valid_counter = 0 162 | 163 | def save_info(self): 164 | with open('MODELS/metrics_.pickle', 'wb') as metrics_file: 165 | pickle.dump(self, metrics_file) 166 | 167 | 168 | class GaussianDistance(object): 169 | def __init__(self, dmin, dmax, step, var=None): 170 | assert dmin < dmax 171 | assert dmax - dmin > step 172 | self.filter = np.arange(dmin, dmax + step, step) # int((dmax-dmin) / step) + 1 173 | if var is None: 174 | var = step 175 | self.var = var 176 | 177 | def expand(self, distances): 178 | # print(distances.shape) [nbr, nbr] 179 | # x = distances[..., np.newaxis] [nbr, nbr, 1] 180 | # print(self.filter.shape) 181 | # print((x-self.filter).shape) 182 | return np.exp(-(distances[..., np.newaxis] - self.filter) ** 2 / self.var ** 2) 183 | 184 | 185 | class AtomInitializer(object): 186 | def __init__(self, atom_types): 187 | self.atom_types = set(atom_types) 188 | self._embedding = {} 189 | 190 | def get_atom_fea(self, atom_type): 191 | assert atom_type in self.atom_types 192 | return self._embedding[atom_type] 193 | 194 | def load_state_dict(self, state_dict): 195 | self._embedding = state_dict 196 | self.atom_types = set(self._embedding.keys()) 197 | self._decodedict = {idx: atom_type for atom_type, idx in 198 | self._embedding.items()} 199 | 200 | def state_dict(self): 201 | return self._embedding 202 | 203 | def decode(self, idx): 204 | if not hasattr(self, '_decodedict'): 205 | self._decodedict = {idx: atom_type for atom_type, idx in 206 | self._embedding.items()} 207 | return self._decodedict[idx] 208 | 209 | 210 | class AtomCustomJSONInitializer(AtomInitializer): 211 | def __init__(self, elem_embedding_file): 212 | with open(elem_embedding_file) as f: 213 | elem_embedding = json.load(f) 214 | elem_embedding = {int(key): value for key, value 215 | in elem_embedding.items()} 216 | atom_types = set(elem_embedding.keys()) # 100 217 | super(AtomCustomJSONInitializer, self).__init__(atom_types) 218 | for key, value in elem_embedding.items(): 219 | self._embedding[key] = np.array(value, dtype=float) 220 | 221 | 222 | class CIF_Lister(Dataset): 223 | def __init__(self, crystals_ids, full_dataset, df=None): 224 | self.crystals_ids = crystals_ids 225 | self.full_dataset = full_dataset 226 | self.material_ids = df.iloc[crystals_ids].values[:, 0].squeeze() # MP-xxx 227 | 228 | def __len__(self): 229 | return len(self.crystals_ids) 230 | 231 | def extract_ids(self, original_dataset): 232 | names = original_dataset.iloc[self.crystals_ids] 233 | return names 234 | 235 | def __getitem__(self, idx): 236 | i = self.crystals_ids[idx] 237 | material = self.full_dataset[i] 238 | 239 | n_features = material[0][0] 240 | e_features = material[0][1] # [n_atom, nbr, 41] 241 | e_features = e_features.view(-1, 41) 242 | a_matrix = material[0][2] 243 | 244 | groups = material[1] 245 | enc_compo = material[2] # normalize feat 246 | coordinates = material[3] 247 | y = material[4] # target 248 | 249 | graph_crystal = Data(x=n_features, y=y, edge_attr=e_features, edge_index=a_matrix, global_feature=enc_compo, \ 250 | cluster=groups, num_atoms=torch.tensor([len(n_features)]).float(), coords=coordinates, 251 | the_idx=torch.tensor([float(i)])) 252 | 253 | return graph_crystal 254 | 255 | class CIF_Dataset(Dataset): 256 | def __init__(self, args, pd_data=None, np_data=None, norm_obj=None, normalization=None, max_num_nbr=12, radius=8, 257 | dmin=0, step=0.2, cls_num=3, root_dir='DATA/'): 258 | self.root_dir = root_dir 259 | self.max_num_nbr, self.radius = max_num_nbr, radius 260 | self.pd_data = pd_data 261 | self.np_data = np_data 262 | self.ari = AtomCustomJSONInitializer(self.root_dir + 'atom_init.json') 263 | self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step) 264 | self.clusterizer = SPCL(n_clusters=cls_num, random_state=None, assign_labels='discretize') 265 | self.clusterizer2 = KMeans(n_clusters=cls_num, random_state=None) 266 | self.encoder_elem = ELEM_Encoder() 267 | self.update_root = None 268 | self.args = args 269 | if self.args.data_src == 'ph_dos_51': 270 | #self.structures = torch.load('DATA/20210612_ph_dos_51/ph_structures.pt') 271 | pkl_file = open('../Mat2Spec_DATA/phdos/ph_structures.pkl', 'rb') 272 | self.structures = pickle.load(pkl_file) 273 | pkl_file.close() 274 | 275 | def __len__(self): 276 | return len(self.pd_data) 277 | 278 | # @functools.lru_cache(maxsize=None) # Cache loaded structures 279 | def __getitem__(self, idx): 280 | cif_id = self.pd_data.iloc[idx][0] 281 | target = self.np_data[idx] 282 | 283 | catche_data_exist = False 284 | 285 | if self.args.data_src == 'binned_dos_128': 286 | if path.exists(f'../Mat2Spec_DATA/materials_with_edos_processed/' + cif_id + '.chkpt'): 287 | catche_data_exist = True 288 | elif self.args.data_src == 'ph_dos_51': 289 | if path.exists(f'../Mat2Spec_DATA/materials_with_phdos_processed/' + str(cif_id) + '.chkpt'): 290 | catche_data_exist = True 291 | elif self.args.data_src == 'no_label_128': 292 | if path.exists(f'../Mat2Spec_DATA/materials_without_dos_processed/' + cif_id + '.chkpt'): 293 | catche_data_exist = True 294 | 295 | if self.args.use_catached_data and catche_data_exist: 296 | if self.args.data_src == 'binned_dos_128': 297 | tmp_dist = torch.load(f'../Mat2Spec_DATA/materials_with_edos_processed/' + cif_id + '.chkpt') 298 | elif self.args.data_src == 'ph_dos_51': 299 | tmp_dist = torch.load(f'../Mat2Spec_DATA/materials_with_phdos_processed/' + str(cif_id) + '.chkpt') 300 | elif self.args.data_src == 'no_label_128': 301 | tmp_dist = torch.load(f'../Mat2Spec_DATA/materials_without_dos_processed/' + cif_id + '.chkpt') 302 | 303 | atom_fea = tmp_dist['atom_fea'] 304 | nbr_fea = tmp_dist['nbr_fea'] 305 | nbr_fea_idx = tmp_dist['nbr_fea_idx'] 306 | groups = tmp_dist['groups'] 307 | enc_compo = tmp_dist['enc_compo'] 308 | coordinates = tmp_dist['coordinates'] 309 | target = tmp_dist['target'] 310 | cif_id = tmp_dist['cif_id'] 311 | atom_id = tmp_dist['atom_id'] 312 | return (atom_fea, nbr_fea, nbr_fea_idx), groups, enc_compo, coordinates, target, cif_id, atom_id 313 | 314 | if self.args.data_src == 'binned_dos_128': 315 | with open(os.path.join(self.root_dir + 'materials_with_edos/', 'dos_' + cif_id + '.json')) as json_file: 316 | data = json.load(json_file) 317 | crystal = Structure.from_dict(data['structure']) 318 | elif self.args.data_src == 'ph_dos_51': 319 | crystal = self.structures[idx] 320 | elif self.args.data_src == 'no_label_128': 321 | with open(os.path.join(self.root_dir + 'materials_without_dos/', cif_id + '.json')) as json_file: 322 | data = json.load(json_file) 323 | crystal = Structure.from_dict(data['structure']) 324 | 325 | atom_fea = np.vstack([self.ari.get_atom_fea(crystal[i].specie.number) for i in range(len(crystal))]) 326 | 327 | atom_fea = torch.Tensor(atom_fea) 328 | 329 | all_nbrs = crystal.get_all_neighbors(self.radius, include_index=True) # (site, distance, index, image) 330 | 331 | all_nbrs = [sorted(nbrs, key=lambda x: x[1]) for nbrs in all_nbrs] # [num_atom in this crystal] 332 | nbr_fea_idx, nbr_fea = [], [] 333 | for nbr in all_nbrs: 334 | if len(nbr) < self.max_num_nbr: 335 | nbr_fea_idx.append(list(map(lambda x: x[2], nbr)) + [0] * (self.max_num_nbr - len(nbr))) 336 | nbr_fea.append(list(map(lambda x: x[1], nbr)) + [self.radius + 1.] * (self.max_num_nbr - len(nbr))) 337 | else: 338 | nbr_fea_idx.append(list(map(lambda x: x[2], nbr[:self.max_num_nbr]))) 339 | nbr_fea.append(list(map(lambda x: x[1], nbr[:self.max_num_nbr]))) 340 | nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea) 341 | 342 | # print(nbr_fea_idx.shape) # [n_atom, nbr] 343 | # print(nbr_fea.shape) # [n_atom, nbr] 344 | nbr_fea = self.gdf.expand(nbr_fea) 345 | # print(nbr_fea.shape) # [n_atom, nbr, 41] 346 | 347 | g_coords = crystal.cart_coords 348 | # print(g_coords.shape) # [n_atom, 3] 349 | groups = [0] * len(g_coords) 350 | if len(g_coords) > 2: 351 | try: 352 | groups = self.clusterizer.fit_predict(g_coords) 353 | except: 354 | groups = self.clusterizer2.fit_predict(g_coords) 355 | groups = torch.tensor(groups).long() # [n_atom] 356 | 357 | atom_fea = torch.Tensor(atom_fea) 358 | nbr_fea = torch.Tensor(nbr_fea) 359 | nbr_fea_idx = self.format_adj_matrix(torch.LongTensor(nbr_fea_idx)) # [2, E] 360 | 361 | target = torch.Tensor(target.astype(float)).view(1, -1) 362 | 363 | coordinates = torch.tensor(g_coords) # [n_atom, 3] 364 | enc_compo = self.encoder_elem.encode(crystal.composition) # [1, 103] 365 | 366 | tmp_dist = {} 367 | tmp_dist['atom_fea'] = atom_fea 368 | tmp_dist['nbr_fea'] = nbr_fea 369 | tmp_dist['nbr_fea_idx'] = nbr_fea_idx 370 | tmp_dist['groups'] = groups 371 | tmp_dist['enc_compo'] = enc_compo 372 | tmp_dist['coordinates'] = coordinates 373 | tmp_dist['target'] = target 374 | tmp_dist['cif_id'] = cif_id 375 | tmp_dist['atom_id'] = [crystal[i].specie for i in range(len(crystal))] 376 | 377 | if self.args.data_src == 'binned_dos_128': 378 | pa = '../Mat2Spec_DATA/materials_with_edos_processed/' 379 | mkdirs(pa) 380 | torch.save(tmp_dist, pa + cif_id + '.chkpt') 381 | elif self.args.data_src == 'ph_dos_51': 382 | pa = '../Mat2Spec_DATA/materials_with_phdos_processed/' 383 | mkdirs(pa) 384 | torch.save(tmp_dist, pa + str(cif_id) + '.chkpt') 385 | elif self.args.data_src == 'no_label_128': 386 | pa = '../Mat2Spec_DATA/materials_without_dos_processed/' 387 | mkdirs(pa) 388 | torch.save(tmp_dist, pa + cif_id + '.chkpt') 389 | 390 | return (atom_fea, nbr_fea, nbr_fea_idx), groups, enc_compo, coordinates, target, cif_id, [crystal[i].specie for i in range(len(crystal))] 391 | 392 | def format_adj_matrix(self, adj_matrix): 393 | size = len(adj_matrix) 394 | src_list = list(range(size)) 395 | all_src_nodes = torch.tensor([[x] * adj_matrix.shape[1] for x in src_list]).view(-1).long().unsqueeze(0) 396 | all_dst_nodes = adj_matrix.view(-1).unsqueeze(0) 397 | 398 | return torch.cat((all_src_nodes, all_dst_nodes), dim=0) 399 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/Mat2Spec/file_setter.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from shutil import copyfile 4 | 5 | def use_property(property_name,source, do_prediction = False): 6 | 7 | print('> Preparing dataset to use for Property Prediction. Please wait ...') 8 | 9 | if property_name in ['band','bandgap','band-gap']: filename = 'bandgap.csv' ;p=1;num_T = 36720 10 | elif property_name in ['bulk','bulkmodulus','bulk-modulus','bulk-moduli']:filename = 'bulkmodulus.csv' ;p=3;num_T = 4664 11 | elif property_name in ['energy-1','formationenergy','formation-energy']: filename = 'formationenergy.csv' ;p=2;num_T = 60000 12 | elif property_name in ['energy-2','fermienergy','fermi-energy']: filename = 'fermienergy.csv' ;p=2;num_T = 60000 13 | elif property_name in ['energy-3','absoluteenergy','absolute-energy']: filename = 'absoluteenergy.csv' ;p=2;num_T = 60000 14 | elif property_name in ['shear','shearmodulus','shear-modulus','shear-moduli']:filename = 'shearmodulus.csv';p=4;num_T = 4664 15 | elif property_name in ['poisson','poissonratio','poisson-ratio']: filename = 'poissonratio.csv' ;p=4;num_T = 4664 16 | elif property_name in ['is_metal','is_not_metal']: filename = 'ismetal.csv' ;p=2;num_T = 55391 17 | elif property_name == 'new-property' : filename = 'newproperty.csv' ;p=None;num_T = None 18 | 19 | df = pd.read_csv(f'DATA/properties-reference/{filename}',names=['material_id','value']).replace(to_replace='None',value=np.nan).dropna() 20 | 21 | # CGCNN 22 | if source == 'CGCNN': 23 | # SAVING THE PROPERTIES SEPARATELY 24 | cif_dir = 'CIF-DATA' 25 | if filename in ['bulkmodulus.csv','shearmodulus.csv','poissonratio.csv']: 26 | small = pd.read_csv(f'DATA/cgcnn-reference/mp-ids-3402.csv' ,names=['mp_ids']).values.squeeze() 27 | df = df[df.material_id.isin(small)] 28 | num_T = 2041 29 | elif filename == 'bandgap.csv': 30 | medium = pd.read_csv(f'DATA/cgcnn-reference/mp-ids-27430.csv',names=['mp_ids']).values.squeeze() 31 | df = df[df.material_id.isin(medium)] 32 | num_T = 16458 33 | elif filename in ['formationenergy.csv','fermienergy.csv','ismetal.csv','absoluteenergy.csv']: 34 | large = pd.read_csv(f'DATA/cgcnn-reference/mp-ids-46744.csv',names=['mp_ids']).values.squeeze() 35 | df = df[df.material_id.isin(large)] 36 | num_T = 28046 37 | CIF_dict = {'radius':8,'step':0.2,'max_num_nbr':12} 38 | 39 | # MEGNET 40 | elif source == 'MEGNET': 41 | cif_dir = 'CIF-DATA' 42 | megnet_df = pd.read_csv('DATA/megnet-reference/megnet.csv') 43 | use_ids = megnet_df[megnet_df.iloc[:,p]==1].material_id.values.squeeze() 44 | df = df[df.material_id.isin(use_ids)] 45 | CIF_dict = {'radius':4,'step':0.5,'max_num_nbr':16} 46 | 47 | # CUSTOM 48 | elif source == 'NEW': 49 | cif_dir = 'CIF-DATA_NEW' 50 | CIF_dict = {'radius':8,'step':0.2,'max_num_nbr':12} 51 | d_src = 'DATA' 52 | src, dst = d_src+'/CIF-DATA/atom_init.json',d_src+'/CIF-DATA_NEW/atom_init.json' 53 | copyfile(src, dst) 54 | 55 | 56 | # ADDITIONAL CLEANING 57 | if p in [3,4]: 58 | df = df[df.value>0] 59 | 60 | 61 | df.to_csv(f'DATA/{cif_dir}/id_prop.csv',index=False,header=False) 62 | if not do_prediction: print(f'> Dataset for {source}---{property_name} ready !\n\n') 63 | return source,num_T,CIF_dict 64 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/Mat2Spec/pytorch_stats_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | 5 | ####################################################### 6 | # STATISTICAL DISTANCES(LOSSES) IN PYTORCH # 7 | ####################################################### 8 | 9 | ## Statistial Distances for 1D weight distributions 10 | ## Inspired by Scipy.Stats Statistial Distances for 1D 11 | ## Pytorch Version, supporting Autograd to make a valid Loss 12 | ## Supposing Inputs are Groups of Same-Length Weight Vectors 13 | ## Instead of (Points, Weight), full-length Weight Vectors are taken as Inputs 14 | ## Code Written by E.Bao, CASIA 15 | 16 | def torch_wasserstein_loss(tensor_a,tensor_b): 17 | #Compute the first Wasserstein distance between two 1D distributions. 18 | return(torch_cdf_loss(tensor_a,tensor_b,p=1)) 19 | 20 | def torch_energy_loss(tensor_a,tensor_b): 21 | # Compute the energy distance between two 1D distributions. 22 | return((2**0.5)*torch_cdf_loss(tensor_a,tensor_b,p=2)) 23 | 24 | def torch_cdf_loss(tensor_a,tensor_b,p=1): 25 | # last-dimension is weight distribution 26 | # p is the norm of the distance, p=1 --> First Wasserstein Distance 27 | # to get a positive weight with our normalized distribution 28 | # we recommend combining this loss with other difference-based losses like L1 29 | 30 | # normalize distribution, add 1e-14 to divisor to avoid 0/0 31 | tensor_a = tensor_a / (torch.sum(tensor_a, dim=-1, keepdim=True) + 1e-14) 32 | tensor_b = tensor_b / (torch.sum(tensor_b, dim=-1, keepdim=True) + 1e-14) 33 | # make cdf with cumsum 34 | cdf_tensor_a = torch.cumsum(tensor_a,dim=-1) 35 | cdf_tensor_b = torch.cumsum(tensor_b,dim=-1) 36 | 37 | # choose different formulas for different norm situations 38 | if p == 1: 39 | cdf_distance = torch.sum(torch.abs((cdf_tensor_a-cdf_tensor_b)),dim=-1) 40 | elif p == 2: 41 | cdf_distance = torch.sqrt(torch.sum(torch.pow((cdf_tensor_a-cdf_tensor_b),2),dim=-1)) 42 | else: 43 | cdf_distance = torch.pow(torch.sum(torch.pow(torch.abs(cdf_tensor_a-cdf_tensor_b),p),dim=-1),1/p) 44 | 45 | cdf_loss = cdf_distance.mean() 46 | return cdf_loss 47 | 48 | def torch_validate_distibution(tensor_a,tensor_b): 49 | # Zero sized dimension is not supported by pytorch, we suppose there is no empty inputs 50 | # Weights should be non-negetive, and with a positive and finite sum 51 | # We suppose all conditions will be corrected by network training 52 | # We only check the match of the size here 53 | if tensor_a.size() != tensor_b.size(): 54 | raise ValueError("Input weight tensors must be of the same size") 55 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/Mat2Spec/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pandas as pd 4 | import os 5 | import shutil 6 | import argparse 7 | from operator import attrgetter 8 | 9 | from sklearn.model_selection import train_test_split 10 | from sklearn.metrics import mean_absolute_error as sk_MAE 11 | from tabulate import tabulate 12 | import random,time 13 | 14 | def set_device(gpu_id=0): 15 | device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') 16 | return device 17 | 18 | def set_model_properties(crystal_property): 19 | if crystal_property in ['poisson-ratio','band-gap','absolute-energy','fermi-energy','formation-energy']: 20 | norm_action = None; classification = None 21 | elif crystal_property == 'is_metal': 22 | norm_action = 'classification-1'; classification = 1 23 | elif crystal_property == 'is_not_metal': 24 | norm_action = 'classification-0'; classification = 1 25 | else: 26 | norm_action = 'log'; classification = None 27 | return norm_action, classification 28 | 29 | def torch_MAE(tensor1,tensor2): 30 | return torch.mean(torch.abs(tensor1-tensor2)) 31 | 32 | def torch_accuracy(pred_tensor,true_tensor): 33 | _,pred_tensor = torch.max(pred_tensor,dim=1) 34 | correct = (pred_tensor==true_tensor).sum().float() 35 | total = pred_tensor.size(0) 36 | accuracy_ans = correct/total 37 | return accuracy_ans 38 | 39 | def output_training(metrics_obj,epoch,estop_val,extra='---'): 40 | header_1, header_2 = 'MSE | e-stop','MAE | TIME' 41 | if metrics_obj.c_property in ['is_metal','is_not_metal']: 42 | header_1,header_2 = 'Cross_E | e-stop','Accuracy | TIME' 43 | 44 | train_1,train_2 = metrics_obj.training_loss1[epoch],metrics_obj.training_loss2[epoch] 45 | valid_1,valid_2 = metrics_obj.valid_loss1[epoch],metrics_obj.valid_loss2[epoch] 46 | 47 | tab_val = [['TRAINING',f'{train_1:.4f}',f'{train_2:.4f}'],['VALIDATION',f'{valid_1:.4f}',f'{valid_2:.4f}'],['E-STOPPING',f'{estop_val}',f'{extra}']] 48 | 49 | output = tabulate(tab_val,headers= [f'EPOCH # {epoch}',header_1,header_2],tablefmt='fancy_grid') 50 | print(output) 51 | return output 52 | 53 | def load_metrics(): 54 | saved_metrics = pickle.load(open("MODELS/metrics_.pickle", "rb", -1)) 55 | return saved_metrics 56 | 57 | 58 | def freeze_params(model, params_to_freeze_list): 59 | for str in params_to_freeze_list: 60 | attr = attrgetter(str)(model) 61 | attr.requires_grad = False 62 | attr.grad = None 63 | 64 | 65 | def unfreeze_params(model, params_to_unfreeze_list): 66 | for str in params_to_unfreeze_list: 67 | attr = attrgetter(str)(model) 68 | #print(str) 69 | #print(attr) 70 | attr.requires_grad = True 71 | 72 | 73 | def RobustL1(output, log_std, target): 74 | """ 75 | Robust L1 loss using a lorentzian prior. Allows for estimation 76 | of an aleatoric uncertainty. 77 | """ 78 | absolute = torch.abs(output - target) 79 | loss = np.sqrt(2.0) * absolute * torch.exp(-log_std) + log_std 80 | return torch.mean(loss) 81 | 82 | 83 | def RobustL2(output, log_std, target): 84 | """ 85 | Robust L2 loss using a gaussian prior. Allows for estimation 86 | of an aleatoric uncertainty. 87 | """ 88 | squared = torch.pow(output - target, 2.0) 89 | loss = 0.5 * squared * torch.exp(-2.0 * log_std) + log_std 90 | return torch.mean(loss) 91 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/test_dos128_norm_sum_kl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=1 python test_Mat2Spec.py \ 4 | --concat_comp '' \ 5 | --Mat2Spec-loss-type 'KL' \ 6 | --label_scaling 'normalized_sum' \ 7 | --data_src 'binned_dos_128' \ 8 | --trainset_subset_ratio 1.0 \ 9 | --Mat2Spec-label-dim 128 \ 10 | 11 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/test_dos128_norm_sum_wd.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python test_Mat2Spec.py \ 2 | --concat_comp '' \ 3 | --Mat2Spec-loss-type 'WD' \ 4 | --label_scaling 'normalized_sum' \ 5 | --data_src 'binned_dos_128' \ 6 | --trainset_subset_ratio 1.0 \ 7 | --Mat2Spec-label-dim 128 -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/test_dos128_std_mae.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=1 python test_Mat2Spec.py \ 4 | --concat_comp '' \ 5 | --Mat2Spec-loss-type 'MAE' \ 6 | --label_scaling 'standardized' \ 7 | --data_src 'binned_dos_128' \ 8 | --trainset_subset_ratio 1.0 \ 9 | --Mat2Spec-label-dim 128 \ 10 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/test_nolabel128_norm_sum_kl.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test_Mat2Spec.py \ 2 | --concat_comp '' \ 3 | --Mat2Spec-loss-type 'KL' \ 4 | --label_scaling 'normalized_sum' \ 5 | --data_src 'no_label_128' \ 6 | --trainset_subset_ratio 1.0 \ 7 | --check-point-path './TRAINED/model_Mat2Spec_binned_dos_128_normalized_sum_KL_trainsize1.0.chkpt' \ 8 | --Mat2Spec-label-dim 128 -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/test_nolabel128_norm_sum_wd.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test_Mat2Spec.py \ 2 | --concat_comp '' \ 3 | --Mat2Spec-loss-type 'WD' \ 4 | --label_scaling 'normalized_sum' \ 5 | --data_src 'no_label_128' \ 6 | --trainset_subset_ratio 1.0 \ 7 | --check-point-path './TRAINED/model_Mat2Spec_binned_dos_128_normalized_sum_WD_trainsize1.0.chkpt' \ 8 | --Mat2Spec-label-dim 128 -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/test_nolabel128_std_mae.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test_Mat2Spec.py \ 2 | --concat_comp '' \ 3 | --Mat2Spec-loss-type 'MAE' \ 4 | --label_scaling 'standardized' \ 5 | --data_src 'no_label_128' \ 6 | --trainset_subset_ratio 1.0 \ 7 | --check-point-path './TRAINED/model_Mat2Spec_binned_dos_128_standardized_MAE_trainsize1.0.chkpt' \ 8 | --Mat2Spec-label-dim 128 9 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/test_phdos51_norm_max_mae.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test_Mat2Spec.py \ 2 | --concat_comp '' \ 3 | --Mat2Spec-loss-type 'MAE' \ 4 | --label_scaling 'normalized_max' \ 5 | --data_src 'ph_dos_51' \ 6 | --trainset_subset_ratio 1.0 \ 7 | --train \ 8 | --check-point-path './TRAINED/model_Mat2Spec_ph_dos_51_normalized_max_MAE_trainsize1.0.chkpt' \ 9 | --Mat2Spec-label-dim 51 \ 10 | --Mat2Spec-keep-prob 0.5 \ 11 | --batch-size 8 12 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/test_phdos51_norm_max_mse.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test_Mat2Spec.py \ 2 | --concat_comp '' \ 3 | --Mat2Spec-loss-type 'MSE' \ 4 | --label_scaling 'normalized_max' \ 5 | --data_src 'ph_dos_51' \ 6 | --trainset_subset_ratio 1.0 \ 7 | --train \ 8 | --check-point-path './TRAINED/model_Mat2Spec_ph_dos_51_normalized_max_MSE_trainsize1.0.chkpt' \ 9 | --Mat2Spec-label-dim 51 \ 10 | --Mat2Spec-keep-prob 0.5 \ 11 | --batch-size 8 12 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/test_phdos51_norm_sum_kl.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test_Mat2Spec.py \ 2 | --concat_comp '' \ 3 | --Mat2Spec-loss-type 'KL' \ 4 | --label_scaling 'normalized_sum' \ 5 | --data_src 'ph_dos_51' \ 6 | --trainset_subset_ratio 1.0 \ 7 | --Mat2Spec-label-dim 51 \ 8 | --Mat2Spec-keep-prob 0.5 \ 9 | --batch-size 8 10 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/test_phdos51_norm_sum_wd.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test_Mat2Spec.py \ 2 | --concat_comp '' \ 3 | --Mat2Spec-loss-type 'WD' \ 4 | --label_scaling 'normalized_sum' \ 5 | --data_src 'ph_dos_51' \ 6 | --trainset_subset_ratio 1.0 \ 7 | --Mat2Spec-label-dim 51 \ 8 | --Mat2Spec-keep-prob 0.5 \ 9 | --batch-size 8 10 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/train_dos128_norm_sum_kl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=1 python train_Mat2Spec.py \ 4 | --concat_comp '' \ 5 | --Mat2Spec-loss-type 'KL' \ 6 | --label_scaling 'normalized_sum' \ 7 | --data_src 'binned_dos_128' \ 8 | --trainset_subset_ratio 1.0 \ 9 | --train \ 10 | --Mat2Spec-label-dim 128 \ 11 | 12 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/train_dos128_norm_sum_wd.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python train_Mat2Spec.py \ 2 | --concat_comp '' \ 3 | --Mat2Spec-loss-type 'WD' \ 4 | --label_scaling 'normalized_sum' \ 5 | --data_src 'binned_dos_128' \ 6 | --trainset_subset_ratio 1.0 \ 7 | --train \ 8 | --Mat2Spec-label-dim 128 -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/train_dos128_std_mae.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=1 python train_Mat2Spec.py \ 4 | --concat_comp '' \ 5 | --Mat2Spec-loss-type 'MAE' \ 6 | --label_scaling 'standardized' \ 7 | --data_src 'binned_dos_128' \ 8 | --trainset_subset_ratio 1.0 \ 9 | --train \ 10 | --Mat2Spec-label-dim 128 \ 11 | 12 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/train_phdos51_norm_max_mae.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train_Mat2Spec.py \ 2 | --concat_comp '' \ 3 | --Mat2Spec-loss-type 'MAE' \ 4 | --label_scaling 'normalized_max' \ 5 | --data_src 'ph_dos_51' \ 6 | --trainset_subset_ratio 1.0 \ 7 | --train \ 8 | --Mat2Spec-label-dim 51 \ 9 | --Mat2Spec-keep-prob 0.5 \ 10 | --batch-size 8 -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/train_phdos51_norm_max_mse.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train_Mat2Spec.py \ 2 | --concat_comp '' \ 3 | --Mat2Spec-loss-type 'MSE' \ 4 | --label_scaling 'normalized_max' \ 5 | --data_src 'ph_dos_51' \ 6 | --trainset_subset_ratio 1.0 \ 7 | --train \ 8 | --Mat2Spec-label-dim 51 \ 9 | --Mat2Spec-keep-prob 0.5 \ 10 | --batch-size 8 -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/train_phdos51_norm_sum_kl.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train_Mat2Spec.py \ 2 | --concat_comp '' \ 3 | --Mat2Spec-loss-type 'KL' \ 4 | --label_scaling 'normalized_sum' \ 5 | --data_src 'ph_dos_51' \ 6 | --trainset_subset_ratio 1.0 \ 7 | --train \ 8 | --Mat2Spec-label-dim 51 \ 9 | --Mat2Spec-keep-prob 0.5 \ 10 | --batch-size 8 -------------------------------------------------------------------------------- /Mat2Spec_Codes/SCRIPTS/train_phdos51_norm_sum_wd.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train_Mat2Spec.py \ 2 | --concat_comp '' \ 3 | --Mat2Spec-loss-type 'WD' \ 4 | --label_scaling 'normalized_sum' \ 5 | --data_src 'ph_dos_51' \ 6 | --trainset_subset_ratio 1.0 \ 7 | --train \ 8 | --Mat2Spec-label-dim 51 \ 9 | --Mat2Spec-keep-prob 0.5 \ 10 | --batch-size 8 -------------------------------------------------------------------------------- /Mat2Spec_Codes/test_Mat2Spec.py: -------------------------------------------------------------------------------- 1 | from Mat2Spec.data import * 2 | from Mat2Spec.Mat2Spec import * 3 | from Mat2Spec.file_setter import use_property 4 | from Mat2Spec.utils import * 5 | import matplotlib.pyplot as plt 6 | import random 7 | from tqdm import tqdm 8 | import gc 9 | import pickle 10 | from copy import copy, deepcopy 11 | from os import makedirs 12 | torch.autograd.set_detect_anomaly(True) 13 | device = set_device() 14 | 15 | # MOST CRUCIAL DATA PARAMETERS 16 | parser = argparse.ArgumentParser(description='Mat2Spec') 17 | parser.add_argument('--data_src', default='binned_dos_128',choices=['binned_dos_128','binned_dos_32','ph_dos_51', 'no_label_32', 'no_label_128']) 18 | parser.add_argument('--label_scaling', default='standardized',choices=['standardized','normalized_sum', 'normalized_max']) 19 | # MOST CRUCIAL MODEL PARAMETERS 20 | parser.add_argument('--num_layers',default=3, type=int, 21 | help='number of AGAT layers to use in model (default:3)') 22 | parser.add_argument('--num_neurons',default=128, type=int, 23 | help='number of neurons to use per AGAT Layer(default:64)') 24 | parser.add_argument('--num_heads',default=4, type=int, 25 | help='number of Attention-Heads to use per AGAT Layer (default:4)') 26 | parser.add_argument('--concat_comp',default=False, type=bool, 27 | help='option to re-use vector of elemental composition after global summation of crystal feature.(default: False)') 28 | parser.add_argument('--train_size',default=0.8, type=float, help='ratio size of the training-set (default:0.8)') 29 | parser.add_argument('--trainset_subset_ratio',default=0.5, type=float, help='ratio size of the training-set subset (default:0.5)') 30 | parser.add_argument('--use_catached_data', default=True, type=bool) 31 | parser.add_argument("--train",action="store_true") # default value is false 32 | parser.add_argument('--num-epochs',default=200, type=int) 33 | parser.add_argument('--batch-size',default=128, type=int) 34 | parser.add_argument('--lr',default=0.001, type=float) 35 | parser.add_argument('--Mat2Spec-input-dim',default=128, type=int) 36 | parser.add_argument('--Mat2Spec-label-dim',default=128, type=int) 37 | parser.add_argument('--Mat2Spec-latent-dim',default=128, type=int) 38 | parser.add_argument('--Mat2Spec-emb-size',default=512, type=int) 39 | parser.add_argument('--Mat2Spec-keep-prob',default=0.5, type=float) 40 | parser.add_argument('--Mat2Spec-scale-coeff',default=1.0, type=float) 41 | parser.add_argument('--Mat2Spec-loss-type',default='MAE', type=str, choices=['MAE', 'KL', 'WD', 'MSE']) 42 | parser.add_argument('--Mat2Spec-K',default=10, type=int) 43 | parser.add_argument('--check-point-path', default=None, type=str) 44 | parser.add_argument('--test-mpid', default='mpids.csv', type=str) 45 | parser.add_argument("--finetune",action="store_true") # default value is false 46 | parser.add_argument("--finetune-dataset",default='null',type=str) 47 | parser.add_argument("--ablation-LE",action="store_true") # default value is false 48 | parser.add_argument("--ablation-CL",action="store_true") # default value is false 49 | args = parser.parse_args(sys.argv[1:]) 50 | 51 | # GNN --- parameters 52 | data_src = args.data_src 53 | RSM = {'radius': 8, 'step': 0.2, 'max_num_nbr': 12} 54 | 55 | number_layers = args.num_layers 56 | number_neurons = args.num_neurons 57 | n_heads = args.num_heads 58 | concat_comp = args.concat_comp 59 | 60 | # SETTING UP CODE TO RUN ON GPU 61 | #gpu_id = 0 62 | #device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') 63 | 64 | # DATA PARAMETERS 65 | random_num = 1; random.seed(random_num) 66 | np.random.seed(random_num) 67 | torch.manual_seed(random_num) 68 | # MODEL HYPER-PARAMETERS 69 | num_epochs = args.num_epochs 70 | learning_rate = args.lr 71 | batch_size = args.batch_size 72 | 73 | stop_patience = 150 74 | best_epoch = 1 75 | adj_epochs = 50 76 | milestones = [150,250] 77 | train_param = {'batch_size':batch_size, 'shuffle': True} 78 | valid_param = {'batch_size':batch_size, 'shuffle': False} 79 | 80 | # DATALOADER/ TARGET NORMALIZATION 81 | if args.data_src == 'binned_dos_128': 82 | pd_data = pd.read_csv(f'../Mat2Spec_DATA/label_edos/'+args.test_mpid) 83 | np_data = np.load(f'../Mat2Spec_DATA/label_edos/total_dos_128.npy') 84 | elif args.data_src == 'ph_dos_51': 85 | pd_data = pd.read_csv(f'../Mat2Spec_DATA/phdos/'+args.test_mpid) 86 | np_data = np.load(f'../Mat2Spec_DATA/phdos/ph_dos.npy') 87 | elif args.data_src == 'no_label_128': 88 | pd_data = pd.read_csv(f'../Mat2Spec_DATA/no_label/'+args.test_mpid) 89 | np_data = np.random.rand(len(pd_data), 128) # dummy label 90 | 91 | NORMALIZER = DATA_normalizer(np_data) 92 | 93 | if args.data_src == 'no_label_128': 94 | mean_tmp = torch.tensor(np.load(f'../Mat2Spec_DATA/no_label/label_mean_binned_dos_128.npy')) 95 | std_tmp = torch.tensor(np.load(f'../Mat2Spec_DATA/no_label/label_std_binned_dos_128.npy')) 96 | NORMALIZER.mean = mean_tmp 97 | NORMALIZER.std = std_tmp 98 | 99 | CRYSTAL_DATA = CIF_Dataset(args, pd_data=pd_data, np_data=np_data, root_dir=f'../Mat2Spec_DATA/', **RSM) 100 | 101 | if args.data_src == 'ph_dos_51': 102 | with open('../Mat2Spec_DATA/phdos/200801_trteva_indices.pkl', 'rb') as f: 103 | train_idx, val_idx, test_idx = pickle.load(f) 104 | elif args.data_src == 'no_label_128': 105 | test_idx = list(range(len(pd_data))) 106 | else: 107 | idx_list = list(range(len(pd_data))) 108 | random.shuffle(idx_list) 109 | train_idx_all, test_val = train_test_split(idx_list, train_size=args.train_size, random_state=random_num) 110 | test_idx, val_idx = train_test_split(test_val, test_size=0.5, random_state=random_num) 111 | 112 | if args.trainset_subset_ratio < 1.0: 113 | train_idx, _ = train_test_split(train_idx_all, train_size=args.trainset_subset_ratio, random_state=random_num) 114 | elif args.data_src != 'ph_dos_51' and args.data_src != 'no_label_128': 115 | train_idx = train_idx_all 116 | 117 | if args.finetune: 118 | assert args.finetune_dataset != 'null' 119 | if args.data_src == 'binned_dos_128': 120 | with open(f'../Mat2Spec_DATA/20210619_binned_32_128/materials_classes/' + args.finetune_dataset + '/test_idx.json', ) as f: 121 | test_idx = json.load(f) 122 | else: 123 | raise ValueError('Finetuning is only supported on the binned dos 128 dataset.') 124 | 125 | print('testing size:', len(test_idx)) 126 | 127 | testing_set = CIF_Lister(test_idx, CRYSTAL_DATA, df=pd_data) 128 | 129 | print(f'> USING MODEL Mat2Spec!') 130 | the_network = Mat2Spec(args, NORMALIZER) 131 | net = the_network.to(device) 132 | # load checkpoint 133 | if args.finetune: 134 | check_point_path = './TRAINED/finetune/model_Mat2Spec_' + args.data_src + '_' + args.label_scaling \ 135 | + '_' + args.Mat2Spec_loss_type + '_finetune_' + args.finetune_dataset + '.chkpt' 136 | else: 137 | check_point_path = './TRAINED/model_Mat2Spec_' + args.data_src + '_' + args.label_scaling \ 138 | + '_' + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.chkpt' 139 | 140 | if args.ablation_LE: 141 | check_point_path = './TRAINED/model_Mat2Spec_binned_dos_128_normalized_sum_KL_trainsize1.0_ablation_LE.chkpt' 142 | 143 | if args.ablation_CL: 144 | check_point_path = './TRAINED/model_Mat2Spec_binned_dos_128_normalized_sum_KL_trainsize1.0_ablation_CL.chkpt' 145 | 146 | if args.check_point_path is not None: 147 | check_point = torch.load(args.check_point_path) 148 | else: 149 | check_point = torch.load(check_point_path) 150 | net.load_state_dict(check_point['model']) 151 | 152 | print(f'> TESTING MODEL ...') 153 | test_loader = torch_DataLoader(dataset=testing_set, **valid_param) 154 | 155 | def test(): 156 | training_counter=0 157 | training_loss=0 158 | valid_counter=0 159 | valid_loss=0 160 | best_valid_loss=1e+10 161 | check_fre = 10 162 | current_step = 0 163 | checkpoint_path = './TRAINED/' 164 | 165 | total_loss_smooth = 0 166 | nll_loss_smooth = 0 167 | nll_loss_x_smooth = 0 168 | kl_loss_smooth = 0 169 | cpc_loss_smooth = 0 170 | prediction = [] 171 | prediction_x = [] 172 | label_gt = [] 173 | label_scale_value = [] 174 | sum_pred_smooth = 0 175 | 176 | start_time = time.time() 177 | 178 | # TESTING-PHASE 179 | net.eval() 180 | args.train = True 181 | for data in tqdm(test_loader, mininterval=0.5, desc='(testing)', position=0, leave=True, ascii=True): 182 | data = data.to(device) 183 | valid_label = deepcopy(data.y).float().to(device) 184 | 185 | if args.label_scaling == 'standardized': 186 | valid_label_normalize = (valid_label - NORMALIZER.mean.to(device)) / NORMALIZER.std.to(device) 187 | elif args.label_scaling == 'normalized_max': 188 | #valid_label_normalize = F.normalize(valid_label, dim=1, p=1) 189 | valid_label_normalize = valid_label/(torch.max(valid_label,dim=1)[0].unsqueeze(1)) 190 | 191 | elif args.label_scaling == 'normalized_sum': 192 | valid_label_normalize = valid_label / torch.sum(valid_label, dim=1, keepdim=True) 193 | 194 | with torch.no_grad(): 195 | predictions = net(data) 196 | total_loss, nll_loss, nll_loss_x, kl_loss, cpc_loss, pred_e, pred_x = \ 197 | compute_loss(valid_label_normalize, predictions, NORMALIZER, args) 198 | 199 | prediction.append(pred_e.detach().cpu().numpy()) 200 | prediction_x.append(pred_x.detach().cpu().numpy()) 201 | label_gt.append(valid_label.detach().cpu().numpy()) 202 | 203 | total_loss_smooth += total_loss 204 | nll_loss_smooth += nll_loss 205 | nll_loss_x_smooth += nll_loss_x 206 | kl_loss_smooth += kl_loss 207 | cpc_loss_smooth += cpc_loss 208 | valid_counter += 1 209 | 210 | total_loss_smooth = total_loss_smooth / valid_counter 211 | nll_loss_smooth = nll_loss_smooth / valid_counter 212 | nll_loss_x_smooth = nll_loss_x_smooth / valid_counter 213 | kl_loss_smooth = kl_loss_smooth / valid_counter 214 | cpc_loss_smooth = cpc_loss_smooth / valid_counter 215 | 216 | prediction = np.concatenate(prediction, axis=0) 217 | prediction_x = np.concatenate(prediction_x, axis=0) 218 | label_gt = np.concatenate(label_gt, axis=0) 219 | 220 | return prediction, prediction_x, label_gt, total_loss_smooth.cpu().numpy(), nll_loss_smooth.cpu().numpy(), nll_loss_x_smooth.cpu().numpy(), kl_loss_smooth.cpu().numpy() 221 | 222 | prediction_list = [] 223 | prediction_x_list = [] 224 | label_gt_list = [] 225 | total_loss_smooth_list = [] 226 | nll_loss_smooth_list = [] 227 | nll_loss_x_smooth_list = [] 228 | kl_loss_smooth_list = [] 229 | 230 | for i in range(3): 231 | print(i) 232 | prediction, prediction_x, label_gt, total_loss_smooth, nll_loss_smooth, nll_loss_x_smooth, kl_loss_smooth = test() 233 | prediction_list.append(np.expand_dims(prediction, axis=0)) 234 | prediction_x_list.append(np.expand_dims(prediction_x, axis=0)) 235 | label_gt_list.append(np.expand_dims(label_gt, axis=0)) 236 | total_loss_smooth_list.append(total_loss_smooth) 237 | nll_loss_smooth_list.append(nll_loss_smooth) 238 | nll_loss_x_smooth_list.append(nll_loss_x_smooth) 239 | kl_loss_smooth_list.append(kl_loss_smooth) 240 | 241 | total_loss_smooth = np.mean(total_loss_smooth_list) 242 | nll_loss_smooth = np.mean(nll_loss_smooth_list) 243 | nll_loss_x_smooth = np.mean(nll_loss_x_smooth_list) 244 | kl_loss_smooth = np.mean(kl_loss_smooth_list) 245 | 246 | prediction = np.concatenate(prediction_list, axis=0) 247 | prediction_x = np.concatenate(prediction_x_list, axis=0) 248 | label_gt = np.concatenate(label_gt_list, axis=0) 249 | 250 | #np.save('./RESULT/prediction_Mat2Spec_allsamples_' + args.data_src + '_' + args.label_scaling + '_' \ 251 | # + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', prediction_x) 252 | 253 | prediction_x_std = np.std(prediction_x, axis=0) 254 | prediction = np.mean(prediction, axis=0) 255 | prediction_x = np.mean(prediction_x, axis=0) 256 | label_gt = np.mean(label_gt, axis=0) 257 | 258 | result_path = './RESULT/' 259 | 260 | if args.finetune: 261 | result_path = result_path + '/finetune/' + args.finetune_dataset + '/' 262 | 263 | if args.ablation_LE: 264 | result_path = result_path + '/ablation_LE/' 265 | 266 | if args.ablation_CL: 267 | result_path = result_path + '/ablation_CL/' 268 | 269 | makedirs(result_path, exist_ok=True) 270 | 271 | if args.label_scaling == 'standardized': 272 | print('\n > label scaling: std') 273 | mean = NORMALIZER.mean.detach().numpy() 274 | std = NORMALIZER.std.detach().numpy() 275 | label_gt_standardized = (label_gt - mean) / std 276 | mae = np.mean(np.abs((prediction) - label_gt_standardized)) 277 | mae_x = np.mean(np.abs((prediction_x) - label_gt_standardized)) 278 | #if args.data_src != 'no_label_128' and args.data_src != 'no_label_32': 279 | prediction = prediction * std + mean 280 | prediction_x = prediction_x * std + mean 281 | prediction_x_std = prediction_x_std * std 282 | prediction[prediction < 0] = 1e-6 283 | prediction_x[prediction_x < 0] = 1e-6 284 | mae_ori = np.mean(np.abs((prediction)-label_gt)) 285 | mae_x_ori = np.mean(np.abs((prediction_x)-label_gt)) 286 | 287 | ## save results ## 288 | if args.data_src != 'no_label_128' and args.data_src != 'no_label_32': 289 | np.save(result_path + 'label_gt_' + args.data_src + '_' + args.label_scaling + '_' \ 290 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', label_gt) 291 | np.save(result_path + 'label_mean_' + args.data_src + '_' + args.label_scaling + '_' \ 292 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', mean) 293 | np.save(result_path + 'label_std_' + args.data_src + '_' + args.label_scaling + '_' \ 294 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', std) 295 | np.save(result_path + 'prediction_Mat2Spec_' + args.data_src + '_' + args.label_scaling + '_' \ 296 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', prediction_x) 297 | #np.save(result_path + 'prediction_Mat2Spec_standard_deviation_' + args.data_src + '_' + args.label_scaling + '_' \ 298 | # + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', prediction_x_std) 299 | testing_mpid = pd_data.iloc[test_idx] 300 | testing_mpid.to_csv(result_path + 'testing_mpids' + args.data_src + '_' + args.label_scaling + '_' \ 301 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.csv', index=False, header=True) 302 | 303 | elif args.label_scaling == 'normalized_max': 304 | print('\n > label scaling: norm max') 305 | label_max = np.expand_dims(np.max(label_gt, axis=1), axis=1) 306 | label_gt_standardized = label_gt / label_max 307 | mae = np.mean(np.abs((prediction) - label_gt_standardized)) 308 | mae_x = np.mean(np.abs((prediction_x) - label_gt_standardized)) 309 | if args.data_src != 'no_label_128' and args.data_src != 'no_label_32': 310 | prediction = prediction * label_max 311 | prediction_x = prediction_x * label_max 312 | prediction_x_std = prediction_x_std * label_max 313 | mae_ori = np.mean(np.abs((prediction) - label_gt)) 314 | mae_x_ori = np.mean(np.abs((prediction_x) - label_gt)) 315 | 316 | ## save results ## 317 | if args.data_src != 'no_label_128' and args.data_src != 'no_label_32': 318 | np.save(result_path + 'label_gt_' + args.data_src + '_' + args.label_scaling + '_' \ 319 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', label_gt) 320 | np.save(result_path + 'label_max_' + args.data_src + '_' + args.label_scaling + '_' \ 321 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', label_max) 322 | np.save(result_path + 'prediction_Mat2Spec_' + args.data_src + '_' + args.label_scaling + '_' \ 323 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', prediction_x) 324 | #np.save(result_path + 'prediction_Mat2Spec_standard_deviation_' + args.data_src + '_' + args.label_scaling + '_' \ 325 | # + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', prediction_x_std) 326 | testing_mpid = pd_data.iloc[test_idx] 327 | testing_mpid.to_csv('testing_mpids' + args.data_src + '_' + args.label_scaling + '_' \ 328 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.csv', index=False, header=True) 329 | 330 | elif args.label_scaling == 'normalized_sum': 331 | print('\n > label scaling: norm sum') 332 | assert args.Mat2Spec_loss_type == 'KL' or args.Mat2Spec_loss_type == 'WD' 333 | label_sum = np.sum(label_gt, axis=1, keepdims=True) 334 | label_gt_standardized = label_gt / label_sum 335 | mae = np.mean(np.abs((prediction) - label_gt_standardized)) 336 | mae_x = np.mean(np.abs((prediction_x) - label_gt_standardized)) 337 | if args.data_src != 'no_label_128' and args.data_src != 'no_label_32': 338 | prediction = prediction * label_sum 339 | prediction_x = prediction_x * label_sum 340 | prediction_x_std = prediction_x_std * label_sum 341 | mae_ori = np.mean(np.abs((prediction) - label_gt)) 342 | mae_x_ori = np.mean(np.abs((prediction_x) - label_gt)) 343 | 344 | ## save results ## 345 | if args.data_src != 'no_label_128' and args.data_src != 'no_label_32': 346 | np.save(result_path + 'label_gt_' + args.data_src + '_' + args.label_scaling + '_' \ 347 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', label_gt) 348 | np.save(result_path + 'label_sum_' + args.data_src + '_' + args.label_scaling + '_' \ 349 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', label_sum) 350 | np.save(result_path + 'prediction_Mat2Spec_' + args.data_src + '_' + args.label_scaling + '_' \ 351 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', prediction_x) 352 | #np.save(result_path + 'prediction_Mat2Spec_standard_deviation_' + args.data_src + '_' + args.label_scaling + '_' \ 353 | # + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', prediction_x_std) 354 | testing_mpid = pd_data.iloc[test_idx] 355 | testing_mpid.to_csv(result_path + 'testing_mpids_' + args.data_src + '_' + args.label_scaling + '_' \ 356 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.csv', index=False, header=True) 357 | 358 | print("\n********** TESTING STATISTIC ***********") 359 | print("total_loss =%.6f\t nll_loss =%.6f\t nll_loss_x =%.6f\t kl_loss =%.6f\t" % 360 | (total_loss_smooth, nll_loss_smooth, nll_loss_x_smooth, kl_loss_smooth)) 361 | print("mae=%.6f\t mae_x=%.6f\t mae_ori=%.6f\t mae_x_ori=%.6f" % (mae, mae_x, mae_ori, mae_x_ori)) 362 | print("\n*****************************************") 363 | 364 | print(f"> DONE TESTING !") 365 | -------------------------------------------------------------------------------- /Mat2Spec_Codes/train_Mat2Spec.py: -------------------------------------------------------------------------------- 1 | from Mat2Spec.data import * 2 | from Mat2Spec.Mat2Spec import * 3 | from Mat2Spec.file_setter import use_property 4 | from Mat2Spec.utils import * 5 | import matplotlib.pyplot as plt 6 | import random 7 | from tqdm import tqdm 8 | import gc 9 | import pickle 10 | from copy import copy, deepcopy 11 | import json 12 | torch.autograd.set_detect_anomaly(True) 13 | device = set_device() 14 | 15 | # MOST CRUCIAL DATA PARAMETERS 16 | parser = argparse.ArgumentParser(description='Mat2Spec') 17 | parser.add_argument('--data_src', default='binned_dos_128',choices=['binned_dos_128','binned_dos_32','ph_dos_51']) 18 | parser.add_argument('--label_scaling', default='standardized',choices=['standardized','normalized_sum', 'normalized_max']) 19 | # MOST CRUCIAL MODEL PARAMETERS 20 | parser.add_argument('--num_layers',default=3, type=int, 21 | help='number of AGAT layers to use in model (default:3)') 22 | parser.add_argument('--num_neurons',default=128, type=int, 23 | help='number of neurons to use per AGAT Layer(default:128)') 24 | parser.add_argument('--num_heads',default=4, type=int, 25 | help='number of Attention-Heads to use per AGAT Layer (default:4)') 26 | parser.add_argument('--concat_comp',default=False, type=bool, 27 | help='option to re-use vector of elemental composition after global summation of crystal feature.(default: False)') 28 | parser.add_argument('--train_size',default=0.8, type=float, help='ratio size of the training-set (default:0.8)') 29 | parser.add_argument('--trainset_subset_ratio',default=0.5, type=float, help='ratio size of the training-set subset (default:0.5)') 30 | parser.add_argument('--use_catached_data', default=True, type=bool) 31 | parser.add_argument("--train",action="store_true") # default value is false 32 | parser.add_argument('--num-epochs',default=200, type=int) 33 | parser.add_argument('--batch-size',default=256, type=int) 34 | parser.add_argument('--lr',default=0.001, type=float) 35 | parser.add_argument('--Mat2Spec-input-dim',default=128, type=int) 36 | parser.add_argument('--Mat2Spec-label-dim',default=128, type=int) 37 | parser.add_argument('--Mat2Spec-latent-dim',default=128, type=int) 38 | parser.add_argument('--Mat2Spec-emb-size',default=512, type=int) 39 | parser.add_argument('--Mat2Spec-keep-prob',default=0.5, type=float) 40 | parser.add_argument('--Mat2Spec-scale-coeff',default=1.0, type=float) 41 | parser.add_argument('--Mat2Spec-loss-type',default='MAE', type=str, choices=['MAE', 'KL', 'WD', 'MSE']) 42 | parser.add_argument('--Mat2Spec-K',default=10, type=int) 43 | parser.add_argument("--finetune",action="store_true") # default value is false 44 | parser.add_argument("--ablation-LE",action="store_true") # default value is false 45 | parser.add_argument("--ablation-CL",action="store_true") # default value is false 46 | parser.add_argument("--finetune-dataset",default='null',type=str) 47 | parser.add_argument('--check-point-path', default=None, type=str) 48 | args = parser.parse_args(sys.argv[1:]) 49 | 50 | # GNN --- parameters 51 | data_src = args.data_src 52 | RSM = {'radius': 8, 'step': 0.2, 'max_num_nbr': 12} 53 | 54 | number_layers = args.num_layers 55 | number_neurons = args.num_neurons 56 | n_heads = args.num_heads 57 | concat_comp = args.concat_comp 58 | 59 | # DATA PARAMETERS 60 | random_num = 1; random.seed(random_num) 61 | np.random.seed(random_num) 62 | torch.manual_seed(random_num) 63 | # MODEL HYPER-PARAMETERS 64 | num_epochs = args.num_epochs 65 | learning_rate = args.lr 66 | batch_size = args.batch_size 67 | 68 | stop_patience = 150 69 | best_epoch = 1 70 | adj_epochs = 50 71 | milestones = [150,250] 72 | train_param = {'batch_size':batch_size, 'shuffle': True} 73 | valid_param = {'batch_size':batch_size, 'shuffle': False} 74 | 75 | # DATALOADER/ TARGET NORMALIZATION 76 | if args.data_src == 'binned_dos_128': 77 | pd_data = pd.read_csv(f'../Mat2Spec_DATA/label_edos/mpids.csv') 78 | np_data = np.load(f'../Mat2Spec_DATA/label_edos/total_dos_128.npy') 79 | elif args.data_src == 'ph_dos_51': 80 | pd_data = pd.read_csv(f'../Mat2Spec_DATA/phdos/mpids.csv') 81 | np_data = np.load(f'../Mat2Spec_DATA/phdos/ph_dos.npy') 82 | else: 83 | raise ValueError('') 84 | 85 | NORMALIZER = DATA_normalizer(np_data) 86 | 87 | CRYSTAL_DATA = CIF_Dataset(args, pd_data=pd_data, np_data=np_data, root_dir=f'../Mat2Spec_DATA/', **RSM) 88 | 89 | if args.data_src == 'ph_dos_51': 90 | with open('../Mat2Spec_DATA/phdos/200801_trteva_indices.pkl', 'rb') as f: 91 | train_idx, val_idx, test_idx = pickle.load(f) 92 | else: 93 | idx_list = list(range(len(pd_data))) 94 | random.shuffle(idx_list) 95 | train_idx_all, test_val = train_test_split(idx_list, train_size=args.train_size, random_state=random_num) 96 | test_idx, val_idx = train_test_split(test_val, test_size=0.5, random_state=random_num) 97 | 98 | if args.trainset_subset_ratio < 1.0: 99 | train_idx, _ = train_test_split(train_idx_all, train_size=args.trainset_subset_ratio, random_state=random_num) 100 | elif args.data_src != 'ph_dos_51': 101 | train_idx = train_idx_all 102 | 103 | if args.finetune: 104 | assert args.finetune_dataset != 'null' 105 | if args.data_src == 'binned_dos_128': 106 | with open(f'../Mat2Spec_DATA/label_edos/materials_classes/' + args.finetune_dataset + '/train_idx.json', ) as f: 107 | train_idx = json.load(f) 108 | 109 | with open(f'../Mat2Spec_DATA/label_edos/materials_classes/' + args.finetune_dataset + '/val_idx.json', ) as f: 110 | val_idx = json.load(f) 111 | 112 | with open(f'../Mat2Spec_DATA/label_edos/materials_classes/' + args.finetune_dataset + '/test_idx.json', ) as f: 113 | test_idx = json.load(f) 114 | else: 115 | raise ValueError('Finetuning is only supported on the binned dos 128 dataset.') 116 | 117 | #print('total size:', len(idx_list)) 118 | print('training size:', len(train_idx)) 119 | print('validation size:', len(val_idx)) 120 | print('testing size:', len(test_idx)) 121 | print('total size:', len(train_idx)+len(val_idx)+len(test_idx)) 122 | 123 | training_set = CIF_Lister(train_idx,CRYSTAL_DATA,df=pd_data) 124 | validation_set = CIF_Lister(val_idx,CRYSTAL_DATA,df=pd_data) 125 | 126 | print(f'> USING MODEL Mat2Spec!') 127 | the_network = Mat2Spec(args, NORMALIZER) 128 | net = the_network.to(device) 129 | 130 | if args.finetune: 131 | # load checkpoint 132 | check_point = torch.load(args.check_point_path) 133 | net.load_state_dict(check_point['model']) 134 | learning_rate = learning_rate/5 135 | 136 | # LOSS & OPTMIZER & SCHEDULER 137 | optimizer = optim.AdamW(net.parameters(), lr = learning_rate, weight_decay = 1e-2) 138 | #optimizer = optim.SGD(net.parameters(), lr = learning_rate, momentum=0.9) 139 | 140 | decay_times = 4 141 | decay_ratios = 0.5 142 | one_epoch_iter = np.ceil(len(train_idx) / batch_size) 143 | 144 | if args.finetune: 145 | decay_ratios = 0.5 146 | 147 | scheduler = lr_scheduler.StepLR(optimizer, one_epoch_iter * (num_epochs / decay_times), decay_ratios) 148 | 149 | print(f'> TRAINING MODEL ...') 150 | train_loader = torch_DataLoader(dataset=training_set, **train_param) 151 | valid_loader = torch_DataLoader(dataset=validation_set, **valid_param) 152 | 153 | training_counter=0 154 | training_loss=0 155 | valid_counter=0 156 | valid_loss=0 157 | best_valid_loss=1e+10 158 | check_fre = 10 159 | current_step = 0 160 | 161 | total_loss_smooth = 0 162 | nll_loss_smooth = 0 163 | nll_loss_x_smooth = 0 164 | kl_loss_smooth = 0 165 | cpc_loss_smooth = 0 166 | prediction = [] 167 | prediction_x = [] 168 | label_gt = [] 169 | label_scale_value = [] 170 | sum_pred_smooth = 0 171 | 172 | start_time = time.time() 173 | for epoch in range(num_epochs): 174 | 175 | # TRAINING-STAGE 176 | net.train() 177 | args.train = True 178 | for data in tqdm(train_loader, mininterval=0.5, desc=f'(EPOCH:{epoch} TRAINING)', position=0, leave=True, ascii=True): 179 | current_step += 1 180 | data = data.to(device) 181 | train_label = deepcopy(data.y).to(device) 182 | if args.label_scaling == 'standardized': 183 | train_label_normalize = (train_label - NORMALIZER.mean.to(device)) / NORMALIZER.std.to(device) 184 | elif args.label_scaling == 'normalized_max': 185 | train_label_normalize = train_label / (torch.max(train_label,dim=1)[0].unsqueeze(1)) 186 | elif args.label_scaling == 'normalized_sum': 187 | train_label_normalize = train_label / torch.sum(train_label, dim=1, keepdim=True) 188 | 189 | predictions = net(data) 190 | total_loss, nll_loss, nll_loss_x, kl_loss, c_loss, pred_e, pred_x = \ 191 | compute_loss(train_label_normalize, predictions, NORMALIZER, args) 192 | 193 | optimizer.zero_grad() 194 | total_loss.backward() 195 | optimizer.step() 196 | scheduler.step() 197 | 198 | prediction.append(pred_e.detach().cpu().numpy()) 199 | prediction_x.append(pred_x.detach().cpu().numpy()) 200 | label_gt.append(train_label.detach().cpu().numpy()) 201 | 202 | total_loss_smooth += total_loss 203 | nll_loss_smooth += nll_loss 204 | nll_loss_x_smooth += nll_loss_x 205 | kl_loss_smooth += kl_loss 206 | cpc_loss_smooth += c_loss 207 | training_counter +=1 208 | 209 | total_loss_smooth = total_loss_smooth / training_counter 210 | nll_loss_smooth = nll_loss_smooth / training_counter 211 | nll_loss_x_smooth = nll_loss_x_smooth / training_counter 212 | kl_loss_smooth = kl_loss_smooth / training_counter 213 | cpc_loss_smooth = cpc_loss_smooth / training_counter 214 | 215 | prediction = np.concatenate(prediction, axis=0) 216 | prediction_x = np.concatenate(prediction_x, axis=0) 217 | label_gt = np.concatenate(label_gt, axis=0) 218 | 219 | if args.label_scaling == 'standardized': 220 | mean = NORMALIZER.mean.detach().numpy() 221 | std = NORMALIZER.std.detach().numpy() 222 | label_gt_standardized = (label_gt-mean)/std 223 | mae = np.mean(np.abs((prediction)-label_gt_standardized)) 224 | mae_x = np.mean(np.abs((prediction_x)-label_gt_standardized)) 225 | prediction = prediction*std+mean 226 | prediction_x = prediction_x*std+mean 227 | prediction[prediction < 0] = 0 228 | prediction_x[prediction_x < 0] = 0 229 | mae_ori = np.mean(np.abs((prediction)-label_gt)) 230 | mae_x_ori = np.mean(np.abs((prediction_x)-label_gt)) 231 | 232 | elif args.label_scaling == 'normalized_max': 233 | label_max = np.expand_dims(np.max(label_gt, axis=1), axis=1) 234 | label_gt_standardized = label_gt / label_max 235 | mae = np.mean(np.abs((prediction)-label_gt_standardized)) 236 | mae_x = np.mean(np.abs((prediction_x)-label_gt_standardized)) 237 | prediction = prediction*label_max 238 | prediction_x = prediction_x*label_max 239 | mae_ori = np.mean(np.abs((prediction)-label_gt)) 240 | mae_x_ori = np.mean(np.abs((prediction_x)-label_gt)) 241 | 242 | elif args.label_scaling == 'normalized_sum': 243 | assert args.Mat2Spec_loss_type == 'KL' or args.Mat2Spec_loss_type == 'WD' 244 | label_sum = np.sum(label_gt, axis=1, keepdims=True) 245 | label_gt_standardized = label_gt / label_sum 246 | mae = np.mean(np.abs((prediction)-label_gt_standardized)) 247 | mae_x = np.mean(np.abs((prediction_x)-label_gt_standardized)) 248 | prediction = prediction*label_sum 249 | prediction_x = prediction_x*label_sum 250 | mae_ori = np.mean(np.abs((prediction)-label_gt)) 251 | mae_x_ori = np.mean(np.abs((prediction_x)-label_gt)) 252 | 253 | print("\n********** TRAINING STATISTIC ***********") 254 | print("total_loss =%.6f\t nll_loss =%.6f\t nll_loss_x =%.6f\t kl_loss =%.6f\t cpc_loss=%.6f\t" % 255 | (total_loss_smooth, nll_loss_smooth, nll_loss_x_smooth, kl_loss_smooth, cpc_loss_smooth)) 256 | print("mae=%.6f\t mae_x=%.6f\t mae_ori=%.6f\t mae_x_ori=%.6f" % (mae, mae_x, mae_ori, mae_x_ori)) 257 | print("\n*****************************************") 258 | 259 | training_counter = 0 260 | total_loss_smooth = 0 261 | nll_loss_smooth = 0 262 | nll_loss_x_smooth = 0 263 | kl_loss_smooth = 0 264 | cpc_loss_smooth = 0 265 | prediction = [] 266 | prediction_x = [] 267 | label_gt = [] 268 | label_scale_value = [] 269 | sum_pred_smooth = 0 270 | 271 | # VALIDATION-PHASE 272 | net.eval() 273 | for data in tqdm(valid_loader, mininterval=0.5, desc='(validating)', position=0, leave=True, ascii=True): 274 | data = data.to(device) 275 | valid_label = deepcopy(data.y).float().to(device) 276 | 277 | if args.label_scaling == 'standardized': 278 | valid_label_normalize = (valid_label - NORMALIZER.mean.to(device)) / NORMALIZER.std.to(device) 279 | elif args.label_scaling == 'normalized_max': 280 | valid_label_normalize = valid_label/(torch.max(valid_label,dim=1)[0].unsqueeze(1)) 281 | 282 | elif args.label_scaling == 'normalized_sum': 283 | valid_label_normalize = valid_label / (torch.sum(valid_label, dim=1, keepdim=True)+1e-8) 284 | 285 | with torch.no_grad(): 286 | predictions = net(data) 287 | total_loss, nll_loss, nll_loss_x, kl_loss, cpc_loss, pred_e, pred_x = \ 288 | compute_loss(valid_label_normalize, predictions, NORMALIZER, args) 289 | 290 | prediction.append(pred_e.detach().cpu().numpy()) 291 | prediction_x.append(pred_x.detach().cpu().numpy()) 292 | label_gt.append(valid_label.detach().cpu().numpy()) 293 | 294 | total_loss_smooth += total_loss 295 | nll_loss_smooth += nll_loss 296 | nll_loss_x_smooth += nll_loss_x 297 | kl_loss_smooth += kl_loss 298 | cpc_loss_smooth += cpc_loss 299 | valid_counter += 1 300 | 301 | total_loss_smooth = total_loss_smooth / valid_counter 302 | nll_loss_smooth = nll_loss_smooth / valid_counter 303 | nll_loss_x_smooth = nll_loss_x_smooth / valid_counter 304 | kl_loss_smooth = kl_loss_smooth / valid_counter 305 | cpc_loss_smooth = cpc_loss_smooth / valid_counter 306 | 307 | prediction = np.concatenate(prediction, axis=0) 308 | prediction_x = np.concatenate(prediction_x, axis=0) 309 | label_gt = np.concatenate(label_gt, axis=0) 310 | 311 | if args.label_scaling == 'standardized': 312 | mean = NORMALIZER.mean.detach().numpy() 313 | std = NORMALIZER.std.detach().numpy() 314 | label_gt_standardized = (label_gt-mean)/std 315 | mae = np.mean(np.abs((prediction)-label_gt_standardized)) 316 | mae_x = np.mean(np.abs((prediction_x)-label_gt_standardized)) 317 | prediction = prediction*std+mean 318 | prediction_x = prediction_x*std+mean 319 | prediction[prediction < 0] = 0 320 | prediction_x[prediction_x < 0] = 0 321 | mae_ori = np.mean(np.abs((prediction)-label_gt)) 322 | mae_x_ori = np.mean(np.abs((prediction_x)-label_gt)) 323 | 324 | elif args.label_scaling == 'normalized_max': 325 | label_max = np.expand_dims(np.max(label_gt, axis=1), axis=1) 326 | label_gt_standardized = label_gt / label_max 327 | mae = np.mean(np.abs((prediction)-label_gt_standardized)) 328 | mae_x = np.mean(np.abs((prediction_x)-label_gt_standardized)) 329 | prediction = prediction*label_max 330 | prediction_x = prediction_x*label_max 331 | mae_ori = np.mean(np.abs((prediction)-label_gt)) 332 | mae_x_ori = np.mean(np.abs((prediction_x)-label_gt)) 333 | 334 | elif args.label_scaling == 'normalized_sum': 335 | assert args.Mat2Spec_loss_type == 'KL' or args.Mat2Spec_loss_type == 'WD' 336 | label_sum = np.sum(label_gt, axis=1, keepdims=True) 337 | label_gt_standardized = label_gt / label_sum 338 | mae = np.mean(np.abs((prediction)-label_gt_standardized)) 339 | mae_x = np.mean(np.abs((prediction_x)-label_gt_standardized)) 340 | prediction = prediction*label_sum 341 | prediction_x = prediction_x*label_sum 342 | mae_ori = np.mean(np.abs((prediction)-label_gt)) 343 | mae_x_ori = np.mean(np.abs((prediction_x)-label_gt)) 344 | 345 | print("\n********** VALIDATING STATISTIC ***********") 346 | print("total_loss =%.6f\t nll_loss =%.6f\t nll_loss_x =%.6f\t kl_loss =%.6f\t cpc_loss = %.6f\t" % 347 | (total_loss_smooth, nll_loss_smooth, nll_loss_x_smooth, kl_loss_smooth, cpc_loss_smooth)) 348 | print("mae=%.6f\t mae_x=%.6f\t mae_ori=%.6f\t mae_x_ori=%.6f" % (mae, mae_x, mae_ori, mae_x_ori)) 349 | print("\n*****************************************") 350 | 351 | if best_valid_loss > mae_x_ori: 352 | best_valid_loss = mae_x_ori 353 | print("\n********** SAVING MODEL ***********") 354 | checkpoint = {'model': net.state_dict(), 'args': args} 355 | if not args.finetune: 356 | #checkpoint_path = './TRAINED/' 357 | save_path = './TRAINED/model_Mat2Spec_' + args.data_src + '_' + args.label_scaling + '_' \ 358 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) 359 | else: 360 | save_path = './TRAINED/finetune/model_Mat2Spec_' + args.data_src + '_' + args.label_scaling + '_' \ 361 | + args.Mat2Spec_loss_type + '_finetune_' + str(args.finetune_dataset) 362 | 363 | if args.ablation_LE: 364 | save_path = save_path + '_ablation_LE' 365 | 366 | if args.ablation_CL: 367 | save_path = save_path + '_ablation_CL' 368 | 369 | save_path = save_path + '.chkpt' 370 | torch.save(checkpoint, save_path) 371 | print("A new model has been saved to " + save_path) 372 | print("\n*****************************************") 373 | 374 | valid_counter=0 375 | total_loss_smooth = 0 376 | nll_loss_smooth = 0 377 | nll_loss_x_smooth = 0 378 | kl_loss_smooth = 0 379 | cpc_loss_smooth = 0 380 | prediction = [] 381 | prediction_x = [] 382 | label_gt = [] 383 | label_scale_value = [] 384 | sum_pred_smooth = 0 385 | gc.collect() 386 | 387 | end_time = time.time() 388 | e_time = end_time - start_time 389 | print('Best validation loss=%.6f, training time (min)=%.6f'%(best_valid_loss, e_time/60)) 390 | print(f"> DONE TRAINING !") 391 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Density of States Prediction for Materials Discovery via Contrastive Learning from Probabilistic Embeddings 2 | 3 | Authors: Shufeng Kong 1, Francesco Ricci 2,4, Dan Guevarra 3, Jeffrey B. Neaton 2,5,6, Carla P. Gomes 1, and John M. Gregoire 3 4 | 1) Department of Computer Science, Cornell University, Ithaca, NY, USA 5 | 2) Material Science Division, Lawrence Berkeley National Laboratory, Berkeley, CA, USA 6 | 3) Division of Engineering and Applied Science, California Institute of Technology, Pasadena, CA, USA 7 | 4) Chemical Science Division, Lawrence Berkeley National Laboratory, Berkeley, CA, USA 8 | 5) Department of Physics, University of California, Berkeley, Berkeley, CA, USA 9 | 6) Kavli Energy NanoSciences Institute at Berkeley, Berkeley, CA, USA 10 | 11 | This a Pytorch implementation of the machine learning model "Mat2Spec" presented in this paper (https://www.nature.com/articles/s41467-022-28543-x). 12 | Any question or suggestion about the codes please directly send to sk2299@cornell.edu 13 | 14 | ### Installation 15 | Install the following packages if not already installed: - may take 30 mins on typical machine to install all of them: 16 | * Python (tested on 3.8.11) 17 | * Pytorch (tested on 1.4.0) 18 | * Cuda (tested on 10.0) 19 | * Pandas (tested on 1.3.3) 20 | * Pytmatgen (tested on 2022.0.14) 21 | * PyTorch-Geometric (tested on 1.5.0) 22 | 23 | Please follow these steps to create an environment: 24 | 25 | 1) Download packages - example: 26 | https://download.pytorch.org/whl/cu100/torch-1.4.0%2Bcu100-cp38-cp38-linux_x86_64.whl 27 | https://download.pytorch.org/whl/cu100/torchvision-0.5.0%2Bcu100-cp38-cp38-linux_x86_64.whl 28 | https://data.pyg.org/whl/torch-1.4.0/torch_cluster-1.5.4%2Bcu100-cp38-cp38-linux_x86_64.whl 29 | https://data.pyg.org/whl/torch-1.4.0/torch_scatter-2.0.4%2Bcu100-cp38-cp38-linux_x86_64.whl 30 | https://data.pyg.org/whl/torch-1.4.0/torch_sparse-0.6.1%2Bcu100-cp38-cp38-linux_x86_64.whl 31 | https://data.pyg.org/whl/torch-1.4.0/torch_spline_conv-1.2.0%2Bcu100-cp38-cp38-linux_x86_64.whl 32 | 33 | 2) Install packages - example 34 | 35 | ```bash 36 | conda create --name mat2spec python=3.8 37 | conda activate mat2spec 38 | pip install torch-1.4.0+cu100-cp38-cp38-linux_x86_64.whl 39 | pip install torchvision-0.5.0+cu100-cp38-cp38-linux_x86_64.whl 40 | pip install torch_cluster-1.5.4+cu100-cp38-cp38-linux_x86_64.whl 41 | pip install torch_scatter-2.0.4+cu100-cp38-cp38-linux_x86_64.whl 42 | pip install torch_sparse-0.6.1+cu100-cp38-cp38-linux_x86_64.whl 43 | pip install torch_spline_conv-1.2.0+cu100-cp38-cp38-linux_x86_64.whl 44 | pip install torch-geometric==1.5.0 45 | pip install pandas 46 | pip install pymatgen 47 | ``` 48 | 49 | When finish using our model, you can deactivate the environment: 50 | ```bash 51 | conda deactivate 52 | ``` 53 | 54 | Remember to activate the environment before using our model next time: 55 | ```bash 56 | conda activate mat2spec 57 | ``` 58 | 59 | ### Datasets 60 | 61 | 1) Phonon density of state: see our data repository link below, or data can be downloaded from here https://github.com/zhantaochen/phonondos_e3nn. 62 | 2) Electronic density of state: see our data repository link below, or data can be downloaded from the Materials Project. 63 | 3) Initial element embeddings: please refer to "Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties" by Tian Xie and Jeffrey C. Grossman. 64 | 65 | These initial element embeddings include the embeddings of the following elements: 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr'. 66 | 67 | Datasets for this work are avaiable at https://data.caltech.edu/records/8975 68 | 69 | Please download the data folder and unzip it under the main folder 'Mat2Spec'. 70 | 71 | ### Example Usage 72 | 73 | Model training typically takes 20 min for phDOS and 3 hours for eDOS on a GPU. 74 | 75 | To train the model on phDOS with maxnorm and MSE: 76 | ```bash 77 | bash SCRIPTS/train_phdos51_norm_max_mse.sh 78 | ``` 79 | Note that the bash scripts manually assign the CUDA device index via environment variable CUDA_VISIBLE_DEVICES and should be adjusted to the correct index (usually '0' for single GPU systems) prior to training or else Pytorch will only leverage CPU. 80 | 81 | To train the model in eDOS with std and MAE: 82 | ```bash 83 | bash SCRIPTS/train_dos128_std_mae.sh 84 | ``` 85 | 86 | To train the model in eDOS with norm sum and KL: 87 | ```bash 88 | bash SCRIPTS/train_dos128_norm_sum_kl.sh 89 | ``` 90 | 91 | To test the trained models: 92 | ```bash 93 | bash SCRIPTS/test_phdos51_norm_max_mse.sh 94 | bash SCRIPTS/test_dos128_std_mae.sh 95 | bash SCRIPTS/test_dos128_norm_sum_kl.sh 96 | ``` 97 | 98 | To use the trained models for predicting eDOS for material without label: 99 | 100 | 1) Place your json files under ./Mat2Spec_DATA/materials_without_dos/ 101 | Each json file should includes a key 'structure' which maps to a material in the pymatgen format. 102 | 103 | 2) Place a csv file named 'mpids.csv' that contains all your json files' names under ./DATA/20210623_no_label 104 | 105 | 3) If you want to use trained models with std and MAE: 106 | 107 | ```bash 108 | bash SCRIPTS/test_nolabel128_std_mae.sh 109 | ``` 110 | 111 | 4) If you want to use trained models with norm sum and KL: 112 | 113 | ```bash 114 | bash SCRIPTS/test_nolabel128_std_mae.sh 115 | bash SCRIPTS/test_nolabel128_norm_sum_kl.sh 116 | ``` 117 | 118 | Then rescale the KL prediction with the std prediction: 119 | ```bash 120 | x_sd = np.load('prediction_Mat2Spec_no_label_128_standardized_MAE_trainsize1.0.npy') 121 | x_kl = np.load('prediction_Mat2Spec_no_label_128_normalized_sum_KL_trainsize1.0.npy') 122 | x = x_kl*np.sum(x_sd, axis=-1, keepdims=True) 123 | ``` 124 | 125 | 126 | All test results (model-predicted DOS) are placed under ./RESULT 127 | 128 | 129 | ### Disclaimer 130 | This is research code shared without support or any guarantee on its quality. However, please do raise an issue or submit a pull request if you spot something wrong or that could be improved and I will try my best to solve it. 131 | 132 | ### Acknowledgements 133 | Implementation of the GNN is modified from GATGNN: https://github.com/superlouis/GATGNN. 134 | --------------------------------------------------------------------------------