├── .env ├── LICENSE ├── README.md ├── cgcnn ├── cgcnn │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── data.cpython-38.pyc │ │ └── model.cpython-38.pyc │ ├── data.py │ └── model.py ├── data │ └── mptest │ │ ├── atom_init.json │ │ ├── test.csv │ │ ├── train.csv │ │ └── val.csv ├── main.py ├── pre-trained │ ├── mpall20_Cbg │ │ ├── checkpoint.pth.tar │ │ └── model_best.pth.tar │ ├── mpall20_Cfm │ │ ├── checkpoint.pth.tar │ │ └── model_best.pth.tar │ ├── mpall20_bg │ │ ├── checkpoint.pth.tar │ │ └── model_best.pth.tar │ ├── mpall20_fm │ │ ├── checkpoint.pth.tar │ │ └── model_best.pth.tar │ ├── mpall40_Cbg │ │ ├── checkpoint.pth.tar │ │ └── model_best.pth.tar │ ├── mpall40_Cfm │ │ ├── checkpoint.pth.tar │ │ └── model_best.pth.tar │ ├── mpall40_bg │ │ ├── checkpoint.pth.tar │ │ └── model_best.pth.tar │ ├── mpall40_fm │ │ ├── checkpoint.pth.tar │ │ └── model_best.pth.tar │ ├── oqmd_Cbg │ │ ├── checkpoint.pth.tar │ │ └── model_best.pth.tar │ ├── oqmd_Cfm │ │ ├── checkpoint.pth.tar │ │ └── model_best.pth.tar │ ├── oqmd_bg │ │ ├── checkpoint.pth.tar │ │ └── model_best.pth.tar │ └── oqmd_fm │ │ ├── checkpoint.pth.tar │ │ └── model_best.pth.tar └── predict.py ├── concdvae ├── PT_train │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── training.cpython-38.pyc │ └── training.py ├── __init__.py ├── __pycache__ │ └── __init__.cpython-38.pyc ├── common │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── data_utils.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ ├── data_utils.py │ └── utils.py ├── pl_data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── datamodule.cpython-38.pyc │ │ └── dataset.cpython-38.pyc │ ├── datamodule.py │ └── dataset.py ├── pl_modules │ ├── ConditionModel.py │ ├── PreCondition.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── ConditionModel.cpython-38.pyc │ │ ├── PreCondition.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── decoder.cpython-38.pyc │ │ ├── gnn.cpython-38.pyc │ │ └── model.cpython-38.pyc │ ├── decoder.py │ ├── embeddings │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── atomic_radii.cpython-38.pyc │ │ │ ├── continuous_embeddings.cpython-38.pyc │ │ │ └── khot_embeddings.cpython-38.pyc │ │ ├── atomic_radii.py │ │ ├── continuous_embeddings.py │ │ └── khot_embeddings.py │ ├── gemnet │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── gemnet.cpython-38.pyc │ │ │ ├── initializers.cpython-38.pyc │ │ │ └── utils.cpython-38.pyc │ │ ├── fit_scaling.py │ │ ├── gemnet-dT.json │ │ ├── gemnet.py │ │ ├── initializers.py │ │ ├── layers │ │ │ ├── __pycache__ │ │ │ │ ├── atom_update_block.cpython-38.pyc │ │ │ │ ├── base_layers.cpython-38.pyc │ │ │ │ ├── basis_utils.cpython-38.pyc │ │ │ │ ├── efficient.cpython-38.pyc │ │ │ │ ├── embedding_block.cpython-38.pyc │ │ │ │ ├── interaction_block.cpython-38.pyc │ │ │ │ ├── radial_basis.cpython-38.pyc │ │ │ │ ├── scaling.cpython-38.pyc │ │ │ │ └── spherical_basis.cpython-38.pyc │ │ │ ├── atom_update_block.py │ │ │ ├── base_layers.py │ │ │ ├── basis_utils.py │ │ │ ├── efficient.py │ │ │ ├── embedding_block.py │ │ │ ├── interaction_block.py │ │ │ ├── radial_basis.py │ │ │ ├── scaling.py │ │ │ └── spherical_basis.py │ │ └── utils.py │ ├── gnn.py │ └── model.py ├── pt2CS.py └── run.py ├── conf ├── conz_1.yaml ├── conz_2.yaml ├── data │ ├── mp_all20.yaml │ ├── mp_all40.yaml │ ├── mptest.yaml │ └── oqmd_20.yaml ├── default.yaml ├── logging │ └── default.yaml ├── model │ ├── conditionmodel │ │ ├── mp_CSclass.yaml │ │ ├── mp_FMclass_BGclass.yaml │ │ ├── mp_format.yaml │ │ ├── mp_format_gap.yaml │ │ └── mp_gap.yaml │ ├── conditionpre │ │ ├── pre_mp_CSclass.yaml │ │ ├── pre_mp_FMclass_BGclass.yaml │ │ ├── pre_mp_format.yaml │ │ ├── pre_mp_format_gap.yaml │ │ └── pre_mp_gap.yaml │ ├── decoder │ │ └── gemnet.yaml │ ├── encoder │ │ └── dimenet.yaml │ ├── supervise.yaml │ ├── vae.yaml │ ├── vae_mp_CSclass.yaml │ ├── vae_mp_FMclass_BGclass.yaml │ ├── vae_mp_format.yaml │ ├── vae_mp_format_gap.yaml │ └── vae_mp_gap.yaml ├── optim │ ├── default.yaml │ ├── less1.yaml │ └── less2.yaml └── train │ ├── default.yaml │ └── new.yaml ├── data ├── mptest │ ├── test.csv │ ├── test_data.pt │ ├── train.csv │ ├── train_data.pt │ ├── val.csv │ └── val_data.pt └── mptest4conz │ ├── atom_init.json │ ├── test.csv │ ├── test_data.pt │ ├── train.csv │ ├── train_data.pt │ ├── val.csv │ └── val_data.pt ├── environment.yml ├── output └── hydra │ └── singlerun │ └── 2024-01-25 │ └── test │ ├── conz_loss_file_ABC.xlsx │ ├── conz_model_ABC_diffu.pth │ ├── general_full.csv │ ├── general_less.csv │ ├── hparams.yaml │ ├── lattice_scaler.pt │ ├── loss_file.xlsx │ ├── model_test.pth │ ├── model_test_notbest.pth │ └── run.log └── scripts ├── __pycache__ ├── condition_diff_z.cpython-38.pyc ├── condition_z.cpython-38.pyc ├── eval_utils.cpython-38.pyc └── test_conz.cpython-38-pytest-7.4.0.pyc ├── condition_diff_z.py ├── eval_utils.py ├── evaluate_diff.py ├── extra_z.py └── pt2cif.py /.env: -------------------------------------------------------------------------------- 1 | export PROJECT_ROOT="D:/2-project/0-MaterialDesign/3-CDVAE/con-cdvae" 2 | export HYDRA_JOBS="D:/2-project/0-MaterialDesign/3-CDVAE/con-cdvae/output/hydra" 3 | export WABDB_DIR="D:/2-project/0-MaterialDesign/3-CDVAE/con-cdvae/output/wabdb" 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Caiyuan Ye, Quanshen Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Con-CDVAE 2 | 3 | This code is improved on the basis of 4 | [CDVAE](https://arxiv.org/abs/2110.06197), 5 | and implements the generation of crystals according to 6 | the target properties. 7 | 8 | Ref: [Cai-Yuan Ye, Hong-Ming Weng, Quan-Sheng Wu, Con-CDVAE: A method for the conditional generation of crystal structures, Computational Materials Today, 1, 100003 (2024).](https://www.sciencedirect.com/science/article/pii/S2950463524000036) 9 | 10 | arXiv: [https://arxiv.org/abs/2403.12478](https://arxiv.org/abs/2403.12478) 11 | 12 | ## Installation 13 | It easy to building a python environment using conda. 14 | Run the following command to install the environment: 15 | ```bash 16 | conda env create -f environment.yml 17 | ``` 18 | 19 | Modify the following environment variables in `.env`. 20 | 21 | - `PROJECT_ROOT`: path to the folder that contains this repo 22 | - `HYDRA_JOBS`: path to a folder to store hydra outputs 23 | - `WABDB`: path to a folder to store wabdb outputs 24 | 25 | ## Datasets 26 | 27 | You can find a small sample of the dataset in `data/`, 28 | including the data used for Con-CDVAE two-step training. 29 | The complete data can be easily downloaded according to the API 30 | provided by the [Materials Project (MP)](https://next-gen.materialsproject.org/) 31 | and [Open Quantum Materials Database (OQMD)](https://oqmd.org/), 32 | and they can be used in the same format as the sample. 33 | 34 | ## Training Con-CDVAE 35 | 36 | ### Step-one training 37 | To train a Con-CDVAE, run the following command first: 38 | 39 | ``` 40 | python concdvae/run.py train=new data=mptest expname=test model=vae_mp_CSclass 41 | ``` 42 | 43 | To use other dataset, user should prepare the data in the same forme as 44 | the sample, and edit a new configure files in `conf/data/` folder, 45 | and use `data=your_data_conf`. To train model for other property, use 46 | `model=vae_mp_format` or `model=vae_mp_gap`. 47 | 48 | If you want to accelerate with a gpu, you should set `accelerator=gpu` 49 | in command line. If you want to accelerate with multiple gpus, you should 50 | run this command: 51 | ``` 52 | torchrun --nproc_per_node 4 concdvae/run.py train=new data=mptest expname=test model=vae_mp_CSclass accelerator=ddp 53 | ``` 54 | After training, model checkpoints can be found in 55 | `$HYDRA_JOBS/singlerun/YYYY-MM-DD/model_expname.pth`. 56 | 57 | 58 | ### Step-two training 59 | After finishing step-one training, you can train the *Prior* block 60 | with the following command. 61 | ``` 62 | python scripts/condition_diff_z.py --model_path /your_path_to_model_checkpoints/ --model_file model_expname.pth --fullfea 0 --label your_label 63 | ``` 64 | Then you can get the default condition *Prior* in 65 | `/your_path_to_model_checkpoints/conz_model_your_label_diffu.pth`. 66 | 67 | If you want to train full conditon *Prior*, you should change 68 | `--fullfea 0` to `--fullfea 1` and set 69 | `--newcond /your_path_to_conf/conf/conz_2.yaml --newdata mptest4conz` 70 | 71 | ## Generating crystals with target propertise 72 | To generate materials, you should prepare condition file. 73 | You can see the example in `/output/hydra/singlerun/2024-01-25/test/`, 74 | where "general_full.csv" is for *default* strategy or *full* strategy, 75 | and "general_less.csv" is for *less* strategy. 76 | 77 | Then run the following command: 78 | ``` 79 | python scripts/evaluate_diff.py --model_path /your_path_to_model_checkpoints/ --model_file model_expname.pth --conz_file conz_model_your_label_diffu.pth --label your_label --prop_path general_full.csv 80 | ``` 81 | 82 | If you want to filter latent variables using the *Predictor* block, set 83 | `--down_sample 100` which means filtering at a ratio of one hundred 84 | to one. 85 | 86 | ## Evaluating model 87 | 88 | To evaluate crystal system, you can use the code `concdvae/pt2CS.py`. 89 | 90 | To evaluate other properties, you should train a 91 | [CGCNN](https://github.com/txie-93/cgcnn) with the following command: 92 | ``` 93 | python cgcnn/main.py /your_path_to_con-cdvae/cgcnn/data/mptest --prop band_gap --label your_label 94 | ``` 95 | This code use the same dataset as Con-CDVAE, You can build 96 | the required database using the methods mentioned earlier. 97 | If you want to train CGCNN on other property, you can set 98 | `--prop formation_energy_per_atom`, `--prop BG_type`, `--prop FM_type`. 99 | It is important to note that if you are training for a 100 | classification task, you should set `--task classification`. 101 | 102 | After training, model checkpoints can be found in 103 | `your_labelmodel_best.pth.tar`. The trained model can be found in 104 | `cgcnn/pre-trained`. 105 | 106 | When you've generated crystals and need to evaluate, 107 | run the following command: 108 | ``` 109 | python cgcnn/predict.py --gendatapath /your_path_to_generated_crystal/ --modelpath /your_path_to_cgcnn_model/model_best.pth.tar --file your_crystal_file.pt --label your_label 110 | ``` 111 | -------------------------------------------------------------------------------- /cgcnn/cgcnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/cgcnn/__init__.py -------------------------------------------------------------------------------- /cgcnn/cgcnn/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/cgcnn/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /cgcnn/cgcnn/__pycache__/data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/cgcnn/__pycache__/data.cpython-38.pyc -------------------------------------------------------------------------------- /cgcnn/cgcnn/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/cgcnn/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /cgcnn/cgcnn/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ConvLayer(nn.Module): 8 | """ 9 | Convolutional operation on graphs 10 | """ 11 | def __init__(self, atom_fea_len, nbr_fea_len): 12 | """ 13 | Initialize ConvLayer. 14 | 15 | Parameters 16 | ---------- 17 | 18 | atom_fea_len: int 19 | Number of atom hidden features. 20 | nbr_fea_len: int 21 | Number of bond features. 22 | """ 23 | super(ConvLayer, self).__init__() 24 | self.atom_fea_len = atom_fea_len 25 | self.nbr_fea_len = nbr_fea_len 26 | self.fc_full = nn.Linear(2*self.atom_fea_len+self.nbr_fea_len, 27 | 2*self.atom_fea_len) 28 | self.sigmoid = nn.Sigmoid() 29 | self.softplus1 = nn.Softplus() 30 | self.bn1 = nn.BatchNorm1d(2*self.atom_fea_len) 31 | self.bn2 = nn.BatchNorm1d(self.atom_fea_len) 32 | self.softplus2 = nn.Softplus() 33 | 34 | def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx): 35 | """ 36 | Forward pass 37 | 38 | N: Total number of atoms in the batch 39 | M: Max number of neighbors 40 | 41 | Parameters 42 | ---------- 43 | 44 | atom_in_fea: Variable(torch.Tensor) shape (N, atom_fea_len) 45 | Atom hidden features before convolution 46 | nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len) 47 | Bond features of each atom's M neighbors 48 | nbr_fea_idx: torch.LongTensor shape (N, M) 49 | Indices of M neighbors of each atom 50 | 51 | Returns 52 | ------- 53 | 54 | atom_out_fea: nn.Variable shape (N, atom_fea_len) 55 | Atom hidden features after convolution 56 | 57 | """ 58 | # TODO will there be problems with the index zero padding? 59 | N, M = nbr_fea_idx.shape 60 | # convolution 61 | atom_nbr_fea = atom_in_fea[nbr_fea_idx, :] 62 | total_nbr_fea = torch.cat( 63 | [atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len), 64 | atom_nbr_fea, nbr_fea], dim=2) 65 | total_gated_fea = self.fc_full(total_nbr_fea) 66 | total_gated_fea = self.bn1(total_gated_fea.view( 67 | -1, self.atom_fea_len*2)).view(N, M, self.atom_fea_len*2) 68 | nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2) 69 | nbr_filter = self.sigmoid(nbr_filter) 70 | nbr_core = self.softplus1(nbr_core) 71 | nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1) 72 | nbr_sumed = self.bn2(nbr_sumed) 73 | out = self.softplus2(atom_in_fea + nbr_sumed) 74 | return out 75 | 76 | 77 | class CrystalGraphConvNet(nn.Module): 78 | """ 79 | Create a crystal graph convolutional neural network for predicting total 80 | material properties. 81 | """ 82 | def __init__(self, orig_atom_fea_len, nbr_fea_len, 83 | atom_fea_len=64, n_conv=3, h_fea_len=128, n_h=1, 84 | classification=False): 85 | """ 86 | Initialize CrystalGraphConvNet. 87 | 88 | Parameters 89 | ---------- 90 | 91 | orig_atom_fea_len: int 92 | Number of atom features in the input. 93 | nbr_fea_len: int 94 | Number of bond features. 95 | atom_fea_len: int 96 | Number of hidden atom features in the convolutional layers 97 | n_conv: int 98 | Number of convolutional layers 99 | h_fea_len: int 100 | Number of hidden features after pooling 101 | n_h: int 102 | Number of hidden layers after pooling 103 | """ 104 | super(CrystalGraphConvNet, self).__init__() 105 | self.classification = classification 106 | self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len) 107 | self.convs = nn.ModuleList([ConvLayer(atom_fea_len=atom_fea_len, 108 | nbr_fea_len=nbr_fea_len) 109 | for _ in range(n_conv)]) 110 | self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len) 111 | self.conv_to_fc_softplus = nn.Softplus() 112 | if n_h > 1: 113 | self.fcs = nn.ModuleList([nn.Linear(h_fea_len, h_fea_len) 114 | for _ in range(n_h-1)]) 115 | self.softpluses = nn.ModuleList([nn.Softplus() 116 | for _ in range(n_h-1)]) 117 | if self.classification: 118 | self.fc_out = nn.Linear(h_fea_len, 2) 119 | else: 120 | self.fc_out = nn.Linear(h_fea_len, 1) 121 | if self.classification: 122 | self.logsoftmax = nn.LogSoftmax(dim=1) 123 | self.dropout = nn.Dropout() 124 | 125 | def forward(self, atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx): 126 | """ 127 | Forward pass 128 | 129 | N: Total number of atoms in the batch 130 | M: Max number of neighbors 131 | N0: Total number of crystals in the batch 132 | 133 | Parameters 134 | ---------- 135 | 136 | atom_fea: Variable(torch.Tensor) shape (N, orig_atom_fea_len) 137 | Atom features from atom type 138 | nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len) 139 | Bond features of each atom's M neighbors 140 | nbr_fea_idx: torch.LongTensor shape (N, M) 141 | Indices of M neighbors of each atom 142 | crystal_atom_idx: list of torch.LongTensor of length N0 143 | Mapping from the crystal idx to atom idx 144 | 145 | Returns 146 | ------- 147 | 148 | prediction: nn.Variable shape (N, ) 149 | Atom hidden features after convolution 150 | 151 | """ 152 | atom_fea = self.embedding(atom_fea) 153 | for conv_func in self.convs: 154 | atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx) 155 | crys_fea = self.pooling(atom_fea, crystal_atom_idx) 156 | crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea)) 157 | crys_fea = self.conv_to_fc_softplus(crys_fea) 158 | if self.classification: 159 | crys_fea = self.dropout(crys_fea) 160 | if hasattr(self, 'fcs') and hasattr(self, 'softpluses'): 161 | for fc, softplus in zip(self.fcs, self.softpluses): 162 | crys_fea = softplus(fc(crys_fea)) 163 | out = self.fc_out(crys_fea) 164 | if self.classification: 165 | out = self.logsoftmax(out) 166 | return out 167 | 168 | def pooling(self, atom_fea, crystal_atom_idx): 169 | """ 170 | Pooling the atom features to crystal features 171 | 172 | N: Total number of atoms in the batch 173 | N0: Total number of crystals in the batch 174 | 175 | Parameters 176 | ---------- 177 | 178 | atom_fea: Variable(torch.Tensor) shape (N, atom_fea_len) 179 | Atom feature vectors of the batch 180 | crystal_atom_idx: list of torch.LongTensor of length N0 181 | Mapping from the crystal idx to atom idx 182 | """ 183 | assert sum([len(idx_map) for idx_map in crystal_atom_idx]) ==\ 184 | atom_fea.data.shape[0] 185 | summed_fea = [torch.mean(atom_fea[idx_map], dim=0, keepdim=True) 186 | for idx_map in crystal_atom_idx] 187 | return torch.cat(summed_fea, dim=0) 188 | -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall20_Cbg/checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall20_Cbg/checkpoint.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall20_Cbg/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall20_Cbg/model_best.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall20_Cfm/checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall20_Cfm/checkpoint.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall20_Cfm/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall20_Cfm/model_best.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall20_bg/checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall20_bg/checkpoint.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall20_bg/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall20_bg/model_best.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall20_fm/checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall20_fm/checkpoint.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall20_fm/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall20_fm/model_best.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall40_Cbg/checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall40_Cbg/checkpoint.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall40_Cbg/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall40_Cbg/model_best.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall40_Cfm/checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall40_Cfm/checkpoint.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall40_Cfm/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall40_Cfm/model_best.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall40_bg/checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall40_bg/checkpoint.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall40_bg/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall40_bg/model_best.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall40_fm/checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall40_fm/checkpoint.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/mpall40_fm/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/mpall40_fm/model_best.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/oqmd_Cbg/checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/oqmd_Cbg/checkpoint.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/oqmd_Cbg/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/oqmd_Cbg/model_best.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/oqmd_Cfm/checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/oqmd_Cfm/checkpoint.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/oqmd_Cfm/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/oqmd_Cfm/model_best.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/oqmd_bg/checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/oqmd_bg/checkpoint.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/oqmd_bg/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/oqmd_bg/model_best.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/oqmd_fm/checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/oqmd_fm/checkpoint.pth.tar -------------------------------------------------------------------------------- /cgcnn/pre-trained/oqmd_fm/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/cgcnn/pre-trained/oqmd_fm/model_best.pth.tar -------------------------------------------------------------------------------- /concdvae/PT_train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/PT_train/__init__.py -------------------------------------------------------------------------------- /concdvae/PT_train/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/PT_train/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/PT_train/__pycache__/training.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/PT_train/__pycache__/training.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/__init__.py -------------------------------------------------------------------------------- /concdvae/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/common/__init__.py -------------------------------------------------------------------------------- /concdvae/common/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/common/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/common/__pycache__/data_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/common/__pycache__/data_utils.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/common/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/common/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/common/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import dotenv 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | def get_env(env_name: str, default: Optional[str] = None) -> str: 9 | """ 10 | Safely read an environment variable. 11 | Raises errors if it is not defined or it is empty. 12 | 13 | :param env_name: the name of the environment variable 14 | :param default: the default (optional) value for the environment variable 15 | 16 | :return: the value of the environment variable 17 | """ 18 | a = os.environ 19 | if env_name not in os.environ: 20 | if default is None: 21 | raise KeyError( 22 | f"{env_name} not defined and no default value is present!") 23 | return default 24 | 25 | env_value: str = os.environ[env_name] 26 | if not env_value: 27 | if default is None: 28 | raise ValueError( 29 | f"{env_name} has yet to be configured and no default value is present!" 30 | ) 31 | return default 32 | 33 | return env_value 34 | 35 | 36 | def load_envs(env_file: Optional[str] = None) -> None: 37 | """ 38 | Load all the environment variables defined in the `env_file`. 39 | This is equivalent to `. env_file` in bash. 40 | 41 | It is possible to define all the system specific variables in the `env_file`. 42 | 43 | :param env_file: the file that defines the environment variables to use. If None 44 | it searches for a `.env` file in the project. 45 | """ 46 | dotenv.load_dotenv(dotenv_path=env_file, override=True) 47 | 48 | 49 | def param_statistics(model): 50 | # Total params 51 | total_params = sum(p.numel() for p in model.parameters()) 52 | print("Total params:", total_params) 53 | 54 | # Trainable params 55 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 56 | print("Trainable params:", trainable_params) 57 | 58 | # Space 59 | total_params_size = sum(p.numel() * p.element_size() for p in model.parameters()) 60 | print("Used Space(MB):", round(total_params_size/1024/1024,2)) 61 | 62 | STATS_KEY: str = "stats" 63 | 64 | 65 | # Load environment variables 66 | load_envs() 67 | 68 | # Set the cwd to the project root 69 | PROJECT_ROOT: Path = Path(get_env("PROJECT_ROOT")) 70 | assert ( 71 | PROJECT_ROOT.exists() 72 | ), "You must configure the PROJECT_ROOT environment variable in a .env file!" 73 | 74 | os.chdir(PROJECT_ROOT) -------------------------------------------------------------------------------- /concdvae/pl_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_data/__init__.py -------------------------------------------------------------------------------- /concdvae/pl_data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_data/__pycache__/datamodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_data/__pycache__/datamodule.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_data/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_data/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_data/datamodule.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional, Sequence 3 | from pathlib import Path 4 | 5 | import hydra 6 | import numpy as np 7 | import os 8 | import omegaconf 9 | import torch 10 | from omegaconf import DictConfig 11 | from torch.utils.data import Dataset 12 | from torch_geometric.loader import DataLoader 13 | from torch.utils.data.distributed import DistributedSampler 14 | 15 | from concdvae.common.data_utils import get_scaler_from_data_list, get_maxAmin_from_data_list,GaussianDistance 16 | 17 | 18 | def worker_init_fn(id: int): 19 | """ 20 | DataLoaders workers init function. 21 | 22 | Initialize the numpy.random seed correctly for each worker, so that 23 | random augmentations between workers and/or epochs are not identical. 24 | 25 | If a global seed is set, the augmentations are deterministic. 26 | 27 | https://pytorch.org/docs/stable/notes/randomness.html#dataloader 28 | """ 29 | uint64_seed = torch.initial_seed() 30 | ss = np.random.SeedSequence([uint64_seed]) 31 | # More than 128 bits (4 32-bit words) would be overkill. 32 | np.random.seed(ss.generate_state(4)) 33 | random.seed(uint64_seed) 34 | 35 | 36 | class CrystDataModule(): 37 | def __init__( 38 | self, 39 | accelerator, 40 | n_delta, 41 | use_prop, 42 | datasets: DictConfig, 43 | num_workers: DictConfig, 44 | batch_size: DictConfig, 45 | scaler_path=None, 46 | ): 47 | super().__init__() 48 | self.datasets = datasets 49 | self.num_workers = num_workers 50 | self.batch_size = batch_size 51 | 52 | self.train_dataset: Optional[Dataset] = None 53 | self.val_datasets: Optional[Sequence[Dataset]] = None 54 | self.test_datasets: Optional[Sequence[Dataset]] = None 55 | 56 | train_path = self.datasets['train']['path'] 57 | train_path = os.path.dirname(train_path) 58 | train_path = os.path.join(train_path, 'train_data.pt') 59 | if (os.path.exists(train_path)): 60 | self.train_dataset = torch.load(train_path) 61 | else: 62 | self.train_dataset = hydra.utils.instantiate(self.datasets.train, _recursive_=False) 63 | torch.save(self.train_dataset, train_path) 64 | print('load train') 65 | 66 | self.lattice_scaler = get_scaler_from_data_list( 67 | self.train_dataset.cached_data, 68 | key='scaled_lattice') 69 | 70 | val_path = self.datasets['val'][0]['path'] 71 | val_path = os.path.dirname(val_path) 72 | val_path = os.path.join(val_path, 'val_data.pt') 73 | if (os.path.exists(val_path)): 74 | self.val_datasets = [torch.load(val_path)] 75 | else: 76 | self.val_datasets = [ hydra.utils.instantiate(dataset_cfg) 77 | for dataset_cfg in self.datasets.val] 78 | torch.save(self.val_datasets[0], val_path) 79 | print('load val') 80 | for val_dataset in self.val_datasets: 81 | val_dataset.lattice_scaler = self.lattice_scaler 82 | 83 | test_path = self.datasets['test'][0]['path'] 84 | test_path = os.path.dirname(test_path) 85 | test_path = os.path.join(test_path, 'test_data.pt') 86 | if (os.path.exists(test_path)): 87 | self.test_datasets = [torch.load(test_path)] 88 | else: 89 | self.test_datasets = [hydra.utils.instantiate(dataset_cfg) 90 | for dataset_cfg in self.datasets.val] 91 | torch.save(self.test_datasets[0], test_path) 92 | print('load test') 93 | for test_dataset in self.test_datasets: 94 | test_dataset.lattice_scaler = self.lattice_scaler 95 | 96 | if accelerator == 'DDP': 97 | train_shuffle = False 98 | train_sampler = DistributedSampler(self.train_dataset) 99 | val_samplers = [DistributedSampler(dataset) for dataset in self.val_datasets] 100 | test_samplers = [DistributedSampler(dataset) for dataset in self.test_datasets] 101 | else: 102 | train_shuffle = True 103 | train_sampler = None 104 | val_samplers = [None for dataset in self.val_datasets] 105 | test_samplers = [None for dataset in self.test_datasets] 106 | 107 | 108 | self.train_dataloader = DataLoader( 109 | self.train_dataset, 110 | shuffle=train_shuffle, 111 | batch_size=self.batch_size.train, 112 | num_workers=self.num_workers.train, 113 | worker_init_fn=worker_init_fn, 114 | sampler=train_sampler, 115 | ) 116 | 117 | self.val_dataloaders = [ 118 | DataLoader( 119 | self.val_datasets[i], 120 | shuffle=False, 121 | batch_size=self.batch_size.val, 122 | num_workers=self.num_workers.val, 123 | worker_init_fn=worker_init_fn, 124 | sampler=val_samplers[i] 125 | ) 126 | for i in range(len(self.val_datasets))] 127 | 128 | self.test_dataloaders = [ 129 | DataLoader( 130 | self.test_datasets[i], 131 | shuffle=False, 132 | batch_size=self.batch_size.test, 133 | num_workers=self.num_workers.test, 134 | worker_init_fn=worker_init_fn, 135 | sampler=test_samplers[i] 136 | ) 137 | for i in range(len(self.test_datasets))] 138 | 139 | def __repr__(self) -> str: 140 | return ( 141 | f"{self.__class__.__name__}(" 142 | f"{self.datasets=}, " 143 | f"{self.num_workers=}, " 144 | f"{self.batch_size=})" 145 | ) -------------------------------------------------------------------------------- /concdvae/pl_data/dataset.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import omegaconf 3 | import torch 4 | import pandas as pd 5 | import numpy as np 6 | import os 7 | import json 8 | from omegaconf import ValueNode 9 | from torch.utils.data import Dataset 10 | 11 | from torch_geometric.data import Data 12 | from pymatgen.core.structure import Structure 13 | 14 | from concdvae.common.utils import PROJECT_ROOT 15 | from concdvae.common.data_utils import ( 16 | preprocess, add_scaled_lattice_prop,chemical_symbols) 17 | 18 | class CrystDataset(Dataset): 19 | def __init__(self, name: ValueNode, path: ValueNode, 20 | prop: ValueNode, use_prop: ValueNode, niggli: ValueNode, primitive: ValueNode, 21 | graph_method: ValueNode, preprocess_workers: ValueNode, 22 | lattice_scale_method: ValueNode, 23 | **kwargs): 24 | super().__init__() 25 | self.path = path 26 | self.name = name 27 | self.df = pd.read_csv(path) 28 | self.prop = prop 29 | self.use_prop = use_prop 30 | self.niggli = niggli 31 | self.primitive = primitive 32 | self.graph_method = graph_method 33 | self.lattice_scale_method = lattice_scale_method 34 | 35 | 36 | 37 | self.cached_data = preprocess( 38 | self.path, 39 | preprocess_workers, 40 | niggli=self.niggli, 41 | primitive=self.primitive, 42 | graph_method=self.graph_method, 43 | prop_list=list(prop)) 44 | 45 | add_scaled_lattice_prop(self.cached_data, lattice_scale_method) 46 | self.lattice_scaler = None 47 | 48 | atom_init_file = os.path.dirname(self.path) 49 | atom_init_file = os.path.join(atom_init_file, 'atom_init.json') 50 | if os.path.exists(atom_init_file): 51 | self.ari = AtomCustomJSONInitializer(atom_init_file) 52 | for i in range(len(self.cached_data)): 53 | crystal = Structure.from_str(self.cached_data[i]['cif'], fmt="cif") 54 | 55 | atom_fea = np.vstack([self.ari.get_atom_fea(crystal[i].specie.number) 56 | for i in range(len(crystal))]) 57 | atom_fea = torch.Tensor(atom_fea) 58 | atom_fea = torch.mean(atom_fea, dim=0) 59 | atom_fea = atom_fea.reshape(1, 92) 60 | self.cached_data[i].update({'formula':atom_fea}) 61 | else: 62 | self.ari = None 63 | 64 | def __len__(self) -> int: 65 | return len(self.cached_data) 66 | 67 | def __getitem__(self, index): 68 | data_dict = self.cached_data[index] 69 | 70 | (frac_coords, atom_types, lengths, angles, edge_indices, 71 | to_jimages, num_atoms) = data_dict['graph_arrays'] 72 | 73 | # atom_coords are fractional coordinates 74 | # edge_index is incremented during batching 75 | # https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html 76 | data = Data( 77 | frac_coords=torch.Tensor(frac_coords), 78 | atom_types=torch.LongTensor(atom_types), 79 | lengths=torch.Tensor(lengths).view(1, -1), 80 | angles=torch.Tensor(angles).view(1, -1), 81 | edge_index=torch.LongTensor( 82 | edge_indices.T).contiguous(), # shape (2, num_edges) 83 | to_jimages=torch.LongTensor(to_jimages), 84 | num_atoms=num_atoms, 85 | num_bonds=edge_indices.shape[0], 86 | num_nodes=num_atoms, # special attribute used for batching in pytorch geometric 87 | ) 88 | 89 | exclude_keys = ['cif', 'graph_arrays', 'scaled_lattice'] 90 | filtered_data = {key: value for key, value in data_dict.items() if key not in exclude_keys} 91 | data.update(filtered_data) 92 | 93 | if self.ari != None: 94 | data.update({'formula': self.cached_data[index]['formula']}) 95 | 96 | return data 97 | 98 | def __repr__(self) -> str: 99 | return f"TensorCrystDataset(len: {len(self.cached_data)})" 100 | 101 | 102 | class AtomInitializer(object): 103 | """ 104 | Base class for intializing the vector representation for atoms. 105 | 106 | !!! Use one AtomInitializer per dataset !!! 107 | """ 108 | def __init__(self, atom_types): 109 | self.atom_types = set(atom_types) 110 | self._embedding = {} 111 | 112 | def get_atom_fea(self, atom_type): 113 | assert atom_type in self.atom_types 114 | return self._embedding[atom_type] 115 | 116 | def load_state_dict(self, state_dict): 117 | self._embedding = state_dict 118 | self.atom_types = set(self._embedding.keys()) 119 | self._decodedict = {idx: atom_type for atom_type, idx in 120 | self._embedding.items()} 121 | 122 | def state_dict(self): 123 | return self._embedding 124 | 125 | def decode(self, idx): 126 | if not hasattr(self, '_decodedict'): 127 | self._decodedict = {idx: atom_type for atom_type, idx in 128 | self._embedding.items()} 129 | return self._decodedict[idx] 130 | 131 | 132 | class AtomCustomJSONInitializer(AtomInitializer): 133 | """ 134 | Initialize atom feature vectors using a JSON file, which is a python 135 | dictionary mapping from element number to a list representing the 136 | feature vector of the element. 137 | 138 | Parameters 139 | ---------- 140 | 141 | elem_embedding_file: str 142 | The path to the .json file 143 | """ 144 | def __init__(self, elem_embedding_file): 145 | with open(elem_embedding_file) as f: 146 | elem_embedding = json.load(f) 147 | elem_embedding = {int(key): value for key, value 148 | in elem_embedding.items()} 149 | atom_types = set(elem_embedding.keys()) 150 | super(AtomCustomJSONInitializer, self).__init__(atom_types) 151 | for key, value in elem_embedding.items(): 152 | self._embedding[key] = np.array(value, dtype=float) 153 | 154 | 155 | def formula2atomnums(formula): 156 | elements = [] 157 | current_element = "" 158 | current_count = "" 159 | 160 | for char in formula: 161 | if char.isupper(): 162 | if current_element: 163 | elements.append((current_element, int(current_count) if current_count else 1)) 164 | current_element = char 165 | current_count = "" 166 | elif char.islower(): 167 | current_element += char 168 | elif char.isdigit(): 169 | current_count += char 170 | 171 | if current_element: 172 | elements.append((current_element, int(current_count) if current_count else 1)) 173 | 174 | ele_list = [] 175 | for data in elements: 176 | for time in range(data[1]): 177 | ele_list.append(data[0]) 178 | 179 | 180 | index_list = [] 181 | for ele in ele_list: 182 | index = chemical_symbols.index(ele) 183 | index_list.append(index) 184 | 185 | 186 | return index_list -------------------------------------------------------------------------------- /concdvae/pl_modules/ConditionModel.py: -------------------------------------------------------------------------------- 1 | ##This part of the code refers to https://github.com/atomistic-machine-learning/SchNet 2 | 3 | import torch.nn as nn 4 | from concdvae.pl_modules.model import build_mlp 5 | import hydra 6 | import torch 7 | import math 8 | import sys 9 | from typing import Dict, Optional, List, Callable, Union, Sequence 10 | from omegaconf import ListConfig 11 | 12 | class ConditioningModule(nn.Module): 13 | def __init__( 14 | self, 15 | n_features, 16 | n_layers, 17 | condition_embeddings, 18 | ): 19 | 20 | super(ConditioningModule, self).__init__() 21 | self.n_features = n_features 22 | self.condition_embeddings = condition_embeddings 23 | condition_embModel = [] 24 | # self.condition_embModel = condition_embeddings 25 | n_in = 0 26 | for condition_emb in self.condition_embeddings: 27 | condition_embModel.append(hydra.utils.instantiate(condition_emb)) 28 | n_in += condition_emb.n_features 29 | self.condition_embModel = nn.ModuleList(condition_embModel) 30 | 31 | self.dense_net = build_mlp( 32 | in_dim=n_in, 33 | out_dim=self.n_features, 34 | hidden_dim=self.n_features, 35 | fc_num_layers=n_layers, 36 | norm=False, 37 | ) 38 | 39 | def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 40 | # embed all conditions 41 | emb_features = [] 42 | for emb in self.condition_embModel: 43 | emb_features += [emb(inputs)] 44 | # concatenate the features 45 | emb_features = torch.cat(emb_features, dim=-1) 46 | # mix the concatenated features 47 | conditional_features = self.dense_net(emb_features) 48 | return conditional_features 49 | 50 | 51 | 52 | class ConditionEmbedding(nn.Module): 53 | 54 | def __init__( 55 | self, 56 | condition_name: str, 57 | n_features: int, 58 | required_data_properties: Optional[List[str]] = [], 59 | condition_type: str = "trajectory", 60 | ): 61 | 62 | super().__init__() 63 | if condition_type not in ["trajectory", "step", "atom"]: 64 | raise ValueError( 65 | f"`condition_type` is {condition_type} but needs to be `trajectory`, " 66 | f"`step`, or `atom` for trajectory-wise, step-wise, or atom-wise " 67 | f"conditions, respectively." 68 | ) 69 | self.condition_name = condition_name 70 | self.condition_type = condition_type 71 | self.n_features = n_features 72 | self.required_data_properties = required_data_properties 73 | 74 | def forward( 75 | self, 76 | inputs: Dict[str, torch.Tensor], 77 | ) -> torch.Tensor: 78 | raise NotImplementedError 79 | 80 | 81 | # class ScalarConditionEmbedding(ConditionEmbedding): 82 | class ScalarConditionEmbedding(nn.Module): 83 | def __init__( 84 | self, 85 | condition_name: str, 86 | condition_min: float, 87 | condition_max: float, 88 | grid_spacing: float, 89 | n_features: int, 90 | n_layers: int, 91 | required_data_properties: Optional[List[str]] = [], 92 | condition_type: str = "trajectory", 93 | ): 94 | super(ScalarConditionEmbedding, self).__init__() 95 | # super().__init__( 96 | # condition_name, n_features, required_data_properties, condition_type 97 | # ) 98 | self.condition_name = condition_name 99 | self.condition_type = condition_type 100 | self.n_features = n_features 101 | self.required_data_properties = required_data_properties 102 | # compute the number of rbfs 103 | n_rbf = math.ceil((condition_max - condition_min) / grid_spacing) + 1 104 | # compute the position of the last rbf 105 | _max = condition_min + grid_spacing * (n_rbf - 1) 106 | # initialize Gaussian rbf expansion network 107 | self.gaussian_expansion = GaussianRBF( 108 | n_rbf=n_rbf, cutoff=_max, start=condition_min 109 | ) 110 | # initialize fully connected network 111 | self.dense_net = build_mlp( 112 | in_dim=n_rbf, 113 | hidden_dim=n_features, 114 | fc_num_layers = n_layers, 115 | out_dim=n_features, 116 | norm=False, 117 | ) 118 | 119 | def forward( 120 | self, 121 | inputs: Dict[str, torch.Tensor], 122 | ) -> torch.Tensor: 123 | # get the scalar condition value 124 | scalar_condition = torch.Tensor(inputs[self.condition_name]).float() 125 | # expand the scalar value with Gaussian rbfs 126 | expanded_condition = self.gaussian_expansion(scalar_condition) 127 | # feed through fully connected network 128 | embedded_condition = self.dense_net(expanded_condition) 129 | return embedded_condition 130 | 131 | 132 | class ClassConditionEmbedding(nn.Module): 133 | def __init__( 134 | self, 135 | condition_name: str, 136 | n_type: int, 137 | n_emb: int, 138 | n_features: int, 139 | n_layers: int, 140 | required_data_properties: Optional[List[str]] = [], 141 | condition_type: str = "trajectory", 142 | ): 143 | super(ClassConditionEmbedding, self).__init__() 144 | self.condition_name = condition_name 145 | self.n_type = n_type 146 | self.embedding_layer = nn.Embedding(n_type, n_emb) 147 | 148 | self.dense_net = build_mlp( 149 | in_dim=n_emb, 150 | hidden_dim=n_features, 151 | fc_num_layers=n_layers, 152 | out_dim=n_features, 153 | norm=False, 154 | ) 155 | 156 | def forward( 157 | self, 158 | inputs: Dict[str, torch.Tensor], 159 | ) -> torch.Tensor: 160 | emb_input = inputs[self.condition_name].int() 161 | emb_condition = self.embedding_layer(emb_input) 162 | embedded_condition = self.dense_net(emb_condition) 163 | return embedded_condition 164 | 165 | 166 | class VectorialConditionEmbedding(nn.Module): 167 | """ 168 | An embedding network for vectorial conditions (e.g. a fingerprint). The vector is 169 | mapped to the final embedding with a fully connected neural network. 170 | """ 171 | 172 | def __init__( 173 | self, 174 | condition_name: str, 175 | n_in: int, 176 | n_features: int, 177 | n_layers: int, 178 | required_data_properties: Optional[List[str]] = [], 179 | condition_type: str = "trajectory", 180 | ): 181 | 182 | super(VectorialConditionEmbedding, self).__init__() 183 | self.condition_name = condition_name 184 | # initialize fully connected network 185 | self.dense_net = build_mlp( 186 | in_dim=n_in, 187 | hidden_dim=n_features, 188 | fc_num_layers=n_layers, 189 | out_dim=n_features, 190 | norm=False, 191 | ) 192 | 193 | def forward( 194 | self, 195 | inputs: Dict[str, torch.Tensor], 196 | ) -> torch.Tensor: 197 | # get the vectorial condition value 198 | vectorial_condition = inputs[self.condition_name] 199 | # feed through fully connected network 200 | embedded_condition = self.dense_net(vectorial_condition) 201 | return embedded_condition 202 | 203 | 204 | class GaussianRBF(nn.Module): 205 | r"""Gaussian radial basis functions.""" 206 | 207 | def __init__( 208 | self, n_rbf: int, cutoff: float, start: float = 0.0, trainable: bool = False 209 | ): 210 | super(GaussianRBF, self).__init__() 211 | self.n_rbf = n_rbf 212 | 213 | # compute offset and width of Gaussian functions 214 | offset = torch.linspace(start, cutoff, n_rbf) 215 | widths = torch.FloatTensor( 216 | torch.abs(offset[1] - offset[0]) * torch.ones_like(offset) 217 | ) 218 | if trainable: 219 | self.widths = nn.Parameter(widths) 220 | self.offsets = nn.Parameter(offset) 221 | else: 222 | # self.register_buffer("widths", widths) 223 | # self.register_buffer("offsets", offset) 224 | self.widths = nn.Parameter(widths) 225 | self.offsets = nn.Parameter(offset) 226 | self.widths.requires_grad = False 227 | self.offsets.requires_grad = False 228 | 229 | def forward(self, inputs: torch.Tensor): 230 | return gaussian_rbf(inputs, self.offsets, self.widths) 231 | 232 | 233 | def gaussian_rbf(inputs: torch.Tensor, offsets: torch.Tensor, widths: torch.Tensor): 234 | # print('input de:', inputs.device, file=sys.stdout) 235 | # print('offset de:', offsets.device, file=sys.stdout) 236 | # print('widths de:', widths.device, file=sys.stdout) 237 | coeff = -0.5 / torch.pow(widths, 2) 238 | diff = inputs[..., None] - offsets 239 | y = torch.exp(coeff * torch.pow(diff, 2)) 240 | return y.float() -------------------------------------------------------------------------------- /concdvae/pl_modules/PreCondition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from concdvae.pl_modules.model import build_mlp 4 | from torch.nn import functional as F 5 | 6 | class ScalarConditionPredict(nn.Module): 7 | def __init__( 8 | self, 9 | condition_name: str, 10 | condition_min: float, 11 | condition_max: float, 12 | latent_dim: int, 13 | hidden_dim: int, 14 | out_dim: int, 15 | n_layers: int, 16 | drop: float = -1, 17 | ): 18 | super(ScalarConditionPredict, self).__init__() 19 | self.condition_name = condition_name 20 | self.condition_min = condition_min 21 | self.condition_max = condition_max 22 | self.latent_dim = latent_dim 23 | self.hidden_dim = hidden_dim 24 | self.out_dim = out_dim 25 | self.n_layers = n_layers 26 | self.drop = drop 27 | 28 | 29 | self.mlp = build_mlp(in_dim=self.latent_dim, 30 | hidden_dim=self.hidden_dim, 31 | fc_num_layers=self.n_layers, 32 | out_dim=self.out_dim, 33 | drop=self.drop) 34 | 35 | 36 | def forward(self, inputs, z): 37 | predict = self.mlp(z) 38 | loss = self.property_loss(inputs, predict) 39 | return loss 40 | 41 | 42 | def property_loss(self, inputs, predict): 43 | true = torch.Tensor(inputs[self.condition_name]).float() 44 | true = (true - self.condition_min) / (self.condition_max - self.condition_min) 45 | true = true.view(true.size(0), 1) 46 | return F.mse_loss(predict, true) 47 | 48 | 49 | class ClassConditionPredict(nn.Module): 50 | def __init__( 51 | self, 52 | condition_name: str, 53 | n_type: int, 54 | latent_dim: int, 55 | hidden_dim: int, 56 | n_layers: int, 57 | drop: float = -1, 58 | ): 59 | super(ClassConditionPredict, self).__init__() 60 | self.condition_name = condition_name 61 | self.n_type = n_type 62 | self.latent_dim = latent_dim 63 | self.hidden_dim = hidden_dim 64 | self.n_layers = n_layers 65 | self.drop = drop 66 | 67 | if drop > 0 and drop < 1: 68 | list_sqe = [nn.Dropout(p=drop), nn.Linear(self.latent_dim, self.n_type)] 69 | else: 70 | list_sqe = [nn.Linear(self.latent_dim, self.n_type)] 71 | self.mlp = nn.Sequential(*list_sqe) 72 | 73 | self.criterion = nn.CrossEntropyLoss() 74 | 75 | def forward(self, inputs, z): 76 | predict = self.mlp(z) 77 | loss = self.property_loss(inputs, predict) 78 | return loss 79 | 80 | 81 | def property_loss(self, inputs, predict): 82 | true = torch.Tensor(inputs[self.condition_name]).long() 83 | return self.criterion(predict,true) 84 | -------------------------------------------------------------------------------- /concdvae/pl_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/__init__.py -------------------------------------------------------------------------------- /concdvae/pl_modules/__pycache__/ConditionModel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/__pycache__/ConditionModel.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/__pycache__/PreCondition.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/__pycache__/PreCondition.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/__pycache__/decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/__pycache__/decoder.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/__pycache__/gnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/__pycache__/gnn.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from concdvae.pl_modules.embeddings import MAX_ATOMIC_NUM 6 | from concdvae.pl_modules.gemnet.gemnet import GemNetT 7 | 8 | 9 | def build_mlp(in_dim, hidden_dim, fc_num_layers, out_dim): 10 | mods = [nn.Linear(in_dim, hidden_dim), nn.ReLU()] 11 | for i in range(fc_num_layers-1): 12 | mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU()] 13 | mods += [nn.Linear(hidden_dim, out_dim)] 14 | return nn.Sequential(*mods) 15 | 16 | 17 | class GemNetTDecoder(nn.Module): 18 | """Decoder with GemNetT.""" 19 | 20 | def __init__( 21 | self, 22 | hidden_dim=128, 23 | latent_dim=256, 24 | max_neighbors=20, 25 | time_emb_dim=64, 26 | radius=6., 27 | scale_file=None, 28 | ): 29 | super(GemNetTDecoder, self).__init__() 30 | self.cutoff = radius 31 | self.max_num_neighbors = max_neighbors 32 | 33 | self.gemnet = GemNetT( 34 | num_targets=1, 35 | latent_dim=latent_dim+time_emb_dim, #!!!!1 36 | emb_size_atom=hidden_dim, 37 | emb_size_edge=hidden_dim, 38 | regress_forces=True, 39 | cutoff=self.cutoff, 40 | max_neighbors=self.max_num_neighbors, 41 | otf_graph=True, 42 | scale_file=scale_file, 43 | ) 44 | self.fc_atom = nn.Linear(hidden_dim, MAX_ATOMIC_NUM) 45 | 46 | def forward(self, z, pred_frac_coords, pred_atom_types, num_atoms, 47 | lengths, angles): 48 | """ 49 | args: 50 | z: (N_cryst, num_latent) 51 | pred_frac_coords: (N_atoms, 3) 52 | pred_atom_types: (N_atoms, ), need to use atomic number e.g. H = 1 53 | num_atoms: (N_cryst,) 54 | lengths: (N_cryst, 3) 55 | angles: (N_cryst, 3) 56 | returns: 57 | atom_frac_coords: (N_atoms, 3) 58 | atom_types: (N_atoms, MAX_ATOMIC_NUM) 59 | """ 60 | # (num_atoms, hidden_dim) (num_crysts, 3) 61 | h, pred_cart_coord_diff = self.gemnet( 62 | z=z, 63 | frac_coords=pred_frac_coords, 64 | atom_types=pred_atom_types, 65 | num_atoms=num_atoms, 66 | lengths=lengths, 67 | angles=angles, 68 | edge_index=None, 69 | to_jimages=None, 70 | num_bonds=None, 71 | ) 72 | pred_atom_types = self.fc_atom(h) 73 | return pred_cart_coord_diff, pred_atom_types -------------------------------------------------------------------------------- /concdvae/pl_modules/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "ATOMIC_RADII", 3 | "KHOT_EMBEDDINGS", 4 | "CONTINUOUS_EMBEDDINGS", 5 | "MAX_ATOMIC_NUM", 6 | ] 7 | 8 | from .atomic_radii import ATOMIC_RADII 9 | from .continuous_embeddings import CONTINUOUS_EMBEDDINGS 10 | from .khot_embeddings import KHOT_EMBEDDINGS 11 | 12 | MAX_ATOMIC_NUM = 100 13 | -------------------------------------------------------------------------------- /concdvae/pl_modules/embeddings/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/embeddings/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/embeddings/__pycache__/atomic_radii.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/embeddings/__pycache__/atomic_radii.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/embeddings/__pycache__/continuous_embeddings.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/embeddings/__pycache__/continuous_embeddings.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/embeddings/__pycache__/khot_embeddings.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/embeddings/__pycache__/khot_embeddings.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/embeddings/atomic_radii.py: -------------------------------------------------------------------------------- 1 | """ 2 | Atomic radii in picometers 3 | 4 | NaN stored for unavailable parameters. 5 | """ 6 | ATOMIC_RADII = { 7 | 0: float("NaN"), 8 | 1: 25.0, 9 | 2: 120.0, 10 | 3: 145.0, 11 | 4: 105.0, 12 | 5: 85.0, 13 | 6: 70.0, 14 | 7: 65.0, 15 | 8: 60.0, 16 | 9: 50.0, 17 | 10: 160.0, 18 | 11: 180.0, 19 | 12: 150.0, 20 | 13: 125.0, 21 | 14: 110.0, 22 | 15: 100.0, 23 | 16: 100.0, 24 | 17: 100.0, 25 | 18: 71.0, 26 | 19: 220.0, 27 | 20: 180.0, 28 | 21: 160.0, 29 | 22: 140.0, 30 | 23: 135.0, 31 | 24: 140.0, 32 | 25: 140.0, 33 | 26: 140.0, 34 | 27: 135.0, 35 | 28: 135.0, 36 | 29: 135.0, 37 | 30: 135.0, 38 | 31: 130.0, 39 | 32: 125.0, 40 | 33: 115.0, 41 | 34: 115.0, 42 | 35: 115.0, 43 | 36: float("NaN"), 44 | 37: 235.0, 45 | 38: 200.0, 46 | 39: 180.0, 47 | 40: 155.0, 48 | 41: 145.0, 49 | 42: 145.0, 50 | 43: 135.0, 51 | 44: 130.0, 52 | 45: 135.0, 53 | 46: 140.0, 54 | 47: 160.0, 55 | 48: 155.0, 56 | 49: 155.0, 57 | 50: 145.0, 58 | 51: 145.0, 59 | 52: 140.0, 60 | 53: 140.0, 61 | 54: float("NaN"), 62 | 55: 260.0, 63 | 56: 215.0, 64 | 57: 195.0, 65 | 58: 185.0, 66 | 59: 185.0, 67 | 60: 185.0, 68 | 61: 185.0, 69 | 62: 185.0, 70 | 63: 185.0, 71 | 64: 180.0, 72 | 65: 175.0, 73 | 66: 175.0, 74 | 67: 175.0, 75 | 68: 175.0, 76 | 69: 175.0, 77 | 70: 175.0, 78 | 71: 175.0, 79 | 72: 155.0, 80 | 73: 145.0, 81 | 74: 135.0, 82 | 75: 135.0, 83 | 76: 130.0, 84 | 77: 135.0, 85 | 78: 135.0, 86 | 79: 135.0, 87 | 80: 150.0, 88 | 81: 190.0, 89 | 82: 180.0, 90 | 83: 160.0, 91 | 84: 190.0, 92 | 85: float("NaN"), 93 | 86: float("NaN"), 94 | 87: float("NaN"), 95 | 88: 215.0, 96 | 89: 195.0, 97 | 90: 180.0, 98 | 91: 180.0, 99 | 92: 175.0, 100 | 93: 175.0, 101 | 94: 175.0, 102 | 95: 175.0, 103 | 96: float("NaN"), 104 | 97: float("NaN"), 105 | 98: float("NaN"), 106 | 99: float("NaN"), 107 | 100: float("NaN"), 108 | } 109 | -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/gemnet/__init__.py -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/gemnet/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/__pycache__/gemnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/gemnet/__pycache__/gemnet.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/__pycache__/initializers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/gemnet/__pycache__/initializers.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/gemnet/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/fit_scaling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | 7 | --- 8 | 9 | Script for calculating the scaling factors used to even out GemNet activation 10 | scales. This generates the `scale_file` specified in the config, which is then 11 | read in at model initialization. 12 | This only needs to be run if the hyperparameters or model change 13 | in places were it would affect the activation scales. 14 | """ 15 | 16 | import logging 17 | import os 18 | import sys 19 | 20 | import torch 21 | from tqdm import trange 22 | 23 | from ocpmodels.models.gemnet.layers.scaling import AutomaticFit 24 | from ocpmodels.models.gemnet.utils import write_json 25 | from ocpmodels.common.flags import flags 26 | from ocpmodels.common.registry import registry 27 | from ocpmodels.common.utils import build_config, setup_imports, setup_logging 28 | 29 | if __name__ == "__main__": 30 | setup_logging() 31 | 32 | num_batches = 16 # number of batches to use to fit a single variable 33 | 34 | parser = flags.get_parser() 35 | args, override_args = parser.parse_known_args() 36 | config = build_config(args, override_args) 37 | assert config["model"]["name"].startswith("gemnet") 38 | config["logger"] = "tensorboard" 39 | 40 | if args.distributed: 41 | raise ValueError( 42 | "I don't think this works with DDP (race conditions)." 43 | ) 44 | 45 | setup_imports() 46 | 47 | scale_file = config["model"]["scale_file"] 48 | 49 | logging.info(f"Run fitting for model: {args.identifier}") 50 | logging.info(f"Target scale file: {scale_file}") 51 | 52 | def initialize_scale_file(scale_file): 53 | # initialize file 54 | preset = {"comment": args.identifier} 55 | write_json(scale_file, preset) 56 | 57 | if os.path.exists(scale_file): 58 | logging.warning(f"Already found existing file: {scale_file}") 59 | flag = input( 60 | "Do you want to continue and overwrite the file (1), " 61 | "only fit the variables not fitted yet (2), or exit (3)? " 62 | ) 63 | if str(flag) == "1": 64 | logging.info("Overwriting the current file.") 65 | initialize_scale_file(scale_file) 66 | elif str(flag) == "2": 67 | logging.info("Only fitting unfitted variables.") 68 | else: 69 | print(flag) 70 | logging.info("Exiting script") 71 | sys.exit() 72 | else: 73 | initialize_scale_file(scale_file) 74 | 75 | AutomaticFit.set2fitmode() 76 | 77 | trainer = registry.get_trainer_class(config.get("trainer", "simple"))( 78 | task=config["task"], 79 | model=config["model"], 80 | dataset=config["dataset"], 81 | optimizer=config["optim"], 82 | identifier=config["identifier"], 83 | run_dir=config.get("run_dir", "./"), 84 | is_debug=config.get("is_debug", False), 85 | is_vis=config.get("is_vis", False), 86 | print_every=config.get("print_every", 10), 87 | seed=config.get("seed", 0), 88 | logger=config.get("logger", "tensorboard"), 89 | local_rank=config["local_rank"], 90 | amp=config.get("amp", False), 91 | cpu=config.get("cpu", False), 92 | slurm=config.get("slurm", {}), 93 | ) 94 | 95 | # Fitting loop 96 | logging.info("Start fitting") 97 | 98 | if not AutomaticFit.fitting_completed(): 99 | with torch.no_grad(): 100 | trainer.model.eval() 101 | for _ in trange(len(AutomaticFit.queue) + 1): 102 | assert ( 103 | trainer.val_loader is not None 104 | ), "Val dataset is required for making predictions" 105 | 106 | for i, batch in enumerate(trainer.val_loader): 107 | with torch.cuda.amp.autocast( 108 | enabled=trainer.scaler is not None 109 | ): 110 | out = trainer._forward(batch) 111 | loss = trainer._compute_loss(out, batch) 112 | del out, loss 113 | if i == num_batches: 114 | break 115 | 116 | current_var = AutomaticFit.activeVar 117 | if current_var is not None: 118 | current_var.fit() # fit current variable 119 | else: 120 | print("Found no variable to fit. Something went wrong!") 121 | 122 | assert AutomaticFit.fitting_completed() 123 | logging.info(f"Fitting done. Results saved to: {scale_file}") 124 | -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/gemnet-dT.json: -------------------------------------------------------------------------------- 1 | { 2 | "comment": "tri_gaussian128", 3 | "TripInteraction_1_had_rbf": 18.873615264892578, 4 | "TripInteraction_1_sum_cbf": 7.996850490570068, 5 | "AtomUpdate_1_sum": 1.220463752746582, 6 | "TripInteraction_2_had_rbf": 16.10817527770996, 7 | "TripInteraction_2_sum_cbf": 7.614634037017822, 8 | "AtomUpdate_2_sum": 0.9690994620323181, 9 | "TripInteraction_3_had_rbf": 15.01930046081543, 10 | "TripInteraction_3_sum_cbf": 7.025179862976074, 11 | "AtomUpdate_3_sum": 0.8903237581253052, 12 | "OutBlock_0_sum": 1.6437848806381226, 13 | "OutBlock_0_had": 16.161039352416992, 14 | "OutBlock_1_sum": 1.1077653169631958, 15 | "OutBlock_1_had": 13.54678726196289, 16 | "OutBlock_2_sum": 0.9477927684783936, 17 | "OutBlock_2_had": 12.754337310791016, 18 | "OutBlock_3_sum": 0.9059251546859741, 19 | "OutBlock_3_had": 13.484951972961426 20 | } 21 | -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/initializers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | 10 | 11 | def _standardize(kernel): 12 | """ 13 | Makes sure that N*Var(W) = 1 and E[W] = 0 14 | """ 15 | eps = 1e-6 16 | 17 | if len(kernel.shape) == 3: 18 | axis = [0, 1] # last dimension is output dimension 19 | else: 20 | axis = 1 21 | 22 | var, mean = torch.var_mean(kernel, dim=axis, unbiased=True, keepdim=True) 23 | kernel = (kernel - mean) / (var + eps) ** 0.5 24 | return kernel 25 | 26 | 27 | def he_orthogonal_init(tensor): 28 | """ 29 | Generate a weight matrix with variance according to He (Kaiming) initialization. 30 | Based on a random (semi-)orthogonal matrix neural networks 31 | are expected to learn better when features are decorrelated 32 | (stated by eg. "Reducing overfitting in deep networks by decorrelating representations", 33 | "Dropout: a simple way to prevent neural networks from overfitting", 34 | "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks") 35 | """ 36 | tensor = torch.nn.init.orthogonal_(tensor) 37 | 38 | if len(tensor.shape) == 3: 39 | fan_in = tensor.shape[:-1].numel() 40 | else: 41 | fan_in = tensor.shape[1] 42 | 43 | with torch.no_grad(): 44 | tensor.data = _standardize(tensor.data) 45 | tensor.data *= (1 / fan_in) ** 0.5 46 | 47 | return tensor 48 | -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/__pycache__/atom_update_block.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/gemnet/layers/__pycache__/atom_update_block.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/__pycache__/base_layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/gemnet/layers/__pycache__/base_layers.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/__pycache__/basis_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/gemnet/layers/__pycache__/basis_utils.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/__pycache__/efficient.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/gemnet/layers/__pycache__/efficient.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/__pycache__/embedding_block.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/gemnet/layers/__pycache__/embedding_block.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/__pycache__/interaction_block.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/gemnet/layers/__pycache__/interaction_block.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/__pycache__/radial_basis.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/gemnet/layers/__pycache__/radial_basis.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/__pycache__/scaling.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/gemnet/layers/__pycache__/scaling.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/__pycache__/spherical_basis.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/concdvae/pl_modules/gemnet/layers/__pycache__/spherical_basis.cpython-38.pyc -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/atom_update_block.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | from torch_scatter import scatter 10 | 11 | from ..initializers import he_orthogonal_init 12 | from .base_layers import Dense, ResidualLayer 13 | from .scaling import ScalingFactor 14 | 15 | 16 | class AtomUpdateBlock(torch.nn.Module): 17 | """ 18 | Aggregate the message embeddings of the atoms 19 | 20 | Parameters 21 | ---------- 22 | emb_size_atom: int 23 | Embedding size of the atoms. 24 | emb_size_atom: int 25 | Embedding size of the edges. 26 | nHidden: int 27 | Number of residual blocks. 28 | activation: callable/str 29 | Name of the activation function to use in the dense layers. 30 | scale_file: str 31 | Path to the json file containing the scaling factors. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | emb_size_atom: int, 37 | emb_size_edge: int, 38 | emb_size_rbf: int, 39 | nHidden: int, 40 | activation=None, 41 | scale_file=None, 42 | name: str = "atom_update", 43 | ): 44 | super().__init__() 45 | self.name = name 46 | 47 | self.dense_rbf = Dense( 48 | emb_size_rbf, emb_size_edge, activation=None, bias=False 49 | ) 50 | self.scale_sum = ScalingFactor( 51 | scale_file=scale_file, name=name + "_sum" 52 | ) 53 | 54 | self.layers = self.get_mlp( 55 | emb_size_edge, emb_size_atom, nHidden, activation 56 | ) 57 | 58 | def get_mlp(self, units_in, units, nHidden, activation): 59 | dense1 = Dense(units_in, units, activation=activation, bias=False) 60 | mlp = [dense1] 61 | res = [ 62 | ResidualLayer(units, nLayers=2, activation=activation) 63 | for i in range(nHidden) 64 | ] 65 | mlp += res 66 | return torch.nn.ModuleList(mlp) 67 | 68 | def forward(self, h, m, rbf, id_j): 69 | """ 70 | Returns 71 | ------- 72 | h: torch.Tensor, shape=(nAtoms, emb_size_atom) 73 | Atom embedding. 74 | """ 75 | nAtoms = h.shape[0] 76 | 77 | mlp_rbf = self.dense_rbf(rbf) # (nEdges, emb_size_edge) 78 | x = m * mlp_rbf 79 | 80 | x2 = scatter(x, id_j, dim=0, dim_size=nAtoms, reduce="sum") 81 | # (nAtoms, emb_size_edge) 82 | x = self.scale_sum(m, x2) 83 | 84 | for layer in self.layers: 85 | x = layer(x) # (nAtoms, emb_size_atom) 86 | 87 | return x 88 | 89 | 90 | class OutputBlock(AtomUpdateBlock): 91 | """ 92 | Combines the atom update block and subsequent final dense layer. 93 | 94 | Parameters 95 | ---------- 96 | emb_size_atom: int 97 | Embedding size of the atoms. 98 | emb_size_atom: int 99 | Embedding size of the edges. 100 | nHidden: int 101 | Number of residual blocks. 102 | num_targets: int 103 | Number of targets. 104 | activation: str 105 | Name of the activation function to use in the dense layers except for the final dense layer. 106 | direct_forces: bool 107 | If true directly predict forces without taking the gradient of the energy potential. 108 | output_init: int 109 | Kernel initializer of the final dense layer. 110 | scale_file: str 111 | Path to the json file containing the scaling factors. 112 | """ 113 | 114 | def __init__( 115 | self, 116 | emb_size_atom: int, 117 | emb_size_edge: int, 118 | emb_size_rbf: int, 119 | nHidden: int, 120 | num_targets: int, 121 | activation=None, 122 | direct_forces=True, 123 | output_init="HeOrthogonal", 124 | scale_file=None, 125 | name: str = "output", 126 | **kwargs, 127 | ): 128 | 129 | super().__init__( 130 | name=name, 131 | emb_size_atom=emb_size_atom, 132 | emb_size_edge=emb_size_edge, 133 | emb_size_rbf=emb_size_rbf, 134 | nHidden=nHidden, 135 | activation=activation, 136 | scale_file=scale_file, 137 | **kwargs, 138 | ) 139 | 140 | assert isinstance(output_init, str) 141 | self.output_init = output_init.lower() 142 | self.direct_forces = direct_forces 143 | 144 | self.seq_energy = self.layers # inherited from parent class 145 | self.out_energy = Dense( 146 | emb_size_atom, num_targets, bias=False, activation=None 147 | ) 148 | 149 | if self.direct_forces: 150 | self.scale_rbf_F = ScalingFactor( 151 | scale_file=scale_file, name=name + "_had" 152 | ) 153 | self.seq_forces = self.get_mlp( 154 | emb_size_edge, emb_size_edge, nHidden, activation 155 | ) 156 | self.out_forces = Dense( 157 | emb_size_edge, num_targets, bias=False, activation=None 158 | ) 159 | self.dense_rbf_F = Dense( 160 | emb_size_rbf, emb_size_edge, activation=None, bias=False 161 | ) 162 | 163 | self.reset_parameters() 164 | 165 | def reset_parameters(self): 166 | if self.output_init == "heorthogonal": 167 | self.out_energy.reset_parameters(he_orthogonal_init) 168 | if self.direct_forces: 169 | self.out_forces.reset_parameters(he_orthogonal_init) 170 | elif self.output_init == "zeros": 171 | self.out_energy.reset_parameters(torch.nn.init.zeros_) 172 | if self.direct_forces: 173 | self.out_forces.reset_parameters(torch.nn.init.zeros_) 174 | else: 175 | raise UserWarning(f"Unknown output_init: {self.output_init}") 176 | 177 | def forward(self, h, m, rbf, id_j): 178 | """ 179 | Returns 180 | ------- 181 | (E, F): tuple 182 | - E: torch.Tensor, shape=(nAtoms, num_targets) 183 | - F: torch.Tensor, shape=(nEdges, num_targets) 184 | Energy and force prediction 185 | """ 186 | nAtoms = h.shape[0] 187 | 188 | # -------------------------------------- Energy Prediction -------------------------------------- # 189 | rbf_emb_E = self.dense_rbf(rbf) # (nEdges, emb_size_edge) 190 | x = m * rbf_emb_E 191 | 192 | x_E = scatter(x, id_j, dim=0, dim_size=nAtoms, reduce="sum") 193 | # (nAtoms, emb_size_edge) 194 | x_E = self.scale_sum(m, x_E) 195 | 196 | for layer in self.seq_energy: 197 | x_E = layer(x_E) # (nAtoms, emb_size_atom) 198 | 199 | x_E = self.out_energy(x_E) # (nAtoms, num_targets) 200 | 201 | # --------------------------------------- Force Prediction -------------------------------------- # 202 | if self.direct_forces: 203 | x_F = m 204 | for i, layer in enumerate(self.seq_forces): 205 | x_F = layer(x_F) # (nEdges, emb_size_edge) 206 | 207 | rbf_emb_F = self.dense_rbf_F(rbf) # (nEdges, emb_size_edge) 208 | x_F_rbf = x_F * rbf_emb_F 209 | x_F = self.scale_rbf_F(x_F, x_F_rbf) 210 | 211 | x_F = self.out_forces(x_F) # (nEdges, num_targets) 212 | else: 213 | x_F = 0 214 | # ----------------------------------------------------------------------------------------------- # 215 | 216 | return x_E, x_F 217 | -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/base_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import math 9 | 10 | import torch 11 | 12 | from ..initializers import he_orthogonal_init 13 | 14 | 15 | class Dense(torch.nn.Module): 16 | """ 17 | Combines dense layer with scaling for swish activation. 18 | 19 | Parameters 20 | ---------- 21 | units: int 22 | Output embedding size. 23 | activation: str 24 | Name of the activation function to use. 25 | bias: bool 26 | True if use bias. 27 | """ 28 | 29 | def __init__(self, in_features, out_features, bias=False, activation=None): 30 | super().__init__() 31 | 32 | self.linear = torch.nn.Linear(in_features, out_features, bias=bias) 33 | self.reset_parameters() 34 | 35 | if isinstance(activation, str): 36 | activation = activation.lower() 37 | if activation in ["swish", "silu"]: 38 | self._activation = ScaledSiLU() 39 | elif activation == "siqu": 40 | self._activation = SiQU() 41 | elif activation is None: 42 | self._activation = torch.nn.Identity() 43 | else: 44 | raise NotImplementedError( 45 | "Activation function not implemented for GemNet (yet)." 46 | ) 47 | 48 | def reset_parameters(self, initializer=he_orthogonal_init): 49 | initializer(self.linear.weight) 50 | if self.linear.bias is not None: 51 | self.linear.bias.data.fill_(0) 52 | 53 | def forward(self, x): 54 | x = self.linear(x) 55 | x = self._activation(x) 56 | return x 57 | 58 | 59 | class ScaledSiLU(torch.nn.Module): 60 | def __init__(self): 61 | super().__init__() 62 | self.scale_factor = 1 / 0.6 63 | self._activation = torch.nn.SiLU() 64 | 65 | def forward(self, x): 66 | return self._activation(x) * self.scale_factor 67 | 68 | 69 | class SiQU(torch.nn.Module): 70 | def __init__(self): 71 | super().__init__() 72 | self._activation = torch.nn.SiLU() 73 | 74 | def forward(self, x): 75 | return x * self._activation(x) 76 | 77 | 78 | class ResidualLayer(torch.nn.Module): 79 | """ 80 | Residual block with output scaled by 1/sqrt(2). 81 | 82 | Parameters 83 | ---------- 84 | units: int 85 | Output embedding size. 86 | nLayers: int 87 | Number of dense layers. 88 | layer_kwargs: str 89 | Keyword arguments for initializing the layers. 90 | """ 91 | 92 | def __init__( 93 | self, units: int, nLayers: int = 2, layer=Dense, **layer_kwargs 94 | ): 95 | super().__init__() 96 | self.dense_mlp = torch.nn.Sequential( 97 | *[ 98 | layer( 99 | in_features=units, 100 | out_features=units, 101 | bias=False, 102 | **layer_kwargs 103 | ) 104 | for _ in range(nLayers) 105 | ] 106 | ) 107 | self.inv_sqrt_2 = 1 / math.sqrt(2) 108 | 109 | def forward(self, input): 110 | x = self.dense_mlp(input) 111 | x = input + x 112 | x = x * self.inv_sqrt_2 113 | return x 114 | -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/basis_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import sympy as sym 10 | from scipy import special as sp 11 | from scipy.optimize import brentq 12 | 13 | 14 | def Jn(r, n): 15 | """ 16 | numerical spherical bessel functions of order n 17 | """ 18 | return sp.spherical_jn(n, r) 19 | 20 | 21 | def Jn_zeros(n, k): 22 | """ 23 | Compute the first k zeros of the spherical bessel functions up to order n (excluded) 24 | """ 25 | zerosj = np.zeros((n, k), dtype="float32") 26 | zerosj[0] = np.arange(1, k + 1) * np.pi 27 | points = np.arange(1, k + n) * np.pi 28 | racines = np.zeros(k + n - 1, dtype="float32") 29 | for i in range(1, n): 30 | for j in range(k + n - 1 - i): 31 | foo = brentq(Jn, points[j], points[j + 1], (i,)) 32 | racines[j] = foo 33 | points = racines 34 | zerosj[i][:k] = racines[:k] 35 | 36 | return zerosj 37 | 38 | 39 | def spherical_bessel_formulas(n): 40 | """ 41 | Computes the sympy formulas for the spherical bessel functions up to order n (excluded) 42 | """ 43 | x = sym.symbols("x") 44 | # j_i = (-x)^i * (1/x * d/dx)^î * sin(x)/x 45 | j = [sym.sin(x) / x] # j_0 46 | a = sym.sin(x) / x 47 | for i in range(1, n): 48 | b = sym.diff(a, x) / x 49 | j += [sym.simplify(b * (-x) ** i)] 50 | a = sym.simplify(b) 51 | return j 52 | 53 | 54 | def bessel_basis(n, k): 55 | """ 56 | Compute the sympy formulas for the normalized and rescaled spherical bessel functions up to 57 | order n (excluded) and maximum frequency k (excluded). 58 | 59 | Returns: 60 | bess_basis: list 61 | Bessel basis formulas taking in a single argument x. 62 | Has length n where each element has length k. -> In total n*k many. 63 | """ 64 | zeros = Jn_zeros(n, k) 65 | normalizer = [] 66 | for order in range(n): 67 | normalizer_tmp = [] 68 | for i in range(k): 69 | normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1) ** 2] 70 | normalizer_tmp = ( 71 | 1 / np.array(normalizer_tmp) ** 0.5 72 | ) # sqrt(2/(j_l+1)**2) , sqrt(1/c**3) not taken into account yet 73 | normalizer += [normalizer_tmp] 74 | 75 | f = spherical_bessel_formulas(n) 76 | x = sym.symbols("x") 77 | bess_basis = [] 78 | for order in range(n): 79 | bess_basis_tmp = [] 80 | for i in range(k): 81 | bess_basis_tmp += [ 82 | sym.simplify( 83 | normalizer[order][i] 84 | * f[order].subs(x, zeros[order, i] * x) 85 | ) 86 | ] 87 | bess_basis += [bess_basis_tmp] 88 | return bess_basis 89 | 90 | 91 | def sph_harm_prefactor(l_degree, m_order): 92 | """Computes the constant pre-factor for the spherical harmonic of degree l and order m. 93 | 94 | Parameters 95 | ---------- 96 | l_degree: int 97 | Degree of the spherical harmonic. l >= 0 98 | m_order: int 99 | Order of the spherical harmonic. -l <= m <= l 100 | 101 | Returns 102 | ------- 103 | factor: float 104 | 105 | """ 106 | # sqrt((2*l+1)/4*pi * (l-m)!/(l+m)! ) 107 | return ( 108 | (2 * l_degree + 1) 109 | / (4 * np.pi) 110 | * np.math.factorial(l_degree - abs(m_order)) 111 | / np.math.factorial(l_degree + abs(m_order)) 112 | ) ** 0.5 113 | 114 | 115 | def associated_legendre_polynomials( 116 | L_maxdegree, zero_m_only=True, pos_m_only=True 117 | ): 118 | """Computes string formulas of the associated legendre polynomials up to degree L (excluded). 119 | 120 | Parameters 121 | ---------- 122 | L_maxdegree: int 123 | Degree up to which to calculate the associated legendre polynomials (degree L is excluded). 124 | zero_m_only: bool 125 | If True only calculate the polynomials for the polynomials where m=0. 126 | pos_m_only: bool 127 | If True only calculate the polynomials for the polynomials where m>=0. Overwritten by zero_m_only. 128 | 129 | Returns 130 | ------- 131 | polynomials: list 132 | Contains the sympy functions of the polynomials (in total L many if zero_m_only is True else L^2 many). 133 | """ 134 | # calculations from http://web.cmb.usc.edu/people/alber/Software/tomominer/docs/cpp/group__legendre__polynomials.html 135 | z = sym.symbols("z") 136 | P_l_m = [ 137 | [0] * (2 * l_degree + 1) for l_degree in range(L_maxdegree) 138 | ] # for order l: -l <= m <= l 139 | 140 | P_l_m[0][0] = 1 141 | if L_maxdegree > 0: 142 | if zero_m_only: 143 | # m = 0 144 | P_l_m[1][0] = z 145 | for l_degree in range(2, L_maxdegree): 146 | P_l_m[l_degree][0] = sym.simplify( 147 | ( 148 | (2 * l_degree - 1) * z * P_l_m[l_degree - 1][0] 149 | - (l_degree - 1) * P_l_m[l_degree - 2][0] 150 | ) 151 | / l_degree 152 | ) 153 | return P_l_m 154 | else: 155 | # for m >= 0 156 | for l_degree in range(1, L_maxdegree): 157 | P_l_m[l_degree][l_degree] = sym.simplify( 158 | (1 - 2 * l_degree) 159 | * (1 - z ** 2) ** 0.5 160 | * P_l_m[l_degree - 1][l_degree - 1] 161 | ) # P_00, P_11, P_22, P_33 162 | 163 | for m_order in range(0, L_maxdegree - 1): 164 | P_l_m[m_order + 1][m_order] = sym.simplify( 165 | (2 * m_order + 1) * z * P_l_m[m_order][m_order] 166 | ) # P_10, P_21, P_32, P_43 167 | 168 | for l_degree in range(2, L_maxdegree): 169 | for m_order in range(l_degree - 1): # P_20, P_30, P_31 170 | P_l_m[l_degree][m_order] = sym.simplify( 171 | ( 172 | (2 * l_degree - 1) 173 | * z 174 | * P_l_m[l_degree - 1][m_order] 175 | - (l_degree + m_order - 1) 176 | * P_l_m[l_degree - 2][m_order] 177 | ) 178 | / (l_degree - m_order) 179 | ) 180 | 181 | if not pos_m_only: 182 | # for m < 0: P_l(-m) = (-1)^m * (l-m)!/(l+m)! * P_lm 183 | for l_degree in range(1, L_maxdegree): 184 | for m_order in range( 185 | 1, l_degree + 1 186 | ): # P_1(-1), P_2(-1) P_2(-2) 187 | P_l_m[l_degree][-m_order] = sym.simplify( 188 | (-1) ** m_order 189 | * np.math.factorial(l_degree - m_order) 190 | / np.math.factorial(l_degree + m_order) 191 | * P_l_m[l_degree][m_order] 192 | ) 193 | 194 | return P_l_m 195 | 196 | 197 | def real_sph_harm(L_maxdegree, use_theta, use_phi=True, zero_m_only=True): 198 | """ 199 | Computes formula strings of the the real part of the spherical harmonics up to degree L (excluded). 200 | Variables are either spherical coordinates phi and theta (or cartesian coordinates x,y,z) on the UNIT SPHERE. 201 | 202 | Parameters 203 | ---------- 204 | L_maxdegree: int 205 | Degree up to which to calculate the spherical harmonics (degree L is excluded). 206 | use_theta: bool 207 | - True: Expects the input of the formula strings to contain theta. 208 | - False: Expects the input of the formula strings to contain z. 209 | use_phi: bool 210 | - True: Expects the input of the formula strings to contain phi. 211 | - False: Expects the input of the formula strings to contain x and y. 212 | Does nothing if zero_m_only is True 213 | zero_m_only: bool 214 | If True only calculate the harmonics where m=0. 215 | 216 | Returns 217 | ------- 218 | Y_lm_real: list 219 | Computes formula strings of the the real part of the spherical harmonics up 220 | to degree L (where degree L is not excluded). 221 | In total L^2 many sph harm exist up to degree L (excluded). However, if zero_m_only only is True then 222 | the total count is reduced to be only L many. 223 | """ 224 | z = sym.symbols("z") 225 | P_l_m = associated_legendre_polynomials(L_maxdegree, zero_m_only) 226 | if zero_m_only: 227 | # for all m != 0: Y_lm = 0 228 | Y_l_m = [[0] for l_degree in range(L_maxdegree)] 229 | else: 230 | Y_l_m = [ 231 | [0] * (2 * l_degree + 1) for l_degree in range(L_maxdegree) 232 | ] # for order l: -l <= m <= l 233 | 234 | # convert expressions to spherical coordiantes 235 | if use_theta: 236 | # replace z by cos(theta) 237 | theta = sym.symbols("theta") 238 | for l_degree in range(L_maxdegree): 239 | for m_order in range(len(P_l_m[l_degree])): 240 | if not isinstance(P_l_m[l_degree][m_order], int): 241 | P_l_m[l_degree][m_order] = P_l_m[l_degree][m_order].subs( 242 | z, sym.cos(theta) 243 | ) 244 | 245 | ## calculate Y_lm 246 | # Y_lm = N * P_lm(cos(theta)) * exp(i*m*phi) 247 | # { sqrt(2) * (-1)^m * N * P_l|m| * sin(|m|*phi) if m < 0 248 | # Y_lm_real = { Y_lm if m = 0 249 | # { sqrt(2) * (-1)^m * N * P_lm * cos(m*phi) if m > 0 250 | 251 | for l_degree in range(L_maxdegree): 252 | Y_l_m[l_degree][0] = sym.simplify( 253 | sph_harm_prefactor(l_degree, 0) * P_l_m[l_degree][0] 254 | ) # Y_l0 255 | 256 | if not zero_m_only: 257 | phi = sym.symbols("phi") 258 | for l_degree in range(1, L_maxdegree): 259 | # m > 0 260 | for m_order in range(1, l_degree + 1): 261 | Y_l_m[l_degree][m_order] = sym.simplify( 262 | 2 ** 0.5 263 | * (-1) ** m_order 264 | * sph_harm_prefactor(l_degree, m_order) 265 | * P_l_m[l_degree][m_order] 266 | * sym.cos(m_order * phi) 267 | ) 268 | # m < 0 269 | for m_order in range(1, l_degree + 1): 270 | Y_l_m[l_degree][-m_order] = sym.simplify( 271 | 2 ** 0.5 272 | * (-1) ** m_order 273 | * sph_harm_prefactor(l_degree, -m_order) 274 | * P_l_m[l_degree][m_order] 275 | * sym.sin(m_order * phi) 276 | ) 277 | 278 | # convert expressions to cartesian coordinates 279 | if not use_phi: 280 | # replace phi by atan2(y,x) 281 | x = sym.symbols("x") 282 | y = sym.symbols("y") 283 | for l_degree in range(L_maxdegree): 284 | for m_order in range(len(Y_l_m[l_degree])): 285 | Y_l_m[l_degree][m_order] = sym.simplify( 286 | Y_l_m[l_degree][m_order].subs(phi, sym.atan2(y, x)) 287 | ) 288 | return Y_l_m 289 | -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/efficient.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | 10 | from ..initializers import he_orthogonal_init 11 | 12 | 13 | class EfficientInteractionDownProjection(torch.nn.Module): 14 | """ 15 | Down projection in the efficient reformulation. 16 | 17 | Parameters 18 | ---------- 19 | emb_size_interm: int 20 | Intermediate embedding size (down-projection size). 21 | kernel_initializer: callable 22 | Initializer of the weight matrix. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | num_spherical: int, 28 | num_radial: int, 29 | emb_size_interm: int, 30 | ): 31 | super().__init__() 32 | 33 | self.num_spherical = num_spherical 34 | self.num_radial = num_radial 35 | self.emb_size_interm = emb_size_interm 36 | 37 | self.reset_parameters() 38 | 39 | def reset_parameters(self): 40 | self.weight = torch.nn.Parameter( 41 | torch.empty( 42 | (self.num_spherical, self.num_radial, self.emb_size_interm) 43 | ), 44 | requires_grad=True, 45 | ) 46 | he_orthogonal_init(self.weight) 47 | 48 | def forward(self, rbf, sph, id_ca, id_ragged_idx): 49 | """ 50 | 51 | Arguments 52 | --------- 53 | rbf: torch.Tensor, shape=(1, nEdges, num_radial) 54 | sph: torch.Tensor, shape=(nEdges, Kmax, num_spherical) 55 | id_ca 56 | id_ragged_idx 57 | 58 | Returns 59 | ------- 60 | rbf_W1: torch.Tensor, shape=(nEdges, emb_size_interm, num_spherical) 61 | sph: torch.Tensor, shape=(nEdges, Kmax, num_spherical) 62 | Kmax = maximum number of neighbors of the edges 63 | """ 64 | num_edges = rbf.shape[1] 65 | 66 | # MatMul: mul + sum over num_radial 67 | rbf_W1 = torch.matmul(rbf, self.weight) 68 | # (num_spherical, nEdges , emb_size_interm) 69 | rbf_W1 = rbf_W1.permute(1, 2, 0) 70 | # (nEdges, emb_size_interm, num_spherical) 71 | 72 | # Zero padded dense matrix 73 | # maximum number of neighbors, catch empty id_ca with maximum 74 | if sph.shape[0] == 0: 75 | Kmax = 0 76 | else: 77 | Kmax = torch.max( 78 | torch.max(id_ragged_idx + 1), 79 | torch.tensor(0).to(id_ragged_idx.device), 80 | ) 81 | 82 | sph2 = sph.new_zeros(num_edges, Kmax, self.num_spherical) 83 | sph2[id_ca, id_ragged_idx] = sph 84 | 85 | sph2 = torch.transpose(sph2, 1, 2) 86 | # (nEdges, num_spherical/emb_size_interm, Kmax) 87 | 88 | return rbf_W1, sph2 89 | 90 | 91 | class EfficientInteractionBilinear(torch.nn.Module): 92 | """ 93 | Efficient reformulation of the bilinear layer and subsequent summation. 94 | 95 | Parameters 96 | ---------- 97 | units_out: int 98 | Embedding output size of the bilinear layer. 99 | kernel_initializer: callable 100 | Initializer of the weight matrix. 101 | """ 102 | 103 | def __init__( 104 | self, 105 | emb_size: int, 106 | emb_size_interm: int, 107 | units_out: int, 108 | ): 109 | super().__init__() 110 | self.emb_size = emb_size 111 | self.emb_size_interm = emb_size_interm 112 | self.units_out = units_out 113 | 114 | self.reset_parameters() 115 | 116 | def reset_parameters(self): 117 | self.weight = torch.nn.Parameter( 118 | torch.empty( 119 | (self.emb_size, self.emb_size_interm, self.units_out), 120 | requires_grad=True, 121 | ) 122 | ) 123 | he_orthogonal_init(self.weight) 124 | 125 | def forward( 126 | self, 127 | basis, 128 | m, 129 | id_reduce, 130 | id_ragged_idx, 131 | ): 132 | """ 133 | 134 | Arguments 135 | --------- 136 | basis 137 | m: quadruplets: m = m_db , triplets: m = m_ba 138 | id_reduce 139 | id_ragged_idx 140 | 141 | Returns 142 | ------- 143 | m_ca: torch.Tensor, shape=(nEdges, units_out) 144 | Edge embeddings. 145 | """ 146 | # num_spherical is actually num_spherical**2 for quadruplets 147 | (rbf_W1, sph) = basis 148 | # (nEdges, emb_size_interm, num_spherical), (nEdges, num_spherical, Kmax) 149 | nEdges = rbf_W1.shape[0] 150 | 151 | # Create (zero-padded) dense matrix of the neighboring edge embeddings. 152 | Kmax = torch.max( 153 | torch.max(id_ragged_idx) + 1, 154 | torch.tensor(0).to(id_ragged_idx.device), 155 | ) 156 | # maximum number of neighbors, catch empty id_reduce_ji with maximum 157 | m2 = m.new_zeros(nEdges, Kmax, self.emb_size) 158 | m2[id_reduce, id_ragged_idx] = m 159 | # (num_quadruplets or num_triplets, emb_size) -> (nEdges, Kmax, emb_size) 160 | 161 | sum_k = torch.matmul(sph, m2) # (nEdges, num_spherical, emb_size) 162 | 163 | # MatMul: mul + sum over num_spherical 164 | rbf_W1_sum_k = torch.matmul(rbf_W1, sum_k) 165 | # (nEdges, emb_size_interm, emb_size) 166 | 167 | # Bilinear: Sum over emb_size_interm and emb_size 168 | m_ca = torch.matmul(rbf_W1_sum_k.permute(2, 0, 1), self.weight) 169 | # (emb_size, nEdges, units_out) 170 | m_ca = torch.sum(m_ca, dim=0) 171 | # (nEdges, units_out) 172 | 173 | return m_ca 174 | -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/embedding_block.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from .base_layers import Dense 12 | 13 | from concdvae.pl_modules.embeddings import MAX_ATOMIC_NUM 14 | 15 | 16 | class AtomEmbedding(torch.nn.Module): 17 | """ 18 | Initial atom embeddings based on the atom type 19 | 20 | Parameters 21 | ---------- 22 | emb_size: int 23 | Atom embeddings size 24 | """ 25 | 26 | def __init__(self, emb_size): 27 | super().__init__() 28 | self.emb_size = emb_size 29 | 30 | # Atom embeddings: We go up to Bi (83). 31 | self.embeddings = torch.nn.Embedding(MAX_ATOMIC_NUM, emb_size) 32 | # init by uniform distribution 33 | torch.nn.init.uniform_( 34 | self.embeddings.weight, a=-np.sqrt(3), b=np.sqrt(3) 35 | ) 36 | 37 | def forward(self, Z): 38 | """ 39 | Returns 40 | ------- 41 | h: torch.Tensor, shape=(nAtoms, emb_size) 42 | Atom embeddings. 43 | """ 44 | h = self.embeddings(Z - 1) # -1 because Z.min()=1 (==Hydrogen) 45 | return h 46 | 47 | 48 | class EdgeEmbedding(torch.nn.Module): 49 | """ 50 | Edge embedding based on the concatenation of atom embeddings and subsequent dense layer. 51 | 52 | Parameters 53 | ---------- 54 | emb_size: int 55 | Embedding size after the dense layer. 56 | activation: str 57 | Activation function used in the dense layer. 58 | """ 59 | 60 | def __init__( 61 | self, 62 | atom_features, 63 | edge_features, 64 | out_features, 65 | activation=None, 66 | ): 67 | super().__init__() 68 | in_features = 2 * atom_features + edge_features 69 | self.dense = Dense( 70 | in_features, out_features, activation=activation, bias=False 71 | ) 72 | 73 | def forward( 74 | self, 75 | h, 76 | m_rbf, 77 | idx_s, 78 | idx_t, 79 | ): 80 | """ 81 | 82 | Arguments 83 | --------- 84 | h 85 | m_rbf: shape (nEdges, nFeatures) 86 | in embedding block: m_rbf = rbf ; In interaction block: m_rbf = m_st 87 | idx_s 88 | idx_t 89 | 90 | Returns 91 | ------- 92 | m_st: torch.Tensor, shape=(nEdges, emb_size) 93 | Edge embeddings. 94 | """ 95 | h_s = h[idx_s] # shape=(nEdges, emb_size) 96 | h_t = h[idx_t] # shape=(nEdges, emb_size) 97 | 98 | m_st = torch.cat( 99 | [h_s, h_t, m_rbf], dim=-1 100 | ) # (nEdges, 2*emb_size+nFeatures) 101 | m_st = self.dense(m_st) # (nEdges, emb_size) 102 | return m_st 103 | -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/interaction_block.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import math 9 | 10 | import torch 11 | 12 | from .atom_update_block import AtomUpdateBlock 13 | from .base_layers import Dense, ResidualLayer 14 | from .efficient import ( 15 | EfficientInteractionBilinear, 16 | ) 17 | from .embedding_block import EdgeEmbedding 18 | from .scaling import ScalingFactor 19 | 20 | 21 | class InteractionBlockTripletsOnly(torch.nn.Module): 22 | """ 23 | Interaction block for GemNet-T/dT. 24 | 25 | Parameters 26 | ---------- 27 | emb_size_atom: int 28 | Embedding size of the atoms. 29 | emb_size_edge: int 30 | Embedding size of the edges. 31 | emb_size_trip: int 32 | (Down-projected) Embedding size in the triplet message passing block. 33 | emb_size_rbf: int 34 | Embedding size of the radial basis transformation. 35 | emb_size_cbf: int 36 | Embedding size of the circular basis transformation (one angle). 37 | 38 | emb_size_bil_trip: int 39 | Embedding size of the edge embeddings in the triplet-based message passing block after the bilinear layer. 40 | num_before_skip: int 41 | Number of residual blocks before the first skip connection. 42 | num_after_skip: int 43 | Number of residual blocks after the first skip connection. 44 | num_concat: int 45 | Number of residual blocks after the concatenation. 46 | num_atom: int 47 | Number of residual blocks in the atom embedding blocks. 48 | 49 | activation: str 50 | Name of the activation function to use in the dense layers except for the final dense layer. 51 | scale_file: str 52 | Path to the json file containing the scaling factors. 53 | """ 54 | 55 | def __init__( 56 | self, 57 | emb_size_atom, 58 | emb_size_edge, 59 | emb_size_trip, 60 | emb_size_rbf, 61 | emb_size_cbf, 62 | emb_size_bil_trip, 63 | num_before_skip, 64 | num_after_skip, 65 | num_concat, 66 | num_atom, 67 | activation=None, 68 | scale_file=None, 69 | name="Interaction", 70 | ): 71 | super().__init__() 72 | self.name = name 73 | 74 | block_nr = name.split("_")[-1] 75 | 76 | ## -------------------------------------------- Message Passing ------------------------------------------- ## 77 | # Dense transformation of skip connection 78 | self.dense_ca = Dense( 79 | emb_size_edge, 80 | emb_size_edge, 81 | activation=activation, 82 | bias=False, 83 | ) 84 | 85 | # Triplet Interaction 86 | self.trip_interaction = TripletInteraction( 87 | emb_size_edge=emb_size_edge, 88 | emb_size_trip=emb_size_trip, 89 | emb_size_bilinear=emb_size_bil_trip, 90 | emb_size_rbf=emb_size_rbf, 91 | emb_size_cbf=emb_size_cbf, 92 | activation=activation, 93 | scale_file=scale_file, 94 | name=f"TripInteraction_{block_nr}", 95 | ) 96 | 97 | ## ---------------------------------------- Update Edge Embeddings ---------------------------------------- ## 98 | # Residual layers before skip connection 99 | self.layers_before_skip = torch.nn.ModuleList( 100 | [ 101 | ResidualLayer( 102 | emb_size_edge, 103 | activation=activation, 104 | ) 105 | for i in range(num_before_skip) 106 | ] 107 | ) 108 | 109 | # Residual layers after skip connection 110 | self.layers_after_skip = torch.nn.ModuleList( 111 | [ 112 | ResidualLayer( 113 | emb_size_edge, 114 | activation=activation, 115 | ) 116 | for i in range(num_after_skip) 117 | ] 118 | ) 119 | 120 | ## ---------------------------------------- Update Atom Embeddings ---------------------------------------- ## 121 | self.atom_update = AtomUpdateBlock( 122 | emb_size_atom=emb_size_atom, 123 | emb_size_edge=emb_size_edge, 124 | emb_size_rbf=emb_size_rbf, 125 | nHidden=num_atom, 126 | activation=activation, 127 | scale_file=scale_file, 128 | name=f"AtomUpdate_{block_nr}", 129 | ) 130 | 131 | ## ------------------------------ Update Edge Embeddings with Atom Embeddings ----------------------------- ## 132 | self.concat_layer = EdgeEmbedding( 133 | emb_size_atom, 134 | emb_size_edge, 135 | emb_size_edge, 136 | activation=activation, 137 | ) 138 | self.residual_m = torch.nn.ModuleList( 139 | [ 140 | ResidualLayer(emb_size_edge, activation=activation) 141 | for _ in range(num_concat) 142 | ] 143 | ) 144 | 145 | self.inv_sqrt_2 = 1 / math.sqrt(2.0) 146 | 147 | def forward( 148 | self, 149 | h, 150 | m, 151 | rbf3, 152 | cbf3, 153 | id3_ragged_idx, 154 | id_swap, 155 | id3_ba, 156 | id3_ca, 157 | rbf_h, 158 | idx_s, 159 | idx_t, 160 | ): 161 | """ 162 | Returns 163 | ------- 164 | h: torch.Tensor, shape=(nEdges, emb_size_atom) 165 | Atom embeddings. 166 | m: torch.Tensor, shape=(nEdges, emb_size_edge) 167 | Edge embeddings (c->a). 168 | """ 169 | 170 | # Initial transformation 171 | x_ca_skip = self.dense_ca(m) # (nEdges, emb_size_edge) 172 | 173 | x3 = self.trip_interaction( 174 | m, 175 | rbf3, 176 | cbf3, 177 | id3_ragged_idx, 178 | id_swap, 179 | id3_ba, 180 | id3_ca, 181 | ) 182 | 183 | ## ----------------------------- Merge Embeddings after Triplet Interaction ------------------------------ ## 184 | x = x_ca_skip + x3 # (nEdges, emb_size_edge) 185 | x = x * self.inv_sqrt_2 186 | 187 | ## ---------------------------------------- Update Edge Embeddings --------------------------------------- ## 188 | # Transformations before skip connection 189 | for i, layer in enumerate(self.layers_before_skip): 190 | x = layer(x) # (nEdges, emb_size_edge) 191 | 192 | # Skip connection 193 | m = m + x # (nEdges, emb_size_edge) 194 | m = m * self.inv_sqrt_2 195 | 196 | # Transformations after skip connection 197 | for i, layer in enumerate(self.layers_after_skip): 198 | m = layer(m) # (nEdges, emb_size_edge) 199 | 200 | ## ---------------------------------------- Update Atom Embeddings --------------------------------------- ## 201 | h2 = self.atom_update(h, m, rbf_h, idx_t) 202 | 203 | # Skip connection 204 | h = h + h2 # (nAtoms, emb_size_atom) 205 | h = h * self.inv_sqrt_2 206 | 207 | ## ----------------------------- Update Edge Embeddings with Atom Embeddings ----------------------------- ## 208 | m2 = self.concat_layer(h, m, idx_s, idx_t) # (nEdges, emb_size_edge) 209 | 210 | for i, layer in enumerate(self.residual_m): 211 | m2 = layer(m2) # (nEdges, emb_size_edge) 212 | 213 | # Skip connection 214 | m = m + m2 # (nEdges, emb_size_edge) 215 | m = m * self.inv_sqrt_2 216 | return h, m 217 | 218 | 219 | class TripletInteraction(torch.nn.Module): 220 | """ 221 | Triplet-based message passing block. 222 | 223 | Parameters 224 | ---------- 225 | emb_size_edge: int 226 | Embedding size of the edges. 227 | emb_size_trip: int 228 | (Down-projected) Embedding size of the edge embeddings after the hadamard product with rbf. 229 | emb_size_bilinear: int 230 | Embedding size of the edge embeddings after the bilinear layer. 231 | emb_size_rbf: int 232 | Embedding size of the radial basis transformation. 233 | emb_size_cbf: int 234 | Embedding size of the circular basis transformation (one angle). 235 | 236 | activation: str 237 | Name of the activation function to use in the dense layers except for the final dense layer. 238 | scale_file: str 239 | Path to the json file containing the scaling factors. 240 | """ 241 | 242 | def __init__( 243 | self, 244 | emb_size_edge, 245 | emb_size_trip, 246 | emb_size_bilinear, 247 | emb_size_rbf, 248 | emb_size_cbf, 249 | activation=None, 250 | scale_file=None, 251 | name="TripletInteraction", 252 | **kwargs, 253 | ): 254 | super().__init__() 255 | self.name = name 256 | 257 | # Dense transformation 258 | self.dense_ba = Dense( 259 | emb_size_edge, 260 | emb_size_edge, 261 | activation=activation, 262 | bias=False, 263 | ) 264 | 265 | # Up projections of basis representations, bilinear layer and scaling factors 266 | self.mlp_rbf = Dense( 267 | emb_size_rbf, 268 | emb_size_edge, 269 | activation=None, 270 | bias=False, 271 | ) 272 | self.scale_rbf = ScalingFactor( 273 | scale_file=scale_file, name=name + "_had_rbf" 274 | ) 275 | 276 | self.mlp_cbf = EfficientInteractionBilinear( 277 | emb_size_trip, emb_size_cbf, emb_size_bilinear 278 | ) 279 | self.scale_cbf_sum = ScalingFactor( 280 | scale_file=scale_file, name=name + "_sum_cbf" 281 | ) # combines scaling for bilinear layer and summation 282 | 283 | # Down and up projections 284 | self.down_projection = Dense( 285 | emb_size_edge, 286 | emb_size_trip, 287 | activation=activation, 288 | bias=False, 289 | ) 290 | self.up_projection_ca = Dense( 291 | emb_size_bilinear, 292 | emb_size_edge, 293 | activation=activation, 294 | bias=False, 295 | ) 296 | self.up_projection_ac = Dense( 297 | emb_size_bilinear, 298 | emb_size_edge, 299 | activation=activation, 300 | bias=False, 301 | ) 302 | 303 | self.inv_sqrt_2 = 1 / math.sqrt(2.0) 304 | 305 | def forward( 306 | self, 307 | m, 308 | rbf3, 309 | cbf3, 310 | id3_ragged_idx, 311 | id_swap, 312 | id3_ba, 313 | id3_ca, 314 | ): 315 | """ 316 | Returns 317 | ------- 318 | m: torch.Tensor, shape=(nEdges, emb_size_edge) 319 | Edge embeddings (c->a). 320 | """ 321 | 322 | # Dense transformation 323 | x_ba = self.dense_ba(m) # (nEdges, emb_size_edge) 324 | 325 | # Transform via radial bessel basis 326 | rbf_emb = self.mlp_rbf(rbf3) # (nEdges, emb_size_edge) 327 | x_ba2 = x_ba * rbf_emb 328 | x_ba = self.scale_rbf(x_ba, x_ba2) 329 | 330 | x_ba = self.down_projection(x_ba) # (nEdges, emb_size_trip) 331 | 332 | # Transform via circular spherical basis 333 | x_ba = x_ba[id3_ba] 334 | 335 | # Efficient bilinear layer 336 | x = self.mlp_cbf(cbf3, x_ba, id3_ca, id3_ragged_idx) 337 | # (nEdges, emb_size_quad) 338 | x = self.scale_cbf_sum(x_ba, x) 339 | 340 | # => 341 | # rbf(d_ba) 342 | # cbf(d_ca, angle_cab) 343 | 344 | # Up project embeddings 345 | x_ca = self.up_projection_ca(x) # (nEdges, emb_size_edge) 346 | x_ac = self.up_projection_ac(x) # (nEdges, emb_size_edge) 347 | 348 | # Merge interaction of c->a and a->c 349 | x_ac = x_ac[id_swap] # swap to add to edge a->c and not c->a 350 | x3 = x_ca + x_ac 351 | x3 = x3 * self.inv_sqrt_2 352 | return x3 353 | -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/radial_basis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import math 9 | 10 | import numpy as np 11 | import torch 12 | from scipy.special import binom 13 | from torch_geometric.nn.models.schnet import GaussianSmearing 14 | 15 | 16 | class PolynomialEnvelope(torch.nn.Module): 17 | """ 18 | Polynomial envelope function that ensures a smooth cutoff. 19 | 20 | Parameters 21 | ---------- 22 | exponent: int 23 | Exponent of the envelope function. 24 | """ 25 | 26 | def __init__(self, exponent): 27 | super().__init__() 28 | assert exponent > 0 29 | self.p = exponent 30 | self.a = -(self.p + 1) * (self.p + 2) / 2 31 | self.b = self.p * (self.p + 2) 32 | self.c = -self.p * (self.p + 1) / 2 33 | 34 | def forward(self, d_scaled): 35 | env_val = ( 36 | 1 37 | + self.a * d_scaled ** self.p 38 | + self.b * d_scaled ** (self.p + 1) 39 | + self.c * d_scaled ** (self.p + 2) 40 | ) 41 | return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled)) 42 | 43 | 44 | class ExponentialEnvelope(torch.nn.Module): 45 | """ 46 | Exponential envelope function that ensures a smooth cutoff, 47 | as proposed in Unke, Chmiela, Gastegger, Schütt, Sauceda, Müller 2021. 48 | SpookyNet: Learning Force Fields with Electronic Degrees of Freedom 49 | and Nonlocal Effects 50 | """ 51 | 52 | def __init__(self): 53 | super().__init__() 54 | 55 | def forward(self, d_scaled): 56 | env_val = torch.exp( 57 | -(d_scaled ** 2) / ((1 - d_scaled) * (1 + d_scaled)) 58 | ) 59 | return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled)) 60 | 61 | 62 | class SphericalBesselBasis(torch.nn.Module): 63 | """ 64 | 1D spherical Bessel basis 65 | 66 | Parameters 67 | ---------- 68 | num_radial: int 69 | Controls maximum frequency. 70 | cutoff: float 71 | Cutoff distance in Angstrom. 72 | """ 73 | 74 | def __init__( 75 | self, 76 | num_radial: int, 77 | cutoff: float, 78 | ): 79 | super().__init__() 80 | self.norm_const = math.sqrt(2 / (cutoff ** 3)) 81 | # cutoff ** 3 to counteract dividing by d_scaled = d / cutoff 82 | 83 | # Initialize frequencies at canonical positions 84 | self.frequencies = torch.nn.Parameter( 85 | data=torch.tensor( 86 | np.pi * np.arange(1, num_radial + 1, dtype=np.float32) 87 | ), 88 | requires_grad=True, 89 | ) 90 | 91 | def forward(self, d_scaled): 92 | return ( 93 | self.norm_const 94 | / d_scaled[:, None] 95 | * torch.sin(self.frequencies * d_scaled[:, None]) 96 | ) # (num_edges, num_radial) 97 | 98 | 99 | class BernsteinBasis(torch.nn.Module): 100 | """ 101 | Bernstein polynomial basis, 102 | as proposed in Unke, Chmiela, Gastegger, Schütt, Sauceda, Müller 2021. 103 | SpookyNet: Learning Force Fields with Electronic Degrees of Freedom 104 | and Nonlocal Effects 105 | 106 | Parameters 107 | ---------- 108 | num_radial: int 109 | Controls maximum frequency. 110 | pregamma_initial: float 111 | Initial value of exponential coefficient gamma. 112 | Default: gamma = 0.5 * a_0**-1 = 0.94486, 113 | inverse softplus -> pregamma = log e**gamma - 1 = 0.45264 114 | """ 115 | 116 | def __init__( 117 | self, 118 | num_radial: int, 119 | pregamma_initial: float = 0.45264, 120 | ): 121 | super().__init__() 122 | prefactor = binom(num_radial - 1, np.arange(num_radial)) 123 | self.register_buffer( 124 | "prefactor", 125 | torch.tensor(prefactor, dtype=torch.float), 126 | persistent=False, 127 | ) 128 | 129 | self.pregamma = torch.nn.Parameter( 130 | data=torch.tensor(pregamma_initial, dtype=torch.float), 131 | requires_grad=True, 132 | ) 133 | self.softplus = torch.nn.Softplus() 134 | 135 | exp1 = torch.arange(num_radial) 136 | self.register_buffer("exp1", exp1[None, :], persistent=False) 137 | exp2 = num_radial - 1 - exp1 138 | self.register_buffer("exp2", exp2[None, :], persistent=False) 139 | 140 | def forward(self, d_scaled): 141 | gamma = self.softplus(self.pregamma) # constrain to positive 142 | exp_d = torch.exp(-gamma * d_scaled)[:, None] 143 | return ( 144 | self.prefactor * (exp_d ** self.exp1) * ((1 - exp_d) ** self.exp2) 145 | ) 146 | 147 | 148 | class RadialBasis(torch.nn.Module): 149 | """ 150 | 151 | Parameters 152 | ---------- 153 | num_radial: int 154 | Controls maximum frequency. 155 | cutoff: float 156 | Cutoff distance in Angstrom. 157 | rbf: dict = {"name": "gaussian"} 158 | Basis function and its hyperparameters. 159 | envelope: dict = {"name": "polynomial", "exponent": 5} 160 | Envelope function and its hyperparameters. 161 | """ 162 | 163 | def __init__( 164 | self, 165 | num_radial: int, 166 | cutoff: float, 167 | rbf: dict = {"name": "gaussian"}, 168 | envelope: dict = {"name": "polynomial", "exponent": 5}, 169 | ): 170 | super().__init__() 171 | self.inv_cutoff = 1 / cutoff 172 | 173 | env_name = envelope["name"].lower() 174 | env_hparams = envelope.copy() 175 | del env_hparams["name"] 176 | 177 | if env_name == "polynomial": 178 | self.envelope = PolynomialEnvelope(**env_hparams) 179 | elif env_name == "exponential": 180 | self.envelope = ExponentialEnvelope(**env_hparams) 181 | else: 182 | raise ValueError(f"Unknown envelope function '{env_name}'.") 183 | 184 | rbf_name = rbf["name"].lower() 185 | rbf_hparams = rbf.copy() 186 | del rbf_hparams["name"] 187 | 188 | # RBFs get distances scaled to be in [0, 1] 189 | if rbf_name == "gaussian": 190 | self.rbf = GaussianSmearing( 191 | start=0, stop=1, num_gaussians=num_radial, **rbf_hparams 192 | ) 193 | elif rbf_name == "spherical_bessel": 194 | self.rbf = SphericalBesselBasis( 195 | num_radial=num_radial, cutoff=cutoff, **rbf_hparams 196 | ) 197 | elif rbf_name == "bernstein": 198 | self.rbf = BernsteinBasis(num_radial=num_radial, **rbf_hparams) 199 | else: 200 | raise ValueError(f"Unknown radial basis function '{rbf_name}'.") 201 | 202 | def forward(self, d): 203 | d_scaled = d * self.inv_cutoff 204 | 205 | env = self.envelope(d_scaled) 206 | return env[:, None] * self.rbf(d_scaled) # (nEdges, num_radial) 207 | -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/scaling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import logging 9 | 10 | import torch 11 | 12 | from ..utils import read_value_json, update_json 13 | 14 | 15 | class AutomaticFit: 16 | """ 17 | All added variables are processed in the order of creation. 18 | """ 19 | 20 | activeVar = None 21 | queue = None 22 | fitting_mode = False 23 | 24 | def __init__(self, variable, scale_file, name): 25 | self.variable = variable # variable to find value for 26 | self.scale_file = scale_file 27 | self._name = name 28 | 29 | self._fitted = False 30 | self.load_maybe() 31 | 32 | # first instance created 33 | if AutomaticFit.fitting_mode and not self._fitted: 34 | 35 | # if first layer set to active 36 | if AutomaticFit.activeVar is None: 37 | AutomaticFit.activeVar = self 38 | AutomaticFit.queue = [] # initialize 39 | # else add to queue 40 | else: 41 | self._add2queue() # adding variables to list fill fail in graph mode 42 | 43 | def reset(): 44 | AutomaticFit.activeVar = None 45 | AutomaticFit.all_processed = False 46 | 47 | def fitting_completed(): 48 | return AutomaticFit.queue is None 49 | 50 | def set2fitmode(): 51 | AutomaticFit.reset() 52 | AutomaticFit.fitting_mode = True 53 | 54 | def _add2queue(self): 55 | logging.debug(f"Add {self._name} to queue.") 56 | # check that same variable is not added twice 57 | for var in AutomaticFit.queue: 58 | if self._name == var._name: 59 | raise ValueError( 60 | f"Variable with the same name ({self._name}) was already added to queue!" 61 | ) 62 | AutomaticFit.queue += [self] 63 | 64 | def set_next_active(self): 65 | """ 66 | Set the next variable in the queue that should be fitted. 67 | """ 68 | queue = AutomaticFit.queue 69 | if len(queue) == 0: 70 | logging.debug("Processed all variables.") 71 | AutomaticFit.queue = None 72 | AutomaticFit.activeVar = None # reset to None 73 | return 74 | AutomaticFit.activeVar = queue.pop(0) 75 | 76 | def load_maybe(self): 77 | """ 78 | Load variable from file or set to initial value of the variable. 79 | """ 80 | value = read_value_json(self.scale_file, self._name) 81 | if value is None: 82 | logging.debug( 83 | f"Initialize variable {self._name}' to {self.variable.numpy():.3f}" 84 | ) 85 | else: 86 | self._fitted = True 87 | logging.debug(f"Set scale factor {self._name} : {value}") 88 | with torch.no_grad(): 89 | self.variable.copy_(torch.tensor(value)) 90 | 91 | 92 | class AutoScaleFit(AutomaticFit): 93 | """ 94 | Class to automatically fit the scaling factors depending on the observed variances. 95 | 96 | Parameters 97 | ---------- 98 | variable: torch.Tensor 99 | Variable to fit. 100 | scale_file: str 101 | Path to the json file where to store/load from the scaling factors. 102 | """ 103 | 104 | def __init__(self, variable, scale_file, name): 105 | super().__init__(variable, scale_file, name) 106 | 107 | if not self._fitted: 108 | self._init_stats() 109 | 110 | def _init_stats(self): 111 | self.variance_in = 0 112 | self.variance_out = 0 113 | self.nSamples = 0 114 | 115 | @torch.no_grad() 116 | def observe(self, x, y): 117 | """ 118 | Observe variances for input x and output y. 119 | The scaling factor alpha is calculated s.t. Var(alpha * y) ~ Var(x) 120 | """ 121 | if self._fitted: 122 | return 123 | 124 | # only track stats for current variable 125 | if AutomaticFit.activeVar == self: 126 | nSamples = y.shape[0] 127 | self.variance_in += ( 128 | torch.mean(torch.var(x, dim=0)).to(dtype=torch.float32) 129 | * nSamples 130 | ) 131 | self.variance_out += ( 132 | torch.mean(torch.var(y, dim=0)).to(dtype=torch.float32) 133 | * nSamples 134 | ) 135 | self.nSamples += nSamples 136 | 137 | @torch.no_grad() 138 | def fit(self): 139 | """ 140 | Fit the scaling factor based on the observed variances. 141 | """ 142 | if AutomaticFit.activeVar == self: 143 | if self.variance_in == 0: 144 | raise ValueError( 145 | f"Did not track the variable {self._name}. Add observe calls to track the variance before and after." 146 | ) 147 | 148 | # calculate variance preserving scaling factor 149 | self.variance_in = self.variance_in / self.nSamples 150 | self.variance_out = self.variance_out / self.nSamples 151 | 152 | ratio = self.variance_out / self.variance_in 153 | value = torch.sqrt(1 / ratio) 154 | logging.info( 155 | f"Variable: {self._name}, " 156 | f"Var_in: {self.variance_in.item():.3f}, " 157 | f"Var_out: {self.variance_out.item():.3f}, " 158 | f"Ratio: {ratio:.3f} => Scaling factor: {value:.3f}" 159 | ) 160 | 161 | # set variable to calculated value 162 | self.variable.copy_(self.variable * value) 163 | update_json( 164 | self.scale_file, {self._name: float(self.variable.item())} 165 | ) 166 | self.set_next_active() # set next variable in queue to active 167 | 168 | 169 | class ScalingFactor(torch.nn.Module): 170 | """ 171 | Scale the output y of the layer s.t. the (mean) variance wrt. to the reference input x_ref is preserved. 172 | 173 | Parameters 174 | ---------- 175 | scale_file: str 176 | Path to the json file where to store/load from the scaling factors. 177 | name: str 178 | Name of the scaling factor 179 | """ 180 | 181 | def __init__(self, scale_file, name, device=None): 182 | super().__init__() 183 | 184 | self.scale_factor = torch.nn.Parameter( 185 | torch.tensor(1.0, device=device), requires_grad=False 186 | ) 187 | self.autofit = AutoScaleFit(self.scale_factor, scale_file, name) 188 | 189 | def forward(self, x_ref, y): 190 | y = y * self.scale_factor 191 | self.autofit.observe(x_ref, y) 192 | 193 | return y 194 | -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/layers/spherical_basis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import sympy as sym 9 | import torch 10 | from torch_geometric.nn.models.schnet import GaussianSmearing 11 | 12 | from .basis_utils import real_sph_harm 13 | from .radial_basis import RadialBasis 14 | 15 | 16 | class CircularBasisLayer(torch.nn.Module): 17 | """ 18 | 2D Fourier Bessel Basis 19 | 20 | Parameters 21 | ---------- 22 | num_spherical: int 23 | Controls maximum frequency. 24 | radial_basis: RadialBasis 25 | Radial basis functions 26 | cbf: dict 27 | Name and hyperparameters of the cosine basis function 28 | efficient: bool 29 | Whether to use the "efficient" summation order 30 | """ 31 | 32 | def __init__( 33 | self, 34 | num_spherical: int, 35 | radial_basis: RadialBasis, 36 | cbf: str, 37 | efficient: bool = False, 38 | ): 39 | super().__init__() 40 | 41 | self.radial_basis = radial_basis 42 | self.efficient = efficient 43 | 44 | cbf_name = cbf["name"].lower() 45 | cbf_hparams = cbf.copy() 46 | del cbf_hparams["name"] 47 | 48 | if cbf_name == "gaussian": 49 | self.cosφ_basis = GaussianSmearing( 50 | start=-1, stop=1, num_gaussians=num_spherical, **cbf_hparams 51 | ) 52 | elif cbf_name == "spherical_harmonics": 53 | Y_lm = real_sph_harm( 54 | num_spherical, use_theta=False, zero_m_only=True 55 | ) 56 | sph_funcs = [] # (num_spherical,) 57 | 58 | # convert to tensorflow functions 59 | z = sym.symbols("z") 60 | modules = {"sin": torch.sin, "cos": torch.cos, "sqrt": torch.sqrt} 61 | m_order = 0 # only single angle 62 | for l_degree in range(len(Y_lm)): # num_spherical 63 | if ( 64 | l_degree == 0 65 | ): # Y_00 is only a constant -> function returns value and not tensor 66 | first_sph = sym.lambdify( 67 | [z], Y_lm[l_degree][m_order], modules 68 | ) 69 | sph_funcs.append( 70 | lambda z: torch.zeros_like(z) + first_sph(z) 71 | ) 72 | else: 73 | sph_funcs.append( 74 | sym.lambdify([z], Y_lm[l_degree][m_order], modules) 75 | ) 76 | self.cosφ_basis = lambda cosφ: torch.stack( 77 | [f(cosφ) for f in sph_funcs], dim=1 78 | ) 79 | else: 80 | raise ValueError(f"Unknown cosine basis function '{cbf_name}'.") 81 | 82 | def forward(self, D_ca, cosφ_cab, id3_ca): 83 | rbf = self.radial_basis(D_ca) # (num_edges, num_radial) 84 | cbf = self.cosφ_basis(cosφ_cab) # (num_triplets, num_spherical) 85 | 86 | if not self.efficient: 87 | rbf = rbf[id3_ca] # (num_triplets, num_radial) 88 | out = (rbf[:, None, :] * cbf[:, :, None]).view( 89 | -1, rbf.shape[-1] * cbf.shape[-1] 90 | ) 91 | return (out,) 92 | # (num_triplets, num_radial * num_spherical) 93 | else: 94 | return (rbf[None, :, :], cbf) 95 | # (1, num_edges, num_radial), (num_edges, num_spherical) 96 | -------------------------------------------------------------------------------- /concdvae/pl_modules/gemnet/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import json 9 | 10 | import torch 11 | from torch_scatter import segment_csr 12 | 13 | 14 | def read_json(path): 15 | """""" 16 | if not path.endswith(".json"): 17 | raise UserWarning(f"Path {path} is not a json-path.") 18 | 19 | with open(path, "r") as f: 20 | content = json.load(f) 21 | return content 22 | 23 | 24 | def update_json(path, data): 25 | """""" 26 | if not path.endswith(".json"): 27 | raise UserWarning(f"Path {path} is not a json-path.") 28 | 29 | content = read_json(path) 30 | content.update(data) 31 | write_json(path, content) 32 | 33 | 34 | def write_json(path, data): 35 | """""" 36 | if not path.endswith(".json"): 37 | raise UserWarning(f"Path {path} is not a json-path.") 38 | 39 | with open(path, "w", encoding="utf-8") as f: 40 | json.dump(data, f, ensure_ascii=False, indent=4) 41 | 42 | 43 | def read_value_json(path, key): 44 | """""" 45 | content = read_json(path) 46 | 47 | if key in content.keys(): 48 | return content[key] 49 | else: 50 | return None 51 | 52 | 53 | def ragged_range(sizes): 54 | """Multiple concatenated ranges. 55 | 56 | Examples 57 | -------- 58 | sizes = [1 4 2 3] 59 | Return: [0 0 1 2 3 0 1 0 1 2] 60 | """ 61 | assert sizes.dim() == 1 62 | if sizes.sum() == 0: 63 | return sizes.new_empty(0) 64 | 65 | # Remove 0 sizes 66 | sizes_nonzero = sizes > 0 67 | if not torch.all(sizes_nonzero): 68 | sizes = torch.masked_select(sizes, sizes_nonzero) 69 | 70 | # Initialize indexing array with ones as we need to setup incremental indexing 71 | # within each group when cumulatively summed at the final stage. 72 | id_steps = torch.ones(sizes.sum(), dtype=torch.long, device=sizes.device) 73 | id_steps[0] = 0 74 | insert_index = sizes[:-1].cumsum(0) 75 | insert_val = (1 - sizes)[:-1] 76 | 77 | # Assign index-offsetting values 78 | id_steps[insert_index] = insert_val 79 | 80 | # Finally index into input array for the group repeated o/p 81 | res = id_steps.cumsum(0) 82 | return res 83 | 84 | 85 | def repeat_blocks( 86 | sizes, 87 | repeats, 88 | continuous_indexing=True, 89 | start_idx=0, 90 | block_inc=0, 91 | repeat_inc=0, 92 | ): 93 | """Repeat blocks of indices. 94 | Adapted from https://stackoverflow.com/questions/51154989/numpy-vectorized-function-to-repeat-blocks-of-consecutive-elements 95 | 96 | continuous_indexing: Whether to keep increasing the index after each block 97 | start_idx: Starting index 98 | block_inc: Number to increment by after each block, 99 | either global or per block. Shape: len(sizes) - 1 100 | repeat_inc: Number to increment by after each repetition, 101 | either global or per block 102 | 103 | Examples 104 | -------- 105 | sizes = [1,3,2] ; repeats = [3,2,3] ; continuous_indexing = False 106 | Return: [0 0 0 0 1 2 0 1 2 0 1 0 1 0 1] 107 | sizes = [1,3,2] ; repeats = [3,2,3] ; continuous_indexing = True 108 | Return: [0 0 0 1 2 3 1 2 3 4 5 4 5 4 5] 109 | sizes = [1,3,2] ; repeats = [3,2,3] ; continuous_indexing = True ; 110 | repeat_inc = 4 111 | Return: [0 4 8 1 2 3 5 6 7 4 5 8 9 12 13] 112 | sizes = [1,3,2] ; repeats = [3,2,3] ; continuous_indexing = True ; 113 | start_idx = 5 114 | Return: [5 5 5 6 7 8 6 7 8 9 10 9 10 9 10] 115 | sizes = [1,3,2] ; repeats = [3,2,3] ; continuous_indexing = True ; 116 | block_inc = 1 117 | Return: [0 0 0 2 3 4 2 3 4 6 7 6 7 6 7] 118 | sizes = [0,3,2] ; repeats = [3,2,3] ; continuous_indexing = True 119 | Return: [0 1 2 0 1 2 3 4 3 4 3 4] 120 | sizes = [2,3,2] ; repeats = [2,0,2] ; continuous_indexing = True 121 | Return: [0 1 0 1 5 6 5 6] 122 | """ 123 | assert sizes.dim() == 1 124 | assert all(sizes >= 0) 125 | 126 | # Remove 0 sizes 127 | sizes_nonzero = sizes > 0 128 | if not torch.all(sizes_nonzero): 129 | print('sizes_non') 130 | print(sizes_nonzero) 131 | assert block_inc == 0 # Implementing this is not worth the effort 132 | sizes = torch.masked_select(sizes, sizes_nonzero) 133 | if isinstance(repeats, torch.Tensor): 134 | repeats = torch.masked_select(repeats, sizes_nonzero) 135 | if isinstance(repeat_inc, torch.Tensor): 136 | repeat_inc = torch.masked_select(repeat_inc, sizes_nonzero) 137 | 138 | if isinstance(repeats, torch.Tensor): 139 | assert all(repeats >= 0) 140 | insert_dummy = repeats[0] == 0 141 | if insert_dummy: 142 | one = sizes.new_ones(1) 143 | zero = sizes.new_zeros(1) 144 | sizes = torch.cat((one, sizes)) 145 | repeats = torch.cat((one, repeats)) 146 | if isinstance(block_inc, torch.Tensor): 147 | block_inc = torch.cat((zero, block_inc)) 148 | if isinstance(repeat_inc, torch.Tensor): 149 | repeat_inc = torch.cat((zero, repeat_inc)) 150 | else: 151 | assert repeats >= 0 152 | insert_dummy = False 153 | 154 | # Get repeats for each group using group lengths/sizes 155 | r1 = torch.repeat_interleave( 156 | torch.arange(len(sizes), device=sizes.device), repeats 157 | ) 158 | 159 | # Get total size of output array, as needed to initialize output indexing array 160 | N = (sizes * repeats).sum() 161 | 162 | # Initialize indexing array with ones as we need to setup incremental indexing 163 | # within each group when cumulatively summed at the final stage. 164 | # Two steps here: 165 | # 1. Within each group, we have multiple sequences, so setup the offsetting 166 | # at each sequence lengths by the seq. lengths preceding those. 167 | id_ar = torch.ones(N, dtype=torch.long, device=sizes.device) 168 | id_ar[0] = 0 169 | insert_index = sizes[r1[:-1]].cumsum(0) 170 | insert_val = (1 - sizes)[r1[:-1]] 171 | 172 | if isinstance(repeats, torch.Tensor) and torch.any(repeats == 0): 173 | diffs = r1[1:] - r1[:-1] 174 | indptr = torch.cat((sizes.new_zeros(1), diffs.cumsum(0))) 175 | if continuous_indexing: 176 | # If a group was skipped (repeats=0) we need to add its size 177 | insert_val += segment_csr(sizes[: r1[-1]], indptr, reduce="sum") 178 | 179 | # Add block increments 180 | if isinstance(block_inc, torch.Tensor): 181 | insert_val += segment_csr( 182 | block_inc[: r1[-1]], indptr, reduce="sum" 183 | ) 184 | else: 185 | insert_val += block_inc * (indptr[1:] - indptr[:-1]) 186 | if insert_dummy: 187 | insert_val[0] -= block_inc 188 | else: 189 | idx = r1[1:] != r1[:-1] 190 | if continuous_indexing: 191 | # 2. For each group, make sure the indexing starts from the next group's 192 | # first element. So, simply assign 1s there. 193 | insert_val[idx] = 1 194 | 195 | # Add block increments 196 | insert_val[idx] += block_inc 197 | 198 | # Add repeat_inc within each group 199 | if isinstance(repeat_inc, torch.Tensor): 200 | insert_val += repeat_inc[r1[:-1]] 201 | if isinstance(repeats, torch.Tensor): 202 | repeat_inc_inner = repeat_inc[repeats > 0][:-1] 203 | else: 204 | repeat_inc_inner = repeat_inc[:-1] 205 | else: 206 | insert_val += repeat_inc 207 | repeat_inc_inner = repeat_inc 208 | 209 | # Subtract the increments between groups 210 | if isinstance(repeats, torch.Tensor): 211 | repeats_inner = repeats[repeats > 0][:-1] 212 | else: 213 | repeats_inner = repeats 214 | insert_val[r1[1:] != r1[:-1]] -= repeat_inc_inner * repeats_inner 215 | 216 | # Assign index-offsetting values 217 | id_ar[insert_index] = insert_val 218 | 219 | if insert_dummy: 220 | id_ar = id_ar[1:] 221 | if continuous_indexing: 222 | id_ar[0] -= 1 223 | 224 | # Set start index now, in case of insertion due to leading repeats=0 225 | id_ar[0] += start_idx 226 | 227 | # Finally index into input array for the group repeated o/p 228 | res = id_ar.cumsum(0) 229 | return res 230 | 231 | 232 | def calculate_interatomic_vectors(R, id_s, id_t, offsets_st): 233 | """ 234 | Calculate the vectors connecting the given atom pairs, 235 | considering offsets from periodic boundary conditions (PBC). 236 | 237 | Parameters 238 | ---------- 239 | R: Tensor, shape = (nAtoms, 3) 240 | Atom positions. 241 | id_s: Tensor, shape = (nEdges,) 242 | Indices of the source atom of the edges. 243 | id_t: Tensor, shape = (nEdges,) 244 | Indices of the target atom of the edges. 245 | offsets_st: Tensor, shape = (nEdges,) 246 | PBC offsets of the edges. 247 | Subtract this from the correct direction. 248 | 249 | Returns 250 | ------- 251 | (D_st, V_st): tuple 252 | D_st: Tensor, shape = (nEdges,) 253 | Distance from atom t to s. 254 | V_st: Tensor, shape = (nEdges,) 255 | Unit direction from atom t to s. 256 | """ 257 | Rs = R[id_s] 258 | Rt = R[id_t] 259 | # ReLU prevents negative numbers in sqrt 260 | if offsets_st is None: 261 | V_st = Rt - Rs # s -> t 262 | else: 263 | V_st = Rt - Rs + offsets_st # s -> t 264 | D_st = torch.sqrt(torch.sum(V_st ** 2, dim=1)) 265 | V_st = V_st / D_st[..., None] 266 | return D_st, V_st 267 | 268 | 269 | def inner_product_normalized(x, y): 270 | """ 271 | Calculate the inner product between the given normalized vectors, 272 | giving a result between -1 and 1. 273 | """ 274 | return torch.sum(x * y, dim=-1).clamp(min=-1, max=1) 275 | 276 | 277 | def mask_neighbors(neighbors, edge_mask): 278 | neighbors_old_indptr = torch.cat([neighbors.new_zeros(1), neighbors]) 279 | neighbors_old_indptr = torch.cumsum(neighbors_old_indptr, dim=0) 280 | neighbors = segment_csr(edge_mask.long(), neighbors_old_indptr) 281 | return neighbors 282 | -------------------------------------------------------------------------------- /concdvae/pt2CS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pymatgen.core.lattice import Lattice 3 | from collections import Counter 4 | from pymatgen.symmetry.analyzer import SpacegroupAnalyzer 5 | from pymatgen.core.structure import Structure 6 | import os 7 | 8 | def determine_crystal_system(structure, tolerance=1.0, angle_tolerance=5): 9 | try: 10 | sga = SpacegroupAnalyzer(structure, symprec=tolerance,angle_tolerance=angle_tolerance) 11 | 12 | crystal_system = sga.get_crystal_system() 13 | 14 | return crystal_system 15 | 16 | except Exception as e: 17 | return "error" + str(e) 18 | 19 | dataroot = 'D:/2-project/0-MaterialDesign/3-CDVAE/cdvae-pre4/output/hydra/singlerun/2023-12-25/mp20_CS/' 20 | dataname = 'eval_gen_less_CS' 21 | lastname = ['0.pt','1.pt','2.pt','3.pt','4.pt','5.pt','6.pt',] 22 | for last in lastname: 23 | print('last name: ', last) 24 | datafile = dataname + last 25 | tolerance = 0.2 26 | angle_tolerance = 5 27 | datafile_read = os.path.join(dataroot, datafile) 28 | data = torch.load(datafile_read,map_location=torch.device('cpu')) 29 | lengths = data['lengths'] 30 | angles = data['angles'] 31 | num_atoms = data['num_atoms'] 32 | frac_coors = data['frac_coords'] 33 | atom_types = data['atom_types'] 34 | 35 | lengths_list = lengths.numpy().tolist() 36 | angles_list = angles.numpy().tolist() 37 | num_atoms_list = num_atoms.tolist() 38 | frac_coors_list = frac_coors.numpy().tolist() 39 | atom_types_list = atom_types.tolist() 40 | 41 | num_materal = 0 42 | CS_type = [] 43 | cs_name = [] 44 | for i in range(len(num_atoms_list)): 45 | now_atom = 0 46 | for a in range(len(num_atoms_list[i])): 47 | length = lengths_list[i][a] 48 | angle = angles_list[i][a] 49 | atom_num = num_atoms_list[i][a] 50 | 51 | atom_type = atom_types_list[i][now_atom: now_atom + atom_num] 52 | frac_coord = frac_coors_list[i][now_atom: now_atom + atom_num][:] 53 | lattice = Lattice.from_parameters(a=length[0], b=length[1], c=length[2], alpha=angle[0], 54 | beta=angle[1], gamma=angle[2]) 55 | 56 | structure = Structure(lattice, atom_type, frac_coord, to_unit_cell=True) 57 | 58 | crystal_system = determine_crystal_system(structure, tolerance, angle_tolerance) 59 | cs_name.append(crystal_system) 60 | 61 | print('result') 62 | element_counts = Counter(cs_name) 63 | 64 | for element, count in element_counts.items(): 65 | print(f"{element}: {count} time") 66 | 67 | print('-------------------------------------------------------------') 68 | 69 | -------------------------------------------------------------------------------- /concdvae/run.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | import os 4 | script_dir = os.path.dirname(os.path.abspath(__file__)) 5 | parent_dir = os.path.abspath(os.path.join(script_dir, "..")) 6 | sys.path.append(parent_dir) 7 | 8 | import hydra 9 | import random 10 | import numpy as np 11 | import torch 12 | import omegaconf 13 | from hydra.core.hydra_config import HydraConfig 14 | from omegaconf import DictConfig, OmegaConf 15 | 16 | from concdvae.common.utils import PROJECT_ROOT, param_statistics 17 | from concdvae.PT_train.training import train 18 | 19 | 20 | 21 | def run(cfg: DictConfig) -> None: 22 | """ 23 | Generic train loop 24 | :param cfg: run configuration, defined by Hydra in /conf 25 | """ 26 | if cfg.train.deterministic: 27 | np.random.seed(cfg.train.random_seed) 28 | random.seed(cfg.train.random_seed) 29 | torch.manual_seed(cfg.train.random_seed) 30 | # torch.backends.cudnn.deterministic = True 31 | if(cfg.accelerator != 'cpu'): 32 | torch.cuda.manual_seed(cfg.train.random_seed) 33 | torch.cuda.manual_seed_all(cfg.train.random_seed) 34 | 35 | # Hydra run directory 36 | hydra_dir = Path(HydraConfig.get().run.dir) 37 | 38 | # Instantiate datamodule 39 | hydra.utils.log.info(f"Instantiating <{cfg.data.datamodule._target_}>") 40 | datamodule = hydra.utils.instantiate( 41 | cfg.data.datamodule, _recursive_=False 42 | ) 43 | 44 | # Instantiate model 45 | hydra.utils.log.info(f"Instantiating <{cfg.model._target_}>") 46 | model = hydra.utils.instantiate( 47 | cfg.model, 48 | optim=cfg.optim, 49 | data=cfg.data, 50 | logging=cfg.logging, 51 | _recursive_=False, 52 | ) 53 | param_statistics(model) 54 | 55 | best_loss_old = None 56 | if(cfg.train.PT_train.start_epochs>1): 57 | filename = 'model_' + cfg.expname + '.pth' 58 | model_root = Path(hydra_dir) / filename 59 | if os.path.exists(model_root): 60 | checkpoint = torch.load(model_root, map_location=torch.device('cpu')) 61 | model_state_dict = checkpoint['model'] 62 | model.load_state_dict(model_state_dict) 63 | cfg.train.PT_train.start_epochs = int(checkpoint['epoch']) 64 | best_loss_old = checkpoint['val_loss'] 65 | 66 | print('use old model with loss=',best_loss_old,',and epoch = ',cfg.train.PT_train.start_epochs) 67 | 68 | 69 | model.lattice_scaler = datamodule.lattice_scaler.copy() 70 | torch.save(datamodule.lattice_scaler, hydra_dir / 'lattice_scaler.pt') 71 | 72 | if cfg.accelerator == 'DDP': 73 | local_rank = torch.distributed.get_rank() 74 | torch.cuda.set_device(local_rank) 75 | device = torch.device('cuda', local_rank) 76 | model.device = device 77 | model.cuda() 78 | model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) 79 | elif cfg.accelerator == 'gpu': 80 | model.device = 'cuda' 81 | model.cuda() 82 | 83 | for name, param in model.named_parameters(): 84 | print(f"Parameter: {name}, Device: {param.device}", file=sys.stdout) 85 | 86 | 87 | # Store the YaML config separately into the wandb dir 88 | yaml_conf: str = OmegaConf.to_yaml(cfg=cfg) 89 | (hydra_dir / "hparams.yaml").write_text(yaml_conf) 90 | 91 | optimizer = hydra.utils.instantiate( 92 | cfg.optim.optimizer, params=model.parameters(), _convert_="partial" 93 | ) 94 | scheduler = hydra.utils.instantiate( 95 | cfg.optim.lr_scheduler, optimizer=optimizer 96 | ) 97 | 98 | hydra.utils.log.info('Start Train') 99 | test_losses, train_loss_epoch, val_loss_epoch = train(cfg, model, datamodule, optimizer, scheduler, hydra_dir, best_loss_old) 100 | 101 | 102 | hydra.utils.log.info('END') 103 | 104 | 105 | @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default") 106 | def main(cfg: omegaconf.DictConfig): 107 | #YCY 108 | if cfg.accelerator == 'DDP': 109 | torch.distributed.init_process_group(backend='nccl') 110 | local_rank = torch.distributed.get_rank() 111 | print(local_rank) 112 | torch.cuda.set_device(local_rank) 113 | device = torch.device('cuda', local_rank) 114 | 115 | run(cfg) 116 | 117 | 118 | if __name__ == "__main__": 119 | main() -------------------------------------------------------------------------------- /conf/conz_1.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.ConditionModel.ConditioningModule 2 | n_features: 128 3 | n_layers: 2 4 | condition_embeddings: 5 | - _target_: concdvae.pl_modules.ConditionModel.ClassConditionEmbedding 6 | condition_name: crystal_system 7 | n_type: 7 8 | n_emb: 64 9 | n_features: 64 10 | n_layers: 3 11 | 12 | - _target_: concdvae.pl_modules.ConditionModel.ClassConditionEmbedding 13 | condition_name: spacegroup 14 | n_type: 231 15 | n_emb: 64 16 | n_features: 64 17 | n_layers: 3 18 | 19 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 20 | condition_name: formation 21 | condition_min: -5.0 22 | condition_max: 0.5 23 | grid_spacing: 0.5 24 | n_features: 64 25 | n_layers: 3 26 | 27 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 28 | condition_name: bandgap 29 | condition_min: -0.5 30 | condition_max: 9.0 31 | grid_spacing: 0.5 32 | n_features: 64 33 | n_layers: 3 34 | 35 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 36 | condition_name: e_above_hull 37 | condition_min: -0.01 38 | condition_max: 0.09 39 | grid_spacing: 0.01 40 | n_features: 64 41 | n_layers: 3 42 | 43 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 44 | condition_name: a 45 | condition_min: 0 46 | condition_max: 40 47 | grid_spacing: 2 48 | n_features: 64 49 | n_layers: 3 50 | 51 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 52 | condition_name: b 53 | condition_min: 0 54 | condition_max: 38 55 | grid_spacing: 2 56 | n_features: 64 57 | n_layers: 3 58 | 59 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 60 | condition_name: c 61 | condition_min: 0 62 | condition_max: 44 63 | grid_spacing: 2 64 | n_features: 64 65 | n_layers: 3 66 | 67 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 68 | condition_name: alpha 69 | condition_min: 0 70 | condition_max: 180 71 | grid_spacing: 10 72 | n_features: 64 73 | n_layers: 3 74 | 75 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 76 | condition_name: beta 77 | condition_min: 0 78 | condition_max: 180 79 | grid_spacing: 10 80 | n_features: 64 81 | n_layers: 3 82 | 83 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 84 | condition_name: gamma 85 | condition_min: 0 86 | condition_max: 180 87 | grid_spacing: 10 88 | n_features: 64 89 | n_layers: 3 90 | 91 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 92 | condition_name: density 93 | condition_min: 0 94 | condition_max: 24 95 | grid_spacing: 2 96 | n_features: 64 97 | n_layers: 3 98 | 99 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 100 | condition_name: coor_number 101 | condition_min: -1 102 | condition_max: 16 103 | grid_spacing: 1 104 | n_features: 64 105 | n_layers: 3 106 | 107 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 108 | condition_name: n_atom 109 | condition_min: 0 110 | condition_max: 20 111 | grid_spacing: 1 112 | n_features: 64 113 | n_layers: 3 114 | 115 | - _target_: concdvae.pl_modules.ConditionModel.VectorialConditionEmbedding 116 | condition_name: formula 117 | n_in: 92 118 | n_features: 64 119 | n_layers: 3 -------------------------------------------------------------------------------- /conf/conz_2.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.ConditionModel.ConditioningModule 2 | n_features: 128 3 | n_layers: 2 4 | condition_embeddings: 5 | - _target_: concdvae.pl_modules.ConditionModel.ClassConditionEmbedding 6 | condition_name: crystal_system 7 | n_type: 7 8 | n_emb: 64 9 | n_features: 64 10 | n_layers: 3 11 | 12 | - _target_: concdvae.pl_modules.ConditionModel.ClassConditionEmbedding 13 | condition_name: spacegroup 14 | n_type: 231 15 | n_emb: 64 16 | n_features: 64 17 | n_layers: 3 18 | 19 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 20 | condition_name: formation 21 | condition_min: -5.0 22 | condition_max: 0.5 23 | grid_spacing: 0.5 24 | n_features: 64 25 | n_layers: 3 26 | 27 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 28 | condition_name: bandgap 29 | condition_min: -0.5 30 | condition_max: 9.0 31 | grid_spacing: 0.5 32 | n_features: 64 33 | n_layers: 3 34 | 35 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 36 | condition_name: e_above_hull 37 | condition_min: -0.01 38 | condition_max: 0.09 39 | grid_spacing: 0.01 40 | n_features: 64 41 | n_layers: 3 42 | 43 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 44 | condition_name: n_atom 45 | condition_min: 0 46 | condition_max: 20 47 | grid_spacing: 1 48 | n_features: 64 49 | n_layers: 3 50 | 51 | - _target_: concdvae.pl_modules.ConditionModel.VectorialConditionEmbedding 52 | condition_name: formula 53 | n_in: 92 54 | n_features: 64 55 | n_layers: 3 -------------------------------------------------------------------------------- /conf/data/mp_all20.yaml: -------------------------------------------------------------------------------- 1 | root_path: ${oc.env:PROJECT_ROOT}/data/mpall_20 2 | prop: ['formation_energy_per_atom', 'band_gap', 'FM_type', 'BG_type','CS_type'] 3 | use_prop: 'formation_energy_per_atom' 4 | num_targets: 1 5 | # prop: scaled_lattice 6 | # num_targets: 6 7 | niggli: true 8 | primitive: False 9 | graph_method: crystalnn 10 | lattice_scale_method: scale_length 11 | preprocess_workers: 60 12 | readout: mean 13 | max_atoms: 20 14 | otf_graph: false 15 | eval_model_name: mp20 16 | 17 | 18 | train_max_epochs: 1000 19 | early_stopping_patience: 100000 20 | teacher_forcing_max_epoch: 500 21 | 22 | n_delta: 40 23 | 24 | datamodule: 25 | _target_: concdvae.pl_data.datamodule.CrystDataModule 26 | 27 | accelerator: ${accelerator} 28 | n_delta: ${data.n_delta} 29 | use_prop: ${data.use_prop} 30 | 31 | datasets: 32 | train: 33 | _target_: concdvae.pl_data.dataset.CrystDataset 34 | name: Formation energy train 35 | path: ${data.root_path}/train.csv 36 | prop: ${data.prop} 37 | use_prop: ${data.use_prop} 38 | niggli: ${data.niggli} 39 | primitive: ${data.primitive} 40 | graph_method: ${data.graph_method} 41 | lattice_scale_method: ${data.lattice_scale_method} 42 | preprocess_workers: ${data.preprocess_workers} 43 | 44 | val: 45 | - _target_: concdvae.pl_data.dataset.CrystDataset 46 | name: Formation energy val 47 | path: ${data.root_path}/val.csv 48 | prop: ${data.prop} 49 | use_prop: ${data.use_prop} 50 | niggli: ${data.niggli} 51 | primitive: ${data.primitive} 52 | graph_method: ${data.graph_method} 53 | lattice_scale_method: ${data.lattice_scale_method} 54 | preprocess_workers: ${data.preprocess_workers} 55 | 56 | test: 57 | - _target_: concdvae.pl_data.dataset.CrystDataset 58 | name: Formation energy test 59 | path: ${data.root_path}/test.csv 60 | prop: ${data.prop} 61 | use_prop: ${data.use_prop} 62 | niggli: ${data.niggli} 63 | primitive: ${data.primitive} 64 | graph_method: ${data.graph_method} 65 | lattice_scale_method: ${data.lattice_scale_method} 66 | preprocess_workers: ${data.preprocess_workers} 67 | 68 | num_workers: 69 | train: 0 70 | val: 0 71 | test: 0 72 | 73 | batch_size: 74 | train: 512 75 | val: 512 76 | test: 512 77 | -------------------------------------------------------------------------------- /conf/data/mp_all40.yaml: -------------------------------------------------------------------------------- 1 | root_path: ${oc.env:PROJECT_ROOT}/data/mpall_40 2 | prop: ['formation_energy_per_atom', 'band_gap', 'FM_type', 'BG_type','CS_type'] 3 | use_prop: 'formation_energy_per_atom' 4 | num_targets: 1 5 | # prop: scaled_lattice 6 | # num_targets: 6 7 | niggli: true 8 | primitive: False 9 | graph_method: crystalnn 10 | lattice_scale_method: scale_length 11 | preprocess_workers: 60 12 | readout: mean 13 | max_atoms: 40 14 | otf_graph: false 15 | eval_model_name: mp20 16 | 17 | 18 | train_max_epochs: 1000 19 | early_stopping_patience: 100000 20 | teacher_forcing_max_epoch: 500 21 | 22 | n_delta: 40 23 | 24 | datamodule: 25 | _target_: concdvae.pl_data.datamodule.CrystDataModule 26 | 27 | accelerator: ${accelerator} 28 | n_delta: ${data.n_delta} 29 | use_prop: ${data.use_prop} 30 | 31 | datasets: 32 | train: 33 | _target_: concdvae.pl_data.dataset.CrystDataset 34 | name: Formation energy train 35 | path: ${data.root_path}/train.csv 36 | prop: ${data.prop} 37 | use_prop: ${data.use_prop} 38 | niggli: ${data.niggli} 39 | primitive: ${data.primitive} 40 | graph_method: ${data.graph_method} 41 | lattice_scale_method: ${data.lattice_scale_method} 42 | preprocess_workers: ${data.preprocess_workers} 43 | 44 | val: 45 | - _target_: concdvae.pl_data.dataset.CrystDataset 46 | name: Formation energy val 47 | path: ${data.root_path}/val.csv 48 | prop: ${data.prop} 49 | use_prop: ${data.use_prop} 50 | niggli: ${data.niggli} 51 | primitive: ${data.primitive} 52 | graph_method: ${data.graph_method} 53 | lattice_scale_method: ${data.lattice_scale_method} 54 | preprocess_workers: ${data.preprocess_workers} 55 | 56 | test: 57 | - _target_: concdvae.pl_data.dataset.CrystDataset 58 | name: Formation energy test 59 | path: ${data.root_path}/test.csv 60 | prop: ${data.prop} 61 | use_prop: ${data.use_prop} 62 | niggli: ${data.niggli} 63 | primitive: ${data.primitive} 64 | graph_method: ${data.graph_method} 65 | lattice_scale_method: ${data.lattice_scale_method} 66 | preprocess_workers: ${data.preprocess_workers} 67 | 68 | num_workers: 69 | train: 0 70 | val: 0 71 | test: 0 72 | 73 | batch_size: 74 | train: 256 75 | val: 256 76 | test: 256 77 | -------------------------------------------------------------------------------- /conf/data/mptest.yaml: -------------------------------------------------------------------------------- 1 | root_path: ${oc.env:PROJECT_ROOT}/data/mptest 2 | prop: ['formation_energy_per_atom', 'band_gap', 'FM_type', 'BG_type','CS_type'] 3 | use_prop: 'formation_energy_per_atom' 4 | num_targets: 1 5 | # prop: scaled_lattice 6 | # num_targets: 6 7 | niggli: true 8 | primitive: False 9 | graph_method: crystalnn 10 | lattice_scale_method: scale_length 11 | preprocess_workers: 1 12 | readout: mean 13 | max_atoms: 20 14 | otf_graph: false 15 | eval_model_name: mp20 16 | 17 | 18 | train_max_epochs: 30 19 | early_stopping_patience: 100000 20 | teacher_forcing_max_epoch: 15 21 | 22 | n_delta: 40 23 | 24 | datamodule: 25 | _target_: concdvae.pl_data.datamodule.CrystDataModule 26 | 27 | accelerator: ${accelerator} 28 | n_delta: ${data.n_delta} 29 | use_prop: ${data.use_prop} 30 | 31 | datasets: 32 | train: 33 | _target_: concdvae.pl_data.dataset.CrystDataset 34 | name: Formation energy train 35 | path: ${data.root_path}/train.csv 36 | prop: ${data.prop} 37 | use_prop: ${data.use_prop} 38 | niggli: ${data.niggli} 39 | primitive: ${data.primitive} 40 | graph_method: ${data.graph_method} 41 | lattice_scale_method: ${data.lattice_scale_method} 42 | preprocess_workers: ${data.preprocess_workers} 43 | 44 | val: 45 | - _target_: concdvae.pl_data.dataset.CrystDataset 46 | name: Formation energy val 47 | path: ${data.root_path}/val.csv 48 | prop: ${data.prop} 49 | use_prop: ${data.use_prop} 50 | niggli: ${data.niggli} 51 | primitive: ${data.primitive} 52 | graph_method: ${data.graph_method} 53 | lattice_scale_method: ${data.lattice_scale_method} 54 | preprocess_workers: ${data.preprocess_workers} 55 | 56 | test: 57 | - _target_: concdvae.pl_data.dataset.CrystDataset 58 | name: Formation energy test 59 | path: ${data.root_path}/test.csv 60 | prop: ${data.prop} 61 | use_prop: ${data.use_prop} 62 | niggli: ${data.niggli} 63 | primitive: ${data.primitive} 64 | graph_method: ${data.graph_method} 65 | lattice_scale_method: ${data.lattice_scale_method} 66 | preprocess_workers: ${data.preprocess_workers} 67 | 68 | num_workers: 69 | train: 0 70 | val: 0 71 | test: 0 72 | 73 | batch_size: 74 | train: 10 75 | val: 10 76 | test: 10 77 | -------------------------------------------------------------------------------- /conf/data/oqmd_20.yaml: -------------------------------------------------------------------------------- 1 | root_path: ${oc.env:PROJECT_ROOT}/data/oqmd 2 | prop: ['formation_energy_per_atom', 'band_gap', 'FM_type', 'BG_type','CS_type'] 3 | use_prop: 'formation_energy_per_atom' 4 | num_targets: 1 5 | # prop: scaled_lattice 6 | # num_targets: 6 7 | niggli: true 8 | primitive: False 9 | graph_method: crystalnn 10 | lattice_scale_method: scale_length 11 | preprocess_workers: 60 12 | readout: mean 13 | max_atoms: 20 14 | otf_graph: false 15 | eval_model_name: mp20 16 | 17 | 18 | train_max_epochs: 1000 19 | early_stopping_patience: 200 20 | teacher_forcing_max_epoch: 500 21 | 22 | n_delta: 40 23 | 24 | datamodule: 25 | _target_: concdvae.pl_data.datamodule.CrystDataModule 26 | 27 | accelerator: ${accelerator} 28 | n_delta: ${data.n_delta} 29 | use_prop: ${data.use_prop} 30 | 31 | datasets: 32 | train: 33 | _target_: concdvae.pl_data.dataset.CrystDataset 34 | name: Formation energy train 35 | path: ${data.root_path}/train.csv 36 | prop: ${data.prop} 37 | use_prop: ${data.use_prop} 38 | niggli: ${data.niggli} 39 | primitive: ${data.primitive} 40 | graph_method: ${data.graph_method} 41 | lattice_scale_method: ${data.lattice_scale_method} 42 | preprocess_workers: ${data.preprocess_workers} 43 | 44 | val: 45 | - _target_: concdvae.pl_data.dataset.CrystDataset 46 | name: Formation energy val 47 | path: ${data.root_path}/val.csv 48 | prop: ${data.prop} 49 | use_prop: ${data.use_prop} 50 | niggli: ${data.niggli} 51 | primitive: ${data.primitive} 52 | graph_method: ${data.graph_method} 53 | lattice_scale_method: ${data.lattice_scale_method} 54 | preprocess_workers: ${data.preprocess_workers} 55 | 56 | test: 57 | - _target_: concdvae.pl_data.dataset.CrystDataset 58 | name: Formation energy test 59 | path: ${data.root_path}/test.csv 60 | prop: ${data.prop} 61 | use_prop: ${data.use_prop} 62 | niggli: ${data.niggli} 63 | primitive: ${data.primitive} 64 | graph_method: ${data.graph_method} 65 | lattice_scale_method: ${data.lattice_scale_method} 66 | preprocess_workers: ${data.preprocess_workers} 67 | 68 | num_workers: 69 | train: 0 70 | val: 0 71 | test: 0 72 | 73 | batch_size: 74 | train: 512 75 | val: 512 76 | test: 512 77 | -------------------------------------------------------------------------------- /conf/default.yaml: -------------------------------------------------------------------------------- 1 | expname: test 2 | 3 | # metadata specialised for each experiment 4 | core: 5 | version: 0.0.1 6 | tags: 7 | - ${now:%Y-%m-%d} 8 | 9 | hydra: 10 | run: 11 | dir: ${oc.env:HYDRA_JOBS}/singlerun/${now:%Y-%m-%d}/${expname}/ 12 | 13 | sweep: 14 | dir: ${oc.env:HYDRA_JOBS}/multirun/${now:%Y-%m-%d}/${expname}/ 15 | subdir: ${hydra.job.num}_${hydra.job.id} 16 | 17 | job: 18 | env_set: 19 | WANDB_START_METHOD: thread 20 | WANDB_DIR: ${oc.env:WABDB_DIR} 21 | 22 | accelerator: 'cpu' 23 | 24 | defaults: 25 | - data: default 26 | - logging: default 27 | - model: vae 28 | - optim: default 29 | - train: default 30 | # Decomment this parameter to get parallel job running 31 | # - override hydra/launcher: joblib 32 | 33 | -------------------------------------------------------------------------------- /conf/logging/default.yaml: -------------------------------------------------------------------------------- 1 | # log frequency 2 | val_check_interval: 5 3 | progress_bar_refresh_rate: 20 4 | 5 | wandb: 6 | name: ${expname} 7 | project: crystal_generation_mit 8 | entity: null 9 | log_model: True 10 | mode: 'online' 11 | group: ${expname} 12 | 13 | wandb_watch: 14 | log: 'all' 15 | log_freq: 500 16 | 17 | lr_monitor: 18 | logging_interval: "step" 19 | log_momentum: False 20 | -------------------------------------------------------------------------------- /conf/model/conditionmodel/mp_CSclass.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.ConditionModel.ConditioningModule 2 | n_features: 128 3 | n_layers: 2 4 | condition_embeddings: 5 | - _target_: concdvae.pl_modules.ConditionModel.ClassConditionEmbedding 6 | condition_name: CS_type 7 | n_type: 7 8 | n_emb: 64 9 | n_features: 64 10 | n_layers: 3 11 | 12 | -------------------------------------------------------------------------------- /conf/model/conditionmodel/mp_FMclass_BGclass.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.ConditionModel.ConditioningModule 2 | n_features: 128 3 | n_layers: 2 4 | condition_embeddings: 5 | - _target_: concdvae.pl_modules.ConditionModel.ClassConditionEmbedding 6 | condition_name: FM_type 7 | n_type: 2 8 | n_emb: 64 9 | n_features: 64 10 | n_layers: 3 11 | 12 | - _target_: concdvae.pl_modules.ConditionModel.ClassConditionEmbedding 13 | condition_name: BG_type 14 | n_type: 2 15 | n_emb: 64 16 | n_features: 64 17 | n_layers: 3 18 | -------------------------------------------------------------------------------- /conf/model/conditionmodel/mp_format.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.ConditionModel.ConditioningModule 2 | n_features: 128 3 | n_layers: 2 4 | condition_embeddings: 5 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 6 | condition_name: formation_energy_per_atom 7 | condition_min: -6.0 8 | condition_max: 1.0 9 | grid_spacing: 0.5 10 | n_features: 64 11 | n_layers: 3 12 | -------------------------------------------------------------------------------- /conf/model/conditionmodel/mp_format_gap.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.ConditionModel.ConditioningModule 2 | n_features: 128 3 | n_layers: 2 4 | condition_embeddings: 5 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 6 | condition_name: formation_energy_per_atom 7 | condition_min: -6.0 8 | condition_max: 1.0 9 | grid_spacing: 0.5 10 | n_features: 64 11 | n_layers: 3 12 | 13 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 14 | condition_name: band_gap 15 | condition_min: -1.0 16 | condition_max: 9.0 17 | grid_spacing: 0.5 18 | n_features: 64 19 | n_layers: 3 20 | -------------------------------------------------------------------------------- /conf/model/conditionmodel/mp_gap.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.ConditionModel.ConditioningModule 2 | n_features: 128 3 | n_layers: 2 4 | condition_embeddings: 5 | - _target_: concdvae.pl_modules.ConditionModel.ScalarConditionEmbedding 6 | condition_name: band_gap 7 | condition_min: -1.0 8 | condition_max: 9.0 9 | grid_spacing: 0.5 10 | n_features: 64 11 | n_layers: 3 12 | -------------------------------------------------------------------------------- /conf/model/conditionpre/pre_mp_CSclass.yaml: -------------------------------------------------------------------------------- 1 | condition_predict: 2 | - _target_: concdvae.pl_modules.PreCondition.ClassConditionPredict 3 | condition_name: CS_type 4 | n_type: 7 5 | latent_dim: ${model.latent_dim} 6 | hidden_dim: 4 7 | n_layers: 1 8 | drop: 0.2 -------------------------------------------------------------------------------- /conf/model/conditionpre/pre_mp_FMclass_BGclass.yaml: -------------------------------------------------------------------------------- 1 | condition_predict: 2 | - _target_: concdvae.pl_modules.PreCondition.ClassConditionPredict 3 | condition_name: FM_type 4 | n_type: 2 5 | latent_dim: ${model.latent_dim} 6 | hidden_dim: 256 7 | n_layers: 2 8 | drop: 0.4 9 | 10 | - _target_: concdvae.pl_modules.PreCondition.ClassConditionPredict 11 | condition_name: BG_type 12 | n_type: 2 13 | latent_dim: ${model.latent_dim} 14 | hidden_dim: 256 15 | n_layers: 2 16 | drop: 0.4 -------------------------------------------------------------------------------- /conf/model/conditionpre/pre_mp_format.yaml: -------------------------------------------------------------------------------- 1 | condition_predict: 2 | - _target_: concdvae.pl_modules.PreCondition.ScalarConditionPredict 3 | condition_name: formation_energy_per_atom 4 | condition_min: -6.0 5 | condition_max: 1.0 6 | latent_dim: ${model.latent_dim} 7 | hidden_dim: 256 8 | out_dim: 1 9 | n_layers: 2 -------------------------------------------------------------------------------- /conf/model/conditionpre/pre_mp_format_gap.yaml: -------------------------------------------------------------------------------- 1 | condition_predict: 2 | - _target_: concdvae.pl_modules.PreCondition.ScalarConditionPredict 3 | condition_name: formation_energy_per_atom 4 | condition_min: -5.0 5 | condition_max: 1.0 6 | latent_dim: ${model.latent_dim} 7 | hidden_dim: 256 8 | out_dim: 1 9 | n_layers: 2 10 | 11 | - _target_: concdvae.pl_modules.PreCondition.ScalarConditionPredict 12 | condition_name: band_gap 13 | condition_min: -1.0 14 | condition_max: 9.0 15 | latent_dim: ${model.latent_dim} 16 | hidden_dim: 256 17 | out_dim: 1 18 | n_layers: 2 -------------------------------------------------------------------------------- /conf/model/conditionpre/pre_mp_gap.yaml: -------------------------------------------------------------------------------- 1 | condition_predict: 2 | - _target_: concdvae.pl_modules.PreCondition.ScalarConditionPredict 3 | condition_name: band_gap 4 | condition_min: -1. 5 | condition_max: 9.0 6 | latent_dim: ${model.latent_dim} 7 | hidden_dim: 256 8 | out_dim: 1 9 | n_layers: 2 -------------------------------------------------------------------------------- /conf/model/decoder/gemnet.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.decoder.GemNetTDecoder 2 | hidden_dim: 128 3 | latent_dim: ${model.latent_dim} 4 | time_emb_dim: ${model.time_emb_dim} 5 | max_neighbors: ${model.max_neighbors} 6 | radius: ${model.radius} 7 | scale_file: ${oc.env:PROJECT_ROOT}/concdvae/pl_modules/gemnet/gemnet-dT.json 8 | -------------------------------------------------------------------------------- /conf/model/encoder/dimenet.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.gnn.DimeNetPlusPlusWrap 2 | num_targets: ${data.num_targets} 3 | hidden_channels: 128 4 | num_blocks: 4 5 | int_emb_size: 64 6 | basis_emb_size: 8 7 | out_emb_channels: 256 8 | num_spherical: 7 9 | num_radial: 6 10 | otf_graph: ${data.otf_graph} 11 | cutoff: 7.0 12 | max_num_neighbors: 20 13 | envelope_exponent: 5 14 | num_before_skip: 1 15 | num_after_skip: 2 16 | num_output_layers: 3 17 | readout: ${data.readout} -------------------------------------------------------------------------------- /conf/model/supervise.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.model.CrystGNN_Supervise 2 | use_orientation: false 3 | hidden_dim: 128 4 | fc_num_layers: 4 5 | use_pe: false 6 | 7 | 8 | defaults: 9 | - encoder: dimenet 10 | -------------------------------------------------------------------------------- /conf/model/vae.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.model.CDVAE 2 | hidden_dim: 256 3 | latent_dim: 256 4 | fc_num_layers: 1 5 | max_atoms: ${data.max_atoms} 6 | cost_natom: 1. 7 | cost_coord: 10. 8 | cost_type: 1. 9 | cost_lattice: 10. 10 | cost_composition: 1. 11 | cost_edge: 10. 12 | cost_property: 3. 13 | beta: 0.01 14 | teacher_forcing_lattice: true 15 | teacher_forcing_max_epoch: ${data.teacher_forcing_max_epoch} 16 | max_neighbors: 20 # maximum number of neighbors for OTF graph bulding in decoder 17 | radius: 7. # maximum search radius for OTF graph building in decoder 18 | sigma_begin: 10. 19 | sigma_end: 0.01 20 | type_sigma_begin: 5. 21 | type_sigma_end: 0.01 22 | num_noise_level: 50 23 | predict_property: False 24 | 25 | n_delta: 40 26 | 27 | defaults: 28 | - encoder: dimenet 29 | - decoder: gemnet 30 | - conditionmodel: perov_heatref_dirgap 31 | - conditionpre: pre_perov_heatref_dirgap 32 | -------------------------------------------------------------------------------- /conf/model/vae_mp_CSclass.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.model.CDVAE 2 | hidden_dim: 256 3 | latent_dim: 256 4 | time_emb_dim: 64 5 | fc_num_layers: 1 6 | max_atoms: ${data.max_atoms} 7 | cost_natom: 1. 8 | cost_coord: 10. 9 | cost_type: 1. 10 | cost_lattice: 10. 11 | cost_composition: 1. 12 | cost_edge: 10. 13 | cost_property: 3. 14 | beta: 0.01 15 | teacher_forcing_lattice: true 16 | teacher_forcing_max_epoch: ${data.teacher_forcing_max_epoch} 17 | max_neighbors: 20 # maximum number of neighbors for OTF graph bulding in decoder 18 | radius: 9. # maximum search radius for OTF graph building in decoder 19 | sigma_begin: 10. 20 | sigma_end: 0.01 21 | type_sigma_begin: 5. 22 | type_sigma_end: 0.01 23 | num_noise_level: 50 24 | predict_property: False 25 | 26 | n_delta: 40 27 | 28 | defaults: 29 | - encoder: dimenet 30 | - decoder: gemnet 31 | - conditionmodel: mp_CSclass 32 | - conditionpre: pre_mp_CSclass 33 | -------------------------------------------------------------------------------- /conf/model/vae_mp_FMclass_BGclass.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.model.CDVAE 2 | hidden_dim: 256 3 | latent_dim: 256 4 | fc_num_layers: 1 5 | time_emb_dim: 64 6 | max_atoms: ${data.max_atoms} 7 | cost_natom: 1. 8 | cost_coord: 10. 9 | cost_type: 1. 10 | cost_lattice: 10. 11 | cost_composition: 1. 12 | cost_edge: 10. 13 | cost_property: 3. 14 | beta: 0.01 15 | teacher_forcing_lattice: true 16 | teacher_forcing_max_epoch: ${data.teacher_forcing_max_epoch} 17 | max_neighbors: 20 # maximum number of neighbors for OTF graph bulding in decoder 18 | radius: 9. # maximum search radius for OTF graph building in decoder 19 | sigma_begin: 10. 20 | sigma_end: 0.01 21 | type_sigma_begin: 5. 22 | type_sigma_end: 0.01 23 | num_noise_level: 50 24 | predict_property: False 25 | 26 | n_delta: 40 27 | 28 | defaults: 29 | - encoder: dimenet 30 | - decoder: gemnet 31 | - conditionmodel: mp_FMclass_BGclass 32 | - conditionpre: pre_mp_FMclass_BGclass 33 | -------------------------------------------------------------------------------- /conf/model/vae_mp_format.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.model.CDVAE 2 | hidden_dim: 256 3 | latent_dim: 256 4 | time_emb_dim: 64 5 | fc_num_layers: 1 6 | max_atoms: ${data.max_atoms} 7 | cost_natom: 1. 8 | cost_coord: 10. 9 | cost_type: 1. 10 | cost_lattice: 10. 11 | cost_composition: 1. 12 | cost_edge: 10. 13 | cost_property: 3. 14 | beta: 0.01 15 | teacher_forcing_lattice: true 16 | teacher_forcing_max_epoch: ${data.teacher_forcing_max_epoch} 17 | max_neighbors: 20 # maximum number of neighbors for OTF graph bulding in decoder 18 | radius: 9. # maximum search radius for OTF graph building in decoder 19 | sigma_begin: 10. 20 | sigma_end: 0.01 21 | type_sigma_begin: 5. 22 | type_sigma_end: 0.01 23 | num_noise_level: 50 24 | predict_property: False 25 | 26 | n_delta: 40 27 | 28 | defaults: 29 | - encoder: dimenet 30 | - decoder: gemnet 31 | - conditionmodel: mp_format 32 | - conditionpre: pre_mp_format 33 | -------------------------------------------------------------------------------- /conf/model/vae_mp_format_gap.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.model.CDVAE 2 | hidden_dim: 256 3 | latent_dim: 256 4 | fc_num_layers: 1 5 | time_emb_dim: 64 6 | max_atoms: ${data.max_atoms} 7 | cost_natom: 1. 8 | cost_coord: 10. 9 | cost_type: 1. 10 | cost_lattice: 10. 11 | cost_composition: 1. 12 | cost_edge: 10. 13 | cost_property: 3. 14 | beta: 0.01 15 | teacher_forcing_lattice: true 16 | teacher_forcing_max_epoch: ${data.teacher_forcing_max_epoch} 17 | max_neighbors: 20 # maximum number of neighbors for OTF graph bulding in decoder 18 | radius: 9. # maximum search radius for OTF graph building in decoder 19 | sigma_begin: 10. 20 | sigma_end: 0.01 21 | type_sigma_begin: 5. 22 | type_sigma_end: 0.01 23 | num_noise_level: 50 24 | predict_property: False 25 | 26 | n_delta: 40 27 | 28 | defaults: 29 | - encoder: dimenet 30 | - decoder: gemnet 31 | - conditionmodel: mp_format_gap 32 | - conditionpre: pre_mp_format_gap 33 | -------------------------------------------------------------------------------- /conf/model/vae_mp_gap.yaml: -------------------------------------------------------------------------------- 1 | _target_: concdvae.pl_modules.model.CDVAE 2 | hidden_dim: 256 3 | latent_dim: 256 4 | fc_num_layers: 1 5 | time_emb_dim: 64 6 | max_atoms: ${data.max_atoms} 7 | cost_natom: 1. 8 | cost_coord: 10. 9 | cost_type: 1. 10 | cost_lattice: 10. 11 | cost_composition: 1. 12 | cost_edge: 10. 13 | cost_property: 3. 14 | beta: 0.01 15 | teacher_forcing_lattice: true 16 | teacher_forcing_max_epoch: ${data.teacher_forcing_max_epoch} 17 | max_neighbors: 20 # maximum number of neighbors for OTF graph bulding in decoder 18 | radius: 9. # maximum search radius for OTF graph building in decoder 19 | sigma_begin: 10. 20 | sigma_end: 0.01 21 | type_sigma_begin: 5. 22 | type_sigma_end: 0.01 23 | num_noise_level: 50 24 | predict_property: False 25 | 26 | n_delta: 40 27 | 28 | defaults: 29 | - encoder: dimenet 30 | - decoder: gemnet 31 | - conditionmodel: mp_gap 32 | - conditionpre: pre_mp_gap 33 | -------------------------------------------------------------------------------- /conf/optim/default.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | # Adam-oriented deep learning 3 | _target_: torch.optim.Adam 4 | # These are all default parameters for the Adam optimizer 5 | lr: 0.001 6 | betas: [ 0.9, 0.999 ] 7 | eps: 1e-08 8 | weight_decay: 0 9 | 10 | use_lr_scheduler: True 11 | lr_scheduler: 12 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 13 | factor: 0.6 14 | patience: 30 15 | min_lr: 1e-4 -------------------------------------------------------------------------------- /conf/optim/less1.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | # Adam-oriented deep learning 3 | _target_: torch.optim.Adam 4 | # These are all default parameters for the Adam optimizer 5 | lr: 0.0001 6 | betas: [ 0.9, 0.999 ] 7 | eps: 1e-08 8 | weight_decay: 0 9 | 10 | use_lr_scheduler: True 11 | lr_scheduler: 12 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 13 | factor: 0.6 14 | patience: 30 15 | min_lr: 1e-5 -------------------------------------------------------------------------------- /conf/optim/less2.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | # Adam-oriented deep learning 3 | _target_: torch.optim.Adam 4 | # These are all default parameters for the Adam optimizer 5 | lr: 0.00001 6 | betas: [ 0.9, 0.999 ] 7 | eps: 1e-08 8 | weight_decay: 0 9 | 10 | use_lr_scheduler: True 11 | lr_scheduler: 12 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 13 | factor: 0.6 14 | patience: 30 15 | min_lr: 1e-6 -------------------------------------------------------------------------------- /conf/train/default.yaml: -------------------------------------------------------------------------------- 1 | # reproducibility 2 | deterministic: False 3 | random_seed: 42 4 | 5 | # training 6 | 7 | pl_trainer: 8 | fast_dev_run: False # Enable this for debug purposes 9 | gpus: 1 10 | precision: 32 11 | # max_steps: 10000 12 | max_epochs: ${data.train_max_epochs} 13 | accumulate_grad_batches: 1 14 | num_sanity_val_steps: 2 15 | gradient_clip_val: 0.5 16 | gradient_clip_algorithm: value 17 | profiler: simple 18 | 19 | monitor_metric: 'val_loss' 20 | monitor_metric_mode: 'min' 21 | 22 | early_stopping: 23 | patience: ${data.early_stopping_patience} # 60 24 | verbose: False 25 | 26 | model_checkpoints: 27 | save_top_k: 1 28 | verbose: False 29 | -------------------------------------------------------------------------------- /conf/train/new.yaml: -------------------------------------------------------------------------------- 1 | # reproducibility 2 | deterministic: True 3 | random_seed: 123 4 | 5 | 6 | PT_train: 7 | start_epochs: 0 8 | max_epochs: ${data.train_max_epochs} 9 | print_freq: 20 10 | clip_grad_norm: -1 11 | clip_grad_norm_epoch: 100 -------------------------------------------------------------------------------- /data/mptest/test_data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/data/mptest/test_data.pt -------------------------------------------------------------------------------- /data/mptest/train_data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/data/mptest/train_data.pt -------------------------------------------------------------------------------- /data/mptest/val_data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/data/mptest/val_data.pt -------------------------------------------------------------------------------- /data/mptest4conz/test_data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/data/mptest4conz/test_data.pt -------------------------------------------------------------------------------- /data/mptest4conz/train_data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/data/mptest4conz/train_data.pt -------------------------------------------------------------------------------- /data/mptest4conz/val_data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/data/mptest4conz/val_data.pt -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: CDVAEycy38 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.05.30=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.4.4=h6a678d5_0 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=3.0.10=h7f8727e_0 15 | - pip=23.2.1=py38h06a4308_0 16 | - python=3.8.17=h955ad1f_0 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=68.0.0=py38h06a4308_0 19 | - sqlite=3.41.2=h5eee18b_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - wheel=0.38.4=py38h06a4308_0 22 | - xz=5.4.2=h5eee18b_0 23 | - zlib=1.2.13=h5eee18b_0 24 | - pip: 25 | - absl-py==1.4.0 26 | - aiohttp==3.8.5 27 | - aiosignal==1.3.1 28 | - altair==5.0.1 29 | - antlr4-python3-runtime==4.9.3 30 | - anyio==3.7.1 31 | - argon2-cffi==21.3.0 32 | - argon2-cffi-bindings==21.2.0 33 | - arrow==1.2.3 34 | - ase==3.22.1 35 | - astroid==2.15.6 36 | - asttokens==2.2.1 37 | - async-lru==2.0.4 38 | - async-timeout==4.0.3 39 | - attrs==23.1.0 40 | - autopep8==2.0.2 41 | - babel==2.12.1 42 | - backcall==0.2.0 43 | - backports-zoneinfo==0.2.1 44 | - beautifulsoup4==4.12.2 45 | - bleach==6.0.0 46 | - blinker==1.6.2 47 | - cachetools==5.3.1 48 | - certifi==2023.7.22 49 | - cffi==1.15.1 50 | - charset-normalizer==3.2.0 51 | - click==8.1.6 52 | - cmake==3.25.0 53 | - comm==0.1.4 54 | - contourpy==1.1.0 55 | - cycler==0.11.0 56 | - debugpy==1.6.7.post1 57 | - decorator==5.1.1 58 | - defusedxml==0.7.1 59 | - dill==0.3.7 60 | - dnspython==2.4.2 61 | - emmet-core==0.64.3 62 | - et-xmlfile==1.1.0 63 | - exceptiongroup==1.1.2 64 | - executing==1.2.0 65 | - fastjsonschema==2.18.0 66 | - filelock==3.9.0 67 | - fire==0.5.0 68 | - fonttools==4.42.0 69 | - fqdn==1.5.1 70 | - frozenlist==1.4.0 71 | - fsspec==2023.6.0 72 | - future==0.18.3 73 | - gitdb==4.0.10 74 | - gitpython==3.1.32 75 | - google-auth==2.22.0 76 | - google-auth-oauthlib==1.0.0 77 | - grpcio==1.57.0 78 | - hydra-core==1.3.2 79 | - hydra-joblib-launcher==1.2.0 80 | - idna==3.4 81 | - importlib-metadata==6.8.0 82 | - importlib-resources==6.0.1 83 | - iniconfig==2.0.0 84 | - ipykernel==6.25.1 85 | - ipython==8.12.2 86 | - ipywidgets==8.1.0 87 | - isoduration==20.11.0 88 | - isort==5.12.0 89 | - jedi==0.19.0 90 | - jinja2==3.1.2 91 | - joblib==1.3.2 92 | - json5==0.9.14 93 | - jsonpointer==2.4 94 | - jsonschema==4.19.0 95 | - jsonschema-specifications==2023.7.1 96 | - jupyter-client==8.3.0 97 | - jupyter-core==5.3.1 98 | - jupyter-events==0.7.0 99 | - jupyter-lsp==2.2.0 100 | - jupyter-server==2.7.0 101 | - jupyter-server-terminals==0.4.4 102 | - jupyterlab==4.0.4 103 | - jupyterlab-pygments==0.2.2 104 | - jupyterlab-server==2.24.0 105 | - jupyterlab-widgets==3.0.8 106 | - kiwisolver==1.4.4 107 | - latexcodec==2.0.1 108 | - lazy-object-proxy==1.9.0 109 | - lightning-lite==1.8.0 110 | - lightning-utilities==0.3.0 111 | - lit==15.0.7 112 | - markdown==3.4.4 113 | - markdown-it-py==3.0.0 114 | - markupsafe==2.1.3 115 | - matminer==0.9.0 116 | - matplotlib==3.7.2 117 | - matplotlib-inline==0.1.6 118 | - mccabe==0.7.0 119 | - mdurl==0.1.2 120 | - mistune==3.0.1 121 | - monty==2023.8.8 122 | - mp-api==0.33.3 123 | - mpmath==1.3.0 124 | - msgpack==1.0.5 125 | - multidict==6.0.4 126 | - multiprocess==0.70.15 127 | - nbclient==0.8.0 128 | - nbconvert==7.7.3 129 | - nbformat==5.9.2 130 | - nest-asyncio==1.5.7 131 | - networkx==3.1 132 | - nglview==3.0.6 133 | - notebook-shim==0.2.3 134 | - numpy==1.24.4 135 | - nvidia-ml-py3==7.352.0 136 | - oauthlib==3.2.2 137 | - omegaconf==2.3.0 138 | - openpyxl==3.1.2 139 | - overrides==7.4.0 140 | - p-tqdm==1.4.0 141 | - packaging==23.1 142 | - palettable==3.3.3 143 | - pandas==1.5.3 144 | - pandocfilters==1.5.0 145 | - parso==0.8.3 146 | - pathos==0.3.1 147 | - pexpect==4.8.0 148 | - pickleshare==0.7.5 149 | - pillow==9.5.0 150 | - pkgutil-resolve-name==1.3.10 151 | - platformdirs==3.10.0 152 | - plotly==5.16.0 153 | - pluggy==1.2.0 154 | - pox==0.3.3 155 | - ppft==1.7.6.7 156 | - prometheus-client==0.17.1 157 | - prompt-toolkit==3.0.39 158 | - protobuf==4.24.0 159 | - psutil==5.9.5 160 | - ptyprocess==0.7.0 161 | - pure-eval==0.2.2 162 | - pyarrow==12.0.1 163 | - pyasn1==0.5.0 164 | - pyasn1-modules==0.3.0 165 | - pybtex==0.24.0 166 | - pycodestyle==2.11.0 167 | - pycparser==2.21 168 | - pydantic==1.10.12 169 | - pydeck==0.8.0 170 | - pyg-lib==0.2.0+pt20cu118 171 | - pygments==2.16.1 172 | - pylint==2.17.5 173 | - pymatgen==2023.8.10 174 | - pymongo==4.4.1 175 | - pympler==1.0.1 176 | - pyparsing==3.0.9 177 | - pytest==7.4.0 178 | - python-dateutil==2.8.2 179 | - python-dotenv==1.0.0 180 | - python-json-logger==2.0.7 181 | - pytorch-lightning==1.8.0 182 | - pytz==2023.3 183 | - pytz-deprecation-shim==0.1.0.post0 184 | - pyyaml==6.0.1 185 | - pyzmq==25.1.1 186 | - referencing==0.30.2 187 | - requests==2.31.0 188 | - requests-oauthlib==1.3.1 189 | - rfc3339-validator==0.1.4 190 | - rfc3986-validator==0.1.1 191 | - rich==13.5.2 192 | - rpds-py==0.9.2 193 | - rsa==4.9 194 | - ruamel-yaml==0.17.32 195 | - ruamel-yaml-clib==0.2.7 196 | - scikit-learn==1.3.0 197 | - scipy==1.10.1 198 | - seaborn==0.12.2 199 | - send2trash==1.8.2 200 | - six==1.16.0 201 | - smact==2.5.2 202 | - smmap==5.0.0 203 | - sniffio==1.3.0 204 | - soupsieve==2.4.1 205 | - spglib==2.0.2 206 | - stack-data==0.6.2 207 | - streamlit==1.25.0 208 | - sympy==1.12 209 | - tabulate==0.9.0 210 | - tenacity==8.2.2 211 | - tensorboard==2.14.0 212 | - tensorboard-data-server==0.7.1 213 | - termcolor==2.3.0 214 | - terminado==0.17.1 215 | - threadpoolctl==3.2.0 216 | - tinycss2==1.2.1 217 | - toml==0.10.2 218 | - tomli==2.0.1 219 | - tomlkit==0.12.1 220 | - toolz==0.12.0 221 | - torch==2.0.1+cu118 222 | - torch-cluster==1.6.1+pt20cu118 223 | - torch-geometric==2.3.1 224 | - torch-scatter==2.1.1+pt20cu118 225 | - torch-sparse==0.6.17+pt20cu118 226 | - torch-spline-conv==1.2.2+pt20cu118 227 | - torchaudio==2.0.2+cu118 228 | - torchmetrics==0.11.4 229 | - torchvision==0.15.2+cu118 230 | - tornado==6.3.3 231 | - tqdm==4.66.1 232 | - traitlets==5.9.0 233 | - triton==2.0.0 234 | - typing-extensions==4.7.1 235 | - tzdata==2023.3 236 | - tzlocal==4.3.1 237 | - uncertainties==3.1.7 238 | - uri-template==1.3.0 239 | - urllib3==1.26.16 240 | - validators==0.21.2 241 | - watchdog==3.0.0 242 | - wcwidth==0.2.6 243 | - webcolors==1.13 244 | - webencodings==0.5.1 245 | - websocket-client==1.6.1 246 | - werkzeug==2.3.6 247 | - widgetsnbextension==4.0.8 248 | - wrapt==1.15.0 249 | - yarl==1.9.2 250 | - zipp==3.16.2 251 | prefix: /home/cyye/anaconda3/envs/CDVAEycy38 252 | -------------------------------------------------------------------------------- /output/hydra/singlerun/2024-01-25/test/conz_loss_file_ABC.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/output/hydra/singlerun/2024-01-25/test/conz_loss_file_ABC.xlsx -------------------------------------------------------------------------------- /output/hydra/singlerun/2024-01-25/test/conz_model_ABC_diffu.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/output/hydra/singlerun/2024-01-25/test/conz_model_ABC_diffu.pth -------------------------------------------------------------------------------- /output/hydra/singlerun/2024-01-25/test/general_full.csv: -------------------------------------------------------------------------------- 1 | label,formula,bandgap,formation,e_above_hull,a,b,c,alpha,beta,gamma,density,coor_number,n_atom,spacegroup,crystal_system,CS_type,band_gap,formation_energy_per_atom 2 | CS0,MnCuF6,0,-2.231473114,0,4.99929318,5.10230773,5.36922358,60.71203721,62.24826115,60.68586764,3.881917742,3,8,2,0,0,0,-2.231473114 3 | CS1,Ho3TmMn8,0,0.024139117,0.036496121,5.13506274,5.345533,8.38938666,90,93.34363685,58.63463674,9.338432839,13.33333333,12,8,1,1,0,0.024139117 4 | CS2,LiMnO2,0.8595,-2.117313932,0.042556032,5.14729131,5.21506637,5.674995,90.00002107,89.99995961,89.99821457,4.093248674,6,16,58,2,2,0.8595,-2.117313932 5 | CS3,SmCo2B2,0,-0.517755139,0,5.49165172,5.49165172,5.49165172,142.0233682,142.0233682,54.79360107,7.72941372,6.4,5,139,3,3,0,-0.517755139 6 | CS4,SmThCN,0,-0.942487525,0.044108841,6.38429834,6.38429834,6.38429813,33.8496327,33.8496327,33.84962986,9.42594885,6,4,160,4,4,0,-0.942487525 7 | CS5,GaTe,0.898,-0.575092354,0,4.13459945,4.13459945,18.42557,90,90,119.9999912,4.804722519,3.5,8,194,5,5,0.898,-0.575092354 8 | CS6,ScAlRh2,0,-1.089973604,0,4.41289362,4.41289362,4.41289362,60,60,60,7.590062737,8,4,225,6,6,0,-1.089973604 9 | -------------------------------------------------------------------------------- /output/hydra/singlerun/2024-01-25/test/general_less.csv: -------------------------------------------------------------------------------- 1 | label,crystal_system,CS_type 2 | CS0,0,0 3 | CS1,1,1 4 | CS2,2,2 5 | CS3,3,3 6 | CS4,4,4 7 | CS5,5,5 8 | CS6,6,6 -------------------------------------------------------------------------------- /output/hydra/singlerun/2024-01-25/test/hparams.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root_path: ${oc.env:PROJECT_ROOT}/data/mptest 3 | prop: 4 | - formation_energy_per_atom 5 | - band_gap 6 | - FM_type 7 | - BG_type 8 | - CS_type 9 | use_prop: formation_energy_per_atom 10 | num_targets: 1 11 | niggli: true 12 | primitive: false 13 | graph_method: crystalnn 14 | lattice_scale_method: scale_length 15 | preprocess_workers: 1 16 | readout: mean 17 | max_atoms: 20 18 | otf_graph: false 19 | eval_model_name: mp20 20 | train_max_epochs: 30 21 | early_stopping_patience: 100000 22 | teacher_forcing_max_epoch: 15 23 | n_delta: 40 24 | datamodule: 25 | _target_: concdvae.pl_data.datamodule.CrystDataModule 26 | accelerator: ${accelerator} 27 | n_delta: ${data.n_delta} 28 | use_prop: ${data.use_prop} 29 | datasets: 30 | train: 31 | _target_: concdvae.pl_data.dataset.CrystDataset 32 | name: Formation energy train 33 | path: ${data.root_path}/train.csv 34 | prop: ${data.prop} 35 | use_prop: ${data.use_prop} 36 | niggli: ${data.niggli} 37 | primitive: ${data.primitive} 38 | graph_method: ${data.graph_method} 39 | lattice_scale_method: ${data.lattice_scale_method} 40 | preprocess_workers: ${data.preprocess_workers} 41 | val: 42 | - _target_: concdvae.pl_data.dataset.CrystDataset 43 | name: Formation energy val 44 | path: ${data.root_path}/val.csv 45 | prop: ${data.prop} 46 | use_prop: ${data.use_prop} 47 | niggli: ${data.niggli} 48 | primitive: ${data.primitive} 49 | graph_method: ${data.graph_method} 50 | lattice_scale_method: ${data.lattice_scale_method} 51 | preprocess_workers: ${data.preprocess_workers} 52 | test: 53 | - _target_: concdvae.pl_data.dataset.CrystDataset 54 | name: Formation energy test 55 | path: ${data.root_path}/test.csv 56 | prop: ${data.prop} 57 | use_prop: ${data.use_prop} 58 | niggli: ${data.niggli} 59 | primitive: ${data.primitive} 60 | graph_method: ${data.graph_method} 61 | lattice_scale_method: ${data.lattice_scale_method} 62 | preprocess_workers: ${data.preprocess_workers} 63 | num_workers: 64 | train: 0 65 | val: 0 66 | test: 0 67 | batch_size: 68 | train: 10 69 | val: 10 70 | test: 10 71 | logging: 72 | val_check_interval: 5 73 | progress_bar_refresh_rate: 20 74 | wandb: 75 | name: ${expname} 76 | project: crystal_generation_mit 77 | entity: null 78 | log_model: true 79 | mode: online 80 | group: ${expname} 81 | wandb_watch: 82 | log: all 83 | log_freq: 500 84 | lr_monitor: 85 | logging_interval: step 86 | log_momentum: false 87 | model: 88 | encoder: 89 | _target_: concdvae.pl_modules.gnn.DimeNetPlusPlusWrap 90 | num_targets: ${data.num_targets} 91 | hidden_channels: 128 92 | num_blocks: 4 93 | int_emb_size: 64 94 | basis_emb_size: 8 95 | out_emb_channels: 256 96 | num_spherical: 7 97 | num_radial: 6 98 | otf_graph: ${data.otf_graph} 99 | cutoff: 7.0 100 | max_num_neighbors: 20 101 | envelope_exponent: 5 102 | num_before_skip: 1 103 | num_after_skip: 2 104 | num_output_layers: 3 105 | readout: ${data.readout} 106 | decoder: 107 | _target_: concdvae.pl_modules.decoder.GemNetTDecoder 108 | hidden_dim: 128 109 | latent_dim: ${model.latent_dim} 110 | time_emb_dim: ${model.time_emb_dim} 111 | max_neighbors: ${model.max_neighbors} 112 | radius: ${model.radius} 113 | scale_file: ${oc.env:PROJECT_ROOT}/concdvae/pl_modules/gemnet/gemnet-dT.json 114 | conditionmodel: 115 | _target_: concdvae.pl_modules.ConditionModel.ConditioningModule 116 | n_features: 128 117 | n_layers: 2 118 | condition_embeddings: 119 | - _target_: concdvae.pl_modules.ConditionModel.ClassConditionEmbedding 120 | condition_name: CS_type 121 | n_type: 7 122 | n_emb: 64 123 | n_features: 64 124 | n_layers: 3 125 | conditionpre: 126 | condition_predict: 127 | - _target_: concdvae.pl_modules.PreCondition.ClassConditionPredict 128 | condition_name: CS_type 129 | n_type: 7 130 | latent_dim: ${model.latent_dim} 131 | hidden_dim: 4 132 | n_layers: 1 133 | drop: 0.2 134 | _target_: concdvae.pl_modules.model.CDVAE 135 | hidden_dim: 256 136 | latent_dim: 256 137 | time_emb_dim: 64 138 | fc_num_layers: 1 139 | max_atoms: ${data.max_atoms} 140 | cost_natom: 1.0 141 | cost_coord: 10.0 142 | cost_type: 1.0 143 | cost_lattice: 10.0 144 | cost_composition: 1.0 145 | cost_edge: 10.0 146 | cost_property: 3.0 147 | beta: 0.01 148 | teacher_forcing_lattice: true 149 | teacher_forcing_max_epoch: ${data.teacher_forcing_max_epoch} 150 | max_neighbors: 20 151 | radius: 9.0 152 | sigma_begin: 10.0 153 | sigma_end: 0.01 154 | type_sigma_begin: 5.0 155 | type_sigma_end: 0.01 156 | num_noise_level: 50 157 | predict_property: false 158 | n_delta: 40 159 | optim: 160 | optimizer: 161 | _target_: torch.optim.Adam 162 | lr: 0.001 163 | betas: 164 | - 0.9 165 | - 0.999 166 | eps: 1.0e-08 167 | weight_decay: 0 168 | use_lr_scheduler: true 169 | lr_scheduler: 170 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 171 | factor: 0.6 172 | patience: 30 173 | min_lr: 0.0001 174 | train: 175 | deterministic: true 176 | random_seed: 123 177 | PT_train: 178 | start_epochs: 0 179 | max_epochs: ${data.train_max_epochs} 180 | print_freq: 20 181 | clip_grad_norm: -1 182 | clip_grad_norm_epoch: 100 183 | expname: test 184 | core: 185 | version: 0.0.1 186 | tags: 187 | - ${now:%Y-%m-%d} 188 | accelerator: cpu 189 | -------------------------------------------------------------------------------- /output/hydra/singlerun/2024-01-25/test/lattice_scaler.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/output/hydra/singlerun/2024-01-25/test/lattice_scaler.pt -------------------------------------------------------------------------------- /output/hydra/singlerun/2024-01-25/test/loss_file.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/output/hydra/singlerun/2024-01-25/test/loss_file.xlsx -------------------------------------------------------------------------------- /output/hydra/singlerun/2024-01-25/test/model_test.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/output/hydra/singlerun/2024-01-25/test/model_test.pth -------------------------------------------------------------------------------- /output/hydra/singlerun/2024-01-25/test/model_test_notbest.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/output/hydra/singlerun/2024-01-25/test/model_test_notbest.pth -------------------------------------------------------------------------------- /output/hydra/singlerun/2024-01-25/test/run.log: -------------------------------------------------------------------------------- 1 | [2024-01-25 17:21:15,027][hydra.utils][INFO] - Instantiating 2 | [2024-01-25 17:23:53,057][hydra.utils][INFO] - Instantiating 3 | [2024-01-25 17:26:12,778][hydra.utils][INFO] - Instantiating 4 | [2024-01-25 17:26:41,343][hydra.utils][INFO] - Instantiating 5 | [2024-01-25 17:27:20,845][hydra.utils][INFO] - Start Train 6 | [2024-01-25 17:33:24,043][hydra.utils][INFO] - END 7 | [2024-01-25 17:42:06,885][hydra.utils][INFO] - Instantiating 8 | [2024-01-25 17:42:13,924][hydra.utils][INFO] - Instantiating 9 | [2024-01-25 17:42:46,193][hydra.utils][INFO] - Start Train 10 | [2024-01-25 17:48:20,748][hydra.utils][INFO] - END 11 | -------------------------------------------------------------------------------- /scripts/__pycache__/condition_diff_z.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/scripts/__pycache__/condition_diff_z.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/condition_z.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/scripts/__pycache__/condition_z.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/eval_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/scripts/__pycache__/eval_utils.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/test_conz.cpython-38-pytest-7.4.0.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyye001/Con-CDVAE/2d5fc98efbcefbc28fb8886f3b8102bff8b0e55d/scripts/__pycache__/test_conz.cpython-38-pytest-7.4.0.pyc -------------------------------------------------------------------------------- /scripts/eval_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import hydra 3 | from pathlib import Path 4 | import numpy as np 5 | import os 6 | import torch 7 | from torch_geometric.loader import DataLoader 8 | 9 | from omegaconf import DictConfig, OmegaConf 10 | from hydra.experimental import compose, initialize_config_dir 11 | 12 | import smact 13 | from smact.screening import pauling_test 14 | 15 | from concdvae.pl_data.datamodule import worker_init_fn 16 | from concdvae.common.utils import PROJECT_ROOT 17 | from concdvae.common.data_utils import chemical_symbols, GaussianDistance 18 | 19 | 20 | def load_model(model_path, model_file, load_data=False): 21 | initialize_config_dir(config_dir=model_path) 22 | cfg: DictConfig = compose(config_name="hparams") 23 | 24 | model = hydra.utils.instantiate( 25 | cfg.model, 26 | optim=cfg.optim, 27 | data=cfg.data, 28 | logging=cfg.logging, 29 | _recursive_=False, 30 | ) 31 | 32 | model_root = Path(model_path) / model_file 33 | checkpoint = torch.load(model_root, map_location=torch.device('cpu')) 34 | model_state_dict = checkpoint['model'] 35 | model.load_state_dict(model_state_dict) 36 | lattice_scaler = torch.load(Path(model_path) / 'lattice_scaler.pt') 37 | model.lattice_scaler = lattice_scaler 38 | 39 | if load_data : 40 | test_datasets = [hydra.utils.instantiate(dataset_cfg) 41 | for dataset_cfg in cfg.data.datamodule.datasets.test] 42 | for test_dataset in test_datasets: 43 | test_dataset.lattice_scaler = lattice_scaler 44 | 45 | test_dataloaders = [ 46 | DataLoader( 47 | test_datasets[i], 48 | shuffle=False, 49 | batch_size=cfg.data.datamodule.batch_size.test, 50 | num_workers=cfg.data.datamodule.num_workers.test, 51 | worker_init_fn=worker_init_fn, 52 | ) 53 | for i in range(len(test_datasets))] 54 | test_loader = test_dataloaders[0] 55 | else: 56 | test_loader = None 57 | 58 | return model, test_loader, cfg 59 | 60 | 61 | def load_data(file_path): 62 | if file_path[-3:] == 'npy': 63 | data = np.load(file_path, allow_pickle=True).item() 64 | for k, v in data.items(): 65 | if k == 'input_data_batch': 66 | for k1, v1 in data[k].items(): 67 | data[k][k1] = torch.from_numpy(v1) 68 | else: 69 | data[k] = torch.from_numpy(v).unsqueeze(0) 70 | else: 71 | data = torch.load(file_path) 72 | return data 73 | 74 | 75 | def load_config(model_path): 76 | with initialize_config_dir(str(model_path)): 77 | cfg = compose(config_name='hparams') 78 | return cfg 79 | 80 | 81 | def get_crystals_list( 82 | frac_coords, atom_types, lengths, angles, num_atoms): 83 | """ 84 | args: 85 | frac_coords: (num_atoms, 3) 86 | atom_types: (num_atoms) 87 | lengths: (num_crystals) 88 | angles: (num_crystals) 89 | num_atoms: (num_crystals) 90 | """ 91 | assert frac_coords.size(0) == atom_types.size(0) == num_atoms.sum() 92 | assert lengths.size(0) == angles.size(0) == num_atoms.size(0) 93 | 94 | start_idx = 0 95 | crystal_array_list = [] 96 | for batch_idx, num_atom in enumerate(num_atoms.tolist()): 97 | cur_frac_coords = frac_coords.narrow(0, start_idx, num_atom) 98 | cur_atom_types = atom_types.narrow(0, start_idx, num_atom) 99 | cur_lengths = lengths[batch_idx] 100 | cur_angles = angles[batch_idx] 101 | 102 | crystal_array_list.append({ 103 | 'frac_coords': cur_frac_coords.detach().cpu().numpy(), 104 | 'atom_types': cur_atom_types.detach().cpu().numpy(), 105 | 'lengths': cur_lengths.detach().cpu().numpy(), 106 | 'angles': cur_angles.detach().cpu().numpy(), 107 | }) 108 | start_idx = start_idx + num_atom 109 | return crystal_array_list 110 | 111 | 112 | def smact_validity(comp, count, 113 | use_pauling_test=True, 114 | include_alloys=True): 115 | elem_symbols = tuple([chemical_symbols[elem] for elem in comp]) 116 | space = smact.element_dictionary(elem_symbols) 117 | smact_elems = [e[1] for e in space.items()] 118 | electronegs = [e.pauling_eneg for e in smact_elems] 119 | ox_combos = [e.oxidation_states for e in smact_elems] 120 | if len(set(elem_symbols)) == 1: 121 | return True 122 | if include_alloys: 123 | is_metal_list = [elem_s in smact.metals for elem_s in elem_symbols] 124 | if all(is_metal_list): 125 | return True 126 | 127 | threshold = np.max(count) 128 | compositions = [] 129 | for ox_states in itertools.product(*ox_combos): 130 | stoichs = [(c,) for c in count] 131 | # Test for charge balance 132 | cn_e, cn_r = smact.neutral_ratios( 133 | ox_states, stoichs=stoichs, threshold=threshold) 134 | # Electronegativity test 135 | if cn_e: 136 | if use_pauling_test: 137 | try: 138 | electroneg_OK = pauling_test(ox_states, electronegs) 139 | except TypeError: 140 | # if no electronegativity data, assume it is okay 141 | electroneg_OK = True 142 | else: 143 | electroneg_OK = True 144 | if electroneg_OK: 145 | for ratio in cn_r: 146 | compositions.append( 147 | tuple([elem_symbols, ox_states, ratio])) 148 | compositions = [(i[0], i[2]) for i in compositions] 149 | compositions = list(set(compositions)) 150 | if len(compositions) > 0: 151 | return True 152 | else: 153 | return False 154 | 155 | 156 | def structure_validity(crystal, cutoff=0.5): 157 | dist_mat = crystal.distance_matrix 158 | # Pad diagonal with a large number 159 | dist_mat = dist_mat + np.diag( 160 | np.ones(dist_mat.shape[0]) * (cutoff + 10.)) 161 | if dist_mat.min() < cutoff or crystal.volume < 0.1: 162 | return False 163 | else: 164 | return True -------------------------------------------------------------------------------- /scripts/extra_z.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | import torch 4 | import hydra 5 | import random 6 | import numpy as np 7 | import sys 8 | import os 9 | script_dir = os.path.dirname(os.path.abspath(__file__)) 10 | parent_dir = os.path.abspath(os.path.join(script_dir, "..")) 11 | sys.path.append(parent_dir) 12 | 13 | from tqdm import tqdm 14 | from torch.optim import Adam 15 | from pathlib import Path 16 | from types import SimpleNamespace 17 | from torch_geometric.data import Batch 18 | 19 | 20 | from eval_utils import load_model 21 | from concdvae.common.data_utils import GaussianDistance 22 | 23 | 24 | def main(args): 25 | print('strat with data:', args.data) 26 | model_path = Path(args.model_path) 27 | # load_data if do reconstruction. 28 | model, test_loader, cfg = load_model(args.model_path, args.model_file, 29 | load_data=True) 30 | if torch.cuda.is_available(): 31 | cfg.data.datamodule.accelerator='gpu' 32 | else: 33 | cfg.data.datamodule.accelerator = 'cpu' 34 | 35 | if(args.data!='test'): 36 | datamodule = hydra.utils.instantiate( 37 | cfg.data.datamodule, _recursive_=False 38 | ) 39 | if(args.data=='train'): 40 | test_loader=datamodule.train_dataloader 41 | elif(args.data=='val'): 42 | test_loader = datamodule.val_dataloaders[0] 43 | else: 44 | print('warng in data:', args.data) 45 | 46 | 47 | if torch.cuda.is_available(): 48 | model.to('cuda') 49 | model.device = 'cuda' 50 | 51 | condition_names = [] 52 | condition_list = [] 53 | id_list = [] 54 | mu_list = [] 55 | for con_emb in cfg.model.conditionmodel.condition_embeddings: 56 | condition_names.append(con_emb['condition_name']) 57 | condition_list.append([]) 58 | 59 | for idx, batch in enumerate(test_loader): 60 | print(idx, 'in', len(test_loader), file=sys.stdout) 61 | sys.stdout.flush() 62 | if torch.cuda.is_available(): 63 | batch = batch.cuda() 64 | 65 | mu, log_var, z = model.encode(batch) 66 | 67 | for i in range(len(condition_names)): 68 | condition = batch[condition_names[i]] 69 | condition = condition.cpu().detach().numpy() 70 | condition = list(condition) 71 | condition_list[i].extend(condition) 72 | # print('he') 73 | id_list.extend(batch['mp_id']) 74 | mu = mu.cpu().detach().numpy() 75 | for i in range(mu.shape[0]): 76 | mu_cry = mu[i,:].tolist() 77 | mu_list.append(mu_cry) 78 | 79 | condition_dict = {k:v for k, v in zip(condition_names,condition_list)} 80 | output = {'material_id': id_list, 'material_z': mu_list} 81 | output.update(condition_dict) 82 | 83 | outputfile = model_path / args.output_file 84 | torch.save(output, outputfile) 85 | print('end') 86 | 87 | 88 | if __name__ == '__main__': 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument('--model_path', required=True) 91 | parser.add_argument('--model_file', default='model_perov.pth', type=str) 92 | parser.add_argument('--output_file', default='material_z.pt', type=str) 93 | parser.add_argument('--data', default='test', type=str) 94 | 95 | args = parser.parse_args() 96 | 97 | main(args) -------------------------------------------------------------------------------- /scripts/pt2cif.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pymatgen.core.structure import Structure 3 | from pymatgen.core.lattice import Lattice 4 | from pymatgen.core.composition import Composition 5 | import os 6 | 7 | dataroot = 'YOUR_PATH_TO_.PT' 8 | datafile = 'eval_gen_abc.pt' 9 | 10 | datafile_read = os.path.join(dataroot, datafile) 11 | data = torch.load(datafile_read,map_location=torch.device('cpu')) 12 | cif_path = os.path.join(dataroot, 'ciffile/') 13 | if not os.path.exists(cif_path): 14 | os.makedirs(cif_path) 15 | lengths = data['lengths'] 16 | angles = data['angles'] 17 | num_atoms = data['num_atoms'] 18 | frac_coors = data['frac_coords'] 19 | atom_types = data['atom_types'] 20 | 21 | lengths_list = lengths.numpy().tolist() 22 | angles_list = angles.numpy().tolist() 23 | num_atoms_list = num_atoms.tolist() 24 | frac_coors_list = frac_coors.numpy().tolist() 25 | atom_types_list = atom_types.tolist() 26 | 27 | num_materal = 0 28 | for i in range(len(num_atoms_list)): #第i个batch? 29 | now_atom = 0 30 | for a in range(len(num_atoms_list[i])): #第a个材料 31 | cif_mat_path = os.path.join(cif_path, str(num_materal)) 32 | length = lengths_list[i][a] 33 | angle = angles_list[i][a] 34 | atom_num = num_atoms_list[i][a] 35 | 36 | atom_type = atom_types_list[i][now_atom: now_atom + atom_num] 37 | frac_coord = frac_coors_list[i][now_atom: now_atom + atom_num][:] 38 | lattice = Lattice.from_parameters(a=length[0], b=length[1], c=length[2], alpha=angle[0], 39 | beta=angle[1], gamma=angle[2]) 40 | 41 | structure = Structure(lattice, atom_type, frac_coord, to_unit_cell=True) 42 | filename = datafile[:-3]+'__' + str(num_materal) + '.cif' 43 | file_path = os.path.join(cif_path, filename) 44 | structure.to(filename=file_path) 45 | now_atom += atom_num 46 | num_materal += 1 47 | 48 | print('end') 49 | 50 | --------------------------------------------------------------------------------