├── LICENSE ├── README.md ├── dataset.py ├── densitymodel.py ├── evaluate_model.py ├── layer.py ├── predict_with_model.py ├── pretrained_models ├── ethylenecarbonate_painn │ ├── arguments.json │ ├── best_model.pth │ ├── commandline_args.txt │ ├── gitdetails.txt │ ├── printlog.txt │ ├── slurmlog.log │ └── submit_script.sh ├── ethylenecarbonate_schnet │ ├── arguments.json │ ├── best_model.pth │ ├── commandline_args.txt │ ├── gitdetails.txt │ ├── printlog.txt │ ├── slurmlog.log │ └── submit_script.sh ├── nmc_painn │ ├── arguments.json │ ├── best_model.pth │ ├── commandline_args.txt │ ├── gitdetails.txt │ ├── printlog.txt │ ├── slurmlog.log │ └── submit_script.sh ├── nmc_schnet │ ├── arguments.json │ ├── best_model.pth │ ├── commandline_args.txt │ ├── gitdetails.txt │ ├── printlog.txt │ ├── slurmlog.log │ └── submit_script.sh ├── qm9_painn │ ├── arguments.json │ ├── best_model.pth │ ├── commandline_args.txt │ ├── gitdetails.txt │ ├── printlog.txt │ ├── slurmlog.log │ └── submit_script.sh └── qm9_schnet │ ├── arguments.json │ ├── best_model.pth │ ├── commandline_args.txt │ ├── datasplits.json │ └── printlog.txt ├── requirements.txt ├── runner.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Peter Bjørn Jørgensen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepDFT Model Implementation 2 | 3 | This is the official Implementation of the DeepDFT model for charge density prediction. 4 | 5 | ## Setup 6 | 7 | Create and activate a virtual environment and install the requirements: 8 | 9 | $ pip install -r requirements.txt 10 | 11 | ## Data 12 | 13 | Training data is expected to be a tar file containing `.cube` (Gaussian) or `.CHGCAR` (VASP) density files. 14 | For best performance the tar files should not be compressed, but the individual files inside the tar 15 | can use `zlib` compression (add `.zz` extension) or lz4 compression (add `.lz4` extension). 16 | The data can be split up in several tar files. In that case create a text (.txt) file 17 | in the same directory as the tar files. The text file must contain the file names of the tar files, one on each line. 18 | Then the text file can then be used as a dataset. 19 | 20 | ## Training the model 21 | 22 | Inspect `runner.py` arguments: 23 | 24 | $ python runner.py --help 25 | 26 | Example used for training the model on QM9: 27 | 28 | $ python runner.py --dataset datadir/qm9vasp.txt --split_file datadir/splits.json --ignore_pbc --cutoff 4 --num_interactions 6 --max_steps 100000000 --node_size 128 29 | 30 | Or to train the equivariant model on the ethylene carbonate dataset: 31 | 32 | $ python runner.py --dataset datadir/ethylenecarbonate.txt --split_file datadir/splits.json --cutoff 4 --num_interactions 3 --use_painn_model --max_steps 100000000 --node_size 128 33 | 34 | The json file contains two keys "train", and "validation" each with a list of indices for the train and validation sets. If the argument is omitted the data will be randomly split. 35 | 36 | ## Running the model on new data 37 | 38 | To use a trained model to predict the electron density around a new structure use the script `predict_with_model.py`. 39 | The first argument is the output directory of the runner script in which the trained model is saved. 40 | The second argument is an ASE compatible xyz file with atom coordinates for the structure to be predicted. 41 | 42 | For example: 43 | 44 | $ python predict_with_model.py pretrained_models/qm9_schnet example_molecule.xyz 45 | 46 | For more options see the `predict_with_model.py` optional arguments: 47 | 48 | $ python predict_with_model.py --help 49 | 50 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | import gzip 3 | import tarfile 4 | import tempfile 5 | import multiprocessing 6 | import queue 7 | import time 8 | import threading 9 | import logging 10 | import zlib 11 | import os 12 | import io 13 | import math 14 | import torch 15 | import torch.utils.data 16 | import lz4.frame 17 | import numpy as np 18 | import ase 19 | import ase.neighborlist 20 | import ase.io.cube 21 | import ase.units 22 | from ase.calculators.vasp import VaspChargeDensity 23 | import asap3 24 | 25 | from layer import pad_and_stack 26 | 27 | def _cell_heights(cell_object): 28 | volume = cell_object.volume 29 | crossproducts = np.cross(cell_object[[1, 2, 0]], cell_object[[2, 0, 1]]) 30 | crosslengths = np.sqrt(np.sum(np.square(crossproducts), axis=1)) 31 | heights = volume / crosslengths 32 | return heights 33 | 34 | 35 | def rotating_pool_worker(dataset, rng, queue): 36 | while True: 37 | for index in rng.permutation(len(dataset)).tolist(): 38 | queue.put(dataset[index]) 39 | 40 | 41 | def transfer_thread(queue: multiprocessing.Queue, datalist: list): 42 | while True: 43 | for index in range(len(datalist)): 44 | datalist[index] = queue.get() 45 | 46 | 47 | class RotatingPoolData(torch.utils.data.Dataset): 48 | """ 49 | Wrapper for a dataset that continously loads data into a smaller pool. 50 | The data loading is performed in a separate process and is assumed to be IO bound. 51 | """ 52 | 53 | def __init__(self, dataset, pool_size, **kwargs): 54 | super().__init__(**kwargs) 55 | self.pool_size = pool_size 56 | self.parent_data = dataset 57 | self.rng = np.random.default_rng() 58 | logging.debug("Filling rotating data pool of size %d" % pool_size) 59 | self.data_pool = [ 60 | self.parent_data[i] 61 | for i in self.rng.integers( 62 | 0, high=len(self.parent_data), size=self.pool_size, endpoint=False 63 | ).tolist() 64 | ] 65 | self.loader_queue = multiprocessing.Queue(2) 66 | 67 | # Start loaders 68 | self.loader_process = multiprocessing.Process( 69 | target=rotating_pool_worker, 70 | args=(self.parent_data, self.rng, self.loader_queue), 71 | ) 72 | self.transfer_thread = threading.Thread( 73 | target=transfer_thread, args=(self.loader_queue, self.data_pool) 74 | ) 75 | self.loader_process.start() 76 | self.transfer_thread.start() 77 | 78 | def __len__(self): 79 | return self.pool_size 80 | 81 | def __getitem__(self, index): 82 | return self.data_pool[index] 83 | 84 | 85 | class BufferData(torch.utils.data.Dataset): 86 | """ 87 | Wrapper for a dataset. Loads all data into memory. 88 | """ 89 | 90 | def __init__(self, dataset, **kwargs): 91 | super().__init__(**kwargs) 92 | 93 | self.data_objects = [dataset[i] for i in range(len(dataset))] 94 | 95 | def __len__(self): 96 | return len(self.data_objects) 97 | 98 | def __getitem__(self, index): 99 | return self.data_objects[index] 100 | 101 | class DensityData(torch.utils.data.Dataset): 102 | def __init__(self, datapath, **kwargs): 103 | super().__init__(**kwargs) 104 | if os.path.isfile(datapath) and datapath.endswith(".tar"): 105 | self.data = DensityDataTar(datapath) 106 | elif os.path.isdir(datapath): 107 | self.data = DensityDataDir(datapath) 108 | else: 109 | raise ValueError("Did not find dataset at path %s", datapath) 110 | def __len__(self): 111 | return len(self.data) 112 | 113 | def __getitem__(self, index): 114 | return self.data[index] 115 | 116 | class DensityDataDir(torch.utils.data.Dataset): 117 | def __init__(self, directory, **kwargs): 118 | super().__init__(**kwargs) 119 | 120 | self.directory = directory 121 | self.member_list = sorted(os.listdir(self.directory)) 122 | self.key_to_idx = {str(k): i for i,k in enumerate(self.member_list)} 123 | 124 | def __len__(self): 125 | return len(self.member_list) 126 | 127 | def extractfile(self, filename): 128 | path = os.path.join(self.directory, filename) 129 | 130 | filecontent = _decompress_file(path) 131 | if path.endswith((".cube", ".cube.gz", ".cube.zz", "cube.lz4")): 132 | density, atoms, origin = _read_cube(filecontent) 133 | else: 134 | density, atoms, origin = _read_vasp(filecontent) 135 | 136 | grid_pos = _calculate_grid_pos(density, origin, atoms.get_cell()) 137 | 138 | metadata = {"filename": filename} 139 | return { 140 | "density": density, 141 | "atoms": atoms, 142 | "origin": origin, 143 | "grid_position": grid_pos, 144 | "metadata": metadata, # Meta information 145 | } 146 | 147 | def __getitem__(self, index): 148 | if isinstance(index, str): 149 | index = self.key_to_idx[index] 150 | return self.extractfile(self.member_list[index]) 151 | 152 | 153 | class DensityDataTar(torch.utils.data.Dataset): 154 | def __init__(self, tarpath, **kwargs): 155 | super().__init__(**kwargs) 156 | 157 | self.tarpath = tarpath 158 | self.member_list = [] 159 | 160 | # Index tar file 161 | with tarfile.open(self.tarpath, "r:") as tar: 162 | for member in tar.getmembers(): 163 | self.member_list.append(member) 164 | self.key_to_idx = {str(k): i for i,k in enumerate(self.member_list)} 165 | 166 | def __len__(self): 167 | return len(self.member_list) 168 | 169 | def extract_member(self, tarinfo): 170 | with tarfile.open(self.tarpath, "r") as tar: 171 | filecontent = _decompress_tarmember(tar, tarinfo) 172 | if tarinfo.name.endswith((".cube", ".cube.gz", "cube.zz", "cube.lz4")): 173 | density, atoms, origin = _read_cube(filecontent) 174 | else: 175 | density, atoms, origin = _read_vasp(filecontent) 176 | 177 | grid_pos = _calculate_grid_pos(density, origin, atoms.get_cell()) 178 | 179 | metadata = {"filename": tarinfo.name} 180 | return { 181 | "density": density, 182 | "atoms": atoms, 183 | "origin": origin, 184 | "grid_position": grid_pos, 185 | "metadata": metadata, # Meta information 186 | } 187 | 188 | def __getitem__(self, index): 189 | if isinstance(index, str): 190 | index = self.key_to_idx[index] 191 | return self.extract_member(self.member_list[index]) 192 | 193 | 194 | class AseNeigborListWrapper: 195 | """ 196 | Wrapper around ASE neighborlist to have the same interface as asap3 neighborlist 197 | 198 | """ 199 | 200 | def __init__(self, cutoff, atoms): 201 | self.neighborlist = ase.neighborlist.NewPrimitiveNeighborList( 202 | cutoff, skin=0.0, self_interaction=False, bothways=True 203 | ) 204 | self.neighborlist.build( 205 | atoms.get_pbc(), atoms.get_cell(), atoms.get_positions() 206 | ) 207 | self.cutoff = cutoff 208 | self.atoms_positions = atoms.get_positions() 209 | self.atoms_cell = atoms.get_cell() 210 | 211 | def get_neighbors(self, i, cutoff): 212 | assert ( 213 | cutoff == self.cutoff 214 | ), "Cutoff must be the same as used to initialise the neighborlist" 215 | 216 | indices, offsets = self.neighborlist.get_neighbors(i) 217 | 218 | rel_positions = ( 219 | self.atoms_positions[indices] 220 | + offsets @ self.atoms_cell 221 | - self.atoms_positions[i][None] 222 | ) 223 | 224 | dist2 = np.sum(np.square(rel_positions), axis=1) 225 | 226 | return indices, rel_positions, dist2 227 | 228 | 229 | def grid_iterator_worker(atoms, meshgrid, probe_count, cutoff, slice_id_queue, result_queue): 230 | try: 231 | neighborlist = asap3.FullNeighborList(cutoff, atoms) 232 | except Exception as e: 233 | logging.info("Failed to create asap3 neighborlist, this might be very slow. Error: %s", e) 234 | neighborlist = None 235 | while True: 236 | try: 237 | slice_id = slice_id_queue.get(True, 1) 238 | except queue.Empty: 239 | while not result_queue.empty(): 240 | time.sleep(1) 241 | result_queue.close() 242 | return 0 243 | res = DensityGridIterator.static_get_slice(slice_id, atoms, meshgrid, probe_count, cutoff, neighborlist=neighborlist) 244 | result_queue.put((slice_id, res)) 245 | 246 | class DensityGridIterator: 247 | def __init__(self, densitydict, probe_count: int, cutoff: float, set_pbc_to: Optional[bool] = None): 248 | num_positions = np.prod(densitydict["grid_position"].shape[0:3]) 249 | self.num_slices = int(math.ceil(num_positions / probe_count)) 250 | self.probe_count = probe_count 251 | self.cutoff = cutoff 252 | self.set_pbc = set_pbc_to 253 | 254 | if self.set_pbc is not None: 255 | self.atoms = densitydict["atoms"].copy() 256 | self.atoms.set_pbc(self.set_pbc) 257 | else: 258 | self.atoms = densitydict["atoms"] 259 | 260 | self.meshgrid = densitydict["grid_position"] 261 | 262 | def get_slice(self, slice_index): 263 | return self.static_get_slice(slice_index, self.atoms, self.meshgrid, self.probe_count, self.cutoff) 264 | 265 | @staticmethod 266 | def static_get_slice(slice_index, atoms, meshgrid, probe_count, cutoff, neighborlist=None): 267 | num_positions = np.prod(meshgrid.shape[0:3]) 268 | flat_index = np.arange(slice_index*probe_count, min((slice_index+1)*probe_count, num_positions)) 269 | pos_index = np.unravel_index(flat_index, meshgrid.shape[0:3]) 270 | probe_pos = meshgrid[pos_index] 271 | probe_edges, probe_edges_displacement = probes_to_graph(atoms, probe_pos, cutoff, neighborlist) 272 | 273 | if not probe_edges: 274 | probe_edges = [np.zeros((0,2), dtype=np.int)] 275 | probe_edges_displacement = [np.zeros((0,3), dtype=np.float32)] 276 | 277 | res = { 278 | "probe_edges": np.concatenate(probe_edges, axis=0), 279 | "probe_edges_displacement": np.concatenate(probe_edges_displacement, axis=0).astype(np.float32), 280 | } 281 | res["num_probe_edges"] = res["probe_edges"].shape[0] 282 | res["num_probes"] = len(flat_index) 283 | res["probe_xyz"] = probe_pos.astype(np.float32) 284 | 285 | return res 286 | 287 | 288 | def __iter__(self): 289 | self.current_slice = 0 290 | slice_id_queue = multiprocessing.Queue() 291 | self.result_queue = multiprocessing.Queue(100) 292 | self.finished_slices = dict() 293 | for i in range(self.num_slices): 294 | slice_id_queue.put(i) 295 | self.workers = [multiprocessing.Process(target=grid_iterator_worker, args=(self.atoms, self.meshgrid, self.probe_count, self.cutoff, slice_id_queue, self.result_queue)) for _ in range(6)] 296 | for w in self.workers: 297 | w.start() 298 | return self 299 | 300 | def __next__(self): 301 | if self.current_slice < self.num_slices: 302 | this_slice = self.current_slice 303 | self.current_slice += 1 304 | 305 | # Retrieve finished slices until we get the one we are looking for 306 | while this_slice not in self.finished_slices: 307 | i, res = self.result_queue.get() 308 | res = {k: torch.tensor(v) for k,v in res.items()} # convert to torch tensor 309 | self.finished_slices[i] = res 310 | return self.finished_slices.pop(this_slice) 311 | else: 312 | for w in self.workers: 313 | w.join() 314 | raise StopIteration 315 | 316 | 317 | def atoms_and_probe_sample_to_graph_dict(density, atoms, grid_pos, cutoff, num_probes): 318 | # Sample probes on the calculated grid 319 | probe_choice_max = np.prod(grid_pos.shape[0:3]) 320 | probe_choice = np.random.randint(probe_choice_max, size=num_probes) 321 | probe_choice = np.unravel_index(probe_choice, grid_pos.shape[0:3]) 322 | probe_pos = grid_pos[probe_choice] 323 | probe_target = density[probe_choice] 324 | 325 | atom_edges, atom_edges_displacement, neighborlist, inv_cell_T = atoms_to_graph(atoms, cutoff) 326 | probe_edges, probe_edges_displacement = probes_to_graph(atoms, probe_pos, cutoff, neighborlist=neighborlist, inv_cell_T=inv_cell_T) 327 | 328 | default_type = torch.get_default_dtype() 329 | 330 | if not probe_edges: 331 | probe_edges = [np.zeros((0,2), dtype=np.int)] 332 | probe_edges_displacement = [np.zeros((0,3), dtype=np.int)] 333 | # pylint: disable=E1102 334 | res = { 335 | "nodes": torch.tensor(atoms.get_atomic_numbers()), 336 | "atom_edges": torch.tensor(np.concatenate(atom_edges, axis=0)), 337 | "atom_edges_displacement": torch.tensor( 338 | np.concatenate(atom_edges_displacement, axis=0), dtype=default_type 339 | ), 340 | "probe_edges": torch.tensor(np.concatenate(probe_edges, axis=0)), 341 | "probe_edges_displacement": torch.tensor( 342 | np.concatenate(probe_edges_displacement, axis=0), dtype=default_type 343 | ), 344 | "probe_target": torch.tensor(probe_target, dtype=default_type), 345 | } 346 | res["num_nodes"] = torch.tensor(res["nodes"].shape[0]) 347 | res["num_atom_edges"] = torch.tensor(res["atom_edges"].shape[0]) 348 | res["num_probe_edges"] = torch.tensor(res["probe_edges"].shape[0]) 349 | res["num_probes"] = torch.tensor(res["probe_target"].shape[0]) 350 | res["probe_xyz"] = torch.tensor(probe_pos, dtype=default_type) 351 | res["atom_xyz"] = torch.tensor(atoms.get_positions(), dtype=default_type) 352 | res["cell"] = torch.tensor(np.array(atoms.get_cell()), dtype=default_type) 353 | 354 | return res 355 | 356 | def atoms_to_graph_dict(atoms, cutoff): 357 | atom_edges, atom_edges_displacement, _, _ = atoms_to_graph(atoms, cutoff) 358 | 359 | default_type = torch.get_default_dtype() 360 | 361 | # pylint: disable=E1102 362 | res = { 363 | "nodes": torch.tensor(atoms.get_atomic_numbers()), 364 | "atom_edges": torch.tensor(np.concatenate(atom_edges, axis=0)), 365 | "atom_edges_displacement": torch.tensor( 366 | np.concatenate(atom_edges_displacement, axis=0), dtype=default_type 367 | ), 368 | } 369 | res["num_nodes"] = torch.tensor(res["nodes"].shape[0]) 370 | res["num_atom_edges"] = torch.tensor(res["atom_edges"].shape[0]) 371 | res["atom_xyz"] = torch.tensor(atoms.get_positions(), dtype=default_type) 372 | res["cell"] = torch.tensor(np.array(atoms.get_cell()), dtype=default_type) 373 | 374 | return res 375 | 376 | def atoms_to_graph(atoms, cutoff): 377 | atom_edges = [] 378 | atom_edges_displacement = [] 379 | 380 | inv_cell_T = np.linalg.inv(atoms.get_cell().complete().T) 381 | 382 | # Compute neighborlist 383 | if ( 384 | np.any(atoms.get_cell().lengths() <= 0.0001) 385 | or ( 386 | np.any(atoms.get_pbc()) 387 | and np.any(_cell_heights(atoms.get_cell()) < cutoff) 388 | ) 389 | ): 390 | neighborlist = AseNeigborListWrapper(cutoff, atoms) 391 | else: 392 | neighborlist = asap3.FullNeighborList(cutoff, atoms) 393 | 394 | atom_positions = atoms.get_positions() 395 | 396 | for i in range(len(atoms)): 397 | neigh_idx, neigh_vec, _ = neighborlist.get_neighbors(i, cutoff) 398 | 399 | self_index = np.ones_like(neigh_idx) * i 400 | edges = np.stack((neigh_idx, self_index), axis=1) 401 | 402 | neigh_pos = atom_positions[neigh_idx] 403 | this_pos = atom_positions[i] 404 | neigh_origin = neigh_vec + this_pos - neigh_pos 405 | neigh_origin_scaled = np.round(inv_cell_T.dot(neigh_origin.T).T) 406 | 407 | atom_edges.append(edges) 408 | atom_edges_displacement.append(neigh_origin_scaled) 409 | 410 | return atom_edges, atom_edges_displacement, neighborlist, inv_cell_T 411 | 412 | def probes_to_graph(atoms, probe_pos, cutoff, neighborlist=None, inv_cell_T=None): 413 | probe_edges = [] 414 | probe_edges_displacement = [] 415 | if inv_cell_T is None: 416 | inv_cell_T = np.linalg.inv(atoms.get_cell().complete().T) 417 | 418 | if hasattr(neighborlist, "get_neighbors_querypoint"): 419 | results = neighborlist.get_neighbors_querypoint(probe_pos, cutoff) 420 | atomic_numbers = atoms.get_atomic_numbers() 421 | else: 422 | # Insert probe atoms 423 | num_probes = probe_pos.shape[0] 424 | probe_atoms = ase.Atoms(numbers=[0] * num_probes, positions=probe_pos) 425 | atoms_with_probes = atoms.copy() 426 | atoms_with_probes.extend(probe_atoms) 427 | atomic_numbers = atoms_with_probes.get_atomic_numbers() 428 | 429 | if ( 430 | np.any(atoms.get_cell().lengths() <= 0.0001) 431 | or ( 432 | np.any(atoms.get_pbc()) 433 | and np.any(_cell_heights(atoms.get_cell()) < cutoff) 434 | ) 435 | ): 436 | neighborlist = AseNeigborListWrapper(cutoff, atoms_with_probes) 437 | else: 438 | neighborlist = asap3.FullNeighborList(cutoff, atoms_with_probes) 439 | 440 | results = [neighborlist.get_neighbors(i+len(atoms), cutoff) for i in range(num_probes)] 441 | 442 | atom_positions = atoms.get_positions() 443 | for i, (neigh_idx, neigh_vec, _) in enumerate(results): 444 | neigh_atomic_species = atomic_numbers[neigh_idx] 445 | 446 | neigh_is_atom = neigh_atomic_species != 0 447 | neigh_atoms = neigh_idx[neigh_is_atom] 448 | self_index = np.ones_like(neigh_atoms) * i 449 | edges = np.stack((neigh_atoms, self_index), axis=1) 450 | 451 | neigh_pos = atom_positions[neigh_atoms] 452 | this_pos = probe_pos[i] 453 | neigh_origin = neigh_vec[neigh_is_atom] + this_pos - neigh_pos 454 | neigh_origin_scaled = np.round(inv_cell_T.dot(neigh_origin.T).T) 455 | 456 | probe_edges.append(edges) 457 | probe_edges_displacement.append(neigh_origin_scaled) 458 | 459 | return probe_edges, probe_edges_displacement 460 | 461 | def collate_list_of_dicts(list_of_dicts, pin_memory=False): 462 | # Convert from "list of dicts" to "dict of lists" 463 | dict_of_lists = {k: [dic[k] for dic in list_of_dicts] for k in list_of_dicts[0]} 464 | 465 | # Convert each list of tensors to single tensor with pad and stack 466 | if pin_memory: 467 | pin = lambda x: x.pin_memory() 468 | else: 469 | pin = lambda x: x 470 | 471 | collated = {k: pin(pad_and_stack(dict_of_lists[k])) for k in dict_of_lists} 472 | return collated 473 | 474 | class CollateFuncRandomSample: 475 | def __init__(self, cutoff, num_probes, pin_memory=True, set_pbc_to=None): 476 | self.num_probes = num_probes 477 | self.cutoff = cutoff 478 | self.pin_memory = pin_memory 479 | self.set_pbc = set_pbc_to 480 | 481 | def __call__(self, input_dicts: List): 482 | graphs = [] 483 | for i in input_dicts: 484 | if self.set_pbc is not None: 485 | atoms = i["atoms"].copy() 486 | atoms.set_pbc(self.set_pbc) 487 | else: 488 | atoms = i["atoms"] 489 | 490 | graphs.append(atoms_and_probe_sample_to_graph_dict( 491 | i["density"], 492 | atoms, 493 | i["grid_position"], 494 | self.cutoff, 495 | self.num_probes, 496 | )) 497 | 498 | return collate_list_of_dicts(graphs, pin_memory=self.pin_memory) 499 | 500 | class CollateFuncAtoms: 501 | def __init__(self, cutoff, pin_memory=True, set_pbc_to=None): 502 | self.cutoff = cutoff 503 | self.pin_memory = pin_memory 504 | self.set_pbc = set_pbc_to 505 | 506 | def __call__(self, input_dicts: List): 507 | graphs = [] 508 | for i in input_dicts: 509 | if self.set_pbc is not None: 510 | atoms = i["atoms"].copy() 511 | atoms.set_pbc(self.set_pbc) 512 | else: 513 | atoms = i["atoms"] 514 | 515 | graphs.append(atoms_to_graph_dict( 516 | atoms, 517 | self.cutoff, 518 | )) 519 | 520 | return collate_list_of_dicts(graphs, pin_memory=self.pin_memory) 521 | 522 | 523 | def _calculate_grid_pos(density, origin, cell): 524 | # Calculate grid positions 525 | ngridpts = np.array(density.shape) # grid matrix 526 | grid_pos = np.meshgrid( 527 | np.arange(ngridpts[0]) / density.shape[0], 528 | np.arange(ngridpts[1]) / density.shape[1], 529 | np.arange(ngridpts[2]) / density.shape[2], 530 | indexing="ij", 531 | ) 532 | grid_pos = np.stack(grid_pos, 3) 533 | grid_pos = np.dot(grid_pos, cell) 534 | grid_pos = grid_pos + origin 535 | return grid_pos 536 | 537 | 538 | def _decompress_tarmember(tar, tarinfo): 539 | """Extract compressed tar file member and return a bytes object with the content""" 540 | 541 | bytesobj = tar.extractfile(tarinfo).read() 542 | if tarinfo.name.endswith(".zz"): 543 | filecontent = zlib.decompress(bytesobj) 544 | elif tarinfo.name.endswith(".lz4"): 545 | filecontent = lz4.frame.decompress(bytesobj) 546 | elif tarinfo.name.endswith(".gz"): 547 | filecontent = gzip.decompress(bytesobj) 548 | else: 549 | filecontent = bytesobj 550 | 551 | return filecontent 552 | 553 | def _decompress_file(filepath): 554 | if filepath.endswith(".zz"): 555 | with open(filepath, "rb") as fp: 556 | f_bytes = fp.read() 557 | filecontent = zlib.decompress(f_bytes) 558 | elif filepath.endswith(".lz4"): 559 | with lz4.frame.open(filepath, mode="rb") as fp: 560 | filecontent = fp.read() 561 | elif filepath.endswith(".gz"): 562 | with gzip.open(filepath, mode="rb") as fp: 563 | filecontent = fp.read() 564 | else: 565 | with open(filepath, mode="rb") as fp: 566 | filecontent = fp.read() 567 | return filecontent 568 | 569 | def _read_vasp(filecontent): 570 | # Write to tmp file and read using ASE 571 | tmpfd, tmppath = tempfile.mkstemp(prefix="tmpdeepdft") 572 | tmpfile = os.fdopen(tmpfd, "wb") 573 | tmpfile.write(filecontent) 574 | tmpfile.close() 575 | vasp_charge = VaspChargeDensity(filename=tmppath) 576 | os.remove(tmppath) 577 | density = vasp_charge.chg[-1] # separate density 578 | atoms = vasp_charge.atoms[-1] # separate atom positions 579 | 580 | return density, atoms, np.zeros(3) # TODO: Can we always assume origin at 0,0,0? 581 | 582 | 583 | def _read_cube(filecontent): 584 | textbuf = io.StringIO(filecontent.decode()) 585 | cube = ase.io.cube.read_cube(textbuf) 586 | # sometimes there is an entry at index 3 587 | # denoting the number of values for each grid position 588 | origin = cube["origin"][0:3] 589 | # by convention the cube electron density is given in electrons/Bohr^3, 590 | # and ase read_cube does not convert to electrons/Å^3, so we do the conversion here 591 | cube["data"] *= 1.0 / ase.units.Bohr ** 3 592 | return cube["data"], cube["atoms"], origin 593 | -------------------------------------------------------------------------------- /densitymodel.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | import math 3 | import ase 4 | import torch 5 | from torch import nn 6 | import layer 7 | from layer import ShiftedSoftplus 8 | 9 | 10 | class DensityModel(nn.Module): 11 | def __init__( 12 | self, 13 | num_interactions, 14 | hidden_state_size, 15 | cutoff, 16 | gaussian_expansion_step=0.1, 17 | **kwargs, 18 | ): 19 | super().__init__(**kwargs) 20 | 21 | self.atom_model = AtomRepresentationModel( 22 | num_interactions, 23 | hidden_state_size, 24 | cutoff, 25 | gaussian_expansion_step, 26 | ) 27 | 28 | self.probe_model = ProbeMessageModel( 29 | num_interactions, 30 | hidden_state_size, 31 | cutoff, 32 | gaussian_expansion_step, 33 | ) 34 | 35 | def forward(self, input_dict): 36 | atom_representation = self.atom_model(input_dict) 37 | probe_result = self.probe_model(input_dict, atom_representation) 38 | return probe_result 39 | 40 | class PainnDensityModel(nn.Module): 41 | def __init__( 42 | self, 43 | num_interactions, 44 | hidden_state_size, 45 | cutoff, 46 | distance_embedding_size=30, 47 | **kwargs, 48 | ): 49 | super().__init__(**kwargs) 50 | 51 | self.atom_model = PainnAtomRepresentationModel( 52 | num_interactions, 53 | hidden_state_size, 54 | cutoff, 55 | distance_embedding_size, 56 | ) 57 | 58 | self.probe_model = PainnProbeMessageModel( 59 | num_interactions, 60 | hidden_state_size, 61 | cutoff, 62 | distance_embedding_size, 63 | ) 64 | 65 | def forward(self, input_dict): 66 | atom_representation_scalar, atom_representation_vector = self.atom_model(input_dict) 67 | probe_result = self.probe_model(input_dict, atom_representation_scalar, atom_representation_vector) 68 | return probe_result 69 | 70 | 71 | class ProbeMessageModel(nn.Module): 72 | def __init__( 73 | self, 74 | num_interactions, 75 | hidden_state_size, 76 | cutoff, 77 | gaussian_expansion_step, 78 | **kwargs, 79 | ): 80 | 81 | super().__init__(**kwargs) 82 | 83 | self.num_interactions = num_interactions 84 | self.hidden_state_size = hidden_state_size 85 | self.cutoff = cutoff 86 | self.gaussian_expansion_step = gaussian_expansion_step 87 | 88 | edge_size = int(math.ceil(self.cutoff / self.gaussian_expansion_step)) 89 | 90 | # Setup interaction networks 91 | self.messagesum_layers = nn.ModuleList( 92 | [ 93 | layer.MessageSum( 94 | hidden_state_size, edge_size, self.cutoff, include_receiver=True 95 | ) 96 | for _ in range(num_interactions) 97 | ] 98 | ) 99 | 100 | # Setup transitions networks 101 | self.probe_state_gate_functions = nn.ModuleList( 102 | [ 103 | nn.Sequential( 104 | nn.Linear(hidden_state_size, hidden_state_size), 105 | ShiftedSoftplus(), 106 | nn.Linear(hidden_state_size, hidden_state_size), 107 | nn.Sigmoid(), 108 | ) 109 | for _ in range(num_interactions) 110 | ] 111 | ) 112 | self.probe_state_transition_functions = nn.ModuleList( 113 | [ 114 | nn.Sequential( 115 | nn.Linear(hidden_state_size, hidden_state_size), 116 | ShiftedSoftplus(), 117 | nn.Linear(hidden_state_size, hidden_state_size), 118 | ) 119 | for _ in range(num_interactions) 120 | ] 121 | ) 122 | 123 | # Setup readout function 124 | self.readout_function = nn.Sequential( 125 | nn.Linear(hidden_state_size, hidden_state_size), 126 | ShiftedSoftplus(), 127 | nn.Linear(hidden_state_size, 1), 128 | ) 129 | 130 | def forward( 131 | self, 132 | input_dict: Dict[str, torch.Tensor], 133 | atom_representation: List[torch.Tensor], 134 | compute_iri=False, 135 | compute_dori=False, 136 | compute_hessian=False, 137 | ): 138 | if compute_iri or compute_dori or compute_hessian: 139 | input_dict["probe_xyz"].requires_grad_() 140 | 141 | # Unpad and concatenate edges and features into batch (0th) dimension 142 | atom_xyz = layer.unpad_and_cat(input_dict["atom_xyz"], input_dict["num_nodes"]) 143 | probe_xyz = layer.unpad_and_cat( 144 | input_dict["probe_xyz"], input_dict["num_probes"] 145 | ) 146 | edge_offset = torch.cumsum( 147 | torch.cat( 148 | ( 149 | torch.tensor([0], device=input_dict["num_nodes"].device), 150 | input_dict["num_nodes"][:-1], 151 | ) 152 | ), 153 | dim=0, 154 | ) 155 | edge_offset = edge_offset[:, None, None] 156 | 157 | # Unpad and concatenate probe edges into batch (0th) dimension 158 | probe_edges_displacement = layer.unpad_and_cat( 159 | input_dict["probe_edges_displacement"], input_dict["num_probe_edges"] 160 | ) 161 | edge_probe_offset = torch.cumsum( 162 | torch.cat( 163 | ( 164 | torch.tensor([0], device=input_dict["num_probes"].device), 165 | input_dict["num_probes"][:-1], 166 | ) 167 | ), 168 | dim=0, 169 | ) 170 | edge_probe_offset = edge_probe_offset[:, None, None] 171 | edge_probe_offset = torch.cat((edge_offset, edge_probe_offset), dim=2) 172 | probe_edges = input_dict["probe_edges"] + edge_probe_offset 173 | probe_edges = layer.unpad_and_cat(probe_edges, input_dict["num_probe_edges"]) 174 | 175 | # Compute edge distances 176 | probe_edges_features = layer.calc_distance_to_probe( 177 | atom_xyz, 178 | probe_xyz, 179 | input_dict["cell"], 180 | probe_edges, 181 | probe_edges_displacement, 182 | input_dict["num_probe_edges"], 183 | ) 184 | 185 | # Expand edge features in Gaussian basis 186 | probe_edge_state = layer.gaussian_expansion( 187 | probe_edges_features, [(0.0, self.gaussian_expansion_step, self.cutoff)] 188 | ) 189 | 190 | # Apply interaction layers 191 | probe_state = torch.zeros( 192 | (torch.sum(input_dict["num_probes"]), self.hidden_state_size), 193 | device=atom_representation[0].device, 194 | ) 195 | for msg_layer, gate_layer, state_layer, nodes in zip( 196 | self.messagesum_layers, 197 | self.probe_state_gate_functions, 198 | self.probe_state_transition_functions, 199 | atom_representation, 200 | ): 201 | msgsum = msg_layer( 202 | nodes, 203 | probe_edges, 204 | probe_edge_state, 205 | probe_edges_features, 206 | probe_state, 207 | ) 208 | gates = gate_layer(probe_state) 209 | probe_state = probe_state * gates + (1 - gates) * state_layer(msgsum) 210 | 211 | # Restack probe states 212 | probe_output = self.readout_function(probe_state).squeeze(1) 213 | probe_output = layer.pad_and_stack( 214 | torch.split( 215 | probe_output, 216 | list(input_dict["num_probes"].detach().cpu().numpy()), 217 | dim=0, 218 | ) 219 | # torch.split(probe_output, input_dict["num_probes"], dim=0) 220 | # probe_output.reshape((-1, input_dict["num_probes"][0])) 221 | ) 222 | 223 | if compute_iri or compute_dori or compute_hessian: 224 | dp_dxyz = torch.autograd.grad( 225 | probe_output, 226 | input_dict["probe_xyz"], 227 | grad_outputs=torch.ones_like(probe_output), 228 | retain_graph=True, 229 | create_graph=True, 230 | )[0] 231 | 232 | grad_probe_outputs = {} 233 | 234 | if compute_iri: 235 | iri = torch.linalg.norm(dp_dxyz, dim=2)/(torch.pow(probe_output, 1.1)) 236 | grad_probe_outputs["iri"] = iri 237 | 238 | if compute_dori: 239 | ## 240 | ## DORI(r) = phi(r) / (1 + phi(r)) 241 | ## phi(r) = ||grad(||grad(rho(r))/rho||^2)||^2 / ||grad(rho(r))/rho(r)||^6 242 | ## 243 | norm_grad_2 = torch.linalg.norm(dp_dxyz/torch.unsqueeze(probe_output, 2), dim=2)**2 244 | 245 | grad_norm_grad_2 = torch.autograd.grad( 246 | norm_grad_2, 247 | input_dict["probe_xyz"], 248 | grad_outputs=torch.ones_like(norm_grad_2), 249 | only_inputs=True, 250 | retain_graph=True, 251 | create_graph=True, 252 | )[0].detach() 253 | 254 | phi_r = torch.linalg.norm(grad_norm_grad_2, dim=2)**2 / (norm_grad_2**3) 255 | 256 | dori = phi_r / (1 + phi_r) 257 | grad_probe_outputs["dori"] = dori 258 | 259 | if compute_hessian: 260 | hessian_shape = (input_dict["probe_xyz"].shape[0], input_dict["probe_xyz"].shape[1], 3, 3) 261 | hessian = torch.zeros(hessian_shape, device=probe_xyz.device, dtype=probe_xyz.dtype) 262 | for dim_idx, grad_out in enumerate(torch.unbind(dp_dxyz, dim=-1)): 263 | dp2_dxyz2 = torch.autograd.grad( 264 | grad_out, 265 | input_dict["probe_xyz"], 266 | grad_outputs=torch.ones_like(grad_out), 267 | only_inputs=True, 268 | retain_graph=True, 269 | create_graph=True, 270 | )[0] 271 | hessian[:, :, dim_idx] = dp2_dxyz2 272 | grad_probe_outputs["hessian"] = hessian 273 | 274 | 275 | if grad_probe_outputs: 276 | return probe_output, grad_probe_outputs 277 | else: 278 | return probe_output 279 | 280 | 281 | class AtomRepresentationModel(nn.Module): 282 | def __init__( 283 | self, 284 | num_interactions, 285 | hidden_state_size, 286 | cutoff, 287 | gaussian_expansion_step, 288 | **kwargs, 289 | ): 290 | 291 | super().__init__(**kwargs) 292 | 293 | self.num_interactions = num_interactions 294 | self.hidden_state_size = hidden_state_size 295 | self.cutoff = cutoff 296 | self.gaussian_expansion_step = gaussian_expansion_step 297 | 298 | edge_size = int(math.ceil(self.cutoff / self.gaussian_expansion_step)) 299 | 300 | # Setup interaction networks 301 | self.interactions = nn.ModuleList( 302 | [ 303 | layer.Interaction( 304 | hidden_state_size, edge_size, self.cutoff, include_receiver=True 305 | ) 306 | for _ in range(num_interactions) 307 | ] 308 | ) 309 | 310 | # Atom embeddings 311 | self.atom_embeddings = nn.Embedding( 312 | len(ase.data.atomic_numbers), self.hidden_state_size 313 | ) 314 | 315 | def forward(self, input_dict): 316 | # Unpad and concatenate edges and features into batch (0th) dimension 317 | edges_displacement = layer.unpad_and_cat( 318 | input_dict["atom_edges_displacement"], input_dict["num_atom_edges"] 319 | ) 320 | edge_offset = torch.cumsum( 321 | torch.cat( 322 | ( 323 | torch.tensor([0], device=input_dict["num_nodes"].device), 324 | input_dict["num_nodes"][:-1], 325 | ) 326 | ), 327 | dim=0, 328 | ) 329 | edge_offset = edge_offset[:, None, None] 330 | edges = input_dict["atom_edges"] + edge_offset 331 | edges = layer.unpad_and_cat(edges, input_dict["num_atom_edges"]) 332 | 333 | # Unpad and concatenate all nodes into batch (0th) dimension 334 | atom_xyz = layer.unpad_and_cat(input_dict["atom_xyz"], input_dict["num_nodes"]) 335 | nodes = layer.unpad_and_cat(input_dict["nodes"], input_dict["num_nodes"]) 336 | nodes = self.atom_embeddings(nodes) 337 | 338 | # Compute edge distances 339 | edges_features = layer.calc_distance( 340 | atom_xyz, 341 | input_dict["cell"], 342 | edges, 343 | edges_displacement, 344 | input_dict["num_atom_edges"], 345 | ) 346 | 347 | # Expand edge features in Gaussian basis 348 | edge_state = layer.gaussian_expansion( 349 | edges_features, [(0.0, self.gaussian_expansion_step, self.cutoff)] 350 | ) 351 | 352 | nodes_list = [] 353 | # Apply interaction layers 354 | for int_layer in self.interactions: 355 | nodes = int_layer(nodes, edges, edge_state, edges_features) 356 | nodes_list.append(nodes) 357 | 358 | return nodes_list 359 | 360 | 361 | class PainnAtomRepresentationModel(nn.Module): 362 | def __init__( 363 | self, 364 | num_interactions, 365 | hidden_state_size, 366 | cutoff, 367 | distance_embedding_size, 368 | **kwargs, 369 | ): 370 | 371 | super().__init__(**kwargs) 372 | 373 | self.num_interactions = num_interactions 374 | self.hidden_state_size = hidden_state_size 375 | self.cutoff = cutoff 376 | self.distance_embedding_size = distance_embedding_size 377 | 378 | # Setup interaction networks 379 | self.interactions = nn.ModuleList( 380 | [ 381 | layer.PaiNNInteraction( 382 | hidden_state_size, self.distance_embedding_size, self.cutoff 383 | ) 384 | for _ in range(num_interactions) 385 | ] 386 | ) 387 | self.scalar_vector_update = nn.ModuleList( 388 | [layer.PaiNNUpdate(hidden_state_size) for _ in range(num_interactions)] 389 | ) 390 | 391 | # Atom embeddings 392 | self.atom_embeddings = nn.Embedding( 393 | len(ase.data.atomic_numbers), self.hidden_state_size 394 | ) 395 | 396 | def forward(self, input_dict): 397 | # Unpad and concatenate edges and features into batch (0th) dimension 398 | edges_displacement = layer.unpad_and_cat( 399 | input_dict["atom_edges_displacement"], input_dict["num_atom_edges"] 400 | ) 401 | edge_offset = torch.cumsum( 402 | torch.cat( 403 | ( 404 | torch.tensor([0], device=input_dict["num_nodes"].device), 405 | input_dict["num_nodes"][:-1], 406 | ) 407 | ), 408 | dim=0, 409 | ) 410 | edge_offset = edge_offset[:, None, None] 411 | edges = input_dict["atom_edges"] + edge_offset 412 | edges = layer.unpad_and_cat(edges, input_dict["num_atom_edges"]) 413 | 414 | # Unpad and concatenate all nodes into batch (0th) dimension 415 | atom_xyz = layer.unpad_and_cat(input_dict["atom_xyz"], input_dict["num_nodes"]) 416 | nodes_scalar = layer.unpad_and_cat(input_dict["nodes"], input_dict["num_nodes"]) 417 | nodes_scalar = self.atom_embeddings(nodes_scalar) 418 | nodes_vector = torch.zeros( 419 | (nodes_scalar.shape[0], 3, self.hidden_state_size), 420 | dtype=nodes_scalar.dtype, 421 | device=nodes_scalar.device, 422 | ) 423 | 424 | # Compute edge distances 425 | edges_distance, edges_diff = layer.calc_distance( 426 | atom_xyz, 427 | input_dict["cell"], 428 | edges, 429 | edges_displacement, 430 | input_dict["num_atom_edges"], 431 | return_diff=True, 432 | ) 433 | 434 | # Expand edge features in sinc basis 435 | edge_state = layer.sinc_expansion( 436 | edges_distance, [(self.distance_embedding_size, self.cutoff)] 437 | ) 438 | 439 | nodes_list_scalar = [] 440 | nodes_list_vector = [] 441 | # Apply interaction layers 442 | for int_layer, update_layer in zip( 443 | self.interactions, self.scalar_vector_update 444 | ): 445 | nodes_scalar, nodes_vector = int_layer( 446 | nodes_scalar, 447 | nodes_vector, 448 | edge_state, 449 | edges_diff, 450 | edges_distance, 451 | edges, 452 | ) 453 | nodes_scalar, nodes_vector = update_layer(nodes_scalar, nodes_vector) 454 | nodes_list_scalar.append(nodes_scalar) 455 | nodes_list_vector.append(nodes_vector) 456 | 457 | return nodes_list_scalar, nodes_list_vector 458 | 459 | 460 | class PainnProbeMessageModel(nn.Module): 461 | def __init__( 462 | self, 463 | num_interactions, 464 | hidden_state_size, 465 | cutoff, 466 | distance_embedding_size, 467 | **kwargs, 468 | ): 469 | 470 | super().__init__(**kwargs) 471 | 472 | self.num_interactions = num_interactions 473 | self.hidden_state_size = hidden_state_size 474 | self.cutoff = cutoff 475 | self.distance_embedding_size = distance_embedding_size 476 | 477 | # Setup interaction networks 478 | self.message_layers = nn.ModuleList( 479 | [ 480 | layer.PaiNNInteractionOneWay( 481 | hidden_state_size, self.distance_embedding_size, self.cutoff 482 | ) 483 | for _ in range(num_interactions) 484 | ] 485 | ) 486 | self.scalar_vector_update = nn.ModuleList( 487 | [layer.PaiNNUpdate(hidden_state_size) for _ in range(num_interactions)] 488 | ) 489 | 490 | # Setup readout function 491 | self.readout_function = nn.Sequential( 492 | nn.Linear(hidden_state_size, hidden_state_size), 493 | nn.SiLU(), 494 | nn.Linear(hidden_state_size, 1), 495 | ) 496 | 497 | def forward( 498 | self, 499 | input_dict: Dict[str, torch.Tensor], 500 | atom_representation_scalar: List[torch.Tensor], 501 | atom_representation_vector: List[torch.Tensor], 502 | compute_iri=False, 503 | compute_dori=False, 504 | compute_hessian=False, 505 | ): 506 | if compute_iri or compute_dori or compute_hessian: 507 | input_dict["probe_xyz"].requires_grad_() 508 | 509 | # Unpad and concatenate edges and features into batch (0th) dimension 510 | atom_xyz = layer.unpad_and_cat(input_dict["atom_xyz"], input_dict["num_nodes"]) 511 | probe_xyz = layer.unpad_and_cat( 512 | input_dict["probe_xyz"], input_dict["num_probes"] 513 | ) 514 | edge_offset = torch.cumsum( 515 | torch.cat( 516 | ( 517 | torch.tensor([0], device=input_dict["num_nodes"].device), 518 | input_dict["num_nodes"][:-1], 519 | ) 520 | ), 521 | dim=0, 522 | ) 523 | edge_offset = edge_offset[:, None, None] 524 | 525 | # Unpad and concatenate probe edges into batch (0th) dimension 526 | probe_edges_displacement = layer.unpad_and_cat( 527 | input_dict["probe_edges_displacement"], input_dict["num_probe_edges"] 528 | ) 529 | edge_probe_offset = torch.cumsum( 530 | torch.cat( 531 | ( 532 | torch.tensor([0], device=input_dict["num_probes"].device), 533 | input_dict["num_probes"][:-1], 534 | ) 535 | ), 536 | dim=0, 537 | ) 538 | edge_probe_offset = edge_probe_offset[:, None, None] 539 | edge_probe_offset = torch.cat((edge_offset, edge_probe_offset), dim=2) 540 | probe_edges = input_dict["probe_edges"] + edge_probe_offset 541 | probe_edges = layer.unpad_and_cat(probe_edges, input_dict["num_probe_edges"]) 542 | 543 | # Compute edge distances 544 | probe_edges_distance, probe_edges_diff = layer.calc_distance_to_probe( 545 | atom_xyz, 546 | probe_xyz, 547 | input_dict["cell"], 548 | probe_edges, 549 | probe_edges_displacement, 550 | input_dict["num_probe_edges"], 551 | return_diff=True, 552 | ) 553 | 554 | # Expand edge features in sinc basis 555 | edge_state = layer.sinc_expansion( 556 | probe_edges_distance, [(self.distance_embedding_size, self.cutoff)] 557 | ) 558 | 559 | # Apply interaction layers 560 | probe_state_scalar = torch.zeros( 561 | (torch.sum(input_dict["num_probes"]), self.hidden_state_size), 562 | device=atom_representation_scalar[0].device, 563 | ) 564 | probe_state_vector = torch.zeros( 565 | (torch.sum(input_dict["num_probes"]), 3, self.hidden_state_size), 566 | device=atom_representation_scalar[0].device, 567 | ) 568 | 569 | for msg_layer, update_layer, atom_nodes_scalar, atom_nodes_vector in zip( 570 | self.message_layers, 571 | self.scalar_vector_update, 572 | atom_representation_scalar, 573 | atom_representation_vector, 574 | ): 575 | probe_state_scalar, probe_state_vector = msg_layer( 576 | atom_nodes_scalar, 577 | atom_nodes_vector, 578 | probe_state_scalar, 579 | probe_state_vector, 580 | edge_state, 581 | probe_edges_diff, 582 | probe_edges_distance, 583 | probe_edges, 584 | ) 585 | probe_state_scalar, probe_state_vector = update_layer( 586 | probe_state_scalar, probe_state_vector 587 | ) 588 | 589 | # Restack probe states 590 | probe_output = self.readout_function(probe_state_scalar).squeeze(1) 591 | probe_output = layer.pad_and_stack( 592 | torch.split( 593 | probe_output, 594 | list(input_dict["num_probes"].detach().cpu().numpy()), 595 | dim=0, 596 | ) 597 | # torch.split(probe_output, input_dict["num_probes"], dim=0) 598 | # probe_output.reshape((-1, input_dict["num_probes"][0])) 599 | ) 600 | 601 | if compute_iri or compute_dori or compute_hessian: 602 | dp_dxyz = torch.autograd.grad( 603 | probe_output, 604 | input_dict["probe_xyz"], 605 | grad_outputs=torch.ones_like(probe_output), 606 | retain_graph=True, 607 | create_graph=True, 608 | )[0] 609 | 610 | grad_probe_outputs = {} 611 | 612 | if compute_iri: 613 | iri = torch.linalg.norm(dp_dxyz, dim=2)/(torch.pow(probe_output, 1.1)) 614 | grad_probe_outputs["iri"] = iri 615 | 616 | if compute_dori: 617 | ## 618 | ## DORI(r) = phi(r) / (1 + phi(r)) 619 | ## phi(r) = ||grad(||grad(rho(r))/rho||^2)||^2 / ||grad(rho(r))/rho(r)||^6 620 | ## 621 | norm_grad_2 = torch.linalg.norm(dp_dxyz/(torch.unsqueeze(probe_output, 2)), dim=2)**2 622 | 623 | grad_norm_grad_2 = torch.autograd.grad( 624 | norm_grad_2, 625 | input_dict["probe_xyz"], 626 | grad_outputs=torch.ones_like(norm_grad_2), 627 | only_inputs=True, 628 | retain_graph=True, 629 | create_graph=True, 630 | )[0].detach() 631 | 632 | phi_r = torch.linalg.norm(grad_norm_grad_2, dim=2)**2 / (norm_grad_2**3) 633 | 634 | dori = phi_r / (1 + phi_r) 635 | grad_probe_outputs["dori"] = dori 636 | 637 | if compute_hessian: 638 | hessian_shape = (input_dict["probe_xyz"].shape[0], input_dict["probe_xyz"].shape[1], 3, 3) 639 | hessian = torch.zeros(hessian_shape, device=probe_xyz.device, dtype=probe_xyz.dtype) 640 | for dim_idx, grad_out in enumerate(torch.unbind(dp_dxyz, dim=-1)): 641 | dp2_dxyz2 = torch.autograd.grad( 642 | grad_out, 643 | input_dict["probe_xyz"], 644 | grad_outputs=torch.ones_like(grad_out), 645 | only_inputs=True, 646 | retain_graph=True, 647 | create_graph=True, 648 | )[0] 649 | hessian[:, :, dim_idx] = dp2_dxyz2 650 | grad_probe_outputs["hessian"] = hessian 651 | 652 | 653 | if grad_probe_outputs: 654 | return probe_output, grad_probe_outputs 655 | else: 656 | return probe_output 657 | -------------------------------------------------------------------------------- /evaluate_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import tarfile 4 | import logging 5 | 6 | import numpy as np 7 | import torch 8 | 9 | import dataset 10 | import densitymodel 11 | from runner import split_data 12 | from utils import write_cube_to_tar 13 | 14 | def get_arguments(arg_list=None): 15 | parser = argparse.ArgumentParser( 16 | description="Evaluate density model", fromfile_prefix_chars="@" 17 | ) 18 | parser.add_argument("--load_model", type=str, default=None) 19 | parser.add_argument("--dataset", type=str, default=None) 20 | parser.add_argument("--output_dir", type=str, default=".") 21 | parser.add_argument("--cutoff", type=float, default=5.0) 22 | parser.add_argument("--num_interactions", type=int, default=3) 23 | parser.add_argument("--node_size", type=int, default=64) 24 | parser.add_argument("--split_file", type=str, default=None) 25 | parser.add_argument("--split", nargs='*', type=str) 26 | parser.add_argument("--probe_count", type=int, default=1000) 27 | parser.add_argument("--write_error_cubes", action="store_true") 28 | parser.add_argument( 29 | "--device", 30 | type=str, 31 | default="cuda", 32 | help="Set which device to use for training e.g. 'cuda' or 'cpu'", 33 | ) 34 | parser.add_argument( 35 | "--ignore_pbc", 36 | action="store_true", 37 | help="If flag is given, ignore periodic boundary conditions in atoms data", 38 | ) 39 | parser.add_argument( 40 | "--use_painn_model", 41 | action="store_true", 42 | help="Use painn model as backend", 43 | ) 44 | 45 | return parser.parse_args(arg_list) 46 | 47 | 48 | def main(): 49 | args = get_arguments() 50 | 51 | # Setup logging 52 | os.makedirs(args.output_dir, exist_ok=True) 53 | handlers = [ 54 | logging.FileHandler( 55 | os.path.join(args.output_dir, "printlog.txt"), mode="w" 56 | ), 57 | logging.StreamHandler(), 58 | ] 59 | 60 | logging.basicConfig( 61 | level=logging.DEBUG, 62 | format="%(asctime)s [%(levelname)-5.5s] %(message)s", 63 | handlers=handlers, 64 | ) 65 | 66 | # Initialise model and load model 67 | device = torch.device(args.device) 68 | if args.use_painn_model: 69 | net = densitymodel.PainnDensityModel(args.num_interactions, args.node_size, args.cutoff,) 70 | else: 71 | net = densitymodel.DensityModel(args.num_interactions, args.node_size, args.cutoff,) 72 | net = net.to(device) 73 | logging.info("loading model from %s", args.load_model) 74 | state_dict = torch.load(args.load_model) 75 | net.load_state_dict(state_dict["model"]) 76 | 77 | # Load dataset 78 | if args.dataset.endswith(".txt"): 79 | # Text file contains list of datafiles 80 | with open(args.dataset, "r") as datasetfiles: 81 | filelist = [os.path.join(os.path.dirname(args.dataset), line.strip('\n')) for line in datasetfiles] 82 | else: 83 | filelist = [args.dataset] 84 | logging.info("loading data %s", args.dataset) 85 | densitydata = torch.utils.data.ConcatDataset([dataset.DensityData(path) for path in filelist]) 86 | 87 | # Split data into train and validation sets 88 | if args.split_file: 89 | datasplits = split_data(densitydata, args) 90 | else: 91 | datasplits = {"all": densitydata} 92 | 93 | 94 | for split_name, densitydataset in datasplits.items(): 95 | if args.split and split_name not in args.split: 96 | continue 97 | dataloader = torch.utils.data.DataLoader( 98 | densitydataset, 99 | 1, 100 | num_workers=4, 101 | collate_fn=lambda x: x[0], 102 | ) 103 | 104 | if args.write_error_cubes: 105 | outname = os.path.join(args.output_dir, "eval_" + split_name + ".tar") 106 | tar = tarfile.open(outname, "w") 107 | 108 | for density_dict in dataloader: 109 | density = [] 110 | with torch.no_grad(): 111 | # Loop over all slices 112 | density_iter = dataset.DensityGridIterator(density_dict, args.ignore_pbc, args.probe_count, args.cutoff) 113 | 114 | # Make graph with no probes 115 | collate_fn = dataset.CollateFuncAtoms( 116 | cutoff=args.cutoff, 117 | pin_memory=True, 118 | disable_pbc=args.ignore_pbc, 119 | ) 120 | graph_dict = collate_fn([density_dict]) 121 | device_batch = { 122 | k: v.to(device=device, non_blocking=True) for k, v in graph_dict.items() 123 | } 124 | if args.use_painn_model: 125 | atom_representation_scalar, atom_representation_vector = net.atom_model(device_batch) 126 | else: 127 | atom_representation = net.atom_model(device_batch) 128 | 129 | num_positions = np.prod(density_dict["grid_position"].shape[0:3]) 130 | sum_abs_error = torch.tensor(0, dtype=torch.double, device=device) 131 | sum_squared_error = torch.tensor(0, dtype=torch.double, device=device) 132 | sum_target = torch.tensor(0, dtype=torch.double, device=device) 133 | for slice_id, probe_graph_dict in enumerate(density_iter): 134 | # Transfer target to device 135 | flat_index = np.arange(slice_id*args.probe_count, min((slice_id+1)*args.probe_count, num_positions)) 136 | pos_index = np.unravel_index(flat_index, density_dict["density"].shape[0:3]) 137 | probe_target = torch.tensor(density_dict["density"][pos_index]).to(device=device, non_blocking=True) 138 | 139 | # Transfer model input to device 140 | probe_dict = dataset.collate_list_of_dicts([probe_graph_dict]) 141 | probe_dict = { 142 | k: v.to(device=device, non_blocking=True) for k, v in probe_dict.items() 143 | } 144 | device_batch["probe_edges"] = probe_dict["probe_edges"] 145 | device_batch["probe_edges_displacement"] = probe_dict["probe_edges_displacement"] 146 | device_batch["probe_xyz"] = probe_dict["probe_xyz"] 147 | device_batch["num_probe_edges"] = probe_dict["num_probe_edges"] 148 | device_batch["num_probes"] = probe_dict["num_probes"] 149 | 150 | if args.use_painn_model: 151 | res = net.probe_model(device_batch, atom_representation_scalar, atom_representation_vector) 152 | else: 153 | res = net.probe_model(device_batch, atom_representation) 154 | 155 | # Compare result with target 156 | error = probe_target - res 157 | sum_abs_error += torch.sum(torch.abs(error)) 158 | sum_squared_error += torch.sum(torch.square(error)) 159 | sum_target += torch.sum(probe_target) 160 | 161 | if args.write_error_cubes: 162 | density.append(res.detach().cpu().numpy()) 163 | 164 | voxel_volume = density_dict["atoms"].get_volume()/np.prod(density_dict["density"].shape) 165 | rmse = torch.sqrt((sum_squared_error/num_positions)) 166 | mae = sum_abs_error/num_positions 167 | abserror_integral = sum_abs_error*voxel_volume 168 | total_integral = sum_target*voxel_volume 169 | percentage_error = 100*abserror_integral/total_integral 170 | 171 | 172 | if args.write_error_cubes: 173 | pred_density = np.concatenate(density, axis=1) 174 | target_density = density_dict["density"] 175 | pred_density = pred_density.reshape(target_density.shape) 176 | errors = pred_density-target_density 177 | 178 | fname_stripped = density_dict["metadata"]["filename"] 179 | while fname_stripped.endswith(".zz"): 180 | fname_stripped = fname_stripped[:-3] 181 | name, _ = os.path.splitext(fname_stripped) 182 | write_cube_to_tar( 183 | tar, 184 | density_dict["atoms"], 185 | pred_density, 186 | density_dict["grid_position"][0, 0, 0], 187 | name + "_prediction" + ".cube" + ".zz", 188 | ) 189 | write_cube_to_tar( 190 | tar, 191 | density_dict["atoms"], 192 | errors, 193 | density_dict["grid_position"][0, 0, 0], 194 | name + "_error" + ".cube" + ".zz", 195 | ) 196 | write_cube_to_tar( 197 | tar, 198 | density_dict["atoms"], 199 | target_density, 200 | density_dict["grid_position"][0, 0, 0], 201 | name + "_target" + ".cube" + ".zz", 202 | ) 203 | 204 | logging.info("split=%s, filename=%s, mae=%f, rmse=%f, abs_relative_error=%f%%", split_name, density_dict["metadata"]["filename"], mae, rmse, percentage_error) 205 | 206 | 207 | if __name__ == "__main__": 208 | main() 209 | -------------------------------------------------------------------------------- /layer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | import itertools 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | 7 | 8 | def pad_and_stack(tensors: List[torch.Tensor]): 9 | """Pad list of tensors if tensors are arrays and stack if they are scalars""" 10 | if tensors[0].shape: 11 | return torch.nn.utils.rnn.pad_sequence( 12 | tensors, batch_first=True, padding_value=0 13 | ) 14 | return torch.stack(tensors) 15 | 16 | 17 | def shifted_softplus(x): 18 | """ 19 | Compute shifted soft-plus activation function. 20 | .. math:: 21 | y = \ln\left(1 + e^{-x}\right) - \ln(2) 22 | 23 | Args: 24 | x (torch.Tensor): input tensor. 25 | 26 | Returns: 27 | torch.Tensor: shifted soft-plus of input. 28 | 29 | """ 30 | return nn.functional.softplus(x) - np.log(2.0) 31 | 32 | 33 | class ShiftedSoftplus(nn.Module): 34 | def forward(self, x): 35 | return shifted_softplus(x) 36 | 37 | 38 | def unpad_and_cat(stacked_seq: torch.Tensor, seq_len: torch.Tensor): 39 | """ 40 | Unpad and concatenate by removing batch dimension 41 | 42 | Args: 43 | stacked_seq: (batch_size, max_length, *) Tensor 44 | seq_len: (batch_size) Tensor with length of each sequence 45 | 46 | Returns: 47 | (prod(seq_len), *) Tensor 48 | 49 | """ 50 | unstacked = stacked_seq.unbind(0) 51 | unpadded = [ 52 | torch.narrow(t, 0, 0, l) for (t, l) in zip(unstacked, seq_len.unbind(0)) 53 | ] 54 | return torch.cat(unpadded, dim=0) 55 | 56 | 57 | def sum_splits(values: torch.Tensor, splits: torch.Tensor): 58 | """ 59 | Sum across dimension 0 of the tensor `values` in chunks 60 | defined in `splits` 61 | 62 | Args: 63 | values: Tensor of shape (`prod(splits)`, *) 64 | splits: 1-dimensional tensor with size of each chunk 65 | 66 | Returns: 67 | Tensor of shape (`splits.shape[0]`, *) 68 | 69 | """ 70 | # prepare an index vector for summation 71 | ind = torch.zeros(splits.sum(), dtype=splits.dtype, device=splits.device) 72 | ind[torch.cumsum(splits, dim=0)[:-1]] = 1 73 | ind = torch.cumsum(ind, dim=0) 74 | # prepare the output 75 | sum_y = torch.zeros( 76 | splits.shape + values.shape[1:], dtype=values.dtype, device=values.device 77 | ) 78 | # do the actual summation 79 | sum_y.index_add_(0, ind, values) 80 | return sum_y 81 | 82 | 83 | def calc_distance( 84 | positions: torch.Tensor, 85 | cells: torch.Tensor, 86 | edges: torch.Tensor, 87 | edges_displacement: torch.Tensor, 88 | splits: torch.Tensor, 89 | return_diff=False, 90 | ): 91 | """ 92 | Calculate distance of edges 93 | 94 | Args: 95 | positions: Tensor of shape (num_nodes, 3) with xyz coordinates inside cell 96 | cells: Tensor of shape (num_splits, 3, 3) with one unit cell for each split 97 | edges: Tensor of shape (num_edges, 2) 98 | edges_displacement: Tensor of shape (num_edges, 3) with the offset (in number of cell vectors) of the sending node 99 | splits: 1-dimensional tensor with the number of edges for each separate graph 100 | return_diff: If non-zero return the also the vector corresponding to edges 101 | """ 102 | unitcell_repeat = torch.repeat_interleave(cells, splits, dim=0) # num_edges, 3, 3 103 | displacement = torch.matmul( 104 | torch.unsqueeze(edges_displacement, 1), unitcell_repeat 105 | ) # num_edges, 1, 3 106 | displacement = torch.squeeze(displacement, dim=1) 107 | neigh_pos = positions[edges[:, 0]] # num_edges, 3 108 | neigh_abs_pos = neigh_pos + displacement # num_edges, 3 109 | this_pos = positions[edges[:, 1]] # num_edges, 3 110 | diff = this_pos - neigh_abs_pos # num_edges, 3 111 | dist = torch.sqrt( 112 | torch.sum(torch.square(diff), dim=1, keepdim=True) 113 | ) # num_edges, 1 114 | 115 | if return_diff: 116 | return dist, diff 117 | else: 118 | return dist 119 | 120 | 121 | def calc_distance_to_probe( 122 | positions: torch.Tensor, 123 | positions_probe: torch.Tensor, 124 | cells: torch.Tensor, 125 | edges: torch.Tensor, 126 | edges_displacement: torch.Tensor, 127 | splits: torch.Tensor, 128 | return_diff=False, 129 | ): 130 | """ 131 | Calculate distance of edges 132 | 133 | Args: 134 | positions: Tensor of shape (num_nodes, 3) with xyz coordinates inside cell 135 | positions_probe: Tensor of shape (num_probes, 3) with xyz coordinates of probes inside cell 136 | cells: Tensor of shape (num_splits, 3, 3) with one unit cell for each split 137 | edges: Tensor of shape (num_edges, 2) 138 | edges_displacement: Tensor of shape (num_edges, 3) with the offset (in number of cell vectors) of the sending node 139 | splits: 1-dimensional tensor with the number of edges for each separate graph 140 | """ 141 | unitcell_repeat = torch.repeat_interleave(cells, splits, dim=0) # num_edges, 3, 3 142 | displacement = torch.matmul( 143 | torch.unsqueeze(edges_displacement, 1), unitcell_repeat 144 | ) # num_edges, 1, 3 145 | displacement = torch.squeeze(displacement, dim=1) 146 | neigh_pos = positions[edges[:, 0]] # num_edges, 3 147 | neigh_abs_pos = neigh_pos + displacement # num_edges, 3 148 | this_pos = positions_probe[edges[:, 1]] # num_edges, 3 149 | diff = this_pos - neigh_abs_pos # num_edges, 3 150 | dist = torch.sqrt( 151 | torch.sum(torch.square(diff), dim=1, keepdim=True) 152 | ) # num_edges, 1 153 | if return_diff: 154 | return dist, diff 155 | else: 156 | return dist 157 | 158 | 159 | def gaussian_expansion(input_x: torch.Tensor, expand_params: List[Tuple]): 160 | """ 161 | Expand each feature in a number of Gaussian basis function. 162 | Expand_params is a list of length input_x.shape[1] 163 | 164 | Args: 165 | input_x: (num_edges, num_features) tensor 166 | expand_params: list of None or (start, step, stop) tuples 167 | 168 | Returns: 169 | (num_edges, ``ceil((stop - start)/step)``) tensor 170 | 171 | """ 172 | feat_list = torch.unbind(input_x, dim=1) 173 | expanded_list = [] 174 | for step_tuple, feat in itertools.zip_longest(expand_params, feat_list): 175 | assert feat is not None, "Too many expansion parameters given" 176 | if step_tuple: 177 | start, step, stop = step_tuple 178 | feat_expanded = torch.unsqueeze(feat, dim=1) 179 | sigma = step 180 | basis_mu = torch.arange( 181 | start, stop, step, device=input_x.device, dtype=input_x.dtype 182 | ) 183 | expanded_list.append( 184 | torch.exp(-((feat_expanded - basis_mu) ** 2) / (2.0 * sigma ** 2)) 185 | ) 186 | else: 187 | expanded_list.append(torch.unsqueeze(feat, 1)) 188 | return torch.cat(expanded_list, dim=1) 189 | 190 | 191 | class SchnetMessageFunction(nn.Module): 192 | def __init__(self, input_size, edge_size, output_size, hard_cutoff): 193 | super().__init__() 194 | self.msg_function_edge = nn.Sequential( 195 | nn.Linear(edge_size, output_size), 196 | ShiftedSoftplus(), 197 | nn.Linear(output_size, output_size), 198 | ) 199 | self.msg_function_node = nn.Sequential( 200 | nn.Linear(input_size, input_size), 201 | ShiftedSoftplus(), 202 | nn.Linear(input_size, output_size), 203 | ) 204 | 205 | self.soft_cutoff_func = lambda x: 1.0 - torch.sigmoid( 206 | 5 * (x - (hard_cutoff - 1.5)) 207 | ) 208 | 209 | def forward(self, node_state, edge_state, edge_distance): 210 | gates = self.msg_function_edge(edge_state) * self.soft_cutoff_func( 211 | edge_distance 212 | ) 213 | nodes = self.msg_function_node(node_state) 214 | return nodes * gates 215 | 216 | 217 | class Interaction(nn.Module): 218 | def __init__(self, node_size, edge_size, cutoff, include_receiver=False): 219 | super().__init__() 220 | 221 | self.message_sum_module = MessageSum( 222 | node_size, edge_size, cutoff, include_receiver 223 | ) 224 | 225 | self.state_transition_function = nn.Sequential( 226 | nn.Linear(node_size, node_size), 227 | ShiftedSoftplus(), 228 | nn.Linear(node_size, node_size), 229 | ) 230 | 231 | def forward(self, node_state, edges, edge_state, edges_distance): 232 | 233 | # Compute sum of messages 234 | message_sum = self.message_sum_module( 235 | node_state, edges, edge_state, edges_distance 236 | ) 237 | 238 | # State transition 239 | new_state = node_state + self.state_transition_function(message_sum) 240 | 241 | return new_state 242 | 243 | 244 | class MessageSum(nn.Module): 245 | def __init__(self, node_size, edge_size, cutoff, include_receiver): 246 | super().__init__() 247 | 248 | self.include_receiver = include_receiver 249 | 250 | if include_receiver: 251 | input_size = node_size * 2 252 | else: 253 | input_size = node_size 254 | 255 | self.message_function = SchnetMessageFunction( 256 | input_size, edge_size, node_size, cutoff 257 | ) 258 | 259 | def forward( 260 | self, node_state, edges, edge_state, edges_distance, receiver_nodes=None 261 | ): 262 | """ 263 | 264 | Args: 265 | node_state: [num_nodes, n_node_features] State of input nodes 266 | edges: [num_edges, 2] array of sender and receiver indices 267 | edge_state: [num_edges, n_features] array of edge features 268 | edges_distance: [num_edges, 1] array of distances 269 | receiver_nodes: If given, use these nodes as receiver nodes instead of node_state 270 | 271 | Returns: 272 | sum of messages to each node 273 | 274 | """ 275 | # Compute all messages 276 | if self.include_receiver: 277 | if receiver_nodes is not None: 278 | senders = node_state[edges[:, 0]] 279 | receivers = receiver_nodes[edges[:, 1]] 280 | nodes = torch.cat((senders, receivers), dim=1) 281 | else: 282 | num_edges = edges.shape[0] 283 | nodes = torch.reshape(node_state[edges], (num_edges, -1)) 284 | else: 285 | nodes = node_state[edges[:, 0]] # Only include sender in messages 286 | messages = self.message_function(nodes, edge_state, edges_distance) 287 | 288 | # Sum messages 289 | if receiver_nodes is not None: 290 | message_sum = torch.zeros_like(receiver_nodes) 291 | else: 292 | message_sum = torch.zeros_like(node_state) 293 | message_sum.index_add_(0, edges[:, 1], messages) 294 | 295 | return message_sum 296 | 297 | 298 | class EdgeUpdate(nn.Module): 299 | def __init__(self, edge_size, node_size): 300 | super().__init__() 301 | 302 | self.node_size = node_size 303 | self.edge_update_mlp = nn.Sequential( 304 | nn.Linear(2 * node_size + edge_size, 2 * edge_size), 305 | ShiftedSoftplus(), 306 | nn.Linear(2 * edge_size, edge_size), 307 | ) 308 | 309 | def forward(self, edge_state, edges, node_state): 310 | combined = torch.cat( 311 | (node_state[edges].view(-1, 2 * self.node_size), edge_state), axis=1 312 | ) 313 | return self.edge_update_mlp(combined) 314 | 315 | 316 | class PaiNNUpdate(nn.Module): 317 | """PaiNN style update network. Models the interaction between scalar and vectorial part""" 318 | 319 | def __init__(self, node_size): 320 | super().__init__() 321 | 322 | self.linearU = nn.Linear(node_size, node_size, bias=False) 323 | self.linearV = nn.Linear(node_size, node_size, bias=False) 324 | self.combined_mlp = nn.Sequential( 325 | nn.Linear(2 * node_size, node_size), 326 | nn.SiLU(), 327 | nn.Linear(node_size, 3 * node_size), 328 | ) 329 | 330 | def forward(self, node_state_scalar, node_state_vector): 331 | """ 332 | Args: 333 | node_state_scalar (tensor): Node states (num_nodes, node_size) 334 | node_state_vector (tensor): Node states (num_nodes, 3, node_size) 335 | 336 | Returns: 337 | Tuple of 2 tensors: 338 | updated_node_state_scalar (num_nodes, node_size) 339 | updated_node_state_vector (num_nodes, 3, node_size) 340 | """ 341 | 342 | Uv = self.linearU(node_state_vector) # num_nodes, 3, node_size 343 | Vv = self.linearV(node_state_vector) # num_nodes, 3, node_size 344 | 345 | Vv_norm = torch.linalg.norm(Vv, dim=1, keepdim=False) # num_nodes, node_size 346 | 347 | mlp_input = torch.cat( 348 | (node_state_scalar, Vv_norm), dim=1 349 | ) # num_nodes, node_size*2 350 | mlp_output = self.combined_mlp(mlp_input) 351 | 352 | a_ss, a_sv, a_vv = torch.split( 353 | mlp_output, node_state_scalar.shape[1], dim=1 354 | ) # num_nodes, node_size 355 | 356 | inner_prod = torch.sum(Uv * Vv, dim=1) # num_nodes, node_size 357 | 358 | delta_v = torch.unsqueeze(a_vv, 1) * Uv # num_nodes, 3, node_size 359 | 360 | delta_s = a_ss + a_sv * inner_prod # num_nodes, node_size 361 | 362 | return node_state_scalar + delta_s, node_state_vector + delta_v 363 | 364 | 365 | class PaiNNInteraction(nn.Module): 366 | """Interaction network""" 367 | 368 | def __init__(self, node_size, edge_size, cutoff): 369 | """ 370 | Args: 371 | node_size (int): Size of node state 372 | edge_size (int): Size of edge state 373 | cutoff (float): Cutoff distance 374 | """ 375 | super().__init__() 376 | 377 | self.filter_layer = nn.Linear(edge_size, 3 * node_size) 378 | 379 | self.cutoff = cutoff 380 | 381 | self.scalar_message_mlp = nn.Sequential( 382 | nn.Linear(node_size, node_size), 383 | nn.SiLU(), 384 | nn.Linear(node_size, 3 * node_size), 385 | ) 386 | 387 | def forward( 388 | self, 389 | node_state_scalar, 390 | node_state_vector, 391 | edge_state, 392 | edge_vector, 393 | edge_distance, 394 | edges, 395 | ): 396 | """ 397 | Args: 398 | node_state_scalar (tensor): Node states (num_nodes, node_size) 399 | node_state_vector (tensor): Node states (num_nodes, 3, node_size) 400 | edge_state (tensor): Edge states (num_edges, edge_size) 401 | edge_vector (tensor): Edge vector difference between nodes (num_edges, 3) 402 | edge_distance (tensor): l2-norm of edge_vector (num_edges, 1) 403 | edges (tensor): Directed edges with node indices (num_edges, 2) 404 | 405 | Returns: 406 | Tuple of 2 tensors: 407 | updated_node_state_scalar (num_nodes, node_size) 408 | updated_node_state_vector (num_nodes, 3, node_size) 409 | """ 410 | # Compute all messages 411 | edge_vector_normalised = edge_vector / torch.maximum( 412 | torch.linalg.norm(edge_vector, dim=1, keepdim=True), torch.tensor(1e-12) 413 | ) # num_edges, 3 414 | 415 | filter_weight = self.filter_layer(edge_state) # num_edges, 3*node_size 416 | filter_weight = filter_weight * cosine_cutoff(edge_distance, self.cutoff) 417 | 418 | scalar_output = self.scalar_message_mlp( 419 | node_state_scalar 420 | ) # num_nodes, 3*node_size 421 | scalar_output = scalar_output[edges[:, 0]] # num_edges, 3*node_size 422 | filter_output = filter_weight * scalar_output # num_edges, 3*node_size 423 | 424 | gate_state_vector, gate_edge_vector, gate_node_state = torch.split( 425 | filter_output, node_state_scalar.shape[1], dim=1 426 | ) 427 | 428 | gate_state_vector = torch.unsqueeze( 429 | gate_state_vector, 1 430 | ) # num_edges, 1, node_size 431 | gate_edge_vector = torch.unsqueeze( 432 | gate_edge_vector, 1 433 | ) # num_edges, 1, node_size 434 | 435 | # Only include sender in messages 436 | messages_scalar = node_state_scalar[edges[:, 0]] * gate_node_state 437 | messages_state_vector = node_state_vector[ 438 | edges[:, 0] 439 | ] * gate_state_vector + gate_edge_vector * torch.unsqueeze( 440 | edge_vector_normalised, 2 441 | ) 442 | 443 | # Sum messages 444 | message_sum_scalar = torch.zeros_like(node_state_scalar) 445 | message_sum_scalar.index_add_(0, edges[:, 1], messages_scalar) 446 | message_sum_vector = torch.zeros_like(node_state_vector) 447 | message_sum_vector.index_add_(0, edges[:, 1], messages_state_vector) 448 | 449 | # State transition 450 | new_state_scalar = node_state_scalar + message_sum_scalar 451 | new_state_vector = node_state_vector + message_sum_vector 452 | 453 | return new_state_scalar, new_state_vector 454 | 455 | 456 | class PaiNNInteractionOneWay(nn.Module): 457 | """Sasme as Interaction network, but the receiving nodes are differently indexed from the sending nodes""" 458 | 459 | def __init__(self, node_size, edge_size, cutoff): 460 | """ 461 | Args: 462 | node_size (int): Size of node state 463 | edge_size (int): Size of edge state 464 | cutoff (float): Cutoff distance 465 | """ 466 | super().__init__() 467 | 468 | self.filter_layer = nn.Linear(edge_size, 3 * node_size) 469 | 470 | self.cutoff = cutoff 471 | 472 | self.scalar_message_mlp = nn.Sequential( 473 | nn.Linear(node_size, node_size), 474 | nn.SiLU(), 475 | nn.Linear(node_size, 3 * node_size), 476 | ) 477 | 478 | # Ignore messages gate (not part of original PaiNN network) 479 | self.update_gate_mlp = nn.Sequential( 480 | nn.Linear(node_size, 2 * node_size), 481 | nn.SiLU(), 482 | nn.Linear(2 * node_size, 2 * node_size), 483 | nn.Sigmoid(), 484 | ) 485 | 486 | def forward( 487 | self, 488 | sender_node_state_scalar, 489 | sender_node_state_vector, 490 | receiver_node_state_scalar, 491 | receiver_node_state_vector, 492 | edge_state, 493 | edge_vector, 494 | edge_distance, 495 | edges, 496 | ): 497 | """ 498 | Args: 499 | sender_node_state_scalar (tensor): Node states (num_nodes, node_size) 500 | sender_node_state_vector (tensor): Node states (num_nodes, 3, node_size) 501 | receiver_node_state_scalar (tensor): Node states (num_nodes, node_size) 502 | receiver_node_state_vector (tensor): Node states (num_nodes, 3, node_size) 503 | edge_state (tensor): Edge states (num_edges, edge_size) 504 | edge_vector (tensor): Edge vector difference between nodes (num_edges, 3) 505 | edge_distance (tensor): l2-norm of edge_vector (num_edges, 1) 506 | edges (tensor): Directed edges with node indices (num_edges, 2) 507 | 508 | Returns: 509 | Tuple of 2 tensors: 510 | updated_node_state_scalar (num_nodes, node_size) 511 | updated_node_state_vector (num_nodes, 3, node_size) 512 | """ 513 | # Compute all messages 514 | edge_vector_normalised = edge_vector / torch.maximum( 515 | torch.linalg.norm(edge_vector, dim=1, keepdim=True), torch.tensor(1e-12) 516 | ) # num_edges, 3 517 | 518 | filter_weight = self.filter_layer(edge_state) # num_edges, 3*node_size 519 | filter_weight = filter_weight * cosine_cutoff(edge_distance, self.cutoff) 520 | 521 | scalar_output = self.scalar_message_mlp( 522 | sender_node_state_scalar 523 | ) # num_nodes, 3*node_size 524 | scalar_output = scalar_output[edges[:, 0]] # num_edges, 3*node_size 525 | filter_output = filter_weight * scalar_output # num_edges, 3*node_size 526 | 527 | gate_state_vector, gate_edge_vector, gate_node_state = torch.split( 528 | filter_output, sender_node_state_scalar.shape[1], dim=1 529 | ) 530 | 531 | gate_state_vector = torch.unsqueeze( 532 | gate_state_vector, 1 533 | ) # num_edges, 1, node_size 534 | gate_edge_vector = torch.unsqueeze( 535 | gate_edge_vector, 1 536 | ) # num_edges, 1, node_size 537 | 538 | # Only include sender in messages 539 | messages_scalar = sender_node_state_scalar[edges[:, 0]] * gate_node_state 540 | messages_state_vector = sender_node_state_vector[ 541 | edges[:, 0] 542 | ] * gate_state_vector + gate_edge_vector * torch.unsqueeze( 543 | edge_vector_normalised, 2 544 | ) 545 | 546 | # Sum messages 547 | message_sum_scalar = torch.zeros_like(receiver_node_state_scalar) 548 | message_sum_scalar.index_add_(0, edges[:, 1], messages_scalar) 549 | message_sum_vector = torch.zeros_like(receiver_node_state_vector) 550 | message_sum_vector.index_add_(0, edges[:, 1], messages_state_vector) 551 | 552 | # State transition 553 | update_gate_scalar, update_gate_vector = torch.split( 554 | self.update_gate_mlp(message_sum_scalar), 555 | receiver_node_state_scalar.shape[1], 556 | dim=1, 557 | ) 558 | update_gate_vector = torch.unsqueeze( 559 | update_gate_vector, 1 560 | ) # num_nodes, 1, node_size 561 | new_state_scalar = ( 562 | update_gate_scalar * receiver_node_state_scalar 563 | + (1.0 - update_gate_scalar) * message_sum_scalar 564 | ) 565 | new_state_vector = ( 566 | update_gate_vector * receiver_node_state_vector 567 | + (1.0 - update_gate_vector) * message_sum_vector 568 | ) 569 | 570 | return new_state_scalar, new_state_vector 571 | 572 | 573 | def sinc_expansion(input_x: torch.Tensor, expand_params: List[Tuple]): 574 | """ 575 | Expand each feature in a sinc-like basis function expansion. 576 | Based on [1]. 577 | sin(n*pi*f/rcut)/f 578 | 579 | [1] arXiv:2003.03123 - Directional Message Passing for Molecular Graphs 580 | 581 | Args: 582 | input_x: (num_edges, num_features) tensor 583 | expand_params: list of None or (n, cutoff) tuples 584 | 585 | Return: 586 | (num_edges, n1+n2+...) tensor 587 | """ 588 | feat_list = torch.unbind(input_x, dim=1) 589 | expanded_list = [] 590 | for step_tuple, feat in itertools.zip_longest(expand_params, feat_list): 591 | assert feat is not None, "Too many expansion parameters given" 592 | if step_tuple: 593 | n, cutoff = step_tuple 594 | feat_expanded = torch.unsqueeze(feat, dim=1) 595 | n_range = torch.arange(n, device=input_x.device, dtype=input_x.dtype) + 1 596 | # multiplication by pi n_range / cutoff is done in original painn for some reason 597 | out = torch.sinc(n_range/cutoff*feat_expanded)*np.pi*n_range/cutoff 598 | expanded_list.append(out) 599 | else: 600 | expanded_list.append(torch.unsqueeze(feat, 1)) 601 | return torch.cat(expanded_list, dim=1) 602 | 603 | 604 | def cosine_cutoff(distance: torch.Tensor, cutoff: float): 605 | """ 606 | Calculate cutoff value based on distance. 607 | This uses the cosine Behler-Parinello cutoff function: 608 | 609 | f(d) = 0.5*(cos(pi*d/d_cut)+1) for d < d_cut and 0 otherwise 610 | """ 611 | 612 | return torch.where( 613 | distance < cutoff, 614 | 0.5 * (torch.cos(np.pi * distance / cutoff) + 1), 615 | torch.tensor(0.0, device=distance.device, dtype=distance.dtype), 616 | ) 617 | -------------------------------------------------------------------------------- /predict_with_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | import argparse 5 | import math 6 | import contextlib 7 | import timeit 8 | 9 | import ase 10 | import ase.io 11 | import torch 12 | import numpy as np 13 | 14 | import dataset 15 | import densitymodel 16 | import utils 17 | 18 | def get_arguments(arg_list=None): 19 | parser = argparse.ArgumentParser( 20 | description="Predict with pretrained model", fromfile_prefix_chars="@" 21 | ) 22 | parser.add_argument("model_dir", type=str, help='Directory of pretrained model') 23 | parser.add_argument("atoms_file", type=str, help='ASE compatible atoms xyz-file') 24 | parser.add_argument("--grid_step", type=float, default=0.05, help="Step size in Ångstrøm") 25 | parser.add_argument("--vacuum", type=float, default=1.0, help="Pad simulation box with vacuum (only used when boundary conditions are not periodic)") 26 | parser.add_argument("--output_dir", type=str, default="model_prediction", help="Output directory") 27 | parser.add_argument("--iri", action="store_true", help="Also compute interaction region indicator (IRI)") 28 | parser.add_argument("--dori", action="store_true", help="Also compute density overlap region indicator (DORI)") 29 | parser.add_argument("--hessian_eig", action="store_true", help="Also compute eigenvalues of density Hessian") 30 | parser.add_argument("--probe_count", type=int, default=5000, help="How many probe points to compute per iteration") 31 | parser.add_argument( 32 | "--device", 33 | type=str, 34 | default="cuda", 35 | help="Set which device to use for inference e.g. 'cuda' or 'cpu'", 36 | ) 37 | parser.add_argument( 38 | "--ignore_pbc", 39 | action="store_true", 40 | help="If flag is given, disable periodic boundary conditions (force to False) in atoms data", 41 | ) 42 | parser.add_argument( 43 | "--force_pbc", 44 | action="store_true", 45 | help="If flag is given, force periodic boundary conditions to True in atoms data", 46 | ) 47 | 48 | return parser.parse_args(arg_list) 49 | 50 | def load_model(model_dir, device): 51 | with open(os.path.join(model_dir, "arguments.json"), "r") as f: 52 | runner_args = argparse.Namespace(**json.load(f)) 53 | if runner_args.use_painn_model: 54 | model = densitymodel.PainnDensityModel(runner_args.num_interactions, runner_args.node_size, runner_args.cutoff) 55 | else: 56 | model = densitymodel.DensityModel(runner_args.num_interactions, runner_args.node_size, runner_args.cutoff) 57 | device = torch.device(device) 58 | model.to(device) 59 | state_dict = torch.load(os.path.join(model_dir, "best_model.pth"), map_location=device) 60 | model.load_state_dict(state_dict["model"]) 61 | return model, runner_args.cutoff 62 | 63 | class LazyMeshGrid(): 64 | def __init__(self, cell, grid_step, origin=None, adjust_grid_step=False): 65 | self.cell = cell 66 | if adjust_grid_step: 67 | n_steps = np.round(self.cell.lengths()/grid_step) 68 | self.scaled_grid_vectors = [np.arange(n)/n for n in n_steps] 69 | self.adjusted_grid_step = self.cell.lengths()/n_steps 70 | else: 71 | self.scaled_grid_vectors = [np.arange(0, l, grid_step)/l for l in self.cell.lengths()] 72 | self.shape = np.array([len(g) for g in self.scaled_grid_vectors] + [3]) 73 | if origin is None: 74 | self.origin = np.zeros(3) 75 | else: 76 | self.origin = origin 77 | 78 | self.origin = np.expand_dims(self.origin, 0) 79 | 80 | def __getitem__(self, indices): 81 | indices = np.array(indices) 82 | indices_shape = indices.shape 83 | if not (len(indices_shape) == 2 and indices_shape[0] == 3): 84 | raise NotImplementedError("Indexing must be a 3xN array-like object") 85 | gridA = self.scaled_grid_vectors[0][indices[0]] 86 | gridB = self.scaled_grid_vectors[1][indices[1]] 87 | gridC = self.scaled_grid_vectors[2][indices[2]] 88 | 89 | grid_pos = np.stack([gridA, gridB, gridC], 1) 90 | grid_pos = np.dot(grid_pos, self.cell) 91 | grid_pos += self.origin 92 | 93 | return grid_pos 94 | 95 | 96 | def ceil_float(x, step_size): 97 | # Round up to nearest step_size and subtract a small epsilon 98 | x = math.ceil(x/step_size) * step_size 99 | eps = 2*np.finfo(float).eps * x 100 | return x - eps 101 | 102 | def load_atoms(atomspath, vacuum, grid_step): 103 | atoms = ase.io.read(atomspath) 104 | 105 | if np.any(atoms.get_pbc()): 106 | atoms, grid_pos, origin = load_material(atoms, grid_step) 107 | else: 108 | atoms, grid_pos, origin = load_molecule(atoms, grid_step, vacuum) 109 | 110 | metadata = {"filename": atomspath} 111 | res = { 112 | "atoms": atoms, 113 | "origin": origin, 114 | "grid_position": grid_pos, 115 | "metadat": metadata, 116 | } 117 | 118 | return res 119 | 120 | def load_material(atoms, grid_step): 121 | atoms = atoms.copy() 122 | grid_pos = LazyMeshGrid(atoms.get_cell(), grid_step, adjust_grid_step=True) 123 | origin = np.zeros(3) 124 | 125 | return atoms, grid_pos, origin 126 | 127 | def load_molecule(atoms, grid_step, vacuum): 128 | atoms = atoms.copy() 129 | atoms.center(vacuum=vacuum) # This will create a cell around the atoms 130 | 131 | # Readjust cell lengths to be a multiple of grid_step 132 | a, b, c, ang_bc, ang_ac, ang_ab = atoms.get_cell_lengths_and_angles() 133 | a, b, c = ceil_float(a, grid_step), ceil_float(b, grid_step), ceil_float(c, grid_step) 134 | atoms.set_cell([a, b, c, ang_bc, ang_ac, ang_ab]) 135 | 136 | origin = np.zeros(3) 137 | 138 | grid_pos = LazyMeshGrid(atoms.get_cell(), grid_step) 139 | 140 | return atoms, grid_pos, origin 141 | def main(): 142 | args = get_arguments() 143 | 144 | # Setup logging 145 | os.makedirs(args.output_dir, exist_ok=True) 146 | logging.basicConfig( 147 | level=logging.DEBUG, 148 | format="%(asctime)s [%(levelname)-5.5s] %(message)s", 149 | handlers=[ 150 | logging.FileHandler( 151 | os.path.join(args.output_dir, "printlog.txt"), mode="w" 152 | ), 153 | logging.StreamHandler(), 154 | ], 155 | ) 156 | 157 | model, cutoff = load_model(args.model_dir, args.device) 158 | 159 | density_dict = load_atoms(args.atoms_file, args.vacuum, args.grid_step) 160 | 161 | device = torch.device(args.device) 162 | 163 | cubewriter = utils.CubeWriter( 164 | os.path.join(args.output_dir, "prediction.cube"), 165 | density_dict["atoms"], 166 | density_dict["grid_position"].shape[0:3], 167 | density_dict["origin"], 168 | "predicted by DeepDFT model", 169 | ) 170 | if args.iri: 171 | cubewriter_iri = utils.CubeWriter( 172 | os.path.join(args.output_dir, "iri.cube"), 173 | density_dict["atoms"], 174 | density_dict["grid_position"].shape[0:3], 175 | density_dict["origin"], 176 | "predicted by DeepDFT model", 177 | ) 178 | if args.dori: 179 | cubewriter_dori = utils.CubeWriter( 180 | os.path.join(args.output_dir, "dori.cube"), 181 | density_dict["atoms"], 182 | density_dict["grid_position"].shape[0:3], 183 | density_dict["origin"], 184 | "predicted by DeepDFT model", 185 | ) 186 | if args.hessian_eig: 187 | cubewriter_hessian_eig = [] 188 | for i in range(3): 189 | cubewriter_hessian_eig.append( 190 | utils.CubeWriter( 191 | os.path.join(args.output_dir, "hessian_eig_%d.cube" % i), 192 | density_dict["atoms"], 193 | density_dict["grid_position"].shape[0:3], 194 | density_dict["origin"], 195 | "predicted by DeepDFT model", 196 | ) 197 | ) 198 | 199 | if args.ignore_pbc and args.force_pbc: 200 | raise ValueError("ignore_pbc and force_pbc are mutually exclusive and can't both be set at the same time") 201 | elif args.ignore_pbc: 202 | set_pbc = False 203 | elif args.force_pbc: 204 | set_pbc = True 205 | else: 206 | set_pbc = None 207 | 208 | start_time = timeit.default_timer() 209 | 210 | if args.iri or args.dori or args.hessian_eig: 211 | contextmanager = contextlib.nullcontext() 212 | else: 213 | # No gradients needed from the model 214 | contextmanager = torch.no_grad() 215 | with contextmanager: 216 | # Make graph with no probes 217 | logging.debug("Computing atom-to-atom graph") 218 | collate_fn = dataset.CollateFuncAtoms( 219 | cutoff=cutoff, 220 | pin_memory=device.type == "cuda", 221 | set_pbc_to=set_pbc, 222 | ) 223 | graph_dict = collate_fn([density_dict]) 224 | logging.debug("Computing atom representation") 225 | device_batch = { 226 | k: v.to(device=device, non_blocking=True) for k, v in graph_dict.items() 227 | } 228 | if isinstance(model, densitymodel.PainnDensityModel): 229 | atom_representation_scalar, atom_representation_vector = model.atom_model(device_batch) 230 | else: 231 | atom_representation = model.atom_model(device_batch) 232 | logging.debug("Atom representation done") 233 | 234 | # Loop over all slices 235 | density_iter = dataset.DensityGridIterator(density_dict, args.probe_count, cutoff, set_pbc_to=set_pbc) 236 | density = [] 237 | for probe_graph_dict in density_iter: 238 | probe_dict = dataset.collate_list_of_dicts([probe_graph_dict]) 239 | probe_dict = { 240 | k: v.to(device=device, non_blocking=True) for k, v in probe_dict.items() 241 | } 242 | device_batch["probe_edges"] = probe_dict["probe_edges"] 243 | device_batch["probe_edges_displacement"] = probe_dict["probe_edges_displacement"] 244 | device_batch["probe_xyz"] = probe_dict["probe_xyz"] 245 | device_batch["num_probe_edges"] = probe_dict["num_probe_edges"] 246 | device_batch["num_probes"] = probe_dict["num_probes"] 247 | 248 | if isinstance(model, densitymodel.PainnDensityModel): 249 | res = model.probe_model(device_batch, atom_representation_scalar, atom_representation_vector, compute_iri=args.iri, compute_dori=args.dori, compute_hessian=args.hessian_eig) 250 | else: 251 | res = model.probe_model(device_batch, atom_representation, compute_iri=args.iri, compute_dori=args.dori, compute_hessian=args.hessian_eig) 252 | 253 | if args.iri or args.dori or args.hessian_eig: 254 | density, grad_outputs = res 255 | else: 256 | density = res 257 | if args.iri: 258 | iri = grad_outputs["iri"].cpu().detach().numpy().flatten() 259 | cubewriter_iri.write(iri) 260 | if args.dori: 261 | cubewriter_dori.write(grad_outputs["dori"].cpu().detach().numpy().flatten()) 262 | if args.hessian_eig: 263 | eigs = torch.linalg.eigvalsh(grad_outputs["hessian"]) 264 | eiglist = torch.unbind(eigs, dim=-1) 265 | for writer, val in zip(cubewriter_hessian_eig, eiglist): 266 | writer.write(val.cpu().detach().numpy().flatten()) 267 | 268 | cubewriter.write(density.cpu().detach().numpy().flatten()) 269 | logging.debug("Written %d/%d", cubewriter.numbers_written, np.prod(density_dict["grid_position"].shape[0:3])) 270 | 271 | end_time = timeit.default_timer() 272 | 273 | logging.info("done time_elapsed=%f", end_time-start_time) 274 | 275 | if __name__ == "__main__": 276 | main() 277 | -------------------------------------------------------------------------------- /pretrained_models/ethylenecarbonate_painn/arguments.json: -------------------------------------------------------------------------------- 1 | {"load_model": null, "cutoff": 4.0, "split_file": "/home/niflheim2/pbjo/ethylenecarbonate/splits.json", "num_interactions": 3, "node_size": 128, "output_dir": "/home/energy/pbjo/densitynet_runs/2021-06-10T18:14:58.096877", "dataset": "/home/niflheim2/pbjo/ethylenecarbonate/ethylenecarbonate.txt", "max_steps": 10000000, "device": "cuda", "use_painn_model": true, "ignore_pbc": false} -------------------------------------------------------------------------------- /pretrained_models/ethylenecarbonate_painn/best_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterbjorgensen/DeepDFT/a6bab4deb5cf05d9b46ae397b72253d04ea3c694/pretrained_models/ethylenecarbonate_painn/best_model.pth -------------------------------------------------------------------------------- /pretrained_models/ethylenecarbonate_painn/commandline_args.txt: -------------------------------------------------------------------------------- 1 | --dataset 2 | /home/niflheim2/pbjo/ethylenecarbonate/ethylenecarbonate.txt 3 | --split_file 4 | /home/niflheim2/pbjo/ethylenecarbonate/splits.json 5 | --cutoff 6 | 4 7 | --num_interactions 8 | 3 9 | --max_steps 10 | 10000000 11 | --node_size 12 | 128 13 | --use_painn_model 14 | --output_dir 15 | /home/energy/pbjo/densitynet_runs/2021-06-10T18:14:58.096877 -------------------------------------------------------------------------------- /pretrained_models/ethylenecarbonate_painn/gitdetails.txt: -------------------------------------------------------------------------------- 1 | 2021-06-09T17:05:43+02:00-b7b414c 2 | -------------------------------------------------------------------------------- /pretrained_models/ethylenecarbonate_painn/submit_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | #SBATCH --mail-type=END,FAIL 3 | #SBATCH --partition=sm3090 4 | #SBATCH -N 1 # Minimum of 1 node 5 | #SBATCH -n 8 # 7 MPI processes per node 6 | #SBATCH --time=7-00:00:00 7 | #SBATCH --mem=15G # 10 GB RAM per node 8 | #SBATCH --gres=gpu:RTX3090:1 9 | module load foss 10 | module load Python/3.8.6-GCCcore-10.2.0 11 | export MKL_NUM_THREADS=1 12 | export NUMEXPR_NUM_THREADS=1 13 | export OMP_NUM_THREADS=1 14 | export OPENBLAS_NUM_THREADS=1 15 | source ~/graphnn_env/bin/activate 16 | 17 | cd ~/densitynet_revisions/2021-06-09T17:05:43+02:00-b7b414c 18 | git fetch && git checkout b7b414c 19 | python -u runner.py --dataset /home/niflheim2/pbjo/ethylenecarbonate/ethylenecarbonate.txt --split_file /home/niflheim2/pbjo/ethylenecarbonate/splits.json --cutoff 4 --num_interactions 3 --max_steps 10000000 --node_size 128 --use_painn_model --output_dir ~/densitynet_runs/2021-06-10T18:14:58.096877 20 | -------------------------------------------------------------------------------- /pretrained_models/ethylenecarbonate_schnet/arguments.json: -------------------------------------------------------------------------------- 1 | {"load_model": null, "cutoff": 4.0, "split_file": "/home/niflheim2/pbjo/ethylenecarbonate/splits.json", "num_interactions": 6, "node_size": 128, "output_dir": "/home/energy/pbjo/densitynet_runs/2021-06-18T12:40:34.676146", "dataset": "/home/niflheim2/pbjo/ethylenecarbonate/ethylenecarbonate.txt", "max_steps": 10000000, "device": "cuda", "use_painn_model": false, "ignore_pbc": false} -------------------------------------------------------------------------------- /pretrained_models/ethylenecarbonate_schnet/best_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterbjorgensen/DeepDFT/a6bab4deb5cf05d9b46ae397b72253d04ea3c694/pretrained_models/ethylenecarbonate_schnet/best_model.pth -------------------------------------------------------------------------------- /pretrained_models/ethylenecarbonate_schnet/commandline_args.txt: -------------------------------------------------------------------------------- 1 | --dataset 2 | /home/niflheim2/pbjo/ethylenecarbonate/ethylenecarbonate.txt 3 | --split_file 4 | /home/niflheim2/pbjo/ethylenecarbonate/splits.json 5 | --cutoff 6 | 4 7 | --num_interactions 8 | 6 9 | --max_steps 10 | 10000000 11 | --node_size 12 | 128 13 | --output_dir 14 | /home/energy/pbjo/densitynet_runs/2021-06-18T12:40:34.676146 -------------------------------------------------------------------------------- /pretrained_models/ethylenecarbonate_schnet/gitdetails.txt: -------------------------------------------------------------------------------- 1 | 2021-06-16T17:34:48+02:00-996ce34 2 | -------------------------------------------------------------------------------- /pretrained_models/ethylenecarbonate_schnet/submit_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | #SBATCH --mail-type=END,FAIL 3 | #SBATCH --partition=sm3090 4 | #SBATCH -N 1 # Minimum of 1 node 5 | #SBATCH -n 8 # 7 MPI processes per node 6 | #SBATCH --time=7-00:00:00 7 | #SBATCH --mem=15G # 10 GB RAM per node 8 | #SBATCH --gres=gpu:RTX3090:1 9 | module load foss 10 | module load Python/3.8.6-GCCcore-10.2.0 11 | export MKL_NUM_THREADS=1 12 | export NUMEXPR_NUM_THREADS=1 13 | export OMP_NUM_THREADS=1 14 | export OPENBLAS_NUM_THREADS=1 15 | source ~/graphnn_env/bin/activate 16 | 17 | cd ~/densitynet_revisions/2021-06-16T17:34:48+02:00-996ce34 18 | git fetch && git checkout 996ce34 19 | python -u runner.py --dataset /home/niflheim2/pbjo/ethylenecarbonate/ethylenecarbonate.txt --split_file /home/niflheim2/pbjo/ethylenecarbonate/splits.json --cutoff 4 --num_interactions 6 --max_steps 10000000 --node_size 128 --output_dir ~/densitynet_runs/2021-06-18T12:40:34.676146 20 | -------------------------------------------------------------------------------- /pretrained_models/nmc_painn/arguments.json: -------------------------------------------------------------------------------- 1 | {"load_model": null, "cutoff": 4.0, "split_file": "/home/niflheim2/pbjo/nmc/split.json", "num_interactions": 3, "node_size": 128, "output_dir": "/home/energy/pbjo/densitynet_runs/2021-10-14T16:36:12.212212", "dataset": "/home/niflheim2/pbjo/nmc/cellrelax.txt", "max_steps": 10000000, "device": "cuda", "use_painn_model": true, "ignore_pbc": false} -------------------------------------------------------------------------------- /pretrained_models/nmc_painn/best_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterbjorgensen/DeepDFT/a6bab4deb5cf05d9b46ae397b72253d04ea3c694/pretrained_models/nmc_painn/best_model.pth -------------------------------------------------------------------------------- /pretrained_models/nmc_painn/commandline_args.txt: -------------------------------------------------------------------------------- 1 | --dataset 2 | /home/niflheim2/pbjo/nmc/cellrelax.txt 3 | --split_file 4 | /home/niflheim2/pbjo/nmc/split.json 5 | --cutoff 6 | 4 7 | --num_interactions 8 | 3 9 | --max_steps 10 | 10000000 11 | --node_size 12 | 128 13 | --use_painn_model 14 | --output_dir 15 | /home/energy/pbjo/densitynet_runs/2021-10-14T16:36:12.212212 -------------------------------------------------------------------------------- /pretrained_models/nmc_painn/gitdetails.txt: -------------------------------------------------------------------------------- 1 | 2021-09-08T13:40:09+02:00-3866288 2 | -------------------------------------------------------------------------------- /pretrained_models/nmc_painn/submit_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | #SBATCH --mail-type=END,FAIL 3 | #SBATCH --partition=sm3090 4 | #SBATCH -N 1 # Minimum of 1 node 5 | #SBATCH -n 8 # 7 MPI processes per node 6 | #SBATCH --time=7-00:00:00 7 | #SBATCH --mem=15G # 10 GB RAM per node 8 | #SBATCH --gres=gpu:RTX3090:1 9 | module load foss 10 | module load Python/3.8.6-GCCcore-10.2.0 11 | export MKL_NUM_THREADS=1 12 | export NUMEXPR_NUM_THREADS=1 13 | export OMP_NUM_THREADS=1 14 | export OPENBLAS_NUM_THREADS=1 15 | source ~/graphnn_env/bin/activate 16 | 17 | cd ~/densitynet_revisions/2021-09-08T13:40:09+02:00-3866288 18 | python -u runner.py --dataset /home/niflheim2/pbjo/nmc/cellrelax.txt --split_file /home/niflheim2/pbjo/nmc/split.json --cutoff 4 --num_interactions 3 --max_steps 10000000 --node_size 128 --use_painn_model --output_dir ~/densitynet_runs/2021-10-14T16:36:12.212212 19 | -------------------------------------------------------------------------------- /pretrained_models/nmc_schnet/arguments.json: -------------------------------------------------------------------------------- 1 | {"load_model": null, "cutoff": 4.0, "split_file": "/home/niflheim2/pbjo/nmc/split.json", "num_interactions": 6, "node_size": 128, "output_dir": "/home/energy/pbjo/densitynet_runs/2021-10-14T16:39:56.374898", "dataset": "/home/niflheim2/pbjo/nmc/cellrelax.txt", "max_steps": 10000000, "device": "cuda", "use_painn_model": false, "ignore_pbc": false} -------------------------------------------------------------------------------- /pretrained_models/nmc_schnet/best_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterbjorgensen/DeepDFT/a6bab4deb5cf05d9b46ae397b72253d04ea3c694/pretrained_models/nmc_schnet/best_model.pth -------------------------------------------------------------------------------- /pretrained_models/nmc_schnet/commandline_args.txt: -------------------------------------------------------------------------------- 1 | --dataset 2 | /home/niflheim2/pbjo/nmc/cellrelax.txt 3 | --split_file 4 | /home/niflheim2/pbjo/nmc/split.json 5 | --cutoff 6 | 4 7 | --num_interactions 8 | 6 9 | --max_steps 10 | 10000000 11 | --node_size 12 | 128 13 | --output_dir 14 | /home/energy/pbjo/densitynet_runs/2021-10-14T16:39:56.374898 -------------------------------------------------------------------------------- /pretrained_models/nmc_schnet/gitdetails.txt: -------------------------------------------------------------------------------- 1 | 2021-09-08T13:40:09+02:00-3866288 2 | -------------------------------------------------------------------------------- /pretrained_models/nmc_schnet/submit_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | #SBATCH --mail-type=END,FAIL 3 | #SBATCH --partition=sm3090 4 | #SBATCH -N 1 # Minimum of 1 node 5 | #SBATCH -n 8 # 7 MPI processes per node 6 | #SBATCH --time=7-00:00:00 7 | #SBATCH --mem=15G # 10 GB RAM per node 8 | #SBATCH --gres=gpu:RTX3090:1 9 | module load foss 10 | module load Python/3.8.6-GCCcore-10.2.0 11 | export MKL_NUM_THREADS=1 12 | export NUMEXPR_NUM_THREADS=1 13 | export OMP_NUM_THREADS=1 14 | export OPENBLAS_NUM_THREADS=1 15 | source ~/graphnn_env/bin/activate 16 | 17 | cd ~/densitynet_revisions/2021-09-08T13:40:09+02:00-3866288 18 | python -u runner.py --dataset /home/niflheim2/pbjo/nmc/cellrelax.txt --split_file /home/niflheim2/pbjo/nmc/split.json --cutoff 4 --num_interactions 6 --max_steps 10000000 --node_size 128 --output_dir ~/densitynet_runs/2021-10-14T16:39:56.374898 19 | -------------------------------------------------------------------------------- /pretrained_models/qm9_painn/arguments.json: -------------------------------------------------------------------------------- 1 | {"load_model": null, "cutoff": 4.0, "split_file": "/home/niflheim2/pbjo/qm9vasp/datasplits.json", "num_interactions": 3, "node_size": 128, "output_dir": "/home/energy/pbjo/densitynet_runs/2021-05-18T17:52:04.873627", "dataset": "/home/niflheim2/pbjo/qm9vasp/qm9vasp.txt", "max_steps": 10000000, "device": "cuda", "use_painn_model": true, "ignore_pbc": false} -------------------------------------------------------------------------------- /pretrained_models/qm9_painn/best_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterbjorgensen/DeepDFT/a6bab4deb5cf05d9b46ae397b72253d04ea3c694/pretrained_models/qm9_painn/best_model.pth -------------------------------------------------------------------------------- /pretrained_models/qm9_painn/commandline_args.txt: -------------------------------------------------------------------------------- 1 | --dataset 2 | /home/niflheim2/pbjo/qm9vasp/qm9vasp.txt 3 | --split_file 4 | /home/niflheim2/pbjo/qm9vasp/datasplits.json 5 | --cutoff 6 | 4 7 | --num_interactions 8 | 3 9 | --max_steps 10 | 10000000 11 | --node_size 12 | 128 13 | --use_painn_model 14 | --output_dir 15 | /home/energy/pbjo/densitynet_runs/2021-05-18T17:52:04.873627 -------------------------------------------------------------------------------- /pretrained_models/qm9_painn/gitdetails.txt: -------------------------------------------------------------------------------- 1 | 2021-05-18T17:49:52+02:00-953ae0e 2 | -------------------------------------------------------------------------------- /pretrained_models/qm9_painn/submit_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | #SBATCH --mail-type=END,FAIL 3 | #SBATCH --partition=sm3090 4 | #SBATCH -N 1 # Minimum of 1 node 5 | #SBATCH -n 8 # 7 MPI processes per node 6 | #SBATCH --time=7-00:00:00 7 | #SBATCH --mem=15G # 10 GB RAM per node 8 | #SBATCH --gres=gpu:RTX3090:1 9 | module load Python 10 | module load foss 11 | export MKL_NUM_THREADS=1 12 | export NUMEXPR_NUM_THREADS=1 13 | export OMP_NUM_THREADS=1 14 | export OPENBLAS_NUM_THREADS=1 15 | source ~/graphnn_env/bin/activate 16 | 17 | cd ~/densitynet_revisions/2021-05-18T17:49:52+02:00-953ae0e 18 | git fetch && git checkout 953ae0e 19 | python -u runner.py --dataset /home/niflheim2/pbjo/qm9vasp/qm9vasp.txt --split_file /home/niflheim2/pbjo/qm9vasp/datasplits.json --cutoff 4 --num_interactions 3 --max_steps 10000000 --node_size 128 --use_painn_model --output_dir ~/densitynet_runs/2021-05-18T17:52:04.873627 20 | -------------------------------------------------------------------------------- /pretrained_models/qm9_schnet/arguments.json: -------------------------------------------------------------------------------- 1 | {"load_model": null, "cutoff": 4.0, "split_file": "/nobackup/pbjo/qm9vasp/index_splits.json", "num_interactions": 6, "node_size": 128, "output_dir": "runs/model_output", "dataset": "/nobackup/pbjo/qm9vasp/qm9vasp.txt", "max_steps": 100000000, "device": "cuda", "ignore_pbc": true, "use_painn_model": false} 2 | -------------------------------------------------------------------------------- /pretrained_models/qm9_schnet/best_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterbjorgensen/DeepDFT/a6bab4deb5cf05d9b46ae397b72253d04ea3c694/pretrained_models/qm9_schnet/best_model.pth -------------------------------------------------------------------------------- /pretrained_models/qm9_schnet/commandline_args.txt: -------------------------------------------------------------------------------- 1 | --dataset 2 | /nobackup/pbjo/qm9vasp/qm9vasp.txt 3 | --split_file 4 | /nobackup/pbjo/qm9vasp/index_splits.json 5 | --ignore_pbc 6 | --cutoff 7 | 4 8 | --num_interactions 9 | 6 10 | --max_steps 11 | 100000000 12 | --node_size 13 | 128 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | asap3 3 | ase 4 | torch 5 | -------------------------------------------------------------------------------- /runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import argparse 5 | import math 6 | import logging 7 | import itertools 8 | import timeit 9 | 10 | import numpy as np 11 | import torch 12 | import torch.utils.data 13 | torch.set_num_threads(1) # Try to avoid thread overload on cluster 14 | 15 | import densitymodel 16 | import dataset 17 | 18 | 19 | def get_arguments(arg_list=None): 20 | parser = argparse.ArgumentParser( 21 | description="Train graph convolution network", fromfile_prefix_chars="+" 22 | ) 23 | parser.add_argument( 24 | "--load_model", 25 | type=str, 26 | default=None, 27 | help="Load model parameters from previous run", 28 | ) 29 | parser.add_argument( 30 | "--cutoff", 31 | type=float, 32 | default=5.0, 33 | help="Atomic interaction cutoff distance [Å]", 34 | ) 35 | parser.add_argument( 36 | "--split_file", 37 | type=str, 38 | default=None, 39 | help="Train/test/validation split file json", 40 | ) 41 | parser.add_argument( 42 | "--num_interactions", 43 | type=int, 44 | default=3, 45 | help="Number of interaction layers used", 46 | ) 47 | parser.add_argument( 48 | "--node_size", type=int, default=64, help="Size of hidden node states" 49 | ) 50 | parser.add_argument( 51 | "--output_dir", 52 | type=str, 53 | default="runs/model_output", 54 | help="Path to output directory", 55 | ) 56 | parser.add_argument( 57 | "--dataset", type=str, default="data/qm9.db", help="Path to ASE database", 58 | ) 59 | parser.add_argument( 60 | "--max_steps", 61 | type=int, 62 | default=int(1e6), 63 | help="Maximum number of optimisation steps", 64 | ) 65 | parser.add_argument( 66 | "--device", 67 | type=str, 68 | default="cuda", 69 | help="Set which device to use for training e.g. 'cuda' or 'cpu'", 70 | ) 71 | 72 | parser.add_argument( 73 | "--use_painn_model", 74 | action="store_true", 75 | help="Enable equivariant message passing model (PaiNN)" 76 | ) 77 | 78 | parser.add_argument( 79 | "--ignore_pbc", 80 | action="store_true", 81 | help="If flag is given, disable periodic boundary conditions (force to False) in atoms data", 82 | ) 83 | 84 | parser.add_argument( 85 | "--force_pbc", 86 | action="store_true", 87 | help="If flag is given, force periodic boundary conditions to True in atoms data", 88 | ) 89 | 90 | return parser.parse_args(arg_list) 91 | 92 | class AverageMeter(object): 93 | """Computes and stores the average and current value""" 94 | def __init__(self, name, fmt=':f'): 95 | self.name = name 96 | self.fmt = fmt 97 | self.reset() 98 | 99 | def reset(self): 100 | self.val = 0 101 | self.avg = 0 102 | self.sum = 0 103 | self.count = 0 104 | 105 | def update(self, val, n=1): 106 | self.val = val 107 | self.sum += val * n 108 | self.count += n 109 | self.avg = self.sum / self.count 110 | 111 | def __str__(self): 112 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 113 | return fmtstr.format(**self.__dict__) 114 | 115 | def split_data(dataset, args): 116 | # Load or generate splits 117 | if args.split_file: 118 | with open(args.split_file, "r") as fp: 119 | splits = json.load(fp) 120 | else: 121 | datalen = len(dataset) 122 | num_validation = int(math.ceil(datalen * 0.05)) 123 | indices = np.random.permutation(len(dataset)) 124 | splits = { 125 | "train": indices[num_validation:].tolist(), 126 | "validation": indices[:num_validation].tolist(), 127 | } 128 | 129 | # Save split file 130 | with open(os.path.join(args.output_dir, "datasplits.json"), "w") as f: 131 | json.dump(splits, f) 132 | 133 | # Split the dataset 134 | datasplits = {} 135 | for key, indices in splits.items(): 136 | datasplits[key] = torch.utils.data.Subset(dataset, indices) 137 | return datasplits 138 | 139 | 140 | def eval_model(model, dataloader, device): 141 | with torch.no_grad(): 142 | running_ae = torch.tensor(0., device=device) 143 | running_se = torch.tensor(0., device=device) 144 | running_count = torch.tensor(0., device=device) 145 | for batch in dataloader: 146 | device_batch = { 147 | k: v.to(device=device, non_blocking=True) for k, v in batch.items() 148 | } 149 | outputs = model(device_batch) 150 | targets = device_batch["probe_target"] 151 | 152 | running_ae += torch.sum(torch.abs(targets - outputs)) 153 | running_se += torch.sum(torch.square(targets - outputs)) 154 | running_count += torch.sum(device_batch["num_probes"]) 155 | 156 | mae = (running_ae / running_count).item() 157 | rmse = (torch.sqrt(running_se / running_count)).item() 158 | 159 | return mae, rmse 160 | 161 | 162 | def get_normalization(dataset, per_atom=True): 163 | try: 164 | num_targets = len(dataset.transformer.targets) 165 | except AttributeError: 166 | num_targets = 1 167 | x_sum = torch.zeros(num_targets) 168 | x_2 = torch.zeros(num_targets) 169 | num_objects = 0 170 | for sample in dataset: 171 | x = sample["targets"] 172 | if per_atom: 173 | x = x / sample["num_nodes"] 174 | x_sum += x 175 | x_2 += x ** 2.0 176 | num_objects += 1 177 | # Var(X) = E[X^2] - E[X]^2 178 | x_mean = x_sum / num_objects 179 | x_var = x_2 / num_objects - x_mean ** 2.0 180 | 181 | return x_mean, torch.sqrt(x_var) 182 | 183 | 184 | def count_parameters(model): 185 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 186 | 187 | def main(): 188 | args = get_arguments() 189 | 190 | # Setup logging 191 | os.makedirs(args.output_dir, exist_ok=True) 192 | logging.basicConfig( 193 | level=logging.DEBUG, 194 | format="%(asctime)s [%(levelname)-5.5s] %(message)s", 195 | handlers=[ 196 | logging.FileHandler( 197 | os.path.join(args.output_dir, "printlog.txt"), mode="w" 198 | ), 199 | logging.StreamHandler(), 200 | ], 201 | ) 202 | 203 | # Save command line args 204 | with open(os.path.join(args.output_dir, "commandline_args.txt"), "w") as f: 205 | f.write("\n".join(sys.argv[1:])) 206 | # Save parsed command line arguments 207 | with open(os.path.join(args.output_dir, "arguments.json"), "w") as f: 208 | json.dump(vars(args), f) 209 | 210 | # Setup dataset and loader 211 | if args.dataset.endswith(".txt"): 212 | # Text file contains list of datafiles 213 | with open(args.dataset, "r") as datasetfiles: 214 | filelist = [os.path.join(os.path.dirname(args.dataset), line.strip('\n')) for line in datasetfiles] 215 | else: 216 | filelist = [args.dataset] 217 | 218 | logging.info("loading data %s", args.dataset) 219 | densitydata = torch.utils.data.ConcatDataset([dataset.DensityData(path) for path in filelist]) 220 | 221 | # Split data into train and validation sets 222 | datasplits = split_data(densitydata, args) 223 | datasplits["train"] = dataset.RotatingPoolData(datasplits["train"], 20) 224 | 225 | if args.ignore_pbc and args.force_pbc: 226 | raise ValueError("ignore_pbc and force_pbc are mutually exclusive and can't both be set at the same time") 227 | elif args.ignore_pbc: 228 | set_pbc = False 229 | elif args.force_pbc: 230 | set_pbc = True 231 | else: 232 | set_pbc = None 233 | 234 | # Setup loaders 235 | train_loader = torch.utils.data.DataLoader( 236 | datasplits["train"], 237 | 2, 238 | num_workers=4, 239 | sampler=torch.utils.data.RandomSampler(datasplits["train"]), 240 | collate_fn=dataset.CollateFuncRandomSample(args.cutoff, 1000, pin_memory=False, set_pbc_to=set_pbc), 241 | ) 242 | val_loader = torch.utils.data.DataLoader( 243 | datasplits["validation"], 244 | 2, 245 | collate_fn=dataset.CollateFuncRandomSample(args.cutoff, 5000, pin_memory=False, set_pbc_to=set_pbc), 246 | num_workers=0, 247 | ) 248 | logging.info("Preloading validation batch") 249 | val_loader = [b for b in val_loader] 250 | 251 | # Initialise model 252 | device = torch.device(args.device) 253 | if args.use_painn_model: 254 | net = densitymodel.PainnDensityModel(args.num_interactions, args.node_size, args.cutoff,) 255 | else: 256 | net = densitymodel.DensityModel(args.num_interactions, args.node_size, args.cutoff,) 257 | logging.debug("model has %d parameters", count_parameters(net)) 258 | net = net.to(device) 259 | 260 | # Setup optimizer 261 | optimizer = torch.optim.Adam(net.parameters(), lr=0.0001) 262 | criterion = torch.nn.MSELoss() 263 | scheduler_fn = lambda step: 0.96 ** (step / 100000) 264 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_fn) 265 | 266 | log_interval = 5000 267 | running_loss = torch.tensor(0.0, device=device) 268 | running_loss_count = torch.tensor(0, device=device) 269 | best_val_mae = np.inf 270 | step = 0 271 | # Restore checkpoint 272 | if args.load_model: 273 | state_dict = torch.load(args.load_model) 274 | net.load_state_dict(state_dict["model"]) 275 | step = state_dict["step"] 276 | best_val_mae = state_dict["best_val_mae"] 277 | optimizer.load_state_dict(state_dict["optimizer"]) 278 | scheduler.load_state_dict(state_dict["scheduler"]) 279 | 280 | logging.info("start training") 281 | 282 | data_timer = AverageMeter("data_timer") 283 | transfer_timer = AverageMeter("transfer_timer") 284 | train_timer = AverageMeter("train_timer") 285 | eval_timer = AverageMeter("eval_time") 286 | 287 | endtime = timeit.default_timer() 288 | for _ in itertools.count(): 289 | for batch_host in train_loader: 290 | data_timer.update(timeit.default_timer()-endtime) 291 | tstart = timeit.default_timer() 292 | # Transfer to 'device' 293 | batch = { 294 | k: v.to(device=device, non_blocking=True) 295 | for (k, v) in batch_host.items() 296 | } 297 | transfer_timer.update(timeit.default_timer()-tstart) 298 | 299 | tstart = timeit.default_timer() 300 | # Reset gradient 301 | optimizer.zero_grad() 302 | 303 | # Forward, backward and optimize 304 | outputs = net(batch) 305 | loss = criterion(outputs, batch["probe_target"]) 306 | loss.backward() 307 | optimizer.step() 308 | 309 | with torch.no_grad(): 310 | running_loss += loss * batch["probe_target"].shape[0] * batch["probe_target"].shape[1] 311 | running_loss_count += torch.sum(batch["num_probes"]) 312 | 313 | train_timer.update(timeit.default_timer()-tstart) 314 | 315 | # print(step, loss_value) 316 | # Validate and save model 317 | if (step % log_interval == 0) or ((step + 1) == args.max_steps): 318 | tstart = timeit.default_timer() 319 | with torch.no_grad(): 320 | train_loss = (running_loss / running_loss_count).item() 321 | running_loss = running_loss_count = 0 322 | 323 | val_mae, val_rmse = eval_model(net, val_loader, device) 324 | 325 | logging.info( 326 | "step=%d, val_mae=%g, val_rmse=%g, sqrt(train_loss)=%g", 327 | step, 328 | val_mae, 329 | val_rmse, 330 | math.sqrt(train_loss), 331 | ) 332 | 333 | # Save checkpoint 334 | if val_mae < best_val_mae: 335 | best_val_mae = val_mae 336 | torch.save( 337 | { 338 | "model": net.state_dict(), 339 | "optimizer": optimizer.state_dict(), 340 | "scheduler": scheduler.state_dict(), 341 | "step": step, 342 | "best_val_mae": best_val_mae, 343 | }, 344 | os.path.join(args.output_dir, "best_model.pth"), 345 | ) 346 | 347 | eval_timer.update(timeit.default_timer()-tstart) 348 | logging.debug( 349 | "%s %s %s %s" % (data_timer, transfer_timer, train_timer, eval_timer) 350 | ) 351 | step += 1 352 | 353 | scheduler.step() 354 | 355 | if step >= args.max_steps: 356 | logging.info("Max steps reached, exiting") 357 | sys.exit(0) 358 | 359 | endtime = timeit.default_timer() 360 | 361 | 362 | if __name__ == "__main__": 363 | main() 364 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import ase 4 | import io 5 | import zlib 6 | import tarfile 7 | 8 | class CubeWriter(): 9 | def __init__(self, filename, atoms, data_shape, origin, comment): 10 | """ 11 | Function to write a cube file. This is a copy of ase.io.cube.write_cube but supports 12 | textIO buffer 13 | 14 | filename: str object 15 | File to which output is written. 16 | atoms: Atoms object 17 | Atoms object specifying the atomic configuration. 18 | data_shape: array-like of dimension 1 19 | Shape of the data to come 20 | origin : 3-tuple 21 | Origin of the volumetric data (units: Angstrom) 22 | comment : str, optional (default = None) 23 | Comment for the first line of the cube file. 24 | """ 25 | 26 | self.fileobj = open(filename, "w") 27 | self.data_shape = data_shape 28 | self.numbers_written = 0 29 | 30 | if comment is None: 31 | comment = 'Cube file from ASE, written on ' + time.strftime('%c') 32 | else: 33 | comment = comment.strip() 34 | self.fileobj.write(comment) 35 | 36 | self.fileobj.write('\nOUTER LOOP: X, MIDDLE LOOP: Y, INNER LOOP: Z\n') 37 | 38 | if origin is None: 39 | origin = np.zeros(3) 40 | else: 41 | origin = np.asarray(origin) / ase.units.Bohr 42 | 43 | self.fileobj.write('{0:5}{1:12.6f}{2:12.6f}{3:12.6f}\n' 44 | .format(len(atoms), *origin)) 45 | 46 | for i in range(3): 47 | n = data_shape[i] 48 | d = atoms.cell[i] / n / ase.units.Bohr 49 | self.fileobj.write('{0:5}{1:12.6f}{2:12.6f}{3:12.6f}\n'.format(n, *d)) 50 | 51 | positions = atoms.positions / ase.units.Bohr 52 | numbers = atoms.numbers 53 | for Z, (x, y, z) in zip(numbers, positions): 54 | self.fileobj.write('{0:5}{1:12.6f}{2:12.6f}{3:12.6f}{4:12.6f}\n' 55 | .format(Z, 0.0, x, y, z)) 56 | 57 | def write(self, data): 58 | for el in data: 59 | self.numbers_written += 1 60 | self.fileobj.write("%e\n" % el) 61 | 62 | if self.numbers_written >= np.prod(self.data_shape): 63 | self.fileobj.close() 64 | 65 | def write_cube(fileobj, atoms, data=None, origin=None, comment=None): 66 | """ 67 | Function to write a cube file. This is a copy of ase.io.cube.write_cube but supports 68 | textIO buffer 69 | 70 | fileobj: file object 71 | File to which output is written. 72 | atoms: Atoms object 73 | Atoms object specifying the atomic configuration. 74 | data : 3dim numpy array, optional (default = None) 75 | Array containing volumetric data as e.g. electronic density 76 | origin : 3-tuple 77 | Origin of the volumetric data (units: Angstrom) 78 | comment : str, optional (default = None) 79 | Comment for the first line of the cube file. 80 | """ 81 | 82 | if data is None: 83 | data = np.ones((2, 2, 2)) 84 | data = np.asarray(data) 85 | 86 | if data.dtype == complex: 87 | data = np.abs(data) 88 | 89 | if comment is None: 90 | comment = 'Cube file from ASE, written on ' + time.strftime('%c') 91 | else: 92 | comment = comment.strip() 93 | fileobj.write(comment) 94 | 95 | fileobj.write('\nOUTER LOOP: X, MIDDLE LOOP: Y, INNER LOOP: Z\n') 96 | 97 | if origin is None: 98 | origin = np.zeros(3) 99 | else: 100 | origin = np.asarray(origin) / ase.units.Bohr 101 | 102 | fileobj.write('{0:5}{1:12.6f}{2:12.6f}{3:12.6f}\n' 103 | .format(len(atoms), *origin)) 104 | 105 | for i in range(3): 106 | n = data.shape[i] 107 | d = atoms.cell[i] / n / ase.units.Bohr 108 | fileobj.write('{0:5}{1:12.6f}{2:12.6f}{3:12.6f}\n'.format(n, *d)) 109 | 110 | positions = atoms.positions / ase.units.Bohr 111 | numbers = atoms.numbers 112 | for Z, (x, y, z) in zip(numbers, positions): 113 | fileobj.write('{0:5}{1:12.6f}{2:12.6f}{3:12.6f}{4:12.6f}\n' 114 | .format(Z, 0.0, x, y, z)) 115 | 116 | for el in data.flat: 117 | fileobj.write("%e\n" % el) 118 | 119 | 120 | def write_cube_to_tar(tar, atoms, cubedata, origin, filename): 121 | """write_cube_to_tar 122 | Write cube file to tar archive and compress the file using zlib. 123 | Cubedata is expected to be in electrons/A^3 and is converted to 124 | electrons/Bohr^3, which is cube file convention 125 | 126 | :param tar: 127 | :param atoms: 128 | :param cubedata: 129 | :param origin: 130 | :param filename: 131 | """ 132 | cbuf = io.StringIO() 133 | write_cube( 134 | cbuf, 135 | atoms, 136 | data=cubedata*(ase.units.Bohr**3), 137 | origin=origin, 138 | comment=filename, 139 | ) 140 | cbuf.seek(0) 141 | cube_bytes = cbuf.getvalue().encode() 142 | cbytes = zlib.compress(cube_bytes) 143 | fsize = len(cbytes) 144 | cbuf = io.BytesIO(cbytes) 145 | cbuf.seek(0) 146 | tarinfo = tarfile.TarInfo(name=filename) 147 | tarinfo.size = fsize 148 | tar.addfile(tarinfo, cbuf) 149 | --------------------------------------------------------------------------------