├── .gitignore ├── EVE ├── VAE_decoder.py ├── VAE_encoder.py ├── VAE_model.py └── default_model_params.json ├── LICENSE ├── README.md ├── compute_evol_indices.py ├── data ├── MSA │ ├── P53_HUMAN_b0.1.a2m │ ├── PTEN_HUMAN_b1.0.a2m │ ├── RASH_HUMAN_b03.a2m │ └── SCN5A_HUMAN_b1.0.a2m ├── labels │ ├── ClinVar_labels_P53_PTEN_RASH_SCN5A.csv │ └── PTEN_ClinVar_labels.csv ├── mappings │ └── example_mapping.csv ├── mutations │ └── .gitkeep └── weights │ └── .gitkeep ├── examples ├── Step1_train_VAE.sh ├── Step2_compute_evol_indices_all_singles.sh └── Step3_train_GMM_and_compute_EVE_scores_all_singles.sh ├── logs └── .gitkeep ├── protein_env.yml ├── results ├── EVE_scores │ └── .gitkeep ├── GMM_parameters │ └── Default_GMM_parameters │ │ ├── GMM_model_dictionary_default │ │ └── GMM_pathogenic_cluster_index_dictionary_default ├── VAE_parameters │ └── .gitkeep ├── evol_indices │ └── .gitkeep ├── plots_histograms │ └── .gitkeep └── plots_scores_vs_labels │ └── .gitkeep ├── train_GMM_and_compute_EVE_scores.py ├── train_VAE.py └── utils ├── data_utils.py ├── default_uncertainty_threshold.json ├── performance_helpers.py └── plot_helpers.py /.gitignore: -------------------------------------------------------------------------------- 1 | EVE/__pycache__/ 2 | utils/__pycache__/ 3 | results/VAE_parameters/* 4 | !results/VAE_parameters/.gitkeep -------------------------------------------------------------------------------- /EVE/VAE_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class VAE_Bayesian_MLP_decoder(nn.Module): 6 | """ 7 | Bayesian MLP decoder class for the VAE model. 8 | """ 9 | def __init__(self, params): 10 | """ 11 | Required input parameters: 12 | - seq_len: (Int) Sequence length of sequence alignment 13 | - alphabet_size: (Int) Alphabet size of sequence alignment (will be driven by the data helper object) 14 | - hidden_layers_sizes: (List) List of the sizes of the hidden layers (all DNNs) 15 | - z_dim: (Int) Dimension of latent space 16 | - first_hidden_nonlinearity: (Str) Type of non-linear activation applied on the first (set of) hidden layer(s) 17 | - last_hidden_nonlinearity: (Str) Type of non-linear activation applied on the very last hidden layer (pre-sparsity) 18 | - dropout_proba: (Float) Dropout probability applied on all hidden layers. If 0.0 then no dropout applied 19 | - convolve_output: (Bool) Whether to perform 1d convolution on output (kernel size 1, stide 1) 20 | - convolution_depth: (Int) Size of the 1D-convolution on output 21 | - include_temperature_scaler: (Bool) Whether we apply the global temperature scaler 22 | - include_sparsity: (Bool) Whether we use the sparsity inducing scheme on the output from the last hidden layer 23 | - num_tiles_sparsity: (Int) Number of tiles to use in the sparsity inducing scheme (the more the tiles, the stronger the sparsity) 24 | - bayesian_decoder: (Bool) Whether the decoder is bayesian or not 25 | """ 26 | super().__init__() 27 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | self.seq_len = params['seq_len'] 29 | self.alphabet_size = params['alphabet_size'] 30 | self.hidden_layers_sizes = params['hidden_layers_sizes'] 31 | self.z_dim = params['z_dim'] 32 | self.bayesian_decoder = True 33 | self.dropout_proba = params['dropout_proba'] 34 | self.convolve_output = params['convolve_output'] 35 | self.convolution_depth = params['convolution_output_depth'] 36 | self.include_temperature_scaler = params['include_temperature_scaler'] 37 | self.include_sparsity = params['include_sparsity'] 38 | self.num_tiles_sparsity = params['num_tiles_sparsity'] 39 | 40 | self.mu_bias_init = 0.1 41 | self.logvar_init = -10.0 42 | self.logit_scale_p = 0.001 43 | 44 | self.hidden_layers_mean=nn.ModuleDict() 45 | self.hidden_layers_log_var=nn.ModuleDict() 46 | for layer_index in range(len(self.hidden_layers_sizes)): 47 | if layer_index==0: 48 | self.hidden_layers_mean[str(layer_index)] = nn.Linear(self.z_dim, self.hidden_layers_sizes[layer_index]) 49 | self.hidden_layers_log_var[str(layer_index)] = nn.Linear(self.z_dim, self.hidden_layers_sizes[layer_index]) 50 | nn.init.constant_(self.hidden_layers_mean[str(layer_index)].bias, self.mu_bias_init) 51 | nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].weight, self.logvar_init) 52 | nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].bias, self.logvar_init) 53 | else: 54 | self.hidden_layers_mean[str(layer_index)] = nn.Linear(self.hidden_layers_sizes[layer_index-1],self.hidden_layers_sizes[layer_index]) 55 | self.hidden_layers_log_var[str(layer_index)] = nn.Linear(self.hidden_layers_sizes[layer_index-1],self.hidden_layers_sizes[layer_index]) 56 | nn.init.constant_(self.hidden_layers_mean[str(layer_index)].bias, self.mu_bias_init) 57 | nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].weight, self.logvar_init) 58 | nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].bias, self.logvar_init) 59 | 60 | if params['first_hidden_nonlinearity'] == 'relu': 61 | self.first_hidden_nonlinearity = nn.ReLU() 62 | elif params['first_hidden_nonlinearity'] == 'tanh': 63 | self.first_hidden_nonlinearity = nn.Tanh() 64 | elif params['first_hidden_nonlinearity'] == 'sigmoid': 65 | self.first_hidden_nonlinearity = nn.Sigmoid() 66 | elif params['first_hidden_nonlinearity'] == 'elu': 67 | self.first_hidden_nonlinearity = nn.ELU() 68 | elif params['first_hidden_nonlinearity'] == 'linear': 69 | self.first_hidden_nonlinearity = nn.Identity() 70 | 71 | if params['last_hidden_nonlinearity'] == 'relu': 72 | self.last_hidden_nonlinearity = nn.ReLU() 73 | elif params['last_hidden_nonlinearity'] == 'tanh': 74 | self.last_hidden_nonlinearity = nn.Tanh() 75 | elif params['last_hidden_nonlinearity'] == 'sigmoid': 76 | self.last_hidden_nonlinearity = nn.Sigmoid() 77 | elif params['last_hidden_nonlinearity'] == 'elu': 78 | self.last_hidden_nonlinearity = nn.ELU() 79 | elif params['last_hidden_nonlinearity'] == 'linear': 80 | self.last_hidden_nonlinearity = nn.Identity() 81 | 82 | if self.dropout_proba > 0.0: 83 | self.dropout_layer = nn.Dropout(p=self.dropout_proba) 84 | 85 | if self.convolve_output: 86 | self.output_convolution_mean = nn.Conv1d(in_channels=self.convolution_depth,out_channels=self.alphabet_size,kernel_size=1,stride=1,bias=False) 87 | self.output_convolution_log_var = nn.Conv1d(in_channels=self.convolution_depth,out_channels=self.alphabet_size,kernel_size=1,stride=1,bias=False) 88 | nn.init.constant_(self.output_convolution_log_var.weight, self.logvar_init) 89 | self.channel_size = self.convolution_depth 90 | else: 91 | self.channel_size = self.alphabet_size 92 | 93 | if self.include_sparsity: 94 | self.sparsity_weight_mean = nn.Parameter(torch.zeros(int(self.hidden_layers_sizes[-1]/self.num_tiles_sparsity), self.seq_len)) 95 | self.sparsity_weight_log_var = nn.Parameter(torch.ones(int(self.hidden_layers_sizes[-1]/self.num_tiles_sparsity), self.seq_len)) 96 | nn.init.constant_(self.sparsity_weight_log_var, self.logvar_init) 97 | 98 | self.last_hidden_layer_weight_mean = nn.Parameter(torch.zeros(self.channel_size * self.seq_len,self.hidden_layers_sizes[-1])) 99 | self.last_hidden_layer_weight_log_var = nn.Parameter(torch.zeros(self.channel_size * self.seq_len,self.hidden_layers_sizes[-1])) 100 | nn.init.xavier_normal_(self.last_hidden_layer_weight_mean) #Glorot initialization 101 | nn.init.constant_(self.last_hidden_layer_weight_log_var, self.logvar_init) 102 | 103 | self.last_hidden_layer_bias_mean = nn.Parameter(torch.zeros(self.alphabet_size * self.seq_len)) 104 | self.last_hidden_layer_bias_log_var = nn.Parameter(torch.zeros(self.alphabet_size * self.seq_len)) 105 | nn.init.constant_(self.last_hidden_layer_bias_mean, self.mu_bias_init) 106 | nn.init.constant_(self.last_hidden_layer_bias_log_var, self.logvar_init) 107 | 108 | if self.include_temperature_scaler: 109 | self.temperature_scaler_mean = nn.Parameter(torch.ones(1)) 110 | self.temperature_scaler_log_var = nn.Parameter(torch.ones(1) * self.logvar_init) 111 | 112 | def sampler(self, mean, log_var): 113 | """ 114 | Samples a latent vector via reparametrization trick 115 | """ 116 | eps = torch.randn_like(mean).to(self.device) 117 | z = torch.exp(0.5*log_var) * eps + mean 118 | return z 119 | 120 | def forward(self, z): 121 | batch_size = z.shape[0] 122 | if self.dropout_proba > 0.0: 123 | x = self.dropout_layer(z) 124 | else: 125 | x = z 126 | 127 | for layer_index in range(len(self.hidden_layers_sizes)-1): 128 | layer_i_weight = self.sampler(self.hidden_layers_mean[str(layer_index)].weight, self.hidden_layers_log_var[str(layer_index)].weight) 129 | layer_i_bias = self.sampler(self.hidden_layers_mean[str(layer_index)].bias, self.hidden_layers_log_var[str(layer_index)].bias) 130 | x = self.first_hidden_nonlinearity(F.linear(x, weight=layer_i_weight, bias=layer_i_bias)) 131 | if self.dropout_proba > 0.0: 132 | x = self.dropout_layer(x) 133 | 134 | last_index = len(self.hidden_layers_sizes)-1 135 | last_layer_weight = self.sampler(self.hidden_layers_mean[str(last_index)].weight, self.hidden_layers_log_var[str(last_index)].weight) 136 | last_layer_bias = self.sampler(self.hidden_layers_mean[str(last_index)].bias, self.hidden_layers_log_var[str(last_index)].bias) 137 | x = self.last_hidden_nonlinearity(F.linear(x, weight=last_layer_weight, bias=last_layer_bias)) 138 | if self.dropout_proba > 0.0: 139 | x = self.dropout_layer(x) 140 | 141 | W_out = self.sampler(self.last_hidden_layer_weight_mean, self.last_hidden_layer_weight_log_var) 142 | b_out = self.sampler(self.last_hidden_layer_bias_mean, self.last_hidden_layer_bias_log_var) 143 | 144 | if self.convolve_output: 145 | output_convolution_weight = self.sampler(self.output_convolution_mean.weight, self.output_convolution_log_var.weight) 146 | W_out = torch.mm(W_out.view(self.seq_len * self.hidden_layers_sizes[-1], self.channel_size), 147 | output_convolution_weight.view(self.channel_size,self.alphabet_size)) #product of size (H * seq_len, alphabet) 148 | 149 | if self.include_sparsity: 150 | sparsity_weights = self.sampler(self.sparsity_weight_mean,self.sparsity_weight_log_var) 151 | sparsity_tiled = sparsity_weights.repeat(self.num_tiles_sparsity,1) 152 | sparsity_tiled = nn.Sigmoid()(sparsity_tiled).unsqueeze(2) 153 | 154 | W_out = W_out.view(self.hidden_layers_sizes[-1], self.seq_len, self.alphabet_size) * sparsity_tiled 155 | 156 | W_out = W_out.view(self.seq_len * self.alphabet_size, self.hidden_layers_sizes[-1]) 157 | 158 | x = F.linear(x, weight=W_out, bias=b_out) 159 | 160 | if self.include_temperature_scaler: 161 | temperature_scaler = self.sampler(self.temperature_scaler_mean,self.temperature_scaler_log_var) 162 | x = torch.log(1.0+torch.exp(temperature_scaler)) * x 163 | 164 | x = x.view(batch_size, self.seq_len, self.alphabet_size) 165 | x_recon_log = F.log_softmax(x, dim=-1) #of shape (batch_size, seq_len, alphabet) 166 | 167 | return x_recon_log 168 | 169 | class VAE_Standard_MLP_decoder(nn.Module): 170 | """ 171 | Standard MLP decoder class for the VAE model. 172 | """ 173 | def __init__(self, seq_len, alphabet_size, hidden_layers_sizes, z_dim, first_hidden_nonlinearity, last_hidden_nonlinearity, dropout_proba, 174 | convolve_output, convolution_depth, include_temperature_scaler, include_sparsity, num_tiles_sparsity): 175 | """ 176 | Required input parameters: 177 | - seq_len: (Int) Sequence length of sequence alignment 178 | - alphabet_size: (Int) Alphabet size of sequence alignment (will be driven by the data helper object) 179 | - hidden_layers_sizes: (List) List of the sizes of the hidden layers (all DNNs) 180 | - z_dim: (Int) Dimension of latent space 181 | - first_hidden_nonlinearity: (Str) Type of non-linear activation applied on the first (set of) hidden layer(s) 182 | - last_hidden_nonlinearity: (Str) Type of non-linear activation applied on the very last hidden layer (pre-sparsity) 183 | - dropout_proba: (Float) Dropout probability applied on all hidden layers. If 0.0 then no dropout applied 184 | - convolve_output: (Bool) Whether to perform 1d convolution on output (kernel size 1, stide 1) 185 | - convolution_depth: (Int) Size of the 1D-convolution on output 186 | - include_temperature_scaler: (Bool) Whether we apply the global temperature scaler 187 | - include_sparsity: (Bool) Whether we use the sparsity inducing scheme on the output from the last hidden layer 188 | - num_tiles_sparsity: (Int) Number of tiles to use in the sparsity inducing scheme (the more the tiles, the stronger the sparsity) 189 | - bayesian_decoder: (Bool) Whether the decoder is bayesian or not 190 | """ 191 | super().__init__() 192 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 193 | self.seq_len = params['seq_len'] 194 | self.alphabet_size = params['alphabet_size'] 195 | self.hidden_layers_sizes = params['hidden_layers_sizes'] 196 | self.z_dim = params['z_dim'] 197 | self.bayesian_decoder = False 198 | self.dropout_proba = params['dropout_proba'] 199 | self.convolve_output = params['convolve_output'] 200 | self.convolution_depth = params['convolution_depth'] 201 | self.include_temperature_scaler = params['include_temperature_scaler'] 202 | self.include_sparsity = params['include_sparsity'] 203 | self.num_tiles_sparsity = params['num_tiles_sparsity'] 204 | 205 | self.mu_bias_init = 0.1 206 | 207 | self.hidden_layers=nn.ModuleDict() 208 | for layer_index in range(len(self.hidden_layers_sizes)): 209 | if layer_index==0: 210 | self.hidden_layers[str(layer_index)] = nn.Linear(self.z_dim, self.hidden_layers_sizes[layer_index]) 211 | nn.init.constant_(self.hidden_layers[str(layer_index)].bias, self.mu_bias_init) 212 | else: 213 | self.hidden_layers[str(layer_index)] = nn.Linear(self.hidden_layers_sizes[layer_index-1],self.hidden_layers_sizes[layer_index]) 214 | nn.init.constant_(self.hidden_layers[str(layer_index)].bias, self.mu_bias_init) 215 | 216 | if params['first_hidden_nonlinearity'] == 'relu': 217 | self.first_hidden_nonlinearity = nn.ReLU() 218 | elif params['first_hidden_nonlinearity'] == 'tanh': 219 | self.first_hidden_nonlinearity = nn.Tanh() 220 | elif params['first_hidden_nonlinearity'] == 'sigmoid': 221 | self.first_hidden_nonlinearity = nn.Sigmoid() 222 | elif params['first_hidden_nonlinearity'] == 'elu': 223 | self.first_hidden_nonlinearity = nn.ELU() 224 | elif params['first_hidden_nonlinearity'] == 'linear': 225 | self.first_hidden_nonlinearity = nn.Identity() 226 | 227 | if params['last_hidden_nonlinearity'] == 'relu': 228 | self.last_hidden_nonlinearity = nn.ReLU() 229 | elif params['last_hidden_nonlinearity'] == 'tanh': 230 | self.last_hidden_nonlinearity = nn.Tanh() 231 | elif params['last_hidden_nonlinearity'] == 'sigmoid': 232 | self.last_hidden_nonlinearity = nn.Sigmoid() 233 | elif params['last_hidden_nonlinearity'] == 'elu': 234 | self.last_hidden_nonlinearity = nn.ELU() 235 | elif params['last_hidden_nonlinearity'] == 'linear': 236 | self.last_hidden_nonlinearity = nn.Identity() 237 | 238 | if self.dropout_proba > 0.0: 239 | self.dropout_layer = nn.Dropout(p=self.dropout_proba) 240 | 241 | if self.convolve_output: 242 | self.output_convolution = nn.Conv1d(in_channels=self.convolution_depth,out_channels=self.alphabet_size,kernel_size=1,stride=1,bias=False) 243 | self.channel_size = self.convolution_depth 244 | else: 245 | self.channel_size = self.alphabet_size 246 | 247 | if self.include_sparsity: 248 | self.sparsity_weight = nn.Parameter(torch.randn(int(self.hidden_layers_sizes[-1]/self.num_tiles_sparsity), self.seq_len)) 249 | 250 | self.W_out = nn.Parameter(torch.zeros(self.channel_size * self.seq_len,self.hidden_layers_sizes[-1])) 251 | nn.init.xavier_normal_(self.W_out) #Initialize weights with Glorot initialization 252 | self.b_out = nn.Parameter(torch.zeros(self.alphabet_size * self.seq_len)) 253 | nn.init.constant_(self.b_out, self.mu_bias_init) 254 | 255 | if self.include_temperature_scaler: 256 | self.temperature_scaler = nn.Parameter(torch.ones(1)) 257 | 258 | def forward(self, z): 259 | batch_size = x.shape[0] 260 | if self.dropout_proba > 0.0: 261 | x = self.dropout_layer(z) 262 | else: 263 | x=z 264 | 265 | for layer_index in range(len(self.hidden_layers_sizes)-1): 266 | x = self.first_hidden_nonlinearity(self.hidden_layers[str(layer_index)](x)) 267 | if self.dropout_proba > 0.0: 268 | x = self.dropout_layer(x) 269 | 270 | x = self.last_hidden_nonlinearity(self.hidden_layers[str(len(self.hidden_layers_sizes)-1)](x)) #of size (batch_size,H) 271 | if self.dropout_proba > 0.0: 272 | x = self.dropout_layer(x) 273 | 274 | W_out = self.W_out.data 275 | 276 | if self.convolve_output: 277 | W_out = torch.mm(W_out.view(self.seq_len * self.hidden_layers_sizes[-1], self.channel_size), 278 | self.output_convolution.weight.view(self.channel_size,self.alphabet_size)) 279 | 280 | if self.include_sparsity: 281 | sparsity_tiled = self.sparsity_weight.repeat(self.num_tiles_sparsity,1) #of size (H,seq_len) 282 | sparsity_tiled = nn.Sigmoid()(sparsity_tiled).unsqueeze(2) #of size (H,seq_len,1) 283 | W_out = W_out.view(self.hidden_layers_sizes[-1], self.seq_len, self.alphabet_size) * sparsity_tiled 284 | 285 | W_out = W_out.view(self.seq_len * self.alphabet_size, self.hidden_layers_sizes[-1]) 286 | 287 | x = F.linear(x, weight=W_out, bias=self.b_out) 288 | 289 | if self.include_temperature_scaler: 290 | x = torch.log(1.0+torch.exp(self.temperature_scaler)) * x 291 | 292 | x = x.view(batch_size, self.seq_len, self.alphabet_size) 293 | x_recon_log = F.log_softmax(x, dim=-1) #of shape (batch_size, seq_len, alphabet) 294 | 295 | return x_recon_log -------------------------------------------------------------------------------- /EVE/VAE_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class VAE_MLP_encoder(nn.Module): 5 | """ 6 | MLP encoder class for the VAE model. 7 | """ 8 | def __init__(self,params): 9 | """ 10 | Required input parameters: 11 | - seq_len: (Int) Sequence length of sequence alignment 12 | - alphabet_size: (Int) Alphabet size of sequence alignment (will be driven by the data helper object) 13 | - hidden_layers_sizes: (List) List of sizes of DNN linear layers 14 | - z_dim: (Int) Size of latent space 15 | - convolve_input: (Bool) Whether to perform 1d convolution on input (kernel size 1, stide 1) 16 | - convolution_depth: (Int) Size of the 1D-convolution on input 17 | - nonlinear_activation: (Str) Type of non-linear activation to apply on each hidden layer 18 | - dropout_proba: (Float) Dropout probability applied on all hidden layers. If 0.0 then no dropout applied 19 | """ 20 | super().__init__() 21 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | self.seq_len = params['seq_len'] 23 | self.alphabet_size = params['alphabet_size'] 24 | self.hidden_layers_sizes = params['hidden_layers_sizes'] 25 | self.z_dim = params['z_dim'] 26 | self.convolve_input = params['convolve_input'] 27 | self.convolution_depth = params['convolution_input_depth'] 28 | self.dropout_proba = params['dropout_proba'] 29 | 30 | self.mu_bias_init = 0.1 31 | self.log_var_bias_init = -10.0 32 | 33 | #Convolving input with kernels of size 1 to capture potential similarities across amino acids when encoding sequences 34 | if self.convolve_input: 35 | self.input_convolution = nn.Conv1d(in_channels=self.alphabet_size,out_channels=self.convolution_depth,kernel_size=1,stride=1,bias=False) 36 | self.channel_size = self.convolution_depth 37 | else: 38 | self.channel_size = self.alphabet_size 39 | 40 | self.hidden_layers=torch.nn.ModuleDict() 41 | for layer_index in range(len(self.hidden_layers_sizes)): 42 | if layer_index==0: 43 | self.hidden_layers[str(layer_index)] = nn.Linear((self.channel_size*self.seq_len),self.hidden_layers_sizes[layer_index]) 44 | nn.init.constant_(self.hidden_layers[str(layer_index)].bias, self.mu_bias_init) 45 | else: 46 | self.hidden_layers[str(layer_index)] = nn.Linear(self.hidden_layers_sizes[layer_index-1],self.hidden_layers_sizes[layer_index]) 47 | nn.init.constant_(self.hidden_layers[str(layer_index)].bias, self.mu_bias_init) 48 | 49 | self.fc_mean = nn.Linear(self.hidden_layers_sizes[-1],self.z_dim) 50 | nn.init.constant_(self.fc_mean.bias, self.mu_bias_init) 51 | self.fc_log_var = nn.Linear(self.hidden_layers_sizes[-1],self.z_dim) 52 | nn.init.constant_(self.fc_log_var.bias, self.log_var_bias_init) 53 | 54 | # set up non-linearity 55 | if params['nonlinear_activation'] == 'relu': 56 | self.nonlinear_activation = nn.ReLU() 57 | elif params['nonlinear_activation'] == 'tanh': 58 | self.nonlinear_activation = nn.Tanh() 59 | elif params['nonlinear_activation'] == 'sigmoid': 60 | self.nonlinear_activation = nn.Sigmoid() 61 | elif params['nonlinear_activation'] == 'elu': 62 | self.nonlinear_activation = nn.ELU() 63 | elif params['nonlinear_activation'] == 'linear': 64 | self.nonlinear_activation = nn.Identity() 65 | 66 | if self.dropout_proba > 0.0: 67 | self.dropout_layer = nn.Dropout(p=self.dropout_proba) 68 | 69 | def forward(self, x): 70 | if self.dropout_proba > 0.0: 71 | x = self.dropout_layer(x) 72 | 73 | if self.convolve_input: 74 | x = x.permute(0,2,1) 75 | x = self.input_convolution(x) 76 | x = x.view(-1,self.seq_len*self.channel_size) 77 | else: 78 | x = x.view(-1,self.seq_len*self.channel_size) 79 | 80 | for layer_index in range(len(self.hidden_layers_sizes)): 81 | x = self.nonlinear_activation(self.hidden_layers[str(layer_index)](x)) 82 | if self.dropout_proba > 0.0: 83 | x = self.dropout_layer(x) 84 | 85 | z_mean = self.fc_mean(x) 86 | z_log_var = self.fc_log_var(x) 87 | 88 | return z_mean, z_log_var -------------------------------------------------------------------------------- /EVE/VAE_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import time 5 | import tqdm 6 | from scipy.special import erfinv 7 | from sklearn.model_selection import train_test_split 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | import torch.backends.cudnn as cudnn 14 | 15 | from . import VAE_encoder, VAE_decoder 16 | 17 | class VAE_model(nn.Module): 18 | """ 19 | Class for the VAE model with estimation of weights distribution parameters via Mean-Field VI. 20 | """ 21 | def __init__(self, 22 | model_name, 23 | data, 24 | encoder_parameters, 25 | decoder_parameters, 26 | random_seed 27 | ): 28 | 29 | super().__init__() 30 | 31 | self.model_name = model_name 32 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 33 | self.dtype = torch.float32 34 | self.random_seed = random_seed 35 | torch.manual_seed(random_seed) 36 | 37 | self.seq_len = data.seq_len 38 | self.alphabet_size = data.alphabet_size 39 | self.Neff = data.Neff 40 | 41 | encoder_parameters['seq_len'] = self.seq_len 42 | encoder_parameters['alphabet_size'] = self.alphabet_size 43 | decoder_parameters['seq_len'] = self.seq_len 44 | decoder_parameters['alphabet_size'] = self.alphabet_size 45 | 46 | self.encoder = VAE_encoder.VAE_MLP_encoder(params=encoder_parameters) 47 | if decoder_parameters['bayesian_decoder']: 48 | self.decoder = VAE_decoder.VAE_Bayesian_MLP_decoder(params=decoder_parameters) 49 | else: 50 | self.decoder = VAE_decoder.VAE_Standard_MLP_decoder(params=decoder_parameters) 51 | self.logit_sparsity_p = decoder_parameters['logit_sparsity_p'] 52 | 53 | def sample_latent(self, mu, log_var): 54 | """ 55 | Samples a latent vector via reparametrization trick 56 | """ 57 | eps = torch.randn_like(mu).to(self.device) 58 | z = torch.exp(0.5*log_var) * eps + mu 59 | return z 60 | 61 | def KLD_diag_gaussians(self, mu, logvar, p_mu, p_logvar): 62 | """ 63 | KL divergence between diagonal gaussian with prior diagonal gaussian. 64 | """ 65 | KLD = 0.5 * (p_logvar - logvar) + 0.5 * (torch.exp(logvar) + torch.pow(mu-p_mu,2)) / (torch.exp(p_logvar)+1e-20) - 0.5 66 | 67 | return torch.sum(KLD) 68 | 69 | def annealing_factor(self, annealing_warm_up, training_step): 70 | """ 71 | Annealing schedule of KL to focus on reconstruction error in early stages of training 72 | """ 73 | if training_step < annealing_warm_up: 74 | return training_step/annealing_warm_up 75 | else: 76 | return 1 77 | 78 | def KLD_global_parameters(self): 79 | """ 80 | KL divergence between the variational distributions and the priors (for the decoder weights). 81 | """ 82 | KLD_decoder_params = 0.0 83 | zero_tensor = torch.tensor(0.0).to(self.device) 84 | 85 | for layer_index in range(len(self.decoder.hidden_layers_sizes)): 86 | for param_type in ['weight','bias']: 87 | KLD_decoder_params += self.KLD_diag_gaussians( 88 | self.decoder.state_dict(keep_vars=True)['hidden_layers_mean.'+str(layer_index)+'.'+param_type].flatten(), 89 | self.decoder.state_dict(keep_vars=True)['hidden_layers_log_var.'+str(layer_index)+'.'+param_type].flatten(), 90 | zero_tensor, 91 | zero_tensor 92 | ) 93 | 94 | for param_type in ['weight','bias']: 95 | KLD_decoder_params += self.KLD_diag_gaussians( 96 | self.decoder.state_dict(keep_vars=True)['last_hidden_layer_'+param_type+'_mean'].flatten(), 97 | self.decoder.state_dict(keep_vars=True)['last_hidden_layer_'+param_type+'_log_var'].flatten(), 98 | zero_tensor, 99 | zero_tensor 100 | ) 101 | 102 | if self.decoder.include_sparsity: 103 | self.logit_scale_sigma = 4.0 104 | self.logit_scale_mu = 2.0**0.5 * self.logit_scale_sigma * erfinv(2.0 * self.logit_sparsity_p - 1.0) 105 | 106 | sparsity_mu = torch.tensor(self.logit_scale_mu).to(self.device) 107 | sparsity_log_var = torch.log(torch.tensor(self.logit_scale_sigma**2)).to(self.device) 108 | KLD_decoder_params += self.KLD_diag_gaussians( 109 | self.decoder.state_dict(keep_vars=True)['sparsity_weight_mean'].flatten(), 110 | self.decoder.state_dict(keep_vars=True)['sparsity_weight_log_var'].flatten(), 111 | sparsity_mu, 112 | sparsity_log_var 113 | ) 114 | 115 | if self.decoder.convolve_output: 116 | for param_type in ['weight']: 117 | KLD_decoder_params += self.KLD_diag_gaussians( 118 | self.decoder.state_dict(keep_vars=True)['output_convolution_mean.'+param_type].flatten(), 119 | self.decoder.state_dict(keep_vars=True)['output_convolution_log_var.'+param_type].flatten(), 120 | zero_tensor, 121 | zero_tensor 122 | ) 123 | 124 | if self.decoder.include_temperature_scaler: 125 | KLD_decoder_params += self.KLD_diag_gaussians( 126 | self.decoder.state_dict(keep_vars=True)['temperature_scaler_mean'].flatten(), 127 | self.decoder.state_dict(keep_vars=True)['temperature_scaler_log_var'].flatten(), 128 | zero_tensor, 129 | zero_tensor 130 | ) 131 | return KLD_decoder_params 132 | 133 | def loss_function(self, x_recon_log, x, mu, log_var, kl_latent_scale, kl_global_params_scale, annealing_warm_up, training_step, Neff): 134 | """ 135 | Returns mean of negative ELBO, reconstruction loss and KL divergence across batch x. 136 | """ 137 | BCE = F.binary_cross_entropy_with_logits(x_recon_log, x, reduction='sum') / x.shape[0] 138 | KLD_latent = (-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())) / x.shape[0] 139 | if self.decoder.bayesian_decoder: 140 | KLD_decoder_params_normalized = self.KLD_global_parameters() / Neff 141 | else: 142 | KLD_decoder_params_normalized = 0.0 143 | warm_up_scale = self.annealing_factor(annealing_warm_up,training_step) 144 | neg_ELBO = BCE + warm_up_scale * (kl_latent_scale * KLD_latent + kl_global_params_scale * KLD_decoder_params_normalized) 145 | return neg_ELBO, BCE, KLD_latent, KLD_decoder_params_normalized 146 | 147 | def all_likelihood_components(self, x): 148 | """ 149 | Returns tensors of ELBO, reconstruction loss and KL divergence for each point in batch x. 150 | """ 151 | mu, log_var = self.encoder(x) 152 | z = self.sample_latent(mu, log_var) 153 | recon_x_log = self.decoder(z) 154 | 155 | recon_x_log = recon_x_log.view(-1,self.alphabet_size*self.seq_len) 156 | x = x.view(-1,self.alphabet_size*self.seq_len) 157 | 158 | BCE_batch_tensor = torch.sum(F.binary_cross_entropy_with_logits(recon_x_log, x, reduction='none'),dim=1) 159 | KLD_batch_tensor = (-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(),dim=1)) 160 | 161 | ELBO_batch_tensor = -(BCE_batch_tensor + KLD_batch_tensor) 162 | 163 | return ELBO_batch_tensor, BCE_batch_tensor, KLD_batch_tensor 164 | 165 | def train_model(self, data, training_parameters): 166 | """ 167 | Training procedure for the VAE model. 168 | If use_validation_set is True then: 169 | - we split the alignment data in train/val sets. 170 | - we train up to num_training_steps steps but store the version of the model with lowest loss on validation set across training 171 | If not, then we train the model for num_training_steps and save the model at the end of training 172 | """ 173 | if torch.cuda.is_available(): 174 | cudnn.benchmark = True 175 | self.train() 176 | 177 | if training_parameters['log_training_info']: 178 | filename = training_parameters['training_logs_location']+os.sep+self.model_name+"_losses.csv" 179 | with open(filename, "a") as logs: 180 | logs.write("Number of sequences in alignment file:\t"+str(data.num_sequences)+"\n") 181 | logs.write("Neff:\t"+str(self.Neff)+"\n") 182 | logs.write("Alignment sequence length:\t"+str(data.seq_len)+"\n") 183 | 184 | optimizer = optim.Adam(self.parameters(), lr=training_parameters['learning_rate'], weight_decay = training_parameters['l2_regularization']) 185 | 186 | if training_parameters['use_lr_scheduler']: 187 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=training_parameters['lr_scheduler_step_size'], gamma=training_parameters['lr_scheduler_gamma']) 188 | 189 | if training_parameters['use_validation_set']: 190 | x_train, x_val, weights_train, weights_val = train_test_split(data.one_hot_encoding, data.weights, test_size=training_parameters['validation_set_pct'], random_state=self.random_seed) 191 | best_val_loss = float('inf') 192 | best_model_step_index=0 193 | else: 194 | x_train = data.one_hot_encoding 195 | weights_train = data.weights 196 | best_val_loss = None 197 | best_model_step_index = training_parameters['num_training_steps'] 198 | 199 | batch_order = np.arange(x_train.shape[0]) 200 | seq_sample_probs = weights_train / np.sum(weights_train) 201 | 202 | self.Neff_training = np.sum(weights_train) 203 | N_training = x_train.shape[0] 204 | 205 | start = time.time() 206 | train_loss = 0 207 | 208 | for training_step in tqdm.tqdm(range(1,training_parameters['num_training_steps']+1), desc="Training model"): 209 | 210 | batch_index = np.random.choice(batch_order, training_parameters['batch_size'], p=seq_sample_probs).tolist() 211 | x = torch.tensor(x_train[batch_index], dtype=self.dtype).to(self.device) 212 | optimizer.zero_grad() 213 | 214 | mu, log_var = self.encoder(x) 215 | z = self.sample_latent(mu, log_var) 216 | recon_x_log = self.decoder(z) 217 | 218 | neg_ELBO, BCE, KLD_latent, KLD_decoder_params_normalized = self.loss_function(recon_x_log, x, mu, log_var, training_parameters['kl_latent_scale'], training_parameters['kl_global_params_scale'], training_parameters['annealing_warm_up'], training_step, self.Neff_training) 219 | 220 | neg_ELBO.backward() 221 | optimizer.step() 222 | 223 | if training_parameters['use_lr_scheduler']: 224 | scheduler.step() 225 | 226 | if training_step % training_parameters['log_training_freq'] == 0: 227 | progress = "|Train : Update {0}. Negative ELBO : {1:.3f}, BCE: {2:.3f}, KLD_latent: {3:.3f}, KLD_decoder_params_norm: {4:.3f}, Time: {5:.2f} |".format(training_step, neg_ELBO, BCE, KLD_latent, KLD_decoder_params_normalized, time.time() - start) 228 | print(progress) 229 | 230 | if training_parameters['log_training_info']: 231 | with open(filename, "a") as logs: 232 | logs.write(progress+"\n") 233 | 234 | if training_step % training_parameters['save_model_params_freq']==0: 235 | self.save(model_checkpoint=training_parameters['model_checkpoint_location']+os.sep+self.model_name+"_step_"+str(training_step), 236 | encoder_parameters=encoder_parameters, 237 | decoder_parameters=decoder_parameters, 238 | training_parameters=training_parameters) 239 | 240 | if training_parameters['use_validation_set'] and training_step % training_parameters['validation_freq'] == 0: 241 | x_val = torch.tensor(x_val, dtype=self.dtype).to(self.device) 242 | val_neg_ELBO, val_BCE, val_KLD_latent, val_KLD_global_parameters = self.test_model(x_val, weights_val, training_parameters['batch_size']) 243 | 244 | progress_val = "\t\t\t|Val : Update {0}. Negative ELBO : {1:.3f}, BCE: {2:.3f}, KLD_latent: {3:.3f}, KLD_decoder_params_norm: {4:.3f}, Time: {5:.2f} |".format(training_step, val_neg_ELBO, val_BCE, val_KLD_latent, val_KLD_global_parameters, time.time() - start) 245 | print(progress_val) 246 | if training_parameters['log_training_info']: 247 | with open(filename, "a") as logs: 248 | logs.write(progress_val+"\n") 249 | 250 | if val_neg_ELBO < best_val_loss: 251 | best_val_loss = val_neg_ELBO 252 | best_model_step_index = training_step 253 | self.save(model_checkpoint=training_parameters['model_checkpoint_location']+os.sep+self.model_name+"_best", 254 | encoder_parameters=encoder_parameters, 255 | decoder_parameters=decoder_parameters, 256 | training_parameters=training_parameters) 257 | self.train() 258 | 259 | def test_model(self, x_val, weights_val, batch_size): 260 | self.eval() 261 | 262 | with torch.no_grad(): 263 | val_batch_order = np.arange(x_val.shape[0]) 264 | val_seq_sample_probs = weights_val / np.sum(weights_val) 265 | 266 | val_batch_index = np.random.choice(val_batch_order, batch_size, p=val_seq_sample_probs).tolist() 267 | x = torch.tensor(x_val[val_batch_index], dtype=self.dtype).to(self.device) 268 | mu, log_var = self.encoder(x) 269 | z = self.sample_latent(mu, log_var) 270 | recon_x_log = self.decoder(z) 271 | 272 | neg_ELBO, BCE, KLD_latent, KLD_global_parameters = self.loss_function(recon_x_log, x, mu, log_var, kl_latent_scale=1.0, kl_global_params_scale=1.0, annealing_warm_up=0, training_step=1, Neff = self.Neff_training) #set annealing factor to 1 273 | 274 | return neg_ELBO.item(), BCE.item(), KLD_latent.item(), KLD_global_parameters.item() 275 | 276 | 277 | def save(self, model_checkpoint, encoder_parameters, decoder_parameters, training_parameters, batch_size=256): 278 | torch.save({ 279 | 'model_state_dict':self.state_dict(), 280 | 'encoder_parameters':encoder_parameters, 281 | 'decoder_parameters':decoder_parameters, 282 | 'training_parameters':training_parameters, 283 | }, model_checkpoint) 284 | 285 | def compute_evol_indices(self, msa_data, list_mutations_location, num_samples, batch_size=256): 286 | """ 287 | The column in the list_mutations dataframe that contains the mutant(s) for a given variant should be called "mutations" 288 | """ 289 | #Multiple mutations are to be passed colon-separated 290 | list_mutations=pd.read_csv(list_mutations_location, header=0) 291 | 292 | #Remove (multiple) mutations that are invalid 293 | list_valid_mutations = ['wt'] 294 | list_valid_mutated_sequences = {} 295 | list_valid_mutated_sequences['wt'] = msa_data.focus_seq_trimmed # first sequence in the list is the wild_type 296 | for mutation in list_mutations['mutations']: 297 | individual_substitutions = mutation.split(':') 298 | mutated_sequence = list(msa_data.focus_seq_trimmed)[:] 299 | fully_valid_mutation = True 300 | for mut in individual_substitutions: 301 | wt_aa, pos, mut_aa = mut[0], int(mut[1:-1]), mut[-1] 302 | if pos not in msa_data.uniprot_focus_col_to_wt_aa_dict or msa_data.uniprot_focus_col_to_wt_aa_dict[pos] != wt_aa or mut not in msa_data.mutant_to_letter_pos_idx_focus_list: 303 | print ("Not a valid mutant: "+mutation) 304 | fully_valid_mutation = False 305 | break 306 | else: 307 | wt_aa,pos,idx_focus = msa_data.mutant_to_letter_pos_idx_focus_list[mut] 308 | mutated_sequence[idx_focus] = mut_aa #perform the corresponding AA substitution 309 | 310 | if fully_valid_mutation: 311 | list_valid_mutations.append(mutation) 312 | list_valid_mutated_sequences[mutation] = ''.join(mutated_sequence) 313 | 314 | #One-hot encoding of mutated sequences 315 | mutated_sequences_one_hot = np.zeros((len(list_valid_mutations),len(msa_data.focus_cols),len(msa_data.alphabet))) 316 | for i,mutation in enumerate(list_valid_mutations): 317 | sequence = list_valid_mutated_sequences[mutation] 318 | for j,letter in enumerate(sequence): 319 | if letter in msa_data.aa_dict: 320 | k = msa_data.aa_dict[letter] 321 | mutated_sequences_one_hot[i,j,k] = 1.0 322 | 323 | mutated_sequences_one_hot = torch.tensor(mutated_sequences_one_hot) 324 | dataloader = torch.utils.data.DataLoader(mutated_sequences_one_hot, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) 325 | prediction_matrix = torch.zeros((len(list_valid_mutations),num_samples)) 326 | 327 | with torch.no_grad(): 328 | for i, batch in enumerate(tqdm.tqdm(dataloader, 'Looping through mutation batches')): 329 | x = batch.type(self.dtype).to(self.device) 330 | for j in tqdm.tqdm(range(num_samples), 'Looping through number of samples for batch #: '+str(i+1)): 331 | seq_predictions, _, _ = self.all_likelihood_components(x) 332 | prediction_matrix[i*batch_size:i*batch_size+len(x),j] = seq_predictions 333 | tqdm.tqdm.write('\n') 334 | mean_predictions = prediction_matrix.mean(dim=1, keepdim=False) 335 | std_predictions = prediction_matrix.std(dim=1, keepdim=False) 336 | delta_elbos = mean_predictions - mean_predictions[0] 337 | evol_indices = - delta_elbos.detach().cpu().numpy() 338 | 339 | return list_valid_mutations, evol_indices, mean_predictions[0].detach().cpu().numpy(), std_predictions.detach().cpu().numpy() -------------------------------------------------------------------------------- /EVE/default_model_params.json: -------------------------------------------------------------------------------- 1 | { "encoder_parameters": { 2 | "hidden_layers_sizes" : [2000,1000,300], 3 | "z_dim" : 50, 4 | "convolve_input" : false, 5 | "convolution_input_depth" : 40, 6 | "nonlinear_activation" : "relu", 7 | "dropout_proba" : 0.0 8 | }, 9 | "decoder_parameters": { 10 | "hidden_layers_sizes" : [300,1000,2000], 11 | "z_dim" : 50, 12 | "bayesian_decoder" : true, 13 | "first_hidden_nonlinearity" : "relu", 14 | "last_hidden_nonlinearity" : "relu", 15 | "dropout_proba" : 0.1, 16 | "convolve_output" : true, 17 | "convolution_output_depth" : 40, 18 | "include_temperature_scaler" : true, 19 | "include_sparsity" : false, 20 | "num_tiles_sparsity" : 0, 21 | "logit_sparsity_p" : 0 22 | }, 23 | "training_parameters": { 24 | "num_training_steps" : 400000, 25 | "learning_rate" : 1e-4, 26 | "batch_size" : 256, 27 | "annealing_warm_up" : 0, 28 | "kl_latent_scale" : 1.0, 29 | "kl_global_params_scale" : 1.0, 30 | "l2_regularization" : 0.0, 31 | "use_lr_scheduler" : false, 32 | "use_validation_set" : false, 33 | "validation_set_pct" : 0.2, 34 | "validation_freq" : 1000, 35 | "log_training_info" : true, 36 | "log_training_freq" : 1000, 37 | "save_model_params_freq" : 500000 38 | } 39 | } 40 | 41 | 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Pascal Notin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Evolutionary model of Variant Effects (EVE) 2 | 3 | Please note that we have migrated the official repo to the following address: https://github.com/OATML-Markslab/EVE. 4 | 5 | ## Overview 6 | EVE is a set of protein-specific models providing for any single amino acid mutation of interest a score reflecting the propensity of the resulting protein to be pathogenic. For each protein family, a Bayesian VAE learns a distribution over amino acid sequences from evolutionary data. It enables the computation of an evolutionary index for each mutant, which approximates the log-likelihood ratio of the mutant vs the wild type. A global-local mixture of Gaussian Mixture Models separates variants into benign and pathogenic clusters based on that index. The EVE scores reflect probabilistic assignments to the pathogenic cluster. 7 | 8 | ## Usage 9 | The end to end process to compute EVE scores consists of three consecutive steps: 10 | 1. Train the Bayesian VAE on a re-weighted multiple sequence alignment (MSA) for the protein of interest => train_VAE.py 11 | 2. Compute the evolutionary indices for all single amino acid mutations => compute_evol_indices.py 12 | 3. Train a GMM to cluster variants on the basis of the evol indices then output scores and uncertainties on the class assignments => train_GMM_and_compute_EVE_scores.py 13 | We also provide all EVE scores for all single amino acid mutations for thousands of proteins at the following address: http://evemodel.org/. 14 | 15 | ## Example scripts 16 | The "examples" folder contains sample bash scripts to obtain EVE scores for a protein of interest (using PTEN as an example). 17 | MSAs and ClinVar labels are provided for 4 proteins (P53, PTEN, RASH and SCN5A) in the data folder. 18 | 19 | ## Data requirements 20 | The only data required to train EVE models and obtain EVE scores from scratch are the multiple sequence alignments (MSAs) for the corresponding proteins. 21 | 22 | ### MSA creation 23 | We built multiple sequence alignments for each protein family by performing five search iterations of the profile HMM homology search tool Jackhmmer against the UniRef100 database of non-redundant protein sequences (downloaded on April 20th 2020). Please refer to the supplementary notes of the EVE paper (section 3.1.1) for a detailed description of the MSA creation process. 24 | Our github repo provides the MSAs for 4 proteins: P53, PTEN, RASH & SCN5A (see data/MSA). MSAs for all proteins may be accessed on our website (https://evemodel.org/). 25 | 26 | ### MSA pre-processing 27 | The EVE codebase provides basic functionalities to pre-process MSAs for modelling (see the MSA_processing class in utils/data_utils.py). By default, sequences with 50% or more gaps in the alignment and/or positions with less than 70% residue occupancy will be removed. These parameters may be adjusted as needed by the end user. 28 | 29 | ### ClinVar labels 30 | The script "train_GMM_and_compute_EVE_scores.py" provides functionalities to compare EVE scores with reference labels (e.g., ClinVar). Our github repo provides labels for 4 proteins: P53, PTEN, RASH & SCN5A (see data/labels). ClinVar labels for all proteins may be accessed on our website (https://evemodel.org/). 31 | 32 | ## Software requirements 33 | The entire codebase is written in python. Package requirements are as follows: 34 | - python=3.7 35 | - pytorch=1.7 36 | - cudatoolkit=11.0 37 | - scikit-learn=0.24.1 38 | - numpy=1.20.1 39 | - pandas=1.2.4 40 | - scipy=1.6.2 41 | - tqdm 42 | - matplotlib 43 | - seaborn 44 | 45 | The corresponding environment may be created via conda and the provided protein_env.yml file as follows: 46 | ``` 47 | conda env create -f protein_env.yml 48 | conda activate protein_env 49 | ``` 50 | 51 | ## License 52 | This project is available under the MIT license. 53 | 54 | ## Reference 55 | If you use this code, please cite the following paper: 56 | ```bibtex 57 | @article{Frazer2021DiseaseVP, 58 | title={Disease variant prediction with deep generative models of evolutionary data.}, 59 | author={Jonathan Frazer and Pascal Notin and Mafalda Dias and Aidan Gomez and Joseph K Min and Kelly P. Brock and Yarin Gal and Debora S. Marks}, 60 | journal={Nature}, 61 | year={2021} 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /compute_evol_indices.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import json 3 | import argparse 4 | import pandas as pd 5 | import torch 6 | 7 | from EVE import VAE_model 8 | from utils import data_utils 9 | 10 | if __name__=='__main__': 11 | 12 | parser = argparse.ArgumentParser(description='Evol indices') 13 | parser.add_argument('--MSA_data_folder', type=str, help='Folder where MSAs are stored') 14 | parser.add_argument('--MSA_list', type=str, help='List of proteins and corresponding MSA file name') 15 | parser.add_argument('--protein_index', type=int, help='Row index of protein in input mapping file') 16 | parser.add_argument('--MSA_weights_location', type=str, help='Location where weights for each sequence in the MSA will be stored') 17 | parser.add_argument('--theta_reweighting', type=float, help='Parameters for MSA sequence re-weighting') 18 | parser.add_argument('--VAE_checkpoint_location', type=str, help='Location where VAE model checkpoints will be stored') 19 | parser.add_argument('--model_name_suffix', default='Jan1', type=str, help='model checkpoint name is the protein name followed by this suffix') 20 | parser.add_argument('--model_parameters_location', type=str, help='Location of VAE model parameters') 21 | parser.add_argument('--computation_mode', type=str, help='Computes evol indices for all single AA mutations or for a passed in list of mutations (singles or multiples) [all_singles,input_mutations_list]') 22 | parser.add_argument('--all_singles_mutations_folder', type=str, help='Location for the list of generated single AA mutations') 23 | parser.add_argument('--mutations_location', type=str, help='Location of all mutations to compute the evol indices for') 24 | parser.add_argument('--output_evol_indices_location', type=str, help='Output location of computed evol indices') 25 | parser.add_argument('--output_evol_indices_filename_suffix', default='', type=str, help='(Optional) Suffix to be added to output filename') 26 | parser.add_argument('--num_samples_compute_evol_indices', type=int, help='Num of samples to approximate delta elbo when computing evol indices') 27 | parser.add_argument('--batch_size', default=256, type=int, help='Batch size when computing evol indices') 28 | args = parser.parse_args() 29 | 30 | mapping_file = pd.read_csv(args.MSA_list) 31 | protein_name = mapping_file['protein_name'][args.protein_index] 32 | msa_location = args.MSA_data_folder + os.sep + mapping_file['msa_location'][args.protein_index] 33 | print("Protein name: "+str(protein_name)) 34 | print("MSA file: "+str(msa_location)) 35 | 36 | if args.theta_reweighting is not None: 37 | theta = args.theta_reweighting 38 | else: 39 | try: 40 | theta = float(mapping_file['theta'][args.protein_index]) 41 | except: 42 | theta = 0.2 43 | print("Theta MSA re-weighting: "+str(theta)) 44 | 45 | data = data_utils.MSA_processing( 46 | MSA_location=msa_location, 47 | theta=theta, 48 | use_weights=True, 49 | weights_location=args.MSA_weights_location + os.sep + protein_name + '_theta_' + str(theta) + '.npy' 50 | ) 51 | 52 | if args.computation_mode=="all_singles": 53 | data.save_all_singles(output_filename=args.all_singles_mutations_folder + os.sep + protein_name + "_all_singles.csv") 54 | args.mutations_location = args.all_singles_mutations_folder + os.sep + protein_name + "_all_singles.csv" 55 | else: 56 | args.mutations_location = args.mutations_location + os.sep + protein_name + ".csv" 57 | 58 | model_name = protein_name + "_" + args.model_name_suffix 59 | print("Model name: "+str(model_name)) 60 | 61 | model_params = json.load(open(args.model_parameters_location)) 62 | 63 | model = VAE_model.VAE_model( 64 | model_name=model_name, 65 | data=data, 66 | encoder_parameters=model_params["encoder_parameters"], 67 | decoder_parameters=model_params["decoder_parameters"], 68 | random_seed=42 69 | ) 70 | model = model.to(model.device) 71 | 72 | try: 73 | checkpoint_name = str(args.VAE_checkpoint_location) + os.sep + model_name + "_final" 74 | checkpoint = torch.load(checkpoint_name) 75 | model.load_state_dict(checkpoint['model_state_dict']) 76 | print("Initialized VAE with checkpoint '{}' ".format(checkpoint_name)) 77 | except: 78 | print("Unable to locate VAE model checkpoint") 79 | sys.exit(0) 80 | 81 | list_valid_mutations, evol_indices, _, _ = model.compute_evol_indices(msa_data=data, 82 | list_mutations_location=args.mutations_location, 83 | num_samples=args.num_samples_compute_evol_indices, 84 | batch_size=args.batch_size) 85 | 86 | df = {} 87 | df['protein_name'] = protein_name 88 | df['mutations'] = list_valid_mutations 89 | df['evol_indices'] = evol_indices 90 | df = pd.DataFrame(df) 91 | 92 | evol_indices_output_filename = args.output_evol_indices_location+os.sep+protein_name+'_'+str(args.num_samples_compute_evol_indices)+'_samples'+args.output_evol_indices_filename_suffix+'.csv' 93 | try: 94 | keep_header = os.stat(evol_indices_output_filename).st_size == 0 95 | except: 96 | keep_header=True 97 | df.to_csv(path_or_buf=evol_indices_output_filename, index=False, mode='a', header=keep_header) -------------------------------------------------------------------------------- /data/labels/ClinVar_labels_P53_PTEN_RASH_SCN5A.csv: -------------------------------------------------------------------------------- 1 | protein_name,mutations,ClinVar_labels_category,ClinVar_labels 2 | P53_HUMAN,D7E,benign,0.0 3 | P53_HUMAN,P8L,benign,0.0 4 | P53_HUMAN,V10I,benign,0.0 5 | P53_HUMAN,N29D,benign,0.0 6 | P53_HUMAN,P36Q,benign,0.0 7 | P53_HUMAN,P47S,benign,0.0 8 | P53_HUMAN,D48E,benign,0.0 9 | P53_HUMAN,D49N,benign,0.0 10 | P53_HUMAN,F54L,benign,0.0 11 | P53_HUMAN,R65G,benign,0.0 12 | P53_HUMAN,V73M,benign,0.0 13 | P53_HUMAN,A74V,benign,0.0 14 | P53_HUMAN,P82L,benign,0.0 15 | P53_HUMAN,A84T,benign,0.0 16 | P53_HUMAN,P89T,benign,0.0 17 | P53_HUMAN,G105R,pathogenic,1.0 18 | P53_HUMAN,G105S,pathogenic,1.0 19 | P53_HUMAN,Y107H,benign,0.0 20 | P53_HUMAN,F109V,pathogenic,1.0 21 | P53_HUMAN,R110H,benign,0.0 22 | P53_HUMAN,R110L,pathogenic,1.0 23 | P53_HUMAN,R110P,pathogenic,1.0 24 | P53_HUMAN,T125K,pathogenic,1.0 25 | P53_HUMAN,T125R,pathogenic,1.0 26 | P53_HUMAN,A129S,benign,0.0 27 | P53_HUMAN,L130P,pathogenic,1.0 28 | P53_HUMAN,N131I,pathogenic,1.0 29 | P53_HUMAN,N131Y,pathogenic,1.0 30 | P53_HUMAN,K132N,pathogenic,1.0 31 | P53_HUMAN,M133T,pathogenic,1.0 32 | P53_HUMAN,F134C,pathogenic,1.0 33 | P53_HUMAN,C135G,pathogenic,1.0 34 | P53_HUMAN,A138P,pathogenic,1.0 35 | P53_HUMAN,C141Y,pathogenic,1.0 36 | P53_HUMAN,P151S,pathogenic,1.0 37 | P53_HUMAN,P151T,pathogenic,1.0 38 | P53_HUMAN,P152L,pathogenic,1.0 39 | P53_HUMAN,P152R,pathogenic,1.0 40 | P53_HUMAN,V157A,pathogenic,1.0 41 | P53_HUMAN,R158G,pathogenic,1.0 42 | P53_HUMAN,R158H,pathogenic,1.0 43 | P53_HUMAN,R158L,pathogenic,1.0 44 | P53_HUMAN,R158P,pathogenic,1.0 45 | P53_HUMAN,Y163C,pathogenic,1.0 46 | P53_HUMAN,Y163D,pathogenic,1.0 47 | P53_HUMAN,Q165K,benign,0.0 48 | P53_HUMAN,V172F,pathogenic,1.0 49 | P53_HUMAN,V173A,pathogenic,1.0 50 | P53_HUMAN,V173E,pathogenic,1.0 51 | P53_HUMAN,V173L,pathogenic,1.0 52 | P53_HUMAN,V173M,pathogenic,1.0 53 | P53_HUMAN,R175G,pathogenic,1.0 54 | P53_HUMAN,R175H,pathogenic,1.0 55 | P53_HUMAN,R175L,pathogenic,1.0 56 | P53_HUMAN,C176F,pathogenic,1.0 57 | P53_HUMAN,C176Y,pathogenic,1.0 58 | P53_HUMAN,P177R,pathogenic,1.0 59 | P53_HUMAN,H178D,pathogenic,1.0 60 | P53_HUMAN,H179P,pathogenic,1.0 61 | P53_HUMAN,H179Q,pathogenic,1.0 62 | P53_HUMAN,H179Y,pathogenic,1.0 63 | P53_HUMAN,E180K,pathogenic,1.0 64 | P53_HUMAN,R181H,pathogenic,1.0 65 | P53_HUMAN,R181P,pathogenic,1.0 66 | P53_HUMAN,S185N,benign,0.0 67 | P53_HUMAN,H193P,pathogenic,1.0 68 | P53_HUMAN,H193Q,pathogenic,1.0 69 | P53_HUMAN,H193R,pathogenic,1.0 70 | P53_HUMAN,H193Y,pathogenic,1.0 71 | P53_HUMAN,L194F,pathogenic,1.0 72 | P53_HUMAN,I195T,pathogenic,1.0 73 | P53_HUMAN,G199V,pathogenic,1.0 74 | P53_HUMAN,Y205C,pathogenic,1.0 75 | P53_HUMAN,R213P,pathogenic,1.0 76 | P53_HUMAN,R213Q,pathogenic,1.0 77 | P53_HUMAN,H214R,pathogenic,1.0 78 | P53_HUMAN,S215I,pathogenic,1.0 79 | P53_HUMAN,V218M,pathogenic,1.0 80 | P53_HUMAN,Y220C,pathogenic,1.0 81 | P53_HUMAN,Y220S,pathogenic,1.0 82 | P53_HUMAN,I232N,pathogenic,1.0 83 | P53_HUMAN,I232T,pathogenic,1.0 84 | P53_HUMAN,Y234C,pathogenic,1.0 85 | P53_HUMAN,Y234H,pathogenic,1.0 86 | P53_HUMAN,N235D,pathogenic,1.0 87 | P53_HUMAN,N235S,benign,0.0 88 | P53_HUMAN,M237I,pathogenic,1.0 89 | P53_HUMAN,C238F,pathogenic,1.0 90 | P53_HUMAN,C238R,pathogenic,1.0 91 | P53_HUMAN,C238S,pathogenic,1.0 92 | P53_HUMAN,C238W,pathogenic,1.0 93 | P53_HUMAN,C238Y,pathogenic,1.0 94 | P53_HUMAN,N239S,pathogenic,1.0 95 | P53_HUMAN,S240R,pathogenic,1.0 96 | P53_HUMAN,S241C,pathogenic,1.0 97 | P53_HUMAN,S241F,pathogenic,1.0 98 | P53_HUMAN,S241Y,pathogenic,1.0 99 | P53_HUMAN,C242R,pathogenic,1.0 100 | P53_HUMAN,C242Y,pathogenic,1.0 101 | P53_HUMAN,G244D,pathogenic,1.0 102 | P53_HUMAN,G244S,pathogenic,1.0 103 | P53_HUMAN,G245A,pathogenic,1.0 104 | P53_HUMAN,G245C,pathogenic,1.0 105 | P53_HUMAN,G245D,pathogenic,1.0 106 | P53_HUMAN,G245S,pathogenic,1.0 107 | P53_HUMAN,G245V,pathogenic,1.0 108 | P53_HUMAN,M246K,pathogenic,1.0 109 | P53_HUMAN,M246V,pathogenic,1.0 110 | P53_HUMAN,R248L,pathogenic,1.0 111 | P53_HUMAN,R248P,pathogenic,1.0 112 | P53_HUMAN,R248Q,pathogenic,1.0 113 | P53_HUMAN,R248W,pathogenic,1.0 114 | P53_HUMAN,R249T,pathogenic,1.0 115 | P53_HUMAN,I251L,pathogenic,1.0 116 | P53_HUMAN,I251S,pathogenic,1.0 117 | P53_HUMAN,I254N,pathogenic,1.0 118 | P53_HUMAN,E258K,pathogenic,1.0 119 | P53_HUMAN,N263S,benign,0.0 120 | P53_HUMAN,L265P,pathogenic,1.0 121 | P53_HUMAN,G266E,pathogenic,1.0 122 | P53_HUMAN,R267W,pathogenic,1.0 123 | P53_HUMAN,F270S,pathogenic,1.0 124 | P53_HUMAN,R273C,pathogenic,1.0 125 | P53_HUMAN,R273G,pathogenic,1.0 126 | P53_HUMAN,R273H,pathogenic,1.0 127 | P53_HUMAN,R273L,pathogenic,1.0 128 | P53_HUMAN,R273P,pathogenic,1.0 129 | P53_HUMAN,R273S,pathogenic,1.0 130 | P53_HUMAN,C275R,pathogenic,1.0 131 | P53_HUMAN,C275W,pathogenic,1.0 132 | P53_HUMAN,C275Y,pathogenic,1.0 133 | P53_HUMAN,A276D,pathogenic,1.0 134 | P53_HUMAN,A276G,pathogenic,1.0 135 | P53_HUMAN,G279E,pathogenic,1.0 136 | P53_HUMAN,R280S,pathogenic,1.0 137 | P53_HUMAN,D281G,pathogenic,1.0 138 | P53_HUMAN,D281N,pathogenic,1.0 139 | P53_HUMAN,D281V,pathogenic,1.0 140 | P53_HUMAN,D281Y,pathogenic,1.0 141 | P53_HUMAN,R282G,pathogenic,1.0 142 | P53_HUMAN,R282P,pathogenic,1.0 143 | P53_HUMAN,R282W,pathogenic,1.0 144 | P53_HUMAN,E285K,pathogenic,1.0 145 | P53_HUMAN,E285V,pathogenic,1.0 146 | P53_HUMAN,E286K,pathogenic,1.0 147 | P53_HUMAN,R290H,benign,0.0 148 | P53_HUMAN,G293W,benign,0.0 149 | P53_HUMAN,H296Y,benign,0.0 150 | P53_HUMAN,E298K,benign,0.0 151 | P53_HUMAN,T312S,benign,0.0 152 | P53_HUMAN,G334W,pathogenic,1.0 153 | P53_HUMAN,R337C,pathogenic,1.0 154 | P53_HUMAN,R337H,pathogenic,1.0 155 | P53_HUMAN,R337L,pathogenic,1.0 156 | P53_HUMAN,R337P,pathogenic,1.0 157 | P53_HUMAN,R342P,pathogenic,1.0 158 | P53_HUMAN,E346D,benign,0.0 159 | P53_HUMAN,A347D,pathogenic,1.0 160 | P53_HUMAN,G360A,benign,0.0 161 | P53_HUMAN,H365Y,benign,0.0 162 | P53_HUMAN,S366A,benign,0.0 163 | P53_HUMAN,G374R,benign,0.0 164 | P53_HUMAN,E388A,benign,0.0 165 | PTEN_HUMAN,M1V,pathogenic,1.0 166 | PTEN_HUMAN,N12I,pathogenic,1.0 167 | PTEN_HUMAN,N12T,pathogenic,1.0 168 | PTEN_HUMAN,K13E,pathogenic,1.0 169 | PTEN_HUMAN,R14G,pathogenic,1.0 170 | PTEN_HUMAN,R15K,pathogenic,1.0 171 | PTEN_HUMAN,R15S,pathogenic,1.0 172 | PTEN_HUMAN,Y16D,pathogenic,1.0 173 | PTEN_HUMAN,Y16H,pathogenic,1.0 174 | PTEN_HUMAN,D24G,pathogenic,1.0 175 | PTEN_HUMAN,D24N,pathogenic,1.0 176 | PTEN_HUMAN,D24Y,pathogenic,1.0 177 | PTEN_HUMAN,T26P,pathogenic,1.0 178 | PTEN_HUMAN,Y27C,pathogenic,1.0 179 | PTEN_HUMAN,Y27S,pathogenic,1.0 180 | PTEN_HUMAN,M35L,pathogenic,1.0 181 | PTEN_HUMAN,M35T,pathogenic,1.0 182 | PTEN_HUMAN,M35V,pathogenic,1.0 183 | PTEN_HUMAN,G36R,pathogenic,1.0 184 | PTEN_HUMAN,P38R,pathogenic,1.0 185 | PTEN_HUMAN,G44D,pathogenic,1.0 186 | PTEN_HUMAN,R47G,pathogenic,1.0 187 | PTEN_HUMAN,R47K,pathogenic,1.0 188 | PTEN_HUMAN,N48K,pathogenic,1.0 189 | PTEN_HUMAN,V53A,pathogenic,1.0 190 | PTEN_HUMAN,L57W,pathogenic,1.0 191 | PTEN_HUMAN,Y68C,pathogenic,1.0 192 | PTEN_HUMAN,Y68D,pathogenic,1.0 193 | PTEN_HUMAN,Y68H,pathogenic,1.0 194 | PTEN_HUMAN,Y68S,pathogenic,1.0 195 | PTEN_HUMAN,A79T,benign,0.0 196 | PTEN_HUMAN,F81C,pathogenic,1.0 197 | PTEN_HUMAN,F90S,pathogenic,1.0 198 | PTEN_HUMAN,D92A,pathogenic,1.0 199 | PTEN_HUMAN,D92G,pathogenic,1.0 200 | PTEN_HUMAN,H93N,pathogenic,1.0 201 | PTEN_HUMAN,H93R,pathogenic,1.0 202 | PTEN_HUMAN,H93Y,pathogenic,1.0 203 | PTEN_HUMAN,P95L,pathogenic,1.0 204 | PTEN_HUMAN,P95T,pathogenic,1.0 205 | PTEN_HUMAN,P96A,pathogenic,1.0 206 | PTEN_HUMAN,P96L,pathogenic,1.0 207 | PTEN_HUMAN,P96S,pathogenic,1.0 208 | PTEN_HUMAN,I101T,pathogenic,1.0 209 | PTEN_HUMAN,C105Y,pathogenic,1.0 210 | PTEN_HUMAN,D107V,pathogenic,1.0 211 | PTEN_HUMAN,L108P,pathogenic,1.0 212 | PTEN_HUMAN,L112P,pathogenic,1.0 213 | PTEN_HUMAN,V119L,pathogenic,1.0 214 | PTEN_HUMAN,A120E,pathogenic,1.0 215 | PTEN_HUMAN,H123D,pathogenic,1.0 216 | PTEN_HUMAN,H123Q,pathogenic,1.0 217 | PTEN_HUMAN,H123R,pathogenic,1.0 218 | PTEN_HUMAN,H123Y,pathogenic,1.0 219 | PTEN_HUMAN,C124G,pathogenic,1.0 220 | PTEN_HUMAN,C124R,pathogenic,1.0 221 | PTEN_HUMAN,C124S,pathogenic,1.0 222 | PTEN_HUMAN,C124Y,pathogenic,1.0 223 | PTEN_HUMAN,G127E,pathogenic,1.0 224 | PTEN_HUMAN,G127R,pathogenic,1.0 225 | PTEN_HUMAN,K128N,pathogenic,1.0 226 | PTEN_HUMAN,G129E,pathogenic,1.0 227 | PTEN_HUMAN,G129R,pathogenic,1.0 228 | PTEN_HUMAN,G129V,pathogenic,1.0 229 | PTEN_HUMAN,R130P,pathogenic,1.0 230 | PTEN_HUMAN,R130Q,pathogenic,1.0 231 | PTEN_HUMAN,G132D,pathogenic,1.0 232 | PTEN_HUMAN,G132S,pathogenic,1.0 233 | PTEN_HUMAN,G132V,pathogenic,1.0 234 | PTEN_HUMAN,M134I,pathogenic,1.0 235 | PTEN_HUMAN,M134R,pathogenic,1.0 236 | PTEN_HUMAN,M134T,pathogenic,1.0 237 | PTEN_HUMAN,I135K,pathogenic,1.0 238 | PTEN_HUMAN,I135V,pathogenic,1.0 239 | PTEN_HUMAN,C136R,pathogenic,1.0 240 | PTEN_HUMAN,C136W,pathogenic,1.0 241 | PTEN_HUMAN,C136Y,pathogenic,1.0 242 | PTEN_HUMAN,Y155C,pathogenic,1.0 243 | PTEN_HUMAN,Y155H,pathogenic,1.0 244 | PTEN_HUMAN,Y155S,pathogenic,1.0 245 | PTEN_HUMAN,R159G,pathogenic,1.0 246 | PTEN_HUMAN,R159M,pathogenic,1.0 247 | PTEN_HUMAN,T160I,pathogenic,1.0 248 | PTEN_HUMAN,D162E,pathogenic,1.0 249 | PTEN_HUMAN,G165R,pathogenic,1.0 250 | PTEN_HUMAN,V166E,pathogenic,1.0 251 | PTEN_HUMAN,S170I,pathogenic,1.0 252 | PTEN_HUMAN,S170R,pathogenic,1.0 253 | PTEN_HUMAN,Q171R,pathogenic,1.0 254 | PTEN_HUMAN,R173C,pathogenic,1.0 255 | PTEN_HUMAN,R173H,pathogenic,1.0 256 | PTEN_HUMAN,R173L,pathogenic,1.0 257 | PTEN_HUMAN,R173P,pathogenic,1.0 258 | PTEN_HUMAN,Y174C,pathogenic,1.0 259 | PTEN_HUMAN,Y174N,pathogenic,1.0 260 | PTEN_HUMAN,Y177H,pathogenic,1.0 261 | PTEN_HUMAN,F200V,pathogenic,1.0 262 | PTEN_HUMAN,T202I,pathogenic,1.0 263 | PTEN_HUMAN,P204A,pathogenic,1.0 264 | PTEN_HUMAN,F241L,pathogenic,1.0 265 | PTEN_HUMAN,F241S,pathogenic,1.0 266 | PTEN_HUMAN,P246L,pathogenic,1.0 267 | PTEN_HUMAN,L247S,pathogenic,1.0 268 | PTEN_HUMAN,D252G,pathogenic,1.0 269 | PTEN_HUMAN,K254T,pathogenic,1.0 270 | PTEN_HUMAN,V255E,pathogenic,1.0 271 | PTEN_HUMAN,D268E,benign,0.0 272 | PTEN_HUMAN,H272P,pathogenic,1.0 273 | PTEN_HUMAN,W274L,pathogenic,1.0 274 | PTEN_HUMAN,N276I,pathogenic,1.0 275 | PTEN_HUMAN,T277R,pathogenic,1.0 276 | PTEN_HUMAN,D326G,pathogenic,1.0 277 | PTEN_HUMAN,R335Q,pathogenic,1.0 278 | RASH_HUMAN,G12A,pathogenic,1.0 279 | RASH_HUMAN,G12C,pathogenic,1.0 280 | RASH_HUMAN,G12D,pathogenic,1.0 281 | RASH_HUMAN,G12S,pathogenic,1.0 282 | RASH_HUMAN,G12V,pathogenic,1.0 283 | RASH_HUMAN,G13C,pathogenic,1.0 284 | RASH_HUMAN,G13D,pathogenic,1.0 285 | RASH_HUMAN,G13V,pathogenic,1.0 286 | RASH_HUMAN,Q22K,pathogenic,1.0 287 | RASH_HUMAN,I46T,pathogenic,1.0 288 | RASH_HUMAN,T58I,pathogenic,1.0 289 | RASH_HUMAN,A59T,pathogenic,1.0 290 | RASH_HUMAN,G60D,pathogenic,1.0 291 | RASH_HUMAN,G60V,pathogenic,1.0 292 | RASH_HUMAN,Q61H,pathogenic,1.0 293 | RASH_HUMAN,Q61K,pathogenic,1.0 294 | RASH_HUMAN,Q61R,pathogenic,1.0 295 | RASH_HUMAN,E63K,pathogenic,1.0 296 | RASH_HUMAN,N86T,benign,0.0 297 | RASH_HUMAN,S89C,pathogenic,1.0 298 | RASH_HUMAN,K117R,pathogenic,1.0 299 | RASH_HUMAN,A146V,pathogenic,1.0 300 | RASH_HUMAN,P174S,benign,0.0 301 | RASH_HUMAN,P179A,benign,0.0 302 | SCN5A_HUMAN,M1L,pathogenic,1.0 303 | SCN5A_HUMAN,R34C,benign,0.0 304 | SCN5A_HUMAN,Y87C,pathogenic,1.0 305 | SCN5A_HUMAN,Q90K,benign,0.0 306 | SCN5A_HUMAN,R104Q,pathogenic,1.0 307 | SCN5A_HUMAN,R104W,pathogenic,1.0 308 | SCN5A_HUMAN,R121Q,pathogenic,1.0 309 | SCN5A_HUMAN,R121W,pathogenic,1.0 310 | SCN5A_HUMAN,N134S,benign,0.0 311 | SCN5A_HUMAN,G180V,pathogenic,1.0 312 | SCN5A_HUMAN,D197G,pathogenic,1.0 313 | SCN5A_HUMAN,D197H,pathogenic,1.0 314 | SCN5A_HUMAN,A204E,pathogenic,1.0 315 | SCN5A_HUMAN,R222G,pathogenic,1.0 316 | SCN5A_HUMAN,R222Q,pathogenic,1.0 317 | SCN5A_HUMAN,R225W,pathogenic,1.0 318 | SCN5A_HUMAN,L250V,pathogenic,1.0 319 | SCN5A_HUMAN,S262R,pathogenic,1.0 320 | SCN5A_HUMAN,R282H,pathogenic,1.0 321 | SCN5A_HUMAN,A286S,benign,0.0 322 | SCN5A_HUMAN,A286V,benign,0.0 323 | SCN5A_HUMAN,L299M,benign,0.0 324 | SCN5A_HUMAN,D356N,pathogenic,1.0 325 | SCN5A_HUMAN,D356Y,pathogenic,1.0 326 | SCN5A_HUMAN,R367C,pathogenic,1.0 327 | SCN5A_HUMAN,R367H,pathogenic,1.0 328 | SCN5A_HUMAN,R376H,pathogenic,1.0 329 | SCN5A_HUMAN,G400W,pathogenic,1.0 330 | SCN5A_HUMAN,N406K,pathogenic,1.0 331 | SCN5A_HUMAN,V411M,pathogenic,1.0 332 | SCN5A_HUMAN,T455A,benign,0.0 333 | SCN5A_HUMAN,L461V,benign,0.0 334 | SCN5A_HUMAN,L494F,benign,0.0 335 | SCN5A_HUMAN,S524Y,benign,0.0 336 | SCN5A_HUMAN,H558R,benign,0.0 337 | SCN5A_HUMAN,A572D,benign,0.0 338 | SCN5A_HUMAN,L618F,benign,0.0 339 | SCN5A_HUMAN,P656L,benign,0.0 340 | SCN5A_HUMAN,R680H,benign,0.0 341 | SCN5A_HUMAN,A735E,pathogenic,1.0 342 | SCN5A_HUMAN,A735V,pathogenic,1.0 343 | SCN5A_HUMAN,G752R,pathogenic,1.0 344 | SCN5A_HUMAN,M764T,pathogenic,1.0 345 | SCN5A_HUMAN,R814W,pathogenic,1.0 346 | SCN5A_HUMAN,T843A,pathogenic,1.0 347 | SCN5A_HUMAN,R878C,pathogenic,1.0 348 | SCN5A_HUMAN,H886P,pathogenic,1.0 349 | SCN5A_HUMAN,S910L,pathogenic,1.0 350 | SCN5A_HUMAN,F919S,pathogenic,1.0 351 | SCN5A_HUMAN,L939P,pathogenic,1.0 352 | SCN5A_HUMAN,A997S,pathogenic,1.0 353 | SCN5A_HUMAN,T1016M,benign,0.0 354 | SCN5A_HUMAN,P1089L,benign,0.0 355 | SCN5A_HUMAN,L1217R,pathogenic,1.0 356 | SCN5A_HUMAN,D1275N,pathogenic,1.0 357 | SCN5A_HUMAN,D1275Y,pathogenic,1.0 358 | SCN5A_HUMAN,R1303Q,pathogenic,1.0 359 | SCN5A_HUMAN,N1325S,pathogenic,1.0 360 | SCN5A_HUMAN,A1326S,pathogenic,1.0 361 | SCN5A_HUMAN,A1330T,pathogenic,1.0 362 | SCN5A_HUMAN,P1332L,pathogenic,1.0 363 | SCN5A_HUMAN,P1332R,pathogenic,1.0 364 | SCN5A_HUMAN,P1332S,pathogenic,1.0 365 | SCN5A_HUMAN,V1378M,pathogenic,1.0 366 | SCN5A_HUMAN,G1408R,pathogenic,1.0 367 | SCN5A_HUMAN,P1438S,pathogenic,1.0 368 | SCN5A_HUMAN,E1441Q,pathogenic,1.0 369 | SCN5A_HUMAN,M1446I,pathogenic,1.0 370 | SCN5A_HUMAN,Y1447H,pathogenic,1.0 371 | SCN5A_HUMAN,Y1449C,pathogenic,1.0 372 | SCN5A_HUMAN,S1458Y,pathogenic,1.0 373 | SCN5A_HUMAN,F1473L,pathogenic,1.0 374 | SCN5A_HUMAN,Q1475L,pathogenic,1.0 375 | SCN5A_HUMAN,T1488K,pathogenic,1.0 376 | SCN5A_HUMAN,D1595H,pathogenic,1.0 377 | SCN5A_HUMAN,D1595N,pathogenic,1.0 378 | SCN5A_HUMAN,G1605R,pathogenic,1.0 379 | SCN5A_HUMAN,T1620M,pathogenic,1.0 380 | SCN5A_HUMAN,R1623L,pathogenic,1.0 381 | SCN5A_HUMAN,R1623Q,pathogenic,1.0 382 | SCN5A_HUMAN,R1632C,pathogenic,1.0 383 | SCN5A_HUMAN,R1638P,pathogenic,1.0 384 | SCN5A_HUMAN,R1644H,pathogenic,1.0 385 | SCN5A_HUMAN,M1651V,pathogenic,1.0 386 | SCN5A_HUMAN,M1676T,pathogenic,1.0 387 | SCN5A_HUMAN,M1676V,pathogenic,1.0 388 | SCN5A_HUMAN,M1691T,pathogenic,1.0 389 | SCN5A_HUMAN,C1703Y,pathogenic,1.0 390 | SCN5A_HUMAN,S1710L,pathogenic,1.0 391 | SCN5A_HUMAN,D1714G,pathogenic,1.0 392 | SCN5A_HUMAN,D1714N,pathogenic,1.0 393 | SCN5A_HUMAN,G1743R,pathogenic,1.0 394 | SCN5A_HUMAN,S1744I,pathogenic,1.0 395 | SCN5A_HUMAN,F1760C,pathogenic,1.0 396 | SCN5A_HUMAN,V1763M,pathogenic,1.0 397 | SCN5A_HUMAN,M1766L,pathogenic,1.0 398 | SCN5A_HUMAN,I1768V,pathogenic,1.0 399 | SCN5A_HUMAN,Y1795C,pathogenic,1.0 400 | SCN5A_HUMAN,D1802G,pathogenic,1.0 401 | SCN5A_HUMAN,M1875T,pathogenic,1.0 402 | SCN5A_HUMAN,E1876K,pathogenic,1.0 403 | SCN5A_HUMAN,S1965C,pathogenic,1.0 404 | SCN5A_HUMAN,S1979C,pathogenic,1.0 405 | SCN5A_HUMAN,G1992A,benign,0.0 406 | SCN5A_HUMAN,P2006L,benign,0.0 407 | -------------------------------------------------------------------------------- /data/mappings/example_mapping.csv: -------------------------------------------------------------------------------- 1 | protein_name,msa_location,theta 2 | PTEN_HUMAN,PTEN_HUMAN_b1.0.a2m,0.2 -------------------------------------------------------------------------------- /data/mutations/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/EVE/740b0a72e9e40102c629bda07da9228baaef6844/data/mutations/.gitkeep -------------------------------------------------------------------------------- /data/weights/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/EVE/740b0a72e9e40102c629bda07da9228baaef6844/data/weights/.gitkeep -------------------------------------------------------------------------------- /examples/Step1_train_VAE.sh: -------------------------------------------------------------------------------- 1 | export MSA_data_folder='./data/MSA' 2 | export MSA_list='./data/mappings/example_mapping.csv' 3 | export MSA_weights_location='./data/weights' 4 | export VAE_checkpoint_location='./results/VAE_parameters' 5 | export model_name_suffix='Jan1_PTEN_example' 6 | export model_parameters_location='./EVE/default_model_params.json' 7 | export training_logs_location='./logs/' 8 | export protein_index=0 9 | 10 | python train_VAE.py \ 11 | --MSA_data_folder ${MSA_data_folder} \ 12 | --MSA_list ${MSA_list} \ 13 | --protein_index ${protein_index} \ 14 | --MSA_weights_location ${MSA_weights_location} \ 15 | --VAE_checkpoint_location ${VAE_checkpoint_location} \ 16 | --model_name_suffix ${model_name_suffix} \ 17 | --model_parameters_location ${model_parameters_location} \ 18 | --training_logs_location ${training_logs_location} -------------------------------------------------------------------------------- /examples/Step2_compute_evol_indices_all_singles.sh: -------------------------------------------------------------------------------- 1 | export MSA_data_folder='./data/MSA' 2 | export MSA_list='./data/mappings/example_mapping.csv' 3 | export MSA_weights_location='./data/weights' 4 | export VAE_checkpoint_location='./results/VAE_parameters' 5 | export model_name_suffix='Jan1_PTEN_example' 6 | export model_parameters_location='./EVE/default_model_params.json' 7 | export training_logs_location='./logs/' 8 | export protein_index=0 9 | 10 | export computation_mode='all_singles' 11 | export all_singles_mutations_folder='./data/mutations' 12 | export output_evol_indices_location='./results/evol_indices' 13 | export num_samples_compute_evol_indices=20000 14 | export batch_size=2048 15 | 16 | python compute_evol_indices.py \ 17 | --MSA_data_folder ${MSA_data_folder} \ 18 | --MSA_list ${MSA_list} \ 19 | --protein_index ${protein_index} \ 20 | --MSA_weights_location ${MSA_weights_location} \ 21 | --VAE_checkpoint_location ${VAE_checkpoint_location} \ 22 | --model_name_suffix ${model_name_suffix} \ 23 | --model_parameters_location ${model_parameters_location} \ 24 | --computation_mode ${computation_mode} \ 25 | --all_singles_mutations_folder ${all_singles_mutations_folder} \ 26 | --output_evol_indices_location ${output_evol_indices_location} \ 27 | --num_samples_compute_evol_indices ${num_samples_compute_evol_indices} \ 28 | --batch_size ${batch_size} -------------------------------------------------------------------------------- /examples/Step3_train_GMM_and_compute_EVE_scores_all_singles.sh: -------------------------------------------------------------------------------- 1 | export input_evol_indices_location='./results/evol_indices' 2 | export input_evol_indices_filename_suffix='_20000_samples' 3 | export protein_list='./data/mappings/example_mapping.csv' 4 | export output_eve_scores_location='./results/EVE_scores' 5 | export output_eve_scores_filename_suffix='Jan1_PTEN_example' 6 | 7 | export GMM_parameter_location='./results/GMM_parameters/Default_GMM_parameters' 8 | export GMM_parameter_filename_suffix='default' 9 | export protein_GMM_weight=0.3 10 | export plot_location='./results' 11 | export labels_file_location='./data/labels/PTEN_ClinVar_labels.csv' 12 | export default_uncertainty_threshold_file_location='./utils/default_uncertainty_threshold.json' 13 | 14 | python train_GMM_and_compute_EVE_scores.py \ 15 | --input_evol_indices_location ${input_evol_indices_location} \ 16 | --input_evol_indices_filename_suffix ${input_evol_indices_filename_suffix} \ 17 | --protein_list ${protein_list} \ 18 | --output_eve_scores_location ${output_eve_scores_location} \ 19 | --output_eve_scores_filename_suffix ${output_eve_scores_filename_suffix} \ 20 | --load_GMM_models \ 21 | --GMM_parameter_location ${GMM_parameter_location} \ 22 | --GMM_parameter_filename_suffix ${GMM_parameter_filename_suffix} \ 23 | --compute_EVE_scores \ 24 | --protein_GMM_weight ${protein_GMM_weight} \ 25 | --plot_histograms \ 26 | --plot_scores_vs_labels \ 27 | --plot_location ${plot_location} \ 28 | --labels_file_location ${labels_file_location} \ 29 | --default_uncertainty_threshold_file_location ${default_uncertainty_threshold_file_location} \ 30 | --verbose 31 | 32 | -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/EVE/740b0a72e9e40102c629bda07da9228baaef6844/logs/.gitkeep -------------------------------------------------------------------------------- /protein_env.yml: -------------------------------------------------------------------------------- 1 | name: protein_env 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - defaults 6 | dependencies: 7 | - python=3.7 8 | - pytorch=1.7 9 | - cudatoolkit=11.0 10 | - scikit-learn=0.24.1 11 | - numpy=1.20.1 12 | - pandas=1.2.4 13 | - scipy=1.6.2 14 | - tqdm 15 | - matplotlib 16 | - seaborn -------------------------------------------------------------------------------- /results/EVE_scores/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/EVE/740b0a72e9e40102c629bda07da9228baaef6844/results/EVE_scores/.gitkeep -------------------------------------------------------------------------------- /results/GMM_parameters/Default_GMM_parameters/GMM_model_dictionary_default: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/EVE/740b0a72e9e40102c629bda07da9228baaef6844/results/GMM_parameters/Default_GMM_parameters/GMM_model_dictionary_default -------------------------------------------------------------------------------- /results/GMM_parameters/Default_GMM_parameters/GMM_pathogenic_cluster_index_dictionary_default: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/EVE/740b0a72e9e40102c629bda07da9228baaef6844/results/GMM_parameters/Default_GMM_parameters/GMM_pathogenic_cluster_index_dictionary_default -------------------------------------------------------------------------------- /results/VAE_parameters/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/EVE/740b0a72e9e40102c629bda07da9228baaef6844/results/VAE_parameters/.gitkeep -------------------------------------------------------------------------------- /results/evol_indices/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/EVE/740b0a72e9e40102c629bda07da9228baaef6844/results/evol_indices/.gitkeep -------------------------------------------------------------------------------- /results/plots_histograms/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/EVE/740b0a72e9e40102c629bda07da9228baaef6844/results/plots_histograms/.gitkeep -------------------------------------------------------------------------------- /results/plots_scores_vs_labels/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/EVE/740b0a72e9e40102c629bda07da9228baaef6844/results/plots_scores_vs_labels/.gitkeep -------------------------------------------------------------------------------- /train_GMM_and_compute_EVE_scores.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import argparse 5 | import pickle 6 | import tqdm 7 | import json 8 | from sklearn import mixture, linear_model, svm, gaussian_process 9 | 10 | from utils import performance_helpers as ph, plot_helpers 11 | 12 | if __name__=='__main__': 13 | parser = argparse.ArgumentParser(description='GMM fit and EVE scores computation') 14 | parser.add_argument('--input_evol_indices_location', type=str, help='Folder where all individual files with evolutionary indices are stored') 15 | parser.add_argument('--input_evol_indices_filename_suffix', type=str, default='', help='Suffix that was added when generating the evol indices files') 16 | parser.add_argument('--protein_list', type=str, help='List of proteins to be included (one per row)') 17 | parser.add_argument('--output_eve_scores_location', type=str, help='Folder where all EVE scores are stored') 18 | parser.add_argument('--output_eve_scores_filename_suffix', default='', type=str, help='(Optional) Suffix to be added to output filename') 19 | 20 | parser.add_argument('--load_GMM_models', default=False, action='store_true', help='If True, load GMM model parameters. If False, train GMMs from evol indices files') 21 | parser.add_argument('--GMM_parameter_location', default=None, type=str, help='Folder where GMM objects are stored if loading / to be stored if we are re-training') 22 | parser.add_argument('--GMM_parameter_filename_suffix', default=None, type=str, help='Suffix of GMMs model files to load') 23 | parser.add_argument('--protein_GMM_weight', default=0.3, type=float, help='Value of global-local GMM mixing parameter') 24 | 25 | parser.add_argument('--compute_EVE_scores', default=False, action='store_true', help='Computes EVE scores and uncertainty metrics for all input protein mutations') 26 | parser.add_argument('--recompute_uncertainty_threshold', default=False, action='store_true', help='Recompute uncertainty thresholds based on all evol indices in file. Otherwise loads default threhold.') 27 | parser.add_argument('--default_uncertainty_threshold_file_location', default='./utils/default_uncertainty_threshold.json', type=str, help='Location of default uncertainty threholds.') 28 | 29 | parser.add_argument('--plot_histograms', default=False, action='store_true', help='Plots all evol indices histograms with GMM fits') 30 | parser.add_argument('--plot_scores_vs_labels', default=False, action='store_true', help='Plots EVE scores Vs labels at each protein position') 31 | parser.add_argument('--labels_file_location', default=None, type=str, help='File with ground truth labels for all proteins of interest (e.g., ClinVar)') 32 | parser.add_argument('--plot_location', default=None, type=str, help='Location of the different plots') 33 | parser.add_argument('--verbose', action='store_true', help='Print detailed information during run') 34 | args = parser.parse_args() 35 | 36 | mapping_file = pd.read_csv(args.protein_list,low_memory=False) 37 | protein_list = np.unique(mapping_file['protein_name']) 38 | list_variables_to_keep=['protein_name','mutations','evol_indices'] 39 | all_evol_indices = pd.concat([pd.read_csv(args.input_evol_indices_location+os.sep+protein+args.input_evol_indices_filename_suffix+'.csv',low_memory=False)[list_variables_to_keep] \ 40 | for protein in protein_list if os.path.exists(args.input_evol_indices_location+os.sep+protein+args.input_evol_indices_filename_suffix+'.csv')], ignore_index=True) 41 | 42 | all_evol_indices = all_evol_indices.drop_duplicates() 43 | X_train = np.array(all_evol_indices['evol_indices']).reshape(-1, 1) 44 | if args.verbose: 45 | print("Training data size: "+str(len(X_train))) 46 | print("Number of distinct proteins in protein_list: "+str(len(np.unique(all_evol_indices['protein_name'])))) 47 | 48 | if args.load_GMM_models: 49 | dict_models = pickle.load( open( args.GMM_parameter_location+os.sep+'GMM_model_dictionary_'+args.GMM_parameter_filename_suffix, "rb" ) ) 50 | dict_pathogenic_cluster_index = pickle.load( open( args.GMM_parameter_location+os.sep+'GMM_pathogenic_cluster_index_dictionary_'+args.GMM_parameter_filename_suffix, "rb" ) ) 51 | else: 52 | dict_models = {} 53 | dict_pathogenic_cluster_index = {} 54 | if not os.path.exists(args.GMM_parameter_location+os.sep+args.output_eve_scores_filename_suffix): 55 | os.makedirs(args.GMM_parameter_location+os.sep+args.output_eve_scores_filename_suffix) 56 | GMM_stats_log_location=args.GMM_parameter_location+os.sep+args.output_eve_scores_filename_suffix+os.sep+'GMM_stats_'+args.output_eve_scores_filename_suffix+'.csv' 57 | with open(GMM_stats_log_location, "a") as logs: 58 | logs.write("protein_name,weight_pathogenic,mean_pathogenic,mean_benign,std_dev_pathogenic,std_dev_benign\n") 59 | 60 | main_GMM = mixture.GaussianMixture(n_components=2, covariance_type='full',max_iter=1000,n_init=30,tol=1e-4) 61 | main_GMM.fit(X_train) 62 | 63 | dict_models['main'] = main_GMM 64 | pathogenic_cluster_index = np.argmax(np.array(main_GMM.means_).flatten()) #The pathogenic cluster is the cluster with higher mean value 65 | dict_pathogenic_cluster_index['main'] = pathogenic_cluster_index 66 | if args.verbose: 67 | inferred_params = main_GMM.get_params() 68 | print("Index of mixture component with highest mean: "+str(pathogenic_cluster_index)) 69 | print("Model parameters: "+str(inferred_params)) 70 | print("Mixture component weights: "+str(main_GMM.weights_)) 71 | print("Mixture component means: "+str(main_GMM.means_)) 72 | print("Cluster component cov: "+str(main_GMM.covariances_)) 73 | with open(GMM_stats_log_location, "a") as logs: 74 | logs.write(",".join(str(x) for x in [ 75 | 'main', np.array(main_GMM.weights_).flatten()[dict_pathogenic_cluster_index['main']], np.array(main_GMM.means_).flatten()[dict_pathogenic_cluster_index['main']], 76 | np.array(main_GMM.means_).flatten()[1 - dict_pathogenic_cluster_index['main']], np.sqrt(np.array(main_GMM.covariances_).flatten()[dict_pathogenic_cluster_index['main']]), 77 | np.sqrt(np.array(main_GMM.covariances_).flatten()[1 - dict_pathogenic_cluster_index['main']]) 78 | ])+"\n") 79 | 80 | if args.protein_GMM_weight > 0.0: 81 | for protein in tqdm.tqdm(protein_list, "Training all protein GMMs"): 82 | X_train_protein = np.array(all_evol_indices['evol_indices'][all_evol_indices.protein_name==protein]).reshape(-1, 1) 83 | if len(X_train_protein) > 0: #We have evol indices computed for protein on file 84 | protein_GMM = mixture.GaussianMixture(n_components=2,covariance_type='full',max_iter=1000,tol=1e-4,weights_init=main_GMM.weights_,means_init=main_GMM.means_,precisions_init=main_GMM.precisions_) 85 | protein_GMM.fit(X_train_protein) 86 | dict_models[protein] = protein_GMM 87 | dict_pathogenic_cluster_index[protein] = np.argmax(np.array(protein_GMM.means_).flatten()) 88 | with open(GMM_stats_log_location, "a") as logs: 89 | logs.write(",".join(str(x) for x in [ 90 | protein, np.array(protein_GMM.weights_).flatten()[dict_pathogenic_cluster_index[protein]], np.array(protein_GMM.means_).flatten()[dict_pathogenic_cluster_index[protein]], 91 | np.array(protein_GMM.means_).flatten()[1 - dict_pathogenic_cluster_index[protein]], np.sqrt(np.array(protein_GMM.covariances_).flatten()[dict_pathogenic_cluster_index[protein]]), 92 | np.sqrt(np.array(protein_GMM.covariances_).flatten()[1 - dict_pathogenic_cluster_index[protein]]) 93 | ])+"\n") 94 | else: 95 | if args.verbose: 96 | print("No evol indices for the protein: "+str(protein)+". Skipping.") 97 | 98 | pickle.dump(dict_models, open(args.GMM_parameter_location+os.sep+args.output_eve_scores_filename_suffix+os.sep+'GMM_model_dictionary_'+args.output_eve_scores_filename_suffix, 'wb')) 99 | pickle.dump(dict_pathogenic_cluster_index, open(args.GMM_parameter_location+os.sep+args.output_eve_scores_filename_suffix+os.sep+'GMM_pathogenic_cluster_index_dictionary_'+args.output_eve_scores_filename_suffix, 'wb')) 100 | 101 | if args.plot_histograms: 102 | if not os.path.exists(args.plot_location+os.sep+'plots_histograms'+os.sep+args.output_eve_scores_filename_suffix): 103 | os.makedirs(args.plot_location+os.sep+'plots_histograms'+os.sep+args.output_eve_scores_filename_suffix) 104 | plot_helpers.plot_histograms(all_evol_indices, dict_models, dict_pathogenic_cluster_index, args.protein_GMM_weight, args.plot_location+os.sep+'plots_histograms'+os.sep+args.output_eve_scores_filename_suffix, args.output_eve_scores_filename_suffix, protein_list) 105 | 106 | if args.compute_EVE_scores: 107 | if args.protein_GMM_weight > 0.0: 108 | all_scores = all_evol_indices.copy() 109 | all_scores['EVE_scores'] = np.nan 110 | all_scores['EVE_classes_100_pct_retained'] = "" 111 | for protein in tqdm.tqdm(protein_list,"Scoring all protein mutations"): 112 | try: 113 | test_data_protein = all_scores[all_scores.protein_name==protein] 114 | X_test_protein = np.array(test_data_protein['evol_indices']).reshape(-1, 1) 115 | mutation_scores_protein = ph.compute_weighted_score_two_GMMs(X_pred=X_test_protein, 116 | main_model = dict_models['main'], 117 | protein_model=dict_models[protein], 118 | cluster_index_main = dict_pathogenic_cluster_index['main'], 119 | cluster_index_protein = dict_pathogenic_cluster_index[protein], 120 | protein_weight = args.protein_GMM_weight) 121 | gmm_class_protein = ph.compute_weighted_class_two_GMMs(X_pred=X_test_protein, 122 | main_model = dict_models['main'], 123 | protein_model=dict_models[protein], 124 | cluster_index_main = dict_pathogenic_cluster_index['main'], 125 | cluster_index_protein = dict_pathogenic_cluster_index[protein], 126 | protein_weight = args.protein_GMM_weight) 127 | gmm_class_label_protein = pd.Series(gmm_class_protein).map(lambda x: 'Pathogenic' if x == 1 else 'Benign') 128 | 129 | all_scores.loc[all_scores.protein_name==protein, 'EVE_scores'] = np.array(mutation_scores_protein) 130 | all_scores.loc[all_scores.protein_name==protein, 'EVE_classes_100_pct_retained'] = np.array(gmm_class_label_protein) 131 | except: 132 | print("Issues with protein: "+str(protein)+". Skipping.") 133 | else: 134 | all_scores = all_evol_indices.copy() 135 | mutation_scores = dict_models['main'].predict_proba(np.array(all_scores['evol_indices']).reshape(-1, 1)) 136 | all_scores['EVE_scores'] = mutation_scores[:,dict_pathogenic_cluster_index['main']] 137 | gmm_class = dict_models['main'].predict(np.array(all_scores['evol_indices']).reshape(-1, 1)) 138 | all_scores['EVE_classes_100_pct_retained'] = np.array(pd.Series(gmm_class).map(lambda x: 'Pathogenic' if x == dict_pathogenic_cluster_index['main'] else 'Benign')) 139 | 140 | len_before_drop_na = len(all_scores) 141 | all_scores = all_scores.dropna(subset=['EVE_scores']) 142 | len_after_drop_na = len(all_scores) 143 | 144 | if args.verbose: 145 | scores_stats = ph.compute_stats(all_scores['EVE_scores']) 146 | print("Score stats: "+str(scores_stats)) 147 | print("Dropped mutations due to missing EVE scores: "+str(len_after_drop_na-len_before_drop_na)) 148 | all_scores['uncertainty'] = ph.predictive_entropy_binary_classifier(all_scores['EVE_scores']) 149 | 150 | if args.recompute_uncertainty_threshold: 151 | uncertainty_cutoffs_deciles, _, _ = ph.compute_uncertainty_deciles(all_scores) 152 | uncertainty_cutoffs_quartiles, _, _ = ph.compute_uncertainty_quartiles(all_scores) 153 | if args.verbose: 154 | print("Uncertainty cutoffs deciles: "+str(uncertainty_cutoffs_deciles)) 155 | print("Uncertainty cutoffs quartiles: "+str(uncertainty_cutoffs_quartiles)) 156 | else: 157 | uncertainty_thresholds = json.load(open(args.default_uncertainty_threshold_file_location)) 158 | uncertainty_cutoffs_deciles = uncertainty_thresholds["deciles"] 159 | uncertainty_cutoffs_quartiles = uncertainty_thresholds["quartiles"] 160 | 161 | for decile in range(1,10): 162 | classification_name = 'EVE_classes_'+str((decile)*10)+"_pct_retained" 163 | all_scores[classification_name] = all_scores['EVE_classes_100_pct_retained'] 164 | all_scores.loc[all_scores['uncertainty'] > uncertainty_cutoffs_deciles[str(decile)], classification_name] = 'Uncertain' 165 | if args.verbose: 166 | print("Stats classification by uncertainty for decile #:"+str(decile)) 167 | print(all_scores[classification_name].value_counts(normalize=True)) 168 | if args.verbose: 169 | print("Stats classification by uncertainty for decile #:"+str(10)) 170 | print(all_scores['EVE_classes_100_pct_retained'].value_counts(normalize=True)) 171 | 172 | for quartile in [1,3]: 173 | classification_name = 'EVE_classes_'+str((quartile)*25)+"_pct_retained" 174 | all_scores[classification_name] = all_scores['EVE_classes_100_pct_retained'] 175 | all_scores.loc[all_scores['uncertainty'] > uncertainty_cutoffs_quartiles[str(quartile)], classification_name] = 'Uncertain' 176 | if args.verbose: 177 | print("Stats classification by uncertainty for quartile #:"+str(quartile)) 178 | print(all_scores[classification_name].value_counts(normalize=True)) 179 | 180 | all_scores.to_csv(args.output_eve_scores_location+os.sep+'all_EVE_scores_'+args.output_eve_scores_filename_suffix+'.csv', index=False) 181 | 182 | if args.plot_scores_vs_labels: 183 | labels_dataset=pd.read_csv(args.labels_file_location,low_memory=False) 184 | all_scores_mutations_with_labels = pd.merge(all_scores, labels_dataset[['protein_name','mutations','ClinVar_labels']], how='inner', on=['protein_name','mutations']) 185 | all_PB_scores = all_scores_mutations_with_labels[all_scores_mutations_with_labels.ClinVar_labels!=0.5].copy() 186 | if not os.path.exists(args.plot_location+os.sep+'plots_scores_vs_labels'+os.sep+args.output_eve_scores_filename_suffix): 187 | os.makedirs(args.plot_location+os.sep+'plots_scores_vs_labels'+os.sep+args.output_eve_scores_filename_suffix) 188 | for protein in tqdm.tqdm(protein_list,"Plot scores Vs labels"): 189 | plot_helpers.plot_scores_vs_labels(score_df=all_PB_scores[all_PB_scores.protein_name==protein], 190 | plot_location=args.plot_location+os.sep+'plots_scores_vs_labels'+os.sep+args.output_eve_scores_filename_suffix, 191 | output_eve_scores_filename_suffix=args.output_eve_scores_filename_suffix+'_'+protein, 192 | mutation_name='mutations', score_name="EVE_scores", label_name='ClinVar_labels') -------------------------------------------------------------------------------- /train_VAE.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import pandas as pd 4 | import json 5 | 6 | from EVE import VAE_model 7 | from utils import data_utils 8 | 9 | if __name__=='__main__': 10 | parser = argparse.ArgumentParser(description='VAE') 11 | parser.add_argument('--MSA_data_folder', type=str, help='Folder where MSAs are stored') 12 | parser.add_argument('--MSA_list', type=str, help='List of proteins and corresponding MSA file name') 13 | parser.add_argument('--protein_index', type=int, help='Row index of protein in input mapping file') 14 | parser.add_argument('--MSA_weights_location', type=str, help='Location where weights for each sequence in the MSA will be stored') 15 | parser.add_argument('--theta_reweighting', type=float, help='Parameters for MSA sequence re-weighting') 16 | parser.add_argument('--VAE_checkpoint_location', type=str, help='Location where VAE model checkpoints will be stored') 17 | parser.add_argument('--model_name_suffix', default='Jan1', type=str, help='model checkpoint name will be the protein name followed by this suffix') 18 | parser.add_argument('--model_parameters_location', type=str, help='Location of VAE model parameters') 19 | parser.add_argument('--training_logs_location', type=str, help='Location of VAE model parameters') 20 | args = parser.parse_args() 21 | 22 | mapping_file = pd.read_csv(args.MSA_list) 23 | protein_name = mapping_file['protein_name'][args.protein_index] 24 | msa_location = args.MSA_data_folder + os.sep + mapping_file['msa_location'][args.protein_index] 25 | print("Protein name: "+str(protein_name)) 26 | print("MSA file: "+str(msa_location)) 27 | 28 | if args.theta_reweighting is not None: 29 | theta = args.theta_reweighting 30 | else: 31 | try: 32 | theta = float(mapping_file['theta'][args.protein_index]) 33 | except: 34 | theta = 0.2 35 | print("Theta MSA re-weighting: "+str(theta)) 36 | 37 | data = data_utils.MSA_processing( 38 | MSA_location=msa_location, 39 | theta=theta, 40 | use_weights=True, 41 | weights_location=args.MSA_weights_location + os.sep + protein_name + '_theta_' + str(theta) + '.npy' 42 | ) 43 | 44 | model_name = protein_name + "_" + args.model_name_suffix 45 | print("Model name: "+str(model_name)) 46 | 47 | model_params = json.load(open(args.model_parameters_location)) 48 | 49 | model = VAE_model.VAE_model( 50 | model_name=model_name, 51 | data=data, 52 | encoder_parameters=model_params["encoder_parameters"], 53 | decoder_parameters=model_params["decoder_parameters"], 54 | random_seed=42 55 | ) 56 | model = model.to(model.device) 57 | 58 | model_params["training_parameters"]['training_logs_location'] = args.training_logs_location 59 | model_params["training_parameters"]['model_checkpoint_location'] = args.VAE_checkpoint_location 60 | 61 | print("Starting to train model: " + model_name) 62 | model.train_model(data=data, training_parameters=model_params["training_parameters"]) 63 | 64 | print("Saving model: " + model_name) 65 | model.save(model_checkpoint=model_params["training_parameters"]['model_checkpoint_location']+os.sep+model_name+"_final", 66 | encoder_parameters=model_params["encoder_parameters"], 67 | decoder_parameters=model_params["decoder_parameters"], 68 | training_parameters=model_params["training_parameters"] 69 | ) -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from collections import defaultdict 4 | import os 5 | import torch 6 | import tqdm 7 | 8 | class MSA_processing: 9 | def __init__(self, 10 | MSA_location="", 11 | theta=0.2, 12 | use_weights=True, 13 | weights_location="./data/weights", 14 | preprocess_MSA=True, 15 | threshold_sequence_frac_gaps=0.5, 16 | threshold_focus_cols_frac_gaps=0.3, 17 | remove_sequences_with_indeterminate_AA_in_focus_cols=True 18 | ): 19 | 20 | """ 21 | Parameters: 22 | - msa_location: (path) Location of the MSA data. Constraints on input MSA format: 23 | - focus_sequence is the first one in the MSA data 24 | - first line is structured as follows: ">focus_seq_name/start_pos-end_pos" (e.g., >SPIKE_SARS2/310-550) 25 | - corespondding sequence data located on following line(s) 26 | - then all other sequences follow with ">name" on first line, corresponding data on subsequent lines 27 | - theta: (float) Sequence weighting hyperparameter. Generally: Prokaryotic and eukaryotic families = 0.2; Viruses = 0.01 28 | - use_weights: (bool) If False, sets all sequence weights to 1. If True, checks weights_location -- if non empty uses that; 29 | otherwise compute weights from scratch and store them at weights_location 30 | - weights_location: (path) Location to load from/save to the sequence weights 31 | - preprocess_MSA: (bool) performs pre-processing of MSA to remove short fragments and positions that are not well covered. 32 | - threshold_sequence_frac_gaps: (float, between 0 and 1) Threshold value to define fragments 33 | - sequences with a fraction of gap characters above threshold_sequence_frac_gaps are removed 34 | - default is set to 0.5 (i.e., fragments with 50% or more gaps are removed) 35 | - threshold_focus_cols_frac_gaps: (float, between 0 and 1) Threshold value to define focus columns 36 | - positions with a fraction of gap characters above threshold_focus_cols_pct_gaps will be set to lower case (and not included in the focus_cols) 37 | - default is set to 0.3 (i.e., focus positions are the ones with 30% of gaps or less, i.e., 70% or more residue occupancy) 38 | - remove_sequences_with_indeterminate_AA_in_focus_cols: (bool) Remove all sequences that have indeterminate AA (e.g., B, J, X, Z) at focus positions of the wild type 39 | """ 40 | np.random.seed(2021) 41 | self.MSA_location = MSA_location 42 | self.weights_location = weights_location 43 | self.theta = theta 44 | self.alphabet = "ACDEFGHIKLMNPQRSTVWY" 45 | self.use_weights = use_weights 46 | self.preprocess_MSA = preprocess_MSA 47 | self.threshold_sequence_frac_gaps = threshold_sequence_frac_gaps 48 | self.threshold_focus_cols_frac_gaps = threshold_focus_cols_frac_gaps 49 | self.remove_sequences_with_indeterminate_AA_in_focus_cols = remove_sequences_with_indeterminate_AA_in_focus_cols 50 | 51 | self.gen_alignment() 52 | self.create_all_singles() 53 | 54 | def gen_alignment(self): 55 | """ Read training alignment and store basics in class instance """ 56 | self.aa_dict = {} 57 | for i,aa in enumerate(self.alphabet): 58 | self.aa_dict[aa] = i 59 | 60 | self.seq_name_to_sequence = defaultdict(str) 61 | name = "" 62 | with open(self.MSA_location, "r") as msa_data: 63 | for i, line in enumerate(msa_data): 64 | line = line.rstrip() 65 | if line.startswith(">"): 66 | name = line 67 | if i==0: 68 | self.focus_seq_name = name 69 | else: 70 | self.seq_name_to_sequence[name] += line 71 | 72 | 73 | ## MSA pre-processing to remove inadequate columns and sequences 74 | if self.preprocess_MSA: 75 | msa_df = pd.DataFrame.from_dict(self.seq_name_to_sequence, orient='index', columns=['sequence']) 76 | # Data clean up 77 | msa_df.sequence = msa_df.sequence.apply(lambda x: x.replace(".","-")).apply(lambda x: ''.join([aa.upper() for aa in x])) 78 | # Remove columns that would be gaps in the wild type 79 | non_gap_wt_cols = [aa!='-' for aa in msa_df.sequence[self.focus_seq_name]] 80 | msa_df['sequence'] = msa_df['sequence'].apply(lambda x: ''.join([aa for aa,non_gap_ind in zip(x, non_gap_wt_cols) if non_gap_ind])) 81 | assert 0.0 <= self.threshold_sequence_frac_gaps <= 1.0,"Invalid fragment filtering parameter" 82 | assert 0.0 <= self.threshold_focus_cols_frac_gaps <= 1.0,"Invalid focus position filtering parameter" 83 | msa_array = np.array([list(seq) for seq in msa_df.sequence]) 84 | gaps_array = np.array(list(map(lambda seq: [aa=='-' for aa in seq], msa_array))) 85 | # Identify fragments with too many gaps 86 | seq_gaps_frac = gaps_array.mean(axis=1) 87 | seq_below_threshold = seq_gaps_frac <= self.threshold_sequence_frac_gaps 88 | print("Proportion of sequences dropped due to fraction of gaps: "+str(round(float(1 - seq_below_threshold.sum()/seq_below_threshold.shape)*100,2))+"%") 89 | # Identify focus columns 90 | columns_gaps_frac = gaps_array[seq_below_threshold].mean(axis=0) 91 | index_cols_below_threshold = columns_gaps_frac <= self.threshold_focus_cols_frac_gaps 92 | print("Proportion of non-focus columns removed: "+str(round(float(1 - index_cols_below_threshold.sum()/index_cols_below_threshold.shape)*100,2))+"%") 93 | # Lower case non focus cols and filter fragment sequences 94 | msa_df['sequence'] = msa_df['sequence'].apply(lambda x: ''.join([aa.upper() if upper_case_ind else aa.lower() for aa, upper_case_ind in zip(x, index_cols_below_threshold)])) 95 | msa_df = msa_df[seq_below_threshold] 96 | # Overwrite seq_name_to_sequence with clean version 97 | self.seq_name_to_sequence = defaultdict(str) 98 | for seq_idx in range(len(msa_df['sequence'])): 99 | self.seq_name_to_sequence[msa_df.index[seq_idx]] = msa_df.sequence[seq_idx] 100 | 101 | self.focus_seq = self.seq_name_to_sequence[self.focus_seq_name] 102 | self.focus_cols = [ix for ix, s in enumerate(self.focus_seq) if s == s.upper() and s!='-'] 103 | self.focus_seq_trimmed = [self.focus_seq[ix] for ix in self.focus_cols] 104 | self.seq_len = len(self.focus_cols) 105 | self.alphabet_size = len(self.alphabet) 106 | 107 | # Connect local sequence index with uniprot index (index shift inferred from 1st row of MSA) 108 | focus_loc = self.focus_seq_name.split("/")[-1] 109 | start,stop = focus_loc.split("-") 110 | self.focus_start_loc = int(start) 111 | self.focus_stop_loc = int(stop) 112 | self.uniprot_focus_col_to_wt_aa_dict \ 113 | = {idx_col+int(start):self.focus_seq[idx_col] for idx_col in self.focus_cols} 114 | self.uniprot_focus_col_to_focus_idx \ 115 | = {idx_col+int(start):idx_col for idx_col in self.focus_cols} 116 | 117 | # Move all letters to CAPS; keeps focus columns only 118 | for seq_name,sequence in self.seq_name_to_sequence.items(): 119 | sequence = sequence.replace(".","-") 120 | self.seq_name_to_sequence[seq_name] = [sequence[ix].upper() for ix in self.focus_cols] 121 | 122 | # Remove sequences that have indeterminate AA (e.g., B, J, X, Z) in the focus columns 123 | if self.remove_sequences_with_indeterminate_AA_in_focus_cols: 124 | alphabet_set = set(list(self.alphabet)) 125 | seq_names_to_remove = [] 126 | for seq_name,sequence in self.seq_name_to_sequence.items(): 127 | for letter in sequence: 128 | if letter not in alphabet_set and letter != "-": 129 | seq_names_to_remove.append(seq_name) 130 | continue 131 | seq_names_to_remove = list(set(seq_names_to_remove)) 132 | for seq_name in seq_names_to_remove: 133 | del self.seq_name_to_sequence[seq_name] 134 | 135 | # Encode the sequences 136 | print ("Encoding sequences") 137 | self.one_hot_encoding = np.zeros((len(self.seq_name_to_sequence.keys()),len(self.focus_cols),len(self.alphabet))) 138 | for i,seq_name in enumerate(self.seq_name_to_sequence.keys()): 139 | sequence = self.seq_name_to_sequence[seq_name] 140 | for j,letter in enumerate(sequence): 141 | if letter in self.aa_dict: 142 | k = self.aa_dict[letter] 143 | self.one_hot_encoding[i,j,k] = 1.0 144 | 145 | if self.use_weights: 146 | try: 147 | self.weights = np.load(file=self.weights_location) 148 | print("Loaded sequence weights from disk") 149 | except: 150 | print ("Computing sequence weights") 151 | list_seq = self.one_hot_encoding 152 | list_seq = list_seq.reshape((list_seq.shape[0], list_seq.shape[1] * list_seq.shape[2])) 153 | def compute_weight(seq): 154 | number_non_empty_positions = np.dot(seq,seq) 155 | if number_non_empty_positions>0: 156 | denom = np.dot(list_seq,seq) / np.dot(seq,seq) 157 | denom = np.sum(denom > 1 - self.theta) 158 | return 1/denom 159 | else: 160 | return 0.0 #return 0 weight if sequence is fully empty 161 | self.weights = np.array(list(map(compute_weight,list_seq))) 162 | np.save(file=self.weights_location, arr=self.weights) 163 | else: 164 | # If not using weights, use an isotropic weight matrix 165 | print("Not weighting sequence data") 166 | self.weights = np.ones(self.one_hot_encoding.shape[0]) 167 | 168 | self.Neff = np.sum(self.weights) 169 | self.num_sequences = self.one_hot_encoding.shape[0] 170 | 171 | print ("Neff =",str(self.Neff)) 172 | print ("Data Shape =",self.one_hot_encoding.shape) 173 | 174 | def create_all_singles(self): 175 | start_idx = self.focus_start_loc 176 | focus_seq_index = 0 177 | self.mutant_to_letter_pos_idx_focus_list = {} 178 | list_valid_mutations = [] 179 | # find all possible valid mutations that can be run with this alignment 180 | alphabet_set = set(list(self.alphabet)) 181 | for i,letter in enumerate(self.focus_seq): 182 | if letter in alphabet_set and letter != "-": 183 | for mut in self.alphabet: 184 | pos = start_idx+i 185 | if mut != letter: 186 | mutant = letter+str(pos)+mut 187 | self.mutant_to_letter_pos_idx_focus_list[mutant] = [letter, pos, focus_seq_index] 188 | list_valid_mutations.append(mutant) 189 | focus_seq_index += 1 190 | self.all_single_mutations = list_valid_mutations 191 | 192 | def save_all_singles(self, output_filename): 193 | with open(output_filename, "w") as output: 194 | output.write('mutations') 195 | for mutation in self.all_single_mutations: 196 | output.write('\n') 197 | output.write(mutation) -------------------------------------------------------------------------------- /utils/default_uncertainty_threshold.json: -------------------------------------------------------------------------------- 1 | { 2 | "deciles": { 3 | "1":0.2197550519787074, 4 | "2":0.3073271315004903, 5 | "3":0.380894592583295, 6 | "4":0.452448591753785, 7 | "5":0.5199203860577154, 8 | "6":0.5610673789803645, 9 | "7":0.6028917246467301, 10 | "8":0.6483370062758973, 11 | "9":0.678595588420331, 12 | "10":0.6931471766074847 13 | }, 14 | "quartiles": { 15 | "1":0.3443212858392093, 16 | "2":0.5199203860577154, 17 | "3":0.6272364669515783, 18 | "4":0.6931471766074847 19 | } 20 | } -------------------------------------------------------------------------------- /utils/performance_helpers.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import roc_auc_score 2 | import numpy as np 3 | import pandas as pd 4 | 5 | def compute_stats(input_array): 6 | return { 7 | 'mean':input_array.mean(), 8 | 'std':input_array.std(), 9 | 'min':input_array.min(), 10 | 'max':input_array.max(), 11 | 'P1':np.percentile(input_array,1), 12 | 'P5':np.percentile(input_array,5), 13 | 'P10':np.percentile(input_array,10), 14 | 'P25':np.percentile(input_array,25), 15 | 'P33':np.percentile(input_array,33), 16 | 'P40':np.percentile(input_array,40), 17 | 'P45':np.percentile(input_array,45), 18 | 'P50':np.median(input_array), 19 | 'P55':np.percentile(input_array,55), 20 | 'P60':np.percentile(input_array,60), 21 | 'P66':np.percentile(input_array,66), 22 | 'P75':np.percentile(input_array,75), 23 | 'P90':np.percentile(input_array,90), 24 | 'P95':np.percentile(input_array,95), 25 | 'P99':np.percentile(input_array,99) 26 | } 27 | 28 | def compute_accuracy_with_uncertain(class_pred, labels): 29 | temp_df = pd.DataFrame({'class_pred': class_pred.copy(),'labels': labels.copy()}) 30 | initial_num_obs = len(temp_df['labels']) 31 | temp_df=temp_df[temp_df['class_pred'] != 'Uncertain'] 32 | filtered_num_obs = len(temp_df['labels']) 33 | temp_df['class_pred_bin'] = temp_df['class_pred'].map(lambda x: 1 if x == 'Pathogenic' else 0) 34 | correct_classification = (temp_df['class_pred_bin'] == temp_df['labels']).astype(int) 35 | accuracy = round(correct_classification.mean()*100,1) 36 | pct_mutations_kept = round(filtered_num_obs/float(initial_num_obs)*100,1) 37 | return accuracy, pct_mutations_kept 38 | 39 | def compute_AUC_overall_with_uncertain(scores, class_pred, labels): 40 | temp_df = pd.DataFrame({'class_pred': class_pred.copy(),'labels': labels.copy(), 'scores': scores.copy()}) 41 | temp_df=temp_df[temp_df['class_pred'] != 'Uncertain'] 42 | AUC = roc_auc_score(y_true=temp_df['labels'], y_score=temp_df['scores']) 43 | return round(AUC*100,1) 44 | 45 | def compute_avg_protein_level_AUC_with_uncertain(scores, class_pred, labels, protein_ID): 46 | temp_df = pd.DataFrame({'class_pred': class_pred.copy(),'labels': labels.copy(), 'scores': scores.copy(), 'protein_ID': protein_ID.copy()}) 47 | temp_df=temp_df[temp_df['class_pred'] != 'Uncertain'] 48 | def compute_auc_group(group): 49 | protein_scores = group['scores'] 50 | protein_labels = group['labels'] 51 | try: 52 | result = roc_auc_score(y_true=protein_labels, y_score=protein_scores) 53 | except: 54 | result = np.nan 55 | return result 56 | protein_level_AUC = temp_df.groupby('protein_ID').apply(compute_auc_group) 57 | avg_AUC = protein_level_AUC.mean(skipna=True) 58 | return round(avg_AUC*100,1) 59 | 60 | def compute_pathogenic_rate_with_uncertain(class_pred, labels): 61 | temp_df = pd.DataFrame({'class_pred': class_pred.copy(),'labels': labels.copy()}) 62 | temp_df=temp_df[temp_df['class_pred'] != 'Uncertain'] 63 | rate = len(temp_df[temp_df['class_pred'] == 'Pathogenic']) / float(len(temp_df)) 64 | return round(rate*100,1) 65 | 66 | def compute_uncertainty_deciles(score_dataframe, score_name="EVE_scores", uncertainty_name='uncertainty', suffix=''): 67 | uncertainty_deciles_name='uncertainty_deciles'+suffix 68 | score_dataframe[uncertainty_deciles_name] = pd.qcut(score_dataframe[uncertainty_name], q=10, labels=range(1,11)).astype(int) 69 | uncertainty_cutoffs_deciles={} 70 | scores_at_uncertainty_deciles_cuttoffs_UB_lower_part={} 71 | scores_at_uncertainty_deciles_cuttoffs_LB_upper_part={} 72 | for decile in range(1,11): 73 | uncertainty_cutoffs_deciles[str(decile)]= np.max(score_dataframe[uncertainty_name][score_dataframe[uncertainty_deciles_name] == decile]) 74 | scores_at_uncertainty_deciles_cuttoffs_UB_lower_part[str(decile)]= np.max(score_dataframe[score_name][(score_dataframe[uncertainty_deciles_name] == decile) & (score_dataframe[score_name] < 0.5)]) 75 | scores_at_uncertainty_deciles_cuttoffs_LB_upper_part[str(decile)]= np.min(score_dataframe[score_name][(score_dataframe[uncertainty_deciles_name] == decile) & (score_dataframe[score_name] > 0.5)]) 76 | return uncertainty_cutoffs_deciles, scores_at_uncertainty_deciles_cuttoffs_UB_lower_part, scores_at_uncertainty_deciles_cuttoffs_LB_upper_part 77 | 78 | def compute_uncertainty_quartiles(score_dataframe, score_name="EVE_scores", uncertainty_name='uncertainty', suffix=''): 79 | uncertainty_deciles_name='uncertainty_quartiles'+suffix 80 | score_dataframe[uncertainty_deciles_name] = pd.qcut(score_dataframe[uncertainty_name], q=4, labels=range(1,5)).astype(int) 81 | uncertainty_cutoffs_quartiles={} 82 | scores_at_uncertainty_quartiles_cuttoffs_UB_lower_part={} 83 | scores_at_uncertainty_quartiles_cuttoffs_LB_upper_part={} 84 | for quartile in range(1,5): 85 | uncertainty_cutoffs_quartiles[str(quartile)]= np.max(score_dataframe[uncertainty_name][score_dataframe[uncertainty_deciles_name] == quartile]) 86 | scores_at_uncertainty_quartiles_cuttoffs_UB_lower_part[str(quartile)]= np.max(score_dataframe[score_name][(score_dataframe[uncertainty_deciles_name] == quartile) & (score_dataframe[score_name] < 0.5)]) 87 | scores_at_uncertainty_quartiles_cuttoffs_LB_upper_part[str(quartile)]= np.min(score_dataframe[score_name][(score_dataframe[uncertainty_deciles_name] == quartile) & (score_dataframe[score_name] > 0.5)]) 88 | return uncertainty_cutoffs_quartiles, scores_at_uncertainty_quartiles_cuttoffs_UB_lower_part, scores_at_uncertainty_quartiles_cuttoffs_LB_upper_part 89 | 90 | def compute_performance_by_uncertainty_decile(score_dataframe, metric="Accuracy", verbose=False, score_name="EVE_scores", uncertainty_name="uncertainty", label_name='ClinVar_labels', protein_name='protein_name', class_100pct_retained_name='EVE_classes_100_pct_retained', suffix=''): 91 | uncertainty_cutoffs_deciles, scores_at_uncertainty_deciles_cuttoffs_UB_lower_part, scores_at_uncertainty_deciles_cuttoffs_LB_upper_part = compute_uncertainty_deciles(score_dataframe, score_name, uncertainty_name, suffix) 92 | performance_by_uncertainty_deciles = {} 93 | pathogenic_rate_by_uncertainty_deciles = {} 94 | for decile in range(1,11): 95 | classification_name = 'class_pred_removing_'+str((10-decile)*10)+"_pct_most_uncertain"+suffix 96 | score_dataframe[classification_name] = score_dataframe[class_100pct_retained_name] 97 | score_dataframe.loc[score_dataframe['uncertainty_deciles'+suffix] > decile, classification_name] = 'Uncertain' 98 | if metric=="Accuracy": 99 | performance_decile = compute_accuracy_with_uncertain(score_dataframe[classification_name], score_dataframe[label_name])[0] 100 | elif metric =="Avg_AUC": 101 | performance_decile = compute_avg_protein_level_AUC_with_uncertain(scores=score_dataframe[score_name], class_pred=score_dataframe[classification_name], labels=score_dataframe[label_name], protein_ID=score_dataframe[protein_name]) 102 | performance_by_uncertainty_deciles[decile] = performance_decile 103 | pathogenic_rate_by_uncertainty_deciles[decile] = compute_pathogenic_rate_with_uncertain(class_pred=score_dataframe[classification_name], labels=score_dataframe[label_name]) 104 | if verbose: 105 | print(str(metric)+" when dropping the "+str((10-decile)*10)+"% of cases with highest uncertainty:\t"+str(performance_by_uncertainty_deciles[decile])+"% \t with pathogenic rate of "+str(pathogenic_rate_by_uncertainty_deciles[decile])+"%\n") 106 | print("Uncertainty decile #"+str(decile)+" cutoff: "+str(uncertainty_cutoffs_deciles[str(decile)])+"\n") 107 | print("Score upper bound for lower part in uncertainty decile: "+str(scores_at_uncertainty_deciles_cuttoffs_UB_lower_part[str(decile)])+"\n") 108 | print("Score lower bound for higher part in uncertainty decile: "+str(scores_at_uncertainty_deciles_cuttoffs_LB_upper_part[str(decile)])+"\n") 109 | return performance_by_uncertainty_deciles, pathogenic_rate_by_uncertainty_deciles 110 | 111 | def compute_performance_by_uncertainty_quartile(score_dataframe, metric="Accuracy", verbose=False, score_name="EVE_scores", uncertainty_name="uncertainty", label_name='ClinVar_labels', protein_name='protein_name', class_100pct_retained_name='EVE_classes_100_pct_retained', suffix=''): 112 | uncertainty_cutoffs_quartiles, scores_at_uncertainty_quartiles_cuttoffs_UB_lower_part, scores_at_uncertainty_quartiles_cuttoffs_LB_upper_part = compute_uncertainty_quartiles(score_dataframe, score_name, uncertainty_name, suffix) 113 | performance_by_uncertainty_quartiles = {} 114 | pathogenic_rate_by_uncertainty_quartiles = {} 115 | for quartile in range(1,5): 116 | classification_name = 'class_pred_removing_'+str((4-quartile)*25)+"_pct_most_uncertain"+suffix 117 | score_dataframe[classification_name] = score_dataframe[class_100pct_retained_name] 118 | score_dataframe.loc[score_dataframe['uncertainty_quartiles'+suffix] > quartile, classification_name] = 'Uncertain' 119 | if metric=="Accuracy": 120 | performance_quartile = compute_accuracy_with_uncertain(score_dataframe[classification_name], score_dataframe[label_name])[0] 121 | elif metric =="Avg_AUC": 122 | performance_quartile = compute_avg_protein_level_AUC_with_uncertain(scores=score_dataframe[score_name], class_pred=score_dataframe[classification_name], labels=score_dataframe[label_name], protein_ID=score_dataframe[protein_name]) 123 | performance_by_uncertainty_quartiles[quartile] = performance_quartile 124 | pathogenic_rate_by_uncertainty_quartiles[quartile] = compute_pathogenic_rate_with_uncertain(class_pred=score_dataframe[classification_name], labels=score_dataframe[label_name]) 125 | if verbose: 126 | print(str(metric)+" when dropping the "+str((4-quartile)*25)+"% of cases with highest uncertainty:\t"+str(performance_by_uncertainty_quartiles[quartile])+"% \t with pathogenic rate of "+str(pathogenic_rate_by_uncertainty_quartiles[quartile])+"%\n") 127 | print("Uncertainty quartile #"+str(quartile)+" cutoff: "+str(uncertainty_cutoffs_quartiles[str(quartile)])+"\n") 128 | print("Score upper bound for lower part in uncertainty quartile: "+str(scores_at_uncertainty_quartiles_cuttoffs_UB_lower_part[str(quartile)])+"\n") 129 | print("Score lower bound for higher part in uncertainty quartile: "+str(scores_at_uncertainty_quartiles_cuttoffs_LB_upper_part[str(quartile)])+"\n") 130 | return performance_by_uncertainty_quartiles, pathogenic_rate_by_uncertainty_quartiles 131 | 132 | def predictive_entropy_binary_classifier(class1_scores, eps=1e-8): 133 | class1_scores = pd.Series(class1_scores).map(lambda x: x - eps if x==1.0 else x + eps if x==0 else x) 134 | class0_scores = 1 - class1_scores 135 | return - np.array((np.log(class1_scores) * class1_scores + np.log(class0_scores) * class0_scores)) 136 | 137 | def compute_weighted_score_two_GMMs(X_pred, main_model, protein_model, cluster_index_main, cluster_index_protein, protein_weight): 138 | return protein_model.predict_proba(X_pred)[:,cluster_index_protein] * protein_weight + (main_model.predict_proba(X_pred)[:,cluster_index_main]) * (1 - protein_weight) 139 | 140 | def compute_weighted_class_two_GMMs(X_pred, main_model, protein_model, cluster_index_main, cluster_index_protein, protein_weight): 141 | """By construct, 1 is always index of pathogenic, 0 always that of benign""" 142 | proba_pathogenic = protein_model.predict_proba(X_pred)[:,cluster_index_protein] * protein_weight + (main_model.predict_proba(X_pred)[:,cluster_index_main]) * (1 - protein_weight) 143 | return (proba_pathogenic > 0.5).astype(int) 144 | -------------------------------------------------------------------------------- /utils/plot_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | 7 | def plot_histograms(all_evol_indices, dict_models, dict_pathogenic_cluster_index, protein_GMM_weight, plot_location, output_eve_scores_filename_suffix, protein_list): 8 | x = np.linspace(-10, 20, 2000) 9 | logprob = dict_models['main'].score_samples(x.reshape(-1,1)) 10 | pdf = np.exp(logprob) 11 | component_share = dict_models['main'].predict_proba(x.reshape(-1, 1)) 12 | pdf_pathogenic = component_share[:,dict_pathogenic_cluster_index['main']] * pdf 13 | pdf_benign = component_share[:,1 - dict_pathogenic_cluster_index['main']] * pdf 14 | plt.plot(x,pdf, '--k', color='black') 15 | plt.plot(x,pdf_pathogenic, '--k', color = 'xkcd:red',linewidth=4) 16 | plt.plot(x,pdf_benign, '--k', color = 'xkcd:sky blue',linewidth=4) 17 | plt.hist(all_evol_indices['evol_indices'], color = 'xkcd:grey', bins = 80, histtype='stepfilled', alpha=0.4, density=True) 18 | plt.xlabel("Evolutionary index", fontsize=13) 19 | plt.ylabel("% of variants", fontsize=13) 20 | plt.xticks(fontsize=10) 21 | plt.yticks(fontsize=10) 22 | plt.savefig(plot_location+os.sep+'histogram_random_samples_'+str(output_eve_scores_filename_suffix)+"_all.png", dpi=800, bbox_inches='tight') 23 | plt.clf() 24 | if protein_GMM_weight > 0.0: 25 | for protein in tqdm.tqdm(protein_list,"Plot protein histograms"): 26 | x = np.linspace(-10, 20, 2000) 27 | logprob = dict_models[protein].score_samples(x.reshape(-1,1)) 28 | pdf = np.exp(logprob) 29 | component_share = dict_models[protein].predict_proba(x.reshape(-1, 1)) 30 | pdf_pathogenic = component_share[:,dict_pathogenic_cluster_index[protein]] * pdf 31 | pdf_benign = component_share[:, 1 - dict_pathogenic_cluster_index[protein]] * pdf 32 | plt.plot(x,pdf, '--k', color='black') 33 | plt.plot(x,pdf_pathogenic, '--k', color = 'xkcd:red',linewidth=4) 34 | plt.plot(x,pdf_benign, '--k', color = 'xkcd:sky blue',linewidth=4) 35 | plt.hist(all_evol_indices['evol_indices'][all_evol_indices['protein_name']==protein], color = 'xkcd:grey', bins = 80, histtype='stepfilled', alpha=0.4, density=True) 36 | plt.xlabel("Evolutionary index", fontsize=13) 37 | plt.ylabel("% of variants", fontsize=13) 38 | plt.xticks(fontsize=10) 39 | plt.yticks(fontsize=10) 40 | plt.savefig(plot_location+os.sep+'histogram_random_samples_'+str(output_eve_scores_filename_suffix)+"_"+str(protein)+".png", dpi=800, bbox_inches='tight') 41 | plt.clf() 42 | 43 | def plot_scores_vs_labels(score_df, plot_location, output_eve_scores_filename_suffix, mutation_name='mutations', score_name="EVE_scores", label_name='labels'): 44 | score_df_local = score_df.copy() 45 | score_df_local = score_df_local[score_df_local[mutation_name] !='w-1t'] #Remove wild type sequence 46 | score_df_local['mutation_position'] = score_df[mutation_name].map(lambda x: int(x[1:-1])) 47 | labels = score_df_local[label_name] 48 | pathogenic = plt.scatter(x=score_df_local['mutation_position'][labels==1], y=score_df_local[score_name][labels==1], color='xkcd:red') 49 | benign = plt.scatter(x=score_df_local['mutation_position'][labels==0], y=score_df_local[score_name][labels==0], color='xkcd:sky blue') 50 | plt.legend([pathogenic,benign],['pathogenic','benign']) 51 | plt.savefig(plot_location+os.sep+'scores_vs_labels_plots_'+str(output_eve_scores_filename_suffix)+".png", dpi=400, bbox_inches='tight') 52 | plt.clf() --------------------------------------------------------------------------------