├── MIT LICENSE.txt
├── Mat2Spec_Codes
├── Mat2Spec
│ ├── Mat2Spec.py
│ ├── SinkhornDistance.py
│ ├── __init__.py
│ ├── data.py
│ ├── file_setter.py
│ ├── pytorch_stats_loss.py
│ └── utils.py
├── SCRIPTS
│ ├── test_dos128_norm_sum_kl.sh
│ ├── test_dos128_norm_sum_wd.sh
│ ├── test_dos128_std_mae.sh
│ ├── test_nolabel128_norm_sum_kl.sh
│ ├── test_nolabel128_norm_sum_wd.sh
│ ├── test_nolabel128_std_mae.sh
│ ├── test_phdos51_norm_max_mae.sh
│ ├── test_phdos51_norm_max_mse.sh
│ ├── test_phdos51_norm_sum_kl.sh
│ ├── test_phdos51_norm_sum_wd.sh
│ ├── train_dos128_norm_sum_kl.sh
│ ├── train_dos128_norm_sum_wd.sh
│ ├── train_dos128_std_mae.sh
│ ├── train_phdos51_norm_max_mae.sh
│ ├── train_phdos51_norm_max_mse.sh
│ ├── train_phdos51_norm_sum_kl.sh
│ └── train_phdos51_norm_sum_wd.sh
├── test_Mat2Spec.py
└── train_Mat2Spec.py
└── README.md
/MIT LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021-2022
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/Mat2Spec/Mat2Spec.py:
--------------------------------------------------------------------------------
1 | import torch, numpy as np
2 | import torch.optim as optim
3 | from torch.optim import lr_scheduler
4 | from torch.nn import Linear, Dropout, Parameter
5 | import torch.nn.functional as F
6 | import torch.nn as nn
7 |
8 | from torch_geometric.nn.conv import MessagePassing
9 | from torch_geometric.utils import softmax
10 | from torch_geometric.nn import global_add_pool, global_mean_pool
11 | from torch_geometric.nn import GATConv
12 | from torch_scatter import scatter_add
13 | from torch_geometric.nn.inits import glorot, zeros
14 |
15 | from random import sample
16 | from copy import copy, deepcopy
17 | from Mat2Spec.utils import *
18 | from Mat2Spec.SinkhornDistance import SinkhornDistance
19 | from Mat2Spec.pytorch_stats_loss import torch_wasserstein_loss
20 |
21 | device = set_device()
22 | torch.cuda.empty_cache()
23 | kl_loss_fn = torch.nn.KLDivLoss()
24 | sinkhorn = SinkhornDistance(eps=0.1, max_iter=50, reduction='mean').to(device)
25 |
26 |
27 | # Note: the part of GNN implementation is modified from https://github.com/superlouis/GATGNN/
28 |
29 | class COMPOSITION_Attention(torch.nn.Module):
30 | def __init__(self,neurons):
31 | super(COMPOSITION_Attention, self).__init__()
32 | self.node_layer1 = Linear(neurons+103,32)
33 | self.atten_layer = Linear(32,1)
34 |
35 | def forward(self,x,batch,global_feat):
36 | #torch.set_printoptions(threshold=10_000)
37 | # global_feat, [bs*103], rach row is an atom composition vector
38 | # x: [num_atom * atom_emb_len]
39 |
40 | counts = torch.unique(batch,return_counts=True)[-1] # return the number of atoms per crystal
41 | # batch includes all of the atoms from the Batch of crystals, each atom indexed by its Batch index.
42 |
43 | graph_embed = global_feat
44 | graph_embed = torch.repeat_interleave(graph_embed, counts, dim=0) # repeat rows according to counts
45 | chunk = torch.cat([x,graph_embed],dim=-1)
46 | x = F.softplus(self.node_layer1(chunk)) # [num_atom * 32]
47 | x = self.atten_layer(x) # [num_atom * 1]
48 | weights = softmax(x,batch) # [num_atom * 1]
49 | return weights
50 |
51 |
52 | class GAT_Crystal(MessagePassing):
53 | def __init__(self, in_features, out_features, edge_dim, heads, concat=False,
54 | dropout=0.0, bias=True, has_edge_attr=True, **kwargs):
55 | super(GAT_Crystal, self).__init__(aggr='add',flow='target_to_source', **kwargs)
56 | self.in_features = in_features
57 | self.out_features = out_features
58 | self.heads = heads
59 | self.concat = concat
60 | #self.dropout = dropout
61 | self.dropout = nn.Dropout(p=dropout)
62 | self.neg_slope = 0.2
63 | self.prelu = nn.PReLU()
64 | self.bn1 = nn.BatchNorm1d(heads)
65 | if has_edge_attr:
66 | self.W = Parameter(torch.Tensor(in_features+edge_dim,heads*out_features))
67 | else:
68 | self.W = Parameter(torch.Tensor(in_features, heads * out_features))
69 | self.att = Parameter(torch.Tensor(1,heads,2*out_features))
70 |
71 | if bias and concat : self.bias = Parameter(torch.Tensor(heads * out_features))
72 | elif bias and not concat : self.bias = Parameter(torch.Tensor(out_features))
73 | else : self.register_parameter('bias', None)
74 | self.reset_parameters()
75 |
76 | def reset_parameters(self):
77 | glorot(self.W)
78 | glorot(self.att)
79 | zeros(self.bias)
80 |
81 | def forward(self, x, edge_index, edge_attr=None):
82 | # x: [num_node, emb_len]
83 | # edge_index: [2, num_edge]
84 | # edge_attr: [num_edge, emb_len]
85 | return self.propagate(edge_index, x=x, edge_attr=edge_attr)
86 |
87 | def message(self, edge_index_i, x_i, x_j, size_i, edge_attr):
88 | # edge_index_i: [num_edge]
89 | # x_i: [num_edge, emb_len]
90 | # x_j: [num_edge, emb_len]
91 | # size_i: num_node
92 | # edge_attr: [num_edge, emb_len]
93 | if edge_attr is not None:
94 | x_i = torch.cat([x_i,edge_attr],dim=-1)
95 | x_j = torch.cat([x_j,edge_attr],dim=-1)
96 |
97 | x_i = F.softplus(torch.matmul(x_i,self.W))
98 | x_j = F.softplus(torch.matmul(x_j,self.W))
99 |
100 | x_i = x_i.view(-1, self.heads, self.out_features) # [num_edge, num_head, emb_len]
101 | x_j = x_j.view(-1, self.heads, self.out_features) # [num_edge, num_head, emb_len]
102 |
103 | alpha = F.softplus((torch.cat([x_i, x_j], dim=-1)*self.att).sum(dim=-1)) # [num_edge, num_head]
104 |
105 | # self.att: (1,heads,2*out_features)
106 |
107 | alpha = F.softplus(self.bn1(alpha))
108 | alpha = softmax(alpha, edge_index_i, size_i) # [num_edge, num_head]
109 | #alpha = softmax(alpha, edge_index_i) # [num_edge, num_head]
110 | alpha = self.dropout(alpha)
111 |
112 | return x_j * alpha.view(-1, self.heads, 1) # [num_edge, num_head, emb_len]
113 |
114 | def update(self, aggr_out):
115 | # aggr_out: [num_node, num_head, emb_len]
116 | if self.concat is True: aggr_out = aggr_out.view(-1, self.heads * self.out_features)
117 | else: aggr_out = aggr_out.mean(dim=1)
118 | if self.bias is not None: aggr_out = aggr_out + self.bias
119 | return aggr_out # [num_node, emb_len]
120 |
121 | class FractionalEncoder(nn.Module):
122 | """
123 | Encoding element fractional amount using a "fractional encoding" inspired
124 | by the positional encoder discussed by Vaswani.
125 | https://arxiv.org/abs/1706.03762
126 | """
127 | def __init__(self,
128 | d_model,
129 | resolution=100,
130 | log10=False,
131 | compute_device=None):
132 | super().__init__()
133 | self.d_model = d_model//2
134 | self.resolution = resolution
135 | self.log10 = log10
136 | self.compute_device = compute_device
137 |
138 | x = torch.linspace(0, self.resolution - 1,
139 | self.resolution,
140 | requires_grad=False) \
141 | .view(self.resolution, 1) # (resolution, 1)
142 | fraction = torch.linspace(0, self.d_model - 1,
143 | self.d_model,
144 | requires_grad=False) \
145 | .view(1, self.d_model).repeat(self.resolution, 1) # (resolution, d_model)
146 |
147 | pe = torch.zeros(self.resolution, self.d_model) # (resolution, d_model)
148 | pe[:, 0::2] = torch.sin(x /torch.pow(50, 2 * fraction[:, 0::2] / self.d_model))
149 | pe[:, 1::2] = torch.cos(x / torch.pow(50, 2 * fraction[:, 1::2] / self.d_model))
150 | pe = self.register_buffer('pe', pe) # (resolution, d_model)
151 |
152 | def forward(self, x):
153 | x = x.clone()
154 | if self.log10:
155 | x = 0.0025 * (torch.log2(x))**2
156 | x[x > 1] = 1
157 | # x = 1 - x # for sinusoidal encoding at x=0
158 | x[x < 1/self.resolution] = 1/self.resolution
159 | frac_idx = torch.round(x * (self.resolution)).to(dtype=torch.long) - 1 # (bs, n_elem)
160 | out = self.pe[frac_idx] # (bs, n_elem, d_model)
161 | return out
162 |
163 | class GNN(torch.nn.Module):
164 | def __init__(self,heads,neurons=64,nl=3,concat_comp=False):
165 | super(GNN, self).__init__()
166 |
167 | self.n_heads = heads
168 | self.number_layers = nl
169 | self.concat_comp = concat_comp
170 |
171 | n_h, n_hX2 = neurons, neurons*2
172 | self.neurons = neurons
173 | self.neg_slope = 0.2
174 |
175 | self.embed_n = Linear(92,n_h)
176 | self.embed_e = Linear(41,n_h)
177 | self.embed_comp = Linear(103,n_h)
178 |
179 | self.node_att = nn.ModuleList([GAT_Crystal(n_h,n_h,n_h,self.n_heads) for i in range(nl)])
180 | self.batch_norm = nn.ModuleList([nn.BatchNorm1d(n_h) for i in range(nl)])
181 |
182 | self.comp_atten = COMPOSITION_Attention(n_h)
183 |
184 | self.emb_scaler = nn.parameter.Parameter(torch.tensor([1.]))
185 | self.pos_scaler = nn.parameter.Parameter(torch.tensor([1.]))
186 | self.pos_scaler_log = nn.parameter.Parameter(torch.tensor([1.]))
187 | self.pe = FractionalEncoder(n_h, resolution=5000, log10=False)
188 | self.ple = FractionalEncoder(n_h, resolution=5000, log10=True)
189 | self.pe_linear = nn.Linear(103, 1)
190 | self.ple_linear = nn.Linear(103, 1)
191 |
192 | if self.concat_comp : reg_h = n_hX2
193 | else : reg_h = n_h
194 |
195 | self.linear1 = nn.Linear(reg_h,reg_h)
196 | self.linear2 = nn.Linear(reg_h,reg_h)
197 |
198 | def forward(self,data):
199 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
200 |
201 | batch, global_feat, cluster = data.batch, data.global_feature, data.cluster
202 |
203 | x = self.embed_n(x) # [num_atom, emb_len]
204 |
205 | edge_attr = F.leaky_relu(self.embed_e(edge_attr),self.neg_slope) # [num_edges, emb_len]
206 |
207 | for a_idx in range(len(self.node_att)):
208 | x = self.node_att[a_idx](x,edge_index,edge_attr) # [num_atom, emb_len]
209 | x = self.batch_norm[a_idx](x)
210 | x = F.softplus(x)
211 |
212 | ag = self.comp_atten(x,batch,global_feat) # [num_atom * 1]
213 | x = (x)*ag # [num_atom, emb_len]
214 |
215 | # CRYSTAL FEATURE-AGGREGATION
216 | y = global_mean_pool(x,batch)#*2**self.emb_scaler#.unsqueeze(1).squeeze() # [bs, emb_len]
217 | #y = F.relu(self.linear1(y)) # [bs, emb_len]
218 | #y = F.relu(self.linear2(y)) # [bs, emb_len]
219 |
220 | if self.concat_comp:
221 | pe = torch.zeros([global_feat.shape[0], global_feat.shape[1], y.shape[1]]).to(device)
222 | ple = torch.zeros([global_feat.shape[0], global_feat.shape[1], y.shape[1]]).to(device)
223 | pe_scaler = 2 ** (1 - self.pos_scaler) ** 2
224 | ple_scaler = 2 ** (1 - self.pos_scaler_log) ** 2
225 | pe[:, :, :y.shape[1] // 2] = self.pe(global_feat)# * pe_scaler
226 | ple[:, :, y.shape[1] // 2:] = self.ple(global_feat)# * ple_scaler
227 | pe = self.pe_linear(torch.transpose(pe, 1,2)).squeeze()* pe_scaler
228 | ple = self.ple_linear(torch.transpose(ple, 1,2)).squeeze()* ple_scaler
229 | y = y + pe + ple
230 | #y = torch.cat([y, pe+ple], dim=-1)
231 | #y = torch.cat([y, F.leaky_relu(self.embed_comp(global_feat), self.neg_slope)], dim=-1)
232 |
233 | return y
234 |
235 | class Mat2Spec(nn.Module):
236 | def __init__(self, args, NORMALIZER):
237 | super(Mat2Spec, self).__init__()
238 | n_heads = args.num_heads
239 | number_neurons = args.num_neurons
240 | number_layers = args.num_layers
241 | concat_comp = args.concat_comp
242 | self.graph_encoder = GNN(n_heads, neurons=number_neurons, nl=number_layers, concat_comp=concat_comp)
243 |
244 | self.loss_type = args.Mat2Spec_loss_type
245 | self.NORMALIZER = NORMALIZER
246 | self.input_dim = args.Mat2Spec_input_dim
247 | self.latent_dim = args.Mat2Spec_latent_dim
248 | self.emb_size = args.Mat2Spec_emb_size
249 | self.label_dim = args.Mat2Spec_label_dim
250 | self.scale_coeff = args.Mat2Spec_scale_coeff
251 | self.keep_prob = args.Mat2Spec_keep_prob
252 | self.K = args.Mat2Spec_K
253 | self.args = args
254 |
255 | self.fx1 = nn.Linear(self.input_dim, 256)
256 | self.fx2 = nn.Linear(256, 512)
257 | self.fx3 = nn.Linear(512, 256)
258 | self.fx_mu = nn.Linear(256, self.latent_dim*self.K)
259 | self.fx_logvar = nn.Linear(256, self.latent_dim*self.K)
260 | self.fx_mix_coeff = nn.Linear(256, self.K)
261 |
262 | self.fe_mix_coeff = nn.Sequential(
263 | nn.Linear(self.label_dim, 128),
264 | nn.ReLU(),
265 | nn.Linear(128, self.label_dim)
266 | )
267 |
268 | self.fd_x1 = nn.Linear(self.input_dim + self.latent_dim, 512)
269 | self.fd_x2 = torch.nn.Sequential(
270 | nn.Linear(512, self.emb_size)
271 | )
272 | self.feat_mp_mu = nn.Linear(self.emb_size, self.label_dim)
273 |
274 | # label layers
275 | self.fe0 = nn.Linear(self.label_dim, self.emb_size)
276 | self.fe1 = nn.Linear(self.label_dim, 512)
277 | self.fe2 = nn.Linear(512, 256)
278 | self.fe_mu = nn.Linear(256, self.latent_dim)
279 | self.fe_logvar = nn.Linear(256, self.latent_dim)
280 |
281 | self.fd1 = self.fd_x1
282 | self.fd2 = self.fd_x2
283 | #self.fd = self.fd_x
284 | self.label_mp_mu = self.feat_mp_mu
285 |
286 | self.bias = nn.Parameter(torch.zeros(self.label_dim))
287 |
288 | assert id(self.fd_x1) == id(self.fd1)
289 | assert id(self.fd_x2) == id(self.fd2)
290 |
291 | self.dropout = nn.Dropout(p=self.keep_prob)
292 | self.emb_proj = nn.Linear(args.Mat2Spec_emb_size, 1024)
293 | self.W = nn.Linear(args.Mat2Spec_label_dim, args.Mat2Spec_emb_size) # linear transformation for label
294 |
295 | def label_encode(self, x):
296 | #h0 = self.dropout(F.relu(self.fe0(x))) # [label_dim, emb_size]
297 | h1 = self.dropout(F.relu(self.fe1(x))) # [label_dim, 512]
298 | h2 = self.dropout(F.relu(self.fe2(h1))) # [label_dim, 256]
299 | mu = self.fe_mu(h2) * self.scale_coeff # [label_dim, latent_dim]
300 | logvar = self.fe_logvar(h2) * self.scale_coeff # [label_dim, latent_dim]
301 |
302 | fe_output = {
303 | 'fe_mu': mu,
304 | 'fe_logvar': logvar
305 | }
306 | return fe_output
307 |
308 | def feat_encode(self, x):
309 | h1 = self.dropout(F.relu(self.fx1(x)))
310 | h2 = self.dropout(F.relu(self.fx2(h1)))
311 | h3 = self.dropout(F.relu(self.fx3(h2)))
312 | mu = self.fx_mu(h3) * self.scale_coeff # [bs, latent_dim]
313 | logvar = self.fx_logvar(h3) * self.scale_coeff
314 | mix_coeff = self.fx_mix_coeff(h3) # [bs, K]
315 |
316 | if self.K > 1:
317 | mu = mu.view(x.shape[0], self.K, self.args.Mat2Spec_latent_dim) # [bs, K, latent_dim]
318 | logvar = logvar.view(x.shape[0], self.K, self.args.Mat2Spec_latent_dim) # [bs, K, latent_dim]
319 |
320 | fx_output = {
321 | 'fx_mu': mu,
322 | 'fx_logvar': logvar,
323 | 'fx_mix_coeff': mix_coeff
324 | }
325 | return fx_output
326 |
327 | def label_reparameterize(self, mu, logvar):
328 | std = torch.exp(0.5 * logvar)
329 | eps = torch.randn_like(std)
330 | return mu + eps * std
331 |
332 | def feat_reparameterize(self, mu, logvar, coeff=1.0):
333 | std = torch.exp(0.5 * logvar)
334 | eps = torch.randn_like(std)
335 | return mu + eps * std
336 |
337 | def label_decode(self, z):
338 | d1 = F.relu(self.fd1(z))
339 | d2 = F.leaky_relu(self.fd2(d1))
340 | return d2
341 |
342 | def feat_decode(self, z):
343 | d1 = F.relu(self.fd_x1(z))
344 | d2 = F.leaky_relu(self.fd_x2(d1))
345 | return d2
346 |
347 | def label_forward(self, x, feat): # x is label
348 | n_label = x.shape[1] # label_dim
349 | all_labels = torch.eye(n_label).to(x.device) # [label_dim, label_dim]
350 | fe_output = self.label_encode(all_labels) # map each label to a Gaussian mixture.
351 | mu = fe_output['fe_mu']
352 | logvar = fe_output['fe_logvar']
353 | fe_output['fe_mix_coeff'] = self.fe_mix_coeff(x)
354 | mix_coeff = F.softmax(fe_output['fe_mix_coeff'], dim=-1)
355 |
356 | if self.args.train:
357 | z = self.label_reparameterize(mu, logvar) # [label_dim, latent_dim]
358 | else:
359 | z = mu
360 | z = torch.matmul(mix_coeff, z)
361 |
362 | label_emb = self.label_decode(torch.cat((feat, z), 1))
363 | fe_output['label_emb'] = label_emb
364 | return fe_output
365 |
366 | def feat_forward(self, x):
367 | fx_output = self.feat_encode(x)
368 | mu = fx_output['fx_mu'] # [bs, latent_dim]
369 | logvar = fx_output['fx_logvar'] # [bs, latent_dim]
370 |
371 | if self.args.train:
372 | z = self.feat_reparameterize(mu, logvar)
373 | else:
374 | z = mu
375 | if self.K > 1:
376 | mix_coeff = fx_output['fx_mix_coeff'] # [bs, K]
377 | mix_coeff = F.softmax(mix_coeff, dim=-1)
378 | mix_coeff = mix_coeff.unsqueeze(-1).expand_as(z)
379 | z = z * mix_coeff
380 | z = torch.sum(z, dim=1) # [bs, latent_dim]
381 |
382 | feat_emb = self.feat_decode(torch.cat((x, z), 1)) # [bs, emb_size]
383 | fx_output['feat_emb'] = feat_emb
384 | return fx_output
385 |
386 | def forward(self, data):
387 | label = data.y
388 | feature = self.graph_encoder(data)
389 |
390 | fe_output = self.label_forward(label, feature)
391 | label_emb = fe_output['label_emb'] # [bs, emb_size]
392 | fx_output = self.feat_forward(feature)
393 | feat_emb = fx_output['feat_emb'] # [bs, emb_size]
394 | W = self.W.weight # [emb_size, label_dim]
395 | label_out = torch.matmul(label_emb, W) # [bs, emb_size] * [emb_size, label_dim] = [bs, label_dim]
396 | feat_out = torch.matmul(feat_emb, W) # [bs, label_dim]
397 |
398 | label_proj = self.emb_proj(label_emb)
399 | feat_proj = self.emb_proj(feat_emb)
400 | fe_output.update(fx_output)
401 | output = fe_output
402 |
403 | if self.args.label_scaling == 'normalized_max':
404 | label_out = F.relu(label_out)
405 | feat_out = F.relu(feat_out)
406 | maxima, _ = torch.max(label_out, dim=1)
407 | label_out = label_out.div(maxima.unsqueeze(1)+1e-8)
408 | maxima, _ = torch.max(feat_out, dim=1)
409 | feat_out = feat_out.div(maxima.unsqueeze(1)+1e-8)
410 |
411 | output['label_out'] = label_out
412 | output['feat_out'] = feat_out
413 | output['label_proj'] = label_proj
414 | output['feat_proj'] = feat_proj
415 | return output
416 |
417 | def kl(fx_mu, fe_mu, fx_logvar, fe_logvar):
418 | kl_loss = 0.5 * torch.sum(
419 | (fx_logvar - fe_logvar) - 1 + torch.exp(fe_logvar - fx_logvar) + (fx_mu - fe_mu)**2 / (
420 | torch.exp(fx_logvar) + 1e-8), dim=-1)
421 | return kl_loss
422 |
423 | def compute_c_loss(BX, BY, tau=1):
424 | BX = F.normalize(BX, dim=1)
425 | BY = F.normalize(BY, dim=1)
426 | b = torch.matmul(BX, torch.transpose(BY, 0, 1)) # [bs, bs]
427 | b = torch.exp(b/tau)
428 | b_diag = torch.diagonal(b, 0).unsqueeze(1) # [bs, 1]
429 | b_sum = torch.sum(b, dim=-1, keepdim=True) # [bs, 1]
430 | c = b_diag/(b_sum-b_diag)
431 | c_loss = -torch.mean(torch.log(c))
432 | return c_loss
433 |
434 | def compute_loss(input_label, output, NORMALIZER, args):
435 | fe_out, fe_mu, fe_logvar, label_emb, label_proj = output['label_out'], output['fe_mu'], output['fe_logvar'], output['label_emb'], output['label_proj']
436 | fx_out, fx_mu, fx_logvar, feat_emb, feat_proj = output['feat_out'], output['fx_mu'], output['fx_logvar'], output['feat_emb'], output['feat_proj']
437 |
438 | fx_mix_coeff = output['fx_mix_coeff'] # [bs, K]
439 | fe_mix_coeff = output['fe_mix_coeff']
440 | fx_mix_coeff = F.softmax(fx_mix_coeff, dim=-1)
441 | fe_mix_coeff = F.softmax(fe_mix_coeff, dim=-1)
442 | fe_mix_coeff = fe_mix_coeff.repeat(1, args.Mat2Spec_K)
443 | fx_mix_coeff = fx_mix_coeff.repeat(1, args.Mat2Spec_label_dim)
444 | mix_coeff = fe_mix_coeff * fx_mix_coeff
445 | fx_mu = fx_mu.repeat(1, args.Mat2Spec_label_dim, 1)
446 | fx_logvar = fx_logvar.repeat(1, args.Mat2Spec_label_dim, 1)
447 | fe_mu = fe_mu.squeeze(0).expand(fx_mu.shape[0], fe_mu.shape[0], fe_mu.shape[1])
448 | fe_logvar = fe_logvar.squeeze(0).expand(fx_mu.shape[0], fe_logvar.shape[0], fe_logvar.shape[1])
449 | fe_mu = fe_mu.repeat(1, args.Mat2Spec_K, 1)
450 | fe_logvar = fe_logvar.repeat(1, args.Mat2Spec_K, 1)
451 | kl_all = kl(fx_mu, fe_mu, fx_logvar, fe_logvar)
452 | kl_all_inv = kl(fe_mu, fx_mu, fe_logvar, fx_logvar)
453 | kl_loss = torch.mean(torch.sum(mix_coeff * (0.5*kl_all + 0.5*kl_all_inv), dim=-1))
454 | #c_loss = torch.mean(-1 * F.cosine_similarity(label_proj, feat_proj))
455 | c_loss = compute_c_loss(label_proj, feat_proj)
456 |
457 | if args.label_scaling == 'normalized_sum':
458 | assert args.Mat2Spec_loss_type == 'KL' or args.Mat2Spec_loss_type == 'WD'
459 | #input_label_normalize = F.softmax(torch.log(input_label+1e-6), dim=1)
460 | input_label_normalize = input_label / (torch.sum(input_label, dim=1, keepdim=True)+1e-8)
461 | pred_e = F.softmax(fe_out, dim=1)
462 | pred_x = F.softmax(fx_out, dim=1)
463 | #nll_loss = kl_loss_fn(torch.log(pred_e+1e-8), input_label_normalize)
464 | #nll_loss_x = kl_loss_fn(torch.log(pred_x+1e-8), input_label_normalize)
465 | P = input_label_normalize
466 | Q_e = pred_e
467 | Q_x = pred_x
468 | c1, c2, c3 = 1, 1.1, 0.1
469 | if args.ablation_LE:
470 | c2 = 0.0
471 | if args.ablation_CL:
472 | c3 = 0.0
473 |
474 | if args.Mat2Spec_loss_type == 'KL':
475 | nll_loss = torch.mean(torch.sum(P*(torch.log(P+1e-8)-torch.log(Q_e+1e-8)),dim=1)) \
476 | #+ torch.mean(torch.sum(Q_e*(torch.log(Q_e+1e-8)-torch.log(P+1e-8)),dim=1))
477 | nll_loss_x = torch.mean(torch.sum(P*(torch.log(P+1e-8)-torch.log(Q_x+1e-8)),dim=1)) \
478 | #+ torch.mean(torch.sum(Q_x*(torch.log(Q_x+1e-8)-torch.log(P+1e-8)),dim=1))
479 | elif args.Mat2Spec_loss_type == 'WD':
480 | #nll_loss, _, _ = sinkhorn(Q_e, P)
481 | #nll_loss_x, _, _ = sinkhorn(Q_x, P)
482 | nll_loss = torch_wasserstein_loss(Q_e, P)
483 | nll_loss_x = torch_wasserstein_loss(Q_x, P)
484 | total_loss = (nll_loss + nll_loss_x) * c1 + kl_loss * c2 + c_loss * c3
485 |
486 | return total_loss, nll_loss, nll_loss_x, kl_loss, c_loss, pred_e, pred_x
487 |
488 | else: # standardized or normalized_max
489 | assert args.Mat2Spec_loss_type == 'MAE' or args.Mat2Spec_loss_type == 'MSE'
490 | pred_e = fe_out
491 | pred_x = fx_out
492 | c1, c2, c3 = 1, 1.1, 0.1
493 | if args.ablation_LE:
494 | c2 = 0.0
495 | if args.ablation_CL:
496 | c3 = 0.0
497 |
498 | if args.Mat2Spec_loss_type == 'MAE':
499 | nll_loss = torch.mean(torch.abs(pred_e-input_label))
500 | nll_loss_x = torch.mean(torch.abs(pred_x-input_label))
501 | elif args.Mat2Spec_loss_type == 'MSE':
502 | nll_loss = torch.mean((pred_e-input_label)**2)
503 | nll_loss_x = torch.mean((pred_x-input_label)**2)
504 | total_loss = (nll_loss + nll_loss_x) * c1 + kl_loss * c2 + c_loss * c3
505 |
506 | return total_loss, nll_loss, nll_loss_x, kl_loss, c_loss, pred_e, pred_x
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/Mat2Spec/SinkhornDistance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
5 |
6 | # Adapted from https://github.com/gpeyre/SinkhornAutoDiff
7 | class SinkhornDistance(nn.Module):
8 | r"""
9 | Given two empirical measures each with :math:`P_1` locations
10 | :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
11 | outputs an approximation of the regularized OT cost for point clouds.
12 | Args:
13 | eps (float): regularization coefficient
14 | max_iter (int): maximum number of Sinkhorn iterations
15 | reduction (string, optional): Specifies the reduction to apply to the output:
16 | 'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
17 | 'mean': the sum of the output will be divided by the number of
18 | elements in the output, 'sum': the output will be summed. Default: 'none'
19 | Shape:
20 | - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
21 | - Output: :math:`(N)` or :math:`()`, depending on `reduction`
22 | """
23 | def __init__(self, eps, max_iter, reduction='none'):
24 | super(SinkhornDistance, self).__init__()
25 | self.eps = eps
26 | self.max_iter = max_iter
27 | self.reduction = reduction
28 |
29 | def forward(self, mu, nu):
30 | # The Sinkhorn algorithm takes as input three variables :
31 | C = self._cost_matrix(mu.shape[0], mu.shape[1]).to(device) # Wasserstein cost function
32 | #x_points = x.shape[-2]
33 | #y_points = y.shape[-2]
34 | #if x.dim() == 2:
35 | # batch_size = 1
36 | #else:
37 | # batch_size = x.shape[0]
38 |
39 | # both marginals are fixed with equal weights
40 | #mu = torch.empty(batch_size, x_points, dtype=torch.float,
41 | # requires_grad=False).fill_(1.0 / x_points).squeeze()
42 | #nu = torch.empty(batch_size, y_points, dtype=torch.float,
43 | # requires_grad=False).fill_(1.0 / y_points).squeeze()
44 |
45 | u = torch.zeros_like(mu)
46 | v = torch.zeros_like(nu)
47 | # To check if algorithm terminates because of threshold
48 | # or max iterations reached
49 | actual_nits = 0
50 | # Stopping criterion
51 | thresh = 1e-2
52 |
53 | # Sinkhorn iterations
54 | for i in range(self.max_iter):
55 | u1 = u # useful to check the update
56 | u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
57 | v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
58 | err = (u - u1).abs().sum(-1).mean()
59 |
60 | actual_nits += 1
61 | if err.item() < thresh:
62 | break
63 |
64 | U, V = u, v
65 | # Transport plan pi = diag(a)*K*diag(b)
66 | pi = torch.exp(self.M(C, U, V))
67 | # Sinkhorn distance
68 | cost = torch.sum(pi * C, dim=(-2, -1))
69 |
70 | if self.reduction == 'mean':
71 | cost = cost.mean()
72 | elif self.reduction == 'sum':
73 | cost = cost.sum()
74 |
75 | return cost, pi, C
76 |
77 | def M(self, C, u, v):
78 | "Modified cost for logarithmic updates"
79 | "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
80 | return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps # [bs,N,N]
81 |
82 | @staticmethod
83 | def _cost_matrix(batch_size, n, p=2):
84 | "Returns the matrix of $|x_i-y_j|^p$."
85 |
86 | a = np.array([[[i, 0] for i in range(n)] for b in range(batch_size)])
87 | b = np.array([[[i, 1] for i in range(n)] for b in range(batch_size)])
88 |
89 | # Wrap with torch tensors
90 | x = torch.tensor(a, dtype=torch.float, requires_grad=False)
91 | y = torch.tensor(b, dtype=torch.float, requires_grad=False)
92 |
93 | x_col = x.unsqueeze(-2)
94 | y_lin = y.unsqueeze(-3)
95 | C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1) # [bs, N, N]
96 | return C
97 |
98 | @staticmethod
99 | def ave(u, u1, tau):
100 | "Barycenter subroutine, used by kinetic acceleration through extrapolation."
101 | return tau * u + (1 - tau) * u1
--------------------------------------------------------------------------------
/Mat2Spec_Codes/Mat2Spec/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/Mat2Spec/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import functools
4 | import torch
5 | import pickle
6 | from torch.utils.data import Dataset
7 | from torch_geometric.data import Dataset as torch_Dataset
8 | from torch_geometric.data import Data, DataLoader as torch_DataLoader
9 | import sys, json, os
10 | from pymatgen.core.structure import Structure
11 | from sklearn.cluster import KMeans
12 | from sklearn.cluster import SpectralClustering as SPCL
13 | import warnings
14 | from Mat2Spec.utils import *
15 | from os import path
16 |
17 | # Note: this file for data loading is modified from https://github.com/superlouis/GATGNN/blob/master/gatgnn/data.py
18 |
19 | # gpu_id = 0
20 | # device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')
21 | device = set_device()
22 |
23 | if not sys.warnoptions:
24 | warnings.simplefilter("ignore")
25 |
26 | def mkdirs(path):
27 | if not os.path.exists(path):
28 | os.makedirs(path)
29 |
30 | class ELEM_Encoder:
31 | def __init__(self):
32 | self.elements = ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl',
33 | 'Ar', 'K',
34 | 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br',
35 | 'Kr', 'Rb',
36 | 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I',
37 | 'Xe', 'Cs',
38 | 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu',
39 | 'Hf', 'Ta',
40 | 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac',
41 | 'Th', 'Pa',
42 | 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr'] # 103
43 | self.e_arr = np.array(self.elements)
44 |
45 | def encode(self, composition_dict): # from formula to composition, which is a vector of length 103
46 | answer = [0] * len(self.elements)
47 |
48 | elements = [str(i) for i in composition_dict.keys()]
49 | counts = [j for j in composition_dict.values()]
50 | total = sum(counts)
51 |
52 | for idx in range(len(elements)):
53 | elem = elements[idx]
54 | ratio = counts[idx] / total
55 | idx_e = self.elements.index(elem)
56 | answer[idx_e] = ratio
57 | return torch.tensor(answer).float().view(1, -1)
58 |
59 | def decode_pymatgen_num(tensor_idx): # from ele_num to ele_name
60 | idx = (tensor_idx - 1).cpu().tolist()
61 | return self.e_arr[idx]
62 |
63 |
64 | class DATA_normalizer:
65 | def __init__(self, array):
66 | tensor = torch.tensor(array)
67 | self.mean = torch.mean(tensor, dim=0).float()
68 | self.std = torch.std(tensor, dim=0).float()
69 |
70 | def reg(self, x):
71 | return x.float()
72 |
73 | def log10(self, x):
74 | return torch.log10(x)
75 |
76 | def delog10(self, x):
77 | return 10 * x
78 |
79 | def norm(self, x):
80 | return (x - self.mean) / self.std
81 |
82 | def denorm(self, x):
83 | return x * self.std + self.mean
84 |
85 |
86 | class METRICS:
87 | def __init__(self, c_property, epoch, torch_criterion, torch_func, device):
88 | self.c_property = c_property
89 | self.criterion = torch_criterion
90 | self.eval_func = torch_func
91 | self.dv = device
92 | self.training_measure1 = torch.tensor(0.0).to(device)
93 | self.training_measure2 = torch.tensor(0.0).to(device)
94 | self.valid_measure1 = torch.tensor(0.0).to(device)
95 | self.valid_measure2 = torch.tensor(0.0).to(device)
96 |
97 | self.training_counter = 0
98 | self.valid_counter = 0
99 |
100 | self.training_loss1 = []
101 | self.training_loss2 = []
102 | self.valid_loss1 = []
103 | self.valid_loss2 = []
104 | self.duration = []
105 | self.dataframe = self.to_frame()
106 |
107 | def __str__(self):
108 | x = self.to_frame()
109 | return x.to_string()
110 |
111 | def to_frame(self):
112 | metrics_df = pd.DataFrame(list(zip(self.training_loss1, self.training_loss2,
113 | self.valid_loss1, self.valid_loss2, self.duration)),
114 | columns=['training_1', 'training_2', 'valid_1', 'valid_2', 'time'])
115 | return metrics_df
116 |
117 | def set_label(self, which_phase, graph_data):
118 | use_label = graph_data.y
119 | return use_label
120 |
121 | def save_time(self, e_duration):
122 | self.duration.append(e_duration)
123 |
124 | def __call__(self, which_phase, tensor_pred, tensor_true, measure=1):
125 | if measure == 1:
126 | if which_phase == 'training':
127 | loss = self.criterion(tensor_pred, tensor_true)
128 | self.training_measure1 += loss
129 | elif which_phase == 'validation':
130 | loss = self.criterion(tensor_pred, tensor_true)
131 | self.valid_measure1 += loss
132 | else:
133 | if which_phase == 'training':
134 | loss = self.eval_func(tensor_pred, tensor_true)
135 | self.training_measure2 += loss
136 | elif which_phase == 'validation':
137 | loss = self.eval_func(tensor_pred, tensor_true)
138 | self.valid_measure2 += loss
139 | return loss
140 |
141 | def reset_parameters(self, which_phase, epoch):
142 | if which_phase == 'training':
143 | # AVERAGES
144 | t1 = self.training_measure1 / (self.training_counter)
145 | t2 = self.training_measure2 / (self.training_counter)
146 |
147 | self.training_loss1.append(t1.item())
148 | self.training_loss2.append(t2.item())
149 | self.training_measure1 = torch.tensor(0.0).to(self.dv)
150 | self.training_measure2 = torch.tensor(0.0).to(self.dv)
151 | self.training_counter = 0
152 | else:
153 | # AVERAGES
154 | v1 = self.valid_measure1 / (self.valid_counter)
155 | v2 = self.valid_measure2 / (self.valid_counter)
156 |
157 | self.valid_loss1.append(v1.item())
158 | self.valid_loss2.append(v2.item())
159 | self.valid_measure1 = torch.tensor(0.0).to(self.dv)
160 | self.valid_measure2 = torch.tensor(0.0).to(self.dv)
161 | self.valid_counter = 0
162 |
163 | def save_info(self):
164 | with open('MODELS/metrics_.pickle', 'wb') as metrics_file:
165 | pickle.dump(self, metrics_file)
166 |
167 |
168 | class GaussianDistance(object):
169 | def __init__(self, dmin, dmax, step, var=None):
170 | assert dmin < dmax
171 | assert dmax - dmin > step
172 | self.filter = np.arange(dmin, dmax + step, step) # int((dmax-dmin) / step) + 1
173 | if var is None:
174 | var = step
175 | self.var = var
176 |
177 | def expand(self, distances):
178 | # print(distances.shape) [nbr, nbr]
179 | # x = distances[..., np.newaxis] [nbr, nbr, 1]
180 | # print(self.filter.shape)
181 | # print((x-self.filter).shape)
182 | return np.exp(-(distances[..., np.newaxis] - self.filter) ** 2 / self.var ** 2)
183 |
184 |
185 | class AtomInitializer(object):
186 | def __init__(self, atom_types):
187 | self.atom_types = set(atom_types)
188 | self._embedding = {}
189 |
190 | def get_atom_fea(self, atom_type):
191 | assert atom_type in self.atom_types
192 | return self._embedding[atom_type]
193 |
194 | def load_state_dict(self, state_dict):
195 | self._embedding = state_dict
196 | self.atom_types = set(self._embedding.keys())
197 | self._decodedict = {idx: atom_type for atom_type, idx in
198 | self._embedding.items()}
199 |
200 | def state_dict(self):
201 | return self._embedding
202 |
203 | def decode(self, idx):
204 | if not hasattr(self, '_decodedict'):
205 | self._decodedict = {idx: atom_type for atom_type, idx in
206 | self._embedding.items()}
207 | return self._decodedict[idx]
208 |
209 |
210 | class AtomCustomJSONInitializer(AtomInitializer):
211 | def __init__(self, elem_embedding_file):
212 | with open(elem_embedding_file) as f:
213 | elem_embedding = json.load(f)
214 | elem_embedding = {int(key): value for key, value
215 | in elem_embedding.items()}
216 | atom_types = set(elem_embedding.keys()) # 100
217 | super(AtomCustomJSONInitializer, self).__init__(atom_types)
218 | for key, value in elem_embedding.items():
219 | self._embedding[key] = np.array(value, dtype=float)
220 |
221 |
222 | class CIF_Lister(Dataset):
223 | def __init__(self, crystals_ids, full_dataset, df=None):
224 | self.crystals_ids = crystals_ids
225 | self.full_dataset = full_dataset
226 | self.material_ids = df.iloc[crystals_ids].values[:, 0].squeeze() # MP-xxx
227 |
228 | def __len__(self):
229 | return len(self.crystals_ids)
230 |
231 | def extract_ids(self, original_dataset):
232 | names = original_dataset.iloc[self.crystals_ids]
233 | return names
234 |
235 | def __getitem__(self, idx):
236 | i = self.crystals_ids[idx]
237 | material = self.full_dataset[i]
238 |
239 | n_features = material[0][0]
240 | e_features = material[0][1] # [n_atom, nbr, 41]
241 | e_features = e_features.view(-1, 41)
242 | a_matrix = material[0][2]
243 |
244 | groups = material[1]
245 | enc_compo = material[2] # normalize feat
246 | coordinates = material[3]
247 | y = material[4] # target
248 |
249 | graph_crystal = Data(x=n_features, y=y, edge_attr=e_features, edge_index=a_matrix, global_feature=enc_compo, \
250 | cluster=groups, num_atoms=torch.tensor([len(n_features)]).float(), coords=coordinates,
251 | the_idx=torch.tensor([float(i)]))
252 |
253 | return graph_crystal
254 |
255 | class CIF_Dataset(Dataset):
256 | def __init__(self, args, pd_data=None, np_data=None, norm_obj=None, normalization=None, max_num_nbr=12, radius=8,
257 | dmin=0, step=0.2, cls_num=3, root_dir='DATA/'):
258 | self.root_dir = root_dir
259 | self.max_num_nbr, self.radius = max_num_nbr, radius
260 | self.pd_data = pd_data
261 | self.np_data = np_data
262 | self.ari = AtomCustomJSONInitializer(self.root_dir + 'atom_init.json')
263 | self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step)
264 | self.clusterizer = SPCL(n_clusters=cls_num, random_state=None, assign_labels='discretize')
265 | self.clusterizer2 = KMeans(n_clusters=cls_num, random_state=None)
266 | self.encoder_elem = ELEM_Encoder()
267 | self.update_root = None
268 | self.args = args
269 | if self.args.data_src == 'ph_dos_51':
270 | #self.structures = torch.load('DATA/20210612_ph_dos_51/ph_structures.pt')
271 | pkl_file = open('../Mat2Spec_DATA/phdos/ph_structures.pkl', 'rb')
272 | self.structures = pickle.load(pkl_file)
273 | pkl_file.close()
274 |
275 | def __len__(self):
276 | return len(self.pd_data)
277 |
278 | # @functools.lru_cache(maxsize=None) # Cache loaded structures
279 | def __getitem__(self, idx):
280 | cif_id = self.pd_data.iloc[idx][0]
281 | target = self.np_data[idx]
282 |
283 | catche_data_exist = False
284 |
285 | if self.args.data_src == 'binned_dos_128':
286 | if path.exists(f'../Mat2Spec_DATA/materials_with_edos_processed/' + cif_id + '.chkpt'):
287 | catche_data_exist = True
288 | elif self.args.data_src == 'ph_dos_51':
289 | if path.exists(f'../Mat2Spec_DATA/materials_with_phdos_processed/' + str(cif_id) + '.chkpt'):
290 | catche_data_exist = True
291 | elif self.args.data_src == 'no_label_128':
292 | if path.exists(f'../Mat2Spec_DATA/materials_without_dos_processed/' + cif_id + '.chkpt'):
293 | catche_data_exist = True
294 |
295 | if self.args.use_catached_data and catche_data_exist:
296 | if self.args.data_src == 'binned_dos_128':
297 | tmp_dist = torch.load(f'../Mat2Spec_DATA/materials_with_edos_processed/' + cif_id + '.chkpt')
298 | elif self.args.data_src == 'ph_dos_51':
299 | tmp_dist = torch.load(f'../Mat2Spec_DATA/materials_with_phdos_processed/' + str(cif_id) + '.chkpt')
300 | elif self.args.data_src == 'no_label_128':
301 | tmp_dist = torch.load(f'../Mat2Spec_DATA/materials_without_dos_processed/' + cif_id + '.chkpt')
302 |
303 | atom_fea = tmp_dist['atom_fea']
304 | nbr_fea = tmp_dist['nbr_fea']
305 | nbr_fea_idx = tmp_dist['nbr_fea_idx']
306 | groups = tmp_dist['groups']
307 | enc_compo = tmp_dist['enc_compo']
308 | coordinates = tmp_dist['coordinates']
309 | target = tmp_dist['target']
310 | cif_id = tmp_dist['cif_id']
311 | atom_id = tmp_dist['atom_id']
312 | return (atom_fea, nbr_fea, nbr_fea_idx), groups, enc_compo, coordinates, target, cif_id, atom_id
313 |
314 | if self.args.data_src == 'binned_dos_128':
315 | with open(os.path.join(self.root_dir + 'materials_with_edos/', 'dos_' + cif_id + '.json')) as json_file:
316 | data = json.load(json_file)
317 | crystal = Structure.from_dict(data['structure'])
318 | elif self.args.data_src == 'ph_dos_51':
319 | crystal = self.structures[idx]
320 | elif self.args.data_src == 'no_label_128':
321 | with open(os.path.join(self.root_dir + 'materials_without_dos/', cif_id + '.json')) as json_file:
322 | data = json.load(json_file)
323 | crystal = Structure.from_dict(data['structure'])
324 |
325 | atom_fea = np.vstack([self.ari.get_atom_fea(crystal[i].specie.number) for i in range(len(crystal))])
326 |
327 | atom_fea = torch.Tensor(atom_fea)
328 |
329 | all_nbrs = crystal.get_all_neighbors(self.radius, include_index=True) # (site, distance, index, image)
330 |
331 | all_nbrs = [sorted(nbrs, key=lambda x: x[1]) for nbrs in all_nbrs] # [num_atom in this crystal]
332 | nbr_fea_idx, nbr_fea = [], []
333 | for nbr in all_nbrs:
334 | if len(nbr) < self.max_num_nbr:
335 | nbr_fea_idx.append(list(map(lambda x: x[2], nbr)) + [0] * (self.max_num_nbr - len(nbr)))
336 | nbr_fea.append(list(map(lambda x: x[1], nbr)) + [self.radius + 1.] * (self.max_num_nbr - len(nbr)))
337 | else:
338 | nbr_fea_idx.append(list(map(lambda x: x[2], nbr[:self.max_num_nbr])))
339 | nbr_fea.append(list(map(lambda x: x[1], nbr[:self.max_num_nbr])))
340 | nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea)
341 |
342 | # print(nbr_fea_idx.shape) # [n_atom, nbr]
343 | # print(nbr_fea.shape) # [n_atom, nbr]
344 | nbr_fea = self.gdf.expand(nbr_fea)
345 | # print(nbr_fea.shape) # [n_atom, nbr, 41]
346 |
347 | g_coords = crystal.cart_coords
348 | # print(g_coords.shape) # [n_atom, 3]
349 | groups = [0] * len(g_coords)
350 | if len(g_coords) > 2:
351 | try:
352 | groups = self.clusterizer.fit_predict(g_coords)
353 | except:
354 | groups = self.clusterizer2.fit_predict(g_coords)
355 | groups = torch.tensor(groups).long() # [n_atom]
356 |
357 | atom_fea = torch.Tensor(atom_fea)
358 | nbr_fea = torch.Tensor(nbr_fea)
359 | nbr_fea_idx = self.format_adj_matrix(torch.LongTensor(nbr_fea_idx)) # [2, E]
360 |
361 | target = torch.Tensor(target.astype(float)).view(1, -1)
362 |
363 | coordinates = torch.tensor(g_coords) # [n_atom, 3]
364 | enc_compo = self.encoder_elem.encode(crystal.composition) # [1, 103]
365 |
366 | tmp_dist = {}
367 | tmp_dist['atom_fea'] = atom_fea
368 | tmp_dist['nbr_fea'] = nbr_fea
369 | tmp_dist['nbr_fea_idx'] = nbr_fea_idx
370 | tmp_dist['groups'] = groups
371 | tmp_dist['enc_compo'] = enc_compo
372 | tmp_dist['coordinates'] = coordinates
373 | tmp_dist['target'] = target
374 | tmp_dist['cif_id'] = cif_id
375 | tmp_dist['atom_id'] = [crystal[i].specie for i in range(len(crystal))]
376 |
377 | if self.args.data_src == 'binned_dos_128':
378 | pa = '../Mat2Spec_DATA/materials_with_edos_processed/'
379 | mkdirs(pa)
380 | torch.save(tmp_dist, pa + cif_id + '.chkpt')
381 | elif self.args.data_src == 'ph_dos_51':
382 | pa = '../Mat2Spec_DATA/materials_with_phdos_processed/'
383 | mkdirs(pa)
384 | torch.save(tmp_dist, pa + str(cif_id) + '.chkpt')
385 | elif self.args.data_src == 'no_label_128':
386 | pa = '../Mat2Spec_DATA/materials_without_dos_processed/'
387 | mkdirs(pa)
388 | torch.save(tmp_dist, pa + cif_id + '.chkpt')
389 |
390 | return (atom_fea, nbr_fea, nbr_fea_idx), groups, enc_compo, coordinates, target, cif_id, [crystal[i].specie for i in range(len(crystal))]
391 |
392 | def format_adj_matrix(self, adj_matrix):
393 | size = len(adj_matrix)
394 | src_list = list(range(size))
395 | all_src_nodes = torch.tensor([[x] * adj_matrix.shape[1] for x in src_list]).view(-1).long().unsqueeze(0)
396 | all_dst_nodes = adj_matrix.view(-1).unsqueeze(0)
397 |
398 | return torch.cat((all_src_nodes, all_dst_nodes), dim=0)
399 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/Mat2Spec/file_setter.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from shutil import copyfile
4 |
5 | def use_property(property_name,source, do_prediction = False):
6 |
7 | print('> Preparing dataset to use for Property Prediction. Please wait ...')
8 |
9 | if property_name in ['band','bandgap','band-gap']: filename = 'bandgap.csv' ;p=1;num_T = 36720
10 | elif property_name in ['bulk','bulkmodulus','bulk-modulus','bulk-moduli']:filename = 'bulkmodulus.csv' ;p=3;num_T = 4664
11 | elif property_name in ['energy-1','formationenergy','formation-energy']: filename = 'formationenergy.csv' ;p=2;num_T = 60000
12 | elif property_name in ['energy-2','fermienergy','fermi-energy']: filename = 'fermienergy.csv' ;p=2;num_T = 60000
13 | elif property_name in ['energy-3','absoluteenergy','absolute-energy']: filename = 'absoluteenergy.csv' ;p=2;num_T = 60000
14 | elif property_name in ['shear','shearmodulus','shear-modulus','shear-moduli']:filename = 'shearmodulus.csv';p=4;num_T = 4664
15 | elif property_name in ['poisson','poissonratio','poisson-ratio']: filename = 'poissonratio.csv' ;p=4;num_T = 4664
16 | elif property_name in ['is_metal','is_not_metal']: filename = 'ismetal.csv' ;p=2;num_T = 55391
17 | elif property_name == 'new-property' : filename = 'newproperty.csv' ;p=None;num_T = None
18 |
19 | df = pd.read_csv(f'DATA/properties-reference/{filename}',names=['material_id','value']).replace(to_replace='None',value=np.nan).dropna()
20 |
21 | # CGCNN
22 | if source == 'CGCNN':
23 | # SAVING THE PROPERTIES SEPARATELY
24 | cif_dir = 'CIF-DATA'
25 | if filename in ['bulkmodulus.csv','shearmodulus.csv','poissonratio.csv']:
26 | small = pd.read_csv(f'DATA/cgcnn-reference/mp-ids-3402.csv' ,names=['mp_ids']).values.squeeze()
27 | df = df[df.material_id.isin(small)]
28 | num_T = 2041
29 | elif filename == 'bandgap.csv':
30 | medium = pd.read_csv(f'DATA/cgcnn-reference/mp-ids-27430.csv',names=['mp_ids']).values.squeeze()
31 | df = df[df.material_id.isin(medium)]
32 | num_T = 16458
33 | elif filename in ['formationenergy.csv','fermienergy.csv','ismetal.csv','absoluteenergy.csv']:
34 | large = pd.read_csv(f'DATA/cgcnn-reference/mp-ids-46744.csv',names=['mp_ids']).values.squeeze()
35 | df = df[df.material_id.isin(large)]
36 | num_T = 28046
37 | CIF_dict = {'radius':8,'step':0.2,'max_num_nbr':12}
38 |
39 | # MEGNET
40 | elif source == 'MEGNET':
41 | cif_dir = 'CIF-DATA'
42 | megnet_df = pd.read_csv('DATA/megnet-reference/megnet.csv')
43 | use_ids = megnet_df[megnet_df.iloc[:,p]==1].material_id.values.squeeze()
44 | df = df[df.material_id.isin(use_ids)]
45 | CIF_dict = {'radius':4,'step':0.5,'max_num_nbr':16}
46 |
47 | # CUSTOM
48 | elif source == 'NEW':
49 | cif_dir = 'CIF-DATA_NEW'
50 | CIF_dict = {'radius':8,'step':0.2,'max_num_nbr':12}
51 | d_src = 'DATA'
52 | src, dst = d_src+'/CIF-DATA/atom_init.json',d_src+'/CIF-DATA_NEW/atom_init.json'
53 | copyfile(src, dst)
54 |
55 |
56 | # ADDITIONAL CLEANING
57 | if p in [3,4]:
58 | df = df[df.value>0]
59 |
60 |
61 | df.to_csv(f'DATA/{cif_dir}/id_prop.csv',index=False,header=False)
62 | if not do_prediction: print(f'> Dataset for {source}---{property_name} ready !\n\n')
63 | return source,num_T,CIF_dict
64 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/Mat2Spec/pytorch_stats_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 |
4 |
5 | #######################################################
6 | # STATISTICAL DISTANCES(LOSSES) IN PYTORCH #
7 | #######################################################
8 |
9 | ## Statistial Distances for 1D weight distributions
10 | ## Inspired by Scipy.Stats Statistial Distances for 1D
11 | ## Pytorch Version, supporting Autograd to make a valid Loss
12 | ## Supposing Inputs are Groups of Same-Length Weight Vectors
13 | ## Instead of (Points, Weight), full-length Weight Vectors are taken as Inputs
14 | ## Code Written by E.Bao, CASIA
15 |
16 | def torch_wasserstein_loss(tensor_a,tensor_b):
17 | #Compute the first Wasserstein distance between two 1D distributions.
18 | return(torch_cdf_loss(tensor_a,tensor_b,p=1))
19 |
20 | def torch_energy_loss(tensor_a,tensor_b):
21 | # Compute the energy distance between two 1D distributions.
22 | return((2**0.5)*torch_cdf_loss(tensor_a,tensor_b,p=2))
23 |
24 | def torch_cdf_loss(tensor_a,tensor_b,p=1):
25 | # last-dimension is weight distribution
26 | # p is the norm of the distance, p=1 --> First Wasserstein Distance
27 | # to get a positive weight with our normalized distribution
28 | # we recommend combining this loss with other difference-based losses like L1
29 |
30 | # normalize distribution, add 1e-14 to divisor to avoid 0/0
31 | tensor_a = tensor_a / (torch.sum(tensor_a, dim=-1, keepdim=True) + 1e-14)
32 | tensor_b = tensor_b / (torch.sum(tensor_b, dim=-1, keepdim=True) + 1e-14)
33 | # make cdf with cumsum
34 | cdf_tensor_a = torch.cumsum(tensor_a,dim=-1)
35 | cdf_tensor_b = torch.cumsum(tensor_b,dim=-1)
36 |
37 | # choose different formulas for different norm situations
38 | if p == 1:
39 | cdf_distance = torch.sum(torch.abs((cdf_tensor_a-cdf_tensor_b)),dim=-1)
40 | elif p == 2:
41 | cdf_distance = torch.sqrt(torch.sum(torch.pow((cdf_tensor_a-cdf_tensor_b),2),dim=-1))
42 | else:
43 | cdf_distance = torch.pow(torch.sum(torch.pow(torch.abs(cdf_tensor_a-cdf_tensor_b),p),dim=-1),1/p)
44 |
45 | cdf_loss = cdf_distance.mean()
46 | return cdf_loss
47 |
48 | def torch_validate_distibution(tensor_a,tensor_b):
49 | # Zero sized dimension is not supported by pytorch, we suppose there is no empty inputs
50 | # Weights should be non-negetive, and with a positive and finite sum
51 | # We suppose all conditions will be corrected by network training
52 | # We only check the match of the size here
53 | if tensor_a.size() != tensor_b.size():
54 | raise ValueError("Input weight tensors must be of the same size")
55 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/Mat2Spec/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import pandas as pd
4 | import os
5 | import shutil
6 | import argparse
7 | from operator import attrgetter
8 |
9 | from sklearn.model_selection import train_test_split
10 | from sklearn.metrics import mean_absolute_error as sk_MAE
11 | from tabulate import tabulate
12 | import random,time
13 |
14 | def set_device(gpu_id=0):
15 | device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')
16 | return device
17 |
18 | def set_model_properties(crystal_property):
19 | if crystal_property in ['poisson-ratio','band-gap','absolute-energy','fermi-energy','formation-energy']:
20 | norm_action = None; classification = None
21 | elif crystal_property == 'is_metal':
22 | norm_action = 'classification-1'; classification = 1
23 | elif crystal_property == 'is_not_metal':
24 | norm_action = 'classification-0'; classification = 1
25 | else:
26 | norm_action = 'log'; classification = None
27 | return norm_action, classification
28 |
29 | def torch_MAE(tensor1,tensor2):
30 | return torch.mean(torch.abs(tensor1-tensor2))
31 |
32 | def torch_accuracy(pred_tensor,true_tensor):
33 | _,pred_tensor = torch.max(pred_tensor,dim=1)
34 | correct = (pred_tensor==true_tensor).sum().float()
35 | total = pred_tensor.size(0)
36 | accuracy_ans = correct/total
37 | return accuracy_ans
38 |
39 | def output_training(metrics_obj,epoch,estop_val,extra='---'):
40 | header_1, header_2 = 'MSE | e-stop','MAE | TIME'
41 | if metrics_obj.c_property in ['is_metal','is_not_metal']:
42 | header_1,header_2 = 'Cross_E | e-stop','Accuracy | TIME'
43 |
44 | train_1,train_2 = metrics_obj.training_loss1[epoch],metrics_obj.training_loss2[epoch]
45 | valid_1,valid_2 = metrics_obj.valid_loss1[epoch],metrics_obj.valid_loss2[epoch]
46 |
47 | tab_val = [['TRAINING',f'{train_1:.4f}',f'{train_2:.4f}'],['VALIDATION',f'{valid_1:.4f}',f'{valid_2:.4f}'],['E-STOPPING',f'{estop_val}',f'{extra}']]
48 |
49 | output = tabulate(tab_val,headers= [f'EPOCH # {epoch}',header_1,header_2],tablefmt='fancy_grid')
50 | print(output)
51 | return output
52 |
53 | def load_metrics():
54 | saved_metrics = pickle.load(open("MODELS/metrics_.pickle", "rb", -1))
55 | return saved_metrics
56 |
57 |
58 | def freeze_params(model, params_to_freeze_list):
59 | for str in params_to_freeze_list:
60 | attr = attrgetter(str)(model)
61 | attr.requires_grad = False
62 | attr.grad = None
63 |
64 |
65 | def unfreeze_params(model, params_to_unfreeze_list):
66 | for str in params_to_unfreeze_list:
67 | attr = attrgetter(str)(model)
68 | #print(str)
69 | #print(attr)
70 | attr.requires_grad = True
71 |
72 |
73 | def RobustL1(output, log_std, target):
74 | """
75 | Robust L1 loss using a lorentzian prior. Allows for estimation
76 | of an aleatoric uncertainty.
77 | """
78 | absolute = torch.abs(output - target)
79 | loss = np.sqrt(2.0) * absolute * torch.exp(-log_std) + log_std
80 | return torch.mean(loss)
81 |
82 |
83 | def RobustL2(output, log_std, target):
84 | """
85 | Robust L2 loss using a gaussian prior. Allows for estimation
86 | of an aleatoric uncertainty.
87 | """
88 | squared = torch.pow(output - target, 2.0)
89 | loss = 0.5 * squared * torch.exp(-2.0 * log_std) + log_std
90 | return torch.mean(loss)
91 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/test_dos128_norm_sum_kl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=1 python test_Mat2Spec.py \
4 | --concat_comp '' \
5 | --Mat2Spec-loss-type 'KL' \
6 | --label_scaling 'normalized_sum' \
7 | --data_src 'binned_dos_128' \
8 | --trainset_subset_ratio 1.0 \
9 | --Mat2Spec-label-dim 128 \
10 |
11 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/test_dos128_norm_sum_wd.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=1 python test_Mat2Spec.py \
2 | --concat_comp '' \
3 | --Mat2Spec-loss-type 'WD' \
4 | --label_scaling 'normalized_sum' \
5 | --data_src 'binned_dos_128' \
6 | --trainset_subset_ratio 1.0 \
7 | --Mat2Spec-label-dim 128
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/test_dos128_std_mae.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=1 python test_Mat2Spec.py \
4 | --concat_comp '' \
5 | --Mat2Spec-loss-type 'MAE' \
6 | --label_scaling 'standardized' \
7 | --data_src 'binned_dos_128' \
8 | --trainset_subset_ratio 1.0 \
9 | --Mat2Spec-label-dim 128 \
10 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/test_nolabel128_norm_sum_kl.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python test_Mat2Spec.py \
2 | --concat_comp '' \
3 | --Mat2Spec-loss-type 'KL' \
4 | --label_scaling 'normalized_sum' \
5 | --data_src 'no_label_128' \
6 | --trainset_subset_ratio 1.0 \
7 | --check-point-path './TRAINED/model_Mat2Spec_binned_dos_128_normalized_sum_KL_trainsize1.0.chkpt' \
8 | --Mat2Spec-label-dim 128
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/test_nolabel128_norm_sum_wd.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python test_Mat2Spec.py \
2 | --concat_comp '' \
3 | --Mat2Spec-loss-type 'WD' \
4 | --label_scaling 'normalized_sum' \
5 | --data_src 'no_label_128' \
6 | --trainset_subset_ratio 1.0 \
7 | --check-point-path './TRAINED/model_Mat2Spec_binned_dos_128_normalized_sum_WD_trainsize1.0.chkpt' \
8 | --Mat2Spec-label-dim 128
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/test_nolabel128_std_mae.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python test_Mat2Spec.py \
2 | --concat_comp '' \
3 | --Mat2Spec-loss-type 'MAE' \
4 | --label_scaling 'standardized' \
5 | --data_src 'no_label_128' \
6 | --trainset_subset_ratio 1.0 \
7 | --check-point-path './TRAINED/model_Mat2Spec_binned_dos_128_standardized_MAE_trainsize1.0.chkpt' \
8 | --Mat2Spec-label-dim 128
9 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/test_phdos51_norm_max_mae.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python test_Mat2Spec.py \
2 | --concat_comp '' \
3 | --Mat2Spec-loss-type 'MAE' \
4 | --label_scaling 'normalized_max' \
5 | --data_src 'ph_dos_51' \
6 | --trainset_subset_ratio 1.0 \
7 | --train \
8 | --check-point-path './TRAINED/model_Mat2Spec_ph_dos_51_normalized_max_MAE_trainsize1.0.chkpt' \
9 | --Mat2Spec-label-dim 51 \
10 | --Mat2Spec-keep-prob 0.5 \
11 | --batch-size 8
12 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/test_phdos51_norm_max_mse.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python test_Mat2Spec.py \
2 | --concat_comp '' \
3 | --Mat2Spec-loss-type 'MSE' \
4 | --label_scaling 'normalized_max' \
5 | --data_src 'ph_dos_51' \
6 | --trainset_subset_ratio 1.0 \
7 | --train \
8 | --check-point-path './TRAINED/model_Mat2Spec_ph_dos_51_normalized_max_MSE_trainsize1.0.chkpt' \
9 | --Mat2Spec-label-dim 51 \
10 | --Mat2Spec-keep-prob 0.5 \
11 | --batch-size 8
12 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/test_phdos51_norm_sum_kl.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python test_Mat2Spec.py \
2 | --concat_comp '' \
3 | --Mat2Spec-loss-type 'KL' \
4 | --label_scaling 'normalized_sum' \
5 | --data_src 'ph_dos_51' \
6 | --trainset_subset_ratio 1.0 \
7 | --Mat2Spec-label-dim 51 \
8 | --Mat2Spec-keep-prob 0.5 \
9 | --batch-size 8
10 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/test_phdos51_norm_sum_wd.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python test_Mat2Spec.py \
2 | --concat_comp '' \
3 | --Mat2Spec-loss-type 'WD' \
4 | --label_scaling 'normalized_sum' \
5 | --data_src 'ph_dos_51' \
6 | --trainset_subset_ratio 1.0 \
7 | --Mat2Spec-label-dim 51 \
8 | --Mat2Spec-keep-prob 0.5 \
9 | --batch-size 8
10 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/train_dos128_norm_sum_kl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=1 python train_Mat2Spec.py \
4 | --concat_comp '' \
5 | --Mat2Spec-loss-type 'KL' \
6 | --label_scaling 'normalized_sum' \
7 | --data_src 'binned_dos_128' \
8 | --trainset_subset_ratio 1.0 \
9 | --train \
10 | --Mat2Spec-label-dim 128 \
11 |
12 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/train_dos128_norm_sum_wd.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=1 python train_Mat2Spec.py \
2 | --concat_comp '' \
3 | --Mat2Spec-loss-type 'WD' \
4 | --label_scaling 'normalized_sum' \
5 | --data_src 'binned_dos_128' \
6 | --trainset_subset_ratio 1.0 \
7 | --train \
8 | --Mat2Spec-label-dim 128
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/train_dos128_std_mae.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=1 python train_Mat2Spec.py \
4 | --concat_comp '' \
5 | --Mat2Spec-loss-type 'MAE' \
6 | --label_scaling 'standardized' \
7 | --data_src 'binned_dos_128' \
8 | --trainset_subset_ratio 1.0 \
9 | --train \
10 | --Mat2Spec-label-dim 128 \
11 |
12 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/train_phdos51_norm_max_mae.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train_Mat2Spec.py \
2 | --concat_comp '' \
3 | --Mat2Spec-loss-type 'MAE' \
4 | --label_scaling 'normalized_max' \
5 | --data_src 'ph_dos_51' \
6 | --trainset_subset_ratio 1.0 \
7 | --train \
8 | --Mat2Spec-label-dim 51 \
9 | --Mat2Spec-keep-prob 0.5 \
10 | --batch-size 8
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/train_phdos51_norm_max_mse.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train_Mat2Spec.py \
2 | --concat_comp '' \
3 | --Mat2Spec-loss-type 'MSE' \
4 | --label_scaling 'normalized_max' \
5 | --data_src 'ph_dos_51' \
6 | --trainset_subset_ratio 1.0 \
7 | --train \
8 | --Mat2Spec-label-dim 51 \
9 | --Mat2Spec-keep-prob 0.5 \
10 | --batch-size 8
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/train_phdos51_norm_sum_kl.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train_Mat2Spec.py \
2 | --concat_comp '' \
3 | --Mat2Spec-loss-type 'KL' \
4 | --label_scaling 'normalized_sum' \
5 | --data_src 'ph_dos_51' \
6 | --trainset_subset_ratio 1.0 \
7 | --train \
8 | --Mat2Spec-label-dim 51 \
9 | --Mat2Spec-keep-prob 0.5 \
10 | --batch-size 8
--------------------------------------------------------------------------------
/Mat2Spec_Codes/SCRIPTS/train_phdos51_norm_sum_wd.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train_Mat2Spec.py \
2 | --concat_comp '' \
3 | --Mat2Spec-loss-type 'WD' \
4 | --label_scaling 'normalized_sum' \
5 | --data_src 'ph_dos_51' \
6 | --trainset_subset_ratio 1.0 \
7 | --train \
8 | --Mat2Spec-label-dim 51 \
9 | --Mat2Spec-keep-prob 0.5 \
10 | --batch-size 8
--------------------------------------------------------------------------------
/Mat2Spec_Codes/test_Mat2Spec.py:
--------------------------------------------------------------------------------
1 | from Mat2Spec.data import *
2 | from Mat2Spec.Mat2Spec import *
3 | from Mat2Spec.file_setter import use_property
4 | from Mat2Spec.utils import *
5 | import matplotlib.pyplot as plt
6 | import random
7 | from tqdm import tqdm
8 | import gc
9 | import pickle
10 | from copy import copy, deepcopy
11 | from os import makedirs
12 | torch.autograd.set_detect_anomaly(True)
13 | device = set_device()
14 |
15 | # MOST CRUCIAL DATA PARAMETERS
16 | parser = argparse.ArgumentParser(description='Mat2Spec')
17 | parser.add_argument('--data_src', default='binned_dos_128',choices=['binned_dos_128','binned_dos_32','ph_dos_51', 'no_label_32', 'no_label_128'])
18 | parser.add_argument('--label_scaling', default='standardized',choices=['standardized','normalized_sum', 'normalized_max'])
19 | # MOST CRUCIAL MODEL PARAMETERS
20 | parser.add_argument('--num_layers',default=3, type=int,
21 | help='number of AGAT layers to use in model (default:3)')
22 | parser.add_argument('--num_neurons',default=128, type=int,
23 | help='number of neurons to use per AGAT Layer(default:64)')
24 | parser.add_argument('--num_heads',default=4, type=int,
25 | help='number of Attention-Heads to use per AGAT Layer (default:4)')
26 | parser.add_argument('--concat_comp',default=False, type=bool,
27 | help='option to re-use vector of elemental composition after global summation of crystal feature.(default: False)')
28 | parser.add_argument('--train_size',default=0.8, type=float, help='ratio size of the training-set (default:0.8)')
29 | parser.add_argument('--trainset_subset_ratio',default=0.5, type=float, help='ratio size of the training-set subset (default:0.5)')
30 | parser.add_argument('--use_catached_data', default=True, type=bool)
31 | parser.add_argument("--train",action="store_true") # default value is false
32 | parser.add_argument('--num-epochs',default=200, type=int)
33 | parser.add_argument('--batch-size',default=128, type=int)
34 | parser.add_argument('--lr',default=0.001, type=float)
35 | parser.add_argument('--Mat2Spec-input-dim',default=128, type=int)
36 | parser.add_argument('--Mat2Spec-label-dim',default=128, type=int)
37 | parser.add_argument('--Mat2Spec-latent-dim',default=128, type=int)
38 | parser.add_argument('--Mat2Spec-emb-size',default=512, type=int)
39 | parser.add_argument('--Mat2Spec-keep-prob',default=0.5, type=float)
40 | parser.add_argument('--Mat2Spec-scale-coeff',default=1.0, type=float)
41 | parser.add_argument('--Mat2Spec-loss-type',default='MAE', type=str, choices=['MAE', 'KL', 'WD', 'MSE'])
42 | parser.add_argument('--Mat2Spec-K',default=10, type=int)
43 | parser.add_argument('--check-point-path', default=None, type=str)
44 | parser.add_argument('--test-mpid', default='mpids.csv', type=str)
45 | parser.add_argument("--finetune",action="store_true") # default value is false
46 | parser.add_argument("--finetune-dataset",default='null',type=str)
47 | parser.add_argument("--ablation-LE",action="store_true") # default value is false
48 | parser.add_argument("--ablation-CL",action="store_true") # default value is false
49 | args = parser.parse_args(sys.argv[1:])
50 |
51 | # GNN --- parameters
52 | data_src = args.data_src
53 | RSM = {'radius': 8, 'step': 0.2, 'max_num_nbr': 12}
54 |
55 | number_layers = args.num_layers
56 | number_neurons = args.num_neurons
57 | n_heads = args.num_heads
58 | concat_comp = args.concat_comp
59 |
60 | # SETTING UP CODE TO RUN ON GPU
61 | #gpu_id = 0
62 | #device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')
63 |
64 | # DATA PARAMETERS
65 | random_num = 1; random.seed(random_num)
66 | np.random.seed(random_num)
67 | torch.manual_seed(random_num)
68 | # MODEL HYPER-PARAMETERS
69 | num_epochs = args.num_epochs
70 | learning_rate = args.lr
71 | batch_size = args.batch_size
72 |
73 | stop_patience = 150
74 | best_epoch = 1
75 | adj_epochs = 50
76 | milestones = [150,250]
77 | train_param = {'batch_size':batch_size, 'shuffle': True}
78 | valid_param = {'batch_size':batch_size, 'shuffle': False}
79 |
80 | # DATALOADER/ TARGET NORMALIZATION
81 | if args.data_src == 'binned_dos_128':
82 | pd_data = pd.read_csv(f'../Mat2Spec_DATA/label_edos/'+args.test_mpid)
83 | np_data = np.load(f'../Mat2Spec_DATA/label_edos/total_dos_128.npy')
84 | elif args.data_src == 'ph_dos_51':
85 | pd_data = pd.read_csv(f'../Mat2Spec_DATA/phdos/'+args.test_mpid)
86 | np_data = np.load(f'../Mat2Spec_DATA/phdos/ph_dos.npy')
87 | elif args.data_src == 'no_label_128':
88 | pd_data = pd.read_csv(f'../Mat2Spec_DATA/no_label/'+args.test_mpid)
89 | np_data = np.random.rand(len(pd_data), 128) # dummy label
90 |
91 | NORMALIZER = DATA_normalizer(np_data)
92 |
93 | if args.data_src == 'no_label_128':
94 | mean_tmp = torch.tensor(np.load(f'../Mat2Spec_DATA/no_label/label_mean_binned_dos_128.npy'))
95 | std_tmp = torch.tensor(np.load(f'../Mat2Spec_DATA/no_label/label_std_binned_dos_128.npy'))
96 | NORMALIZER.mean = mean_tmp
97 | NORMALIZER.std = std_tmp
98 |
99 | CRYSTAL_DATA = CIF_Dataset(args, pd_data=pd_data, np_data=np_data, root_dir=f'../Mat2Spec_DATA/', **RSM)
100 |
101 | if args.data_src == 'ph_dos_51':
102 | with open('../Mat2Spec_DATA/phdos/200801_trteva_indices.pkl', 'rb') as f:
103 | train_idx, val_idx, test_idx = pickle.load(f)
104 | elif args.data_src == 'no_label_128':
105 | test_idx = list(range(len(pd_data)))
106 | else:
107 | idx_list = list(range(len(pd_data)))
108 | random.shuffle(idx_list)
109 | train_idx_all, test_val = train_test_split(idx_list, train_size=args.train_size, random_state=random_num)
110 | test_idx, val_idx = train_test_split(test_val, test_size=0.5, random_state=random_num)
111 |
112 | if args.trainset_subset_ratio < 1.0:
113 | train_idx, _ = train_test_split(train_idx_all, train_size=args.trainset_subset_ratio, random_state=random_num)
114 | elif args.data_src != 'ph_dos_51' and args.data_src != 'no_label_128':
115 | train_idx = train_idx_all
116 |
117 | if args.finetune:
118 | assert args.finetune_dataset != 'null'
119 | if args.data_src == 'binned_dos_128':
120 | with open(f'../Mat2Spec_DATA/20210619_binned_32_128/materials_classes/' + args.finetune_dataset + '/test_idx.json', ) as f:
121 | test_idx = json.load(f)
122 | else:
123 | raise ValueError('Finetuning is only supported on the binned dos 128 dataset.')
124 |
125 | print('testing size:', len(test_idx))
126 |
127 | testing_set = CIF_Lister(test_idx, CRYSTAL_DATA, df=pd_data)
128 |
129 | print(f'> USING MODEL Mat2Spec!')
130 | the_network = Mat2Spec(args, NORMALIZER)
131 | net = the_network.to(device)
132 | # load checkpoint
133 | if args.finetune:
134 | check_point_path = './TRAINED/finetune/model_Mat2Spec_' + args.data_src + '_' + args.label_scaling \
135 | + '_' + args.Mat2Spec_loss_type + '_finetune_' + args.finetune_dataset + '.chkpt'
136 | else:
137 | check_point_path = './TRAINED/model_Mat2Spec_' + args.data_src + '_' + args.label_scaling \
138 | + '_' + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.chkpt'
139 |
140 | if args.ablation_LE:
141 | check_point_path = './TRAINED/model_Mat2Spec_binned_dos_128_normalized_sum_KL_trainsize1.0_ablation_LE.chkpt'
142 |
143 | if args.ablation_CL:
144 | check_point_path = './TRAINED/model_Mat2Spec_binned_dos_128_normalized_sum_KL_trainsize1.0_ablation_CL.chkpt'
145 |
146 | if args.check_point_path is not None:
147 | check_point = torch.load(args.check_point_path)
148 | else:
149 | check_point = torch.load(check_point_path)
150 | net.load_state_dict(check_point['model'])
151 |
152 | print(f'> TESTING MODEL ...')
153 | test_loader = torch_DataLoader(dataset=testing_set, **valid_param)
154 |
155 | def test():
156 | training_counter=0
157 | training_loss=0
158 | valid_counter=0
159 | valid_loss=0
160 | best_valid_loss=1e+10
161 | check_fre = 10
162 | current_step = 0
163 | checkpoint_path = './TRAINED/'
164 |
165 | total_loss_smooth = 0
166 | nll_loss_smooth = 0
167 | nll_loss_x_smooth = 0
168 | kl_loss_smooth = 0
169 | cpc_loss_smooth = 0
170 | prediction = []
171 | prediction_x = []
172 | label_gt = []
173 | label_scale_value = []
174 | sum_pred_smooth = 0
175 |
176 | start_time = time.time()
177 |
178 | # TESTING-PHASE
179 | net.eval()
180 | args.train = True
181 | for data in tqdm(test_loader, mininterval=0.5, desc='(testing)', position=0, leave=True, ascii=True):
182 | data = data.to(device)
183 | valid_label = deepcopy(data.y).float().to(device)
184 |
185 | if args.label_scaling == 'standardized':
186 | valid_label_normalize = (valid_label - NORMALIZER.mean.to(device)) / NORMALIZER.std.to(device)
187 | elif args.label_scaling == 'normalized_max':
188 | #valid_label_normalize = F.normalize(valid_label, dim=1, p=1)
189 | valid_label_normalize = valid_label/(torch.max(valid_label,dim=1)[0].unsqueeze(1))
190 |
191 | elif args.label_scaling == 'normalized_sum':
192 | valid_label_normalize = valid_label / torch.sum(valid_label, dim=1, keepdim=True)
193 |
194 | with torch.no_grad():
195 | predictions = net(data)
196 | total_loss, nll_loss, nll_loss_x, kl_loss, cpc_loss, pred_e, pred_x = \
197 | compute_loss(valid_label_normalize, predictions, NORMALIZER, args)
198 |
199 | prediction.append(pred_e.detach().cpu().numpy())
200 | prediction_x.append(pred_x.detach().cpu().numpy())
201 | label_gt.append(valid_label.detach().cpu().numpy())
202 |
203 | total_loss_smooth += total_loss
204 | nll_loss_smooth += nll_loss
205 | nll_loss_x_smooth += nll_loss_x
206 | kl_loss_smooth += kl_loss
207 | cpc_loss_smooth += cpc_loss
208 | valid_counter += 1
209 |
210 | total_loss_smooth = total_loss_smooth / valid_counter
211 | nll_loss_smooth = nll_loss_smooth / valid_counter
212 | nll_loss_x_smooth = nll_loss_x_smooth / valid_counter
213 | kl_loss_smooth = kl_loss_smooth / valid_counter
214 | cpc_loss_smooth = cpc_loss_smooth / valid_counter
215 |
216 | prediction = np.concatenate(prediction, axis=0)
217 | prediction_x = np.concatenate(prediction_x, axis=0)
218 | label_gt = np.concatenate(label_gt, axis=0)
219 |
220 | return prediction, prediction_x, label_gt, total_loss_smooth.cpu().numpy(), nll_loss_smooth.cpu().numpy(), nll_loss_x_smooth.cpu().numpy(), kl_loss_smooth.cpu().numpy()
221 |
222 | prediction_list = []
223 | prediction_x_list = []
224 | label_gt_list = []
225 | total_loss_smooth_list = []
226 | nll_loss_smooth_list = []
227 | nll_loss_x_smooth_list = []
228 | kl_loss_smooth_list = []
229 |
230 | for i in range(3):
231 | print(i)
232 | prediction, prediction_x, label_gt, total_loss_smooth, nll_loss_smooth, nll_loss_x_smooth, kl_loss_smooth = test()
233 | prediction_list.append(np.expand_dims(prediction, axis=0))
234 | prediction_x_list.append(np.expand_dims(prediction_x, axis=0))
235 | label_gt_list.append(np.expand_dims(label_gt, axis=0))
236 | total_loss_smooth_list.append(total_loss_smooth)
237 | nll_loss_smooth_list.append(nll_loss_smooth)
238 | nll_loss_x_smooth_list.append(nll_loss_x_smooth)
239 | kl_loss_smooth_list.append(kl_loss_smooth)
240 |
241 | total_loss_smooth = np.mean(total_loss_smooth_list)
242 | nll_loss_smooth = np.mean(nll_loss_smooth_list)
243 | nll_loss_x_smooth = np.mean(nll_loss_x_smooth_list)
244 | kl_loss_smooth = np.mean(kl_loss_smooth_list)
245 |
246 | prediction = np.concatenate(prediction_list, axis=0)
247 | prediction_x = np.concatenate(prediction_x_list, axis=0)
248 | label_gt = np.concatenate(label_gt_list, axis=0)
249 |
250 | #np.save('./RESULT/prediction_Mat2Spec_allsamples_' + args.data_src + '_' + args.label_scaling + '_' \
251 | # + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', prediction_x)
252 |
253 | prediction_x_std = np.std(prediction_x, axis=0)
254 | prediction = np.mean(prediction, axis=0)
255 | prediction_x = np.mean(prediction_x, axis=0)
256 | label_gt = np.mean(label_gt, axis=0)
257 |
258 | result_path = './RESULT/'
259 |
260 | if args.finetune:
261 | result_path = result_path + '/finetune/' + args.finetune_dataset + '/'
262 |
263 | if args.ablation_LE:
264 | result_path = result_path + '/ablation_LE/'
265 |
266 | if args.ablation_CL:
267 | result_path = result_path + '/ablation_CL/'
268 |
269 | makedirs(result_path, exist_ok=True)
270 |
271 | if args.label_scaling == 'standardized':
272 | print('\n > label scaling: std')
273 | mean = NORMALIZER.mean.detach().numpy()
274 | std = NORMALIZER.std.detach().numpy()
275 | label_gt_standardized = (label_gt - mean) / std
276 | mae = np.mean(np.abs((prediction) - label_gt_standardized))
277 | mae_x = np.mean(np.abs((prediction_x) - label_gt_standardized))
278 | #if args.data_src != 'no_label_128' and args.data_src != 'no_label_32':
279 | prediction = prediction * std + mean
280 | prediction_x = prediction_x * std + mean
281 | prediction_x_std = prediction_x_std * std
282 | prediction[prediction < 0] = 1e-6
283 | prediction_x[prediction_x < 0] = 1e-6
284 | mae_ori = np.mean(np.abs((prediction)-label_gt))
285 | mae_x_ori = np.mean(np.abs((prediction_x)-label_gt))
286 |
287 | ## save results ##
288 | if args.data_src != 'no_label_128' and args.data_src != 'no_label_32':
289 | np.save(result_path + 'label_gt_' + args.data_src + '_' + args.label_scaling + '_' \
290 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', label_gt)
291 | np.save(result_path + 'label_mean_' + args.data_src + '_' + args.label_scaling + '_' \
292 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', mean)
293 | np.save(result_path + 'label_std_' + args.data_src + '_' + args.label_scaling + '_' \
294 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', std)
295 | np.save(result_path + 'prediction_Mat2Spec_' + args.data_src + '_' + args.label_scaling + '_' \
296 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', prediction_x)
297 | #np.save(result_path + 'prediction_Mat2Spec_standard_deviation_' + args.data_src + '_' + args.label_scaling + '_' \
298 | # + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', prediction_x_std)
299 | testing_mpid = pd_data.iloc[test_idx]
300 | testing_mpid.to_csv(result_path + 'testing_mpids' + args.data_src + '_' + args.label_scaling + '_' \
301 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.csv', index=False, header=True)
302 |
303 | elif args.label_scaling == 'normalized_max':
304 | print('\n > label scaling: norm max')
305 | label_max = np.expand_dims(np.max(label_gt, axis=1), axis=1)
306 | label_gt_standardized = label_gt / label_max
307 | mae = np.mean(np.abs((prediction) - label_gt_standardized))
308 | mae_x = np.mean(np.abs((prediction_x) - label_gt_standardized))
309 | if args.data_src != 'no_label_128' and args.data_src != 'no_label_32':
310 | prediction = prediction * label_max
311 | prediction_x = prediction_x * label_max
312 | prediction_x_std = prediction_x_std * label_max
313 | mae_ori = np.mean(np.abs((prediction) - label_gt))
314 | mae_x_ori = np.mean(np.abs((prediction_x) - label_gt))
315 |
316 | ## save results ##
317 | if args.data_src != 'no_label_128' and args.data_src != 'no_label_32':
318 | np.save(result_path + 'label_gt_' + args.data_src + '_' + args.label_scaling + '_' \
319 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', label_gt)
320 | np.save(result_path + 'label_max_' + args.data_src + '_' + args.label_scaling + '_' \
321 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', label_max)
322 | np.save(result_path + 'prediction_Mat2Spec_' + args.data_src + '_' + args.label_scaling + '_' \
323 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', prediction_x)
324 | #np.save(result_path + 'prediction_Mat2Spec_standard_deviation_' + args.data_src + '_' + args.label_scaling + '_' \
325 | # + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', prediction_x_std)
326 | testing_mpid = pd_data.iloc[test_idx]
327 | testing_mpid.to_csv('testing_mpids' + args.data_src + '_' + args.label_scaling + '_' \
328 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.csv', index=False, header=True)
329 |
330 | elif args.label_scaling == 'normalized_sum':
331 | print('\n > label scaling: norm sum')
332 | assert args.Mat2Spec_loss_type == 'KL' or args.Mat2Spec_loss_type == 'WD'
333 | label_sum = np.sum(label_gt, axis=1, keepdims=True)
334 | label_gt_standardized = label_gt / label_sum
335 | mae = np.mean(np.abs((prediction) - label_gt_standardized))
336 | mae_x = np.mean(np.abs((prediction_x) - label_gt_standardized))
337 | if args.data_src != 'no_label_128' and args.data_src != 'no_label_32':
338 | prediction = prediction * label_sum
339 | prediction_x = prediction_x * label_sum
340 | prediction_x_std = prediction_x_std * label_sum
341 | mae_ori = np.mean(np.abs((prediction) - label_gt))
342 | mae_x_ori = np.mean(np.abs((prediction_x) - label_gt))
343 |
344 | ## save results ##
345 | if args.data_src != 'no_label_128' and args.data_src != 'no_label_32':
346 | np.save(result_path + 'label_gt_' + args.data_src + '_' + args.label_scaling + '_' \
347 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', label_gt)
348 | np.save(result_path + 'label_sum_' + args.data_src + '_' + args.label_scaling + '_' \
349 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', label_sum)
350 | np.save(result_path + 'prediction_Mat2Spec_' + args.data_src + '_' + args.label_scaling + '_' \
351 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', prediction_x)
352 | #np.save(result_path + 'prediction_Mat2Spec_standard_deviation_' + args.data_src + '_' + args.label_scaling + '_' \
353 | # + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.npy', prediction_x_std)
354 | testing_mpid = pd_data.iloc[test_idx]
355 | testing_mpid.to_csv(result_path + 'testing_mpids_' + args.data_src + '_' + args.label_scaling + '_' \
356 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio) + '.csv', index=False, header=True)
357 |
358 | print("\n********** TESTING STATISTIC ***********")
359 | print("total_loss =%.6f\t nll_loss =%.6f\t nll_loss_x =%.6f\t kl_loss =%.6f\t" %
360 | (total_loss_smooth, nll_loss_smooth, nll_loss_x_smooth, kl_loss_smooth))
361 | print("mae=%.6f\t mae_x=%.6f\t mae_ori=%.6f\t mae_x_ori=%.6f" % (mae, mae_x, mae_ori, mae_x_ori))
362 | print("\n*****************************************")
363 |
364 | print(f"> DONE TESTING !")
365 |
--------------------------------------------------------------------------------
/Mat2Spec_Codes/train_Mat2Spec.py:
--------------------------------------------------------------------------------
1 | from Mat2Spec.data import *
2 | from Mat2Spec.Mat2Spec import *
3 | from Mat2Spec.file_setter import use_property
4 | from Mat2Spec.utils import *
5 | import matplotlib.pyplot as plt
6 | import random
7 | from tqdm import tqdm
8 | import gc
9 | import pickle
10 | from copy import copy, deepcopy
11 | import json
12 | torch.autograd.set_detect_anomaly(True)
13 | device = set_device()
14 |
15 | # MOST CRUCIAL DATA PARAMETERS
16 | parser = argparse.ArgumentParser(description='Mat2Spec')
17 | parser.add_argument('--data_src', default='binned_dos_128',choices=['binned_dos_128','binned_dos_32','ph_dos_51'])
18 | parser.add_argument('--label_scaling', default='standardized',choices=['standardized','normalized_sum', 'normalized_max'])
19 | # MOST CRUCIAL MODEL PARAMETERS
20 | parser.add_argument('--num_layers',default=3, type=int,
21 | help='number of AGAT layers to use in model (default:3)')
22 | parser.add_argument('--num_neurons',default=128, type=int,
23 | help='number of neurons to use per AGAT Layer(default:128)')
24 | parser.add_argument('--num_heads',default=4, type=int,
25 | help='number of Attention-Heads to use per AGAT Layer (default:4)')
26 | parser.add_argument('--concat_comp',default=False, type=bool,
27 | help='option to re-use vector of elemental composition after global summation of crystal feature.(default: False)')
28 | parser.add_argument('--train_size',default=0.8, type=float, help='ratio size of the training-set (default:0.8)')
29 | parser.add_argument('--trainset_subset_ratio',default=0.5, type=float, help='ratio size of the training-set subset (default:0.5)')
30 | parser.add_argument('--use_catached_data', default=True, type=bool)
31 | parser.add_argument("--train",action="store_true") # default value is false
32 | parser.add_argument('--num-epochs',default=200, type=int)
33 | parser.add_argument('--batch-size',default=256, type=int)
34 | parser.add_argument('--lr',default=0.001, type=float)
35 | parser.add_argument('--Mat2Spec-input-dim',default=128, type=int)
36 | parser.add_argument('--Mat2Spec-label-dim',default=128, type=int)
37 | parser.add_argument('--Mat2Spec-latent-dim',default=128, type=int)
38 | parser.add_argument('--Mat2Spec-emb-size',default=512, type=int)
39 | parser.add_argument('--Mat2Spec-keep-prob',default=0.5, type=float)
40 | parser.add_argument('--Mat2Spec-scale-coeff',default=1.0, type=float)
41 | parser.add_argument('--Mat2Spec-loss-type',default='MAE', type=str, choices=['MAE', 'KL', 'WD', 'MSE'])
42 | parser.add_argument('--Mat2Spec-K',default=10, type=int)
43 | parser.add_argument("--finetune",action="store_true") # default value is false
44 | parser.add_argument("--ablation-LE",action="store_true") # default value is false
45 | parser.add_argument("--ablation-CL",action="store_true") # default value is false
46 | parser.add_argument("--finetune-dataset",default='null',type=str)
47 | parser.add_argument('--check-point-path', default=None, type=str)
48 | args = parser.parse_args(sys.argv[1:])
49 |
50 | # GNN --- parameters
51 | data_src = args.data_src
52 | RSM = {'radius': 8, 'step': 0.2, 'max_num_nbr': 12}
53 |
54 | number_layers = args.num_layers
55 | number_neurons = args.num_neurons
56 | n_heads = args.num_heads
57 | concat_comp = args.concat_comp
58 |
59 | # DATA PARAMETERS
60 | random_num = 1; random.seed(random_num)
61 | np.random.seed(random_num)
62 | torch.manual_seed(random_num)
63 | # MODEL HYPER-PARAMETERS
64 | num_epochs = args.num_epochs
65 | learning_rate = args.lr
66 | batch_size = args.batch_size
67 |
68 | stop_patience = 150
69 | best_epoch = 1
70 | adj_epochs = 50
71 | milestones = [150,250]
72 | train_param = {'batch_size':batch_size, 'shuffle': True}
73 | valid_param = {'batch_size':batch_size, 'shuffle': False}
74 |
75 | # DATALOADER/ TARGET NORMALIZATION
76 | if args.data_src == 'binned_dos_128':
77 | pd_data = pd.read_csv(f'../Mat2Spec_DATA/label_edos/mpids.csv')
78 | np_data = np.load(f'../Mat2Spec_DATA/label_edos/total_dos_128.npy')
79 | elif args.data_src == 'ph_dos_51':
80 | pd_data = pd.read_csv(f'../Mat2Spec_DATA/phdos/mpids.csv')
81 | np_data = np.load(f'../Mat2Spec_DATA/phdos/ph_dos.npy')
82 | else:
83 | raise ValueError('')
84 |
85 | NORMALIZER = DATA_normalizer(np_data)
86 |
87 | CRYSTAL_DATA = CIF_Dataset(args, pd_data=pd_data, np_data=np_data, root_dir=f'../Mat2Spec_DATA/', **RSM)
88 |
89 | if args.data_src == 'ph_dos_51':
90 | with open('../Mat2Spec_DATA/phdos/200801_trteva_indices.pkl', 'rb') as f:
91 | train_idx, val_idx, test_idx = pickle.load(f)
92 | else:
93 | idx_list = list(range(len(pd_data)))
94 | random.shuffle(idx_list)
95 | train_idx_all, test_val = train_test_split(idx_list, train_size=args.train_size, random_state=random_num)
96 | test_idx, val_idx = train_test_split(test_val, test_size=0.5, random_state=random_num)
97 |
98 | if args.trainset_subset_ratio < 1.0:
99 | train_idx, _ = train_test_split(train_idx_all, train_size=args.trainset_subset_ratio, random_state=random_num)
100 | elif args.data_src != 'ph_dos_51':
101 | train_idx = train_idx_all
102 |
103 | if args.finetune:
104 | assert args.finetune_dataset != 'null'
105 | if args.data_src == 'binned_dos_128':
106 | with open(f'../Mat2Spec_DATA/label_edos/materials_classes/' + args.finetune_dataset + '/train_idx.json', ) as f:
107 | train_idx = json.load(f)
108 |
109 | with open(f'../Mat2Spec_DATA/label_edos/materials_classes/' + args.finetune_dataset + '/val_idx.json', ) as f:
110 | val_idx = json.load(f)
111 |
112 | with open(f'../Mat2Spec_DATA/label_edos/materials_classes/' + args.finetune_dataset + '/test_idx.json', ) as f:
113 | test_idx = json.load(f)
114 | else:
115 | raise ValueError('Finetuning is only supported on the binned dos 128 dataset.')
116 |
117 | #print('total size:', len(idx_list))
118 | print('training size:', len(train_idx))
119 | print('validation size:', len(val_idx))
120 | print('testing size:', len(test_idx))
121 | print('total size:', len(train_idx)+len(val_idx)+len(test_idx))
122 |
123 | training_set = CIF_Lister(train_idx,CRYSTAL_DATA,df=pd_data)
124 | validation_set = CIF_Lister(val_idx,CRYSTAL_DATA,df=pd_data)
125 |
126 | print(f'> USING MODEL Mat2Spec!')
127 | the_network = Mat2Spec(args, NORMALIZER)
128 | net = the_network.to(device)
129 |
130 | if args.finetune:
131 | # load checkpoint
132 | check_point = torch.load(args.check_point_path)
133 | net.load_state_dict(check_point['model'])
134 | learning_rate = learning_rate/5
135 |
136 | # LOSS & OPTMIZER & SCHEDULER
137 | optimizer = optim.AdamW(net.parameters(), lr = learning_rate, weight_decay = 1e-2)
138 | #optimizer = optim.SGD(net.parameters(), lr = learning_rate, momentum=0.9)
139 |
140 | decay_times = 4
141 | decay_ratios = 0.5
142 | one_epoch_iter = np.ceil(len(train_idx) / batch_size)
143 |
144 | if args.finetune:
145 | decay_ratios = 0.5
146 |
147 | scheduler = lr_scheduler.StepLR(optimizer, one_epoch_iter * (num_epochs / decay_times), decay_ratios)
148 |
149 | print(f'> TRAINING MODEL ...')
150 | train_loader = torch_DataLoader(dataset=training_set, **train_param)
151 | valid_loader = torch_DataLoader(dataset=validation_set, **valid_param)
152 |
153 | training_counter=0
154 | training_loss=0
155 | valid_counter=0
156 | valid_loss=0
157 | best_valid_loss=1e+10
158 | check_fre = 10
159 | current_step = 0
160 |
161 | total_loss_smooth = 0
162 | nll_loss_smooth = 0
163 | nll_loss_x_smooth = 0
164 | kl_loss_smooth = 0
165 | cpc_loss_smooth = 0
166 | prediction = []
167 | prediction_x = []
168 | label_gt = []
169 | label_scale_value = []
170 | sum_pred_smooth = 0
171 |
172 | start_time = time.time()
173 | for epoch in range(num_epochs):
174 |
175 | # TRAINING-STAGE
176 | net.train()
177 | args.train = True
178 | for data in tqdm(train_loader, mininterval=0.5, desc=f'(EPOCH:{epoch} TRAINING)', position=0, leave=True, ascii=True):
179 | current_step += 1
180 | data = data.to(device)
181 | train_label = deepcopy(data.y).to(device)
182 | if args.label_scaling == 'standardized':
183 | train_label_normalize = (train_label - NORMALIZER.mean.to(device)) / NORMALIZER.std.to(device)
184 | elif args.label_scaling == 'normalized_max':
185 | train_label_normalize = train_label / (torch.max(train_label,dim=1)[0].unsqueeze(1))
186 | elif args.label_scaling == 'normalized_sum':
187 | train_label_normalize = train_label / torch.sum(train_label, dim=1, keepdim=True)
188 |
189 | predictions = net(data)
190 | total_loss, nll_loss, nll_loss_x, kl_loss, c_loss, pred_e, pred_x = \
191 | compute_loss(train_label_normalize, predictions, NORMALIZER, args)
192 |
193 | optimizer.zero_grad()
194 | total_loss.backward()
195 | optimizer.step()
196 | scheduler.step()
197 |
198 | prediction.append(pred_e.detach().cpu().numpy())
199 | prediction_x.append(pred_x.detach().cpu().numpy())
200 | label_gt.append(train_label.detach().cpu().numpy())
201 |
202 | total_loss_smooth += total_loss
203 | nll_loss_smooth += nll_loss
204 | nll_loss_x_smooth += nll_loss_x
205 | kl_loss_smooth += kl_loss
206 | cpc_loss_smooth += c_loss
207 | training_counter +=1
208 |
209 | total_loss_smooth = total_loss_smooth / training_counter
210 | nll_loss_smooth = nll_loss_smooth / training_counter
211 | nll_loss_x_smooth = nll_loss_x_smooth / training_counter
212 | kl_loss_smooth = kl_loss_smooth / training_counter
213 | cpc_loss_smooth = cpc_loss_smooth / training_counter
214 |
215 | prediction = np.concatenate(prediction, axis=0)
216 | prediction_x = np.concatenate(prediction_x, axis=0)
217 | label_gt = np.concatenate(label_gt, axis=0)
218 |
219 | if args.label_scaling == 'standardized':
220 | mean = NORMALIZER.mean.detach().numpy()
221 | std = NORMALIZER.std.detach().numpy()
222 | label_gt_standardized = (label_gt-mean)/std
223 | mae = np.mean(np.abs((prediction)-label_gt_standardized))
224 | mae_x = np.mean(np.abs((prediction_x)-label_gt_standardized))
225 | prediction = prediction*std+mean
226 | prediction_x = prediction_x*std+mean
227 | prediction[prediction < 0] = 0
228 | prediction_x[prediction_x < 0] = 0
229 | mae_ori = np.mean(np.abs((prediction)-label_gt))
230 | mae_x_ori = np.mean(np.abs((prediction_x)-label_gt))
231 |
232 | elif args.label_scaling == 'normalized_max':
233 | label_max = np.expand_dims(np.max(label_gt, axis=1), axis=1)
234 | label_gt_standardized = label_gt / label_max
235 | mae = np.mean(np.abs((prediction)-label_gt_standardized))
236 | mae_x = np.mean(np.abs((prediction_x)-label_gt_standardized))
237 | prediction = prediction*label_max
238 | prediction_x = prediction_x*label_max
239 | mae_ori = np.mean(np.abs((prediction)-label_gt))
240 | mae_x_ori = np.mean(np.abs((prediction_x)-label_gt))
241 |
242 | elif args.label_scaling == 'normalized_sum':
243 | assert args.Mat2Spec_loss_type == 'KL' or args.Mat2Spec_loss_type == 'WD'
244 | label_sum = np.sum(label_gt, axis=1, keepdims=True)
245 | label_gt_standardized = label_gt / label_sum
246 | mae = np.mean(np.abs((prediction)-label_gt_standardized))
247 | mae_x = np.mean(np.abs((prediction_x)-label_gt_standardized))
248 | prediction = prediction*label_sum
249 | prediction_x = prediction_x*label_sum
250 | mae_ori = np.mean(np.abs((prediction)-label_gt))
251 | mae_x_ori = np.mean(np.abs((prediction_x)-label_gt))
252 |
253 | print("\n********** TRAINING STATISTIC ***********")
254 | print("total_loss =%.6f\t nll_loss =%.6f\t nll_loss_x =%.6f\t kl_loss =%.6f\t cpc_loss=%.6f\t" %
255 | (total_loss_smooth, nll_loss_smooth, nll_loss_x_smooth, kl_loss_smooth, cpc_loss_smooth))
256 | print("mae=%.6f\t mae_x=%.6f\t mae_ori=%.6f\t mae_x_ori=%.6f" % (mae, mae_x, mae_ori, mae_x_ori))
257 | print("\n*****************************************")
258 |
259 | training_counter = 0
260 | total_loss_smooth = 0
261 | nll_loss_smooth = 0
262 | nll_loss_x_smooth = 0
263 | kl_loss_smooth = 0
264 | cpc_loss_smooth = 0
265 | prediction = []
266 | prediction_x = []
267 | label_gt = []
268 | label_scale_value = []
269 | sum_pred_smooth = 0
270 |
271 | # VALIDATION-PHASE
272 | net.eval()
273 | for data in tqdm(valid_loader, mininterval=0.5, desc='(validating)', position=0, leave=True, ascii=True):
274 | data = data.to(device)
275 | valid_label = deepcopy(data.y).float().to(device)
276 |
277 | if args.label_scaling == 'standardized':
278 | valid_label_normalize = (valid_label - NORMALIZER.mean.to(device)) / NORMALIZER.std.to(device)
279 | elif args.label_scaling == 'normalized_max':
280 | valid_label_normalize = valid_label/(torch.max(valid_label,dim=1)[0].unsqueeze(1))
281 |
282 | elif args.label_scaling == 'normalized_sum':
283 | valid_label_normalize = valid_label / (torch.sum(valid_label, dim=1, keepdim=True)+1e-8)
284 |
285 | with torch.no_grad():
286 | predictions = net(data)
287 | total_loss, nll_loss, nll_loss_x, kl_loss, cpc_loss, pred_e, pred_x = \
288 | compute_loss(valid_label_normalize, predictions, NORMALIZER, args)
289 |
290 | prediction.append(pred_e.detach().cpu().numpy())
291 | prediction_x.append(pred_x.detach().cpu().numpy())
292 | label_gt.append(valid_label.detach().cpu().numpy())
293 |
294 | total_loss_smooth += total_loss
295 | nll_loss_smooth += nll_loss
296 | nll_loss_x_smooth += nll_loss_x
297 | kl_loss_smooth += kl_loss
298 | cpc_loss_smooth += cpc_loss
299 | valid_counter += 1
300 |
301 | total_loss_smooth = total_loss_smooth / valid_counter
302 | nll_loss_smooth = nll_loss_smooth / valid_counter
303 | nll_loss_x_smooth = nll_loss_x_smooth / valid_counter
304 | kl_loss_smooth = kl_loss_smooth / valid_counter
305 | cpc_loss_smooth = cpc_loss_smooth / valid_counter
306 |
307 | prediction = np.concatenate(prediction, axis=0)
308 | prediction_x = np.concatenate(prediction_x, axis=0)
309 | label_gt = np.concatenate(label_gt, axis=0)
310 |
311 | if args.label_scaling == 'standardized':
312 | mean = NORMALIZER.mean.detach().numpy()
313 | std = NORMALIZER.std.detach().numpy()
314 | label_gt_standardized = (label_gt-mean)/std
315 | mae = np.mean(np.abs((prediction)-label_gt_standardized))
316 | mae_x = np.mean(np.abs((prediction_x)-label_gt_standardized))
317 | prediction = prediction*std+mean
318 | prediction_x = prediction_x*std+mean
319 | prediction[prediction < 0] = 0
320 | prediction_x[prediction_x < 0] = 0
321 | mae_ori = np.mean(np.abs((prediction)-label_gt))
322 | mae_x_ori = np.mean(np.abs((prediction_x)-label_gt))
323 |
324 | elif args.label_scaling == 'normalized_max':
325 | label_max = np.expand_dims(np.max(label_gt, axis=1), axis=1)
326 | label_gt_standardized = label_gt / label_max
327 | mae = np.mean(np.abs((prediction)-label_gt_standardized))
328 | mae_x = np.mean(np.abs((prediction_x)-label_gt_standardized))
329 | prediction = prediction*label_max
330 | prediction_x = prediction_x*label_max
331 | mae_ori = np.mean(np.abs((prediction)-label_gt))
332 | mae_x_ori = np.mean(np.abs((prediction_x)-label_gt))
333 |
334 | elif args.label_scaling == 'normalized_sum':
335 | assert args.Mat2Spec_loss_type == 'KL' or args.Mat2Spec_loss_type == 'WD'
336 | label_sum = np.sum(label_gt, axis=1, keepdims=True)
337 | label_gt_standardized = label_gt / label_sum
338 | mae = np.mean(np.abs((prediction)-label_gt_standardized))
339 | mae_x = np.mean(np.abs((prediction_x)-label_gt_standardized))
340 | prediction = prediction*label_sum
341 | prediction_x = prediction_x*label_sum
342 | mae_ori = np.mean(np.abs((prediction)-label_gt))
343 | mae_x_ori = np.mean(np.abs((prediction_x)-label_gt))
344 |
345 | print("\n********** VALIDATING STATISTIC ***********")
346 | print("total_loss =%.6f\t nll_loss =%.6f\t nll_loss_x =%.6f\t kl_loss =%.6f\t cpc_loss = %.6f\t" %
347 | (total_loss_smooth, nll_loss_smooth, nll_loss_x_smooth, kl_loss_smooth, cpc_loss_smooth))
348 | print("mae=%.6f\t mae_x=%.6f\t mae_ori=%.6f\t mae_x_ori=%.6f" % (mae, mae_x, mae_ori, mae_x_ori))
349 | print("\n*****************************************")
350 |
351 | if best_valid_loss > mae_x_ori:
352 | best_valid_loss = mae_x_ori
353 | print("\n********** SAVING MODEL ***********")
354 | checkpoint = {'model': net.state_dict(), 'args': args}
355 | if not args.finetune:
356 | #checkpoint_path = './TRAINED/'
357 | save_path = './TRAINED/model_Mat2Spec_' + args.data_src + '_' + args.label_scaling + '_' \
358 | + args.Mat2Spec_loss_type + '_trainsize' + str(args.trainset_subset_ratio)
359 | else:
360 | save_path = './TRAINED/finetune/model_Mat2Spec_' + args.data_src + '_' + args.label_scaling + '_' \
361 | + args.Mat2Spec_loss_type + '_finetune_' + str(args.finetune_dataset)
362 |
363 | if args.ablation_LE:
364 | save_path = save_path + '_ablation_LE'
365 |
366 | if args.ablation_CL:
367 | save_path = save_path + '_ablation_CL'
368 |
369 | save_path = save_path + '.chkpt'
370 | torch.save(checkpoint, save_path)
371 | print("A new model has been saved to " + save_path)
372 | print("\n*****************************************")
373 |
374 | valid_counter=0
375 | total_loss_smooth = 0
376 | nll_loss_smooth = 0
377 | nll_loss_x_smooth = 0
378 | kl_loss_smooth = 0
379 | cpc_loss_smooth = 0
380 | prediction = []
381 | prediction_x = []
382 | label_gt = []
383 | label_scale_value = []
384 | sum_pred_smooth = 0
385 | gc.collect()
386 |
387 | end_time = time.time()
388 | e_time = end_time - start_time
389 | print('Best validation loss=%.6f, training time (min)=%.6f'%(best_valid_loss, e_time/60))
390 | print(f"> DONE TRAINING !")
391 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Density of States Prediction for Materials Discovery via Contrastive Learning from Probabilistic Embeddings
2 |
3 | Authors: Shufeng Kong 1, Francesco Ricci 2,4, Dan Guevarra 3, Jeffrey B. Neaton 2,5,6, Carla P. Gomes 1, and John M. Gregoire 3
4 | 1) Department of Computer Science, Cornell University, Ithaca, NY, USA
5 | 2) Material Science Division, Lawrence Berkeley National Laboratory, Berkeley, CA, USA
6 | 3) Division of Engineering and Applied Science, California Institute of Technology, Pasadena, CA, USA
7 | 4) Chemical Science Division, Lawrence Berkeley National Laboratory, Berkeley, CA, USA
8 | 5) Department of Physics, University of California, Berkeley, Berkeley, CA, USA
9 | 6) Kavli Energy NanoSciences Institute at Berkeley, Berkeley, CA, USA
10 |
11 | This a Pytorch implementation of the machine learning model "Mat2Spec" presented in this paper (https://www.nature.com/articles/s41467-022-28543-x).
12 | Any question or suggestion about the codes please directly send to sk2299@cornell.edu
13 |
14 | ### Installation
15 | Install the following packages if not already installed: - may take 30 mins on typical machine to install all of them:
16 | * Python (tested on 3.8.11)
17 | * Pytorch (tested on 1.4.0)
18 | * Cuda (tested on 10.0)
19 | * Pandas (tested on 1.3.3)
20 | * Pytmatgen (tested on 2022.0.14)
21 | * PyTorch-Geometric (tested on 1.5.0)
22 |
23 | Please follow these steps to create an environment:
24 |
25 | 1) Download packages - example:
26 | https://download.pytorch.org/whl/cu100/torch-1.4.0%2Bcu100-cp38-cp38-linux_x86_64.whl
27 | https://download.pytorch.org/whl/cu100/torchvision-0.5.0%2Bcu100-cp38-cp38-linux_x86_64.whl
28 | https://data.pyg.org/whl/torch-1.4.0/torch_cluster-1.5.4%2Bcu100-cp38-cp38-linux_x86_64.whl
29 | https://data.pyg.org/whl/torch-1.4.0/torch_scatter-2.0.4%2Bcu100-cp38-cp38-linux_x86_64.whl
30 | https://data.pyg.org/whl/torch-1.4.0/torch_sparse-0.6.1%2Bcu100-cp38-cp38-linux_x86_64.whl
31 | https://data.pyg.org/whl/torch-1.4.0/torch_spline_conv-1.2.0%2Bcu100-cp38-cp38-linux_x86_64.whl
32 |
33 | 2) Install packages - example
34 |
35 | ```bash
36 | conda create --name mat2spec python=3.8
37 | conda activate mat2spec
38 | pip install torch-1.4.0+cu100-cp38-cp38-linux_x86_64.whl
39 | pip install torchvision-0.5.0+cu100-cp38-cp38-linux_x86_64.whl
40 | pip install torch_cluster-1.5.4+cu100-cp38-cp38-linux_x86_64.whl
41 | pip install torch_scatter-2.0.4+cu100-cp38-cp38-linux_x86_64.whl
42 | pip install torch_sparse-0.6.1+cu100-cp38-cp38-linux_x86_64.whl
43 | pip install torch_spline_conv-1.2.0+cu100-cp38-cp38-linux_x86_64.whl
44 | pip install torch-geometric==1.5.0
45 | pip install pandas
46 | pip install pymatgen
47 | ```
48 |
49 | When finish using our model, you can deactivate the environment:
50 | ```bash
51 | conda deactivate
52 | ```
53 |
54 | Remember to activate the environment before using our model next time:
55 | ```bash
56 | conda activate mat2spec
57 | ```
58 |
59 | ### Datasets
60 |
61 | 1) Phonon density of state: see our data repository link below, or data can be downloaded from here https://github.com/zhantaochen/phonondos_e3nn.
62 | 2) Electronic density of state: see our data repository link below, or data can be downloaded from the Materials Project.
63 | 3) Initial element embeddings: please refer to "Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties" by Tian Xie and Jeffrey C. Grossman.
64 |
65 | These initial element embeddings include the embeddings of the following elements: 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr'.
66 |
67 | Datasets for this work are avaiable at https://data.caltech.edu/records/8975
68 |
69 | Please download the data folder and unzip it under the main folder 'Mat2Spec'.
70 |
71 | ### Example Usage
72 |
73 | Model training typically takes 20 min for phDOS and 3 hours for eDOS on a GPU.
74 |
75 | To train the model on phDOS with maxnorm and MSE:
76 | ```bash
77 | bash SCRIPTS/train_phdos51_norm_max_mse.sh
78 | ```
79 | Note that the bash scripts manually assign the CUDA device index via environment variable CUDA_VISIBLE_DEVICES and should be adjusted to the correct index (usually '0' for single GPU systems) prior to training or else Pytorch will only leverage CPU.
80 |
81 | To train the model in eDOS with std and MAE:
82 | ```bash
83 | bash SCRIPTS/train_dos128_std_mae.sh
84 | ```
85 |
86 | To train the model in eDOS with norm sum and KL:
87 | ```bash
88 | bash SCRIPTS/train_dos128_norm_sum_kl.sh
89 | ```
90 |
91 | To test the trained models:
92 | ```bash
93 | bash SCRIPTS/test_phdos51_norm_max_mse.sh
94 | bash SCRIPTS/test_dos128_std_mae.sh
95 | bash SCRIPTS/test_dos128_norm_sum_kl.sh
96 | ```
97 |
98 | To use the trained models for predicting eDOS for material without label:
99 |
100 | 1) Place your json files under ./Mat2Spec_DATA/materials_without_dos/
101 | Each json file should includes a key 'structure' which maps to a material in the pymatgen format.
102 |
103 | 2) Place a csv file named 'mpids.csv' that contains all your json files' names under ./DATA/20210623_no_label
104 |
105 | 3) If you want to use trained models with std and MAE:
106 |
107 | ```bash
108 | bash SCRIPTS/test_nolabel128_std_mae.sh
109 | ```
110 |
111 | 4) If you want to use trained models with norm sum and KL:
112 |
113 | ```bash
114 | bash SCRIPTS/test_nolabel128_std_mae.sh
115 | bash SCRIPTS/test_nolabel128_norm_sum_kl.sh
116 | ```
117 |
118 | Then rescale the KL prediction with the std prediction:
119 | ```bash
120 | x_sd = np.load('prediction_Mat2Spec_no_label_128_standardized_MAE_trainsize1.0.npy')
121 | x_kl = np.load('prediction_Mat2Spec_no_label_128_normalized_sum_KL_trainsize1.0.npy')
122 | x = x_kl*np.sum(x_sd, axis=-1, keepdims=True)
123 | ```
124 |
125 |
126 | All test results (model-predicted DOS) are placed under ./RESULT
127 |
128 |
129 | ### Disclaimer
130 | This is research code shared without support or any guarantee on its quality. However, please do raise an issue or submit a pull request if you spot something wrong or that could be improved and I will try my best to solve it.
131 |
132 | ### Acknowledgements
133 | Implementation of the GNN is modified from GATGNN: https://github.com/superlouis/GATGNN.
134 |
--------------------------------------------------------------------------------