├── figure_1.png ├── src ├── gpu_1.sh ├── args.py ├── utils.py ├── utils_vamp.py ├── train.py ├── model.py └── layers.py ├── README.md └── Example_TrpCage_2dEmbedding.ipynb /figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghorbanimahdi73/GraphVampNet/HEAD/figure_1.png -------------------------------------------------------------------------------- /src/gpu_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ##SBATCH --ntasks=28 3 | #SBATCH --job-name=trp_1 4 | #SBATCH --time=24:0:0 5 | #SBATCH -N 1 6 | #SBATCH --partition=v100 7 | #SBATCH --gres=gpu:v100:1 8 | #SBATCH --ntasks-per-core=1 9 | ##SBATCH --mem=200g 10 | # load necessary module 11 | 12 | #module load cuDNN/7.6.5/CUDA-10.1 13 | module laod cuda/10.2 14 | 15 | pwd=$PWD 16 | source ~/.bashrc 17 | cd $pwd 18 | 19 | conda activate koopnet 20 | 21 | for i in {1..10};do 22 | python train.py --epochs 100 --batch-size 1000 --num-atoms 20 --num-classes 5 --save-folder logs_$i --h_a 16 --num_neighbors 7 --n_conv 4 --h_g 2 --conv_type SchNet --dmin 0. --dmax 8. --step 0.5 --tau 5 --train --dist-data ../dists_trpcage_ca_7nbrs_1ns.npz --nbr-data ../inds_trpcage_ca_7nbrs_1ns.npz --residual 23 | done 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GraphVampNet 2 | 3 | ![figure](figure_1.png) 4 | 5 | This repo contains the code for **GraphVAMPNet** 6 | 7 | 8 | ## GraphVAMPNet code 9 | 10 | ## Usage 11 | 12 | ### training 13 | ``` 14 | python train.py --epochs 100 --batch-size 1000 --lr 0.0005 --hidden 16 15 | --num-atoms 20 --num-classes 5 --num_neighbors 7 --conv_type SchNet --dmin 0 16 | --dmax 8. --step 0.5 --dist-data dists.dat --nbr-data nbrs.dat --residual --train 17 | ``` 18 | ### testing 19 | 20 | ``` 21 | python train.py --epochs 100 --batch-size 1000 --lr 0.0005 --hidden 16 22 | --num-atoms 20 --num-classes 5 --num-neighbors 7 --conv_type SchNet --dmin 0 23 | --dmax 8. --step 0.5 --dist-data dists.dat --nbr-data nbrs.dat --residual --trained-model 24 | logs/logs_99.pt 25 | ``` 26 | 27 | ## Requirements 28 | - pytorch 29 | - deeptime 30 | - torch_scatter 31 | 32 | 33 | 34 | ## Sources: 35 | - VAMPNet code is based on deeptime package [deeptime](https://deeptime-ml.github.io/latest/index.html) 36 | - SchNet code is based on the [cgnet](https://github.com/brookehus/cgnet) 37 | 38 | 39 | ## Cite 40 | If you use this code please cite the following paper: 41 | 42 | ``` 43 | ``` 44 | -------------------------------------------------------------------------------- /src/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def buildParser(): 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training') 7 | parser.add_argument('--seed', type=int, default=42, help='random seed') 8 | parser.add_argument('--epochs', type=int, default=10, help='number of training epochs') 9 | parser.add_argument('--batch-size', type=int, default=5000, help='batch-size for training') 10 | parser.add_argument('--lr', type=float, default=0.0005, help='Initial learning rate') 11 | parser.add_argument('--hidden', type=int, default=16, help='number of hidden neurons') 12 | parser.add_argument('--num-atoms', type=int, default=10, help='Number of atoms') 13 | parser.add_argument('--num-classes', type=int, default=6, help='number of coarse-grained classes') 14 | parser.add_argument('--save-folder', type=str, default='logs', help='Where to save the trained model') 15 | parser.add_argument('--dropout', type=float, default=0.4, help='Dropout rate') 16 | parser.add_argument('--atom_init', type=str, default=None, help='inital embedding for atoms file') 17 | parser.add_argument('--h_a', type=int, default=16, help='Atom hidden embedding dimension') 18 | parser.add_argument('--num_neighbors', type=int, default=5, help='Number of neighbors for each atom in the graph') 19 | parser.add_argument('--n_conv', type=int, default=4, help='Number of convolution layers') 20 | parser.add_argument('--save_checkpoints', default=True, action='store_true', help='If True, stores checkpoints') 21 | parser.add_argument('--conv_type', default='', type=str, help='the type of convolution layer, one of \ 22 | [GraphConvLayer, NeighborMultiHeadAttention, SchNet]') 23 | parser.add_argument('--dmin', default=0., type=float, help='Minimum distance for the gaussian filter') 24 | parser.add_argument('--dmax', default=3., type=float, help='maximum distance for the gaussian filter') 25 | parser.add_argument('--step', default=0.2, type=float, help='step for the gaussian filter') 26 | parser.add_argument('--tau', default=1, type=int, help='lag time for the model') 27 | parser.add_argument('--val-frac', default=0.3, type=float, help='fraction of dataset for validation') 28 | parser.add_argument('--num_heads', default=2, type=int, help='number of heads in multihead attention') 29 | parser.add_argument('--trained-model', default=None, type=str, help='path to the trained model for loading') 30 | parser.add_argument('--train', default=False, action='store_true', help='Whether to train the model or not') 31 | parser.add_argument('--use_backbone_atoms', default=False, action='store_true', help='Whether to use all the back bone atoms for training') 32 | parser.add_argument('--dont-pool-backbone', default=False, action='store_true', help='Whether not to pool backbone atoms') 33 | parser.add_argument('--h_g', type=int, default=8, help='Number of embedding dimension after backbone pooling') 34 | parser.add_argument('--seq_file', type=str, default=None, help='Sequence file to initialize a one-hot encoding based on amino types') 35 | parser.add_argument('--dist-data', type=str, default='dists_BBA_7nbrs_1ns.npz', help='the distnace data file') 36 | parser.add_argument('--nbr-data', type=str, default='inds_BBA_7nbrs_1ns.npz', help='the neighbors data file') 37 | parser.add_argument('--score-method', type=str, default='VAMP2', help='the scoring method of VAMPNet') 38 | parser.add_argument('--residual', action='store_true', default=False, help='Whether to use residual connections') 39 | parser.add_argument('--attention-pool', action='store_true', default=False, help= 'Whether to perform attention before global pooling') 40 | parser.add_argument('--return-emb', action='store_true', default=False, help='Whether return the learned graph embeddings') 41 | parser.add_argument('--return-attn', action='store_true', default=False, help='Whether to return the attention probs (only for NeighborMultiHeadAttention)') 42 | return parser 43 | 44 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import deeptime 4 | from tqdm import tqdm 5 | import matplotlib.pyplot as plt 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import mdshare 9 | from torch.utils.data import DataLoader 10 | import json 11 | from sklearn.neighbors import BallTree 12 | import mdtraj as md 13 | import argparse 14 | import sys 15 | 16 | from deeptime.decomposition._koopman import KoopmanChapmanKolmogorovValidator 17 | 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--num-neighbors', type=int, default=5, help='number of neighbors') 21 | parser.add_argument('--traj-folder', type=str, default=None, help='the path to the trajectory folder') 22 | parser.add_argument('--stride', type=int, default=5, help='stride for trajectory') 23 | #parser.add_argument('--use-backbone', action='store_true', default=False, help='Whether to use produce the data for backbone atoms') 24 | 25 | args = parser.parse_args() 26 | ########## for loading the BBA trajectory #################################### 27 | 28 | 29 | traj_1 = ['../2JOF-0-protein/2JOF-0-protein-'+str(i).zfill(3)+'.dcd' for i in range(105)] 30 | 31 | crys = md.load_pdb('2jof.pdb') 32 | top = crys.topology 33 | inds = top.select('backbone') 34 | 35 | t1 = md.load_dcd(traj_1[0], top=top, stride=args.stride, atom_indices=inds) 36 | coor_t1 = t1.xyz 37 | for i in range(1,len(traj_1)): 38 | t1 = md.load_dcd(traj_1[i], top=top, stride=args.stride, atom_indices=inds) 39 | coor_t1 = np.concatenate((coor_t1, t1.xyz), axis=0) 40 | 41 | print(coor_t1.shape) 42 | 43 | data = list([coor_t1]) 44 | np.savez('pos_trpcage_bb.npz', data[0]) 45 | 46 | 47 | if torch.cuda.is_available(): 48 | device = torch.device('cpu') 49 | print('cuda is is available') 50 | else: 51 | print('Using CPU') 52 | device = torch.device('cpu') 53 | 54 | #ala_coords_file = mdshare.fetch( 55 | # "alanine-dipeptide-3x250ns-heavy-atom-positions.npz", working_directory="data" 56 | #) 57 | #with np.load(ala_coords_file) as fh: 58 | # data = [fh[f"arr_{i}"].astype(np.float32) for i in range(3)] 59 | # 60 | #dihedral_file = mdshare.fetch( 61 | # "alanine-dipeptide-3x250ns-backbone-dihedrals.npz", working_directory="data" 62 | #) 63 | # 64 | #with np.load(dihedral_file) as fh: 65 | # dihedral = [fh[f"arr_{i}"] for i in range(3)] 66 | 67 | 68 | 69 | # reshape the data to be in share list of [N,num_atoms,3] 70 | #data_reshaped = [] 71 | #for i in range(len(data)): 72 | # temp = data[i].reshape(data[0].shape[0], 3, 10).swapaxes(1,2) 73 | # data_reshaped.append(temp) 74 | #----------------------------------------------------- 75 | 76 | def get_nbrs(all_coords, num_neighbors=args.num_neighbors): 77 | ''' 78 | inputs: a trajectory or list of trajectories with shape [T, num_atoms, dim] 79 | T: number of steps 80 | dim: number of dimensions (3 coordinates) 81 | 82 | Returns: 83 | if all_coords is a list: 84 | list of trajectories of ditances and indices 85 | else: 86 | trajectory of distances and indices 87 | 88 | [N, num_atoms, num_neighbors] 89 | ''' 90 | k_nbr=num_neighbors+1 91 | if type(all_coords) == list: 92 | all_dists = [] 93 | all_inds = [] 94 | for i in range(len(all_coords)): 95 | dists = [] 96 | inds = [] 97 | tmp_coords = all_coords[i] 98 | for j in tqdm(range(len(tmp_coords))): 99 | tree = BallTree(tmp_coords[j], leaf_size=3) 100 | dist, ind = tree.query(tmp_coords[j], k=k_nbr) 101 | dists.append(dist[:,1:]) 102 | inds.append(ind[:,1:]) 103 | 104 | dists = np.array(dists) 105 | inds = np.array(inds) 106 | all_dists.append(dists) 107 | all_inds.append(inds) 108 | else: 109 | all_inds = [] 110 | all_dists = [] 111 | for i in range(len(all_coords)): 112 | tree = BallTree(all_coords[i], leaf_size=3) 113 | dist , ind = tree.query(all_coords[i], k=k_nbr) 114 | dists.append(dist[:,1:]) 115 | inds.append(ind[:,1:]) 116 | all_dists = np.array(dists) 117 | all_inds = np.array(inds) 118 | 119 | return all_dists, all_inds 120 | 121 | ns = int(args.stride*0.2) # 0.2 ns is the timestep of trajectories 122 | dists, inds = get_nbrs(data, args.num_neighbors) 123 | np.savez('dists_trpcage_bb_'+str(args.num_neighbors)+'nbrs_'+ str(ns)+'ns'+'.npz', dists[0]) 124 | np.savez('inds_trpcage_bb_'+str(args.num_neighbors)+'nbrs_'+ str(ns)+'ns'+'.npz', inds[0]) 125 | 126 | 127 | def chunks(data, chunk_size=5000): 128 | ''' 129 | splitting the trajectory into chunks for passing into analysis part 130 | data: list of trajectories 131 | chunk_size: the size of each chunk 132 | ''' 133 | if type(data) == list: 134 | 135 | for data_tmp in data: 136 | for j in range(0, len(data_tmp),chunk_size): 137 | print(data_tmp[j:j+chunk_size,...].shape) 138 | yield data_tmp[j:j+chunk_size,...] 139 | 140 | else: 141 | 142 | for j in range(0, len(data), chunk_size): 143 | yield data[j:j+chunk_size,...] 144 | -------------------------------------------------------------------------------- /src/utils_vamp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import deeptime 4 | from tqdm import tqdm 5 | import matplotlib.pyplot as plt 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | import json 10 | from deeptime.decomposition.deep import VAMPNet 11 | from deeptime.decomposition import VAMP 12 | import os 13 | import pickle 14 | 15 | 16 | plt.set_cmap('jet') 17 | 18 | def estimate_koopman_op(trajs, tau): 19 | if type(trajs) == list: 20 | traj = np.concatenate([t[:-tau] for t in trajs], axis=0) 21 | traj_lag = np.concatenate([t[tau:] for t in trajs], axis=0) 22 | else: 23 | traj = trajs[:-tau] 24 | traj_lag = trajs[tau:] 25 | c_0 = np.transpose(traj)@traj 26 | c_tau = np.transpose(traj)@traj_lag 27 | 28 | eigv, eigvec = np.linalg.eig(c_0) 29 | include = eigv > 1e-7 30 | eigv = eigv[include] 31 | eigvec = eigvec[:, include] 32 | c0_inv = eigvec @ np.diag(1/eigv)@np.transpose(eigvec) 33 | 34 | koopman_op = c0_inv @ c_tau 35 | return koopman_op 36 | 37 | def get_ck_test(traj, steps, tau): 38 | if type(traj) == list: 39 | n_states = traj[0].shape[1] 40 | else: 41 | n_states = traj.shape[1] 42 | 43 | predicted = np.zeros((n_states, n_states, steps)) 44 | estimated = np.zeros((n_states, n_states, steps)) 45 | 46 | predicted[:, :, 0] = np.identity(n_states) 47 | estimated[:, :, 0] = np.identity(n_states) 48 | 49 | for vector, i in zip(np.identity(n_states), range(n_states)): 50 | for n in range(1, steps): 51 | koop = estimate_koopman_op(traj, tau) 52 | koop_pred = np.linalg.matrix_power(koop, n) 53 | koop_est = estimate_koopman_op(traj, tau*n) 54 | 55 | predicted[i,:,n] = vector@koop_pred 56 | estimated[i,:,n] = vector@koop_est 57 | 58 | return [predicted, estimated] 59 | 60 | def plot_ck_test(pred, est, n_states, steps, tau, save_folder): 61 | fig, ax = plt.subplots(n_states, n_states, sharex=True, sharey=True) 62 | for index_i in range(n_states): 63 | for index_j in range(n_states): 64 | 65 | ax[index_i][index_j].plot(range(0, steps*tau, tau), pred[index_i, index_j], color='b') 66 | ax[index_i][index_j].plot(range(0, steps*tau, tau), est[index_i, index_j], color='r', linestyle='--') 67 | ax[index_i][index_j].set_title(str(index_i+1)+'->'+str(index_j+1), fontsize='small') 68 | 69 | ax[0][0].set_ylim((-0.1, 1.1)) 70 | ax[0][0].set_xlim((0, steps*tau)) 71 | ax[0][0].axes.get_xaxis().set_ticks(np.round(np.linspace(0, steps*tau, 3))) 72 | plt.tight_layout() 73 | plt.show() 74 | plt.savefig(save_folder+'/ck_test.png') 75 | 76 | def get_its(traj, lags): 77 | ''' 78 | implied timescales from a trajectory estiamted at a series of lag times 79 | 80 | parameters: 81 | --------------------- 82 | traj: numpy array [traj_timesteps, traj_dimension] traj or list of trajs 83 | lags: numpy array with size [lagtimes] series of lag times at which the implied timescales are estimated 84 | 85 | Returns: 86 | --------------------- 87 | its: numpy array with size [traj_dimensions-1, lag_times] implied timescales estimated for the trajectory 88 | ''' 89 | 90 | if type(traj) == list: 91 | outputsize = traj[0].shape[1] 92 | else: 93 | outputsize = traj.shape[1] 94 | its = np.zeros((outputsize-1, len(lags))) 95 | 96 | for t, tau_lag in enumerate(lags): 97 | koopman_op = estimate_koopman_op(traj, tau_lag) 98 | k_eigvals, k_eigvec = np.linalg.eig(np.real(koopman_op)) 99 | k_eigvals = np.sort(np.absolute(k_eigvals)) 100 | k_eigvals = k_eigvals[:-1] 101 | its[:, t] = (-tau_lag/np.log(k_eigvals)) 102 | 103 | return its 104 | 105 | def plot_its(its, lag, save_folder,ylog=False): 106 | ''' 107 | plots the implied timescales calculated by the function 108 | 109 | get_its: 110 | parameters: 111 | ------------------------ 112 | its: numpy array 113 | the its array returned by the function get_its 114 | lag: numpy array 115 | lag times array used to estimated the implied timescales 116 | ylog: Boolean, optional, default=False 117 | if true, the plot will be a logarithmic plot, otherwize it 118 | will be a semilogy plot 119 | ''' 120 | 121 | if ylog: 122 | plt.loglog(lag, its.T[:,::-1]) 123 | plt.loglog(lag, lag, 'k') 124 | plt.fill_between(lag, lag, 0.99, alpha=0.2, color='k') 125 | else: 126 | plt.semilogy(lag, its.T[:,::-1]) 127 | plt.semilogy(lag,lag, 'k') 128 | plt.fill_between(lag, lag, 0.99, alpha=0.2, color='k') 129 | plt.show() 130 | plt.savefig(save_folder+'/its.png') 131 | 132 | 133 | def plot_scores(train_scores, validation_scores, save_folder): 134 | plt.loglog(train_scores, label='training') 135 | plt.loglog(validation_scores, label='validation') 136 | plt.xlabel('step') 137 | plt.ylabel('score') 138 | plt.legend() 139 | plt.savefig(save_folder+'/scores.png') 140 | 141 | 142 | 143 | def chapman_kolmogorov_validator(model, mlags, n_observables=None, 144 | observables='phi', statistics='psi'): 145 | ''' returns a chapman-kolmogrov validator based on this estimator and a test model 146 | 147 | parameters: 148 | ----------------- 149 | model: VAMP model 150 | mlags: int or int-array 151 | multiple of lagtimes of the test_model to test against 152 | test_model: CovarianceKoopmanModel, optional, default=None, 153 | The model that is tested, if not provided uses this estimator's encapsulated model. 154 | n_observables: int, optional, default=None, 155 | limit the number of default observables to this number. only used if 'observables' are Nonr or 'statistics' are None. 156 | observables: (input_dimension, n_observables) ndarray 157 | coefficents that express one or multiple observables in the basis of the input features 158 | statistics: (input_dim, n_statistics) ndarray 159 | coefficents that express one or more statistics in the basis of the input features 160 | 161 | Returns: 162 | ------------------ 163 | validator: KoopmanChapmanKolmogrovValidator 164 | the validator 165 | ''' 166 | test_model = model.fetch_model() 167 | assert test_model is not None, 'We need a test model via argument or an estimator which was already fit to data' 168 | 169 | lagtime = model.lagtime 170 | if n_observables is not None: 171 | if n_observables > test_model.dim: 172 | import warnings 173 | warnings.warn('selected singular functgions as observables but dimension is lower thanthe requested number of observables') 174 | n_observables = test_model.dim 175 | 176 | else: 177 | n_observables = test_model.dim 178 | 179 | if isinstance(observables, str) and observables == 'phi': 180 | observables = test_model.singular_vectors_right[:, :n_observables] 181 | observables_mean_free = True 182 | else: 183 | observables_mean_free = False 184 | 185 | if isinstance(statistics, str) and statistics == 'psi': 186 | statistics = test_model.singular_vectors_left[:, :n_observables] 187 | statistics_mean_free = True 188 | else: 189 | statistics_mean_free = False 190 | 191 | return VAMPKoopmanCKValidator(test_model, model, lagtime, mlags, observables, statistics, 192 | observables_mean_free, statistics_mean_free) 193 | 194 | def _vamp_estimate_model_for_lag(estimator: VAMP, model, data, lagtime): 195 | est = VAMP(lagtime=lagtime, dim=estimator.dim, var_cutoff=estimator.var_cutoff, scaling=estimator.scaling, 196 | epsilon=estimator.epsilon, observable_transform=estimator.observable_transform) 197 | 198 | whole_dataset = TrajectoryDataset.from_trajectories(lagtime=lagtime, data=data) 199 | whole_dataloder = DataLoader(whole_dataset, batch_size=10000, shuffle=False) 200 | for batch_0, batch_t in whole_dataloder: 201 | est.partial_fit((batch_0.numpy(), batch_t.numpy())) 202 | 203 | return est.fetch_model() -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import deeptime 4 | from tqdm import tqdm 5 | import matplotlib.pyplot as plt 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import mdshare 9 | from torch.utils.data import DataLoader 10 | import json 11 | from args import buildParser 12 | from layers import GaussianDistance, NeighborMultiHeadAttention, InteractionBlock, GraphConvLayer 13 | from model import GraphVampNet 14 | from deeptime.util.data import TrajectoryDataset 15 | from deeptime.decomposition.deep import VAMPNet 16 | from deeptime.decomposition import VAMP 17 | from copy import deepcopy 18 | import os 19 | import pickle 20 | import warnings 21 | from deeptime.decomposition._koopman import KoopmanChapmanKolmogorovValidator 22 | from utils_vamp import * 23 | 24 | if torch.cuda.is_available(): 25 | device = torch.device('cuda') 26 | print('cuda is available') 27 | else: 28 | print('Using CPU') 29 | device = torch.device('cpu') 30 | 31 | # ignore deprecation warnings 32 | warnings.filterwarnings('ignore',category=DeprecationWarning) 33 | 34 | args = buildParser().parse_args() 35 | 36 | if not os.path.exists(args.save_folder): 37 | print('making the folder for saving checkpoints') 38 | os.makedirs(args.save_folder) 39 | 40 | with open(args.save_folder+'/args.txt','w') as f: 41 | f.write(str(args)) 42 | 43 | meta_file = os.path.join(args.save_folder, 'metadata.pkl') 44 | pickle.dump({'args': args}, open(meta_file, 'wb')) 45 | 46 | #------------------- data as a list of trajectories --------------------------- 47 | 48 | dists1, inds1 = np.load(args.dist_data)['arr_0'], np.load(args.nbr_data)['arr_0'] 49 | 50 | mydists1 = torch.from_numpy(dists1).to(device) 51 | myinds1 = torch.from_numpy(inds1).to(device) 52 | 53 | data = [] 54 | data.append(torch.cat((mydists1,myinds1), axis=-1)) 55 | 56 | dataset = TrajectoryDataset.from_trajectories(lagtime=args.tau, data=data) 57 | 58 | 59 | n_val = int(len(dataset)*args.val_frac) 60 | train_data, val_data = torch.utils.data.random_split(dataset, [len(dataset)-n_val, n_val]) 61 | 62 | loader_train = DataLoader(train_data, batch_size=args.batch_size, shuffle=True) 63 | loader_val = DataLoader(val_data, batch_size=args.batch_size, shuffle=False) 64 | 65 | # data is a list of trajectories [T,N,M+M] 66 | #--------------------------------------------------------------------------------- 67 | 68 | lobe = GraphVampNet() 69 | lobe_timelagged = deepcopy(lobe).to(device=device) 70 | lobe = lobe.to(device) 71 | 72 | vampnet = VAMPNet(lobe=lobe, lobe_timelagged=lobe_timelagged, learning_rate=args.lr, device=device, optimizer='Adam', score_method=args.score_method) 73 | 74 | def count_parameters(model): 75 | ''' 76 | count the number of parameters in the model 77 | ''' 78 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 79 | 80 | print('number of parameters', count_parameters(lobe)) 81 | 82 | def train(train_loader , n_epochs, validation_loader=None): 83 | ''' 84 | Parameters: 85 | ----------------- 86 | train_loader: torch.utils.data.DataLoader 87 | The data to use for training, should yield a tuple of batches representing instantaneous and time-lagged samples 88 | n_epochs: int, the number of epochs for training 89 | validation_loader: torch.utils.data.DataLoader: 90 | The validation data should also be yielded as a two-element tuple. 91 | 92 | Returns: 93 | ----------------- 94 | model: VAMPNet 95 | ''' 96 | 97 | for epoch in tqdm(range(n_epochs)): 98 | ''' 99 | perform batches of data here 100 | ''' 101 | for batch_0, batch_t in train_loader: 102 | vampnet.partial_fit((batch_0.to(device),batch_t.to(device))) 103 | 104 | if validation_loader is not None: 105 | with torch.no_grad(): 106 | scores = [] 107 | for val_batch in validation_loader: 108 | scores.append(vampnet.validate((val_batch[0].to(device), val_batch[1].to(device)))) 109 | 110 | mean_score = torch.mean(torch.stack(scores)) 111 | vampnet._validation_scores.append((vampnet._step, mean_score.item())) 112 | 113 | if args.save_checkpoints: 114 | torch.save({ 115 | 'epoch' : epoch, 116 | 'state_dict': lobe.state_dict(), 117 | }, args.save_folder+'/logs_'+str(epoch)+'.pt') 118 | 119 | return vampnet.fetch_model() 120 | 121 | 122 | plt.set_cmap('jet') 123 | 124 | if not args.train and os.path.isfile(args.trained_model): 125 | print('Loading model') 126 | checkpoint = torch.load(args.trained_model) 127 | lobe.load_state_dict(checkpoint['state_dict']) 128 | lobe_timelagged = deepcopy(lobe).to(device=device) 129 | lobe = lobe.to(device) 130 | lobe.eval() 131 | lobe_timelagged.eval() 132 | vampnet = VAMPNet(lobe=lobe, lobe_timelagged=lobe_timelagged, learning_rate=args.lr, device=device) 133 | model = vampnet.fetch_model() 134 | 135 | 136 | elif args.train: 137 | model = train(train_loader=loader_train, n_epochs=args.epochs, validation_loader=loader_val) 138 | 139 | # save the training and validation scores 140 | with open(args.save_folder+'/train_scores.npy','wb') as f: 141 | np.save(f, vampnet.train_scores) 142 | 143 | with open(args.save_folder+'/validation_scores.npy','wb') as f: 144 | np.save(f, vampnet.validation_scores) 145 | 146 | # plotting the training and validation scores of the model 147 | plt.loglog(*vampnet.train_scores.T, label='training') 148 | plt.loglog(*vampnet.validation_scores.T, label='validation') 149 | plt.xlabel('step') 150 | plt.ylabel('score') 151 | plt.legend() 152 | plt.savefig(args.save_folder+'/scores.png') 153 | 154 | # making a numpy array of data for analysis 155 | data_np = [] 156 | for i in range(len(data)): 157 | data_np.append(data[i].cpu().numpy()) 158 | 159 | 160 | # for the analysis part create an iterator for the whole dataset to feed in batches 161 | whole_dataset = TrajectoryDataset.from_trajectories(lagtime=args.tau, data=data_np) 162 | whole_dataloder = DataLoader(whole_dataset, batch_size=args.batch_size, shuffle=False) 163 | 164 | 165 | 166 | # for plotting the implied timescales 167 | lagtimes = np.arange(1,201,2, dtype=np.int32)timescales = [] 168 | for lag in tqdm(lagtimes): 169 | vamp = VAMP(lagtime=lag, observable_transform=model) 170 | whole_dataset = TrajectoryDataset.from_trajectories(lagtime=lag, data=data_np) 171 | whole_dataloder = DataLoader(whole_dataset, batch_size=10000, shuffle=False) 172 | for batch_0, batch_t in whole_dataloder: 173 | vamp.partial_fit((batch_0.numpy(), batch_t.numpy())) 174 | # 175 | covariances = vamp._covariance_estimator.fetch_model() 176 | ts = vamp.fit_from_covariances(covariances).fetch_model().timescales(k=5) 177 | timescales.append(ts) 178 | 179 | 180 | 181 | 182 | f, ax = plt.subplots(1, 1) 183 | ax.semilogy(lagtimes, timescales) 184 | ax.set_xlabel('lagtime') 185 | ax.set_ylabel('timescale / step') 186 | ax.fill_between(lagtimes, ax.get_ylim()[0]*np.ones(len(lagtimes)), lagtimes, alpha=0.5, color='grey'); 187 | f.savefig(args.save_folder+'/ITS.png') 188 | 189 | 190 | 191 | def chunks(data, chunk_size=5000): 192 | ''' 193 | splitting the trajectory into chunks for passing into analysis part 194 | data: list of trajectories 195 | chunk_size: the size of each chunk 196 | ''' 197 | if type(data) == list: 198 | 199 | for data_tmp in data: 200 | for j in range(0, len(data_tmp),chunk_size): 201 | print(data_tmp[j:j+chunk_size,...].shape) 202 | yield data_tmp[j:j+chunk_size,...] 203 | 204 | else: 205 | 206 | for j in range(0, len(data), chunk_size): 207 | yield data[j:j+chunk_size,...] 208 | 209 | n_classes = int(args.num_classes) 210 | 211 | probs = [] 212 | total_emb= [] 213 | total_attn = [] 214 | for data_tmp in data_np: 215 | # transforming the data into the vampnet for modeling the dynamics 216 | mydata = chunks(data_tmp, chunk_size=5000) 217 | state_probs = np.zeros((data_tmp.shape[0], n_classes)) 218 | emb_tmp = np.zeros((data_tmp.shape[0], args.h_a)) 219 | attn_tmp = np.zeros((data_tmp.shape[0], args.num_atoms, args.num_neighbors)) 220 | 221 | n_iter = 0 222 | for i,batch in enumerate(mydata): 223 | batch_size = len(batch) 224 | state_probs[n_iter:n_iter+batch_size] = model.transform(batch) 225 | if args.return_emb and not args.return_attn: 226 | emb_1 = model.lobe(torch.tensor(batch), return_emb=True, return_attn=False) 227 | emb_tmp[n_iter:n_iter+batch_size] = emb_1.cpu().detach().numpy() 228 | elif args.return_emb and args.return_attn: 229 | emb_1, attn_1 = model.lobe(torch.tensor(batch), return_emb=True, return_attn=True) 230 | emb_tmp[n_iter:n_iter+batch_size], attn_tmp[n_iter:n_iter+batch_size] = emb_1.cpu().detach().numpy(),attn_1.cpu().detach().numpy() 231 | n_iter = n_iter + batch_size 232 | probs.append(state_probs) 233 | if args.return_emb: 234 | total_emb.append(emb_tmp) 235 | if args.return_attn: 236 | total_attn.append(attn_tmp) 237 | 238 | # problem here 239 | np.savez(args.save_folder+'/transformed.npz', probs[0]) 240 | 241 | if args.return_emb: 242 | np.savez(args.save_folder+'/embeddings.npz', total_emb[0]) 243 | if args.return_attn: 244 | np.savez(args.save_folder+'/total_attn.npz', total_attn[0]) 245 | 246 | quit() 247 | 248 | max_tau = 200 249 | lags = np.arange(1, max_tau, 1) 250 | 251 | #its = get_its(probs, lags) 252 | #plot_its(its, lags, ylog=False, save_folder=args.save_folder) 253 | 254 | steps = 6 255 | tau_msm = 200 256 | predicted, estimated = get_ck_test(probs, steps, tau_msm) 257 | 258 | plot_ck_test(predicted, estimated, n_classes, steps, tau_msm, args.save_folder) 259 | 260 | np.save(args.save_folder+'/ITS.npy', np.array(timescales)) 261 | np.savez(args.save_folder+'/ck.npz', list((predicted, estimated))) 262 | 263 | 264 | quit() 265 | 266 | # for plotting the CK test 267 | vamp = VAMP(lagtime=lag, observable_transform=model) 268 | whole_dataset = TrajectoryDataset.from_trajectories(lagtime=200, data=data_np) 269 | whole_dataloder = DataLoader(whole_dataset, batch_size=10000, shuffle=False) 270 | for batch_0, batch_t in whole_dataloder: 271 | vamp.partial_fit((batch_0.numpy(), batch_t.numpy())) 272 | 273 | validator = chapman_kolmogorov_validator(model=vamp, mlags=10) 274 | 275 | #validator = vamp.chapman_kolmogorov_validator(mlags=5) 276 | 277 | cktest = validator.fit(data_np, n_jobs=1, progress=tqdm).fetch_model() 278 | n_states = args.num_classes - 1 279 | 280 | tau = cktest.lagtimes[1] 281 | steps = len(cktest.lagtimes) 282 | fig, ax = plt.subplots(n_states, n_states, sharex=True, sharey=True, constrained_layout=True) 283 | for i in range(n_states): 284 | for j in range(n_states): 285 | pred = ax[i][j].plot(cktest.lagtimes, cktest.predictions[:, i, j], color='b') 286 | est = ax[i][j].plot(cktest.lagtimes, cktest.estimates[:, i, j], color='r', linestyle='--') 287 | ax[i][j].set_title(str(i+1)+ '->' +str(j+1), 288 | fontsize='small') 289 | ax[0][0].set_ylim((-0.1,1.1)); 290 | ax[0][0].set_xlim((0, steps*tau)); 291 | ax[0][0].axes.get_xaxis().set_ticks(np.round(np.linspace(0, steps*tau, 3))); 292 | fig.legend([pred[0], est[0]], ["Predictions", "Estimates"], 'lower center', ncol=2, 293 | bbox_to_anchor=(0.5, -0.1)); 294 | 295 | fig.savefig(args.save_folder+'/cktest.png') 296 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | # Author: Mahdi Ghorbani 2 | 3 | import torch 4 | import numpy as np 5 | import deeptime 6 | from tqdm import tqdm 7 | import matplotlib.pyplot as plt 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | import json 12 | from sklearn.neighbors import BallTree 13 | from args import buildParser 14 | from layers import GaussianDistance, GraphConvLayer, NeighborMultiHeadAttention, LinearLayer, ContinuousFilterConv, InteractionBlock 15 | from layers import GATLayer 16 | from layers import GraphAttentionLayer 17 | #from torch_scatter import scatter_mean 18 | import time 19 | 20 | args = buildParser().parse_args() 21 | 22 | if torch.cuda.is_available(): 23 | device = torch.device('cuda') 24 | print('cuda is is available') 25 | else: 26 | print('Using CPU') 27 | device = torch.device('cpu') 28 | 29 | 30 | class GraphVampNet(nn.Module): 31 | ''' wrapper class for different types of graph convolutions: ['GraphConvLayer', 'NeighborMultiHeadAttention', 'SchNets'] 32 | 33 | parameters: 34 | ----------------- 35 | seq_file: text file, 36 | the sequence file for initializing a one-hot encoding for embedding of different atoms 37 | If not provided, a random initilizer will be used for atom embeddings 38 | num_atoms: int, 39 | number of atoms 40 | num_neighbors: int, 41 | number of neighbors of each atom to be considered. 42 | tau: int, 43 | lagtime to be considered for the model 44 | n_classes: int, 45 | number of classes to cluster the output graphs 46 | n_conv: int, 47 | number of convlutional layers on graphs 48 | dmin: float, 49 | the minimum distance for performing the gaussian basis expansion 50 | dmax: float, 51 | the maximum distance for performing the gaussian basis expansion 52 | step: float, 53 | the step for gaussian basis expansion 54 | h_a: int, 55 | number of dimensions for atom embeddings 56 | h_g: int, 57 | number of dimension for after pooling 58 | atom_embedding_init: str, 59 | the type of initialization for the atom embedding 60 | use_pre_trained: boolean, 61 | Whether to use a pretrained embedding for atoms 62 | activation: 63 | the non-linear activation function to be used in the model 64 | pre_trained_weights_file: str, 65 | the filename for pretrained embedding 66 | conv_type: the type of convlutional layer for graphs ['GraphConvLayer', 'NeighborMultiHeadAttention', 'SchNets'] 67 | num_heads: int, 68 | the number of heads for multi-head attention in NeighborMultiHeadAttention convlution type 69 | residual: boolean, 70 | Whether to add a residual connection between different convolutions 71 | use_backbone_atoms: boolean, 72 | Whether to use backbone atoms for the protein graph model (if False will use the Ca atoms only) 73 | 74 | attention_pool: bool 75 | Whether to add Graph Attention layer between embedding learned before performing graph pooling 76 | ''' 77 | def __init__(self, seq_file=args.seq_file, num_atoms=args.num_atoms, num_neighbors=args.num_neighbors, 78 | n_classes=args.num_classes, n_conv=args.n_conv, dmin=args.dmin, dmax=args.dmax, step=args.step, 79 | h_a=args.h_a, h_g=args.h_g, atom_embedding_init='normal', use_pre_trained=False, activation=nn.ReLU(), 80 | pre_trained_weights_file=None, conv_type=args.conv_type, num_heads=args.num_heads, residual=args.residual, 81 | use_backbone_atoms=args.use_backbone_atoms, dont_pool_backbone=args.dont_pool_backbone, 82 | attention_pool=args.attention_pool): 83 | 84 | super(GraphVampNet, self).__init__() 85 | self.seq_file = seq_file 86 | self.num_atoms = num_atoms 87 | self.num_neighbors = num_neighbors 88 | self.n_classes= n_classes 89 | self.n_conv = n_conv 90 | self.dmin = dmin 91 | self.dmax = dmax 92 | self.step = step 93 | self.h_a = h_a 94 | self.h_g = h_g 95 | self.gauss = GaussianDistance(dmin, dmax, step) 96 | self.h_b = self.gauss.num_features # number of gaussians 97 | self.num_heads = num_heads 98 | self.activation = nn.ReLU() 99 | self.use_backbone_atoms = use_backbone_atoms 100 | self.residual = residual 101 | #self.atom_emb = nn.Embedding(num_embeddings=self.num_atoms, embedding_dim=self.h_a) 102 | #self.atom_emb.weight.data.normal_() 103 | self.conv_type = conv_type 104 | if self.conv_type == 'GraphConvLayer': 105 | self.convs = nn.ModuleList([GraphConvLayer(self.h_a, 106 | self.h_b) for _ in range(self.n_conv)]) 107 | 108 | elif self.conv_type == 'NeighborMultiHeadAttention': 109 | self.convs = nn.ModuleList([NeighborMultiHeadAttention(self.h_a, 110 | self.h_b, 111 | self.num_heads) for _ in range(self.n_conv)]) 112 | 113 | elif self.conv_type == 'SchNet': 114 | self.convs = nn.ModuleList([InteractionBlock(n_inputs=self.h_a, 115 | n_gaussians=self.h_b, 116 | n_filters=self.h_a, 117 | activation=nn.Tanh()) for _ in range(self.n_conv)]) 118 | 119 | self.conv_activation = nn.ReLU() 120 | if self.h_g is not None: 121 | self.fc_classes = nn.Linear(self.h_g, n_classes) 122 | else: 123 | self.fc_classes = nn.Linear(self.h_a, n_classes) 124 | self.init = atom_embedding_init 125 | self.use_pre_trained = use_pre_trained 126 | self.dont_pool_backbone = dont_pool_backbone 127 | self.attention_pool = attention_pool 128 | 129 | #elf.weight = nn.Parameter(torch.Tensor(self.h_a, 1)) 130 | 131 | #if args.use_backbone_atoms: 132 | if args.h_g is not None: 133 | self.amino_emb = nn.Linear(self.h_a, self.h_g) 134 | 135 | if use_pre_trained: 136 | self.pre_trained_emb(pre_trained_weights_file) 137 | 138 | elif seq_file is not None: 139 | atom_emb = self.onehot_encode_amino(seq_file) 140 | self.atom_embeddings = torch.tensor(atom_emb, dtype=torch.float32).to(device) 141 | self.h_init = atom_emb.shape[-1] # dimension of atom embedding [20] 142 | emb = nn.Embedding.from_pretrained(self.atom_embeddings, freeze=False) 143 | self.atom_emb = nn.Linear(self.h_init, self.h_a) # linear layer for atom features 144 | 145 | else: 146 | # initialize the atom embeddings randomly 147 | self.atom_emb = nn.Embedding(num_embeddings=self.num_atoms, embedding_dim=self.h_a) 148 | self.init_emb() 149 | 150 | if self.attention_pool: 151 | self.attn_pool_model = GATLayer(self.h_a, self.h_a, concat_heads=True, alpha=0.2) 152 | 153 | def pre_trained_emb(self, file): 154 | ''' 155 | loads the pre-trained node embedings from a file 156 | For now we are not freezing the pre-trained embeddings since 157 | we are going to update it in the graph convolution 158 | ''' 159 | 160 | with open(self.pre_trained_weights_file) as f: 161 | loaded_emb = json.load(f) 162 | 163 | embed_list = [torch.tensor(value, dtype=torch.float32) for value in loaded_emb.values()] 164 | self.atom_embeddings = torch.stack(embed_list, dim=0) 165 | self.h_init = self.atom_embeddings.shape[-1] # dimension atom embedding init 166 | self.atom_emb = nn.Embedding.from_pretrained(self, atom_embeddings, freeze=False) 167 | self.embedding = nn.Linear(self.h_init, self.h_a) 168 | 169 | def init_emb(self): 170 | ''' 171 | Initialize random embedding for the atoms 172 | ''' 173 | #--------------initialization for the embedding-------------- 174 | if self.init == 'normal': 175 | self.atom_emb.weight.data.normal_() 176 | 177 | elif self.init == 'xavier_normal': 178 | self.atom_emb.weight.data._xavier_normal() 179 | 180 | elif self.init == 'uniform': 181 | self.atom_emb.weight.data._uniform() 182 | 183 | 184 | def onehot_encode_amino(self, seq_file): 185 | ''' 186 | one-hot encoding of amino types for initializing the embedding 187 | ''' 188 | with open(seq_file, 'r') as f: 189 | sequence = open(seq_file) 190 | seq = sequence.readlines()[0].strip() 191 | 192 | amino_dict = {'A':0, 193 | 'R':1, 194 | 'N':2, 195 | 'D':3, 196 | 'C':4, 197 | 'Q':5, 198 | 'E':6, 199 | 'G':7, 200 | 'H':8, 201 | 'I':9, 202 | 'L':10, 203 | 'K':11, 204 | 'M':12, 205 | 'F':13, 206 | 'P':14, 207 | 'S':15, 208 | 'T':16, 209 | 'W':17, 210 | 'Y':18, 211 | 'Z':19} 212 | 213 | if args.use_backbone_atoms: 214 | s_encoded = np.zeros((20, 3*len(seq))) 215 | for i, n in enumerate(s): 216 | s_encoded[amino_dict[n],i*3:i*3+3] = 1 217 | 218 | else: 219 | s_encoded = np.zeros((20, len(seq))) 220 | for i, n in enumerate(s): 221 | s_encoded[amino_dict[n],i] = 1 222 | 223 | return s_encoded.T 224 | 225 | #------------------------------------------------------------ 226 | 227 | @staticmethod 228 | def convert_adj(mat): 229 | # convert the nbr_adj_list matrix to an adjacency matrix 230 | mat = mat.to(torch.long) 231 | adj = torch.zeros((mat.shape[0], mat.shape[1], mat.shape[1])) 232 | for i in range(mat.shape[0]): 233 | adj[i][torch.arange(adj.shape[1])[:,None],mat[i]] = 1 234 | adj[i] += torch.eye(mat.shape[1]) 235 | return adj.to(device) 236 | 237 | def pooling(self, atom_emb): 238 | # global pooling layer by averaging the embedding of nodes to get embedding of graph 239 | 240 | summed = torch.sum(atom_emb, dim=1) 241 | return summed / self.num_atoms 242 | 243 | 244 | def pool_amino(self, atom_emb): 245 | ''' 246 | pooling the features of atoms in each amino acid to get a feature vector for each residue 247 | parameters: 248 | -------------------------- 249 | atom_emb: embedding of atoms [B,N,h_a] 250 | residue_atom_idx: mapping between every atom and every residue in the protein 251 | size: [N] example [0,0,0,1,1,1,2,2,2] for N=6 and NA=3 252 | 253 | Returns: 254 | -------------------------- 255 | pooled features of amino acids in the graph 256 | [B, Na, h_a] 257 | ''' 258 | 259 | B = atom_emb.shape[0] 260 | N = atom_emb.shape[1] 261 | h_a = atom_emb.shape[2] 262 | 263 | residue_atom_idx = torch.arange(N).repeat(1,3) 264 | residue_atom_idx = residue_atom_idx.view(3,N).T.reshape(-1,1).squeeze(-1) 265 | Na = torch.max(residue_atom_idx)+1 # number of residues 266 | pooled = scatter_mean(atom_emb, residue_atom_idx, out=atom_emb.new_zeros(B,Na,h_a), dim=1) 267 | return pooled 268 | 269 | def return_emb(self): 270 | ''' 271 | returns the embedding learned for each amino acid (or Ca atom) after training the model 272 | ''' 273 | return self.emb 274 | 275 | 276 | 277 | def forward(self, data, return_emb=False, return_attn=False): 278 | ''' Graph neural net computation to get features of protein at each timestep of simulation 279 | 280 | data: has shape [batch-size, num_atoms, num_neighbors*2] 281 | the first half of last index contains the nbr_adj_dist -> distance between every two atoms 282 | the second half of data contains the the nbr_adj_list -> the index of neighbors of each atom 283 | 284 | This model: 285 | 1. expands the nbr_adj_dist into gaussian basis function [batch-size, num-atoms, num-neighbors, n_gaussians] 286 | 2. initalize the embedding and get the initial embedding of each atom 287 | 3. perform graph convolution to propagate messages along nodes and edges multiple times 288 | 4. pooling to get the features for the graph from node features 289 | 5. Linear layer to get coarse grain for number of classes 290 | 6. Apply a softmax activation for getting the class assignment probabilities 291 | 292 | ''' 293 | M = data.shape[-1] # num_neighbors*2 294 | nbr_adj_dist = data[:,:,:M//2] # [batch-size, num_atoms, num_neighbors] 295 | nbr_adj_list = data[:,:,M//2:] # [batch-size, num_atoms, num_neighbors] 296 | N = nbr_adj_list.shape[1] 297 | B = nbr_adj_list.shape[0] 298 | 299 | nbr_emb = self.gauss.expand(nbr_adj_dist) # [batch-size, num_atoms, num_neighbors, n_gaussians] 300 | # this is the edge embedding 301 | 302 | atom_emb_idx = torch.arange(N).repeat(B,1).to(device) 303 | atom_emb = self.atom_emb(atom_emb_idx) 304 | # atom_emb [B,N,h_a] 305 | 306 | if args.conv_type == 'GraphConvLayer': 307 | for idx in range(self.n_conv): 308 | tmp_conv = self.convs[idx](atom_emb=atom_emb, 309 | nbr_emb=nbr_emb, 310 | nbr_adj_list=nbr_adj_list) 311 | if self.residual: 312 | atom_emb = atom_emb + tmp_conv 313 | else: 314 | atom_emb = tmp_conv 315 | 316 | elif args.conv_type == 'NeighborMultiHeadAttention': 317 | for idx in range(self.n_conv): 318 | tmp_conv = self.convs[idx](h_V=atom_emb, 319 | h_E=nbr_emb, 320 | mask_attend=nbr_adj_list) 321 | if self.residual: 322 | atom_emb = atom_emb + tmp_conv 323 | else: 324 | atom_emb = tmp_conv 325 | 326 | elif args.conv_type == 'SchNet': 327 | for idx in range(self.n_conv): 328 | tmp_conv, attn_probs = self.convs[idx](features=atom_emb, 329 | rbf_expansion=nbr_emb, 330 | neighbor_list=nbr_adj_list) 331 | 332 | if self.residual: 333 | atom_emb = atom_emb + tmp_conv 334 | else: 335 | atom_emb = tmp_conv 336 | 337 | emb = self.conv_activation(atom_emb) 338 | # [batch-size, num-atoms, h_a] 339 | 340 | 341 | if args.use_backbone_atoms: 342 | # apply a pooling layer for backbone atoms to amino acid features 343 | emb = self.pool_amino(emb) 344 | 345 | #t1 = time.time() 346 | #adj = self.convert_adj(nbr_adj_list) # convert nbr_adj_list to an adjacency matrix 347 | #t2 = time.time() 348 | #print('time: ' + str(t2-t1)) 349 | #if self.attention_pool: 350 | # emb, attn_probs = self.attn_pool_model(emb, adj) 351 | 352 | 353 | #print(attn_probs.shape) 354 | 355 | # embedding for each atom or amino acid learned after training 356 | #print(emb.shape) 357 | self.emb = emb # [batch, N, h_a] 358 | 359 | 360 | # the last pooling layer for getting graph features 361 | #attn_logits = torch.matmul(self.emb, self.weight).reshape(B,N) 362 | #attn_probs = F.softmax(attn_logits, dim=-1) 363 | 364 | #h_a = self.emb.shape[2] 365 | #self.emb = self.emb.reshape((B,h_a,N)) 366 | #self.prot_emb = torch.bmm(self.emb, attn_probs.unsqueeze(-1)).squeeze(-1) 367 | 368 | self.prot_emb = self.pooling(self.emb) 369 | if self.h_g is not None: 370 | self.prot_emb = self.amino_emb(self.prot_emb) 371 | # [B, h_a] or [B, h_g] 372 | #print(self.prot_emb.shape) 373 | self.class_logits = self.fc_classes(self.prot_emb) 374 | # [B, n_classes] 375 | self.class_probs = F.softmax(self.class_logits, dim=-1) 376 | if return_emb: 377 | return self.prot_emb, attn_probs 378 | else: 379 | return self.class_probs 380 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | # Author: Mahdi Ghorbani 2 | 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | if torch.cuda.is_available(): 11 | device = torch.device('cuda') 12 | print('cuda is is available') 13 | else: 14 | print('Using CPU') 15 | device = torch.device('cpu') 16 | 17 | def LinearLayer(d_in, d_out, bias=True, activation=None, dropout=0, weight_init='xavier'): 18 | ''' 19 | Linear layer function 20 | 21 | Parameters: 22 | --------------------- 23 | d_in: int, input dimension 24 | d_out: int, output dimension 25 | bias: bool, (default=True) Whether or not to add bias 26 | activation: torch.nn.Module() (default=None) 27 | activation function of the layer 28 | dropout: float (default=0) 29 | the dropout to be added 30 | weight_init: str,, float, or nn.init function (default='xavier') 31 | specifies the initialization of the layer weights. If float or int is passed a constant initialization 32 | is used. 33 | 34 | Returns: 35 | ---------------------- 36 | seq: list of torch.nn.Module instances 37 | the full linear layer including activations and optional dropout 38 | ''' 39 | seq = [nn.Linear(d_in, d_out, bias=bias)] 40 | if activation is not None: 41 | if isinstance(activation, nn.Module): 42 | seq += [activation] 43 | else: 44 | raise TypeError('Activation {} is not a valid torch.nn.Module'.format(str(activation))) 45 | 46 | if dropout is not None: 47 | seq += [nn.Dropout(dropout)] 48 | 49 | with torch.no_grad(): 50 | if weight_init == 'xavier': 51 | torch.nn.init.xavier_uniform_(seq[0].weight) 52 | if weight_init == 'identity': 53 | torch.nn.init.eye_(seq[0].weight) 54 | if weight_init not in ['xavier', 'identity', None]: 55 | if isinstance(weight_init, int) or isinstance(weight_init, float): 56 | torch.nn.init.constant_(seq[0].weight, weight_init) 57 | 58 | return seq 59 | 60 | class NeighborNormLayer(nn.Module): 61 | ''' Normalization layer that divides the output of a preceding layer by the number of neighbor features. 62 | ''' 63 | def __init__(self): 64 | super(NeighborNormLayer, self).__init__() 65 | 66 | def forward(self, input_features, n_neighbors): 67 | ''' Computes normalized output 68 | 69 | Parameters: 70 | ---------------- 71 | input_features: torch.tensor 72 | input tensor of features [n_frames, n_atoms, n_feats] 73 | n_neighbors: int, number of neighbors 74 | 75 | Returns: 76 | ---------------- 77 | normalized_features: torch.tensor, normalized input features 78 | ''' 79 | return input_features / n_neighbors 80 | 81 | 82 | def gather_nodes(nodes, neighbor_idx): 83 | ''' 84 | given node-features [batch-size, num-atoms, num-features] and neighbor_dix [batch-size, num-atoms, num-neighbors] 85 | this find the neighbors of each node and concatentates their features together: 86 | 87 | Returns: 88 | [batch-szie, num-atoms, num-neighbors, num-features] 89 | ''' 90 | neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], -1)) 91 | # [batch-size, num_nodes*num_neighbors] 92 | neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1,-1,nodes.size(2)) 93 | # [batch-size, num_nodes*num_neighbors, num_features] 94 | 95 | neighbor_features = torch.gather(nodes, 1, neighbors_flat) 96 | neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3]+[-1]) 97 | return neighbor_features 98 | 99 | def cat_neighbor_nodes(h_nodes, h_neighbors, E_idx): 100 | ''' 101 | given node embeddings [h_nodes] and neighbor embeddings [h_neighbors] and indexes this concatentates the features 102 | params: 103 | h_nodes: [B,N,h_nodes] 104 | h_neighbors: [B,N,M,h_edges] 105 | E_idx: [B,N,M] 106 | returns: 107 | concatenated features [B,N,M,h_nodes+h_edges] 108 | ''' 109 | h_nodes = gather_nodes(h_nodes, E_idx) 110 | h_nn = torch.cat([h_neighbors, h_nodes], -1) 111 | return h_nn 112 | 113 | 114 | 115 | class GraphAttentionLayer(nn.Module): 116 | def __init__(self, in_features, out_features, alpha, concat=True,): 117 | super(GraphAttentionLayer, self).__init__() 118 | self.in_features = in_features 119 | self.out_features = out_features 120 | self.alpha = alpha 121 | 122 | self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) 123 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 124 | self.a = nn.Parameter(torch.empty(size=(2*out_features, 1))) 125 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 126 | 127 | self.leakyrelu = nn.LeakyReLU(self.alpha) 128 | 129 | def forward(self, h, adj): 130 | Wh = torch.mm(h, self.W) 131 | Wh1 = torch.matmul(Wh, self.a[:self.out_features, :]) 132 | Wh2 = torch.matmul(Wh, self.a[:,self.out_features:,:]) 133 | # broadcast add 134 | e = Wh1 + Wh2 135 | e = self.leakyrelu(e) 136 | zero_vec = -9e15 * torch.ones_like(e) 137 | attention = torch.where(adj>0, e, zero_vec) 138 | attention = F.softmax(attention, dim=1) 139 | h_prime = torch.matmul(attention, Wh) 140 | 141 | return h_prime, attention 142 | 143 | 144 | class GATLayer(nn.Module): 145 | def __init__(self, c_in, c_out, num_heads=1, concat_heads=True, alpha=0.2): 146 | ''' 147 | inputs: 148 | c_in: dim of input features 149 | c_out: dim of output features 150 | num_head: number of attention heads. here we only use 1 151 | concat_heads: Whether to concat attention heads 152 | alpha: negative slope of LeakyRelu activation 153 | ''' 154 | super().__init__() 155 | self.num_heads = num_heads 156 | self.concat_heads = concat_heads 157 | if self.concat_heads: 158 | assert c_out % num_heads == 0, 'Number of outupt features must be multiple of the count of heads' 159 | c_out = c_out // num_heads 160 | 161 | self.projection = nn.Linear(c_in, c_out*num_heads) 162 | self.a = nn.Parameter(torch.Tensor(num_heads, 2*c_out)) # one per head 163 | self.leakyrelu = nn.LeakyReLU(alpha) 164 | 165 | # parameter initialization 166 | nn.init.xavier_uniform_(self.projection.weight.data, gain=1.414) 167 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 168 | 169 | def forward(self, node_feats, adj_matrix, return_attn_probs=True): 170 | ''' 171 | Inputs: 172 | node_feats: input features of nodes [batch-size, num_nodes, c_in] 173 | adj_matrix: adjacency matrix [batch-size, num_nodes, num_nodes] with self connections 174 | return_attn_probs: If True, attention weights will be returned 175 | ''' 176 | batch_size, num_nodes = node_feats.size(0), node_feats.size(1) 177 | 178 | node_feats = self.projection(node_feats) 179 | node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1) 180 | 181 | # concatenate the node feature of all the nodes that are connected 182 | edges = adj_matrix.nonzero(as_tuple=False) # return edges 183 | node_feats_flat = node_feats.view(batch_size*num_nodes, self.num_heads, -1) 184 | 185 | edge_indices_row = edges[:,0] * num_nodes + edges[:,1] 186 | edge_indices_col = edges[:,0] * num_nodes + edges[:,2] 187 | a_input = torch.cat([ 188 | torch.index_select(input=node_feats_flat, index=edge_indices_row, dim=0), 189 | torch.index_select(input=node_feats_flat, index=edge_indices_col, dim=0)], dim=-1) 190 | 191 | # calculate attention MLP output (independent for each head) 192 | attn_logits = torch.einsum('bhc,hc->bh', a_input, self.a) 193 | attn_logits = self.leakyrelu(attn_logits) 194 | 195 | # map list of attention values back into a matrix 196 | attn_matrix = attn_logits.new_zeros(adj_matrix.shape+(self.num_heads,)).fill_(-9e15) 197 | attn_matrix[adj_matrix[...,None].repeat(1,1,1,self.num_heads)==1] =attn_logits.reshape(-1) 198 | 199 | # Weighted average of attention 200 | attn_probs = F.softmax(attn_matrix, dim=2) 201 | 202 | node_feats = torch.einsum('bijh,bjhc->bihc', attn_probs, node_feats) 203 | 204 | # If heads should be concatenated : 205 | if self.concat_heads: 206 | node_feats = node_feats.reshape(batch_size, num_nodes, -1) 207 | else: 208 | node_feats = node_feats.mean(dim=2) 209 | 210 | if return_attn_probs: 211 | return node_feats, attn_probs 212 | else: 213 | return node_feats 214 | 215 | 216 | class NeighborMultiHeadAttention(nn.Module): 217 | ''' MultiHeadAttention class for neighbors of every node in the graph representation 218 | 219 | multihead attention is defined on the edges between node in the neighborhood of every node 220 | 221 | quries: current node embedding 222 | keys: relational information r(i,j)=(h_j, e_ij) j from N(i,k) 223 | values: relational information same as keys 224 | 225 | 226 | parameters: 227 | ------------------ 228 | num_hidden: int, number of features of atoms 229 | num_int: int, number of gaussians for edge 230 | num_heads: int, number of heads for multi-head attention 231 | 232 | Returns: 233 | ------------------ 234 | updated node embedding after multihead attention of neighbors 235 | ''' 236 | def __init__(self, num_hidden, num_in, num_heads=4): 237 | super(NeighborMultiHeadAttention, self).__init__() 238 | self.num_heads = num_heads 239 | self.num_hidden = num_hidden 240 | 241 | # self-attention layers: {queries, keys, values, output} 242 | self.W_Q = nn.Linear(num_hidden, num_hidden, bias=False) 243 | self.W_K = nn.Linear(num_in, num_hidden, bias=False) # make the edges the same size as node embedding size 244 | self.W_V = nn.Linear(num_in, num_hidden, bias=False) # make the edges the same size as node embedding size 245 | self.W_O = nn.Linear(num_hidden, num_hidden, bias=False) # linear layer for the output 246 | return 247 | 248 | def _masked_softmax(self, attend_logits, mask_attend, dim=-1): 249 | ''' numercially stable masked softmax ''' 250 | # attend_logits: [batch-size, num_nodes, num_heads, num_neighbors] 251 | # mask_attend: [batch-size, num_nodes, num_heads, num_neighbors] 252 | negative_inf = np.finfo(np.float32).min 253 | attend_logits = torch.where(mask_attend, attend_logits, torch.tensor(negative_inf).to(device)) 254 | attend = F.softmax(attend_logits, dim) 255 | attend = mask_attend * attend 256 | return attend 257 | 258 | def forward(self, h_V, h_E, mask_attend=None, return_attn_probs=True): 259 | ''' self attention graph structure O(NK) 260 | 261 | parameters: 262 | h_V: node features [batch_size, num_nodes, n_features] 263 | h_E: neighbor featues [batch_size, num_nodes, num_neighbors, n_gaussians] 264 | mask_attend: mask for attention [batch_size, num_nodes, num_neighbors] 265 | Returns: 266 | h_V_t: node update after neighbor attention [batch-size, num_nodes, n_features] 267 | ''' 268 | # dimensions 269 | n_batch, n_nodes, n_neighbors = h_E.shape[:3] 270 | n_heads = self.num_heads 271 | d = int(self.num_hidden/n_heads) 272 | 273 | Q = self.W_Q(h_V.to(torch.float32)).view([n_batch, n_nodes, 1, n_heads, 1, d]) # [B,N,1,n_heads,1,d] 274 | # [batch-size, num-nodes, 1, num_heads, 1, d] 275 | K = self.W_K(h_E.to(torch.float32)).view([n_batch, n_nodes, n_neighbors, n_heads, d, 1]) 276 | # [batch-size, num-nodes, num_neighbors, num_heads, d, 1] 277 | V = self.W_V(h_E.to(torch.float32)).view([n_batch, n_nodes, n_neighbors, n_heads, d]) 278 | # [batch-size, num-nodes, num_neighbors, num_heads, d] 279 | 280 | # attention with scaled inner product between keys and queries 281 | attend_logits = torch.matmul(Q,K).view([n_batch, n_nodes, n_neighbors, n_heads]).transpose(-2,-1) 282 | # [batch-size, num-nodes, n_heads, num_nbrs] 283 | attend_logits = attend_logits / np.sqrt(d) # normalize 284 | #if mask_attend is not None: 285 | # masked softmax 286 | # mask = mask_attend.unsqueeze(2).expand(-1, -1, n_heads, -1).to(device) 287 | # [batch-size, num_nodes, num_heads, num_neighbors] 288 | # attend = self._masked_softmax(attend_logits, mask) 289 | #else: 290 | attend = F.softmax(attend_logits, -1).to(device) 291 | # [batch-size, num-nodes, num-heads, num-neighbors] 292 | 293 | h_V_update = torch.matmul(attend.unsqueeze(-2).to(torch.float32), V.transpose(2,3)) 294 | # multiply attention scores with the values 295 | h_V_update = h_V_update.view([n_batch, n_nodes, self.num_hidden]) 296 | h_V_update = self.W_O(h_V_update) 297 | if return_attn_probs: 298 | return h_V_update, attend.mean(dim=2) 299 | else: 300 | return h_V_update 301 | 302 | class GraphConvLayer(nn.Module): 303 | r''' Graph convolutional layer as introduced by Tian et al in 2019 for materrials and reformultated for protines. 304 | 305 | 306 | formulation: 307 | v_{i}^{k+1} = v_{i}{k} + \sim_{j} w_{i,j}^k \circ g(z_{i,j}^k W_{c}^k + b_{c}^k) 308 | where: 309 | w_{i,j}^k = sigmoid(z_{i,j}^k W_{g}^k + b_{g}^k) and z_{i,j}^k = concat([v_{i}^k, v_{j}^k, u_{i,j}^k]) 310 | 311 | v_{i}^k is the node embedding of atom i at layer k, z_{i,j}^k is the concatentation of neighboring atoms and their bond features. 312 | g() denote a non-linear activation function, w_{i,j}^k is an edge-gating mechanism to incorporate different interaction strength 313 | among neighbors of a node. 314 | 315 | 316 | Parameters: 317 | --------------------- 318 | atom_emb_dim : int, atom embedding dimension 319 | bond_emb_dim: int, bond embedding dimension (number of gaussians) 320 | 321 | 322 | References: 323 | ---------------------- 324 | {Sanyal2020.04.06.028266, 325 | author = {Sanyal, Soumya and Anishchenko, Ivan and Dagar, Anirudh and Baker, David and Talukdar, Partha}, 326 | title = {ProteinGCN: Protein model quality assessment using Graph Convolutional Networks}, 327 | year = {2020}, 328 | doi = {10.1101/2020.04.06.028266}, 329 | publisher = {Cold Spring Harbor Laboratory}, 330 | URL = {https://www.biorxiv.org/content/early/2020/04/07/2020.04.06.028266}, 331 | journal = {bioRxiv} 332 | ''' 333 | def __init__(self, atom_emb_dim, bond_emb_dim): 334 | super(GraphConvLayer, self).__init__() 335 | self.h_a = atom_emb_dim 336 | self.h_b = bond_emb_dim 337 | self.fc_full = nn.Linear(2*self.h_a+self.h_b, 2*self.h_a) 338 | self.sigmoid = nn.Sigmoid() 339 | self.activation_hidden =nn.ReLU() 340 | self.bn_hidden = nn.BatchNorm1d(2*self.h_a) 341 | self.bn_output = nn.BatchNorm1d(self.h_a) 342 | self.activation_output = nn.ReLU() 343 | 344 | def forward(self, atom_emb, nbr_emb, nbr_adj_list): 345 | ''' Compute the GraphConvLayer 346 | 347 | Parameters: 348 | ------------------ 349 | atom_emb: torch.tensor, atom embeddings [batch-size, num_atoms, num_features] 350 | nbr_emb: torch.tensor, bond embedding of atom neighbors [batch-size, num_atom, num_neighbor, n_gaussians] 351 | nbr_adj_list: torch.tensor, indices for neighbors of a node [batch-size, num_atoms, num_neighbors] 352 | 353 | Returns: 354 | ------------------- 355 | out: atom embeddings after applying GraphConvLayer [batch-size, num_atom, num_features] 356 | ''' 357 | N, M = nbr_adj_list.shape[1:] 358 | # N is number of atoms and M is number of neighbors 359 | B = atom_emb.shape[0] # batch-size 360 | 361 | # gather the feature of neighbors into size [batch-size, num_atoms, num_neighbors, num_features] 362 | atom_nbr_emb = atom_emb[torch.arange(B).unsqueeze(-1), nbr_adj_list.to(torch.long).view(B,-1)].view(B,N,M,self.h_a).to(device) 363 | # concatenate the embedding of neighboring atoms with the bond embedding connecting the two 364 | # shape [batch-size, num-atoms, num_neighbors, 2*num_features + num_gaussians] 365 | total_nbr_emb = torch.cat([atom_emb.unsqueeze(2).expand(B,N,M,self.h_a), atom_nbr_emb, nbr_emb],dim=-1).to(torch.float32) 366 | 367 | # apply a linear layer 368 | total_gated_emb = self.fc_full(total_nbr_emb) 369 | total_gated_emb = self.bn_hidden(total_gated_emb.view(-1,self.h_a*2)).view(B,N,M,self.h_a*2) 370 | nbr_filter, nbr_core = total_gated_emb.chunk(2, dim=3) 371 | # Sigmoid function for edge-gating mechanism 372 | nbr_filter = self.sigmoid(nbr_filter) 373 | nbr_core = self.activation_hidden(nbr_core) 374 | # element-wise multiplication and aggregation of neighbors 375 | nbr_sumed = torch.sum(nbr_filter*nbr_core, dim=2) 376 | # apply batch-norm to output 377 | nbr_sumed = self.bn_output(nbr_sumed.view(-1, self.h_a)).view(B,N,self.h_a) 378 | # apply non-linear activation to output 379 | out = self.activation_output(atom_emb+nbr_sumed) 380 | return out 381 | 382 | class GaussianDistance(object): 383 | def __init__(self, dmin, dmax, step, var=None): 384 | 385 | ''' Expands ditsnces by gaussian basis functions 386 | parameters: 387 | ------------------- 388 | dmin: float, minimum distance between atoms to be considered for gaussian basis 389 | dmax: float, maximum distance between atoms to be considered for gaussian basis 390 | step: float, step size for the gaussian filter 391 | ''' 392 | assert dmin < dmax 393 | assert dmax - dmin > step 394 | self.filter = torch.arange(dmin, dmax+step, step) 395 | self.num_features = len(self.filter) 396 | if var is None: 397 | var = step 398 | self.var = var 399 | 400 | def expand(self, distance): 401 | ''' 402 | apply gaussian distance filter to a numpy array distance 403 | parameters: 404 | ----------------- 405 | N: number of atoms 406 | M: number of neighbors 407 | B: batch-size 408 | distance: shape [B, N, M] 409 | 410 | returns: 411 | expanded distance with shape [B, N, M, bond_fea_len] 412 | ''' 413 | return torch.exp(-(torch.unsqueeze(distance,-1).to(device)-self.filter.to(device))**2/self.var**2) 414 | 415 | 416 | class ContinuousFilterConv(nn.Module): 417 | ''' 418 | Continuous filter convolution layer for SchNet as described by Schütt et al.(2018) 419 | A continuous-filter convolutional layer uses continuous radial basis functions for discrete data. (Schütt et al. (2018)) 420 | Continuous-filter convolution block consists of a filter generating network as follows: 421 | 422 | Filter generator: 423 | 1. get distances betwee nodes. 424 | 2. atom-wise/Linear layer with non-linear activation 425 | 3. atom-wise/Linear layer with non-linear activation 426 | 427 | The filter generator output is then multiplied element-wise with the continuous convolution filter as part of the interaction block. 428 | 429 | Parameters: 430 | ----------------- 431 | n_gaussians: int 432 | number of gaussian used in hte radial basis function. needed to determine input feature size of first dense layer 433 | n_filters: int 434 | number of filters that will be created. Also determines the output size. Needs to be the same size as the features of residual 435 | connection in the interaction block. 436 | activation: nn.Module 437 | Activation function for the filter generating network. 438 | 439 | Notes: 440 | ------------------ 441 | Following current implementation in SchNetPack, the last linear layer of the filter generator does not contain an activation 442 | function. This allows the filter generator to contain negative values. 443 | 444 | References: 445 | 446 | ''' 447 | def __init__(self, n_gaussians, n_filters, activation=nn.Tanh(), normalization_layer=None): 448 | 449 | super(ContinuousFilterConv, self).__init__() 450 | filter_layers = LinearLayer(n_gaussians, n_filters, bias=True, activation=activation) 451 | filter_layers += LinearLayer(n_filters, n_filters, bias=True) # no activation here 452 | self.filter_generator = nn.Sequential(*filter_layers) 453 | 454 | self.nbr_filter = nn.Parameter(torch.Tensor(n_filters, 1)) 455 | nn.init.xavier_uniform_(self.nbr_filter.data, gain=1.414) 456 | 457 | if normalization_layer is not None: 458 | self.normalization_layer = normalization_layer 459 | else: 460 | self.normalization_layer = None 461 | 462 | def forward(self, features, rbf_expansion, neighbor_list): 463 | ''' Compute convolutional block 464 | Parameters: 465 | ------------- 466 | features: torch.Tensor 467 | Feature vector of size [n_frames, n_atoms, n_features] 468 | rbf_expansion: torch.Tensor 469 | Gaussian expansion of bead distances of size, [n_frames, n_atoms, n_neighbors, n_gaussians] 470 | neighbor_list: torch.Tensor 471 | indices of all neighbors of each bead size [n_frames, n_atoms, n_neighbors] 472 | 473 | Returns: 474 | ------------- 475 | aggregated features: torch.Tensor 476 | Residual features of shape [n_frames, n_atoms, n_features] 477 | ''' 478 | 479 | # generate convolutional filter of size [n_frames, n_atoms, n_neighbors, n_features] 480 | 481 | conv_filter = self.filter_generator(rbf_expansion.to(torch.float32)) 482 | 483 | # Feature tensor needs to also be transformed from [n_frames, n_atoms, n_features] 484 | # to [n_frames, n_atoms, n_neighbors, n_features] 485 | n_batch, n_atoms, n_neighbors = neighbor_list.size() 486 | 487 | # size [n_frames, n_atoms*n_neighbors, 1] 488 | neighbor_list = neighbor_list.reshape(-1, n_atoms*n_neighbors, 1) 489 | 490 | # size [n_frames, natoms*n_neighbors, n_features] 491 | neighbor_list = neighbor_list.expand(-1, -1, features.size(2)) 492 | 493 | # Gather the features into the respective places in the neighbor list 494 | neighbor_features = torch.gather(features, 1, neighbor_list.to(torch.int64)) 495 | # Reshape back to [n_frames, n_atoms, n_neighbors, n_features] for element-wise multiplication 496 | 497 | neighbor_features = neighbor_features.reshape(n_batch, n_atoms, n_neighbors, -1) 498 | 499 | # element-wise multiplication of the features with the convolutional filter 500 | conv_features = neighbor_features * conv_filter 501 | # [B, N, M, n_features] 502 | #nbr_filter = self.nbr_filter(conv_features).view([n_batch, n_atoms, 1, n_neighbors]) 503 | nbr_filter = torch.matmul(conv_features, self.nbr_filter).view([n_batch, n_atoms, n_neighbors]) 504 | # [B, N, 1, M] 505 | nbr_filter = F.softmax(nbr_filter, -1).to(device) 506 | nbr_filter = nbr_filter.view([n_batch, n_atoms, n_neighbors]) 507 | # [B, N, M] 508 | aggregated_features = torch.einsum('bij,bijc->bic',nbr_filter, conv_features) 509 | 510 | 511 | # attention for pooling 512 | 513 | # aggregate/pool the features from [n_frames, n_atoms, n_neighbors, n_features] to [n_frames, n_atoms, n_features] 514 | #aggregated_features = torch.sum(conv_features, dim=2) 515 | 516 | ####### later using attention on the connectivities so we need this now ######### 517 | #aggregated_features = conv_features 518 | 519 | if self.normalization_layer is not None: 520 | if isinstance(self.normalization_layer, NeighborNormLayer): 521 | return self.normalization_layer(aggregated_features, n_neighbors) 522 | else: 523 | return self.normalization_layer(aggregated_features) 524 | else: 525 | return aggregated_features, nbr_filter 526 | 527 | 528 | class InteractionBlock(nn.Module): 529 | ''' 530 | SchNet interaction block as described by Schütt et al. (2018). 531 | 532 | An interaction block consists of : 533 | 1. Atom-wise/Linear layer without activation function 534 | 2. Continuous filter convolution, which is a filter-generator multiplied element-wise with the output of previous layer 535 | 3. Atom-wise/Linear layer with the activation 536 | 4. Atom-wise/Linear layer without activation 537 | 538 | The output of an interaction block will then be used to form an additive residual connection with the original input features, 539 | [x'1, ..., x'n] 540 | 541 | Parameters: 542 | --------------- 543 | n_inputs: int, number of input features, determines input size for the initial linear layer 544 | n_gaussians: int, number of gaussians that has been used in the radial basis function. needed to determine the input size of the continuous 545 | filter convolution. 546 | n_filters: int, number of filters that will be created in the continuous filter convolution. The same feature size will be used for the output 547 | linear layers of the interaction block. 548 | activation: nn.Module activation function for the atom-wise layers. 549 | normalization_layer: nn.Module (default=None) 550 | normalization layer to be applied to the output of the ContinuousFilterConvolution 551 | 552 | The residul connection will be added later in model module between the interaction blocks 553 | 554 | References: 555 | ---------------- 556 | K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, 557 | A. Tkatchenko, K.-R. Müller. (2018) 558 | SchNet - a deep learning architecture for molecules and materials. 559 | The Journal of Chemical Physics. 560 | https://doi.org/10.1063/1.5019779 561 | ''' 562 | 563 | def __init__(self, n_inputs, n_gaussians, n_filters, activation=nn.Tanh(), normalization_layer=None): 564 | super(InteractionBlock, self).__init__() 565 | 566 | self.initial_dense = nn.Sequential(*LinearLayer(n_inputs, n_filters, bias=False, activation=None)) 567 | 568 | self.cfconv = ContinuousFilterConv(n_gaussians=n_gaussians, 569 | n_filters=n_filters, 570 | activation=activation, 571 | normalization_layer=normalization_layer) 572 | 573 | #self.attn = AttnBlock(n_filters, n_filters) 574 | 575 | 576 | # look here if adding layers here will give any problems 577 | output_layers = LinearLayer(n_filters, n_filters, bias=True, activation=activation) 578 | output_layers += LinearLayer(n_filters, n_filters, bias=True) 579 | self.output_dense = nn.Sequential(*output_layers) 580 | 581 | @staticmethod 582 | def convert_adj(mat): 583 | # convert the nbr_adj_list matrix to an adjacency matrix 584 | mat = mat.to(torch.long) 585 | adj = torch.zeros((mat.shape[0], mat.shape[1], mat.shape[1])) 586 | for i in range(mat.shape[0]): 587 | adj[i][torch.arange(adj.shape[1])[:,None],mat[i]] = 1 588 | adj[i] += torch.eye(mat.shape[1]) 589 | return adj.to(device) 590 | 591 | 592 | def forward(self, features, rbf_expansion, neighbor_list): 593 | ''' Compute interaction block 594 | 595 | Parameters: 596 | ----------------- 597 | features: torch.Tensor 598 | Input features from an embedding or ineteraction layer. [n_frames, n_atom, n_features] 599 | rbf_expansion: torch.Tensor, 600 | Radial basis function expansion of distances [n_frames, n_atoms, n_neighbors, n_gaussians] 601 | neighbor_list: torch.Tensor 602 | Indices of all neighbors of each atom [n_frames, n_atoms, n_neighbors] 603 | 604 | Returns: 605 | ----------------- 606 | output_features: torch.Tensor 607 | Output of an interaction block. This output can be used to form a residual connection with the output of 608 | a prior embedding/interaction layer. [n_frames, n_atoms, n_filters] 609 | ''' 610 | init_feature_output = self.initial_dense(features) 611 | conv_output, attn = self.cfconv(init_feature_output.to(torch.float32), rbf_expansion.to(torch.float32), neighbor_list) 612 | #adj = self.convert_adj(neighbor_list) 613 | #conv_output_t, attn = self.attn(conv_output, neighbor_list) 614 | output_features = self.output_dense(conv_output).to(torch.float32) 615 | return output_features, attn 616 | -------------------------------------------------------------------------------- /Example_TrpCage_2dEmbedding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "b513f3f0", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import pyemma" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "id": "8c27e196", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "embd = np.load('embeddings.npz')['arr_0']" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "id": "70b72457", 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "data": { 33 | "image/png": "\n", 34 | "text/plain": [ 35 | "
" 36 | ] 37 | }, 38 | "metadata": { 39 | "needs_background": "light" 40 | }, 41 | "output_type": "display_data" 42 | } 43 | ], 44 | "source": [ 45 | "from pylab import *\n", 46 | "rc('axes', linewidth=2)\n", 47 | "f, ax = plt.subplots(1,1, figsize=(8,5), dpi=100)\n", 48 | "pyemma.plots.plot_free_energy(embd[:,0], embd[:,1],ax=ax)\n", 49 | "\n", 50 | "fontsize = 10\n", 51 | "for tick in ax.xaxis.get_major_ticks():\n", 52 | " tick.label1.set_fontsize(fontsize)\n", 53 | " tick.label1.set_fontweight('bold')\n", 54 | "for tick in ax.yaxis.get_major_ticks():\n", 55 | " tick.label1.set_fontsize(fontsize)\n", 56 | " tick.label1.set_fontweight('bold')\n", 57 | " \n", 58 | "plt.savefig('trpcage_free_energy.png')" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "id": "18b7134a", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "transformed = []\n", 69 | "for i in range(1,2):\n", 70 | " trans_temp = []\n", 71 | " trans_temp.append(np.load('trans_'+str(i)+'.npz')['arr_0'])\n", 72 | " transformed.append(trans_temp) " 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 5, 78 | "id": "feb034fc", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "y_m_t = []\n", 83 | "inds = []\n", 84 | "thresh = 0.95\n", 85 | "exp = 0\n", 86 | "for i in range(1):\n", 87 | " tmp = []\n", 88 | " tmp_ind = []\n", 89 | " for j in range(len(transformed[exp][i])):\n", 90 | " if transformed[exp][i][j].max()>thresh:\n", 91 | " tmp.append(np.argmax(transformed[exp][i][j]))\n", 92 | " tmp_ind.append(j)\n", 93 | " tmp = np.array(tmp)\n", 94 | " tmp_ind = np.array(tmp_ind)\n", 95 | " \n", 96 | " y_m_t.append(tmp)\n", 97 | " inds.append(tmp_ind)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 6, 103 | "id": "d686d6a3", 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "[array([ 0, 1, 2, ..., 208797, 208798, 208799])]" 110 | ] 111 | }, 112 | "execution_count": 6, 113 | "metadata": {}, 114 | "output_type": "execute_result" 115 | } 116 | ], 117 | "source": [ 118 | "inds" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 7, 124 | "id": "88d8402e", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "my_dict = {0:0, 1:2, 2:1, 3:3, 4:4}" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 8, 134 | "id": "2d85b2b0", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "y_m_t_t = np.vectorize(my_dict.get)(y_m_t[0])" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 10, 144 | "id": "5b0a6528", 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "data": { 149 | "text/plain": [ 150 | "
" 151 | ] 152 | }, 153 | "metadata": {}, 154 | "output_type": "display_data" 155 | }, 156 | { 157 | "data": { 158 | "image/png": "\n", 159 | "text/plain": [ 160 | "
" 161 | ] 162 | }, 163 | "metadata": { 164 | "needs_background": "light" 165 | }, 166 | "output_type": "display_data" 167 | } 168 | ], 169 | "source": [ 170 | "from pylab import *\n", 171 | "plt.set_cmap('jet')\n", 172 | "rc('axes', linewidth=2)\n", 173 | "f, ax = plt.subplots(1,1, figsize=(8,5))\n", 174 | "pyemma.plots.plot_state_map(embd[inds[0],0], embd[inds[0],1], y_m_t_t, ax=ax, ncontours=100)\n", 175 | "\n", 176 | "fontsize = 10\n", 177 | "for tick in ax.xaxis.get_major_ticks():\n", 178 | " tick.label1.set_fontsize(fontsize)\n", 179 | " tick.label1.set_fontweight('bold')\n", 180 | "for tick in ax.yaxis.get_major_ticks():\n", 181 | " tick.label1.set_fontsize(fontsize)\n", 182 | " tick.label1.set_fontweight('bold')\n", 183 | " \n", 184 | "\n", 185 | "pyemma.plots.plot_state_map(embd[:,0], embd[:,1], np.zeros(embd[:,0].shape[0]), ax=ax, alpha=0.1, mask=True, cbar=False,cmap='Greys')\n", 186 | "plt.savefig('trpcage_states.png')" 187 | ] 188 | } 189 | ], 190 | "metadata": { 191 | "kernelspec": { 192 | "display_name": "Python 3 (ipykernel)", 193 | "language": "python", 194 | "name": "python3" 195 | }, 196 | "language_info": { 197 | "codemirror_mode": { 198 | "name": "ipython", 199 | "version": 3 200 | }, 201 | "file_extension": ".py", 202 | "mimetype": "text/x-python", 203 | "name": "python", 204 | "nbconvert_exporter": "python", 205 | "pygments_lexer": "ipython3", 206 | "version": "3.8.12" 207 | } 208 | }, 209 | "nbformat": 4, 210 | "nbformat_minor": 5 211 | } 212 | --------------------------------------------------------------------------------