├── .gitattributes ├── .gitignore ├── PSICHIC ├── LICENSE ├── config.json ├── models │ ├── .DS_Store │ ├── drug_pool.py │ ├── layers.py │ ├── mini_net.py │ ├── net.py │ ├── pna.py │ ├── protein_pool.py │ └── scaler.py ├── psichic_utils │ ├── .DS_Store │ ├── __init__.py │ ├── data_utils.py │ ├── dataset.py │ ├── interpretation.py │ ├── ligand_init.py │ ├── metrics.py │ └── protein_init.py ├── runtime_config.py ├── trained_weights │ ├── .DS_Store │ ├── PDBv2020_PSICHIC │ │ ├── config.json │ │ ├── degree copy.pt │ │ ├── degree.pt │ │ └── model.pt │ ├── TREAT1 │ │ ├── config.json │ │ ├── degree.pt │ │ └── model.pt │ └── multitask_PSICHIC │ │ ├── config.json │ │ ├── degree.pt │ │ └── model.pt └── wrapper.py ├── README.md ├── auto_updater.py ├── btdr.py ├── config ├── config.yaml └── config_loader.py ├── example.env ├── install_deps_cpu.sh ├── install_deps_cu124.sh ├── my_utils.py ├── neurons ├── miner.py ├── set_weight_to_uid.py └── validator.py └── requirements └── requirements.txt /.gitattributes: -------------------------------------------------------------------------------- 1 | # .gitattributes 2 | 3 | # Treat all text files as text and ensure LF 4 | * text eol=lf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .* 3 | !/.gitignore 4 | venv 5 | *.log 6 | *.log.* 7 | -------------------------------------------------------------------------------- /PSICHIC/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 - Huan Yee Koh 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /PSICHIC/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "tasks": { 3 | "regression_task": false, 4 | "classification_task": false, 5 | "mclassification_task": 0 6 | }, 7 | "optimizer": { 8 | "lrate": 0.0001, 9 | "weight_decay": 0.0001, 10 | "clip": 1, 11 | "betas": [ 12 | 0.9, 13 | 0.999 14 | ], 15 | "eps": 1e-08, 16 | "schedule_lr": false, 17 | "min_lrate": 0, 18 | "warmup_iters": 0, 19 | "lr_decay_iters": 0, 20 | "amsgrad": false 21 | }, 22 | "params": { 23 | "mol_in_channels": 43, 24 | "prot_in_channels": 33, 25 | "prot_evo_channels": 1280, 26 | "hidden_channels": 200, 27 | "aggregators": [ 28 | "mean", 29 | "min", 30 | "max", 31 | "std" 32 | ], 33 | "scalers": [ 34 | "identity", 35 | "amplification", 36 | "linear" 37 | ], 38 | "pre_layers": 2, 39 | "post_layers": 1, 40 | "total_layer": 3, 41 | "K": [ 42 | 5, 43 | 10, 44 | 20 45 | ], 46 | "dropout": 0, 47 | "dropout_attn_score": 0.2, 48 | "heads": 5 49 | } 50 | } -------------------------------------------------------------------------------- /PSICHIC/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metanova-labs/nova/bd3dbf28c45a8b089424939a1c39c90cc291dcd0/PSICHIC/models/.DS_Store -------------------------------------------------------------------------------- /PSICHIC/models/drug_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.utils import softmax 4 | 5 | from torch_scatter import scatter 6 | from torch_geometric.nn import global_add_pool 7 | from .layers import MLP, dropout_node 8 | 9 | class MotifPool(torch.nn.Module): 10 | def __init__(self, hidden_dim, heads, dropout_attn_score=0, dropout_node_proba=0): 11 | super().__init__() 12 | assert hidden_dim % heads == 0 13 | 14 | self.lin_proj = torch.nn.Linear(hidden_dim, hidden_dim) 15 | hidden_dim = hidden_dim // heads 16 | 17 | self.score_proj = torch.nn.ModuleList() 18 | for _ in range(heads): 19 | self.score_proj.append( MLP([ hidden_dim, hidden_dim*2, 1]) ) 20 | 21 | self.heads = heads 22 | self.hidden_dim = hidden_dim 23 | self.dropout_node_proba = dropout_node_proba 24 | self.dropout_attn_score = dropout_attn_score 25 | 26 | def reset_parameters(self): 27 | self.lin_proj.reset_parameters() 28 | for m in self.score_proj: 29 | m.reset_parameters() 30 | 31 | def forward(self, x, x_clique, atom2clique_index, clique_batch, clique_edge_index): 32 | row, col = atom2clique_index 33 | H = self.heads 34 | C = self.hidden_dim 35 | ## residual connection + atom2clique 36 | hx_clique = scatter(x[row], col, dim=0, dim_size=x_clique.size(0), reduce='mean') 37 | x_clique = x_clique + F.relu(self.lin_proj(hx_clique)) 38 | ## GNN scoring 39 | score_clique = x_clique.view(-1, H, C) 40 | score = torch.cat([ mlp(score_clique[:, i]) for i, mlp in enumerate(self.score_proj) ], dim=-1) 41 | score = F.dropout(score, p=self.dropout_attn_score, training=self.training) 42 | alpha = softmax(score, clique_batch) 43 | 44 | ## multihead aggregation of drug feature 45 | scaling_factor = 1. 46 | _, _, clique_drop_mask = dropout_node(clique_edge_index, self.dropout_node_proba, x_clique.size(0), clique_batch, self.training) 47 | scaling_factor = 1. / (1. - self.dropout_node_proba) 48 | 49 | drug_feat = x_clique.view(-1, H, C) * alpha.view(-1, H, 1) 50 | drug_feat = drug_feat.view(-1, H * C) * clique_drop_mask.view(-1,1) 51 | drug_feat = global_add_pool(drug_feat, clique_batch) * scaling_factor 52 | 53 | return drug_feat, x_clique, alpha 54 | -------------------------------------------------------------------------------- /PSICHIC/models/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple, Union 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .pna import PNAConv 7 | import torch 8 | 9 | from torch import Tensor 10 | from torch_sparse import SparseTensor 11 | from torch_geometric.nn import global_add_pool 12 | from torch_geometric.nn.conv import MessagePassing, GCNConv, SAGEConv, APPNP, SGConv 13 | from torch_geometric.nn.dense.linear import Linear 14 | from torch_geometric.typing import Adj, OptTensor, PairTensor 15 | from torch_geometric.utils import softmax, degree, subgraph, to_scipy_sparse_matrix, segregate_self_loops, add_remaining_self_loops 16 | import numpy as np 17 | import scipy.sparse as sp 18 | 19 | 20 | 21 | class SGCluster(torch.nn.Module): 22 | def __init__(self, in_dim, out_dim, K, in_norm=False): #L=nb_hidden_layers 23 | super().__init__() 24 | self.sgc = SGConv(in_dim, out_dim, K=K) 25 | self.in_norm = in_norm 26 | if self.in_norm: 27 | self.in_ln = nn.LayerNorm(in_dim) 28 | 29 | def reset_parameters(self): 30 | self.sgc.reset_parameters() 31 | if self.in_norm: 32 | self.in_ln.reset_parameters() 33 | 34 | def forward(self, x, edge_index): 35 | y = x 36 | # Input Layer Norm 37 | if self.in_norm: 38 | y = self.in_ln(y) 39 | 40 | y = self.sgc(y, edge_index) 41 | 42 | return y 43 | 44 | class APPNPCluster(torch.nn.Module): 45 | def __init__(self, in_dim, out_dim, a, K, in_norm=False): #L=nb_hidden_layers 46 | super().__init__() 47 | self.lin = torch.nn.Linear(in_dim, out_dim) 48 | self.propagate = APPNP(alpha=a, K=K, dropout=0) 49 | self.in_norm = in_norm 50 | if self.in_norm: 51 | self.in_ln = nn.LayerNorm(in_dim) 52 | 53 | def reset_parameters(self): 54 | self.lin.reset_parameters() 55 | if self.in_norm: 56 | self.in_ln.reset_parameters() 57 | 58 | def forward(self, x, edge_index): 59 | y = x 60 | # Input Layer Norm 61 | if self.in_norm: 62 | y = self.in_ln(y) 63 | y = self.lin(y) 64 | 65 | y = self.propagate(y, edge_index) 66 | 67 | return y 68 | 69 | class GCNCluster(torch.nn.Module): 70 | def __init__(self, dims, out_norm=False, in_norm=False): #L=nb_hidden_layers 71 | super().__init__() 72 | list_Conv_layers = [ GCNConv(dims[idx-1], dims[idx]) for idx in range(1,len(dims)) ] 73 | self.Conv_layers = nn.ModuleList(list_Conv_layers) 74 | self.hidden_layers = len(dims) - 2 75 | 76 | self.out_norm = out_norm 77 | self.in_norm = in_norm 78 | 79 | if self.out_norm: 80 | self.out_ln = nn.LayerNorm(dims[-1]) 81 | if self.in_norm: 82 | self.in_ln = nn.LayerNorm(dims[0]) 83 | 84 | def reset_parameters(self): 85 | for idx in range(self.hidden_layers+1): 86 | self.Conv_layers[idx].reset_parameters() 87 | if self.out_norm: 88 | self.out_ln.reset_parameters() 89 | if self.in_norm: 90 | self.in_ln.reset_parameters() 91 | 92 | def forward(self, x, edge_index): 93 | y = x 94 | # Input Layer Norm 95 | if self.in_norm: 96 | y = self.in_ln(y) 97 | 98 | for idx in range(self.hidden_layers): 99 | y = self.Conv_layers[idx](y, edge_index) 100 | y = F.relu(y) 101 | y = self.Conv_layers[-1](y, edge_index) 102 | 103 | if self.out_norm: 104 | y = self.out_ln(y) 105 | 106 | return y 107 | 108 | class SAGECluster(torch.nn.Module): 109 | def __init__(self, dims, in_norm=False, add_self_loops=True, root_weight=False, 110 | normalize=False, temperature=False): #L=nb_hidden_layers 111 | super().__init__() 112 | list_Conv_layers = [ SAGEConv(dims[idx-1], dims[idx], root_weight=root_weight) for idx in range(1,len(dims)) ] 113 | self.Conv_layers = nn.ModuleList(list_Conv_layers) 114 | self.hidden_layers = len(dims) - 2 115 | 116 | self.in_norm = in_norm 117 | self.temperature = temperature 118 | self.normalize = normalize 119 | 120 | if self.temperature: 121 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 122 | 123 | if self.in_norm: 124 | self.in_ln = nn.LayerNorm(dims[0]) 125 | 126 | self.add_self_loops = add_self_loops 127 | 128 | def reset_parameters(self): 129 | for idx in range(self.hidden_layers+1): 130 | self.Conv_layers[idx].reset_parameters() 131 | if self.in_norm: 132 | self.in_ln.reset_parameters() 133 | 134 | def forward(self, x, edge_index): 135 | if self.add_self_loops: 136 | edge_index, _ = add_remaining_self_loops(edge_index=edge_index, num_nodes=x.size(0)) 137 | y = x 138 | # Input Layer Norm 139 | if self.in_norm: 140 | y = self.in_ln(y) 141 | 142 | for idx in range(self.hidden_layers): 143 | y = self.Conv_layers[idx](y, edge_index) 144 | y = F.relu(y) 145 | y = self.Conv_layers[-1](y, edge_index) 146 | 147 | if self.normalize: 148 | y = F.normalize(y, p=2., dim=-1) 149 | 150 | if self.temperature: 151 | logit_scale = self.logit_scale.exp() 152 | y = y * logit_scale 153 | 154 | return y 155 | 156 | class AtomEncoder(torch.nn.Module): 157 | def __init__(self, hidden_channels): 158 | super(AtomEncoder, self).__init__() 159 | 160 | self.embeddings = torch.nn.ModuleList() 161 | 162 | for i in range(9): 163 | self.embeddings.append(torch.nn.Embedding(100, hidden_channels)) 164 | 165 | def reset_parameters(self): 166 | for embedding in self.embeddings: 167 | embedding.reset_parameters() 168 | 169 | def forward(self, x): 170 | if x.dim() == 1: 171 | x = x.unsqueeze(1) 172 | 173 | out = 0 174 | for i in range(x.size(1)): 175 | out += self.embeddings[i](x[:, i]) 176 | return out 177 | 178 | 179 | class BondEncoder(torch.nn.Module): 180 | def __init__(self, hidden_channels): 181 | super(BondEncoder, self).__init__() 182 | 183 | self.embeddings = torch.nn.ModuleList() 184 | 185 | for i in range(3): 186 | self.embeddings.append(torch.nn.Embedding(10, hidden_channels)) 187 | 188 | def reset_parameters(self): 189 | for embedding in self.embeddings: 190 | embedding.reset_parameters() 191 | 192 | def forward(self, edge_attr): 193 | if edge_attr.dim() == 1: 194 | edge_attr = edge_attr.unsqueeze(1) 195 | 196 | out = 0 197 | for i in range(edge_attr.size(1)): 198 | out += self.embeddings[i](edge_attr[:, i]) 199 | return out 200 | 201 | 202 | class PosLinear(nn.Module): 203 | __constants__ = ['in_features', 'out_features'] 204 | in_features: int 205 | out_features: int 206 | weight: Tensor 207 | 208 | def __init__(self, in_features: int, out_features: int, bias: bool = True, init_value=0.2, 209 | device=None, dtype=None) -> None: 210 | factory_kwargs = {'device': device, 'dtype': dtype} 211 | super(PosLinear, self).__init__() 212 | self.in_features = in_features 213 | self.out_features = out_features 214 | # center_value = init_value 215 | # lower_bound = center_value - center_value/10 216 | # upper_bound = center_value + center_value/10 217 | 218 | lower_bound = init_value/2 219 | upper_bound = init_value 220 | weight = nn.init.uniform_(torch.empty((out_features, in_features),**factory_kwargs), a=lower_bound, b=upper_bound) 221 | # weight = nn.init.kaiming_uniform_(torch.empty((out_features, in_features),**factory_kwargs), a=math.sqrt(5)) 222 | weight = torch.abs(weight) 223 | self.weight = nn.Parameter(weight.log()) 224 | # self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) 225 | if bias: 226 | self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) 227 | else: 228 | self.register_parameter('bias', None) 229 | self.reset_parameters() 230 | 231 | 232 | 233 | def reset_parameters(self) -> None: 234 | # nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 235 | # self.weight = torch.abs(self.weight).log() 236 | if self.bias is not None: 237 | nn.init.uniform_(self.bias) 238 | 239 | def forward(self, input: Tensor) -> Tensor: 240 | return F.linear(input, self.weight.exp(), self.bias) 241 | 242 | def extra_repr(self) -> str: 243 | return 'in_features={}, out_features={}, bias={}'.format( 244 | self.in_features, self.out_features, self.bias is not None 245 | ) 246 | 247 | 248 | class MLP(nn.Module): 249 | 250 | def __init__(self, dims, out_norm=False, in_norm=False, bias=True): #L=nb_hidden_layers 251 | super().__init__() 252 | list_FC_layers = [ nn.Linear(dims[idx-1], dims[idx], bias=bias) for idx in range(1,len(dims)) ] 253 | self.FC_layers = nn.ModuleList(list_FC_layers) 254 | self.hidden_layers = len(dims) - 2 255 | 256 | self.out_norm = out_norm 257 | self.in_norm = in_norm 258 | 259 | if self.out_norm: 260 | self.out_ln = nn.LayerNorm(dims[-1]) 261 | if self.in_norm: 262 | self.in_ln = nn.LayerNorm(dims[0]) 263 | 264 | def reset_parameters(self): 265 | for idx in range(self.hidden_layers+1): 266 | self.FC_layers[idx].reset_parameters() 267 | if self.out_norm: 268 | self.out_ln.reset_parameters() 269 | if self.in_norm: 270 | self.in_ln.reset_parameters() 271 | 272 | def forward(self, x): 273 | y = x 274 | # Input Layer Norm 275 | if self.in_norm: 276 | y = self.in_ln(y) 277 | 278 | for idx in range(self.hidden_layers): 279 | y = self.FC_layers[idx](y) 280 | y = F.relu(y) 281 | y = self.FC_layers[-1](y) 282 | 283 | if self.out_norm: 284 | y = self.out_ln(y) 285 | 286 | return y 287 | 288 | class Drug_PNAConv(nn.Module): 289 | def __init__(self, mol_deg, hidden_channels, edge_channels, 290 | pre_layers=2, post_layers=2, 291 | aggregators=['sum', 'mean', 'min', 'max', 'std'], 292 | scalers=['identity', 'amplification', 'attenuation'], 293 | num_towers=4, 294 | dropout=0.1): 295 | super(Drug_PNAConv, self).__init__() 296 | 297 | self.bond_encoder = torch.nn.Embedding(5, hidden_channels) 298 | 299 | self.atom_conv = PNAConv( 300 | in_channels=hidden_channels, out_channels=hidden_channels, 301 | edge_dim=edge_channels, aggregators=aggregators, 302 | scalers=scalers, deg=mol_deg, pre_layers=pre_layers, 303 | post_layers=post_layers,towers=num_towers,divide_input=True, 304 | ) 305 | self.atom_norm = torch.nn.LayerNorm(hidden_channels) 306 | 307 | self.dropout = dropout 308 | 309 | def reset_parameters(self): 310 | self.atom_conv.reset_parameters() 311 | self.atom_norm.reset_parameters() 312 | 313 | 314 | def forward(self, atom_x, bond_x, atom_edge_index): 315 | atom_in = atom_x 316 | bond_x = self.bond_encoder(bond_x.squeeze()) 317 | atom_x = atom_in + F.relu(self.atom_norm(self.atom_conv(atom_x, atom_edge_index, bond_x))) 318 | atom_x = F.dropout(atom_x, self.dropout, training=self.training) 319 | 320 | return atom_x 321 | 322 | 323 | class Protein_PNAConv(nn.Module): 324 | def __init__(self, prot_deg, hidden_channels, edge_channels, 325 | pre_layers=2, post_layers=2, 326 | aggregators=['sum', 'mean', 'min', 'max', 'std'], 327 | scalers=['identity', 'amplification', 'attenuation'], 328 | num_towers=4, 329 | dropout=0.1): 330 | super(Protein_PNAConv, self).__init__() 331 | 332 | self.conv = PNAConv(in_channels=hidden_channels, 333 | out_channels=hidden_channels, 334 | edge_dim=edge_channels, 335 | aggregators=aggregators, 336 | scalers=scalers, 337 | deg=prot_deg, 338 | pre_layers=pre_layers, 339 | post_layers=post_layers, 340 | towers=num_towers, 341 | divide_input=True, 342 | ) 343 | 344 | self.norm = torch.nn.LayerNorm(hidden_channels) 345 | self.dropout = dropout 346 | 347 | def reset_parameters(self): 348 | self.conv.reset_parameters() 349 | self.norm.reset_parameters() 350 | 351 | def forward(self, x, prot_edge_index, prot_edge_attr): 352 | x_in = x 353 | x = x_in + F.relu(self.norm(self.conv(x, prot_edge_index, prot_edge_attr))) 354 | x = F.dropout(x, self.dropout, training=self.training) 355 | 356 | return x 357 | 358 | 359 | class DrugProteinConv(MessagePassing): 360 | 361 | _alpha: OptTensor 362 | 363 | def __init__( 364 | self, 365 | atom_channels: int, 366 | residue_channels: int, 367 | heads: int = 1, 368 | t = 0.2, 369 | dropout_attn_score = 0.2, 370 | edge_dim: Optional[int] = None, 371 | **kwargs, 372 | ): 373 | kwargs.setdefault('aggr', 'add') 374 | super(DrugProteinConv, self).__init__(node_dim=0, **kwargs) 375 | 376 | assert residue_channels%heads == 0 377 | assert atom_channels%heads == 0 378 | 379 | self.residue_out_channels = residue_channels//heads 380 | self.atom_out_channels = atom_channels//heads 381 | self.heads = heads 382 | self.edge_dim = edge_dim 383 | self._alpha = None 384 | 385 | ## Protein Residue -> Drug Atom 386 | self.lin_key = nn.Linear(residue_channels, heads * self.atom_out_channels, bias=False) 387 | self.lin_query = nn.Linear(atom_channels, heads * self.atom_out_channels, bias=False) 388 | self.lin_value = nn.Linear(residue_channels, heads * self.atom_out_channels, bias=False) 389 | if edge_dim is not None: 390 | self.lin_edge = nn.Linear(edge_dim, heads * self.atom_out_channels, bias=False) 391 | else: 392 | self.lin_edge = self.register_parameter('lin_edge', None) 393 | 394 | ## Drug Atom -> Protein Residue 395 | self.lin_atom_value = nn.Linear(atom_channels, heads * self.residue_out_channels, bias=False) 396 | 397 | ## Normalization 398 | self.drug_in_norm = torch.nn.LayerNorm(atom_channels) 399 | self.residue_in_norm = torch.nn.LayerNorm(residue_channels) 400 | 401 | self.drug_out_norm = torch.nn.LayerNorm(heads * self.atom_out_channels) 402 | self.residue_out_norm = torch.nn.LayerNorm(heads * self.residue_out_channels) 403 | ## MLP 404 | self.clique_mlp = MLP([atom_channels*2, atom_channels*2, atom_channels], out_norm=True) 405 | self.residue_mlp = MLP([residue_channels*2, residue_channels*2, residue_channels], out_norm=True) 406 | ## temperature 407 | self.t = t 408 | # self.logit_scale = nn.Parameter(torch.ones([])) # * np.log(1 / 0.07)) 409 | 410 | ## masking attention rate 411 | self.dropout_attn_score = dropout_attn_score 412 | 413 | def reset_parameters(self): 414 | self.lin_key.reset_parameters() 415 | self.lin_query.reset_parameters() 416 | self.lin_value.reset_parameters() 417 | if self.edge_dim: 418 | self.lin_edge.reset_parameters() 419 | # Drug -> Protein 420 | self.lin_atom_value.reset_parameters() 421 | ### normalization 422 | self.drug_in_norm.reset_parameters() 423 | self.residue_in_norm.reset_parameters() 424 | self.drug_out_norm.reset_parameters() 425 | self.residue_out_norm.reset_parameters() 426 | 427 | # MLP update 428 | self.clique_mlp.reset_parameters() 429 | self.residue_mlp.reset_parameters() 430 | 431 | def forward(self, drug_x, clique_x, clique_batch, residue_x, edge_index: Adj): 432 | 433 | # Protein Residue -> Drug Atom 434 | H, aC = self.heads, self.atom_out_channels 435 | residue_hx = self.residue_in_norm(residue_x) ## normalization 436 | query = self.lin_query(drug_x).view(-1, H, aC) 437 | key = self.lin_key(residue_hx).view(-1, H, aC) 438 | value = self.lin_value(residue_hx).view(-1, H, aC) 439 | 440 | # propagate_type: (query: Tensor, key:Tensor, value: Tensor, edge_attr: OptTensor) # noqa 441 | drug_out = self.propagate(edge_index, query=query, key=key, value=value, 442 | edge_attr=None, size=None) 443 | alpha = self._alpha 444 | self._alpha = None 445 | 446 | drug_out = drug_out.view(-1, H * aC) 447 | drug_out = self.drug_out_norm(drug_out) 448 | clique_out = torch.cat([clique_x, drug_out[clique_batch]], dim=-1) 449 | clique_out = self.clique_mlp(clique_out) 450 | 451 | # Drug Atom -> Protein Residue 452 | H, rC = self.heads, self.residue_out_channels 453 | drug_hx = self.drug_in_norm(drug_x) ## normalization 454 | residue_value = self.lin_atom_value(drug_hx).view(-1, H, rC)[edge_index[1]] 455 | residue_out = residue_value * alpha.view(-1, H, 1) 456 | residue_out = residue_out.view(-1, H * rC) 457 | residue_out = self.residue_out_norm(residue_out) 458 | residue_out = torch.cat([residue_out, residue_x], dim=-1) 459 | residue_out = self.residue_mlp(residue_out) 460 | 461 | return clique_out, residue_out, (edge_index, alpha) 462 | 463 | 464 | def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor, 465 | edge_attr: OptTensor, index: Tensor, ptr: OptTensor, 466 | size_i: Optional[int]) -> Tensor: 467 | 468 | alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.atom_out_channels) 469 | alpha = alpha / self.t ## temperature 470 | # logit_scale = self.logit_scale.exp() 471 | # alpha = alpha * logit_scale 472 | 473 | alpha = F.dropout(alpha, p=self.dropout_attn_score, training=self.training) 474 | alpha = softmax(alpha , index, ptr, size_i) 475 | self._alpha = alpha 476 | 477 | out = value_j 478 | out = out * alpha.view(-1, self.heads, 1) 479 | 480 | return out 481 | 482 | 483 | def unbatch(src, batch, dim: int = 0): 484 | r"""Splits :obj:`src` according to a :obj:`batch` vector along dimension 485 | :obj:`dim`. 486 | 487 | Args: 488 | src (Tensor): The source tensor. 489 | batch (LongTensor): The batch vector 490 | :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each 491 | entry in :obj:`src` to a specific example. Must be ordered. 492 | dim (int, optional): The dimension along which to split the :obj:`src` 493 | tensor. (default: :obj:`0`) 494 | 495 | :rtype: :class:`List[Tensor]` 496 | 497 | Example: 498 | 499 | >>> src = torch.arange(7) 500 | >>> batch = torch.tensor([0, 0, 0, 1, 1, 2, 2]) 501 | >>> unbatch(src, batch) 502 | (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) 503 | """ 504 | sizes = degree(batch, dtype=torch.long).tolist() 505 | return src.split(sizes, dim) 506 | 507 | 508 | 509 | def unbatch_edge_index(edge_index, batch): 510 | r"""Splits the :obj:`edge_index` according to a :obj:`batch` vector. 511 | 512 | Args: 513 | edge_index (Tensor): The edge_index tensor. Must be ordered. 514 | batch (LongTensor): The batch vector 515 | :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each 516 | node to a specific example. Must be ordered. 517 | 518 | :rtype: :class:`List[Tensor]` 519 | 520 | Example: 521 | 522 | >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6], 523 | ... [1, 0, 2, 1, 3, 2, 5, 4, 6, 5]]) 524 | >>> batch = torch.tensor([0, 0, 0, 0, 1, 1, 1]) 525 | >>> unbatch_edge_index(edge_index, batch) 526 | (tensor([[0, 1, 1, 2, 2, 3], 527 | [1, 0, 2, 1, 3, 2]]), 528 | tensor([[0, 1, 1, 2], 529 | [1, 0, 2, 1]])) 530 | """ 531 | deg = degree(batch, dtype=torch.int64) 532 | ptr = torch.cat([deg.new_zeros(1), deg.cumsum(dim=0)[:-1]], dim=0) 533 | 534 | edge_batch = batch[edge_index[0]] 535 | edge_index = edge_index - ptr[edge_batch] 536 | sizes = degree(edge_batch, dtype=torch.int64).cpu().tolist() 537 | return edge_index.split(sizes, dim=1) 538 | 539 | 540 | def compute_connectivity(edge_index, batch): ## for numerical stability (i.e. we cap inv_con at 100) 541 | 542 | edges_by_batch = unbatch_edge_index(edge_index, batch) 543 | 544 | nodes_counts = torch.unique(batch, return_counts=True)[1] 545 | 546 | connectivity = torch.tensor([nodes_in_largest_graph(e, n) for e, n in zip(edges_by_batch, nodes_counts)]) 547 | isolation = torch.tensor([isolated_nodes(e, n) for e, n in zip(edges_by_batch, nodes_counts)]) 548 | 549 | return connectivity, isolation 550 | 551 | 552 | def nodes_in_largest_graph(edge_index, num_nodes): 553 | adj = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes) 554 | 555 | num_components, component = sp.csgraph.connected_components(adj) 556 | 557 | _, count = np.unique(component, return_counts=True) 558 | subset = np.in1d(component, count.argsort()[-1:]) 559 | 560 | return subset.sum() / num_nodes 561 | 562 | 563 | def isolated_nodes(edge_index, num_nodes): 564 | r"""Find isolate nodes """ 565 | edge_attr = None 566 | 567 | out = segregate_self_loops(edge_index, edge_attr) 568 | edge_index, edge_attr, loop_edge_index, loop_edge_attr = out 569 | 570 | mask = torch.ones(num_nodes, dtype=torch.bool, device=edge_index.device) 571 | mask[edge_index.view(-1)] = 0 572 | 573 | return mask.sum() / num_nodes 574 | 575 | def dropout_node(edge_index, p, num_nodes, batch, training): 576 | r"""Randomly drops nodes from the adjacency matrix 577 | :obj:`edge_index` with probability :obj:`p` using samples from 578 | a Bernoulli distribution. 579 | 580 | The method returns (1) the retained :obj:`edge_index`, (2) the edge mask 581 | indicating which edges were retained. (3) the node mask indicating 582 | which nodes were retained. 583 | 584 | Args: 585 | edge_index (LongTensor): The edge indices. 586 | p (float, optional): Dropout probability. (default: :obj:`0.5`) 587 | num_nodes (int, optional): The number of nodes, *i.e.* 588 | :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) 589 | training (bool, optional): If set to :obj:`False`, this operation is a 590 | no-op. (default: :obj:`True`) 591 | 592 | :rtype: (:class:`LongTensor`, :class:`BoolTensor`, :class:`BoolTensor`) 593 | 594 | Examples: 595 | 596 | >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], 597 | ... [1, 0, 2, 1, 3, 2]]) 598 | >>> edge_index, edge_mask, node_mask = dropout_node(edge_index) 599 | >>> edge_index 600 | tensor([[0, 1], 601 | [1, 0]]) 602 | >>> edge_mask 603 | tensor([ True, True, False, False, False, False]) 604 | >>> node_mask 605 | tensor([ True, True, False, False]) 606 | """ 607 | if p < 0. or p > 1.: 608 | raise ValueError(f'Dropout probability has to be between 0 and 1 ' 609 | f'(got {p}') 610 | 611 | if not training or p == 0.0: 612 | node_mask = edge_index.new_ones(num_nodes, dtype=torch.bool) 613 | edge_mask = edge_index.new_ones(edge_index.size(1), dtype=torch.bool) 614 | return edge_index, edge_mask, node_mask 615 | 616 | prob = torch.rand(num_nodes, device=edge_index.device) 617 | node_mask = prob > p 618 | 619 | ## ensure no graph is totally dropped out 620 | batch_tf = global_add_pool(node_mask.view(-1,1),batch).flatten() 621 | unbatched_node_mask = unbatch(node_mask, batch) 622 | node_mask_list = [] 623 | 624 | for true_false, sub_node_mask in zip(batch_tf, unbatched_node_mask): 625 | if true_false.item(): 626 | node_mask_list.append(sub_node_mask) 627 | else: 628 | perm = torch.randperm(sub_node_mask.size(0)) 629 | idx = perm[:1] 630 | sub_node_mask[idx] = True 631 | node_mask_list.append(sub_node_mask) 632 | 633 | node_mask = torch.cat(node_mask_list) 634 | 635 | edge_index, _, edge_mask = subgraph(node_mask, edge_index, 636 | num_nodes=num_nodes, 637 | return_edge_mask=True) 638 | return edge_index, edge_mask, node_mask 639 | 640 | def dropout_edge(edge_index: Tensor, p: float = 0.5, 641 | force_undirected: bool = False, 642 | training: bool = True) -> Tuple[Tensor, Tensor]: 643 | r"""Randomly drops edges from the adjacency matrix 644 | :obj:`edge_index` with probability :obj:`p` using samples from 645 | a Bernoulli distribution. 646 | 647 | The method returns (1) the retained :obj:`edge_index`, (2) the edge mask 648 | or index indicating which edges were retained, depending on the argument 649 | :obj:`force_undirected`. 650 | 651 | Args: 652 | edge_index (LongTensor): The edge indices. 653 | p (float, optional): Dropout probability. (default: :obj:`0.5`) 654 | force_undirected (bool, optional): If set to :obj:`True`, will either 655 | drop or keep both edges of an undirected edge. 656 | (default: :obj:`False`) 657 | training (bool, optional): If set to :obj:`False`, this operation is a 658 | no-op. (default: :obj:`True`) 659 | 660 | :rtype: (:class:`LongTensor`, :class:`BoolTensor` or :class:`LongTensor`) 661 | 662 | Examples: 663 | 664 | >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], 665 | ... [1, 0, 2, 1, 3, 2]]) 666 | >>> edge_index, edge_mask = dropout_edge(edge_index) 667 | >>> edge_index 668 | tensor([[0, 1, 2, 2], 669 | [1, 2, 1, 3]]) 670 | >>> edge_mask # masks indicating which edges are retained 671 | tensor([ True, False, True, True, True, False]) 672 | 673 | >>> edge_index, edge_id = dropout_edge(edge_index, 674 | ... force_undirected=True) 675 | >>> edge_index 676 | tensor([[0, 1, 2, 1, 2, 3], 677 | [1, 2, 3, 0, 1, 2]]) 678 | >>> edge_id # indices indicating which edges are retained 679 | tensor([0, 2, 4, 0, 2, 4]) 680 | """ 681 | if p < 0. or p > 1.: 682 | raise ValueError(f'Dropout probability has to be between 0 and 1 ' 683 | f'(got {p}') 684 | 685 | if not training or p == 0.0: 686 | edge_mask = edge_index.new_ones(edge_index.size(1), dtype=torch.bool) 687 | return edge_index, edge_mask 688 | 689 | row, col = edge_index 690 | 691 | edge_mask = torch.rand(row.size(0), device=edge_index.device) >= p 692 | 693 | if force_undirected: 694 | edge_mask[row > col] = False 695 | 696 | edge_index = edge_index[:, edge_mask] 697 | 698 | if force_undirected: 699 | edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1) 700 | edge_mask = edge_mask.nonzero().repeat((2, 1)).squeeze() 701 | 702 | return edge_index, edge_mask 703 | -------------------------------------------------------------------------------- /PSICHIC/models/mini_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import global_add_pool, global_mean_pool 3 | from torch.nn import Embedding, Linear 4 | from torch_geometric.utils import degree, to_scipy_sparse_matrix, segregate_self_loops 5 | import torch.nn.functional as F 6 | from torch_scatter import scatter 7 | import numpy as np 8 | import scipy.sparse as sp 9 | from models.layers import MLP, AtomEncoder, Drug_PNAConv, Protein_PNAConv, DrugProteinConv, PosLinear, GCNCluster, SAGECluster, SGCluster, APPNPCluster, dropout_edge 10 | from copy import deepcopy 11 | ## for drug pooling 12 | from models.drug_pool import MotifPool 13 | ## for cluster 14 | from torch_geometric.utils import dense_to_sparse, to_dense_adj, to_dense_batch, dropout_adj, degree, subgraph, softmax, add_remaining_self_loops 15 | from models.protein_pool import dense_mincut_pool , dense_dmon_pool, simplify_pool 16 | ## for cluster 17 | from torch_geometric.nn.norm import GraphNorm 18 | import torch_geometric 19 | 20 | 21 | EPS = 1e-15 22 | import math 23 | 24 | class net(torch.nn.Module): 25 | def __init__(self, mol_deg, prot_deg, 26 | # MOLECULE 27 | mol_in_channels=43, prot_in_channels=33, prot_evo_channels=1280, 28 | hidden_channels=512, 29 | pre_layers=2, post_layers=2, 30 | aggregators=['mean', 'min', 'max', 'std'], 31 | scalers=['identity', 'amplification', 'attenuation'], 32 | # interaction 33 | total_layer=3, 34 | K = 10, 35 | t = 0.1, 36 | # training 37 | heads=2, 38 | dropout=0.1, 39 | drop_atom=0.1, 40 | drop_residue=0.1, 41 | dropout_attn_score=0.1, 42 | dropout_cluster_edge=0.1, 43 | gaussian_noise=0.1, 44 | # objective 45 | regression_head = True, 46 | classification_head = False, 47 | device='cuda:0'): 48 | super(net, self).__init__() 49 | self.total_layer = total_layer 50 | 51 | if (regression_head or classification_head) == False: 52 | raise Exception('must have one objective') 53 | 54 | self.regression_head = regression_head 55 | self.classification_head = classification_head 56 | 57 | if isinstance(K, int): 58 | K = [K]*total_layer 59 | 60 | # MOLECULE IN FEAT 61 | self.atom_type_encoder = Embedding(20,hidden_channels) 62 | self.atom_feat_encoder = MLP([mol_in_channels, hidden_channels * 2, hidden_channels], out_norm=True) 63 | 64 | ### MOLECULE and PROTEIN 65 | self.mol_convs = torch.nn.ModuleList() 66 | 67 | self.mol_gn2 = torch.nn.ModuleList() 68 | 69 | for idx in range(total_layer): 70 | self.mol_convs.append(Drug_PNAConv( 71 | mol_deg, hidden_channels,edge_channels=hidden_channels, 72 | pre_layers=pre_layers, post_layers=post_layers, 73 | aggregators=aggregators, 74 | scalers=scalers, 75 | num_towers=heads, 76 | dropout=dropout 77 | )) 78 | 79 | self.mol_gn2.append(GraphNorm(hidden_channels)) 80 | 81 | 82 | if self.regression_head: 83 | self.reg_out = MLP([hidden_channels, hidden_channels, 1]) 84 | if self.classification_head: 85 | self.cls_out = MLP([hidden_channels, hidden_channels, 1]) 86 | 87 | self.device = device 88 | 89 | def reset_parameters(self): 90 | self.atom_feat_encoder.reset_parameters() 91 | 92 | for idx in range(self.total_layer): 93 | self.mol_convs[idx].reset_parameters() 94 | 95 | self.mol_gn2[idx].reset_parameters() 96 | 97 | if self.regression_head: 98 | self.reg_out.reset_parameters() 99 | if self.classification_head: 100 | self.cls_out.reset_parameters() 101 | 102 | def forward(self, 103 | # Molecule 104 | mol_x, mol_x_feat, bond_x, atom_edge_index, 105 | clique_x, clique_edge_index, atom2clique_index, # drug cliques 106 | # Protein 107 | residue_x, residue_evo_x, residue_edge_index, residue_edge_weight, 108 | # Mol-Protein Interaction batch 109 | mol_batch=None, prot_batch=None, clique_batch=None): 110 | # Init variables 111 | reg_pred = None 112 | cls_pred = None 113 | # MOLECULE Featurize 114 | atom_x = self.atom_type_encoder(mol_x.squeeze()) + self.atom_feat_encoder(mol_x_feat) 115 | 116 | spectral_loss = torch.tensor(0.).to(self.device) 117 | ortho_loss = torch.tensor(0.).to(self.device) 118 | cluster_loss = torch.tensor(0.).to(self.device) 119 | attention_dict = {} 120 | # MOLECULE-PROTEIN Layers 121 | for idx in range(self.total_layer): 122 | atom_x = self.mol_convs[idx](atom_x, bond_x, atom_edge_index) 123 | 124 | ## Graph Normalization 125 | atom_x = self.mol_gn2[idx](atom_x, mol_batch) 126 | 127 | mol_pool_feat = global_mean_pool(atom_x, mol_batch) 128 | 129 | if self.regression_head: 130 | reg_pred = self.reg_out(mol_pool_feat) 131 | if self.classification_head: 132 | cls_pred = self.cls_out(mol_pool_feat) 133 | 134 | return reg_pred, cls_pred, spectral_loss, ortho_loss, cluster_loss, attention_dict 135 | 136 | def temperature_clamp(self): 137 | pass 138 | 139 | 140 | def connect_mol_prot(self, mol_batch, prot_batch): 141 | mol_num_nodes = mol_batch.size(0) 142 | prot_num_nodes = prot_batch.size(0) 143 | mol_adj = mol_batch.reshape(-1, 1).repeat(1, prot_num_nodes) 144 | pro_adj = prot_batch.repeat(mol_num_nodes, 1) 145 | 146 | m2p_edge_index = (mol_adj == pro_adj).nonzero(as_tuple=False).t().contiguous() 147 | 148 | return m2p_edge_index 149 | 150 | def configure_optimizers(self, weight_decay, learning_rate, betas, eps, amsgrad): 151 | """ 152 | This long function is unfortunately doing something very simple and is being very defensive: 153 | We are separating out all parameters of the model into two buckets: those that will experience 154 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 155 | We are then returning the PyTorch optimizer object. 156 | """ 157 | 158 | # separate out all parameters to those that will and won't experience regularizing weight decay 159 | decay = set() 160 | no_decay = set() 161 | whitelist_weight_modules = (torch.nn.Linear, torch_geometric.nn.dense.linear.Linear) 162 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, GraphNorm, PosLinear) 163 | for mn, m in self.named_modules(): 164 | for pn, p in m.named_parameters(): 165 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 166 | # random note: because named_modules and named_parameters are recursive 167 | # we will see the same tensors p many many times. but doing it this way 168 | # allows us to know which parent module any tensor p belongs to... 169 | if pn.endswith('bias') or pn.endswith('mean_scale'):# or pn.endswith('logit_scale'): 170 | # all biases will not be decayed 171 | no_decay.add(fpn) 172 | # if mn.startswith('cluster'): 173 | # print(mn, 'not decayed!') 174 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 175 | # weights of whitelist modules will be weight decayed 176 | decay.add(fpn) 177 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 178 | # weights of blacklist modules will NOT be weight decayed 179 | no_decay.add(fpn) 180 | 181 | # validate that we considered every parameter 182 | param_dict = {pn: p for pn, p in self.named_parameters()} 183 | inter_params = decay & no_decay 184 | union_params = decay | no_decay 185 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 186 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 187 | % (str(param_dict.keys() - union_params), ) 188 | 189 | # create the pytorch optimizer object 190 | optim_groups = [ 191 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, 192 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 193 | ] 194 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, eps=eps, amsgrad=amsgrad) 195 | return optimizer 196 | 197 | 198 | def _rbf(D, D_min=0., D_max=1., D_count=16, device='cpu'): 199 | ''' 200 | From https://github.com/jingraham/neurips19-graph-protein-design 201 | 202 | Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1. 203 | That is, if `D` has shape [...dims], then the returned tensor will have 204 | shape [...dims, D_count]. 205 | ''' 206 | D = torch.where(D < D_max, D, torch.tensor(D_max).float().to(device) ) 207 | D_mu = torch.linspace(D_min, D_max, D_count, device=device) 208 | D_mu = D_mu.view([1, -1]) 209 | D_sigma = (D_max - D_min) / D_count 210 | D_expand = torch.unsqueeze(D, -1) 211 | 212 | RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2) 213 | return RBF 214 | 215 | 216 | def unbatch(src, batch, dim: int = 0): 217 | r"""Splits :obj:`src` according to a :obj:`batch` vector along dimension 218 | :obj:`dim`. 219 | 220 | Args: 221 | src (Tensor): The source tensor. 222 | batch (LongTensor): The batch vector 223 | :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each 224 | entry in :obj:`src` to a specific example. Must be ordered. 225 | dim (int, optional): The dimension along which to split the :obj:`src` 226 | tensor. (default: :obj:`0`) 227 | 228 | :rtype: :class:`List[Tensor]` 229 | 230 | Example: 231 | 232 | >>> src = torch.arange(7) 233 | >>> batch = torch.tensor([0, 0, 0, 1, 1, 2, 2]) 234 | >>> unbatch(src, batch) 235 | (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) 236 | """ 237 | sizes = degree(batch, dtype=torch.long).tolist() 238 | return src.split(sizes, dim) 239 | 240 | 241 | 242 | def unbatch_edge_index(edge_index, batch): 243 | r"""Splits the :obj:`edge_index` according to a :obj:`batch` vector. 244 | 245 | Args: 246 | edge_index (Tensor): The edge_index tensor. Must be ordered. 247 | batch (LongTensor): The batch vector 248 | :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each 249 | node to a specific example. Must be ordered. 250 | 251 | :rtype: :class:`List[Tensor]` 252 | 253 | Example: 254 | 255 | >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6], 256 | ... [1, 0, 2, 1, 3, 2, 5, 4, 6, 5]]) 257 | >>> batch = torch.tensor([0, 0, 0, 0, 1, 1, 1]) 258 | >>> unbatch_edge_index(edge_index, batch) 259 | (tensor([[0, 1, 1, 2, 2, 3], 260 | [1, 0, 2, 1, 3, 2]]), 261 | tensor([[0, 1, 1, 2], 262 | [1, 0, 2, 1]])) 263 | """ 264 | deg = degree(batch, dtype=torch.int64) 265 | ptr = torch.cat([deg.new_zeros(1), deg.cumsum(dim=0)[:-1]], dim=0) 266 | 267 | edge_batch = batch[edge_index[0]] 268 | edge_index = edge_index - ptr[edge_batch] 269 | sizes = degree(edge_batch, dtype=torch.int64).cpu().tolist() 270 | return edge_index.split(sizes, dim=1) 271 | 272 | 273 | def compute_connectivity(edge_index, batch): ## for numerical stability (i.e. we cap inv_con at 100) 274 | 275 | edges_by_batch = unbatch_edge_index(edge_index, batch) 276 | 277 | nodes_counts = torch.unique(batch, return_counts=True)[1] 278 | 279 | connectivity = torch.tensor([nodes_in_largest_graph(e, n) for e, n in zip(edges_by_batch, nodes_counts)]) 280 | isolation = torch.tensor([isolated_nodes(e, n) for e, n in zip(edges_by_batch, nodes_counts)]) 281 | 282 | return connectivity, isolation 283 | 284 | 285 | def nodes_in_largest_graph(edge_index, num_nodes): 286 | adj = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes) 287 | 288 | num_components, component = sp.csgraph.connected_components(adj) 289 | 290 | _, count = np.unique(component, return_counts=True) 291 | subset = np.in1d(component, count.argsort()[-1:]) 292 | 293 | return subset.sum() / num_nodes 294 | 295 | 296 | def isolated_nodes(edge_index, num_nodes): 297 | r"""Find isolate nodes """ 298 | edge_attr = None 299 | 300 | out = segregate_self_loops(edge_index, edge_attr) 301 | edge_index, edge_attr, loop_edge_index, loop_edge_attr = out 302 | 303 | mask = torch.ones(num_nodes, dtype=torch.bool, device=edge_index.device) 304 | mask[edge_index.view(-1)] = 0 305 | 306 | return mask.sum() / num_nodes 307 | 308 | def dropout_node(edge_index, p, num_nodes, batch, training): 309 | r"""Randomly drops nodes from the adjacency matrix 310 | :obj:`edge_index` with probability :obj:`p` using samples from 311 | a Bernoulli distribution. 312 | 313 | The method returns (1) the retained :obj:`edge_index`, (2) the edge mask 314 | indicating which edges were retained. (3) the node mask indicating 315 | which nodes were retained. 316 | 317 | Args: 318 | edge_index (LongTensor): The edge indices. 319 | p (float, optional): Dropout probability. (default: :obj:`0.5`) 320 | num_nodes (int, optional): The number of nodes, *i.e.* 321 | :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) 322 | training (bool, optional): If set to :obj:`False`, this operation is a 323 | no-op. (default: :obj:`True`) 324 | 325 | :rtype: (:class:`LongTensor`, :class:`BoolTensor`, :class:`BoolTensor`) 326 | 327 | Examples: 328 | 329 | >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], 330 | ... [1, 0, 2, 1, 3, 2]]) 331 | >>> edge_index, edge_mask, node_mask = dropout_node(edge_index) 332 | >>> edge_index 333 | tensor([[0, 1], 334 | [1, 0]]) 335 | >>> edge_mask 336 | tensor([ True, True, False, False, False, False]) 337 | >>> node_mask 338 | tensor([ True, True, False, False]) 339 | """ 340 | if p < 0. or p > 1.: 341 | raise ValueError(f'Dropout probability has to be between 0 and 1 ' 342 | f'(got {p}') 343 | 344 | if not training or p == 0.0: 345 | node_mask = edge_index.new_ones(num_nodes, dtype=torch.bool) 346 | edge_mask = edge_index.new_ones(edge_index.size(1), dtype=torch.bool) 347 | return edge_index, edge_mask, node_mask 348 | 349 | prob = torch.rand(num_nodes, device=edge_index.device) 350 | node_mask = prob > p 351 | 352 | ## ensure no graph is totally dropped out 353 | batch_tf = global_add_pool(node_mask.view(-1,1),batch).flatten() 354 | unbatched_node_mask = unbatch(node_mask, batch) 355 | node_mask_list = [] 356 | 357 | for true_false, sub_node_mask in zip(batch_tf, unbatched_node_mask): 358 | if true_false.item(): 359 | node_mask_list.append(sub_node_mask) 360 | else: 361 | perm = torch.randperm(sub_node_mask.size(0)) 362 | idx = perm[:1] 363 | sub_node_mask[idx] = True 364 | node_mask_list.append(sub_node_mask) 365 | 366 | node_mask = torch.cat(node_mask_list) 367 | 368 | edge_index, _, edge_mask = subgraph(node_mask, edge_index, 369 | num_nodes=num_nodes, 370 | return_edge_mask=True) 371 | return edge_index, edge_mask, node_mask 372 | 373 | 374 | -------------------------------------------------------------------------------- /PSICHIC/models/pna.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import ModuleList, Sequential 6 | 7 | from .scaler import DegreeScalerAggregation 8 | # from models.pna_scaler import DegreeScalerAggregation 9 | from torch_geometric.nn.conv import MessagePassing 10 | from torch_geometric.nn.dense.linear import Linear 11 | from torch_geometric.typing import Adj, OptTensor 12 | from torch_geometric.utils import degree 13 | from torch_geometric.nn.resolver import activation_resolver 14 | 15 | from torch_geometric.nn.inits import reset 16 | 17 | 18 | class PNAConv(MessagePassing): 19 | r"""The Principal Neighbourhood Aggregation graph convolution operator 20 | from the `"Principal Neighbourhood Aggregation for Graph Nets" 21 | `_ paper 22 | 23 | .. math:: 24 | \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( 25 | \mathbf{x}_i, \underset{j \in \mathcal{N}(i)}{\bigoplus} 26 | h_{\mathbf{\Theta}} \left( \mathbf{x}_i, \mathbf{x}_j \right) 27 | \right) 28 | 29 | with 30 | 31 | .. math:: 32 | \bigoplus = \underbrace{\begin{bmatrix} 33 | 1 \\ 34 | S(\mathbf{D}, \alpha=1) \\ 35 | S(\mathbf{D}, \alpha=-1) 36 | \end{bmatrix} }_{\text{scalers}} 37 | \otimes \underbrace{\begin{bmatrix} 38 | \mu \\ 39 | \sigma \\ 40 | \max \\ 41 | \min 42 | \end{bmatrix}}_{\text{aggregators}}, 43 | 44 | where :math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}` 45 | denote MLPs. 46 | 47 | .. note:: 48 | 49 | For an example of using :obj:`PNAConv`, see `examples/pna.py 50 | `_. 52 | 53 | Args: 54 | in_channels (int): Size of each input sample, or :obj:`-1` to derive 55 | the size from the first input(s) to the forward method. 56 | out_channels (int): Size of each output sample. 57 | aggregators (list of str): Set of aggregation function identifiers, 58 | namely :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, 59 | :obj:`"var"` and :obj:`"std"`. 60 | scalers (list of str): Set of scaling function identifiers, namely 61 | :obj:`"identity"`, :obj:`"amplification"`, 62 | :obj:`"attenuation"`, :obj:`"linear"` and 63 | :obj:`"inverse_linear"`. 64 | deg (Tensor): Histogram of in-degrees of nodes in the training set, 65 | used by scalers to normalize. 66 | edge_dim (int, optional): Edge feature dimensionality (in case 67 | there are any). (default :obj:`None`) 68 | towers (int, optional): Number of towers (default: :obj:`1`). 69 | pre_layers (int, optional): Number of transformation layers before 70 | aggregation (default: :obj:`1`). 71 | post_layers (int, optional): Number of transformation layers after 72 | aggregation (default: :obj:`1`). 73 | divide_input (bool, optional): Whether the input features should 74 | be split between towers or not (default: :obj:`False`). 75 | **kwargs (optional): Additional arguments of 76 | :class:`torch_geometric.nn.conv.MessagePassing`. 77 | 78 | Shapes: 79 | - **input:** 80 | node features :math:`(|\mathcal{V}|, F_{in})`, 81 | edge indices :math:`(2, |\mathcal{E}|)`, 82 | edge features :math:`(|\mathcal{E}|, D)` *(optional)* 83 | - **output:** node features :math:`(|\mathcal{V}|, F_{out})` 84 | """ 85 | def __init__(self, in_channels: int, out_channels: int, 86 | aggregators: List[str], scalers: List[str], deg: Tensor, 87 | edge_dim: Optional[int] = None, towers: int = 1, 88 | pre_layers: int = 1, post_layers: int = 1, 89 | act: Union[str, Callable, None] = "relu", 90 | act_kwargs: Optional[Dict[str, Any]] = None, 91 | divide_input: bool = False, **kwargs): 92 | 93 | aggr = DegreeScalerAggregation(aggregators, scalers, deg) 94 | super().__init__(aggr=aggr, node_dim=0, **kwargs) 95 | 96 | if divide_input: 97 | assert in_channels % towers == 0 98 | assert out_channels % towers == 0 99 | 100 | self.in_channels = in_channels 101 | self.out_channels = out_channels 102 | self.edge_dim = edge_dim 103 | self.towers = towers 104 | self.divide_input = divide_input 105 | 106 | self.F_in = in_channels // towers if divide_input else in_channels 107 | self.F_out = self.out_channels // towers 108 | 109 | if self.edge_dim is not None: 110 | self.edge_encoder = Linear(edge_dim, self.F_in) 111 | 112 | self.pre_nns = ModuleList() 113 | self.post_nns = ModuleList() 114 | for _ in range(towers): 115 | modules = [Linear((3 if edge_dim else 2) * self.F_in, self.F_in)] 116 | for _ in range(pre_layers - 1): 117 | modules += [activation_resolver(act, **(act_kwargs or {}))] 118 | modules += [Linear(self.F_in, self.F_in)] 119 | self.pre_nns.append(Sequential(*modules)) 120 | 121 | in_channels = (len(aggregators) * len(scalers) + 1) * self.F_in 122 | modules = [Linear(in_channels, self.F_out)] 123 | for _ in range(post_layers - 1): 124 | modules += [activation_resolver(act, **(act_kwargs or {}))] 125 | modules += [Linear(self.F_out, self.F_out)] 126 | self.post_nns.append(Sequential(*modules)) 127 | 128 | self.lin = Linear(out_channels, out_channels) 129 | 130 | self.reset_parameters() 131 | 132 | def reset_parameters(self): 133 | if self.edge_dim is not None: 134 | self.edge_encoder.reset_parameters() 135 | for nn in self.pre_nns: 136 | reset(nn) 137 | for nn in self.post_nns: 138 | reset(nn) 139 | self.lin.reset_parameters() 140 | 141 | def forward(self, x: Tensor, edge_index: Adj, 142 | edge_attr: OptTensor = None) -> Tensor: 143 | """""" 144 | if self.divide_input: 145 | x = x.view(-1, self.towers, self.F_in) 146 | else: 147 | x = x.view(-1, 1, self.F_in).repeat(1, self.towers, 1) 148 | 149 | # propagate_type: (x: Tensor, edge_attr: OptTensor) 150 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None) 151 | 152 | out = torch.cat([x, out], dim=-1) 153 | outs = [nn(out[:, i]) for i, nn in enumerate(self.post_nns)] 154 | out = torch.cat(outs, dim=1) 155 | 156 | return self.lin(out) 157 | 158 | def message(self, x_i: Tensor, x_j: Tensor, 159 | edge_attr: OptTensor) -> Tensor: 160 | 161 | h: Tensor = x_i # Dummy. 162 | if edge_attr is not None: 163 | edge_attr = self.edge_encoder(edge_attr) 164 | edge_attr = edge_attr.view(-1, 1, self.F_in) 165 | edge_attr = edge_attr.repeat(1, self.towers, 1) 166 | h = torch.cat([x_i, x_j, edge_attr], dim=-1) 167 | else: 168 | h = torch.cat([x_i, x_j], dim=-1) 169 | 170 | hs = [nn(h[:, i]) for i, nn in enumerate(self.pre_nns)] 171 | return torch.stack(hs, dim=1) 172 | 173 | def __repr__(self): 174 | return (f'{self.__class__.__name__}({self.in_channels}, ' 175 | f'{self.out_channels}, towers={self.towers}, ' 176 | f'edge_dim={self.edge_dim})') 177 | 178 | @staticmethod 179 | def get_degree_histogram(loader) -> Tensor: 180 | max_degree = 0 181 | for data in loader: 182 | d = degree(data.edge_index[1], num_nodes=data.num_nodes, 183 | dtype=torch.long) 184 | max_degree = max(max_degree, int(d.max())) 185 | # Compute the in-degree histogram tensor 186 | deg_histogram = torch.zeros(max_degree + 1, dtype=torch.long) 187 | for data in loader: 188 | d = degree(data.edge_index[1], num_nodes=data.num_nodes, 189 | dtype=torch.long) 190 | deg_histogram += torch.bincount(d, minlength=deg_histogram.numel()) 191 | 192 | return deg_histogram 193 | -------------------------------------------------------------------------------- /PSICHIC/models/protein_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | EPS = 1e-15 4 | from .layers import * 5 | 6 | def dense_mincut_pool(x, adj, s, mask=None, cluster_drop_node=None): 7 | r"""The MinCut pooling operator from the `"Spectral Clustering in Graph 8 | Neural Networks for Graph Pooling" `_ 9 | paper 10 | 11 | .. math:: 12 | \mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot 13 | \mathbf{X} 14 | 15 | \mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot 16 | \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S}) 17 | 18 | based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B 19 | \times N \times C}`. 20 | Returns the pooled node feature matrix, the coarsened and symmetrically 21 | normalized adjacency matrix and two auxiliary objectives: (1) The MinCut 22 | loss 23 | 24 | .. math:: 25 | \mathcal{L}_c = - \frac{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{A} 26 | \mathbf{S})} {\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{D} 27 | \mathbf{S})} 28 | 29 | where :math:`\mathbf{D}` is the degree matrix, and (2) the orthogonality 30 | loss 31 | 32 | .. math:: 33 | \mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}} 34 | {{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}} 35 | \right\|}_F. 36 | 37 | Args: 38 | x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B 39 | \times N \times F}` with batch-size :math:`B`, (maximum) 40 | number of nodes :math:`N` for each graph, and feature dimension 41 | :math:`F`. 42 | adj (Tensor): Symmetrically normalized adjacency tensor 43 | :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. 44 | s (Tensor): Assignment tensor :math:`\mathbf{S} \in \mathbb{R}^{B 45 | \times N \times C}` with number of clusters :math:`C`. The softmax 46 | does not have to be applied beforehand, since it is executed 47 | within this method. 48 | mask (BoolTensor, optional): Mask matrix 49 | :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating 50 | the valid nodes for each graph. (default: :obj:`None`) 51 | 52 | :rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`, 53 | :class:`Tensor`) 54 | """ 55 | 56 | x = x.unsqueeze(0) if x.dim() == 2 else x 57 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj 58 | s = s.unsqueeze(0) if s.dim() == 2 else s 59 | 60 | (batch_size, num_nodes, _), k = x.size(), s.size(-1) 61 | 62 | s = torch.softmax(s, dim=-1) 63 | 64 | if mask is not None: 65 | s = s * mask.view(batch_size, num_nodes, 1).to(x.dtype) 66 | x_mask = mask.view(batch_size, num_nodes, 1).to(x.dtype) 67 | 68 | if cluster_drop_node is not None: 69 | x_mask = cluster_drop_node.view(batch_size, num_nodes, 1).to(x.dtype) 70 | 71 | x = x * x_mask 72 | 73 | out = torch.matmul(s.transpose(1, 2), x) 74 | out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) 75 | 76 | # MinCut regularization. 77 | mincut_num = _rank3_trace(out_adj) 78 | d_flat = torch.einsum('ijk->ij', adj) 79 | d = _rank3_diag(d_flat) 80 | mincut_den = _rank3_trace( 81 | torch.matmul(torch.matmul(s.transpose(1, 2), d), s)) 82 | mincut_loss = -(mincut_num / mincut_den) 83 | mincut_loss = torch.mean(mincut_loss) 84 | 85 | # Orthogonality regularization. 86 | ss = torch.matmul(s.transpose(1, 2), s) 87 | i_s = torch.eye(k).type_as(ss) 88 | ortho_loss = torch.norm( 89 | ss / torch.norm(ss, dim=(-1, -2), keepdim=True) - 90 | i_s / torch.norm(i_s), dim=(-1, -2)) 91 | ortho_loss = torch.mean(ortho_loss) 92 | 93 | # Fix and normalize coarsened adjacency matrix. 94 | ind = torch.arange(k, device=out_adj.device) 95 | out_adj[:, ind, ind] = 0 96 | d = torch.einsum('ijk->ij', out_adj) 97 | d = torch.sqrt(d)[:, None] + EPS 98 | out_adj = (out_adj / d) / d.transpose(1, 2) 99 | 100 | # out_loss = mincut_loss + ortho_loss 101 | 102 | return s, out, out_adj, mincut_loss, ortho_loss 103 | 104 | def _rank3_trace(x): 105 | return torch.einsum('ijj->i', x) 106 | 107 | 108 | def _rank3_diag(x): 109 | eye = torch.eye(x.size(1)).type_as(x) 110 | out = eye * x.unsqueeze(2).expand(*x.size(), x.size(1)) 111 | return out 112 | 113 | 114 | from typing import List, Optional, Tuple, Union 115 | 116 | import torch 117 | import torch.nn.functional as F 118 | from torch import Tensor 119 | 120 | from torch_geometric.nn.dense.mincut_pool import _rank3_trace 121 | 122 | EPS = 1e-15 123 | 124 | 125 | def dense_dmon_pool(x, adj, s, mask=None): 126 | r""" 127 | Args: 128 | x (Tensor): Node feature tensor :math:`\mathbf{X} \in 129 | \mathbb{R}^{B \times N \times F}` with batch-size 130 | :math:`B`, (maximum) number of nodes :math:`N` for each graph, 131 | and feature dimension :math:`F`. 132 | Note that the cluster assignment matrix 133 | :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}` is 134 | being created within this method. 135 | adj (Tensor): Adjacency tensor 136 | :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. 137 | mask (BoolTensor, optional): Mask matrix 138 | :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating 139 | the valid nodes for each graph. (default: :obj:`None`) 140 | 141 | :rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`, 142 | :class:`Tensor`, :class:`Tensor`, :class:`Tensor`) 143 | """ 144 | 145 | # x = x.unsqueeze(0) if x.dim() == 2 else x 146 | # adj = adj.unsqueeze(0) if adj.dim() == 2 else adj 147 | # s = s.unsqueeze(0) if s.dim() == 2 else s 148 | 149 | # (batch_size, num_nodes, _), k = x.size(), s.size(-1) 150 | 151 | # s_out = torch.softmax(s, dim=-1) 152 | 153 | # if mask is not None: 154 | # mask = mask.view(batch_size, num_nodes, 1).to(x.dtype) 155 | # x, s = x * mask, s_out * mask 156 | 157 | # out = torch.matmul(s.transpose(1, 2), x) 158 | # out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) 159 | x = x.unsqueeze(0) if x.dim() == 2 else x 160 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj 161 | s = s.unsqueeze(0) if s.dim() == 2 else s 162 | 163 | (batch_size, num_nodes, _), k = x.size(), s.size(-1) 164 | s = torch.softmax(s, dim=-1) 165 | s_out = s 166 | 167 | if mask is not None: 168 | mask = mask.view(batch_size, num_nodes, 1).to(x.dtype) 169 | x, s = x * mask, s * mask 170 | 171 | out = torch.matmul(s.transpose(1, 2), x) 172 | out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) 173 | 174 | # Spectral loss: 175 | degrees = torch.einsum('ijk->ik', adj).transpose(0, 1) 176 | m = torch.einsum('ij->', degrees) 177 | 178 | ca = torch.matmul(s.transpose(1, 2), degrees) 179 | cb = torch.matmul(degrees.transpose(0, 1), s) 180 | 181 | normalizer = torch.matmul(ca, cb) / 2 / m 182 | decompose = out_adj - normalizer 183 | spectral_loss = -_rank3_trace(decompose) / 2 / m 184 | spectral_loss = torch.mean(spectral_loss) 185 | 186 | # Orthogonality regularization: 187 | ss = torch.matmul(s.transpose(1, 2), s) 188 | i_s = torch.eye(k).type_as(ss) 189 | ortho_loss = torch.norm( 190 | ss / torch.norm(ss, dim=(-1, -2), keepdim=True) - 191 | i_s / torch.norm(i_s), dim=(-1, -2)) 192 | ortho_loss = torch.mean(ortho_loss) 193 | 194 | # Cluster loss: 195 | cluster_loss = torch.norm(torch.einsum( 196 | 'ijk->ij', ss)) / adj.size(1) * torch.norm(i_s) - 1 197 | 198 | # Fix and normalize coarsened adjacency matrix: 199 | ind = torch.arange(k, device=out_adj.device) 200 | out_adj[:, ind, ind] = 0 201 | d = torch.einsum('ijk->ij', out_adj) 202 | d = torch.sqrt(d)[:, None] + EPS 203 | out_adj = (out_adj / d) / d.transpose(1, 2) 204 | 205 | 206 | return s_out, out, out_adj, spectral_loss, ortho_loss, cluster_loss 207 | 208 | 209 | import torch 210 | 211 | EPS = 1e-15 212 | 213 | 214 | def simplify_pool(x, adj, s, mask=None, normalize=True): 215 | r"""The Just Balance pooling operator from the `"Simplifying Clustering with 216 | Graph Neural Networks" `_ paper 217 | 218 | .. math:: 219 | \mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot 220 | \mathbf{X} 221 | \mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot 222 | \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S}) 223 | based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B 224 | \times N \times C}`. 225 | Returns the pooled node feature matrix, the coarsened and symmetrically 226 | normalized adjacency matrix and the following auxiliary objective: 227 | .. math:: 228 | \mathcal{L} = - {\mathrm{Tr}(\sqrt{\mathbf{S}^{\top} \mathbf{S}})} 229 | Args: 230 | x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B 231 | \times N \times F}` with batch-size :math:`B`, (maximum) 232 | number of nodes :math:`N` for each graph, and feature dimension 233 | :math:`F`. 234 | adj (Tensor): Symmetrically normalized adjacency tensor 235 | :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. 236 | s (Tensor): Assignment tensor :math:`\mathbf{S} \in \mathbb{R}^{B 237 | \times N \times C}` with number of clusters :math:`C`. The softmax 238 | does not have to be applied beforehand, since it is executed 239 | within this method. 240 | mask (BoolTensor, optional): Mask matrix 241 | :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating 242 | the valid nodes for each graph. (default: :obj:`None`) 243 | :rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`, 244 | :class:`Tensor`) 245 | """ 246 | 247 | x = x.unsqueeze(0) if x.dim() == 2 else x 248 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj 249 | s = s.unsqueeze(0) if s.dim() == 2 else s 250 | 251 | (batch_size, num_nodes, _), k = x.size(), s.size(-1) 252 | 253 | s = torch.softmax(s, dim=-1) 254 | s_out = s 255 | 256 | if mask is not None: 257 | mask = mask.view(batch_size, num_nodes, 1).to(x.dtype) 258 | x, s = x * mask, s * mask 259 | 260 | out = torch.matmul(s.transpose(1, 2), x) 261 | out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) 262 | 263 | # Loss 264 | ss = torch.matmul(s.transpose(1, 2), s) 265 | ss_sqrt = torch.sqrt(ss + EPS) 266 | loss = torch.mean(-_rank3_trace(ss_sqrt)) 267 | if normalize: 268 | loss = loss / torch.sqrt(torch.tensor(num_nodes * k)) 269 | 270 | # Fix and normalize coarsened adjacency matrix. 271 | ind = torch.arange(k, device=out_adj.device) 272 | out_adj[:, ind, ind] = 0 273 | d = torch.einsum('ijk->ij', out_adj) 274 | d = torch.sqrt(d)[:, None] + EPS 275 | out_adj = (out_adj / d) / d.transpose(1, 2) 276 | 277 | return s_out, out, out_adj, loss 278 | 279 | 280 | def _rank3_trace(x): 281 | return torch.einsum('ijj->i', x) 282 | -------------------------------------------------------------------------------- /PSICHIC/models/scaler.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Union 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from torch_geometric.nn.aggr import Aggregation, MultiAggregation 7 | from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver 8 | from torch_geometric.utils import degree 9 | 10 | 11 | class DegreeScalerAggregation(Aggregation): 12 | r"""Combines one or more aggregators and transforms its output with one or 13 | more scalers as introduced in the `"Principal Neighbourhood Aggregation for 14 | Graph Nets" `_ paper. 15 | The scalers are normalised by the in-degree of the training set and so must 16 | be provided at time of construction. 17 | See :class:`torch_geometric.nn.conv.PNAConv` for more information. 18 | 19 | Args: 20 | aggr (string or list or Aggregation): The aggregation scheme to use. 21 | See :class:`~torch_geometric.nn.conv.MessagePassing` for more 22 | information. 23 | scaler (str or list): Set of scaling function identifiers, namely one 24 | or more of :obj:`"identity"`, :obj:`"amplification"`, 25 | :obj:`"attenuation"`, :obj:`"linear"` and :obj:`"inverse_linear"`. 26 | deg (Tensor): Histogram of in-degrees of nodes in the training set, 27 | used by scalers to normalize. 28 | aggr_kwargs (Dict[str, Any], optional): Arguments passed to the 29 | respective aggregation function in case it gets automatically 30 | resolved. (default: :obj:`None`) 31 | """ 32 | def __init__( 33 | self, 34 | aggr: Union[str, List[str], Aggregation], 35 | scaler: Union[str, List[str]], 36 | deg: Tensor, 37 | aggr_kwargs: Optional[List[Dict[str, Any]]] = None, 38 | ): 39 | super().__init__() 40 | 41 | if isinstance(aggr, (str, Aggregation)): 42 | self.aggr = aggr_resolver(aggr, **(aggr_kwargs or {})) 43 | elif isinstance(aggr, (tuple, list)): 44 | self.aggr = MultiAggregation(aggr, aggr_kwargs) 45 | else: 46 | raise ValueError(f"Only strings, list, tuples and instances of" 47 | f"`torch_geometric.nn.aggr.Aggregation` are " 48 | f"valid aggregation schemes (got '{type(aggr)}')") 49 | 50 | self.scaler = [scaler] if isinstance(aggr, str) else scaler 51 | 52 | deg = deg.to(torch.float) 53 | num_nodes = int(deg.sum()) 54 | bin_degrees = torch.arange(deg.numel(), device=deg.device) 55 | self.avg_deg: Dict[str, float] = { 56 | 'lin': float((bin_degrees * deg).sum()) / num_nodes, 57 | 'log': float(((bin_degrees + 1).log() * deg).sum()) / num_nodes, 58 | 'exp': float((bin_degrees.exp() * deg).sum()) / num_nodes, 59 | } 60 | 61 | def forward(self, x: Tensor, index: Optional[Tensor] = None, 62 | ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, 63 | dim: int = -2) -> Tensor: 64 | 65 | # TODO Currently, `degree` can only operate on `index`: 66 | self.assert_index_present(index) 67 | 68 | out = self.aggr(x, index, ptr, dim_size, dim) 69 | 70 | assert index is not None 71 | deg = degree(index, num_nodes=dim_size, dtype=out.dtype).clamp_(1) 72 | size = [1] * len(out.size()) 73 | size[dim] = -1 74 | deg = deg.view(size) 75 | 76 | outs = [] 77 | for scaler in self.scaler: 78 | if scaler == 'identity': 79 | out_scaler = out 80 | elif scaler == 'amplification': 81 | out_scaler = out * (torch.log(deg + 1) / self.avg_deg['log']) 82 | elif scaler == 'attenuation': 83 | out_scaler = out * (self.avg_deg['log'] / torch.log(deg + 1)) 84 | elif scaler == 'exponential': 85 | out_scaler = out * (torch.exp(deg) / self.avg_deg['exp']) 86 | elif scaler == 'linear': 87 | out_scaler = out * (deg / self.avg_deg['lin']) 88 | elif scaler == 'inverse_linear': 89 | out_scaler = out * (self.avg_deg['lin'] / deg) 90 | else: 91 | raise ValueError(f"Unknown scaler '{scaler}'") 92 | outs.append(out_scaler) 93 | 94 | return torch.cat(outs, dim=-1) if len(outs) > 1 else outs[0] 95 | -------------------------------------------------------------------------------- /PSICHIC/psichic_utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metanova-labs/nova/bd3dbf28c45a8b089424939a1c39c90cc291dcd0/PSICHIC/psichic_utils/.DS_Store -------------------------------------------------------------------------------- /PSICHIC/psichic_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_utils import * 2 | from .dataset import * 3 | from .metrics import * 4 | from .protein_init import protein_init 5 | from .ligand_init import ligand_init, smiles2graph 6 | -------------------------------------------------------------------------------- /PSICHIC/psichic_utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sys 4 | 5 | # Check if the code is running in a Jupyter notebook 6 | if 'ipykernel' in sys.modules: 7 | from tqdm.notebook import tqdm 8 | else: 9 | from tqdm import tqdm 10 | 11 | 12 | from itertools import repeat 13 | import pandas as pd 14 | 15 | 16 | import torch 17 | 18 | from torch_geometric.loader import DataLoader 19 | from torch_geometric.utils import degree 20 | 21 | BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 22 | sys.path.append(BASE_DIR) 23 | 24 | from runtime_config import RuntimeConfig 25 | device = RuntimeConfig.DEVICE 26 | 27 | class InfiniteDataLoader(DataLoader): 28 | def __init__(self, *args, **kwargs): 29 | super().__init__(*args, **kwargs) 30 | # Initialize an iterator over the dataset. 31 | self.dataset_iterator = super().__iter__() 32 | 33 | def __iter__(self): 34 | return self 35 | 36 | def __next__(self): 37 | try: 38 | batch = next(self.dataset_iterator) 39 | except StopIteration: 40 | # Dataset exhausted, use a new fresh iterator. 41 | self.dataset_iterator = super().__iter__() 42 | batch = next(self.dataset_iterator) 43 | return batch 44 | 45 | def create_custom_loader(type='epoch'): 46 | if type == 'epoch': 47 | return DataLoader 48 | elif type =='infinite': 49 | return InfiniteDataLoader 50 | else: 51 | raise Exception('Not Implemented') 52 | 53 | class CustomWeightedRandomSampler(torch.utils.data.WeightedRandomSampler): 54 | """WeightedRandomSampler except allows for more than 2^24 samples to be sampled""" 55 | def __init__(self, *args, **kwargs): 56 | super().__init__(*args, **kwargs) 57 | 58 | def __iter__(self): 59 | rand_tensor = np.random.choice(range(0, len(self.weights)), 60 | size=self.num_samples, 61 | p=self.weights.numpy() / torch.sum(self.weights).numpy(), 62 | replace=self.replacement) 63 | rand_tensor = torch.from_numpy(rand_tensor) 64 | return iter(rand_tensor.tolist()) 65 | 66 | def sampler_from_weights(weights): 67 | sampler = CustomWeightedRandomSampler(weights, len(weights), replacement=True) 68 | 69 | return sampler 70 | def create_custom_sampler(class_list, specified_weight={}): 71 | assert isinstance(specified_weight,dict) 72 | class_list = np.array(class_list) 73 | class_weight = { 74 | t: 1./len(np.where(class_list == t)[0]) for t in np.unique(class_list) 75 | } 76 | 77 | samples_weight = np.array([class_weight[t] for t in class_list]) 78 | 79 | if specified_weight: 80 | specified_weight = np.array([specified_weight[i] for i in class_list]) 81 | samples_weight *= specified_weight 82 | 83 | sampler = CustomWeightedRandomSampler(samples_weight, len(samples_weight)) 84 | 85 | return sampler 86 | 87 | def compute_pna_degrees(train_loader): 88 | mol_max_degree = -1 89 | clique_max_degree = -1 90 | prot_max_degree = -1 91 | 92 | for data in tqdm(train_loader): 93 | # mol 94 | mol_d = degree(data.mol_edge_index[1], num_nodes=data.mol_x.shape[0], dtype=torch.long) 95 | mol_max_degree = max(mol_max_degree, int(mol_d.max())) 96 | # clique 97 | try: 98 | clique_d = degree(data.clique_edge_index[1], num_nodes=data.clique_x.shape[0], dtype=torch.long) 99 | except RuntimeError: 100 | print(data.clique_edge_index[1]) 101 | print(data.clique_x) 102 | print('clique shape',data.clique_x.shape) 103 | print('atom shape',data.mol_x.shape[0]) 104 | break 105 | clique_max_degree = max(clique_max_degree, int(clique_d.max())) 106 | # protein 107 | prot_d = degree(data.prot_edge_index[1], num_nodes=data.prot_node_aa.shape[0], dtype=torch.long) 108 | prot_max_degree = max(prot_max_degree, int(prot_d.max())) 109 | 110 | # Compute the in-degree histogram tensor 111 | mol_deg = torch.zeros(mol_max_degree + 1, dtype=torch.long) 112 | clique_deg = torch.zeros(clique_max_degree + 1, dtype=torch.long) 113 | prot_deg = torch.zeros(prot_max_degree + 1, dtype=torch.long) 114 | 115 | for data in tqdm(train_loader): 116 | # mol 117 | mol_d = degree(data.mol_edge_index[1], num_nodes=data.mol_x.shape[0], dtype=torch.long) 118 | mol_deg += torch.bincount(mol_d, minlength=mol_deg.numel()) 119 | 120 | # clique 121 | clique_d = degree(data.clique_edge_index[1], num_nodes=data.clique_x.shape[0], dtype=torch.long) 122 | clique_deg += torch.bincount(clique_d, minlength=clique_deg.numel()) 123 | 124 | # Protein 125 | prot_d = degree(data.prot_edge_index[1], num_nodes=data.prot_node_aa.shape[0], dtype=torch.long) 126 | prot_deg += torch.bincount(prot_d, minlength=prot_deg.numel()) 127 | 128 | return mol_deg, clique_deg, prot_deg 129 | 130 | 131 | def unbatch(src, batch, dim: int = 0): 132 | r"""Splits :obj:`src` according to a :obj:`batch` vector along dimension 133 | :obj:`dim`. 134 | 135 | Args: 136 | src (Tensor): The source tensor. 137 | batch (LongTensor): The batch vector 138 | :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each 139 | entry in :obj:`src` to a specific example. Must be ordered. 140 | dim (int, optional): The dimension along which to split the :obj:`src` 141 | tensor. (default: :obj:`0`) 142 | 143 | :rtype: :class:`List[Tensor]` 144 | 145 | Example: 146 | 147 | >>> src = torch.arange(7) 148 | >>> batch = torch.tensor([0, 0, 0, 1, 1, 2, 2]) 149 | >>> unbatch(src, batch) 150 | (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) 151 | """ 152 | sizes = degree(batch, dtype=torch.long).tolist() 153 | return src.split(sizes, dim) 154 | 155 | 156 | def unbatch_nodes(data_tensor, index_tensor): 157 | """ 158 | Unbatch a data tensor based on an index tensor. 159 | 160 | Args: 161 | data_tensor (torch.Tensor): The tensor to be unbatched. 162 | index_tensor (torch.Tensor): A tensor of the same length as data_tensor's first dimension, 163 | indicating the batch index for each element in data_tensor. 164 | 165 | Returns: 166 | list[torch.Tensor]: A list of tensors, where each tensor corresponds to a separate batch. 167 | """ 168 | return [data_tensor[index_tensor == i] for i in index_tensor.unique()] 169 | 170 | 171 | def repeater(data_loader): 172 | for loader in repeat(data_loader): 173 | for data in loader: 174 | yield data 175 | 176 | def printline(line): 177 | sys.stdout.write(line + "\x1b[K\r") 178 | sys.stdout.flush() 179 | 180 | 181 | def protein_degree_from_dict(protein_dict): 182 | protein_max_degree = -1 183 | for k, v in protein_dict.items(): 184 | node_num = len(v['seq']) 185 | edge_index = v['edge_index'] 186 | protein_degree = degree(edge_index[1], num_nodes=node_num, dtype=torch.long) 187 | protein_max_degree = max(protein_max_degree, protein_degree.max()) 188 | 189 | protein_deg = torch.zeros(protein_max_degree + 1, dtype=torch.long) 190 | for k, v in protein_dict.items(): 191 | node_num = len(v['seq']) 192 | edge_index = v['edge_index'] 193 | protein_degree = degree(edge_index[1], num_nodes=node_num, dtype=torch.long) 194 | protein_deg += torch.bincount(protein_degree, minlength=protein_deg.numel()) 195 | 196 | return protein_deg 197 | 198 | 199 | def ligand_degree_from_dict(ligand_dict): 200 | mol_max_degree = -1 201 | clique_max_degree = -1 202 | 203 | for k, v in tqdm(ligand_dict.items()): 204 | # mol 205 | mol_x = v['atom_feature'] 206 | adj = v['bond_feature'] 207 | mol_edge_index = adj.nonzero(as_tuple=False).t().contiguous() 208 | mol_d = degree(mol_edge_index[1], num_nodes=mol_x.shape[0], dtype=torch.long) 209 | mol_max_degree = max(mol_max_degree, int(mol_d.max())) 210 | # clique 211 | clique_x = v['x_clique'] 212 | clique_edge_index = v['tree_edge_index'].long() 213 | clique_d = degree(clique_edge_index[1], num_nodes=clique_x.shape[0], dtype=torch.long) 214 | clique_max_degree = max(clique_max_degree, int(clique_d.max())) 215 | 216 | mol_deg = torch.zeros(mol_max_degree + 1, dtype=torch.long) 217 | clique_deg = torch.zeros(clique_max_degree + 1, dtype=torch.long) 218 | 219 | for k, v in tqdm(ligand_dict.items()): 220 | # mol 221 | mol_x = v['atom_feature'] 222 | adj = v['bond_feature'] 223 | mol_edge_index = adj.nonzero(as_tuple=False).t().contiguous() 224 | mol_d = degree(mol_edge_index[1], num_nodes=mol_x.shape[0], dtype=torch.long) 225 | 226 | mol_deg += torch.bincount(mol_d, minlength=mol_deg.numel()) 227 | # clique 228 | clique_x = v['x_clique'] 229 | clique_edge_index = v['tree_edge_index'].long() 230 | clique_d = degree(clique_edge_index[1], num_nodes=clique_x.shape[0], dtype=torch.long) 231 | clique_deg += torch.bincount(clique_d, minlength=clique_deg.numel()) 232 | 233 | return mol_deg, clique_deg 234 | 235 | 236 | def minmax_norm(arr): 237 | return (arr - arr.min())/(arr.max() - arr.min()) 238 | 239 | def percentile_rank(arr): 240 | return np.argsort(np.argsort(arr)) / (len(arr)-1) 241 | 242 | from rdkit import Chem 243 | from rdkit.Chem import PropertyPickleOptions 244 | import pickle 245 | 246 | def store_ligand_score(ligand_smiles, atom_types, atom_scores, ligand_path): 247 | # Create a molecule from a SMILES string 248 | mol = Chem.MolFromSmiles(ligand_smiles) 249 | # Add an atom-level property to the first atom 250 | for i, atom in enumerate(mol.GetAtoms()): 251 | 252 | if atom_types[i] == atom.GetSymbol(): 253 | atom.SetProp("PSICHIC_Atom_Score", str(atom_scores[i])) 254 | else: 255 | return False 256 | # Configure RDKit to pickle all properties 257 | Chem.SetDefaultPickleProperties(PropertyPickleOptions.AllProps) 258 | 259 | # Serialize molecule to a pickle file 260 | with open(ligand_path, 'wb') as f: 261 | pickle.dump(mol, f) 262 | 263 | return True 264 | 265 | def store_result(df, attention_dict, interaction_keys, ligand_dict, 266 | reg_pred=None, cls_pred=None, mcls_pred=None, 267 | result_path='', save_interpret=True): 268 | if save_interpret: 269 | unbatched_residue_score = unbatch(attention_dict['residue_final_score'],attention_dict['protein_residue_index']) 270 | unbatched_atom_score = unbatch(attention_dict['atom_final_score'], attention_dict['drug_atom_index']) 271 | unbatched_residue_layer_score = unbatch(attention_dict['residue_layer_scores'],attention_dict['protein_residue_index']) 272 | unbatched_clique_layer_score = unbatch(attention_dict['clique_layer_scores'], attention_dict['drug_clique_index']) 273 | 274 | for idx, key in enumerate(interaction_keys): 275 | matching_row = (df['Protein'] == key[0]) & (df['Ligand'] == key[1]) 276 | if reg_pred is not None: 277 | if 'predicted_binding_affinity' in df.columns: 278 | df.loc[matching_row, 'predicted_binding_affinity'] = reg_pred[idx] 279 | else: 280 | df['predicted_binding_affinity'] = None 281 | df.loc[matching_row, 'predicted_binding_affinity'] = reg_pred[idx] 282 | if cls_pred is not None: 283 | if 'predicted_binary_interaction' in df.columns: 284 | df.loc[matching_row, 'predicted_binary_interaction'] = cls_pred[idx] 285 | else: 286 | df['predicted_binary_interaction'] = None 287 | df.loc[matching_row, 'predicted_binary_interaction'] = cls_pred[idx] 288 | 289 | if mcls_pred is not None: 290 | if 'predicted_antagonist' in df.columns and 'predicted_nonbinder' in df.columns and 'predicted_agonist' in df.columns: 291 | df.loc[matching_row, ['predicted_antagonist','predicted_nonbinder','predicted_agonist']] = mcls_pred[idx].tolist() 292 | else: 293 | df['predicted_antagonist'] = None 294 | df['predicted_nonbinder'] = None 295 | df['predicted_agonist'] = None 296 | df.loc[matching_row, ['predicted_antagonist','predicted_nonbinder','predicted_agonist']] = mcls_pred[idx].tolist() 297 | 298 | 299 | if all([idx in attention_dict['cluster_s'] for idx in range(3)]): 300 | unbatched_cluster_s0 = unbatch_nodes(attention_dict['cluster_s'][0].softmax(dim=-1), attention_dict['protein_residue_index']) 301 | unbatched_cluster_s1 = unbatch_nodes(attention_dict['cluster_s'][1].softmax(dim=-1), attention_dict['protein_residue_index']) 302 | unbatched_cluster_s2 = unbatch_nodes(attention_dict['cluster_s'][2].softmax(dim=-1), attention_dict['protein_residue_index']) 303 | 304 | if save_interpret: 305 | for pair_id in df[matching_row]['ID']: 306 | pair_path = os.path.join(result_path,pair_id) 307 | if not os.path.exists(pair_path): 308 | os.makedirs(pair_path) 309 | ## STORE Protein Interpretation 310 | protein_interpret = pd.DataFrame({ 311 | 'Residue_Type':list(key[0]), 312 | 'PSICHIC_Residue_Score':minmax_norm(unbatched_residue_score[idx].cpu().flatten().numpy()) 313 | }) 314 | 315 | protein_interpret['Residue_ID'] = protein_interpret.index + 1 316 | protein_interpret['PSICHIC_Residue_Percentile'] = percentile_rank(protein_interpret['PSICHIC_Residue_Score']) 317 | protein_interpret = protein_interpret[['Residue_ID','Residue_Type','PSICHIC_Residue_Score','PSICHIC_Residue_Percentile']] 318 | 319 | if all([id_ in attention_dict['cluster_s'] for id_ in range(3)]): 320 | for ci in range(5): 321 | protein_interpret['Layer0_Cluster'+str(ci)] = unbatched_cluster_s0[idx][:,ci].cpu().flatten().numpy() 322 | 323 | for ci in range(10): 324 | protein_interpret['Layer1_Cluster'+str(ci)] = unbatched_cluster_s1[idx][:,ci].cpu().flatten().numpy() 325 | 326 | for ci in range(20): 327 | protein_interpret['Layer2_Cluster'+str(ci)] = unbatched_cluster_s2[idx][:,ci].cpu().flatten().numpy() 328 | 329 | protein_interpret.to_csv(os.path.join(pair_path,'protein.csv'),index=False) 330 | 331 | ## STORE Ligand Interpretation 332 | ligand_path = os.path.join(pair_path,'ligand.pkl') 333 | 334 | successful_ligand = store_ligand_score(key[1], ligand_dict[key[1]]['atom_types'].split('|'), 335 | minmax_norm(unbatched_atom_score[idx].cpu().flatten().numpy()), 336 | ligand_path) 337 | if not successful_ligand: 338 | print('Ligand Intepretation for {} failed due to not matching atom order.'.format(pair_id)) 339 | ## STORE Fingerprint 340 | np.save(os.path.join(pair_path,'fingerprint.npy'), 341 | attention_dict['interaction_fingerprint'][idx].detach().cpu().numpy() 342 | ) 343 | return df 344 | 345 | def virtual_screening(screen_df, model, data_loader, result_path, save_interpret=True, ligand_dict=None, device=device, 346 | save_cluster=False): 347 | if "ID" in screen_df.columns: 348 | # Iterate through the DataFrame check any empty pairs 349 | for i, row in screen_df.iterrows(): 350 | if pd.isna(row['ID']): 351 | screen_df.at[i, 'ID'] = f"PAIR_{i}" 352 | else: 353 | screen_df['ID'] = 'PAIR_' 354 | screen_df['ID'] += screen_df.index.astype(str) 355 | reg_preds = [] 356 | cls_preds = [] 357 | mcls_preds = [] 358 | 359 | model.eval() 360 | 361 | with torch.no_grad(): 362 | for data in tqdm(data_loader): 363 | data = data.to(device) 364 | reg_pred, cls_pred, mcls_pred, sp_loss, o_loss, cl_loss, attention_dict = model( 365 | # Molecule 366 | mol_x=data.mol_x, mol_x_feat=data.mol_x_feat, bond_x=data.mol_edge_attr, 367 | atom_edge_index=data.mol_edge_index, clique_x=data.clique_x, 368 | clique_edge_index=data.clique_edge_index, atom2clique_index=data.atom2clique_index, 369 | # Protein 370 | residue_x=data.prot_node_aa, residue_evo_x=data.prot_node_evo, 371 | residue_edge_index=data.prot_edge_index, 372 | residue_edge_weight=data.prot_edge_weight, 373 | # Mol-Protein Interaction batch 374 | mol_batch=data.mol_x_batch, prot_batch=data.prot_node_aa_batch, clique_batch=data.clique_x_batch, 375 | # save_cluster 376 | save_cluster=save_cluster 377 | ) 378 | interaction_keys = list(zip(data.prot_key, data.mol_key)) 379 | 380 | if reg_pred is not None: 381 | reg_pred = reg_pred.squeeze().reshape(-1).cpu().numpy() 382 | reg_preds.append(reg_pred) 383 | 384 | if cls_pred is not None: 385 | cls_pred = torch.sigmoid(cls_pred).squeeze().reshape(-1).cpu().numpy() 386 | cls_preds.append(cls_pred) 387 | 388 | if mcls_pred is not None: 389 | mcls_pred = torch.softmax(mcls_pred,dim=-1).cpu().numpy() 390 | mcls_preds.append(mcls_pred) 391 | 392 | screen_df = store_result(screen_df, attention_dict, interaction_keys, ligand_dict, 393 | reg_pred, cls_pred, mcls_pred, 394 | result_path=result_path, save_interpret = save_interpret) 395 | 396 | return screen_df 397 | -------------------------------------------------------------------------------- /PSICHIC/psichic_utils/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from torch_geometric.data import Dataset 3 | # from torch.utils.data import Dataset 4 | import torch 5 | import pandas as pd 6 | from torch_geometric.data import Data 7 | import pickle 8 | import torch.utils.data 9 | from copy import deepcopy 10 | import numpy as np 11 | import sys 12 | import os 13 | 14 | BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 15 | sys.path.append(BASE_DIR) 16 | 17 | from runtime_config import RuntimeConfig 18 | device = RuntimeConfig.DEVICE 19 | 20 | 21 | class ProteinMoleculeDataset(Dataset): 22 | def __init__(self, sequence_data, mol_obj, prot_obj, device=device, cache_transform=True): 23 | super(ProteinMoleculeDataset, self).__init__() 24 | 25 | if isinstance(sequence_data,pd.core.frame.DataFrame): 26 | self.pairs = sequence_data 27 | elif isinstance(sequence_data,str): 28 | self.pairs = pd.read_csv(sequence_data) 29 | else: 30 | raise Exception("provide dataframe object or csv path") 31 | 32 | ## MOLECULES 33 | if isinstance(mol_obj, dict): 34 | self.mols = mol_obj 35 | elif isinstance(mol_obj, str): 36 | with open(mol_obj, 'rb') as f: 37 | self.mols = pickle.load(f) 38 | else: 39 | raise Exception("provide dict mol object or pickle path") 40 | 41 | 42 | ## PROTEINS 43 | if isinstance(prot_obj, dict): 44 | self.prots = prot_obj 45 | elif isinstance(prot_obj, str): 46 | self.prots = torch.load(prot_obj) 47 | else: 48 | raise Exception("provide dict mol object or pickle path") 49 | 50 | self.device = device 51 | self.cache_transform = cache_transform 52 | 53 | if self.cache_transform: 54 | for _, v in self.mols.items(): 55 | v['atom_idx'] = v['atom_idx'].long().view(-1, 1) 56 | v['atom_feature'] = v['atom_feature'].float() 57 | adj = v['bond_feature'].long() 58 | mol_edge_index = adj.nonzero(as_tuple=False).t().contiguous() 59 | v['atom_edge_index'] = mol_edge_index 60 | v['atom_edge_attr'] = adj[mol_edge_index[0], mol_edge_index[1]].long() 61 | v['atom_num_nodes'] = v['atom_idx'].shape[0] 62 | 63 | ## Clique 64 | v['x_clique'] = v['x_clique'].long().view(-1, 1) 65 | v['clique_num_nodes'] = v['x_clique'].shape[0] 66 | v['tree_edge_index'] = v['tree_edge_index'].long() 67 | v['atom2clique_index'] = v['atom2clique_index'].long() 68 | 69 | for _, v in self.prots.items(): 70 | v['seq_feat'] = v['seq_feat'].float() 71 | v['token_representation'] = v['token_representation'].float() 72 | v['num_nodes'] = len(v['seq']) 73 | v['node_pos'] = torch.arange(len(v['seq'])).reshape(-1,1) 74 | v['edge_weight'] = v['edge_weight'].float() 75 | 76 | def get(self, index): 77 | return self.__getitem__(index) 78 | 79 | def len(self): 80 | return self.__len__() 81 | def __len__(self): 82 | return len(self.pairs) 83 | 84 | 85 | def __getitem__(self, idx): 86 | # Extract data 87 | mol_key = self.pairs.loc[idx,'Ligand'] 88 | prot_key = self.pairs.loc[idx,'Protein'] 89 | try: 90 | reg_y = self.pairs.loc[idx,'regression_label'] 91 | reg_y = torch.tensor(reg_y).float() 92 | except KeyError: 93 | reg_y = None 94 | 95 | 96 | try: 97 | cls_y = self.pairs.loc[idx,'classification_label'] 98 | cls_y = torch.tensor(cls_y).float() 99 | except KeyError: 100 | cls_y = None 101 | 102 | try: 103 | mcls_y = self.pairs.loc[idx,'multiclass_label'] 104 | mcls_y = torch.tensor(mcls_y + 1).float() 105 | except KeyError: 106 | mcls_y = None 107 | 108 | mol = self.mols[mol_key] 109 | prot = self.prots[prot_key] 110 | 111 | ## PROT 112 | if self.cache_transform: 113 | ## atom 114 | mol_x = mol['atom_idx'] 115 | mol_x_feat = mol['atom_feature'] 116 | mol_edge_index = mol['atom_edge_index'] 117 | mol_edge_attr = mol['atom_edge_attr'] 118 | mol_num_nodes = mol['atom_num_nodes'] 119 | 120 | ## Clique 121 | mol_x_clique = mol['x_clique'] 122 | clique_num_nodes = mol['clique_num_nodes'] 123 | clique_edge_index = mol['tree_edge_index'] 124 | atom2clique_index = mol['atom2clique_index'] 125 | ## Prot 126 | prot_seq = prot['seq'] 127 | prot_node_aa = prot['seq_feat'] 128 | prot_node_evo = prot['token_representation'] 129 | prot_num_nodes = prot['num_nodes'] 130 | prot_node_pos = prot['node_pos'] 131 | prot_edge_index = prot['edge_index'] 132 | prot_edge_weight = prot['edge_weight'] 133 | else: 134 | # MOL 135 | mol_x = mol['atom_idx'].long().view(-1, 1) 136 | mol_x_feat = mol['atom_feature'].float() 137 | adj = mol['bond_feature'].long() 138 | mol_edge_index = adj.nonzero(as_tuple=False).t().contiguous() 139 | mol_edge_attr = adj[mol_edge_index[0], mol_edge_index[1]].long() 140 | mol_num_nodes = mol_x.shape[0] 141 | 142 | ## Clique 143 | mol_x_clique = mol['x_clique'].long().view(-1, 1) 144 | clique_num_nodes = mol_x_clique.shape[0] 145 | clique_edge_index = mol['tree_edge_index'].long() 146 | atom2clique_index = mol['atom2clique_index'].long() 147 | 148 | 149 | prot_seq = prot['seq'] 150 | prot_node_aa = prot['seq_feat'].float() 151 | prot_node_evo = prot['token_representation'].float() 152 | prot_num_nodes = len(prot['seq']) 153 | prot_node_pos = torch.arange(len(prot['seq'])).reshape(-1,1) 154 | prot_edge_index = prot['edge_index'] 155 | prot_edge_weight = prot['edge_weight'].float() 156 | 157 | out = MultiGraphData( 158 | ## MOLECULE 159 | mol_x=mol_x, mol_x_feat=mol_x_feat, mol_edge_index=mol_edge_index, 160 | mol_edge_attr=mol_edge_attr, mol_num_nodes= mol_num_nodes, 161 | clique_x=mol_x_clique, clique_edge_index=clique_edge_index, atom2clique_index=atom2clique_index, 162 | clique_num_nodes=clique_num_nodes, 163 | ## PROTEIN 164 | prot_node_aa=prot_node_aa, prot_node_evo=prot_node_evo, 165 | prot_node_pos=prot_node_pos, prot_seq=prot_seq, 166 | prot_edge_index=prot_edge_index, prot_edge_weight=prot_edge_weight, 167 | prot_num_nodes=prot_num_nodes, 168 | ## Y output 169 | reg_y=reg_y, cls_y=cls_y, mcls_y=mcls_y, 170 | ## keys 171 | mol_key = mol_key, prot_key = prot_key 172 | ) 173 | 174 | return out 175 | 176 | def maybe_num_nodes(index, num_nodes=None): 177 | # NOTE(WMF): I find out a problem here, 178 | # index.max().item() -> int 179 | # num_nodes -> tensor 180 | # need type conversion. 181 | # return index.max().item() + 1 if num_nodes is None else num_nodes 182 | return index.max().item() + 1 if num_nodes is None else int(num_nodes) 183 | 184 | def get_self_loop_attr(edge_index, edge_attr, num_nodes): 185 | r"""Returns the edge features or weights of self-loops 186 | :math:`(i, i)` of every node :math:`i \in \mathcal{V}` in the 187 | graph given by :attr:`edge_index`. Edge features of missing self-loops not 188 | present in :attr:`edge_index` will be filled with zeros. If 189 | :attr:`edge_attr` is not given, it will be the vector of ones. 190 | 191 | .. note:: 192 | This operation is analogous to getting the diagonal elements of the 193 | dense adjacency matrix. 194 | 195 | Args: 196 | edge_index (LongTensor): The edge indices. 197 | edge_attr (Tensor, optional): Edge weights or multi-dimensional edge 198 | features. (default: :obj:`None`) 199 | num_nodes (int, optional): The number of nodes, *i.e.* 200 | :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) 201 | 202 | :rtype: :class:`Tensor` 203 | 204 | Examples: 205 | 206 | >>> edge_index = torch.tensor([[0, 1, 0], 207 | ... [1, 0, 0]]) 208 | >>> edge_weight = torch.tensor([0.2, 0.3, 0.5]) 209 | >>> get_self_loop_attr(edge_index, edge_weight) 210 | tensor([0.5000, 0.0000]) 211 | 212 | >>> get_self_loop_attr(edge_index, edge_weight, num_nodes=4) 213 | tensor([0.5000, 0.0000, 0.0000, 0.0000]) 214 | """ 215 | loop_mask = edge_index[0] == edge_index[1] 216 | loop_index = edge_index[0][loop_mask] 217 | 218 | if edge_attr is not None: 219 | loop_attr = edge_attr[loop_mask] 220 | else: # A vector of ones: 221 | loop_attr = torch.ones_like(loop_index, dtype=torch.float) 222 | 223 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 224 | full_loop_attr = loop_attr.new_zeros((num_nodes, ) + loop_attr.size()[1:]) 225 | full_loop_attr[loop_index] = loop_attr 226 | 227 | return full_loop_attr 228 | 229 | 230 | 231 | class MultiGraphData(Data): 232 | def __inc__(self, key, item, *args): 233 | if key == 'mol_edge_index': 234 | return self.mol_x.size(0) 235 | elif key == 'clique_edge_index': 236 | return self.clique_x.size(0) 237 | elif key == 'atom2clique_index': 238 | return torch.tensor([[self.mol_x.size(0)], [self.clique_x.size(0)]]) 239 | elif key == 'prot_edge_index': 240 | return self.prot_node_aa.size(0) 241 | elif key == 'prot_struc_edge_index': 242 | return self.prot_node_aa.size(0) 243 | elif key == 'm2p_edge_index': 244 | return torch.tensor([[self.mol_x.size(0)], [self.prot_node_aa.size(0)]]) 245 | # elif key == 'edge_index_p2m': 246 | # return torch.tensor([[self.prot_node_s.size(0)],[self.mol_x.size(0)]]) 247 | else: 248 | return super(MultiGraphData, self).__inc__(key, item, *args) 249 | 250 | -------------------------------------------------------------------------------- /PSICHIC/psichic_utils/interpretation.py: -------------------------------------------------------------------------------- 1 | 2 | # def read_molecule 3 | # # Deserialize molecule from the pickle file 4 | # with open('molecule.pkl', 'rb') as f: 5 | # m2 = pickle.load(f) 6 | 7 | # # Print the properties to verify they are preserved 8 | # print('Molecule-level properties:', m2.GetPropsAsDict()) 9 | # print('Atom-level properties for the first atom:', m2.GetAtomWithIdx(0).GetPropsAsDict()) -------------------------------------------------------------------------------- /PSICHIC/psichic_utils/ligand_init.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | from rdkit.Chem.rdchem import BondType 3 | 4 | from rdkit.Chem import ChemicalFeatures 5 | from rdkit import RDConfig 6 | import os 7 | import numpy as np 8 | 9 | import torch 10 | 11 | fdefName = os.path.join(RDConfig.RDDataDir,'BaseFeatures.fdef') 12 | factory = ChemicalFeatures.BuildFeatureFactory(fdefName) 13 | import sys 14 | 15 | # Check if the code is running in a Jupyter notebook 16 | if 'ipykernel' in sys.modules: 17 | from tqdm.notebook import tqdm 18 | else: 19 | from tqdm import tqdm 20 | 21 | 22 | 23 | 24 | def one_of_k_encoding(x, allowable_set): 25 | if x not in allowable_set: 26 | raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set)) 27 | return list(map(lambda s: x == s, allowable_set)) 28 | 29 | def one_of_k_encoding_unk(x, allowable_set): 30 | """Maps inputs not in the allowable set to the last element.""" 31 | if x not in allowable_set: 32 | x = allowable_set[-1] 33 | return list(map(lambda s: x == s, allowable_set)) 34 | 35 | 36 | def atom_features(atom): 37 | #encoding = one_of_k_encoding_unk(atom.GetSymbol(),['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb','Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr','Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) 38 | encoding = one_of_k_encoding(atom.GetDegree(), [0,1,2,3,4,5,6,7,8,9,10]) + one_of_k_encoding_unk(atom.GetTotalNumHs(), [0,1,2,3,4,5,6,7,8,9,10]) 39 | encoding += one_of_k_encoding_unk(atom.GetImplicitValence(), [0,1,2,3,4,5,6,7,8,9,10]) 40 | encoding += one_of_k_encoding_unk(atom.GetHybridization(), [ 41 | Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, 42 | Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, 43 | Chem.rdchem.HybridizationType.SP3D2, 'other']) 44 | # encoding += one_of_k_encoding_unk(atom.GetFormalCharge(), [0,-1,1,2,-100]) 45 | # encoding += one_of_k_encoding_unk(atom.GetNumRadicalElectrons(), [0,1,2,-100]) 46 | encoding += [atom.GetIsAromatic()] 47 | # encoding += [atom.IsInRing()] 48 | 49 | try: 50 | encoding += one_of_k_encoding_unk( 51 | atom.GetProp('_CIPCode'), 52 | ['R', 'S']) + [atom.HasProp('_ChiralityPossible')] 53 | except: 54 | encoding += [0, 0] + [atom.HasProp('_ChiralityPossible')] 55 | 56 | return np.array(encoding) 57 | 58 | 59 | 60 | class MoleculeGraphDataset(): 61 | def __init__(self,atom_classes=None, halogen_detail=False, save_path=None): 62 | ## ATOM CLASSES ## 63 | self.ATOM_CODES = {} 64 | if atom_classes is None: 65 | metals = ([3, 4, 11, 12, 13] + list(range(19, 32)) 66 | + list(range(37, 51)) + list(range(55, 84)) 67 | + list(range(87, 104))) 68 | 69 | self.FEATURE_NAMES = [] 70 | if halogen_detail: 71 | atom_classes = [ 72 | (5, 'B'), 73 | (6, 'C'), 74 | (7, 'N'), 75 | (8, 'O'), 76 | (15, 'P'), 77 | (16, 'S'), 78 | (34, 'Se'), 79 | ## halogen 80 | (9, 'F'), 81 | (17, 'Cl'), 82 | (35, 'Br'), 83 | (53, 'I'), 84 | ## halogen 85 | (metals, 'metal') 86 | ] 87 | else: 88 | atom_classes = [ 89 | (5, 'B'), 90 | (6, 'C'), 91 | (7, 'N'), 92 | (8, 'O'), 93 | (15, 'P'), 94 | (16, 'S'), 95 | (34, 'Se'), 96 | ## halogen 97 | ([9, 17, 35, 53], 'halogen'), 98 | ## halogen 99 | (metals, 'metal') 100 | ] 101 | 102 | 103 | self.NUM_ATOM_CLASSES = len(atom_classes) 104 | for code, (atom, name) in enumerate(atom_classes): 105 | if type(atom) is list: 106 | for a in atom: 107 | self.ATOM_CODES[a] = code 108 | else: 109 | self.ATOM_CODES[atom] = code 110 | self.FEATURE_NAMES.append(name) 111 | 112 | ## Extra atom feature to extract 113 | self.feat_types = ['Donor', 'Acceptor', 'Hydrophobe', 'LumpedHydrophobe'] 114 | 115 | ## Bond feature 116 | self.edge_dict = {BondType.SINGLE: 1, BondType.DOUBLE: 2, 117 | BondType.TRIPLE: 3, BondType.AROMATIC: 4, 118 | BondType.UNSPECIFIED: 1} 119 | ## File Paths 120 | self.save_path = save_path 121 | 122 | def hybridization_onehot(self,hybrid_type): 123 | hybrid_type = str(hybrid_type) 124 | types = {'S': 0, 'SP': 1, 'SP2': 2, 'SP3': 3, 'SP3D': 4, 'SP3D2': 5} 125 | 126 | encoding = np.zeros(len(types)) 127 | try: 128 | encoding[types[hybrid_type]] = 1.0 129 | except: 130 | pass 131 | return encoding 132 | 133 | def encode_num(self,atomic_num): 134 | """Encode atom type with a binary vector. If atom type is not included in 135 | the `atom_classes`, its encoding is an all-zeros vector. 136 | 137 | Parameters 138 | ---------- 139 | atomic_num: int 140 | Atomic number 141 | 142 | Returns 143 | ------- 144 | encoding: np.ndarray 145 | Binary vector encoding atom type (one-hot or null). 146 | """ 147 | 148 | if not isinstance(atomic_num, int): 149 | raise TypeError('Atomic number must be int, %s was given' 150 | % type(atomic_num)) 151 | 152 | encoding = np.zeros(self.NUM_ATOM_CLASSES) 153 | try: 154 | encoding[self.ATOM_CODES[atomic_num]] = 1.0 155 | except: 156 | pass 157 | return encoding 158 | 159 | def atom_feature_extract(self,atom): 160 | ''' 161 | Atom Feature Extraction: 162 | 0 - Degree 163 | 1 - Total Valency 164 | 2 to 7 - Hybridization Type One-hot Encoding 165 | 8 - Number of Radical Electrons 166 | 9 - Number of Formal Charge 167 | 10 - Aromatic 168 | 11 - Belongs to a Ring 169 | 12 - Final to X belongs to Atom Classes 170 | ''' 171 | 172 | feat = [] 173 | 174 | feat.append(atom.GetDegree()) 175 | feat.append(atom.GetTotalValence()) 176 | feat += self.hybridization_onehot(atom.GetHybridization()).tolist() 177 | feat.append(atom.GetNumRadicalElectrons()) 178 | feat.append(atom.GetFormalCharge()) 179 | feat.append(int(atom.GetIsAromatic())) 180 | feat.append(int(atom.IsInRing())) 181 | # Atom class 182 | #feat += self.encode_num(atom.GetAtomicNum()).tolist() 183 | 184 | return feat 185 | 186 | def mol_feature(self,mol): 187 | atom_ids = [] 188 | atom_feats = [] 189 | for atom in mol.GetAtoms(): 190 | atom_ids.append(atom.GetIdx()) 191 | feat = self.atom_feature_extract(atom) 192 | atom_feats.append(feat) 193 | 194 | feature = np.array(list(zip(*sorted(zip(atom_ids, atom_feats))))[-1]) 195 | 196 | return feature 197 | 198 | def mol_extra_feature(self, mol): 199 | atom_num = len(mol.GetAtoms()) 200 | feature = np.zeros((atom_num, len(self.feat_types))) 201 | 202 | fact_feats = factory.GetFeaturesForMol(mol) 203 | for f in fact_feats: 204 | f_type = f.GetFamily() 205 | if f_type in self.feat_types: 206 | f_index = self.feat_types.index(f_type) 207 | atom_ids = f.GetAtomIds() 208 | feature[atom_ids, f_index] = 1 209 | 210 | return feature 211 | 212 | def mol_simplified_feature(self,mol): 213 | atom_ids = [] 214 | atom_feats = [] 215 | for atom in mol.GetAtoms(): 216 | atom_ids.append(atom.GetIdx()) 217 | atomic_num = atom.GetAtomicNum() 218 | 219 | if atomic_num in self.ATOM_CODES.keys(): 220 | atom_feats.append([self.ATOM_CODES[atomic_num] + 1]) 221 | else: 222 | atom_feats.append([0]) 223 | feature = np.array(list(zip(*sorted(zip(atom_ids, atom_feats))))[-1]) 224 | 225 | return feature 226 | 227 | def mol_sequence_simplified_feature(self,mol): 228 | 229 | atom_ids = [] 230 | atom_feats = [] 231 | for atom in mol.GetAtoms(): 232 | 233 | atom_ids.append(atom.GetIdx()) 234 | onehot_label = one_of_k_encoding_unk(atom.GetSymbol(), 235 | ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As', 'Al', 236 | 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 237 | 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 238 | 'Unknown']) 239 | out = np.array(onehot_label).nonzero()[0] 240 | atom_feats.append(out) 241 | 242 | feature = np.array(list(zip(*sorted(zip(atom_ids, atom_feats))))[-1]) 243 | 244 | return feature 245 | 246 | 247 | 248 | 249 | def mol_full_feature(self, mol): 250 | atom_ids = [] 251 | atom_feats = [] 252 | for atom in mol.GetAtoms(): 253 | atom_ids.append(atom.GetIdx()) 254 | feature = atom_features(atom) 255 | atom_feats.append(feature) 256 | feature = np.array(list(zip(*sorted(zip(atom_ids, atom_feats))))[-1]) 257 | 258 | return feature 259 | 260 | def bond_feature(self,mol): 261 | atom_num = len(mol.GetAtoms()) 262 | adj = np.zeros((atom_num,atom_num)) 263 | 264 | for b in mol.GetBonds(): 265 | v1 = b.GetBeginAtomIdx() 266 | v2 = b.GetEndAtomIdx() 267 | b_type = self.edge_dict[b.GetBondType()] 268 | adj[v1 - 1, v2 - 1] = b_type 269 | adj[v2 - 1, v1 - 1] = b_type 270 | 271 | return adj 272 | 273 | def junction_tree(self,mol): 274 | tree_edge_index, atom2clique_index, num_cliques, x_clique = tree_decomposition(mol,return_vocab=True) 275 | ## if weird compounds => each assign the separate cluster 276 | if atom2clique_index.nelement() == 0: 277 | num_cliques = len(mol.GetAtoms()) 278 | x_clique = torch.tensor([3]*num_cliques) 279 | atom2clique_index = torch.stack([torch.arange(num_cliques), 280 | torch.arange(num_cliques)]) 281 | tree = dict(tree_edge_index=tree_edge_index, 282 | atom2clique_index=atom2clique_index, 283 | num_cliques=num_cliques, 284 | x_clique=x_clique) 285 | 286 | return tree 287 | 288 | 289 | def featurize(self,mol,type='atom_type'): 290 | if type=='atom_type': 291 | atom_feature = self.mol_simplified_feature(mol) 292 | elif type =='detailed_atom_type': 293 | atom_feature = self.mol_sequence_simplified_feature(mol) 294 | elif type=='atom_feature': 295 | base_feat = self.mol_feature(mol) 296 | extra_feat = self.mol_extra_feature(mol) 297 | atom_feature = np.concatenate((base_feat, extra_feat), axis=1) 298 | 299 | elif type=='atom_full_feature': 300 | atom_feature = self.mol_full_feature(mol) 301 | #extra_feat = self.mol_extra_feature(mol) 302 | #atom_feature = np.concatenate((base_feat, extra_feat), axis=1) 303 | else: 304 | raise Exception('atom_type or atom_feature') 305 | bond_feature = self.bond_feature(mol) 306 | 307 | return atom_feature, bond_feature 308 | 309 | 310 | 311 | 312 | ## FILE from pytorch geometric version 2. 313 | from itertools import chain 314 | from typing import Any, Tuple, Union 315 | 316 | import torch 317 | from scipy.sparse.csgraph import minimum_spanning_tree 318 | from torch import Tensor 319 | 320 | from torch_geometric.utils import ( 321 | from_scipy_sparse_matrix, 322 | to_scipy_sparse_matrix, 323 | to_undirected, 324 | ) 325 | 326 | 327 | def tree_decomposition( 328 | mol: Any, 329 | return_vocab: bool = False, 330 | ) -> Union[Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor, int, Tensor]]: 331 | r"""The tree decomposition algorithm of molecules from the 332 | `"Junction Tree Variational Autoencoder for Molecular Graph Generation" 333 | `_ paper. 334 | Returns the graph connectivity of the junction tree, the assignment 335 | mapping of each atom to the clique in the junction tree, and the number 336 | of cliques. 337 | 338 | Args: 339 | mol (rdkit.Chem.Mol): An :obj:`rdkit` molecule. 340 | return_vocab (bool, optional): If set to :obj:`True`, will return an 341 | identifier for each clique (ring, bond, bridged compounds, single). 342 | (default: :obj:`False`) 343 | 344 | :rtype: :obj:`(LongTensor, LongTensor, int)` if :obj:`return_vocab` is 345 | :obj:`False`, else :obj:`(LongTensor, LongTensor, int, LongTensor)` 346 | """ 347 | import rdkit.Chem as Chem 348 | 349 | # Cliques = rings and bonds. 350 | cliques = [list(x) for x in Chem.GetSymmSSSR(mol)] 351 | xs = [0] * len(cliques) 352 | for bond in mol.GetBonds(): 353 | if not bond.IsInRing(): 354 | cliques.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) 355 | xs.append(1) 356 | 357 | # Generate `atom2clique` mappings. 358 | atom2clique = [[] for i in range(mol.GetNumAtoms())] 359 | for c in range(len(cliques)): 360 | for atom in cliques[c]: 361 | atom2clique[atom].append(c) 362 | 363 | # Merge rings that share more than 2 atoms as they form bridged compounds. 364 | for c1 in range(len(cliques)): 365 | for atom in cliques[c1]: 366 | for c2 in atom2clique[atom]: 367 | if c1 >= c2 or len(cliques[c1]) <= 2 or len(cliques[c2]) <= 2: 368 | continue 369 | if len(set(cliques[c1]) & set(cliques[c2])) > 2: 370 | cliques[c1] = set(cliques[c1]) | set(cliques[c2]) 371 | xs[c1] = 2 372 | cliques[c2] = [] 373 | xs[c2] = -1 374 | cliques = [c for c in cliques if len(c) > 0] 375 | xs = [x for x in xs if x >= 0] 376 | 377 | # Update `atom2clique` mappings. 378 | atom2clique = [[] for i in range(mol.GetNumAtoms())] 379 | for c in range(len(cliques)): 380 | for atom in cliques[c]: 381 | atom2clique[atom].append(c) 382 | 383 | # Add singleton cliques in case there are more than 2 intersecting 384 | # cliques. We further compute the "initial" clique graph. 385 | edges = {} 386 | for atom in range(mol.GetNumAtoms()): 387 | cs = atom2clique[atom] 388 | if len(cs) <= 1: 389 | continue 390 | 391 | # Number of bond clusters that the atom lies in. 392 | bonds = [c for c in cs if len(cliques[c]) == 2] 393 | # Number of ring clusters that the atom lies in. 394 | rings = [c for c in cs if len(cliques[c]) > 4] 395 | 396 | if len(bonds) > 2 or (len(bonds) == 2 and len(cs) > 2): 397 | cliques.append([atom]) 398 | xs.append(3) 399 | c2 = len(cliques) - 1 400 | for c1 in cs: 401 | edges[(c1, c2)] = 1 402 | 403 | elif len(rings) > 2: 404 | cliques.append([atom]) 405 | xs.append(3) 406 | c2 = len(cliques) - 1 407 | for c1 in cs: 408 | edges[(c1, c2)] = 99 409 | 410 | else: 411 | for i in range(len(cs)): 412 | for j in range(i + 1, len(cs)): 413 | c1, c2 = cs[i], cs[j] 414 | count = len(set(cliques[c1]) & set(cliques[c2])) 415 | edges[(c1, c2)] = min(count, edges.get((c1, c2), 99)) 416 | 417 | # Update `atom2clique` mappings. 418 | atom2clique = [[] for i in range(mol.GetNumAtoms())] 419 | for c in range(len(cliques)): 420 | for atom in cliques[c]: 421 | atom2clique[atom].append(c) 422 | 423 | if len(edges) > 0: 424 | edge_index_T, weight = zip(*edges.items()) 425 | edge_index = torch.tensor(edge_index_T).t() 426 | inv_weight = 100 - torch.tensor(weight) 427 | graph = to_scipy_sparse_matrix(edge_index, inv_weight, len(cliques)) 428 | junc_tree = minimum_spanning_tree(graph) 429 | edge_index, _ = from_scipy_sparse_matrix(junc_tree) 430 | edge_index = to_undirected(edge_index, num_nodes=len(cliques)) 431 | else: 432 | edge_index = torch.empty((2, 0), dtype=torch.long) 433 | 434 | rows = [[i] * len(atom2clique[i]) for i in range(mol.GetNumAtoms())] 435 | row = torch.tensor(list(chain.from_iterable(rows))) 436 | col = torch.tensor(list(chain.from_iterable(atom2clique))) 437 | atom2clique = torch.stack([row, col], dim=0).to(torch.long) 438 | 439 | if return_vocab: 440 | vocab = torch.tensor(xs, dtype=torch.long) 441 | return edge_index, atom2clique, len(cliques), vocab 442 | else: 443 | return edge_index, atom2clique, len(cliques) 444 | 445 | 446 | ### 447 | 448 | def smiles2graph(m_str): 449 | mgd = MoleculeGraphDataset(halogen_detail=False) 450 | mol = Chem.MolFromSmiles(m_str) 451 | if mol is None: 452 | return None 453 | #mol = get_mol(m_str) 454 | try: 455 | atom_feature, bond_feature = mgd.featurize(mol,'atom_full_feature') 456 | atom_idx, _ = mgd.featurize(mol,'atom_type') 457 | tree = mgd.junction_tree(mol) 458 | 459 | out_dict = { 460 | 'smiles':m_str, 461 | 'atom_feature':torch.tensor(atom_feature),#.to(torch.int8), 462 | 'atom_types':'|'.join([i.GetSymbol() for i in mol.GetAtoms()]), 463 | 'atom_idx':torch.tensor(atom_idx),#.to(torch.int8), 464 | 'bond_feature':torch.tensor(bond_feature),#.to(torch.int8), 465 | 466 | } 467 | tree['tree_edge_index'] = tree['tree_edge_index']#.to(torch.int8) 468 | tree['atom2clique_index'] = tree['atom2clique_index']#.to(torch.int8) 469 | tree['x_clique'] = tree['x_clique']#.to(torch.int8) 470 | 471 | out_dict.update(tree) 472 | 473 | return out_dict 474 | except Exception as e: 475 | return None 476 | #### 477 | 478 | 479 | def ligand_init(smiles_list): 480 | ligand_dict = {} 481 | for smiles in tqdm(smiles_list): 482 | graph = smiles2graph(smiles) 483 | if graph is None: 484 | print(f"Error: {smiles} is an invalid SMILES string") 485 | continue 486 | ligand_dict[smiles] = graph 487 | return ligand_dict -------------------------------------------------------------------------------- /PSICHIC/psichic_utils/metrics.py: -------------------------------------------------------------------------------- 1 | from lifelines.utils import concordance_index 2 | from sklearn.metrics import average_precision_score, roc_auc_score, f1_score, recall_score, precision_score, accuracy_score 3 | import numpy as np 4 | from math import sqrt 5 | from sklearn.linear_model import LinearRegression 6 | from scipy import stats 7 | 8 | def get_cindex(Y, P): 9 | return concordance_index(Y, P) 10 | 11 | 12 | def get_mse(Y, P): 13 | Y = np.array(Y) 14 | P = np.array(P) 15 | return np.average((Y - P) ** 2) 16 | 17 | 18 | # Prepare for rm2 19 | def get_k(y_obs, y_pred): 20 | y_obs = np.array(y_obs) 21 | y_pred = np.array(y_pred) 22 | return sum(y_obs * y_pred) / sum(y_pred ** 2) 23 | 24 | 25 | # Prepare for rm2 26 | def squared_error_zero(y_obs, y_pred): 27 | k = get_k(y_obs, y_pred) 28 | y_obs = np.array(y_obs) 29 | y_pred = np.array(y_pred) 30 | y_obs_mean = np.mean(y_obs) 31 | upp = sum((y_obs - k * y_pred) ** 2) 32 | down = sum((y_obs - y_obs_mean) ** 2) 33 | 34 | return 1 - (upp / down) 35 | 36 | 37 | # Prepare for rm2 38 | def r_squared_error(y_obs, y_pred): 39 | y_obs = np.array(y_obs) 40 | y_pred = np.array(y_pred) 41 | y_obs_mean = np.mean(y_obs) 42 | y_pred_mean = np.mean(y_pred) 43 | mult = sum((y_obs - y_obs_mean) * (y_pred - y_pred_mean)) ** 2 44 | y_obs_sq = sum((y_obs - y_obs_mean) ** 2) 45 | y_pred_sq = sum((y_pred - y_pred_mean) ** 2) 46 | return mult / (y_obs_sq * y_pred_sq) 47 | 48 | 49 | def get_rm2(Y, P): 50 | r2 = r_squared_error(Y, P) 51 | r02 = squared_error_zero(Y, P) 52 | 53 | return r2 * (1 - np.sqrt(np.absolute(r2 ** 2 - r02 ** 2))) 54 | 55 | 56 | def cos_formula(a, b, c): 57 | ''' formula to calculate the angle between two edges 58 | a and b are the edge lengths, c is the angle length. 59 | ''' 60 | res = (a**2 + b**2 - c**2) / (2 * a * b) 61 | # sanity check 62 | res = -1. if res < -1. else res 63 | res = 1. if res > 1. else res 64 | return np.arccos(res) 65 | 66 | def get_rmse(y,f): 67 | rmse = sqrt(((y - f)**2).mean(axis=0)) 68 | return rmse 69 | 70 | def get_mae(y,f): 71 | mae = (np.abs(y-f)).mean() 72 | return mae 73 | 74 | def get_sd(y,f): 75 | f,y = f.reshape(-1,1),y.reshape(-1,1) 76 | lr = LinearRegression() 77 | lr.fit(f,y) 78 | y_ = lr.predict(f) 79 | sd = (((y - y_) ** 2).sum() / (len(y) - 1)) ** 0.5 80 | return sd 81 | 82 | def get_pearson(y,f): 83 | rp = np.corrcoef(y, f)[0,1] 84 | return rp 85 | 86 | def get_spearman(y,f): 87 | sp = stats.spearmanr(y,f)[0] 88 | 89 | return sp 90 | 91 | def evaluate_reg(Y, F): 92 | not_nan_indices = ~np.isnan(Y) 93 | Y = Y[not_nan_indices] 94 | F = F[not_nan_indices] 95 | 96 | return { 97 | 'mse': float(get_mse(Y,F)), 98 | 'rmse': float(get_rmse(Y,F)), 99 | 'mae': float(get_mae(Y,F)), 100 | 'sd': float(get_sd(Y,F)), 101 | 'pearson': float(get_pearson(Y,F)), 102 | 'spearman': float(get_spearman(Y,F)), 103 | 'rm2': float(get_rm2(Y,F)), 104 | 'ci': float(get_cindex(Y,F)) 105 | } 106 | 107 | def evaluate_cls(Y,P,threshold=0.5): 108 | predicted_label = P > threshold 109 | 110 | return { 111 | 'roc': float(roc_auc_score(Y,P)), 112 | 'prc': float(average_precision_score(Y,P)), 113 | 'f1': float(f1_score(Y,predicted_label)), 114 | 'recall':float(recall_score(Y, predicted_label)), 115 | 'precision': float(precision_score(Y, predicted_label)) 116 | } 117 | 118 | def indices_to_one_hot(data, nb_classes): 119 | """Convert an iterable of indices to one-hot encoded labels.""" 120 | targets = np.array(data).reshape(-1) 121 | return np.eye(nb_classes)[targets] 122 | 123 | # from sklearn.metrics import precision_recall_curve 124 | # from sklearn.metrics import average_precision_score 125 | 126 | def multiclass_ap(Y_test, y_score, n_classes): 127 | # For each class 128 | # precision = dict() 129 | # recall = dict() 130 | average_precision = dict() 131 | for i in range(n_classes): 132 | num_labels = Y_test[:, i].sum() 133 | if num_labels == 0: continue 134 | average_precision[i] = average_precision_score(Y_test[:, i], y_score[:, i]) 135 | return sum(average_precision.values()) / len(average_precision) 136 | 137 | def evaluate_mcls(Y,P): 138 | nclass = P.shape[-1] 139 | # Filter Y and P based on n_classes 140 | valid_indices = np.isin(Y, np.arange(nclass)) 141 | Y = Y[valid_indices] 142 | P = P[valid_indices] 143 | 144 | onehot_y = indices_to_one_hot(Y,nclass) 145 | try: 146 | roc = roc_auc_score(onehot_y, P, average='macro',multi_class='ovo') 147 | prc = multiclass_ap(onehot_y, P, n_classes=nclass) 148 | except: 149 | roc = -999 150 | prc = -999 151 | pred_class = np.argmax(P,axis=-1) 152 | acc = accuracy_score(Y,pred_class) 153 | multi_result = { 154 | 'multiclass_roc': float(roc), 155 | 'multiclass_prc': float(prc), 156 | 'multiclass_accuracy':float(acc), 157 | 'macro_f1':float(f1_score(Y,pred_class,average='macro')), 158 | } 159 | 160 | return multi_result 161 | 162 | 163 | if __name__ == '__main__': 164 | G = [5.0, 7.251812, 5.0, 7.2676063, 5.0, 8.2218485, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.7212462, 5.0, 5.0, 5.0, 5.0, 5.0, 6.4436975, 5.0, 5.0, 5.60206, 5.0, 5.0, 5.1426673, 5.0, 5.0, 6.387216, 5.0, 5.0, 5.0, 6.251812, 5.0, 5.0, 5.0, 5.0, 5.0, 6.958607, 5.0, 5.0, 5.0, 5.0, 7.1739254, 5.0, 5.0, 5.0, 6.207608, 5.0, 5.5850267, 5.0, 6.481486, 5.0, 6.455932, 5.0, 5.0, 6.853872, 5.7212462, 5.0, 5.6575775, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.29243, 5.6382723, 5.0, 5.0, 5.0, 5.0, 5.0, 5.4317985, 5.0, 6.6777806, 5.0, 5.0, 5.0, 5.0, 5.5086384, 5.0, 5.0, 5.4436975, 5.0, 5.0, 5.6777806, 5.0, 5.075721, 5.0, 5.0, 8.327902, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0] 165 | P = [5.022873, 7.0781856, 4.9978094, 6.7880363, 5.0082135, 8.301622, 5.199977, 5.031757, 5.282739, 5.1505866, 5.0371256, 5.0158253, 7.235809, 5.0488424, 5.0158954, 5.014982, 5.0353045, 5.0385847, 6.210839, 5.0246162, 5.040341, 5.9972534, 5.022253, 5.024069, 5.0325136, 5.858346, 5.1466026, 7.353938, 5.041976, 5.010902, 5.0101852, 5.7545958, 5.0263815, 5.0000725, 4.985109, 5.055313, 5.0001907, 6.8203254, 5.0954485, 5.1212735, 5.0224247, 5.0497823, 6.8255396, 5.0044026, 4.9908457, 5.0110598, 6.855809, 5.297818, 6.2044125, 5.0267057, 6.1194935, 5.005172, 5.6843953, 5.0014734, 5.0232143, 7.3333316, 5.8368444, 5.2844615, 5.8721313, 5.040511, 5.057362, 5.0058765, 5.018214, 5.0278683, 4.995488, 6.170251, 5.2143936, 5.0082054, 5.0141716, 5.560684, 5.0162783, 5.022541, 5.4540567, 5.023486, 5.0640993, 4.9965744, 5.0399494, 5.0136223, 5.1999803, 6.3908367, 5.022854, 5.0350113, 5.002722, 5.0313835, 5.175599, 5.1362724, 5.137325, 5.6480265, 5.03323, 5.054763, 8.333924, 5.0164843, 5.2512374, 5.02013, 5.023677, 5.0309353, 5.031672, 6.3660593, 5.035504, 5.0222054] 166 | # print('regression result:', evaluate_reg(np.array(G), np.array(P))) 167 | # print('cls result:',evaluate_cls(np.array(G) >= 6, np.array(P),6)) 168 | 169 | t = np.array([0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 2, 0, 2, 2, 2, 2]) 170 | p = np.array([[ 0.8746, -0.9049, 0.4639], 171 | [ 0.8708, -0.8453, 0.4591], 172 | [ 0.8843, -0.9211, 0.5301], 173 | [ 0.7957, -0.9277, 0.5806], 174 | [ 0.8791, -0.8414, 0.4515], 175 | [-0.0475, -0.2103, 0.5898], 176 | [-0.0173, -0.0968, 0.4454], 177 | [ 0.8182, -1.0273, 0.6466], 178 | [ 0.9068, -0.9533, 0.5292], 179 | [-0.8911, 1.6034, -0.4516], 180 | [-1.0852, 1.1397, 0.3087], 181 | [-0.6816, 0.7026, 0.3080], 182 | [ 0.3641, -0.2075, 0.2326], 183 | [ 0.5106, -0.1516, 0.0526], 184 | [-0.2555, -0.6679, 1.3138], 185 | [ 0.0850, 0.0502, 0.1387], 186 | [-0.6787, 0.4709, 0.4442], 187 | [-0.5439, 0.4989, 0.2730], 188 | [-0.1568, -1.2820, 1.9857], 189 | [-0.0165, -1.2909, 1.7805]]) 190 | print(evaluate_mcls(t,p)) 191 | 192 | -------------------------------------------------------------------------------- /PSICHIC/psichic_utils/protein_init.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import sys 4 | 5 | # Check if the code is running in a Jupyter notebook 6 | if 'ipykernel' in sys.modules: 7 | from tqdm.notebook import tqdm 8 | else: 9 | from tqdm import tqdm 10 | 11 | import torch 12 | import esm 13 | from torch_geometric.utils import degree, add_self_loops, subgraph, to_undirected, remove_self_loops, coalesce 14 | 15 | import math 16 | 17 | def protein_init(seqs): 18 | result_dict = {} 19 | model_location = "esm2_t33_650M_UR50D" 20 | model, alphabet = esm.pretrained.load_model_and_alphabet(model_location) 21 | model.eval() 22 | if torch.cuda.is_available(): 23 | model = model.cuda() 24 | batch_converter = alphabet.get_batch_converter() 25 | 26 | for seq in tqdm(seqs): 27 | seq_feat = seq_feature(seq) 28 | token_repr, contact_map_proba, logits = esm_extract(model, batch_converter, seq, layer=33, approach='last',dim=1280) 29 | 30 | assert len(contact_map_proba) == len(seq) 31 | edge_index, edge_weight = contact_map(contact_map_proba) 32 | 33 | result_dict[seq] = { 34 | 'seq':seq, 35 | 'seq_feat':torch.from_numpy(seq_feat), 36 | 'token_representation':token_repr.half(), 37 | 'num_nodes': len(seq), 38 | 'num_pos':torch.arange(len(seq)).reshape(-1,1), 39 | 'edge_index': edge_index, 40 | 'edge_weight': edge_weight, 41 | } 42 | 43 | return result_dict 44 | 45 | 46 | 47 | # normalize 48 | def dic_normalize(dic): 49 | # print(dic) 50 | max_value = dic[max(dic, key=dic.get)] 51 | min_value = dic[min(dic, key=dic.get)] 52 | # print(max_value) 53 | interval = float(max_value) - float(min_value) 54 | for key in dic.keys(): 55 | dic[key] = (dic[key] - min_value) / interval 56 | dic['X'] = (max_value + min_value) / 2.0 57 | return dic 58 | 59 | 60 | pro_res_table = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', 61 | 'X'] 62 | 63 | pro_res_aliphatic_table = ['A', 'I', 'L', 'M', 'V'] 64 | pro_res_aromatic_table = ['F', 'W', 'Y'] 65 | pro_res_polar_neutral_table = ['C', 'N', 'Q', 'S', 'T'] 66 | pro_res_acidic_charged_table = ['D', 'E'] 67 | pro_res_basic_charged_table = ['H', 'K', 'R'] 68 | 69 | res_weight_table = {'A': 71.08, 'C': 103.15, 'D': 115.09, 'E': 129.12, 'F': 147.18, 'G': 57.05, 'H': 137.14, 70 | 'I': 113.16, 'K': 128.18, 'L': 113.16, 'M': 131.20, 'N': 114.11, 'P': 97.12, 'Q': 128.13, 71 | 'R': 156.19, 'S': 87.08, 'T': 101.11, 'V': 99.13, 'W': 186.22, 'Y': 163.18} 72 | 73 | res_pka_table = {'A': 2.34, 'C': 1.96, 'D': 1.88, 'E': 2.19, 'F': 1.83, 'G': 2.34, 'H': 1.82, 'I': 2.36, 74 | 'K': 2.18, 'L': 2.36, 'M': 2.28, 'N': 2.02, 'P': 1.99, 'Q': 2.17, 'R': 2.17, 'S': 2.21, 75 | 'T': 2.09, 'V': 2.32, 'W': 2.83, 'Y': 2.32} 76 | 77 | res_pkb_table = {'A': 9.69, 'C': 10.28, 'D': 9.60, 'E': 9.67, 'F': 9.13, 'G': 9.60, 'H': 9.17, 78 | 'I': 9.60, 'K': 8.95, 'L': 9.60, 'M': 9.21, 'N': 8.80, 'P': 10.60, 'Q': 9.13, 79 | 'R': 9.04, 'S': 9.15, 'T': 9.10, 'V': 9.62, 'W': 9.39, 'Y': 9.62} 80 | 81 | res_pkx_table = {'A': 0.00, 'C': 8.18, 'D': 3.65, 'E': 4.25, 'F': 0.00, 'G': 0, 'H': 6.00, 82 | 'I': 0.00, 'K': 10.53, 'L': 0.00, 'M': 0.00, 'N': 0.00, 'P': 0.00, 'Q': 0.00, 83 | 'R': 12.48, 'S': 0.00, 'T': 0.00, 'V': 0.00, 'W': 0.00, 'Y': 0.00} 84 | 85 | res_pl_table = {'A': 6.00, 'C': 5.07, 'D': 2.77, 'E': 3.22, 'F': 5.48, 'G': 5.97, 'H': 7.59, 86 | 'I': 6.02, 'K': 9.74, 'L': 5.98, 'M': 5.74, 'N': 5.41, 'P': 6.30, 'Q': 5.65, 87 | 'R': 10.76, 'S': 5.68, 'T': 5.60, 'V': 5.96, 'W': 5.89, 'Y': 5.96} 88 | 89 | res_hydrophobic_ph2_table = {'A': 47, 'C': 52, 'D': -18, 'E': 8, 'F': 92, 'G': 0, 'H': -42, 'I': 100, 90 | 'K': -37, 'L': 100, 'M': 74, 'N': -41, 'P': -46, 'Q': -18, 'R': -26, 'S': -7, 91 | 'T': 13, 'V': 79, 'W': 84, 'Y': 49} 92 | 93 | res_hydrophobic_ph7_table = {'A': 41, 'C': 49, 'D': -55, 'E': -31, 'F': 100, 'G': 0, 'H': 8, 'I': 99, 94 | 'K': -23, 'L': 97, 'M': 74, 'N': -28, 'P': -46, 'Q': -10, 'R': -14, 'S': -5, 95 | 'T': 13, 'V': 76, 'W': 97, 'Y': 63} 96 | 97 | res_weight_table = dic_normalize(res_weight_table) 98 | res_pka_table = dic_normalize(res_pka_table) 99 | res_pkb_table = dic_normalize(res_pkb_table) 100 | res_pkx_table = dic_normalize(res_pkx_table) 101 | res_pl_table = dic_normalize(res_pl_table) 102 | res_hydrophobic_ph2_table = dic_normalize(res_hydrophobic_ph2_table) 103 | res_hydrophobic_ph7_table = dic_normalize(res_hydrophobic_ph7_table) 104 | 105 | 106 | def residue_features(residue): 107 | res_property1 = [1 if residue in pro_res_aliphatic_table else 0, 1 if residue in pro_res_aromatic_table else 0, 108 | 1 if residue in pro_res_polar_neutral_table else 0, 109 | 1 if residue in pro_res_acidic_charged_table else 0, 110 | 1 if residue in pro_res_basic_charged_table else 0] 111 | res_property2 = [res_weight_table[residue], res_pka_table[residue], res_pkb_table[residue], res_pkx_table[residue], 112 | res_pl_table[residue], res_hydrophobic_ph2_table[residue], res_hydrophobic_ph7_table[residue]] 113 | # print(np.array(res_property1 + res_property2).shape) 114 | return np.array(res_property1 + res_property2) 115 | 116 | 117 | # one ont encoding 118 | def one_of_k_encoding(x, allowable_set): 119 | if x not in allowable_set: 120 | # print(x) 121 | raise Exception('input {0} not in allowable set{1}:'.format(x, allowable_set)) 122 | return list(map(lambda s: x == s, allowable_set)) 123 | 124 | 125 | def one_of_k_encoding_unk(x, allowable_set): 126 | '''Maps inputs not in the allowable set to the last element.''' 127 | if x not in allowable_set: 128 | x = allowable_set[-1] 129 | return list(map(lambda s: x == s, allowable_set)) 130 | 131 | 132 | 133 | def seq_feature(pro_seq): 134 | if 'U' in pro_seq or 'B' in pro_seq: 135 | print('U or B in Sequence') 136 | pro_seq = pro_seq.replace('U','X').replace('B','X') 137 | pro_hot = np.zeros((len(pro_seq), len(pro_res_table))) 138 | pro_property = np.zeros((len(pro_seq), 12)) 139 | for i in range(len(pro_seq)): 140 | # if 'X' in pro_seq: 141 | # print(pro_seq) 142 | pro_hot[i,] = one_of_k_encoding(pro_seq[i], pro_res_table) 143 | pro_property[i,] = residue_features(pro_seq[i]) 144 | return np.concatenate((pro_hot, pro_property), axis=1) 145 | 146 | def contact_map(contact_map_proba, contact_threshold=0.5): 147 | num_residues = contact_map_proba.shape[0] 148 | prot_contact_adj = (contact_map_proba >= contact_threshold).long() 149 | edge_index = prot_contact_adj.nonzero(as_tuple=False).t().contiguous() 150 | row, col = edge_index 151 | edge_weight = contact_map_proba[row, col].float() 152 | ############### CONNECT ISOLATED NODES - Prevent Disconnected Residues ###################### 153 | seq_edge_head1 = torch.stack([torch.arange(num_residues)[:-1],(torch.arange(num_residues)+1)[:-1]]) 154 | seq_edge_tail1 = torch.stack([(torch.arange(num_residues))[1:],(torch.arange(num_residues)-1)[1:]]) 155 | seq_edge_weight1 = torch.ones(seq_edge_head1.size(1) + seq_edge_tail1.size(1)) * contact_threshold 156 | edge_index = torch.cat([edge_index, seq_edge_head1, seq_edge_tail1],dim=-1) 157 | edge_weight = torch.cat([edge_weight, seq_edge_weight1],dim=-1) 158 | 159 | seq_edge_head2 = torch.stack([torch.arange(num_residues)[:-2],(torch.arange(num_residues)+2)[:-2]]) 160 | seq_edge_tail2 = torch.stack([(torch.arange(num_residues))[2:],(torch.arange(num_residues)-2)[2:]]) 161 | seq_edge_weight2 = torch.ones(seq_edge_head2.size(1) + seq_edge_tail2.size(1)) *contact_threshold 162 | edge_index = torch.cat([edge_index, seq_edge_head2, seq_edge_tail2],dim=-1) 163 | edge_weight = torch.cat([edge_weight, seq_edge_weight2],dim=-1) 164 | ############### CONNECT ISOLATED NODES - Prevent Disconnected Residues ###################### 165 | 166 | edge_index, edge_weight = coalesce(edge_index, edge_weight, reduce='max') 167 | edge_index, edge_weight = to_undirected(edge_index, edge_weight, reduce='max') 168 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 169 | edge_index, edge_weight = add_self_loops(edge_index, edge_weight,fill_value=1) 170 | 171 | return edge_index, edge_weight 172 | 173 | 174 | 175 | def esm_extract(model, batch_converter, seq, layer=36, approach='mean',dim=2560): 176 | pro_id = 'A' 177 | if len(seq) <= 700: 178 | data = [] 179 | data.append((pro_id, seq)) 180 | batch_labels, batch_strs, batch_tokens = batch_converter(data) 181 | batch_tokens = batch_tokens.to(next(model.parameters()).device, non_blocking=True) 182 | 183 | with torch.no_grad(): 184 | results = model(batch_tokens, repr_layers=[i for i in range(1, layer + 1)], return_contacts=True) 185 | 186 | logits = results["logits"][0].cpu().numpy()[1: len(seq) + 1] 187 | contact_prob_map = results["contacts"][0].cpu().numpy() 188 | token_representation = torch.cat([results['representations'][i] for i in range(1, layer + 1)]) 189 | assert token_representation.size(0) == layer 190 | 191 | if approach == 'last': 192 | token_representation = token_representation[-1] 193 | elif approach == 'sum': 194 | token_representation = token_representation.sum(dim=0) 195 | elif approach == 'mean': 196 | token_representation = token_representation.mean(dim=0) 197 | 198 | token_representation = token_representation.cpu().numpy() 199 | token_representation = token_representation[1: len(seq) + 1] 200 | else: 201 | contact_prob_map = np.zeros((len(seq), len(seq))) # global contact map prediction 202 | token_representation = np.zeros((len(seq), dim)) 203 | logits = np.zeros((len(seq),layer)) 204 | interval = 350 205 | i = math.ceil(len(seq) / interval) 206 | # ====================== 207 | # = = 208 | # = = 209 | # = ====================== 210 | # = =*********= = 211 | # = =*********= = 212 | # ====================== = 213 | # = = 214 | # = = 215 | # ====================== 216 | # where * is the overlapping area 217 | # subsection seq contact map prediction 218 | for s in range(i): 219 | start = s * interval # sub seq predict start 220 | end = min((s + 2) * interval, len(seq)) # sub seq predict end 221 | sub_seq_len = end - start 222 | 223 | # prediction 224 | temp_seq = seq[start:end] 225 | temp_data = [] 226 | temp_data.append((pro_id, temp_seq)) 227 | batch_labels, batch_strs, batch_tokens = batch_converter(temp_data) 228 | batch_tokens = batch_tokens.to(next(model.parameters()).device, non_blocking=True) 229 | with torch.no_grad(): 230 | results = model(batch_tokens, repr_layers=[i for i in range(1, layer + 1)], return_contacts=True) 231 | 232 | # insert into the global contact map 233 | row, col = np.where(contact_prob_map[start:end, start:end] != 0) 234 | row = row + start 235 | col = col + start 236 | contact_prob_map[start:end, start:end] = contact_prob_map[start:end, start:end] + results["contacts"][ 237 | 0].cpu().numpy() 238 | contact_prob_map[row, col] = contact_prob_map[row, col] / 2.0 239 | 240 | logits[start:end] += results['logits'][0].cpu().numpy()[1: len(temp_seq) + 1] 241 | logits[row] = logits[row]/2.0 242 | 243 | ## TOKEN 244 | subtoken_repr = torch.cat([results['representations'][i] for i in range(1, layer + 1)]) 245 | assert subtoken_repr.size(0) == layer 246 | if approach == 'last': 247 | subtoken_repr = subtoken_repr[-1] 248 | elif approach == 'sum': 249 | subtoken_repr = subtoken_repr.sum(dim=0) 250 | elif approach == 'mean': 251 | subtoken_repr = subtoken_repr.mean(dim=0) 252 | 253 | subtoken_repr = subtoken_repr.cpu().numpy() 254 | subtoken_repr = subtoken_repr[1: len(temp_seq) + 1] 255 | 256 | trow = np.where(token_representation[start:end].sum(axis=-1) != 0)[0] 257 | trow = trow + start 258 | token_representation[start:end] = token_representation[start:end] + subtoken_repr 259 | token_representation[trow] = token_representation[trow] / 2.0 260 | 261 | if end == len(seq): 262 | break 263 | 264 | return torch.from_numpy(token_representation), torch.from_numpy(contact_prob_map), torch.from_numpy(logits) 265 | 266 | 267 | 268 | 269 | def generate_ESM_structure(model, filename, sequence): 270 | model.set_chunk_size(256) 271 | chunk_size = 256 272 | output = None 273 | 274 | while output is None: 275 | try: 276 | with torch.no_grad(): 277 | output = model.infer_pdb(sequence) 278 | 279 | with open(filename, "w") as f: 280 | f.write(output) 281 | print("saved", filename) 282 | except RuntimeError as e: 283 | if 'out of memory' in str(e): 284 | print('| WARNING: ran out of memory on chunk_size', chunk_size) 285 | for p in model.parameters(): 286 | if p.grad is not None: 287 | del p.grad # free some memory 288 | torch.cuda.empty_cache() 289 | chunk_size = chunk_size // 2 290 | if chunk_size > 2: 291 | model.set_chunk_size(chunk_size) 292 | else: 293 | print("Not enough memory for ESMFold") 294 | break 295 | else: 296 | raise e 297 | return output is not None 298 | 299 | 300 | from Bio.PDB import PDBParser 301 | biopython_parser = PDBParser() 302 | 303 | one_to_three = {"A" : "ALA", 304 | "C" : "CYS", 305 | "D" : "ASP", 306 | "E" : "GLU", 307 | "F" : "PHE", 308 | "G" : "GLY", 309 | "H" : "HIS", 310 | "I" : "ILE", 311 | "K" : "LYS", 312 | "L" : "LEU", 313 | "M" : "MET", 314 | "N" : "ASN", 315 | "P" : "PRO", 316 | "Q" : "GLN", 317 | "R" : "ARG", 318 | "S" : "SER", 319 | "T" : "THR", 320 | "V" : "VAL", 321 | "W" : "TRP", 322 | "Y" : "TYR", 323 | "B" : "ASX", 324 | "Z" : "GLX", 325 | "X" : "UNK", 326 | "*" : " * "} 327 | 328 | three_to_one = {} 329 | for _key, _value in one_to_three.items(): 330 | three_to_one[_value] = _key 331 | three_to_one["SEC"] = "C" 332 | three_to_one["MSE"] = "M" 333 | 334 | 335 | def extract_pdb_seq(protein_path): 336 | 337 | structure = biopython_parser.get_structure('random_id', protein_path)[0] 338 | seq = '' 339 | chain_str = '' 340 | for i, chain in enumerate(structure): 341 | for res_idx, residue in enumerate(chain): 342 | if residue.get_resname() == 'HOH': 343 | continue 344 | residue_coords = [] 345 | c_alpha, n, c = None, None, None 346 | for atom in residue: 347 | if atom.name == 'CA': 348 | c_alpha = list(atom.get_vector()) 349 | if atom.name == 'N': 350 | n = list(atom.get_vector()) 351 | if atom.name == 'C': 352 | c = list(atom.get_vector()) 353 | if c_alpha != None and n != None and c != None: # only append residue if it is an amino acid and not 354 | try: 355 | seq += three_to_one[residue.get_resname()] 356 | chain_str += str(chain.id) 357 | except Exception as e: 358 | seq += 'X' 359 | chain_str += str(chain.id) 360 | print("encountered unknown AA: ", residue.get_resname(), ' in the complex. Replacing it with a dash X.') 361 | 362 | return seq, chain_str -------------------------------------------------------------------------------- /PSICHIC/runtime_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | from dotenv import load_dotenv 4 | 5 | load_dotenv() 6 | 7 | class RuntimeConfig: 8 | PSICHIC_PATH = os.path.dirname(os.path.abspath(__file__)) 9 | device = os.environ.get("DEVICE_OVERRIDE") 10 | DEVICE = ["cpu" if device=="cpu" else "cuda:0"][0] 11 | MODEL_PATH = os.path.join(PSICHIC_PATH, 'trained_weights', 'TREAT1') 12 | BATCH_SIZE = 128 13 | 14 | -------------------------------------------------------------------------------- /PSICHIC/trained_weights/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metanova-labs/nova/bd3dbf28c45a8b089424939a1c39c90cc291dcd0/PSICHIC/trained_weights/.DS_Store -------------------------------------------------------------------------------- /PSICHIC/trained_weights/PDBv2020_PSICHIC/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "optimizer": { 3 | "lrate": 0.0001, 4 | "weight_decay": 0.0001, 5 | "clip": 1, 6 | "betas": [ 7 | 0.9, 8 | 0.999 9 | ], 10 | "eps": 1e-08, 11 | "schedule_lr": false, 12 | "min_lrate": 0, 13 | "warmup_iters": 0, 14 | "lr_decay_iters": 0, 15 | "amsgrad": false 16 | }, 17 | "params": { 18 | "mol_in_channels": 43, 19 | "prot_in_channels": 33, 20 | "prot_evo_channels": 1280, 21 | "hidden_channels": 200, 22 | "aggregators": [ 23 | "mean", 24 | "min", 25 | "max", 26 | "std" 27 | ], 28 | "scalers": [ 29 | "identity", 30 | "amplification", 31 | "linear" 32 | ], 33 | "pre_layers": 2, 34 | "post_layers": 1, 35 | "total_layer": 3, 36 | "K": [ 37 | 5, 38 | 10, 39 | 20 40 | ], 41 | "dropout": 0, 42 | "dropout_attn_score": 0.2, 43 | "heads": 5 44 | }, 45 | "tasks": { 46 | "regression_task": true, 47 | "classification_task": false, 48 | "mclassification_task": 0 49 | } 50 | } -------------------------------------------------------------------------------- /PSICHIC/trained_weights/PDBv2020_PSICHIC/degree copy.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metanova-labs/nova/bd3dbf28c45a8b089424939a1c39c90cc291dcd0/PSICHIC/trained_weights/PDBv2020_PSICHIC/degree copy.pt -------------------------------------------------------------------------------- /PSICHIC/trained_weights/PDBv2020_PSICHIC/degree.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metanova-labs/nova/bd3dbf28c45a8b089424939a1c39c90cc291dcd0/PSICHIC/trained_weights/PDBv2020_PSICHIC/degree.pt -------------------------------------------------------------------------------- /PSICHIC/trained_weights/PDBv2020_PSICHIC/model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metanova-labs/nova/bd3dbf28c45a8b089424939a1c39c90cc291dcd0/PSICHIC/trained_weights/PDBv2020_PSICHIC/model.pt -------------------------------------------------------------------------------- /PSICHIC/trained_weights/TREAT1/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "tasks": { 3 | "regression_task": true, 4 | "classification_task": null, 5 | "mclassification_task": null 6 | }, 7 | "optimizer": { 8 | "lrate": 1e-05, 9 | "weight_decay": 0.0001, 10 | "clip": 1, 11 | "betas": [ 12 | 0.9, 13 | 0.999 14 | ], 15 | "eps": 1e-08, 16 | "schedule_lr": false, 17 | "min_lrate": 0, 18 | "warmup_iters": 0, 19 | "lr_decay_iters": 0, 20 | "amsgrad": false 21 | }, 22 | "params": { 23 | "mol_in_channels": 43, 24 | "prot_in_channels": 33, 25 | "prot_evo_channels": 1280, 26 | "hidden_channels": 200, 27 | "aggregators": [ 28 | "mean", 29 | "min", 30 | "max", 31 | "std" 32 | ], 33 | "scalers": [ 34 | "identity", 35 | "amplification", 36 | "linear" 37 | ], 38 | "pre_layers": 2, 39 | "post_layers": 1, 40 | "total_layer": 3, 41 | "K": [ 42 | 5, 43 | 10, 44 | 20 45 | ], 46 | "dropout": 0, 47 | "dropout_attn_score": 0.2, 48 | "heads": 5 49 | } 50 | } -------------------------------------------------------------------------------- /PSICHIC/trained_weights/TREAT1/degree.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metanova-labs/nova/bd3dbf28c45a8b089424939a1c39c90cc291dcd0/PSICHIC/trained_weights/TREAT1/degree.pt -------------------------------------------------------------------------------- /PSICHIC/trained_weights/TREAT1/model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metanova-labs/nova/bd3dbf28c45a8b089424939a1c39c90cc291dcd0/PSICHIC/trained_weights/TREAT1/model.pt -------------------------------------------------------------------------------- /PSICHIC/trained_weights/multitask_PSICHIC/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "tasks": { 3 | "regression_task": true, 4 | "classification_task": false, 5 | "mclassification_task": 3 6 | }, 7 | "optimizer": { 8 | "lrate": 0.0001, 9 | "weight_decay": 0.0001, 10 | "clip": 1, 11 | "betas": [ 12 | 0.9, 13 | 0.999 14 | ], 15 | "eps": 1e-08, 16 | "schedule_lr": false, 17 | "min_lrate": 0, 18 | "warmup_iters": 0, 19 | "lr_decay_iters": 0, 20 | "amsgrad": false 21 | }, 22 | "params": { 23 | "mol_in_channels": 43, 24 | "prot_in_channels": 33, 25 | "prot_evo_channels": 1280, 26 | "hidden_channels": 200, 27 | "aggregators": [ 28 | "mean", 29 | "min", 30 | "max", 31 | "std" 32 | ], 33 | "scalers": [ 34 | "identity", 35 | "amplification", 36 | "linear" 37 | ], 38 | "pre_layers": 2, 39 | "post_layers": 1, 40 | "total_layer": 3, 41 | "K": [ 42 | 5, 43 | 10, 44 | 20 45 | ], 46 | "dropout": 0, 47 | "dropout_attn_score": 0.2, 48 | "heads": 5 49 | } 50 | } -------------------------------------------------------------------------------- /PSICHIC/trained_weights/multitask_PSICHIC/degree.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metanova-labs/nova/bd3dbf28c45a8b089424939a1c39c90cc291dcd0/PSICHIC/trained_weights/multitask_PSICHIC/degree.pt -------------------------------------------------------------------------------- /PSICHIC/trained_weights/multitask_PSICHIC/model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metanova-labs/nova/bd3dbf28c45a8b089424939a1c39c90cc291dcd0/PSICHIC/trained_weights/multitask_PSICHIC/model.pt -------------------------------------------------------------------------------- /PSICHIC/wrapper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | import os 5 | import pandas as pd 6 | import torch 7 | 8 | from .psichic_utils.dataset import ProteinMoleculeDataset 9 | from .psichic_utils.data_utils import DataLoader, virtual_screening 10 | from .psichic_utils import protein_init, ligand_init 11 | from .models.net import net 12 | 13 | from .runtime_config import RuntimeConfig 14 | 15 | class PsichicWrapper: 16 | def __init__(self): 17 | self.runtime_config = RuntimeConfig() 18 | self.device = self.runtime_config.DEVICE 19 | 20 | with open(os.path.join(self.runtime_config.MODEL_PATH, 'config.json'), 'r') as f: 21 | self.model_config = json.load(f) 22 | 23 | def load_model(self): 24 | degree_dict = torch.load(os.path.join(self.runtime_config.MODEL_PATH, 25 | 'degree.pt'), 26 | weights_only=True 27 | ) 28 | param_dict = os.path.join(self.runtime_config.MODEL_PATH, 'model.pt') 29 | mol_deg, prot_deg = degree_dict['ligand_deg'], degree_dict['protein_deg'] 30 | 31 | self.model = net(mol_deg, prot_deg, 32 | # MOLECULE 33 | mol_in_channels=self.model_config['params']['mol_in_channels'], 34 | prot_in_channels=self.model_config['params']['prot_in_channels'], 35 | prot_evo_channels=self.model_config['params']['prot_evo_channels'], 36 | hidden_channels=self.model_config['params']['hidden_channels'], 37 | pre_layers=self.model_config['params']['pre_layers'], 38 | post_layers=self.model_config['params']['post_layers'], 39 | aggregators=self.model_config['params']['aggregators'], 40 | scalers=self.model_config['params']['scalers'], 41 | total_layer=self.model_config['params']['total_layer'], 42 | K=self.model_config['params']['K'], 43 | heads=self.model_config['params']['heads'], 44 | dropout=self.model_config['params']['dropout'], 45 | dropout_attn_score=self.model_config['params']['dropout_attn_score'], 46 | # output 47 | regression_head=self.model_config['tasks']['regression_task'], 48 | classification_head=self.model_config['tasks']['classification_task'] , 49 | multiclassification_head=self.model_config['tasks']['mclassification_task'], 50 | device=self.device).to(self.device) 51 | self.model.reset_parameters() 52 | self.model.load_state_dict(torch.load(param_dict, 53 | map_location=self.device, 54 | weights_only=True 55 | ) 56 | ) 57 | 58 | def initialize_protein(self, protein_seq:str) -> dict: 59 | allowed_chars = set(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', 'X']) 60 | sanitized_protein_seq = ''.join([aa if aa in allowed_chars else 'X' for aa in protein_seq]) 61 | self.protein_seq = [sanitized_protein_seq] 62 | 63 | protein_dict = protein_init(self.protein_seq) 64 | return protein_dict 65 | 66 | def initialize_smiles(self, smiles_list:list) -> dict: 67 | self.smiles_list = smiles_list 68 | smiles_dict = ligand_init(smiles_list) 69 | return smiles_dict 70 | 71 | def create_screen_loader(self, protein_dict, smiles_dict): 72 | self.screen_df = pd.DataFrame({'Protein': [k for k in self.protein_seq for _ in self.smiles_list], 73 | 'Ligand': [l for l in self.smiles_list for _ in self.protein_seq], 74 | }) 75 | 76 | dataset = ProteinMoleculeDataset(self.screen_df, 77 | smiles_dict, 78 | protein_dict, 79 | device=self.device 80 | ) 81 | 82 | self.screen_loader = DataLoader(dataset, 83 | batch_size=self.runtime_config.BATCH_SIZE, 84 | shuffle=False, 85 | follow_batch=['mol_x', 'clique_x', 'prot_node_aa'] 86 | ) 87 | 88 | def run_challenge_start(self, protein_seq:str): 89 | torch.cuda.empty_cache() 90 | self.load_model() 91 | self.protein_dict = self.initialize_protein(protein_seq) 92 | 93 | def run_validation(self, smiles_list:list) -> pd.DataFrame: 94 | self.smiles_dict = self.initialize_smiles(smiles_list) 95 | torch.cuda.empty_cache() 96 | self.create_screen_loader(self.protein_dict, self.smiles_dict) 97 | self.screen_df = virtual_screening(self.screen_df, 98 | self.model, 99 | self.screen_loader, 100 | os.getcwd(), 101 | save_interpret=False, 102 | ligand_dict=self.smiles_dict, 103 | device=self.device, 104 | save_cluster=False, 105 | ) 106 | return self.screen_df 107 | 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NOVA - SN68 2 | 3 | ## High-throughput ML-driven drug screening. 4 | 5 | ### Accelerating drug discovery, powered by Bittensor. 6 | 7 | NOVA harnesses global compute and collective intelligence to navigate huge unexplored chemical spaces, uncovering breakthrough compounds at a fraction of the cost and time. 8 | 9 | ## System Requirements 10 | 11 | - Ubuntu 24.04 LTS (recommended) 12 | - Python 3.12 13 | - CUDA 12.4 (for GPU support) 14 | - Sufficient RAM for ML model operations 15 | - Internet connection for network participation 16 | 17 | ## Installation and Running 18 | 19 | 1. Clone the repository: 20 | ```bash 21 | git clone 22 | cd nova 23 | ``` 24 | 25 | 2. Prepare your .env file as in example.env: 26 | ``` 27 | # General configs 28 | SUBTENSOR_NETWORK="ws://localhost:9944" # or your chosen node 29 | DEVICE_OVERRIDE="cpu" # None to run on GPU 30 | 31 | # Github configs - FOR MINERS 32 | GITHUB_REPO_NAME="repo-name" 33 | GITHUB_REPO_BRANCH="repo-branch" 34 | GITHUB_TOKEN="your_token" 35 | GITHUB_REPO_OWNER="repo-owner" 36 | GITHUB_REPO_PATH="" # path within repo or "" 37 | 38 | # For validators 39 | VALIDATOR_API_KEY="your_api_key" 40 | ``` 41 | 42 | 3. Install dependencies: 43 | - For CPU: 44 | ```bash 45 | ./install_deps_cpu.sh 46 | ``` 47 | - For CUDA 12.4: 48 | ```bash 49 | ./install_deps_cu124.sh 50 | ``` 51 | 52 | 4. Run: 53 | ```bash 54 | # Activate your virtual environment: 55 | source .venv/bin/activate 56 | 57 | # Run your script: 58 | # miner: 59 | python3 neurons/miner.py --wallet.name --wallet.hotkey --logging.info 60 | 61 | # validator: 62 | python3 neurons/validator.py --wallet.name --wallet.hotkey --logging.debug 63 | ``` 64 | 65 | ## Configuration 66 | 67 | The project uses several configuration files: 68 | - `.env`: Environment variables and API keys 69 | - `requirements/`: Dependency specifications for different environments 70 | - Command-line arguments for runtime configuration 71 | - `PSICHIC/runtime_config.py`: runtime configurations for PSICHIC model 72 | 73 | 74 | ## For Validators 75 | 76 | DM the NOVA team to obtain an API key. 77 | 78 | 79 | ## Support 80 | 81 | For support, please open an issue in the repository or contact the NOVA team. 82 | -------------------------------------------------------------------------------- /auto_updater.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import subprocess 4 | import sys 5 | import time 6 | 7 | class AutoUpdater: 8 | """Auto-updater that pulls from metanova-labs/nova main branch and restarts.""" 9 | 10 | UPDATE_INTERVAL = 3600 # Check every hour 11 | REMOTE_URL = "https://github.com/metanova-labs/nova.git" 12 | BRANCH = "main" 13 | REPO_PATH = "." # Current directory 14 | 15 | def __init__(self, logger): 16 | """Initialize with logger. Uses current directory as repo path.""" 17 | self.logger = logger 18 | self._setup_remote() 19 | 20 | def _setup_remote(self): 21 | """Ensure remote URL is correct.""" 22 | self.logger.info(f"Setting up remote URL: {self.REMOTE_URL}") 23 | 24 | returncode, stdout, stderr = self._run_git_command('remote', '-v') 25 | if 'origin' in stdout: 26 | returncode, stdout, stderr = self._run_git_command('remote', 'set-url', 'origin', self.REMOTE_URL) 27 | else: 28 | returncode, stdout, stderr = self._run_git_command('remote', 'add', 'origin', self.REMOTE_URL) 29 | 30 | if returncode != 0: 31 | self.logger.error(f"Failed to set up remote URL: {stderr}") 32 | 33 | def _run_git_command(self, *args): 34 | """Run git command and return results.""" 35 | cmd = ['git'] + list(args) 36 | process = subprocess.run( 37 | cmd, 38 | cwd=self.REPO_PATH, 39 | capture_output=True, 40 | text=True 41 | ) 42 | return process.returncode, process.stdout.strip(), process.stderr.strip() 43 | 44 | def _reset_local_changes(self): 45 | """Reset local changes to HEAD.""" 46 | self.logger.info("Resetting local changes before update") 47 | returncode, stdout, stderr = self._run_git_command('reset', '--hard', 'HEAD') 48 | if returncode != 0: 49 | self.logger.error(f"Failed to reset changes: {stderr}") 50 | return False 51 | return True 52 | 53 | def _check_for_updates(self): 54 | """Check if updates are available.""" 55 | returncode, stdout, stderr = self._run_git_command('fetch', 'origin', self.BRANCH) 56 | if returncode != 0: 57 | self.logger.error(f"Failed to fetch updates: {stderr}") 58 | return False 59 | 60 | returncode, stdout, stderr = self._run_git_command('diff', f'HEAD..origin/{self.BRANCH}') 61 | if returncode != 0: 62 | self.logger.error(f"Failed to check if updates are available: {stderr}") 63 | return False 64 | 65 | return bool(stdout.strip()) 66 | 67 | def _pull_updates(self): 68 | """Pull updates from remote branch.""" 69 | self.logger.info(f"Pulling updates from origin/{self.BRANCH}") 70 | returncode, stdout, stderr = self._run_git_command('pull', 'origin', self.BRANCH) 71 | if returncode != 0: 72 | self.logger.error(f"Failed to pull updates: {stderr}") 73 | return False 74 | return True 75 | 76 | def _restart_process(self): 77 | """Restart the process with same arguments.""" 78 | self.logger.info(f"Restarting process with command: {' '.join(sys.argv)}") 79 | try: 80 | subprocess.Popen([sys.executable] + sys.argv) 81 | time.sleep(1) 82 | os._exit(0) 83 | except Exception as e: 84 | self.logger.error(f"Failed to restart process: {e}") 85 | 86 | async def start_update_loop(self): 87 | """Run update loop checking for and applying updates.""" 88 | while True: 89 | try: 90 | self.logger.info(f"Checking for updates from {self.REMOTE_URL} ({self.BRANCH} branch)") 91 | 92 | if not self._reset_local_changes(): 93 | await asyncio.sleep(self.UPDATE_INTERVAL) 94 | continue 95 | 96 | if self._check_for_updates(): 97 | self.logger.info("Updates available, pulling changes") 98 | 99 | if self._pull_updates(): 100 | self.logger.info("Updates successfully applied, restarting") 101 | self._restart_process() 102 | self.logger.error("Failed to restart after update") 103 | else: 104 | self.logger.info("No updates available") 105 | 106 | except Exception as e: 107 | self.logger.error(f"Error in update loop: {e}") 108 | 109 | self.logger.info(f"Next update check in {self.UPDATE_INTERVAL} seconds") 110 | await asyncio.sleep(self.UPDATE_INTERVAL) -------------------------------------------------------------------------------- /btdr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # credits: Rhef 3 | from concurrent.futures import ThreadPoolExecutor 4 | from typing import Dict, Tuple, Optional 5 | import asyncio 6 | import base64 7 | import hashlib 8 | import logging 9 | import secrets 10 | import time 11 | 12 | from cryptography.fernet import Fernet 13 | import requests 14 | import timelock 15 | import bittensor as bt 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class TooEarly(RuntimeError): 21 | pass 22 | 23 | 24 | class DrandClient: 25 | """Class for Drand-based timelock encryption and decryption.""" 26 | 27 | RETRY_LIMIT = 30 28 | RETRY_BACKOFF_S = 2 29 | 30 | def __init__(self, url): 31 | """Initialize a requests session for better performance.""" 32 | self.session: requests.Session = requests.Session() 33 | self.url = url 34 | 35 | def get(self, round_number: int, retry_if_too_early=False) -> str: 36 | """Fetch the randomness for a given round, using cache to prevent duplicate requests.""" 37 | a = 0 38 | while a <= self.RETRY_LIMIT: 39 | a += 1 40 | response: requests.Response = self.session.get(f"{self.url}/public/{round_number}") 41 | if response.status_code == 200: 42 | break 43 | elif response.status_code in (404, 425): 44 | bt.logging.debug(f"Randomness for round {round_number} is not yet available.") 45 | if not retry_if_too_early: 46 | try: 47 | response.raise_for_status() 48 | except Exception as e: 49 | raise TooEarly() from e 50 | elif response.status_code == 500: 51 | bt.logging.debug(f'{response.status_code} {response} {response.headers} {response.text}') 52 | time.sleep(self.RETRY_BACKOFF_S) 53 | continue 54 | response.raise_for_status() 55 | bt.logging.debug(f"Got randomness for round {round_number} successfully.") 56 | 57 | return response.json() 58 | 59 | 60 | class AbstractBittensorDrandTimelock: 61 | """Class for Drand-based timelock encryption and decryption using the timelock library.""" 62 | #DRAND_URL: str = "https://api.drand.sh" # more 500 than 200 63 | DRAND_URL: str = "https://drand.cloudflare.com" 64 | EPOCH_LENGTH = 361 # Number of blocks per epoch 65 | 66 | def __init__(self) -> None: 67 | """Initialize the Timelock client.""" 68 | self.tl = timelock.Timelock(self.PK_HEX) 69 | self.drand_client = DrandClient(f'{self.DRAND_URL}/{self.CHAIN}') 70 | 71 | def _get_drand_round_info(self, round_number: int, cache: Dict[int, str]): 72 | """Fetch the randomness for a given round, using a cache to prevent duplicate requests.""" 73 | if not (round_info := cache.get(round_number)): 74 | try: 75 | round_info = cache[round_number] = self.drand_client.get(round_number) 76 | except ValueError: 77 | raise RuntimeError(f"Randomness for round {round_number} is not yet available.") 78 | return round_info 79 | 80 | def _get_drand_signature(self, round_number: int, cache: Dict[int, str]) -> str: 81 | return bytearray.fromhex( 82 | self._get_drand_round_info(round_number, cache)['signature'] 83 | ) 84 | 85 | def get_current_round(self) -> int: 86 | return int(time.time()- self.NET_START) // self.ROUND_DURATION 87 | 88 | def encrypt(self, uid: int, message: str, current_block: int) -> Tuple[int, bytes]: 89 | """ 90 | Encrypt a message with a future Drand round key, prefixing it with the UID. 91 | The target round is calculated to be within the last 10 blocks of the competition. 92 | 93 | Args: 94 | uid: The user ID 95 | message: The message to encrypt 96 | current_block: The current block number 97 | 98 | Returns: 99 | A tuple of (target_round, encrypted_message) 100 | """ 101 | # Calculate the next epoch boundary 102 | next_epoch_boundary = ((current_block // self.EPOCH_LENGTH) + 1) * self.EPOCH_LENGTH 103 | # Target round should be 10 blocks before the epoch boundary 104 | target_block = next_epoch_boundary - 10 105 | 106 | # Convert block number to Drand round 107 | # Each block is roughly 12 seconds, and Drand round is 3 seconds 108 | target_round: int = self.get_current_round() + ((target_block - current_block) * 4) 109 | 110 | bt.logging.info(f"Encrypting message for UID {uid}... Unlockable at round {target_round} (block {target_block})") 111 | 112 | prefixed_message: str = f"{uid}:{message}" 113 | sk = secrets.token_bytes(32) # an ephemeral secret key 114 | ciphertext: bytes = self.tl.tle(target_round, prefixed_message, sk) 115 | 116 | return target_round, ciphertext 117 | 118 | def decrypt(self, uid: int, ciphertext: bytes, target_round: int, signature: Optional[str] = None) -> Optional[str]: 119 | """ 120 | Attempt to decrypt a single message, verifying the UID prefix. 121 | If the decrypted message doesn't start with the expected UID prefix, return None. 122 | """ 123 | if not signature: 124 | try: 125 | signature: bytes = self._get_drand_signature(target_round, {}) 126 | except RuntimeError as e: 127 | bt.logging.error(e) 128 | raise 129 | 130 | bt.logging.info(f"Decrypting message for UID {uid} at round {target_round}...") 131 | 132 | # key: bytes = self._derive_key(randomness) 133 | # cipher: Fernet = Fernet(key) 134 | # decrypted_message: str = cipher.decrypt(encrypted_message).decode() 135 | #print(repr(ciphertext)) 136 | plaintext = self.tl.tld(ciphertext, signature).decode() 137 | 138 | expected_prefix = f"{uid}:" 139 | if not plaintext.startswith(expected_prefix): 140 | bt.logging.warning(f"UID mismatch: Expected {expected_prefix} but got {plaintext}") 141 | return None 142 | 143 | return plaintext[len(expected_prefix):] 144 | 145 | def decrypt_dict(self, encrypted_dict: Dict[int, Tuple[int, bytes]]) -> Dict[int, Optional[str]]: 146 | """ 147 | Decrypt a dictionary of {uid: (target_round, encrypted_payload)}, caching signatures for this function call. 148 | """ 149 | decrypted_dict: Dict[int, Optional[bytes]] = {} 150 | cache: Dict[int, str] = {} 151 | 152 | for uid, (target_round, ciphertext) in encrypted_dict.items(): 153 | try: 154 | signature = self._get_drand_signature(target_round, cache) 155 | decrypted_dict[uid] = self.decrypt(uid, ciphertext, target_round, signature) 156 | except RuntimeError: 157 | current_round = self.get_current_round() 158 | bt.logging.warning(f"Skipping UID {uid}: Too early to decrypt: {target_round=}, {current_round=}") 159 | decrypted_dict[uid] = None 160 | continue 161 | except ValueError: 162 | bt.logging.warning(f"Skipping UID {uid}: Invalid ciphertext") 163 | decrypted_dict[uid] = None 164 | continue 165 | #print(repr(ciphertext)) 166 | return decrypted_dict 167 | 168 | 169 | #class BittensorDrandTimelock(AbstractBittensorDrandTimelock): 170 | # ROUND_DURATION = 30 171 | # PK_HEX = '868f005eb8e6e4ca0a47c8a77ceaa5309a47978a7c71bc5cce96366b5d7a569937c529eeda66c7293784a9402801af31' 172 | # CHAIN = '8990e7a9aaed2ffed73dbd7092123d6f289930540d7651336225dc172e51b2ce' 173 | # NET_START = 1595431050 174 | 175 | 176 | class QuicknetBittensorDrandTimelock(AbstractBittensorDrandTimelock): 177 | ROUND_DURATION = 3 178 | PK_HEX = "83cf0f2896adee7eb8b5f01fcad3912212c437e0073e911fb90022d3e760183c8c4b450b6a0a6c3ac6a5776a2d1064510d1fec758c921cc22b0e17e63aaf4bcb5ed66304de9cf809bd274ca73bab4af5a6e9c76a4bc09e76eae8991ef5ece45a" 179 | CHAIN = '52db9ba70e0cc0f6eaf7803dd07447a1f5477735fd3f661792ba94600c84e971' 180 | NET_START = 1692803367 181 | 182 | 183 | def _prepare_test(bdt): 184 | msg1: str = "Secret message #1" 185 | msg2: str = "Secret message #2" 186 | 187 | encrypted_dict: Dict[int, Tuple[int, bytes]] = { 188 | 1: bdt.encrypt(1, msg1, rounds=1), 189 | 2: bdt.encrypt(2, msg2, rounds=15), 190 | } 191 | bt.logging.info(f"Encrypted Dictionary: {encrypted_dict}") 192 | return encrypted_dict 193 | 194 | 195 | def sync_decrypt_example(encrypted_dict, bdt) -> None: 196 | """Synchronous example of encryption and decryption.""" 197 | try: 198 | decrypted_dict: Dict[int, Optional[str]] = bdt.decrypt_dict(encrypted_dict) 199 | logger.info(f"Decrypted Dictionary: {decrypted_dict}") 200 | except RuntimeError: 201 | logger.error("Decryption failed for one or more entries.") 202 | 203 | 204 | async def async_decrypt_example(encrypted_dict, bdt) -> None: 205 | """Example of using BittensorDrandTimelock in async code via ThreadPoolExecutor.""" 206 | loop = asyncio.get_running_loop() 207 | with ThreadPoolExecutor() as executor: 208 | try: 209 | decrypted_dict = await loop.run_in_executor(executor, bdt.decrypt_dict, encrypted_dict) 210 | logger.info(f"Decrypted Dictionary: {decrypted_dict}") 211 | except RuntimeError: 212 | logger.error("Decryption failed for one or more entries.") 213 | 214 | 215 | if __name__ == "__main__": 216 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") 217 | for bdt in ( 218 | QuicknetBittensorDrandTimelock(), 219 | #BittensorDrandTimelock(), # something wrong with the public key? 220 | ): 221 | encrypted_dict = _prepare_test(bdt) 222 | time.sleep(bdt.ROUND_DURATION + 1) 223 | print('='*50) 224 | sync_decrypt_example(encrypted_dict, bdt) 225 | print('='*50) 226 | asyncio.run( 227 | async_decrypt_example( 228 | encrypted_dict, 229 | bdt, 230 | ) 231 | ) 232 | print('#'*75) 233 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | # Number of targets and antitargets to select 2 | protein_selection: 3 | weekly_target: "P31652" 4 | num_antitargets: 4 5 | 6 | # Parameters for validating and scoring molecules 7 | molecule_validation: 8 | # Scoring weights for target/antitarget binding 9 | antitarget_weight: 0.75 10 | # Scoring weights for entropy 11 | entropy_bonus_threshold: 0 12 | entropy_start_weight: 0.3 13 | entropy_start_epoch: 15775 14 | entropy_step_size: 0.007142857 # 1/140 to increase by 1.0 over about 1 week 15 | # Scoring weights for molecule repetition 16 | molecule_repetition_weight: 0 17 | molecule_repetition_threshold: 0 18 | # Molecular property requirements 19 | min_heavy_atoms: 20 20 | min_rotatable_bonds: 1 21 | max_rotatable_bonds: 10 22 | # Number of molecules to validate 23 | num_molecules: 100 24 | 25 | # Competition parameters 26 | competition: 27 | # No submission blocks 28 | no_submission_blocks: 10 29 | -------------------------------------------------------------------------------- /config/config_loader.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | 4 | def load_config(path: str = "config/config.yaml"): 5 | """ 6 | Loads configuration from a YAML file. 7 | """ 8 | if not os.path.exists(path): 9 | raise FileNotFoundError(f"Could not find config file at '{path}'") 10 | 11 | with open(path, "r", encoding="utf-8") as f: 12 | config = yaml.safe_load(f) 13 | 14 | # Load configuration options 15 | weekly_target = config["protein_selection"]["weekly_target"] 16 | num_antitargets = config["protein_selection"]["num_antitargets"] 17 | 18 | no_submission_blocks = config["competition"]["no_submission_blocks"] 19 | 20 | validation_config = config["molecule_validation"] 21 | antitarget_weight = validation_config["antitarget_weight"] 22 | min_heavy_atoms = validation_config["min_heavy_atoms"] 23 | min_rotatable_bonds = validation_config["min_rotatable_bonds"] 24 | max_rotatable_bonds = validation_config["max_rotatable_bonds"] 25 | num_molecules = validation_config["num_molecules"] 26 | entropy_bonus_threshold = validation_config["entropy_bonus_threshold"] 27 | entropy_start_weight = validation_config["entropy_start_weight"] 28 | entropy_start_epoch = validation_config["entropy_start_epoch"] 29 | entropy_step_size = validation_config["entropy_step_size"] 30 | molecule_repetition_weight = validation_config["molecule_repetition_weight"] 31 | molecule_repetition_threshold = validation_config["molecule_repetition_threshold"] 32 | 33 | return { 34 | 'weekly_target': weekly_target, 35 | 'num_antitargets': num_antitargets, 36 | 'no_submission_blocks': no_submission_blocks, 37 | 'antitarget_weight': antitarget_weight, 38 | 'min_heavy_atoms': min_heavy_atoms, 39 | 'min_rotatable_bonds': min_rotatable_bonds, 40 | 'max_rotatable_bonds': max_rotatable_bonds, 41 | 'num_molecules': num_molecules, 42 | 'entropy_bonus_threshold': entropy_bonus_threshold, 43 | 'entropy_start_weight': entropy_start_weight, 44 | 'entropy_start_epoch': entropy_start_epoch, 45 | 'entropy_step_size': entropy_step_size, 46 | 'molecule_repetition_weight': molecule_repetition_weight, 47 | 'molecule_repetition_threshold': molecule_repetition_threshold 48 | } -------------------------------------------------------------------------------- /example.env: -------------------------------------------------------------------------------- 1 | # General configs 2 | SUBTENSOR_NETWORK="ws://localhost:9944" # or your chosen node 3 | DEVICE_OVERRIDE="cpu" # None to run on GPU 4 | 5 | # Github configs - FOR MINERS 6 | GITHUB_REPO_NAME="repo-name" 7 | GITHUB_REPO_BRANCH="repo-branch" 8 | GITHUB_TOKEN="your_token" 9 | GITHUB_REPO_OWNER="repo-owner" 10 | GITHUB_REPO_PATH="" # path within repo or "" 11 | 12 | # For validators 13 | VALIDATOR_API_KEY="your_api_key" 14 | AUTO_UPDATE="0" # Set to "1" to enable auto-updates 15 | -------------------------------------------------------------------------------- /install_deps_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -Eeuo pipefail 3 | 4 | # Install uv: 5 | wget -qO- https://astral.sh/uv/install.sh | sh 6 | 7 | # Install Rust (cargo) with auto-confirmation: 8 | wget -qO- https://sh.rustup.rs | sh -s -- -y 9 | source "$HOME/.cargo/env" 10 | 11 | # Install system build/env tools (Ubuntu/Debian): 12 | sudo apt update && sudo apt install -y build-essential 13 | sudo apt install python3.12-venv 14 | 15 | # Clone timelock at specific commit: 16 | git clone https://github.com/ideal-lab5/timelock.git 17 | cd timelock 18 | git checkout 23fe963f17175e413b7434180d2d0d0776722f1f 19 | cd .. 20 | 21 | # Create and activate virtual environment: 22 | uv venv 23 | source .venv/bin/activate 24 | uv pip install -r requirements/requirements.txt 25 | uv pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cpu 26 | uv pip install torch-geometric==2.6.1 27 | uv pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+cpu.html 28 | uv pip install patchelf 29 | uv pip install maturin==1.8.3 30 | 31 | # Build timelock Python bindings (WASM) 32 | export PYO3_CROSS_PYTHON_VERSION="3.12" && cd timelock/wasm && ./wasm_build_py.sh && cd ../.. 33 | 34 | # Build timelock Python package: 35 | cd timelock/py && uv pip install --upgrade build && python3.12 -m build 36 | uv pip install timelock 37 | 38 | echo "Installation complete." 39 | 40 | -------------------------------------------------------------------------------- /install_deps_cu124.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -Eeuo pipefail 3 | 4 | # Install uv: 5 | wget -qO- https://astral.sh/uv/install.sh | sh 6 | 7 | # Install Rust (cargo) with auto-confirmation: 8 | wget -qO- https://sh.rustup.rs | sh -s -- -y 9 | source "$HOME/.cargo/env" 10 | 11 | # Install system build/env tools (Ubuntu/Debian): 12 | sudo apt update && sudo apt install -y build-essential 13 | sudo apt install python3.12-venv 14 | 15 | # Clone timelock at specific commit: 16 | git clone https://github.com/ideal-lab5/timelock.git 17 | cd timelock 18 | git checkout 23fe963f17175e413b7434180d2d0d0776722f1f 19 | cd .. 20 | 21 | 22 | # Create and activate virtual environment 23 | uv venv && source .venv/bin/activate 24 | && uv pip install -r requirements/requirements.txt 25 | && uv pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124 26 | && uv pip install torch-geometric==2.6.1 27 | && uv pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+cu124.html 28 | && uv pip install patchelf 29 | && uv pip install maturin==1.8.3 30 | 31 | # Build timelock Python bindings (WASM) 32 | export PYO3_CROSS_PYTHON_VERSION="3.12" && cd timelock/wasm && ./wasm_build_py.sh && cd ../.. 33 | 34 | # Build timelock Python package: 35 | cd timelock/py && uv pip install --upgrade build && python3.12 -m build 36 | uv pip install timelock 37 | 38 | echo "Installation complete." 39 | -------------------------------------------------------------------------------- /my_utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import os 3 | import json 4 | from dotenv import load_dotenv 5 | import bittensor as bt 6 | from datasets import load_dataset 7 | import random 8 | from rdkit import Chem 9 | from rdkit.Chem import MACCSkeys 10 | import numpy as np 11 | import math 12 | import pandas as pd 13 | from huggingface_hub import hf_hub_download, hf_hub_url, get_hf_file_metadata 14 | import time 15 | import datetime 16 | 17 | load_dotenv(override=True) 18 | 19 | def upload_file_to_github(filename: str, encoded_content: str): 20 | # Github configs 21 | github_repo_name = os.environ.get('GITHUB_REPO_NAME') # example: nova 22 | github_repo_branch = os.environ.get('GITHUB_REPO_BRANCH') # example: main 23 | github_token = os.environ.get('GITHUB_TOKEN') 24 | github_repo_owner = os.environ.get('GITHUB_REPO_OWNER') # example: metanova-labs 25 | github_repo_path = os.environ.get('GITHUB_REPO_PATH') # example: /data/results or "" 26 | 27 | if not github_repo_name or not github_repo_branch or not github_token or not github_repo_owner: 28 | raise ValueError("Github environment variables not set. Please set them in your .env file.") 29 | 30 | target_file_path = os.path.join(github_repo_path, f'{filename}.txt') 31 | url = f"https://api.github.com/repos/{github_repo_owner}/{github_repo_name}/contents/{target_file_path}" 32 | headers = { 33 | "Authorization": f"Bearer {github_token}", 34 | "Accept": "application/vnd.github+json", 35 | } 36 | 37 | # Check if the file already exists (need its SHA to update) 38 | existing_file = requests.get(url, headers=headers, params={"ref": github_repo_branch}) 39 | sha = existing_file.json().get("sha") if existing_file.status_code == 200 else None 40 | 41 | payload = { 42 | "message": f"Encrypted response for {filename}", 43 | "content": encoded_content, 44 | "branch": github_repo_branch, 45 | } 46 | if sha: 47 | payload["sha"] = sha # updating existing file 48 | 49 | response = requests.put(url, headers=headers, json=payload) 50 | if response.status_code in [200, 201]: 51 | return True 52 | else: 53 | bt.logging.error(f"Failed to upload file for {filename}: {response.status_code} {response.text}") 54 | return False 55 | 56 | 57 | def get_smiles(product_name): 58 | # Remove single and double quotes from product_name if they exist 59 | if product_name: 60 | product_name = product_name.replace("'", "").replace('"', "") 61 | else: 62 | bt.logging.error("Product name is empty.") 63 | return None 64 | 65 | api_key = os.environ.get("VALIDATOR_API_KEY") 66 | if not api_key: 67 | raise ValueError("validator_api_key environment variable not set.") 68 | 69 | url = f"https://8vzqr9wt22.execute-api.us-east-1.amazonaws.com/dev/smiles/{product_name}" 70 | 71 | headers = {"x-api-key": api_key} 72 | 73 | response = requests.get(url, headers=headers) 74 | 75 | data = response.json() 76 | 77 | return data.get("smiles") 78 | 79 | def get_sequence_from_protein_code(protein_code:str) -> str: 80 | """ 81 | Get the amino acid sequence for a protein code. 82 | First tries to fetch from UniProt API, and if that fails, 83 | falls back to searching the Hugging Face dataset. 84 | """ 85 | url = f"https://rest.uniprot.org/uniprotkb/{protein_code}.fasta" 86 | response = requests.get(url) 87 | 88 | if response.status_code == 200: 89 | lines = response.text.splitlines() 90 | sequence_lines = [line.strip() for line in lines if not line.startswith('>')] 91 | amino_acid_sequence = ''.join(sequence_lines) 92 | # Check if the sequence is empty 93 | if not amino_acid_sequence: 94 | bt.logging.warning(f"Retrieved empty sequence for {protein_code} from UniProt API") 95 | else: 96 | return amino_acid_sequence 97 | 98 | bt.logging.info(f"Failed to retrieve sequence for {protein_code} from UniProt API. Trying Hugging Face dataset.") 99 | try: 100 | dataset = load_dataset("Metanova/Proteins", split="train") 101 | 102 | for i in range(len(dataset)): 103 | if dataset[i]["Entry"] == protein_code: 104 | sequence = dataset[i]["Sequence"] 105 | bt.logging.info(f"Found sequence for {protein_code} in Hugging Face dataset") 106 | return sequence 107 | 108 | bt.logging.error(f"Could not find protein {protein_code} in Hugging Face dataset") 109 | return None 110 | 111 | except Exception as e: 112 | bt.logging.error(f"Error accessing Hugging Face dataset: {e}") 113 | return None 114 | 115 | def get_challenge_proteins_from_blockhash(block_hash: str, weekly_target: str, num_antitargets: int) -> dict: 116 | """ 117 | Use block_hash as a seed to pick 'num_targets' and 'num_antitargets' random entries 118 | from the 'Metanova/Proteins' dataset. Returns {'targets': [...], 'antitargets': [...]}. 119 | """ 120 | if not (isinstance(block_hash, str) and block_hash.startswith("0x")): 121 | raise ValueError("block_hash must start with '0x'.") 122 | if not weekly_target or num_antitargets < 0: 123 | raise ValueError("weekly_target must exist and num_antitargets must be non-negative.") 124 | 125 | # Convert block hash to an integer seed 126 | try: 127 | seed = int(block_hash[2:], 16) 128 | except ValueError: 129 | raise ValueError(f"Invalid hex in block_hash: {block_hash}") 130 | 131 | # Initialize random number generator 132 | rng = random.Random(seed) 133 | 134 | # Load huggingface protein dataset 135 | try: 136 | dataset = load_dataset("Metanova/Proteins", split="train") 137 | except Exception as e: 138 | raise RuntimeError("Could not load the 'Metanova/Proteins' dataset.") from e 139 | 140 | dataset_size = len(dataset) 141 | if dataset_size == 0: 142 | raise ValueError("Dataset is empty; cannot pick random entries.") 143 | 144 | # Grab all required indices at once, ensure uniqueness 145 | unique_indices = rng.sample(range(dataset_size), k=(num_antitargets)) 146 | 147 | # Split indices for antitargets 148 | antitarget_indices = unique_indices[:num_antitargets] 149 | 150 | # Convert indices to protein codes 151 | targets = [weekly_target] 152 | antitargets = [dataset[i]["Entry"] for i in antitarget_indices] 153 | 154 | return { 155 | "targets": targets, 156 | "antitargets": antitargets 157 | } 158 | 159 | def get_heavy_atom_count(smiles: str) -> int: 160 | """ 161 | Calculate the number of heavy atoms in a molecule from its SMILES string. 162 | """ 163 | count = 0 164 | i = 0 165 | while i < len(smiles): 166 | c = smiles[i] 167 | 168 | if c.isalpha() and c.isupper(): 169 | elem_symbol = c 170 | 171 | # If the next character is a lowercase letter, include it (e.g., 'Cl', 'Br') 172 | if i + 1 < len(smiles) and smiles[i + 1].islower(): 173 | elem_symbol += smiles[i + 1] 174 | i += 1 175 | 176 | # If it's not 'H', count it as a heavy atom 177 | if elem_symbol != 'H': 178 | count += 1 179 | 180 | i += 1 181 | 182 | return count 183 | 184 | def compute_maccs_entropy(smiles_list: list[str]) -> float: 185 | """ 186 | Computes fingerprint entropy from MACCS keys for a list of SMILES. 187 | 188 | Parameters: 189 | smiles_list (list of str): Molecules in SMILES format. 190 | 191 | Returns: 192 | avg_entropy (float): Average entropy per bit. 193 | """ 194 | n_bits = 167 # RDKit uses 167 bits (index 0 is always 0) 195 | bit_counts = np.zeros(n_bits) 196 | valid_mols = 0 197 | 198 | for smi in smiles_list: 199 | mol = Chem.MolFromSmiles(smi) 200 | if mol: 201 | fp = MACCSkeys.GenMACCSKeys(mol) 202 | arr = np.array(fp) 203 | bit_counts += arr 204 | valid_mols += 1 205 | 206 | if valid_mols == 0: 207 | raise ValueError("No valid molecules found.") 208 | 209 | probs = bit_counts / valid_mols 210 | entropy_per_bit = np.array([ 211 | -p * math.log2(p) - (1 - p) * math.log2(1 - p) if 0 < p < 1 else 0 212 | for p in probs 213 | ]) 214 | 215 | avg_entropy = np.mean(entropy_per_bit) 216 | 217 | return avg_entropy 218 | 219 | def molecule_unique_for_protein_api(protein: str, molecule: str) -> bool: 220 | """ 221 | Check if a molecule has been previously submitted for the same target protein in any competition. 222 | """ 223 | api_key = os.environ.get("VALIDATOR_API_KEY") 224 | if not api_key: 225 | raise ValueError("validator_api_key environment variable not set.") 226 | 227 | url = f"https://dashboard-backend-multitarget.up.railway.app/api/molecule_seen/{molecule}/{protein}" 228 | 229 | headers = { 230 | "Authorization": f"Bearer {api_key}" 231 | } 232 | 233 | try: 234 | response = requests.get(url, headers=headers) 235 | 236 | if response.status_code != 200: 237 | bt.logging.error(f"Failed to check molecule uniqueness: {response.status_code} {response.text}") 238 | return True 239 | 240 | data = response.json() 241 | return not data.get("seen", False) 242 | 243 | except Exception as e: 244 | bt.logging.error(f"Error checking molecule uniqueness: {e}") 245 | return True 246 | 247 | def molecule_unique_for_protein_hf(protein: str, smiles: str) -> bool: 248 | """ 249 | Check if molecule exists in Hugging Face Submission-Archive dataset by comparing InChIKeys. 250 | Returns True if unique (not found), False if found. 251 | """ 252 | if not hasattr(molecule_unique_for_protein_hf, "_CACHE"): 253 | molecule_unique_for_protein_hf._CACHE = (None, None, None, 0) 254 | 255 | try: 256 | cached_protein, cached_sha, inchikeys_set, last_check_time = molecule_unique_for_protein_hf._CACHE 257 | current_time = time.time() 258 | metadata_ttl = 60 259 | 260 | if protein != cached_protein: 261 | bt.logging.debug(f"Switching from protein {cached_protein} to {protein}") 262 | cached_sha = None 263 | 264 | filename = f"{protein}_molecules.csv" 265 | 266 | if cached_sha is None or (current_time - last_check_time > metadata_ttl): 267 | url = hf_hub_url( 268 | repo_id="Metanova/Submission-Archive", 269 | filename=filename, 270 | repo_type="dataset" 271 | ) 272 | 273 | metadata = get_hf_file_metadata(url) 274 | current_sha = metadata.commit_hash 275 | last_check_time = current_time 276 | 277 | if cached_sha != current_sha: 278 | file_path = hf_hub_download( 279 | repo_id="Metanova/Submission-Archive", 280 | filename=filename, 281 | repo_type="dataset", 282 | revision=current_sha 283 | ) 284 | 285 | df = pd.read_csv(file_path, usecols=["InChI_Key"]) 286 | inchikeys_set = set(df["InChI_Key"]) 287 | bt.logging.debug(f"Loaded {len(inchikeys_set)} InChI Keys into lookup set for {protein} (commit {current_sha[:7]})") 288 | 289 | molecule_unique_for_protein_hf._CACHE = (protein, current_sha, inchikeys_set, last_check_time) 290 | else: 291 | molecule_unique_for_protein_hf._CACHE = molecule_unique_for_protein_hf._CACHE[:3] + (last_check_time,) 292 | 293 | mol = Chem.MolFromSmiles(smiles) 294 | if mol is None: 295 | bt.logging.warning(f"Could not parse SMILES string: {smiles}") 296 | return True # Assume unique if we can't parse the SMILES 297 | 298 | inchikey = Chem.MolToInchiKey(mol) 299 | 300 | return inchikey not in inchikeys_set 301 | 302 | except Exception as e: 303 | # Assume molecule is unique if there's an error 304 | bt.logging.warning(f"Error checking molecule in HF dataset: {e}") 305 | return True 306 | 307 | def find_chemically_identical(smiles_list: list[str]) -> dict: 308 | """ 309 | Check for identical molecules in a list of SMILES strings by converting to InChIKeys. 310 | """ 311 | inchikey_to_indices = {} 312 | 313 | for i, smiles in enumerate(smiles_list): 314 | try: 315 | mol = Chem.MolFromSmiles(smiles) 316 | if mol is not None: 317 | inchikey = Chem.MolToInchiKey(mol) 318 | if inchikey not in inchikey_to_indices: 319 | inchikey_to_indices[inchikey] = [] 320 | inchikey_to_indices[inchikey].append(i) 321 | except Exception as e: 322 | bt.logging.warning(f"Error processing SMILES {smiles}: {e}") 323 | 324 | duplicates = {k: v for k, v in inchikey_to_indices.items() if len(v) > 1} 325 | 326 | return duplicates 327 | 328 | def calculate_dynamic_entropy(starting_weight: float, step_size: float, start_epoch: int, current_epoch: int) -> float: 329 | """ 330 | Calculate entropy weight based on epochs elapsed since start epoch. 331 | """ 332 | epochs_elapsed = current_epoch - start_epoch 333 | 334 | entropy_weight = starting_weight + (epochs_elapsed * step_size) 335 | entropy_weight = max(0, entropy_weight) 336 | 337 | bt.logging.info(f"Epochs elapsed: {epochs_elapsed}, entropy weight: {entropy_weight}") 338 | return entropy_weight -------------------------------------------------------------------------------- /neurons/set_weight_to_uid.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import bittensor as bt 4 | import os 5 | from dotenv import load_dotenv 6 | import time 7 | 8 | def main(): 9 | load_dotenv() 10 | 11 | # 1) Parse the single argument for target_uid 12 | parser = argparse.ArgumentParser( 13 | description="Set weights on netuid=68 so that only target_uid has weight=1." 14 | ) 15 | parser.add_argument('--target_uid', type=int, required=True, 16 | help="The UID that will receive weight=1.0. Others = 0.0") 17 | parser.add_argument('--wallet_name', type=str, required=True, 18 | help="The name of the wallet to use.") 19 | parser.add_argument('--wallet_hotkey', type=str, required=True, 20 | help="The hotkey to use for the wallet.") 21 | 22 | args = parser.parse_args() 23 | 24 | NETUID = 68 25 | 26 | wallet = bt.wallet( 27 | name=args.wallet_name, 28 | hotkey=args.wallet_hotkey, 29 | ) 30 | 31 | # Create Subtensor connection using network from .env 32 | subtensor_network = os.getenv('SUBTENSOR_NETWORK') 33 | subtensor = bt.subtensor(network=subtensor_network) 34 | 35 | 36 | # Download the metagraph for netuid=68 37 | metagraph = subtensor.metagraph(NETUID) 38 | 39 | # Check registration 40 | hotkey_ss58 = wallet.hotkey.ss58_address 41 | if hotkey_ss58 not in metagraph.hotkeys: 42 | print(f"Hotkey {hotkey_ss58} is not registered on netuid {NETUID}. Exiting.") 43 | sys.exit(1) 44 | 45 | # 2) Build the weight vector 46 | n = len(metagraph.uids) 47 | weights = [0.0] * n 48 | 49 | # Validate the user-provided target UID 50 | if not (0 <= args.target_uid < n): 51 | print(f"Error: target_uid {args.target_uid} out of range [0, {n-1}]. Exiting.") 52 | sys.exit(1) 53 | 54 | # Set the single weight 55 | weights[args.target_uid] = 1.0 56 | 57 | # 3) Send the weights to the chain with retry logic 58 | max_retries = 10 59 | delay_between_retries = 12 # seconds 60 | for attempt in range(max_retries): 61 | try: 62 | print(f"Attempt {attempt + 1} to set weights.") 63 | result = subtensor.set_weights( 64 | netuid=NETUID, 65 | wallet=wallet, 66 | uids=metagraph.uids, 67 | weights=weights, 68 | wait_for_inclusion=True 69 | ) 70 | print(f"Result from set_weights: {result}") 71 | 72 | # Only break if result indicates success (result[0] == True). 73 | if result[0] is True: 74 | print("Weights set successfully. Exiting retry loop.") 75 | break 76 | else: 77 | print("set_weights returned a non-success response. Will retry if attempts remain.") 78 | if attempt < max_retries - 1: 79 | print(f"Retrying in {delay_between_retries} seconds...") 80 | time.sleep(delay_between_retries) 81 | 82 | except Exception as e: 83 | print(f"Error setting weights: {e}") 84 | 85 | if attempt < max_retries - 1: 86 | print(f"Retrying in {delay_between_retries} seconds...") 87 | time.sleep(delay_between_retries) 88 | else: 89 | print("Failed to set weights after multiple attempts. Exiting.") 90 | sys.exit(1) 91 | 92 | print("Done.") 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | biopython==1.85 2 | bittensor==9.0.1 3 | datasets==3.3.2 4 | fair-esm==2.0.0 5 | huggingface-hub==0.29.1 6 | lifelines==0.30.0 7 | pandas==2.2.3 8 | python-dotenv==1.0.1 9 | rdkit==2024.9.4 10 | requests==2.32.3 11 | substrate-interface==1.7.11 12 | scikit-learn==1.6.1 13 | --------------------------------------------------------------------------------