├── .github └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── LICENSE ├── README.md ├── cgnet ├── __init__.py ├── feature │ ├── __init__.py │ ├── combiner.py │ ├── dataset.py │ ├── feature.py │ ├── geometry.py │ ├── schnet_utils.py │ ├── statistics.py │ └── utils.py ├── molecule │ ├── __init__.py │ ├── aminoacids.py │ ├── tests │ │ ├── __init__.py │ │ ├── test_aminoacids.py │ │ └── test_trajectory.py │ └── trajectory.py ├── network │ ├── __init__.py │ ├── nnet.py │ ├── priors.py │ ├── simulation.py │ └── utils.py └── tests │ ├── __init__.py │ ├── test_divergences.py │ ├── test_feature_combiner.py │ ├── test_feature_utils.py │ ├── test_geometry_core.py │ ├── test_geometry_features.py │ ├── test_geometry_statistics.py │ ├── test_gpu.py │ ├── test_molecule_dataset.py │ ├── test_nnet.py │ ├── test_nnet_utils.py │ ├── test_schnet_features.py │ └── test_simulation.py ├── devtools ├── README.md └── changelog.md ├── examples ├── CG-Force-Fields-With-SchNet-Embeddings.ipynb ├── Training-A-Coarse-Grained-Force-Field.ipynb ├── Variable-Sized-Data-With-CGSchNet.ipynb ├── data │ ├── README.md │ ├── ala2_coordinates.npy │ └── ala2_forces.npy └── figs │ ├── CGnet.png │ └── README.md ├── requirements.txt └── setup.py /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Development: 2 | - [x] Implement feature / fix bug 3 | - [ ] Add documentation 4 | - [ ] Add tests 5 | 6 | Checks: 7 | - [ ] Run `nosetests` 8 | - [ ] Check pep8 compliance 9 | 10 | [Describe changes here] 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.DS_Store 3 | *.egg-info 4 | .ipynb_checkpoints 5 | *.pyc 6 | __pycache__ 7 | build 8 | dist 9 | *.egg 10 | debug/* 11 | *.swp 12 | *.pt 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, coarse-graining 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | cgnet 2 | ===== 3 | 4 | ```diff 5 | - This code is no longer being routinely maintained. For more modern implementations of CGSchNet, we recommend a Pytorch Geometric-based implemenation here: 6 | https://github.com/torchmd/torchmd-net 7 | ``` 8 | 9 | Coarse graining for molecular dymamics ([preprint](https://arxiv.org/abs/2007.11412)) 10 | 11 | Dependencies 12 | ------------ 13 | Required: 14 | + `numpy` 15 | + `pytorch` (1.2 or higher) 16 | + `scipy` 17 | 18 | Optional: 19 | + `mdtraj` (for `cgnet.molecule` only) 20 | + `pandas` (for `cgnet.molecule` only) 21 | + `sklearn` (for testing) 22 | + `Jupyter` (for `examples`) 23 | + `matplotlib` (for `examples`) 24 | 25 | Usage 26 | ----- 27 | Clone the repository: 28 | ``` 29 | git clone git@github.com:coarse-graining/cgnet.git 30 | ``` 31 | 32 | Install any missing dependencies, and then run: 33 | ``` 34 | cd cgnet 35 | python setup.py install 36 | ``` 37 | 38 | Notes 39 | ----- 40 | For compatibility with `pytorch==1.1`, please use the `pytorch-1.1` branch. This branch currently does not include the updates for variable size and Langevin dynamics, nor some normalization options. 41 | ```diff 42 | - CGnet models can display high variance between different training runs. For more stable models, we recommend using CGSchNet instead. 43 | ``` 44 | 45 | Cite 46 | ---- 47 | Please cite our [paper](https://doi.org/10.1063/5.0026133) in J Chem Phys: 48 | 49 | ```bibtex 50 | @article{husic2020coarse, 51 | title={Coarse graining molecular dynamics with graph neural networks}, 52 | author={Husic, Brooke E and Charron, Nicholas E and Lemm, Dominik and Wang, Jiang and P{\'e}rez, Adri{\`a} and Majewski, Maciej and Kr{\"a}mer, Andreas and Chen, Yaoyi and Olsson, Simon and de Fabritiis, Gianni and Noe{\'e}, Frank and Clementi, Cecilia}, 53 | journal={The Journal of Chemical Physics}, 54 | volume={153}, 55 | number={19}, 56 | pages={194101}, 57 | year={2020}, 58 | publisher={AIP Publishing LLC} 59 | } 60 | ``` 61 | 62 | Various methods are based off the following papers. CGnet: 63 | 64 | ```bibtex 65 | @article{wang2019machine, 66 | title={Machine learning of coarse-grained molecular dynamics force fields}, 67 | author={Wang, Jiang and Olsson, Simon and Wehmeyer, Christoph and Pérez, Adrià and Charron, Nicholas E and de Fabritiis, Gianni and Noé, Frank and Clementi, Cecilia}, 68 | journal={ACS Central Science}, 69 | year={2019}, 70 | publisher={ACS Publications}, 71 | doi={10.1021/acscentsci.8b00913} 72 | } 73 | ``` 74 | 75 | SchNet: 76 | 77 | ```bibtex 78 | @article{schutt2018schnetpack, 79 | title={SchNetPack: A deep learning toolbox for atomistic systems}, 80 | author={Schutt, KT and Kessel, Pan and Gastegger, Michael and Nicoli, KA and Tkatchenko, Alexandre and Müller, K-R}, 81 | journal={Journal of Chemical Theory and Computation}, 82 | volume={15}, 83 | number={1}, 84 | pages={448--455}, 85 | year={2018}, 86 | publisher={ACS Publications} 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /cgnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coarse-graining/cgnet/a3e0e8ddc06f4b6a9f48f4886b73b4cf372ff481/cgnet/__init__.py -------------------------------------------------------------------------------- /cgnet/feature/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * 2 | from .feature import * 3 | from .geometry import * 4 | from .statistics import * 5 | from .utils import * 6 | from .schnet_utils import * 7 | from .combiner import * 8 | -------------------------------------------------------------------------------- /cgnet/feature/combiner.py: -------------------------------------------------------------------------------- 1 | # Authors: Nick Charron 2 | 3 | import torch.nn as nn 4 | from cgnet.feature import (GeometryFeature, Geometry, SchnetFeature, 5 | GaussianRBF) 6 | import warnings 7 | g = Geometry(method='torch') 8 | 9 | 10 | class FeatureCombiner(nn.Module): 11 | """Class for combining GeometryFeatures and SchnetFeatures 12 | 13 | 14 | Parameters 15 | ---------- 16 | layer_list : list of nn.Module objects 17 | feature layers with which data is transformed before being passed to 18 | densely/fully connected layers prior to sum pooling and energy 19 | prediction/force generation. 20 | save_geometry : boolean (default=True) 21 | specifies whether or not to save the output of GeometryFeature 22 | layers. It is important to set this to true if CGnet priors 23 | are to be used, and need to callback to GeometryFeature outputs. 24 | propagate_geometry : boolean (default=False) 25 | specifies whether or not to concatenate geometry features (i.e., 26 | distances, angles, and/or dihedrals) to the feature that is 27 | propagated through the neural network. This is designed to be 28 | used ONLY when the layer list is a [GeometryFeature, SchnetFeature]. 29 | (default=False) 30 | distance_indices : list or np.ndarray of int (default=None) 31 | Indices of distances output from a GeometryFeature layer, used 32 | to isolate distances for redundant re-indexing for Schnet utilities 33 | 34 | 35 | Attributes 36 | ---------- 37 | layer_list : nn.ModuleList 38 | feature layers with which data is transformed before being passed to 39 | densely/fully connected layers prior to sum pooling and energy 40 | prediction/force generation. The length of layer_list is the number 41 | of layers. 42 | interfeature_transforms : list of None or method types 43 | inter-feature transforms that may be needed during the forward 44 | method. These functions take the output of a previous feature layer and 45 | transform/reindex it so that it may be used as input for the next 46 | feature layer. The length of this list is equal to the length of the 47 | layer_list. For example, SchnetFeature tools require a redundant form 48 | for distances, so outputs from a previous GeometryFeature layer must be 49 | re-indexed. 50 | save_geometry : boolean (default=True) 51 | specifies whether or not to save the output of GeometryFeature 52 | layers. It is important to set this to true if CGnet priors 53 | are to be used, and need to callback to GeometryFeature outputs. 54 | transform_dictionary : dictionary of strings 55 | dictionary of mappings to provide for specified inter-feature 56 | transforms. Keys are strings which describe the mapping, and values 57 | are mapping objects. For example, a redundant distance mapping may be 58 | represented as: 59 | 60 | {'redundant_distance_maping' : self.redundant_distance_mapping} 61 | 62 | Notes 63 | ----- 64 | There are several cases for combinations of GeometryFeature and 65 | SchnetFeature. 66 | 67 | (1) Geometry Feature alone 68 | (2) Geometry Feature followed by SchnetFeature 69 | (3) SchnetFeature alone 70 | 71 | (1) corresponds to classic CGnet architecture, as proposed by 72 | Wang et. al. (2019). 73 | 74 | (2) is a general combination that allows for prior callbacks to 75 | non-distance features to add functional energy constraints (e.g., 76 | angle/bond/repulsion constraints) that supplement the classic SchNet 77 | architecture as proposed by Schutt et. al. (2018). In this case, 78 | SchnetFeature should be initialized with calculate_geometry=False, 79 | as the preceding GeometryFeature layer already computes a geometrical 80 | featurization. In this case, propagate_geometry can be set to True 81 | or False depending on whether the geometry features should or should 82 | not be propagated through the neural network, respectively. 83 | 84 | (3) corresponds to classic pairwise distance-based SchNet. In this case, 85 | the SchnetFeature must be initialized with calculate_geometry=True 86 | so that it can use Geometry() tools to calculate distances on the fly. 87 | If calculate_geometry=False, the input to the network must be pairwise 88 | distances of size [n_frames, n_beads, n_neighbors], and the terminal 89 | CGnet autograd function will compute derivates with respect to pairwise 90 | distances instead of cartesian coordinates. 91 | 92 | 93 | References 94 | ---------- 95 | Wang, J., Olsson, S., Wehmeyer, C., Pérez, A., Charron, N. E., 96 | de Fabritiis, G., Noé, F., Clementi, C. (2019). Machine Learning 97 | of Coarse-Grained Molecular Dynamics Force Fields. ACS Central Science. 98 | https://doi.org/10.1021/acscentsci.8b00913 99 | K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, 100 | A. Tkatchenko, K.-R. Müller. (2018) 101 | SchNet - a deep learning architecture for molecules and materials. 102 | The Journal of Chemical Physics. 103 | https://doi.org/10.1063/1.5019779 104 | """ 105 | 106 | def __init__(self, layer_list, save_geometry=True, propagate_geometry=False, 107 | distance_indices=None): 108 | super(FeatureCombiner, self).__init__() 109 | self.layer_list = nn.ModuleList(layer_list) 110 | self.save_geometry = save_geometry 111 | self.propagate_geometry = propagate_geometry 112 | self.interfeature_transforms = [] 113 | self.transform_dictionary = {} 114 | self.distance_indices = distance_indices 115 | _has_schnet = False 116 | for layer in self.layer_list: 117 | if isinstance(layer, SchnetFeature): 118 | _has_schnet = True 119 | if (layer.calculate_geometry and any(isinstance(layer, 120 | GeometryFeature) for layer in self.layer_list)): 121 | warnings.warn("This SchnetFeature has been set to " 122 | "calculate pairwise distances. Set " 123 | "SchnetFeature.calculate_geometry=False if you are " 124 | "preceding this SchnetFeature with a GeometryFeature " 125 | "in order to prevent unnecessarily repeated pairwsie " 126 | "distance calculations") 127 | self.interfeature_transforms.append(None) 128 | if (not layer.calculate_geometry and not any(isinstance(layer, 129 | GeometryFeature) for layer in self.layer_list)): 130 | warnings.warn("This SchnetFeature has not been designated " 131 | "to calculate pairwise distances, but no " 132 | "GeometryFeature was found in the layer " 133 | "list. Please ensure that network input is " 134 | "formulated as pairwise distances.") 135 | elif layer.calculate_geometry: 136 | self.interfeature_transforms.append(None) 137 | else: 138 | if self.distance_indices is None: 139 | raise RuntimeError(("Distance indices must be " 140 | "supplied to FeatureCombiner " 141 | "for redundant re-indexing.")) 142 | self.transform_dictionary['redundant_distance_mapping'] = ( 143 | g.get_redundant_distance_mapping(layer._distance_pairs)) 144 | self.interfeature_transforms.append([self.distance_reindex]) 145 | if isinstance(layer, GeometryFeature): 146 | if _has_schnet: 147 | raise RuntimeError( 148 | "A GeometryFeature should never come after a SchnetFeature" 149 | ) 150 | else: 151 | self.interfeature_transforms.append(None) 152 | 153 | # The following just checks whether the layer_list is 154 | # [GeometryFeature, SchnetFeature] is propagate_geometry 155 | # is set to true. 156 | if self.propagate_geometry: 157 | if len(self.layer_list) != 2: 158 | raise RuntimeError( 159 | "propagate_geometry is only designed for a layer " \ 160 | "list of [GeometryFeature, SchnetFeature]" 161 | ) 162 | elif not (isinstance(self.layer_list[0], GeometryFeature) 163 | and isinstance(self.layer_list[1], SchnetFeature)): 164 | raise RuntimeError( 165 | "propagate_geometry is only designed for a layer " \ 166 | "list of [GeometryFeature, SchnetFeature]" 167 | ) 168 | 169 | def distance_reindex(self, geometry_output): 170 | """Reindexes GeometryFeature distance outputs to redundant form for 171 | SchnetFeatures and related tools. See 172 | Geometry.get_redundant_distance_mapping 173 | 174 | Parameters 175 | ---------- 176 | geometry_output : torch.Tensor 177 | geometrical feature output frome a GeometryFeature layer, of size 178 | [n_frames, n_features]. 179 | 180 | Returns 181 | ------- 182 | redundant_distances : torch.Tensor 183 | pairwise distances transformed to shape 184 | [n_frames, n_beads, n_beads-1]. 185 | """ 186 | distances = geometry_output[:, self.distance_indices] 187 | return distances[:, self.transform_dictionary['redundant_distance_mapping']] 188 | 189 | def forward(self, coords, embedding_property=None): 190 | """Forward method through specified feature layers. The forward 191 | operation proceeds through self.layer_list in that same order 192 | that was passed to the FeatureCombiner. 193 | 194 | Parameters 195 | ---------- 196 | coords : torch.Tensor 197 | Input cartesian coordinates of size [n_frames, n_beads, 3] 198 | embedding_property : torch.Tensor (default=None) 199 | Some property that should be embedded. Can be nuclear charge 200 | or maybe an arbitrary number assigned for amino-acids. 201 | Size [n_frames, n_properties]. 202 | 203 | Returns 204 | ------- 205 | feature_ouput : torch.Tensor 206 | output tensor, of shape [n_frames, n_features] after featurization 207 | through the layers contained in self.layer_list. 208 | geometry_features : torch.Tensor (default=None) 209 | if save_geometry is True and the layer list is not just a single 210 | GeometryFeature layer, the output of the last GeometryFeature 211 | layer is returned alongside the terminal features for prior energy 212 | callback access. Else, None is returned. 213 | """ 214 | feature_output = coords 215 | geometry_features = None 216 | for num, (layer, transform) in enumerate(zip(self.layer_list, 217 | self.interfeature_transforms)): 218 | if transform != None: 219 | # apply transform(s) before the layer if specified 220 | for sub_transform in transform: 221 | feature_output = sub_transform(feature_output) 222 | if isinstance(layer, SchnetFeature): 223 | feature_output = layer(feature_output, embedding_property) 224 | else: 225 | feature_output = layer(feature_output) 226 | if isinstance(layer, GeometryFeature) and self.save_geometry: 227 | geometry_features = feature_output 228 | return feature_output, geometry_features 229 | -------------------------------------------------------------------------------- /cgnet/feature/dataset.py: -------------------------------------------------------------------------------- 1 | # Author: Brooke Husic, Nick Charron 2 | # Contributors: Jiang Wang 3 | 4 | 5 | import numpy as np 6 | import torch 7 | import scipy.spatial 8 | 9 | from torch.utils.data import Dataset, DataLoader 10 | from torch.nn.utils.rnn import pad_sequence 11 | 12 | def multi_molecule_collate(input_dictionaries, device=torch.device('cpu')): 13 | """This function is used to construct padded batches for datasets 14 | that consist of molecules of different bead numbers. This must be 15 | done because tensors passed through neural networks must all 16 | be the same size. This method must be passed to the 'collate_fn' 17 | keyword argument in a PyTorch DataLoader object when working 18 | with variable size inputs to a network (see example below). 19 | 20 | Parameters 21 | ---------- 22 | input_dictionaries : list of dictionaries 23 | This is the input list of *unpadded* input data. Each example in the 24 | list is a dictionary with the following key/value pairs: 25 | 26 | 'coords' : np.array of shape (1, num_beads, 3) 27 | 'forces' : np.array of shape (1, num_beads, 3) 28 | 'embed' : np.array of shape (num_beads) 29 | 30 | Embeddings must be specified for this function to work correctly. 31 | A KeyError will be raised if they are not. 32 | 33 | Returns 34 | ------- 35 | batch : tuple of torch.tensors 36 | All the data in the batch, padded according to the largest system 37 | in the batch. The orer of tensors in the tuple is the following: 38 | 39 | coords, forces, embedding_property = batch 40 | 41 | All examples are right-padded with zeros. For example, if the 42 | maximum bead size in list of examples is 8, the embedding array 43 | for an example from a molecule composed of 3 beads will be padded 44 | from: 45 | 46 | upadded_embedding = [1, 2, 5] 47 | 48 | to: 49 | 50 | padded_embedding = [1, 2, 5, 0, 0, 0, 0, 0] 51 | 52 | An analogous right-aligned padding is done for forces and 53 | coordinates. 54 | 55 | Notes 56 | ----- 57 | See docs in MultiMoleculeDataset. While this function pads the inputs 58 | to the model, It is important to properly mask padded portions of tensors 59 | that are passed to the model. If these padded portions are not masked, 60 | then their artifical contribution carries through to the 61 | calculation of forces from the energy and the evaluation of the 62 | model loss. In particular, for MSE-style losses, there is a 63 | backpropagation instability associated with square root operations 64 | evaluated at 0. 65 | 66 | Example 67 | ------- 68 | my_loader = torch.utils.data.DataLoader(my_dataset, batch_size=512, 69 | collate_fn=multi_molecule_collate, 70 | shuffle=True) 71 | """ 72 | 73 | coordinates = pad_sequence([torch.tensor(example['coords'], 74 | requires_grad=True, device=device) 75 | for example in input_dictionaries], 76 | batch_first=True) 77 | forces = pad_sequence([torch.tensor(example['forces'], device=device) 78 | for example in input_dictionaries], 79 | batch_first=True) 80 | embeddings = pad_sequence([torch.tensor(example['embeddings'], device=device) 81 | for example in input_dictionaries], 82 | batch_first=True) 83 | return coordinates, forces, embeddings 84 | 85 | 86 | class MoleculeDataset(Dataset): 87 | """Creates dataset for coordinates and forces. 88 | 89 | Parameters 90 | ---------- 91 | coordinates : np.array 92 | Coordinate data of dimension [n_frames, n_beads, n_dimensions] 93 | forces : np.array 94 | Coordinate data of dimension [n_frames, n_beads, n_dimensions] 95 | embeddings : np.array 96 | Embedding data of dimension [n_frames, n_beads, n_embedding_properties] 97 | Embeddings must be positive integers. 98 | selection : np.array (default=None) 99 | Array of frame indices to select from the coordinates and forces. 100 | If None, all are used. 101 | stride : int (default=1) 102 | Subsample the data by 1 / stride. 103 | device : torch.device (default=torch.device('cpu')) 104 | CUDA device/GPU on which to mount tensors drawn from __getitem__(). 105 | Default device is the local CPU. 106 | """ 107 | 108 | def __init__(self, coordinates, forces, embeddings=None, selection=None, 109 | stride=1, device=torch.device('cpu')): 110 | self.stride = stride 111 | 112 | self.coordinates = self._make_array(coordinates, selection) 113 | self.forces = self._make_array(forces, selection) 114 | if embeddings is not None: 115 | if (np.any(embeddings < 1) or 116 | not np.all(embeddings.astype(int) == embeddings)): 117 | raise ValueError("Embeddings must be positive integers.") 118 | self.embeddings = self._make_array(embeddings, selection) 119 | else: 120 | self.embeddings = None 121 | 122 | self._check_inputs() 123 | 124 | self.len = len(self.coordinates) 125 | self.device = device 126 | 127 | def __getitem__(self, index): 128 | """This will always return 3 items: coordinates, forces, embeddings. 129 | If embeddings are not given, then the third object returned will 130 | be an empty tensor. 131 | """ 132 | if self.embeddings is None: 133 | # Still returns three objects, but the third is an empty tensor 134 | return ( 135 | torch.tensor(self.coordinates[index], 136 | requires_grad=True, device=self.device), 137 | torch.tensor(self.forces[index], 138 | device=self.device), 139 | torch.tensor([]) 140 | ) 141 | else: 142 | return ( 143 | torch.tensor(self.coordinates[index], 144 | requires_grad=True, device=self.device), 145 | torch.tensor(self.forces[index], 146 | device=self.device), 147 | torch.tensor(self.embeddings[index], 148 | device=self.device) 149 | ) 150 | 151 | def __len__(self): 152 | return self.len 153 | 154 | def _make_array(self, data, selection=None): 155 | """Returns an array that contains a selection of data 156 | if specified, at the stride provided. 157 | """ 158 | if selection is not None: 159 | return np.array(data[selection][::self.stride]) 160 | else: 161 | return data[::self.stride] 162 | 163 | def add_data(self, coordinates, forces, embeddings=None, selection=None): 164 | """We add data to the dataset with a custom selection and the stride 165 | specified upon object instantiation, ensuring that the embeddings 166 | have a shape length of 2, and that everything has the same number 167 | of frames. 168 | """ 169 | new_coords = self._make_array(coordinates, selection) 170 | new_forces = self._make_array(forces, selection) 171 | if embeddings is not None: 172 | new_embeddings = self._make_array(embeddings, selection) 173 | 174 | self.coordinates = np.concatenate( 175 | [self.coordinates, new_coords], axis=0) 176 | self.forces = np.concatenate([self.forces, new_forces], axis=0) 177 | 178 | if self.embeddings is not None: 179 | self.embeddings = np.concatenate([self.embeddings, new_embeddings], 180 | axis=0) 181 | 182 | self._check_inputs() 183 | 184 | self.len = len(self.coordinates) 185 | 186 | def _check_inputs(self): 187 | """When we create or add data, we need to make sure that everything 188 | has the same number of frames. 189 | """ 190 | if self.coordinates.shape != self.forces.shape: 191 | raise ValueError("Coordinates and forces must have equal shapes") 192 | 193 | if len(self.coordinates.shape) != 3: 194 | raise ValueError("Coordinates and forces must have three dimensions") 195 | 196 | if self.embeddings is not None: 197 | if len(self.embeddings.shape) != 2: 198 | raise ValueError("Embeddings must have two dimensions") 199 | 200 | if self.coordinates.shape[0] != self.embeddings.shape[0]: 201 | raise ValueError("Embeddings must have the same number of examples " 202 | "as coordinates/forces") 203 | 204 | if self.coordinates.shape[1] != self.embeddings.shape[1]: 205 | raise ValueError("Embeddings must have the same number of beads " 206 | "as the coordinates/forces") 207 | 208 | 209 | class MultiMoleculeDataset(Dataset): 210 | """Dataset object for organizing data from molecules of differing sizes. 211 | It is meant to be paired with multi_molecule_collate function for use in 212 | a PyTorch DataLoader object. With this collating function, the inputs to 213 | the model will be padded on an example-by-example basis so that batches 214 | of tensors all have a single aggregated shape before being passed into 215 | the model. 216 | 217 | Note that unlike MoleculeDataset, MultiMoleculeDataset takes a *list* 218 | of numpy arrays. 219 | 220 | Parameters 221 | ---------- 222 | coordinates_list: list of numpy.arrays 223 | List of coordinate data. Each item i in the list must be a numpy 224 | array of shape [n_beads_i, 3], containing the cartesian coordinates of 225 | a single frame for molecule i 226 | forces_list: list of numpy.arrays 227 | List of force data. Each item i in the list must be a numpy 228 | array of shape [n_beads_i, 3], containing the cartesian forces of a 229 | single frame for molecule i 230 | embeddings_list: list of numpy.arrays 231 | List of embeddings. Each item i in the list must be a numpy array 232 | of shape [n_beads_i], containing the bead embeddings of a 233 | single frame for molecule i. The embedding_list may not be None; 234 | MultiMoleculeDataset is only compatible with SchnetFeatures. 235 | 236 | Attributes 237 | ---------- 238 | data: list of dictionaries 239 | List of individual examples for molecules of different sizes. Each 240 | example is a dictionary with the following key/value pairs: 241 | 242 | 'coords' : np.array of size [n_beads_i, 3] 243 | 'forces' : np.array of size [n_beads_i, 3] 244 | 'embed' : np.array of size [n_beads_i] 245 | 246 | Example 247 | ------- 248 | my_multi_dataset = MultiMoleculeDataset(list_of_coords, list_of_forces, 249 | list_of_embeddings) 250 | my_loader = torch.utils.data.DataLoader(my_multi_dataset, batch_size=512, 251 | collate_fn=multi_molecule_collate, 252 | shuffle=True) 253 | 254 | """ 255 | 256 | def __init__(self, coordinates_list, forces_list, embeddings_list, 257 | selection=None, stride=1, device=torch.device('cpu')): 258 | self._check_inputs(coordinates_list, forces_list, 259 | embeddings_list=embeddings_list) 260 | self.stride = stride 261 | self.data = None 262 | 263 | self._make_array_data(coordinates_list, forces_list, 264 | embeddings_list=embeddings_list, 265 | selection=selection) 266 | self.len = len(self.data) 267 | 268 | def __getitem__(self, indices): 269 | """Returns the list of examples corresponding to the supplied indices. It 270 | is meant to be paired with the collating function multi_molecule_collate() 271 | """ 272 | if isinstance(indices, int): 273 | return self.data[indices] 274 | else: 275 | return [self.data[i] for i in indices] 276 | 277 | def __len__(self): 278 | return self.len 279 | 280 | def _make_array_data(self, coordinates_list, forces_list, 281 | embeddings_list, selection=None): 282 | """Assemble the NumPy arrays into a list of individual dictionaries for 283 | use with the multi_molecule_collate function. 284 | """ 285 | 286 | if self.data == None: 287 | self.data = [] 288 | if selection is not None: 289 | coordinates = [coordinates_list[i] for i in selection] 290 | forces = [forces_list[i] for i in selection] 291 | embeddings = [embeddings_list[i] for i in selection] 292 | for coord, force, embed in zip(coordinates[::self.stride], 293 | forces[::self.stride], 294 | embeddings[::self.stride]): 295 | self.data.append({ 296 | "coords" : coord, "forces" : force, "embeddings" : embed}) 297 | else: 298 | for coord, force, embed in zip(coordinates_list[::self.stride], 299 | forces_list[::self.stride], 300 | embeddings_list[::self.stride]): 301 | self.data.append({ 302 | "coords" : coord, "forces" : force, "embeddings" : embed}) 303 | 304 | 305 | def add_data(self, coordinates_list, forces_list, embeddings_list, 306 | selection=None): 307 | """We add data to the dataset with a custom selection and the stride 308 | specified upon object instantiation, ensuring that the embeddings 309 | have a shape length of 1, and that everything has the same number 310 | of frames. 311 | """ 312 | self._check_inputs(coordinates_list, forces_list, 313 | embeddings_list=embeddings_list) 314 | self._make_array_data(coordinates_list, forces_list, 315 | embeddings_list=embeddings_list, selection=selection) 316 | self.len = len(self.data) 317 | 318 | def _check_inputs(self, coordinates_list, forces_list, embeddings_list): 319 | """Helper function for ensuring data has the correct shape when 320 | adding examples to a MultiMoleculeDataset. This function also checks to 321 | to make sure that no embeddings are 0. 322 | """ 323 | 324 | if embeddings_list is None: 325 | raise ValueError("Embeddings must be supplied, as MultiMoleculeDataset" 326 | " is intended to be used only with SchNet utilities.") 327 | else: 328 | for embedding in embeddings_list: 329 | if np.any(embedding < 1): 330 | raise ValueError("Embeddings must be positive integers.") 331 | 332 | if not (len(coordinates_list) == len(forces_list) == len(embeddings_list)): 333 | raise ValueError("Coordinates, forces, and embeddings lists must " 334 | " contain the same number of examples") 335 | 336 | for idx, (coord, force, embed) in enumerate(zip(coordinates_list, forces_list, 337 | embeddings_list)): 338 | if coord.shape != force.shape: 339 | raise ValueError("Coordinates and forces must have equal shapes at example", idx) 340 | 341 | if len(coord.shape) != 2: 342 | raise ValueError("Coordinates and forces must have two dimensions at example", idx) 343 | 344 | if len(embed.shape) != 1: 345 | raise ValueError("Embeddings must have one dimension at example", idx) 346 | 347 | if coord.shape[0] != embed.shape[0]: 348 | raise ValueError("Embeddings must have the same number of beads " 349 | "as the coordinates/forces at example", idx) 350 | -------------------------------------------------------------------------------- /cgnet/feature/geometry.py: -------------------------------------------------------------------------------- 1 | # Author: Brooke Husic 2 | # Contributors: Dominik Lemm 3 | 4 | import numpy as np 5 | import scipy 6 | import torch 7 | 8 | 9 | class Geometry(): 10 | """Helper class to calculate distances, angles, and dihedrals 11 | with a unified, vectorized framework depending on whether pytorch 12 | or numpy is used. 13 | 14 | Parameters 15 | ---------- 16 | method : 'torch' or 'numpy' (default='torch') 17 | Library used for compuations 18 | device : torch.device (default=torch.device('cpu')) 19 | Device upon which geometrical calculations will take place. When 20 | embedded as an attribute for a feature class, the device will inherit 21 | from the feature device attribute 22 | """ 23 | 24 | def __init__(self, method='torch', device=torch.device('cpu')): 25 | self.device = device 26 | if method not in ['torch', 'numpy']: 27 | raise RuntimeError("Allowed methods are 'torch' and 'numpy'") 28 | self.method = method 29 | 30 | # # # # # # # # # # # # # 31 | # Define any types here # 32 | # # # # # # # # # # # # # 33 | if method == 'torch': 34 | self.bool = torch.bool 35 | self.float32 = torch.float32 36 | 37 | elif self.method == 'numpy': 38 | self.bool = np.bool 39 | self.float32 = np.float32 40 | 41 | def check_for_nans(self, object, name=None): 42 | """This method checks an object for the presence of nans and 43 | returns an error if any nans are found. 44 | """ 45 | if name is None: 46 | name = '' 47 | 48 | if self.isnan(object).any(): 49 | raise ValueError( 50 | "Nan found in {}. Check your coordinates!)".format( 51 | name) 52 | ) 53 | 54 | def check_array_vs_tensor(self, object, name=None): 55 | """This method checks whether the object (i.e., numpy array or torch 56 | tensor) is consistent with the method chosen for the Geometry 57 | instance (i.e., 'numpy' or 'torch', respectively). 58 | """ 59 | if name is None: 60 | name = '' 61 | 62 | if self.method == 'numpy' and type(object) is not np.ndarray: 63 | raise ValueError( 64 | "Input argument {} must be type np.ndarray for Geometry(method='numpy')".format( 65 | name) 66 | ) 67 | if self.method == 'torch' and type(object) is not torch.Tensor: 68 | raise ValueError( 69 | "Input argument {} must be type torch.Tensor for Geometry(method='torch')".format( 70 | name) 71 | ) 72 | 73 | def get_distance_indices(self, n_beads, backbone_inds=[], backbone_map=None): 74 | """Determines indices of pairwise distance features. 75 | """ 76 | pair_order = [] 77 | adj_backbone_pairs = [] 78 | for increment in range(1, n_beads): 79 | for i in range(n_beads - increment): 80 | pair_order.append((i, i+increment)) 81 | if len(backbone_inds) > 0: 82 | if (backbone_map[i+increment] 83 | - backbone_map[i] == 1): 84 | adj_backbone_pairs.append((i, i+increment)) 85 | 86 | return pair_order, adj_backbone_pairs 87 | 88 | def get_redundant_distance_mapping(self, pair_order): 89 | """Reformulates pairwise distances from shape [n_frames, n_dist] 90 | to shape [n_frames, n_beads, n_neighbors] 91 | 92 | This is done by finding the index mapping between non-redundant and 93 | redundant representations of the pairwise distances. This mapping can 94 | then be supplied to Schnet-related features, such as a 95 | RadialBasisFunction() layer, which use redundant pairwise distance 96 | representations. 97 | 98 | """ 99 | pairwise_dist_inds = [zipped_pair[1] for zipped_pair in sorted( 100 | [z for z in zip(pair_order, 101 | np.arange(len(pair_order))) 102 | ]) 103 | ] 104 | map_matrix = scipy.spatial.distance.squareform(pairwise_dist_inds) 105 | map_matrix = map_matrix[~np.eye(map_matrix.shape[0], 106 | dtype=bool)].reshape( 107 | map_matrix.shape[0], -1) 108 | return map_matrix 109 | 110 | def get_vectorize_inputs(self, inds, data): 111 | """Helper function to obtain indices for vectorized calculations. 112 | """ 113 | if len(np.unique([len(feat) for feat in inds])) > 1: 114 | raise ValueError( 115 | "All features must be the same length." 116 | ) 117 | feat_length = len(inds[0]) 118 | 119 | ind_list = [[feat[i] for feat in inds] 120 | for i in range(feat_length)] 121 | 122 | dist_list = [data[:, ind_list[i+1], :] 123 | - data[:, ind_list[i], :] 124 | for i in range(feat_length - 1)] 125 | 126 | if len(dist_list) == 1: 127 | dist_list = dist_list[0] 128 | 129 | return dist_list 130 | 131 | def get_distances(self, distance_inds, data, norm=True): 132 | """Calculates distances in a vectorized fashion. 133 | """ 134 | self.check_array_vs_tensor(data, 'data') 135 | distances = self.get_vectorize_inputs(distance_inds, data) 136 | if norm: 137 | distances = self.norm(distances, axis=2) 138 | self.check_for_nans(distances, 'distances') 139 | return distances 140 | 141 | def get_angles(self, angle_inds, data, clip=True): 142 | """Calculates angles in a vectorized fashion. 143 | 144 | If clip is True (default), then the angle cosines are clipped 145 | to be between -1 and 1 to account for numerical error. 146 | 147 | """ 148 | self.check_array_vs_tensor(data, 'data') 149 | 150 | base, offset = self.get_vectorize_inputs(angle_inds, data) 151 | # This convention assumes that the middle index of the angle triplet 152 | # is the angle vertex. Scalar multiplication of the first vector 153 | # of the angle triplet by -1 means that the vertex point is 154 | # subtracted from the non-vertex point for the first vector. 155 | # This ensures that the arccos operation returns the acute angle 156 | # at the vertex. See test_geometry_features for a non-parallel 157 | # formulation. 158 | base *= -1 159 | 160 | angles = self.sum(base * offset, axis=2) / self.norm(base, 161 | axis=2) / self.norm( 162 | offset, axis=2) 163 | 164 | if clip: 165 | # Clipping to prevent the arccos to be NaN 166 | angles = self.arccos(self.clip(angles, 167 | lower_bound=-1., 168 | upper_bound=1.)) 169 | 170 | self.check_for_nans(angles, 'angles') 171 | 172 | return angles 173 | 174 | def get_dihedrals(self, dihed_inds, data): 175 | """Calculates dihedrals in a vectorized fashion. 176 | 177 | Note 178 | ---- 179 | This is implemented in a hacky/bad way. It calculates twice as many 180 | dihedrals as needed and removes every other one. There is a better 181 | way to do this, I think using two lists of angles, but for now 182 | this has the correct functionality. 183 | """ 184 | self.check_array_vs_tensor(data, 'data') 185 | 186 | angle_inds = np.concatenate([[(f[i], f[i+1], f[i+2]) 187 | for i in range(2)] for f in dihed_inds]) 188 | base, offset = self.get_vectorize_inputs(angle_inds, data) 189 | offset_2 = base[:, 1:] 190 | 191 | cross_product_adj = self.cross(base, offset, axis=2) 192 | cp_base = cross_product_adj[:, :-1, :] 193 | cp_offset = cross_product_adj[:, 1:, :] 194 | 195 | plane_vector = self.cross(cp_offset, offset_2, axis=2) 196 | 197 | dihedral_cosines = self.sum(cp_base[:, ::2]*cp_offset[:, ::2], 198 | axis=2)/self.norm( 199 | cp_base[:, ::2], axis=2)/self.norm(cp_offset[:, ::2], axis=2) 200 | 201 | dihedral_sines = self.sum(cp_base[:, ::2]*plane_vector[:, ::2], 202 | axis=2)/self.norm( 203 | cp_base[:, ::2], axis=2)/self.norm(plane_vector[:, ::2], axis=2) 204 | 205 | 206 | self.check_for_nans(dihedral_cosines, 'dihedral cosines') 207 | self.check_for_nans(dihedral_sines, 'dihedral sines') 208 | 209 | return dihedral_cosines, dihedral_sines 210 | 211 | def get_neighbors(self, distances, cutoff=None): 212 | """Calculates a simple neighbor list in which every bead sees 213 | each other. If a cutoff is specified, only beads inside that distance 214 | cutoff are considered as neighbors. 215 | 216 | Parameters 217 | ---------- 218 | distances: torch.Tensor or np.array 219 | Redundant distance matrix of shape (n_frames, n_beads, n_neighbors). 220 | cutoff: float (default=None) 221 | Distance cutoff in Angstrom in which beads are considered neighbors. 222 | 223 | Returns 224 | ------- 225 | neighbors: torch.Tensor or np.array 226 | Indices of all neighbors of each bead. This is not affected by the 227 | mask. 228 | Shape [n_frames, n_beads, n_neighbors] 229 | neighbor_mask: torch.Tensor or np.array 230 | Index mask to filter out non-existing neighbors that were 231 | introduced to due distance cutoffs. 232 | Shape [n_frames, n_beads, n_neighbors] 233 | 234 | """ 235 | 236 | self.check_array_vs_tensor(distances, 'distances') 237 | 238 | n_frames, n_beads, n_neighbors = distances.shape 239 | 240 | # Create a simple neighbor list of shape [n_frames, n_beads, n_neighbors] 241 | # in which every bead sees each other but themselves. 242 | # First, create a matrix that contains all indices. 243 | neighbors = self.tile(self.arange(n_beads), (n_frames, n_beads, 1)) 244 | # To remove the self interaction of beads, an inverted identity matrix 245 | # is used to exclude the respective indices in the neighbor list. 246 | neighbors = neighbors[:, ~self.eye(n_beads, dtype=self.bool)].reshape( 247 | n_frames, 248 | n_beads, 249 | n_neighbors) 250 | 251 | if cutoff is not None: 252 | # Create an index mask for neighbors that are inside the cutoff 253 | neighbor_mask = distances < cutoff 254 | neighbor_mask = self.to_type(neighbor_mask, self.float32) 255 | else: 256 | neighbor_mask = self.ones((n_frames, n_beads, n_neighbors), 257 | dtype=self.float32) 258 | 259 | return neighbors, neighbor_mask 260 | 261 | def _torch_eye(self, n, dtype): 262 | if dtype == torch.bool: 263 | # Only in pytorch>=1.2! 264 | return torch.BoolTensor(np.eye(n, dtype=np.bool)) 265 | else: 266 | return torch.eye(n, dtype=dtype) 267 | 268 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 269 | # # # # # # # # # # # # # # Versatile Methods # # # # # # # # # # # # # # 270 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 271 | 272 | # The methods implemented below should modify the originals as little as 273 | # possible, such that the documentation for the respective method on the 274 | # numpy and pytorch websites should be sufficient. 275 | 276 | # Methods defined: arccos, cross, norm, sum, arange, tile, eye, ones, 277 | # to_type, clip, isnan 278 | 279 | def arccos(self, x): 280 | if self.method == 'torch': 281 | return torch.acos(x) 282 | elif self.method == 'numpy': 283 | return np.arccos(x) 284 | 285 | def cross(self, x, y, axis): 286 | if self.method == 'torch': 287 | return torch.cross(x, y, dim=axis) 288 | elif self.method == 'numpy': 289 | return np.cross(x, y, axis=axis) 290 | 291 | def norm(self, x, axis): 292 | if self.method == 'torch': 293 | return torch.norm(x, dim=axis) 294 | elif self.method == 'numpy': 295 | return np.linalg.norm(x, axis=axis) 296 | 297 | def sum(self, x, axis): 298 | if self.method == 'torch': 299 | return torch.sum(x, dim=axis) 300 | elif self.method == 'numpy': 301 | return np.sum(x, axis=axis) 302 | 303 | def arange(self, n): 304 | if self.method == 'torch': 305 | return torch.arange(n) 306 | elif self.method == 'numpy': 307 | return np.arange(n) 308 | 309 | def tile(self, x, shape): 310 | if self.method == 'torch': 311 | return x.repeat(*shape) 312 | elif self.method == 'numpy': 313 | return np.tile(x, shape) 314 | 315 | def eye(self, n, dtype): 316 | # As of pytorch 1.2.0, BoolTensors are implemented. However, 317 | # torch.eye does not take dtype=torch.bool on CPU devices yet. 318 | # Watch pytorch PR #24148 for the implementation, which would 319 | # allow us to return torch.eye(n, dtype=dtype) 320 | # For now, we do this: 321 | if self.method == 'torch': 322 | return self._torch_eye(n, dtype).to(self.device) 323 | elif self.method == 'numpy': 324 | return np.eye(n, dtype=dtype) 325 | 326 | def ones(self, shape, dtype): 327 | if self.method == 'torch': 328 | return torch.ones(*shape, dtype=dtype).to(self.device) 329 | elif self.method == 'numpy': 330 | return np.ones(shape, dtype=dtype) 331 | 332 | def to_type(self, x, dtype): 333 | if self.method == 'torch': 334 | return x.type(dtype) 335 | elif self.method == 'numpy': 336 | return x.astype(dtype) 337 | 338 | def clip(self, x, lower_bound, upper_bound, out=None): 339 | if self.method == 'torch': 340 | return torch.clamp(x, min=lower_bound, max=upper_bound, out=out) 341 | elif self.method == 'numpy': 342 | return np.clip(x, a_min=lower_bound, a_max=upper_bound, out=out) 343 | 344 | def isnan(self, x): 345 | if self.method == 'torch': 346 | return torch.isnan(x) 347 | elif self.method == 'numpy': 348 | return np.isnan(x) 349 | -------------------------------------------------------------------------------- /cgnet/feature/schnet_utils.py: -------------------------------------------------------------------------------- 1 | # Author: Dominik Lemm 2 | # Contributors: Nick Charron, Brooke Husic 3 | 4 | import torch 5 | import torch.nn as nn 6 | from cgnet.feature.utils import ShiftedSoftplus, LinearLayer 7 | 8 | 9 | class SimpleNormLayer(nn.Module): 10 | """Simple normalization layer that divides the output of a 11 | preceding layer by a specified number 12 | 13 | Parameters 14 | ---------- 15 | normalization_strength: float 16 | The number with which input is normalized/dived by 17 | """ 18 | 19 | def __init__(self, normalization_strength): 20 | super(SimpleNormLayer, self).__init__() 21 | self.normalization_strength = normalization_strength 22 | 23 | def forward(self, input_features): 24 | """Computes normalized output 25 | 26 | Parameters 27 | ---------- 28 | input_features: torch.Tensor 29 | Input tensor of featuers of any shape 30 | 31 | Returns 32 | ------- 33 | normalized_features: torch.Tensor 34 | Normalized input features 35 | """ 36 | return input_features / self.normalization_strength 37 | 38 | 39 | class NeighborNormLayer(nn.Module): 40 | """Normalization layer that divides the output of a 41 | preceding layer by the number of neighbor features. 42 | Unlike the SimpleNormLayer, this layer allows for 43 | dynamically changing number of neighbors during training. 44 | """ 45 | 46 | def __init__(self): 47 | super(NeighborNormLayer, self).__init__() 48 | 49 | def forward(self, input_features, n_neighbors): 50 | """Computes normalized output 51 | 52 | Parameters 53 | ---------- 54 | input_features: torch.Tensor 55 | Input tensor of featuers of shape 56 | (n_frames, n_beads, n_feats) 57 | n_neighbors: int 58 | the number of neighbors 59 | 60 | Returns 61 | ------- 62 | normalized_features: torch.Tensor 63 | Normalized input features 64 | """ 65 | return input_features / n_neighbors 66 | 67 | 68 | class CGBeadEmbedding(nn.Module): 69 | """Simple embedding class for coarse-grain beads. 70 | Serves as a lookup table that returns a fixed size embedding. 71 | 72 | Parameters 73 | ---------- 74 | n_embeddings: int 75 | Maximum number of different properties/amino_acids/elements, 76 | i.e., the dictionary size. Note: when specifying 77 | n_embeddings, you must input the total number of physical 78 | embeddings + 1. This is because the 0 embedding used for padding 79 | is included by default (see example below). 80 | embedding_dim: int 81 | Size of the embedding vector. 82 | 83 | Example 84 | ------- 85 | If you have 10 unique beads, their labels will be 1, 2, ..., 10 inclusive, 86 | because the 0 index is reserved for padding. Therefore, to specify 87 | embeddings with an output dimension of 128, you would instance 88 | CGBeadEmbedding as follows: 89 | 90 | n_embeddings = 11 # (10 + 1) 91 | embedding_dim = 128 92 | my_embedding_layer = CGBeadEmbedding(n_embeddings, embedding_dim) 93 | """ 94 | 95 | def __init__(self, n_embeddings, embedding_dim): 96 | super(CGBeadEmbedding, self).__init__() 97 | self.embedding = nn.Embedding(num_embeddings=n_embeddings, 98 | embedding_dim=embedding_dim, 99 | padding_idx=0) 100 | 101 | def forward(self, embedding_property): 102 | """ 103 | 104 | Parameters 105 | ---------- 106 | embedding_property: torch.Tensor 107 | Some property that should be embedded. Can be nuclear charge 108 | or maybe an arbitrary number assigned for amino-acids. Passing a 109 | zero will produce an embedding vector filled with zeroes (necessary 110 | in the case of zero padded batches). The properties to be embedded 111 | should be integers (torch type long). 112 | Size [n_frames, n_beads] 113 | 114 | Returns 115 | ------- 116 | embedding_vector: torch.Tensor 117 | Corresponding embedding vector to the passed indices. 118 | Size [n_frames, n_beads, embedding_dim] 119 | """ 120 | return self.embedding(embedding_property) 121 | 122 | 123 | class ContinuousFilterConvolution(nn.Module): 124 | r""" 125 | Continuous-filter convolution block as described by Schütt et al. (2018). 126 | 127 | Unlike convential convolutional layers that utilize discrete filter tensors, 128 | a continuous-filter convolutional layer evaluates the convolution at discrete 129 | locations in space using continuous radial filters (Schütt et al. 2018). 130 | 131 | x_i^{l+i} = (X^i * W^l)_i = \sum_{j=0}^{n_{atoms}} x_j^l \circ W^l (r_j -r_i) 132 | 133 | with feature representation X^l=(x^l_1, ..., x^l_n), filter-generating 134 | network W^l, positions R=(r_1, ..., r_n) and the current layer l. 135 | 136 | A continuous-filter convolution block consists of a filter generating network 137 | as follows: 138 | 139 | Filter Generator: 140 | 1. Featurization of cartesian positions into distances 141 | (which are roto-translationally invariant) 142 | (already precomputed so will be parsed as arguments) 143 | 2. Atom-wise/Linear layer with shifted-softplus activation function 144 | 3. Atom-wise/Linear layer with shifted-softplus activation function 145 | (see Notes) 146 | 147 | The filter generator output is then multiplied element-wise with the 148 | continuous convolution filter as part of the interaction block. 149 | 150 | Parameters 151 | ---------- 152 | n_gaussians: int 153 | Number of Gaussians that has been used in the radial basis function. 154 | Needed to determine the input feature size of the first dense layer. 155 | n_filters: int 156 | Number of filters that will be created. Also determines the output size. 157 | Needs to be the same size as the features of the residual connection in 158 | the interaction block. 159 | activation: nn.Module (default=ShiftedSoftplus()) 160 | Activation function for the filter generating network. Following 161 | Schütt et al, the default value is ShiftedSoftplus, but any 162 | differentiable activation function can be used (see Notes). 163 | normalization_layer: nn.Module (default=None) 164 | Normalization layer to be applied to the ouptut of the 165 | ContinuousFilterConvolution 166 | 167 | Notes 168 | ----- 169 | Following the current implementation in SchNetPack, the last linear layer of 170 | the filter generator does not contain an activation function. 171 | This allows the filter generator to contain negative values. 172 | 173 | In practice, we have observed that ShiftedSoftplus as an activation 174 | function for a SchnetFeature (i.e., within its ContinuousFilterConvolution) 175 | that is used for a CGnet will lead to simulation instabilities when using 176 | that CGnet to generate new data. We have experienced more success with 177 | nn.Tanh(). 178 | 179 | References 180 | ---------- 181 | K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, 182 | A. Tkatchenko, K.-R. Müller. (2018) 183 | SchNet - a deep learning architecture for molecules and materials. 184 | The Journal of Chemical Physics. 185 | https://doi.org/10.1063/1.5019779 186 | """ 187 | 188 | def __init__(self, n_gaussians, n_filters, activation=ShiftedSoftplus(), 189 | normalization_layer=None): 190 | super(ContinuousFilterConvolution, self).__init__() 191 | filter_layers = LinearLayer(n_gaussians, n_filters, bias=True, 192 | activation=activation) 193 | # No activation function in the last layer allows the filter generator 194 | # to contain negative values. 195 | filter_layers += LinearLayer(n_filters, n_filters, bias=True) 196 | self.filter_generator = nn.Sequential(*filter_layers) 197 | 198 | if normalization_layer: 199 | self.normalization_layer = normalization_layer 200 | else: 201 | self.normalization_layer = None 202 | 203 | def forward(self, features, rbf_expansion, neighbor_list, neighbor_mask, bead_mask=None): 204 | """ Compute convolutional block 205 | 206 | Parameters 207 | ---------- 208 | features: torch.Tensor 209 | Feature vector of size [n_frames, n_beads, n_features]. 210 | rbf_expansion: torch.Tensor 211 | Gaussian expansion of bead distances of size 212 | [n_frames, n_beads, n_neighbors, n_gaussians]. 213 | neighbor_list: torch.Tensor 214 | Indices of all neighbors of each bead. 215 | Size [n_frames, n_beads, n_neighbors] 216 | neighbor_mask: torch.Tensor 217 | Index mask to filter out non-existing neighbors that were 218 | introduced to due distance cutoffs or padding. 219 | Size [n_frames, n_beads, n_neighbors] 220 | bead_mask: torch.Tensor (default=None) 221 | Mask used to filter out non-existing beads that may be 222 | present in datasets with molecules of different sizes 223 | Size [n_frames, n_beads, n_neighbors] 224 | 225 | Returns 226 | ------- 227 | aggregated_features: torch.Tensor 228 | Residual features of shape [n_frames, n_beads, n_features] 229 | 230 | """ 231 | 232 | # Generate the convolutional filter 233 | # Size (n_frames, n_beads, n_neighbors, n_features) 234 | conv_filter = self.filter_generator(rbf_expansion) 235 | 236 | # Feature tensor needs to be transformed from 237 | # (n_frames, n_beads, n_features) 238 | # to 239 | # (n_frames, n_beads, n_neighbors, n_features) 240 | # This can be done by feeding the features of a respective bead into 241 | # its position in the neighbor_list. 242 | n_batch, n_beads, n_neighbors = neighbor_list.size() 243 | 244 | # Size (n_frames, n_beads * n_neighbors, 1) 245 | neighbor_list = neighbor_list.reshape(-1, n_beads * n_neighbors, 1) 246 | # Size (n_frames, n_beads * n_neighbors, n_features) 247 | neighbor_list = neighbor_list.expand(-1, -1, features.size(2)) 248 | 249 | # Gather the features into the respective places in the neighbor list 250 | neighbor_features = torch.gather(features, 1, neighbor_list) 251 | # Reshape back to (n_frames, n_beads, n_neighbors, n_features) for 252 | # element-wise multiplication with the filter 253 | neighbor_features = neighbor_features.reshape(n_batch, n_beads, 254 | n_neighbors, -1) 255 | # Element-wise multiplication of the features with 256 | # the convolutional filter 257 | conv_features = neighbor_features * conv_filter 258 | 259 | # Remove features from non-existing neighbors outside the cutoff 260 | conv_features = conv_features * neighbor_mask[:, :, :, None] 261 | # Aggregate/pool the features from (n_frames, n_beads, n_neighs, n_feats) 262 | # to (n_frames, n_beads, n_features) 263 | aggregated_features = torch.sum(conv_features, dim=2) 264 | 265 | # Filter out contributions from non-existent beads introduced by padding 266 | 267 | if bead_mask is not None: 268 | aggregated_features = aggregated_features * bead_mask[:, :, None] 269 | 270 | if self.normalization_layer is not None: 271 | if isinstance(self.normalization_layer, NeighborNormLayer): 272 | return self.normalization_layer(aggregated_features, n_neighbors) 273 | else: 274 | return self.normalization_layer(aggregated_features) 275 | else: 276 | return aggregated_features 277 | 278 | 279 | class InteractionBlock(nn.Module): 280 | """ 281 | SchNet interaction block as described by Schütt et al. (2018). 282 | 283 | An interaction block consists of: 284 | 1. Atom-wise/Linear layer without activation function 285 | 2. Continuous filter convolution, which is a filter-generator multiplied 286 | element-wise with the output of the previous layer 287 | 3. Atom-wise/Linear layer with activation 288 | 4. Atom-wise/Linear layer without activation 289 | 290 | The output of an interaction block will then be used to form an additive 291 | residual connection with the original input features, (x'_1, ... , x'_n), 292 | see Notes. 293 | 294 | Parameters 295 | ---------- 296 | n_inputs: int 297 | Number of input features. Determines input size for the initial linear 298 | layer. 299 | n_gaussians: int 300 | Number of Gaussians that has been used in the radial basis function. 301 | Needed in to determine the input size of the continuous filter 302 | convolution. 303 | n_filters: int 304 | Number of filters that will be created in the continuous filter convolution. 305 | The same feature size will be used for the output linear layers of the 306 | interaction block. 307 | activation: nn.Module (default=ShiftedSoftplus()) 308 | Activation function for the atom-wise layers. Following Schütt et al, 309 | the default value is ShiftedSoftplus, but any differentiable activation 310 | function can be used (see Notes). 311 | normalization_layer: nn.Module (default=None) 312 | Normalization layer to be applied to the ouptut of the 313 | ContinuousFilterConvolution 314 | 315 | Notes 316 | ----- 317 | The additive residual connection between interaction blocks is not 318 | included in the output of this forward pass. The residual connection 319 | will be computed separately outside of this class. 320 | 321 | In practice, we have observed that ShiftedSoftplus as an activation 322 | function for a SchnetFeature (i.e., within its InteractionBlock) 323 | that is used for a CGnet will lead to simulation instabilities when using 324 | that CGnet to generate new data. We have experienced more success with 325 | nn.Tanh(). 326 | 327 | References 328 | ---------- 329 | K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, 330 | A. Tkatchenko, K.-R. Müller. (2018) 331 | SchNet - a deep learning architecture for molecules and materials. 332 | The Journal of Chemical Physics. 333 | https://doi.org/10.1063/1.5019779 334 | """ 335 | 336 | def __init__(self, n_inputs, n_gaussians, n_filters, 337 | activation=ShiftedSoftplus(), normalization_layer=None): 338 | super(InteractionBlock, self).__init__() 339 | 340 | self.initial_dense = nn.Sequential( 341 | *LinearLayer(n_inputs, n_filters, bias=False, 342 | activation=None)) 343 | # backwards compatibility for spelling error in initial dense 344 | # layer attribute. 345 | # WARNING : This will be removed in the future! 346 | self.inital_dense = self.initial_dense 347 | 348 | self.cfconv = ContinuousFilterConvolution(n_gaussians=n_gaussians, 349 | n_filters=n_filters, 350 | activation=activation, 351 | normalization_layer=normalization_layer) 352 | output_layers = LinearLayer(n_filters, n_filters, bias=True, 353 | activation=activation) 354 | output_layers += LinearLayer(n_filters, n_filters, bias=True, 355 | activation=None) 356 | self.output_dense = nn.Sequential(*output_layers) 357 | 358 | def forward(self, features, rbf_expansion, neighbor_list, neighbor_mask, bead_mask=None): 359 | """ Compute interaction block 360 | 361 | Parameters 362 | ---------- 363 | features: torch.Tensor 364 | Input features from an embedding or interaction layer. 365 | Size [n_frames, n_beads, n_features] 366 | rbf_expansion: torch.Tensor 367 | Radial basis function expansion of inter-bead distances. 368 | Size [n_frames, n_beads, n_neighbors, n_gaussians] 369 | neighbor_list: torch.Tensor 370 | Indices of all neighbors of each bead. 371 | Size [n_frames, n_beads, n_neighbors] 372 | neighbor_mask: torch.Tensor 373 | Index mask to filter out non-existing neighbors that were 374 | introduced to due distance cutoffs. 375 | Size [n_frames, n_beads, n_neighbors] 376 | bead_mask: torch.Tensor (default=None) 377 | Mask used to filter out non-existing beads that may be 378 | present in datasets with molecules of different sizes 379 | Size [n_frames, n_beads, n_neighbors] 380 | 381 | Returns 382 | ------- 383 | output_features: torch.Tensor 384 | Output of an interaction block. This output can be used to form 385 | a residual connection with the output of a prior embedding/interaction 386 | layer. 387 | Size [n_frames, n_beads, n_filters] 388 | 389 | """ 390 | init_feature_output = self.initial_dense(features) 391 | conv_output = self.cfconv(init_feature_output, rbf_expansion, 392 | neighbor_list, neighbor_mask, bead_mask=bead_mask) 393 | output_features = self.output_dense(conv_output) 394 | return output_features 395 | -------------------------------------------------------------------------------- /cgnet/feature/utils.py: -------------------------------------------------------------------------------- 1 | # Authors: Nick Charron, Dominik Lemm 2 | # Contributors: Brooke Husic 3 | 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class ShiftedSoftplus(nn.Module): 11 | r""" Shifted softplus (SSP) activation function 12 | 13 | SSP originates from the softplus function: 14 | 15 | y = \ln\left(1 + e^{-x}\right) 16 | 17 | Schütt et al. (2018) introduced a shifting factor to the function in order 18 | to ensure that SSP(0) = 0 while having infinite order of continuity: 19 | 20 | y = \ln\left(1 + e^{-x}\right) - \ln(2) 21 | 22 | SSP allows to obtain smooth potential energy surfaces and second derivatives 23 | that are required for training with forces as well as the calculation of 24 | vibrational modes (Schütt et al. 2018). 25 | 26 | References 27 | ---------- 28 | K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, 29 | A. Tkatchenko, K.-R. Müller. (2018) 30 | SchNet - a deep learning architecture for molecules and materials. 31 | The Journal of Chemical Physics. 32 | https://doi.org/10.1063/1.5019779 33 | 34 | """ 35 | 36 | def __init__(self): 37 | super(ShiftedSoftplus, self).__init__() 38 | 39 | def forward(self, input_tensor): 40 | """ Applies the shifted softplus function element-wise 41 | 42 | Parameters 43 | ---------- 44 | input_tensor: torch.Tensor 45 | Input tensor of size (n_examples, *) where `*` means, any number of 46 | additional dimensions. 47 | 48 | Returns 49 | ------- 50 | Output: torch.Tensor 51 | Same size (n_examples, *) as the input. 52 | """ 53 | return nn.functional.softplus(input_tensor) - np.log(2.0) 54 | 55 | 56 | class _AbstractRBFLayer(nn.Module): 57 | """Abstract layer for definition of radial basis function layers""" 58 | 59 | def __init__(self): 60 | super(_AbstractRBFLayer, self).__init__() 61 | 62 | def __len__(self): 63 | """Method to get the size of the basis used for distance expansions. 64 | 65 | Notes 66 | ----- 67 | This method must be implemented explicitly in a child class. If not, 68 | a NotImplementedError will be raised 69 | """ 70 | raise NotImplementedError() 71 | 72 | def forward(self, distances): 73 | """Forward method to compute expansions of distances into basis 74 | functions. 75 | 76 | Notes 77 | ----- 78 | This method must be explicitly implemented in a child clase. 79 | If not, a NotImplementedError will be raised. 80 | """ 81 | raise NotImplementedError() 82 | 83 | 84 | class GaussianRBF(_AbstractRBFLayer): 85 | r"""Radial basis function (RBF) layer 86 | 87 | This layer serves as a distance expansion using radial basis functions with 88 | the following form: 89 | 90 | e_k (r_j - r_i) = exp(- (\left \| r_j - r_i \right \| - \mu_k)^2 / (2 * var) 91 | 92 | with centers mu_k calculated on a uniform grid between 93 | zero and the distance cutoff and var as the variance. 94 | The radial basis function has the effect of decorrelating the 95 | convolutional filter, which improves the training time. All distances are 96 | assumed, by default, to have units of Angstroms. 97 | 98 | Parameters 99 | ---------- 100 | low_cuttof : float (default=0.0) 101 | Minimum distance cutoff for the Gaussian basis. This cuttoff represents the 102 | center of the first basis funciton. 103 | high_cutoff : float (default=5.0) 104 | Maximum distance cutoff for the Gaussian basis. This cuttoff represents the 105 | center of the last basis function. 106 | n_gaussians : int (default=50) 107 | Total number of Gaussian functions to calculate. Number will be used to 108 | create a uniform grid from 0.0 to cutoff. The number of Gaussians will 109 | also decide the output size of the RBF layer output 110 | ([n_examples, n_beads, n_neighbors, n_gauss]). The default number of 111 | gaussians is the same as that in SchnetPack (Schutt et al, 2019). 112 | variance : float (default=1.0) 113 | The variance (standard deviation squared) of the Gaussian functions. 114 | normalize_output : bool (default=False) 115 | If True, the output of the GaussianRBF layer will be normalized by the sum 116 | over the outputs from every basis function. 117 | 118 | Example 119 | ------- 120 | To instance a SchnetFeature using a GaussianRBF layer with 50 centers, a 121 | low cuttof of 1 distance unit, a high cutoff of 50 distance units, a 122 | variance of 0.8, and no output normalization, the following procedure can 123 | be used: 124 | 125 | rbf_layer = GaussianRBF(low_cutoff=1.0, high_cutoff=50.0, 126 | n_gaussians=50, variance=0.8) 127 | schnet_feature = SchnetFeature(feature_size = ..., 128 | embedding_layer = ..., 129 | rbf_layer=rbf_layer, 130 | n_interaction_blocks = ..., 131 | calculate_geometry = ..., 132 | n_beads = ..., 133 | neighbor_cutoff = ..., 134 | device = ...) 135 | 136 | where the elipses represent the other parameters of the SchnetFeature that 137 | are specific to your needs (see cgnet.feature.SchnetFeature for more 138 | details). 139 | 140 | Notes 141 | ----- 142 | The units of the variance and cutoffs are fixed by the units of the 143 | input distances. 144 | 145 | References 146 | ---------- 147 | Schutt, K. T., Kessel, P., Gastegger, M., Nicoli, K. A., Tkatchenko, A., 148 | & Müller, K.-R. (2019). SchNetPack: A Deep Learning Toolbox For Atomistic 149 | Systems. Journal of Chemical Theory and Computation, 15(1), 448–455. 150 | https://doi.org/10.1021/acs.jctc.8b00908 151 | """ 152 | 153 | def __init__(self, low_cutoff=0.0, high_cutoff=5.0, n_gaussians=50, 154 | variance=1.0, normalize_output=False): 155 | super(GaussianRBF, self).__init__() 156 | self.register_buffer('centers', torch.linspace(low_cutoff, 157 | high_cutoff, n_gaussians)) 158 | self.variance = variance 159 | self.normalize_output = normalize_output 160 | 161 | def __len__(self): 162 | """Method to return basis size""" 163 | return len(self.centers) 164 | 165 | def forward(self, distances, distance_mask=None): 166 | """Calculate Gaussian expansion 167 | 168 | Parameters 169 | ---------- 170 | distances : torch.Tensor 171 | Interatomic distances of size [n_examples, n_beads, n_neighbors] 172 | distance_mask : torch.Tensor 173 | Mask of shape [n_examples, n_beads, n_neighbors] to filter out 174 | contributions from non-physical beads introduced from padding 175 | examples from molecules with varying sizes 176 | 177 | Returns 178 | ------- 179 | gaussian_exp: torch.Tensor 180 | Gaussian expansions of size [n_examples, n_beads, n_neighbors, 181 | n_gauss] 182 | """ 183 | dist_centered_squared = torch.pow(distances.unsqueeze(dim=3) - 184 | self.centers, 2) 185 | gaussian_exp = torch.exp(-(0.5 / self.variance) 186 | * dist_centered_squared) 187 | 188 | # If specified, normalize output by sum over all basis function outputs 189 | if self.normalize_output: 190 | basis_sum = torch.sum(gaussian_exp, dim=3) 191 | gaussian_exp = gaussian_exp / basis_sum[:, :, :, None] 192 | 193 | # Mask the output of the radial distribution with the distance mask 194 | if distance_mask is not None: 195 | gaussian_exp = gaussian_exp * distance_mask[:, :, :, None] 196 | return gaussian_exp 197 | 198 | 199 | class PolynomialCutoffRBF(_AbstractRBFLayer): 200 | r"""Radial basis function (RBF) layer 201 | This layer serves as a distance expansion using modulated radial 202 | basis functions with the following form: 203 | 204 | g_k(r_{ij}) = \phi(r_{ij}, cutoff) * 205 | exp(- \beta * (\left \exp(\alpha * -r_{ij}) - \mu_k\right)^2) 206 | 207 | where \phi(r_{ij}, cutoff) is a piecewise polynomial modulation 208 | function of the following form, 209 | 210 | / 211 | | 1 - 6*(r_{ij}/cutoff)^5 212 | | + 15*(r_{ij}/cutoff)^4 for r_{ij} < cutoff 213 | \phi = -- | - 10*(r_{ij}/cutoff)^3 214 | | 215 | | 0.0 for r_{ij} >= cutoff 216 | \ 217 | 218 | the centers mu_k calculated on a uniform grid between 219 | exp(-low_cutoff) and exp(-high_cutoff), and beta as a scaling 220 | parameter defined as: 221 | 222 | \beta = ((2/n_gaussians) * (1 - exp(-cutoff))^-2 223 | 224 | The radial basis function has the effect of decorrelating the 225 | convolutional filter, which improves the training time. All distances 226 | are assumed, by default, to have units of Angstroms. we suggest that 227 | users visually inspect their basis before use in order to make sure 228 | that they are satisfied with the distribution and cutoffs of the 229 | functions. 230 | 231 | Parameters 232 | ---------- 233 | low_cutoff : float (default=0.0) 234 | Low distance cutoff for the modulation. This parameter, 235 | along with high_cutoff, determine the distribution of the centers of 236 | each basis function. 237 | high_cutoff : float (default=10.0) 238 | Distance cutoff for the modulation. This parameter, 239 | along with low_cutoff, determine the distribution of centers of 240 | each basis function. 241 | alpha : float (default=1.0) 242 | This parameter is a prefactor to the following term: 243 | 244 | alpha * exp(-r_ij) 245 | 246 | Lower values of this parameter results in a slower transition between 247 | sharply peaked gaussian functions at smaller distances and broadly peaked 248 | gaussian functions at larger distances. 249 | with slowly decaying tails. 250 | n_gaussians : int (default=64) 251 | Total number of gaussian functions to calculate. Number will be used to 252 | create a uniform grid from exp(-cutoff) to 1. The number of gaussians 253 | will also decide the output size of the RBF layer output 254 | ([n_examples, n_beads, n_neighbors, n_gauss]). The default value of 255 | 64 gaussians is taken from Unke & Meuwly (2019). 256 | normalize_output : bool (default=False) 257 | If True, the output of the PolynomialCutoffRBF layer will be normalized 258 | by the sum over the outputs from every basis function. 259 | tolerance : float (default=1e-10) 260 | When expanding the modulated gaussians, values below the tolerance 261 | will be set to zero. 262 | device : torch.device (default=torch.device('cpu')) 263 | Device upon which tensors are mounted 264 | 265 | Attributes 266 | ---------- 267 | beta : float 268 | Gaussian decay parameter, defined as: 269 | \beta = ((2/n_gaussians) * (1 - exp(-cutoff))^-2 270 | 271 | Example 272 | ------- 273 | To instance a SchnetFeature using a PolynomialCutoffRBF layer with 50 centers, 274 | a low cuttof of 1 distance unit, a high cutoff of 50 distance units, an 275 | alpha value of 0.8, and no output normalization, the following procedure can 276 | be used: 277 | 278 | rbf_layer = PolynomialCutoffRBF(low_cutoff=1.0, high_cutoff=50.0, 279 | n_gaussians=50, variance=0.8) 280 | schnet_feature = SchnetFeature(feature_size = ..., 281 | embedding_layer = ..., 282 | rbf_layer=rbf_layer, 283 | n_interaction_blocks = ..., 284 | calculate_geometry = ..., 285 | n_beads = ..., 286 | neighbor_cutoff = ..., 287 | device = ...) 288 | 289 | where the elipses represent the other parameters of the SchnetFeature that 290 | are specific to your needs (see cgnet.feature.SchnetFeature for more 291 | details). 292 | 293 | 294 | Notes 295 | ----- 296 | These basis functions were originally introduced as part of the PhysNet 297 | architecture (Unke & Meuwly, 2019). Though the basis function centers are 298 | scattered uniformly, the modulation function has the effect of broadening 299 | those functions closer to the specified cutoff. The overall result is a set 300 | of basis functions which have high resolution at small distances which 301 | smoothly morphs to basis functions with lower resolution at larger 302 | distances. 303 | 304 | The units of the variance, cutoffs, alpha, and beta are fixed by the units 305 | of the input distances. 306 | 307 | References 308 | ---------- 309 | Unke, O. T., & Meuwly, M. (2019). PhysNet: A Neural Network for Predicting 310 | Energies, Forces, Dipole Moments and Partial Charges. Journal of 311 | Chemical Theory and Computation, 15(6), 3678–3693. 312 | https://doi.org/10.1021/acs.jctc.9b00181 313 | 314 | """ 315 | 316 | def __init__(self, low_cutoff=0.0, high_cutoff=10.0, alpha=1.0, 317 | n_gaussians=64, normalize_output=False, tolerance=1e-10, 318 | device=torch.device('cpu')): 319 | super(PolynomialCutoffRBF, self).__init__() 320 | self.tolerance = tolerance 321 | self.device = device 322 | self.register_buffer('centers', torch.linspace(np.exp(-high_cutoff), 323 | np.exp(-low_cutoff), n_gaussians)) 324 | self.high_cutoff = high_cutoff 325 | self.low_cutoff = low_cutoff 326 | self.beta = np.power(((2/n_gaussians) * 327 | (1-np.exp(-self.high_cutoff))), -2) 328 | self.alpha = alpha 329 | self.normalize_output = normalize_output 330 | 331 | def __len__(self): 332 | """Method to return basis size""" 333 | return len(self.centers) 334 | 335 | def modulation(self, distances): 336 | """PhysNet cutoff modulation function 337 | 338 | Parameters 339 | ---------- 340 | distances : torch.Tensor 341 | Interatomic distances of size [n_examples, n_beads, n_neighbors] 342 | 343 | Returns 344 | ------- 345 | mod : torch.Tensor 346 | The modulation envelope of the radial basis functions. Shape 347 | [n_examples, n_beads, n_neighbors] 348 | 349 | """ 350 | zeros = torch.zeros_like(distances).to(self.device) 351 | modulation_envelope = torch.where(distances < self.high_cutoff, 352 | 1 - 6 * 353 | torch.pow((distances/self.high_cutoff), 354 | 5) 355 | + 15 * 356 | torch.pow((distances/self.high_cutoff), 357 | 4) 358 | - 10 * 359 | torch.pow( 360 | (distances/self.high_cutoff), 3), 361 | zeros) 362 | return modulation_envelope 363 | 364 | def forward(self, distances, distance_mask=None): 365 | """Calculate modulated gaussian expansion 366 | 367 | Parameters 368 | ---------- 369 | distances : torch.Tensor 370 | Interatomic distances of size [n_examples, n_beads, n_neighbors] 371 | distance_mask : torch.Tensor 372 | Mask of shape [n_examples, n_beads, n_neighbors] to filter out 373 | contributions from non-physical beads introduced from padding 374 | examples from molecules with varying sizes 375 | 376 | Returns 377 | ------- 378 | expansions : torch.Tensor 379 | Modulated gaussian expansions of size 380 | [n_examples, n_beads, n_neighbors, n_gauss] 381 | 382 | Notes 383 | ----- 384 | The gaussian portion of the basis function is a function of 385 | exp(-r_{ij}), not r_{ij} 386 | 387 | """ 388 | dist_centered_squared = torch.pow(torch.exp(self.alpha * 389 | - distances.unsqueeze(dim=3)) 390 | - self.centers, 2) 391 | gaussian_exp = torch.exp(-self.beta 392 | * dist_centered_squared) 393 | modulation_envelope = self.modulation(distances).unsqueeze(dim=3) 394 | 395 | expansions = modulation_envelope * gaussian_exp 396 | 397 | # In practice, this gives really tiny numbers. For numbers below the 398 | # tolerance, we just set them to zero. 399 | expansions = torch.where(torch.abs(expansions) > self.tolerance, 400 | expansions, 401 | torch.zeros_like(expansions)) 402 | 403 | # If specified, normalize output by sum over all basis function outputs 404 | if self.normalize_output: 405 | basis_sum = torch.sum(expansions, dim=3) 406 | expansions = expansions / basis_sum[:, :, :, None] 407 | 408 | if distance_mask is not None: 409 | expansions = expansions * distance_mask[:, :, :, None] 410 | return expansions 411 | 412 | 413 | def LinearLayer( 414 | d_in, 415 | d_out, 416 | bias=True, 417 | activation=None, 418 | dropout=0, 419 | weight_init='xavier', 420 | weight_init_args=None, 421 | weight_init_kwargs=None): 422 | r"""Linear layer function 423 | 424 | Parameters 425 | ---------- 426 | d_in : int 427 | input dimension 428 | d_out : int 429 | output dimension 430 | bias : bool (default=True) 431 | specifies whether or not to add a bias node 432 | activation : torch.nn.Module() (default=None) 433 | activation function for the layer 434 | dropout : float (default=0) 435 | if > 0, a dropout layer with the specified dropout frequency is 436 | added after the activation. 437 | weight_init : str, float, or nn.init function (default=\'xavier\') 438 | specifies the initialization of the layer weights. For non-option 439 | initializations (eg, xavier initialization), a string may be used 440 | for simplicity. If a float or int is passed, a constant initialization 441 | is used. For more complicated initializations, a torch.nn.init function 442 | object can be passed in. 443 | weight_init_args : list or tuple (default=None) 444 | arguments (excluding the layer.weight argument) for a torch.nn.init 445 | function. 446 | weight_init_kwargs : dict (default=None) 447 | keyword arguements for a torch.nn.init function 448 | 449 | Returns 450 | ------- 451 | seq : list of torch.nn.Module() instances 452 | the full linear layer, including activation and optional dropout. 453 | 454 | Example 455 | ------- 456 | MyLayer = LinearLayer(5, 10, bias=True, activation=nn.Softplus(beta=2), 457 | weight_init=nn.init.kaiming_uniform_, 458 | weight_init_kwargs={"a":0, "mode":"fan_out", 459 | "nonlinearity":"leaky_relu"}) 460 | 461 | Produces a linear layer with input dimension 5, output dimension 10, bias 462 | inclusive, followed by a beta=2 softplus activation, with the layer weights 463 | intialized according to kaiming uniform procedure with preservation of weight 464 | variance magnitudes during backpropagation. 465 | 466 | """ 467 | 468 | seq = [nn.Linear(d_in, d_out, bias=bias)] 469 | if activation: 470 | if isinstance(activation, nn.Module): 471 | seq += [activation] 472 | else: 473 | raise TypeError( 474 | 'Activation {} is not a valid torch.nn.Module'.format( 475 | str(activation)) 476 | ) 477 | if dropout: 478 | seq += [nn.Dropout(dropout)] 479 | 480 | with torch.no_grad(): 481 | if weight_init == 'xavier': 482 | torch.nn.init.xavier_uniform_(seq[0].weight) 483 | if weight_init == 'identity': 484 | torch.nn.init.eye_(seq[0].weight) 485 | if weight_init not in ['xavier', 'identity', None]: 486 | if isinstance(weight_init, int) or isinstance(weight_init, float): 487 | torch.nn.init.constant_(seq[0].weight, weight_init) 488 | if callable(weight_init): 489 | if weight_init_args is None: 490 | weight_init_args = [] 491 | if weight_init_kwargs is None: 492 | weight_init_kwargs = [] 493 | weight_init(seq[0].weight, *weight_init_args, 494 | **weight_init_kwargs) 495 | else: 496 | raise RuntimeError( 497 | 'Unknown weight initialization \"{}\"'.format( 498 | str(weight_init)) 499 | ) 500 | return seq 501 | -------------------------------------------------------------------------------- /cgnet/molecule/__init__.py: -------------------------------------------------------------------------------- 1 | from .aminoacids import * 2 | from .trajectory import * 3 | -------------------------------------------------------------------------------- /cgnet/molecule/aminoacids.py: -------------------------------------------------------------------------------- 1 | # Author: Brooke Husic 2 | 3 | 4 | import numpy as np 5 | import warnings 6 | 7 | # These radii and masses were obtained from the following repository: 8 | # https://github.com/ZiZ1/model_builder/blob/master/models/mappings/atom_types.py 9 | 10 | # The radii were calculated by assuming a sphere and solving for the radius 11 | # using the molar volumes at 25 Celcius reported in Table 6, column 1, of: 12 | # Haeckel, M., Hinz, H,-J., Hedwig, G. (1999). Partial molar volumes of 13 | # proteins: amino acid side-chain contributions derived from the partial 14 | # molar volumes of some tripeptides over the temperature range 10-90 C. 15 | # Biophysical Chemistry. https://doi.org/10.1016/S0301-4622(99)00104-0 16 | 17 | # radii are reported in NANOMETERS 18 | RESIDUE_RADII = { 19 | 'ALA': 0.1845, 'ARG': 0.3134, 20 | 'ASN': 0.2478, 'ASP': 0.2335, 21 | 'CYS': 0.2276, 'GLN': 0.2733, 22 | 'GLU': 0.2639, 'GLY': 0.0000, 23 | 'HIS': 0.2836, 'ILE': 0.2890, 24 | 'LEU': 0.2887, 'LYS': 0.2938, 25 | 'MET': 0.2916, 'PHE': 0.3140, 26 | 'PRO': 0.2419, 'SER': 0.1936, 27 | 'THR': 0.2376, 'TRP': 0.3422, 28 | 'TYR': 0.3169, 'VAL': 0.2620 29 | } 30 | 31 | # masses are reported in AMUS 32 | RESIDUE_MASSES = { 33 | 'ALA': 89.0935, 'ARG': 174.2017, 34 | 'ASN': 132.1184, 'ASP': 133.1032, 35 | 'CYS': 121.1590, 'GLN': 146.1451, 36 | 'GLU': 147.1299, 'GLY': 75.0669, 37 | 'HIS': 155.1552, 'ILE': 131.1736, 38 | 'LEU': 131.1736, 'LYS': 146.1882, 39 | 'MET': 149.2124, 'PHE': 165.1900, 40 | 'PRO': 115.1310, 'SER': 105.0930, 41 | 'THR': 119.1197, 'TRP': 204.2262, 42 | 'TYR': 181.1894, 'VAL': 117.1469 43 | } 44 | 45 | 46 | def calculate_hard_sphere_minima(bead_pairs, cgmolecule, units='Angstroms', 47 | prefactor=0.7): 48 | """This function uses amino acid radii to calculate a minimum contact 49 | distance between atoms in a CGMolecule in either Angstroms or nanometers. 50 | Both glycine-glycine pairs and atoms in the same residue will return 51 | a distance of zero (the latter will also raise a warning). See also Notes, 52 | below. 53 | 54 | Parameters 55 | ---------- 56 | bead_pairs : list of two-element tuples 57 | Each tuple contains the two atom indices in the coarse-grained for 58 | which a mininum distance should be calculated. 59 | cgmolecule : cgnet.molecule.CGMolecule instance 60 | An initialized CGMolecule object. 61 | units : 'Angstroms' or 'nanometers' (default='Angstroms') 62 | The unit in which the minimum distances should be returned 63 | prefactor : float (default=0.7) 64 | Factor by which each atomic radii should be multiplied. 65 | The default of 0.7 is inspired by reference [1]. 66 | 67 | Returns 68 | ------- 69 | hard_sphere_minima : list of floats 70 | Each element contains the minimum hard sphere distance corresponding 71 | to the same index in the input list of bead_pairs 72 | 73 | Notes 74 | ----- 75 | This method does NOT take into account the identity of the atom in the 76 | residue. In other words, the CA-CA, CA-CB, CB-CB, etc. distances will all 77 | be identical between two residues. In the example provided below, the 78 | hard_sphere_minima output will be a list of two identical distances. 79 | 80 | This method does NOT account for distances between atoms within the same 81 | residue (i.e., the same residue index). These distances will be returned 82 | as zero and a warning will be raised. If you are using this method to 83 | populate a repulsion prior, consider applying harmonic priors to such 84 | intra-residue distances. 85 | 86 | References 87 | ---------- 88 | [1] Cheung, M. S., Finke, J. M., Callahan, B., Onuchic, J. N. (2003). 89 | Exploring the interplay between topology and secondary structure 90 | formation in the protein folding problem. J. Phys. Chem. B. 91 | https://doi.org/10.1021/jp034441r 92 | 93 | Example 94 | ------- 95 | names = ['CA', 'CB', 'CA', 'CB'] 96 | resseq = [1, 1, 2, 2] 97 | resmap = {1 : 'ALA', 2 : 'PHE'} 98 | 99 | dipeptide = CGMolecule(names, resseq, resmap) 100 | 101 | # Our CA-CA distance is (0, 2), and our CB-CB distance is (1, 3) 102 | hard_sphere_minima = calculate_hard_sphere_minima([(0, 2), (1, 3)], 103 | dipeptide) 104 | 105 | # Note that in this example, hard_sphere_minima will have two entries 106 | # with the same distance 107 | """ 108 | if units.lower() not in ['angstroms', 'nanometers']: 109 | raise ValueError("units must Angstroms or nanometers") 110 | 111 | resmap = cgmolecule.resmap 112 | resseq = cgmolecule.resseq 113 | if units.lower() == 'angstroms': 114 | residue_radii = {k : 10*v for k, v in RESIDUE_RADII.items()} 115 | else: 116 | residue_radii = RESIDUE_RADII 117 | 118 | # Calculate the distance unless the residue indices are the same, 119 | # in which case use a nan instead. We go through nans because we 120 | # want to provide the user with the problematic indices, and zeros 121 | # aren't unique because a GLY-GLY pair would also return a zero 122 | # even for different residue indices. 123 | hard_sphere_minima = np.array( 124 | [(prefactor*residue_radii[resmap[resseq[b1]]] + 125 | prefactor*residue_radii[resmap[resseq[b2]]]) 126 | if resseq[b1] != resseq[b2] else np.nan 127 | for b1, b2 in bead_pairs] 128 | ) 129 | 130 | nan_indices = np.where(np.isnan(hard_sphere_minima))[0] 131 | if len(nan_indices) > 0: 132 | warnings.warn("The following bead pairs were in the same residue. Their " 133 | "minima were set to zero: {}".format( 134 | [bead_pairs[ni] for ni in nan_indices])) 135 | hard_sphere_minima[nan_indices] = 0. 136 | 137 | hard_sphere_minima = [np.round(dist, 4) for dist in hard_sphere_minima] 138 | 139 | return hard_sphere_minima 140 | -------------------------------------------------------------------------------- /cgnet/molecule/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coarse-graining/cgnet/a3e0e8ddc06f4b6a9f48f4886b73b4cf372ff481/cgnet/molecule/tests/__init__.py -------------------------------------------------------------------------------- /cgnet/molecule/tests/test_aminoacids.py: -------------------------------------------------------------------------------- 1 | # Author: Brooke Husic 2 | 3 | import numpy as np 4 | import itertools 5 | 6 | from cgnet.molecule import (CGMolecule, RESIDUE_RADII, 7 | calculate_hard_sphere_minima) 8 | 9 | 10 | def test_angstrom_conversion(): 11 | # This tests in a somewhat roundabout way whether the angstrom 12 | # conversion from the master dictionary is correct 13 | 14 | # Create a CG model using alpha carbons only 15 | all_residues = list(RESIDUE_RADII.keys()) 16 | doubled_res_list = np.concatenate(np.vstack([all_residues, 17 | all_residues]).T) 18 | names = ['CA'] * len(doubled_res_list) 19 | resseq = np.arange(1, len(doubled_res_list)+1) 20 | resmap = {i+1: doubled_res_list[i] for i in range(len(doubled_res_list))} 21 | mypeptide = CGMolecule(names, resseq, resmap) 22 | 23 | # Enumerate only (CA, CA) pairs when each CA corresponds to the same 24 | # type of residue 25 | same_res_pairs = [(i, i+1) for i in range(len(doubled_res_list)) 26 | if i % 2 == 0] 27 | 28 | # Calculate the minima with a prefactor of 1.0 29 | same_res_minima = calculate_hard_sphere_minima(same_res_pairs, 30 | mypeptide, 31 | units='Angstroms', 32 | prefactor=1.0) 33 | 34 | # The minima should be the radii in nanometers after a factor of 1/20 35 | single_nm_radii = [i/20 for i in same_res_minima] 36 | 37 | values_from_dict = [RESIDUE_RADII[res] for i, res in 38 | enumerate(doubled_res_list) if i % 2 == 0] 39 | 40 | np.testing.assert_allclose(values_from_dict, single_nm_radii) 41 | 42 | 43 | def test_minima_calculation_values(): 44 | # This is a manual test of the minima calculations 45 | 46 | # Shuffle the twenty amino acids. We'll used the firt entrie of the 47 | # shuffled lit to make a random peptide. 48 | possible_residues = list(RESIDUE_RADII.keys()) 49 | np.random.shuffle(possible_residues) 50 | 51 | # Make a CA only CGMolecule object with a random number of residues 52 | num_residues = np.random.randint(3, 10) 53 | names = ['CA'] * num_residues 54 | resseq = np.arange(1, num_residues+1) 55 | resmap = {i+1: possible_residues[i] for i in range(num_residues)} 56 | mypeptide = CGMolecule(names, resseq, resmap) 57 | 58 | # Enumerate all the residue-residue pairs 59 | pairs = list(itertools.combinations(np.arange(num_residues), 2)) 60 | 61 | # Designate a random prefactor (i.e., scaling factor for each radius 62 | # in the calculation 63 | prefactor = np.random.uniform(0.5, 1.3) 64 | 65 | # Perform the manual calculation 66 | manual_distances = [] 67 | for pair in pairs: 68 | # The +1 is needed because the resmap isn't zero-indexed 69 | rad1 = RESIDUE_RADII[resmap[pair[0] + 1]] 70 | rad2 = RESIDUE_RADII[resmap[pair[1] + 1]] 71 | # The *10 converts to angstroms 72 | manual_distances.append(prefactor*rad1*10 + prefactor*rad2*10) 73 | 74 | # Perform the automatic calculation 75 | distances = calculate_hard_sphere_minima(pairs, mypeptide, 76 | prefactor=prefactor) 77 | 78 | # The high tolerance is due to the significant figures in the 79 | # master list 80 | np.testing.assert_allclose(manual_distances, distances, rtol=1e-4) 81 | 82 | 83 | def test_CA_vs_CB_minima_correspondence(): 84 | # This tests that CA-CA distances are the same as CB-CB for the same 85 | # residue pair 86 | 87 | # Shuffle the twenty amino acids. We'll used the firt entrie of the 88 | # shuffled lit to make a random peptide. 89 | possible_residues = list(RESIDUE_RADII.keys()) 90 | np.random.shuffle(possible_residues) 91 | 92 | # Make a CA+CB CGMolecule object with a random number of residues 93 | # Note that this might involve a GLY having a CB - this is fine 94 | num_residues = np.random.randint(3, 10) 95 | names = ['CA', 'CB'] * num_residues 96 | resseq = list(np.concatenate([np.repeat(i+1, 2) 97 | for i in range(num_residues)])) 98 | resmap = {i+1: possible_residues[i] for i in range(num_residues)} 99 | mypeptide = CGMolecule(names, resseq, resmap) 100 | 101 | # Enumerate each set of inds 102 | CA_inds = [i for i in range(num_residues*2) if i % 2 == 0] 103 | CB_inds = [i for i in range(num_residues*2) if i % 2 == 1] 104 | 105 | # Enumerate one set of all CA-CA pairs and one set of all CB-CB pairs 106 | CA_CA_pairs = list(itertools.combinations(CA_inds, 2)) 107 | CB_CB_pairs = list(itertools.combinations(CB_inds, 2)) 108 | 109 | # Calculate each set of minima 110 | CA_CA_minima = calculate_hard_sphere_minima(CA_CA_pairs, mypeptide) 111 | CB_CB_minima = calculate_hard_sphere_minima(CB_CB_pairs, mypeptide) 112 | 113 | # Ensure equality 114 | np.testing.assert_array_equal(CA_CA_minima, CB_CB_minima) 115 | 116 | 117 | def test_intra_residue_zeros(): 118 | # This tests that the minimum distance between atoms within the same 119 | # residue returns zero 120 | 121 | # Shuffle the twenty amino acids. We'll used the firt entrie of the 122 | # shuffled lit to make a random peptide. 123 | possible_residues = list(RESIDUE_RADII.keys()) 124 | np.random.shuffle(possible_residues) 125 | 126 | # Make a CA+CB CGMolecule object with a random number of residues 127 | num_residues = np.random.randint(3, 10) 128 | names = ['CA', 'CB'] * num_residues 129 | resseq = list(np.concatenate([np.repeat(i+1, 2) 130 | for i in range(num_residues)])) 131 | resmap = {i+1: possible_residues[i] for i in range(num_residues)} 132 | mypeptide = CGMolecule(names, resseq, resmap) 133 | 134 | # Enumerate the intraresidue CA-CB pairs 135 | intra_res_pairs = [(i, i+1) for i in range(num_residues - 1) if i % 2 == 0] 136 | 137 | should_be_zeros = calculate_hard_sphere_minima(intra_res_pairs, mypeptide) 138 | 139 | # Test that a zero is returned for each residue 140 | np.testing.assert_array_equal(should_be_zeros, 141 | np.zeros(num_residues // 2)) 142 | -------------------------------------------------------------------------------- /cgnet/molecule/tests/test_trajectory.py: -------------------------------------------------------------------------------- 1 | # Author: Brooke Husic 2 | 3 | import numpy as np 4 | import torch 5 | 6 | import mdtraj as md 7 | 8 | from cgnet.feature import GeometryFeature 9 | from cgnet.molecule import CGMolecule 10 | 11 | 12 | # make a peptide backbone with 3 to 6 non-cap residues 13 | # We need at least three for the angle test so we have 14 | # three CA atoms. 15 | residues = np.random.randint(3, 6) 16 | 17 | # each non-cap residue will have 3 backbone atoms 18 | names = ['C'] + ['N', 'CA', 'C'] * residues + ['N'] 19 | beads = len(names) 20 | 21 | # resseq is the same length as names, where the integer corresponds 22 | # to the residue assignment of that bead, so it looks like 23 | # [1, 2, 2, 2, 3, 3, 3, ..., n] for n beads 24 | resseq = [1] + list(np.concatenate([np.repeat(i+2, 3) 25 | for i in range(residues)])) + [2+residues] 26 | 27 | # resmap maps the residue indices to amino acid identities; here we use 28 | # all alanines - it doesn't matter 29 | resmap = {1: 'ACE', (2+residues): 'NME'} 30 | for r in range(residues): 31 | resmap[r+2] = 'ALA' 32 | 33 | # we manually specify the bonds for some tests later 34 | # in the format that mdtraj requires 35 | bonds = np.zeros([beads-1, 4]) 36 | for b in range(beads-1): 37 | bonds[b] = [b, b+1, 0., 0.] 38 | 39 | # create a pseudo-dataset with three dimensions 40 | frames = np.random.randint(1, 10) 41 | dims = 3 42 | 43 | data = np.random.randn(frames, beads, dims) 44 | data_tensor = torch.Tensor(data) 45 | 46 | 47 | def test_cg_topology_standard(): 48 | # Make sure topology works like an mdtraj topology (auto bond version) 49 | 50 | molecule = CGMolecule(names=names, resseq=resseq, resmap=resmap, 51 | bonds='standard') 52 | 53 | # Here we just make sure mdtraj.topology attributes have the right values 54 | assert molecule.top.n_atoms == beads 55 | assert molecule.top.n_bonds == beads-1 # 'standard' fills in the bonds 56 | assert molecule.top.n_chains == 1 57 | 58 | 59 | def test_cg_topology_no_bonds(): 60 | # Make sure topology works like an mdtraj topology (no bond version) 61 | 62 | molecule = CGMolecule(names=names, resseq=resseq, resmap=resmap, 63 | bonds=None) 64 | 65 | # Here we just make sure mdtraj.topology attributes have the right values 66 | assert molecule.top.n_atoms == beads 67 | assert molecule.top.n_bonds == 0 # no bonds! 68 | assert molecule.top.n_chains == 1 69 | 70 | 71 | def test_cg_topology_custom_bonds(): 72 | # Make sure topology works like an mdtraj topology (custom bond version) 73 | 74 | molecule = CGMolecule(names=names, resseq=resseq, resmap=resmap, 75 | bonds=bonds) 76 | 77 | assert molecule.top.n_atoms == beads 78 | assert molecule.top.n_bonds == beads-1 # manual number of bonds 79 | assert molecule.top.n_chains == 1 80 | 81 | 82 | def test_cg_trajectory(): 83 | # Make sure trajectory works like an mdtraj trajectory 84 | 85 | molecule = CGMolecule(names=names, resseq=resseq, resmap=resmap) 86 | traj = molecule.make_trajectory(data) 87 | 88 | # here we test that some mdtraj trajectory attibutes are right 89 | assert traj.n_frames == frames 90 | assert traj.top.n_atoms == beads 91 | assert traj.n_residues == residues + 2 92 | 93 | 94 | def test_backbone_phi_dihedrals(): 95 | # Make sure backbone phi dihedrals are correct 96 | 97 | molecule = CGMolecule(names=names, resseq=resseq, resmap=resmap) 98 | traj = molecule.make_trajectory(data) 99 | _, mdtraj_phis = md.compute_phi(traj) 100 | mdtraj_phis = np.abs(mdtraj_phis) 101 | 102 | # manual calculation of phi angles 103 | phis = [] 104 | for frame_data in data: 105 | dihed_list = [] 106 | for i in range(residues): 107 | # we get the phi's by starting at the 'N', which is the first 108 | # bead for every residue 109 | a = frame_data[i*3] 110 | b = frame_data[i*3+1] 111 | c = frame_data[i*3+2] 112 | # the last bead in the phi dihedral is the 'N' of the next residue 113 | d = frame_data[i*3+3] 114 | 115 | ba = b-a 116 | cb = c-b 117 | dc = d-c 118 | 119 | c1 = np.cross(ba, cb) 120 | c2 = np.cross(cb, dc) 121 | temp = np.cross(c2, c1) 122 | term1 = np.dot(temp, cb)/np.sqrt(np.dot(cb, cb)) 123 | term2 = np.dot(c2, c1) 124 | dihed_list.append(np.arctan2(term1, term2)) 125 | phis.append(dihed_list) 126 | 127 | phis = np.abs(phis) 128 | 129 | np.testing.assert_allclose(mdtraj_phis, phis, rtol=1e-4) 130 | 131 | 132 | def test_backbone_psi_dihedrals(): 133 | # Make sure backbone psi dihedrals are correct 134 | 135 | molecule = CGMolecule(names=names, resseq=resseq, resmap=resmap) 136 | traj = molecule.make_trajectory(data) 137 | _, mdtraj_psis = md.compute_psi(traj) 138 | mdtraj_psis = np.abs(mdtraj_psis) 139 | 140 | # manual calculation of psi angles 141 | psis = [] 142 | for frame_data in data: 143 | dihed_list = [] 144 | for i in range(residues): 145 | # we get the psi's by starting at the 'CA', which is the second 146 | # bead for every residue 147 | a = frame_data[i*3+1] 148 | b = frame_data[i*3+2] 149 | # the last two beads in the psi dihedral are the 'N' and 'CA' 150 | # of the next residue 151 | c = frame_data[i*3+3] 152 | d = frame_data[i*3+4] 153 | 154 | ba = b-a 155 | cb = c-b 156 | dc = d-c 157 | 158 | c1 = np.cross(ba, cb) 159 | c2 = np.cross(cb, dc) 160 | temp = np.cross(c2, c1) 161 | term1 = np.dot(temp, cb)/np.sqrt(np.dot(cb, cb)) 162 | term2 = np.dot(c2, c1) 163 | dihed_list.append(np.arctan2(term1, term2)) 164 | psis.append(dihed_list) 165 | 166 | psis = np.abs(psis) 167 | 168 | np.testing.assert_allclose(mdtraj_psis, psis, rtol=1e-4) 169 | 170 | 171 | def test_equality_with_cgnet_dihedrals(): 172 | # Make sure dihedrals are consistent with GeometryFeature 173 | 174 | geom_feature = GeometryFeature(feature_tuples='all_backbone', 175 | n_beads=beads) 176 | out = geom_feature.forward(data_tensor) 177 | 178 | molecule = CGMolecule(names=names, resseq=resseq, resmap=resmap) 179 | traj = molecule.make_trajectory(data) 180 | 181 | mdtraj_phis = md.compute_phi(traj)[1] 182 | mdtraj_psis = md.compute_psi(traj)[1] 183 | 184 | mdtraj_phi_cosines = np.cos(mdtraj_phis) 185 | mdtraj_phi_sines = np.sin(mdtraj_phis) 186 | 187 | mdtraj_psi_cosines = np.cos(mdtraj_psis) 188 | mdtraj_psi_sines = np.sin(mdtraj_psis) 189 | 190 | # To get phi's and psi's out of cgnet, we need to specify which 191 | # indices they correspond to along the backbone 192 | # ['N', 'CA', 'C', 'N'] dihedrals 193 | phi_inds = [i*3 for i in range(residues)] 194 | # ['CA', 'C', 'N', 'CA'] dihedrals 195 | psi_inds = [i*3+1 for i in range(residues)] 196 | 197 | cgnet_phi_cosines = geom_feature.dihedral_cosines.numpy()[:, phi_inds] 198 | cgnet_phi_sines = geom_feature.dihedral_sines.numpy()[:, phi_inds] 199 | 200 | cgnet_psi_cosines = geom_feature.dihedral_cosines.numpy()[:, psi_inds] 201 | cgnet_psi_sines = geom_feature.dihedral_sines.numpy()[:, psi_inds] 202 | 203 | np.testing.assert_allclose(mdtraj_phi_cosines, cgnet_phi_cosines, 204 | rtol=1e-4) 205 | np.testing.assert_allclose(mdtraj_phi_sines, cgnet_phi_sines, 206 | rtol=1e-4) 207 | np.testing.assert_allclose(mdtraj_psi_cosines, cgnet_psi_cosines, 208 | rtol=1e-4) 209 | np.testing.assert_allclose(mdtraj_psi_sines, cgnet_psi_sines, 210 | rtol=1e-4) 211 | 212 | 213 | def test_equality_with_cgnet_distances(): 214 | # Make sure CA distances are consistent with GeometryFeature 215 | 216 | geom_feature = GeometryFeature(feature_tuples='all_backbone', 217 | n_beads=beads) 218 | out = geom_feature.forward(data_tensor) 219 | 220 | molecule = CGMolecule(names=names, resseq=resseq, resmap=resmap) 221 | traj = molecule.make_trajectory(data) 222 | 223 | # Calculate all pairs of CA distances 224 | CA_inds = [i for i, name in enumerate(names) if name == 'CA'] 225 | CA_pairs = [] # these are feature tuples 226 | for i, ind1 in enumerate(CA_inds[:-1]): 227 | for j, ind2 in enumerate(CA_inds[i+1:]): 228 | CA_pairs.append((ind1, ind2)) 229 | mdtraj_CA_dists = md.compute_distances(traj, CA_pairs) 230 | 231 | # map each CA distance feature tuple to the integer index 232 | CA_feature_tuple_dict = {key: i for i, key in 233 | enumerate(geom_feature.descriptions['Distances']) 234 | if key in CA_pairs} 235 | 236 | # retrieve CA distances only from the feature object 237 | cgnet_CA_dists = geom_feature.distances.numpy()[:, [CA_feature_tuple_dict[key] 238 | for key in CA_pairs]] 239 | 240 | np.testing.assert_allclose(mdtraj_CA_dists, cgnet_CA_dists, rtol=1e-6) 241 | 242 | 243 | def test_equality_with_cgnet_angles(): 244 | # Make sure CA distances caluclated internally are consistent with mdtraj. 245 | # This test appears here because it requires an mdtraj dependency. 246 | molecule = CGMolecule(names=names, resseq=resseq, resmap=resmap) 247 | traj = molecule.make_trajectory(data) 248 | 249 | # Grab the CA inds only to get the backbone angles and compute them 250 | # with mdtraj 251 | CA_inds = [i for i, name in enumerate(names) if name == 'CA'] 252 | backbone_angles = [(CA_inds[i], CA_inds[i+1], CA_inds[i+2]) 253 | for i in range(len(CA_inds)-2)] 254 | mdtraj_angles = md.compute_angles(traj, backbone_angles) 255 | 256 | # Get the GeometryFeature for just the 257 | geom_feature = GeometryFeature(feature_tuples=backbone_angles) 258 | out = geom_feature.forward(data_tensor) 259 | 260 | cgnet_angles = geom_feature.angles 261 | 262 | np.testing.assert_allclose(mdtraj_angles, cgnet_angles, rtol=1e-4) 263 | -------------------------------------------------------------------------------- /cgnet/molecule/trajectory.py: -------------------------------------------------------------------------------- 1 | # Author: Brooke Husic 2 | 3 | 4 | import torch 5 | import numpy as np 6 | import mdtraj as md 7 | 8 | 9 | class CGMolecule(): 10 | """Casting of a coarse-grained (CG) molecule as an mdtraj-compatible 11 | topology with the option to input trajectory coordinates to create 12 | an mdtraj Trajectory object that can be used for standard analyses 13 | such as the computation of dihedral angles, contact distances, etc. 14 | 15 | Please refer to the mdtraj documentation at http://www.mdtraj.org 16 | or the code at https://github.com/mdtraj/mdtraj. 17 | 18 | Parameters 19 | ---------- 20 | names : list 21 | List of atom names in the CG molecule 22 | resseq : list 23 | List of residue assignments of each atom in the CG molecule 24 | resmap: dictionary 25 | List of residue indexes (key) and corresponding residue name (value) 26 | elements : list or None 27 | List of elements corresponding to each atom in names. If None, 28 | then the first character of the corresponding name string is used. 29 | bonds : np.array, 'standard', or None (default='standard') 30 | If None, no bonds. If 'standard', applies mdtraj.create_standard_bonds 31 | after topology is constructed. If np.array, bonds are given manually 32 | with a np.array of dimensions (n_bonds, 4) with zeroes everywhere 33 | except the ZERO-INDEXED indices of the bonded atoms 34 | starting_index : int (default=0) 35 | Index for first atom in CG molecule if something other than a 0-index 36 | is desired 37 | 38 | Example 39 | ------- 40 | # Alanine dipeptide backbone example 41 | # coordinates is an np.array of dimension [n_frames, n_atoms, 3] 42 | 43 | names = ['C', 'N', 'CA', 'C', 'N'] 44 | resseq = [1, 2, 2, 2, 3] 45 | resmap = {1 : 'ACE', 2 : 'ALA', 3 : 'NME'} 46 | 47 | # bonds are not necessary in this case, since setting 48 | # bonds='standard' gives the desired result 49 | bonds = np.array( 50 | [[0., 1., 0., 0.], 51 | [2., 3., 0., 0.], 52 | [1., 2., 0., 0.], 53 | [3., 4., 0., 0.]])) 54 | 55 | molecule = CGMolecule(names=names, resseq=resseq, resmap=resmap, 56 | bonds=bonds) 57 | traj = molecule.make_trajectory(coordinates) 58 | 59 | Notes 60 | ----- 61 | Currently there is no option to have more than one chain. 62 | Unitcells are not implemented. 63 | 64 | References 65 | ---------- 66 | McGibbon, R. T., Beauchamp, K. A., Harrigan, M. P., Klein, C., 67 | Swails, J. M., Hernández, C. X., Schwantes, C. R., Wang, L.-P., 68 | Lane, T. J., and Pande, V. S. (2015). MDTraj: A Modern Open Library 69 | for the Analysis of Molecular Dynamics Trajectories. Biophys J. 70 | http://dx.doi.org/10.1016/j.bpj.2015.08.015 71 | """ 72 | 73 | def __init__(self, names, resseq, resmap, elements=None, 74 | bonds='standard', starting_index=0): 75 | if len(names) != len(resseq): 76 | raise ValueError( 77 | 'Names and resseq must be lists of the same length') 78 | self.names = names 79 | self.resseq = resseq 80 | 81 | if elements is None: 82 | # this may not be a good idea 83 | elements = [name[0] for name in self.names] 84 | self.elements = elements 85 | 86 | if not np.array_equal(sorted(resmap.keys()), np.unique(resseq)): 87 | raise ValueError( 88 | 'resmap dictionary must have a key for each index in resseq' 89 | ) 90 | self.resmap = resmap 91 | self.bonds = bonds 92 | self.starting_index = starting_index 93 | 94 | self.make_topology() 95 | 96 | def make_topology(self): 97 | """Generates an mdtraj.Topology object. 98 | 99 | Notes 100 | ----- 101 | Currently only implemented for a single chain. 102 | """ 103 | pd = md.utils.import_('pandas') 104 | data = [] 105 | for i, name in enumerate(self.names): 106 | row = (i + self.starting_index, name, name[0], self.resseq[i], 107 | self.resmap[self.resseq[i]], 0, '') 108 | data.append(row) 109 | atoms = pd.DataFrame(data, 110 | columns=["serial", "name", "element", "resSeq", 111 | "resName", "chainID", "segmentID"]) 112 | if type(self.bonds) is str: 113 | if self.bonds == 'standard': 114 | top = md.Topology.from_dataframe(atoms, None) 115 | top.create_standard_bonds() 116 | else: 117 | raise ValueError( 118 | '{} is not an accepted option for bonds'.format(self.bonds) 119 | ) 120 | else: 121 | top = md.Topology.from_dataframe(atoms, self.bonds) 122 | 123 | self.top = top 124 | self.topology = top 125 | 126 | def make_trajectory(self, coordinates): 127 | """Generates an mdtraj.Trajectory object. 128 | 129 | Parameters 130 | ---------- 131 | coordinates : np.array 132 | Coordinate data of dimension [n_frames, n_atoms, n_dimensions], 133 | where n_dimensions must be 3. 134 | 135 | Notes 136 | ----- 137 | This is a bit of a hack, and the user is responsible for using 138 | care with this method and ensuring the resulting trajectory 139 | is the intended output. 140 | 141 | No unit cell information is specified. 142 | """ 143 | if type(coordinates) is torch.Tensor: 144 | coordinates = coordinates.detach().numpy() 145 | 146 | if len(coordinates.shape) != 3: 147 | raise ValueError( 148 | 'coordinates shape must be [frames, atoms, dimensions]' 149 | ) 150 | if coordinates.shape[1] != self.top.n_atoms: 151 | raise ValueError( 152 | 'coordinates dimension 1 must be the number of atoms' 153 | ) 154 | if coordinates.shape[2] != 3: 155 | raise ValueError('coordinates must have 3 dimensions') 156 | 157 | # this is a hack; NOT recommended for actual use of mdtraj 158 | return md.core.trajectory.Trajectory(coordinates, self.top) 159 | -------------------------------------------------------------------------------- /cgnet/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .nnet import * 2 | from .priors import * 3 | from .simulation import * 4 | from .utils import * 5 | -------------------------------------------------------------------------------- /cgnet/network/nnet.py: -------------------------------------------------------------------------------- 1 | # Author: Nick Charron 2 | # Contributors: Brooke Husic, Dominik Lemm, Jiang Wang 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | from .priors import ZscoreLayer, HarmonicLayer, RepulsionLayer 8 | from cgnet.feature import FeatureCombiner, SchnetFeature, GeometryFeature 9 | 10 | 11 | class ForceLoss(torch.nn.Module): 12 | """Loss function for force matching scheme.""" 13 | 14 | def __init__(self): 15 | super(ForceLoss, self).__init__() 16 | 17 | def forward(self, force, labels): 18 | """Returns force matching loss averaged over all examples. 19 | 20 | Parameters 21 | ---------- 22 | force : torch.Tensor (grad enabled) 23 | forces calculated from the CGnet energy via autograd. 24 | Size [n_frames, n_degrees_freedom]. 25 | labels : torch.Tensor 26 | forces to compute the loss against. Size [n_frames, 27 | n_degrees_of_freedom]. 28 | 29 | Returns 30 | ------- 31 | loss : torch.Variable 32 | example-averaged Frobenius loss from force matching. Size [1, 1]. 33 | 34 | """ 35 | loss = ((force - labels)**2).mean() 36 | return loss 37 | 38 | 39 | class CGnet(nn.Module): 40 | """CGnet neural network class 41 | 42 | Parameters 43 | ---------- 44 | arch : list of nn.Module() instances 45 | underlying sequential network architecture. 46 | criterion : nn.Module() instances 47 | loss function to be used for network. 48 | feature : nn.Module() instance 49 | feature layer to transform cartesian coordinates into roto- 50 | translationally invariant features. 51 | priors : list of nn.Module() instances (default=None) 52 | list of prior layers that provide energy contributions external to 53 | the hidden architecture of the CGnet. 54 | 55 | Notes 56 | ----- 57 | CGnets are a class of feedforward neural networks introduced by Jiang et. 58 | al. (2019) which are used to predict coarse-grained molecular force fields 59 | from Cartesain coordinate data. They are characterized by an autograd layer 60 | with respect to input coordinates implemented before the loss function, 61 | which directs the network to learn a representation of the coarse-grained 62 | potential of mean force (PMF) associated with a conservative coarse-grained 63 | force feild via a gradient operation as prescribed by classical mechanics. 64 | CGnets may also contain featurization layers, which transform Cartesian 65 | inputs into roto-translationally invariant features, thereby yeilding a PMF 66 | that respects these invarainces. CGnets may additionally be supplied with 67 | external prior functions, which are useful for regularizing network behavior 68 | in sparsely sampled, unphysical regions of molecular configuration space. 69 | 70 | The inputs to CGnet (coordinates and forces) determine the units 71 | that are learned/used by CGnet. It is important to make sure that the units 72 | between the input coordinates and force labels are consistent with one 73 | another. These units must also be consistent with the interaction 74 | parameters for and specified priors. If one desires to use CGnet to make 75 | predictions in a different unit system, the predictions must be made using 76 | original unit system, and then converted to the desired unit system 77 | outside of the CGnet. Otherwise, a new CGnet model must be trained using the 78 | desired units. 79 | 80 | Examples 81 | -------- 82 | From Jiang et. al. (2019), the optimal architecture for a 5-bead coarse 83 | grain model of alanine dipeptide, featurized into bonds, angles, pairwaise 84 | distances, and backbone torsions, was found to be: 85 | 86 | CGnet( 87 | (input): in_features=30 88 | (arch): Sequential( 89 | (0): GeometryFeature(in_features=30, out_features=17) 90 | (1): Linear(in_features=17, out_features=160, bias=True) 91 | (2): Tanh() 92 | (3): Linear(in_features=160, out_features=160, bias=True) 93 | (4): Tanh() 94 | (5): Linear(in_features=160, out_features=160, bias=True) 95 | (6): Tanh() 96 | (7): Linear(in_features=160, out_features=160, bias=True) 97 | (8): Tanh() 98 | (9): Linear(in_features=160, out_features=160, bias=True) 99 | (10): Tanh() 100 | (11): Linear(in_features=160, out_features=1, bias=True) 101 | (12): HarmonicLayer(bonds) 102 | (13): HarmonicLayer(angles) 103 | (14): torch.autograd.grad(-((11) + (12) + (13)), input, 104 | create_graph=True, retain_graph=True) 105 | ) 106 | (criterion): ForceLoss() 107 | ) 108 | 109 | Mounting to GPU can be accomplished using the 'mount' method. For example, 110 | given an instance of CGnet and a torch.device, the model may be mounted in 111 | the follwing way: 112 | 113 | my_cuda = torch.device('cuda') 114 | model.mount(my_cuda) 115 | 116 | References 117 | ---------- 118 | Wang, J., Olsson, S., Wehmeyer, C., Pérez, A., Charron, N. E., 119 | de Fabritiis, G., Noé, F., Clementi, C. (2019). Machine Learning 120 | of Coarse-Grained Molecular Dynamics Force Fields. ACS Central Science. 121 | https://doi.org/10.1021/acscentsci.8b00913 122 | 123 | """ 124 | 125 | def __init__(self, arch, criterion, feature=None, priors=None): 126 | super(CGnet, self).__init__() 127 | zscore_idx = 1 128 | for layer in arch: 129 | if isinstance(layer, ZscoreLayer): 130 | self.register_buffer('zscores_{}'.format(zscore_idx), 131 | layer.zscores) 132 | 133 | zscore_idx += 1 134 | self.arch = nn.Sequential(*arch) 135 | if priors: 136 | self.priors = nn.Sequential(*priors) 137 | harm_idx = 1 138 | repul_idx = 1 139 | for layer in self.priors: 140 | if isinstance(layer, HarmonicLayer): 141 | self.register_buffer('harmonic_params_{}'.format(harm_idx), 142 | layer.harmonic_parameters) 143 | harm_idx += 1 144 | if isinstance(layer, RepulsionLayer): 145 | self.register_buffer('repulsion_params_{}'.format(repul_idx), 146 | layer.repulsion_parameters) 147 | repul_idx += 1 148 | else: 149 | self.priors = None 150 | self.criterion = criterion 151 | self.feature = feature 152 | 153 | def forward(self, coordinates, embedding_property=None): 154 | """Forward pass through the network ending with autograd layer. 155 | 156 | Parameters 157 | ---------- 158 | coord : torch.Tensor (grad enabled) 159 | input trajectory/data of size [n_frames, n_degrees_of_freedom]. 160 | embedding_property: torch.Tensor (default=None) 161 | Some property that should be embedded. Can be nuclear charge 162 | or maybe an arbitrary number assigned for amino-acids. 163 | Size [n_frames, n_properties] 164 | 165 | Returns 166 | ------- 167 | energy : torch.Tensor 168 | scalar potential energy of size [n_frames, 1]. If priors are 169 | supplied to the CGnet, then this energy is the sum of network 170 | and prior energies. 171 | force : torch.Tensor 172 | vector forces of size [n_frames, n_degrees_of_freedom]. 173 | 174 | Notes 175 | ----- 176 | If a dataset with variable molecule sizes is being used, it is 177 | important to mask the contributions from padded portions of 178 | the input into the neural network. This is done using the batchwise 179 | variable 'bead_mask' (shape [n_frames, n_beads]). 180 | This mask is used to set energy contributions from non-physical beads 181 | to zero through elementwise multiplication with the CGnet ouput for 182 | models using SchnetFeatures 183 | """ 184 | 185 | if self.feature: 186 | # The below code adheres to the following logic: 187 | # 1. The feature_output is always what is passed to the model architecture 188 | # 2. The geom_feature is always what is passed to the priors 189 | # There will never be no feature_output, but sometimes it will 190 | # be the same as the geom_feature. There may be no geom_feature. 191 | if isinstance(self.feature, FeatureCombiner): 192 | # Right now, the only case we have is that a FeatureCombiner 193 | # with two Features will be a SchnetFeature followed by a 194 | # GeometryFeature. 195 | feature_output, geom_feature = self.feature(coordinates, 196 | embedding_property=embedding_property) 197 | if self.feature.propagate_geometry: 198 | # We only can use propagate_geometry if the feature_output is a 199 | # SchnetFeature 200 | schnet_feature = feature_output 201 | if geom_feature is None: 202 | raise RuntimeError( 203 | "There is no GeometryFeature to propagate. Was " \ 204 | "your FeatureCombiner a SchnetFeature only?" 205 | ) 206 | n_frames = coordinates.shape[0] 207 | schnet_feature = schnet_feature.reshape(n_frames, -1) 208 | concatenated_feature = torch.cat((schnet_feature, geom_feature), dim=1) 209 | energy = self.arch(concatenated_feature) 210 | else: 211 | energy = self.arch(feature_output) 212 | if len(energy.size()) == 3: 213 | # sum energy over beads 214 | energy = torch.sum(energy, axis=1) 215 | if not isinstance(self.feature, FeatureCombiner): 216 | if embedding_property is not None: 217 | # This assumes the only feature with an embedding_property 218 | # is a SchnetFeature. If other features can take embeddings, 219 | # this needs to be revisited. 220 | feature_output = self.feature( 221 | coordinates, embedding_property) 222 | geom_feature = None 223 | else: 224 | feature_output = self.feature(coordinates) 225 | geom_feature = feature_output 226 | energy = self.arch(feature_output) 227 | else: 228 | # Finally, if we pass only the coordinates with no pre-computed 229 | # Feature, then we call those coordinates the feature. We will 230 | # name this geom_feature because there may be priors on it. 231 | feature_output = coordinates 232 | geom_feature = coordinates 233 | energy = self.arch(feature_output) 234 | if self.priors: 235 | if geom_feature is None: 236 | raise RuntimeError( 237 | "Priors may only be used with GeometryFeatures or coordinates." 238 | ) 239 | for prior in self.priors: 240 | energy = energy + prior(geom_feature[:, prior.callback_indices]) 241 | # Sum up energies along bead axis for Schnet outputs and mask out 242 | # nonexisting beads 243 | if len(energy.size()) == 3 and isinstance(self.feature, SchnetFeature): 244 | # Make sure to mask those beads which are not physical. 245 | # Their contribution to the predicted energy and forces 246 | # should be zero 247 | bead_mask = torch.clamp(embedding_property, min=0, max=1).float() 248 | masked_energy = energy * bead_mask[:, :, None] 249 | energy = torch.sum(masked_energy, axis=-2) 250 | # Perform autograd to learn potential of conservative force field 251 | force = torch.autograd.grad(-torch.sum(energy), 252 | coordinates, 253 | create_graph=True, 254 | retain_graph=True) 255 | return energy, force[0] 256 | 257 | def mount(self, device): 258 | """Wrapper for device mounting 259 | 260 | Parameters 261 | ---------- 262 | device : torch.device 263 | Device upon which model can be mounted for computation/training 264 | """ 265 | 266 | # Buffers and parameters 267 | self.to(device) 268 | # Non parameters/buffers 269 | if self.feature: 270 | if isinstance(self.feature, FeatureCombiner): 271 | for layer in self.feature.layer_list: 272 | if isinstance(layer, (GeometryFeature, SchnetFeature)): 273 | layer.device = device 274 | layer.geometry.device = device 275 | if isinstance(layer, ZscoreLayer): 276 | layer.to(device) 277 | if isinstance(self.feature, (GeometryFeature, SchnetFeature)): 278 | self.feature.device = device 279 | self.feature.geometry.device = device 280 | 281 | def predict(self, coord, force_labels, embedding_property=None): 282 | """Prediction over test/validation batch. 283 | 284 | Parameters 285 | ---------- 286 | coord: torch.Tensor (grad enabled) 287 | input trajectory/data of size [n_frames, n_degrees_of_freedom] 288 | force_labels: torch.Tensor 289 | force labels of size [n_frames, n_degrees_of_freedom] 290 | embedding_property: torch.Tensor (default=None) 291 | Some property that should be embedded. Can be nuclear charge 292 | or maybe an arbitrary number assigned for amino-acids. 293 | Size [n_frames, n_properties] 294 | Returns 295 | ------- 296 | loss.data : torch.Tensor 297 | loss over prediction inputs. 298 | 299 | """ 300 | 301 | self.eval() # set model to eval mode 302 | energy, force = self.forward(coord) 303 | loss = self.criterion.forward(force, force_labels, 304 | embedding_property=embedding_property) 305 | self.train() # set model to train mode 306 | return loss.data 307 | -------------------------------------------------------------------------------- /cgnet/network/priors.py: -------------------------------------------------------------------------------- 1 | # Author: Nick Charron 2 | # Contributors: Brooke Husic, Dominik Lemm, Jiang Wang, Simon Olsson 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class _AbstractPriorLayer(nn.Module): 8 | """Abstract Layer for definition of priors, which only imposes the minimal 9 | functional constraints to enable model estimation and inference. 10 | """ 11 | def __init__(self): 12 | super(_AbstractPriorLayer, self).__init__() 13 | self.callback_indices = slice(None, None) 14 | 15 | def forward(self, x): 16 | """Forward method to compute the prior energy contribution. 17 | 18 | Notes 19 | ----- 20 | This must be explicitly implemented in a child class that inherits from 21 | _AbstractPriorLayer(). The details of this method should encompass the 22 | mathematical steps to form each specific energy contribution to the 23 | potential energy.; 24 | """ 25 | raise NotImplementedError( 26 | 'forward() method must be overridden in \ 27 | custom classes inheriting from _AbstractPriorLayer()' 28 | ) 29 | 30 | class _PriorLayer(_AbstractPriorLayer): 31 | """Layer for adding prior energy computations external to CGnet hidden 32 | output 33 | 34 | Parameters 35 | ---------- 36 | callback_indices: list of int 37 | indices used to access a specified subset of outputs from the feature 38 | layer through a residual connection 39 | 40 | interaction_parameters : list of python dictionaries 41 | list of dictionaries that specify the constants characterizing 42 | interactions between beads. Each list element corresponds to a single 43 | interaction using a dictionary of parameters keyed to corresponding 44 | numerical values. The order of these dictionaries follows the same order 45 | as the callback indices specifying which outputs from the feature layer 46 | should pass through the prior. The strucutre of interaction_parameters 47 | is the following: 48 | 49 | [ {'parameter_1' : 1.24, 'parameter_2' : 2.21, ... }, 50 | {'parameter_1' : 1.24, 'parameter_2' : 2.21, ... }, 51 | . 52 | . 53 | . 54 | {'parameter_1' : 1.24, 'parameter_2' : 2.21, ... }] 55 | 56 | In this way, _PriorLayer may be subclassed to make arbitray prior 57 | layers based on arbitrary interactions between bead tuples. 58 | 59 | Attributes 60 | ---------- 61 | interaction_parameters: list of dict 62 | each list element contains a dictionary of physical parameters that 63 | characterizxe the interaction of the associated beads. The order of 64 | this list proceeds in the same order as self.callback_indices 65 | callback_indices: list of int 66 | indices used to access a specified subset of outputs from the feature 67 | layer through a residual connection 68 | 69 | Examples 70 | -------- 71 | To assemble the feat_dict input for a HarmonicLayer prior for bonds from an 72 | instance of a stats = GeometryStatistics(): 73 | 74 | bonds_interactions, _ = stats.get_prior_statistics('Bonds', as_list=True) 75 | bonds_idx = stats.return_indices('Bonds') 76 | bond_layer = HarmonicLayer(bonds_idx, bonds_interactions) 77 | 78 | Notes 79 | ----- 80 | callback_indices and interaction_parameters MUST share the same order for 81 | the prior layer to produce correct energies. Using 82 | GeometryStatistics.get_prior_statistics() with as_list=True together with 83 | GeometryStatistics.return_indices() will ensure this is True for the same 84 | list of features. 85 | 86 | The units of the interaction paramters for priors must correspond with the 87 | units of the input coordinates and force labels used to train the CGnet. 88 | 89 | """ 90 | 91 | def __init__(self, callback_indices, interaction_parameters): 92 | super(_PriorLayer, self).__init__() 93 | if len(callback_indices) != len(interaction_parameters): 94 | raise ValueError( 95 | "callback_indices and interaction_parameters must have the same length" 96 | ) 97 | self.interaction_parameters = interaction_parameters 98 | self.callback_indices = callback_indices 99 | 100 | class RepulsionLayer(_PriorLayer): 101 | """Layer for calculating pairwise repulsion energy prior. Pairwise repulsion 102 | energies are calculated using the following formula: 103 | 104 | U_repulsion_ij = (sigma_{ij} / r_{ij}) ^ exp_{ij} 105 | 106 | where U_repulsion_ij is the repulsion energy contribution from 107 | coarse grain beads i and j, sigma_ij is the excluded volume parameter 108 | between the pair (in units of distance), r_ij is the pairwise distance 109 | (in units of distance) between coarse grain beads i and j, and exp_ij 110 | is the repulsion exponenent (dimensionless) that characterizes the 111 | asymptotics of the interaction. 112 | 113 | Parameters 114 | ---------- 115 | callback_indices: list of int 116 | indices used to access a specified subset of outputs from the feature 117 | layer through a residual connection 118 | 119 | interaction_parameters : list of python dictionaries 120 | list of dictionaries that specify the constants characterizing 121 | interactions between beads. Each list element corresponds to a single 122 | interaction using a dictionary of parameters keyed to corresponding 123 | numerical values. The order of these dictionaries follows the same order 124 | as the callback indices specifying which outputs from the feature layer 125 | should pass through the prior. The parameters for RepulsionLayer 126 | dictionaries are 'ex_vol', the excluded volume (in length units), and 127 | 'exp', the (positive) exponent characterizing the repulsion strength 128 | decay with distance. 129 | 130 | Attributes 131 | ---------- 132 | repulsion_parameters : torch.Tensor 133 | tensor of shape [2, num_interactions]. The first row contains the 134 | excluded volumes, the second row contains the exponents, and each 135 | column corresponds to a single interaction in the order determined 136 | by self.callback_indices 137 | 138 | Notes 139 | ----- 140 | This prior energy should be used for longer molecules that may possess 141 | metastable states in which portions of the molecule that are separated by 142 | many CG beads in sequence may nonetheless adopt close physical proximities. 143 | Without this prior, it is possilbe for the CGnet to learn energies that do 144 | not respect proper physical pairwise repulsions. The interaction is modeled 145 | after the VDW interaction term from the classic Leonard Jones potential. 146 | 147 | References 148 | ---------- 149 | Wang, J., Olsson, S., Wehmeyer, C., Pérez, A., Charron, N. E., 150 | de Fabritiis, G., Noé, F., Clementi, C. (2019). Machine Learning 151 | of Coarse-Grained Molecular Dynamics Force Fields. ACS Central Science. 152 | https://doi.org/10.1021/acscentsci.8b00913 153 | """ 154 | 155 | def __init__(self, callback_indices, interaction_parameters): 156 | super(RepulsionLayer, self).__init__( 157 | callback_indices, interaction_parameters) 158 | for param_dict in self.interaction_parameters: 159 | if (key in param_dict for key in ('ex_vol', 'exp')): 160 | pass 161 | else: 162 | raise KeyError( 163 | 'Missing or incorrect key for repulsion parameters' 164 | ) 165 | repulsion_parameters = torch.tensor([]) 166 | for param_dict in self.interaction_parameters: 167 | repulsion_parameters = torch.cat(( 168 | repulsion_parameters, 169 | torch.tensor([[param_dict['ex_vol']], 170 | [param_dict['exp']]])), dim=1) 171 | self.register_buffer('repulsion_parameters', repulsion_parameters) 172 | 173 | def forward(self, in_feat): 174 | """Calculates repulsion interaction contributions to energy 175 | 176 | Parameters 177 | ---------- 178 | in_feat: torch.Tensor 179 | input features, such as pairwise distances, of size (n,k), for 180 | n examples and k features. 181 | Returns 182 | ------- 183 | energy: torch.Tensor 184 | output energy of size (n,1) for n examples. 185 | """ 186 | 187 | n = len(in_feat) 188 | energy = torch.sum((self.repulsion_parameters[0, :]/in_feat) 189 | ** self.repulsion_parameters[1, :], 190 | 1).reshape(n, 1) / 2 191 | return energy 192 | 193 | 194 | class HarmonicLayer(_PriorLayer): 195 | """Layer for calculating bond/angle harmonic energy prior. Harominc energy 196 | contributions have the following form: 197 | 198 | U_harmonic_{ij} = 0.5 * k_{ij} * ((r_{ij} - r_0_{ij}) ^ 2) 199 | 200 | where U_harmonic_ij is the harmonic energy contribution from 201 | coarse grain beads i and j, k_ij is the harmonic spring constant 202 | (in energy/distance**2) that characterizes the strength of the harmonic 203 | interaction between coarse grain beads i and j, r_{ij} is the pairwise 204 | distance (in distance units) between coarse grain beads i and j, and r_0_ij 205 | is the equilibrium/average pairwise distance (in distance units) between 206 | coarse grain beads i and j. 207 | 208 | Parameters 209 | ---------- 210 | callback_indices: list of int 211 | indices used to access a specified subset of outputs from the feature 212 | layer through a residual connection 213 | 214 | interaction_parameters : list of python dictionaries 215 | list of dictionaries that specify the constants characterizing 216 | interactions between beads. Each list element corresponds to a single 217 | interaction using a dictionary of parameters keyed to corresponding 218 | numerical values. The order of these dictionaries follows the same order 219 | as the callback indices specifying which outputs from the feature layer 220 | should pass through the prior. The parameters for HarmonicLayer 221 | dictionaries are 'mean', the center of the harmonic interaction 222 | (in length or angle units), and 'k', the (positive) harmonic spring 223 | constant (in units of energy / length**2 or 1 / length**2). 224 | 225 | Attributes 226 | ---------- 227 | harmonic_parameters : torch.Tensor 228 | tensor of shape [2, num_interactions]. The first row contains the 229 | harmonic spring constants, the second row contains the mean positions, 230 | and each column corresponds to a single interaction in the order 231 | determined by self.callback_indices 232 | 233 | Notes 234 | ----- 235 | This prior energy is useful for constraining the CGnet potential in regions 236 | of configuration space in which sampling is normally precluded by physical 237 | harmonic constraints associated with the structural integrity of the protein 238 | along its backbone. The harmonic parameters are also easily estimated from 239 | all atom simulation data because bond and angle distributions typically have 240 | Gaussian structure, which is easily intepretable as a harmonic energy 241 | contribution via the Boltzmann distribution. 242 | 243 | References 244 | ---------- 245 | Wang, J., Olsson, S., Wehmeyer, C., Pérez, A., Charron, N. E., 246 | de Fabritiis, G., Noé, F., Clementi, C. (2019). Machine Learning 247 | of Coarse-Grained Molecular Dynamics Force Fields. ACS Central Science. 248 | https://doi.org/10.1021/acscentsci.8b00913 249 | """ 250 | 251 | def __init__(self, callback_indices, interaction_parameters): 252 | super(HarmonicLayer, self).__init__( 253 | callback_indices, interaction_parameters) 254 | for param_dict in self.interaction_parameters: 255 | if (key in param_dict for key in ('k', 'mean')): 256 | if torch.isnan(param_dict['k']).any(): 257 | raise ValueError( 258 | 'Harmonic spring constant "k" contains NaNs.' \ 259 | 'Check your parameters.' 260 | ) 261 | if torch.isnan(param_dict['mean']).any(): 262 | raise ValueError( 263 | 'Center of the harmonic interaction "mean" contains NaNs.'\ 264 | 'Check your parameters.' 265 | ) 266 | else: 267 | KeyError('Missing or incorrect key for harmonic parameters') 268 | harmonic_parameters = torch.tensor([]) 269 | for param_dict in self.interaction_parameters: 270 | harmonic_parameters = torch.cat((harmonic_parameters, 271 | torch.tensor([[param_dict['k']], 272 | [param_dict['mean']]])), 273 | dim=1) 274 | self.register_buffer('harmonic_parameters', harmonic_parameters) 275 | 276 | def forward(self, in_feat): 277 | """Calculates harmonic contribution of bond/angle interactions to energy 278 | 279 | Parameters 280 | ---------- 281 | in_feat: torch.Tensor 282 | input features, such as bond distances or angles of size (n,k), for 283 | n examples and k features. 284 | 285 | Returns 286 | ------- 287 | energy: torch.Tensor 288 | output energy of size (n,1) for n examples. 289 | 290 | """ 291 | 292 | n = len(in_feat) 293 | energy = torch.sum(self.harmonic_parameters[0, :] * 294 | (in_feat - self.harmonic_parameters[1, :]) ** 2, 295 | 1).reshape(n, 1) / 2 296 | return energy 297 | 298 | 299 | class ZscoreLayer(nn.Module): 300 | """Layer for Zscore normalization. Zscore normalization involves 301 | scaling features by their mean and standard deviation in the following 302 | way: 303 | 304 | X_normalized = (X - X_avg) / sigma_X 305 | 306 | where X_normalized is the zscore-normalized feature, X is the original 307 | feature, X_avg is the average value of the orignal feature, and sigma_X 308 | is the standard deviation of the original feature. 309 | 310 | Parameters 311 | ---------- 312 | zscores: torch.Tensor 313 | [2, n_features] tensor, where the first row contains the means 314 | and the second row contains the standard deviations of each 315 | feature 316 | 317 | Notes 318 | ----- 319 | Zscore normalization can accelerate training convergence if placed 320 | after a GeometryFeature() layer, especially if the input features 321 | span different orders of magnitudes, such as the combination of angles 322 | and distances. 323 | 324 | For more information, see the documentation for 325 | sklearn.preprocessing.StandardScaler 326 | 327 | """ 328 | 329 | def __init__(self, zscores): 330 | super(ZscoreLayer, self).__init__() 331 | self.register_buffer('zscores', zscores) 332 | 333 | def forward(self, in_feat): 334 | """Normalizes each feature by subtracting its mean and dividing by 335 | its standard deviation. 336 | 337 | Parameters 338 | ---------- 339 | in_feat: torch.Tensor 340 | input data of shape [n_frames, n_features] 341 | 342 | Returns 343 | ------- 344 | rescaled_feat: torch.Tensor 345 | Zscore normalized features. Shape [n_frames, n_features] 346 | 347 | """ 348 | rescaled_feat = (in_feat - self.zscores[0, :])/self.zscores[1, :] 349 | return rescaled_feat 350 | -------------------------------------------------------------------------------- /cgnet/network/utils.py: -------------------------------------------------------------------------------- 1 | # Authors: Nick Charron, Brooke Husic, Jiang Wang 2 | # Contributors: Dominik Lemm 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.data import DataLoader, Dataset 7 | import numpy as np 8 | import warnings 9 | 10 | from cgnet.feature import (GeometryFeature, SchnetFeature, FeatureCombiner, 11 | MultiMoleculeDataset) 12 | 13 | 14 | def _schnet_feature_linear_extractor(schnet_feature, return_weight_data_only=False): 15 | """Helper function to extract instances of nn.Linear from a SchnetFeature 16 | 17 | Parameters 18 | ---------- 19 | schnet_feature : SchnetFeature instance 20 | The SchnetFeature instance from which nn.Linear instances will be 21 | extracted. 22 | return_weight_data_only : bool (default=False) 23 | If True, the function returns the torch tensor for each weight 24 | layer rather than the nn.Linear instance. 25 | 26 | Returns 27 | ------- 28 | linear_list : list of nn.Linear instances or np.arrays, 29 | The list of nn.Linear layers extracted from the supplied 30 | SchnetFeature. See notes below for the order of nn.Linear instances 31 | in this list. 32 | weight_data : list of torch.Tensors 33 | If 'return_data=True', the function instead returns the torch tensors 34 | of each nn.Linear instance weight. See notes below for the order of 35 | tensors in this list 36 | 37 | Notes 38 | ----- 39 | Each InteractionBlock contains nn.Linear instances in the following order: 40 | 41 | 1. initial_dense layer 42 | 2. cfconv.filter_generator layer 1 43 | 3. cfconv.filter_generator layer 2 44 | 4. output layer 1 45 | 5. output layer 2 46 | 47 | This gives five linear layers in total per InteractionBlock. The order of 48 | the nn.Linear instances are returned by _schnet_feature_linear_extractor(). 49 | This is a hardcoded choice, becasue we assume that architectural structure 50 | of all InteractionBlocks are exactly the same (i.e., 1-5 above). 51 | """ 52 | 53 | linear_list = [] 54 | for block in schnet_feature.interaction_blocks: 55 | for block_layer in [block.initial_dense, 56 | block.cfconv.filter_generator, 57 | block.output_dense]: 58 | linear_list += [layer for layer in block_layer 59 | if isinstance(layer, nn.Linear)] 60 | if return_weight_data_only: 61 | weight_data = [layer.weight.data for layer in linear_list] 62 | return weight_data 63 | else: 64 | return linear_list 65 | 66 | 67 | def lipschitz_projection(model, strength=10.0, network_mask=None, schnet_mask=None): 68 | """Performs L2 Lipschitz Projection via spectral normalization 69 | 70 | Parameters 71 | ---------- 72 | model : cgnet.network.CGnet() instance 73 | model to perform Lipschitz projection upon 74 | strength : float (default=10.0) 75 | Strength of L2 lipschitz projection via spectral normalization. 76 | The magntitude of {dominant weight matrix eigenvalue / strength} 77 | is compared to unity, and the weight matrix is rescaled by the max 78 | of this comparison 79 | network_mask : None, 'all', or list of bool (default=None) 80 | mask used to exclude certain terminal network layers from lipschitz 81 | projection. If an element is False, the corresponding weight layer 82 | is exempt from a lipschitz projection. If set to all, a False mask 83 | is used for all terminal network weights. If None, all terminal network 84 | weight layers are subject to Lipschitz constraint. 85 | schnet_mask : None, 'all', or list of bool (default=None) 86 | mask used to exclude certain SchnetFeature layers from lipschitz projection. 87 | If an element is False, the corresponding weight layer is exempt from a 88 | lipschitz projection. The linear layers of a SchnetFeature InteractionBlock 89 | have the following arrangement: 90 | 91 | 1. initial_dense layer 92 | 2. cfconv.filter_generator layer 1 93 | 3. cfconv.filter_generator layer 2 94 | 4. output layer 1 95 | 5. output layer 2 96 | 97 | that is, each InteractionBlock contains 5 nn.Linear instances. If set 98 | to 'all', a False mask is used for all weight layers in every 99 | InteractionBlock. If None, all weight layers are subject to Lipschitz 100 | constraint. 101 | 102 | Notes 103 | ----- 104 | L2 Lipshitz regularization is a per-layer regularization that constrains 105 | the Lipschitz constant of each mapping from one linear layer to the next. 106 | As formulated by Gouk et. al. (2018), this constraint can be enforced by 107 | comparing the magnitudes between the weighted dominant singular value of 108 | the linear layer weight matrix and unity, taking the maximum, and 109 | normalizing the weight matrix by this result: 110 | 111 | W = W / max( s_dom / lambda, 1.0 ) 112 | 113 | for weight matrix W, dominant singular value s_dom, and regularization 114 | strength lambda. In this form, a strong regularization is achieved for 115 | lambda -> 0, and a weak regularization is achieved for lambda -> inf. 116 | 117 | For nn.Linear weights that exist in SchnetFeatures (in the form of dense 118 | layers in InteractionBlocks and dense layers in the continuous filter 119 | convolutions), we assume that the architectural structure of all 120 | InteractionBlocks (and the continuous filter convolutions therein) is 121 | fixed to be the same - that is the nn.Linear instances always appear 122 | in SchnetFeatures in the following fixed order: 123 | 124 | 1. initial_dense layer 125 | 2. cfconv.filter_generator layer 1 126 | 3. cfconv.filter_generator layer 2 127 | 4. output layer 1 128 | 5. output layer 2 129 | 130 | References 131 | ---------- 132 | Gouk, H., Frank, E., Pfahringer, B., & Cree, M. (2018). Regularisation 133 | of Neural Networks by Enforcing Lipschitz Continuity. arXiv:1804.04368 134 | [Cs, Stat]. Retrieved from http://arxiv.org/abs/1804.04368 135 | """ 136 | 137 | # Grab all instances of nn.Linear in the model, including those 138 | # that are part of SchnetFeatures 139 | # First, we grab the instances of nn.Linear from model.arch 140 | network_weight_layers = [layer for layer in model.arch 141 | if isinstance(layer, nn.Linear)] 142 | # Next, we grab the nn.Linear instances from the SchnetFeature 143 | schnet_weight_layers = [] 144 | # if it is part of a FeatureCombiner instance 145 | if isinstance(model.feature, FeatureCombiner): 146 | for feature in model.feature.layer_list: 147 | if isinstance(feature, SchnetFeature): 148 | schnet_weight_layers += _schnet_feature_linear_extractor( 149 | feature) 150 | # Lastly, we handle the case of SchnetFeatures that are not part of 151 | # a FeatureCombiner instance 152 | elif isinstance(model.feature, SchnetFeature): 153 | schnet_weight_layers += _schnet_feature_linear_extractor(model.feature) 154 | 155 | # Next, we assemble a (possibly combined from terminal network and 156 | # SchnetFeature) mask 157 | if network_mask is None: 158 | network_mask = [True for _ in network_weight_layers] 159 | elif network_mask == 'all': 160 | network_mask = [False for _ in network_weight_layers] 161 | if network_mask is not None: 162 | if not isinstance(network_mask, list): 163 | raise ValueError("Lipschitz network mask must be list of booleans") 164 | if len(network_weight_layers) != len(network_mask): 165 | raise ValueError("Lipshitz network mask must have the same number " 166 | "of elements as the number of nn.Linear " 167 | "modules in the model.arch attribute.") 168 | 169 | if schnet_mask is None: 170 | schnet_mask = [True for _ in schnet_weight_layers] 171 | elif schnet_mask == 'all': 172 | schnet_mask = [False for _ in schnet_weight_layers] 173 | if schnet_mask is not None: 174 | if not isinstance(schnet_mask, list): 175 | raise ValueError("Lipschitz schnet mask must be list of booleans") 176 | if len(schnet_weight_layers) != len(schnet_mask): 177 | raise ValueError("Lipshitz schnet mask must have the same number " 178 | "of elements as the number of nn.Linear " 179 | "modules in the model SchnetFeature.") 180 | 181 | full_mask = network_mask + schnet_mask 182 | full_weight_layers = network_weight_layers + schnet_weight_layers 183 | for mask_element, layer in zip(full_mask, full_weight_layers): 184 | if mask_element: 185 | weight = layer.weight.data 186 | u, s, v = torch.svd(weight) 187 | if next(model.parameters()).is_cuda: 188 | device = weight.device 189 | lip_reg = torch.max(((s[0]) / strength), 190 | torch.tensor([1.0]).to(device)) 191 | else: 192 | lip_reg = torch.max(((s[0]) / strength), 193 | torch.tensor([1.0])) 194 | layer.weight.data = weight / (lip_reg) 195 | 196 | 197 | def dataset_loss(model, loader, optimizer=None, 198 | regularization_function=None, 199 | train_mode=True, 200 | verbose_interval=None, 201 | print_function=None): 202 | r"""Compute average loss over arbitrary data loader. 203 | This can be used during testing, in which `optimizer` and 204 | `regularization_function` will remain None, or it can be used 205 | during training, in which an optimizer and (optional) 206 | regularization_function are provided. 207 | 208 | Parameters 209 | ---------- 210 | model : cgnet.network.CGNet() instance 211 | model to calculate loss 212 | loader : torch.utils.data.DataLoader() instance 213 | loader (with associated dataset) 214 | optimizer : torch.optim method or None (default=None) 215 | If not None, the optimizer will be zeroed and stepped for each batch. 216 | regularization_function : in-place function or None (default=None) 217 | If not None, the regularization function will be applied after 218 | stepping the optimizer. It must take only "model" as its input 219 | and operate in-place. 220 | train_mode : bool (default=True) 221 | Specifies whether to put the model into train mode for training/learning 222 | or eval mode for testing/inference. See Notes about the important 223 | distinction between these two modes. The model will always be reverted 224 | back to training mode. 225 | verbose_interval : integer or None (default=None) 226 | If not None, a printout of the batch number and loss will be provided 227 | at the specified interval (with respect to batch number). 228 | print_function : python function or None (default=None) 229 | Print function that takes (batch_number, batch_loss) as its only 230 | two arguments, to print updates with our default or the style of 231 | your choice when verbose_interval is not None. 232 | 233 | Returns 234 | ------- 235 | loss : float 236 | loss computed over the entire dataset. If the last batch consists of a 237 | smaller set of left over examples, its contribution to the loss is 238 | weighted by the ratio of number elements in the MSE matrix to that of 239 | the normal number of elements associated with the loader's batch size 240 | before summation to a scalar. 241 | 242 | Example 243 | ------- 244 | from torch.utils.data import DataLoader 245 | 246 | # assume model is a CGNet object 247 | 248 | # For test data, no optimizer or regularization are needed 249 | test_data_loader = DataLoader(test_data, batch_size=batch_size) 250 | test_loss = dataset_loss(net, test_data_loader) 251 | 252 | # For training data, an optimizer is needed. Regularization may 253 | # be used, too 254 | training_data_loader = DataLoader(training_data, batch_size=batch_size) 255 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) 256 | 257 | # Regularization must be in place 258 | def my_reg_fxn(model, strength=lipschitz_strength): 259 | lipschitz_projection(model, strength=strength) 260 | 261 | def my_print_fxn(batch_num, batch_loss): 262 | print("--> Batch #{}, loss = {}".format(batch_num, batch_loss)) 263 | 264 | training_loss = dataset_loss(net, training_data_loader, 265 | optimizer = optimizer, 266 | regularization_function = my_reg_fxn, 267 | train_mode=True, 268 | verbose_interval = 128, 269 | print_function = my_print_fxn) 270 | 271 | Notes 272 | ----- 273 | This method assumes that if there is a smaller batch, it will be at the 274 | end: namely, we assume that the size of the first batch is the largest 275 | batch size. 276 | 277 | It is important to use train_mode=False when performing inference/assessing 278 | a model on test data because certain PyTorch layer types, such as 279 | BatchNorm1d and Dropout, behave differently in 'eval' and 'train' modes. 280 | For more information, please see 281 | 282 | https://pytorch.org/docs/stable/nn.html#torch.nn.Module.eval 283 | 284 | """ 285 | if optimizer is None: 286 | if regularization_function is not None: 287 | raise RuntimeError( 288 | "regularization_function is only used when there is an optimizer, " 289 | "but you have optimizer=None." 290 | ) 291 | if train_mode: 292 | raise RuntimeError( 293 | "Without an optimizer, you probably wanted train_mode=False" 294 | ) 295 | 296 | if train_mode: 297 | model.train() 298 | else: 299 | model.eval() 300 | 301 | loss = 0 302 | effective_number_of_batches = 0 303 | 304 | for batch_num, batch_data in enumerate(loader): 305 | if optimizer is not None: 306 | optimizer.zero_grad() 307 | 308 | coords, force, embedding_property = batch_data 309 | if batch_num == 0: 310 | reference_batch_size = coords.numel() 311 | 312 | batch_weight = coords.numel() / reference_batch_size 313 | if batch_weight > 1: 314 | raise ValueError( 315 | "The first batch was not the largest batch, so you cannot use " 316 | "dataset loss." 317 | ) 318 | if (isinstance(loader.dataset, MultiMoleculeDataset) or 319 | loader.dataset.embeddings is not None): 320 | potential, predicted_force = model.forward(coords, 321 | embedding_property=embedding_property) 322 | else: 323 | potential, predicted_force = model.forward(coords) 324 | 325 | batch_loss = model.criterion(predicted_force, force) 326 | 327 | if optimizer is not None: 328 | batch_loss.backward() 329 | optimizer.step() 330 | 331 | if regularization_function is not None: 332 | regularization_function(model) 333 | 334 | if verbose_interval is not None: 335 | if(batch_num + 1) % verbose_interval == 0: 336 | if print_function is None: 337 | print("Batch: {}, Loss: {:.2f}".format(batch_num+1, 338 | batch_loss)) 339 | else: 340 | print_function(batch_num+1, batch_loss) 341 | 342 | loss += batch_loss.cpu().detach().numpy() * batch_weight 343 | 344 | effective_number_of_batches += batch_weight 345 | 346 | loss /= effective_number_of_batches 347 | 348 | # If the model was in eval mode, put model back into training mode 349 | if model.training == False: 350 | model.train() 351 | 352 | return loss 353 | -------------------------------------------------------------------------------- /cgnet/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coarse-graining/cgnet/a3e0e8ddc06f4b6a9f48f4886b73b4cf372ff481/cgnet/tests/__init__.py -------------------------------------------------------------------------------- /cgnet/tests/test_divergences.py: -------------------------------------------------------------------------------- 1 | # Authors: Brooke Husic and Nick Charron 2 | # Contributors : Dominik Lemm 3 | 4 | import numpy as np 5 | 6 | from cgnet.feature import (kl_divergence, js_divergence, 7 | discrete_distribution_intersection) 8 | 9 | 10 | def _get_random_distr(): 11 | # This function produces two random distributions upon which 12 | # comparisons, overlaps, and divergences can be calculated 13 | 14 | length = np.random.randint(1, 50) # Range of distribution 15 | n_zeros = np.random.randint(1, 10) # Number of bins with zero counts 16 | zeros = np.zeros(n_zeros) # corresponding array of zero count bins 17 | 18 | # Here, we create two distributions, and then shuffle the bins 19 | # so that the zero count bins are distributed randomly along the 20 | # distribution extent 21 | distribution_1 = np.abs(np.concatenate( 22 | [np.random.randn(length).astype(np.float64), zeros])) 23 | distribution_2 = np.abs(np.concatenate( 24 | [np.random.randn(length).astype(np.float64), zeros])) 25 | np.random.shuffle(distribution_1) 26 | np.random.shuffle(distribution_2) 27 | return distribution_1, distribution_2 28 | 29 | 30 | def _get_uniform_histograms(): 31 | # This function produces two histograms sampled from uniform 32 | # distributions, returning the corresponding bins as well 33 | nbins = np.random.randint(2, high=50) # Random number of bins 34 | _bins = np.linspace(0, 1, nbins) # Equally space bins 35 | 36 | # Here, we produce the two histogram/bin pairs. We explicitly use 37 | # the same _bins, so we check that bins_1 and bins_2 are both equal 38 | # to our specified _bins post histogram creation. 39 | histogram_1, bins_1 = np.histogram(np.random.uniform(size=nbins).astype(np.float64), 40 | bins=_bins, 41 | density=True) 42 | histogram_2, bins_2 = np.histogram(np.random.uniform(size=nbins).astype(np.float64), 43 | bins=_bins, 44 | density=True) 45 | 46 | # We verify that that the two bin arrays are the same as each other and 47 | # as our input _bins. This is necessary for proper histogram comparison, 48 | # which is done bin-wise 49 | np.testing.assert_array_equal(bins_1, bins_2) 50 | np.testing.assert_array_equal(bins_1, _bins) 51 | 52 | # Normalize histograms to create discrete distributions 53 | histogram_1 /= np.sum(histogram_1) 54 | histogram_2 /= np.sum(histogram_2) 55 | 56 | # Create bins spaced by 1 57 | bins_1 *= len(histogram_1) 58 | 59 | return histogram_1, histogram_2, bins_1 60 | 61 | 62 | # We use the above two functions to generate random distributions 63 | # and histogram/bin pairs suitable for comparison using CGnet feature tools 64 | dist1, dist2 = _get_random_distr() 65 | hist1, hist2, bins = _get_uniform_histograms() 66 | 67 | 68 | def test_zero_kl_divergence(): 69 | # Tests the calculation of KL divergence for a random distribution from 70 | # zeros with itself. The KL divergence of a distribution with itself 71 | # is equal to zero 72 | div = kl_divergence(dist1, dist1) 73 | np.testing.assert_allclose(div, 0.) 74 | 75 | 76 | def test_kl_divergence(): 77 | # Tests the calculation of KL divergence for two random distributions with 78 | # zeros using a manual calculation 79 | manual_div = 0. # accumulator for the KL divergence 80 | 81 | # Loop through the bins of distribution 1 and accumulate the KL divergence 82 | for i, entry in enumerate(dist1): 83 | if dist1[i] > 0 and dist2[i] > 0: 84 | manual_div += entry * np.log(entry / dist2[i]) 85 | 86 | # Here we verify that the manual calculation above matches the same produced 87 | # by the kl_divergence function 88 | cgnet_div = kl_divergence(dist1, dist2) 89 | np.testing.assert_allclose(manual_div, cgnet_div) 90 | 91 | 92 | def test_zero_js_divergence(): 93 | # Tests the calculation of JS divergence for a random distribution from 94 | # zeros with itself. The JS divergence of a distribution with itself is 95 | # equal to zero 96 | div = js_divergence(dist1, dist1) 97 | np.testing.assert_allclose(div, 0.) 98 | 99 | 100 | def test_js_divergence(): 101 | # Tests the calculation of JS divergence for two random distributions with 102 | # zeros using a manual calculation 103 | 104 | # Here, we mask those mutual bins where the count multiplciation 105 | # is null, as these do not contribute to the JS divergence 106 | dist1_masked = np.ma.masked_where(dist1 * dist2 == 0, dist1) 107 | dist2_masked = np.ma.masked_where(dist1 * dist2 == 0, dist2) 108 | 109 | # Here we produce the elementwise mean of the masked distributions 110 | # for calculating the JS divergence 111 | elementwise_mean = 0.5 * (dist1_masked + dist2_masked) 112 | 113 | manual_div_1 = 0. # accumulator for the first divergence 114 | # Here, we loop through the bins of the first distribution and calculate 115 | # the divergence 116 | for i, entry in enumerate(dist1): 117 | if dist1[i] > 0 and elementwise_mean[i] > 0: 118 | manual_div_1 += entry * np.log(entry / elementwise_mean[i]) 119 | 120 | manual_div_2 = 0. # accumulator for the second divergence 121 | # Here, we loop through the bins of the second distribution and calculate 122 | # the divergence 123 | for i, entry in enumerate(dist2): 124 | if dist2[i] > 0 and elementwise_mean[i] > 0: 125 | manual_div_2 += entry * np.log(entry / elementwise_mean[i]) 126 | 127 | # Manual calculation of the JS divergence 128 | manual_div = np.mean([manual_div_1, manual_div_2]) 129 | 130 | # Here, we verify that the manual calculation matches the 131 | # output of the js_divergence function 132 | cgnet_div = js_divergence(dist1, dist2) 133 | np.testing.assert_allclose(manual_div, cgnet_div) 134 | 135 | 136 | def test_js_divergence_2(): 137 | # Tests the calculation of JS divergence for two random distributions with 138 | # zeros using masked arrays 139 | 140 | # This is the same test as test_js_divergence_1, just done using numpy 141 | # operations rather than loops 142 | dist1_masked = np.ma.masked_where(dist1 == 0, dist1) 143 | dist2_masked = np.ma.masked_where(dist2 == 0, dist2) 144 | elementwise_mean = 0.5 * (dist1_masked + dist2_masked) 145 | summand = 0.5 * (dist1_masked * np.ma.log(dist1_masked / elementwise_mean)) 146 | summand += 0.5 * \ 147 | ((dist2_masked * np.ma.log(dist2_masked / elementwise_mean))) 148 | manual_div = np.ma.sum(summand) 149 | 150 | cgnet_div = js_divergence(dist1, dist2) 151 | np.testing.assert_allclose(manual_div, cgnet_div) 152 | 153 | 154 | def test_full_discrete_distr_intersection(): 155 | # Tests the intersection of a uniform histogram with itself 156 | # The intersection of any histogram with itself should be unity 157 | 158 | cgnet_intersection = discrete_distribution_intersection(hist1, hist1) 159 | np.testing.assert_allclose(cgnet_intersection, 1.) 160 | 161 | cgnet_intersection_bins = discrete_distribution_intersection(hist1, hist1, 162 | bins) 163 | np.testing.assert_allclose(cgnet_intersection_bins, 1.) 164 | 165 | 166 | def test_discrete_distr_intersection(): 167 | # Tests the calculation of intersection for histograms drawn from 168 | # uniform distributions 169 | 170 | manual_intersection = 0. # intersection accumulator 171 | intervals = np.diff(bins) # intervals betweem histogram bins 172 | 173 | # Here we loop though the common histogram intervals and accumulate 174 | # the intersection of the two histograms 175 | for i in range(len(intervals)): 176 | manual_intersection += min(intervals[i] * hist1[i], 177 | intervals[i] * hist2[i]) 178 | 179 | # Here we verify that the manual calculation matches the output of 180 | # the historam_intersection function 181 | cgnet_intersection = discrete_distribution_intersection(hist1, hist2) 182 | np.testing.assert_allclose(manual_intersection, cgnet_intersection) 183 | 184 | 185 | def test_discrete_distr_intersection_no_bins(): 186 | # Tests the calculation of intersection for histograms drawn from 187 | # uniform distributions. The histogram intersection should fill in 188 | # the bins uniformly if none are supplied. 189 | 190 | hist1n = hist1 / np.sum(hist1) 191 | hist2n = hist2 / np.sum(hist2) 192 | 193 | cgnet_intersection = discrete_distribution_intersection( 194 | hist1n, hist2n, bins) 195 | nobins_intersection = discrete_distribution_intersection(hist1, hist2, 196 | bin_edges=None) 197 | np.testing.assert_allclose(cgnet_intersection, nobins_intersection) 198 | -------------------------------------------------------------------------------- /cgnet/tests/test_feature_utils.py: -------------------------------------------------------------------------------- 1 | # Author: Dominik Lemm 2 | # Contributors: Brooke Husic, Nick Charron 3 | 4 | import numpy as np 5 | import torch 6 | from nose.tools import raises 7 | 8 | from cgnet.feature.utils import (GaussianRBF, PolynomialCutoffRBF, 9 | ShiftedSoftplus, _AbstractRBFLayer) 10 | from cgnet.feature.statistics import GeometryStatistics 11 | from cgnet.feature.feature import GeometryFeature, Geometry 12 | 13 | 14 | # Define sizes for a pseudo-dataset 15 | frames = np.random.randint(10, 30) 16 | beads = np.random.randint(5, 10) 17 | g = Geometry(method='torch') 18 | 19 | 20 | @raises(NotImplementedError) 21 | def test_radial_basis_function_len(): 22 | # Make sure that a NotImplementedError is raised if an RBF layer 23 | # does not have a __len__() method 24 | 25 | # Here, we use the _AbstractRBFLayer base class as our RBF 26 | abstract_RBF = _AbstractRBFLayer() 27 | 28 | # Next, we check to see if the NotImplementedError is raised 29 | # This is done using the decorator above, because we cannot 30 | # use nose.tools.assert_raises directly on special methods 31 | len(abstract_RBF) 32 | 33 | 34 | def test_radial_basis_function(): 35 | # Make sure radial basis functions are consistent with manual calculation 36 | 37 | # Distances need to have shape (n_batch, n_beads, n_neighbors) 38 | distances = torch.randn((frames, beads, beads - 1), dtype=torch.float64) 39 | # Define random parameters for the RBF 40 | variance = np.random.random() + 1 41 | n_gaussians = np.random.randint(5, 10) 42 | high_cutoff = np.random.uniform(5.0, 10.0) 43 | low_cutoff = np.random.uniform(0.0, 4.0) 44 | 45 | # Calculate Gaussian expansion using the implemented layer 46 | rbf = GaussianRBF(high_cutoff=high_cutoff, low_cutoff=low_cutoff, 47 | n_gaussians=n_gaussians, variance=variance) 48 | gauss_layer = rbf.forward(distances) 49 | 50 | # Manually calculate expansion with numpy 51 | # according to the following formula: 52 | # e_k (r_j - r_i) = exp(- \gamma (\left \| r_j - r_i \right \| - \mu_k)^2) 53 | # with centers mu_k calculated on a uniform grid between 54 | # zero and the distance cutoff and gamma as a scaling parameter. 55 | centers = np.linspace(low_cutoff, high_cutoff, 56 | n_gaussians).astype(np.float64) 57 | gamma = -0.5 / variance 58 | distances = np.expand_dims(distances, axis=3) 59 | magnitude_squared = (distances - centers)**2 60 | gauss_manual = np.exp(gamma * magnitude_squared) 61 | 62 | # Shapes and values need to be the same 63 | np.testing.assert_equal(centers.shape, rbf.centers.shape) 64 | np.testing.assert_allclose(gauss_layer.numpy(), gauss_manual, rtol=1e-5) 65 | 66 | 67 | def test_radial_basis_function_distance_masking(): 68 | # Makes sure that if a distance mask is used, the corresponding 69 | # expanded distances returned by GaussianRBF are zero 70 | 71 | # Distances need to have shape (n_batch, n_beads, n_neighbors) 72 | distances = torch.randn((frames, beads, beads - 1), dtype=torch.float64) 73 | # Define random parameters for the RBF 74 | variance = np.random.random() + 1 75 | high_cutoff = np.random.uniform(5.0, 10.0) 76 | low_cutoff = np.random.uniform(0.0, 4.0) 77 | n_gaussians = np.random.randint(5, 10) 78 | neighbor_cutoff = np.abs(np.random.rand()) 79 | neighbors, neighbor_mask = g.get_neighbors(distances, 80 | cutoff=neighbor_cutoff) 81 | 82 | # Calculate Gaussian expansion using the implemented layer 83 | rbf = GaussianRBF(high_cutoff=high_cutoff, low_cutoff=low_cutoff, 84 | n_gaussians=n_gaussians, variance=variance) 85 | gauss_layer = rbf.forward(distances, distance_mask=neighbor_mask) 86 | 87 | # Lastly, we check to see that the application of the mask is correct 88 | # against a manual calculation and masking 89 | centers = np.linspace(low_cutoff, high_cutoff, n_gaussians) 90 | gamma = -0.5 / variance 91 | distances = np.expand_dims(distances, axis=3) 92 | magnitude_squared = (distances - centers)**2 93 | gauss_manual = torch.tensor(np.exp(gamma * magnitude_squared)) 94 | gauss_manual = gauss_manual * neighbor_mask[:, :, :, None].double() 95 | 96 | np.testing.assert_array_almost_equal(gauss_layer.numpy(), 97 | gauss_manual.numpy()) 98 | 99 | 100 | def test_radial_basis_function_normalize(): 101 | # Tests to make sure that the output of GaussianRBF is properly 102 | # normalized if 'normalize_output' is specified as True 103 | 104 | # Distances need to have shape (n_batch, n_beads, n_neighbors) 105 | distances = torch.randn((frames, beads, beads - 1), dtype=torch.float64) 106 | # Define random parameters for the RBF 107 | variance = np.random.random() + 1 108 | n_gaussians = np.random.randint(5, 10) 109 | high_cutoff = np.random.uniform(5.0, 10.0) 110 | low_cutoff = np.random.uniform(0.0, 4.0) 111 | 112 | # Calculate Gaussian expansion using the implemented layer 113 | rbf = GaussianRBF(high_cutoff=high_cutoff, low_cutoff=low_cutoff, 114 | n_gaussians=n_gaussians, variance=variance, 115 | normalize_output=True) 116 | gauss_layer = rbf.forward(distances) 117 | 118 | # Manually calculate expansion with numpy 119 | # according to the following formula: 120 | # e_k (r_j - r_i) = exp(- \gamma (\left \| r_j - r_i \right \| - \mu_k)^2) 121 | # with centers mu_k calculated on a uniform grid between 122 | # zero and the distance cutoff and gamma as a scaling parameter. 123 | centers = np.linspace(low_cutoff, high_cutoff, 124 | n_gaussians).astype(np.float64) 125 | gamma = -0.5 / variance 126 | distances = np.expand_dims(distances, axis=3) 127 | magnitude_squared = (distances - centers)**2 128 | gauss_manual = np.exp(gamma * magnitude_squared) 129 | 130 | # manual output normalization 131 | gauss_manual = gauss_manual / np.sum(gauss_manual, axis=3)[:, :, :, None] 132 | 133 | # Shapes and values need to be the same 134 | np.testing.assert_equal(centers.shape, rbf.centers.shape) 135 | np.testing.assert_allclose(gauss_layer.numpy(), gauss_manual, rtol=1e-5) 136 | 137 | 138 | def test_polynomial_cutoff_rbf(): 139 | # Make sure the polynomial_cutoff radial basis functions are consistent with 140 | # manual calculations 141 | 142 | # Distances need to have shape (n_batch, n_beads, n_neighbors) 143 | distances = np.random.randn(frames, beads, beads - 1).astype(np.float64) 144 | # Define random parameters for the polynomial_cutoff RBF 145 | n_gaussians = np.random.randint(5, 10) 146 | high_cutoff = np.random.uniform(5.0, 10.0) 147 | low_cutoff = np.random.uniform(0.0, 4.0) 148 | alpha = np.random.uniform(0.1, 1.0) 149 | 150 | # Calculate Gaussian expansion using the implemented layer 151 | polynomial_cutoff_rbf = PolynomialCutoffRBF(high_cutoff=high_cutoff, 152 | low_cutoff=low_cutoff, 153 | n_gaussians=n_gaussians, 154 | alpha=alpha, 155 | tolerance=1e-8) 156 | polynomial_cutoff_rbf_layer = polynomial_cutoff_rbf.forward( 157 | torch.tensor(distances)) 158 | 159 | # Manually calculate expansion with numpy 160 | # First, we compute the centers and the scaling factors 161 | centers = np.linspace(np.exp(-high_cutoff), np.exp(-low_cutoff), 162 | n_gaussians).astype(np.float64) 163 | beta = np.power(((2/n_gaussians) * (1-np.exp(-high_cutoff))), -2) 164 | 165 | # Next, we compute the gaussian portion 166 | exp_distances = np.exp(-alpha * np.expand_dims(distances, axis=3)) 167 | magnitude_squared = np.power(exp_distances - centers, 2) 168 | gauss_manual = np.exp(-beta * magnitude_squared) 169 | 170 | # Next, we compute the polynomial modulation 171 | zeros = np.zeros_like(distances) 172 | modulation = np.where(distances < high_cutoff, 173 | 1 - 6.0 * np.power((distances/high_cutoff), 5) 174 | + 15.0 * np.power((distances/high_cutoff), 4) 175 | - 10.0 * np.power((distances/high_cutoff), 3), 176 | zeros) 177 | modulation = np.expand_dims(modulation, axis=3) 178 | 179 | polynomial_cutoff_rbf_manual = modulation * gauss_manual 180 | 181 | # Map tiny values to zero 182 | polynomial_cutoff_rbf_manual = np.where( 183 | np.abs(polynomial_cutoff_rbf_manual) > polynomial_cutoff_rbf.tolerance, 184 | polynomial_cutoff_rbf_manual, 185 | np.zeros_like(polynomial_cutoff_rbf_manual) 186 | ) 187 | 188 | # centers and output values need to be the same 189 | np.testing.assert_allclose(centers, 190 | polynomial_cutoff_rbf.centers, rtol=1e-5) 191 | np.testing.assert_allclose(polynomial_cutoff_rbf_layer.numpy(), 192 | polynomial_cutoff_rbf_manual, rtol=1e-5) 193 | 194 | 195 | def test_polynomial_cutoff_rbf_distance_masking(): 196 | # Makes sure that if a distance mask is used, the corresponding 197 | # expanded distances returned by PolynomialCutoffRBF are zero 198 | 199 | # Distances need to have shape (n_batch, n_beads, n_neighbors) 200 | distances = torch.randn((frames, beads, beads - 1), dtype=torch.float64) 201 | # Define random parameters for the RBF 202 | n_gaussians = np.random.randint(5, 10) 203 | high_cutoff = np.random.uniform(5.0, 10.0) 204 | low_cutoff = np.random.uniform(0.0, 4.0) 205 | alpha = np.random.uniform(0.1, 1.0) 206 | 207 | neighbor_cutoff = np.abs(np.random.rand()) 208 | neighbors, neighbor_mask = g.get_neighbors(distances, 209 | cutoff=neighbor_cutoff) 210 | 211 | # Calculate Gaussian expansion using the implemented layer 212 | polynomial_cutoff_rbf = PolynomialCutoffRBF(high_cutoff=high_cutoff, 213 | low_cutoff=low_cutoff, 214 | n_gaussians=n_gaussians, 215 | alpha=alpha, 216 | tolerance=1e-8) 217 | polynomial_cutoff_rbf_layer = polynomial_cutoff_rbf.forward( 218 | torch.tensor(distances), 219 | distance_mask=neighbor_mask) 220 | 221 | # Manually calculate expansion with numpy 222 | # First, we compute the centers and the scaling factors 223 | centers = np.linspace(np.exp(-high_cutoff), np.exp(-low_cutoff), 224 | n_gaussians).astype(np.float64) 225 | beta = np.power(((2/n_gaussians) * (1-np.exp(-high_cutoff))), -2) 226 | 227 | # Next, we compute the gaussian portion 228 | exp_distances = np.exp(-alpha * np.expand_dims(distances, axis=3)) 229 | magnitude_squared = np.power(exp_distances - centers, 2) 230 | gauss_manual = np.exp(-beta * magnitude_squared) 231 | 232 | # Next, we compute the polynomial modulation 233 | zeros = np.zeros_like(distances) 234 | modulation = np.where(distances < high_cutoff, 235 | 1 - 6.0 * np.power((distances/high_cutoff), 5) 236 | + 15.0 * np.power((distances/high_cutoff), 4) 237 | - 10.0 * np.power((distances/high_cutoff), 3), 238 | zeros) 239 | modulation = np.expand_dims(modulation, axis=3) 240 | 241 | polynomial_cutoff_rbf_manual = modulation * gauss_manual 242 | 243 | # Map tiny values to zero 244 | polynomial_cutoff_rbf_manual = np.where( 245 | np.abs(polynomial_cutoff_rbf_manual) > polynomial_cutoff_rbf.tolerance, 246 | polynomial_cutoff_rbf_manual, 247 | np.zeros_like(polynomial_cutoff_rbf_manual) 248 | ) 249 | polynomial_cutoff_rbf_manual = torch.tensor( 250 | polynomial_cutoff_rbf_manual) * neighbor_mask[:, :, :, None].double() 251 | 252 | np.testing.assert_array_almost_equal(polynomial_cutoff_rbf_layer.numpy(), 253 | polynomial_cutoff_rbf_manual.numpy()) 254 | 255 | 256 | def test_polynomial_cutoff_rbf_normalize(): 257 | # Tests to make sure that the output of PolynomialCutoffRBF is properly 258 | # normalized if 'normalize_output' is specified as True 259 | 260 | # Distances need to have shape (n_batch, n_beads, n_neighbors) 261 | distances = np.random.randn(frames, beads, beads - 1).astype(np.float64) 262 | # Define random parameters for the polynomial_cutoff RBF 263 | n_gaussians = np.random.randint(5, 10) 264 | high_cutoff = np.random.uniform(5.0, 10.0) 265 | low_cutoff = np.random.uniform(0.0, 4.0) 266 | alpha = np.random.uniform(0.1, 1.0) 267 | 268 | # Calculate Gaussian expansion using the implemented layer 269 | polynomial_cutoff_rbf = PolynomialCutoffRBF(high_cutoff=high_cutoff, 270 | low_cutoff=low_cutoff, 271 | n_gaussians=n_gaussians, 272 | alpha=alpha, 273 | normalize_output=True, 274 | tolerance=1e-8) 275 | polynomial_cutoff_rbf_layer = polynomial_cutoff_rbf.forward( 276 | torch.tensor(distances)) 277 | 278 | # Manually calculate expansion with numpy 279 | # First, we compute the centers and the scaling factors 280 | centers = np.linspace(np.exp(-high_cutoff), np.exp(-low_cutoff), 281 | n_gaussians).astype(np.float64) 282 | beta = np.power(((2/n_gaussians) * (1-np.exp(-high_cutoff))), -2) 283 | 284 | # Next, we compute the gaussian portion 285 | exp_distances = np.exp(-alpha * np.expand_dims(distances, axis=3)) 286 | magnitude_squared = np.power(exp_distances - centers, 2) 287 | gauss_manual = np.exp(-beta * magnitude_squared) 288 | 289 | # Next, we compute the polynomial modulation 290 | zeros = np.zeros_like(distances) 291 | modulation = np.where(distances < high_cutoff, 292 | 1 - 6.0 * np.power((distances/high_cutoff), 5) 293 | + 15.0 * np.power((distances/high_cutoff), 4) 294 | - 10.0 * np.power((distances/high_cutoff), 3), 295 | zeros) 296 | modulation = np.expand_dims(modulation, axis=3) 297 | 298 | polynomial_cutoff_rbf_manual = modulation * gauss_manual 299 | 300 | # Map tiny values to zero 301 | polynomial_cutoff_rbf_manual = np.where( 302 | np.abs(polynomial_cutoff_rbf_manual) > polynomial_cutoff_rbf.tolerance, 303 | polynomial_cutoff_rbf_manual, 304 | np.zeros_like(polynomial_cutoff_rbf_manual) 305 | ) 306 | 307 | # manually normalize the output 308 | polynomial_cutoff_rbf_manual /= np.sum(polynomial_cutoff_rbf_manual, 309 | axis=3)[:, :, :, None] 310 | 311 | # centers and output values need to be the same 312 | np.testing.assert_allclose(centers, 313 | polynomial_cutoff_rbf.centers, rtol=1e-5) 314 | np.testing.assert_allclose(polynomial_cutoff_rbf_layer.numpy(), 315 | polynomial_cutoff_rbf_manual, rtol=1e-5) 316 | 317 | 318 | def test_polynomial_cutoff_rbf_zero_cutoff(): 319 | # This test ensures that a choice of zero cutoff produces 320 | # a set of basis functions that all occupy the same center 321 | 322 | # First, we generate a polynomial_cutoff RBF layer with a random number 323 | # of gaussians and a cutoff of zero 324 | n_gaussians = np.random.randint(5, 10) 325 | cutoff = 0.0 326 | polynomial_cutoff_rbf = PolynomialCutoffRBF(n_gaussians=n_gaussians, 327 | high_cutoff=cutoff, low_cutoff=cutoff) 328 | # First we test to see that \beta is infinite 329 | np.testing.assert_equal(np.inf, polynomial_cutoff_rbf.beta) 330 | 331 | # Next we make a mock array of centers at 1.0 332 | centers = torch.linspace( 333 | np.exp(-cutoff), np.exp(-cutoff), n_gaussians, dtype=torch.float64) 334 | 335 | # Here, we test to see that centers are equal in this corner case 336 | np.testing.assert_equal(centers.numpy(), 337 | polynomial_cutoff_rbf.centers.numpy()) 338 | 339 | 340 | def test_shifted_softplus(): 341 | # Make sure shifted softplus activation is consistent with 342 | # manual calculation 343 | 344 | # Initialize random feature vector 345 | feature = torch.randn((frames, beads), dtype=torch.float64) 346 | 347 | ssplus = ShiftedSoftplus() 348 | # Shifted softplus has the following form: 349 | # y = \ln\left(1 + e^{-x}\right) - \ln(2) 350 | manual_output = np.log(1.0 + np.exp(feature.numpy())) - np.log(2.0) 351 | 352 | np.testing.assert_allclose(manual_output, ssplus(feature).numpy()) 353 | -------------------------------------------------------------------------------- /cgnet/tests/test_geometry_core.py: -------------------------------------------------------------------------------- 1 | # Author: Brooke Husic 2 | # Contributors: Dominik Lemm 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from cgnet.feature import Geometry 8 | 9 | g_numpy = Geometry(method='numpy') 10 | g_torch = Geometry(method='torch') 11 | 12 | # Define sizes for a pseudo-dataset 13 | frames = np.random.randint(10, 30) 14 | beads = np.random.randint(5, 10) 15 | 16 | # create random linear protein data 17 | coords = np.random.randn(frames, beads, 3).astype(np.float64) 18 | 19 | # Calculate redundant distances and create a simple neighbor list in which all 20 | # beads see each other (shape [n_frames, n_beads, n_beads -1]). 21 | _distance_pairs, _ = g_numpy.get_distance_indices(beads, [], []) 22 | redundant_distance_mapping = g_numpy.get_redundant_distance_mapping( 23 | _distance_pairs) 24 | 25 | neighbor_cutoff = np.random.uniform(0, 1) 26 | 27 | 28 | def test_tile_methods_numpy_vs_torch(): 29 | # Test to make sure geometry.tile is still equivalent between numpy 30 | # and pytorch 31 | 32 | # Create inputs for a 3d array that will have 24 elements 33 | A = np.array([np.random.randint(10) for _ in range(24)]) 34 | 35 | # Make two likely different shapes for the array and the tiling 36 | # with friendly factors 37 | shape_one = [2, 3, 4] 38 | np.random.shuffle(shape_one) 39 | 40 | shape_two = [2, 3, 4] 41 | np.random.shuffle(shape_two) 42 | 43 | # Reshape A with the first shape 44 | A = A.reshape(*shape_one).astype(np.float32) 45 | 46 | # Test whether the tiling is equivalent to the second shape 47 | # Add in the standard check for fun 48 | g_numpy.check_array_vs_tensor(A) 49 | tile_numpy = g_numpy.tile(A, shape_two) 50 | 51 | g_torch.check_array_vs_tensor(torch.Tensor(A)) 52 | tile_torch = g_torch.tile(torch.Tensor(A), shape_two) 53 | 54 | np.testing.assert_array_equal(tile_numpy, tile_torch) 55 | 56 | 57 | def test_distances_and_neighbors_numpy_vs_torch(): 58 | # Comparison of numpy and torch outputs for getting geometry.get_distances 59 | # and geometry.get_neighbors 60 | 61 | # Calculate distances, neighbors, and neighbor mask using the numpy 62 | # version of Geometry 63 | distances_numpy = g_numpy.get_distances(_distance_pairs, 64 | coords, 65 | norm=True) 66 | distances_numpy = distances_numpy[:, redundant_distance_mapping] 67 | neighbors_numpy, neighbors_mask_numpy = g_numpy.get_neighbors( 68 | distances_numpy, 69 | cutoff=neighbor_cutoff) 70 | 71 | # Calculate distances, neighbors, and neighbor mask using the torch 72 | # version of Geometry 73 | distances_torch = g_torch.get_distances(_distance_pairs, 74 | torch.from_numpy(coords), 75 | norm=True) 76 | distances_torch = distances_torch[:, redundant_distance_mapping] 77 | neighbors_torch, neighbors_mask_torch = g_torch.get_neighbors( 78 | distances_torch, 79 | cutoff=neighbor_cutoff) 80 | 81 | np.testing.assert_allclose(distances_numpy, distances_torch, rtol=1e-9) 82 | np.testing.assert_allclose(neighbors_numpy, neighbors_torch, rtol=1e-9) 83 | np.testing.assert_allclose( 84 | neighbors_mask_numpy, neighbors_mask_torch, rtol=1e-9) 85 | 86 | 87 | def test_nan_check(): 88 | # Test if an assert is raised during the computation of distances, angles 89 | # and dihedrals 90 | 91 | # Calculate angle and dihedral pair indices 92 | angle_pairs = [(i, i+1, i+2) for i in range(beads - 2)] 93 | dihedral_pairs = [(i, i+1, i+2, i+3) for i in range(beads - 3)] 94 | 95 | # Select random frame and bead to set NaN 96 | random_frame = np.random.randint(0, frames) 97 | random_bead = np.random.randint(0, beads) 98 | 99 | # Create test coordinates that contain NaN 100 | nan_coords = coords.copy() 101 | nan_coords[random_frame][random_bead] = np.nan 102 | torch_nan_coords = torch.from_numpy(nan_coords) 103 | 104 | # Check if an assert is raised 105 | np.testing.assert_raises(ValueError, 106 | g_numpy.get_distances, _distance_pairs, nan_coords) 107 | np.testing.assert_raises(ValueError, 108 | g_numpy.get_angles, angle_pairs, nan_coords) 109 | np.testing.assert_raises(ValueError, 110 | g_numpy.get_dihedrals, dihedral_pairs, nan_coords) 111 | 112 | np.testing.assert_raises(ValueError, 113 | g_torch.get_distances, _distance_pairs, 114 | torch_nan_coords) 115 | np.testing.assert_raises(ValueError, 116 | g_torch.get_angles, angle_pairs, 117 | torch_nan_coords) 118 | np.testing.assert_raises(ValueError, 119 | g_torch.get_dihedrals, dihedral_pairs, 120 | torch_nan_coords) 121 | -------------------------------------------------------------------------------- /cgnet/tests/test_geometry_features.py: -------------------------------------------------------------------------------- 1 | # Author: Brooke Husic 2 | # Contributors: Dominik Lemm 3 | 4 | import numpy as np 5 | import scipy.spatial 6 | import torch 7 | 8 | from cgnet.feature import GeometryFeature, Geometry 9 | g = Geometry(method='torch') 10 | 11 | # The following sets up our pseudo-simulation data 12 | 13 | # Number of frames 14 | frames = np.random.randint(1, 10) 15 | 16 | # Number of coarse-grained beads. We need at least 4 so we can do dihedrals. 17 | beads = np.random.randint(8, 10) 18 | 19 | # Number of dimensions; for now geometry only handles 3 20 | dims = 3 21 | 22 | # Create a pseudo simulation dataset 23 | data = np.random.randn(frames, beads, dims).astype(np.float64) 24 | data_tensor = torch.Tensor(data) 25 | 26 | # Note: currently get_distance_indices is not directly tested. 27 | # Possibly add a test here? 28 | distance_inds, _ = g.get_distance_indices(beads) 29 | 30 | angle_inds = [(i, i+1, i+2) for i in range(beads-2)] 31 | dihedral_inds = [(i, i+1, i+2, i+3) for i in range(beads-3)] 32 | 33 | 34 | def test_distance_features(): 35 | # Make sure pairwise distance features are consistent with scipy 36 | 37 | geom_feature = GeometryFeature(feature_tuples='all_backbone', 38 | n_beads=beads) 39 | # Forward pass calculates features (distances, angles, dihedrals) 40 | # and makes them accessible as attributes 41 | _ = geom_feature.forward(data_tensor) 42 | 43 | # Test each frame x_i 44 | for frame_ind in range(frames): 45 | Dmat_xi = scipy.spatial.distance.squareform( 46 | scipy.spatial.distance.pdist(data[frame_ind])) 47 | 48 | xi_feature_distances = list(geom_feature.distances[frame_ind].numpy()) 49 | feature_descriptions = geom_feature.descriptions['Distances'] 50 | 51 | # Arrange the scipy distances in the right order for comparing 52 | # to the GeometryFeature distances 53 | xi_scipy_distances = [Dmat_xi[feature_descriptions[i]] 54 | for i in range(len(feature_descriptions))] 55 | 56 | np.testing.assert_allclose(xi_feature_distances, 57 | xi_scipy_distances, rtol=1e-6) 58 | 59 | 60 | def test_backbone_angle_features(): 61 | # Make sure backbone angle features are consistent with manual calculation 62 | 63 | # For spatial coordinates a, b, c, the angle \theta describing a-b-c 64 | # is calculated using the following formula: 65 | # 66 | # \overline{ba} = a - b 67 | # \overline{cb} = c - b 68 | # \cos(\theta) = (\frac{\overline{ba} \dot \overline{cb}} 69 | # {||\overline{ba}|| ||\overline{cb}||} 70 | # \theta = \arccos(\theta) 71 | 72 | geom_feature = GeometryFeature(feature_tuples='all_backbone', 73 | n_beads=beads) 74 | # Forward pass calculates features (distances, angles, dihedrals) 75 | # and makes them accessible as attributes 76 | _ = geom_feature.forward(data_tensor) 77 | 78 | # Manually calculate the angles one frame at a time 79 | angles = [] 80 | for frame_data in data: 81 | angle_list = [] 82 | for i in range(data.shape[1] - 2): 83 | a = frame_data[i] 84 | b = frame_data[i+1] 85 | c = frame_data[i+2] 86 | 87 | ba = a - b 88 | cb = c - b 89 | 90 | cos_angle = np.dot(ba, cb) / (np.linalg.norm(ba) 91 | * np.linalg.norm(cb)) 92 | angle = np.arccos(cos_angle) 93 | angle_list.append(angle) 94 | angles.append(angle_list) 95 | 96 | np.testing.assert_allclose(geom_feature.angles, angles, rtol=1e-4) 97 | 98 | 99 | def test_dihedral_features(): 100 | # Make sure backbone dihedral features are consistent with manual calculation 101 | 102 | # For spatial coordinates a, b, c, d, the dihedral \alpha describing 103 | # a-b-c-d (i.e., the plane between angles a-b-c- and b-c-d-) is calculated 104 | # using the following formula: 105 | # 106 | # \overline{ba} = b - a 107 | # \overline{cb} = c - a 108 | # \overline{dc} = d - c 109 | # 110 | # % normal vector with plane of first and second angles, respectively 111 | # n_1 = \overline{ba} \times \overline{cb} 112 | # n_2 = \overline{cb} \ times \overline{dc} 113 | # 114 | # m_1 = n_2 \times n_1 115 | # 116 | # \sin(\alpha) = \frac{m_1 \dot \overline{cb}} 117 | # {\sqrt{\overline{cb} \dot \overline{cb}}} 118 | # \cos(\alpha) = n_2 \dot n_1 119 | # \alpha = \arctan{\frac{\sin(\alpha)}{\cos(\alpha)}} 120 | 121 | geom_feature = GeometryFeature(feature_tuples='all_backbone', 122 | n_beads=beads) 123 | # Forward pass calculates features (distances, angles, dihedrals) 124 | # and makes them accessible as attributes 125 | _ = geom_feature.forward(data_tensor) 126 | 127 | # Manually calculate the dihedrals one frame at a time 128 | diheds = [] 129 | for frame_data in data: 130 | dihed_list = [] 131 | for i in range(data.shape[1] - 3): 132 | a = frame_data[i] 133 | b = frame_data[i+1] 134 | c = frame_data[i+2] 135 | d = frame_data[i+3] 136 | 137 | ba = b-a 138 | cb = c-b 139 | dc = d-c 140 | 141 | n1 = np.cross(ba, cb) 142 | n2 = np.cross(cb, dc) 143 | m1 = np.cross(n2, n1) 144 | term1 = np.dot(m1, cb)/np.sqrt(np.dot(cb, cb)) 145 | term2 = np.dot(n2, n1) 146 | dihed_list.append(np.arctan2(term1, term2)) 147 | diheds.append(dihed_list) 148 | 149 | # Instead of comparing the sines and cosines, compare the arctans 150 | feature_diheds = [np.arctan2(geom_feature.dihedral_sines[i].numpy(), 151 | geom_feature.dihedral_cosines[i].numpy()) 152 | for i in range(len(geom_feature.dihedral_sines))] 153 | np.testing.assert_allclose(np.abs(feature_diheds), 154 | np.abs(diheds), rtol=1e-4) 155 | 156 | 157 | def test_distance_index_shuffling(): 158 | # Make sure shuffled distances return the right results 159 | 160 | # Create a dataset with one frame, 10 beads, 3 dimensions 161 | data_to_shuffle = np.random.randn(1, 10, 3) 162 | data_to_shuffle_tensor = torch.Tensor(data_to_shuffle) 163 | 164 | y_dist_inds, _ = g.get_distance_indices(10) 165 | 166 | geom_feature = GeometryFeature(feature_tuples=y_dist_inds) 167 | # Forward pass calculates features (distances, angles, dihedrals) 168 | # and makes them accessible as attributes 169 | _ = geom_feature.forward(data_to_shuffle_tensor) 170 | 171 | # Shuffle the distances indices 172 | inds = np.arange(len(y_dist_inds)) 173 | np.random.shuffle(inds) 174 | 175 | shuffled_inds = [tuple(i) for i in np.array(y_dist_inds)[inds]] 176 | geom_feature_shuffle = GeometryFeature(feature_tuples=shuffled_inds) 177 | _ = geom_feature_shuffle.forward(data_to_shuffle_tensor) 178 | 179 | # See if the non-shuffled distances are the same when indexexed according 180 | # to the shuffling 181 | np.testing.assert_array_equal(geom_feature_shuffle.distances[0], 182 | geom_feature.distances[0][inds]) 183 | 184 | 185 | def test_angle_index_shuffling(): 186 | # Make sure shuffled angles return the right results 187 | 188 | # Create a dataset with one frame, 100 beads, 3 dimensions 189 | data_to_shuffle = np.random.randn(1, 100, 3) 190 | data_to_shuffle_tensor = torch.Tensor(data_to_shuffle) 191 | 192 | y_angle_inds = [(i, i+1, i+2) for i in range(100-2)] 193 | 194 | geom_feature = GeometryFeature(feature_tuples=y_angle_inds) 195 | # Forward pass calculates features (distances, angles, dihedrals) 196 | # and makes them accessible as attributes 197 | _ = geom_feature.forward(data_to_shuffle_tensor) 198 | 199 | # Shuffle all the inds that can serve as an angle start 200 | inds = np.arange(100-2) 201 | np.random.shuffle(inds) 202 | 203 | shuffled_inds = [tuple(i) for i in np.array(y_angle_inds)[inds]] 204 | geom_feature_shuffle = GeometryFeature(feature_tuples=shuffled_inds) 205 | _ = geom_feature_shuffle.forward(data_to_shuffle_tensor) 206 | 207 | # See if the non-shuffled angles are the same when indexexed according 208 | # to the shuffling 209 | np.testing.assert_array_equal(geom_feature_shuffle.angles[0], 210 | geom_feature.angles[0][inds]) 211 | 212 | 213 | def test_dihedral_index_shuffling(): 214 | # Make sure shuffled dihedrals return the right results 215 | 216 | # Create a dataset with one frame, 100 beads, 3 dimensions 217 | data_to_shuffle = np.random.randn(1, 100, 3) 218 | data_to_shuffle_tensor = torch.Tensor(data_to_shuffle) 219 | 220 | y_dihed_inds = [(i, i+1, i+2, i+3) for i in range(100-3)] 221 | 222 | geom_feature = GeometryFeature(feature_tuples=y_dihed_inds) 223 | # Forward pass calculates features (distances, angles, dihedrals) 224 | # and makes them accessible as attributes 225 | _ = geom_feature.forward(data_to_shuffle_tensor) 226 | 227 | # Shuffle all the inds that can serve as a dihedral start 228 | inds = np.arange(100-3) 229 | np.random.shuffle(inds) 230 | 231 | shuffled_inds = [tuple(i) for i in np.array(y_dihed_inds)[inds]] 232 | geom_feature_shuffle = GeometryFeature(feature_tuples=shuffled_inds) 233 | _ = geom_feature_shuffle.forward(data_to_shuffle_tensor) 234 | 235 | # See if the non-shuffled dihedral sines and cosines are the same when 236 | # indexexed according to the shuffling 237 | np.testing.assert_allclose(geom_feature_shuffle.dihedral_cosines[0], 238 | geom_feature.dihedral_cosines[0][inds], rtol=1e-5) 239 | 240 | np.testing.assert_allclose(geom_feature_shuffle.dihedral_sines[0], 241 | geom_feature.dihedral_sines[0][inds], rtol=1e-5) 242 | -------------------------------------------------------------------------------- /cgnet/tests/test_gpu.py: -------------------------------------------------------------------------------- 1 | # Author: Nick Charron 2 | # Contributors: Brooke Husic, Dominik Lemm 3 | 4 | import numpy as np 5 | import tempfile 6 | import torch 7 | import torch.nn as nn 8 | from sklearn.linear_model import LinearRegression 9 | from sklearn.metrics import mean_squared_error as mse 10 | from cgnet.network import (CGnet, ForceLoss, 11 | RepulsionLayer, HarmonicLayer, ZscoreLayer) 12 | from cgnet.feature import (GeometryStatistics, GeometryFeature, 13 | MoleculeDataset, LinearLayer, SchnetFeature, 14 | FeatureCombiner, CGBeadEmbedding, GaussianRBF) 15 | from torch.utils.data import DataLoader 16 | from nose.exc import SkipTest 17 | 18 | 19 | def generate_model(): 20 | # Generate random CGnet model and coordinates 21 | n_frames = np.random.randint(10, 30) 22 | n_beads = np.random.randint(5, 10) 23 | width = np.random.randint(2, high=10) 24 | 25 | # First we create a random data set of a mock linear protein 26 | coords = np.random.randn(n_frames, n_beads, 3).astype('float32') 27 | 28 | # Next, we gather the statistics for Bond/Repulsion priors 29 | stats = GeometryStatistics(coords, backbone_inds='all', 30 | get_all_distances=True, get_backbone_angles=True, 31 | get_backbone_dihedrals=True) 32 | 33 | bonds_list, _ = stats.get_prior_statistics('Bonds', as_list=True) 34 | bonds_idx = stats.return_indices('Bonds') 35 | 36 | repul_distances = [i for i in stats.descriptions['Distances'] 37 | if abs(i[0]-i[1]) > 2] 38 | repul_idx = stats.return_indices(features=repul_distances) 39 | ex_vols = np.random.uniform(2, 8, len(repul_distances)) 40 | exps = np.random.randint(1, 6, len(repul_distances)) 41 | repul_list = [{'ex_vol': ex_vol, 'exp': exp} 42 | for ex_vol, exp in zip(ex_vols, exps)] 43 | # Next, we also grab the Zscores 44 | zscores, _ = stats.get_zscore_array() 45 | 46 | # Here, we assemble the priors list 47 | priors = [HarmonicLayer(bonds_idx, bonds_list)] 48 | priors += [RepulsionLayer(repul_idx, repul_list)] 49 | 50 | # Next, we assemble a SchnetFeature with random initialization arguments 51 | feature_size = np.random.randint(5, high=10) # random feature size 52 | n_embeddings = np.random.randint(3, high=5) # random embedding number 53 | embedding_dim = feature_size # embedding property size 54 | n_interaction_blocks = np.random.randint( 55 | 1, 3) # random number of interactions 56 | neighbor_cutoff = np.random.uniform(0, 1) # random neighbor cutoff 57 | # random embedding property 58 | embedding_property = torch.randint(low=1, high=n_embeddings, 59 | size=(n_frames, n_beads)) 60 | embedding_layer = CGBeadEmbedding(n_embeddings=n_embeddings, 61 | embedding_dim=embedding_dim) 62 | 63 | # gaussian radial basis function layer 64 | rbf_layer = GaussianRBF() 65 | 66 | schnet_feature = SchnetFeature(feature_size=feature_size, 67 | embedding_layer=embedding_layer, 68 | rbf_layer=rbf_layer, 69 | n_interaction_blocks=n_interaction_blocks, 70 | calculate_geometry=False, 71 | n_beads=n_beads, 72 | neighbor_cutoff=neighbor_cutoff) 73 | # Here we create a GeometryFeature, and we assemble our feates and 74 | # ZscoreLayer into a FeatureCombiner 75 | distance_idx = stats.return_indices("Distances") 76 | geometry_feature = GeometryFeature(feature_tuples=stats.feature_tuples) 77 | features = [geometry_feature, ZscoreLayer(zscores), schnet_feature] 78 | combined_features = FeatureCombiner( 79 | features, distance_indices=distance_idx) 80 | 81 | # Next, we create the hidden architecture of CGnet 82 | arch = LinearLayer(feature_size, width) 83 | arch += LinearLayer(width, 1) 84 | 85 | # Finally, we assemble the model 86 | model = CGnet(arch, ForceLoss(), feature=combined_features, 87 | priors=priors) 88 | return model, coords, embedding_property 89 | 90 | 91 | def test_cgnet_mount(): 92 | if not torch.cuda.is_available(): 93 | raise SkipTest("GPU not available for testing.") 94 | device = torch.device('cuda') 95 | 96 | # This test asseses CUDA mounting for an entire CGnet model 97 | # First we create a random model with random protein data 98 | model, coords, embedding_property = generate_model() 99 | 100 | # Next, we mount the model to GPU 101 | model.mount(device) 102 | 103 | # Next, we check to see if each layer is mounted correctly 104 | # This is done by checking if parameters/buffers are mapped to the correct 105 | # device, or that feature classes are imbued with the appropriate device 106 | # First, we check features 107 | for layer in model.feature.layer_list: 108 | if isinstance(layer, (GeometryFeature, SchnetFeature)): 109 | assert layer.device == device 110 | if isinstance(layer, ZscoreLayer): 111 | assert layer.zscores.device.type == device.type 112 | # Next, we check priors 113 | for prior in model.priors: 114 | if isinstance(prior, HarmonicLayer): 115 | assert prior.harmonic_parameters.device.type == device.type 116 | if isinstance(prior, RepulsionLayer): 117 | assert prior.repulsion_parameters.device.type == device.type 118 | # Finally, we check the arch layers 119 | for param in model.parameters(): 120 | assert param.device.type == device.type 121 | 122 | # Lastly, we perform a forward pass over the data and 123 | coords = torch.tensor(coords, requires_grad=True).to(device) 124 | embedding_property = embedding_property.to(device) 125 | pot, pred_force = model.forward(coords, embedding_property) 126 | assert pot.device.type == device.type 127 | assert pred_force.device.type == device.type 128 | 129 | 130 | def test_cgnet_dismount(): 131 | if not torch.cuda.is_available(): 132 | raise SkipTest("GPU not available for testing.") 133 | device = torch.device('cuda') 134 | 135 | # This test asseses the ability of an entire CGnet to dismount from GPU 136 | # First we create a random model with random protein data 137 | model, coords, embedding_property = generate_model() 138 | 139 | # First, we mount the model to GPU 140 | model.mount(device) 141 | 142 | # Here we dismount the model from GPU 143 | device = torch.device('cpu') 144 | model.mount(device) 145 | 146 | # Next, we check to see if each layer is dismounted correctly 147 | # This is done by checking if parameters/buffers are mapped to the correct 148 | # device, or that feature classes are imbued with the appropriate device 149 | # First, we check features 150 | for layer in model.feature.layer_list: 151 | if isinstance(layer, (GeometryFeature, SchnetFeature)): 152 | assert layer.device.type == device.type 153 | if isinstance(layer, ZscoreLayer): 154 | assert layer.zscores.device.type == device.type 155 | # Next, we check priors 156 | for prior in model.priors: 157 | if isinstance(prior, HarmonicLayer): 158 | assert prior.harmonic_parameters.device.type == device.type 159 | if isinstance(prior, RepulsionLayer): 160 | assert prior.repulsion_parameters.device.type == device.type 161 | # Finally, we check the arch layers 162 | for param in model.parameters(): 163 | assert param.device.type == device.type 164 | # Lastly, we perform a forward pass over the data and 165 | coords = torch.tensor(coords, requires_grad=True).to(device) 166 | pot, pred_force = model.forward(coords, embedding_property) 167 | assert pot.device.type == device.type 168 | assert pred_force.device.type == device.type 169 | 170 | 171 | def test_save_load_model(): 172 | # This test asseses the ability to dismount models from GPU that are loaded 173 | # from a saved .pt file 174 | if not torch.cuda.is_available(): 175 | raise SkipTest("GPU not available for testing.") 176 | with tempfile.TemporaryDirectory() as tmp: 177 | device = torch.device('cuda') 178 | 179 | # This test asseses the ability of an entire CGnet to dismount from GPU 180 | # First we create a random model with random protein data 181 | model, coords, embedding_property = generate_model() 182 | 183 | # First, we mount the model to GPU 184 | model.mount(device) 185 | 186 | # Next we save the model to the temporary directory and load it again 187 | # to checkout if it can be dismounted from the GPU 188 | torch.save(model, tmp+"/cgnet_gpu_test.pt") 189 | del model 190 | loaded_model = torch.load(tmp+"/cgnet_gpu_test.pt") 191 | device = torch.device('cpu') 192 | loaded_model.mount(torch.device('cpu')) 193 | # First we check features 194 | for layer in loaded_model.feature.layer_list: 195 | if isinstance(layer, (GeometryFeature, SchnetFeature)): 196 | assert layer.device.type == device.type 197 | if isinstance(layer, ZscoreLayer): 198 | assert layer.zscores.device.type == device.type 199 | # Next, we check priors 200 | for prior in loaded_model.priors: 201 | if isinstance(prior, HarmonicLayer): 202 | assert prior.harmonic_parameters.device.type == device.type 203 | if isinstance(prior, RepulsionLayer): 204 | assert prior.repulsion_parameters.device.type == device.type 205 | # Finally, we check the arch layers 206 | for param in loaded_model.parameters(): 207 | assert param.device.type == device.type 208 | # Lastly, we perform a forward pass over the data and 209 | coords = torch.tensor(coords, requires_grad=True).to(device) 210 | pot, pred_force = loaded_model.forward(coords, embedding_property) 211 | assert pot.device.type == device.type 212 | assert pred_force.device.type == device.type 213 | -------------------------------------------------------------------------------- /cgnet/tests/test_molecule_dataset.py: -------------------------------------------------------------------------------- 1 | # Author: Brooke Husic 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from cgnet.feature import (MoleculeDataset, MultiMoleculeDataset, 7 | multi_molecule_collate) 8 | 9 | # We create an artificial dataset with a random number of 10 | # frames, beads, and dimensions. Since we aren't actually 11 | # doing any featurization, we can use an arbitrary number 12 | # of dimensions 13 | 14 | # For some tests we want an even number of frames 15 | frames = np.random.randint(1, 10)*2 16 | 17 | beads = np.random.randint(1, 10) 18 | dims = np.random.randint(1, 5) 19 | 20 | coords = np.random.randn(frames, beads, dims) # e.g. coords 21 | forces = np.random.randn(frames, beads, dims) # e.g. forces 22 | 23 | # This data is used to test MultiMoleculeDataset methods 24 | # it consists of random data for n_frames number of molecules 25 | # with variying bead numbers per frame/example 26 | 27 | # random largest molecule size of the variable dataset entries 28 | max_beads = np.random.randint(18, 23) 29 | variable_beads = np.random.randint(3, 30 | max_beads, 31 | size=frames) # random protein sizes 32 | variable_coords = [np.random.randn(bead, 3) 33 | for bead in variable_beads] # random coords for each size 34 | variable_forces = [np.random.randn(bead, 3) 35 | for bead in variable_beads] # random forces for each size 36 | 37 | # random embeddings for each size 38 | variable_embeddings = [np.random.randint(1, 39 | high=max_beads, size=bead) 40 | for bead in variable_beads] 41 | 42 | 43 | def test_adding_data(): 44 | # Make sure data is added correctly to a MoleculeDataset 45 | 46 | # Build a dataset with all the data 47 | ds1 = MoleculeDataset(coords, forces) 48 | 49 | # Build a dataset with the first half of the data... 50 | ds2 = MoleculeDataset(coords, forces, selection=np.arange(frames//2)) 51 | # ... then add the second half afterward 52 | ds2.add_data(coords, forces, selection=np.arange(frames//2, frames)) 53 | 54 | # Make sure they're the same 55 | np.testing.assert_array_equal(ds1.coordinates, ds2.coordinates) 56 | np.testing.assert_array_equal(ds1.forces, ds2.forces) 57 | 58 | 59 | def test_adding_variable_selection(): 60 | # Make sure data is added correctly to a MultiMoleculeDataset 61 | 62 | # Build a dataset with all the data 63 | ds1 = MultiMoleculeDataset(variable_coords, variable_forces, 64 | variable_embeddings) 65 | 66 | # Build a dataset with the first half of the data... 67 | ds2 = MultiMoleculeDataset(variable_coords, variable_forces, 68 | variable_embeddings, selection=np.arange(frames//2)) 69 | # ... then add the second half afterward 70 | ds2.add_data(variable_coords, variable_forces, 71 | embeddings_list=variable_embeddings, 72 | selection=np.arange(frames//2, frames)) 73 | 74 | # Make sure they're the same 75 | np.testing.assert_array_equal(ds1.data, ds2.data) 76 | 77 | 78 | def test_stride(): 79 | # Make sure MoleculeDataset stride returns correct results 80 | 81 | stride = np.random.randint(2, 5) 82 | ds = MoleculeDataset(coords, forces, stride=stride) 83 | 84 | strided_coords = coords[::stride] 85 | strided_forces = forces[::stride] 86 | 87 | np.testing.assert_array_equal(ds.coordinates, strided_coords) 88 | np.testing.assert_array_equal(ds.forces, strided_forces) 89 | 90 | 91 | def test_variable_stride(): 92 | # Make sure MultiMoleculeDataset stride returns correct results 93 | 94 | stride = np.random.randint(1, 4) 95 | ds = MultiMoleculeDataset(variable_coords, variable_forces, 96 | variable_embeddings, stride=stride) 97 | 98 | strided_coords = variable_coords[::stride] 99 | strided_forces = variable_forces[::stride] 100 | strided_embeddings = variable_embeddings[::stride] 101 | strided_data = [{'coords': strided_coords[i], 102 | 'forces': strided_forces[i], 103 | 'embeddings': strided_embeddings[i]} 104 | for i in range(len(strided_coords))] 105 | 106 | np.testing.assert_array_equal(ds.data, strided_data) 107 | 108 | 109 | def test_indexing(): 110 | # Make sure MoleculeDataset indexing works (no embeddings) 111 | 112 | # Make a random slice with possible repeats 113 | selection = [np.random.randint(frames) 114 | for _ in range(np.random.randint(frames))] 115 | ds = MoleculeDataset(coords, forces) 116 | 117 | coords_tensor_from_numpy = torch.from_numpy(coords[selection]) 118 | forces_tensor_from_numpy = torch.from_numpy(forces[selection]) 119 | # The third argument is an empty tensor because no embeddings have been 120 | # specified 121 | coords_tensor_from_ds, forces_tensor_from_ds, empty_tensor = ds[selection] 122 | 123 | assert coords_tensor_from_ds.requires_grad 124 | np.testing.assert_array_equal(coords_tensor_from_numpy, 125 | coords_tensor_from_ds.detach().numpy()) 126 | np.testing.assert_array_equal(forces_tensor_from_numpy, 127 | forces_tensor_from_ds.detach().numpy()) 128 | assert len(empty_tensor) == 0 129 | 130 | 131 | def test_variable_indexing(): 132 | # Make sure MultiMoleculeDataset indexing works 133 | 134 | # Make a random slice with possible repeats 135 | selection = [np.random.randint(frames) 136 | for _ in range(np.random.randint(frames))] 137 | ds = MultiMoleculeDataset(variable_coords, variable_forces, 138 | variable_embeddings) 139 | manual_data = [{'coords': variable_coords[i], 140 | 'forces': variable_forces[i], 141 | 'embeddings': variable_embeddings[i]} 142 | for i in selection] 143 | 144 | data = ds[selection] 145 | np.testing.assert_array_equal(manual_data, data) 146 | 147 | 148 | def test_embedding_shape(): 149 | # Test shape of multidimensional embeddings 150 | embeddings = np.random.randint(1, 10, size=(frames, beads)) 151 | 152 | ds = MoleculeDataset(coords, forces, embeddings) 153 | 154 | assert ds[:][2].shape == (frames, beads) 155 | np.testing.assert_array_equal(ds.embeddings, embeddings) 156 | 157 | 158 | def test_multi_molecule_collate(): 159 | # Tests the output of the collating function for variable input 160 | # to make sure that the padding results in a single tensor and 161 | # the padding for each example is a set of right-justified zeros 162 | # for each example with a size lower than the maximum bead size 163 | # in the dataset 164 | 165 | ds = MultiMoleculeDataset(variable_coords, variable_forces, 166 | variable_embeddings) 167 | 168 | # get all data in list of dictionary format 169 | data = ds[np.arange(frames)] 170 | 171 | # get maximum bead number in the dataset 172 | dataset_max_bead = max([coord.shape[0] for coord in variable_coords]) 173 | 174 | # make manually padded data tensors 175 | padded_coord_list = [] 176 | padded_force_list = [] 177 | padded_embedding_list = [] 178 | for data_dict in data: 179 | pads_needed = dataset_max_bead - data_dict['coords'].shape[0] 180 | padded_coords = np.vstack((data_dict['coords'], 181 | np.zeros((pads_needed, 3)))) 182 | padded_forces = np.vstack((data_dict['forces'], 183 | np.zeros((pads_needed, 3)))) 184 | padded_embeddings = np.hstack((data_dict['embeddings'], 185 | np.zeros(pads_needed))) 186 | padded_coord_list.append(padded_coords) 187 | padded_force_list.append(padded_forces) 188 | padded_embedding_list.append(padded_embeddings) 189 | 190 | # assemble the padded data into complete tensors of shape 191 | # [frames, max_beads, 3] for coords/forces, and [frames, max_beads] 192 | # for embeddings 193 | manual_coords = torch.tensor(padded_coord_list, requires_grad=True) 194 | manual_forces = torch.tensor(padded_force_list) 195 | manual_embeddings = torch.tensor(padded_embedding_list) 196 | 197 | # get tensors output from multi_molecule_collate() 198 | coords, forces, embeddings = multi_molecule_collate(data) 199 | 200 | # test manual padding against padding performed by collating 201 | np.testing.assert_array_equal(coords.detach().numpy(), 202 | manual_coords.detach().numpy()) 203 | np.testing.assert_array_equal(forces.numpy(), manual_forces.numpy()) 204 | np.testing.assert_array_equal(embeddings.numpy(), 205 | manual_embeddings.numpy()) 206 | -------------------------------------------------------------------------------- /devtools/README.md: -------------------------------------------------------------------------------- 1 | Developer best practices 2 | == 3 | 4 | Thank you for contributing to `cgnet`! Here are our recommended practices for contributing, in no particular order of importance. 5 | 6 | Developer and merging info 7 | -- 8 | The developers (@coarse-graining/developers) are Brooke ([@brookehus](https://github.com/brookehus)), Nick ([@nec4](https://github.com/nec4)), and Dominik ([@Dom1L](https://github.com/Dom1L)). Only developers have merge permissions to master. 9 | 10 | - PRs from non-developers require 2/3 approving developer reviews. 11 | - Major PRs from developers require 2/2 approving reviews from the other developers. 12 | - Minor PRs from developers require 1/2 approving reviews from the other developers. 13 | - For now, if it's not obvious, we'll just talk within each PR about whether it's "major" or "minor". If this seems to be a hassle, we can add labels or come up with another method. 14 | 15 | 16 | PR best practices 17 | -- 18 | - Since the repo cannot be forked, always develop in your own branch. 19 | - Make sure that you are always pulling the latest code from `master` as you develop. 20 | - Default to PRing to other people's branches, even if it seems obnoxious. Never commit to someone else's branch unless you have cleared it with them first! 21 | - *Never merge your own PR, especially to master*, but in general to any branch. 22 | - Never commit any kinds of non-coding files without discussing first. 23 | - Run `nosetests` before requesting final reviews, and run `nosetests` whenever you review. 24 | - Use the labels and milestones that we've made to categorize your PRs (and issues). 25 | - Be nice and constructive with your feedback! 26 | - Always give a passing review before merging, even if it's just "LGTM"! 27 | 28 | Dependency best practices 29 | -- 30 | - Discuss with the @coarse-graining/developers about how and whether to incorporate a new dependency. 31 | - As of now, we are not adding plotting utilities to the main code; this may change in the future! 32 | 33 | Coding best practices (just a small subset!) 34 | -- 35 | - *Always* use python 3 and `pytorch >= 1.2`! 36 | - Use `pep8` formatting! Packages like `autopep8` can help with this. 37 | - Add yourself to the contributors list at the top of the file. 38 | - Classes are `CamelCase`, and functions `use_underscores`. 39 | - Intra-code dependencies should be one-directional: e.g., `cgnet.network` can import from `cgnet.feature`, but not the other way around. 40 | - Use descriptive variable names, documentation, and comments! You will thank yourself later. 41 | - Don't hide problems! Be transparent and add notes about anything that comes up but remains unaddressed. 42 | 43 | Testing best practices 44 | -- 45 | - All tests go in `cgnet.tests` and are not imported in `cgnet.tests.__init__.py`. The exception for soft dependencies is that a `tests` folder should be in the relevant directory; see `cgnet.molecule` for an example. 46 | - The function must start with `test_`. Use `#` for comments instead of `"""`. 47 | - The purpose of tests is so that future development doesn't break existing methods. Write tests so that if someone down the road breaks your method, your test will tell them! 48 | - Each test should only test one aspect of the code, so if it breaks, you know what is implicated. 49 | - It's okay, even recommended, to copy and paste code between tests! 50 | - Tests shouldn't replicate the code of the main package, but should obtain values in other ways. 51 | 52 | Example notebook best practices 53 | -- 54 | - Example notebooks should be limited in scope and cover just one topic in a tutorial way. If the notebook becomes long-winded, break it into multiple notebooks. 55 | - Never commit a notebook without the output cleared (if you do this by accident, [squash the commits](https://github.com/wprig/wprig/wiki/How-to-squash-commits)). 56 | 57 | Other 58 | -- 59 | - Keep in mind that this repository will eventually become public - please consider whether this is appropriate for your comments and contributions with respect to ongoing research. 60 | -------------------------------------------------------------------------------- /devtools/changelog.md: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | v0.2 (development) 5 | ------------------ 6 | 7 | v0.1 (July 23, 2020) 8 | -------------------- 9 | 10 | Hello world! cgnet is described in our [preprint](https://arxiv.org/abs/2007.11412). 11 | -------------------------------------------------------------------------------- /examples/data/README.md: -------------------------------------------------------------------------------- 1 | The alanine dipeptide coordinates and forces are a subset of the ones used in the CGnet paper [1]. The simulation was performed with the AMBER ff99SB-ILDN force field [2] at 300 K. 2 | 3 | The subset included in this repository contain 10,000 data points at 10 ps intervals, whereas the dataset in the CGnet paper has 1,000,000 data points at 1 ps intervals. 4 | 5 | The file `ala2_coordinates.npy` is a np.ndarray of shape `(10000, 5, 3)`, for 10,000 frames, five coarse-grained beads corresponding to the five backbone atoms C (ACE-1), N (ALA-2), CA (ALA-2), C (ALA-2), N (NME-3), and three dimensions. The file `ala2_forces.npy` is a np.ndarray of the same shape containing the corresponding forces for each frame, bead, and dimension. 6 | 7 | If you use this data, please cite [1]: 8 | 9 | ```bibtex 10 | @article{wang2019machine, 11 | title={Machine learning of coarse-grained molecular dynamics force fields}, 12 | author={Wang, Jiang and Olsson, Simon and Wehmeyer, Christoph and Pérez, Adrià and Charron, Nicholas E and de Fabritiis, Gianni and Noé, Frank and Clementi, Cecilia}, 13 | journal={ACS Central Science}, 14 | year={2019}, 15 | publisher={ACS Publications}, 16 | doi={10.1021/acscentsci.8b00913} 17 | } 18 | ``` 19 | 20 | [1] Wang, J., Olsson, S., Wehmeyer, C., Pérez, A., Charron, N. E., de Fabritiis, G., Noé, F., Clementi, C. (2019). Machine Learning of Coarse-Grained Molecular Dynamics Force Fields. _ACS Central Science._ https://doi.org/10.1021/acscentsci.8b00913 21 | 22 | [2] K. Lindorff-Larsen, S. Piana, K. Palmo, P. Maragakis, J. L. Klepeis, R. O. Dror, and D. E. Shaw. (2010). Improved side-chain torsion potentials for the Amber ff99SB protein force field. _Proteins: Struct., Funct., Bioinf._ 78, 1950 (2010). http://dx.doi.org/10.1002/prot.22711 23 | -------------------------------------------------------------------------------- /examples/data/ala2_coordinates.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coarse-graining/cgnet/a3e0e8ddc06f4b6a9f48f4886b73b4cf372ff481/examples/data/ala2_coordinates.npy -------------------------------------------------------------------------------- /examples/data/ala2_forces.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coarse-graining/cgnet/a3e0e8ddc06f4b6a9f48f4886b73b4cf372ff481/examples/data/ala2_forces.npy -------------------------------------------------------------------------------- /examples/figs/CGnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coarse-graining/cgnet/a3e0e8ddc06f4b6a9f48f4886b73b4cf372ff481/examples/figs/CGnet.png -------------------------------------------------------------------------------- /examples/figs/README.md: -------------------------------------------------------------------------------- 1 | This figure appeared in the following article: 2 | 3 | ```bibtex 4 | @article{wang2019machine, 5 | title={Machine learning of coarse-grained molecular dynamics force fields}, 6 | author={Wang, Jiang and Olsson, Simon and Wehmeyer, Christoph and Pérez, Adrià and Charron, Nicholas E and de Fabritiis, Gianni and Noé, Frank and Clementi, Cecilia}, 7 | journal={ACS Central Science}, 8 | year={2019}, 9 | publisher={ACS Publications}, 10 | doi={10.1021/acscentsci.8b00913} 11 | } 12 | ``` 13 | 14 | In additing to citing the article, if you wish to use this figure you must request permissions from ACS using [this link](https://pubs.acs.org/page/rightslinkno.jsp). 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch>=1.2 3 | scipy 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | NAME = 'cgnet' 4 | VERSION = '0.1' 5 | 6 | 7 | def read(filename): 8 | import os 9 | BASE_DIR = os.path.dirname(__file__) 10 | filename = os.path.join(BASE_DIR, filename) 11 | with open(filename, 'r') as fi: 12 | return fi.read() 13 | 14 | def readlist(filename): 15 | rows = read(filename).split("\n") 16 | rows = [x.strip() for x in rows if x.strip()] 17 | return list(rows) 18 | 19 | setup( 20 | name=NAME, 21 | version=VERSION, 22 | author="Nick Charron, Brooke Husic, Dominik Lemm, Jiang Wang", 23 | author_email="husic@zedat.fu-berlin.de", 24 | url='https://github.com/coarse-graining/cgnet', 25 | #download_url='https://github.com/coarse-graining/cgnet/tarball/master', 26 | #long_description=read('README.md'), 27 | license='BSD-3-Clause', 28 | packages=find_packages(), 29 | zip_safe=True, 30 | entry_points={ 31 | 'console_scripts': [ 32 | '%s = %s.cli.main:main' % (NAME, NAME), 33 | ], 34 | }, 35 | ) 36 | --------------------------------------------------------------------------------