├── LICENSE
├── README.md
├── data
├── getCATH.sh
├── ts50.json
└── ts50remove.txt
├── gvp
├── __init__.py
├── atom3d.py
├── data.py
└── models.py
├── run_atom3d.py
├── run_cpd.py
├── schematic.png
├── setup.py
├── test_equivariance.py
└── vectors.png
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Bowen Jing, Stephan Eismann, Pratham Soni,
4 | Patricia Suriana, Raphael Townshend, Ron Dror
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Geometric Vector Perceptron
2 |
3 | Implementation of equivariant GVP-GNNs as described in [Learning from Protein Structure with Geometric Vector Perceptrons](https://openreview.net/forum?id=1YLJDvSx6J4) by B Jing, S Eismann, P Suriana, RJL Townshend, and RO Dror.
4 |
5 | **UPDATE:** Also includes equivariant GNNs with vector gating as described in [Equivariant Graph Neural Networks for 3D Macromolecular Structure](https://arxiv.org/abs/2106.03843) by B Jing, S Eismann, P Soni, and RO Dror.
6 |
7 | Scripts for training / testing / sampling on protein design and training / testing on all [ATOM3D](https://arxiv.org/abs/2012.04035) tasks are provided.
8 |
9 | **Note:** This implementation is in PyTorch Geometric. The original TensorFlow code, which is not maintained, can be found [here](https://github.com/drorlab/gvp).
10 |
11 |

12 |
13 | 
14 |
15 | ## Requirements
16 | * UNIX environment
17 | * python==3.6.13
18 | * torch==1.8.1
19 | * torch_geometric==1.7.0
20 | * torch_scatter==2.0.6
21 | * torch_cluster==1.5.9
22 | * tqdm==4.38.0
23 | * numpy==1.19.4
24 | * sklearn==0.24.1
25 | * atom3d==0.2.1
26 |
27 | While we have not tested with other versions, any reasonably recent versions of these requirements should work.
28 |
29 | ## General usage
30 |
31 | We provide classes in three modules:
32 | * `gvp`: core GVP modules and GVP-GNN layers
33 | * `gvp.data`: data pipelines for both general use and protein design
34 | * `gvp.models`: implementations of MQA and CPD models
35 | * `gvp.atom3d`: models and data pipelines for ATOM3D
36 |
37 | The core modules in `gvp` are meant to be as general as possible, but you will likely have to modify `gvp.data` and `gvp.models` for your specific application, with the existing classes serving as examples.
38 |
39 | **Installation:** Download this repository and run `python setup.py develop` or `pip install . -e`. Be sure to manually install `torch_geometric` first!
40 |
41 | **Tuple representation:** All inputs and outputs with both scalar and vector channels are represented as a tuple of two tensors `(s, V)`. Similarly, all dimensions should be specified as tuples `(n_scalar, n_vector)` where `n_scalar` and `n_vector` are the number of scalar and vector features, respectively. All `V` tensors must be shaped as `[..., n_vector, 3]`, not `[..., 3, n_vector]`.
42 |
43 | **Batching:** We adopt the `torch_geometric` convention of absorbing the batch dimension into the node dimension and keeping track of batch index in a separate tensor.
44 |
45 | **Amino acids:** Models view sequences as int tensors and are agnostic to aa-to-int mappings. Such mappings are specified as the `letter_to_num` attribute of `gvp.data.ProteinGraphDataset`. Currently, only the 20 standard amino acids are supported.
46 |
47 | For all classes, see the docstrings for more detailed usage. If you have any questions, please contact bjing@cs.stanford.edu.
48 |
49 | ### Core GVP classes
50 |
51 | The class `gvp.GVP` implements a Geometric Vector Perceptron.
52 | ```
53 | import gvp
54 |
55 | in_dims = scalars_in, vectors_in
56 | out_dims = scalars_out, vectors_out
57 | gvp_ = gvp.GVP(in_dims, out_dims)
58 | ```
59 | To use vector gating, pass in `vector_gate=True` and the appropriate activations.
60 | ```
61 | gvp_ = gvp.GVP(in_dims, out_dims,
62 | activations=(F.relu, None), vector_gate=True)
63 | ```
64 | The classes `gvp.Dropout` and `gvp.LayerNorm` implement vector-channel dropout and layer norm, while using normal dropout and layer norm for scalar channels. Both expect inputs and return outputs of form `(s, V)`, but will also behave like their scalar-valued counterparts if passed a single tensor.
65 | ```
66 | dropout = gvp.Dropout(drop_rate=0.1)
67 | layernorm = gvp.LayerNorm(out_dims)
68 | ```
69 | The function `gvp.randn` returns tuples `(s, V)` drawn from a standard normal. Such tuples can be directly used in a forward pass.
70 | ```
71 | x = gvp.randn(n=5, dims=in_dims)
72 | # x = (s, V) with s.shape = [5, scalars_in] and V.shape = [5, vectors_in, 3]
73 |
74 | out = gvp_(x)
75 | out = drouput(out)
76 | out = layernorm(out)
77 | ```
78 | Finally, we provide utility functions for adding, concatenating, and indexing into such tuples.
79 | ```
80 | y = gvp.randn(n=5, dims=in_dims)
81 | z = gvp.tuple_sum(x, y)
82 | z = gvp.tuple_cat(x, y, dim=-1) # concat along channel axis
83 | z = gvp.tuple_cat(x, y, dim=-2) # concat along node / batch axis
84 |
85 | node_mask = torch.rand(5) < 0.5
86 | z = gvp.tuple_index(x, node_mask) # select half the nodes / batch at random
87 | ```
88 | ### GVP-GNN layers
89 | The class `GVPConv` is a `torch_geometric.MessagePassing` module which forms messages and aggregates them at the destination node, returning new node embeddings. The original embeddings are not updated.
90 | ```
91 | nodes = gvp.randn(n=5, in_dims)
92 | edges = gvp.randn(n=10, edge_dims) # 10 random edges
93 | edge_index = torch.randint(0, 5, (2, 10), device=device)
94 |
95 | conv = gvp.GVPConv(in_dims, out_dims, edge_dims)
96 | out = conv(nodes, edge_index, edges)
97 | ```
98 | The class `GVPConvLayer` is a `nn.Module` that forms messages using a `GVPConv` and updates the node embeddings as described in the paper. Because the updates are residual, the dimensionality of the embeddings are not changed.
99 | ```
100 | layer = gvp.GVPConvLayer(node_dims, edge_dims)
101 | nodes = layer(nodes, edge_index, edges)
102 | ```
103 | The class also allows updates where incoming messages where src >= dst are computed using a different set of source embeddings, as in autoregressive models.
104 | ```
105 | nodes_static = gvp.randn(n=5, in_dims)
106 | layer = gvp.GVPConvLayer(node_dims, edge_dims, autoregressive=True)
107 | nodes = layer(nodes, edge_index, edges, autoregressive_x=nodes_static)
108 | ```
109 | Both `GVPConv` and `GVPConvLayer` accept arguments `activations` and `vector_gate` to use vector gating.
110 |
111 | ### Loading data
112 |
113 | The class `gvp.data.ProteinGraphDataset` transforms protein backbone structures into featurized graphs. Following [Ingraham, et al, NeurIPS 2019](https://github.com/jingraham/neurips19-graph-protein-design), we use a JSON/dictionary format to specify backbone structures:
114 |
115 | ```
116 | [
117 | {
118 | "name": "NAME"
119 | "seq": "TQDCSFQHSP...",
120 | "coords": [[[74.46, 58.25, -21.65],...],...]
121 | }
122 | ...
123 | ]
124 | ```
125 | For each structure, `coords` should be a `num_residues x 4 x 3` nested list of the positions of the backbone N, C-alpha, C, and O atoms of each residue (in that order).
126 | ```
127 | import gvp.data
128 |
129 | # structures is a list or list-like as shown above
130 | dataset = gvp.data.ProteinGraphDataset(structures)
131 | # dataset[i] is featurized graph corresponding to structures[i]
132 | ```
133 | The returned graphs are of type `torch_geometric.data.Data` with attributes
134 | * `x`: alpha carbon coordinates
135 | * `seq`: sequence converted to int tensor according to attribute `self.letter_to_num`
136 | * `name`, `edge_index`
137 | * `node_s`, `node_v`: node features as described in the paper with dims `(6, 3)`
138 | * `edge_s`, `edge_v`: edge features as described in the paper with dims `(32, 1)`
139 | * `mask`: false for nodes with any nan coordinates
140 |
141 | The `gvp.data.ProteinGraphDataset` can be used with a `torch.utils.data.DataLoader`. We supply a class `gvp.data.BatchSampler` which will form batches based on the number of total nodes in a batch. Use of this sampler is optional.
142 | ```
143 | node_counts = [len(s['seq']) for s in structures]
144 | sampler = gvp.data.BatchSampler(node_counts, max_nodes=3000)
145 | dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler)
146 | ```
147 | The dataloader will return batched graphs of type `torch_geometric.data.Batch` with an additional `batch` attibute. The attributes of the `Batch` will then need to be formed into `(s, V)` tuples before passing into a GVP-GNN layer or network.
148 | ```
149 | for batch in dataloader:
150 | batch = batch.to(device) # optional
151 | nodes = (batch.node_s, batch.node_v)
152 | edges = (batch.edge_s, batch.edge_v)
153 |
154 | out = layer(nodes, batch.edge_index, edges)
155 | ```
156 |
157 | ### Ready-to-use protein GNNs
158 | We provide two fully specified networks which take in protein graphs and output a scalar prediction for each graph (`gvp.models.MQAModel`) or a 20-dimensional feature vector for each node (`gvp.models.CPDModel`), corresponding to the two tasks in our paper. Note that if you are using the unmodified `gvp.data.ProteinGraphDataset`, `node_in_dims` and `edge_in_dims` must be `(6, 3)` and `(32, 1)`, respectively.
159 | ```
160 | import gvp.models
161 |
162 | # batch, nodes, edges as formed above
163 |
164 | mqa_model = gvp.models.MQAModel(node_in_dim, node_h_dim,
165 | edge_in_dim, edge_h_dim, seq_in=True)
166 | out = mqa_model(nodes, batch.edge_index, edges,
167 | seq=batch.seq, batch=batch.batch) # shape (n_graphs,)
168 |
169 | cpd_model = gvp.models.CPDModel(node_in_dim, node_h_dim,
170 | edge_in_dim, edge_h_dim)
171 | out = cpd_model(nodes, batch.edge_index,
172 | edges, batch.seq) # shape (n_nodes, 20)
173 | ```
174 |
175 | ## Protein design
176 | We provide a script `run_cpd.py` to train, validate, and test a `CPDModel` as specified in the paper using the CATH 4.2 dataset and TS50 dataset. If you want to use a trained model on new structures, see the section "Sampling" below.
177 |
178 | ### Fetching data
179 | Run `getCATH.sh` in `data/` to fetch the CATH 4.2 dataset. If you are interested in testing on the TS 50 test set, also run `grep -Fv -f ts50remove.txt chain_set.jsonl > chain_set_ts50.jsonl` to produce a training set without overlap with the TS 50 test set.
180 |
181 | ### Training / testing
182 | To train a model, simply run `python run_cpd.py --train`. To test a trained model on both the CATH 4.2 test set and the TS50 test set, run `python run_cpd --test-r PATH` for perplexity or with `--test-p` for perplexity. Run `python run_cpd.py -h` for more detailed options.
183 |
184 | ```
185 | $ python run_cpd.py -h
186 |
187 | usage: run_cpd.py [-h] [--models-dir PATH] [--num-workers N] [--max-nodes N] [--epochs N] [--cath-data PATH] [--cath-splits PATH] [--ts50 PATH] [--train] [--test-r PATH] [--test-p PATH] [--n-samples N]
188 |
189 | optional arguments:
190 | -h, --help show this help message and exit
191 | --models-dir PATH directory to save trained models, default=./models/
192 | --num-workers N number of threads for loading data, default=4
193 | --max-nodes N max number of nodes per batch, default=3000
194 | --epochs N training epochs, default=100
195 | --cath-data PATH location of CATH dataset, default=./data/chain_set.jsonl
196 | --cath-splits PATH location of CATH split file, default=./data/chain_set_splits.json
197 | --ts50 PATH location of TS50 dataset, default=./data/ts50.json
198 | --train train a model
199 | --test-r PATH evaluate a trained model on recovery (without training)
200 | --test-p PATH evaluate a trained model on perplexity (without training)
201 | --n-samples N number of sequences to sample (if testing recovery), default=100
202 | ```
203 | **Confusion matrices:** Note that the values are normalized such that each row (corresponding to true class) sums to 1000, with the actual number of residues in that class printed under the "Count" column.
204 |
205 | ### Sampling
206 | To sample from a `CPDModel`, prepare a `ProteinGraphDataset`, but do NOT pass into a `DataLoader`. The sequences are not used, so placeholders can be used for the `seq` attributes of the original structures dicts.
207 |
208 | ```
209 | protein = dataset[i]
210 | nodes = (protein.node_s, protein.node_v)
211 | edges = (protein.edge_s, protein.edge_v)
212 |
213 | sample = model.sample(nodes, protein.edge_index, # shape = (n_samples, n_nodes)
214 | edges, n_samples=n_samples)
215 | ```
216 | The output will be an int tensor, with mappings corresponding to those used when training the model.
217 |
218 | ## ATOM3D
219 | We provide models and dataloaders for all ATOM3D tasks in `gvp.atom3d`, as well as a training and testing script in `run_atom3d.py`. This also supports loading pretrained weights for transfer learning experiments.
220 |
221 | ### Models / data loaders
222 | The GVP-GNNs for ATOM3D are supplied in `gvp.atom3d` and are named after each task: `gvp.atom3d.MSPModel`, `gvp.atom3d.PPIModel`, etc. All of these extend the base class `gvp.atom3d.BaseModel`. These classes take no arguments at initialization, take in a `torch_geometric.data.Batch` representation of a batch of structures, and return an output corresponding to the task. Details vary based on the exact task---see the docstrings.
223 | ```
224 | psr_model = gvp.atom3d.PSRModel()
225 | ```
226 | `gvp.atom3d` also includes data loaders to produce `torch_geometric.data.Batch` objects from an underlying `atom3d.datasets.LMDBDataset`. In the case of all tasks except PPI and RES, these are in the form of callable transform objects---`gvp.atom3d.SMPTransform`, `gvp.atom3d.RSRTransform`, etc---which should be passed into the constructor of a `atom3d.datasets.LMDBDataset`:
227 | ```
228 | psr_dataset = atom3d.datasets.LMDBDataset(path_to_dataset,
229 | transform=gvp.atom3d.PSRTransform())
230 | ```
231 | On the other hand, `gvp.atom3d.PPIDataset` and `gvp.atom3d.RESDataset` take the place of / are wrappers around the `atom3d.datasets.LMDBDataset`:
232 | ```
233 | ppi_dataset = gvp.atom3d.PPIDataset(path_to_dataset)
234 | res_dataset = gvp.atom3d.RESDataset(path_to_dataset, path_to_split) # see docstring
235 | ```
236 | All datasets must be then wrapped in a `torch_geometric.data.DataLoader`:
237 | ```
238 | psr_dataloader = torch_geometric.data.DataLoader(psr_dataset, batch_size=batch_size)
239 | ```
240 | The dataloaders can be directly iterated over to yield `torch_geometric.data.Batch` objects, which can then be passed into the models.
241 | ```
242 | for batch in psr_dataloader:
243 | pred = psr_model(batch) # pred.shape = (batch_size,)
244 | ```
245 |
246 | ### Training / testing
247 |
248 | To run training / testing on ATOM3D, download the datasets as described [here](https://www.atom3d.ai/). Modify the function `get_datasets` in `run_atom3d.py` with the paths to the datasets. Then run:
249 | ```
250 | $ python run_atom3d.py -h
251 |
252 | usage: run_atom3d.py [-h] [--num-workers N] [--smp-idx IDX]
253 | [--lba-split SPLIT] [--batch SIZE] [--train-time MINUTES]
254 | [--val-time MINUTES] [--epochs N] [--test PATH]
255 | [--lr RATE] [--load PATH]
256 | TASK
257 |
258 | positional arguments:
259 | TASK {PSR, RSR, PPI, RES, MSP, SMP, LBA, LEP}
260 |
261 | optional arguments:
262 | -h, --help show this help message and exit
263 | --num-workers N number of threads for loading data, default=4
264 | --smp-idx IDX label index for SMP, in range 0-19
265 | --lba-split SPLIT identity cutoff for LBA, 30 (default) or 60
266 | --batch SIZE batch size, default=8
267 | --train-time MINUTES maximum time between evaluations on valset,
268 | default=120 minutes
269 | --val-time MINUTES maximum time per evaluation on valset, default=20
270 | minutes
271 | --epochs N training epochs, default=50
272 | --test PATH evaluate a trained model
273 | --lr RATE learning rate
274 | --load PATH initialize first 2 GNN layers with pretrained weights
275 | ```
276 | For example:
277 | ```
278 | # train a model
279 | python run_atom3d.py PSR
280 |
281 | # train a model with pretrained weights
282 | python run_atom3d.py PSR --load PATH
283 |
284 | # evaluate a model
285 | python run_atom3d.py PSR --test PATH
286 | ```
287 |
288 | ## Acknowledgements
289 | Portions of the input data pipeline were adapted from [Ingraham, et al, NeurIPS 2019](https://github.com/jingraham/neurips19-graph-protein-design). We thank Pratham Soni for portions of the implementation in PyTorch.
290 |
291 | ## Citation
292 | ```
293 | @inproceedings{
294 | jing2021learning,
295 | title={Learning from Protein Structure with Geometric Vector Perceptrons},
296 | author={Bowen Jing and Stephan Eismann and Patricia Suriana and Raphael John Lamarre Townshend and Ron Dror},
297 | booktitle={International Conference on Learning Representations},
298 | year={2021},
299 | url={https://openreview.net/forum?id=1YLJDvSx6J4}
300 | }
301 |
302 | @article{jing2021equivariant,
303 | title={Equivariant Graph Neural Networks for 3D Macromolecular Structure},
304 | author={Jing, Bowen and Eismann, Stephan and Soni, Pratham N and Dror, Ron O},
305 | journal={arXiv preprint arXiv:2106.03843},
306 | year={2021}
307 | }
308 | ```
309 |
--------------------------------------------------------------------------------
/data/getCATH.sh:
--------------------------------------------------------------------------------
1 | wget http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/chain_set.jsonl
2 | wget http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/chain_set_splits.json
3 | wget http://people.csail.mit.edu/ingraham/graph-protein-design/data/SPIN2/test_split_L100.json
4 | wget http://people.csail.mit.edu/ingraham/graph-protein-design/data/SPIN2/test_split_sc.json
5 |
--------------------------------------------------------------------------------
/data/ts50remove.txt:
--------------------------------------------------------------------------------
1 | 1ypo.A
2 | 1y1l.A
3 | 1am2.A
4 | 3id9.A
5 | 3zid.B
6 | 3vpp.B
7 | 1gyc.A
8 | 3r79.A
9 | 3e82.A
10 | 3lqc.A
11 | 2cvi.A
12 | 2v6y.A
13 | 3nng.A
14 | 1zak.A
15 | 2nvp.A
16 | 1vef.A
17 | 2kp5.A
18 | 2b3h.A
19 | 2n8k.A
20 | 2f91.A
21 | 3e48.A
22 | 2v3t.A
23 | 4jvs.A
24 | 3fpw.A
25 | 1c0n.A
26 | 2gu3.A
27 | 1es0.A
28 | 3zhg.A
29 | 4epk.B
30 | 2avp.A
31 | 2e9x.A
32 | 4dkc.B
33 | 5b1a.M
34 | 3nbk.D
35 | 1o9i.A
36 | 4hcw.A
37 | 3mmz.A
38 | 2rbk.A
39 | 4yon.A
40 | 4nav.A
41 | 4cgw.A
42 | 3io1.A
43 | 1dxe.A
44 | 1a17.A
45 | 3cyn.B
46 | 3m6n.B
47 | 1ble.A
48 | 4qjb.B
49 | 2fml.A
50 | 4eo3.A
51 | 1r5m.A
52 | 3dao.A
53 | 3hhj.B
54 | 1e1h.B
55 | 2jty.A
56 | 3gkm.A
57 | 2wgl.A
58 | 3ajv.A
59 | 2duk.A
60 | 1dd3.A
61 | 3fhk.A
62 | 2fi7.A
63 | 2h2t.B
64 | 1zde.A
65 | 5hk8.A
66 | 2xcj.A
67 | 2gys.A
68 | 1lmj.A
69 | 2e7x.A
70 | 1kql.B
71 | 1yul.A
72 | 2l49.B
73 | 3l0g.A
74 | 3k4i.A
75 | 3l23.A
76 | 1qkr.A
77 | 2bw2.A
78 | 1vc9.A
79 | 4uqx.A
80 | 1emo.A
81 | 3oee.Y
82 | 2xdg.A
83 | 3sz7.A
84 | 2bv2.A
85 | 4jpb.W
86 | 1nps.A
87 | 5got.A
88 | 4mi7.A
89 | 1rtt.A
90 | 1ouv.A
91 | 3bzm.A
92 | 2dx6.A
93 | 1c17.M
94 | 3q4o.A
95 | 1s0w.C
96 | 4ozu.A
97 | 2kna.A
98 | 1jkv.A
99 | 4tkz.A
100 | 3dfj.A
101 | 4k2b.A
102 | 3ejf.A
103 | 3cyg.A
104 | 3qo6.A
105 | 3pt1.A
106 | 2a6t.B
107 | 3mt1.A
108 | 4fd9.A
109 | 2oga.A
110 | 1fjr.A
111 | 1im2.A
112 | 4hfs.A
113 | 3nd5.A
114 | 4lx3.A
115 | 1k1e.D
116 | 1xnf.A
117 | 1f3y.A
118 | 3fk9.A
119 | 4r42.A
120 | 3fmy.A
121 | 1b9x.C
122 | 5im4.F
123 | 3ajv.C
124 | 1kli.L
125 | 5epf.A
126 | 1a45.A
127 | 3vld.A
128 | 2qdl.A
129 | 4nfw.F
130 | 2hi6.A
131 | 1wp0.A
132 | 1mr1.D
133 | 4ar0.A
134 | 3isz.A
135 | 2km4.A
136 | 2o1e.A
137 | 1v7m.V
138 | 4mt4.A
139 | 4fz2.A
140 | 1emn.A
141 | 1kcg.B
142 | 3cng.A
143 | 1va9.A
144 | 2qkm.B
145 | 1ir6.A
146 | 3ny7.A
147 | 1rtq.A
148 | 4hku.A
149 | 3hz2.A
150 | 3piv.A
151 | 2bou.A
152 | 3ecs.A
153 | 2h1y.B
154 | 5tf5.A
155 | 3pe8.A
156 | 1coz.A
157 | 3grn.A
158 | 1kt0.A
159 | 1sqs.A
160 | 5erm.B
161 | 2rag.A
162 | 1akq.A
163 | 1dv8.A
164 | 3u37.A
165 | 5egw.A
166 | 3fm2.A
167 | 5g5g.B
168 | 5e3i.A
169 | 3mpo.A
170 | 1oki.A
171 | 1q33.A
172 | 1dx5.I
173 | 2p9j.A
174 | 4xch.A
175 | 2gpy.A
176 | 2b06.A
177 | 2ra8.A
178 | 2w86.A
179 | 2fek.A
180 | 1egg.A
181 | 3kqg.A
182 | 2dnm.A
183 | 3oq0.B
184 | 4jzs.A
185 | 3a6s.A
186 | 3f13.B
187 | 1vs3.A
188 | 1i8n.A
189 | 1g3i.W
190 | 1ddm.A
191 | 1k2e.A
192 | 1wr8.A
193 | 2va0.A
194 | 3apa.A
195 | 3k2k.A
196 | 3gg6.A
197 | 2vyi.A
198 | 4lej.A
199 | 1a7h.A
200 | 1zu4.A
201 | 3pfo.A
202 | 3uj6.A
203 | 3d68.A
204 | 1b08.C
205 | 4ezb.A
206 | 4lq6.A
207 | 3p8b.B
208 | 2cxx.A
209 | 2i15.A
210 | 1t0i.A
211 | 1ahs.A
212 | 4g4p.A
213 | 3k96.A
214 | 2q88.A
215 | 3o38.B
216 | 4um7.B
217 | 3wur.A
218 | 3cq4.A
219 | 3h95.A
220 | 3a3j.A
221 | 3n8h.A
222 | 1dqb.A
223 | 2p5t.A
224 | 3hkl.A
225 | 4gxw.B
226 | 1j71.A
227 | 3gwi.A
228 | 3q10.D
229 | 1hjz.A
230 | 1kmj.A
231 | 2qjt.B
232 | 3bil.A
233 | 3fbs.A
234 | 3fbs.B
235 | 2vvp.C
236 | 2h1y.A
237 | 3eix.A
238 | 3rui.A
239 | 1efa.A
240 | 2zag.A
241 | 1x57.A
242 | 2rrk.A
243 | 1ucv.A
244 | 3pqi.A
245 | 1qwo.A
246 | 3uk7.A
247 | 1l3a.A
248 | 3skv.A
249 | 3n77.A
250 | 3d8d.A
251 | 1w3i.A
252 | 2ra6.C
253 | 1r61.A
254 | 3qao.A
255 | 1wu4.A
256 | 2d1c.A
257 | 2w2i.A
258 | 2rdx.A
259 | 4gp6.A
260 | 2q1s.A
261 | 4mzy.A
262 | 1bvy.F
263 | 3ofn.Y
264 | 3nv7.A
265 | 3oir.A
266 | 2gwn.A
267 | 2dad.A
268 | 3on9.A
269 | 3rh0.A
270 | 3lvy.A
271 | 3lwk.A
272 | 3kol.A
273 | 1na0.A
274 | 4b6z.A
275 | 3f6a.A
276 | 3u7r.A
277 | 1nrw.A
278 | 2nly.A
279 | 1dm1.A
280 | 4nkp.A
281 | 4og1.A
282 | 5hci.C
283 | 1v7p.B
284 | 2oge.A
285 | 4o8s.A
286 | 2b8e.A
287 | 1elr.A
288 | 1u6l.A
289 | 1oqy.A
290 | 5a2q.I
291 | 1ugo.A
292 | 2a4v.A
293 | 3so6.A
294 | 2yyh.A
295 | 3tqt.A
296 | 1xxm.C
297 | 2zad.A
298 | 1jhe.A
299 | 2erx.A
300 | 1ete.A
301 | 2j4b.B
302 | 2zyz.C
303 | 1hv8.A
304 | 3iey.B
305 | 2jod.A
306 | 1tdq.B
307 | 3pqh.A
308 | 3o8s.A
309 | 2ca5.A
310 | 1zej.A
311 | 3fgh.A
312 | 4qam.B
313 | 5lba.D
314 | 5cfj.A
315 | 2efp.A
316 | 1erj.B
317 | 3pgv.A
318 | 2rav.A
319 | 1o1x.A
320 | 3bkw.A
321 | 1aw5.A
322 | 1st9.B
323 | 4lzh.A
324 | 1i0r.A
325 | 3ujp.A
326 | 3rof.A
327 | 2b9l.A
328 | 3ff7.C
329 | 4w64.B
330 | 1pdo.A
331 | 2nrr.A
332 | 1zwx.A
333 | 3aqf.B
334 | 4ptz.C
335 | 2bmx.B
336 | 3t38.A
337 | 4dev.D
338 | 1wlz.A
339 | 3t9w.A
340 | 2k13.X
341 | 3ufx.B
342 | 3c22.C
343 | 3bpt.A
344 | 2j49.A
345 | 1sjy.A
346 | 1nni.1
347 | 3e8m.A
348 | 2b0v.A
349 | 3fjy.A
350 | 1yxr.A
351 | 1fon.A
352 | 1xip.A
353 | 1x6o.A
354 | 2acf.D
355 | 3e59.B
356 | 1euv.B
357 | 1r4s.A
358 | 3cw3.A
359 | 3mxt.A
360 | 3g80.A
361 | 3aqg.B
362 | 3nhi.A
363 | 3v64.C
364 | 1a9x.B
365 | 2myu.A
366 | 2qdf.A
367 | 1mo7.A
368 | 1ksh.B
369 | 3lst.A
370 | 3l0g.B
371 | 4pmx.A
372 | 2r75.1
373 | 1bvp.1
374 | 1u8s.A
375 | 3kxy.J
376 | 4lfl.B
377 | 1g41.A
378 | 1a79.A
379 | 1i1g.A
380 | 4jbc.A
381 | 4k2m.A
382 | 2i39.B
383 | 2zai.A
384 | 3l4r.A
385 | 1y9q.A
386 | 1fsu.A
387 | 4qd4.A
388 | 2cyy.A
389 | 2cay.B
390 | 3lcm.A
391 | 2ls8.A
392 | 1rkq.A
393 | 1gvf.B
394 | 2g5f.B
395 | 2kz8.A
396 | 3k9x.C
397 | 3tqt.B
398 | 2iu8.C
399 | 5tgz.A
400 | 1lw7.A
401 | 4fo9.A
402 | 2qlr.A
403 | 3mpo.D
404 | 5t3u.B
405 | 1b1c.A
406 | 2z1d.A
407 | 4egs.A
408 | 5jci.A
409 | 2jxx.A
410 | 2cg4.A
411 | 2pqv.A
412 | 2kc7.A
413 | 3fiu.A
414 | 2ox8.A
415 | 3fvw.A
416 | 1vk6.A
417 | 4zev.A
418 | 5i3s.C
419 | 2a2l.C
420 | 3drn.B
421 | 1lpa.A
422 | 3exq.A
423 | 5jsz.A
424 | 1qze.A
425 | 3mcx.A
426 | 2w4e.A
427 | 2om6.A
428 | 4wzu.A
429 | 2j5o.A
430 | 2e0i.A
431 | 3q9s.A
432 | 3ewi.B
433 | 1or4.B
434 | 2aeg.A
435 | 4rmo.A
436 |
--------------------------------------------------------------------------------
/gvp/__init__.py:
--------------------------------------------------------------------------------
1 | import torch, functools
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch_geometric.nn import MessagePassing
5 | from torch_scatter import scatter_add
6 |
7 | def tuple_sum(*args):
8 | '''
9 | Sums any number of tuples (s, V) elementwise.
10 | '''
11 | return tuple(map(sum, zip(*args)))
12 |
13 | def tuple_cat(*args, dim=-1):
14 | '''
15 | Concatenates any number of tuples (s, V) elementwise.
16 |
17 | :param dim: dimension along which to concatenate when viewed
18 | as the `dim` index for the scalar-channel tensors.
19 | This means that `dim=-1` will be applied as
20 | `dim=-2` for the vector-channel tensors.
21 | '''
22 | dim %= len(args[0][0].shape)
23 | s_args, v_args = list(zip(*args))
24 | return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim)
25 |
26 | def tuple_index(x, idx):
27 | '''
28 | Indexes into a tuple (s, V) along the first dimension.
29 |
30 | :param idx: any object which can be used to index into a `torch.Tensor`
31 | '''
32 | return x[0][idx], x[1][idx]
33 |
34 | def randn(n, dims, device="cpu"):
35 | '''
36 | Returns random tuples (s, V) drawn elementwise from a normal distribution.
37 |
38 | :param n: number of data points
39 | :param dims: tuple of dimensions (n_scalar, n_vector)
40 |
41 | :return: (s, V) with s.shape = (n, n_scalar) and
42 | V.shape = (n, n_vector, 3)
43 | '''
44 | return torch.randn(n, dims[0], device=device), \
45 | torch.randn(n, dims[1], 3, device=device)
46 |
47 | def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True):
48 | '''
49 | L2 norm of tensor clamped above a minimum value `eps`.
50 |
51 | :param sqrt: if `False`, returns the square of the L2 norm
52 | '''
53 | out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
54 | return torch.sqrt(out) if sqrt else out
55 |
56 | def _split(x, nv):
57 | '''
58 | Splits a merged representation of (s, V) back into a tuple.
59 | Should be used only with `_merge(s, V)` and only if the tuple
60 | representation cannot be used.
61 |
62 | :param x: the `torch.Tensor` returned from `_merge`
63 | :param nv: the number of vector channels in the input to `_merge`
64 | '''
65 | v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3))
66 | s = x[..., :-3*nv]
67 | return s, v
68 |
69 | def _merge(s, v):
70 | '''
71 | Merges a tuple (s, V) into a single `torch.Tensor`, where the
72 | vector channels are flattened and appended to the scalar channels.
73 | Should be used only if the tuple representation cannot be used.
74 | Use `_split(x, nv)` to reverse.
75 | '''
76 | v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],))
77 | return torch.cat([s, v], -1)
78 |
79 | class GVP(nn.Module):
80 | '''
81 | Geometric Vector Perceptron. See manuscript and README.md
82 | for more details.
83 |
84 | :param in_dims: tuple (n_scalar, n_vector)
85 | :param out_dims: tuple (n_scalar, n_vector)
86 | :param h_dim: intermediate number of vector channels, optional
87 | :param activations: tuple of functions (scalar_act, vector_act)
88 | :param vector_gate: whether to use vector gating.
89 | (vector_act will be used as sigma^+ in vector gating if `True`)
90 | '''
91 | def __init__(self, in_dims, out_dims, h_dim=None,
92 | activations=(F.relu, torch.sigmoid), vector_gate=False):
93 | super(GVP, self).__init__()
94 | self.si, self.vi = in_dims
95 | self.so, self.vo = out_dims
96 | self.vector_gate = vector_gate
97 | if self.vi:
98 | self.h_dim = h_dim or max(self.vi, self.vo)
99 | self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
100 | self.ws = nn.Linear(self.h_dim + self.si, self.so)
101 | if self.vo:
102 | self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
103 | if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo)
104 | else:
105 | self.ws = nn.Linear(self.si, self.so)
106 |
107 | self.scalar_act, self.vector_act = activations
108 | self.dummy_param = nn.Parameter(torch.empty(0))
109 |
110 | def forward(self, x):
111 | '''
112 | :param x: tuple (s, V) of `torch.Tensor`,
113 | or (if vectors_in is 0), a single `torch.Tensor`
114 | :return: tuple (s, V) of `torch.Tensor`,
115 | or (if vectors_out is 0), a single `torch.Tensor`
116 | '''
117 | if self.vi:
118 | s, v = x
119 | v = torch.transpose(v, -1, -2)
120 | vh = self.wh(v)
121 | vn = _norm_no_nan(vh, axis=-2)
122 | s = self.ws(torch.cat([s, vn], -1))
123 | if self.vo:
124 | v = self.wv(vh)
125 | v = torch.transpose(v, -1, -2)
126 | if self.vector_gate:
127 | if self.vector_act:
128 | gate = self.wsv(self.vector_act(s))
129 | else:
130 | gate = self.wsv(s)
131 | v = v * torch.sigmoid(gate).unsqueeze(-1)
132 | elif self.vector_act:
133 | v = v * self.vector_act(
134 | _norm_no_nan(v, axis=-1, keepdims=True))
135 | else:
136 | s = self.ws(x)
137 | if self.vo:
138 | v = torch.zeros(s.shape[0], self.vo, 3,
139 | device=self.dummy_param.device)
140 | if self.scalar_act:
141 | s = self.scalar_act(s)
142 |
143 | return (s, v) if self.vo else s
144 |
145 | class _VDropout(nn.Module):
146 | '''
147 | Vector channel dropout where the elements of each
148 | vector channel are dropped together.
149 | '''
150 | def __init__(self, drop_rate):
151 | super(_VDropout, self).__init__()
152 | self.drop_rate = drop_rate
153 | self.dummy_param = nn.Parameter(torch.empty(0))
154 |
155 | def forward(self, x):
156 | '''
157 | :param x: `torch.Tensor` corresponding to vector channels
158 | '''
159 | device = self.dummy_param.device
160 | if not self.training:
161 | return x
162 | mask = torch.bernoulli(
163 | (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device)
164 | ).unsqueeze(-1)
165 | x = mask * x / (1 - self.drop_rate)
166 | return x
167 |
168 | class Dropout(nn.Module):
169 | '''
170 | Combined dropout for tuples (s, V).
171 | Takes tuples (s, V) as input and as output.
172 | '''
173 | def __init__(self, drop_rate):
174 | super(Dropout, self).__init__()
175 | self.sdropout = nn.Dropout(drop_rate)
176 | self.vdropout = _VDropout(drop_rate)
177 |
178 | def forward(self, x):
179 | '''
180 | :param x: tuple (s, V) of `torch.Tensor`,
181 | or single `torch.Tensor`
182 | (will be assumed to be scalar channels)
183 | '''
184 | if type(x) is torch.Tensor:
185 | return self.sdropout(x)
186 | s, v = x
187 | return self.sdropout(s), self.vdropout(v)
188 |
189 | class LayerNorm(nn.Module):
190 | '''
191 | Combined LayerNorm for tuples (s, V).
192 | Takes tuples (s, V) as input and as output.
193 | '''
194 | def __init__(self, dims):
195 | super(LayerNorm, self).__init__()
196 | self.s, self.v = dims
197 | self.scalar_norm = nn.LayerNorm(self.s)
198 |
199 | def forward(self, x):
200 | '''
201 | :param x: tuple (s, V) of `torch.Tensor`,
202 | or single `torch.Tensor`
203 | (will be assumed to be scalar channels)
204 | '''
205 | if not self.v:
206 | return self.scalar_norm(x)
207 | s, v = x
208 | vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False)
209 | vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True))
210 | return self.scalar_norm(s), v / vn
211 |
212 | class GVPConv(MessagePassing):
213 | '''
214 | Graph convolution / message passing with Geometric Vector Perceptrons.
215 | Takes in a graph with node and edge embeddings,
216 | and returns new node embeddings.
217 |
218 | This does NOT do residual updates and pointwise feedforward layers
219 | ---see `GVPConvLayer`.
220 |
221 | :param in_dims: input node embedding dimensions (n_scalar, n_vector)
222 | :param out_dims: output node embedding dimensions (n_scalar, n_vector)
223 | :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
224 | :param n_layers: number of GVPs in the message function
225 | :param module_list: preconstructed message function, overrides n_layers
226 | :param aggr: should be "add" if some incoming edges are masked, as in
227 | a masked autoregressive decoder architecture, otherwise "mean"
228 | :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
229 | :param vector_gate: whether to use vector gating.
230 | (vector_act will be used as sigma^+ in vector gating if `True`)
231 | '''
232 | def __init__(self, in_dims, out_dims, edge_dims,
233 | n_layers=3, module_list=None, aggr="mean",
234 | activations=(F.relu, torch.sigmoid), vector_gate=False):
235 | super(GVPConv, self).__init__(aggr=aggr)
236 | self.si, self.vi = in_dims
237 | self.so, self.vo = out_dims
238 | self.se, self.ve = edge_dims
239 |
240 | GVP_ = functools.partial(GVP,
241 | activations=activations, vector_gate=vector_gate)
242 |
243 | module_list = module_list or []
244 | if not module_list:
245 | if n_layers == 1:
246 | module_list.append(
247 | GVP_((2*self.si + self.se, 2*self.vi + self.ve),
248 | (self.so, self.vo), activations=(None, None)))
249 | else:
250 | module_list.append(
251 | GVP_((2*self.si + self.se, 2*self.vi + self.ve), out_dims)
252 | )
253 | for i in range(n_layers - 2):
254 | module_list.append(GVP_(out_dims, out_dims))
255 | module_list.append(GVP_(out_dims, out_dims,
256 | activations=(None, None)))
257 | self.message_func = nn.Sequential(*module_list)
258 |
259 | def forward(self, x, edge_index, edge_attr):
260 | '''
261 | :param x: tuple (s, V) of `torch.Tensor`
262 | :param edge_index: array of shape [2, n_edges]
263 | :param edge_attr: tuple (s, V) of `torch.Tensor`
264 | '''
265 | x_s, x_v = x
266 | message = self.propagate(edge_index,
267 | s=x_s, v=x_v.reshape(x_v.shape[0], 3*x_v.shape[1]),
268 | edge_attr=edge_attr)
269 | return _split(message, self.vo)
270 |
271 | def message(self, s_i, v_i, s_j, v_j, edge_attr):
272 | v_j = v_j.view(v_j.shape[0], v_j.shape[1]//3, 3)
273 | v_i = v_i.view(v_i.shape[0], v_i.shape[1]//3, 3)
274 | message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
275 | message = self.message_func(message)
276 | return _merge(*message)
277 |
278 |
279 | class GVPConvLayer(nn.Module):
280 | '''
281 | Full graph convolution / message passing layer with
282 | Geometric Vector Perceptrons. Residually updates node embeddings with
283 | aggregated incoming messages, applies a pointwise feedforward
284 | network to node embeddings, and returns updated node embeddings.
285 |
286 | To only compute the aggregated messages, see `GVPConv`.
287 |
288 | :param node_dims: node embedding dimensions (n_scalar, n_vector)
289 | :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
290 | :param n_message: number of GVPs to use in message function
291 | :param n_feedforward: number of GVPs to use in feedforward function
292 | :param drop_rate: drop probability in all dropout layers
293 | :param autoregressive: if `True`, this `GVPConvLayer` will be used
294 | with a different set of input node embeddings for messages
295 | where src >= dst
296 | :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
297 | :param vector_gate: whether to use vector gating.
298 | (vector_act will be used as sigma^+ in vector gating if `True`)
299 | '''
300 | def __init__(self, node_dims, edge_dims,
301 | n_message=3, n_feedforward=2, drop_rate=.1,
302 | autoregressive=False,
303 | activations=(F.relu, torch.sigmoid), vector_gate=False):
304 |
305 | super(GVPConvLayer, self).__init__()
306 | self.conv = GVPConv(node_dims, node_dims, edge_dims, n_message,
307 | aggr="add" if autoregressive else "mean",
308 | activations=activations, vector_gate=vector_gate)
309 | GVP_ = functools.partial(GVP,
310 | activations=activations, vector_gate=vector_gate)
311 | self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)])
312 | self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])
313 |
314 | ff_func = []
315 | if n_feedforward == 1:
316 | ff_func.append(GVP_(node_dims, node_dims, activations=(None, None)))
317 | else:
318 | hid_dims = 4*node_dims[0], 2*node_dims[1]
319 | ff_func.append(GVP_(node_dims, hid_dims))
320 | for i in range(n_feedforward-2):
321 | ff_func.append(GVP_(hid_dims, hid_dims))
322 | ff_func.append(GVP_(hid_dims, node_dims, activations=(None, None)))
323 | self.ff_func = nn.Sequential(*ff_func)
324 |
325 | def forward(self, x, edge_index, edge_attr,
326 | autoregressive_x=None, node_mask=None):
327 | '''
328 | :param x: tuple (s, V) of `torch.Tensor`
329 | :param edge_index: array of shape [2, n_edges]
330 | :param edge_attr: tuple (s, V) of `torch.Tensor`
331 | :param autoregressive_x: tuple (s, V) of `torch.Tensor`.
332 | If not `None`, will be used as src node embeddings
333 | for forming messages where src >= dst. The corrent node
334 | embeddings `x` will still be the base of the update and the
335 | pointwise feedforward.
336 | :param node_mask: array of type `bool` to index into the first
337 | dim of node embeddings (s, V). If not `None`, only
338 | these nodes will be updated.
339 | '''
340 |
341 | if autoregressive_x is not None:
342 | src, dst = edge_index
343 | mask = src < dst
344 | edge_index_forward = edge_index[:, mask]
345 | edge_index_backward = edge_index[:, ~mask]
346 | edge_attr_forward = tuple_index(edge_attr, mask)
347 | edge_attr_backward = tuple_index(edge_attr, ~mask)
348 |
349 | dh = tuple_sum(
350 | self.conv(x, edge_index_forward, edge_attr_forward),
351 | self.conv(autoregressive_x, edge_index_backward, edge_attr_backward)
352 | )
353 |
354 | count = scatter_add(torch.ones_like(dst), dst,
355 | dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1)
356 |
357 | dh = dh[0] / count, dh[1] / count.unsqueeze(-1)
358 |
359 | else:
360 | dh = self.conv(x, edge_index, edge_attr)
361 |
362 | if node_mask is not None:
363 | x_ = x
364 | x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)
365 |
366 | x = self.norm[0](tuple_sum(x, self.dropout[0](dh)))
367 |
368 | dh = self.ff_func(x)
369 | x = self.norm[1](tuple_sum(x, self.dropout[1](dh)))
370 |
371 | if node_mask is not None:
372 | x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
373 | x = x_
374 | return x
375 |
--------------------------------------------------------------------------------
/gvp/atom3d.py:
--------------------------------------------------------------------------------
1 | import torch, random, scipy, math
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import pandas as pd
5 | import numpy as np
6 | from atom3d.datasets import LMDBDataset
7 | import atom3d.datasets.ppi.neighbors as nb
8 | from torch.utils.data import IterableDataset
9 | from . import GVP, GVPConvLayer, LayerNorm
10 | import torch_cluster, torch_geometric, torch_scatter
11 | from .data import _normalize, _rbf
12 |
13 | _NUM_ATOM_TYPES = 9
14 | _element_mapping = lambda x: {
15 | 'H' : 0,
16 | 'C' : 1,
17 | 'N' : 2,
18 | 'O' : 3,
19 | 'F' : 4,
20 | 'S' : 5,
21 | 'Cl': 6, 'CL': 6,
22 | 'P' : 7
23 | }.get(x, 8)
24 | _amino_acids = lambda x: {
25 | 'ALA': 0,
26 | 'ARG': 1,
27 | 'ASN': 2,
28 | 'ASP': 3,
29 | 'CYS': 4,
30 | 'GLU': 5,
31 | 'GLN': 6,
32 | 'GLY': 7,
33 | 'HIS': 8,
34 | 'ILE': 9,
35 | 'LEU': 10,
36 | 'LYS': 11,
37 | 'MET': 12,
38 | 'PHE': 13,
39 | 'PRO': 14,
40 | 'SER': 15,
41 | 'THR': 16,
42 | 'TRP': 17,
43 | 'TYR': 18,
44 | 'VAL': 19
45 | }.get(x, 20)
46 | _DEFAULT_V_DIM = (100, 16)
47 | _DEFAULT_E_DIM = (32, 1)
48 |
49 | def _edge_features(coords, edge_index, D_max=4.5, num_rbf=16, device='cpu'):
50 |
51 | E_vectors = coords[edge_index[0]] - coords[edge_index[1]]
52 | rbf = _rbf(E_vectors.norm(dim=-1),
53 | D_max=D_max, D_count=num_rbf, device=device)
54 |
55 | edge_s = rbf
56 | edge_v = _normalize(E_vectors).unsqueeze(-2)
57 |
58 | edge_s, edge_v = map(torch.nan_to_num,
59 | (edge_s, edge_v))
60 |
61 | return edge_s, edge_v
62 |
63 | class BaseTransform:
64 | '''
65 | Implementation of an ATOM3D Transform which featurizes the atomic
66 | coordinates in an ATOM3D dataframes into `torch_geometric.data.Data`
67 | graphs. This class should not be used directly; instead, use the
68 | task-specific transforms, which all extend BaseTransform. Node
69 | and edge features are as described in the EGNN manuscript.
70 |
71 | Returned graphs have the following attributes:
72 | -x atomic coordinates, shape [n_nodes, 3]
73 | -atoms numeric encoding of atomic identity, shape [n_nodes]
74 | -edge_index edge indices, shape [2, n_edges]
75 | -edge_s edge scalar features, shape [n_edges, 16]
76 | -edge_v edge scalar features, shape [n_edges, 1, 3]
77 |
78 | Subclasses of BaseTransform will produce graphs with additional
79 | attributes for the tasks-specific training labels, in addition
80 | to the above.
81 |
82 | All subclasses of BaseTransform directly inherit the BaseTransform
83 | constructor.
84 |
85 | :param edge_cutoff: distance cutoff to use when drawing edges
86 | :param num_rbf: number of radial bases to encode the distance on each edge
87 | :device: if "cuda", will do preprocessing on the GPU
88 | '''
89 | def __init__(self, edge_cutoff=4.5, num_rbf=16, device='cpu'):
90 | self.edge_cutoff = edge_cutoff
91 | self.num_rbf = num_rbf
92 | self.device = device
93 |
94 | def __call__(self, df):
95 | '''
96 | :param df: `pandas.DataFrame` of atomic coordinates
97 | in the ATOM3D format
98 |
99 | :return: `torch_geometric.data.Data` structure graph
100 | '''
101 | with torch.no_grad():
102 | coords = torch.as_tensor(df[['x', 'y', 'z']].to_numpy(),
103 | dtype=torch.float32, device=self.device)
104 | atoms = torch.as_tensor(list(map(_element_mapping, df.element)),
105 | dtype=torch.long, device=self.device)
106 |
107 | edge_index = torch_cluster.radius_graph(coords, r=self.edge_cutoff)
108 |
109 | edge_s, edge_v = _edge_features(coords, edge_index,
110 | D_max=self.edge_cutoff, num_rbf=self.num_rbf, device=self.device)
111 |
112 | return torch_geometric.data.Data(x=coords, atoms=atoms,
113 | edge_index=edge_index, edge_s=edge_s, edge_v=edge_v)
114 |
115 | class BaseModel(nn.Module):
116 | '''
117 | A base 5-layer GVP-GNN for all ATOM3D tasks, using GVPs with
118 | vector gating as described in the manuscript. Takes in atomic-level
119 | structure graphs of type `torch_geometric.data.Batch`
120 | and returns a single scalar.
121 |
122 | This class should not be used directly. Instead, please use the
123 | task-specific models which extend BaseModel. (Some of these classes
124 | may be aliases of BaseModel.)
125 |
126 | :param num_rbf: number of radial bases to use in the edge embedding
127 | '''
128 | def __init__(self, num_rbf=16):
129 |
130 | super().__init__()
131 | activations = (F.relu, None)
132 |
133 | self.embed = nn.Embedding(_NUM_ATOM_TYPES, _NUM_ATOM_TYPES)
134 |
135 | self.W_e = nn.Sequential(
136 | LayerNorm((num_rbf, 1)),
137 | GVP((num_rbf, 1), _DEFAULT_E_DIM,
138 | activations=(None, None), vector_gate=True)
139 | )
140 |
141 | self.W_v = nn.Sequential(
142 | LayerNorm((_NUM_ATOM_TYPES, 0)),
143 | GVP((_NUM_ATOM_TYPES, 0), _DEFAULT_V_DIM,
144 | activations=(None, None), vector_gate=True)
145 | )
146 |
147 | self.layers = nn.ModuleList(
148 | GVPConvLayer(_DEFAULT_V_DIM, _DEFAULT_E_DIM,
149 | activations=activations, vector_gate=True)
150 | for _ in range(5))
151 |
152 | ns, _ = _DEFAULT_V_DIM
153 | self.W_out = nn.Sequential(
154 | LayerNorm(_DEFAULT_V_DIM),
155 | GVP(_DEFAULT_V_DIM, (ns, 0),
156 | activations=activations, vector_gate=True)
157 | )
158 |
159 | self.dense = nn.Sequential(
160 | nn.Linear(ns, 2*ns), nn.ReLU(inplace=True),
161 | nn.Dropout(p=0.1),
162 | nn.Linear(2*ns, 1)
163 | )
164 |
165 | def forward(self, batch, scatter_mean=True, dense=True):
166 | '''
167 | Forward pass which can be adjusted based on task formulation.
168 |
169 | :param batch: `torch_geometric.data.Batch` with data attributes
170 | as returned from a BaseTransform
171 | :param scatter_mean: if `True`, returns mean of final node embeddings
172 | (for each graph), else, returns embeddings seperately
173 | :param dense: if `True`, applies final dense layer to reduce embedding
174 | to a single scalar; else, returns the embedding
175 | '''
176 | h_V = self.embed(batch.atoms)
177 | h_E = (batch.edge_s, batch.edge_v)
178 | h_V = self.W_v(h_V)
179 | h_E = self.W_e(h_E)
180 |
181 | batch_id = batch.batch
182 |
183 | for layer in self.layers:
184 | h_V = layer(h_V, batch.edge_index, h_E)
185 |
186 | out = self.W_out(h_V)
187 | if scatter_mean: out = torch_scatter.scatter_mean(out, batch_id, dim=0)
188 | if dense: out = self.dense(out).squeeze(-1)
189 | return out
190 |
191 | ########################################################################
192 |
193 | class SMPTransform(BaseTransform):
194 | '''
195 | Transforms dict-style entries from the ATOM3D SMP dataset
196 | to featurized graphs. Returns a `torch_geometric.data.Data`
197 | graph with attribute `label` and all structural attributes
198 | as described in BaseTransform.
199 |
200 | Includes hydrogen atoms.
201 | '''
202 | def __call__(self, elem):
203 | data = super().__call__(elem['atoms'])
204 | with torch.no_grad():
205 | data.label = torch.as_tensor(elem['labels'],
206 | device=self.device, dtype=torch.float32)
207 | return data
208 |
209 | SMPModel = BaseModel
210 |
211 | ########################################################################
212 |
213 | class PPIDataset(IterableDataset):
214 | '''
215 | A `torch.utils.data.IterableDataset` wrapper around a
216 | ATOM3D PPI dataset. Extracts (many) individual amino acid pairs
217 | from each structure of two interacting proteins. The returned graphs
218 | are seperate and each represents a 30 angstrom radius from the
219 | selected residue's alpha carbon.
220 |
221 | On each iteration, returns a pair of `torch_geometric.data.Data`
222 | graphs with the (same) attribute `label` which is 1 if the two
223 | amino acids interact and 0 otherwise, `ca_idx` for the node index
224 | of the alpha carbon, and all structural attributes as
225 | described in BaseTransform.
226 |
227 | Modified from
228 | https://github.com/drorlab/atom3d/blob/master/examples/ppi/gnn/data.py
229 |
230 | Excludes hydrogen atoms.
231 |
232 | :param lmdb_dataset: path to ATOM3D dataset
233 | '''
234 | def __init__(self, lmdb_dataset):
235 | self.dataset = LMDBDataset(lmdb_dataset)
236 | self.transform = BaseTransform()
237 |
238 | def __iter__(self):
239 | worker_info = torch.utils.data.get_worker_info()
240 | if worker_info is None:
241 | gen = self._dataset_generator(list(range(len(self.dataset))), shuffle=True)
242 | else:
243 | per_worker = int(math.ceil(len(self.dataset) / float(worker_info.num_workers)))
244 | worker_id = worker_info.id
245 | iter_start = worker_id * per_worker
246 | iter_end = min(iter_start + per_worker, len(self.dataset))
247 | gen = self._dataset_generator(
248 | list(range(len(self.dataset)))[iter_start:iter_end],
249 | shuffle=True)
250 | return gen
251 |
252 | def _df_to_graph(self, struct_df, chain_res, label):
253 |
254 | struct_df = struct_df[struct_df.element != 'H'].reset_index(drop=True)
255 |
256 | chain, resnum = chain_res
257 | res_df = struct_df[(struct_df.chain == chain) & (struct_df.residue == resnum)]
258 | if 'CA' not in res_df.name.tolist():
259 | return None
260 | ca_pos = res_df[res_df['name']=='CA'][['x', 'y', 'z']].astype(np.float32).to_numpy()[0]
261 |
262 | kd_tree = scipy.spatial.KDTree(struct_df[['x','y','z']].to_numpy())
263 | graph_pt_idx = kd_tree.query_ball_point(ca_pos, r=30.0, p=2.0)
264 | graph_df = struct_df.iloc[graph_pt_idx].reset_index(drop=True)
265 |
266 | ca_idx = np.where((graph_df.chain == chain) & (graph_df.residue == resnum) & (graph_df.name == 'CA'))[0]
267 | if len(ca_idx) != 1:
268 | return None
269 |
270 | data = self.transform(graph_df)
271 | data.label = label
272 |
273 | data.ca_idx = int(ca_idx)
274 | data.n_nodes = data.num_nodes
275 |
276 | return data
277 |
278 | def _dataset_generator(self, indices, shuffle=True):
279 | if shuffle: random.shuffle(indices)
280 | with torch.no_grad():
281 | for idx in indices:
282 | data = self.dataset[idx]
283 |
284 | neighbors = data['atoms_neighbors']
285 | pairs = data['atoms_pairs']
286 |
287 | for i, (ensemble_name, target_df) in enumerate(pairs.groupby(['ensemble'])):
288 | sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df)
289 | positives = neighbors[neighbors.ensemble0 == ensemble_name]
290 | negatives = nb.get_negatives(positives, bound1, bound2)
291 | negatives['label'] = 0
292 | labels = self._create_labels(positives, negatives, num_pos=10, neg_pos_ratio=1)
293 |
294 | for index, row in labels.iterrows():
295 |
296 | label = float(row['label'])
297 | chain_res1 = row[['chain0', 'residue0']].values
298 | chain_res2 = row[['chain1', 'residue1']].values
299 | graph1 = self._df_to_graph(bound1, chain_res1, label)
300 | graph2 = self._df_to_graph(bound2, chain_res2, label)
301 | if (graph1 is None) or (graph2 is None):
302 | continue
303 | yield graph1, graph2
304 |
305 | def _create_labels(self, positives, negatives, num_pos, neg_pos_ratio):
306 | frac = min(1, num_pos / positives.shape[0])
307 | positives = positives.sample(frac=frac)
308 | n = positives.shape[0] * neg_pos_ratio
309 | n = min(negatives.shape[0], n)
310 | negatives = negatives.sample(n, random_state=0, axis=0)
311 | labels = pd.concat([positives, negatives])[['chain0', 'residue0', 'chain1', 'residue1', 'label']]
312 | return labels
313 |
314 | class PPIModel(BaseModel):
315 | '''
316 | GVP-GNN for the PPI task.
317 |
318 | Extends BaseModel to accept a tuple (batch1, batch2)
319 | of `torch_geometric.data.Batch` graphs, where each graph
320 | index in a batch is paired with the same graph index in the
321 | other batch.
322 |
323 | As noted in the manuscript, PPIModel uses the final alpha
324 | carbon embeddings instead of the graph mean embedding.
325 |
326 | Returns a single scalar for each graph pair which can be used as
327 | a logit in binary classification.
328 | '''
329 | def __init__(self, **kwargs):
330 |
331 | super().__init__(**kwargs)
332 | ns, _ = _DEFAULT_V_DIM
333 | self.dense = nn.Sequential(
334 | nn.Linear(2*ns, 4*ns), nn.ReLU(inplace=True),
335 | nn.Dropout(p=0.1),
336 | nn.Linear(4*ns, 1)
337 | )
338 |
339 | def forward(self, batch):
340 | graph1, graph2 = batch
341 | out1, out2 = map(self._gnn_forward, (graph1, graph2))
342 | out = torch.cat([out1, out2], dim=-1)
343 | out = self.dense(out)
344 | return torch.sigmoid(out).squeeze(-1)
345 |
346 | def _gnn_forward(self, graph):
347 | out = super().forward(graph, scatter_mean=False, dense=False)
348 | return out[graph.ca_idx+graph.ptr[:-1]]
349 |
350 |
351 | ########################################################################
352 |
353 | class LBATransform(BaseTransform):
354 | '''
355 | Transforms dict-style entries from the ATOM3D LBA dataset
356 | to featurized graphs. Returns a `torch_geometric.data.Data`
357 | graph with attribute `label` for the neglog-affinity
358 | and all structural attributes as described in BaseTransform.
359 |
360 | The transform combines the atomic coordinates of the pocket
361 | and ligand atoms and treats them as a single structure / graph.
362 |
363 | Includes hydrogen atoms.
364 | '''
365 | def __call__(self, elem):
366 | pocket, ligand = elem['atoms_pocket'], elem['atoms_ligand']
367 | df = pd.concat([pocket, ligand], ignore_index=True)
368 |
369 | data = super().__call__(df)
370 | with torch.no_grad():
371 | data.label = elem['scores']['neglog_aff']
372 | lig_flag = torch.zeros(df.shape[0], device=self.device, dtype=torch.bool)
373 | lig_flag[-len(ligand):] = 1
374 | data.lig_flag = lig_flag
375 | return data
376 |
377 | LBAModel = BaseModel
378 |
379 | ########################################################################
380 |
381 | class LEPTransform(BaseTransform):
382 | '''
383 | Transforms dict-style entries from the ATOM3D LEP dataset
384 | to featurized graphs. Returns a tuple (active, inactive) of
385 | `torch_geometric.data.Data` graphs with the (same) attribute
386 | `label` which is equal to 1. if the ligand activates the protein
387 | and 0. otherwise, and all structural attributes as described
388 | in BaseTransform.
389 |
390 | The transform combines the atomic coordinates of the pocket
391 | and ligand atoms and treats them as a single structure / graph.
392 |
393 | Excludes hydrogen atoms.
394 | '''
395 | def __call__(self, elem):
396 | active, inactive = elem['atoms_active'], elem['atoms_inactive']
397 | with torch.no_grad():
398 | active, inactive = map(self._to_graph, (active, inactive))
399 | active.label = inactive.label = 1. if elem['label'] == 'A' else 0.
400 | return active, inactive
401 |
402 | def _to_graph(self, df):
403 | df = df[df.element != 'H'].reset_index(drop=True)
404 | return super().__call__(df)
405 |
406 | class LEPModel(BaseModel):
407 | '''
408 | GVP-GNN for the LEP task.
409 |
410 | Extends BaseModel to accept a tuple (batch1, batch2)
411 | of `torch_geometric.data.Batch` graphs, where each graph
412 | index in a batch is paired with the same graph index in the
413 | other batch.
414 |
415 | Returns a single scalar for each graph pair which can be used as
416 | a logit in binary classification.
417 | '''
418 | def __init__(self, **kwargs):
419 | super().__init__(**kwargs)
420 | ns, _ = _DEFAULT_V_DIM
421 | self.dense = nn.Sequential(
422 | nn.Linear(2*ns, 4*ns), nn.ReLU(inplace=True),
423 | nn.Dropout(p=0.1),
424 | nn.Linear(4*ns, 1)
425 | )
426 |
427 | def forward(self, batch):
428 | out1, out2 = map(self._gnn_forward, batch)
429 | out = torch.cat([out1, out2], dim=-1)
430 | out = self.dense(out)
431 | return torch.sigmoid(out).squeeze(-1)
432 |
433 | def _gnn_forward(self, graph):
434 | return super().forward(graph, dense=False)
435 |
436 | ########################################################################
437 |
438 | class MSPTransform(BaseTransform):
439 | '''
440 | Transforms dict-style entries from the ATOM3D MSP dataset
441 | to featurized graphs. Returns a tuple (original, mutated) of
442 | `torch_geometric.data.Data` graphs with the (same) attribute
443 | `label` which is equal to 1. if the mutation stabilizes the
444 | complex and 0. otherwise, and all structural attributes as
445 | described in BaseTransform.
446 |
447 | The transform combines the atomic coordinates of the two proteis
448 | in each complex and treats them as a single structure / graph.
449 |
450 | Adapted from
451 | https://github.com/drorlab/atom3d/blob/master/examples/msp/gnn/data.py
452 |
453 | Excludes hydrogen atoms.
454 | '''
455 | def __call__(self, elem):
456 | mutation = elem['id'].split('_')[-1]
457 | orig_df = elem['original_atoms'].reset_index(drop=True)
458 | mut_df = elem['mutated_atoms'].reset_index(drop=True)
459 | with torch.no_grad():
460 | original, mutated = self._transform(orig_df, mutation), \
461 | self._transform(mut_df, mutation)
462 | original.label = mutated.label = 1. if elem['label'] == '1' else 0.
463 | return original, mutated
464 |
465 | def _transform(self, df, mutation):
466 |
467 | df = df[df.element != 'H'].reset_index(drop=True)
468 | data = super().__call__(df)
469 | data.node_mask = self._extract_node_mask(df, mutation)
470 | return data
471 |
472 | def _extract_node_mask(self, df, mutation):
473 | chain, res = mutation[1], int(mutation[2:-1])
474 | idx = df.index[(df.chain.values == chain) & (df.residue.values == res)].values
475 | mask = torch.zeros(len(df), dtype=torch.long, device=self.device)
476 | mask[idx] = 1
477 | return mask
478 |
479 | class MSPModel(BaseModel):
480 | '''
481 | GVP-GNN for the MSP task.
482 |
483 | Extends BaseModel to accept a tuple (batch1, batch2)
484 | of `torch_geometric.data.Batch` graphs, where each graph
485 | index in a batch is paired with the same graph index in the
486 | other batch.
487 |
488 | As noted in the manuscript, MSPModel uses the final embeddings
489 | averaged over the residue of interest instead of the entire graph.
490 |
491 | Returns a single scalar for each graph pair which can be used as
492 | a logit in binary classification.
493 | '''
494 | def __init__(self, **kwargs):
495 | super().__init__(**kwargs)
496 | ns, _ = _DEFAULT_V_DIM
497 | self.dense = nn.Sequential(
498 | nn.Linear(2*ns, 4*ns), nn.ReLU(inplace=True),
499 | nn.Dropout(p=0.1),
500 | nn.Linear(4*ns, 1)
501 | )
502 |
503 | def forward(self, batch):
504 | out1, out2 = map(self._gnn_forward, batch)
505 | out = torch.cat([out1, out2], dim=-1)
506 | out = self.dense(out)
507 | return torch.sigmoid(out).squeeze(-1)
508 |
509 | def _gnn_forward(self, graph):
510 | out = super().forward(graph, scatter_mean=False, dense=False)
511 | out = out * graph.node_mask.unsqueeze(-1)
512 | out = torch_scatter.scatter_add(out, graph.batch, dim=0)
513 | count = torch_scatter.scatter_add(graph.node_mask, graph.batch)
514 | return out / count.unsqueeze(-1)
515 |
516 | ########################################################################
517 |
518 | class PSRTransform(BaseTransform):
519 | '''
520 | Transforms dict-style entries from the ATOM3D PSR dataset
521 | to featurized graphs. Returns a `torch_geometric.data.Data`
522 | graph with attribute `label` for the GDT_TS, `id` for the
523 | name of the target, and all structural attributes as
524 | described in BaseTransform.
525 |
526 | Includes hydrogen atoms.
527 | '''
528 | def __call__(self, elem):
529 | df = elem['atoms']
530 | df = df[df.element != 'H'].reset_index(drop=True)
531 | data = super().__call__(df)
532 | data.label = elem['scores']['gdt_ts']
533 | data.id = eval(elem['id'])[0]
534 | return data
535 |
536 | PSRModel = BaseModel
537 |
538 | ########################################################################
539 |
540 | class RSRTransform(BaseTransform):
541 | '''
542 | Transforms dict-style entries from the ATOM3D RSR dataset
543 | to featurized graphs. Returns a `torch_geometric.data.Data`
544 | graph with attribute `label` for the RMSD, `id` for the
545 | name of the target, and all structural attributes as
546 | described in BaseTransform.
547 |
548 | Includes hydrogen atoms.
549 | '''
550 | def __call__(self, elem):
551 | df = elem['atoms']
552 | df = df[df.element != 'H'].reset_index(drop=True)
553 | data = super().__call__(df)
554 | data.label = elem['scores']['rms']
555 | data.id = eval(elem['id'])[0]
556 | return data
557 |
558 | RSRModel = BaseModel
559 |
560 | ########################################################################
561 |
562 | class RESDataset(IterableDataset):
563 | '''
564 | A `torch.utils.data.IterableDataset` wrapper around a
565 | ATOM3D RES dataset.
566 |
567 | On each iteration, returns a `torch_geometric.data.Data`
568 | graph with the attribute `label` encoding the masked residue
569 | identity, `ca_idx` for the node index of the alpha carbon,
570 | and all structural attributes as described in BaseTransform.
571 |
572 | Excludes hydrogen atoms.
573 |
574 | :param lmdb_dataset: path to ATOM3D dataset
575 | :param split_path: path to the ATOM3D split file
576 | '''
577 | def __init__(self, lmdb_dataset, split_path):
578 | self.dataset = LMDBDataset(lmdb_dataset)
579 | self.idx = list(map(int, open(split_path).read().split()))
580 | self.transform = BaseTransform()
581 |
582 | def __iter__(self):
583 | worker_info = torch.utils.data.get_worker_info()
584 | if worker_info is None:
585 | gen = self._dataset_generator(list(range(len(self.idx))),
586 | shuffle=True)
587 | else:
588 | per_worker = int(math.ceil(len(self.idx) / float(worker_info.num_workers)))
589 | worker_id = worker_info.id
590 | iter_start = worker_id * per_worker
591 | iter_end = min(iter_start + per_worker, len(self.idx))
592 | gen = self._dataset_generator(list(range(len(self.idx)))[iter_start:iter_end],
593 | shuffle=True)
594 | return gen
595 |
596 | def _dataset_generator(self, indices, shuffle=True):
597 | if shuffle: random.shuffle(indices)
598 | with torch.no_grad():
599 | for idx in indices:
600 | data = self.dataset[self.idx[idx]]
601 | atoms = data['atoms']
602 | for sub in data['labels'].itertuples():
603 | _, num, aa = sub.subunit.split('_')
604 | num, aa = int(num), _amino_acids(aa)
605 | if aa == 20: continue
606 | my_atoms = atoms.iloc[data['subunit_indices'][sub.Index]].reset_index(drop=True)
607 | ca_idx = np.where((my_atoms.residue == num) & (my_atoms.name == 'CA'))[0]
608 | if len(ca_idx) != 1: continue
609 |
610 | with torch.no_grad():
611 | graph = self.transform(my_atoms)
612 | graph.label = aa
613 | graph.ca_idx = int(ca_idx)
614 | yield graph
615 |
616 | class RESModel(BaseModel):
617 | '''
618 | GVP-GNN for the RES task.
619 |
620 | Extends BaseModel to output a 20-dim vector instead of a single
621 | scalar for each graph, which can be used as logits in 20-way
622 | classification.
623 |
624 | As noted in the manuscript, RESModel uses the final alpha
625 | carbon embeddings instead of the graph mean embedding.
626 | '''
627 | def __init__(self, **kwargs):
628 | super().__init__(**kwargs)
629 | ns, _ = _DEFAULT_V_DIM
630 | self.dense = nn.Sequential(
631 | nn.Linear(ns, 2*ns), nn.ReLU(inplace=True),
632 | nn.Dropout(p=0.1),
633 | nn.Linear(2*ns, 20)
634 | )
635 | def forward(self, batch):
636 | out = super().forward(batch, scatter_mean=False)
637 | return out[batch.ca_idx+batch.ptr[:-1]]
--------------------------------------------------------------------------------
/gvp/data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import tqdm, random
4 | import torch, math
5 | import torch.utils.data as data
6 | import torch.nn.functional as F
7 | import torch_geometric
8 | import torch_cluster
9 |
10 | def _normalize(tensor, dim=-1):
11 | '''
12 | Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
13 | '''
14 | return torch.nan_to_num(
15 | torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)))
16 |
17 |
18 | def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'):
19 | '''
20 | From https://github.com/jingraham/neurips19-graph-protein-design
21 |
22 | Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1.
23 | That is, if `D` has shape [...dims], then the returned tensor will have
24 | shape [...dims, D_count].
25 | '''
26 | D_mu = torch.linspace(D_min, D_max, D_count, device=device)
27 | D_mu = D_mu.view([1, -1])
28 | D_sigma = (D_max - D_min) / D_count
29 | D_expand = torch.unsqueeze(D, -1)
30 |
31 | RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2)
32 | return RBF
33 |
34 |
35 | class CATHDataset:
36 | '''
37 | Loader and container class for the CATH 4.2 dataset downloaded
38 | from http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/.
39 |
40 | Has attributes `self.train`, `self.val`, `self.test`, each of which are
41 | JSON/dictionary-type datasets as described in README.md.
42 |
43 | :param path: path to chain_set.jsonl
44 | :param splits_path: path to chain_set_splits.json or equivalent.
45 | '''
46 | def __init__(self, path, splits_path):
47 | with open(splits_path) as f:
48 | dataset_splits = json.load(f)
49 | train_list, val_list, test_list = dataset_splits['train'], \
50 | dataset_splits['validation'], dataset_splits['test']
51 |
52 | self.train, self.val, self.test = [], [], []
53 |
54 | with open(path) as f:
55 | lines = f.readlines()
56 |
57 | for line in tqdm.tqdm(lines):
58 | entry = json.loads(line)
59 | name = entry['name']
60 | coords = entry['coords']
61 |
62 | entry['coords'] = list(zip(
63 | coords['N'], coords['CA'], coords['C'], coords['O']
64 | ))
65 |
66 | if name in train_list:
67 | self.train.append(entry)
68 | elif name in val_list:
69 | self.val.append(entry)
70 | elif name in test_list:
71 | self.test.append(entry)
72 |
73 | class BatchSampler(data.Sampler):
74 | '''
75 | From https://github.com/jingraham/neurips19-graph-protein-design.
76 |
77 | A `torch.utils.data.Sampler` which samples batches according to a
78 | maximum number of graph nodes.
79 |
80 | :param node_counts: array of node counts in the dataset to sample from
81 | :param max_nodes: the maximum number of nodes in any batch,
82 | including batches of a single element
83 | :param shuffle: if `True`, batches in shuffled order
84 | '''
85 | def __init__(self, node_counts, max_nodes=3000, shuffle=True):
86 |
87 | self.node_counts = node_counts
88 | self.idx = [i for i in range(len(node_counts))
89 | if node_counts[i] <= max_nodes]
90 | self.shuffle = shuffle
91 | self.max_nodes = max_nodes
92 | self._form_batches()
93 |
94 | def _form_batches(self):
95 | self.batches = []
96 | if self.shuffle: random.shuffle(self.idx)
97 | idx = self.idx
98 | while idx:
99 | batch = []
100 | n_nodes = 0
101 | while idx and n_nodes + self.node_counts[idx[0]] <= self.max_nodes:
102 | next_idx, idx = idx[0], idx[1:]
103 | n_nodes += self.node_counts[next_idx]
104 | batch.append(next_idx)
105 | self.batches.append(batch)
106 |
107 | def __len__(self):
108 | if not self.batches: self._form_batches()
109 | return len(self.batches)
110 |
111 | def __iter__(self):
112 | if not self.batches: self._form_batches()
113 | for batch in self.batches: yield batch
114 |
115 | class ProteinGraphDataset(data.Dataset):
116 | '''
117 | A map-syle `torch.utils.data.Dataset` which transforms JSON/dictionary-style
118 | protein structures into featurized protein graphs as described in the
119 | manuscript.
120 |
121 | Returned graphs are of type `torch_geometric.data.Data` with attributes
122 | -x alpha carbon coordinates, shape [n_nodes, 3]
123 | -seq sequence converted to int tensor according to `self.letter_to_num`, shape [n_nodes]
124 | -name name of the protein structure, string
125 | -node_s node scalar features, shape [n_nodes, 6]
126 | -node_v node vector features, shape [n_nodes, 3, 3]
127 | -edge_s edge scalar features, shape [n_edges, 32]
128 | -edge_v edge scalar features, shape [n_edges, 1, 3]
129 | -edge_index edge indices, shape [2, n_edges]
130 | -mask node mask, `False` for nodes with missing data that are excluded from message passing
131 |
132 | Portions from https://github.com/jingraham/neurips19-graph-protein-design.
133 |
134 | :param data_list: JSON/dictionary-style protein dataset as described in README.md.
135 | :param num_positional_embeddings: number of positional embeddings
136 | :param top_k: number of edges to draw per node (as destination node)
137 | :param device: if "cuda", will do preprocessing on the GPU
138 | '''
139 | def __init__(self, data_list,
140 | num_positional_embeddings=16,
141 | top_k=30, num_rbf=16, device="cpu"):
142 |
143 | super(ProteinGraphDataset, self).__init__()
144 |
145 | self.data_list = data_list
146 | self.top_k = top_k
147 | self.num_rbf = num_rbf
148 | self.num_positional_embeddings = num_positional_embeddings
149 | self.device = device
150 | self.node_counts = [len(e['seq']) for e in data_list]
151 |
152 | self.letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9,
153 | 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8,
154 | 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19,
155 | 'N': 2, 'Y': 18, 'M': 12}
156 | self.num_to_letter = {v:k for k, v in self.letter_to_num.items()}
157 |
158 | def __len__(self): return len(self.data_list)
159 |
160 | def __getitem__(self, i): return self._featurize_as_graph(self.data_list[i])
161 |
162 | def _featurize_as_graph(self, protein):
163 | name = protein['name']
164 | with torch.no_grad():
165 | coords = torch.as_tensor(protein['coords'],
166 | device=self.device, dtype=torch.float32)
167 | seq = torch.as_tensor([self.letter_to_num[a] for a in protein['seq']],
168 | device=self.device, dtype=torch.long)
169 |
170 | mask = torch.isfinite(coords.sum(dim=(1,2)))
171 | coords[~mask] = np.inf
172 |
173 | X_ca = coords[:, 1]
174 | edge_index = torch_cluster.knn_graph(X_ca, k=self.top_k)
175 |
176 | pos_embeddings = self._positional_embeddings(edge_index)
177 | E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]]
178 | rbf = _rbf(E_vectors.norm(dim=-1), D_count=self.num_rbf, device=self.device)
179 |
180 | dihedrals = self._dihedrals(coords)
181 | orientations = self._orientations(X_ca)
182 | sidechains = self._sidechains(coords)
183 |
184 | node_s = dihedrals
185 | node_v = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2)
186 | edge_s = torch.cat([rbf, pos_embeddings], dim=-1)
187 | edge_v = _normalize(E_vectors).unsqueeze(-2)
188 |
189 | node_s, node_v, edge_s, edge_v = map(torch.nan_to_num,
190 | (node_s, node_v, edge_s, edge_v))
191 |
192 | data = torch_geometric.data.Data(x=X_ca, seq=seq, name=name,
193 | node_s=node_s, node_v=node_v,
194 | edge_s=edge_s, edge_v=edge_v,
195 | edge_index=edge_index, mask=mask)
196 | return data
197 |
198 | def _dihedrals(self, X, eps=1e-7):
199 | # From https://github.com/jingraham/neurips19-graph-protein-design
200 |
201 | X = torch.reshape(X[:, :3], [3*X.shape[0], 3])
202 | dX = X[1:] - X[:-1]
203 | U = _normalize(dX, dim=-1)
204 | u_2 = U[:-2]
205 | u_1 = U[1:-1]
206 | u_0 = U[2:]
207 |
208 | # Backbone normals
209 | n_2 = _normalize(torch.cross(u_2, u_1), dim=-1)
210 | n_1 = _normalize(torch.cross(u_1, u_0), dim=-1)
211 |
212 | # Angle between normals
213 | cosD = torch.sum(n_2 * n_1, -1)
214 | cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
215 | D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD)
216 |
217 | # This scheme will remove phi[0], psi[-1], omega[-1]
218 | D = F.pad(D, [1, 2])
219 | D = torch.reshape(D, [-1, 3])
220 | # Lift angle representations to the circle
221 | D_features = torch.cat([torch.cos(D), torch.sin(D)], 1)
222 | return D_features
223 |
224 |
225 | def _positional_embeddings(self, edge_index,
226 | num_embeddings=None,
227 | period_range=[2, 1000]):
228 | # From https://github.com/jingraham/neurips19-graph-protein-design
229 | num_embeddings = num_embeddings or self.num_positional_embeddings
230 | d = edge_index[0] - edge_index[1]
231 |
232 | frequency = torch.exp(
233 | torch.arange(0, num_embeddings, 2, dtype=torch.float32, device=self.device)
234 | * -(np.log(10000.0) / num_embeddings)
235 | )
236 | angles = d.unsqueeze(-1) * frequency
237 | E = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
238 | return E
239 |
240 | def _orientations(self, X):
241 | forward = _normalize(X[1:] - X[:-1])
242 | backward = _normalize(X[:-1] - X[1:])
243 | forward = F.pad(forward, [0, 0, 0, 1])
244 | backward = F.pad(backward, [0, 0, 1, 0])
245 | return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2)
246 |
247 | def _sidechains(self, X):
248 | n, origin, c = X[:, 0], X[:, 1], X[:, 2]
249 | c, n = _normalize(c - origin), _normalize(n - origin)
250 | bisector = _normalize(c + n)
251 | perp = _normalize(torch.cross(c, n))
252 | vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3)
253 | return vec
--------------------------------------------------------------------------------
/gvp/models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from . import GVP, GVPConvLayer, LayerNorm, tuple_index
5 | from torch.distributions import Categorical
6 | from torch_scatter import scatter_mean
7 |
8 | class CPDModel(torch.nn.Module):
9 | '''
10 | GVP-GNN for structure-conditioned autoregressive
11 | protein design as described in manuscript.
12 |
13 | Takes in protein structure graphs of type `torch_geometric.data.Data`
14 | or `torch_geometric.data.Batch` and returns a categorical distribution
15 | over 20 amino acids at each position in a `torch.Tensor` of
16 | shape [n_nodes, 20].
17 |
18 | Should be used with `gvp.data.ProteinGraphDataset`, or with generators
19 | of `torch_geometric.data.Batch` objects with the same attributes.
20 |
21 | The standard forward pass requires sequence information as input
22 | and should be used for training or evaluating likelihood.
23 | For sampling or design, use `self.sample`.
24 |
25 | :param node_in_dim: node dimensions in input graph, should be
26 | (6, 3) if using original features
27 | :param node_h_dim: node dimensions to use in GVP-GNN layers
28 | :param node_in_dim: edge dimensions in input graph, should be
29 | (32, 1) if using original features
30 | :param edge_h_dim: edge dimensions to embed to before use
31 | in GVP-GNN layers
32 | :param num_layers: number of GVP-GNN layers in each of the encoder
33 | and decoder modules
34 | :param drop_rate: rate to use in all dropout layers
35 | '''
36 | def __init__(self, node_in_dim, node_h_dim,
37 | edge_in_dim, edge_h_dim,
38 | num_layers=3, drop_rate=0.1):
39 |
40 | super(CPDModel, self).__init__()
41 |
42 | self.W_v = nn.Sequential(
43 | GVP(node_in_dim, node_h_dim, activations=(None, None)),
44 | LayerNorm(node_h_dim)
45 | )
46 | self.W_e = nn.Sequential(
47 | GVP(edge_in_dim, edge_h_dim, activations=(None, None)),
48 | LayerNorm(edge_h_dim)
49 | )
50 |
51 | self.encoder_layers = nn.ModuleList(
52 | GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate)
53 | for _ in range(num_layers))
54 |
55 | self.W_s = nn.Embedding(20, 20)
56 | edge_h_dim = (edge_h_dim[0] + 20, edge_h_dim[1])
57 |
58 | self.decoder_layers = nn.ModuleList(
59 | GVPConvLayer(node_h_dim, edge_h_dim,
60 | drop_rate=drop_rate, autoregressive=True)
61 | for _ in range(num_layers))
62 |
63 | self.W_out = GVP(node_h_dim, (20, 0), activations=(None, None))
64 |
65 | def forward(self, h_V, edge_index, h_E, seq):
66 | '''
67 | Forward pass to be used at train-time, or evaluating likelihood.
68 |
69 | :param h_V: tuple (s, V) of node embeddings
70 | :param edge_index: `torch.Tensor` of shape [2, num_edges]
71 | :param h_E: tuple (s, V) of edge embeddings
72 | :param seq: int `torch.Tensor` of shape [num_nodes]
73 | '''
74 | h_V = self.W_v(h_V)
75 | h_E = self.W_e(h_E)
76 |
77 | for layer in self.encoder_layers:
78 | h_V = layer(h_V, edge_index, h_E)
79 |
80 | encoder_embeddings = h_V
81 |
82 | h_S = self.W_s(seq)
83 | h_S = h_S[edge_index[0]]
84 | h_S[edge_index[0] >= edge_index[1]] = 0
85 | h_E = (torch.cat([h_E[0], h_S], dim=-1), h_E[1])
86 |
87 | for layer in self.decoder_layers:
88 | h_V = layer(h_V, edge_index, h_E, autoregressive_x = encoder_embeddings)
89 |
90 | logits = self.W_out(h_V)
91 |
92 | return logits
93 |
94 | def sample(self, h_V, edge_index, h_E, n_samples, temperature=0.1):
95 | '''
96 | Samples sequences autoregressively from the distribution
97 | learned by the model.
98 |
99 | :param h_V: tuple (s, V) of node embeddings
100 | :param edge_index: `torch.Tensor` of shape [2, num_edges]
101 | :param h_E: tuple (s, V) of edge embeddings
102 | :param n_samples: number of samples
103 | :param temperature: temperature to use in softmax
104 | over the categorical distribution
105 |
106 | :return: int `torch.Tensor` of shape [n_samples, n_nodes] based on the
107 | residue-to-int mapping of the original training data
108 | '''
109 |
110 | with torch.no_grad():
111 |
112 | device = edge_index.device
113 | L = h_V[0].shape[0]
114 |
115 | h_V = self.W_v(h_V)
116 | h_E = self.W_e(h_E)
117 |
118 | for layer in self.encoder_layers:
119 | h_V = layer(h_V, edge_index, h_E)
120 |
121 | h_V = (h_V[0].repeat(n_samples, 1),
122 | h_V[1].repeat(n_samples, 1, 1))
123 |
124 | h_E = (h_E[0].repeat(n_samples, 1),
125 | h_E[1].repeat(n_samples, 1, 1))
126 |
127 | edge_index = edge_index.expand(n_samples, -1, -1)
128 | offset = L * torch.arange(n_samples, device=device).view(-1, 1, 1)
129 | edge_index = torch.cat(tuple(edge_index + offset), dim=-1)
130 |
131 | seq = torch.zeros(n_samples * L, device=device, dtype=torch.int)
132 | h_S = torch.zeros(n_samples * L, 20, device=device)
133 |
134 | h_V_cache = [(h_V[0].clone(), h_V[1].clone()) for _ in self.decoder_layers]
135 |
136 | for i in range(L):
137 |
138 | h_S_ = h_S[edge_index[0]]
139 | h_S_[edge_index[0] >= edge_index[1]] = 0
140 | h_E_ = (torch.cat([h_E[0], h_S_], dim=-1), h_E[1])
141 |
142 | edge_mask = edge_index[1] % L == i
143 | edge_index_ = edge_index[:, edge_mask]
144 | h_E_ = tuple_index(h_E_, edge_mask)
145 | node_mask = torch.zeros(n_samples * L, device=device, dtype=torch.bool)
146 | node_mask[i::L] = True
147 |
148 | for j, layer in enumerate(self.decoder_layers):
149 | out = layer(h_V_cache[j], edge_index_, h_E_,
150 | autoregressive_x=h_V_cache[0], node_mask=node_mask)
151 |
152 | out = tuple_index(out, node_mask)
153 |
154 | if j < len(self.decoder_layers)-1:
155 | h_V_cache[j+1][0][i::L] = out[0]
156 | h_V_cache[j+1][1][i::L] = out[1]
157 |
158 | logits = self.W_out(out)
159 | seq[i::L] = Categorical(logits=logits / temperature).sample()
160 | h_S[i::L] = self.W_s(seq[i::L])
161 |
162 | return seq.view(n_samples, L)
163 |
164 | class MQAModel(nn.Module):
165 | '''
166 | GVP-GNN for Model Quality Assessment as described in manuscript.
167 |
168 | Takes in protein structure graphs of type `torch_geometric.data.Data`
169 | or `torch_geometric.data.Batch` and returns a scalar score for
170 | each graph in the batch in a `torch.Tensor` of shape [n_nodes]
171 |
172 | Should be used with `gvp.data.ProteinGraphDataset`, or with generators
173 | of `torch_geometric.data.Batch` objects with the same attributes.
174 |
175 | :param node_in_dim: node dimensions in input graph, should be
176 | (6, 3) if using original features
177 | :param node_h_dim: node dimensions to use in GVP-GNN layers
178 | :param node_in_dim: edge dimensions in input graph, should be
179 | (32, 1) if using original features
180 | :param edge_h_dim: edge dimensions to embed to before use
181 | in GVP-GNN layers
182 | :seq_in: if `True`, sequences will also be passed in with
183 | the forward pass; otherwise, sequence information
184 | is assumed to be part of input node embeddings
185 | :param num_layers: number of GVP-GNN layers
186 | :param drop_rate: rate to use in all dropout layers
187 | '''
188 | def __init__(self, node_in_dim, node_h_dim,
189 | edge_in_dim, edge_h_dim,
190 | seq_in=False, num_layers=3, drop_rate=0.1):
191 |
192 | super(MQAModel, self).__init__()
193 |
194 | if seq_in:
195 | self.W_s = nn.Embedding(20, 20)
196 | node_in_dim = (node_in_dim[0] + 20, node_in_dim[1])
197 |
198 | self.W_v = nn.Sequential(
199 | LayerNorm(node_in_dim),
200 | GVP(node_in_dim, node_h_dim, activations=(None, None))
201 | )
202 | self.W_e = nn.Sequential(
203 | LayerNorm(edge_in_dim),
204 | GVP(edge_in_dim, edge_h_dim, activations=(None, None))
205 | )
206 |
207 | self.layers = nn.ModuleList(
208 | GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate)
209 | for _ in range(num_layers))
210 |
211 | ns, _ = node_h_dim
212 | self.W_out = nn.Sequential(
213 | LayerNorm(node_h_dim),
214 | GVP(node_h_dim, (ns, 0)))
215 |
216 | self.dense = nn.Sequential(
217 | nn.Linear(ns, 2*ns), nn.ReLU(inplace=True),
218 | nn.Dropout(p=drop_rate),
219 | nn.Linear(2*ns, 1)
220 | )
221 |
222 | def forward(self, h_V, edge_index, h_E, seq=None, batch=None):
223 | '''
224 | :param h_V: tuple (s, V) of node embeddings
225 | :param edge_index: `torch.Tensor` of shape [2, num_edges]
226 | :param h_E: tuple (s, V) of edge embeddings
227 | :param seq: if not `None`, int `torch.Tensor` of shape [num_nodes]
228 | to be embedded and appended to `h_V`
229 | '''
230 | if seq is not None:
231 | seq = self.W_s(seq)
232 | h_V = (torch.cat([h_V[0], seq], dim=-1), h_V[1])
233 | h_V = self.W_v(h_V)
234 | h_E = self.W_e(h_E)
235 | for layer in self.layers:
236 | h_V = layer(h_V, edge_index, h_E)
237 | out = self.W_out(h_V)
238 |
239 | if batch is None: out = out.mean(dim=0, keepdims=True)
240 | else: out = scatter_mean(out, batch, dim=0)
241 |
242 | return self.dense(out).squeeze(-1) + 0.5
--------------------------------------------------------------------------------
/run_atom3d.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser()
4 | parser.add_argument('task', metavar='TASK', choices=[
5 | 'PSR', 'RSR', 'PPI', 'RES', 'MSP', 'SMP', 'LBA', 'LEP'
6 | ], help="{PSR, RSR, PPI, RES, MSP, SMP, LBA, LEP}")
7 | parser.add_argument('--num-workers', metavar='N', type=int, default=4,
8 | help='number of threads for loading data, default=4')
9 | parser.add_argument('--smp-idx', metavar='IDX', type=int, default=None,
10 | choices=list(range(20)),
11 | help='label index for SMP, in range 0-19')
12 | parser.add_argument('--lba-split', metavar='SPLIT', type=int, choices=[30, 60],
13 | help='identity cutoff for LBA, 30 (default) or 60', default=30)
14 | parser.add_argument('--batch', metavar='SIZE', type=int, default=8,
15 | help='batch size, default=8')
16 | parser.add_argument('--train-time', metavar='MINUTES', type=int, default=120,
17 | help='maximum time between evaluations on valset, default=120 minutes')
18 | parser.add_argument('--val-time', metavar='MINUTES', type=int, default=20,
19 | help='maximum time per evaluation on valset, default=20 minutes')
20 | parser.add_argument('--epochs', metavar='N', type=int, default=50,
21 | help='training epochs, default=50')
22 | parser.add_argument('--test', metavar='PATH', default=None,
23 | help='evaluate a trained model')
24 | parser.add_argument('--lr', metavar='RATE', default=1e-4, type=float,
25 | help='learning rate')
26 | parser.add_argument('--load', metavar='PATH', default=None,
27 | help='initialize first 2 GNN layers with pretrained weights')
28 |
29 | args = parser.parse_args()
30 |
31 | import gvp
32 | from atom3d.datasets import LMDBDataset
33 | import torch_geometric
34 | from functools import partial
35 | import gvp.atom3d
36 | import torch.nn as nn
37 | import tqdm, torch, time, os
38 | import numpy as np
39 | from atom3d.util import metrics
40 | import sklearn.metrics as sk_metrics
41 | from collections import defaultdict
42 | import scipy.stats as stats
43 | print = partial(print, flush=True)
44 |
45 | models_dir = 'models'
46 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
47 | model_id = float(time.time())
48 |
49 | def main():
50 | datasets = get_datasets(args.task, args.lba_split)
51 | dataloader = partial(torch_geometric.data.DataLoader,
52 | num_workers=args.num_workers, batch_size=args.batch)
53 | if args.task not in ['PPI', 'RES']:
54 | dataloader = partial(dataloader, shuffle=True)
55 |
56 | trainset, valset, testset = map(dataloader, datasets)
57 | model = get_model(args.task).to(device)
58 |
59 | if args.test:
60 | test(model, testset)
61 |
62 | else:
63 | if args.load:
64 | load(model, args.load)
65 | train(model, trainset, valset)
66 |
67 | def test(model, testset):
68 | model.load_state_dict(torch.load(args.test))
69 | model.eval()
70 | t = tqdm.tqdm(testset)
71 | metrics = get_metrics(args.task)
72 | targets, predicts, ids = [], [], []
73 | with torch.no_grad():
74 | for batch in t:
75 | pred = forward(model, batch, device)
76 | label = get_label(batch, args.task, args.smp_idx)
77 | if args.task == 'RES':
78 | pred = pred.argmax(dim=-1)
79 | if args.task in ['PSR', 'RSR']:
80 | ids.extend(batch.id)
81 | targets.extend(list(label.cpu().numpy()))
82 | predicts.extend(list(pred.cpu().numpy()))
83 |
84 | for name, func in metrics.items():
85 | if args.task in ['PSR', 'RSR']:
86 | func = partial(func, ids=ids)
87 | value = func(targets, predicts)
88 | print(f"{name}: {value}")
89 |
90 | def train(model, trainset, valset):
91 |
92 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
93 |
94 | best_path, best_val = None, np.inf
95 |
96 | for epoch in range(args.epochs):
97 | model.train()
98 | loss = loop(trainset, model, optimizer=optimizer, max_time=args.train_time)
99 | path = f"{models_dir}/{args.task}_{model_id}_{epoch}.pt"
100 | torch.save(model.state_dict(), path)
101 | print(f'\nEPOCH {epoch} TRAIN loss: {loss:.8f}')
102 | model.eval()
103 | with torch.no_grad():
104 | loss = loop(valset, model, max_time=args.val_time)
105 | print(f'\nEPOCH {epoch} VAL loss: {loss:.8f}')
106 | if loss < best_val:
107 | best_path, best_val = path, loss
108 | print(f'BEST {best_path} VAL loss: {best_val:.8f}')
109 |
110 | def loop(dataset, model, optimizer=None, max_time=None):
111 | start = time.time()
112 |
113 | loss_fn = get_loss(args.task)
114 | t = tqdm.tqdm(dataset)
115 | total_loss, total_count = 0, 0
116 |
117 | for batch in t:
118 | if max_time and (time.time() - start) > 60*max_time: break
119 | if optimizer: optimizer.zero_grad()
120 | try:
121 | out = forward(model, batch, device)
122 | except RuntimeError as e:
123 | if "CUDA out of memory" not in str(e): raise(e)
124 | torch.cuda.empty_cache()
125 | print('Skipped batch due to OOM', flush=True)
126 | continue
127 |
128 | label = get_label(batch, args.task, args.smp_idx)
129 | loss_value = loss_fn(out, label)
130 | total_loss += float(loss_value)
131 | total_count += 1
132 |
133 | if optimizer:
134 | try:
135 | loss_value.backward()
136 | optimizer.step()
137 | except RuntimeError as e:
138 | if "CUDA out of memory" not in str(e): raise(e)
139 | torch.cuda.empty_cache()
140 | print('Skipped batch due to OOM', flush=True)
141 | continue
142 |
143 | t.set_description(f"{total_loss/total_count:.8f}")
144 |
145 | return total_loss / total_count
146 |
147 | def load(model, path):
148 | params = torch.load(path)
149 | state_dict = model.state_dict()
150 | for name, p in params.items():
151 | if name in state_dict and \
152 | name[:8] in ['layers.0', 'layers.1'] and \
153 | state_dict[name].shape == p.shape:
154 | print("Loading", name)
155 | model.state_dict()[name].copy_(p)
156 |
157 | #######################################################################
158 |
159 | def get_label(batch, task, smp_idx=None):
160 | if type(batch) in [list, tuple]: batch = batch[0]
161 | if task == 'SMP':
162 | assert smp_idx is not None
163 | return batch.label[smp_idx::20]
164 | return batch.label
165 |
166 | def get_metrics(task):
167 | def _correlation(metric, targets, predict, ids=None, glob=True):
168 | if glob: return metric(targets, predict)
169 | _targets, _predict = defaultdict(list), defaultdict(list)
170 | for _t, _p, _id in zip(targets, predict, ids):
171 | _targets[_id].append(_t)
172 | _predict[_id].append(_p)
173 | return np.mean([metric(_targets[_id], _predict[_id]) for _id in _targets])
174 |
175 | correlations = {
176 | 'pearson': partial(_correlation, metrics.pearson),
177 | 'kendall': partial(_correlation, metrics.kendall),
178 | 'spearman': partial(_correlation, metrics.spearman)
179 | }
180 | mean_correlations = {f'mean {k}' : partial(v, glob=False) \
181 | for k, v in correlations.items()}
182 |
183 | return {
184 | 'RSR' : {**correlations, **mean_correlations},
185 | 'PSR' : {**correlations, **mean_correlations},
186 | 'PPI' : {'auroc': metrics.auroc},
187 | 'RES' : {'accuracy': metrics.accuracy},
188 | 'MSP' : {'auroc': metrics.auroc, 'auprc': metrics.auprc},
189 | 'LEP' : {'auroc': metrics.auroc, 'auprc': metrics.auprc},
190 | 'LBA' : {**correlations, 'rmse': partial(sk_metrics.mean_squared_error, squared=False)},
191 | 'SMP' : {'mae': sk_metrics.mean_absolute_error}
192 | }[task]
193 |
194 | def get_loss(task):
195 | if task in ['PSR', 'RSR', 'SMP', 'LBA']: return nn.MSELoss() # regression
196 | elif task in ['PPI', 'MSP', 'LEP']: return nn.BCELoss() # binary classification
197 | elif task in ['RES']: return nn.CrossEntropyLoss() # multiclass classification
198 |
199 | def forward(model, batch, device):
200 | if type(batch) in [list, tuple]:
201 | batch = batch[0].to(device), batch[1].to(device)
202 | else:
203 | batch = batch.to(device)
204 | return model(batch)
205 |
206 | def get_datasets(task, lba_split=30):
207 | data_path = {
208 | 'RES' : 'atom3d-data/RES/raw/RES/data/',
209 | 'PPI' : 'atom3d-data/PPI/splits/DIPS-split/data/',
210 | 'RSR' : 'atom3d-data/RSR/splits/candidates-split-by-time/data/',
211 | 'PSR' : 'atom3d-data/PSR/splits/split-by-year/data/',
212 | 'MSP' : 'atom3d-data/MSP/splits/split-by-sequence-identity-30/data/',
213 | 'LEP' : 'atom3d-data/LEP/splits/split-by-protein/data/',
214 | 'LBA' : f'atom3d-data/LBA/splits/split-by-sequence-identity-{lba_split}/data/',
215 | 'SMP' : 'atom3d-data/SMP/splits/random/data/'
216 | }[task]
217 |
218 | if task == 'RES':
219 | split_path = 'atom3d-data/RES/splits/split-by-cath-topology/indices/'
220 | dataset = partial(gvp.atom3d.RESDataset, data_path)
221 | trainset = dataset(split_path=split_path+'train_indices.txt')
222 | valset = dataset(split_path=split_path+'val_indices.txt')
223 | testset = dataset(split_path=split_path+'test_indices.txt')
224 |
225 | elif task == 'PPI':
226 | trainset = gvp.atom3d.PPIDataset(data_path+'train')
227 | valset = gvp.atom3d.PPIDataset(data_path+'val')
228 | testset = gvp.atom3d.PPIDataset(data_path+'test')
229 |
230 | else:
231 | transform = {
232 | 'RSR' : gvp.atom3d.RSRTransform,
233 | 'PSR' : gvp.atom3d.PSRTransform,
234 | 'MSP' : gvp.atom3d.MSPTransform,
235 | 'LEP' : gvp.atom3d.LEPTransform,
236 | 'LBA' : gvp.atom3d.LBATransform,
237 | 'SMP' : gvp.atom3d.SMPTransform,
238 | }[task]()
239 |
240 | trainset = LMDBDataset(data_path+'train', transform=transform)
241 | valset = LMDBDataset(data_path+'val', transform=transform)
242 | testset = LMDBDataset(data_path+'test', transform=transform)
243 |
244 | return trainset, valset, testset
245 |
246 | def get_model(task):
247 | return {
248 | 'RES' : gvp.atom3d.RESModel,
249 | 'PPI' : gvp.atom3d.PPIModel,
250 | 'RSR' : gvp.atom3d.RSRModel,
251 | 'PSR' : gvp.atom3d.PSRModel,
252 | 'MSP' : gvp.atom3d.MSPModel,
253 | 'LEP' : gvp.atom3d.LEPModel,
254 | 'LBA' : gvp.atom3d.LBAModel,
255 | 'SMP' : gvp.atom3d.SMPModel
256 | }[task]()
257 |
258 | if __name__ == "__main__":
259 | main()
--------------------------------------------------------------------------------
/run_cpd.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser()
4 | parser.add_argument('--models-dir', metavar='PATH', default='./models/',
5 | help='directory to save trained models, default=./models/')
6 | parser.add_argument('--num-workers', metavar='N', type=int, default=4,
7 | help='number of threads for loading data, default=4')
8 | parser.add_argument('--max-nodes', metavar='N', type=int, default=3000,
9 | help='max number of nodes per batch, default=3000')
10 | parser.add_argument('--epochs', metavar='N', type=int, default=100,
11 | help='training epochs, default=100')
12 | parser.add_argument('--cath-data', metavar='PATH', default='./data/chain_set.jsonl',
13 | help='location of CATH dataset, default=./data/chain_set.jsonl')
14 | parser.add_argument('--cath-splits', metavar='PATH', default='./data/chain_set_splits.json',
15 | help='location of CATH split file, default=./data/chain_set_splits.json')
16 | parser.add_argument('--ts50', metavar='PATH', default='./data/ts50.json',
17 | help='location of TS50 dataset, default=./data/ts50.json')
18 | parser.add_argument('--train', action="store_true", help="train a model")
19 | parser.add_argument('--test-r', metavar='PATH', default=None,
20 | help='evaluate a trained model on recovery (without training)')
21 | parser.add_argument('--test-p', metavar='PATH', default=None,
22 | help='evaluate a trained model on perplexity (without training)')
23 | parser.add_argument('--n-samples', metavar='N', default=100,
24 | help='number of sequences to sample (if testing recovery), default=100')
25 |
26 | args = parser.parse_args()
27 | assert sum(map(bool, [args.train, args.test_p, args.test_r])) == 1, \
28 | "Specify exactly one of --train, --test_r, --test_p"
29 |
30 | import torch
31 | import torch.nn as nn
32 | import gvp.data, gvp.models
33 | from datetime import datetime
34 | import tqdm, os, json
35 | import numpy as np
36 | from sklearn.metrics import confusion_matrix
37 | import torch_geometric
38 | from functools import partial
39 | print = partial(print, flush=True)
40 |
41 | node_dim = (100, 16)
42 | edge_dim = (32, 1)
43 | device = "cuda" if torch.cuda.is_available() else "cpu"
44 |
45 | if not os.path.exists(args.models_dir): os.makedirs(args.models_dir)
46 | model_id = int(datetime.timestamp(datetime.now()))
47 | dataloader = lambda x: torch_geometric.data.DataLoader(x,
48 | num_workers=args.num_workers,
49 | batch_sampler=gvp.data.BatchSampler(
50 | x.node_counts, max_nodes=args.max_nodes))
51 |
52 | def main():
53 |
54 | model = gvp.models.CPDModel((6, 3), node_dim, (32, 1), edge_dim).to(device)
55 |
56 | print("Loading CATH dataset")
57 | cath = gvp.data.CATHDataset(path="data/chain_set.jsonl",
58 | splits_path="data/chain_set_splits.json")
59 |
60 | trainset, valset, testset = map(gvp.data.ProteinGraphDataset,
61 | (cath.train, cath.val, cath.test))
62 |
63 | if args.test_r or args.test_p:
64 | ts50set = gvp.data.ProteinGraphDataset(json.load(open(args.ts50)))
65 | model.load_state_dict(torch.load(args.test_r or args.test_p))
66 |
67 | if args.test_r:
68 | print("Testing on CATH testset"); test_recovery(model, testset)
69 | print("Testing on TS50 set"); test_recovery(model, ts50set)
70 |
71 | elif args.test_p:
72 | print("Testing on CATH testset"); test_perplexity(model, testset)
73 | print("Testing on TS50 set"); test_perplexity(model, ts50set)
74 |
75 | elif args.train:
76 | train(model, trainset, valset, testset)
77 |
78 |
79 | def train(model, trainset, valset, testset):
80 | train_loader, val_loader, test_loader = map(dataloader,
81 | (trainset, valset, testset))
82 | optimizer = torch.optim.Adam(model.parameters())
83 | best_path, best_val = None, np.inf
84 | lookup = train_loader.dataset.num_to_letter
85 | for epoch in range(args.epochs):
86 | model.train()
87 | loss, acc, confusion = loop(model, train_loader, optimizer=optimizer)
88 | path = f"{args.models_dir}/{model_id}_{epoch}.pt"
89 | torch.save(model.state_dict(), path)
90 | print(f'EPOCH {epoch} TRAIN loss: {loss:.4f} acc: {acc:.4f}')
91 | print_confusion(confusion, lookup=lookup)
92 |
93 | model.eval()
94 | with torch.no_grad():
95 | loss, acc, confusion = loop(model, val_loader)
96 | print(f'EPOCH {epoch} VAL loss: {loss:.4f} acc: {acc:.4f}')
97 | print_confusion(confusion, lookup=lookup)
98 |
99 | if loss < best_val:
100 | best_path, best_val = path, loss
101 | print(f'BEST {best_path} VAL loss: {best_val:.4f}')
102 |
103 | print(f"TESTING: loading from {best_path}")
104 | model.load_state_dict(torch.load(best_path))
105 |
106 | model.eval()
107 | with torch.no_grad():
108 | loss, acc, confusion = loop(model, test_loader)
109 | print(f'TEST loss: {loss:.4f} acc: {acc:.4f}')
110 | print_confusion(confusion,lookup=lookup)
111 |
112 | def test_perplexity(model, dataset):
113 | model.eval()
114 | with torch.no_grad():
115 | loss, acc, confusion = loop(model, dataloader(dataset))
116 | print(f'TEST perplexity: {np.exp(loss):.4f}')
117 | print_confusion(confusion, lookup=dataset.num_to_letter)
118 |
119 | def test_recovery(model, dataset):
120 | recovery = []
121 |
122 | for protein in tqdm.tqdm(dataset):
123 | protein = protein.to(device)
124 | h_V = (protein.node_s, protein.node_v)
125 | h_E = (protein.edge_s, protein.edge_v)
126 | sample = model.sample(h_V, protein.edge_index,
127 | h_E, n_samples=args.n_samples)
128 |
129 | recovery_ = sample.eq(protein.seq).float().mean().cpu().numpy()
130 | recovery.append(recovery_)
131 | print(protein.name, recovery_, flush=True)
132 |
133 | recovery = np.median(recovery)
134 | print(f'TEST recovery: {recovery:.4f}')
135 |
136 | def loop(model, dataloader, optimizer=None):
137 |
138 | confusion = np.zeros((20, 20))
139 | t = tqdm.tqdm(dataloader)
140 | loss_fn = nn.CrossEntropyLoss()
141 | total_loss, total_correct, total_count = 0, 0, 0
142 |
143 | for batch in t:
144 | if optimizer: optimizer.zero_grad()
145 |
146 | batch = batch.to(device)
147 | h_V = (batch.node_s, batch.node_v)
148 | h_E = (batch.edge_s, batch.edge_v)
149 |
150 | logits = model(h_V, batch.edge_index, h_E, seq=batch.seq)
151 | logits, seq = logits[batch.mask], batch.seq[batch.mask]
152 | loss_value = loss_fn(logits, seq)
153 |
154 | if optimizer:
155 | loss_value.backward()
156 | optimizer.step()
157 |
158 | num_nodes = int(batch.mask.sum())
159 | total_loss += float(loss_value) * num_nodes
160 | total_count += num_nodes
161 | pred = torch.argmax(logits, dim=-1).detach().cpu().numpy()
162 | true = seq.detach().cpu().numpy()
163 | total_correct += (pred == true).sum()
164 | confusion += confusion_matrix(true, pred, labels=range(20))
165 | t.set_description("%.5f" % float(total_loss/total_count))
166 |
167 | torch.cuda.empty_cache()
168 |
169 | return total_loss / total_count, total_correct / total_count, confusion
170 |
171 | def print_confusion(mat, lookup):
172 | counts = mat.astype(np.int32)
173 | mat = (counts.T / counts.sum(axis=-1, keepdims=True).T).T
174 | mat = np.round(mat * 1000).astype(np.int32)
175 | res = '\n'
176 | for i in range(20):
177 | res += '\t{}'.format(lookup[i])
178 | res += '\tCount\n'
179 | for i in range(20):
180 | res += '{}\t'.format(lookup[i])
181 | res += '\t'.join('{}'.format(n) for n in mat[i])
182 | res += '\t{}\n'.format(sum(counts[i]))
183 | print(res)
184 |
185 | if __name__== "__main__":
186 | main()
--------------------------------------------------------------------------------
/schematic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/drorlab/gvp-pytorch/82af6b22eaf8311c15733117b0071408d24ed877/schematic.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open("README.md", "r", encoding="utf-8") as fh:
4 | long_description = fh.read()
5 |
6 | setup(
7 | name='gvp',
8 | packages=find_packages(include=[
9 | 'gvp',
10 | 'gvp.data',
11 | 'gvp.models'
12 | ]),
13 | version='0.1.1',
14 | description='Geometric Vector Perceptron',
15 | license='MIT',
16 | long_description=long_description,
17 | long_description_content_type="text/markdown",
18 | install_requires=[
19 | 'torch',
20 | 'torch_geometric',
21 | 'torch_scatter',
22 | 'torch_cluster',
23 | 'tqdm',
24 | 'numpy',
25 | 'sklearn',
26 | 'atom3d'
27 | ]
28 | )
--------------------------------------------------------------------------------
/test_equivariance.py:
--------------------------------------------------------------------------------
1 | import gvp
2 | import gvp.models
3 | import gvp.data
4 | import torch
5 | from torch import nn
6 | from scipy.spatial.transform import Rotation
7 | import unittest
8 |
9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10 |
11 | node_dim = (100, 16)
12 | edge_dim = (32, 1)
13 | n_nodes = 300
14 | n_edges = 10000
15 |
16 | nodes = gvp.randn(n_nodes, node_dim, device=device)
17 | edges = gvp.randn(n_edges, edge_dim, device=device)
18 | edge_index = torch.randint(0, n_nodes, (2, n_edges), device=device)
19 | batch_idx = torch.randint(0, 5, (n_nodes,), device=device)
20 | seq = torch.randint(0, 20, (n_nodes,), device=device)
21 |
22 | class EquivarianceTest(unittest.TestCase):
23 |
24 | def test_gvp(self):
25 | model = gvp.GVP(node_dim, node_dim).to(device).eval()
26 | model_fn = lambda h_V, h_E: model(h_V)
27 | test_equivariance(model_fn, nodes, edges)
28 |
29 | def test_gvp_vector_gate(self):
30 | model = gvp.GVP(node_dim, node_dim, vector_gate=True).to(device).eval()
31 | model_fn = lambda h_V, h_E: model(h_V)
32 | test_equivariance(model_fn, nodes, edges)
33 |
34 | def test_gvp_sequence(self):
35 | model = nn.Sequential(
36 | gvp.GVP(node_dim, node_dim),
37 | gvp.Dropout(0.1),
38 | gvp.LayerNorm(node_dim)
39 | ).to(device).eval()
40 | model_fn = lambda h_V, h_E: model(h_V)
41 | test_equivariance(model_fn, nodes, edges)
42 |
43 | def test_gvp_sequence_vector_gate(self):
44 | model = nn.Sequential(
45 | gvp.GVP(node_dim, node_dim, vector_gate=True),
46 | gvp.Dropout(0.1),
47 | gvp.LayerNorm(node_dim)
48 | ).to(device).eval()
49 | model_fn = lambda h_V, h_E: model(h_V)
50 | test_equivariance(model_fn, nodes, edges)
51 |
52 | def test_gvp_conv(self):
53 | model = gvp.GVPConv(node_dim, node_dim, edge_dim).to(device).eval()
54 | model_fn = lambda h_V, h_E: model(h_V, edge_index, h_E)
55 | test_equivariance(model_fn, nodes, edges)
56 |
57 | def test_gvp_conv_vector_gate(self):
58 | model = gvp.GVPConv(node_dim, node_dim, edge_dim, vector_gate=True).to(device).eval()
59 | model_fn = lambda h_V, h_E: model(h_V, edge_index, h_E)
60 | test_equivariance(model_fn, nodes, edges)
61 |
62 | def test_gvp_conv_layer(self):
63 | model = gvp.GVPConvLayer(node_dim, edge_dim).to(device).eval()
64 | model_fn = lambda h_V, h_E: model(h_V, edge_index, h_E,
65 | autoregressive_x=h_V)
66 | test_equivariance(model_fn, nodes, edges)
67 |
68 | def test_gvp_conv_layer_vector_gate(self):
69 | model = gvp.GVPConvLayer(node_dim, edge_dim, vector_gate=True).to(device).eval()
70 | model_fn = lambda h_V, h_E: model(h_V, edge_index, h_E,
71 | autoregressive_x=h_V)
72 | test_equivariance(model_fn, nodes, edges)
73 |
74 | def test_mqa_model(self):
75 | model = gvp.models.MQAModel(node_dim, node_dim,
76 | edge_dim, edge_dim).to(device).eval()
77 | model_fn = lambda h_V, h_E: (model(h_V, edge_index, h_E, batch=batch_idx), \
78 | torch.zeros_like(nodes[1]))
79 | test_equivariance(model_fn, nodes, edges)
80 |
81 | def test_cpd_model(self):
82 | model = gvp.models.CPDModel(node_dim, node_dim,
83 | edge_dim, edge_dim).to(device).eval()
84 | model_fn = lambda h_V, h_E: (model(h_V, edge_index, h_E, seq=seq), \
85 | torch.zeros_like(nodes[1]))
86 | test_equivariance(model_fn, nodes, edges)
87 |
88 |
89 | def test_equivariance(model, nodes, edges):
90 |
91 | random = torch.as_tensor(Rotation.random().as_matrix(),
92 | dtype=torch.float32, device=device)
93 |
94 | with torch.no_grad():
95 |
96 | out_s, out_v = model(nodes, edges)
97 | n_v_rot, e_v_rot = nodes[1] @ random, edges[1] @ random
98 | out_v_rot = out_v @ random
99 | out_s_prime, out_v_prime = model((nodes[0], n_v_rot), (edges[0], e_v_rot))
100 |
101 | assert torch.allclose(out_s, out_s_prime, atol=1e-5, rtol=1e-4)
102 | assert torch.allclose(out_v_rot, out_v_prime, atol=1e-5, rtol=1e-4)
103 |
104 |
105 | if __name__ == "__main__":
106 | unittest.main()
--------------------------------------------------------------------------------
/vectors.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/drorlab/gvp-pytorch/82af6b22eaf8311c15733117b0071408d24ed877/vectors.png
--------------------------------------------------------------------------------