├── .gitignore
├── LICENSE
├── README.md
├── ResearchIpynb.ipynb
├── __init__.py
├── baselines
├── __init__.py
└── filters.py
├── default.pkl
├── expt-polyphonic-fast
├── locked_write.py
├── setup_experiments.py
├── template.q
└── train_dkf.py
├── expt-polyphonic
├── README.md
├── dkf_polyphonic.py
├── hpc_uas1
│ ├── mdata_MFLR.q
│ ├── mdata_STLR.q
│ ├── mdata_STR.q
│ ├── mdata_STR_MLP.q
│ ├── mdata_STR_ar.q
│ ├── mdata_STR_ar_aug.q
│ ├── mdata_STR_ar_aug_nade.q
│ ├── nott_MFLR.q
│ ├── nott_STLR.q
│ ├── nott_STR.q
│ ├── nott_STR_MLP.q
│ ├── nott_STR_ar.q
│ ├── nott_STR_ar_aug.q
│ ├── nott_STR_ar_aug_nade.q
│ ├── piano_MFLR.q
│ ├── piano_STLR.q
│ ├── piano_STR.q
│ ├── piano_STR_MLP.q
│ ├── piano_STR_ar.q
│ ├── piano_STR_ar_aug.q
│ └── piano_STR_ar_aug_nade.q
├── jsb_expts.sh
├── musedata_expt.sh
├── nott_expt.sh
├── piano_expt.sh
└── train_dkf.py
├── expt-synthetic-fast
├── create_expt.py
├── run_baselines.py
└── train.py
├── expt-synthetic
├── README.md
├── create_expt.py
├── run_baselines.py
└── train.py
├── expt-template
├── README.md
├── load.py
├── runme.sh
└── train.py
├── images
├── ELBO.png
└── dkf.png
├── ipynb
└── synthetic
│ ├── VisualizeSynthetic-fast.ipynb
│ ├── VisualizeSynthetic.ipynb
│ └── VisualizeSyntheticScaling.ipynb
├── paper.pdf
├── parse_args_dkf.py
├── polyphonic_samples
├── jsb0.mp3
├── jsb1.mp3
├── musedata0.mp3
├── musedata1.mp3
├── nott0.mp3
├── nott1.mp3
├── piano0.mp3
└── piano1.mp3
├── stinfmodel
├── README.md
├── __init__.py
├── dkf.py
├── evaluate.py
├── learning.py
└── testall.sh
└── stinfmodel_fast
├── __init__.py
├── dkf.py
├── evaluate.py
└── learning.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.h5
2 | *.amat
3 | *.pyc
4 | *.ipnyb_checkpoints/*
5 | *-checkpoint.ipynb
6 | chkpt*
7 | *.png
8 | *.pkl
9 | *.sh
10 | *.pkl.gz
11 | chkpt-*/
12 | *.o*
13 | *.tmp
14 | *.q
15 | *indicators.txt
16 | *.aux
17 | *.brf
18 | *.log
19 | *.bbl
20 | *.blg
21 | *.dvi
22 | *.ps
23 | *.out
24 | *.fdb_latexmk
25 | *.synctex.gz
26 | *.swp
27 | *.log
28 | ipynb/synthetic/*.pdf
29 | ipynb/medical/*.pdf
30 | expt-polyphonic/*.txt
31 | *.midi
32 | *.wav
33 | *.npz
34 | *.pkl
35 | datasets/tmp_MOCAP*.mat
36 | *.txt
37 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2016 Sontag Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # structuredInf
2 | Code to fully reproduce benchmark results (and to extend for your own purposes) from the paper:
3 |
Krishnan, Shalit, Sontag. Structured Inference Networks for Nonlinear State Space Models, AAAI 2017.
4 | See here for a simplified and easier to use version of the code.
5 |
6 | ## Goal
7 | The goal of this package is to provide a black box inference algorithm for learning models of time-series data.
8 | Inference during learning and at test time is based on compiled recognition or [inference network.](https://arxiv.org/abs/1401.4082)
9 |
10 | ## Model
11 | The figure below describes a simple model of time-series data.
12 |
13 | This method is a good fit if:
14 | * You have an arbitrarily specified state space model whose parameters you're interested in fitting.
15 | * You would like to have a method for fast posterior inference at train and test time
16 | * Your temporal generative model has Gaussian latent variables (mean/variance can be a nonlinear function of previous timestep's variables).
17 |
18 | 
19 |
20 | The code uses variational inference during learning to maximize the likelihood of the observed data:
21 | 
22 |
23 | *Generative Model*
24 |
25 | * The latent variables z1...zT and the observations x1...xT describe the generative process for the data.
26 | * The figure depicts a state space model for time-varying data.
27 | * The emission and transition functions may be pre-specified to have a fixed functional form, a parametric functional form, a function parameterized by a deep neural networks or some combination thereof.
28 |
29 | *Inference Model*
30 |
31 | The box q(z1..zT | x1...xT) represents the inference network. There are several supported inference networks within this package.
32 | * Inference implemented with a bi-directional LSTM
33 | * Inference implemented with an LSTM conditioned on observations in the future
34 | * Inference implemented with an LSTM conditioned on observations from the past
35 |
36 | ## Installation
37 |
38 | ### Requirements
39 | This package has the following requirements:
40 |
41 | python2.7
42 |
43 | [Theano](https://github.com/Theano/Theano)
44 | Used for automatic differentiations
45 |
46 | [theanomodels] (https://github.com/clinicalml/theanomodels)
47 | Wrapper around theano that takes care of bookkeeping, saving/loading models etc. Clone the github repository
48 | and add its location to the PYTHONPATH environment variable so that it is accessible by python.
49 |
50 | [pykalman] (https://pykalman.github.io/)
51 | [Optional: For running baseline UKFs/KFs]
52 |
53 | An NVIDIA GPU w/ atleast 6G of memory is recommended.
54 |
55 | Once the requirements have been met, clone this repository and it's ready to run.
56 |
57 | ### Folder Structure
58 | The following folders contain code to reproduct the results reported in our paper:
59 | * expt-synthetic, expt-polyphonic: Contains code and instructions for reproducing results from the paper.
60 | * baselines/: Contains to run some of the baseline algorithms on the synthetic data
61 | * ipynb/: Ipython notebooks for visualizing saved checkpoints and building plots
62 |
63 | The main files of interest are:
64 | * parse_args_dkf.py: Arguments that the model expects to be present. Looking through it is useful to understand the different knobs available to tune the model.
65 | * stinfmodel/dkf.py: Code to construct the inference and generative model. The code is commented to enable easy modification for different scenarios.
66 | * stinfmodel/evaluate.py: Code to evaluate the Deep Kalman Filter's performance during learning.
67 | * stinfmodel/learning.py: Code for performing stochastic gradient ascent in the Evidence Lower Bound.
68 |
69 | ## Dataset
70 |
71 | We use numpy tensors to store the datasets with binary numpy masks to allow batch sizes comprising sequences of variable length. We train the models using mini-batch gradient descent on negative ELBO.
72 |
73 | ### Format
74 |
75 | The code to run on polyphonic and synthetic datasets has already been created in the theanomodels repository. See theanomodels/datasets/load.py for how the dataset is created and loaded.
76 |
77 | The datasets are stored in three dimensional numpy tensors.
78 | To deal with datapoints
79 | of different lengths, we use numpy matrices comprised of binary masks. There may be different choices
80 | to manipulate data that you may adopt depending on your needs and this is merely a guideline.
81 |
82 | ```
83 | assert type(dataset) is dict,'Expecting dictionary'
84 | dataset['train'] # N_train x T_train_max x dim_observation : training data
85 | dataset['test'] # N_test x T_test_max x dim_observation : validation data
86 | dataset['valid'] # N_valid x T_valid_max x dim_observation : test data
87 | dataset['mask_train'] # N_train x T_train_max : training masks
88 | dataset['mask_test'] # N_test x T_test_max : validation masks
89 | dataset['mask_valid'] # N_valid x T_valid_max : test masks
90 | dataset['data_type'] # real/binary
91 | dataset['has_masks'] # true/false
92 | ```
93 | During learning, we select a minibatch of these tensors to update the weights of the model.
94 |
95 | ### Running on different datasets
96 |
97 | **See the folder expt-template for an example of how to setup your data and run the code on your data**
98 |
99 | ## References:
100 | ```
101 | @inproceedings{krishnan2016structured,
102 | title={Structured Inference Networks for Nonlinear State Space Models},
103 | author={Krishnan, Rahul G and Shalit, Uri and Sontag, David},
104 | booktitle={AAAI},
105 | year={2017}
106 | }
107 | ```
108 | This paper subsumes the work in : [Deep Kalman Filters] (https://arxiv.org/abs/1511.05121)
109 |
--------------------------------------------------------------------------------
/ResearchIpynb.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": false
8 | },
9 | "outputs": [
10 | {
11 | "name": "stdout",
12 | "output_type": "stream",
13 | "text": [
14 | "Couldn't import dot_parser, loading of dot files will not be possible.\n"
15 | ]
16 | },
17 | {
18 | "name": "stderr",
19 | "output_type": "stream",
20 | "text": [
21 | "Using gpu device 0: GeForce GTX TITAN (CNMeM is disabled, cuDNN 4007)\n"
22 | ]
23 | }
24 | ],
25 | "source": [
26 | "import theano\n",
27 | "import theano.tensor as T\n",
28 | "import numpy as np"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 2,
34 | "metadata": {
35 | "collapsed": false
36 | },
37 | "outputs": [],
38 | "source": [
39 | "X = theano.shared(np.zeros((4,10,20)).astype('float32'))\n",
40 | "mask = theano.shared(np.zeros((4,10)).astype('float32'))\n",
41 | "newX = T.tensor3('newX',dtype='float32')\n",
42 | "newMask=T.matrix('newMask',dtype='float32')\n",
43 | "\n",
44 | "resetX= theano.function([newX,newMask],None,updates=[(X,newX),(mask,newMask)])\n",
45 | "statX = theano.function([],[X.mean(),X.max(),X.sum()])"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 3,
51 | "metadata": {
52 | "collapsed": false
53 | },
54 | "outputs": [
55 | {
56 | "name": "stdout",
57 | "output_type": "stream",
58 | "text": [
59 | "0.0 0.0 0.0\n"
60 | ]
61 | }
62 | ],
63 | "source": [
64 | "mnX,maX,smX = statX()\n",
65 | "print mnX,maX,smX"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 4,
71 | "metadata": {
72 | "collapsed": false
73 | },
74 | "outputs": [
75 | {
76 | "data": {
77 | "text/plain": [
78 | "[]"
79 | ]
80 | },
81 | "execution_count": 4,
82 | "metadata": {},
83 | "output_type": "execute_result"
84 | }
85 | ],
86 | "source": [
87 | "nX = np.ones((12,4,32)).astype('float32')\n",
88 | "nM = np.ones((17,3)).astype('float32')\n",
89 | "resetX(newX=nX,newMask=nM)"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 5,
95 | "metadata": {
96 | "collapsed": false
97 | },
98 | "outputs": [
99 | {
100 | "name": "stdout",
101 | "output_type": "stream",
102 | "text": [
103 | "1.0 1.0 1536.0\n"
104 | ]
105 | }
106 | ],
107 | "source": [
108 | "mnX,maX,smX = statX()\n",
109 | "print mnX,maX,smX"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {
116 | "collapsed": true
117 | },
118 | "outputs": [],
119 | "source": []
120 | }
121 | ],
122 | "metadata": {
123 | "kernelspec": {
124 | "display_name": "Python 2",
125 | "language": "python",
126 | "name": "python2"
127 | },
128 | "language_info": {
129 | "codemirror_mode": {
130 | "name": "ipython",
131 | "version": 2
132 | },
133 | "file_extension": ".py",
134 | "mimetype": "text/x-python",
135 | "name": "python",
136 | "nbconvert_exporter": "python",
137 | "pygments_lexer": "ipython2",
138 | "version": "2.7.3"
139 | }
140 | },
141 | "nbformat": 4,
142 | "nbformat_minor": 0
143 | }
144 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | all=['stinfmodel']
2 |
--------------------------------------------------------------------------------
/baselines/__init__.py:
--------------------------------------------------------------------------------
1 | all=['filters']
2 |
--------------------------------------------------------------------------------
/baselines/filters.py:
--------------------------------------------------------------------------------
1 | from pykalman import KalmanFilter
2 | from pykalman import UnscentedKalmanFilter
3 | from pykalman import AdditiveUnscentedKalmanFilter
4 | import numpy as np
5 |
6 | #Running UKF/KF
7 | def runFilter(observations, params, dname, filterType):
8 | """
9 | Run Kalman Filter (UKF/KF) and return smoothed means/covariances
10 | observations : nsample x T
11 | params : {'dname'... contains all the necessary parameters for KF/UKF}
12 | filterType : 'KF' or 'UKF'
13 | """
14 | s1 = set(params[dname].keys())
15 | s2 = set(['trans_fxn','obs_fxn','trans_cov','obs_cov','init_mu','init_cov',
16 | 'trans_mult','obs_mult','trans_drift','obs_drift','baseline'])
17 | for k in s2:
18 | assert k in s1,k+' not found in params'
19 | #assert s1.issubset(s2) and s1.issuperset(s2),'Missing in params: '+', '.join(list(s2.difference(s1)))
20 | assert filterType=='KF' or filterType=='UKF','Expecting KF/UKF'
21 | model,mean,var = None,None,None
22 | X = observations.squeeze()
23 | #assert len(X.shape)==2,'observations must be nsamples x T'
24 | if filterType=='KF':
25 | def setupArr(arr):
26 | if type(arr) is np.ndarray:
27 | return arr
28 | else:
29 | return np.array([arr])
30 | model=KalmanFilter(
31 | transition_matrices = setupArr(params[dname]['trans_mult']), #multiplier for z_t-1
32 | observation_matrices = setupArr(params[dname]['obs_mult']).T, #multiplier for z_t
33 | transition_covariance = setupArr(params[dname]['trans_cov_full']), #transition cov
34 | observation_covariance= setupArr(params[dname]['obs_cov_full']), #obs cov
35 | transition_offsets = setupArr(params[dname]['trans_drift']),#additive const. in trans
36 | observation_offsets = setupArr(params[dname]['obs_drift']), #additive const. in obs
37 | initial_state_mean = setupArr(params[dname]['init_mu']),
38 | initial_state_covariance = setupArr(params[dname]['init_cov_full']))
39 | else:
40 | #In this case, the transition and emission function may have other parameters
41 | #Create wrapper functions that are instantiated w/ the true parameters
42 | #and pass them to the UKF
43 | def trans_fxn(z):
44 | return params[dname]['trans_fxn'](z, fxn_params = params[dname]['params'])
45 | def obs_fxn(z):
46 | return params[dname]['obs_fxn'](z, fxn_params = params[dname]['params'])
47 |
48 | model=AdditiveUnscentedKalmanFilter(
49 | transition_functions = trans_fxn, #params[dname]['trans_fxn'],
50 | observation_functions = obs_fxn, #params[dname]['obs_fxn'],
51 | transition_covariance = np.array([params[dname]['trans_cov']]), #transition cov
52 | observation_covariance= np.array([params[dname]['obs_cov']]), #obs cov
53 | initial_state_mean = np.array([params[dname]['init_mu']]),
54 | initial_state_covariance = np.array(params[dname]['init_cov']))
55 | #Run smoothing algorithm with model
56 | dim_stoc = params[dname]['dim_stoc']
57 | if dim_stoc>1:
58 | mus = np.zeros((X.shape[0],X.shape[1],dim_stoc))
59 | cov = np.zeros((X.shape[0],X.shape[1],dim_stoc))
60 | else:
61 | mus = np.zeros(X.shape)
62 | cov = np.zeros(X.shape)
63 | ll = 0
64 | for n in range(X.shape[0]):
65 | (smoothed_state_means, smoothed_state_covariances) = model.smooth(X[n,:])
66 | if dim_stoc>1:
67 | mus[n,:] = smoothed_state_means
68 | cov[n,:] = np.concatenate([np.diag(k)[None,:] for k in smoothed_state_covariances],axis=0)
69 | else:
70 | mus[n,:] = smoothed_state_means.ravel()
71 | cov[n,:] = smoothed_state_covariances.ravel()
72 | if filterType=='KF':
73 | ll += model.loglikelihood(X[n,:])
74 | return mus,cov,ll
75 |
76 | #Generating Data
77 | def sampleGaussian(mu,cov):
78 | """
79 | Sample from gaussian with mu/cov
80 |
81 | returns: random sample from N(mu,cov) of shape mu
82 | mu: must be numpy array
83 | cov: can be scalar or same shape as mu
84 | """
85 | return np.multiply(np.random.randn(*mu.shape),np.sqrt(cov))+mu
86 |
87 | def generateData(N,T,params, dname):
88 | """
89 | Generate sequential dataset
90 | returns: N x T matrix of observations, latents
91 | N : #samples
92 | T : time steps
93 | params : {'dname'...contains necessary functions}
94 | dname : dataset name
95 | """
96 | np.random.seed(1)
97 | assert dname in params,dname+' not found in params'
98 | Z = np.zeros((N,T))
99 | X = np.zeros((N,T))
100 | Z[:,0] = sampleGaussian(params[dname]['init_mu']*np.ones((N,)),
101 | params[dname]['init_cov'])
102 | for t in range(1,T):
103 | Z[:,t] = sampleGaussian(params[dname]['trans_fxn'](Z[:,t-1]),params[dname]['trans_cov'])
104 | for t in range(T):
105 | X[:,t] = sampleGaussian(params[dname]['obs_fxn'](Z[:,t]),params[dname]['obs_cov'])
106 | return X,Z
107 |
108 | #Reconstruction
109 | def reconsMus(mus_posterior, params, dname):
110 | """
111 | Estimate the observation means using posterior means
112 | mus_posterior : N x T matrix of posterior means
113 | params : {'dname'...contains necessary functions}
114 | """
115 | mu_rec = np.zeros(mus_posterior.shape)
116 | for t in range(mus_posterior.shape[1]):
117 | mu_rec[:,t] = params[dname]['obs_fxn'](mus_posterior[:,t])
118 | return mu_rec
119 |
--------------------------------------------------------------------------------
/default.pkl:
--------------------------------------------------------------------------------
1 | (dp1
2 | S'shuffle'
3 | p2
4 | I00
5 | sS'use_nade'
6 | p3
7 | I00
8 | sS'cov_explicit'
9 | p4
10 | I00
11 | sS'dataset'
12 | p5
13 | S''
14 | sS'epochs'
15 | p6
16 | I2000
17 | sS'seed'
18 | p7
19 | I1
20 | sS'init_weight'
21 | p8
22 | F0.10000000000000001
23 | sS'reg_spec'
24 | p9
25 | S'_'
26 | sS'use_prev_input'
27 | p10
28 | I00
29 | sS'reg_value'
30 | p11
31 | F0.050000000000000003
32 | sS'reloadFile'
33 | p12
34 | S'./NOSUCHFILE'
35 | p13
36 | sS'dim_stochastic'
37 | p14
38 | I100
39 | sS'rnn_layers'
40 | p15
41 | I1
42 | sS'transition_layers'
43 | p16
44 | I2
45 | sS'lr'
46 | p17
47 | F0.00080000000000000004
48 | sS'reg_type'
49 | p18
50 | S'l2'
51 | p19
52 | sS'init_scheme'
53 | p20
54 | S'uniform'
55 | p21
56 | sS'replicate_K'
57 | p22
58 | NsS'optimizer'
59 | p23
60 | S'adam'
61 | p24
62 | sS'use_generative_prior'
63 | p25
64 | I00
65 | sS'q_mlp_layers'
66 | p26
67 | I1
68 | sS'maxout_stride'
69 | p27
70 | I4
71 | sS'batch_size'
72 | p28
73 | I20
74 | sS'savedir'
75 | p29
76 | S'./chkpt'
77 | p30
78 | sS'forget_bias'
79 | p31
80 | F-5
81 | sS'inference_model'
82 | p32
83 | S'structured'
84 | p33
85 | sS'emission_layers'
86 | p34
87 | I2
88 | sS'savefreq'
89 | p35
90 | I25
91 | sS'rnn_size'
92 | p36
93 | I600
94 | sS'paramFile'
95 | p37
96 | g13
97 | sS'emission_type'
98 | p38
99 | S'mlp'
100 | p39
101 | sS'nonlinearity'
102 | p40
103 | S'relu'
104 | p41
105 | sS'rnn_dropout'
106 | p42
107 | F0.10000000000000001
108 | sS'dim_hidden'
109 | p43
110 | I200
111 | sS'var_model'
112 | p44
113 | S'lstmr'
114 | p45
115 | sS'anneal_rate'
116 | p46
117 | F10
118 | sS'debug'
119 | p47
120 | I00
121 | sS'validate_only'
122 | p48
123 | I00
124 | sS'transition_type'
125 | p49
126 | S'simple_gated'
127 | p50
128 | sS'unique_id'
129 | p51
130 | S'DKF_lr-8_0000e-04-vm-lstmr-inf-structured-dh-200-ds-100-nl-relu-bs-20-ep-2000-rs-600-rd-1_0000e-01-ttype-simple_gated-etype-mlp-previnp-False-ar-1_0000e+01-rv-5_0000e-02-nade-False-uid'
131 | p52
132 | sS'leaky_param'
133 | p53
134 | F0
135 | s.
--------------------------------------------------------------------------------
/expt-polyphonic-fast/locked_write.py:
--------------------------------------------------------------------------------
1 | import fcntl,errno,time
2 | with open('remove.me','a') as f:
3 | while True:
4 | try:
5 | fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB)
6 | break
7 | except IOError as e:
8 | if e.errno != errno.EAGAIN:
9 | raise
10 | else:
11 | time.sleep(0.1)
12 | f.write('another line\n')
13 | fcntl.flock(f, fcntl.LOCK_UN)
14 |
--------------------------------------------------------------------------------
/expt-polyphonic-fast/setup_experiments.py:
--------------------------------------------------------------------------------
1 | """
2 | Rahul G. Krishnan
3 |
4 | Script to setup experiments either on HPC or individually
5 | """
6 | import numpy as np
7 | from collections import OrderedDict
8 | import argparse,os
9 |
10 | parser = argparse.ArgumentParser(description='Setup Expts')
11 | parser.add_argument('-hpc','--onHPC',action='store_true')
12 | parser.add_argument('-dset','--dataset', default='jsb',action='store')
13 | parser.add_argument('-ngpu','--num_gpus', default=4,action='store',type=int)
14 | args = parser.parse_args()
15 |
16 | #MAIN FLAGS
17 | onHPC = args.onHPC
18 | DATASET = args.dataset
19 | THFLAGS = 'THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu" '
20 |
21 | #Get dataset
22 | dataset = DATASET.split('-')[0]
23 | all_datasets = ['jsb','piano','nottingham','musedata']
24 | assert dataset in all_datasets,'Dset not found: '+dataset
25 | all_expts = OrderedDict()
26 | for dset in all_datasets:
27 | all_expts[dset] = OrderedDict()
28 |
29 | #Experiments to run for each dataset
30 | all_expts['jsb']['ST-R'] = 'python2.7 train_dkf.py -vm R -infm structured -dset '
31 | all_expts['jsb']['MF-LR'] = 'python2.7 train_dkf.py -vm LR -infm mean_field -dset '
32 | all_expts['jsb']['ST-LR'] = 'python2.7 train_dkf.py -vm LR -infm structured -dset '
33 | all_expts['jsb']['ST-R-mlp'] = 'python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset '
34 | all_expts['jsb']['ST-L'] = 'python2.7 train_dkf.py -vm L -infm structured -dset '
35 | all_expts['jsb']['DKF-ar'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset '
36 | all_expts['jsb']['DKF-aug'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset '
37 | all_expts['jsb']['DKF-aug-nade'] ='python2.7 train_dkf.py -vm R -infm structured -etype conditional -previnp -usenade -dset '
38 |
39 | all_expts['nottingham']['ST-R'] = 'python2.7 train_dkf.py -vm R -infm structured -dset '
40 | all_expts['nottingham']['MF-LR'] = 'python2.7 train_dkf.py -vm LR -infm mean_field -dset '
41 | all_expts['nottingham']['ST-LR'] = 'python2.7 train_dkf.py -vm LR -infm structured -dset '
42 | all_expts['nottingham']['ST-R-mlp'] = 'python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset '
43 | all_expts['nottingham']['ST-L'] = 'python2.7 train_dkf.py -vm L -infm structured -dset '
44 | all_expts['nottingham']['DKF-ar'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset '
45 | all_expts['nottingham']['DKF-aug'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset '
46 | all_expts['nottingham']['DKF-aug-nade'] ='python2.7 train_dkf.py -vm R -infm structured -ar 1000 -etype conditional -previnp -usenade -dset '
47 |
48 | all_expts['musedata']['ST-R'] = 'python2.7 train_dkf.py -vm R -infm structured -dset '
49 | all_expts['musedata']['MF-LR'] = 'python2.7 train_dkf.py -vm LR -infm mean_field -dset '
50 | all_expts['musedata']['ST-LR'] = 'python2.7 train_dkf.py -vm LR -infm structured -dset '
51 | all_expts['musedata']['ST-R-mlp'] = 'python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset '
52 | all_expts['musedata']['ST-L'] = 'python2.7 train_dkf.py -vm L -infm structured -dset '
53 | all_expts['musedata']['DKF-ar'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset '
54 | all_expts['musedata']['DKF-aug'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset '
55 | all_expts['musedata']['DKF-aug-nade'] ='python2.7 train_dkf.py -vm R -infm structured -ar 1000 -etype conditional -previnp -usenade -dset -ds 50 -dh 100 -rs 400'
56 |
57 | all_expts['piano']['ST-R'] = 'python2.7 train_dkf.py -vm R -infm structured -dset '
58 | all_expts['piano']['MF-LR'] = 'python2.7 train_dkf.py -vm LR -infm mean_field -dset '
59 | all_expts['piano']['ST-LR'] = 'python2.7 train_dkf.py -vm LR -infm structured -dset '
60 | all_expts['piano']['ST-R-mlp'] = 'python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset '
61 | all_expts['piano']['ST-L'] = 'python2.7 train_dkf.py -vm L -infm structured -dset '
62 | all_expts['piano']['DKF-ar'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset '
63 | all_expts['piano']['DKF-aug'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset '
64 | all_expts['piano']['DKF-aug-nade'] ='python2.7 train_dkf.py -vm R -infm structured -ar 1000 -etype conditional -previnp -usenade -dset -ds 50 -dh 100 -rs 400'
65 |
66 | if onHPC:
67 | DIR = './hpc_'+dataset
68 | os.system('rm -rf '+DIR)
69 | os.system('mkdir -p '+DIR)
70 | with open('template.q') as ff:
71 | template = ff.read()
72 | runallcmd = ''
73 | for name in all_expts[dataset]:
74 | runcmd = all_expts[dataset][name].replace('',DATASET)+' -uid '+name
75 | command = THFLAGS.replace('',str(np.random.randint(args.num_gpus)))+runcmd
76 | with open(DIR+'/'+name+'.q','w') as f:
77 | f.write(template.replace('',name).replace('',command))
78 | print 'Wrote to:',DIR+'/'+name+'.q'
79 | runallcmd+= 'qsub '+name+'.q\n'
80 | with open(DIR+'/runall.sh','w') as f:
81 | f.write(runallcmd)
82 | else:
83 | for name in all_expts[dataset]:
84 | runcmd = all_expts[dataset][name].replace('',DATASET)+' -uid '+name
85 | command = THFLAGS.replace('',str(np.random.randint(args.num_gpus)))+runcmd
86 | print command
87 |
--------------------------------------------------------------------------------
/expt-polyphonic-fast/template.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=30:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.5v5.1
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic-fast
19 | cd $RUNDIR
20 |
21 |
--------------------------------------------------------------------------------
/expt-polyphonic-fast/train_dkf.py:
--------------------------------------------------------------------------------
1 | import os,time,sys
2 | import fcntl,errno
3 | import socket
4 | sys.path.append('../')
5 | from datasets.load import loadDataset
6 | from parse_args_dkf import params
7 | from utils.misc import removeIfExists,createIfAbsent,mapPrint,saveHDF5,displayTime,getLowestError
8 |
9 | if params['dataset']=='':
10 | params['dataset']='jsb'
11 | dataset = loadDataset(params['dataset'])
12 | params['savedir']+='-'+params['dataset']
13 | createIfAbsent(params['savedir'])
14 |
15 | #Saving/loading
16 | for k in ['dim_observations','dim_actions','data_type']:
17 | params[k] = dataset[k]
18 | mapPrint('Options: ',params)
19 |
20 | if params['use_nade']:
21 | params['data_type']='binary_nade'
22 | #Setup VAE Model (or reload from existing savefile)
23 | start_time = time.time()
24 | from stinfmodel_fast.dkf import DKF
25 | import stinfmodel_fast.learning as DKF_learn
26 | import stinfmodel_fast.evaluate as DKF_evaluate
27 | displayTime('import DKF',start_time, time.time())
28 | dkf = None
29 |
30 |
31 | #Remove from params
32 | start_time = time.time()
33 | removeIfExists('./NOSUCHFILE')
34 | reloadFile = params.pop('reloadFile')
35 | if os.path.exists(reloadFile):
36 | pfile=params.pop('paramFile')
37 | assert os.path.exists(pfile),pfile+' not found. Need paramfile'
38 | print 'Reloading trained model from : ',reloadFile
39 | print 'Assuming ',pfile,' corresponds to model'
40 | dkf = DKF(params, paramFile = pfile, reloadFile = reloadFile)
41 | else:
42 | pfile= params['savedir']+'/'+params['unique_id']+'-config.pkl'
43 | print 'Training model from scratch. Parameters in: ',pfile
44 | dkf = DKF(params, paramFile = pfile)
45 | displayTime('Building dkf',start_time, time.time())
46 |
47 | savef = os.path.join(params['savedir'],params['unique_id'])
48 | print 'Savefile: ',savef
49 | start_time= time.time()
50 | savedata = DKF_learn.learn(dkf, dataset['train'], dataset['mask_train'],
51 | epoch_start =0 ,
52 | epoch_end = params['epochs'],
53 | batch_size = params['batch_size'],
54 | savefreq = params['savefreq'],
55 | savefile = savef,
56 | dataset_eval=dataset['valid'],
57 | mask_eval = dataset['mask_valid'],
58 | replicate_K= params['replicate_K'],
59 | shuffle = False
60 | )
61 | displayTime('Running DKF',start_time, time.time() )
62 |
63 | dkf = None
64 | """
65 | Load the best DKF based on the validation error
66 | """
67 | epochMin, valMin, idxMin = getLowestError(savedata['valid_bound'])
68 | reloadFile= pfile.replace('-config.pkl','')+'-EP'+str(int(epochMin))+'-params.npz'
69 | print 'Loading from : ',reloadFile
70 | params['validate_only'] = True
71 | dkf_best = DKF(params, paramFile = pfile, reloadFile = reloadFile)
72 | additional = {}
73 | savedata['bound_test_best'] = DKF_evaluate.evaluateBound(dkf_best, dataset['test'], dataset['mask_test'], S = 2, batch_size = params['batch_size'], additional =additional)
74 | savedata['bound_tsbn_test_best'] = additional['tsbn_bound']
75 | savedata['ll_test_best'] = DKF_evaluate.impSamplingNLL(dkf_best, dataset['test'], dataset['mask_test'], S = 2000, batch_size = params['batch_size'])
76 | saveHDF5(savef+'-final.h5',savedata)
77 | print 'Experiment Name: <',params['expt_name'],'> Test Bound: ',savedata['bound_test_best'],' ',savedata['bound_tsbn_test_best'],' ',savedata['ll_test_best']
78 |
79 | with open(params['dataset']+'-results.txt','a') as f:
80 | while True:
81 | try:
82 | fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB)
83 | break
84 | except IOError as e:
85 | if e.errno != errno.EAGAIN:
86 | raise
87 | else:
88 | time.sleep(0.1)
89 | f.write('Experiment Name: <'+params['expt_name']+'> Test Bound: '+str(savedata['bound_test_best'])+' '+str(savedata['bound_tsbn_test_best'])+' '+str(savedata['ll_test_best'])+'\n')
90 | fcntl.flock(f, fcntl.LOCK_UN)
91 |
92 | if 'nyu.edu' in socket.gethostname():
93 | import ipdb;ipdb.set_trace()
94 |
--------------------------------------------------------------------------------
/expt-polyphonic/README.md:
--------------------------------------------------------------------------------
1 | ## Modeling Polyphonic Music with a Deep Kalman Filter
2 | Author: Rahul G. Krishnan (rahul@cs.nyu.edu)
3 |
4 | ## Reproducing Numbers in the Paper
5 | ```
6 | Run [dataset]_expts.sh to run the experiments on the polyphonic datasets
7 | ```
8 |
9 | ## Model Details
10 | DKF
11 | Standard Deep Kalman Filter. The default inference algorithm is set to be ST-R (see paper for more details) although
12 | this can be modified through a variety of knobs primarily the -inference_model and the -var_model hyperparameters.
13 |
14 | DKF NADE
15 | Use a nade to model the data rather than a distribution that treats dimensions of the data independantly. (can be used with the -usenade flag)
16 |
17 | DKF AUG
18 | Augmented DKF.
19 | The emission distribution of the generative model is parameterized as p(x_t|x_{t-1}, z_t) (toggled with the -previnp flag)
20 | and the transition distribution is parameterized as p(zt|z{t-1}) (can be activated with the -etype conditional flag)
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/mdata_MFLR.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N mdata-MFLR
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=0.9,scan.allow_gc=False,compiledir_format=compiledir_format=compiledir_%(platform)s-%(processor)s-%(python_version)s-%(python_bitwidth)s-1" python train_dkf.py -vm LR -infm mean_field -dset musedata-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/mdata_STLR.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N mdata-STLR
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm LR -infm structured -dset musedata-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/mdata_STR.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=30:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N mdata-ST-R
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -dset musedata-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/mdata_STR_MLP.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N mdata-STR-MLP
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset musedata-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/mdata_STR_ar.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N mdata-AR
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset musedata-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/mdata_STR_ar_aug.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N mdata-ARAUG
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset musedata-sorted -bs 10 -dh 100 -ds 50
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/mdata_STR_ar_aug_nade.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N mdata-STRAUG
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset musedata-sorted -bs 10 -dh 100 -ds 50
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/nott_MFLR.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N nott-MFLR
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=0.9,scan.allow_gc=False,compiledir_format=compiledir_format=compiledir_%(platform)s-%(processor)s-%(python_version)s-%(python_bitwidth)s-1" python train_dkf.py -vm LR -infm mean_field -dset nottingham-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/nott_STLR.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N nott-STLR
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm LR -infm structured -dset nottingham-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/nott_STR.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=30:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N nott-ST-R
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -dset nottingham-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/nott_STR_MLP.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N nott-STR-MLP
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset nottingham-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/nott_STR_ar.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N nott-AR
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset nottingham-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/nott_STR_ar_aug.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N nott-ARAUG
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset nottingham-sorted -bs 10 -dh 100 -ds 50
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/nott_STR_ar_aug_nade.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N nott-STRAUG
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset nottingham-sorted -bs 10 -dh 100 -ds 50
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/piano_MFLR.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N piano-MFLR
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=0.9,scan.allow_gc=False,compiledir_format=compiledir_format=compiledir_%(platform)s-%(processor)s-%(python_version)s-%(python_bitwidth)s-1" python train_dkf.py -vm LR -infm mean_field -dset piano-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/piano_STLR.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N piano-STLR
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm LR -infm structured -dset piano-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/piano_STR.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=30:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N piano-ST-R
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -dset piano-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/piano_STR_MLP.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N piano-STR-MLP
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset piano-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/piano_STR_ar.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N piano-AR
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset piano-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/piano_STR_ar_aug.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N piano-ARAUG
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset piano-sorted -bs 10 -dh 100 -ds 50
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/hpc_uas1/piano_STR_ar_aug_nade.q:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l nodes=1:ppn=2:gpus=1:k80
3 | #PBS -l walltime=24:00:00
4 | #PBS -l mem=16GB
5 | #PBS -N piano-STRAUG
6 | #PBS -M rahul@cs.nyu.edu
7 | #PBS -j oe
8 |
9 | module purge
10 | module load node
11 | module load cmake
12 | module load python/intel/2.7.6
13 | module load numpy/intel/1.9.2
14 | module load hdf5/intel/1.8.12
15 | module load cuda/7.5.18
16 | module load cudnn/7.0
17 |
18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic
19 | cd $RUNDIR
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset piano-sorted -bs 10 -dh 100 -ds 50
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/jsb_expts.sh:
--------------------------------------------------------------------------------
1 | #MF-LR
2 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm LR -infm mean_field -dset jsb-sorted
3 | #7.098 7.042 -6.6828
4 | #Look into this -> makes sense? Rerunning on rose2[0]
5 |
6 | #ST-R
7 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -dset jsb-sorted
8 | #7.0145 6.9564 -6.5863
9 |
10 | #ST-R -ttype mlp
11 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset jsb-sorted
12 | #7.0976 7.0421 -6.652
13 |
14 | #ST-LR
15 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm LR -infm structured -dset jsb-sorted
16 | #
17 |
18 | #DKF w/ AR
19 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset jsb-sorted
20 | #6.904 6.854 -6.4320
21 |
22 | #DKF Aug
23 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset jsb-sorted
24 | #6.867 6.7935 -6.3199
25 |
26 | #DKF Aug w/ NADE
27 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset jsb-sorted
28 | #5.398 5.325 -5.3809
29 |
--------------------------------------------------------------------------------
/expt-polyphonic/musedata_expt.sh:
--------------------------------------------------------------------------------
1 | #MF-LR
2 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm LR -infm mean_field -dset musedata-sorted
3 |
4 | #ST-R
5 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -dset musedata-sorted
6 |
7 | #ST-R -ttype mlp
8 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset musedata-sorted
9 |
10 | #ST-LR
11 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm LR -infm structured -dset musedata-sorted
12 |
13 | #DKF w/ AR
14 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset musedata-sorted
15 |
16 | #DKF Aug
17 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset musedata-sorted
18 |
19 | #DKF Aug w/ NADE
20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset musedata-sorted
21 |
--------------------------------------------------------------------------------
/expt-polyphonic/nott_expt.sh:
--------------------------------------------------------------------------------
1 | #MF-LR
2 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm LR -infm mean_field -dset nottingham-sorted
3 | #7.098 7.042 -6.6828
4 | #Look into this -> makes sense? Rerunning on rose2[0]
5 |
6 | #ST-R
7 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -dset nottingham-sorted
8 | #7.0145 6.9564 -6.5863
9 |
10 | #ST-R -ttype mlp
11 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset nottingham-sorted
12 | #7.0976 7.0421 -6.652
13 |
14 | #ST-LR
15 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm LR -infm structured -dset nottingham-sorted
16 | #
17 |
18 | #DKF w/ AR
19 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset nottingham-sorted
20 | #6.904 6.854 -6.4320
21 |
22 | #DKF Aug
23 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset nottingham-sorted
24 | #6.867 6.7935 -6.3199
25 |
26 | #DKF Aug w/ NADE
27 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset nottingham-sorted
28 | #5.398 5.325 -5.3809
29 |
--------------------------------------------------------------------------------
/expt-polyphonic/piano_expt.sh:
--------------------------------------------------------------------------------
1 | #MF-LR
2 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm LR -infm mean_field -dset piano-sorted
3 | #7.098 7.042 -6.6828
4 | #Look into this -> makes sense? Rerunning on rose2[0]
5 |
6 | #ST-R
7 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -dset piano-sorted
8 | #7.0145 6.9564 -6.5863
9 |
10 | #ST-R -ttype mlp
11 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset piano-sorted
12 | #7.0976 7.0421 -6.652
13 |
14 | #ST-LR
15 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm LR -infm structured -dset piano-sorted
16 | #
17 |
18 | #DKF w/ AR
19 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset piano-sorted
20 | #6.904 6.854 -6.4320
21 |
22 | #DKF Aug
23 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset piano-sorted
24 | #6.867 6.7935 -6.3199
25 |
26 | #DKF Aug w/ NADE
27 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset piano-sorted
28 | #5.398 5.325 -5.3809
29 |
--------------------------------------------------------------------------------
/expt-polyphonic/train_dkf.py:
--------------------------------------------------------------------------------
1 | import os,time,sys
2 | sys.path.append('../')
3 | from datasets.load import loadDataset
4 | from parse_args_dkf import params
5 | from utils.misc import removeIfExists,createIfAbsent,mapPrint,saveHDF5,displayTime,getLowestError
6 |
7 | if params['dataset']=='':
8 | params['dataset']='jsb'
9 | dataset = loadDataset(params['dataset'])
10 | params['savedir']+='-'+params['dataset']
11 | createIfAbsent(params['savedir'])
12 |
13 | #Saving/loading
14 | for k in ['dim_observations','dim_actions','data_type']:
15 | params[k] = dataset[k]
16 | mapPrint('Options: ',params)
17 |
18 | if params['use_nade']:
19 | params['data_type']='binary_nade'
20 | #Setup VAE Model (or reload from existing savefile)
21 | start_time = time.time()
22 | from stinfmodel.dkf import DKF
23 | import stinfmodel.learning as DKF_learn
24 | import stinfmodel.evaluate as DKF_evaluate
25 | displayTime('import DKF',start_time, time.time())
26 | dkf = None
27 |
28 | #Remove from params
29 | start_time = time.time()
30 | removeIfExists('./NOSUCHFILE')
31 | reloadFile = params.pop('reloadFile')
32 | if os.path.exists(reloadFile):
33 | pfile=params.pop('paramFile')
34 | assert os.path.exists(pfile),pfile+' not found. Need paramfile'
35 | print 'Reloading trained model from : ',reloadFile
36 | print 'Assuming ',pfile,' corresponds to model'
37 | dkf = DKF(params, paramFile = pfile, reloadFile = reloadFile)
38 | else:
39 | pfile= params['savedir']+'/'+params['unique_id']+'-config.pkl'
40 | print 'Training model from scratch. Parameters in: ',pfile
41 | dkf = DKF(params, paramFile = pfile)
42 | displayTime('Building dkf',start_time, time.time())
43 |
44 | savef = os.path.join(params['savedir'],params['unique_id'])
45 | print 'Savefile: ',savef
46 | start_time= time.time()
47 | savedata = DKF_learn.learn(dkf, dataset['train'], dataset['mask_train'],
48 | epoch_start =0 ,
49 | epoch_end = params['epochs'],
50 | batch_size = params['batch_size'],
51 | savefreq = params['savefreq'],
52 | savefile = savef,
53 | dataset_eval=dataset['valid'],
54 | mask_eval = dataset['mask_valid'],
55 | replicate_K= params['replicate_K'],
56 | shuffle = False
57 | )
58 | displayTime('Running DKF',start_time, time.time() )
59 | """
60 | Load the best DKF based on the validation error
61 | """
62 | epochMin, valMin, idxMin = getLowestError(savedata['valid_bound'])
63 | reloadFile= pfile.replace('-config.pkl','')+'-EP'+str(int(epochMin))+'-params.npz'
64 | print 'Loading from : ',reloadFile
65 | params['validate_only'] = True
66 | dkf_best = DKF(params, paramFile = pfile, reloadFile = reloadFile)
67 | additional = {}
68 | savedata['bound_test_best'] = DKF_evaluate.evaluateBound(dkf_best, dataset['test'], dataset['mask_test'], S = 2, batch_size = params['batch_size'], additional =additional)
69 | savedata['bound_tsbn_test_best'] = additional['tsbn_bound']
70 | savedata['ll_test_best'] = DKF_evaluate.impSamplingNLL(dkf_best, dataset['test'], dataset['mask_test'], S = 2000, batch_size = params['batch_size'])
71 | saveHDF5(savef+'-final.h5',savedata)
72 | print 'Test Bound: ',savedata['bound_test_best'],savedata['bound_tsbn_test_best'],savedata['ll_test_best']
73 | import ipdb;ipdb.set_trace()
74 |
--------------------------------------------------------------------------------
/expt-synthetic-fast/create_expt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import sys
4 | import itertools
5 | import argparse
6 | parser = argparse.ArgumentParser(description='Create experiments for synthetic data')
7 | parser.add_argument("-s",'--onScreen', type=bool, default=False,help ='create command to run on screen')
8 | args = parser.parse_args()
9 | np.random.seed(1)
10 | ## Standard checks
11 | var_model = ['LR','L','R']
12 | inference_model = ['mean_field','structured']
13 | optimization = ['adam']
14 | lr = ['0.0008']
15 | nonlinearity = ['relu']
16 | datasets = ['synthetic9','synthetic10']
17 | dim_hidden = [40]
18 | rnn_layers = [1]
19 | rnn_size = [40]
20 | batch_size = [250]
21 | rnn_dropout = [0.00001]
22 | #Add the list and it2s name here
23 | takecrossprod = [inference_model,datasets,optimization, lr, var_model, nonlinearity, dim_hidden, rnn_layers, rnn_size, batch_size, rnn_dropout]
24 | names = ['infm','dset','opt','lr','vm','nl','dh','rl','rs','bs','rd']
25 |
26 | def buildConfig(element,paramnames):
27 | config = {}
28 | for idx,paramvalue in enumerate(element):
29 | config[paramnames[idx]] = paramvalue
30 | return config
31 | def buildParamString(config):
32 | paramstr = ''
33 | for paramname in config:
34 | paramstr += '-'+paramname+' '+str(config[paramname])+' '
35 | return paramstr
36 |
37 | tovis = {}
38 | for idx, element in enumerate(itertools.product(*takecrossprod)):
39 | config = buildConfig(element,names)
40 | #Fixed parameters
41 | config['ep'] = '1500'
42 | config['sfreq'] = '100'
43 | gpun = 1
44 | if idx%2==0:
45 | gpun = 2
46 | if 'lr' not in config:
47 | config['lr']='0.001'
48 | if 'opt' not in config:
49 | config['opt']='adam'
50 |
51 | paramstr= buildParamString(config)
52 | assert 'dset' in config and 'vm' in config and 'infm' in config, 'Expecting dataset, var_model and inference_model to be in config'
53 | #Savefile to look for in visualize
54 | savestr = 'dkf_'+config['dset']+'_ep'+config['ep']+'_rs20dh20ds1vm'+config['vm']+'lr'+config['lr']
55 | savestr+= 'ep'+config['ep']+'opt'+config['opt']
56 | savestr +='infm'+config['infm']+'synthetic'
57 | tovis[paramstr.replace(' ','')] = savestr
58 |
59 | cmd = "export CUDA_VISIBLE_DEVICES="+str(gpun-1)+";"
60 | if args.onScreen:
61 | randStr = 'THEANO_FLAGS="lib.cnmem=0.4,compiledir_format=compiledir_%(platform)s-%(processor)s-%(python_version)s-%(python_bitwidth)s-'+str(np.random.randint(100))+'"'
62 | cmd += randStr+ " python2.7 train.py "+paramstr
63 | name = paramstr.replace(' ','').strip().replace('inference_model','').replace('var_model','').replace('optimization','')
64 | name = name.replace('-','')
65 | name = name.replace('synthetic','S').replace('nonlinearity','NL').replace('dataset','')
66 | print "screen -S "+name+" -t "+name+" -d -m"
67 | print "screen -r "+name+" -p 0 -X stuff $\'"+cmd+"\\n\'"
68 | #+"| tee ./checkpointdir_"+config['dataset']+'/'+savestr.replace('t7','log')+"\\n\'"
69 | else:
70 | cmd += "python2.7 train.py "+paramstr
71 | print cmd
72 |
--------------------------------------------------------------------------------
/expt-synthetic-fast/run_baselines.py:
--------------------------------------------------------------------------------
1 | import sys,os,h5py,glob
2 | sys.path.append('../')
3 | from baselines.filters import runFilter
4 | import numpy as np
5 | from utils.misc import getPYDIR
6 | from datasets.synthp import params_synthetic
7 |
8 | def runBaselines(DIR, name):
9 | DATADIR = getPYDIR()+'/datasets/synthetic'
10 | assert os.path.exists(DATADIR),DATADIR+' not found. must have this to run baselines'
11 | if not os.path.exists(DIR):
12 | os.mkdir(DIR)
13 |
14 | for f in glob.glob(DATADIR+'/*.h5'):
15 | dataset = os.path.basename(f).replace('.h5','')
16 | if name not in dataset:#'synthetic' not in dataset or 'synthetic11' in dataset:
17 | continue
18 | print dataset,f
19 | if os.path.exists(DIR+'/'+dataset+'-baseline.h5'):
20 | print DIR+'/'+dataset+'-baseline.h5',' found....not rerunning baseline'
21 | continue
22 | print 'Reading from: ',f,' Saving to: ',DIR+'/'+dataset+'-baseline.h5'
23 |
24 | filterType = params_synthetic[dataset]['baseline']
25 | h5fout = h5py.File(DIR+'/'+dataset+'-baseline.h5',mode='w')
26 | h5f = h5py.File(f,mode='r')
27 |
28 | if int(dataset.split('synthetic')[1]) in [9,10,11]:
29 | print 'Running filter: ',filterType,' on train'
30 | X = h5f['train'].value
31 | mus,cov,ll = runFilter(X, params_synthetic, dataset, filterType)
32 | h5fout.create_dataset('train_mu',data = mus)
33 | h5fout.create_dataset('train_cov',data = cov)
34 | h5fout.create_dataset('train_ll',data = np.array([ll]))
35 | rmse = np.sqrt(np.square(mus-h5f['train_z'].value.squeeze()).mean())
36 | h5fout.create_dataset('train_rmse',data = np.array([rmse]))
37 |
38 | #Always run exact inference on the validation set
39 | print 'Running filter: ',filterType,' on valid'
40 | X = h5f['valid'].value
41 | mus,cov,ll = runFilter(X, params_synthetic, dataset, filterType)
42 | h5fout.create_dataset('valid_mu',data = mus)
43 | h5fout.create_dataset('valid_cov',data = cov)
44 | h5fout.create_dataset('valid_ll',data = np.array([ll]))
45 | rmse = np.sqrt(np.square(mus-h5f['valid_z'].value.squeeze()).mean())
46 | h5fout.create_dataset('valid_rmse',data = np.array([rmse]))
47 |
48 | if int(dataset.split('synthetic')[1]) in [9,10,11]:
49 | print 'Running filter: ',filterType,' on test'
50 | X = h5f['test'].value
51 | mus,cov,ll = runFilter(X, params_synthetic, dataset, filterType)
52 | h5fout.create_dataset('test_mu',data = mus)
53 | h5fout.create_dataset('test_cov',data = cov)
54 | h5fout.create_dataset('test_ll',data = np.array([ll]))
55 | rmse = np.sqrt(np.square(mus-h5f['test_z'].value.squeeze()).mean())
56 | h5fout.create_dataset('test_rmse',data = np.array([rmse]))
57 |
58 | h5f.close()
59 | h5fout.close()
60 | if __name__=='__main__':
61 | assert len(sys.argv)==2,'expecting sname'
62 | runBaselines('./baselines',sys.argv[-1].strip())
63 |
--------------------------------------------------------------------------------
/expt-synthetic-fast/train.py:
--------------------------------------------------------------------------------
1 | import os,time,sys
2 | sys.path.append('../')
3 | import numpy as np
4 | from datasets.load import loadDataset
5 | from parse_args_dkf import params
6 | from utils.misc import removeIfExists,createIfAbsent,mapPrint,saveHDF5,displayTime
7 |
8 |
9 | if params['dataset']=='':
10 | params['dataset']='synthetic9'
11 | dataset = loadDataset(params['dataset'])
12 |
13 | dataset['train'] = dataset['train'][:params['ntrain']]
14 | params['savedir']+='-'+params['dataset']
15 | createIfAbsent(params['savedir'])
16 |
17 | #Saving/loading
18 | for k in ['dim_observations','dim_actions','data_type', 'dim_stochastic']:
19 | params[k] = dataset[k]
20 | mapPrint('Options: ',params)
21 |
22 | #Setup VAE Model (or reload from existing savefile)
23 | start_time = time.time()
24 | from stinfmodel_fast.dkf import DKF
25 | import stinfmodel_fast.evaluate as DKF_evaluate
26 | import stinfmodel_fast.learning as DKF_learn
27 | displayTime('import DKF',start_time, time.time())
28 | dkf = None
29 |
30 | #Remove from params
31 | start_time = time.time()
32 | removeIfExists('./NOSUCHFILE')
33 | reloadFile = params.pop('reloadFile')
34 | if os.path.exists(reloadFile):
35 | pfile=params.pop('paramFile')
36 | assert os.path.exists(pfile),pfile+' not found. Need paramfile'
37 | print 'Reloading trained model from : ',reloadFile
38 | print 'Assuming ',pfile,' corresponds to model'
39 | dkf = DKF(params, paramFile = pfile, reloadFile = reloadFile)
40 | else:
41 | pfile= params['savedir']+'/'+params['unique_id']+'-config.pkl'
42 | print 'Training model from scratch. Parameters in: ',pfile
43 | dkf = DKF(params, paramFile = pfile)
44 | displayTime('Building dkf',start_time, time.time())
45 |
46 |
47 | savef = os.path.join(params['savedir'],params['unique_id'])
48 | print 'Savefile: ',savef
49 | start_time= time.time()
50 | savedata = DKF_learn.learn(dkf, dataset['train'], dataset['mask_train'],
51 | epoch_start =0 ,
52 | epoch_end = params['epochs'],
53 | batch_size = params['batch_size'],
54 | savefreq = params['savefreq'],
55 | savefile = savef,
56 | dataset_eval=dataset['valid'],
57 | mask_eval = dataset['mask_valid'],
58 | replicate_K = 5
59 | )
60 | displayTime('Running DKF',start_time, time.time())
61 | #Save file log file
62 | saveHDF5(savef+'-final.h5',savedata)
63 | #import ipdb;ipdb.set_trace()
64 |
--------------------------------------------------------------------------------
/expt-synthetic/README.md:
--------------------------------------------------------------------------------
1 | ## README for expt-synthetic
2 | Author: Rahul G. Krishnan (rahul@cs.nyu.edu)
3 |
4 | ## Reproducing results from the paper
5 | This folder contains the scripts to reproduce the synthetic experiments.
6 |
7 | Steps:
8 |
9 | ```
10 | Run "python create_expt.py". This will yield a list of settings to run. Run them sequentially/in parallel
11 | Run the ipython notebook in structuredinference/ipynb/synthetic to obtain the desired plots
12 | ```
13 | ## Synthetic Datasets
14 |
15 | The synthetic datasets are located in the theanomodels repository. See theanomodels/datasets/synthp.py
16 | * The file defines a dictionary, one entry for every synthetic dataset
17 | * Each dataset's parameters are contained within sub-dictionaries. This includes the initial mean/covariance for the transition distribution, the fixed covariance for the emission distribution as well as the emission and transition functions. (params_synthetic['synthetic9'] is a dict with keys such as trans_fxn, emis_fxn, init_cov etc. trans_fxn, emis_fxn are pointers to functions)
18 | * At model creation, this dictionary is embedded into the DKF and use to create the transition and emission function for the generative model. The DKF directly uses the theano implementations of the transition and emission functions in (theanomodels/datasets/synthpTheano.py)
19 |
--------------------------------------------------------------------------------
/expt-synthetic/create_expt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import sys
4 | import itertools
5 | import argparse
6 | parser = argparse.ArgumentParser(description='Create experiments for synthetic data')
7 | parser.add_argument("-s",'--onScreen', type=bool, default=False,help ='create command to run on screen')
8 | args = parser.parse_args()
9 | np.random.seed(1)
10 | ## Standard checks
11 | var_model = ['LR','L','R']
12 | inference_model = ['mean_field','structured']
13 | optimization = ['adam']
14 | lr = ['0.0008']
15 | nonlinearity = ['relu']
16 | datasets = ['synthetic9','synthetic10']
17 | dim_hidden = [40]
18 | rnn_layers = [1]
19 | rnn_size = [40]
20 | batch_size = [250]
21 | rnn_dropout = [0.00001]
22 | #Add the list and it2s name here
23 | takecrossprod = [inference_model,datasets,optimization, lr, var_model, nonlinearity, dim_hidden, rnn_layers, rnn_size, batch_size, rnn_dropout]
24 | names = ['infm','dset','opt','lr','vm','nl','dh','rl','rs','bs','rd']
25 |
26 | def buildConfig(element,paramnames):
27 | config = {}
28 | for idx,paramvalue in enumerate(element):
29 | config[paramnames[idx]] = paramvalue
30 | return config
31 | def buildParamString(config):
32 | paramstr = ''
33 | for paramname in config:
34 | paramstr += '-'+paramname+' '+str(config[paramname])+' '
35 | return paramstr
36 |
37 | tovis = {}
38 | for idx, element in enumerate(itertools.product(*takecrossprod)):
39 | config = buildConfig(element,names)
40 | #Fixed parameters
41 | config['ep'] = '1500'
42 | config['sfreq'] = '100'
43 | gpun = 1
44 | if idx%2==0:
45 | gpun = 2
46 | if 'lr' not in config:
47 | config['lr']='0.001'
48 | if 'opt' not in config:
49 | config['opt']='adam'
50 |
51 | paramstr= buildParamString(config)
52 | assert 'dset' in config and 'vm' in config and 'infm' in config, 'Expecting dataset, var_model and inference_model to be in config'
53 | #Savefile to look for in visualize
54 | savestr = 'dkf_'+config['dset']+'_ep'+config['ep']+'_rs20dh20ds1vm'+config['vm']+'lr'+config['lr']
55 | savestr+= 'ep'+config['ep']+'opt'+config['opt']
56 | savestr +='infm'+config['infm']+'synthetic'
57 | tovis[paramstr.replace(' ','')] = savestr
58 |
59 | cmd = "export CUDA_VISIBLE_DEVICES="+str(gpun-1)+";"
60 | if args.onScreen:
61 | randStr = 'THEANO_FLAGS="lib.cnmem=0.4,compiledir_format=compiledir_%(platform)s-%(processor)s-%(python_version)s-%(python_bitwidth)s-'+str(np.random.randint(100))+'"'
62 | cmd += randStr+ " python2.7 train.py "+paramstr
63 | name = paramstr.replace(' ','').strip().replace('inference_model','').replace('var_model','').replace('optimization','')
64 | name = name.replace('-','')
65 | name = name.replace('synthetic','S').replace('nonlinearity','NL').replace('dataset','')
66 | print "screen -S "+name+" -t "+name+" -d -m"
67 | print "screen -r "+name+" -p 0 -X stuff $\'"+cmd+"\\n\'"
68 | #+"| tee ./checkpointdir_"+config['dataset']+'/'+savestr.replace('t7','log')+"\\n\'"
69 | else:
70 | cmd += "python2.7 train.py "+paramstr
71 | print cmd
72 |
--------------------------------------------------------------------------------
/expt-synthetic/run_baselines.py:
--------------------------------------------------------------------------------
1 | import sys,os,h5py,glob
2 | sys.path.append('../')
3 | from baselines.filters import runFilter
4 | import numpy as np
5 | from utils.misc import getPYDIR
6 | from datasets.synthp import params_synthetic
7 |
8 | def runBaselines(DIR):
9 | DATADIR = getPYDIR()+'/datasets/synthetic'
10 | assert os.path.exists(DATADIR),DATADIR+' not found. must have this to run baselines'
11 | if not os.path.exists(DIR):
12 | os.mkdir(DIR)
13 |
14 | for f in glob.glob(DATADIR+'/*.h5'):
15 | dataset = os.path.basename(f).replace('.h5','')
16 | print 'Reading from: ',f,' Saving to: ',DIR+'/'+dataset+'-baseline.h5'
17 |
18 | filterType = params_synthetic[dataset]['baseline']
19 | h5fout = h5py.File(DIR+'/'+dataset+'-baseline.h5',mode='w')
20 | h5f = h5py.File(f,mode='r')
21 |
22 | print 'Running filter: ',filterType,' on train'
23 | X = h5f['train'].value
24 | mus,cov,ll = runFilter(X, params_synthetic, dataset, filterType)
25 | h5fout.create_dataset('train_mu',data = mus)
26 | h5fout.create_dataset('train_cov',data = cov)
27 | h5fout.create_dataset('train_ll',data = np.array([ll]))
28 | rmse = np.sqrt(np.square(mus-h5f['train_z'].value.squeeze()).mean())
29 | h5fout.create_dataset('train_rmse',data = np.array([rmse]))
30 |
31 | print 'Running filter: ',filterType,' on valid'
32 | X = h5f['valid'].value
33 | mus,cov,ll = runFilter(X, params_synthetic, dataset, filterType)
34 | h5fout.create_dataset('valid_mu',data = mus)
35 | h5fout.create_dataset('valid_cov',data = cov)
36 | h5fout.create_dataset('valid_ll',data = np.array([ll]))
37 | rmse = np.sqrt(np.square(mus-h5f['valid_z'].value.squeeze()).mean())
38 | h5fout.create_dataset('valid_rmse',data = np.array([rmse]))
39 |
40 | print 'Running filter: ',filterType,' on test'
41 | X = h5f['test'].value
42 | mus,cov,ll = runFilter(X, params_synthetic, dataset, filterType)
43 | h5fout.create_dataset('test_mu',data = mus)
44 | h5fout.create_dataset('test_cov',data = cov)
45 | h5fout.create_dataset('test_ll',data = np.array([ll]))
46 | rmse = np.sqrt(np.square(mus-h5f['test_z'].value.squeeze()).mean())
47 | h5fout.create_dataset('test_rmse',data = np.array([rmse]))
48 |
49 | h5f.close()
50 | h5fout.close()
51 | if __name__=='__main__':
52 | runBaselines('./baselines')
53 |
--------------------------------------------------------------------------------
/expt-synthetic/train.py:
--------------------------------------------------------------------------------
1 | import os,time,sys
2 | sys.path.append('../')
3 | import numpy as np
4 | from datasets.load import loadDataset
5 | from parse_args_dkf import params
6 | from utils.misc import removeIfExists,createIfAbsent,mapPrint,saveHDF5,displayTime
7 |
8 | params['dim_stochastic'] = 1
9 |
10 | if params['dataset']=='':
11 | params['dataset']='synthetic9'
12 | dataset = loadDataset(params['dataset'])
13 | params['savedir']+='-'+params['dataset']
14 | createIfAbsent(params['savedir'])
15 |
16 | #Saving/loading
17 | for k in ['dim_observations','dim_actions','data_type']:
18 | params[k] = dataset[k]
19 | mapPrint('Options: ',params)
20 |
21 |
22 | #Setup VAE Model (or reload from existing savefile)
23 | start_time = time.time()
24 | from stinfmodel.dkf import DKF
25 | import stinfmodel.evaluate as DKF_evaluate
26 | import stinfmodel.learning as DKF_learn
27 | displayTime('import DKF',start_time, time.time())
28 | dkf = None
29 |
30 | #Remove from params
31 | start_time = time.time()
32 | removeIfExists('./NOSUCHFILE')
33 | reloadFile = params.pop('reloadFile')
34 | if os.path.exists(reloadFile):
35 | pfile=params.pop('paramFile')
36 | assert os.path.exists(pfile),pfile+' not found. Need paramfile'
37 | print 'Reloading trained model from : ',reloadFile
38 | print 'Assuming ',pfile,' corresponds to model'
39 | dkf = DKF(params, paramFile = pfile, reloadFile = reloadFile)
40 | else:
41 | pfile= params['savedir']+'/'+params['unique_id']+'-config.pkl'
42 | print 'Training model from scratch. Parameters in: ',pfile
43 | dkf = DKF(params, paramFile = pfile)
44 | displayTime('Building dkf',start_time, time.time())
45 |
46 |
47 | savef = os.path.join(params['savedir'],params['unique_id'])
48 | print 'Savefile: ',savef
49 | start_time= time.time()
50 | savedata = DKF_learn.learn(dkf, dataset['train'], dataset['mask_train'],
51 | epoch_start =0 ,
52 | epoch_end = params['epochs'],
53 | batch_size = params['batch_size'],
54 | savefreq = params['savefreq'],
55 | savefile = savef,
56 | dataset_eval=dataset['valid'],
57 | mask_eval = dataset['mask_valid'],
58 | replicate_K = 5
59 | )
60 | displayTime('Running DKF',start_time, time.time())
61 | #Save file log file
62 | saveHDF5(savef+'-final.h5',savedata)
63 |
64 | #On the validation set, estimate the MSE
65 | def estimateMSE(self):
66 | assert 'synthetic' in self.params['dataset'],'Only valid for synthetic data'
67 |
68 | allmus,alllogcov=[],[]
69 |
70 | for s in range(50):
71 | _,mus, logcov = DKF_evaluate.infer(dkf,dataset['valid'])
72 | allmus.append(mus)
73 | alllogcov.append(logcov)
74 | mean_mus = np.concatenate([mus[:,:,:,None] for mus in allmus],axis=3).mean(3)
75 | print 'Validation MSE: ',np.square(mean_mus-dataset['valid_z']).mean()
76 | corrlist = []
77 | for n in range(mean_mus.shape[0]):
78 | corrlist.append(np.corrcoef(mean_mus[n,:].ravel(),dataset['valid_z'][n,:].ravel())[0,1])
79 | print 'Validation Correlation with with True Zs: ',np.mean(corrlist)
80 | #import ipdb;ipdb.set_trace()
81 |
--------------------------------------------------------------------------------
/expt-template/README.md:
--------------------------------------------------------------------------------
1 | ## Template for running code
2 | Author: Rahul G. Krishnan (rahul@cs.nyu.edu)
3 |
4 | This contains a simple, commented example of loading data and running the DKF on random data
5 | The file load.py contains an example of how to load data, the file train.py is a commented
6 | example of how to load the models and run the training script.
7 |
--------------------------------------------------------------------------------
/expt-template/load.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | """
3 | Create a fake dataset time series dataset
4 | """
5 | def loadDataset():
6 | dataset = {}
7 | Ntrain, Nvalid, Ntest = 3000,100,100
8 | T , dim_observations = 100,40
9 | dataset['train'] = (np.random.randn(Ntrain,T, dim_observations)>0)*1.
10 | dataset['mask_train'] = np.ones((Ntrain,T))
11 | dataset['valid'] = (np.random.randn(Ntest,T,dim_observations)>0)*1.
12 | dataset['mask_valid'] = np.ones((Nvalid,T))
13 | dataset['test'] = (np.random.randn(Ntest,T,dim_observations)>0)*1.
14 | dataset['mask_test'] = np.ones((Ntest,T))
15 | dataset['dim_observations'] = dim_observations
16 | dataset['data_type'] = 'binary'
17 | return dataset
18 |
--------------------------------------------------------------------------------
/expt-template/runme.sh:
--------------------------------------------------------------------------------
1 | python2.7 train.py -vm R -infm structured -ds 10 -dh 50
2 |
--------------------------------------------------------------------------------
/expt-template/train.py:
--------------------------------------------------------------------------------
1 | import os,time,sys
2 |
3 | """ Add the higher level directory to PYTHONPATH to be able to access the models """
4 | sys.path.append('../')
5 |
6 | """ Change this to modify the loadDataset function """
7 | from load import loadDataset
8 |
9 | """
10 | This will contain a hashmap where the
11 | parameters correspond to the default ones modified
12 | by any command line options given to this script
13 | """
14 | from parse_args_dkf import params
15 |
16 | """ Some utility functions from theanomodels """
17 | from utils.misc import removeIfExists,createIfAbsent,mapPrint,saveHDF5,displayTime
18 |
19 |
20 | """ Load the dataset into a hashmap. See load.py for details """
21 | dataset = loadDataset()
22 | params['savedir']+='-template'
23 | createIfAbsent(params['savedir'])
24 |
25 | """ Add dataset and NADE parameters to "params"
26 | which will become part of the model
27 | """
28 | for k in ['dim_observations','data_type']:
29 | params[k] = dataset[k]
30 | mapPrint('Options: ',params)
31 | if params['use_nade']:
32 | params['data_type']='binary_nade'
33 |
34 | """
35 | import DKF + learn/evaluate functions
36 | """
37 | start_time = time.time()
38 | from stinfmodel.dkf import DKF
39 | import stinfmodel.learning as DKF_learn
40 | import stinfmodel.evaluate as DKF_evaluate
41 | displayTime('import DKF',start_time, time.time())
42 | dkf = None
43 |
44 | #Remove from params
45 | start_time = time.time()
46 | removeIfExists('./NOSUCHFILE')
47 | reloadFile = params.pop('reloadFile')
48 | """ Reload parameters if reloadFile exists otherwise setup model from scratch
49 | and initialize parameters randomly.
50 | """
51 | if os.path.exists(reloadFile):
52 | pfile=params.pop('paramFile')
53 | """ paramFile is set inside the BaseClass in theanomodels
54 | to point to the pickle file containing params"""
55 | assert os.path.exists(pfile),pfile+' not found. Need paramfile'
56 | print 'Reloading trained model from : ',reloadFile
57 | print 'Assuming ',pfile,' corresponds to model'
58 | dkf = DKF(params, paramFile = pfile, reloadFile = reloadFile)
59 | else:
60 | pfile= params['savedir']+'/'+params['unique_id']+'-config.pkl'
61 | print 'Training model from scratch. Parameters in: ',pfile
62 | dkf = DKF(params, paramFile = pfile)
63 | displayTime('Building dkf',start_time, time.time())
64 |
65 | """Set save prefix"""
66 | savef = os.path.join(params['savedir'],params['unique_id'])
67 | print 'Savefile: ',savef
68 | start_time= time.time()
69 |
70 | """Learn the model (see stinfmodel/learning.py)"""
71 | savedata = DKF_learn.learn(dkf, dataset['train'], dataset['mask_train'],
72 | epoch_start =0 ,
73 | epoch_end = params['epochs'],
74 | batch_size = params['batch_size'],
75 | savefreq = params['savefreq'],
76 | savefile = savef,
77 | dataset_eval=dataset['valid'],
78 | mask_eval = dataset['mask_valid'],
79 | replicate_K= params['replicate_K'],
80 | shuffle = False
81 | )
82 | displayTime('Running DKF',start_time, time.time())
83 |
84 | """ Evaluate bound on test set (see stinfmodel/evaluate.py)"""
85 | savedata['bound_test'] = DKF_evaluate.evaluateBound(dkf, dataset['test'], dataset['mask_test'],
86 | batch_size = params['batch_size'])
87 | saveHDF5(savef+'-final.h5',savedata)
88 | print 'Test Bound: ',savedata['bound_test']
89 | import ipdb;ipdb.set_trace()
90 |
--------------------------------------------------------------------------------
/images/ELBO.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/images/ELBO.png
--------------------------------------------------------------------------------
/images/dkf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/images/dkf.png
--------------------------------------------------------------------------------
/ipynb/synthetic/VisualizeSyntheticScaling.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "Visualize the scalability as a function of dimension in the linear case"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 6,
13 | "metadata": {
14 | "collapsed": false
15 | },
16 | "outputs": [
17 | {
18 | "name": "stdout",
19 | "output_type": "stream",
20 | "text": [
21 | "Found: /data/ml2/rahul/theanomodels/datasets/synthetic//synthetic18.h5\n",
22 | "Found: /data/ml2/rahul/theanomodels/datasets/synthetic//synthetic19.h5\n",
23 | "Found: /data/ml2/rahul/theanomodels/datasets/synthetic//synthetic20.h5\n"
24 | ]
25 | }
26 | ],
27 | "source": [
28 | "import numpy as np\n",
29 | "%matplotlib inline \n",
30 | "import glob,h5py,os,re\n",
31 | "import matplotlib.pyplot as plt\n",
32 | "import matplotlib as mpl\n",
33 | "mpl.rcParams['lines.linewidth']=2.5\n",
34 | "mpl.rcParams['lines.markersize']=8\n",
35 | "mpl.rcParams['text.usetex']=True\n",
36 | "mpl.rcParams['text.latex.unicode']=True\n",
37 | "mpl.rcParams['font.family'] = 'serif'\n",
38 | "mpl.rcParams['font.serif'] = 'Times New Roman'\n",
39 | "mpl.rcParams['text.latex.preamble']= ['\\usepackage{amsfonts}','\\usepackage{amsmath}']\n",
40 | "mpl.rcParams['font.size'] = 20\n",
41 | "mpl.rcParams['axes.labelsize']=20\n",
42 | "mpl.rcParams['legend.fontsize']=20\n",
43 | "#http://stackoverflow.com/questions/22408237/named-colors-in-matplotlib\n",
44 | "import cPickle as pickle\n",
45 | "from datasets.synthp import params_synthetic\n",
46 | "from utils.misc import getPYDIR,getConfigFile,readPickle, loadHDF5\n",
47 | "\n",
48 | "#visualize synthetic results\n",
49 | "def getCode(f):\n",
50 | " params = readPickle(getConfigFile(f))[0]\n",
51 | " code = params['inference_model']+'-'+params['var_model']\n",
52 | " code = code.replace('structured','ST').replace('mean_field','MF')\n",
53 | " return code\n",
54 | "results = {}\n",
55 | "results[10] = {}\n",
56 | "results[10]['name'] = 'synthetic15'\n",
57 | "results[10]['maxEPOCH'] = 200\n",
58 | "results[10]['dim'] = 10\n",
59 | "results[100] = {}\n",
60 | "results[100]['name'] = 'synthetic16'\n",
61 | "results[100]['maxEPOCH'] = 200\n",
62 | "results[100]['dim'] = 100\n",
63 | "results[250] = {}\n",
64 | "results[250]['name'] = 'synthetic27'\n",
65 | "results[250]['maxEPOCH'] = 200\n",
66 | "results[250]['dim'] = 250\n",
67 | "\n",
68 | "ntrainlist = [10,100,1000]\n",
69 | "SAVEDIR = '../../expt-synthetic-fast/chkpt-'\n",
70 | "\n",
71 | "DATADIR = getPYDIR()+'/datasets/synthetic/'\n",
72 | "\n",
73 | "baselines = {}\n",
74 | "for dset in results:\n",
75 | " bline = '../../expt-synthetic-fast/baselines/'+results[dset]['name']+'-baseline.h5'\n",
76 | " if os.path.exists(bline):\n",
77 | " bb = loadHDF5(bline)\n",
78 | " print bb.keys()\n",
79 | " else:\n",
80 | " bb = None\n",
81 | " baselines[dset] = bb\n",
82 | "\n",
83 | "\n",
84 | "from datasets.load import loadDataset\n",
85 | "datasets = {}\n",
86 | "dset_params = {}\n",
87 | "for dset in results: \n",
88 | " datasets[dset] = loadDataset(results[dset]['name'])\n",
89 | " dset_params[dset] = params_synthetic[results[dset]['name']]"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 9,
95 | "metadata": {
96 | "collapsed": false
97 | },
98 | "outputs": [
99 | {
100 | "name": "stdout",
101 | "output_type": "stream",
102 | "text": [
103 | "10 10 (500, 25, 10) (500, 25, 10)\n",
104 | "0.981814049597\n",
105 | "100 10 (500, 25, 10) (500, 25, 10)\n",
106 | "0.981765100916\n",
107 | "1000 10 (500, 25, 10) (500, 25, 10)\n",
108 | "0.981760031913\n",
109 | "10 100 (500, 25, 100) (500, 25, 100)\n",
110 | "0.983456427193\n",
111 | "1000 100 (500, 25, 100) (500, 25, 100)\n",
112 | "0.983420357791\n",
113 | "100 100 (500, 25, 100) (500, 25, 100)\n",
114 | "0.983437612098\n",
115 | "100 250 (500, 25, 250) (500, 25, 250)\n",
116 | "0.982943434674\n",
117 | "10 250 (500, 25, 250) (500, 25, 250)\n",
118 | "0.982955749653\n",
119 | "1000 250 (500, 25, 250) (500, 25, 250)\n",
120 | "0.982936344814\n",
121 | "9 ignored\n"
122 | ]
123 | }
124 | ],
125 | "source": [
126 | "#Visualizing epoch\n",
127 | "getepoch = re.compile(\"-EP(.*)-\")\n",
128 | "\n",
129 | "def estimateMSE(mu_posterior, true_z):\n",
130 | " err_sum = np.square(mu_posterior-true_z).sum()\n",
131 | " return np.sqrt(np.square(mu_posterior-true_z.squeeze()).mean())\n",
132 | "\n",
133 | "final_result = {}\n",
134 | "ignored = 0\n",
135 | "for dset in results:\n",
136 | " DIR = SAVEDIR+results[dset]['name']+'/'\n",
137 | " for f in glob.glob(DIR+'*-final.h5'):\n",
138 | " code = getCode(f)\n",
139 | " if code!='ST-R':\n",
140 | " ignored +=1\n",
141 | " continue\n",
142 | " params = readPickle(getConfigFile(f))[0]\n",
143 | " Ntrain, dimstoc = params['ntrain'],results[dset]['dim']\n",
144 | " alldata = loadHDF5(f)\n",
145 | " #print f,alldata.keys()\n",
146 | " valid_mus= alldata['mu_posterior_valid'][-1]\n",
147 | " valid_zs = datasets[dimstoc]['valid_z']\n",
148 | " print Ntrain,dimstoc, valid_mus.shape, valid_zs.shape\n",
149 | " rmse_train = estimateMSE(valid_mus,valid_zs)\n",
150 | " print rmse_train\n",
151 | " final_result['dim'+str(dimstoc)+'ntrain'+str(Ntrain)]= rmse_train\n",
152 | "print ignored,' ignored'"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": 10,
158 | "metadata": {
159 | "collapsed": false
160 | },
161 | "outputs": [
162 | {
163 | "data": {
164 | "text/plain": [
165 | "(0, 6)"
166 | ]
167 | },
168 | "execution_count": 10,
169 | "metadata": {},
170 | "output_type": "execute_result"
171 | },
172 | {
173 | "data": {
174 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAFVCAYAAAAJ9lMdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3W9s3Hae3/G3bG96SBtrIqe43UMfRKMtcP2DJFIc7FNm\nJedRHhwi23m2xQG17D5Z9MHKf1AU0W7RPVu655XlfXRA0Y0s7/PGUswDrgW2a8sO0GIXiD0K0F76\nJJZlJ9nibhNNH/xID4dDzpAz5PBH8vMCxmNxfkP+hl+SX/JH8kcQERERERERERERERERERERERER\nERERERERERERERERERERERERERERERERERFLTRRdgXE6Bs++gZeKrkecY/DlN3A8w1EuA28Bp4Fd\n4GTo80ngF8AicADcBt7PcPppLABtYCfiszlgHmgBTcxviSpXXkd4xqG9yyZH+JJDLZsRnyVZNrMq\nk4tj45iILb6Bl9oZjcsFzgA3ASejcU5kn6DWgFmggVmQZ4H7gc+fYn7GOnAh42mnMQdsYjYIYU3g\nMnA2MGwT2Kf7t5TbIS+xEvi7DRwCR0PlvgWO0Lv7lnf5FS2bEZ8lWTazKpObI3lPoIpcsk8AOVoA\nznv/Px9T5tGY6hI2jVnJ5/uUuQT8MjTsOnAtr0pZYYLeDTTesKjj97zL56MOy2ZWZXKjJJCSS6kS\nAMAJYA/YApYiPm9gDreLsIfZy1vD7PVEOYM5NA66h9mASLnVYdnMqkxulARScCldAgB47L1f997P\nhT6fB7bHV51UGt7rs9Bwf8Pwxlhrk7c73qus5dOr+rKZVZlcKQkk5JJvAnBzGCemndXfw9jBnHQK\nH3Y3MXs9NmoO+HxqLLUYl7/2XklOXLUtK59eHZbNrMrkqlYnhoflkn8COJPDeDFXXGwG/vbbGacZ\nbuX6KGX5S+R7YquR47jH7wXv/XfAnxLfNt/2yoyzfPbqsGwOaspKWiZXSgIDuIwnAdwE3s5+9A3M\nVRa+DcyKdh5zNULaNtd3sqtaItXa0x/k7733LcwG+mXgB3QunrwL/Bp4gtlQfzvG8tmrw7KZVZlc\nqTmoD5fxJYA8xh/hKaaN1T8JdxKz6tsq7oScr6iThvnyN75/BLwWGP6aNyxYZpzl81XFZTOrMrlS\nEojhUvoE0MS0s4Zdw+xlLdJ7bbZt/PrH3aQ0aAUqpwnge8CP6DTJ4P3/R95nE2Mun626LJtZlcmV\nmoMiuJQ+AYC5smIzYvgOZu/iPOYuzDTG3e564L2mgGeB4X476YMRxm2vNvBd4DsRn30H+GPgf4+5\nfLbqtGxmVSY3SgIhLpVIANDb5hq0AVwk/Yoz7nZXME0Eb9J9Cd1J0m8k7Oev9s+Av+1T7nPMMby/\n7ziO8od9yqdXp2UzqzK5UXNQgEtlEgDATJ/P/OuybbsGO6rh4RK9lw4uecOr5cfe68+IvpvXd8Qr\nM87y2arTsplVmdzUqgM5oB13ZOtSfAKY6Hob2jrmTsOmN7klove6PqS4Drl8k8AVzJ7hEuYuyR3g\nN8CtQLlZzG/yO9e6B3w81prmr93Vd1BRfQTFlTd107I53LKZVZlc1CoJ1LAXUSmL+vUiKiIiIiIi\nIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIilroKPML0UbqOebBM2G3v8w8xD5wH87Sx\n4PeuBl6b3vCHeVZcKuXYM0xnuZa+jgWfLJSFZTorStTzWicxXfoeYh5j92HG009jgegNA8Ac5rcs\neu9R5ZKUsdbRo1i9bHr1G9U5+j8ucRl4I+X3JlESGEnNniz2zUv9n5XnUuxTBSay7kp4DdNPeQOz\nkQ0/t/WpV6F14ELG005jDpOsTkd81gQuA2cDwzYxG4X7KcpY7dtveenOnc7fDx7Ayop5vRG1WQzJ\nu/zbb2fSzXW/rusXgS1gL+X3ntJ5EI0MQU8W6+JgNtBnMBvsso0/0gKdpxaFn17kezSmuoRNE980\n4LsE/DI07DqmmSBNmdKwLQE8yP9JzouYh6lEJYAkduk0H0lKSgI9HCqWCE5gVq4tzBOSwhqYB10X\nYQ9zBLJG/OH+GcxKHnQPk9zSlCkFGxPAysrgciOYxySAtEdsk4H/t+g8oVlSUhKI5FChRPDYe/cP\nmc+FPp/Hvue5+hre67PQcD9pvZGwTCnYmgByTAKnMU2BwzTZLdBJBHtDjkNQEujDoQKJYJbOHvIO\nZo8p3CTUZPjD8Lw1B3w+lbBMKdiaAJKUH8IC5hzOFbr36pM6lW116qtmJ4bTcuhsqPM4WRwef+ZO\nYk6Q+vx28mmG2/B/lLL8JfLdQ0vSlFWaZgJbNuhjSACzmCs0zmKuWrtB90n9KA3MZaFgjh4WgIu5\n1K5mlAQGchhfIshcA3P1hG8DkwTOY66mSXs+4J3sqpZIkr340uzpD2LDBn0MCaABvIw5MgXTPHkP\n0yy5E/clzHJ6OfD3VfpfNSQJqTkoEYfxNA3l7imm/d8/QXyS6PsHbNHvmnIwG4YkZSqhAgkATDw+\nDvx9H1gl/WWetzOrUc0pCSTmkH8iyFQTcw4g7Bpmb2yR3vsGbOPX/3jM5/sJy5ReRRJAnMuYI7qr\ngwoG7NB9lCtDUnNQKg75Ng1lap7u8wG+Hcze2HnS702N+5zAgfeagq47Vv12/geBcoPKlFbFE4Dv\nDGZ5vI69FypUkpJAag4lSQTh8wFBG5iTamk36uM+JwCm+epNui8BPUl3AktSppRqkgDA7JzsYFaq\nk4XWpGZsaQ5awNw5uow5JJwttjqDOBRw529aM30+89tfbbs/IOpE3yV6L2td8oanKVM6FUsAVzE7\nHpNE3yU+h7lqbRazc/JDb/i10PeWc69pzdhwdn0JE/wrgWHXie/iYBTt/n0HpeWS7RHBxPN/RrCO\nSapNTMWWiD4i+BB4f8RpjWoSE/cGpp73MHuDvwFuBcrNYn5TC/O77tF9cjFpGZu1Le87COzYXkjF\nNOntAXCJzl2uGatdL6JSErb3InrkCF/n+POlxq4DPwkNm6REt/qLiMjw9um0/YmIyJgVeXVQE9MW\n/ARzzTp07v68UUiNRERkbOYwT7QKNwddpbenSxERqRg/CbwXGj5NRe7yFBGReH4SeDXis0N0clhE\nJHdFnhOI6tcmqEnolv/XX3+9/cknn+RXIxGRavqEmB3rIu8Y9vuFievvvSdJfPLJJ7Tb7dK/Pvjg\ng8LroJfiV8dXXWMHvB63IS6624hNep8QNIe5Yqj0HX+JiNiu6CRwid6uCy6jq4NERMai6F5En2I2\n+OvAI+AE8EvgV0VWKm+O4xRdBRmB4ldeil2vsnUI1fbat0REJKGJifjOKYtuDhIRkQIpCYiI1JiS\ngIhIjSkJiIjUmJKAiEiNKQmIiNSYkoCISI0pCYiI1JiSgIhIjSkJiIjUmJKAiEiNKQmIiNSYkoCI\nSI0pCYiI1JiSgIhIjSkJiIjUmJKAiEiNKQmIiNSYkoCISI0pCYiI1JiSgIhIjSkJiIjUWNFJoAnc\nBGYDfy8Di4XVSESkRopOAmA2+PeAQ+Au8AS4VWiNRERq4ljB028DC0ALaAAPiq2OiEi9FJ0EACaA\nz4quhIhIHdnQHCQiIgWx4UigiWkKApjy3m8UVBcRkVopOgnse6/gieB1712JQEQkZxNFVyDCLLBD\n56ggqN1ut8dcHRGRcpuYmICY7b2N5wT2MM1DrxZcDxGRyiu6OegicB14Ghh24L03ibhqaGVl5fn/\nHcfBcZzcKiciUkau6+K6bqKyRTYHNYGHwBzd9wc0MOcJopKAmoNERFKytTmoBZyn9waxBeARundA\nRCR3NpwTmA79fRk4U0RFRETqxoarg87RuU9gBnOJaFz3EWoOEhFJqV9zkA1JIA0lARGRlGw9JyAi\nIgVTEhARqTElARGRGlMSEBGpMSUBEZEaUxIQEakxJQERkRpTEhARqTElARGRGlMSEBGpMSUBEZEa\nUxIQEakxJQERkRpTEhARqTElARGRGlMSEBGpMSUBEZEaUxIQEakxJQERkRpTEhARqTElARGRGlMS\nEBGpMSUBEZEasy0JNID1oishIlIXx4quQMg14OWiKyEiUhc2HQk0MQmgXXRFRETqwqYkMA/cBiaK\nroiISF3YkgTmge2iKyEiUje2JIEmsIeOAkRExsqGJLAI3Ci6EiIidVR0EmgUPH0RkVor+hLRM3Qf\nBQy8MmhlZeX5/x3HwXGczCslIlJmruvium6iskW2wc967/cDw5aABeBszHfa7bauIBURSWNiYgJi\ntvdFJoFzwExo2BzmJPEW8BvgVuhzJQERkZRsTQJRloGTwPsxnysJiIik1C8JFH1iOGwC+xKTiEhl\n2bLBnQbOA6cxXUfcAD6k+3wB6EhARCS1MjUHDaIkICKSUpmag0REZIyUBEREakxJQESkxpQERERq\nTElARKTG+vUddBxzR+8U5i7ebUx3z75JTBcPU3Tu/H0I/CL7aoqISB76XSJ6CNwELtO98Y9zDfgJ\ncDSDesXRJaIiIikNe5/AQ+D7Kac1zHfSUBIQEUmpXxLo1xwUftzjtPfeDowsfISgR0SKiJRIvxPD\nB6G/G5j+/1uY7h2iHggT/o6IiFgszUNl7nuvk8BaPtUREZFxGuYS0VbmtRARkULoPgERkRobdIlo\nVBt/I2Y4mHsHdImoiIhFhr06aA+4GvfFGBdTlBURkYL1SwI3MQ93SaM5Ql1ERGTM9FAZEZGKy/Oh\nMq96LxERKaF+zUEXMR3Egbks9B6d5qFZTHPRlPdZA3P+QJ3HiYiUyKDmoJvAOrATGDYNPAI2gAuB\n4ReBffJNBGoOEhFJadgO5M4Bm8DT0PCrmG4jojqKW6c7MWRNSUBEJKVhzwk06E0AYJqItmK+o76D\nRERKJE3fQWASwxxwKcM6LNK5tHQG09SkvolERMagXxI4ETHsjPe+E/HZME4DT+je6N/13pUIRERy\n1q856C7mSWG+eczTw1Zjyv+F90rj/Yhh28D5lOMREZEh9EsCW8AzzNPC9oHbmCuCLgfKzGNOID/E\nHCWcIZ1JzNFA0ATmwTUiIpKzpHcMx3UaNxkaT5vok8lp3AP+C/CXEZ/p6iARkZSG7UAuKO6qn1E3\n+GGngS+ITgAiIpKxfs1B54YY3zDf8b93FdOcdHbIcYiISEr9moM+Il3X0BOYE8fvjFCfScyVR5eI\nvgJJzUEiIikNe8fw4RDTajP6Q2UWMX0UTdPb3KQkICKS0rDnBN7BnBB+QvL7AtZT1SzafW+6C8Ct\n8IcrKyvP/+84Do7jZDBJEZHqcF0X13UTlU1yddAknd5EHwEP+pSdJ3nCmMPcizAXGmcDc0nqEr2d\n0elIQEQkpVGvDnpKZ498GtNcA+amrnBzTZo7idvALqYr6qCTgfGLiEiORnmy2Cymz58DzB79MJeL\nLmNuQAt+9zbm5rN/E1FeRwIiIikNe2I4jUXMpZ0fAr9K+d3lwP/fAv4H8fcJKAmIiKSUxc1iUY5j\n+v45j2nX346byADqKE5EpCDDJIFFzIZ/AdOmfx3zBDI9S0BEpGSS7rnPY+7mXcKcyL2OOVkcPqmb\nNzUHiYikNGxz0CymuWfJ+3sTc+XObp/vLBJxbb+IiNhp0B3DW5iTvUk27H6XDycHFRyBjgRERFIa\n9kjgPiYBQOfegH5O0HlMpIiIlEC/JHCb9E07MyPURURExiyr+wTGRc1BIiIp9WsO6vc8gbQWMfcO\niIhISWSZBG4R/eB4ERGxVJIkcBx4I++KiIjI+A06J7CIuRsYzHMF5ul0+/wqpruIJuaE8FuYm8fy\nfDykzgmIiKQ0bAdys5gEcA3Yw2zwz2LuA1j2hvsOMH0HnSP7h88HKQmIiKQ07H0C54FTmAQAZiO/\njXkgfBN4mXw3+CIikrNB5wT2Qn/v0jkiUAIQESm5fkkgrlfQ63lURERExi/LS0TBnEcQEZGSyDoJ\nLA0uIiIithjUi2jaB8VMAkeHr85AujpIRCSlUXoRXY/7YgwdCYiIlMigXkRvpByfupIWESkR9SIq\nIlJx4+pFVERESqZfc9C4LGL6HWpgmpNukr4ZSkREhlB0ElgEHtF5gtkkcA+TENaKqpSISF0U3RzU\npNMrKZiuKC7R3TmdiIjkpMgk0MQ8hGY6NPy+965nGIiI5KzIJLCP6YwunARERGRMijwncEB0Eprz\n3h9EfCYiIhkq+pxAlCvonICIyFjYlgROY/osulJ0RURE6sCmJNDA9D00X3RFRETqouj7BII2MEcC\nz/oVWllZef5/x3FwHCfXSomIlI3rurium6isLX0HXQV+zoAEgPoOEhFJzfa+g85huqwOJoBZdOmo\niEjuim4OOu29T3kv31ng8virIyJSL0U2BzUwN4xFeQT804jhag4SEUmpX3OQLecEklISEBFJyfZz\nAiIiUhAlARGRGlMSEBGpMSUBEZEaUxIQEakxJQERkRpTEhARqTElARGRGlMSEBGpMSUBEZEaUxIQ\nEakxJQERkRpTEhARqTElARGRGlMSkMK0Wi0uXLiQqOyNGze4desWt27dYm1trXbTsE2V5msd4xdU\n9JPFRtJqtVhdXWV9fX1g2Rs3bjA1NfX8e8vLy7Wahk3u37/P9vY2jx8/5u7duwPLb2xs8Morr/De\ne+8BsLe3x4ULF/rOr6pMwzZVmq91jF8VtNvtdnt3d7e9urravnTpUvvNN99sD3L9+vX2rVu3nv/d\narXa58+f7/udqkzDZru7u4l+98zMTOSwg4OD2kzDNlWar3WIHxD7NK5SNgfNzs6yvLzM+++/n6j8\n6urq8+wNMD09zfb2Nk+fPq38NGzWTvCUuFarxf5+71NIm80m29vbtZmGbao0X+sYv6BSJgFfVRbE\nui+E/bRarefNX0GNRoNWq6VpWKxK87XK8St1EkiiKgtilRfCYT1+/FjTKKkqzdeyx6/ySaCfqiyI\nZV8I+zk4OIj9zHt4tqZhqSrN1yrHr/JJoCoLYpUXwn4ajUbsZydOnNA0LFal+Vrl+NmSBBaA+TxG\nXJUFscoLYT/NZjPyXMjBwQHNZlPTsFiV5muV42dDEpgDNulzCdMoqrIgVnkh7Mf/beEroFqtFgsL\nC5qGxao0X6scvyKTwDSwTk5HAL6qLIhVXgjDWq0WW1tbz/++du0am5ubz//e3d3l1KlTHD9+vPbT\nsE2V5msd41ekh8APE5TrugHi3r17kTd5PHr0qH3z5s3nf29sbLQ3Nja6vnfhwoVEN1lUZRo2abVa\n7dXV1fapU6faR44caV+6dKnrd12/fr39/e9/v+s7q6ur7a2trfbW1lZ7dXW1NtOwTZXma53iR5+W\nFlvOKD4EloCPB5Rrt9tt9vb22Nra4vbt2+zs7LC8vMzMzAznzp0DzO3da2trfPrpp8+/uLa29nxv\nOkl3C1WZhoiId/FI5Pa+lElARESS65cEbDgxLCIiBVESEBGpMSUBEZEaK93zBFZWVp7/33EcHMcp\nrC4iIjZyXRfXdROVLd2J4fyr8h3gD/lO4ghwmO8kXpp8iWcHz/KdSEovv/QSB199lfNUjgHf5DuJ\nMcTvxX/0Il9/+XW+E0lh6vhxnnz5Zc5Tqca6B/DiP3yRr7+yJ379TgyX7khg2CzgAmeAm4DTp9wE\nf8h/KocTsBIa9gfgr4D/E5j8BPBPgB9h1o8U5b/8j3mvsOkdfPUVd+g//4flYub8F3wDeU/l8Av4\nd2Qar3D536/8PvPaj+LJl1/2XStckq1f/Qxe9zKYymFgO5hhvMLlf/+1XfHrx6ZzArkdlbiMvoAm\n43hTOeNNNYXvAH9M9zrQBr5L7wI3THlLODmM06UTX2OI+Z+I05lKTeKVhEsJ1q8oWcYrqnxJFJkE\nJoGrmK4jmt7/rwKLWU7EZVwLqM9h6AX1c0xEGt7rCPC3GZavIJeo+Ga4oejRmYriVbL1y5dXvILl\nS6TI5qCnwGXv/xfymIDLuBdQn0NnQU0x9SPAnwH/0vv7fwK/ybB8xbjEzWGHoeZ/WjWPl0vJ1i/f\nj733rOMVLP+zdFUqki0nhpNKfL+wy3CLiJkhWR3TxdUi4pxAG3PC6mho+LeYhSscqUHlf5rssZXj\nNDExMY45Syd+caVGNQEfkG28wuVX7IpfOHYu2c/Z9OveMLUIrXtZxStcfsW++FGnO4ZditpDCXNI\nfOg6Qe8ChzcsKnRpy1eIS9L4OuTWNFTjeLmUcP2Kk1W84sqXQOWSgIstC6jPId826npxSRtfB83/\n7Lho/aqaSiUBF9sWUJ9D5IJ6x3sllbZ8xbgMG1+HQjYUFYuvS8nWr7QqFq+kKpMEXLJcQN2Rx9DL\nobOgYtoQ/9p7JWk6TFu+YlxGja/DWBNBBeObZwJwRx6Dw0jxrWC8kqpEEnDJegEdw3XmvwNe8F6/\no/+C1B6ifIW4ZBVfh8wSQdbxCpe3UJ4J4EwmY3IYKr5ZxKvE61zJTkn1Xh3kkm0CMDPkTsZjjZjK\nUcyVBtA52fQy8APgpDf8LvBr4AlmIUtT/hu7rk6A4a4OckkXiWRXmKQda8RUjpFtvMLlv7Urflle\n2RXk0onE20A2W9PgWJ3QZ4Grg7KMV7j8in3xo4pXB7nktal2yL3p4NuI//8R8Fpg+GvesGHKV4CL\n5fHNMl5R5SvOpeD4Zh2vcPmSKG0ScMn7JJVDrolgIvT/72H6JHkhMPwFb9j3hihfci6WxzfreIXL\nV5yLBfHNMl5R5UuilEnApaR9lQTl1eeMX77EXEoQ3zz7nLGnFSEXLhbFV31AlS8JuJSwr5IoefY5\n83lGdSyAS0nim1efM375inKxML417wOqdF1JF3OZmkPmfdHk2edMSTciLiXqi6bf3aFZ9Cn0q8FV\nKBuXJHPY7fvpcBy6Ls8Oq3kfUGVrfWzn3x89xB+P+6VG3UxF9B2UdR8m/8GuqxOg/xUmLpnNWcbx\n1An+Pfn0OeNbsSt+o14d5DJ4zppZ88qAUqNwgbe7172s++zyy6/YFz+qcnWQk8M4XcL90febek7n\nCGrch4mLLXeiOiSOb43jlZZLmviOqRtwX437gPKVLglkzUV90RTJxZYE4HNQfLPjovXLdrVOAi6W\n9kVTkz5MXGxLAD6HVPGtSbzScqnp+lWy2JbuxHBWXLLsiybDzZjfJ4k/iUGHmAnL37hxg6mpKQBa\nrRbLy8t9R5u2fFoutiYAn0Oi+OYUrzDb4jeIi6XrF+Qbr2D5AJvjV8sk4JJHXzQZLKjBPknw/v+n\nxC94CctvbGzwyiuv8N577wGwt7fHhQsXWF9fjxxt2vJpudieAHwOfeObU7zCbIvfIC6Wrl++vOIV\nLP/3ncFli5/t2u0RX3eg/Yr3HvU5MMRo77ThFe89SXnavEubFe/1Lm3+MW2O0eaoXwfv/8e8z9KW\nh7ZvZmamHTYzM9M+ODjoGT5M+aRIMP9HfQ0Xv7TxzSFe4fKWxS/pujdKfONjl3b9GrDu5RGviPK2\nxS9uo1qrcwIuFvVFk2efM4H+g1qtFvv7+z2TbzabbG9v9wxPWz4tu7sjjuPQE988+5wJ/N+2+PXj\nYtH6lUQefXaF+n0qQ/xqkwRcLOirJCivPmf88p5Wq/W8bTGo0WjQarV6hqctn9Z4uiN2c5iCQ9cN\nR3n2ORP4v23xi+Ni2fqVRB59doWaisoQv1okAReL+iqJU0AfJo8fP05VxbTlozgjj6GXS/g+jzE8\nD8KCPmeKiF8UlxKsX1GyjFdU+QFsiZ8NSWAOWAYWvff5LEfuYmFfJXFy6MPk4OAg9uveXYQjlS+a\nS1R8x3TD0Rj6nLE9fi4lWr98efbZ5Zf32B4/KP7qoCZwGTgbGLYJ7AP3Rx25S4n6ooFc+jBpNBrE\nOXHiRM+wtOWL5BI3hx1yu7wwaAx9ztgcP5eSrV++H3vvefTZ5Zf/mRlkc/x8RSeBS8AvQ8OuA9eA\nd0YZsUvRlyE6pFpQ28Cf031L+mvAv/A+i+rDpF95T7PZjDzRdHBwQLPZ7BmetnxRXAbNWYdcE0HW\n8YrZybM1fi4lW7+C/PaPLOMVVR574xdUdHPQGWA3NOwesDDKSF2KXkB9DokPXXPqw8RfcJ4+fdpV\nrNVqsbDQO5vTli+CS9L4OuTWNDSmPmdsjJ9LCdevODn3AWVj/MKKTAJ+S9tnoeF+o9gbw4zUxZYF\n1OdQdF8o165dY3Nz8/nfu7u7nDp1iuPHjwNmAdva2kpcvkgu9euLxqb4uWj9Ssum+EUp8kzfHObR\nzVGJ6BBzNPBxaHi738l3l9EX0NG6Iu7HpeuR2it0+hh5O+Eo0pRfgXZgbq2trT3fywjfhr6xscHa\n2hqffvppovLDGl93xFFTSfLtpCK6Ao8ySnxX7IrfxMQEd8g3AYytG/CVmI+yXB9X7IsfMdt7m5PA\naXofrRGbBFyy7I/+zohjiePy/KkFHwA/9QZ/wOBItFOWX+leCG0wShJwSbyKE78hSTqWQRIkgbTx\nCpf/qV3xm5iYGEdP/4y2A+aSaDdhJWLwqPEKl1+xL35Y+DyB3jsihuSS9R7KGK4z9/sYecH7f7/l\npT1E+QpxyaMvGnekMWUer3B5C43nRr9ROAwV3yziVeJ1zuYjgUTNQS7ZJoDOkUDOB75H6dxi7p9s\nehn4AXDSG34X+DXwBLOQpSn/jV17IjDckYBLukgka1JIO9aIqRwj23iFy39rV/xGbcqL49LVSEo2\nW9PgWJ3QZ4EjgSzjFS6/Yl/8sLA5qIG5H6ABPAt9dohJEg9Cwx8Ar+dfNRGRSvmEIS+2yds+8Gpo\nWAOTBEREJGdF3yewDbwZGnYSuF1AXUREZMymgY9Cwzax9LBFRESyN0t3B3I/HMM0r2ES0Li/K+kp\nVuWm+Fmu6L6DwHQUN3JncSlNM/xlCKN8NysLXh12Ij6bw/TE2sJ00LcbU64s6h6rssezrPFbBN7C\nnKNsYi41uhH4vIlJUj/HbL+a3ndawK1AubLHr7I2GX4PY5TvZmEOc0I96oipialf0CbmaKus6hyr\nKsSzjPFbpLtJehJ4iGmp8DUxF7D4r33gX4fGU4r4FX1iWJKbBtbp/7yFfr2yyvhkFSvFsxhNui9P\nf4qJRXBTbdiNAAAGkElEQVS+tzFHeU1Msp8CfhEaj+JnsTLunQQ9JHrvsoqX3NY5VlWIZ9ni18Tc\nFhaerr/n7x8hTDP4AViliJ+OBKojl15ZJRdJYqV4FmMfs2c/avIpTfyUBHotYW4Mv4tpC4TutkBb\nDXriRGZ9NVmkyrGqQzxtjN8BZrsY7rJmznsPNhP5J4MXgXPeK/hZP9bEz4arg2wyjwn2IjADXPH+\nXh9yfOF7IAa5RH5XSsU/t66cqh6r+IfNdsqUWdnid4Xutvx97xW8Esive/AqojjWxE9JoNtdOpdv\nfYxZQL+gt0vrpEZ6RGZK1uxZjEnVY1X1eJYpfqcx7fhXAsOe0p0AwJz03cEkgdLET0mgW/CZbqcx\n7YJrBdUlrd4Hk3YbtGdZNlWP1aB4lT2eZYlfA9NsNegkMMCeV/5VSrQ+6pxAtCVMMG1cKOO0vPe4\nZ9ANWijLqqqxqks8bY/fBiZJhXs6vkjnPIbP37A3KVH8dCTQaxl4xPCHpUHjbGf29x6n6F5g/bbH\ncLfcVVD1WFU9nrbH7yrmBrBwAmh6n31Edxz82LQo0fqoJNDtIr0LpX+5WLj9L4lxtjNDp1fWzwLD\nqtorax1iVeV42h6/c5gTvcEN+Cxmw94CztO7IV/A/KbPvL9LET8lgY4F4CwmcE3MXsKcN3zcG4gk\noh4IdAlzciq4Ei15w6ukLrGqajxtj99p732K7hO8Z4HLgb+nMecBfJfpflJmVeNXCVF3Il4NfX4I\n/Fd62/SKugt1ElPHdUzdfuP9vRgqV0SvrHmqe6zKHs+yxc+/ozfq9Wmo7DlMTJYxsY66Aazs8aus\nst3KXmeKVbkpfpbT1UEiIjWmJCAiUmNKAiIiNVbXJLDP8E8sGuW7kp5iVW6Kn4iIiIiIiIiIiIiI\niEjJXMX0jXKIuUOy6KdA5eUqcA/zO68GXuuYG5jifvcSFvUOmVAZ6ywiBTpHthuN2QzHleU0+v3O\nc3Q/EjE4nf80xLSKVMY6i0iBst5zHMfRxDDTGPQ7/a6ERURqJeskcDfDcWU5jSS/85BkT5wSsVZd\nbxYTO1zEdCFc1mlsYfqV981RvqRQxjpLhvQ8AcmT33XyFDAD/AWdZ8uew/QlD52uhR9hHtLtW8Lc\nMboPvIV5GMcOph17FbPxehnTDz3AKcwJXX8cSaYxirt0+pefBN7HND35O1dJ6unPoyZwgu7+6mH0\neYA3Df/Rhw2vjN+nfbjOvtPeePcx8dun0y9+mmmLSMUkbQ7yHx7um6a3WWYS06QS5SLwXmjYw9A4\nDyPK7NPdL32/afST5HcueeMOTi9qWlH1fIJJUkF36T6JncU8OE3vnv56xDiCloCfhIb5fecHJZn/\nYjk1B0lezmI2Nr49zN5ucAMR9cQtMHurl+l99my4+QU6D/T27dPZ++83jXEL1/Mxvf3itDCPH4Ts\n5sE03fMD4GafejYwR01/GRp+A7hC7xVRg+a/WE7NQZK1SUyTj9/cMI/ZKDzBbGDCD96OctZ7Dz+J\n67E3nqDwRmic/IeGD/o9EF3PqBPWfmLIah5sYZpoTtFpStrpU36B+COgllevG6FhUmJKApK1BUzb\n8QTmxqqLdDYaB3FfCmnQ3QbdT5INcF5mSP7Q8LT1zGoe7GGOBq5gjiCuYxLD2ZjyTeLjtI/5zUmn\nLSWg5iDJmv9g7tvAz4GPI8qEmxR8s95nj+h+wHeW/Glk4Qxmo5qHFtnMg3nMkdllTFPTEUyCCZ+P\n8PlHbFFOAF9kUCexiJKAZKmJ2YA0MXuf4fZsf+NyErMhDu9x+hu9Ha9s1PNlw80jg8RNY1QXMckq\n/Buzsk028+BUxDjOA2/GlL+JiV9UopzFHEVIhSgJSJauY/YkH3t/Bzc+C5iNpn+i1r9UtBUo1/CG\nH2A2VNdC458HdhPUI3wyOGoao1jCHAWMen19uJ4TgWFZzoPwieQZ4u92PsCcz7kSGn7Rq8tnKact\nljtadAWkVK5imhG+C7yI2bAvYC5D/Cvgn3nvD4C/wTRBHAL/HPgcc4Ly32Iuc/ytN87/hdmo/gPM\nBmjPG77rffdfYfZK/XH8FrNBv4a50el73vD/B/wMeBezJ/t/A+OKm0a/33kGeDXid/4I0yTy58Df\nBb4TrtPfAH+SoJ6fe/W55o2/iWl7/21G86AB/Dev3J8E3v9zTJ3/Dvjv3vTe9Ya/i0nuazG/ddD8\nFxERERERERERERERERERERERERERERERERERERERERERERERkYz8f8J4ka++IKFVAAAAAElFTkSu\nQmCC\n",
175 | "text/plain": [
176 | ""
177 | ]
178 | },
179 | "metadata": {},
180 | "output_type": "display_data"
181 | }
182 | ],
183 | "source": [
184 | "N = 3\n",
185 | "ind = np.arange(N) # the x locations for the groups\n",
186 | "width = 0.9 # the width of the bars\n",
187 | "\n",
188 | "fig, ax = plt.subplots()\n",
189 | "train10 = (final_result['dim10ntrain10'], final_result['dim100ntrain10'], final_result['dim250ntrain10'])\n",
190 | "rects1 = ax.bar(ind, train10, width/4, color='r', hatch=\"/\")\n",
191 | "\n",
192 | "train100 = (final_result['dim10ntrain100'], final_result['dim100ntrain100'], final_result['dim250ntrain100'])\n",
193 | "rects2 = ax.bar(ind + width/4, train100, width/4, color='b', hatch=\"\\\\\")\n",
194 | "\n",
195 | "train1000 = (final_result['dim10ntrain1000'], final_result['dim100ntrain1000'], final_result['dim250ntrain1000'])\n",
196 | "rects3 = ax.bar(ind + 2*width/4, train1000, width/4, color='g', hatch=\"*\")\n",
197 | "\n",
198 | "if baselines[10] is None:\n",
199 | " KF = (0,0,0)\n",
200 | "else:\n",
201 | " KF = (baselines[10]['valid_rmse'].sum(), baselines[100]['valid_rmse'].sum(), baselines[250]['valid_rmse'].sum())\n",
202 | "rects4 = ax.bar(ind + 3*width/4, KF, width/4, color='y', hatch=\"//\")\n",
203 | "\n",
204 | "# add some text for labels, title and axes ticks\n",
205 | "ax.set_xlabel('Latent Dimension')\n",
206 | "ax.set_ylabel('RMSE')\n",
207 | "ax.set_xticks(ind + width/2)\n",
208 | "ax.set_xticklabels(('$|z|=10$', '$|z|=100$', '$|z|=250$'))\n",
209 | "\n",
210 | "ax.legend((rects1[0], rects2[0], rects3[0], rects4[0]), ('$N=10$', '$N=100$', '$N=1000$', 'KF'),\n",
211 | " loc='upper center', bbox_to_anchor=(0.5, 1.3),ncol=2, frameon=False)\n",
212 | "def autolabel(rects):\n",
213 | " # attach some text labels\n",
214 | " for idx, rect in enumerate(rects):\n",
215 | " height = rect.get_height()\n",
216 | " ax.text(rect.get_x() + rect.get_width()/2., 1.05*height,\n",
217 | " '%.1f' % height,\n",
218 | " ha='center', va='bottom',size=18)\n",
219 | "autolabel(rects1)\n",
220 | "autolabel(rects2)\n",
221 | "autolabel(rects3)\n",
222 | "autolabel(rects4)\n",
223 | "autolabel\n",
224 | "plt.ylim(0,6)\n",
225 | "\n",
226 | "#plt.savefig('scaling-exact-inference.pdf',bbox_inches='tight')"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": null,
232 | "metadata": {
233 | "collapsed": true
234 | },
235 | "outputs": [],
236 | "source": []
237 | }
238 | ],
239 | "metadata": {
240 | "kernelspec": {
241 | "display_name": "Python 2",
242 | "language": "python",
243 | "name": "python2"
244 | },
245 | "language_info": {
246 | "codemirror_mode": {
247 | "name": "ipython",
248 | "version": 2
249 | },
250 | "file_extension": ".py",
251 | "mimetype": "text/x-python",
252 | "name": "python",
253 | "nbconvert_exporter": "python",
254 | "pygments_lexer": "ipython2",
255 | "version": "2.7.3"
256 | }
257 | },
258 | "nbformat": 4,
259 | "nbformat_minor": 0
260 | }
261 |
--------------------------------------------------------------------------------
/paper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/paper.pdf
--------------------------------------------------------------------------------
/parse_args_dkf.py:
--------------------------------------------------------------------------------
1 | """
2 | Parse command line and store result in params
3 | Model : DKF
4 | """
5 | import argparse,copy
6 | from collections import OrderedDict
7 | p = argparse.ArgumentParser(description="Arguments for variational autoencoder")
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('-dset','--dataset', action='store',default = '', help='Dataset', type=str)
10 |
11 | #Recognition Model
12 | parser.add_argument('-rl','--rnn_layers', action='store',default = 1, help='Number of layers in the RNN', type=int, choices=[1,2])
13 | parser.add_argument('-rs','--rnn_size', action='store',default = 600, help='Hidden unit size in q model/RNN', type=int)
14 | parser.add_argument('-rd','--rnn_dropout', action='store',default = 0.1, help='Dropout after each RNN output layer', type=float)
15 | parser.add_argument('-vm','--var_model', action='store',default = 'R', help='Variational Model', type=str, choices=['L','LR','R'])
16 | parser.add_argument('-infm','--inference_model', action='store',default = 'structured', help='Inference Model', type=str, choices=['mean_field','structured'])
17 | parser.add_argument('-ql','--q_mlp_layers', action='store',default = 1, help='#Layers in Recognition Model', type=int)
18 | parser.add_argument('-useprior','--use_generative_prior', action='store_true', help='Use genertative prior in inference network')
19 |
20 | #Generative model
21 | parser.add_argument('-ds','--dim_stochastic', action='store',default = 100, help='Stochastic dimensions', type=int)
22 | parser.add_argument('-dh','--dim_hidden', action='store', default = 200, help='Hidden dimensions in DKF', type=int)
23 | parser.add_argument('-tl','--transition_layers', action='store', default = 2, help='Layers in transition fxn', type=int)
24 | parser.add_argument('-ttype','--transition_type', action='store', default = 'simple_gated', help='Layers in transition fxn', type=str, choices=['mlp','simple_gated'])
25 | parser.add_argument('-previnp','--use_prev_input', action='store_true', help='Use previous input in transition')
26 | parser.add_argument('-usenade','--use_nade', action='store_true', help='Use NADE ')
27 |
28 | parser.add_argument('-el','--emission_layers', action='store',default = 2, help='Layers in emission fxn', type=int)
29 | parser.add_argument('-etype','--emission_type', action='store',default = 'mlp', help='Type of emission fxn', type=str, choices=['mlp','conditional','res'])
30 |
31 | #Weights and Nonlinearity
32 | parser.add_argument('-iw','--init_weight', action='store',default = 0.1, help='Range to initialize weights during learning',type=float)
33 | parser.add_argument('-ischeme','--init_scheme', action='store',default = 'uniform', help='Type of initialization for weights', type=str, choices=['uniform','normal','xavier','he','orthogonal'])
34 | parser.add_argument('-nl','--nonlinearity', action='store',default = 'relu', help='Nonlinarity',type=str, choices=['relu','tanh','softplus','maxout','elu'])
35 | parser.add_argument('-lky','--leaky_param', action='store',default =0., help='Leaky ReLU parameter',type=float)
36 | parser.add_argument('-mstride','--maxout_stride', action='store',default = 4, help='Stride for maxout',type=int)
37 | parser.add_argument('-fg','--forget_bias', action='store',default = -5., help='Bias for forget gates', type=float)
38 |
39 | parser.add_argument('-vonly','--validate_only', action='store_true', help='Only build fxn for validation')
40 |
41 | #Optimization
42 | parser.add_argument('-lr','--lr', action='store',default = 8e-4, help='Learning rate', type=float)
43 | parser.add_argument('-opt','--optimizer', action='store',default = 'adam', help='Optimizer',choices=['adam','rmsprop'])
44 | parser.add_argument('-bs','--batch_size', action='store',default = 20, help='Batch Size',type=int)
45 | parser.add_argument('-ar','--anneal_rate', action='store',default = 10., help='Number of param. updates before anneal=1',type=float)
46 | parser.add_argument('-repK','--replicate_K', action='store',default = None, help='Number of samples used for the variational bound. Created by replicating the batch',type=int)
47 | parser.add_argument('-shuf','--shuffle', action='store_true',help='Shuffle during training')
48 | parser.add_argument('-covexp','--cov_explicit', action='store_true',help='Explicitly parameterize covariance')
49 | parser.add_argument('-nt','--ntrain', action='store',type=int,default=5000,help='number of training')
50 |
51 | #Regularization
52 | parser.add_argument('-reg','--reg_type', action='store',default = 'l2', help='Type of regularization',type=str,choices=['l1','l2'])
53 | parser.add_argument('-rv','--reg_value', action='store',default = 0.05, help='Amount of regularization',type=float)
54 | parser.add_argument('-rspec','--reg_spec', action='store',default = '_', help='String to match parameters (Default is generative model)',type=str)
55 |
56 | #Save/load
57 | parser.add_argument('-debug','--debug', action='store_true',help='Debug')
58 | parser.add_argument('-uid','--unique_id', action='store',default = 'uid',help='Unique Identifier',type=str)
59 | parser.add_argument('-seed','--seed', action='store',default = 1, help='Random Seed',type=int)
60 | parser.add_argument('-dir','--savedir', action='store',default = './chkpt', help='Prefix for savedir',type=str)
61 | parser.add_argument('-ep','--epochs', action='store',default = 2000, help='MaxEpochs',type=int)
62 | parser.add_argument('-reload','--reloadFile', action='store',default = './NOSUCHFILE', help='Reload from saved model',type=str)
63 | parser.add_argument('-params','--paramFile', action='store',default = './NOSUCHFILE', help='Reload parameters from saved model',type=str)
64 | parser.add_argument('-sfreq','--savefreq', action='store',default = 25, help='Frequency of saving',type=int)
65 | params = vars(parser.parse_args())
66 |
67 | hmap = OrderedDict()
68 | hmap['lr']='lr'
69 | hmap['var_model']='vm'
70 | hmap['inference_model']='inf'
71 | hmap['dim_hidden']='dh'
72 | hmap['dim_stochastic']='ds'
73 | hmap['nonlinearity']='nl'
74 | hmap['batch_size']='bs'
75 | hmap['epochs']='ep'
76 | hmap['rnn_size']='rs'
77 | hmap['transition_type']='ttype'
78 | hmap['emission_type']='etype'
79 | hmap['use_prev_input']='previnp'
80 | hmap['anneal_rate']='ar'
81 | hmap['reg_value']='rv'
82 | hmap['use_nade']='nade'
83 | hmap['ntrain']='nt'
84 | combined = ''
85 | for k in hmap:
86 | if k in params:
87 | if type(params[k]) is float:
88 | combined+=hmap[k]+'-'+('%.4e')%(params[k])+'-'
89 | else:
90 | combined+=hmap[k]+'-'+str(params[k])+'-'
91 |
92 | params['expt_name'] = params['unique_id']
93 | params['unique_id'] = combined[:-1]+'-'+params['unique_id']
94 | params['unique_id'] = 'DKF_'+params['unique_id'].replace('.','_')
95 | """
96 | import cPickle as pickle
97 | with open('default.pkl','wb') as f:
98 | pickle.dump(params,f)
99 | """
100 |
--------------------------------------------------------------------------------
/polyphonic_samples/jsb0.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/jsb0.mp3
--------------------------------------------------------------------------------
/polyphonic_samples/jsb1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/jsb1.mp3
--------------------------------------------------------------------------------
/polyphonic_samples/musedata0.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/musedata0.mp3
--------------------------------------------------------------------------------
/polyphonic_samples/musedata1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/musedata1.mp3
--------------------------------------------------------------------------------
/polyphonic_samples/nott0.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/nott0.mp3
--------------------------------------------------------------------------------
/polyphonic_samples/nott1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/nott1.mp3
--------------------------------------------------------------------------------
/polyphonic_samples/piano0.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/piano0.mp3
--------------------------------------------------------------------------------
/polyphonic_samples/piano1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/piano1.mp3
--------------------------------------------------------------------------------
/stinfmodel/README.md:
--------------------------------------------------------------------------------
1 | ## Main Repository containing implementation of DKF
2 |
3 | The file dkf.py contains the class definition for the Deep Kalman Filter
4 |
5 | The file uses many of the arguments from structuredinference/parse_args_dkf.py to create the model.
6 |
7 | The main parameters are detailed here below:
8 |
9 | ```
10 | * dim_stochastic
11 | * dim_hidden
12 | ```
13 |
--------------------------------------------------------------------------------
/stinfmodel/__init__.py:
--------------------------------------------------------------------------------
1 | all=['dkf','evaluate','learning']
2 |
--------------------------------------------------------------------------------
/stinfmodel/dkf.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import six.moves.cPickle as pickle
3 | from collections import OrderedDict
4 | import numpy as np
5 | import sys, time, os, gzip, theano,math
6 | sys.path.append('../')
7 | from theano import config
8 | theano.config.compute_test_value = 'warn'
9 | from theano.printing import pydotprint
10 | import theano.tensor as T
11 | from utils.misc import saveHDF5
12 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
13 |
14 | from utils.optimizer import adam,rmsprop
15 | from models.__init__ import BaseModel
16 | from datasets.synthp import params_synthetic
17 | from datasets.synthpTheano import updateParamsSynthetic
18 |
19 | #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%#
20 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""#
21 | #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%#
22 |
23 | """
24 | DEEP KALMAN FILTER
25 | """
26 | class DKF(BaseModel, object):
27 | def __init__(self, params, paramFile=None, reloadFile=None):
28 | self.scan_updates = []
29 | super(DKF,self).__init__(params, paramFile=paramFile, reloadFile=reloadFile)
30 | if 'synthetic' in self.params['dataset'] and not hasattr(self, 'params_synthetic'):
31 | assert False, 'Expecting to have params_synthetic as an attribute in DKF class'
32 | assert self.params['nonlinearity']!='maxout','Maxout nonlinearity not supported'
33 |
34 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""#
35 | def _fakeData(self):
36 | """ Fake data for tag testing """
37 | T = 3
38 | N = 2
39 | mask = np.random.random((N,T)).astype(config.floatX)
40 | small = mask<0.5
41 | large = mask>=0.5
42 | mask[small] = 0.
43 | mask[large]= 1.
44 | eps = np.random.randn(N,T,self.params['dim_stochastic']).astype(config.floatX)
45 | X = np.random.randn(N,T,self.params['dim_observations']).astype(config.floatX)
46 | return X ,mask, eps
47 |
48 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""#
49 | def _createParams(self):
50 | """ Model parameters """
51 | npWeights = OrderedDict()
52 | self._createInferenceParams(npWeights)
53 | self._createGenerativeParams(npWeights)
54 | return npWeights
55 |
56 | def _createGenerativeParams(self, npWeights):
57 | """ Create weights/params for generative model """
58 | if 'synthetic' in self.params['dataset']:
59 | return
60 | DIM_HIDDEN = self.params['dim_hidden']
61 | DIM_STOCHASTIC = self.params['dim_stochastic']
62 | if self.params['transition_type']=='mlp':
63 | DIM_HIDDEN_TRANS = DIM_HIDDEN*2
64 | for l in range(self.params['transition_layers']):
65 | dim_input,dim_output = DIM_HIDDEN_TRANS, DIM_HIDDEN_TRANS
66 | if l==0:
67 | dim_input = self.params['dim_stochastic']
68 | npWeights['p_trans_W_'+str(l)] = self._getWeight((dim_input, dim_output))
69 | npWeights['p_trans_b_'+str(l)] = self._getWeight((dim_output,))
70 | if self.params['use_prev_input']:
71 | npWeights['p_trans_W_0'] = self._getWeight((DIM_STOCHASTIC+self.params['dim_observations'],
72 | DIM_HIDDEN_TRANS))
73 | npWeights['p_trans_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,))
74 | MU_COV_INP = DIM_HIDDEN_TRANS
75 | elif self.params['transition_type']=='simple_gated':
76 | DIM_HIDDEN_TRANS = DIM_HIDDEN*2
77 | npWeights['p_gate_embed_W_0'] = self._getWeight((DIM_STOCHASTIC, DIM_HIDDEN_TRANS))
78 | npWeights['p_gate_embed_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,))
79 | npWeights['p_gate_embed_W_1'] = self._getWeight((DIM_HIDDEN_TRANS, DIM_STOCHASTIC))
80 | npWeights['p_gate_embed_b_1'] = self._getWeight((DIM_STOCHASTIC,))
81 | npWeights['p_z_W_0'] = self._getWeight((DIM_STOCHASTIC, DIM_HIDDEN_TRANS))
82 | npWeights['p_z_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,))
83 | npWeights['p_z_W_1'] = self._getWeight((DIM_HIDDEN_TRANS, DIM_STOCHASTIC))
84 | npWeights['p_z_b_1'] = self._getWeight((DIM_STOCHASTIC,))
85 | if self.params['use_prev_input']:
86 | npWeights['p_z_W_0'] = self._getWeight((DIM_STOCHASTIC+self.params['dim_observations'], DIM_HIDDEN_TRANS))
87 | npWeights['p_z_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,))
88 | npWeights['p_gate_embed_W_0'] = self._getWeight((DIM_STOCHASTIC+self.params['dim_observations'],
89 | DIM_HIDDEN_TRANS))
90 | npWeights['p_gate_embed_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,))
91 | MU_COV_INP = DIM_STOCHASTIC
92 | else:
93 | assert False,'Invalid transition type: '+self.params['transition_type']
94 |
95 | if self.params['transition_type']=='simple_gated':
96 | weight= np.eye(self.params['dim_stochastic']).astype(config.floatX)
97 | bias = np.zeros((self.params['dim_stochastic'],)).astype(config.floatX)
98 | #Initialize the weights to be identity
99 | npWeights['p_trans_W_mu'] = weight
100 | npWeights['p_trans_b_mu'] = bias
101 | else:
102 | npWeights['p_trans_W_mu'] = self._getWeight((MU_COV_INP, self.params['dim_stochastic']))
103 | npWeights['p_trans_b_mu'] = self._getWeight((self.params['dim_stochastic'],))
104 | npWeights['p_trans_W_cov'] = self._getWeight((MU_COV_INP, self.params['dim_stochastic']))
105 | npWeights['p_trans_b_cov'] = self._getWeight((self.params['dim_stochastic'],))
106 |
107 | #Emission Function [MLP]
108 | if self.params['emission_type'] == 'mlp':
109 | for l in range(self.params['emission_layers']):
110 | dim_input,dim_output = DIM_HIDDEN, DIM_HIDDEN
111 | if l==0:
112 | dim_input = self.params['dim_stochastic']
113 | npWeights['p_emis_W_'+str(l)] = self._getWeight((dim_input, dim_output))
114 | npWeights['p_emis_b_'+str(l)] = self._getWeight((dim_output,))
115 | elif self.params['emission_type'] =='conditional':
116 | for l in range(self.params['emission_layers']):
117 | dim_input,dim_output = DIM_HIDDEN, DIM_HIDDEN
118 | if l==0:
119 | dim_input = self.params['dim_stochastic']+self.params['dim_observations']
120 | npWeights['p_emis_W_'+str(l)] = self._getWeight((dim_input, dim_output))
121 | npWeights['p_emis_b_'+str(l)] = self._getWeight((dim_output,))
122 | else:
123 | assert False, 'Invalid emission type: '+str(self.params['emission_type'])
124 | if self.params['data_type']=='binary':
125 | npWeights['p_emis_W_ber'] = self._getWeight((self.params['dim_hidden'], self.params['dim_observations']))
126 | npWeights['p_emis_b_ber'] = self._getWeight((self.params['dim_observations'],))
127 | elif self.params['data_type']=='binary_nade':
128 | n_visible, n_hidden = self.params['dim_observations'], self.params['dim_hidden']
129 | npWeights['p_nade_W'] = self._getWeight((n_visible, n_hidden))
130 | npWeights['p_nade_U'] = self._getWeight((n_visible,n_hidden))
131 | npWeights['p_nade_b'] = self._getWeight((n_visible,))
132 | else:
133 | assert False,'Invalid datatype: '+params['data_type']
134 |
135 | def _createInferenceParams(self, npWeights):
136 | """ Create weights/params for inference network """
137 |
138 | #Initial embedding for the inputs
139 | DIM_INPUT = self.params['dim_observations']
140 | RNN_SIZE = self.params['rnn_size']
141 |
142 | DIM_HIDDEN = RNN_SIZE
143 | DIM_STOC = self.params['dim_stochastic']
144 |
145 | #Embed the Input -> RNN_SIZE
146 | dim_input, dim_output= DIM_INPUT, RNN_SIZE
147 | npWeights['q_W_input_0'] = self._getWeight((dim_input, dim_output))
148 | npWeights['q_b_input_0'] = self._getWeight((dim_output,))
149 |
150 | #Setup weights for LSTM
151 | self._createLSTMWeights(npWeights)
152 |
153 | #Embedding before MF/ST inference model
154 | if self.params['inference_model']=='mean_field':
155 | pass
156 | elif self.params['inference_model']=='structured':
157 | DIM_INPUT = self.params['dim_stochastic']
158 | if self.params['use_generative_prior']:
159 | DIM_INPUT = self.params['rnn_size']
160 | npWeights['q_W_st_0'] = self._getWeight((DIM_INPUT, self.params['rnn_size']))
161 | npWeights['q_b_st_0'] = self._getWeight((self.params['rnn_size'],))
162 | else:
163 | assert False,'Invalid inference model: '+self.params['inference_model']
164 | RNN_SIZE = self.params['rnn_size']
165 | npWeights['q_W_mu'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic']))
166 | npWeights['q_b_mu'] = self._getWeight((self.params['dim_stochastic'],))
167 | npWeights['q_W_cov'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic']))
168 | npWeights['q_b_cov'] = self._getWeight((self.params['dim_stochastic'],))
169 | if self.params['var_model']=='LR' and self.params['inference_model']=='mean_field':
170 | npWeights['q_W_mu_r'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic']))
171 | npWeights['q_b_mu_r'] = self._getWeight((self.params['dim_stochastic'],))
172 | npWeights['q_W_cov_r'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic']))
173 | npWeights['q_b_cov_r'] = self._getWeight((self.params['dim_stochastic'],))
174 |
175 | def _createLSTMWeights(self, npWeights):
176 | #LSTM L/LR/R w/ orthogonal weight initialization
177 | suffices_to_build = []
178 | if self.params['var_model']=='LR' or self.params['var_model']=='L':
179 | suffices_to_build.append('l')
180 | if self.params['var_model']=='LR' or self.params['var_model']=='R':
181 | suffices_to_build.append('r')
182 | RNN_SIZE = self.params['rnn_size']
183 | for suffix in suffices_to_build:
184 | for l in range(self.params['rnn_layers']):
185 | npWeights['W_lstm_'+suffix+'_'+str(l)] = self._getWeight((RNN_SIZE,RNN_SIZE*4))
186 | npWeights['b_lstm_'+suffix+'_'+str(l)] = self._getWeight((RNN_SIZE*4,), scheme='lstm')
187 | npWeights['U_lstm_'+suffix+'_'+str(l)] = self._getWeight((RNN_SIZE,RNN_SIZE*4),scheme='lstm')
188 |
189 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""#
190 | def _getEmissionFxn(self, z, X=None):
191 | """
192 | Apply emission function to zs
193 | Input: z [bs x T x dim]
194 | Output: (params, ) or (mu, cov) of size [bs x T x dim]
195 | """
196 | if 'synthetic' in self.params['dataset']:
197 | self._p('Using emission function for '+self.params['dataset'])
198 | mu = self.params_synthetic[self.params['dataset']]['obs_fxn'](z)
199 | cov = T.ones_like(mu)*self.params_synthetic[self.params['dataset']]['obs_cov']
200 | cov.name = 'EmissionCov'
201 | return [mu,cov]
202 |
203 | if self.params['emission_type']=='mlp':
204 | self._p('EMISSION TYPE: MLP')
205 | hid = z
206 | elif self.params['emission_type']=='conditional':
207 | self._p('EMISSION TYPE: conditional')
208 | X_prev = T.concatenate([T.zeros_like(X[:,[0],:]),X[:,:-1,:]],axis=1)
209 | hid = T.concatenate([z,X_prev],axis=2)
210 | else:
211 | assert False,'Invalid emission type'
212 |
213 | self._p('TODO: FIX THIS, SHOULD BE LINEAR FOR NADE')
214 | for l in range(self.params['emission_layers']):
215 | #Do not use a non-linearity in the last layer
216 | #if l==self.params['emission_layers']-1:
217 | # hid = T.dot(hid, self.tWeights['p_emis_W_'+str(l)]) + self.tWeights['p_emis_b_'+str(l)]
218 | hid = self._LinearNL(self.tWeights['p_emis_W_'+str(l)], self.tWeights['p_emis_b_'+str(l)], hid)
219 | if self.params['data_type']=='binary':
220 | mean_params=T.nnet.sigmoid(T.dot(hid,self.tWeights['p_emis_W_ber'])+self.tWeights['p_emis_b_ber'])
221 | return [mean_params]
222 | elif self.params['data_type']=='binary_nade':
223 | self._p('NADE observations')
224 | assert X is not None,'Need observations for NADE'
225 | x_reshaped = X.dimshuffle(2,0,1)
226 | x0 = T.ones((hid.shape[0],hid.shape[1]))#x_reshaped[0]) # bs x T
227 | a0 = hid #bs x T x nhid
228 | W = self.tWeights['p_nade_W']
229 | V = self.tWeights['p_nade_U']
230 | b = self.tWeights['p_nade_b']
231 | #Use a NADE at the output
232 | def NADEDensity(x, w, v, b, a_prev, x_prev):#Estimating likelihood
233 | a = a_prev + T.dot(T.shape_padright(x_prev, 1), T.shape_padleft(w, 1))
234 | h = T.nnet.sigmoid(a) #bs x T x nhid
235 | p_xi_is_one = T.nnet.sigmoid(T.dot(h, v) + b)
236 | return (a, x, p_xi_is_one)
237 | ([_, _, mean_params], _) = theano.scan(NADEDensity,
238 | sequences=[x_reshaped, W, V, b],
239 | outputs_info=[a0, x0,None])
240 | #theano function to sample from NADE
241 | def NADESample(w, v, b, a_prev_s, x_prev_s):
242 | a_s = a_prev_s + T.dot(T.shape_padright(x_prev_s, 1), T.shape_padleft(w, 1))
243 | h_s = T.nnet.sigmoid(a_s) #bs x T x nhid
244 | p_xi_is_one_s = T.nnet.sigmoid(T.dot(h_s, v) + b)
245 | x_s = T.switch(p_xi_is_one_s>0.5,1.,0.)
246 | return (a_s, x_s, p_xi_is_one_s)
247 | ([_, _, sampled_params], _) = theano.scan(NADESample,
248 | sequences=[W, V, b],
249 | outputs_info=[a0, x0,None])
250 | """
251 | def NADEDensityAndSample(x, w, v, b,
252 | a_prev, x_prev,
253 | a_prev_s, x_prev_s ):
254 | a = a_prev + T.dot(T.shape_padright(x_prev, 1), T.shape_padleft(w, 1))
255 | h = T.nnet.sigmoid(a) #bs x T x nhid
256 | p_xi_is_one = T.nnet.sigmoid(T.dot(h, v) + b)
257 |
258 | a_s = a_prev_s + T.dot(T.shape_padright(x_prev_s, 1), T.shape_padleft(w, 1))
259 | h_s = T.nnet.sigmoid(a_s) #bs x T x nhid
260 | p_xi_is_one_s = T.nnet.sigmoid(T.dot(h_s, v) + b)
261 | x_s = T.switch(p_xi_is_one_s>0.5,1.,0.)
262 | return (a, x, a_s, x_s, p_xi_is_one, p_xi_is_one_s)
263 |
264 | ([_, _, _, _, mean_params,sampled_params], _) = theano.scan(NADEDensityAndSample,
265 | sequences=[x_reshaped, W, V, b],
266 | outputs_info=[a0, x0, a0, x0, None, None])
267 | """
268 | sampled_params = sampled_params.dimshuffle(1,2,0)
269 | mean_params = mean_params.dimshuffle(1,2,0)
270 | return [mean_params,sampled_params]
271 | else:
272 | assert False,'Invalid type of data'
273 |
274 | def _getTransitionFxn(self, z, X=None):
275 | """
276 | Apply transition function to zs
277 | Input: z [bs x T x dim], u [bs x T x dim]
278 | Output: mu, cov of size [bs x T x dim]
279 | """
280 | if 'synthetic' in self.params['dataset']:
281 | self._p('Using transition function for '+self.params['dataset'])
282 | mu = self.params_synthetic[self.params['dataset']]['trans_fxn'](z)
283 | cov = T.ones_like(mu)*self.params_synthetic[self.params['dataset']]['trans_cov']
284 | cov.name = 'TransitionCov'
285 | return mu,cov
286 |
287 | if self.params['transition_type']=='simple_gated':
288 | def mlp(inp, W1,b1,W2,b2, X_prev=None):
289 | if X_prev is not None:
290 | h1 = self._LinearNL(W1,b1, T.concatenate([inp,X_prev],axis=2))
291 | else:
292 | h1 = self._LinearNL(W1,b1, inp)
293 | h2 = T.dot(h1,W2)+b2
294 | return h2
295 |
296 | gateInp= z
297 | X_prev = None
298 | if self.params['use_prev_input']:
299 | X_prev = T.concatenate([T.zeros_like(X[:,[0],:]),X[:,:-1,:]],axis=1)
300 | gate = T.nnet.sigmoid(mlp(gateInp, self.tWeights['p_gate_embed_W_0'], self.tWeights['p_gate_embed_b_0'],
301 | self.tWeights['p_gate_embed_W_1'],self.tWeights['p_gate_embed_b_1'],
302 | X_prev = X_prev))
303 | z_prop = mlp(z,self.tWeights['p_z_W_0'] ,self.tWeights['p_z_b_0'],
304 | self.tWeights['p_z_W_1'] , self.tWeights['p_z_b_1'], X_prev = X_prev)
305 | mu = gate*z_prop + (1.-gate)*(T.dot(z, self.tWeights['p_trans_W_mu'])+self.tWeights['p_trans_b_mu'])
306 | cov = T.nnet.softplus(T.dot(self._applyNL(z_prop), self.tWeights['p_trans_W_cov'])+
307 | self.tWeights['p_trans_b_cov'])
308 | return mu,cov
309 | elif self.params['transition_type']=='mlp':
310 | hid = z
311 | if self.params['use_prev_input']:
312 | X_prev = T.concatenate([T.zeros_like(X[:,[0],:]),X[:,:-1,:]],axis=1)
313 | hid = T.concatenate([hid,X_prev],axis=2)
314 | for l in range(self.params['transition_layers']):
315 | hid = self._LinearNL(self.tWeights['p_trans_W_'+str(l)],self.tWeights['p_trans_b_'+str(l)],hid)
316 | mu = T.dot(hid, self.tWeights['p_trans_W_mu']) + self.tWeights['p_trans_b_mu']
317 | cov = T.nnet.softplus(T.dot(hid, self.tWeights['p_trans_W_cov'])+self.tWeights['p_trans_b_cov'])
318 | return mu,cov
319 | else:
320 | assert False,'Invalid Transition type: '+str(self.params['transition_type'])
321 |
322 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""#
323 | def _buildLSTM(self, X, embedding, dropout_prob = 0.):
324 | """
325 | Take the embedding of bs x T x dim and return T x bs x dim that is the result of the scan operation
326 | for L/LR/R
327 | Input: embedding [bs x T x dim]
328 | Output:hidden_state [T x bs x dim]
329 | """
330 | start_time = time.time()
331 | self._p('In <_buildLSTM>')
332 | suffix = ''
333 | if self.params['var_model']=='R':
334 | suffix='r'
335 | return self._LSTMlayer(embedding, suffix, dropout_prob)
336 | elif self.params['var_model']=='L':
337 | suffix='l'
338 | return self._LSTMlayer(embedding, suffix, dropout_prob)
339 | elif self.params['var_model']=='LR':
340 | suffix='l'
341 | l2r = self._LSTMlayer(embedding, suffix, dropout_prob)
342 | suffix='r'
343 | r2l = self._LSTMlayer(embedding, suffix, dropout_prob)
344 | return [l2r,r2l]
345 | else:
346 | assert False,'Invalid variational model'
347 | self._p(('Done <_buildLSTM> [Took %.4f]')%(time.time()-start_time))
348 |
349 |
350 | def _inferenceLayer(self, hidden_state, eps):
351 | """
352 | Take input of T x bs x dim and return z, mu,
353 | sq each of size (bs x T x dim)
354 | Input: hidden_state [T x bs x dim], eps [bs x T x dim]
355 | Output: z [bs x T x dim], mu [bs x T x dim], cov [bs x T x dim]
356 | """
357 | def structuredApproximation(h_t, eps_t, z_prev,
358 | q_W_st_0, q_b_st_0,
359 | q_W_mu, q_b_mu,
360 | q_W_cov,q_b_cov):
361 | #Using the prior distribution directly
362 | if self.params['use_generative_prior']:
363 | assert not self.params['use_prev_input'],'No support for using previous input'
364 | #Get mu/cov from z_prev through prior distribution
365 | mu_1,cov_1 = self._getTransitionFxn(z_prev)
366 | #Combine with estimate of mu/cov from data
367 | h_data = T.tanh(T.dot(h_t,q_W_st_0)+q_b_st_0)
368 | mu_2 = T.dot(h_data,q_W_mu)+q_b_mu
369 | cov_2 = T.nnet.softplus(T.dot(h_data,q_W_cov)+q_b_cov)
370 | mu = (mu_1*cov_2+mu_2*cov_1)/(cov_1+cov_2)
371 | cov = (cov_1*cov_2)/(cov_1+cov_2)
372 | z = mu + T.sqrt(cov)*eps_t
373 | else:
374 | h_next = T.tanh(T.dot(z_prev,q_W_st_0)+q_b_st_0)
375 | if self.params['var_model']=='LR':
376 | h_next = (1./3.)*(h_t+h_next)
377 | else:
378 | h_next = (1./2.)*(h_t+h_next)
379 | mu_t = T.dot(h_next,q_W_mu)+q_b_mu
380 | cov_t = T.nnet.softplus(T.dot(h_next,q_W_cov)+q_b_cov)
381 | z_t = mu_t+T.sqrt(cov_t)*eps_t
382 | return z_t, mu_t, cov_t
383 |
384 | if self.params['inference_model']=='structured':
385 | #Structured recognition networks
386 | if self.params['var_model']=='LR':
387 | state = hidden_state[0]+hidden_state[1]
388 | else:
389 | state = hidden_state
390 | eps_swap = eps.swapaxes(0,1)
391 | if self.params['dim_stochastic']==1:
392 | """
393 | TODO: Write to theano authors regarding this issue.
394 | Workaround for theano issue: The result of a matrix multiply is a "matrix"
395 | even if one of the dimensions is 1. However defining a tensor with one dimension one
396 | means theano regards the resulting tensor as a matrix and consequently in the
397 | scan as a column. This results in a mismatch in tensor type in input (column)
398 | and output (matrix) and throws an error. This is a workaround that preserves
399 | type while
400 | """
401 | z0 = T.zeros((eps_swap.shape[1], self.params['rnn_size']))
402 | z0 = T.dot(z0,T.zeros_like(self.tWeights['q_W_mu']))
403 | else:
404 | z0 = T.zeros((eps_swap.shape[1], self.params['dim_stochastic']))
405 | rval, _ = theano.scan(structuredApproximation,
406 | sequences=[state, eps_swap],
407 | outputs_info=[z0, None,None],
408 | non_sequences=[self.tWeights[k] for k in
409 | ['q_W_st_0', 'q_b_st_0']]+
410 | [self.tWeights[k] for k in
411 | ['q_W_mu','q_b_mu','q_W_cov','q_b_cov']],
412 | name='structuredApproximation')
413 | z, mu, cov = rval[0].swapaxes(0,1), rval[1].swapaxes(0,1), rval[2].swapaxes(0,1)
414 | return z, mu, cov
415 | elif self.params['inference_model']=='mean_field':
416 | if self.params['var_model']=='LR':
417 | l2r = hidden_state[0].swapaxes(0,1)
418 | r2l = hidden_state[1].swapaxes(0,1)
419 | hidl2r = l2r
420 | mu_1 = T.dot(hidl2r,self.tWeights['q_W_mu'])+self.tWeights['q_b_mu']
421 | cov_1 = T.nnet.softplus(T.dot(hidl2r, self.tWeights['q_W_cov'])+self.tWeights['q_b_cov'])
422 | hidr2l = r2l
423 | mu_2 = T.dot(hidr2l,self.tWeights['q_W_mu_r'])+self.tWeights['q_b_mu_r']
424 | cov_2 = T.nnet.softplus(T.dot(hidr2l, self.tWeights['q_W_cov_r'])+self.tWeights['q_b_cov_r'])
425 | mu = (mu_1*cov_2+mu_2*cov_1)/(cov_1+cov_2)
426 | cov= (cov_1*cov_2)/(cov_1+cov_2)
427 | z = mu + T.sqrt(cov)*eps
428 | else:
429 | hid = hidden_state.swapaxes(0,1)
430 | mu = T.dot(hid,self.tWeights['q_W_mu'])+self.tWeights['q_b_mu']
431 | cov = T.nnet.softplus(T.dot(hid,self.tWeights['q_W_cov'])+ self.tWeights['q_b_cov'])
432 | z = mu + T.sqrt(cov)*eps
433 | return z,mu,cov
434 | else:
435 | assert False,'Invalid recognition model'
436 |
437 | def _qEmbeddingLayer(self, X):
438 | """ Embed for q """
439 | return self._LinearNL(self.tWeights['q_W_input_0'],self.tWeights['q_b_input_0'], X)
440 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""#
441 |
442 | def _inferenceAndReconstruction(self, X, eps, dropout_prob = 0.):
443 | """
444 | Returns z_q, mu_q and cov_q
445 | """
446 | self._p('Building with dropout:'+str(dropout_prob))
447 | embedding = self._qEmbeddingLayer(X)
448 | hidden_state = self._buildLSTM(X, embedding, dropout_prob)
449 | z_q,mu_q,cov_q = self._inferenceLayer(hidden_state, eps)
450 |
451 | #Regularize z_q (for train)
452 | #if dropout_prob>0.:
453 | # z_q = z_q + self.srng.normal(z_q.shape, 0.,0.0025,dtype=config.floatX)
454 | #z_q.name = 'z_q'
455 |
456 | observation_params = self._getEmissionFxn(z_q,X=X)
457 | mu_trans, cov_trans= self._getTransitionFxn(z_q, X=X)
458 | mu_prior = T.concatenate([T.alloc(np.asarray(0.).astype(config.floatX),
459 | X.shape[0],1,self.params['dim_stochastic']), mu_trans[:,:-1,:]],axis=1)
460 | cov_prior = T.concatenate([T.alloc(np.asarray(1.).astype(config.floatX),
461 | X.shape[0],1,self.params['dim_stochastic']), cov_trans[:,:-1,:]],axis=1)
462 | return observation_params, z_q, mu_q, cov_q, mu_prior, cov_prior, mu_trans, cov_trans
463 |
464 |
465 | def _getTemporalKL(self, mu_q, cov_q, mu_prior, cov_prior, M, batchVector = False):
466 | """
467 | TemporalKL divergence KL (q||p)
468 | KL(q_t||p_t) = 0.5*(log|sigmasq_p| -log|sigmasq_q| -D + Tr(sigmasq_p^-1 sigmasq_q)
469 | + (mu_p-mu_q)^T sigmasq_p^-1 (mu_p-mu_q))
470 | M is a mask of size bs x T that should be applied once the KL divergence for each point
471 | across time has been estimated
472 | """
473 | assert np.all(cov_q.tag.test_value>0.),'should be positive'
474 | assert np.all(cov_prior.tag.test_value>0.),'should be positive'
475 | diff_mu = mu_prior-mu_q
476 | KL_t = T.log(cov_prior)-T.log(cov_q) - 1. + cov_q/cov_prior + diff_mu**2/cov_prior
477 | KLvec = (0.5*KL_t.sum(2)*M).sum(1,keepdims=True)
478 | if batchVector:
479 | return KLvec
480 | else:
481 | return KLvec.sum()
482 |
483 | def _getNegCLL(self, obs_params, X, M, batchVector = False):
484 | """
485 | Estimate the negative conditional log likelihood of x|z under the generative model
486 | M: mask of size bs x T
487 | X: target of size bs x T x dim
488 | """
489 | if self.params['data_type']=='real':
490 | mu_p = obs_params[0]
491 | cov_p = obs_params[1]
492 | std_p = T.sqrt(cov_p)
493 | negCLL_t = 0.5 * np.log(2 * np.pi) + 0.5*T.log(cov_p) + 0.5 * ((X - mu_p) / std_p)**2
494 | negCLL = (negCLL_t.sum(2)*M).sum(1,keepdims=True)
495 | else:
496 | mean_p = obs_params[0]
497 | negCLL = (T.nnet.binary_crossentropy(mean_p,X).sum(2)*M).sum(1,keepdims=True)
498 | if batchVector:
499 | return negCLL
500 | else:
501 | return negCLL.sum()
502 |
503 | def _buildModel(self):
504 | if 'synthetic' in self.params['dataset']:
505 | self.params_synthetic = params_synthetic
506 | """ High level function to build and setup theano functions """
507 | X = T.tensor3('X', dtype=config.floatX)
508 | eps = T.tensor3('eps', dtype=config.floatX)
509 | M = T.matrix('M', dtype=config.floatX)
510 | X.tag.test_value, M.tag.test_value, eps.tag.test_value = self._fakeData()
511 |
512 | #Learning Rates and annealing objective function
513 | #Add them to npWeights/tWeights to be tracked [do not have a prefix _W or _b so wont be diff.]
514 | self._addWeights('lr', np.asarray(self.params['lr'],dtype=config.floatX),borrow=False)
515 | self._addWeights('anneal', np.asarray(0.01,dtype=config.floatX),borrow=False)
516 | self._addWeights('update_ctr', np.asarray(1.,dtype=config.floatX),borrow=False)
517 | lr = self.tWeights['lr']
518 | anneal = self.tWeights['anneal']
519 | iteration_t = self.tWeights['update_ctr']
520 |
521 | anneal_div = 1000.
522 | if 'anneal_rate' in self.params:
523 | self._p('Anneal = 1 in '+str(self.params['anneal_rate'])+' param. updates')
524 | anneal_div = self.params['anneal_rate']
525 | if 'synthetic' in self.params['dataset']:
526 | anneal_div = 100.
527 | anneal_update = [(iteration_t, iteration_t+1),
528 | (anneal,T.switch(0.01+iteration_t/anneal_div>1,1,0.01+iteration_t/anneal_div))]
529 | fxn_inputs = [X, M, eps]
530 | if not self.params['validate_only']:
531 | print '****** CREATING TRAINING FUNCTION*****'
532 | ############# Setup training functions ###########
533 | obs_params, z_q, mu_q, cov_q, mu_prior, cov_prior, _, _ = self._inferenceAndReconstruction(
534 | X, eps,
535 | dropout_prob = self.params['rnn_dropout'])
536 | negCLL = self._getNegCLL(obs_params, X, M)
537 | TemporalKL = self._getTemporalKL(mu_q, cov_q, mu_prior, cov_prior, M)
538 | train_cost = negCLL+anneal*TemporalKL
539 |
540 | #Get updates from optimizer
541 | model_params = self._getModelParams()
542 | optimizer_up, norm_list = self._setupOptimizer(train_cost, model_params,lr = lr,
543 | reg_type =self.params['reg_type'],
544 | reg_spec =self.params['reg_spec'],
545 | reg_value= self.params['reg_value'],
546 | divide_grad = T.cast(X.shape[0],dtype=config.floatX),
547 | grad_norm = 1.)
548 |
549 | #Add annealing updates
550 | optimizer_up +=anneal_update+self.updates
551 | self._p(str(len(self.updates))+' other updates')
552 | ############# Setup train & evaluate functions ###########
553 | self.train_debug = theano.function(fxn_inputs,[train_cost,norm_list[0],norm_list[1],
554 | norm_list[2],negCLL, TemporalKL, anneal.sum()],
555 | updates = optimizer_up, name='Train (with Debug)')
556 | #Updates ack
557 | self.updates_ack = True
558 | eval_obs_params, eval_z_q, eval_mu_q, eval_cov_q, eval_mu_prior, eval_cov_prior, \
559 | eval_mu_trans, eval_cov_trans = self._inferenceAndReconstruction(
560 | X, eps,
561 | dropout_prob = 0.)
562 | eval_z_q.name = 'eval_z_q'
563 | eval_CNLLvec=self._getNegCLL(eval_obs_params, X, M, batchVector = True)
564 | eval_KLvec = self._getTemporalKL(eval_mu_q, eval_cov_q,eval_mu_prior, eval_cov_prior, M, batchVector = True)
565 | eval_cost = eval_CNLLvec + eval_KLvec
566 |
567 | #From here on, convert to the log covariance since we only use it for evaluation
568 | assert np.all(eval_cov_q.tag.test_value>0.),'should be positive'
569 | assert np.all(eval_cov_prior.tag.test_value>0.),'should be positive'
570 | assert np.all(eval_cov_trans.tag.test_value>0.),'should be positive'
571 | eval_logcov_q = T.log(eval_cov_q)
572 | eval_logcov_prior = T.log(eval_cov_prior)
573 | eval_logcov_trans = T.log(eval_cov_trans)
574 |
575 | ll_prior = self._llGaussian(eval_z_q, eval_mu_prior, eval_logcov_prior).sum(2)*M
576 | ll_posterior = self._llGaussian(eval_z_q, eval_mu_q, eval_logcov_q).sum(2)*M
577 | ll_estimate = -1*eval_CNLLvec+ll_prior.sum(1,keepdims=True)-ll_posterior.sum(1,keepdims=True)
578 |
579 | eval_inputs = [eval_z_q]
580 | self.likelihood = theano.function(fxn_inputs, ll_estimate, name = 'Importance Sampling based likelihood')
581 | self.evaluate = theano.function(fxn_inputs, eval_cost, name = 'Evaluate Bound')
582 | if self.params['use_prev_input']:
583 | eval_inputs.append(X)
584 | self.transition_fxn = theano.function(eval_inputs,[eval_mu_trans, eval_logcov_trans],
585 | name='Transition Function')
586 | emission_inputs = [eval_z_q]
587 | if self.params['emission_type']=='conditional':
588 | emission_inputs.append(X)
589 | if self.params['data_type']=='binary_nade':
590 | self.emission_fxn = theano.function(emission_inputs,
591 | eval_obs_params[1], name='Emission Function')
592 | else:
593 | self.emission_fxn = theano.function(emission_inputs,
594 | eval_obs_params[0], name='Emission Function')
595 | self.posterior_inference = theano.function([X, eps],
596 | [eval_z_q, eval_mu_q, eval_logcov_q],
597 | name='Posterior Inference')
598 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""#
599 | if __name__=='__main__':
600 | """ use this to check compilation for various options"""
601 | from parse_args_dkf import params
602 | if params['use_nade']:
603 | params['data_type'] = 'binary_nade'
604 | else:
605 | params['data_type'] = 'binary'
606 | params['dim_observations'] = 10
607 | dkf = DKF(params, paramFile = 'tmp')
608 | os.unlink('tmp')
609 |
--------------------------------------------------------------------------------
/stinfmodel/evaluate.py:
--------------------------------------------------------------------------------
1 |
2 | from theano import config
3 | import numpy as np
4 | import time
5 | """
6 | Functions for evaluating a DKF object
7 | """
8 | def infer(dkf, dataset):
9 | """ Posterior Inference using recognition network
10 | Returns: z,mu,logcov (each a 3D tensor) Remember to multiply each by the mask of the dataset before
11 | using the latent variables
12 | """
13 | assert len(dataset.shape)==3,'Expecting 3D tensor for data'
14 | assert dataset.shape[2]==dkf.params['dim_observations'],'Data dim. not matching'
15 | eps = np.random.randn(dataset.shape[0],
16 | dataset.shape[1],
17 | dkf.params['dim_stochastic']).astype(config.floatX)
18 | return dkf.posterior_inference(X=dataset.astype(config.floatX), eps=eps)
19 |
20 | def evaluateBound(dkf, dataset, mask, batch_size,S=2, normalization = 'frame', additional={}):
21 | """ Evaluate ELBO """
22 | bound = 0
23 | start_time = time.time()
24 | N = dataset.shape[0]
25 |
26 | tsbn_bound = 0
27 | for bnum,st_idx in enumerate(range(0,N,batch_size)):
28 | end_idx = min(st_idx+batch_size, N)
29 | X = dataset[st_idx:end_idx,:,:].astype(config.floatX)
30 | M = mask[st_idx:end_idx,:].astype(config.floatX)
31 | U = None
32 |
33 | #Reduce the dimensionality of the tensors based on the maximum size of the mask
34 | maxT = int(np.max(M.sum(1)))
35 | X = X[:,:maxT,:]
36 | M = M[:,:maxT]
37 | eps = np.random.randn(X.shape[0],maxT,dkf.params['dim_stochastic']).astype(config.floatX)
38 | maxS = S
39 | bound_sum, tsbn_bound_sum = 0, 0
40 | for s in range(S):
41 | if s>0 and s%500==0:
42 | dkf._p('Done '+str(s))
43 | eps = np.random.randn(X.shape[0],maxT,dkf.params['dim_stochastic']).astype(config.floatX)
44 | batch_vec= dkf.evaluate(X=X, M=M, eps=eps)
45 | if np.any(np.isnan(batch_vec)) or np.any(np.isinf(batch_vec)):
46 | dkf._p('NaN detected during evaluation. Ignoring this sample')
47 | maxS -=1
48 | continue
49 | else:
50 | tsbn_bound_sum+=(batch_vec/M.sum(1,keepdims=True)).sum()
51 | bound_sum+=batch_vec.sum()
52 | tsbn_bound += tsbn_bound_sum/float(max(maxS*N,1.))
53 | bound += bound_sum/float(max(maxS,1.))
54 | if normalization=='frame':
55 | bound /= float(mask.sum())
56 | elif normalization=='sequence':
57 | bound /= float(N)
58 | else:
59 | assert False,'Invalid normalization specified'
60 | end_time = time.time()
61 | dkf._p(('(Evaluate) Validation Bound: %.4f [Took %.4f seconds], TSBN Bound: %.4f')%(bound,end_time-start_time,tsbn_bound))
62 | additional['tsbn_bound'] = tsbn_bound
63 | return bound
64 |
65 |
66 | def impSamplingNLL(dkf, dataset, mask, batch_size, S = 2, normalization = 'frame'):
67 | """ Importance sampling based log likelihood """
68 | ll = 0
69 | start_time = time.time()
70 | N = dataset.shape[0]
71 | for bnum,st_idx in enumerate(range(0,N,batch_size)):
72 | end_idx = min(st_idx+batch_size, N)
73 | X = dataset[st_idx:end_idx,:,:].astype(config.floatX)
74 | M = mask[st_idx:end_idx,:].astype(config.floatX)
75 | U = None
76 | maxT = int(np.max(M.sum(1)))
77 | X = X[:,:maxT,:]
78 | M = M[:,:maxT]
79 | eps = np.random.randn(X.shape[0],maxT,dkf.params['dim_stochastic']).astype(config.floatX)
80 | maxS = S
81 | lllist = []
82 | for s in range(S):
83 | if s>0 and s%500==0:
84 | dkf._p('Done '+str(s))
85 | eps = np.random.randn(X.shape[0],maxT,
86 | dkf.params['dim_stochastic']).astype(config.floatX)
87 | batch_vec = dkf.likelihood(X=X, M=M, eps=eps)
88 | if np.any(np.isnan(batch_vec)) or np.any(np.isinf(batch_vec)):
89 | dkf._p('NaN detected during evaluation. Ignoring this sample')
90 | maxS -=1
91 | continue
92 | else:
93 | lllist.append(batch_vec)
94 | ll += dkf.meanSumExp(np.concatenate(lllist,axis=1), axis=1).sum()
95 | if normalization=='frame':
96 | ll /= float(mask.sum())
97 | elif normalization=='sequence':
98 | ll /= float(N)
99 | else:
100 | assert False,'Invalid normalization specified'
101 | end_time = time.time()
102 | dkf._p(('(Evaluate w/ Imp. Sampling) Validation LL: %.4f [Took %.4f seconds]')%(ll,end_time-start_time))
103 | return ll
104 |
105 |
106 | def sampleGaussian(dkf,mu,logcov):
107 | return mu + np.random.randn(*mu.shape)*np.exp(0.5*logcov)
108 |
109 | def sample(dkf, nsamples=100, T=10, additional = {}):
110 | """
111 | Sample from Generative Model
112 | """
113 | assert T>1, 'Sample atleast 2 timesteps'
114 | #Initial sample
115 | z = np.random.randn(nsamples,1,dkf.params['dim_stochastic']).astype(config.floatX)
116 | all_zs = [np.copy(z)]
117 | additional['mu'] = []
118 | additional['logcov'] = []
119 | for t in range(T-1):
120 | mu,logcov = dkf.transition_fxn(z)
121 | z = dkf.sampleGaussian(mu,logcov).astype(config.floatX)
122 | all_zs.append(np.copy(z))
123 | additional['mu'].append(np.copy(mu))
124 | additional['logcov'].append(np.copy(logcov))
125 | zvec = np.concatenate(all_zs,axis=1)
126 | additional['mu'] = np.concatenate(additional['mu'], axis=1)
127 | additional['logcov'] = np.concatenate(additional['logcov'], axis=1)
128 | return dkf.emission_fxn(zvec), zvec
129 |
130 |
--------------------------------------------------------------------------------
/stinfmodel/learning.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions for learning with a DKF object
3 | """
4 | import evaluate as DKF_evaluate
5 | import numpy as np
6 | from utils.misc import saveHDF5
7 | import time
8 | from theano import config
9 |
10 | def learn(dkf, dataset, mask, epoch_start=0, epoch_end=1000,
11 | batch_size=200, shuffle=False,
12 | savefreq=None, savefile = None,
13 | dataset_eval = None, mask_eval = None,
14 | replicate_K = None,
15 | normalization = 'frame'):
16 | """
17 | Train DKF
18 | """
19 | assert not dkf.params['validate_only'],'cannot learn in validate only mode'
20 | assert len(dataset.shape)==3,'Expecting 3D tensor for data'
21 | assert dataset.shape[2]==dkf.params['dim_observations'],'Dim observations not valid'
22 | N = dataset.shape[0]
23 | idxlist = range(N)
24 | batchlist = np.split(idxlist, range(batch_size,N,batch_size))
25 |
26 | bound_train_list,bound_valid_list,bound_tsbn_list,nll_valid_list = [],[],[],[]
27 | p_norm, g_norm, opt_norm = None, None, None
28 |
29 | #Lists used to track quantities for synthetic experiments
30 | mu_list_train, cov_list_train, mu_list_valid, cov_list_valid = [],[],[],[]
31 |
32 | #Start of training loop
33 | for epoch in range(epoch_start, epoch_end):
34 | #Shuffle
35 | if shuffle:
36 | np.random.shuffle(idxlist)
37 | batchlist = np.split(idxlist, range(batch_size,N,batch_size))
38 | #Always shuffle order the batches are presented in
39 | np.random.shuffle(batchlist)
40 |
41 | start_time = time.time()
42 | bound = 0
43 | for bnum, batch_idx in enumerate(batchlist):
44 | batch_idx = batchlist[bnum]
45 | X = dataset[batch_idx,:,:].astype(config.floatX)
46 | M = mask[batch_idx,:].astype(config.floatX)
47 | U = None
48 |
49 | #Tack on 0's if the matrix size (optimization->theano doesnt have to redefine matrices)
50 | if X.shape[0]> A.result
2 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm LR -infm mean_field -dset jsb >> A.result
3 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm R -infm structured -dset jsb >> A.result
4 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm R -infm structured -ttype mlp -dset jsb >>A.result
5 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm LR -infm structured -dset jsb >>A.result
6 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm R -infm structured -ar 5000 -dset jsb >>A.result
7 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm R -infm structured -ar 5000 -usenade -dset jsb >>A.result
8 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm R -infm structured -ar 5000 -etype conditional -usenade -dset jsb >>A.result
9 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset jsb >>A.result
10 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset jsb >> A.result
11 | grep -n "buildModel" A.result
12 | rm -rf A.result
13 |
--------------------------------------------------------------------------------
/stinfmodel_fast/__init__.py:
--------------------------------------------------------------------------------
1 | all=['dkf','evaluate','learning']
2 |
--------------------------------------------------------------------------------
/stinfmodel_fast/dkf.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import six.moves.cPickle as pickle
3 | from collections import OrderedDict
4 | import numpy as np
5 | import sys, time, os, gzip, theano,math
6 | sys.path.append('../')
7 | from theano import config
8 | theano.config.compute_test_value = 'warn'
9 | from theano.printing import pydotprint
10 | import theano.tensor as T
11 | from utils.misc import saveHDF5
12 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
13 |
14 | from utils.optimizer import adam,rmsprop
15 | from models.__init__ import BaseModel
16 | from datasets.synthp import params_synthetic
17 | from datasets.synthpTheano import updateParamsSynthetic
18 |
19 | #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%#
20 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""#
21 | #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%#
22 |
23 | """
24 | DEEP MARKOV MODEL [DEEP KALMAN FILTER]
25 | """
26 | class DKF(BaseModel, object):
27 | def __init__(self, params, paramFile=None, reloadFile=None):
28 | self.scan_updates = []
29 | super(DKF,self).__init__(params, paramFile=paramFile, reloadFile=reloadFile)
30 | if 'synthetic' in self.params['dataset'] and not hasattr(self, 'params_synthetic'):
31 | assert False, 'Expecting to have params_synthetic as an attribute in DKF class'
32 | assert self.params['nonlinearity']!='maxout','Maxout nonlinearity not supported'
33 |
34 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""#
35 | def _createParams(self):
36 | """ Model parameters """
37 | npWeights = OrderedDict()
38 | self._createInferenceParams(npWeights)
39 | self._createGenerativeParams(npWeights)
40 | return npWeights
41 |
42 | def _createGenerativeParams(self, npWeights):
43 | """ Create weights/params for generative model """
44 | if 'synthetic' in self.params['dataset']:
45 | updateParamsSynthetic(params_synthetic)
46 | self.params_synthetic = params_synthetic
47 | for k in self.params_synthetic[self.params['dataset']]['params']:
48 | npWeights[k+'_W'] = np.array(np.random.uniform(-0.2,0.2),dtype=config.floatX)
49 | return
50 | DIM_HIDDEN = self.params['dim_hidden']
51 | DIM_STOCHASTIC = self.params['dim_stochastic']
52 | if self.params['transition_type']=='mlp':
53 | DIM_HIDDEN_TRANS = DIM_HIDDEN*2
54 | for l in range(self.params['transition_layers']):
55 | dim_input,dim_output = DIM_HIDDEN_TRANS, DIM_HIDDEN_TRANS
56 | if l==0:
57 | dim_input = self.params['dim_stochastic']
58 | npWeights['p_trans_W_'+str(l)] = self._getWeight((dim_input, dim_output))
59 | npWeights['p_trans_b_'+str(l)] = self._getWeight((dim_output,))
60 | if self.params['use_prev_input']:
61 | npWeights['p_trans_W_0'] = self._getWeight((DIM_STOCHASTIC+self.params['dim_observations'],
62 | DIM_HIDDEN_TRANS))
63 | npWeights['p_trans_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,))
64 | MU_COV_INP = DIM_HIDDEN_TRANS
65 | elif self.params['transition_type']=='simple_gated':
66 | DIM_HIDDEN_TRANS = DIM_HIDDEN*2
67 | npWeights['p_gate_embed_W_0'] = self._getWeight((DIM_STOCHASTIC, DIM_HIDDEN_TRANS))
68 | npWeights['p_gate_embed_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,))
69 | npWeights['p_gate_embed_W_1'] = self._getWeight((DIM_HIDDEN_TRANS, DIM_STOCHASTIC))
70 | npWeights['p_gate_embed_b_1'] = self._getWeight((DIM_STOCHASTIC,))
71 | npWeights['p_z_W_0'] = self._getWeight((DIM_STOCHASTIC, DIM_HIDDEN_TRANS))
72 | npWeights['p_z_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,))
73 | npWeights['p_z_W_1'] = self._getWeight((DIM_HIDDEN_TRANS, DIM_STOCHASTIC))
74 | npWeights['p_z_b_1'] = self._getWeight((DIM_STOCHASTIC,))
75 | if self.params['use_prev_input']:
76 | npWeights['p_z_W_0'] = self._getWeight((DIM_STOCHASTIC+self.params['dim_observations'], DIM_HIDDEN_TRANS))
77 | npWeights['p_z_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,))
78 | npWeights['p_gate_embed_W_0'] = self._getWeight((DIM_STOCHASTIC+self.params['dim_observations'],
79 | DIM_HIDDEN_TRANS))
80 | npWeights['p_gate_embed_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,))
81 | MU_COV_INP = DIM_STOCHASTIC
82 | else:
83 | assert False,'Invalid transition type: '+self.params['transition_type']
84 |
85 | if self.params['transition_type']=='simple_gated':
86 | weight= np.eye(self.params['dim_stochastic']).astype(config.floatX)
87 | bias = np.zeros((self.params['dim_stochastic'],)).astype(config.floatX)
88 | #Initialize the weights to be identity
89 | npWeights['p_trans_W_mu'] = weight
90 | npWeights['p_trans_b_mu'] = bias
91 | else:
92 | npWeights['p_trans_W_mu'] = self._getWeight((MU_COV_INP, self.params['dim_stochastic']))
93 | npWeights['p_trans_b_mu'] = self._getWeight((self.params['dim_stochastic'],))
94 | npWeights['p_trans_W_cov'] = self._getWeight((MU_COV_INP, self.params['dim_stochastic']))
95 | npWeights['p_trans_b_cov'] = self._getWeight((self.params['dim_stochastic'],))
96 |
97 | #Emission Function [MLP]
98 | if self.params['emission_type'] == 'mlp':
99 | for l in range(self.params['emission_layers']):
100 | dim_input,dim_output = DIM_HIDDEN, DIM_HIDDEN
101 | if l==0:
102 | dim_input = self.params['dim_stochastic']
103 | npWeights['p_emis_W_'+str(l)] = self._getWeight((dim_input, dim_output))
104 | npWeights['p_emis_b_'+str(l)] = self._getWeight((dim_output,))
105 | elif self.params['emission_type'] == 'res':
106 | for l in range(self.params['emission_layers']):
107 | dim_input,dim_output = DIM_HIDDEN, DIM_HIDDEN
108 | if l==0:
109 | dim_input = self.params['dim_stochastic']
110 | npWeights['p_emis_W_'+str(l)] = self._getWeight((dim_input, dim_output))
111 | npWeights['p_emis_b_'+str(l)] = self._getWeight((dim_output,))
112 | dim_res_out = self.params['dim_observations']
113 | if self.params['data_type']=='binary_nade':
114 | dim_res_out = DIM_HIDDEN
115 | npWeights['p_res_W'] = self._getWeight((self.params['dim_stochastic'], dim_res_out))
116 | elif self.params['emission_type'] =='conditional':
117 | for l in range(self.params['emission_layers']):
118 | dim_input,dim_output = DIM_HIDDEN, DIM_HIDDEN
119 | if l==0:
120 | dim_input = self.params['dim_stochastic']+self.params['dim_observations']
121 | npWeights['p_emis_W_'+str(l)] = self._getWeight((dim_input, dim_output))
122 | npWeights['p_emis_b_'+str(l)] = self._getWeight((dim_output,))
123 | else:
124 | assert False, 'Invalid emission type: '+str(self.params['emission_type'])
125 | if self.params['data_type']=='binary':
126 | npWeights['p_emis_W_ber'] = self._getWeight((self.params['dim_hidden'], self.params['dim_observations']))
127 | npWeights['p_emis_b_ber'] = self._getWeight((self.params['dim_observations'],))
128 | elif self.params['data_type']=='binary_nade':
129 | n_visible, n_hidden = self.params['dim_observations'], self.params['dim_hidden']
130 | npWeights['p_nade_W'] = self._getWeight((n_visible, n_hidden))
131 | npWeights['p_nade_U'] = self._getWeight((n_visible,n_hidden))
132 | npWeights['p_nade_b'] = self._getWeight((n_visible,))
133 | else:
134 | assert False,'Invalid datatype: '+params['data_type']
135 |
136 | def _createInferenceParams(self, npWeights):
137 | """ Create weights/params for inference network """
138 |
139 | #Initial embedding for the inputs
140 | DIM_INPUT = self.params['dim_observations']
141 | RNN_SIZE = self.params['rnn_size']
142 |
143 | DIM_HIDDEN = RNN_SIZE
144 | DIM_STOC = self.params['dim_stochastic']
145 |
146 | #Embed the Input -> RNN_SIZE
147 | dim_input, dim_output= DIM_INPUT, RNN_SIZE
148 | npWeights['q_W_input_0'] = self._getWeight((dim_input, dim_output))
149 | npWeights['q_b_input_0'] = self._getWeight((dim_output,))
150 |
151 | #Setup weights for LSTM
152 | self._createLSTMWeights(npWeights)
153 |
154 | #Embedding before MF/ST inference model
155 | if self.params['inference_model']=='mean_field':
156 | pass
157 | elif self.params['inference_model']=='structured':
158 | DIM_INPUT = self.params['dim_stochastic']
159 | if self.params['use_generative_prior']:
160 | DIM_INPUT = self.params['rnn_size']
161 | npWeights['q_W_st_0'] = self._getWeight((DIM_INPUT, self.params['rnn_size']))
162 | npWeights['q_b_st_0'] = self._getWeight((self.params['rnn_size'],))
163 | else:
164 | assert False,'Invalid inference model: '+self.params['inference_model']
165 | RNN_SIZE = self.params['rnn_size']
166 | npWeights['q_W_mu'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic']))
167 | npWeights['q_b_mu'] = self._getWeight((self.params['dim_stochastic'],))
168 | npWeights['q_W_cov'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic']))
169 | npWeights['q_b_cov'] = self._getWeight((self.params['dim_stochastic'],))
170 | if self.params['var_model']=='LR' and self.params['inference_model']=='mean_field':
171 | npWeights['q_W_mu_r'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic']))
172 | npWeights['q_b_mu_r'] = self._getWeight((self.params['dim_stochastic'],))
173 | npWeights['q_W_cov_r'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic']))
174 | npWeights['q_b_cov_r'] = self._getWeight((self.params['dim_stochastic'],))
175 |
176 | def _createLSTMWeights(self, npWeights):
177 | #LSTM L/LR/R w/ orthogonal weight initialization
178 | suffices_to_build = []
179 | if self.params['var_model']=='LR' or self.params['var_model']=='L':
180 | suffices_to_build.append('l')
181 | if self.params['var_model']=='LR' or self.params['var_model']=='R':
182 | suffices_to_build.append('r')
183 | RNN_SIZE = self.params['rnn_size']
184 | for suffix in suffices_to_build:
185 | for l in range(self.params['rnn_layers']):
186 | npWeights['W_lstm_'+suffix+'_'+str(l)] = self._getWeight((RNN_SIZE,RNN_SIZE*4))
187 | npWeights['b_lstm_'+suffix+'_'+str(l)] = self._getWeight((RNN_SIZE*4,), scheme='lstm')
188 | npWeights['U_lstm_'+suffix+'_'+str(l)] = self._getWeight((RNN_SIZE,RNN_SIZE*4),scheme='lstm')
189 |
190 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""#
191 | def _getEmissionFxn(self, z, X=None):
192 | """
193 | Apply emission function to zs
194 | Input: z [bs x T x dim]
195 | Output: (params, ) or (mu, cov) of size [bs x T x dim]
196 | """
197 | if 'synthetic' in self.params['dataset']:
198 | self._p('Using emission function for '+self.params['dataset'])
199 | tParams = {}
200 | for k in self.params_synthetic[self.params['dataset']]['params']:
201 | tParams[k] = self.tWeights[k+'_W']
202 | mu = self.params_synthetic[self.params['dataset']]['obs_fxn'](z, fxn_params = tParams)
203 | cov = T.ones_like(mu)*self.params_synthetic[self.params['dataset']]['obs_cov']
204 | cov.name = 'EmissionCov'
205 | return [mu,cov]
206 |
207 | if self.params['emission_type'] in ['mlp','res']:
208 | self._p('EMISSION TYPE: MLP or RES')
209 | hid = z
210 | elif self.params['emission_type']=='conditional':
211 | self._p('EMISSION TYPE: conditional')
212 | X_prev = T.concatenate([T.zeros_like(X[:,[0],:]),X[:,:-1,:]],axis=1)
213 | hid = T.concatenate([z,X_prev],axis=2)
214 | else:
215 | assert False,'Invalid emission type'
216 | for l in range(self.params['emission_layers']):
217 | if self.params['data_type']=='binary_nade' and l==self.params['emission_layers']-1:
218 | hid = T.dot(hid, self.tWeights['p_emis_W_'+str(l)]) + self.tWeights['p_emis_b_'+str(l)]
219 | else:
220 | hid = self._LinearNL(self.tWeights['p_emis_W_'+str(l)], self.tWeights['p_emis_b_'+str(l)], hid)
221 |
222 | if self.params['data_type']=='binary':
223 | if self.params['emission_type']=='res':
224 | hid = T.dot(z,self.tWeights['p_res_W'])+T.dot(hid,self.tWeights['p_emis_W_ber'])+self.tWeights['p_emis_b_ber']
225 | mean_params=T.nnet.sigmoid(hid)
226 | else:
227 | mean_params=T.nnet.sigmoid(T.dot(hid,self.tWeights['p_emis_W_ber'])+self.tWeights['p_emis_b_ber'])
228 | return [mean_params]
229 | elif self.params['data_type']=='binary_nade':
230 | self._p('NADE observations')
231 | assert X is not None,'Need observations for NADE'
232 | if self.params['emission_type']=='res':
233 | hid += T.dot(z,self.tWeights['p_res_W'])
234 | x_reshaped = X.dimshuffle(2,0,1)
235 | x0 = T.ones((hid.shape[0],hid.shape[1]))#x_reshaped[0]) # bs x T
236 | a0 = hid #bs x T x nhid
237 | W = self.tWeights['p_nade_W']
238 | V = self.tWeights['p_nade_U']
239 | b = self.tWeights['p_nade_b']
240 | #Use a NADE at the output
241 | def NADEDensity(x, w, v, b, a_prev, x_prev):#Estimating likelihood
242 | a = a_prev + T.dot(T.shape_padright(x_prev, 1), T.shape_padleft(w, 1))
243 | h = T.nnet.sigmoid(a) #Original - bs x T x nhid
244 | p_xi_is_one = T.nnet.sigmoid(T.dot(h, v) + b)
245 | return (a, x, p_xi_is_one)
246 | ([_, _, mean_params], _) = theano.scan(NADEDensity,
247 | sequences=[x_reshaped, W, V, b],
248 | outputs_info=[a0, x0,None])
249 | #theano function to sample from NADE
250 | def NADESample(w, v, b, a_prev_s, x_prev_s):
251 | a_s = a_prev_s + T.dot(T.shape_padright(x_prev_s, 1), T.shape_padleft(w, 1))
252 | h_s = T.nnet.sigmoid(a_s) #Original - bs x T x nhid
253 | p_xi_is_one_s = T.nnet.sigmoid(T.dot(h_s, v) + b)
254 | x_s = T.switch(p_xi_is_one_s>0.5,1.,0.)
255 | return (a_s, x_s, p_xi_is_one_s)
256 | ([_, _, sampled_params], _) = theano.scan(NADESample,
257 | sequences=[W, V, b],
258 | outputs_info=[a0, x0,None])
259 | """
260 | def NADEDensityAndSample(x, w, v, b,
261 | a_prev, x_prev,
262 | a_prev_s, x_prev_s ):
263 | a = a_prev + T.dot(T.shape_padright(x_prev, 1), T.shape_padleft(w, 1))
264 | h = T.nnet.sigmoid(a) #bs x T x nhid
265 | p_xi_is_one = T.nnet.sigmoid(T.dot(h, v) + b)
266 |
267 | a_s = a_prev_s + T.dot(T.shape_padright(x_prev_s, 1), T.shape_padleft(w, 1))
268 | h_s = T.nnet.sigmoid(a_s) #bs x T x nhid
269 | p_xi_is_one_s = T.nnet.sigmoid(T.dot(h_s, v) + b)
270 | x_s = T.switch(p_xi_is_one_s>0.5,1.,0.)
271 | return (a, x, a_s, x_s, p_xi_is_one, p_xi_is_one_s)
272 |
273 | ([_, _, _, _, mean_params,sampled_params], _) = theano.scan(NADEDensityAndSample,
274 | sequences=[x_reshaped, W, V, b],
275 | outputs_info=[a0, x0, a0, x0, None, None])
276 | """
277 | sampled_params = sampled_params.dimshuffle(1,2,0)
278 | mean_params = mean_params.dimshuffle(1,2,0)
279 | return [mean_params,sampled_params]
280 | else:
281 | assert False,'Invalid type of data'
282 |
283 | def _getTransitionFxn(self, z, X=None):
284 | """
285 | Apply transition function to zs
286 | Input: z [bs x T x dim], u [bs x T x dim]
287 | Output: mu, cov of size [bs x T x dim]
288 | """
289 | if 'synthetic' in self.params['dataset']:
290 | self._p('Using transition function for '+self.params['dataset'])
291 | tParams = {}
292 | for k in self.params_synthetic[self.params['dataset']]['params']:
293 | tParams[k] = self.tWeights[k+'_W']
294 | mu = self.params_synthetic[self.params['dataset']]['trans_fxn'](z, fxn_params = tParams)
295 | cov = T.ones_like(mu)*self.params_synthetic[self.params['dataset']]['trans_cov']
296 | cov.name = 'TransitionCov'
297 | return mu,cov
298 |
299 | if self.params['transition_type']=='simple_gated':
300 | def mlp(inp, W1,b1,W2,b2, X_prev=None):
301 | if X_prev is not None:
302 | h1 = self._LinearNL(W1,b1, T.concatenate([inp,X_prev],axis=2))
303 | else:
304 | h1 = self._LinearNL(W1,b1, inp)
305 | h2 = T.dot(h1,W2)+b2
306 | return h2
307 |
308 | gateInp= z
309 | X_prev = None
310 | if self.params['use_prev_input']:
311 | X_prev = T.concatenate([T.zeros_like(X[:,[0],:]),X[:,:-1,:]],axis=1)
312 | gate = T.nnet.sigmoid(mlp(gateInp, self.tWeights['p_gate_embed_W_0'], self.tWeights['p_gate_embed_b_0'],
313 | self.tWeights['p_gate_embed_W_1'],self.tWeights['p_gate_embed_b_1'],
314 | X_prev = X_prev))
315 | z_prop = mlp(z,self.tWeights['p_z_W_0'] ,self.tWeights['p_z_b_0'],
316 | self.tWeights['p_z_W_1'] , self.tWeights['p_z_b_1'], X_prev = X_prev)
317 | mu = gate*z_prop + (1.-gate)*(T.dot(z, self.tWeights['p_trans_W_mu'])+self.tWeights['p_trans_b_mu'])
318 | cov = T.nnet.softplus(T.dot(self._applyNL(z_prop), self.tWeights['p_trans_W_cov'])+
319 | self.tWeights['p_trans_b_cov'])
320 | return mu,cov
321 | elif self.params['transition_type']=='mlp':
322 | hid = z
323 | if self.params['use_prev_input']:
324 | X_prev = T.concatenate([T.zeros_like(X[:,[0],:]),X[:,:-1,:]],axis=1)
325 | hid = T.concatenate([hid,X_prev],axis=2)
326 | for l in range(self.params['transition_layers']):
327 | hid = self._LinearNL(self.tWeights['p_trans_W_'+str(l)],self.tWeights['p_trans_b_'+str(l)],hid)
328 | mu = T.dot(hid, self.tWeights['p_trans_W_mu']) + self.tWeights['p_trans_b_mu']
329 | cov = T.nnet.softplus(T.dot(hid, self.tWeights['p_trans_W_cov'])+self.tWeights['p_trans_b_cov'])
330 | return mu,cov
331 | else:
332 | assert False,'Invalid Transition type: '+str(self.params['transition_type'])
333 |
334 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""#
335 | def _buildLSTM(self, X, embedding, dropout_prob = 0.):
336 | """
337 | Take the embedding of bs x T x dim and return T x bs x dim that is the result of the scan operation
338 | for L/LR/R
339 | Input: embedding [bs x T x dim]
340 | Output:hidden_state [T x bs x dim]
341 | """
342 | start_time = time.time()
343 | self._p('In <_buildLSTM>')
344 | suffix = ''
345 | if self.params['var_model']=='R':
346 | suffix='r'
347 | return self._LSTMlayer(embedding, suffix, dropout_prob)
348 | elif self.params['var_model']=='L':
349 | suffix='l'
350 | return self._LSTMlayer(embedding, suffix, dropout_prob)
351 | elif self.params['var_model']=='LR':
352 | suffix='l'
353 | l2r = self._LSTMlayer(embedding, suffix, dropout_prob)
354 | suffix='r'
355 | r2l = self._LSTMlayer(embedding, suffix, dropout_prob)
356 | return [l2r,r2l]
357 | else:
358 | assert False,'Invalid variational model'
359 | self._p(('Done <_buildLSTM> [Took %.4f]')%(time.time()-start_time))
360 |
361 |
362 | def _inferenceLayer(self, hidden_state):
363 | """
364 | Take input of T x bs x dim and return z, mu,
365 | sq each of size (bs x T x dim)
366 | Input: hidden_state [T x bs x dim], eps [bs x T x dim]
367 | Output: z [bs x T x dim], mu [bs x T x dim], cov [bs x T x dim]
368 | """
369 | def structuredApproximation(h_t, eps_t, z_prev,
370 | q_W_st_0, q_b_st_0,
371 | q_W_mu, q_b_mu,
372 | q_W_cov,q_b_cov):
373 | #Using the prior distribution directly
374 | if self.params['use_generative_prior']:
375 | assert not self.params['use_prev_input'],'No support for using previous input'
376 | #Get mu/cov from z_prev through prior distribution
377 | mu_1,cov_1 = self._getTransitionFxn(z_prev)
378 | #Combine with estimate of mu/cov from data
379 | h_data = T.tanh(T.dot(h_t,q_W_st_0)+q_b_st_0)
380 | mu_2 = T.dot(h_data,q_W_mu)+q_b_mu
381 | cov_2 = T.nnet.softplus(T.dot(h_data,q_W_cov)+q_b_cov)
382 | mu = (mu_1*cov_2+mu_2*cov_1)/(cov_1+cov_2)
383 | cov = (cov_1*cov_2)/(cov_1+cov_2)
384 | z = mu + T.sqrt(cov)*eps_t
385 | return z, mu, cov
386 | else:
387 | h_next = T.tanh(T.dot(z_prev,q_W_st_0)+q_b_st_0)
388 | if self.params['var_model']=='LR':
389 | h_next = (1./3.)*(h_t+h_next)
390 | else:
391 | h_next = (1./2.)*(h_t+h_next)
392 | mu_t = T.dot(h_next,q_W_mu)+q_b_mu
393 | cov_t = T.nnet.softplus(T.dot(h_next,q_W_cov)+q_b_cov)
394 | z_t = mu_t+T.sqrt(cov_t)*eps_t
395 | return z_t, mu_t, cov_t
396 | if type(hidden_state) is list:
397 | eps = self.srng.normal(size=(hidden_state[0].shape[1],hidden_state[0].shape[0],self.params['dim_stochastic']))
398 | else:
399 | eps = self.srng.normal(size=(hidden_state.shape[1],hidden_state.shape[0],self.params['dim_stochastic']))
400 | if self.params['inference_model']=='structured':
401 | #Structured recognition networks
402 | if self.params['var_model']=='LR':
403 | state = hidden_state[0]+hidden_state[1]
404 | else:
405 | state = hidden_state
406 | eps_swap = eps.swapaxes(0,1)
407 | if self.params['dim_stochastic']==1:
408 | """
409 | TODO: Write to theano authors regarding this issue.
410 | Workaround for theano issue: The result of a matrix multiply is a "matrix"
411 | even if one of the dimensions is 1. However defining a tensor with one dimension one
412 | means theano regards the resulting tensor as a matrix and consequently in the
413 | scan as a column. This results in a mismatch in tensor type in input (column)
414 | and output (matrix) and throws an error. This is a workaround that preserves
415 | type while not affecting dimensions
416 | """
417 | z0 = T.zeros((eps_swap.shape[1], self.params['rnn_size']))
418 | z0 = T.dot(z0,T.zeros_like(self.tWeights['q_W_mu']))
419 | else:
420 | z0 = T.zeros((eps_swap.shape[1], self.params['dim_stochastic']))
421 | rval, _ = theano.scan(structuredApproximation,
422 | sequences=[state, eps_swap],
423 | outputs_info=[z0, None,None],
424 | non_sequences=[self.tWeights[k] for k in
425 | ['q_W_st_0', 'q_b_st_0']]+
426 | [self.tWeights[k] for k in
427 | ['q_W_mu','q_b_mu','q_W_cov','q_b_cov']],
428 | name='structuredApproximation')
429 | z, mu, cov = rval[0].swapaxes(0,1), rval[1].swapaxes(0,1), rval[2].swapaxes(0,1)
430 | return z, mu, cov
431 | elif self.params['inference_model']=='mean_field':
432 | if self.params['var_model']=='LR':
433 | l2r = hidden_state[0].swapaxes(0,1)
434 | r2l = hidden_state[1].swapaxes(0,1)
435 | hidl2r = l2r
436 | mu_1 = T.dot(hidl2r,self.tWeights['q_W_mu'])+self.tWeights['q_b_mu']
437 | cov_1 = T.nnet.softplus(T.dot(hidl2r, self.tWeights['q_W_cov'])+self.tWeights['q_b_cov'])
438 | hidr2l = r2l
439 | mu_2 = T.dot(hidr2l,self.tWeights['q_W_mu_r'])+self.tWeights['q_b_mu_r']
440 | cov_2 = T.nnet.softplus(T.dot(hidr2l, self.tWeights['q_W_cov_r'])+self.tWeights['q_b_cov_r'])
441 | mu = (mu_1*cov_2+mu_2*cov_1)/(cov_1+cov_2)
442 | cov= (cov_1*cov_2)/(cov_1+cov_2)
443 | z = mu + T.sqrt(cov)*eps
444 | else:
445 | hid = hidden_state.swapaxes(0,1)
446 | mu = T.dot(hid,self.tWeights['q_W_mu'])+self.tWeights['q_b_mu']
447 | cov = T.nnet.softplus(T.dot(hid,self.tWeights['q_W_cov'])+ self.tWeights['q_b_cov'])
448 | z = mu + T.sqrt(cov)*eps
449 | return z,mu,cov
450 | else:
451 | assert False,'Invalid recognition model'
452 |
453 | def _qEmbeddingLayer(self, X):
454 | """ Embed for q """
455 | return self._LinearNL(self.tWeights['q_W_input_0'],self.tWeights['q_b_input_0'], X)
456 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""#
457 |
458 | def _inferenceAndReconstruction(self, X, dropout_prob = 0.):
459 | """
460 | Returns z_q, mu_q and cov_q
461 | """
462 | self._p('Building with dropout:'+str(dropout_prob))
463 | embedding = self._qEmbeddingLayer(X)
464 | hidden_state = self._buildLSTM(X, embedding, dropout_prob)
465 | z_q,mu_q,cov_q = self._inferenceLayer(hidden_state)
466 |
467 | #Regularize z_q (for train)
468 | #if dropout_prob>0.:
469 | # z_q = z_q + self.srng.normal(z_q.shape, 0.,0.0025,dtype=config.floatX)
470 | #z_q.name = 'z_q'
471 |
472 | observation_params = self._getEmissionFxn(z_q,X=X)
473 | mu_trans, cov_trans= self._getTransitionFxn(z_q, X=X)
474 | mu_prior = T.concatenate([T.alloc(np.asarray(0.).astype(config.floatX),
475 | X.shape[0],1,self.params['dim_stochastic']), mu_trans[:,:-1,:]],axis=1)
476 | cov_prior = T.concatenate([T.alloc(np.asarray(1.).astype(config.floatX),
477 | X.shape[0],1,self.params['dim_stochastic']), cov_trans[:,:-1,:]],axis=1)
478 | return observation_params, z_q, mu_q, cov_q, mu_prior, cov_prior, mu_trans, cov_trans
479 |
480 |
481 | def _getTemporalKL(self, mu_q, cov_q, mu_prior, cov_prior, M, batchVector = False):
482 | """
483 | TemporalKL divergence KL (q||p)
484 | KL(q_t||p_t) = 0.5*(log|sigmasq_p| -log|sigmasq_q| -D + Tr(sigmasq_p^-1 sigmasq_q)
485 | + (mu_p-mu_q)^T sigmasq_p^-1 (mu_p-mu_q))
486 | M is a mask of size bs x T that should be applied once the KL divergence for each point
487 | across time has been estimated
488 | """
489 | assert np.all(cov_q.tag.test_value>0.),'should be positive'
490 | assert np.all(cov_prior.tag.test_value>0.),'should be positive'
491 | diff_mu = mu_prior-mu_q
492 | KL_t = T.log(cov_prior)-T.log(cov_q) - 1. + cov_q/cov_prior + diff_mu**2/cov_prior
493 | KLvec = (0.5*KL_t.sum(2)*M).sum(1,keepdims=True)
494 | if batchVector:
495 | return KLvec
496 | else:
497 | return KLvec.sum()
498 |
499 | def _getNegCLL(self, obs_params, X, M, batchVector = False):
500 | """
501 | Estimate the negative conditional log likelihood of x|z under the generative model
502 | M: mask of size bs x T
503 | X: target of size bs x T x dim
504 | """
505 | if self.params['data_type']=='real':
506 | mu_p = obs_params[0]
507 | cov_p = obs_params[1]
508 | std_p = T.sqrt(cov_p)
509 | negCLL_t = 0.5 * np.log(2 * np.pi) + 0.5*T.log(cov_p) + 0.5 * ((X - mu_p) / std_p)**2
510 | negCLL = (negCLL_t.sum(2)*M).sum(1,keepdims=True)
511 | else:
512 | mean_p = obs_params[0]
513 | negCLL = (T.nnet.binary_crossentropy(mean_p,X).sum(2)*M).sum(1,keepdims=True)
514 | if batchVector:
515 | return negCLL
516 | else:
517 | return negCLL.sum()
518 |
519 | def resetDataset(self, newX,newM,quiet=False):
520 | if not quiet:
521 | ddim,mdim = self.dimData()
522 | self._p('Original dim:'+str(ddim)+', '+str(mdim))
523 | self.setData(newX=newX.astype(config.floatX),newMask=newM.astype(config.floatX))
524 | if not quiet:
525 | ddim,mdim = self.dimData()
526 | self._p('New dim:'+str(ddim)+', '+str(mdim))
527 | def _buildModel(self):
528 | if 'synthetic' in self.params['dataset']:
529 | self.params_synthetic = params_synthetic
530 | """ High level function to build and setup theano functions """
531 | #X = T.tensor3('X', dtype=config.floatX)
532 | #eps = T.tensor3('eps', dtype=config.floatX)
533 | #M = T.matrix('M', dtype=config.floatX)
534 | idx = T.vector('idx',dtype='int64')
535 | idx.tag.test_value = np.array([0,1]).astype('int64')
536 | self.dataset = theano.shared(np.random.uniform(0,1,size=(3,5,self.params['dim_observations'])).astype(config.floatX))
537 | self.mask = theano.shared(np.array(([[1,1,1,1,0],[1,1,0,0,0],[1,1,1,0,0]])).astype(config.floatX))
538 | X_o = self.dataset[idx]
539 | M_o = self.mask[idx]
540 | maxidx = T.cast(M_o.sum(1).max(),'int64')
541 | X = X_o[:,:maxidx,:]
542 | M = M_o[:,:maxidx]
543 | newX,newMask = T.tensor3('newX',dtype=config.floatX),T.matrix('newMask',dtype=config.floatX)
544 | self.setData = theano.function([newX,newMask],None,updates=[(self.dataset,newX),(self.mask,newMask)])
545 | self.dimData = theano.function([],[self.dataset.shape,self.mask.shape])
546 |
547 | #Learning Rates and annealing objective function
548 | #Add them to npWeights/tWeights to be tracked [do not have a prefix _W or _b so wont be diff.]
549 | self._addWeights('lr', np.asarray(self.params['lr'],dtype=config.floatX),borrow=False)
550 | self._addWeights('anneal', np.asarray(0.01,dtype=config.floatX),borrow=False)
551 | self._addWeights('update_ctr', np.asarray(1.,dtype=config.floatX),borrow=False)
552 | lr = self.tWeights['lr']
553 | anneal = self.tWeights['anneal']
554 | iteration_t = self.tWeights['update_ctr']
555 |
556 | anneal_div = 1000.
557 | if 'anneal_rate' in self.params:
558 | self._p('Anneal = 1 in '+str(self.params['anneal_rate'])+' param. updates')
559 | anneal_div = self.params['anneal_rate']
560 | if 'synthetic' in self.params['dataset']:
561 | anneal_div = 100.
562 | anneal_update = [(iteration_t, iteration_t+1),
563 | (anneal,T.switch(0.01+iteration_t/anneal_div>1,1,0.01+iteration_t/anneal_div))]
564 | fxn_inputs = [idx]
565 | if not self.params['validate_only']:
566 | print '****** CREATING TRAINING FUNCTION*****'
567 | ############# Setup training functions ###########
568 | obs_params, z_q, mu_q, cov_q, mu_prior, cov_prior, _, _ = self._inferenceAndReconstruction(
569 | X, dropout_prob = self.params['rnn_dropout'])
570 | negCLL = self._getNegCLL(obs_params, X, M)
571 | TemporalKL = self._getTemporalKL(mu_q, cov_q, mu_prior, cov_prior, M)
572 | train_cost = negCLL+anneal*TemporalKL
573 |
574 | #Get updates from optimizer
575 | model_params = self._getModelParams()
576 | optimizer_up, norm_list = self._setupOptimizer(train_cost, model_params,lr = lr,
577 | #Turning off for synthetic
578 | #reg_type =self.params['reg_type'],
579 | #reg_spec =self.params['reg_spec'],
580 | #reg_value= self.params['reg_value'],
581 | divide_grad = T.cast(X.shape[0],dtype=config.floatX),
582 | grad_norm = 1.)
583 |
584 | #Add annealing updates
585 | optimizer_up +=anneal_update+self.updates
586 | self._p(str(len(self.updates))+' other updates')
587 | ############# Setup train & evaluate functions ###########
588 | self.train_debug = theano.function(fxn_inputs,[train_cost,norm_list[0],norm_list[1],
589 | norm_list[2],negCLL, TemporalKL, anneal.sum()],
590 | updates = optimizer_up, name='Train (with Debug)')
591 | #Updates ack
592 | self.updates_ack = True
593 | eval_obs_params, eval_z_q, eval_mu_q, eval_cov_q, eval_mu_prior, eval_cov_prior, \
594 | eval_mu_trans, eval_cov_trans = self._inferenceAndReconstruction(X,dropout_prob = 0.)
595 | eval_z_q.name = 'eval_z_q'
596 | eval_CNLLvec=self._getNegCLL(eval_obs_params, X, M, batchVector = True)
597 | eval_KLvec = self._getTemporalKL(eval_mu_q, eval_cov_q,eval_mu_prior, eval_cov_prior, M, batchVector = True)
598 | eval_cost = eval_CNLLvec + eval_KLvec
599 |
600 | #From here on, convert to the log covariance since we only use it for evaluation
601 | assert np.all(eval_cov_q.tag.test_value>0.),'should be positive'
602 | assert np.all(eval_cov_prior.tag.test_value>0.),'should be positive'
603 | assert np.all(eval_cov_trans.tag.test_value>0.),'should be positive'
604 | eval_logcov_q = T.log(eval_cov_q)
605 | eval_logcov_prior = T.log(eval_cov_prior)
606 | eval_logcov_trans = T.log(eval_cov_trans)
607 |
608 | ll_prior = self._llGaussian(eval_z_q, eval_mu_prior, eval_logcov_prior).sum(2)*M
609 | ll_posterior = self._llGaussian(eval_z_q, eval_mu_q, eval_logcov_q).sum(2)*M
610 | ll_estimate = -1*eval_CNLLvec+ll_prior.sum(1,keepdims=True)-ll_posterior.sum(1,keepdims=True)
611 |
612 | eval_inputs = [eval_z_q]
613 | self.likelihood = theano.function(fxn_inputs, ll_estimate, name = 'Importance Sampling based likelihood')
614 | self.evaluate = theano.function(fxn_inputs, eval_cost, name = 'Evaluate Bound')
615 | if self.params['use_prev_input']:
616 | eval_inputs.append(X)
617 | self.transition_fxn = theano.function(eval_inputs,[eval_mu_trans, eval_logcov_trans],
618 | name='Transition Function')
619 | emission_inputs = [eval_z_q]
620 | if self.params['emission_type']=='conditional':
621 | emission_inputs.append(X)
622 | if self.params['data_type']=='binary_nade':
623 | self.emission_fxn = theano.function(emission_inputs,
624 | eval_obs_params[1], name='Emission Function')
625 | else:
626 | self.emission_fxn = theano.function(emission_inputs,
627 | eval_obs_params[0], name='Emission Function')
628 | self.posterior_inference = theano.function(fxn_inputs,
629 | [eval_z_q, eval_mu_q, eval_logcov_q],
630 | name='Posterior Inference')
631 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""#
632 | if __name__=='__main__':
633 | """ use this to check compilation for various options"""
634 | from parse_args_dkf import params
635 | if params['use_nade']:
636 | params['data_type'] = 'binary_nade'
637 | else:
638 | params['data_type'] = 'binary'
639 | params['dim_observations'] = 10
640 | dkf = DKF(params, paramFile = 'tmp')
641 | os.unlink('tmp')
642 | import ipdb;ipdb.set_trace()
643 |
--------------------------------------------------------------------------------
/stinfmodel_fast/evaluate.py:
--------------------------------------------------------------------------------
1 |
2 | from theano import config
3 | import numpy as np
4 | import time
5 | """
6 | Functions for evaluating a DKF object
7 | """
8 | def infer(dkf, dataset, mask):
9 | """ Posterior Inference using recognition network
10 | Returns: z,mu,logcov (each a 3D tensor) Remember to multiply each by the mask of the dataset before
11 | using the latent variables
12 | """
13 | dkf.resetDataset(dataset,mask,quiet=True)
14 | assert len(dataset.shape)==3,'Expecting 3D tensor for data'
15 | assert dataset.shape[2]==dkf.params['dim_observations'],'Data dim. not matching'
16 | return dkf.posterior_inference(idx=np.arange(dataset.shape[0]))
17 |
18 | def evaluateBound(dkf, dataset, mask, batch_size,S=2, normalization = 'frame', additional={}):
19 | """ Evaluate ELBO """
20 | bound = 0
21 | start_time = time.time()
22 | N = dataset.shape[0]
23 | tsbn_bound = 0
24 | dkf.resetDataset(dataset,mask)
25 | for bnum,st_idx in enumerate(range(0,N,batch_size)):
26 | end_idx = min(st_idx+batch_size, N)
27 | idx_data= np.arange(st_idx,end_idx)
28 | maxS = S
29 | bound_sum, tsbn_bound_sum = 0, 0
30 | for s in range(S):
31 | if s>0 and s%500==0:
32 | dkf._p('Done '+str(s))
33 | batch_vec= dkf.evaluate(idx=idx_data)
34 | M = mask[idx_data]
35 | if np.any(np.isnan(batch_vec)) or np.any(np.isinf(batch_vec)):
36 | dkf._p('NaN detected during evaluation. Ignoring this sample')
37 | maxS -=1
38 | continue
39 | else:
40 | tsbn_bound_sum+=(batch_vec/M.sum(1,keepdims=True)).sum()
41 | bound_sum+=batch_vec.sum()
42 | tsbn_bound += tsbn_bound_sum/float(max(maxS*N,1.))
43 | bound += bound_sum/float(max(maxS,1.))
44 | if normalization=='frame':
45 | bound /= float(mask.sum())
46 | elif normalization=='sequence':
47 | bound /= float(N)
48 | else:
49 | assert False,'Invalid normalization specified'
50 | end_time = time.time()
51 | dkf._p(('(Evaluate) Validation Bound: %.4f [Took %.4f seconds], TSBN Bound: %.4f')%(bound,end_time-start_time,tsbn_bound))
52 | additional['tsbn_bound'] = tsbn_bound
53 | return bound
54 |
55 |
56 | def impSamplingNLL(dkf, dataset, mask, batch_size, S = 2, normalization = 'frame'):
57 | """ Importance sampling based log likelihood """
58 | ll = 0
59 | start_time = time.time()
60 | N = dataset.shape[0]
61 | dkf.resetDataset(dataset,mask)
62 | for bnum,st_idx in enumerate(range(0,N,batch_size)):
63 | end_idx = min(st_idx+batch_size, N)
64 | idx_data= np.arange(st_idx,end_idx)
65 | maxS = S
66 | lllist = []
67 | for s in range(S):
68 | if s>0 and s%500==0:
69 | dkf._p('Done '+str(s))
70 | batch_vec = dkf.likelihood(idx=idx_data)
71 | if np.any(np.isnan(batch_vec)) or np.any(np.isinf(batch_vec)):
72 | dkf._p('NaN detected during evaluation. Ignoring this sample')
73 | maxS -=1
74 | continue
75 | else:
76 | lllist.append(batch_vec)
77 | ll += dkf.meanSumExp(np.concatenate(lllist,axis=1), axis=1).sum()
78 | if normalization=='frame':
79 | ll /= float(mask.sum())
80 | elif normalization=='sequence':
81 | ll /= float(N)
82 | else:
83 | assert False,'Invalid normalization specified'
84 | end_time = time.time()
85 | dkf._p(('(Evaluate w/ Imp. Sampling) Validation LL: %.4f [Took %.4f seconds]')%(ll,end_time-start_time))
86 | return ll
87 |
88 |
89 | def sampleGaussian(dkf,mu,logcov):
90 | return mu + np.random.randn(*mu.shape)*np.exp(0.5*logcov)
91 |
92 | def sample(dkf, nsamples=100, T=10, additional = {}):
93 | """
94 | Sample from Generative Model
95 | """
96 | assert T>1, 'Sample atleast 2 timesteps'
97 | #Initial sample
98 | z = np.random.randn(nsamples,1,dkf.params['dim_stochastic']).astype(config.floatX)
99 | all_zs = [np.copy(z)]
100 | additional['mu'] = []
101 | additional['logcov'] = []
102 | for t in range(T-1):
103 | mu,logcov = dkf.transition_fxn(z)
104 | z = dkf.sampleGaussian(mu,logcov).astype(config.floatX)
105 | all_zs.append(np.copy(z))
106 | additional['mu'].append(np.copy(mu))
107 | additional['logcov'].append(np.copy(logcov))
108 | zvec = np.concatenate(all_zs,axis=1)
109 | additional['mu'] = np.concatenate(additional['mu'], axis=1)
110 | additional['logcov'] = np.concatenate(additional['logcov'], axis=1)
111 | return dkf.emission_fxn(zvec), zvec
112 |
--------------------------------------------------------------------------------
/stinfmodel_fast/learning.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions for learning with a DKF object
3 | """
4 | import evaluate as DKF_evaluate
5 | import numpy as np
6 | from utils.misc import saveHDF5
7 | import time
8 | from theano import config
9 |
10 | def learn(dkf, dataset, mask, epoch_start=0, epoch_end=1000,
11 | batch_size=200, shuffle=True,
12 | savefreq=None, savefile = None,
13 | dataset_eval = None, mask_eval = None,
14 | replicate_K = None,
15 | normalization = 'frame'):
16 | """
17 | Train DKF
18 | """
19 | assert not dkf.params['validate_only'],'cannot learn in validate only mode'
20 | assert len(dataset.shape)==3,'Expecting 3D tensor for data'
21 | assert dataset.shape[2]==dkf.params['dim_observations'],'Dim observations not valid'
22 | N = dataset.shape[0]
23 | idxlist = range(N)
24 | batchlist = np.split(idxlist, range(batch_size,N,batch_size))
25 |
26 | bound_train_list,bound_valid_list,bound_tsbn_list,nll_valid_list = [],[],[],[]
27 | p_norm, g_norm, opt_norm = None, None, None
28 |
29 | #Lists used to track quantities for synthetic experiments
30 | mu_list_train, cov_list_train, mu_list_valid, cov_list_valid = [],[],[],[]
31 | model_params = {}
32 |
33 | #Start of training loop
34 | if 'synthetic' in dkf.params['dataset']:
35 | epfreq = 10
36 | else:
37 | epfreq = 1
38 |
39 | #Set data
40 | dkf.resetDataset(dataset, mask)
41 | for epoch in range(epoch_start, epoch_end):
42 | #Shuffle
43 | if shuffle:
44 | np.random.shuffle(idxlist)
45 | batchlist = np.split(idxlist, range(batch_size,N,batch_size))
46 | #Always shuffle order the batches are presented in
47 | np.random.shuffle(batchlist)
48 |
49 | start_time = time.time()
50 | bound = 0
51 | for bnum, batch_idx in enumerate(batchlist):
52 | batch_idx = batchlist[bnum]
53 | batch_bound, p_norm, g_norm, opt_norm, negCLL, KL, anneal = dkf.train_debug(idx=batch_idx)
54 |
55 | #Number of frames
56 | M_sum = mask[batch_idx].sum()
57 | #Correction for replicating batch
58 | if replicate_K is not None:
59 | batch_bound, negCLL, KL = batch_bound/replicate_K, negCLL/replicate_K, KL/replicate_K,
60 | M_sum = M_sum/replicate_K
61 | #Update bound
62 | bound += batch_bound
63 | ### Display ###
64 | if epoch%epfreq==0 and bnum%10==0:
65 | if normalization=='frame':
66 | bval = batch_bound/float(M_sum)
67 | elif normalization=='sequence':
68 | bval = batch_bound/float(X.shape[0])
69 | else:
70 | assert False,'Invalid normalization'
71 | dkf._p(('Bnum: %d, Batch Bound: %.4f, |w|: %.4f, |dw|: %.4f, |w_opt|: %.4f')%(bnum,bval,p_norm, g_norm, opt_norm))
72 | dkf._p(('-veCLL:%.4f, KL:%.4f, anneal:%.4f')%(negCLL, KL, anneal))
73 | if normalization=='frame':
74 | bound /= (float(mask.sum())/replicate_K)
75 | elif normalization=='sequence':
76 | bound /= float(N)
77 | else:
78 | assert False,'Invalid normalization'
79 | bound_train_list.append((epoch,bound))
80 | end_time = time.time()
81 | if epoch%epfreq==0:
82 | dkf._p(('(Ep %d) Bound: %.4f [Took %.4f seconds] ')%(epoch, bound, end_time-start_time))
83 |
84 | #Save at intermediate stages
85 | if savefreq is not None and epoch%savefreq==0:
86 | assert savefile is not None, 'expecting savefile'
87 | dkf._p(('Saving at epoch %d'%epoch))
88 | dkf._saveModel(fname = savefile+'-EP'+str(epoch))
89 | intermediate = {}
90 | if dataset_eval is not None and mask_eval is not None:
91 | tmpMap = {}
92 | bound_valid_list.append(
93 | (epoch,
94 | DKF_evaluate.evaluateBound(dkf, dataset_eval, mask_eval, batch_size=batch_size,
95 | additional = tmpMap, normalization=normalization)))
96 | bound_tsbn_list.append((epoch, tmpMap['tsbn_bound']))
97 | nll_valid_list.append(
98 | DKF_evaluate.impSamplingNLL(dkf, dataset_eval, mask_eval, batch_size,
99 | normalization=normalization))
100 | intermediate['valid_bound'] = np.array(bound_valid_list)
101 | intermediate['train_bound'] = np.array(bound_train_list)
102 | intermediate['tsbn_bound'] = np.array(bound_tsbn_list)
103 | intermediate['valid_nll'] = np.array(nll_valid_list)
104 | if 'synthetic' in dkf.params['dataset']:
105 | mu_train, cov_train, mu_valid, cov_valid, learned_params = _syntheticProc(dkf, dataset,mask, dataset_eval,mask_eval)
106 | if dkf.params['dim_stochastic']==1:
107 | mu_list_train.append(mu_train)
108 | cov_list_train.append(cov_train)
109 | mu_list_valid.append(mu_valid)
110 | cov_list_valid.append(cov_valid)
111 | intermediate['mu_posterior_train'] = np.concatenate(mu_list_train, axis=2)
112 | intermediate['cov_posterior_train'] = np.concatenate(cov_list_train, axis=2)
113 | intermediate['mu_posterior_valid'] = np.concatenate(mu_list_valid, axis=2)
114 | intermediate['cov_posterior_valid'] = np.concatenate(cov_list_valid, axis=2)
115 | else:
116 | mu_list_train.append(mu_train[None,:])
117 | cov_list_train.append(cov_train[None,:])
118 | mu_list_valid.append(mu_valid[None,:])
119 | cov_list_valid.append(cov_valid[None,:])
120 | intermediate['mu_posterior_train'] = np.concatenate(mu_list_train, axis=0)
121 | intermediate['cov_posterior_train'] = np.concatenate(cov_list_train, axis=0)
122 | intermediate['mu_posterior_valid'] = np.concatenate(mu_list_valid, axis=0)
123 | intermediate['cov_posterior_valid'] = np.concatenate(cov_list_valid, axis=0)
124 | for k in dkf.params_synthetic[dkf.params['dataset']]['params']:
125 | if k in model_params:
126 | model_params[k].append(learned_params[k])
127 | else:
128 | model_params[k] = [learned_params[k]]
129 | for k in dkf.params_synthetic[dkf.params['dataset']]['params']:
130 | intermediate[k+'_learned'] = np.array(model_params[k]).squeeze()
131 | saveHDF5(savefile+'-EP'+str(epoch)+'-stats.h5', intermediate)
132 | ### Update X in the computational flow_graph to point to training data
133 | dkf.resetDataset(dataset, mask)
134 | #Final information to be collected
135 | retMap = {}
136 | retMap['train_bound'] = np.array(bound_train_list)
137 | retMap['valid_bound'] = np.array(bound_valid_list)
138 | retMap['tsbn_bound'] = np.array(bound_tsbn_list)
139 | retMap['valid_nll'] = np.array(nll_valid_list)
140 | if 'synthetic' in dkf.params['dataset']:
141 | if dkf.params['dim_stochastic']==1:
142 | retMap['mu_posterior_train'] = np.concatenate(mu_list_train, axis=2)
143 | retMap['cov_posterior_train'] = np.concatenate(cov_list_train, axis=2)
144 | retMap['mu_posterior_valid'] = np.concatenate(mu_list_valid, axis=2)
145 | retMap['cov_posterior_valid'] = np.concatenate(cov_list_valid, axis=2)
146 | else:
147 | retMap['mu_posterior_train'] = np.concatenate(mu_list_train, axis=0)
148 | retMap['cov_posterior_train'] = np.concatenate(cov_list_train, axis=0)
149 | retMap['mu_posterior_valid'] = np.concatenate(mu_list_valid, axis=0)
150 | retMap['cov_posterior_valid'] = np.concatenate(cov_list_valid, axis=0)
151 | for k in dkf.params_synthetic[dkf.params['dataset']]['params']:
152 | retMap[k+'_learned'] = np.array(model_params[k])
153 | return retMap
154 |
155 | def _syntheticProc(dkf, dataset, mask, dataset_eval, mask_eval):
156 | """
157 | Collect statistics on the synthetic dataset
158 | """
159 | allmus, alllogcov = [], []
160 | if dkf.params['dim_stochastic']==1:
161 | for s in range(10):
162 | _,mus, logcov = DKF_evaluate.infer(dkf,dataset,mask)
163 | allmus.append(np.copy(mus))
164 | alllogcov.append(np.copy(logcov))
165 | allmus_v, alllogcov_v = [], []
166 | for s in range(10):
167 | _,mus, logcov = DKF_evaluate.infer(dkf,dataset_eval,mask)
168 | allmus_v.append(np.copy(mus))
169 | alllogcov_v.append(np.copy(logcov))
170 | mu_train = np.concatenate(allmus,axis=2).mean(2,keepdims=True)
171 | cov_train= np.exp(np.concatenate(alllogcov,axis=2)).mean(2,keepdims=True)
172 | mu_valid = np.concatenate(allmus_v,axis=2).mean(2,keepdims=True)
173 | cov_valid= np.exp(np.concatenate(alllogcov_v,axis=2)).mean(2,keepdims=True)
174 | else:
175 | for s in range(10):
176 | _,mus, logcov = DKF_evaluate.infer(dkf,dataset,mask)
177 | allmus.append(np.copy(mus)[None,:])
178 | alllogcov.append(np.copy(logcov)[None,:])
179 | allmus_v, alllogcov_v = [], []
180 | for s in range(10):
181 | _,mus, logcov = DKF_evaluate.infer(dkf,dataset_eval,mask)
182 | allmus_v.append(np.copy(mus)[None,:])
183 | alllogcov_v.append(np.copy(logcov)[None,:])
184 | mu_train = np.concatenate(allmus,axis=0).mean(0)
185 | cov_train= np.exp(np.concatenate(alllogcov,axis=0)).mean(0)
186 | mu_valid = np.concatenate(allmus_v,axis=0).mean(0)
187 | cov_valid= np.exp(np.concatenate(alllogcov_v,axis=0)).mean(0)
188 | #Extract the learned parameters w/in the generative model
189 | learned_params = {}
190 | for k in dkf.params_synthetic[dkf.params['dataset']]['params']:
191 | learned_params[k] = dkf.tWeights[k+'_W'].get_value()
192 | return mu_train, cov_train, mu_valid, cov_valid, learned_params
193 |
--------------------------------------------------------------------------------