├── LICENSE ├── README.md ├── data ├── getCATH.sh ├── ts50.json └── ts50remove.txt ├── gvp ├── __init__.py ├── atom3d.py ├── data.py └── models.py ├── run_atom3d.py ├── run_cpd.py ├── schematic.png ├── setup.py ├── test_equivariance.py └── vectors.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Bowen Jing, Stephan Eismann, Pratham Soni, 4 | Patricia Suriana, Raphael Townshend, Ron Dror 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Geometric Vector Perceptron 2 | 3 | Implementation of equivariant GVP-GNNs as described in [Learning from Protein Structure with Geometric Vector Perceptrons](https://openreview.net/forum?id=1YLJDvSx6J4) by B Jing, S Eismann, P Suriana, RJL Townshend, and RO Dror. 4 | 5 | **UPDATE:** Also includes equivariant GNNs with vector gating as described in [Equivariant Graph Neural Networks for 3D Macromolecular Structure](https://arxiv.org/abs/2106.03843) by B Jing, S Eismann, P Soni, and RO Dror. 6 | 7 | Scripts for training / testing / sampling on protein design and training / testing on all [ATOM3D](https://arxiv.org/abs/2012.04035) tasks are provided. 8 | 9 | **Note:** This implementation is in PyTorch Geometric. The original TensorFlow code, which is not maintained, can be found [here](https://github.com/drorlab/gvp). 10 | 11 |

12 | 13 |

14 | 15 | ## Requirements 16 | * UNIX environment 17 | * python==3.6.13 18 | * torch==1.8.1 19 | * torch_geometric==1.7.0 20 | * torch_scatter==2.0.6 21 | * torch_cluster==1.5.9 22 | * tqdm==4.38.0 23 | * numpy==1.19.4 24 | * sklearn==0.24.1 25 | * atom3d==0.2.1 26 | 27 | While we have not tested with other versions, any reasonably recent versions of these requirements should work. 28 | 29 | ## General usage 30 | 31 | We provide classes in three modules: 32 | * `gvp`: core GVP modules and GVP-GNN layers 33 | * `gvp.data`: data pipelines for both general use and protein design 34 | * `gvp.models`: implementations of MQA and CPD models 35 | * `gvp.atom3d`: models and data pipelines for ATOM3D 36 | 37 | The core modules in `gvp` are meant to be as general as possible, but you will likely have to modify `gvp.data` and `gvp.models` for your specific application, with the existing classes serving as examples. 38 | 39 | **Installation:** Download this repository and run `python setup.py develop` or `pip install . -e`. Be sure to manually install `torch_geometric` first! 40 | 41 | **Tuple representation:** All inputs and outputs with both scalar and vector channels are represented as a tuple of two tensors `(s, V)`. Similarly, all dimensions should be specified as tuples `(n_scalar, n_vector)` where `n_scalar` and `n_vector` are the number of scalar and vector features, respectively. All `V` tensors must be shaped as `[..., n_vector, 3]`, not `[..., 3, n_vector]`. 42 | 43 | **Batching:** We adopt the `torch_geometric` convention of absorbing the batch dimension into the node dimension and keeping track of batch index in a separate tensor. 44 | 45 | **Amino acids:** Models view sequences as int tensors and are agnostic to aa-to-int mappings. Such mappings are specified as the `letter_to_num` attribute of `gvp.data.ProteinGraphDataset`. Currently, only the 20 standard amino acids are supported. 46 | 47 | For all classes, see the docstrings for more detailed usage. If you have any questions, please contact bjing@cs.stanford.edu. 48 | 49 | ### Core GVP classes 50 | 51 | The class `gvp.GVP` implements a Geometric Vector Perceptron. 52 | ``` 53 | import gvp 54 | 55 | in_dims = scalars_in, vectors_in 56 | out_dims = scalars_out, vectors_out 57 | gvp_ = gvp.GVP(in_dims, out_dims) 58 | ``` 59 | To use vector gating, pass in `vector_gate=True` and the appropriate activations. 60 | ``` 61 | gvp_ = gvp.GVP(in_dims, out_dims, 62 | activations=(F.relu, None), vector_gate=True) 63 | ``` 64 | The classes `gvp.Dropout` and `gvp.LayerNorm` implement vector-channel dropout and layer norm, while using normal dropout and layer norm for scalar channels. Both expect inputs and return outputs of form `(s, V)`, but will also behave like their scalar-valued counterparts if passed a single tensor. 65 | ``` 66 | dropout = gvp.Dropout(drop_rate=0.1) 67 | layernorm = gvp.LayerNorm(out_dims) 68 | ``` 69 | The function `gvp.randn` returns tuples `(s, V)` drawn from a standard normal. Such tuples can be directly used in a forward pass. 70 | ``` 71 | x = gvp.randn(n=5, dims=in_dims) 72 | # x = (s, V) with s.shape = [5, scalars_in] and V.shape = [5, vectors_in, 3] 73 | 74 | out = gvp_(x) 75 | out = drouput(out) 76 | out = layernorm(out) 77 | ``` 78 | Finally, we provide utility functions for adding, concatenating, and indexing into such tuples. 79 | ``` 80 | y = gvp.randn(n=5, dims=in_dims) 81 | z = gvp.tuple_sum(x, y) 82 | z = gvp.tuple_cat(x, y, dim=-1) # concat along channel axis 83 | z = gvp.tuple_cat(x, y, dim=-2) # concat along node / batch axis 84 | 85 | node_mask = torch.rand(5) < 0.5 86 | z = gvp.tuple_index(x, node_mask) # select half the nodes / batch at random 87 | ``` 88 | ### GVP-GNN layers 89 | The class `GVPConv` is a `torch_geometric.MessagePassing` module which forms messages and aggregates them at the destination node, returning new node embeddings. The original embeddings are not updated. 90 | ``` 91 | nodes = gvp.randn(n=5, in_dims) 92 | edges = gvp.randn(n=10, edge_dims) # 10 random edges 93 | edge_index = torch.randint(0, 5, (2, 10), device=device) 94 | 95 | conv = gvp.GVPConv(in_dims, out_dims, edge_dims) 96 | out = conv(nodes, edge_index, edges) 97 | ``` 98 | The class `GVPConvLayer` is a `nn.Module` that forms messages using a `GVPConv` and updates the node embeddings as described in the paper. Because the updates are residual, the dimensionality of the embeddings are not changed. 99 | ``` 100 | layer = gvp.GVPConvLayer(node_dims, edge_dims) 101 | nodes = layer(nodes, edge_index, edges) 102 | ``` 103 | The class also allows updates where incoming messages where src >= dst are computed using a different set of source embeddings, as in autoregressive models. 104 | ``` 105 | nodes_static = gvp.randn(n=5, in_dims) 106 | layer = gvp.GVPConvLayer(node_dims, edge_dims, autoregressive=True) 107 | nodes = layer(nodes, edge_index, edges, autoregressive_x=nodes_static) 108 | ``` 109 | Both `GVPConv` and `GVPConvLayer` accept arguments `activations` and `vector_gate` to use vector gating. 110 | 111 | ### Loading data 112 | 113 | The class `gvp.data.ProteinGraphDataset` transforms protein backbone structures into featurized graphs. Following [Ingraham, et al, NeurIPS 2019](https://github.com/jingraham/neurips19-graph-protein-design), we use a JSON/dictionary format to specify backbone structures: 114 | 115 | ``` 116 | [ 117 | { 118 | "name": "NAME" 119 | "seq": "TQDCSFQHSP...", 120 | "coords": [[[74.46, 58.25, -21.65],...],...] 121 | } 122 | ... 123 | ] 124 | ``` 125 | For each structure, `coords` should be a `num_residues x 4 x 3` nested list of the positions of the backbone N, C-alpha, C, and O atoms of each residue (in that order). 126 | ``` 127 | import gvp.data 128 | 129 | # structures is a list or list-like as shown above 130 | dataset = gvp.data.ProteinGraphDataset(structures) 131 | # dataset[i] is featurized graph corresponding to structures[i] 132 | ``` 133 | The returned graphs are of type `torch_geometric.data.Data` with attributes 134 | * `x`: alpha carbon coordinates 135 | * `seq`: sequence converted to int tensor according to attribute `self.letter_to_num` 136 | * `name`, `edge_index` 137 | * `node_s`, `node_v`: node features as described in the paper with dims `(6, 3)` 138 | * `edge_s`, `edge_v`: edge features as described in the paper with dims `(32, 1)` 139 | * `mask`: false for nodes with any nan coordinates 140 | 141 | The `gvp.data.ProteinGraphDataset` can be used with a `torch.utils.data.DataLoader`. We supply a class `gvp.data.BatchSampler` which will form batches based on the number of total nodes in a batch. Use of this sampler is optional. 142 | ``` 143 | node_counts = [len(s['seq']) for s in structures] 144 | sampler = gvp.data.BatchSampler(node_counts, max_nodes=3000) 145 | dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler) 146 | ``` 147 | The dataloader will return batched graphs of type `torch_geometric.data.Batch` with an additional `batch` attibute. The attributes of the `Batch` will then need to be formed into `(s, V)` tuples before passing into a GVP-GNN layer or network. 148 | ``` 149 | for batch in dataloader: 150 | batch = batch.to(device) # optional 151 | nodes = (batch.node_s, batch.node_v) 152 | edges = (batch.edge_s, batch.edge_v) 153 | 154 | out = layer(nodes, batch.edge_index, edges) 155 | ``` 156 | 157 | ### Ready-to-use protein GNNs 158 | We provide two fully specified networks which take in protein graphs and output a scalar prediction for each graph (`gvp.models.MQAModel`) or a 20-dimensional feature vector for each node (`gvp.models.CPDModel`), corresponding to the two tasks in our paper. Note that if you are using the unmodified `gvp.data.ProteinGraphDataset`, `node_in_dims` and `edge_in_dims` must be `(6, 3)` and `(32, 1)`, respectively. 159 | ``` 160 | import gvp.models 161 | 162 | # batch, nodes, edges as formed above 163 | 164 | mqa_model = gvp.models.MQAModel(node_in_dim, node_h_dim, 165 | edge_in_dim, edge_h_dim, seq_in=True) 166 | out = mqa_model(nodes, batch.edge_index, edges, 167 | seq=batch.seq, batch=batch.batch) # shape (n_graphs,) 168 | 169 | cpd_model = gvp.models.CPDModel(node_in_dim, node_h_dim, 170 | edge_in_dim, edge_h_dim) 171 | out = cpd_model(nodes, batch.edge_index, 172 | edges, batch.seq) # shape (n_nodes, 20) 173 | ``` 174 | 175 | ## Protein design 176 | We provide a script `run_cpd.py` to train, validate, and test a `CPDModel` as specified in the paper using the CATH 4.2 dataset and TS50 dataset. If you want to use a trained model on new structures, see the section "Sampling" below. 177 | 178 | ### Fetching data 179 | Run `getCATH.sh` in `data/` to fetch the CATH 4.2 dataset. If you are interested in testing on the TS 50 test set, also run `grep -Fv -f ts50remove.txt chain_set.jsonl > chain_set_ts50.jsonl` to produce a training set without overlap with the TS 50 test set. 180 | 181 | ### Training / testing 182 | To train a model, simply run `python run_cpd.py --train`. To test a trained model on both the CATH 4.2 test set and the TS50 test set, run `python run_cpd --test-r PATH` for perplexity or with `--test-p` for perplexity. Run `python run_cpd.py -h` for more detailed options. 183 | 184 | ``` 185 | $ python run_cpd.py -h 186 | 187 | usage: run_cpd.py [-h] [--models-dir PATH] [--num-workers N] [--max-nodes N] [--epochs N] [--cath-data PATH] [--cath-splits PATH] [--ts50 PATH] [--train] [--test-r PATH] [--test-p PATH] [--n-samples N] 188 | 189 | optional arguments: 190 | -h, --help show this help message and exit 191 | --models-dir PATH directory to save trained models, default=./models/ 192 | --num-workers N number of threads for loading data, default=4 193 | --max-nodes N max number of nodes per batch, default=3000 194 | --epochs N training epochs, default=100 195 | --cath-data PATH location of CATH dataset, default=./data/chain_set.jsonl 196 | --cath-splits PATH location of CATH split file, default=./data/chain_set_splits.json 197 | --ts50 PATH location of TS50 dataset, default=./data/ts50.json 198 | --train train a model 199 | --test-r PATH evaluate a trained model on recovery (without training) 200 | --test-p PATH evaluate a trained model on perplexity (without training) 201 | --n-samples N number of sequences to sample (if testing recovery), default=100 202 | ``` 203 | **Confusion matrices:** Note that the values are normalized such that each row (corresponding to true class) sums to 1000, with the actual number of residues in that class printed under the "Count" column. 204 | 205 | ### Sampling 206 | To sample from a `CPDModel`, prepare a `ProteinGraphDataset`, but do NOT pass into a `DataLoader`. The sequences are not used, so placeholders can be used for the `seq` attributes of the original structures dicts. 207 | 208 | ``` 209 | protein = dataset[i] 210 | nodes = (protein.node_s, protein.node_v) 211 | edges = (protein.edge_s, protein.edge_v) 212 | 213 | sample = model.sample(nodes, protein.edge_index, # shape = (n_samples, n_nodes) 214 | edges, n_samples=n_samples) 215 | ``` 216 | The output will be an int tensor, with mappings corresponding to those used when training the model. 217 | 218 | ## ATOM3D 219 | We provide models and dataloaders for all ATOM3D tasks in `gvp.atom3d`, as well as a training and testing script in `run_atom3d.py`. This also supports loading pretrained weights for transfer learning experiments. 220 | 221 | ### Models / data loaders 222 | The GVP-GNNs for ATOM3D are supplied in `gvp.atom3d` and are named after each task: `gvp.atom3d.MSPModel`, `gvp.atom3d.PPIModel`, etc. All of these extend the base class `gvp.atom3d.BaseModel`. These classes take no arguments at initialization, take in a `torch_geometric.data.Batch` representation of a batch of structures, and return an output corresponding to the task. Details vary based on the exact task---see the docstrings. 223 | ``` 224 | psr_model = gvp.atom3d.PSRModel() 225 | ``` 226 | `gvp.atom3d` also includes data loaders to produce `torch_geometric.data.Batch` objects from an underlying `atom3d.datasets.LMDBDataset`. In the case of all tasks except PPI and RES, these are in the form of callable transform objects---`gvp.atom3d.SMPTransform`, `gvp.atom3d.RSRTransform`, etc---which should be passed into the constructor of a `atom3d.datasets.LMDBDataset`: 227 | ``` 228 | psr_dataset = atom3d.datasets.LMDBDataset(path_to_dataset, 229 | transform=gvp.atom3d.PSRTransform()) 230 | ``` 231 | On the other hand, `gvp.atom3d.PPIDataset` and `gvp.atom3d.RESDataset` take the place of / are wrappers around the `atom3d.datasets.LMDBDataset`: 232 | ``` 233 | ppi_dataset = gvp.atom3d.PPIDataset(path_to_dataset) 234 | res_dataset = gvp.atom3d.RESDataset(path_to_dataset, path_to_split) # see docstring 235 | ``` 236 | All datasets must be then wrapped in a `torch_geometric.data.DataLoader`: 237 | ``` 238 | psr_dataloader = torch_geometric.data.DataLoader(psr_dataset, batch_size=batch_size) 239 | ``` 240 | The dataloaders can be directly iterated over to yield `torch_geometric.data.Batch` objects, which can then be passed into the models. 241 | ``` 242 | for batch in psr_dataloader: 243 | pred = psr_model(batch) # pred.shape = (batch_size,) 244 | ``` 245 | 246 | ### Training / testing 247 | 248 | To run training / testing on ATOM3D, download the datasets as described [here](https://www.atom3d.ai/). Modify the function `get_datasets` in `run_atom3d.py` with the paths to the datasets. Then run: 249 | ``` 250 | $ python run_atom3d.py -h 251 | 252 | usage: run_atom3d.py [-h] [--num-workers N] [--smp-idx IDX] 253 | [--lba-split SPLIT] [--batch SIZE] [--train-time MINUTES] 254 | [--val-time MINUTES] [--epochs N] [--test PATH] 255 | [--lr RATE] [--load PATH] 256 | TASK 257 | 258 | positional arguments: 259 | TASK {PSR, RSR, PPI, RES, MSP, SMP, LBA, LEP} 260 | 261 | optional arguments: 262 | -h, --help show this help message and exit 263 | --num-workers N number of threads for loading data, default=4 264 | --smp-idx IDX label index for SMP, in range 0-19 265 | --lba-split SPLIT identity cutoff for LBA, 30 (default) or 60 266 | --batch SIZE batch size, default=8 267 | --train-time MINUTES maximum time between evaluations on valset, 268 | default=120 minutes 269 | --val-time MINUTES maximum time per evaluation on valset, default=20 270 | minutes 271 | --epochs N training epochs, default=50 272 | --test PATH evaluate a trained model 273 | --lr RATE learning rate 274 | --load PATH initialize first 2 GNN layers with pretrained weights 275 | ``` 276 | For example: 277 | ``` 278 | # train a model 279 | python run_atom3d.py PSR 280 | 281 | # train a model with pretrained weights 282 | python run_atom3d.py PSR --load PATH 283 | 284 | # evaluate a model 285 | python run_atom3d.py PSR --test PATH 286 | ``` 287 | 288 | ## Acknowledgements 289 | Portions of the input data pipeline were adapted from [Ingraham, et al, NeurIPS 2019](https://github.com/jingraham/neurips19-graph-protein-design). We thank Pratham Soni for portions of the implementation in PyTorch. 290 | 291 | ## Citation 292 | ``` 293 | @inproceedings{ 294 | jing2021learning, 295 | title={Learning from Protein Structure with Geometric Vector Perceptrons}, 296 | author={Bowen Jing and Stephan Eismann and Patricia Suriana and Raphael John Lamarre Townshend and Ron Dror}, 297 | booktitle={International Conference on Learning Representations}, 298 | year={2021}, 299 | url={https://openreview.net/forum?id=1YLJDvSx6J4} 300 | } 301 | 302 | @article{jing2021equivariant, 303 | title={Equivariant Graph Neural Networks for 3D Macromolecular Structure}, 304 | author={Jing, Bowen and Eismann, Stephan and Soni, Pratham N and Dror, Ron O}, 305 | journal={arXiv preprint arXiv:2106.03843}, 306 | year={2021} 307 | } 308 | ``` 309 | -------------------------------------------------------------------------------- /data/getCATH.sh: -------------------------------------------------------------------------------- 1 | wget http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/chain_set.jsonl 2 | wget http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/chain_set_splits.json 3 | wget http://people.csail.mit.edu/ingraham/graph-protein-design/data/SPIN2/test_split_L100.json 4 | wget http://people.csail.mit.edu/ingraham/graph-protein-design/data/SPIN2/test_split_sc.json 5 | -------------------------------------------------------------------------------- /data/ts50remove.txt: -------------------------------------------------------------------------------- 1 | 1ypo.A 2 | 1y1l.A 3 | 1am2.A 4 | 3id9.A 5 | 3zid.B 6 | 3vpp.B 7 | 1gyc.A 8 | 3r79.A 9 | 3e82.A 10 | 3lqc.A 11 | 2cvi.A 12 | 2v6y.A 13 | 3nng.A 14 | 1zak.A 15 | 2nvp.A 16 | 1vef.A 17 | 2kp5.A 18 | 2b3h.A 19 | 2n8k.A 20 | 2f91.A 21 | 3e48.A 22 | 2v3t.A 23 | 4jvs.A 24 | 3fpw.A 25 | 1c0n.A 26 | 2gu3.A 27 | 1es0.A 28 | 3zhg.A 29 | 4epk.B 30 | 2avp.A 31 | 2e9x.A 32 | 4dkc.B 33 | 5b1a.M 34 | 3nbk.D 35 | 1o9i.A 36 | 4hcw.A 37 | 3mmz.A 38 | 2rbk.A 39 | 4yon.A 40 | 4nav.A 41 | 4cgw.A 42 | 3io1.A 43 | 1dxe.A 44 | 1a17.A 45 | 3cyn.B 46 | 3m6n.B 47 | 1ble.A 48 | 4qjb.B 49 | 2fml.A 50 | 4eo3.A 51 | 1r5m.A 52 | 3dao.A 53 | 3hhj.B 54 | 1e1h.B 55 | 2jty.A 56 | 3gkm.A 57 | 2wgl.A 58 | 3ajv.A 59 | 2duk.A 60 | 1dd3.A 61 | 3fhk.A 62 | 2fi7.A 63 | 2h2t.B 64 | 1zde.A 65 | 5hk8.A 66 | 2xcj.A 67 | 2gys.A 68 | 1lmj.A 69 | 2e7x.A 70 | 1kql.B 71 | 1yul.A 72 | 2l49.B 73 | 3l0g.A 74 | 3k4i.A 75 | 3l23.A 76 | 1qkr.A 77 | 2bw2.A 78 | 1vc9.A 79 | 4uqx.A 80 | 1emo.A 81 | 3oee.Y 82 | 2xdg.A 83 | 3sz7.A 84 | 2bv2.A 85 | 4jpb.W 86 | 1nps.A 87 | 5got.A 88 | 4mi7.A 89 | 1rtt.A 90 | 1ouv.A 91 | 3bzm.A 92 | 2dx6.A 93 | 1c17.M 94 | 3q4o.A 95 | 1s0w.C 96 | 4ozu.A 97 | 2kna.A 98 | 1jkv.A 99 | 4tkz.A 100 | 3dfj.A 101 | 4k2b.A 102 | 3ejf.A 103 | 3cyg.A 104 | 3qo6.A 105 | 3pt1.A 106 | 2a6t.B 107 | 3mt1.A 108 | 4fd9.A 109 | 2oga.A 110 | 1fjr.A 111 | 1im2.A 112 | 4hfs.A 113 | 3nd5.A 114 | 4lx3.A 115 | 1k1e.D 116 | 1xnf.A 117 | 1f3y.A 118 | 3fk9.A 119 | 4r42.A 120 | 3fmy.A 121 | 1b9x.C 122 | 5im4.F 123 | 3ajv.C 124 | 1kli.L 125 | 5epf.A 126 | 1a45.A 127 | 3vld.A 128 | 2qdl.A 129 | 4nfw.F 130 | 2hi6.A 131 | 1wp0.A 132 | 1mr1.D 133 | 4ar0.A 134 | 3isz.A 135 | 2km4.A 136 | 2o1e.A 137 | 1v7m.V 138 | 4mt4.A 139 | 4fz2.A 140 | 1emn.A 141 | 1kcg.B 142 | 3cng.A 143 | 1va9.A 144 | 2qkm.B 145 | 1ir6.A 146 | 3ny7.A 147 | 1rtq.A 148 | 4hku.A 149 | 3hz2.A 150 | 3piv.A 151 | 2bou.A 152 | 3ecs.A 153 | 2h1y.B 154 | 5tf5.A 155 | 3pe8.A 156 | 1coz.A 157 | 3grn.A 158 | 1kt0.A 159 | 1sqs.A 160 | 5erm.B 161 | 2rag.A 162 | 1akq.A 163 | 1dv8.A 164 | 3u37.A 165 | 5egw.A 166 | 3fm2.A 167 | 5g5g.B 168 | 5e3i.A 169 | 3mpo.A 170 | 1oki.A 171 | 1q33.A 172 | 1dx5.I 173 | 2p9j.A 174 | 4xch.A 175 | 2gpy.A 176 | 2b06.A 177 | 2ra8.A 178 | 2w86.A 179 | 2fek.A 180 | 1egg.A 181 | 3kqg.A 182 | 2dnm.A 183 | 3oq0.B 184 | 4jzs.A 185 | 3a6s.A 186 | 3f13.B 187 | 1vs3.A 188 | 1i8n.A 189 | 1g3i.W 190 | 1ddm.A 191 | 1k2e.A 192 | 1wr8.A 193 | 2va0.A 194 | 3apa.A 195 | 3k2k.A 196 | 3gg6.A 197 | 2vyi.A 198 | 4lej.A 199 | 1a7h.A 200 | 1zu4.A 201 | 3pfo.A 202 | 3uj6.A 203 | 3d68.A 204 | 1b08.C 205 | 4ezb.A 206 | 4lq6.A 207 | 3p8b.B 208 | 2cxx.A 209 | 2i15.A 210 | 1t0i.A 211 | 1ahs.A 212 | 4g4p.A 213 | 3k96.A 214 | 2q88.A 215 | 3o38.B 216 | 4um7.B 217 | 3wur.A 218 | 3cq4.A 219 | 3h95.A 220 | 3a3j.A 221 | 3n8h.A 222 | 1dqb.A 223 | 2p5t.A 224 | 3hkl.A 225 | 4gxw.B 226 | 1j71.A 227 | 3gwi.A 228 | 3q10.D 229 | 1hjz.A 230 | 1kmj.A 231 | 2qjt.B 232 | 3bil.A 233 | 3fbs.A 234 | 3fbs.B 235 | 2vvp.C 236 | 2h1y.A 237 | 3eix.A 238 | 3rui.A 239 | 1efa.A 240 | 2zag.A 241 | 1x57.A 242 | 2rrk.A 243 | 1ucv.A 244 | 3pqi.A 245 | 1qwo.A 246 | 3uk7.A 247 | 1l3a.A 248 | 3skv.A 249 | 3n77.A 250 | 3d8d.A 251 | 1w3i.A 252 | 2ra6.C 253 | 1r61.A 254 | 3qao.A 255 | 1wu4.A 256 | 2d1c.A 257 | 2w2i.A 258 | 2rdx.A 259 | 4gp6.A 260 | 2q1s.A 261 | 4mzy.A 262 | 1bvy.F 263 | 3ofn.Y 264 | 3nv7.A 265 | 3oir.A 266 | 2gwn.A 267 | 2dad.A 268 | 3on9.A 269 | 3rh0.A 270 | 3lvy.A 271 | 3lwk.A 272 | 3kol.A 273 | 1na0.A 274 | 4b6z.A 275 | 3f6a.A 276 | 3u7r.A 277 | 1nrw.A 278 | 2nly.A 279 | 1dm1.A 280 | 4nkp.A 281 | 4og1.A 282 | 5hci.C 283 | 1v7p.B 284 | 2oge.A 285 | 4o8s.A 286 | 2b8e.A 287 | 1elr.A 288 | 1u6l.A 289 | 1oqy.A 290 | 5a2q.I 291 | 1ugo.A 292 | 2a4v.A 293 | 3so6.A 294 | 2yyh.A 295 | 3tqt.A 296 | 1xxm.C 297 | 2zad.A 298 | 1jhe.A 299 | 2erx.A 300 | 1ete.A 301 | 2j4b.B 302 | 2zyz.C 303 | 1hv8.A 304 | 3iey.B 305 | 2jod.A 306 | 1tdq.B 307 | 3pqh.A 308 | 3o8s.A 309 | 2ca5.A 310 | 1zej.A 311 | 3fgh.A 312 | 4qam.B 313 | 5lba.D 314 | 5cfj.A 315 | 2efp.A 316 | 1erj.B 317 | 3pgv.A 318 | 2rav.A 319 | 1o1x.A 320 | 3bkw.A 321 | 1aw5.A 322 | 1st9.B 323 | 4lzh.A 324 | 1i0r.A 325 | 3ujp.A 326 | 3rof.A 327 | 2b9l.A 328 | 3ff7.C 329 | 4w64.B 330 | 1pdo.A 331 | 2nrr.A 332 | 1zwx.A 333 | 3aqf.B 334 | 4ptz.C 335 | 2bmx.B 336 | 3t38.A 337 | 4dev.D 338 | 1wlz.A 339 | 3t9w.A 340 | 2k13.X 341 | 3ufx.B 342 | 3c22.C 343 | 3bpt.A 344 | 2j49.A 345 | 1sjy.A 346 | 1nni.1 347 | 3e8m.A 348 | 2b0v.A 349 | 3fjy.A 350 | 1yxr.A 351 | 1fon.A 352 | 1xip.A 353 | 1x6o.A 354 | 2acf.D 355 | 3e59.B 356 | 1euv.B 357 | 1r4s.A 358 | 3cw3.A 359 | 3mxt.A 360 | 3g80.A 361 | 3aqg.B 362 | 3nhi.A 363 | 3v64.C 364 | 1a9x.B 365 | 2myu.A 366 | 2qdf.A 367 | 1mo7.A 368 | 1ksh.B 369 | 3lst.A 370 | 3l0g.B 371 | 4pmx.A 372 | 2r75.1 373 | 1bvp.1 374 | 1u8s.A 375 | 3kxy.J 376 | 4lfl.B 377 | 1g41.A 378 | 1a79.A 379 | 1i1g.A 380 | 4jbc.A 381 | 4k2m.A 382 | 2i39.B 383 | 2zai.A 384 | 3l4r.A 385 | 1y9q.A 386 | 1fsu.A 387 | 4qd4.A 388 | 2cyy.A 389 | 2cay.B 390 | 3lcm.A 391 | 2ls8.A 392 | 1rkq.A 393 | 1gvf.B 394 | 2g5f.B 395 | 2kz8.A 396 | 3k9x.C 397 | 3tqt.B 398 | 2iu8.C 399 | 5tgz.A 400 | 1lw7.A 401 | 4fo9.A 402 | 2qlr.A 403 | 3mpo.D 404 | 5t3u.B 405 | 1b1c.A 406 | 2z1d.A 407 | 4egs.A 408 | 5jci.A 409 | 2jxx.A 410 | 2cg4.A 411 | 2pqv.A 412 | 2kc7.A 413 | 3fiu.A 414 | 2ox8.A 415 | 3fvw.A 416 | 1vk6.A 417 | 4zev.A 418 | 5i3s.C 419 | 2a2l.C 420 | 3drn.B 421 | 1lpa.A 422 | 3exq.A 423 | 5jsz.A 424 | 1qze.A 425 | 3mcx.A 426 | 2w4e.A 427 | 2om6.A 428 | 4wzu.A 429 | 2j5o.A 430 | 2e0i.A 431 | 3q9s.A 432 | 3ewi.B 433 | 1or4.B 434 | 2aeg.A 435 | 4rmo.A 436 | -------------------------------------------------------------------------------- /gvp/__init__.py: -------------------------------------------------------------------------------- 1 | import torch, functools 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import MessagePassing 5 | from torch_scatter import scatter_add 6 | 7 | def tuple_sum(*args): 8 | ''' 9 | Sums any number of tuples (s, V) elementwise. 10 | ''' 11 | return tuple(map(sum, zip(*args))) 12 | 13 | def tuple_cat(*args, dim=-1): 14 | ''' 15 | Concatenates any number of tuples (s, V) elementwise. 16 | 17 | :param dim: dimension along which to concatenate when viewed 18 | as the `dim` index for the scalar-channel tensors. 19 | This means that `dim=-1` will be applied as 20 | `dim=-2` for the vector-channel tensors. 21 | ''' 22 | dim %= len(args[0][0].shape) 23 | s_args, v_args = list(zip(*args)) 24 | return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim) 25 | 26 | def tuple_index(x, idx): 27 | ''' 28 | Indexes into a tuple (s, V) along the first dimension. 29 | 30 | :param idx: any object which can be used to index into a `torch.Tensor` 31 | ''' 32 | return x[0][idx], x[1][idx] 33 | 34 | def randn(n, dims, device="cpu"): 35 | ''' 36 | Returns random tuples (s, V) drawn elementwise from a normal distribution. 37 | 38 | :param n: number of data points 39 | :param dims: tuple of dimensions (n_scalar, n_vector) 40 | 41 | :return: (s, V) with s.shape = (n, n_scalar) and 42 | V.shape = (n, n_vector, 3) 43 | ''' 44 | return torch.randn(n, dims[0], device=device), \ 45 | torch.randn(n, dims[1], 3, device=device) 46 | 47 | def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True): 48 | ''' 49 | L2 norm of tensor clamped above a minimum value `eps`. 50 | 51 | :param sqrt: if `False`, returns the square of the L2 norm 52 | ''' 53 | out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps) 54 | return torch.sqrt(out) if sqrt else out 55 | 56 | def _split(x, nv): 57 | ''' 58 | Splits a merged representation of (s, V) back into a tuple. 59 | Should be used only with `_merge(s, V)` and only if the tuple 60 | representation cannot be used. 61 | 62 | :param x: the `torch.Tensor` returned from `_merge` 63 | :param nv: the number of vector channels in the input to `_merge` 64 | ''' 65 | v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3)) 66 | s = x[..., :-3*nv] 67 | return s, v 68 | 69 | def _merge(s, v): 70 | ''' 71 | Merges a tuple (s, V) into a single `torch.Tensor`, where the 72 | vector channels are flattened and appended to the scalar channels. 73 | Should be used only if the tuple representation cannot be used. 74 | Use `_split(x, nv)` to reverse. 75 | ''' 76 | v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],)) 77 | return torch.cat([s, v], -1) 78 | 79 | class GVP(nn.Module): 80 | ''' 81 | Geometric Vector Perceptron. See manuscript and README.md 82 | for more details. 83 | 84 | :param in_dims: tuple (n_scalar, n_vector) 85 | :param out_dims: tuple (n_scalar, n_vector) 86 | :param h_dim: intermediate number of vector channels, optional 87 | :param activations: tuple of functions (scalar_act, vector_act) 88 | :param vector_gate: whether to use vector gating. 89 | (vector_act will be used as sigma^+ in vector gating if `True`) 90 | ''' 91 | def __init__(self, in_dims, out_dims, h_dim=None, 92 | activations=(F.relu, torch.sigmoid), vector_gate=False): 93 | super(GVP, self).__init__() 94 | self.si, self.vi = in_dims 95 | self.so, self.vo = out_dims 96 | self.vector_gate = vector_gate 97 | if self.vi: 98 | self.h_dim = h_dim or max(self.vi, self.vo) 99 | self.wh = nn.Linear(self.vi, self.h_dim, bias=False) 100 | self.ws = nn.Linear(self.h_dim + self.si, self.so) 101 | if self.vo: 102 | self.wv = nn.Linear(self.h_dim, self.vo, bias=False) 103 | if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo) 104 | else: 105 | self.ws = nn.Linear(self.si, self.so) 106 | 107 | self.scalar_act, self.vector_act = activations 108 | self.dummy_param = nn.Parameter(torch.empty(0)) 109 | 110 | def forward(self, x): 111 | ''' 112 | :param x: tuple (s, V) of `torch.Tensor`, 113 | or (if vectors_in is 0), a single `torch.Tensor` 114 | :return: tuple (s, V) of `torch.Tensor`, 115 | or (if vectors_out is 0), a single `torch.Tensor` 116 | ''' 117 | if self.vi: 118 | s, v = x 119 | v = torch.transpose(v, -1, -2) 120 | vh = self.wh(v) 121 | vn = _norm_no_nan(vh, axis=-2) 122 | s = self.ws(torch.cat([s, vn], -1)) 123 | if self.vo: 124 | v = self.wv(vh) 125 | v = torch.transpose(v, -1, -2) 126 | if self.vector_gate: 127 | if self.vector_act: 128 | gate = self.wsv(self.vector_act(s)) 129 | else: 130 | gate = self.wsv(s) 131 | v = v * torch.sigmoid(gate).unsqueeze(-1) 132 | elif self.vector_act: 133 | v = v * self.vector_act( 134 | _norm_no_nan(v, axis=-1, keepdims=True)) 135 | else: 136 | s = self.ws(x) 137 | if self.vo: 138 | v = torch.zeros(s.shape[0], self.vo, 3, 139 | device=self.dummy_param.device) 140 | if self.scalar_act: 141 | s = self.scalar_act(s) 142 | 143 | return (s, v) if self.vo else s 144 | 145 | class _VDropout(nn.Module): 146 | ''' 147 | Vector channel dropout where the elements of each 148 | vector channel are dropped together. 149 | ''' 150 | def __init__(self, drop_rate): 151 | super(_VDropout, self).__init__() 152 | self.drop_rate = drop_rate 153 | self.dummy_param = nn.Parameter(torch.empty(0)) 154 | 155 | def forward(self, x): 156 | ''' 157 | :param x: `torch.Tensor` corresponding to vector channels 158 | ''' 159 | device = self.dummy_param.device 160 | if not self.training: 161 | return x 162 | mask = torch.bernoulli( 163 | (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device) 164 | ).unsqueeze(-1) 165 | x = mask * x / (1 - self.drop_rate) 166 | return x 167 | 168 | class Dropout(nn.Module): 169 | ''' 170 | Combined dropout for tuples (s, V). 171 | Takes tuples (s, V) as input and as output. 172 | ''' 173 | def __init__(self, drop_rate): 174 | super(Dropout, self).__init__() 175 | self.sdropout = nn.Dropout(drop_rate) 176 | self.vdropout = _VDropout(drop_rate) 177 | 178 | def forward(self, x): 179 | ''' 180 | :param x: tuple (s, V) of `torch.Tensor`, 181 | or single `torch.Tensor` 182 | (will be assumed to be scalar channels) 183 | ''' 184 | if type(x) is torch.Tensor: 185 | return self.sdropout(x) 186 | s, v = x 187 | return self.sdropout(s), self.vdropout(v) 188 | 189 | class LayerNorm(nn.Module): 190 | ''' 191 | Combined LayerNorm for tuples (s, V). 192 | Takes tuples (s, V) as input and as output. 193 | ''' 194 | def __init__(self, dims): 195 | super(LayerNorm, self).__init__() 196 | self.s, self.v = dims 197 | self.scalar_norm = nn.LayerNorm(self.s) 198 | 199 | def forward(self, x): 200 | ''' 201 | :param x: tuple (s, V) of `torch.Tensor`, 202 | or single `torch.Tensor` 203 | (will be assumed to be scalar channels) 204 | ''' 205 | if not self.v: 206 | return self.scalar_norm(x) 207 | s, v = x 208 | vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False) 209 | vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True)) 210 | return self.scalar_norm(s), v / vn 211 | 212 | class GVPConv(MessagePassing): 213 | ''' 214 | Graph convolution / message passing with Geometric Vector Perceptrons. 215 | Takes in a graph with node and edge embeddings, 216 | and returns new node embeddings. 217 | 218 | This does NOT do residual updates and pointwise feedforward layers 219 | ---see `GVPConvLayer`. 220 | 221 | :param in_dims: input node embedding dimensions (n_scalar, n_vector) 222 | :param out_dims: output node embedding dimensions (n_scalar, n_vector) 223 | :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) 224 | :param n_layers: number of GVPs in the message function 225 | :param module_list: preconstructed message function, overrides n_layers 226 | :param aggr: should be "add" if some incoming edges are masked, as in 227 | a masked autoregressive decoder architecture, otherwise "mean" 228 | :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs 229 | :param vector_gate: whether to use vector gating. 230 | (vector_act will be used as sigma^+ in vector gating if `True`) 231 | ''' 232 | def __init__(self, in_dims, out_dims, edge_dims, 233 | n_layers=3, module_list=None, aggr="mean", 234 | activations=(F.relu, torch.sigmoid), vector_gate=False): 235 | super(GVPConv, self).__init__(aggr=aggr) 236 | self.si, self.vi = in_dims 237 | self.so, self.vo = out_dims 238 | self.se, self.ve = edge_dims 239 | 240 | GVP_ = functools.partial(GVP, 241 | activations=activations, vector_gate=vector_gate) 242 | 243 | module_list = module_list or [] 244 | if not module_list: 245 | if n_layers == 1: 246 | module_list.append( 247 | GVP_((2*self.si + self.se, 2*self.vi + self.ve), 248 | (self.so, self.vo), activations=(None, None))) 249 | else: 250 | module_list.append( 251 | GVP_((2*self.si + self.se, 2*self.vi + self.ve), out_dims) 252 | ) 253 | for i in range(n_layers - 2): 254 | module_list.append(GVP_(out_dims, out_dims)) 255 | module_list.append(GVP_(out_dims, out_dims, 256 | activations=(None, None))) 257 | self.message_func = nn.Sequential(*module_list) 258 | 259 | def forward(self, x, edge_index, edge_attr): 260 | ''' 261 | :param x: tuple (s, V) of `torch.Tensor` 262 | :param edge_index: array of shape [2, n_edges] 263 | :param edge_attr: tuple (s, V) of `torch.Tensor` 264 | ''' 265 | x_s, x_v = x 266 | message = self.propagate(edge_index, 267 | s=x_s, v=x_v.reshape(x_v.shape[0], 3*x_v.shape[1]), 268 | edge_attr=edge_attr) 269 | return _split(message, self.vo) 270 | 271 | def message(self, s_i, v_i, s_j, v_j, edge_attr): 272 | v_j = v_j.view(v_j.shape[0], v_j.shape[1]//3, 3) 273 | v_i = v_i.view(v_i.shape[0], v_i.shape[1]//3, 3) 274 | message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i)) 275 | message = self.message_func(message) 276 | return _merge(*message) 277 | 278 | 279 | class GVPConvLayer(nn.Module): 280 | ''' 281 | Full graph convolution / message passing layer with 282 | Geometric Vector Perceptrons. Residually updates node embeddings with 283 | aggregated incoming messages, applies a pointwise feedforward 284 | network to node embeddings, and returns updated node embeddings. 285 | 286 | To only compute the aggregated messages, see `GVPConv`. 287 | 288 | :param node_dims: node embedding dimensions (n_scalar, n_vector) 289 | :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) 290 | :param n_message: number of GVPs to use in message function 291 | :param n_feedforward: number of GVPs to use in feedforward function 292 | :param drop_rate: drop probability in all dropout layers 293 | :param autoregressive: if `True`, this `GVPConvLayer` will be used 294 | with a different set of input node embeddings for messages 295 | where src >= dst 296 | :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs 297 | :param vector_gate: whether to use vector gating. 298 | (vector_act will be used as sigma^+ in vector gating if `True`) 299 | ''' 300 | def __init__(self, node_dims, edge_dims, 301 | n_message=3, n_feedforward=2, drop_rate=.1, 302 | autoregressive=False, 303 | activations=(F.relu, torch.sigmoid), vector_gate=False): 304 | 305 | super(GVPConvLayer, self).__init__() 306 | self.conv = GVPConv(node_dims, node_dims, edge_dims, n_message, 307 | aggr="add" if autoregressive else "mean", 308 | activations=activations, vector_gate=vector_gate) 309 | GVP_ = functools.partial(GVP, 310 | activations=activations, vector_gate=vector_gate) 311 | self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)]) 312 | self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)]) 313 | 314 | ff_func = [] 315 | if n_feedforward == 1: 316 | ff_func.append(GVP_(node_dims, node_dims, activations=(None, None))) 317 | else: 318 | hid_dims = 4*node_dims[0], 2*node_dims[1] 319 | ff_func.append(GVP_(node_dims, hid_dims)) 320 | for i in range(n_feedforward-2): 321 | ff_func.append(GVP_(hid_dims, hid_dims)) 322 | ff_func.append(GVP_(hid_dims, node_dims, activations=(None, None))) 323 | self.ff_func = nn.Sequential(*ff_func) 324 | 325 | def forward(self, x, edge_index, edge_attr, 326 | autoregressive_x=None, node_mask=None): 327 | ''' 328 | :param x: tuple (s, V) of `torch.Tensor` 329 | :param edge_index: array of shape [2, n_edges] 330 | :param edge_attr: tuple (s, V) of `torch.Tensor` 331 | :param autoregressive_x: tuple (s, V) of `torch.Tensor`. 332 | If not `None`, will be used as src node embeddings 333 | for forming messages where src >= dst. The corrent node 334 | embeddings `x` will still be the base of the update and the 335 | pointwise feedforward. 336 | :param node_mask: array of type `bool` to index into the first 337 | dim of node embeddings (s, V). If not `None`, only 338 | these nodes will be updated. 339 | ''' 340 | 341 | if autoregressive_x is not None: 342 | src, dst = edge_index 343 | mask = src < dst 344 | edge_index_forward = edge_index[:, mask] 345 | edge_index_backward = edge_index[:, ~mask] 346 | edge_attr_forward = tuple_index(edge_attr, mask) 347 | edge_attr_backward = tuple_index(edge_attr, ~mask) 348 | 349 | dh = tuple_sum( 350 | self.conv(x, edge_index_forward, edge_attr_forward), 351 | self.conv(autoregressive_x, edge_index_backward, edge_attr_backward) 352 | ) 353 | 354 | count = scatter_add(torch.ones_like(dst), dst, 355 | dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1) 356 | 357 | dh = dh[0] / count, dh[1] / count.unsqueeze(-1) 358 | 359 | else: 360 | dh = self.conv(x, edge_index, edge_attr) 361 | 362 | if node_mask is not None: 363 | x_ = x 364 | x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask) 365 | 366 | x = self.norm[0](tuple_sum(x, self.dropout[0](dh))) 367 | 368 | dh = self.ff_func(x) 369 | x = self.norm[1](tuple_sum(x, self.dropout[1](dh))) 370 | 371 | if node_mask is not None: 372 | x_[0][node_mask], x_[1][node_mask] = x[0], x[1] 373 | x = x_ 374 | return x 375 | -------------------------------------------------------------------------------- /gvp/atom3d.py: -------------------------------------------------------------------------------- 1 | import torch, random, scipy, math 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pandas as pd 5 | import numpy as np 6 | from atom3d.datasets import LMDBDataset 7 | import atom3d.datasets.ppi.neighbors as nb 8 | from torch.utils.data import IterableDataset 9 | from . import GVP, GVPConvLayer, LayerNorm 10 | import torch_cluster, torch_geometric, torch_scatter 11 | from .data import _normalize, _rbf 12 | 13 | _NUM_ATOM_TYPES = 9 14 | _element_mapping = lambda x: { 15 | 'H' : 0, 16 | 'C' : 1, 17 | 'N' : 2, 18 | 'O' : 3, 19 | 'F' : 4, 20 | 'S' : 5, 21 | 'Cl': 6, 'CL': 6, 22 | 'P' : 7 23 | }.get(x, 8) 24 | _amino_acids = lambda x: { 25 | 'ALA': 0, 26 | 'ARG': 1, 27 | 'ASN': 2, 28 | 'ASP': 3, 29 | 'CYS': 4, 30 | 'GLU': 5, 31 | 'GLN': 6, 32 | 'GLY': 7, 33 | 'HIS': 8, 34 | 'ILE': 9, 35 | 'LEU': 10, 36 | 'LYS': 11, 37 | 'MET': 12, 38 | 'PHE': 13, 39 | 'PRO': 14, 40 | 'SER': 15, 41 | 'THR': 16, 42 | 'TRP': 17, 43 | 'TYR': 18, 44 | 'VAL': 19 45 | }.get(x, 20) 46 | _DEFAULT_V_DIM = (100, 16) 47 | _DEFAULT_E_DIM = (32, 1) 48 | 49 | def _edge_features(coords, edge_index, D_max=4.5, num_rbf=16, device='cpu'): 50 | 51 | E_vectors = coords[edge_index[0]] - coords[edge_index[1]] 52 | rbf = _rbf(E_vectors.norm(dim=-1), 53 | D_max=D_max, D_count=num_rbf, device=device) 54 | 55 | edge_s = rbf 56 | edge_v = _normalize(E_vectors).unsqueeze(-2) 57 | 58 | edge_s, edge_v = map(torch.nan_to_num, 59 | (edge_s, edge_v)) 60 | 61 | return edge_s, edge_v 62 | 63 | class BaseTransform: 64 | ''' 65 | Implementation of an ATOM3D Transform which featurizes the atomic 66 | coordinates in an ATOM3D dataframes into `torch_geometric.data.Data` 67 | graphs. This class should not be used directly; instead, use the 68 | task-specific transforms, which all extend BaseTransform. Node 69 | and edge features are as described in the EGNN manuscript. 70 | 71 | Returned graphs have the following attributes: 72 | -x atomic coordinates, shape [n_nodes, 3] 73 | -atoms numeric encoding of atomic identity, shape [n_nodes] 74 | -edge_index edge indices, shape [2, n_edges] 75 | -edge_s edge scalar features, shape [n_edges, 16] 76 | -edge_v edge scalar features, shape [n_edges, 1, 3] 77 | 78 | Subclasses of BaseTransform will produce graphs with additional 79 | attributes for the tasks-specific training labels, in addition 80 | to the above. 81 | 82 | All subclasses of BaseTransform directly inherit the BaseTransform 83 | constructor. 84 | 85 | :param edge_cutoff: distance cutoff to use when drawing edges 86 | :param num_rbf: number of radial bases to encode the distance on each edge 87 | :device: if "cuda", will do preprocessing on the GPU 88 | ''' 89 | def __init__(self, edge_cutoff=4.5, num_rbf=16, device='cpu'): 90 | self.edge_cutoff = edge_cutoff 91 | self.num_rbf = num_rbf 92 | self.device = device 93 | 94 | def __call__(self, df): 95 | ''' 96 | :param df: `pandas.DataFrame` of atomic coordinates 97 | in the ATOM3D format 98 | 99 | :return: `torch_geometric.data.Data` structure graph 100 | ''' 101 | with torch.no_grad(): 102 | coords = torch.as_tensor(df[['x', 'y', 'z']].to_numpy(), 103 | dtype=torch.float32, device=self.device) 104 | atoms = torch.as_tensor(list(map(_element_mapping, df.element)), 105 | dtype=torch.long, device=self.device) 106 | 107 | edge_index = torch_cluster.radius_graph(coords, r=self.edge_cutoff) 108 | 109 | edge_s, edge_v = _edge_features(coords, edge_index, 110 | D_max=self.edge_cutoff, num_rbf=self.num_rbf, device=self.device) 111 | 112 | return torch_geometric.data.Data(x=coords, atoms=atoms, 113 | edge_index=edge_index, edge_s=edge_s, edge_v=edge_v) 114 | 115 | class BaseModel(nn.Module): 116 | ''' 117 | A base 5-layer GVP-GNN for all ATOM3D tasks, using GVPs with 118 | vector gating as described in the manuscript. Takes in atomic-level 119 | structure graphs of type `torch_geometric.data.Batch` 120 | and returns a single scalar. 121 | 122 | This class should not be used directly. Instead, please use the 123 | task-specific models which extend BaseModel. (Some of these classes 124 | may be aliases of BaseModel.) 125 | 126 | :param num_rbf: number of radial bases to use in the edge embedding 127 | ''' 128 | def __init__(self, num_rbf=16): 129 | 130 | super().__init__() 131 | activations = (F.relu, None) 132 | 133 | self.embed = nn.Embedding(_NUM_ATOM_TYPES, _NUM_ATOM_TYPES) 134 | 135 | self.W_e = nn.Sequential( 136 | LayerNorm((num_rbf, 1)), 137 | GVP((num_rbf, 1), _DEFAULT_E_DIM, 138 | activations=(None, None), vector_gate=True) 139 | ) 140 | 141 | self.W_v = nn.Sequential( 142 | LayerNorm((_NUM_ATOM_TYPES, 0)), 143 | GVP((_NUM_ATOM_TYPES, 0), _DEFAULT_V_DIM, 144 | activations=(None, None), vector_gate=True) 145 | ) 146 | 147 | self.layers = nn.ModuleList( 148 | GVPConvLayer(_DEFAULT_V_DIM, _DEFAULT_E_DIM, 149 | activations=activations, vector_gate=True) 150 | for _ in range(5)) 151 | 152 | ns, _ = _DEFAULT_V_DIM 153 | self.W_out = nn.Sequential( 154 | LayerNorm(_DEFAULT_V_DIM), 155 | GVP(_DEFAULT_V_DIM, (ns, 0), 156 | activations=activations, vector_gate=True) 157 | ) 158 | 159 | self.dense = nn.Sequential( 160 | nn.Linear(ns, 2*ns), nn.ReLU(inplace=True), 161 | nn.Dropout(p=0.1), 162 | nn.Linear(2*ns, 1) 163 | ) 164 | 165 | def forward(self, batch, scatter_mean=True, dense=True): 166 | ''' 167 | Forward pass which can be adjusted based on task formulation. 168 | 169 | :param batch: `torch_geometric.data.Batch` with data attributes 170 | as returned from a BaseTransform 171 | :param scatter_mean: if `True`, returns mean of final node embeddings 172 | (for each graph), else, returns embeddings seperately 173 | :param dense: if `True`, applies final dense layer to reduce embedding 174 | to a single scalar; else, returns the embedding 175 | ''' 176 | h_V = self.embed(batch.atoms) 177 | h_E = (batch.edge_s, batch.edge_v) 178 | h_V = self.W_v(h_V) 179 | h_E = self.W_e(h_E) 180 | 181 | batch_id = batch.batch 182 | 183 | for layer in self.layers: 184 | h_V = layer(h_V, batch.edge_index, h_E) 185 | 186 | out = self.W_out(h_V) 187 | if scatter_mean: out = torch_scatter.scatter_mean(out, batch_id, dim=0) 188 | if dense: out = self.dense(out).squeeze(-1) 189 | return out 190 | 191 | ######################################################################## 192 | 193 | class SMPTransform(BaseTransform): 194 | ''' 195 | Transforms dict-style entries from the ATOM3D SMP dataset 196 | to featurized graphs. Returns a `torch_geometric.data.Data` 197 | graph with attribute `label` and all structural attributes 198 | as described in BaseTransform. 199 | 200 | Includes hydrogen atoms. 201 | ''' 202 | def __call__(self, elem): 203 | data = super().__call__(elem['atoms']) 204 | with torch.no_grad(): 205 | data.label = torch.as_tensor(elem['labels'], 206 | device=self.device, dtype=torch.float32) 207 | return data 208 | 209 | SMPModel = BaseModel 210 | 211 | ######################################################################## 212 | 213 | class PPIDataset(IterableDataset): 214 | ''' 215 | A `torch.utils.data.IterableDataset` wrapper around a 216 | ATOM3D PPI dataset. Extracts (many) individual amino acid pairs 217 | from each structure of two interacting proteins. The returned graphs 218 | are seperate and each represents a 30 angstrom radius from the 219 | selected residue's alpha carbon. 220 | 221 | On each iteration, returns a pair of `torch_geometric.data.Data` 222 | graphs with the (same) attribute `label` which is 1 if the two 223 | amino acids interact and 0 otherwise, `ca_idx` for the node index 224 | of the alpha carbon, and all structural attributes as 225 | described in BaseTransform. 226 | 227 | Modified from 228 | https://github.com/drorlab/atom3d/blob/master/examples/ppi/gnn/data.py 229 | 230 | Excludes hydrogen atoms. 231 | 232 | :param lmdb_dataset: path to ATOM3D dataset 233 | ''' 234 | def __init__(self, lmdb_dataset): 235 | self.dataset = LMDBDataset(lmdb_dataset) 236 | self.transform = BaseTransform() 237 | 238 | def __iter__(self): 239 | worker_info = torch.utils.data.get_worker_info() 240 | if worker_info is None: 241 | gen = self._dataset_generator(list(range(len(self.dataset))), shuffle=True) 242 | else: 243 | per_worker = int(math.ceil(len(self.dataset) / float(worker_info.num_workers))) 244 | worker_id = worker_info.id 245 | iter_start = worker_id * per_worker 246 | iter_end = min(iter_start + per_worker, len(self.dataset)) 247 | gen = self._dataset_generator( 248 | list(range(len(self.dataset)))[iter_start:iter_end], 249 | shuffle=True) 250 | return gen 251 | 252 | def _df_to_graph(self, struct_df, chain_res, label): 253 | 254 | struct_df = struct_df[struct_df.element != 'H'].reset_index(drop=True) 255 | 256 | chain, resnum = chain_res 257 | res_df = struct_df[(struct_df.chain == chain) & (struct_df.residue == resnum)] 258 | if 'CA' not in res_df.name.tolist(): 259 | return None 260 | ca_pos = res_df[res_df['name']=='CA'][['x', 'y', 'z']].astype(np.float32).to_numpy()[0] 261 | 262 | kd_tree = scipy.spatial.KDTree(struct_df[['x','y','z']].to_numpy()) 263 | graph_pt_idx = kd_tree.query_ball_point(ca_pos, r=30.0, p=2.0) 264 | graph_df = struct_df.iloc[graph_pt_idx].reset_index(drop=True) 265 | 266 | ca_idx = np.where((graph_df.chain == chain) & (graph_df.residue == resnum) & (graph_df.name == 'CA'))[0] 267 | if len(ca_idx) != 1: 268 | return None 269 | 270 | data = self.transform(graph_df) 271 | data.label = label 272 | 273 | data.ca_idx = int(ca_idx) 274 | data.n_nodes = data.num_nodes 275 | 276 | return data 277 | 278 | def _dataset_generator(self, indices, shuffle=True): 279 | if shuffle: random.shuffle(indices) 280 | with torch.no_grad(): 281 | for idx in indices: 282 | data = self.dataset[idx] 283 | 284 | neighbors = data['atoms_neighbors'] 285 | pairs = data['atoms_pairs'] 286 | 287 | for i, (ensemble_name, target_df) in enumerate(pairs.groupby(['ensemble'])): 288 | sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df) 289 | positives = neighbors[neighbors.ensemble0 == ensemble_name] 290 | negatives = nb.get_negatives(positives, bound1, bound2) 291 | negatives['label'] = 0 292 | labels = self._create_labels(positives, negatives, num_pos=10, neg_pos_ratio=1) 293 | 294 | for index, row in labels.iterrows(): 295 | 296 | label = float(row['label']) 297 | chain_res1 = row[['chain0', 'residue0']].values 298 | chain_res2 = row[['chain1', 'residue1']].values 299 | graph1 = self._df_to_graph(bound1, chain_res1, label) 300 | graph2 = self._df_to_graph(bound2, chain_res2, label) 301 | if (graph1 is None) or (graph2 is None): 302 | continue 303 | yield graph1, graph2 304 | 305 | def _create_labels(self, positives, negatives, num_pos, neg_pos_ratio): 306 | frac = min(1, num_pos / positives.shape[0]) 307 | positives = positives.sample(frac=frac) 308 | n = positives.shape[0] * neg_pos_ratio 309 | n = min(negatives.shape[0], n) 310 | negatives = negatives.sample(n, random_state=0, axis=0) 311 | labels = pd.concat([positives, negatives])[['chain0', 'residue0', 'chain1', 'residue1', 'label']] 312 | return labels 313 | 314 | class PPIModel(BaseModel): 315 | ''' 316 | GVP-GNN for the PPI task. 317 | 318 | Extends BaseModel to accept a tuple (batch1, batch2) 319 | of `torch_geometric.data.Batch` graphs, where each graph 320 | index in a batch is paired with the same graph index in the 321 | other batch. 322 | 323 | As noted in the manuscript, PPIModel uses the final alpha 324 | carbon embeddings instead of the graph mean embedding. 325 | 326 | Returns a single scalar for each graph pair which can be used as 327 | a logit in binary classification. 328 | ''' 329 | def __init__(self, **kwargs): 330 | 331 | super().__init__(**kwargs) 332 | ns, _ = _DEFAULT_V_DIM 333 | self.dense = nn.Sequential( 334 | nn.Linear(2*ns, 4*ns), nn.ReLU(inplace=True), 335 | nn.Dropout(p=0.1), 336 | nn.Linear(4*ns, 1) 337 | ) 338 | 339 | def forward(self, batch): 340 | graph1, graph2 = batch 341 | out1, out2 = map(self._gnn_forward, (graph1, graph2)) 342 | out = torch.cat([out1, out2], dim=-1) 343 | out = self.dense(out) 344 | return torch.sigmoid(out).squeeze(-1) 345 | 346 | def _gnn_forward(self, graph): 347 | out = super().forward(graph, scatter_mean=False, dense=False) 348 | return out[graph.ca_idx+graph.ptr[:-1]] 349 | 350 | 351 | ######################################################################## 352 | 353 | class LBATransform(BaseTransform): 354 | ''' 355 | Transforms dict-style entries from the ATOM3D LBA dataset 356 | to featurized graphs. Returns a `torch_geometric.data.Data` 357 | graph with attribute `label` for the neglog-affinity 358 | and all structural attributes as described in BaseTransform. 359 | 360 | The transform combines the atomic coordinates of the pocket 361 | and ligand atoms and treats them as a single structure / graph. 362 | 363 | Includes hydrogen atoms. 364 | ''' 365 | def __call__(self, elem): 366 | pocket, ligand = elem['atoms_pocket'], elem['atoms_ligand'] 367 | df = pd.concat([pocket, ligand], ignore_index=True) 368 | 369 | data = super().__call__(df) 370 | with torch.no_grad(): 371 | data.label = elem['scores']['neglog_aff'] 372 | lig_flag = torch.zeros(df.shape[0], device=self.device, dtype=torch.bool) 373 | lig_flag[-len(ligand):] = 1 374 | data.lig_flag = lig_flag 375 | return data 376 | 377 | LBAModel = BaseModel 378 | 379 | ######################################################################## 380 | 381 | class LEPTransform(BaseTransform): 382 | ''' 383 | Transforms dict-style entries from the ATOM3D LEP dataset 384 | to featurized graphs. Returns a tuple (active, inactive) of 385 | `torch_geometric.data.Data` graphs with the (same) attribute 386 | `label` which is equal to 1. if the ligand activates the protein 387 | and 0. otherwise, and all structural attributes as described 388 | in BaseTransform. 389 | 390 | The transform combines the atomic coordinates of the pocket 391 | and ligand atoms and treats them as a single structure / graph. 392 | 393 | Excludes hydrogen atoms. 394 | ''' 395 | def __call__(self, elem): 396 | active, inactive = elem['atoms_active'], elem['atoms_inactive'] 397 | with torch.no_grad(): 398 | active, inactive = map(self._to_graph, (active, inactive)) 399 | active.label = inactive.label = 1. if elem['label'] == 'A' else 0. 400 | return active, inactive 401 | 402 | def _to_graph(self, df): 403 | df = df[df.element != 'H'].reset_index(drop=True) 404 | return super().__call__(df) 405 | 406 | class LEPModel(BaseModel): 407 | ''' 408 | GVP-GNN for the LEP task. 409 | 410 | Extends BaseModel to accept a tuple (batch1, batch2) 411 | of `torch_geometric.data.Batch` graphs, where each graph 412 | index in a batch is paired with the same graph index in the 413 | other batch. 414 | 415 | Returns a single scalar for each graph pair which can be used as 416 | a logit in binary classification. 417 | ''' 418 | def __init__(self, **kwargs): 419 | super().__init__(**kwargs) 420 | ns, _ = _DEFAULT_V_DIM 421 | self.dense = nn.Sequential( 422 | nn.Linear(2*ns, 4*ns), nn.ReLU(inplace=True), 423 | nn.Dropout(p=0.1), 424 | nn.Linear(4*ns, 1) 425 | ) 426 | 427 | def forward(self, batch): 428 | out1, out2 = map(self._gnn_forward, batch) 429 | out = torch.cat([out1, out2], dim=-1) 430 | out = self.dense(out) 431 | return torch.sigmoid(out).squeeze(-1) 432 | 433 | def _gnn_forward(self, graph): 434 | return super().forward(graph, dense=False) 435 | 436 | ######################################################################## 437 | 438 | class MSPTransform(BaseTransform): 439 | ''' 440 | Transforms dict-style entries from the ATOM3D MSP dataset 441 | to featurized graphs. Returns a tuple (original, mutated) of 442 | `torch_geometric.data.Data` graphs with the (same) attribute 443 | `label` which is equal to 1. if the mutation stabilizes the 444 | complex and 0. otherwise, and all structural attributes as 445 | described in BaseTransform. 446 | 447 | The transform combines the atomic coordinates of the two proteis 448 | in each complex and treats them as a single structure / graph. 449 | 450 | Adapted from 451 | https://github.com/drorlab/atom3d/blob/master/examples/msp/gnn/data.py 452 | 453 | Excludes hydrogen atoms. 454 | ''' 455 | def __call__(self, elem): 456 | mutation = elem['id'].split('_')[-1] 457 | orig_df = elem['original_atoms'].reset_index(drop=True) 458 | mut_df = elem['mutated_atoms'].reset_index(drop=True) 459 | with torch.no_grad(): 460 | original, mutated = self._transform(orig_df, mutation), \ 461 | self._transform(mut_df, mutation) 462 | original.label = mutated.label = 1. if elem['label'] == '1' else 0. 463 | return original, mutated 464 | 465 | def _transform(self, df, mutation): 466 | 467 | df = df[df.element != 'H'].reset_index(drop=True) 468 | data = super().__call__(df) 469 | data.node_mask = self._extract_node_mask(df, mutation) 470 | return data 471 | 472 | def _extract_node_mask(self, df, mutation): 473 | chain, res = mutation[1], int(mutation[2:-1]) 474 | idx = df.index[(df.chain.values == chain) & (df.residue.values == res)].values 475 | mask = torch.zeros(len(df), dtype=torch.long, device=self.device) 476 | mask[idx] = 1 477 | return mask 478 | 479 | class MSPModel(BaseModel): 480 | ''' 481 | GVP-GNN for the MSP task. 482 | 483 | Extends BaseModel to accept a tuple (batch1, batch2) 484 | of `torch_geometric.data.Batch` graphs, where each graph 485 | index in a batch is paired with the same graph index in the 486 | other batch. 487 | 488 | As noted in the manuscript, MSPModel uses the final embeddings 489 | averaged over the residue of interest instead of the entire graph. 490 | 491 | Returns a single scalar for each graph pair which can be used as 492 | a logit in binary classification. 493 | ''' 494 | def __init__(self, **kwargs): 495 | super().__init__(**kwargs) 496 | ns, _ = _DEFAULT_V_DIM 497 | self.dense = nn.Sequential( 498 | nn.Linear(2*ns, 4*ns), nn.ReLU(inplace=True), 499 | nn.Dropout(p=0.1), 500 | nn.Linear(4*ns, 1) 501 | ) 502 | 503 | def forward(self, batch): 504 | out1, out2 = map(self._gnn_forward, batch) 505 | out = torch.cat([out1, out2], dim=-1) 506 | out = self.dense(out) 507 | return torch.sigmoid(out).squeeze(-1) 508 | 509 | def _gnn_forward(self, graph): 510 | out = super().forward(graph, scatter_mean=False, dense=False) 511 | out = out * graph.node_mask.unsqueeze(-1) 512 | out = torch_scatter.scatter_add(out, graph.batch, dim=0) 513 | count = torch_scatter.scatter_add(graph.node_mask, graph.batch) 514 | return out / count.unsqueeze(-1) 515 | 516 | ######################################################################## 517 | 518 | class PSRTransform(BaseTransform): 519 | ''' 520 | Transforms dict-style entries from the ATOM3D PSR dataset 521 | to featurized graphs. Returns a `torch_geometric.data.Data` 522 | graph with attribute `label` for the GDT_TS, `id` for the 523 | name of the target, and all structural attributes as 524 | described in BaseTransform. 525 | 526 | Includes hydrogen atoms. 527 | ''' 528 | def __call__(self, elem): 529 | df = elem['atoms'] 530 | df = df[df.element != 'H'].reset_index(drop=True) 531 | data = super().__call__(df) 532 | data.label = elem['scores']['gdt_ts'] 533 | data.id = eval(elem['id'])[0] 534 | return data 535 | 536 | PSRModel = BaseModel 537 | 538 | ######################################################################## 539 | 540 | class RSRTransform(BaseTransform): 541 | ''' 542 | Transforms dict-style entries from the ATOM3D RSR dataset 543 | to featurized graphs. Returns a `torch_geometric.data.Data` 544 | graph with attribute `label` for the RMSD, `id` for the 545 | name of the target, and all structural attributes as 546 | described in BaseTransform. 547 | 548 | Includes hydrogen atoms. 549 | ''' 550 | def __call__(self, elem): 551 | df = elem['atoms'] 552 | df = df[df.element != 'H'].reset_index(drop=True) 553 | data = super().__call__(df) 554 | data.label = elem['scores']['rms'] 555 | data.id = eval(elem['id'])[0] 556 | return data 557 | 558 | RSRModel = BaseModel 559 | 560 | ######################################################################## 561 | 562 | class RESDataset(IterableDataset): 563 | ''' 564 | A `torch.utils.data.IterableDataset` wrapper around a 565 | ATOM3D RES dataset. 566 | 567 | On each iteration, returns a `torch_geometric.data.Data` 568 | graph with the attribute `label` encoding the masked residue 569 | identity, `ca_idx` for the node index of the alpha carbon, 570 | and all structural attributes as described in BaseTransform. 571 | 572 | Excludes hydrogen atoms. 573 | 574 | :param lmdb_dataset: path to ATOM3D dataset 575 | :param split_path: path to the ATOM3D split file 576 | ''' 577 | def __init__(self, lmdb_dataset, split_path): 578 | self.dataset = LMDBDataset(lmdb_dataset) 579 | self.idx = list(map(int, open(split_path).read().split())) 580 | self.transform = BaseTransform() 581 | 582 | def __iter__(self): 583 | worker_info = torch.utils.data.get_worker_info() 584 | if worker_info is None: 585 | gen = self._dataset_generator(list(range(len(self.idx))), 586 | shuffle=True) 587 | else: 588 | per_worker = int(math.ceil(len(self.idx) / float(worker_info.num_workers))) 589 | worker_id = worker_info.id 590 | iter_start = worker_id * per_worker 591 | iter_end = min(iter_start + per_worker, len(self.idx)) 592 | gen = self._dataset_generator(list(range(len(self.idx)))[iter_start:iter_end], 593 | shuffle=True) 594 | return gen 595 | 596 | def _dataset_generator(self, indices, shuffle=True): 597 | if shuffle: random.shuffle(indices) 598 | with torch.no_grad(): 599 | for idx in indices: 600 | data = self.dataset[self.idx[idx]] 601 | atoms = data['atoms'] 602 | for sub in data['labels'].itertuples(): 603 | _, num, aa = sub.subunit.split('_') 604 | num, aa = int(num), _amino_acids(aa) 605 | if aa == 20: continue 606 | my_atoms = atoms.iloc[data['subunit_indices'][sub.Index]].reset_index(drop=True) 607 | ca_idx = np.where((my_atoms.residue == num) & (my_atoms.name == 'CA'))[0] 608 | if len(ca_idx) != 1: continue 609 | 610 | with torch.no_grad(): 611 | graph = self.transform(my_atoms) 612 | graph.label = aa 613 | graph.ca_idx = int(ca_idx) 614 | yield graph 615 | 616 | class RESModel(BaseModel): 617 | ''' 618 | GVP-GNN for the RES task. 619 | 620 | Extends BaseModel to output a 20-dim vector instead of a single 621 | scalar for each graph, which can be used as logits in 20-way 622 | classification. 623 | 624 | As noted in the manuscript, RESModel uses the final alpha 625 | carbon embeddings instead of the graph mean embedding. 626 | ''' 627 | def __init__(self, **kwargs): 628 | super().__init__(**kwargs) 629 | ns, _ = _DEFAULT_V_DIM 630 | self.dense = nn.Sequential( 631 | nn.Linear(ns, 2*ns), nn.ReLU(inplace=True), 632 | nn.Dropout(p=0.1), 633 | nn.Linear(2*ns, 20) 634 | ) 635 | def forward(self, batch): 636 | out = super().forward(batch, scatter_mean=False) 637 | return out[batch.ca_idx+batch.ptr[:-1]] -------------------------------------------------------------------------------- /gvp/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import tqdm, random 4 | import torch, math 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | import torch_geometric 8 | import torch_cluster 9 | 10 | def _normalize(tensor, dim=-1): 11 | ''' 12 | Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. 13 | ''' 14 | return torch.nan_to_num( 15 | torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) 16 | 17 | 18 | def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'): 19 | ''' 20 | From https://github.com/jingraham/neurips19-graph-protein-design 21 | 22 | Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1. 23 | That is, if `D` has shape [...dims], then the returned tensor will have 24 | shape [...dims, D_count]. 25 | ''' 26 | D_mu = torch.linspace(D_min, D_max, D_count, device=device) 27 | D_mu = D_mu.view([1, -1]) 28 | D_sigma = (D_max - D_min) / D_count 29 | D_expand = torch.unsqueeze(D, -1) 30 | 31 | RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2) 32 | return RBF 33 | 34 | 35 | class CATHDataset: 36 | ''' 37 | Loader and container class for the CATH 4.2 dataset downloaded 38 | from http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/. 39 | 40 | Has attributes `self.train`, `self.val`, `self.test`, each of which are 41 | JSON/dictionary-type datasets as described in README.md. 42 | 43 | :param path: path to chain_set.jsonl 44 | :param splits_path: path to chain_set_splits.json or equivalent. 45 | ''' 46 | def __init__(self, path, splits_path): 47 | with open(splits_path) as f: 48 | dataset_splits = json.load(f) 49 | train_list, val_list, test_list = dataset_splits['train'], \ 50 | dataset_splits['validation'], dataset_splits['test'] 51 | 52 | self.train, self.val, self.test = [], [], [] 53 | 54 | with open(path) as f: 55 | lines = f.readlines() 56 | 57 | for line in tqdm.tqdm(lines): 58 | entry = json.loads(line) 59 | name = entry['name'] 60 | coords = entry['coords'] 61 | 62 | entry['coords'] = list(zip( 63 | coords['N'], coords['CA'], coords['C'], coords['O'] 64 | )) 65 | 66 | if name in train_list: 67 | self.train.append(entry) 68 | elif name in val_list: 69 | self.val.append(entry) 70 | elif name in test_list: 71 | self.test.append(entry) 72 | 73 | class BatchSampler(data.Sampler): 74 | ''' 75 | From https://github.com/jingraham/neurips19-graph-protein-design. 76 | 77 | A `torch.utils.data.Sampler` which samples batches according to a 78 | maximum number of graph nodes. 79 | 80 | :param node_counts: array of node counts in the dataset to sample from 81 | :param max_nodes: the maximum number of nodes in any batch, 82 | including batches of a single element 83 | :param shuffle: if `True`, batches in shuffled order 84 | ''' 85 | def __init__(self, node_counts, max_nodes=3000, shuffle=True): 86 | 87 | self.node_counts = node_counts 88 | self.idx = [i for i in range(len(node_counts)) 89 | if node_counts[i] <= max_nodes] 90 | self.shuffle = shuffle 91 | self.max_nodes = max_nodes 92 | self._form_batches() 93 | 94 | def _form_batches(self): 95 | self.batches = [] 96 | if self.shuffle: random.shuffle(self.idx) 97 | idx = self.idx 98 | while idx: 99 | batch = [] 100 | n_nodes = 0 101 | while idx and n_nodes + self.node_counts[idx[0]] <= self.max_nodes: 102 | next_idx, idx = idx[0], idx[1:] 103 | n_nodes += self.node_counts[next_idx] 104 | batch.append(next_idx) 105 | self.batches.append(batch) 106 | 107 | def __len__(self): 108 | if not self.batches: self._form_batches() 109 | return len(self.batches) 110 | 111 | def __iter__(self): 112 | if not self.batches: self._form_batches() 113 | for batch in self.batches: yield batch 114 | 115 | class ProteinGraphDataset(data.Dataset): 116 | ''' 117 | A map-syle `torch.utils.data.Dataset` which transforms JSON/dictionary-style 118 | protein structures into featurized protein graphs as described in the 119 | manuscript. 120 | 121 | Returned graphs are of type `torch_geometric.data.Data` with attributes 122 | -x alpha carbon coordinates, shape [n_nodes, 3] 123 | -seq sequence converted to int tensor according to `self.letter_to_num`, shape [n_nodes] 124 | -name name of the protein structure, string 125 | -node_s node scalar features, shape [n_nodes, 6] 126 | -node_v node vector features, shape [n_nodes, 3, 3] 127 | -edge_s edge scalar features, shape [n_edges, 32] 128 | -edge_v edge scalar features, shape [n_edges, 1, 3] 129 | -edge_index edge indices, shape [2, n_edges] 130 | -mask node mask, `False` for nodes with missing data that are excluded from message passing 131 | 132 | Portions from https://github.com/jingraham/neurips19-graph-protein-design. 133 | 134 | :param data_list: JSON/dictionary-style protein dataset as described in README.md. 135 | :param num_positional_embeddings: number of positional embeddings 136 | :param top_k: number of edges to draw per node (as destination node) 137 | :param device: if "cuda", will do preprocessing on the GPU 138 | ''' 139 | def __init__(self, data_list, 140 | num_positional_embeddings=16, 141 | top_k=30, num_rbf=16, device="cpu"): 142 | 143 | super(ProteinGraphDataset, self).__init__() 144 | 145 | self.data_list = data_list 146 | self.top_k = top_k 147 | self.num_rbf = num_rbf 148 | self.num_positional_embeddings = num_positional_embeddings 149 | self.device = device 150 | self.node_counts = [len(e['seq']) for e in data_list] 151 | 152 | self.letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9, 153 | 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8, 154 | 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, 155 | 'N': 2, 'Y': 18, 'M': 12} 156 | self.num_to_letter = {v:k for k, v in self.letter_to_num.items()} 157 | 158 | def __len__(self): return len(self.data_list) 159 | 160 | def __getitem__(self, i): return self._featurize_as_graph(self.data_list[i]) 161 | 162 | def _featurize_as_graph(self, protein): 163 | name = protein['name'] 164 | with torch.no_grad(): 165 | coords = torch.as_tensor(protein['coords'], 166 | device=self.device, dtype=torch.float32) 167 | seq = torch.as_tensor([self.letter_to_num[a] for a in protein['seq']], 168 | device=self.device, dtype=torch.long) 169 | 170 | mask = torch.isfinite(coords.sum(dim=(1,2))) 171 | coords[~mask] = np.inf 172 | 173 | X_ca = coords[:, 1] 174 | edge_index = torch_cluster.knn_graph(X_ca, k=self.top_k) 175 | 176 | pos_embeddings = self._positional_embeddings(edge_index) 177 | E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]] 178 | rbf = _rbf(E_vectors.norm(dim=-1), D_count=self.num_rbf, device=self.device) 179 | 180 | dihedrals = self._dihedrals(coords) 181 | orientations = self._orientations(X_ca) 182 | sidechains = self._sidechains(coords) 183 | 184 | node_s = dihedrals 185 | node_v = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2) 186 | edge_s = torch.cat([rbf, pos_embeddings], dim=-1) 187 | edge_v = _normalize(E_vectors).unsqueeze(-2) 188 | 189 | node_s, node_v, edge_s, edge_v = map(torch.nan_to_num, 190 | (node_s, node_v, edge_s, edge_v)) 191 | 192 | data = torch_geometric.data.Data(x=X_ca, seq=seq, name=name, 193 | node_s=node_s, node_v=node_v, 194 | edge_s=edge_s, edge_v=edge_v, 195 | edge_index=edge_index, mask=mask) 196 | return data 197 | 198 | def _dihedrals(self, X, eps=1e-7): 199 | # From https://github.com/jingraham/neurips19-graph-protein-design 200 | 201 | X = torch.reshape(X[:, :3], [3*X.shape[0], 3]) 202 | dX = X[1:] - X[:-1] 203 | U = _normalize(dX, dim=-1) 204 | u_2 = U[:-2] 205 | u_1 = U[1:-1] 206 | u_0 = U[2:] 207 | 208 | # Backbone normals 209 | n_2 = _normalize(torch.cross(u_2, u_1), dim=-1) 210 | n_1 = _normalize(torch.cross(u_1, u_0), dim=-1) 211 | 212 | # Angle between normals 213 | cosD = torch.sum(n_2 * n_1, -1) 214 | cosD = torch.clamp(cosD, -1 + eps, 1 - eps) 215 | D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD) 216 | 217 | # This scheme will remove phi[0], psi[-1], omega[-1] 218 | D = F.pad(D, [1, 2]) 219 | D = torch.reshape(D, [-1, 3]) 220 | # Lift angle representations to the circle 221 | D_features = torch.cat([torch.cos(D), torch.sin(D)], 1) 222 | return D_features 223 | 224 | 225 | def _positional_embeddings(self, edge_index, 226 | num_embeddings=None, 227 | period_range=[2, 1000]): 228 | # From https://github.com/jingraham/neurips19-graph-protein-design 229 | num_embeddings = num_embeddings or self.num_positional_embeddings 230 | d = edge_index[0] - edge_index[1] 231 | 232 | frequency = torch.exp( 233 | torch.arange(0, num_embeddings, 2, dtype=torch.float32, device=self.device) 234 | * -(np.log(10000.0) / num_embeddings) 235 | ) 236 | angles = d.unsqueeze(-1) * frequency 237 | E = torch.cat((torch.cos(angles), torch.sin(angles)), -1) 238 | return E 239 | 240 | def _orientations(self, X): 241 | forward = _normalize(X[1:] - X[:-1]) 242 | backward = _normalize(X[:-1] - X[1:]) 243 | forward = F.pad(forward, [0, 0, 0, 1]) 244 | backward = F.pad(backward, [0, 0, 1, 0]) 245 | return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2) 246 | 247 | def _sidechains(self, X): 248 | n, origin, c = X[:, 0], X[:, 1], X[:, 2] 249 | c, n = _normalize(c - origin), _normalize(n - origin) 250 | bisector = _normalize(c + n) 251 | perp = _normalize(torch.cross(c, n)) 252 | vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3) 253 | return vec -------------------------------------------------------------------------------- /gvp/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from . import GVP, GVPConvLayer, LayerNorm, tuple_index 5 | from torch.distributions import Categorical 6 | from torch_scatter import scatter_mean 7 | 8 | class CPDModel(torch.nn.Module): 9 | ''' 10 | GVP-GNN for structure-conditioned autoregressive 11 | protein design as described in manuscript. 12 | 13 | Takes in protein structure graphs of type `torch_geometric.data.Data` 14 | or `torch_geometric.data.Batch` and returns a categorical distribution 15 | over 20 amino acids at each position in a `torch.Tensor` of 16 | shape [n_nodes, 20]. 17 | 18 | Should be used with `gvp.data.ProteinGraphDataset`, or with generators 19 | of `torch_geometric.data.Batch` objects with the same attributes. 20 | 21 | The standard forward pass requires sequence information as input 22 | and should be used for training or evaluating likelihood. 23 | For sampling or design, use `self.sample`. 24 | 25 | :param node_in_dim: node dimensions in input graph, should be 26 | (6, 3) if using original features 27 | :param node_h_dim: node dimensions to use in GVP-GNN layers 28 | :param node_in_dim: edge dimensions in input graph, should be 29 | (32, 1) if using original features 30 | :param edge_h_dim: edge dimensions to embed to before use 31 | in GVP-GNN layers 32 | :param num_layers: number of GVP-GNN layers in each of the encoder 33 | and decoder modules 34 | :param drop_rate: rate to use in all dropout layers 35 | ''' 36 | def __init__(self, node_in_dim, node_h_dim, 37 | edge_in_dim, edge_h_dim, 38 | num_layers=3, drop_rate=0.1): 39 | 40 | super(CPDModel, self).__init__() 41 | 42 | self.W_v = nn.Sequential( 43 | GVP(node_in_dim, node_h_dim, activations=(None, None)), 44 | LayerNorm(node_h_dim) 45 | ) 46 | self.W_e = nn.Sequential( 47 | GVP(edge_in_dim, edge_h_dim, activations=(None, None)), 48 | LayerNorm(edge_h_dim) 49 | ) 50 | 51 | self.encoder_layers = nn.ModuleList( 52 | GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) 53 | for _ in range(num_layers)) 54 | 55 | self.W_s = nn.Embedding(20, 20) 56 | edge_h_dim = (edge_h_dim[0] + 20, edge_h_dim[1]) 57 | 58 | self.decoder_layers = nn.ModuleList( 59 | GVPConvLayer(node_h_dim, edge_h_dim, 60 | drop_rate=drop_rate, autoregressive=True) 61 | for _ in range(num_layers)) 62 | 63 | self.W_out = GVP(node_h_dim, (20, 0), activations=(None, None)) 64 | 65 | def forward(self, h_V, edge_index, h_E, seq): 66 | ''' 67 | Forward pass to be used at train-time, or evaluating likelihood. 68 | 69 | :param h_V: tuple (s, V) of node embeddings 70 | :param edge_index: `torch.Tensor` of shape [2, num_edges] 71 | :param h_E: tuple (s, V) of edge embeddings 72 | :param seq: int `torch.Tensor` of shape [num_nodes] 73 | ''' 74 | h_V = self.W_v(h_V) 75 | h_E = self.W_e(h_E) 76 | 77 | for layer in self.encoder_layers: 78 | h_V = layer(h_V, edge_index, h_E) 79 | 80 | encoder_embeddings = h_V 81 | 82 | h_S = self.W_s(seq) 83 | h_S = h_S[edge_index[0]] 84 | h_S[edge_index[0] >= edge_index[1]] = 0 85 | h_E = (torch.cat([h_E[0], h_S], dim=-1), h_E[1]) 86 | 87 | for layer in self.decoder_layers: 88 | h_V = layer(h_V, edge_index, h_E, autoregressive_x = encoder_embeddings) 89 | 90 | logits = self.W_out(h_V) 91 | 92 | return logits 93 | 94 | def sample(self, h_V, edge_index, h_E, n_samples, temperature=0.1): 95 | ''' 96 | Samples sequences autoregressively from the distribution 97 | learned by the model. 98 | 99 | :param h_V: tuple (s, V) of node embeddings 100 | :param edge_index: `torch.Tensor` of shape [2, num_edges] 101 | :param h_E: tuple (s, V) of edge embeddings 102 | :param n_samples: number of samples 103 | :param temperature: temperature to use in softmax 104 | over the categorical distribution 105 | 106 | :return: int `torch.Tensor` of shape [n_samples, n_nodes] based on the 107 | residue-to-int mapping of the original training data 108 | ''' 109 | 110 | with torch.no_grad(): 111 | 112 | device = edge_index.device 113 | L = h_V[0].shape[0] 114 | 115 | h_V = self.W_v(h_V) 116 | h_E = self.W_e(h_E) 117 | 118 | for layer in self.encoder_layers: 119 | h_V = layer(h_V, edge_index, h_E) 120 | 121 | h_V = (h_V[0].repeat(n_samples, 1), 122 | h_V[1].repeat(n_samples, 1, 1)) 123 | 124 | h_E = (h_E[0].repeat(n_samples, 1), 125 | h_E[1].repeat(n_samples, 1, 1)) 126 | 127 | edge_index = edge_index.expand(n_samples, -1, -1) 128 | offset = L * torch.arange(n_samples, device=device).view(-1, 1, 1) 129 | edge_index = torch.cat(tuple(edge_index + offset), dim=-1) 130 | 131 | seq = torch.zeros(n_samples * L, device=device, dtype=torch.int) 132 | h_S = torch.zeros(n_samples * L, 20, device=device) 133 | 134 | h_V_cache = [(h_V[0].clone(), h_V[1].clone()) for _ in self.decoder_layers] 135 | 136 | for i in range(L): 137 | 138 | h_S_ = h_S[edge_index[0]] 139 | h_S_[edge_index[0] >= edge_index[1]] = 0 140 | h_E_ = (torch.cat([h_E[0], h_S_], dim=-1), h_E[1]) 141 | 142 | edge_mask = edge_index[1] % L == i 143 | edge_index_ = edge_index[:, edge_mask] 144 | h_E_ = tuple_index(h_E_, edge_mask) 145 | node_mask = torch.zeros(n_samples * L, device=device, dtype=torch.bool) 146 | node_mask[i::L] = True 147 | 148 | for j, layer in enumerate(self.decoder_layers): 149 | out = layer(h_V_cache[j], edge_index_, h_E_, 150 | autoregressive_x=h_V_cache[0], node_mask=node_mask) 151 | 152 | out = tuple_index(out, node_mask) 153 | 154 | if j < len(self.decoder_layers)-1: 155 | h_V_cache[j+1][0][i::L] = out[0] 156 | h_V_cache[j+1][1][i::L] = out[1] 157 | 158 | logits = self.W_out(out) 159 | seq[i::L] = Categorical(logits=logits / temperature).sample() 160 | h_S[i::L] = self.W_s(seq[i::L]) 161 | 162 | return seq.view(n_samples, L) 163 | 164 | class MQAModel(nn.Module): 165 | ''' 166 | GVP-GNN for Model Quality Assessment as described in manuscript. 167 | 168 | Takes in protein structure graphs of type `torch_geometric.data.Data` 169 | or `torch_geometric.data.Batch` and returns a scalar score for 170 | each graph in the batch in a `torch.Tensor` of shape [n_nodes] 171 | 172 | Should be used with `gvp.data.ProteinGraphDataset`, or with generators 173 | of `torch_geometric.data.Batch` objects with the same attributes. 174 | 175 | :param node_in_dim: node dimensions in input graph, should be 176 | (6, 3) if using original features 177 | :param node_h_dim: node dimensions to use in GVP-GNN layers 178 | :param node_in_dim: edge dimensions in input graph, should be 179 | (32, 1) if using original features 180 | :param edge_h_dim: edge dimensions to embed to before use 181 | in GVP-GNN layers 182 | :seq_in: if `True`, sequences will also be passed in with 183 | the forward pass; otherwise, sequence information 184 | is assumed to be part of input node embeddings 185 | :param num_layers: number of GVP-GNN layers 186 | :param drop_rate: rate to use in all dropout layers 187 | ''' 188 | def __init__(self, node_in_dim, node_h_dim, 189 | edge_in_dim, edge_h_dim, 190 | seq_in=False, num_layers=3, drop_rate=0.1): 191 | 192 | super(MQAModel, self).__init__() 193 | 194 | if seq_in: 195 | self.W_s = nn.Embedding(20, 20) 196 | node_in_dim = (node_in_dim[0] + 20, node_in_dim[1]) 197 | 198 | self.W_v = nn.Sequential( 199 | LayerNorm(node_in_dim), 200 | GVP(node_in_dim, node_h_dim, activations=(None, None)) 201 | ) 202 | self.W_e = nn.Sequential( 203 | LayerNorm(edge_in_dim), 204 | GVP(edge_in_dim, edge_h_dim, activations=(None, None)) 205 | ) 206 | 207 | self.layers = nn.ModuleList( 208 | GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) 209 | for _ in range(num_layers)) 210 | 211 | ns, _ = node_h_dim 212 | self.W_out = nn.Sequential( 213 | LayerNorm(node_h_dim), 214 | GVP(node_h_dim, (ns, 0))) 215 | 216 | self.dense = nn.Sequential( 217 | nn.Linear(ns, 2*ns), nn.ReLU(inplace=True), 218 | nn.Dropout(p=drop_rate), 219 | nn.Linear(2*ns, 1) 220 | ) 221 | 222 | def forward(self, h_V, edge_index, h_E, seq=None, batch=None): 223 | ''' 224 | :param h_V: tuple (s, V) of node embeddings 225 | :param edge_index: `torch.Tensor` of shape [2, num_edges] 226 | :param h_E: tuple (s, V) of edge embeddings 227 | :param seq: if not `None`, int `torch.Tensor` of shape [num_nodes] 228 | to be embedded and appended to `h_V` 229 | ''' 230 | if seq is not None: 231 | seq = self.W_s(seq) 232 | h_V = (torch.cat([h_V[0], seq], dim=-1), h_V[1]) 233 | h_V = self.W_v(h_V) 234 | h_E = self.W_e(h_E) 235 | for layer in self.layers: 236 | h_V = layer(h_V, edge_index, h_E) 237 | out = self.W_out(h_V) 238 | 239 | if batch is None: out = out.mean(dim=0, keepdims=True) 240 | else: out = scatter_mean(out, batch, dim=0) 241 | 242 | return self.dense(out).squeeze(-1) + 0.5 -------------------------------------------------------------------------------- /run_atom3d.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('task', metavar='TASK', choices=[ 5 | 'PSR', 'RSR', 'PPI', 'RES', 'MSP', 'SMP', 'LBA', 'LEP' 6 | ], help="{PSR, RSR, PPI, RES, MSP, SMP, LBA, LEP}") 7 | parser.add_argument('--num-workers', metavar='N', type=int, default=4, 8 | help='number of threads for loading data, default=4') 9 | parser.add_argument('--smp-idx', metavar='IDX', type=int, default=None, 10 | choices=list(range(20)), 11 | help='label index for SMP, in range 0-19') 12 | parser.add_argument('--lba-split', metavar='SPLIT', type=int, choices=[30, 60], 13 | help='identity cutoff for LBA, 30 (default) or 60', default=30) 14 | parser.add_argument('--batch', metavar='SIZE', type=int, default=8, 15 | help='batch size, default=8') 16 | parser.add_argument('--train-time', metavar='MINUTES', type=int, default=120, 17 | help='maximum time between evaluations on valset, default=120 minutes') 18 | parser.add_argument('--val-time', metavar='MINUTES', type=int, default=20, 19 | help='maximum time per evaluation on valset, default=20 minutes') 20 | parser.add_argument('--epochs', metavar='N', type=int, default=50, 21 | help='training epochs, default=50') 22 | parser.add_argument('--test', metavar='PATH', default=None, 23 | help='evaluate a trained model') 24 | parser.add_argument('--lr', metavar='RATE', default=1e-4, type=float, 25 | help='learning rate') 26 | parser.add_argument('--load', metavar='PATH', default=None, 27 | help='initialize first 2 GNN layers with pretrained weights') 28 | 29 | args = parser.parse_args() 30 | 31 | import gvp 32 | from atom3d.datasets import LMDBDataset 33 | import torch_geometric 34 | from functools import partial 35 | import gvp.atom3d 36 | import torch.nn as nn 37 | import tqdm, torch, time, os 38 | import numpy as np 39 | from atom3d.util import metrics 40 | import sklearn.metrics as sk_metrics 41 | from collections import defaultdict 42 | import scipy.stats as stats 43 | print = partial(print, flush=True) 44 | 45 | models_dir = 'models' 46 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 47 | model_id = float(time.time()) 48 | 49 | def main(): 50 | datasets = get_datasets(args.task, args.lba_split) 51 | dataloader = partial(torch_geometric.data.DataLoader, 52 | num_workers=args.num_workers, batch_size=args.batch) 53 | if args.task not in ['PPI', 'RES']: 54 | dataloader = partial(dataloader, shuffle=True) 55 | 56 | trainset, valset, testset = map(dataloader, datasets) 57 | model = get_model(args.task).to(device) 58 | 59 | if args.test: 60 | test(model, testset) 61 | 62 | else: 63 | if args.load: 64 | load(model, args.load) 65 | train(model, trainset, valset) 66 | 67 | def test(model, testset): 68 | model.load_state_dict(torch.load(args.test)) 69 | model.eval() 70 | t = tqdm.tqdm(testset) 71 | metrics = get_metrics(args.task) 72 | targets, predicts, ids = [], [], [] 73 | with torch.no_grad(): 74 | for batch in t: 75 | pred = forward(model, batch, device) 76 | label = get_label(batch, args.task, args.smp_idx) 77 | if args.task == 'RES': 78 | pred = pred.argmax(dim=-1) 79 | if args.task in ['PSR', 'RSR']: 80 | ids.extend(batch.id) 81 | targets.extend(list(label.cpu().numpy())) 82 | predicts.extend(list(pred.cpu().numpy())) 83 | 84 | for name, func in metrics.items(): 85 | if args.task in ['PSR', 'RSR']: 86 | func = partial(func, ids=ids) 87 | value = func(targets, predicts) 88 | print(f"{name}: {value}") 89 | 90 | def train(model, trainset, valset): 91 | 92 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 93 | 94 | best_path, best_val = None, np.inf 95 | 96 | for epoch in range(args.epochs): 97 | model.train() 98 | loss = loop(trainset, model, optimizer=optimizer, max_time=args.train_time) 99 | path = f"{models_dir}/{args.task}_{model_id}_{epoch}.pt" 100 | torch.save(model.state_dict(), path) 101 | print(f'\nEPOCH {epoch} TRAIN loss: {loss:.8f}') 102 | model.eval() 103 | with torch.no_grad(): 104 | loss = loop(valset, model, max_time=args.val_time) 105 | print(f'\nEPOCH {epoch} VAL loss: {loss:.8f}') 106 | if loss < best_val: 107 | best_path, best_val = path, loss 108 | print(f'BEST {best_path} VAL loss: {best_val:.8f}') 109 | 110 | def loop(dataset, model, optimizer=None, max_time=None): 111 | start = time.time() 112 | 113 | loss_fn = get_loss(args.task) 114 | t = tqdm.tqdm(dataset) 115 | total_loss, total_count = 0, 0 116 | 117 | for batch in t: 118 | if max_time and (time.time() - start) > 60*max_time: break 119 | if optimizer: optimizer.zero_grad() 120 | try: 121 | out = forward(model, batch, device) 122 | except RuntimeError as e: 123 | if "CUDA out of memory" not in str(e): raise(e) 124 | torch.cuda.empty_cache() 125 | print('Skipped batch due to OOM', flush=True) 126 | continue 127 | 128 | label = get_label(batch, args.task, args.smp_idx) 129 | loss_value = loss_fn(out, label) 130 | total_loss += float(loss_value) 131 | total_count += 1 132 | 133 | if optimizer: 134 | try: 135 | loss_value.backward() 136 | optimizer.step() 137 | except RuntimeError as e: 138 | if "CUDA out of memory" not in str(e): raise(e) 139 | torch.cuda.empty_cache() 140 | print('Skipped batch due to OOM', flush=True) 141 | continue 142 | 143 | t.set_description(f"{total_loss/total_count:.8f}") 144 | 145 | return total_loss / total_count 146 | 147 | def load(model, path): 148 | params = torch.load(path) 149 | state_dict = model.state_dict() 150 | for name, p in params.items(): 151 | if name in state_dict and \ 152 | name[:8] in ['layers.0', 'layers.1'] and \ 153 | state_dict[name].shape == p.shape: 154 | print("Loading", name) 155 | model.state_dict()[name].copy_(p) 156 | 157 | ####################################################################### 158 | 159 | def get_label(batch, task, smp_idx=None): 160 | if type(batch) in [list, tuple]: batch = batch[0] 161 | if task == 'SMP': 162 | assert smp_idx is not None 163 | return batch.label[smp_idx::20] 164 | return batch.label 165 | 166 | def get_metrics(task): 167 | def _correlation(metric, targets, predict, ids=None, glob=True): 168 | if glob: return metric(targets, predict) 169 | _targets, _predict = defaultdict(list), defaultdict(list) 170 | for _t, _p, _id in zip(targets, predict, ids): 171 | _targets[_id].append(_t) 172 | _predict[_id].append(_p) 173 | return np.mean([metric(_targets[_id], _predict[_id]) for _id in _targets]) 174 | 175 | correlations = { 176 | 'pearson': partial(_correlation, metrics.pearson), 177 | 'kendall': partial(_correlation, metrics.kendall), 178 | 'spearman': partial(_correlation, metrics.spearman) 179 | } 180 | mean_correlations = {f'mean {k}' : partial(v, glob=False) \ 181 | for k, v in correlations.items()} 182 | 183 | return { 184 | 'RSR' : {**correlations, **mean_correlations}, 185 | 'PSR' : {**correlations, **mean_correlations}, 186 | 'PPI' : {'auroc': metrics.auroc}, 187 | 'RES' : {'accuracy': metrics.accuracy}, 188 | 'MSP' : {'auroc': metrics.auroc, 'auprc': metrics.auprc}, 189 | 'LEP' : {'auroc': metrics.auroc, 'auprc': metrics.auprc}, 190 | 'LBA' : {**correlations, 'rmse': partial(sk_metrics.mean_squared_error, squared=False)}, 191 | 'SMP' : {'mae': sk_metrics.mean_absolute_error} 192 | }[task] 193 | 194 | def get_loss(task): 195 | if task in ['PSR', 'RSR', 'SMP', 'LBA']: return nn.MSELoss() # regression 196 | elif task in ['PPI', 'MSP', 'LEP']: return nn.BCELoss() # binary classification 197 | elif task in ['RES']: return nn.CrossEntropyLoss() # multiclass classification 198 | 199 | def forward(model, batch, device): 200 | if type(batch) in [list, tuple]: 201 | batch = batch[0].to(device), batch[1].to(device) 202 | else: 203 | batch = batch.to(device) 204 | return model(batch) 205 | 206 | def get_datasets(task, lba_split=30): 207 | data_path = { 208 | 'RES' : 'atom3d-data/RES/raw/RES/data/', 209 | 'PPI' : 'atom3d-data/PPI/splits/DIPS-split/data/', 210 | 'RSR' : 'atom3d-data/RSR/splits/candidates-split-by-time/data/', 211 | 'PSR' : 'atom3d-data/PSR/splits/split-by-year/data/', 212 | 'MSP' : 'atom3d-data/MSP/splits/split-by-sequence-identity-30/data/', 213 | 'LEP' : 'atom3d-data/LEP/splits/split-by-protein/data/', 214 | 'LBA' : f'atom3d-data/LBA/splits/split-by-sequence-identity-{lba_split}/data/', 215 | 'SMP' : 'atom3d-data/SMP/splits/random/data/' 216 | }[task] 217 | 218 | if task == 'RES': 219 | split_path = 'atom3d-data/RES/splits/split-by-cath-topology/indices/' 220 | dataset = partial(gvp.atom3d.RESDataset, data_path) 221 | trainset = dataset(split_path=split_path+'train_indices.txt') 222 | valset = dataset(split_path=split_path+'val_indices.txt') 223 | testset = dataset(split_path=split_path+'test_indices.txt') 224 | 225 | elif task == 'PPI': 226 | trainset = gvp.atom3d.PPIDataset(data_path+'train') 227 | valset = gvp.atom3d.PPIDataset(data_path+'val') 228 | testset = gvp.atom3d.PPIDataset(data_path+'test') 229 | 230 | else: 231 | transform = { 232 | 'RSR' : gvp.atom3d.RSRTransform, 233 | 'PSR' : gvp.atom3d.PSRTransform, 234 | 'MSP' : gvp.atom3d.MSPTransform, 235 | 'LEP' : gvp.atom3d.LEPTransform, 236 | 'LBA' : gvp.atom3d.LBATransform, 237 | 'SMP' : gvp.atom3d.SMPTransform, 238 | }[task]() 239 | 240 | trainset = LMDBDataset(data_path+'train', transform=transform) 241 | valset = LMDBDataset(data_path+'val', transform=transform) 242 | testset = LMDBDataset(data_path+'test', transform=transform) 243 | 244 | return trainset, valset, testset 245 | 246 | def get_model(task): 247 | return { 248 | 'RES' : gvp.atom3d.RESModel, 249 | 'PPI' : gvp.atom3d.PPIModel, 250 | 'RSR' : gvp.atom3d.RSRModel, 251 | 'PSR' : gvp.atom3d.PSRModel, 252 | 'MSP' : gvp.atom3d.MSPModel, 253 | 'LEP' : gvp.atom3d.LEPModel, 254 | 'LBA' : gvp.atom3d.LBAModel, 255 | 'SMP' : gvp.atom3d.SMPModel 256 | }[task]() 257 | 258 | if __name__ == "__main__": 259 | main() -------------------------------------------------------------------------------- /run_cpd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--models-dir', metavar='PATH', default='./models/', 5 | help='directory to save trained models, default=./models/') 6 | parser.add_argument('--num-workers', metavar='N', type=int, default=4, 7 | help='number of threads for loading data, default=4') 8 | parser.add_argument('--max-nodes', metavar='N', type=int, default=3000, 9 | help='max number of nodes per batch, default=3000') 10 | parser.add_argument('--epochs', metavar='N', type=int, default=100, 11 | help='training epochs, default=100') 12 | parser.add_argument('--cath-data', metavar='PATH', default='./data/chain_set.jsonl', 13 | help='location of CATH dataset, default=./data/chain_set.jsonl') 14 | parser.add_argument('--cath-splits', metavar='PATH', default='./data/chain_set_splits.json', 15 | help='location of CATH split file, default=./data/chain_set_splits.json') 16 | parser.add_argument('--ts50', metavar='PATH', default='./data/ts50.json', 17 | help='location of TS50 dataset, default=./data/ts50.json') 18 | parser.add_argument('--train', action="store_true", help="train a model") 19 | parser.add_argument('--test-r', metavar='PATH', default=None, 20 | help='evaluate a trained model on recovery (without training)') 21 | parser.add_argument('--test-p', metavar='PATH', default=None, 22 | help='evaluate a trained model on perplexity (without training)') 23 | parser.add_argument('--n-samples', metavar='N', default=100, 24 | help='number of sequences to sample (if testing recovery), default=100') 25 | 26 | args = parser.parse_args() 27 | assert sum(map(bool, [args.train, args.test_p, args.test_r])) == 1, \ 28 | "Specify exactly one of --train, --test_r, --test_p" 29 | 30 | import torch 31 | import torch.nn as nn 32 | import gvp.data, gvp.models 33 | from datetime import datetime 34 | import tqdm, os, json 35 | import numpy as np 36 | from sklearn.metrics import confusion_matrix 37 | import torch_geometric 38 | from functools import partial 39 | print = partial(print, flush=True) 40 | 41 | node_dim = (100, 16) 42 | edge_dim = (32, 1) 43 | device = "cuda" if torch.cuda.is_available() else "cpu" 44 | 45 | if not os.path.exists(args.models_dir): os.makedirs(args.models_dir) 46 | model_id = int(datetime.timestamp(datetime.now())) 47 | dataloader = lambda x: torch_geometric.data.DataLoader(x, 48 | num_workers=args.num_workers, 49 | batch_sampler=gvp.data.BatchSampler( 50 | x.node_counts, max_nodes=args.max_nodes)) 51 | 52 | def main(): 53 | 54 | model = gvp.models.CPDModel((6, 3), node_dim, (32, 1), edge_dim).to(device) 55 | 56 | print("Loading CATH dataset") 57 | cath = gvp.data.CATHDataset(path="data/chain_set.jsonl", 58 | splits_path="data/chain_set_splits.json") 59 | 60 | trainset, valset, testset = map(gvp.data.ProteinGraphDataset, 61 | (cath.train, cath.val, cath.test)) 62 | 63 | if args.test_r or args.test_p: 64 | ts50set = gvp.data.ProteinGraphDataset(json.load(open(args.ts50))) 65 | model.load_state_dict(torch.load(args.test_r or args.test_p)) 66 | 67 | if args.test_r: 68 | print("Testing on CATH testset"); test_recovery(model, testset) 69 | print("Testing on TS50 set"); test_recovery(model, ts50set) 70 | 71 | elif args.test_p: 72 | print("Testing on CATH testset"); test_perplexity(model, testset) 73 | print("Testing on TS50 set"); test_perplexity(model, ts50set) 74 | 75 | elif args.train: 76 | train(model, trainset, valset, testset) 77 | 78 | 79 | def train(model, trainset, valset, testset): 80 | train_loader, val_loader, test_loader = map(dataloader, 81 | (trainset, valset, testset)) 82 | optimizer = torch.optim.Adam(model.parameters()) 83 | best_path, best_val = None, np.inf 84 | lookup = train_loader.dataset.num_to_letter 85 | for epoch in range(args.epochs): 86 | model.train() 87 | loss, acc, confusion = loop(model, train_loader, optimizer=optimizer) 88 | path = f"{args.models_dir}/{model_id}_{epoch}.pt" 89 | torch.save(model.state_dict(), path) 90 | print(f'EPOCH {epoch} TRAIN loss: {loss:.4f} acc: {acc:.4f}') 91 | print_confusion(confusion, lookup=lookup) 92 | 93 | model.eval() 94 | with torch.no_grad(): 95 | loss, acc, confusion = loop(model, val_loader) 96 | print(f'EPOCH {epoch} VAL loss: {loss:.4f} acc: {acc:.4f}') 97 | print_confusion(confusion, lookup=lookup) 98 | 99 | if loss < best_val: 100 | best_path, best_val = path, loss 101 | print(f'BEST {best_path} VAL loss: {best_val:.4f}') 102 | 103 | print(f"TESTING: loading from {best_path}") 104 | model.load_state_dict(torch.load(best_path)) 105 | 106 | model.eval() 107 | with torch.no_grad(): 108 | loss, acc, confusion = loop(model, test_loader) 109 | print(f'TEST loss: {loss:.4f} acc: {acc:.4f}') 110 | print_confusion(confusion,lookup=lookup) 111 | 112 | def test_perplexity(model, dataset): 113 | model.eval() 114 | with torch.no_grad(): 115 | loss, acc, confusion = loop(model, dataloader(dataset)) 116 | print(f'TEST perplexity: {np.exp(loss):.4f}') 117 | print_confusion(confusion, lookup=dataset.num_to_letter) 118 | 119 | def test_recovery(model, dataset): 120 | recovery = [] 121 | 122 | for protein in tqdm.tqdm(dataset): 123 | protein = protein.to(device) 124 | h_V = (protein.node_s, protein.node_v) 125 | h_E = (protein.edge_s, protein.edge_v) 126 | sample = model.sample(h_V, protein.edge_index, 127 | h_E, n_samples=args.n_samples) 128 | 129 | recovery_ = sample.eq(protein.seq).float().mean().cpu().numpy() 130 | recovery.append(recovery_) 131 | print(protein.name, recovery_, flush=True) 132 | 133 | recovery = np.median(recovery) 134 | print(f'TEST recovery: {recovery:.4f}') 135 | 136 | def loop(model, dataloader, optimizer=None): 137 | 138 | confusion = np.zeros((20, 20)) 139 | t = tqdm.tqdm(dataloader) 140 | loss_fn = nn.CrossEntropyLoss() 141 | total_loss, total_correct, total_count = 0, 0, 0 142 | 143 | for batch in t: 144 | if optimizer: optimizer.zero_grad() 145 | 146 | batch = batch.to(device) 147 | h_V = (batch.node_s, batch.node_v) 148 | h_E = (batch.edge_s, batch.edge_v) 149 | 150 | logits = model(h_V, batch.edge_index, h_E, seq=batch.seq) 151 | logits, seq = logits[batch.mask], batch.seq[batch.mask] 152 | loss_value = loss_fn(logits, seq) 153 | 154 | if optimizer: 155 | loss_value.backward() 156 | optimizer.step() 157 | 158 | num_nodes = int(batch.mask.sum()) 159 | total_loss += float(loss_value) * num_nodes 160 | total_count += num_nodes 161 | pred = torch.argmax(logits, dim=-1).detach().cpu().numpy() 162 | true = seq.detach().cpu().numpy() 163 | total_correct += (pred == true).sum() 164 | confusion += confusion_matrix(true, pred, labels=range(20)) 165 | t.set_description("%.5f" % float(total_loss/total_count)) 166 | 167 | torch.cuda.empty_cache() 168 | 169 | return total_loss / total_count, total_correct / total_count, confusion 170 | 171 | def print_confusion(mat, lookup): 172 | counts = mat.astype(np.int32) 173 | mat = (counts.T / counts.sum(axis=-1, keepdims=True).T).T 174 | mat = np.round(mat * 1000).astype(np.int32) 175 | res = '\n' 176 | for i in range(20): 177 | res += '\t{}'.format(lookup[i]) 178 | res += '\tCount\n' 179 | for i in range(20): 180 | res += '{}\t'.format(lookup[i]) 181 | res += '\t'.join('{}'.format(n) for n in mat[i]) 182 | res += '\t{}\n'.format(sum(counts[i])) 183 | print(res) 184 | 185 | if __name__== "__main__": 186 | main() -------------------------------------------------------------------------------- /schematic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drorlab/gvp-pytorch/82af6b22eaf8311c15733117b0071408d24ed877/schematic.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setup( 7 | name='gvp', 8 | packages=find_packages(include=[ 9 | 'gvp', 10 | 'gvp.data', 11 | 'gvp.models' 12 | ]), 13 | version='0.1.1', 14 | description='Geometric Vector Perceptron', 15 | license='MIT', 16 | long_description=long_description, 17 | long_description_content_type="text/markdown", 18 | install_requires=[ 19 | 'torch', 20 | 'torch_geometric', 21 | 'torch_scatter', 22 | 'torch_cluster', 23 | 'tqdm', 24 | 'numpy', 25 | 'sklearn', 26 | 'atom3d' 27 | ] 28 | ) -------------------------------------------------------------------------------- /test_equivariance.py: -------------------------------------------------------------------------------- 1 | import gvp 2 | import gvp.models 3 | import gvp.data 4 | import torch 5 | from torch import nn 6 | from scipy.spatial.transform import Rotation 7 | import unittest 8 | 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | node_dim = (100, 16) 12 | edge_dim = (32, 1) 13 | n_nodes = 300 14 | n_edges = 10000 15 | 16 | nodes = gvp.randn(n_nodes, node_dim, device=device) 17 | edges = gvp.randn(n_edges, edge_dim, device=device) 18 | edge_index = torch.randint(0, n_nodes, (2, n_edges), device=device) 19 | batch_idx = torch.randint(0, 5, (n_nodes,), device=device) 20 | seq = torch.randint(0, 20, (n_nodes,), device=device) 21 | 22 | class EquivarianceTest(unittest.TestCase): 23 | 24 | def test_gvp(self): 25 | model = gvp.GVP(node_dim, node_dim).to(device).eval() 26 | model_fn = lambda h_V, h_E: model(h_V) 27 | test_equivariance(model_fn, nodes, edges) 28 | 29 | def test_gvp_vector_gate(self): 30 | model = gvp.GVP(node_dim, node_dim, vector_gate=True).to(device).eval() 31 | model_fn = lambda h_V, h_E: model(h_V) 32 | test_equivariance(model_fn, nodes, edges) 33 | 34 | def test_gvp_sequence(self): 35 | model = nn.Sequential( 36 | gvp.GVP(node_dim, node_dim), 37 | gvp.Dropout(0.1), 38 | gvp.LayerNorm(node_dim) 39 | ).to(device).eval() 40 | model_fn = lambda h_V, h_E: model(h_V) 41 | test_equivariance(model_fn, nodes, edges) 42 | 43 | def test_gvp_sequence_vector_gate(self): 44 | model = nn.Sequential( 45 | gvp.GVP(node_dim, node_dim, vector_gate=True), 46 | gvp.Dropout(0.1), 47 | gvp.LayerNorm(node_dim) 48 | ).to(device).eval() 49 | model_fn = lambda h_V, h_E: model(h_V) 50 | test_equivariance(model_fn, nodes, edges) 51 | 52 | def test_gvp_conv(self): 53 | model = gvp.GVPConv(node_dim, node_dim, edge_dim).to(device).eval() 54 | model_fn = lambda h_V, h_E: model(h_V, edge_index, h_E) 55 | test_equivariance(model_fn, nodes, edges) 56 | 57 | def test_gvp_conv_vector_gate(self): 58 | model = gvp.GVPConv(node_dim, node_dim, edge_dim, vector_gate=True).to(device).eval() 59 | model_fn = lambda h_V, h_E: model(h_V, edge_index, h_E) 60 | test_equivariance(model_fn, nodes, edges) 61 | 62 | def test_gvp_conv_layer(self): 63 | model = gvp.GVPConvLayer(node_dim, edge_dim).to(device).eval() 64 | model_fn = lambda h_V, h_E: model(h_V, edge_index, h_E, 65 | autoregressive_x=h_V) 66 | test_equivariance(model_fn, nodes, edges) 67 | 68 | def test_gvp_conv_layer_vector_gate(self): 69 | model = gvp.GVPConvLayer(node_dim, edge_dim, vector_gate=True).to(device).eval() 70 | model_fn = lambda h_V, h_E: model(h_V, edge_index, h_E, 71 | autoregressive_x=h_V) 72 | test_equivariance(model_fn, nodes, edges) 73 | 74 | def test_mqa_model(self): 75 | model = gvp.models.MQAModel(node_dim, node_dim, 76 | edge_dim, edge_dim).to(device).eval() 77 | model_fn = lambda h_V, h_E: (model(h_V, edge_index, h_E, batch=batch_idx), \ 78 | torch.zeros_like(nodes[1])) 79 | test_equivariance(model_fn, nodes, edges) 80 | 81 | def test_cpd_model(self): 82 | model = gvp.models.CPDModel(node_dim, node_dim, 83 | edge_dim, edge_dim).to(device).eval() 84 | model_fn = lambda h_V, h_E: (model(h_V, edge_index, h_E, seq=seq), \ 85 | torch.zeros_like(nodes[1])) 86 | test_equivariance(model_fn, nodes, edges) 87 | 88 | 89 | def test_equivariance(model, nodes, edges): 90 | 91 | random = torch.as_tensor(Rotation.random().as_matrix(), 92 | dtype=torch.float32, device=device) 93 | 94 | with torch.no_grad(): 95 | 96 | out_s, out_v = model(nodes, edges) 97 | n_v_rot, e_v_rot = nodes[1] @ random, edges[1] @ random 98 | out_v_rot = out_v @ random 99 | out_s_prime, out_v_prime = model((nodes[0], n_v_rot), (edges[0], e_v_rot)) 100 | 101 | assert torch.allclose(out_s, out_s_prime, atol=1e-5, rtol=1e-4) 102 | assert torch.allclose(out_v_rot, out_v_prime, atol=1e-5, rtol=1e-4) 103 | 104 | 105 | if __name__ == "__main__": 106 | unittest.main() -------------------------------------------------------------------------------- /vectors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drorlab/gvp-pytorch/82af6b22eaf8311c15733117b0071408d24ed877/vectors.png --------------------------------------------------------------------------------