├── README.md ├── SynaptotagminC2A_dataPreparation.ipynb ├── masks.py ├── 10Cube.ipynb ├── Toymodel_2Systems.ipynb ├── SynaptotagminC2A.ipynb ├── examples.py └── ivampnets.py /README.md: -------------------------------------------------------------------------------- 1 | # iVAMPnets 2 | 3 | Codebase for the iVAMPnets estimator and model which includes the classes for constructing the masks for toymodels and real protein applications. 4 | The implemented methods allow to decompose a possible high dimensional system in its weakly coupled or independent subsystems. Thereby, the downstream estimation of the kinetic models is much more data efficient than estimating a global kinetic model which might not be feasible. The whole pipeline is an end-to-end deep learning framework which allows to define your own network architectures for the kinetics estimation of each subsystem. 5 | The data for the synaptotagmin C2A system is available upon request. The code is designed to reproduce the results of our paper "Deep learning to decompose macromolecules into independent Markovian domains" (https://www.biorxiv.org/content/10.1101/2022.03.30.486366v1) and is based on the deeptime package (see https://deeptime-ml.github.io/latest/index.html). 6 | 7 | The code includes: 8 | 1. (ivampnets.py) The definition of the ivampnets estimator class, which allows to fit a given model to simulation data. The definition of the ivampnets model class - the resulting model - which can then be used to estimate transition matrices, implied timescales, eigenfunctions, etc. 9 | 2. (masks.py) The definition of the mask modules, which can be used to give the modeler an intuition which part of the global system is assigned to which subsystem. 10 | 3. (examples.py) Helper functions to generate the data for the toy systems and plot some results. 11 | 4. (Toymodel_2Systems.ipynb) Notebook to reproduce the results for a simple truly independent 2D system. Typical runtime (cpu): 2 min 12 | 5. (10Cube.ipynb) Notebook to reproduce the results for the 10-Cube example. Typical runtime (cpu): 5 min 13 | 6. (SynaptotagminC2A.ipynb) Notebook to reproduce the results for a protein example. The data of the synaptotagmin C2A domain is available upon request. Typical runtime (cuda): 1.5 hours 14 | 15 | The code was executed using the following package versions on a linux computer (debian bullseye): 16 | 17 | ``` 18 | python=3.6 or higher 19 | jupyterlab=3.2.0 or jupyter=1.0.0 20 | 21 | pytorch=1.8.0 22 | deeptime=0.2.9 23 | numpy=1.19.5 24 | matplotlib=3.1.3 25 | ``` 26 | optional: 27 | ``` 28 | tensorboard=2.6.0 29 | h5py=1.10.4 30 | ``` 31 | 32 | ## Installation instructions 33 | 34 | The software dependencies can be installed with anaconda / miniconda. If you do not have miniconda or anaconda, please follow the instructions here: https://conda.io/miniconda.html 35 | 36 | The following command can be used to create a new conda environment and install all dependencies for the ivampnets scripts. 37 | ```bash 38 | conda create -n ivampnets pytorch=1.8.0 deeptime=0.2.9 numpy=1.19.5 matplotlib=3.1.3 jupyter h5py -c conda-forge 39 | ``` 40 | The new conda environment can be activated with 41 | ```bash 42 | conda activate ivampnets 43 | ``` 44 | 45 | 46 | In case you are already a conda and jupyter notebook user with various environments, you can install your environment Python kernel via 47 | ```bash 48 | python -m ipykernel install --user --name ivampnets 49 | ``` 50 | This repository including the python scripts and jupyter notebooks can be downloaded with 51 | ```bash 52 | git clone git@github.com:markovmodel/ivampnets.git 53 | ``` 54 | 55 | The following command will start the jupyter notebook server: 56 | ```bash 57 | jupyter notebook 58 | ``` 59 | 60 | Your browser should pop up pointing to a list of notebooks once you navigate into the repository directory. If it's the wrong browser, add for example `--browser=firefox` or copy and paste the URL into the browser of your choice. 61 | 62 | The typical install time ranges from 5 minutes for conda-users to 20 minutes if conda has to be set up from scratch. 63 | -------------------------------------------------------------------------------- /SynaptotagminC2A_dataPreparation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2022-07-26T15:51:59.813087Z", 9 | "start_time": "2022-07-26T15:51:43.109613Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "import numpy as np\n", 15 | "import pyemma\n", 16 | "from tqdm.notebook import tqdm\n", 17 | "import mdtraj\n", 18 | "import itertools\n", 19 | "import h5py\n", 20 | "from glob import glob" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": { 27 | "ExecuteTime": { 28 | "end_time": "2022-07-26T15:52:01.029155Z", 29 | "start_time": "2022-07-26T15:52:00.993261Z" 30 | } 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "# file paths\n", 35 | "topfile = 'setup/hsynapto.pdb'\n", 36 | "syt_files = glob('0cal_dyn*.1/hsynapto-protein-stride10.xtc')\n", 37 | "\n", 38 | "outfile = 'syt_0cal_internal1by1_stride100.hdf5'" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "ExecuteTime": { 46 | "end_time": "2022-07-26T15:52:02.464546Z", 47 | "start_time": "2022-07-26T15:52:01.770482Z" 48 | } 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "# define pyemma featurizer\n", 53 | "feat = pyemma.coordinates.featurizer(topfile)\n", 54 | "\n", 55 | "# add pairs of residues, exclude first and last 3 residues\n", 56 | "pairs = feat.pairs(np.arange(3, feat.topology.n_residues - 3), excluded_neighbors=5)\n", 57 | "feat.add_residue_mindist(residue_pairs=pairs)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": { 64 | "ExecuteTime": { 65 | "end_time": "2022-07-26T15:53:38.939695Z", 66 | "start_time": "2022-07-26T15:53:38.788757Z" 67 | } 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "# create iterator\n", 72 | "data_source = pyemma.coordinates.source(syt_files, feat)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": { 79 | "ExecuteTime": { 80 | "end_time": "2022-07-26T15:54:11.197537Z", 81 | "start_time": "2022-07-26T15:53:39.612383Z" 82 | } 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "### process data with featurizer and write to disk\n", 87 | "\n", 88 | "# note that stride parameter here must be multiplied by the stride on the\n", 89 | "# trajectories that we're loading (which is 10),\n", 90 | "# i.e., loading with stride 10 here is a total stride of 100. compare `outfile`\n", 91 | "\n", 92 | "it = data_source.iterator(stride=10, chunk=1000)\n", 93 | "\n", 94 | "with h5py.File(outfile, \"w\") as f:\n", 95 | " last_trajid = -1\n", 96 | " for trajid, chunk in tqdm(it, total=it.n_chunks):\n", 97 | " \n", 98 | " if last_trajid < trajid:\n", 99 | " if last_trajid != -1:\n", 100 | " dset.flush()\n", 101 | " dset = f.create_dataset(syt_files[trajid].split('/')[-2], \n", 102 | " shape=(it.trajectory_length(), feat.dimension()), \n", 103 | " dtype=np.float32)\n", 104 | " start = 0\n", 105 | " last_trajid = trajid\n", 106 | " dset[it.pos:it.pos + it.chunksize if it.pos + it.chunksize < it.trajectory_length() else None] = chunk\n", 107 | " start += 1" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [] 116 | } 117 | ], 118 | "metadata": { 119 | "kernelspec": { 120 | "display_name": "py37_mar20", 121 | "language": "python", 122 | "name": "py37_mar20" 123 | }, 124 | "language_info": { 125 | "codemirror_mode": { 126 | "name": "ipython", 127 | "version": 3 128 | }, 129 | "file_extension": ".py", 130 | "mimetype": "text/x-python", 131 | "name": "python", 132 | "nbconvert_exporter": "python", 133 | "pygments_lexer": "ipython3", 134 | "version": "3.7.6" 135 | } 136 | }, 137 | "nbformat": 4, 138 | "nbformat_minor": 2 139 | } 140 | -------------------------------------------------------------------------------- /masks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import matplotlib.pyplot as plt 6 | 7 | class Mean_std_layer(nn.Module): 8 | """ Custom Linear layer for substracting the mean and dividing by the std 9 | 10 | Parameters 11 | ---------- 12 | intput_size: int 13 | The input size. 14 | mean: torch.Tensor 15 | The mean values of all training points of the input features. Should have the size (1,intput_size) 16 | std: torch.Tensor 17 | The std values of all training points of the input features. Should have the size (1,intput_size) 18 | """ 19 | def __init__(self, intput_size, mean=None, std=None): 20 | super().__init__() 21 | self.input_size = intput_size 22 | if mean is None: 23 | mean = torch.zeros((1,input_size)) 24 | self.weights_mean = nn.Parameter(mean, requires_grad=False) # nn.Parameter is a Tensor that's a module parameter. 25 | if std is None: 26 | std = torch.ones((1,input_size)) 27 | self.weights_std = nn.Parameter(std, requires_grad=False) 28 | 29 | def forward(self, x): 30 | y = (x-self.weights_mean)/self.weights_std 31 | return y 32 | 33 | def set_both(self, mean, std): 34 | new_params = [mean, std] 35 | with torch.no_grad(): 36 | for i, param in enumerate(self.parameters()): 37 | new_param = new_params[i] 38 | param.copy_(torch.Tensor(new_param[None,:])) 39 | 40 | 41 | class Mask(torch.nn.Module): 42 | ''' Mask, which acts directly on the input features. 43 | 44 | Parameters 45 | ---------- 46 | input_size: int 47 | Feature input size. 48 | N: int 49 | Number of subsystems. 50 | factor_fake: float, default=1. 51 | Factor how strong the fake subsystem is to take over input space. Makes the mask sparser for the real subsystems. 52 | noise: float, default=0. 53 | Regularize the masked by adding noise to the input. Therefore, the downstream lobes cannot recover inputs with low important weights. 54 | The larger the noise the stronger the weigh assignment of the mask will become. 55 | cutoff: float, must be between 0 and 1 56 | Cutoff after which an attention weight is set to zero. A total uninformative weight would be one, which is how 57 | the mask is initialized. 58 | mean: torch.Tensor 59 | The mean values of all training points of the input features. Should have the size (1,intput_size) 60 | std: torch.Tensor 61 | The std values of all training points of the input features. Should have the size (1,intput_size) 62 | ''' 63 | def __init__(self, input_size, N, factor_fake=1., noise=0., cutoff=0.9, 64 | mean=None, std=None, device='cpu'): 65 | super(Mask, self).__init__() 66 | self.input_size = input_size 67 | self.normalizer = Mean_std_layer(input_size, mean, std) 68 | self.factor_fake=factor_fake 69 | self.N = N 70 | list_weights = [] 71 | for n in range(self.N): 72 | alpha = torch.ones((1, input_size, 1)) 73 | weight = torch.nn.Parameter(data=alpha, requires_grad=True) 74 | list_weights.append(weight) 75 | self.list_weights = nn.ParameterList(list_weights) 76 | self.noise=noise 77 | self.cutoff = cutoff 78 | self.device=device 79 | def forward(self, x): 80 | ''' Applies the attention weights to all inputs and adds the defined noise. Furthermore, it 81 | normalizes the input to be approximately Gaussian. 82 | ''' 83 | weight_sf = self.get_softmax() 84 | prod = self.N + 1 85 | # first remove mean and std 86 | x = self.normalizer(x) 87 | 88 | masked_x = x[:,:,None] * weight_sf * prod# include factor 89 | 90 | if self.noise>0.: 91 | max_attention_value = torch.max(weight_sf, dim=1, keepdim=True)[0].detach() 92 | shape = weight_sf.shape 93 | # shape = (x.shape[0], weight_sf.shape[1], weight_sf.shape[2]) 94 | random_numbers = torch.randn(shape, device=self.device) * self.noise 95 | masked_x += (1 - weight_sf/max_attention_value) * random_numbers 96 | 97 | # split them for each subsystem 98 | masked_list = torch.split(masked_x, 1, dim=2) 99 | return masked_list 100 | 101 | def get_softmax(self): 102 | ''' Estimates the attention weight for each input and subsystem. 103 | ''' 104 | weights_all = [] 105 | for param in self.list_weights: 106 | # first make a softmax over the input feature dimension to make them all positive 107 | weights_all.append(F.softmax(param, dim=1)*self.input_size) # the factor makes them on average around 1 108 | weights_per_N = torch.cat(weights_all, dim=2) # dim: 1 x input_size x N 109 | # add a fake subsystem 110 | fake_axis = torch.ones_like(self.list_weights[0])*self.factor_fake 111 | weights_per_N_fake = torch.cat([weights_per_N, fake_axis], dim=2) 112 | 113 | # normalize them along the subsystem axis 114 | weights_per_N_fake = torch.relu(weights_per_N_fake-self.cutoff) # set all to zero smaller cutoff 115 | weights_per_N_fake = weights_per_N_fake / torch.sum(weights_per_N_fake, dim=2, keepdims=True) # norm them to 1 116 | # remove fake axis 117 | weight_sf = weights_per_N_fake[:,:,:self.N] 118 | 119 | return weight_sf 120 | 121 | 122 | class Mask_proteins(torch.nn.Module): 123 | ''' Mask, which acts on protein residue distances. 124 | 125 | Parameters 126 | ---------- 127 | input_size: int 128 | Feature input size. 129 | N: int 130 | Number of subsystems. 131 | skip_res: int 132 | How many residues at the ends of the amino acid chain are neglected for distance calculation. 133 | patch_size: int 134 | Size of the window which slides over the acid chain. 135 | skip: int 136 | How many residues are skipped in each step of the window. It results in the fact that skip many residues have the same 137 | attention weight. 138 | factor_fake: float, default=1. 139 | Factor how strong the fake subsystem is to take over input space. Makes the mask sparser for the real subsystems. 140 | noise: float, default=0. 141 | Regularize the masked by adding noise to the input. Therefore, the downstream lobes cannot recover inputs with low important weights. 142 | The larger the noise the stronger the weigh assignment of the mask will become. 143 | cutoff: float, must be between 0 and 1 144 | Cutoff after which an attention weight is set to zero. A total uninformative weight would be one, which is how 145 | the mask is initialized. 146 | mean: torch.Tensor 147 | The mean values of all training points of the input features. Should have the size (1,intput_size) 148 | std: torch.Tensor 149 | The std values of all training points of the input features. Should have the size (1,intput_size) 150 | ''' 151 | def __init__(self, input_size, N, skip_res, patchsize, skip, factor_fake=3., 152 | noise=0., cutoff=0.5, mean=None, std=None, device='cpu'): 153 | super(Mask_proteins, self).__init__() 154 | 155 | self.device = device 156 | self.normalizer = Mean_std_layer(input_size, mean, std) 157 | self.noise = noise 158 | self.patchsize = patchsize 159 | self.skip = skip 160 | self.factor_fake = factor_fake 161 | self.N = N 162 | self.cutoff = cutoff 163 | self.skip_res = skip_res 164 | self.n_residues = int(-1/2 + np.sqrt(1/4+input_size*2) + self.skip_res) 165 | print('Number of residues is: {}'.format(self.n_residues)) 166 | self.residues_1 = [] 167 | self.residues_2 = [] 168 | 169 | self.nb_per_res = int(np.ceil(self.patchsize/self.skip)) # number of windows for each residue 170 | self.bs_per_res = np.empty((self.n_residues, self.nb_per_res), dtype=int) 171 | 172 | self.balance = (self.n_residues%self.skip)//2 # how much move the whole windows to make it symmetric at the ends 173 | for i in range(self.n_residues): 174 | start = (i+self.balance)//self.skip #within skip the same values 175 | self.bs_per_res[i] = np.arange(start, start+self.nb_per_res) 176 | 177 | self.number_weights = self.bs_per_res[-1,-1]+1 178 | # get the indexes of the residues which are part of the distances in the input 179 | for n1 in range(self.n_residues-self.skip_res): 180 | for n2 in range(n1+self.skip_res, self.n_residues): 181 | self.residues_1.append(n1) 182 | self.residues_2.append(n2) 183 | 184 | # initialize the weights you need for the windows. 185 | list_weights = [] 186 | for n in range(self.N): 187 | alpha = torch.ones((1, self.number_weights, 1)) 188 | 189 | weight = torch.nn.Parameter(data=alpha, requires_grad=True) 190 | list_weights.append(weight) 191 | self.list_weights = nn.ParameterList(list_weights) 192 | 193 | 194 | def forward(self, x): 195 | ''' Applies the attention weights of each residue to all distances and adds the defined noise. Furthermore, it 196 | normalizes the input to be approximately Gaussian. 197 | ''' 198 | # first remove mean and std 199 | x = self.normalizer(x) 200 | # get the weights for each residue 201 | weights_for_res = self.get_softmax() 202 | 203 | prod = self.N + 1 # plus one because of the fake subsystem 204 | # get the weights for each input feature, due to the distance two residue weights 205 | weight_1 = weights_for_res[self.residues_1] * prod 206 | weight_2 = weights_for_res[self.residues_2] * prod 207 | alpha = weight_1[None,:,:] * weight_2[None,:,:] 208 | 209 | 210 | masked_x = x[:,:,None] * alpha 211 | if self.noise>0.: 212 | max_attention_value = torch.max(alpha, dim=1, keepdim=True)[0].detach() 213 | shape = (x.shape[0], alpha.shape[1], alpha.shape[2]) # You should check again which one! 214 | # shape = alpha.shape 215 | random_numbers = torch.randn(shape, device=self.device) * self.noise 216 | masked_x += (1 - alpha/max_attention_value) * random_numbers 217 | # split them for each subsystem 218 | masked_list = torch.split(masked_x, 1, dim=2) 219 | return masked_list 220 | 221 | def get_softmax(self): 222 | ''' Estimates the attention weights for each residue for all subsystems. 223 | ''' 224 | weights_per_N = [] 225 | weights_all = [] 226 | for param in self.list_weights: 227 | weights_all.append(param) 228 | 229 | for param in self.list_weights: 230 | # make them positive 231 | param = F.softmax(param, dim=1)*self.number_weights # this way on average 1 232 | weights_for_res = [] 233 | for i in range(self.nb_per_res): # get all weights b for each residue 234 | 235 | weights_for_res.append(param[:,self.bs_per_res[:,i],:]) 236 | 237 | # take the product of all windows involved for the same residue 238 | weights_for_res = torch.prod(torch.cat(weights_for_res, dim=0), dim=0) # take the product of the b factors 239 | 240 | weights_per_N.append(weights_for_res) 241 | # Add the fake subsystem 242 | fake_axis = torch.ones_like(weights_per_N[0])*self.factor_fake 243 | weights_per_N = torch.cat(weights_per_N, dim=1) 244 | weights_per_N_fake = torch.cat([weights_per_N, fake_axis], dim=1) 245 | # normalize them along the subsystem axis 246 | weights_per_N_fake = torch.relu(weights_per_N_fake-self.cutoff) # set all to zero smaller 0.5 247 | weights_per_N_fake = weights_per_N_fake / torch.sum(weights_per_N_fake, dim=1, keepdims=True) # norm them to 1 248 | # remove the fake system 249 | weights_for_res = weights_per_N_fake[:,:self.N] 250 | 251 | return weights_for_res -------------------------------------------------------------------------------- /10Cube.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "99ac4f17", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "%matplotlib inline\n", 13 | "\n", 14 | "import torch.nn as nn\n", 15 | "import torch.nn.functional as F\n", 16 | "import torch\n", 17 | "from torch.utils import data\n", 18 | "import matplotlib.gridspec as gridspec" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "8ad358a6", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "tau = 1\n", 29 | "number_subsystems = 10\n", 30 | "fake_dims = 10\n", 31 | "output_sizes = [2 for _ in range(number_subsystems)]\n", 32 | "# tau list for timescales estimation\n", 33 | "msmlags = np.arange(1, 10)\n", 34 | "\n", 35 | "# Batch size for Stochastic Gradient descent\n", 36 | "batch_size = 10000\n", 37 | "\n", 38 | "# How many hidden layers the network chi has\n", 39 | "network_depth = 3\n", 40 | "\n", 41 | "# Width of every layer of chi\n", 42 | "layer_width = 30\n", 43 | "\n", 44 | "# Learning rate used for the ADAM optimizer\n", 45 | "\n", 46 | "# create a list with the number of nodes for each layer\n", 47 | "nodes = [layer_width]*network_depth\n", 48 | "\n", 49 | "# Definition of the hidden Markov transition matrices\n", 50 | "eps_list = np.linspace(0.,.1, number_subsystems+1)[1:]\n", 51 | "lam = 0 #0.04\n", 52 | "# Number of unformative noise dimensions\n", 53 | "dim_noise = 10\n", 54 | "\n", 55 | "# How strong the fake subsystem is\n", 56 | "factor_fake = 3.\n", 57 | "# How large the noise in the mask for regularization is\n", 58 | "noise = 1.\n", 59 | "# Threshold after which the attention weight is set to zero\n", 60 | "cutoff=0.7\n", 61 | "\n", 62 | "# Learning rate\n", 63 | "learning_rate=0.001\n", 64 | "# epsilon for inversion of symmetric matrices\n", 65 | "epsilon=1e-8\n", 66 | "# score method\n", 67 | "score_mode='regularize' # one of ('trunc', 'regularize', 'clamp', 'old')" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "id": "d92c9347", 73 | "metadata": {}, 74 | "source": [ 75 | "### Create toymodel" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "id": "6226e55e", 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "from examples import HyperCube" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "41a326b8", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "toymodel = HyperCube(eps_list, lam=lam)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "id": "3d52c5f8", 101 | "metadata": {}, 102 | "source": [ 103 | "### Sample hidden and observable trajectory" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "a6544310", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "angles = np.pi / 4 * np.ones(number_subsystems//2)\n", 114 | "# training data with 100000 steps\n", 115 | "hidden_state_traj, observable_traj = toymodel.generate_traj(100000, angles=angles, dim_noise=dim_noise)\n", 116 | "\n", 117 | "# validation data with 10000 steps\n", 118 | "hidden_state_traj_valid, observable_traj_valid = toymodel.generate_traj(10000, angles=angles, dim_noise=dim_noise)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "id": "caba8e29", 124 | "metadata": {}, 125 | "source": [ 126 | "### Define training and validation set" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "id": "d81d68e9", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "from deeptime.util.data import TrajectoryDataset\n", 137 | "\n", 138 | "train_data = TrajectoryDataset(lagtime=tau, trajectory=observable_traj.astype('float32'))\n", 139 | "val_data = TrajectoryDataset(lagtime=tau, trajectory=observable_traj_valid.astype('float32'))" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "id": "8985ac87", 145 | "metadata": {}, 146 | "source": [ 147 | "### Define networks" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "id": "26eb4a80", 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "from masks import Mask\n", 158 | "from collections import OrderedDict\n", 159 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 160 | "\n", 161 | "# Assuming that we are on a CUDA machine, this should print a CUDA device:\n", 162 | "\n", 163 | "print(device)\n", 164 | "input_size = observable_traj.shape[1] \n", 165 | "mask = Mask(input_size, number_subsystems, mean=torch.Tensor(train_data.data.mean(0)),\n", 166 | " std=torch.Tensor(train_data.data.std(0)), factor_fake=factor_fake, noise=noise, \n", 167 | " device=device, cutoff=cutoff)\n", 168 | "mask.to(device=device)\n", 169 | "lobes = []\n", 170 | "for output_size in output_sizes:\n", 171 | " lobe_dict = OrderedDict([('Layer_input', nn.Linear(input_size, layer_width)),\n", 172 | " ('Elu_input', nn.ELU())])\n", 173 | " for d in range(network_depth):\n", 174 | " lobe_dict['Layer'+str(d)]=nn.Linear(layer_width, layer_width)\n", 175 | " lobe_dict['Elu'+str(d)]=nn.ELU()\n", 176 | " lobe_dict['Layer_output']=nn.Linear(layer_width, output_size)\n", 177 | " lobe_dict['Softmax']=nn.Softmax(dim=1) # obtain fuzzy probability distribution over output states\n", 178 | " \n", 179 | " lobe = nn.Sequential(\n", 180 | " lobe_dict \n", 181 | " )\n", 182 | " lobes.append(lobe.to(device=device))\n", 183 | "\n", 184 | "print(mask)\n", 185 | "print(lobes)\n", 186 | " " 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "id": "8772de4d", 192 | "metadata": {}, 193 | "source": [ 194 | "### Create iVAMPnets estimator" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "id": "efacd90f", 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "from ivampnets import iVAMPnet" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "id": "83916f55", 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "ivampnet = iVAMPnet(lobes, mask, device, learning_rate=learning_rate, epsilon=epsilon, score_mode=score_mode, learning_rate_mask=learning_rate/2)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "id": "dca38304", 220 | "metadata": {}, 221 | "source": [ 222 | "### Plot mask before training" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "id": "80c7aacb", 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "from examples import plot_mask\n", 233 | "plot_mask(mask, skip=2)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "id": "e78895a4", 239 | "metadata": {}, 240 | "source": [ 241 | "### Create data loader" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "id": "27dc64ae", 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "from torch.utils.data import DataLoader\n", 252 | "\n", 253 | "loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n", 254 | "loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "id": "5a7361b8", 260 | "metadata": {}, 261 | "source": [ 262 | "### Create a tensorboard writer to observe performance during training" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "id": "1857cb58", 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "tensorboard_installed = False\n", 273 | "if tensorboard_installed:\n", 274 | " from torch.utils.tensorboard import SummaryWriter\n", 275 | " writer = SummaryWriter('./runs/Cube10/')\n", 276 | " input_model, _ = next(iter(loader_train))\n", 277 | " # writer.add_graph(lobe, input_to_model=input_model.to(device))\n", 278 | "else:\n", 279 | " writer=None" 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "id": "4e92b2ff", 285 | "metadata": {}, 286 | "source": [ 287 | "### Fit the model on the training data" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "id": "e713fb95", 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "epochs = 50" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "id": "002b703d", 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "# model = ivampnet.fit(loader_train, n_epochs=epochs, validation_loader=loader_val, mask=False, lam_decomp=2, \n", 308 | "# lam_trace=0.5, start_mask=0, end_trace=1, tb_writer=writer, clip=False).fetch_model()\n", 309 | "# plot_mask(mask, skip=2)\n", 310 | "# mask.noise = 5\n", 311 | "# model = ivampnet.fit(loader_train, n_epochs=epochs, validation_loader=loader_val, mask=True, lam_decomp=2, \n", 312 | "# lam_trace=0, start_mask=0, end_trace=0, tb_writer=writer, clip=False).fetch_model()\n", 313 | "# plot_mask(mask, skip=2)\n", 314 | "# mask.noise = 10\n", 315 | "# model = ivampnet.fit(loader_train, n_epochs=epochs, validation_loader=loader_val, mask=True, lam_decomp=2, \n", 316 | "# lam_trace=0, start_mask=0, end_trace=0, tb_writer=writer, clip=False).fetch_model()\n", 317 | "lam_pen_perc=0.05\n", 318 | "lam_pen_C00=0. \n", 319 | "lam_pen_C11=0. \n", 320 | "lam_pen_C01=0.\n", 321 | "model = ivampnet.fit(loader_train, n_epochs=epochs, validation_loader=loader_val, mask=False, lam_decomp=0, \n", 322 | " lam_trace=0., start_mask=0, end_trace=10, tb_writer=writer, clip=False, \n", 323 | " lam_pen_perc=lam_pen_perc, lam_pen_C00=lam_pen_C00, lam_pen_C11=lam_pen_C11,\n", 324 | " lam_pen_C01=lam_pen_C01).fetch_model()\n", 325 | "plot_mask(mask, skip=2)\n", 326 | "# mask.noise = 5\n", 327 | "lam_pen_perc=0.04\n", 328 | "model = ivampnet.fit(loader_train, n_epochs=epochs, validation_loader=loader_val, mask=True, lam_decomp=0, \n", 329 | " lam_trace=0, start_mask=0, end_trace=0, tb_writer=writer, clip=False,\n", 330 | " lam_pen_perc=lam_pen_perc, lam_pen_C00=lam_pen_C00, lam_pen_C11=lam_pen_C11,\n", 331 | " lam_pen_C01=lam_pen_C01).fetch_model()\n", 332 | "plot_mask(mask, skip=2)\n", 333 | "mask.noise = 2\n", 334 | "lam_pen_perc=0.02\n", 335 | "model = ivampnet.fit(loader_train, n_epochs=epochs, validation_loader=loader_val, mask=True, lam_decomp=0, \n", 336 | " lam_trace=0, start_mask=0, end_trace=0, tb_writer=writer, clip=False,\n", 337 | " lam_pen_perc=lam_pen_perc, lam_pen_C00=lam_pen_C00, lam_pen_C11=lam_pen_C11,\n", 338 | " lam_pen_C01=lam_pen_C01).fetch_model()" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": null, 344 | "id": "95548c29", 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [ 348 | "# execution time (on cpu): ~ 4.5 min" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "id": "3984d128", 354 | "metadata": {}, 355 | "source": [ 356 | "### Plot the training and validation scores" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": null, 362 | "id": "1e8d7f7e", 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "plt.loglog(*ivampnet.train_scores.T, label='training')\n", 367 | "plt.loglog(*ivampnet.validation_scores.T, label='validation')\n", 368 | "plt.xlabel('step')\n", 369 | "plt.ylabel('score')\n", 370 | "plt.legend();" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": null, 376 | "id": "0f836100", 377 | "metadata": {}, 378 | "outputs": [], 379 | "source": [ 380 | "plt.loglog(*ivampnet.train_pen_C01.T, label='training')\n", 381 | "plt.loglog(*ivampnet.validation_pen_C01.T, label='validation')\n", 382 | "plt.xlabel('step')\n", 383 | "plt.ylabel('score')\n", 384 | "plt.legend();" 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "id": "144270b5", 390 | "metadata": {}, 391 | "source": [ 392 | "### Plot the mask after training" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "id": "641d5348", 399 | "metadata": {}, 400 | "outputs": [], 401 | "source": [ 402 | "plot_mask(mask, vmax=0.5, skip=2)" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "id": "c0d74aeb", 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [ 412 | "# reproduces Fig. 4c (or a permutation with respect to ivampnet state assignments)" 413 | ] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "id": "143348c9", 418 | "metadata": {}, 419 | "source": [ 420 | "### Estimate implied timescales" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "id": "8b6d36de", 427 | "metadata": {}, 428 | "outputs": [], 429 | "source": [ 430 | "mask.noise=0 # set the noise to zero. \n", 431 | "its = []\n", 432 | "for tau_i in msmlags:\n", 433 | " val_data_temp = TrajectoryDataset(lagtime=tau_i, trajectory=observable_traj_valid.astype('float32'))\n", 434 | " its.append(model.timescales(val_data_temp.data, val_data_temp.data_lagged, tau_i))\n", 435 | "# Convert to array\n", 436 | "its = np.array(its)\n", 437 | "# Change the shape\n", 438 | "its = np.transpose(its, axes=[1,0,2])\n", 439 | "# Estimate the true timescales of the hidden Markov Chain\n", 440 | "eigvals_true = np.array(toymodel.eigvals_list_coupled).flatten()\n", 441 | "its_true = -1/np.log(eigvals_true)" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": null, 447 | "id": "e1749ded", 448 | "metadata": {}, 449 | "outputs": [], 450 | "source": [ 451 | "from examples import plot_hypercube_its" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": null, 457 | "id": "39413b00", 458 | "metadata": {}, 459 | "outputs": [], 460 | "source": [ 461 | "plot_hypercube_its(its, msmlags, its_true, ylog=True)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": null, 467 | "id": "90807a5c", 468 | "metadata": {}, 469 | "outputs": [], 470 | "source": [ 471 | "# reproduces Fig. 4d" 472 | ] 473 | } 474 | ], 475 | "metadata": { 476 | "kernelspec": { 477 | "display_name": "rseed", 478 | "language": "python", 479 | "name": "python3" 480 | }, 481 | "language_info": { 482 | "codemirror_mode": { 483 | "name": "ipython", 484 | "version": 3 485 | }, 486 | "file_extension": ".py", 487 | "mimetype": "text/x-python", 488 | "name": "python", 489 | "nbconvert_exporter": "python", 490 | "pygments_lexer": "ipython3", 491 | "version": "3.8.15" 492 | }, 493 | "vscode": { 494 | "interpreter": { 495 | "hash": "ae7d9bbee64574db456ae08990e902edc98caa89a93077ee255475653bb8dd96" 496 | } 497 | } 498 | }, 499 | "nbformat": 4, 500 | "nbformat_minor": 5 501 | } 502 | -------------------------------------------------------------------------------- /Toymodel_2Systems.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "6882cf61", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "%matplotlib inline\n", 13 | "\n", 14 | "import torch.nn as nn\n", 15 | "import torch.nn.functional as F\n", 16 | "import torch\n", 17 | "from torch.utils import data\n", 18 | "import matplotlib.gridspec as gridspec" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "id": "b23855bb", 24 | "metadata": {}, 25 | "source": [ 26 | "### Hyperparameters" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "f8089031", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "tau = 1\n", 37 | "output_sizes = [3,2]\n", 38 | "\n", 39 | "# Batch size for Stochastic Gradient descent\n", 40 | "batch_size = 10000\n", 41 | "\n", 42 | "# How many hidden layers the network chi has\n", 43 | "network_depth = 3\n", 44 | "\n", 45 | "# Width of every layer of chi\n", 46 | "layer_width = 30\n", 47 | "\n", 48 | "# Learning rate used for the ADAM optimizer\n", 49 | "\n", 50 | "# create a list with the number of nodes for each layer\n", 51 | "nodes = [layer_width]*network_depth\n", 52 | "\n", 53 | "number_subsystems = len(output_sizes)\n", 54 | "\n", 55 | "# How strong the fake subsystem is\n", 56 | "factor_fake = 0.\n", 57 | "# How large the noise in the mask for regularization is\n", 58 | "noise = 1.\n", 59 | "# Threshold after which the attention weight is set to zero\n", 60 | "cutoff=0.9\n", 61 | "# Learning rate\n", 62 | "learning_rate=0.005\n", 63 | "# epsilon\n", 64 | "epsilon=1e-8\n", 65 | "# score method\n", 66 | "score_mode='regularize' # one of ('trunc', 'regularize', 'clamp', 'old')" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "b6eea92f", 72 | "metadata": {}, 73 | "source": [ 74 | "### Create toymodel" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "7365b8cb", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "from examples import Toymodel_2Systems" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "id": "7deff63d", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "eps_list = [.025, .125, .05, .1]\n", 95 | "toymodel = Toymodel_2Systems(eps_list)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "id": "11585e56", 101 | "metadata": {}, 102 | "source": [ 103 | "### Sample hidden and observable trajectory" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "511eeeb8", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "# training data with 100000 steps\n", 114 | "hidden_state_traj, observable_traj = toymodel.generate_traj(100000)\n", 115 | "\n", 116 | "# validation data with 10000 steps\n", 117 | "hidden_state_traj_valid, observable_traj_valid = toymodel.generate_traj(10000)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "id": "1d05ab91", 123 | "metadata": {}, 124 | "source": [ 125 | "### Plot trajectory and true global eigenfunctions" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "id": "e2f05cbc", 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "toymodel.plot_toymodel(hidden_state_traj_valid, observable_traj_valid)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "id": "4dac631b", 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "toymodel.plot_eigfunc(hidden_state_traj_valid, observable_traj_valid)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "id": "ebf9daa7", 151 | "metadata": {}, 152 | "source": [ 153 | "### Define training and validation set" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "id": "a93fd590", 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "from deeptime.util.data import TrajectoryDataset\n", 164 | "\n", 165 | "train_data = TrajectoryDataset(lagtime=tau, trajectory=observable_traj.astype('float32'))\n", 166 | "val_data = TrajectoryDataset(lagtime=tau, trajectory=observable_traj_valid.astype('float32'))" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "id": "433c8d8f", 172 | "metadata": {}, 173 | "source": [ 174 | "### Define networks" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "id": "a7702a34", 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "from masks import Mask\n", 185 | "from collections import OrderedDict\n", 186 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 187 | "\n", 188 | "# Assuming that we are on a CUDA machine, this should print a CUDA device:\n", 189 | "\n", 190 | "print(device)\n", 191 | "input_size = observable_traj.shape[1] \n", 192 | "mask = Mask(input_size, number_subsystems, mean=torch.Tensor(train_data.data.mean(0)),\n", 193 | " std=torch.Tensor(train_data.data.std(0)), factor_fake=factor_fake, noise=noise, \n", 194 | " device=device, cutoff=cutoff)\n", 195 | "mask.to(device=device)\n", 196 | "lobes = []\n", 197 | "for output_size in output_sizes:\n", 198 | " lobe_dict = OrderedDict([('Layer_input', nn.Linear(input_size, layer_width)),\n", 199 | " ('Elu_input', nn.ELU())])\n", 200 | " for d in range(network_depth):\n", 201 | " lobe_dict['Layer'+str(d)]=nn.Linear(layer_width, layer_width)\n", 202 | " lobe_dict['Elu'+str(d)]=nn.ELU()\n", 203 | " lobe_dict['Layer_output']=nn.Linear(layer_width, output_size)\n", 204 | " lobe_dict['Softmax']=nn.Softmax(dim=1) # obtain fuzzy probability distribution over output states\n", 205 | " \n", 206 | " lobe = nn.Sequential(\n", 207 | " lobe_dict \n", 208 | " )\n", 209 | " lobes.append(lobe.to(device=device))\n", 210 | "\n", 211 | "print(mask)\n", 212 | "print(lobes)\n", 213 | " " 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "id": "6533c8b0", 219 | "metadata": {}, 220 | "source": [ 221 | "### Create iVAMPnets estimator" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "id": "9f55db7a", 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "from ivampnets import iVAMPnet" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "id": "f6c568ac", 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "ivampnet = iVAMPnet(lobes, mask, device, learning_rate=learning_rate, epsilon=epsilon, score_mode=score_mode, learning_rate_mask=learning_rate/2)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "id": "60929530", 247 | "metadata": {}, 248 | "source": [ 249 | "### Plot mask before training" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "id": "e14a6358", 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "from examples import plot_mask\n", 260 | "plot_mask(mask)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "id": "b7875dc2", 266 | "metadata": {}, 267 | "source": [ 268 | "### Create data loader" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "id": "f7ccc0bf", 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "from torch.utils.data import DataLoader\n", 279 | "\n", 280 | "loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n", 281 | "loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "id": "8f98d914", 287 | "metadata": {}, 288 | "source": [ 289 | "### Create a tensorboard writer to observe performance during training" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "id": "3e100e00", 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "tensorboard_installed = False\n", 300 | "if tensorboard_installed:\n", 301 | " from torch.utils.tensorboard import SummaryWriter\n", 302 | " writer = SummaryWriter('./runs/Toy2/')\n", 303 | " input_model, _ = next(iter(loader_train))\n", 304 | "else:\n", 305 | " writer=None" 306 | ] 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "id": "26c4993c", 311 | "metadata": {}, 312 | "source": [ 313 | "### Fit the model on the training data" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "id": "0d2532ae", 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [ 323 | "# model = ivampnet.fit(loader_train, n_epochs=10, validation_loader=loader_val, mask=False, lam_decomp=2., \n", 324 | "# lam_trace=1., start_mask=0, end_trace=2, tb_writer=writer, clip=False).fetch_model()\n", 325 | "# model = ivampnet.fit(loader_train, n_epochs=10, validation_loader=loader_val, mask=True, lam_decomp=0., \n", 326 | "# lam_trace=1., start_mask=0, end_trace=2, tb_writer=writer, clip=False).fetch_model()\n", 327 | "lam_pen_perc=0.05\n", 328 | "lam_pen_C00=0. \n", 329 | "lam_pen_C11=0. \n", 330 | "lam_pen_C01=0.\n", 331 | "model = ivampnet.fit(loader_train, n_epochs=10, validation_loader=loader_val, mask=False, lam_decomp=2., \n", 332 | " lam_trace=0., start_mask=0, end_trace=2, tb_writer=writer, clip=False,\n", 333 | " lam_pen_perc=lam_pen_perc, lam_pen_C00=lam_pen_C00, lam_pen_C11=lam_pen_C11,\n", 334 | " lam_pen_C01=lam_pen_C01).fetch_model()\n", 335 | "model = ivampnet.fit(loader_train, n_epochs=10, validation_loader=loader_val, mask=True, lam_decomp=0., \n", 336 | " lam_trace=0., start_mask=0, end_trace=2, tb_writer=writer, clip=False,\n", 337 | " lam_pen_perc=lam_pen_perc, lam_pen_C00=lam_pen_C00, lam_pen_C11=lam_pen_C11,\n", 338 | " lam_pen_C01=lam_pen_C01).fetch_model()\n", 339 | "if ivampnet.train_pen_scores[-1,1]>0.02:\n", 340 | " print('The model does not seem to be converged to an independent solution!')" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "id": "edcbc766", 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "# execution time (cpu): ~ 30 sec" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "id": "375d23fc", 356 | "metadata": {}, 357 | "source": [ 358 | "### Plot the training and validation scores" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "id": "41e73e2f", 365 | "metadata": {}, 366 | "outputs": [], 367 | "source": [ 368 | "plt.loglog(*ivampnet.train_scores.T, label='training')\n", 369 | "plt.loglog(*ivampnet.validation_scores.T, label='validation')\n", 370 | "plt.xlabel('step')\n", 371 | "plt.ylabel('score')\n", 372 | "plt.legend();" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "id": "dbbed897", 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "plt.loglog(*ivampnet.train_pen_C01.T, label='training')\n", 383 | "plt.loglog(*ivampnet.validation_pen_C01.T, label='validation')\n", 384 | "plt.xlabel('step')\n", 385 | "plt.ylabel('score')\n", 386 | "plt.legend();" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "id": "c19cc279", 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "plt.loglog(*ivampnet.train_pen_scores.T, label='training')\n", 397 | "plt.loglog(*ivampnet.validation_pen_scores.T, label='validation')\n", 398 | "plt.xlabel('step')\n", 399 | "plt.ylabel('score')\n", 400 | "plt.legend();" 401 | ] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "id": "a20f6c9c", 406 | "metadata": {}, 407 | "source": [ 408 | "### Plot the mask after the training" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": null, 414 | "id": "8b1e2ed6", 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "plot_mask(mask)" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": null, 424 | "id": "3ed8483a", 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [ 428 | "# reproduces Fig. 3b (or permutation of it)" 429 | ] 430 | }, 431 | { 432 | "cell_type": "markdown", 433 | "id": "f0935fa5", 434 | "metadata": {}, 435 | "source": [ 436 | "### Compare the eigenvalues from the true and the estimated transition matrix" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "id": "f8b68c6a", 443 | "metadata": {}, 444 | "outputs": [], 445 | "source": [ 446 | "mask.noise=0 # set the noise to zero. \n", 447 | "T_list = model.get_transition_matrix(val_data.data, val_data.data_lagged)" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "id": "f891e262", 454 | "metadata": {}, 455 | "outputs": [], 456 | "source": [ 457 | "# Estimated eigenvalues\n", 458 | "for T in T_list:\n", 459 | " print(np.linalg.eigvals(T))" 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": null, 465 | "id": "5181f619", 466 | "metadata": {}, 467 | "outputs": [], 468 | "source": [ 469 | "# True eigenvalues\n", 470 | "print(np.linalg.eigvals(toymodel.T1)), np.linalg.eigvals(toymodel.T2)" 471 | ] 472 | }, 473 | { 474 | "cell_type": "markdown", 475 | "id": "7bd6fa2a", 476 | "metadata": {}, 477 | "source": [ 478 | "### Plot the state assignment" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": null, 484 | "id": "d91b1e36", 485 | "metadata": {}, 486 | "outputs": [], 487 | "source": [ 488 | "from examples import plot_states\n", 489 | "plot_states(model, val_data.data)" 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "id": "91ae1a08", 495 | "metadata": {}, 496 | "source": [ 497 | "### Estimate the eigenfunctions" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": null, 503 | "id": "01a49a12", 504 | "metadata": {}, 505 | "outputs": [], 506 | "source": [ 507 | "from examples import plot_eigfuncs\n", 508 | "plot_eigfuncs(model, val_data)" 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": null, 514 | "id": "3d8049c4", 515 | "metadata": {}, 516 | "outputs": [], 517 | "source": [ 518 | "# reproduces Fig. 3c (possibly with permutation of state assignments)" 519 | ] 520 | } 521 | ], 522 | "metadata": { 523 | "kernelspec": { 524 | "display_name": "rseed", 525 | "language": "python", 526 | "name": "python3" 527 | }, 528 | "language_info": { 529 | "codemirror_mode": { 530 | "name": "ipython", 531 | "version": 3 532 | }, 533 | "file_extension": ".py", 534 | "mimetype": "text/x-python", 535 | "name": "python", 536 | "nbconvert_exporter": "python", 537 | "pygments_lexer": "ipython3", 538 | "version": "3.8.15" 539 | }, 540 | "vscode": { 541 | "interpreter": { 542 | "hash": "ae7d9bbee64574db456ae08990e902edc98caa89a93077ee255475653bb8dd96" 543 | } 544 | } 545 | }, 546 | "nbformat": 4, 547 | "nbformat_minor": 5 548 | } 549 | -------------------------------------------------------------------------------- /SynaptotagminC2A.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "42847660", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "%matplotlib inline\n", 13 | "\n", 14 | "import torch.nn as nn\n", 15 | "import torch.nn.functional as F\n", 16 | "import torch\n", 17 | "from torch.utils import data\n", 18 | "import torch.optim as optim\n", 19 | "import matplotlib.gridspec as gridspec\n", 20 | "\n", 21 | "import h5py" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "43117557", 27 | "metadata": {}, 28 | "source": [ 29 | "### Hyperparameters" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "id": "6f87cbf8", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "stride = 1\n", 40 | "\n", 41 | "tau = 100//stride \n", 42 | "\n", 43 | "output_sizes = [8,8]\n", 44 | "number_subsystems = len(output_sizes)\n", 45 | "# tau list for timescales estimation\n", 46 | "tau_list = [1,2,4,8]\n", 47 | "\n", 48 | "# Batch size for Stochastic Gradient descent\n", 49 | "batch_size = 20000\n", 50 | "# Which trajectory points percentage is used as validation and testing, the rest is for training\n", 51 | "valid_ratio = 0.3\n", 52 | "test_ratio = 0.0001\n", 53 | "# How many hidden layers the network chi has\n", 54 | "network_depth = 3\n", 55 | "\n", 56 | "# Width of every layer of chi\n", 57 | "layer_width = 100\n", 58 | "# create a list with the number of nodes for each layer\n", 59 | "nodes = [layer_width]*network_depth\n", 60 | "# data preparation\n", 61 | "# how many residues are skipped for distance calculation\n", 62 | "skip_res = 6\n", 63 | "# Size of the windows for attention mechanism\n", 64 | "patchsize = 8\n", 65 | "# How many residues are skipped before defining a new window\n", 66 | "skip_over = 4\n", 67 | "\n", 68 | "# How strong the fake subsystem is\n", 69 | "factor_fake = 2.\n", 70 | "# How large the noise in the mask for regularization is\n", 71 | "noise = 2.\n", 72 | "# Threshold after which the attention weight is set to zero\n", 73 | "cutoff=0.9\n", 74 | "# Learning rate\n", 75 | "learning_rate=0.001\n", 76 | "# epsilon\n", 77 | "epsilon=1e-5\n", 78 | "# score method\n", 79 | "score_mode='regularize' # one of ('trunc', 'regularize', 'clamp', 'old')" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "id": "d10b4cf1", 85 | "metadata": {}, 86 | "source": [ 87 | "### Load data" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "588c7ec9", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# data set has a total length of 184 µs with a 1 ns resolution (total of 184000 frames)\n", 98 | "\n", 99 | "data_trajs = []\n", 100 | "hdf5_names = []\n", 101 | "loaded_data_stride = 100\n", 102 | "exclude_list = []\n", 103 | "with h5py.File(f\"/group/ag_cmb/scratch/deeptime_data/syt/syt_0cal_internal1by1_stride{loaded_data_stride}.hdf5\", \"r\") as f: # 1 frame = 1 ns\n", 104 | " #print(\"datasets:\", f.keys())\n", 105 | " for n, name in enumerate(f.keys()):\n", 106 | " if n not in exclude_list:\n", 107 | " hdf5_names.append(name)\n", 108 | " dset = f[name]\n", 109 | " dat = dset[...].astype('float32')\n", 110 | "\n", 111 | " data_trajs.append(1/np.exp(dat[::stride]))" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "id": "51a0e9fd", 117 | "metadata": {}, 118 | "source": [ 119 | "### Define dataset" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "f45ee7fb", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "from deeptime.util.data import TrajectoriesDataset\n", 130 | "\n", 131 | "dataset = TrajectoriesDataset.from_numpy(lagtime=tau, data=data_trajs)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "id": "22c7b97e", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "# split into training/validation/test set\n", 142 | "n_val = int(len(dataset)*valid_ratio)\n", 143 | "n_test = int(len(dataset)*test_ratio)\n", 144 | "train_data, val_data, test_data = torch.utils.data.random_split(dataset, [len(dataset) - n_val - n_test, n_val, n_test])" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "id": "9d6acabf", 150 | "metadata": {}, 151 | "source": [ 152 | "### Define networks" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "id": "634d4b2b", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "from masks import Mask_proteins\n", 163 | "from collections import OrderedDict\n", 164 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 165 | "\n", 166 | "# Assuming that we are on a CUDA machine, this should print a CUDA device:\n", 167 | "train_mean = np.concatenate(train_data.dataset.trajectories, axis=0).mean(0)\n", 168 | "train_std = np.concatenate(train_data.dataset.trajectories, axis=0).std(0)\n", 169 | "print(device)\n", 170 | "input_size = train_data.dataset.trajectories[0].shape[-1]\n", 171 | "mask = Mask_proteins(input_size, number_subsystems, skip_res=skip_res, patchsize=patchsize, skip=skip_over, mean=torch.Tensor(train_mean),\n", 172 | " std=torch.Tensor(train_std), factor_fake=factor_fake, noise=noise, device=device, cutoff=cutoff)\n", 173 | "mask.to(device=device)\n", 174 | "lobes = []\n", 175 | "for output_size in output_sizes:\n", 176 | " lobe_dict = OrderedDict([('Layer_input', nn.Linear(input_size, layer_width)),\n", 177 | " ('Elu_input', nn.ELU())])\n", 178 | " for d in range(network_depth):\n", 179 | " lobe_dict['Layer'+str(d)]=nn.Linear(layer_width, layer_width)\n", 180 | " lobe_dict['Elu'+str(d)]=nn.ELU()\n", 181 | " lobe_dict['Layer_output']=nn.Linear(layer_width, output_size)\n", 182 | " lobe_dict['Softmax']=nn.Softmax(dim=1) # obtain fuzzy probability distribution over output states\n", 183 | " \n", 184 | " lobe = nn.Sequential(\n", 185 | " lobe_dict \n", 186 | " )\n", 187 | " lobes.append(lobe.to(device=device))\n", 188 | "\n", 189 | "print(mask)\n", 190 | "print(lobes) " 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "id": "ea28cbbb", 196 | "metadata": {}, 197 | "source": [ 198 | "### Create iVAMPnets estimator" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "id": "05aabb2d", 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "from ivampnets import iVAMPnet" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "id": "0325caf5", 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "ivampnet = iVAMPnet(lobes, mask, device, learning_rate=learning_rate, epsilon=epsilon, score_mode=score_mode)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "id": "967887b4", 224 | "metadata": {}, 225 | "source": [ 226 | "### Plot mask before training" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "id": "dba6b102", 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "from examples import plot_mask\n", 237 | "plot_mask(mask, skip=10)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "id": "f03dd31d", 243 | "metadata": {}, 244 | "source": [ 245 | "### Create data loader" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "id": "06a08456", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "from torch.utils.data import DataLoader\n", 256 | "\n", 257 | "loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n", 258 | "loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "id": "c8a820a0", 264 | "metadata": {}, 265 | "source": [ 266 | "### Create a tensorboard writer to observe performance during training" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "id": "7038f8fc", 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "tensorboard_installed = False\n", 277 | "if tensorboard_installed:\n", 278 | " from torch.utils.tensorboard import SummaryWriter\n", 279 | " writer = SummaryWriter(log_dir='./runs/Syt/')\n", 280 | " input_model, _ = next(iter(loader_train))\n", 281 | " # writer.add_graph(lobe, input_to_model=input_model.to(device))\n", 282 | "else:\n", 283 | " writer=None" 284 | ] 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "id": "229a10ff", 289 | "metadata": {}, 290 | "source": [ 291 | "### Fit the model on the training data" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "id": "bfaff88e", 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "model = ivampnet.fit(loader_train, n_epochs=50, validation_loader=loader_val, mask=True, lam_decomp=20., \n", 302 | " lam_trace=1., start_mask=0, end_trace=20, tb_writer=writer, clip=False).fetch_model()\n", 303 | "\n", 304 | "plot_mask(mask, skip=10)\n", 305 | "mask.noise=5.\n", 306 | "model = ivampnet.fit(loader_train, n_epochs=150, validation_loader=loader_val, mask=True, lam_decomp=50., \n", 307 | " lam_trace=0., start_mask=0, end_trace=0, tb_writer=writer, clip=False).fetch_model()\n", 308 | "plot_mask(mask, skip=10)\n", 309 | "mask.noise=10.\n", 310 | "model = ivampnet.fit(loader_train, n_epochs=150, validation_loader=loader_val, mask=True, lam_decomp=100., \n", 311 | " lam_trace=0., start_mask=0, end_trace=0, tb_writer=writer, clip=False).fetch_model()" 312 | ] 313 | }, 314 | { 315 | "cell_type": "markdown", 316 | "id": "474445e3", 317 | "metadata": {}, 318 | "source": [ 319 | "### Plot training and validation scores" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "id": "8a16599e", 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "plt.loglog(*ivampnet.train_scores.T, label='training')\n", 330 | "plt.loglog(*ivampnet.validation_scores.T, label='validation')\n", 331 | "plt.xlabel('step')\n", 332 | "plt.ylabel('score')\n", 333 | "plt.legend();" 334 | ] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "id": "2e65c886", 339 | "metadata": {}, 340 | "source": [ 341 | "### Plot the mask after training" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "id": "66a61989", 348 | "metadata": {}, 349 | "outputs": [], 350 | "source": [ 351 | "plot_mask(mask, skip=10)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "id": "8437a367", 358 | "metadata": {}, 359 | "outputs": [], 360 | "source": [ 361 | "from examples import plot_protein_its, plot_protein_mask\n", 362 | "plot_protein_mask(mask, skip_start=4)" 363 | ] 364 | }, 365 | { 366 | "cell_type": "markdown", 367 | "id": "f01dfcd4", 368 | "metadata": {}, 369 | "source": [ 370 | "### Finally train without noise" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": null, 376 | "id": "92c01b2a", 377 | "metadata": {}, 378 | "outputs": [], 379 | "source": [ 380 | "# the noise is only important to make the training of the mask meaningfull.\n", 381 | "# Here, the mask should be well trained, so we disable the mask training from here on.\n", 382 | "mask.noise=0.\n", 383 | "model = ivampnet.fit(loader_train, n_epochs=300, validation_loader=loader_val, mask=False, lam_decomp=100., \n", 384 | " lam_trace=0., start_mask=0, end_trace=0, tb_writer=writer, clip=False).fetch_model()" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": null, 390 | "id": "ddc0b3ab", 391 | "metadata": {}, 392 | "outputs": [], 393 | "source": [ 394 | "# In principle you can then also train the model without enforcing the decomposition score anymore\n", 395 | "# However, you should observe if the independence score rise significantly, then you need to reverse the progress\n", 396 | "# You can use the save_criteria parameter to control it.\n", 397 | "model = ivampnet.fit(loader_train, n_epochs=300, validation_loader=loader_val, mask=False, lam_decomp=0., \n", 398 | " lam_trace=0., start_mask=0, end_trace=0, tb_writer=writer, clip=False, save_criteria=0.012).fetch_model()" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "id": "29e2875c", 404 | "metadata": {}, 405 | "source": [ 406 | "### Estimate implied timescales" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": null, 412 | "id": "036bd72b", 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "runs = 5\n", 417 | "its = [[] for _ in range(runs)]\n", 418 | "# cheap error estimation, instead of retraining chi, evaluate the model on different trajectories\n", 419 | "percentage = 0.9\n", 420 | "N_trajs = len(dataset.trajectories)\n", 421 | "indexes_traj = np.arange(N_trajs)\n", 422 | "n_val = int(N_trajs * percentage)\n", 423 | "msmlags=np.array([1,2,4,6,10,15,20,25])*10\n", 424 | "for run in range(runs):\n", 425 | " for tau_i in msmlags:\n", 426 | " np.random.shuffle(indexes_traj)\n", 427 | " indexes_used = indexes_traj[:n_val]\n", 428 | " data_t = np.concatenate([dataset.trajectories[a][:-tau_i] for a in indexes_used], axis=0)\n", 429 | " data_tau = np.concatenate([dataset.trajectories[a][tau_i:] for a in indexes_used], axis=0)\n", 430 | " its[run].append(model.timescales(data_t, data_tau, tau_i, batchsize=10000))" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": null, 436 | "id": "0e16ad8b", 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [ 440 | "# reorder its, subsystems can have different outputsizes!\n", 441 | "its_reorder = [np.zeros((runs,len(msmlags), output_sizes[n]-1)) for n in range(number_subsystems)]\n", 442 | "for n in range(number_subsystems):\n", 443 | " for run in range(runs):\n", 444 | " for lag in range(len(msmlags)):\n", 445 | " its_reorder[n][run,lag] = its[run][lag][n]" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": null, 451 | "id": "57c8d5a3", 452 | "metadata": {}, 453 | "outputs": [], 454 | "source": [ 455 | "axes, fig = plot_protein_its(its_reorder, msmlags, ylog=True, multiple_runs=True, percent=0.9)\n", 456 | "x_ticks = np.array([1,5,10,20,40])*10\n", 457 | "x_ticks_labels = x_ticks*stride # for estimating the right units!\n", 458 | "y_ticks = np.array([1000,10000, 100000])/stride\n", 459 | "y_ticks_labels = y_ticks*stride/1000\n", 460 | "for n in range(number_subsystems):\n", 461 | " ax=axes[n]\n", 462 | " ax.plot(msmlags,msmlags, 'k')\n", 463 | " ax.fill_between(msmlags, msmlags[0], msmlags, color = 'k', alpha = 0.2)\n", 464 | " ax.set_xlabel('Lagtime [ns]', fontsize=16)\n", 465 | " if n==0:\n", 466 | " ax.set_ylabel('Implied Timescales [$\\mu$s]', fontsize=16)\n", 467 | " ax.legend(fontsize=14, loc='lower right')\n", 468 | " ax.set_xticks(x_ticks)\n", 469 | " ax.set_xticklabels(x_ticks_labels, fontsize=14)\n", 470 | " ax.set_yticks(y_ticks)\n", 471 | " ax.set_yticklabels(y_ticks_labels, fontsize=14)\n", 472 | " ax.tick_params(direction='out', length=6, width=2, colors='k',\n", 473 | " grid_color='k', grid_alpha=0.5)\n", 474 | " ax.set_xlim(10,250)\n", 475 | " ax.set_ylim(0.01*1000, 200*1000)\n", 476 | " # fig.savefig('./Syt_its.pdf', bbox_inches='tight')\n", 477 | "\n", 478 | "plt.show()" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": null, 484 | "id": "9df5595e", 485 | "metadata": {}, 486 | "outputs": [], 487 | "source": [ 488 | "# reproduces Fig. 5b" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": null, 494 | "id": "3e6ec4ad", 495 | "metadata": {}, 496 | "outputs": [], 497 | "source": [ 498 | "# ivampnet.save_params('./Syt_params')" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": null, 504 | "id": "f51b21f6", 505 | "metadata": {}, 506 | "outputs": [], 507 | "source": [] 508 | } 509 | ], 510 | "metadata": { 511 | "kernelspec": { 512 | "display_name": "Python 3", 513 | "language": "python", 514 | "name": "python3" 515 | }, 516 | "language_info": { 517 | "codemirror_mode": { 518 | "name": "ipython", 519 | "version": 3 520 | }, 521 | "file_extension": ".py", 522 | "mimetype": "text/x-python", 523 | "name": "python", 524 | "nbconvert_exporter": "python", 525 | "pygments_lexer": "ipython3", 526 | "version": "3.8.8" 527 | } 528 | }, 529 | "nbformat": 4, 530 | "nbformat_minor": 5 531 | } 532 | -------------------------------------------------------------------------------- /examples.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import deeptime 3 | import matplotlib.pyplot as plt 4 | 5 | class Toymodel_2Systems(): 6 | ''' Class for generating the data for the Toymodel with two subsystems sampled from a hidden Markov Chain 7 | ''' 8 | def __init__(self, eps_list, mean=None, cov=None): 9 | super().__init__() 10 | self.eps_list = eps_list 11 | 12 | self.T, self.T1, self.T2 = self.generate_hidden_matrix() 13 | self.msm = msm = deeptime.markov.msm.MarkovStateModel(self.T) 14 | if mean is None: 15 | mean_per_state = np.array([[2, 2], 16 | [2, -2], 17 | [0, 2], 18 | [0, -2], 19 | [-2, 2], 20 | [-2, -2]]) 21 | if cov is None: 22 | cov = .1 * np.eye(2) 23 | self.mean = mean_per_state 24 | self.cov = cov 25 | 26 | def generate_traj(self, steps): 27 | ''' Generate a trajectory with the defined hidden Markov Chain 28 | 29 | Parameters 30 | ---------- 31 | steps: int 32 | Number of timesteps 33 | 34 | Returns 35 | ------- 36 | hidden_state_traj: np.array 37 | The hidden Markov Chain. 38 | observable_traj: np.array 39 | The observable trajectory. 40 | 41 | ''' 42 | hidden_state_traj = self.msm.simulate(steps) 43 | observable_traj = np.zeros((hidden_state_traj.shape[0], 2)) - 1 44 | n_hidden = self.T.shape[0] 45 | for state in range(n_hidden): 46 | ix = np.where(hidden_state_traj == state)[0] 47 | observable_traj[ix] = np.random.multivariate_normal(self.mean[state], self.cov, size=ix.shape[0]) 48 | 49 | return hidden_state_traj, observable_traj 50 | 51 | 52 | def generate_hidden_matrix(self): 53 | """ 54 | Generates hidden transition matrix. 55 | """ 56 | eps0, eps1, eps2, eps3 = self.eps_list 57 | X1 = np.array([[1-eps0-eps0, eps0, eps0], 58 | [eps1, 1-eps1-eps1, eps1], 59 | [eps0, eps1, 1-eps0-eps1]]) 60 | X1 = X1/np.sum(X1, keepdims=True) 61 | pi = np.sum(X1,1, keepdims=True) 62 | 63 | T1 = X1 / pi 64 | # X2 = np.array([[1-eps2-eps2, eps2, eps2], 65 | # [eps3, 1-eps3-eps3, eps3], 66 | # [eps2, eps3, 1-eps2-eps3]]) 67 | X2 = np.array([[1-eps2, eps2], 68 | [eps2,1-eps2]]) 69 | X2 = X2/np.sum(X2, keepdims=True) 70 | pi = np.sum(X2,1, keepdims=True) 71 | 72 | T2 = X2 / pi 73 | 74 | T = np.kron(T1, T2) 75 | 76 | assert deeptime.markov.tools.analysis.is_transition_matrix(T) 77 | return T, T1, T2 78 | 79 | def plot_toymodel(self, hidden_state_traj, observable_traj): 80 | ''' Plots the toymodel given a hidden trajectory and the corresponding observable coordinates. 81 | 82 | Parameters 83 | --------- 84 | hidden_state_traj: nd.array 85 | The hidden trajectory of size (T,), where T is the number of frames. 86 | observable_traj: nd.array 87 | The observable array of size (T, n), where n is the size of the observable space. 88 | ''' 89 | plt.scatter(*observable_traj.T, c=hidden_state_traj, alpha=.5) 90 | plt.show() 91 | 92 | def plot_eigfunc(self, hidden_state_traj, observable_traj, save=None): 93 | ''' Plots the true eigenfunctions. 94 | 95 | Parameters 96 | ---------- 97 | hidden_state_traj: nd.array 98 | The hidden trajectory of size (T,), where T is the number of frames. 99 | observable_traj: nd.array 100 | The observable array of size (T, n), where n is the size of the observable space. 101 | save: default=None 102 | If save is not None, the figure will be saved. 103 | ''' 104 | eigv, eigvec = np.linalg.eig(self.T) 105 | ind_sort = np.argsort(eigv)[::-1] 106 | eigv = eigv[ind_sort] 107 | eigvec = eigvec[:,ind_sort] 108 | 109 | x_size = 3 110 | y_size = 2 111 | factor=2 112 | factor_x=1.5 113 | factor_y=2 114 | fig, ax = plt.subplots(x_size, y_size, sharex=True, sharey=True, figsize=(6*factor_x,4*factor_y)) 115 | i_state = 0 116 | skip=1 117 | ax[0,0].text(0.8,6,'Global eigenfunctions', fontsize=10*factor) 118 | for i in range(self.T.shape[0]): 119 | # print(output_i, system_i) 120 | output_i = i//y_size 121 | system_i = i%y_size 122 | # print(output_i, system_i) 123 | eigv_i = eigvec[:,i] 124 | if i ==0: 125 | c=np.ones_like(eigv_i[hidden_state_traj[::skip]]) 126 | else: 127 | c=eigv_i[hidden_state_traj[::skip]] 128 | ax[output_i, system_i].scatter( 129 | *observable_traj[::skip].T, c=c, 130 | ) 131 | ax[output_i, system_i].set_title(r'$\lambda_{}={:.3}$'.format(i,eigv[i]), fontsize=10*factor) 132 | if output_i==(x_size-1): 133 | ax[output_i, system_i].set_xlabel('x', fontsize=10*factor) 134 | ax[output_i, system_i].set_xticks([-2,0,2]) 135 | ax[output_i, system_i].set_xticklabels([-2,0,2], fontsize=8*factor) 136 | if system_i ==0: 137 | ax[output_i, system_i].set_ylabel('y', fontsize=10*factor) 138 | ax[output_i, system_i].set_yticks([-2,0,2]) 139 | ax[output_i, system_i].set_yticklabels([-2,0,2], fontsize=8*factor) 140 | if save is not None: 141 | fig.savefig('./3x2_mix_T_hidden_eigvec.png', bbox_inches='tight', dpi=900) 142 | fig.show() 143 | 144 | def plot_mask(mask, vmax=1., save=False, skip=1): 145 | ''' Plots the mask of the toymodels. 146 | 147 | Parameters 148 | ---------- 149 | mask: masks.Mask 150 | The mask defined in masks.py 151 | vmax: float 152 | The maximal value of the scale which will be used. 153 | save: bool 154 | If True, the figure will be saved. 155 | skip: int 156 | Number of input features which will be skipped for the yticks. 157 | 158 | ''' 159 | attention = mask.get_softmax() 160 | attention_np = np.squeeze(attention.detach().to('cpu').numpy()) 161 | plt.imshow(attention_np, vmin=0, vmax=vmax, cmap=plt.cm.binary, aspect='auto') 162 | plt.xlabel('Subsystem', fontsize=18) 163 | plt.ylabel('Input', fontsize=18) 164 | input_size, number_subsystems = attention_np.shape 165 | plt.xticks(np.arange(number_subsystems),['{}'.format(i) for i in range(number_subsystems)], fontsize=16) 166 | plt.yticks(np.arange(0,input_size,skip),['x{}'.format(i) for i in range(0,input_size,skip)], fontsize=16) 167 | plt.show() 168 | if save: 169 | plt.savefig('./Mask.pdf', bbox_inches='tight') 170 | 171 | def plot_states(model, data, save=False): 172 | ''' Plots the state probability vector of all subsystems. 173 | 174 | Parameters 175 | ---------- 176 | model: ivampnets.iVAMPnetModel 177 | The model which transforms the input data. 178 | data: torch.Tensor or nd.array 179 | Input data which should be plotted. Has to be transformabel by the model. 180 | save: bool 181 | If True, the figure will be saved. 182 | ''' 183 | pred_list = model.transform(data) 184 | number_subsystems = len(pred_list) 185 | transformed_data = [] 186 | output_sizes = [] 187 | for n in range(model._N): 188 | transformed_data.append(np.concatenate(pred_list[n], axis=0)) 189 | output_sizes.append(transformed_data[-1].shape[-1]) 190 | transformed_data = np.concatenate(transformed_data, axis=1) 191 | subsysteme = ['I', 'II'] 192 | 193 | max_output_size = max(output_sizes) 194 | x_size = output_sizes[0] 195 | y_size = output_sizes[1] 196 | factor=2 197 | factor_x=1.5 198 | factor_y=2 199 | fig, ax = plt.subplots(x_size, y_size, sharex=True, sharey=True, figsize=(6*factor_x,4*factor_y)) 200 | ax[0,0].text(1.,6,'State assignment', fontsize=10*factor) 201 | state_real = 0 202 | for i_state in range(number_subsystems * max_output_size): 203 | output_i = i_state%max_output_size 204 | system_i = i_state//max_output_size 205 | if output_i < output_sizes[system_i]: 206 | z = transformed_data[:,state_real] 207 | # print(z.shape) 208 | ax[output_i, system_i].scatter( 209 | x=data[:, 0], y=data[:, 1], c=z, 210 | ) 211 | if output_i ==0: 212 | ax[output_i, system_i].set_title(f"Subsystem {subsysteme[system_i]}", fontsize=10*factor) 213 | if system_i ==0: 214 | if output_i==(y_size): 215 | ax[output_i, system_i].set_xlabel('x', fontsize=10*factor) 216 | ax[output_i, system_i].set_xticks([-2,0,2]) 217 | ax[output_i, system_i].set_xticklabels([-2,0,2], fontsize=8*factor) 218 | else: 219 | if output_i==(y_size-1): 220 | ax[output_i, system_i].set_xlabel('x', fontsize=10*factor) 221 | ax[output_i, system_i].set_xticks([-2,0,2]) 222 | ax[output_i, system_i].set_xticklabels([-2,0,2], fontsize=8*factor) 223 | if system_i ==0: 224 | ax[output_i, system_i].set_ylabel('y', fontsize=10*factor) 225 | ax[output_i, system_i].set_yticks([-2,0,2]) 226 | ax[output_i, system_i].set_yticklabels([-2,0,2], fontsize=8*factor) 227 | state_real+=1 228 | else: 229 | ax[output_i, system_i].axis('off') 230 | if save: 231 | fig.savefig('3x2_mix_state_assignment.png', bbox_inches='tight', dpi=900) 232 | plt.show() 233 | 234 | def plot_eigfuncs(model, dataset): 235 | ''' Plots the eigenfunctions of the approximation of the model given the dataset. 236 | 237 | Parameters 238 | ---------- 239 | model: ivampnets.iVAMPnetModel 240 | The model which transforms the input data. 241 | dataset: TrajectoryDataset 242 | Dataset with data and data_lagged. 243 | ''' 244 | T_list = model.get_transition_matrix(dataset.data, dataset.data_lagged) 245 | pred_list = model.transform(dataset.data) 246 | number_subsystems = len(pred_list) 247 | transformed_data = [] 248 | output_sizes = [] 249 | for n in range(model._N): 250 | transformed_data.append(np.concatenate(pred_list[n], axis=0)) 251 | output_sizes.append(transformed_data[-1].shape[-1]) 252 | 253 | x_size = output_sizes[0] 254 | y_size = 2 255 | factor=2 256 | factor_x=1.5 257 | factor_y=2 258 | fig, ax = plt.subplots(x_size, y_size, sharex=True, sharey=True, figsize=(6*factor_x,4*factor_y)) 259 | ax[0,0].text(-2,6,'Subsystem I', fontsize=10*factor) 260 | ax[0,1].text(-2,6,'Subsystem II', fontsize=10*factor) 261 | i_state = 0 262 | for n in range(number_subsystems): 263 | K=T_list[n] 264 | eigv, eigvec = np.linalg.eig(K) 265 | ind_sort = np.argsort(eigv)[::-1] 266 | eigv = eigv[ind_sort] 267 | eigvec = eigvec[:,ind_sort] 268 | for i in range(K.shape[0]): 269 | output_i=i 270 | system_i=n 271 | eigv_i = eigvec[:,i] 272 | if i ==0: 273 | c=np.ones_like(transformed_data[n]@eigv_i) 274 | else: 275 | c=transformed_data[n]@eigv_i 276 | if output_i < output_sizes[system_i]: 277 | 278 | ax[output_i, system_i].scatter( 279 | *dataset.data.T, c=c, 280 | ) 281 | ax[output_i, system_i].set_title(r'$\lambda_{}={:.3}$'.format(output_i,eigv[i]), fontsize=10*factor) 282 | if output_i==(output_sizes[n]-1): 283 | ax[output_i, system_i].set_xlabel('x', fontsize=10*factor) 284 | ax[output_i, system_i].set_xticks([-2,0,2]) 285 | ax[output_i, system_i].set_xticklabels([-2,0,2], fontsize=8*factor) 286 | if system_i ==0: 287 | ax[output_i, system_i].set_ylabel('y', fontsize=10*factor) 288 | ax[output_i, system_i].set_yticks([-2,0,2]) 289 | ax[output_i, system_i].set_yticklabels([-2,0,2], fontsize=8*factor) 290 | 291 | ax[2, 1].axis('off') 292 | # fig.savefig('./Figs/3x2_mix_T_hidden_eigvec_estimated.png', bbox_inches='tight', dpi=900) 293 | plt.show() 294 | 295 | 296 | class HyperCube(): 297 | ''' Class for generating the data for the Hyper Cube sampled from a hidden Markov Chain. 298 | 299 | Parameters 300 | ----------- 301 | eps_list: list. 302 | List of the probability for each independent subsystem to stay in the same state 303 | lam: float. 304 | Coupling of the subsystems. If zero no coupling is active. 305 | mean: np.array 306 | Defines the mean values of the multivariant Gaussians, when generating a trajectory in the observable space. 307 | If None, predefined values are taken 308 | std: np.array 309 | Defines the std values of the same multivariant Gaussian. 310 | ''' 311 | def __init__(self, eps_list, lam=0.0, mean=None, cov=None): 312 | super().__init__() 313 | self.eps_list = eps_list 314 | self.lam = lam 315 | self.T_total, self.T_list, self.T_coupled_list = self.generate_hidden_matrix() 316 | self.msm = msm = deeptime.markov.msm.MarkovStateModel(self.T_total) 317 | self.N = len(eps_list) 318 | output_size = [2 for _ in range(self.N)] 319 | if mean is None: 320 | indices_fullsys = np.arange(2**self.N) 321 | indices_subsystems = np.unravel_index(indices_fullsys, output_size) 322 | indices_fullsys, indices_subsystems 323 | mean_per_state = [] 324 | for i in range(len(indices_fullsys)): 325 | list_ind = [indices_subsystems[n][i] for n in range(self.N)] 326 | mean_per_state.append(list_ind) 327 | mean_per_state = 2*np.array(mean_per_state) 328 | if cov is None: 329 | cov = .1 * np.eye(self.N) 330 | self.mean = mean_per_state 331 | self.cov = cov 332 | 333 | self.eigvals_list = [] 334 | self.eigvals_list_coupled = [] 335 | for i in range(self.N): 336 | Ti = self.T_list[i] 337 | eigv, eigvec = np.linalg.eig(Ti) 338 | ind_sort = np.argsort(eigv)[::-1] 339 | eigv = eigv[ind_sort] 340 | self.eigvals_list.append(eigv[1:]) 341 | if i<(self.N//2): 342 | Ti = self.T_coupled_list[i] 343 | eigv, eigvec = np.linalg.eig(Ti) 344 | ind_sort = np.argsort(eigv)[::-1] 345 | eigv = eigv[ind_sort] 346 | self.eigvals_list_coupled.append(eigv[1:-1]) 347 | 348 | def generate_traj(self, steps, angles=None, dim_noise=0): 349 | ''' Generate a trajectory with the defined hidden Markov Chain 350 | 351 | Parameters 352 | ---------- 353 | steps: int 354 | Number of timesteps 355 | angles: np.array 356 | Rotate the observable space by specified angles. 357 | dim_noise: int 358 | Number of noise dimensions 359 | Returns 360 | ------- 361 | hidden_state_traj: np.array 362 | The hidden Markov Chain. 363 | observable_traj: np.array 364 | The observable trajectory. 365 | 366 | ''' 367 | 368 | hidden_state_traj = self.msm.simulate(steps) 369 | observable_traj = np.zeros((hidden_state_traj.shape[0], self.N)) - 1 370 | n_hidden = self.T_total.shape[0] 371 | for state in range(n_hidden): 372 | ix = np.where(hidden_state_traj == state)[0] 373 | observable_traj[ix] = np.random.multivariate_normal(self.mean[state], self.cov, size=ix.shape[0]) 374 | 375 | if angles is not None: 376 | rot_matrix = self._get_rotation_matrix(angles) 377 | observable_traj = observable_traj @ rot_matrix 378 | if dim_noise>0: 379 | observable_traj = np.concatenate((observable_traj, np.random.randn(steps, dim_noise)), axis=1) 380 | 381 | return hidden_state_traj, observable_traj 382 | 383 | 384 | def generate_hidden_matrix(self): 385 | """ 386 | Generates hidden transition matrix. 387 | """ 388 | T_list = [] 389 | T_coupled_list = [] 390 | lam = self.lam 391 | for i in range(len(self.eps_list)): 392 | epsi = self.eps_list[i] 393 | Ti = np.array([[1-epsi, epsi], 394 | [epsi, 1-epsi]]) 395 | T_list.append(Ti) 396 | if (i%2)==0: 397 | eps1 = self.eps_list[i] 398 | eps2 = self.eps_list[i+1] 399 | Tij = np.array([[(1 - eps2) * (1 - eps1) - lam, eps2 * (1 - eps1) - lam, (1 - eps2) * eps1+lam, eps2 * eps1+lam], 400 | [eps2 * (1 - eps1) - lam, (1 - eps2) * (1 - eps1) - lam, eps2 * eps1+lam, (1 - eps2) * eps1+lam], 401 | [(1 - eps2) * eps1 + lam, eps2 * eps1 + lam, (1 - eps2) * (1 - eps1) - lam, eps2 * (1 - eps1) - lam], 402 | [eps2 * eps1 + lam, (1 - eps2) * eps1 + lam, eps2 * (1 - eps1) - lam, (1 - eps2) * (1 - eps1) - lam]]) 403 | T_coupled_list.append(Tij) 404 | 405 | T_total = np.array([[1]]) 406 | for Ti in T_coupled_list: 407 | T_total = np.kron(T_total, Ti) 408 | return T_total, T_list, T_coupled_list 409 | 410 | def _get_rotation_matrix(self, angles=None): 411 | '''Goal is to create a rotation matrix which just rotates within a coupled 2D system, 412 | so each subsystem just needs information from two input features''' 413 | 414 | if type(angles)==type(None): 415 | angles = 2 * np.pi * np.random.random(self.N//2) 416 | 417 | rot_total = np.eye(self.N) 418 | for i in range(self.N//2): 419 | rot_temp = np.eye(self.N) 420 | start = i*2 421 | end = start+2 422 | rot = np.array([[ np.cos(angles[i]), np.sin(angles[i])], 423 | [-np.sin(angles[i]), np.cos(angles[i])]]) 424 | rot_temp[start:end, start:end] = rot 425 | # print(rot_temp) 426 | rot_total = rot_total @ rot_temp 427 | 428 | return rot_total 429 | 430 | 431 | def plot_its(its, lag, ylog=False, multiple_runs = False): 432 | '''Plots the provided implied timescales.' 433 | 434 | Parameters 435 | ---------- 436 | its: numpy array 437 | the its array returned by the function get_its 438 | lag: numpy array 439 | lag times array used to estimate the implied timescales 440 | ylog: Boolean, optional, default = False 441 | if true, the plot will be a logarithmic plot, otherwise it 442 | will be a semilogy plot 443 | multiple_runs: bool 444 | If True the provided its are expected to have a first dimension with number of runs which should be used to 445 | estimate a mean and an error estimate. 446 | 447 | ''' 448 | fig, ax = plt.subplots() 449 | 450 | func = ax.loglog if ylog else ax.semilogy 451 | if not multiple_runs: 452 | its = np.sort(its, axis=0) 453 | for i in range(np.shape(its)[0]): 454 | j=i+1 455 | if i==0: 456 | label='Model' 457 | else: 458 | label='' 459 | func(lag, its[-j] ,'o',lw=2, ms=7, label=label) 460 | else: 461 | its_mean = np.mean(its, 0)[::-1] 462 | its_std = np.std(its, 0)[::-1] 463 | for index_its, m, s in zip(range(len(its)), its_mean, its_std): 464 | func(lag, m, color = 'C{}'.format(index_its)) 465 | ax.fill_between(lag, m+s, m-s, color = 'C{}'.format(index_its), alpha = 0.2) 466 | 467 | func(lag,lag, 'k') 468 | ax.fill_between(lag, lag, 0.99, alpha=0.2, color='k'); 469 | return ax, fig 470 | 471 | def plot_hypercube_its(its, msmlags, its_true, ylog=False, save=None): 472 | '''Plots the provided implied timescales of the hypercube toy example.' 473 | 474 | Parameters 475 | ---------- 476 | its: numpy array 477 | the its array returned by the function get_its 478 | lag: numpy array 479 | lag times array used to estimate the implied timescales 480 | ylog: Boolean, optional, default = False 481 | if true, the plot will be a logarithmic plot, otherwise it 482 | will be a semilogy plot 483 | multiple_runs: bool 484 | If True the provided its are expected to have a first dimension with number of runs which should be used to 485 | estimate a mean and an error estimate. 486 | 487 | ''' 488 | ax, fig = plot_its(its, msmlags, ylog=ylog) 489 | 490 | for i, _its in enumerate(its_true.T): 491 | if i==0: 492 | label='True' 493 | else: 494 | label='' 495 | ax.hlines(_its, 1,msmlags.max(), color='C{}'.format(i), label=label) 496 | # ax.plot(msmlags, _its, 'x',ms=8, c='C{}'.format(i), label=label) 497 | ax.set_xlabel('Lagtime [a.u.]', fontsize=16) 498 | ax.set_ylabel('Implied Timescales [a.u.]', fontsize=16) 499 | ax.legend(fontsize=14, loc='lower right') 500 | ax.set_xticks([1,3,6,9]) 501 | ax.set_xlim(0.95,8.5) 502 | ax.set_ylim(1,60) 503 | ax.set_xticklabels([1,3,6,9], fontsize=14) 504 | ax.set_xticklabels([], fontsize=14, minor=True) 505 | # ax.set_xticks([2,4,5,7,8]) 506 | # ax.set_xticklabels(['','','','',''], fontsize=14) 507 | 508 | ax.set_yticks([1,10,50]) 509 | ax.set_yticklabels([1,10,50], fontsize=14) 510 | ax.tick_params(which='major', direction='out', length=6, width=2, colors='k', 511 | grid_color='k', grid_alpha=0.5) 512 | if save is not None: 513 | fig.savefig('./Hypercube_10_ITS.pdf', bbox_inches='tight') 514 | plt.show() 515 | 516 | def plot_protein_mask(mask, skip_start=4, save=None): 517 | ''' Helper function to plot the mask of a protein. 518 | Parameters 519 | ---------- 520 | mask: masks.mask_proteins 521 | A mask_proteins object from the mask.py file. 522 | skip_start: int 523 | How many residues where skipped from the beginning of the chain before including them in the distance 524 | calculation. 525 | save: bool 526 | If true the plot will be saved. 527 | ''' 528 | import matplotlib.lines as mlines 529 | attention = mask.get_softmax() 530 | values = np.squeeze(attention.detach().to('cpu').numpy()) 531 | plt.plot(np.arange(skip_start,attention.shape[0]+skip_start), values, linewidth=2) 532 | plt.xticks(fontsize=14) 533 | plt.xlabel('Residue', fontsize=16) 534 | plt.yticks(fontsize=14) 535 | plt.ylabel('Importance weight', fontsize=16) 536 | patch1 = mlines.Line2D([], [], color='C0',linewidth=3, 537 | label='Subsystem I') 538 | patch2 = mlines.Line2D([], [], color='C1',linewidth=3, 539 | label='Subsystem II') 540 | plt.legend(handles=[patch1, patch2], fontsize=14) 541 | if save: 542 | plt.savefig('./Syt_attention.pdf', bbox_inches='tight') 543 | plt.show() 544 | 545 | def plot_protein_its(its, lag, ylog=False, multiple_runs = False, percent=0.68): 546 | '''Plots the implied timescales calculated by the function 547 | 'get_its' 548 | 549 | Parameters 550 | ---------- 551 | its: numpy array 552 | the its array returned by the function get_its 553 | lag: numpy array 554 | lag times array used to estimate the implied timescales 555 | ylog: Boolean, optional, default = False 556 | if true, the plot will be a logarithmic plot, otherwise it 557 | will be a semilogy plot 558 | 559 | ''' 560 | fig, ax = plt.subplots(1,2, sharey=True, figsize=(12,4)) 561 | number_subsystems = len(its) 562 | 563 | labels=['Subsystem I', 'Subsystem II', 'Subsystem III', 'Subsystem IV', 'Subsystem V'] 564 | style = '-o' 565 | if not multiple_runs: 566 | for n, its_s in enumerate(its): 567 | func = ax[n].loglog if ylog else ax[n].semilogy 568 | 569 | for i in range(np.shape(its_s)[1]): 570 | 571 | if i==0: 572 | label=labels[n] 573 | else: 574 | label='' 575 | func(lag, its_s[:,-(i+1)], style, lw=2, ms=7,label=label) 576 | else: 577 | for n in range(number_subsystems): 578 | func = ax[n].loglog if ylog else ax[n].semilogy 579 | its_n = its[n] 580 | for index_its in range(its_n.shape[-1]): 581 | if index_its==0: 582 | label=labels[n] 583 | else: 584 | label='' 585 | its_all = its_n[:,:,index_its] 586 | sort_its = np.sort(its_all,axis=0) 587 | runs=its_all.shape[0] 588 | ind_upper_lower = int(runs/2- percent * runs/2)+1 589 | lower = sort_its[ind_upper_lower] 590 | upper = sort_its[-ind_upper_lower] 591 | m = its_all.mean(0) 592 | func(lag, m, style, lw=2, ms=7,label=label, color = 'C{}'.format(index_its)) 593 | ax[n].fill_between(lag, upper, lower, color = 'C{}'.format(index_its), alpha = 0.2) 594 | 595 | return ax, fig -------------------------------------------------------------------------------- /ivampnets.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Callable, Tuple, List 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.data import Dataset 7 | from itertools import chain 8 | 9 | 10 | from deeptime.base import Model, Transformer 11 | from deeptime.base_torch import DLEstimatorMixin 12 | from deeptime.util.torch import map_data 13 | from deeptime.markov.tools.analysis import pcca_memberships 14 | 15 | CLIP_VALUE = 1. 16 | 17 | def _inv(x, return_sqrt=False, epsilon=1e-6): 18 | '''Utility function that returns the inverse of a matrix, with the 19 | option to return the square root of the inverse matrix. 20 | Parameters 21 | ---------- 22 | x: numpy array with shape [m,m] 23 | matrix to be inverted 24 | 25 | ret_sqrt: bool, optional, default = False 26 | if True, the square root of the inverse matrix is returned instead 27 | Returns 28 | ------- 29 | x_inv: numpy array with shape [m,m] 30 | inverse of the original matrix 31 | ''' 32 | 33 | # Calculate eigvalues and eigvectors 34 | eigval_all, eigvec_all = torch.symeig(x, eigenvectors=True) 35 | # eigval_all, eigvec_all = torch.linalg.eigh(x, UPLO='U') 36 | # Filter out eigvalues below threshold and corresponding eigvectors 37 | # eig_th = torch.Tensor(epsilon) 38 | index_eig = eigval_all > epsilon 39 | # print(index_eig) 40 | eigval = eigval_all[index_eig] 41 | eigvec = eigvec_all[:,index_eig] 42 | 43 | # Build the diagonal matrix with the filtered eigenvalues or square 44 | # root of the filtered eigenvalues according to the parameter 45 | if return_sqrt: 46 | diag = torch.diag(torch.sqrt(1/eigval)) 47 | else: 48 | diag = torch.diag(1/eigval) 49 | # print(diag.shape, eigvec.shape) 50 | # Rebuild the square root of the inverse matrix 51 | x_inv = torch.matmul(eigvec, torch.matmul(diag, eigvec.T)) 52 | 53 | return x_inv 54 | 55 | 56 | def symeig_reg(mat, epsilon: float = 1e-6, mode='regularize', eigenvectors=True) \ 57 | -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 58 | r""" Solves a eigenvector/eigenvalue decomposition for a hermetian matrix also if it is rank deficient. 59 | 60 | Parameters 61 | ---------- 62 | mat : torch.Tensor 63 | the hermetian matrix 64 | epsilon : float, default=1e-6 65 | Cutoff for eigenvalues. 66 | mode : str, default='regularize' 67 | Whether to truncate eigenvalues if they are too small or to regularize them by taking the absolute value 68 | and adding a small positive constant. :code:`trunc` leads to truncation, :code:`regularize` leads to epsilon 69 | being added to the eigenvalues after taking the absolute value 70 | eigenvectors : bool, default=True 71 | Whether to compute eigenvectors. 72 | 73 | Returns 74 | ------- 75 | (eigval, eigvec) : Tuple[torch.Tensor, Optional[torch.Tensor]] 76 | Eigenvalues and -vectors. 77 | """ 78 | assert mode in sym_inverse.valid_modes, f"Invalid mode {mode}, supported are {sym_inverse.valid_modes}" 79 | 80 | if mode == 'regularize': 81 | identity = torch.eye(mat.shape[0], dtype=mat.dtype, device=mat.device) 82 | mat = mat + epsilon * identity 83 | 84 | # Calculate eigvalues and potentially eigvectors 85 | eigval, eigvec = torch.symeig(mat, eigenvectors=True) 86 | # eigval, eigvec = torch.linalg.eigh(mat, UPLO='U') 87 | 88 | if eigenvectors: 89 | eigvec = eigvec.transpose(0, 1) 90 | 91 | if mode == 'trunc': 92 | # Filter out Eigenvalues below threshold and corresponding Eigenvectors 93 | mask = eigval > epsilon 94 | eigval = eigval[mask] 95 | if eigenvectors: 96 | eigvec = eigvec[mask] 97 | elif mode == 'regularize': 98 | # Calculate eigvalues and eigvectors 99 | eigval = torch.abs(eigval) 100 | elif mode == 'clamp': 101 | eigval = torch.clamp_min(eigval, min=epsilon) 102 | 103 | else: 104 | raise RuntimeError("Invalid mode! Should have been caught by the assertion.") 105 | 106 | if eigenvectors: 107 | return eigval, eigvec 108 | else: 109 | return eigval, eigvec 110 | 111 | 112 | def sym_inverse(mat, epsilon: float = 1e-6, return_sqrt=False, mode='regularize', return_both=False): 113 | """ Utility function that returns the inverse of a matrix, with the 114 | option to return the square root of the inverse matrix. 115 | 116 | Parameters 117 | ---------- 118 | mat: numpy array with shape [m,m] 119 | Matrix to be inverted. 120 | epsilon : float 121 | Cutoff for eigenvalues. 122 | return_sqrt: bool, optional, default = False 123 | if True, the square root of the inverse matrix is returned instead 124 | mode: str, default='trunc' 125 | Whether to truncate eigenvalues if they are too small or to regularize them by taking the absolute value 126 | and adding a small positive constant. :code:`trunc` leads to truncation, :code:`regularize` leads to epsilon 127 | being added to the eigenvalues after taking the absolute value 128 | return_both: bool, default=False 129 | Whether to return the sqrt and its inverse or simply the inverse 130 | Returns 131 | ------- 132 | x_inv: numpy array with shape [m,m] 133 | inverse of the original matrix 134 | """ 135 | if mode=='old': 136 | return _inv(mat, epsilon=epsilon, return_sqrt=return_sqrt) 137 | eigval, eigvec = symeig_reg(mat, epsilon, mode) 138 | 139 | # Build the diagonal matrix with the filtered eigenvalues or square 140 | # root of the filtered eigenvalues according to the parameter 141 | if return_sqrt: 142 | diag_inv = torch.diag(torch.sqrt(1. / eigval)) 143 | if return_both: 144 | diag = torch.diag(torch.sqrt(eigval)) 145 | else: 146 | diag_inv = torch.diag(1. / eigval) 147 | if return_both: 148 | diag = torch.diag(eigval) 149 | if not return_both: 150 | return eigvec.t() @ diag_inv @ eigvec 151 | else: 152 | return eigvec.t() @ diag_inv @ eigvec, eigvec.t() @ diag @ eigvec 153 | 154 | 155 | sym_inverse.valid_modes = ('trunc', 'regularize', 'clamp', 'old') 156 | 157 | 158 | def covariances(x: torch.Tensor, y: torch.Tensor, remove_mean: bool = True): 159 | """Computes instantaneous and time-lagged covariances matrices. 160 | 161 | Parameters 162 | ---------- 163 | x : (T, n) torch.Tensor 164 | Instantaneous data. 165 | y : (T, n) torch.Tensor 166 | Time-lagged data. 167 | remove_mean: bool, default=True 168 | Whether to remove the mean of x and y. 169 | 170 | Returns 171 | ------- 172 | cov_00 : (n, n) torch.Tensor 173 | Auto-covariance matrix of x. 174 | cov_0t : (n, n) torch.Tensor 175 | Cross-covariance matrix of x and y. 176 | cov_tt : (n, n) torch.Tensor 177 | Auto-covariance matrix of y. 178 | 179 | See Also 180 | -------- 181 | deeptime.covariance.Covariance : Estimator yielding these kind of covariance matrices based on raw numpy arrays 182 | using an online estimation procedure. 183 | """ 184 | 185 | assert x.shape == y.shape, "x and y must be of same shape" 186 | batch_size = x.shape[0] 187 | 188 | if remove_mean: 189 | x = x - x.mean(dim=0, keepdim=True) 190 | y = y - y.mean(dim=0, keepdim=True) 191 | 192 | # Calculate the cross-covariance 193 | y_t = y.transpose(0, 1) 194 | x_t = x.transpose(0, 1) 195 | cov_01 = 1 / (batch_size - 1) * torch.matmul(x_t, y) 196 | # Calculate the auto-correlations 197 | cov_00 = 1 / (batch_size - 1) * torch.matmul(x_t, x) 198 | cov_11 = 1 / (batch_size - 1) * torch.matmul(y_t, y) 199 | 200 | return cov_00, cov_01, cov_11 201 | 202 | 203 | valid_score_methods = ('VAMP1', 'VAMP2', 'VAMPE') 204 | 205 | def VAMPE_score(chi_t, chi_tau, epsilon=1e-6, mode='regularize'): 206 | '''Calculates the VAMPE score for an individual VAMPnet model. Furthermore, it returns 207 | the singular functions and singular values to construct the global operator. 208 | Parameters 209 | ---------- 210 | chi_t: (T, n) torch.Tensor 211 | Instantaneous data with shape batchsize x outputsize. 212 | chi_tau: (T, n) torch.Tensor 213 | Time-lagged data with shape batchsize x outputsize. 214 | 215 | Returns 216 | ------- 217 | score: (1) torch.Tensor 218 | VAMPE score. 219 | S: (n, n) torch.Tensor 220 | Singular values on the diagonal of the matrix. 221 | u: (n, n) torch.Tensor 222 | Left singular functions. 223 | v: (n, n) torch.Tensor 224 | Right singular functions. 225 | trace_cov: (1) torch.Tensor 226 | Sum of eigenvalues of the covariance matrix. 227 | ''' 228 | shape = chi_t.shape 229 | 230 | batch_size = shape[0] 231 | 232 | x, y = chi_t, chi_tau 233 | 234 | # Calculate the covariance matrices 235 | cov_00 = 1/(batch_size) * torch.matmul(x.T, x) 236 | cov_11 = 1/(batch_size) * torch.matmul(y.T, y) 237 | cov_01 = 1/(batch_size) * torch.matmul(x.T, y) 238 | 239 | # Calculate the inverse of the self-covariance matrices 240 | cov_00_inv = sym_inverse(cov_00, return_sqrt = True, epsilon=epsilon, mode=mode) 241 | cov_11_inv = sym_inverse(cov_11, return_sqrt = True, epsilon=epsilon, mode=mode) 242 | 243 | # Estimate Vamp-matrix 244 | K = torch.matmul(cov_00_inv, torch.matmul(cov_01, cov_11_inv)) 245 | # Estimate the singular value decomposition 246 | a, sing_values, b = torch.svd(K, compute_uv=True) 247 | # Estimate the singular functions 248 | u = cov_00_inv @ a 249 | v = cov_11_inv @ b 250 | S = torch.diag(sing_values) 251 | # Estimate the VAMPE score 252 | term1 = 2* S @ u.T @ cov_01 @ v 253 | term2 = S @ u.T @ cov_00 @ u @ S @ v.T @ cov_11 @ v 254 | 255 | score = torch.trace(term1 - term2) 256 | # expand zero dimension for summation 257 | score = torch.unsqueeze(score, dim=0) 258 | 259 | # estimate sum of eigenvalues of the covariance matrix to enforce harder assignment 260 | trace_cov = torch.unsqueeze(torch.trace(cov_00), dim=0) 261 | 262 | return score, S, u, v, trace_cov 263 | 264 | 265 | def VAMPE_score_pair(chi1_t, chi1_tau, chi2_t, chi2_tau, S1, S2, u1, u2, v1, v2, device='cpu'): 266 | '''Calculates the VAMPE score for a pair of individual VAMPnet models. The operator is constructed 267 | by the two individual operators and evaluate on the outer space constructed by the individual 268 | feature functions. 269 | Parameters 270 | ---------- 271 | chi1_t: (T, n) torch.Tensor 272 | Instantaneous data with shape batchsize x outputsize of VAMPnet 1. 273 | chi1_tau: (T, n) torch.Tensor 274 | Time-lagged data with shape batchsize x outputsize of VAMPnet 1. 275 | chi2_t: (T, m) torch.Tensor 276 | Instantaneous data with shape batchsize x outputsize of VAMPnet 2. 277 | chi2_tau: (T, m) torch.Tensor 278 | Time-lagged data with shape batchsize x outputsize of VAMPnet 2. 279 | S1: (n, n) torch.Tensor 280 | Singular values of VAMPnet 1. 281 | S2: (m, m) torch.Tensor 282 | Singular values of VAMPnet 2. 283 | u1: (n, n) torch.Tensor 284 | Left singular functions of VAMPnet 1. 285 | u2: (m, m) torch.Tensor 286 | Left singular functions of VAMPnet 2. 287 | v1: (n, n) torch.Tensor 288 | Right singular functions of VAMPnet 1. 289 | v2: (m, m) torch.Tensor 290 | Right singular functions of VAMPnet 2. 291 | 292 | Returns 293 | ------- 294 | score: (1) torch.Tensor 295 | VAMPE score for the performance of the constructed operator on the global features. 296 | pen_C00: (1) torch.Tensor 297 | Error of the left singular functions. 298 | pen_C11: (1) torch.Tensor 299 | Error of the right singular functions. 300 | pen_C01: (1) torch.Tensor 301 | Error of the correlation of the two singular functions. 302 | ''' 303 | 304 | 305 | shape1 = chi1_t.shape 306 | shape2 = chi2_t.shape 307 | new_shape = shape1[1] * shape2[1] 308 | batch_size = shape1[0] 309 | 310 | # construct the singular functions for the global model from both individual subsystems 311 | U_train = torch.reshape(u1[:,None,:,None] * u2[None,:,None,:], (new_shape, new_shape)) 312 | V_train = torch.reshape(v1[:,None,:,None] * v2[None,:,None,:], (new_shape, new_shape)) 313 | K_train = torch.reshape(S1[:,None,:,None] * S2[None,:,None,:], (new_shape, new_shape)) 314 | # construct the global feature space as the outer product of the individual spaces 315 | chi_t_outer = torch.reshape(chi1_t[:,:,None] * chi2_t[:,None,:], (batch_size,new_shape)) 316 | chi_tau_outer = torch.reshape(chi1_tau[:,:,None] * chi2_tau[:,None,:], (batch_size,new_shape)) 317 | 318 | x, y = chi_t_outer, chi_tau_outer 319 | # Calculate the covariance matrices 320 | cov_00 = 1/(batch_size) * torch.matmul(x.T, x) 321 | cov_11 = 1/(batch_size) * torch.matmul(y.T, y) 322 | cov_01 = 1/(batch_size) * torch.matmul(x.T, y) 323 | 324 | # map the matrices on the singular functions 325 | 326 | C00_map = U_train.T @ cov_00 @ U_train 327 | C11_map = V_train.T @ cov_11 @ V_train 328 | C01_map = U_train.T @ cov_01 @ V_train 329 | # helper function to estimate errors from optimal solution 330 | unit_matrix = torch.eye(new_shape, device=device) 331 | # Estimate the deviation from the optimal behaviour if the two system would be truly independent 332 | pen_C00 = torch.unsqueeze(torch.sum(torch.abs(unit_matrix - C00_map)), dim=0) / (new_shape-1)**2 333 | pen_C11 = torch.unsqueeze(torch.sum(torch.abs(unit_matrix - C11_map)), dim=0) / (new_shape-1)**2 334 | pen_C01 = torch.unsqueeze(torch.sum(torch.abs(C01_map - K_train)), dim=0) / (new_shape-1)**2 335 | # Estimate the VAMPE score of how well the constructed operator predicts the dynamic of the global feature space 336 | term1 = 2 * K_train @ C01_map 337 | term2 = K_train @ C00_map @ K_train @ C11_map 338 | 339 | score = torch.trace(term1 - term2) 340 | # add zero dimension for summation 341 | score = torch.unsqueeze(score, dim=0) 342 | 343 | return score, pen_C00, pen_C11, pen_C01 344 | 345 | def score_loss(score1, score2, score12): 346 | ''' Estimates the discrepancy of the global score and the two individual VAMPE scores. 347 | Parameters 348 | ---------- 349 | score1: (1) torch.Tensor 350 | Score of VAMPnets 1. 351 | score2: (1) torch.Tensor 352 | Score of VAMPnets 2. 353 | score12: (1) torch.Tensor 354 | Score of the constructed global operator. 355 | 356 | Returns 357 | ------- 358 | pen_scores: (1) torch.Tensor 359 | Error of the scores due to non independent behavior. 360 | 361 | ''' 362 | prod_score = score1 * score2 363 | # Estimate normalizer to rescale them but not use them for gradient updates. 364 | norm1 = torch.abs(prod_score.detach()) 365 | norm2 = torch.abs(score12.detach()) 366 | diff = torch.abs(score12 - prod_score) 367 | score_diff = torch.unsqueeze(diff / norm1, dim=0) 368 | score_diff2 = torch.unsqueeze(diff / norm2, dim=0) 369 | pen_scores = (score_diff + score_diff2) / 2. 370 | 371 | return pen_scores 372 | 373 | def score_all_systems(chi_t_list, chi_tau_list, epsilon=1e-6, mode='regularize'): 374 | ''' Estimates all scores and singular functions/values for all VAMPnets. 375 | Parameters 376 | ---------- 377 | chi_t_list: list of length number subsystems 378 | List of the feature functions of all VAMPnets for the instantaneous data. 379 | chi_tau_list: list of length number subsystems 380 | List of the feature functions of all VAMPnets for the time-lagged data. 381 | 382 | Returns 383 | ------- 384 | scores_single: list of length number subsystems 385 | List of the individual scores of all subsystems. 386 | S_single: list of length number subsystems 387 | List of the individual singular values of all subsystems. 388 | u_single: list of length number subsystems 389 | List of the individual left singular functions of all subsystems. 390 | v_single: list of length number subsystems 391 | List of the individual right singular functions of all subsystems. 392 | trace_single: list of length number subsystems 393 | List of the traces of covariance matrices of all subsystems. 394 | ''' 395 | scores_single = [] 396 | u_single = [] 397 | v_single = [] 398 | S_single = [] 399 | trace_single = [] 400 | N = len(chi_t_list) 401 | for i in range(N): 402 | 403 | chi_i_t = chi_t_list[i] 404 | chi_i_tau = chi_tau_list[i] 405 | 406 | score_i, S_i , u_i, v_i, trace_i = VAMPE_score(chi_i_t, chi_i_tau, epsilon=epsilon, mode=mode) 407 | scores_single.append(score_i) 408 | trace_single.append(trace_i) 409 | u_single.append(u_i) 410 | v_single.append(v_i) 411 | S_single.append(S_i) 412 | 413 | return scores_single, S_single, u_single, v_single, trace_single 414 | 415 | def score_all_outer_systems(chi_t_list, chi_tau_list, S_list, u_list, v_list, device='cpu'): 416 | ''' Estimates the global scores and all penalties. 417 | Parameters 418 | ---------- 419 | chi_t_list: list of length number subsystems 420 | List of the feature functions of all VAMPnets for the instantaneous data. 421 | chi_tau_list: list of length number subsystems 422 | List of the feature functions of all VAMPnets for the time-lagged data. 423 | S_list: list of length number subsystems 424 | List of the individual singular values of all subsystems. 425 | u_list: list of length number subsystems 426 | List of the individual left singular functions of all subsystems. 427 | v_list: list of length number subsystems 428 | List of the individual right singular functions of all subsystems. 429 | 430 | Returns 431 | ------- 432 | scores_outer: list of length N*(N-1)/2, where N is the number of subsystems 433 | List of global scores of all pairs of VAMPnets. 434 | pen_C00_map: list of length number pairs. 435 | List of penalties of singular left functions. 436 | pen_C11_map: list of length number pairs. 437 | List of penalties of singular right functions. 438 | pen_C01_map: list of length number pairs. 439 | List of penalties of the correlation of the singular functions. 440 | ''' 441 | scores_outer = [] 442 | pen_C00_map = [] 443 | pen_C11_map = [] 444 | pen_C01_map = [] 445 | # N_pair = len(scores_outer) 446 | # N = int(0.5+np.sqrt(0.25+2*N_pair)) 447 | N = len(chi_t_list) 448 | for i in range(N): 449 | 450 | for j in range(i+1,N): 451 | 452 | score_ij, pen_C00_ij, pen_C11_ij, pen_C01_ij = VAMPE_score_pair( 453 | chi_t_list[i], chi_tau_list[i], 454 | chi_t_list[j], chi_tau_list[j], 455 | S_list[i], S_list[j], 456 | u_list[i], u_list[j], 457 | v_list[i], v_list[j], 458 | device=device) 459 | scores_outer.append(score_ij) 460 | pen_C00_map.append(pen_C00_ij) 461 | pen_C11_map.append(pen_C11_ij) 462 | pen_C01_map.append(pen_C01_ij) 463 | 464 | return scores_outer, pen_C00_map, pen_C11_map, pen_C01_map 465 | 466 | 467 | def pen_all_scores(scores, score_pairs): 468 | ''' Estimate all penalties of the individual and global scores. 469 | Paramters 470 | --------- 471 | scores: list of length number subsystems 472 | List of the VAMPE scores of all VAMPnets. 473 | score_pairs: list of length number of pairs 474 | List of the pairwise global scores of all combinations of VAMPnets. 475 | 476 | Returns 477 | ------- 478 | pen_scores: list of length number of pairs 479 | List of the pairwise penalties of all combindations of VAMPnets. 480 | ''' 481 | pen_scores = [] 482 | counter = 0 483 | N = len(scores) 484 | for i in range(N): 485 | 486 | for j in range(i+1, N): 487 | pen_score_ij = score_loss(scores[i], scores[j], score_pairs[counter]) 488 | pen_scores.append(pen_score_ij) 489 | counter+=1 490 | 491 | return pen_scores 492 | 493 | def estimate_transition_matrix(chi_t, chi_tau, mode='regularize', epsilon=1e-6): 494 | ''' Estimate the transition matrix given the feature vectors at time t and tau 495 | 496 | 497 | ''' 498 | shape = chi_t.shape 499 | 500 | batch_size = shape[0] 501 | 502 | x, y = chi_t, chi_tau 503 | 504 | # Calculate the covariance matrices 505 | cov_00 = 1/(batch_size) * torch.matmul(x.T, x) 506 | cov_11 = 1/(batch_size) * torch.matmul(y.T, y) 507 | cov_01 = 1/(batch_size) * torch.matmul(x.T, y) 508 | 509 | # Calculate the inverse of the self-covariance matrices 510 | cov_00_inv = sym_inverse(cov_00, return_sqrt = False, epsilon=epsilon, mode=mode) 511 | 512 | T = cov_00_inv @ cov_01 513 | 514 | return T.detach().to('cpu').numpy() 515 | 516 | class iVAMPnetModel(Transformer, Model): 517 | r""" 518 | A iVAMPNet model which can be fit to data optimizing for one of the implemented VAMP scores. 519 | 520 | Parameters 521 | ---------- 522 | lobes : list of torch.nn.Module 523 | List of the lobes of each VAMPNet. See also :class:`deeptime.util.torch.MLP`. 524 | mask : layer 525 | Layer, which masks the inputs to the different iVAMPnets. This makes it possible to interpret which 526 | part of the input is important for each lobe. 527 | dtype : data type, default=np.float32 528 | The data type for which operations should be performed. Leads to an appropriate cast within fit and 529 | transform methods. 530 | device : device, default=None 531 | The device for the lobes. Can be None which defaults to CPU. 532 | 533 | See Also 534 | -------- 535 | iVAMPNet : The corresponding estimator. 536 | """ 537 | 538 | def __init__(self, lobes: list, mask, 539 | dtype=np.float32, device=None, epsilon=1e-6, mode='regularize'): 540 | super().__init__() 541 | self._lobes = lobes 542 | self._mask = mask 543 | self._N = len(lobes) 544 | self._dtype = dtype 545 | if self._dtype == np.float32: 546 | for n in range(self._N): 547 | self._lobes[n] = self._lobes[n].float() 548 | elif self._dtype == np.float64: 549 | for n in range(self._N): 550 | self._lobes[n] = self._lobes[n].double() 551 | self._device = device 552 | self._epsilon = epsilon 553 | self._mode = mode 554 | 555 | def transform(self, data, numpy=True, batchsize=0, **kwargs): 556 | '''Transforms the supplied data with the model. It outputs the fuzzy state assignment for each 557 | subsystem. 558 | 559 | Parameters 560 | ---------- 561 | data: nd.array or torch.Tensor 562 | Data which should be transformed to a fuzzy state assignment. 563 | numpy: bool, default=True 564 | If the output should be converted to a numpy array. 565 | batchsize: int, default=0 566 | The batchsize which should be used to predict one chunk of the data, which is useful, if 567 | data does not fit into the memory. If batchsize<=0 the whole dataset will be simultaneously 568 | transformed. 569 | 570 | Returns 571 | ------- 572 | out: nd.array or torch.Tensor 573 | The transformed data. If numpy=True the output will be a nd.array otherwise a torch.Tensor. 574 | ''' 575 | for lobe in self._lobes: 576 | lobe.eval() 577 | 578 | if batchsize>0: 579 | batches = data.shape[0]//batchsize + (data.shape[0]%batchsize>0) 580 | if isinstance(data, torch.Tensor): 581 | torch.split(data, batches) 582 | else: 583 | data = np.array_split(data, batches) 584 | 585 | out = [[] for _ in range(self._N)] 586 | with torch.no_grad(): 587 | for data_tensor in map_data(data, device=self._device, dtype=self._dtype): 588 | mask_data = self._mask(data_tensor) 589 | for n in range(self._N): 590 | mask_n = torch.squeeze(mask_data[n], dim=2) 591 | if numpy: 592 | out[n].append(self._lobes[n](mask_n).cpu().numpy()) 593 | else: 594 | out[n].append(self._lobes[n](mask_n)) 595 | return out if len(out) > 1 else out[0] 596 | 597 | 598 | def get_transition_matrix(self, data_0, data_t, batchsize=0): 599 | ''' Estimates the transition matrix based on the two provided datasets, where each frame 600 | should be lagtime apart. 601 | 602 | Parameters 603 | ---------- 604 | data_0: nd.array or torch.Tensor 605 | The instantaneous data. 606 | data_t: nd.array or torch.Tensor 607 | The time-lagged data. 608 | batchsize: int, default=0 609 | The batchsize which should be used to predict one chunk of the data, which is useful, if 610 | data does not fit into the memory. If batchsize<=0 the whole dataset will be simultaneously 611 | transformed. 612 | 613 | Returns 614 | ------- 615 | T_list: list 616 | The list of the transition matrices of all subsystems. 617 | 618 | ''' 619 | 620 | chi_t_list = self.transform(data_0, numpy=False, batchsize=batchsize) 621 | chi_tau_list = self.transform(data_t, numpy=False, batchsize=batchsize) 622 | T_list = [] 623 | for n in range(self._N): 624 | chi_t, chi_tau = torch.cat(chi_t_list[n], dim=0), torch.cat(chi_tau_list[n], dim=0) 625 | K = estimate_transition_matrix(chi_t, chi_tau, mode=self._mode, epsilon=self._epsilon).astype('float64') 626 | # Converting to double precision destroys the normalization 627 | T = K / K.sum(axis=1)[:, None] 628 | T_list.append(T) 629 | return T_list 630 | 631 | def timescales(self, data_0, data_t, tau, batchsize=0): 632 | ''' Estimates the timescales of the model given the provided data. 633 | 634 | Parameters 635 | ---------- 636 | data_0: nd.array or torch.Tensor 637 | The instantaneous data. 638 | data_t: nd.array or torch.Tensor 639 | The time-lagged data. 640 | tau: int 641 | The time-lagged used for the data. 642 | batchsize: int, default=0 643 | The batchsize which should be used to predict one chunk of the data, which is useful, if 644 | data does not fit into the memory. If batchsize<=0 the whole dataset will be simultaneously 645 | transformed. 646 | 647 | Returns 648 | ------- 649 | its: list 650 | The list of the implied timescales of all subsystems. 651 | 652 | ''' 653 | 654 | T_list = self.get_transition_matrix(data_0, data_t, batchsize=batchsize) 655 | its = [] 656 | for T in T_list: 657 | eigvals = np.linalg.eigvals(T) 658 | eigvals_sort = np.sort(eigvals)[:-1] # remove eigenvalue 1 659 | its.append( - tau/np.log(np.abs(eigvals_sort[::-1]))) 660 | 661 | return its 662 | 663 | 664 | class iVAMPnet(DLEstimatorMixin, Transformer): 665 | r""" Implementation of iVAMPNets :cite:`vnet-mardt2018vampnets` which try to find an optimal featurization of 666 | data based on a VAMPE score :cite:`vnet-wu2020variational` by using neural networks as featurizing transforms 667 | which are sought to be independent. This estimator is also a transformer 668 | and can be used to transform data into the optimized space. From there it can either be used to estimate 669 | Markov state models via making assignment probabilities crisp (in case of softmax output distributions) or 670 | to estimate the Koopman operator using the :class:`VAMP ` estimator. 671 | 672 | Parameters 673 | ---------- 674 | lobes : list of torch.nn.Module 675 | List of the lobes of each VAMPNet. See also :class:`deeptime.util.torch.MLP`. 676 | mask : torch.nn.module 677 | Module which masks the input features to assign them to a specific subsystem. 678 | device : torch device, default=None 679 | The device on which the torch modules are executed. 680 | optimizer : str or Callable, default='Adam' 681 | An optimizer which can either be provided in terms of a class reference (like `torch.optim.Adam`) or 682 | a string (like `'Adam'`). Defaults to Adam. 683 | learning_rate : float, default=5e-4 684 | The learning rate of the optimizer. 685 | score_mode : str, default='regularize' 686 | The mode under which inverses of positive semi-definite matrices are estimated. Per default, the matrices 687 | are perturbed by a small constant added to the diagonal. This makes sure that eigenvalues are not too 688 | small. For a complete list of modes, see :meth:`sym_inverse`. 689 | epsilon : float, default=1e-6 690 | The strength of the regularization under which matrices are inverted. Meaning depends on the score_mode, 691 | see :meth:`sym_inverse`. 692 | dtype : dtype, default=np.float32 693 | The data type of the modules and incoming data. 694 | shuffle : bool, default=True 695 | Whether to shuffle data during training after each epoch. 696 | 697 | See Also 698 | -------- 699 | deeptime.decomposition.VAMP 700 | 701 | References 702 | ---------- 703 | .. bibliography:: /references.bib 704 | :style: unsrt 705 | :filter: docname in docnames 706 | :keyprefix: vnet- 707 | """ 708 | _MUTABLE_INPUT_DATA = True 709 | 710 | def __init__(self, lobes: list, mask: nn.Module, 711 | device=None, optimizer: Union[str, Callable] = 'Adam', learning_rate: float = 5e-4, learning_rate_mask: float = 5e-4, 712 | score_mode: str = 'regularize', epsilon: float = 1e-6, 713 | dtype=np.float32, shuffle: bool = True): 714 | super().__init__() 715 | self.N = len(lobes) 716 | self.lobes = lobes 717 | self.mask = mask 718 | self.score_mode = score_mode 719 | self._step = 0 720 | self.shuffle = shuffle 721 | self._epsilon = epsilon 722 | self.device = device 723 | self.learning_rate = learning_rate 724 | self.learning_rate_mask = learning_rate_mask 725 | self.dtype = dtype 726 | self.optimizer_lobes = [torch.optim.Adam(lobe.parameters(), lr=self.learning_rate) for lobe in self.lobes] 727 | self.optimizer_mask = torch.optim.Adam(self.mask.parameters(), lr=self.learning_rate_mask) 728 | self._train_scores = [] 729 | self._validation_scores = [] 730 | self._train_vampe = [] 731 | self._train_pen_C00 = [] 732 | self._train_pen_C11 = [] 733 | self._train_pen_C01 = [] 734 | self._train_pen_scores = [] 735 | self._train_trace = [] 736 | self._validation_vampe = [] 737 | self._validation_pen_C00 = [] 738 | self._validation_pen_C11 = [] 739 | self._validation_pen_C01 = [] 740 | self._validation_pen_scores = [] 741 | self._validation_trace = [] 742 | 743 | @property 744 | def train_scores(self) -> np.ndarray: 745 | r""" The collected train scores. First dimension contains the step, second dimension the score. Initially empty. 746 | 747 | :type: (T, 2) ndarray 748 | """ 749 | return np.array(self._train_scores) 750 | @property 751 | def train_vampe(self) -> np.ndarray: 752 | r""" The collected train scores. First dimension contains the step, second dimension the score. Initially empty. 753 | 754 | :type: (T, 2) ndarray 755 | """ 756 | return np.array(self._train_vampe) 757 | @property 758 | def train_pen_C00(self) -> np.ndarray: 759 | r""" The collected train scores. First dimension contains the step, second dimension the score. Initially empty. 760 | 761 | :type: (T, 2) ndarray 762 | """ 763 | return np.array(self._train_pen_C00) 764 | @property 765 | def train_pen_C11(self) -> np.ndarray: 766 | r""" The collected train scores. First dimension contains the step, second dimension the score. Initially empty. 767 | 768 | :type: (T, 2) ndarray 769 | """ 770 | return np.array(self._train_pen_C11) 771 | @property 772 | def train_pen_C01(self) -> np.ndarray: 773 | r""" The collected train scores. First dimension contains the step, second dimension the score. Initially empty. 774 | 775 | :type: (T, 2) ndarray 776 | """ 777 | return np.array(self._train_pen_C01) 778 | @property 779 | def train_pen_scores(self) -> np.ndarray: 780 | r""" The collected train scores. First dimension contains the step, second dimension the score. Initially empty. 781 | 782 | :type: (T, 2) ndarray 783 | """ 784 | return np.array(self._train_pen_scores) 785 | @property 786 | def train_trace(self) -> np.ndarray: 787 | r""" The collected train scores. First dimension contains the step, second dimension the score. Initially empty. 788 | 789 | :type: (T, 2) ndarray 790 | """ 791 | return np.array(self._train_trace) 792 | 793 | @property 794 | def validation_scores(self) -> np.ndarray: 795 | r""" The collected validation scores. First dimension contains the step, second dimension the score. 796 | Initially empty. 797 | 798 | :type: (T, 2) ndarray 799 | """ 800 | return np.array(self._validation_scores) 801 | @property 802 | def validation_vampe(self) -> np.ndarray: 803 | r""" The collected train scores. First dimension contains the step, second dimension the score. Initially empty. 804 | 805 | :type: (T, 2) ndarray 806 | """ 807 | return np.array(self._validation_vampe) 808 | @property 809 | def validation_pen_C00(self) -> np.ndarray: 810 | r""" The collected train scores. First dimension contains the step, second dimension the score. Initially empty. 811 | 812 | :type: (T, 2) ndarray 813 | """ 814 | return np.array(self._validation_pen_C00) 815 | @property 816 | def validation_pen_C11(self) -> np.ndarray: 817 | r""" The collected train scores. First dimension contains the step, second dimension the score. Initially empty. 818 | 819 | :type: (T, 2) ndarray 820 | """ 821 | return np.array(self._validation_pen_C11) 822 | @property 823 | def validation_pen_C01(self) -> np.ndarray: 824 | r""" The collected train scores. First dimension contains the step, second dimension the score. Initially empty. 825 | 826 | :type: (T, 2) ndarray 827 | """ 828 | return np.array(self._validation_pen_C01) 829 | @property 830 | def validation_pen_scores(self) -> np.ndarray: 831 | r""" The collected train scores. First dimension contains the step, second dimension the score. Initially empty. 832 | 833 | :type: (T, 2) ndarray 834 | """ 835 | return np.array(self._validation_pen_scores) 836 | @property 837 | def validation_trace(self) -> np.ndarray: 838 | r""" The collected train scores. First dimension contains the step, second dimension the score. Initially empty. 839 | 840 | :type: (T, 2) ndarray 841 | """ 842 | return np.array(self._validation_trace) 843 | @property 844 | def epsilon(self) -> float: 845 | r""" Regularization parameter for matrix inverses. 846 | 847 | :getter: Gets the currently set parameter. 848 | :setter: Sets a new parameter. Must be non-negative. 849 | :type: float 850 | """ 851 | return self._epsilon 852 | 853 | @epsilon.setter 854 | def epsilon(self, value: float): 855 | assert value >= 0 856 | self._epsilon = value 857 | 858 | @property 859 | def score_method(self) -> str: 860 | r""" Property which steers the scoring behavior of this estimator. 861 | 862 | :getter: Gets the current score. 863 | :setter: Sets the score to use. 864 | :type: str 865 | """ 866 | return self._score_method 867 | 868 | @score_method.setter 869 | def score_method(self, value: str): 870 | if value not in valid_score_methods: 871 | raise ValueError(f"Tried setting an unsupported scoring method '{value}', " 872 | f"available are {valid_score_methods}.") 873 | self._score_method = value 874 | 875 | # @property 876 | # def lobes(self) -> nn.Module: 877 | # r""" The instantaneous lobe of the VAMPNet. 878 | 879 | # :getter: Gets the instantaneous lobe. 880 | # :setter: Sets a new lobe. 881 | # :type: torch.nn.Module 882 | # """ 883 | # return self.lobes 884 | 885 | # @lobes.setter 886 | # def lobes(self, value: list): 887 | # assert len(value)==self.N, 'You must provide as many lobes as independent subsystems!' 888 | # for n in range(self.N): 889 | # self.lobes[n] = value[n] 890 | # if self.dtype == np.float32: 891 | # self.lobes[n] = self.lobes[n].float() 892 | # else: 893 | # self.lobes[n] = self.lobes[n].double() 894 | # self.lobes[n] = self.lobes[n].to(device=self.device) 895 | def forward(self, data): 896 | 897 | if data.get_device(): 898 | data = data.to(device=self.device) 899 | masked_data = self.mask(data) 900 | chi_data_list = [] 901 | for n in range(self.N): 902 | lobe = self.lobes[n] 903 | data_n = torch.squeeze(masked_data[n], dim=2) 904 | chi_data_list.append(lobe(data_n)) 905 | return chi_data_list 906 | 907 | def reset_scores(self): 908 | self._train_scores = [] 909 | self._validation_scores = [] 910 | self._train_vampe = [] 911 | self._train_pen_C00 = [] 912 | self._train_pen_C11 = [] 913 | self._train_pen_C01 = [] 914 | self._train_pen_scores = [] 915 | self._train_trace = [] 916 | self._validation_vampe = [] 917 | self._validation_pen_C00 = [] 918 | self._validation_pen_C11 = [] 919 | self._validation_pen_C01 = [] 920 | self._validation_pen_scores = [] 921 | self._validation_trace = [] 922 | self._step = 0 923 | 924 | def partial_fit(self, data, lam_decomp: float = 1., mask: bool = False, lam_trace: float = 0., 925 | train_score_callback: Callable[[int, torch.Tensor], None] = None, 926 | tb_writer=None, clip=False, lam_pen_perc=None, lam_pen_C00=0., lam_pen_C11=0., lam_pen_C01=0.): 927 | r""" Performs a partial fit on data. This does not perform any batching. 928 | 929 | Parameters 930 | ---------- 931 | data : tuple or list of length 2, containing instantaneous and timelagged data 932 | The data to train the lobe(s) on. 933 | lam_decomp : float 934 | The weighting factor how much the dependency score should be weighted in the loss. 935 | mask : bool default False 936 | Whether the mask should be trained or not. 937 | lam_trace : float 938 | The weighting factor how much the trace should be weighted in the loss. 939 | train_score_callback : callable, optional, default=None 940 | An optional callback function which is evaluated after partial fit, containing the current step 941 | of the training (only meaningful during a :meth:`fit`) and the current score as torch Tensor. 942 | tb_writer : tensorboard writer 943 | If given, scores will be recorded in the tensorboard log file. 944 | clip : bool default=False 945 | If True the gradients of the weights will be clipped by norm before applying them for the update. 946 | Returns 947 | ------- 948 | self : iVAMPNet 949 | Reference to self. 950 | """ 951 | 952 | if self.dtype == np.float32: 953 | for n in range(self.N): 954 | self.lobes[n] = self.lobes[n].float() 955 | elif self.dtype == np.float64: 956 | for n in range(self.N): 957 | self.lobes[n] = self.lobes[n].double() 958 | for n in range(self.N): 959 | self.lobes[n].train() 960 | self.mask.train() 961 | 962 | assert isinstance(data, (list, tuple)) and len(data) == 2, \ 963 | "Data must be a list or tuple of batches belonging to instantaneous " \ 964 | "and respective time-lagged data." 965 | 966 | batch_0, batch_t = data[0], data[1] 967 | 968 | if isinstance(data[0], np.ndarray): 969 | batch_0 = torch.from_numpy(data[0].astype(self.dtype)).to(device=self.device) 970 | if isinstance(data[1], np.ndarray): 971 | batch_t = torch.from_numpy(data[1].astype(self.dtype)).to(device=self.device) 972 | for opt in self.optimizer_lobes: 973 | opt.zero_grad() 974 | if mask: 975 | self.optimizer_mask.zero_grad() 976 | chi_t_list = self.forward(batch_0) # returns list of feature vectors 977 | chi_tau_list = self.forward(batch_t) 978 | # Estimate all individual scores and singular functions 979 | scores_single, S_single, u_single, v_single, trace_single = score_all_systems(chi_t_list, chi_tau_list, 980 | epsilon=self._epsilon, mode=self.score_mode) 981 | # Estimate all pairwise scores and independent penalties 982 | score_pairs, pen_C00_map, pen_C11_map, pen_C01_map = score_all_outer_systems(chi_t_list, chi_tau_list, S_single, 983 | u_single, v_single, device=self.device) 984 | # Estimate the penalty of the scores 985 | pen_scores = pen_all_scores(scores_single, score_pairs) 986 | # Take the mean over all pairs 987 | pen_scores_all = torch.mean(torch.cat(pen_scores, dim=0)) 988 | pen_C00_map_all = torch.mean(torch.cat(pen_C00_map, dim=0)) 989 | pen_C11_map_all = torch.mean(torch.cat(pen_C11_map, dim=0)) 990 | pen_C01_map_all = torch.mean(torch.cat(pen_C01_map, dim=0)) 991 | trace_all = torch.mean(torch.cat(trace_single, dim=0)) 992 | # Estimate the sum of scores, !!! Check if mean is correct 993 | vamp_sum_score = torch.mean(torch.cat(scores_single, dim=0)) 994 | vamp_score_pairs = torch.mean(torch.cat(score_pairs, dim=0)) 995 | if lam_pen_perc is not None: 996 | vamp_score_item, pen_score_item, pen_c00_item, pen_c11_item, pen_c01_item = vamp_score_pairs.item(), pen_scores_all.item(), pen_C00_map_all.item(), pen_C11_map_all.item(), pen_C01_map_all.item() 997 | fac_pen_score, fac_pen_c00, fac_pen_c11, fac_pen_c01 = lam_pen_perc * vamp_score_item/pen_score_item, lam_pen_C00 * vamp_score_item/pen_c00_item, lam_pen_C11 * vamp_score_item/pen_c11_item, lam_pen_C01 * vamp_score_item/pen_c01_item 998 | loss_value = - vamp_score_pairs - lam_trace * trace_all + fac_pen_score * pen_scores_all + fac_pen_c00 * pen_C00_map_all + fac_pen_c11 * pen_C11_map_all + fac_pen_c01 * pen_C01_map_all 999 | else: 1000 | loss_value = - vamp_score_pairs + lam_decomp * pen_scores_all - lam_trace * trace_all 1001 | loss_value.backward() 1002 | if clip: 1003 | # clip the gradients 1004 | for lobe in self.lobes: 1005 | torch.nn.utils.clip_grad_norm_(lobe.parameters(), CLIP_VALUE) 1006 | 1007 | if mask: 1008 | if clip: 1009 | torch.nn.utils.clip_grad_norm_(self.mask.parameters(), CLIP_VALUE) 1010 | self.optimizer_mask.step() 1011 | for opt in self.optimizer_lobes: 1012 | opt.step() 1013 | if train_score_callback is not None: 1014 | lval_detached = loss_value.detach() 1015 | train_score_callback(self._step, -lval_detached) 1016 | if tb_writer is not None: 1017 | tb_writer.add_scalars('Loss', {'train': loss_value.item()}, self._step) 1018 | tb_writer.add_scalars('VAMPE', {'train': vamp_score_pairs.item()}, self._step) 1019 | tb_writer.add_scalars('Pen_C00', {'train': pen_C00_map_all.item()}, self._step) 1020 | tb_writer.add_scalars('Pen_C11', {'train': pen_C11_map_all.item()}, self._step) 1021 | tb_writer.add_scalars('Pen_C01', {'train': pen_C01_map_all.item()}, self._step) 1022 | tb_writer.add_scalars('Pen_scores', {'train': pen_scores_all.item()}, self._step) 1023 | tb_writer.add_scalars('Trace_all', {'train': trace_all.item()}, self._step) 1024 | self._train_scores.append((self._step, (-loss_value).item())) 1025 | self._train_vampe.append((self._step, (vamp_score_pairs).item())) 1026 | self._train_pen_C00.append((self._step, (pen_C00_map_all).item())) 1027 | self._train_pen_C11.append((self._step, (pen_C11_map_all).item())) 1028 | self._train_pen_C01.append((self._step, (pen_C01_map_all).item())) 1029 | self._train_pen_scores.append((self._step, (pen_scores_all).item())) 1030 | self._train_trace.append((self._step, (trace_all).item())) 1031 | self._step += 1 1032 | 1033 | return self 1034 | 1035 | def validate(self, validation_data: Tuple[torch.Tensor], lam_decomp: float = 1., lam_trace: float = 0., 1036 | lam_pen_perc=None, lam_pen_C00=0., lam_pen_C11=0., lam_pen_C01=0.) -> torch.Tensor: 1037 | r""" Evaluates the currently set lobe(s) on validation data and returns the value of the configured score. 1038 | 1039 | Parameters 1040 | ---------- 1041 | validation_data : Tuple of torch Tensor containing instantaneous and timelagged data 1042 | The validation data. 1043 | lam_decomp : float 1044 | The weighting factor how much the dependency score should be weighted in the loss. 1045 | lam_trace : float 1046 | The weighting factor how much the trace should be weighted in the loss. 1047 | 1048 | Returns 1049 | ------- 1050 | score : torch.Tensor 1051 | The value of the score. 1052 | """ 1053 | for lobe in self.lobes: 1054 | lobe.eval() 1055 | self.mask.eval() 1056 | 1057 | with torch.no_grad(): 1058 | chi_t_list = self.forward(validation_data[0]) 1059 | chi_tau_list = self.forward(validation_data[1]) 1060 | 1061 | # Estimate all individual scores and singular functions 1062 | scores_single, S_single, u_single, v_single, trace_single = score_all_systems(chi_t_list, chi_tau_list, 1063 | epsilon=self._epsilon, mode=self.score_mode) 1064 | # Estimate all pairwise scores and independent penalties 1065 | score_pairs, pen_C00_map, pen_C11_map, pen_C01_map = score_all_outer_systems(chi_t_list, chi_tau_list, 1066 | S_single, u_single, v_single, device=self.device) 1067 | # Estimate the penalty of the scores 1068 | pen_scores = pen_all_scores(scores_single, score_pairs) 1069 | # Take the mean over all pairs 1070 | pen_scores_all = torch.mean(torch.cat(pen_scores, dim=0)) 1071 | pen_C00_map_all = torch.mean(torch.cat(pen_C00_map, dim=0)) 1072 | pen_C11_map_all = torch.mean(torch.cat(pen_C11_map, dim=0)) 1073 | pen_C01_map_all = torch.mean(torch.cat(pen_C01_map, dim=0)) 1074 | trace_all = torch.mean(torch.cat(trace_single, dim=0)) 1075 | # Estimate the sum of scores, !!! Check if mean is correct 1076 | vamp_sum_score = torch.mean(torch.cat(scores_single, dim=0)) 1077 | vamp_score_pairs = torch.mean(torch.cat(score_pairs, dim=0)) 1078 | 1079 | if lam_pen_perc is not None: 1080 | vamp_score_item, pen_score_item, pen_c00_item, pen_c11_item, pen_c01_item = vamp_score_pairs.item(), pen_scores_all.item(), pen_C00_map_all.item(), pen_C11_map_all.item(), pen_C01_map_all.item() 1081 | fac_pen_score, fac_pen_c00, fac_pen_c11, fac_pen_c01 = lam_pen_perc * vamp_score_item/pen_score_item, lam_pen_C00 * vamp_score_item/pen_c00_item, lam_pen_C11 * vamp_score_item/pen_c11_item, lam_pen_C01 * vamp_score_item/pen_c01_item 1082 | loss_value = - vamp_score_pairs - lam_trace * trace_all + fac_pen_score * pen_scores_all + fac_pen_c00 * pen_C00_map_all + fac_pen_c11 * pen_C11_map_all + fac_pen_c01 * pen_C01_map_all 1083 | else: 1084 | loss_value = - vamp_score_pairs + lam_decomp * pen_scores_all - lam_trace * trace_all 1085 | 1086 | return loss_value, vamp_score_pairs, pen_scores_all, pen_C00_map_all, pen_C11_map_all, pen_C01_map_all, trace_all 1087 | 1088 | 1089 | def fit(self, data_loader: torch.utils.data.DataLoader, n_epochs=1, validation_loader=None, 1090 | mask=False, lam_decomp: float = 1., lam_trace: float = 0., 1091 | start_mask: int = 0, end_trace: int = 0, 1092 | train_score_callback: Callable[[int, torch.Tensor], None] = None, 1093 | validation_score_callback: Callable[[int, torch.Tensor], None] = None, 1094 | tb_writer=None, reset_step=False, clip=False, save_criteria=None, 1095 | lam_pen_perc=None, lam_pen_C00=0., lam_pen_C11=0., lam_pen_C01=0., **kwargs): 1096 | r""" Fits iVAMPnet on data. 1097 | 1098 | Parameters 1099 | ---------- 1100 | data_loader : torch.utils.data.DataLoader 1101 | The data to use for training. Should yield a tuple of batches representing 1102 | instantaneous and time-lagged samples. 1103 | n_epochs : int, default=1 1104 | The number of epochs (i.e., passes through the training data) to use for training. 1105 | validation_loader : torch.utils.data.DataLoader, optional, default=None 1106 | Validation data, should also be yielded as a two-element tuple. 1107 | mask : bool, default=False 1108 | Bool to decide if the mask should be trained or not 1109 | lam_decomp : float 1110 | The weighting factor how much the dependency score should be weighted in the loss. 1111 | lam_trace : float 1112 | The weighting factor how much the trace should be weighted in the loss. 1113 | start_mask : int, default=0 1114 | The epoch after which the mask should be trained. 1115 | end_trace : int, default=0 1116 | The epoch from which on the trace should not be included in the loss anymore. 1117 | train_score_callback : callable, optional, default=None 1118 | Callback function which is invoked after each batch and gets as arguments the current training step 1119 | as well as the score (as torch Tensor). 1120 | validation_score_callback : callable, optional, default=None 1121 | Callback function for validation data. Is invoked after each epoch if validation data is given 1122 | and the callback function is not None. Same as the train callback, this gets the 'step' as well as 1123 | the score. 1124 | tb_writer : tensorboard writer 1125 | If given, scores will be recorded in the tensorboard log file. 1126 | clip : bool default=False 1127 | If True the gradients of the weights will be clipped by norm before applying them for the update. 1128 | save_criteria : float 1129 | If the validation value of pen_C01 is lower than save_criteria the weights are saved. 1130 | At the end of the training loop the weights will be set to the last saved weights. 1131 | **kwargs 1132 | Optional keyword arguments for scikit-learn compatibility 1133 | 1134 | Returns 1135 | ------- 1136 | self : iVAMPNet 1137 | Reference to self. 1138 | """ 1139 | if reset_step: # if all statistics should be recollected from scratch 1140 | self.reset_scores() 1141 | if save_criteria is not None: 1142 | weights_temp = self.state_dict() 1143 | # and train 1144 | train_mask=False 1145 | for epoch in range(n_epochs): 1146 | if (epoch >= start_mask) and mask: 1147 | train_mask=True 1148 | if epoch >= end_trace: 1149 | lam_trace = 0. 1150 | for batch_0, batch_t in data_loader: 1151 | self.partial_fit((batch_0, batch_t), lam_decomp=lam_decomp, mask=train_mask, 1152 | lam_trace=lam_trace, 1153 | train_score_callback=train_score_callback, tb_writer=tb_writer, 1154 | clip=clip, lam_pen_perc=lam_pen_perc, lam_pen_C00=lam_pen_C00, lam_pen_C11=lam_pen_C11, lam_pen_C01=lam_pen_C01) 1155 | 1156 | if validation_loader is not None: 1157 | with torch.no_grad(): 1158 | val_scores = [] 1159 | val_vamp_scores = [] 1160 | val_pen_scores = [] 1161 | val_pen_C00 = [] 1162 | val_pen_C11 = [] 1163 | val_pen_C01 = [] 1164 | val_trace = [] 1165 | for val_batch in validation_loader: 1166 | ret = self.validate((val_batch[0], val_batch[1]), lam_decomp=lam_decomp, lam_trace=lam_trace, lam_pen_perc=lam_pen_perc, lam_pen_C00=lam_pen_C00, lam_pen_C11=lam_pen_C11, lam_pen_C01=lam_pen_C01) 1167 | loss_value, vamp_score_pairs, pen_scores_all, pen_C00_map_all, pen_C11_map_all, pen_C01_map_all, trace_all = ret 1168 | val_scores.append(-loss_value) 1169 | val_vamp_scores.append(vamp_score_pairs) 1170 | val_pen_scores.append(pen_scores_all) 1171 | val_pen_C00.append(pen_C00_map_all) 1172 | val_pen_C11.append(pen_C11_map_all) 1173 | val_pen_C01.append(pen_C01_map_all) 1174 | val_trace.append(trace_all) 1175 | 1176 | mean_score = torch.mean(torch.stack(val_scores)) 1177 | mean_vamp_score = torch.mean(torch.stack(val_vamp_scores)) 1178 | mean_pen_score = torch.mean(torch.stack(val_pen_scores)) 1179 | mean_pen_C00 = torch.mean(torch.stack(val_pen_C00)) 1180 | mean_pen_C11 = torch.mean(torch.stack(val_pen_C11)) 1181 | mean_pen_C01 = torch.mean(torch.stack(val_pen_C01)) 1182 | mean_trace = torch.mean(torch.stack(val_trace)) 1183 | 1184 | if validation_score_callback is not None: 1185 | validation_score_callback(self._step, mean_score.detach()) 1186 | if tb_writer is not None: 1187 | tb_writer.add_scalars('Loss', {'valid': -mean_score.item()}, self._step) 1188 | tb_writer.add_scalars('VAMPE', {'valid': mean_vamp_score.item()}, self._step) 1189 | tb_writer.add_scalars('Pen_C00', {'valid': mean_pen_C00.item()}, self._step) 1190 | tb_writer.add_scalars('Pen_C11', {'valid': mean_pen_C11.item()}, self._step) 1191 | tb_writer.add_scalars('Pen_C01', {'valid': mean_pen_C01.item()}, self._step) 1192 | tb_writer.add_scalars('Pen_scores', {'valid': mean_pen_score.item()}, self._step) 1193 | tb_writer.add_scalars('Trace_all', {'valid': mean_trace.item()}, self._step) 1194 | self._validation_scores.append((self._step, (mean_score).item())) 1195 | self._validation_vampe.append((self._step, (mean_vamp_score).item())) 1196 | self._validation_pen_C00.append((self._step, (mean_pen_C00).item())) 1197 | self._validation_pen_C11.append((self._step, (mean_pen_C11).item())) 1198 | self._validation_pen_C01.append((self._step, (mean_pen_C01).item())) 1199 | self._validation_pen_scores.append((self._step, (mean_pen_score).item())) 1200 | self._validation_trace.append((self._step, (mean_trace).item())) 1201 | 1202 | if save_criteria is not None: 1203 | if mean_pen_C01 < save_criteria: 1204 | # if the criteria is met, save the weights 1205 | weights_temp = self.state_dict() 1206 | 1207 | if save_criteria is not None: 1208 | # End the end of the loop load the weights which last fulfilled the save constraint 1209 | self.load_state_dict(weights_temp) 1210 | 1211 | return self 1212 | 1213 | def transform(self, data, instantaneous: bool = True, **kwargs): 1214 | r""" Transforms data through the instantaneous or time-shifted network lobe. 1215 | 1216 | Parameters 1217 | ---------- 1218 | data : numpy array or torch tensor 1219 | The data to transform. 1220 | instantaneous : bool, default=True 1221 | Whether to use the instantaneous lobe or the time-shifted lobe for transformation. 1222 | **kwargs 1223 | Ignored kwargs for api compatibility. 1224 | 1225 | Returns 1226 | ------- 1227 | transform : array_like 1228 | List of numpy array or numpy array containing transformed data. 1229 | """ 1230 | model = self.fetch_model() 1231 | return model.transform(data, **kwargs) 1232 | 1233 | def fetch_model(self) -> iVAMPnetModel: 1234 | r""" Yields the current model. """ 1235 | return iVAMPnetModel(self.lobes, self.mask, dtype=self.dtype, device=self.device, epsilon=self.epsilon, 1236 | mode=self.score_mode) 1237 | 1238 | def state_dict(self): 1239 | ''' Returns the state_dict of all lobes and the mask. 1240 | 1241 | Returns 1242 | ------- 1243 | ret: list of state_dicts 1244 | ''' 1245 | dicts_lobe = [] 1246 | for lobe in self.lobes: 1247 | dicts_lobe.append(lobe.state_dict()) 1248 | mask_dict = self.mask.state_dict() 1249 | ret = [dicts_lobe, mask_dict] 1250 | return ret 1251 | 1252 | 1253 | def load_state_dict(self, state_dict): 1254 | ''' Loads the provided state_dict into the estimator. Useful to load a saved training instance. 1255 | 1256 | Parameters 1257 | ---------- 1258 | state_dict: list 1259 | Should be of the form given by the function self.state_dict. Its a list of a list of all state_dict lobes and the state dict of the mask. 1260 | ''' 1261 | 1262 | dict_lobes, mask_dict = state_dict 1263 | for n in range(self.N): 1264 | self.lobes[n].load_state_dict(dict_lobes[n]) 1265 | self.mask.load_state_dict(mask_dict) 1266 | return 1267 | 1268 | def save_params(self, path: str): 1269 | ''' Saves the state_dicts at the specified path. 1270 | 1271 | Parameters 1272 | ---------- 1273 | path: str 1274 | The path where the state_dict should be saved. 1275 | ''' 1276 | dicts_lobe, mask_dict = self.state_dict() 1277 | savez_dict = dict() 1278 | for n in range(self.N): 1279 | savez_dict['lobe_'+str(n)] = dicts_lobe[n] 1280 | savez_dict['mask_dict'] = mask_dict 1281 | 1282 | np.savez(path, **savez_dict) 1283 | 1284 | return print('Saved parameters at: '+path) 1285 | 1286 | def load_params(self, path: str): 1287 | ''' Loads the state_dicts from the specified path. 1288 | 1289 | Parameters 1290 | ---------- 1291 | path: str 1292 | The path where the state_dict should be loaded from. 1293 | ''' 1294 | dicts = np.load(path, allow_pickle=True) 1295 | 1296 | dict_lobes = [] 1297 | mask_dict = dicts['mask_dict'].item() 1298 | for n in range(self.N): 1299 | dict_lobes.append(dicts['lobe_'+str(n)].item()) 1300 | state_dict = [dict_lobes, mask_dict] 1301 | self.load_state_dict(state_dict) 1302 | 1303 | return --------------------------------------------------------------------------------