├── .gitignore ├── Arguments.py ├── LICENSE.md ├── README.md ├── benchmark_layers.py ├── benchmark_models.py ├── benchmark_scripts ├── DGCNN_site.sh ├── Pointnet_site.sh ├── dMaSIF_search.sh └── dMaSIF_site.sh ├── data.py ├── data_analysis ├── analyse_atomnet.ipynb ├── analyse_descriptors.py ├── analyse_descriptors_para.py ├── analyse_output.ipynb ├── analyse_site_outputs.py ├── analyse_site_outputs_graph.ipynb ├── plot_search.ipynb └── profiling_surface.ipynb ├── data_iteration.py ├── data_preprocessing ├── convert_pdb2npy.py ├── convert_ply2npy.py └── download_pdb.py ├── geometry_processing.py ├── helper.py ├── lists ├── testing.txt ├── testing_ppi.txt ├── training.txt └── training_ppi.txt ├── main_inference.py ├── main_training.py ├── model.py ├── models └── dMaSIF_search_3layer_12A_16dim ├── overview.PNG └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | 01-benchmark_surfaces/ 2 | 01-benchmark_surfaces_npy/ 3 | 01-benchmark_pdbs_npy/ 4 | 01-benchmark_pdbs/ 5 | 01-benchmark_pdbs/ 6 | shape_index/ 7 | masif_preds/ 8 | runs/ 9 | venv/ 10 | preds/ 11 | *.log 12 | NeurIPS_2020_benchmarks/ 13 | *.out 14 | figures/ 15 | timings/ 16 | data_analysis/roc_curves 17 | data_analysis/.ipynb_checkpoints/ 18 | .ipynb_checkpoints/ -------------------------------------------------------------------------------- /Arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description="Network parameters") 4 | 5 | # Main parameters 6 | parser.add_argument( 7 | "--experiment_name", type=str, help="Name of experiment", required=True 8 | ) 9 | parser.add_argument( 10 | "--use_mesh", type=bool, default=False, help="Use precomputed surfaces" 11 | ) 12 | parser.add_argument( 13 | "--embedding_layer", 14 | type=str, 15 | default="dMaSIF", 16 | choices=["dMaSIF", "DGCNN", "PointNet++"], 17 | help="Which convolutional embedding layer to use", 18 | ) 19 | parser.add_argument("--profile", type=bool, default=False, help="Profile code") 20 | 21 | # Geometric parameters 22 | parser.add_argument( 23 | "--curvature_scales", 24 | type=list, 25 | default=[1.0, 2.0, 3.0, 5.0, 10.0], 26 | help="Scales at which we compute the geometric features (mean and Gauss curvatures)", 27 | ) 28 | parser.add_argument( 29 | "--resolution", 30 | type=float, 31 | default=1.0, 32 | help="Resolution of the generated point cloud", 33 | ) 34 | parser.add_argument( 35 | "--distance", 36 | type=float, 37 | default=1.05, 38 | help="Distance parameter in surface generation", 39 | ) 40 | parser.add_argument( 41 | "--variance", 42 | type=float, 43 | default=0.1, 44 | help="Variance parameter in surface generation", 45 | ) 46 | parser.add_argument( 47 | "--sup_sampling", type=int, default=20, help="Sup-sampling ratio around atoms" 48 | ) 49 | 50 | # Hyper-parameters for the embedding 51 | parser.add_argument( 52 | "--atom_dims", 53 | type=int, 54 | default=6, 55 | help="Number of atom types and dimension of resulting chemical features", 56 | ) 57 | parser.add_argument( 58 | "--emb_dims", 59 | type=int, 60 | default=8, 61 | help="Number of input features (+ 3 xyz coordinates for DGCNNs)", 62 | ) 63 | parser.add_argument( 64 | "--in_channels", 65 | type=int, 66 | default=16, 67 | help="Number of embedding dimensions", 68 | ) 69 | parser.add_argument( 70 | "--orientation_units", 71 | type=int, 72 | default=16, 73 | help="Number of hidden units for the orientation score MLP", 74 | ) 75 | parser.add_argument( 76 | "--unet_hidden_channels", 77 | type=int, 78 | default=8, 79 | help="Number of hidden units for TangentConv UNet", 80 | ) 81 | parser.add_argument( 82 | "--post_units", 83 | type=int, 84 | default=8, 85 | help="Number of hidden units for the post-processing MLP", 86 | ) 87 | parser.add_argument( 88 | "--n_layers", type=int, default=1, help="Number of convolutional layers" 89 | ) 90 | parser.add_argument( 91 | "--radius", type=float, default=9.0, help="Radius to use for the convolution" 92 | ) 93 | parser.add_argument( 94 | "--k", 95 | type=int, 96 | default=40, 97 | help="Number of nearset neighbours for DGCNN and PointNet++", 98 | ) 99 | parser.add_argument( 100 | "--dropout", 101 | type=float, 102 | default=0.0, 103 | help="Amount of Dropout for the input features", 104 | ) 105 | 106 | # Training 107 | parser.add_argument( 108 | "--n_epochs", type=int, default=50, help="Number of training epochs" 109 | ) 110 | parser.add_argument( 111 | "--batch_size", type=int, default=1, help="Number of proteins in a batch" 112 | ) 113 | parser.add_argument( 114 | "--device", type=str, default="cuda:0", help="Which gpu/cpu to train on" 115 | ) 116 | parser.add_argument( 117 | "--restart_training", 118 | type=str, 119 | default="", 120 | help="Which model to restart the training from", 121 | ) 122 | parser.add_argument( 123 | "--n_rocauc_samples", 124 | type=int, 125 | default=100, 126 | help="Number of samples for the Matching ROC-AUC", 127 | ) 128 | parser.add_argument( 129 | "--validation_fraction", 130 | type=float, 131 | default=0.1, 132 | help="Fraction of training dataset to use for validation", 133 | ) 134 | parser.add_argument("--seed", type=int, default=42, help="Random seed") 135 | parser.add_argument( 136 | "--random_rotation", 137 | type=bool, 138 | default=False, 139 | help="Move proteins to center and add random rotation", 140 | ) 141 | parser.add_argument( 142 | "--single_protein", 143 | type=bool, 144 | default=False, 145 | help="Use single protein in a pair or both", 146 | ) 147 | parser.add_argument("--site", type=bool, default=False, help="Predict interaction site") 148 | parser.add_argument( 149 | "--search", 150 | type=bool, 151 | default=False, 152 | help="Predict matching between two partners", 153 | ) 154 | parser.add_argument( 155 | "--no_chem", type=bool, default=False, help="Predict without chemical information" 156 | ) 157 | parser.add_argument( 158 | "--no_geom", type=bool, default=False, help="Predict without curvature information" 159 | ) 160 | parser.add_argument( 161 | "--single_pdb", 162 | type=str, 163 | default="", 164 | help="Which structure to do inference on", 165 | ) 166 | parser.add_argument( 167 | "--pdb_list", 168 | type=str, 169 | default="", 170 | help="Which structures to do inference on", 171 | ) 172 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-NoDerivatives 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 58 | International Public License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial-NoDerivatives 4.0 International Public 63 | License ("Public License"). To the extent this Public License may be 64 | interpreted as a contract, You are granted the Licensed Rights in 65 | consideration of Your acceptance of these terms and conditions, and the 66 | Licensor grants You such rights in consideration of benefits the 67 | Licensor receives from making the Licensed Material available under 68 | these terms and conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Copyright and Similar Rights means copyright and/or similar rights 84 | closely related to copyright including, without limitation, 85 | performance, broadcast, sound recording, and Sui Generis Database 86 | Rights, without regard to how the rights are labeled or 87 | categorized. For purposes of this Public License, the rights 88 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 89 | Rights. 90 | 91 | c. Effective Technological Measures means those measures that, in the 92 | absence of proper authority, may not be circumvented under laws 93 | fulfilling obligations under Article 11 of the WIPO Copyright 94 | Treaty adopted on December 20, 1996, and/or similar international 95 | agreements. 96 | 97 | d. Exceptions and Limitations means fair use, fair dealing, and/or 98 | any other exception or limitation to Copyright and Similar Rights 99 | that applies to Your use of the Licensed Material. 100 | 101 | e. Licensed Material means the artistic or literary work, database, 102 | or other material to which the Licensor applied this Public 103 | License. 104 | 105 | f. Licensed Rights means the rights granted to You subject to the 106 | terms and conditions of this Public License, which are limited to 107 | all Copyright and Similar Rights that apply to Your use of the 108 | Licensed Material and that the Licensor has authority to license. 109 | 110 | g. Licensor means the individual(s) or entity(ies) granting rights 111 | under this Public License. 112 | 113 | h. NonCommercial means not primarily intended for or directed towards 114 | commercial advantage or monetary compensation. For purposes of 115 | this Public License, the exchange of the Licensed Material for 116 | other material subject to Copyright and Similar Rights by digital 117 | file-sharing or similar means is NonCommercial provided there is 118 | no payment of monetary compensation in connection with the 119 | exchange. 120 | 121 | i. Share means to provide material to the public by any means or 122 | process that requires permission under the Licensed Rights, such 123 | as reproduction, public display, public performance, distribution, 124 | dissemination, communication, or importation, and to make material 125 | available to the public including in ways that members of the 126 | public may access the material from a place and at a time 127 | individually chosen by them. 128 | 129 | j. Sui Generis Database Rights means rights other than copyright 130 | resulting from Directive 96/9/EC of the European Parliament and of 131 | the Council of 11 March 1996 on the legal protection of databases, 132 | as amended and/or succeeded, as well as other essentially 133 | equivalent rights anywhere in the world. 134 | 135 | k. You means the individual or entity exercising the Licensed Rights 136 | under this Public License. Your has a corresponding meaning. 137 | 138 | 139 | Section 2 -- Scope. 140 | 141 | a. License grant. 142 | 143 | 1. Subject to the terms and conditions of this Public License, 144 | the Licensor hereby grants You a worldwide, royalty-free, 145 | non-sublicensable, non-exclusive, irrevocable license to 146 | exercise the Licensed Rights in the Licensed Material to: 147 | 148 | a. reproduce and Share the Licensed Material, in whole or 149 | in part, for NonCommercial purposes only; and 150 | 151 | b. produce and reproduce, but not Share, Adapted Material 152 | for NonCommercial purposes only. 153 | 154 | 2. Exceptions and Limitations. For the avoidance of doubt, where 155 | Exceptions and Limitations apply to Your use, this Public 156 | License does not apply, and You do not need to comply with 157 | its terms and conditions. 158 | 159 | 3. Term. The term of this Public License is specified in Section 160 | 6(a). 161 | 162 | 4. Media and formats; technical modifications allowed. The 163 | Licensor authorizes You to exercise the Licensed Rights in 164 | all media and formats whether now known or hereafter created, 165 | and to make technical modifications necessary to do so. The 166 | Licensor waives and/or agrees not to assert any right or 167 | authority to forbid You from making technical modifications 168 | necessary to exercise the Licensed Rights, including 169 | technical modifications necessary to circumvent Effective 170 | Technological Measures. For purposes of this Public License, 171 | simply making modifications authorized by this Section 2(a) 172 | (4) never produces Adapted Material. 173 | 174 | 5. Downstream recipients. 175 | 176 | a. Offer from the Licensor -- Licensed Material. Every 177 | recipient of the Licensed Material automatically 178 | receives an offer from the Licensor to exercise the 179 | Licensed Rights under the terms and conditions of this 180 | Public License. 181 | 182 | b. No downstream restrictions. You may not offer or impose 183 | any additional or different terms or conditions on, or 184 | apply any Effective Technological Measures to, the 185 | Licensed Material if doing so restricts exercise of the 186 | Licensed Rights by any recipient of the Licensed 187 | Material. 188 | 189 | 6. No endorsement. Nothing in this Public License constitutes or 190 | may be construed as permission to assert or imply that You 191 | are, or that Your use of the Licensed Material is, connected 192 | with, or sponsored, endorsed, or granted official status by, 193 | the Licensor or others designated to receive attribution as 194 | provided in Section 3(a)(1)(A)(i). 195 | 196 | b. Other rights. 197 | 198 | 1. Moral rights, such as the right of integrity, are not 199 | licensed under this Public License, nor are publicity, 200 | privacy, and/or other similar personality rights; however, to 201 | the extent possible, the Licensor waives and/or agrees not to 202 | assert any such rights held by the Licensor to the limited 203 | extent necessary to allow You to exercise the Licensed 204 | Rights, but not otherwise. 205 | 206 | 2. Patent and trademark rights are not licensed under this 207 | Public License. 208 | 209 | 3. To the extent possible, the Licensor waives any right to 210 | collect royalties from You for the exercise of the Licensed 211 | Rights, whether directly or through a collecting society 212 | under any voluntary or waivable statutory or compulsory 213 | licensing scheme. In all other cases the Licensor expressly 214 | reserves any right to collect such royalties, including when 215 | the Licensed Material is used other than for NonCommercial 216 | purposes. 217 | 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material, You must: 227 | 228 | a. retain the following if it is supplied by the Licensor 229 | with the Licensed Material: 230 | 231 | i. identification of the creator(s) of the Licensed 232 | Material and any others designated to receive 233 | attribution, in any reasonable manner requested by 234 | the Licensor (including by pseudonym if 235 | designated); 236 | 237 | ii. a copyright notice; 238 | 239 | iii. a notice that refers to this Public License; 240 | 241 | iv. a notice that refers to the disclaimer of 242 | warranties; 243 | 244 | v. a URI or hyperlink to the Licensed Material to the 245 | extent reasonably practicable; 246 | 247 | b. indicate if You modified the Licensed Material and 248 | retain an indication of any previous modifications; and 249 | 250 | c. indicate the Licensed Material is licensed under this 251 | Public License, and include the text of, or the URI or 252 | hyperlink to, this Public License. 253 | 254 | For the avoidance of doubt, You do not have permission under 255 | this Public License to Share Adapted Material. 256 | 257 | 2. You may satisfy the conditions in Section 3(a)(1) in any 258 | reasonable manner based on the medium, means, and context in 259 | which You Share the Licensed Material. For example, it may be 260 | reasonable to satisfy the conditions by providing a URI or 261 | hyperlink to a resource that includes the required 262 | information. 263 | 264 | 3. If requested by the Licensor, You must remove any of the 265 | information required by Section 3(a)(1)(A) to the extent 266 | reasonably practicable. 267 | 268 | 269 | Section 4 -- Sui Generis Database Rights. 270 | 271 | Where the Licensed Rights include Sui Generis Database Rights that 272 | apply to Your use of the Licensed Material: 273 | 274 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 275 | to extract, reuse, reproduce, and Share all or a substantial 276 | portion of the contents of the database for NonCommercial purposes 277 | only and provided You do not Share Adapted Material; 278 | 279 | b. if You include all or a substantial portion of the database 280 | contents in a database in which You have Sui Generis Database 281 | Rights, then the database in which You have Sui Generis Database 282 | Rights (but not its individual contents) is Adapted Material; and 283 | 284 | c. You must comply with the conditions in Section 3(a) if You Share 285 | all or a substantial portion of the contents of the database. 286 | 287 | For the avoidance of doubt, this Section 4 supplements and does not 288 | replace Your obligations under this Public License where the Licensed 289 | Rights include other Copyright and Similar Rights. 290 | 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | 350 | Section 7 -- Other Terms and Conditions. 351 | 352 | a. The Licensor shall not be bound by any additional or different 353 | terms or conditions communicated by You unless expressly agreed. 354 | 355 | b. Any arrangements, understandings, or agreements regarding the 356 | Licensed Material not stated herein are separate from and 357 | independent of the terms and conditions of this Public License. 358 | 359 | 360 | Section 8 -- Interpretation. 361 | 362 | a. For the avoidance of doubt, this Public License does not, and 363 | shall not be interpreted to, reduce, limit, restrict, or impose 364 | conditions on any use of the Licensed Material that could lawfully 365 | be made without permission under this Public License. 366 | 367 | b. To the extent possible, if any provision of this Public License is 368 | deemed unenforceable, it shall be automatically reformed to the 369 | minimum extent necessary to make it enforceable. If the provision 370 | cannot be reformed, it shall be severed from this Public License 371 | without affecting the enforceability of the remaining terms and 372 | conditions. 373 | 374 | c. No term or condition of this Public License will be waived and no 375 | failure to comply consented to unless expressly agreed to by the 376 | Licensor. 377 | 378 | d. Nothing in this Public License constitutes or may be interpreted 379 | as a limitation upon, or waiver of, any privileges and immunities 380 | that apply to the Licensor or You, including from the legal 381 | processes of any jurisdiction or authority. 382 | 383 | ======================================================================= 384 | 385 | Creative Commons is not a party to its public 386 | licenses. Notwithstanding, Creative Commons may elect to apply one of 387 | its public licenses to material it publishes and in those instances 388 | will be considered the “Licensor.” The text of the Creative Commons 389 | public licenses is dedicated to the public domain under the CC0 Public 390 | Domain Dedication. Except for the limited purpose of indicating that 391 | material is shared under a Creative Commons public license or as 392 | otherwise permitted by the Creative Commons policies published at 393 | creativecommons.org/policies, Creative Commons does not authorize the 394 | use of the trademark "Creative Commons" or any other trademark or logo 395 | of Creative Commons without its prior written consent including, 396 | without limitation, in connection with any unauthorized modifications 397 | to any of its public licenses or any other arrangements, 398 | understandings, or agreements concerning use of licensed material. For 399 | the avoidance of doubt, this paragraph does not form part of the 400 | public licenses. 401 | 402 | Creative Commons may be contacted at creativecommons.org. 403 | 404 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## dMaSIF - Fast end-to-end learning on protein surfaces 2 | ![Method overview](overview.PNG) 3 | 4 | ## Abstract 5 | 6 | Proteins’ biological functions are defined by the geometric 7 | and chemical structure of their 3D molecular surfaces. 8 | Recent works have shown that geometric deep learning can 9 | be used on mesh-based representations of proteins to identify 10 | potential functional sites, such as binding targets for 11 | potential drugs. Unfortunately though, the use of meshes as 12 | the underlying representation for protein structure has multiple 13 | drawbacks including the need to pre-compute the input 14 | features and mesh connectivities. This becomes a bottleneck 15 | for many important tasks in protein science. 16 | 17 | In this paper, we present a new framework for deep 18 | learning on protein structures that addresses these limitations. 19 | Among the key advantages of our method are the computation 20 | and sampling of the molecular surface on-the-fly 21 | from the underlying atomic point cloud and a novel efficient 22 | geometric convolutional layer. As a result, we are able to 23 | process large collections of proteins in an end-to-end fashion, 24 | taking as the sole input the raw 3D coordinates and 25 | chemical types of their atoms, eliminating the need for any 26 | hand-crafted pre-computed features. 27 | 28 | To showcase the performance of our approach, we test it 29 | on two tasks in the field of protein structural bioinformatics: 30 | the identification of interaction sites and the prediction 31 | of protein-protein interactions. On both tasks, we achieve 32 | state-of-the-art performance with much faster run times and 33 | fewer parameters than previous models. These results will 34 | considerably ease the deployment of deep learning methods 35 | in protein science and open the door for end-to-end differentiable 36 | approaches in protein modeling tasks such as function 37 | prediction and design. 38 | 39 | ## Hardware requirements 40 | 41 | Models have been trained on either a single NVIDIA RTX 2080 Ti or a single Tesla V100 GPU. Time and memory benchmarks were performed on a single Tesla V100. 42 | 43 | ## Software prerequisites 44 | 45 | Scripts have been tested using the following two sets of core dependencies: 46 | 47 | | Dependency | First Option | Second Option | 48 | | ------------- | ------------- | ------------- | 49 | | GCC | 7.5.0 | 8.4.0 | 50 | | CMAKE | 3.10.2 | 3.16.5 | 51 | | CUDA | 10.0.130 | 10.2.89 | 52 | | cuDNN | 7.6.4.38 | 7.6.5.32 | 53 | | Python | 3.6.9 | 3.7.7 | 54 | | PyTorch | 1.4.0 | 1.6.0 | 55 | | PyKeops | 1.4 | 1.4.1 | 56 | | PyTorch Geometric | 1.5.0 | 1.6.1 | 57 | 58 | 59 | ## Code overview 60 | 61 | 62 | Usage: 63 | - In order to **train models**, run `main_training.py` with the appropriate flags. 64 | Available flags and their descriptions can be found in `Arguments.py`. 65 | 66 | - The command line options needed to reproduce the **benchmarks** can be found in `benchmark_scripts/`. 67 | 68 | - To make **inference** on the testing set using pretrained models, use `main_inference.py` with the flags that were used for training the models. 69 | Note that the `--experiment_name flag` should be modified to specify the training epoch to use. 70 | 71 | Implementation: 72 | - Our **surface generation** algorithm, **curvature** estimation method and **quasi-geodesic convolutions** are implemented in `geometry_processing.py`. 73 | 74 | - The **definition of the neural network** along with surface and input features can be found in `model.py`. The convolutional layers are implemented in `benchmark_models.py`. 75 | 76 | - The scripts used to **generate the figures** of the paper can be found in `data_analysis/`. 77 | 78 | 79 | ## License 80 | 81 | Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License. 82 | 83 | ## Reference 84 | 85 | Sverrisson, F., Feydy, J., Correia, B. E., & Bronstein, M. M. (2020). Fast end-to-end learning on protein surfaces. [bioRxiv](https://www.biorxiv.org/content/10.1101/2020.12.28.424589v1). -------------------------------------------------------------------------------- /benchmark_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | from pykeops.torch import LazyTensor 4 | from torch_geometric.nn import EdgeConv, Reshape 5 | 6 | from torch_cluster import knn 7 | 8 | from math import ceil 9 | from torch_geometric.nn.inits import reset 10 | 11 | from torch.nn import ELU, Conv1d 12 | from torch.nn import Sequential as S, Linear as L, BatchNorm1d as BN 13 | 14 | 15 | def ranges_slices(batch): 16 | """Helper function for the diagonal ranges function.""" 17 | Ns = batch.bincount() 18 | indices = Ns.cumsum(0) 19 | ranges = torch.cat((0 * indices[:1], indices)) 20 | ranges = ( 21 | torch.stack((ranges[:-1], ranges[1:])).t().int().contiguous().to(batch.device) 22 | ) 23 | slices = (1 + torch.arange(len(Ns))).int().to(batch.device) 24 | 25 | return ranges, slices 26 | 27 | 28 | def diagonal_ranges(batch_x=None, batch_y=None): 29 | """Encodes the block-diagonal structure associated to a batch vector.""" 30 | 31 | if batch_x is None and batch_y is None: 32 | return None 33 | 34 | ranges_x, slices_x = ranges_slices(batch_x) 35 | ranges_y, slices_y = ranges_slices(batch_y) 36 | 37 | return ranges_x, slices_x, ranges_y, ranges_y, slices_y, ranges_x 38 | 39 | 40 | @torch.jit.ignore 41 | def keops_knn( 42 | x: torch.Tensor, 43 | y: torch.Tensor, 44 | k: int, 45 | batch_x: Optional[torch.Tensor] = None, 46 | batch_y: Optional[torch.Tensor] = None, 47 | cosine: bool = False, 48 | ) -> torch.Tensor: 49 | r"""Straightforward modification of PyTorch_geometric's knn method.""" 50 | 51 | x = x.view(-1, 1) if x.dim() == 1 else x 52 | y = y.view(-1, 1) if y.dim() == 1 else y 53 | 54 | y_i = LazyTensor(y[:, None, :]) 55 | x_j = LazyTensor(x[None, :, :]) 56 | 57 | if cosine: 58 | D_ij = -(y_i | x_j) 59 | else: 60 | D_ij = ((y_i - x_j) ** 2).sum(-1) 61 | 62 | D_ij.ranges = diagonal_ranges(batch_y, batch_x) 63 | idy = D_ij.argKmin(k, dim=1) # (N, K) 64 | 65 | rows = torch.arange(k * len(y), device=idy.device) // k 66 | 67 | return torch.stack([rows, idy.view(-1)], dim=0) 68 | 69 | 70 | knns = {"torch": knn, "keops": keops_knn} 71 | 72 | 73 | @torch.jit.ignore 74 | def knn_graph( 75 | x: torch.Tensor, 76 | k: int, 77 | batch: Optional[torch.Tensor] = None, 78 | loop: bool = False, 79 | flow: str = "source_to_target", 80 | cosine: bool = False, 81 | target: Optional[torch.Tensor] = None, 82 | batch_target: Optional[torch.Tensor] = None, 83 | backend: str = "torch", 84 | ) -> torch.Tensor: 85 | r"""Straightforward modification of PyTorch_geometric's knn_graph method to allow for source/targets.""" 86 | 87 | assert flow in ["source_to_target", "target_to_source"] 88 | if target is None: 89 | target = x 90 | if batch_target is None: 91 | batch_target = batch 92 | 93 | row, col = knns[backend]( 94 | x, target, k if loop else k + 1, batch, batch_target, cosine=cosine 95 | ) 96 | row, col = (col, row) if flow == "source_to_target" else (row, col) 97 | if not loop: 98 | mask = row != col 99 | row, col = row[mask], col[mask] 100 | return torch.stack([row, col], dim=0) 101 | 102 | 103 | class MyDynamicEdgeConv(EdgeConv): 104 | r"""Straightforward modification of PyTorch_geometric's DynamicEdgeConv layer.""" 105 | 106 | def __init__(self, nn, k, aggr="max", **kwargs): 107 | super(MyDynamicEdgeConv, self).__init__(nn=nn, aggr=aggr, **kwargs) 108 | self.k = k 109 | 110 | def forward(self, x, batch=None): 111 | """""" 112 | edge_index = knn_graph( 113 | x, self.k, batch, loop=False, flow=self.flow, backend="keops" 114 | ) 115 | return super(MyDynamicEdgeConv, self).forward(x, edge_index) 116 | 117 | def __repr__(self): 118 | return "{}(nn={}, k={})".format(self.__class__.__name__, self.nn, self.k) 119 | 120 | 121 | class MyXConv(torch.nn.Module): 122 | def __init__( 123 | self, 124 | in_channels=None, 125 | out_channels=None, 126 | dim=None, 127 | kernel_size=None, 128 | hidden_channels=None, 129 | dilation=1, 130 | bias=True, 131 | backend="torch", 132 | ): 133 | super(MyXConv, self).__init__() 134 | 135 | self.in_channels = in_channels 136 | if hidden_channels is None: 137 | hidden_channels = in_channels // 4 138 | if hidden_channels == 0: 139 | hidden_channels = 1 140 | 141 | self.hidden_channels = hidden_channels 142 | self.out_channels = out_channels 143 | self.dim = dim 144 | self.kernel_size = kernel_size 145 | self.dilation = dilation 146 | self.backend = backend 147 | 148 | C_in, C_delta, C_out = in_channels, hidden_channels, out_channels 149 | D, K = dim, kernel_size 150 | 151 | self.mlp1 = S( 152 | L(dim, C_delta), 153 | ELU(), 154 | BN(C_delta), 155 | L(C_delta, C_delta), 156 | ELU(), 157 | BN(C_delta), 158 | Reshape(-1, K, C_delta), 159 | ) 160 | 161 | self.mlp2 = S( 162 | L(D * K, K ** 2), 163 | ELU(), 164 | BN(K ** 2), 165 | Reshape(-1, K, K), 166 | Conv1d(K, K ** 2, K, groups=K), 167 | ELU(), 168 | BN(K ** 2), 169 | Reshape(-1, K, K), 170 | Conv1d(K, K ** 2, K, groups=K), 171 | BN(K ** 2), 172 | Reshape(-1, K, K), 173 | ) 174 | 175 | C_in = C_in + C_delta 176 | depth_multiplier = int(ceil(C_out / C_in)) 177 | self.conv = S( 178 | Conv1d(C_in, C_in * depth_multiplier, K, groups=C_in), 179 | Reshape(-1, C_in * depth_multiplier), 180 | L(C_in * depth_multiplier, C_out, bias=bias), 181 | ) 182 | 183 | self.reset_parameters() 184 | 185 | def reset_parameters(self): 186 | reset(self.mlp1) 187 | reset(self.mlp2) 188 | reset(self.conv) 189 | 190 | def forward(self, x, source, batch_source, target, batch_target): 191 | """""" 192 | # Load data shapes: 193 | # pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos 194 | (Nin, _), (N, D), K = source.size(), target.size(), self.kernel_size 195 | 196 | # Compute K-nn: 197 | row, col = knn_graph( 198 | source, 199 | K * self.dilation, 200 | batch_source, 201 | loop=True, 202 | flow="target_to_source", 203 | target=target, 204 | batch_target=batch_target, 205 | backend=self.backend, 206 | ) 207 | # row is a vector of size N*K*dilation that indexes "target" 208 | # col is a vector of size N*K*dilation that indexes "source" 209 | 210 | # If needed, sup-sample the K-NN graph: 211 | if self.dilation > 1: 212 | dil = self.dilation 213 | index = torch.randint( 214 | K * dil, 215 | (N, K), 216 | dtype=torch.long, 217 | layout=torch.strided, 218 | device=row.device, 219 | ) 220 | arange = torch.arange(N, dtype=torch.long, device=row.device) 221 | arange = arange * (K * dil) 222 | index = (index + arange.view(-1, 1)).view(-1) # (N*K,) 223 | row, col = row[index], col[index] 224 | 225 | # assert row.max() < N 226 | # assert col.max() < Nin 227 | 228 | # Line 1: local difference vector: 229 | pos = source[col] - target[row] # (N * K, D) 230 | 231 | # Line 2: compute F_delta 232 | x_star = self.mlp1(pos.view(N * K, D)) 233 | 234 | # Line 3: concatenate the features and reshape: 235 | if x is not None: 236 | x = x.unsqueeze(-1) if x.dim() == 1 else x 237 | x = x[col].view(N, K, self.in_channels) 238 | x_star = torch.cat([x_star, x], dim=-1) 239 | x_star = x_star.transpose(1, 2).contiguous() 240 | x_star = x_star.view(N, self.in_channels + self.hidden_channels, K, 1) 241 | 242 | # Line 4: Compute the transformation matrix: 243 | transform_matrix = self.mlp2(pos.view(N, K * D)) 244 | transform_matrix = transform_matrix.view(N, 1, K, K) 245 | 246 | # Line 5: Apply it to the neighborhood: 247 | x_transformed = torch.matmul(transform_matrix, x_star) 248 | x_transformed = x_transformed.view(N, -1, K) # (N, I+H, K) 249 | 250 | # Line 6: Apply the convolution filter: 251 | out = self.conv(x_transformed) # (N, Cout) 252 | 253 | return out 254 | 255 | def __repr__(self): 256 | return "{}({}, {})".format( 257 | self.__class__.__name__, self.in_channels, self.out_channels 258 | ) 259 | -------------------------------------------------------------------------------- /benchmark_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import ( 5 | Sequential as Seq, 6 | Dropout, 7 | Linear as Lin, 8 | LeakyReLU, 9 | ReLU, 10 | BatchNorm1d as BN, 11 | ) 12 | import torch_geometric.transforms as T 13 | from torch_geometric.data import DataLoader 14 | from torch_geometric.nn import ( 15 | DynamicEdgeConv, 16 | PointConv, 17 | XConv, 18 | fps, 19 | radius, 20 | global_max_pool, 21 | knn_interpolate, 22 | ) 23 | from pykeops.torch import LazyTensor 24 | 25 | from benchmark_layers import MyDynamicEdgeConv, MyXConv 26 | from geometry_processing import dMaSIFConv, mesh_normals_areas, tangent_vectors 27 | from helper import diagonal_ranges 28 | 29 | DEConv = {"torch": DynamicEdgeConv, "keops": MyDynamicEdgeConv} 30 | 31 | # Dynamic Graph CNNs =========================================================== 32 | # Adapted from the PyTorch_geometric gallery to get a close fit to 33 | # the original paper. 34 | 35 | 36 | def MLP(channels, batch_norm=True): 37 | """Multi-layer perceptron, with ReLU non-linearities and batch normalization.""" 38 | return Seq( 39 | *[ 40 | Seq( 41 | Lin(channels[i - 1], channels[i]), 42 | BN(channels[i]) if batch_norm else nn.Identity(), 43 | LeakyReLU(negative_slope=0.2), 44 | ) 45 | for i in range(1, len(channels)) 46 | ] 47 | ) 48 | 49 | 50 | class DGCNN_seg(torch.nn.Module): 51 | def __init__( 52 | self, in_channels, out_channels, n_layers, k=40, aggr="max", backend="keops" 53 | ): 54 | super(DGCNN_seg, self).__init__() 55 | 56 | self.name = "DGCNN_seg_" + backend 57 | self.I, self.O = ( 58 | in_channels + 3, 59 | out_channels, 60 | ) # Add coordinates to input channels 61 | self.n_layers = n_layers 62 | 63 | self.transform_1 = DEConv[backend](MLP([2 * 3, 64, 128]), k, aggr) 64 | self.transform_2 = MLP([128, 1024]) 65 | self.transform_3 = MLP([1024, 512, 256], batch_norm=False) 66 | self.transform_4 = Lin(256, 3 * 3) 67 | 68 | self.conv_layers = nn.ModuleList( 69 | [DEConv[backend](MLP([2 * self.I, self.O, self.O]), k, aggr)] 70 | + [ 71 | DEConv[backend](MLP([2 * self.O, self.O, self.O]), k, aggr) 72 | for i in range(n_layers - 1) 73 | ] 74 | ) 75 | 76 | self.linear_layers = nn.ModuleList( 77 | [ 78 | nn.Sequential( 79 | nn.Linear(self.O, self.O), nn.ReLU(), nn.Linear(self.O, self.O) 80 | ) 81 | for i in range(n_layers) 82 | ] 83 | ) 84 | 85 | self.linear_transform = nn.ModuleList( 86 | [nn.Linear(self.I, self.O)] 87 | + [nn.Linear(self.O, self.O) for i in range(n_layers - 1)] 88 | ) 89 | 90 | def forward(self, positions, features, batch_indices): 91 | # Lab: (B,), Pos: (N, 3), Batch: (N,) 92 | pos, feat, batch = positions, features, batch_indices 93 | 94 | # TransformNet: 95 | x = pos # Don't use the normals! 96 | 97 | x = self.transform_1(x, batch) # (N, 3) -> (N, 128) 98 | x = self.transform_2(x) # (N, 128) -> (N, 1024) 99 | x = global_max_pool(x, batch) # (B, 1024) 100 | 101 | x = self.transform_3(x) # (B, 256) 102 | x = self.transform_4(x) # (B, 3*3) 103 | x = x[batch] # (N, 3*3) 104 | x = x.view(-1, 3, 3) # (N, 3, 3) 105 | 106 | # Apply the transform: 107 | x0 = torch.einsum("ni,nij->nj", pos, x) # (N, 3) 108 | 109 | # Add features to coordinates 110 | x = torch.cat([x0, feat], dim=-1).contiguous() 111 | 112 | for i in range(self.n_layers): 113 | x_i = self.conv_layers[i](x, batch) 114 | x_i = self.linear_layers[i](x_i) 115 | x = self.linear_transform[i](x) 116 | x = x + x_i 117 | 118 | return x 119 | 120 | 121 | # Reference PointNet models, from the PyTorch_geometric gallery ========================= 122 | 123 | 124 | class SAModule(torch.nn.Module): 125 | """Set abstraction module.""" 126 | 127 | def __init__(self, ratio, r, nn, max_num_neighbors=64): 128 | super(SAModule, self).__init__() 129 | self.ratio = ratio 130 | self.r = r 131 | self.conv = PointConv(nn) 132 | self.max_num_neighbors = max_num_neighbors 133 | 134 | def forward(self, x, pos, batch): 135 | # Subsample with Farthest Point Sampling: 136 | # idx = fps(pos, batch, ratio=self.ratio) # Extract self.ratio indices TURN OFF FOR NOW 137 | idx = torch.arange(0, len(pos), device=pos.device) 138 | 139 | # For each "cluster", get the list of (up to 64) neighbors in a ball of radius r: 140 | row, col = radius( 141 | pos, 142 | pos[idx], 143 | self.r, 144 | batch, 145 | batch[idx], 146 | max_num_neighbors=self.max_num_neighbors, 147 | ) 148 | 149 | # Applies the PointNet++ Conv: 150 | edge_index = torch.stack([col, row], dim=0) 151 | x = self.conv(x, (pos, pos[idx]), edge_index) 152 | 153 | # Return the features and sub-sampled point clouds: 154 | pos, batch = pos[idx], batch[idx] 155 | return x, pos, batch 156 | 157 | 158 | class GlobalSAModule(torch.nn.Module): 159 | def __init__(self, nn): 160 | super(GlobalSAModule, self).__init__() 161 | self.nn = nn 162 | 163 | def forward(self, x, pos, batch): 164 | x = self.nn(torch.cat([x, pos], dim=1)) 165 | x = global_max_pool(x, batch) 166 | pos = pos.new_zeros((x.size(0), 3)) 167 | batch = torch.arange(x.size(0), device=batch.device) 168 | return x, pos, batch 169 | 170 | 171 | class FPModule(torch.nn.Module): 172 | def __init__(self, k, nn): 173 | super(FPModule, self).__init__() 174 | self.k = k 175 | self.nn = nn 176 | 177 | def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip): 178 | x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k) 179 | if x_skip is not None: 180 | x = torch.cat([x, x_skip], dim=1) 181 | x = self.nn(x) 182 | return x, pos_skip, batch_skip 183 | 184 | 185 | class PointNet2_seg(torch.nn.Module): 186 | def __init__(self, args, in_channels, out_channels): 187 | super(PointNet2_seg, self).__init__() 188 | 189 | self.name = "PointNet2" 190 | self.I, self.O = in_channels, out_channels 191 | self.radius = args.radius 192 | self.k = 10000 # We don't restrict the number of points in a patch 193 | self.n_layers = args.n_layers 194 | 195 | # self.sa1_module = SAModule(1.0, self.radius, MLP([self.I+3, self.O, self.O]),self.k) 196 | self.layers = nn.ModuleList( 197 | [SAModule(1.0, self.radius, MLP([self.I + 3, self.O, self.O]), self.k)] 198 | + [ 199 | SAModule(1.0, self.radius, MLP([self.O + 3, self.O, self.O]), self.k) 200 | for i in range(self.n_layers - 1) 201 | ] 202 | ) 203 | 204 | self.linear_layers = nn.ModuleList( 205 | [ 206 | nn.Sequential( 207 | nn.Linear(self.O, self.O), nn.ReLU(), nn.Linear(self.O, self.O) 208 | ) 209 | for i in range(self.n_layers) 210 | ] 211 | ) 212 | 213 | self.linear_transform = nn.ModuleList( 214 | [nn.Linear(self.I, self.O)] 215 | + [nn.Linear(self.O, self.O) for i in range(self.n_layers - 1)] 216 | ) 217 | 218 | def forward(self, positions, features, batch_indices): 219 | x = (features, positions, batch_indices) 220 | for i, layer in enumerate(self.layers): 221 | x_i, pos, b_ind = layer(*x) 222 | x_i = self.linear_layers[i](x_i) 223 | x = self.linear_transform[i](x[0]) 224 | x = x + x_i 225 | x = (x, pos, b_ind) 226 | 227 | return x[0] 228 | 229 | 230 | ## TangentConv benchmark segmentation 231 | 232 | 233 | class dMaSIFConv_seg(torch.nn.Module): 234 | def __init__(self, args, in_channels, out_channels, n_layers, radius=9.0): 235 | super(dMaSIFConv_seg, self).__init__() 236 | 237 | self.name = "dMaSIFConv_seg_keops" 238 | self.radius = radius 239 | self.I, self.O = in_channels, out_channels 240 | 241 | self.layers = nn.ModuleList( 242 | [dMaSIFConv(self.I, self.O, radius, self.O)] 243 | + [dMaSIFConv(self.O, self.O, radius, self.O) for i in range(n_layers - 1)] 244 | ) 245 | 246 | self.linear_layers = nn.ModuleList( 247 | [ 248 | nn.Sequential( 249 | nn.Linear(self.O, self.O), nn.ReLU(), nn.Linear(self.O, self.O) 250 | ) 251 | for i in range(n_layers) 252 | ] 253 | ) 254 | 255 | self.linear_transform = nn.ModuleList( 256 | [nn.Linear(self.I, self.O)] 257 | + [nn.Linear(self.O, self.O) for i in range(n_layers - 1)] 258 | ) 259 | 260 | def forward(self, features): 261 | # Lab: (B,), Pos: (N, 3), Batch: (N,) 262 | points, nuv, ranges = self.points, self.nuv, self.ranges 263 | x = features 264 | for i, layer in enumerate(self.layers): 265 | x_i = layer(points, nuv, x, ranges) 266 | x_i = self.linear_layers[i](x_i) 267 | x = self.linear_transform[i](x) 268 | x = x + x_i 269 | 270 | return x 271 | 272 | def load_mesh(self, xyz, triangles=None, normals=None, weights=None, batch=None): 273 | """Loads the geometry of a triangle mesh. 274 | 275 | Input arguments: 276 | - xyz, a point cloud encoded as an (N, 3) Tensor. 277 | - triangles, a connectivity matrix encoded as an (N, 3) integer tensor. 278 | - weights, importance weights for the orientation estimation, encoded as an (N, 1) Tensor. 279 | - radius, the scale used to estimate the local normals. 280 | - a batch vector, following PyTorch_Geometric's conventions. 281 | 282 | The routine updates the model attributes: 283 | - points, i.e. the point cloud itself, 284 | - nuv, a local oriented basis in R^3 for every point, 285 | - ranges, custom KeOps syntax to implement batch processing. 286 | """ 287 | 288 | # 1. Save the vertices for later use in the convolutions --------------- 289 | self.points = xyz 290 | self.batch = batch 291 | self.ranges = diagonal_ranges( 292 | batch 293 | ) # KeOps support for heterogeneous batch processing 294 | self.triangles = triangles 295 | self.normals = normals 296 | self.weights = weights 297 | 298 | # 2. Estimate the normals and tangent frame ---------------------------- 299 | # Normalize the scale: 300 | points = xyz / self.radius 301 | 302 | # Normals and local areas: 303 | if normals is None: 304 | normals, areas = mesh_normals_areas(points, triangles, 0.5, batch) 305 | tangent_bases = tangent_vectors(normals) # Tangent basis (N, 2, 3) 306 | 307 | # 3. Steer the tangent bases according to the gradient of "weights" ---- 308 | 309 | # 3.a) Encoding as KeOps LazyTensors: 310 | # Orientation scores: 311 | weights_j = LazyTensor(weights.view(1, -1, 1)) # (1, N, 1) 312 | # Vertices: 313 | x_i = LazyTensor(points[:, None, :]) # (N, 1, 3) 314 | x_j = LazyTensor(points[None, :, :]) # (1, N, 3) 315 | # Normals: 316 | n_i = LazyTensor(normals[:, None, :]) # (N, 1, 3) 317 | n_j = LazyTensor(normals[None, :, :]) # (1, N, 3) 318 | # Tangent basis: 319 | uv_i = LazyTensor(tangent_bases.view(-1, 1, 6)) # (N, 1, 6) 320 | 321 | # 3.b) Pseudo-geodesic window: 322 | # Pseudo-geodesic squared distance: 323 | rho2_ij = ((x_j - x_i) ** 2).sum(-1) * ((2 - (n_i | n_j)) ** 2) # (N, N, 1) 324 | # Gaussian window: 325 | window_ij = (-rho2_ij).exp() # (N, N, 1) 326 | 327 | # 3.c) Coordinates in the (u, v) basis - not oriented yet: 328 | X_ij = uv_i.matvecmult(x_j - x_i) # (N, N, 2) 329 | 330 | # 3.d) Local average in the tangent plane: 331 | orientation_weight_ij = window_ij * weights_j # (N, N, 1) 332 | orientation_vector_ij = orientation_weight_ij * X_ij # (N, N, 2) 333 | 334 | # Support for heterogeneous batch processing: 335 | orientation_vector_ij.ranges = self.ranges # Block-diagonal sparsity mask 336 | 337 | orientation_vector_i = orientation_vector_ij.sum(dim=1) # (N, 2) 338 | orientation_vector_i = ( 339 | orientation_vector_i + 1e-5 340 | ) # Just in case someone's alone... 341 | 342 | # 3.e) Normalize stuff: 343 | orientation_vector_i = F.normalize(orientation_vector_i, p=2, dim=-1) #  (N, 2) 344 | ex_i, ey_i = ( 345 | orientation_vector_i[:, 0][:, None], 346 | orientation_vector_i[:, 1][:, None], 347 | ) # (N,1) 348 | 349 | # 3.f) Re-orient the (u,v) basis: 350 | uv_i = tangent_bases # (N, 2, 3) 351 | u_i, v_i = uv_i[:, 0, :], uv_i[:, 1, :] # (N, 3) 352 | tangent_bases = torch.cat( 353 | (ex_i * u_i + ey_i * v_i, -ey_i * u_i + ex_i * v_i), dim=1 354 | ).contiguous() # (N, 6) 355 | 356 | # 4. Store the local 3D frame as an attribute -------------------------- 357 | self.nuv = torch.cat( 358 | (normals.view(-1, 1, 3), tangent_bases.view(-1, 2, 3)), dim=1 359 | ) 360 | -------------------------------------------------------------------------------- /benchmark_scripts/DGCNN_site.sh: -------------------------------------------------------------------------------- 1 | # Load environment 2 | python -W ignore -u main_training.py --experiment_name DGCNN_site_1layer_k200 --batch_size 64 --embedding_layer DGCNN --site True --single_protein True --device cuda:0 --n_layers 1 --random_rotation True --k 200 3 | python -W ignore -u main_training.py --experiment_name DGCNN_site_1layer_k100 --batch_size 64 --embedding_layer DGCNN --site True --single_protein True --device cuda:0 --n_layers 1 --random_rotation True --k 100 4 | 5 | python -W ignore -u main_training.py --experiment_name DGCNN_site_3layer_k200 --batch_size 64 --embedding_layer DGCNN --site True --single_protein True --device cuda:0 --n_layers 3 --random_rotation True --k 200 6 | python -W ignore -u main_training.py --experiment_name DGCNN_site_3layer_k100 --batch_size 64 --embedding_layer DGCNN --site True --single_protein True --device cuda:0 --n_layers 3 --random_rotation True --k 100 -------------------------------------------------------------------------------- /benchmark_scripts/Pointnet_site.sh: -------------------------------------------------------------------------------- 1 | # Load environment 2 | python -W ignore -u main_training.py --experiment_name PointNet_site_3layer_15A --batch_size 64 --embedding_layer PointNet++ --site True --single_protein True --device cuda:0 --random_rotation True --radius 15.0 --n_layers 3 3 | python -W ignore -u main_training.py --experiment_name PointNet_site_3layer_5A --batch_size 64 --embedding_layer PointNet++ --site True --single_protein True --device cuda:0 --random_rotation True --radius 5.0 --n_layers 3 4 | python -W ignore -u main_training.py --experiment_name PointNet_site_3layer_9A --batch_size 64 --embedding_layer PointNet++ --site True --single_protein True --device cuda:0 --random_rotation True --radius 9.0 --n_layers 3 5 | 6 | python -W ignore -u main_training.py --experiment_name PointNet_site_1layer_15A --batch_size 64 --embedding_layer PointNet++ --site True --single_protein True --device cuda:0 --random_rotation True --radius 15.0 --n_layers 1 7 | python -W ignore -u main_training.py --experiment_name PointNet_site_1layer_5A --batch_size 64 --embedding_layer PointNet++ --site True --single_protein True --device cuda:0 --random_rotation True --radius 5.0 --n_layers 1 8 | python -W ignore -u main_training.py --experiment_name PointNet_site_1layer_9A --batch_size 64 --embedding_layer PointNet++ --site True --single_protein True --device cuda:0 --random_rotation True --radius 9.0 --n_layers 1 -------------------------------------------------------------------------------- /benchmark_scripts/dMaSIF_search.sh: -------------------------------------------------------------------------------- 1 | # Load environment 2 | python -W ignore -u main_training.py --experiment_name dMaSIF_search_1layer_12A --batch_size 64 --embedding_layer dMaSIF --search True --device cuda:0 --random_rotation True --radius 12.0 --n_layers 1 3 | python -W ignore -u main_training.py --experiment_name dMaSIF_search_3layer_12A --batch_size 64 --embedding_layer dMaSIF --search True --device cuda:0 --random_rotation True --radius 12.0 --n_layers 3 4 | -------------------------------------------------------------------------------- /benchmark_scripts/dMaSIF_site.sh: -------------------------------------------------------------------------------- 1 | # Load environment 2 | python -W ignore -u main_training.py --experiment_name dMaSIF_site_1layer_15A --batch_size 64 --embedding_layer dMaSIF --site True --single_protein True --random_rotation True --radius 15.0 --n_layers 1 3 | python -W ignore -u main_training.py --experiment_name dMaSIF_site_1layer_5A --batch_size 64 --embedding_layer dMaSIF --site True --single_protein True --random_rotation True --radius 5.0 --n_layers 1 4 | python -W ignore -u main_training.py --experiment_name dMaSIF_site_1layer_9A --batch_size 64 --embedding_layer dMaSIF --site True --single_protein True --random_rotation True --radius 9.0 --n_layers 1 5 | 6 | python -W ignore -u main_training.py --experiment_name dMaSIF_site_3layer_15A --batch_size 64 --embedding_layer dMaSIF --site True --single_protein True --random_rotation True --radius 15.0 --n_layers 3 7 | python -W ignore -u main_training.py --experiment_name dMaSIF_site_3layer_5A --batch_size 64 --embedding_layer dMaSIF --site True --single_protein True --random_rotation True --radius 5.0 --n_layers 3 8 | python -W ignore -u main_training.py --experiment_name dMaSIF_site_3layer_9A --batch_size 64 --embedding_layer dMaSIF --site True --single_protein True --random_rotation True --radius 9.0 --n_layers 3 9 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import InMemoryDataset, Data, DataLoader 3 | from torch_geometric.transforms import Compose 4 | import numpy as np 5 | from scipy.spatial.transform import Rotation 6 | import math 7 | import urllib.request 8 | import tarfile 9 | from pathlib import Path 10 | import requests 11 | from data_preprocessing.convert_pdb2npy import convert_pdbs 12 | from data_preprocessing.convert_ply2npy import convert_plys 13 | 14 | tensor = torch.FloatTensor 15 | inttensor = torch.LongTensor 16 | 17 | 18 | def numpy(x): 19 | return x.detach().cpu().numpy() 20 | 21 | 22 | def iface_valid_filter(protein_pair): 23 | labels1 = protein_pair.y_p1.reshape(-1) 24 | labels2 = protein_pair.y_p2.reshape(-1) 25 | valid1 = ( 26 | (torch.sum(labels1) < 0.75 * len(labels1)) 27 | and (torch.sum(labels1) > 30) 28 | and (torch.sum(labels1) > 0.01 * labels2.shape[0]) 29 | ) 30 | valid2 = ( 31 | (torch.sum(labels2) < 0.75 * len(labels2)) 32 | and (torch.sum(labels2) > 30) 33 | and (torch.sum(labels2) > 0.01 * labels1.shape[0]) 34 | ) 35 | 36 | return valid1 and valid2 37 | 38 | 39 | class RandomRotationPairAtoms(object): 40 | r"""Randomly rotate a protein""" 41 | 42 | def __call__(self, data): 43 | R1 = tensor(Rotation.random().as_matrix()) 44 | R2 = tensor(Rotation.random().as_matrix()) 45 | 46 | data.atom_coords_p1 = torch.matmul(R1, data.atom_coords_p1.T).T 47 | data.xyz_p1 = torch.matmul(R1, data.xyz_p1.T).T 48 | data.normals_p1 = torch.matmul(R1, data.normals_p1.T).T 49 | 50 | data.atom_coords_p2 = torch.matmul(R2, data.atom_coords_p2.T).T 51 | data.xyz_p2 = torch.matmul(R2, data.xyz_p2.T).T 52 | data.normals_p2 = torch.matmul(R2, data.normals_p2.T).T 53 | 54 | data.rand_rot1 = R1 55 | data.rand_rot2 = R2 56 | return data 57 | 58 | def __repr__(self): 59 | return "{}()".format(self.__class__.__name__) 60 | 61 | 62 | class CenterPairAtoms(object): 63 | r"""Centers a protein""" 64 | 65 | def __call__(self, data): 66 | atom_center1 = data.atom_coords_p1.mean(dim=-2, keepdim=True) 67 | atom_center2 = data.atom_coords_p2.mean(dim=-2, keepdim=True) 68 | 69 | data.atom_coords_p1 = data.atom_coords_p1 - atom_center1 70 | data.atom_coords_p2 = data.atom_coords_p2 - atom_center2 71 | 72 | data.xyz_p1 = data.xyz_p1 - atom_center1 73 | data.xyz_p2 = data.xyz_p2 - atom_center2 74 | 75 | data.atom_center1 = atom_center1 76 | data.atom_center2 = atom_center2 77 | return data 78 | 79 | def __repr__(self): 80 | return "{}()".format(self.__class__.__name__) 81 | 82 | 83 | class NormalizeChemFeatures(object): 84 | r"""Centers a protein""" 85 | 86 | def __call__(self, data): 87 | pb_upper = 3.0 88 | pb_lower = -3.0 89 | 90 | chem_p1 = data.chemical_features_p1 91 | chem_p2 = data.chemical_features_p2 92 | 93 | pb_p1 = chem_p1[:, 0] 94 | pb_p2 = chem_p2[:, 0] 95 | hb_p1 = chem_p1[:, 1] 96 | hb_p2 = chem_p2[:, 1] 97 | hp_p1 = chem_p1[:, 2] 98 | hp_p2 = chem_p2[:, 2] 99 | 100 | # Normalize PB 101 | pb_p1 = torch.clamp(pb_p1, pb_lower, pb_upper) 102 | pb_p1 = (pb_p1 - pb_lower) / (pb_upper - pb_lower) 103 | pb_p1 = 2 * pb_p1 - 1 104 | 105 | pb_p2 = torch.clamp(pb_p2, pb_lower, pb_upper) 106 | pb_p2 = (pb_p2 - pb_lower) / (pb_upper - pb_lower) 107 | pb_p2 = 2 * pb_p2 - 1 108 | 109 | # Normalize HP 110 | hp_p1 = hp_p1 / 4.5 111 | hp_p2 = hp_p2 / 4.5 112 | 113 | data.chemical_features_p1 = torch.stack([pb_p1, hb_p1, hp_p1]).T 114 | data.chemical_features_p2 = torch.stack([pb_p2, hb_p2, hp_p2]).T 115 | 116 | return data 117 | 118 | def __repr__(self): 119 | return "{}()".format(self.__class__.__name__) 120 | 121 | 122 | def load_protein_npy(pdb_id, data_dir, center=False, single_pdb=False): 123 | """Loads a protein surface mesh and its features""" 124 | 125 | # Load the data, and read the connectivity information: 126 | triangles = ( 127 | None 128 | if single_pdb 129 | else inttensor(np.load(data_dir / (pdb_id + "_triangles.npy"))).T 130 | ) 131 | # Normalize the point cloud, as specified by the user: 132 | points = None if single_pdb else tensor(np.load(data_dir / (pdb_id + "_xyz.npy"))) 133 | center_location = None if single_pdb else torch.mean(points, axis=0, keepdims=True) 134 | 135 | atom_coords = tensor(np.load(data_dir / (pdb_id + "_atomxyz.npy"))) 136 | atom_types = tensor(np.load(data_dir / (pdb_id + "_atomtypes.npy"))) 137 | 138 | if center: 139 | points = points - center_location 140 | atom_coords = atom_coords - center_location 141 | 142 | # Interface labels 143 | iface_labels = ( 144 | None 145 | if single_pdb 146 | else tensor(np.load(data_dir / (pdb_id + "_iface_labels.npy")).reshape((-1, 1))) 147 | ) 148 | 149 | # Features 150 | chemical_features = ( 151 | None if single_pdb else tensor(np.load(data_dir / (pdb_id + "_features.npy"))) 152 | ) 153 | 154 | # Normals 155 | normals = ( 156 | None if single_pdb else tensor(np.load(data_dir / (pdb_id + "_normals.npy"))) 157 | ) 158 | 159 | protein_data = Data( 160 | xyz=points, 161 | face=triangles, 162 | chemical_features=chemical_features, 163 | y=iface_labels, 164 | normals=normals, 165 | center_location=center_location, 166 | num_nodes=None if single_pdb else points.shape[0], 167 | atom_coords=atom_coords, 168 | atom_types=atom_types, 169 | ) 170 | return protein_data 171 | 172 | 173 | class PairData(Data): 174 | def __init__( 175 | self, 176 | xyz_p1=None, 177 | xyz_p2=None, 178 | face_p1=None, 179 | face_p2=None, 180 | chemical_features_p1=None, 181 | chemical_features_p2=None, 182 | y_p1=None, 183 | y_p2=None, 184 | normals_p1=None, 185 | normals_p2=None, 186 | center_location_p1=None, 187 | center_location_p2=None, 188 | atom_coords_p1=None, 189 | atom_coords_p2=None, 190 | atom_types_p1=None, 191 | atom_types_p2=None, 192 | atom_center1=None, 193 | atom_center2=None, 194 | rand_rot1=None, 195 | rand_rot2=None, 196 | ): 197 | super().__init__() 198 | self.xyz_p1 = xyz_p1 199 | self.xyz_p2 = xyz_p2 200 | self.face_p1 = face_p1 201 | self.face_p2 = face_p2 202 | 203 | self.chemical_features_p1 = chemical_features_p1 204 | self.chemical_features_p2 = chemical_features_p2 205 | self.y_p1 = y_p1 206 | self.y_p2 = y_p2 207 | self.normals_p1 = normals_p1 208 | self.normals_p2 = normals_p2 209 | self.center_location_p1 = center_location_p1 210 | self.center_location_p2 = center_location_p2 211 | self.atom_coords_p1 = atom_coords_p1 212 | self.atom_coords_p2 = atom_coords_p2 213 | self.atom_types_p1 = atom_types_p1 214 | self.atom_types_p2 = atom_types_p2 215 | self.atom_center1 = atom_center1 216 | self.atom_center2 = atom_center2 217 | self.rand_rot1 = rand_rot1 218 | self.rand_rot2 = rand_rot2 219 | 220 | def __inc__(self, key, value): 221 | if key == "face_p1": 222 | return self.xyz_p1.size(0) 223 | if key == "face_p2": 224 | return self.xyz_p2.size(0) 225 | else: 226 | return super(PairData, self).__inc__(key, value) 227 | 228 | def __cat_dim__(self, key, value): 229 | if ("index" in key) or ("face" in key): 230 | return 1 231 | else: 232 | return 0 233 | 234 | 235 | def load_protein_pair(pdb_id, data_dir,single_pdb=False): 236 | """Loads a protein surface mesh and its features""" 237 | pspl = pdb_id.split("_") 238 | p1_id = pspl[0] + "_" + pspl[1] 239 | p2_id = pspl[0] + "_" + pspl[2] 240 | 241 | p1 = load_protein_npy(p1_id, data_dir, center=False,single_pdb=single_pdb) 242 | p2 = load_protein_npy(p2_id, data_dir, center=False,single_pdb=single_pdb) 243 | # pdist = ((p1['xyz'][:,None,:]-p2['xyz'][None,:,:])**2).sum(-1).sqrt() 244 | # pdist = pdist<2.0 245 | # y_p1 = (pdist.sum(1)>0).to(torch.float).reshape(-1,1) 246 | # y_p2 = (pdist.sum(0)>0).to(torch.float).reshape(-1,1) 247 | y_p1 = p1["y"] 248 | y_p2 = p2["y"] 249 | 250 | protein_pair_data = PairData( 251 | xyz_p1=p1["xyz"], 252 | xyz_p2=p2["xyz"], 253 | face_p1=p1["face"], 254 | face_p2=p2["face"], 255 | chemical_features_p1=p1["chemical_features"], 256 | chemical_features_p2=p2["chemical_features"], 257 | y_p1=y_p1, 258 | y_p2=y_p2, 259 | normals_p1=p1["normals"], 260 | normals_p2=p2["normals"], 261 | center_location_p1=p1["center_location"], 262 | center_location_p2=p2["center_location"], 263 | atom_coords_p1=p1["atom_coords"], 264 | atom_coords_p2=p2["atom_coords"], 265 | atom_types_p1=p1["atom_types"], 266 | atom_types_p2=p2["atom_types"], 267 | ) 268 | return protein_pair_data 269 | 270 | 271 | class ProteinPairsSurfaces(InMemoryDataset): 272 | url = "" 273 | 274 | def __init__(self, root, ppi=False, train=True, transform=None, pre_transform=None): 275 | self.ppi = ppi 276 | super(ProteinPairsSurfaces, self).__init__(root, transform, pre_transform) 277 | path = self.processed_paths[0] if train else self.processed_paths[1] 278 | self.data, self.slices = torch.load(path) 279 | 280 | @property 281 | def raw_file_names(self): 282 | return "masif_site_masif_search_pdbs_and_ply_files.tar.gz" 283 | 284 | @property 285 | def processed_file_names(self): 286 | if not self.ppi: 287 | file_names = [ 288 | "training_pairs_data.pt", 289 | "testing_pairs_data.pt", 290 | "training_pairs_data_ids.npy", 291 | "testing_pairs_data_ids.npy", 292 | ] 293 | else: 294 | file_names = [ 295 | "training_pairs_data_ppi.pt", 296 | "testing_pairs_data_ppi.pt", 297 | "training_pairs_data_ids_ppi.npy", 298 | "testing_pairs_data_ids_ppi.npy", 299 | ] 300 | return file_names 301 | 302 | def download(self): 303 | url = 'https://zenodo.org/record/2625420/files/masif_site_masif_search_pdbs_and_ply_files.tar.gz' 304 | target_path = self.raw_paths[0] 305 | response = requests.get(url, stream=True) 306 | if response.status_code == 200: 307 | with open(target_path, 'wb') as f: 308 | f.write(response.raw.read()) 309 | 310 | #raise RuntimeError( 311 | # "Dataset not found. Please download {} from {} and move it to {}".format( 312 | # self.raw_file_names, self.url, self.raw_dir 313 | # ) 314 | #) 315 | 316 | def process(self): 317 | pdb_dir = Path(self.root) / "raw" / "01-benchmark_pdbs" 318 | surf_dir = Path(self.root) / "raw" / "01-benchmark_surfaces" 319 | protein_dir = Path(self.root) / "raw" / "01-benchmark_surfaces_npy" 320 | lists_dir = Path('./lists') 321 | 322 | # Untar surface files 323 | if not (pdb_dir.exists() and surf_dir.exists()): 324 | tar = tarfile.open(self.raw_paths[0]) 325 | tar.extractall(self.raw_dir) 326 | tar.close() 327 | 328 | if not protein_dir.exists(): 329 | protein_dir.mkdir(parents=False, exist_ok=False) 330 | convert_plys(surf_dir,protein_dir) 331 | convert_pdbs(pdb_dir,protein_dir) 332 | 333 | with open(lists_dir / "training.txt") as f_tr, open( 334 | lists_dir / "testing.txt" 335 | ) as f_ts: 336 | training_list = sorted(f_tr.read().splitlines()) 337 | testing_list = sorted(f_ts.read().splitlines()) 338 | 339 | with open(lists_dir / "training_ppi.txt") as f_tr, open( 340 | lists_dir / "testing_ppi.txt" 341 | ) as f_ts: 342 | training_pairs_list = sorted(f_tr.read().splitlines()) 343 | testing_pairs_list = sorted(f_ts.read().splitlines()) 344 | pairs_list = sorted(training_pairs_list + testing_pairs_list) 345 | 346 | if not self.ppi: 347 | training_pairs_list = [] 348 | for p in pairs_list: 349 | pspl = p.split("_") 350 | p1 = pspl[0] + "_" + pspl[1] 351 | p2 = pspl[0] + "_" + pspl[2] 352 | 353 | if p1 in training_list: 354 | training_pairs_list.append(p) 355 | if p2 in training_list: 356 | training_pairs_list.append(pspl[0] + "_" + pspl[2] + "_" + pspl[1]) 357 | 358 | testing_pairs_list = [] 359 | for p in pairs_list: 360 | pspl = p.split("_") 361 | p1 = pspl[0] + "_" + pspl[1] 362 | p2 = pspl[0] + "_" + pspl[2] 363 | if p1 in testing_list: 364 | testing_pairs_list.append(p) 365 | if p2 in testing_list: 366 | testing_pairs_list.append(pspl[0] + "_" + pspl[2] + "_" + pspl[1]) 367 | 368 | # # Read data into huge `Data` list. 369 | training_pairs_data = [] 370 | training_pairs_data_ids = [] 371 | for p in training_pairs_list: 372 | try: 373 | protein_pair = load_protein_pair(p, protein_dir) 374 | except FileNotFoundError: 375 | continue 376 | training_pairs_data.append(protein_pair) 377 | training_pairs_data_ids.append(p) 378 | 379 | testing_pairs_data = [] 380 | testing_pairs_data_ids = [] 381 | for p in testing_pairs_list: 382 | try: 383 | protein_pair = load_protein_pair(p, protein_dir) 384 | except FileNotFoundError: 385 | continue 386 | testing_pairs_data.append(protein_pair) 387 | testing_pairs_data_ids.append(p) 388 | 389 | if self.pre_filter is not None: 390 | training_pairs_data = [ 391 | data for data in training_pairs_data if self.pre_filter(data) 392 | ] 393 | testing_pairs_data = [ 394 | data for data in testing_pairs_data if self.pre_filter(data) 395 | ] 396 | 397 | if self.pre_transform is not None: 398 | training_pairs_data = [ 399 | self.pre_transform(data) for data in training_pairs_data 400 | ] 401 | testing_pairs_data = [ 402 | self.pre_transform(data) for data in testing_pairs_data 403 | ] 404 | 405 | training_pairs_data, training_pairs_slices = self.collate(training_pairs_data) 406 | torch.save( 407 | (training_pairs_data, training_pairs_slices), self.processed_paths[0] 408 | ) 409 | np.save(self.processed_paths[2], training_pairs_data_ids) 410 | testing_pairs_data, testing_pairs_slices = self.collate(testing_pairs_data) 411 | torch.save((testing_pairs_data, testing_pairs_slices), self.processed_paths[1]) 412 | np.save(self.processed_paths[3], testing_pairs_data_ids) 413 | -------------------------------------------------------------------------------- /data_analysis/analyse_descriptors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from sklearn.metrics import roc_auc_score, roc_curve 4 | from scipy.spatial.distance import cdist 5 | import matplotlib.pyplot as plt 6 | 7 | top_dir = Path('..') 8 | experiment_names = ['TangentConv_search_3L_16dim_12A_FIXED_binet_c_restarted_epoch43', 9 | 'TangentConv_search_3L_16dim_12A_FIXED_binet_g_restarted_epoch38', 10 | 'TangentConv_search_1L_8dim_12A_FIXED_binet_gc_epoch34', 11 | 'TangentConv_search_3L_8dim_12A_FIXED_binet_gc_restarted_epoch49', 12 | 'TangentConv_search_1L_16dim_12A_FIXED_binet_gc_epoch45'] 13 | 14 | with open(top_dir/'surface_data/raw/protein_surfaces/testing_ppi.txt') as f: 15 | testing_list = f.read().splitlines() 16 | 17 | pdb_list = testing_list 18 | 19 | for experiment_name in experiment_names: 20 | print(experiment_name) 21 | desc_dir = top_dir/f'preds/{experiment_name}' 22 | all_roc_aucs = [] 23 | all_preds = [] 24 | all_labels = [] 25 | for i, pdb_id in enumerate(pdb_list): 26 | pdb_id1 = pdb_id.split('_')[0]+'_'+pdb_id.split('_')[1] 27 | pdb_id2 = pdb_id.split('_')[0]+'_'+pdb_id.split('_')[2] 28 | if i%100==0: 29 | print(i,np.mean(all_roc_aucs)) 30 | 31 | try: 32 | desc1 = np.load(desc_dir/f'{pdb_id1}_predfeatures.npy')[:,16:16+16] 33 | desc2 = np.load(desc_dir/f'{pdb_id2}_predfeatures.npy')[:,16:16+16] 34 | xyz1 = np.load(desc_dir/f'{pdb_id1}_predcoords.npy') 35 | xyz2 = np.load(desc_dir/f'{pdb_id2}_predcoords.npy') 36 | except FileNotFoundError: 37 | continue 38 | 39 | dists = cdist(xyz1,xyz2)<1.0 40 | if dists.sum()<1: 41 | continue 42 | 43 | iface_pos1 = dists.sum(1)>0 44 | iface_pos2 = dists.sum(0)>0 45 | 46 | pos_dists1 = dists[iface_pos1,:] 47 | pos_dists2 = dists[:,iface_pos2] 48 | 49 | desc_dists = np.matmul(desc1,desc2.T) 50 | #desc_dists = 1/cdist(desc1,desc2) 51 | 52 | pos_dists = desc_dists[dists].reshape(-1) 53 | pos_labels = np.ones_like(pos_dists) 54 | neg_dists1 = desc_dists[iface_pos1,:][pos_dists1==0].reshape(-1) 55 | neg_dists2 = desc_dists[:,iface_pos2][pos_dists2==0].reshape(-1) 56 | 57 | #neg_dists = np.concatenate([neg_dists1,neg_dists2],axis=0) 58 | neg_dists = neg_dists1 59 | neg_dists = np.random.choice(neg_dists,200,replace=False) 60 | neg_labels = np.zeros_like(neg_dists) 61 | 62 | preds = np.concatenate([pos_dists,neg_dists]) 63 | labels = np.concatenate([pos_labels,neg_labels]) 64 | 65 | roc_auc = roc_auc_score(labels,preds) 66 | all_roc_aucs.append(roc_auc) 67 | all_preds.extend(list(preds)) 68 | all_labels.extend(list(labels)) 69 | 70 | 71 | fpr, tpr, thresholds = roc_curve(all_labels,all_preds) 72 | np.save(f'roc_curves/{experiment_name}_fpr.npy',fpr) 73 | np.save(f'roc_curves/{experiment_name}_tpr.npy',tpr) 74 | np.save(f'roc_curves/{experiment_name}_all_labels.npy',all_labels) 75 | np.save(f'roc_curves/{experiment_name}_all_preds.npy',all_preds) 76 | 77 | -------------------------------------------------------------------------------- /data_analysis/analyse_descriptors_para.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from sklearn.metrics import roc_auc_score, roc_curve 4 | from scipy.spatial.distance import cdist 5 | import matplotlib.pyplot as plt 6 | import dask 7 | 8 | top_dir = Path('..') 9 | experiment_names = ['TangentConv_search_3L_16dim_12A_FIXED_binet_c_restarted_epoch43', 10 | 'TangentConv_search_3L_16dim_12A_FIXED_binet_g_restarted_epoch38', 11 | 'TangentConv_search_1L_8dim_12A_FIXED_binet_gc_epoch34', 12 | 'TangentConv_search_3L_8dim_12A_FIXED_binet_gc_restarted_epoch49', 13 | 'TangentConv_search_1L_16dim_12A_FIXED_binet_gc_epoch45'] 14 | experiment_names = ['TangentConv_search_3L_16dim_12A_FIXED_binet_gc_restarted_epoch47'] 15 | experiment_names = ['TangentConv_search_3L_16dim_12A_FIXED_binet_gc_subsamp50_epoch25'] 16 | experiment_names = ['TangentConv_search_3L_16dim_12A_FIXED_binet_gc_subsamp50_restarted_restarted_restarted_epoch53'] 17 | 18 | 19 | ndims = [16,16,8,8,16] 20 | ndims = [16] 21 | 22 | 23 | with open(top_dir/'surface_data/raw/protein_surfaces/testing_ppi.txt') as f: 24 | testing_list = f.read().splitlines() 25 | 26 | pdb_list = testing_list 27 | 28 | @dask.delayed 29 | def analyse_pdb(pdb_id,D): 30 | pdb_id1 = pdb_id.split('_')[0]+'_'+pdb_id.split('_')[1] 31 | pdb_id2 = pdb_id.split('_')[0]+'_'+pdb_id.split('_')[2] 32 | 33 | try: 34 | desc1 = np.load(desc_dir/f'{pdb_id1}_predfeatures.npy')[:,16:16+D] 35 | desc2 = np.load(desc_dir/f'{pdb_id2}_predfeatures.npy')[:,16:16+D] 36 | xyz1 = np.load(desc_dir/f'{pdb_id1}_predcoords.npy') 37 | xyz2 = np.load(desc_dir/f'{pdb_id2}_predcoords.npy') 38 | except FileNotFoundError: 39 | return -1 40 | 41 | dists = cdist(xyz1,xyz2)<1.0 42 | if dists.sum()<1: 43 | return -1 44 | 45 | iface_pos1 = dists.sum(1)>0 46 | iface_pos2 = dists.sum(0)>0 47 | 48 | pos_dists1 = dists[iface_pos1,:] 49 | pos_dists2 = dists[:,iface_pos2] 50 | 51 | desc_dists = np.matmul(desc1,desc2.T) 52 | #desc_dists = 1/cdist(desc1,desc2) 53 | 54 | pos_dists = desc_dists[dists].reshape(-1) 55 | pos_labels = np.ones_like(pos_dists) 56 | neg_dists1 = desc_dists[iface_pos1,:][pos_dists1==0].reshape(-1) 57 | neg_dists2 = desc_dists[:,iface_pos2][pos_dists2==0].reshape(-1) 58 | 59 | #neg_dists = np.concatenate([neg_dists1,neg_dists2],axis=0) 60 | neg_dists = neg_dists1 61 | neg_dists = np.random.choice(neg_dists,400,replace=False) 62 | neg_labels = np.zeros_like(neg_dists) 63 | 64 | preds = np.concatenate([pos_dists,neg_dists]) 65 | labels = np.concatenate([pos_labels,neg_labels]) 66 | 67 | roc_auc = roc_auc_score(labels,preds) 68 | 69 | return roc_auc, preds, labels 70 | 71 | for experiment_name, D in zip(experiment_names,ndims): 72 | print(experiment_name) 73 | desc_dir = top_dir/f'preds/{experiment_name}' 74 | all_roc_aucs = [] 75 | all_preds = [] 76 | all_labels = [] 77 | all_res = [] 78 | for i, pdb_id in enumerate(pdb_list): 79 | res = analyse_pdb(pdb_id,D) 80 | all_res.append(res) 81 | 82 | all_res = dask.compute(*all_res) 83 | for res in all_res: 84 | if res==-1: 85 | continue 86 | all_roc_aucs.append(res[0]) 87 | all_preds.extend(list(res[1])) 88 | all_labels.extend(list(res[2])) 89 | 90 | print('ROC-AUC',np.mean(all_roc_aucs)) 91 | 92 | fpr, tpr, thresholds = roc_curve(all_labels,all_preds) 93 | np.save(f'roc_curves/{experiment_name}_fpr.npy',fpr) 94 | np.save(f'roc_curves/{experiment_name}_tpr.npy',tpr) 95 | np.save(f'roc_curves/{experiment_name}_all_labels.npy',all_labels) 96 | np.save(f'roc_curves/{experiment_name}_all_preds.npy',all_preds) 97 | 98 | -------------------------------------------------------------------------------- /data_analysis/analyse_site_outputs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from pathlib import Path 4 | from tqdm import tqdm 5 | from scipy.spatial.distance import cdist 6 | from sklearn.metrics import roc_curve, roc_auc_score 7 | 8 | 9 | masif_preds = Path("masif_preds/") 10 | timings = Path("timings/") 11 | raw_data = Path("surface_data/raw/protein_surfaces/01-benchmark_surfaces_npy") 12 | 13 | experiment_names = [ 14 | "TangentConv_site_1layer_5A_epoch49", 15 | "TangentConv_site_1layer_9A_epoch49", 16 | "TangentConv_site_1layer_15A_epoch49", 17 | "TangentConv_site_3layer_15A_epoch17", 18 | "TangentConv_site_3layer_5A_epoch49", 19 | "TangentConv_site_3layer_9A_epoch46", 20 | "PointNet_site_3layer_9A_epoch37", 21 | "PointNet_site_3layer_5A_epoch46", 22 | "DGCNN_site_1layer_k100_epoch32", 23 | "PointNet_site_1layer_5A_epoch30", 24 | "PointNet_site_1layer_9A_epoch30", 25 | "DGCNN_site_1layer_k40_epoch46", 26 | "DGCNN_site_3layer_k40_epoch33", 27 | ] 28 | 29 | experiment_names = [ 30 | 'Rebuttal_TangentConv_site_1L_8dim_9A_gc_subsamp20_dist05_epoch42', 31 | 'Rebuttal_TangentConv_site_1L_8dim_9A_gc_subsamp20_dist20_epoch49', 32 | 'Rebuttal_TangentConv_site_1L_8dim_9A_gc_subsamp20_dist105_epoch44', 33 | 'Rebuttal_TangentConv_site_1L_8dim_9A_gc_subsamp20_var01_epoch43', 34 | 'Rebuttal_TangentConv_site_1L_8dim_9A_gc_subsamp20_var02_epoch49', 35 | 'Rebuttal_TangentConv_site_1L_8dim_9A_gc_subsamp20_var005_epoch37' 36 | ] 37 | 38 | for experiment_name in experiment_names: 39 | print(experiment_name) 40 | datafolder = Path(f"preds/{experiment_name}") 41 | pdb_list = [p.stem[:-5] for p in datafolder.glob("*pred.vtk")] 42 | 43 | n_meshpoints = [] 44 | n_predpoints = [] 45 | meshpoints_mindists = [] 46 | predpoints_mindists = [] 47 | for pdb_id in tqdm(pdb_list): 48 | predpoints = np.load(datafolder / (pdb_id + "_predcoords.npy")) 49 | meshpoints = np.load(datafolder / (pdb_id + "_meshpoints.npy")) 50 | n_meshpoints.append(meshpoints.shape[0]) 51 | n_predpoints.append(predpoints.shape[0]) 52 | 53 | pdists = cdist(meshpoints, predpoints) 54 | meshpoints_mindists.append(pdists.min(1)) 55 | predpoints_mindists.append(pdists.min(0)) 56 | 57 | all_meshpoints_mindists = np.concatenate(meshpoints_mindists) 58 | all_predpoints_mindists = np.concatenate(predpoints_mindists) 59 | 60 | meshpoint_percentile = np.percentile(all_meshpoints_mindists, 99) 61 | predpoint_percentile = np.percentile(all_predpoints_mindists, 99) 62 | 63 | meshpoints_masks = [] 64 | predpoints_masks = [] 65 | for pdb_id in tqdm(pdb_list): 66 | predpoints = np.load(datafolder / (pdb_id + "_predcoords.npy")) 67 | meshpoints = np.load(datafolder / (pdb_id + "_meshpoints.npy")) 68 | 69 | pdists = cdist(meshpoints, predpoints) 70 | meshpoints_masks.append(pdists.min(1) < meshpoint_percentile) 71 | predpoints_masks.append(pdists.min(0) < predpoint_percentile) 72 | 73 | predpoints_preds = [] 74 | predpoints_labels = [] 75 | npoints = [] 76 | for i, pdb_id in enumerate(tqdm(pdb_list)): 77 | predpoints_features = np.load(datafolder / (pdb_id + "_predfeatures.npy")) 78 | predpoints_features = predpoints_features[predpoints_masks[i]] 79 | 80 | predpoints_preds.append(predpoints_features[:, -2]) 81 | predpoints_labels.append(predpoints_features[:, -1]) 82 | npoints.append(predpoints_features.shape[0]) 83 | 84 | predpoints_labels = np.concatenate(predpoints_labels) 85 | predpoints_preds = np.concatenate(predpoints_preds) 86 | rocauc = roc_auc_score(predpoints_labels.reshape(-1), predpoints_preds.reshape(-1)) 87 | print("ROC-AUC", rocauc) 88 | 89 | np.save(timings / f"{experiment_name}_predpoints_preds", predpoints_preds) 90 | np.save(timings / f"{experiment_name}_predpoints_labels", predpoints_labels) 91 | np.save(timings / f"{experiment_name}_npoints", npoints) 92 | -------------------------------------------------------------------------------- /data_iteration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from helper import * 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.autograd.profiler as profiler 7 | from sklearn.metrics import roc_auc_score 8 | from pathlib import Path 9 | import math 10 | from tqdm import tqdm 11 | from geometry_processing import save_vtk 12 | from helper import numpy, diagonal_ranges 13 | import time 14 | 15 | 16 | def process_single(protein_pair, chain_idx=1): 17 | """Turn the PyG data object into a dict.""" 18 | 19 | P = {} 20 | with_mesh = "face_p1" in protein_pair.keys 21 | preprocessed = "gen_xyz_p1" in protein_pair.keys 22 | 23 | if chain_idx == 1: 24 | # Ground truth labels are available on mesh vertices: 25 | P["mesh_labels"] = protein_pair.y_p1 if with_mesh else None 26 | 27 | # N.B.: The DataLoader should use the optional argument 28 | # "follow_batch=['xyz_p1', 'xyz_p2']", as described on the PyG tutorial. 29 | P["mesh_batch"] = protein_pair.xyz_p1_batch if with_mesh else None 30 | 31 | # Surface information: 32 | P["mesh_xyz"] = protein_pair.xyz_p1 if with_mesh else None 33 | P["mesh_triangles"] = protein_pair.face_p1 if with_mesh else None 34 | 35 | # Atom information: 36 | P["atoms"] = protein_pair.atom_coords_p1 37 | P["batch_atoms"] = protein_pair.atom_coords_p1_batch 38 | 39 | # Chemical features: atom coordinates and types. 40 | P["atom_xyz"] = protein_pair.atom_coords_p1 41 | P["atomtypes"] = protein_pair.atom_types_p1 42 | 43 | P["xyz"] = protein_pair.gen_xyz_p1 if preprocessed else None 44 | P["normals"] = protein_pair.gen_normals_p1 if preprocessed else None 45 | P["batch"] = protein_pair.gen_batch_p1 if preprocessed else None 46 | P["labels"] = protein_pair.gen_labels_p1 if preprocessed else None 47 | 48 | elif chain_idx == 2: 49 | # Ground truth labels are available on mesh vertices: 50 | P["mesh_labels"] = protein_pair.y_p2 if with_mesh else None 51 | 52 | # N.B.: The DataLoader should use the optional argument 53 | # "follow_batch=['xyz_p1', 'xyz_p2']", as described on the PyG tutorial. 54 | P["mesh_batch"] = protein_pair.xyz_p2_batch if with_mesh else None 55 | 56 | # Surface information: 57 | P["mesh_xyz"] = protein_pair.xyz_p2 if with_mesh else None 58 | P["mesh_triangles"] = protein_pair.face_p2 if with_mesh else None 59 | 60 | # Atom information: 61 | P["atoms"] = protein_pair.atom_coords_p2 62 | P["batch_atoms"] = protein_pair.atom_coords_p2_batch 63 | 64 | # Chemical features: atom coordinates and types. 65 | P["atom_xyz"] = protein_pair.atom_coords_p2 66 | P["atomtypes"] = protein_pair.atom_types_p2 67 | 68 | P["xyz"] = protein_pair.gen_xyz_p2 if preprocessed else None 69 | P["normals"] = protein_pair.gen_normals_p2 if preprocessed else None 70 | P["batch"] = protein_pair.gen_batch_p2 if preprocessed else None 71 | P["labels"] = protein_pair.gen_labels_p2 if preprocessed else None 72 | 73 | return P 74 | 75 | 76 | def save_protein_batch_single(protein_pair_id, P, save_path, pdb_idx): 77 | 78 | protein_pair_id = protein_pair_id.split("_") 79 | pdb_id = protein_pair_id[0] + "_" + protein_pair_id[pdb_idx] 80 | 81 | batch = P["batch"] 82 | 83 | xyz = P["xyz"] 84 | 85 | inputs = P["input_features"] 86 | 87 | embedding = P["embedding_1"] if pdb_idx == 1 else P["embedding_2"] 88 | emb_id = 1 if pdb_idx == 1 else 2 89 | 90 | predictions = torch.sigmoid(P["iface_preds"]) if "iface_preds" in P.keys() else 0.0*embedding[:,0].view(-1, 1) 91 | 92 | labels = P["labels"].view(-1, 1) if P["labels"] is not None else 0.0 * predictions 93 | 94 | coloring = torch.cat([inputs, embedding, predictions, labels], axis=1) 95 | 96 | save_vtk(str(save_path / pdb_id) + f"_pred_emb{emb_id}", xyz, values=coloring) 97 | np.save(str(save_path / pdb_id) + "_predcoords", numpy(xyz)) 98 | np.save(str(save_path / pdb_id) + f"_predfeatures_emb{emb_id}", numpy(coloring)) 99 | 100 | 101 | def project_iface_labels(P, threshold=2.0): 102 | 103 | queries = P["xyz"] 104 | batch_queries = P["batch"] 105 | source = P["mesh_xyz"] 106 | batch_source = P["mesh_batch"] 107 | labels = P["mesh_labels"] 108 | x_i = LazyTensor(queries[:, None, :]) # (N, 1, D) 109 | y_j = LazyTensor(source[None, :, :]) # (1, M, D) 110 | 111 | D_ij = ((x_i - y_j) ** 2).sum(-1) # (N, M) 112 | D_ij.ranges = diagonal_ranges(batch_queries, batch_source) 113 | nn_i = D_ij.argmin(dim=1).view(-1) # (N,) 114 | nn_dist_i = ( 115 | D_ij.min(dim=1).view(-1, 1) < threshold 116 | ).float() # If chain is not connected because of missing densities MaSIF cut out a part of the protein 117 | 118 | query_labels = labels[nn_i] * nn_dist_i 119 | 120 | P["labels"] = query_labels 121 | 122 | 123 | def process(args, protein_pair, net): 124 | P1 = process_single(protein_pair, chain_idx=1) 125 | if not "gen_xyz_p1" in protein_pair.keys: 126 | net.preprocess_surface(P1) 127 | #if P1["mesh_labels"] is not None: 128 | # project_iface_labels(P1) 129 | P2 = None 130 | if not args.single_protein: 131 | P2 = process_single(protein_pair, chain_idx=2) 132 | if not "gen_xyz_p2" in protein_pair.keys: 133 | net.preprocess_surface(P2) 134 | #if P2["mesh_labels"] is not None: 135 | # project_iface_labels(P2) 136 | 137 | return P1, P2 138 | 139 | 140 | def generate_matchinglabels(args, P1, P2): 141 | if args.random_rotation: 142 | P1["xyz"] = torch.matmul(P1["rand_rot"].T, P1["xyz"].T).T + P1["atom_center"] 143 | P2["xyz"] = torch.matmul(P2["rand_rot"].T, P2["xyz"].T).T + P2["atom_center"] 144 | xyz1_i = LazyTensor(P1["xyz"][:, None, :].contiguous()) 145 | xyz2_j = LazyTensor(P2["xyz"][None, :, :].contiguous()) 146 | 147 | xyz_dists = ((xyz1_i - xyz2_j) ** 2).sum(-1).sqrt() 148 | xyz_dists = (1.0 - xyz_dists).step() 149 | 150 | p1_iface_labels = (xyz_dists.sum(1) > 1.0).float().view(-1) 151 | p2_iface_labels = (xyz_dists.sum(0) > 1.0).float().view(-1) 152 | 153 | P1["labels"] = p1_iface_labels 154 | P2["labels"] = p2_iface_labels 155 | 156 | 157 | def compute_loss(args, P1, P2, n_points_sample=16): 158 | 159 | if args.search: 160 | pos_xyz1 = P1["xyz"][P1["labels"] == 1] 161 | pos_xyz2 = P2["xyz"][P2["labels"] == 1] 162 | pos_descs1 = P1["embedding_1"][P1["labels"] == 1] 163 | pos_descs2 = P2["embedding_2"][P2["labels"] == 1] 164 | 165 | pos_xyz_dists = ( 166 | ((pos_xyz1[:, None, :] - pos_xyz2[None, :, :]) ** 2).sum(-1).sqrt() 167 | ) 168 | pos_desc_dists = torch.matmul(pos_descs1, pos_descs2.T) 169 | 170 | pos_preds = pos_desc_dists[pos_xyz_dists < 1.0] 171 | pos_labels = torch.ones_like(pos_preds) 172 | 173 | n_desc_sample = 100 174 | sample_desc2 = torch.randperm(len(P2["embedding_2"]))[:n_desc_sample] 175 | sample_desc2 = P2["embedding_2"][sample_desc2] 176 | neg_preds = torch.matmul(pos_descs1, sample_desc2.T).view(-1) 177 | neg_labels = torch.zeros_like(neg_preds) 178 | 179 | # For symmetry 180 | pos_descs1_2 = P1["embedding_2"][P1["labels"] == 1] 181 | pos_descs2_2 = P2["embedding_1"][P2["labels"] == 1] 182 | 183 | pos_desc_dists2 = torch.matmul(pos_descs2_2, pos_descs1_2.T) 184 | pos_preds2 = pos_desc_dists2[pos_xyz_dists.T < 1.0] 185 | pos_preds = torch.cat([pos_preds, pos_preds2], dim=0) 186 | pos_labels = torch.ones_like(pos_preds) 187 | 188 | sample_desc1_2 = torch.randperm(len(P1["embedding_2"]))[:n_desc_sample] 189 | sample_desc1_2 = P1["embedding_2"][sample_desc1_2] 190 | neg_preds_2 = torch.matmul(pos_descs2_2, sample_desc1_2.T).view(-1) 191 | 192 | neg_preds = torch.cat([neg_preds, neg_preds_2], dim=0) 193 | neg_labels = torch.zeros_like(neg_preds) 194 | 195 | else: 196 | pos_preds = P1["iface_preds"][P1["labels"] == 1] 197 | pos_labels = P1["labels"][P1["labels"] == 1] 198 | neg_preds = P1["iface_preds"][P1["labels"] == 0] 199 | neg_labels = P1["labels"][P1["labels"] == 0] 200 | 201 | n_points_sample = len(pos_labels) 202 | pos_indices = torch.randperm(len(pos_labels))[:n_points_sample] 203 | neg_indices = torch.randperm(len(neg_labels))[:n_points_sample] 204 | 205 | pos_preds = pos_preds[pos_indices] 206 | pos_labels = pos_labels[pos_indices] 207 | neg_preds = neg_preds[neg_indices] 208 | neg_labels = neg_labels[neg_indices] 209 | 210 | preds_concat = torch.cat([pos_preds, neg_preds]) 211 | labels_concat = torch.cat([pos_labels, neg_labels]) 212 | 213 | loss = F.binary_cross_entropy_with_logits(preds_concat, labels_concat) 214 | 215 | return loss, preds_concat, labels_concat 216 | 217 | 218 | def extract_single(P_batch, number): 219 | P = {} # First and second proteins 220 | batch = P_batch["batch"] == number 221 | batch_atoms = P_batch["batch_atoms"] == number 222 | 223 | with_mesh = P_batch["labels"] is not None 224 | # Ground truth labels are available on mesh vertices: 225 | P["labels"] = P_batch["labels"][batch] if with_mesh else None 226 | 227 | P["batch"] = P_batch["batch"][batch] 228 | 229 | # Surface information: 230 | P["xyz"] = P_batch["xyz"][batch] 231 | P["normals"] = P_batch["normals"][batch] 232 | 233 | # Atom information: 234 | P["atoms"] = P_batch["atoms"][batch_atoms] 235 | P["batch_atoms"] = P_batch["batch_atoms"][batch_atoms] 236 | 237 | # Chemical features: atom coordinates and types. 238 | P["atom_xyz"] = P_batch["atom_xyz"][batch_atoms] 239 | P["atomtypes"] = P_batch["atomtypes"][batch_atoms] 240 | 241 | return P 242 | 243 | 244 | def iterate( 245 | net, 246 | dataset, 247 | optimizer, 248 | args, 249 | test=False, 250 | save_path=None, 251 | pdb_ids=None, 252 | summary_writer=None, 253 | epoch_number=None, 254 | ): 255 | """Goes through one epoch of the dataset, returns information for Tensorboard.""" 256 | 257 | if test: 258 | net.eval() 259 | torch.set_grad_enabled(False) 260 | else: 261 | net.train() 262 | torch.set_grad_enabled(True) 263 | 264 | # Statistics and fancy graphs to summarize the epoch: 265 | info = [] 266 | total_processed_pairs = 0 267 | # Loop over one epoch: 268 | for it, protein_pair in enumerate( 269 | tqdm(dataset) 270 | ): # , desc="Test " if test else "Train")): 271 | protein_batch_size = protein_pair.atom_coords_p1_batch[-1].item() + 1 272 | if save_path is not None: 273 | batch_ids = pdb_ids[ 274 | total_processed_pairs : total_processed_pairs + protein_batch_size 275 | ] 276 | total_processed_pairs += protein_batch_size 277 | 278 | protein_pair.to(args.device) 279 | 280 | if not test: 281 | optimizer.zero_grad() 282 | 283 | # Generate the surface: 284 | torch.cuda.synchronize() 285 | surface_time = time.time() 286 | P1_batch, P2_batch = process(args, protein_pair, net) 287 | torch.cuda.synchronize() 288 | surface_time = time.time() - surface_time 289 | 290 | for protein_it in range(protein_batch_size): 291 | torch.cuda.synchronize() 292 | iteration_time = time.time() 293 | 294 | P1 = extract_single(P1_batch, protein_it) 295 | P2 = None if args.single_protein else extract_single(P2_batch, protein_it) 296 | 297 | 298 | if args.random_rotation: 299 | P1["rand_rot"] = protein_pair.rand_rot1.view(-1, 3, 3)[0] 300 | P1["atom_center"] = protein_pair.atom_center1.view(-1, 1, 3)[0] 301 | P1["xyz"] = P1["xyz"] - P1["atom_center"] 302 | P1["xyz"] = ( 303 | torch.matmul(P1["rand_rot"], P1["xyz"].T).T 304 | ).contiguous() 305 | P1["normals"] = ( 306 | torch.matmul(P1["rand_rot"], P1["normals"].T).T 307 | ).contiguous() 308 | if not args.single_protein: 309 | P2["rand_rot"] = protein_pair.rand_rot2.view(-1, 3, 3)[0] 310 | P2["atom_center"] = protein_pair.atom_center2.view(-1, 1, 3)[0] 311 | P2["xyz"] = P2["xyz"] - P2["atom_center"] 312 | P2["xyz"] = ( 313 | torch.matmul(P2["rand_rot"], P2["xyz"].T).T 314 | ).contiguous() 315 | P2["normals"] = ( 316 | torch.matmul(P2["rand_rot"], P2["normals"].T).T 317 | ).contiguous() 318 | else: 319 | P1["rand_rot"] = torch.eye(3, device=P1["xyz"].device) 320 | P1["atom_center"] = torch.zeros((1, 3), device=P1["xyz"].device) 321 | if not args.single_protein: 322 | P2["rand_rot"] = torch.eye(3, device=P2["xyz"].device) 323 | P2["atom_center"] = torch.zeros((1, 3), device=P2["xyz"].device) 324 | 325 | torch.cuda.synchronize() 326 | prediction_time = time.time() 327 | outputs = net(P1, P2) 328 | torch.cuda.synchronize() 329 | prediction_time = time.time() - prediction_time 330 | 331 | P1 = outputs["P1"] 332 | P2 = outputs["P2"] 333 | 334 | if args.search: 335 | generate_matchinglabels(args, P1, P2) 336 | 337 | if P1["labels"] is not None: 338 | loss, sampled_preds, sampled_labels = compute_loss(args, P1, P2) 339 | else: 340 | loss = torch.tensor(0.0) 341 | sampled_preds = None 342 | sampled_labels = None 343 | 344 | # Compute the gradient, update the model weights: 345 | if not test: 346 | torch.cuda.synchronize() 347 | back_time = time.time() 348 | loss.backward() 349 | optimizer.step() 350 | torch.cuda.synchronize() 351 | back_time = time.time() - back_time 352 | 353 | if it == protein_it == 0 and not test: 354 | for para_it, parameter in enumerate(net.atomnet.parameters()): 355 | if parameter.requires_grad: 356 | summary_writer.add_histogram( 357 | f"Gradients/Atomnet/para_{para_it}_{parameter.shape}", 358 | parameter.grad.view(-1), 359 | epoch_number, 360 | ) 361 | for para_it, parameter in enumerate(net.conv.parameters()): 362 | if parameter.requires_grad: 363 | summary_writer.add_histogram( 364 | f"Gradients/Conv/para_{para_it}_{parameter.shape}", 365 | parameter.grad.view(-1), 366 | epoch_number, 367 | ) 368 | 369 | for d, features in enumerate(P1["input_features"].T): 370 | summary_writer.add_histogram(f"Input features/{d}", features) 371 | 372 | if save_path is not None: 373 | save_protein_batch_single( 374 | batch_ids[protein_it], P1, save_path, pdb_idx=1 375 | ) 376 | if not args.single_protein: 377 | save_protein_batch_single( 378 | batch_ids[protein_it], P2, save_path, pdb_idx=2 379 | ) 380 | 381 | try: 382 | if sampled_labels is not None: 383 | roc_auc = roc_auc_score( 384 | np.rint(numpy(sampled_labels.view(-1))), 385 | numpy(sampled_preds.view(-1)), 386 | ) 387 | else: 388 | roc_auc = 0.0 389 | except Exception as e: 390 | print("Problem with computing roc-auc") 391 | print(e) 392 | continue 393 | 394 | R_values = outputs["R_values"] 395 | 396 | info.append( 397 | dict( 398 | { 399 | "Loss": loss.item(), 400 | "ROC-AUC": roc_auc, 401 | "conv_time": outputs["conv_time"], 402 | "memory_usage": outputs["memory_usage"], 403 | }, 404 | # Merge the "R_values" dict into "info", with a prefix: 405 | **{"R_values/" + k: v for k, v in R_values.items()}, 406 | ) 407 | ) 408 | torch.cuda.synchronize() 409 | iteration_time = time.time() - iteration_time 410 | 411 | # Turn a list of dicts into a dict of lists: 412 | newdict = {} 413 | for k, v in [(key, d[key]) for d in info for key in d]: 414 | if k not in newdict: 415 | newdict[k] = [v] 416 | else: 417 | newdict[k].append(v) 418 | info = newdict 419 | 420 | # Final post-processing: 421 | return info 422 | 423 | def iterate_surface_precompute(dataset, net, args): 424 | processed_dataset = [] 425 | for it, protein_pair in enumerate(tqdm(dataset)): 426 | protein_pair.to(args.device) 427 | P1, P2 = process(args, protein_pair, net) 428 | if args.random_rotation: 429 | P1["rand_rot"] = protein_pair.rand_rot1 430 | P1["atom_center"] = protein_pair.atom_center1 431 | P1["xyz"] = ( 432 | torch.matmul(P1["rand_rot"].T, P1["xyz"].T).T + P1["atom_center"] 433 | ) 434 | P1["normals"] = torch.matmul(P1["rand_rot"].T, P1["normals"].T).T 435 | if not args.single_protein: 436 | P2["rand_rot"] = protein_pair.rand_rot2 437 | P2["atom_center"] = protein_pair.atom_center2 438 | P2["xyz"] = ( 439 | torch.matmul(P2["rand_rot"].T, P2["xyz"].T).T + P2["atom_center"] 440 | ) 441 | P2["normals"] = torch.matmul(P2["rand_rot"].T, P2["normals"].T).T 442 | protein_pair = protein_pair.to_data_list()[0] 443 | protein_pair.gen_xyz_p1 = P1["xyz"] 444 | protein_pair.gen_normals_p1 = P1["normals"] 445 | protein_pair.gen_batch_p1 = P1["batch"] 446 | protein_pair.gen_labels_p1 = P1["labels"] 447 | protein_pair.gen_xyz_p2 = P2["xyz"] 448 | protein_pair.gen_normals_p2 = P2["normals"] 449 | protein_pair.gen_batch_p2 = P2["batch"] 450 | protein_pair.gen_labels_p2 = P2["labels"] 451 | processed_dataset.append(protein_pair.to("cpu")) 452 | return processed_dataset 453 | -------------------------------------------------------------------------------- /data_preprocessing/convert_pdb2npy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from tqdm import tqdm 4 | from Bio.PDB import * 5 | 6 | ele2num = {"C": 0, "H": 1, "O": 2, "N": 3, "S": 4, "SE": 5} 7 | 8 | 9 | def load_structure_np(fname, center): 10 | """Loads a .ply mesh to return a point cloud and connectivity.""" 11 | # Load the data 12 | parser = PDBParser() 13 | structure = parser.get_structure("structure", fname) 14 | atoms = structure.get_atoms() 15 | 16 | coords = [] 17 | types = [] 18 | for atom in atoms: 19 | coords.append(atom.get_coord()) 20 | types.append(ele2num[atom.element]) 21 | 22 | coords = np.stack(coords) 23 | types_array = np.zeros((len(types), len(ele2num))) 24 | for i, t in enumerate(types): 25 | types_array[i, t] = 1.0 26 | 27 | # Normalize the coordinates, as specified by the user: 28 | if center: 29 | coords = coords - np.mean(coords, axis=0, keepdims=True) 30 | 31 | return {"xyz": coords, "types": types_array} 32 | 33 | 34 | def convert_pdbs(pdb_dir, npy_dir): 35 | print("Converting PDBs") 36 | for p in tqdm(pdb_dir.glob("*.pdb")): 37 | protein = load_structure_np(p, center=False) 38 | np.save(npy_dir / (p.stem + "_atomxyz.npy"), protein["xyz"]) 39 | np.save(npy_dir / (p.stem + "_atomtypes.npy"), protein["types"]) 40 | -------------------------------------------------------------------------------- /data_preprocessing/convert_ply2npy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from tqdm import tqdm 4 | from plyfile import PlyData, PlyElement 5 | 6 | 7 | def load_surface_np(fname, center): 8 | """Loads a .ply mesh to return a point cloud and connectivity.""" 9 | 10 | # Load the data, and read the connectivity information: 11 | plydata = PlyData.read(str(fname)) 12 | triangles = np.vstack(plydata["face"].data["vertex_indices"]) 13 | 14 | # Normalize the point cloud, as specified by the user: 15 | points = np.vstack([[v[0], v[1], v[2]] for v in plydata["vertex"]]) 16 | if center: 17 | points = points - np.mean(points, axis=0, keepdims=True) 18 | 19 | nx = plydata["vertex"]["nx"] 20 | ny = plydata["vertex"]["ny"] 21 | nz = plydata["vertex"]["nz"] 22 | normals = np.stack([nx, ny, nz]).T 23 | 24 | # Interface labels 25 | iface_labels = plydata["vertex"]["iface"] 26 | 27 | # Features 28 | charge = plydata["vertex"]["charge"] 29 | hbond = plydata["vertex"]["hbond"] 30 | hphob = plydata["vertex"]["hphob"] 31 | features = np.stack([charge, hbond, hphob]).T 32 | 33 | return { 34 | "xyz": points, 35 | "triangles": triangles, 36 | "features": features, 37 | "iface_labels": iface_labels, 38 | "normals": normals, 39 | } 40 | 41 | 42 | def convert_plys(ply_dir, npy_dir): 43 | print("Converting PLYs") 44 | for p in tqdm(ply_dir.glob("*.ply")): 45 | protein = load_surface_np(p, center=False) 46 | np.save(npy_dir / (p.stem + "_xyz.npy"), protein["xyz"]) 47 | np.save(npy_dir / (p.stem + "_triangles.npy"), protein["triangles"]) 48 | np.save(npy_dir / (p.stem + "_features.npy"), protein["features"]) 49 | np.save(npy_dir / (p.stem + "_iface_labels.npy"), protein["iface_labels"]) 50 | np.save(npy_dir / (p.stem + "_normals.npy"), protein["normals"]) 51 | 52 | -------------------------------------------------------------------------------- /data_preprocessing/download_pdb.py: -------------------------------------------------------------------------------- 1 | import Bio 2 | from Bio.PDB import * 3 | from Bio.SeqUtils import IUPACData 4 | import sys 5 | import importlib 6 | import os 7 | import numpy as np 8 | from subprocess import Popen, PIPE 9 | from pathlib import Path 10 | from convert_pdb2npy import load_structure_np 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser(description="Arguments") 14 | parser.add_argument( 15 | "--pdb", type=str,default='', help="PDB code along with chains to extract, example 1ABC_A_B", required=False 16 | ) 17 | parser.add_argument( 18 | "--pdb_list", type=str,default='', help="Path to a text file that includes a list of PDB codes along with chains, example 1ABC_A_B", required=False 19 | ) 20 | 21 | tmp_dir = Path('./tmp') 22 | pdb_dir = Path('./pdbs') 23 | npy_dir = Path('./npys') 24 | 25 | PROTEIN_LETTERS = [x.upper() for x in IUPACData.protein_letters_3to1.keys()] 26 | 27 | # Exclude disordered atoms. 28 | class NotDisordered(Select): 29 | def accept_atom(self, atom): 30 | return not atom.is_disordered() or atom.get_altloc() == "A" or atom.get_altloc() == "1" 31 | 32 | 33 | def find_modified_amino_acids(path): 34 | """ 35 | Contributed by github user jomimc - find modified amino acids in the PDB (e.g. MSE) 36 | """ 37 | res_set = set() 38 | for line in open(path, 'r'): 39 | if line[:6] == 'SEQRES': 40 | for res in line.split()[4:]: 41 | res_set.add(res) 42 | for res in list(res_set): 43 | if res in PROTEIN_LETTERS: 44 | res_set.remove(res) 45 | return res_set 46 | 47 | 48 | def extractPDB( 49 | infilename, outfilename, chain_ids=None 50 | ): 51 | # extract the chain_ids from infilename and save in outfilename. 52 | parser = PDBParser(QUIET=True) 53 | struct = parser.get_structure(infilename, infilename) 54 | model = Selection.unfold_entities(struct, "M")[0] 55 | chains = Selection.unfold_entities(struct, "C") 56 | # Select residues to extract and build new structure 57 | structBuild = StructureBuilder.StructureBuilder() 58 | structBuild.init_structure("output") 59 | structBuild.init_seg(" ") 60 | structBuild.init_model(0) 61 | outputStruct = structBuild.get_structure() 62 | 63 | # Load a list of non-standard amino acid names -- these are 64 | # typically listed under HETATM, so they would be typically 65 | # ignored by the orginal algorithm 66 | modified_amino_acids = find_modified_amino_acids(infilename) 67 | 68 | for chain in model: 69 | if ( 70 | chain_ids == None 71 | or chain.get_id() in chain_ids 72 | ): 73 | structBuild.init_chain(chain.get_id()) 74 | for residue in chain: 75 | het = residue.get_id() 76 | if het[0] == " ": 77 | outputStruct[0][chain.get_id()].add(residue) 78 | elif het[0][-3:] in modified_amino_acids: 79 | outputStruct[0][chain.get_id()].add(residue) 80 | 81 | # Output the selected residues 82 | pdbio = PDBIO() 83 | pdbio.set_structure(outputStruct) 84 | pdbio.save(outfilename, select=NotDisordered()) 85 | 86 | def protonate(in_pdb_file, out_pdb_file): 87 | # protonate (i.e., add hydrogens) a pdb using reduce and save to an output file. 88 | # in_pdb_file: file to protonate. 89 | # out_pdb_file: output file where to save the protonated pdb file. 90 | 91 | # Remove protons first, in case the structure is already protonated 92 | args = ["reduce", "-Trim", in_pdb_file] 93 | p2 = Popen(args, stdout=PIPE, stderr=PIPE) 94 | stdout, stderr = p2.communicate() 95 | outfile = open(out_pdb_file, "w") 96 | outfile.write(stdout.decode('utf-8').rstrip()) 97 | outfile.close() 98 | # Now add them again. 99 | args = ["reduce", "-HIS", out_pdb_file] 100 | p2 = Popen(args, stdout=PIPE, stderr=PIPE) 101 | stdout, stderr = p2.communicate() 102 | outfile = open(out_pdb_file, "w") 103 | outfile.write(stdout.decode('utf-8')) 104 | outfile.close() 105 | 106 | 107 | 108 | def get_single(pdb_id: str,chains: list): 109 | protonated_file = pdb_dir/f"{pdb_id}.pdb" 110 | if not protonated_file.exists(): 111 | # Download pdb 112 | pdbl = PDBList() 113 | pdb_filename = pdbl.retrieve_pdb_file(pdb_id, pdir=tmp_dir,file_format='pdb') 114 | 115 | ##### Protonate with reduce, if hydrogens included. 116 | # - Always protonate as this is useful for charges. If necessary ignore hydrogens later. 117 | protonate(pdb_filename, protonated_file) 118 | 119 | pdb_filename = protonated_file 120 | 121 | # Extract chains of interest. 122 | for chain in chains: 123 | out_filename = pdb_dir/f"{pdb_id}_{chain}.pdb" 124 | extractPDB(pdb_filename, str(out_filename), chain) 125 | protein = load_structure_np(out_filename,center=False) 126 | np.save(npy_dir / f"{pdb_id}_{chain}_atomxyz", protein["xyz"]) 127 | np.save(npy_dir / f"{pdb_id}_{chain}_atomtypes", protein["types"]) 128 | 129 | if __name__ == '__main__': 130 | args = parser.parse_args() 131 | if args.pdb != '': 132 | pdb_id = args.pdb.split('_') 133 | chains = pdb_id[1:] 134 | pdb_id = pdb_id[0] 135 | get_single(pdb_id,chains) 136 | 137 | elif args.pdb_list != '': 138 | with open(args.pdb_list) as f: 139 | pdb_list = f.read().splitlines() 140 | for pdb_id in pdb_list: 141 | pdb_id = pdb_id.split('_') 142 | chains = pdb_id[1:] 143 | pdb_id = pdb_id[0] 144 | get_single(pdb_id,chains) 145 | else: 146 | raise ValueError('Must specify PDB or PDB list') -------------------------------------------------------------------------------- /geometry_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import pi 3 | import torch 4 | from pykeops.torch import LazyTensor 5 | from plyfile import PlyData, PlyElement 6 | from helper import * 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | # from matplotlib import pyplot as plt 11 | from pykeops.torch.cluster import grid_cluster, cluster_ranges_centroids, from_matrix 12 | from math import pi, sqrt 13 | 14 | 15 | # Input-Output for tests ======================================================= 16 | 17 | import os 18 | from pyvtk import PolyData, PointData, CellData, Scalars, Vectors, VtkData, PointData 19 | 20 | 21 | def save_vtk( 22 | fname, xyz, triangles=None, values=None, vectors=None, triangle_values=None 23 | ): 24 | """Saves a point cloud or triangle mesh as a .vtk file. 25 | 26 | Files can be opened with Paraview or displayed using the PyVista library. 27 | 28 | Args: 29 | fname (string): filename. 30 | xyz (Tensor): (N,3) point cloud or vertices. 31 | triangles (integer Tensor, optional): (T,3) mesh connectivity. Defaults to None. 32 | values (Tensor, optional): (N,D) values, supported by the vertices. Defaults to None. 33 | vectors (Tensor, optional): (N,3) vectors, supported by the vertices. Defaults to None. 34 | triangle_values (Tensor, optional): (T,D) values, supported by the triangles. Defaults to None. 35 | """ 36 | 37 | # Encode the points/vertices as a VTK structure: 38 | if triangles is None: # Point cloud 39 | structure = PolyData(points=numpy(xyz), vertices=np.arange(len(xyz))) 40 | else: # Surface mesh 41 | structure = PolyData(points=numpy(xyz), polygons=numpy(triangles)) 42 | 43 | data = [structure] 44 | pointdata, celldata = [], [] 45 | 46 | # Point values - one channel per column of the `values` array: 47 | if values is not None: 48 | values = numpy(values) 49 | if len(values.shape) == 1: 50 | values = values[:, None] 51 | features = values.T 52 | pointdata += [ 53 | Scalars(f, name=f"features_{i:02d}") for i, f in enumerate(features) 54 | ] 55 | 56 | # Point vectors - one vector per point: 57 | if vectors is not None: 58 | pointdata += [Vectors(numpy(vectors), name="vectors")] 59 | 60 | # Store in the VTK object: 61 | if pointdata != []: 62 | pointdata = PointData(*pointdata) 63 | data.append(pointdata) 64 | 65 | # Triangle values - one channel per column of the `triangle_values` array: 66 | if triangle_values is not None: 67 | triangle_values = numpy(triangle_values) 68 | if len(triangle_values.shape) == 1: 69 | triangle_values = triangle_values[:, None] 70 | features = triangle_values.T 71 | celldata += [ 72 | Scalars(f, name=f"features_{i:02d}") for i, f in enumerate(features) 73 | ] 74 | 75 | celldata = CellData(*celldata) 76 | data.append(celldata) 77 | 78 | #  Write to hard drive: 79 | vtk = VtkData(*data) 80 | os.makedirs(os.path.dirname(fname), exist_ok=True) 81 | vtk.tofile(fname) 82 | 83 | 84 | # On-the-fly generation of the surfaces ======================================== 85 | 86 | 87 | def subsample(x, batch=None, scale=1.0): 88 | """Subsamples the point cloud using a grid (cubic) clustering scheme. 89 | 90 | The function returns one average sample per cell, as described in Fig. 3.e) 91 | of the paper. 92 | 93 | Args: 94 | x (Tensor): (N,3) point cloud. 95 | batch (integer Tensor, optional): (N,) batch vector, as in PyTorch_geometric. 96 | Defaults to None. 97 | scale (float, optional): side length of the cubic grid cells. Defaults to 1 (Angstrom). 98 | 99 | Returns: 100 | (M,3): sub-sampled point cloud, with M <= N. 101 | """ 102 | 103 | if batch is None: # Single protein case: 104 | if True: # Use a fast scatter_add_ implementation 105 | labels = grid_cluster(x, scale).long() 106 | C = labels.max() + 1 107 | 108 | # We append a "1" to the input vectors, in order to 109 | # compute both the numerator and denominator of the "average" 110 | #  fraction in one pass through the data. 111 | x_1 = torch.cat((x, torch.ones_like(x[:, :1])), dim=1) 112 | D = x_1.shape[1] 113 | points = torch.zeros_like(x_1[:C]) 114 | points.scatter_add_(0, labels[:, None].repeat(1, D), x_1) 115 | return (points[:, :-1] / points[:, -1:]).contiguous() 116 | 117 | else: # Older implementation; 118 | points = scatter(points * weights[:, None], labels, dim=0) 119 | weights = scatter(weights, labels, dim=0) 120 | points = points / weights[:, None] 121 | 122 | else: # We process proteins using a for loop. 123 | # This is probably sub-optimal, but I don't really know 124 | # how to do more elegantly (this type of computation is 125 | # not super well supported by PyTorch). 126 | batch_size = torch.max(batch).item() + 1 # Typically, =32 127 | points, batches = [], [] 128 | for b in range(batch_size): 129 | p = subsample(x[batch == b], scale=scale) 130 | points.append(p) 131 | batches.append(b * torch.ones_like(batch[: len(p)])) 132 | 133 | return torch.cat(points, dim=0), torch.cat(batches, dim=0) 134 | 135 | 136 | def soft_distances(x, y, batch_x, batch_y, smoothness=0.01, atomtypes=None): 137 | """Computes a soft distance function to the atom centers of a protein. 138 | 139 | Implements Eq. (1) of the paper in a fast and numerically stable way. 140 | 141 | Args: 142 | x (Tensor): (N,3) atom centers. 143 | y (Tensor): (M,3) sampling locations. 144 | batch_x (integer Tensor): (N,) batch vector for x, as in PyTorch_geometric. 145 | batch_y (integer Tensor): (M,) batch vector for y, as in PyTorch_geometric. 146 | smoothness (float, optional): atom radii if atom types are not provided. Defaults to .01. 147 | atomtypes (integer Tensor, optional): (N,6) one-hot encoding of the atom chemical types. Defaults to None. 148 | 149 | Returns: 150 | Tensor: (M,) values of the soft distance function on the points `y`. 151 | """ 152 | # Build the (N, M, 1) symbolic matrix of squared distances: 153 | x_i = LazyTensor(x[:, None, :]) # (N, 1, 3) atoms 154 | y_j = LazyTensor(y[None, :, :]) # (1, M, 3) sampling points 155 | D_ij = ((x_i - y_j) ** 2).sum(-1) # (N, M, 1) squared distances 156 | 157 | # Use a block-diagonal sparsity mask to support heterogeneous batch processing: 158 | D_ij.ranges = diagonal_ranges(batch_x, batch_y) 159 | 160 | if atomtypes is not None: 161 | # Turn the one-hot encoding "atomtypes" into a vector of diameters "smoothness_i": 162 | # (N, 6) -> (N, 1, 1) (There are 6 atom types) 163 | atomic_radii = torch.cuda.FloatTensor( 164 | [170, 110, 152, 155, 180, 190], device=x.device 165 | ) 166 | atomic_radii = atomic_radii / atomic_radii.min() 167 | atomtype_radii = atomtypes * atomic_radii[None, :] # n_atoms, n_atomtypes 168 | # smoothness = atomtypes @ atomic_radii # (N, 6) @ (6,) = (N,) 169 | smoothness = torch.sum( 170 | smoothness * atomtype_radii, dim=1, keepdim=False 171 | ) # n_atoms, 1 172 | smoothness_i = LazyTensor(smoothness[:, None, None]) 173 | 174 | # Compute an estimation of the mean smoothness in a neighborhood 175 | # of each sampling point: 176 | # density = (-D_ij.sqrt()).exp().sum(0).view(-1) # (M,) local density of atoms 177 | # smooth = (smoothness_i * (-D_ij.sqrt()).exp()).sum(0).view(-1) # (M,) 178 | # mean_smoothness = smooth / density # (M,) 179 | 180 | # soft_dists = -mean_smoothness * ( 181 | # (-D_ij.sqrt() / smoothness_i).logsumexp(dim=0) 182 | # ).view(-1) 183 | mean_smoothness = (-D_ij.sqrt()).exp().sum(0) 184 | mean_smoothness_j = LazyTensor(mean_smoothness[None, :, :]) 185 | mean_smoothness = ( 186 | smoothness_i * (-D_ij.sqrt()).exp() / mean_smoothness_j 187 | ) # n_atoms, n_points, 1 188 | mean_smoothness = mean_smoothness.sum(0).view(-1) 189 | soft_dists = -mean_smoothness * ( 190 | (-D_ij.sqrt() / smoothness_i).logsumexp(dim=0) 191 | ).view(-1) 192 | 193 | else: 194 | soft_dists = -smoothness * ((-D_ij.sqrt() / smoothness).logsumexp(dim=0)).view( 195 | -1 196 | ) 197 | 198 | return soft_dists 199 | 200 | 201 | def atoms_to_points_normals( 202 | atoms, 203 | batch, 204 | distance=1.05, 205 | smoothness=0.5, 206 | resolution=1.0, 207 | nits=4, 208 | atomtypes=None, 209 | sup_sampling=20, 210 | variance=0.1, 211 | ): 212 | """Turns a collection of atoms into an oriented point cloud. 213 | 214 | Sampling algorithm for protein surfaces, described in Fig. 3 of the paper. 215 | 216 | Args: 217 | atoms (Tensor): (N,3) coordinates of the atom centers `a_k`. 218 | batch (integer Tensor): (N,) batch vector, as in PyTorch_geometric. 219 | distance (float, optional): value of the level set to sample from 220 | the smooth distance function. Defaults to 1.05. 221 | smoothness (float, optional): radii of the atoms, if atom types are 222 | not provided. Defaults to 0.5. 223 | resolution (float, optional): side length of the cubic cells in 224 | the final sub-sampling pass. Defaults to 1.0. 225 | nits (int, optional): number of iterations . Defaults to 4. 226 | atomtypes (Tensor, optional): (N,6) one-hot encoding of the atom 227 | chemical types. Defaults to None. 228 | 229 | Returns: 230 | (Tensor): (M,3) coordinates for the surface points `x_i`. 231 | (Tensor): (M,3) unit normals `n_i`. 232 | (integer Tensor): (M,) batch vector, as in PyTorch_geometric. 233 | """ 234 | # a) Parameters for the soft distance function and its level set: 235 | T = distance 236 | 237 | N, D = atoms.shape 238 | B = sup_sampling # Sup-sampling ratio 239 | 240 | # Batch vectors: 241 | batch_atoms = batch 242 | batch_z = batch[:, None].repeat(1, B).view(N * B) 243 | 244 | # b) Draw N*B points at random in the neighborhood of our atoms 245 | z = atoms[:, None, :] + 10 * T * torch.randn(N, B, D).type_as(atoms) 246 | z = z.view(-1, D) # (N*B, D) 247 | 248 | # We don't want to backprop through a full network here! 249 | atoms = atoms.detach().contiguous() 250 | z = z.detach().contiguous() 251 | 252 | # N.B.: Test mode disables the autograd engine: we must switch it on explicitely. 253 | with torch.enable_grad(): 254 | if z.is_leaf: 255 | z.requires_grad = True 256 | 257 | # c) Iterative loop: gradient descent along the potential 258 | # ".5 * (dist - T)^2" with respect to the positions z of our samples 259 | for it in range(nits): 260 | dists = soft_distances( 261 | atoms, 262 | z, 263 | batch_atoms, 264 | batch_z, 265 | smoothness=smoothness, 266 | atomtypes=atomtypes, 267 | ) 268 | Loss = ((dists - T) ** 2).sum() 269 | g = torch.autograd.grad(Loss, z)[0] 270 | z.data -= 0.5 * g 271 | 272 | # d) Only keep the points which are reasonably close to the level set: 273 | dists = soft_distances( 274 | atoms, z, batch_atoms, batch_z, smoothness=smoothness, atomtypes=atomtypes 275 | ) 276 | margin = (dists - T).abs() 277 | mask = margin < variance * T 278 | 279 | # d') And remove the points that are trapped *inside* the protein: 280 | zz = z.detach() 281 | zz.requires_grad = True 282 | for it in range(nits): 283 | dists = soft_distances( 284 | atoms, 285 | zz, 286 | batch_atoms, 287 | batch_z, 288 | smoothness=smoothness, 289 | atomtypes=atomtypes, 290 | ) 291 | Loss = (1.0 * dists).sum() 292 | g = torch.autograd.grad(Loss, zz)[0] 293 | normals = F.normalize(g, p=2, dim=-1) # (N, 3) 294 | zz = zz + 1.0 * T * normals 295 | 296 | dists = soft_distances( 297 | atoms, zz, batch_atoms, batch_z, smoothness=smoothness, atomtypes=atomtypes 298 | ) 299 | mask = mask & (dists > 1.5 * T) 300 | 301 | z = z[mask].contiguous().detach() 302 | batch_z = batch_z[mask].contiguous().detach() 303 | 304 | # e) Subsample the point cloud: 305 | points, batch_points = subsample(z, batch_z, scale=resolution) 306 | 307 | # f) Compute the normals on this smaller point cloud: 308 | p = points.detach() 309 | p.requires_grad = True 310 | dists = soft_distances( 311 | atoms, 312 | p, 313 | batch_atoms, 314 | batch_points, 315 | smoothness=smoothness, 316 | atomtypes=atomtypes, 317 | ) 318 | Loss = (1.0 * dists).sum() 319 | g = torch.autograd.grad(Loss, p)[0] 320 | normals = F.normalize(g, p=2, dim=-1) # (N, 3) 321 | points = points - 0.5 * normals 322 | return points.detach(), normals.detach(), batch_points.detach() 323 | 324 | 325 | # Surface mesh -> Normals ====================================================== 326 | 327 | 328 | def mesh_normals_areas(vertices, triangles=None, scale=[1.0], batch=None, normals=None): 329 | """Returns a smooth field of normals, possibly at different scales. 330 | 331 | points, triangles or normals, scale(s) -> normals 332 | (N, 3), (3, T) or (N,3), (S,) -> (N, 3) or (N, S, 3) 333 | 334 | Simply put - if `triangles` are provided: 335 | 1. Normals are first computed for every triangle using simple 3D geometry 336 | and are weighted according to surface area. 337 | 2. The normal at any given vertex is then computed as the weighted average 338 | of the normals of all triangles in a neighborhood specified 339 | by Gaussian windows whose radii are given in the list of "scales". 340 | 341 | If `normals` are provided instead, we simply smooth the discrete vector 342 | field using Gaussian windows whose radii are given in the list of "scales". 343 | 344 | If more than one scale is provided, normal fields are computed in parallel 345 | and returned in a single 3D tensor. 346 | 347 | Args: 348 | vertices (Tensor): (N,3) coordinates of mesh vertices or 3D points. 349 | triangles (integer Tensor, optional): (3,T) mesh connectivity. Defaults to None. 350 | scale (list of floats, optional): (S,) radii of the Gaussian smoothing windows. Defaults to [1.]. 351 | batch (integer Tensor, optional): batch vector, as in PyTorch_geometric. Defaults to None. 352 | normals (Tensor, optional): (N,3) raw normals vectors on the vertices. Defaults to None. 353 | 354 | Returns: 355 | (Tensor): (N,3) or (N,S,3) point normals. 356 | (Tensor): (N,) point areas, if triangles were provided. 357 | """ 358 | 359 | # Single- or Multi-scale mode: 360 | if hasattr(scale, "__len__"): 361 | scales, single_scale = scale, False 362 | else: 363 | scales, single_scale = [scale], True 364 | scales = torch.Tensor(scales).type_as(vertices) # (S,) 365 | 366 | # Compute the "raw" field of normals: 367 | if triangles is not None: 368 | # Vertices of all triangles in the mesh: 369 | A = vertices[triangles[0, :]] # (N, 3) 370 | B = vertices[triangles[1, :]] # (N, 3) 371 | C = vertices[triangles[2, :]] # (N, 3) 372 | 373 | # Triangle centers and normals (length = surface area): 374 | centers = (A + B + C) / 3 # (N, 3) 375 | V = (B - A).cross(C - A) # (N, 3) 376 | 377 | # Vertice areas: 378 | S = (V ** 2).sum(-1).sqrt() / 6 # (N,) 1/3 of a triangle area 379 | areas = torch.zeros(len(vertices)).type_as(vertices) # (N,) 380 | areas.scatter_add_(0, triangles[0, :], S) # Aggregate from "A's" 381 | areas.scatter_add_(0, triangles[1, :], S) # Aggregate from "B's" 382 | areas.scatter_add_(0, triangles[2, :], S) # Aggregate from "C's" 383 | 384 | else: # Use "normals" instead 385 | areas = None 386 | V = normals 387 | centers = vertices 388 | 389 | # Normal of a vertex = average of all normals in a ball of size "scale": 390 | x_i = LazyTensor(vertices[:, None, :]) # (N, 1, 3) 391 | y_j = LazyTensor(centers[None, :, :]) # (1, M, 3) 392 | v_j = LazyTensor(V[None, :, :]) # (1, M, 3) 393 | s = LazyTensor(scales[None, None, :]) # (1, 1, S) 394 | 395 | D_ij = ((x_i - y_j) ** 2).sum(-1) #  (N, M, 1) 396 | K_ij = (-D_ij / (2 * s ** 2)).exp() # (N, M, S) 397 | 398 | # Support for heterogeneous batch processing: 399 | if batch is not None: 400 | batch_vertices = batch 401 | batch_centers = batch[triangles[0, :]] if triangles is not None else batch 402 | K_ij.ranges = diagonal_ranges(batch_vertices, batch_centers) 403 | 404 | if single_scale: 405 | U = (K_ij * v_j).sum(dim=1) # (N, 3) 406 | else: 407 | U = (K_ij.tensorprod(v_j)).sum(dim=1) # (N, S*3) 408 | U = U.view(-1, len(scales), 3) # (N, S, 3) 409 | 410 | normals = F.normalize(U, p=2, dim=-1) # (N, 3) or (N, S, 3) 411 | 412 | return normals, areas 413 | 414 | 415 | # Compute tangent planes and curvatures ======================================== 416 | 417 | 418 | def tangent_vectors(normals): 419 | """Returns a pair of vector fields u and v to complete the orthonormal basis [n,u,v]. 420 | 421 | normals -> uv 422 | (N, 3) or (N, S, 3) -> (N, 2, 3) or (N, S, 2, 3) 423 | 424 | This routine assumes that the 3D "normal" vectors are normalized. 425 | It is based on the 2017 paper from Pixar, "Building an orthonormal basis, revisited". 426 | 427 | Args: 428 | normals (Tensor): (N,3) or (N,S,3) normals `n_i`, i.e. unit-norm 3D vectors. 429 | 430 | Returns: 431 | (Tensor): (N,2,3) or (N,S,2,3) unit vectors `u_i` and `v_i` to complete 432 | the tangent coordinate systems `[n_i,u_i,v_i]. 433 | """ 434 | x, y, z = normals[..., 0], normals[..., 1], normals[..., 2] 435 | s = (2 * (z >= 0)) - 1.0 # = z.sign(), but =1. if z=0. 436 | a = -1 / (s + z) 437 | b = x * y * a 438 | uv = torch.stack((1 + s * x * x * a, s * b, -s * x, b, s + y * y * a, -y), dim=-1) 439 | uv = uv.view(uv.shape[:-1] + (2, 3)) 440 | 441 | return uv 442 | 443 | 444 | def curvatures( 445 | vertices, triangles=None, scales=[1.0], batch=None, normals=None, reg=0.01 446 | ): 447 | """Returns a collection of mean (H) and Gauss (K) curvatures at different scales. 448 | 449 | points, faces, scales -> (H_1, K_1, ..., H_S, K_S) 450 | (N, 3), (3, N), (S,) -> (N, S*2) 451 | 452 | We rely on a very simple linear regression method, for all vertices: 453 | 454 | 1. Estimate normals and surface areas. 455 | 2. Compute a local tangent frame. 456 | 3. In a pseudo-geodesic Gaussian neighborhood at scale s, 457 | compute the two (2, 2) covariance matrices PPt and PQt 458 | between the displacement vectors "P = x_i - x_j" and 459 | the normals "Q = n_i - n_j", projected on the local tangent plane. 460 | 4. Up to the sign, the shape operator S at scale s is then approximated 461 | as "S = (reg**2 * I_2 + PPt)^-1 @ PQt". 462 | 5. The mean and Gauss curvatures are the trace and determinant of 463 | this (2, 2) matrix. 464 | 465 | As of today, this implementation does not weigh points by surface areas: 466 | this could make a sizeable difference if protein surfaces were not 467 | sub-sampled to ensure uniform sampling density. 468 | 469 | For convergence analysis, see for instance 470 | "Efficient curvature estimation for oriented point clouds", 471 | Cao, Li, Sun, Assadi, Zhang, 2019. 472 | 473 | Args: 474 | vertices (Tensor): (N,3) coordinates of the points or mesh vertices. 475 | triangles (integer Tensor, optional): (3,T) mesh connectivity. Defaults to None. 476 | scales (list of floats, optional): list of (S,) smoothing scales. Defaults to [1.]. 477 | batch (integer Tensor, optional): batch vector, as in PyTorch_geometric. Defaults to None. 478 | normals (Tensor, optional): (N,3) field of "raw" unit normals. Defaults to None. 479 | reg (float, optional): small amount of Tikhonov/ridge regularization 480 | in the estimation of the shape operator. Defaults to .01. 481 | 482 | Returns: 483 | (Tensor): (N, S*2) tensor of mean and Gauss curvatures computed for 484 | every point at the required scales. 485 | """ 486 | # Number of points, number of scales: 487 | N, S = vertices.shape[0], len(scales) 488 | ranges = diagonal_ranges(batch) 489 | 490 | # Compute the normals at different scales + vertice areas: 491 | normals_s, _ = mesh_normals_areas( 492 | vertices, triangles=triangles, normals=normals, scale=scales, batch=batch 493 | ) # (N, S, 3), (N,) 494 | 495 | # Local tangent bases: 496 | uv_s = tangent_vectors(normals_s) # (N, S, 2, 3) 497 | 498 | features = [] 499 | 500 | for s, scale in enumerate(scales): 501 | # Extract the relevant descriptors at the current scale: 502 | normals = normals_s[:, s, :].contiguous() #  (N, 3) 503 | uv = uv_s[:, s, :, :].contiguous() # (N, 2, 3) 504 | 505 | # Encode as symbolic tensors: 506 | # Points: 507 | x_i = LazyTensor(vertices.view(N, 1, 3)) 508 | x_j = LazyTensor(vertices.view(1, N, 3)) 509 | # Normals: 510 | n_i = LazyTensor(normals.view(N, 1, 3)) 511 | n_j = LazyTensor(normals.view(1, N, 3)) 512 | # Tangent bases: 513 | uv_i = LazyTensor(uv.view(N, 1, 6)) 514 | 515 | # Pseudo-geodesic squared distance: 516 | d2_ij = ((x_j - x_i) ** 2).sum(-1) * ((2 - (n_i | n_j)) ** 2) # (N, N, 1) 517 | # Gaussian window: 518 | window_ij = (-d2_ij / (2 * (scale ** 2))).exp() # (N, N, 1) 519 | 520 | # Project on the tangent plane: 521 | P_ij = uv_i.matvecmult(x_j - x_i) # (N, N, 2) 522 | Q_ij = uv_i.matvecmult(n_j - n_i) # (N, N, 2) 523 | # Concatenate: 524 | PQ_ij = P_ij.concat(Q_ij) # (N, N, 2+2) 525 | 526 | # Covariances, with a scale-dependent weight: 527 | PPt_PQt_ij = P_ij.tensorprod(PQ_ij) # (N, N, 2*(2+2)) 528 | PPt_PQt_ij = window_ij * PPt_PQt_ij #  (N, N, 2*(2+2)) 529 | 530 | # Reduction - with batch support: 531 | PPt_PQt_ij.ranges = ranges 532 | PPt_PQt = PPt_PQt_ij.sum(1) # (N, 2*(2+2)) 533 | 534 | # Reshape to get the two covariance matrices: 535 | PPt_PQt = PPt_PQt.view(N, 2, 2, 2) 536 | PPt, PQt = PPt_PQt[:, :, 0, :], PPt_PQt[:, :, 1, :] # (N, 2, 2), (N, 2, 2) 537 | 538 | # Add a small ridge regression: 539 | PPt[:, 0, 0] += reg 540 | PPt[:, 1, 1] += reg 541 | 542 | # (minus) Shape operator, i.e. the differential of the Gauss map: 543 | # = (PPt^-1 @ PQt) : simple estimation through linear regression 544 | S = torch.solve(PQt, PPt).solution 545 | a, b, c, d = S[:, 0, 0], S[:, 0, 1], S[:, 1, 0], S[:, 1, 1] # (N,) 546 | 547 | # Normalization 548 | mean_curvature = a + d 549 | gauss_curvature = a * d - b * c 550 | features += [mean_curvature.clamp(-1, 1), gauss_curvature.clamp(-1, 1)] 551 | 552 | features = torch.stack(features, dim=-1) 553 | return features 554 | 555 | 556 | #  Fast tangent convolution layer =============================================== 557 | class ContiguousBackward(torch.autograd.Function): 558 | """ 559 | Function to ensure contiguous gradient in backward pass. To be applied after PyKeOps reduction. 560 | N.B.: This workaround fixes a bug that will be fixed in ulterior KeOp releases. 561 | """ 562 | @staticmethod 563 | def forward(ctx, input): 564 | return input 565 | 566 | @staticmethod 567 | def backward(ctx, grad_output): 568 | return grad_output.contiguous() 569 | 570 | class dMaSIFConv(nn.Module): 571 | def __init__( 572 | self, in_channels=1, out_channels=1, radius=1.0, hidden_units=None, cheap=False 573 | ): 574 | """Creates the KeOps convolution layer. 575 | 576 | I = in_channels is the dimension of the input features 577 | O = out_channels is the dimension of the output features 578 | H = hidden_units is the dimension of the intermediate representation 579 | radius is the size of the pseudo-geodesic Gaussian window w_ij = W(d_ij) 580 | 581 | 582 | This affordable layer implements an elementary "convolution" operator 583 | on a cloud of N points (x_i) in dimension 3 that we decompose in three steps: 584 | 585 | 1. Apply the MLP "net_in" on the input features "f_i". (N, I) -> (N, H) 586 | 587 | 2. Compute H interaction terms in parallel with: 588 | f_i = sum_j [ w_ij * conv(P_ij) * f_j ] 589 | In the equation above: 590 | - w_ij is a pseudo-geodesic window with a set radius. 591 | - P_ij is a vector of dimension 3, equal to "x_j-x_i" 592 | in the local oriented basis at x_i. 593 | - "conv" is an MLP from R^3 to R^H: 594 | - with 1 linear layer if "cheap" is True; 595 | - with 2 linear layers and C=8 intermediate "cuts" otherwise. 596 | - "*" is coordinate-wise product. 597 | - f_j is the vector of transformed features. 598 | 599 | 3. Apply the MLP "net_out" on the output features. (N, H) -> (N, O) 600 | 601 | 602 | A more general layer would have implemented conv(P_ij) as a full 603 | (H, H) matrix instead of a mere (H,) vector... At a much higher 604 | computational cost. The reasoning behind the code below is that 605 | a given time budget is better spent on using a larger architecture 606 | and more channels than on a very complex convolution operator. 607 | Interactions between channels happen at steps 1. and 3., 608 | whereas the (costly) point-to-point interaction step 2. 609 | lets the network aggregate information in spatial neighborhoods. 610 | 611 | Args: 612 | in_channels (int, optional): numper of input features per point. Defaults to 1. 613 | out_channels (int, optional): number of output features per point. Defaults to 1. 614 | radius (float, optional): deviation of the Gaussian window on the 615 | quasi-geodesic distance `d_ij`. Defaults to 1.. 616 | hidden_units (int, optional): number of hidden features per point. 617 | Defaults to out_channels. 618 | cheap (bool, optional): shall we use a 1-layer deep Filter, 619 | instead of a 2-layer deep MLP? Defaults to False. 620 | """ 621 | 622 | super(dMaSIFConv, self).__init__() 623 | 624 | self.Input = in_channels 625 | self.Output = out_channels 626 | self.Radius = radius 627 | self.Hidden = self.Output if hidden_units is None else hidden_units 628 | self.Cuts = 8 # Number of hidden units for the 3D MLP Filter. 629 | self.cheap = cheap 630 | 631 | # For performance reasons, we cut our "hidden" vectors 632 | # in n_heads "independent heads" of dimension 8. 633 | self.heads_dim = 8 # 4 is probably too small; 16 is certainly too big 634 | 635 | # We accept "Hidden" dimensions of size 1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, ... 636 | if self.Hidden < self.heads_dim: 637 | self.heads_dim = self.Hidden 638 | 639 | if self.Hidden % self.heads_dim != 0: 640 | raise ValueError(f"The dimension of the hidden units ({self.Hidden})"\ 641 | + f"should be a multiple of the heads dimension ({self.heads_dim}).") 642 | else: 643 | self.n_heads = self.Hidden // self.heads_dim 644 | 645 | 646 | # Transformation of the input features: 647 | self.net_in = nn.Sequential( 648 | nn.Linear(self.Input, self.Hidden), # (H, I) + (H,) 649 | nn.LeakyReLU(negative_slope=0.2), 650 | nn.Linear(self.Hidden, self.Hidden), # (H, H) + (H,) 651 | # nn.LayerNorm(self.Hidden),#nn.BatchNorm1d(self.Hidden), 652 | nn.LeakyReLU(negative_slope=0.2), 653 | ) #  (H,) 654 | self.norm_in = nn.GroupNorm(4, self.Hidden) 655 | # self.norm_in = nn.LayerNorm(self.Hidden) 656 | # self.norm_in = nn.Identity() 657 | 658 | # 3D convolution filters, encoded as an MLP: 659 | if cheap: 660 | self.conv = nn.Sequential( 661 | nn.Linear(3, self.Hidden), nn.ReLU() # (H, 3) + (H,) 662 | ) # KeOps does not support well LeakyReLu 663 | else: 664 | self.conv = nn.Sequential( 665 | nn.Linear(3, self.Cuts), # (C, 3) + (C,) 666 | nn.ReLU(), # KeOps does not support well LeakyReLu 667 | nn.Linear(self.Cuts, self.Hidden), 668 | ) # (H, C) + (H,) 669 | 670 | # Transformation of the output features: 671 | self.net_out = nn.Sequential( 672 | nn.Linear(self.Hidden, self.Output), # (O, H) + (O,) 673 | nn.LeakyReLU(negative_slope=0.2), 674 | nn.Linear(self.Output, self.Output), # (O, O) + (O,) 675 | # nn.LayerNorm(self.Output),#nn.BatchNorm1d(self.Output), 676 | nn.LeakyReLU(negative_slope=0.2), 677 | ) #  (O,) 678 | 679 | self.norm_out = nn.GroupNorm(4, self.Output) 680 | # self.norm_out = nn.LayerNorm(self.Output) 681 | # self.norm_out = nn.Identity() 682 | 683 | # Custom initialization for the MLP convolution filters: 684 | # we get interesting piecewise affine cuts on a normalized neighborhood. 685 | with torch.no_grad(): 686 | nn.init.normal_(self.conv[0].weight) 687 | nn.init.uniform_(self.conv[0].bias) 688 | self.conv[0].bias *= 0.8 * (self.conv[0].weight ** 2).sum(-1).sqrt() 689 | 690 | if not cheap: 691 | nn.init.uniform_( 692 | self.conv[2].weight, 693 | a=-1 / np.sqrt(self.Cuts), 694 | b=1 / np.sqrt(self.Cuts), 695 | ) 696 | nn.init.normal_(self.conv[2].bias) 697 | self.conv[2].bias *= 0.5 * (self.conv[2].weight ** 2).sum(-1).sqrt() 698 | 699 | 700 | def forward(self, points, nuv, features, ranges=None): 701 | """Performs a quasi-geodesic interaction step. 702 | 703 | points, local basis, in features -> out features 704 | (N, 3), (N, 3, 3), (N, I) -> (N, O) 705 | 706 | This layer computes the interaction step of Eq. (7) in the paper, 707 | in-between the application of two MLP networks independently on all 708 | feature vectors. 709 | 710 | Args: 711 | points (Tensor): (N,3) point coordinates `x_i`. 712 | nuv (Tensor): (N,3,3) local coordinate systems `[n_i,u_i,v_i]`. 713 | features (Tensor): (N,I) input feature vectors `f_i`. 714 | ranges (6-uple of integer Tensors, optional): low-level format 715 | to support batch processing, as described in the KeOps documentation. 716 | In practice, this will be built by a higher-level object 717 | to encode the relevant "batch vectors" in a way that is convenient 718 | for the KeOps CUDA engine. Defaults to None. 719 | 720 | Returns: 721 | (Tensor): (N,O) output feature vectors `f'_i`. 722 | """ 723 | 724 | # 1. Transform the input features: ------------------------------------- 725 | features = self.net_in(features) # (N, I) -> (N, H) 726 | features = features.transpose(1, 0)[None, :, :] # (1,H,N) 727 | features = self.norm_in(features) 728 | features = features[0].transpose(1, 0).contiguous() # (1, H, N) -> (N, H) 729 | 730 | # 2. Compute the local "shape contexts": ------------------------------- 731 | 732 | # 2.a Normalize the kernel radius: 733 | points = points / (sqrt(2.0) * self.Radius) # (N, 3) 734 | 735 | # 2.b Encode the variables as KeOps LazyTensors 736 | 737 | # Vertices: 738 | x_i = LazyTensor(points[:, None, :]) # (N, 1, 3) 739 | x_j = LazyTensor(points[None, :, :]) # (1, N, 3) 740 | 741 | # WARNING - Here, we assume that the normals are fixed: 742 | normals = ( 743 | nuv[:, 0, :].contiguous().detach() 744 | ) # (N, 3) - remove the .detach() if needed 745 | 746 | # Local bases: 747 | nuv_i = LazyTensor(nuv.view(-1, 1, 9)) # (N, 1, 9) 748 | # Normals: 749 | n_i = nuv_i[:3] # (N, 1, 3) 750 | 751 | n_j = LazyTensor(normals[None, :, :]) # (1, N, 3) 752 | 753 | # To avoid register spilling when using large embeddings, we perform our KeOps reduction 754 | # over the vector of length "self.Hidden = self.n_heads * self.heads_dim" 755 | # as self.n_heads reduction over vectors of length self.heads_dim (= "Hd" in the comments). 756 | head_out_features = [] 757 | for head in range(self.n_heads): 758 | 759 | # Extract a slice of width Hd from the feature array 760 | head_start = head * self.heads_dim 761 | head_end = head_start + self.heads_dim 762 | head_features = features[:, head_start:head_end].contiguous() # (N, H) -> (N, Hd) 763 | 764 | # Features: 765 | f_j = LazyTensor(head_features[None, :, :]) # (1, N, Hd) 766 | 767 | # Convolution parameters: 768 | if self.cheap: 769 | # Extract a slice of Hd lines: (H, 3) -> (Hd, 3) 770 | A = self.conv[0].weight[head_start:head_end, :].contiguous() 771 | # Extract a slice of Hd coefficients: (H,) -> (Hd,) 772 | B = self.conv[0].bias[head_start:head_end].contiguous() 773 | AB = torch.cat((A, B[:, None]), dim=1) # (Hd, 4) 774 | ab = LazyTensor(AB.view(1, 1, -1)) # (1, 1, Hd*4) 775 | else: 776 | A_1, B_1 = self.conv[0].weight, self.conv[0].bias # (C, 3), (C,) 777 | # Extract a slice of Hd lines: (H, C) -> (Hd, C) 778 | A_2 = self.conv[2].weight[head_start:head_end, :].contiguous() 779 | # Extract a slice of Hd coefficients: (H,) -> (Hd,) 780 | B_2 = self.conv[2].bias[head_start:head_end].contiguous() 781 | a_1 = LazyTensor(A_1.view(1, 1, -1)) # (1, 1, C*3) 782 | b_1 = LazyTensor(B_1.view(1, 1, -1)) # (1, 1, C) 783 | a_2 = LazyTensor(A_2.view(1, 1, -1)) # (1, 1, Hd*C) 784 | b_2 = LazyTensor(B_2.view(1, 1, -1)) # (1, 1, Hd) 785 | 786 | # 2.c Pseudo-geodesic window: 787 | # Pseudo-geodesic squared distance: 788 | d2_ij = ((x_j - x_i) ** 2).sum(-1) * ((2 - (n_i | n_j)) ** 2) # (N, N, 1) 789 | # Gaussian window: 790 | window_ij = (-d2_ij).exp() # (N, N, 1) 791 | 792 | # 2.d Local MLP: 793 | # Local coordinates: 794 | X_ij = nuv_i.matvecmult(x_j - x_i) # (N, N, 9) "@" (N, N, 3) = (N, N, 3) 795 | # MLP: 796 | if self.cheap: 797 | X_ij = ab.matvecmult( 798 | X_ij.concat(LazyTensor(1)) 799 | ) # (N, N, Hd*4) @ (N, N, 3+1) = (N, N, Hd) 800 | X_ij = X_ij.relu() # (N, N, Hd) 801 | else: 802 | X_ij = a_1.matvecmult(X_ij) + b_1 # (N, N, C) 803 | X_ij = X_ij.relu() # (N, N, C) 804 | X_ij = a_2.matvecmult(X_ij) + b_2 # (N, N, Hd) 805 | X_ij = X_ij.relu() 806 | 807 | # 2.e Actual computation: 808 | F_ij = window_ij * X_ij * f_j # (N, N, Hd) 809 | F_ij.ranges = ranges # Support for batches and/or block-sparsity 810 | 811 | head_out_features.append(ContiguousBackward().apply(F_ij.sum(dim=1))) # (N, Hd) 812 | 813 | # Concatenate the result of our n_heads "attention heads": 814 | features = torch.cat(head_out_features, dim=1) # n_heads * (N, Hd) -> (N, H) 815 | 816 | # 3. Transform the output features: ------------------------------------ 817 | features = self.net_out(features) # (N, H) -> (N, O) 818 | features = features.transpose(1, 0)[None, :, :] # (1,O,N) 819 | features = self.norm_out(features) 820 | features = features[0].transpose(1, 0).contiguous() 821 | 822 | return features 823 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | 3 | import numpy as np 4 | import torch 5 | from pykeops.torch import LazyTensor 6 | from plyfile import PlyData, PlyElement 7 | from pathlib import Path 8 | 9 | 10 | tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor 11 | inttensor = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor 12 | numpy = lambda x: x.detach().cpu().numpy() 13 | 14 | 15 | def ranges_slices(batch): 16 | """Helper function for the diagonal ranges function.""" 17 | Ns = batch.bincount() 18 | indices = Ns.cumsum(0) 19 | ranges = torch.cat((0 * indices[:1], indices)) 20 | ranges = ( 21 | torch.stack((ranges[:-1], ranges[1:])).t().int().contiguous().to(batch.device) 22 | ) 23 | slices = (1 + torch.arange(len(Ns))).int().to(batch.device) 24 | 25 | return ranges, slices 26 | 27 | 28 | def diagonal_ranges(batch_x=None, batch_y=None): 29 | """Encodes the block-diagonal structure associated to a batch vector.""" 30 | 31 | if batch_x is None and batch_y is None: 32 | return None # No batch processing 33 | elif batch_y is None: 34 | batch_y = batch_x # "symmetric" case 35 | 36 | ranges_x, slices_x = ranges_slices(batch_x) 37 | ranges_y, slices_y = ranges_slices(batch_y) 38 | 39 | return ranges_x, slices_x, ranges_y, ranges_y, slices_y, ranges_x 40 | 41 | 42 | def soft_dimension(features): 43 | """Continuous approximation of the rank of a (N, D) sample. 44 | 45 | Let "s" denote the (D,) vector of eigenvalues of Cov, 46 | the (D, D) covariance matrix of the sample "features". 47 | Then, 48 | R(features) = \sum_i sqrt(s_i) / \max_i sqrt(s_i) 49 | 50 | This quantity encodes the number of PCA components that would be 51 | required to describe the sample with a good precision. 52 | It is equal to D if the sample is isotropic, but is generally much lower. 53 | 54 | Up to the re-normalization by the largest eigenvalue, 55 | this continuous pseudo-rank is equal to the nuclear norm of the sample. 56 | """ 57 | 58 | nfeat = features.shape[-1] 59 | features = features.view(-1, nfeat) 60 | x = features - torch.mean(features, dim=0, keepdim=True) 61 | cov = x.T @ x 62 | try: 63 | u, s, v = torch.svd(cov) 64 | R = s.sqrt().sum() / s.sqrt().max() 65 | except: 66 | return -1 67 | return R.item() 68 | -------------------------------------------------------------------------------- /lists/testing.txt: -------------------------------------------------------------------------------- 1 | 1A99_D 2 | 1AUI_A 3 | 1BO6_A 4 | 1BOU_B 5 | 1BPL_B 6 | 1C7N_D 7 | 1CS0_B 8 | 1DM5_F 9 | 1DOS_B 10 | 1EFV_A 11 | 1EHI_A 12 | 1EV7_A 13 | 1EWY_A 14 | 1EZ1_B 15 | 1F06_A 16 | 1F6M_A 17 | 1FFV_C 18 | 1FR8_A 19 | 1FT8_C 20 | 1FZW_B 21 | 1GP2_BG 22 | 1GX7_A 23 | 1H6D_A 24 | 1HF2_C 25 | 1I2M_B 26 | 1IUG_A 27 | 1IXS_B 28 | 1IZ1_B 29 | 1J0E_B 30 | 1JG8_D 31 | 1JR3_D 32 | 1JTD_B 33 | 1K3R_B 34 | 1KIZ_B 35 | 1KYQ_B 36 | 1LEH_B 37 | 1M1T_B 38 | 1M32_D 39 | 1MB2_F 40 | 1NCA_N 41 | 1ND6_B 42 | 1O4U_B 43 | 1O57_D 44 | 1O61_A 45 | 1OJ7_D 46 | 1OMZ_B 47 | 1ONW_B 48 | 1P6X_A 49 | 1P9E_A 50 | 1PK8_A 51 | 1POI_A 52 | 1R0K_D 53 | 1R8J_A 54 | 1RQD_A 55 | 1RYI_C 56 | 1SOJ_J 57 | 1ST0_A 58 | 1SUW_D 59 | 1SXJ_D 60 | 1SXJ_E 61 | 1SZ2_B 62 | 1T8U_B 63 | 1TO6_A 64 | 1U0R_C 65 | 1U2V_C 66 | 1U6S_A 67 | 1UL1_X 68 | 1V7C_A 69 | 1VHK_C 70 | 1W1W_B 71 | 1W23_A 72 | 1WDW_H 73 | 1WKH_B 74 | 1WKV_B 75 | 1WPP_B 76 | 1WX1_B 77 | 1X7O_A 78 | 1X9J_G 79 | 1XF9_A 80 | 1XG2_A 81 | 1XI8_A 82 | 1XQS_A 83 | 1XXI_C 84 | 1XXI_E 85 | 1Y56_B 86 | 1Y8Q_A 87 | 1YKJ_A 88 | 1Z85_B 89 | 1ZCT_B 90 | 1ZH8_A 91 | 1ZHH_A 92 | 1ZXO_C 93 | 2AF4_C 94 | 2AUN_A 95 | 2AW6_A 96 | 2AYO_A 97 | 2B3Z_D 98 | 2BP7_B 99 | 2BWN_E 100 | 2C0L_A 101 | 2CDB_A 102 | 2CE8_A 103 | 2CH5_A 104 | 2CVO_B 105 | 2DG0_E 106 | 2E2P_A 107 | 2E5F_A 108 | 2E7J_B 109 | 2E89_A 110 | 2EG5_E 111 | 2EJW_A 112 | 2EP5_B 113 | 2F4M_A 114 | 2F4N_C 115 | 2FV2_C 116 | 2GSZ_F 117 | 2GVQ_C 118 | 2GZA_A 119 | 2HYX_C 120 | 2HZG_A 121 | 2HZK_D 122 | 2I3T_A 123 | 2I7N_B 124 | 2IJZ_G 125 | 2IP2_B 126 | 2J0Q_B 127 | 2J5T_G 128 | 2J6X_G 129 | 2MTA_HL 130 | 2NQL_A 131 | 2OBN_D 132 | 2OGJ_E 133 | 2OIZ_B 134 | 2OOR_AB 135 | 2OZK_B 136 | 2PBI_B 137 | 2PMS_A 138 | 2PP1_A 139 | 2PVP_A 140 | 2Q0J_A 141 | 2QDH_A 142 | 2QFC_A 143 | 2QGI_A 144 | 2QXV_A 145 | 2QYO_A 146 | 2R87_A 147 | 2R8Q_B 148 | 2UX8_C 149 | 2V3B_A 150 | 2V7X_C 151 | 2V9P_G 152 | 2V9T_B 153 | 2VCG_D 154 | 2VHI_C 155 | 2VHW_A 156 | 2VN8_A 157 | 2VUN_C 158 | 2WPX_B 159 | 2WUS_A 160 | 2WVM_A 161 | 2WYR_G 162 | 2X0D_A 163 | 2X2E_D 164 | 2X5D_D 165 | 2X65_B 166 | 2XT2_A 167 | 2XWT_C 168 | 2Y0M_A 169 | 2Y5B_A 170 | 2YCH_A 171 | 2Z4R_B 172 | 2Z50_A 173 | 2Z71_C 174 | 2Z9V_B 175 | 2ZBK_A 176 | 2ZIU_A 177 | 2ZIU_B 178 | 2ZUC_B 179 | 2ZZX_A 180 | 3AB1_B 181 | 3AEO_C 182 | 3AFO_A 183 | 3AP2_A 184 | 3AUY_A 185 | 3B5U_J 186 | 3BH6_B 187 | 3BM5_A 188 | 3BP8_AB 189 | 3BT1_U 190 | 3BTV_B 191 | 3BV6_D 192 | 3BWO_D 193 | 3C0B_C 194 | 3C0K_B 195 | 3C3J_F 196 | 3C48_A 197 | 3CE9_A 198 | 3CEA_B 199 | 3CQ6_A 200 | 3CYJ_B 201 | 3D6K_B 202 | 3DDM_B 203 | 3DHW_D 204 | 3DP7_A 205 | 3DUG_C 206 | 3DZ2_A 207 | 3DZC_B 208 | 3E18_A 209 | 3E38_B 210 | 3E5P_C 211 | 3E9M_A 212 | 3EIQ_C 213 | 3EPW_B 214 | 3ES8_H 215 | 3EUA_H 216 | 3EZ6_B 217 | 3EZY_A 218 | 3FGT_B 219 | 3FHC_A 220 | 3FZ0_B 221 | 3G0I_A 222 | 3G8Q_C 223 | 3G8R_A 224 | 3GJZ_A 225 | 3GL1_A 226 | 3GZT_O 227 | 3H6G_B 228 | 3H77_B 229 | 3H8L_B 230 | 3H9G_A 231 | 3HE3_F 232 | 3HGU_A 233 | 3HLI_A 234 | 3HMK_B 235 | 3HPV_B 236 | 3HWS_D 237 | 3HXJ_A 238 | 3IAU_A 239 | 3IF8_A 240 | 3IGF_B 241 | 3IO1_A 242 | 3ISL_B 243 | 3IX1_B 244 | 3JSK_G 245 | 3JTX_A 246 | 3K5H_A 247 | 3KKI_A 248 | 3KL5_B 249 | 3KL9_J 250 | 3L31_A 251 | 3L9W_B 252 | 3LED_A 253 | 3LEE_A 254 | 3LJQ_C 255 | 3LKU_E 256 | 3LMA_C 257 | 3LOU_C 258 | 3LVK_AC 259 | 3M2T_A 260 | 3MCA_B 261 | 3MCZ_A 262 | 3MEN_B 263 | 3MGC_A 264 | 3MKR_A 265 | 3MMY_G 266 | 3MWE_B 267 | 3MZK_D 268 | 3N29_A 269 | 3N3D_A 270 | 3NND_C 271 | 3NTQ_A 272 | 3NVN_A 273 | 3NVV_B 274 | 3O3M_D 275 | 3O3P_A 276 | 3O5T_A 277 | 3OBK_F 278 | 3OM1_A 279 | 3ON5_A 280 | 3OOQ_H 281 | 3OQB_D 282 | 3OV3_B 283 | 3P6K_B 284 | 3P72_A 285 | 3P9I_C 286 | 3PG9_H 287 | 3PGA_1 288 | 3PND_D 289 | 3PNK_B 290 | 3PUZ_B 291 | 3PWS_A 292 | 3QBX_B 293 | 3QE9_Y 294 | 3QKW_C 295 | 3QML_D 296 | 3QW2_B 297 | 3R0Q_A 298 | 3R1X_A 299 | 3R5X_D 300 | 3R9A_AC 301 | 3RAM_D 302 | 3RCY_C 303 | 3RF6_B 304 | 3RFH_A 305 | 3RHF_B 306 | 3RMR_A 307 | 3RT0_A 308 | 3S5U_E 309 | 3SF5_D 310 | 3SJA_I 311 | 3SN6_A 312 | 3SSO_B 313 | 3SYL_A 314 | 3SZP_A 315 | 3T5P_A 316 | 3T8E_A 317 | 3THO_A 318 | 3THO_B 319 | 3TII_B 320 | 3TK1_B 321 | 3TQC_B 322 | 3TWO_B 323 | 3U5Z_B 324 | 3UGV_C 325 | 3UI2_A 326 | 3UK7_A 327 | 3UVN_C 328 | 3V5N_C 329 | 3VGK_E 330 | 3VH0_D 331 | 3VH3_A 332 | 3VV2_A 333 | 3VYR_B 334 | 3WN7_A 335 | 3ZWL_B 336 | 4BKX_B 337 | 4C9B_B 338 | 4DVG_B 339 | 4ETP_A 340 | 4ETP_B 341 | 4FZV_A 342 | 4HDO_A 343 | 4LVN_A 344 | 4M0W_A 345 | 4V0O_F 346 | 4X33_B 347 | 4XL5_C 348 | 4Y61_B 349 | 4YC7_B 350 | 4YEB_A 351 | 4ZGY_A 352 | 4ZRJ_A 353 | 5BV7_A 354 | 5CXB_B 355 | 5J57_A 356 | 5TIH_A 357 | 5XIM_A 358 | 7MDH_B 359 | 4JLR_S 360 | -------------------------------------------------------------------------------- /lists/testing_ppi.txt: -------------------------------------------------------------------------------- 1 | 1A2K_C_AB 2 | 1A2W_A_B 3 | 1A79_C_B 4 | 1A99_C_D 5 | 1ACB_E_I 6 | 1AGQ_C_D 7 | 1AHS_C_B 8 | 1AK4_A_D 9 | 1AN1_E_I 10 | 1ARZ_A_C 11 | 1ATN_A_D 12 | 1AVX_A_B 13 | 1AY7_A_B 14 | 1B27_A_D 15 | 1B2S_A_D 16 | 1B2U_A_D 17 | 1B3S_A_D 18 | 1B3T_A_B 19 | 1B65_A_B 20 | 1B6C_A_B 21 | 1BJR_I_E 22 | 1BO4_A_B 23 | 1BPO_C_B 24 | 1BRS_A_D 25 | 1C3X_A_C 26 | 1C8N_A_C 27 | 1C9P_A_B 28 | 1C9S_G_H 29 | 1C9T_A_G 30 | 1CBW_ABC_D 31 | 1CGI_E_I 32 | 1CL7_I_H 33 | 1CQ3_A_B 34 | 1D6R_A_I 35 | 1DB2_A_B 36 | 1DEV_C_D 37 | 1DFJ_E_I 38 | 1DJS_A_B 39 | 1DLE_A_B 40 | 1DML_A_B 41 | 1DN2_B_F 42 | 1E5Q_C_D 43 | 1E8N_A_I 44 | 1EAI_A_C 45 | 1EAW_A_B 46 | 1EJA_A_B 47 | 1EM8_A_B 48 | 1EPT_A_B 49 | 1ERN_A_B 50 | 1EWJ_C_D 51 | 1EWY_A_C 52 | 1EZI_A_B 53 | 1EZU_C_AB 54 | 1F2U_C_D 55 | 1F37_A_B 56 | 1F45_A_B 57 | 1F5R_A_I 58 | 1F7Z_A_I 59 | 1F9S_A_D 60 | 1FCC_AB_C 61 | 1FFV_A_C 62 | 1FGL_A_B 63 | 1FIW_A_L 64 | 1FLE_E_I 65 | 1FU5_A_B 66 | 1FY8_E_I 67 | 1G31_C_B 68 | 1G60_A_B 69 | 1G9I_E_I 70 | 1GCQ_C_B 71 | 1GGP_A_B 72 | 1GL0_E_I 73 | 1GL1_A_I 74 | 1GO4_D_H 75 | 1GT7_O_P 76 | 1GUS_C_E 77 | 1GXD_A_C 78 | 1H1V_A_G 79 | 1H6D_I_J 80 | 1H9R_A_B 81 | 1HAA_A_B 82 | 1HBT_I_H 83 | 1HCF_AB_X 84 | 1HEZ_BA_E 85 | 1HIA_AB_I 86 | 1HPU_C_B 87 | 1HX5_A_B 88 | 1HYR_A_C 89 | 1I07_A_B 90 | 1I4E_A_B 91 | 1I4O_B_D 92 | 1I9C_C_D 93 | 1ICF_A_I 94 | 1ID5_H_L 95 | 1IGU_A_B 96 | 1IJX_B_E 97 | 1INN_A_B 98 | 1IYJ_C_D 99 | 1J3R_A_B 100 | 1J4U_A_B 101 | 1JBU_H_X 102 | 1JFM_E_D 103 | 1JIW_P_I 104 | 1JK9_B_A 105 | 1JK9_B_D 106 | 1JK9_C_D 107 | 1JKG_A_B 108 | 1JNP_A_B 109 | 1JTD_B_A 110 | 1JXQ_C_D 111 | 1JYI_A_P 112 | 1JZO_A_B 113 | 1K88_A_B 114 | 1KAC_A_B 115 | 1KCA_G_H 116 | 1KL8_A_B 117 | 1KXJ_A_B 118 | 1KXP_A_D 119 | 1L0A_A_B 120 | 1L2W_C_J 121 | 1L4D_A_B 122 | 1L4I_A_B 123 | 1L4Z_A_B 124 | 1LDT_T_L 125 | 1LK6_I_C 126 | 1LQM_E_F 127 | 1LW6_E_I 128 | 1M1F_A_B 129 | 1MAS_A_B 130 | 1MBY_A_B 131 | 1MCV_A_I 132 | 1MK9_G_F 133 | 1ML0_AB_D 134 | 1MR1_A_D 135 | 1MZW_A_B 136 | 1N0L_C_D 137 | 1NB5_AP_I 138 | 1NP6_A_B 139 | 1NPO_A_C 140 | 1NQ9_I_L 141 | 1NQL_A_B 142 | 1NR7_A_E 143 | 1NR9_C_D 144 | 1NU9_A_C 145 | 1O9A_A_B 146 | 1O9Y_A_D 147 | 1OMO_A_B 148 | 1OS2_E_F 149 | 1OSM_C_B 150 | 1OX9_D_L 151 | 1OYV_B_I 152 | 1P3H_I_H 153 | 1P69_A_B 154 | 1P6A_A_B 155 | 1P9U_F_H 156 | 1PBI_A_B 157 | 1PGL_1_2 158 | 1POI_A_D 159 | 1PPE_E_I 160 | 1PVH_A_B 161 | 1PXV_A_C 162 | 1PXV_B_D 163 | 1Q1L_B_D 164 | 1Q5H_A_B 165 | 1Q8M_A_B 166 | 1Q9U_A_B 167 | 1QB3_A_C 168 | 1QI1_C_B 169 | 1QJS_A_B 170 | 1QOL_G_F 171 | 1R0K_C_D 172 | 1R0R_E_I 173 | 1R7A_A_B 174 | 1R9N_D_H 175 | 1RY7_A_B 176 | 1RZJ_C_G 177 | 1RZP_C_B 178 | 1S1Q_A_B 179 | 1S4C_C_B 180 | 1S98_A_B 181 | 1SCE_A_C 182 | 1SCE_B_D 183 | 1SHS_A_C 184 | 1SHS_A_E 185 | 1SHW_A_B 186 | 1SHY_A_B 187 | 1SMF_E_I 188 | 1SMO_A_B 189 | 1SOT_A_C 190 | 1SUW_C_D 191 | 1T0F_A_B 192 | 1T0H_A_B 193 | 1T0P_A_B 194 | 1T6B_X_Y 195 | 1T7P_A_B 196 | 1T8U_A_B 197 | 1TAW_A_B 198 | 1TE1_A_B 199 | 1TM1_E_I 200 | 1TM3_E_I 201 | 1TM4_E_I 202 | 1TM5_E_I 203 | 1TM7_E_I 204 | 1TO1_E_I 205 | 1TQ9_A_B 206 | 1TZI_BA_V 207 | 1TZS_A_P 208 | 1U20_A_B 209 | 1UAN_A_B 210 | 1UDI_E_I 211 | 1UE7_C_D 212 | 1UGH_E_I 213 | 1UHE_A_B 214 | 1UK4_A_G 215 | 1UM2_A_C 216 | 1UMF_B_D 217 | 1UNN_A_B 218 | 1UP6_A_D 219 | 1UUG_A_B 220 | 1UWG_L_X 221 | 1V8H_A_B 222 | 1VG9_E_G 223 | 1VGC_C_B 224 | 1VGO_A_B 225 | 1VH4_A_B 226 | 1VHJ_C_F 227 | 1VS3_A_B 228 | 1VZY_A_B 229 | 1W1I_B_E 230 | 1W4R_A_D 231 | 1WDX_C_B 232 | 1WLP_A_B 233 | 1WOQ_A_B 234 | 1WQJ_B_I 235 | 1WZ3_A_B 236 | 1X1U_A_D 237 | 1X1W_A_D 238 | 1X1X_A_D 239 | 1X1Y_A_D 240 | 1XD3_A_B 241 | 1XD3_C_D 242 | 1XDT_T_R 243 | 1XFS_A_B 244 | 1XG2_A_B 245 | 1XPJ_A_D 246 | 1XSQ_A_B 247 | 1XT9_A_B 248 | 1XUA_A_B 249 | 1XV2_A_B 250 | 1XWD_A_B 251 | 1Y07_A_B 252 | 1Y0G_A_B 253 | 1Y1K_E_I 254 | 1Y1O_A_B 255 | 1Y33_E_I 256 | 1Y34_E_I 257 | 1Y3B_E_I 258 | 1Y3C_E_I 259 | 1Y3D_E_I 260 | 1Y43_A_B 261 | 1Y48_E_I 262 | 1Y4A_E_I 263 | 1Y4D_E_I 264 | 1Y96_A_B 265 | 1YBG_A_B 266 | 1YC0_A_I 267 | 1YC6_3_E 268 | 1YCS_A_B 269 | 1YFN_D_H 270 | 1YL7_A_D 271 | 1YLQ_A_B 272 | 1YOX_C_B 273 | 1YUK_A_B 274 | 1YVB_A_I 275 | 1YY9_A_D 276 | 1Z0K_A_C 277 | 1Z3G_A_H 278 | 1Z3G_B_I 279 | 1Z5Y_D_E 280 | 1ZCP_A_B 281 | 1ZH8_A_B 282 | 1ZJD_A_B 283 | 1ZLI_A_B 284 | 1ZR0_A_B 285 | 1ZUD_3_4 286 | 1ZVN_A_B 287 | 2A0S_A_B 288 | 2A2L_C_B 289 | 2A5Z_A_C 290 | 2A6P_A_B 291 | 2A74_E_F 292 | 2ABZ_B_E 293 | 2AF6_B_D 294 | 2AFF_A_B 295 | 2ANE_C_D 296 | 2AOB_C_D 297 | 2AQX_A_B 298 | 2AVF_C_B 299 | 2AXW_A_B 300 | 2AZN_A_D 301 | 2B3Z_C_D 302 | 2B42_B_A 303 | 2BBA_A_P 304 | 2BCM_C_B 305 | 2BE1_A_B 306 | 2BMA_C_B 307 | 2BTF_A_P 308 | 2BUJ_A_B 309 | 2C1W_A_B 310 | 2C35_D_H 311 | 2C7E_C_D 312 | 2C9P_A_B 313 | 2CCL_A_C 314 | 2CE8_A_B 315 | 2CH8_A_D 316 | 2CJR_A_B 317 | 2CO6_A_B 318 | 2D10_A_E 319 | 2D1P_D_G 320 | 2D2A_A_B 321 | 2DG0_E_H 322 | 2DOH_X_C 323 | 2DOI_X_C 324 | 2DP4_I_E 325 | 2DPF_C_D 326 | 2DSP_B_I 327 | 2DUP_A_B 328 | 2DYM_G_H 329 | 2E4M_A_C 330 | 2EIL_A_F 331 | 2EJG_B_D 332 | 2EP5_B_D 333 | 2ETE_A_B 334 | 2EVV_C_D 335 | 2EWN_A_B 336 | 2F2F_C_B 337 | 2FB8_A_B 338 | 2FBE_C_D 339 | 2FDB_M_P 340 | 2FE8_A_C 341 | 2FJU_B_A 342 | 2FKD_G_J 343 | 2FP7_A_B 344 | 2FPE_A_B 345 | 2FTL_E_I 346 | 2FTM_A_B 347 | 2FU5_A_D 348 | 2G2U_A_B 349 | 2G2W_A_B 350 | 2G45_A_B 351 | 2G6V_A_B 352 | 2G81_E_I 353 | 2GBK_B_D 354 | 2GD4_C_B 355 | 2GEC_A_B 356 | 2GEF_A_B 357 | 2GHV_C_E 358 | 2GHW_C_D 359 | 2GJV_C_D 360 | 2GKW_A_B 361 | 2GQS_A_B 362 | 2GS7_A_B 363 | 2GT2_A_B 364 | 2H1T_A_B 365 | 2H3N_A_B 366 | 2H5K_A_B 367 | 2HAX_A_B 368 | 2HD3_C_D 369 | 2HD3_E_F 370 | 2HDP_A_B 371 | 2HEK_A_B 372 | 2HEY_T_G 373 | 2HJ1_A_B 374 | 2HL3_A_B 375 | 2HLE_A_B 376 | 2HQH_B_F 377 | 2HQL_A_E 378 | 2HTB_C_D 379 | 2HVB_A_B 380 | 2HVY_A_C 381 | 2HWJ_A_C 382 | 2HZM_G_H 383 | 2HZS_B_H 384 | 2HZS_F_K 385 | 2I04_A_B 386 | 2I0B_A_B 387 | 2I32_A_E 388 | 2I5G_A_B 389 | 2I79_C_B 390 | 2I7R_A_B 391 | 2I9B_C_G 392 | 2IA9_A_B 393 | 2IDO_A_B 394 | 2IHS_A_C 395 | 2IJ0_C_B 396 | 2IO1_A_B 397 | 2IQH_A_C 398 | 2IWO_A_B 399 | 2IWP_A_B 400 | 2IY1_A_B 401 | 2J12_A_B 402 | 2J7Q_C_D 403 | 2J8X_A_B 404 | 2JG8_D_F 405 | 2JI1_C_D 406 | 2JJS_A_C 407 | 2JJS_B_D 408 | 2JJT_A_C 409 | 2JOD_A_B 410 | 2K2S_A_B 411 | 2K6D_A_B 412 | 2KAI_A_B 413 | 2KWJ_A_B 414 | 2L0F_A_B 415 | 2L29_A_B 416 | 2LBU_E_D 417 | 2MCN_A_B 418 | 2MNU_A_B 419 | 2MTA_HL_A 420 | 2NBV_A_B 421 | 2NM1_A_B 422 | 2NN3_C_D 423 | 2NQD_A_B 424 | 2NU1_I_E 425 | 2NUU_K_L 426 | 2NXM_A_B 427 | 2NZ1_X_Y 428 | 2O8Q_A_B 429 | 2O95_A_B 430 | 2O9Q_A_C 431 | 2OGJ_E_F 432 | 2OIN_A_C 433 | 2OKQ_A_B 434 | 2OPI_A_B 435 | 2OS5_A_D 436 | 2OS7_C_F 437 | 2OTL_A_Z 438 | 2OUL_A_B 439 | 2OVI_A_D 440 | 2OYA_A_B 441 | 2OZK_B_D 442 | 2P04_A_B 443 | 2P35_A_B 444 | 2P42_A_B 445 | 2P43_A_B 446 | 2P44_A_B 447 | 2P45_A_B 448 | 2P46_A_B 449 | 2P47_A_B 450 | 2P48_A_B 451 | 2P49_A_B 452 | 2P4A_A_B 453 | 2P4R_A_T 454 | 2P4Z_A_B 455 | 2P5X_A_B 456 | 2P6B_A_E 457 | 2PKG_A_C 458 | 2PMV_A_B 459 | 2PNH_A_B 460 | 2PO2_A_B 461 | 2PQ2_A_B 462 | 2PQS_A_B 463 | 2PQV_A_B 464 | 2PTC_E_I 465 | 2PUY_B_E 466 | 2PZD_A_B 467 | 2Q17_C_B 468 | 2Q7N_A_B 469 | 2Q81_A_D 470 | 2QBW_A_B 471 | 2QBX_A_P 472 | 2QC1_A_B 473 | 2QF3_A_C 474 | 2QKI_A_G 475 | 2QLC_C_B 476 | 2QLP_A_C 477 | 2QW7_I_H 478 | 2QYI_C_D 479 | 2R0K_A_H 480 | 2R2C_A_B 481 | 2R5O_A_B 482 | 2R9P_A_E 483 | 2RA3_A_C 484 | 2RL7_C_D 485 | 2SIC_E_I 486 | 2SNI_E_I 487 | 2TGP_Z_I 488 | 2UUY_A_B 489 | 2V0R_A_B 490 | 2V3B_A_B 491 | 2V52_B_M 492 | 2VE6_A_D 493 | 2VER_A_N 494 | 2VIF_A_P 495 | 2VJF_C_D 496 | 2VPM_A_B 497 | 2VSC_A_B 498 | 2W0C_C_B 499 | 2W1T_A_B 500 | 2W2N_A_E 501 | 2W80_A_D 502 | 2W80_G_H 503 | 2W81_A_D 504 | 2WAM_A_C 505 | 2WC4_A_C 506 | 2WFX_A_B 507 | 2WG3_A_C 508 | 2WG4_A_B 509 | 2WLG_C_B 510 | 2WO2_A_B 511 | 2WO3_A_B 512 | 2WQ4_A_C 513 | 2WQZ_A_D 514 | 2WVT_A_B 515 | 2WWX_A_B 516 | 2X36_C_D 517 | 2X53_W_V 518 | 2X5Q_A_B 519 | 2X89_F_G 520 | 2X8K_A_B 521 | 2X9A_A_D 522 | 2X9A_C_D 523 | 2X9M_B_D 524 | 2XB6_A_C 525 | 2XBB_A_C 526 | 2XCE_E_D 527 | 2XFG_A_B 528 | 2XJZ_C_K 529 | 2XTJ_A_D 530 | 2Y32_B_D 531 | 2Y9X_B_F 532 | 2YC2_A_D 533 | 2YCH_A_B 534 | 2YH9_A_C 535 | 2YVJ_A_B 536 | 2YVL_B_D 537 | 2YVS_A_B 538 | 2YYS_A_B 539 | 2YZJ_A_C 540 | 2Z0E_A_B 541 | 2Z0P_C_D 542 | 2Z29_A_B 543 | 2Z2M_C_B 544 | 2Z7F_E_I 545 | 2Z7X_A_B 546 | 2Z8M_A_B 547 | 2ZA4_A_B 548 | 2ZCK_P_S 549 | 2ZDC_A_B 550 | 2ZG6_A_B 551 | 2ZME_A_C 552 | 2ZNV_E_D 553 | 2ZSU_A_B 554 | 2ZVW_C_E 555 | 2ZVW_C_K 556 | 2ZXW_O_U 557 | 3A1P_C_D 558 | 3AD8_B_D 559 | 3AEH_A_B 560 | 3AFF_A_B 561 | 3AHS_A_C 562 | 3AJY_A_C 563 | 3ALZ_A_B 564 | 3AOG_H_L 565 | 3AXY_B_D 566 | 3B01_A_C 567 | 3B08_A_B 568 | 3B5U_B_D 569 | 3B5U_J_L 570 | 3B6P_B_D 571 | 3B76_A_B 572 | 3B93_A_C 573 | 3B9I_A_B 574 | 3BAL_A_B 575 | 3BCP_A_B 576 | 3BCW_A_B 577 | 3BFW_C_D 578 | 3BGL_C_B 579 | 3BHD_A_B 580 | 3BIW_A_E 581 | 3BN3_A_B 582 | 3BPD_G_F 583 | 3BQB_A_X 584 | 3BRC_A_B 585 | 3BRD_A_D 586 | 3BT1_B_U 587 | 3BTV_A_B 588 | 3BWU_C_D 589 | 3BX1_A_C 590 | 3BX7_A_C 591 | 3C0B_C_D 592 | 3C4O_A_B 593 | 3C4P_A_B 594 | 3C7T_A_C 595 | 3C8I_A_B 596 | 3CAM_A_B 597 | 3CDW_A_H 598 | 3CE9_A_D 599 | 3CEW_C_D 600 | 3CG8_C_B 601 | 3CGY_A_B 602 | 3CHW_A_P 603 | 3CJX_G_H 604 | 3CO2_A_D 605 | 3CQ9_C_D 606 | 3D1E_A_P 607 | 3D1M_A_D 608 | 3D4G_C_D 609 | 3D4R_A_B 610 | 3D5N_F_I 611 | 3DA7_A_D 612 | 3DAW_A_B 613 | 3DAX_A_B 614 | 3DCA_C_D 615 | 3DCL_A_B 616 | 3DGP_A_B 617 | 3DJP_A_B 618 | 3DKU_A_B 619 | 3DQQ_A_B 620 | 3DSN_C_B 621 | 3E05_B_H 622 | 3E05_C_B 623 | 3E1Z_A_B 624 | 3E2K_A_D 625 | 3E2L_A_C 626 | 3E2U_A_E 627 | 3E38_A_B 628 | 3E9M_A_D 629 | 3ECY_A_B 630 | 3EDP_A_B 631 | 3EHU_A_C 632 | 3EMJ_K_L 633 | 3EN0_A_B 634 | 3ENT_A_B 635 | 3EPZ_A_B 636 | 3EUK_C_E 637 | 3EYD_C_D 638 | 3F5N_A_D 639 | 3F74_A_B 640 | 3F75_A_P 641 | 3FCG_A_B 642 | 3FD4_A_B 643 | 3FEF_C_D 644 | 3FFU_A_B 645 | 3FG8_C_E 646 | 3FHC_A_B 647 | 3FJS_C_D 648 | 3FJU_A_B 649 | 3FK9_A_B 650 | 3FL1_A_B 651 | 3FLP_M_N 652 | 3FP6_E_I 653 | 3FPR_A_D 654 | 3FPU_A_B 655 | 3FPV_A_F 656 | 3FSN_A_B 657 | 3FUY_A_B 658 | 3FYF_A_B 659 | 3G9V_A_B 660 | 3GBU_A_D 661 | 3GFU_A_B 662 | 3GMW_A_B 663 | 3GNJ_C_D 664 | 3GQH_A_B 665 | 3GRW_A_H 666 | 3GWY_A_B 667 | 3GXU_A_B 668 | 3GZ8_A_B 669 | 3GZE_A_X 670 | 3GZE_C_Y 671 | 3GZR_A_B 672 | 3H11_BC_A 673 | 3H35_C_B 674 | 3H3B_A_C 675 | 3H6S_A_E 676 | 3H8D_B_F 677 | 3H8G_B_E 678 | 3H9G_A_E 679 | 3HCG_A_C 680 | 3HF5_A_D 681 | 3HHJ_A_B 682 | 3HLN_O_U 683 | 3HM8_A_C 684 | 3HMK_A_B 685 | 3HN6_B_D 686 | 3HO5_B_H 687 | 3HPN_E_F 688 | 3HQR_A_S 689 | 3HRD_E_H 690 | 3HT2_A_C 691 | 3HTR_A_B 692 | 3I2B_E_H 693 | 3I5V_A_D 694 | 3I84_A_B 695 | 3IAS_G_F 696 | 3IBM_A_B 697 | 3ISM_A_B 698 | 3JRQ_A_B 699 | 3JUY_C_B 700 | 3JVC_A_C 701 | 3JVZ_B_D 702 | 3JZA_A_B 703 | 3K1R_A_B 704 | 3K25_A_B 705 | 3K3C_A_B 706 | 3K4W_D_F 707 | 3K6S_A_E 708 | 3K9M_A_C 709 | 3K9M_B_D 710 | 3KDG_A_B 711 | 3KL9_I_J 712 | 3KLQ_A_B 713 | 3KMH_A_B 714 | 3KMT_A_B 715 | 3KQZ_K_L 716 | 3KTM_C_F 717 | 3KTS_A_B 718 | 3KW5_A_B 719 | 3KWV_D_F 720 | 3KY8_A_B 721 | 3KZH_A_B 722 | 3L2H_A_D 723 | 3L33_A_E 724 | 3L9J_C_T 725 | 3LAQ_A_U 726 | 3LHX_A_B 727 | 3LM1_E_F 728 | 3LMS_A_B 729 | 3LQV_B_Q 730 | 3LRJ_C_D 731 | 3LU9_E_F 732 | 3M5O_A_C 733 | 3M5R_A_G 734 | 3M85_B_E 735 | 3M85_B_G 736 | 3MAL_A_B 737 | 3ME4_A_B 738 | 3MJ9_A_H 739 | 3ML6_B_F 740 | 3MQW_C_D 741 | 3MZW_A_B 742 | 3N2B_A_C 743 | 3N4I_A_B 744 | 3N6Q_E_H 745 | 3NCT_B_D 746 | 3NEK_A_B 747 | 3NFG_G_H 748 | 3NGB_E_G 749 | 3NPG_C_D 750 | 3NRJ_A_H 751 | 3NS1_A_B 752 | 3NTQ_A_B 753 | 3O2X_B_D 754 | 3O34_A_B 755 | 3O9L_A_B 756 | 3OEU_2_M 757 | 3OGF_A_B 758 | 3OGO_A_E 759 | 3OJ2_A_C 760 | 3OJM_A_B 761 | 3OKJ_C_D 762 | 3OLM_A_D 763 | 3OSL_A_B 764 | 3OZB_A_C 765 | 3P71_C_T 766 | 3P83_B_F 767 | 3P8B_C_D 768 | 3P92_A_E 769 | 3P95_A_E 770 | 3P9W_A_B 771 | 3PCQ_C_D 772 | 3PGA_1_4 773 | 3PIG_A_B 774 | 3PIM_A_B 775 | 3PNR_A_B 776 | 3PPE_A_B 777 | 3PRP_A_B 778 | 3PS4_C_B 779 | 3PYY_A_B 780 | 3Q0Y_C_B 781 | 3Q7H_K_M 782 | 3Q87_A_B 783 | 3Q9N_A_C 784 | 3Q9U_A_C 785 | 3QC8_A_B 786 | 3QDZ_B_E 787 | 3QFM_A_B 788 | 3QHY_A_B 789 | 3QJ7_A_D 790 | 3QNA_A_D 791 | 3QPB_K_J 792 | 3QQ8_A_B 793 | 3QSK_A_B 794 | 3QWN_I_J 795 | 3QWQ_A_B 796 | 3RBQ_B_H 797 | 3RDZ_A_C 798 | 3RT0_A_C 799 | 3S5B_A_B 800 | 3S8V_A_X 801 | 3S9C_A_B 802 | 3SGB_E_I 803 | 3SGQ_E_I 804 | 3SJ9_A_B 805 | 3SLH_A_B 806 | 3SM1_A_B 807 | 3SOQ_A_Z 808 | 3T3A_A_B 809 | 3TDM_A_B 810 | 3TG9_A_B 811 | 3TGK_E_I 812 | 3THT_A_B 813 | 3TII_A_B 814 | 3TIW_B_D 815 | 3TL8_A_B 816 | 3TND_B_D 817 | 3TQY_A_B 818 | 3TSR_A_E 819 | 3TSZ_A_B 820 | 3TU3_A_B 821 | 3U02_B_D 822 | 3U1O_A_B 823 | 3U4J_C_B 824 | 3UI2_A_B 825 | 3UIR_A_C 826 | 3UVN_C_D 827 | 3UZP_A_B 828 | 3UZV_A_B 829 | 3V3K_A_B 830 | 3V4P_B_H 831 | 3V5N_C_D 832 | 3V96_A_B 833 | 3VPJ_A_E 834 | 3VYR_A_B 835 | 3WA5_A_B 836 | 3WDG_A_B 837 | 3WN7_A_B 838 | 3ZRZ_A_C 839 | 3ZWL_B_E 840 | 4A94_A_D 841 | 4AFQ_A_C 842 | 4AFZ_A_C 843 | 4AG2_A_C 844 | 4AN7_A_B 845 | 4AOQ_A_D 846 | 4AOR_A_D 847 | 4APF_A_B 848 | 4AYD_A_D 849 | 4AYE_A_D 850 | 4AYI_A_D 851 | 4B1V_B_N 852 | 4B1X_B_M 853 | 4B1Y_B_M 854 | 4BD9_A_B 855 | 4BQD_A_C 856 | 4BWQ_E_F 857 | 4CDK_A_E 858 | 4CJ0_A_B 859 | 4CJ1_A_B 860 | 4CMM_A_B 861 | 4CPA_A_I 862 | 4DG4_A_E 863 | 4DGE_A_C 864 | 4DOQ_A_B 865 | 4EIG_A_B 866 | 4EQA_A_C 867 | 4F0A_A_B 868 | 4FT4_B_Q 869 | 4FZA_A_B 870 | 4G6U_A_B 871 | 4GH7_A_B 872 | 4GI3_A_C 873 | 4HDO_A_B 874 | 4I6L_A_B 875 | 4ILW_A_D 876 | 4IOP_A_B 877 | 4J2Y_A_B 878 | 4JRA_A_D 879 | 4JW3_A_C 880 | 4K1R_A_B 881 | 4K24_A_U 882 | 4KBB_A_C 883 | 4KDI_A_C 884 | 4KFZ_A_C 885 | 4KGG_C_A 886 | 4KR0_A_B 887 | 4KRL_B_A 888 | 4KSD_A_B 889 | 4KV5_C_D 890 | 4L0P_A_B 891 | 4LAD_A_B 892 | 4LLO_A_B 893 | 4LQW_A_D 894 | 4LYL_A_B 895 | 4M0W_A_B 896 | 4M5F_A_B 897 | 4MSM_A_B 898 | 4NOO_A_B 899 | 4NSO_A_B 900 | 4NZL_A_B 901 | 4NZW_A_B 902 | 4PEQ_A_B 903 | 4PJ2_A_D 904 | 4POU_A_B 905 | 4PQT_A_B 906 | 4QT8_A_C 907 | 4QZV_A_B 908 | 4RS1_B_A 909 | 4TQ0_A_B 910 | 4TQ1_A_B 911 | 4U30_A_X 912 | 4U97_A_B 913 | 4UDM_B_A 914 | 4V0M_F_E 915 | 4V0N_D_C 916 | 4V0O_F_E 917 | 4W6X_A_B 918 | 4WEM_A_B 919 | 4WEN_A_B 920 | 4XL1_A_B 921 | 4XL5_A_C 922 | 4XLW_A_B 923 | 4XXB_A_B 924 | 4YDJ_HL_G 925 | 4YEB_A_B 926 | 4YN0_A_B 927 | 4YWC_A_C 928 | 4ZK9_A_B 929 | 4ZKC_A_B 930 | 4ZQU_A_B 931 | 5AYR_A_B 932 | 5AYS_A_C 933 | 5B75_A_B 934 | 5B76_A_B 935 | 5B77_A_B 936 | 5B78_A_B 937 | 5CJO_HL_A 938 | 5D1K_A_B 939 | 5D1L_A_B 940 | 5D1M_A_B 941 | 5D3I_A_B 942 | 5DJT_A_B 943 | 5DMJ_A_B 944 | 5EB1_A_B 945 | 5F3X_A_B 946 | 5F4E_A_B 947 | 5G1X_A_B 948 | 5GPG_A_B 949 | 5HPK_A_B 950 | 5INB_A_B 951 | 5IOH_A_B 952 | 5J28_A_C 953 | 5JKE_A_B 954 | 5JLV_A_C 955 | 5JMC_A_B 956 | 5JYL_A_B 957 | 6CMG_CB_A 958 | 4ZQK_A_B 959 | 3BIK_A_B 960 | -------------------------------------------------------------------------------- /main_inference.py: -------------------------------------------------------------------------------- 1 | # Standard imports: 2 | import numpy as np 3 | import torch 4 | from torch.utils.tensorboard import SummaryWriter 5 | from torch.utils.data import random_split 6 | from torch_geometric.data import DataLoader 7 | from torch_geometric.transforms import Compose 8 | from pathlib import Path 9 | 10 | # Custom data loader and model: 11 | from data import ProteinPairsSurfaces, PairData, CenterPairAtoms, load_protein_pair 12 | from data import RandomRotationPairAtoms, NormalizeChemFeatures, iface_valid_filter 13 | from model import dMaSIF 14 | from data_iteration import iterate 15 | from helper import * 16 | from Arguments import parser 17 | 18 | args = parser.parse_args() 19 | model_path = "models/" + args.experiment_name 20 | save_predictions_path = Path("preds/" + args.experiment_name) 21 | 22 | # Ensure reproducability: 23 | torch.backends.cudnn.deterministic = True 24 | torch.manual_seed(args.seed) 25 | torch.cuda.manual_seed_all(args.seed) 26 | np.random.seed(args.seed) 27 | 28 | 29 | # Load the train and test datasets: 30 | transformations = ( 31 | Compose([NormalizeChemFeatures(), CenterPairAtoms(), RandomRotationPairAtoms()]) 32 | if args.random_rotation 33 | else Compose([NormalizeChemFeatures()]) 34 | ) 35 | 36 | if args.single_pdb != "": 37 | single_data_dir = Path("./data_preprocessing/npys/") 38 | test_dataset = [load_protein_pair(args.single_pdb, single_data_dir,single_pdb=True)] 39 | test_pdb_ids = [args.single_pdb] 40 | elif args.pdb_list != "": 41 | with open(args.pdb_list) as f: 42 | pdb_list = f.read().splitlines() 43 | single_data_dir = Path("./data_preprocessing/npys/") 44 | test_dataset = [load_protein_pair(pdb, single_data_dir,single_pdb=True) for pdb in pdb_list] 45 | test_pdb_ids = [pdb for pdb in pdb_list] 46 | else: 47 | test_dataset = ProteinPairsSurfaces( 48 | "surface_data", train=False, ppi=args.search, transform=transformations 49 | ) 50 | test_pdb_ids = ( 51 | np.load("surface_data/processed/testing_pairs_data_ids.npy") 52 | if args.site 53 | else np.load("surface_data/processed/testing_pairs_data_ids_ppi.npy") 54 | ) 55 | 56 | test_dataset = [ 57 | (data, pdb_id) 58 | for data, pdb_id in zip(test_dataset, test_pdb_ids) 59 | if iface_valid_filter(data) 60 | ] 61 | test_dataset, test_pdb_ids = list(zip(*test_dataset)) 62 | 63 | 64 | # PyTorch geometric expects an explicit list of "batched variables": 65 | batch_vars = ["xyz_p1", "xyz_p2", "atom_coords_p1", "atom_coords_p2"] 66 | test_loader = DataLoader( 67 | test_dataset, batch_size=args.batch_size, follow_batch=batch_vars 68 | ) 69 | 70 | net = dMaSIF(args) 71 | # net.load_state_dict(torch.load(model_path, map_location=args.device)) 72 | net.load_state_dict( 73 | torch.load(model_path, map_location=args.device)["model_state_dict"] 74 | ) 75 | net = net.to(args.device) 76 | 77 | # Perform one pass through the data: 78 | info = iterate( 79 | net, 80 | test_loader, 81 | None, 82 | args, 83 | test=True, 84 | save_path=save_predictions_path, 85 | pdb_ids=test_pdb_ids, 86 | ) 87 | 88 | #np.save(f"timings/{args.experiment_name}_convtime.npy", info["conv_time"]) 89 | #np.save(f"timings/{args.experiment_name}_memoryusage.npy", info["memory_usage"]) 90 | -------------------------------------------------------------------------------- /main_training.py: -------------------------------------------------------------------------------- 1 | # Standard imports: 2 | import numpy as np 3 | import torch 4 | from torch.utils.tensorboard import SummaryWriter 5 | from torch.utils.data import random_split 6 | from torch_geometric.data import DataLoader 7 | from torch_geometric.transforms import Compose 8 | from pathlib import Path 9 | 10 | # Custom data loader and model: 11 | from data import ProteinPairsSurfaces, PairData, CenterPairAtoms 12 | from data import RandomRotationPairAtoms, NormalizeChemFeatures, iface_valid_filter 13 | from model import dMaSIF 14 | from data_iteration import iterate, iterate_surface_precompute 15 | from helper import * 16 | from Arguments import parser 17 | 18 | # Parse the arguments, prepare the TensorBoard writer: 19 | args = parser.parse_args() 20 | writer = SummaryWriter("runs/{}".format(args.experiment_name)) 21 | model_path = "models/" + args.experiment_name 22 | 23 | if not Path("models/").exists(): 24 | Path("models/").mkdir(exist_ok=False) 25 | 26 | # Ensure reproducibility: 27 | torch.backends.cudnn.deterministic = True 28 | torch.manual_seed(args.seed) 29 | torch.cuda.manual_seed_all(args.seed) 30 | np.random.seed(args.seed) 31 | 32 | # Create the model, with a warm restart if applicable: 33 | net = dMaSIF(args) 34 | net = net.to(args.device) 35 | 36 | # We load the train and test datasets. 37 | # Random transforms, to ensure that no network/baseline overfits on pose parameters: 38 | transformations = ( 39 | Compose([NormalizeChemFeatures(), CenterPairAtoms(), RandomRotationPairAtoms()]) 40 | if args.random_rotation 41 | else Compose([NormalizeChemFeatures()]) 42 | ) 43 | 44 | # PyTorch geometric expects an explicit list of "batched variables": 45 | batch_vars = ["xyz_p1", "xyz_p2", "atom_coords_p1", "atom_coords_p2"] 46 | # Load the train dataset: 47 | train_dataset = ProteinPairsSurfaces( 48 | "surface_data", ppi=args.search, train=True, transform=transformations 49 | ) 50 | train_dataset = [data for data in train_dataset if iface_valid_filter(data)] 51 | train_loader = DataLoader( 52 | train_dataset, batch_size=1, follow_batch=batch_vars, shuffle=True 53 | ) 54 | print("Preprocessing training dataset") 55 | train_dataset = iterate_surface_precompute(train_loader, net, args) 56 | 57 | # Train/Validation split: 58 | train_nsamples = len(train_dataset) 59 | val_nsamples = int(train_nsamples * args.validation_fraction) 60 | train_nsamples = train_nsamples - val_nsamples 61 | train_dataset, val_dataset = random_split( 62 | train_dataset, [train_nsamples, val_nsamples] 63 | ) 64 | 65 | # Load the test dataset: 66 | test_dataset = ProteinPairsSurfaces( 67 | "surface_data", ppi=args.search, train=False, transform=transformations 68 | ) 69 | test_dataset = [data for data in test_dataset if iface_valid_filter(data)] 70 | test_loader = DataLoader( 71 | test_dataset, batch_size=1, follow_batch=batch_vars, shuffle=True 72 | ) 73 | print("Preprocessing testing dataset") 74 | test_dataset = iterate_surface_precompute(test_loader, net, args) 75 | 76 | 77 | # PyTorch_geometric data loaders: 78 | train_loader = DataLoader( 79 | train_dataset, batch_size=1, follow_batch=batch_vars, shuffle=True 80 | ) 81 | val_loader = DataLoader(val_dataset, batch_size=1, follow_batch=batch_vars) 82 | test_loader = DataLoader(test_dataset, batch_size=1, follow_batch=batch_vars) 83 | 84 | 85 | # Baseline optimizer: 86 | optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, amsgrad=True) 87 | best_loss = 1e10 # We save the "best model so far" 88 | 89 | starting_epoch = 0 90 | if args.restart_training != "": 91 | checkpoint = torch.load("models/" + args.restart_training) 92 | net.load_state_dict(checkpoint["model_state_dict"]) 93 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 94 | starting_epoch = checkpoint["epoch"] 95 | best_loss = checkpoint["best_loss"] 96 | 97 | # Training loop (~100 times) over the dataset: 98 | for i in range(starting_epoch, args.n_epochs): 99 | # Train first, Test second: 100 | for dataset_type in ["Train", "Validation", "Test"]: 101 | if dataset_type == "Train": 102 | test = False 103 | else: 104 | test = True 105 | 106 | suffix = dataset_type 107 | if dataset_type == "Train": 108 | dataloader = train_loader 109 | elif dataset_type == "Validation": 110 | dataloader = val_loader 111 | elif dataset_type == "Test": 112 | dataloader = test_loader 113 | 114 | # Perform one pass through the data: 115 | info = iterate( 116 | net, 117 | dataloader, 118 | optimizer, 119 | args, 120 | test=test, 121 | summary_writer=writer, 122 | epoch_number=i, 123 | ) 124 | 125 | # Write down the results using a TensorBoard writer: 126 | for key, val in info.items(): 127 | if key in [ 128 | "Loss", 129 | "ROC-AUC", 130 | "Distance/Positives", 131 | "Distance/Negatives", 132 | "Matching ROC-AUC", 133 | ]: 134 | writer.add_scalar(f"{key}/{suffix}", np.mean(val), i) 135 | 136 | if "R_values/" in key: 137 | val = np.array(val) 138 | writer.add_scalar(f"{key}/{suffix}", np.mean(val[val > 0]), i) 139 | 140 | if dataset_type == "Validation": # Store validation loss for saving the model 141 | val_loss = np.mean(info["Loss"]) 142 | 143 | if True: # Additional saves 144 | if val_loss < best_loss: 145 | print("Validation loss {}, saving model".format(val_loss)) 146 | torch.save( 147 | { 148 | "epoch": i, 149 | "model_state_dict": net.state_dict(), 150 | "optimizer_state_dict": optimizer.state_dict(), 151 | "best_loss": best_loss, 152 | }, 153 | model_path + "_epoch{}".format(i), 154 | ) 155 | 156 | best_loss = val_loss 157 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.autograd.profiler as profiler 7 | from pykeops.torch import LazyTensor 8 | 9 | from geometry_processing import ( 10 | curvatures, 11 | mesh_normals_areas, 12 | tangent_vectors, 13 | atoms_to_points_normals, 14 | ) 15 | from helper import soft_dimension, diagonal_ranges 16 | from benchmark_models import DGCNN_seg, PointNet2_seg, dMaSIFConv_seg 17 | 18 | 19 | def knn_atoms(x, y, x_batch, y_batch, k): 20 | N, D = x.shape 21 | x_i = LazyTensor(x[:, None, :]) 22 | y_j = LazyTensor(y[None, :, :]) 23 | 24 | pairwise_distance_ij = ((x_i - y_j) ** 2).sum(-1) 25 | pairwise_distance_ij.ranges = diagonal_ranges(x_batch, y_batch) 26 | 27 | # N.B.: KeOps doesn't yet support backprop through Kmin reductions... 28 | # dists, idx = pairwise_distance_ij.Kmin_argKmin(K=k,axis=1) 29 | # So we have to re-compute the values ourselves: 30 | idx = pairwise_distance_ij.argKmin(K=k, axis=1) # (N, K) 31 | x_ik = y[idx.view(-1)].view(N, k, D) 32 | dists = ((x[:, None, :] - x_ik) ** 2).sum(-1) 33 | 34 | return idx, dists 35 | 36 | 37 | def get_atom_features(x, y, x_batch, y_batch, y_atomtype, k=16): 38 | 39 | idx, dists = knn_atoms(x, y, x_batch, y_batch, k=k) # (num_points, k) 40 | num_points, _ = idx.size() 41 | 42 | idx = idx.view(-1) 43 | dists = 1 / dists.view(-1, 1) 44 | _, num_dims = y_atomtype.size() 45 | 46 | feature = y_atomtype[idx, :] 47 | feature = torch.cat([feature, dists], dim=1) 48 | feature = feature.view(num_points, k, num_dims + 1) 49 | 50 | return feature 51 | 52 | 53 | class Atom_embedding(nn.Module): 54 | def __init__(self, args): 55 | super(Atom_embedding, self).__init__() 56 | self.D = args.atom_dims 57 | self.k = 16 58 | self.conv1 = nn.Linear(self.D + 1, self.D) 59 | self.conv2 = nn.Linear(self.D, self.D) 60 | self.conv3 = nn.Linear(2 * self.D, self.D) 61 | self.bn1 = nn.BatchNorm1d(self.D) 62 | self.bn2 = nn.BatchNorm1d(self.D) 63 | self.relu = nn.LeakyReLU(negative_slope=0.2) 64 | 65 | def forward(self, x, y, y_atomtypes, x_batch, y_batch): 66 | fx = get_atom_features(x, y, x_batch, y_batch, y_atomtypes, k=self.k) 67 | fx = self.conv1(fx) 68 | fx = fx.view(-1, self.D) 69 | fx = self.bn1(self.relu(fx)) 70 | fx = fx.view(-1, self.k, self.D) 71 | fx1 = fx.sum(dim=1, keepdim=False) 72 | 73 | fx = self.conv2(fx) 74 | fx = fx.view(-1, self.D) 75 | fx = self.bn2(self.relu(fx)) 76 | fx = fx.view(-1, self.k, self.D) 77 | fx2 = fx.sum(dim=1, keepdim=False) 78 | fx = torch.cat((fx1, fx2), dim=-1) 79 | fx = self.conv3(fx) 80 | 81 | return fx 82 | 83 | 84 | class AtomNet(nn.Module): 85 | def __init__(self, args): 86 | super(AtomNet, self).__init__() 87 | self.args = args 88 | 89 | self.transform_types = nn.Sequential( 90 | nn.Linear(args.atom_dims, args.atom_dims), 91 | nn.LeakyReLU(negative_slope=0.2), 92 | nn.Linear(args.atom_dims, args.atom_dims), 93 | nn.LeakyReLU(negative_slope=0.2), 94 | nn.Linear(args.atom_dims, args.atom_dims), 95 | nn.LeakyReLU(negative_slope=0.2), 96 | ) 97 | self.embed = Atom_embedding(args) 98 | 99 | def forward(self, xyz, atom_xyz, atomtypes, batch, atom_batch): 100 | # Run a DGCNN on the available information: 101 | atomtypes = self.transform_types(atomtypes) 102 | return self.embed(xyz, atom_xyz, atomtypes, batch, atom_batch) 103 | 104 | class Atom_embedding_MP(nn.Module): 105 | def __init__(self, args): 106 | super(Atom_embedding_MP, self).__init__() 107 | self.D = args.atom_dims 108 | self.k = 16 109 | self.n_layers = 3 110 | self.mlp = nn.ModuleList( 111 | [ 112 | nn.Sequential( 113 | nn.Linear(2 * self.D + 1, 2 * self.D + 1), 114 | nn.LeakyReLU(negative_slope=0.2), 115 | nn.Linear(2 * self.D + 1, self.D), 116 | ) 117 | for i in range(self.n_layers) 118 | ] 119 | ) 120 | self.norm = nn.ModuleList( 121 | [nn.GroupNorm(2, self.D) for i in range(self.n_layers)] 122 | ) 123 | self.relu = nn.LeakyReLU(negative_slope=0.2) 124 | 125 | def forward(self, x, y, y_atomtypes, x_batch, y_batch): 126 | idx, dists = knn_atoms(x, y, x_batch, y_batch, k=self.k) # N, 9, 7 127 | num_points = x.shape[0] 128 | num_dims = y_atomtypes.shape[-1] 129 | 130 | point_emb = torch.ones_like(x[:, 0])[:, None].repeat(1, num_dims) 131 | for i in range(self.n_layers): 132 | features = y_atomtypes[idx.reshape(-1), :] 133 | features = torch.cat([features, dists.reshape(-1, 1)], dim=1) 134 | features = features.view(num_points, self.k, num_dims + 1) 135 | features = torch.cat( 136 | [point_emb[:, None, :].repeat(1, self.k, 1), features], dim=-1 137 | ) # N, 8, 13 138 | 139 | messages = self.mlp[i](features) # N,8,6 140 | messages = messages.sum(1) # N,6 141 | point_emb = point_emb + self.relu(self.norm[i](messages)) 142 | 143 | return point_emb 144 | 145 | class Atom_Atom_embedding_MP(nn.Module): 146 | def __init__(self, args): 147 | super(Atom_Atom_embedding_MP, self).__init__() 148 | self.D = args.atom_dims 149 | self.k = 17 150 | self.n_layers = 3 151 | 152 | self.mlp = nn.ModuleList( 153 | [ 154 | nn.Sequential( 155 | nn.Linear(2 * self.D + 1, 2 * self.D + 1), 156 | nn.LeakyReLU(negative_slope=0.2), 157 | nn.Linear(2 * self.D + 1, self.D), 158 | ) 159 | for i in range(self.n_layers) 160 | ] 161 | ) 162 | 163 | self.norm = nn.ModuleList( 164 | [nn.GroupNorm(2, self.D) for i in range(self.n_layers)] 165 | ) 166 | self.relu = nn.LeakyReLU(negative_slope=0.2) 167 | 168 | def forward(self, x, y, y_atomtypes, x_batch, y_batch): 169 | idx, dists = knn_atoms(x, y, x_batch, y_batch, k=self.k) # N, 9, 7 170 | idx = idx[:, 1:] # Remove self 171 | dists = dists[:, 1:] 172 | k = self.k - 1 173 | num_points = y_atomtypes.shape[0] 174 | 175 | out = y_atomtypes 176 | for i in range(self.n_layers): 177 | _, num_dims = out.size() 178 | features = out[idx.reshape(-1), :] 179 | features = torch.cat([features, dists.reshape(-1, 1)], dim=1) 180 | features = features.view(num_points, k, num_dims + 1) 181 | features = torch.cat( 182 | [out[:, None, :].repeat(1, k, 1), features], dim=-1 183 | ) # N, 8, 13 184 | 185 | messages = self.mlp[i](features) # N,8,6 186 | messages = messages.sum(1) # N,6 187 | out = out + self.relu(self.norm[i](messages)) 188 | 189 | return out 190 | 191 | class AtomNet_MP(nn.Module): 192 | def __init__(self, args): 193 | super(AtomNet_MP, self).__init__() 194 | self.args = args 195 | 196 | self.transform_types = nn.Sequential( 197 | nn.Linear(args.atom_dims, args.atom_dims), 198 | nn.LeakyReLU(negative_slope=0.2), 199 | nn.Linear(args.atom_dims, args.atom_dims), 200 | ) 201 | 202 | self.embed = Atom_embedding_MP(args) 203 | self.atom_atom = Atom_Atom_embedding_MP(args) 204 | 205 | def forward(self, xyz, atom_xyz, atomtypes, batch, atom_batch): 206 | # Run a DGCNN on the available information: 207 | atomtypes = self.transform_types(atomtypes) 208 | atomtypes = self.atom_atom( 209 | atom_xyz, atom_xyz, atomtypes, atom_batch, atom_batch 210 | ) 211 | atomtypes = self.embed(xyz, atom_xyz, atomtypes, batch, atom_batch) 212 | return atomtypes 213 | 214 | 215 | def combine_pair(P1, P2): 216 | P1P2 = {} 217 | for key in P1: 218 | v1 = P1[key] 219 | v2 = P2[key] 220 | if v1 is None: 221 | continue 222 | 223 | if key == "batch" or key == "batch_atoms": 224 | v1v2 = torch.cat([v1, v2 + v1[-1] + 1], dim=0) 225 | elif key == "triangles": 226 | # v1v2 = torch.cat([v1,v2],dim=1) 227 | continue 228 | else: 229 | v1v2 = torch.cat([v1, v2], dim=0) 230 | P1P2[key] = v1v2 231 | 232 | return P1P2 233 | 234 | 235 | def split_pair(P1P2): 236 | batch_size = P1P2["batch_atoms"][-1] + 1 237 | p1_indices = P1P2["batch"] < batch_size // 2 238 | p2_indices = P1P2["batch"] >= batch_size // 2 239 | 240 | p1_atom_indices = P1P2["batch_atoms"] < batch_size // 2 241 | p2_atom_indices = P1P2["batch_atoms"] >= batch_size // 2 242 | 243 | P1 = {} 244 | P2 = {} 245 | for key in P1P2: 246 | v1v2 = P1P2[key] 247 | 248 | if (key == "rand_rot") or (key == "atom_center"): 249 | n = v1v2.shape[0] // 2 250 | P1[key] = v1v2[:n].view(-1, 3) 251 | P2[key] = v1v2[n:].view(-1, 3) 252 | elif "atom" in key: 253 | P1[key] = v1v2[p1_atom_indices] 254 | P2[key] = v1v2[p2_atom_indices] 255 | elif key == "triangles": 256 | continue 257 | # P1[key] = v1v2[:,p1_atom_indices] 258 | # P2[key] = v1v2[:,p2_atom_indices] 259 | else: 260 | P1[key] = v1v2[p1_indices] 261 | P2[key] = v1v2[p2_indices] 262 | 263 | P2["batch"] = P2["batch"] - batch_size + 1 264 | P2["batch_atoms"] = P2["batch_atoms"] - batch_size + 1 265 | 266 | return P1, P2 267 | 268 | 269 | 270 | def project_iface_labels(P, threshold=2.0): 271 | 272 | queries = P["xyz"] 273 | batch_queries = P["batch"] 274 | source = P["mesh_xyz"] 275 | batch_source = P["mesh_batch"] 276 | labels = P["mesh_labels"] 277 | x_i = LazyTensor(queries[:, None, :]) # (N, 1, D) 278 | y_j = LazyTensor(source[None, :, :]) # (1, M, D) 279 | 280 | D_ij = ((x_i - y_j) ** 2).sum(-1).sqrt() # (N, M) 281 | D_ij.ranges = diagonal_ranges(batch_queries, batch_source) 282 | nn_i = D_ij.argmin(dim=1).view(-1) # (N,) 283 | nn_dist_i = ( 284 | D_ij.min(dim=1).view(-1, 1) < threshold 285 | ).float() # If chain is not connected because of missing densities MaSIF cut out a part of the protein 286 | query_labels = labels[nn_i] * nn_dist_i 287 | P["labels"] = query_labels 288 | 289 | class dMaSIF(nn.Module): 290 | def __init__(self, args): 291 | super(dMaSIF, self).__init__() 292 | # Additional geometric features: mean and Gauss curvatures computed at different scales. 293 | self.curvature_scales = args.curvature_scales 294 | self.args = args 295 | 296 | I = args.in_channels 297 | O = args.orientation_units 298 | E = args.emb_dims 299 | H = args.post_units 300 | 301 | # Computes chemical features 302 | self.atomnet = AtomNet_MP(args) 303 | self.dropout = nn.Dropout(args.dropout) 304 | 305 | if args.embedding_layer == "dMaSIF": 306 | # Post-processing, without batch norm: 307 | self.orientation_scores = nn.Sequential( 308 | nn.Linear(I, O), 309 | nn.LeakyReLU(negative_slope=0.2), 310 | nn.Linear(O, 1), 311 | ) 312 | 313 | # Segmentation network: 314 | self.conv = dMaSIFConv_seg( 315 | args, 316 | in_channels=I, 317 | out_channels=E, 318 | n_layers=args.n_layers, 319 | radius=args.radius, 320 | ) 321 | 322 | # Asymmetric embedding 323 | if args.search: 324 | self.orientation_scores2 = nn.Sequential( 325 | nn.Linear(I, O), 326 | nn.LeakyReLU(negative_slope=0.2), 327 | nn.Linear(O, 1), 328 | ) 329 | 330 | self.conv2 = dMaSIFConv_seg( 331 | args, 332 | in_channels=I, 333 | out_channels=E, 334 | n_layers=args.n_layers, 335 | radius=args.radius, 336 | ) 337 | 338 | elif args.embedding_layer == "DGCNN": 339 | self.conv = DGCNN_seg(I + 3, E,self.args.n_layers,self.args.k) 340 | if args.search: 341 | self.conv2 = DGCNN_seg(I + 3, E,self.args.n_layers,self.args.k) 342 | 343 | elif args.embedding_layer == "PointNet++": 344 | self.conv = PointNet2_seg(args, I, E) 345 | if args.search: 346 | self.conv2 = PointNet2_seg(args, I, E) 347 | 348 | if args.site: 349 | # Post-processing, without batch norm: 350 | self.net_out = nn.Sequential( 351 | nn.Linear(E, H), 352 | nn.LeakyReLU(negative_slope=0.2), 353 | nn.Linear(H, H), 354 | nn.LeakyReLU(negative_slope=0.2), 355 | nn.Linear(H, 1), 356 | ) 357 | 358 | def features(self, P, i=1): 359 | """Estimates geometric and chemical features from a protein surface or a cloud of atoms.""" 360 | if ( 361 | not self.args.use_mesh and "xyz" not in P 362 | ): # Compute the pseudo-surface directly from the atoms 363 | # (Note that we use the fact that dicts are "passed by reference" here) 364 | P["xyz"], P["normals"], P["batch"] = atoms_to_points_normals( 365 | P["atoms"], 366 | P["batch_atoms"], 367 | atomtypes=P["atomtypes"], 368 | resolution=self.args.resolution, 369 | sup_sampling=self.args.sup_sampling, 370 | ) 371 | 372 | # Estimate the curvatures using the triangles or the estimated normals: 373 | P_curvatures = curvatures( 374 | P["xyz"], 375 | triangles=P["triangles"] if self.args.use_mesh else None, 376 | normals=None if self.args.use_mesh else P["normals"], 377 | scales=self.curvature_scales, 378 | batch=P["batch"], 379 | ) 380 | 381 | # Compute chemical features on-the-fly: 382 | chemfeats = self.atomnet( 383 | P["xyz"], P["atom_xyz"], P["atomtypes"], P["batch"], P["batch_atoms"] 384 | ) 385 | 386 | if self.args.no_chem: 387 | chemfeats = 0.0 * chemfeats 388 | if self.args.no_geom: 389 | P_curvatures = 0.0 * P_curvatures 390 | 391 | # Concatenate our features: 392 | return torch.cat([P_curvatures, chemfeats], dim=1).contiguous() 393 | 394 | def embed(self, P): 395 | """Embeds all points of a protein in a high-dimensional vector space.""" 396 | 397 | features = self.dropout(self.features(P)) 398 | P["input_features"] = features 399 | 400 | torch.cuda.synchronize(device=features.device) 401 | torch.cuda.reset_max_memory_allocated(device=P["atoms"].device) 402 | begin = time.time() 403 | 404 | # Ours: 405 | if self.args.embedding_layer == "dMaSIF": 406 | self.conv.load_mesh( 407 | P["xyz"], 408 | triangles=P["triangles"] if self.args.use_mesh else None, 409 | normals=None if self.args.use_mesh else P["normals"], 410 | weights=self.orientation_scores(features), 411 | batch=P["batch"], 412 | ) 413 | P["embedding_1"] = self.conv(features) 414 | if self.args.search: 415 | self.conv2.load_mesh( 416 | P["xyz"], 417 | triangles=P["triangles"] if self.args.use_mesh else None, 418 | normals=None if self.args.use_mesh else P["normals"], 419 | weights=self.orientation_scores2(features), 420 | batch=P["batch"], 421 | ) 422 | P["embedding_2"] = self.conv2(features) 423 | 424 | # First baseline: 425 | elif self.args.embedding_layer == "DGCNN": 426 | features = torch.cat([features, P["xyz"]], dim=-1).contiguous() 427 | P["embedding_1"] = self.conv(P["xyz"], features, P["batch"]) 428 | if self.args.search: 429 | P["embedding_2"] = self.conv2( 430 | P["xyz"], features, P["batch"] 431 | ) 432 | 433 | # Second baseline 434 | elif self.args.embedding_layer == "PointNet++": 435 | P["embedding_1"] = self.conv(P["xyz"], features, P["batch"]) 436 | if self.args.search: 437 | P["embedding_2"] = self.conv2(P["xyz"], features, P["batch"]) 438 | 439 | torch.cuda.synchronize(device=features.device) 440 | end = time.time() 441 | memory_usage = torch.cuda.max_memory_allocated(device=P["atoms"].device) 442 | conv_time = end - begin 443 | 444 | return conv_time, memory_usage 445 | 446 | def preprocess_surface(self, P): 447 | P["xyz"], P["normals"], P["batch"] = atoms_to_points_normals( 448 | P["atoms"], 449 | P["batch_atoms"], 450 | atomtypes=P["atomtypes"], 451 | resolution=self.args.resolution, 452 | sup_sampling=self.args.sup_sampling, 453 | distance=self.args.distance, 454 | ) 455 | if P['mesh_labels'] is not None: 456 | project_iface_labels(P) 457 | 458 | def forward(self, P1, P2=None): 459 | # Compute embeddings of the point clouds: 460 | if P2 is not None: 461 | P1P2 = combine_pair(P1, P2) 462 | else: 463 | P1P2 = P1 464 | 465 | conv_time, memory_usage = self.embed(P1P2) 466 | 467 | # Monitor the approximate rank of our representations: 468 | R_values = {} 469 | R_values["input"] = soft_dimension(P1P2["input_features"]) 470 | R_values["conv"] = soft_dimension(P1P2["embedding_1"]) 471 | 472 | if self.args.site: 473 | P1P2["iface_preds"] = self.net_out(P1P2["embedding_1"]) 474 | 475 | if P2 is not None: 476 | P1, P2 = split_pair(P1P2) 477 | else: 478 | P1 = P1P2 479 | 480 | return { 481 | "P1": P1, 482 | "P2": P2, 483 | "R_values": R_values, 484 | "conv_time": conv_time, 485 | "memory_usage": memory_usage, 486 | } 487 | -------------------------------------------------------------------------------- /models/dMaSIF_search_3layer_12A_16dim: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FreyrS/dMaSIF/0dcc26c3c218a39d5fe26beb2e788b95fb028896/models/dMaSIF_search_3layer_12A_16dim -------------------------------------------------------------------------------- /overview.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FreyrS/dMaSIF/0dcc26c3c218a39d5fe26beb2e788b95fb028896/overview.PNG -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | appdirs==1.4.4 3 | argon2-cffi==20.1.0 4 | ase==3.20.1 5 | async-generator==1.10 6 | attrs==20.3.0 7 | backcall==0.2.0 8 | biopython==1.78 9 | black==20.8b1 10 | bleach==3.2.1 11 | cached-property==1.5.2 12 | cachetools==4.1.1 13 | certifi==2020.6.20 14 | cffi==1.14.3 15 | chardet==3.0.4 16 | click==7.1.2 17 | cloudpickle==1.6.0 18 | cycler==0.10.0 19 | dask==2020.12.0 20 | decorator==4.4.2 21 | defusedxml==0.6.0 22 | entrypoints==0.3 23 | future==0.18.2 24 | google-auth==1.23.0 25 | google-auth-oauthlib==0.4.2 26 | googledrivedownloader==0.4 27 | GPUtil==1.4.0 28 | grpcio==1.33.2 29 | h5py==3.0.0 30 | idna==2.10 31 | importlib-metadata==2.0.0 32 | ipykernel==5.3.4 33 | ipython==7.19.0 34 | ipython-genutils==0.2.0 35 | ipywidgets==7.5.1 36 | isodate==0.6.0 37 | jedi==0.17.2 38 | Jinja2==2.11.2 39 | joblib==0.17.0 40 | jsonschema==3.2.0 41 | jupyter==1.0.0 42 | jupyter-client==6.1.7 43 | jupyter-console==6.2.0 44 | jupyter-core==4.6.3 45 | jupyterlab-pygments==0.1.2 46 | kaleido==0.0.3.post1 47 | kiwisolver==1.3.1 48 | llvmlite==0.34.0 49 | Markdown==3.3.3 50 | MarkupSafe==1.1.1 51 | matplotlib==3.3.2 52 | mistune==0.8.4 53 | mypy-extensions==0.4.3 54 | nbclient==0.5.1 55 | nbconvert==6.0.7 56 | nbformat==5.0.8 57 | nest-asyncio==1.4.3 58 | networkx==2.5 59 | notebook==6.1.5 60 | numba==0.51.2 61 | numpy==1.19.3 62 | oauthlib==3.1.0 63 | packaging==20.4 64 | pandas==1.1.4 65 | pandocfilters==1.4.3 66 | parso==0.7.1 67 | pathspec==0.8.1 68 | pexpect==4.8.0 69 | pickleshare==0.7.5 70 | Pillow==8.0.1 71 | plotly==4.13.0 72 | plyfile==0.7.2 73 | prometheus-client==0.8.0 74 | prompt-toolkit==3.0.8 75 | protobuf==3.13.0 76 | ptyprocess==0.6.0 77 | pyasn1==0.4.8 78 | pyasn1-modules==0.2.8 79 | pycparser==2.20 80 | Pygments==2.7.2 81 | pykeops==1.4.1 82 | pyparsing==2.4.7 83 | pyrsistent==0.17.3 84 | python-dateutil==2.8.1 85 | pytz==2020.4 86 | PyVTK==0.5.18 87 | PyYAML==5.3.1 88 | pyzmq==19.0.2 89 | qtconsole==4.7.7 90 | QtPy==1.9.0 91 | rdflib==5.0.0 92 | regex==2020.11.13 93 | requests==2.24.0 94 | requests-oauthlib==1.3.0 95 | retrying==1.3.3 96 | rope==0.18.0 97 | rsa==4.6 98 | scikit-learn==0.23.2 99 | scipy==1.5.3 100 | seaborn==0.11.0 101 | Send2Trash==1.5.0 102 | six==1.15.0 103 | tensorboard==2.3.0 104 | tensorboard-plugin-wit==1.7.0 105 | terminado==0.9.1 106 | testpath==0.4.4 107 | threadpoolctl==2.1.0 108 | toml==0.10.2 109 | toolz==0.11.1 110 | torch==1.6.0 111 | torch-cluster==1.5.8 112 | torch-geometric==1.6.1 113 | torch-scatter==2.0.5 114 | torch-sparse==0.6.8 115 | torch-spline-conv==1.2.0 116 | tornado==6.1 117 | tqdm==4.51.0 118 | traitlets==5.0.5 119 | typed-ast==1.4.1 120 | typing-extensions==3.7.4.3 121 | urllib3==1.25.11 122 | wcwidth==0.2.5 123 | webencodings==0.5.1 124 | Werkzeug==1.0.1 125 | widgetsnbextension==3.5.1 126 | zipp==3.4.0 127 | --------------------------------------------------------------------------------