├── AlaDi └── deep_msm_ala_heavyatoms_withoutKoop_filter_states.ipynb ├── Prinz ├── deep_ed_0.py └── deep_ml_0.py └── README.md /Prinz/deep_ed_0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable, grad, backward 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from scipy.stats import norm 8 | import torch.utils.data as Data 9 | from math import pi,inf,log 10 | import copy 11 | 12 | from pyemma.plots import scatter_contour 13 | from pyemma.msm import MSM,markov_model 14 | from scipy import linalg 15 | from approximate_diffusion_models import OneDimensionalModel 16 | 17 | all_trajs=np.load('data/traj.npy') 18 | all_trajs_val=np.load('data/traj_val.npy') 19 | 20 | beta=1. 21 | def potential_function(x): 22 | return 4*(x**8+0.8*np.exp(-80*x*x)+0.2*np.exp(-80*(x-0.5)**2)+0.5*np.exp(-40*(x+0.5)**2)) 23 | 24 | lb=-1. 25 | ub=1. 26 | grid_num=100 27 | delta_t=0.01 28 | diffusion_model=OneDimensionalModel(potential_function,beta,lb,ub,grid_num,delta_t) 29 | 30 | tau=5 31 | 32 | class EarlyStopping: 33 | def __init__(self,p=0): 34 | self.patience=p 35 | self.j=0 36 | self.v=inf 37 | self.other_parameters=None 38 | 39 | def reset(self): 40 | self.j=0 41 | self.v=inf 42 | self.other_parameters=None 43 | 44 | def read_validation_result(self,model,validation_cost,other_parameters=None): 45 | if validation_cost=self.patience: 53 | return True 54 | return False 55 | 56 | def get_best_model(self): 57 | return copy.deepcopy(self.model) 58 | 59 | def get_best_other_parameters(self): 60 | return self.other_parameters 61 | 62 | class Net_P(nn.Module): 63 | def __init__(self,input_dim,state_num,net_width=64,n_hidden_layer=4): 64 | super(Net_P, self).__init__() 65 | self.input_dim=input_dim 66 | self.state_num=state_num 67 | self.net_width=net_width 68 | self.n_hidden_layer=n_hidden_layer 69 | 70 | self.hidden_layer_list=nn.ModuleList([nn.Linear(input_dim,net_width)]+[nn.Linear(net_width, net_width) for i in range(n_hidden_layer-1)]) 71 | self.output_layer=nn.Linear(net_width,state_num) 72 | self.bn_input=nn.BatchNorm1d(input_dim) 73 | self.bn_hidden_list=nn.ModuleList([nn.BatchNorm1d(net_width) for i in range(n_hidden_layer)]) 74 | 75 | def forward(self,x): 76 | x=self.bn_input(x) 77 | for i in range(self.n_hidden_layer): 78 | x=self.hidden_layer_list[i](x) 79 | x=self.bn_hidden_list[i](x) 80 | x=F.relu(x) 81 | x=self.output_layer(x) 82 | return x 83 | 84 | class Net_G(nn.Module): 85 | def __init__(self,data_dim,state_num,noise_dim,eps=0,net_width=64,n_hidden_layer=4): 86 | super(Net_G, self).__init__() 87 | self.data_dim=data_dim 88 | self.noise_dim=noise_dim 89 | self.state_num=state_num 90 | self.net_width=net_width 91 | self.n_hidden_layer=n_hidden_layer 92 | self.eps=eps 93 | 94 | self.hidden_layer_list=nn.ModuleList([nn.Linear(state_num+noise_dim,net_width)]+[nn.Linear(net_width, net_width) for i in range(n_hidden_layer-1)]) 95 | self.output_layer=nn.Linear(net_width,data_dim) 96 | self.bn_input=nn.BatchNorm1d(state_num+noise_dim) 97 | self.bn_hidden_list=nn.ModuleList([nn.BatchNorm1d(net_width) for i in range(n_hidden_layer)]) 98 | 99 | def forward(self,x): 100 | x=self.bn_input(x) 101 | for i in range(self.n_hidden_layer): 102 | x=self.hidden_layer_list[i](x) 103 | x=self.bn_hidden_list[i](x) 104 | x=F.relu(x) 105 | x=self.output_layer(x) 106 | return x 107 | 108 | state_num=4 109 | noise_dim=4 110 | 111 | partition_mem=np.empty([3,diffusion_model.center_list.shape[0],state_num]) 112 | K_0_mem=np.empty([3,state_num,state_num]) 113 | its_0_mem=np.empty([3,3]) 114 | transition_density_0_mem=np.empty([3,diffusion_model.center_list.shape[0],diffusion_model.center_list.shape[0]]) 115 | stationary_density_0_mem=np.empty([3,diffusion_model.center_list.shape[0]]) 116 | 117 | for kk in range(3): 118 | traj=all_trajs[kk] 119 | traj_val=all_trajs_val[kk] 120 | 121 | P=Net_P(1,state_num) 122 | G=Net_G(1,state_num,noise_dim) 123 | 124 | P.train() 125 | G.train() 126 | 127 | batch_size = 100 128 | LR = 1e-3 # learning rate for generator 129 | 130 | X_mem=torch.from_numpy(traj[:-tau]).float() 131 | Y_mem=torch.from_numpy(traj[tau:]).float() 132 | X_val=Variable(torch.from_numpy(traj_val[:-tau]).float()) 133 | Y_val=Variable(torch.from_numpy(traj_val[tau:]).float()) 134 | data_size=X_mem.shape[0] 135 | data_size_val=traj_val.shape[0]-tau 136 | 137 | opt_P = torch.optim.Adam(P.parameters(),lr=LR) 138 | opt_G = torch.optim.Adam(G.parameters(),lr=LR) 139 | stopper=EarlyStopping(5) 140 | for epoch in range(200): 141 | idx_mem_0=torch.randperm(data_size) 142 | idx=0 143 | while True: 144 | actual_batch_size=min(batch_size,data_size-idx) 145 | if actual_batch_size<=0: 146 | break 147 | X=Variable(X_mem[idx_mem_0[idx:idx+actual_batch_size]]) 148 | Y=Variable(Y_mem[idx_mem_0[idx:idx+actual_batch_size]]) 149 | idx+=actual_batch_size 150 | O = P(X) 151 | M = F.softmax(O,dim=1) 152 | B0 = torch.eye(state_num)[(state_num-torch.sum(torch.cumsum(M,dim=1)>=Variable(torch.rand(actual_batch_size)).unsqueeze(1),1).long()).data].float() 153 | B1 = torch.eye(state_num)[(state_num-torch.sum(torch.cumsum(M,dim=1)>=Variable(torch.rand(actual_batch_size)).unsqueeze(1),1).long()).data].float() 154 | R0 = Variable(torch.cat((B0,torch.randn(actual_batch_size, noise_dim)),1)) 155 | R1 = Variable(torch.cat((B1,torch.randn(actual_batch_size, noise_dim)),1)) 156 | Y0 = G(R0) 157 | Y1 = G(R1) 158 | D = torch.abs(Y0-Y)+torch.abs(Y1-Y)-torch.abs(Y0-Y1) 159 | 160 | G_loss = torch.mean(D) 161 | opt_G.zero_grad() 162 | G_loss.backward() 163 | opt_G.step() 164 | 165 | opt_P.zero_grad() 166 | O.backward((B0+B1-2*M.data)*D.data/(actual_batch_size+0.)) 167 | opt_P.step() 168 | 169 | P.eval() 170 | G.eval() 171 | O=P(X_val) 172 | M = F.softmax(O,dim=1) 173 | B0 = torch.eye(state_num)[(state_num-torch.sum(torch.cumsum(M,dim=1)>=Variable(torch.rand(data_size_val)).unsqueeze(1),1).long()).data].float() 174 | B1 = torch.eye(state_num)[(state_num-torch.sum(torch.cumsum(M,dim=1)>=Variable(torch.rand(data_size_val)).unsqueeze(1),1).long()).data].float() 175 | R0 = Variable(torch.cat((B0,torch.randn(data_size_val, noise_dim)),1)) 176 | R1 = Variable(torch.cat((B1,torch.randn(data_size_val, noise_dim)),1)) 177 | Y0 = G(R0) 178 | Y1 = G(R1) 179 | D = torch.abs(Y0-Y_val)+torch.abs(Y1-Y_val)-torch.abs(Y0-Y1) 180 | loss_val=(torch.mean(D)).data[0] 181 | P.train() 182 | G.train() 183 | print(epoch,loss_val) 184 | if stopper.read_validation_result([P,G],loss_val): 185 | break 186 | 187 | P,G=stopper.get_best_model() 188 | 189 | LR=1e-5 190 | P.eval() 191 | opt_G = torch.optim.Adam(G.parameters(),lr=LR) 192 | stopper=EarlyStopping(5) 193 | stopper.read_validation_result(G,loss_val) 194 | M_mem = F.softmax(P(Variable(X_mem)),dim=1).data 195 | M_val = F.softmax(P(X_val),dim=1) 196 | for epoch in range(200): 197 | idx_mem_0=torch.randperm(data_size) 198 | idx=0 199 | while True: 200 | actual_batch_size=min(batch_size,data_size-idx) 201 | if actual_batch_size<=0: 202 | break 203 | M=Variable(M_mem[idx_mem_0[idx:idx+actual_batch_size]]) 204 | Y=Variable(Y_mem[idx_mem_0[idx:idx+actual_batch_size]]) 205 | idx+=actual_batch_size 206 | B0 = torch.eye(state_num)[(state_num-torch.sum(torch.cumsum(M,dim=1)>=Variable(torch.rand(actual_batch_size)).unsqueeze(1),1).long()).data].float() 207 | B1 = torch.eye(state_num)[(state_num-torch.sum(torch.cumsum(M,dim=1)>=Variable(torch.rand(actual_batch_size)).unsqueeze(1),1).long()).data].float() 208 | R0 = Variable(torch.cat((B0,torch.randn(actual_batch_size, noise_dim)),1)) 209 | R1 = Variable(torch.cat((B1,torch.randn(actual_batch_size, noise_dim)),1)) 210 | Y0 = G(R0) 211 | Y1 = G(R1) 212 | D = torch.abs(Y0-Y)+torch.abs(Y1-Y)-torch.abs(Y0-Y1) 213 | 214 | G_loss = torch.mean(D) 215 | opt_G.zero_grad() 216 | G_loss.backward() 217 | opt_G.step() 218 | 219 | G.eval() 220 | B0 = torch.eye(state_num)[(state_num-torch.sum(torch.cumsum(M_val,dim=1)>=Variable(torch.rand(data_size_val)).unsqueeze(1),1).long()).data].float() 221 | B1 = torch.eye(state_num)[(state_num-torch.sum(torch.cumsum(M_val,dim=1)>=Variable(torch.rand(data_size_val)).unsqueeze(1),1).long()).data].float() 222 | R0 = Variable(torch.cat((B0,torch.randn(data_size_val, noise_dim)),1)) 223 | R1 = Variable(torch.cat((B1,torch.randn(data_size_val, noise_dim)),1)) 224 | Y0 = G(R0) 225 | Y1 = G(R1) 226 | D = torch.abs(Y0-Y_val)+torch.abs(Y1-Y_val)-torch.abs(Y0-Y1) 227 | loss_val=(torch.mean(D)).data[0] 228 | G.train() 229 | print(epoch,loss_val) 230 | if stopper.read_validation_result(G,loss_val): 231 | break 232 | 233 | G=stopper.get_best_model() 234 | 235 | torch.save(P.state_dict(), 'data/ed/P_params_traj_'+str(kk)+'_tau_'+str(tau)+'.pkl') 236 | torch.save(G.state_dict(), 'data/ed/G_params_traj_'+str(kk)+'_tau_'+str(tau)+'.pkl') 237 | 238 | P.load_state_dict(torch.load('data/ed/P_params_traj_'+str(kk)+'_tau_'+str(tau)+'.pkl')) 239 | G.load_state_dict(torch.load('data/ed/G_params_traj_'+str(kk)+'_tau_'+str(tau)+'.pkl')) 240 | 241 | P.eval() 242 | G.eval() 243 | 244 | xx=Variable(torch.from_numpy(diffusion_model.center_list.reshape(-1,1)).float()) 245 | pp=(F.softmax(P(xx),1)).data.numpy() 246 | partition_mem[kk]=pp 247 | 248 | TEST_BATCH_SIZE=10000 249 | sample_mem=np.empty([TEST_BATCH_SIZE,state_num]) 250 | for idx in range(state_num): 251 | B=torch.zeros([TEST_BATCH_SIZE,state_num]) 252 | B[:,idx]=1 253 | R=Variable(torch.cat((B,torch.randn(TEST_BATCH_SIZE, noise_dim)),1)) 254 | sample_mem[:,idx]=G(R).data.numpy().reshape(-1) 255 | 256 | K=np.empty([state_num,state_num]) 257 | for idx in range(state_num): 258 | GR=Variable(torch.from_numpy(sample_mem[:,idx].reshape(-1,1)).float()) 259 | K[idx,:]=torch.mean(F.softmax(P(GR),1),0).data.numpy() 260 | K=K/K.sum(1)[:,np.newaxis] 261 | 262 | its=-tau*delta_t/np.log(sorted(np.absolute(np.linalg.eigvals(K)), key=lambda x:np.absolute(x),reverse=True)[1:4]) 263 | its_0_mem[kk]=its 264 | 265 | print(its) 266 | print(diffusion_model.its[1:4]) 267 | 268 | hist_mem=np.empty([diffusion_model.center_list.shape[0],state_num]) 269 | for i in range(state_num): 270 | hist_mem[:,i]=np.histogram(sample_mem[:,i],bins=grid_num,range=(lb,ub),density=True,)[0] 271 | hist_mem[:,i]/=hist_mem[:,i].sum() 272 | 273 | transition_density=pp.dot(hist_mem.T) 274 | model=markov_model(K) 275 | stationary_density=model.stationary_distribution.dot(hist_mem.T) 276 | 277 | transition_density_0_mem[kk]=transition_density 278 | stationary_density_0_mem[kk]=stationary_density 279 | 280 | np.save('data/ed/partition_mem',partition_mem) 281 | np.save('data/ed/K_0_mem',K_0_mem) 282 | np.save('data/ed/its_0_mem',its_0_mem) 283 | np.save('data/ed/transition_density_0_mem',transition_density_0_mem) 284 | np.save('data/ed/stationary_density_0_mem',stationary_density_0_mem) 285 | 286 | for kk in range(3): 287 | plt.figure() 288 | plt.plot(partition_mem[kk]) 289 | plt.figure() 290 | plt.plot(stationary_density_0_mem[kk]) 291 | plt.figure() 292 | plt.contourf(transition_density_0_mem[kk]) -------------------------------------------------------------------------------- /Prinz/deep_ml_0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable, grad, backward 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from scipy.stats import norm 8 | import torch.utils.data as Data 9 | from math import pi,inf,log 10 | import copy 11 | 12 | from pyemma.plots import scatter_contour 13 | from pyemma.msm import MSM,markov_model 14 | from scipy import linalg 15 | from approximate_diffusion_models import OneDimensionalModel 16 | 17 | all_trajs=np.load('data/traj.npy') 18 | all_trajs_val=np.load('data/traj_val.npy') 19 | 20 | beta=1. 21 | def potential_function(x): 22 | return 4*(x**8+0.8*np.exp(-80*x*x)+0.2*np.exp(-80*(x-0.5)**2)+0.5*np.exp(-40*(x+0.5)**2)) 23 | 24 | lb=-1. 25 | ub=1. 26 | grid_num=100 27 | delta_t=0.01 28 | diffusion_model=OneDimensionalModel(potential_function,beta,lb,ub,grid_num,delta_t) 29 | 30 | tau=5 31 | 32 | def log_sum_exp(value, dim=None, keepdim=False): 33 | """Numerically stable implementation of the operation 34 | 35 | value.exp().sum(dim, keepdim).log() 36 | """ 37 | # TODO: torch.max(value, dim=None) threw an error at time of writing 38 | if dim is not None: 39 | m, _ = torch.max(value, dim=dim, keepdim=True) 40 | value0 = value - m 41 | if keepdim is False: 42 | m = m.squeeze(dim) 43 | return m + torch.log(torch.sum(torch.exp(value0), 44 | dim=dim, keepdim=keepdim)) 45 | else: 46 | m = torch.max(value) 47 | sum_exp = torch.sum(torch.exp(value - m)) 48 | return m + torch.log(sum_exp) 49 | 50 | class EarlyStopping: 51 | def __init__(self,p=0): 52 | self.patience=p 53 | self.j=0 54 | self.v=inf 55 | self.other_parameters=None 56 | 57 | def reset(self): 58 | self.j=0 59 | self.v=inf 60 | self.other_parameters=None 61 | 62 | def read_validation_result(self,model,validation_cost,other_parameters=None): 63 | if validation_cost=self.patience: 71 | return True 72 | return False 73 | 74 | def get_best_model(self): 75 | return copy.deepcopy(self.model) 76 | 77 | def get_best_other_parameters(self): 78 | return self.other_parameters 79 | 80 | class Net_P(nn.Module): 81 | def __init__(self,input_dim,state_num,net_width=64,n_hidden_layer=4): 82 | super(Net_P, self).__init__() 83 | self.input_dim=input_dim 84 | self.state_num=state_num 85 | self.net_width=net_width 86 | self.n_hidden_layer=n_hidden_layer 87 | 88 | self.hidden_layer_list=nn.ModuleList([nn.Linear(input_dim,net_width)]+[nn.Linear(net_width, net_width) for i in range(n_hidden_layer-1)]) 89 | self.output_layer=nn.Linear(net_width,state_num) 90 | self.bn_input=nn.BatchNorm1d(input_dim) 91 | self.bn_hidden_list=nn.ModuleList([nn.BatchNorm1d(net_width) for i in range(n_hidden_layer)]) 92 | self.bn_output=nn.BatchNorm1d(state_num) 93 | 94 | def forward(self,x): 95 | x=self.bn_input(x) 96 | for i in range(self.n_hidden_layer): 97 | x=self.hidden_layer_list[i](x) 98 | x=self.bn_hidden_list[i](x) 99 | x=F.relu(x) 100 | x=self.output_layer(x) 101 | x=self.bn_output(x) 102 | x=F.log_softmax(x,dim=1) 103 | return x 104 | 105 | class Net_G(nn.Module): 106 | def __init__(self,input_dim,state_num,eps=0,net_width=64,n_hidden_layer=4): 107 | super(Net_G, self).__init__() 108 | self.input_dim=input_dim 109 | self.state_num=state_num 110 | self.net_width=net_width 111 | self.n_hidden_layer=n_hidden_layer 112 | self.eps=eps 113 | 114 | self.hidden_layer_list=nn.ModuleList([nn.Linear(input_dim,net_width)]+[nn.Linear(net_width, net_width) for i in range(n_hidden_layer-1)]) 115 | self.output_layer=nn.Linear(net_width,state_num) 116 | self.bn_input=nn.BatchNorm1d(input_dim) 117 | self.bn_hidden_list=nn.ModuleList([nn.BatchNorm1d(net_width) for i in range(n_hidden_layer)]) 118 | 119 | def forward(self,x): 120 | x=self.bn_input(x) 121 | for i in range(self.n_hidden_layer): 122 | x=self.hidden_layer_list[i](x) 123 | x=self.bn_hidden_list[i](x) 124 | x=F.relu(x) 125 | x=self.output_layer(x) 126 | return x 127 | 128 | state_num=4 129 | 130 | partition_mem=np.empty([3,diffusion_model.center_list.shape[0],state_num]) 131 | K_0_mem=np.empty([3,state_num,state_num]) 132 | its_0_mem=np.empty([3,3]) 133 | transition_density_0_mem=np.empty([3,diffusion_model.center_list.shape[0],diffusion_model.center_list.shape[0]]) 134 | stationary_density_0_mem=np.empty([3,diffusion_model.center_list.shape[0]]) 135 | 136 | for kk in range(3): 137 | traj=all_trajs[kk] 138 | traj_val=all_trajs_val[kk] 139 | 140 | P=Net_P(1,state_num) 141 | G=Net_G(1,state_num) 142 | 143 | P.train() 144 | G.train() 145 | 146 | batch_size = 100 147 | LR = 1e-3 # learning rate for generator 148 | 149 | X_mem=torch.from_numpy(traj[:-tau]).float() 150 | Y_mem=torch.from_numpy(traj[tau:]).float() 151 | X_val=Variable(torch.from_numpy(traj_val[:-tau]).float()) 152 | Y_val=Variable(torch.from_numpy(traj_val[tau:]).float()) 153 | data_size=X_mem.shape[0] 154 | data_size_val=traj_val.shape[0]-tau 155 | ''' 156 | opt = torch.optim.Adam(list(P.parameters())+list(G.parameters()),lr=LR) 157 | stopper=EarlyStopping(5) 158 | for epoch in range(200): 159 | idx_mem_0=torch.randperm(data_size) 160 | idx=0 161 | while True: 162 | actual_batch_size=min(batch_size,data_size-idx) 163 | if actual_batch_size<=0: 164 | break 165 | X_0=Variable(X_mem[idx_mem_0[idx:idx+actual_batch_size]]) 166 | Y_0=Variable(Y_mem[idx_mem_0[idx:idx+actual_batch_size]]) 167 | idx+=actual_batch_size 168 | log_Chi_0=P(X_0) 169 | log_Gamma_0=G(Y_0) 170 | log_Gamma_0=log_Gamma_0-log_sum_exp(log_Gamma_0,0)+log(actual_batch_size+0.) 171 | ll=log_sum_exp(log_Chi_0+log_Gamma_0,1) 172 | loss=-torch.mean(ll) 173 | opt.zero_grad() 174 | backward(loss) 175 | opt.step() 176 | 177 | P.eval() 178 | G.eval() 179 | log_Chi_val=P(X_val) 180 | log_Gamma_val=G(Y_val) 181 | log_Gamma_val=log_Gamma_val-log_sum_exp(log_Gamma_val,0)+log(data_size_val+0.) 182 | ll=log_sum_exp(log_Chi_val+log_Gamma_val,1) 183 | loss_val=-torch.sum(ll).data[0] 184 | print(epoch,loss_val) 185 | P.train() 186 | G.train() 187 | if stopper.read_validation_result([P,G],loss_val): 188 | break 189 | 190 | P,G=stopper.get_best_model() 191 | 192 | LR=1e-5 193 | opt = torch.optim.Adam(list(G.parameters()),lr=LR) 194 | stopper=EarlyStopping(5) 195 | stopper.read_validation_result(G,loss_val) 196 | P.eval() 197 | G.train() 198 | log_Chi=P(Variable(X_mem)).data 199 | log_Chi_val=P(X_val) 200 | for epoch in range(200): 201 | idx_mem_0=torch.randperm(data_size) 202 | idx=0 203 | print(epoch) 204 | while True: 205 | actual_batch_size=min(batch_size,data_size-idx) 206 | if actual_batch_size<=0: 207 | break 208 | Y_0=Variable(Y_mem[idx_mem_0[idx:idx+actual_batch_size]]) 209 | log_Chi_0=Variable(log_Chi[idx_mem_0[idx:idx+actual_batch_size]]) 210 | idx+=actual_batch_size 211 | log_Gamma_0=G(Y_0) 212 | log_Gamma_0=log_Gamma_0-log_sum_exp(log_Gamma_0,0)+log(actual_batch_size+0.) 213 | ll=log_sum_exp(log_Chi_0+log_Gamma_0,1) 214 | loss=-torch.mean(ll) 215 | opt.zero_grad() 216 | backward(loss) 217 | opt.step() 218 | 219 | G.eval() 220 | Gamma_val=G(Y_val) 221 | Gamma_val=Gamma_val/torch.mean(Gamma_val,0) 222 | log_Gamma_val=G(Y_val) 223 | log_Gamma_val=log_Gamma_val-log_sum_exp(log_Gamma_val,0)+log(data_size_val+0.) 224 | ll=log_sum_exp(log_Chi_val+log_Gamma_val,1) 225 | loss_val=-torch.sum(ll).data[0] 226 | G.train() 227 | print(epoch,loss_val) 228 | if stopper.read_validation_result(G,loss_val): 229 | break 230 | G=stopper.get_best_model() 231 | 232 | 233 | torch.save(P.state_dict(), 'data/ml/P_params_traj_'+str(kk)+'_tau_'+str(tau)+'.pkl') 234 | torch.save(G.state_dict(), 'data/ml/G_params_traj_'+str(kk)+'_tau_'+str(tau)+'.pkl') 235 | ''' 236 | P.load_state_dict(torch.load('data/ml/P_params_traj_'+str(kk)+'_tau_'+str(tau)+'.pkl')) 237 | G.load_state_dict(torch.load('data/ml/G_params_traj_'+str(kk)+'_tau_'+str(tau)+'.pkl')) 238 | 239 | P.eval() 240 | G.eval() 241 | 242 | xx=Variable(torch.from_numpy(diffusion_model.center_list.reshape(-1,1)).float()) 243 | pp=(torch.exp(P(xx))).data.numpy() 244 | partition_mem[kk]=pp 245 | 246 | Chi_1=torch.exp(P(Variable(Y_mem))) 247 | log_Gamma=G(Variable(Y_mem)) 248 | Gamma=torch.exp(log_Gamma-log_sum_exp(log_Gamma)) 249 | Gamma=Gamma/torch.mean(Gamma,0) 250 | K=torch.mm(torch.t(Gamma),Chi_1).data.numpy()/data_size 251 | K=K/K.sum(1)[:,np.newaxis] 252 | K_0_mem[kk]=K 253 | its=-tau*delta_t/np.log(sorted(np.absolute(np.linalg.eigvals(K)), key=lambda x:np.absolute(x),reverse=True)[1:4]) 254 | its_0_mem[kk]=its 255 | 256 | print(its) 257 | print(diffusion_model.its[1:4]) 258 | 259 | hist_mem=np.empty([diffusion_model.center_list.shape[0],state_num]) 260 | for i in range(state_num): 261 | hist_mem[:,i]=np.histogram(traj[tau:].reshape(-1),bins=grid_num,range=(lb,ub),density=True,weights=Gamma[:,i].data.numpy().reshape(-1))[0] 262 | hist_mem[:,i]/=hist_mem[:,i].sum() 263 | 264 | transition_density=pp.dot(hist_mem.T) 265 | model=markov_model(K) 266 | stationary_density=model.stationary_distribution.dot(hist_mem.T) 267 | 268 | transition_density_0_mem[kk]=transition_density 269 | stationary_density_0_mem[kk]=stationary_density 270 | 271 | np.save('data/ml/partition_mem',partition_mem) 272 | np.save('data/ml/K_0_mem',K_0_mem) 273 | np.save('data/ml/its_0_mem',its_0_mem) 274 | np.save('data/ml/transition_density_0_mem',transition_density_0_mem) 275 | np.save('data/ml/stationary_density_0_mem',stationary_density_0_mem) 276 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepGenMSM 2 | 3 | Here, we collect the code necessary to reproduce the results of our Project "Deep generative Markov State models for dynamical systems". 4 | 5 | The folder Prinz contains the code for the 1D four well Prinz potential using PyTorch. 6 | 7 | The folder AlaDi contains the code for the Alanine Dipeptide molecule and the generation of new structures using Tensorflow. 8 | --------------------------------------------------------------------------------