├── .github └── workflows │ └── tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── crystalformer ├── __init__.py ├── cli │ ├── __init__.py │ ├── classifier.py │ ├── cond_gen.py │ ├── dataset.py │ ├── spg_sample.py │ ├── train_dpo.py │ └── train_ppo.py ├── data │ ├── .gitkeep │ ├── wyckoff_list.csv │ └── wyckoff_symbols.csv ├── extension │ ├── __init__.py │ ├── experimental.py │ ├── loss.py │ ├── mcmc.py │ ├── model.py │ ├── train.py │ └── transformer.py ├── reinforce │ ├── __init__.py │ ├── dpo.py │ ├── ehull.py │ ├── potential.py │ ├── ppo.py │ ├── reward.py │ ├── sample.py │ └── vanilla.py └── src │ ├── __init__.py │ ├── attention.py │ ├── checkpoint.py │ ├── elements.py │ ├── lattice.py │ ├── loss.py │ ├── mcmc.py │ ├── rope.py │ ├── sample.py │ ├── train.py │ ├── transformer.py │ ├── utils.py │ ├── von_mises.py │ └── wyckoff.py ├── data ├── atoms.json └── mini.csv ├── imgs ├── crystalformer.png └── output.gif ├── main.py ├── model ├── README.md └── config.yaml ├── requirements.txt ├── scripts ├── README.md ├── awl2struct.py ├── check_sun_materials.py ├── compute_metrics.py ├── compute_metrics_matbench.py ├── config.py ├── e_above_hull.py ├── e_above_hull_alex.py ├── element_substition.py ├── eval_utils.py ├── mlff_relax.py ├── plot_embedding.py ├── process_alex.py └── structure_visualization.ipynb ├── setup.py └── tests ├── config.py ├── test_fc_mask.py ├── test_lattice.py ├── test_sampling.py ├── test_transformer.py ├── test_utils.py └── test_wyckoff.py /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: pytest 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.10"] 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: set PYTHONPATH 20 | run: | 21 | echo "PYTHONPATH=/home/runner/work/crystal_gpt/crystal_gpt" >> $GITHUB_ENV 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install pytest 26 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 27 | pip install -U jax 28 | pip install . 29 | - name: Test with pytest 30 | run: | 31 | pytest 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | job* 3 | *.out 4 | __pycache__/ 5 | data/ 6 | experimental/ 7 | *.ipynb 8 | *.egg-info/ -------------------------------------------------------------------------------- /crystalformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/CrystalFormer/a2ae7982234b84134950c5c7a4f49d592c83bddf/crystalformer/__init__.py -------------------------------------------------------------------------------- /crystalformer/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/CrystalFormer/a2ae7982234b84134950c5c7a4f49d592c83bddf/crystalformer/cli/__init__.py -------------------------------------------------------------------------------- /crystalformer/cli/classifier.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import pandas as pd 4 | import os 5 | import optax 6 | from jax.flatten_util import ravel_pytree 7 | 8 | import crystalformer.src.checkpoint as checkpoint 9 | from crystalformer.src.utils import GLXYZAW_from_file 10 | 11 | from crystalformer.extension.model import make_classifier 12 | from crystalformer.extension.transformer import make_transformer 13 | from crystalformer.extension.train import train 14 | from crystalformer.extension.loss import make_classifier_loss 15 | 16 | 17 | def get_labels(csv_file, label_col): 18 | data = pd.read_csv(csv_file) 19 | labels = data[label_col].values 20 | labels = jnp.array(labels, dtype=float) 21 | return labels 22 | 23 | def GLXYZAW_from_sample(spg, test_path): 24 | ### read from generated data 25 | from ast import literal_eval 26 | from crystalformer.src.wyckoff import mult_table 27 | 28 | test_data = pd.read_csv(test_path) 29 | L, XYZ, A, W = test_data['L'], test_data['X'], test_data['A'], test_data['W'] 30 | L = L.apply(lambda x: literal_eval(x)) 31 | XYZ = XYZ.apply(lambda x: literal_eval(x)) 32 | A = A.apply(lambda x: literal_eval(x)) 33 | W = W.apply(lambda x: literal_eval(x)) 34 | 35 | # convert array of list to numpy ndarray 36 | G = jnp.array([spg]*len(L)) 37 | L = jnp.array(L.tolist()) 38 | XYZ = jnp.array(XYZ.tolist()) 39 | A = jnp.array(A.tolist()) 40 | W = jnp.array(W.tolist()) 41 | 42 | M = jax.vmap(lambda g, w: mult_table[g-1, w], in_axes=(0, 0))(G, W) # (batchsize, n_max) 43 | num_atoms = jnp.sum(M, axis=1) 44 | length, angle = jnp.split(L, 2, axis=-1) 45 | length = length/num_atoms[:, None]**(1/3) 46 | angle = angle * (jnp.pi / 180) # to rad 47 | L = jnp.concatenate([length, angle], axis=-1) 48 | 49 | return G, L, XYZ, A, W 50 | 51 | 52 | def main(): 53 | 54 | import argparse 55 | parser = argparse.ArgumentParser(description='') 56 | 57 | group = parser.add_argument_group('dataset') 58 | group.add_argument('--train_path', default='/data/zdcao/crystal_gpt/dataset/mp_20/train.csv', help='') 59 | group.add_argument('--valid_path', default='/data/zdcao/crystal_gpt/dataset/mp_20/val.csv', help='') 60 | group.add_argument('--test_path', default='/data/zdcao/crystal_gpt/dataset/mp_20/test.csv', help='') 61 | group.add_argument('--spacegroup', type=int, default=None, help='The space group number') 62 | group.add_argument('--property', default='band_gap', help='The property to predict') 63 | group.add_argument('--num_io_process', type=int, default=40, help='number of io processes') 64 | 65 | group = parser.add_argument_group('predict dataset') 66 | group.add_argument('--output_path', type=str, default='./predict.npy', help='The path to save the prediction result') 67 | 68 | group = parser.add_argument_group('physics parameters') 69 | group.add_argument('--n_max', type=int, default=21, help='The maximum number of atoms in the cell') 70 | group.add_argument('--atom_types', type=int, default=119, help='Atom types including the padded atoms') 71 | group.add_argument('--wyck_types', type=int, default=28, help='Number of possible multiplicites including 0') 72 | 73 | group = parser.add_argument_group('transformer parameters') 74 | group.add_argument('--Nf', type=int, default=5, help='number of frequencies for fc') 75 | group.add_argument('--Kx', type=int, default=16, help='number of modes in x') 76 | group.add_argument('--Kl', type=int, default=4, help='number of modes in lattice') 77 | group.add_argument('--h0_size', type=int, default=256, help='hidden layer dimension for the first atom, 0 means we simply use a table for first aw_logit') 78 | group.add_argument('--transformer_layers', type=int, default=4, help='The number of layers in transformer') 79 | group.add_argument('--num_heads', type=int, default=8, help='The number of heads') 80 | group.add_argument('--key_size', type=int, default=32, help='The key size') 81 | group.add_argument('--model_size', type=int, default=64, help='The model size') 82 | group.add_argument('--embed_size', type=int, default=32, help='The enbedding size') 83 | group.add_argument('--dropout_rate', type=float, default=0.3, help='The dropout rate') 84 | 85 | group = parser.add_argument_group('classifier parameters') 86 | group.add_argument('--sequence_length', type=int, default=105, help='The sequence length') 87 | group.add_argument('--outputs_size', type=int, default=64, help='The outputs size') 88 | group.add_argument('--hidden_sizes', type=str, default='128,128,64' , help='The hidden sizes') 89 | group.add_argument('--num_classes', type=int, default=1, help='The number of classes') 90 | group.add_argument('--restore_path', type=str, default="/data/zdcao/crystal_gpt/classifier/", help='The restore path') 91 | 92 | group = parser.add_argument_group('training parameters') 93 | group.add_argument('--lr', type=float, default=1e-4, help='The learning rate') 94 | group.add_argument('--epochs', type=int, default=1000, help='The number of epochs') 95 | group.add_argument('--batchsize', type=int, default=256, help='The batch size') 96 | group.add_argument('--optimizer', type=str, default='adam', choices=["none", "adam"], help='The optimizer') 97 | 98 | args = parser.parse_args() 99 | key = jax.random.PRNGKey(42) 100 | 101 | if args.optimizer != "none": 102 | train_data = GLXYZAW_from_file(args.train_path, args.atom_types, 103 | args.wyck_types, args.n_max, args.num_io_process) 104 | valid_data = GLXYZAW_from_file(args.valid_path, args.atom_types, 105 | args.wyck_types, args.n_max, args.num_io_process) 106 | 107 | train_labels = get_labels(args.train_path, args.property) 108 | valid_labels = get_labels(args.valid_path, args.property) 109 | 110 | train_data = (*train_data, train_labels) 111 | valid_data = (*valid_data, valid_labels) 112 | 113 | else: 114 | if args.spacegroup == None: 115 | G, L, XYZ, A, W = GLXYZAW_from_file(args.test_path, args.atom_types, 116 | args.wyck_types, args.n_max, args.num_io_process) 117 | test_labels = get_labels(args.test_path, args.property) 118 | 119 | else: 120 | G, L, XYZ, A, W = GLXYZAW_from_sample(args.spacegroup, args.test_path) 121 | 122 | ################### Model ############################# 123 | transformer_params, state, transformer = make_transformer(key, args.Nf, args.Kx, args.Kl, args.n_max, 124 | args.h0_size, 125 | args.transformer_layers, args.num_heads, 126 | args.key_size, args.model_size, args.embed_size, 127 | args.atom_types, args.wyck_types, 128 | args.dropout_rate) 129 | print ("# of transformer params", ravel_pytree(transformer_params)[0].size) 130 | 131 | 132 | key, subkey = jax.random.split(key) 133 | classifier_params, classifier = make_classifier(subkey, 134 | n_max=args.n_max, 135 | embed_size=args.embed_size, 136 | sequence_length=args.sequence_length, 137 | outputs_size=args.outputs_size, 138 | hidden_sizes=[int(x) for x in args.hidden_sizes.split(',')], 139 | num_classes=args.num_classes) 140 | 141 | print ("# of classifier params", ravel_pytree(classifier_params)[0].size) 142 | 143 | 144 | params = (transformer_params, classifier_params) 145 | 146 | print("\n========== Prepare logs ==========") 147 | output_path = os.path.dirname(args.restore_path) 148 | print("Will output samples to: %s" % output_path) 149 | 150 | print("\n========== Load checkpoint==========") 151 | ckpt_filename, epoch_finished = checkpoint.find_ckpt_filename(args.restore_path) 152 | if ckpt_filename is not None: 153 | print("Load checkpoint file: %s, epoch finished: %g" %(ckpt_filename, epoch_finished)) 154 | ckpt = checkpoint.load_data(ckpt_filename) 155 | _params = ckpt["params"] 156 | else: 157 | print("No checkpoint file found. Start from scratch.") 158 | 159 | if len(_params) == len(params): 160 | params = _params 161 | else: 162 | params = (_params, params[1]) # only restore transformer params 163 | print("only restore transformer params") 164 | 165 | loss_fn, forward_fn = make_classifier_loss(transformer, classifier) 166 | 167 | if args.optimizer == 'adam': 168 | 169 | param_labels = ('transformer', 'classifier') 170 | optimizer = optax.multi_transform({'transformer': optax.adam(args.lr*0.1), 171 | 'classifier': optax.adam(args.lr)}, 172 | param_labels) 173 | opt_state = optimizer.init(params) 174 | 175 | print("\n========== Start training ==========") 176 | key, subkey = jax.random.split(key) 177 | params, opt_state = train(subkey, optimizer, opt_state, loss_fn, params, state, epoch_finished, args.epochs, args.batchsize, train_data, valid_data, output_path) 178 | 179 | elif args.optimizer == 'none': 180 | 181 | y = jax.vmap(forward_fn, 182 | in_axes=(None, None, None, 0, 0, 0, 0, 0, None) 183 | )(params, state, key, G, L, XYZ, A, W, False) 184 | 185 | jnp.save(args.output_path, y) 186 | 187 | else: 188 | raise NotImplementedError(f"Optimizer {args.optimizer} not implemented") 189 | 190 | 191 | if __name__ == "__main__": 192 | main() 193 | -------------------------------------------------------------------------------- /crystalformer/cli/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lmdb 3 | import pickle 4 | import numpy as np 5 | from crystalformer.src.utils import GLXYZAW_from_file 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | 9 | 10 | def csv_to_lmdb(csv_file, lmdb_file, args): 11 | if os.path.exists(lmdb_file): 12 | os.remove(lmdb_file) 13 | print(f"Removed existing {lmdb_file}") 14 | 15 | values = GLXYZAW_from_file(csv_file, 16 | atom_types=args.atom_types, 17 | wyck_types=args.wyck_types, 18 | n_max=args.n_max, 19 | num_workers=args.num_workers) 20 | keys = np.arange(len(values[0])) 21 | 22 | env = lmdb.open( 23 | lmdb_file, 24 | subdir=False, 25 | readonly=False, 26 | lock=False, 27 | readahead=False, 28 | meminit=False, 29 | max_readers=1, 30 | map_size=int(100e9), 31 | ) 32 | 33 | with env.begin(write=True) as txn: 34 | for key, value in zip(keys, values): 35 | txn.put(str(key).encode("utf-8"), pickle.dumps(value)) 36 | 37 | print(f"Successfully converted {csv_file} to {lmdb_file}") 38 | 39 | 40 | def main(): 41 | import argparse 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--n_max', type=int, default=21, help='The maximum number of atoms in the cell') 44 | parser.add_argument('--atom_types', type=int, default=119, help='Atom types including the padded atoms') 45 | parser.add_argument('--wyck_types', type=int, default=28, help='Number of possible multiplicites including 0') 46 | 47 | parser.add_argument("--path", type=str, required=True) 48 | parser.add_argument("--num_workers", type=int, default=40) 49 | args = parser.parse_args() 50 | 51 | for i in ["test", "val", "train"]: 52 | csv_to_lmdb( 53 | os.path.join(args.path, f"{i}.csv"), 54 | os.path.join(args.path, f"{i}.lmdb"), 55 | args 56 | ) 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /crystalformer/cli/train_dpo.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.flatten_util import ravel_pytree 4 | import os 5 | import optax 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | 9 | from crystalformer.src.utils import GLXYZAW_from_file 10 | from crystalformer.src.loss import make_loss_fn 11 | from crystalformer.src.transformer import make_transformer 12 | import crystalformer.src.checkpoint as checkpoint 13 | 14 | from crystalformer.reinforce.dpo import make_dpo_loss, train 15 | 16 | 17 | def main(): 18 | import argparse 19 | parser = argparse.ArgumentParser(description='') 20 | 21 | group = parser.add_argument_group('training parameters') 22 | group.add_argument('--epochs', type=int, default=100, help='') 23 | group.add_argument('--batchsize', type=int, default=100, help='') 24 | group.add_argument('--lr', type=float, default=1e-5, help='learning rate') 25 | group.add_argument('--lr_decay', type=float, default=0.0, help='lr decay') 26 | group.add_argument('--weight_decay', type=float, default=0.0, help='weight decay') 27 | group.add_argument('--clip_grad', type=float, default=1.0, help='clip gradient') 28 | group.add_argument("--optimizer", type=str, default="adam", choices=["none", "adam", "adamw"], help="optimizer type") 29 | 30 | group.add_argument("--folder", default="./data/", help="the folder to save data") 31 | group.add_argument("--restore_path", default=None, help="checkpoint path or file") 32 | 33 | group = parser.add_argument_group('dataset') 34 | group.add_argument('--chosen_path', default='/data/zdcao/crystal_gpt/dataset/mp_20/val.csv', help='') 35 | group.add_argument('--rejected_path', default='/data/zdcao/crystal_gpt/dataset/mp_20/val.csv', help='') 36 | group.add_argument("--val_ratio", type=float, default=0.2, help="validation ratio") 37 | group.add_argument('--num_io_process', type=int, default=40, help='number of process used in multiprocessing io') 38 | 39 | group = parser.add_argument_group('transformer parameters') 40 | group.add_argument('--Nf', type=int, default=5, help='number of frequencies for fc') 41 | group.add_argument('--Kx', type=int, default=16, help='number of modes in x') 42 | group.add_argument('--Kl', type=int, default=4, help='number of modes in lattice') 43 | group.add_argument('--h0_size', type=int, default=256, help='hidden layer dimension for the first atom, 0 means we simply use a table for first aw_logit') 44 | group.add_argument('--transformer_layers', type=int, default=16, help='The number of layers in transformer') 45 | group.add_argument('--num_heads', type=int, default=16, help='The number of heads') 46 | group.add_argument('--key_size', type=int, default=64, help='The key size') 47 | group.add_argument('--model_size', type=int, default=64, help='The model size') 48 | group.add_argument('--embed_size', type=int, default=32, help='The enbedding size') 49 | group.add_argument('--dropout_rate', type=float, default=0.1, help='The dropout rate for MLP') 50 | group.add_argument('--attn_dropout', type=float, default=0.1, help='The dropout rate for attention') 51 | 52 | group = parser.add_argument_group('physics parameters') 53 | group.add_argument('--n_max', type=int, default=21, help='The maximum number of atoms in the cell') 54 | group.add_argument('--atom_types', type=int, default=119, help='Atom types including the padded atoms') 55 | group.add_argument('--wyck_types', type=int, default=28, help='Number of possible multiplicites including 0') 56 | 57 | group = parser.add_argument_group('reinforcement learning parameters') 58 | parser.add_argument('--beta', type=float, default=0.1, help='beta for DPO loss') 59 | parser.add_argument('--label_smoothing', type=float, default=0.0, help='label smoothing for DPO loss') 60 | parser.add_argument('--gamma', type=float, default=0.0, help='logp regularization coefficient for DPO loss') 61 | parser.add_argument('--ipo', action='store_true', help='use IPO loss instead of DPO loss') 62 | 63 | args = parser.parse_args() 64 | 65 | print("\n========== Load dataset ==========") 66 | chosen_data = GLXYZAW_from_file(args.chosen_path, args.atom_types, args.wyck_types, args.n_max, args.num_io_process) 67 | rejected_data = GLXYZAW_from_file(args.rejected_path, args.atom_types, args.wyck_types, args.n_max, args.num_io_process) 68 | 69 | print("================ parameters ================") 70 | # print all the parameters 71 | for arg in vars(args): 72 | print(f"{arg}: {getattr(args, arg)}") 73 | 74 | print("\n========== Prepare transformer ==========") 75 | ################### Model ############################# 76 | key = jax.random.PRNGKey(42) 77 | params, transformer = make_transformer(key, args.Nf, args.Kx, args.Kl, args.n_max, 78 | args.h0_size, 79 | args.transformer_layers, args.num_heads, 80 | args.key_size, args.model_size, args.embed_size, 81 | args.atom_types, args.wyck_types, 82 | args.dropout_rate, args.attn_dropout) 83 | 84 | transformer_name = 'Nf_%d_Kx_%d_Kl_%d_h0_%d_l_%d_H_%d_k_%d_m_%d_e_%d_drop_%g'%(args.Nf, args.Kx, args.Kl, args.h0_size, args.transformer_layers, args.num_heads, args.key_size, args.model_size, args.embed_size, args.dropout_rate) 85 | 86 | print ("# of transformer params", ravel_pytree(params)[0].size) 87 | 88 | ################### Train ############################# 89 | 90 | loss_fn, logp_fn = make_loss_fn(args.n_max, args.atom_types, args.wyck_types, args.Kx, args.Kl, transformer) 91 | 92 | print("\n========== Prepare logs ==========") 93 | if args.optimizer != "none" or args.restore_path is None: 94 | output_path = args.folder \ 95 | + "beta_%g_label_%g_gamma_%g_"%(args.beta, args.label_smoothing, args.gamma) \ 96 | + args.optimizer+"_bs_%d_lr_%g_decay_%g_clip_%g" % (args.batchsize, args.lr, args.lr_decay, args.clip_grad) \ 97 | + '_A_%g_W_%g_N_%g'%(args.atom_types, args.wyck_types, args.n_max) \ 98 | + ("_wd_%g"%(args.weight_decay) if args.optimizer == "adamw" else "") \ 99 | + "_" + transformer_name 100 | 101 | os.makedirs(output_path, exist_ok=True) 102 | print("Create directory for output: %s" % output_path) 103 | else: 104 | output_path = os.path.dirname(args.restore_path) 105 | print("Will output samples to: %s" % output_path) 106 | 107 | 108 | print("\n========== Load checkpoint==========") 109 | ckpt_filename, epoch_finished = checkpoint.find_ckpt_filename(args.restore_path or output_path) 110 | if ckpt_filename is not None: 111 | print("Load checkpoint file: %s, epoch finished: %g" %(ckpt_filename, epoch_finished)) 112 | ckpt = checkpoint.load_data(ckpt_filename) 113 | params = ckpt["params"] 114 | else: 115 | print("No checkpoint file found. Start from scratch.") 116 | 117 | if args.optimizer != "none": 118 | 119 | schedule = lambda t: args.lr/(1+args.lr_decay*t) 120 | 121 | if args.optimizer == "adam": 122 | optimizer = optax.chain(optax.clip_by_global_norm(args.clip_grad), 123 | optax.scale_by_adam(), 124 | optax.scale_by_schedule(schedule), 125 | optax.scale(-1.)) 126 | elif args.optimizer == 'adamw': 127 | optimizer = optax.chain(optax.clip(args.clip_grad), 128 | optax.adamw(learning_rate=schedule, weight_decay=args.weight_decay) 129 | ) 130 | 131 | opt_state = optimizer.init(params) 132 | try: 133 | opt_state.update(ckpt["opt_state"]) 134 | except: 135 | print ("failed to update opt_state from checkpoint") 136 | pass 137 | 138 | print("\n========== Start RL training ==========") 139 | dpo_loss_fn = make_dpo_loss(logp_fn, 140 | beta=args.beta, 141 | label_smoothing=args.label_smoothing, 142 | gamma=args.gamma, 143 | ipo=args.ipo) 144 | 145 | # PPO training 146 | params, opt_state = train(key, optimizer, opt_state, dpo_loss_fn, logp_fn, params, epoch_finished, 147 | args.epochs, args.batchsize, chosen_data, rejected_data, output_path, args.val_ratio) 148 | 149 | else: 150 | raise NotImplementedError("No optimizer specified. Please specify an optimizer in the config file.") 151 | 152 | 153 | if __name__ == "__main__": 154 | main() 155 | -------------------------------------------------------------------------------- /crystalformer/data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/CrystalFormer/a2ae7982234b84134950c5c7a4f49d592c83bddf/crystalformer/data/.gitkeep -------------------------------------------------------------------------------- /crystalformer/extension/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/CrystalFormer/a2ae7982234b84134950c5c7a4f49d592c83bddf/crystalformer/extension/__init__.py -------------------------------------------------------------------------------- /crystalformer/extension/experimental.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from functools import partial 4 | 5 | from crystalformer.src.sample import project_xyz 6 | from crystalformer.src.von_mises import sample_von_mises 7 | from crystalformer.src.lattice import symmetrize_lattice 8 | 9 | 10 | def make_cond_logp(logp_fn, forward_fn, target, alpha): 11 | ''' 12 | logp_fn: function to calculate log p(x) 13 | forward_fn: function to calculate log p(y|x), x is G, L, XYZ, A, W 14 | target: target label 15 | alpha: hyperparameter to control the trade-off between log p(x) and log p(y|x) 16 | NOTE that the logp_fn and forward_fn should be vmapped before passing to this function 17 | ''' 18 | 19 | def forward(G, L, XYZ, A, W, target): 20 | y = forward_fn(G, L, XYZ, A, W, target) 21 | return y 22 | 23 | def callback_forward(G, L, XYZ, A, W, target): 24 | result_shape = jax.ShapeDtypeStruct(G.shape, jnp.float32) 25 | return jax.experimental.io_callback(forward, result_shape, G, L, XYZ, A, W, target) 26 | 27 | def cond_logp_fn(params, key, G, L, XYZ, A, W, is_training): 28 | ''' 29 | params: base model parameters 30 | ''' 31 | # calculate log p(x) 32 | logp_w, logp_xyz, logp_a, logp_l = logp_fn(params, key, G, L, XYZ, A, W, is_training) 33 | logp_base = logp_xyz + logp_w + logp_a + logp_l 34 | 35 | # calculate p(y|x) 36 | logp_cond = callback_forward(G, L, XYZ, A, W, target) 37 | 38 | # trade-off between log p(x) and p(y|x) 39 | logp = logp_base - alpha * logp_cond.squeeze() 40 | return logp 41 | 42 | return cond_logp_fn 43 | 44 | 45 | def make_mcmc_step(base_params, n_max, atom_types, atom_mask=None, constraints=None): 46 | 47 | if atom_mask is None or jnp.all(atom_mask == 0): 48 | atom_mask = jnp.ones((n_max, atom_types)) 49 | 50 | if constraints is None: 51 | constraints = jnp.arange(0, n_max, 1) 52 | 53 | def update_A(i, A, a, constraints): 54 | def body_fn(j, A): 55 | A = jax.lax.cond(constraints[j] == constraints[i], 56 | lambda _: A.at[:, j].set(a), 57 | lambda _: A, 58 | None) 59 | return A 60 | 61 | A = jax.lax.fori_loop(0, A.shape[1], body_fn, A) 62 | return A 63 | 64 | @partial(jax.jit, static_argnums=0) 65 | def mcmc(logp_fn, x_init, key, mc_steps, mc_width, temp): 66 | """ 67 | Markov Chain Monte Carlo sampling algorithm. 68 | 69 | INPUT: 70 | logp_fn: callable that evaluate log-probability of a batch of configuration x. 71 | The signature is logp_fn(x), where x has shape (batch, n, dim). 72 | x_init: initial value of x, with shape (batch, n, dim). 73 | key: initial PRNG key. 74 | mc_steps: total number of Monte Carlo steps. 75 | mc_width: size of the Monte Carlo proposal. 76 | temp: temperature in the smiulated annealing. 77 | 78 | OUTPUT: 79 | x: resulting batch samples, with the same shape as `x_init`. 80 | """ 81 | 82 | def update_lattice(i, state): 83 | def update_length(key, L): 84 | length, angle = jnp.split(L, 2, axis=-1) 85 | length += jax.random.normal(key, length.shape) * mc_width 86 | return jnp.concatenate([length, angle], axis=-1) 87 | 88 | def update_angle(key, L): 89 | length, angle = jnp.split(L, 2, axis=-1) 90 | angle += jax.random.normal(key, angle.shape) * mc_width 91 | return jnp.concatenate([length, angle], axis=-1) 92 | 93 | x, logp, key, num_accepts, temp = state 94 | G, L, XYZ, A, W = x 95 | key, key_proposal_L, key_accept, key_logp = jax.random.split(key, 4) 96 | 97 | L_proposal = jax.lax.cond(i % (n_max+2) % n_max == 0, 98 | lambda _: update_length(key_proposal_L, L), 99 | lambda _: update_angle(key_proposal_L, L), 100 | None) 101 | 102 | length, angle = jnp.split(L_proposal, 2, axis=-1) 103 | angle = jnp.rad2deg(angle) # change the unit to degree 104 | L_proposal = jnp.concatenate([length, angle], axis=-1) 105 | L_proposal = jax.vmap(symmetrize_lattice, (0, 0))(G, L_proposal) 106 | 107 | length, angle = jnp.split(L_proposal, 2, axis=-1) 108 | angle = jnp.deg2rad(angle) # change the unit to rad 109 | L_proposal = jnp.concatenate([length, angle], axis=-1) 110 | 111 | x_proposal = (G, L_proposal, XYZ, A, W) 112 | logp_proposal = logp_fn(base_params, key_logp, *x_proposal, False) 113 | 114 | ratio = jnp.exp((logp_proposal - logp)/ temp) 115 | accept = jax.random.uniform(key_accept, ratio.shape) < ratio 116 | 117 | L_new = jnp.where(accept[:, None], L_proposal, L) # update lattice 118 | x_new = (G, L_new, XYZ, A, W) 119 | logp_new = jnp.where(accept, logp_proposal, logp) 120 | num_accepts += jnp.sum(accept) 121 | 122 | jax.debug.print("logp {x} {y}", 123 | x=logp_new.mean(), 124 | y=jnp.std(logp_new)/jnp.sqrt(logp_new.shape[0]) 125 | ) 126 | return x_new, logp_new, key, num_accepts, temp 127 | 128 | 129 | def update_a_xyz(i, state): 130 | def true_func(i, state): 131 | x, logp, key, num_accepts, temp = state 132 | G, L, XYZ, A, W = x 133 | key, key_proposal_A, key_proposal_XYZ, key_accept, key_logp = jax.random.split(key, 5) 134 | 135 | p_normalized = atom_mask[i%n_max] / jnp.sum(atom_mask[i%n_max]) # only propose atom types that are allowed 136 | _a = jax.random.choice(key_proposal_A, a=atom_types, p=p_normalized, shape=(A.shape[0], )) 137 | # _A = A.at[:, i%n_max].set(_a) 138 | _A = update_A(i%n_max, A, _a, constraints) 139 | A_proposal = jnp.where(A == 0, A, _A) 140 | 141 | _xyz = XYZ[:, i%n_max] + sample_von_mises(key_proposal_XYZ, 0, 1/mc_width**2, XYZ[:, i%n_max].shape) 142 | _xyz = jax.vmap(project_xyz, in_axes=(0, 0, 0, None))(G, W[:, i%n_max], _xyz, 0) 143 | _XYZ = XYZ.at[:, i%n_max].set(_xyz) 144 | _XYZ -= jnp.floor(_XYZ) # wrap to [0, 1) 145 | XYZ_proposal = _XYZ 146 | x_proposal = (G, L, XYZ_proposal, A_proposal, W) 147 | 148 | logp_proposal = logp_fn(base_params, key_logp, *x_proposal, False) 149 | 150 | ratio = jnp.exp((logp_proposal - logp)/ temp) 151 | accept = jax.random.uniform(key_accept, ratio.shape) < ratio 152 | 153 | A_new = jnp.where(accept[:, None], A_proposal, A) # update atom types 154 | XYZ_new = jnp.where(accept[:, None, None], XYZ_proposal, XYZ) # update atom positions 155 | x_new = (G, L, XYZ_new, A_new, W) 156 | logp_new = jnp.where(accept, logp_proposal, logp) 157 | num_accepts += jnp.sum(accept*jnp.where(A[:, i%n_max]==0, 0, 1)) 158 | 159 | jax.debug.print("logp {x} {y}", 160 | x=logp_new.mean(), 161 | y=jnp.std(logp_new)/jnp.sqrt(logp_new.shape[0]) 162 | ) 163 | return x_new, logp_new, key, num_accepts, temp 164 | 165 | def false_func(i, state): 166 | return state 167 | 168 | x, logp, key, num_accepts, temp = state 169 | A = x[3] 170 | x, logp, key, num_accepts, temp = jax.lax.cond(A[:, i%(n_max+2)%n_max].sum() != 0, 171 | lambda _: true_func(i, state), 172 | lambda _: false_func(i, state), 173 | None) 174 | return x, logp, key, num_accepts, temp 175 | 176 | def step(i, state): 177 | x, logp, key, num_accepts, temp = jax.lax.cond(i % (n_max+2) < n_max, 178 | lambda _: update_a_xyz(i, state), 179 | lambda _: update_lattice(i, state), 180 | None) 181 | return x, logp, key, num_accepts, temp 182 | 183 | key, subkey = jax.random.split(key) 184 | logp_init = logp_fn(base_params, subkey, *x_init, False) 185 | jax.debug.print("logp {x} {y}", 186 | x=logp_init.mean(), 187 | y=jnp.std(logp_init)/jnp.sqrt(logp_init.shape[0]), 188 | ) 189 | 190 | x, logp, key, num_accepts, temp = jax.lax.fori_loop(0, mc_steps, step, (x_init, logp_init, key, 0., temp)) 191 | A = x[3] 192 | scale = jnp.sum(A != 0)/(A.shape[0]*n_max) 193 | # accept_rate = num_accepts / (scale * mc_steps * x[0].shape[0]) 194 | accept_rate = num_accepts / (scale*mc_steps*n_max/(n_max+2) + mc_steps*2/(n_max+2)) 195 | accept_rate = accept_rate / x[0].shape[0] 196 | return x, accept_rate 197 | 198 | return mcmc 199 | -------------------------------------------------------------------------------- /crystalformer/extension/loss.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from functools import partial 4 | 5 | from crystalformer.src.wyckoff import mult_table 6 | 7 | 8 | def make_classifier_loss(transformer, classifier): 9 | 10 | def forward_fn(params, state, key, G, L, XYZ, A, W, is_train): 11 | M = mult_table[G-1, W] # (n_max,) multplicities 12 | transformer_params, classifier_params = params 13 | _, state = transformer(transformer_params, state, key, G, XYZ, A, W, M, is_train) 14 | 15 | h = state['~']['last_hidden_state'] 16 | g = state['~']['_g_embeddings'] 17 | 18 | key, subkey = jax.random.split(key) 19 | y = classifier(classifier_params, subkey, g, L, W, h, is_train) 20 | return y 21 | 22 | @partial(jax.vmap, in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, None)) 23 | def mae_loss(params, state, key, G, L, XYZ, A, W, labels, is_training): 24 | y = forward_fn(params, state, key, G, L, XYZ, A, W, is_training) 25 | return jnp.abs(y - labels) 26 | 27 | @partial(jax.vmap, in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, None)) 28 | def mse_loss(params, state, key, G, L, XYZ, A, W, labels, is_training): 29 | y = forward_fn(params, state, key, G, L, XYZ, A, W, is_training) 30 | return jnp.square(y - labels) 31 | 32 | def loss_fn(params, state, key, G, L, XYZ, A, W, labels, is_training): 33 | loss = jnp.mean(mae_loss(params, state, key, G, L, XYZ, A, W, labels, is_training)) 34 | return loss 35 | 36 | return loss_fn, forward_fn 37 | 38 | 39 | def make_cond_logp(logp_fn, forward_fn, target, alpha): 40 | ''' 41 | logp_fn: function to calculate log p(x) 42 | forward_fn: function to calculate p(y|x) 43 | target: target label 44 | alpha: hyperparameter to control the trade-off between log p(x) and log p(y|x) 45 | NOTE that the logp_fn and forward_fn should be vmapped before passing to this function 46 | ''' 47 | def cond_logp_fn(base_params, cond_params, state, key, G, L, XYZ, A, W, is_training): 48 | ''' 49 | base_params: base model parameters 50 | cond_params: conditional model parameters 51 | ''' 52 | # calculate log p(x) 53 | logp_w, logp_xyz, logp_a, logp_l = logp_fn(base_params, key, G, L, XYZ, A, W, is_training) 54 | logp_base = logp_xyz + logp_w + logp_a + logp_l 55 | 56 | # calculate p(y|x) 57 | y = forward_fn(cond_params, state, key, G, L, XYZ, A, W, is_training) # f(x) 58 | logp_cond = jnp.abs(target - y) # |y - f(x)| 59 | 60 | # trade-off between log p(x) and p(y|x) 61 | logp = logp_base - alpha * logp_cond.squeeze() 62 | return logp 63 | 64 | return cond_logp_fn 65 | 66 | 67 | def make_multi_cond_logp(logp_fn, forward_fns, targets, alphas): 68 | ''' 69 | logp_fn: function to calculate log p(x) 70 | forward_fns: functions to calculate p(y|x) 71 | targets: target labels 72 | alphas: hyperparameters to control the trade-off between log p(x) and log p(y|x) 73 | 74 | NOTE that the logp_fn and forward_fns should be vmapped before passing to this function 75 | ''' 76 | 77 | num_conditions = len(forward_fns) 78 | assert len(forward_fns) == len(targets) == len(alphas), "The number of forward functions, targets, and alphas should be the same" 79 | print (num_conditions) 80 | 81 | def cond_logp_fn(base_params, cond_params, state, key, G, L, XYZ, A, W, is_training): 82 | ''' 83 | base_params: base model parameters 84 | cond_params: conditional model parameters 85 | ''' 86 | # calculate log p(x) 87 | logp_w, logp_xyz, logp_a, logp_l = logp_fn(base_params, key, G, L, XYZ, A, W, is_training) 88 | logp_base = logp_xyz + logp_w + logp_a + logp_l 89 | 90 | # calculate multiple p(y|x) 91 | key, *subkeys = jax.random.split(key, num_conditions+1) 92 | logp_cond = 0 93 | for i in range(num_conditions): 94 | y = forward_fns[i](cond_params[i], state, subkeys[i], G, L, XYZ, A, W, is_training) 95 | logp_cond += jnp.abs(targets[i] - y).squeeze() * alphas[i] 96 | 97 | # trade-off between log p(x) and p(y|x) 98 | logp = logp_base - logp_cond 99 | 100 | return logp 101 | 102 | return cond_logp_fn 103 | 104 | 105 | if __name__ == "__main__": 106 | from crystalformer.src.utils import GLXYZAW_from_file 107 | 108 | from model import make_classifier 109 | from transformer import make_transformer 110 | 111 | atom_types = 119 112 | n_max = 21 113 | wyck_types = 28 114 | Nf = 5 115 | Kx = 16 116 | Kl = 4 117 | dropout_rate = 0.1 118 | 119 | csv_file = '../data/mini.csv' 120 | G, L, XYZ, A, W = GLXYZAW_from_file(csv_file, atom_types, wyck_types, n_max) 121 | 122 | key = jax.random.PRNGKey(42) 123 | 124 | transformer_params, state, transformer = make_transformer(key, Nf, Kx, Kl, n_max, 128, 4, 4, 8, 16, 16, atom_types, wyck_types, dropout_rate) 125 | classifier_params, classifier = make_classifier(key, 126 | n_max=n_max, 127 | embed_size=16, 128 | sequence_length=105, 129 | outputs_size=16, 130 | hidden_sizes=[16, 16], 131 | num_classes=1) 132 | 133 | params = (transformer_params, classifier_params) 134 | loss_fn, forward_fn = make_classifier_loss(transformer, classifier) 135 | 136 | # test loss_fn for classifier 137 | labels = jnp.ones(G.shape) 138 | value = jax.jit(loss_fn, static_argnums=9)(params, state, key, G[:1], L[:1], XYZ[:1], A[:1], W[:1], labels[:1], True) 139 | print (value) 140 | 141 | value = jax.jit(loss_fn, static_argnums=9)(params, state, key, G[:1], L[:1], XYZ[:1]+1.0, A[:1], W[:1], labels[:1], True) 142 | print (value) 143 | 144 | 145 | ############### test cond_loss_fn ################ 146 | from loss import make_loss_fn 147 | from transformer import make_transformer as make_transformer_base 148 | 149 | base_params, base_transformer = make_transformer_base(key, Nf, Kx, Kl, n_max, 128, 4, 4, 8, 16, 16, atom_types, wyck_types, dropout_rate) 150 | 151 | loss_fn, logp_fn = make_loss_fn(n_max, atom_types, wyck_types, Kx, Kl, base_transformer) 152 | 153 | # test_cond_loss 154 | forward = jax.vmap(forward_fn, in_axes=(None, None, None, 0, 0, 0, 0, 0, None)) 155 | cond_logp_fn = make_cond_logp(logp_fn, forward, 156 | target=1.0, 157 | alpha=0.1) 158 | value = jax.jit(cond_logp_fn, static_argnums=9)(base_params, params, state, key, G, L, XYZ, A, W, False) 159 | print(value) 160 | print(value.shape) 161 | 162 | # test_multi_cond_loss 163 | forward_fns = (forward, forward) 164 | targets = (1.0, 1.0) 165 | alphas = (0.1, 0.1) 166 | multi_cond_logp_fn = make_multi_cond_logp(logp_fn, forward_fns, targets, alphas) 167 | value = jax.jit(multi_cond_logp_fn, static_argnums=9)(base_params, (params, params), state, key, G, L, XYZ, A, W, False) 168 | print(value) 169 | print(value.shape) 170 | -------------------------------------------------------------------------------- /crystalformer/extension/mcmc.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from functools import partial 4 | 5 | from crystalformer.src.sample import project_xyz 6 | from crystalformer.src.von_mises import sample_von_mises 7 | from crystalformer.src.lattice import symmetrize_lattice 8 | 9 | 10 | def make_mcmc_step(base_params, cond_params, model_state, n_max, atom_types, atom_mask=None, constraints=None): 11 | 12 | if atom_mask is None or jnp.all(atom_mask == 0): 13 | atom_mask = jnp.ones((n_max, atom_types)) 14 | 15 | if constraints is None: 16 | constraints = jnp.arange(0, n_max, 1) 17 | 18 | def update_A(i, A, a, constraints): 19 | def body_fn(j, A): 20 | A = jax.lax.cond(constraints[j] == constraints[i], 21 | lambda _: A.at[:, j].set(a), 22 | lambda _: A, 23 | None) 24 | return A 25 | 26 | A = jax.lax.fori_loop(0, A.shape[1], body_fn, A) 27 | return A 28 | 29 | @partial(jax.jit, static_argnums=0) 30 | def mcmc(logp_fn, x_init, key, mc_steps, mc_width, temp): 31 | """ 32 | Markov Chain Monte Carlo sampling algorithm. 33 | 34 | INPUT: 35 | logp_fn: callable that evaluate log-probability of a batch of configuration x. 36 | The signature is logp_fn(x), where x has shape (batch, n, dim). 37 | x_init: initial value of x, with shape (batch, n, dim). 38 | key: initial PRNG key. 39 | mc_steps: total number of Monte Carlo steps. 40 | mc_width: size of the Monte Carlo proposal. 41 | temp: temperature in the smiulated annealing. 42 | 43 | OUTPUT: 44 | x: resulting batch samples, with the same shape as `x_init`. 45 | """ 46 | 47 | def update_lattice(i, state): 48 | def update_length(key, L): 49 | length, angle = jnp.split(L, 2, axis=-1) 50 | length += jax.random.normal(key, length.shape) * mc_width 51 | return jnp.concatenate([length, angle], axis=-1) 52 | 53 | def update_angle(key, L): 54 | length, angle = jnp.split(L, 2, axis=-1) 55 | angle += jax.random.normal(key, angle.shape) * mc_width 56 | return jnp.concatenate([length, angle], axis=-1) 57 | 58 | x, logp, key, num_accepts, temp = state 59 | G, L, XYZ, A, W = x 60 | key, key_proposal_L, key_accept, key_logp = jax.random.split(key, 4) 61 | 62 | L_proposal = jax.lax.cond(i % (n_max+2) % n_max == 0, 63 | lambda _: update_length(key_proposal_L, L), 64 | lambda _: update_angle(key_proposal_L, L), 65 | None) 66 | 67 | length, angle = jnp.split(L_proposal, 2, axis=-1) 68 | angle = jnp.rad2deg(angle) # change the unit to degree 69 | L_proposal = jnp.concatenate([length, angle], axis=-1) 70 | L_proposal = jax.vmap(symmetrize_lattice, (0, 0))(G, L_proposal) 71 | 72 | length, angle = jnp.split(L_proposal, 2, axis=-1) 73 | angle = jnp.deg2rad(angle) # change the unit to rad 74 | L_proposal = jnp.concatenate([length, angle], axis=-1) 75 | 76 | x_proposal = (G, L_proposal, XYZ, A, W) 77 | logp_proposal = logp_fn(base_params, cond_params, model_state, key_logp, *x_proposal, False) 78 | 79 | ratio = jnp.exp((logp_proposal - logp)/ temp) 80 | accept = jax.random.uniform(key_accept, ratio.shape) < ratio 81 | 82 | L_new = jnp.where(accept[:, None], L_proposal, L) # update lattice 83 | x_new = (G, L_new, XYZ, A, W) 84 | logp_new = jnp.where(accept, logp_proposal, logp) 85 | num_accepts += jnp.sum(accept) 86 | 87 | jax.debug.print("logp {x} {y}", 88 | x=logp_new.mean(), 89 | y=jnp.std(logp_new)/jnp.sqrt(logp_new.shape[0]) 90 | ) 91 | return x_new, logp_new, key, num_accepts, temp 92 | 93 | 94 | def update_a_xyz(i, state): 95 | def true_func(i, state): 96 | x, logp, key, num_accepts, temp = state 97 | G, L, XYZ, A, W = x 98 | key, key_proposal_A, key_proposal_XYZ, key_accept, key_logp = jax.random.split(key, 5) 99 | 100 | p_normalized = atom_mask[i%n_max] / jnp.sum(atom_mask[i%n_max]) # only propose atom types that are allowed 101 | _a = jax.random.choice(key_proposal_A, a=atom_types, p=p_normalized, shape=(A.shape[0], )) 102 | # _A = A.at[:, i%n_max].set(_a) 103 | _A = update_A(i%n_max, A, _a, constraints) 104 | A_proposal = jnp.where(A == 0, A, _A) 105 | 106 | _xyz = XYZ[:, i%n_max] + sample_von_mises(key_proposal_XYZ, 0, 1/mc_width**2, XYZ[:, i%n_max].shape) 107 | _xyz = jax.vmap(project_xyz, in_axes=(0, 0, 0, None))(G, W[:, i%n_max], _xyz, 0) 108 | _XYZ = XYZ.at[:, i%n_max].set(_xyz) 109 | _XYZ -= jnp.floor(_XYZ) # wrap to [0, 1) 110 | XYZ_proposal = _XYZ 111 | x_proposal = (G, L, XYZ_proposal, A_proposal, W) 112 | 113 | logp_proposal = logp_fn(base_params, cond_params, model_state, key_logp, *x_proposal, False) 114 | 115 | ratio = jnp.exp((logp_proposal - logp)/ temp) 116 | accept = jax.random.uniform(key_accept, ratio.shape) < ratio 117 | 118 | A_new = jnp.where(accept[:, None], A_proposal, A) # update atom types 119 | XYZ_new = jnp.where(accept[:, None, None], XYZ_proposal, XYZ) # update atom positions 120 | x_new = (G, L, XYZ_new, A_new, W) 121 | logp_new = jnp.where(accept, logp_proposal, logp) 122 | num_accepts += jnp.sum(accept*jnp.where(A[:, i%n_max]==0, 0, 1)) 123 | 124 | jax.debug.print("logp {x} {y}", 125 | x=logp_new.mean(), 126 | y=jnp.std(logp_new)/jnp.sqrt(logp_new.shape[0]) 127 | ) 128 | return x_new, logp_new, key, num_accepts, temp 129 | 130 | def false_func(i, state): 131 | return state 132 | 133 | x, logp, key, num_accepts, temp = state 134 | A = x[3] 135 | x, logp, key, num_accepts, temp = jax.lax.cond(A[:, i%(n_max+2)%n_max].sum() != 0, 136 | lambda _: true_func(i, state), 137 | lambda _: false_func(i, state), 138 | None) 139 | return x, logp, key, num_accepts, temp 140 | 141 | def step(i, state): 142 | x, logp, key, num_accepts, temp = jax.lax.cond(i % (n_max+2) < n_max, 143 | lambda _: update_a_xyz(i, state), 144 | lambda _: update_lattice(i, state), 145 | None) 146 | return x, logp, key, num_accepts, temp 147 | 148 | key, subkey = jax.random.split(key) 149 | logp_init = logp_fn(base_params, cond_params, model_state, subkey, *x_init, False) 150 | jax.debug.print("logp {x} {y}", 151 | x=logp_init.mean(), 152 | y=jnp.std(logp_init)/jnp.sqrt(logp_init.shape[0]), 153 | ) 154 | 155 | x, logp, key, num_accepts, temp = jax.lax.fori_loop(0, mc_steps, step, (x_init, logp_init, key, 0., temp)) 156 | A = x[3] 157 | scale = jnp.sum(A != 0)/(A.shape[0]*n_max) 158 | # accept_rate = num_accepts / (scale * mc_steps * x[0].shape[0]) 159 | accept_rate = num_accepts / (scale*mc_steps*n_max/(n_max+2) + mc_steps*2/(n_max+2)) 160 | accept_rate = accept_rate / x[0].shape[0] 161 | return x, accept_rate 162 | 163 | return mcmc 164 | 165 | 166 | if __name__ == "__main__": 167 | from crystalformer.src.utils import GLXYZAW_from_file 168 | from crystalformer.src.loss import make_loss_fn 169 | from crystalformer.src.transformer import make_transformer 170 | atom_types = 119 171 | n_max = 21 172 | wyck_types = 28 173 | Nf = 5 174 | Kx = 16 175 | Kl = 4 176 | dropout_rate = 0.3 177 | 178 | csv_file = '../data/mini.csv' 179 | G, L, XYZ, A, W = GLXYZAW_from_file(csv_file, atom_types, wyck_types, n_max) 180 | 181 | key = jax.random.PRNGKey(42) 182 | 183 | params, transformer = make_transformer(key, Nf, Kx, Kl, n_max, 128, 4, 4, 8, 16, 16, atom_types, wyck_types, dropout_rate) 184 | 185 | loss_fn, logp_fn = make_loss_fn(n_max, atom_types, wyck_types, Kx, Kl, transformer) 186 | 187 | # MCMC sampling test 188 | mc_steps = 23 189 | mc_width = 0.1 190 | x_init = (G[:5], L[:5], XYZ[:5], A[:5], W[:5]) 191 | 192 | value = jax.jit(logp_fn, static_argnums=7)(params, key, *x_init, False) 193 | 194 | jnp.set_printoptions(threshold=jnp.inf) 195 | mcmc = make_mcmc_step(params, n_max=n_max, atom_types=atom_types) 196 | 197 | for i in range(5): 198 | key, subkey = jax.random.split(key) 199 | x, acc = mcmc(logp_fn, x_init=x_init, key=subkey, mc_steps=mc_steps, mc_width=mc_width) 200 | print(i, acc) 201 | 202 | print("check if the lattice is changed") 203 | print(x_init[1]) 204 | print(x[1]) 205 | 206 | print("check if the atom position is changed") 207 | print(x_init[2]) 208 | print(x[2]) -------------------------------------------------------------------------------- /crystalformer/extension/model.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import haiku as hk 4 | 5 | 6 | def make_classifier(key, 7 | n_max = 21, 8 | embed_size=32, 9 | sequence_length=105, 10 | outputs_size=64, 11 | hidden_sizes=[128, 128], 12 | num_classes=1, 13 | dropout_rate=0.3): 14 | 15 | @hk.transform 16 | def network(g, l, w, h, is_training): 17 | """ 18 | sequence_length = n_max * 5 19 | g : (embed_size, ) 20 | l : (6, ) 21 | w : (n_max,) 22 | h : (sequence_length, ouputs_size) 23 | """ 24 | mask = jnp.where(w > 0, 1, 0) 25 | mask = jnp.repeat(mask, 5, axis=-1) 26 | # mask = hk.Reshape((sequence_length, ))(mask) 27 | h = h * mask[:, None] 28 | 29 | w = jnp.mean(h[0::5, :], axis=-2) 30 | a = jnp.mean(h[1::5, :], axis=-2) 31 | xyz = jnp.mean(h[2::5, :], axis=-2) + jnp.mean(h[3::5, :], axis=-2) + jnp.mean(h[4::5, :], axis=-2) 32 | 33 | h = jnp.concatenate([w, a, xyz], axis=0) 34 | h = hk.Flatten()(h) 35 | 36 | h = jnp.concatenate([g, h, l], axis=0) 37 | 38 | h = jax.nn.relu(hk.Linear(hidden_sizes[0])(h)) 39 | h = hk.dropout(hk.next_rng_key(), dropout_rate, h) if is_training else h # Dropout after the first ReLU 40 | 41 | for hidden_size in hidden_sizes[1: -1]: 42 | h_dense = jax.nn.relu(hk.Linear(hidden_size)(h)) 43 | h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) if is_training else h_dense 44 | h = h + h_dense 45 | 46 | h = hk.Linear(hidden_sizes[-1])(h) 47 | h = jax.nn.relu(h) 48 | h = hk.Linear(num_classes)(h) 49 | 50 | return h 51 | 52 | g = jnp.ones(embed_size) 53 | w = jnp.ones(n_max) 54 | l = jnp.ones(6) 55 | h = jnp.zeros((sequence_length, outputs_size)) 56 | 57 | params = network.init(key, g, l, w, h, True) 58 | return params, network.apply -------------------------------------------------------------------------------- /crystalformer/extension/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jax 3 | import optax 4 | import math 5 | 6 | import crystalformer.src.checkpoint as checkpoint 7 | 8 | 9 | def train(key, optimizer, opt_state, loss_fn, params, state, epoch_finished, epochs, batchsize, train_data, valid_data, path): 10 | 11 | @jax.jit 12 | def update(params, state, key, opt_state, data): 13 | G, L, X, A, W, labels = data 14 | value, grad = jax.value_and_grad(loss_fn)(params, state, key, G, L, X, A, W, labels, True) 15 | updates, opt_state = optimizer.update(grad, opt_state, params) 16 | params = optax.apply_updates(params, updates) 17 | return params, opt_state, value 18 | 19 | log_filename = os.path.join(path, "data.txt") 20 | f = open(log_filename, "w" if epoch_finished == 0 else "a", buffering=1, newline="\n") 21 | if os.path.getsize(log_filename) == 0: 22 | f.write("epoch t_loss v_loss\n") 23 | 24 | for epoch in range(epoch_finished+1, epochs): 25 | key, subkey = jax.random.split(key) 26 | train_data = jax.tree_util.tree_map(lambda x: jax.random.permutation(subkey, x), train_data) 27 | 28 | train_G, train_L, train_X, train_A, train_W, train_labels = train_data 29 | 30 | train_loss = 0.0 31 | num_samples = len(train_labels) 32 | num_batches = math.ceil(num_samples / batchsize) 33 | for batch_idx in range(num_batches): 34 | start_idx = batch_idx * batchsize 35 | end_idx = min(start_idx + batchsize, num_samples) 36 | data = train_G[start_idx:end_idx], \ 37 | train_L[start_idx:end_idx], \ 38 | train_X[start_idx:end_idx], \ 39 | train_A[start_idx:end_idx], \ 40 | train_W[start_idx:end_idx], \ 41 | train_labels[start_idx:end_idx] 42 | 43 | key, subkey = jax.random.split(key) 44 | params, opt_state, loss = update(params, state, subkey, opt_state, data) 45 | train_loss = train_loss + loss 46 | 47 | train_loss = train_loss / num_batches 48 | 49 | if epoch % 10 == 0: 50 | valid_G, valid_L, valid_X, valid_A, valid_W, valid_labels = valid_data 51 | valid_loss = 0.0 52 | num_samples = len(valid_labels) 53 | num_batches = math.ceil(num_samples / batchsize) 54 | for batch_idx in range(num_batches): 55 | start_idx = batch_idx * batchsize 56 | end_idx = min(start_idx + batchsize, num_samples) 57 | G, L, X, A, W, labels = valid_G[start_idx:end_idx], \ 58 | valid_L[start_idx:end_idx], \ 59 | valid_X[start_idx:end_idx], \ 60 | valid_A[start_idx:end_idx], \ 61 | valid_W[start_idx:end_idx], \ 62 | valid_labels[start_idx:end_idx] 63 | 64 | key, subkey = jax.random.split(key) 65 | loss = loss_fn(params, state, subkey, G, L, X, A, W, labels, False) 66 | valid_loss = valid_loss + loss 67 | 68 | valid_loss = valid_loss / num_batches 69 | 70 | f.write( ("%6d" + 2*" %.6f" + "\n") % (epoch, 71 | train_loss, valid_loss 72 | )) 73 | 74 | ckpt = {"params": params, 75 | "opt_state" : opt_state 76 | } 77 | ckpt_filename = os.path.join(path, "epoch_%06d.pkl" %(epoch)) 78 | checkpoint.save_data(ckpt, ckpt_filename) 79 | print("Save checkpoint file: %s" % ckpt_filename) 80 | 81 | f.close() 82 | return params, opt_state 83 | -------------------------------------------------------------------------------- /crystalformer/reinforce/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/CrystalFormer/a2ae7982234b84134950c5c7a4f49d592c83bddf/crystalformer/reinforce/__init__.py -------------------------------------------------------------------------------- /crystalformer/reinforce/dpo.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import optax 4 | import os 5 | import math 6 | 7 | from crystalformer.src.utils import shuffle 8 | import crystalformer.src.checkpoint as checkpoint 9 | 10 | 11 | def make_dpo_loss(logp_fn, beta, label_smoothing=0.0, gamma=0.0, ipo=False): 12 | 13 | # https://github.com/eric-mitchell/direct-preference-optimization/blob/f8b8c0f49dc92a430bae41585f9d467d3618fe2f/trainers.py#L45-L87 14 | def dpo_logp_fn(policy_chosen_logps, 15 | policy_rejected_logps, 16 | ref_chosen_logps, 17 | ref_rejected_logps): 18 | 19 | pi_logratios = policy_chosen_logps - policy_rejected_logps 20 | ref_logratios = ref_chosen_logps - ref_rejected_logps 21 | 22 | logits = pi_logratios - ref_logratios 23 | 24 | if ipo: 25 | losses = (logits - 1/(2 * beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf 26 | else: 27 | # label_smoothing=0 gives original DPO 28 | losses = -jax.nn.log_sigmoid(beta * logits) * (1 - label_smoothing) - jax.nn.log_sigmoid(-beta * logits) * label_smoothing 29 | return jnp.mean(losses) 30 | 31 | def loss_fn(params, key, x_w, x_l, ref_chosen_logps, ref_rejected_logps): 32 | key, subkey = jax.random.split(key) 33 | logp_w, logp_xyz, logp_a, logp_l = logp_fn(params, subkey, *x_w, False) 34 | policy_chosen_logps = logp_w + logp_xyz + logp_a + logp_l 35 | 36 | key, subkey = jax.random.split(key) 37 | logp_w, logp_xyz, logp_a, logp_l = logp_fn(params, subkey, *x_l, False) 38 | policy_rejected_logps = logp_w + logp_xyz + logp_a + logp_l 39 | 40 | dpo_loss = dpo_logp_fn(policy_chosen_logps, 41 | policy_rejected_logps, 42 | ref_chosen_logps, 43 | ref_rejected_logps) 44 | loss = dpo_loss - gamma * jnp.mean(policy_chosen_logps) 45 | 46 | return loss, (dpo_loss, jnp.mean(policy_chosen_logps), jnp.mean(policy_rejected_logps)) 47 | 48 | return loss_fn 49 | 50 | 51 | def train(key, optimizer, opt_state, dpo_loss_fn, logp_fn, params, epoch_finished, epochs, batchsize, chosen_data, rejected_data, path, val_ratio=0.2): 52 | 53 | @jax.jit 54 | def step(params, key, opt_state, x_w, x_l, ref_chosen_logps, ref_rejected_logps): 55 | value, grad = jax.value_and_grad(dpo_loss_fn, has_aux=True)(params, key, x_w, x_l, ref_chosen_logps, ref_rejected_logps) 56 | updates, opt_state = optimizer.update(grad, opt_state, params) 57 | params = optax.apply_updates(params, updates) 58 | return params, opt_state, value 59 | 60 | log_filename = os.path.join(path, "data.txt") 61 | f = open(log_filename, "w" if epoch_finished == 0 else "a", buffering=1, newline="\n") 62 | if os.path.getsize(log_filename) == 0: 63 | f.write("epoch loss dpo_loss chosen_logp rejected_logp v_loss v_dpo_loss v_chosen_logp v_rejected_logp\n") 64 | ref_params = params 65 | logp_fn = jax.jit(logp_fn, static_argnums=7) 66 | 67 | ref_chosen_logps = jnp.array([]) 68 | ref_rejected_logps = jnp.array([]) 69 | _, chosen_L, _, _, _ = chosen_data 70 | num_samples = len(chosen_L) 71 | num_batches = math.ceil(num_samples / batchsize) 72 | for batch_idx in range(num_batches): 73 | start_idx = batch_idx * batchsize 74 | end_idx = min(start_idx + batchsize, num_samples) 75 | key, subkey1, subkey2 = jax.random.split(key, 3) 76 | 77 | data = jax.tree_util.tree_map(lambda x: x[start_idx:end_idx], chosen_data) 78 | logp_w, logp_xyz, logp_a, logp_l = logp_fn(ref_params, subkey1, *data, False) 79 | logp = logp_w + logp_xyz + logp_a + logp_l 80 | ref_chosen_logps = jnp.append(ref_chosen_logps, logp, axis=0) 81 | 82 | data = jax.tree_util.tree_map(lambda x: x[start_idx:end_idx], rejected_data) 83 | logp_w, logp_xyz, logp_a, logp_l = logp_fn(ref_params, subkey2, *data, False) 84 | logp = logp_w + logp_xyz + logp_a + logp_l 85 | ref_rejected_logps = jnp.append(ref_rejected_logps, logp, axis=0) 86 | 87 | print(ref_chosen_logps.shape, ref_rejected_logps.shape) 88 | print(f"ref_chosen_logps: {jnp.mean(ref_chosen_logps)}, ref_rejected_logps: {jnp.mean(ref_rejected_logps)}") 89 | print("Finished calculating reference logp") 90 | 91 | # Shuffle the data 92 | key, subkey = jax.random.split(key) 93 | idx = jax.random.permutation(subkey, jnp.arange(num_samples)) 94 | chosen_data = jax.tree_util.tree_map(lambda x: x[idx], chosen_data) 95 | rejected_data = jax.tree_util.tree_map(lambda x: x[idx], rejected_data) 96 | ref_chosen_logps = ref_chosen_logps[idx] 97 | ref_rejected_logps = ref_rejected_logps[idx] 98 | 99 | # Split the data into training and validation 100 | num_val_samples = int(num_samples * val_ratio) 101 | num_train_samples = num_samples - num_val_samples 102 | print("num_train_samples: %d, num_val_samples: %d" % (num_train_samples, num_val_samples)) 103 | 104 | train_chosen_data = jax.tree_util.tree_map(lambda x: x[:num_train_samples], chosen_data) 105 | train_rejected_data = jax.tree_util.tree_map(lambda x: x[:num_train_samples], rejected_data) 106 | train_ref_chosen_logps = ref_chosen_logps[:num_train_samples] 107 | train_ref_rejected_logps = ref_rejected_logps[:num_train_samples] 108 | 109 | val_chosen_data = jax.tree_util.tree_map(lambda x: x[num_train_samples:], chosen_data) 110 | val_rejected_data = jax.tree_util.tree_map(lambda x: x[num_train_samples:], rejected_data) 111 | val_ref_chosen_logps = ref_chosen_logps[num_train_samples:] 112 | val_ref_rejected_logps = ref_rejected_logps[num_train_samples:] 113 | 114 | 115 | for epoch in range(epoch_finished+1, epochs+1): 116 | key, subkey = jax.random.split(key) 117 | train_chosen_data = shuffle(subkey, train_chosen_data) 118 | train_rejected_data = shuffle(subkey, train_rejected_data) 119 | 120 | idx = jax.random.permutation(subkey, jnp.arange(len(train_ref_chosen_logps))) 121 | train_ref_chosen_logps = train_ref_chosen_logps[idx] 122 | train_ref_rejected_logps = train_ref_rejected_logps[idx] 123 | 124 | train_loss = 0.0 125 | train_dpo_loss = 0.0 126 | train_policy_chosen_logps = 0.0 127 | train_policy_rejected_logps = 0.0 128 | _, chosen_L, _, _, _ = train_chosen_data 129 | num_samples = chosen_L.shape[0] 130 | num_batches = math.ceil(num_samples / batchsize) 131 | for batch_idx in range(num_batches): 132 | start_idx = batch_idx * batchsize 133 | end_idx = min(start_idx + batchsize, num_samples) 134 | x_w = jax.tree_util.tree_map(lambda x: x[start_idx:end_idx], train_chosen_data) 135 | x_l = jax.tree_util.tree_map(lambda x: x[start_idx:end_idx], train_rejected_data) 136 | ref_chosen_logps_batch = train_ref_chosen_logps[start_idx:end_idx] 137 | ref_rejected_logps_batch = train_ref_rejected_logps[start_idx:end_idx] 138 | 139 | key, subkey = jax.random.split(key) 140 | params, opt_state, value = step(params, subkey, opt_state, x_w, x_l, ref_chosen_logps_batch, ref_rejected_logps_batch) 141 | loss, (dpo_loss, policy_chosen_logps, policy_rejected_logps) = value 142 | train_loss += loss 143 | train_dpo_loss += dpo_loss 144 | train_policy_chosen_logps += policy_chosen_logps 145 | train_policy_rejected_logps += policy_rejected_logps 146 | 147 | train_loss /= num_batches 148 | train_dpo_loss /= num_batches 149 | train_policy_chosen_logps /= num_batches 150 | train_policy_rejected_logps /= num_batches 151 | f.write( ("%6d" + 4*" %.6f") % (epoch, train_loss, train_dpo_loss, train_policy_chosen_logps, train_policy_rejected_logps)) 152 | 153 | # Validation 154 | val_loss = 0.0 155 | val_dpo_loss = 0.0 156 | val_policy_chosen_logps = 0.0 157 | val_policy_rejected_logps = 0.0 158 | num_val_samples = len(val_ref_chosen_logps) 159 | num_batches = math.ceil(num_val_samples / batchsize) 160 | for batch_idx in range(num_batches): 161 | start_idx = batch_idx * batchsize 162 | end_idx = min(start_idx + batchsize, num_val_samples) 163 | x_w = jax.tree_util.tree_map(lambda x: x[start_idx:end_idx], val_chosen_data) 164 | x_l = jax.tree_util.tree_map(lambda x: x[start_idx:end_idx], val_rejected_data) 165 | ref_chosen_logps_batch = val_ref_chosen_logps[start_idx:end_idx] 166 | ref_rejected_logps_batch = val_ref_rejected_logps[start_idx:end_idx] 167 | 168 | key, subkey = jax.random.split(key) 169 | loss, (dpo_loss, policy_chosen_logps, policy_rejected_logps) = jax.jit(dpo_loss_fn)(params, subkey, x_w, x_l, ref_chosen_logps_batch, ref_rejected_logps_batch) 170 | val_loss += loss 171 | val_dpo_loss += dpo_loss 172 | val_policy_chosen_logps += policy_chosen_logps 173 | val_policy_rejected_logps += policy_rejected_logps 174 | 175 | val_loss /= num_batches 176 | val_dpo_loss /= num_batches 177 | val_policy_chosen_logps /= num_batches 178 | val_policy_rejected_logps /= num_batches 179 | f.write( (4*" %.6f" + "\n") % (val_loss, val_dpo_loss, val_policy_chosen_logps, val_policy_rejected_logps)) 180 | 181 | 182 | if epoch % 1 == 0: 183 | ckpt = {"params": params, 184 | "opt_state" : opt_state 185 | } 186 | ckpt_filename = os.path.join(path, "epoch_%06d.pkl" %(epoch)) 187 | checkpoint.save_data(ckpt, ckpt_filename) 188 | print("Save checkpoint file: %s" % ckpt_filename) 189 | 190 | f.close() 191 | 192 | return params, opt_state 193 | -------------------------------------------------------------------------------- /crystalformer/reinforce/ehull.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from pymatgen.core import Structure 3 | from pymatgen.entries.computed_entries import ComputedStructureEntry, ComputedEntry 4 | from pymatgen.analysis.phase_diagram import PhaseDiagram 5 | from pymatgen.entries.compatibility import MaterialsProject2020Compatibility 6 | from pymatgen.io.vasp.sets import MPRelaxSet 7 | from pymatgen.io.vasp.inputs import Incar, Poscar 8 | 9 | 10 | def generate_CSE(structure, m3gnet_energy): 11 | # Write VASP inputs files as if we were going to do a standard MP run 12 | # this is mainly necessary to get the right U values / etc 13 | b = MPRelaxSet(structure) 14 | with tempfile.TemporaryDirectory() as tmpdirname: 15 | b.write_input(f"{tmpdirname}/", potcar_spec=True) 16 | poscar = Poscar.from_file(f"{tmpdirname}/POSCAR") 17 | incar = Incar.from_file(f"{tmpdirname}/INCAR") 18 | clean_structure = Structure.from_file(f"{tmpdirname}/POSCAR") 19 | 20 | # Get the U values and figure out if we should have run a GGA+U calc 21 | param = {"hubbards": {}} 22 | if "LDAUU" in incar: 23 | param["hubbards"] = dict(zip(poscar.site_symbols, incar["LDAUU"])) 24 | param["is_hubbard"] = ( 25 | incar.get("LDAU", True) and sum(param["hubbards"].values()) > 0 26 | ) 27 | if param["is_hubbard"]: 28 | param["run_type"] = "GGA+U" 29 | 30 | # Make a ComputedStructureEntry without the correction 31 | cse_d = { 32 | "structure": clean_structure, 33 | "energy": m3gnet_energy, 34 | "correction": 0.0, 35 | "parameters": param, 36 | } 37 | 38 | # Apply the MP 2020 correction scheme (anion/+U/etc) 39 | cse = ComputedStructureEntry.from_dict(cse_d) 40 | _ = MaterialsProject2020Compatibility(check_potcar=False).process_entries( 41 | cse, 42 | clean=True, 43 | ) 44 | 45 | # Return the final CSE (notice that the composition/etc is also clean, not things like Fe3+)! 46 | return cse 47 | 48 | 49 | def calculate_hull(structure, energy, entries): 50 | entries = [ComputedEntry.from_dict(i) for i in entries] 51 | pd = PhaseDiagram(entries) 52 | 53 | entry = generate_CSE(structure, energy) 54 | ehull = pd.get_e_above_hull(entry, allow_negative=True) 55 | 56 | return ehull 57 | 58 | 59 | def forward_fn(structure, energy, ref_data): 60 | 61 | comp = structure.composition 62 | elements = set(ii.name for ii in comp.elements) 63 | 64 | # filter entries by elements 65 | entries = [entry for entry in ref_data['entries'] if set(entry['data']['elements']) <= elements] 66 | ehull = calculate_hull(structure, energy, entries) 67 | 68 | return ehull 69 | -------------------------------------------------------------------------------- /crystalformer/reinforce/potential.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ase.calculators.calculator import Calculator, all_changes 4 | from ase.neighborlist import NeighborList 5 | from ase.stress import full_3x3_to_voigt_6_stress 6 | 7 | 8 | class ExponentialPotential(Calculator): 9 | """ 10 | Exponential potential for ASE. 11 | u(r) = exp(-alpha * r) 12 | 13 | https://gitlab.com/ase/ase/-/blob/master/ase/calculators/lj.py?ref_type=heads 14 | """ 15 | 16 | implemented_properties = ['energy', 'energies', 'forces', 'free_energy'] 17 | implemented_properties += ['stress', 'stresses'] # bulk properties 18 | default_parameters = { 19 | 'alpha': 1.0, 20 | 'rc': None, 21 | 'ro': None, 22 | 'smooth': False, 23 | } 24 | nolabel = True 25 | 26 | def __init__(self, **kwargs): 27 | """ 28 | Parameters 29 | ---------- 30 | alpha: float 31 | The decay constant of the exponential potential, default 1.0 32 | rc: float, None 33 | Cut-off for the NeighborList. The energy is upshifted to be continuous at rc. 34 | Default None 35 | ro: float, None 36 | Onset of cutoff function in 'smooth' mode. Defaults to 0.66 * rc. 37 | smooth: bool, False 38 | Cutoff mode. False means that the pairwise energy is simply shifted 39 | to be 0 at r = rc, leading to the energy going to 0 continuously, 40 | but the forces jumping to zero discontinuously at the cutoff. 41 | True means that a smooth cutoff function is multiplied to the pairwise 42 | energy that smoothly goes to 0 between ro and rc. Both energy and 43 | forces are continuous in that case. 44 | If smooth=True, make sure to check the tail of the 45 | forces for kinks, ro might have to be adjusted to avoid distorting 46 | the potential too much. 47 | 48 | """ 49 | 50 | Calculator.__init__(self, **kwargs) 51 | 52 | if self.parameters.rc is None: 53 | self.parameters.rc = 10.0 # Choose an appropriate rc for your system 54 | 55 | if self.parameters.ro is None: 56 | self.parameters.ro = 0.66 * self.parameters.rc 57 | 58 | self.nl = None 59 | 60 | def calculate( 61 | self, 62 | atoms=None, 63 | properties=None, 64 | system_changes=all_changes, 65 | ): 66 | if properties is None: 67 | properties = self.implemented_properties 68 | 69 | Calculator.calculate(self, atoms, properties, system_changes) 70 | 71 | natoms = len(self.atoms) 72 | 73 | alpha = self.parameters.alpha 74 | rc = self.parameters.rc 75 | ro = self.parameters.ro 76 | smooth = self.parameters.smooth 77 | 78 | if self.nl is None or 'numbers' in system_changes: 79 | self.nl = NeighborList( 80 | [rc / 2] * natoms, self_interaction=False, bothways=True 81 | ) 82 | 83 | self.nl.update(self.atoms) 84 | 85 | positions = self.atoms.positions 86 | cell = self.atoms.cell 87 | 88 | # potential value at rc 89 | e0 = np.exp(-alpha * rc) 90 | 91 | energies = np.zeros(natoms) 92 | forces = np.zeros((natoms, 3)) 93 | stresses = np.zeros((natoms, 3, 3)) 94 | 95 | for ii in range(natoms): 96 | neighbors, offsets = self.nl.get_neighbors(ii) 97 | cells = np.dot(offsets, cell) 98 | 99 | # pointing *towards* neighbours 100 | distance_vectors = positions[neighbors] + cells - positions[ii] 101 | 102 | r = np.sqrt((distance_vectors ** 2).sum(1)) 103 | r[r > rc] = np.inf # Exclude pairs beyond cutoff 104 | 105 | if smooth: 106 | cutoff_fn = cutoff_function(r ** 2, rc ** 2, ro ** 2) 107 | d_cutoff_fn = d_cutoff_function(r ** 2, rc ** 2, ro ** 2) 108 | 109 | pairwise_energies = np.exp(-alpha * r) 110 | pairwise_forces = -alpha * np.exp(-alpha * r) / r # du_ij/dr 111 | 112 | if smooth: 113 | pairwise_forces = ( 114 | cutoff_fn * pairwise_forces + 2 * d_cutoff_fn 115 | * pairwise_energies 116 | ) 117 | pairwise_energies *= cutoff_fn 118 | else: 119 | pairwise_energies -= e0 * (r != 0.0) 120 | 121 | pairwise_forces = pairwise_forces[:, np.newaxis] * distance_vectors 122 | 123 | energies[ii] += 0.5 * pairwise_energies.sum() # atomic energies 124 | forces[ii] += pairwise_forces.sum(axis=0) 125 | 126 | stresses[ii] += 0.5 * np.dot( 127 | pairwise_forces.T, distance_vectors 128 | ) # equivalent to outer product 129 | 130 | # no lattice, no stress 131 | if self.atoms.cell.rank == 3: 132 | stresses = full_3x3_to_voigt_6_stress(stresses) 133 | self.results['stress'] = stresses.sum( 134 | axis=0) / self.atoms.get_volume() 135 | self.results['stresses'] = stresses / self.atoms.get_volume() 136 | 137 | energy = energies.sum() 138 | self.results['energy'] = energy 139 | self.results['energies'] = energies 140 | 141 | self.results['free_energy'] = energy 142 | 143 | self.results['forces'] = forces 144 | 145 | 146 | def cutoff_function(r, rc, ro): 147 | """Smooth cutoff function. 148 | 149 | Goes from 1 to 0 between ro and rc, ensuring 150 | that u(r) = exp(-alpha * r) * cutoff_function(r) is C^1. 151 | 152 | Defined as 1 below ro, 0 above rc. 153 | 154 | Note that r, rc, ro are all expected to be squared, 155 | i.e. `r = r_ij^2`, etc. 156 | 157 | """ 158 | 159 | return np.where( 160 | r < ro, 161 | 1.0, 162 | np.where(r < rc, (rc - r) ** 2 * (rc + 2 * 163 | r - 3 * ro) / (rc - ro) ** 3, 0.0), 164 | ) 165 | 166 | 167 | def d_cutoff_function(r, rc, ro): 168 | """Derivative of smooth cutoff function wrt r. 169 | 170 | Note that `r = r_ij^2`, so for the derivative wrt to `r_ij`, 171 | we need to multiply `2*r_ij`. This gives rise to the factor 2 172 | above, the `r_ij` is cancelled out by the remaining derivative 173 | `d r_ij / d d_ij`, i.e. going from scalar distance to distance vector. 174 | """ 175 | 176 | return np.where( 177 | r < ro, 178 | 0.0, 179 | np.where(r < rc, 6 * (rc - r) * (ro - r) / (rc - ro) ** 3, 0.0), 180 | ) 181 | -------------------------------------------------------------------------------- /crystalformer/reinforce/ppo.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import os 4 | import optax 5 | import math 6 | from functools import partial 7 | 8 | import crystalformer.src.checkpoint as checkpoint 9 | from crystalformer.src.lattice import norm_lattice 10 | 11 | 12 | def make_ppo_loss_fn(logp_fn, eps_clip, beta=0.1): 13 | 14 | """ 15 | PPO clipped objective function with KL divergence regularization 16 | PPO_loss = PPO-clip + beta * KL(P || P_pretrain) 17 | 18 | Note that we only consider the logp_xyz and logp_l in the logp_fn 19 | """ 20 | 21 | def ppo_loss_fn(params, key, x, old_logp, pretrain_logp, advantages): 22 | 23 | logp_w, logp_xyz, logp_a, logp_l = logp_fn(params, key, *x, False) 24 | logp = logp_w + logp_xyz + logp_a + logp_l 25 | 26 | kl_loss = logp - pretrain_logp 27 | advantages = advantages - beta * kl_loss 28 | 29 | # Finding the ratio (pi_theta / pi_theta__old) 30 | ratios = jnp.exp(logp - old_logp) 31 | 32 | # Finding Surrogate Loss 33 | surr1 = ratios * advantages 34 | surr2 = jax.lax.clamp(1-eps_clip, ratios, 1+eps_clip) * advantages 35 | 36 | # Final loss of clipped objective PPO 37 | ppo_loss = jnp.mean(jnp.minimum(surr1, surr2)) 38 | 39 | return ppo_loss, (jnp.mean(kl_loss)) 40 | 41 | return ppo_loss_fn 42 | 43 | 44 | def train(key, optimizer, opt_state, spg_mask, loss_fn, logp_fn, batch_reward_fn, ppo_loss_fn, sample_crystal, params, epoch_finished, epochs, ppo_epochs, batchsize, valid_data, path): 45 | 46 | num_devices = jax.local_device_count() 47 | batch_per_device = batchsize // num_devices 48 | shape_prefix = (num_devices, batch_per_device) 49 | print("num_devices: ", num_devices) 50 | print("batch_per_device: ", batch_per_device) 51 | print("shape_prefix: ", shape_prefix) 52 | 53 | @partial(jax.pmap, axis_name="p", in_axes=(None, None, None, 0, 0, 0, 0), out_axes=(None, None, 0),) 54 | def step(params, key, opt_state, x, old_logp, pretrain_logp, advantages): 55 | value, grad = jax.value_and_grad(ppo_loss_fn, has_aux=True)(params, key, x, old_logp, pretrain_logp, advantages) 56 | grad = jax.lax.pmean(grad, axis_name="p") 57 | value = jax.lax.pmean(value, axis_name="p") 58 | grad = jax.tree_util.tree_map(lambda g_: g_ * -1.0, grad) # invert gradient for maximization 59 | updates, opt_state = optimizer.update(grad, opt_state, params) 60 | params = optax.apply_updates(params, updates) 61 | return params, opt_state, value 62 | 63 | log_filename = os.path.join(path, "data.txt") 64 | f = open(log_filename, "w" if epoch_finished == 0 else "a", buffering=1, newline="\n") 65 | if os.path.getsize(log_filename) == 0: 66 | f.write("epoch f_mean f_err v_loss v_loss_w v_loss_a v_loss_xyz v_loss_l\n") 67 | pretrain_params = params 68 | logp_fn = jax.jit(logp_fn, static_argnums=7) 69 | loss_fn = jax.jit(loss_fn, static_argnums=7) 70 | 71 | for epoch in range(epoch_finished+1, epochs+1): 72 | 73 | key, subkey1, subkey2 = jax.random.split(key, 3) 74 | G = jax.random.choice(subkey1, 75 | a=jnp.arange(1, 231, 1), 76 | p=spg_mask, 77 | shape=(batchsize, )) 78 | XYZ, A, W, _, L = sample_crystal(subkey2, params, G) 79 | 80 | x = (G, L, XYZ, A, W) 81 | rewards = - batch_reward_fn(x) # inverse reward 82 | f_mean = jnp.mean(rewards) 83 | f_err = jnp.std(rewards) / jnp.sqrt(batchsize) 84 | 85 | # running average baseline 86 | baseline = f_mean if epoch == epoch_finished+1 else 0.95 * baseline + 0.05 * f_mean 87 | advantages = rewards - baseline 88 | 89 | f.write( ("%6d" + 2*" %.6f") % (epoch, f_mean, f_err)) 90 | 91 | G, L, XYZ, A, W = x 92 | L = norm_lattice(G, W, L) 93 | x = (G, L, XYZ, A, W) 94 | 95 | key, subkey1, subkey2 = jax.random.split(key, 3) 96 | logp_w, logp_xyz, logp_a, logp_l = logp_fn(params, subkey1, *x, False) 97 | old_logp = logp_w + logp_xyz + logp_a + logp_l 98 | 99 | logp_w, logp_xyz, logp_a, logp_l = logp_fn(pretrain_params, subkey2, *x, False) 100 | pretrain_logp = logp_w + logp_xyz + logp_a + logp_l 101 | 102 | x = jax.tree_util.tree_map(lambda _x: _x.reshape(shape_prefix + _x.shape[1:]), x) 103 | old_logp = old_logp.reshape(shape_prefix + old_logp.shape[1:]) 104 | pretrain_logp = pretrain_logp.reshape(shape_prefix + pretrain_logp.shape[1:]) 105 | advantages = advantages.reshape(shape_prefix + advantages.shape[1:]) 106 | 107 | for _ in range(ppo_epochs): 108 | key, subkey = jax.random.split(key) 109 | params, opt_state, value = step(params, subkey, opt_state, x, old_logp, pretrain_logp, advantages) 110 | ppo_loss, (kl_loss) = value 111 | print(f"epoch {epoch}, loss {jnp.mean(ppo_loss):.6f} {jnp.mean(kl_loss):.6f}") 112 | 113 | valid_loss = 0.0 114 | valid_aux = 0.0, 0.0, 0.0, 0.0 115 | num_samples = len(valid_data[0]) 116 | num_batches = math.ceil(num_samples / batchsize) 117 | for batch_idx in range(num_batches): 118 | start_idx = batch_idx * batchsize 119 | end_idx = min(start_idx + batchsize, num_samples) 120 | batch_data = jax.tree_util.tree_map(lambda x: x[start_idx:end_idx], valid_data) 121 | 122 | key, subkey = jax.random.split(key) 123 | loss, aux = loss_fn(params, subkey, *batch_data, False) 124 | valid_loss, valid_aux = jax.tree_util.tree_map( 125 | lambda acc, i: acc + i, 126 | (valid_loss, valid_aux), 127 | (loss, aux) 128 | ) 129 | 130 | valid_loss, valid_aux = jax.tree_util.tree_map( 131 | lambda x: x/num_batches, 132 | (valid_loss, valid_aux) 133 | ) 134 | valid_loss_w, valid_loss_a, valid_loss_xyz, valid_loss_l = valid_aux 135 | f.write( (5*" %.6f" + "\n") % (valid_loss, 136 | valid_loss_w, 137 | valid_loss_a, 138 | valid_loss_xyz, 139 | valid_loss_l)) 140 | 141 | if epoch % 5 == 0: 142 | ckpt = {"params": params, 143 | "opt_state" : opt_state 144 | } 145 | ckpt_filename = os.path.join(path, "epoch_%06d.pkl" %(epoch)) 146 | checkpoint.save_data(ckpt, ckpt_filename) 147 | print("Save checkpoint file: %s" % ckpt_filename) 148 | 149 | f.close() 150 | 151 | return params, opt_state 152 | -------------------------------------------------------------------------------- /crystalformer/reinforce/sample.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from functools import partial 4 | 5 | from crystalformer.src.lattice import symmetrize_lattice 6 | from crystalformer.src.wyckoff import mult_table 7 | from crystalformer.src.sample import project_xyz, sample_top_p, sample_x 8 | 9 | 10 | @partial(jax.vmap, in_axes=(None, None, 0, 0, 0, 0, 0, 0), out_axes=0) # batch 11 | def inference(model, params, G, W, A, X, Y, Z): 12 | XYZ = jnp.concatenate([X[:, None], 13 | Y[:, None], 14 | Z[:, None] 15 | ], 16 | axis=-1) 17 | M = mult_table[G-1, W] 18 | return model(params, None, G, XYZ, A, W, M, False) 19 | 20 | 21 | def make_sample_crystal(transformer, n_max, atom_types, wyck_types, Kx, Kl): 22 | """ 23 | sample fucntion for different space group 24 | """ 25 | 26 | @partial(jax.jit, static_argnums=(4, 5)) 27 | def sample_crystal(key, params, G, atom_mask, top_p, temperature): 28 | 29 | def body_fn(i, state): 30 | key, W, A, X, Y, Z, L = state 31 | 32 | # (1) W 33 | w_logit = inference(transformer, params, G, W, A, X, Y, Z)[:, 5*i] # (batchsize, output_size) 34 | w_logit = w_logit[:, :wyck_types] 35 | 36 | key, subkey = jax.random.split(key) 37 | w = sample_top_p(subkey, w_logit, top_p, temperature) 38 | W = W.at[:, i].set(w) 39 | 40 | # (2) A 41 | h_al = inference(transformer, params, G, W, A, X, Y, Z)[:, 5*i+1] # (batchsize, output_size) 42 | a_logit = h_al[:, :atom_types] 43 | 44 | key, subkey = jax.random.split(key) 45 | a_logit = a_logit + jnp.where(atom_mask[i, :], 0.0, -1e10) # enhance the probability of masked atoms (do not need to normalize since we only use it for sampling, not computing logp) 46 | a = sample_top_p(subkey, a_logit, top_p, temperature) # use T1 for the first atom type 47 | A = A.at[:, i].set(a) 48 | 49 | lattice_params = h_al[:, atom_types:atom_types+Kl+2*6*Kl] 50 | L = L.at[:, i].set(lattice_params) 51 | 52 | # (3) X 53 | h_x = inference(transformer, params, G, W, A, X, Y, Z)[:, 5*i+2] # (batchsize, output_size) 54 | key, x = sample_x(key, h_x, Kx, top_p, temperature, batchsize) 55 | 56 | # project to the first WP 57 | xyz = jnp.concatenate([x[:, None], 58 | jnp.zeros((batchsize, 1)), 59 | jnp.zeros((batchsize, 1)), 60 | ], axis=-1) 61 | xyz = jax.vmap(project_xyz, in_axes=(0, 0, 0, None), out_axes=0)(G, w, xyz, 0) 62 | x = xyz[:, 0] 63 | X = X.at[:, i].set(x) 64 | 65 | # (4) Y 66 | h_y = inference(transformer, params, G, W, A, X, Y, Z)[:, 5*i+3] # (batchsize, output_size) 67 | key, y = sample_x(key, h_y, Kx, top_p, temperature, batchsize) 68 | 69 | # project to the first WP 70 | xyz = jnp.concatenate([X[:, i][:, None], 71 | y[:, None], 72 | jnp.zeros((batchsize, 1)), 73 | ], axis=-1) 74 | xyz = jax.vmap(project_xyz, in_axes=(0, 0, 0, None), out_axes=0)(G, w, xyz, 0) 75 | y = xyz[:, 1] 76 | Y = Y.at[:, i].set(y) 77 | 78 | # (5) Z 79 | h_z = inference(transformer, params, G, W, A, X, Y, Z)[:, 5*i+4] # (batchsize, output_size) 80 | key, z = sample_x(key, h_z, Kx, top_p, temperature, batchsize) 81 | 82 | # project to the first WP 83 | xyz = jnp.concatenate([X[:, i][:, None], 84 | Y[:, i][:, None], 85 | z[:, None], 86 | ], axis=-1) 87 | xyz = jax.vmap(project_xyz, in_axes=(0, 0, 0, None), out_axes=0)(G, w, xyz, 0) 88 | z = xyz[:, 2] 89 | Z = Z.at[:, i].set(z) 90 | 91 | return key, W, A, X, Y, Z, L 92 | 93 | # we waste computation time by always working with the maximum length sequence, but we save compilation time 94 | batchsize = G.shape[0] 95 | W = jnp.zeros((batchsize, n_max), dtype=int) 96 | A = jnp.zeros((batchsize, n_max), dtype=int) 97 | X = jnp.zeros((batchsize, n_max)) 98 | Y = jnp.zeros((batchsize, n_max)) 99 | Z = jnp.zeros((batchsize, n_max)) 100 | L = jnp.zeros((batchsize, n_max, Kl+2*6*Kl)) # we accumulate lattice params and sample lattice after 101 | 102 | key, W, A, X, Y, Z, L = jax.lax.fori_loop(0, n_max, body_fn, (key, W, A, X, Y, Z, L)) 103 | 104 | M = jax.vmap(lambda g, w: mult_table[g-1, w], in_axes=(0, 0))(G, W) 105 | num_sites = jnp.sum(A!=0, axis=1) 106 | num_atoms = jnp.sum(M, axis=1) 107 | 108 | l_logit, mu, sigma = jnp.split(L[jnp.arange(batchsize), num_sites, :], [Kl, Kl+6*Kl], axis=-1) 109 | 110 | key, key_k, key_l = jax.random.split(key, 3) 111 | # k is (batchsize, ) integer array whose value in [0, Kl) 112 | k = sample_top_p(key_k, l_logit, top_p, temperature) 113 | 114 | mu = mu.reshape(batchsize, Kl, 6) 115 | mu = mu[jnp.arange(batchsize), k] # (batchsize, 6) 116 | sigma = sigma.reshape(batchsize, Kl, 6) 117 | sigma = sigma[jnp.arange(batchsize), k] # (batchsize, 6) 118 | L = jax.random.normal(key_l, (batchsize, 6)) * sigma*jnp.sqrt(temperature) + mu # (batchsize, 6) 119 | 120 | #scale length according to atom number since we did reverse of that when loading data 121 | length, angle = jnp.split(L, 2, axis=-1) 122 | length = length*num_atoms[:, None]**(1/3) 123 | angle = angle * (180.0 / jnp.pi) # to deg 124 | L = jnp.concatenate([length, angle], axis=-1) 125 | 126 | #impose space group constraint to lattice params 127 | L = jax.vmap(symmetrize_lattice, (0, 0))(G, L) 128 | 129 | XYZ = jnp.concatenate([X[..., None], 130 | Y[..., None], 131 | Z[..., None] 132 | ], 133 | axis=-1) 134 | 135 | return XYZ, A, W, M, L 136 | 137 | return sample_crystal 138 | 139 | 140 | if __name__ == "__main__": 141 | from crystalformer.src.transformer import make_transformer 142 | atom_types = 119 143 | n_max = 21 144 | wyck_types = 28 145 | Nf = 5 146 | Kx = 16 147 | Kl = 4 148 | dropout_rate = 0.1 149 | 150 | key = jax.random.PRNGKey(42) 151 | params, transformer = make_transformer(key, Nf, Kx, Kl, n_max, 128, 4, 4, 8, 16, 16, atom_types, wyck_types, dropout_rate) 152 | sample_crystal = make_sample_crystal(transformer, n_max, atom_types, wyck_types, Kx, Kl) 153 | atom_mask = jnp.zeros((n_max, atom_types)) 154 | 155 | G = jnp.array([2, 12, 62, 139, 166, 194, 225]) 156 | XYZ, A, W, M, L = sample_crystal(key, params, G, atom_mask, 1.0, 1.0) 157 | print(XYZ.shape, A.shape, W.shape, M.shape, L.shape) 158 | print ("G:\n", G) # space group 159 | print ("XYZ:\n", XYZ) # fractional coordinate 160 | print ("A:\n", A) # element type 161 | print ("W:\n", W) # Wyckoff positions 162 | print ("M:\n", M) # multiplicity 163 | print ("N:\n", M.sum(axis=-1)) # total number of atoms 164 | print ("L:\n", L) # lattice 165 | -------------------------------------------------------------------------------- /crystalformer/reinforce/vanilla.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import os 4 | import optax 5 | 6 | import crystalformer.src.checkpoint as checkpoint 7 | from crystalformer.src.lattice import norm_lattice 8 | 9 | 10 | def make_reinforce_loss(batch_logp, batch_reward_fn): 11 | 12 | def loss(params, key, x, is_train): 13 | f = batch_reward_fn(x) 14 | f = jax.lax.stop_gradient(f) 15 | 16 | f_mean = jnp.mean(f) 17 | f_std = jnp.std(f)/jnp.sqrt(f.shape[0]) 18 | 19 | G, L, XYZ, A, W = x 20 | L = norm_lattice(G, W, L) 21 | x = (G, L, XYZ, A, W) 22 | 23 | # TODO: now only support for crystalformer logp 24 | logp_w, logp_xyz, logp_a, logp_l = jax.jit(batch_logp, static_argnums=7)(params, key, *x, is_train) 25 | entropy = logp_w + logp_xyz + logp_a + logp_l 26 | 27 | return -jnp.mean((f - f_mean) * entropy), (-f_mean, f_std) 28 | 29 | return loss 30 | 31 | 32 | def train(key, optimizer, opt_state, loss_fn, sample_crystal, params, epoch_finished, epochs, batchsize, path): 33 | 34 | def update(params, key, opt_state, spacegroup): 35 | @jax.jit 36 | def apply_update(grad, params, opt_state): 37 | grad = jax.tree_util.tree_map(lambda g_: g_ * -1.0, grad) # invert gradient for maximization 38 | updates, opt_state = optimizer.update(grad, opt_state, params) 39 | params = optax.apply_updates(params, updates) 40 | return params, opt_state 41 | 42 | key, sample_key, loss_key = jax.random.split(key, 3) 43 | XYZ, A, W, M, L = sample_crystal(sample_key, params=params, g=spacegroup, batchsize=batchsize) 44 | G = spacegroup * jnp.ones((batchsize), dtype=int) 45 | x = (G, L, XYZ, A, W) 46 | value, grad = jax.value_and_grad(loss_fn, has_aux=True)(params, loss_key, x, True) 47 | params, opt_state = apply_update(grad, params, opt_state) 48 | return params, opt_state, value 49 | 50 | log_filename = os.path.join(path, "data.txt") 51 | f = open(log_filename, "w" if epoch_finished == 0 else "a", buffering=1, newline="\n") 52 | if os.path.getsize(log_filename) == 0: 53 | f.write("epoch f_mean f_err\n") 54 | 55 | for epoch in range(epoch_finished+1, epochs): 56 | key, subkey = jax.random.split(key) 57 | params, opt_state, value = update(params, subkey, opt_state, spacegroup=1) # TODO: only for P1 for now 58 | _, (f_mean, f_err) = value 59 | 60 | f.write( ("%6d" + 2*" %.6f" + "\n") % (epoch, f_mean, f_err)) 61 | 62 | if epoch % 5 == 0: 63 | ckpt = {"params": params, 64 | "opt_state" : opt_state 65 | } 66 | ckpt_filename = os.path.join(path, "epoch_%06d.pkl" %(epoch)) 67 | checkpoint.save_data(ckpt, ckpt_filename) 68 | print("Save checkpoint file: %s" % ckpt_filename) 69 | 70 | f.close() 71 | return params, opt_state 72 | -------------------------------------------------------------------------------- /crystalformer/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/CrystalFormer/a2ae7982234b84134950c5c7a4f49d592c83bddf/crystalformer/src/__init__.py -------------------------------------------------------------------------------- /crystalformer/src/attention.py: -------------------------------------------------------------------------------- 1 | # https://github.com/google-deepmind/dm-haiku/blob/main/haiku/_src/attention.py 2 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """(Multi-Head) Attention module for use in Transformer architectures.""" 17 | 18 | from typing import Optional 19 | import warnings 20 | 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | 25 | import haiku as hk 26 | 27 | from crystalformer.src.rope import sine_table, apply_rotary_embedding 28 | # from crystalformer.src.rope import RelativePosition 29 | 30 | 31 | class MultiHeadAttention(hk.Module): 32 | """Multi-headed attention (MHA) module. 33 | 34 | This module is intended for attending over sequences of vectors. 35 | 36 | Rough sketch: 37 | - Compute keys (K), queries (Q), and values (V) as projections of inputs. 38 | - Attention weights are computed as W = softmax(QK^T / sqrt(key_size)). 39 | - Output is another projection of WV^T. 40 | 41 | For more detail, see the original Transformer paper: 42 | "Attention is all you need" https://arxiv.org/abs/1706.03762. 43 | 44 | Glossary of shapes: 45 | - T: Sequence length. 46 | - D: Vector (embedding) size. 47 | - H: Number of attention heads. 48 | """ 49 | 50 | def __init__( 51 | self, 52 | num_heads: int, 53 | key_size: int, 54 | w_init: Optional[hk.initializers.Initializer] = None, 55 | with_bias: bool = True, 56 | b_init: Optional[hk.initializers.Initializer] = None, 57 | value_size: Optional[int] = None, 58 | model_size: Optional[int] = None, 59 | dropout_rate: Optional[float] = 0.0, 60 | name: Optional[str] = None, 61 | ): 62 | """Initialises the module. 63 | 64 | Args: 65 | num_heads: Number of independent attention heads (H). 66 | key_size: The size of keys (K) and queries used for attention. 67 | w_init: Initialiser for weights in the linear map. Once `w_init_scale` is 68 | fully deprecated `w_init` will become mandatory. Until then it has a 69 | default value of `None` for backwards compatability. 70 | with_bias: Whether to add a bias when computing various linear 71 | projections. 72 | b_init: Optional initializer for bias. By default, zero. 73 | value_size: Optional size of the value projection (V). If None, defaults 74 | to the key size (K). 75 | model_size: Optional size of the output embedding (D'). If None, defaults 76 | to the key size multiplied by the number of heads (K * H). 77 | name: Optional name for this module. 78 | """ 79 | super().__init__(name=name) 80 | self.num_heads = num_heads 81 | self.key_size = key_size 82 | self.value_size = value_size or key_size 83 | self.model_size = model_size or key_size * num_heads 84 | self.dropout_rate = dropout_rate 85 | 86 | if w_init is None: 87 | w_init = hk.initializers.VarianceScaling(w_init_scale) 88 | self.w_init = w_init 89 | self.with_bias = with_bias 90 | self.b_init = b_init 91 | 92 | def __call__( 93 | self, 94 | query: jax.Array, 95 | key: jax.Array, 96 | value: jax.Array, 97 | mask: Optional[jax.Array] = None, 98 | is_train : Optional[bool] = False, 99 | ) -> jax.Array: 100 | """Computes (optionally masked) MHA with queries, keys & values. 101 | 102 | This module broadcasts over zero or more 'batch-like' leading dimensions. 103 | 104 | Args: 105 | query: Embeddings sequence used to compute queries; shape [..., T', D_q]. 106 | key: Embeddings sequence used to compute keys; shape [..., T, D_k]. 107 | value: Embeddings sequence used to compute values; shape [..., T, D_v]. 108 | mask: Optional mask applied to attention weights; shape [..., H=1, T', T]. 109 | 110 | Returns: 111 | A new sequence of embeddings, consisting of a projection of the 112 | attention-weighted value projections; shape [..., T', D']. 113 | """ 114 | 115 | # In shape hints below, we suppress the leading dims [...] for brevity. 116 | # Hence e.g. [A, B] should be read in every case as [..., A, B]. 117 | *leading_dims, sequence_length, _ = query.shape 118 | projection = self._linear_projection 119 | 120 | # Compute key/query/values (overload K/Q/V to denote the respective sizes). 121 | query_heads = projection(query, self.key_size, "query") # [T', H, Q=K] 122 | key_heads = projection(key, self.key_size, "key") # [T, H, K] 123 | value_heads = projection(value, self.value_size, "value") # [T, H, V] 124 | 125 | # Rotary Positional Embeddings 126 | sin, cos = sine_table(features=self.key_size, length=sequence_length) 127 | query_heads, key_heads = apply_rotary_embedding(query_heads, key_heads, cos, sin) 128 | 129 | # relative_position = RelativePosition(max_relative_position=sequence_length) 130 | 131 | # Compute attention weights. 132 | attn_logits = jnp.einsum("...thd,...Thd->...htT", query_heads, key_heads) 133 | attn_logits = attn_logits / np.sqrt(self.key_size).astype(key.dtype) 134 | 135 | # # Add relative position embeddings 136 | # attn_logits += relative_position(sequence_length, sequence_length) 137 | 138 | if mask is not None: 139 | if mask.ndim != attn_logits.ndim: 140 | raise ValueError( 141 | f"Mask dimensionality {mask.ndim} must match logits dimensionality " 142 | f"{attn_logits.ndim}." 143 | ) 144 | attn_logits = jnp.where(mask, attn_logits, -1e30) 145 | attn_weights = jax.nn.softmax(attn_logits) # [H, T', T] 146 | 147 | if is_train: 148 | attn_weights = hk.dropout(hk.next_rng_key(), self.dropout_rate, attn_weights) 149 | 150 | # Weight the values by the attention and flatten the head vectors. 151 | attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads) 152 | attn = jnp.reshape(attn, (*leading_dims, sequence_length, -1)) # [T', H*V] 153 | 154 | # Apply another projection to get the final embeddings. 155 | final_projection = hk.Linear(self.model_size, w_init=self.w_init, 156 | with_bias=self.with_bias, b_init=self.b_init) 157 | return final_projection(attn) # [T', D'] 158 | 159 | @hk.transparent 160 | def _linear_projection( 161 | self, 162 | x: jax.Array, 163 | head_size: int, 164 | name: Optional[str] = None, 165 | ) -> jax.Array: 166 | y = hk.Linear(self.num_heads * head_size, w_init=self.w_init, 167 | with_bias=self.with_bias, b_init=self.b_init, name=name)(x) 168 | *leading_dims, _ = x.shape 169 | return y.reshape((*leading_dims, self.num_heads, head_size)) 170 | -------------------------------------------------------------------------------- /crystalformer/src/checkpoint.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import re 4 | 5 | def find_ckpt_filename(path_or_file): 6 | """ 7 | Find the latest checkpoint file in the given directory or the given file. 8 | If path_or_file is a file, it should be a checkpoint file. 9 | If path_or_file is a directory, it should contain checkpoint files. 10 | Returns the filename of the latest checkpoint file and the epoch number. 11 | 12 | Args: 13 | path_or_file (str): The directory containing checkpoint files or a checkpoint file. 14 | 15 | Returns: 16 | fname: The filename of the latest checkpoint file. 17 | epoch: The epoch number of the latest checkpoint file. 18 | """ 19 | if os.path.isfile(path_or_file): 20 | epoch = int(re.search('epoch_([0-9]*).pkl', path_or_file).group(1)) 21 | return path_or_file, epoch 22 | files = [f for f in os.listdir(path_or_file) if ('pkl' in f)] 23 | for f in sorted(files, reverse=True): 24 | fname = os.path.join(path_or_file, f) 25 | try: 26 | with open(fname, "rb") as f: 27 | pickle.load(f) 28 | epoch = int(re.search('epoch_([0-9]*).pkl', fname).group(1)) 29 | return fname, epoch 30 | except (OSError, EOFError): 31 | print('Error loading checkpoint. Trying next checkpoint...', fname) 32 | return None, 0 33 | 34 | def load_data(filename): 35 | with open(filename, "rb") as f: 36 | data = pickle.load(f) 37 | return data 38 | 39 | def save_data(data, filename): 40 | with open(filename, "wb") as f: 41 | pickle.dump(data, f) 42 | 43 | 44 | -------------------------------------------------------------------------------- /crystalformer/src/elements.py: -------------------------------------------------------------------------------- 1 | element_list = [ 2 | # 0 3 | 'X', 4 | # 1 5 | 'H', 'He', 6 | # 2 7 | 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 8 | # 3 9 | 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 10 | # 4 11 | 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 12 | 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 13 | # 5 14 | 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 15 | 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 16 | # 6 17 | 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 18 | 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 19 | 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 20 | 'Po', 'At', 'Rn', 21 | # 7 22 | 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 23 | 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', 24 | 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 25 | 'Lv', 'Ts', 'Og'] 26 | 27 | element_dict = {value: index for index, value in enumerate(element_list)} 28 | 29 | # radioactive elements 30 | radioactive_elements = [ 'Tc', 'Pm', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 31 | 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 32 | 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og'] 33 | radioactive_elements_dict = {e: element_dict[e] for e in radioactive_elements} 34 | 35 | # noble gas elements 36 | noble_gas = ['He', 'Ne', 'Ar', 'Kr', 'Xe', 'Rn', 'Og'] 37 | noble_gas_dict = {e: element_dict[e] for e in noble_gas} 38 | 39 | 40 | if __name__=="__main__": 41 | print (len(element_list)) 42 | print (element_dict["H"]) 43 | 44 | atom_types = 119 45 | wyck_types = 3 46 | aw_types = (atom_types -1)*(wyck_types -1) + 1 47 | print (aw_types) 48 | idx = [element_dict[e] for e in ['H', 'C', 'O']] 49 | aw_mask = [1] + [1 if ((i-1)%(atom_types-1)+1 in idx) else 0 for i in range(1, aw_types)] # 1 for possible elements 50 | print (idx ) 51 | print (aw_mask) 52 | print(radioactive_elements_dict) 53 | print(noble_gas_dict) 54 | atom_mask = [1] + [1 if i not in radioactive_elements_dict.values() and i not in noble_gas_dict.values() else 0 for i in range(1, atom_types)] 55 | print('sampling structure formed by non-radioactive elements and non-noble gas') 56 | print(atom_mask) 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /crystalformer/src/lattice.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from crystalformer.src.wyckoff import mult_table 4 | 5 | def make_lattice_mask(): 6 | ''' 7 | return mask for independent lattice params 8 | ''' 9 | # 1-2 10 | # 3-15 11 | # 16-74 12 | # 75-142 13 | # 143-194 14 | # 195-230 15 | mask = [1, 1, 1, 1, 1, 1] * 2 +\ 16 | [1, 1, 1, 0, 1, 0] * 13+\ 17 | [1, 1, 1, 0, 0, 0] * 59+\ 18 | [1, 0, 1, 0, 0, 0] * 68+\ 19 | [1, 0, 1, 0, 0, 0] * 52+\ 20 | [1, 0, 0, 0, 0, 0] * 36 21 | 22 | return jnp.array(mask).reshape(230, 6) 23 | 24 | def symmetrize_lattice(spacegroup, lattice): 25 | ''' 26 | place lattice params into lattice according to the space group 27 | ''' 28 | 29 | a, b, c, alpha, beta, gamma = lattice 30 | 31 | L = lattice 32 | L = jnp.where(spacegroup <= 2, L, jnp.array([a, b, c, 90., beta, 90.])) 33 | L = jnp.where(spacegroup <= 15, L, jnp.array([a, b, c, 90., 90., 90.])) 34 | L = jnp.where(spacegroup <= 74, L, jnp.array([a, a, c, 90., 90., 90.])) 35 | L = jnp.where(spacegroup <= 142, L, jnp.array([a, a, c, 90., 90., 120.])) 36 | L = jnp.where(spacegroup <= 194, L, jnp.array([a, a, a, 90., 90., 90.])) 37 | 38 | return L 39 | 40 | 41 | def norm_lattice(G, W, L): 42 | """ 43 | normalize the lattice lengths by the number of atoms in the unit cell, 44 | change the lattice angles to radian 45 | a -> a/n_atoms^(1/3) 46 | angle -> angle * pi/180 47 | """ 48 | M = jax.vmap(lambda g, w: mult_table[g-1, w], in_axes=(0, 0))(G, W) # (batchsize, n_max) 49 | num_atoms = jnp.sum(M, axis=1) 50 | length, angle = jnp.split(L, 2, axis=-1) 51 | length = length/num_atoms[:, None]**(1/3) 52 | angle = angle * (jnp.pi / 180) # to rad 53 | L = jnp.concatenate([length, angle], axis=-1) 54 | 55 | return L 56 | 57 | 58 | if __name__ == '__main__': 59 | 60 | mask = make_lattice_mask() 61 | print (mask) 62 | 63 | key = jax.random.PRNGKey(42) 64 | lattice = jax.random.normal(key, (6,)) 65 | lattice = lattice.reshape([1, 6]).repeat(3, axis=0) 66 | 67 | G = jnp.array([25, 99, 221]) 68 | L = jax.vmap(symmetrize_lattice)(G, lattice) 69 | print (L) 70 | 71 | -------------------------------------------------------------------------------- /crystalformer/src/loss.py: -------------------------------------------------------------------------------- 1 | import jax 2 | #jax.config.update("jax_enable_x64", True) 3 | import jax.numpy as jnp 4 | from functools import partial 5 | 6 | from crystalformer.src.von_mises import von_mises_logpdf 7 | from crystalformer.src.lattice import make_lattice_mask 8 | from crystalformer.src.wyckoff import mult_table, fc_mask_table 9 | 10 | 11 | def make_loss_fn(n_max, atom_types, wyck_types, Kx, Kl, transformer, lamb_a=1.0, lamb_w=1.0, lamb_l=1.0): 12 | """ 13 | Args: 14 | n_max: maximum number of atoms in the unit cell 15 | atom_types: number of atom types 16 | wyck_types: number of wyckoff types 17 | Kx: number of von mises components for x, y, z 18 | Kl: number of Guassian mixture components for lattice parameters 19 | transformer: model 20 | lamb_a: weight for atom type loss 21 | lamb_w: weight for wyckoff position loss 22 | lamb_l: weight for lattice parameter loss 23 | 24 | Returns: 25 | loss_fn: loss function 26 | logp_fn: log probability function 27 | """ 28 | 29 | coord_types = 3*Kx 30 | lattice_mask = make_lattice_mask() 31 | 32 | def compute_logp_x(h_x, X, fc_mask_x): 33 | x_logit, loc, kappa = jnp.split(h_x, [Kx, 2*Kx], axis=-1) 34 | x_loc = loc.reshape(n_max, Kx) 35 | kappa = kappa.reshape(n_max, Kx) 36 | logp_x = jax.vmap(von_mises_logpdf, (None, 1, 1), 1)((X-0.5)*2*jnp.pi, loc, kappa) # (n_max, Kx) 37 | logp_x = jax.scipy.special.logsumexp(x_logit + logp_x, axis=1) # (n_max, ) 38 | logp_x = jnp.sum(jnp.where(fc_mask_x, logp_x, jnp.zeros_like(logp_x))) 39 | 40 | return logp_x 41 | 42 | @partial(jax.vmap, in_axes=(None, None, 0, 0, 0, 0, 0, None), out_axes=0) # batch 43 | def logp_fn(params, key, G, L, XYZ, A, W, is_train): 44 | ''' 45 | G: scalar 46 | L: (6,) [a, b, c, alpha, beta, gamma] 47 | XYZ: (n_max, 3) 48 | A: (n_max,) 49 | W: (n_max,) 50 | ''' 51 | 52 | num_sites = jnp.sum(A!=0) 53 | M = mult_table[G-1, W] # (n_max,) multplicities 54 | #num_atoms = jnp.sum(M) 55 | 56 | h = transformer(params, key, G, XYZ, A, W, M, is_train) # (5*n_max+1, ...) 57 | w_logit = h[0::5, :wyck_types] # (n_max+1, wyck_types) 58 | w_logit = w_logit[:-1] # (n_max, wyck_types) 59 | a_logit = h[1::5, :atom_types] 60 | h_x = h[2::5, :coord_types] 61 | h_y = h[3::5, :coord_types] 62 | h_z = h[4::5, :coord_types] 63 | 64 | logp_w = jnp.sum(w_logit[jnp.arange(n_max), W.astype(int)]) 65 | logp_a = jnp.sum(a_logit[jnp.arange(n_max), A.astype(int)]) 66 | 67 | X, Y, Z = XYZ[:, 0], XYZ[:, 1], XYZ[:,2] 68 | 69 | fc_mask = jnp.logical_and((W>0)[:, None], fc_mask_table[G-1, W]) # (n_max, 3) 70 | logp_x = compute_logp_x(h_x, X, fc_mask[:, 0]) 71 | logp_y = compute_logp_x(h_y, Y, fc_mask[:, 1]) 72 | logp_z = compute_logp_x(h_z, Z, fc_mask[:, 2]) 73 | 74 | logp_xyz = logp_x + logp_y + logp_z 75 | 76 | l_logit, mu, sigma = jnp.split(h[1::5][num_sites, 77 | atom_types:atom_types+Kl+2*6*Kl], [Kl, Kl+Kl*6], axis=-1) 78 | mu = mu.reshape(Kl, 6) 79 | sigma = sigma.reshape(Kl, 6) 80 | logp_l = jax.vmap(jax.scipy.stats.norm.logpdf, (None, 0, 0))(L,mu,sigma) #(Kl, 6) 81 | logp_l = jax.scipy.special.logsumexp(l_logit[:, None] + logp_l, axis=0) # (6,) 82 | logp_l = jnp.sum(jnp.where((lattice_mask[G-1]>0), logp_l, jnp.zeros_like(logp_l))) 83 | 84 | return logp_w, logp_xyz, logp_a, logp_l 85 | 86 | # https://github.com/google/jax/blob/cd6eeea9e3e8652e17fdbb1575c9a63fcd558d6b/jax/_src/ad_checkpoint.py#L73 87 | # This is a useful heuristic for transformers. 88 | # @partial(jax.checkpoint, policy=jax.checkpoint_policies.offload_dot_with_no_batch_dims, static_argnums=(7,)) 89 | # @partial(jax.checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable, static_argnums=(7,)) 90 | def loss_fn(params, key, G, L, XYZ, A, W, is_train): 91 | logp_w, logp_xyz, logp_a, logp_l = logp_fn(params, key, G, L, XYZ, A, W, is_train) 92 | loss_w = -jnp.mean(logp_w) 93 | loss_xyz = -jnp.mean(logp_xyz) 94 | loss_a = -jnp.mean(logp_a) 95 | loss_l = -jnp.mean(logp_l) 96 | 97 | return loss_xyz + lamb_a* loss_a + lamb_w*loss_w + lamb_l*loss_l, (loss_w, loss_a, loss_xyz, loss_l) 98 | 99 | return loss_fn, logp_fn 100 | 101 | if __name__=='__main__': 102 | from utils import GLXYZAW_from_file 103 | from transformer import make_transformer 104 | atom_types = 119 105 | n_max = 20 106 | wyck_types = 20 107 | Nf = 5 108 | Kx = 16 109 | Kl = 4 110 | dropout_rate = 0.1 111 | 112 | csv_file = '../data/mini.csv' 113 | G, L, XYZ, A, W = GLXYZAW_from_file(csv_file, atom_types, wyck_types, n_max) 114 | 115 | key = jax.random.PRNGKey(42) 116 | 117 | params, transformer = make_transformer(key, Nf, Kx, Kl, n_max, 128, 4, 4, 8, 16, 16, atom_types, wyck_types, dropout_rate) 118 | 119 | loss_fn, _ = make_loss_fn(n_max, atom_types, wyck_types, Kx, Kl, transformer) 120 | 121 | value = jax.jit(loss_fn, static_argnums=7)(params, key, G[:1], L[:1], XYZ[:1], A[:1], W[:1], True) 122 | print (value) 123 | 124 | value = jax.jit(loss_fn, static_argnums=7)(params, key, G[:1], L[:1], XYZ[:1]+1.0, A[:1], W[:1], True) 125 | print (value) 126 | -------------------------------------------------------------------------------- /crystalformer/src/mcmc.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from functools import partial 4 | 5 | from crystalformer.src.von_mises import sample_von_mises 6 | from crystalformer.src.sample import project_xyz 7 | 8 | 9 | def make_mcmc_step(params, n_max, atom_types, atom_mask=None, constraints=None): 10 | 11 | if atom_mask is None or jnp.all(atom_mask == 0): 12 | atom_mask = jnp.ones((n_max, atom_types)) 13 | 14 | if constraints is None: 15 | constraints = jnp.arange(0, n_max, 1) 16 | 17 | def update_A(i, A, a, constraints): 18 | def body_fn(j, A): 19 | A = jax.lax.cond(constraints[j] == constraints[i], 20 | lambda _: A.at[:, j].set(a), 21 | lambda _: A, 22 | None) 23 | return A 24 | 25 | A = jax.lax.fori_loop(0, A.shape[1], body_fn, A) 26 | return A 27 | 28 | @partial(jax.jit, static_argnums=0) 29 | def mcmc(logp_fn, x_init, key, mc_steps, mc_width): 30 | """ 31 | Markov Chain Monte Carlo sampling algorithm. 32 | 33 | INPUT: 34 | logp_fn: callable that evaluate log-probability of a batch of configuration x. 35 | The signature is logp_fn(x), where x has shape (batch, n, dim). 36 | x_init: initial value of x, with shape (batch, n, dim). 37 | key: initial PRNG key. 38 | mc_steps: total number of Monte Carlo steps. 39 | mc_width: size of the Monte Carlo proposal. 40 | 41 | OUTPUT: 42 | x: resulting batch samples, with the same shape as `x_init`. 43 | """ 44 | def step(i, state): 45 | 46 | def true_func(i, state): 47 | x, logp, key, num_accepts = state 48 | G, L, XYZ, A, W = x 49 | key, key_proposal_A, key_proposal_XYZ, key_accept, key_logp = jax.random.split(key, 5) 50 | 51 | p_normalized = atom_mask[i%n_max] / jnp.sum(atom_mask[i%n_max]) # only propose atom types that are allowed 52 | _a = jax.random.choice(key_proposal_A, a=atom_types, p=p_normalized, shape=(A.shape[0], )) 53 | # _A = A.at[:, i%n_max].set(_a) 54 | _A = update_A(i%n_max, A, _a, constraints) 55 | A_proposal = jnp.where(A == 0, A, _A) 56 | 57 | _xyz = XYZ[:, i%n_max] + sample_von_mises(key_proposal_XYZ, 0, 1/mc_width**2, XYZ[:, i%n_max].shape) 58 | _xyz = jax.vmap(project_xyz, in_axes=(0, 0, 0, None))(G, W[:, i%n_max], _xyz, 0) 59 | _XYZ = XYZ.at[:, i%n_max].set(_xyz) 60 | _XYZ -= jnp.floor(_XYZ) # wrap to [0, 1) 61 | XYZ_proposal = _XYZ 62 | x_proposal = (G, L, XYZ_proposal, A_proposal, W) 63 | 64 | logp_w, logp_xyz, logp_a, _ = logp_fn(params, key_logp, *x_proposal, False) 65 | logp_proposal = logp_w + logp_xyz + logp_a 66 | 67 | ratio = jnp.exp((logp_proposal - logp)) 68 | accept = jax.random.uniform(key_accept, ratio.shape) < ratio 69 | 70 | A_new = jnp.where(accept[:, None], A_proposal, A) # update atom types 71 | XYZ_new = jnp.where(accept[:, None, None], XYZ_proposal, XYZ) # update atom positions 72 | x_new = (G, L, XYZ_new, A_new, W) 73 | logp_new = jnp.where(accept, logp_proposal, logp) 74 | num_accepts += jnp.sum(accept*jnp.where(A[:, i%n_max]==0, 0, 1)) 75 | jax.debug.print("logp {x} {y}", 76 | x=logp_new.mean(), 77 | y=jnp.std(logp_new)/jnp.sqrt(logp_new.shape[0]) 78 | ) 79 | return x_new, logp_new, key, num_accepts 80 | 81 | def false_func(i, state): 82 | x, logp, key, num_accepts = state 83 | return x, logp, key, num_accepts 84 | 85 | x, logp, key, num_accepts = state 86 | A = x[3] 87 | x, logp, key, num_accepts = jax.lax.cond(A[:, i%n_max].sum() != 0, 88 | lambda _: true_func(i, state), 89 | lambda _: false_func(i, state), 90 | None) 91 | return x, logp, key, num_accepts 92 | 93 | key, subkey = jax.random.split(key) 94 | logp_w, logp_xyz, logp_a, _ = logp_fn(params, subkey, *x_init, False) 95 | logp_init = logp_w + logp_xyz + logp_a 96 | jax.debug.print("logp {x} {y}", 97 | x=logp_init.mean(), 98 | y=jnp.std(logp_init)/jnp.sqrt(logp_init.shape[0]) 99 | ) 100 | 101 | x, logp, key, num_accepts = jax.lax.fori_loop(0, mc_steps, step, (x_init, logp_init, key, 0.)) 102 | # print("logp", logp) 103 | A = x[3] 104 | scale = jnp.sum(A != 0)/(A.shape[0]*n_max) 105 | accept_rate = num_accepts / (scale * mc_steps * x[0].shape[0]) 106 | return x, accept_rate 107 | 108 | return mcmc 109 | 110 | 111 | if __name__ == "__main__": 112 | from utils import GLXYZAW_from_file 113 | from loss import make_loss_fn 114 | from transformer import make_transformer 115 | atom_types = 119 116 | n_max = 21 117 | wyck_types = 28 118 | Nf = 5 119 | Kx = 16 120 | Kl = 4 121 | dropout_rate = 0.3 122 | 123 | csv_file = '../../data/mini.csv' 124 | G, L, XYZ, A, W = GLXYZAW_from_file(csv_file, atom_types, wyck_types, n_max) 125 | 126 | key = jax.random.PRNGKey(42) 127 | 128 | params, transformer = make_transformer(key, Nf, Kx, Kl, n_max, 128, 4, 4, 8, 16, 16, atom_types, wyck_types, dropout_rate) 129 | 130 | loss_fn, logp_fn = make_loss_fn(n_max, atom_types, wyck_types, Kx, Kl, transformer) 131 | 132 | # MCMC sampling test 133 | mc_steps = 21 134 | mc_width = 0.1 135 | x_init = (G[:5], L[:5], XYZ[:5], A[:5], W[:5]) 136 | 137 | value = jax.jit(logp_fn, static_argnums=7)(params, key, *x_init, False) 138 | 139 | jnp.set_printoptions(threshold=jnp.inf) 140 | mcmc = make_mcmc_step(params, n_max=n_max, atom_types=atom_types) 141 | 142 | for i in range(5): 143 | key, subkey = jax.random.split(key) 144 | x, acc = mcmc(logp_fn, x_init=x_init, key=subkey, mc_steps=mc_steps, mc_width=mc_width) 145 | print(i, acc) 146 | 147 | print("check if the atom type is changed") 148 | print(x_init[3]) 149 | print(x[3]) 150 | 151 | print("check if the atom position is changed") 152 | print(x_init[2]) 153 | print(x[2]) -------------------------------------------------------------------------------- /crystalformer/src/rope.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import haiku as hk 4 | 5 | 6 | def sine_table(features, length, min_timescale=1.0, max_timescale=10000.0): 7 | fraction = jnp.arange(0, features, 2, dtype=jnp.float32) / features 8 | timescale = min_timescale * (max_timescale / min_timescale) ** fraction 9 | rotational_frequency = 1.0 / timescale 10 | # Must use high precision einsum here, bfloat16 rounding is catastrophic. 11 | sinusoid_inp = jnp.einsum( 12 | 'i,j->ij', 13 | jnp.arange(length), 14 | rotational_frequency, 15 | precision=jax.lax.Precision.HIGHEST, 16 | ) 17 | sinusoid_inp = jnp.concatenate([sinusoid_inp, sinusoid_inp], axis=-1) 18 | return jnp.sin(sinusoid_inp), jnp.cos(sinusoid_inp) 19 | 20 | 21 | def rotate_half(x): 22 | x1, x2 = jnp.split(x, 2, axis=-1) 23 | x = jnp.concatenate([-x2, x1], axis=-1) 24 | return x 25 | 26 | 27 | # https://github.com/google/flax/blob/nnx/flax/experimental/nnx/examples/07_transformer.py#L131-L157 28 | def apply_rotary_embedding(q, k, cos, sin, index=None): 29 | """ 30 | Helper function to apply Rotary Embeddings. 31 | 32 | The implementation is different from the original Rotary position embeddings, 33 | more details can be found in F.2. section of https://arxiv.org/abs/2202.07765 34 | """ 35 | qlen, qheads, d = q.shape 36 | klen, kheads, kd = k.shape 37 | if index is not None: 38 | qcos = jax.lax.broadcast_in_dim( 39 | cos[index, :], (qlen, qheads, d), (2,) 40 | ) 41 | qsin = jax.lax.broadcast_in_dim( 42 | sin[index, :], (qlen, qheads, d), (2,) 43 | ) 44 | else: 45 | qcos = jax.lax.broadcast_in_dim( 46 | cos[:qlen, :], (qlen, qheads, d), (0, 2) 47 | ) 48 | qsin = jax.lax.broadcast_in_dim( 49 | sin[:qlen, :], (qlen, qheads, d), (0, 2) 50 | ) 51 | kcos = jax.lax.broadcast_in_dim( 52 | cos[:klen, :], (klen, kheads, d), (0, 2) 53 | ) 54 | ksin = jax.lax.broadcast_in_dim( 55 | sin[:klen, :], (klen, kheads, d), (0, 2) 56 | ) 57 | out_q = (q * qcos) + (rotate_half(q) * qsin) 58 | out_k = (k * kcos) + (rotate_half(k) * ksin) 59 | return out_q, out_k 60 | 61 | 62 | class RelativePosition(hk.Module): 63 | """ 64 | Relative Positional Embeddings 65 | 66 | e_ij = (x_i * W^Q) * (x_j * W^K)^T / sqrt(d) + d_ij 67 | d_ij is the relative position embedding 68 | """ 69 | def __init__(self, max_relative_position): 70 | """ 71 | max_relative_position: maximum relative position 72 | """ 73 | 74 | super().__init__() 75 | self.max_relative_position = max_relative_position 76 | self.embeddings_table = hk.get_parameter( 77 | "embeddings_table", 78 | shape=(max_relative_position * 2 + 1, ), 79 | init=hk.initializers.TruncatedNormal(0.01) 80 | ) 81 | 82 | def __call__(self, length_q, length_k): 83 | range_vec_q = jnp.arange(length_q) 84 | range_vec_k = jnp.arange(length_k) 85 | distance_mat = range_vec_k[None, :] - range_vec_q[:, None] 86 | distance_mat_clipped = jax.lax.clamp(-self.max_relative_position, distance_mat, self.max_relative_position) 87 | final_mat = distance_mat_clipped + self.max_relative_position 88 | final_mat = final_mat.astype(int) 89 | embeddings = self.embeddings_table[final_mat] 90 | 91 | return embeddings 92 | -------------------------------------------------------------------------------- /crystalformer/src/train.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from functools import partial 4 | import os 5 | import optax 6 | import math 7 | 8 | from crystalformer.src.utils import shuffle 9 | import crystalformer.src.checkpoint as checkpoint 10 | 11 | 12 | shard = jax.pmap(lambda x: x) 13 | p_split = jax.pmap(lambda key: tuple(jax.random.split(key))) 14 | 15 | 16 | def scatter(x: jnp.ndarray, retain_axis=False) -> jnp.ndarray: 17 | num_devices = jax.local_device_count() 18 | if x.shape[0] % num_devices != 0: 19 | raise ValueError("The first dimension of x must be divisible by the total number of GPU devices. " 20 | "Got x.shape[0] = %d for %d devices now." % (x.shape[0], num_devices)) 21 | dim_per_device = x.shape[0] // num_devices 22 | x = x.reshape( 23 | (num_devices,) + 24 | (() if dim_per_device == 1 and not retain_axis else (dim_per_device,)) + 25 | x.shape[1:] 26 | ) 27 | return shard(x) 28 | 29 | 30 | def train(key, optimizer, opt_state, loss_fn, params, epoch_finished, epochs, batchsize, train_data, valid_data, path, val_interval): 31 | 32 | num_devices = jax.local_device_count() 33 | batch_per_device = batchsize // num_devices 34 | shape_prefix = (num_devices, batch_per_device) 35 | print("num_devices: ", num_devices) 36 | print("batch_per_device: ", batch_per_device) 37 | print("shape_prefix: ", shape_prefix) 38 | 39 | key = jax.random.fold_in(key, jax.process_index()) # make different key for different process 40 | key, *keys = jax.random.split(key, num_devices + 1) 41 | keys = scatter(jnp.array(keys)) 42 | 43 | @partial(jax.pmap, axis_name="p", in_axes=(None, 0, None, 0), out_axes=(None, None, 0),) 44 | def update(params, key, opt_state, data): 45 | G, L, X, A, W = data 46 | value, grad = jax.value_and_grad(loss_fn, has_aux=True)(params, key, G, L, X, A, W, True) 47 | grad = jax.lax.pmean(grad, axis_name="p") 48 | value = jax.lax.pmean(value, axis_name="p") 49 | updates, opt_state = optimizer.update(grad, opt_state, params) 50 | params = optax.apply_updates(params, updates) 51 | return params, opt_state, value 52 | 53 | log_filename = os.path.join(path, "data.txt") 54 | f = open(log_filename, "w" if epoch_finished == 0 else "a", buffering=1, newline="\n") 55 | if os.path.getsize(log_filename) == 0: 56 | f.write("epoch t_loss v_loss t_loss_w v_loss_w t_loss_a v_loss_a t_loss_xyz v_loss_xyz t_loss_l v_loss_l\n") 57 | 58 | for epoch in range(epoch_finished+1, epochs+1): 59 | key, subkey = jax.random.split(key) 60 | train_data = shuffle(subkey, train_data) 61 | 62 | _, train_L, _, _, _ = train_data 63 | 64 | train_loss = 0.0 65 | train_aux = 0.0, 0.0, 0.0, 0.0 66 | num_samples = train_L.shape[0] 67 | if num_samples % batchsize == 0: 68 | num_batches = math.ceil(num_samples / batchsize) 69 | else: 70 | num_batches = math.ceil(num_samples / batchsize) - 1 71 | for batch_idx in range(num_batches): 72 | start_idx = batch_idx * batchsize 73 | end_idx = min(start_idx + batchsize, num_samples) 74 | data = jax.tree_util.tree_map(lambda x: x[start_idx:end_idx], train_data) 75 | data = jax.tree_util.tree_map(lambda x: x.reshape(shape_prefix + x.shape[1:]), data) 76 | 77 | keys, subkeys = p_split(keys) 78 | params, opt_state, (loss, aux) = update(params, subkeys, opt_state, data) 79 | train_loss, train_aux = jax.tree_util.tree_map( 80 | lambda acc, i: acc + jnp.mean(i), 81 | (train_loss, train_aux), 82 | (loss, aux) 83 | ) 84 | 85 | train_loss, train_aux = jax.tree_util.tree_map( 86 | lambda x: x/num_batches, 87 | (train_loss, train_aux) 88 | ) 89 | 90 | if epoch % val_interval == 0: 91 | _, valid_L, _, _, _ = valid_data 92 | valid_loss = 0.0 93 | valid_aux = 0.0, 0.0, 0.0, 0.0 94 | num_samples = valid_L.shape[0] 95 | if num_samples % batchsize == 0: 96 | num_batches = math.ceil(num_samples / batchsize) 97 | else: 98 | num_batches = math.ceil(num_samples / batchsize) - 1 99 | for batch_idx in range(num_batches): 100 | start_idx = batch_idx * batchsize 101 | end_idx = min(start_idx + batchsize, num_samples) 102 | data = jax.tree_util.tree_map(lambda x: x[start_idx:end_idx], valid_data) 103 | data = jax.tree_util.tree_map(lambda x: x.reshape(shape_prefix + x.shape[1:]), data) 104 | 105 | keys, subkeys = p_split(keys) 106 | loss, aux = jax.pmap(loss_fn, in_axes=(None, 0, 0, 0, 0, 0, 0), 107 | static_broadcasted_argnums=7)(params, subkeys, *data, False) 108 | valid_loss, valid_aux = jax.tree_util.tree_map( 109 | lambda acc, i: acc + jnp.mean(i), 110 | (valid_loss, valid_aux), 111 | (loss, aux) 112 | ) 113 | 114 | valid_loss, valid_aux = jax.tree_util.tree_map( 115 | lambda x: x/num_batches, 116 | (valid_loss, valid_aux) 117 | ) 118 | 119 | train_loss_w, train_loss_a, train_loss_xyz, train_loss_l = train_aux 120 | valid_loss_w, valid_loss_a, valid_loss_xyz, valid_loss_l = valid_aux 121 | 122 | f.write( ("%6d" + 10*" %.6f" + "\n") % (epoch, 123 | train_loss, valid_loss, 124 | train_loss_w, valid_loss_w, 125 | train_loss_a, valid_loss_a, 126 | train_loss_xyz, valid_loss_xyz, 127 | train_loss_l, valid_loss_l 128 | )) 129 | 130 | ckpt = {"params": params, 131 | "opt_state" : opt_state 132 | } 133 | ckpt_filename = os.path.join(path, "epoch_%06d.pkl" %(epoch)) 134 | if jax.process_index() == 0: 135 | checkpoint.save_data(ckpt, ckpt_filename) 136 | print("Save checkpoint file: %s" % ckpt_filename) 137 | 138 | f.close() 139 | return params, opt_state 140 | -------------------------------------------------------------------------------- /crystalformer/src/utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import pandas as pd 5 | from pyxtal import pyxtal 6 | from pymatgen.core import Structure, Lattice 7 | from pymatgen.symmetry.analyzer import SpacegroupAnalyzer 8 | from functools import partial 9 | import multiprocessing 10 | import os 11 | 12 | from crystalformer.src.wyckoff import mult_table 13 | from crystalformer.src.elements import element_list 14 | 15 | @jax.vmap 16 | def sort_atoms(W, A, X): 17 | """ 18 | lex sort atoms according W, X, Y, Z 19 | 20 | W: (n, ) 21 | A: (n, ) 22 | X: (n, dim) int 23 | """ 24 | W_temp = jnp.where(W>0, W, 9999) # change 0 to 9999 so they remain in the end after sort 25 | 26 | X -= jnp.floor(X) 27 | idx = jnp.lexsort((X[:,2], X[:,1], X[:,0], W_temp)) 28 | 29 | #assert jnp.allclose(W, W[idx]) 30 | A = A[idx] 31 | X = X[idx] 32 | return A, X 33 | 34 | def letter_to_number(letter): 35 | """ 36 | 'a' to 1 , 'b' to 2 , 'z' to 26, and 'A' to 27 37 | """ 38 | return ord(letter) - ord('a') + 1 if 'a' <= letter <= 'z' else 27 if letter == 'A' else None 39 | 40 | def shuffle(key, data): 41 | """ 42 | shuffle data along batch dimension 43 | """ 44 | G, L, XYZ, A, W = data 45 | idx = jax.random.permutation(key, jnp.arange(len(L))) 46 | return G[idx], L[idx], XYZ[idx], A[idx], W[idx] 47 | 48 | def process_one(cif, atom_types, wyck_types, n_max, tol=0.01): 49 | """ 50 | # taken from https://anonymous.4open.science/r/DiffCSP-PP-8F0D/diffcsp/common/data_utils.py 51 | Process one cif string to get G, L, XYZ, A, W 52 | 53 | Args: 54 | cif: cif string 55 | atom_types: number of atom types 56 | wyck_types: number of wyckoff types 57 | n_max: maximum number of atoms in the unit cell 58 | tol: tolerance for pyxtal 59 | 60 | Returns: 61 | G: space group number 62 | L: lattice parameters 63 | XYZ: fractional coordinates 64 | A: atom types 65 | W: wyckoff letters 66 | """ 67 | try: crystal = Structure.from_str(cif, fmt='cif') 68 | except: crystal = Structure.from_dict(eval(cif)) 69 | spga = SpacegroupAnalyzer(crystal, symprec=tol) 70 | crystal = spga.get_refined_structure() 71 | c = pyxtal() 72 | try: 73 | c.from_seed(crystal, tol=0.01) 74 | except: 75 | c.from_seed(crystal, tol=0.0001) 76 | 77 | g = c.group.number 78 | num_sites = len(c.atom_sites) 79 | assert (n_max > num_sites) # we will need at least one empty site for output of L params 80 | 81 | print (g, c.group.symbol, num_sites) 82 | natoms = 0 83 | ww = [] 84 | aa = [] 85 | fc = [] 86 | ws = [] 87 | for site in c.atom_sites: 88 | a = element_list.index(site.specie) 89 | x = site.position 90 | m = site.wp.multiplicity 91 | w = letter_to_number(site.wp.letter) 92 | symbol = str(m) + site.wp.letter 93 | natoms += site.wp.multiplicity 94 | assert (a < atom_types) 95 | assert (w < wyck_types) 96 | assert (np.allclose(x, site.wp[0].operate(x))) 97 | aa.append( a ) 98 | ww.append( w ) 99 | fc.append( x ) # the generator of the orbit 100 | ws.append( symbol ) 101 | print ('g, a, w, m, symbol, x:', g, a, w, m, symbol, x) 102 | idx = np.argsort(ww) 103 | ww = np.array(ww)[idx] 104 | aa = np.array(aa)[idx] 105 | fc = np.array(fc)[idx].reshape(num_sites, 3) 106 | ws = np.array(ws)[idx] 107 | print (ws, aa, ww, natoms) 108 | 109 | aa = np.concatenate([aa, 110 | np.full((n_max - num_sites, ), 0)], 111 | axis=0) 112 | 113 | ww = np.concatenate([ww, 114 | np.full((n_max - num_sites, ), 0)], 115 | axis=0) 116 | fc = np.concatenate([fc, 117 | np.full((n_max - num_sites, 3), 1e10)], 118 | axis=0) 119 | 120 | abc = np.array([c.lattice.a, c.lattice.b, c.lattice.c])/natoms**(1./3.) 121 | angles = np.array([c.lattice.alpha, c.lattice.beta, c.lattice.gamma]) 122 | l = np.concatenate([abc, angles]) 123 | 124 | print ('===================================') 125 | 126 | return g, l, fc, aa, ww 127 | 128 | def GLXYZAW_from_file(csv_file, atom_types, wyck_types, n_max, num_workers=1): 129 | """ 130 | Read cif strings from csv file and convert them to G, L, XYZ, A, W 131 | Note that cif strings must be in the column 'cif' 132 | 133 | Args: 134 | csv_file: csv file containing cif strings 135 | atom_types: number of atom types 136 | wyck_types: number of wyckoff types 137 | n_max: maximum number of atoms in the unit cell 138 | num_workers: number of workers for multiprocessing 139 | 140 | Returns: 141 | G: space group number 142 | L: lattice parameters 143 | XYZ: fractional coordinates 144 | A: atom types 145 | W: wyckoff letters 146 | """ 147 | if csv_file.endswith('.lmdb'): 148 | import lmdb 149 | import pickle 150 | # read from lmdb 151 | env = lmdb.open( 152 | csv_file, 153 | subdir=False, 154 | readonly=True, 155 | lock=False, 156 | readahead=False, 157 | meminit=False, 158 | ) 159 | 160 | contents = env.begin().cursor().iternext() 161 | data = tuple([pickle.loads(value) for _, value in contents]) 162 | G, L, XYZ, A, W = data 163 | print('G:', G.shape) 164 | print('L:', L.shape) 165 | print('XYZ:', XYZ.shape) 166 | print('A:', A.shape) 167 | print('W:', W.shape) 168 | return G, L, XYZ, A, W 169 | 170 | data = pd.read_csv(csv_file) 171 | try: cif_strings = data['cif'] 172 | except: cif_strings = data['structure'] 173 | 174 | p = multiprocessing.Pool(num_workers) 175 | partial_process_one = partial(process_one, atom_types=atom_types, wyck_types=wyck_types, n_max=n_max) 176 | results = p.map_async(partial_process_one, cif_strings).get() 177 | p.close() 178 | p.join() 179 | 180 | G, L, XYZ, A, W = zip(*results) 181 | 182 | G = jnp.array(G) 183 | A = jnp.array(A).reshape(-1, n_max) 184 | W = jnp.array(W).reshape(-1, n_max) 185 | XYZ = jnp.array(XYZ).reshape(-1, n_max, 3) 186 | L = jnp.array(L).reshape(-1, 6) 187 | 188 | A, XYZ = sort_atoms(W, A, XYZ) 189 | 190 | return G, L, XYZ, A, W 191 | 192 | def GLXA_to_structure_single(G, L, X, A): 193 | """ 194 | Convert G, L, X, A to pymatgen structure. Do not use this function due to the bug in pymatgen. 195 | 196 | Args: 197 | G: space group number 198 | L: lattice parameters 199 | X: fractional coordinates 200 | A: atom types 201 | 202 | Returns: 203 | structure: pymatgen structure 204 | """ 205 | lattice = Lattice.from_parameters(*L) 206 | # filter out padding atoms 207 | idx = np.where(A > 0) 208 | A = A[idx] 209 | X = X[idx] 210 | structure = Structure.from_spacegroup(sg=G, lattice=lattice, species=A, coords=X).as_dict() 211 | 212 | return structure 213 | 214 | def GLXA_to_csv(G, L, X, A, num_worker=1, filename='out_structure.csv'): 215 | 216 | L = np.array(L) 217 | X = np.array(X) 218 | A = np.array(A) 219 | p = multiprocessing.Pool(num_worker) 220 | if isinstance(G, int): 221 | G = np.array([G] * len(L)) 222 | structures = p.starmap_async(GLXA_to_structure_single, zip(G, L, X, A)).get() 223 | p.close() 224 | p.join() 225 | 226 | data = pd.DataFrame() 227 | data['cif'] = structures 228 | header = False if os.path.exists(filename) else True 229 | data.to_csv(filename, mode='a', index=False, header=header) 230 | 231 | 232 | if __name__=='__main__': 233 | atom_types = 119 234 | wyck_types = 28 235 | n_max = 24 236 | 237 | import numpy as np 238 | np.set_printoptions(threshold=np.inf) 239 | 240 | #csv_file = '../data/mini.csv' 241 | #csv_file = '/home/wanglei/cdvae/data/carbon_24/val.csv' 242 | #csv_file = '/home/wanglei/cdvae/data/perov_5/val.csv' 243 | csv_file = '/home/wanglei/cdvae/data/mp_20/train.csv' 244 | 245 | G, L, XYZ, A, W = GLXYZAW_from_file(csv_file, atom_types, wyck_types, n_max) 246 | 247 | print (G.shape) 248 | print (L.shape) 249 | print (XYZ.shape) 250 | print (A.shape) 251 | print (W.shape) 252 | 253 | print ('L:\n',L) 254 | print ('XYZ:\n',XYZ) 255 | 256 | 257 | @jax.vmap 258 | def lookup(G, W): 259 | return mult_table[G-1, W] # (n_max, ) 260 | M = lookup(G, W) # (batchsize, n_max) 261 | print ('N:\n', M.sum(axis=-1)) 262 | -------------------------------------------------------------------------------- /crystalformer/src/von_mises.py: -------------------------------------------------------------------------------- 1 | # https://github.com/pyro-ppl/numpyro/blob/master/numpyro/distributions/util.py 2 | # Copyright Contributors to the Pyro project. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import jax 6 | from jax import jit, lax, random 7 | import jax.numpy as jnp 8 | from functools import partial 9 | 10 | def sample_von_mises(key, loc, concentration, shape): 11 | """Generate sample from von Mises distribution 12 | 13 | :param key: random number generator key 14 | :param sample_shape: shape of samples 15 | :return: samples from von Mises 16 | """ 17 | samples = von_mises_centered( 18 | key, concentration, shape 19 | ) 20 | samples = samples + loc # VM(0, concentration) -> VM(loc,concentration) 21 | samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi 22 | 23 | return samples 24 | 25 | def von_mises_centered(key, concentration, shape, dtype=jnp.float64): 26 | """Compute centered von Mises samples using rejection sampling from [1] with wrapped Cauchy proposal. 27 | *** References *** 28 | [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986; 29 | Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf 30 | :param key: random number generator key 31 | :param concentration: concentration of distribution 32 | :param shape: shape of samples 33 | :param dtype: float precesions for choosing correct s cutfoff 34 | :return: centered samples from von Mises 35 | """ 36 | shape = shape or jnp.shape(concentration) 37 | dtype = jnp.result_type(dtype) 38 | concentration = lax.convert_element_type(concentration, dtype) 39 | concentration = jnp.broadcast_to(concentration, shape) 40 | return _von_mises_centered(key, concentration, shape, dtype) 41 | 42 | 43 | @partial(jit, static_argnums=(2, 3)) 44 | def _von_mises_centered(key, concentration, shape, dtype): 45 | # Cutoff from TensorFlow probability 46 | # (https://github.com/tensorflow/probability/blob/f051e03dd3cc847d31061803c2b31c564562a993/tensorflow_probability/python/distributions/von_mises.py#L567-L570) 47 | s_cutoff_map = { 48 | jnp.dtype(jnp.float16): 1.8e-1, 49 | jnp.dtype(jnp.float32): 2e-2, 50 | jnp.dtype(jnp.float64): 1.2e-4, 51 | } 52 | s_cutoff = s_cutoff_map.get(dtype) 53 | 54 | r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration**2) 55 | rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration) 56 | s_exact = (1.0 + rho**2) / (2.0 * rho) 57 | 58 | s_approximate = 1.0 / concentration 59 | 60 | s = jnp.where(concentration > s_cutoff, s_exact, s_approximate) 61 | 62 | def cond_fn(*args): 63 | """check if all are done or reached max number of iterations""" 64 | i, _, done, _, _ = args[0] 65 | return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done))) 66 | 67 | def body_fn(*args): 68 | i, key, done, _, w = args[0] 69 | uni_ukey, uni_vkey, key = random.split(key, 3) 70 | 71 | u = random.uniform( 72 | key=uni_ukey, 73 | shape=shape, 74 | dtype=concentration.dtype, 75 | minval=-1.0, 76 | maxval=1.0, 77 | ) 78 | z = jnp.cos(jnp.pi * u) 79 | w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done 80 | 81 | y = concentration * (s - w) 82 | v = random.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype) 83 | 84 | accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y) 85 | 86 | return i + 1, key, accept | done, u, w 87 | 88 | init_done = jnp.zeros(shape, dtype=bool) 89 | init_u = jnp.zeros(shape) 90 | init_w = jnp.zeros(shape) 91 | 92 | _, _, done, u, w = lax.while_loop( 93 | cond_fun=cond_fn, 94 | body_fun=body_fn, 95 | init_val=(jnp.array(0), key, init_done, init_u, init_w), 96 | ) 97 | 98 | return jnp.sign(u) * jnp.arccos(jnp.clip(w, -1.0, 1.0)) 99 | 100 | def von_mises_logpdf(x, loc, concentration): 101 | ''' 102 | kappa is the concentration. kappa = 0 means uniform distribution 103 | ''' 104 | return -(jnp.log(2 * jnp.pi) + jnp.log(jax.scipy.special.i0e(concentration)) 105 | ) + concentration * (jnp.cos((x - loc) % (2 * jnp.pi)) - 1) 106 | 107 | if __name__=='__main__': 108 | key = jax.random.PRNGKey(42) 109 | loc = jnp.array([-1.0, 1.0, 0.0]) 110 | kappa = jnp.array([10.0, 10.0, 100.0]) 111 | x = sample_von_mises(key, loc, kappa, (3, )) 112 | print (x) 113 | 114 | -------------------------------------------------------------------------------- /crystalformer/src/wyckoff.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import numpy as np 4 | import re 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | def from_xyz_str(xyz_str: str): 9 | """ 10 | Args: 11 | xyz_str: string of the form 'x, y, z', '-x, -y, z', '-2y+1/2, 3x+1/2, z-y+1/2', etc. 12 | Returns: 13 | affine operator as a 3x4 array 14 | """ 15 | rot_matrix = np.zeros((3, 3)) 16 | trans = np.zeros(3) 17 | tokens = xyz_str.strip().replace(" ", "").lower().split(",") 18 | re_rot = re.compile(r"([+-]?)([\d\.]*)/?([\d\.]*)([x-z])") 19 | re_trans = re.compile(r"([+-]?)([\d\.]+)/?([\d\.]*)(?![x-z])") 20 | for i, tok in enumerate(tokens): 21 | # build the rotation matrix 22 | for m in re_rot.finditer(tok): 23 | factor = -1.0 if m.group(1) == "-" else 1.0 24 | if m.group(2) != "": 25 | factor *= float(m.group(2)) / float(m.group(3)) if m.group(3) != "" else float(m.group(2)) 26 | j = ord(m.group(4)) - 120 27 | rot_matrix[i, j] = factor 28 | # build the translation vector 29 | for m in re_trans.finditer(tok): 30 | factor = -1 if m.group(1) == "-" else 1 31 | num = float(m.group(2)) / float(m.group(3)) if m.group(3) != "" else float(m.group(2)) 32 | trans[i] = num * factor 33 | return np.concatenate( [rot_matrix, trans[:, None]], axis=1) # (3, 4) 34 | 35 | 36 | df = pd.read_csv(os.path.join(os.path.dirname(__file__), '../data/wyckoff_list.csv')) 37 | df['Wyckoff Positions'] = df['Wyckoff Positions'].apply(eval) # convert string to list 38 | wyckoff_positions = df['Wyckoff Positions'].tolist() 39 | 40 | symops = np.zeros((230, 28, 576, 3, 4)) # 576 is the least common multiple for all possible mult 41 | mult_table = np.zeros((230, 28), dtype=int) # mult_table[g-1, w] = multiplicity , 28 because we had pad 0 42 | wmax_table = np.zeros((230,), dtype=int) # wmax_table[g-1] = number of possible wyckoff letters for g 43 | dof0_table = np.ones((230, 28), dtype=bool) # dof0_table[g-1, w] = True for those wyckoff points with dof = 0 (no continuous dof) 44 | fc_mask_table = np.zeros((230, 28, 3), dtype=bool) # fc_mask_table[g-1, w] = True for continuous fc 45 | 46 | def build_g_code(): 47 | #use general wyckoff position as the code for space groups 48 | xyz_table = [] 49 | g_table = [] 50 | for g in range(230): 51 | wp0 = wyckoff_positions[g][0] 52 | g_table.append([]) 53 | for xyz in wp0: 54 | if xyz not in xyz_table: 55 | xyz_table.append(xyz) 56 | g_table[-1].append(xyz_table.index(xyz)) 57 | assert len(g_table[-1]) == len(set(g_table[-1])) 58 | 59 | g_code = [] 60 | for g in range(230): 61 | g_code.append( [1 if i in g_table[g] else 0 for i in range(len(xyz_table))] ) 62 | del xyz_table 63 | del g_table 64 | g_code = jnp.array(g_code) 65 | return g_code 66 | 67 | for g in range(230): 68 | wyckoffs = [] 69 | for x in wyckoff_positions[g]: 70 | wyckoffs.append([]) 71 | for y in x: 72 | wyckoffs[-1].append(from_xyz_str(y)) 73 | wyckoffs = wyckoffs[::-1] # a-z,A 74 | 75 | mult = [len(w) for w in wyckoffs] 76 | mult_table[g, 1:len(mult)+1] = mult 77 | wmax_table[g] = len(mult) 78 | 79 | # print (g+1, [len(w) for w in wyckoffs]) 80 | for w, wyckoff in enumerate(wyckoffs): 81 | wyckoff = np.array(wyckoff) 82 | repeats = symops.shape[2] // wyckoff.shape[0] 83 | symops[g, w+1, :, :, :] = np.tile(wyckoff, (repeats, 1, 1)) 84 | dof0_table[g, w+1] = np.linalg.matrix_rank(wyckoff[0, :3, :3]) == 0 85 | fc_mask_table[g, w+1] = jnp.abs(wyckoff[0, :3, :3]).sum(axis=1)!=0 86 | 87 | symops = jnp.array(symops) 88 | mult_table = jnp.array(mult_table) 89 | wmax_table = jnp.array(wmax_table) 90 | dof0_table = jnp.array(dof0_table) 91 | fc_mask_table = jnp.array(fc_mask_table) 92 | 93 | def symmetrize_atoms(g, w, x): 94 | ''' 95 | symmetrize atoms via, apply all sg symmetry op, finding the generator, and lastly apply symops 96 | we need to do that because the sampled atom might not be at the first WP 97 | Args: 98 | g: int 99 | w: int 100 | x: (3,) 101 | Returns: 102 | xs: (m, 3) symmetrize atom positions 103 | ''' 104 | 105 | # (1) apply all space group symmetry op to the x 106 | w_max = wmax_table[g-1].item() 107 | m_max = mult_table[g-1, w_max].item() 108 | ops = symops[g-1, w_max, :m_max] # (m_max, 3, 4) 109 | affine_point = jnp.array([*x, 1]) # (4, ) 110 | coords = ops@affine_point # (m_max, 3) 111 | coords -= jnp.floor(coords) 112 | 113 | # (2) search for the generator which satisfies op0(x) = x , i.e. the first Wyckoff position 114 | # here we solve it in a jit friendly way by looking for the minimal distance solution for the lhs and rhs 115 | #https://github.com/qzhu2017/PyXtal/blob/82e7d0eac1965c2713179eeda26a60cace06afc8/pyxtal/wyckoff_site.py#L115 116 | def dist_to_op0x(coord): 117 | diff = jnp.dot(symops[g-1, w, 0], jnp.array([*coord, 1])) - coord 118 | diff -= jnp.rint(diff) 119 | return jnp.sum(diff**2) 120 | loc = jnp.argmin(jax.vmap(dist_to_op0x)(coords)) 121 | x = coords[loc].reshape(3,) 122 | 123 | # (3) lastly, apply the given symmetry op to x 124 | m = mult_table[g-1, w] 125 | ops = symops[g-1, w, :m] # (m, 3, 4) 126 | affine_point = jnp.array([*x, 1]) # (4, ) 127 | xs = ops@affine_point # (m, 3) 128 | xs -= jnp.floor(xs) # wrap back to 0-1 129 | return xs 130 | 131 | if __name__=='__main__': 132 | print (symops.shape) 133 | print (symops.size*symops.dtype.itemsize//(1024*1024)) 134 | 135 | import numpy as np 136 | np.set_printoptions(threshold=np.inf) 137 | 138 | print (symops[166-1,3, :6]) 139 | op = symops[166-1, 3, 0] 140 | print (op) 141 | 142 | w_max = wmax_table[225-1] 143 | m_max = mult_table[225-1, w_max] 144 | print ('w_max, m_max', w_max, m_max) 145 | 146 | print (fc_mask_table[225-1, 6]) 147 | sys.exit(0) 148 | 149 | print ('mult_table') 150 | print (mult_table[25-1]) # space group id -> multiplicity table 151 | print (mult_table[42-1]) 152 | print (mult_table[47-1]) 153 | print (mult_table[99-1]) 154 | print (mult_table[123-1]) 155 | print (mult_table[221-1]) 156 | print (mult_table[166-1]) 157 | 158 | print ('dof0_table') 159 | print (dof0_table[25-1]) 160 | print (dof0_table[42-1]) 161 | print (dof0_table[47-1]) 162 | print (dof0_table[225-1]) 163 | print (dof0_table[166-1]) 164 | 165 | print ('wmax_table') 166 | print (wmax_table[47-1]) 167 | print (wmax_table[123-1]) 168 | print (wmax_table[166-1]) 169 | 170 | print ('wmax_table', wmax_table) 171 | 172 | atom_types = 119 173 | aw_max = wmax_table*(atom_types-1) # the maximum value of aw 174 | print ( (aw_max-1)%(atom_types-1)+1 ) # = 118 175 | print ( (aw_max-1)//(atom_types-1)+1 ) # = wmax 176 | -------------------------------------------------------------------------------- /data/atoms.json: -------------------------------------------------------------------------------- 1 | { 2 | "atom_mask":[ 3 | ["V", "Cr", "Mn", "Fe", "Co", "Ni"], 4 | ["V", "Cr", "Mn", "Fe", "Co", "Ni"], 5 | ["O", "S", "F", "Cl", "Br", "I"], 6 | ["O", "S", "F", "Cl", "Br", "I"] 7 | ], 8 | "constraints":[ 9 | [2, 3] 10 | ] 11 | } -------------------------------------------------------------------------------- /imgs/crystalformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/CrystalFormer/a2ae7982234b84134950c5c7a4f49d592c83bddf/imgs/crystalformer.png -------------------------------------------------------------------------------- /imgs/output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/CrystalFormer/a2ae7982234b84134950c5c7a4f49d592c83bddf/imgs/output.gif -------------------------------------------------------------------------------- /model/README.md: -------------------------------------------------------------------------------- 1 | # Model Card 2 | 3 | ## Alex-20 4 | 5 | The pre-trained model is available on [Google Drive](https://drive.google.com/file/d/1Fjt3bXzouAb-GX3ScAuejDA6eggtOYe4/view?usp=sharing) and [Hugging Face Model Hub](https://huggingface.co/zdcao/CrystalFormer/blob/main/alex20/PBE/epoch_005500.pkl). 6 | 7 | ### Model Parameters 8 | 9 | ```python 10 | params, transformer = make_transformer( 11 | key=jax.random.PRNGKey(42), 12 | Nf=5, 13 | Kx=16, 14 | Kl=4, 15 | n_max=21, 16 | h0_size=256, 17 | num_layers=16, 18 | num_heads=16, 19 | key_size=64, 20 | model_size=64, 21 | embed_size=32, 22 | atom_types=119, 23 | wyck_types=28, 24 | dropout_rate=0.1, 25 | attn_rate=0.1, 26 | widening_factor=4, 27 | sigmamin=1e-3 28 | ) 29 | ``` 30 | 31 | ### Training dataset 32 | 33 | Alex-20: contains ~1.3M general inorganic materials curated from the [Alexandria database](https://alexandria.icams.rub.de/), with $E_{hull} < 0.1$ eV/atom and no more than 20 atoms in unit cell. The dataset can be found in the [Google Drive](https://drive.google.com/drive/folders/1QeYz9lQX9Lk-OxhKBOwvuyKBecznPVlX?usp=drive_link) or [Hugging Face Datasets](https://huggingface.co/datasets/zdcao/alex-20). 34 | 35 | 36 | ## Alex-20 RL 37 | 38 | - $E_{hull}$ reward: The checkpoint is available on [Google Drive](https://drive.google.com/file/d/1LlrpWj1GWUBZb-Ix_D3DfXxPd6EVsY6e/view?usp=sharing) and [Hugging Face Model Hub](https://huggingface.co/zdcao/CrystalFormer/blob/main/alex20/RL-ehull/epoch_000195.pkl). The reward is chosen to be the negative energy above the hull, which is calculated by the [Orb model](https://github.com/orbital-materials/orb-models) based on the Alexandria convex hull. 39 | 40 | - Dielectric FoM Reward: The checkpoint is available on [Google Drive](https://drive.google.com/file/d/1Jsa5uHa_Eu3cULqBDZxyia7CBgqe7Hg4/view?usp=sharing) and [Hugging Face Model Hub](https://huggingface.co/zdcao/CrystalFormer/blob/main/alex20/RL-dielectric/epoch_000100.pkl). The reward is chosen to be figures of dielectric figure of merit (FoM), which is the product of the total dielectric constant and the band gap. We use the pretrained [MEGNet](https://github.com/materialsvirtuallab/matgl/tree/main/pretrained_models/MEGNet-MP-2019.4.1-BandGap-mfi) to predict the band gap. The checkpoint of the total dielectric constant prediction model can be found in the [Google Drive](https://drive.google.com/drive/folders/1hQJD5R0dMJVC3nA1YkSHkCG9s-IAVNnA?usp=sharing). You can load the model using [matgl](https://github.com/materialsvirtuallab/matgl/tree/main) package. 41 | 42 | 43 | ## MP-20 44 | 45 | > [!IMPORTANT] 46 | > The load the MP-20 checkpoint, you need to switch the `CrystalFormer` to version 0.3 The current version of the model is not compatible with the MP-20 checkpoint. 47 | 48 | ### Checkpoint 49 | 50 | The pre-trained model is available on [Google Drive](https://drive.google.com/file/d/1koHC6n38BqsY2_z3xHTi40HcFbVesUKd/view?usp=sharing) and [Hugging Face Model Hub](https://huggingface.co/zdcao/CrystalFormer/blob/main/mp20/epoch_003800.pkl). 51 | 52 | ### Model Parameters 53 | 54 | ```python 55 | params, transformer = make_transformer( 56 | key=jax.random.PRNGKey(42), 57 | Nf=5, 58 | Kx=16, 59 | Kl=4, 60 | n_max=21, 61 | h0_size=256, 62 | num_layers=16, 63 | num_heads=16, 64 | key_size=64, 65 | model_size=64, 66 | embed_size=32, 67 | atom_types=119, 68 | wyck_types=28, 69 | dropout_rate=0.5, 70 | widening_factor=4, 71 | sigmamin=1e-3 72 | ) 73 | ``` 74 | 75 | ### Training dataset 76 | 77 | MP-20 (Jain et al., 2013): contains 45k general inorganic materials, including most experimentally known materials with no more than 20 atoms in unit cell. 78 | More details can be found in the [CDVAE repository](https://github.com/txie-93/cdvae/tree/main/data/mp_20). -------------------------------------------------------------------------------- /model/config.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: "outputs/${training_name}_${physics_name}_${loss_name}_${transformer_name}" 4 | sweep: 5 | dir: "outputs/${training_name}_${physics_name}_${loss_name}_${transformer_name}" 6 | 7 | # training_parameters: 8 | epochs: 10000 9 | batchsize: 100 10 | lr: 0.0001 11 | lr_decay: 0.0 12 | weight_decay: 0.0 13 | clip_grad: 1.0 14 | optimizer: adam 15 | folder: "./" 16 | restore_path: null 17 | training_name: "${folder}${optimizer}_bs_${batchsize}_\ 18 | lr_${lr}_decay_${lr_decay}_clip_${clip_grad}" 19 | 20 | # dataset: 21 | train_path: "/data/zdcao/crystal_gpt/dataset/mp_20/train.csv" 22 | valid_path: "/data/zdcao/crystal_gpt/dataset/mp_20/val.csv" 23 | test_path: "/data/zdcao/crystal_gpt/dataset/mp_20/test.csv" 24 | 25 | # transformer_parameters: 26 | Nf: 5 # number of frequencies for fc 27 | Kx: 16 # number of modes in x 28 | Kl: 4 # number of modes in lattice 29 | h0_size: 256 # hidden layer dimension for the first atom, 0 means we simply use a table for first aw_logit') 30 | transformer_layers: 16 # The number of layers in transformer 31 | num_heads: 16 # The number of heads 32 | key_size: 64 # The key size 33 | model_size: 64 # The model size 34 | embed_size: 32 # The enbedding size 35 | dropout_rate: 0.5 # The dropout rate 36 | transformer_name: "Nf_${Nf}_Kx_${Kx}_Kl_${Kl}_\ 37 | h0_${h0_size}_l_${transformer_layers}_H_${num_heads}_\ 38 | k_${key_size}_m_${model_size}_e_${embed_size}_drop_${dropout_rate}" 39 | 40 | # loss_parameters: 41 | lamb_a: 1.0 # weight for the a part relative to fc 42 | lamb_w: 1.0 # weight for the w part relative to fc 43 | lamb_l: 1.0 # weight for the lattice part relative to fc 44 | loss_name: "a_${lamb_a}_w_${lamb_w}_l_${lamb_l}" 45 | 46 | # physics_parameters: 47 | n_max: 21 # The maximum number of atoms in the cell 48 | atom_types: 119 # Atom types including the padded atoms 49 | wyck_types: 28 # Number of possible multiplicites including 0 50 | physics_name: "A_${atom_types}_W_${wyck_types}_N_${n_max}" 51 | 52 | # sampling_parameters: 53 | spacegroup: null # Since the exact value is not provided, null is used 54 | elements: null # List format will be needed when specifying elements, e.g., [Bi, Ti, O] 55 | top_p: 1.0 # 1.0 means un-modified logits, smaller value of p give give less diverse samples 56 | temperature: 1.0 # temperature used for sampling 57 | num_io_process: 40 # number of process used in multiprocessing io 58 | num_samples: 1000 # number of test samples 59 | use_foriloop: true # false for not using the fori_loop, true if specified 60 | output_filename: "output.csv" # outfile to save sampled structures 61 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dm-haiku==0.0.14 2 | optax==0.2.4 3 | pyxtal==1.0.7 -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | ## Post-Processing Scripts 2 | 3 | ### Contents 4 | - [Post-Processing Scripts](#post-processing-scripts) 5 | - [Contents](#contents) 6 | - [Transform](#transform) 7 | - [Structure and Composition Validity](#structure-and-composition-validity) 8 | - [Novelty and Uniqueness](#novelty-and-uniqueness) 9 | - [Relaxation](#relaxation) 10 | - [Energy Above the Hull](#energy-above-the-hull) 11 | - [Embedding Visualization](#embedding-visualization) 12 | - [Stable, Unique and Novel Structures](#stable-unique-and-novel-structures) 13 | - [Structure Visualization](#structure-visualization) 14 | 15 | ### Transform 16 | `awl2struct.py` is a script to transform the generated `L, W, A, X` to the `cif` format. 17 | 18 | ```bash 19 | python awl2struct.py --output_path YOUR_PATH --label SPACE_GROUP --num_io_process 40 20 | ``` 21 | - `output_path`: the path to read the generated `L, W, A, X` and save the `cif` files 22 | - `label`: the label to save the `cif` files, which is the space group number 23 | - `num_io_process`: the number of processes 24 | 25 | 26 | ### Structure and Composition Validity 27 | `compute_metrics.py` is a script to calculate the structure and composition validity of the generated structures. 28 | 29 | ```bash 30 | python ../scripts/compute_metrics.py --root_path YOUR_PATH --filename YOUR_FILE --output_path ./ --num_io_process 40 31 | ``` 32 | - `root_path`: the path to the dataset 33 | - `filename`: the filename of the generated structures 34 | - `num_io_process`: the number of processes 35 | 36 | ### Novelty and Uniqueness 37 | `compute_metrics_matbench.py` is a script to calculate the novelty and uniqueness of the generated structures. 38 | ```bash 39 | python ../scripts/compute_metrics_matbench.py --train_path TRAIN_PATH --test_path TEST_PATH --gen_path GEN_PATH --output_path OUTPUT_PATH --label SPACE_GROUP --num_io_process 40 40 | ``` 41 | - `train_path`: the path to the training dataset 42 | - `test_path`: the path to the test dataset 43 | - `gen_path`: the path to the generated dataset 44 | - `output_path`: the path to save the metrics results 45 | - `label`: the label to save the metrics results, which is the space group number `g` 46 | - `num_io_process`: the number of processes 47 | 48 | Note that the training, test, and generated datasets should contain the structures within the **same** space group `g` which is specified in the command `--label`. 49 | 50 | 51 | ### Relaxation 52 | `mlff_relax.py` is a script to relax the generated structures using pretrained machine learning force field. Now we support the [`orb`](https://github.com/orbital-materials/orb-models), [`MACE`](https://github.com/ACEsuit/mace), [`matgl`](https://github.com/materialsvirtuallab/matgl) and [`deepmd-kit`](https://github.com/deepmodeling/deepmd-kit) models. Please install corresponding packages before running the script. 53 | 54 | ```bash 55 | python mlff_relax.py --restore_path RESTORE_PATH --filename FILENAME --relaxation --model orb --model_path MODEL_PATH 56 | ``` 57 | - `restore_path`: the path to the generated structures 58 | - `filename`: the filename of the generated structures 59 | - `relaxation`: whether to relax the structures, if not specified, the script will only predict the energy of the structures without relaxation 60 | - `model`: the model to use for relaxation, which can be `orb`, `mace`, `matgl` or `dp` 61 | - `model_path`: the path to the machine learning force field checkpoint 62 | - `primitive`: whether to convert the structures to primitive cells, if not specified, the script will only relax the structures without converting to primitive cells. This can be used to reduce the number of atoms in the structures and speed up the relaxation process 63 | - `fixsymmetry`: whether to fix the space group symmetry of the structures in the relaxation process 64 | 65 | ### Energy Above the Hull 66 | `e_above_hull.py` is a script to calculate the energy above the hull of the generated structures based on the Materials Project database. To calculate the energy above the hull, the API key of the Materials Project is required, which can be obtained from the [Materials Project website](https://next-gen.materialsproject.org/). Furthermore, the `mp_api` package should be installed. 67 | 68 | ```bash 69 | python e_above_hull.py --restore_path RESTORE_PATH --filename FILENAME --api_key API_KEY --label LABEL --relaxation 70 | ``` 71 | - `restore_path`: the path to the structures 72 | - `filename`: the filename of the structures 73 | - `api_key`: the API key of the Materials Project 74 | - `label`: the label to save the energy above the hull file 75 | - `relaxation`: whether to calculate the energy above the hull based on the relaxed structures 76 | 77 | `e_above_hull_alex.py` is a script to calculate the energy above the hull of the generated structures based on the Alexandria database. To calculate the energy above the hull, the Alexandria convex hull data is required, which can be obtained from the [Alexandria website](https://alexandria.icams.rub.de/). 78 | 79 | ```bash 80 | python e_above_hull_alex.py --convex_path CONVEX_PATH --restore_path RESTORE_PATH --filename FILENAME --api_key API_KEY --label LABEL --relaxation 81 | ``` 82 | - `convex_path`: the path to the Alexandria convex hull data 83 | - `restore_path`: the path to the structures 84 | - `filename`: the filename of the structures 85 | - `api_key`: the API key of the Materials Project 86 | - `label`: the label to save the energy above the hull file 87 | - `relaxation`: whether to calculate the energy above the hull based on the relaxed structures 88 | 89 | ### Embedding Visualization 90 | `plot_embeddings.py` is a script to visualize the correlation of the learned embedding vectors of different elements. 91 | 92 | ```bash 93 | python plot_embeddings.py --restore_path RESTORE_PATH 94 | ``` 95 | 96 | - `restore_path`: the path to the model checkpoint 97 | 98 | ### Stable, Unique and Novel Structures 99 | `check_sun_materials.py` is a script to check the stable, unique and novel structures based on the given reference dataset. 100 | 101 | ```bash 102 | python check_sun_materials.py --restore_path RESTORE_PATH --filename FILENAME --ref_path REF_PATH 103 | ``` 104 | 105 | - `restore_path`: the path to the generated structures 106 | - `filename`: the filename of the generated structures 107 | - `ref_path`: the path to the reference dataset 108 | 109 | ### Structure Visualization 110 | `structure_visualization.ipynb` is a notebook to visualize the generated structures. 111 | -------------------------------------------------------------------------------- /scripts/awl2struct.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./crystalformer/src/') 3 | 4 | import pandas as pd 5 | import numpy as np 6 | from ast import literal_eval 7 | import multiprocessing 8 | import itertools 9 | import argparse 10 | 11 | from pymatgen.core import Structure, Lattice 12 | from wyckoff import wmax_table, mult_table, symops 13 | 14 | symops = np.array(symops) 15 | mult_table = np.array(mult_table) 16 | wmax_table = np.array(wmax_table) 17 | 18 | 19 | def symmetrize_atoms(g, w, x): 20 | ''' 21 | symmetrize atoms via, apply all sg symmetry op, finding the generator, and lastly apply symops 22 | we need to do that because the sampled atom might not be at the first WP 23 | Args: 24 | g: int 25 | w: int 26 | x: (3,) 27 | Returns: 28 | xs: (m, 3) symmetrize atom positions 29 | ''' 30 | 31 | # (1) apply all space group symmetry op to the x 32 | w_max = wmax_table[g-1].item() 33 | m_max = mult_table[g-1, w_max].item() 34 | ops = symops[g-1, w_max, :m_max] # (m_max, 3, 4) 35 | affine_point = np.array([*x, 1]) # (4, ) 36 | coords = ops@affine_point # (m_max, 3) 37 | coords -= np.floor(coords) 38 | 39 | # (2) search for the generator which satisfies op0(x) = x , i.e. the first Wyckoff position 40 | # here we solve it in a jit friendly way by looking for the minimal distance solution for the lhs and rhs 41 | #https://github.com/qzhu2017/PyXtal/blob/82e7d0eac1965c2713179eeda26a60cace06afc8/pyxtal/wyckoff_site.py#L115 42 | def dist_to_op0x(coord): 43 | diff = np.dot(symops[g-1, w, 0], np.array([*coord, 1])) - coord 44 | diff -= np.rint(diff) 45 | return np.sum(diff**2) 46 | # loc = np.argmin(jax.vmap(dist_to_op0x)(coords)) 47 | loc = np.argmin([dist_to_op0x(coord) for coord in coords]) 48 | x = coords[loc].reshape(3,) 49 | 50 | # (3) lastly, apply the given symmetry op to x 51 | m = mult_table[g-1, w] 52 | ops = symops[g-1, w, :m] # (m, 3, 4) 53 | affine_point = np.array([*x, 1]) # (4, ) 54 | xs = ops@affine_point # (m, 3) 55 | xs -= np.floor(xs) # wrap back to 0-1 56 | return xs 57 | 58 | def get_struct_from_lawx(G, L, A, W, X): 59 | """ 60 | Get the pymatgen.Structure object from the input data 61 | 62 | Args: 63 | G: space group number 64 | L: lattice parameters 65 | A: element number list 66 | W: wyckoff letter list 67 | X: fractional coordinates list 68 | 69 | Returns: 70 | struct: pymatgen.Structure object 71 | """ 72 | A = A[np.nonzero(A)] 73 | X = X[np.nonzero(A)] 74 | W = W[np.nonzero(A)] 75 | 76 | lattice = Lattice.from_parameters(*L) 77 | xs_list = [symmetrize_atoms(G, w, x) for w, x in zip(W, X)] 78 | as_list = [[A[idx] for _ in range(len(xs))] for idx, xs in enumerate(xs_list)] 79 | A_list = list(itertools.chain.from_iterable(as_list)) 80 | X_list = list(itertools.chain.from_iterable(xs_list)) 81 | struct = Structure(lattice, A_list, X_list) 82 | return struct.as_dict() 83 | 84 | 85 | def main(args): 86 | if args.label is not None: 87 | input_path = args.output_path + f'output_{args.label}.csv' 88 | output_path = args.output_path + f'output_{args.label}_struct.csv' 89 | else: 90 | input_path = args.output_path + f'output.csv' 91 | output_path = args.output_path + f'output_struct.csv' 92 | 93 | origin_data = pd.read_csv(input_path) 94 | 95 | L,X,A,W = origin_data['L'],origin_data['X'],origin_data['A'],origin_data['W'] 96 | L = L.apply(lambda x: literal_eval(x)) 97 | X = X.apply(lambda x: literal_eval(x)) 98 | A = A.apply(lambda x: literal_eval(x)) 99 | W = W.apply(lambda x: literal_eval(x)) 100 | # M = M.apply(lambda x: literal_eval(x)) 101 | 102 | # convert array of list to numpy ndarray 103 | L = np.array(L.tolist()) 104 | X = np.array(X.tolist()) 105 | A = np.array(A.tolist()) 106 | W = np.array(W.tolist()) 107 | print(L.shape,X.shape,A.shape,W.shape) 108 | 109 | if args.label is None: 110 | G = origin_data['G'] 111 | G = np.array(G.tolist()) 112 | else: 113 | G = np.array([int(args.label) for _ in range(len(L))]) 114 | 115 | ### Multiprocessing. Use it if only run on CPU 116 | p = multiprocessing.Pool(args.num_io_process) 117 | structures = p.starmap_async(get_struct_from_lawx, zip(G, L, A, W, X)).get() 118 | p.close() 119 | p.join() 120 | 121 | data = pd.DataFrame() 122 | data['cif'] = structures 123 | data.to_csv(output_path, mode='a', index=False, header=True) 124 | 125 | 126 | if __name__ == '__main__': 127 | parser = argparse.ArgumentParser(description='') 128 | parser.add_argument('--output_path', default='./', help='filepath of the output and input file') 129 | parser.add_argument('--label', default=None, help='output file label') 130 | parser.add_argument('--num_io_process', type=int, default=40, help='number of process used in multiprocessing io') 131 | args = parser.parse_args() 132 | main(args) 133 | -------------------------------------------------------------------------------- /scripts/check_sun_materials.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | from pymatgen.core import Structure, Composition 5 | from pymatgen.analysis.structure_matcher import StructureMatcher 6 | from pymatgen.symmetry.analyzer import SpacegroupAnalyzer 7 | 8 | 9 | def make_compare_structures(StructureMatcher): 10 | """ 11 | Args: 12 | StructureMatcher: pymatgen.analysis.structure_matcher.StructureMatcher 13 | 14 | Returns: 15 | compare_structures: function, compare two structures 16 | """ 17 | 18 | def compare_structures(s1, s2): 19 | """ 20 | Args: 21 | s1: pymatgen Structure 22 | s2: pymatgen Structure 23 | 24 | Returns: 25 | bool, True if the two structures are the same 26 | """ 27 | 28 | if s1.composition.reduced_composition != s2.composition.reduced_composition: 29 | return False 30 | else: 31 | return StructureMatcher.fit(s1, s2) 32 | 33 | return compare_structures 34 | 35 | 36 | def make_search_duplicate(ref_data, StructureMatcher, spg_search=False): 37 | """ 38 | Args: 39 | ref_data: pd.DataFrame, reference data 40 | StructureMatcher: pymatgen.analysis.structure_matcher.StructureMatcher 41 | spg_search: bool, whether to filter the reference data by space group number 42 | 43 | Returns: 44 | search_duplicate: function, search for duplicates in the reference data 45 | """ 46 | 47 | def search_duplicate(s): 48 | """ 49 | Args: 50 | s: pymatgen Structure 51 | 52 | Returns: 53 | duplicate: bool, True if the structure is a duplicate 54 | 55 | sometimes the matching of the space group number is not accurate 56 | so we will not use it to filter the reference data 57 | """ 58 | 59 | if spg_search: 60 | try: 61 | spg_analyzer = SpacegroupAnalyzer(s) 62 | spg = spg_analyzer.get_space_group_number() 63 | 64 | except Exception as e: 65 | spg = None 66 | print(e) 67 | print(f"Error with structure {s}") 68 | pass 69 | 70 | if spg is not None: 71 | sub_data = ref_data[ref_data['spg'] == spg] 72 | else: 73 | sub_data = ref_data 74 | 75 | else: 76 | sub_data = ref_data 77 | 78 | # pick all structures with the same composition 79 | sub_data = sub_data[sub_data['composition'] == s.composition.reduced_composition] 80 | 81 | duplicate = False 82 | # compare the structure with all structures with the same composition 83 | for s2 in sub_data['structure']: 84 | s2 = Structure.from_dict(eval(s2)) 85 | if StructureMatcher.fit(s, s2): 86 | duplicate = True 87 | break 88 | 89 | return duplicate 90 | 91 | return search_duplicate 92 | 93 | 94 | def main(args): 95 | 96 | # print all the parameters 97 | for arg in vars(args): 98 | print(f"{arg}: {getattr(args, arg)}") 99 | 100 | data = pd.read_csv(os.path.join(args.restore_path, args.filename)) 101 | ref_data = pd.read_csv(args.ref_path) 102 | 103 | # only keep the necessary columns 104 | ref_data = ref_data[['formula', 'elements', 'structure', 'spg']] 105 | 106 | if args.spg_search and args.spg is not None: 107 | ref_data = ref_data[ref_data['spg'] == args.spg] # filter by space group 108 | print(f"Number of structures in the reference data with space group {args.spg}: {ref_data.shape[0]}") 109 | else: 110 | print(f"Number of structures in the reference data: {ref_data.shape[0]}") 111 | 112 | sm = StructureMatcher() 113 | compare_structures = make_compare_structures(sm) 114 | 115 | # remove unstable structures 116 | data = data[data['relaxed_ehull'] <= 0.1] 117 | structures = [Structure.from_dict(eval(crys_dict)) for crys_dict in data['relaxed_cif']] 118 | print(f"Number of stable structures: {len(structures)}") 119 | 120 | # remove duplicates (Uniqueness) 121 | idx_list = [] 122 | unique_structures = [] 123 | for idx, s in enumerate(structures): 124 | if not any([compare_structures(s, us) for us in unique_structures]): 125 | unique_structures.append(s) 126 | idx_list.append(idx) 127 | 128 | data = data.iloc[idx_list] 129 | print(f"Number of stable and unique structures: {len(unique_structures)}") 130 | 131 | # remove structures that are already in the reference data (Novelty) 132 | comp_list = [] 133 | for idx, formula in enumerate(ref_data['formula']): 134 | try: 135 | comp = Composition(formula) 136 | comp_list.append(comp) 137 | except Exception as e: 138 | # Can't parse formula when formula is NaN 139 | print(e) 140 | print(f"Error with formula {formula}") 141 | if ref_data.iloc[idx]['elements'] == "['Na', 'N']": 142 | comp_list.append(Composition("NaN")) 143 | 144 | print(len(comp_list)) 145 | comp_list = [comp.reduced_composition for comp in comp_list] 146 | ref_data['composition'] = comp_list 147 | 148 | search_duplicate = make_search_duplicate(ref_data, sm, args.spg_search) 149 | duplicate_list = list(map(search_duplicate, unique_structures)) 150 | 151 | # pick the idx of False in duplicate_list 152 | idx_list = [idx for idx, duplicate in enumerate(duplicate_list) if not duplicate] 153 | data = data.iloc[idx_list] 154 | print(f"Number of stable, unique and novel structures: {data.shape[0]}") 155 | 156 | if args.spg is not None: 157 | data.to_csv(os.path.join(args.restore_path, f"sun_structures_{args.spg}.csv"), index=False) 158 | else: 159 | data.to_csv(os.path.join(args.restore_path, "sun_structures.csv"), index=False) 160 | 161 | 162 | if __name__ == "__main__": 163 | import argparse 164 | parser = argparse.ArgumentParser("Check the stable, Unique and Novelty structures") 165 | parser.add_argument("--spg", type=int, default=None, help="Space group number") 166 | parser.add_argument("--spg_search", action="store_true", help="Whether to filter the reference data by space group number") 167 | parser.add_argument("--restore_path", type=str, default=None, help="Path to the restored data") 168 | parser.add_argument("--filename", type=str, default="relaxed_structures_ehull.csv", help="Filename of the restored data") 169 | parser.add_argument("--ref_path", type=str, default="/data/zdcao/crystal_gpt/dataset/alex/PBE/alex20/alex20.csv", help="Path to the reference data") 170 | args = parser.parse_args() 171 | main(args) 172 | -------------------------------------------------------------------------------- /scripts/compute_metrics.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/txie-93/cdvae/blob/main/scripts/compute_metrics.py 2 | from collections import Counter 3 | import argparse 4 | import json 5 | import os 6 | import pandas as pd 7 | from ast import literal_eval 8 | 9 | import numpy as np 10 | import multiprocessing 11 | from pathlib import Path 12 | 13 | from pymatgen.core.structure import Structure, Composition 14 | from matminer.featurizers.site.fingerprint import CrystalNNFingerprint 15 | from matminer.featurizers.composition.composite import ElementProperty 16 | 17 | from eval_utils import ( 18 | smact_validity, structure_validity) 19 | 20 | # TODO: AttributeError in CrystalNNFP 21 | CrystalNNFP = CrystalNNFingerprint.from_preset("ops") 22 | CompFP = ElementProperty.from_preset('magpie') 23 | 24 | 25 | class Crystal(object): 26 | 27 | def __init__(self, crys_dict): 28 | self.crys_dict = crys_dict 29 | 30 | self.get_structure() 31 | self.get_composition() 32 | self.get_validity() 33 | # self.get_fingerprints() 34 | 35 | def get_structure(self): 36 | try: 37 | self.structure = Structure.from_dict(self.crys_dict) 38 | self.atom_types = [s.specie.number for s in self.structure] 39 | self.constructed = True 40 | except Exception: 41 | self.constructed = False 42 | self.invalid_reason = 'construction_raises_exception' 43 | if self.structure.volume < 0.1: 44 | self.constructed = False 45 | self.invalid_reason = 'unrealistically_small_lattice' 46 | 47 | def get_composition(self): 48 | elem_counter = Counter(self.atom_types) 49 | composition = [(elem, elem_counter[elem]) 50 | for elem in sorted(elem_counter.keys())] 51 | elems, counts = list(zip(*composition)) 52 | counts = np.array(counts) 53 | counts = counts / np.gcd.reduce(counts) 54 | self.elems = elems 55 | self.comps = tuple(counts.astype('int').tolist()) 56 | 57 | def get_validity(self): 58 | self.comp_valid = smact_validity(self.elems, self.comps) 59 | if self.constructed: 60 | self.struct_valid = structure_validity(self.structure) 61 | else: 62 | self.struct_valid = False 63 | self.valid = self.comp_valid and self.struct_valid 64 | 65 | def get_fingerprints(self): 66 | elem_counter = Counter(self.atom_types) 67 | comp = Composition(elem_counter) 68 | self.comp_fp = CompFP.featurize(comp) 69 | try: 70 | site_fps = [CrystalNNFP.featurize( 71 | self.structure, i) for i in range(len(self.structure))] 72 | except Exception: 73 | # counts crystal as invalid if fingerprint cannot be constructed. 74 | self.valid = False 75 | self.comp_fp = None 76 | self.struct_fp = None 77 | return 78 | self.struct_fp = np.array(site_fps).mean(axis=0) 79 | 80 | 81 | def get_validity(crys): 82 | comp_valid = np.array([c.comp_valid for c in crys]).mean() 83 | struct_valid = np.array([c.struct_valid for c in crys]).mean() 84 | valid = np.array([c.valid for c in crys]).mean() 85 | return {'comp_valid': comp_valid, 86 | 'struct_valid': struct_valid, 87 | 'valid': valid} 88 | 89 | def get_crystal(cif_dict): 90 | try: return Crystal(cif_dict) 91 | except: 92 | print("Crystal construction failed") 93 | # print(cif_dict) 94 | struct = Structure.from_dict(cif_dict) 95 | print(struct) 96 | return None # return None if Crystal construction fails 97 | 98 | def main(args): 99 | all_metrics = {} 100 | 101 | csv_path = os.path.join(args.root_path, args.filename) 102 | data = pd.read_csv(csv_path) 103 | cif_strings = data['cif'] 104 | 105 | p = multiprocessing.Pool(args.num_io_process) 106 | crys_dict = p.map_async(literal_eval, cif_strings).get() 107 | # crys = p.map_async(Crystal, crys_dict).get() 108 | crys = p.map_async(get_crystal, crys_dict).get() 109 | crys = [c for c in crys if c is not None] 110 | print(f"Number of valid crystals: {len(crys)}") 111 | p.close() 112 | p.join() 113 | 114 | all_metrics['validity'] = get_validity(crys) 115 | print(all_metrics) 116 | 117 | if args.label == '': 118 | metrics_out_file = 'eval_metrics.json' 119 | else: 120 | metrics_out_file = f'eval_metrics_{args.label}.json' 121 | metrics_out_file = os.path.join(args.root_path, metrics_out_file) 122 | print("output path:", metrics_out_file) 123 | 124 | # only overwrite metrics computed in the new run. 125 | if Path(metrics_out_file).exists(): 126 | with open(metrics_out_file, 'r') as f: 127 | written_metrics = json.load(f) 128 | if isinstance(written_metrics, dict): 129 | written_metrics.update(all_metrics) 130 | else: 131 | with open(metrics_out_file, 'w') as f: 132 | json.dump(all_metrics, f) 133 | if isinstance(written_metrics, dict): 134 | with open(metrics_out_file, 'w') as f: 135 | json.dump(written_metrics, f) 136 | else: 137 | with open(metrics_out_file, 'w') as f: 138 | json.dump(all_metrics, f) 139 | 140 | 141 | if __name__ == '__main__': 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument('--root_path', default='/data/zdcao/crystal_gpt/dataset/mp_20/symm_data/') 144 | parser.add_argument('--filename', default='out_structure.csv') 145 | parser.add_argument('--label', default='') 146 | parser.add_argument('--num_io_process', type=int, default=40, help='number of process used in multiprocessing io') 147 | args = parser.parse_args() 148 | main(args) 149 | 150 | -------------------------------------------------------------------------------- /scripts/compute_metrics_matbench.py: -------------------------------------------------------------------------------- 1 | # Function: compute the metrics of the generated structures 2 | # It takes about 20 min to compute the metrics of 9000 generated structures for MP-20 dataset 3 | import pandas as pd 4 | from pymatgen.core import Structure 5 | import multiprocessing 6 | import argparse 7 | from ast import literal_eval 8 | import json 9 | from time import time 10 | 11 | from matbench_genmetrics.core.metrics import GenMetrics 12 | 13 | 14 | def get_structure(cif): 15 | try: 16 | return Structure.from_str(cif, fmt='cif') 17 | except: 18 | return Structure.from_dict(literal_eval(cif)) 19 | 20 | def main(args): 21 | train_df = pd.read_csv(args.train_path) 22 | test_df = pd.read_csv(args.test_path) 23 | gen_df = pd.read_csv(args.gen_path) 24 | 25 | p = multiprocessing.Pool(args.num_io_process) 26 | train_structures = p.map_async(get_structure, train_df['cif']).get() 27 | test_structures = p.map_async(get_structure, test_df['cif']).get() 28 | gen_structures = p.map_async(get_structure, gen_df['cif']).get() 29 | p.close() 30 | p.join() 31 | 32 | start_time = time() 33 | all_metrics = {} 34 | gen_metrics = GenMetrics(train_structures=train_structures, 35 | test_structures=test_structures, 36 | gen_structures=gen_structures, 37 | ) 38 | 39 | # all_metrics = gen_metrics.metrics 40 | # all_metrics['validity'] = gen_metrics.validity 41 | all_metrics['novelty'] = gen_metrics.novelty 42 | all_metrics['uniqueness'] = gen_metrics.uniqueness 43 | 44 | end_time = time() 45 | print('Time used: {:.2f} s'.format(end_time - start_time)) 46 | print(all_metrics) 47 | with open(args.output_path + f'metrics_{args.label}.json', 'w') as f: 48 | json.dump(all_metrics, f, indent=4) 49 | 50 | if __name__ == '__main__': 51 | parser = argparse.ArgumentParser(description='') 52 | parser.add_argument('--train_path', default='/data/zdcao/crystal_gpt/dataset/mp_20/sg_225/train_sg_225.csv', help='') 53 | parser.add_argument('--test_path', default='/data/zdcao/crystal_gpt/dataset/mp_20/sg_225/test_sg_225.csv', help='') 54 | parser.add_argument('--gen_path', default='/data/zdcao/crystal_gpt/data/adam_bs_100_lr_0.0001_decay_0_clip_1_A_119_W_28_N_21_Nf_5_K_48_16_h0_256_l_4_H_8_k_16_m_32_drop_0.3/temp_1.0/output_225.csv', help='') 55 | parser.add_argument('--output_path', default='//data/zdcao/crystal_gpt/data/adam_bs_100_lr_0.0001_decay_0_clip_1_A_119_W_28_N_21_Nf_5_K_48_16_h0_256_l_4_H_8_k_16_m_32_drop_0.3/temp_1.0/', help='filepath of the metrics output file') 56 | parser.add_argument('--label', default='225', help='output file label') 57 | parser.add_argument('--num_io_process', type=int, default=40, help='number of process used in multiprocessing io') 58 | args = parser.parse_args() 59 | 60 | main(args) 61 | -------------------------------------------------------------------------------- /scripts/config.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | testdir = os.path.dirname(os.path.abspath(__file__)) 3 | sys.path.append(os.path.join(testdir, "../crystalformer/src")) 4 | -------------------------------------------------------------------------------- /scripts/e_above_hull.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import tempfile 5 | from mp_api.client import MPRester 6 | from pymatgen.core import Structure 7 | from pymatgen.analysis.phase_diagram import PhaseDiagram 8 | from pymatgen.entries.computed_entries import ComputedStructureEntry 9 | from pymatgen.entries.compatibility import MaterialsProject2020Compatibility 10 | from pymatgen.io.vasp.sets import MPRelaxSet 11 | from pymatgen.io.vasp.inputs import Incar, Poscar 12 | 13 | 14 | # Taken from https://github.com/facebookresearch/crystal-llm/blob/main/e_above_hull.py 15 | def generate_CSE(structure, m3gnet_energy): 16 | # Write VASP inputs files as if we were going to do a standard MP run 17 | # this is mainly necessary to get the right U values / etc 18 | b = MPRelaxSet(structure) 19 | with tempfile.TemporaryDirectory() as tmpdirname: 20 | b.write_input(f"{tmpdirname}/", potcar_spec=True) 21 | poscar = Poscar.from_file(f"{tmpdirname}/POSCAR") 22 | incar = Incar.from_file(f"{tmpdirname}/INCAR") 23 | clean_structure = Structure.from_file(f"{tmpdirname}/POSCAR") 24 | 25 | # Get the U values and figure out if we should have run a GGA+U calc 26 | param = {"hubbards": {}} 27 | if "LDAUU" in incar: 28 | param["hubbards"] = dict(zip(poscar.site_symbols, incar["LDAUU"])) 29 | param["is_hubbard"] = ( 30 | incar.get("LDAU", True) and sum(param["hubbards"].values()) > 0 31 | ) 32 | if param["is_hubbard"]: 33 | param["run_type"] = "GGA+U" 34 | 35 | # Make a ComputedStructureEntry without the correction 36 | cse_d = { 37 | "structure": clean_structure, 38 | "energy": m3gnet_energy, 39 | "correction": 0.0, 40 | "parameters": param, 41 | } 42 | 43 | # Apply the MP 2020 correction scheme (anion/+U/etc) 44 | cse = ComputedStructureEntry.from_dict(cse_d) 45 | _ = MaterialsProject2020Compatibility(check_potcar=False).process_entries( 46 | cse, 47 | clean=True, 48 | ) 49 | 50 | # Return the final CSE (notice that the composition/etc is also clean, not things like Fe3+)! 51 | return cse 52 | 53 | 54 | def get_strutures_ehull(mpr, structures, energies): 55 | """ 56 | Get the e_above_hull for a list of structures 57 | 58 | Args: 59 | mpr: MPRester object 60 | structures: list of pymatgen.Structure objects 61 | energies: list of energies of the structures 62 | 63 | Returns: 64 | ehull_list: list of e_above_hull values 65 | """ 66 | ehull_list = [] 67 | for s, e in zip(structures, energies): 68 | # entry = PDEntry(s.composition, e) 69 | entry = generate_CSE(s, e) 70 | elements = [el.name for el in entry.composition.elements] 71 | 72 | # Obtain only corrected GGA and GGA+U ComputedStructureEntry objects 73 | entries = mpr.get_entries_in_chemsys(elements=elements, 74 | additional_criteria={"thermo_types": ["GGA_GGA+U"], 75 | "is_stable": True} # Only stable entries 76 | ) 77 | pd = PhaseDiagram(entries) 78 | try: 79 | ehull = pd.get_e_above_hull(entry, allow_negative=True) 80 | ehull_list.append(ehull) 81 | print(f"Structure: {s.formula}, E_hull: {ehull:.3f} eV/atom") 82 | except: 83 | print(f"Structure: {s.formula}, E_hull: N/A") 84 | ehull_list.append(np.nan) 85 | 86 | return ehull_list 87 | 88 | 89 | def main(args): 90 | data = pd.read_csv(os.path.join(args.restore_path, args.filename)) 91 | cif_strings = data["relaxed_cif"] 92 | try: structures = [Structure.from_str(cif, fmt="cif") for cif in cif_strings] 93 | except: structures = [Structure.from_dict(eval(cif)) for cif in cif_strings] 94 | mpr = MPRester(args.api_key) 95 | 96 | unrelaxed_ehull_list = get_strutures_ehull(mpr, structures, data["initial_energy"]) 97 | if args.relaxation: 98 | relaxed_ehull_list = get_strutures_ehull(mpr, structures, data["final_energy"]) 99 | else: 100 | relaxed_ehull_list = [np.nan] * len(structures) # Fill with NaNs 101 | 102 | output_data = pd.DataFrame() 103 | output_data["relaxed_cif"] = cif_strings 104 | output_data["relaxed_ehull"] = relaxed_ehull_list 105 | output_data["unrelaxed_ehull"] = unrelaxed_ehull_list 106 | if args.label: 107 | output_data.to_csv(f"ehull_{args.label}.csv", index=False) 108 | else: 109 | output_data.to_csv("ehull.csv", index=False) 110 | 111 | 112 | if __name__ == "__main__": 113 | import argparse 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--restore_path", type=str, default="/data/zdcao/crystal_gpt/dataset/mp_20/") 116 | parser.add_argument('--filename', default='relaxed_structures_testdata.csv') 117 | parser.add_argument('--relaxation', action='store_true') 118 | parser.add_argument('--api_key', default='9zBRHS6Zp94KE28PeMdSk5gCyteIm6Ks') 119 | parser.add_argument('--label', default='testdata') 120 | args = parser.parse_args() 121 | main(args) 122 | -------------------------------------------------------------------------------- /scripts/e_above_hull_alex.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json, bz2 3 | import tempfile 4 | import pandas as pd 5 | import multiprocessing as mp 6 | from functools import partial 7 | from pymatgen.core import Structure 8 | from pymatgen.entries.computed_entries import ComputedStructureEntry 9 | from pymatgen.analysis.phase_diagram import PhaseDiagram 10 | from pymatgen.entries.compatibility import MaterialsProject2020Compatibility 11 | from pymatgen.io.vasp.sets import MPRelaxSet 12 | from pymatgen.io.vasp.inputs import Incar, Poscar 13 | 14 | 15 | # Taken from https://github.com/facebookresearch/crystal-llm/blob/main/e_above_hull.py 16 | def generate_CSE(structure, m3gnet_energy): 17 | # Write VASP inputs files as if we were going to do a standard MP run 18 | # this is mainly necessary to get the right U values / etc 19 | b = MPRelaxSet(structure) 20 | with tempfile.TemporaryDirectory() as tmpdirname: 21 | b.write_input(f"{tmpdirname}/", potcar_spec=True) 22 | poscar = Poscar.from_file(f"{tmpdirname}/POSCAR") 23 | incar = Incar.from_file(f"{tmpdirname}/INCAR") 24 | clean_structure = Structure.from_file(f"{tmpdirname}/POSCAR") 25 | 26 | # Get the U values and figure out if we should have run a GGA+U calc 27 | param = {"hubbards": {}} 28 | if "LDAUU" in incar: 29 | param["hubbards"] = dict(zip(poscar.site_symbols, incar["LDAUU"])) 30 | param["is_hubbard"] = ( 31 | incar.get("LDAU", True) and sum(param["hubbards"].values()) > 0 32 | ) 33 | if param["is_hubbard"]: 34 | param["run_type"] = "GGA+U" 35 | 36 | # Make a ComputedStructureEntry without the correction 37 | cse_d = { 38 | "structure": clean_structure, 39 | "energy": m3gnet_energy, 40 | "correction": 0.0, 41 | "parameters": param, 42 | } 43 | 44 | # Apply the MP 2020 correction scheme (anion/+U/etc) 45 | cse = ComputedStructureEntry.from_dict(cse_d) 46 | _ = MaterialsProject2020Compatibility(check_potcar=False).process_entries( 47 | cse, 48 | clean=True, 49 | ) 50 | 51 | # Return the final CSE (notice that the composition/etc is also clean, not things like Fe3+)! 52 | return cse 53 | 54 | 55 | def calculate_hull(structure, energy, entries): 56 | entries = [ComputedStructureEntry.from_dict(i) for i in entries] 57 | pd = PhaseDiagram(entries) 58 | 59 | try: 60 | entry = generate_CSE(structure, energy) 61 | ehull = pd.get_e_above_hull(entry, allow_negative=True) 62 | print(f"Structure: {structure.formula}, E_hull: {ehull:.3f} eV/atom") 63 | except Exception as e: 64 | print(f"Structure: {structure.formula}, E_hull: Error: {e}") 65 | ehull = None 66 | 67 | return ehull 68 | 69 | 70 | def forward_fn(structure, energy, ref_data): 71 | 72 | comp = structure.composition 73 | elements = set(ii.name for ii in comp.elements) 74 | 75 | # filter entries by elements 76 | entries = [entry for entry in ref_data['entries'] if set(entry['data']['elements']) <= elements] 77 | ehull = calculate_hull(structure, energy, entries) 78 | 79 | return ehull 80 | 81 | 82 | def main(args): 83 | with bz2.open(args.convex_path) as fh: 84 | ref_data = json.loads(fh.read().decode('utf-8')) 85 | partial_forward_fn = partial(forward_fn, ref_data=ref_data) 86 | 87 | data = pd.read_csv(os.path.join(args.restore_path, args.filename)) 88 | try: structures = [Structure.from_dict(eval(cif)) for cif in data['relaxed_cif']] 89 | except: structures = [Structure.from_str(cif, fmt="cif") for cif in data['relaxed_cif']] 90 | 91 | # with mp.Pool(args.num_io_process) as p: 92 | # unrelaxed_ehull_list = p.map_async(partial_forward_fn, zip(structures, data['initial_energy'])).get() 93 | unrelaxed_ehull_list = list(map(partial_forward_fn, structures, data['initial_energy'])) 94 | data['unrelaxed_ehull'] = unrelaxed_ehull_list 95 | 96 | if args.relaxation: 97 | # with mp.Pool(args.num_io_process) as p: 98 | # relaxed_ehull_list = p.map_async(partial_forward_fn, zip(structures, data['final_energy'])).get() 99 | relaxed_ehull_list = list(map(partial_forward_fn, structures, data['final_energy'])) 100 | data['relaxed_ehull'] = relaxed_ehull_list 101 | 102 | else: 103 | data['relaxed_ehull'] = unrelaxed_ehull_list # same as unrelaxed 104 | 105 | if args.label: 106 | data.to_csv(f"{args.restore_path}/relaxed_structures_{args.label}_ehull.csv", index=False) 107 | else: 108 | data.to_csv(f"{args.restore_path}/relaxed_structures_ehull.csv", index=False) 109 | 110 | 111 | if __name__ == '__main__': 112 | import argparse 113 | parser = argparse.ArgumentParser(description="Calculate e_above_hull for relaxed structures") 114 | parser.add_argument("--convex_path", type=str, default="/data/zdcao/crystal_gpt/dataset/alex/PBE/convex_hull_pbe_2023.12.29.json.bz2") 115 | parser.add_argument("--restore_path", type=str, default="./experimental/") 116 | parser.add_argument('--filename', default='relaxed_structures.csv') 117 | parser.add_argument('--relaxation', action='store_true') 118 | parser.add_argument('--label', default=None) 119 | parser.add_argument('--num_io_process', type=int, default=4) 120 | args = parser.parse_args() 121 | main(args) 122 | -------------------------------------------------------------------------------- /scripts/element_substition.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import argparse 3 | 4 | from pymatgen.core import Structure, Composition 5 | from pymatgen.analysis.bond_valence import BVAnalyzer 6 | from pymatgen.analysis.structure_prediction.volume_predictor import DLSVolumePredictor 7 | from pymatgen.transformations.advanced_transformations import SubstitutionPredictorTransformation 8 | 9 | 10 | def main(args): 11 | # initialize pymatgen objects 12 | bv_analyzer = BVAnalyzer(symm_tol=0) 13 | volume_predictor = DLSVolumePredictor() 14 | elem_sub = SubstitutionPredictorTransformation() 15 | 16 | data = pd.read_csv(args.input_path) 17 | 18 | # select data with spacegroup 19 | data = data[data['spacegroup.number']==args.spacegroup] 20 | 21 | # select data with anonymized formula 22 | cif_strings = [] 23 | for formula, cif in zip(data['pretty_formula'], data['cif']) : 24 | comp = Composition(formula) 25 | if comp.anonymized_formula == args.anonymized_formula: 26 | # print(formula) 27 | cif_strings.append(cif) 28 | print(len(cif_strings)) 29 | 30 | structures = [Structure.from_str(cif, fmt='cif') for cif in cif_strings] 31 | 32 | oxi_structs = [] 33 | 34 | for struct in structures: 35 | try: 36 | struct = bv_analyzer.get_oxi_state_decorated_structure(struct) 37 | oxi_structs.append(struct) 38 | except: 39 | print(struct.composition.reduced_formula) 40 | 41 | print(len(oxi_structs)) 42 | 43 | sub_scale_structs = [] 44 | prob_list = [] 45 | for idx, struct in enumerate(oxi_structs): 46 | print(idx) 47 | try: 48 | sub_structs = elem_sub.apply_transformation(struct , return_ranked_list=10) # return top 3 structures 49 | print(f"there is {len(sub_structs)} sub structures") 50 | for _sub_struct in sub_structs: 51 | 52 | if _sub_struct['probability'] < float(args.prob_threshold): 53 | continue 54 | 55 | sub_struct = _sub_struct['structure'] 56 | if sub_struct.matches(struct): 57 | continue 58 | 59 | scale_struct = volume_predictor.get_predicted_structure(sub_struct) 60 | sub_scale_structs.append(scale_struct) 61 | prob_list.append(_sub_struct['probability']) 62 | except Exception as e: 63 | print(e) 64 | # print(struct) 65 | continue 66 | 67 | print(len(sub_scale_structs)) 68 | 69 | # remove duplicate structures in the list 70 | last_struct = [] 71 | last_prob = [] 72 | for idx, (struct, prob) in enumerate(zip(sub_scale_structs, prob_list)): 73 | if idx == 0: 74 | last_struct.append(struct) 75 | last_prob.append(prob) 76 | continue 77 | 78 | if struct in last_struct: 79 | continue 80 | last_struct.append(struct) 81 | last_prob.append(prob) 82 | print(len(last_struct)) 83 | 84 | # convert to conventional cell 85 | last_struct = [s.to_conventional() for s in last_struct] 86 | 87 | output_data = pd.DataFrame() 88 | output_data['cif'] = [struct.as_dict() for struct in last_struct] 89 | output_data['probability'] = last_prob 90 | output_data = output_data.sort_values(by='probability', ascending=False) 91 | # only select top num structures 92 | output_data = output_data.head(int(args.top_num)) 93 | output_data.to_csv(args.output_path, index=False) 94 | 95 | 96 | if __name__ == '__main__': 97 | parser = argparse.ArgumentParser(description='') 98 | parser.add_argument('--input_path', default='/data/zdcao/crystal_gpt/dataset/mp_20/train.csv', help='filepath of the input file') 99 | parser.add_argument('--spacegroup', default=216, help='spacegroup number') 100 | parser.add_argument('--anonymized_formula', default='AB', help='anonymized formula') 101 | parser.add_argument('--prob_threshold', default=0.05, help='probability threshold') 102 | parser.add_argument('--top_num', default=100, help='top number of the output') 103 | parser.add_argument('--output_path', default='./test.csv', help='filepath of the output file') 104 | args = parser.parse_args() 105 | main(args) 106 | -------------------------------------------------------------------------------- /scripts/eval_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | 4 | import smact 5 | from smact.screening import pauling_test 6 | 7 | from config import * 8 | from elements import element_list 9 | 10 | # Taken without modification from the original CDVAE repo (https://github.com/txie-93/cdvae) 11 | # But delete the unused functions 12 | 13 | def smact_validity(comp, count, 14 | use_pauling_test=True, 15 | include_alloys=True): 16 | elem_symbols = tuple([element_list[elem] for elem in comp]) 17 | space = smact.element_dictionary(elem_symbols) 18 | smact_elems = [e[1] for e in space.items()] 19 | electronegs = [e.pauling_eneg for e in smact_elems] 20 | ox_combos = [e.oxidation_states for e in smact_elems] 21 | if len(set(elem_symbols)) == 1: 22 | return True 23 | if include_alloys: 24 | is_metal_list = [elem_s in smact.metals for elem_s in elem_symbols] 25 | if all(is_metal_list): 26 | return True 27 | 28 | threshold = np.max(count) 29 | compositions = [] 30 | for ox_states in itertools.product(*ox_combos): 31 | stoichs = [(c,) for c in count] 32 | # Test for charge balance 33 | cn_e, cn_r = smact.neutral_ratios( 34 | ox_states, stoichs=stoichs, threshold=threshold) 35 | # Electronegativity test 36 | if cn_e: 37 | if use_pauling_test: 38 | try: 39 | electroneg_OK = pauling_test(ox_states, electronegs) 40 | except TypeError: 41 | # if no electronegativity data, assume it is okay 42 | electroneg_OK = True 43 | else: 44 | electroneg_OK = True 45 | if electroneg_OK: 46 | for ratio in cn_r: 47 | compositions.append( 48 | tuple([elem_symbols, ox_states, ratio])) 49 | compositions = [(i[0], i[2]) for i in compositions] 50 | compositions = list(set(compositions)) 51 | if len(compositions) > 0: 52 | return True 53 | else: 54 | return False 55 | 56 | 57 | def structure_validity(crystal, cutoff=0.5): 58 | dist_mat = crystal.distance_matrix 59 | # Pad diagonal with a large number 60 | dist_mat = dist_mat + np.diag( 61 | np.ones(dist_mat.shape[0]) * (cutoff + 10.)) 62 | if dist_mat.min() < cutoff or crystal.volume < 0.1: 63 | return False 64 | else: 65 | return True -------------------------------------------------------------------------------- /scripts/mlff_relax.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | # To suppress warnings for clearer output 3 | warnings.simplefilter("ignore") 4 | 5 | import pandas as pd 6 | import os 7 | from time import time 8 | from ast import literal_eval 9 | from tqdm import tqdm 10 | 11 | from pymatgen.core import Structure 12 | from pymatgen.io.ase import AseAtomsAdaptor 13 | 14 | from ase.optimize import FIRE 15 | from ase.filters import FrechetCellFilter 16 | from ase.constraints import FixSymmetry 17 | 18 | 19 | def make_orb_calc(model_path, device="cuda"): 20 | from orb_models.forcefield import pretrained 21 | from orb_models.forcefield.calculator import ORBCalculator 22 | 23 | # Load the ORB forcefield model 24 | orbff = pretrained.orb_v2(model_path, device=device) 25 | calc = ORBCalculator(orbff, device=device) 26 | 27 | return calc 28 | 29 | 30 | def make_matgl_calc(model_path, device="cuda"): 31 | import matgl 32 | from matgl.ext.ase import PESCalculator 33 | 34 | pot = matgl.load_model(model_path, device=device) 35 | calc = PESCalculator(pot) 36 | 37 | return calc 38 | 39 | 40 | def make_mace_calc(model_path, device="cuda", default_dtype="float64"): 41 | from mace.calculators import mace_mp 42 | 43 | calc = mace_mp(model=model_path, 44 | dispersion=False, 45 | default_dtype=default_dtype, 46 | device=device) 47 | return calc 48 | 49 | 50 | def make_dp_calc(model_path): 51 | from deepmd.calculator import DP 52 | 53 | calc = DP(model_path) 54 | return calc 55 | 56 | 57 | def relax_structures(calc, structures, relaxation, fmax, steps, fixsymmetry): 58 | """ 59 | Args: 60 | calc: ASE calculator object 61 | structures: List of pymatgen Structure objects 62 | relaxation: Boolean, whether to relax the structures 63 | fmax: Maximum force tolerance for relaxation 64 | steps: Maximum number of steps for relaxation 65 | 66 | Returns: 67 | initial_energies: List of initial energies 68 | final_energies: List of final energies 69 | relaxed_cif_strings: List of relaxed structures in CIF format 70 | formula_list: List of formulas of the structures 71 | 72 | if relaxation is False, the final energies will be the same as the initial energies 73 | """ 74 | 75 | ase_adaptor = AseAtomsAdaptor() 76 | 77 | initial_energies = [] 78 | final_energies = [] 79 | relaxed_cif_strings = [] 80 | 81 | for i in tqdm(range(len(structures))): 82 | struct = structures[i] 83 | atoms = ase_adaptor.get_atoms(struct) 84 | atoms.calc = calc 85 | 86 | initial_energy = atoms.get_potential_energy() 87 | initial_energies.append(initial_energy) 88 | 89 | if relaxation: 90 | if fixsymmetry: 91 | try: 92 | # Fix the space group symmetry of the structure 93 | c = FixSymmetry(atoms) 94 | atoms.set_constraint(c) 95 | except Exception as e: 96 | # sometimes the FixSymmetry constraint may not work if atoms are too close 97 | print(f"Error fixing symmetry: {e}") 98 | 99 | # The following code is adapted from matgl repo 100 | # https://github.com/materialsvirtuallab/matgl/blob/824c1c4cefa9129c0af7066523d1665515f42899/src/matgl/ext/ase.py#L218-L304 101 | # Relax the structure using the FIRE optimizer 102 | optimizer = FrechetCellFilter(atoms) 103 | FIRE(optimizer).run(fmax=fmax, steps=steps) # Run the FIRE optimizer for 100 steps 104 | 105 | final_energies.append(atoms.get_potential_energy()) 106 | relaxed_cif_strings.append(ase_adaptor.get_structure(atoms).as_dict()) 107 | 108 | else: 109 | final_energies.append(initial_energy) 110 | relaxed_cif_strings.append(struct.as_dict()) 111 | 112 | 113 | formula_list = [struct.composition.formula for struct in structures] 114 | 115 | return initial_energies, final_energies, relaxed_cif_strings, formula_list 116 | 117 | 118 | def main(args): 119 | csv_file = os.path.join(args.restore_path, args.filename) 120 | 121 | data = pd.read_csv(csv_file) 122 | cif_strings = data['cif'] 123 | 124 | try: structures = [Structure.from_dict(literal_eval(cif)) for cif in cif_strings] 125 | except: structures = [Structure.from_str(cif, fmt="cif") for cif in cif_strings] 126 | 127 | if args.primitive: 128 | print("Converting structures to primitive form...") 129 | structures = [struct.get_primitive_structure() for struct in structures] 130 | 131 | print("Relaxing structures...") 132 | if args.relaxation: 133 | print("Relaxation is enabled. This may take a while.") 134 | else: 135 | print("Relaxation is disabled. Only initial energies will be calculated.") 136 | 137 | if args.fixsymmetry: 138 | print("Fixing space group symmetry of the structures.") 139 | 140 | print(f"Using {args.model} model at {args.model_path}") 141 | if args.model == "orb": 142 | calc = make_orb_calc(args.model_path, args.device) 143 | elif args.model == "matgl": 144 | calc = make_matgl_calc(args.model_path, args.device) 145 | elif args.model == "mace": 146 | calc = make_mace_calc(args.model_path, args.device) 147 | elif args.model == "dp": 148 | calc = make_dp_calc(args.model_path) 149 | else: 150 | raise ValueError("Invalid model type. Please choose from 'orb', 'matgl', 'mace' or 'dp'.") 151 | 152 | print("Calculating energies...") 153 | start_time = time() 154 | results = relax_structures(calc, structures, args.relaxation, args.fmax, args.steps, args.fixsymmetry) 155 | end_time = time() 156 | print(f"Relaxation took {end_time - start_time:.2f} seconds") 157 | 158 | initial_energies, final_energies, relaxed_cif_strings, formula_list = results 159 | output_data = pd.DataFrame() 160 | output_data['initial_energy'] = initial_energies 161 | output_data['final_energy'] = final_energies 162 | output_data['relaxed_cif'] = relaxed_cif_strings 163 | output_data['formula'] = formula_list 164 | 165 | if args.label: 166 | output_data.to_csv(os.path.join(args.restore_path, f"relaxed_structures_{args.label}.csv"), 167 | index=False) 168 | else: 169 | output_data.to_csv(os.path.join(args.restore_path, "relaxed_structures.csv"), 170 | index=False) 171 | 172 | 173 | if __name__ == "__main__": 174 | 175 | import argparse 176 | parser = argparse.ArgumentParser() 177 | parser.add_argument("--model", type=str, choices=["orb", "matgl", "mace", "dp"], default="orb", help="choose the MLFF model") 178 | parser.add_argument("--device", type=str, default="cuda", help="choose the device to run the model on") 179 | parser.add_argument("--model_path", type=str, default="./data/orb-v2-20241011.ckpt", help="path to the model checkpoint") 180 | parser.add_argument("--restore_path", type=str, default="./experimental/", help="") 181 | parser.add_argument('--filename', default='output_struct.csv', help='filename of the csv file containing the structures') 182 | parser.add_argument('--relaxation', action='store_true', help='whether to relax the structures') 183 | parser.add_argument('--fmax', type=float, default=0.1, help='maximum force tolerance for relaxation') 184 | parser.add_argument('--steps', type=int, default=500, help='max number of steps for relaxation') 185 | parser.add_argument('--label', default=None, help='label for the output file') 186 | parser.add_argument('--primitive', action='store_true', help='convert structures to primitive form') 187 | parser.add_argument('--fixsymmetry', action='store_true', help='fix space group symmetry of the structures') 188 | 189 | args = parser.parse_args() 190 | main(args) 191 | -------------------------------------------------------------------------------- /scripts/plot_embedding.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import os 6 | from functools import partial 7 | 8 | import sys 9 | sys.path.append("./crystalformer/src/") 10 | import checkpoint 11 | from elements import element_list 12 | 13 | @partial(jax.vmap, in_axes=(0, None), out_axes=0) 14 | @partial(jax.vmap, in_axes=(None, 0), out_axes=0) 15 | def cosine_similarity(vector1, vector2): 16 | dot_product = jnp.dot(vector1, vector2) 17 | norm_a = jnp.linalg.norm(vector1) 18 | norm_b = jnp.linalg.norm(vector2) 19 | return dot_product / (norm_a * norm_b) 20 | 21 | import argparse 22 | parser = argparse.ArgumentParser(description="pretrain rdf") 23 | parser.add_argument("--restore_path", default="/data/wanglei/crystalgpt/mp-mpsort-xyz-embed/w-a-x-y-z-periodic-fixed-size-embed-eb630/adam_bs_100_lr_0.0001_decay_0_clip_1_A_119_W_28_N_21_a_1_w_1_l_1_Nf_5_Kx_16_Kl_4_h0_256_l_8_H_8_k_32_m_64_e_32_drop_0.3/", help="") 24 | args = parser.parse_args() 25 | 26 | path = os.path.dirname(args.restore_path) 27 | 28 | 29 | ckpt_filename, epoch_finished = checkpoint.find_ckpt_filename(args.restore_path) 30 | print("Load checkpoint file: %s, epoch finished: %g" %(ckpt_filename, epoch_finished)) 31 | ckpt = checkpoint.load_data(ckpt_filename) 32 | 33 | a_embeddings = ckpt["params"]["~"]["a_embeddings"] 34 | a_a = cosine_similarity(a_embeddings, a_embeddings) 35 | 36 | g_embeddings = ckpt["params"]["~"]["g_embeddings"] 37 | g_g = cosine_similarity(g_embeddings, g_embeddings) 38 | 39 | print (a_a.shape, g_g.shape) 40 | 41 | fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16, 8)) 42 | 43 | ax = axes[0] 44 | a_max = 90 45 | ticks = np.arange(a_max) 46 | element_ticks = [element_list[i+1] for i in ticks] 47 | ax.set_xticks(ticks, labels=element_ticks, fontsize=8, rotation=90) 48 | ax.set_yticks(ticks, labels=element_ticks, fontsize=8) 49 | cax = ax.imshow(a_a[1:a_max+1, 1:a_max+1], cmap='coolwarm', interpolation='none') 50 | fig.colorbar(cax, ax=ax) 51 | 52 | ax = axes[1] 53 | cax = ax.imshow(g_g[:100, :100], cmap='coolwarm', interpolation='none') 54 | fig.colorbar(cax, ax=ax) 55 | 56 | plt.show() 57 | -------------------------------------------------------------------------------- /scripts/process_alex.py: -------------------------------------------------------------------------------- 1 | # alexandria dataset: https://alexandria.icams.rub.de/, https://archive.materialscloud.org/record/2022.126 2 | # This script is used to process the raw data from alexandria dataset, and save the data into csv files 3 | import os 4 | import json 5 | import bz2 6 | import pandas as pd 7 | import multiprocessing as mp 8 | from sklearn.model_selection import train_test_split 9 | 10 | 11 | def get_file_name(filepath): 12 | filename_list = [] 13 | for root, _, files in os.walk(filepath): 14 | for file in files: 15 | if file.endswith(".json.bz2"): 16 | filename_list.append(os.path.join(root, file)) 17 | return filename_list 18 | 19 | 20 | def get_data_from_file(filename): 21 | with bz2.open(filename) as fh: 22 | data = json.loads(fh.read().decode('utf-8')) 23 | fh.close() 24 | 25 | print(len(data["entries"])) 26 | entries = data["entries"] 27 | # save this key information to a dataframe 28 | df = pd.DataFrame([{'e_above_hull': entry['data']['e_above_hull'], 29 | 'e_form': entry['data']['e_form'], 30 | 'mat_id': entry['data']['mat_id'], 31 | 'formula': entry['data']['formula'], 32 | 'elements': entry['data']['elements'], 33 | 'spg': entry['data']['spg'], 34 | 'band_gap_dir': entry['data']['band_gap_dir'], 35 | 'band_gap_ind': entry['data']['band_gap_ind'], 36 | 'nsites': entry['data']['nsites'], 37 | 'structure': entry['structure']} for entry in entries]) 38 | # screening the data 39 | df = df[(df['e_above_hull'] <= 0.1) & (df['nsites'] <= 20)] 40 | 41 | return df 42 | 43 | 44 | def main(args): 45 | filename_list = get_file_name(args.input_path) 46 | print(len(filename_list)) 47 | with mp.Pool(args.num_io_process) as pool: 48 | df_list = pool.map_async(get_data_from_file, filename_list).get() 49 | df_total = pd.concat(df_list, axis=0) 50 | 51 | print("total data: ", df_total.shape) 52 | if args.ratio < 1.0: 53 | df_total = df_total.sample(frac=args.ratio, random_state=42) 54 | print("random sampled data: ", df_total.shape) 55 | 56 | ########### split the data into train, val, test ########### 57 | train_data, val_test_data = train_test_split(df_total, test_size=0.2, random_state=42) 58 | val_data, test_data = train_test_split(val_test_data, test_size=0.5, random_state=42) 59 | 60 | print("train data: ",train_data.shape) 61 | print("val data: ", val_data.shape) 62 | print("test data: ", test_data.shape) 63 | print(f"will output the data to {args.output_path}") 64 | train_data.to_csv(f"{args.output_path}/train.csv", index=False) 65 | val_data.to_csv(f"{args.output_path}/val.csv", index=False) 66 | test_data.to_csv(f"{args.output_path}/test.csv", index=False) 67 | 68 | 69 | if __name__ == "__main__": 70 | import argparse 71 | parser = argparse.ArgumentParser(description="Process ALEX data") 72 | parser.add_argument("--input_path", type=str, default="/data/zdcao/crystal_gpt/dataset/alex/origin/", help="path to the input data") 73 | parser.add_argument("--output_path", type=str, default="/data/zdcao/crystal_gpt/dataset/alex/alex20_811/", help="path to the output data") 74 | parser.add_argument("--ratio", type=float, default=1.0, help="ratio of the data to be used") 75 | parser.add_argument('--num_io_process', type=int, default=20, help='number of process used in multiprocessing io') 76 | 77 | args = parser.parse_args() 78 | main(args) 79 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'crystalformer', 5 | version = '0.4.2', 6 | keywords='Crystal Generation', 7 | description = 'CrystalFormer is a transformer-based autoregressive model specifically designed for space group-controlled generation of crystalline materials.', 8 | license = 'Apache License', 9 | url = 'https://github.com/deepmodeling/CrystalFormer', 10 | author = 'iopcompphys', 11 | author_email = 'zdcao@iphy.ac.cn, wanglei@iphy.ac.cn', 12 | packages = find_packages(), 13 | include_package_data = True, 14 | package_data={ 15 | 'crystalformer': ['data/*.csv'], 16 | }, 17 | platforms = 'linux', 18 | install_requires = [], 19 | entry_points = { 20 | 'console_scripts': [ 21 | "train_ppo=crystalformer.cli.train_ppo:main", 22 | "train_dpo=crystalformer.cli.train_dpo:main", 23 | "classifier=crystalformer.cli.classifier:main", 24 | "cond_gen=crystalformer.cli.cond_gen:main", 25 | "dataset=crystalformer.cli.dataset:main", 26 | "spg_sample=crystalformer.cli.spg_sample:main", 27 | ] 28 | } 29 | ) 30 | -------------------------------------------------------------------------------- /tests/config.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | testdir = os.path.dirname(os.path.abspath(__file__)) 3 | rootdir = os.path.dirname(testdir) 4 | datadir = os.path.join(rootdir, "crystalformer/data") 5 | # sys.path.append(os.path.join(testdir, "../crystalformer/src")) 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | import haiku as hk 11 | -------------------------------------------------------------------------------- /tests/test_fc_mask.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | import pandas as pd 3 | import os 4 | import numpy as np 5 | import jax.numpy as jnp 6 | 7 | df = pd.read_csv(os.path.join(datadir, 'wyckoff_list.csv')) 8 | df['Wyckoff Positions'] = df['Wyckoff Positions'].apply(eval) # convert string to list 9 | wyckoff_positions = df['Wyckoff Positions'].tolist() 10 | 11 | def convert_to_binary_list(s): 12 | """ 13 | Converts a list of strings into a list of binary lists based on the presence of 'x', 'y', or 'z'. 14 | """ 15 | components = s.split(',') 16 | #TODO a better translation can be xxx->100 but not 111 17 | return [1 if any(char in comp for char in ['x', 'y', 'z']) else 0 for comp in components] 18 | 19 | fc_mask_list = [] 20 | for g, wp_list in enumerate(wyckoff_positions): 21 | sub_list = [] 22 | for wp in wp_list[::-1]: 23 | sub_list.append(convert_to_binary_list(wp[0])) 24 | fc_mask_list.append(sub_list) 25 | 26 | max_len = max(len(sub_list) for sub_list in fc_mask_list) 27 | 28 | fc_mask_table = np.zeros((len(fc_mask_list), max_len+1, 3), dtype=int) # (230, 28, 3) 29 | for i, sub_list in enumerate(fc_mask_list): 30 | for j, l in enumerate(sub_list): 31 | fc_mask_table[i, j+1, : ] = l # we have added a padding of W=0 32 | fc_mask_table = jnp.array(fc_mask_table) # 1 in the fc_mask_table select those active fractional coordinate 33 | 34 | from config import * 35 | 36 | def test_fc_mask(): 37 | from crystalformer.src.wyckoff import symops, wmax_table 38 | from crystalformer.src.wyckoff import fc_mask_table as fc_mask_table_test 39 | 40 | for g in range(1, 231): 41 | for w in range(1, wmax_table[g]+1): 42 | op = symops[g-1, w, 0] # 0 since we conly consider the first wyckoff point in the equivalent class when building fc_mask_table 43 | fc_mask = (op[:3, :3].sum(axis=1)!=0) 44 | assert jnp.allclose(fc_mask, fc_mask_table[g-1, w]) 45 | assert jnp.allclose(fc_mask, fc_mask_table_test[g-1, w]) 46 | 47 | test_fc_mask() 48 | -------------------------------------------------------------------------------- /tests/test_lattice.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | 3 | from crystalformer.src.lattice import symmetrize_lattice, make_lattice_mask 4 | 5 | def test_symmetrize_lattice(): 6 | key = jax.random.PRNGKey(42) 7 | 8 | G = jnp.arange(230) + 1 9 | L = jax.random.uniform(key, (6,)) 10 | L = L.reshape([1, 6]).repeat(230, axis=0) 11 | 12 | lattice = jax.jit(jax.vmap(symmetrize_lattice))(G, L) 13 | print (lattice) 14 | 15 | a, b, c, alpha, beta, gamma = lattice[99-1] 16 | assert (alpha==beta==gamma==90) 17 | assert (a==b) 18 | 19 | def test_make_mask(): 20 | 21 | def make_spacegroup_mask(spacegroup): 22 | ''' 23 | return mask for independent lattice params 24 | ''' 25 | 26 | mask = jnp.array([1, 1, 1, 1, 1, 1]) 27 | 28 | mask = jnp.where(spacegroup <= 2, mask, jnp.array([1, 1, 1, 0, 1, 0])) 29 | mask = jnp.where(spacegroup <= 15, mask, jnp.array([1, 1, 1, 0, 0, 0])) 30 | mask = jnp.where(spacegroup <= 74, mask, jnp.array([1, 0, 1, 0, 0, 0])) 31 | mask = jnp.where(spacegroup <= 142, mask, jnp.array([1, 0, 1, 0, 0, 0])) 32 | mask = jnp.where(spacegroup <= 194, mask, jnp.array([1, 0, 0, 0, 0, 0])) 33 | return mask 34 | 35 | mask = make_lattice_mask() 36 | 37 | for g in range(1, 231): 38 | assert jnp.allclose(mask[g-1] , make_spacegroup_mask(g)) 39 | 40 | test_symmetrize_lattice() 41 | test_make_mask() 42 | 43 | -------------------------------------------------------------------------------- /tests/test_sampling.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | from crystalformer.src.wyckoff import symops 3 | 4 | def test_symops(): 5 | from crystalformer.src.wyckoff import wmax_table, mult_table 6 | def project_x(g, w, x, idx): 7 | ''' 8 | One wants to project randomly sampled fc to the nearest Wyckoff point 9 | Alternately, we randomly select a Wyckoff point, and then project fc to that point 10 | To achieve that, we do the following 3 steps 11 | ''' 12 | w_max = wmax_table[g-1].item() 13 | m_max = mult_table[g-1, w_max].item() 14 | 15 | # (1) apply all space group symmetry op to the fc to get x 16 | ops = symops[g-1, w_max, :m_max] # (m_max, 3, 4) 17 | affine_point = jnp.array([*x, 1]) # (4, ) 18 | coords = ops@affine_point # (m_max, 3) 19 | coords -= jnp.floor(coords) 20 | 21 | # (2) search for the generator which satisfies op0(x) = x , i.e. the first Wyckoff position 22 | # here we solve it in a jit friendly way by looking for the minimal distance solution for the lhs and rhs 23 | #https://github.com/qzhu2017/PyXtal/blob/82e7d0eac1965c2713179eeda26a60cace06afc8/pyxtal/wyckoff_site.py#L115 24 | def dist_to_op0x(coord): 25 | diff = jnp.dot(symops[g-1, w, 0], jnp.array([*coord, 1])) - coord 26 | diff -= jnp.floor(diff) 27 | return jnp.sum(diff**2) 28 | loc = jnp.argmin(jax.vmap(dist_to_op0x)(coords)) 29 | x = coords[loc].reshape(3,) 30 | 31 | # (3) lastly, apply the given randomly sampled Wyckoff symmetry op to x 32 | op = symops[g-1, w, idx].reshape(3, 4) 33 | affine_point = jnp.array([*x, 1]) # (4, ) 34 | x = jnp.dot(op, affine_point) # (3, ) 35 | x -= jnp.floor(x) 36 | return x 37 | 38 | # these two tests shows that depending on the z coordinate (which is supposed to be rationals) 39 | # the WP can be recoginized differently, resulting different x 40 | # this motivate that we either predict idx in [1, m], or we predict all fc once there is a continuous dof 41 | g = 167 42 | w = jnp.array(5) 43 | idx = jnp.array(5) 44 | x = jnp.array([0.123, 0.123, 0.75]) 45 | y = project_x(g, w, x, idx) 46 | assert jnp.allclose(y, jnp.array([0.123, 0.123, 0.75])) 47 | 48 | x = jnp.array([0.123, 0.123, 0.25]) 49 | y = project_x(g, w, x, idx) 50 | assert jnp.allclose(y, jnp.array([0.877, 0.877, 0.75])) 51 | 52 | g = 225 53 | w = jnp.array(5) 54 | x = jnp.array([0., 0., 0.7334]) 55 | 56 | idx = jnp.array(0) 57 | y = project_x(g, w, x, idx) 58 | assert jnp.allclose(y, jnp.array([0.7334, 0., 0.])) 59 | 60 | idx = jnp.array(3) 61 | y = project_x(g, w, x, idx) 62 | assert jnp.allclose(y, jnp.array([0., 1.0-0.7334, 0.])) 63 | 64 | g = 166 65 | w = jnp.array(8) 66 | x = jnp.array([0.1, 0.2, 0.3]) 67 | 68 | idx = jnp.array(5) 69 | y = project_x(g, w, x, idx) 70 | assert jnp.allclose(y, jnp.array([1-0.1, 1-0.2, 1-0.3])) 71 | 72 | def test_sample_top_p(): 73 | from crystalformer.src.sample import sample_top_p 74 | key = jax.random.PRNGKey(42) 75 | logits = jnp.array([[1.0, 1.0, 2.0, 2.0, 3.0], 76 | [-1.0, 1.0, 4.0, 1.0, 0.0] 77 | ] 78 | ) 79 | p = 0.8 80 | temperature = 1.0 81 | k = jax.jit(sample_top_p, static_argnums=2)(key, logits, p, temperature) 82 | print (k) 83 | 84 | test_sample_top_p() 85 | test_symops() 86 | -------------------------------------------------------------------------------- /tests/test_transformer.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | 3 | from crystalformer.src.utils import GLXYZAW_from_file 4 | from crystalformer.src.wyckoff import mult_table 5 | from crystalformer.src.transformer import make_transformer 6 | 7 | def test_autoregressive(): 8 | atom_types = 119 9 | wyck_types = 28 10 | Nf = 8 11 | n_max = 21 12 | Kx = 16 13 | Kl = 8 14 | dim = 3 15 | dropout_rate = 0.0 16 | 17 | csv_file = os.path.join(datadir, '../../data/mini.csv') 18 | G, L, X, A, W = GLXYZAW_from_file(csv_file, atom_types, wyck_types, n_max, dim) 19 | 20 | @jax.vmap 21 | def lookup(G, W): 22 | return mult_table[G-1, W] # (n_max, ) 23 | M = lookup(G, W) # (batchsize, n_max) 24 | num_sites = jnp.sum(A!=0, axis=1) 25 | 26 | key = jax.random.PRNGKey(42) 27 | params, transformer = make_transformer(key, Nf, Kx, Kl, n_max, dim, 128, 4, 4, 8, 16,atom_types, wyck_types, dropout_rate) 28 | 29 | def test_fn(X, M): 30 | output = transformer(params, None, G[0], X, A[0], W[0], M, False) 31 | print (output.shape) 32 | return output.sum(axis=-1) 33 | 34 | jac_x = jax.jacfwd(test_fn, argnums=0)(X[0], M[0]) 35 | jac_m = jax.jacfwd(test_fn, argnums=1)(X[0], M[0].astype(jnp.float32))[:, :, None] 36 | 37 | print(jac_x.shape, jac_m.shape) 38 | 39 | def print_dependencey(jac): 40 | dependencey = jnp.linalg.norm(jac, axis=-1) 41 | for row in (dependencey != 0.).astype(int): 42 | print(" ".join(str(val) for val in row)) 43 | 44 | print ("jac_a_x") 45 | print_dependencey(jac_x[::2]) 46 | print ("jac_x_x") 47 | print_dependencey(jac_x[1::2]) 48 | print ("jac_a_a") 49 | print_dependencey(jac_m[::2]) 50 | print ("jac_x_a") 51 | print_dependencey(jac_m[1::2]) 52 | 53 | 54 | def test_perm(): 55 | 56 | key = jax.random.PRNGKey(42) 57 | 58 | #W = jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0]) 59 | W = jnp.array([1,2, 2, 2, 5, 0,0, 0]) 60 | n = len(W) 61 | key = jax.random.PRNGKey(42) 62 | 63 | temp = jnp.where(W>0, W, 9999) 64 | idx_perm = jax.random.permutation(key, jnp.arange(n)) 65 | temp = temp[idx_perm] 66 | idx_sort = jnp.argsort(temp) 67 | idx = idx_perm[idx_sort] 68 | 69 | print (idx) 70 | print (W) 71 | assert jnp.allclose(W, W[idx]) 72 | 73 | test_autoregressive() 74 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | 3 | from crystalformer.src.utils import GLXYZAW_from_file 4 | from crystalformer.src.wyckoff import mult_table 5 | 6 | def calc_n(G, W): 7 | @jax.vmap 8 | def lookup(G, W): 9 | return mult_table[G-1, W] # (n_max, ) 10 | M = lookup(G, W) # (batchsize, n_max) 11 | N = M.sum(axis=-1) 12 | return N 13 | 14 | def test_utils(): 15 | 16 | atom_types = 119 17 | mult_types = 10 18 | n_max = 10 19 | dim = 3 20 | csv_file = os.path.join(datadir, '../../data/mini.csv') 21 | 22 | G, L, X, A, W = GLXYZAW_from_file(csv_file, atom_types, mult_types, n_max, dim) 23 | 24 | assert G.ndim == 1 25 | assert L.ndim == 2 26 | assert L.shape[-1] == 6 27 | 28 | import numpy as np 29 | np.set_printoptions(threshold=np.inf) 30 | 31 | print ("A:\n", A) 32 | N = calc_n(G, W) 33 | 34 | assert jnp.all(N==5) 35 | 36 | if __name__ == '__main__': 37 | 38 | test_utils() 39 | --------------------------------------------------------------------------------