├── .gitignore
├── Arguments.py
├── LICENSE.md
├── README.md
├── benchmark_layers.py
├── benchmark_models.py
├── benchmark_scripts
├── DGCNN_site.sh
├── Pointnet_site.sh
├── dMaSIF_search.sh
└── dMaSIF_site.sh
├── data.py
├── data_analysis
├── analyse_atomnet.ipynb
├── analyse_descriptors.py
├── analyse_descriptors_para.py
├── analyse_output.ipynb
├── analyse_site_outputs.py
├── analyse_site_outputs_graph.ipynb
├── plot_search.ipynb
└── profiling_surface.ipynb
├── data_iteration.py
├── data_preprocessing
├── convert_pdb2npy.py
├── convert_ply2npy.py
└── download_pdb.py
├── geometry_processing.py
├── helper.py
├── lists
├── testing.txt
├── testing_ppi.txt
├── training.txt
└── training_ppi.txt
├── main_inference.py
├── main_training.py
├── model.py
├── models
└── dMaSIF_search_3layer_12A_16dim
├── overview.PNG
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | 01-benchmark_surfaces/
2 | 01-benchmark_surfaces_npy/
3 | 01-benchmark_pdbs_npy/
4 | 01-benchmark_pdbs/
5 | 01-benchmark_pdbs/
6 | shape_index/
7 | masif_preds/
8 | runs/
9 | venv/
10 | preds/
11 | *.log
12 | NeurIPS_2020_benchmarks/
13 | *.out
14 | figures/
15 | timings/
16 | data_analysis/roc_curves
17 | data_analysis/.ipynb_checkpoints/
18 | .ipynb_checkpoints/
--------------------------------------------------------------------------------
/Arguments.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser(description="Network parameters")
4 |
5 | # Main parameters
6 | parser.add_argument(
7 | "--experiment_name", type=str, help="Name of experiment", required=True
8 | )
9 | parser.add_argument(
10 | "--use_mesh", type=bool, default=False, help="Use precomputed surfaces"
11 | )
12 | parser.add_argument(
13 | "--embedding_layer",
14 | type=str,
15 | default="dMaSIF",
16 | choices=["dMaSIF", "DGCNN", "PointNet++"],
17 | help="Which convolutional embedding layer to use",
18 | )
19 | parser.add_argument("--profile", type=bool, default=False, help="Profile code")
20 |
21 | # Geometric parameters
22 | parser.add_argument(
23 | "--curvature_scales",
24 | type=list,
25 | default=[1.0, 2.0, 3.0, 5.0, 10.0],
26 | help="Scales at which we compute the geometric features (mean and Gauss curvatures)",
27 | )
28 | parser.add_argument(
29 | "--resolution",
30 | type=float,
31 | default=1.0,
32 | help="Resolution of the generated point cloud",
33 | )
34 | parser.add_argument(
35 | "--distance",
36 | type=float,
37 | default=1.05,
38 | help="Distance parameter in surface generation",
39 | )
40 | parser.add_argument(
41 | "--variance",
42 | type=float,
43 | default=0.1,
44 | help="Variance parameter in surface generation",
45 | )
46 | parser.add_argument(
47 | "--sup_sampling", type=int, default=20, help="Sup-sampling ratio around atoms"
48 | )
49 |
50 | # Hyper-parameters for the embedding
51 | parser.add_argument(
52 | "--atom_dims",
53 | type=int,
54 | default=6,
55 | help="Number of atom types and dimension of resulting chemical features",
56 | )
57 | parser.add_argument(
58 | "--emb_dims",
59 | type=int,
60 | default=8,
61 | help="Number of input features (+ 3 xyz coordinates for DGCNNs)",
62 | )
63 | parser.add_argument(
64 | "--in_channels",
65 | type=int,
66 | default=16,
67 | help="Number of embedding dimensions",
68 | )
69 | parser.add_argument(
70 | "--orientation_units",
71 | type=int,
72 | default=16,
73 | help="Number of hidden units for the orientation score MLP",
74 | )
75 | parser.add_argument(
76 | "--unet_hidden_channels",
77 | type=int,
78 | default=8,
79 | help="Number of hidden units for TangentConv UNet",
80 | )
81 | parser.add_argument(
82 | "--post_units",
83 | type=int,
84 | default=8,
85 | help="Number of hidden units for the post-processing MLP",
86 | )
87 | parser.add_argument(
88 | "--n_layers", type=int, default=1, help="Number of convolutional layers"
89 | )
90 | parser.add_argument(
91 | "--radius", type=float, default=9.0, help="Radius to use for the convolution"
92 | )
93 | parser.add_argument(
94 | "--k",
95 | type=int,
96 | default=40,
97 | help="Number of nearset neighbours for DGCNN and PointNet++",
98 | )
99 | parser.add_argument(
100 | "--dropout",
101 | type=float,
102 | default=0.0,
103 | help="Amount of Dropout for the input features",
104 | )
105 |
106 | # Training
107 | parser.add_argument(
108 | "--n_epochs", type=int, default=50, help="Number of training epochs"
109 | )
110 | parser.add_argument(
111 | "--batch_size", type=int, default=1, help="Number of proteins in a batch"
112 | )
113 | parser.add_argument(
114 | "--device", type=str, default="cuda:0", help="Which gpu/cpu to train on"
115 | )
116 | parser.add_argument(
117 | "--restart_training",
118 | type=str,
119 | default="",
120 | help="Which model to restart the training from",
121 | )
122 | parser.add_argument(
123 | "--n_rocauc_samples",
124 | type=int,
125 | default=100,
126 | help="Number of samples for the Matching ROC-AUC",
127 | )
128 | parser.add_argument(
129 | "--validation_fraction",
130 | type=float,
131 | default=0.1,
132 | help="Fraction of training dataset to use for validation",
133 | )
134 | parser.add_argument("--seed", type=int, default=42, help="Random seed")
135 | parser.add_argument(
136 | "--random_rotation",
137 | type=bool,
138 | default=False,
139 | help="Move proteins to center and add random rotation",
140 | )
141 | parser.add_argument(
142 | "--single_protein",
143 | type=bool,
144 | default=False,
145 | help="Use single protein in a pair or both",
146 | )
147 | parser.add_argument("--site", type=bool, default=False, help="Predict interaction site")
148 | parser.add_argument(
149 | "--search",
150 | type=bool,
151 | default=False,
152 | help="Predict matching between two partners",
153 | )
154 | parser.add_argument(
155 | "--no_chem", type=bool, default=False, help="Predict without chemical information"
156 | )
157 | parser.add_argument(
158 | "--no_geom", type=bool, default=False, help="Predict without curvature information"
159 | )
160 | parser.add_argument(
161 | "--single_pdb",
162 | type=str,
163 | default="",
164 | help="Which structure to do inference on",
165 | )
166 | parser.add_argument(
167 | "--pdb_list",
168 | type=str,
169 | default="",
170 | help="Which structures to do inference on",
171 | )
172 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial-NoDerivatives 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial-NoDerivatives 4.0
58 | International Public License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial-NoDerivatives 4.0 International Public
63 | License ("Public License"). To the extent this Public License may be
64 | interpreted as a contract, You are granted the Licensed Rights in
65 | consideration of Your acceptance of these terms and conditions, and the
66 | Licensor grants You such rights in consideration of benefits the
67 | Licensor receives from making the Licensed Material available under
68 | these terms and conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Copyright and Similar Rights means copyright and/or similar rights
84 | closely related to copyright including, without limitation,
85 | performance, broadcast, sound recording, and Sui Generis Database
86 | Rights, without regard to how the rights are labeled or
87 | categorized. For purposes of this Public License, the rights
88 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
89 | Rights.
90 |
91 | c. Effective Technological Measures means those measures that, in the
92 | absence of proper authority, may not be circumvented under laws
93 | fulfilling obligations under Article 11 of the WIPO Copyright
94 | Treaty adopted on December 20, 1996, and/or similar international
95 | agreements.
96 |
97 | d. Exceptions and Limitations means fair use, fair dealing, and/or
98 | any other exception or limitation to Copyright and Similar Rights
99 | that applies to Your use of the Licensed Material.
100 |
101 | e. Licensed Material means the artistic or literary work, database,
102 | or other material to which the Licensor applied this Public
103 | License.
104 |
105 | f. Licensed Rights means the rights granted to You subject to the
106 | terms and conditions of this Public License, which are limited to
107 | all Copyright and Similar Rights that apply to Your use of the
108 | Licensed Material and that the Licensor has authority to license.
109 |
110 | g. Licensor means the individual(s) or entity(ies) granting rights
111 | under this Public License.
112 |
113 | h. NonCommercial means not primarily intended for or directed towards
114 | commercial advantage or monetary compensation. For purposes of
115 | this Public License, the exchange of the Licensed Material for
116 | other material subject to Copyright and Similar Rights by digital
117 | file-sharing or similar means is NonCommercial provided there is
118 | no payment of monetary compensation in connection with the
119 | exchange.
120 |
121 | i. Share means to provide material to the public by any means or
122 | process that requires permission under the Licensed Rights, such
123 | as reproduction, public display, public performance, distribution,
124 | dissemination, communication, or importation, and to make material
125 | available to the public including in ways that members of the
126 | public may access the material from a place and at a time
127 | individually chosen by them.
128 |
129 | j. Sui Generis Database Rights means rights other than copyright
130 | resulting from Directive 96/9/EC of the European Parliament and of
131 | the Council of 11 March 1996 on the legal protection of databases,
132 | as amended and/or succeeded, as well as other essentially
133 | equivalent rights anywhere in the world.
134 |
135 | k. You means the individual or entity exercising the Licensed Rights
136 | under this Public License. Your has a corresponding meaning.
137 |
138 |
139 | Section 2 -- Scope.
140 |
141 | a. License grant.
142 |
143 | 1. Subject to the terms and conditions of this Public License,
144 | the Licensor hereby grants You a worldwide, royalty-free,
145 | non-sublicensable, non-exclusive, irrevocable license to
146 | exercise the Licensed Rights in the Licensed Material to:
147 |
148 | a. reproduce and Share the Licensed Material, in whole or
149 | in part, for NonCommercial purposes only; and
150 |
151 | b. produce and reproduce, but not Share, Adapted Material
152 | for NonCommercial purposes only.
153 |
154 | 2. Exceptions and Limitations. For the avoidance of doubt, where
155 | Exceptions and Limitations apply to Your use, this Public
156 | License does not apply, and You do not need to comply with
157 | its terms and conditions.
158 |
159 | 3. Term. The term of this Public License is specified in Section
160 | 6(a).
161 |
162 | 4. Media and formats; technical modifications allowed. The
163 | Licensor authorizes You to exercise the Licensed Rights in
164 | all media and formats whether now known or hereafter created,
165 | and to make technical modifications necessary to do so. The
166 | Licensor waives and/or agrees not to assert any right or
167 | authority to forbid You from making technical modifications
168 | necessary to exercise the Licensed Rights, including
169 | technical modifications necessary to circumvent Effective
170 | Technological Measures. For purposes of this Public License,
171 | simply making modifications authorized by this Section 2(a)
172 | (4) never produces Adapted Material.
173 |
174 | 5. Downstream recipients.
175 |
176 | a. Offer from the Licensor -- Licensed Material. Every
177 | recipient of the Licensed Material automatically
178 | receives an offer from the Licensor to exercise the
179 | Licensed Rights under the terms and conditions of this
180 | Public License.
181 |
182 | b. No downstream restrictions. You may not offer or impose
183 | any additional or different terms or conditions on, or
184 | apply any Effective Technological Measures to, the
185 | Licensed Material if doing so restricts exercise of the
186 | Licensed Rights by any recipient of the Licensed
187 | Material.
188 |
189 | 6. No endorsement. Nothing in this Public License constitutes or
190 | may be construed as permission to assert or imply that You
191 | are, or that Your use of the Licensed Material is, connected
192 | with, or sponsored, endorsed, or granted official status by,
193 | the Licensor or others designated to receive attribution as
194 | provided in Section 3(a)(1)(A)(i).
195 |
196 | b. Other rights.
197 |
198 | 1. Moral rights, such as the right of integrity, are not
199 | licensed under this Public License, nor are publicity,
200 | privacy, and/or other similar personality rights; however, to
201 | the extent possible, the Licensor waives and/or agrees not to
202 | assert any such rights held by the Licensor to the limited
203 | extent necessary to allow You to exercise the Licensed
204 | Rights, but not otherwise.
205 |
206 | 2. Patent and trademark rights are not licensed under this
207 | Public License.
208 |
209 | 3. To the extent possible, the Licensor waives any right to
210 | collect royalties from You for the exercise of the Licensed
211 | Rights, whether directly or through a collecting society
212 | under any voluntary or waivable statutory or compulsory
213 | licensing scheme. In all other cases the Licensor expressly
214 | reserves any right to collect such royalties, including when
215 | the Licensed Material is used other than for NonCommercial
216 | purposes.
217 |
218 |
219 | Section 3 -- License Conditions.
220 |
221 | Your exercise of the Licensed Rights is expressly made subject to the
222 | following conditions.
223 |
224 | a. Attribution.
225 |
226 | 1. If You Share the Licensed Material, You must:
227 |
228 | a. retain the following if it is supplied by the Licensor
229 | with the Licensed Material:
230 |
231 | i. identification of the creator(s) of the Licensed
232 | Material and any others designated to receive
233 | attribution, in any reasonable manner requested by
234 | the Licensor (including by pseudonym if
235 | designated);
236 |
237 | ii. a copyright notice;
238 |
239 | iii. a notice that refers to this Public License;
240 |
241 | iv. a notice that refers to the disclaimer of
242 | warranties;
243 |
244 | v. a URI or hyperlink to the Licensed Material to the
245 | extent reasonably practicable;
246 |
247 | b. indicate if You modified the Licensed Material and
248 | retain an indication of any previous modifications; and
249 |
250 | c. indicate the Licensed Material is licensed under this
251 | Public License, and include the text of, or the URI or
252 | hyperlink to, this Public License.
253 |
254 | For the avoidance of doubt, You do not have permission under
255 | this Public License to Share Adapted Material.
256 |
257 | 2. You may satisfy the conditions in Section 3(a)(1) in any
258 | reasonable manner based on the medium, means, and context in
259 | which You Share the Licensed Material. For example, it may be
260 | reasonable to satisfy the conditions by providing a URI or
261 | hyperlink to a resource that includes the required
262 | information.
263 |
264 | 3. If requested by the Licensor, You must remove any of the
265 | information required by Section 3(a)(1)(A) to the extent
266 | reasonably practicable.
267 |
268 |
269 | Section 4 -- Sui Generis Database Rights.
270 |
271 | Where the Licensed Rights include Sui Generis Database Rights that
272 | apply to Your use of the Licensed Material:
273 |
274 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
275 | to extract, reuse, reproduce, and Share all or a substantial
276 | portion of the contents of the database for NonCommercial purposes
277 | only and provided You do not Share Adapted Material;
278 |
279 | b. if You include all or a substantial portion of the database
280 | contents in a database in which You have Sui Generis Database
281 | Rights, then the database in which You have Sui Generis Database
282 | Rights (but not its individual contents) is Adapted Material; and
283 |
284 | c. You must comply with the conditions in Section 3(a) if You Share
285 | all or a substantial portion of the contents of the database.
286 |
287 | For the avoidance of doubt, this Section 4 supplements and does not
288 | replace Your obligations under this Public License where the Licensed
289 | Rights include other Copyright and Similar Rights.
290 |
291 |
292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293 |
294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304 |
305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314 |
315 | c. The disclaimer of warranties and limitation of liability provided
316 | above shall be interpreted in a manner that, to the extent
317 | possible, most closely approximates an absolute disclaimer and
318 | waiver of all liability.
319 |
320 |
321 | Section 6 -- Term and Termination.
322 |
323 | a. This Public License applies for the term of the Copyright and
324 | Similar Rights licensed here. However, if You fail to comply with
325 | this Public License, then Your rights under this Public License
326 | terminate automatically.
327 |
328 | b. Where Your right to use the Licensed Material has terminated under
329 | Section 6(a), it reinstates:
330 |
331 | 1. automatically as of the date the violation is cured, provided
332 | it is cured within 30 days of Your discovery of the
333 | violation; or
334 |
335 | 2. upon express reinstatement by the Licensor.
336 |
337 | For the avoidance of doubt, this Section 6(b) does not affect any
338 | right the Licensor may have to seek remedies for Your violations
339 | of this Public License.
340 |
341 | c. For the avoidance of doubt, the Licensor may also offer the
342 | Licensed Material under separate terms or conditions or stop
343 | distributing the Licensed Material at any time; however, doing so
344 | will not terminate this Public License.
345 |
346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
347 | License.
348 |
349 |
350 | Section 7 -- Other Terms and Conditions.
351 |
352 | a. The Licensor shall not be bound by any additional or different
353 | terms or conditions communicated by You unless expressly agreed.
354 |
355 | b. Any arrangements, understandings, or agreements regarding the
356 | Licensed Material not stated herein are separate from and
357 | independent of the terms and conditions of this Public License.
358 |
359 |
360 | Section 8 -- Interpretation.
361 |
362 | a. For the avoidance of doubt, this Public License does not, and
363 | shall not be interpreted to, reduce, limit, restrict, or impose
364 | conditions on any use of the Licensed Material that could lawfully
365 | be made without permission under this Public License.
366 |
367 | b. To the extent possible, if any provision of this Public License is
368 | deemed unenforceable, it shall be automatically reformed to the
369 | minimum extent necessary to make it enforceable. If the provision
370 | cannot be reformed, it shall be severed from this Public License
371 | without affecting the enforceability of the remaining terms and
372 | conditions.
373 |
374 | c. No term or condition of this Public License will be waived and no
375 | failure to comply consented to unless expressly agreed to by the
376 | Licensor.
377 |
378 | d. Nothing in this Public License constitutes or may be interpreted
379 | as a limitation upon, or waiver of, any privileges and immunities
380 | that apply to the Licensor or You, including from the legal
381 | processes of any jurisdiction or authority.
382 |
383 | =======================================================================
384 |
385 | Creative Commons is not a party to its public
386 | licenses. Notwithstanding, Creative Commons may elect to apply one of
387 | its public licenses to material it publishes and in those instances
388 | will be considered the “Licensor.” The text of the Creative Commons
389 | public licenses is dedicated to the public domain under the CC0 Public
390 | Domain Dedication. Except for the limited purpose of indicating that
391 | material is shared under a Creative Commons public license or as
392 | otherwise permitted by the Creative Commons policies published at
393 | creativecommons.org/policies, Creative Commons does not authorize the
394 | use of the trademark "Creative Commons" or any other trademark or logo
395 | of Creative Commons without its prior written consent including,
396 | without limitation, in connection with any unauthorized modifications
397 | to any of its public licenses or any other arrangements,
398 | understandings, or agreements concerning use of licensed material. For
399 | the avoidance of doubt, this paragraph does not form part of the
400 | public licenses.
401 |
402 | Creative Commons may be contacted at creativecommons.org.
403 |
404 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## dMaSIF - Fast end-to-end learning on protein surfaces
2 | 
3 |
4 | ## Abstract
5 |
6 | Proteins’ biological functions are defined by the geometric
7 | and chemical structure of their 3D molecular surfaces.
8 | Recent works have shown that geometric deep learning can
9 | be used on mesh-based representations of proteins to identify
10 | potential functional sites, such as binding targets for
11 | potential drugs. Unfortunately though, the use of meshes as
12 | the underlying representation for protein structure has multiple
13 | drawbacks including the need to pre-compute the input
14 | features and mesh connectivities. This becomes a bottleneck
15 | for many important tasks in protein science.
16 |
17 | In this paper, we present a new framework for deep
18 | learning on protein structures that addresses these limitations.
19 | Among the key advantages of our method are the computation
20 | and sampling of the molecular surface on-the-fly
21 | from the underlying atomic point cloud and a novel efficient
22 | geometric convolutional layer. As a result, we are able to
23 | process large collections of proteins in an end-to-end fashion,
24 | taking as the sole input the raw 3D coordinates and
25 | chemical types of their atoms, eliminating the need for any
26 | hand-crafted pre-computed features.
27 |
28 | To showcase the performance of our approach, we test it
29 | on two tasks in the field of protein structural bioinformatics:
30 | the identification of interaction sites and the prediction
31 | of protein-protein interactions. On both tasks, we achieve
32 | state-of-the-art performance with much faster run times and
33 | fewer parameters than previous models. These results will
34 | considerably ease the deployment of deep learning methods
35 | in protein science and open the door for end-to-end differentiable
36 | approaches in protein modeling tasks such as function
37 | prediction and design.
38 |
39 | ## Hardware requirements
40 |
41 | Models have been trained on either a single NVIDIA RTX 2080 Ti or a single Tesla V100 GPU. Time and memory benchmarks were performed on a single Tesla V100.
42 |
43 | ## Software prerequisites
44 |
45 | Scripts have been tested using the following two sets of core dependencies:
46 |
47 | | Dependency | First Option | Second Option |
48 | | ------------- | ------------- | ------------- |
49 | | GCC | 7.5.0 | 8.4.0 |
50 | | CMAKE | 3.10.2 | 3.16.5 |
51 | | CUDA | 10.0.130 | 10.2.89 |
52 | | cuDNN | 7.6.4.38 | 7.6.5.32 |
53 | | Python | 3.6.9 | 3.7.7 |
54 | | PyTorch | 1.4.0 | 1.6.0 |
55 | | PyKeops | 1.4 | 1.4.1 |
56 | | PyTorch Geometric | 1.5.0 | 1.6.1 |
57 |
58 |
59 | ## Code overview
60 |
61 |
62 | Usage:
63 | - In order to **train models**, run `main_training.py` with the appropriate flags.
64 | Available flags and their descriptions can be found in `Arguments.py`.
65 |
66 | - The command line options needed to reproduce the **benchmarks** can be found in `benchmark_scripts/`.
67 |
68 | - To make **inference** on the testing set using pretrained models, use `main_inference.py` with the flags that were used for training the models.
69 | Note that the `--experiment_name flag` should be modified to specify the training epoch to use.
70 |
71 | Implementation:
72 | - Our **surface generation** algorithm, **curvature** estimation method and **quasi-geodesic convolutions** are implemented in `geometry_processing.py`.
73 |
74 | - The **definition of the neural network** along with surface and input features can be found in `model.py`. The convolutional layers are implemented in `benchmark_models.py`.
75 |
76 | - The scripts used to **generate the figures** of the paper can be found in `data_analysis/`.
77 |
78 |
79 | ## License
80 |
81 | 
This work is licensed under a Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License.
82 |
83 | ## Reference
84 |
85 | Sverrisson, F., Feydy, J., Correia, B. E., & Bronstein, M. M. (2020). Fast end-to-end learning on protein surfaces. [bioRxiv](https://www.biorxiv.org/content/10.1101/2020.12.28.424589v1).
--------------------------------------------------------------------------------
/benchmark_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional
3 | from pykeops.torch import LazyTensor
4 | from torch_geometric.nn import EdgeConv, Reshape
5 |
6 | from torch_cluster import knn
7 |
8 | from math import ceil
9 | from torch_geometric.nn.inits import reset
10 |
11 | from torch.nn import ELU, Conv1d
12 | from torch.nn import Sequential as S, Linear as L, BatchNorm1d as BN
13 |
14 |
15 | def ranges_slices(batch):
16 | """Helper function for the diagonal ranges function."""
17 | Ns = batch.bincount()
18 | indices = Ns.cumsum(0)
19 | ranges = torch.cat((0 * indices[:1], indices))
20 | ranges = (
21 | torch.stack((ranges[:-1], ranges[1:])).t().int().contiguous().to(batch.device)
22 | )
23 | slices = (1 + torch.arange(len(Ns))).int().to(batch.device)
24 |
25 | return ranges, slices
26 |
27 |
28 | def diagonal_ranges(batch_x=None, batch_y=None):
29 | """Encodes the block-diagonal structure associated to a batch vector."""
30 |
31 | if batch_x is None and batch_y is None:
32 | return None
33 |
34 | ranges_x, slices_x = ranges_slices(batch_x)
35 | ranges_y, slices_y = ranges_slices(batch_y)
36 |
37 | return ranges_x, slices_x, ranges_y, ranges_y, slices_y, ranges_x
38 |
39 |
40 | @torch.jit.ignore
41 | def keops_knn(
42 | x: torch.Tensor,
43 | y: torch.Tensor,
44 | k: int,
45 | batch_x: Optional[torch.Tensor] = None,
46 | batch_y: Optional[torch.Tensor] = None,
47 | cosine: bool = False,
48 | ) -> torch.Tensor:
49 | r"""Straightforward modification of PyTorch_geometric's knn method."""
50 |
51 | x = x.view(-1, 1) if x.dim() == 1 else x
52 | y = y.view(-1, 1) if y.dim() == 1 else y
53 |
54 | y_i = LazyTensor(y[:, None, :])
55 | x_j = LazyTensor(x[None, :, :])
56 |
57 | if cosine:
58 | D_ij = -(y_i | x_j)
59 | else:
60 | D_ij = ((y_i - x_j) ** 2).sum(-1)
61 |
62 | D_ij.ranges = diagonal_ranges(batch_y, batch_x)
63 | idy = D_ij.argKmin(k, dim=1) # (N, K)
64 |
65 | rows = torch.arange(k * len(y), device=idy.device) // k
66 |
67 | return torch.stack([rows, idy.view(-1)], dim=0)
68 |
69 |
70 | knns = {"torch": knn, "keops": keops_knn}
71 |
72 |
73 | @torch.jit.ignore
74 | def knn_graph(
75 | x: torch.Tensor,
76 | k: int,
77 | batch: Optional[torch.Tensor] = None,
78 | loop: bool = False,
79 | flow: str = "source_to_target",
80 | cosine: bool = False,
81 | target: Optional[torch.Tensor] = None,
82 | batch_target: Optional[torch.Tensor] = None,
83 | backend: str = "torch",
84 | ) -> torch.Tensor:
85 | r"""Straightforward modification of PyTorch_geometric's knn_graph method to allow for source/targets."""
86 |
87 | assert flow in ["source_to_target", "target_to_source"]
88 | if target is None:
89 | target = x
90 | if batch_target is None:
91 | batch_target = batch
92 |
93 | row, col = knns[backend](
94 | x, target, k if loop else k + 1, batch, batch_target, cosine=cosine
95 | )
96 | row, col = (col, row) if flow == "source_to_target" else (row, col)
97 | if not loop:
98 | mask = row != col
99 | row, col = row[mask], col[mask]
100 | return torch.stack([row, col], dim=0)
101 |
102 |
103 | class MyDynamicEdgeConv(EdgeConv):
104 | r"""Straightforward modification of PyTorch_geometric's DynamicEdgeConv layer."""
105 |
106 | def __init__(self, nn, k, aggr="max", **kwargs):
107 | super(MyDynamicEdgeConv, self).__init__(nn=nn, aggr=aggr, **kwargs)
108 | self.k = k
109 |
110 | def forward(self, x, batch=None):
111 | """"""
112 | edge_index = knn_graph(
113 | x, self.k, batch, loop=False, flow=self.flow, backend="keops"
114 | )
115 | return super(MyDynamicEdgeConv, self).forward(x, edge_index)
116 |
117 | def __repr__(self):
118 | return "{}(nn={}, k={})".format(self.__class__.__name__, self.nn, self.k)
119 |
120 |
121 | class MyXConv(torch.nn.Module):
122 | def __init__(
123 | self,
124 | in_channels=None,
125 | out_channels=None,
126 | dim=None,
127 | kernel_size=None,
128 | hidden_channels=None,
129 | dilation=1,
130 | bias=True,
131 | backend="torch",
132 | ):
133 | super(MyXConv, self).__init__()
134 |
135 | self.in_channels = in_channels
136 | if hidden_channels is None:
137 | hidden_channels = in_channels // 4
138 | if hidden_channels == 0:
139 | hidden_channels = 1
140 |
141 | self.hidden_channels = hidden_channels
142 | self.out_channels = out_channels
143 | self.dim = dim
144 | self.kernel_size = kernel_size
145 | self.dilation = dilation
146 | self.backend = backend
147 |
148 | C_in, C_delta, C_out = in_channels, hidden_channels, out_channels
149 | D, K = dim, kernel_size
150 |
151 | self.mlp1 = S(
152 | L(dim, C_delta),
153 | ELU(),
154 | BN(C_delta),
155 | L(C_delta, C_delta),
156 | ELU(),
157 | BN(C_delta),
158 | Reshape(-1, K, C_delta),
159 | )
160 |
161 | self.mlp2 = S(
162 | L(D * K, K ** 2),
163 | ELU(),
164 | BN(K ** 2),
165 | Reshape(-1, K, K),
166 | Conv1d(K, K ** 2, K, groups=K),
167 | ELU(),
168 | BN(K ** 2),
169 | Reshape(-1, K, K),
170 | Conv1d(K, K ** 2, K, groups=K),
171 | BN(K ** 2),
172 | Reshape(-1, K, K),
173 | )
174 |
175 | C_in = C_in + C_delta
176 | depth_multiplier = int(ceil(C_out / C_in))
177 | self.conv = S(
178 | Conv1d(C_in, C_in * depth_multiplier, K, groups=C_in),
179 | Reshape(-1, C_in * depth_multiplier),
180 | L(C_in * depth_multiplier, C_out, bias=bias),
181 | )
182 |
183 | self.reset_parameters()
184 |
185 | def reset_parameters(self):
186 | reset(self.mlp1)
187 | reset(self.mlp2)
188 | reset(self.conv)
189 |
190 | def forward(self, x, source, batch_source, target, batch_target):
191 | """"""
192 | # Load data shapes:
193 | # pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
194 | (Nin, _), (N, D), K = source.size(), target.size(), self.kernel_size
195 |
196 | # Compute K-nn:
197 | row, col = knn_graph(
198 | source,
199 | K * self.dilation,
200 | batch_source,
201 | loop=True,
202 | flow="target_to_source",
203 | target=target,
204 | batch_target=batch_target,
205 | backend=self.backend,
206 | )
207 | # row is a vector of size N*K*dilation that indexes "target"
208 | # col is a vector of size N*K*dilation that indexes "source"
209 |
210 | # If needed, sup-sample the K-NN graph:
211 | if self.dilation > 1:
212 | dil = self.dilation
213 | index = torch.randint(
214 | K * dil,
215 | (N, K),
216 | dtype=torch.long,
217 | layout=torch.strided,
218 | device=row.device,
219 | )
220 | arange = torch.arange(N, dtype=torch.long, device=row.device)
221 | arange = arange * (K * dil)
222 | index = (index + arange.view(-1, 1)).view(-1) # (N*K,)
223 | row, col = row[index], col[index]
224 |
225 | # assert row.max() < N
226 | # assert col.max() < Nin
227 |
228 | # Line 1: local difference vector:
229 | pos = source[col] - target[row] # (N * K, D)
230 |
231 | # Line 2: compute F_delta
232 | x_star = self.mlp1(pos.view(N * K, D))
233 |
234 | # Line 3: concatenate the features and reshape:
235 | if x is not None:
236 | x = x.unsqueeze(-1) if x.dim() == 1 else x
237 | x = x[col].view(N, K, self.in_channels)
238 | x_star = torch.cat([x_star, x], dim=-1)
239 | x_star = x_star.transpose(1, 2).contiguous()
240 | x_star = x_star.view(N, self.in_channels + self.hidden_channels, K, 1)
241 |
242 | # Line 4: Compute the transformation matrix:
243 | transform_matrix = self.mlp2(pos.view(N, K * D))
244 | transform_matrix = transform_matrix.view(N, 1, K, K)
245 |
246 | # Line 5: Apply it to the neighborhood:
247 | x_transformed = torch.matmul(transform_matrix, x_star)
248 | x_transformed = x_transformed.view(N, -1, K) # (N, I+H, K)
249 |
250 | # Line 6: Apply the convolution filter:
251 | out = self.conv(x_transformed) # (N, Cout)
252 |
253 | return out
254 |
255 | def __repr__(self):
256 | return "{}({}, {})".format(
257 | self.__class__.__name__, self.in_channels, self.out_channels
258 | )
259 |
--------------------------------------------------------------------------------
/benchmark_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.nn import (
5 | Sequential as Seq,
6 | Dropout,
7 | Linear as Lin,
8 | LeakyReLU,
9 | ReLU,
10 | BatchNorm1d as BN,
11 | )
12 | import torch_geometric.transforms as T
13 | from torch_geometric.data import DataLoader
14 | from torch_geometric.nn import (
15 | DynamicEdgeConv,
16 | PointConv,
17 | XConv,
18 | fps,
19 | radius,
20 | global_max_pool,
21 | knn_interpolate,
22 | )
23 | from pykeops.torch import LazyTensor
24 |
25 | from benchmark_layers import MyDynamicEdgeConv, MyXConv
26 | from geometry_processing import dMaSIFConv, mesh_normals_areas, tangent_vectors
27 | from helper import diagonal_ranges
28 |
29 | DEConv = {"torch": DynamicEdgeConv, "keops": MyDynamicEdgeConv}
30 |
31 | # Dynamic Graph CNNs ===========================================================
32 | # Adapted from the PyTorch_geometric gallery to get a close fit to
33 | # the original paper.
34 |
35 |
36 | def MLP(channels, batch_norm=True):
37 | """Multi-layer perceptron, with ReLU non-linearities and batch normalization."""
38 | return Seq(
39 | *[
40 | Seq(
41 | Lin(channels[i - 1], channels[i]),
42 | BN(channels[i]) if batch_norm else nn.Identity(),
43 | LeakyReLU(negative_slope=0.2),
44 | )
45 | for i in range(1, len(channels))
46 | ]
47 | )
48 |
49 |
50 | class DGCNN_seg(torch.nn.Module):
51 | def __init__(
52 | self, in_channels, out_channels, n_layers, k=40, aggr="max", backend="keops"
53 | ):
54 | super(DGCNN_seg, self).__init__()
55 |
56 | self.name = "DGCNN_seg_" + backend
57 | self.I, self.O = (
58 | in_channels + 3,
59 | out_channels,
60 | ) # Add coordinates to input channels
61 | self.n_layers = n_layers
62 |
63 | self.transform_1 = DEConv[backend](MLP([2 * 3, 64, 128]), k, aggr)
64 | self.transform_2 = MLP([128, 1024])
65 | self.transform_3 = MLP([1024, 512, 256], batch_norm=False)
66 | self.transform_4 = Lin(256, 3 * 3)
67 |
68 | self.conv_layers = nn.ModuleList(
69 | [DEConv[backend](MLP([2 * self.I, self.O, self.O]), k, aggr)]
70 | + [
71 | DEConv[backend](MLP([2 * self.O, self.O, self.O]), k, aggr)
72 | for i in range(n_layers - 1)
73 | ]
74 | )
75 |
76 | self.linear_layers = nn.ModuleList(
77 | [
78 | nn.Sequential(
79 | nn.Linear(self.O, self.O), nn.ReLU(), nn.Linear(self.O, self.O)
80 | )
81 | for i in range(n_layers)
82 | ]
83 | )
84 |
85 | self.linear_transform = nn.ModuleList(
86 | [nn.Linear(self.I, self.O)]
87 | + [nn.Linear(self.O, self.O) for i in range(n_layers - 1)]
88 | )
89 |
90 | def forward(self, positions, features, batch_indices):
91 | # Lab: (B,), Pos: (N, 3), Batch: (N,)
92 | pos, feat, batch = positions, features, batch_indices
93 |
94 | # TransformNet:
95 | x = pos # Don't use the normals!
96 |
97 | x = self.transform_1(x, batch) # (N, 3) -> (N, 128)
98 | x = self.transform_2(x) # (N, 128) -> (N, 1024)
99 | x = global_max_pool(x, batch) # (B, 1024)
100 |
101 | x = self.transform_3(x) # (B, 256)
102 | x = self.transform_4(x) # (B, 3*3)
103 | x = x[batch] # (N, 3*3)
104 | x = x.view(-1, 3, 3) # (N, 3, 3)
105 |
106 | # Apply the transform:
107 | x0 = torch.einsum("ni,nij->nj", pos, x) # (N, 3)
108 |
109 | # Add features to coordinates
110 | x = torch.cat([x0, feat], dim=-1).contiguous()
111 |
112 | for i in range(self.n_layers):
113 | x_i = self.conv_layers[i](x, batch)
114 | x_i = self.linear_layers[i](x_i)
115 | x = self.linear_transform[i](x)
116 | x = x + x_i
117 |
118 | return x
119 |
120 |
121 | # Reference PointNet models, from the PyTorch_geometric gallery =========================
122 |
123 |
124 | class SAModule(torch.nn.Module):
125 | """Set abstraction module."""
126 |
127 | def __init__(self, ratio, r, nn, max_num_neighbors=64):
128 | super(SAModule, self).__init__()
129 | self.ratio = ratio
130 | self.r = r
131 | self.conv = PointConv(nn)
132 | self.max_num_neighbors = max_num_neighbors
133 |
134 | def forward(self, x, pos, batch):
135 | # Subsample with Farthest Point Sampling:
136 | # idx = fps(pos, batch, ratio=self.ratio) # Extract self.ratio indices TURN OFF FOR NOW
137 | idx = torch.arange(0, len(pos), device=pos.device)
138 |
139 | # For each "cluster", get the list of (up to 64) neighbors in a ball of radius r:
140 | row, col = radius(
141 | pos,
142 | pos[idx],
143 | self.r,
144 | batch,
145 | batch[idx],
146 | max_num_neighbors=self.max_num_neighbors,
147 | )
148 |
149 | # Applies the PointNet++ Conv:
150 | edge_index = torch.stack([col, row], dim=0)
151 | x = self.conv(x, (pos, pos[idx]), edge_index)
152 |
153 | # Return the features and sub-sampled point clouds:
154 | pos, batch = pos[idx], batch[idx]
155 | return x, pos, batch
156 |
157 |
158 | class GlobalSAModule(torch.nn.Module):
159 | def __init__(self, nn):
160 | super(GlobalSAModule, self).__init__()
161 | self.nn = nn
162 |
163 | def forward(self, x, pos, batch):
164 | x = self.nn(torch.cat([x, pos], dim=1))
165 | x = global_max_pool(x, batch)
166 | pos = pos.new_zeros((x.size(0), 3))
167 | batch = torch.arange(x.size(0), device=batch.device)
168 | return x, pos, batch
169 |
170 |
171 | class FPModule(torch.nn.Module):
172 | def __init__(self, k, nn):
173 | super(FPModule, self).__init__()
174 | self.k = k
175 | self.nn = nn
176 |
177 | def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):
178 | x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
179 | if x_skip is not None:
180 | x = torch.cat([x, x_skip], dim=1)
181 | x = self.nn(x)
182 | return x, pos_skip, batch_skip
183 |
184 |
185 | class PointNet2_seg(torch.nn.Module):
186 | def __init__(self, args, in_channels, out_channels):
187 | super(PointNet2_seg, self).__init__()
188 |
189 | self.name = "PointNet2"
190 | self.I, self.O = in_channels, out_channels
191 | self.radius = args.radius
192 | self.k = 10000 # We don't restrict the number of points in a patch
193 | self.n_layers = args.n_layers
194 |
195 | # self.sa1_module = SAModule(1.0, self.radius, MLP([self.I+3, self.O, self.O]),self.k)
196 | self.layers = nn.ModuleList(
197 | [SAModule(1.0, self.radius, MLP([self.I + 3, self.O, self.O]), self.k)]
198 | + [
199 | SAModule(1.0, self.radius, MLP([self.O + 3, self.O, self.O]), self.k)
200 | for i in range(self.n_layers - 1)
201 | ]
202 | )
203 |
204 | self.linear_layers = nn.ModuleList(
205 | [
206 | nn.Sequential(
207 | nn.Linear(self.O, self.O), nn.ReLU(), nn.Linear(self.O, self.O)
208 | )
209 | for i in range(self.n_layers)
210 | ]
211 | )
212 |
213 | self.linear_transform = nn.ModuleList(
214 | [nn.Linear(self.I, self.O)]
215 | + [nn.Linear(self.O, self.O) for i in range(self.n_layers - 1)]
216 | )
217 |
218 | def forward(self, positions, features, batch_indices):
219 | x = (features, positions, batch_indices)
220 | for i, layer in enumerate(self.layers):
221 | x_i, pos, b_ind = layer(*x)
222 | x_i = self.linear_layers[i](x_i)
223 | x = self.linear_transform[i](x[0])
224 | x = x + x_i
225 | x = (x, pos, b_ind)
226 |
227 | return x[0]
228 |
229 |
230 | ## TangentConv benchmark segmentation
231 |
232 |
233 | class dMaSIFConv_seg(torch.nn.Module):
234 | def __init__(self, args, in_channels, out_channels, n_layers, radius=9.0):
235 | super(dMaSIFConv_seg, self).__init__()
236 |
237 | self.name = "dMaSIFConv_seg_keops"
238 | self.radius = radius
239 | self.I, self.O = in_channels, out_channels
240 |
241 | self.layers = nn.ModuleList(
242 | [dMaSIFConv(self.I, self.O, radius, self.O)]
243 | + [dMaSIFConv(self.O, self.O, radius, self.O) for i in range(n_layers - 1)]
244 | )
245 |
246 | self.linear_layers = nn.ModuleList(
247 | [
248 | nn.Sequential(
249 | nn.Linear(self.O, self.O), nn.ReLU(), nn.Linear(self.O, self.O)
250 | )
251 | for i in range(n_layers)
252 | ]
253 | )
254 |
255 | self.linear_transform = nn.ModuleList(
256 | [nn.Linear(self.I, self.O)]
257 | + [nn.Linear(self.O, self.O) for i in range(n_layers - 1)]
258 | )
259 |
260 | def forward(self, features):
261 | # Lab: (B,), Pos: (N, 3), Batch: (N,)
262 | points, nuv, ranges = self.points, self.nuv, self.ranges
263 | x = features
264 | for i, layer in enumerate(self.layers):
265 | x_i = layer(points, nuv, x, ranges)
266 | x_i = self.linear_layers[i](x_i)
267 | x = self.linear_transform[i](x)
268 | x = x + x_i
269 |
270 | return x
271 |
272 | def load_mesh(self, xyz, triangles=None, normals=None, weights=None, batch=None):
273 | """Loads the geometry of a triangle mesh.
274 |
275 | Input arguments:
276 | - xyz, a point cloud encoded as an (N, 3) Tensor.
277 | - triangles, a connectivity matrix encoded as an (N, 3) integer tensor.
278 | - weights, importance weights for the orientation estimation, encoded as an (N, 1) Tensor.
279 | - radius, the scale used to estimate the local normals.
280 | - a batch vector, following PyTorch_Geometric's conventions.
281 |
282 | The routine updates the model attributes:
283 | - points, i.e. the point cloud itself,
284 | - nuv, a local oriented basis in R^3 for every point,
285 | - ranges, custom KeOps syntax to implement batch processing.
286 | """
287 |
288 | # 1. Save the vertices for later use in the convolutions ---------------
289 | self.points = xyz
290 | self.batch = batch
291 | self.ranges = diagonal_ranges(
292 | batch
293 | ) # KeOps support for heterogeneous batch processing
294 | self.triangles = triangles
295 | self.normals = normals
296 | self.weights = weights
297 |
298 | # 2. Estimate the normals and tangent frame ----------------------------
299 | # Normalize the scale:
300 | points = xyz / self.radius
301 |
302 | # Normals and local areas:
303 | if normals is None:
304 | normals, areas = mesh_normals_areas(points, triangles, 0.5, batch)
305 | tangent_bases = tangent_vectors(normals) # Tangent basis (N, 2, 3)
306 |
307 | # 3. Steer the tangent bases according to the gradient of "weights" ----
308 |
309 | # 3.a) Encoding as KeOps LazyTensors:
310 | # Orientation scores:
311 | weights_j = LazyTensor(weights.view(1, -1, 1)) # (1, N, 1)
312 | # Vertices:
313 | x_i = LazyTensor(points[:, None, :]) # (N, 1, 3)
314 | x_j = LazyTensor(points[None, :, :]) # (1, N, 3)
315 | # Normals:
316 | n_i = LazyTensor(normals[:, None, :]) # (N, 1, 3)
317 | n_j = LazyTensor(normals[None, :, :]) # (1, N, 3)
318 | # Tangent basis:
319 | uv_i = LazyTensor(tangent_bases.view(-1, 1, 6)) # (N, 1, 6)
320 |
321 | # 3.b) Pseudo-geodesic window:
322 | # Pseudo-geodesic squared distance:
323 | rho2_ij = ((x_j - x_i) ** 2).sum(-1) * ((2 - (n_i | n_j)) ** 2) # (N, N, 1)
324 | # Gaussian window:
325 | window_ij = (-rho2_ij).exp() # (N, N, 1)
326 |
327 | # 3.c) Coordinates in the (u, v) basis - not oriented yet:
328 | X_ij = uv_i.matvecmult(x_j - x_i) # (N, N, 2)
329 |
330 | # 3.d) Local average in the tangent plane:
331 | orientation_weight_ij = window_ij * weights_j # (N, N, 1)
332 | orientation_vector_ij = orientation_weight_ij * X_ij # (N, N, 2)
333 |
334 | # Support for heterogeneous batch processing:
335 | orientation_vector_ij.ranges = self.ranges # Block-diagonal sparsity mask
336 |
337 | orientation_vector_i = orientation_vector_ij.sum(dim=1) # (N, 2)
338 | orientation_vector_i = (
339 | orientation_vector_i + 1e-5
340 | ) # Just in case someone's alone...
341 |
342 | # 3.e) Normalize stuff:
343 | orientation_vector_i = F.normalize(orientation_vector_i, p=2, dim=-1) # (N, 2)
344 | ex_i, ey_i = (
345 | orientation_vector_i[:, 0][:, None],
346 | orientation_vector_i[:, 1][:, None],
347 | ) # (N,1)
348 |
349 | # 3.f) Re-orient the (u,v) basis:
350 | uv_i = tangent_bases # (N, 2, 3)
351 | u_i, v_i = uv_i[:, 0, :], uv_i[:, 1, :] # (N, 3)
352 | tangent_bases = torch.cat(
353 | (ex_i * u_i + ey_i * v_i, -ey_i * u_i + ex_i * v_i), dim=1
354 | ).contiguous() # (N, 6)
355 |
356 | # 4. Store the local 3D frame as an attribute --------------------------
357 | self.nuv = torch.cat(
358 | (normals.view(-1, 1, 3), tangent_bases.view(-1, 2, 3)), dim=1
359 | )
360 |
--------------------------------------------------------------------------------
/benchmark_scripts/DGCNN_site.sh:
--------------------------------------------------------------------------------
1 | # Load environment
2 | python -W ignore -u main_training.py --experiment_name DGCNN_site_1layer_k200 --batch_size 64 --embedding_layer DGCNN --site True --single_protein True --device cuda:0 --n_layers 1 --random_rotation True --k 200
3 | python -W ignore -u main_training.py --experiment_name DGCNN_site_1layer_k100 --batch_size 64 --embedding_layer DGCNN --site True --single_protein True --device cuda:0 --n_layers 1 --random_rotation True --k 100
4 |
5 | python -W ignore -u main_training.py --experiment_name DGCNN_site_3layer_k200 --batch_size 64 --embedding_layer DGCNN --site True --single_protein True --device cuda:0 --n_layers 3 --random_rotation True --k 200
6 | python -W ignore -u main_training.py --experiment_name DGCNN_site_3layer_k100 --batch_size 64 --embedding_layer DGCNN --site True --single_protein True --device cuda:0 --n_layers 3 --random_rotation True --k 100
--------------------------------------------------------------------------------
/benchmark_scripts/Pointnet_site.sh:
--------------------------------------------------------------------------------
1 | # Load environment
2 | python -W ignore -u main_training.py --experiment_name PointNet_site_3layer_15A --batch_size 64 --embedding_layer PointNet++ --site True --single_protein True --device cuda:0 --random_rotation True --radius 15.0 --n_layers 3
3 | python -W ignore -u main_training.py --experiment_name PointNet_site_3layer_5A --batch_size 64 --embedding_layer PointNet++ --site True --single_protein True --device cuda:0 --random_rotation True --radius 5.0 --n_layers 3
4 | python -W ignore -u main_training.py --experiment_name PointNet_site_3layer_9A --batch_size 64 --embedding_layer PointNet++ --site True --single_protein True --device cuda:0 --random_rotation True --radius 9.0 --n_layers 3
5 |
6 | python -W ignore -u main_training.py --experiment_name PointNet_site_1layer_15A --batch_size 64 --embedding_layer PointNet++ --site True --single_protein True --device cuda:0 --random_rotation True --radius 15.0 --n_layers 1
7 | python -W ignore -u main_training.py --experiment_name PointNet_site_1layer_5A --batch_size 64 --embedding_layer PointNet++ --site True --single_protein True --device cuda:0 --random_rotation True --radius 5.0 --n_layers 1
8 | python -W ignore -u main_training.py --experiment_name PointNet_site_1layer_9A --batch_size 64 --embedding_layer PointNet++ --site True --single_protein True --device cuda:0 --random_rotation True --radius 9.0 --n_layers 1
--------------------------------------------------------------------------------
/benchmark_scripts/dMaSIF_search.sh:
--------------------------------------------------------------------------------
1 | # Load environment
2 | python -W ignore -u main_training.py --experiment_name dMaSIF_search_1layer_12A --batch_size 64 --embedding_layer dMaSIF --search True --device cuda:0 --random_rotation True --radius 12.0 --n_layers 1
3 | python -W ignore -u main_training.py --experiment_name dMaSIF_search_3layer_12A --batch_size 64 --embedding_layer dMaSIF --search True --device cuda:0 --random_rotation True --radius 12.0 --n_layers 3
4 |
--------------------------------------------------------------------------------
/benchmark_scripts/dMaSIF_site.sh:
--------------------------------------------------------------------------------
1 | # Load environment
2 | python -W ignore -u main_training.py --experiment_name dMaSIF_site_1layer_15A --batch_size 64 --embedding_layer dMaSIF --site True --single_protein True --random_rotation True --radius 15.0 --n_layers 1
3 | python -W ignore -u main_training.py --experiment_name dMaSIF_site_1layer_5A --batch_size 64 --embedding_layer dMaSIF --site True --single_protein True --random_rotation True --radius 5.0 --n_layers 1
4 | python -W ignore -u main_training.py --experiment_name dMaSIF_site_1layer_9A --batch_size 64 --embedding_layer dMaSIF --site True --single_protein True --random_rotation True --radius 9.0 --n_layers 1
5 |
6 | python -W ignore -u main_training.py --experiment_name dMaSIF_site_3layer_15A --batch_size 64 --embedding_layer dMaSIF --site True --single_protein True --random_rotation True --radius 15.0 --n_layers 3
7 | python -W ignore -u main_training.py --experiment_name dMaSIF_site_3layer_5A --batch_size 64 --embedding_layer dMaSIF --site True --single_protein True --random_rotation True --radius 5.0 --n_layers 3
8 | python -W ignore -u main_training.py --experiment_name dMaSIF_site_3layer_9A --batch_size 64 --embedding_layer dMaSIF --site True --single_protein True --random_rotation True --radius 9.0 --n_layers 3
9 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.data import InMemoryDataset, Data, DataLoader
3 | from torch_geometric.transforms import Compose
4 | import numpy as np
5 | from scipy.spatial.transform import Rotation
6 | import math
7 | import urllib.request
8 | import tarfile
9 | from pathlib import Path
10 | import requests
11 | from data_preprocessing.convert_pdb2npy import convert_pdbs
12 | from data_preprocessing.convert_ply2npy import convert_plys
13 |
14 | tensor = torch.FloatTensor
15 | inttensor = torch.LongTensor
16 |
17 |
18 | def numpy(x):
19 | return x.detach().cpu().numpy()
20 |
21 |
22 | def iface_valid_filter(protein_pair):
23 | labels1 = protein_pair.y_p1.reshape(-1)
24 | labels2 = protein_pair.y_p2.reshape(-1)
25 | valid1 = (
26 | (torch.sum(labels1) < 0.75 * len(labels1))
27 | and (torch.sum(labels1) > 30)
28 | and (torch.sum(labels1) > 0.01 * labels2.shape[0])
29 | )
30 | valid2 = (
31 | (torch.sum(labels2) < 0.75 * len(labels2))
32 | and (torch.sum(labels2) > 30)
33 | and (torch.sum(labels2) > 0.01 * labels1.shape[0])
34 | )
35 |
36 | return valid1 and valid2
37 |
38 |
39 | class RandomRotationPairAtoms(object):
40 | r"""Randomly rotate a protein"""
41 |
42 | def __call__(self, data):
43 | R1 = tensor(Rotation.random().as_matrix())
44 | R2 = tensor(Rotation.random().as_matrix())
45 |
46 | data.atom_coords_p1 = torch.matmul(R1, data.atom_coords_p1.T).T
47 | data.xyz_p1 = torch.matmul(R1, data.xyz_p1.T).T
48 | data.normals_p1 = torch.matmul(R1, data.normals_p1.T).T
49 |
50 | data.atom_coords_p2 = torch.matmul(R2, data.atom_coords_p2.T).T
51 | data.xyz_p2 = torch.matmul(R2, data.xyz_p2.T).T
52 | data.normals_p2 = torch.matmul(R2, data.normals_p2.T).T
53 |
54 | data.rand_rot1 = R1
55 | data.rand_rot2 = R2
56 | return data
57 |
58 | def __repr__(self):
59 | return "{}()".format(self.__class__.__name__)
60 |
61 |
62 | class CenterPairAtoms(object):
63 | r"""Centers a protein"""
64 |
65 | def __call__(self, data):
66 | atom_center1 = data.atom_coords_p1.mean(dim=-2, keepdim=True)
67 | atom_center2 = data.atom_coords_p2.mean(dim=-2, keepdim=True)
68 |
69 | data.atom_coords_p1 = data.atom_coords_p1 - atom_center1
70 | data.atom_coords_p2 = data.atom_coords_p2 - atom_center2
71 |
72 | data.xyz_p1 = data.xyz_p1 - atom_center1
73 | data.xyz_p2 = data.xyz_p2 - atom_center2
74 |
75 | data.atom_center1 = atom_center1
76 | data.atom_center2 = atom_center2
77 | return data
78 |
79 | def __repr__(self):
80 | return "{}()".format(self.__class__.__name__)
81 |
82 |
83 | class NormalizeChemFeatures(object):
84 | r"""Centers a protein"""
85 |
86 | def __call__(self, data):
87 | pb_upper = 3.0
88 | pb_lower = -3.0
89 |
90 | chem_p1 = data.chemical_features_p1
91 | chem_p2 = data.chemical_features_p2
92 |
93 | pb_p1 = chem_p1[:, 0]
94 | pb_p2 = chem_p2[:, 0]
95 | hb_p1 = chem_p1[:, 1]
96 | hb_p2 = chem_p2[:, 1]
97 | hp_p1 = chem_p1[:, 2]
98 | hp_p2 = chem_p2[:, 2]
99 |
100 | # Normalize PB
101 | pb_p1 = torch.clamp(pb_p1, pb_lower, pb_upper)
102 | pb_p1 = (pb_p1 - pb_lower) / (pb_upper - pb_lower)
103 | pb_p1 = 2 * pb_p1 - 1
104 |
105 | pb_p2 = torch.clamp(pb_p2, pb_lower, pb_upper)
106 | pb_p2 = (pb_p2 - pb_lower) / (pb_upper - pb_lower)
107 | pb_p2 = 2 * pb_p2 - 1
108 |
109 | # Normalize HP
110 | hp_p1 = hp_p1 / 4.5
111 | hp_p2 = hp_p2 / 4.5
112 |
113 | data.chemical_features_p1 = torch.stack([pb_p1, hb_p1, hp_p1]).T
114 | data.chemical_features_p2 = torch.stack([pb_p2, hb_p2, hp_p2]).T
115 |
116 | return data
117 |
118 | def __repr__(self):
119 | return "{}()".format(self.__class__.__name__)
120 |
121 |
122 | def load_protein_npy(pdb_id, data_dir, center=False, single_pdb=False):
123 | """Loads a protein surface mesh and its features"""
124 |
125 | # Load the data, and read the connectivity information:
126 | triangles = (
127 | None
128 | if single_pdb
129 | else inttensor(np.load(data_dir / (pdb_id + "_triangles.npy"))).T
130 | )
131 | # Normalize the point cloud, as specified by the user:
132 | points = None if single_pdb else tensor(np.load(data_dir / (pdb_id + "_xyz.npy")))
133 | center_location = None if single_pdb else torch.mean(points, axis=0, keepdims=True)
134 |
135 | atom_coords = tensor(np.load(data_dir / (pdb_id + "_atomxyz.npy")))
136 | atom_types = tensor(np.load(data_dir / (pdb_id + "_atomtypes.npy")))
137 |
138 | if center:
139 | points = points - center_location
140 | atom_coords = atom_coords - center_location
141 |
142 | # Interface labels
143 | iface_labels = (
144 | None
145 | if single_pdb
146 | else tensor(np.load(data_dir / (pdb_id + "_iface_labels.npy")).reshape((-1, 1)))
147 | )
148 |
149 | # Features
150 | chemical_features = (
151 | None if single_pdb else tensor(np.load(data_dir / (pdb_id + "_features.npy")))
152 | )
153 |
154 | # Normals
155 | normals = (
156 | None if single_pdb else tensor(np.load(data_dir / (pdb_id + "_normals.npy")))
157 | )
158 |
159 | protein_data = Data(
160 | xyz=points,
161 | face=triangles,
162 | chemical_features=chemical_features,
163 | y=iface_labels,
164 | normals=normals,
165 | center_location=center_location,
166 | num_nodes=None if single_pdb else points.shape[0],
167 | atom_coords=atom_coords,
168 | atom_types=atom_types,
169 | )
170 | return protein_data
171 |
172 |
173 | class PairData(Data):
174 | def __init__(
175 | self,
176 | xyz_p1=None,
177 | xyz_p2=None,
178 | face_p1=None,
179 | face_p2=None,
180 | chemical_features_p1=None,
181 | chemical_features_p2=None,
182 | y_p1=None,
183 | y_p2=None,
184 | normals_p1=None,
185 | normals_p2=None,
186 | center_location_p1=None,
187 | center_location_p2=None,
188 | atom_coords_p1=None,
189 | atom_coords_p2=None,
190 | atom_types_p1=None,
191 | atom_types_p2=None,
192 | atom_center1=None,
193 | atom_center2=None,
194 | rand_rot1=None,
195 | rand_rot2=None,
196 | ):
197 | super().__init__()
198 | self.xyz_p1 = xyz_p1
199 | self.xyz_p2 = xyz_p2
200 | self.face_p1 = face_p1
201 | self.face_p2 = face_p2
202 |
203 | self.chemical_features_p1 = chemical_features_p1
204 | self.chemical_features_p2 = chemical_features_p2
205 | self.y_p1 = y_p1
206 | self.y_p2 = y_p2
207 | self.normals_p1 = normals_p1
208 | self.normals_p2 = normals_p2
209 | self.center_location_p1 = center_location_p1
210 | self.center_location_p2 = center_location_p2
211 | self.atom_coords_p1 = atom_coords_p1
212 | self.atom_coords_p2 = atom_coords_p2
213 | self.atom_types_p1 = atom_types_p1
214 | self.atom_types_p2 = atom_types_p2
215 | self.atom_center1 = atom_center1
216 | self.atom_center2 = atom_center2
217 | self.rand_rot1 = rand_rot1
218 | self.rand_rot2 = rand_rot2
219 |
220 | def __inc__(self, key, value):
221 | if key == "face_p1":
222 | return self.xyz_p1.size(0)
223 | if key == "face_p2":
224 | return self.xyz_p2.size(0)
225 | else:
226 | return super(PairData, self).__inc__(key, value)
227 |
228 | def __cat_dim__(self, key, value):
229 | if ("index" in key) or ("face" in key):
230 | return 1
231 | else:
232 | return 0
233 |
234 |
235 | def load_protein_pair(pdb_id, data_dir,single_pdb=False):
236 | """Loads a protein surface mesh and its features"""
237 | pspl = pdb_id.split("_")
238 | p1_id = pspl[0] + "_" + pspl[1]
239 | p2_id = pspl[0] + "_" + pspl[2]
240 |
241 | p1 = load_protein_npy(p1_id, data_dir, center=False,single_pdb=single_pdb)
242 | p2 = load_protein_npy(p2_id, data_dir, center=False,single_pdb=single_pdb)
243 | # pdist = ((p1['xyz'][:,None,:]-p2['xyz'][None,:,:])**2).sum(-1).sqrt()
244 | # pdist = pdist<2.0
245 | # y_p1 = (pdist.sum(1)>0).to(torch.float).reshape(-1,1)
246 | # y_p2 = (pdist.sum(0)>0).to(torch.float).reshape(-1,1)
247 | y_p1 = p1["y"]
248 | y_p2 = p2["y"]
249 |
250 | protein_pair_data = PairData(
251 | xyz_p1=p1["xyz"],
252 | xyz_p2=p2["xyz"],
253 | face_p1=p1["face"],
254 | face_p2=p2["face"],
255 | chemical_features_p1=p1["chemical_features"],
256 | chemical_features_p2=p2["chemical_features"],
257 | y_p1=y_p1,
258 | y_p2=y_p2,
259 | normals_p1=p1["normals"],
260 | normals_p2=p2["normals"],
261 | center_location_p1=p1["center_location"],
262 | center_location_p2=p2["center_location"],
263 | atom_coords_p1=p1["atom_coords"],
264 | atom_coords_p2=p2["atom_coords"],
265 | atom_types_p1=p1["atom_types"],
266 | atom_types_p2=p2["atom_types"],
267 | )
268 | return protein_pair_data
269 |
270 |
271 | class ProteinPairsSurfaces(InMemoryDataset):
272 | url = ""
273 |
274 | def __init__(self, root, ppi=False, train=True, transform=None, pre_transform=None):
275 | self.ppi = ppi
276 | super(ProteinPairsSurfaces, self).__init__(root, transform, pre_transform)
277 | path = self.processed_paths[0] if train else self.processed_paths[1]
278 | self.data, self.slices = torch.load(path)
279 |
280 | @property
281 | def raw_file_names(self):
282 | return "masif_site_masif_search_pdbs_and_ply_files.tar.gz"
283 |
284 | @property
285 | def processed_file_names(self):
286 | if not self.ppi:
287 | file_names = [
288 | "training_pairs_data.pt",
289 | "testing_pairs_data.pt",
290 | "training_pairs_data_ids.npy",
291 | "testing_pairs_data_ids.npy",
292 | ]
293 | else:
294 | file_names = [
295 | "training_pairs_data_ppi.pt",
296 | "testing_pairs_data_ppi.pt",
297 | "training_pairs_data_ids_ppi.npy",
298 | "testing_pairs_data_ids_ppi.npy",
299 | ]
300 | return file_names
301 |
302 | def download(self):
303 | url = 'https://zenodo.org/record/2625420/files/masif_site_masif_search_pdbs_and_ply_files.tar.gz'
304 | target_path = self.raw_paths[0]
305 | response = requests.get(url, stream=True)
306 | if response.status_code == 200:
307 | with open(target_path, 'wb') as f:
308 | f.write(response.raw.read())
309 |
310 | #raise RuntimeError(
311 | # "Dataset not found. Please download {} from {} and move it to {}".format(
312 | # self.raw_file_names, self.url, self.raw_dir
313 | # )
314 | #)
315 |
316 | def process(self):
317 | pdb_dir = Path(self.root) / "raw" / "01-benchmark_pdbs"
318 | surf_dir = Path(self.root) / "raw" / "01-benchmark_surfaces"
319 | protein_dir = Path(self.root) / "raw" / "01-benchmark_surfaces_npy"
320 | lists_dir = Path('./lists')
321 |
322 | # Untar surface files
323 | if not (pdb_dir.exists() and surf_dir.exists()):
324 | tar = tarfile.open(self.raw_paths[0])
325 | tar.extractall(self.raw_dir)
326 | tar.close()
327 |
328 | if not protein_dir.exists():
329 | protein_dir.mkdir(parents=False, exist_ok=False)
330 | convert_plys(surf_dir,protein_dir)
331 | convert_pdbs(pdb_dir,protein_dir)
332 |
333 | with open(lists_dir / "training.txt") as f_tr, open(
334 | lists_dir / "testing.txt"
335 | ) as f_ts:
336 | training_list = sorted(f_tr.read().splitlines())
337 | testing_list = sorted(f_ts.read().splitlines())
338 |
339 | with open(lists_dir / "training_ppi.txt") as f_tr, open(
340 | lists_dir / "testing_ppi.txt"
341 | ) as f_ts:
342 | training_pairs_list = sorted(f_tr.read().splitlines())
343 | testing_pairs_list = sorted(f_ts.read().splitlines())
344 | pairs_list = sorted(training_pairs_list + testing_pairs_list)
345 |
346 | if not self.ppi:
347 | training_pairs_list = []
348 | for p in pairs_list:
349 | pspl = p.split("_")
350 | p1 = pspl[0] + "_" + pspl[1]
351 | p2 = pspl[0] + "_" + pspl[2]
352 |
353 | if p1 in training_list:
354 | training_pairs_list.append(p)
355 | if p2 in training_list:
356 | training_pairs_list.append(pspl[0] + "_" + pspl[2] + "_" + pspl[1])
357 |
358 | testing_pairs_list = []
359 | for p in pairs_list:
360 | pspl = p.split("_")
361 | p1 = pspl[0] + "_" + pspl[1]
362 | p2 = pspl[0] + "_" + pspl[2]
363 | if p1 in testing_list:
364 | testing_pairs_list.append(p)
365 | if p2 in testing_list:
366 | testing_pairs_list.append(pspl[0] + "_" + pspl[2] + "_" + pspl[1])
367 |
368 | # # Read data into huge `Data` list.
369 | training_pairs_data = []
370 | training_pairs_data_ids = []
371 | for p in training_pairs_list:
372 | try:
373 | protein_pair = load_protein_pair(p, protein_dir)
374 | except FileNotFoundError:
375 | continue
376 | training_pairs_data.append(protein_pair)
377 | training_pairs_data_ids.append(p)
378 |
379 | testing_pairs_data = []
380 | testing_pairs_data_ids = []
381 | for p in testing_pairs_list:
382 | try:
383 | protein_pair = load_protein_pair(p, protein_dir)
384 | except FileNotFoundError:
385 | continue
386 | testing_pairs_data.append(protein_pair)
387 | testing_pairs_data_ids.append(p)
388 |
389 | if self.pre_filter is not None:
390 | training_pairs_data = [
391 | data for data in training_pairs_data if self.pre_filter(data)
392 | ]
393 | testing_pairs_data = [
394 | data for data in testing_pairs_data if self.pre_filter(data)
395 | ]
396 |
397 | if self.pre_transform is not None:
398 | training_pairs_data = [
399 | self.pre_transform(data) for data in training_pairs_data
400 | ]
401 | testing_pairs_data = [
402 | self.pre_transform(data) for data in testing_pairs_data
403 | ]
404 |
405 | training_pairs_data, training_pairs_slices = self.collate(training_pairs_data)
406 | torch.save(
407 | (training_pairs_data, training_pairs_slices), self.processed_paths[0]
408 | )
409 | np.save(self.processed_paths[2], training_pairs_data_ids)
410 | testing_pairs_data, testing_pairs_slices = self.collate(testing_pairs_data)
411 | torch.save((testing_pairs_data, testing_pairs_slices), self.processed_paths[1])
412 | np.save(self.processed_paths[3], testing_pairs_data_ids)
413 |
--------------------------------------------------------------------------------
/data_analysis/analyse_descriptors.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pathlib import Path
3 | from sklearn.metrics import roc_auc_score, roc_curve
4 | from scipy.spatial.distance import cdist
5 | import matplotlib.pyplot as plt
6 |
7 | top_dir = Path('..')
8 | experiment_names = ['TangentConv_search_3L_16dim_12A_FIXED_binet_c_restarted_epoch43',
9 | 'TangentConv_search_3L_16dim_12A_FIXED_binet_g_restarted_epoch38',
10 | 'TangentConv_search_1L_8dim_12A_FIXED_binet_gc_epoch34',
11 | 'TangentConv_search_3L_8dim_12A_FIXED_binet_gc_restarted_epoch49',
12 | 'TangentConv_search_1L_16dim_12A_FIXED_binet_gc_epoch45']
13 |
14 | with open(top_dir/'surface_data/raw/protein_surfaces/testing_ppi.txt') as f:
15 | testing_list = f.read().splitlines()
16 |
17 | pdb_list = testing_list
18 |
19 | for experiment_name in experiment_names:
20 | print(experiment_name)
21 | desc_dir = top_dir/f'preds/{experiment_name}'
22 | all_roc_aucs = []
23 | all_preds = []
24 | all_labels = []
25 | for i, pdb_id in enumerate(pdb_list):
26 | pdb_id1 = pdb_id.split('_')[0]+'_'+pdb_id.split('_')[1]
27 | pdb_id2 = pdb_id.split('_')[0]+'_'+pdb_id.split('_')[2]
28 | if i%100==0:
29 | print(i,np.mean(all_roc_aucs))
30 |
31 | try:
32 | desc1 = np.load(desc_dir/f'{pdb_id1}_predfeatures.npy')[:,16:16+16]
33 | desc2 = np.load(desc_dir/f'{pdb_id2}_predfeatures.npy')[:,16:16+16]
34 | xyz1 = np.load(desc_dir/f'{pdb_id1}_predcoords.npy')
35 | xyz2 = np.load(desc_dir/f'{pdb_id2}_predcoords.npy')
36 | except FileNotFoundError:
37 | continue
38 |
39 | dists = cdist(xyz1,xyz2)<1.0
40 | if dists.sum()<1:
41 | continue
42 |
43 | iface_pos1 = dists.sum(1)>0
44 | iface_pos2 = dists.sum(0)>0
45 |
46 | pos_dists1 = dists[iface_pos1,:]
47 | pos_dists2 = dists[:,iface_pos2]
48 |
49 | desc_dists = np.matmul(desc1,desc2.T)
50 | #desc_dists = 1/cdist(desc1,desc2)
51 |
52 | pos_dists = desc_dists[dists].reshape(-1)
53 | pos_labels = np.ones_like(pos_dists)
54 | neg_dists1 = desc_dists[iface_pos1,:][pos_dists1==0].reshape(-1)
55 | neg_dists2 = desc_dists[:,iface_pos2][pos_dists2==0].reshape(-1)
56 |
57 | #neg_dists = np.concatenate([neg_dists1,neg_dists2],axis=0)
58 | neg_dists = neg_dists1
59 | neg_dists = np.random.choice(neg_dists,200,replace=False)
60 | neg_labels = np.zeros_like(neg_dists)
61 |
62 | preds = np.concatenate([pos_dists,neg_dists])
63 | labels = np.concatenate([pos_labels,neg_labels])
64 |
65 | roc_auc = roc_auc_score(labels,preds)
66 | all_roc_aucs.append(roc_auc)
67 | all_preds.extend(list(preds))
68 | all_labels.extend(list(labels))
69 |
70 |
71 | fpr, tpr, thresholds = roc_curve(all_labels,all_preds)
72 | np.save(f'roc_curves/{experiment_name}_fpr.npy',fpr)
73 | np.save(f'roc_curves/{experiment_name}_tpr.npy',tpr)
74 | np.save(f'roc_curves/{experiment_name}_all_labels.npy',all_labels)
75 | np.save(f'roc_curves/{experiment_name}_all_preds.npy',all_preds)
76 |
77 |
--------------------------------------------------------------------------------
/data_analysis/analyse_descriptors_para.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pathlib import Path
3 | from sklearn.metrics import roc_auc_score, roc_curve
4 | from scipy.spatial.distance import cdist
5 | import matplotlib.pyplot as plt
6 | import dask
7 |
8 | top_dir = Path('..')
9 | experiment_names = ['TangentConv_search_3L_16dim_12A_FIXED_binet_c_restarted_epoch43',
10 | 'TangentConv_search_3L_16dim_12A_FIXED_binet_g_restarted_epoch38',
11 | 'TangentConv_search_1L_8dim_12A_FIXED_binet_gc_epoch34',
12 | 'TangentConv_search_3L_8dim_12A_FIXED_binet_gc_restarted_epoch49',
13 | 'TangentConv_search_1L_16dim_12A_FIXED_binet_gc_epoch45']
14 | experiment_names = ['TangentConv_search_3L_16dim_12A_FIXED_binet_gc_restarted_epoch47']
15 | experiment_names = ['TangentConv_search_3L_16dim_12A_FIXED_binet_gc_subsamp50_epoch25']
16 | experiment_names = ['TangentConv_search_3L_16dim_12A_FIXED_binet_gc_subsamp50_restarted_restarted_restarted_epoch53']
17 |
18 |
19 | ndims = [16,16,8,8,16]
20 | ndims = [16]
21 |
22 |
23 | with open(top_dir/'surface_data/raw/protein_surfaces/testing_ppi.txt') as f:
24 | testing_list = f.read().splitlines()
25 |
26 | pdb_list = testing_list
27 |
28 | @dask.delayed
29 | def analyse_pdb(pdb_id,D):
30 | pdb_id1 = pdb_id.split('_')[0]+'_'+pdb_id.split('_')[1]
31 | pdb_id2 = pdb_id.split('_')[0]+'_'+pdb_id.split('_')[2]
32 |
33 | try:
34 | desc1 = np.load(desc_dir/f'{pdb_id1}_predfeatures.npy')[:,16:16+D]
35 | desc2 = np.load(desc_dir/f'{pdb_id2}_predfeatures.npy')[:,16:16+D]
36 | xyz1 = np.load(desc_dir/f'{pdb_id1}_predcoords.npy')
37 | xyz2 = np.load(desc_dir/f'{pdb_id2}_predcoords.npy')
38 | except FileNotFoundError:
39 | return -1
40 |
41 | dists = cdist(xyz1,xyz2)<1.0
42 | if dists.sum()<1:
43 | return -1
44 |
45 | iface_pos1 = dists.sum(1)>0
46 | iface_pos2 = dists.sum(0)>0
47 |
48 | pos_dists1 = dists[iface_pos1,:]
49 | pos_dists2 = dists[:,iface_pos2]
50 |
51 | desc_dists = np.matmul(desc1,desc2.T)
52 | #desc_dists = 1/cdist(desc1,desc2)
53 |
54 | pos_dists = desc_dists[dists].reshape(-1)
55 | pos_labels = np.ones_like(pos_dists)
56 | neg_dists1 = desc_dists[iface_pos1,:][pos_dists1==0].reshape(-1)
57 | neg_dists2 = desc_dists[:,iface_pos2][pos_dists2==0].reshape(-1)
58 |
59 | #neg_dists = np.concatenate([neg_dists1,neg_dists2],axis=0)
60 | neg_dists = neg_dists1
61 | neg_dists = np.random.choice(neg_dists,400,replace=False)
62 | neg_labels = np.zeros_like(neg_dists)
63 |
64 | preds = np.concatenate([pos_dists,neg_dists])
65 | labels = np.concatenate([pos_labels,neg_labels])
66 |
67 | roc_auc = roc_auc_score(labels,preds)
68 |
69 | return roc_auc, preds, labels
70 |
71 | for experiment_name, D in zip(experiment_names,ndims):
72 | print(experiment_name)
73 | desc_dir = top_dir/f'preds/{experiment_name}'
74 | all_roc_aucs = []
75 | all_preds = []
76 | all_labels = []
77 | all_res = []
78 | for i, pdb_id in enumerate(pdb_list):
79 | res = analyse_pdb(pdb_id,D)
80 | all_res.append(res)
81 |
82 | all_res = dask.compute(*all_res)
83 | for res in all_res:
84 | if res==-1:
85 | continue
86 | all_roc_aucs.append(res[0])
87 | all_preds.extend(list(res[1]))
88 | all_labels.extend(list(res[2]))
89 |
90 | print('ROC-AUC',np.mean(all_roc_aucs))
91 |
92 | fpr, tpr, thresholds = roc_curve(all_labels,all_preds)
93 | np.save(f'roc_curves/{experiment_name}_fpr.npy',fpr)
94 | np.save(f'roc_curves/{experiment_name}_tpr.npy',tpr)
95 | np.save(f'roc_curves/{experiment_name}_all_labels.npy',all_labels)
96 | np.save(f'roc_curves/{experiment_name}_all_preds.npy',all_preds)
97 |
98 |
--------------------------------------------------------------------------------
/data_analysis/analyse_site_outputs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from pathlib import Path
4 | from tqdm import tqdm
5 | from scipy.spatial.distance import cdist
6 | from sklearn.metrics import roc_curve, roc_auc_score
7 |
8 |
9 | masif_preds = Path("masif_preds/")
10 | timings = Path("timings/")
11 | raw_data = Path("surface_data/raw/protein_surfaces/01-benchmark_surfaces_npy")
12 |
13 | experiment_names = [
14 | "TangentConv_site_1layer_5A_epoch49",
15 | "TangentConv_site_1layer_9A_epoch49",
16 | "TangentConv_site_1layer_15A_epoch49",
17 | "TangentConv_site_3layer_15A_epoch17",
18 | "TangentConv_site_3layer_5A_epoch49",
19 | "TangentConv_site_3layer_9A_epoch46",
20 | "PointNet_site_3layer_9A_epoch37",
21 | "PointNet_site_3layer_5A_epoch46",
22 | "DGCNN_site_1layer_k100_epoch32",
23 | "PointNet_site_1layer_5A_epoch30",
24 | "PointNet_site_1layer_9A_epoch30",
25 | "DGCNN_site_1layer_k40_epoch46",
26 | "DGCNN_site_3layer_k40_epoch33",
27 | ]
28 |
29 | experiment_names = [
30 | 'Rebuttal_TangentConv_site_1L_8dim_9A_gc_subsamp20_dist05_epoch42',
31 | 'Rebuttal_TangentConv_site_1L_8dim_9A_gc_subsamp20_dist20_epoch49',
32 | 'Rebuttal_TangentConv_site_1L_8dim_9A_gc_subsamp20_dist105_epoch44',
33 | 'Rebuttal_TangentConv_site_1L_8dim_9A_gc_subsamp20_var01_epoch43',
34 | 'Rebuttal_TangentConv_site_1L_8dim_9A_gc_subsamp20_var02_epoch49',
35 | 'Rebuttal_TangentConv_site_1L_8dim_9A_gc_subsamp20_var005_epoch37'
36 | ]
37 |
38 | for experiment_name in experiment_names:
39 | print(experiment_name)
40 | datafolder = Path(f"preds/{experiment_name}")
41 | pdb_list = [p.stem[:-5] for p in datafolder.glob("*pred.vtk")]
42 |
43 | n_meshpoints = []
44 | n_predpoints = []
45 | meshpoints_mindists = []
46 | predpoints_mindists = []
47 | for pdb_id in tqdm(pdb_list):
48 | predpoints = np.load(datafolder / (pdb_id + "_predcoords.npy"))
49 | meshpoints = np.load(datafolder / (pdb_id + "_meshpoints.npy"))
50 | n_meshpoints.append(meshpoints.shape[0])
51 | n_predpoints.append(predpoints.shape[0])
52 |
53 | pdists = cdist(meshpoints, predpoints)
54 | meshpoints_mindists.append(pdists.min(1))
55 | predpoints_mindists.append(pdists.min(0))
56 |
57 | all_meshpoints_mindists = np.concatenate(meshpoints_mindists)
58 | all_predpoints_mindists = np.concatenate(predpoints_mindists)
59 |
60 | meshpoint_percentile = np.percentile(all_meshpoints_mindists, 99)
61 | predpoint_percentile = np.percentile(all_predpoints_mindists, 99)
62 |
63 | meshpoints_masks = []
64 | predpoints_masks = []
65 | for pdb_id in tqdm(pdb_list):
66 | predpoints = np.load(datafolder / (pdb_id + "_predcoords.npy"))
67 | meshpoints = np.load(datafolder / (pdb_id + "_meshpoints.npy"))
68 |
69 | pdists = cdist(meshpoints, predpoints)
70 | meshpoints_masks.append(pdists.min(1) < meshpoint_percentile)
71 | predpoints_masks.append(pdists.min(0) < predpoint_percentile)
72 |
73 | predpoints_preds = []
74 | predpoints_labels = []
75 | npoints = []
76 | for i, pdb_id in enumerate(tqdm(pdb_list)):
77 | predpoints_features = np.load(datafolder / (pdb_id + "_predfeatures.npy"))
78 | predpoints_features = predpoints_features[predpoints_masks[i]]
79 |
80 | predpoints_preds.append(predpoints_features[:, -2])
81 | predpoints_labels.append(predpoints_features[:, -1])
82 | npoints.append(predpoints_features.shape[0])
83 |
84 | predpoints_labels = np.concatenate(predpoints_labels)
85 | predpoints_preds = np.concatenate(predpoints_preds)
86 | rocauc = roc_auc_score(predpoints_labels.reshape(-1), predpoints_preds.reshape(-1))
87 | print("ROC-AUC", rocauc)
88 |
89 | np.save(timings / f"{experiment_name}_predpoints_preds", predpoints_preds)
90 | np.save(timings / f"{experiment_name}_predpoints_labels", predpoints_labels)
91 | np.save(timings / f"{experiment_name}_npoints", npoints)
92 |
--------------------------------------------------------------------------------
/data_iteration.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from helper import *
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.autograd.profiler as profiler
7 | from sklearn.metrics import roc_auc_score
8 | from pathlib import Path
9 | import math
10 | from tqdm import tqdm
11 | from geometry_processing import save_vtk
12 | from helper import numpy, diagonal_ranges
13 | import time
14 |
15 |
16 | def process_single(protein_pair, chain_idx=1):
17 | """Turn the PyG data object into a dict."""
18 |
19 | P = {}
20 | with_mesh = "face_p1" in protein_pair.keys
21 | preprocessed = "gen_xyz_p1" in protein_pair.keys
22 |
23 | if chain_idx == 1:
24 | # Ground truth labels are available on mesh vertices:
25 | P["mesh_labels"] = protein_pair.y_p1 if with_mesh else None
26 |
27 | # N.B.: The DataLoader should use the optional argument
28 | # "follow_batch=['xyz_p1', 'xyz_p2']", as described on the PyG tutorial.
29 | P["mesh_batch"] = protein_pair.xyz_p1_batch if with_mesh else None
30 |
31 | # Surface information:
32 | P["mesh_xyz"] = protein_pair.xyz_p1 if with_mesh else None
33 | P["mesh_triangles"] = protein_pair.face_p1 if with_mesh else None
34 |
35 | # Atom information:
36 | P["atoms"] = protein_pair.atom_coords_p1
37 | P["batch_atoms"] = protein_pair.atom_coords_p1_batch
38 |
39 | # Chemical features: atom coordinates and types.
40 | P["atom_xyz"] = protein_pair.atom_coords_p1
41 | P["atomtypes"] = protein_pair.atom_types_p1
42 |
43 | P["xyz"] = protein_pair.gen_xyz_p1 if preprocessed else None
44 | P["normals"] = protein_pair.gen_normals_p1 if preprocessed else None
45 | P["batch"] = protein_pair.gen_batch_p1 if preprocessed else None
46 | P["labels"] = protein_pair.gen_labels_p1 if preprocessed else None
47 |
48 | elif chain_idx == 2:
49 | # Ground truth labels are available on mesh vertices:
50 | P["mesh_labels"] = protein_pair.y_p2 if with_mesh else None
51 |
52 | # N.B.: The DataLoader should use the optional argument
53 | # "follow_batch=['xyz_p1', 'xyz_p2']", as described on the PyG tutorial.
54 | P["mesh_batch"] = protein_pair.xyz_p2_batch if with_mesh else None
55 |
56 | # Surface information:
57 | P["mesh_xyz"] = protein_pair.xyz_p2 if with_mesh else None
58 | P["mesh_triangles"] = protein_pair.face_p2 if with_mesh else None
59 |
60 | # Atom information:
61 | P["atoms"] = protein_pair.atom_coords_p2
62 | P["batch_atoms"] = protein_pair.atom_coords_p2_batch
63 |
64 | # Chemical features: atom coordinates and types.
65 | P["atom_xyz"] = protein_pair.atom_coords_p2
66 | P["atomtypes"] = protein_pair.atom_types_p2
67 |
68 | P["xyz"] = protein_pair.gen_xyz_p2 if preprocessed else None
69 | P["normals"] = protein_pair.gen_normals_p2 if preprocessed else None
70 | P["batch"] = protein_pair.gen_batch_p2 if preprocessed else None
71 | P["labels"] = protein_pair.gen_labels_p2 if preprocessed else None
72 |
73 | return P
74 |
75 |
76 | def save_protein_batch_single(protein_pair_id, P, save_path, pdb_idx):
77 |
78 | protein_pair_id = protein_pair_id.split("_")
79 | pdb_id = protein_pair_id[0] + "_" + protein_pair_id[pdb_idx]
80 |
81 | batch = P["batch"]
82 |
83 | xyz = P["xyz"]
84 |
85 | inputs = P["input_features"]
86 |
87 | embedding = P["embedding_1"] if pdb_idx == 1 else P["embedding_2"]
88 | emb_id = 1 if pdb_idx == 1 else 2
89 |
90 | predictions = torch.sigmoid(P["iface_preds"]) if "iface_preds" in P.keys() else 0.0*embedding[:,0].view(-1, 1)
91 |
92 | labels = P["labels"].view(-1, 1) if P["labels"] is not None else 0.0 * predictions
93 |
94 | coloring = torch.cat([inputs, embedding, predictions, labels], axis=1)
95 |
96 | save_vtk(str(save_path / pdb_id) + f"_pred_emb{emb_id}", xyz, values=coloring)
97 | np.save(str(save_path / pdb_id) + "_predcoords", numpy(xyz))
98 | np.save(str(save_path / pdb_id) + f"_predfeatures_emb{emb_id}", numpy(coloring))
99 |
100 |
101 | def project_iface_labels(P, threshold=2.0):
102 |
103 | queries = P["xyz"]
104 | batch_queries = P["batch"]
105 | source = P["mesh_xyz"]
106 | batch_source = P["mesh_batch"]
107 | labels = P["mesh_labels"]
108 | x_i = LazyTensor(queries[:, None, :]) # (N, 1, D)
109 | y_j = LazyTensor(source[None, :, :]) # (1, M, D)
110 |
111 | D_ij = ((x_i - y_j) ** 2).sum(-1) # (N, M)
112 | D_ij.ranges = diagonal_ranges(batch_queries, batch_source)
113 | nn_i = D_ij.argmin(dim=1).view(-1) # (N,)
114 | nn_dist_i = (
115 | D_ij.min(dim=1).view(-1, 1) < threshold
116 | ).float() # If chain is not connected because of missing densities MaSIF cut out a part of the protein
117 |
118 | query_labels = labels[nn_i] * nn_dist_i
119 |
120 | P["labels"] = query_labels
121 |
122 |
123 | def process(args, protein_pair, net):
124 | P1 = process_single(protein_pair, chain_idx=1)
125 | if not "gen_xyz_p1" in protein_pair.keys:
126 | net.preprocess_surface(P1)
127 | #if P1["mesh_labels"] is not None:
128 | # project_iface_labels(P1)
129 | P2 = None
130 | if not args.single_protein:
131 | P2 = process_single(protein_pair, chain_idx=2)
132 | if not "gen_xyz_p2" in protein_pair.keys:
133 | net.preprocess_surface(P2)
134 | #if P2["mesh_labels"] is not None:
135 | # project_iface_labels(P2)
136 |
137 | return P1, P2
138 |
139 |
140 | def generate_matchinglabels(args, P1, P2):
141 | if args.random_rotation:
142 | P1["xyz"] = torch.matmul(P1["rand_rot"].T, P1["xyz"].T).T + P1["atom_center"]
143 | P2["xyz"] = torch.matmul(P2["rand_rot"].T, P2["xyz"].T).T + P2["atom_center"]
144 | xyz1_i = LazyTensor(P1["xyz"][:, None, :].contiguous())
145 | xyz2_j = LazyTensor(P2["xyz"][None, :, :].contiguous())
146 |
147 | xyz_dists = ((xyz1_i - xyz2_j) ** 2).sum(-1).sqrt()
148 | xyz_dists = (1.0 - xyz_dists).step()
149 |
150 | p1_iface_labels = (xyz_dists.sum(1) > 1.0).float().view(-1)
151 | p2_iface_labels = (xyz_dists.sum(0) > 1.0).float().view(-1)
152 |
153 | P1["labels"] = p1_iface_labels
154 | P2["labels"] = p2_iface_labels
155 |
156 |
157 | def compute_loss(args, P1, P2, n_points_sample=16):
158 |
159 | if args.search:
160 | pos_xyz1 = P1["xyz"][P1["labels"] == 1]
161 | pos_xyz2 = P2["xyz"][P2["labels"] == 1]
162 | pos_descs1 = P1["embedding_1"][P1["labels"] == 1]
163 | pos_descs2 = P2["embedding_2"][P2["labels"] == 1]
164 |
165 | pos_xyz_dists = (
166 | ((pos_xyz1[:, None, :] - pos_xyz2[None, :, :]) ** 2).sum(-1).sqrt()
167 | )
168 | pos_desc_dists = torch.matmul(pos_descs1, pos_descs2.T)
169 |
170 | pos_preds = pos_desc_dists[pos_xyz_dists < 1.0]
171 | pos_labels = torch.ones_like(pos_preds)
172 |
173 | n_desc_sample = 100
174 | sample_desc2 = torch.randperm(len(P2["embedding_2"]))[:n_desc_sample]
175 | sample_desc2 = P2["embedding_2"][sample_desc2]
176 | neg_preds = torch.matmul(pos_descs1, sample_desc2.T).view(-1)
177 | neg_labels = torch.zeros_like(neg_preds)
178 |
179 | # For symmetry
180 | pos_descs1_2 = P1["embedding_2"][P1["labels"] == 1]
181 | pos_descs2_2 = P2["embedding_1"][P2["labels"] == 1]
182 |
183 | pos_desc_dists2 = torch.matmul(pos_descs2_2, pos_descs1_2.T)
184 | pos_preds2 = pos_desc_dists2[pos_xyz_dists.T < 1.0]
185 | pos_preds = torch.cat([pos_preds, pos_preds2], dim=0)
186 | pos_labels = torch.ones_like(pos_preds)
187 |
188 | sample_desc1_2 = torch.randperm(len(P1["embedding_2"]))[:n_desc_sample]
189 | sample_desc1_2 = P1["embedding_2"][sample_desc1_2]
190 | neg_preds_2 = torch.matmul(pos_descs2_2, sample_desc1_2.T).view(-1)
191 |
192 | neg_preds = torch.cat([neg_preds, neg_preds_2], dim=0)
193 | neg_labels = torch.zeros_like(neg_preds)
194 |
195 | else:
196 | pos_preds = P1["iface_preds"][P1["labels"] == 1]
197 | pos_labels = P1["labels"][P1["labels"] == 1]
198 | neg_preds = P1["iface_preds"][P1["labels"] == 0]
199 | neg_labels = P1["labels"][P1["labels"] == 0]
200 |
201 | n_points_sample = len(pos_labels)
202 | pos_indices = torch.randperm(len(pos_labels))[:n_points_sample]
203 | neg_indices = torch.randperm(len(neg_labels))[:n_points_sample]
204 |
205 | pos_preds = pos_preds[pos_indices]
206 | pos_labels = pos_labels[pos_indices]
207 | neg_preds = neg_preds[neg_indices]
208 | neg_labels = neg_labels[neg_indices]
209 |
210 | preds_concat = torch.cat([pos_preds, neg_preds])
211 | labels_concat = torch.cat([pos_labels, neg_labels])
212 |
213 | loss = F.binary_cross_entropy_with_logits(preds_concat, labels_concat)
214 |
215 | return loss, preds_concat, labels_concat
216 |
217 |
218 | def extract_single(P_batch, number):
219 | P = {} # First and second proteins
220 | batch = P_batch["batch"] == number
221 | batch_atoms = P_batch["batch_atoms"] == number
222 |
223 | with_mesh = P_batch["labels"] is not None
224 | # Ground truth labels are available on mesh vertices:
225 | P["labels"] = P_batch["labels"][batch] if with_mesh else None
226 |
227 | P["batch"] = P_batch["batch"][batch]
228 |
229 | # Surface information:
230 | P["xyz"] = P_batch["xyz"][batch]
231 | P["normals"] = P_batch["normals"][batch]
232 |
233 | # Atom information:
234 | P["atoms"] = P_batch["atoms"][batch_atoms]
235 | P["batch_atoms"] = P_batch["batch_atoms"][batch_atoms]
236 |
237 | # Chemical features: atom coordinates and types.
238 | P["atom_xyz"] = P_batch["atom_xyz"][batch_atoms]
239 | P["atomtypes"] = P_batch["atomtypes"][batch_atoms]
240 |
241 | return P
242 |
243 |
244 | def iterate(
245 | net,
246 | dataset,
247 | optimizer,
248 | args,
249 | test=False,
250 | save_path=None,
251 | pdb_ids=None,
252 | summary_writer=None,
253 | epoch_number=None,
254 | ):
255 | """Goes through one epoch of the dataset, returns information for Tensorboard."""
256 |
257 | if test:
258 | net.eval()
259 | torch.set_grad_enabled(False)
260 | else:
261 | net.train()
262 | torch.set_grad_enabled(True)
263 |
264 | # Statistics and fancy graphs to summarize the epoch:
265 | info = []
266 | total_processed_pairs = 0
267 | # Loop over one epoch:
268 | for it, protein_pair in enumerate(
269 | tqdm(dataset)
270 | ): # , desc="Test " if test else "Train")):
271 | protein_batch_size = protein_pair.atom_coords_p1_batch[-1].item() + 1
272 | if save_path is not None:
273 | batch_ids = pdb_ids[
274 | total_processed_pairs : total_processed_pairs + protein_batch_size
275 | ]
276 | total_processed_pairs += protein_batch_size
277 |
278 | protein_pair.to(args.device)
279 |
280 | if not test:
281 | optimizer.zero_grad()
282 |
283 | # Generate the surface:
284 | torch.cuda.synchronize()
285 | surface_time = time.time()
286 | P1_batch, P2_batch = process(args, protein_pair, net)
287 | torch.cuda.synchronize()
288 | surface_time = time.time() - surface_time
289 |
290 | for protein_it in range(protein_batch_size):
291 | torch.cuda.synchronize()
292 | iteration_time = time.time()
293 |
294 | P1 = extract_single(P1_batch, protein_it)
295 | P2 = None if args.single_protein else extract_single(P2_batch, protein_it)
296 |
297 |
298 | if args.random_rotation:
299 | P1["rand_rot"] = protein_pair.rand_rot1.view(-1, 3, 3)[0]
300 | P1["atom_center"] = protein_pair.atom_center1.view(-1, 1, 3)[0]
301 | P1["xyz"] = P1["xyz"] - P1["atom_center"]
302 | P1["xyz"] = (
303 | torch.matmul(P1["rand_rot"], P1["xyz"].T).T
304 | ).contiguous()
305 | P1["normals"] = (
306 | torch.matmul(P1["rand_rot"], P1["normals"].T).T
307 | ).contiguous()
308 | if not args.single_protein:
309 | P2["rand_rot"] = protein_pair.rand_rot2.view(-1, 3, 3)[0]
310 | P2["atom_center"] = protein_pair.atom_center2.view(-1, 1, 3)[0]
311 | P2["xyz"] = P2["xyz"] - P2["atom_center"]
312 | P2["xyz"] = (
313 | torch.matmul(P2["rand_rot"], P2["xyz"].T).T
314 | ).contiguous()
315 | P2["normals"] = (
316 | torch.matmul(P2["rand_rot"], P2["normals"].T).T
317 | ).contiguous()
318 | else:
319 | P1["rand_rot"] = torch.eye(3, device=P1["xyz"].device)
320 | P1["atom_center"] = torch.zeros((1, 3), device=P1["xyz"].device)
321 | if not args.single_protein:
322 | P2["rand_rot"] = torch.eye(3, device=P2["xyz"].device)
323 | P2["atom_center"] = torch.zeros((1, 3), device=P2["xyz"].device)
324 |
325 | torch.cuda.synchronize()
326 | prediction_time = time.time()
327 | outputs = net(P1, P2)
328 | torch.cuda.synchronize()
329 | prediction_time = time.time() - prediction_time
330 |
331 | P1 = outputs["P1"]
332 | P2 = outputs["P2"]
333 |
334 | if args.search:
335 | generate_matchinglabels(args, P1, P2)
336 |
337 | if P1["labels"] is not None:
338 | loss, sampled_preds, sampled_labels = compute_loss(args, P1, P2)
339 | else:
340 | loss = torch.tensor(0.0)
341 | sampled_preds = None
342 | sampled_labels = None
343 |
344 | # Compute the gradient, update the model weights:
345 | if not test:
346 | torch.cuda.synchronize()
347 | back_time = time.time()
348 | loss.backward()
349 | optimizer.step()
350 | torch.cuda.synchronize()
351 | back_time = time.time() - back_time
352 |
353 | if it == protein_it == 0 and not test:
354 | for para_it, parameter in enumerate(net.atomnet.parameters()):
355 | if parameter.requires_grad:
356 | summary_writer.add_histogram(
357 | f"Gradients/Atomnet/para_{para_it}_{parameter.shape}",
358 | parameter.grad.view(-1),
359 | epoch_number,
360 | )
361 | for para_it, parameter in enumerate(net.conv.parameters()):
362 | if parameter.requires_grad:
363 | summary_writer.add_histogram(
364 | f"Gradients/Conv/para_{para_it}_{parameter.shape}",
365 | parameter.grad.view(-1),
366 | epoch_number,
367 | )
368 |
369 | for d, features in enumerate(P1["input_features"].T):
370 | summary_writer.add_histogram(f"Input features/{d}", features)
371 |
372 | if save_path is not None:
373 | save_protein_batch_single(
374 | batch_ids[protein_it], P1, save_path, pdb_idx=1
375 | )
376 | if not args.single_protein:
377 | save_protein_batch_single(
378 | batch_ids[protein_it], P2, save_path, pdb_idx=2
379 | )
380 |
381 | try:
382 | if sampled_labels is not None:
383 | roc_auc = roc_auc_score(
384 | np.rint(numpy(sampled_labels.view(-1))),
385 | numpy(sampled_preds.view(-1)),
386 | )
387 | else:
388 | roc_auc = 0.0
389 | except Exception as e:
390 | print("Problem with computing roc-auc")
391 | print(e)
392 | continue
393 |
394 | R_values = outputs["R_values"]
395 |
396 | info.append(
397 | dict(
398 | {
399 | "Loss": loss.item(),
400 | "ROC-AUC": roc_auc,
401 | "conv_time": outputs["conv_time"],
402 | "memory_usage": outputs["memory_usage"],
403 | },
404 | # Merge the "R_values" dict into "info", with a prefix:
405 | **{"R_values/" + k: v for k, v in R_values.items()},
406 | )
407 | )
408 | torch.cuda.synchronize()
409 | iteration_time = time.time() - iteration_time
410 |
411 | # Turn a list of dicts into a dict of lists:
412 | newdict = {}
413 | for k, v in [(key, d[key]) for d in info for key in d]:
414 | if k not in newdict:
415 | newdict[k] = [v]
416 | else:
417 | newdict[k].append(v)
418 | info = newdict
419 |
420 | # Final post-processing:
421 | return info
422 |
423 | def iterate_surface_precompute(dataset, net, args):
424 | processed_dataset = []
425 | for it, protein_pair in enumerate(tqdm(dataset)):
426 | protein_pair.to(args.device)
427 | P1, P2 = process(args, protein_pair, net)
428 | if args.random_rotation:
429 | P1["rand_rot"] = protein_pair.rand_rot1
430 | P1["atom_center"] = protein_pair.atom_center1
431 | P1["xyz"] = (
432 | torch.matmul(P1["rand_rot"].T, P1["xyz"].T).T + P1["atom_center"]
433 | )
434 | P1["normals"] = torch.matmul(P1["rand_rot"].T, P1["normals"].T).T
435 | if not args.single_protein:
436 | P2["rand_rot"] = protein_pair.rand_rot2
437 | P2["atom_center"] = protein_pair.atom_center2
438 | P2["xyz"] = (
439 | torch.matmul(P2["rand_rot"].T, P2["xyz"].T).T + P2["atom_center"]
440 | )
441 | P2["normals"] = torch.matmul(P2["rand_rot"].T, P2["normals"].T).T
442 | protein_pair = protein_pair.to_data_list()[0]
443 | protein_pair.gen_xyz_p1 = P1["xyz"]
444 | protein_pair.gen_normals_p1 = P1["normals"]
445 | protein_pair.gen_batch_p1 = P1["batch"]
446 | protein_pair.gen_labels_p1 = P1["labels"]
447 | protein_pair.gen_xyz_p2 = P2["xyz"]
448 | protein_pair.gen_normals_p2 = P2["normals"]
449 | protein_pair.gen_batch_p2 = P2["batch"]
450 | protein_pair.gen_labels_p2 = P2["labels"]
451 | processed_dataset.append(protein_pair.to("cpu"))
452 | return processed_dataset
453 |
--------------------------------------------------------------------------------
/data_preprocessing/convert_pdb2npy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pathlib import Path
3 | from tqdm import tqdm
4 | from Bio.PDB import *
5 |
6 | ele2num = {"C": 0, "H": 1, "O": 2, "N": 3, "S": 4, "SE": 5}
7 |
8 |
9 | def load_structure_np(fname, center):
10 | """Loads a .ply mesh to return a point cloud and connectivity."""
11 | # Load the data
12 | parser = PDBParser()
13 | structure = parser.get_structure("structure", fname)
14 | atoms = structure.get_atoms()
15 |
16 | coords = []
17 | types = []
18 | for atom in atoms:
19 | coords.append(atom.get_coord())
20 | types.append(ele2num[atom.element])
21 |
22 | coords = np.stack(coords)
23 | types_array = np.zeros((len(types), len(ele2num)))
24 | for i, t in enumerate(types):
25 | types_array[i, t] = 1.0
26 |
27 | # Normalize the coordinates, as specified by the user:
28 | if center:
29 | coords = coords - np.mean(coords, axis=0, keepdims=True)
30 |
31 | return {"xyz": coords, "types": types_array}
32 |
33 |
34 | def convert_pdbs(pdb_dir, npy_dir):
35 | print("Converting PDBs")
36 | for p in tqdm(pdb_dir.glob("*.pdb")):
37 | protein = load_structure_np(p, center=False)
38 | np.save(npy_dir / (p.stem + "_atomxyz.npy"), protein["xyz"])
39 | np.save(npy_dir / (p.stem + "_atomtypes.npy"), protein["types"])
40 |
--------------------------------------------------------------------------------
/data_preprocessing/convert_ply2npy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pathlib import Path
3 | from tqdm import tqdm
4 | from plyfile import PlyData, PlyElement
5 |
6 |
7 | def load_surface_np(fname, center):
8 | """Loads a .ply mesh to return a point cloud and connectivity."""
9 |
10 | # Load the data, and read the connectivity information:
11 | plydata = PlyData.read(str(fname))
12 | triangles = np.vstack(plydata["face"].data["vertex_indices"])
13 |
14 | # Normalize the point cloud, as specified by the user:
15 | points = np.vstack([[v[0], v[1], v[2]] for v in plydata["vertex"]])
16 | if center:
17 | points = points - np.mean(points, axis=0, keepdims=True)
18 |
19 | nx = plydata["vertex"]["nx"]
20 | ny = plydata["vertex"]["ny"]
21 | nz = plydata["vertex"]["nz"]
22 | normals = np.stack([nx, ny, nz]).T
23 |
24 | # Interface labels
25 | iface_labels = plydata["vertex"]["iface"]
26 |
27 | # Features
28 | charge = plydata["vertex"]["charge"]
29 | hbond = plydata["vertex"]["hbond"]
30 | hphob = plydata["vertex"]["hphob"]
31 | features = np.stack([charge, hbond, hphob]).T
32 |
33 | return {
34 | "xyz": points,
35 | "triangles": triangles,
36 | "features": features,
37 | "iface_labels": iface_labels,
38 | "normals": normals,
39 | }
40 |
41 |
42 | def convert_plys(ply_dir, npy_dir):
43 | print("Converting PLYs")
44 | for p in tqdm(ply_dir.glob("*.ply")):
45 | protein = load_surface_np(p, center=False)
46 | np.save(npy_dir / (p.stem + "_xyz.npy"), protein["xyz"])
47 | np.save(npy_dir / (p.stem + "_triangles.npy"), protein["triangles"])
48 | np.save(npy_dir / (p.stem + "_features.npy"), protein["features"])
49 | np.save(npy_dir / (p.stem + "_iface_labels.npy"), protein["iface_labels"])
50 | np.save(npy_dir / (p.stem + "_normals.npy"), protein["normals"])
51 |
52 |
--------------------------------------------------------------------------------
/data_preprocessing/download_pdb.py:
--------------------------------------------------------------------------------
1 | import Bio
2 | from Bio.PDB import *
3 | from Bio.SeqUtils import IUPACData
4 | import sys
5 | import importlib
6 | import os
7 | import numpy as np
8 | from subprocess import Popen, PIPE
9 | from pathlib import Path
10 | from convert_pdb2npy import load_structure_np
11 | import argparse
12 |
13 | parser = argparse.ArgumentParser(description="Arguments")
14 | parser.add_argument(
15 | "--pdb", type=str,default='', help="PDB code along with chains to extract, example 1ABC_A_B", required=False
16 | )
17 | parser.add_argument(
18 | "--pdb_list", type=str,default='', help="Path to a text file that includes a list of PDB codes along with chains, example 1ABC_A_B", required=False
19 | )
20 |
21 | tmp_dir = Path('./tmp')
22 | pdb_dir = Path('./pdbs')
23 | npy_dir = Path('./npys')
24 |
25 | PROTEIN_LETTERS = [x.upper() for x in IUPACData.protein_letters_3to1.keys()]
26 |
27 | # Exclude disordered atoms.
28 | class NotDisordered(Select):
29 | def accept_atom(self, atom):
30 | return not atom.is_disordered() or atom.get_altloc() == "A" or atom.get_altloc() == "1"
31 |
32 |
33 | def find_modified_amino_acids(path):
34 | """
35 | Contributed by github user jomimc - find modified amino acids in the PDB (e.g. MSE)
36 | """
37 | res_set = set()
38 | for line in open(path, 'r'):
39 | if line[:6] == 'SEQRES':
40 | for res in line.split()[4:]:
41 | res_set.add(res)
42 | for res in list(res_set):
43 | if res in PROTEIN_LETTERS:
44 | res_set.remove(res)
45 | return res_set
46 |
47 |
48 | def extractPDB(
49 | infilename, outfilename, chain_ids=None
50 | ):
51 | # extract the chain_ids from infilename and save in outfilename.
52 | parser = PDBParser(QUIET=True)
53 | struct = parser.get_structure(infilename, infilename)
54 | model = Selection.unfold_entities(struct, "M")[0]
55 | chains = Selection.unfold_entities(struct, "C")
56 | # Select residues to extract and build new structure
57 | structBuild = StructureBuilder.StructureBuilder()
58 | structBuild.init_structure("output")
59 | structBuild.init_seg(" ")
60 | structBuild.init_model(0)
61 | outputStruct = structBuild.get_structure()
62 |
63 | # Load a list of non-standard amino acid names -- these are
64 | # typically listed under HETATM, so they would be typically
65 | # ignored by the orginal algorithm
66 | modified_amino_acids = find_modified_amino_acids(infilename)
67 |
68 | for chain in model:
69 | if (
70 | chain_ids == None
71 | or chain.get_id() in chain_ids
72 | ):
73 | structBuild.init_chain(chain.get_id())
74 | for residue in chain:
75 | het = residue.get_id()
76 | if het[0] == " ":
77 | outputStruct[0][chain.get_id()].add(residue)
78 | elif het[0][-3:] in modified_amino_acids:
79 | outputStruct[0][chain.get_id()].add(residue)
80 |
81 | # Output the selected residues
82 | pdbio = PDBIO()
83 | pdbio.set_structure(outputStruct)
84 | pdbio.save(outfilename, select=NotDisordered())
85 |
86 | def protonate(in_pdb_file, out_pdb_file):
87 | # protonate (i.e., add hydrogens) a pdb using reduce and save to an output file.
88 | # in_pdb_file: file to protonate.
89 | # out_pdb_file: output file where to save the protonated pdb file.
90 |
91 | # Remove protons first, in case the structure is already protonated
92 | args = ["reduce", "-Trim", in_pdb_file]
93 | p2 = Popen(args, stdout=PIPE, stderr=PIPE)
94 | stdout, stderr = p2.communicate()
95 | outfile = open(out_pdb_file, "w")
96 | outfile.write(stdout.decode('utf-8').rstrip())
97 | outfile.close()
98 | # Now add them again.
99 | args = ["reduce", "-HIS", out_pdb_file]
100 | p2 = Popen(args, stdout=PIPE, stderr=PIPE)
101 | stdout, stderr = p2.communicate()
102 | outfile = open(out_pdb_file, "w")
103 | outfile.write(stdout.decode('utf-8'))
104 | outfile.close()
105 |
106 |
107 |
108 | def get_single(pdb_id: str,chains: list):
109 | protonated_file = pdb_dir/f"{pdb_id}.pdb"
110 | if not protonated_file.exists():
111 | # Download pdb
112 | pdbl = PDBList()
113 | pdb_filename = pdbl.retrieve_pdb_file(pdb_id, pdir=tmp_dir,file_format='pdb')
114 |
115 | ##### Protonate with reduce, if hydrogens included.
116 | # - Always protonate as this is useful for charges. If necessary ignore hydrogens later.
117 | protonate(pdb_filename, protonated_file)
118 |
119 | pdb_filename = protonated_file
120 |
121 | # Extract chains of interest.
122 | for chain in chains:
123 | out_filename = pdb_dir/f"{pdb_id}_{chain}.pdb"
124 | extractPDB(pdb_filename, str(out_filename), chain)
125 | protein = load_structure_np(out_filename,center=False)
126 | np.save(npy_dir / f"{pdb_id}_{chain}_atomxyz", protein["xyz"])
127 | np.save(npy_dir / f"{pdb_id}_{chain}_atomtypes", protein["types"])
128 |
129 | if __name__ == '__main__':
130 | args = parser.parse_args()
131 | if args.pdb != '':
132 | pdb_id = args.pdb.split('_')
133 | chains = pdb_id[1:]
134 | pdb_id = pdb_id[0]
135 | get_single(pdb_id,chains)
136 |
137 | elif args.pdb_list != '':
138 | with open(args.pdb_list) as f:
139 | pdb_list = f.read().splitlines()
140 | for pdb_id in pdb_list:
141 | pdb_id = pdb_id.split('_')
142 | chains = pdb_id[1:]
143 | pdb_id = pdb_id[0]
144 | get_single(pdb_id,chains)
145 | else:
146 | raise ValueError('Must specify PDB or PDB list')
--------------------------------------------------------------------------------
/geometry_processing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from math import pi
3 | import torch
4 | from pykeops.torch import LazyTensor
5 | from plyfile import PlyData, PlyElement
6 | from helper import *
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | # from matplotlib import pyplot as plt
11 | from pykeops.torch.cluster import grid_cluster, cluster_ranges_centroids, from_matrix
12 | from math import pi, sqrt
13 |
14 |
15 | # Input-Output for tests =======================================================
16 |
17 | import os
18 | from pyvtk import PolyData, PointData, CellData, Scalars, Vectors, VtkData, PointData
19 |
20 |
21 | def save_vtk(
22 | fname, xyz, triangles=None, values=None, vectors=None, triangle_values=None
23 | ):
24 | """Saves a point cloud or triangle mesh as a .vtk file.
25 |
26 | Files can be opened with Paraview or displayed using the PyVista library.
27 |
28 | Args:
29 | fname (string): filename.
30 | xyz (Tensor): (N,3) point cloud or vertices.
31 | triangles (integer Tensor, optional): (T,3) mesh connectivity. Defaults to None.
32 | values (Tensor, optional): (N,D) values, supported by the vertices. Defaults to None.
33 | vectors (Tensor, optional): (N,3) vectors, supported by the vertices. Defaults to None.
34 | triangle_values (Tensor, optional): (T,D) values, supported by the triangles. Defaults to None.
35 | """
36 |
37 | # Encode the points/vertices as a VTK structure:
38 | if triangles is None: # Point cloud
39 | structure = PolyData(points=numpy(xyz), vertices=np.arange(len(xyz)))
40 | else: # Surface mesh
41 | structure = PolyData(points=numpy(xyz), polygons=numpy(triangles))
42 |
43 | data = [structure]
44 | pointdata, celldata = [], []
45 |
46 | # Point values - one channel per column of the `values` array:
47 | if values is not None:
48 | values = numpy(values)
49 | if len(values.shape) == 1:
50 | values = values[:, None]
51 | features = values.T
52 | pointdata += [
53 | Scalars(f, name=f"features_{i:02d}") for i, f in enumerate(features)
54 | ]
55 |
56 | # Point vectors - one vector per point:
57 | if vectors is not None:
58 | pointdata += [Vectors(numpy(vectors), name="vectors")]
59 |
60 | # Store in the VTK object:
61 | if pointdata != []:
62 | pointdata = PointData(*pointdata)
63 | data.append(pointdata)
64 |
65 | # Triangle values - one channel per column of the `triangle_values` array:
66 | if triangle_values is not None:
67 | triangle_values = numpy(triangle_values)
68 | if len(triangle_values.shape) == 1:
69 | triangle_values = triangle_values[:, None]
70 | features = triangle_values.T
71 | celldata += [
72 | Scalars(f, name=f"features_{i:02d}") for i, f in enumerate(features)
73 | ]
74 |
75 | celldata = CellData(*celldata)
76 | data.append(celldata)
77 |
78 | # Write to hard drive:
79 | vtk = VtkData(*data)
80 | os.makedirs(os.path.dirname(fname), exist_ok=True)
81 | vtk.tofile(fname)
82 |
83 |
84 | # On-the-fly generation of the surfaces ========================================
85 |
86 |
87 | def subsample(x, batch=None, scale=1.0):
88 | """Subsamples the point cloud using a grid (cubic) clustering scheme.
89 |
90 | The function returns one average sample per cell, as described in Fig. 3.e)
91 | of the paper.
92 |
93 | Args:
94 | x (Tensor): (N,3) point cloud.
95 | batch (integer Tensor, optional): (N,) batch vector, as in PyTorch_geometric.
96 | Defaults to None.
97 | scale (float, optional): side length of the cubic grid cells. Defaults to 1 (Angstrom).
98 |
99 | Returns:
100 | (M,3): sub-sampled point cloud, with M <= N.
101 | """
102 |
103 | if batch is None: # Single protein case:
104 | if True: # Use a fast scatter_add_ implementation
105 | labels = grid_cluster(x, scale).long()
106 | C = labels.max() + 1
107 |
108 | # We append a "1" to the input vectors, in order to
109 | # compute both the numerator and denominator of the "average"
110 | # fraction in one pass through the data.
111 | x_1 = torch.cat((x, torch.ones_like(x[:, :1])), dim=1)
112 | D = x_1.shape[1]
113 | points = torch.zeros_like(x_1[:C])
114 | points.scatter_add_(0, labels[:, None].repeat(1, D), x_1)
115 | return (points[:, :-1] / points[:, -1:]).contiguous()
116 |
117 | else: # Older implementation;
118 | points = scatter(points * weights[:, None], labels, dim=0)
119 | weights = scatter(weights, labels, dim=0)
120 | points = points / weights[:, None]
121 |
122 | else: # We process proteins using a for loop.
123 | # This is probably sub-optimal, but I don't really know
124 | # how to do more elegantly (this type of computation is
125 | # not super well supported by PyTorch).
126 | batch_size = torch.max(batch).item() + 1 # Typically, =32
127 | points, batches = [], []
128 | for b in range(batch_size):
129 | p = subsample(x[batch == b], scale=scale)
130 | points.append(p)
131 | batches.append(b * torch.ones_like(batch[: len(p)]))
132 |
133 | return torch.cat(points, dim=0), torch.cat(batches, dim=0)
134 |
135 |
136 | def soft_distances(x, y, batch_x, batch_y, smoothness=0.01, atomtypes=None):
137 | """Computes a soft distance function to the atom centers of a protein.
138 |
139 | Implements Eq. (1) of the paper in a fast and numerically stable way.
140 |
141 | Args:
142 | x (Tensor): (N,3) atom centers.
143 | y (Tensor): (M,3) sampling locations.
144 | batch_x (integer Tensor): (N,) batch vector for x, as in PyTorch_geometric.
145 | batch_y (integer Tensor): (M,) batch vector for y, as in PyTorch_geometric.
146 | smoothness (float, optional): atom radii if atom types are not provided. Defaults to .01.
147 | atomtypes (integer Tensor, optional): (N,6) one-hot encoding of the atom chemical types. Defaults to None.
148 |
149 | Returns:
150 | Tensor: (M,) values of the soft distance function on the points `y`.
151 | """
152 | # Build the (N, M, 1) symbolic matrix of squared distances:
153 | x_i = LazyTensor(x[:, None, :]) # (N, 1, 3) atoms
154 | y_j = LazyTensor(y[None, :, :]) # (1, M, 3) sampling points
155 | D_ij = ((x_i - y_j) ** 2).sum(-1) # (N, M, 1) squared distances
156 |
157 | # Use a block-diagonal sparsity mask to support heterogeneous batch processing:
158 | D_ij.ranges = diagonal_ranges(batch_x, batch_y)
159 |
160 | if atomtypes is not None:
161 | # Turn the one-hot encoding "atomtypes" into a vector of diameters "smoothness_i":
162 | # (N, 6) -> (N, 1, 1) (There are 6 atom types)
163 | atomic_radii = torch.cuda.FloatTensor(
164 | [170, 110, 152, 155, 180, 190], device=x.device
165 | )
166 | atomic_radii = atomic_radii / atomic_radii.min()
167 | atomtype_radii = atomtypes * atomic_radii[None, :] # n_atoms, n_atomtypes
168 | # smoothness = atomtypes @ atomic_radii # (N, 6) @ (6,) = (N,)
169 | smoothness = torch.sum(
170 | smoothness * atomtype_radii, dim=1, keepdim=False
171 | ) # n_atoms, 1
172 | smoothness_i = LazyTensor(smoothness[:, None, None])
173 |
174 | # Compute an estimation of the mean smoothness in a neighborhood
175 | # of each sampling point:
176 | # density = (-D_ij.sqrt()).exp().sum(0).view(-1) # (M,) local density of atoms
177 | # smooth = (smoothness_i * (-D_ij.sqrt()).exp()).sum(0).view(-1) # (M,)
178 | # mean_smoothness = smooth / density # (M,)
179 |
180 | # soft_dists = -mean_smoothness * (
181 | # (-D_ij.sqrt() / smoothness_i).logsumexp(dim=0)
182 | # ).view(-1)
183 | mean_smoothness = (-D_ij.sqrt()).exp().sum(0)
184 | mean_smoothness_j = LazyTensor(mean_smoothness[None, :, :])
185 | mean_smoothness = (
186 | smoothness_i * (-D_ij.sqrt()).exp() / mean_smoothness_j
187 | ) # n_atoms, n_points, 1
188 | mean_smoothness = mean_smoothness.sum(0).view(-1)
189 | soft_dists = -mean_smoothness * (
190 | (-D_ij.sqrt() / smoothness_i).logsumexp(dim=0)
191 | ).view(-1)
192 |
193 | else:
194 | soft_dists = -smoothness * ((-D_ij.sqrt() / smoothness).logsumexp(dim=0)).view(
195 | -1
196 | )
197 |
198 | return soft_dists
199 |
200 |
201 | def atoms_to_points_normals(
202 | atoms,
203 | batch,
204 | distance=1.05,
205 | smoothness=0.5,
206 | resolution=1.0,
207 | nits=4,
208 | atomtypes=None,
209 | sup_sampling=20,
210 | variance=0.1,
211 | ):
212 | """Turns a collection of atoms into an oriented point cloud.
213 |
214 | Sampling algorithm for protein surfaces, described in Fig. 3 of the paper.
215 |
216 | Args:
217 | atoms (Tensor): (N,3) coordinates of the atom centers `a_k`.
218 | batch (integer Tensor): (N,) batch vector, as in PyTorch_geometric.
219 | distance (float, optional): value of the level set to sample from
220 | the smooth distance function. Defaults to 1.05.
221 | smoothness (float, optional): radii of the atoms, if atom types are
222 | not provided. Defaults to 0.5.
223 | resolution (float, optional): side length of the cubic cells in
224 | the final sub-sampling pass. Defaults to 1.0.
225 | nits (int, optional): number of iterations . Defaults to 4.
226 | atomtypes (Tensor, optional): (N,6) one-hot encoding of the atom
227 | chemical types. Defaults to None.
228 |
229 | Returns:
230 | (Tensor): (M,3) coordinates for the surface points `x_i`.
231 | (Tensor): (M,3) unit normals `n_i`.
232 | (integer Tensor): (M,) batch vector, as in PyTorch_geometric.
233 | """
234 | # a) Parameters for the soft distance function and its level set:
235 | T = distance
236 |
237 | N, D = atoms.shape
238 | B = sup_sampling # Sup-sampling ratio
239 |
240 | # Batch vectors:
241 | batch_atoms = batch
242 | batch_z = batch[:, None].repeat(1, B).view(N * B)
243 |
244 | # b) Draw N*B points at random in the neighborhood of our atoms
245 | z = atoms[:, None, :] + 10 * T * torch.randn(N, B, D).type_as(atoms)
246 | z = z.view(-1, D) # (N*B, D)
247 |
248 | # We don't want to backprop through a full network here!
249 | atoms = atoms.detach().contiguous()
250 | z = z.detach().contiguous()
251 |
252 | # N.B.: Test mode disables the autograd engine: we must switch it on explicitely.
253 | with torch.enable_grad():
254 | if z.is_leaf:
255 | z.requires_grad = True
256 |
257 | # c) Iterative loop: gradient descent along the potential
258 | # ".5 * (dist - T)^2" with respect to the positions z of our samples
259 | for it in range(nits):
260 | dists = soft_distances(
261 | atoms,
262 | z,
263 | batch_atoms,
264 | batch_z,
265 | smoothness=smoothness,
266 | atomtypes=atomtypes,
267 | )
268 | Loss = ((dists - T) ** 2).sum()
269 | g = torch.autograd.grad(Loss, z)[0]
270 | z.data -= 0.5 * g
271 |
272 | # d) Only keep the points which are reasonably close to the level set:
273 | dists = soft_distances(
274 | atoms, z, batch_atoms, batch_z, smoothness=smoothness, atomtypes=atomtypes
275 | )
276 | margin = (dists - T).abs()
277 | mask = margin < variance * T
278 |
279 | # d') And remove the points that are trapped *inside* the protein:
280 | zz = z.detach()
281 | zz.requires_grad = True
282 | for it in range(nits):
283 | dists = soft_distances(
284 | atoms,
285 | zz,
286 | batch_atoms,
287 | batch_z,
288 | smoothness=smoothness,
289 | atomtypes=atomtypes,
290 | )
291 | Loss = (1.0 * dists).sum()
292 | g = torch.autograd.grad(Loss, zz)[0]
293 | normals = F.normalize(g, p=2, dim=-1) # (N, 3)
294 | zz = zz + 1.0 * T * normals
295 |
296 | dists = soft_distances(
297 | atoms, zz, batch_atoms, batch_z, smoothness=smoothness, atomtypes=atomtypes
298 | )
299 | mask = mask & (dists > 1.5 * T)
300 |
301 | z = z[mask].contiguous().detach()
302 | batch_z = batch_z[mask].contiguous().detach()
303 |
304 | # e) Subsample the point cloud:
305 | points, batch_points = subsample(z, batch_z, scale=resolution)
306 |
307 | # f) Compute the normals on this smaller point cloud:
308 | p = points.detach()
309 | p.requires_grad = True
310 | dists = soft_distances(
311 | atoms,
312 | p,
313 | batch_atoms,
314 | batch_points,
315 | smoothness=smoothness,
316 | atomtypes=atomtypes,
317 | )
318 | Loss = (1.0 * dists).sum()
319 | g = torch.autograd.grad(Loss, p)[0]
320 | normals = F.normalize(g, p=2, dim=-1) # (N, 3)
321 | points = points - 0.5 * normals
322 | return points.detach(), normals.detach(), batch_points.detach()
323 |
324 |
325 | # Surface mesh -> Normals ======================================================
326 |
327 |
328 | def mesh_normals_areas(vertices, triangles=None, scale=[1.0], batch=None, normals=None):
329 | """Returns a smooth field of normals, possibly at different scales.
330 |
331 | points, triangles or normals, scale(s) -> normals
332 | (N, 3), (3, T) or (N,3), (S,) -> (N, 3) or (N, S, 3)
333 |
334 | Simply put - if `triangles` are provided:
335 | 1. Normals are first computed for every triangle using simple 3D geometry
336 | and are weighted according to surface area.
337 | 2. The normal at any given vertex is then computed as the weighted average
338 | of the normals of all triangles in a neighborhood specified
339 | by Gaussian windows whose radii are given in the list of "scales".
340 |
341 | If `normals` are provided instead, we simply smooth the discrete vector
342 | field using Gaussian windows whose radii are given in the list of "scales".
343 |
344 | If more than one scale is provided, normal fields are computed in parallel
345 | and returned in a single 3D tensor.
346 |
347 | Args:
348 | vertices (Tensor): (N,3) coordinates of mesh vertices or 3D points.
349 | triangles (integer Tensor, optional): (3,T) mesh connectivity. Defaults to None.
350 | scale (list of floats, optional): (S,) radii of the Gaussian smoothing windows. Defaults to [1.].
351 | batch (integer Tensor, optional): batch vector, as in PyTorch_geometric. Defaults to None.
352 | normals (Tensor, optional): (N,3) raw normals vectors on the vertices. Defaults to None.
353 |
354 | Returns:
355 | (Tensor): (N,3) or (N,S,3) point normals.
356 | (Tensor): (N,) point areas, if triangles were provided.
357 | """
358 |
359 | # Single- or Multi-scale mode:
360 | if hasattr(scale, "__len__"):
361 | scales, single_scale = scale, False
362 | else:
363 | scales, single_scale = [scale], True
364 | scales = torch.Tensor(scales).type_as(vertices) # (S,)
365 |
366 | # Compute the "raw" field of normals:
367 | if triangles is not None:
368 | # Vertices of all triangles in the mesh:
369 | A = vertices[triangles[0, :]] # (N, 3)
370 | B = vertices[triangles[1, :]] # (N, 3)
371 | C = vertices[triangles[2, :]] # (N, 3)
372 |
373 | # Triangle centers and normals (length = surface area):
374 | centers = (A + B + C) / 3 # (N, 3)
375 | V = (B - A).cross(C - A) # (N, 3)
376 |
377 | # Vertice areas:
378 | S = (V ** 2).sum(-1).sqrt() / 6 # (N,) 1/3 of a triangle area
379 | areas = torch.zeros(len(vertices)).type_as(vertices) # (N,)
380 | areas.scatter_add_(0, triangles[0, :], S) # Aggregate from "A's"
381 | areas.scatter_add_(0, triangles[1, :], S) # Aggregate from "B's"
382 | areas.scatter_add_(0, triangles[2, :], S) # Aggregate from "C's"
383 |
384 | else: # Use "normals" instead
385 | areas = None
386 | V = normals
387 | centers = vertices
388 |
389 | # Normal of a vertex = average of all normals in a ball of size "scale":
390 | x_i = LazyTensor(vertices[:, None, :]) # (N, 1, 3)
391 | y_j = LazyTensor(centers[None, :, :]) # (1, M, 3)
392 | v_j = LazyTensor(V[None, :, :]) # (1, M, 3)
393 | s = LazyTensor(scales[None, None, :]) # (1, 1, S)
394 |
395 | D_ij = ((x_i - y_j) ** 2).sum(-1) # (N, M, 1)
396 | K_ij = (-D_ij / (2 * s ** 2)).exp() # (N, M, S)
397 |
398 | # Support for heterogeneous batch processing:
399 | if batch is not None:
400 | batch_vertices = batch
401 | batch_centers = batch[triangles[0, :]] if triangles is not None else batch
402 | K_ij.ranges = diagonal_ranges(batch_vertices, batch_centers)
403 |
404 | if single_scale:
405 | U = (K_ij * v_j).sum(dim=1) # (N, 3)
406 | else:
407 | U = (K_ij.tensorprod(v_j)).sum(dim=1) # (N, S*3)
408 | U = U.view(-1, len(scales), 3) # (N, S, 3)
409 |
410 | normals = F.normalize(U, p=2, dim=-1) # (N, 3) or (N, S, 3)
411 |
412 | return normals, areas
413 |
414 |
415 | # Compute tangent planes and curvatures ========================================
416 |
417 |
418 | def tangent_vectors(normals):
419 | """Returns a pair of vector fields u and v to complete the orthonormal basis [n,u,v].
420 |
421 | normals -> uv
422 | (N, 3) or (N, S, 3) -> (N, 2, 3) or (N, S, 2, 3)
423 |
424 | This routine assumes that the 3D "normal" vectors are normalized.
425 | It is based on the 2017 paper from Pixar, "Building an orthonormal basis, revisited".
426 |
427 | Args:
428 | normals (Tensor): (N,3) or (N,S,3) normals `n_i`, i.e. unit-norm 3D vectors.
429 |
430 | Returns:
431 | (Tensor): (N,2,3) or (N,S,2,3) unit vectors `u_i` and `v_i` to complete
432 | the tangent coordinate systems `[n_i,u_i,v_i].
433 | """
434 | x, y, z = normals[..., 0], normals[..., 1], normals[..., 2]
435 | s = (2 * (z >= 0)) - 1.0 # = z.sign(), but =1. if z=0.
436 | a = -1 / (s + z)
437 | b = x * y * a
438 | uv = torch.stack((1 + s * x * x * a, s * b, -s * x, b, s + y * y * a, -y), dim=-1)
439 | uv = uv.view(uv.shape[:-1] + (2, 3))
440 |
441 | return uv
442 |
443 |
444 | def curvatures(
445 | vertices, triangles=None, scales=[1.0], batch=None, normals=None, reg=0.01
446 | ):
447 | """Returns a collection of mean (H) and Gauss (K) curvatures at different scales.
448 |
449 | points, faces, scales -> (H_1, K_1, ..., H_S, K_S)
450 | (N, 3), (3, N), (S,) -> (N, S*2)
451 |
452 | We rely on a very simple linear regression method, for all vertices:
453 |
454 | 1. Estimate normals and surface areas.
455 | 2. Compute a local tangent frame.
456 | 3. In a pseudo-geodesic Gaussian neighborhood at scale s,
457 | compute the two (2, 2) covariance matrices PPt and PQt
458 | between the displacement vectors "P = x_i - x_j" and
459 | the normals "Q = n_i - n_j", projected on the local tangent plane.
460 | 4. Up to the sign, the shape operator S at scale s is then approximated
461 | as "S = (reg**2 * I_2 + PPt)^-1 @ PQt".
462 | 5. The mean and Gauss curvatures are the trace and determinant of
463 | this (2, 2) matrix.
464 |
465 | As of today, this implementation does not weigh points by surface areas:
466 | this could make a sizeable difference if protein surfaces were not
467 | sub-sampled to ensure uniform sampling density.
468 |
469 | For convergence analysis, see for instance
470 | "Efficient curvature estimation for oriented point clouds",
471 | Cao, Li, Sun, Assadi, Zhang, 2019.
472 |
473 | Args:
474 | vertices (Tensor): (N,3) coordinates of the points or mesh vertices.
475 | triangles (integer Tensor, optional): (3,T) mesh connectivity. Defaults to None.
476 | scales (list of floats, optional): list of (S,) smoothing scales. Defaults to [1.].
477 | batch (integer Tensor, optional): batch vector, as in PyTorch_geometric. Defaults to None.
478 | normals (Tensor, optional): (N,3) field of "raw" unit normals. Defaults to None.
479 | reg (float, optional): small amount of Tikhonov/ridge regularization
480 | in the estimation of the shape operator. Defaults to .01.
481 |
482 | Returns:
483 | (Tensor): (N, S*2) tensor of mean and Gauss curvatures computed for
484 | every point at the required scales.
485 | """
486 | # Number of points, number of scales:
487 | N, S = vertices.shape[0], len(scales)
488 | ranges = diagonal_ranges(batch)
489 |
490 | # Compute the normals at different scales + vertice areas:
491 | normals_s, _ = mesh_normals_areas(
492 | vertices, triangles=triangles, normals=normals, scale=scales, batch=batch
493 | ) # (N, S, 3), (N,)
494 |
495 | # Local tangent bases:
496 | uv_s = tangent_vectors(normals_s) # (N, S, 2, 3)
497 |
498 | features = []
499 |
500 | for s, scale in enumerate(scales):
501 | # Extract the relevant descriptors at the current scale:
502 | normals = normals_s[:, s, :].contiguous() # (N, 3)
503 | uv = uv_s[:, s, :, :].contiguous() # (N, 2, 3)
504 |
505 | # Encode as symbolic tensors:
506 | # Points:
507 | x_i = LazyTensor(vertices.view(N, 1, 3))
508 | x_j = LazyTensor(vertices.view(1, N, 3))
509 | # Normals:
510 | n_i = LazyTensor(normals.view(N, 1, 3))
511 | n_j = LazyTensor(normals.view(1, N, 3))
512 | # Tangent bases:
513 | uv_i = LazyTensor(uv.view(N, 1, 6))
514 |
515 | # Pseudo-geodesic squared distance:
516 | d2_ij = ((x_j - x_i) ** 2).sum(-1) * ((2 - (n_i | n_j)) ** 2) # (N, N, 1)
517 | # Gaussian window:
518 | window_ij = (-d2_ij / (2 * (scale ** 2))).exp() # (N, N, 1)
519 |
520 | # Project on the tangent plane:
521 | P_ij = uv_i.matvecmult(x_j - x_i) # (N, N, 2)
522 | Q_ij = uv_i.matvecmult(n_j - n_i) # (N, N, 2)
523 | # Concatenate:
524 | PQ_ij = P_ij.concat(Q_ij) # (N, N, 2+2)
525 |
526 | # Covariances, with a scale-dependent weight:
527 | PPt_PQt_ij = P_ij.tensorprod(PQ_ij) # (N, N, 2*(2+2))
528 | PPt_PQt_ij = window_ij * PPt_PQt_ij # (N, N, 2*(2+2))
529 |
530 | # Reduction - with batch support:
531 | PPt_PQt_ij.ranges = ranges
532 | PPt_PQt = PPt_PQt_ij.sum(1) # (N, 2*(2+2))
533 |
534 | # Reshape to get the two covariance matrices:
535 | PPt_PQt = PPt_PQt.view(N, 2, 2, 2)
536 | PPt, PQt = PPt_PQt[:, :, 0, :], PPt_PQt[:, :, 1, :] # (N, 2, 2), (N, 2, 2)
537 |
538 | # Add a small ridge regression:
539 | PPt[:, 0, 0] += reg
540 | PPt[:, 1, 1] += reg
541 |
542 | # (minus) Shape operator, i.e. the differential of the Gauss map:
543 | # = (PPt^-1 @ PQt) : simple estimation through linear regression
544 | S = torch.solve(PQt, PPt).solution
545 | a, b, c, d = S[:, 0, 0], S[:, 0, 1], S[:, 1, 0], S[:, 1, 1] # (N,)
546 |
547 | # Normalization
548 | mean_curvature = a + d
549 | gauss_curvature = a * d - b * c
550 | features += [mean_curvature.clamp(-1, 1), gauss_curvature.clamp(-1, 1)]
551 |
552 | features = torch.stack(features, dim=-1)
553 | return features
554 |
555 |
556 | # Fast tangent convolution layer ===============================================
557 | class ContiguousBackward(torch.autograd.Function):
558 | """
559 | Function to ensure contiguous gradient in backward pass. To be applied after PyKeOps reduction.
560 | N.B.: This workaround fixes a bug that will be fixed in ulterior KeOp releases.
561 | """
562 | @staticmethod
563 | def forward(ctx, input):
564 | return input
565 |
566 | @staticmethod
567 | def backward(ctx, grad_output):
568 | return grad_output.contiguous()
569 |
570 | class dMaSIFConv(nn.Module):
571 | def __init__(
572 | self, in_channels=1, out_channels=1, radius=1.0, hidden_units=None, cheap=False
573 | ):
574 | """Creates the KeOps convolution layer.
575 |
576 | I = in_channels is the dimension of the input features
577 | O = out_channels is the dimension of the output features
578 | H = hidden_units is the dimension of the intermediate representation
579 | radius is the size of the pseudo-geodesic Gaussian window w_ij = W(d_ij)
580 |
581 |
582 | This affordable layer implements an elementary "convolution" operator
583 | on a cloud of N points (x_i) in dimension 3 that we decompose in three steps:
584 |
585 | 1. Apply the MLP "net_in" on the input features "f_i". (N, I) -> (N, H)
586 |
587 | 2. Compute H interaction terms in parallel with:
588 | f_i = sum_j [ w_ij * conv(P_ij) * f_j ]
589 | In the equation above:
590 | - w_ij is a pseudo-geodesic window with a set radius.
591 | - P_ij is a vector of dimension 3, equal to "x_j-x_i"
592 | in the local oriented basis at x_i.
593 | - "conv" is an MLP from R^3 to R^H:
594 | - with 1 linear layer if "cheap" is True;
595 | - with 2 linear layers and C=8 intermediate "cuts" otherwise.
596 | - "*" is coordinate-wise product.
597 | - f_j is the vector of transformed features.
598 |
599 | 3. Apply the MLP "net_out" on the output features. (N, H) -> (N, O)
600 |
601 |
602 | A more general layer would have implemented conv(P_ij) as a full
603 | (H, H) matrix instead of a mere (H,) vector... At a much higher
604 | computational cost. The reasoning behind the code below is that
605 | a given time budget is better spent on using a larger architecture
606 | and more channels than on a very complex convolution operator.
607 | Interactions between channels happen at steps 1. and 3.,
608 | whereas the (costly) point-to-point interaction step 2.
609 | lets the network aggregate information in spatial neighborhoods.
610 |
611 | Args:
612 | in_channels (int, optional): numper of input features per point. Defaults to 1.
613 | out_channels (int, optional): number of output features per point. Defaults to 1.
614 | radius (float, optional): deviation of the Gaussian window on the
615 | quasi-geodesic distance `d_ij`. Defaults to 1..
616 | hidden_units (int, optional): number of hidden features per point.
617 | Defaults to out_channels.
618 | cheap (bool, optional): shall we use a 1-layer deep Filter,
619 | instead of a 2-layer deep MLP? Defaults to False.
620 | """
621 |
622 | super(dMaSIFConv, self).__init__()
623 |
624 | self.Input = in_channels
625 | self.Output = out_channels
626 | self.Radius = radius
627 | self.Hidden = self.Output if hidden_units is None else hidden_units
628 | self.Cuts = 8 # Number of hidden units for the 3D MLP Filter.
629 | self.cheap = cheap
630 |
631 | # For performance reasons, we cut our "hidden" vectors
632 | # in n_heads "independent heads" of dimension 8.
633 | self.heads_dim = 8 # 4 is probably too small; 16 is certainly too big
634 |
635 | # We accept "Hidden" dimensions of size 1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, ...
636 | if self.Hidden < self.heads_dim:
637 | self.heads_dim = self.Hidden
638 |
639 | if self.Hidden % self.heads_dim != 0:
640 | raise ValueError(f"The dimension of the hidden units ({self.Hidden})"\
641 | + f"should be a multiple of the heads dimension ({self.heads_dim}).")
642 | else:
643 | self.n_heads = self.Hidden // self.heads_dim
644 |
645 |
646 | # Transformation of the input features:
647 | self.net_in = nn.Sequential(
648 | nn.Linear(self.Input, self.Hidden), # (H, I) + (H,)
649 | nn.LeakyReLU(negative_slope=0.2),
650 | nn.Linear(self.Hidden, self.Hidden), # (H, H) + (H,)
651 | # nn.LayerNorm(self.Hidden),#nn.BatchNorm1d(self.Hidden),
652 | nn.LeakyReLU(negative_slope=0.2),
653 | ) # (H,)
654 | self.norm_in = nn.GroupNorm(4, self.Hidden)
655 | # self.norm_in = nn.LayerNorm(self.Hidden)
656 | # self.norm_in = nn.Identity()
657 |
658 | # 3D convolution filters, encoded as an MLP:
659 | if cheap:
660 | self.conv = nn.Sequential(
661 | nn.Linear(3, self.Hidden), nn.ReLU() # (H, 3) + (H,)
662 | ) # KeOps does not support well LeakyReLu
663 | else:
664 | self.conv = nn.Sequential(
665 | nn.Linear(3, self.Cuts), # (C, 3) + (C,)
666 | nn.ReLU(), # KeOps does not support well LeakyReLu
667 | nn.Linear(self.Cuts, self.Hidden),
668 | ) # (H, C) + (H,)
669 |
670 | # Transformation of the output features:
671 | self.net_out = nn.Sequential(
672 | nn.Linear(self.Hidden, self.Output), # (O, H) + (O,)
673 | nn.LeakyReLU(negative_slope=0.2),
674 | nn.Linear(self.Output, self.Output), # (O, O) + (O,)
675 | # nn.LayerNorm(self.Output),#nn.BatchNorm1d(self.Output),
676 | nn.LeakyReLU(negative_slope=0.2),
677 | ) # (O,)
678 |
679 | self.norm_out = nn.GroupNorm(4, self.Output)
680 | # self.norm_out = nn.LayerNorm(self.Output)
681 | # self.norm_out = nn.Identity()
682 |
683 | # Custom initialization for the MLP convolution filters:
684 | # we get interesting piecewise affine cuts on a normalized neighborhood.
685 | with torch.no_grad():
686 | nn.init.normal_(self.conv[0].weight)
687 | nn.init.uniform_(self.conv[0].bias)
688 | self.conv[0].bias *= 0.8 * (self.conv[0].weight ** 2).sum(-1).sqrt()
689 |
690 | if not cheap:
691 | nn.init.uniform_(
692 | self.conv[2].weight,
693 | a=-1 / np.sqrt(self.Cuts),
694 | b=1 / np.sqrt(self.Cuts),
695 | )
696 | nn.init.normal_(self.conv[2].bias)
697 | self.conv[2].bias *= 0.5 * (self.conv[2].weight ** 2).sum(-1).sqrt()
698 |
699 |
700 | def forward(self, points, nuv, features, ranges=None):
701 | """Performs a quasi-geodesic interaction step.
702 |
703 | points, local basis, in features -> out features
704 | (N, 3), (N, 3, 3), (N, I) -> (N, O)
705 |
706 | This layer computes the interaction step of Eq. (7) in the paper,
707 | in-between the application of two MLP networks independently on all
708 | feature vectors.
709 |
710 | Args:
711 | points (Tensor): (N,3) point coordinates `x_i`.
712 | nuv (Tensor): (N,3,3) local coordinate systems `[n_i,u_i,v_i]`.
713 | features (Tensor): (N,I) input feature vectors `f_i`.
714 | ranges (6-uple of integer Tensors, optional): low-level format
715 | to support batch processing, as described in the KeOps documentation.
716 | In practice, this will be built by a higher-level object
717 | to encode the relevant "batch vectors" in a way that is convenient
718 | for the KeOps CUDA engine. Defaults to None.
719 |
720 | Returns:
721 | (Tensor): (N,O) output feature vectors `f'_i`.
722 | """
723 |
724 | # 1. Transform the input features: -------------------------------------
725 | features = self.net_in(features) # (N, I) -> (N, H)
726 | features = features.transpose(1, 0)[None, :, :] # (1,H,N)
727 | features = self.norm_in(features)
728 | features = features[0].transpose(1, 0).contiguous() # (1, H, N) -> (N, H)
729 |
730 | # 2. Compute the local "shape contexts": -------------------------------
731 |
732 | # 2.a Normalize the kernel radius:
733 | points = points / (sqrt(2.0) * self.Radius) # (N, 3)
734 |
735 | # 2.b Encode the variables as KeOps LazyTensors
736 |
737 | # Vertices:
738 | x_i = LazyTensor(points[:, None, :]) # (N, 1, 3)
739 | x_j = LazyTensor(points[None, :, :]) # (1, N, 3)
740 |
741 | # WARNING - Here, we assume that the normals are fixed:
742 | normals = (
743 | nuv[:, 0, :].contiguous().detach()
744 | ) # (N, 3) - remove the .detach() if needed
745 |
746 | # Local bases:
747 | nuv_i = LazyTensor(nuv.view(-1, 1, 9)) # (N, 1, 9)
748 | # Normals:
749 | n_i = nuv_i[:3] # (N, 1, 3)
750 |
751 | n_j = LazyTensor(normals[None, :, :]) # (1, N, 3)
752 |
753 | # To avoid register spilling when using large embeddings, we perform our KeOps reduction
754 | # over the vector of length "self.Hidden = self.n_heads * self.heads_dim"
755 | # as self.n_heads reduction over vectors of length self.heads_dim (= "Hd" in the comments).
756 | head_out_features = []
757 | for head in range(self.n_heads):
758 |
759 | # Extract a slice of width Hd from the feature array
760 | head_start = head * self.heads_dim
761 | head_end = head_start + self.heads_dim
762 | head_features = features[:, head_start:head_end].contiguous() # (N, H) -> (N, Hd)
763 |
764 | # Features:
765 | f_j = LazyTensor(head_features[None, :, :]) # (1, N, Hd)
766 |
767 | # Convolution parameters:
768 | if self.cheap:
769 | # Extract a slice of Hd lines: (H, 3) -> (Hd, 3)
770 | A = self.conv[0].weight[head_start:head_end, :].contiguous()
771 | # Extract a slice of Hd coefficients: (H,) -> (Hd,)
772 | B = self.conv[0].bias[head_start:head_end].contiguous()
773 | AB = torch.cat((A, B[:, None]), dim=1) # (Hd, 4)
774 | ab = LazyTensor(AB.view(1, 1, -1)) # (1, 1, Hd*4)
775 | else:
776 | A_1, B_1 = self.conv[0].weight, self.conv[0].bias # (C, 3), (C,)
777 | # Extract a slice of Hd lines: (H, C) -> (Hd, C)
778 | A_2 = self.conv[2].weight[head_start:head_end, :].contiguous()
779 | # Extract a slice of Hd coefficients: (H,) -> (Hd,)
780 | B_2 = self.conv[2].bias[head_start:head_end].contiguous()
781 | a_1 = LazyTensor(A_1.view(1, 1, -1)) # (1, 1, C*3)
782 | b_1 = LazyTensor(B_1.view(1, 1, -1)) # (1, 1, C)
783 | a_2 = LazyTensor(A_2.view(1, 1, -1)) # (1, 1, Hd*C)
784 | b_2 = LazyTensor(B_2.view(1, 1, -1)) # (1, 1, Hd)
785 |
786 | # 2.c Pseudo-geodesic window:
787 | # Pseudo-geodesic squared distance:
788 | d2_ij = ((x_j - x_i) ** 2).sum(-1) * ((2 - (n_i | n_j)) ** 2) # (N, N, 1)
789 | # Gaussian window:
790 | window_ij = (-d2_ij).exp() # (N, N, 1)
791 |
792 | # 2.d Local MLP:
793 | # Local coordinates:
794 | X_ij = nuv_i.matvecmult(x_j - x_i) # (N, N, 9) "@" (N, N, 3) = (N, N, 3)
795 | # MLP:
796 | if self.cheap:
797 | X_ij = ab.matvecmult(
798 | X_ij.concat(LazyTensor(1))
799 | ) # (N, N, Hd*4) @ (N, N, 3+1) = (N, N, Hd)
800 | X_ij = X_ij.relu() # (N, N, Hd)
801 | else:
802 | X_ij = a_1.matvecmult(X_ij) + b_1 # (N, N, C)
803 | X_ij = X_ij.relu() # (N, N, C)
804 | X_ij = a_2.matvecmult(X_ij) + b_2 # (N, N, Hd)
805 | X_ij = X_ij.relu()
806 |
807 | # 2.e Actual computation:
808 | F_ij = window_ij * X_ij * f_j # (N, N, Hd)
809 | F_ij.ranges = ranges # Support for batches and/or block-sparsity
810 |
811 | head_out_features.append(ContiguousBackward().apply(F_ij.sum(dim=1))) # (N, Hd)
812 |
813 | # Concatenate the result of our n_heads "attention heads":
814 | features = torch.cat(head_out_features, dim=1) # n_heads * (N, Hd) -> (N, H)
815 |
816 | # 3. Transform the output features: ------------------------------------
817 | features = self.net_out(features) # (N, H) -> (N, O)
818 | features = features.transpose(1, 0)[None, :, :] # (1,O,N)
819 | features = self.norm_out(features)
820 | features = features[0].transpose(1, 0).contiguous()
821 |
822 | return features
823 |
--------------------------------------------------------------------------------
/helper.py:
--------------------------------------------------------------------------------
1 | import colorsys
2 |
3 | import numpy as np
4 | import torch
5 | from pykeops.torch import LazyTensor
6 | from plyfile import PlyData, PlyElement
7 | from pathlib import Path
8 |
9 |
10 | tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
11 | inttensor = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor
12 | numpy = lambda x: x.detach().cpu().numpy()
13 |
14 |
15 | def ranges_slices(batch):
16 | """Helper function for the diagonal ranges function."""
17 | Ns = batch.bincount()
18 | indices = Ns.cumsum(0)
19 | ranges = torch.cat((0 * indices[:1], indices))
20 | ranges = (
21 | torch.stack((ranges[:-1], ranges[1:])).t().int().contiguous().to(batch.device)
22 | )
23 | slices = (1 + torch.arange(len(Ns))).int().to(batch.device)
24 |
25 | return ranges, slices
26 |
27 |
28 | def diagonal_ranges(batch_x=None, batch_y=None):
29 | """Encodes the block-diagonal structure associated to a batch vector."""
30 |
31 | if batch_x is None and batch_y is None:
32 | return None # No batch processing
33 | elif batch_y is None:
34 | batch_y = batch_x # "symmetric" case
35 |
36 | ranges_x, slices_x = ranges_slices(batch_x)
37 | ranges_y, slices_y = ranges_slices(batch_y)
38 |
39 | return ranges_x, slices_x, ranges_y, ranges_y, slices_y, ranges_x
40 |
41 |
42 | def soft_dimension(features):
43 | """Continuous approximation of the rank of a (N, D) sample.
44 |
45 | Let "s" denote the (D,) vector of eigenvalues of Cov,
46 | the (D, D) covariance matrix of the sample "features".
47 | Then,
48 | R(features) = \sum_i sqrt(s_i) / \max_i sqrt(s_i)
49 |
50 | This quantity encodes the number of PCA components that would be
51 | required to describe the sample with a good precision.
52 | It is equal to D if the sample is isotropic, but is generally much lower.
53 |
54 | Up to the re-normalization by the largest eigenvalue,
55 | this continuous pseudo-rank is equal to the nuclear norm of the sample.
56 | """
57 |
58 | nfeat = features.shape[-1]
59 | features = features.view(-1, nfeat)
60 | x = features - torch.mean(features, dim=0, keepdim=True)
61 | cov = x.T @ x
62 | try:
63 | u, s, v = torch.svd(cov)
64 | R = s.sqrt().sum() / s.sqrt().max()
65 | except:
66 | return -1
67 | return R.item()
68 |
--------------------------------------------------------------------------------
/lists/testing.txt:
--------------------------------------------------------------------------------
1 | 1A99_D
2 | 1AUI_A
3 | 1BO6_A
4 | 1BOU_B
5 | 1BPL_B
6 | 1C7N_D
7 | 1CS0_B
8 | 1DM5_F
9 | 1DOS_B
10 | 1EFV_A
11 | 1EHI_A
12 | 1EV7_A
13 | 1EWY_A
14 | 1EZ1_B
15 | 1F06_A
16 | 1F6M_A
17 | 1FFV_C
18 | 1FR8_A
19 | 1FT8_C
20 | 1FZW_B
21 | 1GP2_BG
22 | 1GX7_A
23 | 1H6D_A
24 | 1HF2_C
25 | 1I2M_B
26 | 1IUG_A
27 | 1IXS_B
28 | 1IZ1_B
29 | 1J0E_B
30 | 1JG8_D
31 | 1JR3_D
32 | 1JTD_B
33 | 1K3R_B
34 | 1KIZ_B
35 | 1KYQ_B
36 | 1LEH_B
37 | 1M1T_B
38 | 1M32_D
39 | 1MB2_F
40 | 1NCA_N
41 | 1ND6_B
42 | 1O4U_B
43 | 1O57_D
44 | 1O61_A
45 | 1OJ7_D
46 | 1OMZ_B
47 | 1ONW_B
48 | 1P6X_A
49 | 1P9E_A
50 | 1PK8_A
51 | 1POI_A
52 | 1R0K_D
53 | 1R8J_A
54 | 1RQD_A
55 | 1RYI_C
56 | 1SOJ_J
57 | 1ST0_A
58 | 1SUW_D
59 | 1SXJ_D
60 | 1SXJ_E
61 | 1SZ2_B
62 | 1T8U_B
63 | 1TO6_A
64 | 1U0R_C
65 | 1U2V_C
66 | 1U6S_A
67 | 1UL1_X
68 | 1V7C_A
69 | 1VHK_C
70 | 1W1W_B
71 | 1W23_A
72 | 1WDW_H
73 | 1WKH_B
74 | 1WKV_B
75 | 1WPP_B
76 | 1WX1_B
77 | 1X7O_A
78 | 1X9J_G
79 | 1XF9_A
80 | 1XG2_A
81 | 1XI8_A
82 | 1XQS_A
83 | 1XXI_C
84 | 1XXI_E
85 | 1Y56_B
86 | 1Y8Q_A
87 | 1YKJ_A
88 | 1Z85_B
89 | 1ZCT_B
90 | 1ZH8_A
91 | 1ZHH_A
92 | 1ZXO_C
93 | 2AF4_C
94 | 2AUN_A
95 | 2AW6_A
96 | 2AYO_A
97 | 2B3Z_D
98 | 2BP7_B
99 | 2BWN_E
100 | 2C0L_A
101 | 2CDB_A
102 | 2CE8_A
103 | 2CH5_A
104 | 2CVO_B
105 | 2DG0_E
106 | 2E2P_A
107 | 2E5F_A
108 | 2E7J_B
109 | 2E89_A
110 | 2EG5_E
111 | 2EJW_A
112 | 2EP5_B
113 | 2F4M_A
114 | 2F4N_C
115 | 2FV2_C
116 | 2GSZ_F
117 | 2GVQ_C
118 | 2GZA_A
119 | 2HYX_C
120 | 2HZG_A
121 | 2HZK_D
122 | 2I3T_A
123 | 2I7N_B
124 | 2IJZ_G
125 | 2IP2_B
126 | 2J0Q_B
127 | 2J5T_G
128 | 2J6X_G
129 | 2MTA_HL
130 | 2NQL_A
131 | 2OBN_D
132 | 2OGJ_E
133 | 2OIZ_B
134 | 2OOR_AB
135 | 2OZK_B
136 | 2PBI_B
137 | 2PMS_A
138 | 2PP1_A
139 | 2PVP_A
140 | 2Q0J_A
141 | 2QDH_A
142 | 2QFC_A
143 | 2QGI_A
144 | 2QXV_A
145 | 2QYO_A
146 | 2R87_A
147 | 2R8Q_B
148 | 2UX8_C
149 | 2V3B_A
150 | 2V7X_C
151 | 2V9P_G
152 | 2V9T_B
153 | 2VCG_D
154 | 2VHI_C
155 | 2VHW_A
156 | 2VN8_A
157 | 2VUN_C
158 | 2WPX_B
159 | 2WUS_A
160 | 2WVM_A
161 | 2WYR_G
162 | 2X0D_A
163 | 2X2E_D
164 | 2X5D_D
165 | 2X65_B
166 | 2XT2_A
167 | 2XWT_C
168 | 2Y0M_A
169 | 2Y5B_A
170 | 2YCH_A
171 | 2Z4R_B
172 | 2Z50_A
173 | 2Z71_C
174 | 2Z9V_B
175 | 2ZBK_A
176 | 2ZIU_A
177 | 2ZIU_B
178 | 2ZUC_B
179 | 2ZZX_A
180 | 3AB1_B
181 | 3AEO_C
182 | 3AFO_A
183 | 3AP2_A
184 | 3AUY_A
185 | 3B5U_J
186 | 3BH6_B
187 | 3BM5_A
188 | 3BP8_AB
189 | 3BT1_U
190 | 3BTV_B
191 | 3BV6_D
192 | 3BWO_D
193 | 3C0B_C
194 | 3C0K_B
195 | 3C3J_F
196 | 3C48_A
197 | 3CE9_A
198 | 3CEA_B
199 | 3CQ6_A
200 | 3CYJ_B
201 | 3D6K_B
202 | 3DDM_B
203 | 3DHW_D
204 | 3DP7_A
205 | 3DUG_C
206 | 3DZ2_A
207 | 3DZC_B
208 | 3E18_A
209 | 3E38_B
210 | 3E5P_C
211 | 3E9M_A
212 | 3EIQ_C
213 | 3EPW_B
214 | 3ES8_H
215 | 3EUA_H
216 | 3EZ6_B
217 | 3EZY_A
218 | 3FGT_B
219 | 3FHC_A
220 | 3FZ0_B
221 | 3G0I_A
222 | 3G8Q_C
223 | 3G8R_A
224 | 3GJZ_A
225 | 3GL1_A
226 | 3GZT_O
227 | 3H6G_B
228 | 3H77_B
229 | 3H8L_B
230 | 3H9G_A
231 | 3HE3_F
232 | 3HGU_A
233 | 3HLI_A
234 | 3HMK_B
235 | 3HPV_B
236 | 3HWS_D
237 | 3HXJ_A
238 | 3IAU_A
239 | 3IF8_A
240 | 3IGF_B
241 | 3IO1_A
242 | 3ISL_B
243 | 3IX1_B
244 | 3JSK_G
245 | 3JTX_A
246 | 3K5H_A
247 | 3KKI_A
248 | 3KL5_B
249 | 3KL9_J
250 | 3L31_A
251 | 3L9W_B
252 | 3LED_A
253 | 3LEE_A
254 | 3LJQ_C
255 | 3LKU_E
256 | 3LMA_C
257 | 3LOU_C
258 | 3LVK_AC
259 | 3M2T_A
260 | 3MCA_B
261 | 3MCZ_A
262 | 3MEN_B
263 | 3MGC_A
264 | 3MKR_A
265 | 3MMY_G
266 | 3MWE_B
267 | 3MZK_D
268 | 3N29_A
269 | 3N3D_A
270 | 3NND_C
271 | 3NTQ_A
272 | 3NVN_A
273 | 3NVV_B
274 | 3O3M_D
275 | 3O3P_A
276 | 3O5T_A
277 | 3OBK_F
278 | 3OM1_A
279 | 3ON5_A
280 | 3OOQ_H
281 | 3OQB_D
282 | 3OV3_B
283 | 3P6K_B
284 | 3P72_A
285 | 3P9I_C
286 | 3PG9_H
287 | 3PGA_1
288 | 3PND_D
289 | 3PNK_B
290 | 3PUZ_B
291 | 3PWS_A
292 | 3QBX_B
293 | 3QE9_Y
294 | 3QKW_C
295 | 3QML_D
296 | 3QW2_B
297 | 3R0Q_A
298 | 3R1X_A
299 | 3R5X_D
300 | 3R9A_AC
301 | 3RAM_D
302 | 3RCY_C
303 | 3RF6_B
304 | 3RFH_A
305 | 3RHF_B
306 | 3RMR_A
307 | 3RT0_A
308 | 3S5U_E
309 | 3SF5_D
310 | 3SJA_I
311 | 3SN6_A
312 | 3SSO_B
313 | 3SYL_A
314 | 3SZP_A
315 | 3T5P_A
316 | 3T8E_A
317 | 3THO_A
318 | 3THO_B
319 | 3TII_B
320 | 3TK1_B
321 | 3TQC_B
322 | 3TWO_B
323 | 3U5Z_B
324 | 3UGV_C
325 | 3UI2_A
326 | 3UK7_A
327 | 3UVN_C
328 | 3V5N_C
329 | 3VGK_E
330 | 3VH0_D
331 | 3VH3_A
332 | 3VV2_A
333 | 3VYR_B
334 | 3WN7_A
335 | 3ZWL_B
336 | 4BKX_B
337 | 4C9B_B
338 | 4DVG_B
339 | 4ETP_A
340 | 4ETP_B
341 | 4FZV_A
342 | 4HDO_A
343 | 4LVN_A
344 | 4M0W_A
345 | 4V0O_F
346 | 4X33_B
347 | 4XL5_C
348 | 4Y61_B
349 | 4YC7_B
350 | 4YEB_A
351 | 4ZGY_A
352 | 4ZRJ_A
353 | 5BV7_A
354 | 5CXB_B
355 | 5J57_A
356 | 5TIH_A
357 | 5XIM_A
358 | 7MDH_B
359 | 4JLR_S
360 |
--------------------------------------------------------------------------------
/lists/testing_ppi.txt:
--------------------------------------------------------------------------------
1 | 1A2K_C_AB
2 | 1A2W_A_B
3 | 1A79_C_B
4 | 1A99_C_D
5 | 1ACB_E_I
6 | 1AGQ_C_D
7 | 1AHS_C_B
8 | 1AK4_A_D
9 | 1AN1_E_I
10 | 1ARZ_A_C
11 | 1ATN_A_D
12 | 1AVX_A_B
13 | 1AY7_A_B
14 | 1B27_A_D
15 | 1B2S_A_D
16 | 1B2U_A_D
17 | 1B3S_A_D
18 | 1B3T_A_B
19 | 1B65_A_B
20 | 1B6C_A_B
21 | 1BJR_I_E
22 | 1BO4_A_B
23 | 1BPO_C_B
24 | 1BRS_A_D
25 | 1C3X_A_C
26 | 1C8N_A_C
27 | 1C9P_A_B
28 | 1C9S_G_H
29 | 1C9T_A_G
30 | 1CBW_ABC_D
31 | 1CGI_E_I
32 | 1CL7_I_H
33 | 1CQ3_A_B
34 | 1D6R_A_I
35 | 1DB2_A_B
36 | 1DEV_C_D
37 | 1DFJ_E_I
38 | 1DJS_A_B
39 | 1DLE_A_B
40 | 1DML_A_B
41 | 1DN2_B_F
42 | 1E5Q_C_D
43 | 1E8N_A_I
44 | 1EAI_A_C
45 | 1EAW_A_B
46 | 1EJA_A_B
47 | 1EM8_A_B
48 | 1EPT_A_B
49 | 1ERN_A_B
50 | 1EWJ_C_D
51 | 1EWY_A_C
52 | 1EZI_A_B
53 | 1EZU_C_AB
54 | 1F2U_C_D
55 | 1F37_A_B
56 | 1F45_A_B
57 | 1F5R_A_I
58 | 1F7Z_A_I
59 | 1F9S_A_D
60 | 1FCC_AB_C
61 | 1FFV_A_C
62 | 1FGL_A_B
63 | 1FIW_A_L
64 | 1FLE_E_I
65 | 1FU5_A_B
66 | 1FY8_E_I
67 | 1G31_C_B
68 | 1G60_A_B
69 | 1G9I_E_I
70 | 1GCQ_C_B
71 | 1GGP_A_B
72 | 1GL0_E_I
73 | 1GL1_A_I
74 | 1GO4_D_H
75 | 1GT7_O_P
76 | 1GUS_C_E
77 | 1GXD_A_C
78 | 1H1V_A_G
79 | 1H6D_I_J
80 | 1H9R_A_B
81 | 1HAA_A_B
82 | 1HBT_I_H
83 | 1HCF_AB_X
84 | 1HEZ_BA_E
85 | 1HIA_AB_I
86 | 1HPU_C_B
87 | 1HX5_A_B
88 | 1HYR_A_C
89 | 1I07_A_B
90 | 1I4E_A_B
91 | 1I4O_B_D
92 | 1I9C_C_D
93 | 1ICF_A_I
94 | 1ID5_H_L
95 | 1IGU_A_B
96 | 1IJX_B_E
97 | 1INN_A_B
98 | 1IYJ_C_D
99 | 1J3R_A_B
100 | 1J4U_A_B
101 | 1JBU_H_X
102 | 1JFM_E_D
103 | 1JIW_P_I
104 | 1JK9_B_A
105 | 1JK9_B_D
106 | 1JK9_C_D
107 | 1JKG_A_B
108 | 1JNP_A_B
109 | 1JTD_B_A
110 | 1JXQ_C_D
111 | 1JYI_A_P
112 | 1JZO_A_B
113 | 1K88_A_B
114 | 1KAC_A_B
115 | 1KCA_G_H
116 | 1KL8_A_B
117 | 1KXJ_A_B
118 | 1KXP_A_D
119 | 1L0A_A_B
120 | 1L2W_C_J
121 | 1L4D_A_B
122 | 1L4I_A_B
123 | 1L4Z_A_B
124 | 1LDT_T_L
125 | 1LK6_I_C
126 | 1LQM_E_F
127 | 1LW6_E_I
128 | 1M1F_A_B
129 | 1MAS_A_B
130 | 1MBY_A_B
131 | 1MCV_A_I
132 | 1MK9_G_F
133 | 1ML0_AB_D
134 | 1MR1_A_D
135 | 1MZW_A_B
136 | 1N0L_C_D
137 | 1NB5_AP_I
138 | 1NP6_A_B
139 | 1NPO_A_C
140 | 1NQ9_I_L
141 | 1NQL_A_B
142 | 1NR7_A_E
143 | 1NR9_C_D
144 | 1NU9_A_C
145 | 1O9A_A_B
146 | 1O9Y_A_D
147 | 1OMO_A_B
148 | 1OS2_E_F
149 | 1OSM_C_B
150 | 1OX9_D_L
151 | 1OYV_B_I
152 | 1P3H_I_H
153 | 1P69_A_B
154 | 1P6A_A_B
155 | 1P9U_F_H
156 | 1PBI_A_B
157 | 1PGL_1_2
158 | 1POI_A_D
159 | 1PPE_E_I
160 | 1PVH_A_B
161 | 1PXV_A_C
162 | 1PXV_B_D
163 | 1Q1L_B_D
164 | 1Q5H_A_B
165 | 1Q8M_A_B
166 | 1Q9U_A_B
167 | 1QB3_A_C
168 | 1QI1_C_B
169 | 1QJS_A_B
170 | 1QOL_G_F
171 | 1R0K_C_D
172 | 1R0R_E_I
173 | 1R7A_A_B
174 | 1R9N_D_H
175 | 1RY7_A_B
176 | 1RZJ_C_G
177 | 1RZP_C_B
178 | 1S1Q_A_B
179 | 1S4C_C_B
180 | 1S98_A_B
181 | 1SCE_A_C
182 | 1SCE_B_D
183 | 1SHS_A_C
184 | 1SHS_A_E
185 | 1SHW_A_B
186 | 1SHY_A_B
187 | 1SMF_E_I
188 | 1SMO_A_B
189 | 1SOT_A_C
190 | 1SUW_C_D
191 | 1T0F_A_B
192 | 1T0H_A_B
193 | 1T0P_A_B
194 | 1T6B_X_Y
195 | 1T7P_A_B
196 | 1T8U_A_B
197 | 1TAW_A_B
198 | 1TE1_A_B
199 | 1TM1_E_I
200 | 1TM3_E_I
201 | 1TM4_E_I
202 | 1TM5_E_I
203 | 1TM7_E_I
204 | 1TO1_E_I
205 | 1TQ9_A_B
206 | 1TZI_BA_V
207 | 1TZS_A_P
208 | 1U20_A_B
209 | 1UAN_A_B
210 | 1UDI_E_I
211 | 1UE7_C_D
212 | 1UGH_E_I
213 | 1UHE_A_B
214 | 1UK4_A_G
215 | 1UM2_A_C
216 | 1UMF_B_D
217 | 1UNN_A_B
218 | 1UP6_A_D
219 | 1UUG_A_B
220 | 1UWG_L_X
221 | 1V8H_A_B
222 | 1VG9_E_G
223 | 1VGC_C_B
224 | 1VGO_A_B
225 | 1VH4_A_B
226 | 1VHJ_C_F
227 | 1VS3_A_B
228 | 1VZY_A_B
229 | 1W1I_B_E
230 | 1W4R_A_D
231 | 1WDX_C_B
232 | 1WLP_A_B
233 | 1WOQ_A_B
234 | 1WQJ_B_I
235 | 1WZ3_A_B
236 | 1X1U_A_D
237 | 1X1W_A_D
238 | 1X1X_A_D
239 | 1X1Y_A_D
240 | 1XD3_A_B
241 | 1XD3_C_D
242 | 1XDT_T_R
243 | 1XFS_A_B
244 | 1XG2_A_B
245 | 1XPJ_A_D
246 | 1XSQ_A_B
247 | 1XT9_A_B
248 | 1XUA_A_B
249 | 1XV2_A_B
250 | 1XWD_A_B
251 | 1Y07_A_B
252 | 1Y0G_A_B
253 | 1Y1K_E_I
254 | 1Y1O_A_B
255 | 1Y33_E_I
256 | 1Y34_E_I
257 | 1Y3B_E_I
258 | 1Y3C_E_I
259 | 1Y3D_E_I
260 | 1Y43_A_B
261 | 1Y48_E_I
262 | 1Y4A_E_I
263 | 1Y4D_E_I
264 | 1Y96_A_B
265 | 1YBG_A_B
266 | 1YC0_A_I
267 | 1YC6_3_E
268 | 1YCS_A_B
269 | 1YFN_D_H
270 | 1YL7_A_D
271 | 1YLQ_A_B
272 | 1YOX_C_B
273 | 1YUK_A_B
274 | 1YVB_A_I
275 | 1YY9_A_D
276 | 1Z0K_A_C
277 | 1Z3G_A_H
278 | 1Z3G_B_I
279 | 1Z5Y_D_E
280 | 1ZCP_A_B
281 | 1ZH8_A_B
282 | 1ZJD_A_B
283 | 1ZLI_A_B
284 | 1ZR0_A_B
285 | 1ZUD_3_4
286 | 1ZVN_A_B
287 | 2A0S_A_B
288 | 2A2L_C_B
289 | 2A5Z_A_C
290 | 2A6P_A_B
291 | 2A74_E_F
292 | 2ABZ_B_E
293 | 2AF6_B_D
294 | 2AFF_A_B
295 | 2ANE_C_D
296 | 2AOB_C_D
297 | 2AQX_A_B
298 | 2AVF_C_B
299 | 2AXW_A_B
300 | 2AZN_A_D
301 | 2B3Z_C_D
302 | 2B42_B_A
303 | 2BBA_A_P
304 | 2BCM_C_B
305 | 2BE1_A_B
306 | 2BMA_C_B
307 | 2BTF_A_P
308 | 2BUJ_A_B
309 | 2C1W_A_B
310 | 2C35_D_H
311 | 2C7E_C_D
312 | 2C9P_A_B
313 | 2CCL_A_C
314 | 2CE8_A_B
315 | 2CH8_A_D
316 | 2CJR_A_B
317 | 2CO6_A_B
318 | 2D10_A_E
319 | 2D1P_D_G
320 | 2D2A_A_B
321 | 2DG0_E_H
322 | 2DOH_X_C
323 | 2DOI_X_C
324 | 2DP4_I_E
325 | 2DPF_C_D
326 | 2DSP_B_I
327 | 2DUP_A_B
328 | 2DYM_G_H
329 | 2E4M_A_C
330 | 2EIL_A_F
331 | 2EJG_B_D
332 | 2EP5_B_D
333 | 2ETE_A_B
334 | 2EVV_C_D
335 | 2EWN_A_B
336 | 2F2F_C_B
337 | 2FB8_A_B
338 | 2FBE_C_D
339 | 2FDB_M_P
340 | 2FE8_A_C
341 | 2FJU_B_A
342 | 2FKD_G_J
343 | 2FP7_A_B
344 | 2FPE_A_B
345 | 2FTL_E_I
346 | 2FTM_A_B
347 | 2FU5_A_D
348 | 2G2U_A_B
349 | 2G2W_A_B
350 | 2G45_A_B
351 | 2G6V_A_B
352 | 2G81_E_I
353 | 2GBK_B_D
354 | 2GD4_C_B
355 | 2GEC_A_B
356 | 2GEF_A_B
357 | 2GHV_C_E
358 | 2GHW_C_D
359 | 2GJV_C_D
360 | 2GKW_A_B
361 | 2GQS_A_B
362 | 2GS7_A_B
363 | 2GT2_A_B
364 | 2H1T_A_B
365 | 2H3N_A_B
366 | 2H5K_A_B
367 | 2HAX_A_B
368 | 2HD3_C_D
369 | 2HD3_E_F
370 | 2HDP_A_B
371 | 2HEK_A_B
372 | 2HEY_T_G
373 | 2HJ1_A_B
374 | 2HL3_A_B
375 | 2HLE_A_B
376 | 2HQH_B_F
377 | 2HQL_A_E
378 | 2HTB_C_D
379 | 2HVB_A_B
380 | 2HVY_A_C
381 | 2HWJ_A_C
382 | 2HZM_G_H
383 | 2HZS_B_H
384 | 2HZS_F_K
385 | 2I04_A_B
386 | 2I0B_A_B
387 | 2I32_A_E
388 | 2I5G_A_B
389 | 2I79_C_B
390 | 2I7R_A_B
391 | 2I9B_C_G
392 | 2IA9_A_B
393 | 2IDO_A_B
394 | 2IHS_A_C
395 | 2IJ0_C_B
396 | 2IO1_A_B
397 | 2IQH_A_C
398 | 2IWO_A_B
399 | 2IWP_A_B
400 | 2IY1_A_B
401 | 2J12_A_B
402 | 2J7Q_C_D
403 | 2J8X_A_B
404 | 2JG8_D_F
405 | 2JI1_C_D
406 | 2JJS_A_C
407 | 2JJS_B_D
408 | 2JJT_A_C
409 | 2JOD_A_B
410 | 2K2S_A_B
411 | 2K6D_A_B
412 | 2KAI_A_B
413 | 2KWJ_A_B
414 | 2L0F_A_B
415 | 2L29_A_B
416 | 2LBU_E_D
417 | 2MCN_A_B
418 | 2MNU_A_B
419 | 2MTA_HL_A
420 | 2NBV_A_B
421 | 2NM1_A_B
422 | 2NN3_C_D
423 | 2NQD_A_B
424 | 2NU1_I_E
425 | 2NUU_K_L
426 | 2NXM_A_B
427 | 2NZ1_X_Y
428 | 2O8Q_A_B
429 | 2O95_A_B
430 | 2O9Q_A_C
431 | 2OGJ_E_F
432 | 2OIN_A_C
433 | 2OKQ_A_B
434 | 2OPI_A_B
435 | 2OS5_A_D
436 | 2OS7_C_F
437 | 2OTL_A_Z
438 | 2OUL_A_B
439 | 2OVI_A_D
440 | 2OYA_A_B
441 | 2OZK_B_D
442 | 2P04_A_B
443 | 2P35_A_B
444 | 2P42_A_B
445 | 2P43_A_B
446 | 2P44_A_B
447 | 2P45_A_B
448 | 2P46_A_B
449 | 2P47_A_B
450 | 2P48_A_B
451 | 2P49_A_B
452 | 2P4A_A_B
453 | 2P4R_A_T
454 | 2P4Z_A_B
455 | 2P5X_A_B
456 | 2P6B_A_E
457 | 2PKG_A_C
458 | 2PMV_A_B
459 | 2PNH_A_B
460 | 2PO2_A_B
461 | 2PQ2_A_B
462 | 2PQS_A_B
463 | 2PQV_A_B
464 | 2PTC_E_I
465 | 2PUY_B_E
466 | 2PZD_A_B
467 | 2Q17_C_B
468 | 2Q7N_A_B
469 | 2Q81_A_D
470 | 2QBW_A_B
471 | 2QBX_A_P
472 | 2QC1_A_B
473 | 2QF3_A_C
474 | 2QKI_A_G
475 | 2QLC_C_B
476 | 2QLP_A_C
477 | 2QW7_I_H
478 | 2QYI_C_D
479 | 2R0K_A_H
480 | 2R2C_A_B
481 | 2R5O_A_B
482 | 2R9P_A_E
483 | 2RA3_A_C
484 | 2RL7_C_D
485 | 2SIC_E_I
486 | 2SNI_E_I
487 | 2TGP_Z_I
488 | 2UUY_A_B
489 | 2V0R_A_B
490 | 2V3B_A_B
491 | 2V52_B_M
492 | 2VE6_A_D
493 | 2VER_A_N
494 | 2VIF_A_P
495 | 2VJF_C_D
496 | 2VPM_A_B
497 | 2VSC_A_B
498 | 2W0C_C_B
499 | 2W1T_A_B
500 | 2W2N_A_E
501 | 2W80_A_D
502 | 2W80_G_H
503 | 2W81_A_D
504 | 2WAM_A_C
505 | 2WC4_A_C
506 | 2WFX_A_B
507 | 2WG3_A_C
508 | 2WG4_A_B
509 | 2WLG_C_B
510 | 2WO2_A_B
511 | 2WO3_A_B
512 | 2WQ4_A_C
513 | 2WQZ_A_D
514 | 2WVT_A_B
515 | 2WWX_A_B
516 | 2X36_C_D
517 | 2X53_W_V
518 | 2X5Q_A_B
519 | 2X89_F_G
520 | 2X8K_A_B
521 | 2X9A_A_D
522 | 2X9A_C_D
523 | 2X9M_B_D
524 | 2XB6_A_C
525 | 2XBB_A_C
526 | 2XCE_E_D
527 | 2XFG_A_B
528 | 2XJZ_C_K
529 | 2XTJ_A_D
530 | 2Y32_B_D
531 | 2Y9X_B_F
532 | 2YC2_A_D
533 | 2YCH_A_B
534 | 2YH9_A_C
535 | 2YVJ_A_B
536 | 2YVL_B_D
537 | 2YVS_A_B
538 | 2YYS_A_B
539 | 2YZJ_A_C
540 | 2Z0E_A_B
541 | 2Z0P_C_D
542 | 2Z29_A_B
543 | 2Z2M_C_B
544 | 2Z7F_E_I
545 | 2Z7X_A_B
546 | 2Z8M_A_B
547 | 2ZA4_A_B
548 | 2ZCK_P_S
549 | 2ZDC_A_B
550 | 2ZG6_A_B
551 | 2ZME_A_C
552 | 2ZNV_E_D
553 | 2ZSU_A_B
554 | 2ZVW_C_E
555 | 2ZVW_C_K
556 | 2ZXW_O_U
557 | 3A1P_C_D
558 | 3AD8_B_D
559 | 3AEH_A_B
560 | 3AFF_A_B
561 | 3AHS_A_C
562 | 3AJY_A_C
563 | 3ALZ_A_B
564 | 3AOG_H_L
565 | 3AXY_B_D
566 | 3B01_A_C
567 | 3B08_A_B
568 | 3B5U_B_D
569 | 3B5U_J_L
570 | 3B6P_B_D
571 | 3B76_A_B
572 | 3B93_A_C
573 | 3B9I_A_B
574 | 3BAL_A_B
575 | 3BCP_A_B
576 | 3BCW_A_B
577 | 3BFW_C_D
578 | 3BGL_C_B
579 | 3BHD_A_B
580 | 3BIW_A_E
581 | 3BN3_A_B
582 | 3BPD_G_F
583 | 3BQB_A_X
584 | 3BRC_A_B
585 | 3BRD_A_D
586 | 3BT1_B_U
587 | 3BTV_A_B
588 | 3BWU_C_D
589 | 3BX1_A_C
590 | 3BX7_A_C
591 | 3C0B_C_D
592 | 3C4O_A_B
593 | 3C4P_A_B
594 | 3C7T_A_C
595 | 3C8I_A_B
596 | 3CAM_A_B
597 | 3CDW_A_H
598 | 3CE9_A_D
599 | 3CEW_C_D
600 | 3CG8_C_B
601 | 3CGY_A_B
602 | 3CHW_A_P
603 | 3CJX_G_H
604 | 3CO2_A_D
605 | 3CQ9_C_D
606 | 3D1E_A_P
607 | 3D1M_A_D
608 | 3D4G_C_D
609 | 3D4R_A_B
610 | 3D5N_F_I
611 | 3DA7_A_D
612 | 3DAW_A_B
613 | 3DAX_A_B
614 | 3DCA_C_D
615 | 3DCL_A_B
616 | 3DGP_A_B
617 | 3DJP_A_B
618 | 3DKU_A_B
619 | 3DQQ_A_B
620 | 3DSN_C_B
621 | 3E05_B_H
622 | 3E05_C_B
623 | 3E1Z_A_B
624 | 3E2K_A_D
625 | 3E2L_A_C
626 | 3E2U_A_E
627 | 3E38_A_B
628 | 3E9M_A_D
629 | 3ECY_A_B
630 | 3EDP_A_B
631 | 3EHU_A_C
632 | 3EMJ_K_L
633 | 3EN0_A_B
634 | 3ENT_A_B
635 | 3EPZ_A_B
636 | 3EUK_C_E
637 | 3EYD_C_D
638 | 3F5N_A_D
639 | 3F74_A_B
640 | 3F75_A_P
641 | 3FCG_A_B
642 | 3FD4_A_B
643 | 3FEF_C_D
644 | 3FFU_A_B
645 | 3FG8_C_E
646 | 3FHC_A_B
647 | 3FJS_C_D
648 | 3FJU_A_B
649 | 3FK9_A_B
650 | 3FL1_A_B
651 | 3FLP_M_N
652 | 3FP6_E_I
653 | 3FPR_A_D
654 | 3FPU_A_B
655 | 3FPV_A_F
656 | 3FSN_A_B
657 | 3FUY_A_B
658 | 3FYF_A_B
659 | 3G9V_A_B
660 | 3GBU_A_D
661 | 3GFU_A_B
662 | 3GMW_A_B
663 | 3GNJ_C_D
664 | 3GQH_A_B
665 | 3GRW_A_H
666 | 3GWY_A_B
667 | 3GXU_A_B
668 | 3GZ8_A_B
669 | 3GZE_A_X
670 | 3GZE_C_Y
671 | 3GZR_A_B
672 | 3H11_BC_A
673 | 3H35_C_B
674 | 3H3B_A_C
675 | 3H6S_A_E
676 | 3H8D_B_F
677 | 3H8G_B_E
678 | 3H9G_A_E
679 | 3HCG_A_C
680 | 3HF5_A_D
681 | 3HHJ_A_B
682 | 3HLN_O_U
683 | 3HM8_A_C
684 | 3HMK_A_B
685 | 3HN6_B_D
686 | 3HO5_B_H
687 | 3HPN_E_F
688 | 3HQR_A_S
689 | 3HRD_E_H
690 | 3HT2_A_C
691 | 3HTR_A_B
692 | 3I2B_E_H
693 | 3I5V_A_D
694 | 3I84_A_B
695 | 3IAS_G_F
696 | 3IBM_A_B
697 | 3ISM_A_B
698 | 3JRQ_A_B
699 | 3JUY_C_B
700 | 3JVC_A_C
701 | 3JVZ_B_D
702 | 3JZA_A_B
703 | 3K1R_A_B
704 | 3K25_A_B
705 | 3K3C_A_B
706 | 3K4W_D_F
707 | 3K6S_A_E
708 | 3K9M_A_C
709 | 3K9M_B_D
710 | 3KDG_A_B
711 | 3KL9_I_J
712 | 3KLQ_A_B
713 | 3KMH_A_B
714 | 3KMT_A_B
715 | 3KQZ_K_L
716 | 3KTM_C_F
717 | 3KTS_A_B
718 | 3KW5_A_B
719 | 3KWV_D_F
720 | 3KY8_A_B
721 | 3KZH_A_B
722 | 3L2H_A_D
723 | 3L33_A_E
724 | 3L9J_C_T
725 | 3LAQ_A_U
726 | 3LHX_A_B
727 | 3LM1_E_F
728 | 3LMS_A_B
729 | 3LQV_B_Q
730 | 3LRJ_C_D
731 | 3LU9_E_F
732 | 3M5O_A_C
733 | 3M5R_A_G
734 | 3M85_B_E
735 | 3M85_B_G
736 | 3MAL_A_B
737 | 3ME4_A_B
738 | 3MJ9_A_H
739 | 3ML6_B_F
740 | 3MQW_C_D
741 | 3MZW_A_B
742 | 3N2B_A_C
743 | 3N4I_A_B
744 | 3N6Q_E_H
745 | 3NCT_B_D
746 | 3NEK_A_B
747 | 3NFG_G_H
748 | 3NGB_E_G
749 | 3NPG_C_D
750 | 3NRJ_A_H
751 | 3NS1_A_B
752 | 3NTQ_A_B
753 | 3O2X_B_D
754 | 3O34_A_B
755 | 3O9L_A_B
756 | 3OEU_2_M
757 | 3OGF_A_B
758 | 3OGO_A_E
759 | 3OJ2_A_C
760 | 3OJM_A_B
761 | 3OKJ_C_D
762 | 3OLM_A_D
763 | 3OSL_A_B
764 | 3OZB_A_C
765 | 3P71_C_T
766 | 3P83_B_F
767 | 3P8B_C_D
768 | 3P92_A_E
769 | 3P95_A_E
770 | 3P9W_A_B
771 | 3PCQ_C_D
772 | 3PGA_1_4
773 | 3PIG_A_B
774 | 3PIM_A_B
775 | 3PNR_A_B
776 | 3PPE_A_B
777 | 3PRP_A_B
778 | 3PS4_C_B
779 | 3PYY_A_B
780 | 3Q0Y_C_B
781 | 3Q7H_K_M
782 | 3Q87_A_B
783 | 3Q9N_A_C
784 | 3Q9U_A_C
785 | 3QC8_A_B
786 | 3QDZ_B_E
787 | 3QFM_A_B
788 | 3QHY_A_B
789 | 3QJ7_A_D
790 | 3QNA_A_D
791 | 3QPB_K_J
792 | 3QQ8_A_B
793 | 3QSK_A_B
794 | 3QWN_I_J
795 | 3QWQ_A_B
796 | 3RBQ_B_H
797 | 3RDZ_A_C
798 | 3RT0_A_C
799 | 3S5B_A_B
800 | 3S8V_A_X
801 | 3S9C_A_B
802 | 3SGB_E_I
803 | 3SGQ_E_I
804 | 3SJ9_A_B
805 | 3SLH_A_B
806 | 3SM1_A_B
807 | 3SOQ_A_Z
808 | 3T3A_A_B
809 | 3TDM_A_B
810 | 3TG9_A_B
811 | 3TGK_E_I
812 | 3THT_A_B
813 | 3TII_A_B
814 | 3TIW_B_D
815 | 3TL8_A_B
816 | 3TND_B_D
817 | 3TQY_A_B
818 | 3TSR_A_E
819 | 3TSZ_A_B
820 | 3TU3_A_B
821 | 3U02_B_D
822 | 3U1O_A_B
823 | 3U4J_C_B
824 | 3UI2_A_B
825 | 3UIR_A_C
826 | 3UVN_C_D
827 | 3UZP_A_B
828 | 3UZV_A_B
829 | 3V3K_A_B
830 | 3V4P_B_H
831 | 3V5N_C_D
832 | 3V96_A_B
833 | 3VPJ_A_E
834 | 3VYR_A_B
835 | 3WA5_A_B
836 | 3WDG_A_B
837 | 3WN7_A_B
838 | 3ZRZ_A_C
839 | 3ZWL_B_E
840 | 4A94_A_D
841 | 4AFQ_A_C
842 | 4AFZ_A_C
843 | 4AG2_A_C
844 | 4AN7_A_B
845 | 4AOQ_A_D
846 | 4AOR_A_D
847 | 4APF_A_B
848 | 4AYD_A_D
849 | 4AYE_A_D
850 | 4AYI_A_D
851 | 4B1V_B_N
852 | 4B1X_B_M
853 | 4B1Y_B_M
854 | 4BD9_A_B
855 | 4BQD_A_C
856 | 4BWQ_E_F
857 | 4CDK_A_E
858 | 4CJ0_A_B
859 | 4CJ1_A_B
860 | 4CMM_A_B
861 | 4CPA_A_I
862 | 4DG4_A_E
863 | 4DGE_A_C
864 | 4DOQ_A_B
865 | 4EIG_A_B
866 | 4EQA_A_C
867 | 4F0A_A_B
868 | 4FT4_B_Q
869 | 4FZA_A_B
870 | 4G6U_A_B
871 | 4GH7_A_B
872 | 4GI3_A_C
873 | 4HDO_A_B
874 | 4I6L_A_B
875 | 4ILW_A_D
876 | 4IOP_A_B
877 | 4J2Y_A_B
878 | 4JRA_A_D
879 | 4JW3_A_C
880 | 4K1R_A_B
881 | 4K24_A_U
882 | 4KBB_A_C
883 | 4KDI_A_C
884 | 4KFZ_A_C
885 | 4KGG_C_A
886 | 4KR0_A_B
887 | 4KRL_B_A
888 | 4KSD_A_B
889 | 4KV5_C_D
890 | 4L0P_A_B
891 | 4LAD_A_B
892 | 4LLO_A_B
893 | 4LQW_A_D
894 | 4LYL_A_B
895 | 4M0W_A_B
896 | 4M5F_A_B
897 | 4MSM_A_B
898 | 4NOO_A_B
899 | 4NSO_A_B
900 | 4NZL_A_B
901 | 4NZW_A_B
902 | 4PEQ_A_B
903 | 4PJ2_A_D
904 | 4POU_A_B
905 | 4PQT_A_B
906 | 4QT8_A_C
907 | 4QZV_A_B
908 | 4RS1_B_A
909 | 4TQ0_A_B
910 | 4TQ1_A_B
911 | 4U30_A_X
912 | 4U97_A_B
913 | 4UDM_B_A
914 | 4V0M_F_E
915 | 4V0N_D_C
916 | 4V0O_F_E
917 | 4W6X_A_B
918 | 4WEM_A_B
919 | 4WEN_A_B
920 | 4XL1_A_B
921 | 4XL5_A_C
922 | 4XLW_A_B
923 | 4XXB_A_B
924 | 4YDJ_HL_G
925 | 4YEB_A_B
926 | 4YN0_A_B
927 | 4YWC_A_C
928 | 4ZK9_A_B
929 | 4ZKC_A_B
930 | 4ZQU_A_B
931 | 5AYR_A_B
932 | 5AYS_A_C
933 | 5B75_A_B
934 | 5B76_A_B
935 | 5B77_A_B
936 | 5B78_A_B
937 | 5CJO_HL_A
938 | 5D1K_A_B
939 | 5D1L_A_B
940 | 5D1M_A_B
941 | 5D3I_A_B
942 | 5DJT_A_B
943 | 5DMJ_A_B
944 | 5EB1_A_B
945 | 5F3X_A_B
946 | 5F4E_A_B
947 | 5G1X_A_B
948 | 5GPG_A_B
949 | 5HPK_A_B
950 | 5INB_A_B
951 | 5IOH_A_B
952 | 5J28_A_C
953 | 5JKE_A_B
954 | 5JLV_A_C
955 | 5JMC_A_B
956 | 5JYL_A_B
957 | 6CMG_CB_A
958 | 4ZQK_A_B
959 | 3BIK_A_B
960 |
--------------------------------------------------------------------------------
/main_inference.py:
--------------------------------------------------------------------------------
1 | # Standard imports:
2 | import numpy as np
3 | import torch
4 | from torch.utils.tensorboard import SummaryWriter
5 | from torch.utils.data import random_split
6 | from torch_geometric.data import DataLoader
7 | from torch_geometric.transforms import Compose
8 | from pathlib import Path
9 |
10 | # Custom data loader and model:
11 | from data import ProteinPairsSurfaces, PairData, CenterPairAtoms, load_protein_pair
12 | from data import RandomRotationPairAtoms, NormalizeChemFeatures, iface_valid_filter
13 | from model import dMaSIF
14 | from data_iteration import iterate
15 | from helper import *
16 | from Arguments import parser
17 |
18 | args = parser.parse_args()
19 | model_path = "models/" + args.experiment_name
20 | save_predictions_path = Path("preds/" + args.experiment_name)
21 |
22 | # Ensure reproducability:
23 | torch.backends.cudnn.deterministic = True
24 | torch.manual_seed(args.seed)
25 | torch.cuda.manual_seed_all(args.seed)
26 | np.random.seed(args.seed)
27 |
28 |
29 | # Load the train and test datasets:
30 | transformations = (
31 | Compose([NormalizeChemFeatures(), CenterPairAtoms(), RandomRotationPairAtoms()])
32 | if args.random_rotation
33 | else Compose([NormalizeChemFeatures()])
34 | )
35 |
36 | if args.single_pdb != "":
37 | single_data_dir = Path("./data_preprocessing/npys/")
38 | test_dataset = [load_protein_pair(args.single_pdb, single_data_dir,single_pdb=True)]
39 | test_pdb_ids = [args.single_pdb]
40 | elif args.pdb_list != "":
41 | with open(args.pdb_list) as f:
42 | pdb_list = f.read().splitlines()
43 | single_data_dir = Path("./data_preprocessing/npys/")
44 | test_dataset = [load_protein_pair(pdb, single_data_dir,single_pdb=True) for pdb in pdb_list]
45 | test_pdb_ids = [pdb for pdb in pdb_list]
46 | else:
47 | test_dataset = ProteinPairsSurfaces(
48 | "surface_data", train=False, ppi=args.search, transform=transformations
49 | )
50 | test_pdb_ids = (
51 | np.load("surface_data/processed/testing_pairs_data_ids.npy")
52 | if args.site
53 | else np.load("surface_data/processed/testing_pairs_data_ids_ppi.npy")
54 | )
55 |
56 | test_dataset = [
57 | (data, pdb_id)
58 | for data, pdb_id in zip(test_dataset, test_pdb_ids)
59 | if iface_valid_filter(data)
60 | ]
61 | test_dataset, test_pdb_ids = list(zip(*test_dataset))
62 |
63 |
64 | # PyTorch geometric expects an explicit list of "batched variables":
65 | batch_vars = ["xyz_p1", "xyz_p2", "atom_coords_p1", "atom_coords_p2"]
66 | test_loader = DataLoader(
67 | test_dataset, batch_size=args.batch_size, follow_batch=batch_vars
68 | )
69 |
70 | net = dMaSIF(args)
71 | # net.load_state_dict(torch.load(model_path, map_location=args.device))
72 | net.load_state_dict(
73 | torch.load(model_path, map_location=args.device)["model_state_dict"]
74 | )
75 | net = net.to(args.device)
76 |
77 | # Perform one pass through the data:
78 | info = iterate(
79 | net,
80 | test_loader,
81 | None,
82 | args,
83 | test=True,
84 | save_path=save_predictions_path,
85 | pdb_ids=test_pdb_ids,
86 | )
87 |
88 | #np.save(f"timings/{args.experiment_name}_convtime.npy", info["conv_time"])
89 | #np.save(f"timings/{args.experiment_name}_memoryusage.npy", info["memory_usage"])
90 |
--------------------------------------------------------------------------------
/main_training.py:
--------------------------------------------------------------------------------
1 | # Standard imports:
2 | import numpy as np
3 | import torch
4 | from torch.utils.tensorboard import SummaryWriter
5 | from torch.utils.data import random_split
6 | from torch_geometric.data import DataLoader
7 | from torch_geometric.transforms import Compose
8 | from pathlib import Path
9 |
10 | # Custom data loader and model:
11 | from data import ProteinPairsSurfaces, PairData, CenterPairAtoms
12 | from data import RandomRotationPairAtoms, NormalizeChemFeatures, iface_valid_filter
13 | from model import dMaSIF
14 | from data_iteration import iterate, iterate_surface_precompute
15 | from helper import *
16 | from Arguments import parser
17 |
18 | # Parse the arguments, prepare the TensorBoard writer:
19 | args = parser.parse_args()
20 | writer = SummaryWriter("runs/{}".format(args.experiment_name))
21 | model_path = "models/" + args.experiment_name
22 |
23 | if not Path("models/").exists():
24 | Path("models/").mkdir(exist_ok=False)
25 |
26 | # Ensure reproducibility:
27 | torch.backends.cudnn.deterministic = True
28 | torch.manual_seed(args.seed)
29 | torch.cuda.manual_seed_all(args.seed)
30 | np.random.seed(args.seed)
31 |
32 | # Create the model, with a warm restart if applicable:
33 | net = dMaSIF(args)
34 | net = net.to(args.device)
35 |
36 | # We load the train and test datasets.
37 | # Random transforms, to ensure that no network/baseline overfits on pose parameters:
38 | transformations = (
39 | Compose([NormalizeChemFeatures(), CenterPairAtoms(), RandomRotationPairAtoms()])
40 | if args.random_rotation
41 | else Compose([NormalizeChemFeatures()])
42 | )
43 |
44 | # PyTorch geometric expects an explicit list of "batched variables":
45 | batch_vars = ["xyz_p1", "xyz_p2", "atom_coords_p1", "atom_coords_p2"]
46 | # Load the train dataset:
47 | train_dataset = ProteinPairsSurfaces(
48 | "surface_data", ppi=args.search, train=True, transform=transformations
49 | )
50 | train_dataset = [data for data in train_dataset if iface_valid_filter(data)]
51 | train_loader = DataLoader(
52 | train_dataset, batch_size=1, follow_batch=batch_vars, shuffle=True
53 | )
54 | print("Preprocessing training dataset")
55 | train_dataset = iterate_surface_precompute(train_loader, net, args)
56 |
57 | # Train/Validation split:
58 | train_nsamples = len(train_dataset)
59 | val_nsamples = int(train_nsamples * args.validation_fraction)
60 | train_nsamples = train_nsamples - val_nsamples
61 | train_dataset, val_dataset = random_split(
62 | train_dataset, [train_nsamples, val_nsamples]
63 | )
64 |
65 | # Load the test dataset:
66 | test_dataset = ProteinPairsSurfaces(
67 | "surface_data", ppi=args.search, train=False, transform=transformations
68 | )
69 | test_dataset = [data for data in test_dataset if iface_valid_filter(data)]
70 | test_loader = DataLoader(
71 | test_dataset, batch_size=1, follow_batch=batch_vars, shuffle=True
72 | )
73 | print("Preprocessing testing dataset")
74 | test_dataset = iterate_surface_precompute(test_loader, net, args)
75 |
76 |
77 | # PyTorch_geometric data loaders:
78 | train_loader = DataLoader(
79 | train_dataset, batch_size=1, follow_batch=batch_vars, shuffle=True
80 | )
81 | val_loader = DataLoader(val_dataset, batch_size=1, follow_batch=batch_vars)
82 | test_loader = DataLoader(test_dataset, batch_size=1, follow_batch=batch_vars)
83 |
84 |
85 | # Baseline optimizer:
86 | optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, amsgrad=True)
87 | best_loss = 1e10 # We save the "best model so far"
88 |
89 | starting_epoch = 0
90 | if args.restart_training != "":
91 | checkpoint = torch.load("models/" + args.restart_training)
92 | net.load_state_dict(checkpoint["model_state_dict"])
93 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
94 | starting_epoch = checkpoint["epoch"]
95 | best_loss = checkpoint["best_loss"]
96 |
97 | # Training loop (~100 times) over the dataset:
98 | for i in range(starting_epoch, args.n_epochs):
99 | # Train first, Test second:
100 | for dataset_type in ["Train", "Validation", "Test"]:
101 | if dataset_type == "Train":
102 | test = False
103 | else:
104 | test = True
105 |
106 | suffix = dataset_type
107 | if dataset_type == "Train":
108 | dataloader = train_loader
109 | elif dataset_type == "Validation":
110 | dataloader = val_loader
111 | elif dataset_type == "Test":
112 | dataloader = test_loader
113 |
114 | # Perform one pass through the data:
115 | info = iterate(
116 | net,
117 | dataloader,
118 | optimizer,
119 | args,
120 | test=test,
121 | summary_writer=writer,
122 | epoch_number=i,
123 | )
124 |
125 | # Write down the results using a TensorBoard writer:
126 | for key, val in info.items():
127 | if key in [
128 | "Loss",
129 | "ROC-AUC",
130 | "Distance/Positives",
131 | "Distance/Negatives",
132 | "Matching ROC-AUC",
133 | ]:
134 | writer.add_scalar(f"{key}/{suffix}", np.mean(val), i)
135 |
136 | if "R_values/" in key:
137 | val = np.array(val)
138 | writer.add_scalar(f"{key}/{suffix}", np.mean(val[val > 0]), i)
139 |
140 | if dataset_type == "Validation": # Store validation loss for saving the model
141 | val_loss = np.mean(info["Loss"])
142 |
143 | if True: # Additional saves
144 | if val_loss < best_loss:
145 | print("Validation loss {}, saving model".format(val_loss))
146 | torch.save(
147 | {
148 | "epoch": i,
149 | "model_state_dict": net.state_dict(),
150 | "optimizer_state_dict": optimizer.state_dict(),
151 | "best_loss": best_loss,
152 | },
153 | model_path + "_epoch{}".format(i),
154 | )
155 |
156 | best_loss = val_loss
157 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.autograd.profiler as profiler
7 | from pykeops.torch import LazyTensor
8 |
9 | from geometry_processing import (
10 | curvatures,
11 | mesh_normals_areas,
12 | tangent_vectors,
13 | atoms_to_points_normals,
14 | )
15 | from helper import soft_dimension, diagonal_ranges
16 | from benchmark_models import DGCNN_seg, PointNet2_seg, dMaSIFConv_seg
17 |
18 |
19 | def knn_atoms(x, y, x_batch, y_batch, k):
20 | N, D = x.shape
21 | x_i = LazyTensor(x[:, None, :])
22 | y_j = LazyTensor(y[None, :, :])
23 |
24 | pairwise_distance_ij = ((x_i - y_j) ** 2).sum(-1)
25 | pairwise_distance_ij.ranges = diagonal_ranges(x_batch, y_batch)
26 |
27 | # N.B.: KeOps doesn't yet support backprop through Kmin reductions...
28 | # dists, idx = pairwise_distance_ij.Kmin_argKmin(K=k,axis=1)
29 | # So we have to re-compute the values ourselves:
30 | idx = pairwise_distance_ij.argKmin(K=k, axis=1) # (N, K)
31 | x_ik = y[idx.view(-1)].view(N, k, D)
32 | dists = ((x[:, None, :] - x_ik) ** 2).sum(-1)
33 |
34 | return idx, dists
35 |
36 |
37 | def get_atom_features(x, y, x_batch, y_batch, y_atomtype, k=16):
38 |
39 | idx, dists = knn_atoms(x, y, x_batch, y_batch, k=k) # (num_points, k)
40 | num_points, _ = idx.size()
41 |
42 | idx = idx.view(-1)
43 | dists = 1 / dists.view(-1, 1)
44 | _, num_dims = y_atomtype.size()
45 |
46 | feature = y_atomtype[idx, :]
47 | feature = torch.cat([feature, dists], dim=1)
48 | feature = feature.view(num_points, k, num_dims + 1)
49 |
50 | return feature
51 |
52 |
53 | class Atom_embedding(nn.Module):
54 | def __init__(self, args):
55 | super(Atom_embedding, self).__init__()
56 | self.D = args.atom_dims
57 | self.k = 16
58 | self.conv1 = nn.Linear(self.D + 1, self.D)
59 | self.conv2 = nn.Linear(self.D, self.D)
60 | self.conv3 = nn.Linear(2 * self.D, self.D)
61 | self.bn1 = nn.BatchNorm1d(self.D)
62 | self.bn2 = nn.BatchNorm1d(self.D)
63 | self.relu = nn.LeakyReLU(negative_slope=0.2)
64 |
65 | def forward(self, x, y, y_atomtypes, x_batch, y_batch):
66 | fx = get_atom_features(x, y, x_batch, y_batch, y_atomtypes, k=self.k)
67 | fx = self.conv1(fx)
68 | fx = fx.view(-1, self.D)
69 | fx = self.bn1(self.relu(fx))
70 | fx = fx.view(-1, self.k, self.D)
71 | fx1 = fx.sum(dim=1, keepdim=False)
72 |
73 | fx = self.conv2(fx)
74 | fx = fx.view(-1, self.D)
75 | fx = self.bn2(self.relu(fx))
76 | fx = fx.view(-1, self.k, self.D)
77 | fx2 = fx.sum(dim=1, keepdim=False)
78 | fx = torch.cat((fx1, fx2), dim=-1)
79 | fx = self.conv3(fx)
80 |
81 | return fx
82 |
83 |
84 | class AtomNet(nn.Module):
85 | def __init__(self, args):
86 | super(AtomNet, self).__init__()
87 | self.args = args
88 |
89 | self.transform_types = nn.Sequential(
90 | nn.Linear(args.atom_dims, args.atom_dims),
91 | nn.LeakyReLU(negative_slope=0.2),
92 | nn.Linear(args.atom_dims, args.atom_dims),
93 | nn.LeakyReLU(negative_slope=0.2),
94 | nn.Linear(args.atom_dims, args.atom_dims),
95 | nn.LeakyReLU(negative_slope=0.2),
96 | )
97 | self.embed = Atom_embedding(args)
98 |
99 | def forward(self, xyz, atom_xyz, atomtypes, batch, atom_batch):
100 | # Run a DGCNN on the available information:
101 | atomtypes = self.transform_types(atomtypes)
102 | return self.embed(xyz, atom_xyz, atomtypes, batch, atom_batch)
103 |
104 | class Atom_embedding_MP(nn.Module):
105 | def __init__(self, args):
106 | super(Atom_embedding_MP, self).__init__()
107 | self.D = args.atom_dims
108 | self.k = 16
109 | self.n_layers = 3
110 | self.mlp = nn.ModuleList(
111 | [
112 | nn.Sequential(
113 | nn.Linear(2 * self.D + 1, 2 * self.D + 1),
114 | nn.LeakyReLU(negative_slope=0.2),
115 | nn.Linear(2 * self.D + 1, self.D),
116 | )
117 | for i in range(self.n_layers)
118 | ]
119 | )
120 | self.norm = nn.ModuleList(
121 | [nn.GroupNorm(2, self.D) for i in range(self.n_layers)]
122 | )
123 | self.relu = nn.LeakyReLU(negative_slope=0.2)
124 |
125 | def forward(self, x, y, y_atomtypes, x_batch, y_batch):
126 | idx, dists = knn_atoms(x, y, x_batch, y_batch, k=self.k) # N, 9, 7
127 | num_points = x.shape[0]
128 | num_dims = y_atomtypes.shape[-1]
129 |
130 | point_emb = torch.ones_like(x[:, 0])[:, None].repeat(1, num_dims)
131 | for i in range(self.n_layers):
132 | features = y_atomtypes[idx.reshape(-1), :]
133 | features = torch.cat([features, dists.reshape(-1, 1)], dim=1)
134 | features = features.view(num_points, self.k, num_dims + 1)
135 | features = torch.cat(
136 | [point_emb[:, None, :].repeat(1, self.k, 1), features], dim=-1
137 | ) # N, 8, 13
138 |
139 | messages = self.mlp[i](features) # N,8,6
140 | messages = messages.sum(1) # N,6
141 | point_emb = point_emb + self.relu(self.norm[i](messages))
142 |
143 | return point_emb
144 |
145 | class Atom_Atom_embedding_MP(nn.Module):
146 | def __init__(self, args):
147 | super(Atom_Atom_embedding_MP, self).__init__()
148 | self.D = args.atom_dims
149 | self.k = 17
150 | self.n_layers = 3
151 |
152 | self.mlp = nn.ModuleList(
153 | [
154 | nn.Sequential(
155 | nn.Linear(2 * self.D + 1, 2 * self.D + 1),
156 | nn.LeakyReLU(negative_slope=0.2),
157 | nn.Linear(2 * self.D + 1, self.D),
158 | )
159 | for i in range(self.n_layers)
160 | ]
161 | )
162 |
163 | self.norm = nn.ModuleList(
164 | [nn.GroupNorm(2, self.D) for i in range(self.n_layers)]
165 | )
166 | self.relu = nn.LeakyReLU(negative_slope=0.2)
167 |
168 | def forward(self, x, y, y_atomtypes, x_batch, y_batch):
169 | idx, dists = knn_atoms(x, y, x_batch, y_batch, k=self.k) # N, 9, 7
170 | idx = idx[:, 1:] # Remove self
171 | dists = dists[:, 1:]
172 | k = self.k - 1
173 | num_points = y_atomtypes.shape[0]
174 |
175 | out = y_atomtypes
176 | for i in range(self.n_layers):
177 | _, num_dims = out.size()
178 | features = out[idx.reshape(-1), :]
179 | features = torch.cat([features, dists.reshape(-1, 1)], dim=1)
180 | features = features.view(num_points, k, num_dims + 1)
181 | features = torch.cat(
182 | [out[:, None, :].repeat(1, k, 1), features], dim=-1
183 | ) # N, 8, 13
184 |
185 | messages = self.mlp[i](features) # N,8,6
186 | messages = messages.sum(1) # N,6
187 | out = out + self.relu(self.norm[i](messages))
188 |
189 | return out
190 |
191 | class AtomNet_MP(nn.Module):
192 | def __init__(self, args):
193 | super(AtomNet_MP, self).__init__()
194 | self.args = args
195 |
196 | self.transform_types = nn.Sequential(
197 | nn.Linear(args.atom_dims, args.atom_dims),
198 | nn.LeakyReLU(negative_slope=0.2),
199 | nn.Linear(args.atom_dims, args.atom_dims),
200 | )
201 |
202 | self.embed = Atom_embedding_MP(args)
203 | self.atom_atom = Atom_Atom_embedding_MP(args)
204 |
205 | def forward(self, xyz, atom_xyz, atomtypes, batch, atom_batch):
206 | # Run a DGCNN on the available information:
207 | atomtypes = self.transform_types(atomtypes)
208 | atomtypes = self.atom_atom(
209 | atom_xyz, atom_xyz, atomtypes, atom_batch, atom_batch
210 | )
211 | atomtypes = self.embed(xyz, atom_xyz, atomtypes, batch, atom_batch)
212 | return atomtypes
213 |
214 |
215 | def combine_pair(P1, P2):
216 | P1P2 = {}
217 | for key in P1:
218 | v1 = P1[key]
219 | v2 = P2[key]
220 | if v1 is None:
221 | continue
222 |
223 | if key == "batch" or key == "batch_atoms":
224 | v1v2 = torch.cat([v1, v2 + v1[-1] + 1], dim=0)
225 | elif key == "triangles":
226 | # v1v2 = torch.cat([v1,v2],dim=1)
227 | continue
228 | else:
229 | v1v2 = torch.cat([v1, v2], dim=0)
230 | P1P2[key] = v1v2
231 |
232 | return P1P2
233 |
234 |
235 | def split_pair(P1P2):
236 | batch_size = P1P2["batch_atoms"][-1] + 1
237 | p1_indices = P1P2["batch"] < batch_size // 2
238 | p2_indices = P1P2["batch"] >= batch_size // 2
239 |
240 | p1_atom_indices = P1P2["batch_atoms"] < batch_size // 2
241 | p2_atom_indices = P1P2["batch_atoms"] >= batch_size // 2
242 |
243 | P1 = {}
244 | P2 = {}
245 | for key in P1P2:
246 | v1v2 = P1P2[key]
247 |
248 | if (key == "rand_rot") or (key == "atom_center"):
249 | n = v1v2.shape[0] // 2
250 | P1[key] = v1v2[:n].view(-1, 3)
251 | P2[key] = v1v2[n:].view(-1, 3)
252 | elif "atom" in key:
253 | P1[key] = v1v2[p1_atom_indices]
254 | P2[key] = v1v2[p2_atom_indices]
255 | elif key == "triangles":
256 | continue
257 | # P1[key] = v1v2[:,p1_atom_indices]
258 | # P2[key] = v1v2[:,p2_atom_indices]
259 | else:
260 | P1[key] = v1v2[p1_indices]
261 | P2[key] = v1v2[p2_indices]
262 |
263 | P2["batch"] = P2["batch"] - batch_size + 1
264 | P2["batch_atoms"] = P2["batch_atoms"] - batch_size + 1
265 |
266 | return P1, P2
267 |
268 |
269 |
270 | def project_iface_labels(P, threshold=2.0):
271 |
272 | queries = P["xyz"]
273 | batch_queries = P["batch"]
274 | source = P["mesh_xyz"]
275 | batch_source = P["mesh_batch"]
276 | labels = P["mesh_labels"]
277 | x_i = LazyTensor(queries[:, None, :]) # (N, 1, D)
278 | y_j = LazyTensor(source[None, :, :]) # (1, M, D)
279 |
280 | D_ij = ((x_i - y_j) ** 2).sum(-1).sqrt() # (N, M)
281 | D_ij.ranges = diagonal_ranges(batch_queries, batch_source)
282 | nn_i = D_ij.argmin(dim=1).view(-1) # (N,)
283 | nn_dist_i = (
284 | D_ij.min(dim=1).view(-1, 1) < threshold
285 | ).float() # If chain is not connected because of missing densities MaSIF cut out a part of the protein
286 | query_labels = labels[nn_i] * nn_dist_i
287 | P["labels"] = query_labels
288 |
289 | class dMaSIF(nn.Module):
290 | def __init__(self, args):
291 | super(dMaSIF, self).__init__()
292 | # Additional geometric features: mean and Gauss curvatures computed at different scales.
293 | self.curvature_scales = args.curvature_scales
294 | self.args = args
295 |
296 | I = args.in_channels
297 | O = args.orientation_units
298 | E = args.emb_dims
299 | H = args.post_units
300 |
301 | # Computes chemical features
302 | self.atomnet = AtomNet_MP(args)
303 | self.dropout = nn.Dropout(args.dropout)
304 |
305 | if args.embedding_layer == "dMaSIF":
306 | # Post-processing, without batch norm:
307 | self.orientation_scores = nn.Sequential(
308 | nn.Linear(I, O),
309 | nn.LeakyReLU(negative_slope=0.2),
310 | nn.Linear(O, 1),
311 | )
312 |
313 | # Segmentation network:
314 | self.conv = dMaSIFConv_seg(
315 | args,
316 | in_channels=I,
317 | out_channels=E,
318 | n_layers=args.n_layers,
319 | radius=args.radius,
320 | )
321 |
322 | # Asymmetric embedding
323 | if args.search:
324 | self.orientation_scores2 = nn.Sequential(
325 | nn.Linear(I, O),
326 | nn.LeakyReLU(negative_slope=0.2),
327 | nn.Linear(O, 1),
328 | )
329 |
330 | self.conv2 = dMaSIFConv_seg(
331 | args,
332 | in_channels=I,
333 | out_channels=E,
334 | n_layers=args.n_layers,
335 | radius=args.radius,
336 | )
337 |
338 | elif args.embedding_layer == "DGCNN":
339 | self.conv = DGCNN_seg(I + 3, E,self.args.n_layers,self.args.k)
340 | if args.search:
341 | self.conv2 = DGCNN_seg(I + 3, E,self.args.n_layers,self.args.k)
342 |
343 | elif args.embedding_layer == "PointNet++":
344 | self.conv = PointNet2_seg(args, I, E)
345 | if args.search:
346 | self.conv2 = PointNet2_seg(args, I, E)
347 |
348 | if args.site:
349 | # Post-processing, without batch norm:
350 | self.net_out = nn.Sequential(
351 | nn.Linear(E, H),
352 | nn.LeakyReLU(negative_slope=0.2),
353 | nn.Linear(H, H),
354 | nn.LeakyReLU(negative_slope=0.2),
355 | nn.Linear(H, 1),
356 | )
357 |
358 | def features(self, P, i=1):
359 | """Estimates geometric and chemical features from a protein surface or a cloud of atoms."""
360 | if (
361 | not self.args.use_mesh and "xyz" not in P
362 | ): # Compute the pseudo-surface directly from the atoms
363 | # (Note that we use the fact that dicts are "passed by reference" here)
364 | P["xyz"], P["normals"], P["batch"] = atoms_to_points_normals(
365 | P["atoms"],
366 | P["batch_atoms"],
367 | atomtypes=P["atomtypes"],
368 | resolution=self.args.resolution,
369 | sup_sampling=self.args.sup_sampling,
370 | )
371 |
372 | # Estimate the curvatures using the triangles or the estimated normals:
373 | P_curvatures = curvatures(
374 | P["xyz"],
375 | triangles=P["triangles"] if self.args.use_mesh else None,
376 | normals=None if self.args.use_mesh else P["normals"],
377 | scales=self.curvature_scales,
378 | batch=P["batch"],
379 | )
380 |
381 | # Compute chemical features on-the-fly:
382 | chemfeats = self.atomnet(
383 | P["xyz"], P["atom_xyz"], P["atomtypes"], P["batch"], P["batch_atoms"]
384 | )
385 |
386 | if self.args.no_chem:
387 | chemfeats = 0.0 * chemfeats
388 | if self.args.no_geom:
389 | P_curvatures = 0.0 * P_curvatures
390 |
391 | # Concatenate our features:
392 | return torch.cat([P_curvatures, chemfeats], dim=1).contiguous()
393 |
394 | def embed(self, P):
395 | """Embeds all points of a protein in a high-dimensional vector space."""
396 |
397 | features = self.dropout(self.features(P))
398 | P["input_features"] = features
399 |
400 | torch.cuda.synchronize(device=features.device)
401 | torch.cuda.reset_max_memory_allocated(device=P["atoms"].device)
402 | begin = time.time()
403 |
404 | # Ours:
405 | if self.args.embedding_layer == "dMaSIF":
406 | self.conv.load_mesh(
407 | P["xyz"],
408 | triangles=P["triangles"] if self.args.use_mesh else None,
409 | normals=None if self.args.use_mesh else P["normals"],
410 | weights=self.orientation_scores(features),
411 | batch=P["batch"],
412 | )
413 | P["embedding_1"] = self.conv(features)
414 | if self.args.search:
415 | self.conv2.load_mesh(
416 | P["xyz"],
417 | triangles=P["triangles"] if self.args.use_mesh else None,
418 | normals=None if self.args.use_mesh else P["normals"],
419 | weights=self.orientation_scores2(features),
420 | batch=P["batch"],
421 | )
422 | P["embedding_2"] = self.conv2(features)
423 |
424 | # First baseline:
425 | elif self.args.embedding_layer == "DGCNN":
426 | features = torch.cat([features, P["xyz"]], dim=-1).contiguous()
427 | P["embedding_1"] = self.conv(P["xyz"], features, P["batch"])
428 | if self.args.search:
429 | P["embedding_2"] = self.conv2(
430 | P["xyz"], features, P["batch"]
431 | )
432 |
433 | # Second baseline
434 | elif self.args.embedding_layer == "PointNet++":
435 | P["embedding_1"] = self.conv(P["xyz"], features, P["batch"])
436 | if self.args.search:
437 | P["embedding_2"] = self.conv2(P["xyz"], features, P["batch"])
438 |
439 | torch.cuda.synchronize(device=features.device)
440 | end = time.time()
441 | memory_usage = torch.cuda.max_memory_allocated(device=P["atoms"].device)
442 | conv_time = end - begin
443 |
444 | return conv_time, memory_usage
445 |
446 | def preprocess_surface(self, P):
447 | P["xyz"], P["normals"], P["batch"] = atoms_to_points_normals(
448 | P["atoms"],
449 | P["batch_atoms"],
450 | atomtypes=P["atomtypes"],
451 | resolution=self.args.resolution,
452 | sup_sampling=self.args.sup_sampling,
453 | distance=self.args.distance,
454 | )
455 | if P['mesh_labels'] is not None:
456 | project_iface_labels(P)
457 |
458 | def forward(self, P1, P2=None):
459 | # Compute embeddings of the point clouds:
460 | if P2 is not None:
461 | P1P2 = combine_pair(P1, P2)
462 | else:
463 | P1P2 = P1
464 |
465 | conv_time, memory_usage = self.embed(P1P2)
466 |
467 | # Monitor the approximate rank of our representations:
468 | R_values = {}
469 | R_values["input"] = soft_dimension(P1P2["input_features"])
470 | R_values["conv"] = soft_dimension(P1P2["embedding_1"])
471 |
472 | if self.args.site:
473 | P1P2["iface_preds"] = self.net_out(P1P2["embedding_1"])
474 |
475 | if P2 is not None:
476 | P1, P2 = split_pair(P1P2)
477 | else:
478 | P1 = P1P2
479 |
480 | return {
481 | "P1": P1,
482 | "P2": P2,
483 | "R_values": R_values,
484 | "conv_time": conv_time,
485 | "memory_usage": memory_usage,
486 | }
487 |
--------------------------------------------------------------------------------
/models/dMaSIF_search_3layer_12A_16dim:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FreyrS/dMaSIF/0dcc26c3c218a39d5fe26beb2e788b95fb028896/models/dMaSIF_search_3layer_12A_16dim
--------------------------------------------------------------------------------
/overview.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FreyrS/dMaSIF/0dcc26c3c218a39d5fe26beb2e788b95fb028896/overview.PNG
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.11.0
2 | appdirs==1.4.4
3 | argon2-cffi==20.1.0
4 | ase==3.20.1
5 | async-generator==1.10
6 | attrs==20.3.0
7 | backcall==0.2.0
8 | biopython==1.78
9 | black==20.8b1
10 | bleach==3.2.1
11 | cached-property==1.5.2
12 | cachetools==4.1.1
13 | certifi==2020.6.20
14 | cffi==1.14.3
15 | chardet==3.0.4
16 | click==7.1.2
17 | cloudpickle==1.6.0
18 | cycler==0.10.0
19 | dask==2020.12.0
20 | decorator==4.4.2
21 | defusedxml==0.6.0
22 | entrypoints==0.3
23 | future==0.18.2
24 | google-auth==1.23.0
25 | google-auth-oauthlib==0.4.2
26 | googledrivedownloader==0.4
27 | GPUtil==1.4.0
28 | grpcio==1.33.2
29 | h5py==3.0.0
30 | idna==2.10
31 | importlib-metadata==2.0.0
32 | ipykernel==5.3.4
33 | ipython==7.19.0
34 | ipython-genutils==0.2.0
35 | ipywidgets==7.5.1
36 | isodate==0.6.0
37 | jedi==0.17.2
38 | Jinja2==2.11.2
39 | joblib==0.17.0
40 | jsonschema==3.2.0
41 | jupyter==1.0.0
42 | jupyter-client==6.1.7
43 | jupyter-console==6.2.0
44 | jupyter-core==4.6.3
45 | jupyterlab-pygments==0.1.2
46 | kaleido==0.0.3.post1
47 | kiwisolver==1.3.1
48 | llvmlite==0.34.0
49 | Markdown==3.3.3
50 | MarkupSafe==1.1.1
51 | matplotlib==3.3.2
52 | mistune==0.8.4
53 | mypy-extensions==0.4.3
54 | nbclient==0.5.1
55 | nbconvert==6.0.7
56 | nbformat==5.0.8
57 | nest-asyncio==1.4.3
58 | networkx==2.5
59 | notebook==6.1.5
60 | numba==0.51.2
61 | numpy==1.19.3
62 | oauthlib==3.1.0
63 | packaging==20.4
64 | pandas==1.1.4
65 | pandocfilters==1.4.3
66 | parso==0.7.1
67 | pathspec==0.8.1
68 | pexpect==4.8.0
69 | pickleshare==0.7.5
70 | Pillow==8.0.1
71 | plotly==4.13.0
72 | plyfile==0.7.2
73 | prometheus-client==0.8.0
74 | prompt-toolkit==3.0.8
75 | protobuf==3.13.0
76 | ptyprocess==0.6.0
77 | pyasn1==0.4.8
78 | pyasn1-modules==0.2.8
79 | pycparser==2.20
80 | Pygments==2.7.2
81 | pykeops==1.4.1
82 | pyparsing==2.4.7
83 | pyrsistent==0.17.3
84 | python-dateutil==2.8.1
85 | pytz==2020.4
86 | PyVTK==0.5.18
87 | PyYAML==5.3.1
88 | pyzmq==19.0.2
89 | qtconsole==4.7.7
90 | QtPy==1.9.0
91 | rdflib==5.0.0
92 | regex==2020.11.13
93 | requests==2.24.0
94 | requests-oauthlib==1.3.0
95 | retrying==1.3.3
96 | rope==0.18.0
97 | rsa==4.6
98 | scikit-learn==0.23.2
99 | scipy==1.5.3
100 | seaborn==0.11.0
101 | Send2Trash==1.5.0
102 | six==1.15.0
103 | tensorboard==2.3.0
104 | tensorboard-plugin-wit==1.7.0
105 | terminado==0.9.1
106 | testpath==0.4.4
107 | threadpoolctl==2.1.0
108 | toml==0.10.2
109 | toolz==0.11.1
110 | torch==1.6.0
111 | torch-cluster==1.5.8
112 | torch-geometric==1.6.1
113 | torch-scatter==2.0.5
114 | torch-sparse==0.6.8
115 | torch-spline-conv==1.2.0
116 | tornado==6.1
117 | tqdm==4.51.0
118 | traitlets==5.0.5
119 | typed-ast==1.4.1
120 | typing-extensions==3.7.4.3
121 | urllib3==1.25.11
122 | wcwidth==0.2.5
123 | webencodings==0.5.1
124 | Werkzeug==1.0.1
125 | widgetsnbextension==3.5.1
126 | zipp==3.4.0
127 |
--------------------------------------------------------------------------------