├── .gitignore ├── MEGNet ├── MEGNet_Trainer.py ├── __init__.py ├── model │ ├── Layers.py │ ├── MEGNet.py │ └── __init__.py └── utils │ ├── Struct2Graph.py │ ├── __init__.py │ └── utils.py ├── README.md ├── cuda_requirements.txt ├── experiment_scripts ├── MEGNet_Original_Trainer.py ├── experiments_configs │ ├── bulk_moduli.yaml │ ├── mp_band_gap.yaml │ ├── mp_formation_energy.yaml │ ├── tf_bulk_moduli.yaml │ ├── tf_mp_band_gap.yaml │ └── tf_mp_formation_energy.yaml └── model_configs │ ├── tf_bg.yaml │ ├── tf_bulk.yaml │ ├── tf_ef.yaml │ ├── torch_bg.yaml │ ├── torch_bulk.yaml │ └── torch_ef.yaml ├── experiments_data └── bulk_moduli.json ├── requirements.txt ├── run_experiment.py ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | experiments_data/mp.2018.6.1.json 2 | MEGNet_test.egg-info/ 3 | .idea/ 4 | **/__pycache__/ 5 | -------------------------------------------------------------------------------- /MEGNet/MEGNet_Trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | from tqdm import tqdm 6 | from torch_geometric.loader import DataLoader 7 | 8 | from .utils.Struct2Graph import ( 9 | SimpleCrystalConverter, 10 | FlattenGaussianDistanceConverter, 11 | GaussianDistanceConverter, 12 | AtomFeaturesExtractor, 13 | ) 14 | from .model.MEGNet import MEGNet 15 | from .utils.utils import Scaler 16 | 17 | 18 | class MEGNetTrainer: 19 | def __init__(self, trainset, testset, config): 20 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 21 | self.config = config 22 | self.target_name = config['model']['target_name'] 23 | 24 | if self.config["data"]["add_z_bond_coord"]: 25 | bond_converter = FlattenGaussianDistanceConverter( 26 | centers=np.linspace(0, self.config['data']['cutoff'], self.config['model']['edge_embed_size']) 27 | ) 28 | else: 29 | bond_converter = GaussianDistanceConverter( 30 | centers=np.linspace(0, self.config['data']['cutoff'], self.config['model']['edge_embed_size']) 31 | ) 32 | atom_converter = AtomFeaturesExtractor(self.config["data"]["atom_features"]) 33 | 34 | self.model = MEGNet( 35 | edge_input_shape=bond_converter.get_shape(), 36 | node_input_shape=atom_converter.get_shape(), 37 | state_input_shape=self.config["model"]["state_input_shape"] 38 | ).to(self.device) 39 | self.Scaler = Scaler() 40 | 41 | self.converter = SimpleCrystalConverter( 42 | target_name=self.config['model']['target_name'], 43 | bond_converter=bond_converter, 44 | atom_converter=atom_converter, 45 | cutoff=self.config["data"]["cutoff"], 46 | add_z_bond_coord=self.config["data"]["add_z_bond_coord"] 47 | ) 48 | print("converting data") 49 | self.train_structures = [self.converter.convert(s) for s in tqdm(trainset)] 50 | self.test_structures = [self.converter.convert(s) for s in tqdm(testset)] 51 | self.Scaler.fit(self.train_structures) 52 | 53 | self.trainloader = DataLoader( 54 | self.train_structures, 55 | batch_size=self.config["model"]["train_batch_size"], 56 | shuffle=True, 57 | ) 58 | self.testloader = DataLoader( 59 | self.test_structures, 60 | batch_size=self.config["model"]["test_batch_size"], 61 | shuffle=False, 62 | ) 63 | 64 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config['optim']['lr_initial']) 65 | if self.config["optim"]["scheduler"].lower() == "ReduceLROnPlateau".lower(): 66 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 67 | self.optimizer, 68 | factor=self.config["optim"]["factor"], 69 | patience=self.config["optim"]["patience"], 70 | threshold=self.config["optim"]["threshold"], 71 | min_lr=self.config["optim"]["min_lr"], 72 | verbose=True, 73 | ) 74 | else : 75 | self.scheduler = torch.optim.lr_scheduler.ExponentialLR( 76 | self.optimizer, 77 | gamma=0.99 78 | ) 79 | 80 | def one_epoch(self): 81 | total = [] 82 | self.model.train(True) 83 | for batch in tqdm(self.trainloader): 84 | batch = batch.to(self.device) 85 | y = self.Scaler.transform(batch.y) 86 | 87 | preds = self.model( 88 | batch.x, batch.edge_index, batch.edge_attr, batch.state, batch.batch, batch.bond_batch 89 | ).squeeze() 90 | 91 | loss = F.mse_loss(y, preds) 92 | loss.backward() 93 | self.optimizer.step() 94 | self.optimizer.zero_grad() 95 | 96 | total.append( 97 | F.l1_loss(self.Scaler.inverse_transform(preds), batch.y, reduction='sum').to('cpu').data.numpy() 98 | ) 99 | return sum(total) / len(self.train_structures) 100 | 101 | def validation(self): 102 | total = [] 103 | self.model.train(False) 104 | with torch.no_grad(): 105 | for batch in self.testloader: 106 | batch = batch.to(self.device) 107 | y = batch.y 108 | 109 | preds = self.model( 110 | batch.x, batch.edge_index, batch.edge_attr, batch.state, batch.batch, batch.bond_batch 111 | ).squeeze() 112 | 113 | total.append(F.l1_loss(self.Scaler.inverse_transform(preds), y, reduction='sum').to('cpu').data.numpy()) 114 | 115 | return sum(total) / len(self.test_structures) 116 | 117 | def train(self): 118 | for epoch in range(self.config['model']['epochs']): 119 | print(f"===={epoch} out of {self.config['model']['epochs'] - 1} epochs====") 120 | print(f'target: {self.target_name} device: {self.device}') 121 | 122 | train_loss = self.one_epoch() 123 | validation_loss = self.validation() 124 | 125 | self.scheduler.step(train_loss) 126 | 127 | print(f"train loss: {train_loss}, test loss: {validation_loss}") 128 | -------------------------------------------------------------------------------- /MEGNet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RomanovIgnat/MEGNet_PyTorch/75001472680115434c8c209d52f1e4f30cd2af69/MEGNet/__init__.py -------------------------------------------------------------------------------- /MEGNet/model/Layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.nn import MessagePassing, global_mean_pool 4 | 5 | 6 | class ShiftedSoftplus(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.sp = nn.Softplus() 10 | self.shift = nn.Parameter(torch.log(torch.tensor([2.])), requires_grad=False) 11 | 12 | def forward(self, x): 13 | return self.sp(x) - self.shift 14 | 15 | 16 | class MegnetModule(MessagePassing): 17 | def __init__(self, 18 | edge_input_shape, 19 | node_input_shape, 20 | state_input_shape, 21 | inner_skip=False, 22 | embed_size=32, 23 | ): 24 | """ 25 | Parameters 26 | ---------- 27 | edge_input_shape: size of edge features' 28 | node_input_shape: size of node features' 29 | state_input_shape: size of global state features' 30 | inner_skip: use inner or outer skip connection 31 | embed_size: embedding and output size 32 | """ 33 | super().__init__(aggr="mean") 34 | self.inner_skip = inner_skip 35 | self.phi_e = nn.Sequential( 36 | nn.Linear(4 * embed_size, 2 * embed_size), 37 | ShiftedSoftplus(), 38 | nn.Linear(2 * embed_size, 2 * embed_size), 39 | ShiftedSoftplus(), 40 | nn.Linear(2 * embed_size, embed_size), 41 | ShiftedSoftplus(), 42 | ) 43 | 44 | self.phi_u = nn.Sequential( 45 | nn.Linear(3 * embed_size, 2 * embed_size), 46 | ShiftedSoftplus(), 47 | nn.Linear(2 * embed_size, 2 * embed_size), 48 | ShiftedSoftplus(), 49 | nn.Linear(2 * embed_size, embed_size), 50 | ShiftedSoftplus(), 51 | ) 52 | 53 | self.phi_v = nn.Sequential( 54 | nn.Linear(3 * embed_size, 2 * embed_size), 55 | ShiftedSoftplus(), 56 | nn.Linear(2 * embed_size, 2 * embed_size), 57 | ShiftedSoftplus(), 58 | nn.Linear(2 * embed_size, embed_size), 59 | ShiftedSoftplus(), 60 | ) 61 | 62 | self.preprocess_e = nn.Sequential( 63 | nn.Linear(edge_input_shape, 2 * embed_size), 64 | ShiftedSoftplus(), 65 | nn.Linear(2 * embed_size, embed_size), 66 | ShiftedSoftplus(), 67 | ) 68 | 69 | self.preprocess_v = nn.Sequential( 70 | nn.Linear(node_input_shape, 2 * embed_size), 71 | ShiftedSoftplus(), 72 | nn.Linear(2 * embed_size, embed_size), 73 | ShiftedSoftplus(), 74 | ) 75 | 76 | self.preprocess_u = nn.Sequential( 77 | nn.Linear(state_input_shape, 2 * embed_size), 78 | ShiftedSoftplus(), 79 | nn.Linear(2 * embed_size, embed_size), 80 | ShiftedSoftplus(), 81 | ) 82 | 83 | def forward(self, x, edge_index, edge_attr, state, batch, bond_batch): 84 | if not self.inner_skip: 85 | x_skip = x 86 | edge_attr_skip = edge_attr 87 | state_skip = state 88 | 89 | x = self.preprocess_v(x) 90 | edge_attr = self.preprocess_e(edge_attr) 91 | state = self.preprocess_u(state) 92 | else: 93 | x = self.preprocess_v(x) 94 | edge_attr = self.preprocess_e(edge_attr) 95 | state = self.preprocess_u(state) 96 | 97 | x_skip = x 98 | edge_attr_skip = edge_attr 99 | state_skip = state 100 | 101 | if torch.numel(bond_batch) > 0: 102 | edge_attr = self.edge_updater( 103 | edge_index=edge_index, x=x, edge_attr=edge_attr, state=state, bond_batch=bond_batch 104 | ) 105 | x = self.propagate( 106 | edge_index=edge_index, x=x, edge_attr=edge_attr, state=state, batch=batch 107 | ) 108 | u_v = global_mean_pool(x, batch) 109 | u_e = global_mean_pool(edge_attr, bond_batch, batch.max().item() + 1) 110 | state = self.phi_u(torch.cat((u_e, u_v, state), 1)) 111 | return x + x_skip, edge_attr + edge_attr_skip, state + state_skip 112 | 113 | def message(self, x_i, x_j, edge_attr): 114 | return edge_attr 115 | 116 | def update(self, inputs, x, state, batch): 117 | return self.phi_v(torch.cat((inputs, x, state[batch, :]), 1)) 118 | 119 | def edge_update(self, x_i, x_j, edge_attr, state, bond_batch): 120 | return self.phi_e(torch.cat((x_i, x_j, edge_attr, state[bond_batch, :]), 1)) 121 | -------------------------------------------------------------------------------- /MEGNet/model/MEGNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .Layers import MegnetModule, ShiftedSoftplus 5 | from torch_geometric.nn import Set2Set 6 | 7 | 8 | ATOMIC_NUMBERS = 95 9 | 10 | 11 | class MEGNet(nn.Module): 12 | def __init__(self, 13 | edge_input_shape, 14 | node_input_shape, 15 | state_input_shape, 16 | node_embedding_size=16, 17 | embedding_size=32, 18 | n_blocks=3, 19 | ): 20 | """ 21 | Parameters 22 | ---------- 23 | edge_input_shape: size of edge features' 24 | node_input_shape: size of node features' 25 | state_input_shape: size of global state features' 26 | node_embedding_size: if using embedding layer the size of result embeddings 27 | embedding_size: size of inner embeddings 28 | n_blocks: amount of MEGNet blocks 29 | """ 30 | super().__init__() 31 | self.embedded = node_input_shape is None 32 | if self.embedded: 33 | node_input_shape = node_embedding_size 34 | self.emb = nn.Embedding(ATOMIC_NUMBERS, node_embedding_size) 35 | 36 | self.m1 = MegnetModule( 37 | edge_input_shape, node_input_shape, state_input_shape, inner_skip=True, embed_size=embedding_size 38 | ) 39 | self.blocks = nn.ModuleList() 40 | for i in range(n_blocks - 1): 41 | self.blocks.append(MegnetModule(embedding_size, embedding_size, embedding_size)) 42 | 43 | self.se = Set2Set(embedding_size, 1) 44 | self.sv = Set2Set(embedding_size, 1) 45 | self.hiddens = nn.Sequential( 46 | nn.Linear(5 * embedding_size, embedding_size), 47 | ShiftedSoftplus(), 48 | nn.Linear(embedding_size, embedding_size // 2), 49 | ShiftedSoftplus(), 50 | nn.Linear(embedding_size // 2, 1) 51 | ) 52 | 53 | def forward(self, x, edge_index, edge_attr, state, batch, bond_batch): 54 | if self.embedded: 55 | x = self.emb(x).squeeze() 56 | else: 57 | x = x.float() 58 | 59 | x, edge_attr, state = self.m1(x, edge_index, edge_attr, state, batch, bond_batch) 60 | for block in self.blocks: 61 | x, edge_attr, state = block(x, edge_index, edge_attr, state, batch, bond_batch) 62 | x = self.sv(x, batch) 63 | edge_attr = self.se(edge_attr, bond_batch) 64 | 65 | tmp_shape = x.shape[0] - edge_attr.shape[0] 66 | edge_attr = F.pad(edge_attr, (0, 0, 0, tmp_shape), value=0.0) 67 | 68 | tmp = torch.cat((x, edge_attr, state), 1) 69 | out = self.hiddens(tmp) 70 | return out 71 | -------------------------------------------------------------------------------- /MEGNet/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RomanovIgnat/MEGNet_PyTorch/75001472680115434c8c209d52f1e4f30cd2af69/MEGNet/model/__init__.py -------------------------------------------------------------------------------- /MEGNet/utils/Struct2Graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch_geometric.data import Data 4 | from pymatgen.core import Structure 5 | from pymatgen.core.periodic_table import DummySpecies 6 | from pymatgen.optimization.neighbors import find_points_in_spheres 7 | 8 | 9 | class MyTensor(torch.Tensor): 10 | """ 11 | this class is needed to work with graphs without edges 12 | """ 13 | def max(self, *args, **kwargs): 14 | if torch.numel(self) == 0: 15 | return 0 16 | else: 17 | return torch.max(self, *args, **kwargs) 18 | 19 | 20 | class SimpleCrystalConverter: 21 | def __init__( 22 | self, 23 | target_name, 24 | atom_converter=None, 25 | bond_converter=None, 26 | add_z_bond_coord=False, 27 | cutoff=5.0 28 | ): 29 | """ 30 | Parameters 31 | ---------- 32 | atom_converter: converter that converts pymatgen structure to node features 33 | bond_converter: converter that converts distances to edge features 34 | add_z_bond_coord: use z-coordinate feature or no 35 | cutoff: cutoff radius 36 | """ 37 | self.target_name = target_name 38 | self.cutoff = cutoff 39 | self.atom_converter = atom_converter if atom_converter else DummyConverter() 40 | self.bond_converter = bond_converter if bond_converter else DummyConverter() 41 | self.add_z_bond_coord = add_z_bond_coord 42 | 43 | def convert(self, d): 44 | lattice_matrix = np.ascontiguousarray(np.array(d.lattice.matrix), dtype=float) 45 | pbc = np.array([1, 1, 1], dtype=int) 46 | cart_coords = np.ascontiguousarray(np.array(d.cart_coords), dtype=float) 47 | 48 | center_indices, neighbor_indices, _, distances = find_points_in_spheres( 49 | cart_coords, cart_coords, r=self.cutoff, pbc=pbc, lattice=lattice_matrix, tol=1e-8 50 | ) 51 | 52 | exclude_self = (center_indices != neighbor_indices) 53 | 54 | edge_index = torch.Tensor(np.stack((center_indices[exclude_self], neighbor_indices[exclude_self]))).long() 55 | 56 | x = torch.Tensor(self.atom_converter.convert(d)).long() 57 | 58 | distances_preprocessed = distances[exclude_self] 59 | if self.add_z_bond_coord: 60 | z_coord_diff = np.abs(cart_coords[edge_index[0], 2] - cart_coords[edge_index[1], 2]) 61 | distances_preprocessed = np.stack( 62 | (distances_preprocessed, z_coord_diff), axis=0 63 | ) 64 | 65 | edge_attr = torch.Tensor(self.bond_converter.convert(distances_preprocessed)) 66 | state = getattr(d, "state", None) or [[0.0, 0.0]] 67 | y = getattr(d, self.target_name) if hasattr(d, self.target_name) else 0 68 | bond_batch = MyTensor(np.zeros(edge_index.shape[1])).long() 69 | 70 | return Data( 71 | x=x, edge_index=edge_index, edge_attr=edge_attr, state=torch.Tensor(state), y=y, bond_batch=bond_batch 72 | ) 73 | 74 | def __call__(self, d): 75 | return self.convert(d) 76 | 77 | 78 | class DummyConverter: 79 | def convert(self, d): 80 | return d.reshape((-1, 1)) 81 | 82 | 83 | class GaussianDistanceConverter: 84 | def __init__(self, centers=np.linspace(0, 5, 100), sigma=0.5): 85 | self.centers = centers 86 | self.sigma = sigma 87 | 88 | def convert(self, d): 89 | return np.exp( 90 | -((d.reshape((-1, 1)) - self.centers.reshape((1, -1))) / self.sigma) ** 2 91 | ) 92 | 93 | def get_shape(self): 94 | return len(self.centers) 95 | 96 | 97 | class FlattenGaussianDistanceConverter(GaussianDistanceConverter): 98 | def __init__(self, centers=np.linspace(0, 5, 100), sigma=0.5): 99 | super().__init__(centers, sigma) 100 | 101 | def convert(self, d): 102 | res = [] 103 | for arr in d: 104 | res.append(super().convert(arr)) 105 | return np.hstack(res) 106 | 107 | def get_shape(self): 108 | return 2 * len(self.centers) 109 | 110 | 111 | class AtomFeaturesExtractor: 112 | def __init__(self, atom_features): 113 | self.atom_features = atom_features 114 | 115 | def convert(self, structure: Structure): 116 | if self.atom_features == "Z": 117 | return np.array( 118 | [0 if isinstance(i, DummySpecies) else i.Z for i in structure.species] 119 | ).reshape(-1, 1) 120 | elif self.atom_features == 'werespecies': 121 | return np.array([ 122 | [ 123 | 0 if isinstance(i, DummySpecies) else i.Z, 124 | i.properties["was"], 125 | ] for i in structure.species 126 | ]) 127 | else: 128 | return np.array( 129 | [0 if isinstance(i, DummySpecies) else i.Z for i in structure.species] 130 | ).reshape(-1, 1) 131 | 132 | def get_shape(self): 133 | if self.atom_features == "Z": 134 | return 1 135 | elif self.atom_features == 'werespecies': 136 | return 2 137 | else: 138 | return None 139 | -------------------------------------------------------------------------------- /MEGNet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RomanovIgnat/MEGNet_PyTorch/75001472680115434c8c209d52f1e4f30cd2af69/MEGNet/utils/__init__.py -------------------------------------------------------------------------------- /MEGNet/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import copy 3 | import torch 4 | import random 5 | 6 | from pymatgen.io.cif import CifParser 7 | 8 | 9 | class Scaler: 10 | def __init__(self): 11 | self.mean = 0 12 | self.std = 1.0 13 | 14 | def fit(self, dataset, feature_name='y'): 15 | data = np.array([getattr(dataset[i], feature_name) for i in range(len(dataset))]) 16 | self.mean = np.mean(data) 17 | self.std = np.std(data) 18 | 19 | def transform(self, data): 20 | data_copy = copy(data) 21 | return (data_copy - self.mean) / (self.std if abs(self.std) > 1e-7 else 1.) 22 | 23 | def inverse_transform(self, data): 24 | data_copy = copy(data) 25 | std = self.std if abs(self.std) > 1e-7 else 1.0 26 | return data_copy * std + self.mean 27 | 28 | 29 | class String2StructConverter: 30 | def __init__(self, struct_target_names): 31 | self.target_names = struct_target_names 32 | 33 | def convert(self, elem): 34 | struct = CifParser.from_string(elem['structure']).get_structures()[0] 35 | for name in self.target_names: 36 | setattr(struct, name, elem[name]) 37 | return struct 38 | 39 | 40 | def set_random_seed(seed): 41 | torch.backends.cudnn.deterministic = True 42 | torch.manual_seed(seed) 43 | torch.cuda.manual_seed(seed) 44 | np.random.seed(seed) 45 | random.seed(seed) 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MEGNet_PyTorch -------------------------------------------------------------------------------- /cuda_requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | astunparse==1.6.3 3 | cachetools==5.0.0 4 | certifi==2021.10.8 5 | charset-normalizer==2.0.9 6 | click==8.0.3 7 | cycler==0.11.0 8 | flatbuffers==2.0 9 | fonttools==4.28.5 10 | future==0.18.2 11 | gast==0.5.3 12 | google-auth==2.6.2 13 | google-auth-oauthlib==0.4.6 14 | google-pasta==0.2.0 15 | googledrivedownloader==0.4 16 | grpcio==1.44.0 17 | h5py==3.6.0 18 | idna==3.3 19 | importlib-metadata==4.11.3 20 | isodate==0.6.1 21 | Jinja2==3.0.3 22 | joblib==1.1.0 23 | keras==2.8.0 24 | Keras-Preprocessing==1.1.2 25 | kiwisolver==1.3.2 26 | libclang==13.0.0 27 | Markdown==3.3.6 28 | MarkupSafe==2.0.1 29 | matplotlib==3.5.1 30 | megnet==1.3.0 31 | monty==2021.12.1 32 | mpmath==1.2.1 33 | networkx==2.6.3 34 | numpy==1.22.0 35 | oauthlib==3.2.0 36 | opt-einsum==3.3.0 37 | packaging==21.3 38 | palettable==3.3.0 39 | pandas==1.3.5 40 | Pillow==8.4.0 41 | plotly==5.5.0 42 | protobuf==3.19.4 43 | pyasn1==0.4.8 44 | pyasn1-modules==0.2.8 45 | pymatgen==2022.0.17 46 | pyparsing==3.0.6 47 | python-dateutil==2.8.2 48 | pytz==2021.3 49 | PyYAML==6.0 50 | rdflib==6.1.1 51 | requests==2.26.0 52 | requests-oauthlib==1.3.1 53 | rsa==4.8 54 | ruamel.yaml==0.17.20 55 | ruamel.yaml.clib==0.2.6 56 | scikit-learn==1.0.2 57 | scipy==1.7.3 58 | six==1.16.0 59 | spglib==1.16.3 60 | sympy==1.9 61 | tabulate==0.8.9 62 | tenacity==8.0.1 63 | tensorboard==2.8.0 64 | tensorboard-data-server==0.6.1 65 | tensorboard-plugin-wit==1.8.1 66 | tensorflow==2.8.0 67 | tensorflow-io-gcs-filesystem==0.24.0 68 | termcolor==1.1.0 69 | tf-estimator-nightly==2.8.0.dev2021122109 70 | threadpoolctl==3.0.0 71 | torch==1.10.0+cu113 72 | torch-geometric==2.0.4 73 | torch-scatter==2.0.9 74 | torch-sparse==0.6.13 75 | torchaudio==0.10.0+cu113 76 | torchvision==0.11.1+cu113 77 | tqdm==4.62.3 78 | typing_extensions==4.0.1 79 | uncertainties==3.1.6 80 | urllib3==1.26.7 81 | Werkzeug==2.1.0 82 | wrapt==1.14.0 83 | yacs==0.1.8 84 | zipp==3.7.0 85 | -------------------------------------------------------------------------------- /experiment_scripts/MEGNet_Original_Trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from megnet.models import MEGNetModel 4 | from megnet.data.graph import GaussianDistance 5 | from megnet.data.crystal import CrystalGraph 6 | from megnet.utils.preprocessing import StandardScaler 7 | 8 | 9 | class MEGNetOriginalTrainer: 10 | def __init__(self, trainset, testset, config): 11 | self.config = config 12 | 13 | self.train_structures = trainset 14 | self.test_structures = testset 15 | self.train_targets = [getattr(s, self.config['model']['target_name']) for s in trainset] 16 | self.test_targets = [getattr(s, self.config['model']['target_name']) for s in testset] 17 | 18 | self.cg = CrystalGraph( 19 | bond_converter=GaussianDistance( 20 | np.linspace(0, self.config['data']['cutoff'], self.config['model']['nfeat_edge']) 21 | ), 22 | cutoff=self.config['data']['cutoff'], 23 | ) 24 | 25 | self.Scaler = StandardScaler.from_training_data(trainset, self.train_targets) 26 | 27 | self.model = MEGNetModel( 28 | nfeat_edge=self.config['model']['nfeat_edge'], 29 | nfeat_global=self.config['model']['nfeat_global'], 30 | nfeat_node=self.config['model']['nfeat_node'], 31 | lr=self.config['model']['learning_rate'], 32 | graph_converter=self.cg, 33 | target_scaler=self.Scaler, 34 | metrics=['mae'] 35 | ) 36 | 37 | def train(self): 38 | self.model.train( 39 | self.train_structures, 40 | self.train_targets, 41 | self.test_structures, 42 | self.test_targets, 43 | epochs=self.config['model']['epochs'], 44 | save_checkpoint=False, 45 | dirname="experiment_scripts" 46 | ) 47 | -------------------------------------------------------------------------------- /experiment_scripts/experiments_configs/bulk_moduli.yaml: -------------------------------------------------------------------------------- 1 | is_str: false 2 | test_size: 0.1 3 | shuffle: true 4 | trainer: 'pytorch_trainer' -------------------------------------------------------------------------------- /experiment_scripts/experiments_configs/mp_band_gap.yaml: -------------------------------------------------------------------------------- 1 | is_str: true 2 | test_size: 0.03 3 | shuffle: true 4 | trainer: 'pytorch_trainer' -------------------------------------------------------------------------------- /experiment_scripts/experiments_configs/mp_formation_energy.yaml: -------------------------------------------------------------------------------- 1 | is_str: true 2 | test_size: 0.03 3 | shuffle: true 4 | trainer: 'pytorch_trainer' -------------------------------------------------------------------------------- /experiment_scripts/experiments_configs/tf_bulk_moduli.yaml: -------------------------------------------------------------------------------- 1 | is_str: false 2 | test_size: 0.1 3 | shuffle: true 4 | trainer: 'tf_trainer' -------------------------------------------------------------------------------- /experiment_scripts/experiments_configs/tf_mp_band_gap.yaml: -------------------------------------------------------------------------------- 1 | is_str: true 2 | test_size: 0.03 3 | shuffle: true 4 | trainer: 'tf_trainer' -------------------------------------------------------------------------------- /experiment_scripts/experiments_configs/tf_mp_formation_energy.yaml: -------------------------------------------------------------------------------- 1 | is_str: true 2 | test_size: 0.03 3 | shuffle: true 4 | trainer: 'tf_trainer' -------------------------------------------------------------------------------- /experiment_scripts/model_configs/tf_bg.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | cutoff: 5 3 | model: 4 | nfeat_edge: 10 5 | nfeat_global: 2 6 | nfeat_node: ~ 7 | target_name: 'band_gap' 8 | epochs: 1000 9 | learning_rate: 1.e-3 -------------------------------------------------------------------------------- /experiment_scripts/model_configs/tf_bulk.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | cutoff: 5 3 | model: 4 | nfeat_edge: 10 5 | nfeat_global: 2 6 | nfeat_node: ~ 7 | target_name: 'bulk_moduli' 8 | epochs: 10 9 | learning_rate: 1.e-3 -------------------------------------------------------------------------------- /experiment_scripts/model_configs/tf_ef.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | cutoff: 5 3 | model: 4 | nfeat_edge: 10 5 | nfeat_global: 2 6 | nfeat_node: ~ 7 | target_name: 'formation_energy_per_atom' 8 | epochs: 1000 9 | learning_rate: 1.e-3 -------------------------------------------------------------------------------- /experiment_scripts/model_configs/torch_bg.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | add_z_bond_coord: false 3 | cutoff: 5 4 | atom_features: 'embed' 5 | model: 6 | edge_embed_size: 10 7 | state_input_shape: 2 8 | target_name: 'band_gap' 9 | train_batch_size: 256 10 | test_batch_size: 256 11 | epochs: 1000 12 | optim: 13 | lr_initial: 1.e-3 14 | factor: 0.5 15 | patience: 100 16 | threshold: 5.e-2 17 | min_lr: 1.e-5 18 | scheduler: 'ReduceLROnPlateau' -------------------------------------------------------------------------------- /experiment_scripts/model_configs/torch_bulk.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | add_z_bond_coord: false 3 | cutoff: 5 4 | atom_features: 'embed' 5 | model: 6 | edge_embed_size: 10 7 | state_input_shape: 2 8 | target_name: 'bulk_moduli' 9 | train_batch_size: 10 10 | test_batch_size: 10 11 | epochs: 10 12 | optim: 13 | lr_initial: 1.e-3 14 | factor: 0.5 15 | patience: 100 16 | threshold: 5.e-2 17 | min_lr: 1.e-5 18 | scheduler: 'ReduceLROnPlateau' -------------------------------------------------------------------------------- /experiment_scripts/model_configs/torch_ef.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | add_z_bond_coord: false 3 | cutoff: 5 4 | atom_features: 'embed' 5 | model: 6 | edge_embed_size: 10 7 | state_input_shape: 2 8 | target_name: 'formation_energy_per_atom' 9 | train_batch_size: 256 10 | test_batch_size: 256 11 | epochs: 1000 12 | optim: 13 | lr_initial: 1.e-3 14 | factor: 0.5 15 | patience: 100 16 | threshold: 5.e-2 17 | min_lr: 1.e-5 18 | scheduler: 'ReduceLROnPlateau' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | astunparse==1.6.3 3 | cachetools==5.0.0 4 | certifi==2021.10.8 5 | charset-normalizer==2.0.9 6 | click==8.0.3 7 | cycler==0.11.0 8 | flatbuffers==2.0 9 | fonttools==4.28.5 10 | future==0.18.2 11 | gast==0.5.3 12 | google-auth==2.6.2 13 | google-auth-oauthlib==0.4.6 14 | google-pasta==0.2.0 15 | googledrivedownloader==0.4 16 | grpcio==1.44.0 17 | h5py==3.6.0 18 | idna==3.3 19 | importlib-metadata==4.11.3 20 | isodate==0.6.1 21 | Jinja2==3.0.3 22 | joblib==1.1.0 23 | keras==2.8.0 24 | Keras-Preprocessing==1.1.2 25 | kiwisolver==1.3.2 26 | libclang==13.0.0 27 | Markdown==3.3.6 28 | MarkupSafe==2.0.1 29 | matplotlib==3.5.1 30 | megnet==1.3.0 31 | monty==2021.12.1 32 | mpmath==1.2.1 33 | networkx==2.6.3 34 | numpy==1.22.0 35 | oauthlib==3.2.0 36 | opt-einsum==3.3.0 37 | packaging==21.3 38 | palettable==3.3.0 39 | pandas==1.3.5 40 | Pillow==8.4.0 41 | plotly==5.5.0 42 | protobuf==3.19.4 43 | pyasn1==0.4.8 44 | pyasn1-modules==0.2.8 45 | pymatgen==2022.0.17 46 | pyparsing==3.0.6 47 | python-dateutil==2.8.2 48 | pytz==2021.3 49 | PyYAML==6.0 50 | rdflib==6.1.1 51 | requests==2.26.0 52 | requests-oauthlib==1.3.1 53 | rsa==4.8 54 | ruamel.yaml==0.17.20 55 | ruamel.yaml.clib==0.2.6 56 | scikit-learn==1.0.2 57 | scipy==1.7.3 58 | six==1.16.0 59 | spglib==1.16.3 60 | sympy==1.9 61 | tabulate==0.8.9 62 | tenacity==8.0.1 63 | tensorboard==2.8.0 64 | tensorboard-data-server==0.6.1 65 | tensorboard-plugin-wit==1.8.1 66 | tensorflow==2.8.0 67 | tensorflow-io-gcs-filesystem==0.24.0 68 | termcolor==1.1.0 69 | tf-estimator-nightly==2.8.0.dev2021122109 70 | threadpoolctl==3.0.0 71 | torch==1.10.1+cpu 72 | torch-cluster==1.6.0 73 | torch-geometric==2.0.4 74 | torch-scatter==2.0.9 75 | torch-sparse==0.6.13 76 | torch-spline-conv==1.2.1 77 | torchaudio==0.10.1 78 | torchvision==0.11.2+cpu 79 | tqdm==4.62.3 80 | typing_extensions==4.0.1 81 | uncertainties==3.1.6 82 | urllib3==1.26.7 83 | Werkzeug==2.1.0 84 | wrapt==1.14.0 85 | yacs==0.1.8 86 | zipp==3.7.0 87 | -------------------------------------------------------------------------------- /run_experiment.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import click 4 | import numpy as np 5 | import yaml 6 | 7 | from monty.serialization import loadfn 8 | from tqdm import tqdm 9 | 10 | from MEGNet.utils.utils import set_random_seed 11 | from MEGNet.utils.utils import String2StructConverter 12 | from MEGNet.MEGNet_Trainer import MEGNetTrainer 13 | 14 | from experiment_scripts.MEGNet_Original_Trainer import MEGNetOriginalTrainer 15 | 16 | 17 | def set_return(o, name, val): 18 | setattr(o, name, val) 19 | return o 20 | 21 | 22 | def parse_data(path, is_str): 23 | print("reading data") 24 | raw_data = loadfn(path) 25 | if is_str: 26 | converter = String2StructConverter(['formation_energy_per_atom', 'band_gap']) 27 | structures_list = [converter.convert(s) for s in tqdm(raw_data)] 28 | else: 29 | structures_list = raw_data["structures"] 30 | targets = np.log10(raw_data["bulk_moduli"]) 31 | structures_list = [set_return(s, "bulk_moduli", float(t)) for s, t in tqdm(zip(structures_list, targets))] 32 | return structures_list 33 | 34 | 35 | @click.command() 36 | @click.option('--dataset_path') 37 | @click.option('--experiment_config_path') 38 | @click.option('--model_config_path') 39 | def main(dataset_path, experiment_config_path, model_config_path): 40 | set_random_seed(17) 41 | 42 | with open(experiment_config_path) as ey: 43 | experiment_config = yaml.full_load(ey) 44 | with open(model_config_path) as my: 45 | model_config = yaml.full_load(my) 46 | 47 | dataset = parse_data(dataset_path, experiment_config['is_str']) 48 | if experiment_config['shuffle']: 49 | random.shuffle(dataset) 50 | 51 | test_size = int(len(dataset) * experiment_config['test_size']) 52 | trainset = dataset[:-test_size] 53 | testset = dataset[-test_size:] 54 | print(f"len of trainset: {len(trainset)}") 55 | print(f"len of testset: {len(testset)}") 56 | 57 | if experiment_config['trainer'] == 'pytorch_trainer': 58 | trainer = MEGNetTrainer(trainset, testset, model_config) 59 | else: 60 | trainer = MEGNetOriginalTrainer(trainset, testset, model_config) 61 | 62 | trainer.train() 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = MEGNet_test 3 | 4 | [options] 5 | packages = find: 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | --------------------------------------------------------------------------------