├── .gitignore ├── LICENSE ├── README.md ├── ResearchIpynb.ipynb ├── __init__.py ├── baselines ├── __init__.py └── filters.py ├── default.pkl ├── expt-polyphonic-fast ├── locked_write.py ├── setup_experiments.py ├── template.q └── train_dkf.py ├── expt-polyphonic ├── README.md ├── dkf_polyphonic.py ├── hpc_uas1 │ ├── mdata_MFLR.q │ ├── mdata_STLR.q │ ├── mdata_STR.q │ ├── mdata_STR_MLP.q │ ├── mdata_STR_ar.q │ ├── mdata_STR_ar_aug.q │ ├── mdata_STR_ar_aug_nade.q │ ├── nott_MFLR.q │ ├── nott_STLR.q │ ├── nott_STR.q │ ├── nott_STR_MLP.q │ ├── nott_STR_ar.q │ ├── nott_STR_ar_aug.q │ ├── nott_STR_ar_aug_nade.q │ ├── piano_MFLR.q │ ├── piano_STLR.q │ ├── piano_STR.q │ ├── piano_STR_MLP.q │ ├── piano_STR_ar.q │ ├── piano_STR_ar_aug.q │ └── piano_STR_ar_aug_nade.q ├── jsb_expts.sh ├── musedata_expt.sh ├── nott_expt.sh ├── piano_expt.sh └── train_dkf.py ├── expt-synthetic-fast ├── create_expt.py ├── run_baselines.py └── train.py ├── expt-synthetic ├── README.md ├── create_expt.py ├── run_baselines.py └── train.py ├── expt-template ├── README.md ├── load.py ├── runme.sh └── train.py ├── images ├── ELBO.png └── dkf.png ├── ipynb └── synthetic │ ├── VisualizeSynthetic-fast.ipynb │ ├── VisualizeSynthetic.ipynb │ └── VisualizeSyntheticScaling.ipynb ├── paper.pdf ├── parse_args_dkf.py ├── polyphonic_samples ├── jsb0.mp3 ├── jsb1.mp3 ├── musedata0.mp3 ├── musedata1.mp3 ├── nott0.mp3 ├── nott1.mp3 ├── piano0.mp3 └── piano1.mp3 ├── stinfmodel ├── README.md ├── __init__.py ├── dkf.py ├── evaluate.py ├── learning.py └── testall.sh └── stinfmodel_fast ├── __init__.py ├── dkf.py ├── evaluate.py └── learning.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.h5 2 | *.amat 3 | *.pyc 4 | *.ipnyb_checkpoints/* 5 | *-checkpoint.ipynb 6 | chkpt* 7 | *.png 8 | *.pkl 9 | *.sh 10 | *.pkl.gz 11 | chkpt-*/ 12 | *.o* 13 | *.tmp 14 | *.q 15 | *indicators.txt 16 | *.aux 17 | *.brf 18 | *.log 19 | *.bbl 20 | *.blg 21 | *.dvi 22 | *.ps 23 | *.out 24 | *.fdb_latexmk 25 | *.synctex.gz 26 | *.swp 27 | *.log 28 | ipynb/synthetic/*.pdf 29 | ipynb/medical/*.pdf 30 | expt-polyphonic/*.txt 31 | *.midi 32 | *.wav 33 | *.npz 34 | *.pkl 35 | datasets/tmp_MOCAP*.mat 36 | *.txt 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Sontag Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # structuredInf 2 | Code to fully reproduce benchmark results (and to extend for your own purposes) from the paper: 3 |

Krishnan, Shalit, Sontag. Structured Inference Networks for Nonlinear State Space Models, AAAI 2017.

4 | See here for a simplified and easier to use version of the code. 5 | 6 | ## Goal 7 | The goal of this package is to provide a black box inference algorithm for learning models of time-series data. 8 | Inference during learning and at test time is based on compiled recognition or [inference network.](https://arxiv.org/abs/1401.4082) 9 | 10 | ## Model 11 | The figure below describes a simple model of time-series data. 12 | 13 | This method is a good fit if: 14 | * You have an arbitrarily specified state space model whose parameters you're interested in fitting. 15 | * You would like to have a method for fast posterior inference at train and test time 16 | * Your temporal generative model has Gaussian latent variables (mean/variance can be a nonlinear function of previous timestep's variables). 17 | 18 |

Deep Kalman Filter

19 | 20 | The code uses variational inference during learning to maximize the likelihood of the observed data: 21 |

Evidence Lower Bound

22 | 23 | *Generative Model* 24 | 25 | * The latent variables z1...zT and the observations x1...xT describe the generative process for the data. 26 | * The figure depicts a state space model for time-varying data. 27 | * The emission and transition functions may be pre-specified to have a fixed functional form, a parametric functional form, a function parameterized by a deep neural networks or some combination thereof. 28 | 29 | *Inference Model* 30 | 31 | The box q(z1..zT | x1...xT) represents the inference network. There are several supported inference networks within this package. 32 | * Inference implemented with a bi-directional LSTM 33 | * Inference implemented with an LSTM conditioned on observations in the future 34 | * Inference implemented with an LSTM conditioned on observations from the past 35 | 36 | ## Installation 37 | 38 | ### Requirements 39 | This package has the following requirements: 40 | 41 | python2.7 42 | 43 | [Theano](https://github.com/Theano/Theano) 44 | Used for automatic differentiations 45 | 46 | [theanomodels] (https://github.com/clinicalml/theanomodels) 47 | Wrapper around theano that takes care of bookkeeping, saving/loading models etc. Clone the github repository 48 | and add its location to the PYTHONPATH environment variable so that it is accessible by python. 49 | 50 | [pykalman] (https://pykalman.github.io/) 51 | [Optional: For running baseline UKFs/KFs] 52 | 53 | An NVIDIA GPU w/ atleast 6G of memory is recommended. 54 | 55 | Once the requirements have been met, clone this repository and it's ready to run. 56 | 57 | ### Folder Structure 58 | The following folders contain code to reproduct the results reported in our paper: 59 | * expt-synthetic, expt-polyphonic: Contains code and instructions for reproducing results from the paper. 60 | * baselines/: Contains to run some of the baseline algorithms on the synthetic data 61 | * ipynb/: Ipython notebooks for visualizing saved checkpoints and building plots 62 | 63 | The main files of interest are: 64 | * parse_args_dkf.py: Arguments that the model expects to be present. Looking through it is useful to understand the different knobs available to tune the model. 65 | * stinfmodel/dkf.py: Code to construct the inference and generative model. The code is commented to enable easy modification for different scenarios. 66 | * stinfmodel/evaluate.py: Code to evaluate the Deep Kalman Filter's performance during learning. 67 | * stinfmodel/learning.py: Code for performing stochastic gradient ascent in the Evidence Lower Bound. 68 | 69 | ## Dataset 70 | 71 | We use numpy tensors to store the datasets with binary numpy masks to allow batch sizes comprising sequences of variable length. We train the models using mini-batch gradient descent on negative ELBO. 72 | 73 | ### Format 74 | 75 | The code to run on polyphonic and synthetic datasets has already been created in the theanomodels repository. See theanomodels/datasets/load.py for how the dataset is created and loaded. 76 | 77 | The datasets are stored in three dimensional numpy tensors. 78 | To deal with datapoints 79 | of different lengths, we use numpy matrices comprised of binary masks. There may be different choices 80 | to manipulate data that you may adopt depending on your needs and this is merely a guideline. 81 | 82 | ``` 83 | assert type(dataset) is dict,'Expecting dictionary' 84 | dataset['train'] # N_train x T_train_max x dim_observation : training data 85 | dataset['test'] # N_test x T_test_max x dim_observation : validation data 86 | dataset['valid'] # N_valid x T_valid_max x dim_observation : test data 87 | dataset['mask_train'] # N_train x T_train_max : training masks 88 | dataset['mask_test'] # N_test x T_test_max : validation masks 89 | dataset['mask_valid'] # N_valid x T_valid_max : test masks 90 | dataset['data_type'] # real/binary 91 | dataset['has_masks'] # true/false 92 | ``` 93 | During learning, we select a minibatch of these tensors to update the weights of the model. 94 | 95 | ### Running on different datasets 96 | 97 | **See the folder expt-template for an example of how to setup your data and run the code on your data** 98 | 99 | ## References: 100 | ``` 101 | @inproceedings{krishnan2016structured, 102 | title={Structured Inference Networks for Nonlinear State Space Models}, 103 | author={Krishnan, Rahul G and Shalit, Uri and Sontag, David}, 104 | booktitle={AAAI}, 105 | year={2017} 106 | } 107 | ``` 108 | This paper subsumes the work in : [Deep Kalman Filters] (https://arxiv.org/abs/1511.05121) 109 | -------------------------------------------------------------------------------- /ResearchIpynb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stdout", 12 | "output_type": "stream", 13 | "text": [ 14 | "Couldn't import dot_parser, loading of dot files will not be possible.\n" 15 | ] 16 | }, 17 | { 18 | "name": "stderr", 19 | "output_type": "stream", 20 | "text": [ 21 | "Using gpu device 0: GeForce GTX TITAN (CNMeM is disabled, cuDNN 4007)\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "import theano\n", 27 | "import theano.tensor as T\n", 28 | "import numpy as np" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": { 35 | "collapsed": false 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "X = theano.shared(np.zeros((4,10,20)).astype('float32'))\n", 40 | "mask = theano.shared(np.zeros((4,10)).astype('float32'))\n", 41 | "newX = T.tensor3('newX',dtype='float32')\n", 42 | "newMask=T.matrix('newMask',dtype='float32')\n", 43 | "\n", 44 | "resetX= theano.function([newX,newMask],None,updates=[(X,newX),(mask,newMask)])\n", 45 | "statX = theano.function([],[X.mean(),X.max(),X.sum()])" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 3, 51 | "metadata": { 52 | "collapsed": false 53 | }, 54 | "outputs": [ 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "0.0 0.0 0.0\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "mnX,maX,smX = statX()\n", 65 | "print mnX,maX,smX" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 4, 71 | "metadata": { 72 | "collapsed": false 73 | }, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "[]" 79 | ] 80 | }, 81 | "execution_count": 4, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | } 85 | ], 86 | "source": [ 87 | "nX = np.ones((12,4,32)).astype('float32')\n", 88 | "nM = np.ones((17,3)).astype('float32')\n", 89 | "resetX(newX=nX,newMask=nM)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 5, 95 | "metadata": { 96 | "collapsed": false 97 | }, 98 | "outputs": [ 99 | { 100 | "name": "stdout", 101 | "output_type": "stream", 102 | "text": [ 103 | "1.0 1.0 1536.0\n" 104 | ] 105 | } 106 | ], 107 | "source": [ 108 | "mnX,maX,smX = statX()\n", 109 | "print mnX,maX,smX" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": { 116 | "collapsed": true 117 | }, 118 | "outputs": [], 119 | "source": [] 120 | } 121 | ], 122 | "metadata": { 123 | "kernelspec": { 124 | "display_name": "Python 2", 125 | "language": "python", 126 | "name": "python2" 127 | }, 128 | "language_info": { 129 | "codemirror_mode": { 130 | "name": "ipython", 131 | "version": 2 132 | }, 133 | "file_extension": ".py", 134 | "mimetype": "text/x-python", 135 | "name": "python", 136 | "nbconvert_exporter": "python", 137 | "pygments_lexer": "ipython2", 138 | "version": "2.7.3" 139 | } 140 | }, 141 | "nbformat": 4, 142 | "nbformat_minor": 0 143 | } 144 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | all=['stinfmodel'] 2 | -------------------------------------------------------------------------------- /baselines/__init__.py: -------------------------------------------------------------------------------- 1 | all=['filters'] 2 | -------------------------------------------------------------------------------- /baselines/filters.py: -------------------------------------------------------------------------------- 1 | from pykalman import KalmanFilter 2 | from pykalman import UnscentedKalmanFilter 3 | from pykalman import AdditiveUnscentedKalmanFilter 4 | import numpy as np 5 | 6 | #Running UKF/KF 7 | def runFilter(observations, params, dname, filterType): 8 | """ 9 | Run Kalman Filter (UKF/KF) and return smoothed means/covariances 10 | observations : nsample x T 11 | params : {'dname'... contains all the necessary parameters for KF/UKF} 12 | filterType : 'KF' or 'UKF' 13 | """ 14 | s1 = set(params[dname].keys()) 15 | s2 = set(['trans_fxn','obs_fxn','trans_cov','obs_cov','init_mu','init_cov', 16 | 'trans_mult','obs_mult','trans_drift','obs_drift','baseline']) 17 | for k in s2: 18 | assert k in s1,k+' not found in params' 19 | #assert s1.issubset(s2) and s1.issuperset(s2),'Missing in params: '+', '.join(list(s2.difference(s1))) 20 | assert filterType=='KF' or filterType=='UKF','Expecting KF/UKF' 21 | model,mean,var = None,None,None 22 | X = observations.squeeze() 23 | #assert len(X.shape)==2,'observations must be nsamples x T' 24 | if filterType=='KF': 25 | def setupArr(arr): 26 | if type(arr) is np.ndarray: 27 | return arr 28 | else: 29 | return np.array([arr]) 30 | model=KalmanFilter( 31 | transition_matrices = setupArr(params[dname]['trans_mult']), #multiplier for z_t-1 32 | observation_matrices = setupArr(params[dname]['obs_mult']).T, #multiplier for z_t 33 | transition_covariance = setupArr(params[dname]['trans_cov_full']), #transition cov 34 | observation_covariance= setupArr(params[dname]['obs_cov_full']), #obs cov 35 | transition_offsets = setupArr(params[dname]['trans_drift']),#additive const. in trans 36 | observation_offsets = setupArr(params[dname]['obs_drift']), #additive const. in obs 37 | initial_state_mean = setupArr(params[dname]['init_mu']), 38 | initial_state_covariance = setupArr(params[dname]['init_cov_full'])) 39 | else: 40 | #In this case, the transition and emission function may have other parameters 41 | #Create wrapper functions that are instantiated w/ the true parameters 42 | #and pass them to the UKF 43 | def trans_fxn(z): 44 | return params[dname]['trans_fxn'](z, fxn_params = params[dname]['params']) 45 | def obs_fxn(z): 46 | return params[dname]['obs_fxn'](z, fxn_params = params[dname]['params']) 47 | 48 | model=AdditiveUnscentedKalmanFilter( 49 | transition_functions = trans_fxn, #params[dname]['trans_fxn'], 50 | observation_functions = obs_fxn, #params[dname]['obs_fxn'], 51 | transition_covariance = np.array([params[dname]['trans_cov']]), #transition cov 52 | observation_covariance= np.array([params[dname]['obs_cov']]), #obs cov 53 | initial_state_mean = np.array([params[dname]['init_mu']]), 54 | initial_state_covariance = np.array(params[dname]['init_cov'])) 55 | #Run smoothing algorithm with model 56 | dim_stoc = params[dname]['dim_stoc'] 57 | if dim_stoc>1: 58 | mus = np.zeros((X.shape[0],X.shape[1],dim_stoc)) 59 | cov = np.zeros((X.shape[0],X.shape[1],dim_stoc)) 60 | else: 61 | mus = np.zeros(X.shape) 62 | cov = np.zeros(X.shape) 63 | ll = 0 64 | for n in range(X.shape[0]): 65 | (smoothed_state_means, smoothed_state_covariances) = model.smooth(X[n,:]) 66 | if dim_stoc>1: 67 | mus[n,:] = smoothed_state_means 68 | cov[n,:] = np.concatenate([np.diag(k)[None,:] for k in smoothed_state_covariances],axis=0) 69 | else: 70 | mus[n,:] = smoothed_state_means.ravel() 71 | cov[n,:] = smoothed_state_covariances.ravel() 72 | if filterType=='KF': 73 | ll += model.loglikelihood(X[n,:]) 74 | return mus,cov,ll 75 | 76 | #Generating Data 77 | def sampleGaussian(mu,cov): 78 | """ 79 | Sample from gaussian with mu/cov 80 | 81 | returns: random sample from N(mu,cov) of shape mu 82 | mu: must be numpy array 83 | cov: can be scalar or same shape as mu 84 | """ 85 | return np.multiply(np.random.randn(*mu.shape),np.sqrt(cov))+mu 86 | 87 | def generateData(N,T,params, dname): 88 | """ 89 | Generate sequential dataset 90 | returns: N x T matrix of observations, latents 91 | N : #samples 92 | T : time steps 93 | params : {'dname'...contains necessary functions} 94 | dname : dataset name 95 | """ 96 | np.random.seed(1) 97 | assert dname in params,dname+' not found in params' 98 | Z = np.zeros((N,T)) 99 | X = np.zeros((N,T)) 100 | Z[:,0] = sampleGaussian(params[dname]['init_mu']*np.ones((N,)), 101 | params[dname]['init_cov']) 102 | for t in range(1,T): 103 | Z[:,t] = sampleGaussian(params[dname]['trans_fxn'](Z[:,t-1]),params[dname]['trans_cov']) 104 | for t in range(T): 105 | X[:,t] = sampleGaussian(params[dname]['obs_fxn'](Z[:,t]),params[dname]['obs_cov']) 106 | return X,Z 107 | 108 | #Reconstruction 109 | def reconsMus(mus_posterior, params, dname): 110 | """ 111 | Estimate the observation means using posterior means 112 | mus_posterior : N x T matrix of posterior means 113 | params : {'dname'...contains necessary functions} 114 | """ 115 | mu_rec = np.zeros(mus_posterior.shape) 116 | for t in range(mus_posterior.shape[1]): 117 | mu_rec[:,t] = params[dname]['obs_fxn'](mus_posterior[:,t]) 118 | return mu_rec 119 | -------------------------------------------------------------------------------- /default.pkl: -------------------------------------------------------------------------------- 1 | (dp1 2 | S'shuffle' 3 | p2 4 | I00 5 | sS'use_nade' 6 | p3 7 | I00 8 | sS'cov_explicit' 9 | p4 10 | I00 11 | sS'dataset' 12 | p5 13 | S'' 14 | sS'epochs' 15 | p6 16 | I2000 17 | sS'seed' 18 | p7 19 | I1 20 | sS'init_weight' 21 | p8 22 | F0.10000000000000001 23 | sS'reg_spec' 24 | p9 25 | S'_' 26 | sS'use_prev_input' 27 | p10 28 | I00 29 | sS'reg_value' 30 | p11 31 | F0.050000000000000003 32 | sS'reloadFile' 33 | p12 34 | S'./NOSUCHFILE' 35 | p13 36 | sS'dim_stochastic' 37 | p14 38 | I100 39 | sS'rnn_layers' 40 | p15 41 | I1 42 | sS'transition_layers' 43 | p16 44 | I2 45 | sS'lr' 46 | p17 47 | F0.00080000000000000004 48 | sS'reg_type' 49 | p18 50 | S'l2' 51 | p19 52 | sS'init_scheme' 53 | p20 54 | S'uniform' 55 | p21 56 | sS'replicate_K' 57 | p22 58 | NsS'optimizer' 59 | p23 60 | S'adam' 61 | p24 62 | sS'use_generative_prior' 63 | p25 64 | I00 65 | sS'q_mlp_layers' 66 | p26 67 | I1 68 | sS'maxout_stride' 69 | p27 70 | I4 71 | sS'batch_size' 72 | p28 73 | I20 74 | sS'savedir' 75 | p29 76 | S'./chkpt' 77 | p30 78 | sS'forget_bias' 79 | p31 80 | F-5 81 | sS'inference_model' 82 | p32 83 | S'structured' 84 | p33 85 | sS'emission_layers' 86 | p34 87 | I2 88 | sS'savefreq' 89 | p35 90 | I25 91 | sS'rnn_size' 92 | p36 93 | I600 94 | sS'paramFile' 95 | p37 96 | g13 97 | sS'emission_type' 98 | p38 99 | S'mlp' 100 | p39 101 | sS'nonlinearity' 102 | p40 103 | S'relu' 104 | p41 105 | sS'rnn_dropout' 106 | p42 107 | F0.10000000000000001 108 | sS'dim_hidden' 109 | p43 110 | I200 111 | sS'var_model' 112 | p44 113 | S'lstmr' 114 | p45 115 | sS'anneal_rate' 116 | p46 117 | F10 118 | sS'debug' 119 | p47 120 | I00 121 | sS'validate_only' 122 | p48 123 | I00 124 | sS'transition_type' 125 | p49 126 | S'simple_gated' 127 | p50 128 | sS'unique_id' 129 | p51 130 | S'DKF_lr-8_0000e-04-vm-lstmr-inf-structured-dh-200-ds-100-nl-relu-bs-20-ep-2000-rs-600-rd-1_0000e-01-ttype-simple_gated-etype-mlp-previnp-False-ar-1_0000e+01-rv-5_0000e-02-nade-False-uid' 131 | p52 132 | sS'leaky_param' 133 | p53 134 | F0 135 | s. -------------------------------------------------------------------------------- /expt-polyphonic-fast/locked_write.py: -------------------------------------------------------------------------------- 1 | import fcntl,errno,time 2 | with open('remove.me','a') as f: 3 | while True: 4 | try: 5 | fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) 6 | break 7 | except IOError as e: 8 | if e.errno != errno.EAGAIN: 9 | raise 10 | else: 11 | time.sleep(0.1) 12 | f.write('another line\n') 13 | fcntl.flock(f, fcntl.LOCK_UN) 14 | -------------------------------------------------------------------------------- /expt-polyphonic-fast/setup_experiments.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rahul G. Krishnan 3 | 4 | Script to setup experiments either on HPC or individually 5 | """ 6 | import numpy as np 7 | from collections import OrderedDict 8 | import argparse,os 9 | 10 | parser = argparse.ArgumentParser(description='Setup Expts') 11 | parser.add_argument('-hpc','--onHPC',action='store_true') 12 | parser.add_argument('-dset','--dataset', default='jsb',action='store') 13 | parser.add_argument('-ngpu','--num_gpus', default=4,action='store',type=int) 14 | args = parser.parse_args() 15 | 16 | #MAIN FLAGS 17 | onHPC = args.onHPC 18 | DATASET = args.dataset 19 | THFLAGS = 'THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu" ' 20 | 21 | #Get dataset 22 | dataset = DATASET.split('-')[0] 23 | all_datasets = ['jsb','piano','nottingham','musedata'] 24 | assert dataset in all_datasets,'Dset not found: '+dataset 25 | all_expts = OrderedDict() 26 | for dset in all_datasets: 27 | all_expts[dset] = OrderedDict() 28 | 29 | #Experiments to run for each dataset 30 | all_expts['jsb']['ST-R'] = 'python2.7 train_dkf.py -vm R -infm structured -dset ' 31 | all_expts['jsb']['MF-LR'] = 'python2.7 train_dkf.py -vm LR -infm mean_field -dset ' 32 | all_expts['jsb']['ST-LR'] = 'python2.7 train_dkf.py -vm LR -infm structured -dset ' 33 | all_expts['jsb']['ST-R-mlp'] = 'python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset ' 34 | all_expts['jsb']['ST-L'] = 'python2.7 train_dkf.py -vm L -infm structured -dset ' 35 | all_expts['jsb']['DKF-ar'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset ' 36 | all_expts['jsb']['DKF-aug'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset ' 37 | all_expts['jsb']['DKF-aug-nade'] ='python2.7 train_dkf.py -vm R -infm structured -etype conditional -previnp -usenade -dset ' 38 | 39 | all_expts['nottingham']['ST-R'] = 'python2.7 train_dkf.py -vm R -infm structured -dset ' 40 | all_expts['nottingham']['MF-LR'] = 'python2.7 train_dkf.py -vm LR -infm mean_field -dset ' 41 | all_expts['nottingham']['ST-LR'] = 'python2.7 train_dkf.py -vm LR -infm structured -dset ' 42 | all_expts['nottingham']['ST-R-mlp'] = 'python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset ' 43 | all_expts['nottingham']['ST-L'] = 'python2.7 train_dkf.py -vm L -infm structured -dset ' 44 | all_expts['nottingham']['DKF-ar'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset ' 45 | all_expts['nottingham']['DKF-aug'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset ' 46 | all_expts['nottingham']['DKF-aug-nade'] ='python2.7 train_dkf.py -vm R -infm structured -ar 1000 -etype conditional -previnp -usenade -dset ' 47 | 48 | all_expts['musedata']['ST-R'] = 'python2.7 train_dkf.py -vm R -infm structured -dset ' 49 | all_expts['musedata']['MF-LR'] = 'python2.7 train_dkf.py -vm LR -infm mean_field -dset ' 50 | all_expts['musedata']['ST-LR'] = 'python2.7 train_dkf.py -vm LR -infm structured -dset ' 51 | all_expts['musedata']['ST-R-mlp'] = 'python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset ' 52 | all_expts['musedata']['ST-L'] = 'python2.7 train_dkf.py -vm L -infm structured -dset ' 53 | all_expts['musedata']['DKF-ar'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset ' 54 | all_expts['musedata']['DKF-aug'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset ' 55 | all_expts['musedata']['DKF-aug-nade'] ='python2.7 train_dkf.py -vm R -infm structured -ar 1000 -etype conditional -previnp -usenade -dset -ds 50 -dh 100 -rs 400' 56 | 57 | all_expts['piano']['ST-R'] = 'python2.7 train_dkf.py -vm R -infm structured -dset ' 58 | all_expts['piano']['MF-LR'] = 'python2.7 train_dkf.py -vm LR -infm mean_field -dset ' 59 | all_expts['piano']['ST-LR'] = 'python2.7 train_dkf.py -vm LR -infm structured -dset ' 60 | all_expts['piano']['ST-R-mlp'] = 'python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset ' 61 | all_expts['piano']['ST-L'] = 'python2.7 train_dkf.py -vm L -infm structured -dset ' 62 | all_expts['piano']['DKF-ar'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset ' 63 | all_expts['piano']['DKF-aug'] ='python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset ' 64 | all_expts['piano']['DKF-aug-nade'] ='python2.7 train_dkf.py -vm R -infm structured -ar 1000 -etype conditional -previnp -usenade -dset -ds 50 -dh 100 -rs 400' 65 | 66 | if onHPC: 67 | DIR = './hpc_'+dataset 68 | os.system('rm -rf '+DIR) 69 | os.system('mkdir -p '+DIR) 70 | with open('template.q') as ff: 71 | template = ff.read() 72 | runallcmd = '' 73 | for name in all_expts[dataset]: 74 | runcmd = all_expts[dataset][name].replace('',DATASET)+' -uid '+name 75 | command = THFLAGS.replace('',str(np.random.randint(args.num_gpus)))+runcmd 76 | with open(DIR+'/'+name+'.q','w') as f: 77 | f.write(template.replace('',name).replace('',command)) 78 | print 'Wrote to:',DIR+'/'+name+'.q' 79 | runallcmd+= 'qsub '+name+'.q\n' 80 | with open(DIR+'/runall.sh','w') as f: 81 | f.write(runallcmd) 82 | else: 83 | for name in all_expts[dataset]: 84 | runcmd = all_expts[dataset][name].replace('',DATASET)+' -uid '+name 85 | command = THFLAGS.replace('',str(np.random.randint(args.num_gpus)))+runcmd 86 | print command 87 | -------------------------------------------------------------------------------- /expt-polyphonic-fast/template.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=30:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.5v5.1 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic-fast 19 | cd $RUNDIR 20 | 21 | -------------------------------------------------------------------------------- /expt-polyphonic-fast/train_dkf.py: -------------------------------------------------------------------------------- 1 | import os,time,sys 2 | import fcntl,errno 3 | import socket 4 | sys.path.append('../') 5 | from datasets.load import loadDataset 6 | from parse_args_dkf import params 7 | from utils.misc import removeIfExists,createIfAbsent,mapPrint,saveHDF5,displayTime,getLowestError 8 | 9 | if params['dataset']=='': 10 | params['dataset']='jsb' 11 | dataset = loadDataset(params['dataset']) 12 | params['savedir']+='-'+params['dataset'] 13 | createIfAbsent(params['savedir']) 14 | 15 | #Saving/loading 16 | for k in ['dim_observations','dim_actions','data_type']: 17 | params[k] = dataset[k] 18 | mapPrint('Options: ',params) 19 | 20 | if params['use_nade']: 21 | params['data_type']='binary_nade' 22 | #Setup VAE Model (or reload from existing savefile) 23 | start_time = time.time() 24 | from stinfmodel_fast.dkf import DKF 25 | import stinfmodel_fast.learning as DKF_learn 26 | import stinfmodel_fast.evaluate as DKF_evaluate 27 | displayTime('import DKF',start_time, time.time()) 28 | dkf = None 29 | 30 | 31 | #Remove from params 32 | start_time = time.time() 33 | removeIfExists('./NOSUCHFILE') 34 | reloadFile = params.pop('reloadFile') 35 | if os.path.exists(reloadFile): 36 | pfile=params.pop('paramFile') 37 | assert os.path.exists(pfile),pfile+' not found. Need paramfile' 38 | print 'Reloading trained model from : ',reloadFile 39 | print 'Assuming ',pfile,' corresponds to model' 40 | dkf = DKF(params, paramFile = pfile, reloadFile = reloadFile) 41 | else: 42 | pfile= params['savedir']+'/'+params['unique_id']+'-config.pkl' 43 | print 'Training model from scratch. Parameters in: ',pfile 44 | dkf = DKF(params, paramFile = pfile) 45 | displayTime('Building dkf',start_time, time.time()) 46 | 47 | savef = os.path.join(params['savedir'],params['unique_id']) 48 | print 'Savefile: ',savef 49 | start_time= time.time() 50 | savedata = DKF_learn.learn(dkf, dataset['train'], dataset['mask_train'], 51 | epoch_start =0 , 52 | epoch_end = params['epochs'], 53 | batch_size = params['batch_size'], 54 | savefreq = params['savefreq'], 55 | savefile = savef, 56 | dataset_eval=dataset['valid'], 57 | mask_eval = dataset['mask_valid'], 58 | replicate_K= params['replicate_K'], 59 | shuffle = False 60 | ) 61 | displayTime('Running DKF',start_time, time.time() ) 62 | 63 | dkf = None 64 | """ 65 | Load the best DKF based on the validation error 66 | """ 67 | epochMin, valMin, idxMin = getLowestError(savedata['valid_bound']) 68 | reloadFile= pfile.replace('-config.pkl','')+'-EP'+str(int(epochMin))+'-params.npz' 69 | print 'Loading from : ',reloadFile 70 | params['validate_only'] = True 71 | dkf_best = DKF(params, paramFile = pfile, reloadFile = reloadFile) 72 | additional = {} 73 | savedata['bound_test_best'] = DKF_evaluate.evaluateBound(dkf_best, dataset['test'], dataset['mask_test'], S = 2, batch_size = params['batch_size'], additional =additional) 74 | savedata['bound_tsbn_test_best'] = additional['tsbn_bound'] 75 | savedata['ll_test_best'] = DKF_evaluate.impSamplingNLL(dkf_best, dataset['test'], dataset['mask_test'], S = 2000, batch_size = params['batch_size']) 76 | saveHDF5(savef+'-final.h5',savedata) 77 | print 'Experiment Name: <',params['expt_name'],'> Test Bound: ',savedata['bound_test_best'],' ',savedata['bound_tsbn_test_best'],' ',savedata['ll_test_best'] 78 | 79 | with open(params['dataset']+'-results.txt','a') as f: 80 | while True: 81 | try: 82 | fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) 83 | break 84 | except IOError as e: 85 | if e.errno != errno.EAGAIN: 86 | raise 87 | else: 88 | time.sleep(0.1) 89 | f.write('Experiment Name: <'+params['expt_name']+'> Test Bound: '+str(savedata['bound_test_best'])+' '+str(savedata['bound_tsbn_test_best'])+' '+str(savedata['ll_test_best'])+'\n') 90 | fcntl.flock(f, fcntl.LOCK_UN) 91 | 92 | if 'nyu.edu' in socket.gethostname(): 93 | import ipdb;ipdb.set_trace() 94 | -------------------------------------------------------------------------------- /expt-polyphonic/README.md: -------------------------------------------------------------------------------- 1 | ## Modeling Polyphonic Music with a Deep Kalman Filter 2 | Author: Rahul G. Krishnan (rahul@cs.nyu.edu) 3 | 4 | ## Reproducing Numbers in the Paper 5 | ``` 6 | Run [dataset]_expts.sh to run the experiments on the polyphonic datasets 7 | ``` 8 | 9 | ## Model Details 10 | DKF 11 | Standard Deep Kalman Filter. The default inference algorithm is set to be ST-R (see paper for more details) although 12 | this can be modified through a variety of knobs primarily the -inference_model and the -var_model hyperparameters. 13 | 14 | DKF NADE 15 | Use a nade to model the data rather than a distribution that treats dimensions of the data independantly. (can be used with the -usenade flag) 16 | 17 | DKF AUG 18 | Augmented DKF. 19 | The emission distribution of the generative model is parameterized as p(x_t|x_{t-1}, z_t) (toggled with the -previnp flag) 20 | and the transition distribution is parameterized as p(zt|z{t-1}) (can be activated with the -etype conditional flag) 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/mdata_MFLR.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N mdata-MFLR 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=0.9,scan.allow_gc=False,compiledir_format=compiledir_format=compiledir_%(platform)s-%(processor)s-%(python_version)s-%(python_bitwidth)s-1" python train_dkf.py -vm LR -infm mean_field -dset musedata-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/mdata_STLR.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N mdata-STLR 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm LR -infm structured -dset musedata-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/mdata_STR.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=30:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N mdata-ST-R 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -dset musedata-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/mdata_STR_MLP.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N mdata-STR-MLP 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset musedata-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/mdata_STR_ar.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N mdata-AR 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset musedata-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/mdata_STR_ar_aug.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N mdata-ARAUG 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset musedata-sorted -bs 10 -dh 100 -ds 50 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/mdata_STR_ar_aug_nade.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N mdata-STRAUG 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset musedata-sorted -bs 10 -dh 100 -ds 50 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/nott_MFLR.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N nott-MFLR 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=0.9,scan.allow_gc=False,compiledir_format=compiledir_format=compiledir_%(platform)s-%(processor)s-%(python_version)s-%(python_bitwidth)s-1" python train_dkf.py -vm LR -infm mean_field -dset nottingham-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/nott_STLR.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N nott-STLR 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm LR -infm structured -dset nottingham-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/nott_STR.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=30:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N nott-ST-R 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -dset nottingham-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/nott_STR_MLP.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N nott-STR-MLP 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset nottingham-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/nott_STR_ar.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N nott-AR 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset nottingham-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/nott_STR_ar_aug.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N nott-ARAUG 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset nottingham-sorted -bs 10 -dh 100 -ds 50 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/nott_STR_ar_aug_nade.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N nott-STRAUG 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset nottingham-sorted -bs 10 -dh 100 -ds 50 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/piano_MFLR.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N piano-MFLR 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=0.9,scan.allow_gc=False,compiledir_format=compiledir_format=compiledir_%(platform)s-%(processor)s-%(python_version)s-%(python_bitwidth)s-1" python train_dkf.py -vm LR -infm mean_field -dset piano-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/piano_STLR.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N piano-STLR 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm LR -infm structured -dset piano-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/piano_STR.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=30:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N piano-ST-R 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -dset piano-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/piano_STR_MLP.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N piano-STR-MLP 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset piano-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/piano_STR_ar.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N piano-AR 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset piano-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/piano_STR_ar_aug.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N piano-ARAUG 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset piano-sorted -bs 10 -dh 100 -ds 50 21 | -------------------------------------------------------------------------------- /expt-polyphonic/hpc_uas1/piano_STR_ar_aug_nade.q: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l nodes=1:ppn=2:gpus=1:k80 3 | #PBS -l walltime=24:00:00 4 | #PBS -l mem=16GB 5 | #PBS -N piano-STRAUG 6 | #PBS -M rahul@cs.nyu.edu 7 | #PBS -j oe 8 | 9 | module purge 10 | module load node 11 | module load cmake 12 | module load python/intel/2.7.6 13 | module load numpy/intel/1.9.2 14 | module load hdf5/intel/1.8.12 15 | module load cuda/7.5.18 16 | module load cudnn/7.0 17 | 18 | RUNDIR=$SCRATCH/structuredinference/expt-polyphonic 19 | cd $RUNDIR 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset piano-sorted -bs 10 -dh 100 -ds 50 21 | -------------------------------------------------------------------------------- /expt-polyphonic/jsb_expts.sh: -------------------------------------------------------------------------------- 1 | #MF-LR 2 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm LR -infm mean_field -dset jsb-sorted 3 | #7.098 7.042 -6.6828 4 | #Look into this -> makes sense? Rerunning on rose2[0] 5 | 6 | #ST-R 7 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -dset jsb-sorted 8 | #7.0145 6.9564 -6.5863 9 | 10 | #ST-R -ttype mlp 11 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset jsb-sorted 12 | #7.0976 7.0421 -6.652 13 | 14 | #ST-LR 15 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm LR -infm structured -dset jsb-sorted 16 | # 17 | 18 | #DKF w/ AR 19 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset jsb-sorted 20 | #6.904 6.854 -6.4320 21 | 22 | #DKF Aug 23 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset jsb-sorted 24 | #6.867 6.7935 -6.3199 25 | 26 | #DKF Aug w/ NADE 27 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset jsb-sorted 28 | #5.398 5.325 -5.3809 29 | -------------------------------------------------------------------------------- /expt-polyphonic/musedata_expt.sh: -------------------------------------------------------------------------------- 1 | #MF-LR 2 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm LR -infm mean_field -dset musedata-sorted 3 | 4 | #ST-R 5 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -dset musedata-sorted 6 | 7 | #ST-R -ttype mlp 8 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset musedata-sorted 9 | 10 | #ST-LR 11 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm LR -infm structured -dset musedata-sorted 12 | 13 | #DKF w/ AR 14 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset musedata-sorted 15 | 16 | #DKF Aug 17 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset musedata-sorted 18 | 19 | #DKF Aug w/ NADE 20 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset musedata-sorted 21 | -------------------------------------------------------------------------------- /expt-polyphonic/nott_expt.sh: -------------------------------------------------------------------------------- 1 | #MF-LR 2 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm LR -infm mean_field -dset nottingham-sorted 3 | #7.098 7.042 -6.6828 4 | #Look into this -> makes sense? Rerunning on rose2[0] 5 | 6 | #ST-R 7 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -dset nottingham-sorted 8 | #7.0145 6.9564 -6.5863 9 | 10 | #ST-R -ttype mlp 11 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset nottingham-sorted 12 | #7.0976 7.0421 -6.652 13 | 14 | #ST-LR 15 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm LR -infm structured -dset nottingham-sorted 16 | # 17 | 18 | #DKF w/ AR 19 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset nottingham-sorted 20 | #6.904 6.854 -6.4320 21 | 22 | #DKF Aug 23 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset nottingham-sorted 24 | #6.867 6.7935 -6.3199 25 | 26 | #DKF Aug w/ NADE 27 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset nottingham-sorted 28 | #5.398 5.325 -5.3809 29 | -------------------------------------------------------------------------------- /expt-polyphonic/piano_expt.sh: -------------------------------------------------------------------------------- 1 | #MF-LR 2 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm LR -infm mean_field -dset piano-sorted 3 | #7.098 7.042 -6.6828 4 | #Look into this -> makes sense? Rerunning on rose2[0] 5 | 6 | #ST-R 7 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -dset piano-sorted 8 | #7.0145 6.9564 -6.5863 9 | 10 | #ST-R -ttype mlp 11 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ttype mlp -dset piano-sorted 12 | #7.0976 7.0421 -6.652 13 | 14 | #ST-LR 15 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm LR -infm structured -dset piano-sorted 16 | # 17 | 18 | #DKF w/ AR 19 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -dset piano-sorted 20 | #6.904 6.854 -6.4320 21 | 22 | #DKF Aug 23 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu1" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset piano-sorted 24 | #6.867 6.7935 -6.3199 25 | 26 | #DKF Aug w/ NADE 27 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 train_dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset piano-sorted 28 | #5.398 5.325 -5.3809 29 | -------------------------------------------------------------------------------- /expt-polyphonic/train_dkf.py: -------------------------------------------------------------------------------- 1 | import os,time,sys 2 | sys.path.append('../') 3 | from datasets.load import loadDataset 4 | from parse_args_dkf import params 5 | from utils.misc import removeIfExists,createIfAbsent,mapPrint,saveHDF5,displayTime,getLowestError 6 | 7 | if params['dataset']=='': 8 | params['dataset']='jsb' 9 | dataset = loadDataset(params['dataset']) 10 | params['savedir']+='-'+params['dataset'] 11 | createIfAbsent(params['savedir']) 12 | 13 | #Saving/loading 14 | for k in ['dim_observations','dim_actions','data_type']: 15 | params[k] = dataset[k] 16 | mapPrint('Options: ',params) 17 | 18 | if params['use_nade']: 19 | params['data_type']='binary_nade' 20 | #Setup VAE Model (or reload from existing savefile) 21 | start_time = time.time() 22 | from stinfmodel.dkf import DKF 23 | import stinfmodel.learning as DKF_learn 24 | import stinfmodel.evaluate as DKF_evaluate 25 | displayTime('import DKF',start_time, time.time()) 26 | dkf = None 27 | 28 | #Remove from params 29 | start_time = time.time() 30 | removeIfExists('./NOSUCHFILE') 31 | reloadFile = params.pop('reloadFile') 32 | if os.path.exists(reloadFile): 33 | pfile=params.pop('paramFile') 34 | assert os.path.exists(pfile),pfile+' not found. Need paramfile' 35 | print 'Reloading trained model from : ',reloadFile 36 | print 'Assuming ',pfile,' corresponds to model' 37 | dkf = DKF(params, paramFile = pfile, reloadFile = reloadFile) 38 | else: 39 | pfile= params['savedir']+'/'+params['unique_id']+'-config.pkl' 40 | print 'Training model from scratch. Parameters in: ',pfile 41 | dkf = DKF(params, paramFile = pfile) 42 | displayTime('Building dkf',start_time, time.time()) 43 | 44 | savef = os.path.join(params['savedir'],params['unique_id']) 45 | print 'Savefile: ',savef 46 | start_time= time.time() 47 | savedata = DKF_learn.learn(dkf, dataset['train'], dataset['mask_train'], 48 | epoch_start =0 , 49 | epoch_end = params['epochs'], 50 | batch_size = params['batch_size'], 51 | savefreq = params['savefreq'], 52 | savefile = savef, 53 | dataset_eval=dataset['valid'], 54 | mask_eval = dataset['mask_valid'], 55 | replicate_K= params['replicate_K'], 56 | shuffle = False 57 | ) 58 | displayTime('Running DKF',start_time, time.time() ) 59 | """ 60 | Load the best DKF based on the validation error 61 | """ 62 | epochMin, valMin, idxMin = getLowestError(savedata['valid_bound']) 63 | reloadFile= pfile.replace('-config.pkl','')+'-EP'+str(int(epochMin))+'-params.npz' 64 | print 'Loading from : ',reloadFile 65 | params['validate_only'] = True 66 | dkf_best = DKF(params, paramFile = pfile, reloadFile = reloadFile) 67 | additional = {} 68 | savedata['bound_test_best'] = DKF_evaluate.evaluateBound(dkf_best, dataset['test'], dataset['mask_test'], S = 2, batch_size = params['batch_size'], additional =additional) 69 | savedata['bound_tsbn_test_best'] = additional['tsbn_bound'] 70 | savedata['ll_test_best'] = DKF_evaluate.impSamplingNLL(dkf_best, dataset['test'], dataset['mask_test'], S = 2000, batch_size = params['batch_size']) 71 | saveHDF5(savef+'-final.h5',savedata) 72 | print 'Test Bound: ',savedata['bound_test_best'],savedata['bound_tsbn_test_best'],savedata['ll_test_best'] 73 | import ipdb;ipdb.set_trace() 74 | -------------------------------------------------------------------------------- /expt-synthetic-fast/create_expt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sys 4 | import itertools 5 | import argparse 6 | parser = argparse.ArgumentParser(description='Create experiments for synthetic data') 7 | parser.add_argument("-s",'--onScreen', type=bool, default=False,help ='create command to run on screen') 8 | args = parser.parse_args() 9 | np.random.seed(1) 10 | ## Standard checks 11 | var_model = ['LR','L','R'] 12 | inference_model = ['mean_field','structured'] 13 | optimization = ['adam'] 14 | lr = ['0.0008'] 15 | nonlinearity = ['relu'] 16 | datasets = ['synthetic9','synthetic10'] 17 | dim_hidden = [40] 18 | rnn_layers = [1] 19 | rnn_size = [40] 20 | batch_size = [250] 21 | rnn_dropout = [0.00001] 22 | #Add the list and it2s name here 23 | takecrossprod = [inference_model,datasets,optimization, lr, var_model, nonlinearity, dim_hidden, rnn_layers, rnn_size, batch_size, rnn_dropout] 24 | names = ['infm','dset','opt','lr','vm','nl','dh','rl','rs','bs','rd'] 25 | 26 | def buildConfig(element,paramnames): 27 | config = {} 28 | for idx,paramvalue in enumerate(element): 29 | config[paramnames[idx]] = paramvalue 30 | return config 31 | def buildParamString(config): 32 | paramstr = '' 33 | for paramname in config: 34 | paramstr += '-'+paramname+' '+str(config[paramname])+' ' 35 | return paramstr 36 | 37 | tovis = {} 38 | for idx, element in enumerate(itertools.product(*takecrossprod)): 39 | config = buildConfig(element,names) 40 | #Fixed parameters 41 | config['ep'] = '1500' 42 | config['sfreq'] = '100' 43 | gpun = 1 44 | if idx%2==0: 45 | gpun = 2 46 | if 'lr' not in config: 47 | config['lr']='0.001' 48 | if 'opt' not in config: 49 | config['opt']='adam' 50 | 51 | paramstr= buildParamString(config) 52 | assert 'dset' in config and 'vm' in config and 'infm' in config, 'Expecting dataset, var_model and inference_model to be in config' 53 | #Savefile to look for in visualize 54 | savestr = 'dkf_'+config['dset']+'_ep'+config['ep']+'_rs20dh20ds1vm'+config['vm']+'lr'+config['lr'] 55 | savestr+= 'ep'+config['ep']+'opt'+config['opt'] 56 | savestr +='infm'+config['infm']+'synthetic' 57 | tovis[paramstr.replace(' ','')] = savestr 58 | 59 | cmd = "export CUDA_VISIBLE_DEVICES="+str(gpun-1)+";" 60 | if args.onScreen: 61 | randStr = 'THEANO_FLAGS="lib.cnmem=0.4,compiledir_format=compiledir_%(platform)s-%(processor)s-%(python_version)s-%(python_bitwidth)s-'+str(np.random.randint(100))+'"' 62 | cmd += randStr+ " python2.7 train.py "+paramstr 63 | name = paramstr.replace(' ','').strip().replace('inference_model','').replace('var_model','').replace('optimization','') 64 | name = name.replace('-','') 65 | name = name.replace('synthetic','S').replace('nonlinearity','NL').replace('dataset','') 66 | print "screen -S "+name+" -t "+name+" -d -m" 67 | print "screen -r "+name+" -p 0 -X stuff $\'"+cmd+"\\n\'" 68 | #+"| tee ./checkpointdir_"+config['dataset']+'/'+savestr.replace('t7','log')+"\\n\'" 69 | else: 70 | cmd += "python2.7 train.py "+paramstr 71 | print cmd 72 | -------------------------------------------------------------------------------- /expt-synthetic-fast/run_baselines.py: -------------------------------------------------------------------------------- 1 | import sys,os,h5py,glob 2 | sys.path.append('../') 3 | from baselines.filters import runFilter 4 | import numpy as np 5 | from utils.misc import getPYDIR 6 | from datasets.synthp import params_synthetic 7 | 8 | def runBaselines(DIR, name): 9 | DATADIR = getPYDIR()+'/datasets/synthetic' 10 | assert os.path.exists(DATADIR),DATADIR+' not found. must have this to run baselines' 11 | if not os.path.exists(DIR): 12 | os.mkdir(DIR) 13 | 14 | for f in glob.glob(DATADIR+'/*.h5'): 15 | dataset = os.path.basename(f).replace('.h5','') 16 | if name not in dataset:#'synthetic' not in dataset or 'synthetic11' in dataset: 17 | continue 18 | print dataset,f 19 | if os.path.exists(DIR+'/'+dataset+'-baseline.h5'): 20 | print DIR+'/'+dataset+'-baseline.h5',' found....not rerunning baseline' 21 | continue 22 | print 'Reading from: ',f,' Saving to: ',DIR+'/'+dataset+'-baseline.h5' 23 | 24 | filterType = params_synthetic[dataset]['baseline'] 25 | h5fout = h5py.File(DIR+'/'+dataset+'-baseline.h5',mode='w') 26 | h5f = h5py.File(f,mode='r') 27 | 28 | if int(dataset.split('synthetic')[1]) in [9,10,11]: 29 | print 'Running filter: ',filterType,' on train' 30 | X = h5f['train'].value 31 | mus,cov,ll = runFilter(X, params_synthetic, dataset, filterType) 32 | h5fout.create_dataset('train_mu',data = mus) 33 | h5fout.create_dataset('train_cov',data = cov) 34 | h5fout.create_dataset('train_ll',data = np.array([ll])) 35 | rmse = np.sqrt(np.square(mus-h5f['train_z'].value.squeeze()).mean()) 36 | h5fout.create_dataset('train_rmse',data = np.array([rmse])) 37 | 38 | #Always run exact inference on the validation set 39 | print 'Running filter: ',filterType,' on valid' 40 | X = h5f['valid'].value 41 | mus,cov,ll = runFilter(X, params_synthetic, dataset, filterType) 42 | h5fout.create_dataset('valid_mu',data = mus) 43 | h5fout.create_dataset('valid_cov',data = cov) 44 | h5fout.create_dataset('valid_ll',data = np.array([ll])) 45 | rmse = np.sqrt(np.square(mus-h5f['valid_z'].value.squeeze()).mean()) 46 | h5fout.create_dataset('valid_rmse',data = np.array([rmse])) 47 | 48 | if int(dataset.split('synthetic')[1]) in [9,10,11]: 49 | print 'Running filter: ',filterType,' on test' 50 | X = h5f['test'].value 51 | mus,cov,ll = runFilter(X, params_synthetic, dataset, filterType) 52 | h5fout.create_dataset('test_mu',data = mus) 53 | h5fout.create_dataset('test_cov',data = cov) 54 | h5fout.create_dataset('test_ll',data = np.array([ll])) 55 | rmse = np.sqrt(np.square(mus-h5f['test_z'].value.squeeze()).mean()) 56 | h5fout.create_dataset('test_rmse',data = np.array([rmse])) 57 | 58 | h5f.close() 59 | h5fout.close() 60 | if __name__=='__main__': 61 | assert len(sys.argv)==2,'expecting sname' 62 | runBaselines('./baselines',sys.argv[-1].strip()) 63 | -------------------------------------------------------------------------------- /expt-synthetic-fast/train.py: -------------------------------------------------------------------------------- 1 | import os,time,sys 2 | sys.path.append('../') 3 | import numpy as np 4 | from datasets.load import loadDataset 5 | from parse_args_dkf import params 6 | from utils.misc import removeIfExists,createIfAbsent,mapPrint,saveHDF5,displayTime 7 | 8 | 9 | if params['dataset']=='': 10 | params['dataset']='synthetic9' 11 | dataset = loadDataset(params['dataset']) 12 | 13 | dataset['train'] = dataset['train'][:params['ntrain']] 14 | params['savedir']+='-'+params['dataset'] 15 | createIfAbsent(params['savedir']) 16 | 17 | #Saving/loading 18 | for k in ['dim_observations','dim_actions','data_type', 'dim_stochastic']: 19 | params[k] = dataset[k] 20 | mapPrint('Options: ',params) 21 | 22 | #Setup VAE Model (or reload from existing savefile) 23 | start_time = time.time() 24 | from stinfmodel_fast.dkf import DKF 25 | import stinfmodel_fast.evaluate as DKF_evaluate 26 | import stinfmodel_fast.learning as DKF_learn 27 | displayTime('import DKF',start_time, time.time()) 28 | dkf = None 29 | 30 | #Remove from params 31 | start_time = time.time() 32 | removeIfExists('./NOSUCHFILE') 33 | reloadFile = params.pop('reloadFile') 34 | if os.path.exists(reloadFile): 35 | pfile=params.pop('paramFile') 36 | assert os.path.exists(pfile),pfile+' not found. Need paramfile' 37 | print 'Reloading trained model from : ',reloadFile 38 | print 'Assuming ',pfile,' corresponds to model' 39 | dkf = DKF(params, paramFile = pfile, reloadFile = reloadFile) 40 | else: 41 | pfile= params['savedir']+'/'+params['unique_id']+'-config.pkl' 42 | print 'Training model from scratch. Parameters in: ',pfile 43 | dkf = DKF(params, paramFile = pfile) 44 | displayTime('Building dkf',start_time, time.time()) 45 | 46 | 47 | savef = os.path.join(params['savedir'],params['unique_id']) 48 | print 'Savefile: ',savef 49 | start_time= time.time() 50 | savedata = DKF_learn.learn(dkf, dataset['train'], dataset['mask_train'], 51 | epoch_start =0 , 52 | epoch_end = params['epochs'], 53 | batch_size = params['batch_size'], 54 | savefreq = params['savefreq'], 55 | savefile = savef, 56 | dataset_eval=dataset['valid'], 57 | mask_eval = dataset['mask_valid'], 58 | replicate_K = 5 59 | ) 60 | displayTime('Running DKF',start_time, time.time()) 61 | #Save file log file 62 | saveHDF5(savef+'-final.h5',savedata) 63 | #import ipdb;ipdb.set_trace() 64 | -------------------------------------------------------------------------------- /expt-synthetic/README.md: -------------------------------------------------------------------------------- 1 | ## README for expt-synthetic 2 | Author: Rahul G. Krishnan (rahul@cs.nyu.edu) 3 | 4 | ## Reproducing results from the paper 5 | This folder contains the scripts to reproduce the synthetic experiments. 6 | 7 | Steps: 8 | 9 | ``` 10 | Run "python create_expt.py". This will yield a list of settings to run. Run them sequentially/in parallel 11 | Run the ipython notebook in structuredinference/ipynb/synthetic to obtain the desired plots 12 | ``` 13 | ## Synthetic Datasets 14 | 15 | The synthetic datasets are located in the theanomodels repository. See theanomodels/datasets/synthp.py 16 | * The file defines a dictionary, one entry for every synthetic dataset 17 | * Each dataset's parameters are contained within sub-dictionaries. This includes the initial mean/covariance for the transition distribution, the fixed covariance for the emission distribution as well as the emission and transition functions. (params_synthetic['synthetic9'] is a dict with keys such as trans_fxn, emis_fxn, init_cov etc. trans_fxn, emis_fxn are pointers to functions) 18 | * At model creation, this dictionary is embedded into the DKF and use to create the transition and emission function for the generative model. The DKF directly uses the theano implementations of the transition and emission functions in (theanomodels/datasets/synthpTheano.py) 19 | -------------------------------------------------------------------------------- /expt-synthetic/create_expt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sys 4 | import itertools 5 | import argparse 6 | parser = argparse.ArgumentParser(description='Create experiments for synthetic data') 7 | parser.add_argument("-s",'--onScreen', type=bool, default=False,help ='create command to run on screen') 8 | args = parser.parse_args() 9 | np.random.seed(1) 10 | ## Standard checks 11 | var_model = ['LR','L','R'] 12 | inference_model = ['mean_field','structured'] 13 | optimization = ['adam'] 14 | lr = ['0.0008'] 15 | nonlinearity = ['relu'] 16 | datasets = ['synthetic9','synthetic10'] 17 | dim_hidden = [40] 18 | rnn_layers = [1] 19 | rnn_size = [40] 20 | batch_size = [250] 21 | rnn_dropout = [0.00001] 22 | #Add the list and it2s name here 23 | takecrossprod = [inference_model,datasets,optimization, lr, var_model, nonlinearity, dim_hidden, rnn_layers, rnn_size, batch_size, rnn_dropout] 24 | names = ['infm','dset','opt','lr','vm','nl','dh','rl','rs','bs','rd'] 25 | 26 | def buildConfig(element,paramnames): 27 | config = {} 28 | for idx,paramvalue in enumerate(element): 29 | config[paramnames[idx]] = paramvalue 30 | return config 31 | def buildParamString(config): 32 | paramstr = '' 33 | for paramname in config: 34 | paramstr += '-'+paramname+' '+str(config[paramname])+' ' 35 | return paramstr 36 | 37 | tovis = {} 38 | for idx, element in enumerate(itertools.product(*takecrossprod)): 39 | config = buildConfig(element,names) 40 | #Fixed parameters 41 | config['ep'] = '1500' 42 | config['sfreq'] = '100' 43 | gpun = 1 44 | if idx%2==0: 45 | gpun = 2 46 | if 'lr' not in config: 47 | config['lr']='0.001' 48 | if 'opt' not in config: 49 | config['opt']='adam' 50 | 51 | paramstr= buildParamString(config) 52 | assert 'dset' in config and 'vm' in config and 'infm' in config, 'Expecting dataset, var_model and inference_model to be in config' 53 | #Savefile to look for in visualize 54 | savestr = 'dkf_'+config['dset']+'_ep'+config['ep']+'_rs20dh20ds1vm'+config['vm']+'lr'+config['lr'] 55 | savestr+= 'ep'+config['ep']+'opt'+config['opt'] 56 | savestr +='infm'+config['infm']+'synthetic' 57 | tovis[paramstr.replace(' ','')] = savestr 58 | 59 | cmd = "export CUDA_VISIBLE_DEVICES="+str(gpun-1)+";" 60 | if args.onScreen: 61 | randStr = 'THEANO_FLAGS="lib.cnmem=0.4,compiledir_format=compiledir_%(platform)s-%(processor)s-%(python_version)s-%(python_bitwidth)s-'+str(np.random.randint(100))+'"' 62 | cmd += randStr+ " python2.7 train.py "+paramstr 63 | name = paramstr.replace(' ','').strip().replace('inference_model','').replace('var_model','').replace('optimization','') 64 | name = name.replace('-','') 65 | name = name.replace('synthetic','S').replace('nonlinearity','NL').replace('dataset','') 66 | print "screen -S "+name+" -t "+name+" -d -m" 67 | print "screen -r "+name+" -p 0 -X stuff $\'"+cmd+"\\n\'" 68 | #+"| tee ./checkpointdir_"+config['dataset']+'/'+savestr.replace('t7','log')+"\\n\'" 69 | else: 70 | cmd += "python2.7 train.py "+paramstr 71 | print cmd 72 | -------------------------------------------------------------------------------- /expt-synthetic/run_baselines.py: -------------------------------------------------------------------------------- 1 | import sys,os,h5py,glob 2 | sys.path.append('../') 3 | from baselines.filters import runFilter 4 | import numpy as np 5 | from utils.misc import getPYDIR 6 | from datasets.synthp import params_synthetic 7 | 8 | def runBaselines(DIR): 9 | DATADIR = getPYDIR()+'/datasets/synthetic' 10 | assert os.path.exists(DATADIR),DATADIR+' not found. must have this to run baselines' 11 | if not os.path.exists(DIR): 12 | os.mkdir(DIR) 13 | 14 | for f in glob.glob(DATADIR+'/*.h5'): 15 | dataset = os.path.basename(f).replace('.h5','') 16 | print 'Reading from: ',f,' Saving to: ',DIR+'/'+dataset+'-baseline.h5' 17 | 18 | filterType = params_synthetic[dataset]['baseline'] 19 | h5fout = h5py.File(DIR+'/'+dataset+'-baseline.h5',mode='w') 20 | h5f = h5py.File(f,mode='r') 21 | 22 | print 'Running filter: ',filterType,' on train' 23 | X = h5f['train'].value 24 | mus,cov,ll = runFilter(X, params_synthetic, dataset, filterType) 25 | h5fout.create_dataset('train_mu',data = mus) 26 | h5fout.create_dataset('train_cov',data = cov) 27 | h5fout.create_dataset('train_ll',data = np.array([ll])) 28 | rmse = np.sqrt(np.square(mus-h5f['train_z'].value.squeeze()).mean()) 29 | h5fout.create_dataset('train_rmse',data = np.array([rmse])) 30 | 31 | print 'Running filter: ',filterType,' on valid' 32 | X = h5f['valid'].value 33 | mus,cov,ll = runFilter(X, params_synthetic, dataset, filterType) 34 | h5fout.create_dataset('valid_mu',data = mus) 35 | h5fout.create_dataset('valid_cov',data = cov) 36 | h5fout.create_dataset('valid_ll',data = np.array([ll])) 37 | rmse = np.sqrt(np.square(mus-h5f['valid_z'].value.squeeze()).mean()) 38 | h5fout.create_dataset('valid_rmse',data = np.array([rmse])) 39 | 40 | print 'Running filter: ',filterType,' on test' 41 | X = h5f['test'].value 42 | mus,cov,ll = runFilter(X, params_synthetic, dataset, filterType) 43 | h5fout.create_dataset('test_mu',data = mus) 44 | h5fout.create_dataset('test_cov',data = cov) 45 | h5fout.create_dataset('test_ll',data = np.array([ll])) 46 | rmse = np.sqrt(np.square(mus-h5f['test_z'].value.squeeze()).mean()) 47 | h5fout.create_dataset('test_rmse',data = np.array([rmse])) 48 | 49 | h5f.close() 50 | h5fout.close() 51 | if __name__=='__main__': 52 | runBaselines('./baselines') 53 | -------------------------------------------------------------------------------- /expt-synthetic/train.py: -------------------------------------------------------------------------------- 1 | import os,time,sys 2 | sys.path.append('../') 3 | import numpy as np 4 | from datasets.load import loadDataset 5 | from parse_args_dkf import params 6 | from utils.misc import removeIfExists,createIfAbsent,mapPrint,saveHDF5,displayTime 7 | 8 | params['dim_stochastic'] = 1 9 | 10 | if params['dataset']=='': 11 | params['dataset']='synthetic9' 12 | dataset = loadDataset(params['dataset']) 13 | params['savedir']+='-'+params['dataset'] 14 | createIfAbsent(params['savedir']) 15 | 16 | #Saving/loading 17 | for k in ['dim_observations','dim_actions','data_type']: 18 | params[k] = dataset[k] 19 | mapPrint('Options: ',params) 20 | 21 | 22 | #Setup VAE Model (or reload from existing savefile) 23 | start_time = time.time() 24 | from stinfmodel.dkf import DKF 25 | import stinfmodel.evaluate as DKF_evaluate 26 | import stinfmodel.learning as DKF_learn 27 | displayTime('import DKF',start_time, time.time()) 28 | dkf = None 29 | 30 | #Remove from params 31 | start_time = time.time() 32 | removeIfExists('./NOSUCHFILE') 33 | reloadFile = params.pop('reloadFile') 34 | if os.path.exists(reloadFile): 35 | pfile=params.pop('paramFile') 36 | assert os.path.exists(pfile),pfile+' not found. Need paramfile' 37 | print 'Reloading trained model from : ',reloadFile 38 | print 'Assuming ',pfile,' corresponds to model' 39 | dkf = DKF(params, paramFile = pfile, reloadFile = reloadFile) 40 | else: 41 | pfile= params['savedir']+'/'+params['unique_id']+'-config.pkl' 42 | print 'Training model from scratch. Parameters in: ',pfile 43 | dkf = DKF(params, paramFile = pfile) 44 | displayTime('Building dkf',start_time, time.time()) 45 | 46 | 47 | savef = os.path.join(params['savedir'],params['unique_id']) 48 | print 'Savefile: ',savef 49 | start_time= time.time() 50 | savedata = DKF_learn.learn(dkf, dataset['train'], dataset['mask_train'], 51 | epoch_start =0 , 52 | epoch_end = params['epochs'], 53 | batch_size = params['batch_size'], 54 | savefreq = params['savefreq'], 55 | savefile = savef, 56 | dataset_eval=dataset['valid'], 57 | mask_eval = dataset['mask_valid'], 58 | replicate_K = 5 59 | ) 60 | displayTime('Running DKF',start_time, time.time()) 61 | #Save file log file 62 | saveHDF5(savef+'-final.h5',savedata) 63 | 64 | #On the validation set, estimate the MSE 65 | def estimateMSE(self): 66 | assert 'synthetic' in self.params['dataset'],'Only valid for synthetic data' 67 | 68 | allmus,alllogcov=[],[] 69 | 70 | for s in range(50): 71 | _,mus, logcov = DKF_evaluate.infer(dkf,dataset['valid']) 72 | allmus.append(mus) 73 | alllogcov.append(logcov) 74 | mean_mus = np.concatenate([mus[:,:,:,None] for mus in allmus],axis=3).mean(3) 75 | print 'Validation MSE: ',np.square(mean_mus-dataset['valid_z']).mean() 76 | corrlist = [] 77 | for n in range(mean_mus.shape[0]): 78 | corrlist.append(np.corrcoef(mean_mus[n,:].ravel(),dataset['valid_z'][n,:].ravel())[0,1]) 79 | print 'Validation Correlation with with True Zs: ',np.mean(corrlist) 80 | #import ipdb;ipdb.set_trace() 81 | -------------------------------------------------------------------------------- /expt-template/README.md: -------------------------------------------------------------------------------- 1 | ## Template for running code 2 | Author: Rahul G. Krishnan (rahul@cs.nyu.edu) 3 | 4 | This contains a simple, commented example of loading data and running the DKF on random data 5 | The file load.py contains an example of how to load data, the file train.py is a commented 6 | example of how to load the models and run the training script. 7 | -------------------------------------------------------------------------------- /expt-template/load.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | """ 3 | Create a fake dataset time series dataset 4 | """ 5 | def loadDataset(): 6 | dataset = {} 7 | Ntrain, Nvalid, Ntest = 3000,100,100 8 | T , dim_observations = 100,40 9 | dataset['train'] = (np.random.randn(Ntrain,T, dim_observations)>0)*1. 10 | dataset['mask_train'] = np.ones((Ntrain,T)) 11 | dataset['valid'] = (np.random.randn(Ntest,T,dim_observations)>0)*1. 12 | dataset['mask_valid'] = np.ones((Nvalid,T)) 13 | dataset['test'] = (np.random.randn(Ntest,T,dim_observations)>0)*1. 14 | dataset['mask_test'] = np.ones((Ntest,T)) 15 | dataset['dim_observations'] = dim_observations 16 | dataset['data_type'] = 'binary' 17 | return dataset 18 | -------------------------------------------------------------------------------- /expt-template/runme.sh: -------------------------------------------------------------------------------- 1 | python2.7 train.py -vm R -infm structured -ds 10 -dh 50 2 | -------------------------------------------------------------------------------- /expt-template/train.py: -------------------------------------------------------------------------------- 1 | import os,time,sys 2 | 3 | """ Add the higher level directory to PYTHONPATH to be able to access the models """ 4 | sys.path.append('../') 5 | 6 | """ Change this to modify the loadDataset function """ 7 | from load import loadDataset 8 | 9 | """ 10 | This will contain a hashmap where the 11 | parameters correspond to the default ones modified 12 | by any command line options given to this script 13 | """ 14 | from parse_args_dkf import params 15 | 16 | """ Some utility functions from theanomodels """ 17 | from utils.misc import removeIfExists,createIfAbsent,mapPrint,saveHDF5,displayTime 18 | 19 | 20 | """ Load the dataset into a hashmap. See load.py for details """ 21 | dataset = loadDataset() 22 | params['savedir']+='-template' 23 | createIfAbsent(params['savedir']) 24 | 25 | """ Add dataset and NADE parameters to "params" 26 | which will become part of the model 27 | """ 28 | for k in ['dim_observations','data_type']: 29 | params[k] = dataset[k] 30 | mapPrint('Options: ',params) 31 | if params['use_nade']: 32 | params['data_type']='binary_nade' 33 | 34 | """ 35 | import DKF + learn/evaluate functions 36 | """ 37 | start_time = time.time() 38 | from stinfmodel.dkf import DKF 39 | import stinfmodel.learning as DKF_learn 40 | import stinfmodel.evaluate as DKF_evaluate 41 | displayTime('import DKF',start_time, time.time()) 42 | dkf = None 43 | 44 | #Remove from params 45 | start_time = time.time() 46 | removeIfExists('./NOSUCHFILE') 47 | reloadFile = params.pop('reloadFile') 48 | """ Reload parameters if reloadFile exists otherwise setup model from scratch 49 | and initialize parameters randomly. 50 | """ 51 | if os.path.exists(reloadFile): 52 | pfile=params.pop('paramFile') 53 | """ paramFile is set inside the BaseClass in theanomodels 54 | to point to the pickle file containing params""" 55 | assert os.path.exists(pfile),pfile+' not found. Need paramfile' 56 | print 'Reloading trained model from : ',reloadFile 57 | print 'Assuming ',pfile,' corresponds to model' 58 | dkf = DKF(params, paramFile = pfile, reloadFile = reloadFile) 59 | else: 60 | pfile= params['savedir']+'/'+params['unique_id']+'-config.pkl' 61 | print 'Training model from scratch. Parameters in: ',pfile 62 | dkf = DKF(params, paramFile = pfile) 63 | displayTime('Building dkf',start_time, time.time()) 64 | 65 | """Set save prefix""" 66 | savef = os.path.join(params['savedir'],params['unique_id']) 67 | print 'Savefile: ',savef 68 | start_time= time.time() 69 | 70 | """Learn the model (see stinfmodel/learning.py)""" 71 | savedata = DKF_learn.learn(dkf, dataset['train'], dataset['mask_train'], 72 | epoch_start =0 , 73 | epoch_end = params['epochs'], 74 | batch_size = params['batch_size'], 75 | savefreq = params['savefreq'], 76 | savefile = savef, 77 | dataset_eval=dataset['valid'], 78 | mask_eval = dataset['mask_valid'], 79 | replicate_K= params['replicate_K'], 80 | shuffle = False 81 | ) 82 | displayTime('Running DKF',start_time, time.time()) 83 | 84 | """ Evaluate bound on test set (see stinfmodel/evaluate.py)""" 85 | savedata['bound_test'] = DKF_evaluate.evaluateBound(dkf, dataset['test'], dataset['mask_test'], 86 | batch_size = params['batch_size']) 87 | saveHDF5(savef+'-final.h5',savedata) 88 | print 'Test Bound: ',savedata['bound_test'] 89 | import ipdb;ipdb.set_trace() 90 | -------------------------------------------------------------------------------- /images/ELBO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/images/ELBO.png -------------------------------------------------------------------------------- /images/dkf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/images/dkf.png -------------------------------------------------------------------------------- /ipynb/synthetic/VisualizeSyntheticScaling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Visualize the scalability as a function of dimension in the linear case" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 6, 13 | "metadata": { 14 | "collapsed": false 15 | }, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "Found: /data/ml2/rahul/theanomodels/datasets/synthetic//synthetic18.h5\n", 22 | "Found: /data/ml2/rahul/theanomodels/datasets/synthetic//synthetic19.h5\n", 23 | "Found: /data/ml2/rahul/theanomodels/datasets/synthetic//synthetic20.h5\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "import numpy as np\n", 29 | "%matplotlib inline \n", 30 | "import glob,h5py,os,re\n", 31 | "import matplotlib.pyplot as plt\n", 32 | "import matplotlib as mpl\n", 33 | "mpl.rcParams['lines.linewidth']=2.5\n", 34 | "mpl.rcParams['lines.markersize']=8\n", 35 | "mpl.rcParams['text.usetex']=True\n", 36 | "mpl.rcParams['text.latex.unicode']=True\n", 37 | "mpl.rcParams['font.family'] = 'serif'\n", 38 | "mpl.rcParams['font.serif'] = 'Times New Roman'\n", 39 | "mpl.rcParams['text.latex.preamble']= ['\\usepackage{amsfonts}','\\usepackage{amsmath}']\n", 40 | "mpl.rcParams['font.size'] = 20\n", 41 | "mpl.rcParams['axes.labelsize']=20\n", 42 | "mpl.rcParams['legend.fontsize']=20\n", 43 | "#http://stackoverflow.com/questions/22408237/named-colors-in-matplotlib\n", 44 | "import cPickle as pickle\n", 45 | "from datasets.synthp import params_synthetic\n", 46 | "from utils.misc import getPYDIR,getConfigFile,readPickle, loadHDF5\n", 47 | "\n", 48 | "#visualize synthetic results\n", 49 | "def getCode(f):\n", 50 | " params = readPickle(getConfigFile(f))[0]\n", 51 | " code = params['inference_model']+'-'+params['var_model']\n", 52 | " code = code.replace('structured','ST').replace('mean_field','MF')\n", 53 | " return code\n", 54 | "results = {}\n", 55 | "results[10] = {}\n", 56 | "results[10]['name'] = 'synthetic15'\n", 57 | "results[10]['maxEPOCH'] = 200\n", 58 | "results[10]['dim'] = 10\n", 59 | "results[100] = {}\n", 60 | "results[100]['name'] = 'synthetic16'\n", 61 | "results[100]['maxEPOCH'] = 200\n", 62 | "results[100]['dim'] = 100\n", 63 | "results[250] = {}\n", 64 | "results[250]['name'] = 'synthetic27'\n", 65 | "results[250]['maxEPOCH'] = 200\n", 66 | "results[250]['dim'] = 250\n", 67 | "\n", 68 | "ntrainlist = [10,100,1000]\n", 69 | "SAVEDIR = '../../expt-synthetic-fast/chkpt-'\n", 70 | "\n", 71 | "DATADIR = getPYDIR()+'/datasets/synthetic/'\n", 72 | "\n", 73 | "baselines = {}\n", 74 | "for dset in results:\n", 75 | " bline = '../../expt-synthetic-fast/baselines/'+results[dset]['name']+'-baseline.h5'\n", 76 | " if os.path.exists(bline):\n", 77 | " bb = loadHDF5(bline)\n", 78 | " print bb.keys()\n", 79 | " else:\n", 80 | " bb = None\n", 81 | " baselines[dset] = bb\n", 82 | "\n", 83 | "\n", 84 | "from datasets.load import loadDataset\n", 85 | "datasets = {}\n", 86 | "dset_params = {}\n", 87 | "for dset in results: \n", 88 | " datasets[dset] = loadDataset(results[dset]['name'])\n", 89 | " dset_params[dset] = params_synthetic[results[dset]['name']]" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 9, 95 | "metadata": { 96 | "collapsed": false 97 | }, 98 | "outputs": [ 99 | { 100 | "name": "stdout", 101 | "output_type": "stream", 102 | "text": [ 103 | "10 10 (500, 25, 10) (500, 25, 10)\n", 104 | "0.981814049597\n", 105 | "100 10 (500, 25, 10) (500, 25, 10)\n", 106 | "0.981765100916\n", 107 | "1000 10 (500, 25, 10) (500, 25, 10)\n", 108 | "0.981760031913\n", 109 | "10 100 (500, 25, 100) (500, 25, 100)\n", 110 | "0.983456427193\n", 111 | "1000 100 (500, 25, 100) (500, 25, 100)\n", 112 | "0.983420357791\n", 113 | "100 100 (500, 25, 100) (500, 25, 100)\n", 114 | "0.983437612098\n", 115 | "100 250 (500, 25, 250) (500, 25, 250)\n", 116 | "0.982943434674\n", 117 | "10 250 (500, 25, 250) (500, 25, 250)\n", 118 | "0.982955749653\n", 119 | "1000 250 (500, 25, 250) (500, 25, 250)\n", 120 | "0.982936344814\n", 121 | "9 ignored\n" 122 | ] 123 | } 124 | ], 125 | "source": [ 126 | "#Visualizing epoch\n", 127 | "getepoch = re.compile(\"-EP(.*)-\")\n", 128 | "\n", 129 | "def estimateMSE(mu_posterior, true_z):\n", 130 | " err_sum = np.square(mu_posterior-true_z).sum()\n", 131 | " return np.sqrt(np.square(mu_posterior-true_z.squeeze()).mean())\n", 132 | "\n", 133 | "final_result = {}\n", 134 | "ignored = 0\n", 135 | "for dset in results:\n", 136 | " DIR = SAVEDIR+results[dset]['name']+'/'\n", 137 | " for f in glob.glob(DIR+'*-final.h5'):\n", 138 | " code = getCode(f)\n", 139 | " if code!='ST-R':\n", 140 | " ignored +=1\n", 141 | " continue\n", 142 | " params = readPickle(getConfigFile(f))[0]\n", 143 | " Ntrain, dimstoc = params['ntrain'],results[dset]['dim']\n", 144 | " alldata = loadHDF5(f)\n", 145 | " #print f,alldata.keys()\n", 146 | " valid_mus= alldata['mu_posterior_valid'][-1]\n", 147 | " valid_zs = datasets[dimstoc]['valid_z']\n", 148 | " print Ntrain,dimstoc, valid_mus.shape, valid_zs.shape\n", 149 | " rmse_train = estimateMSE(valid_mus,valid_zs)\n", 150 | " print rmse_train\n", 151 | " final_result['dim'+str(dimstoc)+'ntrain'+str(Ntrain)]= rmse_train\n", 152 | "print ignored,' ignored'" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 10, 158 | "metadata": { 159 | "collapsed": false 160 | }, 161 | "outputs": [ 162 | { 163 | "data": { 164 | "text/plain": [ 165 | "(0, 6)" 166 | ] 167 | }, 168 | "execution_count": 10, 169 | "metadata": {}, 170 | "output_type": "execute_result" 171 | }, 172 | { 173 | "data": { 174 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAFVCAYAAAAJ9lMdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3W9s3Hae3/G3bG96SBtrIqe43UMfRKMtcP2DJFIc7FNm\nJedRHhwi23m2xQG17D5Z9MHKf1AU0W7RPVu655XlfXRA0Y0s7/PGUswDrgW2a8sO0GIXiD0K0F76\nJJZlJ9nibhNNH/xID4dDzpAz5PBH8vMCxmNxfkP+hl+SX/JH8kcQERERERERERERERERERERERER\nERERERERERERERERERERERERERERERERERFLTRRdgXE6Bs++gZeKrkecY/DlN3A8w1EuA28Bp4Fd\n4GTo80ngF8AicADcBt7PcPppLABtYCfiszlgHmgBTcxviSpXXkd4xqG9yyZH+JJDLZsRnyVZNrMq\nk4tj45iILb6Bl9oZjcsFzgA3ASejcU5kn6DWgFmggVmQZ4H7gc+fYn7GOnAh42mnMQdsYjYIYU3g\nMnA2MGwT2Kf7t5TbIS+xEvi7DRwCR0PlvgWO0Lv7lnf5FS2bEZ8lWTazKpObI3lPoIpcsk8AOVoA\nznv/Px9T5tGY6hI2jVnJ5/uUuQT8MjTsOnAtr0pZYYLeDTTesKjj97zL56MOy2ZWZXKjJJCSS6kS\nAMAJYA/YApYiPm9gDreLsIfZy1vD7PVEOYM5NA66h9mASLnVYdnMqkxulARScCldAgB47L1f997P\nhT6fB7bHV51UGt7rs9Bwf8Pwxlhrk7c73qus5dOr+rKZVZlcKQkk5JJvAnBzGCemndXfw9jBnHQK\nH3Y3MXs9NmoO+HxqLLUYl7/2XklOXLUtK59eHZbNrMrkqlYnhoflkn8COJPDeDFXXGwG/vbbGacZ\nbuX6KGX5S+R7YquR47jH7wXv/XfAnxLfNt/2yoyzfPbqsGwOaspKWiZXSgIDuIwnAdwE3s5+9A3M\nVRa+DcyKdh5zNULaNtd3sqtaItXa0x/k7733LcwG+mXgB3QunrwL/Bp4gtlQfzvG8tmrw7KZVZlc\nqTmoD5fxJYA8xh/hKaaN1T8JdxKz6tsq7oScr6iThvnyN75/BLwWGP6aNyxYZpzl81XFZTOrMrlS\nEojhUvoE0MS0s4Zdw+xlLdJ7bbZt/PrH3aQ0aAUqpwnge8CP6DTJ4P3/R95nE2Mun626LJtZlcmV\nmoMiuJQ+AYC5smIzYvgOZu/iPOYuzDTG3e564L2mgGeB4X476YMRxm2vNvBd4DsRn30H+GPgf4+5\nfLbqtGxmVSY3SgIhLpVIANDb5hq0AVwk/Yoz7nZXME0Eb9J9Cd1J0m8k7Oev9s+Av+1T7nPMMby/\n7ziO8od9yqdXp2UzqzK5UXNQgEtlEgDATJ/P/OuybbsGO6rh4RK9lw4uecOr5cfe68+IvpvXd8Qr\nM87y2arTsplVmdzUqgM5oB13ZOtSfAKY6Hob2jrmTsOmN7klove6PqS4Drl8k8AVzJ7hEuYuyR3g\nN8CtQLlZzG/yO9e6B3w81prmr93Vd1BRfQTFlTd107I53LKZVZlc1CoJ1LAXUSmL+vUiKiIiIiIi\nIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIiIilroKPML0UbqOebBM2G3v8w8xD5wH87Sx\n4PeuBl6b3vCHeVZcKuXYM0xnuZa+jgWfLJSFZTorStTzWicxXfoeYh5j92HG009jgegNA8Ac5rcs\neu9R5ZKUsdbRo1i9bHr1G9U5+j8ucRl4I+X3JlESGEnNniz2zUv9n5XnUuxTBSay7kp4DdNPeQOz\nkQ0/t/WpV6F14ELG005jDpOsTkd81gQuA2cDwzYxG4X7KcpY7dtveenOnc7fDx7Ayop5vRG1WQzJ\nu/zbb2fSzXW/rusXgS1gL+X3ntJ5EI0MQU8W6+JgNtBnMBvsso0/0gKdpxaFn17kezSmuoRNE980\n4LsE/DI07DqmmSBNmdKwLQE8yP9JzouYh6lEJYAkduk0H0lKSgI9HCqWCE5gVq4tzBOSwhqYB10X\nYQ9zBLJG/OH+GcxKHnQPk9zSlCkFGxPAysrgciOYxySAtEdsk4H/t+g8oVlSUhKI5FChRPDYe/cP\nmc+FPp/Hvue5+hre67PQcD9pvZGwTCnYmgByTAKnMU2BwzTZLdBJBHtDjkNQEujDoQKJYJbOHvIO\nZo8p3CTUZPjD8Lw1B3w+lbBMKdiaAJKUH8IC5hzOFbr36pM6lW116qtmJ4bTcuhsqPM4WRwef+ZO\nYk6Q+vx28mmG2/B/lLL8JfLdQ0vSlFWaZgJbNuhjSACzmCs0zmKuWrtB90n9KA3MZaFgjh4WgIu5\n1K5mlAQGchhfIshcA3P1hG8DkwTOY66mSXs+4J3sqpZIkr340uzpD2LDBn0MCaABvIw5MgXTPHkP\n0yy5E/clzHJ6OfD3VfpfNSQJqTkoEYfxNA3l7imm/d8/QXyS6PsHbNHvmnIwG4YkZSqhAgkATDw+\nDvx9H1gl/WWetzOrUc0pCSTmkH8iyFQTcw4g7Bpmb2yR3vsGbOPX/3jM5/sJy5ReRRJAnMuYI7qr\ngwoG7NB9lCtDUnNQKg75Ng1lap7u8wG+Hcze2HnS702N+5zAgfeagq47Vv12/geBcoPKlFbFE4Dv\nDGZ5vI69FypUkpJAag4lSQTh8wFBG5iTamk36uM+JwCm+epNui8BPUl3AktSppRqkgDA7JzsYFaq\nk4XWpGZsaQ5awNw5uow5JJwttjqDOBRw529aM30+89tfbbs/IOpE3yV6L2td8oanKVM6FUsAVzE7\nHpNE3yU+h7lqbRazc/JDb/i10PeWc69pzdhwdn0JE/wrgWHXie/iYBTt/n0HpeWS7RHBxPN/RrCO\nSapNTMWWiD4i+BB4f8RpjWoSE/cGpp73MHuDvwFuBcrNYn5TC/O77tF9cjFpGZu1Le87COzYXkjF\nNOntAXCJzl2uGatdL6JSErb3InrkCF/n+POlxq4DPwkNm6REt/qLiMjw9um0/YmIyJgVeXVQE9MW\n/ARzzTp07v68UUiNRERkbOYwT7QKNwddpbenSxERqRg/CbwXGj5NRe7yFBGReH4SeDXis0N0clhE\nJHdFnhOI6tcmqEnolv/XX3+9/cknn+RXIxGRavqEmB3rIu8Y9vuFievvvSdJfPLJJ7Tb7dK/Pvjg\ng8LroJfiV8dXXWMHvB63IS6624hNep8QNIe5Yqj0HX+JiNiu6CRwid6uCy6jq4NERMai6F5En2I2\n+OvAI+AE8EvgV0VWKm+O4xRdBRmB4ldeil2vsnUI1fbat0REJKGJifjOKYtuDhIRkQIpCYiI1JiS\ngIhIjSkJiIjUmJKAiEiNKQmIiNSYkoCISI0pCYiI1JiSgIhIjSkJiIjUmJKAiEiNKQmIiNSYkoCI\nSI0pCYiI1JiSgIhIjSkJiIjUmJKAiEiNKQmIiNSYkoCISI0pCYiI1JiSgIhIjSkJiIjUWNFJoAnc\nBGYDfy8Di4XVSESkRopOAmA2+PeAQ+Au8AS4VWiNRERq4ljB028DC0ALaAAPiq2OiEi9FJ0EACaA\nz4quhIhIHdnQHCQiIgWx4UigiWkKApjy3m8UVBcRkVopOgnse6/gieB1712JQEQkZxNFVyDCLLBD\n56ggqN1ut8dcHRGRcpuYmICY7b2N5wT2MM1DrxZcDxGRyiu6OegicB14Ghh24L03ibhqaGVl5fn/\nHcfBcZzcKiciUkau6+K6bqKyRTYHNYGHwBzd9wc0MOcJopKAmoNERFKytTmoBZyn9waxBeARundA\nRCR3NpwTmA79fRk4U0RFRETqxoarg87RuU9gBnOJaFz3EWoOEhFJqV9zkA1JIA0lARGRlGw9JyAi\nIgVTEhARqTElARGRGlMSEBGpMSUBEZEaUxIQEakxJQERkRpTEhARqTElARGRGlMSEBGpMSUBEZEa\nUxIQEakxJQERkRpTEhARqTElARGRGlMSEBGpMSUBEZEaUxIQEakxJQERkRpTEhARqTElARGRGlMS\nEBGpMSUBEZEasy0JNID1oishIlIXx4quQMg14OWiKyEiUhc2HQk0MQmgXXRFRETqwqYkMA/cBiaK\nroiISF3YkgTmge2iKyEiUje2JIEmsIeOAkRExsqGJLAI3Ci6EiIidVR0EmgUPH0RkVor+hLRM3Qf\nBQy8MmhlZeX5/x3HwXGczCslIlJmruvium6iskW2wc967/cDw5aABeBszHfa7bauIBURSWNiYgJi\ntvdFJoFzwExo2BzmJPEW8BvgVuhzJQERkZRsTQJRloGTwPsxnysJiIik1C8JFH1iOGwC+xKTiEhl\n2bLBnQbOA6cxXUfcAD6k+3wB6EhARCS1MjUHDaIkICKSUpmag0REZIyUBEREakxJQESkxpQERERq\nTElARKTG+vUddBxzR+8U5i7ebUx3z75JTBcPU3Tu/H0I/CL7aoqISB76XSJ6CNwELtO98Y9zDfgJ\ncDSDesXRJaIiIikNe5/AQ+D7Kac1zHfSUBIQEUmpXxLo1xwUftzjtPfeDowsfISgR0SKiJRIvxPD\nB6G/G5j+/1uY7h2iHggT/o6IiFgszUNl7nuvk8BaPtUREZFxGuYS0VbmtRARkULoPgERkRobdIlo\nVBt/I2Y4mHsHdImoiIhFhr06aA+4GvfFGBdTlBURkYL1SwI3MQ93SaM5Ql1ERGTM9FAZEZGKy/Oh\nMq96LxERKaF+zUEXMR3Egbks9B6d5qFZTHPRlPdZA3P+QJ3HiYiUyKDmoJvAOrATGDYNPAI2gAuB\n4ReBffJNBGoOEhFJadgO5M4Bm8DT0PCrmG4jojqKW6c7MWRNSUBEJKVhzwk06E0AYJqItmK+o76D\nRERKJE3fQWASwxxwKcM6LNK5tHQG09SkvolERMagXxI4ETHsjPe+E/HZME4DT+je6N/13pUIRERy\n1q856C7mSWG+eczTw1Zjyv+F90rj/Yhh28D5lOMREZEh9EsCW8AzzNPC9oHbmCuCLgfKzGNOID/E\nHCWcIZ1JzNFA0ATmwTUiIpKzpHcMx3UaNxkaT5vok8lp3AP+C/CXEZ/p6iARkZSG7UAuKO6qn1E3\n+GGngS+ITgAiIpKxfs1B54YY3zDf8b93FdOcdHbIcYiISEr9moM+Il3X0BOYE8fvjFCfScyVR5eI\nvgJJzUEiIikNe8fw4RDTajP6Q2UWMX0UTdPb3KQkICKS0rDnBN7BnBB+QvL7AtZT1SzafW+6C8Ct\n8IcrKyvP/+84Do7jZDBJEZHqcF0X13UTlU1yddAknd5EHwEP+pSdJ3nCmMPcizAXGmcDc0nqEr2d\n0elIQEQkpVGvDnpKZ498GtNcA+amrnBzTZo7idvALqYr6qCTgfGLiEiORnmy2Cymz58DzB79MJeL\nLmNuQAt+9zbm5rN/E1FeRwIiIikNe2I4jUXMpZ0fAr9K+d3lwP/fAv4H8fcJKAmIiKSUxc1iUY5j\n+v45j2nX346byADqKE5EpCDDJIFFzIZ/AdOmfx3zBDI9S0BEpGSS7rnPY+7mXcKcyL2OOVkcPqmb\nNzUHiYikNGxz0CymuWfJ+3sTc+XObp/vLBJxbb+IiNhp0B3DW5iTvUk27H6XDycHFRyBjgRERFIa\n9kjgPiYBQOfegH5O0HlMpIiIlEC/JHCb9E07MyPURURExiyr+wTGRc1BIiIp9WsO6vc8gbQWMfcO\niIhISWSZBG4R/eB4ERGxVJIkcBx4I++KiIjI+A06J7CIuRsYzHMF5ul0+/wqpruIJuaE8FuYm8fy\nfDykzgmIiKQ0bAdys5gEcA3Yw2zwz2LuA1j2hvsOMH0HnSP7h88HKQmIiKQ07H0C54FTmAQAZiO/\njXkgfBN4mXw3+CIikrNB5wT2Qn/v0jkiUAIQESm5fkkgrlfQ63lURERExi/LS0TBnEcQEZGSyDoJ\nLA0uIiIithjUi2jaB8VMAkeHr85AujpIRCSlUXoRXY/7YgwdCYiIlMigXkRvpByfupIWESkR9SIq\nIlJx4+pFVERESqZfc9C4LGL6HWpgmpNukr4ZSkREhlB0ElgEHtF5gtkkcA+TENaKqpSISF0U3RzU\npNMrKZiuKC7R3TmdiIjkpMgk0MQ8hGY6NPy+965nGIiI5KzIJLCP6YwunARERGRMijwncEB0Eprz\n3h9EfCYiIhkq+pxAlCvonICIyFjYlgROY/osulJ0RURE6sCmJNDA9D00X3RFRETqouj7BII2MEcC\nz/oVWllZef5/x3FwHCfXSomIlI3rurium6isLX0HXQV+zoAEgPoOEhFJzfa+g85huqwOJoBZdOmo\niEjuim4OOu29T3kv31ng8virIyJSL0U2BzUwN4xFeQT804jhag4SEUmpX3OQLecEklISEBFJyfZz\nAiIiUhAlARGRGlMSEBGpMSUBEZEaUxIQEakxJQERkRpTEhARqTElARGRGlMSEBGpMSUBEZEaUxIQ\nEakxJQERkRpTEhARqTElARGRGlMSkMK0Wi0uXLiQqOyNGze4desWt27dYm1trXbTsE2V5msd4xdU\n9JPFRtJqtVhdXWV9fX1g2Rs3bjA1NfX8e8vLy7Wahk3u37/P9vY2jx8/5u7duwPLb2xs8Morr/De\ne+8BsLe3x4ULF/rOr6pMwzZVmq91jF8VtNvtdnt3d7e9urravnTpUvvNN99sD3L9+vX2rVu3nv/d\narXa58+f7/udqkzDZru7u4l+98zMTOSwg4OD2kzDNlWar3WIHxD7NK5SNgfNzs6yvLzM+++/n6j8\n6urq8+wNMD09zfb2Nk+fPq38NGzWTvCUuFarxf5+71NIm80m29vbtZmGbao0X+sYv6BSJgFfVRbE\nui+E/bRarefNX0GNRoNWq6VpWKxK87XK8St1EkiiKgtilRfCYT1+/FjTKKkqzdeyx6/ySaCfqiyI\nZV8I+zk4OIj9zHt4tqZhqSrN1yrHr/JJoCoLYpUXwn4ajUbsZydOnNA0LFal+Vrl+NmSBBaA+TxG\nXJUFscoLYT/NZjPyXMjBwQHNZlPTsFiV5muV42dDEpgDNulzCdMoqrIgVnkh7Mf/beEroFqtFgsL\nC5qGxao0X6scvyKTwDSwTk5HAL6qLIhVXgjDWq0WW1tbz/++du0am5ubz//e3d3l1KlTHD9+vPbT\nsE2V5msd41ekh8APE5TrugHi3r17kTd5PHr0qH3z5s3nf29sbLQ3Nja6vnfhwoVEN1lUZRo2abVa\n7dXV1fapU6faR44caV+6dKnrd12/fr39/e9/v+s7q6ur7a2trfbW1lZ7dXW1NtOwTZXma53iR5+W\nFlvOKD4EloCPB5Rrt9tt9vb22Nra4vbt2+zs7LC8vMzMzAznzp0DzO3da2trfPrpp8+/uLa29nxv\nOkl3C1WZhoiId/FI5Pa+lElARESS65cEbDgxLCIiBVESEBGpMSUBEZEaK93zBFZWVp7/33EcHMcp\nrC4iIjZyXRfXdROVLd2J4fyr8h3gD/lO4ghwmO8kXpp8iWcHz/KdSEovv/QSB199lfNUjgHf5DuJ\nMcTvxX/0Il9/+XW+E0lh6vhxnnz5Zc5Tqca6B/DiP3yRr7+yJ379TgyX7khg2CzgAmeAm4DTp9wE\nf8h/KocTsBIa9gfgr4D/E5j8BPBPgB9h1o8U5b/8j3mvsOkdfPUVd+g//4flYub8F3wDeU/l8Av4\nd2Qar3D536/8PvPaj+LJl1/2XStckq1f/Qxe9zKYymFgO5hhvMLlf/+1XfHrx6ZzArkdlbiMvoAm\n43hTOeNNNYXvAH9M9zrQBr5L7wI3THlLODmM06UTX2OI+Z+I05lKTeKVhEsJ1q8oWcYrqnxJFJkE\nJoGrmK4jmt7/rwKLWU7EZVwLqM9h6AX1c0xEGt7rCPC3GZavIJeo+Ga4oejRmYriVbL1y5dXvILl\nS6TI5qCnwGXv/xfymIDLuBdQn0NnQU0x9SPAnwH/0vv7fwK/ybB8xbjEzWGHoeZ/WjWPl0vJ1i/f\nj733rOMVLP+zdFUqki0nhpNKfL+wy3CLiJkhWR3TxdUi4pxAG3PC6mho+LeYhSscqUHlf5rssZXj\nNDExMY45Syd+caVGNQEfkG28wuVX7IpfOHYu2c/Z9OveMLUIrXtZxStcfsW++FGnO4ZditpDCXNI\nfOg6Qe8ChzcsKnRpy1eIS9L4OuTWNFTjeLmUcP2Kk1W84sqXQOWSgIstC6jPId826npxSRtfB83/\n7Lho/aqaSiUBF9sWUJ9D5IJ6x3sllbZ8xbgMG1+HQjYUFYuvS8nWr7QqFq+kKpMEXLJcQN2Rx9DL\nobOgYtoQ/9p7JWk6TFu+YlxGja/DWBNBBeObZwJwRx6Dw0jxrWC8kqpEEnDJegEdw3XmvwNe8F6/\no/+C1B6ifIW4ZBVfh8wSQdbxCpe3UJ4J4EwmY3IYKr5ZxKvE61zJTkn1Xh3kkm0CMDPkTsZjjZjK\nUcyVBtA52fQy8APgpDf8LvBr4AlmIUtT/hu7rk6A4a4OckkXiWRXmKQda8RUjpFtvMLlv7Urflle\n2RXk0onE20A2W9PgWJ3QZ4Grg7KMV7j8in3xo4pXB7nktal2yL3p4NuI//8R8Fpg+GvesGHKV4CL\n5fHNMl5R5SvOpeD4Zh2vcPmSKG0ScMn7JJVDrolgIvT/72H6JHkhMPwFb9j3hihfci6WxzfreIXL\nV5yLBfHNMl5R5UuilEnApaR9lQTl1eeMX77EXEoQ3zz7nLGnFSEXLhbFV31AlS8JuJSwr5IoefY5\n83lGdSyAS0nim1efM375inKxML417wOqdF1JF3OZmkPmfdHk2edMSTciLiXqi6bf3aFZ9Cn0q8FV\nKBuXJHPY7fvpcBy6Ls8Oq3kfUGVrfWzn3x89xB+P+6VG3UxF9B2UdR8m/8GuqxOg/xUmLpnNWcbx\n1An+Pfn0OeNbsSt+o14d5DJ4zppZ88qAUqNwgbe7172s++zyy6/YFz+qcnWQk8M4XcL90febek7n\nCGrch4mLLXeiOiSOb43jlZZLmviOqRtwX437gPKVLglkzUV90RTJxZYE4HNQfLPjovXLdrVOAi6W\n9kVTkz5MXGxLAD6HVPGtSbzScqnp+lWy2JbuxHBWXLLsiybDzZjfJ4k/iUGHmAnL37hxg6mpKQBa\nrRbLy8t9R5u2fFoutiYAn0Oi+OYUrzDb4jeIi6XrF+Qbr2D5AJvjV8sk4JJHXzQZLKjBPknw/v+n\nxC94CctvbGzwyiuv8N577wGwt7fHhQsXWF9fjxxt2vJpudieAHwOfeObU7zCbIvfIC6Wrl++vOIV\nLP/3ncFli5/t2u0RX3eg/Yr3HvU5MMRo77ThFe89SXnavEubFe/1Lm3+MW2O0eaoXwfv/8e8z9KW\nh7ZvZmamHTYzM9M+ODjoGT5M+aRIMP9HfQ0Xv7TxzSFe4fKWxS/pujdKfONjl3b9GrDu5RGviPK2\nxS9uo1qrcwIuFvVFk2efM4H+g1qtFvv7+z2TbzabbG9v9wxPWz4tu7sjjuPQE988+5wJ/N+2+PXj\nYtH6lUQefXaF+n0qQ/xqkwRcLOirJCivPmf88p5Wq/W8bTGo0WjQarV6hqctn9Z4uiN2c5iCQ9cN\nR3n2ORP4v23xi+Ni2fqVRB59doWaisoQv1okAReL+iqJU0AfJo8fP05VxbTlozgjj6GXS/g+jzE8\nD8KCPmeKiF8UlxKsX1GyjFdU+QFsiZ8NSWAOWAYWvff5LEfuYmFfJXFy6MPk4OAg9uveXYQjlS+a\nS1R8x3TD0Rj6nLE9fi4lWr98efbZ5Zf32B4/KP7qoCZwGTgbGLYJ7AP3Rx25S4n6ooFc+jBpNBrE\nOXHiRM+wtOWL5BI3hx1yu7wwaAx9ztgcP5eSrV++H3vvefTZ5Zf/mRlkc/x8RSeBS8AvQ8OuA9eA\nd0YZsUvRlyE6pFpQ28Cf031L+mvAv/A+i+rDpF95T7PZjDzRdHBwQLPZ7BmetnxRXAbNWYdcE0HW\n8YrZybM1fi4lW7+C/PaPLOMVVR574xdUdHPQGWA3NOwesDDKSF2KXkB9DokPXXPqw8RfcJ4+fdpV\nrNVqsbDQO5vTli+CS9L4OuTWNDSmPmdsjJ9LCdevODn3AWVj/MKKTAJ+S9tnoeF+o9gbw4zUxZYF\n1OdQdF8o165dY3Nz8/nfu7u7nDp1iuPHjwNmAdva2kpcvkgu9euLxqb4uWj9Ssum+EUp8kzfHObR\nzVGJ6BBzNPBxaHi738l3l9EX0NG6Iu7HpeuR2it0+hh5O+Eo0pRfgXZgbq2trT3fywjfhr6xscHa\n2hqffvppovLDGl93xFFTSfLtpCK6Ao8ySnxX7IrfxMQEd8g3AYytG/CVmI+yXB9X7IsfMdt7m5PA\naXofrRGbBFyy7I/+zohjiePy/KkFHwA/9QZ/wOBItFOWX+leCG0wShJwSbyKE78hSTqWQRIkgbTx\nCpf/qV3xm5iYGEdP/4y2A+aSaDdhJWLwqPEKl1+xL35Y+DyB3jsihuSS9R7KGK4z9/sYecH7f7/l\npT1E+QpxyaMvGnekMWUer3B5C43nRr9ROAwV3yziVeJ1zuYjgUTNQS7ZJoDOkUDOB75H6dxi7p9s\nehn4AXDSG34X+DXwBLOQpSn/jV17IjDckYBLukgka1JIO9aIqRwj23iFy39rV/xGbcqL49LVSEo2\nW9PgWJ3QZ4EjgSzjFS6/Yl/8sLA5qIG5H6ABPAt9dohJEg9Cwx8Ar+dfNRGRSvmEIS+2yds+8Gpo\nWAOTBEREJGdF3yewDbwZGnYSuF1AXUREZMymgY9Cwzax9LBFRESyN0t3B3I/HMM0r2ES0Li/K+kp\nVuWm+Fmu6L6DwHQUN3JncSlNM/xlCKN8NysLXh12Ij6bw/TE2sJ00LcbU64s6h6rssezrPFbBN7C\nnKNsYi41uhH4vIlJUj/HbL+a3ndawK1AubLHr7I2GX4PY5TvZmEOc0I96oipialf0CbmaKus6hyr\nKsSzjPFbpLtJehJ4iGmp8DUxF7D4r33gX4fGU4r4FX1iWJKbBtbp/7yFfr2yyvhkFSvFsxhNui9P\nf4qJRXBTbdiNAAAGkElEQVS+tzFHeU1Msp8CfhEaj+JnsTLunQQ9JHrvsoqX3NY5VlWIZ9ni18Tc\nFhaerr/n7x8hTDP4AViliJ+OBKojl15ZJRdJYqV4FmMfs2c/avIpTfyUBHotYW4Mv4tpC4TutkBb\nDXriRGZ9NVmkyrGqQzxtjN8BZrsY7rJmznsPNhP5J4MXgXPeK/hZP9bEz4arg2wyjwn2IjADXPH+\nXh9yfOF7IAa5RH5XSsU/t66cqh6r+IfNdsqUWdnid4Xutvx97xW8Esive/AqojjWxE9JoNtdOpdv\nfYxZQL+gt0vrpEZ6RGZK1uxZjEnVY1X1eJYpfqcx7fhXAsOe0p0AwJz03cEkgdLET0mgW/CZbqcx\n7YJrBdUlrd4Hk3YbtGdZNlWP1aB4lT2eZYlfA9NsNegkMMCeV/5VSrQ+6pxAtCVMMG1cKOO0vPe4\nZ9ANWijLqqqxqks8bY/fBiZJhXs6vkjnPIbP37A3KVH8dCTQaxl4xPCHpUHjbGf29x6n6F5g/bbH\ncLfcVVD1WFU9nrbH7yrmBrBwAmh6n31Edxz82LQo0fqoJNDtIr0LpX+5WLj9L4lxtjNDp1fWzwLD\nqtorax1iVeV42h6/c5gTvcEN+Cxmw94CztO7IV/A/KbPvL9LET8lgY4F4CwmcE3MXsKcN3zcG4gk\noh4IdAlzciq4Ei15w6ukLrGqajxtj99p732K7hO8Z4HLgb+nMecBfJfpflJmVeNXCVF3Il4NfX4I\n/Fd62/SKugt1ElPHdUzdfuP9vRgqV0SvrHmqe6zKHs+yxc+/ozfq9Wmo7DlMTJYxsY66Aazs8aus\nst3KXmeKVbkpfpbT1UEiIjWmJCAiUmNKAiIiNVbXJLDP8E8sGuW7kp5iVW6Kn4iIiIiIiIiIiIiI\niEjJXMX0jXKIuUOy6KdA5eUqcA/zO68GXuuYG5jifvcSFvUOmVAZ6ywiBTpHthuN2QzHleU0+v3O\nc3Q/EjE4nf80xLSKVMY6i0iBst5zHMfRxDDTGPQ7/a6ERURqJeskcDfDcWU5jSS/85BkT5wSsVZd\nbxYTO1zEdCFc1mlsYfqV981RvqRQxjpLhvQ8AcmT33XyFDAD/AWdZ8uew/QlD52uhR9hHtLtW8Lc\nMboPvIV5GMcOph17FbPxehnTDz3AKcwJXX8cSaYxirt0+pefBN7HND35O1dJ6unPoyZwgu7+6mH0\neYA3Df/Rhw2vjN+nfbjOvtPeePcx8dun0y9+mmmLSMUkbQ7yHx7um6a3WWYS06QS5SLwXmjYw9A4\nDyPK7NPdL32/afST5HcueeMOTi9qWlH1fIJJUkF36T6JncU8OE3vnv56xDiCloCfhIb5fecHJZn/\nYjk1B0lezmI2Nr49zN5ucAMR9cQtMHurl+l99my4+QU6D/T27dPZ++83jXEL1/Mxvf3itDCPH4Ts\n5sE03fMD4GafejYwR01/GRp+A7hC7xVRg+a/WE7NQZK1SUyTj9/cMI/ZKDzBbGDCD96OctZ7Dz+J\n67E3nqDwRmic/IeGD/o9EF3PqBPWfmLIah5sYZpoTtFpStrpU36B+COgllevG6FhUmJKApK1BUzb\n8QTmxqqLdDYaB3FfCmnQ3QbdT5INcF5mSP7Q8LT1zGoe7GGOBq5gjiCuYxLD2ZjyTeLjtI/5zUmn\nLSWg5iDJmv9g7tvAz4GPI8qEmxR8s95nj+h+wHeW/Glk4Qxmo5qHFtnMg3nMkdllTFPTEUyCCZ+P\n8PlHbFFOAF9kUCexiJKAZKmJ2YA0MXuf4fZsf+NyErMhDu9x+hu9Ha9s1PNlw80jg8RNY1QXMckq\n/Buzsk028+BUxDjOA2/GlL+JiV9UopzFHEVIhSgJSJauY/YkH3t/Bzc+C5iNpn+i1r9UtBUo1/CG\nH2A2VNdC458HdhPUI3wyOGoao1jCHAWMen19uJ4TgWFZzoPwieQZ4u92PsCcz7kSGn7Rq8tnKact\nljtadAWkVK5imhG+C7yI2bAvYC5D/Cvgn3nvD4C/wTRBHAL/HPgcc4Ly32Iuc/ytN87/hdmo/gPM\nBmjPG77rffdfYfZK/XH8FrNBv4a50el73vD/B/wMeBezJ/t/A+OKm0a/33kGeDXid/4I0yTy58Df\nBb4TrtPfAH+SoJ6fe/W55o2/iWl7/21G86AB/Dev3J8E3v9zTJ3/Dvjv3vTe9Ya/i0nuazG/ddD8\nFxERERERERERERERERERERERERERERERERERERERERERERERkYz8f8J4ka++IKFVAAAAAElFTkSu\nQmCC\n", 175 | "text/plain": [ 176 | "" 177 | ] 178 | }, 179 | "metadata": {}, 180 | "output_type": "display_data" 181 | } 182 | ], 183 | "source": [ 184 | "N = 3\n", 185 | "ind = np.arange(N) # the x locations for the groups\n", 186 | "width = 0.9 # the width of the bars\n", 187 | "\n", 188 | "fig, ax = plt.subplots()\n", 189 | "train10 = (final_result['dim10ntrain10'], final_result['dim100ntrain10'], final_result['dim250ntrain10'])\n", 190 | "rects1 = ax.bar(ind, train10, width/4, color='r', hatch=\"/\")\n", 191 | "\n", 192 | "train100 = (final_result['dim10ntrain100'], final_result['dim100ntrain100'], final_result['dim250ntrain100'])\n", 193 | "rects2 = ax.bar(ind + width/4, train100, width/4, color='b', hatch=\"\\\\\")\n", 194 | "\n", 195 | "train1000 = (final_result['dim10ntrain1000'], final_result['dim100ntrain1000'], final_result['dim250ntrain1000'])\n", 196 | "rects3 = ax.bar(ind + 2*width/4, train1000, width/4, color='g', hatch=\"*\")\n", 197 | "\n", 198 | "if baselines[10] is None:\n", 199 | " KF = (0,0,0)\n", 200 | "else:\n", 201 | " KF = (baselines[10]['valid_rmse'].sum(), baselines[100]['valid_rmse'].sum(), baselines[250]['valid_rmse'].sum())\n", 202 | "rects4 = ax.bar(ind + 3*width/4, KF, width/4, color='y', hatch=\"//\")\n", 203 | "\n", 204 | "# add some text for labels, title and axes ticks\n", 205 | "ax.set_xlabel('Latent Dimension')\n", 206 | "ax.set_ylabel('RMSE')\n", 207 | "ax.set_xticks(ind + width/2)\n", 208 | "ax.set_xticklabels(('$|z|=10$', '$|z|=100$', '$|z|=250$'))\n", 209 | "\n", 210 | "ax.legend((rects1[0], rects2[0], rects3[0], rects4[0]), ('$N=10$', '$N=100$', '$N=1000$', 'KF'),\n", 211 | " loc='upper center', bbox_to_anchor=(0.5, 1.3),ncol=2, frameon=False)\n", 212 | "def autolabel(rects):\n", 213 | " # attach some text labels\n", 214 | " for idx, rect in enumerate(rects):\n", 215 | " height = rect.get_height()\n", 216 | " ax.text(rect.get_x() + rect.get_width()/2., 1.05*height,\n", 217 | " '%.1f' % height,\n", 218 | " ha='center', va='bottom',size=18)\n", 219 | "autolabel(rects1)\n", 220 | "autolabel(rects2)\n", 221 | "autolabel(rects3)\n", 222 | "autolabel(rects4)\n", 223 | "autolabel\n", 224 | "plt.ylim(0,6)\n", 225 | "\n", 226 | "#plt.savefig('scaling-exact-inference.pdf',bbox_inches='tight')" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": { 233 | "collapsed": true 234 | }, 235 | "outputs": [], 236 | "source": [] 237 | } 238 | ], 239 | "metadata": { 240 | "kernelspec": { 241 | "display_name": "Python 2", 242 | "language": "python", 243 | "name": "python2" 244 | }, 245 | "language_info": { 246 | "codemirror_mode": { 247 | "name": "ipython", 248 | "version": 2 249 | }, 250 | "file_extension": ".py", 251 | "mimetype": "text/x-python", 252 | "name": "python", 253 | "nbconvert_exporter": "python", 254 | "pygments_lexer": "ipython2", 255 | "version": "2.7.3" 256 | } 257 | }, 258 | "nbformat": 4, 259 | "nbformat_minor": 0 260 | } 261 | -------------------------------------------------------------------------------- /paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/paper.pdf -------------------------------------------------------------------------------- /parse_args_dkf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Parse command line and store result in params 3 | Model : DKF 4 | """ 5 | import argparse,copy 6 | from collections import OrderedDict 7 | p = argparse.ArgumentParser(description="Arguments for variational autoencoder") 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('-dset','--dataset', action='store',default = '', help='Dataset', type=str) 10 | 11 | #Recognition Model 12 | parser.add_argument('-rl','--rnn_layers', action='store',default = 1, help='Number of layers in the RNN', type=int, choices=[1,2]) 13 | parser.add_argument('-rs','--rnn_size', action='store',default = 600, help='Hidden unit size in q model/RNN', type=int) 14 | parser.add_argument('-rd','--rnn_dropout', action='store',default = 0.1, help='Dropout after each RNN output layer', type=float) 15 | parser.add_argument('-vm','--var_model', action='store',default = 'R', help='Variational Model', type=str, choices=['L','LR','R']) 16 | parser.add_argument('-infm','--inference_model', action='store',default = 'structured', help='Inference Model', type=str, choices=['mean_field','structured']) 17 | parser.add_argument('-ql','--q_mlp_layers', action='store',default = 1, help='#Layers in Recognition Model', type=int) 18 | parser.add_argument('-useprior','--use_generative_prior', action='store_true', help='Use genertative prior in inference network') 19 | 20 | #Generative model 21 | parser.add_argument('-ds','--dim_stochastic', action='store',default = 100, help='Stochastic dimensions', type=int) 22 | parser.add_argument('-dh','--dim_hidden', action='store', default = 200, help='Hidden dimensions in DKF', type=int) 23 | parser.add_argument('-tl','--transition_layers', action='store', default = 2, help='Layers in transition fxn', type=int) 24 | parser.add_argument('-ttype','--transition_type', action='store', default = 'simple_gated', help='Layers in transition fxn', type=str, choices=['mlp','simple_gated']) 25 | parser.add_argument('-previnp','--use_prev_input', action='store_true', help='Use previous input in transition') 26 | parser.add_argument('-usenade','--use_nade', action='store_true', help='Use NADE ') 27 | 28 | parser.add_argument('-el','--emission_layers', action='store',default = 2, help='Layers in emission fxn', type=int) 29 | parser.add_argument('-etype','--emission_type', action='store',default = 'mlp', help='Type of emission fxn', type=str, choices=['mlp','conditional','res']) 30 | 31 | #Weights and Nonlinearity 32 | parser.add_argument('-iw','--init_weight', action='store',default = 0.1, help='Range to initialize weights during learning',type=float) 33 | parser.add_argument('-ischeme','--init_scheme', action='store',default = 'uniform', help='Type of initialization for weights', type=str, choices=['uniform','normal','xavier','he','orthogonal']) 34 | parser.add_argument('-nl','--nonlinearity', action='store',default = 'relu', help='Nonlinarity',type=str, choices=['relu','tanh','softplus','maxout','elu']) 35 | parser.add_argument('-lky','--leaky_param', action='store',default =0., help='Leaky ReLU parameter',type=float) 36 | parser.add_argument('-mstride','--maxout_stride', action='store',default = 4, help='Stride for maxout',type=int) 37 | parser.add_argument('-fg','--forget_bias', action='store',default = -5., help='Bias for forget gates', type=float) 38 | 39 | parser.add_argument('-vonly','--validate_only', action='store_true', help='Only build fxn for validation') 40 | 41 | #Optimization 42 | parser.add_argument('-lr','--lr', action='store',default = 8e-4, help='Learning rate', type=float) 43 | parser.add_argument('-opt','--optimizer', action='store',default = 'adam', help='Optimizer',choices=['adam','rmsprop']) 44 | parser.add_argument('-bs','--batch_size', action='store',default = 20, help='Batch Size',type=int) 45 | parser.add_argument('-ar','--anneal_rate', action='store',default = 10., help='Number of param. updates before anneal=1',type=float) 46 | parser.add_argument('-repK','--replicate_K', action='store',default = None, help='Number of samples used for the variational bound. Created by replicating the batch',type=int) 47 | parser.add_argument('-shuf','--shuffle', action='store_true',help='Shuffle during training') 48 | parser.add_argument('-covexp','--cov_explicit', action='store_true',help='Explicitly parameterize covariance') 49 | parser.add_argument('-nt','--ntrain', action='store',type=int,default=5000,help='number of training') 50 | 51 | #Regularization 52 | parser.add_argument('-reg','--reg_type', action='store',default = 'l2', help='Type of regularization',type=str,choices=['l1','l2']) 53 | parser.add_argument('-rv','--reg_value', action='store',default = 0.05, help='Amount of regularization',type=float) 54 | parser.add_argument('-rspec','--reg_spec', action='store',default = '_', help='String to match parameters (Default is generative model)',type=str) 55 | 56 | #Save/load 57 | parser.add_argument('-debug','--debug', action='store_true',help='Debug') 58 | parser.add_argument('-uid','--unique_id', action='store',default = 'uid',help='Unique Identifier',type=str) 59 | parser.add_argument('-seed','--seed', action='store',default = 1, help='Random Seed',type=int) 60 | parser.add_argument('-dir','--savedir', action='store',default = './chkpt', help='Prefix for savedir',type=str) 61 | parser.add_argument('-ep','--epochs', action='store',default = 2000, help='MaxEpochs',type=int) 62 | parser.add_argument('-reload','--reloadFile', action='store',default = './NOSUCHFILE', help='Reload from saved model',type=str) 63 | parser.add_argument('-params','--paramFile', action='store',default = './NOSUCHFILE', help='Reload parameters from saved model',type=str) 64 | parser.add_argument('-sfreq','--savefreq', action='store',default = 25, help='Frequency of saving',type=int) 65 | params = vars(parser.parse_args()) 66 | 67 | hmap = OrderedDict() 68 | hmap['lr']='lr' 69 | hmap['var_model']='vm' 70 | hmap['inference_model']='inf' 71 | hmap['dim_hidden']='dh' 72 | hmap['dim_stochastic']='ds' 73 | hmap['nonlinearity']='nl' 74 | hmap['batch_size']='bs' 75 | hmap['epochs']='ep' 76 | hmap['rnn_size']='rs' 77 | hmap['transition_type']='ttype' 78 | hmap['emission_type']='etype' 79 | hmap['use_prev_input']='previnp' 80 | hmap['anneal_rate']='ar' 81 | hmap['reg_value']='rv' 82 | hmap['use_nade']='nade' 83 | hmap['ntrain']='nt' 84 | combined = '' 85 | for k in hmap: 86 | if k in params: 87 | if type(params[k]) is float: 88 | combined+=hmap[k]+'-'+('%.4e')%(params[k])+'-' 89 | else: 90 | combined+=hmap[k]+'-'+str(params[k])+'-' 91 | 92 | params['expt_name'] = params['unique_id'] 93 | params['unique_id'] = combined[:-1]+'-'+params['unique_id'] 94 | params['unique_id'] = 'DKF_'+params['unique_id'].replace('.','_') 95 | """ 96 | import cPickle as pickle 97 | with open('default.pkl','wb') as f: 98 | pickle.dump(params,f) 99 | """ 100 | -------------------------------------------------------------------------------- /polyphonic_samples/jsb0.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/jsb0.mp3 -------------------------------------------------------------------------------- /polyphonic_samples/jsb1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/jsb1.mp3 -------------------------------------------------------------------------------- /polyphonic_samples/musedata0.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/musedata0.mp3 -------------------------------------------------------------------------------- /polyphonic_samples/musedata1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/musedata1.mp3 -------------------------------------------------------------------------------- /polyphonic_samples/nott0.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/nott0.mp3 -------------------------------------------------------------------------------- /polyphonic_samples/nott1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/nott1.mp3 -------------------------------------------------------------------------------- /polyphonic_samples/piano0.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/piano0.mp3 -------------------------------------------------------------------------------- /polyphonic_samples/piano1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clinicalml/structuredinference/ad2a00b90092ba69d52fe8306ce7136d98f195c4/polyphonic_samples/piano1.mp3 -------------------------------------------------------------------------------- /stinfmodel/README.md: -------------------------------------------------------------------------------- 1 | ## Main Repository containing implementation of DKF 2 | 3 | The file dkf.py contains the class definition for the Deep Kalman Filter 4 | 5 | The file uses many of the arguments from structuredinference/parse_args_dkf.py to create the model. 6 | 7 | The main parameters are detailed here below: 8 | 9 | ``` 10 | * dim_stochastic 11 | * dim_hidden 12 | ``` 13 | -------------------------------------------------------------------------------- /stinfmodel/__init__.py: -------------------------------------------------------------------------------- 1 | all=['dkf','evaluate','learning'] 2 | -------------------------------------------------------------------------------- /stinfmodel/dkf.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import six.moves.cPickle as pickle 3 | from collections import OrderedDict 4 | import numpy as np 5 | import sys, time, os, gzip, theano,math 6 | sys.path.append('../') 7 | from theano import config 8 | theano.config.compute_test_value = 'warn' 9 | from theano.printing import pydotprint 10 | import theano.tensor as T 11 | from utils.misc import saveHDF5 12 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 13 | 14 | from utils.optimizer import adam,rmsprop 15 | from models.__init__ import BaseModel 16 | from datasets.synthp import params_synthetic 17 | from datasets.synthpTheano import updateParamsSynthetic 18 | 19 | #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%# 20 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""# 21 | #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%# 22 | 23 | """ 24 | DEEP KALMAN FILTER 25 | """ 26 | class DKF(BaseModel, object): 27 | def __init__(self, params, paramFile=None, reloadFile=None): 28 | self.scan_updates = [] 29 | super(DKF,self).__init__(params, paramFile=paramFile, reloadFile=reloadFile) 30 | if 'synthetic' in self.params['dataset'] and not hasattr(self, 'params_synthetic'): 31 | assert False, 'Expecting to have params_synthetic as an attribute in DKF class' 32 | assert self.params['nonlinearity']!='maxout','Maxout nonlinearity not supported' 33 | 34 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""# 35 | def _fakeData(self): 36 | """ Fake data for tag testing """ 37 | T = 3 38 | N = 2 39 | mask = np.random.random((N,T)).astype(config.floatX) 40 | small = mask<0.5 41 | large = mask>=0.5 42 | mask[small] = 0. 43 | mask[large]= 1. 44 | eps = np.random.randn(N,T,self.params['dim_stochastic']).astype(config.floatX) 45 | X = np.random.randn(N,T,self.params['dim_observations']).astype(config.floatX) 46 | return X ,mask, eps 47 | 48 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""# 49 | def _createParams(self): 50 | """ Model parameters """ 51 | npWeights = OrderedDict() 52 | self._createInferenceParams(npWeights) 53 | self._createGenerativeParams(npWeights) 54 | return npWeights 55 | 56 | def _createGenerativeParams(self, npWeights): 57 | """ Create weights/params for generative model """ 58 | if 'synthetic' in self.params['dataset']: 59 | return 60 | DIM_HIDDEN = self.params['dim_hidden'] 61 | DIM_STOCHASTIC = self.params['dim_stochastic'] 62 | if self.params['transition_type']=='mlp': 63 | DIM_HIDDEN_TRANS = DIM_HIDDEN*2 64 | for l in range(self.params['transition_layers']): 65 | dim_input,dim_output = DIM_HIDDEN_TRANS, DIM_HIDDEN_TRANS 66 | if l==0: 67 | dim_input = self.params['dim_stochastic'] 68 | npWeights['p_trans_W_'+str(l)] = self._getWeight((dim_input, dim_output)) 69 | npWeights['p_trans_b_'+str(l)] = self._getWeight((dim_output,)) 70 | if self.params['use_prev_input']: 71 | npWeights['p_trans_W_0'] = self._getWeight((DIM_STOCHASTIC+self.params['dim_observations'], 72 | DIM_HIDDEN_TRANS)) 73 | npWeights['p_trans_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,)) 74 | MU_COV_INP = DIM_HIDDEN_TRANS 75 | elif self.params['transition_type']=='simple_gated': 76 | DIM_HIDDEN_TRANS = DIM_HIDDEN*2 77 | npWeights['p_gate_embed_W_0'] = self._getWeight((DIM_STOCHASTIC, DIM_HIDDEN_TRANS)) 78 | npWeights['p_gate_embed_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,)) 79 | npWeights['p_gate_embed_W_1'] = self._getWeight((DIM_HIDDEN_TRANS, DIM_STOCHASTIC)) 80 | npWeights['p_gate_embed_b_1'] = self._getWeight((DIM_STOCHASTIC,)) 81 | npWeights['p_z_W_0'] = self._getWeight((DIM_STOCHASTIC, DIM_HIDDEN_TRANS)) 82 | npWeights['p_z_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,)) 83 | npWeights['p_z_W_1'] = self._getWeight((DIM_HIDDEN_TRANS, DIM_STOCHASTIC)) 84 | npWeights['p_z_b_1'] = self._getWeight((DIM_STOCHASTIC,)) 85 | if self.params['use_prev_input']: 86 | npWeights['p_z_W_0'] = self._getWeight((DIM_STOCHASTIC+self.params['dim_observations'], DIM_HIDDEN_TRANS)) 87 | npWeights['p_z_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,)) 88 | npWeights['p_gate_embed_W_0'] = self._getWeight((DIM_STOCHASTIC+self.params['dim_observations'], 89 | DIM_HIDDEN_TRANS)) 90 | npWeights['p_gate_embed_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,)) 91 | MU_COV_INP = DIM_STOCHASTIC 92 | else: 93 | assert False,'Invalid transition type: '+self.params['transition_type'] 94 | 95 | if self.params['transition_type']=='simple_gated': 96 | weight= np.eye(self.params['dim_stochastic']).astype(config.floatX) 97 | bias = np.zeros((self.params['dim_stochastic'],)).astype(config.floatX) 98 | #Initialize the weights to be identity 99 | npWeights['p_trans_W_mu'] = weight 100 | npWeights['p_trans_b_mu'] = bias 101 | else: 102 | npWeights['p_trans_W_mu'] = self._getWeight((MU_COV_INP, self.params['dim_stochastic'])) 103 | npWeights['p_trans_b_mu'] = self._getWeight((self.params['dim_stochastic'],)) 104 | npWeights['p_trans_W_cov'] = self._getWeight((MU_COV_INP, self.params['dim_stochastic'])) 105 | npWeights['p_trans_b_cov'] = self._getWeight((self.params['dim_stochastic'],)) 106 | 107 | #Emission Function [MLP] 108 | if self.params['emission_type'] == 'mlp': 109 | for l in range(self.params['emission_layers']): 110 | dim_input,dim_output = DIM_HIDDEN, DIM_HIDDEN 111 | if l==0: 112 | dim_input = self.params['dim_stochastic'] 113 | npWeights['p_emis_W_'+str(l)] = self._getWeight((dim_input, dim_output)) 114 | npWeights['p_emis_b_'+str(l)] = self._getWeight((dim_output,)) 115 | elif self.params['emission_type'] =='conditional': 116 | for l in range(self.params['emission_layers']): 117 | dim_input,dim_output = DIM_HIDDEN, DIM_HIDDEN 118 | if l==0: 119 | dim_input = self.params['dim_stochastic']+self.params['dim_observations'] 120 | npWeights['p_emis_W_'+str(l)] = self._getWeight((dim_input, dim_output)) 121 | npWeights['p_emis_b_'+str(l)] = self._getWeight((dim_output,)) 122 | else: 123 | assert False, 'Invalid emission type: '+str(self.params['emission_type']) 124 | if self.params['data_type']=='binary': 125 | npWeights['p_emis_W_ber'] = self._getWeight((self.params['dim_hidden'], self.params['dim_observations'])) 126 | npWeights['p_emis_b_ber'] = self._getWeight((self.params['dim_observations'],)) 127 | elif self.params['data_type']=='binary_nade': 128 | n_visible, n_hidden = self.params['dim_observations'], self.params['dim_hidden'] 129 | npWeights['p_nade_W'] = self._getWeight((n_visible, n_hidden)) 130 | npWeights['p_nade_U'] = self._getWeight((n_visible,n_hidden)) 131 | npWeights['p_nade_b'] = self._getWeight((n_visible,)) 132 | else: 133 | assert False,'Invalid datatype: '+params['data_type'] 134 | 135 | def _createInferenceParams(self, npWeights): 136 | """ Create weights/params for inference network """ 137 | 138 | #Initial embedding for the inputs 139 | DIM_INPUT = self.params['dim_observations'] 140 | RNN_SIZE = self.params['rnn_size'] 141 | 142 | DIM_HIDDEN = RNN_SIZE 143 | DIM_STOC = self.params['dim_stochastic'] 144 | 145 | #Embed the Input -> RNN_SIZE 146 | dim_input, dim_output= DIM_INPUT, RNN_SIZE 147 | npWeights['q_W_input_0'] = self._getWeight((dim_input, dim_output)) 148 | npWeights['q_b_input_0'] = self._getWeight((dim_output,)) 149 | 150 | #Setup weights for LSTM 151 | self._createLSTMWeights(npWeights) 152 | 153 | #Embedding before MF/ST inference model 154 | if self.params['inference_model']=='mean_field': 155 | pass 156 | elif self.params['inference_model']=='structured': 157 | DIM_INPUT = self.params['dim_stochastic'] 158 | if self.params['use_generative_prior']: 159 | DIM_INPUT = self.params['rnn_size'] 160 | npWeights['q_W_st_0'] = self._getWeight((DIM_INPUT, self.params['rnn_size'])) 161 | npWeights['q_b_st_0'] = self._getWeight((self.params['rnn_size'],)) 162 | else: 163 | assert False,'Invalid inference model: '+self.params['inference_model'] 164 | RNN_SIZE = self.params['rnn_size'] 165 | npWeights['q_W_mu'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic'])) 166 | npWeights['q_b_mu'] = self._getWeight((self.params['dim_stochastic'],)) 167 | npWeights['q_W_cov'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic'])) 168 | npWeights['q_b_cov'] = self._getWeight((self.params['dim_stochastic'],)) 169 | if self.params['var_model']=='LR' and self.params['inference_model']=='mean_field': 170 | npWeights['q_W_mu_r'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic'])) 171 | npWeights['q_b_mu_r'] = self._getWeight((self.params['dim_stochastic'],)) 172 | npWeights['q_W_cov_r'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic'])) 173 | npWeights['q_b_cov_r'] = self._getWeight((self.params['dim_stochastic'],)) 174 | 175 | def _createLSTMWeights(self, npWeights): 176 | #LSTM L/LR/R w/ orthogonal weight initialization 177 | suffices_to_build = [] 178 | if self.params['var_model']=='LR' or self.params['var_model']=='L': 179 | suffices_to_build.append('l') 180 | if self.params['var_model']=='LR' or self.params['var_model']=='R': 181 | suffices_to_build.append('r') 182 | RNN_SIZE = self.params['rnn_size'] 183 | for suffix in suffices_to_build: 184 | for l in range(self.params['rnn_layers']): 185 | npWeights['W_lstm_'+suffix+'_'+str(l)] = self._getWeight((RNN_SIZE,RNN_SIZE*4)) 186 | npWeights['b_lstm_'+suffix+'_'+str(l)] = self._getWeight((RNN_SIZE*4,), scheme='lstm') 187 | npWeights['U_lstm_'+suffix+'_'+str(l)] = self._getWeight((RNN_SIZE,RNN_SIZE*4),scheme='lstm') 188 | 189 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""# 190 | def _getEmissionFxn(self, z, X=None): 191 | """ 192 | Apply emission function to zs 193 | Input: z [bs x T x dim] 194 | Output: (params, ) or (mu, cov) of size [bs x T x dim] 195 | """ 196 | if 'synthetic' in self.params['dataset']: 197 | self._p('Using emission function for '+self.params['dataset']) 198 | mu = self.params_synthetic[self.params['dataset']]['obs_fxn'](z) 199 | cov = T.ones_like(mu)*self.params_synthetic[self.params['dataset']]['obs_cov'] 200 | cov.name = 'EmissionCov' 201 | return [mu,cov] 202 | 203 | if self.params['emission_type']=='mlp': 204 | self._p('EMISSION TYPE: MLP') 205 | hid = z 206 | elif self.params['emission_type']=='conditional': 207 | self._p('EMISSION TYPE: conditional') 208 | X_prev = T.concatenate([T.zeros_like(X[:,[0],:]),X[:,:-1,:]],axis=1) 209 | hid = T.concatenate([z,X_prev],axis=2) 210 | else: 211 | assert False,'Invalid emission type' 212 | 213 | self._p('TODO: FIX THIS, SHOULD BE LINEAR FOR NADE') 214 | for l in range(self.params['emission_layers']): 215 | #Do not use a non-linearity in the last layer 216 | #if l==self.params['emission_layers']-1: 217 | # hid = T.dot(hid, self.tWeights['p_emis_W_'+str(l)]) + self.tWeights['p_emis_b_'+str(l)] 218 | hid = self._LinearNL(self.tWeights['p_emis_W_'+str(l)], self.tWeights['p_emis_b_'+str(l)], hid) 219 | if self.params['data_type']=='binary': 220 | mean_params=T.nnet.sigmoid(T.dot(hid,self.tWeights['p_emis_W_ber'])+self.tWeights['p_emis_b_ber']) 221 | return [mean_params] 222 | elif self.params['data_type']=='binary_nade': 223 | self._p('NADE observations') 224 | assert X is not None,'Need observations for NADE' 225 | x_reshaped = X.dimshuffle(2,0,1) 226 | x0 = T.ones((hid.shape[0],hid.shape[1]))#x_reshaped[0]) # bs x T 227 | a0 = hid #bs x T x nhid 228 | W = self.tWeights['p_nade_W'] 229 | V = self.tWeights['p_nade_U'] 230 | b = self.tWeights['p_nade_b'] 231 | #Use a NADE at the output 232 | def NADEDensity(x, w, v, b, a_prev, x_prev):#Estimating likelihood 233 | a = a_prev + T.dot(T.shape_padright(x_prev, 1), T.shape_padleft(w, 1)) 234 | h = T.nnet.sigmoid(a) #bs x T x nhid 235 | p_xi_is_one = T.nnet.sigmoid(T.dot(h, v) + b) 236 | return (a, x, p_xi_is_one) 237 | ([_, _, mean_params], _) = theano.scan(NADEDensity, 238 | sequences=[x_reshaped, W, V, b], 239 | outputs_info=[a0, x0,None]) 240 | #theano function to sample from NADE 241 | def NADESample(w, v, b, a_prev_s, x_prev_s): 242 | a_s = a_prev_s + T.dot(T.shape_padright(x_prev_s, 1), T.shape_padleft(w, 1)) 243 | h_s = T.nnet.sigmoid(a_s) #bs x T x nhid 244 | p_xi_is_one_s = T.nnet.sigmoid(T.dot(h_s, v) + b) 245 | x_s = T.switch(p_xi_is_one_s>0.5,1.,0.) 246 | return (a_s, x_s, p_xi_is_one_s) 247 | ([_, _, sampled_params], _) = theano.scan(NADESample, 248 | sequences=[W, V, b], 249 | outputs_info=[a0, x0,None]) 250 | """ 251 | def NADEDensityAndSample(x, w, v, b, 252 | a_prev, x_prev, 253 | a_prev_s, x_prev_s ): 254 | a = a_prev + T.dot(T.shape_padright(x_prev, 1), T.shape_padleft(w, 1)) 255 | h = T.nnet.sigmoid(a) #bs x T x nhid 256 | p_xi_is_one = T.nnet.sigmoid(T.dot(h, v) + b) 257 | 258 | a_s = a_prev_s + T.dot(T.shape_padright(x_prev_s, 1), T.shape_padleft(w, 1)) 259 | h_s = T.nnet.sigmoid(a_s) #bs x T x nhid 260 | p_xi_is_one_s = T.nnet.sigmoid(T.dot(h_s, v) + b) 261 | x_s = T.switch(p_xi_is_one_s>0.5,1.,0.) 262 | return (a, x, a_s, x_s, p_xi_is_one, p_xi_is_one_s) 263 | 264 | ([_, _, _, _, mean_params,sampled_params], _) = theano.scan(NADEDensityAndSample, 265 | sequences=[x_reshaped, W, V, b], 266 | outputs_info=[a0, x0, a0, x0, None, None]) 267 | """ 268 | sampled_params = sampled_params.dimshuffle(1,2,0) 269 | mean_params = mean_params.dimshuffle(1,2,0) 270 | return [mean_params,sampled_params] 271 | else: 272 | assert False,'Invalid type of data' 273 | 274 | def _getTransitionFxn(self, z, X=None): 275 | """ 276 | Apply transition function to zs 277 | Input: z [bs x T x dim], u [bs x T x dim] 278 | Output: mu, cov of size [bs x T x dim] 279 | """ 280 | if 'synthetic' in self.params['dataset']: 281 | self._p('Using transition function for '+self.params['dataset']) 282 | mu = self.params_synthetic[self.params['dataset']]['trans_fxn'](z) 283 | cov = T.ones_like(mu)*self.params_synthetic[self.params['dataset']]['trans_cov'] 284 | cov.name = 'TransitionCov' 285 | return mu,cov 286 | 287 | if self.params['transition_type']=='simple_gated': 288 | def mlp(inp, W1,b1,W2,b2, X_prev=None): 289 | if X_prev is not None: 290 | h1 = self._LinearNL(W1,b1, T.concatenate([inp,X_prev],axis=2)) 291 | else: 292 | h1 = self._LinearNL(W1,b1, inp) 293 | h2 = T.dot(h1,W2)+b2 294 | return h2 295 | 296 | gateInp= z 297 | X_prev = None 298 | if self.params['use_prev_input']: 299 | X_prev = T.concatenate([T.zeros_like(X[:,[0],:]),X[:,:-1,:]],axis=1) 300 | gate = T.nnet.sigmoid(mlp(gateInp, self.tWeights['p_gate_embed_W_0'], self.tWeights['p_gate_embed_b_0'], 301 | self.tWeights['p_gate_embed_W_1'],self.tWeights['p_gate_embed_b_1'], 302 | X_prev = X_prev)) 303 | z_prop = mlp(z,self.tWeights['p_z_W_0'] ,self.tWeights['p_z_b_0'], 304 | self.tWeights['p_z_W_1'] , self.tWeights['p_z_b_1'], X_prev = X_prev) 305 | mu = gate*z_prop + (1.-gate)*(T.dot(z, self.tWeights['p_trans_W_mu'])+self.tWeights['p_trans_b_mu']) 306 | cov = T.nnet.softplus(T.dot(self._applyNL(z_prop), self.tWeights['p_trans_W_cov'])+ 307 | self.tWeights['p_trans_b_cov']) 308 | return mu,cov 309 | elif self.params['transition_type']=='mlp': 310 | hid = z 311 | if self.params['use_prev_input']: 312 | X_prev = T.concatenate([T.zeros_like(X[:,[0],:]),X[:,:-1,:]],axis=1) 313 | hid = T.concatenate([hid,X_prev],axis=2) 314 | for l in range(self.params['transition_layers']): 315 | hid = self._LinearNL(self.tWeights['p_trans_W_'+str(l)],self.tWeights['p_trans_b_'+str(l)],hid) 316 | mu = T.dot(hid, self.tWeights['p_trans_W_mu']) + self.tWeights['p_trans_b_mu'] 317 | cov = T.nnet.softplus(T.dot(hid, self.tWeights['p_trans_W_cov'])+self.tWeights['p_trans_b_cov']) 318 | return mu,cov 319 | else: 320 | assert False,'Invalid Transition type: '+str(self.params['transition_type']) 321 | 322 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""# 323 | def _buildLSTM(self, X, embedding, dropout_prob = 0.): 324 | """ 325 | Take the embedding of bs x T x dim and return T x bs x dim that is the result of the scan operation 326 | for L/LR/R 327 | Input: embedding [bs x T x dim] 328 | Output:hidden_state [T x bs x dim] 329 | """ 330 | start_time = time.time() 331 | self._p('In <_buildLSTM>') 332 | suffix = '' 333 | if self.params['var_model']=='R': 334 | suffix='r' 335 | return self._LSTMlayer(embedding, suffix, dropout_prob) 336 | elif self.params['var_model']=='L': 337 | suffix='l' 338 | return self._LSTMlayer(embedding, suffix, dropout_prob) 339 | elif self.params['var_model']=='LR': 340 | suffix='l' 341 | l2r = self._LSTMlayer(embedding, suffix, dropout_prob) 342 | suffix='r' 343 | r2l = self._LSTMlayer(embedding, suffix, dropout_prob) 344 | return [l2r,r2l] 345 | else: 346 | assert False,'Invalid variational model' 347 | self._p(('Done <_buildLSTM> [Took %.4f]')%(time.time()-start_time)) 348 | 349 | 350 | def _inferenceLayer(self, hidden_state, eps): 351 | """ 352 | Take input of T x bs x dim and return z, mu, 353 | sq each of size (bs x T x dim) 354 | Input: hidden_state [T x bs x dim], eps [bs x T x dim] 355 | Output: z [bs x T x dim], mu [bs x T x dim], cov [bs x T x dim] 356 | """ 357 | def structuredApproximation(h_t, eps_t, z_prev, 358 | q_W_st_0, q_b_st_0, 359 | q_W_mu, q_b_mu, 360 | q_W_cov,q_b_cov): 361 | #Using the prior distribution directly 362 | if self.params['use_generative_prior']: 363 | assert not self.params['use_prev_input'],'No support for using previous input' 364 | #Get mu/cov from z_prev through prior distribution 365 | mu_1,cov_1 = self._getTransitionFxn(z_prev) 366 | #Combine with estimate of mu/cov from data 367 | h_data = T.tanh(T.dot(h_t,q_W_st_0)+q_b_st_0) 368 | mu_2 = T.dot(h_data,q_W_mu)+q_b_mu 369 | cov_2 = T.nnet.softplus(T.dot(h_data,q_W_cov)+q_b_cov) 370 | mu = (mu_1*cov_2+mu_2*cov_1)/(cov_1+cov_2) 371 | cov = (cov_1*cov_2)/(cov_1+cov_2) 372 | z = mu + T.sqrt(cov)*eps_t 373 | else: 374 | h_next = T.tanh(T.dot(z_prev,q_W_st_0)+q_b_st_0) 375 | if self.params['var_model']=='LR': 376 | h_next = (1./3.)*(h_t+h_next) 377 | else: 378 | h_next = (1./2.)*(h_t+h_next) 379 | mu_t = T.dot(h_next,q_W_mu)+q_b_mu 380 | cov_t = T.nnet.softplus(T.dot(h_next,q_W_cov)+q_b_cov) 381 | z_t = mu_t+T.sqrt(cov_t)*eps_t 382 | return z_t, mu_t, cov_t 383 | 384 | if self.params['inference_model']=='structured': 385 | #Structured recognition networks 386 | if self.params['var_model']=='LR': 387 | state = hidden_state[0]+hidden_state[1] 388 | else: 389 | state = hidden_state 390 | eps_swap = eps.swapaxes(0,1) 391 | if self.params['dim_stochastic']==1: 392 | """ 393 | TODO: Write to theano authors regarding this issue. 394 | Workaround for theano issue: The result of a matrix multiply is a "matrix" 395 | even if one of the dimensions is 1. However defining a tensor with one dimension one 396 | means theano regards the resulting tensor as a matrix and consequently in the 397 | scan as a column. This results in a mismatch in tensor type in input (column) 398 | and output (matrix) and throws an error. This is a workaround that preserves 399 | type while 400 | """ 401 | z0 = T.zeros((eps_swap.shape[1], self.params['rnn_size'])) 402 | z0 = T.dot(z0,T.zeros_like(self.tWeights['q_W_mu'])) 403 | else: 404 | z0 = T.zeros((eps_swap.shape[1], self.params['dim_stochastic'])) 405 | rval, _ = theano.scan(structuredApproximation, 406 | sequences=[state, eps_swap], 407 | outputs_info=[z0, None,None], 408 | non_sequences=[self.tWeights[k] for k in 409 | ['q_W_st_0', 'q_b_st_0']]+ 410 | [self.tWeights[k] for k in 411 | ['q_W_mu','q_b_mu','q_W_cov','q_b_cov']], 412 | name='structuredApproximation') 413 | z, mu, cov = rval[0].swapaxes(0,1), rval[1].swapaxes(0,1), rval[2].swapaxes(0,1) 414 | return z, mu, cov 415 | elif self.params['inference_model']=='mean_field': 416 | if self.params['var_model']=='LR': 417 | l2r = hidden_state[0].swapaxes(0,1) 418 | r2l = hidden_state[1].swapaxes(0,1) 419 | hidl2r = l2r 420 | mu_1 = T.dot(hidl2r,self.tWeights['q_W_mu'])+self.tWeights['q_b_mu'] 421 | cov_1 = T.nnet.softplus(T.dot(hidl2r, self.tWeights['q_W_cov'])+self.tWeights['q_b_cov']) 422 | hidr2l = r2l 423 | mu_2 = T.dot(hidr2l,self.tWeights['q_W_mu_r'])+self.tWeights['q_b_mu_r'] 424 | cov_2 = T.nnet.softplus(T.dot(hidr2l, self.tWeights['q_W_cov_r'])+self.tWeights['q_b_cov_r']) 425 | mu = (mu_1*cov_2+mu_2*cov_1)/(cov_1+cov_2) 426 | cov= (cov_1*cov_2)/(cov_1+cov_2) 427 | z = mu + T.sqrt(cov)*eps 428 | else: 429 | hid = hidden_state.swapaxes(0,1) 430 | mu = T.dot(hid,self.tWeights['q_W_mu'])+self.tWeights['q_b_mu'] 431 | cov = T.nnet.softplus(T.dot(hid,self.tWeights['q_W_cov'])+ self.tWeights['q_b_cov']) 432 | z = mu + T.sqrt(cov)*eps 433 | return z,mu,cov 434 | else: 435 | assert False,'Invalid recognition model' 436 | 437 | def _qEmbeddingLayer(self, X): 438 | """ Embed for q """ 439 | return self._LinearNL(self.tWeights['q_W_input_0'],self.tWeights['q_b_input_0'], X) 440 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""# 441 | 442 | def _inferenceAndReconstruction(self, X, eps, dropout_prob = 0.): 443 | """ 444 | Returns z_q, mu_q and cov_q 445 | """ 446 | self._p('Building with dropout:'+str(dropout_prob)) 447 | embedding = self._qEmbeddingLayer(X) 448 | hidden_state = self._buildLSTM(X, embedding, dropout_prob) 449 | z_q,mu_q,cov_q = self._inferenceLayer(hidden_state, eps) 450 | 451 | #Regularize z_q (for train) 452 | #if dropout_prob>0.: 453 | # z_q = z_q + self.srng.normal(z_q.shape, 0.,0.0025,dtype=config.floatX) 454 | #z_q.name = 'z_q' 455 | 456 | observation_params = self._getEmissionFxn(z_q,X=X) 457 | mu_trans, cov_trans= self._getTransitionFxn(z_q, X=X) 458 | mu_prior = T.concatenate([T.alloc(np.asarray(0.).astype(config.floatX), 459 | X.shape[0],1,self.params['dim_stochastic']), mu_trans[:,:-1,:]],axis=1) 460 | cov_prior = T.concatenate([T.alloc(np.asarray(1.).astype(config.floatX), 461 | X.shape[0],1,self.params['dim_stochastic']), cov_trans[:,:-1,:]],axis=1) 462 | return observation_params, z_q, mu_q, cov_q, mu_prior, cov_prior, mu_trans, cov_trans 463 | 464 | 465 | def _getTemporalKL(self, mu_q, cov_q, mu_prior, cov_prior, M, batchVector = False): 466 | """ 467 | TemporalKL divergence KL (q||p) 468 | KL(q_t||p_t) = 0.5*(log|sigmasq_p| -log|sigmasq_q| -D + Tr(sigmasq_p^-1 sigmasq_q) 469 | + (mu_p-mu_q)^T sigmasq_p^-1 (mu_p-mu_q)) 470 | M is a mask of size bs x T that should be applied once the KL divergence for each point 471 | across time has been estimated 472 | """ 473 | assert np.all(cov_q.tag.test_value>0.),'should be positive' 474 | assert np.all(cov_prior.tag.test_value>0.),'should be positive' 475 | diff_mu = mu_prior-mu_q 476 | KL_t = T.log(cov_prior)-T.log(cov_q) - 1. + cov_q/cov_prior + diff_mu**2/cov_prior 477 | KLvec = (0.5*KL_t.sum(2)*M).sum(1,keepdims=True) 478 | if batchVector: 479 | return KLvec 480 | else: 481 | return KLvec.sum() 482 | 483 | def _getNegCLL(self, obs_params, X, M, batchVector = False): 484 | """ 485 | Estimate the negative conditional log likelihood of x|z under the generative model 486 | M: mask of size bs x T 487 | X: target of size bs x T x dim 488 | """ 489 | if self.params['data_type']=='real': 490 | mu_p = obs_params[0] 491 | cov_p = obs_params[1] 492 | std_p = T.sqrt(cov_p) 493 | negCLL_t = 0.5 * np.log(2 * np.pi) + 0.5*T.log(cov_p) + 0.5 * ((X - mu_p) / std_p)**2 494 | negCLL = (negCLL_t.sum(2)*M).sum(1,keepdims=True) 495 | else: 496 | mean_p = obs_params[0] 497 | negCLL = (T.nnet.binary_crossentropy(mean_p,X).sum(2)*M).sum(1,keepdims=True) 498 | if batchVector: 499 | return negCLL 500 | else: 501 | return negCLL.sum() 502 | 503 | def _buildModel(self): 504 | if 'synthetic' in self.params['dataset']: 505 | self.params_synthetic = params_synthetic 506 | """ High level function to build and setup theano functions """ 507 | X = T.tensor3('X', dtype=config.floatX) 508 | eps = T.tensor3('eps', dtype=config.floatX) 509 | M = T.matrix('M', dtype=config.floatX) 510 | X.tag.test_value, M.tag.test_value, eps.tag.test_value = self._fakeData() 511 | 512 | #Learning Rates and annealing objective function 513 | #Add them to npWeights/tWeights to be tracked [do not have a prefix _W or _b so wont be diff.] 514 | self._addWeights('lr', np.asarray(self.params['lr'],dtype=config.floatX),borrow=False) 515 | self._addWeights('anneal', np.asarray(0.01,dtype=config.floatX),borrow=False) 516 | self._addWeights('update_ctr', np.asarray(1.,dtype=config.floatX),borrow=False) 517 | lr = self.tWeights['lr'] 518 | anneal = self.tWeights['anneal'] 519 | iteration_t = self.tWeights['update_ctr'] 520 | 521 | anneal_div = 1000. 522 | if 'anneal_rate' in self.params: 523 | self._p('Anneal = 1 in '+str(self.params['anneal_rate'])+' param. updates') 524 | anneal_div = self.params['anneal_rate'] 525 | if 'synthetic' in self.params['dataset']: 526 | anneal_div = 100. 527 | anneal_update = [(iteration_t, iteration_t+1), 528 | (anneal,T.switch(0.01+iteration_t/anneal_div>1,1,0.01+iteration_t/anneal_div))] 529 | fxn_inputs = [X, M, eps] 530 | if not self.params['validate_only']: 531 | print '****** CREATING TRAINING FUNCTION*****' 532 | ############# Setup training functions ########### 533 | obs_params, z_q, mu_q, cov_q, mu_prior, cov_prior, _, _ = self._inferenceAndReconstruction( 534 | X, eps, 535 | dropout_prob = self.params['rnn_dropout']) 536 | negCLL = self._getNegCLL(obs_params, X, M) 537 | TemporalKL = self._getTemporalKL(mu_q, cov_q, mu_prior, cov_prior, M) 538 | train_cost = negCLL+anneal*TemporalKL 539 | 540 | #Get updates from optimizer 541 | model_params = self._getModelParams() 542 | optimizer_up, norm_list = self._setupOptimizer(train_cost, model_params,lr = lr, 543 | reg_type =self.params['reg_type'], 544 | reg_spec =self.params['reg_spec'], 545 | reg_value= self.params['reg_value'], 546 | divide_grad = T.cast(X.shape[0],dtype=config.floatX), 547 | grad_norm = 1.) 548 | 549 | #Add annealing updates 550 | optimizer_up +=anneal_update+self.updates 551 | self._p(str(len(self.updates))+' other updates') 552 | ############# Setup train & evaluate functions ########### 553 | self.train_debug = theano.function(fxn_inputs,[train_cost,norm_list[0],norm_list[1], 554 | norm_list[2],negCLL, TemporalKL, anneal.sum()], 555 | updates = optimizer_up, name='Train (with Debug)') 556 | #Updates ack 557 | self.updates_ack = True 558 | eval_obs_params, eval_z_q, eval_mu_q, eval_cov_q, eval_mu_prior, eval_cov_prior, \ 559 | eval_mu_trans, eval_cov_trans = self._inferenceAndReconstruction( 560 | X, eps, 561 | dropout_prob = 0.) 562 | eval_z_q.name = 'eval_z_q' 563 | eval_CNLLvec=self._getNegCLL(eval_obs_params, X, M, batchVector = True) 564 | eval_KLvec = self._getTemporalKL(eval_mu_q, eval_cov_q,eval_mu_prior, eval_cov_prior, M, batchVector = True) 565 | eval_cost = eval_CNLLvec + eval_KLvec 566 | 567 | #From here on, convert to the log covariance since we only use it for evaluation 568 | assert np.all(eval_cov_q.tag.test_value>0.),'should be positive' 569 | assert np.all(eval_cov_prior.tag.test_value>0.),'should be positive' 570 | assert np.all(eval_cov_trans.tag.test_value>0.),'should be positive' 571 | eval_logcov_q = T.log(eval_cov_q) 572 | eval_logcov_prior = T.log(eval_cov_prior) 573 | eval_logcov_trans = T.log(eval_cov_trans) 574 | 575 | ll_prior = self._llGaussian(eval_z_q, eval_mu_prior, eval_logcov_prior).sum(2)*M 576 | ll_posterior = self._llGaussian(eval_z_q, eval_mu_q, eval_logcov_q).sum(2)*M 577 | ll_estimate = -1*eval_CNLLvec+ll_prior.sum(1,keepdims=True)-ll_posterior.sum(1,keepdims=True) 578 | 579 | eval_inputs = [eval_z_q] 580 | self.likelihood = theano.function(fxn_inputs, ll_estimate, name = 'Importance Sampling based likelihood') 581 | self.evaluate = theano.function(fxn_inputs, eval_cost, name = 'Evaluate Bound') 582 | if self.params['use_prev_input']: 583 | eval_inputs.append(X) 584 | self.transition_fxn = theano.function(eval_inputs,[eval_mu_trans, eval_logcov_trans], 585 | name='Transition Function') 586 | emission_inputs = [eval_z_q] 587 | if self.params['emission_type']=='conditional': 588 | emission_inputs.append(X) 589 | if self.params['data_type']=='binary_nade': 590 | self.emission_fxn = theano.function(emission_inputs, 591 | eval_obs_params[1], name='Emission Function') 592 | else: 593 | self.emission_fxn = theano.function(emission_inputs, 594 | eval_obs_params[0], name='Emission Function') 595 | self.posterior_inference = theano.function([X, eps], 596 | [eval_z_q, eval_mu_q, eval_logcov_q], 597 | name='Posterior Inference') 598 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""# 599 | if __name__=='__main__': 600 | """ use this to check compilation for various options""" 601 | from parse_args_dkf import params 602 | if params['use_nade']: 603 | params['data_type'] = 'binary_nade' 604 | else: 605 | params['data_type'] = 'binary' 606 | params['dim_observations'] = 10 607 | dkf = DKF(params, paramFile = 'tmp') 608 | os.unlink('tmp') 609 | -------------------------------------------------------------------------------- /stinfmodel/evaluate.py: -------------------------------------------------------------------------------- 1 | 2 | from theano import config 3 | import numpy as np 4 | import time 5 | """ 6 | Functions for evaluating a DKF object 7 | """ 8 | def infer(dkf, dataset): 9 | """ Posterior Inference using recognition network 10 | Returns: z,mu,logcov (each a 3D tensor) Remember to multiply each by the mask of the dataset before 11 | using the latent variables 12 | """ 13 | assert len(dataset.shape)==3,'Expecting 3D tensor for data' 14 | assert dataset.shape[2]==dkf.params['dim_observations'],'Data dim. not matching' 15 | eps = np.random.randn(dataset.shape[0], 16 | dataset.shape[1], 17 | dkf.params['dim_stochastic']).astype(config.floatX) 18 | return dkf.posterior_inference(X=dataset.astype(config.floatX), eps=eps) 19 | 20 | def evaluateBound(dkf, dataset, mask, batch_size,S=2, normalization = 'frame', additional={}): 21 | """ Evaluate ELBO """ 22 | bound = 0 23 | start_time = time.time() 24 | N = dataset.shape[0] 25 | 26 | tsbn_bound = 0 27 | for bnum,st_idx in enumerate(range(0,N,batch_size)): 28 | end_idx = min(st_idx+batch_size, N) 29 | X = dataset[st_idx:end_idx,:,:].astype(config.floatX) 30 | M = mask[st_idx:end_idx,:].astype(config.floatX) 31 | U = None 32 | 33 | #Reduce the dimensionality of the tensors based on the maximum size of the mask 34 | maxT = int(np.max(M.sum(1))) 35 | X = X[:,:maxT,:] 36 | M = M[:,:maxT] 37 | eps = np.random.randn(X.shape[0],maxT,dkf.params['dim_stochastic']).astype(config.floatX) 38 | maxS = S 39 | bound_sum, tsbn_bound_sum = 0, 0 40 | for s in range(S): 41 | if s>0 and s%500==0: 42 | dkf._p('Done '+str(s)) 43 | eps = np.random.randn(X.shape[0],maxT,dkf.params['dim_stochastic']).astype(config.floatX) 44 | batch_vec= dkf.evaluate(X=X, M=M, eps=eps) 45 | if np.any(np.isnan(batch_vec)) or np.any(np.isinf(batch_vec)): 46 | dkf._p('NaN detected during evaluation. Ignoring this sample') 47 | maxS -=1 48 | continue 49 | else: 50 | tsbn_bound_sum+=(batch_vec/M.sum(1,keepdims=True)).sum() 51 | bound_sum+=batch_vec.sum() 52 | tsbn_bound += tsbn_bound_sum/float(max(maxS*N,1.)) 53 | bound += bound_sum/float(max(maxS,1.)) 54 | if normalization=='frame': 55 | bound /= float(mask.sum()) 56 | elif normalization=='sequence': 57 | bound /= float(N) 58 | else: 59 | assert False,'Invalid normalization specified' 60 | end_time = time.time() 61 | dkf._p(('(Evaluate) Validation Bound: %.4f [Took %.4f seconds], TSBN Bound: %.4f')%(bound,end_time-start_time,tsbn_bound)) 62 | additional['tsbn_bound'] = tsbn_bound 63 | return bound 64 | 65 | 66 | def impSamplingNLL(dkf, dataset, mask, batch_size, S = 2, normalization = 'frame'): 67 | """ Importance sampling based log likelihood """ 68 | ll = 0 69 | start_time = time.time() 70 | N = dataset.shape[0] 71 | for bnum,st_idx in enumerate(range(0,N,batch_size)): 72 | end_idx = min(st_idx+batch_size, N) 73 | X = dataset[st_idx:end_idx,:,:].astype(config.floatX) 74 | M = mask[st_idx:end_idx,:].astype(config.floatX) 75 | U = None 76 | maxT = int(np.max(M.sum(1))) 77 | X = X[:,:maxT,:] 78 | M = M[:,:maxT] 79 | eps = np.random.randn(X.shape[0],maxT,dkf.params['dim_stochastic']).astype(config.floatX) 80 | maxS = S 81 | lllist = [] 82 | for s in range(S): 83 | if s>0 and s%500==0: 84 | dkf._p('Done '+str(s)) 85 | eps = np.random.randn(X.shape[0],maxT, 86 | dkf.params['dim_stochastic']).astype(config.floatX) 87 | batch_vec = dkf.likelihood(X=X, M=M, eps=eps) 88 | if np.any(np.isnan(batch_vec)) or np.any(np.isinf(batch_vec)): 89 | dkf._p('NaN detected during evaluation. Ignoring this sample') 90 | maxS -=1 91 | continue 92 | else: 93 | lllist.append(batch_vec) 94 | ll += dkf.meanSumExp(np.concatenate(lllist,axis=1), axis=1).sum() 95 | if normalization=='frame': 96 | ll /= float(mask.sum()) 97 | elif normalization=='sequence': 98 | ll /= float(N) 99 | else: 100 | assert False,'Invalid normalization specified' 101 | end_time = time.time() 102 | dkf._p(('(Evaluate w/ Imp. Sampling) Validation LL: %.4f [Took %.4f seconds]')%(ll,end_time-start_time)) 103 | return ll 104 | 105 | 106 | def sampleGaussian(dkf,mu,logcov): 107 | return mu + np.random.randn(*mu.shape)*np.exp(0.5*logcov) 108 | 109 | def sample(dkf, nsamples=100, T=10, additional = {}): 110 | """ 111 | Sample from Generative Model 112 | """ 113 | assert T>1, 'Sample atleast 2 timesteps' 114 | #Initial sample 115 | z = np.random.randn(nsamples,1,dkf.params['dim_stochastic']).astype(config.floatX) 116 | all_zs = [np.copy(z)] 117 | additional['mu'] = [] 118 | additional['logcov'] = [] 119 | for t in range(T-1): 120 | mu,logcov = dkf.transition_fxn(z) 121 | z = dkf.sampleGaussian(mu,logcov).astype(config.floatX) 122 | all_zs.append(np.copy(z)) 123 | additional['mu'].append(np.copy(mu)) 124 | additional['logcov'].append(np.copy(logcov)) 125 | zvec = np.concatenate(all_zs,axis=1) 126 | additional['mu'] = np.concatenate(additional['mu'], axis=1) 127 | additional['logcov'] = np.concatenate(additional['logcov'], axis=1) 128 | return dkf.emission_fxn(zvec), zvec 129 | 130 | -------------------------------------------------------------------------------- /stinfmodel/learning.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for learning with a DKF object 3 | """ 4 | import evaluate as DKF_evaluate 5 | import numpy as np 6 | from utils.misc import saveHDF5 7 | import time 8 | from theano import config 9 | 10 | def learn(dkf, dataset, mask, epoch_start=0, epoch_end=1000, 11 | batch_size=200, shuffle=False, 12 | savefreq=None, savefile = None, 13 | dataset_eval = None, mask_eval = None, 14 | replicate_K = None, 15 | normalization = 'frame'): 16 | """ 17 | Train DKF 18 | """ 19 | assert not dkf.params['validate_only'],'cannot learn in validate only mode' 20 | assert len(dataset.shape)==3,'Expecting 3D tensor for data' 21 | assert dataset.shape[2]==dkf.params['dim_observations'],'Dim observations not valid' 22 | N = dataset.shape[0] 23 | idxlist = range(N) 24 | batchlist = np.split(idxlist, range(batch_size,N,batch_size)) 25 | 26 | bound_train_list,bound_valid_list,bound_tsbn_list,nll_valid_list = [],[],[],[] 27 | p_norm, g_norm, opt_norm = None, None, None 28 | 29 | #Lists used to track quantities for synthetic experiments 30 | mu_list_train, cov_list_train, mu_list_valid, cov_list_valid = [],[],[],[] 31 | 32 | #Start of training loop 33 | for epoch in range(epoch_start, epoch_end): 34 | #Shuffle 35 | if shuffle: 36 | np.random.shuffle(idxlist) 37 | batchlist = np.split(idxlist, range(batch_size,N,batch_size)) 38 | #Always shuffle order the batches are presented in 39 | np.random.shuffle(batchlist) 40 | 41 | start_time = time.time() 42 | bound = 0 43 | for bnum, batch_idx in enumerate(batchlist): 44 | batch_idx = batchlist[bnum] 45 | X = dataset[batch_idx,:,:].astype(config.floatX) 46 | M = mask[batch_idx,:].astype(config.floatX) 47 | U = None 48 | 49 | #Tack on 0's if the matrix size (optimization->theano doesnt have to redefine matrices) 50 | if X.shape[0]> A.result 2 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm LR -infm mean_field -dset jsb >> A.result 3 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm R -infm structured -dset jsb >> A.result 4 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm R -infm structured -ttype mlp -dset jsb >>A.result 5 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm LR -infm structured -dset jsb >>A.result 6 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm R -infm structured -ar 5000 -dset jsb >>A.result 7 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm R -infm structured -ar 5000 -usenade -dset jsb >>A.result 8 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm R -infm structured -ar 5000 -etype conditional -usenade -dset jsb >>A.result 9 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -dset jsb >>A.result 10 | THEANO_FLAGS="lib.cnmem=1.,scan.allow_gc=False,compiledir_format=gpu0" python2.7 dkf.py -vm R -infm structured -ar 5000 -etype conditional -previnp -usenade -dset jsb >> A.result 11 | grep -n "buildModel" A.result 12 | rm -rf A.result 13 | -------------------------------------------------------------------------------- /stinfmodel_fast/__init__.py: -------------------------------------------------------------------------------- 1 | all=['dkf','evaluate','learning'] 2 | -------------------------------------------------------------------------------- /stinfmodel_fast/dkf.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import six.moves.cPickle as pickle 3 | from collections import OrderedDict 4 | import numpy as np 5 | import sys, time, os, gzip, theano,math 6 | sys.path.append('../') 7 | from theano import config 8 | theano.config.compute_test_value = 'warn' 9 | from theano.printing import pydotprint 10 | import theano.tensor as T 11 | from utils.misc import saveHDF5 12 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 13 | 14 | from utils.optimizer import adam,rmsprop 15 | from models.__init__ import BaseModel 16 | from datasets.synthp import params_synthetic 17 | from datasets.synthpTheano import updateParamsSynthetic 18 | 19 | #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%# 20 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""# 21 | #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%# 22 | 23 | """ 24 | DEEP MARKOV MODEL [DEEP KALMAN FILTER] 25 | """ 26 | class DKF(BaseModel, object): 27 | def __init__(self, params, paramFile=None, reloadFile=None): 28 | self.scan_updates = [] 29 | super(DKF,self).__init__(params, paramFile=paramFile, reloadFile=reloadFile) 30 | if 'synthetic' in self.params['dataset'] and not hasattr(self, 'params_synthetic'): 31 | assert False, 'Expecting to have params_synthetic as an attribute in DKF class' 32 | assert self.params['nonlinearity']!='maxout','Maxout nonlinearity not supported' 33 | 34 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""# 35 | def _createParams(self): 36 | """ Model parameters """ 37 | npWeights = OrderedDict() 38 | self._createInferenceParams(npWeights) 39 | self._createGenerativeParams(npWeights) 40 | return npWeights 41 | 42 | def _createGenerativeParams(self, npWeights): 43 | """ Create weights/params for generative model """ 44 | if 'synthetic' in self.params['dataset']: 45 | updateParamsSynthetic(params_synthetic) 46 | self.params_synthetic = params_synthetic 47 | for k in self.params_synthetic[self.params['dataset']]['params']: 48 | npWeights[k+'_W'] = np.array(np.random.uniform(-0.2,0.2),dtype=config.floatX) 49 | return 50 | DIM_HIDDEN = self.params['dim_hidden'] 51 | DIM_STOCHASTIC = self.params['dim_stochastic'] 52 | if self.params['transition_type']=='mlp': 53 | DIM_HIDDEN_TRANS = DIM_HIDDEN*2 54 | for l in range(self.params['transition_layers']): 55 | dim_input,dim_output = DIM_HIDDEN_TRANS, DIM_HIDDEN_TRANS 56 | if l==0: 57 | dim_input = self.params['dim_stochastic'] 58 | npWeights['p_trans_W_'+str(l)] = self._getWeight((dim_input, dim_output)) 59 | npWeights['p_trans_b_'+str(l)] = self._getWeight((dim_output,)) 60 | if self.params['use_prev_input']: 61 | npWeights['p_trans_W_0'] = self._getWeight((DIM_STOCHASTIC+self.params['dim_observations'], 62 | DIM_HIDDEN_TRANS)) 63 | npWeights['p_trans_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,)) 64 | MU_COV_INP = DIM_HIDDEN_TRANS 65 | elif self.params['transition_type']=='simple_gated': 66 | DIM_HIDDEN_TRANS = DIM_HIDDEN*2 67 | npWeights['p_gate_embed_W_0'] = self._getWeight((DIM_STOCHASTIC, DIM_HIDDEN_TRANS)) 68 | npWeights['p_gate_embed_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,)) 69 | npWeights['p_gate_embed_W_1'] = self._getWeight((DIM_HIDDEN_TRANS, DIM_STOCHASTIC)) 70 | npWeights['p_gate_embed_b_1'] = self._getWeight((DIM_STOCHASTIC,)) 71 | npWeights['p_z_W_0'] = self._getWeight((DIM_STOCHASTIC, DIM_HIDDEN_TRANS)) 72 | npWeights['p_z_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,)) 73 | npWeights['p_z_W_1'] = self._getWeight((DIM_HIDDEN_TRANS, DIM_STOCHASTIC)) 74 | npWeights['p_z_b_1'] = self._getWeight((DIM_STOCHASTIC,)) 75 | if self.params['use_prev_input']: 76 | npWeights['p_z_W_0'] = self._getWeight((DIM_STOCHASTIC+self.params['dim_observations'], DIM_HIDDEN_TRANS)) 77 | npWeights['p_z_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,)) 78 | npWeights['p_gate_embed_W_0'] = self._getWeight((DIM_STOCHASTIC+self.params['dim_observations'], 79 | DIM_HIDDEN_TRANS)) 80 | npWeights['p_gate_embed_b_0'] = self._getWeight((DIM_HIDDEN_TRANS,)) 81 | MU_COV_INP = DIM_STOCHASTIC 82 | else: 83 | assert False,'Invalid transition type: '+self.params['transition_type'] 84 | 85 | if self.params['transition_type']=='simple_gated': 86 | weight= np.eye(self.params['dim_stochastic']).astype(config.floatX) 87 | bias = np.zeros((self.params['dim_stochastic'],)).astype(config.floatX) 88 | #Initialize the weights to be identity 89 | npWeights['p_trans_W_mu'] = weight 90 | npWeights['p_trans_b_mu'] = bias 91 | else: 92 | npWeights['p_trans_W_mu'] = self._getWeight((MU_COV_INP, self.params['dim_stochastic'])) 93 | npWeights['p_trans_b_mu'] = self._getWeight((self.params['dim_stochastic'],)) 94 | npWeights['p_trans_W_cov'] = self._getWeight((MU_COV_INP, self.params['dim_stochastic'])) 95 | npWeights['p_trans_b_cov'] = self._getWeight((self.params['dim_stochastic'],)) 96 | 97 | #Emission Function [MLP] 98 | if self.params['emission_type'] == 'mlp': 99 | for l in range(self.params['emission_layers']): 100 | dim_input,dim_output = DIM_HIDDEN, DIM_HIDDEN 101 | if l==0: 102 | dim_input = self.params['dim_stochastic'] 103 | npWeights['p_emis_W_'+str(l)] = self._getWeight((dim_input, dim_output)) 104 | npWeights['p_emis_b_'+str(l)] = self._getWeight((dim_output,)) 105 | elif self.params['emission_type'] == 'res': 106 | for l in range(self.params['emission_layers']): 107 | dim_input,dim_output = DIM_HIDDEN, DIM_HIDDEN 108 | if l==0: 109 | dim_input = self.params['dim_stochastic'] 110 | npWeights['p_emis_W_'+str(l)] = self._getWeight((dim_input, dim_output)) 111 | npWeights['p_emis_b_'+str(l)] = self._getWeight((dim_output,)) 112 | dim_res_out = self.params['dim_observations'] 113 | if self.params['data_type']=='binary_nade': 114 | dim_res_out = DIM_HIDDEN 115 | npWeights['p_res_W'] = self._getWeight((self.params['dim_stochastic'], dim_res_out)) 116 | elif self.params['emission_type'] =='conditional': 117 | for l in range(self.params['emission_layers']): 118 | dim_input,dim_output = DIM_HIDDEN, DIM_HIDDEN 119 | if l==0: 120 | dim_input = self.params['dim_stochastic']+self.params['dim_observations'] 121 | npWeights['p_emis_W_'+str(l)] = self._getWeight((dim_input, dim_output)) 122 | npWeights['p_emis_b_'+str(l)] = self._getWeight((dim_output,)) 123 | else: 124 | assert False, 'Invalid emission type: '+str(self.params['emission_type']) 125 | if self.params['data_type']=='binary': 126 | npWeights['p_emis_W_ber'] = self._getWeight((self.params['dim_hidden'], self.params['dim_observations'])) 127 | npWeights['p_emis_b_ber'] = self._getWeight((self.params['dim_observations'],)) 128 | elif self.params['data_type']=='binary_nade': 129 | n_visible, n_hidden = self.params['dim_observations'], self.params['dim_hidden'] 130 | npWeights['p_nade_W'] = self._getWeight((n_visible, n_hidden)) 131 | npWeights['p_nade_U'] = self._getWeight((n_visible,n_hidden)) 132 | npWeights['p_nade_b'] = self._getWeight((n_visible,)) 133 | else: 134 | assert False,'Invalid datatype: '+params['data_type'] 135 | 136 | def _createInferenceParams(self, npWeights): 137 | """ Create weights/params for inference network """ 138 | 139 | #Initial embedding for the inputs 140 | DIM_INPUT = self.params['dim_observations'] 141 | RNN_SIZE = self.params['rnn_size'] 142 | 143 | DIM_HIDDEN = RNN_SIZE 144 | DIM_STOC = self.params['dim_stochastic'] 145 | 146 | #Embed the Input -> RNN_SIZE 147 | dim_input, dim_output= DIM_INPUT, RNN_SIZE 148 | npWeights['q_W_input_0'] = self._getWeight((dim_input, dim_output)) 149 | npWeights['q_b_input_0'] = self._getWeight((dim_output,)) 150 | 151 | #Setup weights for LSTM 152 | self._createLSTMWeights(npWeights) 153 | 154 | #Embedding before MF/ST inference model 155 | if self.params['inference_model']=='mean_field': 156 | pass 157 | elif self.params['inference_model']=='structured': 158 | DIM_INPUT = self.params['dim_stochastic'] 159 | if self.params['use_generative_prior']: 160 | DIM_INPUT = self.params['rnn_size'] 161 | npWeights['q_W_st_0'] = self._getWeight((DIM_INPUT, self.params['rnn_size'])) 162 | npWeights['q_b_st_0'] = self._getWeight((self.params['rnn_size'],)) 163 | else: 164 | assert False,'Invalid inference model: '+self.params['inference_model'] 165 | RNN_SIZE = self.params['rnn_size'] 166 | npWeights['q_W_mu'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic'])) 167 | npWeights['q_b_mu'] = self._getWeight((self.params['dim_stochastic'],)) 168 | npWeights['q_W_cov'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic'])) 169 | npWeights['q_b_cov'] = self._getWeight((self.params['dim_stochastic'],)) 170 | if self.params['var_model']=='LR' and self.params['inference_model']=='mean_field': 171 | npWeights['q_W_mu_r'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic'])) 172 | npWeights['q_b_mu_r'] = self._getWeight((self.params['dim_stochastic'],)) 173 | npWeights['q_W_cov_r'] = self._getWeight((RNN_SIZE, self.params['dim_stochastic'])) 174 | npWeights['q_b_cov_r'] = self._getWeight((self.params['dim_stochastic'],)) 175 | 176 | def _createLSTMWeights(self, npWeights): 177 | #LSTM L/LR/R w/ orthogonal weight initialization 178 | suffices_to_build = [] 179 | if self.params['var_model']=='LR' or self.params['var_model']=='L': 180 | suffices_to_build.append('l') 181 | if self.params['var_model']=='LR' or self.params['var_model']=='R': 182 | suffices_to_build.append('r') 183 | RNN_SIZE = self.params['rnn_size'] 184 | for suffix in suffices_to_build: 185 | for l in range(self.params['rnn_layers']): 186 | npWeights['W_lstm_'+suffix+'_'+str(l)] = self._getWeight((RNN_SIZE,RNN_SIZE*4)) 187 | npWeights['b_lstm_'+suffix+'_'+str(l)] = self._getWeight((RNN_SIZE*4,), scheme='lstm') 188 | npWeights['U_lstm_'+suffix+'_'+str(l)] = self._getWeight((RNN_SIZE,RNN_SIZE*4),scheme='lstm') 189 | 190 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""# 191 | def _getEmissionFxn(self, z, X=None): 192 | """ 193 | Apply emission function to zs 194 | Input: z [bs x T x dim] 195 | Output: (params, ) or (mu, cov) of size [bs x T x dim] 196 | """ 197 | if 'synthetic' in self.params['dataset']: 198 | self._p('Using emission function for '+self.params['dataset']) 199 | tParams = {} 200 | for k in self.params_synthetic[self.params['dataset']]['params']: 201 | tParams[k] = self.tWeights[k+'_W'] 202 | mu = self.params_synthetic[self.params['dataset']]['obs_fxn'](z, fxn_params = tParams) 203 | cov = T.ones_like(mu)*self.params_synthetic[self.params['dataset']]['obs_cov'] 204 | cov.name = 'EmissionCov' 205 | return [mu,cov] 206 | 207 | if self.params['emission_type'] in ['mlp','res']: 208 | self._p('EMISSION TYPE: MLP or RES') 209 | hid = z 210 | elif self.params['emission_type']=='conditional': 211 | self._p('EMISSION TYPE: conditional') 212 | X_prev = T.concatenate([T.zeros_like(X[:,[0],:]),X[:,:-1,:]],axis=1) 213 | hid = T.concatenate([z,X_prev],axis=2) 214 | else: 215 | assert False,'Invalid emission type' 216 | for l in range(self.params['emission_layers']): 217 | if self.params['data_type']=='binary_nade' and l==self.params['emission_layers']-1: 218 | hid = T.dot(hid, self.tWeights['p_emis_W_'+str(l)]) + self.tWeights['p_emis_b_'+str(l)] 219 | else: 220 | hid = self._LinearNL(self.tWeights['p_emis_W_'+str(l)], self.tWeights['p_emis_b_'+str(l)], hid) 221 | 222 | if self.params['data_type']=='binary': 223 | if self.params['emission_type']=='res': 224 | hid = T.dot(z,self.tWeights['p_res_W'])+T.dot(hid,self.tWeights['p_emis_W_ber'])+self.tWeights['p_emis_b_ber'] 225 | mean_params=T.nnet.sigmoid(hid) 226 | else: 227 | mean_params=T.nnet.sigmoid(T.dot(hid,self.tWeights['p_emis_W_ber'])+self.tWeights['p_emis_b_ber']) 228 | return [mean_params] 229 | elif self.params['data_type']=='binary_nade': 230 | self._p('NADE observations') 231 | assert X is not None,'Need observations for NADE' 232 | if self.params['emission_type']=='res': 233 | hid += T.dot(z,self.tWeights['p_res_W']) 234 | x_reshaped = X.dimshuffle(2,0,1) 235 | x0 = T.ones((hid.shape[0],hid.shape[1]))#x_reshaped[0]) # bs x T 236 | a0 = hid #bs x T x nhid 237 | W = self.tWeights['p_nade_W'] 238 | V = self.tWeights['p_nade_U'] 239 | b = self.tWeights['p_nade_b'] 240 | #Use a NADE at the output 241 | def NADEDensity(x, w, v, b, a_prev, x_prev):#Estimating likelihood 242 | a = a_prev + T.dot(T.shape_padright(x_prev, 1), T.shape_padleft(w, 1)) 243 | h = T.nnet.sigmoid(a) #Original - bs x T x nhid 244 | p_xi_is_one = T.nnet.sigmoid(T.dot(h, v) + b) 245 | return (a, x, p_xi_is_one) 246 | ([_, _, mean_params], _) = theano.scan(NADEDensity, 247 | sequences=[x_reshaped, W, V, b], 248 | outputs_info=[a0, x0,None]) 249 | #theano function to sample from NADE 250 | def NADESample(w, v, b, a_prev_s, x_prev_s): 251 | a_s = a_prev_s + T.dot(T.shape_padright(x_prev_s, 1), T.shape_padleft(w, 1)) 252 | h_s = T.nnet.sigmoid(a_s) #Original - bs x T x nhid 253 | p_xi_is_one_s = T.nnet.sigmoid(T.dot(h_s, v) + b) 254 | x_s = T.switch(p_xi_is_one_s>0.5,1.,0.) 255 | return (a_s, x_s, p_xi_is_one_s) 256 | ([_, _, sampled_params], _) = theano.scan(NADESample, 257 | sequences=[W, V, b], 258 | outputs_info=[a0, x0,None]) 259 | """ 260 | def NADEDensityAndSample(x, w, v, b, 261 | a_prev, x_prev, 262 | a_prev_s, x_prev_s ): 263 | a = a_prev + T.dot(T.shape_padright(x_prev, 1), T.shape_padleft(w, 1)) 264 | h = T.nnet.sigmoid(a) #bs x T x nhid 265 | p_xi_is_one = T.nnet.sigmoid(T.dot(h, v) + b) 266 | 267 | a_s = a_prev_s + T.dot(T.shape_padright(x_prev_s, 1), T.shape_padleft(w, 1)) 268 | h_s = T.nnet.sigmoid(a_s) #bs x T x nhid 269 | p_xi_is_one_s = T.nnet.sigmoid(T.dot(h_s, v) + b) 270 | x_s = T.switch(p_xi_is_one_s>0.5,1.,0.) 271 | return (a, x, a_s, x_s, p_xi_is_one, p_xi_is_one_s) 272 | 273 | ([_, _, _, _, mean_params,sampled_params], _) = theano.scan(NADEDensityAndSample, 274 | sequences=[x_reshaped, W, V, b], 275 | outputs_info=[a0, x0, a0, x0, None, None]) 276 | """ 277 | sampled_params = sampled_params.dimshuffle(1,2,0) 278 | mean_params = mean_params.dimshuffle(1,2,0) 279 | return [mean_params,sampled_params] 280 | else: 281 | assert False,'Invalid type of data' 282 | 283 | def _getTransitionFxn(self, z, X=None): 284 | """ 285 | Apply transition function to zs 286 | Input: z [bs x T x dim], u [bs x T x dim] 287 | Output: mu, cov of size [bs x T x dim] 288 | """ 289 | if 'synthetic' in self.params['dataset']: 290 | self._p('Using transition function for '+self.params['dataset']) 291 | tParams = {} 292 | for k in self.params_synthetic[self.params['dataset']]['params']: 293 | tParams[k] = self.tWeights[k+'_W'] 294 | mu = self.params_synthetic[self.params['dataset']]['trans_fxn'](z, fxn_params = tParams) 295 | cov = T.ones_like(mu)*self.params_synthetic[self.params['dataset']]['trans_cov'] 296 | cov.name = 'TransitionCov' 297 | return mu,cov 298 | 299 | if self.params['transition_type']=='simple_gated': 300 | def mlp(inp, W1,b1,W2,b2, X_prev=None): 301 | if X_prev is not None: 302 | h1 = self._LinearNL(W1,b1, T.concatenate([inp,X_prev],axis=2)) 303 | else: 304 | h1 = self._LinearNL(W1,b1, inp) 305 | h2 = T.dot(h1,W2)+b2 306 | return h2 307 | 308 | gateInp= z 309 | X_prev = None 310 | if self.params['use_prev_input']: 311 | X_prev = T.concatenate([T.zeros_like(X[:,[0],:]),X[:,:-1,:]],axis=1) 312 | gate = T.nnet.sigmoid(mlp(gateInp, self.tWeights['p_gate_embed_W_0'], self.tWeights['p_gate_embed_b_0'], 313 | self.tWeights['p_gate_embed_W_1'],self.tWeights['p_gate_embed_b_1'], 314 | X_prev = X_prev)) 315 | z_prop = mlp(z,self.tWeights['p_z_W_0'] ,self.tWeights['p_z_b_0'], 316 | self.tWeights['p_z_W_1'] , self.tWeights['p_z_b_1'], X_prev = X_prev) 317 | mu = gate*z_prop + (1.-gate)*(T.dot(z, self.tWeights['p_trans_W_mu'])+self.tWeights['p_trans_b_mu']) 318 | cov = T.nnet.softplus(T.dot(self._applyNL(z_prop), self.tWeights['p_trans_W_cov'])+ 319 | self.tWeights['p_trans_b_cov']) 320 | return mu,cov 321 | elif self.params['transition_type']=='mlp': 322 | hid = z 323 | if self.params['use_prev_input']: 324 | X_prev = T.concatenate([T.zeros_like(X[:,[0],:]),X[:,:-1,:]],axis=1) 325 | hid = T.concatenate([hid,X_prev],axis=2) 326 | for l in range(self.params['transition_layers']): 327 | hid = self._LinearNL(self.tWeights['p_trans_W_'+str(l)],self.tWeights['p_trans_b_'+str(l)],hid) 328 | mu = T.dot(hid, self.tWeights['p_trans_W_mu']) + self.tWeights['p_trans_b_mu'] 329 | cov = T.nnet.softplus(T.dot(hid, self.tWeights['p_trans_W_cov'])+self.tWeights['p_trans_b_cov']) 330 | return mu,cov 331 | else: 332 | assert False,'Invalid Transition type: '+str(self.params['transition_type']) 333 | 334 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""# 335 | def _buildLSTM(self, X, embedding, dropout_prob = 0.): 336 | """ 337 | Take the embedding of bs x T x dim and return T x bs x dim that is the result of the scan operation 338 | for L/LR/R 339 | Input: embedding [bs x T x dim] 340 | Output:hidden_state [T x bs x dim] 341 | """ 342 | start_time = time.time() 343 | self._p('In <_buildLSTM>') 344 | suffix = '' 345 | if self.params['var_model']=='R': 346 | suffix='r' 347 | return self._LSTMlayer(embedding, suffix, dropout_prob) 348 | elif self.params['var_model']=='L': 349 | suffix='l' 350 | return self._LSTMlayer(embedding, suffix, dropout_prob) 351 | elif self.params['var_model']=='LR': 352 | suffix='l' 353 | l2r = self._LSTMlayer(embedding, suffix, dropout_prob) 354 | suffix='r' 355 | r2l = self._LSTMlayer(embedding, suffix, dropout_prob) 356 | return [l2r,r2l] 357 | else: 358 | assert False,'Invalid variational model' 359 | self._p(('Done <_buildLSTM> [Took %.4f]')%(time.time()-start_time)) 360 | 361 | 362 | def _inferenceLayer(self, hidden_state): 363 | """ 364 | Take input of T x bs x dim and return z, mu, 365 | sq each of size (bs x T x dim) 366 | Input: hidden_state [T x bs x dim], eps [bs x T x dim] 367 | Output: z [bs x T x dim], mu [bs x T x dim], cov [bs x T x dim] 368 | """ 369 | def structuredApproximation(h_t, eps_t, z_prev, 370 | q_W_st_0, q_b_st_0, 371 | q_W_mu, q_b_mu, 372 | q_W_cov,q_b_cov): 373 | #Using the prior distribution directly 374 | if self.params['use_generative_prior']: 375 | assert not self.params['use_prev_input'],'No support for using previous input' 376 | #Get mu/cov from z_prev through prior distribution 377 | mu_1,cov_1 = self._getTransitionFxn(z_prev) 378 | #Combine with estimate of mu/cov from data 379 | h_data = T.tanh(T.dot(h_t,q_W_st_0)+q_b_st_0) 380 | mu_2 = T.dot(h_data,q_W_mu)+q_b_mu 381 | cov_2 = T.nnet.softplus(T.dot(h_data,q_W_cov)+q_b_cov) 382 | mu = (mu_1*cov_2+mu_2*cov_1)/(cov_1+cov_2) 383 | cov = (cov_1*cov_2)/(cov_1+cov_2) 384 | z = mu + T.sqrt(cov)*eps_t 385 | return z, mu, cov 386 | else: 387 | h_next = T.tanh(T.dot(z_prev,q_W_st_0)+q_b_st_0) 388 | if self.params['var_model']=='LR': 389 | h_next = (1./3.)*(h_t+h_next) 390 | else: 391 | h_next = (1./2.)*(h_t+h_next) 392 | mu_t = T.dot(h_next,q_W_mu)+q_b_mu 393 | cov_t = T.nnet.softplus(T.dot(h_next,q_W_cov)+q_b_cov) 394 | z_t = mu_t+T.sqrt(cov_t)*eps_t 395 | return z_t, mu_t, cov_t 396 | if type(hidden_state) is list: 397 | eps = self.srng.normal(size=(hidden_state[0].shape[1],hidden_state[0].shape[0],self.params['dim_stochastic'])) 398 | else: 399 | eps = self.srng.normal(size=(hidden_state.shape[1],hidden_state.shape[0],self.params['dim_stochastic'])) 400 | if self.params['inference_model']=='structured': 401 | #Structured recognition networks 402 | if self.params['var_model']=='LR': 403 | state = hidden_state[0]+hidden_state[1] 404 | else: 405 | state = hidden_state 406 | eps_swap = eps.swapaxes(0,1) 407 | if self.params['dim_stochastic']==1: 408 | """ 409 | TODO: Write to theano authors regarding this issue. 410 | Workaround for theano issue: The result of a matrix multiply is a "matrix" 411 | even if one of the dimensions is 1. However defining a tensor with one dimension one 412 | means theano regards the resulting tensor as a matrix and consequently in the 413 | scan as a column. This results in a mismatch in tensor type in input (column) 414 | and output (matrix) and throws an error. This is a workaround that preserves 415 | type while not affecting dimensions 416 | """ 417 | z0 = T.zeros((eps_swap.shape[1], self.params['rnn_size'])) 418 | z0 = T.dot(z0,T.zeros_like(self.tWeights['q_W_mu'])) 419 | else: 420 | z0 = T.zeros((eps_swap.shape[1], self.params['dim_stochastic'])) 421 | rval, _ = theano.scan(structuredApproximation, 422 | sequences=[state, eps_swap], 423 | outputs_info=[z0, None,None], 424 | non_sequences=[self.tWeights[k] for k in 425 | ['q_W_st_0', 'q_b_st_0']]+ 426 | [self.tWeights[k] for k in 427 | ['q_W_mu','q_b_mu','q_W_cov','q_b_cov']], 428 | name='structuredApproximation') 429 | z, mu, cov = rval[0].swapaxes(0,1), rval[1].swapaxes(0,1), rval[2].swapaxes(0,1) 430 | return z, mu, cov 431 | elif self.params['inference_model']=='mean_field': 432 | if self.params['var_model']=='LR': 433 | l2r = hidden_state[0].swapaxes(0,1) 434 | r2l = hidden_state[1].swapaxes(0,1) 435 | hidl2r = l2r 436 | mu_1 = T.dot(hidl2r,self.tWeights['q_W_mu'])+self.tWeights['q_b_mu'] 437 | cov_1 = T.nnet.softplus(T.dot(hidl2r, self.tWeights['q_W_cov'])+self.tWeights['q_b_cov']) 438 | hidr2l = r2l 439 | mu_2 = T.dot(hidr2l,self.tWeights['q_W_mu_r'])+self.tWeights['q_b_mu_r'] 440 | cov_2 = T.nnet.softplus(T.dot(hidr2l, self.tWeights['q_W_cov_r'])+self.tWeights['q_b_cov_r']) 441 | mu = (mu_1*cov_2+mu_2*cov_1)/(cov_1+cov_2) 442 | cov= (cov_1*cov_2)/(cov_1+cov_2) 443 | z = mu + T.sqrt(cov)*eps 444 | else: 445 | hid = hidden_state.swapaxes(0,1) 446 | mu = T.dot(hid,self.tWeights['q_W_mu'])+self.tWeights['q_b_mu'] 447 | cov = T.nnet.softplus(T.dot(hid,self.tWeights['q_W_cov'])+ self.tWeights['q_b_cov']) 448 | z = mu + T.sqrt(cov)*eps 449 | return z,mu,cov 450 | else: 451 | assert False,'Invalid recognition model' 452 | 453 | def _qEmbeddingLayer(self, X): 454 | """ Embed for q """ 455 | return self._LinearNL(self.tWeights['q_W_input_0'],self.tWeights['q_b_input_0'], X) 456 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""# 457 | 458 | def _inferenceAndReconstruction(self, X, dropout_prob = 0.): 459 | """ 460 | Returns z_q, mu_q and cov_q 461 | """ 462 | self._p('Building with dropout:'+str(dropout_prob)) 463 | embedding = self._qEmbeddingLayer(X) 464 | hidden_state = self._buildLSTM(X, embedding, dropout_prob) 465 | z_q,mu_q,cov_q = self._inferenceLayer(hidden_state) 466 | 467 | #Regularize z_q (for train) 468 | #if dropout_prob>0.: 469 | # z_q = z_q + self.srng.normal(z_q.shape, 0.,0.0025,dtype=config.floatX) 470 | #z_q.name = 'z_q' 471 | 472 | observation_params = self._getEmissionFxn(z_q,X=X) 473 | mu_trans, cov_trans= self._getTransitionFxn(z_q, X=X) 474 | mu_prior = T.concatenate([T.alloc(np.asarray(0.).astype(config.floatX), 475 | X.shape[0],1,self.params['dim_stochastic']), mu_trans[:,:-1,:]],axis=1) 476 | cov_prior = T.concatenate([T.alloc(np.asarray(1.).astype(config.floatX), 477 | X.shape[0],1,self.params['dim_stochastic']), cov_trans[:,:-1,:]],axis=1) 478 | return observation_params, z_q, mu_q, cov_q, mu_prior, cov_prior, mu_trans, cov_trans 479 | 480 | 481 | def _getTemporalKL(self, mu_q, cov_q, mu_prior, cov_prior, M, batchVector = False): 482 | """ 483 | TemporalKL divergence KL (q||p) 484 | KL(q_t||p_t) = 0.5*(log|sigmasq_p| -log|sigmasq_q| -D + Tr(sigmasq_p^-1 sigmasq_q) 485 | + (mu_p-mu_q)^T sigmasq_p^-1 (mu_p-mu_q)) 486 | M is a mask of size bs x T that should be applied once the KL divergence for each point 487 | across time has been estimated 488 | """ 489 | assert np.all(cov_q.tag.test_value>0.),'should be positive' 490 | assert np.all(cov_prior.tag.test_value>0.),'should be positive' 491 | diff_mu = mu_prior-mu_q 492 | KL_t = T.log(cov_prior)-T.log(cov_q) - 1. + cov_q/cov_prior + diff_mu**2/cov_prior 493 | KLvec = (0.5*KL_t.sum(2)*M).sum(1,keepdims=True) 494 | if batchVector: 495 | return KLvec 496 | else: 497 | return KLvec.sum() 498 | 499 | def _getNegCLL(self, obs_params, X, M, batchVector = False): 500 | """ 501 | Estimate the negative conditional log likelihood of x|z under the generative model 502 | M: mask of size bs x T 503 | X: target of size bs x T x dim 504 | """ 505 | if self.params['data_type']=='real': 506 | mu_p = obs_params[0] 507 | cov_p = obs_params[1] 508 | std_p = T.sqrt(cov_p) 509 | negCLL_t = 0.5 * np.log(2 * np.pi) + 0.5*T.log(cov_p) + 0.5 * ((X - mu_p) / std_p)**2 510 | negCLL = (negCLL_t.sum(2)*M).sum(1,keepdims=True) 511 | else: 512 | mean_p = obs_params[0] 513 | negCLL = (T.nnet.binary_crossentropy(mean_p,X).sum(2)*M).sum(1,keepdims=True) 514 | if batchVector: 515 | return negCLL 516 | else: 517 | return negCLL.sum() 518 | 519 | def resetDataset(self, newX,newM,quiet=False): 520 | if not quiet: 521 | ddim,mdim = self.dimData() 522 | self._p('Original dim:'+str(ddim)+', '+str(mdim)) 523 | self.setData(newX=newX.astype(config.floatX),newMask=newM.astype(config.floatX)) 524 | if not quiet: 525 | ddim,mdim = self.dimData() 526 | self._p('New dim:'+str(ddim)+', '+str(mdim)) 527 | def _buildModel(self): 528 | if 'synthetic' in self.params['dataset']: 529 | self.params_synthetic = params_synthetic 530 | """ High level function to build and setup theano functions """ 531 | #X = T.tensor3('X', dtype=config.floatX) 532 | #eps = T.tensor3('eps', dtype=config.floatX) 533 | #M = T.matrix('M', dtype=config.floatX) 534 | idx = T.vector('idx',dtype='int64') 535 | idx.tag.test_value = np.array([0,1]).astype('int64') 536 | self.dataset = theano.shared(np.random.uniform(0,1,size=(3,5,self.params['dim_observations'])).astype(config.floatX)) 537 | self.mask = theano.shared(np.array(([[1,1,1,1,0],[1,1,0,0,0],[1,1,1,0,0]])).astype(config.floatX)) 538 | X_o = self.dataset[idx] 539 | M_o = self.mask[idx] 540 | maxidx = T.cast(M_o.sum(1).max(),'int64') 541 | X = X_o[:,:maxidx,:] 542 | M = M_o[:,:maxidx] 543 | newX,newMask = T.tensor3('newX',dtype=config.floatX),T.matrix('newMask',dtype=config.floatX) 544 | self.setData = theano.function([newX,newMask],None,updates=[(self.dataset,newX),(self.mask,newMask)]) 545 | self.dimData = theano.function([],[self.dataset.shape,self.mask.shape]) 546 | 547 | #Learning Rates and annealing objective function 548 | #Add them to npWeights/tWeights to be tracked [do not have a prefix _W or _b so wont be diff.] 549 | self._addWeights('lr', np.asarray(self.params['lr'],dtype=config.floatX),borrow=False) 550 | self._addWeights('anneal', np.asarray(0.01,dtype=config.floatX),borrow=False) 551 | self._addWeights('update_ctr', np.asarray(1.,dtype=config.floatX),borrow=False) 552 | lr = self.tWeights['lr'] 553 | anneal = self.tWeights['anneal'] 554 | iteration_t = self.tWeights['update_ctr'] 555 | 556 | anneal_div = 1000. 557 | if 'anneal_rate' in self.params: 558 | self._p('Anneal = 1 in '+str(self.params['anneal_rate'])+' param. updates') 559 | anneal_div = self.params['anneal_rate'] 560 | if 'synthetic' in self.params['dataset']: 561 | anneal_div = 100. 562 | anneal_update = [(iteration_t, iteration_t+1), 563 | (anneal,T.switch(0.01+iteration_t/anneal_div>1,1,0.01+iteration_t/anneal_div))] 564 | fxn_inputs = [idx] 565 | if not self.params['validate_only']: 566 | print '****** CREATING TRAINING FUNCTION*****' 567 | ############# Setup training functions ########### 568 | obs_params, z_q, mu_q, cov_q, mu_prior, cov_prior, _, _ = self._inferenceAndReconstruction( 569 | X, dropout_prob = self.params['rnn_dropout']) 570 | negCLL = self._getNegCLL(obs_params, X, M) 571 | TemporalKL = self._getTemporalKL(mu_q, cov_q, mu_prior, cov_prior, M) 572 | train_cost = negCLL+anneal*TemporalKL 573 | 574 | #Get updates from optimizer 575 | model_params = self._getModelParams() 576 | optimizer_up, norm_list = self._setupOptimizer(train_cost, model_params,lr = lr, 577 | #Turning off for synthetic 578 | #reg_type =self.params['reg_type'], 579 | #reg_spec =self.params['reg_spec'], 580 | #reg_value= self.params['reg_value'], 581 | divide_grad = T.cast(X.shape[0],dtype=config.floatX), 582 | grad_norm = 1.) 583 | 584 | #Add annealing updates 585 | optimizer_up +=anneal_update+self.updates 586 | self._p(str(len(self.updates))+' other updates') 587 | ############# Setup train & evaluate functions ########### 588 | self.train_debug = theano.function(fxn_inputs,[train_cost,norm_list[0],norm_list[1], 589 | norm_list[2],negCLL, TemporalKL, anneal.sum()], 590 | updates = optimizer_up, name='Train (with Debug)') 591 | #Updates ack 592 | self.updates_ack = True 593 | eval_obs_params, eval_z_q, eval_mu_q, eval_cov_q, eval_mu_prior, eval_cov_prior, \ 594 | eval_mu_trans, eval_cov_trans = self._inferenceAndReconstruction(X,dropout_prob = 0.) 595 | eval_z_q.name = 'eval_z_q' 596 | eval_CNLLvec=self._getNegCLL(eval_obs_params, X, M, batchVector = True) 597 | eval_KLvec = self._getTemporalKL(eval_mu_q, eval_cov_q,eval_mu_prior, eval_cov_prior, M, batchVector = True) 598 | eval_cost = eval_CNLLvec + eval_KLvec 599 | 600 | #From here on, convert to the log covariance since we only use it for evaluation 601 | assert np.all(eval_cov_q.tag.test_value>0.),'should be positive' 602 | assert np.all(eval_cov_prior.tag.test_value>0.),'should be positive' 603 | assert np.all(eval_cov_trans.tag.test_value>0.),'should be positive' 604 | eval_logcov_q = T.log(eval_cov_q) 605 | eval_logcov_prior = T.log(eval_cov_prior) 606 | eval_logcov_trans = T.log(eval_cov_trans) 607 | 608 | ll_prior = self._llGaussian(eval_z_q, eval_mu_prior, eval_logcov_prior).sum(2)*M 609 | ll_posterior = self._llGaussian(eval_z_q, eval_mu_q, eval_logcov_q).sum(2)*M 610 | ll_estimate = -1*eval_CNLLvec+ll_prior.sum(1,keepdims=True)-ll_posterior.sum(1,keepdims=True) 611 | 612 | eval_inputs = [eval_z_q] 613 | self.likelihood = theano.function(fxn_inputs, ll_estimate, name = 'Importance Sampling based likelihood') 614 | self.evaluate = theano.function(fxn_inputs, eval_cost, name = 'Evaluate Bound') 615 | if self.params['use_prev_input']: 616 | eval_inputs.append(X) 617 | self.transition_fxn = theano.function(eval_inputs,[eval_mu_trans, eval_logcov_trans], 618 | name='Transition Function') 619 | emission_inputs = [eval_z_q] 620 | if self.params['emission_type']=='conditional': 621 | emission_inputs.append(X) 622 | if self.params['data_type']=='binary_nade': 623 | self.emission_fxn = theano.function(emission_inputs, 624 | eval_obs_params[1], name='Emission Function') 625 | else: 626 | self.emission_fxn = theano.function(emission_inputs, 627 | eval_obs_params[0], name='Emission Function') 628 | self.posterior_inference = theano.function(fxn_inputs, 629 | [eval_z_q, eval_mu_q, eval_logcov_q], 630 | name='Posterior Inference') 631 | #"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""# 632 | if __name__=='__main__': 633 | """ use this to check compilation for various options""" 634 | from parse_args_dkf import params 635 | if params['use_nade']: 636 | params['data_type'] = 'binary_nade' 637 | else: 638 | params['data_type'] = 'binary' 639 | params['dim_observations'] = 10 640 | dkf = DKF(params, paramFile = 'tmp') 641 | os.unlink('tmp') 642 | import ipdb;ipdb.set_trace() 643 | -------------------------------------------------------------------------------- /stinfmodel_fast/evaluate.py: -------------------------------------------------------------------------------- 1 | 2 | from theano import config 3 | import numpy as np 4 | import time 5 | """ 6 | Functions for evaluating a DKF object 7 | """ 8 | def infer(dkf, dataset, mask): 9 | """ Posterior Inference using recognition network 10 | Returns: z,mu,logcov (each a 3D tensor) Remember to multiply each by the mask of the dataset before 11 | using the latent variables 12 | """ 13 | dkf.resetDataset(dataset,mask,quiet=True) 14 | assert len(dataset.shape)==3,'Expecting 3D tensor for data' 15 | assert dataset.shape[2]==dkf.params['dim_observations'],'Data dim. not matching' 16 | return dkf.posterior_inference(idx=np.arange(dataset.shape[0])) 17 | 18 | def evaluateBound(dkf, dataset, mask, batch_size,S=2, normalization = 'frame', additional={}): 19 | """ Evaluate ELBO """ 20 | bound = 0 21 | start_time = time.time() 22 | N = dataset.shape[0] 23 | tsbn_bound = 0 24 | dkf.resetDataset(dataset,mask) 25 | for bnum,st_idx in enumerate(range(0,N,batch_size)): 26 | end_idx = min(st_idx+batch_size, N) 27 | idx_data= np.arange(st_idx,end_idx) 28 | maxS = S 29 | bound_sum, tsbn_bound_sum = 0, 0 30 | for s in range(S): 31 | if s>0 and s%500==0: 32 | dkf._p('Done '+str(s)) 33 | batch_vec= dkf.evaluate(idx=idx_data) 34 | M = mask[idx_data] 35 | if np.any(np.isnan(batch_vec)) or np.any(np.isinf(batch_vec)): 36 | dkf._p('NaN detected during evaluation. Ignoring this sample') 37 | maxS -=1 38 | continue 39 | else: 40 | tsbn_bound_sum+=(batch_vec/M.sum(1,keepdims=True)).sum() 41 | bound_sum+=batch_vec.sum() 42 | tsbn_bound += tsbn_bound_sum/float(max(maxS*N,1.)) 43 | bound += bound_sum/float(max(maxS,1.)) 44 | if normalization=='frame': 45 | bound /= float(mask.sum()) 46 | elif normalization=='sequence': 47 | bound /= float(N) 48 | else: 49 | assert False,'Invalid normalization specified' 50 | end_time = time.time() 51 | dkf._p(('(Evaluate) Validation Bound: %.4f [Took %.4f seconds], TSBN Bound: %.4f')%(bound,end_time-start_time,tsbn_bound)) 52 | additional['tsbn_bound'] = tsbn_bound 53 | return bound 54 | 55 | 56 | def impSamplingNLL(dkf, dataset, mask, batch_size, S = 2, normalization = 'frame'): 57 | """ Importance sampling based log likelihood """ 58 | ll = 0 59 | start_time = time.time() 60 | N = dataset.shape[0] 61 | dkf.resetDataset(dataset,mask) 62 | for bnum,st_idx in enumerate(range(0,N,batch_size)): 63 | end_idx = min(st_idx+batch_size, N) 64 | idx_data= np.arange(st_idx,end_idx) 65 | maxS = S 66 | lllist = [] 67 | for s in range(S): 68 | if s>0 and s%500==0: 69 | dkf._p('Done '+str(s)) 70 | batch_vec = dkf.likelihood(idx=idx_data) 71 | if np.any(np.isnan(batch_vec)) or np.any(np.isinf(batch_vec)): 72 | dkf._p('NaN detected during evaluation. Ignoring this sample') 73 | maxS -=1 74 | continue 75 | else: 76 | lllist.append(batch_vec) 77 | ll += dkf.meanSumExp(np.concatenate(lllist,axis=1), axis=1).sum() 78 | if normalization=='frame': 79 | ll /= float(mask.sum()) 80 | elif normalization=='sequence': 81 | ll /= float(N) 82 | else: 83 | assert False,'Invalid normalization specified' 84 | end_time = time.time() 85 | dkf._p(('(Evaluate w/ Imp. Sampling) Validation LL: %.4f [Took %.4f seconds]')%(ll,end_time-start_time)) 86 | return ll 87 | 88 | 89 | def sampleGaussian(dkf,mu,logcov): 90 | return mu + np.random.randn(*mu.shape)*np.exp(0.5*logcov) 91 | 92 | def sample(dkf, nsamples=100, T=10, additional = {}): 93 | """ 94 | Sample from Generative Model 95 | """ 96 | assert T>1, 'Sample atleast 2 timesteps' 97 | #Initial sample 98 | z = np.random.randn(nsamples,1,dkf.params['dim_stochastic']).astype(config.floatX) 99 | all_zs = [np.copy(z)] 100 | additional['mu'] = [] 101 | additional['logcov'] = [] 102 | for t in range(T-1): 103 | mu,logcov = dkf.transition_fxn(z) 104 | z = dkf.sampleGaussian(mu,logcov).astype(config.floatX) 105 | all_zs.append(np.copy(z)) 106 | additional['mu'].append(np.copy(mu)) 107 | additional['logcov'].append(np.copy(logcov)) 108 | zvec = np.concatenate(all_zs,axis=1) 109 | additional['mu'] = np.concatenate(additional['mu'], axis=1) 110 | additional['logcov'] = np.concatenate(additional['logcov'], axis=1) 111 | return dkf.emission_fxn(zvec), zvec 112 | -------------------------------------------------------------------------------- /stinfmodel_fast/learning.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for learning with a DKF object 3 | """ 4 | import evaluate as DKF_evaluate 5 | import numpy as np 6 | from utils.misc import saveHDF5 7 | import time 8 | from theano import config 9 | 10 | def learn(dkf, dataset, mask, epoch_start=0, epoch_end=1000, 11 | batch_size=200, shuffle=True, 12 | savefreq=None, savefile = None, 13 | dataset_eval = None, mask_eval = None, 14 | replicate_K = None, 15 | normalization = 'frame'): 16 | """ 17 | Train DKF 18 | """ 19 | assert not dkf.params['validate_only'],'cannot learn in validate only mode' 20 | assert len(dataset.shape)==3,'Expecting 3D tensor for data' 21 | assert dataset.shape[2]==dkf.params['dim_observations'],'Dim observations not valid' 22 | N = dataset.shape[0] 23 | idxlist = range(N) 24 | batchlist = np.split(idxlist, range(batch_size,N,batch_size)) 25 | 26 | bound_train_list,bound_valid_list,bound_tsbn_list,nll_valid_list = [],[],[],[] 27 | p_norm, g_norm, opt_norm = None, None, None 28 | 29 | #Lists used to track quantities for synthetic experiments 30 | mu_list_train, cov_list_train, mu_list_valid, cov_list_valid = [],[],[],[] 31 | model_params = {} 32 | 33 | #Start of training loop 34 | if 'synthetic' in dkf.params['dataset']: 35 | epfreq = 10 36 | else: 37 | epfreq = 1 38 | 39 | #Set data 40 | dkf.resetDataset(dataset, mask) 41 | for epoch in range(epoch_start, epoch_end): 42 | #Shuffle 43 | if shuffle: 44 | np.random.shuffle(idxlist) 45 | batchlist = np.split(idxlist, range(batch_size,N,batch_size)) 46 | #Always shuffle order the batches are presented in 47 | np.random.shuffle(batchlist) 48 | 49 | start_time = time.time() 50 | bound = 0 51 | for bnum, batch_idx in enumerate(batchlist): 52 | batch_idx = batchlist[bnum] 53 | batch_bound, p_norm, g_norm, opt_norm, negCLL, KL, anneal = dkf.train_debug(idx=batch_idx) 54 | 55 | #Number of frames 56 | M_sum = mask[batch_idx].sum() 57 | #Correction for replicating batch 58 | if replicate_K is not None: 59 | batch_bound, negCLL, KL = batch_bound/replicate_K, negCLL/replicate_K, KL/replicate_K, 60 | M_sum = M_sum/replicate_K 61 | #Update bound 62 | bound += batch_bound 63 | ### Display ### 64 | if epoch%epfreq==0 and bnum%10==0: 65 | if normalization=='frame': 66 | bval = batch_bound/float(M_sum) 67 | elif normalization=='sequence': 68 | bval = batch_bound/float(X.shape[0]) 69 | else: 70 | assert False,'Invalid normalization' 71 | dkf._p(('Bnum: %d, Batch Bound: %.4f, |w|: %.4f, |dw|: %.4f, |w_opt|: %.4f')%(bnum,bval,p_norm, g_norm, opt_norm)) 72 | dkf._p(('-veCLL:%.4f, KL:%.4f, anneal:%.4f')%(negCLL, KL, anneal)) 73 | if normalization=='frame': 74 | bound /= (float(mask.sum())/replicate_K) 75 | elif normalization=='sequence': 76 | bound /= float(N) 77 | else: 78 | assert False,'Invalid normalization' 79 | bound_train_list.append((epoch,bound)) 80 | end_time = time.time() 81 | if epoch%epfreq==0: 82 | dkf._p(('(Ep %d) Bound: %.4f [Took %.4f seconds] ')%(epoch, bound, end_time-start_time)) 83 | 84 | #Save at intermediate stages 85 | if savefreq is not None and epoch%savefreq==0: 86 | assert savefile is not None, 'expecting savefile' 87 | dkf._p(('Saving at epoch %d'%epoch)) 88 | dkf._saveModel(fname = savefile+'-EP'+str(epoch)) 89 | intermediate = {} 90 | if dataset_eval is not None and mask_eval is not None: 91 | tmpMap = {} 92 | bound_valid_list.append( 93 | (epoch, 94 | DKF_evaluate.evaluateBound(dkf, dataset_eval, mask_eval, batch_size=batch_size, 95 | additional = tmpMap, normalization=normalization))) 96 | bound_tsbn_list.append((epoch, tmpMap['tsbn_bound'])) 97 | nll_valid_list.append( 98 | DKF_evaluate.impSamplingNLL(dkf, dataset_eval, mask_eval, batch_size, 99 | normalization=normalization)) 100 | intermediate['valid_bound'] = np.array(bound_valid_list) 101 | intermediate['train_bound'] = np.array(bound_train_list) 102 | intermediate['tsbn_bound'] = np.array(bound_tsbn_list) 103 | intermediate['valid_nll'] = np.array(nll_valid_list) 104 | if 'synthetic' in dkf.params['dataset']: 105 | mu_train, cov_train, mu_valid, cov_valid, learned_params = _syntheticProc(dkf, dataset,mask, dataset_eval,mask_eval) 106 | if dkf.params['dim_stochastic']==1: 107 | mu_list_train.append(mu_train) 108 | cov_list_train.append(cov_train) 109 | mu_list_valid.append(mu_valid) 110 | cov_list_valid.append(cov_valid) 111 | intermediate['mu_posterior_train'] = np.concatenate(mu_list_train, axis=2) 112 | intermediate['cov_posterior_train'] = np.concatenate(cov_list_train, axis=2) 113 | intermediate['mu_posterior_valid'] = np.concatenate(mu_list_valid, axis=2) 114 | intermediate['cov_posterior_valid'] = np.concatenate(cov_list_valid, axis=2) 115 | else: 116 | mu_list_train.append(mu_train[None,:]) 117 | cov_list_train.append(cov_train[None,:]) 118 | mu_list_valid.append(mu_valid[None,:]) 119 | cov_list_valid.append(cov_valid[None,:]) 120 | intermediate['mu_posterior_train'] = np.concatenate(mu_list_train, axis=0) 121 | intermediate['cov_posterior_train'] = np.concatenate(cov_list_train, axis=0) 122 | intermediate['mu_posterior_valid'] = np.concatenate(mu_list_valid, axis=0) 123 | intermediate['cov_posterior_valid'] = np.concatenate(cov_list_valid, axis=0) 124 | for k in dkf.params_synthetic[dkf.params['dataset']]['params']: 125 | if k in model_params: 126 | model_params[k].append(learned_params[k]) 127 | else: 128 | model_params[k] = [learned_params[k]] 129 | for k in dkf.params_synthetic[dkf.params['dataset']]['params']: 130 | intermediate[k+'_learned'] = np.array(model_params[k]).squeeze() 131 | saveHDF5(savefile+'-EP'+str(epoch)+'-stats.h5', intermediate) 132 | ### Update X in the computational flow_graph to point to training data 133 | dkf.resetDataset(dataset, mask) 134 | #Final information to be collected 135 | retMap = {} 136 | retMap['train_bound'] = np.array(bound_train_list) 137 | retMap['valid_bound'] = np.array(bound_valid_list) 138 | retMap['tsbn_bound'] = np.array(bound_tsbn_list) 139 | retMap['valid_nll'] = np.array(nll_valid_list) 140 | if 'synthetic' in dkf.params['dataset']: 141 | if dkf.params['dim_stochastic']==1: 142 | retMap['mu_posterior_train'] = np.concatenate(mu_list_train, axis=2) 143 | retMap['cov_posterior_train'] = np.concatenate(cov_list_train, axis=2) 144 | retMap['mu_posterior_valid'] = np.concatenate(mu_list_valid, axis=2) 145 | retMap['cov_posterior_valid'] = np.concatenate(cov_list_valid, axis=2) 146 | else: 147 | retMap['mu_posterior_train'] = np.concatenate(mu_list_train, axis=0) 148 | retMap['cov_posterior_train'] = np.concatenate(cov_list_train, axis=0) 149 | retMap['mu_posterior_valid'] = np.concatenate(mu_list_valid, axis=0) 150 | retMap['cov_posterior_valid'] = np.concatenate(cov_list_valid, axis=0) 151 | for k in dkf.params_synthetic[dkf.params['dataset']]['params']: 152 | retMap[k+'_learned'] = np.array(model_params[k]) 153 | return retMap 154 | 155 | def _syntheticProc(dkf, dataset, mask, dataset_eval, mask_eval): 156 | """ 157 | Collect statistics on the synthetic dataset 158 | """ 159 | allmus, alllogcov = [], [] 160 | if dkf.params['dim_stochastic']==1: 161 | for s in range(10): 162 | _,mus, logcov = DKF_evaluate.infer(dkf,dataset,mask) 163 | allmus.append(np.copy(mus)) 164 | alllogcov.append(np.copy(logcov)) 165 | allmus_v, alllogcov_v = [], [] 166 | for s in range(10): 167 | _,mus, logcov = DKF_evaluate.infer(dkf,dataset_eval,mask) 168 | allmus_v.append(np.copy(mus)) 169 | alllogcov_v.append(np.copy(logcov)) 170 | mu_train = np.concatenate(allmus,axis=2).mean(2,keepdims=True) 171 | cov_train= np.exp(np.concatenate(alllogcov,axis=2)).mean(2,keepdims=True) 172 | mu_valid = np.concatenate(allmus_v,axis=2).mean(2,keepdims=True) 173 | cov_valid= np.exp(np.concatenate(alllogcov_v,axis=2)).mean(2,keepdims=True) 174 | else: 175 | for s in range(10): 176 | _,mus, logcov = DKF_evaluate.infer(dkf,dataset,mask) 177 | allmus.append(np.copy(mus)[None,:]) 178 | alllogcov.append(np.copy(logcov)[None,:]) 179 | allmus_v, alllogcov_v = [], [] 180 | for s in range(10): 181 | _,mus, logcov = DKF_evaluate.infer(dkf,dataset_eval,mask) 182 | allmus_v.append(np.copy(mus)[None,:]) 183 | alllogcov_v.append(np.copy(logcov)[None,:]) 184 | mu_train = np.concatenate(allmus,axis=0).mean(0) 185 | cov_train= np.exp(np.concatenate(alllogcov,axis=0)).mean(0) 186 | mu_valid = np.concatenate(allmus_v,axis=0).mean(0) 187 | cov_valid= np.exp(np.concatenate(alllogcov_v,axis=0)).mean(0) 188 | #Extract the learned parameters w/in the generative model 189 | learned_params = {} 190 | for k in dkf.params_synthetic[dkf.params['dataset']]['params']: 191 | learned_params[k] = dkf.tWeights[k+'_W'].get_value() 192 | return mu_train, cov_train, mu_valid, cov_valid, learned_params 193 | --------------------------------------------------------------------------------