├── Artificial_Life_example.py ├── Calcium_Imaging_example.py ├── DeepDMBD ├── DynamicMarkovBlanketDiscovery.py ├── Flame_example.py ├── Flocking_example.py ├── LICENSE.md ├── Life_as_we_know_it_test.py ├── Lorenz_example.py ├── NewtonsCradle_example.py ├── Presentations ├── Dynamic Markov Blanket Discovery.pptx └── temp.txt ├── README.md ├── __init__.py ├── cradle.mp4 ├── cradle2.mp4 ├── data ├── flame_even_smaller.pt ├── lorenz_data.pt └── rotor.pt ├── dmbd_demo.ipynb ├── flame_pc_scores.png ├── models ├── ARHMM.py ├── BayesNet.py ├── BayesianFactorAnalysis.py ├── BayesianTransformer.py ├── BlockFactorAnalysis.py ├── DynamicMarkovBlanketDiscovery.py ├── GaussianMixtureModel.py ├── HMM.py ├── IsotropicGaussianMixtureModel.py ├── LDS.py ├── LDS_px.py ├── MVN_ard.py ├── MixLDS.py ├── MixtureofLinearTransforms.py ├── MixtureofMatrixNormalGammas.py ├── MultiNomialLogisticRegression.py ├── NLRegression.py ├── NLRegression_Multinomial.py ├── NormalSparse.py ├── PoissonMixtureModel.py ├── ReducedRankRegression.py ├── __init__.py ├── dHMM.py ├── dMixture.py ├── dMixtureofLinearTransforms.py ├── dists │ ├── ConjugatePrior.py │ ├── Delta.py │ ├── DiagonalWishart.py │ ├── Dirichlet.py │ ├── Gamma.py │ ├── LDS_unused.py │ ├── MVN_ard.py │ ├── MatrixNormalGamma.py │ ├── MatrixNormalWishart.py │ ├── Mixture.py │ ├── MultivariateNormal.py │ ├── MultivariateNormal_vector_format.py │ ├── NormalGamma.py │ ├── NormalInverseWishart.py │ ├── TensorNormalWishart.py │ ├── Wishart.py │ ├── __init__.py │ └── utils │ │ ├── __init__.py │ │ └── matrix_utils.py ├── rHMM.py └── vbem_test.py ├── simulations ├── Forager.py ├── Lorenz.py ├── NewtonsCradle.py ├── __init__ .py ├── cartthingy.py ├── flame.py └── forager_temp.py ├── test_bayes_net.py ├── test_dmbd.py └── vbem_test.py /Artificial_Life_example.py: -------------------------------------------------------------------------------- 1 | 2 | from models.DynamicMarkovBlanketDiscovery import * 3 | from matplotlib import pyplot as plt 4 | from matplotlib.animation import FuncAnimation, FFMpegWriter 5 | from matplotlib import cm 6 | start_time=time.time() 7 | 8 | print('Test on Artificial Life Data') 9 | print('Loading data....') 10 | y=np.genfromtxt('./data/rotor_story_y.txt') 11 | x=np.genfromtxt('./data/rotor_story_x.txt') 12 | print('....Done.') 13 | y=torch.tensor(y,requires_grad=False).float() 14 | x=torch.tensor(x,requires_grad=False).float() 15 | y=y.unsqueeze(-1) 16 | x=x.unsqueeze(-1) 17 | 18 | T = 100 19 | data = torch.cat((y,x),dim=-1) 20 | data = data[::9] 21 | v_data = torch.diff(data,dim=0) 22 | v_data = v_data/v_data.std() 23 | data = data[1:] 24 | data = data/data.std() 25 | 26 | data = torch.cat((data,v_data),dim=-1) 27 | del v_data 28 | del x 29 | del y 30 | T = data.shape[0] 31 | T = T//2 32 | data = data[:T] 33 | data = data.unsqueeze(1).clone().detach() 34 | 35 | # print('Initializing V model....') 36 | # v_model = DMBD(obs_shape=v_data.shape[-2:],role_dims=(16,16,16),hidden_dims=(5,5,5)) 37 | # print('Updating model V....') 38 | # v_model.update(v_data,None,None,iters=100,latent_iters=1,lr=0.25) 39 | # v_model.update(v_data,None,None,iters=10,latent_iters=1,lr=1) 40 | # print('Making Movie') 41 | # f = r"c://Users/brain/Desktop/rotator_movie_v.mp4" 42 | # ar = animate_results('sbz',f) 43 | # ar.make_movie(v_model, data, list(range(10))) 44 | # len_v_data = v_data 45 | # len_v_model = v_model 46 | 47 | print('Initializing X + V model....') 48 | model = DMBD(obs_shape=data.shape[-2:],role_dims=(0,1,0),hidden_dims=(12,4,0),regression_dim = 0, control_dim = 0, number_of_objects=11) 49 | 50 | print('Updating model X+V....') 51 | model.update(data,None,None,iters=40,latent_iters=1,lr=0.5,verbose=True) 52 | #model.px = None 53 | #model.update(data,None,None,iters=50,latent_iters=1,lr=1.0,verbose=True) 54 | #model.update(data,None,None,iters=10,latent_iters=1,lr=1) 55 | print('Making Movie') 56 | f = r"./rotator_movie.mp4" 57 | ar = animate_results('particular',f).make_movie(model, data, (0,)) 58 | 59 | sbz=model.px.mean() 60 | B = model.obs_model.obs_dist.mean() 61 | if model.regression_dim==1: 62 | roles = B[...,:-1]@sbz + B[...,-1:] 63 | else: 64 | roles = B@sbz 65 | sbz = sbz.squeeze(-3).squeeze(-1) 66 | roles = roles.squeeze(-1)[...,0:2] 67 | 68 | batch_num = 0 69 | temp1 = data[:,batch_num,:,0] 70 | temp2 = data[:,batch_num,:,1] 71 | rtemp1 = roles[:,batch_num,:,0] 72 | rtemp2 = roles[:,batch_num,:,1] 73 | 74 | idx = (model.assignment()[:,batch_num,:]==0) 75 | plt.scatter(temp1[idx],temp2[idx],color='y',alpha=0.5) 76 | ev_dim = model.role_dims[0] 77 | ob_dim = np.sum(model.role_dims[1:]) 78 | 79 | for i in range(ev_dim): 80 | idx = (model.obs_model.assignment()[:,batch_num,:]==i) 81 | plt.scatter(rtemp1[:,i],rtemp2[:,i]) 82 | plt.title('Environment + Roles') 83 | plt.show() 84 | 85 | ctemp = model.role_dims[1]*('b',) + model.role_dims[2]*('r',) 86 | 87 | for j in range(model.number_of_objects): 88 | idx = (model.assignment()[:,batch_num,:]==0) 89 | plt.scatter(temp1[idx],temp2[idx],color='y',alpha=0.2) 90 | for i in range(1+2*j,1+2*(j+1)): 91 | idx = (model.assignment()[:,batch_num,:]==i) 92 | plt.scatter(temp1[idx],temp2[idx]) 93 | plt.title('Object '+str(j+1) + ' (yellow is environment)') 94 | plt.show() 95 | 96 | idx = (model.assignment()[:,batch_num,:]==0) 97 | plt.scatter(temp1[idx],temp2[idx],color='y',alpha=0.2) 98 | k=0 99 | for i in range(ev_dim+ob_dim*j,ev_dim+ob_dim*(j+1)): 100 | idx = (model.obs_model.assignment()[:,batch_num,:]==i) 101 | plt.scatter(rtemp1[:,i],rtemp2[:,i],color=ctemp[k]) 102 | k=k+1 103 | plt.title('Object '+str(j+1) + ' roles') 104 | plt.show() 105 | 106 | 107 | len_data = data 108 | len_model = model 109 | 110 | run_time = time.time()-start_time 111 | print('Total Run Time: ',run_time) 112 | 113 | 114 | 115 | # # make frame by frame movie using particular designations 116 | # assignments = model.particular_assignment()/model.number_of_objects 117 | # confidence = model.assignment_pr().max(-1)[0] 118 | 119 | # fig = plt.figure(figsize=(7,7)) 120 | # ax = plt.axes(xlim=(-2.5,2.5),ylim=(-2.5,2.5)) 121 | # scatter=ax.scatter([], [], cmap = cm.rainbow, c=[], vmin=0.0, vmax=1.0) 122 | 123 | # T = data.shape[0] 124 | # fn = 0 125 | # scatter.set_offsets(data[fn%T, fn//T,:,:].numpy()) 126 | # scatter.set_array(assignments[fn%T, fn//T,:].numpy()) 127 | # scatter.set_alpha(confidence[fn%T, fn//T,:].numpy()) 128 | 129 | # plt.plot(model.ELBO_save) 130 | -------------------------------------------------------------------------------- /Calcium_Imaging_example.py: -------------------------------------------------------------------------------- 1 | 2 | from models.DynamicMarkovBlanketDiscovery import * 3 | from matplotlib import pyplot as plt 4 | from matplotlib.animation import FuncAnimation, FFMpegWriter 5 | from matplotlib import cm 6 | 7 | 8 | print('Test on Calcium Imaging data') 9 | 10 | data = torch.tensor(np.load('data\calciumForJeff.npy')).float().unsqueeze(-1) 11 | data = data/data.std() 12 | v_data = data.diff(dim=0,n=1) 13 | v_data = v_data/v_data.std() 14 | data = torch.cat((data[1:],v_data),dim=-1) 15 | data = data[:3600] 16 | data = data.reshape(12,300,41,2).swapaxes(0,1).clone().detach() 17 | model = DMBD(obs_shape=data.shape[-2:],role_dims=(1,1,0),hidden_dims=(4,2,0),batch_shape=(),regression_dim = -1, control_dim=0,number_of_objects=5) 18 | model.update(data,None,None,iters=50,lr=0.5,verbose=True) 19 | 20 | batch_num = 0 21 | t = torch.arange(0,data.shape[0]).view((data.shape[0],)+(1,)*(data.ndim-1)).expand(data.shape) 22 | plt.scatter(t[:,batch_num,:,0],data[:,batch_num,:,0],c=model.particular_assignment()[:,batch_num,:]) 23 | plt.show() 24 | dbar = torch.zeros(data.shape[0:2]+(model.number_of_objects+1,),requires_grad=False) 25 | ass = model.particular_assignment() 26 | for n in range(model.number_of_objects+1): 27 | temp = (data*(ass==n).unsqueeze(-1)).sum(-2)[...,0] 28 | temp = temp/temp.std() 29 | temp.unsqueeze(-1) 30 | dbar[:,:,n]=temp.clone().detach() 31 | 32 | 33 | -------------------------------------------------------------------------------- /Flame_example.py: -------------------------------------------------------------------------------- 1 | 2 | print('Test on Flame data set') 3 | 4 | import torch 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from models.DynamicMarkovBlanketDiscovery import * 8 | 9 | data = torch.load('./data/flame_even_smaller.pt').clone().detach() 10 | 11 | model = DMBD(obs_shape=data.shape[-2:],role_dims=(3,3,3),hidden_dims=(4,4,4),batch_shape=(),regression_dim = -1, control_dim=0,number_of_objects=1) 12 | 13 | from matplotlib.colors import ListedColormap, Normalize 14 | cmap = ListedColormap(['red', 'green', 'blue']) 15 | vmin = 0 # Minimum value of the color scale 16 | vmax = 2 # Maximum value of the color scale 17 | norm = Normalize(vmin=vmin, vmax=vmax) 18 | 19 | for i in range(10): 20 | model.update(data,None,None,iters=2,latent_iters=1,lr=0.5) 21 | 22 | sbz=model.px.mean().squeeze() 23 | r1 = model.role_dims[0] 24 | r2 = r1+model.role_dims[1] 25 | r3 = r2+ model.role_dims[2] 26 | h1 = model.hidden_dims[0] 27 | h2 = h1+model.hidden_dims[1] 28 | h3 = h2+ model.hidden_dims[2] 29 | 30 | 31 | p = model.assignment_pr() 32 | a = 2-model.assignment() 33 | plt.imshow(a[:,0,:].transpose(-2,-1),cmap=cmap, norm=norm, origin='lower') 34 | plt.xlabel('Time') 35 | plt.ylabel('Location') 36 | plt.savefig('flame_assignments.png') 37 | 38 | p = p.sum(-2) 39 | print('Show PC scores') 40 | s = sbz[:,:,0:h1] 41 | s = s-s.mean(0).mean(0) 42 | b = sbz[:,:,h1:h2] 43 | b = b-b.mean(0).mean(0) 44 | z = sbz[:,:,h2:h3] 45 | z = z-z.mean(0).mean(0) 46 | 47 | cs = (s.unsqueeze(-1)*s.unsqueeze(-2)).mean(0).mean(0) 48 | cb = (b.unsqueeze(-1)*b.unsqueeze(-2)).mean(0).mean(0) 49 | cz = (z.unsqueeze(-1)*z.unsqueeze(-2)).mean(0).mean(0) 50 | 51 | d,v=torch.linalg.eigh(cs) 52 | ss = v.transpose(-2,-1)@s.unsqueeze(-1) 53 | d,v=torch.linalg.eigh(cb) 54 | bb = v.transpose(-2,-1)@b.unsqueeze(-1) 55 | d,v=torch.linalg.eigh(cz) 56 | zz = v.transpose(-2,-1)@z.unsqueeze(-1) 57 | 58 | ss = ss.squeeze(-1)[...,-2:] 59 | bb = bb.squeeze(-1)[...,-2:] 60 | zz = zz.squeeze(-1)[...,-2:] 61 | 62 | ss = ss/ss.std() 63 | bb = bb/bb.std() 64 | zz = zz/zz.std() 65 | 66 | batch_num = 0 67 | fig, axs = plt.subplots(2, 1, sharex=True) 68 | 69 | axs[0].plot(zz[:,batch_num,-1:],'r',label='s') 70 | axs[0].plot(bb[:,batch_num,-1:],'g',label='b') 71 | axs[0].plot(ss[:,batch_num,-1:],'b',label='z') 72 | axs[0].set_title('Top PC Scores') 73 | # handles, labels = axs[0].get_legend_handles_labels() 74 | # selected_handles = [handles[0], handles[2], handles[4]] 75 | # selected_labels = [labels[0], labels[2], labels[4]] 76 | # axs[0].legend(selected_handles, selected_labels) 77 | axs[0].legend() 78 | 79 | axs[1].plot(p[:,batch_num,2],'r') 80 | axs[1].plot(p[:,batch_num,1],'g') 81 | axs[1].plot(p[:,batch_num,0],'b') 82 | axs[1].set_title('Number of Assigned Nodes') 83 | axs[1].set_xlabel('Time') 84 | plt.savefig('flame_pc_scores.png') 85 | plt.show() 86 | 87 | 88 | -------------------------------------------------------------------------------- /Flocking_example.py: -------------------------------------------------------------------------------- 1 | 2 | from models.DynamicMarkovBlanketDiscovery import * 3 | from matplotlib import pyplot as plt 4 | from matplotlib.animation import FuncAnimation, FFMpegWriter 5 | from matplotlib import cm 6 | start_time=time.time() 7 | 8 | # print("Test on Flocking Data") 9 | # with np.load("data\couzin2zone_sim_hist_key1_100runs.npz") as data: 10 | # r = data["r"] 11 | # v = r[:,1:]-r[:,:-1] 12 | # r = r[:,:-1] 13 | # r = torch.tensor(r).float().swapaxes(0,1) 14 | # v = torch.tensor(v).float().swapaxes(0,1) 15 | 16 | # data = torch.cat((r,v),dim=-1) 17 | 18 | def smoothe(data,n): 19 | temp = data[0:-n] 20 | for i in range(1,n): 21 | temp = temp+data[i:-(n-i)] 22 | return temp[::n]/n 23 | # data = smoothe(data,10) 24 | # data = data/data.std((0,1,2),True) 25 | # torch.save(data,'./data/flocking.pt') 26 | 27 | # print("Preprocessing Complete") 28 | 29 | data = torch.load('./data/flocking.pt') 30 | data = smoothe(data,4) 31 | data = data[:100,:20] 32 | 33 | #data = data[...,2:4] 34 | model = DMBD(obs_shape=data.shape[-2:],role_dims=(1,2,2),hidden_dims=(4,4,4),regression_dim = -1, control_dim = 0, number_of_objects=6, unique_obs=False) 35 | 36 | #model.A.mu[...,-1]=torch.randn(model.A.mu[...,-1].shape) 37 | model.update(data,None,None,iters=20,latent_iters=1,lr=1,verbose=True) 38 | 39 | #model.update(data_v[:,0:4],None,None,iters=2,latent_iters=4,lr=0.001,verbose=True) 40 | sbz=model.px.mean() 41 | B = model.obs_model.obs_dist.mean() 42 | if model.regression_dim==1: 43 | roles = B[...,:-1]@sbz + B[...,-1:] 44 | else: 45 | roles = B@sbz 46 | sbz = sbz.squeeze(-3).squeeze(-1) 47 | roles = roles.squeeze(-1)[...,0:2] 48 | 49 | batch_num = 1 50 | temp1 = data[:,batch_num,:,0] 51 | temp2 = data[:,batch_num,:,1] 52 | rtemp1 = roles[:,batch_num,:,0] 53 | rtemp2 = roles[:,batch_num,:,1] 54 | 55 | idx = (model.assignment()[:,batch_num,:]==0) 56 | plt.scatter(temp1[idx],temp2[idx],color='y',alpha=0.5) 57 | ev_dim = model.role_dims[0] 58 | ob_dim = np.sum(model.role_dims[1:]) 59 | 60 | for i in range(ev_dim): 61 | idx = (model.obs_model.assignment()[:,batch_num,:]==i) 62 | plt.scatter(rtemp1[:,i],rtemp2[:,i]) 63 | plt.title('Environment + Roles') 64 | plt.show() 65 | 66 | ctemp = model.role_dims[1]*('b',) + model.role_dims[2]*('r',) 67 | 68 | for j in range(model.number_of_objects): 69 | idx = (model.assignment()[:,batch_num,:]==0) 70 | plt.scatter(temp1[idx],temp2[idx],color='y',alpha=0.2) 71 | for i in range(1+2*j,1+2*(j+1)): 72 | idx = (model.assignment()[:,batch_num,:]==i) 73 | plt.scatter(temp1[idx],temp2[idx]) 74 | plt.title('Object '+str(j+1) + ' (yellow is environment)') 75 | plt.show() 76 | 77 | idx = (model.assignment()[:,batch_num,:]==0) 78 | plt.scatter(temp1[idx],temp2[idx],color='y',alpha=0.2) 79 | k=0 80 | for i in range(ev_dim+ob_dim*j,ev_dim+ob_dim*(j+1)): 81 | idx = (model.obs_model.assignment()[:,batch_num,:]==i) 82 | plt.scatter(rtemp1[:,i],rtemp2[:,i],color=ctemp[k]) 83 | k=k+1 84 | plt.title('Object '+str(j+1) + ' roles') 85 | plt.show() 86 | 87 | print('Making Movie') 88 | f = r"flock.mp4" 89 | ar = animate_results('particular',f, xlim = (-1,2), ylim = (-1,3), fps=20).make_movie(model, data, (0,1,2,3,4,5,6,7,8,9)) 90 | print('Done') 91 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any person obtaining 2 | a copy of this software and associated documentation files (the 3 | "Software"), to deal in the Software without restriction, including 4 | without limitation the rights to use, copy, modify, merge, publish, 5 | distribute, sublicense, and/or sell copies of the Software, and to 6 | permit persons to whom the Software is furnished to do so, subject to 7 | the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be 10 | included in all copies or substantial portions of the Software. 11 | 12 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 13 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 14 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 16 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 17 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 18 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /Life_as_we_know_it_test.py: -------------------------------------------------------------------------------- 1 | 2 | from models.DynamicMarkovBlanketDiscovery import * 3 | from matplotlib import pyplot as plt 4 | from matplotlib.animation import FuncAnimation, FFMpegWriter 5 | from matplotlib import cm 6 | start_time=time.time() 7 | 8 | 9 | print('Test on life as we know it data set') 10 | print('Loading Data...') 11 | y=np.genfromtxt('./data/ly.txt') 12 | x=np.genfromtxt('./data/lx.txt') 13 | print('Done.') 14 | y=torch.tensor(y,requires_grad=False).float().transpose(-2,-1) 15 | x=torch.tensor(x,requires_grad=False).float().transpose(-2,-1) 16 | y=y.unsqueeze(-1) 17 | x=x.unsqueeze(-1) 18 | data = torch.cat((x,y),dim=-1) 19 | data = data/data.std() 20 | data = data[847:].clone().detach() 21 | v_data = torch.diff(data,dim=0) 22 | v_data = v_data/v_data.std() 23 | data = data[1:] 24 | data = torch.cat((data,v_data),dim=-1) 25 | del x 26 | del y 27 | 28 | #data = data.reshape(12,100,128,2).transpose(0,1) 29 | #v_data = v_data.reshape(12,100,128,2).transpose(0,1) 30 | data = data.reshape(6,200,128,4).transpose(0,1) 31 | 32 | print('Initializing X + V model....') 33 | model = DMBD(obs_shape=data.shape[-2:],role_dims=(0,1,1),hidden_dims=(12,4,4),regression_dim = 0, control_dim=0,number_of_objects=6) 34 | print('Updating model X+V....') 35 | model.update(data,None,None,iters=40,latent_iters=1,lr=0.5,verbose=True) 36 | 37 | print('Making Movie') 38 | #f = r"c://Users/brain/OneDrive/Desktop/wil.mp4" 39 | f = r"wil.mp4" 40 | animate_results('particular',f).make_movie(model, data, list(range(data.shape[1]))) 41 | sbz=model.px.mean() 42 | B = model.obs_model.obs_dist.mean() 43 | if model.regression_dim==1: 44 | roles = B[...,:-1]@sbz + B[...,-1:] 45 | else: 46 | roles = B@sbz 47 | sbz = sbz.squeeze() 48 | roles = roles.squeeze()[...,0:2] 49 | 50 | batch_num = 0 51 | temp1 = data[:,batch_num,:,0] 52 | temp2 = data[:,batch_num,:,1] 53 | rtemp1 = roles[:,batch_num,:,0] 54 | rtemp2 = roles[:,batch_num,:,1] 55 | 56 | idx = (model.assignment()[:,batch_num,:]==0) 57 | plt.scatter(temp1[idx],temp2[idx],color='y',alpha=0.5) 58 | ev_dim = model.role_dims[0] 59 | ob_dim = np.sum(model.role_dims[1:]) 60 | 61 | for i in range(ev_dim): 62 | idx = (model.obs_model.assignment()[:,batch_num,:]==i) 63 | plt.scatter(rtemp1[:,i],rtemp2[:,i]) 64 | plt.title('Environment + Roles') 65 | plt.show() 66 | 67 | ctemp = model.role_dims[1]*('b',) + model.role_dims[2]*('r',) 68 | 69 | for j in range(model.number_of_objects): 70 | idx = (model.assignment()[:,batch_num,:]==0) 71 | plt.scatter(temp1[idx],temp2[idx],color='y',alpha=0.2) 72 | for i in range(1+2*j,1+2*(j+1)): 73 | idx = (model.assignment()[:,batch_num,:]==i) 74 | plt.scatter(temp1[idx],temp2[idx]) 75 | plt.title('Object '+str(j+1) + ' (yellow is environment)') 76 | plt.show() 77 | 78 | idx = (model.assignment()[:,batch_num,:]==0) 79 | plt.scatter(temp1[idx],temp2[idx],color='y',alpha=0.2) 80 | k=0 81 | for i in range(ev_dim+ob_dim*j,ev_dim+ob_dim*(j+1)): 82 | idx = (model.obs_model.assignment()[:,batch_num,:]==i) 83 | plt.scatter(rtemp1[:,i],rtemp2[:,i],color=ctemp[k]) 84 | k=k+1 85 | plt.title('Object '+str(j+1) + ' roles') 86 | plt.show() 87 | 88 | 89 | run_time = time.time()-start_time 90 | print('Total Run Time: ',run_time) 91 | 92 | 93 | 94 | # # make frame by frame movie using particular designations 95 | # assignments = model.particular_assignment()/model.number_of_objects 96 | # confidence = model.assignment_pr().max(-1)[0] 97 | 98 | # fig = plt.figure(figsize=(7,7)) 99 | # ax = plt.axes(xlim=(-2.5,2.5),ylim=(-2.5,2.5)) 100 | # scatter=ax.scatter([], [], cmap = cm.rainbow, c=[], vmin=0.0, vmax=1.0) 101 | 102 | # T = data.shape[0] 103 | # fn = 0 104 | # scatter.set_offsets(data[fn%T, fn//T,:,:].numpy()) 105 | # scatter.set_array(assignments[fn%T, fn//T,:].numpy()) 106 | # scatter.set_alpha(confidence[fn%T, fn//T,:].numpy()) 107 | # plt.savefig('lenia0.png') 108 | 109 | -------------------------------------------------------------------------------- /Lorenz_example.py: -------------------------------------------------------------------------------- 1 | 2 | from models.DynamicMarkovBlanketDiscovery import * 3 | from matplotlib import pyplot as plt 4 | from matplotlib.animation import FuncAnimation, FFMpegWriter 5 | from matplotlib import cm 6 | 7 | ############################################################################### 8 | print('Test on Lorenz attractor') 9 | from simulations import Lorenz 10 | from matplotlib import pyplot as plt 11 | from matplotlib.colors import ListedColormap, Normalize 12 | cmap = ListedColormap(['red', 'green', 'blue']) 13 | vmin = 0 # Minimum value of the color scale 14 | vmax = 2 # Maximum value of the color scale 15 | norm = Normalize(vmin=vmin, vmax=vmax) 16 | 17 | sim = Lorenz.Lorenz() 18 | data = sim.simulate(100) 19 | 20 | data = torch.cat((data[...,0,:],data[...,1,:],data[...,2,:]),dim=-1).unsqueeze(-2) 21 | data = data - data.mean((0,1,2),True) 22 | data = data/data.std() 23 | 24 | model = DMBD(obs_shape=data.shape[-2:],role_dims=(4,4,4),hidden_dims=(3,3,3),batch_shape=(),regression_dim = 0, control_dim=0,number_of_objects=1) 25 | model.obs_model.ptemp = 6.0 26 | #model.update(data,None,None,iters=1,lr=1) 27 | iters = 10 28 | loc1 = torch.tensor((-0.5,-0.6,1.6)) 29 | loc2 = torch.tensor((0.5,0.6,1.6)) 30 | for i in range(iters): 31 | model.update(data,None,None,iters=2,latent_iters=1,lr=0.5,verbose=True) 32 | 33 | 34 | sbz=model.px.mean().squeeze() 35 | r1 = model.role_dims[0] 36 | r2 = r1+model.role_dims[1] 37 | r3 = r2+ model.role_dims[2] 38 | h1 = model.hidden_dims[0] 39 | h2 = h1+model.hidden_dims[1] 40 | h3 = h2+ model.hidden_dims[2] 41 | 42 | cmap = ListedColormap(['blue', 'green', 'red']) 43 | 44 | p = model.assignment_pr() 45 | a = model.assignment() 46 | batch_num = 0 47 | fig = plt.figure() 48 | ax = fig.add_subplot(111, projection='3d') 49 | ax.scatter(data[:,batch_num,0,0],data[:,batch_num,0,2],data[:,batch_num,0,4],cmap=cmap,norm=norm,c=a[:,batch_num,0]) 50 | ax.xticklabels = [] 51 | ax.yticklabels = [] 52 | ax.zticklabels = [] 53 | ax.xlable = 'x' 54 | ax.ylable = 'y' 55 | ax.zlable = 'z' 56 | plt.savefig('lorenz3d.png') 57 | plt.show() 58 | 59 | fig = plt.figure() 60 | ax = fig.add_subplot(111) 61 | 62 | # plt.scatter(data[:,batch_num,:,0],data[:,batch_num,:,2],c=a[:,batch_num,:],cmap=cmap,norm=norm) 63 | # plt.show() 64 | ax.scatter(data[:,batch_num,:,0],data[:,batch_num,:,4],c=a[:,batch_num,:],cmap=cmap,norm=norm) 65 | ax.xticklabels = [] 66 | ax.yticklabels = [] 67 | ax.xlabel = 'x' 68 | ax.ylabel = 'z' 69 | plt.savefig('lorenz2d.png') 70 | plt.show() 71 | # plt.scatter(data[:,batch_num,:,2],data[:,batch_num,:,4],c=a[:,batch_num,:],cmap=cmap,norm=norm) 72 | # plt.show() 73 | 74 | d1 = (data[...,0::2] - loc1).pow(2).sum(-1).sqrt() 75 | d2 = (data[...,0::2] - loc2).pow(2).sum(-1).sqrt() 76 | 77 | plt.scatter(d1[:,batch_num],d2[:,batch_num],c=a[:,batch_num,:],cmap=cmap,norm=norm) 78 | plt.show() 79 | 80 | p = p.sum(-2) 81 | print('Show PC scores') 82 | s = sbz[:,:,0:h1] 83 | s = s-s.mean(0).mean(0) 84 | b = sbz[:,:,h1:h2] 85 | b = b-b.mean(0).mean(0) 86 | z = sbz[:,:,h2:h3] 87 | z = z-z.mean(0).mean(0) 88 | 89 | cs = (s.unsqueeze(-1)*s.unsqueeze(-2)).mean(0).mean(0) 90 | cb = (b.unsqueeze(-1)*b.unsqueeze(-2)).mean(0).mean(0) 91 | cz = (z.unsqueeze(-1)*z.unsqueeze(-2)).mean(0).mean(0) 92 | 93 | d,v=torch.linalg.eigh(cs) 94 | ss = v.transpose(-2,-1)@s.unsqueeze(-1) 95 | d,v=torch.linalg.eigh(cb) 96 | bb = v.transpose(-2,-1)@b.unsqueeze(-1) 97 | d,v=torch.linalg.eigh(cz) 98 | zz = v.transpose(-2,-1)@z.unsqueeze(-1) 99 | 100 | ss = ss.squeeze(-1)[...,-2:] 101 | bb = bb.squeeze(-1)[...,-2:] 102 | zz = zz.squeeze(-1)[...,-2:] 103 | 104 | ss = ss/ss.std() 105 | bb = bb/bb.std() 106 | zz = zz/zz.std() 107 | 108 | batch_num = 0 109 | fig, axs = plt.subplots(2, 1, sharex=True) 110 | 111 | axs[0].plot(zz[:,batch_num,-1:],'r',label='s') 112 | axs[0].plot(bb[:,batch_num,-1:],'g',label='b') 113 | axs[0].plot(ss[:,batch_num,-1:],'b',label='z') 114 | axs[0].set_title('Top PC Score') 115 | # handles, labels = axs[0].get_legend_handles_labels() 116 | # selected_handles = [handles[0], handles[2], handles[4]] 117 | # selected_labels = [labels[0], labels[2], labels[4]] 118 | # axs[0].legend(selected_handles, selected_labels) 119 | axs[0].legend() 120 | 121 | axs[1].plot(p[:,batch_num,2],'r') 122 | axs[1].plot(p[:,batch_num,1],'g') 123 | axs[1].plot(p[:,batch_num,0],'b') 124 | axs[1].set_title('Number of Assigned Nodes') 125 | axs[1].set_xlabel('Time') 126 | plt.savefig('lorenz_pc_scores.png') 127 | plt.show() 128 | 129 | 130 | -------------------------------------------------------------------------------- /NewtonsCradle_example.py: -------------------------------------------------------------------------------- 1 | 2 | from models.DynamicMarkovBlanketDiscovery import * 3 | from matplotlib import pyplot as plt 4 | from matplotlib.animation import FuncAnimation, FFMpegWriter 5 | from matplotlib import cm 6 | from simulations.NewtonsCradle import NewtonsCradle 7 | 8 | dmodel = NewtonsCradle(n_balls=5,ball_size=0.2,Tmax=500,batch_size=40,g=1,leak=0.05/8,dt=0.05,include_string=False) 9 | 10 | data_temp = dmodel.generate_data('random')[0] 11 | data_0 = data_temp[0::5] 12 | data_temp = dmodel.generate_data('1 ball object')[0] 13 | data_1 = data_temp[0::5] 14 | data_temp = dmodel.generate_data('2 ball object')[0] 15 | data_2 = data_temp[0::5] 16 | # data_temp = dmodel.generate_data('3 ball object')[0] 17 | # data_3 = data_temp[0::5] 18 | # data_temp = dmodel.generate_data('4 ball object')[0] 19 | # data_4 = data_temp[0::5] 20 | data_temp = dmodel.generate_data('1 + 1 ball object')[0] 21 | data_11 = data_temp[0::5] 22 | data_temp = dmodel.generate_data('2 + 2 ball object')[0] 23 | data_22 = data_temp[0::5] 24 | data_temp = dmodel.generate_data('1 + 2 ball object')[0] 25 | data_12 = data_temp[0::5] 26 | # data_temp = dmodel.generate_data('2 + 3 ball object')[0] 27 | # data_23 = data_temp[0::5] 28 | # data_temp = dmodel.generate_data('1 + 3 ball object')[0] 29 | # data_13 = data_temp[0::5] 30 | 31 | datas = (data_0,data_1,data_2,data_11,data_12,data_22)#,data_3,data_4,data_11,data_12,data_13,data_22,data_23) 32 | dy = torch.zeros(2) 33 | delta = 0.5 34 | xlim = (-1.5,1.5) 35 | ylim = (-1.2+delta,0.2+delta) 36 | dy[1] = delta 37 | new_datas = () 38 | for k,data in enumerate(datas): 39 | data = data + dy 40 | v_data = torch.diff(data,dim=0) 41 | v_data = v_data/v_data.std() 42 | new_datas = new_datas + (torch.cat((data[1:],v_data),dim=-1),) 43 | datas = new_datas 44 | 45 | 46 | # num_mixtures = 5 47 | # batch_shape = () 48 | # hidden_dims = (4,4,4) 49 | # role_dims = (2,2,2) 50 | # iters = 40 51 | # model0 = DMBD(obs_shape=(5,4),role_dims=role_dims,hidden_dims=hidden_dims,batch_shape=batch_shape,regression_dim = -1, control_dim=-1) 52 | # model1 = DMBD(obs_shape=(5,4),role_dims=role_dims,hidden_dims=hidden_dims,batch_shape=batch_shape,regression_dim = -1, control_dim=-1) 53 | # model2 = DMBD(obs_shape=(5,4),role_dims=role_dims,hidden_dims=hidden_dims,batch_shape=batch_shape,regression_dim = -1, control_dim=-1) 54 | # model3 = DMBD(obs_shape=(5,4),role_dims=role_dims,hidden_dims=hidden_dims,batch_shape=batch_shape,regression_dim = -1, control_dim=-1) 55 | # model4 = DMBD(obs_shape=(5,4),role_dims=role_dims,hidden_dims=hidden_dims,batch_shape=batch_shape,regression_dim = -1, control_dim=-1) 56 | # model11 = DMBD(obs_shape=(5,4),role_dims=role_dims,hidden_dims=hidden_dims,batch_shape=batch_shape,regression_dim = -1, control_dim=-1) 57 | # model12 = DMBD(obs_shape=(5,4),role_dims=role_dims,hidden_dims=hidden_dims,batch_shape=batch_shape,regression_dim = -1, control_dim=-1) 58 | # model13 = DMBD(obs_shape=(5,4),role_dims=role_dims,hidden_dims=hidden_dims,batch_shape=batch_shape,regression_dim = -1, control_dim=-1) 59 | # model22 = DMBD(obs_shape=(5,4),role_dims=role_dims,hidden_dims=hidden_dims,batch_shape=batch_shape,regression_dim = -1, control_dim=-1) 60 | # model23 = DMBD(obs_shape=(5,4),role_dims=role_dims,hidden_dims=hidden_dims,batch_shape=batch_shape,regression_dim = -1, control_dim=-1) 61 | 62 | # models = [] 63 | # data = torch.cat(datas[0:3],dim=1) 64 | # for i in range(10): 65 | # model = DMBD(obs_shape=data.shape[-2:],role_dims=(8,4,8),hidden_dims=(4,2,4),batch_shape=(),regression_dim = -1, control_dim=0) 66 | # models.append(model) 67 | 68 | # ELBO = [] 69 | # for k, model in enumerate(models): 70 | # model.update(data,None,None,iters=20,latent_iters=1,lr=0.5) 71 | # ELBO.append(model.ELBO()) 72 | 73 | data = torch.cat(datas[3:5],dim=1).clone().detach() 74 | data = torch.cat((datas[0],data),dim=1).clone().detach() 75 | print('simulations complete') 76 | 77 | model = DMBD(obs_shape=data.shape[-2:],role_dims=(8,8,8),hidden_dims=(4,4,4),batch_shape=(),regression_dim = -1, control_dim=0) 78 | # model0 = DMBD(obs_shape=data.shape[-2:],role_dims=(20,0,0),hidden_dims=(10,0,0),batch_shape=(),regression_dim = -1, control_dim=0) 79 | # model0.update(data,None,None,iters=2*iters,latent_iters=1,lr=0.5,verbose=True) 80 | # f = r"./cradle0.mp4" 81 | # ar = animate_results('role',f, xlim = (-1.5,1.5), ylim = (-0.2,1.2), fps=10) 82 | # ar.make_movie(model0, data, (0,40,60,80))#,120))#,60,61,80,81)) 83 | iters = 80 84 | r1 = model.role_dims[0] 85 | r2 = r1+model.role_dims[1] 86 | r3 = r2+ model.role_dims[2] 87 | h1 = model.hidden_dims[0] 88 | h2 = h1+model.hidden_dims[1] 89 | h3 = h2+ model.hidden_dims[2] 90 | batch_num = 50 91 | 92 | for i in range(iters): 93 | model.update(data,None,None,iters=1,latent_iters=1,lr=0.5,verbose=True) 94 | 95 | # batch_num = torch.randint(0,data.shape[1],(1,)).item() 96 | sbz=model.px.mean() 97 | B = model.obs_model.obs_dist.mean() 98 | if model.regression_dim==0: 99 | roles = B@sbz 100 | else: 101 | roles = B[...,:-1]@sbz + B[...,-1:] 102 | sbz = sbz.squeeze() 103 | roles = roles.squeeze() 104 | idx = model.obs_model.NA/model.obs_model.NA.sum()>0.01 105 | 106 | r1 = model.role_dims[0] 107 | r2 = r1+model.role_dims[1] 108 | r3 = r2+ model.role_dims[2] 109 | 110 | pbar = model.obs_model.NA/model.obs_model.NA.sum() 111 | pbar = pbar/pbar.max() 112 | p1=model.obs_model.p[:,batch_num,:,list(range(0,r1))].mean(-2) 113 | p2=model.obs_model.p[:,batch_num,:,list(range(r1,r2))].mean(-2) 114 | p3=model.obs_model.p[:,batch_num,:,list(range(r2,r3))].mean(-2) 115 | 116 | plt.scatter(roles[:,batch_num,list(range(0,r1)),0],roles[:,batch_num,list(range(0,r1)),1],color='r',alpha=0.25) 117 | plt.scatter(roles[:,batch_num,list(range(r1,r2)),0],roles[:,batch_num,list(range(r1,r2)),1],color='g',alpha=0.25) 118 | plt.scatter(roles[:,batch_num,list(range(r2,r3)),0],roles[:,batch_num,list(range(r2,r3)),1],color='b',alpha=0.25) 119 | plt.xlim(xlim) 120 | plt.ylim(ylim) 121 | plt.show() 122 | # plt.plot(roles[:,batch_num,list(range(0,r1)),2],roles[:,batch_num,list(range(0,r1)),3],color='b') 123 | # plt.plot(roles[:,batch_num,list(range(r1,r2)),2],roles[:,batch_num,list(range(r1,r2)),3],color='g') 124 | # plt.plot(roles[:,batch_num,list(range(r2,r3)),2],roles[:,batch_num,list(range(r2,r3)),3],color='r') 125 | # plt.show() 126 | 127 | p = model.assignment_pr() 128 | p = p.sum(-2) 129 | print('Show PC scores') 130 | s = sbz[:,:,0:h1] 131 | s = s-s.mean(0).mean(0) 132 | b = sbz[:,:,h1:h2] 133 | b = b-b.mean(0).mean(0) 134 | z = sbz[:,:,h2:h3] 135 | z = z-z.mean(0).mean(0) 136 | 137 | cs = (s.unsqueeze(-1)*s.unsqueeze(-2)).mean(0).mean(0) 138 | cb = (b.unsqueeze(-1)*b.unsqueeze(-2)).mean(0).mean(0) 139 | cz = (z.unsqueeze(-1)*z.unsqueeze(-2)).mean(0).mean(0) 140 | 141 | d,v=torch.linalg.eigh(cs) 142 | ss = v.transpose(-2,-1)@s.unsqueeze(-1) 143 | print('Normalized Eigenvalues of s',d/d.sum()) 144 | d,v=torch.linalg.eigh(cb) 145 | print('Normalized Eigenvalues of b',d/d.sum()) 146 | bb = v.transpose(-2,-1)@b.unsqueeze(-1) 147 | d,v=torch.linalg.eigh(cz) 148 | print('Normalized Eigenvalues of z',d/d.sum()) 149 | zz = v.transpose(-2,-1)@z.unsqueeze(-1) 150 | 151 | ss = ss.squeeze(-1)[...,-2:] 152 | bb = bb.squeeze(-1)[...,-2:] 153 | zz = zz.squeeze(-1)[...,-2:] 154 | 155 | ss = ss/ss.std() 156 | bb = bb/bb.std() 157 | zz = zz/zz.std() 158 | 159 | fig, axs = plt.subplots(2, 1, sharex=True) 160 | 161 | axs[0].plot(ss[:,batch_num,-1:],'r',label='s') 162 | axs[0].plot(bb[:,batch_num,-1:],'g',label='b') 163 | axs[0].plot(zz[:,batch_num,-1:],'b',label='z') 164 | axs[0].set_title('Top PC Score') 165 | # handles, labels = axs[0].get_legend_handles_labels() 166 | # selected_handles = [handles[0], handles[2], handles[4]] 167 | # selected_labels = [labels[0], labels[2], labels[4]] 168 | # axs[0].legend(selected_handles, selected_labels) 169 | axs[0].legend() 170 | 171 | axs[1].plot(p[:,batch_num,0],'r') 172 | axs[1].plot(p[:,batch_num,1],'g') 173 | axs[1].plot(p[:,batch_num,2],'b') 174 | axs[1].set_title('Number of Assigned Objects') 175 | axs[1].set_xlabel('Time') 176 | #plt.savefig('C://Users/brain/Desktop/cradlePCs1.png') 177 | plt.show() 178 | 179 | print('Generating Movie...') 180 | # f = r"C://Users/brain/OneDrive/Desktop/cradle.mp4" 181 | # f = r"C://Users/brain/Desktop/cradle.mp4" 182 | f = r"./cradle.mp4" 183 | ar = animate_results('sbz',f, xlim = xlim, ylim = ylim, fps=10) 184 | ar.make_movie(model, data, (0,20,40,60,80,100))#,120))#,60,61,80,81)) 185 | 186 | batch_num = 40 187 | plt.scatter(data[:,batch_num,:,0],data[:,batch_num,:,1],cmap='rainbow_r',c=model.obs_model.p.argmax(-1)[:,batch_num,:]) 188 | -------------------------------------------------------------------------------- /Presentations/Dynamic Markov Blanket Discovery.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesianempirimancer/pyDMBD/e058e9cca8709016c7f7c4eb01b0d432b9f1a067/Presentations/Dynamic Markov Blanket Discovery.pptx -------------------------------------------------------------------------------- /Presentations/temp.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .simulations import * 3 | -------------------------------------------------------------------------------- /cradle.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesianempirimancer/pyDMBD/e058e9cca8709016c7f7c4eb01b0d432b9f1a067/cradle.mp4 -------------------------------------------------------------------------------- /cradle2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesianempirimancer/pyDMBD/e058e9cca8709016c7f7c4eb01b0d432b9f1a067/cradle2.mp4 -------------------------------------------------------------------------------- /data/flame_even_smaller.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesianempirimancer/pyDMBD/e058e9cca8709016c7f7c4eb01b0d432b9f1a067/data/flame_even_smaller.pt -------------------------------------------------------------------------------- /data/lorenz_data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesianempirimancer/pyDMBD/e058e9cca8709016c7f7c4eb01b0d432b9f1a067/data/lorenz_data.pt -------------------------------------------------------------------------------- /data/rotor.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesianempirimancer/pyDMBD/e058e9cca8709016c7f7c4eb01b0d432b9f1a067/data/rotor.pt -------------------------------------------------------------------------------- /flame_pc_scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesianempirimancer/pyDMBD/e058e9cca8709016c7f7c4eb01b0d432b9f1a067/flame_pc_scores.png -------------------------------------------------------------------------------- /models/ARHMM.py: -------------------------------------------------------------------------------- 1 | # Variational Bayesian Expectation Maximization Autoregressive HMM. This is a subclass of HMM. 2 | # It assumes a generative model of the form: 3 | # p(y_t|x^t,z_t) = N(y_t|A_z^t x^t + b_z_t, Sigma_z_t) 4 | # where z_t is HMM. 5 | 6 | import torch 7 | from .dists import MatrixNormalWishart 8 | from .dists import MultivariateNormal_vector_format 9 | from .dists.utils import matrix_utils 10 | from .HMM import HMM 11 | from .dists import Delta 12 | 13 | class ARHMM(HMM): 14 | def __init__(self,dim,n,p,batch_shape = (),pad_X=True,X_mask = None, mask=None, transition_mask=None): 15 | dist = MatrixNormalWishart(torch.zeros(batch_shape + (dim,n,p),requires_grad=False),pad_X=pad_X,X_mask=X_mask,mask=mask) 16 | super().__init__(dist,transition_mask=transition_mask) 17 | 18 | def obs_logits(self,XY,t=None): 19 | if t is not None: 20 | return self.obs_dist.Elog_like(XY[0][t],XY[1][t]) 21 | else: 22 | return self.obs_dist.Elog_like(XY[0],XY[1]) 23 | 24 | def update_obs_parms(self,XY,lr): 25 | self.obs_dist.raw_update(XY[0],XY[1],self.p,lr) 26 | 27 | # def update_states(self,XY): 28 | # T = XY[0].shape[0] 29 | # super().update_states(XY,T) 30 | 31 | def Elog_like_X_given_Y(self,Y): 32 | invSigma_x_x, invSigmamu_x, Residual = self.obs_dist.Elog_like_X_given_Y(Y) 33 | if self.p is not None: 34 | invSigma_x_x = (invSigma_x_x*self.p.unsqueeze(-1).unsqueeze(-2)).sum(-3) 35 | invSigmamu_x = (invSigmamu_x*self.p.unsqueeze(-1).unsqueeze(-2)).sum(-3) 36 | Residual = (Residual*self.p).sum(-1) 37 | return invSigma_x_x, invSigmamu_x, Residual 38 | 39 | class ARHMM_prXY(HMM): 40 | def __init__(self,dim,n,p,batch_shape = (),X_mask = None, mask=None,pad_X=True, transition_mask = None): 41 | dist = MatrixNormalWishart(torch.zeros(batch_shape + (dim,n,p),requires_grad=False),mask=mask,X_mask=X_mask,pad_X=pad_X) 42 | super().__init__(dist,transition_mask = transition_mask) 43 | 44 | def obs_logits(self,XY): 45 | return self.obs_dist.Elog_like_given_pX_pY(XY[0],XY[1]) 46 | 47 | def update_obs_parms(self,XY,lr): 48 | self.obs_dist.update(XY[0],XY[1],self.p,lr) 49 | 50 | def Elog_like_X_given_pY(self,pY): 51 | invSigma_x_x, invSigmamu_x, Residual = self.obs_dist.Elog_like_X_given_pY(pY) 52 | if self.p is not None: 53 | invSigma_x_x = (invSigma_x_x*self.p.view(self.p.shape + (1,)*2)).sum(-3) 54 | invSigmamu_x = (invSigmamu_x*self.p.view(self.p.shape + (1,)*2)).sum(-3) 55 | Residual = (Residual*self.p).sum(-1) 56 | return invSigma_x_x, invSigmamu_x, Residual 57 | 58 | 59 | class ARHMM_prXRY(HMM): # Assumes that R and Y are observed 60 | def __init__(self,dim,n,p1,p2,batch_shape=(),mask=None,X_mask = None, transition_mask = None, pad_X=False): 61 | self.p1 = p1 62 | self.p2 = p2 63 | dist = MatrixNormalWishart(torch.zeros(batch_shape + (dim,n,p1+p2),requires_grad=False),mask=mask,X_mask=X_mask,pad_X=pad_X) 64 | super().__init__(dist,transition_mask=transition_mask) 65 | 66 | def Elog_like(self,XRY): 67 | return (self.obs_logits(XRY)*self.p).sum(-1) 68 | 69 | def obs_logits(self,XRY): 70 | # Elog_like_given_pX_pY only uses EXX and EX so just need to update Sigma and mu!!!! 71 | # This assumes that XRY[0] is in vector format and sizes are compatible 72 | 73 | Sigma = matrix_utils.block_diag_matrix_builder(XRY[0].ESigma(),torch.zeros(XRY[0].shape[:-2]+(self.p2,self.p2),requires_grad=False)) 74 | mu = torch.cat((XRY[0].mean(),XRY[1]),dim=-2) 75 | return self.obs_dist.Elog_like_given_pX_pY(MultivariateNormal_vector_format(mu=mu,Sigma=Sigma),Delta(XRY[2])) 76 | 77 | def update_obs_parms(self,XRY,lr): #only uses expectations 78 | Sigma = matrix_utils.block_diag_matrix_builder(XRY[0].ESigma(),torch.zeros(XRY[0].shape[:-2]+(self.p2,self.p2),requires_grad=False)) 79 | mu = torch.cat((XRY[0].mean(),XRY[1]),dim=-2) 80 | prXR = MultivariateNormal_vector_format(mu=mu,Sigma=Sigma) 81 | self.obs_dist.update(prXR,Delta(XRY[2]),self.p,lr) 82 | 83 | def Elog_like_X(self,YR): 84 | invSigma_xr_xr, invSigmamu_xr, Residual = self.obs_dist.Elog_like_X(YR[0]) 85 | invSigma_x_x = invSigma_xr_xr[...,:self.p1,:self.p1] 86 | invSigmamu_x = invSigmamu_xr[...,:self.p1,:] - invSigma_xr_xr[...,:self.p1,self.p1:]@YR[1] 87 | Residual = Residual - 0.5*(invSigma_xr_xr[...,self.p1:,self.p1:]*(YR[1]*YR[1].transpose(-2,-1))).sum(-1).sum(-1) 88 | Residual = Residual + (invSigmamu_xr[...,self.p1:,:]*YR[1]).sum(-1).sum(-1) 89 | 90 | if self.p is not None: 91 | invSigma_x_x = (invSigma_x_x*self.p.view(self.p.shape + (1,)*2)).sum(-3) 92 | invSigmamu_x = (invSigmamu_x*self.p.view(self.p.shape + (1,)*2)).sum(-3) 93 | Residual = (Residual*self.p).sum(-1) 94 | 95 | return invSigma_x_x, invSigmamu_x, Residual 96 | 97 | 98 | -------------------------------------------------------------------------------- /models/BayesNet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .dMixtureofLinearTransforms import dMixtureofLinearTransforms 4 | from .dists import MatrixNormalWishart, MatrixNormalGamma, MultivariateNormal_vector_format, Delta 5 | from .dists.utils import matrix_utils 6 | # implements sequential a mixture of linear transforms 7 | # n is the output dimension 8 | # p is the input dimension 9 | # mixture dims is a list of the dimensions of the hidden layers, i.e. mixture_dims.shape = (n_layers) 10 | # Layer 0 is hidden_dims[0] x p, Layer 1 is hidden_dims[1] x hidden_dims[0], ... Layer n is hiddem_dims[n-1] x hidden_dims[n-1], Layer n+1 is n x hidden_dims[n-1] 11 | # mixture_dims[0] is the number of mixture components for the first layer, mixture_dims[1] is the number of mixture components for the second layer, etc. 12 | # Total number of layers is len(hidden_dims) + 1 13 | # Total number of messages is len(hidden_dims) 14 | 15 | class BayesNet(): 16 | 17 | def __init__(self, n, p, hidden_dims, mixture_dims, batch_shape=(),pad_X=True): 18 | self.num_layers = len(mixture_dims) 19 | self.mixture_dims = mixture_dims 20 | self.batch_shape = batch_shape 21 | Input_Dist = MatrixNormalWishart 22 | Layer_Dist = MatrixNormalWishart 23 | Output_Dist = MatrixNormalWishart 24 | # self.layers = [dMixtureofLinearTransforms(hidden_dims[0],p,mixture_dims[0],batch_shape=batch_shape,pad_X=True)] 25 | self.layers = [Input_Dist(mu_0 = torch.zeros(batch_shape + (hidden_dims[0],p),requires_grad=False), 26 | U_0 = torch.eye(hidden_dims[0],requires_grad=False), 27 | pad_X=True)] 28 | for i in range(1,self.num_layers): 29 | # self.layers.append(dMixtureofLinearTransforms(hidden_dims[i],hidden_dims[i-1],mixture_dims[i],batch_shape=batch_shape,pad_X=True)) 30 | self.layers.append(Layer_Dist( 31 | mu_0 = torch.zeros(batch_shape + (hidden_dims[i],hidden_dims[i-1]),requires_grad=False), 32 | U_0 = torch.eye(hidden_dims[i],requires_grad=False), 33 | pad_X=True)) 34 | self.layers.append(Output_Dist(mu_0 = torch.zeros(batch_shape + (n,hidden_dims[-1]),requires_grad=False), 35 | U_0 = torch.eye(n,requires_grad=False), 36 | pad_X=True)) 37 | self.MSE = [] 38 | self.ELBO_save = [] 39 | self.ELBO_last = -torch.tensor(-torch.inf) 40 | 41 | for layer in self.layers: 42 | layer.mu = torch.randn_like(layer.mu)/np.sqrt(np.sum(hidden_dims)/len(hidden_dims)) 43 | 44 | def update(self,X,Y,iters=1.0,lr=1.0,verbose=False,FBI=False): 45 | # X is sample x batch x p 46 | # returns sample x batch x n 47 | for i in range(iters): 48 | 49 | mu, Sigma, invSigma, invSigmamu = self.layers[0].predict(X.unsqueeze(-1)) 50 | pX_forward = [MultivariateNormal_vector_format(mu=mu, Sigma=Sigma,invSigma=invSigma,invSigmamu=invSigmamu)] # forward message into layer 1 51 | for n in range(1,self.num_layers): 52 | pX_forward.append(self.layers[n].forward(pX_forward[n-1])) 53 | Y_pred = self.layers[-1].forward(pX_forward[-1]).mean().squeeze(-1) 54 | 55 | pX_backward = [None]*self.num_layers 56 | pX = [None]*self.num_layers 57 | invSigma, invSigmamu, Res = self.layers[-1].Elog_like_X(Y.unsqueeze(-1)) 58 | pX_backward[-1] = MultivariateNormal_vector_format(invSigmamu=invSigmamu,invSigma=invSigma) 59 | pX[-1] = MultivariateNormal_vector_format(invSigma = pX_forward[-1].EinvSigma() + pX_backward[-1].EinvSigma(), 60 | invSigmamu = pX_forward[-1].EinvSigmamu() + pX_backward[-1].EinvSigmamu()) 61 | # FBI step -1 HERE 62 | if FBI is True: 63 | self.layers[-1].update(pX[-1],Delta(Y.unsqueeze(-1)),lr=lr) 64 | invSigma, invSigmamu, Res = self.layers[-1].Elog_like_X(Y.unsqueeze(-1)) 65 | pX_backward[-1] = MultivariateNormal_vector_format(invSigmamu=invSigmamu,invSigma=invSigma) 66 | pX[-1] = MultivariateNormal_vector_format(invSigma = pX_forward[-1].EinvSigma() + pX_backward[-1].EinvSigma(), 67 | invSigmamu = pX_forward[-1].EinvSigmamu() + pX_backward[-1].EinvSigmamu()) 68 | 69 | for n in range(self.num_layers-1,0,-1): 70 | pX_backward[n-1]=self.layers[n].backward(pX_backward[n])[0] 71 | pX[n-1] = MultivariateNormal_vector_format(invSigma = pX_forward[n-1].EinvSigma() + pX_backward[n-1].EinvSigma(), 72 | invSigmamu = pX_forward[n-1].EinvSigmamu() + pX_backward[n-1].EinvSigmamu()) 73 | # FBI ALGORITHM STEP n HERE 74 | if FBI is True: 75 | self.layers[n].update(pX[n-1],pX[n],lr=lr) 76 | pX_backward[n-1]=self.layers[n].backward(pX_backward[n])[0] 77 | pX[n-1] = MultivariateNormal_vector_format(invSigma = pX_forward[n-1].EinvSigma() + pX_backward[n-1].EinvSigma(), 78 | invSigmamu = pX_forward[n-1].EinvSigmamu() + pX_backward[n-1].EinvSigmamu()) 79 | # FBI ALGORITHM STEP 0 HERE 80 | if FBI is True: 81 | self.layers[0].update(Delta(X.unsqueeze(-1)),pX[0],lr=lr) 82 | 83 | self.ELBO = self.Elog_like(X,Y,pX).sum(0) - self.KLqprior() 84 | # self.pX = pX_forward 85 | # n = i%len(self.layers) 86 | # if n == self.num_layers: 87 | # self.layers[-1].update(pX_forward[-1],Delta(Y.unsqueeze(-1)),lr=1.0) 88 | # elif n == 0: 89 | # self.layers[0].update(Delta(X.unsqueeze(-1)),pX_forward[0],lr=lr) 90 | # else: 91 | # self.layers[n].update(pX_forward[n-1],pX_forward[n],lr=lr) 92 | 93 | if FBI is not True: 94 | self.layers[-1].update(pX[-1],Delta(Y.unsqueeze(-1)),lr=lr) 95 | self.layers[0].update(Delta(X.unsqueeze(-1)),pX[0],lr=lr) 96 | for n in range(1,len(self.layers)-1): 97 | self.layers[n].update(pX[n-1],pX[n],lr=lr) 98 | # SExx = pX[n-1].EXXT().sum(0) 99 | # SEyy = pX[n].EXXT().sum(0) 100 | # PJyy = pX_backward[n].EinvSigma()+self.layers[n].EinvSigma() 101 | # PJyx = -self.layers[n].EinvUX() 102 | # PJxx = pX_forward[n-1].EinvSigma() + self.layers[n].EXTinvUX() 103 | # A,B = matrix_utils.block_matrix_inverse(PJyy,PJyx,PJyx.transpose(-1,-2),PJxx,block_form='left')[0:2] 104 | # SEyx = (pX[n].mean()@pX[n-1].mean().transpose(-1,-2)).sum(0) #+ (A@B).sum(0) 105 | # self.layers[n].ss_update(SExx,SEyx,SEyy,torch.tensor(Y.shape[0]),lr=lr) 106 | 107 | MSE = ((Y_pred-Y)**2).mean() 108 | self.MSE.append(MSE) 109 | self.ELBO_save.append(self.ELBO) 110 | 111 | self.pX = pX 112 | self.pX_forward = pX_forward 113 | self.pX_backward = pX_backward 114 | 115 | if verbose: 116 | print('Percent Change in ELBO = ',(self.ELBO-self.ELBO_last)/self.ELBO_last.abs(),' MSE = ',MSE) 117 | self.ELBO_last = self.ELBO 118 | 119 | def KLqprior(self): 120 | KL = 0.0 121 | for layer in self.layers: 122 | KL = KL + layer.KLqprior() 123 | return KL 124 | 125 | def Elog_like(self,X,Y,qX): 126 | Res = self.layers[0].Elog_like_given_pX_pY(Delta(X.unsqueeze(-1)),qX[0]) 127 | for i in range(1,self.num_layers): 128 | Res = Res + self.layers[i].Elog_like_given_pX_pY(qX[i-1],qX[i]) 129 | Res = Res + self.layers[-1].Elog_like_given_pX_pY(qX[-1],Delta(Y.unsqueeze(-1))) 130 | for q in qX: 131 | Res = Res - q.Res() 132 | return Res 133 | 134 | def predict(self,X): 135 | mu, Sigma, invSigma, invSigmamu = self.layers[0].predict(X.unsqueeze(-1)) 136 | pX_forward = MultivariateNormal_vector_format(mu=mu, Sigma=Sigma,invSigmamu=invSigmamu,invSigma=invSigma) # forward message into layer 1 137 | for n in range(1,self.num_layers+1): 138 | pX_forward = self.layers[n].forward(pX_forward) 139 | return pX_forward.mean() 140 | 141 | -------------------------------------------------------------------------------- /models/BayesianFactorAnalysis.py: -------------------------------------------------------------------------------- 1 | # Import necessary libraries 2 | import torch 3 | import numpy as np 4 | from .dists import MatrixNormalGamma, MultivariateNormal_vector_format 5 | 6 | # This class represents a Bayesian factor analysis model 7 | class BayesianFactorAnalysis(): 8 | # Constructor method 9 | def __init__(self, obs_dim, latent_dim, batch_shape=(), pad_X=True): 10 | # Initialize the model's parameters 11 | self.batch_shape = batch_shape 12 | self.batch_dim = len(batch_shape) 13 | self.event_dim = 2 14 | self.obs_dim = obs_dim 15 | self.latent_dim = latent_dim 16 | self.A = MatrixNormalGamma(mu_0=torch.zeros(batch_shape + (obs_dim, latent_dim), requires_grad=False)) 17 | 18 | # This method updates the latent variables of the model 19 | def update_latents(self, Y): 20 | # Compute the expected log-likelihood of the data given the latent variables 21 | invSigma, invSigmamu, Res = self.A.Elog_like_X(Y.unsqueeze(-1)) 22 | # Update the prior distribution over the latent variables 23 | self.pz = MultivariateNormal_vector_format(invSigma=invSigma + torch.eye(self.latent_dim, requires_grad=False), invSigmamu=invSigmamu) # sum over roles 24 | self.logZ = Res - self.pz.Res() 25 | 26 | # This method updates the model's parameters 27 | def update_parms(self, Y, lr=1.0): 28 | # Reshape the data 29 | Y = Y.view(Y.shape + (1,)) 30 | # Compute the expected sufficient statistics 31 | SEzz = self.pz.EXXT().sum(0) 32 | SEyy = (Y @ Y.transpose(-2,-1)).sum(0) 33 | SEyz = (Y @ self.pz.mean().transpose(-2, -1)).sum(0) 34 | N = torch.tensor(Y.shape[0]) 35 | # Update the parameters of the model 36 | self.A.ss_update(SEzz, SEyz, SEyy, N, lr=lr) 37 | 38 | # This method updates the model's latent variables and parameters 39 | def raw_update(self, Y, iters=1, lr=1.0, verbose=False): 40 | ELBO = -torch.tensor(torch.inf) 41 | # Iterate over the specified number of iterations 42 | for i in range(iters): 43 | # Update the latent variables 44 | self.update_latents(Y) 45 | # Update the parameters 46 | self.update_parms(Y, lr) 47 | # Compute the ELBO 48 | ELBO_new = self.ELBO() 49 | if verbose: 50 | print('Percent change in ELBO: ', (ELBO_new - ELBO) / ELBO.abs()) 51 | ELBO = ELBO_new 52 | 53 | # This method predicts the output of the model given the prior distribution over the latent variables 54 | def forward(self, pz): 55 | # Compute the mean and covariance of the posterior distribution over Y 56 | B = self.A.EinvUX() 57 | invD = (pz.EinvSigma()+self.A.EXTinvUX()).inverse() 58 | invSigma_yy = self.A.EinvSigma() - B@invD@B.transpose(-2,-1) 59 | invSigmamu_y = B@invD@pz.EinvSigmamu() 60 | Res = 0.5*self.A.ElogdetinvSigma() - 0.5*self.obs_dim*np.log(2*np.pi) + self.pz.Res() 61 | return MultivariateNormal_vector_format(invSigmamu=invSigmamu_y, invSigma=invSigma_yy), Res 62 | 63 | def backward(self,pY): 64 | invSigma, invSigmamu, Res = self.A.Elog_like_X_given_pY(pY) 65 | pz = MultivariateNormal_vector_format(invSigma=invSigma + torch.eye(self.latent_dim, requires_grad=False), invSigmamu=invSigmamu) # sum over roles 66 | return pz, Res-self.pz.Res() 67 | 68 | # This method computes the evidence lower bound (ELBO) of the model 69 | def ELBO(self): 70 | return self.logZ.sum() - self.KLqprior() 71 | 72 | # This method computes the Kullback-Leibler divergence between the prior distribution over the latent variables and the true prior 73 | def KLqprior(self): 74 | return self.A.KLqprior() # + self.alpha.KLqprior() 75 | 76 | 77 | -------------------------------------------------------------------------------- /models/BayesianTransformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .dists import Dirichlet 5 | from .dists import MatrixNormalWishart 6 | from .MixtureofLinearTransforms import MixtureofLinearTransforms 7 | 8 | class BayesianTransformer(): 9 | # The logic of the Bayesian Transformer is that observations, Y with size (num_obs,obs_dim), are probabilistically 10 | # clustered into mixture_dim groups that have different relationships to the latent, X. In generative modeling 11 | # terms, p(y_i|x,z_i) gives the probability of observation x_i given the output y and the latent assignment 12 | # z_i \in {1,...,mixture_dim}. Here i is the number of observations (each of length obs_dim) which can be vary from 13 | # sample to sample. To account for this variability with a tensor of fixed size, we account for the possibility that 14 | # x_i can contain nan values. If a nan is present, then the corresponding observation is not used in the calculation 15 | # of p(y|x). 16 | 17 | def __init__(self, mixture_dim, latent_dim, obs_dim, batch_shape = (), pad_X=True): 18 | pass 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /models/BlockFactorAnalysis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .dists.MatrixNormalWishart import MatrixNormalWishart 4 | from .dists.MatrixNormalGamma import MatrixNormalGamma 5 | from .dists.Dirichlet import Dirichlet 6 | from .dists import MultivariateNormal_vector_format 7 | from .dists import Delta 8 | 9 | class BlockFactorAnalysis(): 10 | 11 | def __init__(self, num_obs, n, p, num_blocks, batch_shape=(), pad_X=True): 12 | print('Not Working and Probably Fundamentally Flawed ') 13 | self.batch_shape = batch_shape 14 | self.batch_dim = len(batch_shape) 15 | self.event_dim = 2 16 | self.latent_dim = p 17 | self.num_blocks = num_blocks 18 | self.num_obs = num_obs 19 | self.A = MatrixNormalWishart(mu_0=torch.zeros(batch_shape + (num_obs, num_blocks, n, p), requires_grad=False), 20 | U_0=torch.zeros(batch_shape +(num_obs, num_blocks, n, n),requires_grad=False) + 100*torch.eye(n, requires_grad=False), 21 | pad_X=False) 22 | self.pi = Dirichlet(0.5*torch.ones(batch_shape + (num_blocks,), requires_grad=False)) 23 | self.pX = None 24 | self.p = None 25 | self.ELBO_last = -torch.tensor(torch.inf) 26 | 27 | def update_assignments(self,Y): 28 | if self.pX is None: 29 | self.p = torch.tensor(1.0) 30 | self.update_latents(Y) 31 | 32 | log_p = self.pi.loggeomean() + self.A.Elog_like_given_pX_pY(self.pX,Delta(Y.unsqueeze(-1).unsqueeze(-3))) 33 | log_p = log_p.sum(0,True) 34 | logZ = log_p.logsumexp(-1,keepdim=True) 35 | log_p = log_p - logZ 36 | self.logZ = logZ.squeeze(-1) 37 | self.p = log_p.exp() 38 | 39 | def update_latents(self,Y): 40 | if self.p is None: 41 | self.p=torch.tensor(1.0/self.num_blocks) 42 | invSigma, invSigmamu, Res = self.A.Elog_like_X(Y.unsqueeze(-1).unsqueeze(-3)) 43 | invSigma = (invSigma*self.p.view(self.p.shape + (1,1))).sum(-4,True) 44 | invSigmamu = (invSigmamu*self.p.view(self.p.shape + (1,1))).sum(-4,True) 45 | invSigma = invSigma + torch.eye(invSigma.shape[-1],requires_grad=False) 46 | Res = (Res*self.p).sum(-2,True) 47 | self.pX = MultivariateNormal_vector_format(invSigma=invSigma, invSigmamu=invSigmamu) 48 | Res = Res - self.pX.Res() 49 | return Res 50 | 51 | def update_parms(self,Y,lr=1.0): 52 | self.pX.invSigmamu = self.pX.invSigmamu.expand(self.pX.invSigmamu.shape[:-4] + (self.num_obs,) + self.pX.invSigmamu.shape[-3:]) 53 | self.A.update(self.pX,Delta(Y.unsqueeze(-1).unsqueeze(-3)),p=self.p,lr=lr) 54 | self.pi.raw_update(self.p,lr=lr) 55 | 56 | def update(self,Y,iters=1,lr=1.0,verbose=False): 57 | for i in range(iters): 58 | self.update_assignments(Y) 59 | Res = self.update_latents(Y) 60 | idx = self.p>0.00001 61 | ELBO = Res.sum() - (self.p[idx]*self.p[idx].log()).sum() + (self.p*self.pi.loggeomean()).sum() - self.KLqprior() 62 | if verbose: 63 | print('Percent change in ELBO: ', (ELBO - self.ELBO_last) / self.ELBO_last.abs()) 64 | self.ELBO_last = ELBO 65 | self.update_parms(Y,lr=lr) 66 | 67 | def KLqprior(self): 68 | return self.A.KLqprior().sum(-1).sum(-1) + self.pi.KLqprior() 69 | 70 | -------------------------------------------------------------------------------- /models/GaussianMixtureModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .dists import NormalInverseWishart 4 | from .dists import NormalGamma 5 | from .dists import Mixture 6 | 7 | class GaussianMixtureModel(Mixture): 8 | def __init__(self,dim,n): 9 | dist = NormalInverseWishart(torch.ones(dim,requires_grad=False), 10 | torch.zeros(dim,n,requires_grad=False), 11 | torch.ones(dim,requires_grad=False)*(n+2), 12 | torch.zeros(dim,n,n,requires_grad=False)+torch.eye(n,requires_grad=False)) 13 | super().__init__(dist) 14 | 15 | ###################################THIS IS A MORE FULLY FEATURE VERSION, BUT NEEDS TESTING 16 | # class GaussianMixtureModel(): 17 | # def __init__(self,alpha_0,n): 18 | # self.event_dim = 1 19 | # self.batch_dim = (alpha_0.ndim-1) 20 | 21 | # self.dim = n 22 | # self.num_clusters = alpha_0.shape[-1] 23 | # self.pi = Dirichlet(alpha_0) 24 | 25 | # invsigma = 1.0 #self.num_clusters**(2.0/self.dim) 26 | # self.niw = NormalinverseWishart(torch.ones(alpha_0.shape,requires_grad=False), 27 | # torch.zeros(alpha_0.shape + (self.dim,),requires_grad=False), 28 | # (self.dim+2)*torch.ones(alpha_0.shape,requires_grad=False), 29 | # torch.zeros(alpha_0.shape + (self.dim,self.dim),requires_grad=False)+torch.eye(self.dim,requires_grad=False)*invsigma) 30 | 31 | # def update(self,data,iters=1,lr=1.0,verbose=False): 32 | # for i in range(iters): 33 | # # E step 34 | # log_p = self.Elog_like(data) 35 | # shift = log_p.max(-1,True)[0] 36 | # log_p = log_p - shift 37 | # self.logZ = (log_p).logsumexp(-1,True) + shift 38 | # self.p = log_p.exp() 39 | # self.p = self.p/self.p.sum(-1,True) 40 | 41 | # self.NA = self.p 42 | # while self.NA.ndim > self.event_dim + self.batch_dim: 43 | # self.NA = self.NA.sum(0) 44 | 45 | # # M step 46 | # self.KLqprior_last = self.KLqprior() 47 | # self.pi.ss_update(self.NA,lr) 48 | # self.niw.raw_update(data,self.p,lr) 49 | # if verbose: 50 | # print('ELBO: ',self.ELBO()) 51 | 52 | # def ELBO(self,data=None): 53 | # if data is None: 54 | # return self.logZ.sum() - self.KLqprior_last 55 | # else: 56 | # self.log_p = self.niw.Elog_like(data) + self.pi.loggeomean() 57 | # shift = self.log_p.max(-1,True)[0] 58 | # self.logZ = (self.log_p - shift).logsumexp(-1,True) + shift 59 | # self.p = self.log_p.exp()/self.logZ.exp() 60 | # self.NA = self.p.sum(0) 61 | # return self.logZ.sum() - self.KLqprior() 62 | 63 | # def Elog_like(self,X): 64 | # return self.niw.Elog_like(X) + self.pi.loggeomean() 65 | 66 | 67 | # def KLqprior(self): 68 | # return self.pi.KLqprior() + self.niw.KLqprior().sum(-1) # this is because the last batch dimension is the cluster dimension 69 | 70 | # def assignment_pr(self): 71 | # return self.p 72 | 73 | # def assignment(self): 74 | # return self.p.argmax(-1) 75 | 76 | # def means(self): 77 | # return self.niw.mu 78 | 79 | # def initialize(self,data,lr=0.5): 80 | # data_mat = data.reshape(-1,self.dim) 81 | # self.pi.alpha = self.pi.alpha_0 82 | # ind = torch.randint(data_mat.size(0),[self.num_clusters]) 83 | # self.niw.mu = data_mat[ind] 84 | # self.update(data_mat,1,lr) 85 | # self.pi.alpha = self.pi.alpha_0 86 | # # self.fill_unused(data_mat) 87 | # self.update(data_mat,1,lr) 88 | # # self.fill_unused(data_mat) 89 | # self.update(data_mat,1,lr) 90 | 91 | # def fill_unused(self,data): 92 | # data = data.reshape(-1,self.dim) 93 | # m,loc = self.niw.Elog_like(data).max(-1)[0].sort() # find least likely data points 94 | # k=0 95 | # invV_bar = self.niw.invU.mean().mean(0) 96 | # nu_bar = self.niw.nu.mean(0) 97 | # lambda_bar = self.niw.lambda_mu.mean(0) 98 | # for i in range(self.num_clusters): 99 | # if(self.NA[i]<1/self.num_clusters): 100 | # self.niw.mu[i] = data[loc[k]] 101 | # self.niw.invV[i] = invV_bar 102 | # self.niw.nu[i] = nu_bar 103 | # self.niw.lambda_mu[i] = lambda_bar 104 | # k=k+1 105 | # self.niw.update_expectations() 106 | # self.pi.alpha = self.pi.alpha_0 107 | 108 | # def prune_unused(self): 109 | # # removes unused components 110 | # ind = self.NA>1/self.num_clusters 111 | # self.niw.mu_0 = self.niw.mu_0[ind] 112 | # self.niw.invV_0 = self.niw.invV_0[ind] 113 | # self.niw.mu = self.niw.mu[ind] 114 | # self.niw.invV = self.niw.invV[ind] 115 | # self.niw.nu = self.niw.nu[ind] 116 | # self.niw.lambda_mu = self.niw.lambda_mu[ind] 117 | # alpha = self.pi.alpha 118 | # alpha = alpha[ind] 119 | # self.pi = Dirichlet(torch.ones(alpha[ind].shape)/alpha[ind].shape[0]) 120 | # self.pi.alpha = alpha[ind] 121 | # self.num_clusters = ind.sum() 122 | # self.pi_alpha_0 = torch.ones(self.num_clusters)/self.num_clusters 123 | 124 | 125 | # gmm = GaussianMixtureModel(torch.ones(16,4)/2.0,torch.zeros(16,4,2)) 126 | # #gmm.initialize(data) 127 | # gmm.update(data,10,0.5) 128 | # #gmm.fill_unused(data) 129 | # gmm.update(data,10,0.5) 130 | # #gmm.fill_unused(data) 131 | # gmm.update(data,20,1) 132 | # #print((gmm.NA>1)) 133 | # #print(gmm.NA) 134 | # #print(mu) 135 | # #print(gmm.get_means()[(gmm.NA>1),:]) 136 | # import matplotlib.pyplot as plt 137 | 138 | # fig, ax = plt.subplots(nrows=4, ncols=4) 139 | # i=0 140 | # for row in ax: 141 | # for col in row: 142 | # col.scatter(data[:,0],data[:,1],c=gmm.assignment()[:,i]) 143 | # i=i+1 144 | # plt.show() 145 | 146 | # from matplotlib import pyplot as plt 147 | 148 | # loc = gmm.ELBO().argmax() 149 | -------------------------------------------------------------------------------- /models/IsotropicGaussianMixtureModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .dists import NormalGamma 3 | from .dists import Mixture 4 | 5 | class IsotropicGaussianMixtureModel(Mixture): 6 | def __init__(self,dim,n): 7 | dist = NormalGamma(torch.ones(dim,requires_grad=False), 8 | torch.zeros(dim,n,requires_grad=False), 9 | torch.ones(dim,n,requires_grad=False), 10 | torch.ones(dim,n,requires_grad=False)) 11 | super().__init__(dist) 12 | 13 | ###################################THIS IS A MORE FULLY FEATURE VERSION, BUT NEEDS TESTING 14 | # class GaussianMixtureModel(): 15 | # def __init__(self,alpha_0,n): 16 | # self.event_dim = 1 17 | # self.batch_dim = (alpha_0.ndim-1) 18 | 19 | # self.dim = n 20 | # self.num_clusters = alpha_0.shape[-1] 21 | # self.pi = Dirichlet(alpha_0) 22 | 23 | # invsigma = 1.0 #self.num_clusters**(2.0/self.dim) 24 | # self.niw = NormalinverseWishart(torch.ones(alpha_0.shape,requires_grad=False), 25 | # torch.zeros(alpha_0.shape + (self.dim,),requires_grad=False), 26 | # (self.dim+2)*torch.ones(alpha_0.shape,requires_grad=False), 27 | # torch.zeros(alpha_0.shape + (self.dim,self.dim),requires_grad=False)+torch.eye(self.dim,requires_grad=False)*invsigma) 28 | 29 | # def update(self,data,iters=1,lr=1.0,verbose=False): 30 | # for i in range(iters): 31 | # # E step 32 | # log_p = self.Elog_like(data) 33 | # shift = log_p.max(-1,True)[0] 34 | # log_p = log_p - shift 35 | # self.logZ = (log_p).logsumexp(-1,True) + shift 36 | # self.p = log_p.exp() 37 | # self.p = self.p/self.p.sum(-1,True) 38 | 39 | # self.NA = self.p 40 | # while self.NA.ndim > self.event_dim + self.batch_dim: 41 | # self.NA = self.NA.sum(0) 42 | 43 | # # M step 44 | # self.KLqprior_last = self.KLqprior() 45 | # self.pi.ss_update(self.NA,lr) 46 | # self.niw.raw_update(data,self.p,lr) 47 | # if verbose: 48 | # print('ELBO: ',self.ELBO()) 49 | 50 | # def ELBO(self,data=None): 51 | # if data is None: 52 | # return self.logZ.sum() - self.KLqprior_last 53 | # else: 54 | # self.log_p = self.niw.Elog_like(data) + self.pi.loggeomean() 55 | # shift = self.log_p.max(-1,True)[0] 56 | # self.logZ = (self.log_p - shift).logsumexp(-1,True) + shift 57 | # self.p = self.log_p.exp()/self.logZ.exp() 58 | # self.NA = self.p.sum(0) 59 | # return self.logZ.sum() - self.KLqprior() 60 | 61 | # def Elog_like(self,X): 62 | # return self.niw.Elog_like(X) + self.pi.loggeomean() 63 | 64 | 65 | # def KLqprior(self): 66 | # return self.pi.KLqprior() + self.niw.KLqprior().sum(-1) # this is because the last batch dimension is the cluster dimension 67 | 68 | # def assignment_pr(self): 69 | # return self.p 70 | 71 | # def assignment(self): 72 | # return self.p.argmax(-1) 73 | 74 | # def means(self): 75 | # return self.niw.mu 76 | 77 | # def initialize(self,data,lr=0.5): 78 | # data_mat = data.reshape(-1,self.dim) 79 | # self.pi.alpha = self.pi.alpha_0 80 | # ind = torch.randint(data_mat.size(0),[self.num_clusters]) 81 | # self.niw.mu = data_mat[ind] 82 | # self.update(data_mat,1,lr) 83 | # self.pi.alpha = self.pi.alpha_0 84 | # # self.fill_unused(data_mat) 85 | # self.update(data_mat,1,lr) 86 | # # self.fill_unused(data_mat) 87 | # self.update(data_mat,1,lr) 88 | 89 | # def fill_unused(self,data): 90 | # data = data.reshape(-1,self.dim) 91 | # m,loc = self.niw.Elog_like(data).max(-1)[0].sort() # find least likely data points 92 | # k=0 93 | # invV_bar = self.niw.invU.mean().mean(0) 94 | # nu_bar = self.niw.nu.mean(0) 95 | # lambda_bar = self.niw.lambda_mu.mean(0) 96 | # for i in range(self.num_clusters): 97 | # if(self.NA[i]<1/self.num_clusters): 98 | # self.niw.mu[i] = data[loc[k]] 99 | # self.niw.invV[i] = invV_bar 100 | # self.niw.nu[i] = nu_bar 101 | # self.niw.lambda_mu[i] = lambda_bar 102 | # k=k+1 103 | # self.niw.update_expectations() 104 | # self.pi.alpha = self.pi.alpha_0 105 | 106 | # def prune_unused(self): 107 | # # removes unused components 108 | # ind = self.NA>1/self.num_clusters 109 | # self.niw.mu_0 = self.niw.mu_0[ind] 110 | # self.niw.invV_0 = self.niw.invV_0[ind] 111 | # self.niw.mu = self.niw.mu[ind] 112 | # self.niw.invV = self.niw.invV[ind] 113 | # self.niw.nu = self.niw.nu[ind] 114 | # self.niw.lambda_mu = self.niw.lambda_mu[ind] 115 | # alpha = self.pi.alpha 116 | # alpha = alpha[ind] 117 | # self.pi = Dirichlet(torch.ones(alpha[ind].shape)/alpha[ind].shape[0]) 118 | # self.pi.alpha = alpha[ind] 119 | # self.num_clusters = ind.sum() 120 | # self.pi_alpha_0 = torch.ones(self.num_clusters)/self.num_clusters 121 | 122 | 123 | # gmm = GaussianMixtureModel(torch.ones(16,4)/2.0,torch.zeros(16,4,2)) 124 | # #gmm.initialize(data) 125 | # gmm.update(data,10,0.5) 126 | # #gmm.fill_unused(data) 127 | # gmm.update(data,10,0.5) 128 | # #gmm.fill_unused(data) 129 | # gmm.update(data,20,1) 130 | # #print((gmm.NA>1)) 131 | # #print(gmm.NA) 132 | # #print(mu) 133 | # #print(gmm.get_means()[(gmm.NA>1),:]) 134 | # import matplotlib.pyplot as plt 135 | 136 | # fig, ax = plt.subplots(nrows=4, ncols=4) 137 | # i=0 138 | # for row in ax: 139 | # for col in row: 140 | # col.scatter(data[:,0],data[:,1],c=gmm.assignment()[:,i]) 141 | # i=i+1 142 | # plt.show() 143 | 144 | # from matplotlib import pyplot as plt 145 | 146 | # loc = gmm.ELBO().argmax() 147 | -------------------------------------------------------------------------------- /models/MVN_ard.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .dists import Gamma 4 | 5 | class MVN_ard(): 6 | def __init__(self,dim,batch_shape=(),scale=1): 7 | 8 | self.dim = dim 9 | self.event_dim = 2 10 | self.event_dim_0 = 2 11 | self.event_shape = (dim,1) 12 | self.batch_shape = batch_shape 13 | self.batch_dim = len(self.batch_shape) 14 | self.mu = torch.randn(batch_shape + (dim,1),requires_grad=False)*scale 15 | self.invSigma = torch.zeros(batch_shape + (dim,dim),requires_grad=False) + torch.eye(dim,requires_grad=False) 16 | self.Sigma = self.invSigma 17 | self.logdetinvSigma = self.invSigma.logdet() 18 | self.invSigmamu = self.invSigma@self.mu 19 | self.alpha = Gamma(torch.ones(batch_shape+(dim,),requires_grad=False),torch.ones(batch_shape+(dim,),requires_grad=False)) 20 | 21 | 22 | def to_event(self,n): 23 | if n == 0: 24 | return self 25 | self.event_dim = self.event_dim + n 26 | self.batch_dim = self.batch_dim - n 27 | self.event_shape = self.batch_shape[-n:] + self.event_shape 28 | self.batch_shape = self.batch_shape[:-n] 29 | return self 30 | 31 | def ss_update(self,SExx,SEx, iters = 1, lr=1.0): 32 | 33 | for i in range(iters): 34 | invSigma = SExx + self.alpha.mean().unsqueeze(-1)*torch.eye(self.dim,requires_grad=False) 35 | invSigmamu = SEx 36 | self.invSigma = (1-lr)*self.invSigma + lr*invSigma 37 | self.Sigma = self.invSigma.inverse() 38 | self.invSigmamu = (1-lr)*self.invSigmamu + lr*invSigmamu 39 | self.mu = self.Sigma@self.invSigmamu 40 | self.alpha.ss_update(0.5,0.5*self.EXXT().diagonal(dim1=-1,dim2=-2),lr) 41 | 42 | self.logdetinvSigma = self.invSigma.logdet() 43 | 44 | def raw_update(self,X,p=None,lr=1.0): # assumes X is a vector and p is sample x batch 45 | 46 | if p is None: 47 | SEx = X 48 | SExx = X@X.transpose(-2,-1) 49 | sample_shape = X.shape[:-self.event_dim-self.batch_dim] 50 | n = torch.tensor(np.prod(sample_shape),requires_grad=False) 51 | n = n.expand(self.batch_shape + self.event_shape[:-2]) 52 | while SEx.ndim>self.event_dim + self.batch_dim: 53 | SExx = SExx.sum(0) 54 | SEx = SEx.sum(0) 55 | self.ss_update(SExx,SEx,n,lr) # inputs to ss_update must be batch + event consistent 56 | 57 | else: # data is shape sample_shape x batch_shape x event_shape with the first batch dimension having size 1 58 | 59 | for i in range(self.event_dim): 60 | p=p.unsqueeze(-1) 61 | SExx = X@X.transpose(-2,-1)*p 62 | SEx = X*p 63 | while SEx.ndim>self.event_dim + self.batch_dim: 64 | SExx = SExx.sum(0) 65 | SEx = SEx.sum(0) 66 | p = p.sum(0) 67 | self.ss_update(SExx,SEx,p.squeeze(-1).squeeze(-1),lr) # inputs to ss_update must be batch + event consistent 68 | # p now has shape batch_shape + event_shape so it must be squeezed by the default event_shape which is 1 69 | 70 | def KLqprior(self): 71 | KL = 0.5*(self.mu.pow(2).squeeze(-1)*self.alpha.mean()).sum(-1) - 0.5*self.alpha.loggeomean().sum(-1) + 0.5*self.ElogdetinvSigma() 72 | KL = KL + self.alpha.KLqprior().sum(-1) 73 | for i in range(self.event_dim-2): 74 | KL = KL.sum(-1) 75 | return KL 76 | 77 | def mean(self): 78 | return self.mu 79 | 80 | def ESigma(self): 81 | return self.Sigma 82 | 83 | def EinvSigma(self): 84 | return self.invSigma 85 | 86 | def EinvSigmamu(self): 87 | return self.invSigmamu 88 | 89 | def ElogdetinvSigma(self): 90 | return self.logdetinvSigma 91 | 92 | def EX(self): 93 | return self.mean() 94 | 95 | def EXXT(self): 96 | return self.ESigma() + self.mean()@self.mean().transpose(-2,-1) 97 | 98 | def EXTX(self): 99 | return self.ESigma().sum(-1).sum(-1) + self.mean().pow(2).sum(-2).squeeze(-1) 100 | 101 | def EXTinvUX(self): 102 | return (self.mean().transpose(-2,-1)@self.EinvSigma()@self.mean()).squeeze(-1).squeeze(-1) 103 | 104 | def Res(self): 105 | return - 0.5*(self.mean()*self.EinvSigmamu()).sum(-1).sum(-1) + 0.5*self.ElogdetinvSigma() - 0.5*self.dim*np.log(2*np.pi) 106 | 107 | 108 | -------------------------------------------------------------------------------- /models/MixLDS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .dists import Dirichlet 4 | from .LDS import LinearDynamicalSystems 5 | 6 | class MixtureofLinearDynamicalSystems(): 7 | def __init__(self,num_systems, obs_shape, hidden_dim, control_dim, regression_dim): 8 | self.num_systems = num_systems 9 | self.lds = LinearDynamicalSystems(obs_shape, hidden_dim, control_dim, regression_dim, latent_noise='independent', batch_shape= (num_systems,)) 10 | self.lds.expand_to_batch = True 11 | self.pi = Dirichlet(0.5*torch.ones(num_systems)) 12 | 13 | def update(self, y, u, r,iters=1,lr=1): 14 | y,u,r = self.lds.reshape_inputs(y,u,r) 15 | ELBO = -torch.tensor(torch.inf) 16 | for i in range(iters): 17 | ELBO_last = ELBO 18 | self.lds.update_latents(y,u,r) 19 | log_p = self.lds.logZ 20 | # self.log_p = self.lds.logZ/y.shape[0] % this is wrong but gets better mixing when lr = 1 21 | log_p = log_p + self.pi.loggeomean() 22 | 23 | shift = log_p.max(-1,True)[0] 24 | log_p = log_p - shift 25 | self.logZ = (log_p.logsumexp(-1,True)+shift).squeeze(-1) # has shape sample x batch_shape 26 | self.p = torch.exp(log_p) 27 | self.p = self.p/self.p.sum(-1,True) 28 | self.NA = self.p.sum(0) 29 | 30 | ELBO = self.logZ.sum() - self.KLqprior() 31 | self.pi.ss_update(self.NA,lr=lr) 32 | self.lds.ss_update(p=self.p,lr=lr) # Note that this takes care of the p averages for input to obs_model.ss_update 33 | self.lds.obs_model.ss_update(self.lds.SE_xr_xr,self.lds.SE_y_xr,self.lds.SE_y_y,self.lds.T,lr) 34 | 35 | print('Percent Change in ELBO = %f' % (((ELBO-ELBO_last)/ELBO_last.abs()).data*100)) 36 | 37 | def KLqprior(self): 38 | return self.pi.KLqprior() + self.lds.KLqprior().sum(-1) 39 | 40 | def ELBO(self): 41 | self.KL_last - self.logZ 42 | 43 | def assignment_pr(self): 44 | return self.p 45 | 46 | def assignment(self): 47 | return self.p.argmax(-1) 48 | 49 | 50 | -------------------------------------------------------------------------------- /models/MixtureofLinearTransforms.py: -------------------------------------------------------------------------------- 1 | # Variational Bayesian Expectation Maximization for linear regression and mixtures of linear models 2 | # with Gaussian observations 3 | 4 | import torch 5 | import numpy as np 6 | from .dists.Dirichlet import * 7 | from .dists.MatrixNormalWishart import * 8 | from .dists.MatrixNormalGamma import * 9 | 10 | class MixtureofLinearTransforms(): 11 | 12 | def __init__(self,n,p,dim,batch_shape = (), pad_X=True, type = 'Wishart'): 13 | self.n = n 14 | self.p = p 15 | self.dim = dim # here dim is the number of experts 16 | self.event_dim = 1 17 | self.event_shape = (dim,) 18 | self.batch_dim = len(batch_shape) 19 | self.batch_shape = batch_shape 20 | 21 | if type == 'Wishart': 22 | self.W = MatrixNormalWishart(mu_0 = torch.zeros(batch_shape + (dim,n,p),requires_grad=False), 23 | U_0=torch.zeros(batch_shape + (dim,n,n),requires_grad=False)+torch.eye(n,requires_grad=False)*dim**2, 24 | pad_X=pad_X) 25 | elif type == 'Gamma': 26 | self.W = MatrixNormalGamma(mu_0 = torch.zeros(batch_shape + (dim,n,p),requires_grad=False), 27 | U_0=torch.zeros(batch_shape + (dim,n,n),requires_grad=False)+torch.eye(n,requires_grad=False)*dim**2, 28 | pad_X=pad_X) 29 | else: 30 | raise ValueError('type must be either Wishart (default) or Gamma') 31 | 32 | self.pi = Dirichlet(0.5*torch.ones(batch_shape + (dim,))) 33 | self.KL_last = self.KLqprior() 34 | self.ELBO_last = -torch.tensor(torch.inf) 35 | 36 | def update_assignments(self,X,Y): 37 | log_p = self.W.Elog_like(X.unsqueeze(-3),Y.unsqueeze(-3)) + self.pi.loggeomean() 38 | shift = log_p.max(-1,True)[0] 39 | log_p = log_p - shift 40 | self.logZ = (log_p.logsumexp(-1,True)+shift).squeeze(-1).sum(0) 41 | self.p = log_p.exp() 42 | self.p = self.p/self.p.sum(-1,True) 43 | 44 | def Elog_like(self,X,Y): 45 | ELL = (self.W.Elog_like(X.unsqueeze(-3),Y.unsqueeze(-3))*self.p).sum(-1) 46 | for i in range(self.event_dim-1): 47 | ELL = ELL.sum(-1) 48 | return ELL 49 | 50 | def raw_update(self,X,Y,iters=1,lr=1.0,verbose=False): 51 | for i in range(iters): 52 | # E-Step 53 | self.update_assignments(X,Y) 54 | self.KL_last = self.KLqprior() 55 | ELBO = self.ELBO() 56 | 57 | # M-Step 58 | self.pi.ss_update(self.p.sum(0),lr=lr) 59 | self.W.raw_update(X.unsqueeze(-3),Y.unsqueeze(-3),p=self.p,lr=lr) 60 | if verbose: 61 | print('Iteration %d: Percent Change in ELBO = %f' % (i,(((ELBO-self.ELBO_last)/self.ELBO_last.abs()).data*100))) 62 | self.ELBO_last = ELBO 63 | 64 | def update_assignments_given_pX_pY(self,pX,pY): 65 | log_p = self.W.Elog_like_given_pX_pY(pX.unsqueeze(-3),pY.unsqueeze(-3)) + self.pi.loggeomean() 66 | shift = log_p.max(-1,True)[0] 67 | log_p = log_p - shift 68 | self.logZ = (log_p.logsumexp(-1,True)+shift).squeeze(-1).sum(0) # has shape sample 69 | self.p = log_p.exp() 70 | self.p = self.p/self.p.sum(-1,True) 71 | 72 | def Elog_like_given_pX_pY(self,pX,pY): 73 | ELL = (self.W.Elog_like(pX.unsqueeze(-3),pY.unsqueeze(-3))*self.p).sum(-1) 74 | for i in range(self.event_dim-1): 75 | ELL = ELL.sum(-1) 76 | return ELL 77 | 78 | def update(self,pX,pY,iters=1,lr=1,verbose=False): 79 | for i in range(iters): 80 | # E-Step 81 | self.update_assignments_given_pX_pY(pX,pY) 82 | self.KL_last = self.KLqprior() 83 | ELBO = self.ELBO() 84 | 85 | # M-Step 86 | self.pi.ss_update(self.p.sum(0),lr=lr) 87 | self.W.update(pX.unsqueeze(-3),pY.unsqueeze(-3),p=self.p,lr=lr) 88 | if verbose: 89 | print('Iteration %d: Percent Change in ELBO = %f' % (i,(((ELBO-self.ELBO_last)/self.ELBO_last.abs()).data*100))) 90 | self.ELBO_last = ELBO 91 | 92 | def predict(self,X): 93 | # mu_y, Sigma_y_y = self.W.predict(X.unsqueeze(-3))[0:2] 94 | # p = self.pi.mean().unsqueeze(-1).unsqueeze(-1) 95 | # mu = (mu_y*p).sum(-3) 96 | # Sigma = ((Sigma_y_y+mu_y@mu_y.transpose(-2,-1))*p).sum(-3) - mu@mu.transpose(-2,-1) 97 | # return mu, Sigma 98 | mu, Sigma, invSigma, invSigmamu, Res = self.W.predict(X.unsqueeze(-3)) 99 | 100 | log_p = Res + self.pi.loggeomean() 101 | log_p = log_p - log_p.max(-1,True)[0] 102 | p = log_p.exp() 103 | p = p/p.sum(-1,True) 104 | p = p.unsqueeze(-1).unsqueeze(-1) 105 | # invSigmamu = (invSigmamu*p).sum(-3) 106 | # invSigma = (invSigma*p).sum(-3) 107 | 108 | Sigma = ((Sigma+mu@mu.transpose(-2,-1))*p).sum(-3) 109 | mu = (mu*p).sum(-3) 110 | Sigma = Sigma - mu@mu.transpose(-2,-1) 111 | 112 | return mu, Sigma, p.squeeze(-1).squeeze(-1) 113 | 114 | def forward(self,pX): 115 | pass 116 | 117 | def Elog_like_X(self,Y): 118 | pass 119 | 120 | def backward(self,pY): 121 | pass 122 | 123 | def KLqprior(self): 124 | return self.pi.KLqprior() + self.W.KLqprior().sum(-1) 125 | 126 | def ELBO(self): 127 | return self.logZ - self.KL_last 128 | 129 | def assignment_pr(self): 130 | return self.p 131 | 132 | def assignment(self): 133 | return self.p.argmax(-1) 134 | 135 | def mean(self): 136 | return self.p 137 | 138 | ### Compute special expectations used for VB inference 139 | def event_average(self,A): # returns sample_shape + W.event_shape 140 | # A is mix_batch_shape + mix_event_shape + event_shape 141 | p=self.p 142 | for i in range(self.W.event_dim): 143 | p = p.unsqueeze(-1) 144 | out = (A*p) 145 | for i in range(self.event_dim): 146 | out = out.sum(-self.W.event_dim-1) 147 | return out 148 | 149 | def EinvUX(self): 150 | return self.event_average(self.W.EinvUX()) 151 | 152 | def EXTinvU(self): 153 | return self.event_average(self.W.EXTinvU()) 154 | 155 | def EXTAX(self,A): # X is n x p, A is n x n 156 | return self.event_average(self.W.EXTAX(A)) 157 | 158 | def EXAXT(self,A): # A is p x p 159 | return self.event_average(self.W.EXAXT(A)) 160 | 161 | def EXTinvUX(self): 162 | return self.event_average(self.W.EXTinvUX()) 163 | 164 | def EXinvVXT(self): 165 | return self.event_average(self.W.EXinvVXT()) 166 | 167 | def EXmMUTinvUXmMU(self): # X minus mu 168 | return self.event_average(self.W.EXmMUTinvUXmMU()) 169 | 170 | def EXmMUinvVXmMUT(self): 171 | return self.event_average(self.W.EXmMUinvVXmMUT()) 172 | 173 | def EXTX(self): 174 | return self.event_average(self.W.EXTX()) 175 | 176 | def EXXT(self): 177 | return self.event_average(self.W.EXXT()) 178 | 179 | def EinvSigma(self): 180 | return self.event_average(self.W.EinvSigma()) 181 | 182 | def ESigma(self): 183 | return self.event_average(self.W.ESigma()) 184 | 185 | def average(self,A): 186 | out=self.p*A 187 | for i in range(self.event_dim): 188 | out = out.sum(-1) 189 | return out 190 | 191 | def ElogdetinvU(self): 192 | return self.average(self.W.invU.ElogdetinvSigma()) 193 | 194 | def ElogdetinvSigma(self): 195 | return self.average(self.W.ElogdetinvSigma()) 196 | 197 | def weights(self): 198 | if self.padX: 199 | return self.W.mu[...,:-1] 200 | else: 201 | return self.W.mu 202 | 203 | def bias(self): 204 | if self.padX: 205 | return self.W.mu[...,-1] 206 | else: 207 | return None 208 | 209 | def means(self): 210 | return self.mu 211 | 212 | 213 | -------------------------------------------------------------------------------- /models/MixtureofMatrixNormalGammas.py: -------------------------------------------------------------------------------- 1 | # Variational Bayesian Expectation Maximization for linear regression and mixtures of linear models 2 | # with Gaussian observations 3 | 4 | import torch 5 | import numpy as np 6 | from .dists.MatrixNormalGamma import MatrixNormalGamma 7 | from .dists.Dirichlet import Dirichlet 8 | 9 | class MixtureofMatrixNormalGammas(): 10 | 11 | def __init__(self,mu_0,alpha_0,padX=True): 12 | n = mu_0.shape[-2] 13 | p = mu_0.shape[-1] 14 | self.padX = padX 15 | if self.padX: 16 | p = p+1 17 | mu_0 = torch.cat((mu_0,torch.zeros(mu_0.shape[:-1],requires_grad=False).unsqueeze(-1)),-1) 18 | self.n = n 19 | self.p = p 20 | self.dim = alpha_0.shape[-1] # here dim is the number of experts 21 | self.event_dim = 1 22 | self.event_shape = alpha_0.shape[-1:] 23 | self.batch_dim = alpha_0.ndim - 1 24 | self.batch_shape = alpha_0.shape[:-1] 25 | 26 | w_event_length = mu_0.ndim - 2 27 | mu_0 = mu_0.expand(self.batch_shape + self.event_shape + mu_0.shape) 28 | self.W = MatrixNormalGamma(mu_0).to_event(w_event_length) 29 | self.pi = Dirichlet(alpha_0) 30 | self.KL_last = self.KLqprior() 31 | 32 | def to_event(self,n): 33 | if n == 0: 34 | return self 35 | self.event_dim = self.event_dim + n 36 | self.batch_dim = self.batch_dim - n 37 | self.event_shape = self.batch_shape[-n:] + self.event_shape 38 | self.batch_shape = self.batch_shape[:-n] 39 | self.pi.to_event(n) 40 | self.W.to_event(n) 41 | return self 42 | 43 | def Elog_like(self,X,Y): 44 | ELL = (self.W.Elog_like(X,Y)*self.pi.mean()).sum(-1) 45 | for i in range(self.event_dim-1): 46 | ELL = ELL.sum(-1) 47 | return ELL 48 | 49 | def raw_update(self,X,Y,iters=1,lr=1.0,verbose=False): 50 | if self.padX: 51 | X = torch.cat((X,torch.ones(X.shape[:-2]+(1,1),requires_grad=False)),-2) 52 | ELBO = -torch.tensor(torch.inf) 53 | for i in range(iters): 54 | ELBO_last = ELBO 55 | # E-Step 56 | self.log_p = self.W.Elog_like(X,Y) + self.pi.loggeomean() 57 | shift = self.log_p.max(-1,True)[0] 58 | self.log_p = self.log_p - shift 59 | self.logZ = (self.log_p.logsumexp(-1,True)+shift).squeeze(-1) # has shape sample 60 | self.p = torch.exp(self.log_p) 61 | self.p = self.p/self.p.sum(-1,True) 62 | self.NA = self.p.sum(0) 63 | self.KL_last = self.KLqprior() 64 | ELBO = self.ELBO() 65 | 66 | # M-Step 67 | self.pi.ss_update(self.NA) 68 | self.W.raw_update(X,Y,self.p,lr) 69 | if verbose: 70 | print('Iteration %d: Percent Change in ELBO = %f' % (i,(((ELBO-ELBO_last)/ELBO_last.abs()).data*100))) 71 | 72 | def KLqprior(self): 73 | return self.pi.KLqprior() + self.W.KLqprior().sum(-1) 74 | 75 | def ELBO(self): 76 | return self.logZ.sum() - self.KL_last 77 | 78 | def assignment_pr(self): 79 | return self.p 80 | 81 | def assignment(self): 82 | return self.p.argmax(-1) 83 | 84 | def mean(self): 85 | return self.p 86 | 87 | ### Compute special expectations used for VB inference 88 | def event_average(self,A): # returns sample_shape + W.event_shape 89 | # A is mix_batch_shape + mix_event_shape + event_shape 90 | p=self.p 91 | for i in range(self.W.event_dim): 92 | p = p.unsqueeze(-1) 93 | out = (A*p) 94 | for i in range(self.event_dim): 95 | out = out.sum(-self.W.event_dim-1) 96 | return out 97 | 98 | def EinvUX(self): 99 | return self.event_average(self.W.EinvUX()) 100 | 101 | def EXTinvU(self): 102 | return self.event_average(self.W.EXTinvU()) 103 | 104 | def EXTAX(self,A): # X is n x p, A is n x n 105 | return self.event_average(self.W.EXTAX(A)) 106 | 107 | def EXAXT(self,A): # A is p x p 108 | return self.event_average(self.W.EXAXT(A)) 109 | 110 | def EXTinvUX(self): 111 | return self.event_average(self.W.EXTinvUX()) 112 | 113 | def EXinvVXT(self): 114 | return self.event_average(self.W.EXinvVXT()) 115 | 116 | def EXmMUTinvUXmMU(self): # X minus mu 117 | return self.event_average(self.W.EXmMUTinvUXmMU()) 118 | 119 | def EXmMUinvVXmMUT(self): 120 | return self.event_average(self.W.EXmMUinvVXmMUT()) 121 | 122 | def EXTX(self): 123 | return self.event_average(self.W.EXTX()) 124 | 125 | def EXXT(self): 126 | return self.event_average(self.W.EXXT()) 127 | 128 | def EinvSigma(self): 129 | return self.event_average(self.W.EinvSigma()) 130 | 131 | def ESigma(self): 132 | return self.event_average(self.W.ESigma()) 133 | 134 | def average(self,A): 135 | out=self.p*A 136 | for i in range(self.event_dim): 137 | out = out.sum(-1) 138 | return out 139 | 140 | def ElogdetinvU(self): 141 | return self.average(self.W.invU.ElogdetinvSigma()) 142 | 143 | def ElogdetinvSigma(self): 144 | return self.average(self.W.ElogdetinvSigma()) 145 | 146 | def weights(self): 147 | if self.padX: 148 | return self.W.mu[...,:-1] 149 | else: 150 | return self.W.mu 151 | 152 | def bias(self): 153 | if self.padX: 154 | return self.W.mu[...,-1] 155 | else: 156 | return None 157 | 158 | def means(self): 159 | return self.mu 160 | 161 | 162 | # from matplotlib import pyplot as plt 163 | # dim = 2 164 | # p = 3 165 | # n_samples = 400 166 | # print('TEST MIXTURE model') 167 | # nc=3 168 | # n = 2*dim 169 | # w_true = torch.randn(nc,n,p) 170 | # b_true = torch.randn(nc,n,1) 171 | # X=torch.randn(n_samples,p) 172 | # Y=torch.zeros(n_samples,n) 173 | # for i in range(n_samples): 174 | # Y[i,:] = X[i:i+1,:]@w_true[i%nc,:,:].transpose(-1,-2) + b_true[i%nc,:,:].transpose(-2,-1) + torch.randn(1)/4.0 175 | # nc=5 176 | # mu_0 = torch.zeros(n,p) 177 | # model = MixtureofMatrixNormalGammas(mu_0,torch.ones(nc)*0.5,True) 178 | # model.raw_update(X.unsqueeze(-2).unsqueeze(-1),Y.unsqueeze(-2).unsqueeze(-1),iters=20,verbose=True) 179 | # xidx = (w_true[0,0,:]**2).argmax() 180 | # plt.scatter(X[:,xidx].data,Y[:,0].data,c=model.assignment()) 181 | # plt.show() 182 | 183 | 184 | # print('TEST MIXTURE with non-trivial observation shape') 185 | # nc=3 186 | # n = 2*dim 187 | # w_true = torch.randn(nc,n,p) 188 | # b_true = torch.randn(nc,n,1) 189 | # X=torch.randn(n_samples,p) 190 | # Y=torch.zeros(n_samples,n) 191 | # for i in range(n_samples): 192 | # Y[i,:] = X[i:i+1,:]@w_true[i%nc,:,:].transpose(-1,-2) + b_true[i%nc,:,:].transpose(-2,-1) + torch.randn(1)/4.0 193 | # nc=5 194 | # n = 2 195 | # X = X.unsqueeze(-2) 196 | # Y = Y.reshape(n_samples,dim,n) 197 | # mu_0 = torch.zeros(dim,n,p) 198 | # model2 = MixtureofMatrixNormalGammas(mu_0,torch.ones(nc)*0.5,True) 199 | # model2.raw_update(X.unsqueeze(-3).unsqueeze(-1),Y.unsqueeze(-3).unsqueeze(-1),iters=20,verbose=True) 200 | # xidx = (w_true[0,0,:]**2).argmax() 201 | # plt.scatter(X[:,0,xidx].data,Y[:,0,0].data,c=model.assignment()) 202 | # plt.show() 203 | 204 | -------------------------------------------------------------------------------- /models/NLRegression_Multinomial.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from .dists import MatrixNormalWishart 5 | from .MultiNomialLogisticRegression import MultiNomialLogisticRegression 6 | 7 | class NLRegression_Multinomial(): 8 | # Generative model of NL regression. Generative model is: 9 | # z_t ~ MNRL(x_t) 10 | # y_t|z_t,x_t ~ MatrixNormalWishart 11 | print("NLRegression has no forward/backward methods, Use dMixtureofLinearTransofrorms instead") 12 | def __init__(self,n,p,mixture_dim,batch_shape=(),pad_X=True): 13 | 14 | self.batch_shape = batch_shape 15 | self.batch_dim = len(batch_shape) 16 | self.event_dim = 2 17 | self.n = n 18 | self.p = p 19 | self.mixture_dim = mixture_dim 20 | self.ELBO_last = -torch.tensor(torch.inf) 21 | 22 | self.A = MatrixNormalWishart(torch.zeros(batch_shape + (mixture_dim,n,p),requires_grad=False), 23 | U_0=torch.zeros(batch_shape + (mixture_dim,n,n),requires_grad=False)+torch.eye(n,requires_grad=False)*mixture_dim**2, 24 | pad_X=pad_X) 25 | self.Z = MultiNomialLogisticRegression(mixture_dim, p, batch_shape = (), pad_X=pad_X) 26 | 27 | def raw_update(self,X,Y,iters=1.0,lr=1.0,verbose=False): 28 | AX = X.view(X.shape + (1,)) # make vector 29 | AY = Y.view(Y.shape + (1,)) 30 | AX = AX.view(AX.shape[:-2] + (self.batch_dim+1)*(1,) + AX.shape[-2:]) # add z dim and batch_dim 31 | AY = AY.view(AY.shape[:-2] + (self.batch_dim+1)*(1,) + AY.shape[-2:]) 32 | 33 | for i in range(int(iters)): 34 | log_p = self.A.Elog_like(AX,AY) + self.Z.log_predict(X) 35 | shift = log_p.max(-1,True)[0] 36 | log_p = log_p - shift 37 | self.logZ = shift.squeeze(-1) + log_p.logsumexp(-1) 38 | p = log_p.exp() 39 | p = p/p.sum(-1,True) 40 | self.NA = p.sum(0) 41 | 42 | ELBO = self.logZ - self.KLqprior() 43 | if verbose: print("Percent Change in ELBO = ",((ELBO-self.ELBO_last)/self.ELBO_last.abs()).data*100) 44 | ELBO_last = ELBO 45 | 46 | self.A.raw_update(AX,AY,p=p,lr=lr) 47 | self.Z.raw_update(X,p,lr=lr,verbose=False) 48 | 49 | def Elog_like_X(self,Y): 50 | AY = Y.view(Y.shape + (1,)) 51 | AY = AY.view(AY.shape[:-2] + (self.batch_dim+1)*(1,) + AY.shape[-2:]) 52 | invSigma,invSigmamu,Res = self.A.Elog_like_X(AY) 53 | 54 | def forward(self,X): 55 | pass 56 | 57 | def backward(self,Y): 58 | pass 59 | 60 | def predict_full(self,X): 61 | log_p = self.Z.log_predict(X) 62 | log_p = log_p - log_p.max(-1,keepdim=True)[0] 63 | p = log_p.exp() 64 | p = p/p.sum(-1,True) 65 | p = p.view(p.shape+(1,1)) 66 | return self.A.predict(X.unsqueeze(-2).unsqueeze(-1)) + (p,) 67 | 68 | def predict(self,X): 69 | p=self.Z.predict(X) 70 | p = p.view(p.shape+(1,1)) 71 | 72 | mu_y, Sigma_y_y = self.A.predict(X.unsqueeze(-2).unsqueeze(-1))[0:2] 73 | mu = (mu_y*p).sum(-3) 74 | Sigma = ((Sigma_y_y + mu_y@mu_y.transpose(-2,-1))*p).sum(-3) - mu@mu.transpose(-2,-1) 75 | 76 | return mu, Sigma 77 | 78 | def ELBO(self): 79 | return self.logZ - self.KLqprior() 80 | 81 | def KLqprior(self): 82 | return self.A.KLqprior().sum(-1) + self.Z.KLqprior() 83 | 84 | -------------------------------------------------------------------------------- /models/NormalSparse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .dists import Gamma 4 | 5 | class NormalSparse(): 6 | # bunch of independent normal distributions with zero mean and a gamma prior on variance 7 | def __init__(self,event_shape,batch_shape=(),scale=1): 8 | 9 | self.event_dim_0 = 0 10 | self.event_shape = event_shape 11 | self.batch_shape = batch_shape 12 | self.batch_dim = len(self.batch_shape) 13 | self.alpha = Gamma(0.5*torch.ones(batch_shape+event_shape,requires_grad=False),0.5/scale*torch.ones(batch_shape+event_shape,requires_grad=False)) 14 | 15 | 16 | def to_event(self,n): 17 | if n == 0: 18 | return self 19 | self.event_dim = self.event_dim + n 20 | self.batch_dim = self.batch_dim - n 21 | self.event_shape = self.batch_shape[-n:] + self.event_shape 22 | self.batch_shape = self.batch_shape[:-n] 23 | return self 24 | 25 | def ss_update(self,SExx, N, iters = 1, lr=1.0): 26 | 27 | for i in range(iters): 28 | self.alpha.ss_update(0.5*N,0.5*SExx) 29 | 30 | def raw_update(self,X,p=None,lr=1.0): # assumes X is a vector and p is sample x batch 31 | 32 | if p is None: 33 | SExx = X.pow(2) 34 | sample_shape = X.shape[:-self.event_dim-self.batch_dim] 35 | N = torch.tensor(np.prod(sample_shape),requires_grad=False) 36 | N = N.expand(self.batch_shape + self.event_shape) 37 | while SExx.ndim>self.event_dim + self.batch_dim: 38 | SExx = SExx.sum(0) 39 | self.ss_update(SExx,N,lr) # inputs to ss_update must be batch + event consistent 40 | 41 | else: # data is shape sample_shape x batch_shape x event_shape with the first batch dimension having size 1 42 | 43 | for i in range(self.event_dim): 44 | p=p.unsqueeze(-1) 45 | SExx = X.pow(2)*p 46 | while SExx.ndim>self.event_dim + self.batch_dim: 47 | SExx = SExx.sum(0) 48 | p = p.sum(0) 49 | self.ss_update(SExx,p,lr) # inputs to ss_update must be batch + event consistent 50 | # p now has shape batch_shape + event_shape so it must be squeezed by the default event_shape which is 1 51 | 52 | def KLqprior(self): 53 | KL = KL + self.alpha.KLqprior() 54 | for i in range(self.event_dim): 55 | KL = KL.sum(-1) 56 | return KL 57 | 58 | def mean(self): 59 | return self.zeros(self.batch_shape+self.event_shape) 60 | 61 | def ESigma(self): 62 | return self.alpha.meaninv() 63 | 64 | def EinvSigma(self): 65 | return self.alpha.mean() 66 | 67 | def EinvSigmamu(self): 68 | return self.mean() 69 | 70 | def ElogdetinvSigma(self): 71 | return self.alpha.loggeomean() 72 | 73 | def EX(self): 74 | return self.mean() 75 | 76 | def EXX(self): 77 | return self.ESigma() + self.mean().pow(2) 78 | 79 | def EXinvUX(self): 80 | return self.mean().pow(2)*self.EinvSigma() 81 | 82 | def Res(self): 83 | return + 0.5*self.ElogdetinvSigma() - 0.5*np.log(2*np.pi) 84 | 85 | 86 | -------------------------------------------------------------------------------- /models/PoissonMixtureModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .dists import Mixture, Gamma 3 | class PoissonMixtureModel(Mixture): 4 | def __init__(self,alpha_0,beta_0): 5 | dist = Gamma(alpha_0,beta_0).to_event(1) 6 | super().__init__(dist) 7 | 8 | -------------------------------------------------------------------------------- /models/ReducedRankRegression.py: -------------------------------------------------------------------------------- 1 | # VBEM for Reduced Rank Regression. Unlike typical approaches, this is based upon a Bayesian 2 | # canonical correlation analysis with a pre-specified dimension for the latent space. 3 | # Generative model is: 4 | # y_t = A u_t + noise 5 | # x_t = B u_t + noise 6 | # u_t ~ N(0,I) 7 | # 8 | # Priors and posteriors over A and B are Matrix Normal Wishart 9 | # and the effective Regression coefficients are @^T 10 | # 11 | 12 | import torch 13 | import numpy as np 14 | 15 | from .dists import MatrixNormalWishart, MatrixNormalGamma 16 | from .dists import NormalGamma, NormalInverseWishart 17 | from .dists import MultivariateNormal_vector_format 18 | from .dists import MVN_ard 19 | from .dists import Delta 20 | 21 | print('Reduced Rank Regression: need to marginalize over U instead of using VB for prediction') 22 | 23 | class ReducedRankRegression(): 24 | def __init__(self,n,p,dim,batch_shape = (),pad_X=False,independent = False): 25 | self.n=n 26 | self.p=p 27 | self.dim=dim 28 | self.event_dim=2 29 | self.batch_shape = batch_shape 30 | self.batch_dim = len(batch_shape) 31 | self.event_shape = (dim,1) 32 | 33 | if independent is True: 34 | self.A = MatrixNormalGamma(torch.zeros(batch_shape + (n,dim),requires_grad=False),pad_X=pad_X) 35 | self.B = MatrixNormalGamma(torch.zeros(batch_shape + (p,dim),requires_grad=False),pad_X=pad_X) 36 | else: 37 | self.A = MatrixNormalWishart(torch.zeros(batch_shape + (n,dim),requires_grad=False),pad_X=pad_X) 38 | self.B = MatrixNormalWishart(torch.zeros(batch_shape + (p,dim),requires_grad=False),pad_X=pad_X) 39 | # self.U = MVN_ard(dim,batch_shape=batch_shape) 40 | self.U = NormalGamma(torch.ones(batch_shape), 41 | torch.zeros(batch_shape + (dim,)), 42 | 0.5*torch.ones(batch_shape+(dim,)), 43 | 0.5*torch.ones(batch_shape+(dim,))) 44 | # self.U = NormalInverseWishart(mu_0 = torch.zeros(batch_shape + (dim,))) 45 | self.ELBO_last = -torch.tensor(torch.inf) 46 | 47 | def raw_update(self,X,Y,iters=1,lr=1.0,verbose=False): 48 | sample_shape = X.shape[:1-self.event_dim-self.batch_dim] 49 | 50 | X=X.unsqueeze(-1) 51 | Y=Y.unsqueeze(-1) 52 | ELBO = self.ELBO_last 53 | for i in range(iters): 54 | invSigma, invSigmamu, Residual = self.B.Elog_like_X(X) 55 | invSigma_bw, invSigmamu_bw, Residual_bw = self.A.Elog_like_X(Y) 56 | 57 | invSigma = invSigma_bw + invSigma + self.U.EinvSigma() #torch.eye(self.dim) 58 | invSigmamu = invSigmamu_bw + invSigmamu + self.U.EinvSigmamu().unsqueeze(-1) # unsqueeze is for NIW 59 | Residual = Residual + Residual_bw + 0.5*self.U.ElogdetinvSigma() - 0.5*self.dim*np.log(2*np.pi) 60 | 61 | Sigma = invSigma.inverse() 62 | mu = Sigma@invSigmamu 63 | Residual_u = -0.5*(mu*invSigmamu).sum(-1).sum(-1) + 0.5*invSigma.logdet() - 0.5*np.log(2*np.pi) 64 | Residual = Residual - Residual_u 65 | 66 | self.logZ = Residual.sum(0) 67 | pu = MultivariateNormal_vector_format(mu=mu,Sigma=Sigma,invSigma=invSigma,invSigmamu=invSigmamu,Residual=Residual_u) 68 | self.pu = pu 69 | if verbose is True: 70 | self.ELBO_last = ELBO 71 | ELBO = self.logZ.sum() - self.KLqprior().sum() 72 | print('Percent change in ELBO = ',(ELBO-self.ELBO_last)/self.ELBO_last.abs()*100) 73 | 74 | self.A.update(pu,Delta(Y),lr=lr) 75 | self.B.update(pu,Delta(X),lr=lr) 76 | SExx = pu.EXXT().sum(0) 77 | SEx = pu.EX().sum(0).squeeze(-1) 78 | N=torch.ones(sample_shape[1:])*sample_shape[0] 79 | while SExx.ndim > self.event_dim + self.batch_dim: 80 | SExx = SExx.sum(0) 81 | SEx = SEx.sum(0) 82 | N=N.sum(0) 83 | 84 | self.U.ss_update(SExx.diagonal(dim1=-1,dim2=-2),SEx,N,lr=lr) # This is for NG 85 | # self.U.ss_update(SExx,SEx,N,lr=lr) # This is for NIW 86 | # self.U.ss_update(SExx,SEx,iters=2,lr=lr) # This is for ARD 87 | 88 | def Elog_like(self,X,Y): # also updates pu 89 | X=X.unsqueeze(-1) 90 | Y=Y.unsqueeze(-1) 91 | invSigma, invSigmamu, Residual = self.B.Elog_like_X(X) 92 | invSigma_bw, invSigmamu_bw, Residual_bw = self.A.Elog_like_X(Y) 93 | 94 | invSigma = invSigma_bw + invSigma + self.U.EinvSigma() #torch.eye(self.dim) 95 | invSigmamu = invSigmamu_bw + invSigmamu + self.U.EinvSigmamu().unsqueeze(-1) # unsqueeze is for NIW 96 | Residual = Residual + Residual_bw + 0.5*self.U.ElogdetinvSigma() - 0.5*self.dim*np.log(2*np.pi) 97 | 98 | Sigma = invSigma.inverse() 99 | mu = Sigma@invSigmamu 100 | Residual_u = -0.5*(mu*invSigmamu).sum(-1).sum(-1) + 0.5*invSigma.logdet() - 0.5*np.log(2*np.pi) 101 | Residual = Residual - Residual_u 102 | self.logZ = Residual.sum(0) 103 | self.pu = MultivariateNormal_vector_format(mu=mu,Sigma=Sigma,invSigma=invSigma,invSigmamu=invSigmamu,Residual=Residual_u) 104 | return Residual 105 | 106 | def update_parms(self,X,Y,p=None,lr=1): 107 | sample_shape = X.shape[:1-self.event_dim-self.batch_dim] 108 | self.A.update(self.pu,Delta(Y.unsqueeze(-1)),p=p,lr=lr) 109 | self.B.update(self.pu,Delta(X.unsqueeze(-1)),p=p,lr=lr) 110 | if p is None: 111 | SExx = self.pu.EXXT().sum(0) 112 | SEx = self.pu.EX().sum(0).squeeze(-1) 113 | N=torch.ones(sample_shape[1:])*sample_shape[0] 114 | else: 115 | SExx = (self.pi.EXXT()*p.view(p.shape+(1,1))).sum(0) 116 | SEx = (self.pu.EX()*p.view(p.shape+(1,1))).sum(0).squeeze(-1) 117 | N = p.sum(0) 118 | while SExx.ndim > self.event_dim + self.batch_dim: 119 | SExx = SExx.sum(0) 120 | SEx = SEx.sum(0) 121 | N=N.sum(0) 122 | self.U.ss_update(SExx.diagonal(dim1=-1,dim2=-2),SEx,N,lr=lr) # This is for NG 123 | 124 | 125 | def KLqprior(self): 126 | return self.A.KLqprior() + self.B.KLqprior() + self.U.KLqprior() 127 | 128 | def EW(self): 129 | return self.A.mean()@self.B.EXTinvU().transpose(-2,-1) 130 | 131 | def predict(self,X): 132 | invSigma, invSigmamu, Residual = self.B.Elog_like_X(X) 133 | invSigma = invSigma + self.U.EinvSigma() 134 | invSigmamu = invSigmamu + self.U.EinvSigmamu().unsqueeze(-1) 135 | Residual = Residual + 0.5*self.U.ElogdetinvSigma() - 0.5*self.dim*np.log(2*np.pi) 136 | return self.A.predict_given_pX(MultivariateNormal_vector_format(invSigma=invSigma,invSigmamu=invSigmamu)) 137 | 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .dists import * 2 | from .ARHMM import * 3 | from .BayesianFactorAnalysis import * 4 | from .BayesianTransformer import * 5 | from .BlockFactorAnalysis import * 6 | from .dHMM import * 7 | from .dMixture import * 8 | from .dMixtureofLinearTransforms import * 9 | from .GaussianMixtureModel import * 10 | from .HMM import * 11 | from .IsotropicGaussianMixtureModel import * 12 | from .LDS_px import * 13 | from .LDS import * 14 | from .MixLDS import * 15 | from .MixtureofLinearTransforms import * 16 | from .MultiNomialLogisticRegression import * 17 | from .NLRegression_Multinomial import * 18 | from .NLRegression import * 19 | from .NormalSparse import * 20 | from .PoissonMixtureModel import * 21 | from .ReducedRankRegression import * 22 | from .rHMM import * 23 | -------------------------------------------------------------------------------- /models/dMixture.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from .MultiNomialLogisticRegression import MultiNomialLogisticRegression 4 | class dMixture(): 5 | 6 | def __init__(self,dist,p): 7 | self.event_dim = 1 8 | self.batch_dim = dist.batch_dim - 1 9 | self.event_shape = dist.batch_shape[-1:] 10 | self.batch_shape = dist.batch_shape[:-1] 11 | self.pi = MultiNomialLogisticRegression(self.event_shape[-1],p,batch_shape = self.batch_shape,pad_X=True) 12 | self.dist = dist 13 | self.logZ = torch.tensor(-torch.inf,requires_grad=False) 14 | print('Untested') 15 | 16 | def update_assignments(self,X,Y): 17 | log_p = self.dist.Elog_like(Y.unsqueeze(-self.dist.event_dim-1)) + self.pi.log_predict(X) 18 | shift = log_p.max(-1,True)[0] 19 | log_p = log_p - shift 20 | self.logZ = ((log_p).logsumexp(-1,True) + shift).squeeze(-1) 21 | self.p = log_p.exp() 22 | self.p = self.p/self.p.sum(-1,True) 23 | self.NA = self.p 24 | while self.NA.ndim > self.event_dim + self.batch_dim: 25 | self.logZ = self.logZ.sum(0) 26 | self.NA = self.NA.sum(0) 27 | 28 | def update_parms(self,X,Y,lr=1.0): 29 | self.pi.raw_update(X,self.p,lr=lr) 30 | self.dist.raw_update(Y.unsqueeze(-self.dist.event_dim-1),self.p,lr) 31 | 32 | def raw_update(self,X,Y,iters=1,lr=1.0,verbose=False): 33 | ELBO = torch.tensor(-torch.inf) 34 | for i in range(iters): 35 | # E-Step 36 | ELBO_last = ELBO 37 | self.update_assignments(X,Y) 38 | ELBO = self.ELBO() 39 | self.update_parms(X,Y,lr) 40 | if verbose: 41 | print('Percent Change in ELBO: ',(ELBO-ELBO_last)/ELBO_last.abs()*100.0) 42 | 43 | def Elog_like(self,X,Y): 44 | #broken for non trivial batch shape because of incompatibility in dist.batch_shape with data shape 45 | log_p = self.dist.Elog_like(Y.unsqueeze(-self.dist.event_dim-1)) + self.pi.loggeomean(X) 46 | shift = log_p.max(-1,True)[0] 47 | return ((log_p - shift).exp().sum(-1,True) + shift).squeeze(-1) #logZ 48 | 49 | def KLqprior(self): 50 | KL = self.pi.KLqprior() + self.dist.KLqprior().sum(-1) 51 | for i in range(self.event_dim-1): 52 | KL = KL.sum(-1) 53 | return KL 54 | 55 | def ELBO(self): 56 | return self.logZ - self.KLqprior() 57 | 58 | def assignment_pr(self): 59 | return self.p 60 | 61 | def assignment(self): 62 | return self.p.argmax(-1) 63 | 64 | def means(self): 65 | return self.dist.mean() 66 | 67 | def event_average_f(self,function_string,A=None,keepdim=False): 68 | if A is None: 69 | return self.event_average(eval('self.dist.'+function_string)(),keepdim=keepdim) 70 | else: 71 | return self.event_average(eval('self.dist.'+function_string)(A),keepdim=keepdim) 72 | 73 | def average_f(self,function_string,A=None,keepdim=False): 74 | if A is None: 75 | return self.average(eval('self.dist.'+function_string)(),keepdim=keepdim) 76 | else: 77 | return self.average(eval('self.dist.'+function_string)(A),keepdim=keepdim) 78 | 79 | def average(self,A,keepdim=False): 80 | return (A*self.p).sum(-1,keepdim) 81 | 82 | ### Compute special expectations used for VB inference 83 | def event_average(self,A,keepdim=False): # returns sample_shape + W.event_shape 84 | # A is mix_batch_shape + mix_event_shape + event_shape 85 | out = (A*self.p.view(self.p.shape+(1,)*self.dist.event_dim)).sum(-1-self.dist.event_dim,keepdim) 86 | for i in range(self.event_dim-1): 87 | out = out.sum(-self.dist.event_dim-1,keepdim) 88 | return out 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /models/dMixtureofLinearTransforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .MultiNomialLogisticRegression import MultiNomialLogisticRegression 4 | from .dists import MatrixNormalGamma, MatrixNormalWishart, MultivariateNormal_vector_format, MVN_ard 5 | 6 | class dMixtureofLinearTransforms(): 7 | # This basically a mxiture of linear transforms, p(y|x,z) with a mixture components driven by 8 | # z ~ p(z|x) which is MNLR. Component number give the number of different z's, latent_dim gives the dimension of x, and obs_dim gives the dimension 9 | # of y. 10 | 11 | def __init__(self, n, p, mixture_dim, batch_shape=(),pad_X=True, type = 'Wishart'): 12 | print('Backward method no implemented yet') 13 | self.batch_shape = batch_shape 14 | self.batch_dim = len(batch_shape) 15 | self.event_dim = 2 16 | self.n = n 17 | self.p = p 18 | self.mix_dim = mixture_dim 19 | self.ELBO_last = -torch.tensor(torch.inf) 20 | 21 | if type == 'Wishart': 22 | self.A = MatrixNormalWishart(mu_0 = torch.zeros(batch_shape + (mixture_dim,n,p),requires_grad=False), 23 | U_0=torch.zeros(batch_shape + (mixture_dim,n,n),requires_grad=False)+torch.eye(n,requires_grad=False)*mixture_dim**2, 24 | pad_X=pad_X) 25 | elif type == 'Gamma': 26 | self.A = MatrixNormalGamma(mu_0 = torch.zeros(batch_shape + (mixture_dim,n,p),requires_grad=False), 27 | U_0=torch.zeros(batch_shape + (mixture_dim,n),requires_grad=False)+torch.ones(n,requires_grad=False)*mixture_dim**2, 28 | pad_X=pad_X) 29 | elif type == 'MVN_ard': 30 | raise NotImplementedError 31 | else: 32 | raise ValueError('type must be either Wishart (default) or Gamma') 33 | self.pi = MultiNomialLogisticRegression(mixture_dim,p,batch_shape = batch_shape,pad_X=True) 34 | 35 | def raw_update(self,X,Y,iters=1,lr=1.0,verbose=False): 36 | AX = X.unsqueeze(-1) # make vector 37 | AY = Y.unsqueeze(-1) 38 | AX = AX.view(AX.shape[:-2] + (self.batch_dim+1)*(1,) + AX.shape[-2:]) # add z dim and batch_dim 39 | AY = AY.view(AY.shape[:-2] + (self.batch_dim+1)*(1,) + AY.shape[-2:]) 40 | 41 | for i in range(iters): 42 | log_p = self.A.Elog_like(AX,AY) + self.pi.log_predict(X) # A.Elog_like is sample x batch x component 43 | shift = log_p.max(-1,True)[0] 44 | log_p = log_p - shift 45 | logZ = (shift.squeeze(-1) + log_p.logsumexp(-1)).sum(0) 46 | p = log_p.exp() 47 | p = p/p.sum(-1,True) 48 | 49 | ELBO = logZ - self.KLqprior() 50 | if verbose: print("Percent Change in ELBO = ",((ELBO-self.ELBO_last)/self.ELBO_last.abs()).data*100) 51 | self.ELBO_last = ELBO 52 | 53 | self.A.raw_update(AX,AY,p=p,lr=lr) 54 | self.pi.raw_update(X,p,lr=lr,verbose=False) 55 | 56 | def postdict(self, Y): 57 | 58 | invSigma, invSigmamu, Res = self.A.Elog_like_X(Y.unsqueeze(-2).unsqueeze(-1)) # Res is sample x batch x component 59 | # invSigma, invSigmamu, Res = self.A.Elog_like_X_given_pY(pY.unsqueeze(-3)) # Res is sample x batch x component 60 | like_X = MultivariateNormal_vector_format(invSigma = invSigma.unsqueeze(0).movedim(-3,-3-self.batch_dim), invSigmamu = invSigmamu.movedim(-3,-3-self.batch_dim)) 61 | Res = Res.movedim(-1,-1-self.batch_dim) # This res is just from the A, does not include like_X contribution 62 | 63 | Z = torch.eye(self.mix_dim) 64 | for i in range(self.batch_dim): 65 | Z = Z.unsqueeze(-2) 66 | invSigma, invSigmamu, Sigma, mu, Res_z = self.pi.Elog_like_X(like_X,Z,iters=4) # Res_z includes input like_X contrib, but not output like_X contrib 67 | Res = Res + Res_z + 0.5*(mu*invSigmamu).sum(-2).squeeze(-1) - 0.5*invSigma.logdet() + like_X.dim/2.0*np.log(2*np.pi) 68 | logZ = Res.logsumexp(-1-self.batch_dim,True) 69 | logp = Res - logZ 70 | logZ = logZ.squeeze(-1) 71 | p = logp.exp() 72 | 73 | pv = p.view(p.shape+(1,1)) 74 | invSigma = (invSigma*pv).sum(-3-self.batch_dim) 75 | invSigmamu = (invSigmamu*pv).sum(-3-self.batch_dim) 76 | return MultivariateNormal_vector_format(invSigma = invSigma, invSigmamu = invSigmamu), logZ.squeeze(-1-self.batch_dim), p 77 | 78 | # Sigma = ((Sigma+mu@mu.transpose(-2,-1))*pv).sum(-3-self.batch_dim) 79 | # mu = (mu*pv).sum(-3-self.batch_dim) 80 | # Sigma = Sigma - mu@mu.transpose(-2,-1) 81 | # return MultivariateNormal_vector_format(Sigma = Sigma, mu = mu), logZ.squeeze(-1-self.batch_dim), p 82 | # return MultivariateNormal_vector_format(invSigma = invSigma, invSigmamu = invSigmamu), logZ.squeeze(-1-self.batch_dim) 83 | 84 | def predict(self,X): # update to handle batching 85 | p=self.pi.predict(X) 86 | pv=p.view(p.shape+(1,1)) 87 | Xv = X.view(X.shape[:-1]+(1,) + X.shape[-1:] + (1,)) 88 | mu_y, Sigma_y_y = self.A.predict(Xv)[0:2] 89 | 90 | # invSigma = (invSigma_y_y*pv).sum(-3) 91 | # invSigmamu = (invSigmamu_y*pv).sum(-3) 92 | # Sigma = invSigma.inverse() 93 | # mu = Sigma@invSigmamu 94 | 95 | Sigma = ((Sigma_y_y + mu_y@mu_y.transpose(-1,-2))*pv).sum(-3) 96 | mu = (mu_y*pv).sum(-3) 97 | Sigma = Sigma - mu@mu.transpose(-2,-1) 98 | return mu, Sigma, p 99 | 100 | def update(self,pX,pY,iters=1,lr=1.0,verbose=False): 101 | # Expects X and Y to be batch consistent, i.e. X is sample x batch x p 102 | # Y is sample x batch x n 103 | pAX = pX.unsqueeze(-3) 104 | pAY = pY.unsqueeze(-3) 105 | for i in range(iters): 106 | log_p = self.A.Elog_like_given_pX_pY(pAX,pAY) + self.pi.log_forward(pX) 107 | shift = log_p.max(-1,True)[0] 108 | log_p = log_p - shift 109 | self.logZ = shift.squeeze(-1) + log_p.logsumexp(-1) 110 | p = log_p.exp() 111 | p = p/p.sum(-1,True) 112 | self.A.update(pAX,pAY,p=p,lr=lr) 113 | self.pi.update(pX,p,lr=lr,verbose=False) 114 | self.NA = p.sum(0) 115 | 116 | ELBO = self.logZ.sum() - self.KLqprior().sum() 117 | if verbose: 118 | print('Percent Change in ELBO: ', (ELBO-self.ELBO_last)/self.ELBO_last.abs()) 119 | self.ELBO_last = ELBO 120 | 121 | def forward(self,pX): 122 | p = self.pi.forward(pX) 123 | pY = self.A.forward(pX.unsqueeze(-3)) 124 | mu = (pY.mean()*p.view(p.shape+(1,1))).sum(-3) 125 | Sigma = (pY.EXXT()*p.view(p.shape+(1,1))).sum(-3)-mu@mu.transpose(-2,-1) 126 | return MultivariateNormal_vector_format(Sigma = Sigma, mu = mu) 127 | 128 | def forward_mix(self,pX): 129 | return self.A.forward(pX.unsqueeze(-3)), self.pi.forward(pX) 130 | 131 | def backward(self,pY): 132 | pX, ResA = self.A.backward(pY.unsqueeze(-3)) 133 | invSigma, invSigmamu, Sigma, mu, Res = self.pi.backward(pX,torch.eye(self.mix_dim)) 134 | log_p = Res + ResA 135 | p = log_p - log_p.max(-1,True)[0] 136 | p = p.exp() 137 | p = p/p.sum(-1,True) 138 | p = p.unsqueeze(-1).unsqueeze(-1) 139 | 140 | invSigma = (invSigma*p).sum(-3) 141 | invSigmamu = (invSigmamu*p).sum(-3) 142 | 143 | return MultivariateNormal_vector_format(invSigma = invSigma, invSigmamu = invSigmamu) 144 | 145 | def backward_mix(self,pY): 146 | pX, ResA = self.A.backward(pY.unsqueeze(-3)) 147 | invSigma, invSigmamu, Sigma, mu, Res = self.pi.backward(pX,torch.eye(self.mix_dim)) 148 | log_p = Res + ResA 149 | shift = log_p.max(-1,True)[0] 150 | log_p = log_p - shift 151 | Res = (shift.squeeze(-1) + log_p.logsumexp(-1)) 152 | p = p.exp() 153 | p = p/p.sum(-1,True) 154 | p = p.unsqueeze(-1).unsqueeze(-1) 155 | pX = MultivariateNormal_vector_format(invSigma = invSigma, invSigmamu= invSigmamu, mu = mu, Sigma = Sigma) 156 | Res = Res - pX.Res() 157 | 158 | return MultivariateNormal_vector_format(invSigma = invSigma, invSigmamu= invSigmamu, mu = mu, Sigma = Sigma), p, Res 159 | 160 | def KLqprior(self): 161 | return self.A.KLqprior().sum(-1) + self.pi.KLqprior() 162 | 163 | 164 | -------------------------------------------------------------------------------- /models/dists/ConjugatePrior.py: -------------------------------------------------------------------------------- 1 | 2 | class ConjugatePrior(): 3 | def __init__(self): 4 | self.event_dim_0 = 0 # smallest possible event dimension 5 | self.event_dim = 0 6 | self.event_shape = () 7 | self.batch_dim = 0 8 | self.batch_shape = () 9 | self.nat_dim = 0 10 | self.nat_parms_0 = 0 11 | self.nat_parms = 0 12 | 13 | def to_event(self,n): 14 | if n < 1: 15 | return self 16 | self.event_dim = self.event_dim + n 17 | self.batch_dim = self.batch_dim - n 18 | self.event_shape = self.batch_shape[-n:] + self.event_shape 19 | self.batch_shape = self.batch_shape[:-n] 20 | return self 21 | 22 | def T(self,X): # evaluate the sufficient statistic 23 | pass 24 | 25 | def ET(self): # expected value of the sufficient statistic given the natural parameters, self.nat_parms 26 | pass 27 | 28 | def logZ(self): # log partition function of the natural parameters often called A(\eta) 29 | pass 30 | 31 | def logZ_ub(self): # upper bound on the log partition function 32 | pass 33 | 34 | def ss_update(self,ET,lr=1.0): 35 | self.nat_parms = ET + self.nat_parms_0 36 | while ET.ndim > self.event_dim + self.batch_dim: 37 | ET = ET.sum(0) 38 | self.nat_parms = self.nat_parms*(1-lr) + lr*(ET+self.nat_parms_0) 39 | 40 | def raw_update(self,X,p=None,lr=1.0): 41 | if p is None: 42 | EmpT = self.T(X) 43 | else: # assumes p is sample by batch 44 | if(self.batch_dim==0): 45 | sample_shape = p.shape 46 | else: 47 | sample_shape = p.shape[:-self.batch_dim] 48 | EmpT = self.T(X.view(sample_shape+self.batch_dim*(1,)+self.event_shape))*p.view(p.shape + self.nat_dim*(1,)) 49 | while EmpT.ndim > self.event_dim + self.batch_dim: 50 | EmpT = EmpT.sum(0) 51 | self.ss_update(EmpT,lr) 52 | 53 | def KL_qprior_event(self): # returns the KL divergence between prior (nat_parms_0) and posterior (nat_parms) 54 | pass 55 | 56 | def KL_qprior(self): 57 | KL = self.KL_qprior_event() 58 | for i in range(self.event_dim - self.event_dim_0): 59 | KL = KL.sum(-1) 60 | 61 | def Elog_like_0(self,X): # reuturns the likelihood of X under the default event_shape 62 | pass 63 | 64 | def Elog_like(self,X): 65 | ELL = self.Elog_like_0(self,X) 66 | for i in range(self.event_dim - self.event_dim_0): 67 | ELL = ELL.sum(-1) 68 | return ELL 69 | 70 | def sample(self,sample_shape=()): 71 | pass 72 | 73 | 74 | -------------------------------------------------------------------------------- /models/dists/Delta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | class Delta(): 3 | def __init__(self,X): 4 | self.X = X 5 | 6 | def unsqueeze(self,dim): # only appliles to batch 7 | self.X = self.X.unsqueeze(dim) 8 | return self 9 | 10 | def squeeze(self,dim): # only appliles to batch 11 | self.X = self.X.squeeze(dim) 12 | return self 13 | 14 | # def Elog_like(self): 15 | # torch.ones(self.X.shape[:-self.event_dim],requires_grad=False) 16 | 17 | # def KLqprior(self): 18 | # return torch.zeros(self.X.shape[:-self.event_dim],requires_grad=False) 19 | 20 | # def ELBO(self): 21 | # return torch.zeros(self.X.shape[:-self.event_dim],requires_grad=False) 22 | @property 23 | def shape(self): 24 | return self.X.shape 25 | 26 | def mean(self): 27 | return self.X 28 | 29 | def EX(self): 30 | return self.X 31 | 32 | def EXXT(self): 33 | return self.X@self.X.transpose(-1,-2) 34 | 35 | def EXTX(self): 36 | return self.X.transpose(-1,-2)@self.X 37 | 38 | def EXTAX(self,A): 39 | return self.X.transpose(-1,-2)@A@self.X 40 | 41 | def EXX(self): 42 | return self.X**2 43 | 44 | def ElogX(self): 45 | return torch.log(self.X) 46 | 47 | def E(self,f): 48 | return f(self.X) 49 | 50 | def logZ(self): 51 | return torch.zeros(self.batch_shape) 52 | 53 | 54 | -------------------------------------------------------------------------------- /models/dists/DiagonalWishart.py: -------------------------------------------------------------------------------- 1 | # Variational Bayesian Expectation Maximization for linear regression and mixtures of linear models 2 | # with Gaussian observations 3 | 4 | import torch 5 | import numpy as np 6 | from .Gamma import Gamma 7 | 8 | class DiagonalWishart(): 9 | 10 | def __init__(self,nu_0,U_0): # best to set nu_0 >= 2 11 | # here nu_0 and invU are same shape 12 | self.dim = U_0.shape[-1] 13 | self.event_dim = 1 14 | self.batch_dim = U_0.ndim - 1 15 | self.event_shape = U_0.shape[-1:] 16 | self.batch_shape = U_0.shape[:-1] 17 | self.gamma = Gamma(nu_0,1.0/U_0).to_event(1) 18 | 19 | def to_event(self,n): 20 | if n==0: 21 | return self 22 | self.event_dim = self.event_dim + n 23 | self.batch_dim = self.batch_dim - n 24 | self.event_shape = self.batch_shape[-n:] + self.event_shape 25 | self.batch_shape = self.batch_shape[:-n] 26 | self.gamma.to_event(n) 27 | return self 28 | 29 | def ss_update(self,SExx,n,lr=1.0): 30 | idx = n>1 31 | SExx = SExx*(idx) 32 | self.gamma.ss_update(n/2.0,SExx/2.0,lr) 33 | 34 | def KLqprior(self): 35 | return self.gamma.KLqprior() 36 | 37 | def logZ(self): 38 | return self.gamma.logZ() 39 | 40 | # These expectations return Matrices with diagonal elements 41 | # generally one should avoid using these function and instead 42 | # use self.gamma.mean(), self.gamma.meaninv(), self.gamma.loggeomean() 43 | def ESigma(self): 44 | return self.tensor_diag(self.gamma.meaninv()) 45 | 46 | def EinvSigma(self): 47 | return self.tensor_diag(self.gamma.mean()) 48 | 49 | def ElogdetinvSigma(self): 50 | return self.gamma.loggeomean().sum(-1) 51 | 52 | def mean(self): 53 | return self.tensor_diag(self.gamma.mean()) 54 | 55 | def tensor_diag(self,A): 56 | return A.unsqueeze(-1)*torch.eye(A.shape[-1],requires_grad=False) 57 | 58 | def tensor_extract_diag(self,A): 59 | return A.diagonal(dim=-2,dim1=-1) 60 | 61 | class DiagonalWishart_UnitTrace(DiagonalWishart): 62 | 63 | def suminv_d_plus_x(self,x): 64 | return (self.gamma.alpha/(self.gamma.beta+x)).sum(-1,True) 65 | 66 | def suminv_d_plus_x_prime(self,x): 67 | return -(self.gamma.alpha/(self.gamma.beta+x)**2).sum(-1,True) 68 | 69 | def ss_update(self,SExx,n,lr=1.0,iters=10): 70 | super().ss_update(SExx,n,lr=lr) 71 | # x=self.gamma.alpha.sum(-1,True) 72 | x = torch.zeros(self.gamma.beta.shape[:-1]+(1,),requires_grad=False) 73 | for i in range(iters): 74 | x = x + (10*self.dim-self.suminv_d_plus_x(x))/self.suminv_d_plus_x_prime(x) 75 | idx = x<-self.gamma.beta.min(-1,True)[0] 76 | x = x*(~idx) + (-self.gamma.beta.min(-1,True)[0]+1e-4)*idx # ensure positive definite 77 | 78 | self.rescale = 1+x/self.gamma.beta 79 | self.gamma.beta = self.gamma.beta+x 80 | 81 | 82 | -------------------------------------------------------------------------------- /models/dists/Dirichlet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Dirichlet(): 4 | def __init__(self,alpha_0): 5 | self.event_dim = 1 6 | self.dim = alpha_0.shape[-1] 7 | self.batch_dim = alpha_0.ndim - 1 8 | self.event_shape = alpha_0.shape[-1:] 9 | self.batch_shape = alpha_0.shape[:-1] 10 | self.alpha_0 = alpha_0 11 | self.alpha = self.alpha_0 + 2.0*torch.rand(self.alpha_0.shape,requires_grad=False)*self.alpha_0 # type: ignore 12 | 13 | def to_event(self,n): 14 | if n == 0: 15 | return self 16 | self.event_dim = self.event_dim + n 17 | self.batch_dim = self.batch_dim - n 18 | self.event_shape = self.batch_shape[-n:] + self.event_shape 19 | self.batch_shape = self.batch_shape[:-n] 20 | return self 21 | 22 | def ss_update(self,NA,lr=1.0): 23 | alpha = NA + self.alpha_0 24 | self.alpha = (alpha-self.alpha)*lr + self.alpha 25 | 26 | def raw_update(self,X,p=None,lr=1.0): 27 | if p is None: 28 | # assumes X is sample x batch x event 29 | NA = X 30 | else: 31 | # assumes X is sample by event and p is sample x batch 32 | for i in range(self.event_dim): 33 | p=p.unsqueeze(-1) 34 | for i in range(self.batch_dim): 35 | X=X.unsqueeze(-self.event_dim-1) 36 | NA = X*p 37 | while NA.ndim > self.event_dim + self.batch_dim: 38 | NA = NA.sum(0) 39 | self.ss_update(NA,lr) 40 | 41 | def mean(self): 42 | return self.alpha/self.alpha.sum(-1,keepdim=True) 43 | 44 | def loggeomean(self): 45 | return self.alpha.digamma() - self.alpha.sum(-1,keepdim=True).digamma() 46 | 47 | def ElogX(self): 48 | return self.alpha.digamma() - self.alpha.sum(-1,keepdim=True).digamma() 49 | 50 | def var(self): 51 | alpha_sum = self.alpha.sum(-1,keepdim=True) 52 | mean = self.mean() 53 | return mean*(1-mean)/(alpha_sum+1) 54 | 55 | def covariance(self): 56 | alpha_sum = self.alpha.sum(-1,keepdim=True) 57 | mean = self.mean() 58 | return (mean/(alpha_sum+1)).unsqueeze(-1)*torch.eye(self.dim,requires_grad=False)-(mean/alpha_sum+1).unsqueeze(-1)*(1-mean.unsqueeze(-2)) 59 | 60 | def EXXT(self): 61 | return self.mean().unsqueeze(-1)*self.mean().unsqueeze(-2) + self.covariance() 62 | 63 | def KL_lgamma(self,x): 64 | out = x.lgamma() 65 | out[out== torch.inf]=0 66 | return out 67 | 68 | def KL_digamma(self,x): 69 | out = x.digamma() 70 | out[out== -torch.inf]=0 71 | return out 72 | 73 | def KLqprior(self): 74 | alpha_sum = self.alpha.sum(-1) 75 | alpha_0_sum = self.alpha_0.sum(-1) 76 | 77 | KL = alpha_sum.lgamma() - self.KL_lgamma(self.alpha).sum(-1) 78 | KL = KL - alpha_0_sum.lgamma() + self.KL_lgamma(self.alpha_0).sum(-1) 79 | KL = KL + ((self.alpha-self.alpha_0)*(self.KL_digamma(self.alpha)-alpha_sum.digamma().unsqueeze(-1))).sum(-1) 80 | 81 | while KL.ndim > self.batch_dim: 82 | KL = KL.sum(-1) 83 | return KL 84 | 85 | def logZ(self): 86 | return self.alpha.lgamma().sum(-1) - self.alpha.sum(-1).lgamma() 87 | 88 | def Elog_like(self,X): 89 | # assumes multinomial observations with data.shape = samples x batch_shape* x event_shape 90 | # returns sample shape x batch shape 91 | ELL = (X*self.loggeomean()).sum(-1) + (1+X.sum(-1)).lgamma() - (1+X).lgamma().sum(-1) 92 | for i in range(self.event_dim-1): 93 | ELL = ELL.sum(-1) 94 | return ELL 95 | -------------------------------------------------------------------------------- /models/dists/Gamma.py: -------------------------------------------------------------------------------- 1 | # Gamma distribution as conjugate prior for Poisson distribution 2 | # raw update assumes Poisson observation model 3 | 4 | import torch 5 | import numpy as np 6 | 7 | class Gamma(): 8 | def __init__(self,alpha,beta): 9 | self.event_dim = 0 10 | self.event_shape = () 11 | self.batch_dim = alpha.ndim 12 | self.batch_shape = alpha.shape 13 | self.alpha_0 = alpha 14 | self.beta_0 = beta 15 | self.alpha = alpha + torch.rand(alpha.shape,requires_grad=False) 16 | self.beta = beta + torch.rand(alpha.shape,requires_grad=False) 17 | 18 | def to_event(self,n): 19 | if n == 0: 20 | return self 21 | self.event_dim = self.event_dim + n 22 | self.batch_dim = self.batch_dim - n 23 | self.event_shape = self.batch_shape[-n:] + self.event_shape 24 | self.batch_shape = self.batch_shape[:-n] 25 | return self 26 | 27 | def ss_update(self,SElogx,SEx,lr=1.0): 28 | alpha = self.alpha_0 + SElogx 29 | beta = self.beta_0 + SEx 30 | self.alpha = (alpha-self.alpha)*lr + self.alpha 31 | self.beta = (beta-self.beta)*lr + self.beta 32 | 33 | def raw_update(self,X,p=None,lr=1.0): 34 | 35 | if p is None: 36 | # assumes X is sample x batch x event 37 | sample_shape = X.shape[:-self.event_dim-self.batch_dim] 38 | n = torch.tensor(np.prod(sample_shape),requires_grad=False) 39 | n = n.expand(self.batch_shape) 40 | SEx=X 41 | for i in range(len(sample_shape)): 42 | SEx = SEx.sum(0) 43 | 44 | else: 45 | n=p 46 | for i in range(self.event_dim): 47 | n=n.unsqueeze(-1) # now p is sample x batch x event 48 | for i in range(self.batch_dim): 49 | X=X.unsqueeze(-self.event_dim-1) 50 | SEx = X*n 51 | 52 | while SEx.ndim>self.event_dim + self.batch_dim: 53 | SEx = SEx.sum(0) 54 | n = n.sum(0) 55 | 56 | self.ss_update(SEx,n,lr) 57 | 58 | def Elog_like(self,X): # ASSUMES POISSON OBSERVATION MODEL 59 | for i in range(self.batch_dim): 60 | X=X.unsqueeze(-self.event_dim-1) 61 | ELL = X*self.loggeomean()- (X+1).lgamma() - self.mean() 62 | for i in range(self.event_dim): 63 | ELL = ELL.sum(-1) 64 | return ELL 65 | 66 | def mean(self): 67 | return self.alpha/self.beta 68 | 69 | def var(self): 70 | return self.alpha/self.beta**2 71 | 72 | def meaninv(self): 73 | return self.beta/(self.alpha-1) 74 | 75 | def ElogX(self): 76 | return self.alpha.digamma() - self.beta.log() 77 | 78 | def loggeomean(self): 79 | return self.alpha.log() - self.beta.log() 80 | 81 | def entropy(self): 82 | return self.alpha.log() - self.beta.log() + self.alpha.lgamma() + (1-self.alpha)*self.alpha.digamma() 83 | 84 | def logZ(self): 85 | return -self.alpha*self.beta.log() + self.alpha.lgamma() 86 | 87 | def logZprior(self): 88 | return -self.alpha_0*self.beta_0.log() + self.alpha_0.lgamma() 89 | 90 | def KLqprior(self): 91 | KL = (self.alpha-self.alpha_0)*self.alpha.digamma() - self.alpha.lgamma() + self.alpha_0.lgamma() + self.alpha_0*(self.beta.log()-self.beta_0.log()) + self.alpha*(self.beta_0/self.beta-1) 92 | for i in range(self.event_dim): 93 | KL = KL.sum(-1) 94 | return KL 95 | 96 | 97 | -------------------------------------------------------------------------------- /models/dists/MVN_ard.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .Gamma import Gamma 4 | 5 | class MVN_ard(): 6 | def __init__(self,dim,batch_shape=(),scale=1): 7 | 8 | self.dim = dim 9 | self.event_dim = 2 10 | self.event_dim_0 = 2 11 | self.event_shape = (dim,1) 12 | self.batch_shape = batch_shape 13 | self.batch_dim = len(self.batch_shape) 14 | self.mu = torch.randn(batch_shape + (dim,1),requires_grad=False)*scale 15 | self.invSigma = torch.zeros(batch_shape + (dim,dim),requires_grad=False) + torch.eye(dim,requires_grad=False) 16 | self.Sigma = self.invSigma 17 | self.logdetinvSigma = self.invSigma.logdet() 18 | self.invSigmamu = self.invSigma@self.mu 19 | self.alpha = Gamma(torch.ones(batch_shape+(dim,),requires_grad=False),torch.ones(batch_shape+(dim,),requires_grad=False)) 20 | 21 | 22 | def to_event(self,n): 23 | if n == 0: 24 | return self 25 | self.event_dim = self.event_dim + n 26 | self.batch_dim = self.batch_dim - n 27 | self.event_shape = self.batch_shape[-n:] + self.event_shape 28 | self.batch_shape = self.batch_shape[:-n] 29 | return self 30 | 31 | def ss_update(self,SExx,SEx, iters = 1, lr=1.0): 32 | 33 | for i in range(iters): 34 | invSigma = SExx + self.alpha.mean().unsqueeze(-1)*torch.eye(self.dim,requires_grad=False) 35 | invSigmamu = SEx 36 | self.invSigma = (1-lr)*self.invSigma + lr*invSigma 37 | self.Sigma = self.invSigma.inverse() 38 | self.invSigmamu = (1-lr)*self.invSigmamu + lr*invSigmamu 39 | self.mu = self.Sigma@self.invSigmamu 40 | self.alpha.ss_update(0.5,0.5*self.EXXT().diagonal(dim1=-1,dim2=-2),lr) 41 | 42 | self.logdetinvSigma = self.invSigma.logdet() 43 | 44 | def raw_update(self,X,p=None,lr=1.0): # assumes X is a vector and p is sample x batch 45 | 46 | if p is None: 47 | SEx = X 48 | SExx = X@X.transpose(-2,-1) 49 | sample_shape = X.shape[:-self.event_dim-self.batch_dim] 50 | n = torch.tensor(np.prod(sample_shape),requires_grad=False) 51 | n = n.expand(self.batch_shape + self.event_shape[:-2]) 52 | while SEx.ndim>self.event_dim + self.batch_dim: 53 | SExx = SExx.sum(0) 54 | SEx = SEx.sum(0) 55 | self.ss_update(SExx,SEx,n,lr) # inputs to ss_update must be batch + event consistent 56 | 57 | else: # data is shape sample_shape x batch_shape x event_shape with the first batch dimension having size 1 58 | 59 | for i in range(self.event_dim): 60 | p=p.unsqueeze(-1) 61 | SExx = X@X.transpose(-2,-1)*p 62 | SEx = X*p 63 | while SEx.ndim>self.event_dim + self.batch_dim: 64 | SExx = SExx.sum(0) 65 | SEx = SEx.sum(0) 66 | p = p.sum(0) 67 | self.ss_update(SExx,SEx,p.squeeze(-1).squeeze(-1),lr) # inputs to ss_update must be batch + event consistent 68 | # p now has shape batch_shape + event_shape so it must be squeezed by the default event_shape which is 1 69 | 70 | def KLqprior(self): 71 | KL = 0.5*(self.mu.pow(2).squeeze(-1)*self.alpha.mean()).sum(-1) - 0.5*self.alpha.loggeomean().sum(-1) + 0.5*self.ElogdetinvSigma() 72 | KL = KL + self.alpha.KLqprior().sum(-1) 73 | for i in range(self.event_dim-2): 74 | KL = KL.sum(-1) 75 | return KL 76 | 77 | def mean(self): 78 | return self.mu 79 | 80 | def ESigma(self): 81 | return self.Sigma 82 | 83 | def EinvSigma(self): 84 | return self.invSigma 85 | 86 | def EinvSigmamu(self): 87 | return self.invSigmamu 88 | 89 | def ElogdetinvSigma(self): 90 | return self.logdetinvSigma 91 | 92 | def EX(self): 93 | return self.mean() 94 | 95 | def EXXT(self): 96 | return self.ESigma() + self.mean()@self.mean().transpose(-2,-1) 97 | 98 | def EXTX(self): 99 | return self.ESigma().sum(-1).sum(-1) + self.mean().pow(2).sum(-2).squeeze(-1) 100 | 101 | def EXTinvUX(self): 102 | return (self.mean().transpose(-2,-1)@self.EinvSigma()@self.mean()).squeeze(-1).squeeze(-1) 103 | 104 | def Res(self): 105 | return - 0.5*(self.mean()*self.EinvSigmamu()).sum(-1).sum(-1) + 0.5*self.ElogdetinvSigma() - 0.5*self.dim*np.log(2*np.pi) 106 | 107 | 108 | -------------------------------------------------------------------------------- /models/dists/Mixture.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | #from .Dirichlet import Dirichlet 4 | from .Dirichlet import Dirichlet 5 | class Mixture(): 6 | # This class takes takes in a distribution with non trivial batch shape and 7 | # produces a mixture distribution with the number of mixture components equal 8 | # to the terminal dimension of the batch shape. The mixture distribution 9 | # has batch shape equal to the batch shape of the input distribution minus the final dimension 10 | # 11 | # IMPORTANT: This routine expects data to be sample_shape + dist.batch_shape[:-1] + (1,) + dist.event_shape 12 | # or if running VB batches in parallel: sample_shape + (1,)*mix.batch_dim + (1,) + dist.event_shape 13 | # when this is the case the observations will not need to be reshaped at any time. Only p will be reshaped for raw_updates 14 | 15 | 16 | def __init__(self,dist): 17 | self.event_dim = 1 18 | self.batch_dim = dist.batch_dim - 1 19 | self.event_shape = dist.batch_shape[-1:] 20 | self.batch_shape = dist.batch_shape[:-1] 21 | self.pi = Dirichlet(0.5*torch.ones(self.batch_shape+self.event_shape,requires_grad=False)) 22 | self.dist = dist 23 | self.logZ = torch.tensor(-torch.inf,requires_grad=False) 24 | 25 | def to_event(self,n): 26 | if n == 0: 27 | return self 28 | self.event_dim = self.event_dim + n 29 | self.event_shape = self.batch_shape[-n:] + self.event_shape 30 | self.batch_shape = self.batch_shape[:-n] 31 | self.pi.to_event(n) 32 | self.dist.to_event(n) 33 | return self 34 | 35 | def update_assignments(self,X): 36 | log_p = self.Elog_like(X) 37 | shift = log_p.max(-1,True)[0] 38 | log_p = log_p - shift 39 | self.logZ = ((log_p).logsumexp(-1,True) + shift).squeeze(-1) 40 | self.p = log_p.exp() 41 | self.p = self.p/self.p.sum(-1,True) 42 | self.NA = self.p 43 | while self.NA.ndim > self.event_dim + self.batch_dim: 44 | self.logZ = self.logZ.sum(0) 45 | self.NA = self.NA.sum(0) 46 | 47 | def update_parms(self,X,lr=1.0): 48 | self.pi.ss_update(self.NA,lr=lr) 49 | self.update_dist(X,lr=lr) 50 | 51 | def raw_update(self,X,iters=1,lr=1.0,verbose=False): 52 | self.update(X,iters=iters,lr=lr,verbose=verbose) 53 | 54 | def update(self,X,iters=1,lr=1.0,verbose=False): 55 | # Expects X to be sample_shape + dist.batch_shape[:-1] + (1,) + dist.event_shape 56 | ELBO = torch.tensor(-torch.inf) 57 | for i in range(iters): 58 | # E-Step 59 | ELBO_last = ELBO 60 | self.update_assignments(X) 61 | ELBO = self.ELBO() 62 | self.update_parms(X,lr) 63 | if verbose: 64 | print('Percent Change in ELBO: ',(ELBO-ELBO_last)/ELBO_last.abs()*100.0) 65 | 66 | def update_dist(self,X,lr): 67 | self.dist.raw_update(X,self.p,lr) 68 | 69 | def Elog_like(self,X): 70 | #broken for non trivial batch shape because of incompatibility in dist.batch_shape with data shape 71 | return self.dist.Elog_like(X) + self.pi.loggeomean() 72 | 73 | def KLqprior(self): 74 | KL = self.pi.KLqprior() + self.dist.KLqprior().sum(-1) 75 | for i in range(self.event_dim-1): 76 | KL = KL.sum(-1) 77 | return KL 78 | 79 | def ELBO(self): 80 | return self.logZ - self.KLqprior() 81 | 82 | def assignment_pr(self): 83 | return self.p 84 | 85 | def assignment(self): 86 | return self.p.argmax(-1) 87 | 88 | def means(self): 89 | return self.dist.mean() 90 | 91 | def event_average_f(self,function_string,A=None,keepdim=False): 92 | if A is None: 93 | return self.event_average(eval('self.dist.'+function_string)(),keepdim=keepdim) 94 | else: 95 | return self.event_average(eval('self.dist.'+function_string)(A),keepdim=keepdim) 96 | 97 | def average_f(self,function_string,A=None,keepdim=False): 98 | if A is None: 99 | return self.average(eval('self.dist.'+function_string)(),keepdim=keepdim) 100 | else: 101 | return self.average(eval('self.dist.'+function_string)(A),keepdim=keepdim) 102 | 103 | def average(self,A,keepdim=False): 104 | return (A*self.p).sum(-1,keepdim) 105 | 106 | ### Compute special expectations used for VB inference 107 | def event_average(self,A,keepdim=False): # returns sample_shape + W.event_shape 108 | # A is mix_batch_shape + mix_event_shape + event_shape 109 | out = (A*self.p.view(self.p.shape+(1,)*self.dist.event_dim)).sum(-1-self.dist.event_dim,keepdim) 110 | for i in range(self.event_dim-1): 111 | out = out.sum(-self.dist.event_dim-1,keepdim) 112 | return out 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /models/dists/MultivariateNormal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class MultivariateNormal(): 5 | def __init__(self,mu=None,Sigma=None,invSigmamu=None,invSigma=None): 6 | 7 | self.mu = mu 8 | self.Sigma = Sigma 9 | self.invSigmamu = invSigmamu 10 | self.invSigma = invSigma 11 | 12 | self.event_dim = 1 # final dimension is the dimension of the distribution 13 | if self.mu is not None: 14 | self.dim = mu.shape[-1] 15 | self.event_shape = mu.shape[-1:] 16 | self.batch_shape= mu.shape[:-1] 17 | elif self.invSigmamu is not None: 18 | self.dim = invSigmamu.shape[-1] 19 | self.event_shape = invSigmamu.shape[-1:] 20 | self.batch_shape = invSigmamu.shape[:-1] 21 | else: 22 | print('mu and invSigmamu are both None: cannont initialize MultivariateNormal') 23 | return None 24 | 25 | self.batch_dim = len(self.batch_shape) 26 | self.event_dim = len(self.event_shape) 27 | 28 | def to_event(self,n): 29 | if n==0: 30 | return self 31 | self.event_dim = self.event_dim + n 32 | self.batch_dim = self.batch_dim - n 33 | self.event_shape = self.batch_shape[-n:] + self.event_shape 34 | self.batch_shape = self.batch_shape[:-n] 35 | 36 | def mean(self): 37 | if self.mu is None: 38 | self.mu = (self.invSigma.inverse()*self.invSigmamu.unsqueeze(-2)).sum(-1) 39 | return self.mu 40 | 41 | def ESigma(self): 42 | if self.Sigma is None: 43 | self.Sigma = self.invSigma.inverse() 44 | return self.Sigma 45 | 46 | def EinvSigma(self): 47 | if self.invSigma is None: 48 | self.invSigma = self.Sigma.inverse() 49 | return self.invSigma 50 | 51 | def EinvSigmamu(self): 52 | if self.invSigmamu is None: 53 | self.invSigmamu = (self.EinvSigma().inverse()*self.mean().unsqueeze(-2)).sum(-1) 54 | return self.invSigmamu 55 | 56 | def ElogdetinvSigma(self): 57 | if self.Sigma is None: 58 | return self.invSigma.logdet() 59 | else: 60 | return -self.Sigma.logdet() 61 | 62 | def EX(self): 63 | return self.mean() 64 | 65 | def EXXT(self): 66 | return self.ESigma() + self.mean().unsqueeze(-1)*self.mean().unsqueeze(-2) 67 | 68 | def EXTX(self): 69 | return self.EXXT().sum(-1).sum(-1) 70 | 71 | def ss_update(self,SExx,SEx,n, lr=1.0): 72 | self.mu = SEx/n.unsqueeze(-1) 73 | self.Sigma = SExx/n.unsqueeze(-1).unsqueeze(-1) - self.mu.unsqueeze(-1)*self.mu.unsqueeze(-2) 74 | self.invSigma = None 75 | self.invSigmamu = None 76 | 77 | def raw_update(self,X,p=None,lr=1.0): # assumes X is a vector i.e. 78 | 79 | if p is None: 80 | SEx = X 81 | SExx = X.unsqueeze(-1)*X.unsqueeze(-2) 82 | sample_shape = X.shape[:-self.event_dim-self.batch_dim] 83 | n = torch.tensor(np.prod(sample_shape),requires_grad=False) 84 | n = n.expand(self.batch_shape + self.event_shape[:-1]) 85 | while SEx.ndim>self.event_dim + self.batch_dim: 86 | SExx = SExx.sum(0) 87 | SEx = SEx.sum(0) 88 | self.ss_update(SExx,SEx,n,lr) # inputs to ss_update must be batch + event consistent 89 | 90 | else: # data is shape sample_shape x batch_shape x event_shape with the first batch dimension having size 1 91 | 92 | for i in range(self.event_dim): 93 | p=p.unsqueeze(-1) 94 | SEx = X*p 95 | SExx = X.unsqueeze(-1)*X.unsqueeze(-2)*p.unsqueeze(-1) 96 | while SEx.ndim>self.event_dim + self.batch_dim: 97 | SExx = SExx.sum(0) 98 | SEx = SEx.sum(0) 99 | p = p.sum(0) 100 | self.ss_update(SExx,SEx,p.squeeze(-1),lr) # inputs to ss_update must be batch + event consistent 101 | # p now has shape batch_shape + event_shape so it must be squeezed by the default event_shape which is 1 102 | 103 | 104 | def Elog_like(self,X): 105 | # X is num_samples x num_dists x dim 106 | # returns num_samples x num_dists 107 | # output should be num_samples 108 | 109 | out = -0.5*((X - self.mu).unsqueeze(-1)*(X-self.mu).unsqueeze(-2)*self.EinvSigma()).sum(-1).sum(-1) 110 | out = out - 0.5*self.dim*np.log(2*np.pi) + 0.5*self.ElogdetinvSigma() 111 | for i in range(self.event_dim-2): 112 | out = out.sum(-1) 113 | return out 114 | 115 | def KLqprior(self): 116 | return torch.tensor(0.0) 117 | 118 | # class MixtureofMultivariateNormals(Mixture): 119 | # def __init__(self,mu_0,Sigma_0): 120 | # dist = MultivariateNormal(mu = torch.randn(mu_0.shape)+mu_0,Sigma = Sigma_0) 121 | # super().__init__(dist) 122 | 123 | -------------------------------------------------------------------------------- /models/dists/MultivariateNormal_vector_format.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class MultivariateNormal_vector_format(): 5 | def __init__(self,mu=None,Sigma=None,invSigmamu=None,invSigma=None,Residual=None): 6 | 7 | self.mu = mu 8 | self.Sigma = Sigma 9 | self.invSigmamu = invSigmamu 10 | self.invSigma = invSigma 11 | self.Residual = Residual 12 | 13 | self.event_dim = 2 # This is because we assue that this is a distribution over vectors that are dim x 1 matrices 14 | if self.mu is not None: 15 | self.dim = mu.shape[-2] 16 | self.event_shape = mu.shape[-2:] 17 | self.batch_shape= mu.shape[:-2] 18 | elif self.invSigmamu is not None: 19 | self.dim = invSigmamu.shape[-2] 20 | self.event_shape = invSigmamu.shape[-2:] 21 | self.batch_shape = invSigmamu.shape[:-2] 22 | else: 23 | print('mu and invSigmamu are both None: cannont initialize MultivariateNormal') 24 | return None 25 | 26 | self.batch_dim = len(self.batch_shape) 27 | self.event_dim = len(self.event_shape) 28 | 29 | @property 30 | def shape(self): 31 | return self.batch_shape + self.event_shape 32 | 33 | def to_event(self,n): 34 | if n == 0: 35 | return self 36 | self.event_dim = self.event_dim + n 37 | self.batch_dim = self.batch_dim - n 38 | self.event_shape = self.batch_shape[-n:] + self.event_shape 39 | self.batch_shape = self.batch_shape[:-n] 40 | return self 41 | 42 | def unsqueeze(self,dim): # only appliles to batch 43 | assert(dim + self.event_dim < 0) 44 | if self.mu is not None: 45 | mu = self.mu.unsqueeze(dim) 46 | else: mu = None 47 | if self.Sigma is not None: 48 | Sigma = self.Sigma.unsqueeze(dim) 49 | else: Sigma = None 50 | if self.invSigmamu is not None: 51 | invSigmamu = self.invSigmamu.unsqueeze(dim) 52 | else: invSigmamu = None 53 | if self.invSigma is not None: 54 | invSigma = self.invSigma.unsqueeze(dim) 55 | else: invSigma = None 56 | event_dim = self.event_dim - 2 57 | return MultivariateNormal_vector_format(mu,Sigma,invSigmamu,invSigma).to_event(event_dim) 58 | 59 | def combiner(self,other): 60 | self.invSigma = self.EinvSigma()+other.EinvSigma() 61 | self.invSigmamu = self.EinvSigmamu()+other.EinvSigmamu() 62 | self.Sigma = None 63 | self.mu = None 64 | 65 | def nat_combiner(self,invSigma,invSigmamu): 66 | self.invSigma = self.EinvSigma()+invSigma 67 | self.invSigmamu = self.EinvSigmamu()+invSigmamu 68 | self.Sigma = None 69 | self.mu = None 70 | 71 | def mean(self): 72 | if self.mu is None: 73 | self.mu = self.invSigma.inverse()@self.invSigmamu 74 | return self.mu 75 | 76 | def ESigma(self): 77 | if self.Sigma is None: 78 | self.Sigma = self.invSigma.inverse() 79 | return self.Sigma 80 | 81 | def EinvSigma(self): 82 | if self.invSigma is None: 83 | self.invSigma = self.Sigma.inverse() 84 | return self.invSigma 85 | 86 | def EinvSigmamu(self): 87 | if self.invSigmamu is None: 88 | self.invSigmamu = self.EinvSigma()@self.mean() 89 | return self.invSigmamu 90 | 91 | def EResidual(self): 92 | if self.Residual is None: 93 | self.Residual = - 0.5*(self.mean()*self.EinvSigmamu()).sum(-1).sum(-1) + 0.5*self.ElogdetinvSigma() - 0.5*self.dim*np.log(2*np.pi) 94 | return self.Residual 95 | 96 | def ElogdetinvSigma(self): 97 | if self.Sigma is None: 98 | return self.invSigma.logdet() 99 | else: 100 | return -self.Sigma.logdet() 101 | 102 | def EX(self): 103 | return self.mean() 104 | 105 | def EXXT(self): 106 | return self.ESigma() + self.mean()@self.mean().transpose(-2,-1) 107 | 108 | def EXTX(self): 109 | return self.ESigma().sum(-1).sum(-1) + (self.mean().transpose(-2,-1)@self.mean()).squeeze(-1).squeeze(-1) 110 | 111 | def Res(self): 112 | return - 0.5*(self.mean()*self.EinvSigmamu()).sum(-1).sum(-1) + 0.5*self.ElogdetinvSigma() - 0.5*self.dim*np.log(2*np.pi) 113 | 114 | def ss_update(self,SExx,SEx,n, lr=1.0): 115 | n=n.unsqueeze(-1).unsqueeze(-1) 116 | self.mu = SEx/n 117 | self.Sigma = SExx/n - self.mu@self.mu.transpose(-2,-1) 118 | self.invSigma = None 119 | self.invSigmamu = None 120 | 121 | def raw_update(self,X,p=None,lr=1.0): # assumes X is a vector i.e. 122 | 123 | 124 | if p is None: 125 | SEx = X 126 | SExx = X@X.transpose(-2,-1) 127 | sample_shape = X.shape[:-self.event_dim-self.batch_dim] 128 | n = torch.tensor(np.prod(sample_shape),requires_grad=False) 129 | n = n.expand(self.batch_shape + self.event_shape[:-2]) 130 | while SEx.ndim>self.event_dim + self.batch_dim: 131 | SExx = SExx.sum(0) 132 | SEx = SEx.sum(0) 133 | self.ss_update(SExx,SEx,n,lr) # inputs to ss_update must be batch + event consistent 134 | 135 | else: # data is shape sample_shape x batch_shape x event_shape with the first batch dimension having size 1 136 | 137 | for i in range(self.event_dim): 138 | p=p.unsqueeze(-1) 139 | SExx = X@X.transpose(-2,-1)*p 140 | SEx = X*p 141 | while SEx.ndim>self.event_dim + self.batch_dim: 142 | SExx = SExx.sum(0) 143 | SEx = SEx.sum(0) 144 | p = p.sum(0) 145 | self.ss_update(SExx,SEx,p.squeeze(-1).squeeze(-1),lr) # inputs to ss_update must be batch + event consistent 146 | # p now has shape batch_shape + event_shape so it must be squeezed by the default event_shape which is 1 147 | 148 | 149 | def Elog_like(self,X): 150 | # X is num_samples x num_dists x dim 151 | # returns num_samples x num_dists 152 | # output should be num_samples 153 | 154 | out = -0.5*((X - self.mu).transpose(-2,-1)@self.EinvSigma()@(X - self.mu)).squeeze(-1).squeeze(-1) 155 | out = out - 0.5*self.dim*np.log(2*np.pi) + 0.5*self.ElogdetinvSigma() 156 | for i in range(self.event_dim-2): 157 | out = out.sum(-1) 158 | return out 159 | 160 | def KLqprior(self): 161 | return torch.tensor(0.0,requires_grad=False) 162 | 163 | # from .Mixture import Mixture 164 | 165 | # class MixtureofMultivariateNormals_vector_format(Mixture): 166 | # def __init__(self,mu_0,Sigma_0): 167 | # dist = MultivariateNormal_vector_format(mu = torch.randn(mu_0.shape,requires_grad=False)+mu_0,Sigma = Sigma_0) 168 | # super().__init__(dist) 169 | 170 | 171 | -------------------------------------------------------------------------------- /models/dists/NormalGamma.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .Gamma import Gamma 4 | 5 | class NormalGamma(): 6 | def __init__(self,lambda_mu_0,mu_0,alpha_0,beta_0): 7 | 8 | self.dim = mu_0.shape[-1] 9 | self.event_dim = 1 10 | self.event_shape = mu_0.shape[-1:] 11 | self.batch_dim = mu_0.ndim - 1 12 | self.batch_shape = mu_0.shape[:-1] 13 | 14 | self.lambda_mu_0 = lambda_mu_0 15 | self.lambda_mu = self.lambda_mu_0 16 | self.mu_0 = mu_0 17 | self.gamma = Gamma(alpha_0,beta_0).to_event(1) 18 | self.mu = mu_0 + torch.randn(mu_0.shape,requires_grad=False)*self.gamma.mean().sqrt() 19 | 20 | def mean(self): 21 | return self.mu 22 | 23 | def Emumu(self): 24 | return self.mu.unsqueeze(-2)*self.mu.unsqueeze(-1) + self.ESigma()/self.lambda_mu.unsqueeze(-1).unsqueeze(-1) 25 | 26 | def ElogdetinvSigma(self): 27 | return self.gamma.loggeomean().sum(-1) 28 | 29 | def EmuTinvSigmamu(self): 30 | return (self.mu**2*self.gamma.mean()).sum(-1) + self.dim/self.lambda_mu 31 | 32 | def EXTinvUX(self): 33 | return (self.mu**2*self.gamma.mean()).sum(-1) + self.dim/self.lambda_mu 34 | 35 | def EinvSigma(self): 36 | return self.gamma.mean().unsqueeze(-1)*torch.eye(self.dim,requires_grad=False) 37 | 38 | def ESigma(self): 39 | return self.gamma.meaninv().unsqueeze(-1)*torch.eye(self.dim,requires_grad=False) 40 | 41 | def Res(self): 42 | return -0.5*self.EXTinvUX() + 0.5*self.ElogdetinvSigma() - 0.5*self.dim*np.log(2*np.pi) 43 | 44 | def EinvSigmamu(self): 45 | return self.gamma.mean()*self.mu 46 | 47 | def to_event(self,n): 48 | if n == 0: 49 | return self 50 | self.event_dim = self.event_dim + n 51 | self.batch_dim = self.batch_dim - n 52 | self.event_shape = self.batch_shape[-n:] + self.event_shape 53 | self.batch_shape = self.batch_shape[:-n] 54 | self.gamma.to_event(n) 55 | return self 56 | 57 | def ss_update(self,SExx,SEx,n, lr=1.0): 58 | 59 | lambda_mu = self.lambda_mu_0 + n 60 | mu = (self.lambda_mu_0.unsqueeze(-1)*self.mu_0 + SEx)/lambda_mu.unsqueeze(-1) 61 | SExx = SExx + self.lambda_mu_0.unsqueeze(-1)*self.mu_0**2 - lambda_mu.unsqueeze(-1)*mu**2 62 | 63 | self.lambda_mu = (lambda_mu-self.lambda_mu)*lr + self.lambda_mu 64 | self.mu = (mu-self.mu)*lr + self.mu 65 | 66 | self.gamma.ss_update(0.5*n.unsqueeze(-1),0.5*SExx) 67 | 68 | def raw_update(self,X,p=None,lr=1.0): 69 | 70 | if p is None: # data is sample_shape + batch_shape + event_event_shape 71 | SEx = X 72 | SExx = X**2 73 | sample_shape = X.shape[:-self.event_dim-self.batch_dim] 74 | n = torch.tensor(np.prod(sample_shape),requires_grad=False) 75 | n = n.expand(self.batch_shape + self.event_shape[:-1]) 76 | while SEx.ndim>self.event_dim + self.batch_dim: 77 | SExx = SExx.sum(0) 78 | SEx = SEx.sum(0) 79 | self.ss_update(SExx,SEx,n,lr) # inputs to ss_update must be batch + event consistent 80 | 81 | else: # data is sample_shape + batch_shape* + event_shape and p is num_samples x batch_shape 82 | # batch_shape* can be (1,)*batch_dim 83 | for i in range(self.event_dim): 84 | p=p.unsqueeze(-1) 85 | SEx = X*p 86 | SExx = (X**2*p) 87 | while SEx.ndim>self.event_dim + self.batch_dim: 88 | SExx = SExx.sum(0) 89 | SEx = SEx.sum(0) 90 | p = p.sum(0) 91 | self.ss_update(SExx,SEx,p.squeeze(-1),lr) # inputs to ss_update must be batch + event consistent 92 | 93 | 94 | def Elog_like(self,X): 95 | # X is num_samples x num_dists x dim 96 | # returns num_samples x num_dists 97 | # output should be num_samples 98 | 99 | out = -0.5*(X.pow(2)*self.gamma.mean()).sum(-1) + (X*self.EinvSigmamu()).sum(-1) - 0.5*(self.EXTinvUX()) 100 | out = out + 0.5*self.ElogdetinvSigma() - 0.5*self.dim*np.log(2*np.pi) 101 | 102 | 103 | 104 | out = -0.5*((X - self.mu)**2*self.gamma.mean()).sum(-1) + 0.5*self.gamma.loggeomean().sum(-1) - 0.5*self.dim*np.log(2*np.pi) 105 | for i in range(self.event_dim-1): 106 | out = out.sum(-1) 107 | return out 108 | 109 | def KLqprior(self): 110 | 111 | out = self.lambda_mu_0/2.0*((self.mu-self.mu_0)**2*self.gamma.mean()).sum(-1) 112 | out = out + self.dim/2.0*(self.lambda_mu_0/self.lambda_mu - (self.lambda_mu_0/self.lambda_mu).log() -1) 113 | for i in range(self.event_dim-1): 114 | out = out.sum(-1) 115 | return out + self.gamma.KLqprior().sum(-1) 116 | 117 | # from .Mixture import Mixture 118 | # class MixtureofNormalGammas(Mixture): 119 | # def __init__(self,dim,n): 120 | # dist = NormalGamma(torch.ones(dim,requires_grad=False), 121 | # torch.zeros(dim,n,requires_grad=False), 122 | # torch.ones(dim,n,requires_grad=False), 123 | # torch.ones(dim,n,requires_grad=False), 124 | # ) 125 | # super().__init__(dist) 126 | -------------------------------------------------------------------------------- /models/dists/NormalInverseWishart.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .Wishart import Wishart 4 | 5 | class NormalInverseWishart(): 6 | 7 | def __init__(self,lambda_mu_0=None,mu_0=None,nu_0=None,invV_0=None): 8 | 9 | self.event_dim = 1 10 | self.event_shape = mu_0.shape[-1:] 11 | self.batch_shape = mu_0.shape[:-1] 12 | self.batch_dim = mu_0.ndim - self.event_dim 13 | self.dim = mu_0.shape[-1] 14 | if lambda_mu_0 is None: 15 | self.lambda_mu_0 = torch.ones(mu_0.shape[:-1],requires_grad=False) 16 | else: 17 | self.lambda_mu_0 = lambda_mu_0 18 | if nu_0 is None: 19 | nu_0 = torch.ones(mu_0.shape[:-1],requires_grad=False)*(self.dim+2) 20 | if invV_0 is None: 21 | invV_0 = torch.zeros(mu_0.shape+(self.dim,),requires_grad=False)+torch.eye(self.dim,requires_grad=False) 22 | 23 | self.lambda_mu = self.lambda_mu_0 24 | self.mu_0 = mu_0 25 | self.mu = mu_0 + torch.randn(mu_0.shape,requires_grad=False) 26 | 27 | self.invU = Wishart(nu_0,invV_0) 28 | 29 | def mean(self): 30 | return self.mu 31 | 32 | def EX(self): 33 | return self.mu 34 | 35 | def EXXT(self): 36 | return self.mu.unsqueeze(-1)*self.mu.unsqueeze(-2) + self.invU.ESigma()/self.lambda_mu.unsqueeze(-1).unsqueeze(-1) 37 | 38 | def ESigma(self): 39 | return self.invU.ESigma() 40 | 41 | def ElogdetinvSigma(self): 42 | return self.invU.ElogdetinvSigma() 43 | 44 | def EinvSigmamu(self): 45 | return (self.invU.EinvSigma()*self.mu.unsqueeze(-2)).sum(-1) 46 | 47 | def EinvSigma(self): 48 | return self.invU.EinvSigma() 49 | 50 | def EinvUX(self): 51 | return (self.invU.EinvSigma()*self.mu.unsqueeze(-2)).sum(-1) 52 | 53 | def EXTinvUX(self): 54 | return (self.mu.unsqueeze(-1)*self.invU.EinvSigma()*self.mu.unsqueeze(-2)).sum(-1).sum(-1) + self.dim/self.lambda_mu 55 | 56 | def to_event(self,n): 57 | if n ==0: 58 | return self 59 | self.event_dim = self.event_dim + n 60 | self.batch_dim = self.batch_dim - n 61 | self.event_shape = self.batch_shape[-n:] + self.event_shape 62 | self.batch_shape = self.batch_shape[:-n] 63 | self.invU.to_event(n) 64 | return self 65 | 66 | def ss_update(self,SExx,SEx,n, lr=1.0): 67 | # SExx is batch_shape + event_shape + (dim,) 68 | # SEx is batch_shape + event_shape 69 | # n is batch_shape + event_shape[:-1] 70 | 71 | lambda_mu = self.lambda_mu_0 + n 72 | mu = (self.lambda_mu_0.unsqueeze(-1)*self.mu_0 + SEx)/lambda_mu.unsqueeze(-1) 73 | # invV = SExx + self.lambda_mu_0.unsqueeze(-1).unsqueeze(-1)*self.mu_0.unsqueeze(-1)*self.mu_0.unsqueeze(-2) - lambda_mu.unsqueeze(-1).unsqueeze(-1)*mu.unsqueeze(-1)*mu.unsqueeze(-2) 74 | invV = SExx + self.lambda_mu_0.unsqueeze(-1).unsqueeze(-1)*self.mu_0.unsqueeze(-1)*self.mu_0.unsqueeze(-2) - n.unsqueeze(-1).unsqueeze(-1)*mu.unsqueeze(-1)*mu.unsqueeze(-2) 75 | 76 | self.lambda_mu = (lambda_mu-self.lambda_mu)*lr + self.lambda_mu 77 | self.mu = (mu-self.mu)*lr + self.mu 78 | self.invU.ss_update(invV,n,lr) 79 | 80 | def raw_update(self,X,p=None,lr=1.0): 81 | # assumes data is num_samples (Times) x batch_shape x evevnt_dim 82 | # if specified p has shape num_samples x batch_shape 83 | # the critical manipulation here is that p averages over the batch dimension 84 | 85 | if p is None: 86 | SEx = X 87 | SExx = X.unsqueeze(-1)*X.unsqueeze(-2) 88 | sample_shape = X.shape[:-self.event_dim-self.batch_dim] 89 | n = torch.tensor(np.prod(sample_shape),requires_grad=False) 90 | n = n.expand(self.batch_shape + self.event_shape[:-1]) 91 | while SEx.ndim>self.event_dim + self.batch_dim: 92 | SExx = SExx.sum(0) 93 | SEx = SEx.sum(0) 94 | self.ss_update(SExx,SEx,n,lr) # inputs to ss_update must be batch + event consistent 95 | 96 | else: # data is shape sample_shape x batch_shape x event_shape with the first batch dimension having size 1 97 | 98 | for i in range(self.event_dim): 99 | p=p.unsqueeze(-1) 100 | SExx = X.unsqueeze(-1)*X.unsqueeze(-2)*p.unsqueeze(-1) 101 | SEx = X*p 102 | while SEx.ndim>self.event_dim + self.batch_dim: 103 | SExx = SExx.sum(0) 104 | SEx = SEx.sum(0) 105 | p = p.sum(0) 106 | # p now has shape batch_shape + event_shape so it must be squeezed by the default event_shape which is 1 107 | self.ss_update(SExx,SEx,p.squeeze(-1),lr) # inputs to ss_update must be batch + event consistent 108 | 109 | def Elog_like(self,X): 110 | # X is num_samples x batch_shape x event_shape OR num_samples x (1,)*batch_dim x event_shape 111 | 112 | out = -0.5*((X.unsqueeze(-1)*self.EinvSigma()).sum(-2)*X).sum(-1) + (X*self.EinvSigmamu()).sum(-1) - 0.5*(self.EXTinvUX()) 113 | out = out + 0.5*self.ElogdetinvSigma() - 0.5*self.dim*np.log(2*np.pi) 114 | 115 | for i in range(self.event_dim-1): 116 | out = out.sum(-1) 117 | return out 118 | 119 | def KLqprior(self): 120 | KL = 0.5*(self.lambda_mu_0/self.lambda_mu - 1 + (self.lambda_mu/self.lambda_mu_0).log())*self.dim 121 | KL = KL + 0.5*self.lambda_mu_0*((self.mu-self.mu_0).unsqueeze(-1)*(self.mu-self.mu_0).unsqueeze(-2)*self.invU.mean()).sum(-1).sum(-1) 122 | for i in range(self.event_dim-1): 123 | KL = KL.sum(-1) 124 | KL = KL + self.invU.KLqprior() 125 | return KL 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /models/dists/Wishart.py: -------------------------------------------------------------------------------- 1 | # Implements Wishart distribution and associated natural parameter updates. This could be made more memory efficient by 2 | # using the eigenvalue decomposition for all calculation instead so simultaneously storing invU and U. Which is to say that 3 | # currently it uses 3x more memory than is really needed. We could fix this by replacing invU and U with @property methods 4 | # that compute them using invU = self.v@(self.d.unsqueeze(-1)*self.v.transpose(-2,-1)) and U = self.v@(1.0/self.d.unsqueeze(-1)*self.v.transpose(-2,-1)) 5 | 6 | import numpy as np 7 | import torch 8 | 9 | class Wishart(): 10 | 11 | def __init__(self,nu,U): #nu, invU are natural parameters, nu*U is expected value of precision matrix 12 | 13 | self.dim = U.shape[-1] 14 | self.event_dim = 2 15 | self.batch_dim = U.ndim-2 16 | self.event_shape = U.shape[-2:] 17 | self.batch_shape = U.shape[:-2] 18 | 19 | self.invU_0 = U.inverse() 20 | self.logdet_invU_0 = self.invU_0.logdet() 21 | self.invU = self.invU_0 22 | self.U = U 23 | self.nu_0 = nu 24 | self.nu = nu 25 | self.logdet_invU = self.invU.logdet() 26 | 27 | def to_event(self,n): 28 | if n ==0: 29 | return self 30 | self.event_dim = self.event_dim + n 31 | self.batch_dim = self.batch_dim - n 32 | self.event_shape = self.batch_shape[-n:] + self.event_shape 33 | self.batch_shape = self.batch_shape[:-n] 34 | return self 35 | 36 | def log_mvgamma(self,nu): 37 | return (nu.unsqueeze(-1) - torch.arange(self.dim)/2.0).lgamma().sum(-1) 38 | 39 | def log_mvdigamma(self,nu): 40 | return (nu.unsqueeze(-1) - torch.arange(self.dim)/2.0).digamma().sum(-1) 41 | 42 | def ss_update(self,SExx,n,lr=1.0): 43 | idx = n>1 44 | SExx = SExx*(idx).unsqueeze(-1).unsqueeze(-1) 45 | self.invU = (self.invU_0 + SExx)*lr + (1-lr)*self.invU 46 | self.nu = (self.nu_0 + n)*lr + (1-lr)*self.nu 47 | self.U = self.invU.inverse() 48 | self.logdet_invU = self.invU.logdet() 49 | 50 | 51 | # idx = ~(self.logdet_invU>self.logdet_invU_0) 52 | # if idx.sum()>0: 53 | # print('Wishart ss_update hack triggered at',idx.sum(),'locations') 54 | # print(idx) 55 | # self.invU[idx] = self.invU_0[idx] 56 | # self.U[idx] = self.invU_0[idx].inverse() 57 | # self.nu[idx] = self.nu_0[idx] 58 | # self.logdet_invU[idx] = self.logdet_invU_0[idx] 59 | 60 | def mean(self): 61 | return self.U*self.nu.unsqueeze(-1).unsqueeze(-1) 62 | 63 | def meaninv(self): 64 | return self.invU/(self.nu.unsqueeze(-1).unsqueeze(-1) - self.dim - 1) 65 | 66 | def ESigma(self): 67 | return self.invU/(self.nu.unsqueeze(-1).unsqueeze(-1) - self.dim - 1) 68 | 69 | def EinvSigma(self): 70 | return self.U*self.nu.unsqueeze(-1).unsqueeze(-1) 71 | 72 | def ElogdetinvSigma(self): 73 | return self.dim*np.log(2) - self.logdet_invU + ((self.nu.unsqueeze(-1) - torch.arange(self.dim))/2.0).digamma().sum(-1) 74 | 75 | def KLqprior(self): 76 | out = self.nu_0/2.0*(self.logdet_invU-self.logdet_invU_0) + self.nu/2.0*(self.invU_0*self.U).sum(-1).sum(-1) - self.nu*self.dim/2.0 77 | out = out + self.log_mvgamma(self.nu_0/2.0) - self.log_mvgamma(self.nu/2.0) + (self.nu - self.nu_0)/2.0*self.log_mvdigamma(self.nu/2.0) 78 | 79 | for i in range(self.event_dim -2): 80 | out = out.sum(-1) 81 | return out 82 | 83 | def logZ(self): 84 | return self.log_mvgamma(self.nu/2.0) + 0.5*self.nu*self.dim*np.log(2) - 0.5*self.nu*self.logdet_invU 85 | 86 | 87 | class Wishart_eigh(): 88 | 89 | def __init__(self,nu,U): #nu, invU are natural parameters, nu*U is expected value of precision matrix 90 | 91 | self.dim = U.shape[-1] 92 | self.event_dim = 2 93 | self.batch_dim = U.ndim-2 94 | self.event_shape = U.shape[-2:] 95 | self.batch_shape = U.shape[:-2] 96 | 97 | self.d, self.v = torch.linalg.eigh(U) 98 | self.d = 1.0/self.d 99 | self.invU_0 = self.v@(self.d.unsqueeze(-1)*self.v.transpose(-2,-1)) 100 | self.logdet_invU_0 = self.d.log().sum(-1) 101 | self.nu_0 = nu 102 | self.nu = self.nu_0 103 | 104 | @property 105 | def U(self): 106 | return self.v@(1.0/self.d.unsqueeze(-1)*self.v.transpose(-2,-1)) 107 | 108 | @property 109 | def invU(self): 110 | return self.v@(self.d.unsqueeze(-1)*self.v.transpose(-2,-1)) 111 | 112 | @property 113 | def logdet_invU(self): 114 | return self.d.log().sum(-1) 115 | 116 | def to_event(self,n): 117 | if n ==0: 118 | return self 119 | self.event_dim = self.event_dim + n 120 | self.batch_dim = self.batch_dim - n 121 | self.event_shape = self.batch_shape[-n:] + self.event_shape 122 | self.batch_shape = self.batch_shape[:-n] 123 | return self 124 | 125 | def log_mvgamma(self,nu): 126 | return (nu.unsqueeze(-1) - torch.arange(self.dim)/2.0).lgamma().sum(-1) 127 | 128 | def log_mvdigamma(self,nu): 129 | return (nu.unsqueeze(-1) - torch.arange(self.dim)/2.0).digamma().sum(-1) 130 | 131 | def ss_update(self,SExx,n,lr=1.0): 132 | idx = n>1 133 | SExx = SExx*(idx).unsqueeze(-1).unsqueeze(-1) 134 | invU = (self.invU_0 + SExx)*lr + (1-lr)*self.invU 135 | self.nu = (self.nu_0 + n)*lr + (1-lr)*self.nu 136 | self.d, self.v = torch.linalg.eigh(0.5*invU+0.5*invU.transpose(-2,-1)) # recall v@d@v.transpose(-2,-1) = invU 137 | 138 | def nat_update(self,nu,invU): 139 | self.nu = nu 140 | self.d, self.v = torch.linalg.eigh(0.5*invU+0.5*invU.transpose(-2,-1)) # recall v@d@v.transpose(-2,-1) = invU 141 | 142 | def mean(self): 143 | return self.U*self.nu.unsqueeze(-1).unsqueeze(-1) 144 | 145 | def meaninv(self): 146 | return self.invU/(self.nu.unsqueeze(-1).unsqueeze(-1) - self.dim - 1) 147 | 148 | def ESigma(self): 149 | return self.invU/(self.nu.unsqueeze(-1).unsqueeze(-1) - self.dim - 1) 150 | 151 | def EinvSigma(self): 152 | return self.U*self.nu.unsqueeze(-1).unsqueeze(-1) 153 | 154 | def ElogdetinvSigma(self): 155 | return self.dim*np.log(2) - self.logdet_invU + ((self.nu.unsqueeze(-1) - torch.arange(self.dim))/2.0).digamma().sum(-1) 156 | 157 | def KLqprior(self): 158 | out = self.nu_0/2.0*(self.logdet_invU-self.logdet_invU_0) + self.nu/2.0*(self.invU_0*self.U).sum(-1).sum(-1) - self.nu*self.dim/2.0 159 | out = out + self.log_mvgamma(self.nu_0/2.0) - self.log_mvgamma(self.nu/2.0) + (self.nu - self.nu_0)/2.0*self.log_mvdigamma(self.nu/2.0) 160 | 161 | for i in range(self.event_dim -2): 162 | out = out.sum(-1) 163 | return out 164 | 165 | def logZ(self): 166 | return self.log_mvgamma(self.nu/2.0) + 0.5*self.nu*self.dim*np.log(2) - 0.5*self.nu*self.logdet_invU 167 | 168 | 169 | class Wishart_UnitTrace(Wishart_eigh): 170 | 171 | def suminv_d_plus_x(self,x): 172 | return self.nu*(1.0/(self.d+x)).sum(-1) 173 | 174 | def suminv_d_plus_x_prime(self,x): 175 | return -self.nu*(1.0/(self.d+x)**2).sum(-1) 176 | 177 | def ss_update(self,SExx,n,lr=1.0,iters=8): 178 | super().ss_update(SExx,n,lr=lr) 179 | x=self.d.mean(-1) 180 | for i in range(iters): 181 | x = x + (self.dim-self.suminv_d_plus_x(x))/self.suminv_d_plus_x_prime(x) 182 | x[x<-self.d.min()] = -self.d.min()+1e-6 # ensure positive definite 183 | self.d = self.d+x 184 | 185 | 186 | class Wishart_UnitDet(Wishart_eigh): 187 | 188 | def log_mvdigamma_prime(self,nu): 189 | return (nu.unsqueeze(-1) - torch.arange(self.dim)/2.0).polygamma(1).sum(-1) 190 | 191 | def ss_update(self,SExx,n,lr=1.0,iters=4): 192 | super().ss_update(SExx,n,lr=lr) 193 | log_mvdigamma_target = -self.dim*np.log(2) + self.logdet_invU 194 | lognu = (log_mvdigamma_target/self.dim) 195 | for k in range(iters): 196 | lognu = lognu + (log_mvdigamma_target-self.log_mvdigamma(lognu.exp()))/self.log_mvdigamma_prime(lognu.exp())*(-lognu).exp() 197 | self.nu = 2.0*lognu.exp() 198 | -------------------------------------------------------------------------------- /models/dists/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .Dirichlet import Dirichlet 4 | from .Mixture import Mixture 5 | from .Delta import Delta 6 | from .Gamma import Gamma 7 | from .Wishart import Wishart 8 | from .Wishart import Wishart_UnitTrace 9 | from .Wishart import Wishart_UnitDet 10 | from .MultivariateNormal import MultivariateNormal 11 | from .MultivariateNormal_vector_format import MultivariateNormal_vector_format 12 | from .MVN_ard import MVN_ard 13 | from .NormalGamma import NormalGamma 14 | from .NormalInverseWishart import NormalInverseWishart 15 | from .DiagonalWishart import DiagonalWishart 16 | from .DiagonalWishart import DiagonalWishart_UnitTrace 17 | from .MatrixNormalGamma import MatrixNormalGamma 18 | from .MatrixNormalGamma import MatrixNormalGamma_UnitTrace 19 | from .MatrixNormalWishart import MatrixNormalWishart 20 | from .TensorNormalWishart import TensorNormalWishart 21 | -------------------------------------------------------------------------------- /models/dists/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .matrix_utils import matrix_utils -------------------------------------------------------------------------------- /models/dists/utils/matrix_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | class matrix_utils(): 3 | 4 | def block_diag_matrix_builder(A,B): 5 | # builds a block matrix [[A,B],[C,D]] out of compatible tensors 6 | n1 = A.shape[-1] 7 | n2 = B.shape[-1] 8 | t_shape = A.shape[:-2] 9 | return torch.cat((torch.cat((A, torch.zeros(t_shape + (n1,n2),requires_grad=False)),-1),torch.cat((torch.zeros(t_shape + (n2,n1),requires_grad=False), B),-1)),-2) 10 | 11 | def block_matrix_inverse(A,B,C,D,block_form=True): 12 | # inverts a block matrix of the form [A B; C D] and returns the blocks [Ainv Binv; Cinv Dinv] 13 | invA = A.inverse() 14 | invD = D.inverse() 15 | Ainv = (A - B@invD@C).inverse() 16 | Dinv = (D - C@invA@B).inverse() 17 | 18 | if(block_form == 'left'): # left decomposed returns abcd.inverse = [A 0; 0 D] @ [eye B; C eye] 19 | return Ainv, -B@invD, -C@invA, Dinv 20 | elif(block_form == 'right'): # right decomposed returns abcd.inverse = [eye B; C eye] @ [A 0; 0 D] 21 | return Ainv, -invA@B, -invD@C, Dinv 22 | elif(block_form == 'True'): 23 | return Ainv, -Ainv@B@Dinv, -invD@C@invA, Dinv 24 | else: 25 | return torch.cat((torch.cat((Ainv, -invA@B@Dinv),-1),torch.cat((-invD@C@Ainv, Dinv),-1)),-2) 26 | 27 | def block_matrix_builder(A,B,C,D): 28 | # builds a block matrix [[A,B],[C,D]] out of compatible tensors 29 | return torch.cat((torch.cat((A, B),-1),torch.cat((C, D),-1)),-2) 30 | 31 | def block_precision_marginalizer(A,B,C,D): 32 | # When computing the precision of marginals, A - B@invD@C, does not need to be inverted 33 | # This is because (A - B@invD@C).inverse is the marginal covariance, the inverse of which is precsion 34 | # As a result in many applications we can save on computation by returning the inverse of Joint Precision 35 | # in the form [A_prec 0; 0 D_prec] @ [eye B; C eye]. This is particularly useful when computing 36 | # marginal invSigma and invSigmamu since invSigma_A = A_prec 37 | # invSigmamu_A = invSigmamu_J_A - B@invD@invSigmamu_J_B 38 | # invSigma_D = D_prec 39 | # invSigmamu_D = invSigmamu_J_D - C@invA@invSigmamu_J_A 40 | 41 | invA = A.inverse() 42 | invD = D.inverse() 43 | A_prec = (A - B@invD@C) 44 | D_prec = (D - C@invA@B) 45 | 46 | return A_prec, -B@invD, -C@invA, D_prec 47 | 48 | 49 | def block_matrix_logdet(A,B,C,D,singular=False): 50 | if(singular == 'A'): 51 | return D.logdet() + (A - B@D.inverse()@C).logdet() 52 | elif(singular == 'D'): 53 | return A.logdet() + (D - C@A.inverse()@B).logdet() 54 | else: 55 | return D.logdet() + (A - B@D.inverse()@C).logdet() 56 | 57 | -------------------------------------------------------------------------------- /models/rHMM.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from .dists import Dirichlet 5 | from .MultiNomialLogisticRegression import MultiNomialLogisticRegression as MNLR 6 | class rHMM(): 7 | 8 | def __init__(self, obs_dist, p, transition_mask=None): 9 | print('work in progress: fix transition_loggoemean and add mask') 10 | self.obs_dist = obs_dist 11 | # assume that the first dimension the batch_shape is the dimension of the HMM 12 | self.hidden_dim = obs_dist.batch_shape[-1] 13 | self.event_dim = 1 14 | self.event_shape = (self.hidden_dim,) 15 | self.batch_shape = obs_dist.batch_shape[:-1] 16 | self.batch_dim = len(self.batch_shape) 17 | self.transition_mask = transition_mask 18 | 19 | self.transition = MNLR(n,p,batch_shape = (n,),pad_X=True) 20 | self.initial = Dirichlet(0.5*torch.ones(self.batch_shape+(self.hidden_dim,),requires_grad=False)) 21 | self.initial.alpha = self.initial.alpha_0 22 | self.sumlogZ = -torch.inf 23 | self.p = None 24 | 25 | def to_event(self,n): 26 | if n < 1: 27 | return self 28 | self.event_dim = self.event_dim + n 29 | self.batch_dim = self.batch_dim - n 30 | self.event_shape = self.batch_shape[-n:] + self.event_shape 31 | self.batch_shape = self.batch_shape[:-n] 32 | return self 33 | 34 | def transition_loggeomean(self): 35 | if self.transition_mask is None: 36 | return self.transition.loggeomean() 37 | else: 38 | return self.transition.loggeomean() + self.transition_mask.log() 39 | 40 | def logmatmulexp(self,x,y): 41 | 42 | x_shift = x.max(-1, keepdim=True)[0] 43 | y_shift = y.max(-2, keepdim=True)[0] 44 | xy = torch.matmul((x - x_shift).exp(), (y - y_shift).exp()).log() 45 | return xy + x_shift + y_shift 46 | 47 | def forward_step(self,logits,observation_logits): 48 | return (logits.unsqueeze(-1) + observation_logits.unsqueeze(-2) + self.transition_loggeomean()).logsumexp(-2) 49 | 50 | def backward_step(self,logits,observation_logits): 51 | return (logits.unsqueeze(-2) + observation_logits.unsqueeze(-2) + self.transition_loggeomean()).logsumexp(-1) 52 | 53 | def forward_backward_logits(self,fw_logits): 54 | # Assumes that time is in the first dimension of the observation 55 | # On input fw_logits = observation_logits. 56 | # T = observation_logits.shape[0] 57 | T = fw_logits.shape[0] 58 | 59 | # logits = self.transition_loggeomean() + observation_logits.unsqueeze(-2) 60 | # fw_logits = torch.zeros(observation_logits.shape,requires_grad=False) 61 | # fw_logits[0] = (logits[0] + self.initial.loggeomean().unsqueeze(-1)).logsumexp(-2) 62 | 63 | fw_logits[0] = (fw_logits[0].unsqueeze(-2) + self.initial.loggeomean().unsqueeze(-1)).logsumexp(-2) 64 | 65 | for t in range(1,T): 66 | # fw_logits[t] = (fw_logits[t-1].unsqueeze(-1) + logits[t]).logsumexp(-2) 67 | fw_logits[t] = (fw_logits[t-1].unsqueeze(-1) + fw_logits[t].unsqueeze(-2) + self.transition_loggeomean()).logsumexp(-2) 68 | logZ = fw_logits[-1].logsumexp(-1,True) 69 | fw_logits = fw_logits - logZ 70 | logZ = logZ.squeeze(-1) 71 | SEzz = torch.zeros(fw_logits.shape[1:]+(self.hidden_dim,),requires_grad=False) 72 | for t in range(T-2,-1,-1): 73 | ### Backward Smoothing 74 | temp = fw_logits[t].unsqueeze(-1) + self.transition_loggeomean() 75 | xi_logits = (temp - temp.logsumexp(-2,keepdim=True)) + fw_logits[t+1].unsqueeze(-2) 76 | fw_logits[t] = xi_logits.logsumexp(-1) 77 | xi_logits = (xi_logits - xi_logits.logsumexp([-1,-2], keepdim=True)) 78 | SEzz = SEzz + xi_logits.exp() 79 | 80 | # Now do the initial step 81 | # Backward Smoothing 82 | temp = self.initial.loggeomean().unsqueeze(-1) + self.transition_loggeomean() 83 | xi_logits = (temp - temp.logsumexp(-2,keepdim=True)) + fw_logits[0].unsqueeze(-2) 84 | SEz0 = xi_logits.logsumexp(-1) 85 | SEz0 = (SEz0-SEz0.logsumexp(-1,True)).exp() 86 | xi_logits = (xi_logits - xi_logits.logsumexp([-1,-2], keepdim=True)) 87 | SEzz = SEzz + xi_logits.exp() 88 | # Backward inference 89 | # bw_logits = bw_logits.unsqueeze(-2) + logits[0] 90 | # xi_logits = self.initial.loggeomean().unsqueeze(-1) + bw_logits 91 | # xi_logits = (xi_logits - xi_logits.logsumexp([-1,-2], keepdim=True)) 92 | # SEzz = SEzz + xi_logits.exp() 93 | # bw_logits = self.initial.loggeomean() + bw_logits.logsumexp(-1) 94 | # SEz0 = (bw_logits - bw_logits.max(-1,keepdim=True)[0]).exp() 95 | # SEz0 = SEz0/SEz0.sum(-1,True) 96 | 97 | self.p = (fw_logits - fw_logits.max(-1,keepdim=True)[0]).exp() 98 | self.p = self.p/self.p.sum(-1,keepdim=True) 99 | 100 | return SEzz, SEz0, logZ # Note that only Time has been integrated out of sufficient statistics 101 | # and the despite the name fw_logits is posterior probability of states 102 | def assignment_pr(self): 103 | return self.p 104 | 105 | def assignment(self): 106 | return self.p.argmax(-1) 107 | 108 | def obs_logits(self,X): 109 | return self.obs_dist.Elog_like(X) 110 | 111 | def update_states(self,X): 112 | # updates states and stores in self.p 113 | # also updates sufficient statistics of Markov process (self.SEzz, self.SEz0) and self.logZ and self.sumlogZ 114 | SEzz, SEz0, logZ = self.forward_backward_logits(self.obs_logits(X)) # recall that time has been integrated out except for p. 115 | NA = self.p.sum(0) # also integrate out time for NA 116 | self.logZ = logZ 117 | while NA.ndim > self.batch_dim + self.event_dim: # sum out the sample shape 118 | NA = NA.sum(0) 119 | SEzz = SEzz.sum(0) 120 | SEz0 = SEz0.sum(0) 121 | logZ = logZ.sum(0) 122 | self.SEzz = SEzz 123 | self.SEz0 = SEz0 124 | self.NA=NA 125 | self.sumlogZ = logZ 126 | 127 | def update_markov_parms(self,lr=1.0): 128 | self.transition.ss_update(self.SEzz,lr) 129 | self.initial.ss_update(self.SEz0,lr) 130 | 131 | def update_obs_parms(self,X,lr=1.0): 132 | self.obs_dist.raw_update(X,self.p,lr) 133 | 134 | def update_parms(self,X,lr=1.0): 135 | self.transition.ss_update(self.SEzz,lr) 136 | self.initial.ss_update(self.SEz0,lr) 137 | self.update_obs_parms(X,self.p,lr) 138 | 139 | def update(self,X,iters=1,lr=1.0,verbose=False): 140 | 141 | ELBO = -np.inf 142 | for i in range(iters): 143 | ELBO_last = ELBO 144 | self.update_states(X) 145 | self.KLqprior_last = self.KLqprior() 146 | self.update_markov_parms(lr) 147 | self.update_obs_parms(X,lr) 148 | 149 | ELBO = self.ELBO().sum() 150 | if verbose: 151 | print('Percent Change in ELBO = %f' % ((ELBO-ELBO_last)/np.abs(ELBO_last)*100)) 152 | 153 | def Elog_like(self,X): # assumes that p is up to date 154 | ELL = (self.obs_dist.Elog_like(X)*self.p).sum(-1) 155 | for i in range(self.event_dim - 1): 156 | ELL = ELL.sum(-1) 157 | return ELL 158 | 159 | def KLqprior(self): 160 | KL = self.obs_dist.KLqprior().sum(-1) + self.transition.KLqprior() + self.initial.KLqprior() # assumes default event_dim = 1 161 | for i in range(self.event_dim - 1): 162 | KL = KL.sum(-1) 163 | return KL 164 | 165 | def ELBO(self): 166 | return self.sumlogZ - self.KLqprior() 167 | 168 | def event_average_f(self,function_string,keepdim=False): 169 | return self.event_average(eval('self.obs_dist.'+function_string)(),keepdim) 170 | 171 | def average_f(self,function_string,keepdim=False): 172 | return self.average(eval('self.obs_dist.'+function_string)(),keepdim) 173 | 174 | def average(self,A,keepdim=False): # returns sample_shape 175 | # A is mix_batch_shape + mix_event_shape 176 | return (A*self.p).sum(-1,keepdim) 177 | 178 | ### Compute special expectations used for VB inference 179 | def event_average(self,A,keepdim=False): # returns sample_shape + W.event_shape 180 | # A is mix_batch_shape + mix_event_shape + event_shape 181 | 182 | out = (A*self.p.view(self.p.shape + (1,)*self.obs_dist.event_dim)).sum(-self.obs_dist.event_dim-1,keepdim) 183 | for i in range(self.event_dim-1): 184 | out = out.sum(-self.obs_dist.event_dim-1,keepdim) 185 | return out 186 | 187 | 188 | 189 | -------------------------------------------------------------------------------- /simulations/Forager.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import torch 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | class Forager(): 8 | 9 | def __init__(self): 10 | # Set the number of foods and their properties 11 | self.num_foods = 10 12 | self.food_range = 100 13 | self.forager_speed = 1 14 | self.vision_range = 20 15 | self.max_food_items = 3 16 | self.d_max = 75 # Maximum distance from home 17 | self.num_steps = 2000 18 | self.noise = 0.5 19 | 20 | def simulate(self): 21 | num_foods = self.num_foods 22 | food_range = self.food_range 23 | forager_speed = self.forager_speed 24 | vision_range = self.vision_range 25 | max_food_items = self.max_food_items 26 | d_max = self.d_max 27 | num_steps = self.num_steps 28 | noise = self.noise 29 | # Create a list to store the foods and their initial locations 30 | foods = [(random.uniform(-food_range, food_range), random.uniform(-food_range, food_range)) 31 | for _ in range(num_foods)] 32 | 33 | # Set the initial location of the forager and its memory 34 | forager_location = (0, 0) 35 | forager_memory = [] 36 | food_collected = 0 37 | food_in_memory = [0] * self.num_foods 38 | 39 | # Store the positions of forager, food items, and consumed food at each time point 40 | forager_positions = [forager_location] 41 | food_positions = [foods[:]] 42 | food_memory = [food_in_memory[:]] 43 | 44 | rand_direction = 2*math.pi*random.uniform(0,1) 45 | 46 | # Simulate the foraging process 47 | for step in range(num_steps): 48 | 49 | new_food_items = [food for food in foods if food not in forager_memory and 50 | math.sqrt((forager_location[0] - food[0]) ** 2 + 51 | (forager_location[1] - food[1]) ** 2) <= vision_range] 52 | 53 | if new_food_items: 54 | # Add the new food items to forager's memory 55 | forager_memory.extend(new_food_items) 56 | for food in new_food_items: 57 | food_in_memory[foods.index(food)] = 1 58 | 59 | if food_collected == max_food_items: 60 | # Move back towards (0, 0) at normal speed 61 | angle = math.atan2(-forager_location[1], -forager_location[0]) 62 | forager_location = ( 63 | forager_location[0] + forager_speed * math.cos(angle) + random.normalvariate(0, noise), 64 | forager_location[1] + forager_speed * math.sin(angle) + random.normalvariate(0, noise) 65 | ) 66 | 67 | # Check if the forager has reached the origin 68 | if math.sqrt(forager_location[0] ** 2 + forager_location[1] ** 2) <= forager_speed: 69 | # Reset food_collected and forager_memory 70 | food_collected = 0 71 | rand_direction = 2*math.pi*random.uniform(0,1) 72 | 73 | # Check if the forager sees any new food items 74 | 75 | if food_collected < max_food_items: 76 | if forager_memory: 77 | # Find the nearest food item 78 | nearest_food = min(forager_memory, key=lambda f: math.sqrt((forager_location[0] - f[0]) ** 2 + 79 | (forager_location[1] - f[1]) ** 2)) 80 | 81 | # Calculate the angle between forager and nearest food item 82 | angle = math.atan2(nearest_food[1] - forager_location[1], nearest_food[0] - forager_location[0]) 83 | 84 | # Move towards the nearest food item with maximum forager speed 85 | forager_location = ( 86 | forager_location[0] + forager_speed * math.cos(angle) + random.normalvariate(0, noise), 87 | forager_location[1] + forager_speed * math.sin(angle) + random.normalvariate(0, noise) 88 | ) 89 | 90 | # Check if the forager has reached the nearest food item 91 | if math.sqrt((forager_location[0] - nearest_food[0]) ** 2 + 92 | (forager_location[1] - nearest_food[1]) ** 2) <= forager_speed: 93 | # Move the food item to a new random location 94 | food_in_memory[foods.index(nearest_food)] = 0 95 | foods[foods.index(nearest_food)] = (random.uniform(-food_range, food_range), random.uniform(-food_range, food_range)) 96 | 97 | # Remove the food item from forager's memory 98 | forager_memory.remove(nearest_food) 99 | food_collected += 1 100 | 101 | else: 102 | # Randomly move away from home until reaching d_max distance 103 | if math.sqrt(forager_location[0] ** 2 + forager_location[1] ** 2) <= d_max: 104 | angle = rand_direction 105 | forager_location = ( 106 | forager_location[0] + forager_speed * math.cos(angle) + random.normalvariate(0, noise), 107 | forager_location[1] + forager_speed * math.sin(angle) + random.normalvariate(0, noise) 108 | ) 109 | else: 110 | # Move clockwise around home until seeing a new food item 111 | angle = math.atan2(forager_location[1], forager_location[0]) + math.pi / 2 112 | forager_location = ( 113 | forager_location[0] + forager_speed * math.cos(angle) + random.normalvariate(0, noise), 114 | forager_location[1] + forager_speed * math.sin(angle) + random.normalvariate(0, noise) 115 | ) 116 | 117 | # Store the positions of forager, food items, and consumed food at each time point 118 | forager_positions.append(forager_location) 119 | food_positions.append(foods[:]) 120 | food_memory.append(food_in_memory[:]) 121 | 122 | return torch.tensor(forager_positions), torch.tensor(food_positions), torch.tensor(food_memory) 123 | 124 | def plot(self, forager_positions, food_positions): 125 | 126 | # Plot the trajectory of the forager, food locations, and consumed food locations 127 | plt.figure(figsize=(8, 6)) 128 | plt.plot(forager_positions[:,0], forager_positions[:,1], label="Forager Trajectory") 129 | plt.scatter(food_positions[:,:,0], food_positions[:,:,1], marker='x', color='green', label="Consumed Food Locations") 130 | plt.scatter(food_positions[-1,:,0], food_positions[-1,:,1], marker='o', s=80, color='red', label="Remaining Food Locations") 131 | plt.xlabel("X") 132 | plt.ylabel("Y") 133 | plt.title("Forager Trajectory and Food Locations") 134 | plt.grid(True) 135 | plt.show() 136 | 137 | def simulate_batches(self, batch_num): 138 | forager_positions = torch.zeros(sim.num_steps + 1,batch_num,2) 139 | food_positions = torch.zeros(sim.num_steps + 1,batch_num,sim.num_foods,2) 140 | food_memory = torch.zeros(sim.num_steps + 1,batch_num,sim.num_foods) 141 | 142 | for i in range(0,batch_num): 143 | forager_positions[:,i,:], food_positions[:,i,:,:], food_memory[:,i,:] = sim.simulate() 144 | 145 | data = torch.cat((forager_positions.unsqueeze(-2),food_positions),-2) 146 | 147 | return data, food_memory 148 | 149 | 150 | sim = Forager() 151 | forager_positions, food_positions, food_memory = sim.simulate() 152 | sim.plot(forager_positions,food_positions) 153 | 154 | # batch_num = 100 155 | # data = sim.simulate_batches(batch_num) 156 | 157 | # v_data = data.diff(n=1,dim=0) 158 | # data = data[1:] 159 | 160 | # v_data = torch.cat((data,v_data),-1) 161 | # v_data = v_data + torch.randn(v_data.shape)*0.1 162 | # v_data = v_data/v_data.std((0,1),True) -------------------------------------------------------------------------------- /simulations/Lorenz.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | class Lorenz(): 6 | def __init__(self): 7 | # Constants 8 | self.sigma = 10.0 9 | self.rho = 28.0 10 | self.beta = 8.0 / 3.0 11 | 12 | # Time step and number of iterations 13 | self.dt = 0.01 14 | self.num_steps = 2000 15 | 16 | def simulate(self, batch_num): 17 | 18 | param_noise_level = 0.02 19 | 20 | sigma = self.sigma + 2*(torch.rand(batch_num)-0.5)*self.sigma*param_noise_level 21 | rho = self.rho + 2*(torch.rand(batch_num)-0.5)*self.rho*param_noise_level 22 | beta = self.beta + 2*(torch.rand(batch_num)-0.5)*self.beta*param_noise_level 23 | 24 | # Initial conditions 25 | 26 | x=torch.randn(batch_num) 27 | y=torch.randn(batch_num) 28 | z=torch.randn(batch_num) 29 | 30 | # Empty lists to store the trajectory 31 | data = torch.zeros(self.num_steps,batch_num,3) 32 | 33 | # Simulation loop 34 | for t in range(self.num_steps): 35 | # Compute derivatives 36 | dx_dt = sigma * (y - x) 37 | dy_dt = x * (rho - z) - y 38 | dz_dt = x * y - beta * z 39 | 40 | # Update variables using Euler's method 41 | x = x + dx_dt * self.dt 42 | y = y + dy_dt * self.dt 43 | z = z + dz_dt * self.dt 44 | 45 | # Append current values to the trajectory 46 | data[t,:,0]=x 47 | data[t,:,1]=y 48 | data[t,:,2]=z 49 | 50 | 51 | n_smoothe = 5 52 | v_data = (data[1:]-data[:-1])/self.dt 53 | data = data[1:] 54 | data = torch.cat((data.unsqueeze(-1),v_data.unsqueeze(-1)),dim=-1) 55 | data = self.smoothe(data,n_smoothe)[::n_smoothe] 56 | data = data/data.std(dim=(0,1,2),keepdim=True) 57 | torch.save(data,'lorenz_data.pt') 58 | return data 59 | 60 | def plot(self,data,batch_num=0): 61 | # Plot the attractor 62 | fig = plt.figure() 63 | ax = fig.add_subplot(111, projection='3d') 64 | ax.plot(data[:,batch_num,0,0], data[:,batch_num,1,0], data[:,batch_num,2,0], lw=0.5) 65 | ax.set_xlabel('X') 66 | ax.set_ylabel('Y') 67 | ax.set_zlabel('Z') 68 | plt.show() 69 | 70 | batch = 0 71 | fig = plt.figure() 72 | ax = fig.add_subplot(111, projection='3d') 73 | ax.plot(data[:,batch_num,0,1], data[:,batch_num,1,1], data[:,batch_num,2,1], lw=0.5) 74 | ax.set_xlabel('VX') 75 | ax.set_ylabel('VY') 76 | ax.set_zlabel('VZ') 77 | plt.show() 78 | 79 | def smoothe(self,data,n): 80 | temp = torch.zeros((data.shape[0]-n,)+data.shape[1:]) 81 | for i in range(n): 82 | temp = temp + data[i:-n+i] 83 | return temp/n 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /simulations/NewtonsCradle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class NewtonsCradle(): 5 | def __init__(self,n_balls,ball_size,Tmax,batch_size,g,leak,dt,include_string=False): 6 | self.n_balls = n_balls 7 | self.Tmax = Tmax 8 | self.batch_size = batch_size 9 | self.dt = dt 10 | self.ball_size = ball_size 11 | self.x_loc = (torch.arange(n_balls) - (n_balls-1)/2)*ball_size 12 | self.g = g 13 | self.leak = leak 14 | self.include_string = include_string 15 | 16 | def initialize(self,init_type='random'): 17 | self.init_type = init_type 18 | if(init_type=='random'): 19 | theta_0 = torch.rand(self.batch_size,self.n_balls)*2*np.pi- np.pi 20 | theta_0 = theta_0.sort(-1)[0] 21 | theta_0 = theta_0/20.0 22 | if(init_type=='1 ball object'): 23 | theta = 2*np.pi*(torch.rand(self.batch_size,1) - 0.5)/100 - np.pi/2*(torch.rand(self.batch_size,1)+1)/2 24 | other_thetas = 2*np.pi*(torch.rand(self.batch_size,self.n_balls-1)-0.5) 25 | other_thetas = other_thetas.sort(-1)[0]/100.0 26 | theta_0 = torch.cat((theta,other_thetas),-1) 27 | if(init_type=='2 ball object'): 28 | theta = 2*np.pi*(torch.rand(self.batch_size,2) - 0.5)/100 - np.pi/2*(torch.rand(self.batch_size,1)+1)/2 29 | theta = theta.sort(-1)[0] 30 | other_thetas = 2*np.pi*(torch.rand(self.batch_size,self.n_balls-2) - 0.5) 31 | other_thetas = other_thetas.sort(-1)[0]/100.0 32 | theta_0 = torch.cat((theta,other_thetas),-1) 33 | if(init_type=='3 ball object'): 34 | theta = 2*np.pi*(torch.rand(self.batch_size,3) - 0.5)/100 - np.pi/2*(torch.rand(self.batch_size,1)+1)/2 35 | theta = theta.sort(-1)[0] 36 | other_thetas = 2*np.pi*(torch.rand(self.batch_size,self.n_balls-3) - 0.5) 37 | other_thetas = other_thetas.sort(-1)[0]/100.0 38 | theta_0 = torch.cat((theta,other_thetas),-1) 39 | if(init_type=='4 ball object'): 40 | theta = 2*np.pi*(torch.rand(self.batch_size,4) - 0.5)/100 - np.pi/2*(torch.rand(self.batch_size,1)+1)/2 41 | theta = theta.sort(-1)[0] 42 | other_thetas = 2*np.pi*(torch.rand(self.batch_size,self.n_balls-4) - 0.5) 43 | other_thetas = other_thetas.sort(-1)[0]/100.0 44 | theta_0 = torch.cat((theta,other_thetas),-1) 45 | 46 | if(init_type == '1 + 1 ball object'): 47 | thetaL = 2*np.pi*(torch.rand(self.batch_size,1) - 0.5)/100 - np.pi/2*(torch.rand(self.batch_size,1)+1)/2 48 | thetaR = 2*np.pi*(torch.rand(self.batch_size,1) - 0.5)/100 + np.pi/2*(torch.rand(self.batch_size,1)+1)/2 49 | thetaL = thetaL.sort(-1)[0] 50 | thetaR = thetaR.sort(-1)[0] 51 | other_thetas = 2*np.pi*(torch.rand(self.batch_size,self.n_balls-2) - 0.5) 52 | other_thetas = other_thetas.sort(-1)[0]/100.0 53 | theta_0 = torch.cat((thetaL,other_thetas,thetaR),-1) 54 | 55 | if(init_type == '1 + 2 ball object'): 56 | thetaL = 2*np.pi*(torch.rand(self.batch_size,1) - 0.5)/100 - np.pi/2*(torch.rand(self.batch_size,1)+1)/2 57 | thetaR = 2*np.pi*(torch.rand(self.batch_size,2) - 0.5)/100 + np.pi/2*(torch.rand(self.batch_size,1)+1)/2 58 | thetaL = thetaL.sort(-1)[0] 59 | thetaR = thetaR.sort(-1)[0] 60 | other_thetas = 2*np.pi*(torch.rand(self.batch_size,self.n_balls-3) - 0.5) 61 | other_thetas = other_thetas.sort(-1)[0]/100.0 62 | theta_0 = torch.cat((thetaL,other_thetas,thetaR),-1) 63 | 64 | if(init_type == '1 + 3 ball object'): 65 | thetaL = 2*np.pi*(torch.rand(self.batch_size,1) - 0.5)/100 - np.pi/2*(torch.rand(self.batch_size,1)+1)/2 66 | thetaR = 2*np.pi*(torch.rand(self.batch_size,3) - 0.5)/100 + np.pi/2*(torch.rand(self.batch_size,1)+1)/2 67 | thetaL = thetaL.sort(-1)[0] 68 | thetaR = thetaR.sort(-1)[0] 69 | other_thetas = 2*np.pi*(torch.rand(self.batch_size,self.n_balls-4) - 0.5) 70 | other_thetas = other_thetas.sort(-1)[0]/100.0 71 | theta_0 = torch.cat((thetaL,other_thetas,thetaR),-1) 72 | 73 | if(init_type == '2 + 3 ball object'): 74 | thetaL = 2*np.pi*(torch.rand(self.batch_size,2) - 0.5)/100 - np.pi/2*(torch.rand(self.batch_size,1)+1)/2 75 | thetaR = 2*np.pi*(torch.rand(self.batch_size,3) - 0.5)/100 + np.pi/2*(torch.rand(self.batch_size,1)+1)/2 76 | thetaL = thetaL.sort(-1)[0] 77 | thetaR = thetaR.sort(-1)[0] 78 | theta_0 = torch.cat((thetaL,thetaR),-1) 79 | 80 | if(init_type == '2 + 2 ball object'): 81 | thetaL = 2*np.pi*(torch.rand(self.batch_size,2) - 0.5)/100 - np.pi/2*(torch.rand(self.batch_size,1)+1)/2 82 | thetaR = 2*np.pi*(torch.rand(self.batch_size,2) - 0.5)/100 + np.pi/2*(torch.rand(self.batch_size,1)+1)/2 83 | thetaL = thetaL.sort(-1)[0] 84 | thetaR = thetaR.sort(-1)[0] 85 | other_thetas = 2*np.pi*(torch.rand(self.batch_size,self.n_balls-4) - 0.5) 86 | other_thetas = other_thetas.sort(-1)[0]/100.0 87 | theta_0 = torch.cat((thetaL,other_thetas,thetaR),-1) 88 | 89 | return theta_0 90 | 91 | def generate_data(self,init_type='random'): 92 | self.init_type = init_type 93 | theta = torch.zeros(self.Tmax,self.batch_size,self.n_balls) 94 | theta[0] = self.initialize(self.init_type) 95 | v_theta = torch.zeros(self.Tmax,self.batch_size,self.n_balls) 96 | hit = torch.zeros(self.batch_size,self.n_balls) 97 | for t in range(1,self.Tmax): 98 | v_theta[t] = v_theta[t-1] - self.dt*self.g*theta[t-1].sin() - self.leak*self.dt*v_theta[t-1] 99 | theta[t] = theta[t-1] + self.dt*v_theta[t] 100 | X = theta[t].sin() + self.x_loc 101 | Y = -theta[t].cos() 102 | for k in range(1,self.n_balls): 103 | dist = (X[:,k]-X[:,k-1])**2 + (Y[:,k]-Y[:,k-1])**2 104 | hit[:,k] = (dist < self.ball_size**2).float() 105 | # temp = theta[t,:,k] 106 | # theta[t,:,k] = theta[t,:,k-1]*hit[:,k] + theta[t,:,k]*(1-hit[:,k]) 107 | # theta[t,:,k-1] = temp*hit[:,k] + theta[t,:,k-1]*(1-hit[:,k]) 108 | v_temp = v_theta[t,:,k-1].clone() 109 | v_theta[t,:,k-1]=v_theta[t,:,k]*hit[:,k] + v_theta[t,:,k-1]*(1-hit[:,k]) 110 | v_theta[t,:,k] = v_temp*hit[:,k] + v_theta[t,:,k]*(1-hit[:,k]) 111 | theta[t,:,k-1] = theta[t-1,:,k-1] + self.dt*v_theta[t,:,k-1] 112 | theta[t,:,k] = theta[t-1,:,k] + self.dt*v_theta[t,:,k] 113 | 114 | theta[t],idx=theta[t].sort(-1) 115 | # for k in range(self.n_balls-1,0,-1): 116 | # dist = (X[:,k]-X[:,k-1])**2 + (Y[:,k]-Y[:,k-1])**2 117 | # hit[:,k] = (dist < self.ball_size**2).float() 118 | # # temp = theta[t,:,k] 119 | # # theta[t,:,k] = theta[t,:,k-1]*hit[:,k] + theta[t,:,k]*(1-hit[:,k]) 120 | # # theta[t,:,k-1] = temp*hit[:,k] + theta[t,:,k-1]*(1-hit[:,k]) 121 | # v_temp = v_theta[t,:,k-1].clone() 122 | # v_theta[t,:,k-1]=v_theta[t,:,k]*hit[:,k] + v_theta[t,:,k-1]*(1-hit[:,k]) 123 | # v_theta[t,:,k] = v_temp*hit[:,k] + v_theta[t,:,k]*(1-hit[:,k]) 124 | # theta[t,:,k-1] = theta[t-1,:,k-1] + self.dt*v_theta[t,:,k-1] 125 | # theta[t,:,k] = theta[t-1,:,k] + self.dt*v_theta[t,:,k] 126 | 127 | X = theta.sin() + self.x_loc 128 | Y = -theta.cos() 129 | if isinstance(self.include_string,int): 130 | for k in range(1,self.include_string): 131 | R = 1-k/(self.include_string) 132 | X = torch.cat((X,theta.sin()*R + self.x_loc),-1) 133 | Y = torch.cat((Y,-theta.cos()*R),-1) 134 | 135 | X = X.unsqueeze(-1) 136 | Y = Y.unsqueeze(-1) 137 | return torch.cat((X,Y),-1), theta 138 | 139 | 140 | 141 | # model = NewtonsCradle(n_balls=5,ball_size=0.2,Tmax=1000,batch_size=1,g=1,leak=0.02/2,dt=0.05) 142 | # data = model.generate_data('1 ball object')[0] 143 | 144 | # import matplotlib.pyplot as plt 145 | 146 | # X = data[...,0] 147 | # Y = data[...,1] 148 | 149 | # plt.plot(data[:,0,0,0]) 150 | # plt.plot(data[:,0,1,0]) 151 | # plt.plot(data[:,0,2,0]) 152 | # plt.plot(data[:,0,3,0]) 153 | # plt.plot(data[:,0,4,0]) 154 | # plt.show() 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /simulations/__init__ .py: -------------------------------------------------------------------------------- 1 | from .NewtonsCradle import NewtonsCradle 2 | from .cartthingy import cartthingy 3 | from .Lorenz import Lorenz 4 | from .Forager import Forager 5 | -------------------------------------------------------------------------------- /simulations/cartthingy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | class cartthingy(): 6 | def __init__(self,): 7 | pass 8 | 9 | def simulate(batch_num=1): 10 | # System parameters 11 | m_c = 1.0 # Cart mass 12 | m_p1 = 0.5 # Mass of pendulum 1 13 | m_p2 = 0.5 # Mass of pendulum 2 14 | l1 = 1 # Length of pendulum 1 15 | l2 = 1 # Length of pendulum 2 16 | g = 1 # Gravity 17 | attractor = 0.1 18 | drag = 0.2 19 | # Simulation parameters 20 | dt = 0.02 # Time step 21 | T = 50.0 # Total simulation time 22 | N = int(T / dt) # Number of time steps 23 | 24 | # Initial conditions 25 | x0 = torch.randn(batch_num,1) # Initial cart position 26 | theta1_0 = np.pi/2 - 2*np.pi/2*torch.rand(batch_num,1) # Initial angle of pendulum 1 27 | theta2_0 = np.pi/2 - 2*np.pi/2*torch.rand(batch_num,1) # Initial angle of pendulum 2 28 | x_dot0 = torch.zeros(batch_num,1) # Initial cart velocity 29 | theta1_dot0 = torch.zeros(batch_num,1) # Initial angular velocity of pendulum 1 30 | theta2_dot0 = torch.zeros(batch_num,1) # Initial angular velocity of pendulum 2 31 | 32 | # Initialize arrays to store the trajectory 33 | trajectory = torch.zeros((N,batch_num, 6)) 34 | trajectory[0] = torch.cat((x0, theta1_0, theta2_0, x_dot0, theta1_dot0, theta2_dot0),dim=-1) 35 | 36 | # Simulate the system 37 | for i in range(1, N): 38 | # Unpack the state variables 39 | x = trajectory[i-1,:,0] 40 | theta1 = trajectory[i-1,:,1] 41 | theta2 = trajectory[i-1,:,2] 42 | x_dot = trajectory[i-1,:,3] 43 | theta1_dot = trajectory[i-1,:,4] 44 | theta2_dot = trajectory[i-1,:,5] 45 | 46 | # Compute the control input (e.g., based on a controller) 47 | control = -attractor*x #- drag*x_dot # Placeholder control input 48 | 49 | # Compute the derivatives of the state variables 50 | 51 | denom = m_c + m_p1*np.sin(theta1)**2 + m_p2*np.sin(theta2)**2 52 | x_ddot = control + np.sin(theta1) * (m_p1 * l1 * theta1_dot ** 2) + np.sin(theta2)*(m_p2 * l2 * theta2_dot ** 2) + m_p1*g*np.sin(theta1)*np.cos(theta1) + m_p2*g*np.sin(theta2)*np.cos(theta2) 53 | x_ddot = x_ddot/denom 54 | 55 | theta1_ddot = -g*l1*np.sin(theta1) - np.cos(theta1)*x_ddot/l1 56 | theta2_ddot = -g*l2*np.sin(theta2) - np.cos(theta2)*x_ddot/l2 57 | 58 | # Update the state variables using Euler integration 59 | x_new = x + x_dot * dt 60 | theta1_new = theta1 + theta1_dot * dt 61 | theta2_new = theta2 + theta2_dot * dt 62 | x_dot_new = x_dot + x_ddot * dt 63 | theta1_dot_new = theta1_dot + theta1_ddot * dt 64 | theta2_dot_new = theta2_dot + theta2_ddot * dt 65 | 66 | # Store the updated state variables in the trajectory 67 | trajectory[i] = torch.cat((x_new.unsqueeze(-1), theta1_new.unsqueeze(-1), theta2_new.unsqueeze(-1), x_dot_new.unsqueeze(-1), theta1_dot_new.unsqueeze(-1), theta2_dot_new.unsqueeze(-1)),dim=-1) 68 | return trajectory[::5] 69 | 70 | 71 | # # Plotting the trajectory 72 | # batch_num = 0 73 | 74 | # t = np.linspace(0, T, N) 75 | # x = trajectory[:, batch_num,0] 76 | # theta1 = trajectory[:, batch_num, 1] 77 | # theta2 = trajectory[:, batch_num, 2] 78 | # x_p1 = x + l1 * np.sin(theta1) 79 | # y_p1 = -l1 * np.cos(theta1) # Negative sign to flip the y-axis 80 | # x_p2 = x - l2 * np.sin(theta2) 81 | # y_p2 = -l2 * np.cos(theta2) # Negative sign to flip the y-axis 82 | 83 | 84 | # plt.figure() 85 | # plt.plot(t, x, label='Cart position') 86 | # plt.plot(t, x_p1, label='Pendulum 1 x') 87 | # plt.plot(t, x_p2, label='Pendulum 2 x') 88 | # plt.plot(t, y_p1, label='Pendulum 1 y') 89 | # plt.plot(t, y_p2, label='Pendulum 2 y') 90 | # plt.xlabel('Time') 91 | # plt.ylabel('Magnitude') 92 | # plt.title('Cart with Two Pendulums') 93 | # plt.legend() 94 | # plt.grid(True) 95 | # plt.show() 96 | 97 | # plt.plot(x_p1[::5],x_p2[::5]) 98 | # plt.show() 99 | -------------------------------------------------------------------------------- /simulations/forager_temp.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import torch 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | class Forager: 8 | def __init__(self): 9 | # Set the number of foods and their properties 10 | self.num_foods = 20 11 | self.food_range = 100 12 | self.forager_speed = 1 13 | self.vision_range = 25 14 | self.max_food_items = 3 15 | self.d_max = 75 # Maximum distance from home 16 | self.num_steps = 20000 17 | 18 | def simulate(self): 19 | num_foods = self.num_foods 20 | food_range = self.food_range 21 | forager_speed = self.forager_speed 22 | vision_range = self.vision_range 23 | max_food_items = self.max_food_items 24 | d_max = self.d_max 25 | num_steps = self.num_steps 26 | 27 | # Create a list to store the foods and their initial locations 28 | foods = [(random.uniform(-food_range, food_range), random.uniform(-food_range, food_range)) 29 | for _ in range(num_foods)] 30 | 31 | # Set the initial location of the forager and its memory 32 | forager_location = (0, 0) 33 | forager_memory = [] 34 | food_collected = 0 35 | 36 | # Store the positions of forager, food items, consumed food, and food memory at each time point 37 | forager_positions = [forager_location] 38 | food_positions = [foods[:]] 39 | consumed_food_positions = [[]] 40 | food_memory = [] 41 | 42 | # Simulate the foraging process 43 | for step in range(num_steps): 44 | if food_collected == max_food_items: 45 | forager_location = (0, 0) 46 | food_collected = 0 47 | forager_memory.clear() 48 | food_memory.clear() 49 | 50 | if food_collected < max_food_items: 51 | # Check if the forager sees any new food items 52 | new_food_items = [food for food in foods if food not in forager_memory and 53 | math.sqrt((forager_location[0] - food[0]) ** 2 + 54 | (forager_location[1] - food[1]) ** 2) <= vision_range] 55 | 56 | if new_food_items: 57 | # Add the new food items to the forager's memory and mark them as present in memory 58 | forager_memory.extend(new_food_items) 59 | food_memory.extend([1] * len(new_food_items)) 60 | 61 | if forager_memory: 62 | # Find the nearest food item 63 | nearest_food = min(forager_memory, key=lambda f: math.sqrt((forager_location[0] - f[0]) ** 2 + 64 | (forager_location[1] - f[1]) ** 2)) 65 | 66 | # Calculate the angle between the forager and the nearest food item 67 | angle = math.atan2(nearest_food[1] - forager_location[1], nearest_food[0] - forager_location[0]) 68 | 69 | # Move towards the nearest food item with the maximum forager speed 70 | forager_location = ( 71 | forager_location[0] + forager_speed * math.cos(angle), 72 | forager_location[1] + forager_speed * math.sin(angle) 73 | ) 74 | 75 | # Check if the forager has reached the nearest food item 76 | if math.sqrt((forager_location[0] - nearest_food[0]) ** 2 + 77 | (forager_location[1] - nearest_food[1]) ** 2) <= forager_speed: 78 | # Move the food item to a new random location 79 | foods.remove(nearest_food) 80 | new_location = (random.uniform(-food_range, food_range), 81 | random.uniform(-food_range, food_range)) 82 | foods.append(new_location) 83 | 84 | # Remove the food item from the forager's memory and mark it as absent in memory 85 | index = forager_memory.index(nearest_food) 86 | forager_memory.pop(index) 87 | food_memory.pop(index) 88 | food_collected += 1 89 | consumed_food_positions[-1].append(nearest_food) 90 | else: 91 | # Randomly move away from home until reaching d_max distance 92 | if math.sqrt(forager_location[0] ** 2 + forager_location[1] ** 2) <= d_max: 93 | angle = random.uniform(0, 2 * math.pi) / 4.0 94 | forager_location = ( 95 | forager_location[0] + forager_speed * math.cos(angle), 96 | forager_location[1] + forager_speed * math.sin(angle) 97 | ) 98 | else: 99 | # Move clockwise around home until seeing a new food item 100 | angle = math.atan2(forager_location[1], forager_location[0]) + math.pi / 2 101 | forager_location = ( 102 | forager_location[0] + forager_speed * math.cos(angle), 103 | forager_location[1] + forager_speed * math.sin(angle) 104 | ) 105 | 106 | # Store the positions of forager, food items, consumed food, and food memory at each time point 107 | forager_positions.append(forager_location) 108 | food_positions.append(foods[:]) 109 | consumed_food_positions.append([]) 110 | food_memory.append([0] * len(foods)) 111 | 112 | return forager_positions, food_positions, consumed_food_positions, food_memory 113 | 114 | def plot(self, forager_positions, food_positions, consumed_food_positions): 115 | # Extract x and y coordinates for plotting 116 | forager_x, forager_y = zip(*forager_positions) 117 | food_x, food_y = zip(*food_positions) 118 | consumed_food_x, consumed_food_y = zip(*consumed_food_positions) 119 | 120 | # Plot the trajectory of the forager, food locations, and consumed food locations 121 | plt.figure(figsize=(8, 6)) 122 | plt.plot(forager_x, forager_y, label="Forager Trajectory") 123 | plt.scatter(food_x, food_y, marker='o', color='red', label="Food Locations") 124 | plt.scatter(consumed_food_x, consumed_food_y, marker='x', color='green', label="Consumed Food Locations") 125 | plt.xlabel("X") 126 | plt.ylabel("Y") 127 | plt.title("Forager Trajectory and Food Locations") 128 | plt.legend() 129 | plt.grid(True) 130 | plt.show() 131 | 132 | 133 | sim = Forager() 134 | forager_positions, food_positions, consumed_food_positions, food_memory = sim.simulate() 135 | sim.plot(forager_positions, food_positions, consumed_food_positions) 136 | -------------------------------------------------------------------------------- /test_bayes_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | import time 5 | from models.BayesNet import * 6 | from models.dists import MatrixNormalWishart, MultivariateNormal_vector_format 7 | 8 | 9 | n=4 10 | p=10 11 | 12 | num_samples = 500 13 | iters = 100 14 | 15 | X = torch.randn(num_samples,p) 16 | X = X-X.mean(0,True) 17 | W = 2*torch.randn(p,n)/np.sqrt(10) 18 | Y = X@W + torch.randn(num_samples,n)/100.0 19 | Y = Y + 0.5 20 | hidden_dims = (10,10,10) 21 | latent_dims = (2,2,2) 22 | 23 | W=W.transpose(-2,-1) 24 | W_hat = MatrixNormalWishart(mu_0 = torch.zeros(n,p),pad_X=True) 25 | t=time.time() 26 | W_hat.raw_update(X.unsqueeze(-1),Y.unsqueeze(-1)) 27 | W_hat.raw_update(X.unsqueeze(-1),Y.unsqueeze(-1)) 28 | W_hat_runtime=time.time()-t 29 | pY = MultivariateNormal_vector_format(mu = Y.unsqueeze(-1),invSigma=1000*torch.eye(n)) 30 | px,Res = W_hat.backward(pY) 31 | invSigma_x_x, invSigmamu_x, Residual = W_hat.Elog_like_X(Y.unsqueeze(-1)) 32 | mu_x = (invSigma_x_x.inverse()@invSigmamu_x) 33 | 34 | 35 | # plt.scatter(mu_x,px.mean().squeeze(-1)) 36 | # plt.show() 37 | 38 | Y_hat = W_hat.predict(X.unsqueeze(-1))[0] 39 | MSE = ((Y-Y_hat.squeeze(-1))**2).mean() 40 | Y_hat2 = W_hat.forward(MultivariateNormal_vector_format(mu = X.unsqueeze(-1),Sigma=torch.eye(p)/1000.0)).mean().squeeze(-1) 41 | 42 | fig, axs = plt.subplots(3, 1, figsize=(6, 6)) 43 | if W_hat.pad_X: 44 | axs[0].scatter(W, W_hat.mean()[:,:-1]) 45 | else: 46 | axs[0].scatter(W, W_hat.mean()) 47 | axs[0].plot([W.min(), W.max()], [W.min(), W.max()]) 48 | axs[0].set_title('W_hat vs W') 49 | axs[1].scatter(X, px.mean().squeeze(-1)) 50 | axs[1].scatter(X, mu_x.squeeze(-1)) 51 | axs[1].plot([X.min(), X.max()], [X.min(), X.max()]) 52 | axs[1].set_title('Backward Prediction') 53 | axs[2].scatter(Y, Y_hat.squeeze(-1)) 54 | axs[2].scatter(Y, Y_hat2.squeeze(-1)) 55 | axs[2].plot([Y.min(), Y.max()], [Y.min(), Y.max()]) 56 | axs[2].set_title('Forward Prediction') 57 | plt.tight_layout() 58 | plt.show() 59 | print('MSE: ',MSE, ' Time: ',W_hat_runtime) 60 | 61 | 62 | model = BayesNet(n,p,hidden_dims,latent_dims) 63 | t=time.time() 64 | model.update(X,Y,lr=1,iters = iters,verbose=False,FBI=False) 65 | model_run_time=time.time()-t 66 | 67 | set_model = BayesNet(n,p,hidden_dims,latent_dims) 68 | # for k, layer in enumerate(model.layers): 69 | # layer.mu = torch.randn_like(layer.mu)/np.sqrt(p)*0.1 70 | # set_model.layers[k].mu = layer.mu.clone().detach() 71 | t=time.time() 72 | set_model.update(X,Y,lr=1,iters = iters,verbose=False,FBI=True) 73 | set_model_run_time=time.time()-t 74 | 75 | Yhat = model.predict(X) 76 | set_Yhat = set_model.predict(X) 77 | 78 | W_net = torch.eye(X.shape[-1]) 79 | for k, layer in enumerate(model.layers): 80 | W_net = layer.weights()@W_net 81 | set_W_net = torch.eye(X.shape[-1]) 82 | for k, layer in enumerate(set_model.layers): 83 | set_W_net = layer.weights()@set_W_net 84 | 85 | fig, axs = plt.subplots(4, 1, figsize=(6, 6)) 86 | axs[0].scatter(Y[:,0], Yhat.squeeze(-1)[:,0],c='b') 87 | axs[0].scatter(Y[:,1], Yhat.squeeze(-1)[:,1],c='b') 88 | axs[0].plot([Y.min(), Y.max()], [Y.min(), Y.max()]) 89 | axs[0].set_title('Prediction') 90 | axs[1].plot(torch.tensor(model.ELBO_save[2:]).diff()) 91 | axs[1].set_title('Change in ELBO') 92 | axs[2].plot(model.MSE[2:]) 93 | axs[2].set_title('MSE') 94 | axs[3].scatter(W, W_net) 95 | axs[3].plot([W.min(), W.max()], [W.min(), W.max()]) 96 | axs[3].set_title('Weights') 97 | 98 | # plt.tight_layout() 99 | # plt.show() 100 | print('MSE: ',model.MSE[-1], ' Time: ',model_run_time) 101 | 102 | # fig, axs = plt.subplots(4, 1, figsize=(6, 6)) 103 | axs[0].scatter(Y[:,0], set_Yhat.squeeze(-1)[:,0],c='orange') 104 | axs[0].scatter(Y[:,1], set_Yhat.squeeze(-1)[:,1],c='orange') 105 | axs[0].plot([Y.min(), Y.max()], [Y.min(), Y.max()]) 106 | axs[0].set_title('Prediction') 107 | axs[1].plot(torch.tensor(set_model.ELBO_save[2:]).diff()) 108 | axs[1].set_title('Change in ELBO') 109 | axs[2].plot(set_model.MSE[2:]) 110 | axs[2].set_title('MSE') 111 | axs[3].scatter(W, set_W_net) 112 | axs[3].plot([W.min(), W.max()], [W.min(), W.max()]) 113 | axs[3].set_title('Weights') 114 | 115 | plt.tight_layout() 116 | plt.show() 117 | print('set_MSE: ',set_model.MSE[-1], ' Time: ',set_model_run_time) 118 | 119 | --------------------------------------------------------------------------------