├── .gitignore ├── CGAT ├── CGAT.py ├── Hypernetworksmp.py ├── __init__.py ├── add_volume_target.py ├── data.py ├── gaussian_process.py ├── lambs.py ├── lightning_module.py ├── message_changed.py ├── predict.py ├── prepare_data.py ├── roost_message.py ├── test.py ├── test_prepare_data.py ├── train.py └── utils.py ├── LICENSE ├── README.md ├── Utilities ├── adjust_data.py ├── calculate_embeddings.py ├── calculate_errors.py ├── element_correlation.py ├── errors_of_additional_data.py ├── filter_embeddings.py ├── get_additional_data.py ├── get_highest_errors.py ├── gp_predict.py ├── metropolis.py ├── prediction.py ├── prepare.sh ├── prepare_active_learning.py ├── sample.py ├── train.sh └── tsne.py ├── embeddings └── matscholar-embedding.json ├── prepare_data.py ├── pyproject.toml ├── requirements.txt ├── runs └── plot.py ├── setup.cfg ├── setup.py ├── test.py └── training_scripts ├── train.sh ├── transfer_full.sh └── transfer_only_residual.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *class java.lang.String 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # PyCharm 132 | .idea/ 133 | 134 | # Data and tensorboard logs 135 | original_data/ 136 | tb_logs/ 137 | *.pickle.gz 138 | -------------------------------------------------------------------------------- /CGAT/CGAT.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | from CGAT.roost_message import Roost 4 | import torch.nn.functional as F 5 | import torch 6 | from torch_scatter import scatter_max, scatter_add 7 | from torch_geometric.nn import MessagePassing 8 | from CGAT.message_changed import SimpleNetwork, ResidualNetwork 9 | from torch_geometric.utils import softmax 10 | import torch.nn as nn 11 | from CGAT.Hypernetworksmp import H_Net, H_Net_0 12 | 13 | 14 | class MHAttention(nn.Module): 15 | """ 16 | Multihead attention with fully connected networks used for the combination of the global features of the composition 17 | and the node representations 18 | """ 19 | 20 | def __init__( 21 | self, 22 | in_channels, 23 | out_channels, 24 | heads=1, 25 | vector_attention=False): 26 | """ 27 | Inputs 28 | ---------- 29 | in_channels (int): Size of node embeddings and composition embeddings 30 | out_channels (int): Size of output node embedding 31 | heads (int): Number of attention heads 32 | vector_attention (bool): If set to true vectorized attention coefficients are used 33 | """ 34 | super(MHAttention, self).__init__() 35 | self.heads = heads 36 | self.out_channels = out_channels 37 | if(vector_attention): 38 | self.MH_A = MultiHeadNetwork( 39 | 2 * in_channels, 40 | out_channels, 41 | in_channels, 42 | heads, 43 | view=False) 44 | else: 45 | self.MH_A = MultiHeadNetwork( 46 | 2 * in_channels, 1, in_channels, heads, view=False) 47 | self.MH_M = MultiHeadNetwork( 48 | in_channels, out_channels, in_channels, heads) 49 | 50 | def forward(self, fea, cry_fea, index, size=None): 51 | """ forward pass """ 52 | size = index[-1].item() + 1 if size is None else size 53 | m = self.MH_M(fea) 54 | # concatenate atomic and global featues for the corresponding sytem 55 | fea = torch.stack([fea, cry_fea[index]]) 56 | # switch axis to get correct reshaping in Multiheadnetworks 57 | fea = fea.transpose(1, 0) 58 | alpha = self.MH_A(fea) 59 | alpha = softmax(alpha, index, None, size) 60 | out = scatter_add((alpha * m).view(-1, self.heads * \ 61 | self.out_channels), index, dim=0, dim_size=size) 62 | return out 63 | 64 | 65 | class MultiHeadNetwork(nn.Module): 66 | """ 67 | nb_heads parallel feed forward networks to be used in MHAttention 68 | """ 69 | 70 | def __init__( 71 | self, 72 | input_dim, 73 | output_dim, 74 | hidden_layer_dim, 75 | nb_heads, 76 | view=True): 77 | """ 78 | Inputs 79 | ---------- 80 | input_dim (int): Input size 81 | output_dim (int): Outputsize 82 | hidden_layer_dim (int): Hidden layer dimension 83 | nb_heads (int): Number of attention heads/fully connected networks 84 | view (bool): Set to False if fea tensor is not contiguous in memory 85 | """ 86 | super(MultiHeadNetwork, self).__init__() 87 | 88 | self.input_dim = input_dim 89 | self.nb_heads = nb_heads 90 | self.output_dim = output_dim 91 | self.fc_in = nn.Conv1d(in_channels=input_dim * nb_heads, 92 | out_channels= hidden_layer_dim * nb_heads, 93 | kernel_size=1, 94 | groups=nb_heads) 95 | self.acts = nn.LeakyReLU() 96 | self.fc_out = nn.Conv1d( 97 | in_channels=hidden_layer_dim * nb_heads, 98 | out_channels=output_dim * nb_heads, 99 | kernel_size=1, 100 | groups=nb_heads) 101 | self.view = view 102 | 103 | def forward(self, fea): 104 | if self.view: 105 | fea = self.acts(self.fc_in(fea.view(-1, self.input_dim, 1).repeat(1, self.nb_heads, 1))) 106 | else: 107 | fea = self.acts(self.fc_in(fea.reshape(-1, self.input_dim, 1).repeat(1, self.nb_heads, 1))) 108 | 109 | return self.fc_out(fea).view(-1, self.nb_heads, self.output_dim) 110 | 111 | def __repr__(self): 112 | return '{}'.format(self.__class__.__name__) 113 | 114 | 115 | class GATConvEdges(nn.Module): 116 | """ graph attentional operator for edges combines node and edge information 117 | and updates the edge embedding through a multihead attention mechanism 118 | Args: 119 | in_channels (int): Size of node embedding. 120 | out_channels (int): Size of output embedding. 121 | nbr_channels (int): Size of edge embeddings 122 | heads (int, optional): Number of multi-head-attentions. 123 | (default: :obj:`1`) 124 | concat (bool, optional): If set to :obj:`False`, the multi-head 125 | attentions are averaged instead of concatenated. 126 | (default: :obj:`True`) 127 | negative_slope (float, optional): LeakyReLU angle of the negative 128 | slope. (default: :obj:`0.2`) 129 | dropout (float, optional): Dropout probability of the normalized 130 | attention coefficients which exposes each node to a stochastically 131 | sampled neighborhood during training. (default: :obj:`0`) 132 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 133 | an additive bias. (default: :obj:`True`) 134 | vector_attention (bool, optional): If set to :obj:`False`, the attention coefficients will be scalar else they 135 | will be vectors (default: :obj:`False`) 136 | first (bool, optional): Ignore (default: :obj:`True`) 137 | no_hyper (bool, optional): If set to False will use hypernetworks for pooling_NN (default :obj'False') 138 | **kwargs (optional): Additional arguments of 139 | :class:`torch_geometric.nn.conv.MessagePassing`. 140 | """ 141 | 142 | def __init__( 143 | self, 144 | in_channels, 145 | out_channels, 146 | nbr_channels, 147 | heads=1, 148 | concat=True, 149 | negative_slope=0.2, 150 | dropout=0, 151 | bias=True, 152 | vector_attention=False, 153 | first=False, 154 | no_hyper=True, 155 | **kwargs): 156 | super(GATConvEdges, self).__init__(**kwargs) 157 | 158 | self.in_channels = in_channels 159 | self.out_channels = out_channels 160 | self.nbr_channels = nbr_channels 161 | self.heads = heads 162 | self.concat = concat 163 | self.negative_slope = negative_slope 164 | self.dropout = dropout 165 | self.vector_attention = vector_attention 166 | 167 | if(vector_attention): 168 | self.MH_A = MultiHeadNetwork(2 * in_channels + nbr_channels, 169 | out_channels, 170 | int((2 * in_channels + nbr_channels) / 1.5), 171 | heads) 172 | else: 173 | self.MH_A = MultiHeadNetwork(2 * in_channels + nbr_channels, 174 | 1, 175 | int((2 * in_channels + nbr_channels) / 1.5), 176 | heads) 177 | 178 | self.MH_M = MultiHeadNetwork(2 * in_channels + nbr_channels, 179 | out_channels, 180 | int((2 * in_channels + nbr_channels) / 1.5), 181 | heads) 182 | 183 | if no_hyper: 184 | self.Pooling_NN = SimpleNetwork(out_channels, 185 | out_channels, 186 | [out_channels]) 187 | elif first: 188 | self.Pooling_NN = H_Net_0( 189 | out_channels, 190 | 3, 191 | out_channels, 192 | out_channels, 193 | 2, 194 | out_channels, 195 | out_channels) 196 | else: 197 | self.Pooling_NN = H_Net( 198 | out_channels, 199 | 3, 200 | out_channels, 201 | out_channels, 202 | 2, 203 | out_channels, 204 | out_channels) 205 | self.first = first 206 | self.no_hyper = no_hyper 207 | 208 | def forward(self, x, edge_index, edge_attr, x_0, size=None): 209 | x_i = x[edge_index[0]] 210 | x_j = x[edge_index[1]] 211 | m = torch.cat([x_i, edge_attr, x_j], dim=-1) 212 | alpha = self.MH_A(m) 213 | m = self.MH_M(m) 214 | alpha = alpha.exp() 215 | 216 | if(not self.vector_attention): 217 | alpha = alpha / alpha.sum(dim=1).view(-1, 1, 1) 218 | else: 219 | alpha = alpha / alpha.sum(dim=1).view(-1, 1, self.out_channels) 220 | 221 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 222 | aggr_out = m.view(-1, self.heads, self.out_channels) * alpha 223 | aggr_out = aggr_out.mean(dim=1) 224 | if self.no_hyper: 225 | aggr_out = self.Pooling_NN(edge_attr) 226 | elif self.first: 227 | aggr_out = self.Pooling_NN(edge_attr, aggr_out) 228 | else: 229 | aggr_out = self.Pooling_NN(x_0, edge_attr, aggr_out) 230 | return aggr_out 231 | 232 | 233 | class GATConvNodes(MessagePassing): 234 | """Graph attentional operator adapted from `"Graph Attention Networks" 235 | `_ paper 236 | Args: 237 | in_channels (int): Size of node embedding. 238 | out_channels (int): Size of output embedding. 239 | nbr_channels (int): Size of edge embeddings 240 | heads (int, optional): Number of multi-head-attentions. 241 | (default: :obj:`1`) 242 | concat (bool, optional): If set to :obj:`False`, the multi-head 243 | attentions are averaged instead of concatenated. 244 | (default: :obj:`True`) 245 | negative_slope (float, optional): LeakyReLU angle of the negative 246 | slope. (default: :obj:`0.2`) 247 | dropout (float, optional): Dropout probability of the normalized 248 | attention coefficients which exposes each node to a stochastically 249 | sampled neighborhood during training. (default: :obj:`0`) 250 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 251 | an additive bias. (default: :obj:`True`) 252 | final (bool, optional): :obj:`False`, Should be set to false for the last message passing layer 253 | (default: :obj:`False`) 254 | vector_attention (bool, optional): If set to :obj:`False`, the attention coefficients will be scalar else they 255 | will be vectors (default: :obj:`False`) 256 | first (bool, optional): Ignore (default: :obj:`True`) 257 | **kwargs (optional): Additional arguments of 258 | :class:`torch_geometric.nn.conv.MessagePassing`. 259 | """ 260 | 261 | def __init__( 262 | self, 263 | in_channels, 264 | out_channels, 265 | nbr_channels, 266 | heads=1, 267 | concat=False, 268 | negative_slope=0.2, 269 | dropout=0, 270 | bias=True, 271 | final=False, 272 | vector_attention=False, 273 | first=False, 274 | **kwargs): 275 | super(GATConvNodes, self).__init__(aggr='add', **kwargs) 276 | self.node_dim = 0 #propagation axis 277 | self.in_channels = in_channels 278 | self.out_channels = out_channels 279 | self.nbr_channels = nbr_channels 280 | self.heads = heads 281 | self.concat = concat 282 | self.negative_slope = negative_slope 283 | self.dropout = dropout 284 | self.final = final 285 | self.first = first 286 | if vector_attention: 287 | self.MH_A = MultiHeadNetwork(2 * in_channels + nbr_channels, 288 | out_channels, 289 | int((2 * in_channels + nbr_channels) / 1.5), 290 | heads) 291 | else: 292 | self.MH_A = MultiHeadNetwork(2 * in_channels + nbr_channels, 293 | 1, 294 | int((2 * in_channels + nbr_channels) / 1.5), 295 | heads) 296 | self.MH_M = MultiHeadNetwork(2 * in_channels + nbr_channels, 297 | out_channels, 298 | int((2 * in_channels + nbr_channels) / 1.5), 299 | heads) 300 | if not final and first: 301 | self.Pooling_NN = H_Net_0(out_channels, 3, out_channels, out_channels, 302 | 2, out_channels, out_channels) 303 | elif not final: 304 | self.Pooling_NN = H_Net(out_channels, 3, out_channels, out_channels, 305 | 2, out_channels, out_channels) 306 | 307 | def forward(self, x, edge_index, edge_attr, x_0, size=None): 308 | if torch.is_tensor(x): 309 | pass 310 | else: 311 | x = (None if x[0] is None else x[0], 312 | None if x[1] is None else x[1]) 313 | return self.propagate( 314 | edge_index, 315 | x=x, 316 | edge_attr=edge_attr, 317 | x_0=x_0) 318 | 319 | def message(self, x_i, x_j, edge_attr, edge_index_i): 320 | m = torch.cat([x_i, edge_attr, x_j], dim=-1) 321 | alpha = self.MH_A(m) 322 | m = self.MH_M(m) 323 | alpha = softmax(alpha, edge_index_i) #, size_i) 324 | # Sample attention coefficients stochastically. 325 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 326 | return m * alpha 327 | 328 | def update(self, aggr_out, x_0, x): 329 | aggr_out = aggr_out.mean(dim=1) 330 | if not self.final and self.first: 331 | return self.Pooling_NN(x, aggr_out) 332 | elif not self.final: 333 | return self.Pooling_NN(x_0, x, aggr_out) 334 | else: 335 | return aggr_out 336 | 337 | def __repr__(self): 338 | return '{}({}, {}, heads={})'.format(self.__class__.__name__, 339 | self.in_channels, 340 | self.out_channels, self.heads) 341 | 342 | 343 | class CGAtNet(nn.Module): 344 | """ 345 | Create a neural network for predicting total material properties. 346 | 347 | The CGatNet is comprised of a fully connected output network, 348 | message passing graph layers used on a crystal graph and a composition based roost model 349 | (see https://github.com/CompRhys/roost and https://doi.org/10.1038/s41467-020-19964-7 for the roost model) that is 350 | used in the final pooling step. 351 | 352 | The message passing layers are used to determine an embedding for the whole material that is used in the 353 | fully connected network. Critically the graphs are used to 354 | represent (crystalline) materials in a volume agnostic manner. 355 | The model is also agnostic to changes in the structure that do not cause a change in 356 | neighborlist. 357 | """ 358 | 359 | def __init__( 360 | self, 361 | orig_elem_fea_len, 362 | elem_fea_len, 363 | n_graph, 364 | nbr_embedding_size=128, 365 | neighbor_number=12, 366 | mean_pooling=True, 367 | rezero=False, 368 | msg_heads=3, 369 | update_edges=False, 370 | vector_attention=False, 371 | global_vector_attention=False, 372 | n_graph_roost=3, 373 | no_hyper=True): 374 | """ 375 | Args: 376 | orig_elem_fea_len (int): size of pretrained species embedding. 377 | elem_fea_len (int): Size of species embedding used during message passing 378 | n_graph (int): Number of message passing steps in CGAT modell 379 | nbr_embedding_size (int): Size of edge embeddings 380 | neighbor_number (int): Number of neighbors considered during each message passing step 381 | mean_pooling (int): If set to False the material embeddings returned by the pooling layer following the message 382 | passing will be concatenated instead of averaged (default: obj: True) 383 | msg_heads (int, optional): Number of multi-head-attentions. 384 | (default: :obj:`3`) 385 | update_edges (bool): If set to True edge embeddings will be updated (default: :obj:`False`) 386 | vector_attention (bool, optional): If set to :obj:`False`, the attention coefficients during the message 387 | passing phase will be scalar else they will be vectors (default: :obj:`False`) 388 | global_vector_attention (bool, optional): If set to :obj:`False`, the attention coefficients of the pooling 389 | layer after the message passing phase will be scalar else they will be vectors (default: :obj:`False`) 390 | n_graph_roost (int): Number of message passing steps in the roost model 391 | no_hyper (bool, optional): If set to :obj:`False`, hypernetworks will be used in the message passing of the edges 392 | **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. 393 | """ 394 | super(CGAtNet, self).__init__() 395 | self.mean_pooling = mean_pooling 396 | self.update_edges = update_edges 397 | # apply linear transform to the input to get a trainable embedding 398 | self.embedding = nn.Linear(orig_elem_fea_len, elem_fea_len, bias=False) 399 | self.nbr_embedding = nn.Embedding( 400 | num_embeddings=neighbor_number + 1, 401 | embedding_dim=nbr_embedding_size) 402 | self.no_hyper = no_hyper 403 | # create a list of Message passing layers 404 | msg_heads = msg_heads 405 | if not self.update_edges: 406 | self.graphs = nn.ModuleList( 407 | [ 408 | GATConvNodes(elem_fea_len, 409 | nbr_embedding_size, 410 | msg_heads, 411 | concat=True, 412 | vector_attention=vector_attention, 413 | first=True) 414 | ]) 415 | self.graphs.extend( 416 | [ 417 | GATConvNodes(elem_fea_len, 418 | nbr_embedding_size, 419 | msg_heads, 420 | concat=True, 421 | vector_attention=vector_attention) 422 | for i in range( 423 | n_graph - 1) 424 | ] 425 | ) 426 | 427 | elif no_hyper: #no hyper networks for edge updates 428 | self.graphs = nn.ModuleList( 429 | [ 430 | nn.ModuleDict( 431 | { 432 | 'Node': GATConvNodes(elem_fea_len, 433 | elem_fea_len, 434 | nbr_embedding_size, 435 | msg_heads, 436 | concat=True, 437 | vector_attention=vector_attention, 438 | first=True), 439 | 'Edge': GATConvEdges(elem_fea_len, 440 | nbr_embedding_size, 441 | nbr_embedding_size, 442 | msg_heads, 443 | concat=True, 444 | vector_attention=vector_attention, 445 | first=True) 446 | } 447 | ) 448 | ] 449 | ) 450 | self.graphs.extend( 451 | [ 452 | nn.ModuleDict( 453 | { 454 | 'Node': GATConvNodes(elem_fea_len, 455 | elem_fea_len, 456 | nbr_embedding_size, 457 | msg_heads, 458 | concat=True, 459 | vector_attention=vector_attention), 460 | 461 | 'Edge': GATConvEdges(elem_fea_len, 462 | nbr_embedding_size, 463 | nbr_embedding_size, 464 | msg_heads, 465 | concat=True, 466 | vector_attention=vector_attention) 467 | } 468 | ) for i in range(n_graph - 1)]) 469 | else: 470 | self.graphs = nn.ModuleList( 471 | [ 472 | nn.ModuleDict( 473 | { 474 | 'Node': GATConvNodes(elem_fea_len, 475 | elem_fea_len, 476 | nbr_embedding_size, 477 | msg_heads, 478 | concat=True, 479 | vector_attention=vector_attention, 480 | first=True), 481 | 482 | 'Edge': GATConvEdges(elem_fea_len, 483 | nbr_embedding_size, 484 | nbr_embedding_size, 485 | msg_heads, 486 | concat=True, 487 | vector_attention=vector_attention, 488 | first=True, 489 | no_hyper=False)})]) 490 | self.graphs.extend( 491 | [ 492 | nn.ModuleDict( 493 | { 494 | 'Node': GATConvNodes(elem_fea_len, 495 | elem_fea_len, 496 | nbr_embedding_size, 497 | msg_heads, 498 | concat=True, 499 | vector_attention=vector_attention), 500 | 501 | 'Edge': GATConvEdges(elem_fea_len, 502 | nbr_embedding_size, 503 | nbr_embedding_size, 504 | msg_heads, 505 | concat=True, 506 | vector_attention=vector_attention, 507 | no_hyper=False) 508 | } 509 | ) 510 | for i in range(n_graph - 1)] 511 | ) 512 | 513 | # Add a roost model for a composition based pooling 514 | self.roost = Roost(orig_elem_fea_len, elem_fea_len, n_graph_roost) 515 | 516 | # define a global pooling function for materials 517 | mat_heads = msg_heads 518 | self.cry_pool = MHAttention(in_channels=elem_fea_len, 519 | out_channels=elem_fea_len, 520 | heads=mat_heads, 521 | vector_attention=global_vector_attention) 522 | 523 | # define an output neural network 524 | self.msg_heads = msg_heads 525 | self.elem_fea_len = elem_fea_len 526 | out_hidden = [1024, 1024, 512, 512, 256, 256, 128] 527 | 528 | if mean_pooling: 529 | self.output_nn = ResidualNetwork(elem_fea_len, 530 | 2, 531 | out_hidden, 532 | if_rezero=rezero) 533 | else: 534 | self.output_nn = ResidualNetwork(elem_fea_len * msg_heads, 535 | 2, 536 | out_hidden, 537 | if_rezero=rezero) 538 | 539 | 540 | def forward(self, batch, roost, *, last_layer=True, return_graph_embedding=False): 541 | """ 542 | Forward pass 543 | 544 | Parameters 545 | ---------- 546 | batch: 547 | Batch of torch_geometric graph objects 548 | roost: 549 | ---------- 550 | orig_elem_fea: Variable(torch.Tensor) shape (N, orig_elem_fea_len) 551 | Atom features of each of the N elems in the batch 552 | self_fea_idx: torch.Tensor shape (M,) 553 | Indices of the elem each of the M bonds correspond to 554 | nbr_fea_idx: torch.Tensor shape (M,) 555 | Indices of of the neighbours of the M bonds connect to 556 | elem_bond_idx: list of torch.LongTensor of length C 557 | Mapping from the bond idx to elem idx 558 | crystal_elem_idx: list of torch.LongTensor of length C 559 | Mapping from the elem idx to crystal idx 560 | 561 | Returns 562 | ------- 563 | out: nn.Variable shape (C,) 564 | Atom hidden features after message passing 565 | """ 566 | edge_index = batch.edge_index 567 | crystal_elem_idx = batch.batch 568 | size = (batch.num_nodes, batch.num_nodes) 569 | edge_attr = self.nbr_embedding(batch.edge_attr) 570 | elem_fea = self.embedding(batch.x) 571 | elem_fea_0 = elem_fea.clone() 572 | 573 | if self.update_edges: 574 | edge_attr_0 = edge_attr.clone() 575 | if(not self.update_edges): 576 | for graph_func in self.graphs: 577 | elem_fea = elem_fea + \ 578 | graph_func(elem_fea, edge_index, edge_attr, elem_fea_0, size) 579 | else: 580 | for graph_func in self.graphs: 581 | node_update = graph_func['Node']( 582 | elem_fea, edge_index, edge_attr, elem_fea_0, size) 583 | edge_attr = edge_attr + \ 584 | graph_func['Edge'](elem_fea, edge_index, edge_attr, edge_attr_0, size) 585 | elem_fea = elem_fea + node_update 586 | 587 | crys_fea = self.roost(*roost) 588 | crys_fea = self.cry_pool(elem_fea, crys_fea, crystal_elem_idx) 589 | 590 | if self.mean_pooling: 591 | crys_fea = crys_fea.view(-1, self.msg_heads, self.elem_fea_len) 592 | crys_fea = torch.mean(crys_fea, dim=1) 593 | if return_graph_embedding: 594 | return crys_fea 595 | crys_fea = self.output_nn(crys_fea, last_layer=last_layer) 596 | else: 597 | if return_graph_embedding: 598 | return crys_fea 599 | crys_fea = self.output_nn(crys_fea, last_layer=last_layer) 600 | return crys_fea 601 | 602 | def __repr__(self): 603 | return '{}'.format(self.__class__.__name__) 604 | 605 | def get_output_parameters(self): 606 | return self.output_nn.parameters() 607 | 608 | def get_hidden_parameters(self): 609 | return itertools.chain(self.embedding.parameters(), 610 | self.nbr_embedding.parameters(), 611 | self.graphs.parameters(), 612 | self.roost.parameters(), 613 | self.cry_pool.parameters()) 614 | -------------------------------------------------------------------------------- /CGAT/Hypernetworksmp.py: -------------------------------------------------------------------------------- 1 | # Uses code from https://github.com/vsitzmann/scene-representation-networks 2 | '''Pytorch implementations of hyper-network modules.''' 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | import torchvision.utils 7 | 8 | import numpy as np 9 | 10 | import math 11 | import numbers 12 | 13 | import functools 14 | 15 | 16 | def partialclass(cls, *args, **kwds): 17 | 18 | class NewCls(cls): 19 | __init__ = functools.partialmethod(cls.__init__, *args, **kwds) 20 | 21 | return NewCls 22 | 23 | 24 | class FCLayer(nn.Module): 25 | def __init__(self, in_features, out_features): 26 | super().__init__() 27 | self.net = nn.Sequential( 28 | nn.Linear(in_features, out_features), 29 | nn.Tanh(), 30 | ) 31 | 32 | def forward(self, input): 33 | return self.net(input) 34 | 35 | 36 | class FCBlock(nn.Module): 37 | def __init__(self, 38 | hidden_ch, 39 | num_hidden_layers, 40 | in_features, 41 | out_features, 42 | outermost_linear=False): 43 | super().__init__() 44 | 45 | self.net = [] 46 | self.net.append( 47 | FCLayer( 48 | in_features=in_features, 49 | out_features=hidden_ch)) 50 | 51 | for i in range(num_hidden_layers): 52 | self.net.append( 53 | FCLayer( 54 | in_features=hidden_ch, 55 | out_features=hidden_ch)) 56 | 57 | if outermost_linear: 58 | self.net.append( 59 | nn.Linear( 60 | in_features=hidden_ch, 61 | out_features=out_features)) 62 | else: 63 | self.net.append( 64 | FCLayer( 65 | in_features=hidden_ch, 66 | out_features=out_features)) 67 | 68 | self.net = nn.Sequential(*self.net) 69 | self.net.apply(self.init_weights) 70 | 71 | def __getitem__(self, item): 72 | return self.net[item] 73 | 74 | def init_weights(self, m): 75 | if isinstance(m, nn.Linear): 76 | nn.init.kaiming_normal_( 77 | m.weight, 78 | a=0.0, 79 | nonlinearity='leaky_relu', 80 | mode='fan_in') 81 | 82 | def forward(self, input): 83 | return self.net(input) 84 | 85 | 86 | class HyperLayer(nn.Module): 87 | '''A hypernetwork that predicts a single Dense Layer, including LayerNorm and a ReLU.''' 88 | 89 | def __init__(self, 90 | in_ch, 91 | out_ch, 92 | hyper_in_ch, 93 | hyper_num_hidden_layers, 94 | hyper_hidden_ch): 95 | super().__init__() 96 | 97 | self.hyper_linear = HyperLinear( 98 | in_ch=in_ch, 99 | out_ch=out_ch, 100 | hyper_in_ch=hyper_in_ch, 101 | hyper_num_hidden_layers=hyper_num_hidden_layers, 102 | hyper_hidden_ch=hyper_hidden_ch) 103 | self.norm_nl = nn.Sequential( 104 | nn.LayerNorm([out_ch], elementwise_affine=False), 105 | # nn.ReLU(inplace=True) 106 | nn.Tanh() 107 | ) 108 | 109 | def forward(self, hyper_input): 110 | ''' 111 | :param hyper_input: input to hypernetwork. 112 | :return: nn.Module; predicted fully connected network. 113 | ''' 114 | return nn.Sequential(self.hyper_linear(hyper_input), self.norm_nl) 115 | 116 | 117 | class HyperFC(nn.Module): 118 | '''Builds a hypernetwork that predicts a fully connected neural network. 119 | ''' 120 | 121 | def __init__(self, 122 | hyper_in_ch, 123 | hyper_num_hidden_layers, 124 | hyper_hidden_ch, 125 | hidden_ch, 126 | num_hidden_layers, 127 | in_ch, 128 | out_ch, 129 | outermost_linear=False): 130 | super().__init__() 131 | 132 | # PreconfHyperLinear = partialclass(HyperLinear, 133 | # hyper_in_ch=hyper_in_ch, 134 | # hyper_num_hidden_layers=hyper_num_hidden_layers, 135 | # hyper_hidden_ch=hyper_hidden_ch) 136 | # PreconfHyperLayer = partialclass(HyperLayer, 137 | # hyper_in_ch=hyper_in_ch, 138 | # hyper_num_hidden_layers=hyper_num_hidden_layers, 139 | # hyper_hidden_ch=hyper_hidden_ch) 140 | 141 | self.layers = nn.ModuleList() 142 | self.layers.append( 143 | HyperLayer( 144 | in_ch=in_ch, 145 | out_ch=hidden_ch, 146 | hyper_in_ch=hyper_in_ch, 147 | hyper_num_hidden_layers=hyper_num_hidden_layers, 148 | hyper_hidden_ch=hyper_hidden_ch)) 149 | 150 | for i in range(num_hidden_layers): 151 | self.layers.append( 152 | HyperLayer( 153 | in_ch=hidden_ch, 154 | out_ch=hidden_ch, 155 | hyper_in_ch=hyper_in_ch, 156 | hyper_num_hidden_layers=hyper_num_hidden_layers, 157 | hyper_hidden_ch=hyper_hidden_ch)) 158 | 159 | if outermost_linear: 160 | self.layers.append( 161 | HyperLinear( 162 | in_ch=hidden_ch, 163 | out_ch=out_ch, 164 | hyper_in_ch=hyper_in_ch, 165 | hyper_num_hidden_layers=hyper_num_hidden_layers, 166 | hyper_hidden_ch=hyper_hidden_ch)) 167 | else: 168 | self.layers.append( 169 | HyperLayer( 170 | in_ch=hidden_ch, 171 | out_ch=out_ch, 172 | hyper_in_ch=hyper_in_ch, 173 | hyper_num_hidden_layers=hyper_num_hidden_layers, 174 | hyper_hidden_ch=hyper_hidden_ch)) 175 | 176 | def forward(self, hyper_input): 177 | ''' 178 | :param hyper_input: Input to hypernetwork. 179 | :return: nn.Module; Predicted fully connected neural network. 180 | ''' 181 | net = [] 182 | for i in range(len(self.layers)): 183 | net.append(self.layers[i](hyper_input)) 184 | 185 | return nn.Sequential(*net) 186 | 187 | 188 | class BatchLinear(nn.Module): 189 | def __init__(self, 190 | weights, 191 | biases): 192 | '''Implements a batch linear layer. 193 | :param weights: Shape: (batch, out_ch, in_ch) 194 | :param biases: Shape: (batch, 1, out_ch) 195 | ''' 196 | super().__init__() 197 | 198 | self.weights = weights 199 | self.biases = biases 200 | 201 | def __repr__(self): 202 | return "BatchLinear(in_ch=%d, out_ch=%d)" % ( 203 | self.weights.shape[-1], self.weights.shape[-2]) 204 | 205 | def forward(self, input): 206 | output = input.matmul(self.weights.permute( 207 | *[i for i in range(len(self.weights.shape) - 2)], -1, -2)) 208 | output += self.biases 209 | return output 210 | 211 | 212 | def last_hyper_layer_init(m): 213 | if isinstance(m, nn.Linear): 214 | nn.init.kaiming_normal_( 215 | m.weight, 216 | a=0.0, 217 | nonlinearity='leaky_relu', 218 | mode='fan_in') 219 | m.weight.data *= 1e-1 220 | 221 | 222 | class HyperLinear(nn.Module): 223 | '''A hypernetwork that predicts a single linear layer (weights & biases).''' 224 | 225 | def __init__(self, 226 | in_ch, 227 | out_ch, 228 | hyper_in_ch, 229 | hyper_num_hidden_layers, 230 | hyper_hidden_ch): 231 | 232 | super().__init__() 233 | self.in_ch = in_ch 234 | self.out_ch = out_ch 235 | 236 | self.hypo_params = FCBlock(in_features=hyper_in_ch, 237 | hidden_ch=hyper_hidden_ch, 238 | num_hidden_layers=hyper_num_hidden_layers, 239 | out_features=(in_ch * out_ch) + out_ch, 240 | outermost_linear=True) 241 | self.hypo_params[-1].apply(last_hyper_layer_init) 242 | 243 | def forward(self, hyper_input): 244 | hypo_params = self.hypo_params(hyper_input) 245 | 246 | # Indices explicit to catch erros in shape of output layer 247 | weights = hypo_params[..., :self.in_ch * self.out_ch] 248 | biases = hypo_params[..., self.in_ch * 249 | self.out_ch:(self.in_ch * self.out_ch) + self.out_ch] 250 | 251 | biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch) 252 | weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch) 253 | 254 | return BatchLinear(weights=weights, biases=biases) 255 | 256 | 257 | class H_Net_0(nn.Module): 258 | def __init__(self, hyper_in_ch, 259 | hyper_num_hidden_layers, 260 | hyper_hidden_ch, 261 | hidden_ch, 262 | num_hidden_layers, 263 | in_ch, 264 | out_ch, 265 | outermost_linear=True): 266 | super(H_Net_0, self).__init__() 267 | self.Hyper = HyperFC(hyper_in_ch, 268 | hyper_num_hidden_layers, 269 | hyper_hidden_ch, 270 | hidden_ch, 271 | num_hidden_layers, 272 | in_ch, 273 | out_ch, 274 | outermost_linear=True) 275 | self.out_ch = out_ch 276 | 277 | def forward(self, h_0, x): 278 | NN = self.Hyper(h_0) 279 | return NN( 280 | x.view( 281 | x.shape[0], 282 | 1, 283 | x.shape[1])).view( 284 | x.shape[0], 285 | self.out_ch) 286 | 287 | 288 | class H_Net(nn.Module): 289 | def __init__(self, hyper_in_ch, 290 | hyper_num_hidden_layers, 291 | hyper_hidden_ch, 292 | hidden_ch, 293 | num_hidden_layers, 294 | in_ch, 295 | out_ch, 296 | outermost_linear=True): 297 | super(H_Net, self).__init__() 298 | self.Hyper = HyperFC(hyper_in_ch, 299 | hyper_num_hidden_layers, 300 | hyper_hidden_ch, 301 | hidden_ch, 302 | num_hidden_layers, 303 | in_ch, 304 | out_ch, 305 | outermost_linear=True) 306 | self.damping = nn.Parameter(torch.rand(1)) 307 | self.out_ch = out_ch 308 | 309 | def forward(self, h_0, h_t, x): 310 | with torch.no_grad(): 311 | self.damping.data = self.damping.data.clamp(0.0, 1.0) 312 | NN = self.Hyper(self.damping * h_0 + (1 - self.damping) * x) 313 | return NN(x.view(x.shape[0], 1, x.shape[1])).view(x.shape[0], self.out_ch) 314 | -------------------------------------------------------------------------------- /CGAT/__init__.py: -------------------------------------------------------------------------------- 1 | from .CGAT import CGAtNet 2 | -------------------------------------------------------------------------------- /CGAT/add_volume_target.py: -------------------------------------------------------------------------------- 1 | import gzip as gz 2 | import pickle 3 | from pymatgen.entries.computed_entries import ComputedStructureEntry 4 | from typing import List 5 | import re 6 | from tqdm import tqdm 7 | 8 | 9 | def main(): 10 | id_ = 0 11 | pattern = re.compile(r'spg(\d{1,3})') 12 | for i in tqdm(range(0, 2830000, 10000)): 13 | path = f'data_{i}_{i + 10000}.pickle.gz' 14 | data_list: List[ComputedStructureEntry] = pickle.load(gz.open(f'../original_data/{path}', 'rb')) 15 | to_remove = [] 16 | for data in data_list: 17 | data.data['volume'] = data.structure.volume / data.structure.num_sites 18 | try: 19 | # use spg from Data if it is included 20 | data.data['id'] = f"{id_},{data.data['spg']}" 21 | except KeyError: 22 | # else try to extract it from id 23 | try: 24 | match = pattern.search(data.data['id']) 25 | spg = int(match.group(1)) 26 | except AttributeError: 27 | # calculate spg as a last resort 28 | spg = data.structure.get_space_group_info() 29 | data.data['id'] = f"{id_},{spg}" 30 | # remove elements with just one element 31 | if len(set(data.structure.atomic_numbers)) == 1: 32 | to_remove.append(data) 33 | else: 34 | id_ += 1 35 | pickle.dump([data for data in data_list if data not in to_remove], gz.open(f'../unprepared_volume_data/{path}', 'wb')) 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /CGAT/data.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Data 2 | import gzip as gz 3 | import os 4 | import sys 5 | 6 | # import functools 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import Dataset 10 | import pickle 11 | from .roost_message import LoadFeaturiser 12 | 13 | import re 14 | 15 | 16 | class CompositionData(Dataset): 17 | """ 18 | The CompositionData dataset is a wrapper for a dataset 19 | """ 20 | 21 | def __init__(self, data, fea_path, radius=8.0, max_neighbor_number=12, target='e_above_hull'): 22 | """ 23 | Constructs dataset 24 | Args: 25 | data: expects either a gzipped pickle of the dictionary or a dictionary 26 | with the keys 'batch_comp', 'comps', 'target', 'input' 27 | fea_path: 28 | path to file containing the element embedding information 29 | radius: 30 | cutoff radius 31 | max_neighbor_number: 32 | maximum number of neighbors used during message passing 33 | target: 34 | name of training/validation/testing target 35 | Returns: 36 | """ 37 | 38 | if isinstance(data, str): 39 | assert os.path.exists(data), \ 40 | "{} does not exist!".format(data) 41 | self.data = pickle.load(gz.open(data, "rb")) 42 | else: 43 | self.data = data 44 | 45 | self.radius = radius 46 | self.max_num_nbr = max_neighbor_number 47 | if self.data['input'].shape[0] > 3: 48 | self.format = 1 49 | else: 50 | self.format = 0 51 | assert os.path.exists(fea_path), "{} does not exist!".format(fea_path) 52 | self.atom_features = LoadFeaturiser(fea_path) 53 | self.atom_fea_dim = self.atom_features.embedding_size 54 | self.target = target 55 | 56 | def __len__(self): 57 | """Returns length of dataset""" 58 | return len(self.data['target'][self.target]) 59 | 60 | # @functools.lru_cache(maxsize=None) # Cache loaded structures 61 | def __getitem__(self, idx): 62 | composition = self.data['batch_comp'][idx] 63 | elements = self.data['comps'][idx] 64 | if isinstance(elements, str): 65 | pattern = re.compile(r'([a-z]+)(\d+)', re.IGNORECASE) 66 | try: 67 | matches = pattern.findall(self.data['batch_comp'][idx]) 68 | except TypeError: 69 | matches = pattern.findall(self.data['batch_comp'][idx][0]) 70 | elements = [] 71 | for el, count in matches: 72 | for _ in range(int(count)): 73 | elements.append(el) 74 | try: 75 | elements = elements.tolist() 76 | except BaseException: 77 | pass 78 | if (isinstance(elements[0], list) or isinstance(elements[0], tuple)): 79 | elements = [el[0] for el in elements] 80 | N = len(elements) 81 | comp = {} 82 | weights = [] 83 | elements2 = [] 84 | for el in elements: 85 | comp[el] = elements.count(el) 86 | 87 | for k, v in comp.items(): 88 | weights.append(v / len(elements)) 89 | elements2.append(k) 90 | env_idx = list(range(len(elements2))) 91 | self_fea_idx_c = [] 92 | nbr_fea_idx_c = [] 93 | nbrs = len(elements2) - 1 94 | for i, _ in enumerate(elements2): 95 | self_fea_idx_c += [i] * nbrs 96 | nbr_fea_idx_c += env_idx[:i] + env_idx[i + 1:] 97 | 98 | atom_fea_c = np.vstack([self.atom_features.get_fea(element) 99 | for element in elements2]) 100 | atom_weights_c = torch.Tensor(weights) 101 | atom_fea_c = torch.Tensor(atom_fea_c) 102 | self_fea_idx_c = torch.LongTensor(self_fea_idx_c) 103 | nbr_fea_idx_c = torch.LongTensor(nbr_fea_idx_c) 104 | 105 | if (self.format == 0): 106 | try: 107 | atom_fea = np.vstack([self.atom_features.get_fea(element) 108 | for element in elements]) 109 | except AssertionError: 110 | print(composition) 111 | sys.exit() 112 | 113 | target = self.data['target'][self.target][idx] 114 | atom_fea = torch.Tensor(atom_fea) 115 | nbr_fea = torch.LongTensor( 116 | self.data['input'][0][idx][:, 0:self.max_num_nbr].flatten().astype(int)) 117 | nbr_fea_idx = torch.LongTensor( 118 | self.data['input'][2][idx][:, 0:self.max_num_nbr].flatten().astype(int)) 119 | self_fea_idx = torch.LongTensor( 120 | self.data['input'][1][idx][:, 0:self.max_num_nbr].flatten().astype(int)) 121 | target = torch.Tensor([target]) 122 | else: 123 | try: 124 | atom_fea = np.vstack([self.atom_features.get_fea( 125 | elements[i]) for i in range(len(elements))]) 126 | except AssertionError: 127 | print(composition) 128 | sys.exit() 129 | 130 | target = self.data['target'][self.target][idx] 131 | atom_fea = torch.Tensor(atom_fea) 132 | nbr_fea = torch.LongTensor( 133 | self.data['input'][idx][0][:, 0:self.max_num_nbr].flatten()) 134 | nbr_fea_idx = torch.LongTensor( 135 | self.data['input'][idx][2][:, 0:self.max_num_nbr].flatten()) 136 | self_fea_idx = torch.LongTensor( 137 | self.data['input'][idx][1][:, 0:self.max_num_nbr].flatten()) 138 | target = torch.Tensor([target]) 139 | if self.target != 'volume': 140 | return Data(x=atom_fea, edge_index=torch.stack((self_fea_idx, nbr_fea_idx)), edge_attr=nbr_fea, 141 | y=target * N), (atom_weights_c, atom_fea_c, self_fea_idx_c, nbr_fea_idx_c) 142 | else: 143 | return Data(x=atom_fea, edge_index=torch.stack((self_fea_idx, nbr_fea_idx)), edge_attr=nbr_fea, 144 | y=target), (atom_weights_c, atom_fea_c, self_fea_idx_c, nbr_fea_idx_c) 145 | -------------------------------------------------------------------------------- /CGAT/lambs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # MIT License 16 | # 17 | # Copyright (c) 2019 cybertronai 18 | # 19 | # Permission is hereby granted, free of charge, to any person obtaining a copy 20 | # of this software and associated documentation files (the "Software"), to deal 21 | # in the Software without restriction, including without limitation the rights 22 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 23 | # copies of the Software, and to permit persons to whom the Software is 24 | # furnished to do so, subject to the following conditions: 25 | # 26 | # The above copyright notice and this permission notice shall be included in all 27 | # copies or substantial portions of the Software. 28 | # 29 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 30 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 31 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 32 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 33 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 | # SOFTWARE. 36 | 37 | """Lamb optimizer.""" 38 | 39 | import torch 40 | from torch.optim import Optimizer 41 | 42 | 43 | class Lamb(Optimizer): 44 | r"""Implements Lamb algorithm. 45 | 46 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 47 | 48 | Arguments: 49 | params (iterable): iterable of parameters to optimize or dicts defining 50 | parameter groups 51 | lr (float, optional): learning rate (default: 1e-3) 52 | betas (Tuple[float, float], optional): coefficients used for computing 53 | running averages of gradient and its square (default: (0.9, 0.999)) 54 | eps (float, optional): term added to the denominator to improve 55 | numerical stability (default: 1e-8) 56 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 57 | adam (bool, optional): always use trust ratio = 1, which turns this into 58 | Adam. Useful for comparison purposes. 59 | 60 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 61 | https://arxiv.org/abs/1904.00962 62 | """ 63 | 64 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 65 | weight_decay=0, adam=False): 66 | if not 0.0 <= lr: 67 | raise ValueError("Invalid learning rate: {}".format(lr)) 68 | if not 0.0 <= eps: 69 | raise ValueError("Invalid epsilon value: {}".format(eps)) 70 | if not 0.0 <= betas[0] < 1.0: 71 | raise ValueError( 72 | "Invalid beta parameter at index 0: {}".format( 73 | betas[0])) 74 | if not 0.0 <= betas[1] < 1.0: 75 | raise ValueError( 76 | "Invalid beta parameter at index 1: {}".format( 77 | betas[1])) 78 | defaults = dict(lr=lr, betas=betas, eps=eps, 79 | weight_decay=weight_decay) 80 | self.adam = adam 81 | super(Lamb, self).__init__(params, defaults) 82 | 83 | def step(self, closure=None): 84 | """Performs a single optimization step. 85 | 86 | Arguments: 87 | closure (callable, optional): A closure that reevaluates the model 88 | and returns the loss. 89 | """ 90 | loss = None 91 | if closure is not None: 92 | loss = closure() 93 | 94 | for group in self.param_groups: 95 | for p in group['params']: 96 | if p.grad is None: 97 | continue 98 | grad = p.grad.data 99 | if grad.is_sparse: 100 | raise RuntimeError( 101 | 'Lamb does not support sparse gradients.') 102 | 103 | state = self.state[p] 104 | 105 | # State initialization 106 | if len(state) == 0: 107 | state['step'] = 0 108 | # Exponential moving average of gradient values 109 | state['exp_avg'] = torch.zeros_like(p.data) 110 | # Exponential moving average of squared gradient values 111 | state['exp_avg_sq'] = torch.zeros_like(p.data) 112 | 113 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 114 | beta1, beta2 = group['betas'] 115 | 116 | state['step'] += 1 117 | 118 | # Decay the first and second moment running average coefficient 119 | # m_t 120 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 121 | # v_t 122 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 123 | 124 | # Paper v3 does not use debiasing. 125 | # bias_correction1 = 1 - beta1 ** state['step'] 126 | # bias_correction2 = 1 - beta2 ** state['step'] 127 | # Apply bias to lr to avoid broadcast. 128 | # * math.sqrt(bias_correction2) / bias_correction1 129 | step_size = group['lr'] 130 | 131 | weight_norm = p.data.norm(p=2).clamp_(0, 10) 132 | 133 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 134 | if group['weight_decay'] != 0: 135 | adam_step.add_(group['weight_decay'], p.data) 136 | 137 | adam_norm = adam_step.norm(p=2) 138 | 139 | if weight_norm == 0.0 or adam_norm == 0.0: 140 | trust_ratio = 1 141 | else: 142 | trust_ratio = weight_norm / (adam_norm + group['eps']) 143 | 144 | state['weight_norm'] = weight_norm 145 | state['adam_norm'] = adam_norm 146 | state['trust_ratio'] = trust_ratio 147 | if self.adam: 148 | trust_ratio = 1 149 | 150 | p.data.add_(-step_size * trust_ratio, adam_step) 151 | 152 | return loss 153 | 154 | 155 | @torch.jit.script 156 | def lamb_kernel( 157 | param, 158 | grad, 159 | exp_avg, 160 | exp_avg_sq, 161 | beta1: float, 162 | beta2: float, 163 | step_size: float, 164 | eps: float, 165 | weight_decay: float): 166 | exp_avg = exp_avg * beta1 + (1 - beta1) * grad 167 | exp_avg_sq = exp_avg_sq * beta2 + (1 - beta2) * (grad * grad) 168 | 169 | adam_step = exp_avg / (exp_avg_sq.sqrt() + eps) 170 | adam_step = adam_step + weight_decay * param 171 | 172 | weight_norm = param.norm(p=2).clamp_(0, 10) 173 | adam_norm = adam_step.norm(p=2) 174 | 175 | trust_ratio = weight_norm / (adam_norm + eps) 176 | trust_ratio = (weight_norm == 0.0) * 1.0 + \ 177 | (weight_norm != 0.0) * trust_ratio 178 | trust_ratio = (adam_norm == 0.0) * 1.0 + (adam_norm != 0.0) * trust_ratio 179 | 180 | param = param - step_size * trust_ratio * adam_step 181 | return param, exp_avg, exp_avg_sq 182 | 183 | 184 | class JITLamb(Optimizer): 185 | r"""Implements Lamb algorithm. 186 | 187 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 188 | 189 | Arguments: 190 | params (iterable): iterable of parameters to optimize or dicts defining 191 | parameter groups 192 | lr (float, optional): learning rate (default: 1e-3) 193 | betas (Tuple[float, float], optional): coefficients used for computing 194 | running averages of gradient and its square (default: (0.9, 0.999)) 195 | eps (float, optional): term added to the denominator to improve 196 | numerical stability (default: 1e-8) 197 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 198 | adam (bool, optional): always use trust ratio = 1, which turns this into 199 | Adam. Useful for comparison purposes. 200 | 201 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 202 | https://arxiv.org/abs/1904.00962 203 | """ 204 | 205 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 206 | weight_decay=0, adam=False): 207 | if not 0.0 <= lr: 208 | raise ValueError("Invalid learning rate: {}".format(lr)) 209 | if not 0.0 <= eps: 210 | raise ValueError("Invalid epsilon value: {}".format(eps)) 211 | if not 0.0 <= betas[0] < 1.0: 212 | raise ValueError( 213 | "Invalid beta parameter at index 0: {}".format( 214 | betas[0])) 215 | if not 0.0 <= betas[1] < 1.0: 216 | raise ValueError( 217 | "Invalid beta parameter at index 1: {}".format( 218 | betas[1])) 219 | defaults = dict(lr=lr, betas=betas, eps=eps, 220 | weight_decay=weight_decay) 221 | self.adam = adam 222 | super().__init__(params, defaults) 223 | 224 | def step(self, closure=None): 225 | """Performs a single optimization step. 226 | 227 | Arguments: 228 | closure (callable, optional): A closure that reevaluates the model 229 | and returns the loss. 230 | """ 231 | loss = None 232 | if closure is not None: 233 | loss = closure() 234 | 235 | for group in self.param_groups: 236 | for p in group['params']: 237 | if p.grad is None: 238 | continue 239 | grad = p.grad.data 240 | if grad.is_sparse: 241 | raise RuntimeError( 242 | 'Lamb does not support sparse gradients.') 243 | 244 | state = self.state[p] 245 | 246 | # State initialization 247 | if len(state) == 0: 248 | state['step'] = 0 249 | # Exponential moving average of gradient values 250 | state['exp_avg'] = torch.zeros_like(p.data) 251 | # Exponential moving average of squared gradient values 252 | state['exp_avg_sq'] = torch.zeros_like(p.data) 253 | 254 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 255 | beta1, beta2 = group['betas'] 256 | 257 | state['step'] += 1 258 | step_size = group['lr'] 259 | 260 | param, exp_avg, exp_avg_sq = lamb_kernel(p.data, grad, exp_avg, 261 | exp_avg_sq, beta1, 262 | beta2, step_size, 263 | group['eps'], 264 | group['weight_decay'], 265 | ) 266 | state['exp_avg'] = exp_avg 267 | state['exp_avg_sq'] = exp_avg_sq 268 | p.data = param 269 | 270 | return loss 271 | -------------------------------------------------------------------------------- /CGAT/lightning_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example template for defining a system 3 | """ 4 | from CGAT.roost_message import collate_batch 5 | import importlib 6 | from CGAT.lambs import JITLamb 7 | import numpy as np 8 | 9 | from argparse import ArgumentParser, Namespace 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | from torch.utils.data import DataLoader 15 | from torch.nn.functional import l1_loss as mae 16 | from torch.nn.functional import mse_loss as mse 17 | 18 | from sklearn.model_selection import train_test_split as split 19 | 20 | from CGAT.utils import RobustL1, RobustL2, cyclical_lr 21 | from torch_geometric.data import Batch 22 | from CGAT.data import CompositionData 23 | 24 | from pytorch_lightning.core import LightningModule 25 | import os, glob 26 | 27 | 28 | def collate_fn(datalist): 29 | return datalist 30 | 31 | 32 | class LightningModel(LightningModule): 33 | """ 34 | Lightning model for CGAtNet defined in hparams.version 35 | """ 36 | 37 | def __init__(self, hparams): 38 | """ 39 | Pass in parsed HyperOptArgumentParser to the model 40 | Args: 41 | hparams: Namespace passed from the argument parser including all the hyperparameters 42 | """ 43 | super().__init__() 44 | # initialization of mean and standard deviation of the target data (needed for reloading without recalculation) 45 | self.mean = torch.nn.parameter.Parameter(torch.zeros(1), requires_grad=False) 46 | self.std = torch.nn.parameter.Parameter(torch.zeros(1), requires_grad=False) 47 | 48 | # self.hparams = hparams 49 | self.save_hyperparameters(hparams) 50 | # datasets are loaded for training or testing not needed in production 51 | if self.hparams.train: 52 | datasets = [] 53 | # used for single file 54 | try: 55 | dataset = CompositionData( 56 | data=self.hparams.data_path, 57 | fea_path=self.hparams.fea_path, 58 | max_neighbor_number=self.hparams.max_nbr, 59 | target=self.hparams.target) 60 | print(self.hparams.data_path + ' loaded') 61 | # used for folder of dataset files 62 | except: 63 | f_n = sorted([file for file in glob.glob(os.path.join(self.hparams.data_path, "*.pickle.gz"))]) 64 | print("{} files to load".format(len(f_n))) 65 | for file in f_n: 66 | try: 67 | datasets.append(CompositionData( 68 | data=file, 69 | fea_path=self.hparams.fea_path, 70 | max_neighbor_number=self.hparams.max_nbr, 71 | target=self.hparams.target)) 72 | print(file + ' loaded') 73 | except: 74 | print(file + ' could not be loaded') 75 | print("{} files succesfully loaded".format(len(datasets))) 76 | dataset = torch.utils.data.ConcatDataset(datasets) 77 | 78 | if self.hparams.test_path is None or self.hparams.val_path is None: 79 | indices = list(range(len(dataset))) 80 | train_idx, test_idx = split(indices, random_state=self.hparams.seed, 81 | test_size=self.hparams.test_size) 82 | train_set = torch.utils.data.Subset(dataset, train_idx) 83 | self.test_set = torch.utils.data.Subset(dataset, test_idx) 84 | indices = list(range(len(train_set))) 85 | train_idx, val_idx = split(indices, random_state=self.hparams.seed, 86 | test_size=self.hparams.val_size / (1 - self.hparams.test_size)) 87 | train_set_2 = torch.utils.data.Subset(train_set, train_idx) 88 | self.val_subset = torch.utils.data.Subset(train_set, val_idx) 89 | else: 90 | test_data = torch.utils.data.ConcatDataset([CompositionData(data=file, 91 | fea_path=self.hparams.fea_path, 92 | max_neighbor_number=self.hparams.max_nbr, 93 | target=self.hparams.target) 94 | for file in glob.glob( 95 | os.path.join(self.hparams.test_path, "*.pickle.gz"))]) 96 | val_data = torch.utils.data.ConcatDataset([CompositionData(data=file, 97 | fea_path=self.hparams.fea_path, 98 | max_neighbor_number=self.hparams.max_nbr, 99 | target=self.hparams.target) 100 | for file in glob.glob( 101 | os.path.join(self.hparams.val_path, "*.pickle.gz"))]) 102 | 103 | train_set = dataset 104 | self.test_set = test_data 105 | train_set_2 = train_set 106 | self.val_subset = val_data 107 | 108 | # Use train_percentage to get errors for different training set sizes 109 | # but same test and validation sets 110 | if self.hparams.train_percentage != 0.0: 111 | indices = list(range(len(train_set_2))) 112 | train_idx, rest_idx = split( 113 | indices, random_state=self.hparams.seed, test_size=1.0 - self.hparams.train_percentage / ( 114 | 1 - self.hparams.val_size - self.hparams.test_size)) 115 | self.train_subset = torch.utils.data.Subset(train_set_2, train_idx) 116 | else: 117 | self.train_subset = train_set_2 118 | print('Normalization started') 119 | 120 | def collate_fn2(data_list): 121 | return [el[0].y for el in data_list] 122 | 123 | sample_target = torch.cat(collate_fn2(self.train_subset)) 124 | self.mean = torch.nn.parameter.Parameter(torch.mean(sample_target, dim=0, keepdim=False), 125 | requires_grad=False) 126 | self.std = torch.nn.parameter.Parameter(torch.std(sample_target, dim=0, keepdim=False), requires_grad=False) 127 | print('mean: ', self.mean.item(), 'std: ', self.std.item()) 128 | print('normalization ended') 129 | 130 | # select loss function 131 | if not self.hparams.std_loss: 132 | print('robust loss') 133 | if self.hparams.loss == "L1": 134 | self.criterion = RobustL1 135 | elif self.hparams.loss == "L2": 136 | self.criterion = RobustL2 137 | elif self.hparams.std_loss: 138 | print('No robust loss function') 139 | if self.hparams.loss == 'L1': 140 | self.criterion = nn.L1Loss() 141 | else: 142 | self.criterion = nn.MSELoss() 143 | # build model 144 | self.__build_model() 145 | 146 | # --------------------- 147 | # MODEL SETUP 148 | # --------------------- 149 | def norm(self, tensor): 150 | """ 151 | normalizes tensor 152 | """ 153 | return (tensor - self.mean) / self.std 154 | 155 | def denorm(self, normed_tensor): 156 | """ 157 | return normalized tensor to original form 158 | """ 159 | return normed_tensor * self.std + self.mean 160 | 161 | def __build_model(self): 162 | """ 163 | builds model from hparams 164 | """ 165 | gat = importlib.import_module(self.hparams.version) 166 | self.model = gat.CGAtNet(200, 167 | elem_fea_len=self.hparams.atom_fea_len, 168 | n_graph=self.hparams.n_graph, 169 | rezero=self.hparams.rezero, 170 | mean_pooling=not self.hparams.mean_pooling, 171 | neighbor_number=self.hparams.max_nbr, 172 | msg_heads=self.hparams.msg_heads, 173 | update_edges=self.hparams.update_edges, 174 | vector_attention=self.hparams.vector_attention, 175 | global_vector_attention=self.hparams.global_vector_attention, 176 | n_graph_roost=self.hparams.n_graph_roost) # , self.hparams.dropout) 177 | 178 | params = sum([np.prod(p.size()) for p in self.model.parameters()]) 179 | print('this model has {0:1d} parameters '.format(params)) 180 | 181 | # --------------------- 182 | # TRAINING 183 | # --------------------- 184 | 185 | def evaluate(self, batch, *, last_layer=True, return_graph_embedding=False): 186 | """ 187 | calculates normalized and unnormalized output of the network. 188 | Batch object should include input for CGAT and Roost 189 | Args: 190 | batch: Tuple of graph object from pytorch geometric and input for Roost 191 | Returns: 192 | output: Normalized output 193 | log_std: log std for uncertainty estimation/loss function 194 | pred: Denormalized output 195 | target: target value for error calculations 196 | norm of target: normalized target for training 197 | """ 198 | device = next(self.model.parameters()).device 199 | b_comp, batch = [el[1] for el in batch], [el[0] for el in batch] 200 | batch = (Batch.from_data_list(batch)).to(device) 201 | b_comp = collate_batch(b_comp) 202 | b_comp = (tensor.to(device) for tensor in b_comp) 203 | if return_graph_embedding: 204 | return self.model(batch, b_comp, return_graph_embedding=True) 205 | if last_layer: 206 | output, log_std = self.model(batch, b_comp).chunk(2, dim=1) 207 | target = batch.y.view(len(batch.y), 1) 208 | target_norm = self.norm(target) 209 | pred = self.denorm(output.data) 210 | return output, log_std, pred, target, target_norm 211 | else: 212 | output = self.model(batch, b_comp, last_layer=last_layer) 213 | return output 214 | 215 | def forward(self, batch, batch_idx=None): 216 | """ 217 | Use for prediction with a dataloader 218 | Args: 219 | batch: Tuple of graph object from pytorch geometric and input for Roost 220 | batch_idx: identifiers of batch elements 221 | Returns: 222 | denormalized prediction of the network 223 | """ 224 | _, log_std, pred, _, _ = self.evaluate(batch) 225 | return pred 226 | 227 | def training_step(self, batch, batch_idx): 228 | """ 229 | Calculates loss and various error metrics 230 | Args: 231 | batch: Tuple of graph object from pytorch geometric and input for Roost 232 | batch_idx: identifiers of batch elements 233 | Returns: loss 234 | """ 235 | output, log_std, pred, target, target_norm = self.evaluate(batch) 236 | # calculate loss 237 | if not self.hparams.std_loss: 238 | loss = self.criterion(output, log_std, target_norm) 239 | else: 240 | loss = self.criterion(output, target_norm) 241 | 242 | mae_error = mae(pred, target) 243 | rmse_error = mse(pred, target).sqrt_() 244 | self.log('train_loss', 245 | loss, 246 | on_step=False, 247 | on_epoch=True, 248 | sync_dist=True) 249 | self.log('train_mae', 250 | mae_error, 251 | on_step=False, 252 | on_epoch=True, 253 | sync_dist=True) 254 | self.log('train_rmse', 255 | rmse_error, 256 | on_step=False, 257 | on_epoch=True, 258 | sync_dist=True) 259 | return loss 260 | 261 | def validation_step(self, batch, batch_idx): 262 | """ 263 | Calculates various error metrics for validation 264 | Args: 265 | batch: Tuple of graph object from pytorch geometric and input for Roost 266 | batch_idx: identifiers of batch elements 267 | Returns: 268 | """ 269 | output, log_std, pred, target, target_norm = self.evaluate(batch) 270 | 271 | if not self.hparams.std_loss: 272 | val_loss = self.criterion(output, log_std, target_norm) 273 | else: 274 | val_loss = self.criterion(output, target_norm) 275 | 276 | val_mae = mae(pred, target) 277 | val_rmse = mse(pred, target).sqrt_() 278 | self.log('val_loss', val_loss, on_epoch=True, sync_dist=True) 279 | self.log('val_mae', val_mae, on_epoch=True, sync_dist=True) 280 | self.log('val_rmse', val_rmse, on_epoch=True, sync_dist=True) 281 | 282 | def test_step(self, batch, batch_idx): 283 | """ 284 | Calculates various error metrics for testing 285 | Args: 286 | batch: Tuple of graph object from pytorch geometric and input for Roost 287 | batch_idx: identifiers of batch elements 288 | Returns: 289 | """ 290 | output, log_std, pred, target, target_norm = self.evaluate(batch) 291 | 292 | if not self.hparams.std_loss: 293 | test_loss = self.criterion(output, log_std, target_norm) 294 | else: 295 | test_loss = self.criterion(output, target_norm) 296 | 297 | test_mae = mae(pred, target) 298 | test_rmse = mse(pred, target).sqrt_() 299 | self.log('test_loss', test_loss, on_epoch=True, sync_dist=True) 300 | self.log('test_mae', test_mae, on_epoch=True, sync_dist=True) 301 | self.log('test_rmse', test_rmse, on_epoch=True, sync_dist=True) 302 | 303 | # --------------------- 304 | # TRAINING SETUP 305 | # --------------------- 306 | def configure_optimizers(self): 307 | """ 308 | Creates optimizers for training according to the hyperparameter settings 309 | Args: 310 | Returns: 311 | [optimizer], [scheduler]: Tuple of list of optimizers and list of learning rate schedulers 312 | """ 313 | # Select parameters, which should be trained 314 | if self.hparams.only_residual: 315 | parameters = self.model.get_output_parameters() 316 | else: 317 | parameters = self.model.parameters() 318 | # Select Optimiser 319 | if self.hparams.optim == "SGD": 320 | optimizer = optim.SGD(parameters, 321 | lr=self.hparams.learning_rate, 322 | weight_decay=self.hparams.weight_decay, 323 | momentum=self.hparams.momentum) 324 | elif self.hparams.optim == "Adam": 325 | optimizer = optim.Adam(parameters, 326 | lr=self.hparams.learning_rate, 327 | weight_decay=self.hparams.weight_decay) 328 | elif self.hparams.optim == "AdamW": 329 | optimizer = optim.AdamW(parameters, 330 | lr=self.hparams.learning_rate, 331 | weight_decay=self.hparams.weight_decay) 332 | elif self.hparams.optim == "LAMB": 333 | optimizer = JITLamb(parameters, 334 | lr=self.hparams.learning_rate, 335 | weight_decay=self.hparams.weight_decay) 336 | else: 337 | raise NameError( 338 | "Only SGD, Adam, AdamW, Lambs are allowed as --optim") 339 | 340 | if self.hparams.clr: 341 | clr = cyclical_lr(period=self.hparams.clr_period, 342 | cycle_mul=0.1, 343 | tune_mul=0.05, ) 344 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, [clr]) 345 | else: 346 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 347 | mode='min', 348 | factor=0.1, 349 | patience=5, 350 | verbose=False, 351 | threshold=0.0002, 352 | threshold_mode='rel', 353 | cooldown=0, 354 | eps=1e-08) 355 | return [optimizer], [scheduler] 356 | 357 | def train_dataloader(self): 358 | """ 359 | creates dataloader for training according to the hyperparameters 360 | Args: 361 | Returns: 362 | train_generator: Dataloader for training dataset 363 | """ 364 | params = {"batch_size": self.hparams.batch_size, 365 | "num_workers": self.hparams.workers, 366 | "pin_memory": False, 367 | "shuffle": True, 368 | "drop_last": True 369 | } 370 | print('length of train_subset: {}'.format(len(self.train_subset))) 371 | train_generator = DataLoader( 372 | self.train_subset, collate_fn=collate_fn, **params) 373 | return train_generator 374 | 375 | def val_dataloader(self): 376 | """ 377 | creates dataloader for validation according to the hyperparameters 378 | Args: 379 | Returns: 380 | val_generator: Dataloader for validation dataset 381 | """ 382 | params = {"batch_size": self.hparams.batch_size, 383 | # "num_workers": self.hparams.workers, 384 | "pin_memory": False, 385 | "drop_last": True, 386 | "shuffle": False} 387 | val_generator = DataLoader( 388 | self.val_subset, 389 | collate_fn=collate_fn, 390 | **params) 391 | print('length of val_subset: {}'.format(len(self.val_subset))) 392 | return val_generator 393 | 394 | def test_dataloader(self): 395 | """ 396 | creates dataloader for testing according to the hyperparameters 397 | Args: 398 | Returns: 399 | test_generator: Dataloader for testing dataset 400 | """ 401 | params = {"batch_size": self.hparams.batch_size, 402 | # "num_workers": self.hparams.workers, 403 | "pin_memory": False, 404 | "drop_last": True, 405 | "shuffle": False} 406 | test_generator = DataLoader( 407 | self.test_set, 408 | collate_fn=collate_fn, 409 | **params) 410 | print('length of test_subset: {}'.format(len(self.test_set))) 411 | return test_generator 412 | 413 | @classmethod 414 | def load(cls, path_to_checkpoint: str, train: bool = False): 415 | checkpoint = torch.load(path_to_checkpoint) 416 | hparams = Namespace() 417 | hparams.__dict__ = checkpoint['hyper_parameters'] 418 | hparams.train = train 419 | model = cls(hparams) 420 | model.load_state_dict(checkpoint['state_dict']) 421 | if not train: 422 | model.eval() 423 | 424 | return model 425 | 426 | @staticmethod 427 | def add_model_specific_args(parent_parser: ArgumentParser = None) -> ArgumentParser: # pragma: no-cover 428 | """ 429 | Parameters defined here will be available through self.hparams 430 | Args: 431 | parent_parser: ArgumentParser from e.g. the training script that adds gpu settings and Trainer settings 432 | Returns: 433 | parser: ArgumentParser for all hyperparameters and training/test settings 434 | """ 435 | if parent_parser is not None: 436 | parser = ArgumentParser(parents=[parent_parser]) 437 | else: 438 | parser = ArgumentParser() 439 | 440 | parser.add_argument("--data-path", 441 | type=str, 442 | default="data/", 443 | metavar="PATH", 444 | help="path to folder/file that contains dataset files, tries to load all " 445 | "*.pickle.gz in folder") 446 | parser.add_argument("--fea-path", 447 | type=str, 448 | default="../embeddings/matscholar-embedding.json", 449 | metavar="PATH", 450 | help="atom feature path") 451 | parser.add_argument("--version", 452 | type=str, 453 | default="CGAT", 454 | help="module from which to load CGAtNet class") 455 | parser.add_argument("--nbr-embedding-size", 456 | default=512, 457 | type=int, 458 | help="size of edge embedding") 459 | parser.add_argument("--msg-heads", 460 | default=5, 461 | type=int, 462 | help="number of attention-heads in message passing/final pooling layer") 463 | parser.add_argument("--workers", 464 | default=0, 465 | type=int, 466 | metavar="N", 467 | help="number of data loading workers (default: 0), crashes on some machines if used") 468 | parser.add_argument("--batch-size", 469 | default=64, 470 | type=int, 471 | metavar="N", 472 | help="mini-batch size (default: 128), when using multiple gpus the actual batch-size" 473 | " is --batch-size*n_gpus") 474 | parser.add_argument("--val-size", 475 | default=0.1, 476 | type=float, 477 | metavar="N", 478 | help="proportion of data used for validation") 479 | parser.add_argument("--test-size", 480 | default=0.1, 481 | type=float, 482 | metavar="N", 483 | help="proportion of data for testing") 484 | parser.add_argument("--max-nbr", 485 | default=24, 486 | type=int, 487 | metavar="max_N", 488 | help="num of neighbors maximum depends on the number set during the feature calculation") 489 | parser.add_argument("--epochs", 490 | default=390, 491 | type=int, 492 | metavar="N", 493 | help="number of total epochs to run") 494 | parser.add_argument("--loss", 495 | default="L1", 496 | type=str, 497 | metavar="str", 498 | help="choose a (Robust if std-loss False) Loss Function; L2 or L1") 499 | parser.add_argument("--optim", 500 | default="AdamW", 501 | type=str, 502 | metavar="str", 503 | help="choose an optimizer; SGD, Adam or AdamW") 504 | parser.add_argument("--learning-rate", "--lr", 505 | default=0.000125, 506 | type=float, 507 | metavar="float", 508 | help="initial learning rate (default: 3e-4)") 509 | parser.add_argument("--momentum", 510 | default=0.9, 511 | type=float, 512 | metavar="float [0,1]", 513 | help="momentum (default: 0.9)") 514 | parser.add_argument("--weight-decay", 515 | default=1e-6, 516 | type=float, 517 | metavar="float [0,1]", 518 | help="weight decay (default: 0)") 519 | parser.add_argument("--atom-fea-len", 520 | default=128, 521 | type=int, 522 | metavar="N", 523 | help="size of node embedding") 524 | parser.add_argument("--n-graph", 525 | default=5, 526 | type=int, 527 | metavar="N", 528 | help="number of graph layers in CGAT model") 529 | parser.add_argument("--n-graph-roost", 530 | default=3, 531 | type=int, 532 | metavar="N", 533 | help="number of graph layers in roost module") 534 | parser.add_argument("--global_vector_attention", 535 | action="store_false", 536 | help="whether vector attention or scalar attention is used") 537 | parser.add_argument("--update_edges", 538 | action="store_false", 539 | help="whether edges are updated") 540 | parser.add_argument("--vector_attention", 541 | action="store_false", 542 | help="whether vector attention or scalar attention is used") 543 | parser.add_argument("--clr", 544 | action="store_false", 545 | help="use a cyclical learning rate schedule") 546 | parser.add_argument("--rezero", 547 | action="store_false", 548 | help="start residual layers with 0 as prefactor") 549 | parser.add_argument("--mean-pooling", 550 | action="store_false", 551 | help="chooses pooling variant for attention heads (mean or concatenation)") 552 | parser.add_argument("--std-loss", 553 | action="store_false", 554 | help="whether to choose a loss function that considers uncertainty") 555 | parser.add_argument("--clr-period", 556 | default=130, 557 | type=int, 558 | help="how many epochs per learning rate cycle to perform") 559 | parser.add_argument("--train-percentage", 560 | default=0.0, 561 | type=float, 562 | help="Percentage of the training data that is used for training (only use to get" 563 | " a training set size vs test error curve") 564 | parser.add_argument("--seed", 565 | default=0, 566 | type=int, 567 | metavar="N", 568 | help="seed for random number generator") 569 | parser.add_argument("--smoke-test", 570 | action="store_true", 571 | help="Finish quickly for testing") 572 | parser.add_argument("--train", 573 | action="store_false", 574 | help="if set to True datasets will not be loaded to speed up loading of the model") 575 | parser.add_argument("--target", 576 | default="e_above_hull_new", 577 | type=str, 578 | metavar="str", 579 | help="choose the target variable, the dataset dictionary should have a corresponding" 580 | "dictionary structure data['target'][target]") 581 | parser.add_argument("--test-path", 582 | default=None, 583 | type=str, 584 | help="path to data set with the test set (only used in combination with --val-path)") 585 | parser.add_argument("--val-path", 586 | default=None, 587 | type=str, 588 | help="path to data set with the validation set (only used in combination with --val-path)") 589 | parser.add_argument("--only-residual", 590 | action="store_true", 591 | help="Train only the residual network for transfer learning.") 592 | 593 | return parser 594 | -------------------------------------------------------------------------------- /CGAT/message_changed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch_scatter import scatter_max, scatter_add, \ 5 | scatter_mean 6 | 7 | """ 8 | MIT License 9 | Copyright (c) 2019-2020 Rhys Goodall 10 | 11 | Permission is hereby granted, free of charge, to any person obtaining a copy 12 | of this software and associated documentation files (the "Software"), to deal 13 | in the Software without restriction, including without limitation the rights 14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | copies of the Software, and to permit persons to whom the Software is 16 | furnished to do so, subject to the following conditions: 17 | 18 | The above copyright notice and this permission notice shall be included in all 19 | copies or substantial portions of the Software. 20 | 21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 27 | SOFTWARE. 28 | """ 29 | 30 | 31 | class SimpleNetwork(nn.Module): 32 | """ 33 | Simple Feed Forward Neural Network 34 | """ 35 | 36 | def __init__(self, input_dim, output_dim, hidden_layer_dims): 37 | """ 38 | Inputs 39 | ---------- 40 | input_dim: int 41 | output_dim: int 42 | hidden_layer_dims: list(int) 43 | 44 | """ 45 | super(SimpleNetwork, self).__init__() 46 | 47 | dims = [input_dim] + hidden_layer_dims 48 | # print(dims, output_dim) 49 | # print(dims) 50 | 51 | self.fcs = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) 52 | for i in range(len(dims) - 1)]) 53 | self.acts = nn.ModuleList([nn.LeakyReLU() 54 | for _ in range(len(dims) - 1)]) 55 | 56 | self.fc_out = nn.Linear(dims[-1], output_dim) 57 | 58 | def forward(self, fea): 59 | for fc, act in zip(self.fcs, self.acts): 60 | # print('fea',fea.shape) 61 | fea = act(fc(fea)) 62 | 63 | return self.fc_out(fea) 64 | 65 | def __repr__(self): 66 | return '{}'.format(self.__class__.__name__) 67 | 68 | 69 | class Rezero(nn.Module): 70 | def __init__(self): 71 | super().__init__() 72 | self.alpha = nn.Parameter(torch.zeros(1)) 73 | 74 | def forward(self, x): 75 | return self.alpha * x 76 | 77 | def __repr__(self): 78 | return '{}'.format(self.__class__.__name__) 79 | 80 | 81 | class ResidualNetwork(nn.Module): 82 | """ 83 | Feed forward Residual Neural Network 84 | """ 85 | 86 | def __init__( 87 | self, 88 | input_dim, 89 | output_dim, 90 | hidden_layer_dims, 91 | if_rezero=False): 92 | """ 93 | Inputs 94 | ---------- 95 | input_dim: int 96 | output_dim: int 97 | hidden_layer_dims: list(int) 98 | 99 | """ 100 | super(ResidualNetwork, self).__init__() 101 | 102 | dims = [input_dim] + hidden_layer_dims 103 | 104 | self.fcs = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) 105 | for i in range(len(dims) - 1)]) 106 | # self.bns = nn.ModuleList([nn.BatchNorm1d(dims[i+1]) 107 | # for i in range(len(dims)-1)]) 108 | self.res_fcs = nn.ModuleList([nn.Linear(dims[i], dims[i + 1], bias=False) 109 | if (dims[i] != dims[i + 1]) 110 | else nn.Identity() 111 | for i in range(len(dims) - 1)]) 112 | self.acts = nn.ModuleList([nn.ReLU() for _ in range(len(dims) - 1)]) 113 | 114 | self.fc_out = nn.Linear(dims[-1], output_dim) 115 | self.if_rezero = if_rezero 116 | if (self.if_rezero): 117 | self.rezeros = nn.ModuleList( 118 | [Rezero() for _ in range(len(dims) - 1)]) 119 | 120 | def forward(self, fea, *, last_layer=True): 121 | # for fc, bn, res_fc, act in zip(self.fcs, self.bns, 122 | # self.res_fcs, self.acts): 123 | # fea = act(bn(fc(fea)))+res_fc(fea) 124 | if (not self.if_rezero): 125 | for fc, res_fc, act in zip(self.fcs, self.res_fcs, self.acts): 126 | fea = act(fc(fea)) + res_fc(fea) 127 | else: 128 | for fc, res_fc, act, rez in zip( 129 | self.fcs, self.res_fcs, self.acts, self.rezeros): 130 | fea = rez(act(fc(fea))) + res_fc(fea) 131 | 132 | if last_layer: 133 | return self.fc_out(fea) 134 | else: 135 | return fea 136 | 137 | def __repr__(self): 138 | return '{}'.format(self.__class__.__name__) 139 | -------------------------------------------------------------------------------- /CGAT/predict.py: -------------------------------------------------------------------------------- 1 | from lightning_module import LightningModel 2 | from data import CompositionData 3 | import torch 4 | from torch.utils.data import DataLoader 5 | import pickle, gzip as gz 6 | 7 | # example script for predicting data with a trained model 8 | 9 | 10 | model_path = 'example.ckpt' 11 | model = LightningModel.load_from_checkpoint(model_path, train=False) 12 | # set train=False to avoid reloading of large training datasets 13 | 14 | data_path = 'example_test_data.pickle.gz' 15 | dataset = CompositionData( 16 | data=data_path, 17 | fea_path=model.hparams.fea_path, 18 | max_neighbor_number=model.hparams.max_nbr) 19 | 20 | params = {"batch_size": 5000, 21 | "pin_memory": False, 22 | "shuffle": False, 23 | "drop_last": False 24 | } 25 | 26 | 27 | def collate_fn(datalist): 28 | return datalist 29 | 30 | 31 | model = model.cuda() 32 | dataloader = DataLoader(dataset, collate_fn=collate_fn, **params) 33 | 34 | prediction_list = [] 35 | for i, batch in enumerate(dataloader): 36 | with torch.no_grad(): 37 | print(i) 38 | prediction_list.append(model.evaluate(batch)[2]) # tuple of size 5 output, log_std, pred, target, target_norm 39 | 40 | pickle.dump(torch.cat(prediction_list), gz.open('predictions.pickle.gz', 'wb')) 41 | -------------------------------------------------------------------------------- /CGAT/prepare_data.py: -------------------------------------------------------------------------------- 1 | import gzip as gz 2 | import os 3 | import argparse 4 | 5 | import numpy as np 6 | import warnings 7 | import torch 8 | import pickle 9 | from torch.utils.data import Dataset, DataLoader 10 | from tqdm import tqdm 11 | from .roost_message import LoadFeaturiser 12 | 13 | 14 | def build_dataset_prepare(data, 15 | target_property=["e_above_hull", 'e_form'], 16 | radius=18.0, 17 | fea_path="../embeddings/matscholar-embedding.json", 18 | max_neighbor_number=24): 19 | """Use to calculate features for lists of pickle and gzipped ComputedEntry pickles (either a path to the file or the file directly), returns dictionary with all necessary inputs. If the data has no target values the target values are set to -1e8 20 | Always enter list of target properties""" 21 | 22 | def tensor2numpy(l): 23 | """recursively convert torch Tensors into numpy arrays""" 24 | if isinstance(l, torch.Tensor): 25 | return l.numpy() 26 | elif isinstance(l, str) or isinstance(l, int) or isinstance(l, float): 27 | return l 28 | elif isinstance(l, list) or isinstance(l, tuple): 29 | return np.asarray([tensor2numpy(i) for i in l], dtype=object) 30 | elif isinstance(l, dict): 31 | npdict = {} 32 | for name, val in l.items(): 33 | npdict[name] = tensor2numpy(val) 34 | return npdict 35 | else: 36 | return None # this will give an error later on 37 | 38 | d = CompositionDataPrepare(data, 39 | fea_path=fea_path, 40 | target_property=target_property, 41 | max_neighbor_number=max_neighbor_number, 42 | radius=radius) 43 | 44 | loader = DataLoader(d, batch_size=1) 45 | 46 | input1_ = [] 47 | input2_ = [] 48 | input3_ = [] 49 | comps_ = [] 50 | batch_comp_ = [] 51 | if type(target_property) == list: 52 | target_ = {} 53 | for name in target_property: 54 | target_[name] = [] 55 | else: 56 | target_ = [] 57 | batch_ids_ = [] 58 | 59 | for input_, target, batch_comp, batch_ids in tqdm(loader): 60 | if len(input_) == 1: # remove compounds with not enough neighbors 61 | continue 62 | input1_.append(input_[0]) 63 | comps_.append(input_[1]) 64 | input2_.append(input_[2]) 65 | input3_.append(input_[3]) 66 | if isinstance(target_property, list): 67 | for name in target_property: 68 | target_[name].append(target[name]) 69 | else: 70 | target_.append(target) 71 | 72 | batch_comp_.append(batch_comp) 73 | batch_ids_.append(batch_ids) 74 | 75 | input1_ = tensor2numpy(input1_) 76 | input2_ = tensor2numpy(input2_) 77 | input3_ = tensor2numpy(input3_) 78 | 79 | n = input1_[0].shape[0] 80 | shape = input1_.shape 81 | if len(shape) > 2: 82 | i1 = np.empty(shape=(1, shape[0]), dtype=object) 83 | i2 = np.empty(shape=(1, shape[0]), dtype=object) 84 | i3 = np.empty(shape=(1, shape[0]), dtype=object) 85 | i1[:, :, ] = [[input1_[l] for l in range(shape[0])]] 86 | input1_ = i1 87 | i2[:, :, ] = [[input2_[l] for l in range(shape[0])]] 88 | input2_ = i2 89 | i3[:, :, ] = [[input3_[l] for l in range(shape[0])]] 90 | input3_ = i3 91 | 92 | inputs_ = np.vstack((input1_, input2_, input3_)) 93 | 94 | return {'input': inputs_, 95 | 'batch_ids': batch_ids_, 96 | 'batch_comp': tensor2numpy(batch_comp_), 97 | 'target': tensor2numpy(target_), 98 | 'comps': tensor2numpy(comps_)} 99 | 100 | 101 | class CompositionDataPrepare(Dataset): 102 | """ 103 | The CompositionData dataset is a wrapper for a dataset data points are 104 | automatically constructed from composition strings. 105 | """ 106 | 107 | def __init__(self, data, fea_path, target_property=['e-form'], radius=18.0, max_neighbor_number=24): 108 | """ 109 | """ 110 | if isinstance(data, str): 111 | self.data = pickle.load(gz.open(data, 'rb')) 112 | else: 113 | self.data = data 114 | self.radius = radius 115 | self.max_num_nbr = max_neighbor_number 116 | self.target_property = target_property 117 | assert os.path.exists(fea_path), "{} does not exist!".format(fea_path) 118 | self.atom_features = LoadFeaturiser(fea_path) 119 | self.atom_fea_dim = self.atom_features.embedding_size 120 | 121 | def __len__(self): 122 | return len(self.data) 123 | 124 | def __getitem__(self, idx): 125 | try: 126 | cry_id = self.data[idx].data['id'] 127 | except KeyError: 128 | cry_id = 'unknown' 129 | composition = self.data[idx].composition.formula 130 | try: 131 | crystal = self.data[idx].structure 132 | except: 133 | crystal = self.data[idx] 134 | 135 | elements = [element.specie.symbol for element in crystal] 136 | try: 137 | target = {} 138 | for name in self.target_property: 139 | target[name] = self.data[idx].data[name] / len(crystal.sites) 140 | except KeyError: 141 | target = {} 142 | warnings.warn('no target property') 143 | for name in self.target_property: 144 | target[name] = -1e8 145 | 146 | all_nbrs = crystal.get_all_neighbors(self.radius, include_index=True) 147 | all_nbrs = [sorted(nbrs, key=lambda x: x[1])[0:self.max_num_nbr] for nbrs in all_nbrs] 148 | 149 | nbr_fea_idx, nbr_fea, self_fea_idx = [], [], [] 150 | for site, nbr in enumerate(all_nbrs): 151 | nbr_fea_idx_sub, nbr_fea_sub, self_fea_idx_sub = [], [], [] 152 | if len(nbr) < self.max_num_nbr: 153 | warnings.warn('{} does not contain enough neighbors in the cutoff to build the full graph. ' 154 | 'If it happens frequently, consider increase ' 155 | 'radius. Compound is not added to the feature set'.format(cry_id)) 156 | return (torch.ones(1)), torch.ones(1), torch.ones(1), torch.ones( 157 | 1) # fake input will be removed in build_dataset_prepare 158 | else: 159 | for n in range(self.max_num_nbr): 160 | self_fea_idx_sub.append(site) 161 | for j in range(self.max_num_nbr): 162 | nbr_fea_idx_sub.append(nbr[j][2]) 163 | index = 1 164 | dist = nbr[0][1] 165 | for j in range(self.max_num_nbr): 166 | if (nbr[j][1] > dist + 1e-8): 167 | dist = nbr[j][1] 168 | index += 1 169 | nbr_fea_sub.append(index) 170 | nbr_fea_idx.append(nbr_fea_idx_sub) 171 | nbr_fea.append(nbr_fea_sub) 172 | self_fea_idx.append(self_fea_idx_sub) 173 | return (nbr_fea, elements, self_fea_idx, nbr_fea_idx), \ 174 | target, composition, cry_id 175 | 176 | def get_targets(self, idx1, idx2): 177 | target = [] 178 | l = [] 179 | for el in idx2: 180 | l.append(self.data[el][self.target_property]) 181 | for el in idx1: 182 | target.append(l[el]) 183 | del l 184 | return torch.tensor(target).reshape(len(idx1), 1) 185 | 186 | 187 | def collate_batch(dataset_list): 188 | """ 189 | Collate a list of data and return a batch for predicting crystal 190 | properties. 191 | Parameters 192 | ---------- 193 | dataset_list: list of tuples for each data point. 194 | (atom_fea, nbr_fea, nbr_fea_idx, target) 195 | atom_fea: torch.Tensor shape (n_i, atom_fea_len) 196 | nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len) 197 | nbr_fea_idx: torch.LongTensor shape (n_i, M) 198 | target: torch.Tensor shape (1, ) 199 | cif_id: str or int 200 | Returns 201 | ------- 202 | N = sum(n_i); N0 = sum(i) 203 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len) 204 | Atom features from atom type 205 | batch_nbr_fea: torch.Tensor shape (N, M, nbr_fea_len) 206 | Bond features of each atom"s M neighbors 207 | batch_nbr_fea_idx: torch.LongTensor shape (N, M) 208 | Indices of M neighbors of each atom 209 | crystal_atom_idx: list of torch.LongTensor of length N0 210 | Mapping from the crystal idx to atom idx 211 | target: torch.Tensor shape (N, 1) 212 | Target value for prediction 213 | batch_cif_ids: list 214 | """ 215 | # define the lists 216 | batch_atom_weights = [] 217 | batch_atom_fea = [] 218 | batch_nbr_fea = [] 219 | batch_self_fea_idx = [] 220 | batch_nbr_fea_idx = [] 221 | crystal_atom_idx = [] 222 | batch_target = [] 223 | batch_comp = [] 224 | batch_cry_ids = [] 225 | 226 | cry_base_idx = 0 227 | for i, ((atom_fea, nbr_fea, self_fea_idx, nbr_fea_idx, _), 228 | target, comp, cry_id) in enumerate(dataset_list): 229 | # number of atoms for this crystal 230 | n_i = atom_fea.shape[0] 231 | # batch the features together 232 | # batch_atom_weights.append(atom_weights) 233 | batch_atom_fea.append(atom_fea) 234 | batch_nbr_fea.append(nbr_fea) 235 | # mappings from bonds to atoms 236 | batch_self_fea_idx.append(self_fea_idx + cry_base_idx) 237 | batch_nbr_fea_idx.append(nbr_fea_idx + cry_base_idx) 238 | 239 | # mapping from atoms to crystals 240 | crystal_atom_idx.append(torch.tensor([i] * n_i)) 241 | 242 | # batch the targets and ids 243 | batch_target.append(target) 244 | batch_comp.append(comp) 245 | batch_cry_ids.append(cry_id) 246 | 247 | # increment the id counter 248 | cry_base_idx += n_i 249 | return (torch.cat(batch_atom_fea, dim=0), torch.cat(batch_nbr_fea, dim=0), torch.cat(batch_self_fea_idx, dim=0), 250 | torch.cat(batch_nbr_fea_idx, dim=0), torch.cat(crystal_atom_idx)), \ 251 | torch.cat(batch_target, dim=0), \ 252 | batch_comp, \ 253 | batch_cry_ids 254 | 255 | 256 | def collate_batch2(dataset_list): 257 | """ 258 | Collate a list of data and return a batch for predicting crystal 259 | properties. 260 | Parameters 261 | ---------- 262 | dataset_list: list of tuples for each data point. 263 | (atom_fea, nbr_fea, nbr_fea_idx, target) 264 | atom_fea: torch.Tensor shape (n_i, atom_fea_len) 265 | nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len) 266 | nbr_fea_idx: torch.LongTensor shape (n_i, M) 267 | target: torch.Tensor shape (1, ) 268 | cif_id: str or int 269 | Returns 270 | ------- 271 | N = sum(n_i); N0 = sum(i) 272 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len) 273 | Atom features from atom type 274 | batch_nbr_fea: torch.Tensor shape (N, M, nbr_fea_len) 275 | Bond features of each atom"s M neighbors 276 | batch_nbr_fea_idx: torch.LongTensor shape (N, M) 277 | Indices of M neighbors of each atom 278 | crystal_atom_idx: list of torch.LongTensor of length N0 279 | Mapping from the crystal idx to atom idx 280 | target: torch.Tensor shape (N, 1) 281 | Target value for prediction 282 | batch_cif_ids: list 283 | """ 284 | # define the lists 285 | batch_atom_weights = [] 286 | batch_atom_fea = [] 287 | batch_nbr_fea = [] 288 | batch_self_fea_idx = [] 289 | batch_nbr_fea_idx = [] 290 | crystal_atom_idx = [] 291 | batch_target = [] 292 | batch_comp = [] 293 | batch_cry_ids = [] 294 | 295 | cry_base_idx = 0 296 | for i, ((nbr_fea, atom_fea, self_fea_idx, nbr_fea_idx), 297 | target, comp, cry_id) in enumerate(dataset_list): 298 | # number of atoms for this crystal 299 | n_i = atom_fea.shape[0] 300 | # batch the features together 301 | # batch_atom_weights.append(atom_weights) 302 | batch_atom_fea.append(atom_fea) 303 | batch_nbr_fea.append(nbr_fea) 304 | # mappings from bonds to atoms 305 | batch_self_fea_idx.append(self_fea_idx + cry_base_idx) 306 | batch_nbr_fea_idx.append(nbr_fea_idx + cry_base_idx) 307 | 308 | # mapping from atoms to crystals 309 | crystal_atom_idx.append(torch.tensor([i] * n_i)) 310 | 311 | # batch the targets and ids 312 | batch_target.append(target) 313 | batch_comp.append(comp) 314 | batch_cry_ids.append(cry_id) 315 | 316 | # increment the id counter 317 | cry_base_idx += n_i 318 | return (torch.cat(batch_atom_fea, dim=0), torch.cat(batch_nbr_fea, dim=0), torch.cat(batch_self_fea_idx, dim=0), 319 | torch.cat(batch_nbr_fea_idx, dim=0), torch.cat(crystal_atom_idx)), \ 320 | torch.cat(batch_target, dim=0), \ 321 | batch_comp, \ 322 | batch_cry_ids 323 | 324 | 325 | class AverageMeter(object): 326 | """Computes and stores the average and current value""" 327 | 328 | def __init__(self): 329 | self.reset() 330 | 331 | def reset(self): 332 | self.val = 0 333 | self.avg = 0 334 | self.sum = 0 335 | self.count = 0 336 | 337 | def update(self, val, n=1): 338 | self.val = val 339 | self.sum += val * n 340 | self.count += n 341 | self.avg = self.sum / self.count 342 | 343 | 344 | class Normalizer(object): 345 | """Normalize a Tensor and restore it later. """ 346 | 347 | def __init__(self, log=False): 348 | """tensor is taken as a sample to calculate the mean and std""" 349 | self.mean = torch.tensor((0)) 350 | self.std = torch.tensor((1)) 351 | 352 | def fit(self, tensor, dim=0, keepdim=False): 353 | """tensor is taken as a sample to calculate the mean and std""" 354 | self.mean = torch.mean(tensor, dim, keepdim) 355 | self.std = torch.std(tensor, dim, keepdim) 356 | 357 | def norm(self, tensor): 358 | return (tensor - self.mean) / self.std 359 | 360 | def denorm(self, normed_tensor): 361 | return normed_tensor * self.std + self.mean 362 | 363 | def state_dict(self): 364 | return {"mean": self.mean, 365 | "std": self.std} 366 | 367 | def load_state_dict(self, state_dict): 368 | self.mean = state_dict["mean"].cpu() 369 | self.std = state_dict["std"].cpu() 370 | 371 | 372 | def main(): 373 | parser = argparse.ArgumentParser() 374 | parser.add_argument('--file', default='dcgat_1_000.pickle.gz') 375 | parser.add_argument('--source-dir', default='./') 376 | parser.add_argument('--target-dir', default='./') 377 | parser.add_argument('--target-file', default='dcgat_1_000_features.pickle.gz') 378 | args = parser.parse_args() 379 | test = build_dataset_prepare(os.path.join(args.source_dir, args.file)) 380 | if args.target_file is None: 381 | pickle.dump(test, gz.open(os.path.join(args.target_dir, os.path.basename(args.file)), 'wb')) 382 | else: 383 | pickle.dump(test, gz.open(os.path.join(args.target_dir, args.target_file), 'wb')) 384 | 385 | 386 | if __name__ == '__main__': 387 | main() 388 | -------------------------------------------------------------------------------- /CGAT/roost_message.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | Copyright (c) 2019-2020 Rhys Goodall 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | """ 23 | 24 | import torch 25 | import torch.nn as nn 26 | import numpy as np 27 | from torch_scatter import scatter_max, scatter_add, \ 28 | scatter_mean 29 | import json 30 | 31 | 32 | 33 | class Featuriser(object): 34 | """ 35 | Base class for featurising nodes and edges. 36 | """ 37 | 38 | def __init__(self, allowed_types): 39 | self.allowed_types = set(allowed_types) 40 | self._embedding = {} 41 | 42 | def get_fea(self, key): 43 | assert key in self.allowed_types, "{} is not an allowed atom type".format( 44 | key) 45 | return self._embedding[key] 46 | 47 | def load_state_dict(self, state_dict): 48 | self._embedding = state_dict 49 | self.allowed_types = set(self._embedding.keys()) 50 | 51 | def get_state_dict(self): 52 | return self._embedding 53 | 54 | def embedding_size(self): 55 | return len(self._embedding[list(self._embedding.keys())[0]]) 56 | 57 | 58 | class LoadFeaturiser(Featuriser): 59 | """ 60 | Initialize atom feature vectors using a JSON file, which is a python 61 | dictionary mapping from element number to a list representing the 62 | feature vector of the element. 63 | 64 | Notes 65 | --------- 66 | For the specific composition net application the keys are concatenated 67 | strings of the form "NaCl" where the order of concatenation matters. 68 | This is done because the bond "ClNa" has the opposite dipole to "NaCl" 69 | so for a general representation we need to be able to asign different 70 | bond features for different directions on the multigraph. 71 | 72 | Parameters 73 | ---------- 74 | elem_embedding_file: str 75 | The path to the .json file 76 | """ 77 | 78 | def __init__(self, embedding_file): 79 | with open(embedding_file) as f: 80 | embedding = json.load(f) 81 | allowed_types = set(embedding.keys()) 82 | super(LoadFeaturiser, self).__init__(allowed_types) 83 | for key, value in embedding.items(): 84 | self._embedding[key] = np.array(value, dtype=float) 85 | 86 | 87 | 88 | class MessageLayer(nn.Module): 89 | """ 90 | Class defining the message passing operation on the composition graph 91 | """ 92 | 93 | def __init__(self, fea_len, num_heads=1): 94 | """ 95 | Inputs 96 | ---------- 97 | fea_len: int 98 | Number of elem hidden features. 99 | """ 100 | super(MessageLayer, self).__init__() 101 | 102 | # Pooling and Output 103 | hidden_ele = [256] 104 | hidden_msg = [256] 105 | self.pooling = nn.ModuleList([WeightedAttention( 106 | gate_nn=SimpleNetwork(2 * fea_len, 1, hidden_ele), 107 | message_nn=SimpleNetwork(2 * fea_len, fea_len, hidden_msg), 108 | # message_nn=nn.Linear(2*fea_len, fea_len), 109 | # message_nn=nn.Identity(), 110 | ) for _ in range(num_heads)]) 111 | 112 | def forward(self, elem_weights, elem_in_fea, 113 | self_fea_idx, nbr_fea_idx): 114 | """ 115 | Forward pass 116 | Parameters 117 | ---------- 118 | N: Total number of elems (nodes) in the batch 119 | M: Total number of bonds (edges) in the batch 120 | C: Total number of crystals (graphs) in the batch 121 | Inputs 122 | ---------- 123 | elem_weights: Variable(torch.Tensor) shape (N,) 124 | The fractional weights of elems in their materials 125 | elem_in_fea: Variable(torch.Tensor) shape (N, elem_fea_len) 126 | Atom hidden features before message passing 127 | self_fea_idx: torch.Tensor shape (M,) 128 | Indices of M neighbours of each elem 129 | nbr_fea_idx: torch.Tensor shape (M,) 130 | Indices of M neighbours of each elem 131 | Returns 132 | ------- 133 | elem_out_fea: nn.Variable shape (N, elem_fea_len) 134 | Atom hidden features after message passing 135 | """ 136 | # construct the total features for passing 137 | elem_nbr_weights = elem_weights[nbr_fea_idx, :] 138 | elem_nbr_fea = elem_in_fea[nbr_fea_idx, :] 139 | elem_self_fea = elem_in_fea[self_fea_idx, :] 140 | fea = torch.cat([elem_self_fea, elem_nbr_fea], dim=1) 141 | #print('fea shape after cat',fea.shape) 142 | # sum selectivity over the neighbours to get elems 143 | head_fea = [] 144 | for attnhead in self.pooling: 145 | head_fea.append(attnhead(fea=fea, 146 | index=self_fea_idx, 147 | weights=elem_nbr_weights)) 148 | 149 | # # Concatenate 150 | # fea = torch.cat(head_fea, dim=1) 151 | fea = torch.mean(torch.stack(head_fea), dim=0) 152 | #print(fea.shape, elem_in_fea.shape) 153 | return fea + elem_in_fea 154 | 155 | def __repr__(self): 156 | return '{}'.format(self.__class__.__name__) 157 | 158 | 159 | class Roost(nn.Module): 160 | """ 161 | Create a neural network for predicting total material properties. 162 | The Roost model is comprised of a fully connected network 163 | and message passing graph layers. 164 | The message passing layers are used to determine a descriptor set 165 | for the fully connected network. Critically the graphs are used to 166 | represent (crystalline) materials in a structure agnostic manner 167 | but contain trainable parameters unlike other structure agnostic 168 | approaches. 169 | """ 170 | 171 | def __init__(self, orig_elem_fea_len, elem_fea_len, n_graph): 172 | """ 173 | Initialize CompositionNet. 174 | Parameters 175 | ---------- 176 | n_h: Number of hidden layers after pooling 177 | Inputs 178 | ---------- 179 | orig_elem_fea_len: int 180 | Number of elem features in the input. 181 | elem_fea_len: int 182 | Number of hidden elem features in the graph layers 183 | n_graph: int 184 | Number of graph layers 185 | """ 186 | super(Roost, self).__init__() 187 | 188 | # apply linear transform to the input to get a trainable embedding 189 | self.embedding = nn.Linear(orig_elem_fea_len, elem_fea_len - 1) 190 | 191 | # create a list of Message passing layers 192 | 193 | msg_heads = 1 194 | self.graphs = nn.ModuleList( 195 | [MessageLayer(elem_fea_len, msg_heads) 196 | for i in range(n_graph)]) 197 | 198 | # define a global pooling function for materials 199 | mat_heads = 1 200 | mat_hidden = [256] 201 | # msg_hidden = [256] 202 | self.cry_pool = nn.ModuleList([WeightedAttention( 203 | gate_nn=SimpleNetwork(elem_fea_len, 1, mat_hidden), 204 | # message_nn=SimpleNetwork(elem_fea_len, 20, msg_hidden), 205 | # message_nn=nn.Linear(elem_fea_len, elem_fea_len), 206 | message_nn=nn.Identity(), 207 | ) for _ in range(mat_heads)]) 208 | 209 | # define an output neural network 210 | # out_hidden = [512, 256, 128, 64] 211 | 212 | def forward(self, elem_weights, orig_elem_fea, self_fea_idx, 213 | nbr_fea_idx, crystal_elem_idx): 214 | """ 215 | Forward pass 216 | Parameters 217 | ---------- 218 | N: Total number of elems (nodes) in the batch 219 | M: Total number of bonds (edges) in the batch 220 | C: Total number of crystals (graphs) in the batch 221 | Inputs 222 | ---------- 223 | orig_elem_fea: Variable(torch.Tensor) shape (N, orig_elem_fea_len) 224 | Atom features of each of the N elems in the batch 225 | self_fea_idx: torch.Tensor shape (M,) 226 | Indices of the elem each of the M bonds correspond to 227 | nbr_fea_idx: torch.Tensor shape (M,) 228 | Indices of of the neighbours of the M bonds connect to 229 | elem_bond_idx: list of torch.LongTensor of length C 230 | Mapping from the bond idx to elem idx 231 | crystal_elem_idx: list of torch.LongTensor of length C 232 | Mapping from the elem idx to crystal idx 233 | Returns 234 | ------- 235 | out: nn.Variable shape (C,) 236 | Atom hidden features after message passing 237 | """ 238 | 239 | # embed the original features into the graph layer description 240 | elem_fea = self.embedding(orig_elem_fea) 241 | 242 | # do this so that we can examine the embeddings without 243 | # influence of the weights 244 | #print(elem_fea.shape, elem_weights.shape) 245 | elem_fea = torch.cat([elem_fea, elem_weights], dim=1) 246 | 247 | # apply the graph message passing functions 248 | for graph_func in self.graphs: 249 | elem_fea = graph_func(elem_weights, elem_fea, 250 | self_fea_idx, nbr_fea_idx) 251 | 252 | # generate crystal features by pooling the elemental features 253 | head_fea = [] 254 | for attnhead in self.cry_pool: 255 | head_fea.append(attnhead(fea=elem_fea, 256 | index=crystal_elem_idx, 257 | weights=elem_weights)) 258 | 259 | crys_fea = torch.mean(torch.stack(head_fea), dim=0) 260 | # crys_fea = torch.cat(head_fea, dim=1) 261 | 262 | # apply neural network to map from learned features to target 263 | 264 | return crys_fea 265 | 266 | def __repr__(self): 267 | return '{}'.format(self.__class__.__name__) 268 | 269 | 270 | class WeightedMeanPooling(torch.nn.Module): 271 | """ 272 | mean pooling 273 | """ 274 | 275 | def __init__(self): 276 | super(WeightedMeanPooling, self).__init__() 277 | 278 | def forward(self, fea, index, weights): 279 | fea = weights * fea 280 | return scatter_mean(fea, index, dim=0) 281 | 282 | def __repr__(self): 283 | return '{}'.format(self.__class__.__name__) 284 | 285 | 286 | class WeightedAttention(nn.Module): 287 | """ 288 | Weighted softmax attention layer 289 | """ 290 | 291 | def __init__(self, gate_nn, message_nn, num_heads=1): 292 | """ 293 | Inputs 294 | ---------- 295 | gate_nn: Variable(nn.Module) 296 | """ 297 | super(WeightedAttention, self).__init__() 298 | self.gate_nn = gate_nn 299 | self.message_nn = message_nn 300 | self.pow = torch.nn.Parameter(torch.randn((1))) 301 | 302 | def forward(self, fea, index, weights): 303 | """ forward pass """ 304 | 305 | gate = self.gate_nn(fea) 306 | 307 | gate = gate - scatter_max(gate, index, dim=0)[0][index] 308 | gate = (weights ** self.pow) * gate.exp() 309 | # gate = weights * gate.exp() 310 | # gate = gate.exp() 311 | gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-13) 312 | 313 | fea = self.message_nn(fea) 314 | # print(fea.shape) 315 | out = scatter_add(gate * fea, index, dim=0) 316 | # print(out.shape) 317 | return out 318 | 319 | def __repr__(self): 320 | return '{}(gate_nn={})'.format(self.__class__.__name__, 321 | self.gate_nn) 322 | 323 | 324 | class SimpleNetwork(nn.Module): 325 | """ 326 | Simple Feed Forward Neural Network 327 | """ 328 | 329 | def __init__(self, input_dim, output_dim, hidden_layer_dims): 330 | """ 331 | Inputs 332 | ---------- 333 | input_dim: int 334 | output_dim: int 335 | hidden_layer_dims: list(int) 336 | """ 337 | super(SimpleNetwork, self).__init__() 338 | 339 | dims = [input_dim] + hidden_layer_dims 340 | 341 | self.fcs = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) 342 | for i in range(len(dims) - 1)]) 343 | self.acts = nn.ModuleList([nn.LeakyReLU() 344 | for _ in range(len(dims) - 1)]) 345 | 346 | self.fc_out = nn.Linear(dims[-1], output_dim) 347 | 348 | def forward(self, fea): 349 | for fc, act in zip(self.fcs, self.acts): 350 | fea = act(fc(fea)) 351 | 352 | return self.fc_out(fea) 353 | 354 | def __repr__(self): 355 | return '{}'.format(self.__class__.__name__) 356 | 357 | 358 | class ResidualNetwork(nn.Module): 359 | """ 360 | Feed forward Residual Neural Network 361 | """ 362 | 363 | def __init__(self, input_dim, output_dim, hidden_layer_dims): 364 | """ 365 | Inputs 366 | ---------- 367 | input_dim: int 368 | output_dim: int 369 | hidden_layer_dims: list(int) 370 | """ 371 | super(ResidualNetwork, self).__init__() 372 | 373 | dims = [input_dim] + hidden_layer_dims 374 | 375 | self.fcs = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) 376 | for i in range(len(dims) - 1)]) 377 | # self.bns = nn.ModuleList([nn.BatchNorm1d(dims[i+1]) 378 | # for i in range(len(dims)-1)]) 379 | self.res_fcs = nn.ModuleList([nn.Linear(dims[i], dims[i + 1], bias=False) 380 | if (dims[i] != dims[i + 1]) 381 | else nn.Identity() 382 | for i in range(len(dims) - 1)]) 383 | self.acts = nn.ModuleList([nn.ReLU() for _ in range(len(dims) - 1)]) 384 | 385 | self.fc_out = nn.Linear(dims[-1], output_dim) 386 | 387 | def forward(self, fea): 388 | # for fc, bn, res_fc, act in zip(self.fcs, self.bns, 389 | # self.res_fcs, self.acts): 390 | # fea = act(bn(fc(fea)))+res_fc(fea) 391 | for fc, res_fc, act in zip(self.fcs, self.res_fcs, self.acts): 392 | fea = act(fc(fea)) + res_fc(fea) 393 | 394 | return self.fc_out(fea) 395 | 396 | def __repr__(self): 397 | return '{}'.format(self.__class__.__name__) 398 | 399 | 400 | def collate_batch(dataset_list): 401 | """ 402 | Collate a list of data and return a batch for predicting crystal 403 | properties. 404 | Parameters 405 | ---------- 406 | dataset_list: list of tuples for each data point. 407 | (atom_fea, nbr_fea, nbr_fea_idx, target) 408 | atom_fea: torch.Tensor shape (n_i, atom_fea_len) 409 | nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len) 410 | nbr_fea_idx: torch.LongTensor shape (n_i, M) 411 | target: torch.Tensor shape (1, ) 412 | cif_id: str or int 413 | Returns 414 | ------- 415 | N = sum(n_i); N0 = sum(i) 416 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len) 417 | Atom features from atom type 418 | batch_nbr_fea: torch.Tensor shape (N, M, nbr_fea_len) 419 | Bond features of each atom"s M neighbors 420 | batch_nbr_fea_idx: torch.LongTensor shape (N, M) 421 | Indices of M neighbors of each atom 422 | crystal_atom_idx: list of torch.LongTensor of length N0 423 | Mapping from the crystal idx to atom idx 424 | target: torch.Tensor shape (N, 1) 425 | Target value for prediction 426 | batch_cif_ids: list 427 | """ 428 | # define the lists 429 | batch_atom_weights = [] 430 | batch_atom_fea = [] 431 | batch_self_fea_idx = [] 432 | batch_nbr_fea_idx = [] 433 | crystal_atom_idx = [] 434 | cry_base_idx = 0 435 | for i, (atom_weights, atom_fea, self_fea_idx, 436 | nbr_fea_idx) in enumerate(dataset_list): 437 | # number of atoms for this crystal 438 | n_i = atom_fea.shape[0] 439 | 440 | # batch the features together 441 | batch_atom_weights.append(atom_weights) 442 | batch_atom_fea.append(atom_fea) 443 | 444 | # mappings from bonds to atoms 445 | batch_self_fea_idx.append(self_fea_idx + cry_base_idx) 446 | batch_nbr_fea_idx.append(nbr_fea_idx + cry_base_idx) 447 | 448 | # mapping from atoms to crystals 449 | crystal_atom_idx.append(torch.tensor([i] * n_i)) 450 | 451 | # batch the targets and ids 452 | cry_base_idx += n_i 453 | # print('hal',torch.cat(batch_atom_weights, dim=0).shape,torch.cat(batch_atom_fea, dim=0).shape ) 454 | return (torch.cat(batch_atom_weights, dim=0).view(-1, 1), 455 | torch.cat(batch_atom_fea, dim=0), 456 | torch.cat(batch_self_fea_idx, dim=0), 457 | torch.cat(batch_nbr_fea_idx, dim=0), 458 | torch.cat(crystal_atom_idx)) 459 | -------------------------------------------------------------------------------- /CGAT/test.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import ModelCheckpoint 2 | import os 3 | from argparse import ArgumentParser 4 | import os 5 | import gc 6 | import datetime 7 | import numpy as np 8 | import pandas as pd 9 | 10 | import numpy as np 11 | import torch 12 | 13 | import pytorch_lightning as pl 14 | from lightning_module import LightningModel 15 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger 16 | SEED = 1 17 | torch.manual_seed(SEED) 18 | np.random.seed(SEED) 19 | 20 | 21 | def main(hparams): 22 | """ 23 | testing routine 24 | Args: 25 | hparams: checkpoint of the model to be tested and gpu, parallel backend etc., 26 | defined in the argument parser in if __name__ == '__main__': 27 | Returns: 28 | """ 29 | checkpoint_path=hparams.ckp 30 | model = LightningModel.load_from_checkpoint( 31 | checkpoint_path=checkpoint_path, 32 | tags_csv= hparams.hparams, 33 | ) 34 | 35 | trainer = pl.Trainer( 36 | gpus=[hparams.first_gpu+el for el in range(hparams.gpus)], 37 | distributed_backend=hparams.distributed_backend, 38 | ) 39 | 40 | trainer.test(model) 41 | 42 | if __name__ == '__main__': 43 | 44 | root_dir = os.path.dirname(os.path.realpath(__file__)) 45 | parent_parser = ArgumentParser(add_help=False) 46 | 47 | parent_parser.add_argument( 48 | '--gpus', 49 | type=int, 50 | default=4, 51 | help='how many gpus' 52 | ) 53 | parent_parser.add_argument( 54 | '--distributed_backend', 55 | type=str, 56 | default='ddp', 57 | help='supports three options dp, ddp, ddp2' 58 | ) 59 | parent_parser.add_argument( 60 | '--amp_optimization', 61 | type=str, 62 | default='00', 63 | help="mixed precision format, default 00 (32), 01 mixed, 02 closer to 16, should not be used during testing" 64 | ) 65 | parent_parser.add_argument( 66 | '--first-gpu', 67 | type=int, 68 | default=0, 69 | help='gpu number to use [first_gpu, ..., first_gpu+gpus]' 70 | ) 71 | parent_parser.add_argument( 72 | '--ckp', 73 | type=str, 74 | default='', 75 | help='ckp path, if left empty no checkpoint is used' 76 | ) 77 | parent_parser.add_argument( 78 | '--hparams', 79 | type=str, 80 | default='', 81 | help='path for hparams of ckp if left empty no checkpoint is used' 82 | ) 83 | parent_parser.add_argument("--test", 84 | action="store_true", 85 | help="whether to train or test" 86 | ) 87 | 88 | 89 | # each LightningModule defines arguments relevant to it 90 | parser = LightningModel.add_model_specific_args(parent_parser) 91 | hyperparams = parser.parse_args() 92 | 93 | print(hyperparams) 94 | main(hyperparams) 95 | -------------------------------------------------------------------------------- /CGAT/test_prepare_data.py: -------------------------------------------------------------------------------- 1 | 2 | import gzip as gz 3 | import sys, os 4 | import argparse 5 | import functools 6 | import json 7 | import numpy as np 8 | import pandas as pd 9 | import warnings 10 | import torch 11 | import pickle 12 | from torch.utils.data import Dataset, DataLoader 13 | 14 | 15 | class Featuriser(object): 16 | """ 17 | Base class for featurising nodes and edges. 18 | """ 19 | 20 | def __init__(self, allowed_types): 21 | self.allowed_types = set(allowed_types) 22 | self._embedding = {} 23 | 24 | def get_fea(self, key): 25 | assert key in self.allowed_types, "{} is not an allowed atom type".format( 26 | key) 27 | return self._embedding[key] 28 | 29 | def load_state_dict(self, state_dict): 30 | self._embedding = state_dict 31 | self.allowed_types = set(self._embedding.keys()) 32 | 33 | def get_state_dict(self): 34 | return self._embedding 35 | 36 | def embedding_size(self): 37 | return len(self._embedding[list(self._embedding.keys())[0]]) 38 | 39 | 40 | class LoadFeaturiser(Featuriser): 41 | """ 42 | Initialize atom feature vectors using a JSON file, which is a python 43 | dictionary mapping from element number to a list representing the 44 | feature vector of the element. 45 | 46 | Notes 47 | --------- 48 | For the specific composition net application the keys are concatenated 49 | strings of the form "NaCl" where the order of concatenation matters. 50 | This is done because the bond "ClNa" has the opposite dipole to "NaCl" 51 | so for a general representation we need to be able to asign different 52 | bond features for different directions on the multigraph. 53 | 54 | Parameters 55 | ---------- 56 | elem_embedding_file: str 57 | The path to the .json file 58 | """ 59 | 60 | def __init__(self, embedding_file): 61 | with open(embedding_file) as f: 62 | embedding = json.load(f) 63 | allowed_types = set(embedding.keys()) 64 | super(LoadFeaturiser, self).__init__(allowed_types) 65 | for key, value in embedding.items(): 66 | self._embedding[key] = np.array(value, dtype=float) 67 | 68 | 69 | 70 | def build_dataset_prepare(data, 71 | target_property=['e_above_hull_new','e-form'], 72 | fea_path = "../embeddings/matscholar-embedding.json", 73 | id='layered_perovskites'): 74 | """Use to calculate features for lists of pickle and gzipped ComputedEntry pickles, 75 | returns dictionary with all necessary inputs. Use for lists with all materials having 76 | the same number of atoms""" 77 | 78 | def tensor2numpy(l): 79 | """recursively convert torch Tensors into numpy arrays""" 80 | if isinstance(l, torch.Tensor): 81 | return l.numpy() 82 | elif isinstance(l, str) or isinstance(l, int) or isinstance(l, float): 83 | return l 84 | elif isinstance(l, list) or isinstance(l, tuple): 85 | return np.asarray([tensor2numpy(i) for i in l]) 86 | elif isinstance(l, dict): 87 | npdict = {} 88 | for name, val in l.items(): 89 | npdict[name]= tensor2numpy(val) 90 | return npdict 91 | else: 92 | return None # this will give an error later on 93 | 94 | d = CompositionDataPrepare(data=data, 95 | fea_path=fea_path, 96 | target_property=target_property) 97 | loader = DataLoader(d, batch_size=1) 98 | 99 | input1_ = [] 100 | input2_ = [] 101 | input3_ = [] 102 | comps_ = [] 103 | batch_comp_ = [] 104 | if type(target_property)==list: 105 | target_ = {} 106 | for name in target_property: 107 | target_[name] = [] 108 | else: 109 | target_ = [] 110 | batch_ids_=[] 111 | 112 | for input_, target, batch_comp, batch_ids in loader: 113 | input1_.append(input_[0]) 114 | comps_.append(input_[1]) 115 | input2_.append(input_[2]) 116 | input3_.append(input_[3]) 117 | if isinstance(target_property,list): 118 | for name in target_property: 119 | target_[name].append(target[name]) 120 | else: 121 | target_.append(target) 122 | batch_comp_.append(batch_comp) 123 | batch_ids_.append(batch_ids) 124 | 125 | input1_ = tensor2numpy(input1_) 126 | input2_ = tensor2numpy(input2_) 127 | input3_ = tensor2numpy(input3_) 128 | 129 | n = input1_[0].shape[0] 130 | shape = input1_.shape 131 | try: 132 | input1_ = np.reshape(input1_,(1, shape[0], n, 24)) 133 | input2_ = np.reshape(input2_,(1, shape[0], n, 24)) 134 | input3_ = np.reshape(input3_,(1, shape[0], n, 24)) 135 | 136 | except: 137 | input1_=np.asarray(input1_) 138 | input2_=np.asarray(input2_) 139 | input3_=np.asarray(input3_) 140 | 141 | inputs_ = np.vstack((input1_, input2_, input3_)) 142 | 143 | return {'input': inputs_, 144 | 'batch_ids': batch_ids_, 145 | 'batch_comp': tensor2numpy(batch_comp_), 146 | 'target': tensor2numpy(target_), 147 | 'comps': tensor2numpy(comps_)} 148 | 149 | class CompositionDataPrepare(Dataset): 150 | """ 151 | The CompositionData dataset is a wrapper for a dataset data points are 152 | automatically constructed from composition strings. 153 | """ 154 | def __init__(self, data, fea_path, target_property='e-form', radius=18.0, max_neighbor_number = 24): 155 | """ 156 | """ 157 | self.data = pickle.load(gz.open(data,'rb')) 158 | print(len(self.data)) 159 | self.radius =radius 160 | self.max_num_nbr = max_neighbor_number 161 | self.target_property = target_property 162 | assert os.path.exists(fea_path), "{} does not exist!".format(fea_path) 163 | self.atom_features = LoadFeaturiser(fea_path) 164 | self.atom_fea_dim = self.atom_features.embedding_size 165 | 166 | def __len__(self): 167 | return len(self.data) 168 | 169 | def __getitem__(self, idx): 170 | cry_id ='ml_perovskites' 171 | composition = self.data[idx].composition.formula 172 | try: 173 | crystal = self.data[idx].structure 174 | except: 175 | crystal = self.data[idx] 176 | elements = [element.specie.symbol for element in crystal] 177 | if isinstance(self.target_property,tuple): 178 | target = self.data[idx].as_dict()[self.target_property[0]][self.target_property[1]] 179 | elif isinstance(self.target_property,list): 180 | target = {} 181 | for name in self.target_property: 182 | target[name] = self.data[idx].data[name]/len(crystal.sites) 183 | else: 184 | target = self.data[idx].data[self.target_property]/len(crystal.sites) 185 | #target = target/len(crystal.sites) 186 | 187 | all_nbrs = crystal.get_all_neighbors(self.radius, include_index=True) 188 | all_nbrs = [sorted(nbrs, key=lambda x: x[1])[0:self.max_num_nbr] for nbrs in all_nbrs] 189 | 190 | nbr_fea_idx, nbr_fea, self_fea_idx = [], [], [] 191 | for site, nbr in enumerate(all_nbrs): 192 | nbr_fea_idx_sub, nbr_fea_sub, self_fea_idx_sub= [],[],[] 193 | if len(nbr) < self.max_num_nbr: 194 | warnings.warn('{} not find enough neighbors to build graph. ' 195 | 'If it happens frequently, consider increase ' 196 | 'radius.'.format(cry_id)) 197 | for n in range(len(nbr)): 198 | self_fea_idx_sub.append(site) 199 | for j in range(len(nbr)): 200 | nbr_fea_idx_sub.append(nbr[j][2]) 201 | index = 1 202 | try: 203 | dist = nbr[0][1] 204 | except: 205 | print('no neighbor', cry_id) 206 | for el in nbr: 207 | if(el[1]>dist+1e-8): 208 | dist = el[1] 209 | index+=1 210 | nbr_fea_sub.append(index) 211 | else: 212 | for n in range(self.max_num_nbr): 213 | self_fea_idx_sub.append(site) 214 | for j in range(self.max_num_nbr): 215 | nbr_fea_idx_sub.append(nbr[j][2]) 216 | index = 1 217 | dist = nbr[0][1] 218 | for j in range(self.max_num_nbr): 219 | if(nbr[j][1]>dist+1e-8): 220 | dist = nbr[j][1] 221 | index+=1 222 | nbr_fea_sub.append(index) 223 | nbr_fea_idx.append(nbr_fea_idx_sub) 224 | nbr_fea.append(nbr_fea_sub) 225 | self_fea_idx.append(self_fea_idx_sub) 226 | return (nbr_fea, elements, self_fea_idx, nbr_fea_idx), \ 227 | target, composition, cry_id 228 | 229 | # def get_targets(self, idx1, idx2): 230 | # target = [] 231 | # l=[] 232 | # for el in idx2: 233 | # l.append(self.data[el][self.target_property]) 234 | # for el in idx1: 235 | # target.append(l[el]) 236 | # del l 237 | # return torch.tensor(target).reshape(len(idx1),1) 238 | 239 | 240 | 241 | def collate_batch(dataset_list): 242 | """ 243 | Collate a list of data and return a batch for predicting crystal 244 | properties. 245 | Parameters 246 | ---------- 247 | dataset_list: list of tuples for each data point. 248 | (atom_fea, nbr_fea, nbr_fea_idx, target) 249 | atom_fea: torch.Tensor shape (n_i, atom_fea_len) 250 | nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len) 251 | nbr_fea_idx: torch.LongTensor shape (n_i, M) 252 | target: torch.Tensor shape (1, ) 253 | cif_id: str or int 254 | Returns 255 | ------- 256 | N = sum(n_i); N0 = sum(i) 257 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len) 258 | Atom features from atom type 259 | batch_nbr_fea: torch.Tensor shape (N, M, nbr_fea_len) 260 | Bond features of each atom"s M neighbors 261 | batch_nbr_fea_idx: torch.LongTensor shape (N, M) 262 | Indices of M neighbors of each atom 263 | crystal_atom_idx: list of torch.LongTensor of length N0 264 | Mapping from the crystal idx to atom idx 265 | target: torch.Tensor shape (N, 1) 266 | Target value for prediction 267 | batch_cif_ids: list 268 | """ 269 | # define the lists 270 | batch_atom_weights = [] 271 | batch_atom_fea = [] 272 | batch_nbr_fea = [] 273 | batch_self_fea_idx = [] 274 | batch_nbr_fea_idx = [] 275 | crystal_atom_idx = [] 276 | batch_target = [] 277 | batch_comp = [] 278 | batch_cry_ids = [] 279 | 280 | cry_base_idx = 0 281 | for i, ((atom_fea, nbr_fea, self_fea_idx, nbr_fea_idx,_), 282 | target, comp, cry_id) in enumerate(dataset_list): 283 | # number of atoms for this crystal 284 | n_i = atom_fea.shape[0] 285 | # batch the features together 286 | # batch_atom_weights.append(atom_weights) 287 | batch_atom_fea.append(atom_fea) 288 | batch_nbr_fea.append(nbr_fea) 289 | # mappings from bonds to atoms 290 | batch_self_fea_idx.append(self_fea_idx+cry_base_idx) 291 | batch_nbr_fea_idx.append(nbr_fea_idx+cry_base_idx) 292 | 293 | # mapping from atoms to crystals 294 | crystal_atom_idx.append(torch.tensor([i]*n_i)) 295 | 296 | # batch the targets and ids 297 | batch_target.append(target) 298 | batch_comp.append(comp) 299 | batch_cry_ids.append(cry_id) 300 | 301 | # increment the id counter 302 | cry_base_idx += n_i 303 | return (torch.cat(batch_atom_fea, dim=0), torch.cat(batch_nbr_fea, dim=0), torch.cat(batch_self_fea_idx, dim=0), torch.cat(batch_nbr_fea_idx, dim=0), torch.cat(crystal_atom_idx)), \ 304 | torch.cat(batch_target, dim=0), \ 305 | batch_comp, \ 306 | batch_cry_ids 307 | 308 | 309 | def collate_batch2(dataset_list): 310 | """ 311 | Collate a list of data and return a batch for predicting crystal 312 | properties. 313 | Parameters 314 | ---------- 315 | dataset_list: list of tuples for each data point. 316 | (atom_fea, nbr_fea, nbr_fea_idx, target) 317 | atom_fea: torch.Tensor shape (n_i, atom_fea_len) 318 | nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len) 319 | nbr_fea_idx: torch.LongTensor shape (n_i, M) 320 | target: torch.Tensor shape (1, ) 321 | cif_id: str or int 322 | Returns 323 | ------- 324 | N = sum(n_i); N0 = sum(i) 325 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len) 326 | Atom features from atom type 327 | batch_nbr_fea: torch.Tensor shape (N, M, nbr_fea_len) 328 | Bond features of each atom"s M neighbors 329 | batch_nbr_fea_idx: torch.LongTensor shape (N, M) 330 | Indices of M neighbors of each atom 331 | crystal_atom_idx: list of torch.LongTensor of length N0 332 | Mapping from the crystal idx to atom idx 333 | target: torch.Tensor shape (N, 1) 334 | Target value for prediction 335 | batch_cif_ids: list 336 | """ 337 | # define the lists 338 | batch_atom_weights = [] 339 | batch_atom_fea = [] 340 | batch_nbr_fea = [] 341 | batch_self_fea_idx = [] 342 | batch_nbr_fea_idx = [] 343 | crystal_atom_idx = [] 344 | batch_target = [] 345 | batch_comp = [] 346 | batch_cry_ids = [] 347 | 348 | cry_base_idx = 0 349 | for i, ((nbr_fea, atom_fea, self_fea_idx, nbr_fea_idx), 350 | target, comp, cry_id) in enumerate(dataset_list): 351 | # number of atoms for this crystal 352 | n_i = atom_fea.shape[0] 353 | # batch the features together 354 | # batch_atom_weights.append(atom_weights) 355 | batch_atom_fea.append(atom_fea) 356 | batch_nbr_fea.append(nbr_fea) 357 | # mappings from bonds to atoms 358 | batch_self_fea_idx.append(self_fea_idx+cry_base_idx) 359 | batch_nbr_fea_idx.append(nbr_fea_idx+cry_base_idx) 360 | 361 | # mapping from atoms to crystals 362 | crystal_atom_idx.append(torch.tensor([i]*n_i)) 363 | 364 | # batch the targets and ids 365 | batch_target.append(target) 366 | batch_comp.append(comp) 367 | batch_cry_ids.append(cry_id) 368 | 369 | # increment the id counter 370 | cry_base_idx += n_i 371 | return (torch.cat(batch_atom_fea, dim=0), torch.cat(batch_nbr_fea, dim=0), torch.cat(batch_self_fea_idx, dim=0), torch.cat(batch_nbr_fea_idx, dim=0), torch.cat(crystal_atom_idx)), \ 372 | torch.cat(batch_target, dim=0), \ 373 | batch_comp, \ 374 | batch_cry_ids 375 | 376 | class AverageMeter(object): 377 | """Computes and stores the average and current value""" 378 | def __init__(self): 379 | self.reset() 380 | 381 | def reset(self): 382 | self.val = 0 383 | self.avg = 0 384 | self.sum = 0 385 | self.count = 0 386 | 387 | def update(self, val, n=1): 388 | self.val = val 389 | self.sum += val * n 390 | self.count += n 391 | self.avg = self.sum / self.count 392 | 393 | 394 | class Normalizer(object): 395 | """Normalize a Tensor and restore it later. """ 396 | def __init__(self, log=False): 397 | """tensor is taken as a sample to calculate the mean and std""" 398 | self.mean = torch.tensor((0)) 399 | self.std = torch.tensor((1)) 400 | 401 | def fit(self, tensor, dim=0, keepdim=False): 402 | """tensor is taken as a sample to calculate the mean and std""" 403 | self.mean = torch.mean(tensor, dim, keepdim) 404 | self.std = torch.std(tensor, dim, keepdim) 405 | 406 | def norm(self, tensor): 407 | return (tensor - self.mean) / self.std 408 | 409 | def denorm(self, normed_tensor): 410 | return normed_tensor * self.std + self.mean 411 | 412 | def state_dict(self): 413 | return {"mean": self.mean, 414 | "std": self.std} 415 | 416 | def load_state_dict(self, state_dict): 417 | self.mean = state_dict["mean"].cpu() 418 | self.std = state_dict["std"].cpu() 419 | 420 | file = sys.argv[1] 421 | test = build_dataset_prepare(file) 422 | pickle.dump(test, gz.open('features/features_'+file.replace('../../data_0921/',''),'wb')) 423 | -------------------------------------------------------------------------------- /CGAT/train.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import ModelCheckpoint 2 | import os 3 | from argparse import ArgumentParser 4 | import os 5 | import gc 6 | import datetime 7 | import numpy as np 8 | import pandas as pd 9 | 10 | import numpy as np 11 | import torch 12 | 13 | import pytorch_lightning as pl 14 | from .lightning_module import LightningModel 15 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger 16 | 17 | SEED = 1 18 | torch.manual_seed(SEED) 19 | np.random.seed(SEED) 20 | 21 | 22 | def main(hparams): 23 | """ 24 | Main training routine specific for this project 25 | :param hparams: 26 | """ 27 | # initialize model 28 | if hparams.pretrained_model is None: 29 | model = LightningModel(hparams) 30 | else: 31 | assert os.path.isfile(hparams.pretrained_model), f"Checkpoint file {hparams.pretrained_model} does not exist!" 32 | # load model from checkpoint and override old hyperparameters 33 | model = LightningModel.load_from_checkpoint(hparams.pretrained_model, **vars(hparams)) 34 | # definte path for model checkpoints and tensorboard 35 | name = "runs/f-{s}_t-{date:%Y-%m-%d_%H-%M-%S}".format( 36 | date=datetime.datetime.now(), 37 | s=hparams.seed) 38 | 39 | # initialize logger 40 | logger = TensorBoardLogger("tb_logs", name=name) 41 | # define checkpoint callback 42 | checkpoint_callback = ModelCheckpoint( 43 | filename='{epoch}-{val_mae:.3f}', 44 | dirpath=os.path.join(os.getcwd(), 'tb_logs/', name), 45 | save_top_k=1, 46 | verbose=True, 47 | monitor='val_mae', 48 | mode='min') 49 | # prefix='') 50 | 51 | print('the model will train on the following gpus:', [hparams.first_gpu + el for el in range(hparams.gpus)]) 52 | if hparams.ckp == '': 53 | trainer = pl.Trainer( 54 | max_epochs=hparams.epochs, 55 | gpus=[hparams.first_gpu + el for el in range(hparams.gpus)], 56 | strategy=hparams.distributed_backend, 57 | amp_backend='apex', 58 | amp_level=hparams.amp_optimization, 59 | callbacks=[checkpoint_callback], 60 | logger=logger, 61 | check_val_every_n_epoch=2, 62 | accumulate_grad_batches=hparams.acc_batches, 63 | ) 64 | else: 65 | trainer = pl.Trainer( 66 | max_epochs=hparams.epochs, 67 | gpus=hparams.gpus, 68 | strategy=hparams.distributed_backend, 69 | amp_backend='apex', 70 | amp_level=hparams.amp_optimization, 71 | callbacks=[checkpoint_callback], 72 | logger=logger, 73 | check_val_every_n_epoch=2, 74 | resume_from_checkpoint=hparams.ckp, 75 | accumulate_grad_batches=hparams.acc_batches, 76 | ) 77 | 78 | # START TRAINING 79 | trainer.fit(model) 80 | 81 | 82 | def run(): 83 | root_dir = os.path.dirname(os.path.realpath(__file__)) 84 | parent_parser = ArgumentParser(add_help=False) 85 | 86 | # argumentparser for the training process 87 | parent_parser.add_argument( 88 | '--gpus', 89 | type=int, 90 | default=4, 91 | help='number of gpus to use' 92 | ) 93 | parent_parser.add_argument( 94 | '--acc_batches', 95 | type=int, 96 | default=1, 97 | help='number of batches to accumulate' 98 | ) 99 | parent_parser.add_argument( 100 | '--distributed_backend', 101 | type=str, 102 | default='ddp', 103 | help='supports three options dp, ddp, ddp2' 104 | ) 105 | parent_parser.add_argument( 106 | '--amp_optimization', 107 | type=str, 108 | default='00', 109 | help='mixed precision format, default 00 (32), 01 mixed, 02 closer to 16' 110 | ) 111 | parent_parser.add_argument( 112 | '--first-gpu', 113 | type=int, 114 | default=0, 115 | help='gpu number to use [first_gpu-first_gpu+gpus]' 116 | ) 117 | parent_parser.add_argument( 118 | '--ckp', 119 | type=str, 120 | default='', 121 | help='ckp path, if left empty no checkpoint is used' 122 | ) 123 | parent_parser.add_argument("--test", 124 | action="store_true", 125 | help="whether to train or test" 126 | ) 127 | parent_parser.add_argument("--pretrained-model", 128 | type=str, 129 | default=None, 130 | help='path to checkpoint of pretrained model for transfer learning') 131 | 132 | # each LightningModule defines arguments relevant to it 133 | parser = LightningModel.add_model_specific_args(parent_parser) 134 | hyperparams = parser.parse_args() 135 | 136 | # --------------------- 137 | # RUN TRAINING 138 | # --------------------- 139 | print(hyperparams) 140 | main(hyperparams) 141 | 142 | 143 | if __name__ == '__main__': 144 | run() 145 | -------------------------------------------------------------------------------- /CGAT/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | Copyright (c) 2019-2020 Rhys Goodall 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | """ 23 | 24 | import torch 25 | import math 26 | import numpy as np 27 | from torch.optim.lr_scheduler import _LRScheduler 28 | 29 | 30 | def RobustL1(output, log_std, target): 31 | """ 32 | Robust L1 loss using a lorentzian prior. Allows for estimation 33 | of an aleatoric uncertainty. 34 | """ 35 | loss = np.sqrt(2.0) * torch.abs(output - target) * \ 36 | torch.exp(- log_std) + log_std 37 | return torch.mean(loss) 38 | 39 | 40 | def RobustL2(output, log_std, target): 41 | """ 42 | Robust L2 loss using a gaussian prior. Allows for estimation 43 | of an aleatoric uncertainty. 44 | """ 45 | loss = 0.5 * torch.pow(output - target, 2.0) * \ 46 | torch.exp(- 2.0 * log_std) + log_std 47 | return torch.mean(loss) 48 | 49 | 50 | def cyclical_lr(period=100, cycle_mul=0.2, tune_mul=0.05): 51 | # Scaler: we can adapt this if we do not want the triangular CLR 52 | def scaler(x): return 1. 53 | 54 | # Lambda function to calculate the LR 55 | def lr_lambda(it): return cycle_mul + \ 56 | (1. - cycle_mul) * relative(it, period) 57 | 58 | # Additional function to see where on the cycle we are 59 | def relative(it, stepsize): 60 | cycle = math.floor(1 + it / (period)) 61 | x = abs(2 * (it / period - cycle) + 1) 62 | return max(0, (1 - x)) * scaler(cycle) 63 | 64 | return lr_lambda 65 | 66 | 67 | class LinearLR(_LRScheduler): 68 | """Linearly increases the learning rate between two boundaries over a number of 69 | iterations. 70 | 71 | Arguments: 72 | optimizer (torch.optim.Optimizer): wrapped optimizer. 73 | end_lr (float, optional): the initial learning rate which is the lower 74 | boundary of the test. Default: 10. 75 | num_iter (int, optional): the number of iterations over which the test 76 | occurs. Default: 100. 77 | last_epoch (int): the index of last epoch. Default: -1. 78 | 79 | """ 80 | 81 | def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1): 82 | self.end_lr = end_lr 83 | self.num_iter = num_iter 84 | super(LinearLR, self).__init__(optimizer, last_epoch) 85 | 86 | def get_lr(self): 87 | curr_iter = self.last_epoch + 1 88 | r = curr_iter / self.num_iter 89 | return [base_lr + r * (self.end_lr - base_lr) 90 | for base_lr in self.base_lrs] 91 | 92 | 93 | class ExponentialLR(_LRScheduler): 94 | """Exponentially increases the learning rate between two boundaries over a number of 95 | iterations. 96 | 97 | Arguments: 98 | optimizer (torch.optim.Optimizer): wrapped optimizer. 99 | end_lr (float, optional): the initial learning rate which is the lower 100 | boundary of the test. Default: 10. 101 | num_iter (int, optional): the number of iterations over which the test 102 | occurs. Default: 100. 103 | last_epoch (int): the index of last epoch. Default: -1. 104 | 105 | """ 106 | 107 | def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1): 108 | self.end_lr = end_lr 109 | self.num_iter = num_iter 110 | super(ExponentialLR, self).__init__(optimizer, last_epoch) 111 | 112 | def get_lr(self): 113 | curr_iter = self.last_epoch + 1 114 | r = curr_iter / self.num_iter 115 | return [base_lr * (self.end_lr / base_lr) ** 116 | r for base_lr in self.base_lrs] 117 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 hyllios 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CGAT 2 | Crystal graph attention neural networks for materials prediction 3 | 4 | The code requires the following external packages: 5 | * torch 1.10.0+cu111 6 | * torch-cluster 1.5.9 7 | * torch-geometric 2.0.3 8 | * torch-scatter 2.0.9 9 | * torch-sparse 0.6.12 10 | * torch-spline-conv 1.2.1 11 | * torchaudio 0.10.0 12 | * torchvision 0.11.1 13 | * pytorch-lightning 1.5.8 14 | * pymatgen 2022.2.25 15 | * tqdm 16 | * numpy 17 | * gpytorch 1.6.0 18 | 19 | newer package versions might work. 20 | 21 | pip commands tested for python 3.8: 22 | 23 | `pip install torch==1.8.0 torchvision==0.9.0+cu101 -f https://download.pytorch.org/whl/cu101/torch_stable.html` 24 | 25 | `pip install torch-scatter==2.0.6 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html` 26 | 27 | `pip install torch-sparse==0.6.9 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html` 28 | 29 | `pip install torch-cluster==1.5.9 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html` 30 | 31 | `pip install torch-spline-conv==1.2.1 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html` 32 | 33 | `pip install torch-geometric==1.6.3` 34 | 35 | `pip install --upgrade-strategy only-if-needed pytorch-lightning==1.5.8 torch==1.8.0+cu101` 36 | 37 | `pip install pymatgen==2022.0.5` 38 | 39 | 40 | The dataset used in the work can be found at https://archive.materialscloud.org/record/2021.128. There are some slight changes as most aflow materials denoted as possible outliers in the hull were recalculated and some systems from the materials project were updated. For the non-mixed perovskite systems the distance to the hull was recalculated with this updated dataset. Data for the paper "Large-scale machine-learning-assisted exploration of the whole materials space" can be found at https://archive.materialscloud.org/record/2022.126. 41 | 42 | # Usage 43 | The package can be installed by cloning the repository and running 44 | ```shell 45 | pip install . 46 | ``` 47 | in the repository. 48 | 49 | (If one wants to edit the source code installing with `pip install -e .` is advised.) 50 | 51 | After installing one can make use of the following console scripts: 52 | * `train-CGAT` to train a Crystal Graph Network, 53 | * `prepare` to prepare trainings data for use with CGAT, 54 | * `train-GP` to train Gaussian Processes. 55 | 56 | (A full list of command line arguments can be found by running the command with `-h`.) 57 | 58 | To test the package one can download some of the data from materials cloud, e.g., https://archive.materialscloud.org/record/file?filename=dcgat_1_000.json.bz2&record_id=1485 and convert it with the script in the README and save it. 59 | 60 | ``` 61 | import json, bz2, pickle, gzip as gz 62 | from pymatgen.entries.computed_entries import ComputedStructureEntry 63 | 64 | with bz2.open("dcgat_1_000.json.bz2") as fh: 65 | data = json.loads(fh.read().decode('utf-8')) 66 | 67 | entries = [ComputedStructureEntry.from_dict(i) for i in data["entries"][:1000]] 68 | 69 | print("Found " + str(len(entries)) + " entries") 70 | print("\nEntry:\n", entries[0]) 71 | print("\nStructure:\n", entries[0].structure) 72 | #only using the first 1000 entries to save time 73 | pickle.dump(entries, gz.open('dcgat_1_000.pickle.gz','wb')) 74 | ``` 75 | 76 | Convert the ComputedStructureEntries to features: 77 | 78 | `python prepare_data.py --source-dir ../ --file dcgat_1_000.pickle.gz --target-file dcgat_1_000_features.pickle.gz --target-dir ../` 79 | 80 | Run the training script (if necessary change the rights with chmod +x ./training_scripts/train.sh). The training script assumes 2 gpus right now. If only one is available strategy=hparams.distributed_backend needs to be removed from CGAT/train.py and --gpus set to 1.: 81 | 82 | `./training_scripts/train.sh` 83 | 84 | Test the model: 85 | 86 | `python test.py --ckp tb_logs/runs/your_checkpoint.ckpt --data-path dcgat_1_000_features.pickle.gz --fea-path embeddings/matscholar-embedding.json` 87 | -------------------------------------------------------------------------------- /Utilities/adjust_data.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import pickle 3 | import gzip as gz 4 | from pymatgen.entries.computed_entries import ComputedStructureEntry 5 | import numpy as np 6 | from tqdm import tqdm, trange 7 | import os 8 | 9 | 10 | def get_batch_ids(path: Union[str, list[str]]) -> set: 11 | if isinstance(path, str): 12 | data = load(path) 13 | batch_ids = [batch_id[0] for batch_id in data['batch_ids']] 14 | elif isinstance(path, list): 15 | batch_ids = [] 16 | for p in tqdm(path): 17 | data = load(p) 18 | batch_ids.extend((batch_id[0] for batch_id in data['batch_ids'])) 19 | else: 20 | raise TypeError("Argument 'path' has to be either a string or list of strings.") 21 | return set(batch_ids) 22 | 23 | 24 | def remove_batch_ids(data: dict, batch_ids: set, inplace: bool = True, modify_batch_ids: bool = True) -> dict: 25 | if len(batch_ids) == 0: 26 | return data 27 | if not modify_batch_ids: 28 | batch_ids = batch_ids.copy() 29 | # create list of indices which have to be removed 30 | indices_to_remove = [] 31 | for i, (batch_id,) in enumerate(data['batch_ids']): 32 | if batch_id in batch_ids: 33 | indices_to_remove.append(i) 34 | batch_ids.remove(batch_id) 35 | # reverse list of indices to enable easy removing of items of a list by consecutive pops 36 | indices_to_remove.reverse() 37 | if inplace: 38 | new_data = data 39 | else: 40 | new_data = {} 41 | new_data['input'] = np.delete(data['input'], indices_to_remove, axis=1) 42 | ids: list = data['batch_ids'].copy() 43 | for i in indices_to_remove: 44 | ids.pop(i) 45 | new_data['batch_ids'] = ids 46 | new_data['batch_comp'] = np.delete(data['batch_comp'], indices_to_remove, axis=0) 47 | if not inplace: 48 | new_data['target'] = {} 49 | for target in data['target']: 50 | new_data['target'][target] = np.delete(data['target'][target], indices_to_remove, axis=0) 51 | new_data['comps'] = np.delete(data['comps'], indices_to_remove, axis=0) 52 | 53 | return new_data 54 | 55 | 56 | def get_samples_from_unprepared_data(batch_ids: set, unprepared_files: list[str], modify_batch_ids: bool = True) \ 57 | -> list[ComputedStructureEntry]: 58 | if not modify_batch_ids: 59 | batch_ids = batch_ids.copy() 60 | sample = [] 61 | for file in tqdm(unprepared_files): 62 | data: list[ComputedStructureEntry] = load(file) 63 | for entry in data: 64 | if entry.data['id'] in batch_ids: 65 | sample.append(entry) 66 | batch_ids.remove(entry.data['id']) 67 | return sample 68 | 69 | 70 | def getfile(i: int, dir: str = 'data'): 71 | return os.path.join(dir, f'data_{i * 10000}_{(i + 1) * 10000}.pickle.gz') 72 | 73 | 74 | def load(path: str): 75 | return pickle.load(gz.open(path, 'rb')) 76 | 77 | 78 | def save(data, path): 79 | pickle.dump(data, gz.open(path, 'wb')) 80 | 81 | 82 | def main(): 83 | paths = ['./active_learning/sample/random_sample.pickle.gz'] 84 | paths.extend((f'./active_learning/val/val_data_{i * 10000}_{(i + 1) * 10000}.pickle.gz' for i in range(29))) 85 | paths.extend((f'./active_learning/test/test_data_{i * 10000}_{(i + 1) * 10000}.pickle.gz' for i in range(29))) 86 | batch_ids = get_batch_ids(paths) 87 | data_path = 'data' 88 | dest_path = 'active_learning' 89 | 90 | for i in trange(283): 91 | data = load(getfile(i, data_path)) 92 | remove_batch_ids(data, batch_ids) 93 | save(data, getfile(i, dest_path)) 94 | 95 | 96 | if __name__ == '__main__': 97 | main() 98 | -------------------------------------------------------------------------------- /Utilities/calculate_embeddings.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import gzip as gz 3 | from argparse import ArgumentParser 4 | from CGAT.lightning_module import LightningModel, collate_fn 5 | from CGAT.data import CompositionData 6 | from torch.utils.data import DataLoader 7 | import os 8 | from glob import glob 9 | import torch 10 | from tqdm import tqdm 11 | 12 | 13 | def load(file): 14 | return pickle.load(gz.open(file)) 15 | 16 | 17 | def save(data, file): 18 | pickle.dump(data, gz.open(file, 'wb')) 19 | 20 | 21 | def main(): 22 | parser = ArgumentParser() 23 | parser.add_argument('--data-path', '-d', 24 | type=str, 25 | required=True, 26 | nargs='+') 27 | parser.add_argument('--target-path', '-t', 28 | type=str, 29 | required=True) 30 | parser.add_argument('--model-path', '-m', 31 | type=str, 32 | required=True) 33 | parser.add_argument('--fea-path', '-f', 34 | type=str, 35 | default=None) 36 | parser.add_argument('--batch-size', '-b', 37 | type=int, 38 | default=100) 39 | args = parser.parse_args() 40 | 41 | model = LightningModel.load_from_checkpoint(args.model_path, train=False) 42 | model.cuda() 43 | 44 | for data_path in tqdm(args.data_path): 45 | if os.path.isdir(data_path): 46 | files = glob(os.path.join(data_path, '*.pickle.gz')) 47 | else: 48 | files = [data_path] 49 | 50 | if os.path.isfile(args.target_path): 51 | raise ValueError("'target-path' must not be a directory and not an existing file!") 52 | 53 | if not os.path.isdir(args.target_path): 54 | os.makedirs(args.target_path) 55 | 56 | for file in tqdm(files): 57 | data = load(file) 58 | dataset = CompositionData( 59 | data=data, 60 | fea_path=args.fea_path if args.fea_path else model.hparams.fea_path, 61 | max_neighbor_number=model.hparams.max_nbr, 62 | target=model.hparams.target 63 | ) 64 | loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) 65 | 66 | embedding_list = [] 67 | for batch in loader: 68 | with torch.no_grad(): 69 | embedding_list.append(model.evaluate(batch, return_graph_embedding=True).cpu()) 70 | embedding = torch.cat(embedding_list).numpy() 71 | data['input'] = embedding 72 | if len(args.data_path) == 1: 73 | save(data, os.path.join(args.target_path, os.path.basename(file))) 74 | else: 75 | save(data, os.path.join(args.target_path, os.path.basename(data_path), os.path.basename(file))) 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /Utilities/calculate_errors.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import pickle 3 | 4 | from CGAT.lightning_module import LightningModel, collate_fn 5 | from tqdm import trange 6 | from CGAT.data import CompositionData 7 | from torch.utils.data import DataLoader 8 | import pandas as pd 9 | import os 10 | from sklearn.metrics import mean_absolute_error 11 | from pytorch_lightning import Trainer 12 | 13 | 14 | def get_file(i: int, path: str): 15 | return os.path.join(path, f'data_{i * 10000}_{(i + 1) * 10000}.pickle.gz') 16 | 17 | 18 | def main(): 19 | parser = LightningModel.add_model_specific_args() 20 | 21 | parser.add_argument( 22 | '--ckp', 23 | type=str, 24 | default='', 25 | help='ckp path', 26 | required=True 27 | ) 28 | 29 | parser.add_argument( 30 | '--gpus', 31 | type=int, 32 | default=2, 33 | help='number of gpus to use' 34 | ) 35 | parser.add_argument( 36 | '--acc_batches', 37 | type=int, 38 | default=1, 39 | help='number of batches to accumulate' 40 | ) 41 | parser.add_argument( 42 | '--distributed_backend', 43 | type=str, 44 | default='ddp', 45 | help='supports three options dp, ddp, ddp2' 46 | ) 47 | parser.add_argument( 48 | '--amp_optimization', 49 | type=str, 50 | default='00', 51 | help='mixed precision format, default 00 (32), 01 mixed, 02 closer to 16' 52 | ) 53 | 54 | hparams = parser.parse_args() 55 | # Disable training for faster loading 56 | hparams.train = False 57 | 58 | # load model 59 | model = LightningModel.load(hparams.ckp) 60 | 61 | trainer = Trainer( 62 | gpus=hparams.gpus, 63 | strategy=hparams.distributed_backend, 64 | amp_backend='apex', 65 | amp_level=hparams.amp_optimization, 66 | accumulate_grad_batches=hparams.acc_batches, 67 | ) 68 | 69 | PATH = 'active_learning' 70 | # iterate over unused data and evaluate the error 71 | for i in trange(283): 72 | # declare dataframe for saving errors 73 | errors = pd.DataFrame(columns=['batch_ids', 'errors']) 74 | dataset = CompositionData( 75 | data=get_file(i, PATH), 76 | fea_path=hparams.fea_path, 77 | max_neighbor_number=hparams.max_nbr, 78 | target=hparams.target 79 | ) 80 | data = pickle.load(gzip.open(get_file(i, PATH), 'rb')) 81 | targets = data['target'][hparams.target].reshape((-1, 1, 1)) 82 | loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn) 83 | # TODO get prediction from other GPU! 84 | predictions = trainer.predict(model=model, dataloaders=loader) 85 | for j, batch in enumerate(predictions): 86 | row = {'errors': mean_absolute_error(targets[j], predictions[j].cpu().numpy()), 87 | 'batch_ids': data['batch_ids'][j][0]} 88 | errors = errors.append(row, ignore_index=True) 89 | 90 | errors.to_csv(get_file(i, PATH + '/temp').replace('data', 'errors').replace('pickle.gz', 'csv'), index=False) 91 | 92 | 93 | if __name__ == '__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /Utilities/element_correlation.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import gzip as gz 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from sample import getfile 6 | from tqdm import trange, tqdm 7 | from pymatgen.core.periodic_table import Element 8 | import re 9 | 10 | 11 | def get_distribution(hist: np.ndarray): 12 | def f(x): 13 | x = int(x) 14 | return hist[x] 15 | 16 | return f 17 | 18 | 19 | def main(): 20 | elements: list[list[int]] = [] 21 | pattern = re.compile(r'([a-zA-Z]+)\d+') 22 | for i in trange(283): 23 | data = pickle.load(gz.open(getfile(i), 'rb')) 24 | for comp, in data['batch_comp']: 25 | elements.append([Element(el).Z for el in pattern.findall(comp)]) 26 | biggest_element = max([max(els) for els in elements]) 27 | correlation_matrix = np.zeros((biggest_element, biggest_element)) 28 | for els in tqdm(elements): 29 | for i in els: 30 | for j in els: 31 | correlation_matrix[i - 1, j - 1] += 1 32 | 33 | correlation_matrix = ( 34 | correlation_matrix.T / np.where(correlation_matrix.diagonal() != 0, correlation_matrix.diagonal(), 35 | np.ones(biggest_element))).T 36 | 37 | for i in range(biggest_element): 38 | correlation_matrix[i, i] = 0 39 | plt.matshow(correlation_matrix) 40 | plt.colorbar() 41 | print(np.sort(correlation_matrix.flatten())[:-10:-1]) 42 | 43 | plt.figure() 44 | x = np.linspace(0, biggest_element, biggest_element * 100, endpoint=False) 45 | # plt.hist(list(range(biggest_element)), bins=biggest_element, weights=correlation_matrix.mean(axis=0)) 46 | f = get_distribution(correlation_matrix.mean(axis=0)) 47 | y = np.array([f(i) for i in x]) 48 | plt.plot(x, np.where(y != 0, y, np.zeros_like(y))) 49 | plt.figure() 50 | plt.plot(x, [min(i, 200) for i in np.where(y > 1e-3, 1 / y, np.zeros_like(y))]) 51 | 52 | plt.show() 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /Utilities/errors_of_additional_data.py: -------------------------------------------------------------------------------- 1 | from CGAT.lightning_module import LightningModel, collate_fn 2 | from CGAT.data import CompositionData 3 | from torch.utils.data import DataLoader 4 | import pandas as pd 5 | import os 6 | from sklearn.metrics import mean_absolute_error 7 | import glob 8 | from get_additional_data import get_composition 9 | from tqdm import tqdm 10 | import numpy as np 11 | import re 12 | 13 | 14 | def get_seed(path): 15 | pattern = re.compile(r'f-(\d+)_') 16 | return int(pattern.search(path).group(1)) 17 | 18 | 19 | def main(): 20 | data_paths = glob.glob(os.path.join("additional_data", "*", "*.pickle.gz")) 21 | assert len(data_paths) > 0 22 | print(f"Found {len(data_paths)} datasets") 23 | # sizes = [50_000, 75_000, 100_000, 125_000, 150_000, 200_000, 250_000] 24 | # runs = ["f-0_t-2022-01-03_15-07-47", 25 | # "f-0_t-2022-01-09_14-33-04", 26 | # "f-0_t-2022-01-14_14-50-14", 27 | # "f-0_t-2022-01-18_10-00-20", 28 | # "f-0_t-2022-01-22_17-40-56", 29 | # "f-0_t-2022-01-26_22-42-02", 30 | # "f-0_t-2022-02-02_12-39-57"] 31 | # model_paths = [glob.glob(os.path.join("tb_logs", 32 | # "runs", 33 | # "{run}", 34 | # "*.ckpt").format(run=run))[0] for run in runs] 35 | model_paths = sorted(glob.glob(os.path.join('new_active_learning', 'checkpoints', '*', '*.ckpt')), key=get_seed) 36 | seeds = list(map(get_seed, model_paths)) 37 | df = pd.DataFrame(columns=['comp', 'seed', 'mae']) 38 | for i, model_path in zip(seeds, tqdm(model_paths)): 39 | model = LightningModel.load_from_checkpoint(model_path, train=False) 40 | model = model.cuda() 41 | 42 | for path in tqdm(data_paths): 43 | dataset = CompositionData( 44 | data=path, 45 | fea_path="embeddings/matscholar-embedding.json", 46 | max_neighbor_number=model.hparams.max_nbr, 47 | target=model.hparams.target 48 | ) 49 | loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn) 50 | comp = get_composition(path) 51 | errors = [] 52 | for batch in loader: 53 | _, _, pred, target, _ = model.evaluate(batch) 54 | errors.append(mean_absolute_error(target.cpu().numpy(), pred.cpu().numpy())) 55 | df.loc[len(df)] = [comp, i, np.mean(errors)] 56 | df.to_csv('new_active_learning/errors.csv', index=False) 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /Utilities/filter_embeddings.py: -------------------------------------------------------------------------------- 1 | from adjust_data import save, load 2 | from tqdm import tqdm 3 | from glob import glob 4 | import os 5 | import numpy as np 6 | 7 | 8 | def remove_batch_ids(data: dict, batch_ids: set, inplace: bool = True, modify_batch_ids: bool = True) -> dict: 9 | if len(batch_ids) == 0: 10 | return data 11 | if not modify_batch_ids: 12 | batch_ids = batch_ids.copy() 13 | # create list of indices which have to be removed 14 | indices_to_remove = [] 15 | for i, (batch_id,) in enumerate(data['batch_ids']): 16 | if batch_id in batch_ids: 17 | indices_to_remove.append(i) 18 | batch_ids.remove(batch_id) 19 | # reverse list of indices to enable easy removing of items of a list by consecutive pops 20 | indices_to_remove.reverse() 21 | if inplace: 22 | new_data = data 23 | else: 24 | new_data = {} 25 | new_data['input'] = np.delete(data['input'], indices_to_remove, axis=0) 26 | ids: list = data['batch_ids'].copy() 27 | for i in indices_to_remove: 28 | ids.pop(i) 29 | new_data['batch_ids'] = ids 30 | new_data['batch_comp'] = np.delete(data['batch_comp'], indices_to_remove, axis=0) 31 | if not inplace: 32 | new_data['target'] = {} 33 | for target in data['target']: 34 | new_data['target'][target] = np.delete(data['target'][target], indices_to_remove, axis=0) 35 | new_data['comps'] = np.delete(data['comps'], indices_to_remove, axis=0) 36 | 37 | return new_data 38 | 39 | 40 | def get_ids(file): 41 | data = load(file) 42 | return set([batch_id for batch_id, in data['batch_ids']]) 43 | 44 | 45 | def get_test_and_val_ids(path_to_dir): 46 | files = glob(os.path.join(path_to_dir, 'val', '*.pickle.gz')) + \ 47 | glob(os.path.join(path_to_dir, 'test', '*.pickle.gz')) 48 | ids = set() 49 | for file in tqdm(files): 50 | ids |= get_ids(file) 51 | return ids 52 | 53 | 54 | def main(): 55 | path = 'graph_embeddings' 56 | target_dir = os.path.join(path, 'train') 57 | print('Gathering ids...') 58 | test_and_val_ids = get_test_and_val_ids(path) 59 | 60 | print('Deleting test and validation entries form training data...') 61 | files = glob(os.path.join(path, '*.pickle.gz')) 62 | if not os.path.isdir(target_dir): 63 | os.makedirs(target_dir) 64 | for file in tqdm(files): 65 | data = load(file) 66 | data = remove_batch_ids(data, test_and_val_ids) 67 | save(data, os.path.join(target_dir, os.path.basename(file))) 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /Utilities/get_additional_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import pickle 4 | import re 5 | import bz2 6 | import gzip as gz 7 | import json 8 | from pymatgen.entries.computed_entries import ComputedStructureEntry 9 | from CGAT.prepare_data import build_dataset_prepare 10 | from tqdm import tqdm 11 | 12 | 13 | def get_composition(file: str): 14 | pattern = re.compile(r'(?:/|\\)' + r'([A-Z]\d*)' + r'([A-Z]\d*)?' * 10 + r'(?:/|\\)') 15 | return "".join(filter(None, pattern.search(file).groups())) 16 | 17 | 18 | def get_file_name(file: str): 19 | pattern = re.compile(r'([\w-]*)\.json\.bz2') 20 | return pattern.search(file)[1] 21 | 22 | 23 | def main(): 24 | PATH = "/nfs/data-019/marques/data/material_prediction_CGAT/{comp}" 25 | files = glob.glob(os.path.join(PATH.format(comp='binaries'), '*', 'annotated', '*.json.bz2')) + \ 26 | glob.glob(os.path.join(PATH.format(comp='ternaries'), '*', 'annotated', '*.json.bz2')) 27 | print(f"Found {len(files)} files.") 28 | new_dir = "additional_data" 29 | if not os.path.exists(new_dir): 30 | os.mkdir(new_dir) 31 | for file in tqdm(files): 32 | dir = os.path.join(new_dir, get_composition(file)) 33 | if not os.path.exists(dir): 34 | os.mkdir(dir) 35 | with bz2.open(file, 'rb') as f: 36 | json_data = json.load(f) 37 | data = list(map(ComputedStructureEntry.from_dict, json_data['entries'])) 38 | with gz.open(os.path.join(dir, f'{get_file_name(file)}.pickle.gz'), 'wb') as f: 39 | pickle.dump(build_dataset_prepare(data, target_property=['e_above_hull_new', 'e-form']), f) 40 | 41 | 42 | def test_get_composition(): 43 | cases = [ 44 | "/nfs/data-019/marques/data/material_prediction_CGAT/binaries/A2B13/annotated/batch-000.json.bz2", 45 | "/nfs/data-019/marques/data/material_prediction_CGAT/binaries/A2B3/annotated/batch-000.json.bz2", 46 | "/nfs/data-019/marques/data/material_prediction_CGAT/binaries/AB12/annotated/batch-000.json.bz2", 47 | "/nfs/data-019/marques/data/material_prediction_CGAT/binaries/AB2/annotated/batch-000.json.bz2", 48 | "/nfs/data-019/marques/data/material_prediction_CGAT/binaries/AB/annotated/batch-000.json.bz2", 49 | "/nfs/data-019/marques/data/material_prediction_CGAT/ternaries/A2B2C5/annotated/batch-000.json.bz2", 50 | "/nfs/data-019/marques/data/material_prediction_CGAT/ternaries/A3B4C12/annotated/batch-000.json.bz2", 51 | ] 52 | results = [ 53 | "A2B13", 54 | "A2B3", 55 | "AB12", 56 | "AB2", 57 | "AB", 58 | "A2B2C5", 59 | "A3B4C12" 60 | ] 61 | for case, result in zip(cases, results): 62 | try: 63 | assert get_composition(case) == result 64 | except AssertionError: 65 | print("Fails for: ", repr(case)) 66 | print("Expected: ", repr(result)) 67 | print("Found:", repr(get_composition(case))) 68 | 69 | 70 | if __name__ == '__main__': 71 | # test_get_composition() 72 | main() 73 | -------------------------------------------------------------------------------- /Utilities/get_highest_errors.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pickle 3 | import gzip 4 | from calculate_errors import get_file 5 | from tqdm import trange 6 | import numpy as np 7 | from sample import get_id, search 8 | 9 | 10 | def get_csv(i: int, path: str): 11 | return get_file(i, path + '/temp').replace('data', 'errors').replace('pickle.gz', 'csv') 12 | 13 | 14 | def main(): 15 | PATH = 'active_learning' 16 | UNPREPARED_PATH = 'unprepared_volume_data' 17 | 18 | # Start by loading the errors file 19 | errors = pd.DataFrame(columns=['batch_ids', 'errors']) 20 | 21 | print('Reading error files...') 22 | errors = pd.concat([pd.read_csv(get_csv(i, PATH)) for i in trange(283)], ignore_index=True) 23 | 24 | N = 25000 25 | # find the first N samples with the largest errors 26 | print('Sorting...') 27 | errors = errors.sort_values(by='errors', ascending=False, ignore_index=True).head(N) 28 | # convert batch_ids to a set for faster 'in' searching 29 | batch_ids = set(errors['batch_ids'].to_list()) 30 | 31 | # create list for saving those samples 32 | new_sample = [] 33 | print('Saving samples with highest errors...') 34 | for i in trange(283): 35 | data = pickle.load(gzip.open(get_file(i, PATH), 'rb')) 36 | unprepared_data = pickle.load(gzip.open(get_file(i, UNPREPARED_PATH), 'rb')) 37 | indices_to_remove = [] 38 | current_batch_ids = [] 39 | # find all batch_ids for the new sample in the current file 40 | for j, batch_id in enumerate(data['batch_ids']): 41 | batch_id = batch_id[0] 42 | if batch_id in batch_ids: 43 | indices_to_remove.append(j) 44 | current_batch_ids.append(get_id(batch_id)) 45 | batch_ids.remove(batch_id) 46 | 47 | if len(indices_to_remove) > 0: 48 | # reverse order of indices for easy popping 49 | indices_to_remove.reverse() 50 | # remove those entries from data 51 | data['input'] = np.delete(data['input'], indices_to_remove, axis=1) 52 | for j in indices_to_remove: 53 | data['batch_ids'].pop(j) 54 | data['batch_comp'] = np.delete(data['batch_comp'], indices_to_remove) 55 | data['comps'] = np.delete(data['comps'], indices_to_remove) 56 | for target in data['target']: 57 | data['target'][target] = np.delete(data['target'][target], indices_to_remove) 58 | # overwrite the data with the removed entries 59 | pickle.dump(data, gzip.open(get_file(i, PATH), 'wb')) 60 | 61 | # find batch_ids in unprepared data and append to new_sample 62 | for batch_id in current_batch_ids: 63 | new_sample.append(unprepared_data.pop(search(unprepared_data, batch_id))) 64 | # save new sample 65 | pickle.dump(new_sample, gzip.open('./active_learning/sample/unprepared_sample_1.pickle.gz', 'wb')) 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /Utilities/gp_predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from CGAT.gaussian_process import GLightningModel, EmbeddingData 3 | from glob import glob 4 | from tqdm import tqdm 5 | from torch.utils.data import DataLoader 6 | import os 7 | import pandas as pd 8 | import numpy as np 9 | 10 | 11 | def main(): 12 | data_paths = glob(os.path.join("new_active_learning", "A*B*", "*.pickle.gz")) 13 | assert len(data_paths) > 0 14 | print(f"Found {len(data_paths)} datasets.") 15 | 16 | model_path = os.path.join("new_active_learning", "gp.ckpt") 17 | model = GLightningModel.load_from_checkpoint(model_path, train=False) 18 | 19 | for path in tqdm(data_paths): 20 | dataset = EmbeddingData(path, model.hparams.target) 21 | loader = DataLoader(dataset, batch_size=500, shuffle=False) 22 | predictions = [] 23 | uncertainties = [] 24 | errors = [] 25 | for batch in loader: 26 | with torch.no_grad(): 27 | _, (_, upper), pred, target, _ = model.evaluate(batch) 28 | predictions.append(pred.numpy()) 29 | uncertainties.append(upper.numpy() - pred.numpy()) 30 | errors.append(np.abs(pred.numpy() - target.numpy())) 31 | df = pd.DataFrame() 32 | columns = ['prediction', 'uncertainty', 'absolute error'] 33 | lists = [predictions, uncertainties, errors] 34 | for column, list in zip(columns, lists): 35 | df[column] = np.concatenate(list) 36 | df.to_csv(os.path.join(os.path.dirname(path), 'gp_results.csv'), index=False) 37 | 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /Utilities/metropolis.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from scipy.special import erf 5 | 6 | 7 | class MarkovChain: 8 | 9 | def __init__(self, distribution, generator, start=None, *args, **kwargs): 10 | self.distribution = distribution 11 | self.generator = generator 12 | self.args = args 13 | self.kwargs = kwargs 14 | self.chain = [] 15 | if start is None: 16 | x = generator(*args, **kwargs) 17 | p = distribution(x) 18 | while p <= 0: 19 | x = generator(*args, **kwargs) 20 | p = distribution(x) 21 | self.chain.append(x) 22 | else: 23 | self.chain.append(start) 24 | 25 | def __getitem__(self, item): 26 | return self.chain[item] 27 | 28 | def __iter__(self): 29 | return self.chain.__iter__() 30 | 31 | def __len__(self): 32 | return self.chain.__len__() 33 | 34 | def step(self, n: int = 1): 35 | for _ in range(n): 36 | y = self.generator(*self.args, **self.kwargs) 37 | p = min(1, self.distribution(y) / self.distribution(self[-1])) 38 | if random.random() <= p: 39 | self.chain.append(y) 40 | else: 41 | self.chain.append(self[-1]) 42 | 43 | 44 | def main(): 45 | def distribution(x): 46 | return 1 if np.abs(x) > .5 else 0 47 | 48 | def distribution2(skew): 49 | def f(x): 50 | return 2 / np.sqrt(2 * np.pi) * np.exp(-np.square(x) / 2) * (1 + erf(skew * x / np.sqrt(2))) / 2 51 | 52 | return f 53 | 54 | def generator(): 55 | return random.random() * 6 - 3 56 | 57 | def normal_distribution(x): 58 | return np.exp(-np.square(x) / 2) / np.sqrt(2 * np.pi) 59 | 60 | def cum_distribution(x): 61 | return (1 + erf(x / np.sqrt(2))) / 2 62 | 63 | def get_skewd_gaussion(skew, location, scale): 64 | def f(x): 65 | return 2 / scale * normal_distribution((x - location) / scale) * cum_distribution( 66 | skew * (x - location) / scale) 67 | 68 | return f 69 | 70 | chain = MarkovChain(get_skewd_gaussion(100, -0.02, 0.7), lambda: random.random() * 3.5) 71 | n = 50_000 72 | chain.step(n) 73 | plt.hist(chain, weights=[1 / len(chain)] * len(chain), bins=100) 74 | plt.show() 75 | # 76 | # x = np.linspace(-3, 3, 600) 77 | # skew = 4 78 | # y = 2 / np.sqrt(2 * np.pi) * np.exp(-np.square(x) / 2) * (1 + erf(skew * x / np.sqrt(2))) / 2 79 | # plt.plot(x, y) 80 | # plt.show() 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /Utilities/prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from CGAT.lightning_module import LightningModel, collate_fn 4 | from CGAT.data import CompositionData 5 | from torch.utils.data import DataLoader 6 | import numpy as np 7 | import os 8 | import glob 9 | from get_additional_data import get_composition 10 | from tqdm import tqdm 11 | import re 12 | 13 | 14 | def get_seed(path): 15 | pattern = re.compile(r'f-(\d+)_') 16 | return int(pattern.search(path).group(1)) 17 | 18 | 19 | def main(): 20 | data_paths = glob.glob(os.path.join("additional_data", "*", "*.pickle.gz")) 21 | assert len(data_paths) > 0 22 | print(f"Found {len(data_paths)} datasets") 23 | # model_paths = sorted(glob.glob(os.path.join('new_active_learning', 'checkpoints', 'e_hull', '350_000', '*', '*.ckpt')), 24 | # key=get_seed) 25 | # seeds = list(map(get_seed, model_paths)) 26 | model_paths = glob.glob(os.path.join('new_active_learning', 'checkpoints', 'old_checkpoints', '*.ckpt')) 27 | seeds = ['old'] 28 | # df = pd.DataFrame(columns=['comp', 'seed', 'entry', 'prediction']) 29 | get_embeddings = True 30 | for seed, model_path in zip(seeds, tqdm(model_paths)): 31 | model = LightningModel.load_from_checkpoint(model_path, train=False) 32 | model = model.cuda() 33 | 34 | for path in tqdm(data_paths): 35 | dataset = CompositionData( 36 | data=path, 37 | fea_path="embeddings/matscholar-embedding.json", 38 | max_neighbor_number=model.hparams.max_nbr, 39 | target=model.hparams.target 40 | ) 41 | loader = DataLoader(dataset, batch_size=1000, shuffle=False, collate_fn=collate_fn) 42 | comp = get_composition(path) 43 | if not get_embeddings: 44 | predictions = [] 45 | targets = [] 46 | log_stds = [] 47 | for batch in loader: 48 | with torch.no_grad(): 49 | _, log_std, pred, target, _ = model.evaluate(batch) 50 | predictions.append(pred) 51 | targets.append(target) 52 | log_stds.append(log_std) 53 | dir = os.path.join('new_active_learning', comp) 54 | if not os.path.isdir(dir): 55 | os.makedirs(dir) 56 | np.savetxt(os.path.join(dir, f'{seed}.txt'), torch.cat(predictions).cpu().numpy().reshape((-1,))) 57 | np.savetxt(os.path.join(dir, f'target.txt'), torch.cat(targets).cpu().numpy().reshape((-1,))) 58 | np.savetxt(os.path.join(dir, f'log_std_{seed}.txt'), torch.cat(log_stds).cpu().detach().numpy().reshape((-1,))) 59 | else: 60 | embeddings = [] 61 | for batch in loader: 62 | with torch.no_grad(): 63 | embeddings.append(model.evaluate(batch, return_graph_embedding=True)) 64 | dir = os.path.join('new_active_learning', comp) 65 | if not os.path.isdir(dir): 66 | os.makedirs(dir) 67 | np.savetxt(os.path.join(dir, f'graph_embeddings.txt'), torch.cat(embeddings).cpu().numpy()) 68 | 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /Utilities/prepare.sh: -------------------------------------------------------------------------------- 1 | for ((i=0; i < 2830000; i += 10000)) do 2 | python CGAT/prepare_volume_data.py --file data_"$i"_`expr $i + 10000`.pickle.gz & 3 | done 4 | -------------------------------------------------------------------------------- /Utilities/prepare_active_learning.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import gzip as gz 3 | import os 4 | from tqdm import tqdm 5 | from glob import glob 6 | from adjust_data import remove_batch_ids 7 | 8 | 9 | def load(path): 10 | return pickle.load(gz.open(path)) 11 | 12 | 13 | def save(data, path): 14 | pickle.dump(data, gz.open(path, 'wb')) 15 | 16 | 17 | def main(): 18 | data_dir = 'data' 19 | used_data_path = os.path.join('active_learning', 'sample', 'randoms', 'random_sample_150000.pickle.gz') 20 | target_dir = os.path.join('new_active_learning', 'remaining') 21 | 22 | if not os.path.isdir(target_dir): 23 | os.makedirs(target_dir) 24 | 25 | print("Loading test and validation data...") 26 | test_val_data = load(os.path.join(data_dir, 'indices', 'test_and_val_idx.pickle.gz')) 27 | test_ids = set([batch_id for batch_id, in test_val_data['test_batch_ids']]) 28 | val_ids = set([batch_id for batch_id in test_val_data['val_batch_ids']]) 29 | 30 | print("Loading training data...") 31 | used_data = load(used_data_path) 32 | used_ids = set([batch_id for batch_id, in used_data['batch_ids']]) 33 | 34 | ids = test_ids | val_ids | used_ids 35 | 36 | print("Gathering remaining data...") 37 | for file in tqdm(glob(os.path.join(data_dir, '*.pickle.gz'))): 38 | data = load(file) 39 | to_remove = set() 40 | 41 | for batch_id, in data['batch_ids']: 42 | if batch_id in ids: 43 | to_remove.add(batch_id) 44 | ids.remove(batch_id) 45 | 46 | remove_batch_ids(data, to_remove) 47 | save(data, os.path.join(target_dir, os.path.basename(file))) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /Utilities/sample.py: -------------------------------------------------------------------------------- 1 | import gzip as gz 2 | import pickle 3 | from pymatgen.core.periodic_table import Element 4 | from tqdm import trange, tqdm 5 | import matplotlib.pyplot as plt 6 | import random 7 | import numpy as np 8 | from typing import Union 9 | from metropolis import MarkovChain 10 | from pymatgen.entries.computed_entries import ComputedStructureEntry 11 | 12 | DIR = 'data' 13 | 14 | 15 | def getfile(i: int): 16 | return f'{DIR}/data_{i * 10000}_{(i + 1) * 10000}.pickle.gz' 17 | 18 | 19 | def find_closest(sample: Union[list, np.ndarray], target): 20 | # assume sorted list 21 | if target < sample[0]: 22 | return 0 23 | elif target > sample[-1]: 24 | return len(sample) - 1 25 | 26 | lower_bound = 0 27 | upper_bound = len(sample) - 1 28 | 29 | while lower_bound <= upper_bound: 30 | mid = (lower_bound + upper_bound) // 2 31 | value = sample[mid] 32 | 33 | if target < value: 34 | upper_bound = mid - 1 35 | elif target > value: 36 | lower_bound = mid + 1 37 | else: 38 | return mid 39 | 40 | if (sample[lower_bound] - target) < (target - sample[upper_bound]): 41 | return lower_bound 42 | else: 43 | return upper_bound 44 | 45 | 46 | def find_element(sample: list[set], el: int): 47 | for i, s in enumerate(sample): 48 | if el in s: 49 | return i 50 | 51 | 52 | def get_distribution(hist: Union[np.ndarray, list]): 53 | def f(x): 54 | x = int(x) 55 | return hist[x] 56 | 57 | return f 58 | 59 | 60 | def get_id(entry: Union[ComputedStructureEntry, str]) -> int: 61 | if isinstance(entry, ComputedStructureEntry): 62 | return int(entry.data['id'].split(',')[0]) 63 | elif isinstance(entry, str): 64 | return int(entry.split(',')[0]) 65 | 66 | 67 | def search(data: list[ComputedStructureEntry], batch_id: int): 68 | low = 0 69 | high = len(data) - 1 70 | 71 | while low <= high: 72 | mid = (high + low) // 2 73 | curr_id = get_id(data[mid]) 74 | if curr_id < batch_id: 75 | low = mid + 1 76 | elif curr_id > batch_id: 77 | high = mid - 1 78 | else: 79 | return mid 80 | raise ValueError 81 | 82 | 83 | def main(): 84 | batch_ids = [] 85 | elements = [] 86 | stoichiometries = [] 87 | 88 | test_val_data = pickle.load(gz.open('data/test_and_val_idx.pickle.gz', 'rb')) 89 | test_batch_ids = set([batch_id[0] for batch_id in test_val_data['test_batch_ids']]) 90 | val_batch_ids = set([batch_id[0] for batch_id in test_val_data['val_batch_ids']]) 91 | test_val_batch_ids = test_batch_ids | val_batch_ids 92 | 93 | # files = sorted([getfile(i) for i in range(283)]) 94 | 95 | for i in trange(283): 96 | data = pickle.load(gz.open(getfile(i), 'rb')) 97 | for j, d in enumerate(data['batch_ids']): 98 | if d[0] not in test_val_batch_ids: 99 | split = d[0].split(',') 100 | # batch_id = int(split[0]) 101 | _elements = set([Element(el.rstrip('0123456789')).Z for el in data['batch_comp'][j][0].split()]) 102 | batch_ids.append(d[0]) 103 | elements.append(_elements) 104 | stoichiometries.append(data['batch_comp'][j][0]) 105 | 106 | print(f'{len(set(stoichiometries)) / len(stoichiometries):.2%} unique stoichiometries') 107 | 108 | print('Calculating correlation matrix') 109 | biggest_element = max([max(els) for els in elements]) 110 | correlation_matrix = np.zeros((biggest_element, biggest_element)) 111 | for els in tqdm(elements): 112 | for i in els: 113 | for j in els: 114 | correlation_matrix[i - 1, j - 1] += 1 115 | 116 | correlation_matrix = ( 117 | correlation_matrix.T / np.where(correlation_matrix.diagonal() != 0, correlation_matrix.diagonal(), 118 | np.ones(biggest_element))).T 119 | 120 | for i in range(biggest_element): 121 | correlation_matrix[i, i] = 0 122 | 123 | y = correlation_matrix.mean(axis=0) 124 | distribution = get_distribution([min(150, i) for i in np.where(y > 1e-3, 1 / y, np.zeros_like(y))]) 125 | 126 | # indices = [i for i in range(len(spgs))] 127 | N = 50000 128 | random.seed(1) 129 | # sample = random.sample(indices, N) 130 | 131 | # all_elements = [el for l in elements for el in l] 132 | # 133 | # plt.figure() 134 | # plt.hist(all_elements, bins=max(all_elements)) 135 | # plt.title('Elements') 136 | # plt.savefig('elements.pdf') 137 | # plt.show() 138 | 139 | batch_ids = np.array(batch_ids) 140 | elements = np.array(elements) 141 | stoichiometries = np.array(stoichiometries) 142 | 143 | np.random.seed(0) 144 | args = [i for i in range(len(batch_ids))] 145 | np.random.shuffle(args) 146 | batch_ids = list(batch_ids[args]) 147 | elements = list(elements[args]) 148 | stoichiometries = list(stoichiometries[args]) 149 | 150 | # sample for elements 151 | # sample_batch_ids = set() 152 | # elements_sample = [] 153 | # stoichiometries_sample = set() 154 | # 155 | # chain = MarkovChain(distribution, lambda: random.randint(0, biggest_element - 1)) 156 | # chain.step(N) 157 | # 158 | # # element_list = list(set(all_elements)) 159 | # # for el in tqdm(random.choices(element_list, k=N)): 160 | # bar = tqdm(total=N) 161 | # while len(sample_batch_ids) < N: 162 | # chain.step(1) 163 | # el = chain[-1] + 1 164 | # while True: 165 | # i = find_element(elements, el) 166 | # if i is None: 167 | # break 168 | # stoichiometry = stoichiometries.pop(i) 169 | # if stoichiometry not in stoichiometries_sample: 170 | # sample_batch_ids.add(batch_ids.pop(i)) 171 | # elements_sample.append(elements.pop(i)) 172 | # stoichiometries_sample.add(stoichiometry) 173 | # bar.update(1) 174 | # break 175 | # else: 176 | # elements.pop(i) 177 | # batch_ids.pop(i) 178 | # bar.close() 179 | 180 | # random sample 181 | sample_batch_ids = set(random.sample(batch_ids, N)) 182 | 183 | # all_elements = [el for l in elements_sample for el in l] 184 | # 185 | # plt.figure() 186 | # plt.hist(all_elements, bins=max(all_elements)) 187 | # plt.title('sampled Elements') 188 | # plt.savefig('sampled_elements.pdf') 189 | # plt.show() 190 | sample_data = [] 191 | test_data = [] 192 | val_data = [] 193 | for i in trange(283): 194 | data = pickle.load(gz.open(getfile(i), 'rb')) 195 | sample_indices = [] 196 | test_val_indices = [] 197 | curr_sample_batch_ids = [] 198 | curr_test_batch_ids = [] 199 | curr_val_batch_ids = [] 200 | for j, batch_id in enumerate(data['batch_ids']): 201 | batch_id = batch_id[0] 202 | if batch_id in sample_batch_ids: 203 | sample_indices.append(j) 204 | sample_batch_ids.remove(batch_id) 205 | curr_sample_batch_ids.append(get_id(batch_id)) 206 | elif batch_id in test_batch_ids: 207 | test_val_indices.append(j) 208 | test_val_batch_ids.remove(batch_id) 209 | test_batch_ids.remove(batch_id) 210 | curr_test_batch_ids.append(get_id(batch_id)) 211 | elif batch_id in val_batch_ids: 212 | test_val_indices.append(j) 213 | test_val_batch_ids.remove(batch_id) 214 | val_batch_ids.remove(batch_id) 215 | curr_val_batch_ids.append(get_id(batch_id)) 216 | if len(sample_indices) > 0 or len(test_val_indices) > 0: 217 | unprepared_data: list[ComputedStructureEntry] = pickle.load( 218 | gz.open(getfile(i).replace(f'{DIR}/', 'unprepared_volume_data/'))) 219 | for batch_id in curr_sample_batch_ids: 220 | j = search(unprepared_data, batch_id) 221 | sample_data.append(unprepared_data.pop(j)) 222 | for batch_id in curr_test_batch_ids: 223 | j = search(unprepared_data, batch_id) 224 | test_data.append(unprepared_data.pop(j)) 225 | for batch_id in curr_val_batch_ids: 226 | j = search(unprepared_data, batch_id) 227 | val_data.append(unprepared_data.pop(j)) 228 | # if len(sample_data.keys()) == 0: 229 | # sample_data['input'] = data['input'][:, sample_indices] 230 | # sample_data['batch_ids'] = [data['batch_ids'][j] for j in sample_indices] 231 | # sample_data['batch_comp'] = data['batch_comp'][sample_indices] 232 | # sample_data['target'] = {} 233 | # for target in data['target']: 234 | # sample_data['target'][target] = data['target'][target][sample_indices] 235 | # sample_data['comps'] = data['comps'][sample_indices] 236 | # else: 237 | # sample_data['input'] = np.concatenate((sample_data['input'], data['input'][:, sample_indices]), axis=1) 238 | # sample_data['batch_ids'] += [data['batch_ids'][j] for j in sample_indices] 239 | # sample_data['batch_comp'] = np.concatenate((sample_data['batch_comp'], data['batch_comp'])) 240 | # for target in data['target']: 241 | # sample_data['target'][target] = np.concatenate((sample_data['target'][target], 242 | # data['target'][target][sample_indices])) 243 | # sample_data['comps'] = np.concatenate((sample_data['comps'], data['comps'][sample_indices])) 244 | all_used_indices = sorted(sample_indices + test_val_indices, reverse=True) 245 | data['input'] = np.delete(data['input'], all_used_indices, axis=1) 246 | for j in all_used_indices: 247 | data['batch_ids'].pop(j) 248 | data['batch_comp'] = np.delete(data['batch_comp'], all_used_indices) 249 | data['comps'] = np.delete(data['comps'], all_used_indices) 250 | for target in data['target']: 251 | data['target'][target] = np.delete(data['target'][target], all_used_indices) 252 | pickle.dump(data, gz.open(getfile(i).replace(f'{DIR}/', 'active_learning/'), 'wb')) 253 | pickle.dump(sample_data, gz.open('active_learning/unprepared_random_sample.pickle.gz', 'wb')) 254 | pickle.dump(test_data, gz.open('active_learning/unprepared_test_data.pickle.gz', 'wb')) 255 | pickle.dump(val_data, gz.open('active_learning/unprepared_val_data.pickle.gz', 'wb')) 256 | 257 | 258 | if __name__ == '__main__': 259 | main() 260 | -------------------------------------------------------------------------------- /Utilities/train.sh: -------------------------------------------------------------------------------- 1 | python CGAT/train.py --gpus 1 --fea-path embeddings/matscholar-embedding.json --epochs 2 --target volume --data-path test-data/ 2 | -------------------------------------------------------------------------------- /Utilities/tsne.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | from sklearn.metrics import mean_absolute_error 6 | from openTSNE import TSNE 7 | from tqdm import tqdm 8 | 9 | 10 | def cap(array, upper_limit): 11 | temp = np.full((len(array), 2), upper_limit) 12 | temp[:, 0] = array 13 | return np.amin(temp, axis=1) 14 | 15 | 16 | def plot(embedding, value, **kwargs): 17 | plt.scatter(embedding[:, 0], embedding[:, 1], c=value, **kwargs) 18 | plt.colorbar() 19 | 20 | 21 | def main(): 22 | comps = Path('new_active_learning').glob('A*B*') 23 | 24 | embeddings = [] 25 | errors = [] 26 | colors = [] 27 | targets = [] 28 | comp_id = [] 29 | 30 | cycle = plt.rcParams['axes.prop_cycle'].by_key()['color'] 31 | 32 | for i, comp in enumerate(tqdm(list(comps))): 33 | files = comp.glob('*.txt') 34 | predictions = pd.DataFrame() 35 | for file in files: 36 | if 'target' == file.stem: 37 | targets.append(np.loadtxt(file)) 38 | # targets.append(target) 39 | elif 'log_std' in file.stem: 40 | pass 41 | elif 'graph_embeddings' == file.stem: 42 | embeddings.append(np.loadtxt(file)) 43 | else: 44 | try: 45 | predictions[int(file.stem)] = np.loadtxt(file) 46 | except ValueError: 47 | pass 48 | errors.append(np.array([mean_absolute_error([targets[-1][i]] * len(row.keys()), row) for i, row in 49 | predictions.iterrows()] 50 | )) 51 | colors += [cycle[i % len(cycle)]] * len(errors[-1]) 52 | comp_id += [i] * len(errors[-1]) 53 | 54 | embeddings = np.concatenate(embeddings) 55 | errors = np.concatenate(errors) 56 | targets = np.concatenate(targets) 57 | 58 | errors = cap(errors, .4) 59 | targets = cap(targets, .5) 60 | 61 | titles = ['Error', 'Prototypes', 'distance to convex hull'] 62 | values = [errors, colors, targets] 63 | 64 | embedding_list = [] 65 | for metric in ['euclidean', 'cosine']: 66 | tsne = TSNE(n_jobs=6, perplexity=500, metric=metric, exaggeration=2) 67 | embedding = tsne.fit(embeddings) 68 | embedding_list.append(embedding) 69 | 70 | for title, value in zip(titles, values): 71 | plt.figure() 72 | plt.title(f'{title} -- {metric}') 73 | plot(embedding, value, s=.5) 74 | plt.show() 75 | 76 | embedding = embedding_list[0] 77 | for e in embedding_list[1:]: 78 | embedding *= e 79 | for title, value in zip(titles, values): 80 | plt.figure() 81 | plt.title(f'{title} -- product') 82 | plot(embedding, value, s=.5) 83 | plt.show() 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import gzip as gz 2 | import os 3 | import argparse 4 | 5 | import numpy as np 6 | import warnings 7 | import torch 8 | import pickle 9 | from torch.utils.data import Dataset, DataLoader 10 | from tqdm import tqdm 11 | from CGAT.roost_message import LoadFeaturiser 12 | 13 | 14 | def build_dataset_prepare(data, 15 | target_property=["e_above_hull", 'e_form'], 16 | radius=18.0, 17 | fea_path="embeddings/matscholar-embedding.json", 18 | max_neighbor_number=24): 19 | """Use to calculate features for lists of pickle and gzipped ComputedEntry pickles (either a path to the file or the file directly), returns dictionary with all necessary inputs. If the data has no target values the target values are set to -1e8 20 | Always enter list of target properties""" 21 | 22 | def tensor2numpy(l): 23 | """recursively convert torch Tensors into numpy arrays""" 24 | if isinstance(l, torch.Tensor): 25 | return l.numpy() 26 | elif isinstance(l, str) or isinstance(l, int) or isinstance(l, float): 27 | return l 28 | elif isinstance(l, list) or isinstance(l, tuple): 29 | return np.asarray([tensor2numpy(i) for i in l], dtype=object) 30 | elif isinstance(l, dict): 31 | npdict = {} 32 | for name, val in l.items(): 33 | npdict[name] = tensor2numpy(val) 34 | return npdict 35 | else: 36 | return None # this will give an error later on 37 | 38 | d = CompositionDataPrepare(data, 39 | fea_path=fea_path, 40 | target_property=target_property, 41 | max_neighbor_number=max_neighbor_number, 42 | radius=radius) 43 | 44 | loader = DataLoader(d, batch_size=1) 45 | 46 | input1_ = [] 47 | input2_ = [] 48 | input3_ = [] 49 | comps_ = [] 50 | batch_comp_ = [] 51 | if type(target_property) == list: 52 | target_ = {} 53 | for name in target_property: 54 | target_[name] = [] 55 | else: 56 | target_ = [] 57 | batch_ids_ = [] 58 | 59 | for input_, target, batch_comp, batch_ids in tqdm(loader): 60 | if len(input_) == 1: # remove compounds with not enough neighbors 61 | continue 62 | input1_.append(input_[0]) 63 | comps_.append(input_[1]) 64 | input2_.append(input_[2]) 65 | input3_.append(input_[3]) 66 | if isinstance(target_property, list): 67 | for name in target_property: 68 | target_[name].append(target[name]) 69 | else: 70 | target_.append(target) 71 | 72 | batch_comp_.append(batch_comp) 73 | batch_ids_.append(batch_ids) 74 | 75 | input1_ = tensor2numpy(input1_) 76 | input2_ = tensor2numpy(input2_) 77 | input3_ = tensor2numpy(input3_) 78 | 79 | n = input1_[0].shape[0] 80 | shape = input1_.shape 81 | if len(shape) > 2: 82 | i1 = np.empty(shape=(1, shape[0]), dtype=object) 83 | i2 = np.empty(shape=(1, shape[0]), dtype=object) 84 | i3 = np.empty(shape=(1, shape[0]), dtype=object) 85 | i1[:, :, ] = [[input1_[l] for l in range(shape[0])]] 86 | input1_ = i1 87 | i2[:, :, ] = [[input2_[l] for l in range(shape[0])]] 88 | input2_ = i2 89 | i3[:, :, ] = [[input3_[l] for l in range(shape[0])]] 90 | input3_ = i3 91 | 92 | inputs_ = np.vstack((input1_, input2_, input3_)) 93 | 94 | return {'input': inputs_, 95 | 'batch_ids': batch_ids_, 96 | 'batch_comp': tensor2numpy(batch_comp_), 97 | 'target': tensor2numpy(target_), 98 | 'comps': tensor2numpy(comps_)} 99 | 100 | 101 | class CompositionDataPrepare(Dataset): 102 | """ 103 | The CompositionData dataset is a wrapper for a dataset data points are 104 | automatically constructed from composition strings. 105 | """ 106 | 107 | def __init__(self, data, fea_path, target_property=['e-form'], radius=18.0, max_neighbor_number=24): 108 | """ 109 | """ 110 | if isinstance(data, str): 111 | self.data = pickle.load(gz.open(data, 'rb')) 112 | else: 113 | self.data = data 114 | self.radius = radius 115 | self.max_num_nbr = max_neighbor_number 116 | self.target_property = target_property 117 | assert os.path.exists(fea_path), "{} does not exist!".format(fea_path) 118 | self.atom_features = LoadFeaturiser(fea_path) 119 | self.atom_fea_dim = self.atom_features.embedding_size 120 | 121 | def __len__(self): 122 | return len(self.data) 123 | 124 | def __getitem__(self, idx): 125 | try: 126 | cry_id = self.data[idx].data['id'] 127 | except KeyError: 128 | cry_id = 'unknown' 129 | composition = self.data[idx].composition.formula 130 | try: 131 | crystal = self.data[idx].structure 132 | except: 133 | crystal = self.data[idx] 134 | 135 | elements = [element.specie.symbol for element in crystal] 136 | if len(set(elements))==1: 137 | return (torch.ones(1)), torch.ones(1), torch.ones(1), torch.ones( 138 | 1) 139 | try: 140 | target = {} 141 | for name in self.target_property: 142 | target[name] = self.data[idx].data[name] / len(crystal.sites) 143 | except KeyError: 144 | target = {} 145 | warnings.warn('no target property') 146 | for name in self.target_property: 147 | target[name] = -1e8 148 | 149 | all_nbrs = crystal.get_all_neighbors(self.radius, include_index=True) 150 | all_nbrs = [sorted(nbrs, key=lambda x: x[1])[0:self.max_num_nbr] for nbrs in all_nbrs] 151 | 152 | nbr_fea_idx, nbr_fea, self_fea_idx = [], [], [] 153 | for site, nbr in enumerate(all_nbrs): 154 | nbr_fea_idx_sub, nbr_fea_sub, self_fea_idx_sub = [], [], [] 155 | if len(nbr) < self.max_num_nbr: 156 | warnings.warn('{} does not contain enough neighbors in the cutoff to build the full graph. ' 157 | 'If it happens frequently, consider increase ' 158 | 'radius. Compound is not added to the feature set'.format(cry_id)) 159 | return (torch.ones(1)), torch.ones(1), torch.ones(1), torch.ones( 160 | 1) # fake input will be removed in build_dataset_prepare 161 | else: 162 | for n in range(self.max_num_nbr): 163 | self_fea_idx_sub.append(site) 164 | for j in range(self.max_num_nbr): 165 | nbr_fea_idx_sub.append(nbr[j][2]) 166 | index = 1 167 | dist = nbr[0][1] 168 | for j in range(self.max_num_nbr): 169 | if (nbr[j][1] > dist + 1e-8): 170 | dist = nbr[j][1] 171 | index += 1 172 | nbr_fea_sub.append(index) 173 | nbr_fea_idx.append(nbr_fea_idx_sub) 174 | nbr_fea.append(nbr_fea_sub) 175 | self_fea_idx.append(self_fea_idx_sub) 176 | return (nbr_fea, elements, self_fea_idx, nbr_fea_idx), \ 177 | target, composition, cry_id 178 | 179 | def get_targets(self, idx1, idx2): 180 | target = [] 181 | l = [] 182 | for el in idx2: 183 | l.append(self.data[el][self.target_property]) 184 | for el in idx1: 185 | target.append(l[el]) 186 | del l 187 | return torch.tensor(target).reshape(len(idx1), 1) 188 | 189 | 190 | def collate_batch(dataset_list): 191 | """ 192 | Collate a list of data and return a batch for predicting crystal 193 | properties. 194 | Parameters 195 | ---------- 196 | dataset_list: list of tuples for each data point. 197 | (atom_fea, nbr_fea, nbr_fea_idx, target) 198 | atom_fea: torch.Tensor shape (n_i, atom_fea_len) 199 | nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len) 200 | nbr_fea_idx: torch.LongTensor shape (n_i, M) 201 | target: torch.Tensor shape (1, ) 202 | cif_id: str or int 203 | Returns 204 | ------- 205 | N = sum(n_i); N0 = sum(i) 206 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len) 207 | Atom features from atom type 208 | batch_nbr_fea: torch.Tensor shape (N, M, nbr_fea_len) 209 | Bond features of each atom"s M neighbors 210 | batch_nbr_fea_idx: torch.LongTensor shape (N, M) 211 | Indices of M neighbors of each atom 212 | crystal_atom_idx: list of torch.LongTensor of length N0 213 | Mapping from the crystal idx to atom idx 214 | target: torch.Tensor shape (N, 1) 215 | Target value for prediction 216 | batch_cif_ids: list 217 | """ 218 | # define the lists 219 | batch_atom_weights = [] 220 | batch_atom_fea = [] 221 | batch_nbr_fea = [] 222 | batch_self_fea_idx = [] 223 | batch_nbr_fea_idx = [] 224 | crystal_atom_idx = [] 225 | batch_target = [] 226 | batch_comp = [] 227 | batch_cry_ids = [] 228 | 229 | cry_base_idx = 0 230 | for i, ((atom_fea, nbr_fea, self_fea_idx, nbr_fea_idx, _), 231 | target, comp, cry_id) in enumerate(dataset_list): 232 | # number of atoms for this crystal 233 | n_i = atom_fea.shape[0] 234 | # batch the features together 235 | # batch_atom_weights.append(atom_weights) 236 | batch_atom_fea.append(atom_fea) 237 | batch_nbr_fea.append(nbr_fea) 238 | # mappings from bonds to atoms 239 | batch_self_fea_idx.append(self_fea_idx + cry_base_idx) 240 | batch_nbr_fea_idx.append(nbr_fea_idx + cry_base_idx) 241 | 242 | # mapping from atoms to crystals 243 | crystal_atom_idx.append(torch.tensor([i] * n_i)) 244 | 245 | # batch the targets and ids 246 | batch_target.append(target) 247 | batch_comp.append(comp) 248 | batch_cry_ids.append(cry_id) 249 | 250 | # increment the id counter 251 | cry_base_idx += n_i 252 | return (torch.cat(batch_atom_fea, dim=0), torch.cat(batch_nbr_fea, dim=0), torch.cat(batch_self_fea_idx, dim=0), 253 | torch.cat(batch_nbr_fea_idx, dim=0), torch.cat(crystal_atom_idx)), \ 254 | torch.cat(batch_target, dim=0), \ 255 | batch_comp, \ 256 | batch_cry_ids 257 | 258 | 259 | def collate_batch2(dataset_list): 260 | """ 261 | Collate a list of data and return a batch for predicting crystal 262 | properties. 263 | Parameters 264 | ---------- 265 | dataset_list: list of tuples for each data point. 266 | (atom_fea, nbr_fea, nbr_fea_idx, target) 267 | atom_fea: torch.Tensor shape (n_i, atom_fea_len) 268 | nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len) 269 | nbr_fea_idx: torch.LongTensor shape (n_i, M) 270 | target: torch.Tensor shape (1, ) 271 | cif_id: str or int 272 | Returns 273 | ------- 274 | N = sum(n_i); N0 = sum(i) 275 | batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len) 276 | Atom features from atom type 277 | batch_nbr_fea: torch.Tensor shape (N, M, nbr_fea_len) 278 | Bond features of each atom"s M neighbors 279 | batch_nbr_fea_idx: torch.LongTensor shape (N, M) 280 | Indices of M neighbors of each atom 281 | crystal_atom_idx: list of torch.LongTensor of length N0 282 | Mapping from the crystal idx to atom idx 283 | target: torch.Tensor shape (N, 1) 284 | Target value for prediction 285 | batch_cif_ids: list 286 | """ 287 | # define the lists 288 | batch_atom_weights = [] 289 | batch_atom_fea = [] 290 | batch_nbr_fea = [] 291 | batch_self_fea_idx = [] 292 | batch_nbr_fea_idx = [] 293 | crystal_atom_idx = [] 294 | batch_target = [] 295 | batch_comp = [] 296 | batch_cry_ids = [] 297 | 298 | cry_base_idx = 0 299 | for i, ((nbr_fea, atom_fea, self_fea_idx, nbr_fea_idx), 300 | target, comp, cry_id) in enumerate(dataset_list): 301 | # number of atoms for this crystal 302 | n_i = atom_fea.shape[0] 303 | # batch the features together 304 | # batch_atom_weights.append(atom_weights) 305 | batch_atom_fea.append(atom_fea) 306 | batch_nbr_fea.append(nbr_fea) 307 | # mappings from bonds to atoms 308 | batch_self_fea_idx.append(self_fea_idx + cry_base_idx) 309 | batch_nbr_fea_idx.append(nbr_fea_idx + cry_base_idx) 310 | 311 | # mapping from atoms to crystals 312 | crystal_atom_idx.append(torch.tensor([i] * n_i)) 313 | 314 | # batch the targets and ids 315 | batch_target.append(target) 316 | batch_comp.append(comp) 317 | batch_cry_ids.append(cry_id) 318 | 319 | # increment the id counter 320 | cry_base_idx += n_i 321 | return (torch.cat(batch_atom_fea, dim=0), torch.cat(batch_nbr_fea, dim=0), torch.cat(batch_self_fea_idx, dim=0), 322 | torch.cat(batch_nbr_fea_idx, dim=0), torch.cat(crystal_atom_idx)), \ 323 | torch.cat(batch_target, dim=0), \ 324 | batch_comp, \ 325 | batch_cry_ids 326 | 327 | 328 | class AverageMeter(object): 329 | """Computes and stores the average and current value""" 330 | 331 | def __init__(self): 332 | self.reset() 333 | 334 | def reset(self): 335 | self.val = 0 336 | self.avg = 0 337 | self.sum = 0 338 | self.count = 0 339 | 340 | def update(self, val, n=1): 341 | self.val = val 342 | self.sum += val * n 343 | self.count += n 344 | self.avg = self.sum / self.count 345 | 346 | 347 | class Normalizer(object): 348 | """Normalize a Tensor and restore it later. """ 349 | 350 | def __init__(self, log=False): 351 | """tensor is taken as a sample to calculate the mean and std""" 352 | self.mean = torch.tensor((0)) 353 | self.std = torch.tensor((1)) 354 | 355 | def fit(self, tensor, dim=0, keepdim=False): 356 | """tensor is taken as a sample to calculate the mean and std""" 357 | self.mean = torch.mean(tensor, dim, keepdim) 358 | self.std = torch.std(tensor, dim, keepdim) 359 | 360 | def norm(self, tensor): 361 | return (tensor - self.mean) / self.std 362 | 363 | def denorm(self, normed_tensor): 364 | return normed_tensor * self.std + self.mean 365 | 366 | def state_dict(self): 367 | return {"mean": self.mean, 368 | "std": self.std} 369 | 370 | def load_state_dict(self, state_dict): 371 | self.mean = state_dict["mean"].cpu() 372 | self.std = state_dict["std"].cpu() 373 | 374 | 375 | def main(): 376 | parser = argparse.ArgumentParser() 377 | parser.add_argument('--file', default='dcgat_1_000.pickle.gz') 378 | parser.add_argument('--source-dir', default='./') 379 | parser.add_argument('--target-dir', default='./') 380 | parser.add_argument('--target-file', default='dcgat_1_000_features.pickle.gz') 381 | args = parser.parse_args() 382 | test = build_dataset_prepare(os.path.join(args.source_dir, args.file)) 383 | if args.target_file is None: 384 | pickle.dump(test, gz.open(os.path.join(args.target_dir, os.path.basename(args.file)), 'wb')) 385 | else: 386 | pickle.dump(test, gz.open(os.path.join(args.target_dir, args.target_file), 'wb')) 387 | 388 | 389 | if __name__ == '__main__': 390 | main() 391 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.21.4 2 | pandas>=1.3.4 3 | scikit-learn>=1.0.1 4 | pymatgen>=2022.0.16 5 | tqdm>=4.62.3 6 | setuptools>=58.0.4 7 | matplotlib>=3.4.3 8 | scipy>=1.7.1 9 | pytorch-lightning==1.5.8 -------------------------------------------------------------------------------- /runs/plot.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyllios/CGAT/0a5d34057f7ec131293a0b1b0c269cfb8ac63ef1/runs/plot.py -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = CGAT 3 | version = 0.1 4 | url = https://github.com/hyllios/CGAT 5 | license = MIT 6 | author = Jonathan Schmidt 7 | description = Crystal graph attention neural networks for materials prediction 8 | 9 | [options] 10 | packages = CGAT 11 | package_dir = =. 12 | 13 | [options.entry_points] 14 | console_scripts = 15 | train-CGAT = CGAT.train:run 16 | prepare = CGAT.prepare_data:main 17 | train-GP = CGAT.gaussian_process:main -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == '__main__': 4 | setup() 5 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import ModelCheckpoint 2 | import os 3 | from argparse import ArgumentParser 4 | import os 5 | import gc 6 | import datetime 7 | import numpy as np 8 | import pandas as pd 9 | 10 | import numpy as np 11 | import torch 12 | 13 | import pytorch_lightning as pl 14 | from CGAT.lightning_module import LightningModel 15 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger 16 | SEED = 1 17 | torch.manual_seed(SEED) 18 | np.random.seed(SEED) 19 | 20 | 21 | def main(hparams): 22 | """ 23 | testing routine 24 | Args: 25 | hparams: checkpoint of the model to be tested and gpu, parallel backend etc., 26 | defined in the argument parser in if __name__ == '__main__': 27 | Returns: 28 | """ 29 | checkpoint_path=hparams.ckp 30 | model = LightningModel.load_from_checkpoint( 31 | checkpoint_path=checkpoint_path,train=hparams.train,test=hparams.test, test_path = hparams.test_path, val_path=hparams.val_path, fea_path= hparams.fea_path 32 | ) 33 | 34 | trainer = pl.Trainer( 35 | gpus=[hparams.first_gpu+el for el in range(hparams.gpus)], 36 | ) 37 | 38 | trainer.test(model) 39 | 40 | if __name__ == '__main__': 41 | 42 | root_dir = os.path.dirname(os.path.realpath(__file__)) 43 | parent_parser = ArgumentParser(add_help=False) 44 | 45 | parent_parser.add_argument( 46 | '--gpus', 47 | type=int, 48 | default=1, 49 | help='how many gpus' 50 | ) 51 | parent_parser.add_argument( 52 | '--amp_optimization', 53 | type=str, 54 | default='00', 55 | help="mixed precision format, default 00 (32), 01 mixed, 02 closer to 16, should not be used during testing" 56 | ) 57 | parent_parser.add_argument( 58 | '--first-gpu', 59 | type=int, 60 | default=0, 61 | help='gpu number to use [first_gpu, ..., first_gpu+gpus]' 62 | ) 63 | parent_parser.add_argument( 64 | '--ckp', 65 | type=str, 66 | default='', 67 | help='ckp path, if left empty no checkpoint is used' 68 | ) 69 | parent_parser.add_argument( 70 | '--hparams', 71 | type=str, 72 | default='', 73 | help='path for hparams of ckp if left empty no checkpoint is used' 74 | ) 75 | parent_parser.add_argument("--test", 76 | action="store_true", 77 | help="whether to train or test" 78 | ) 79 | 80 | 81 | # each LightningModule defines arguments relevant to it 82 | parser = LightningModel.add_model_specific_args(parent_parser) 83 | hyperparams = parser.parse_args() 84 | 85 | print(hyperparams) 86 | main(hyperparams) 87 | -------------------------------------------------------------------------------- /training_scripts/train.sh: -------------------------------------------------------------------------------- 1 | embedding_path="embeddings/matscholar-embedding.json" 2 | 3 | # Training loop 4 | for target in e_above_hull 5 | do 6 | echo Training target "$target" 7 | train-CGAT --gpus 2 --target "$target" --fea-path "$embedding_path" --epochs 280 --clr-period 70 --data-path dcgat_1_000_features.pickle.gz --batch-size 2 8 | done 9 | -------------------------------------------------------------------------------- /training_scripts/transfer_full.sh: -------------------------------------------------------------------------------- 1 | val_data="$1/val" 2 | test_data="$1/test" 3 | embeddings="embeddings/matscholar-embedding.json" 4 | 5 | train-CGAT --gpus 2 --target "$2" --data-path "$1" --val-path "$val_data" --test-path "$test_data" --fea-path "$embeddings" --epochs 390 --clr-period 70 --pretrained-model "$3" 6 | -------------------------------------------------------------------------------- /training_scripts/transfer_only_residual.sh: -------------------------------------------------------------------------------- 1 | val_data="$1/val" 2 | test_data="$1/test" 3 | embeddings="embeddings/matscholar-embedding.json" 4 | 5 | train-CGAT --gpus 2 --target "$2" --data-path "$1" --val-path "$val_data" --test-path "$test_data" --fea-path "$embeddings" --epochs 390 --clr-period 70 --pretrained-model "$3" --only-residual 6 | --------------------------------------------------------------------------------