├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── requirements.txt ├── scripts └── light_train.py ├── setup.py └── torchmdnet ├── __init__.py └── nnp ├── __init__.py ├── calculators ├── __init__.py └── torchmdcalc.py ├── model.py ├── npdataset.py ├── schnet_dataset.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 compscience.org 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt README.md LICENSE 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ifndef version 2 | $(error version variable is not set. Call with `make release version=XXX`) 3 | endif 4 | 5 | release: 6 | git checkout master 7 | git fetch 8 | git pull 9 | git tag -a $(version) -m "$(version) release" 10 | git push --tags origin $(version) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchmd-net-legacy 2 | 3 | This repository has been deprecated. Use the new https://github.com/torchmd/torchmd-net 4 | 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchmd 2 | schnetpack -------------------------------------------------------------------------------- /scripts/light_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | 5 | import torch 6 | from torch.nn import MSELoss, L1Loss 7 | from torch.utils.data import DataLoader, WeightedRandomSampler 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau 9 | 10 | import schnetpack as spk 11 | import schnetpack.atomistic as atm 12 | import schnetpack.representation as rep 13 | from schnetpack.nn.cutoff import CosineCutoff 14 | from schnetpack.data.loader import _collate_aseatoms 15 | from schnetpack.environment import SimpleEnvironmentProvider 16 | 17 | from torchmdnet.nnp.schnet_dataset import SchNetDataset 18 | from torchmdnet.nnp.utils import LoadFromFile, LogWriter 19 | from torchmdnet.nnp.utils import save_argparse 20 | from torchmdnet.nnp.utils import train_val_test_split, set_batch_size 21 | from torchmdnet.nnp.npdataset import NpysDataset, NpysDataset2 22 | from torchmdnet.nnp.model import make_schnet_model 23 | 24 | import argparse 25 | 26 | import pytorch_lightning as pl 27 | from pytorch_lightning.callbacks import LearningRateMonitor 28 | 29 | 30 | def get_args(): 31 | # fmt: off 32 | parser = argparse.ArgumentParser(description='Training') 33 | parser.add_argument('--conf','-c', type=open, action=LoadFromFile)#keep first 34 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') 35 | parser.add_argument('--batch-size', default=32,type=int, help='batch size') 36 | parser.add_argument('--num-epochs', default=300,type=int, help='number of epochs') 37 | parser.add_argument('--order', default=None, help='Npy file with order on which to split idx_train,idx_val,idx_test') 38 | parser.add_argument('--coords', default='coords.npy', help='Data source') 39 | parser.add_argument('--forces', default='forces.npy', help='Data source') 40 | parser.add_argument('--embeddings', default='embeddings.npy', help='Data source') 41 | parser.add_argument('--weights', default=None, help='Data source') 42 | parser.add_argument('--splits', default=None, help='Npz with splits idx_train,idx_val,idx_test') 43 | parser.add_argument('--gpus', default=0, help='Number of GPUs. Use CUDA_VISIBLE_DEVICES=1,2 to decide gpu') 44 | parser.add_argument('--num-nodes', type=int, default=1, help='Number of nodes') 45 | parser.add_argument('--log-dir', '-l', default='/tmp/net', help='log file') 46 | parser.add_argument('--label', default='forces', help='Label') 47 | parser.add_argument('--derivative', default='forces', help='Label') 48 | parser.add_argument('--eval-interval',type=int,default=2,help='eval interval, one eval per n updates (default: 2)') 49 | parser.add_argument('--save-interval',type=int,default=10,help='save interval, one save per n updates (default: 10)') 50 | parser.add_argument('--seed',type=int,default=1,help='random seed (default: 1)') 51 | parser.add_argument('--load-model',default=None,help='Restart training using a model checkpoint') 52 | parser.add_argument('--progress',action='store_true', default=False,help='Progress bar during batching') 53 | parser.add_argument('--val-ratio',type=float, default=0.05,help='Percentual of validation set') 54 | parser.add_argument('--test-ratio',type=float, default=0,help='Percentual of test set') 55 | parser.add_argument('--num-workers',type=int, default=0,help='Number of workers for data prefetch') 56 | parser.add_argument('--num-filters',type=int, default=128,help='Number of filter in model') 57 | parser.add_argument('--num-gaussians',type=int, default=50,help='Number of Gaussians in model') 58 | parser.add_argument('--num-interactions',type=int, default=2,help='Number of interactions in model') 59 | parser.add_argument('--max-z',type=int, default=100,help='Max atomic number in model') 60 | parser.add_argument('--cutoff',type=float, default=9,help='Cutoff in model') 61 | parser.add_argument('--lr-patience',type=int,default=10,help='Patience for lr-schedule. Patience per eval-interval of validation') 62 | parser.add_argument('--lr-min',type=float, default=1e-6,help='Minimum learning rate before early stop') 63 | parser.add_argument('--lr-factor',type=float, default=0.8,help='Minimum learning rate before early stop') 64 | parser.add_argument('--distributed-backend', default='ddp', help='Distributed backend: dp, ddp, ddp2') 65 | # fmt: on 66 | args = parser.parse_args() 67 | 68 | if args.val_ratio == 0: 69 | args.eval_interval = 0 70 | 71 | save_argparse(args, os.path.join(args.log_dir, "input.yaml"), exclude=["conf"]) 72 | 73 | return args 74 | 75 | 76 | def make_splits( 77 | dataset_len, val_ratio, test_ratio, seed, filename=None, splits=None, order=None 78 | ): 79 | if splits is not None: 80 | splits = np.load(splits) 81 | idx_train = splits["idx_train"] 82 | idx_val = splits["idx_val"] 83 | idx_test = splits["idx_test"] 84 | else: 85 | idx_train, idx_val, idx_test = train_val_test_split( 86 | dataset_len, val_ratio, test_ratio, seed, order 87 | ) 88 | 89 | if filename is not None: 90 | np.savez(filename, idx_train=idx_train, idx_val=idx_val, idx_test=idx_test) 91 | 92 | return idx_train, idx_val, idx_test 93 | 94 | 95 | class LNNP(pl.LightningModule): 96 | def __init__(self, hparams): 97 | super(LNNP, self).__init__() 98 | self.hparams = hparams 99 | if self.hparams.load_model: 100 | raise NotImplementedError # TODO 101 | else: 102 | self.model = make_schnet_model(self.hparams) 103 | # save linear fit model with random parameters 104 | self.loss_fn = MSELoss() 105 | self.test_fn = L1Loss() 106 | 107 | def prepare_data(self): 108 | print("Preparing data...", flush=True) 109 | self.dataset = NpysDataset2( 110 | self.hparams.coords, self.hparams.forces, self.hparams.embeddings 111 | ) 112 | self.dataset = SchNetDataset( 113 | self.dataset, 114 | environment_provider=SimpleEnvironmentProvider(), 115 | label=["forces"], 116 | ) 117 | self.idx_train, self.idx_val, self.idx_test = make_splits( 118 | len(self.dataset), 119 | self.hparams.val_ratio, 120 | self.hparams.test_ratio, 121 | self.hparams.seed, 122 | os.path.join(self.hparams.log_dir, f"splits.npz"), 123 | self.hparams.splits, 124 | ) 125 | self.train_dataset = torch.utils.data.Subset(self.dataset, self.idx_train) 126 | self.val_dataset = torch.utils.data.Subset(self.dataset, self.idx_val) 127 | self.test_dataset = torch.utils.data.Subset(self.dataset, self.idx_test) 128 | print( 129 | "train {}, val {}, test {}".format( 130 | len(self.train_dataset), len(self.val_dataset), len(self.test_dataset) 131 | ) 132 | ) 133 | 134 | if self.hparams.weights is not None: 135 | self.weights = torch.from_numpy(np.load(self.hparams.weights)) 136 | else: 137 | self.weights = torch.ones(len(self.dataset)) 138 | 139 | def forward(self, x): 140 | return self.model(x) 141 | 142 | def train_dataloader(self): 143 | train_loader = DataLoader( 144 | self.train_dataset, 145 | sampler=WeightedRandomSampler( 146 | self.weights[self.idx_train], len(self.train_dataset) 147 | ), 148 | batch_size=set_batch_size(self.hparams.batch_size, len(self.train_dataset)), 149 | shuffle=False, 150 | collate_fn=_collate_aseatoms, 151 | num_workers=self.hparams.num_workers, 152 | pin_memory=True, 153 | ) 154 | return train_loader 155 | 156 | def training_step(self, batch, batch_idx): 157 | prediction = self(batch) 158 | loss = self.loss_fn(prediction[self.hparams.label], batch[self.hparams.label]) 159 | self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=True, logger=True) 160 | return loss 161 | 162 | def val_dataloader(self): 163 | val_loader = None 164 | if len(self.val_dataset) > 0: 165 | # val_loader = DataLoader(self.val_dataset, sampler=WeightedRandomSampler(self.weights[self.idx_val], len(self.val_dataset)), 166 | val_loader = DataLoader( 167 | self.val_dataset, 168 | batch_size=set_batch_size( 169 | self.hparams.batch_size, len(self.val_dataset) 170 | ), 171 | collate_fn=_collate_aseatoms, 172 | num_workers=self.hparams.num_workers, 173 | pin_memory=True, 174 | ) 175 | return val_loader 176 | 177 | def validation_step(self, batch, batch_idx): 178 | torch.set_grad_enabled(True) 179 | prediction = self(batch) 180 | torch.set_grad_enabled(False) 181 | loss = self.loss_fn(prediction[self.hparams.label], batch[self.hparams.label]) 182 | return loss 183 | 184 | def validation_epoch_end(self, validation_step_outputs): 185 | avg_loss = torch.stack(validation_step_outputs).mean() 186 | self.log("val_loss", avg_loss) 187 | 188 | def test_dataloader(self): 189 | test_loader = None 190 | if len(self.test_dataset) > 0: 191 | # test_loader = DataLoader(self.test_dataset, sampler=WeightedRandomSampler(self.weights[self.idx_test], len(self.test_dataset)), 192 | test_loader = DataLoader( 193 | self.test_dataset, 194 | batch_size=set_batch_size( 195 | self.hparams.batch_size, len(self.test_dataset) 196 | ), 197 | collate_fn=_collate_aseatoms, 198 | num_workers=self.hparams.num_workers, 199 | pin_memory=True, 200 | ) 201 | return test_loader 202 | 203 | def test_step(self, batch, batch_idx): 204 | torch.set_grad_enabled(True) 205 | prediction = self(batch) 206 | torch.set_grad_enabled(False) 207 | loss = self.test_fn(prediction[self.hparams.label], batch[self.hparams.label]) 208 | return loss 209 | 210 | def test_epoch_end(self, test_step_outputs): 211 | avg_loss = torch.stack(test_step_outputs).mean() 212 | self.log("test_loss", avg_loss) 213 | 214 | def configure_optimizers(self): 215 | # optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.lr, momentum=0.9) 216 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams.lr) 217 | scheduler = ReduceLROnPlateau( 218 | optimizer, 219 | "min", 220 | factor=self.hparams.lr_factor, 221 | patience=self.hparams.lr_patience, 222 | min_lr=self.hparams.lr_min 223 | ) 224 | lr_scheduler = {'scheduler':scheduler, 225 | 'monitor':'val_loss', 226 | 'interval': 'epoch', 227 | 'frequency': 1, 228 | } 229 | return [optimizer], [lr_scheduler] 230 | 231 | def main(): 232 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 233 | 234 | args = get_args() 235 | torch.manual_seed(args.seed) 236 | torch.cuda.manual_seed_all(args.seed) 237 | 238 | model = LNNP(args) 239 | checkpoint_callback = ModelCheckpoint( 240 | filepath=args.log_dir, 241 | monitor="val_loss", 242 | save_top_k=8, 243 | period=args.eval_interval, 244 | ) 245 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 246 | tb_logger = pl.loggers.TensorBoardLogger(args.log_dir) 247 | trainer = pl.Trainer( 248 | gpus=args.gpus, 249 | max_epochs=args.num_epochs, 250 | distributed_backend=args.distributed_backend, 251 | num_nodes=args.num_nodes, 252 | default_root_dir=args.log_dir, 253 | auto_lr_find=False, 254 | resume_from_checkpoint=args.load_model, 255 | checkpoint_callback=checkpoint_callback, 256 | callbacks=[lr_monitor], 257 | logger=tb_logger, 258 | reload_dataloaders_every_epoch=False 259 | ) 260 | 261 | trainer.fit(model) 262 | 263 | # run test set after completing the fit 264 | trainer.test() 265 | 266 | # logs = LogWriter(args.log_dir,keys=('epoch','train_loss','val_loss','test_mae','lr')) 267 | 268 | 269 | # logs.write_row({'epoch':epoch,'train_loss':train_loss,'val_loss':val_loss, 270 | # 'test_mae':test_mae, 'lr':optimizer.param_groups[0]['lr']}) 271 | # progress.set_postfix({'Loss': train_loss, 'lr':optimizer.param_groups[0]['lr']}) 272 | 273 | # if optimizer.param_groups[0]['lr'] < args.lr_min: 274 | # print("Early stop reached") 275 | # break 276 | 277 | 278 | if __name__ == "__main__": 279 | main() 280 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import subprocess 3 | import os 4 | 5 | try: 6 | version = ( 7 | subprocess.check_output(["git", "describe", "--abbrev=0", "--tags"]) 8 | .strip() 9 | .decode("utf-8") 10 | ) 11 | except Exception as e: 12 | print("Could not get version tag. Defaulting to version 0") 13 | version = "0" 14 | 15 | with open("requirements.txt") as f: 16 | requirements = f.read().splitlines() 17 | 18 | if __name__ == "__main__": 19 | with open("README.md", "r") as fh: 20 | long_description = fh.read() 21 | 22 | setuptools.setup( 23 | name="torchmdnet", 24 | version=version, 25 | author="Acellera", 26 | author_email="info@acellera.com", 27 | description="TorchMD-net. Training Schnet with pytorch lighthing", 28 | long_description=long_description, 29 | long_description_content_type="text/markdown", 30 | url="https://github.com/torchmd/torchmd-net/", 31 | classifiers=[ 32 | "Programming Language :: Python :: 3", 33 | "Operating System :: POSIX :: Linux", 34 | "License :: OSI Approved :: MIT License", 35 | ], 36 | packages=setuptools.find_packages(include=["torchmdnet*"], exclude=[]), 37 | # package_data={"torchmdnet": ["config.ini", "logging.ini"],}, 38 | install_requires=requirements, 39 | ) 40 | -------------------------------------------------------------------------------- /torchmdnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/torchmd/torchmd-net-legacy/322a0bc23eab22e53279cbede69d9c38dc9c5dac/torchmdnet/__init__.py -------------------------------------------------------------------------------- /torchmdnet/nnp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/torchmd/torchmd-net-legacy/322a0bc23eab22e53279cbede69d9c38dc9c5dac/torchmdnet/nnp/__init__.py -------------------------------------------------------------------------------- /torchmdnet/nnp/calculators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/torchmd/torchmd-net-legacy/322a0bc23eab22e53279cbede69d9c38dc9c5dac/torchmdnet/nnp/calculators/__init__.py -------------------------------------------------------------------------------- /torchmdnet/nnp/calculators/torchmdcalc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from torchmdnet.nnp.schnet_dataset import SchNetDataset 5 | from schnetpack.environment import SimpleEnvironmentProvider 6 | from schnetpack.data.loader import _collate_aseatoms 7 | from torch.utils.data import DataLoader 8 | from torchmdnet.nnp.model import make_schnet_model, load_schnet_model 9 | from schnetpack import Properties 10 | 11 | 12 | class External: 13 | def __init__(self, netfile, embeddings, device="cpu"): 14 | self.model = load_schnet_model( 15 | netfile, device=device, derivative="forces", label="energy" 16 | ) 17 | self.model.to(device) 18 | self.device = device 19 | self.embeddings = embeddings.to(device) 20 | 21 | nreplicas = self.embeddings.shape[0] 22 | natoms = self.embeddings.shape[1] 23 | 24 | self.cell_offset = torch.zeros( 25 | [nreplicas, natoms, natoms - 1, 3], dtype=torch.float32 26 | ).to(device) 27 | 28 | # All vs all neighbors 29 | self.neighbors = torch.zeros( 30 | (nreplicas, natoms, natoms - 1), dtype=torch.int64 31 | ).to(device) 32 | for i in range(natoms): 33 | self.neighbors[:, i, :i] = torch.arange(0, i, dtype=torch.int64) 34 | self.neighbors[:, i, i:] = torch.arange(i + 1, natoms, dtype=torch.int64) 35 | 36 | self.neighbor_mask = torch.ones( 37 | (nreplicas, natoms, natoms - 1), dtype=torch.float32 38 | ).to(device) 39 | self.atom_mask = torch.ones((nreplicas, natoms), dtype=torch.float32).to(device) 40 | 41 | self.model.eval() 42 | 43 | def calculate(self, pos, box): 44 | assert pos.ndim == 3 45 | assert box.ndim == 3 46 | 47 | pos = pos.to(self.device).type(torch.float32) 48 | box = box.to(self.device).type(torch.float32) 49 | batch = { 50 | Properties.R: pos, 51 | Properties.cell: box, 52 | Properties.Z: self.embeddings, 53 | Properties.cell_offset: self.cell_offset, 54 | Properties.neighbors: self.neighbors, 55 | Properties.neighbor_mask: self.neighbor_mask, 56 | Properties.atom_mask: self.atom_mask, 57 | } 58 | pred = self.model(batch) 59 | return pred["energy"].detach(), pred["forces"].detach() 60 | 61 | 62 | if __name__ == "__main__": 63 | mydevice = "cuda" 64 | coords = np.array( 65 | [ 66 | [-6.878, -0.708, 2.896], 67 | [-4.189, -0.302, 0.213], 68 | [-1.287, 1.320, 2.084], 69 | [0.579, 3.407, -0.513], 70 | [3.531, 3.694, 1.893], 71 | [4.684, 0.239, 0.748], 72 | [2.498, -0.018, -2.375], 73 | [0.411, -3.025, -1.274], 74 | [-2.598, -4.011, 0.868], 75 | [-1.229, -3.774, 4.431], 76 | ], 77 | dtype=np.float32, 78 | ).reshape(1, -1, 3) 79 | coords = np.repeat(coords, 2, axis=0) 80 | box = np.array([56.3, 48.7, 24.2], dtype=np.float32).reshape(1, 3) 81 | box = np.repeat(box, 2, axis=0) 82 | # atom_pos = torch.tensor(coords).unsqueeze(0).to(mydevice) 83 | # box_t = torch.Tensor(box).unsqueeze(0).to(mydevice) 84 | z = np.load("../../tests/data/chignolin_aa.npy") 85 | z = z[:, 1].astype(np.int) 86 | ext = External("../../tests/data/model.ckp.30", z, mydevice) 87 | Epot, f = ext.calculate(coords, box) 88 | print(Epot) 89 | print(f) 90 | -------------------------------------------------------------------------------- /torchmdnet/nnp/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import schnetpack as spk 3 | import schnetpack.atomistic as atm 4 | import schnetpack.representation as rep 5 | from schnetpack.nn.cutoff import CosineCutoff 6 | from schnetpack.data.loader import _collate_aseatoms 7 | from schnetpack.environment import SimpleEnvironmentProvider 8 | from argparse import Namespace 9 | 10 | def make_schnet_model(args): 11 | label = args.label 12 | negative_dr=args.derivative is not None 13 | atomrefs = None 14 | #if hasattr(args,'atomrefs') and args.atomrefs is not None: 15 | # atomrefs = self.read_atomrefs(args.atomrefs,args.max_z) 16 | reps = rep.SchNet(n_atom_basis=args.num_filters, n_filters=args.num_filters, 17 | n_interactions=args.num_interactions, cutoff=args.cutoff, 18 | n_gaussians=args.num_gaussians, max_z=args.max_z, cutoff_network=CosineCutoff) 19 | output = spk.atomistic.Atomwise(n_in=reps.n_atom_basis, aggregation_mode='sum', 20 | property=label, derivative=args.derivative, negative_dr=negative_dr, 21 | mean=None, stddev=None, atomref=atomrefs) 22 | model = atm.AtomisticModel(reps, output) 23 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 24 | print('Number of trainable parameters {}'.format(total_params)) 25 | return model 26 | 27 | 28 | def load_schnet_model(model_file, **kargs): 29 | ckp = torch.load(model_file, map_location=kargs['device']) 30 | if 'info' in ckp.keys(): 31 | args = ckp['info'] 32 | else: 33 | if 'hparams' in ckp.keys(): 34 | args = Namespace(**ckp['hparams']) 35 | elif 'hyper_parameters' in ckp.keys(): 36 | args = Namespace(**ckp['hyper_parameters']) 37 | 38 | new_state_dict = {k.replace('model.',''):ckp['state_dict'][k] for k in ckp['state_dict'].keys()} 39 | ckp['state_dict'] = new_state_dict 40 | 41 | if not hasattr(args, 'trainable_gaussians'): 42 | args.trainable_gaussians = False 43 | 44 | #override 45 | for k in kargs.keys(): 46 | setattr(args,k,kargs[k]) 47 | 48 | model = make_schnet_model(args) 49 | model.load_state_dict(ckp['state_dict']) 50 | return model 51 | -------------------------------------------------------------------------------- /torchmdnet/nnp/npdataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | class NPDataset(Dataset): 7 | def __init__(self,coordfile,forcefile,atomtypefile): 8 | coo = np.load(coordfile) 9 | f = np.load(forcefile) 10 | at = np.load(atomtypefile) 11 | self.coo = torch.from_numpy(coo) 12 | self.f = torch.from_numpy(f) 13 | z = at[:,1].astype(np.int) 14 | self.z = torch.from_numpy(z) #get only the numbers 15 | assert self.coo.shape == self.f.shape 16 | assert self.coo.shape[1] == self.z.shape[0] 17 | 18 | def __len__(self): 19 | return self.coo.shape[0] 20 | 21 | def __getitem__(self,idx): 22 | return {'coordinates':self.coo[idx],'forces':self.f[idx],'atomic_numbers':self.z} 23 | 24 | 25 | class NpysDataset(Dataset): 26 | def __init__(self,coordfiles,forcefiles,embedfiles): 27 | coordfiles=sorted(glob.glob(coordfiles)) 28 | forcefiles=sorted(glob.glob(forcefiles)) 29 | embedfiles=sorted(glob.glob(embedfiles)) 30 | self.coords = [] 31 | self.forces = [] 32 | self.embeddings = [] 33 | self.index= [] 34 | assert len(coordfiles)==len(forcefiles)==len(embedfiles) 35 | print("Coordinates ", coordfiles) 36 | print("Forces ", forcefiles) 37 | print("Embeddings ",embedfiles) 38 | nfiles = len(coordfiles) 39 | for i in range(nfiles): 40 | cdata = torch.tensor( np.load(coordfiles[i]) ) 41 | self.coords.append(cdata) 42 | fdata = torch.tensor( np.load(forcefiles[i]) ) 43 | self.forces.append(fdata) 44 | edata = torch.tensor( np.load(embedfiles[i]).astype(np.int) ) 45 | self.embeddings.append(edata) 46 | size = cdata.shape[0] 47 | self.index.extend(list(zip(size*[i],range(size)))) 48 | assert cdata.shape == fdata.shape, "{} {}".format(cdata.shape, fdata.shape) 49 | assert cdata.shape[1] == edata.shape[0] 50 | print("Combined dataset size {}".format(len(self.index))) 51 | 52 | def __len__(self): 53 | return len(self.index) 54 | 55 | def __getitem__(self,idx): 56 | fileid,index = self.index[idx] 57 | 58 | return {'coordinates':self.coords[fileid][index], 59 | 'forces':self.forces[fileid][index], 60 | 'atomic_numbers':self.embeddings[fileid]} 61 | 62 | 63 | 64 | 65 | class NpysDataset2(Dataset): 66 | def __init__(self,coordglob,forceglob,embedglob): 67 | self.coordfiles=sorted(glob.glob(coordglob)) 68 | self.forcefiles=sorted(glob.glob(forceglob)) 69 | self.embedfiles=sorted(glob.glob(embedglob)) 70 | self.index= [] 71 | assert len(self.coordfiles)==len(self.forcefiles)==len(self.embedfiles) 72 | print("Coordinates files: ", len(self.coordfiles)) 73 | print("Forces files: ", len(self.forcefiles)) 74 | print("Embeddings files: ", len(self.embedfiles)) 75 | #make index 76 | nfiles = len(self.coordfiles) 77 | for i in range(nfiles): 78 | cdata = np.load(self.coordfiles[i]) 79 | fdata = np.load(self.forcefiles[i]) 80 | edata = np.load(self.embedfiles[i]).astype(np.int) 81 | size = cdata.shape[0] 82 | self.index.extend(list(zip(size*[i],range(size)))) 83 | #consistency check 84 | assert cdata.shape == fdata.shape, "{} {}".format(cdata.shape, fdata.shape) 85 | assert cdata.shape[1] == edata.shape[0] 86 | print("Combined dataset size {}".format(len(self.index))) 87 | 88 | def __len__(self): 89 | return len(self.index) 90 | 91 | def __getitem__(self,idx): 92 | fileid,index = self.index[idx] 93 | 94 | cdata = np.load(self.coordfiles[fileid], mmap_mode='r') #I only need one element 95 | fdata = np.load(self.forcefiles[fileid], mmap_mode='r') 96 | edata = np.load(self.embedfiles[fileid]).astype(np.int) 97 | 98 | return {'coordinates':np.array(cdata[index]), 99 | 'forces':np.array(fdata[index]), 100 | 'atomic_numbers':edata} 101 | -------------------------------------------------------------------------------- /torchmdnet/nnp/schnet_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from torch.utils.data import Dataset 5 | from schnetpack import Properties 6 | from torch.utils.data.sampler import RandomSampler 7 | 8 | class SchNetDataset(Dataset): 9 | def __init__(self, dataset, environment_provider,label=['energy']): 10 | self.dataset = dataset 11 | self.environment_provider = environment_provider 12 | label = list(label) if isinstance(label,str) else label 13 | self.label = label 14 | 15 | def __len__(self): 16 | return len(self.dataset) 17 | 18 | def __getitem__(self, index): 19 | data = self.dataset[index] 20 | atomic_numbers = data['atomic_numbers'] 21 | positions = data['coordinates'] 22 | 23 | properties = {} 24 | properties["_idx"] = torch.LongTensor([index]) 25 | properties[Properties.Z] = torch.LongTensor(atomic_numbers) 26 | properties[Properties.R] = torch.FloatTensor(positions) 27 | cell = torch.zeros((3,3),dtype=torch.float32) 28 | if 'box' in data: 29 | cell[torch.eye(3).bool()] = torch.FloatTensor(data['box']) 30 | properties[Properties.cell] = cell 31 | 32 | for l in self.label: 33 | properties[l] = torch.FloatTensor( data[l] ) 34 | 35 | at = FakeASE(atomic_numbers, positions, cell.numpy(), pbc=False) 36 | nbh_idx, offsets = self.environment_provider.get_environment(at) 37 | properties[Properties.neighbors] = torch.LongTensor(nbh_idx.astype(np.int)) 38 | properties[Properties.cell_offset] = torch.FloatTensor(offsets.astype(np.float32)) 39 | return properties 40 | 41 | # def get_label(self, label_name): 42 | # labels = [] 43 | # natoms = [] 44 | # for i in range(len(self.dataset)): 45 | # dset = self.dataset[i] 46 | # label = dset[label_name] 47 | # n_atoms = len(dset['atomic_numbers']) 48 | # labels.append(label) 49 | # natoms.append(n_atoms) 50 | # return np.array(labels), np.array(natoms) 51 | 52 | # def calc_stats(self, label_name, per_atom=True): #this is wrong, it should consider atomrefs!!!! 53 | # labels, natoms = self.get_label(label_name) 54 | # if per_atom: 55 | # return np.mean(labels / natoms, keepdims=True, dtype='float32'), \ 56 | # np.std(labels / natoms, keepdims=True, dtype='float32') 57 | # else: 58 | # return np.mean(labels, keepdims=True, dtype='float32'), \ 59 | # np.std(labels, keepdims=True, dtype='float32') 60 | 61 | 62 | class FakeASE: 63 | def __init__(self, numbers, positions, cell, pbc): 64 | self.numbers = numbers 65 | self.positions = positions 66 | self.cell = cell 67 | self.pbc = np.array([pbc, pbc, pbc]) 68 | 69 | def get_number_of_atoms(self): 70 | return len(self.numbers) 71 | 72 | def get_cell(self, complete): 73 | return self.cell 74 | 75 | 76 | from collections import OrderedDict 77 | 78 | class CachedDataset(Dataset):#TODO: UNTESTED 79 | def __init__(self, dataset, cache_size=200000): 80 | self.dataset = dataset 81 | self.cache_size = cache_size 82 | self.cache = OrderedDict() 83 | 84 | def __getitem__(self, idx): 85 | if self.cache.get(str(idx)) is not None: 86 | return self.cache.get(str(idx)) 87 | else: 88 | if len(self.cache.keys())>=self.cache_size: 89 | self.cache.popitem(last=False) #remove first 90 | self.cache[str(idx)]=self.dataset[idx] 91 | return self.cache[str(idx)] 92 | 93 | def __len__(self): 94 | return len(self.dataset) 95 | 96 | def refresh(self): 97 | self.cache = OrderedDict() 98 | 99 | 100 | -------------------------------------------------------------------------------- /torchmdnet/nnp/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Oct 30 09:45:30 2019 5 | 6 | @author: gianni 7 | """ 8 | import glob 9 | import csv 10 | import json 11 | import os 12 | import time 13 | import argparse 14 | import torch 15 | import yaml 16 | import sys 17 | import numpy as np 18 | 19 | def update_learning_rate(optimizer, lr): 20 | for param_group in optimizer.param_groups: 21 | param_group['lr'] = lr 22 | 23 | class LogWriter(object): 24 | #kind of inspired form openai.baselines.bench.monitor 25 | #We can add here an optional Tensorboard logger as well 26 | def __init__(self, path, keys, header=''): 27 | self.keys = tuple(keys)+('t',) 28 | assert path is not None 29 | self._clean_log_dir(path) 30 | filename = os.path.join(path, 'monitor.csv') 31 | 32 | self.f = open(filename, "wt") 33 | if isinstance(header, dict): 34 | header = '# {} \n'.format(json.dumps(header)) 35 | self.f.write(header) 36 | self.logger = csv.DictWriter(self.f, fieldnames=self.keys) 37 | self.logger.writeheader() 38 | self.f.flush() 39 | self.tstart = time.time() 40 | 41 | def write_row(self, epinfo): 42 | if self.logger: 43 | t = time.time() - self.tstart 44 | epinfo['t'] = t 45 | self.logger.writerow(epinfo) 46 | self.f.flush() 47 | 48 | def _clean_log_dir(self,log_dir): 49 | try: 50 | os.makedirs(log_dir) 51 | except OSError: 52 | files = glob.glob(os.path.join(log_dir, '*.csv')) 53 | for f in files: 54 | os.remove(f) 55 | 56 | 57 | class LoadFromFile(argparse.Action): 58 | #parser.add_argument('--file', type=open, action=LoadFromFile) 59 | def __call__ (self, parser, namespace, values, option_string = None): 60 | if values.name.endswith("yaml") or values.name.endswith("yml"): 61 | with values as f: 62 | namespace.__dict__.update(yaml.load(f, Loader=yaml.FullLoader)) 63 | else: 64 | raise ValueError("configuration file must end with yaml or yml") 65 | 66 | 67 | 68 | def save_argparse(args,filename,exclude=None): 69 | if filename.endswith('yaml') or filename.endswith('yml'): 70 | if isinstance(exclude, str): 71 | exclude = [exclude,] 72 | args = args.__dict__.copy() 73 | for exl in exclude: 74 | del args[exl] 75 | yaml.dump(args, open(filename, 'w')) 76 | else: 77 | raise ValueError("Configuration file should end with yaml or yml") 78 | 79 | 80 | def group_weight(module, weight_decay): 81 | group_decay = [] 82 | group_no_decay = [] 83 | for m in module.modules(): 84 | try: 85 | group_decay.append(m.weight) 86 | except: 87 | pass 88 | try: 89 | if m.bias is not None: 90 | group_no_decay.append(m.bias) 91 | except: 92 | pass 93 | 94 | assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay) 95 | groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=weight_decay)] 96 | return groups 97 | 98 | 99 | from sklearn.model_selection import train_test_split 100 | 101 | def train_val_test_split(dset_len,val_ratio,test_ratio, seed, order=None): 102 | shuffle = True if order is None else False 103 | valtest_ratio = val_ratio+test_ratio 104 | idx_train = list(range(dset_len)) 105 | idx_test = [] 106 | idx_val = [] 107 | if valtest_ratio>0 and dset_len>0: 108 | idx_train, idx_tmp = train_test_split(range(dset_len), test_size=valtest_ratio, random_state=seed, shuffle=shuffle) 109 | if test_ratio == 0: 110 | idx_val = idx_tmp 111 | elif val_ratio == 0: 112 | idx_test = idx_tmp 113 | else: 114 | test_val_ratio = test_ratio/(test_ratio+val_ratio) 115 | idx_val, idx_test = train_test_split(idx_tmp, test_size=test_val_ratio,random_state=seed, shuffle=shuffle) 116 | 117 | if order is not None: 118 | idx_train = [order[i] for i in idx_train] 119 | idx_val = [order[i] for i in idx_val] 120 | idx_test = [order[i] for i in idx_test] 121 | 122 | return np.array(idx_train), np.array(idx_val), np.array(idx_test) 123 | 124 | def set_batch_size(max_batch_size, len_dataset): 125 | batch_size = min(int(max_batch_size*len_dataset/32768)+4,max_batch_size) #min size is equal to 4 126 | if batch_size != max_batch_size: 127 | print('Warning: Dataset lenght {}. Reducing batch_size {}'.format(len_dataset,batch_size)) 128 | return batch_size 129 | 130 | if __name__ == '__main__': 131 | order = [9,8,7,6,5,4,3,2,1,0] 132 | idx_train, idx_val, idx_test = train_val_test_split(10,0.1,0.2,0,order) 133 | print(idx_train) 134 | print(idx_val) 135 | print(idx_test) --------------------------------------------------------------------------------