├── LICENSE ├── README.md ├── datasets ├── Dataset.py ├── Dataset2.py └── __init__.py ├── extra.sh ├── networks ├── layers.py ├── network_rl.py ├── toy_mlp.py └── toy_vae.py ├── paper_grid.sh ├── paper_grid_rl.sh ├── paper_toy.sh ├── plotter ├── Plotter2.py └── __init__.py ├── pythonutils ├── __init__.py └── helpers.py ├── rlenv ├── __init__.py ├── chicken.py └── grid.py ├── rlutils ├── __init__.py ├── helpers.py ├── policies.py └── rollout.py ├── tfutils ├── distributions.py └── helpers.py ├── vae_grid.py └── vae_main.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Thomas Moerland 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Multimodal Transition Dynamics for Model-Based Reinforcement Learning 2 | 3 | Code for reproducing key results in the paper [Learning Multimodal Transition Dynamics for Model-Based Reinforcement Learning](https://arxiv.org/pdf/1705.00470.pdf) by Thomas M. Moerland, Joost Broekens and Catholijn M. Jonker. 4 | 5 | ## Prerequisites 6 | 1. Install recent versions of: 7 | - Python 3 8 | - Tensorflow 9 | - Numpy (e.g. `pip install numpy`) 10 | - Matplotlib 11 | 12 | 2. Clone this repository: 13 | ```sh 14 | git clone https://github.com/tmoer/multimodal_varinf.git 15 | ``` 16 | ## Syntax 17 | Example: 18 | ```sh 19 | python3 vae_main.py --logdir --hpconfig network=1,n_rep=10,var_type='discrete',K=3,N=3,verbose=False 20 | python3 vae_grid.py --logdir --hpconfig network=1,n_epochs=75000,n_rep=5,var_type='continuous',z_size=8,n_flow=0,artificial_data=False,use_target_net=True,test_on_policy=True,verbose=False 21 | ``` 22 | For default hyper-parameters, look at the `get_hps()` function in the `vae_grid.py` and `vae_main.py` scripts. 23 | 24 | ## Reproducing Paper Results 25 | Run: 26 | ```sh 27 | bash paper_toy.sh (Sec 4.1) 28 | bash paper_grid.sh (Sec 4.2) 29 | bash paper_grid_rl.sh (Sec 4.2) 30 | ``` 31 | 32 | ## Citation 33 | ``` 34 | @proceedings{moerland2017learning, 35 | author = "Moerland, Thomas M. and Broekens, Joost and Jonker, Catholijn M.", 36 | note = "arXiv preprint arXiv:1705.00470", 37 | journal = "Scaling Up Reinforcement Learning (SURL) Workshop @ European Machine Learning Conference (ECML)", 38 | title = "{Learning Multimodal Transition Dynamics for Model-Based Reinforcement Learning}", 39 | year = "2017" 40 | } 41 | ``` 42 | 43 | -------------------------------------------------------------------------------- /datasets/Dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Dataset class 4 | @author: thomas 5 | """ 6 | import numpy as np 7 | import random 8 | 9 | class Dataset: 10 | ''' Bimodal, linear data ''' 11 | 12 | def __init__(self,datasize): 13 | self.datasize = datasize 14 | N = int(datasize/2) 15 | X1 = self.new_x_data(N) 16 | X2 = self.new_x_data(N) 17 | Y1 = -4 * (self.X1) + 5 + np.random.normal(0,0.1,N) 18 | Y2 = 4 * self.X2 + 1.6 + np.random.normal(0,0.1,N) 19 | 20 | self.X = np.append(X1,X2) 21 | self.Y = np.append(Y1,Y2) 22 | self.order = random.sample(range(datasize),datasize) 23 | 24 | def new_x_data(self,M): 25 | return np.random.uniform(-1,1,M) 26 | 27 | def next_batch_epoch(self,M): 28 | ''' epoch batch ''' 29 | while len(self.order) < M: 30 | self.order.extend(random.sample(range(self.datasize),self.datasize)) 31 | ind = self.order[0:M] 32 | del self.order[0:M] 33 | return self.X[ind], self.Y[ind] 34 | 35 | def next_batch_random(self,M): 36 | ''' random batch ''' 37 | ind = random.sample(range(self.datasize),M) 38 | return self.X[ind], self.Y[ind] 39 | 40 | 41 | -------------------------------------------------------------------------------- /datasets/Dataset2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Dataset class 4 | @author: thomas 5 | """ 6 | import numpy as np 7 | import random 8 | 9 | class Dataset: 10 | ''' Trimodal, non-linear data ''' 11 | 12 | def __init__(self,datasize): 13 | self.datasize = datasize 14 | self.X = self.new_x_data(datasize) 15 | 16 | def data_fn(x): 17 | if x < -0.3 : 18 | y = 2.5 + np.random.normal(0,0.1,1) 19 | elif x < 0.3: 20 | # bimodal 21 | if np.random.binomial(1,0.2): 22 | y = 4 * x + 4 + np.random.normal(0,0.1,1) 23 | else: 24 | y = -4 * x + 1.5 + np.random.normal(0,0.1,1) 25 | else: 26 | # trimodal 27 | if np.random.binomial(1,0.3): 28 | y = np.log(x+1)+5 + np.random.normal(0,0.1,1) 29 | elif np.random.binomial(1,0.5): 30 | y = -1 * x + 0.2 + np.random.normal(0,0.1,1) 31 | else: 32 | y = 5*x**2 + np.random.normal(0,0.1,1) 33 | return y 34 | 35 | self.Y = np.zeros(datasize) 36 | for i in range(datasize): 37 | self.Y[i] = data_fn(self.X[i]) 38 | self.order = random.sample(range(datasize),datasize) 39 | 40 | def new_x_data(self,M): 41 | return np.random.uniform(-1,1,M) 42 | 43 | def next_batch_epoch(self,M): 44 | ''' epoch batch ''' 45 | while len(self.order) < M: 46 | self.order.extend(random.sample(range(self.datasize),self.datasize)) 47 | ind = self.order[0:M] 48 | del self.order[0:M] 49 | return self.X[ind], self.Y[ind] 50 | 51 | def next_batch_random(self,M): 52 | ''' random batch ''' 53 | ind = random.sample(range(self.datasize),M) 54 | return self.X[ind], self.Y[ind] 55 | 56 | 57 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Mar 2 14:48:24 2017 4 | 5 | @author: thomas 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /extra.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Starting extra simulations, with multiple layer VAEs" 4 | python3 vae_main.py --hpconfig network=1,n_rep=10,var_type='continuous-discrete',depth=2,K=3,N=3,z_size=3,n_flow=5,verbose=False && 5 | python3 vae_grid.py --hpconfig artificial_data=True,network=1,n_epochs=75000,n_rep=5,var_type='continuous-discrete',depth=2,K=8,N=4,z_size=8,n_flow=6,verbose=False && 6 | python3 vae_grid.py --hpconfig artificial_data=False,use_target_net=True,test_on_policy=True,network=1,n_epochs=75000,n_rep=5,var_type='continuous-discrete',depth=2,z_size=8,n_flow=6,ar=False,verbose=False && 7 | 8 | echo "Starting extra simulations, with some hyperloop examples" 9 | python3 vae_main.py --hpconfig network=1,n_rep=10,var_type='continuous-discrete',depth=2,K=3,N=3,z_size=3,n_flow=5,verbose=False && 10 | python3 vae_main.py --hpconfig network=1,n_rep=10,var_type='discrete',K=3,N=3,loop_hyper=True,item1='lr',seq1='0.005-0.0005-0.00005',item2='kl_min',seq2='0-0.05-0.10-0.20',verbose=False && 11 | python3 vae_grid.py --hpconfig var_type='discrete',K=4,N=8,loop_hyper=True,n_epochs=50000,n_rep=3,item1='lr_init',seq1='0.005-0.0005-0.00005',item2='kl_min',seq2='0-0.05-0.10-0.30-1.0',verbose=False && 12 | python3 vae_grid.py --hpconfig var_type='discrete',K=4,N=8,loop_hyper=True,n_rep=3,lr=0.0005,kl_min=0.30,item1='out_lik',seq1='discrete-discretized_logistic',item2='ignore_sigma_outcome',seq2='True-False',verbose=False && 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /networks/layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Tensorflow variational inference layers 4 | @author: thomas 5 | """ 6 | import tensorflow as tf 7 | import tensorflow.contrib.slim as slim 8 | import numpy as np 9 | from tfutils.helpers import split 10 | from tfutils.distributions import gumbel_softmax, DiagonalGaussian 11 | 12 | class Latent_Layer(object): 13 | ''' Latent Layer object ''' 14 | 15 | def __init__(self,hps,var_type,depth,is_top=True): 16 | self.hps = hps 17 | self.var_type = var_type # discrete or continuous 18 | self.depth = depth 19 | self.is_top = is_top 20 | 21 | def up(self,h): 22 | h_up = slim.fully_connected(h,self.hps.h_size,activation_fn=tf.nn.relu) 23 | 24 | if self.var_type == 'discrete': 25 | # q_z 26 | self.K = K = self.hps.K 27 | self.N = N = self.hps.N 28 | h_up = slim.fully_connected(h_up,K*N,activation_fn=None) 29 | self.logits_q = tf.reshape(h_up,[-1,K]) # unnormalized logits for N separate K-categorical distributions (shape=(batch_size*N,K)) 30 | 31 | h_out = slim.fully_connected(h_up,self.hps.h_size,activation_fn=None) 32 | 33 | elif self.var_type == 'continuous': 34 | hps = self.hps 35 | z_size = hps.z_size 36 | h_size = hps.h_size 37 | 38 | h_up = slim.fully_connected(h_up,h_size,activation_fn=None) 39 | h_up = slim.fully_connected(h,z_size*2 + h_size,activation_fn=None) 40 | self.qz_mean, self.qz_logsd, h_out = split(h_up, 1, [z_size, z_size, h_size]) 41 | 42 | if self.hps.resnet: 43 | return h + 0.2 * h_out 44 | else: 45 | return h_out 46 | 47 | def down(self,h,is_training,temp,lamb): 48 | 49 | h_down = slim.fully_connected(h,self.hps.h_size,tf.nn.relu) 50 | 51 | if self.var_type == 'discrete': 52 | self.K = K = self.hps.K 53 | self.N = N = self.hps.N 54 | if self.is_top: 55 | h_down = slim.fully_connected(h_down,K*N + self.hps.h_size,activation_fn=None) 56 | logits_p, h_det = split(h_down,1,[K*N] + [self.hps.h_size]) 57 | logits_q = self.logits_q 58 | else: 59 | # top down inference 60 | h_down = slim.fully_connected(h_down,2*K*N + self.hps.h_size,activation_fn=None) 61 | logits_p,logits_r, h_det = split(h_down,1,[K*N]*2 + [self.hps.h_size]) 62 | logits_q = self.logits_q + tf.reshape(logits_r,[-1,K]) 63 | 64 | self.logits_p = logits_p = tf.reshape(logits_p,[-1,K]) 65 | self.p_z = tf.nn.softmax(logits_p) 66 | self.q_z = tf.nn.softmax(logits_q) 67 | 68 | # Sample z 69 | z = tf.cond(is_training, 70 | lambda: tf.reshape(gumbel_softmax(logits_q,temp,hard=False),[-1,N,K]), 71 | lambda: tf.reshape(gumbel_softmax(logits_p,temp,hard=False),[-1,N,K]) 72 | ) 73 | 74 | # KL divergence 75 | kl_discrete = tf.reshape(self.q_z*(tf.log(self.q_z+1e-20)-tf.log(self.p_z+1e-20)),[-1,N,K]) 76 | kl = tf.reduce_sum(kl_discrete,[2]) # sum over number of categories 77 | 78 | # pass on 79 | h_down = tf.concat([slim.flatten(z),h_det],1) 80 | h_out = slim.fully_connected(h_down,self.hps.h_size,activation_fn=tf.nn.relu) 81 | 82 | elif self.var_type == 'continuous': 83 | hps = self.hps 84 | z_size = hps.z_size 85 | h_size = hps.h_size 86 | 87 | h_down = slim.fully_connected(h_down,self.hps.h_size,activation_fn=None) 88 | 89 | if self.is_top: 90 | h_down = slim.fully_connected(h_down, 2 * z_size + h_size,activation_fn=None) 91 | pz_mean, pz_logsd, h_det = split(h_down, 1, [z_size] * 2 + [h_size] * 1) 92 | qz_mean = self.qz_mean 93 | qz_logsd = self.qz_logsd 94 | else: 95 | # top down inference 96 | h_down = slim.fully_connected(h_down, 4 * z_size + h_size,activation_fn=None) 97 | pz_mean, pz_logsd, rz_mean, rz_logsd, h_det = split(h_down, 1, [z_size] * 4 + [h_size] * 1) 98 | qz_mean = self.qz_mean + rz_mean 99 | qz_logsd = self.qz_logsd + rz_logsd 100 | 101 | # identify distributions 102 | if self.hps.ignore_sigma_latent: 103 | prior = DiagonalGaussian(pz_mean,tf.zeros(tf.shape(pz_mean))) 104 | posterior = DiagonalGaussian(qz_mean,tf.zeros(tf.shape(qz_mean))) 105 | else: 106 | prior = DiagonalGaussian(pz_mean,2*pz_logsd) 107 | posterior = DiagonalGaussian(qz_mean,2*qz_logsd) 108 | 109 | # sample z 110 | z = tf.cond(is_training, 111 | lambda: posterior.sample, 112 | lambda: prior.sample) 113 | 114 | # KL Divergence with flow 115 | z, kl = tf.cond(is_training, 116 | lambda: kl_train(z,prior,posterior,hps), 117 | lambda: kl_test(z)) 118 | 119 | # output # pass on 120 | h_down = tf.concat([slim.flatten(z),h_det],1) 121 | h_out = slim.fully_connected(h_down,self.hps.h_size,activation_fn=tf.nn.relu) 122 | 123 | # Manipulate KL divergence 124 | kl_sample = kl 125 | if self.hps.kl_min > 0: # use free-bits/nats (Kingma, 2016) 126 | kl_ave = tf.reduce_mean(kl,[0],keep_dims=True) # average over mini-batch 127 | kl_max = tf.maximum(kl_ave,self.hps.kl_min) 128 | kl = tf.tile(kl_max,[tf.shape(kl)[0],1]) # shape: [batch_size * k,latent_size] 129 | if self.hps.use_lamb: # use warm-up 130 | kl = lamb * kl 131 | 132 | kl_sum = tf.reduce_sum(kl,[1]) # shape [batch_size*k,] 133 | kl_sample = tf.reduce_sum(kl_sample,[1]) 134 | 135 | if self.hps.resnet: 136 | return h + 0.2 * h_out, kl_sum 137 | else: 138 | return h_out, kl_sum, kl_sample 139 | 140 | def kl_test(z): 141 | kl = tf.zeros([1,1]) 142 | return z, kl 143 | 144 | def kl_train(z,prior,posterior,hps): 145 | # push prior through AR layer 146 | logqs = posterior.logps(z) 147 | if hps.n_flow > 0: 148 | nice_layers = [] 149 | print('Does this print') 150 | for i in range(hps.n_flow): 151 | nice_layers.append(nice_layer(tf.shape(z),hps,'nice{}'.format(i),ar=hps.ar)) 152 | 153 | for i,layer in enumerate(nice_layers): 154 | z,log_det = layer.forward(z) 155 | logqs += log_det 156 | 157 | # track the KL divergence after transformation 158 | logps = prior.logps(z) 159 | kl = logqs - logps 160 | return z, kl 161 | 162 | ### Autoregressive layers 163 | class nice_layer: 164 | ''' Autoregressive layer with easy inverse (real-nvp layer) ''' 165 | 166 | def __init__(self,input_shape,hps,name,n_hidden=20,pattern=2,ar=False): 167 | self.name = name # variable scope 168 | self.batch_size = input_shape[0] 169 | self.latent_size = input_shape[1] 170 | self.n_hidden = n_hidden 171 | self.hps = hps 172 | self.ar = ar 173 | # calculate mask 174 | self.mask = self._get_mask() 175 | 176 | def _get_mask(self): 177 | numbers = np.random.binomial(1,0.5,[1,self.hps.z_size]) 178 | mask = tf.tile(tf.constant(numbers,dtype='bool'),[self.batch_size,1]) 179 | return mask 180 | 181 | def _get_mu_and_sigma(self,z): 182 | # predict mu and sigma 183 | z_mask = z * tf.to_float(self.mask) 184 | mid = slim.fully_connected(z_mask,self.n_hidden,activation_fn=tf.nn.relu) 185 | mu_pred = slim.fully_connected(mid,self.hps.z_size,activation_fn=None) 186 | log_sigma_pred = slim.fully_connected(mid,self.hps.z_size,activation_fn=None) 187 | 188 | # inverse mask the outcome 189 | mu_out = mu_pred * tf.to_float(tf.logical_not(self.mask)) 190 | log_sigma_out = log_sigma_pred * tf.to_float(tf.logical_not(self.mask)) 191 | return mu_out, log_sigma_out 192 | 193 | def forward(self,z): 194 | if not self.ar: 195 | mu,log_sigma = self._get_mu_and_sigma(z) 196 | else: 197 | # permute z 198 | z = tf.reshape(z,[-1]+[1]*self.hps.z_size) 199 | perm = np.random.permutation(self.hps.z_size)+1 200 | z = tf.transpose(z,np.append([0],perm)) 201 | z = tf.reshape(z,[-1,self.hps.z_size]) 202 | mu,log_sigma = ar_layer(z,self.hps,n_hidden=self.n_hidden) 203 | log_sigma = tf.clip_by_value(log_sigma,-5,5) 204 | if not self.hps.ignore_sigma_flow: 205 | y = z * tf.exp(log_sigma) + mu 206 | log_det = -1 * log_sigma 207 | else: 208 | y = z + mu 209 | log_det = 0.0 210 | return y,log_det 211 | 212 | def backward(self,y): 213 | mu,log_sigma = self._get_mu_and_sigma(y) 214 | log_sigma = tf.clip_by_value(log_sigma,-5,5) 215 | if not self.hps.ignore_sigma_flow: 216 | z = (y - mu)/tf.exp(log_sigma) 217 | log_det = log_sigma 218 | else: 219 | z = y - mu 220 | log_det = 0.0 221 | return z,log_det 222 | 223 | def ar_layer(z0,hps,n_hidden=10): 224 | ''' old iaf layer ''' 225 | # Repeat input 226 | z_rep = tf.reshape(tf.tile(z0,[1,hps.z_size]),[-1,hps.z_size]) 227 | 228 | # make mask 229 | mask = tf.sequence_mask(tf.range(hps.z_size),hps.z_size)[None,:,:] 230 | mask = tf.reshape(tf.tile(mask,[tf.shape(z0)[0],1,1]),[-1,hps.z_size]) 231 | 232 | # predict mu and sigma 233 | z_mask = z_rep * tf.to_float(mask) 234 | mid = slim.fully_connected(z_mask,n_hidden,activation_fn=tf.nn.relu) 235 | pars = slim.fully_connected(mid,2,activation_fn=None) 236 | pars = tf.reshape(pars,[-1,hps.z_size,2]) 237 | mu, log_sigma = tf.unstack(pars,axis=2) 238 | return mu, log_sigma -------------------------------------------------------------------------------- /networks/network_rl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Mar 2 14:48:24 2017 4 | 5 | @author: thomas 6 | """ 7 | #from layers import Latent_Layer 8 | import sys 9 | import os.path 10 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 11 | sys.path.append(os.path.abspath(os.path.dirname(__file__))) 12 | 13 | import tensorflow as tf 14 | import tensorflow.contrib.slim as slim 15 | import numpy as np 16 | from tfutils.helpers import repeat_v2 17 | from tfutils.distributions import logsumexp, discretized_logistic 18 | from layers import Latent_Layer 19 | 20 | class Network(object): 21 | ''' VAE & RL template ''' 22 | 23 | def __init__(self,hps,state_dim,binsize=6): 24 | # binsize = the number of discrete categories per dimension of the state (x and y below). 25 | # Input and output are normalized over this quantity. 26 | 27 | # placeholders 28 | self.x = x = tf.placeholder("float32", shape=[None,state_dim]) 29 | self.y = y = tf.placeholder("float32", shape=[None,state_dim]) 30 | self.a = a = tf.placeholder("float32", shape=[None,1]) 31 | self.Qtarget = Qtarget = tf.placeholder("float32", shape=[None,1]) 32 | 33 | self.is_training = is_training = tf.placeholder("bool") # if True: sample from q, else sample from p 34 | self.k = k = tf.placeholder('int32') # number of importance samples 35 | self.temp = temp = tf.Variable(5.0,name='temperature',trainable=False) # Temperature for discrete latents 36 | self.lamb = lamb = tf.Variable(1.0,name="lambda",trainable=False) # Lambda for KL annealing 37 | 38 | xa = tf.concat([x/binsize,a],axis=1) 39 | # Importance sampling: repeats along second dimension 40 | xa_rep = repeat_v2(xa,k) 41 | y_rep = repeat_v2(y/binsize,k) 42 | 43 | # RL part of the graph 44 | with tf.variable_scope('q_net'): 45 | rl1 = slim.fully_connected(x,50,tf.nn.relu) 46 | rl2 = slim.fully_connected(rl1,50,tf.nn.relu) 47 | rl3 = slim.fully_connected(rl2,50,activation_fn=None) 48 | self.Qsa = Qsa = slim.fully_connected(rl3,4,activation_fn=None) 49 | 50 | if hps.use_target_net: 51 | 52 | with tf.variable_scope('target_net'): 53 | rl1_t = slim.fully_connected(x,50,tf.nn.relu) 54 | rl2_t = slim.fully_connected(rl1_t,50,tf.nn.relu) 55 | rl3_t = slim.fully_connected(rl2_t,50,activation_fn=None) 56 | self.Qsa_t = slim.fully_connected(rl3_t,4,activation_fn=None) 57 | 58 | copy_ops = [] 59 | q_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='q_net') 60 | tar_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net') 61 | 62 | for tar,q in zip(q_var,tar_var): 63 | copy_op = q.assign(tar) 64 | copy_ops.append(copy_op) 65 | self.copy_op = tf.group(*copy_ops, name='copy_op') 66 | 67 | a_onehot = tf.one_hot(tf.to_int32(tf.squeeze(a,axis=1)),4,1.0,0.0) 68 | Qs = tf.reduce_sum(a_onehot*Qsa,reduction_indices=1) ## identify Qsa based on a 69 | self.rl_cost = rl_cost = tf.nn.l2_loss(Qs - Qtarget) 70 | # Batch norm: skip for now 71 | 72 | # Encoder x,y --> h 73 | xy = tf.concat([xa_rep,y_rep],1) # concatenate along last dim 74 | h_up = slim.fully_connected(xy,hps.h_size,tf.nn.relu) 75 | 76 | # Initialize ladders 77 | layers = [] 78 | for i in range(hps.depth): 79 | layers.append(Latent_Layer(hps,hps.var_type[i],i)) 80 | 81 | # Ladder up 82 | for i,layer in enumerate(layers): 83 | h_up = layer.up(h_up) 84 | 85 | # Ladder down 86 | # Prior x --> p_z 87 | h_down = slim.fully_connected(xa_rep,hps.h_size,tf.nn.relu) 88 | kl_sum = 0.0 89 | kl_sample = 0.0 90 | for i,layer in reversed(list(enumerate(layers))): 91 | h_down, kl_cur, kl_sam = layer.down(h_down,is_training,temp,lamb) 92 | kl_sum += kl_cur 93 | kl_sample += kl_sam 94 | 95 | # Decoder: x,z --> y 96 | xz = tf.concat([slim.flatten(h_down),xa_rep],1) 97 | dec1 = slim.fully_connected(xz,250,tf.nn.relu) 98 | dec2 = slim.fully_connected(dec1,250,tf.nn.relu) 99 | dec3 = slim.fully_connected(dec2,250,activation_fn=None) 100 | mu_y = slim.fully_connected(dec3,state_dim,activation_fn=None) 101 | 102 | if hps.ignore_sigma_outcome: 103 | log_dec_noise = tf.zeros(tf.shape(mu_y)) 104 | else: 105 | log_dec_noise = slim.fully_connected(dec3,1,activation_fn=None) 106 | 107 | # p(y|x,z) 108 | if hps.out_lik == 'normal': 109 | dec_noise = tf.exp(tf.clip_by_value(log_dec_noise,-10,10)) 110 | outdist = tf.contrib.distributions.Normal(mu_y,dec_noise) 111 | self.log_py_x = log_py_x = tf.reduce_sum(outdist.log_prob(y_rep),axis=1) 112 | self.nats = -1*tf.reduce_mean(logsumexp(tf.reshape(log_py_x - kl_sample,[-1,k])) - tf.log(tf.to_float(k))) 113 | y_sample = outdist.sample() if not hps.ignore_sigma_outcome else mu_y 114 | self.y_sample = tf.to_int32(tf.round(tf.clip_by_value(y_sample,0,1)*binsize)) 115 | elif hps.out_lik == 'discretized_logistic': 116 | self.log_py_x = log_py_x = tf.reduce_sum(discretized_logistic(mu_y,log_dec_noise,binsize=1,sample=y_rep),axis=1) 117 | outdist = tf.contrib.distributions.Logistic(loc=mu_y,scale = tf.exp(log_dec_noise)) 118 | self.nats = -1*tf.reduce_mean(logsumexp(tf.reshape(tf.reduce_sum(outdist.log_prob(y_rep),axis=1) - kl_sample,[-1,k]))- tf.log(tf.to_float(k))) 119 | y_sample = outdist.sample() if not hps.ignore_sigma_outcome else mu_y 120 | self.y_sample = tf.to_int32(tf.round(tf.clip_by_value(y_sample,0,1)*binsize)) 121 | elif hps.out_lik == 'discrete': 122 | logits_y = slim.fully_connected(dec3,state_dim*(binsize+1),activation_fn=None) 123 | logits_y = tf.reshape(logits_y,[-1,state_dim,binsize+1]) 124 | disc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_y,labels=tf.to_int32(tf.round(y_rep*6))) 125 | self.log_py_x = log_py_x = -tf.reduce_sum(disc_loss,[1]) 126 | self.nats = -1*tf.reduce_mean(logsumexp(tf.reshape(log_py_x - kl_sample,[-1,k])) - tf.log(tf.to_float(k))) 127 | outdist = tf.contrib.distributions.Categorical(logits=logits_y) 128 | self.y_sample = outdist.sample() if not hps.ignore_sigma_outcome else tf.argmax(logits_y,axis=2) 129 | 130 | # To display 131 | self.kl = tf.reduce_mean(kl_sum) 132 | 133 | # ELBO 134 | log_divergence = tf.reshape(log_py_x - kl_sum,[-1,k]) # shape [batch_size,k] 135 | if np.abs(hps.alpha-1.0)>1e-3: # use Renyi alpha-divergence 136 | log_divergence = log_divergence * (1-hps.alpha) 137 | logF = logsumexp(log_divergence) 138 | self.elbo = elbo = tf.reduce_mean(logF - tf.log(tf.to_float(k)))/ (1-hps.alpha) 139 | else: 140 | # use KL divergence 141 | self.elbo = elbo = tf.reduce_mean(log_divergence) 142 | self.loss = loss = -elbo 143 | 144 | ### Optimizer 145 | self.lr = lr = tf.Variable(0.001,name="learning_rate",trainable=False) 146 | global_step = tf.Variable(0,name='global_step',trainable=False) 147 | optimizer = tf.train.AdamOptimizer(learning_rate=lr) 148 | 149 | if hps.max_grad != None: 150 | grads_and_vars = optimizer.compute_gradients(loss) 151 | for idx, (grad, var) in enumerate(grads_and_vars): 152 | if grad is not None: 153 | grads_and_vars[idx] = (tf.clip_by_norm(grad, hps.max_grad), var) 154 | self.train_op = optimizer.apply_gradients(grads_and_vars) 155 | self.grads_and_vars = grads_and_vars 156 | else: 157 | self.train_op = optimizer.minimize(loss,global_step=global_step) 158 | self.grads_and_vars = tf.constant(0) 159 | 160 | self.train_op_rl = optimizer.minimize(rl_cost) 161 | self.init_op=tf.global_variables_initializer() -------------------------------------------------------------------------------- /networks/toy_mlp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Mar 2 14:48:24 2017 4 | 5 | @author: thomas 6 | """ 7 | #from layers import Latent_Layer 8 | import sys 9 | import os.path 10 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 11 | sys.path.append(os.path.abspath(os.path.dirname(__file__))) 12 | 13 | import tensorflow as tf 14 | import tensorflow.contrib.slim as slim 15 | from tfutils.helpers import repeat_v2 16 | from tfutils.distributions import logsumexp, DiagonalGaussian 17 | 18 | class Network(object): 19 | ''' Generic VAE template ''' 20 | 21 | def __init__(self,hps): 22 | 23 | # placeholders 24 | self.x = x = tf.placeholder("float32", shape=[None,1]) 25 | self.y = y = tf.placeholder("float32", shape=[None,1]) 26 | self.is_training = is_training = tf.placeholder("bool") # if True: sample from q, else sample from p 27 | self.k = k = tf.placeholder('int32') # number of importance samples 28 | self.temp = temp = tf.Variable(5.0,name='temperature',trainable=False) # Temperature for discrete latents 29 | self.lamb = lamb = tf.Variable(1.0,name="lambda",trainable=False) # Lambda for KL annealing 30 | 31 | # Importance sampling: repeats along second dimension 32 | x_rep = repeat_v2(x,k) 33 | y_rep = repeat_v2(y,k) 34 | 35 | if not hps.deterministic: 36 | # stochastic noise input to network 37 | z_dist = DiagonalGaussian(tf.zeros([tf.shape(x_rep)[0],hps.z_size]),tf.zeros([tf.shape(x_rep)[0],hps.z_size])) 38 | z = z_dist.sample 39 | logpz = tf.reduce_sum(z_dist.logps(z),axis=1) 40 | xz = tf.concat([z,x_rep],1) 41 | else: 42 | xz = x_rep 43 | logpz = 0.0 44 | hps.ignore_sigma_outcome = True 45 | 46 | # Decoder: x,z --> y 47 | dec1 = slim.fully_connected(xz,50,tf.nn.relu) 48 | dec2 = slim.fully_connected(dec1,50,tf.nn.relu) 49 | dec3 = slim.fully_connected(dec2,50,activation_fn=None) 50 | mu_y = slim.fully_connected(dec3,1,activation_fn=None) 51 | 52 | if hps.ignore_sigma_outcome: 53 | log_dec_noise = tf.zeros(tf.shape(mu_y)) 54 | else: 55 | log_dec_noise = slim.fully_connected(dec3,1,activation_fn=None) 56 | 57 | kl_sum = tf.zeros(tf.shape(mu_y)) 58 | # p(y|x,z) 59 | dec_noise = tf.exp(tf.clip_by_value(log_dec_noise,-10,10)) 60 | outdist = tf.contrib.distributions.Normal(mu_y,dec_noise) 61 | self.log_py_x = log_py_x = tf.reduce_sum(outdist.log_prob(y_rep),axis=1) 62 | self.nats = -1*tf.reduce_mean(logsumexp(tf.reshape(log_py_x + logpz,[-1,k])) - tf.log(tf.to_float(k))) 63 | 64 | if hps.ignore_sigma_outcome: 65 | self.y_sample = mu_y 66 | else: 67 | self.y_sample = outdist.sample() 68 | # To display 69 | self.kl = tf.reduce_mean(kl_sum) 70 | 71 | # ELBO 72 | log_divergence = tf.reshape(log_py_x + logpz,[-1,k]) # shape [batch_size,k] 73 | logF = logsumexp(log_divergence) 74 | self.elbo = elbo = tf.reduce_mean(logF - tf.log(tf.to_float(k))) 75 | self.loss = loss = -elbo 76 | 77 | ### Optimizer 78 | self.lr = lr = tf.Variable(0.001,name="learning_rate",trainable=False) 79 | global_step = tf.Variable(0,name='global_step',trainable=False) 80 | optimizer = tf.train.AdamOptimizer(learning_rate=lr) 81 | self.train_op = optimizer.minimize(loss,global_step=global_step) 82 | self.init_op=tf.global_variables_initializer() -------------------------------------------------------------------------------- /networks/toy_vae.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Mar 2 14:48:24 2017 4 | 5 | @author: thomas 6 | """ 7 | #from layers import Latent_Layer 8 | import sys 9 | import os.path 10 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 11 | sys.path.append(os.path.abspath(os.path.dirname(__file__))) 12 | 13 | import tensorflow as tf 14 | import tensorflow.contrib.slim as slim 15 | import numpy as np 16 | from tfutils.helpers import repeat_v2 17 | from tfutils.distributions import logsumexp, discretized_logistic 18 | from layers import Latent_Layer 19 | 20 | class Network(object): 21 | ''' VAE template ''' 22 | 23 | def __init__(self,hps): 24 | # placeholders 25 | self.x = x = tf.placeholder("float32", shape=[None,1]) 26 | self.y = y = tf.placeholder("float32", shape=[None,1]) 27 | self.is_training = is_training = tf.placeholder("bool") # if True: sample from q, else sample from p 28 | self.k = k = tf.placeholder('int32') # number of importance samples 29 | self.temp = temp = tf.Variable(5.0,name='temperature',trainable=False) # Temperature for discrete latents 30 | self.lamb = lamb = tf.Variable(1.0,name="lambda",trainable=False) # Lambda for KL annealing 31 | 32 | # Importance sampling: repeats along second dimension 33 | x_rep = repeat_v2(x,k) 34 | y_rep = repeat_v2(y,k) 35 | 36 | # Encoder x,y --> h 37 | xy = tf.concat([x_rep,y_rep],1) # concatenate along last dim 38 | h_up = slim.fully_connected(xy,hps.h_size,tf.nn.relu) 39 | 40 | # Initialize ladders 41 | layers = [] 42 | for i in range(hps.depth): 43 | layers.append(Latent_Layer(hps=hps,var_type=hps.var_type[i],depth=i,is_top=(i==(hps.depth-1)))) 44 | 45 | # Ladder up 46 | for i,layer in enumerate(layers): 47 | h_up = layer.up(h_up) 48 | 49 | # Prior x --> p_z 50 | h_down = slim.fully_connected(x_rep,hps.h_size,tf.nn.relu) 51 | kl_sum = 0.0 52 | kl_sample = 0.0 53 | # Ladder down 54 | for i,layer in reversed(list(enumerate(layers))): 55 | h_down, kl_cur, kl_sam = layer.down(h_down,is_training,temp,lamb) 56 | kl_sum += kl_cur 57 | kl_sample += kl_sam 58 | 59 | # Decoder: x,z --> y 60 | xz = tf.concat([slim.flatten(h_down),x_rep],1) 61 | dec1 = slim.fully_connected(xz,50,tf.nn.relu) 62 | dec2 = slim.fully_connected(dec1,50,tf.nn.relu) 63 | dec3 = slim.fully_connected(dec2,50,activation_fn=None) 64 | mu_y = slim.fully_connected(dec3,1,activation_fn=None) 65 | if hps.ignore_sigma_outcome: 66 | log_dec_noise = tf.zeros(tf.shape(mu_y)) 67 | else: 68 | log_dec_noise = slim.fully_connected(dec3,1,activation_fn=None) 69 | 70 | # p(y|x,z) 71 | if hps.out_lik == 'normal': 72 | dec_noise = tf.exp(tf.clip_by_value(log_dec_noise,-10,10)) 73 | outdist = tf.contrib.distributions.Normal(mu_y,dec_noise) 74 | self.log_py_x = log_py_x = tf.reduce_sum(outdist.log_prob(y_rep),axis=1) 75 | self.nats = -1*tf.reduce_mean(logsumexp(tf.reshape(log_py_x - kl_sample,[-1,k])) - tf.log(tf.to_float(k))) 76 | elif hps.out_lik == 'discretized_logistic': 77 | self.log_py_x = log_py_x = tf.reduce_sum(discretized_logistic(mu_y,log_dec_noise,binsize=1,sample=y_rep),axis=1) 78 | outdist = tf.contrib.distributions.Logistic(loc=mu_y,scale = tf.exp(log_dec_noise)) 79 | self.nats = -1*tf.reduce_mean(logsumexp(tf.reshape(tf.reduce_sum(outdist.log_prob(y_rep),axis=1) - kl_sample,[-1,k]))- tf.log(tf.to_float(k))) 80 | elif hps.out_lik == 'squared_error': 81 | hps.ignore_sigma_outcome = True 82 | self.log_py_x = log_py_x = -tf.reduce_sum(tf.pow(mu_y - y_rep, 2),axis=1) # Gaussian loglik has minus in front 83 | self.nats = tf.zeros([1]) 84 | 85 | # sample y 86 | if hps.ignore_sigma_outcome: 87 | self.y_sample = mu_y 88 | else: 89 | self.y_sample = outdist.sample() 90 | 91 | # To display KL 92 | self.kl = tf.reduce_mean(kl_sum) 93 | 94 | # ELBO 95 | log_divergence = tf.reshape(log_py_x - kl_sum,[-1,k]) # shape [batch_size,k] 96 | if np.abs(hps.alpha-1.0)>1e-3: # prevent zero division 97 | log_divergence = log_divergence * (1-hps.alpha) 98 | logF = logsumexp(log_divergence) 99 | self.elbo = elbo = tf.reduce_mean(logF - tf.log(tf.to_float(k)))/ (1-hps.alpha) 100 | else: 101 | logF = logsumexp(log_divergence) 102 | self.elbo = elbo = tf.reduce_mean(logF - tf.log(tf.to_float(k))) 103 | 104 | self.loss = loss = -elbo 105 | 106 | ### Optimizer 107 | self.lr = lr = tf.Variable(0.001,name="learning_rate",trainable=False) 108 | global_step = tf.Variable(0,name='global_step',trainable=False) 109 | optimizer = tf.train.AdamOptimizer(learning_rate=lr) 110 | self.train_op = optimizer.minimize(loss,global_step=global_step) 111 | self.init_op=tf.global_variables_initializer() 112 | 113 | #gvs = optimizer.compute_gradients(loss) 114 | #if hps.grad_clip > 0: # gradient clipping 115 | # gvs = [(tf.clip_by_value(grad, -hps.grad_clip, hps.grad_clip), var) for grad, var in gvs] 116 | #self.train_op = optimizer.apply_gradients(gvs) 117 | 118 | 119 | -------------------------------------------------------------------------------- /paper_grid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Starting grid-world simulations on decorrelated data (randomly sampled transitions across state-space)" 4 | 5 | python3 vae_grid.py --hpconfig artificial_data=True,network=1,n_epochs=75000,n_rep=5,var_type='discrete',K=4,N=8,verbose=False && 6 | python3 vae_grid.py --hpconfig artificial_data=True,network=1,n_epochs=75000,n_rep=5,var_type='continuous',z_size=8,n_flow=0,verbose=False && 7 | python3 vae_grid.py --hpconfig artificial_data=True,network=1,n_epochs=75000,n_rep=5,var_type='continuous',z_size=8,n_flow=6,verbose=False && 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /paper_grid_rl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Starting simulations on gridworld with RL agent" 4 | 5 | python3 vae_grid.py --hpconfig artificial_data=False,use_target_net=True,test_on_policy=True,network=1,n_epochs=75000,n_rep=5,var_type='discrete',K=4,N=8,verbose=False && 6 | python3 vae_grid.py --hpconfig artificial_data=False,use_target_net=True,test_on_policy=True,network=1,n_epochs=75000,n_rep=5,var_type='continuous',z_size=8,n_flow=0,verbose=False && 7 | python3 vae_grid.py --hpconfig artificial_data=False,use_target_net=True,test_on_policy=True,network=1,n_epochs=75000,n_rep=5,var_type='continuous',z_size=8,n_flow=6,ar=False,verbose=False && 8 | python3 vae_grid.py --hpconfig artificial_data=False,use_target_net=True,test_on_policy=True,network=1,n_epochs=75000,n_rep=5,var_type='continuous',z_size=8,n_flow=3,ar=True,verbose=False && 9 | 10 | -------------------------------------------------------------------------------- /paper_toy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Start running simulations to reproduce paper results on toy domain" 4 | 5 | python3 vae_main.py --hpconfig network=2,n_rep=10,deterministic=True,verbose=False && 6 | python3 vae_main.py --hpconfig network=2,n_rep=10,deterministic=False,z_size=3,verbose=False && 7 | python3 vae_main.py --hpconfig network=1,n_rep=10,var_type='discrete',K=3,N=3,verbose=False && 8 | python3 vae_main.py --hpconfig network=1,n_rep=10,var_type='continuous',z_size=3,n_flow=0,verbose=False && 9 | python3 vae_main.py --hpconfig network=1,n_rep=10,var_type='continuous',z_size=3,n_flow=5,verbose=False && 10 | -------------------------------------------------------------------------------- /plotter/Plotter2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Plotting class 4 | @author: thomas 5 | """ 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | 9 | class Plotter: 10 | ''' Manage plotting ''' 11 | 12 | def __init__(self): 13 | pass 14 | 15 | def plot_data(self,Data): 16 | 17 | self.fig = fig = plt.figure() 18 | self.ax1 = ax1 = fig.add_subplot(311) 19 | self.ax2 = ax2 = fig.add_subplot(312) 20 | self.ax3 = ax3 = fig.add_subplot(313) 21 | 22 | ax1.scatter(Data.X,Data.Y,color='b') 23 | ax1.axis([-1, 1, -1, 7]) 24 | plt.xlabel('S') 25 | plt.ylabel('S\'') 26 | plt.title('True data') 27 | 28 | new_dat, = ax2.plot([],[],'ko') 29 | self.new_dat = new_dat 30 | ax2.axis([-1, 1, -1, 7]) 31 | plt.xlabel('S') 32 | plt.ylabel('S\'') 33 | 34 | plt.subplot(313) 35 | new_dat_2, = ax3.plot([],[],'ro') 36 | self.new_dat_2 = new_dat_2 37 | self.t = np.zeros(300) 38 | self.lr = np.zeros(300) 39 | ax3.axis([0, 100000, 0, 0.01]) 40 | plt.ylabel('learning rate') 41 | 42 | #plt.draw() 43 | #plt.show(block=False) 44 | 45 | def plot_samples(self,x,y): 46 | self.new_dat.set_xdata(x) 47 | self.new_dat.set_ydata(y) 48 | self.fig.canvas.draw() 49 | #plt.pause(0.001) 50 | #plt.show(block=False) 51 | 52 | def plot_lr(self,t,lr): 53 | #plt.plot(t,lr) 54 | self.t[:len(t)] = t 55 | self.lr[:len(t)] = lr 56 | #self.ax3.clear() 57 | #self.ax3.plot(t,lr) 58 | self.new_dat_2.set_xdata(self.t) 59 | self.new_dat_2.set_ydata(self.lr) 60 | self.fig.canvas.draw() 61 | #plt.pause(0.001) 62 | #plt.show(block=False) 63 | 64 | -------------------------------------------------------------------------------- /plotter/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Mar 2 14:48:24 2017 4 | 5 | @author: thomas 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /pythonutils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Mar 2 14:48:24 2017 4 | 5 | @author: thomas 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /pythonutils/helpers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Python helpers 4 | @author: thomas 5 | """ 6 | import os 7 | import matplotlib.pyplot as plt 8 | 9 | def save(path, ext='png', close=True, verbose=True): 10 | """Save a figure from pyplot. 11 | Parameters 12 | ---------- 13 | path : string 14 | The path (and filename, without the extension) to save the 15 | figure to. 16 | ext : string (default='png') 17 | The file extension. This must be supported by the active 18 | matplotlib backend (see matplotlib.backends module). Most 19 | backends support 'png', 'pdf', 'ps', 'eps', and 'svg'. 20 | close : boolean (default=True) 21 | Whether to close the figure after saving. If you want to save 22 | the figure multiple times (e.g., to multiple formats), you 23 | should NOT close it in between saves or you will have to 24 | re-plot it. 25 | verbose : boolean (default=True) 26 | Whether to print information about when and where the image 27 | has been saved. 28 | """ 29 | 30 | # Extract the directory and filename from the given path 31 | directory = os.path.split(path)[0] 32 | filename = "%s.%s" % (os.path.split(path)[1], ext) 33 | if directory == '': 34 | directory = '.' 35 | 36 | # If the directory does not exist, create it 37 | if not os.path.exists(directory): 38 | os.makedirs(directory) 39 | 40 | # The final path to save to 41 | savepath = os.path.join(directory, filename) 42 | 43 | if verbose: 44 | print("Saving figure to '%s'..." % savepath), 45 | 46 | # Actually save the figure 47 | plt.savefig(savepath) 48 | 49 | # Close it 50 | if close: 51 | plt.close() 52 | 53 | if verbose: 54 | print("Done") 55 | 56 | def nested_list(n1,n2,n3): 57 | results=[] 58 | for i in range(n1): 59 | results.append([]) 60 | for j in range(n2): 61 | results[-1].append([]) 62 | for k in range(n3): 63 | results[-1][-1].append([]) 64 | return results 65 | 66 | def make_name(hps): 67 | ''' structures output folders based on hps ''' 68 | name = '' 69 | if hasattr(hps,'artificial_data'): 70 | if not hps.artificial_data: 71 | if hps.use_target_net: 72 | name += 'RL_target_network/' 73 | else: 74 | name += 'RL/' 75 | else: 76 | name +='decorrelated/' 77 | if hps.loop_hyper: 78 | name += 'hyper_{}_{}/'.format(hps.item1,hps.item2) 79 | elif hps.network == 2: 80 | if hps.deterministic: 81 | name += 'deterministic' 82 | else: 83 | name += 'mlp_z{}'.format(hps.z_size) 84 | elif len(hps.var_type) > 1: 85 | name += '{}_z{}_nf{}_n{}_K{}'.format('_'.join(hps.var_type),hps.z_size,hps.n_flow,hps.N,hps.K) 86 | elif hps.var_type == ['discrete']: 87 | name += '{}_n{}_K{}'.format(hps.var_type[0],hps.N,hps.K) 88 | elif hps.var_type == ['continuous']: 89 | name += '{}_z{}_nf{}{}'.format(hps.var_type[0],hps.z_size,hps.n_flow,'ar' if hps.ar else '') 90 | return name -------------------------------------------------------------------------------- /rlenv/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Mar 2 14:48:24 2017 4 | 5 | @author: thomas 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /rlenv/chicken.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Wet-Chicken benchmark 4 | @author: thomas 5 | """ 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | class chicken_env(object): 11 | ''' Wet Chicken Benchmark ''' 12 | 13 | def __init__(self,to_plot = True): 14 | self.state = np.array([0,0]) 15 | self.observation_shape = np.shape(self.get_state())[0] 16 | 17 | if to_plot: 18 | plt.ion() 19 | fig = plt.figure() 20 | ax1 = fig.add_subplot(111,aspect='equal') 21 | #ax1.axis('off') 22 | plt.xlim([-0.5,5.5]) 23 | plt.ylim([-0.5,5.5]) 24 | 25 | self.g1 = ax1.add_artist(plt.Circle((self.state[0],self.state[1]),0.1,color='red')) 26 | self.fig = fig 27 | self.ax1 = ax1 28 | self.fig.canvas.draw() 29 | self.fig.canvas.flush_events() 30 | 31 | def reset(self): 32 | self.state = np.array([0,0]) 33 | return self.get_state() 34 | 35 | def get_state(self): 36 | return self.state/5 37 | 38 | def set_state(self,state): 39 | self.state = state 40 | 41 | def step(self,a): 42 | x = self.state[0] 43 | y = self.state[1] 44 | ax = a[0] 45 | ay = a[1] 46 | tau = np.random.uniform(-1,1) 47 | w=5.0 48 | l=5.0 49 | 50 | v = 3 * x * (1/w) 51 | s = 3.5 - v 52 | yhat = y + ay - 1 + v + s*tau 53 | 54 | # change x 55 | if x + ax < 0: 56 | x = 0 57 | elif yhat > l: 58 | x = 0 59 | elif x + ax > w: 60 | x = w 61 | else: 62 | x = x + ax 63 | 64 | # change y 65 | if yhat < 0: 66 | y = 0 67 | elif yhat > l: 68 | y = 0 69 | else: 70 | y = yhat 71 | 72 | self.state = np.array([x,y]).flatten() 73 | 74 | r = - (l - y) 75 | 76 | return self.state,r,yhat>l 77 | 78 | def plot(self): 79 | self.g1.remove() 80 | self.g1 = self.ax1.add_artist(plt.Circle((self.state[0],self.state[1]),0.1,color='red')) 81 | self.fig.canvas.draw() 82 | 83 | # Test 84 | if __name__ == '__main__': 85 | Env = chicken_env(True) 86 | s = Env.get_state() 87 | for i in range(500): 88 | a = np.random.uniform(-1,1,2) 89 | s,r,dead = Env.step(a) 90 | if not dead: 91 | Env.plot() 92 | else: 93 | print('Died in step',i,', restarting') 94 | s = Env.reset() 95 | print(Env.get_state()) 96 | print('Finished') -------------------------------------------------------------------------------- /rlenv/grid.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Grid-world Environment wit stochastic ghost 4 | @author: thomas 5 | """ 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import matplotlib.patches as patches 10 | import random 11 | 12 | class grid_env(object): 13 | ''' Grid world with stochastic ghosts ''' 14 | 15 | def __init__(self,to_plot=True,grid=False): 16 | world = np.zeros([7,7],dtype='int32') 17 | world[1:6,1] = 1 18 | world[1:3,4] = 1 19 | world[4:6,4] = 1 20 | self.world = world 21 | self.grid = grid 22 | self.reset() 23 | self.observation_shape = np.shape(self.get_state())[0] 24 | 25 | if to_plot: 26 | plt.ion() 27 | fig = plt.figure() 28 | ax1 = fig.add_subplot(111,aspect='equal') 29 | ax1.axis('off') 30 | plt.xlim([-1,8]) 31 | plt.ylim([-1,8]) 32 | 33 | #colors = matplotlib.colors.ListerColormap() 34 | for i in range(7): 35 | for j in range(7): 36 | if world[i,j]==1: 37 | col = "black" 38 | else: 39 | col = "white" 40 | ax1.add_patch( 41 | patches.Rectangle( 42 | (i,j),1,1, 43 | #fill=False, 44 | edgecolor='black', 45 | linewidth = 2, 46 | facecolor = col,), 47 | ) 48 | if np.all([i,j] == self.ghost1): 49 | self.g1 = ax1.add_artist(plt.Circle((i+0.5,j+0.5),0.3,color='red')) 50 | if np.all([i,j] == self.ghost2): 51 | self.g2 = ax1.add_artist(plt.Circle((i+0.5,j+0.5),0.3,color='blue')) 52 | if np.all([i,j] == self.pacman): 53 | self.p = ax1.add_artist(plt.Circle((i+0.5,j+0.5),0.3,color='yellow')) 54 | self.fig = fig 55 | self.ax1 = ax1 56 | self.fig.canvas.draw() 57 | 58 | def reset(self): 59 | self.pacman = np.array([0,0]) 60 | self.ghost1 = np.array([1,3]) 61 | self.ghost2 = np.array([5,3]) 62 | return self.get_state() 63 | 64 | def set_state(self,state): 65 | self.pacman = np.array(state[0:2]) 66 | self.ghost1 = np.array(state[2:4]) 67 | self.ghost2 = np.array(state[4:6]) 68 | 69 | def step(self,a): 70 | # move pacman 71 | self._move(self.pacman,a) 72 | 73 | # check collision 74 | dead = self._check_dead() 75 | if dead: 76 | r = -1 77 | return self.get_state(),r,dead 78 | 79 | # move ghosts 80 | wall = True 81 | while wall: 82 | a1 = random.sample(range(4),1) # random ghost 83 | wall = self._move(self.ghost1,a1) 84 | 85 | # move ghosts 86 | wall = True 87 | while wall: 88 | a2 = np.where(np.random.multinomial(1,[0.1,0.1,0.4,0.4]))[0] # probabilistic ghost 89 | wall = self._move(self.ghost2,a2) 90 | 91 | # check collision again 92 | dead = self._check_dead() 93 | if dead: 94 | r = -1 95 | else: 96 | if np.all(self.pacman == np.array([6,6])): 97 | r = 10 98 | dead = True 99 | #print('Reached the goal') 100 | else: 101 | r = 0 102 | return self.get_state(),r,dead 103 | 104 | def get_state(self): 105 | if not self.grid: 106 | state = np.concatenate((self.pacman,self.ghost1,self.ghost2)) 107 | else: 108 | state = np.copy(self.world) 109 | state = np.stack(state,np.zeros(7,7),np.zeros(7,7),np.zeros(7,7),axis=2) 110 | state[self.pacman[0],self.pacman[1],1] = 1 111 | state[self.ghost1[0],self.ghost1[1],2] = 1 112 | state[self.ghost2[0],self.ghost2[1],3] = 1 113 | return state 114 | 115 | def plot(self): 116 | self.g1.remove() 117 | self.g2.remove() 118 | self.p.remove() 119 | 120 | # replot 121 | self.g1 = self.ax1.add_artist(plt.Circle(self.ghost1+0.5,0.3,color='red')) 122 | self.g2 = self.ax1.add_artist(plt.Circle(self.ghost2+0.5,0.3,color='blue')) 123 | self.p = self.ax1.add_artist(plt.Circle(self.pacman +0.5,0.3,color='yellow')) 124 | self.fig.canvas.draw() 125 | 126 | def plot_predictions(self,world): 127 | for i in range(7): 128 | for j in range(7): 129 | for k in range(3): 130 | if k==1: 131 | col = "yellow" 132 | elif k == 2: 133 | col = "red" 134 | elif k == 3: 135 | col = 'blue' 136 | if world[i,j,k]>0.0: 137 | self.ax1.add_patch(patches.Rectangle( 138 | (i,j),1,1, 139 | #fill=False, 140 | edgecolor='black', 141 | linewidth = 2, 142 | facecolor = col, 143 | alpha=world[i,j,k]), 144 | ) 145 | 146 | def _move(self,s,a): 147 | s_old = np.copy(s) 148 | 149 | # move 150 | if int(a[0]) == 0: #up 151 | s[1] +=1 152 | elif int(a[0]) == 1: #down 153 | s[1] -=1 154 | elif int(a[0])== 2: #right 155 | s[0] +=1 156 | elif int(a[0])==3: #left 157 | s[0] -=1 158 | else: 159 | raise ValueError('move not possible') 160 | 161 | # check if move is possible 162 | if s[0]<0 or s[0]>6 or s[1]<0 or s[1]>6: # out of grid 163 | wall = True 164 | elif np.all(self.world[s[0],s[1]] == 1): # wall 165 | wall = True 166 | else: 167 | wall = False 168 | 169 | if wall: 170 | # Need to repeat, put back old values 171 | s[0] = s_old[0] 172 | s[1] = s_old[1] 173 | return wall 174 | else: 175 | # Move to new state 176 | return wall 177 | 178 | def _check_dead(self): 179 | if np.all(self.pacman == self.ghost1) or np.all(self.pacman == self.ghost2): 180 | return True 181 | else: 182 | return False 183 | 184 | 185 | # Test 186 | if __name__ == '__main__': 187 | grid = grid_env(True) 188 | s = grid.get_state() 189 | for i in range(200): 190 | a = random.sample(range(4),1) 191 | s,r,dead = grid.step(a) 192 | if not dead: 193 | grid.plot() 194 | else: 195 | print('Died in step',i,', restarting') 196 | s = grid.reset() 197 | print(grid.get_state()) 198 | print('Finished') 199 | plt.show(block=True) 200 | -------------------------------------------------------------------------------- /rlutils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Mar 2 14:48:24 2017 4 | 5 | @author: thomas 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /rlutils/helpers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Grid-world environment 4 | @author: thomas 5 | """ 6 | import sys 7 | import os.path 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 9 | sys.path.append(os.path.abspath(os.path.dirname(__file__))) 10 | 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | import matplotlib.patches as patches 14 | from rlenv.grid import grid_env as grid_env 15 | from rlutils.policies import egreedy 16 | from scipy.linalg import norm 17 | 18 | ### other stuff 19 | def make_rl_data(Env,batch_size): 20 | world = np.zeros([7,7],dtype='int32') 21 | world[1:6,1] = 1 22 | world[1:3,4] = 1 23 | world[4:6,4] = 1 24 | 25 | s_data = np.zeros([batch_size,Env.observation_shape],dtype='int32') 26 | s1_data = np.zeros([batch_size,Env.observation_shape],dtype='int32') 27 | a_data = np.zeros([batch_size,1],dtype='int32') 28 | r_data = np.zeros([batch_size,1],dtype='int32') 29 | term_data = np.zeros([batch_size,1],dtype='int32') 30 | count = 0 31 | while count < batch_size: 32 | i,j,k,l,m,n = np.random.randint(0,7,1),np.random.randint(0,7,1),np.random.randint(0,7,1),np.random.randint(0,7,1),np.random.randint(0,7,1),np.random.randint(0,7,1) 33 | a = np.random.randint(0,4,1)[0] 34 | do = not bool(world[i,j]) and (not bool(world[k,l])) and (not bool(world[m,n])) 35 | if do: 36 | s = np.array([i,j,k,l,m,n]).flatten() 37 | Env.set_state(s) 38 | s1 , r , dead = Env.step([a]) 39 | s_data[count,] = s 40 | s1_data[count,] = s1 41 | a_data[count] = a 42 | r_data[count] = r 43 | term_data[count] = dead 44 | count += 1 45 | return s_data, a_data, s1_data, r_data, term_data 46 | 47 | def make_test_data(sess,model,Env,batch_size,epsilon=0.05): 48 | ''' on policy test data ''' 49 | s_data = np.zeros([batch_size,Env.observation_shape],dtype='int32') 50 | s1_data = np.zeros([batch_size,Env.observation_shape],dtype='int32') 51 | a_data = np.zeros([batch_size,1],dtype='int32') 52 | r_data = np.zeros([batch_size,1],dtype='int32') 53 | term_data = np.zeros([batch_size,1],dtype='int32') 54 | count = 0 55 | s = Env.reset() 56 | while count < batch_size: 57 | Qsa = sess.run(model.Qsa, feed_dict = {model.x :s[None,:], 58 | model.k : 1, 59 | }) 60 | a = np.array([egreedy(Qsa[0],epsilon)]) 61 | s1 = sess.run(model.y_sample,{ model.x : s[None,:], 62 | model.y : np.zeros(np.shape(s))[None,:], 63 | model.a : a[:,None], 64 | model.lamb : 1, 65 | model.temp : 0.0001, 66 | model.is_training : False, 67 | model.k: 1}) 68 | s_data[count,] = s 69 | a_data[count] = a 70 | s1_data[count,] = s1 71 | Env.set_state(s1) 72 | term_data[count,] = dead = Env._check_dead() 73 | if dead: 74 | s = Env.reset() 75 | else: 76 | s = s1 77 | count += 1 78 | return s_data, a_data, s1_data, r_data, term_data 79 | 80 | def plot_predictions(model,sess,n_row,n_col,run_id,hps,on_policy=False,s=np.array([0,0,1,3,5,3])): 81 | world = np.zeros([7,7],dtype='int32') 82 | world[1:6,1] = 1 83 | world[1:3,4] = 1 84 | world[4:6,4] = 1 85 | 86 | fig = plt.figure()#figsize=(7,7),dpi=600,aspect='auto') 87 | for row in range(n_row): 88 | for col in range(n_col): 89 | ax1 = fig.add_subplot(n_row,n_col,((n_col*col) + row + 1),aspect='equal') 90 | # plot the environment 91 | ax1.axis('off') 92 | plt.xlim([-1,8]) 93 | plt.ylim([-1,8]) 94 | plot_predictions 95 | for i in range(7): 96 | for j in range(7): 97 | if world[i,j]==1: 98 | col = "black" 99 | else: 100 | col = "white" 101 | ax1.add_patch( 102 | patches.Rectangle( 103 | (i,j),1,1, 104 | #fill=False, 105 | edgecolor='black', 106 | linewidth = 2, 107 | facecolor = col,), 108 | ) 109 | 110 | # sample some state 111 | do = False 112 | if not on_policy: 113 | while not do: 114 | i,j,k,l,m,n = np.random.randint(0,7,1),np.random.randint(0,7,1),np.random.randint(0,7,1),np.random.randint(0,7,1),np.random.randint(0,7,1),np.random.randint(0,7,1) 115 | a = np.random.randint(0,4,1) 116 | do = not bool(world[i,j]) and (not bool(world[k,l])) and (not bool(world[m,n])) 117 | s = np.array([i,j,k,l,m,n]).flatten() 118 | else: 119 | i,j,k,l,m,n = s 120 | Qsa = sess.run(model.Qsa, feed_dict = {model.x :s[None,:], 121 | model.k : 1, 122 | }) 123 | a = np.array([egreedy(Qsa[0],0.01)]) 124 | 125 | # add the start 126 | ax1.add_artist(plt.Circle((m+0.5,n+0.5),0.3,color='blue')) 127 | ax1.add_artist(plt.Circle((k+0.5,l+0.5),0.3,color='red')) 128 | ax1.add_artist(plt.Circle((i+0.5,j+0.5),0.3,color='green')) 129 | if a == 0: 130 | action = 'up' 131 | elif a == 1: 132 | action = 'down' 133 | elif a == 2: 134 | action = 'right' 135 | elif a == 3: 136 | action = 'left' 137 | ax1.set_title(action) 138 | 139 | trans = predict(model,sess,s,a) 140 | for agent in range(3): 141 | for i in range(7): 142 | for j in range(7): 143 | if trans[i,j,agent]>0: 144 | if agent == 0: 145 | col = 'green' 146 | elif agent == 1: 147 | col = 'red' 148 | elif agent == 2: 149 | col = 'blue' 150 | 151 | ax1.add_patch( 152 | patches.Rectangle( 153 | (i,j),1,1, 154 | fill=True, 155 | alpha = trans[i,j,agent], 156 | edgecolor='black', 157 | linewidth = 2, 158 | facecolor = col,), 159 | ) 160 | if on_policy: 161 | s1 = sess.run(model.y_sample,{ model.x : s[None,:], 162 | model.y : np.zeros(np.shape(s[None,:])), 163 | model.a : a[:,None], 164 | model.lamb : 1, 165 | model.temp : 0.0001, 166 | model.is_training : False, 167 | model.k: 1}) 168 | s = s1[0] 169 | return s 170 | 171 | def predict(model,sess,some_s,some_a,n_test_samp=200): 172 | freq = np.zeros([7,7,3]) 173 | for m in range(n_test_samp): 174 | y_sample = sess.run(model.y_sample,{ model.x : some_s[None,:], 175 | model.y : np.zeros(np.shape(some_s))[None,:], 176 | model.a : some_a[:,None], 177 | model.lamb : 1, 178 | model.temp : 0.0001, 179 | model.is_training : False, 180 | model.k: 1}) 181 | y_sample = y_sample.flatten() 182 | freq[y_sample[0],y_sample[1],0] += 1 183 | freq[y_sample[2],y_sample[3],1] += 1 184 | freq[y_sample[4],y_sample[5],2] += 1 185 | 186 | trans = freq/n_test_samp 187 | return trans 188 | 189 | def kl_preds_v2(model,sess,s_test,a_test,n_rep_per_item=200): 190 | ## Compare sample distribution to ground truth 191 | Env = grid_env(False) 192 | n_test_items,state_size = s_test.shape 193 | distances = np.empty([n_test_items,3]) 194 | 195 | for i in range(n_test_items): 196 | state = s_test[i,:].astype('int32') 197 | action = np.round(a_test[i,:]).astype('int32') 198 | 199 | # ground truth 200 | state_truth = np.empty([n_rep_per_item,s_test.shape[1]]) 201 | for o in range(n_rep_per_item): 202 | Env.set_state(state.flatten()) 203 | s1,r,dead = Env.step(action.flatten()) 204 | state_truth[o,:] = s1 205 | truth_count,bins = np.histogramdd(state_truth,bins=[np.arange(8)-0.5]*state_size) 206 | truth_prob = truth_count/n_rep_per_item 207 | 208 | # predictions of model 209 | y_sample = sess.run(model.y_sample,{ model.x : state[None,:].repeat(n_rep_per_item,axis=0), 210 | model.y : np.zeros(np.shape(state[None,:])).repeat(n_rep_per_item,axis=0), 211 | model.a : action[None,:].repeat(n_rep_per_item,axis=0), 212 | model.Qtarget : np.zeros(np.shape(action[None,:])).repeat(n_rep_per_item,axis=0), 213 | model.lr : 0, 214 | model.lamb : 1, 215 | model.temp : 0.00001, 216 | model.is_training : False, 217 | model.k: 1}) 218 | sample_count,bins = np.histogramdd(y_sample,bins=[np.arange(8)-0.5]*state_size) 219 | sample_prob = sample_count/n_rep_per_item 220 | 221 | distances[i,0]= np.sum(truth_prob*(np.log(truth_prob+1e-5)-np.log(sample_prob+1e-5))) # KL(p|p_tilde) 222 | distances[i,1]= np.sum(sample_prob*(np.log(sample_prob+1e-5)-np.log(truth_prob+1e-5))) # Inverse KL(p_tilde|p) 223 | distances[i,2]= norm(np.sqrt(truth_prob) - np.sqrt(sample_prob))/np.sqrt(2) 224 | 225 | return np.mean(distances,axis=0) -------------------------------------------------------------------------------- /rlutils/policies.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Mar 16 16:33:17 2017 4 | 5 | @author: thomas 6 | """ 7 | import numpy as np 8 | ### policies 9 | def egreedy(Qs, epsilon=0.05): 10 | ''' e-greedy policy on Q values ''' 11 | if Qs.ndim == 1: 12 | a = np.argmax(Qs) 13 | if np.random.rand() < epsilon: 14 | a = np.random.randint(np.size(Qs)) 15 | return a 16 | else: 17 | raise ValueError('Qs.ndim should be 1') 18 | 19 | def softmax(Qs, temp = 1): 20 | ''' Boltzmann policy on Q values ''' 21 | if Qs.ndim == 1: 22 | x = Qs * temp 23 | e_x = np.exp(x - np.max(x)) 24 | probs = e_x / e_x.sum() 25 | return np.where(np.random.multinomial(1,probs))[0] 26 | else: 27 | raise ValueError('Qs.ndim should be 1') 28 | -------------------------------------------------------------------------------- /rlutils/rollout.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Apr 20 09:26:00 2017 4 | 5 | @author: thomas 6 | """ 7 | import numpy as np 8 | from policies import egreedy 9 | 10 | def rollout(Env,hps,model,sess,s,epsilon=0.05): 11 | ''' Q-learning + e-greedy roll-out ''' 12 | s_batch = np.empty(np.append(hps.batch_size,Env.observation_shape),dtype='float32') 13 | a_batch = np.empty([hps.batch_size,1],dtype='float32') 14 | r_batch = np.empty([hps.batch_size,1],dtype='float32') 15 | term_batch = np.empty([hps.batch_size,1],dtype='float32') 16 | s1_batch = np.empty(np.append(hps.batch_size,Env.observation_shape),dtype='float32') 17 | for _ in range(hps.batch_size): 18 | Qsa = sess.run(model.Qsa, feed_dict = {model.x : s[None,:], 19 | model.k : 1}) 20 | a = egreedy(Qsa[0],epsilon) 21 | s1,r,dead = Env.step([a]) 22 | s_batch[_,],a_batch[_,],r_batch[_,],s1_batch[_,],term_batch[_,] = s,a,r,s1,dead 23 | s = s1 24 | if dead: 25 | s = Env.reset() 26 | 27 | # Calculate targets 28 | if hps.use_target_net: 29 | Qsa1 = sess.run(model.Qsa_t, feed_dict = {model.x : s1_batch,model.k : 1}) 30 | else: 31 | Qsa1 = sess.run(model.Qsa, feed_dict = {model.x : s1_batch,model.k : 1}) 32 | 33 | Qmax = np.max(Qsa1,axis=1)[:,None] 34 | Qmax *= (1. - term_batch) 35 | Qtarget_batch = r_batch + hps.gamma * Qmax 36 | 37 | return s_batch, a_batch, s1_batch, r_batch, term_batch, Qtarget_batch, s, Env -------------------------------------------------------------------------------- /tfutils/distributions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Distribution functions 4 | Copyright (c) 2016 openai 5 | https://github.com/openai/iaf/blob/master/tf_utils 6 | 7 | Gumbell-softmax functions 8 | Copyright (c) 2016 Eric Jang 9 | https://github.com/ericjang/gumbel-softmax 10 | """ 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | ##### Probabilistic functions (partially from OpenAI github) 15 | class DiagonalGaussian(object): 16 | def __init__(self, mean, logvar, sample=None): 17 | self.mean = mean 18 | self.logvar = logvar 19 | 20 | if sample is None: 21 | noise = tf.random_normal(tf.shape(mean)) 22 | sample = mean + tf.exp(0.5 * logvar) * noise 23 | self.sample = sample 24 | 25 | def logps(self, sample): 26 | return gaussian_diag_logps(self.mean, self.logvar, sample) 27 | 28 | def gaussian_diag_logps(mean, logvar, sample=None): 29 | if sample is None: 30 | noise = tf.random_normal(tf.shape(mean)) 31 | sample = mean + tf.exp(0.5 * logvar) * noise 32 | return tf.clip_by_value(-0.5 * (np.log(2 * np.pi) + logvar + tf.square(sample - mean) / tf.exp(logvar)),-(10e10),10e10) 33 | 34 | def discretized_logistic(mean, logscale=None, binsize=1/6.0, sample=None): 35 | if logscale is None: 36 | logscale = tf.zeros(tf.shape(mean)) 37 | logscale = tf.clip_by_value(logscale,-10,10) 38 | scale = tf.exp(logscale) 39 | x = (sample - mean) * (1/binsize) # stretch back 40 | logp = tf.log(tf.sigmoid((x + 0.5)/scale) - tf.sigmoid((x - 0.5)/scale) + 1e-20) 41 | return logp 42 | 43 | def logsumexp(x): 44 | x_max = tf.reduce_max(x, [1], keep_dims=True) 45 | return tf.reshape(x_max, [-1]) + tf.log(tf.reduce_sum(tf.exp(x - x_max), [1])) 46 | 47 | def compute_lowerbound(log_pxz, sum_kl_costs, k=1, alpha=0.5): 48 | if k == 1: 49 | return sum_kl_costs - log_pxz 50 | # log 1/k \sum p(x | z) * p(z) / q(z | x) = -log(k) + logsumexp(log p(x|z) + log p(z) - log q(z|x)) 51 | log_pxz = tf.reshape(log_pxz, [-1, k]) 52 | sum_kl_costs = tf.reshape(sum_kl_costs, [-1, k]) 53 | diff = (log_pxz - sum_kl_costs)*(1-alpha) 54 | elbo = tf.reduce_mean(- tf.log(float(k)) + logsumexp(diff)) / (1-alpha) 55 | return -elbo 56 | 57 | 58 | ###### Gumbell-softmax ####### (from E. Jiang) 59 | def sample_gumbel(shape, eps=1e-20): 60 | """Sample from Gumbel(0, 1)""" 61 | U = tf.random_uniform(shape,minval=0,maxval=1) 62 | return -tf.log(-tf.log(U + eps) + eps) 63 | 64 | def gumbel_softmax_sample(logits, temperature): 65 | """ Draw a sample from the Gumbel-Softmax distribution""" 66 | y = tf.add(logits,sample_gumbel(tf.shape(logits))) 67 | return tf.nn.softmax( tf.div(y, temperature)) 68 | 69 | def gumbel_softmax(logits, temperature, hard=False): 70 | """Sample from the Gumbel-Softmax distribution and optionally discretize. 71 | Args: 72 | logits: [batch_size, n_class] unnormalized log-probs 73 | temperature: non-negative scalar 74 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 75 | Returns: 76 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 77 | If hard=True, then the returned sample will be one-hot, otherwise it will 78 | be a probabilitiy distribution that sums to 1 across classes 79 | """ 80 | y = gumbel_softmax_sample(logits, temperature) 81 | #if hard: 82 | # k = tf.shape(logits)[-1] 83 | # #y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype) 84 | # y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype) 85 | # y = tf.stop_gradient(y_hard - y) + y 86 | return y 87 | 88 | -------------------------------------------------------------------------------- /tfutils/helpers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Tensorflow helper functionality 4 | Partially from: https://github.com/openai/iaf/blob/master/tf_utils/common.py 5 | Copyright (c) 2016 openai 6 | """ 7 | import tensorflow as tf 8 | import numpy as np 9 | 10 | def anneal_linear(t, n, e_final,e_init=1): 11 | ''' Linear anneals between e_init and e_final ''' 12 | if t >= n: 13 | return e_final 14 | else: 15 | return e_init - ( (t/n) * (e_init - e_final) ) 16 | 17 | class HParams(object): 18 | ''' Hyperparameter object, from https://github.com/openai/iaf/blob/master/tf_utils ''' 19 | 20 | def __init__(self, **kwargs): 21 | self._items = {} 22 | for k, v in kwargs.items(): 23 | self._set(k, v) 24 | 25 | def _set(self, k, v): 26 | self._items[k] = v 27 | setattr(self, k, v) 28 | 29 | def parse(self, str_value): 30 | hps = HParams(**self._items) 31 | for entry in str_value.strip().split(","): 32 | entry = entry.strip() 33 | if not entry: 34 | continue 35 | key, sep, value = entry.partition("=") 36 | if not sep: 37 | raise ValueError("Unable to parse: %s" % entry) 38 | default_value = hps._items[key] 39 | if isinstance(default_value, bool): 40 | hps._set(key, value.lower() == "true") 41 | elif isinstance(default_value, int): 42 | hps._set(key, int(value)) 43 | elif isinstance(default_value, float): 44 | hps._set(key, float(value)) 45 | elif isinstance(default_value, list): 46 | value = value.split('-') 47 | default_inlist = hps._items[key][0] 48 | if key == 'seq1': 49 | default_inlist = hps._items[hps._items['item1']] 50 | if key == 'seq2': 51 | default_inlist = hps._items[hps._items['item2']] 52 | if isinstance(default_inlist, bool): 53 | hps._set(key, [i.lower() == "true" for i in value]) 54 | elif isinstance(default_inlist, int): 55 | hps._set(key, [int(i) for i in value]) 56 | elif isinstance(default_inlist, float): 57 | hps._set(key, [float(i) for i in value]) 58 | else: 59 | hps._set(key,value) # string 60 | else: 61 | hps._set(key, value) 62 | return hps 63 | 64 | def split(x, split_dim, split_sizes): 65 | n = len(list(x.get_shape())) 66 | dim_size = np.sum(split_sizes) 67 | assert int(x.get_shape()[split_dim]) == dim_size 68 | ids = np.cumsum([0] + split_sizes) 69 | ids[-1] = -1 70 | begin_ids = ids[:-1] 71 | 72 | ret = [] 73 | for i in range(len(split_sizes)): 74 | cur_begin = np.zeros([n], dtype=np.int32) 75 | cur_begin[split_dim] = begin_ids[i] 76 | cur_end = np.zeros([n], dtype=np.int32) - 1 77 | cur_end[split_dim] = split_sizes[i] 78 | ret += [tf.slice(x, cur_begin, cur_end)] 79 | return ret 80 | 81 | def repeat_v2(x,k): 82 | ''' repeat k times along first dimension ''' 83 | def change(x,k): 84 | shape = x.get_shape().as_list()[1:] 85 | x_1 = tf.expand_dims(x,1) 86 | tile_shape = tf.concat([tf.ones(1,dtype='int32'),[k],tf.ones([tf.rank(x)-1],dtype='int32')],axis=0) 87 | x_rep = tf.tile(x_1,tile_shape) 88 | new_shape = np.insert(shape,0,-1) 89 | x_out = tf.reshape(x_rep,new_shape) 90 | return x_out 91 | 92 | return tf.cond(tf.equal(k,1), 93 | lambda: x, 94 | lambda: change(x,k)) 95 | -------------------------------------------------------------------------------- /vae_grid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Variational Inference on Gridworld 5 | @author: thomas 6 | """ 7 | import tensorflow as tf 8 | import numpy as np 9 | from tfutils.helpers import anneal_linear, HParams 10 | import matplotlib.pyplot as plt 11 | from rlenv.grid import grid_env as grid_env 12 | from rlutils.policies import egreedy 13 | from rlutils.helpers import make_rl_data, make_test_data, plot_predictions, kl_preds_v2 14 | from pythonutils.helpers import save, make_name, nested_list 15 | import os 16 | import logging 17 | 18 | # Tensorflow parser 19 | flags = tf.app.flags 20 | flags.DEFINE_string("hpconfig", "", "Overrides default hyper-parameters.") 21 | flags.DEFINE_string("save_dir", "results/grid", "Results directory.") 22 | flags.DEFINE_string("check_dir", "/tmp/best_model", "Checkpoint directory.") 23 | FLAGS = flags.FLAGS 24 | 25 | def get_hps(): 26 | ''' Hyperparameter settings ''' 27 | return HParams( 28 | # General learning set-up 29 | network = 1, # which network to run 30 | n_epochs = 75000, # number of batches 31 | batch_size = 32, # batch size 32 | eval_freq = 500, # Evaluation frequency 33 | test_size = 1500, # Test set size 34 | debug = False, # tf debugging 35 | lr_init = 0.0005, # Initial learning rate 36 | lr_final_frac = 0.2, # lr_final = lr_init * lr_final_frac 37 | anneal_frac_lr = 0.7, # percentage of n_epochs to anneal over 38 | max_grad = None, # gradient clip size, None is no clipping 39 | verbose = True, # print verbosity 40 | 41 | # q(z|x,y) (Variational approximation) 42 | var_type = ['continuous'], # Distribution for q(z|x,y), either ['continuous'] for Gaussian, or ['discrete'] for categorical 43 | # repeated if depth>1, for multiple layers use: var_type = ['continuous','discrete'], 44 | depth = 1, # depth of stochastic layers 45 | h_size = 100, # dimensionality in variatonal ladder (deterministic part) 46 | resnet = False, # Whether to use Resnet like architecture 47 | 48 | # Discrete latent variables 49 | K = 3, # categories per discrete latent variable 50 | N = 3, # number of variables 51 | tau_init = 2.0, # softmax temperature initial 52 | tau_final = 0.001, # softmax temperature final 53 | anneal_frac_tau = 0.7, # anneal fraction 54 | 55 | # Continuous latent variables 56 | z_size = 3, # number of continuous latent variables 57 | n_flow = 5, # depth of flow (if n_flow=0 --> no flow) 58 | ar = False, # type of flow, if False: affine coupling layer (Real NVP), if True: inverse autoregressive flow (IAF) 59 | ignore_sigma_latent = False, # use sigma in variational approximation 60 | ignore_sigma_flow = False, # use sigma in flow transformations 61 | 62 | # p(y|x,z) (Decoder distribution) 63 | out_lik = 'discrete', # distribution for p(y|x,z), can be 'normal', 'discrete' or 'discretized logistic' 64 | ignore_sigma_outcome = True, # Whether to learn the SD of p(y|x,z). 65 | #For discrete, whether to sample for categorical or deterministically argmax over the predicted class probabilities. 66 | 67 | # VAE objective 68 | k = 3, # number of importance samples 69 | alpha = 0.5, # alpha in Renyi alpha divergence 70 | kl_min = 0.07, # Number of "free bits/nats", only used if kl_min>0 71 | use_lamb = False, # KL annealing (alternative to kl_min) 72 | lamb_init = 0.1, # Initial contribution of KL to loss : L = p(y|x,z) + lambda*KL(q|p) 73 | lamb_final = 1.0, # Final lambda 74 | anneal_frac_lamb = 0.3, # anneal iteration fraction 75 | 76 | # Reinforcement learning settings 77 | artificial_data = True, # if True: no RL but sample data across state-space (decorrelated) 78 | use_target_net = False, # if True: use target net in DQN 79 | eps_init = 1.0, # initial epsilon in e-greedy action selection 80 | eps_final = 0.10, # final epsilon 81 | anneal_frac_eps = 0.6, # fraction of n_epochs to anneal over 82 | gamma = 0.99, # discount factor 83 | test_on_policy = False, # if True: plot evalutions while following policy (only useful with artificial_data=False) 84 | 85 | # Hyperparameter looping 86 | n_rep = 10, # number of repetitions per setting 87 | loop_hyper = False, # If False, no looping (ignores other settings below) 88 | item1 = 'kl_min', # First hyperparameter to loop over (should appear in settings above) 89 | seq1 = [0,0.04,0.07,0.10,0.20], # Values to loop over 90 | item2 = 'use_lamb', # Second hyperparameter 91 | seq2 = [False, True], # Second loop values 92 | ) 93 | 94 | def run(hps): 95 | ''' Main function: run training and evaluation ''' 96 | Env = grid_env(False) 97 | Test_env = grid_env(False) 98 | if hps.artificial_data: 99 | s_valid_pre, a_valid_pre, s1_valid_pre, r_valid_pre, term_valid_pre = make_rl_data(Test_env,int(hps.test_size/2)) 100 | s_test_pre, a_test_pre, s1_test_pre, r_test_pre, term_test_pre = make_rl_data(Test_env,hps.test_size) 101 | 102 | # Set-up hyperparameter loop 103 | n_rep = hps.n_rep 104 | seq1 = hps.seq1 105 | seq2 = hps.seq2 106 | results = np.empty([len(seq1),len(seq2),n_rep]) 107 | results_elbo = np.empty([len(seq1),len(seq2),n_rep]) 108 | results_distances = np.empty([len(seq1),len(seq2),3,n_rep]) 109 | av_rewards = nested_list(len(seq1),len(seq2),n_rep) 110 | 111 | for j,item1 in enumerate(seq1): 112 | hps._set(hps.item1,item1) 113 | for l,item2 in enumerate(seq2): 114 | hps._set(hps.item2,item2) 115 | for rep in range(n_rep): 116 | tf.reset_default_graph() 117 | hps.lr_final = hps.lr_init*hps.lr_final_frac 118 | 119 | # Initialize anneal parameters 120 | np_lr= anneal_linear(0,hps.n_epochs * hps.anneal_frac_lr,hps.lr_final,hps.lr_init) 121 | np_temp = anneal_linear(0,hps.n_epochs * hps.anneal_frac_tau,hps.tau_final,hps.tau_init) 122 | np_lamb = anneal_linear(0,hps.n_epochs * hps.anneal_frac_lamb,hps.lamb_final,hps.lamb_init) 123 | np_eps = anneal_linear(0,hps.n_epochs * hps.anneal_frac_eps,hps.eps_final,hps.eps_init) 124 | 125 | # Build network 126 | if hps.network == 1: 127 | import networks.network_rl as net 128 | model = net.Network(hps,Env.observation_shape) 129 | 130 | # Check model size 131 | total_size = 0 132 | for v in tf.trainable_variables(): 133 | total_size += np.prod([int(s) for s in v.get_shape()]) 134 | print("Total number of trainable variables: {}".format(total_size)) 135 | 136 | # Session and initialization 137 | with tf.Session() as sess: 138 | if hps.debug: 139 | sess = tf.python.debug.LocalCLIDebugWrapperSession(sess) 140 | sess.add_tensor_filter("has_inf_or_nan", tf.python.debug.has_inf_or_nan) 141 | 142 | sess.run(model.init_op) 143 | saver = tf.train.Saver() 144 | 145 | # Some storage 146 | t = [] 147 | lr = [] 148 | elbo_keep = [] 149 | train_nats_keep = [] 150 | valid_nats_keep = [] 151 | test_nats_keep = [] 152 | min_valid_nats = 1e50 153 | best_test_nats = 0.0 154 | best_elbo = 0.0 155 | best_iter = 0 156 | died_ep = [] 157 | epoch_reward = [] 158 | 159 | # Train 160 | print('Initialized, starting to train') 161 | s = Env.reset() 162 | for i in range(hps.n_epochs): 163 | 164 | if not hps.artificial_data: # roll out in Env 165 | s_batch = np.empty(np.append(hps.batch_size,Env.observation_shape),dtype='float32') 166 | a_batch = np.empty([hps.batch_size,1],dtype='float32') 167 | r_batch = np.empty([hps.batch_size,1],dtype='float32') 168 | term_batch = np.empty([hps.batch_size,1],dtype='float32') 169 | s1_batch = np.empty(np.append(hps.batch_size,Env.observation_shape),dtype='float32') 170 | for _ in range(hps.batch_size): 171 | Qsa = sess.run(model.Qsa, feed_dict = {model.x :s[None,:], 172 | model.k : 1, 173 | }) 174 | a = egreedy(Qsa[0],np_eps) 175 | s1,r,dead = Env.step([a]) 176 | s_batch[_,],a_batch[_,],r_batch[_,],s1_batch[_,],term_batch[_,] = s,a,r,s1,dead 177 | s = s1 178 | #Env.plot() 179 | if dead: 180 | s = Env.reset() # process smaller batch 181 | died_ep.extend([i]) 182 | 183 | else: # Sample some transitions across state-space 184 | s_batch, a_batch, s1_batch, r_batch,term_batch = make_rl_data(Env,hps.batch_size) 185 | 186 | # Calculate targets 187 | if hps.use_target_net: 188 | Qsa1 = sess.run(model.Qsa_t, feed_dict = {model.x : s1_batch,model.k : 1}) 189 | else: 190 | Qsa1 = sess.run(model.Qsa, feed_dict = {model.x : s1_batch,model.k : 1}) 191 | 192 | Qmax = np.max(Qsa1,axis=1)[:,None] 193 | Qmax *= (1. - term_batch) 194 | Qtarget_batch = r_batch + hps.gamma * Qmax 195 | 196 | # store stuff 197 | epoch_reward.extend([np.mean(r_batch)]) 198 | 199 | # draw batch 200 | __,__, np_elbo = sess.run([model.train_op,model.train_op_rl,model.elbo],{ model.x : s_batch, 201 | model.y : s1_batch, 202 | model.a : a_batch, 203 | model.Qtarget : Qtarget_batch, 204 | model.lr : np_lr, 205 | model.lamb : np_lamb, 206 | model.temp : np_temp, 207 | model.is_training : True, 208 | model.k: hps.k} ) 209 | 210 | # Annealing 211 | if i % 250 == 1: 212 | np_lr= anneal_linear(i,hps.n_epochs * hps.anneal_frac_lr,hps.lr_final,hps.lr_init) 213 | np_temp = anneal_linear(i,hps.n_epochs * hps.anneal_frac_tau,hps.tau_final,hps.tau_init) 214 | np_lamb = anneal_linear(i,hps.n_epochs * hps.anneal_frac_lamb,hps.lamb_final,hps.lamb_init) 215 | np_eps = anneal_linear(i,hps.n_epochs * hps.anneal_frac_eps,hps.eps_final,hps.eps_init) 216 | 217 | # Evaluate 218 | if i % hps.eval_freq == 1: 219 | if hps.use_target_net: 220 | sess.run([model.copy_op]) 221 | 222 | if (not hps.artificial_data) and hps.test_on_policy: 223 | s_valid, a_valid, s1_valid, r_valid, term_valid = make_test_data(sess,model,Test_env,hps.test_size,epsilon=0.05) 224 | s_test, a_test, s1_test, r_test, term_test = make_test_data(sess,model,Test_env,hps.test_size,epsilon=0.05) 225 | else: 226 | s_valid, a_valid, s1_valid, r_valid, term_valid = s_valid_pre, a_valid_pre, s1_valid_pre, r_valid_pre, term_valid_pre 227 | s_test, a_test, s1_test, r_test, term_test = s_test_pre, a_test_pre, s1_test_pre, r_test_pre, term_test_pre 228 | 229 | train_elbo,train_nats,train_kl,train_rl_cost = sess.run([model.elbo,model.nats,model.kl,model.rl_cost],{ model.x : s_batch, 230 | model.y : s1_batch, 231 | model.a : a_batch, 232 | model.Qtarget : Qtarget_batch, 233 | model.lamb : np_lamb, 234 | model.temp : 0.0001, 235 | model.is_training : True, 236 | model.k: 40}) 237 | 238 | valid_nats = sess.run(model.nats,{ model.x : s_valid, 239 | model.y : s1_valid, 240 | model.a : a_valid, 241 | model.Qtarget : np.zeros(np.shape(a_valid)), 242 | model.lamb : np_lamb, 243 | model.temp : 0.0001, 244 | model.is_training : True, 245 | model.k: 40}) 246 | test_nats = sess.run(model.nats,{ model.x : s_test, 247 | model.y : s1_test, 248 | model.a : a_test, 249 | model.Qtarget : np.zeros(np.shape(a_test)), 250 | model.lamb : np_lamb, 251 | model.temp : 0.0001, 252 | model.is_training : True, 253 | model.k: 40}) 254 | if hps.verbose: 255 | print('Step',i,'ELBO: ',train_elbo, 'Training nats:',train_nats, 'Training KL:',train_kl, 'RL cost:',train_rl_cost, 256 | ' \n Valid nats',valid_nats, ' Test set nats',test_nats, 257 | ' \n Average reward in last 50 batches',np.mean(epoch_reward[-50:]), 'Learning rate',np_lr,'Softmax Temp',np_temp,'Epsilon:',np_eps) 258 | 259 | t.extend([i]) 260 | lr.extend([np_lr]) 261 | train_nats_keep.extend([train_nats]) 262 | valid_nats_keep.extend([valid_nats]) 263 | test_nats_keep.extend([test_nats]) 264 | elbo_keep.extend([train_elbo]) 265 | if valid_nats < min_valid_nats: 266 | min_valid_nats = valid_nats 267 | #best_sample = y_samples # keep the sample 268 | best_test_nats = test_nats 269 | best_elbo = train_nats 270 | best_iter = i 271 | saver.save(sess,FLAGS.check_dir) 272 | 273 | # VAE storage 274 | print('Best result in iteration',best_iter,'with valid_nats',min_valid_nats,'and test nats',best_test_nats) 275 | saver.restore(sess,FLAGS.check_dir) 276 | print('Restored best VAE model') 277 | 278 | # nats 279 | fig = plt.figure() 280 | plt.plot(t,train_nats_keep,label='train nats') 281 | plt.plot(t,valid_nats_keep,label='valid nats') 282 | plt.plot(t,test_nats_keep,label='test nats') 283 | plt.plot(t,elbo_keep,label='ELBO') 284 | plt.legend(loc=0) 285 | if hps.loop_hyper: 286 | save(os.path.join(hps.my_dir,'nats_{}={}_{}={}_{}'.format(hps.item1,item1,hps.item2,item2,rep))) 287 | else: 288 | save(os.path.join(hps.my_dir,'nats{}'.format(rep)),ext='png',close=True,verbose=False) 289 | results[j,l,rep] = best_test_nats 290 | results_elbo[j,l,rep] = best_elbo 291 | 292 | # Distances from true distribution 293 | distances = kl_preds_v2(model,sess,s_test,a_test) 294 | results_distances[j,l,:,rep] = distances 295 | 296 | # Visualize some predictions 297 | n_row = 2 298 | n_col = 2 299 | s_start=np.array([0,0,1,3,5,3]) 300 | for extra_rep in range(3): 301 | if hps.test_on_policy: 302 | s_start = plot_predictions(model,sess,n_row,n_col,rep,hps,True,s_start) 303 | else: 304 | s_start = plot_predictions(model,sess,n_row,n_col,rep,hps,False) 305 | if hps.loop_hyper: 306 | name = os.path.join(hps.my_dir,'predictions_{}={}_{}={}_{}'.format(hps.item1,item1,hps.item2,item2,rep)) 307 | else: 308 | name = os.path.join(hps.my_dir,'predictions_{}{}'.format(rep,extra_rep)) 309 | save(name,ext='png',close=True,verbose=False) 310 | 311 | ############# RL ################ 312 | if not hps.artificial_data: 313 | window = 200 314 | av_reward = np.convolve(epoch_reward, np.ones((window,))/window, mode='valid') 315 | av_rewards[j][l][rep] = av_reward # average rewards of RL agent 316 | 317 | # Show learned behaviour 318 | if (not hps.artificial_data) and hps.verbose: 319 | print('Start evaluating policy') 320 | Env = grid_env(True) 321 | Env.reset() 322 | for lll in range(100): 323 | Qsa = sess.run(model.Qsa, feed_dict = {model.x :s[None,:], 324 | model.k : 1, 325 | }) 326 | a = egreedy(Qsa[0],0.01) 327 | s,r,dead = Env.step([a]) 328 | Env.plot() 329 | if dead: 330 | print('Died in step',lll,', restarting') 331 | s = Env.reset() 332 | plt.close() 333 | # Overall results 334 | results_raw = results 335 | results = np.mean(results,axis=2) 336 | results_raw_elbo = results_elbo 337 | results_elbo = np.mean(results_elbo,axis=2) 338 | results_raw_distances = results_distances 339 | results_distances = np.mean(results_distances,axis=3) 340 | 341 | logging.info('-------------------- Overall Results --------------------------') 342 | logging.info('vae' if hps.network == 1 else 'mlp_{}'.format('deterministic' if hps.deterministic else 'stochastic')) 343 | logging.info('Latent type %s of depth %s',hps.var_type[0],hps.depth) 344 | logging.info('(z_size,n_flow) %s %s and (n,k) %s %s',hps.z_size,hps.n_flow,hps.N,hps.K) 345 | logging.info('Results over %s runs',n_rep) 346 | logging.info('Test nats: %s',results[0,0]) 347 | logging.info('Elbo: %s',-1*results_elbo[0,0]) 348 | logging.info('KL with true distr: %s',results_distances[0,0,:]) 349 | logging.info('Raw data over repetitions \n %s \n %s \n %s',results_raw,results_raw_elbo,results_raw_distances) 350 | 351 | if hps.loop_hyper: 352 | fig = plt.figure() 353 | for i in range(results.shape[1]): 354 | plt.plot([i for i in range(len(seq1))],results[:,i],label='{} = {}'.format(hps.item2,hps.seq2[i])) 355 | plt.xlabel(hps.item1) 356 | plt.gca().set_xticklabels(hps.seq1) 357 | plt.legend(loc=0) 358 | save(os.path.join(os.getcwd(),FLAGS.save_dir,'run_{}/looped'.format(make_name(hps))),ext='png',close=True,verbose=False) 359 | 360 | if not hps.artificial_data: 361 | fig = plt.figure() 362 | for ii in range(len(seq1)): 363 | for jj in range(len(seq2)): 364 | signal = np.mean(np.array(av_rewards[ii][jj]),axis=0) 365 | plt.plot([i for i in range(len(signal))],signal,label='{} = {},{} = {}'.format(hps.item1,hps.seq1[ii],hps.item2,hps.seq2[jj])) 366 | plt.xlabel('Steps') 367 | plt.legend(loc=0) 368 | save(os.path.join(os.getcwd(),FLAGS.save_dir,'run_{}/looped_reward'.format(make_name(hps))),ext='png',close=True,verbose=False) 369 | 370 | 371 | def init_logger(hps,my_dir=None): 372 | if not os.path.exists(my_dir): 373 | os.makedirs(my_dir) 374 | handlers = [logging.FileHandler(os.path.join(my_dir,'results.txt'),mode='w'), 375 | logging.StreamHandler()] 376 | logging.basicConfig(level = logging.INFO, format = '%(message)s', handlers = handlers) 377 | 378 | def main(_): 379 | hps = get_hps().parse(FLAGS.hpconfig) 380 | FLAGS.check_dir = FLAGS.check_dir + str(np.random.randint(0,1e7,1)[0]) 381 | 382 | if hps.depth>1 and len(hps.var_type) == 1: 383 | hps.var_type = [hps.var_type[0] for i in range(hps.depth)] 384 | print(hps.var_type) 385 | 386 | if not hps.loop_hyper: 387 | hps._set('seq1',[hps._items[hps.item1]]) 388 | hps._set('seq2',[hps._items[hps.item2]]) 389 | 390 | # logging and saving 391 | hps.my_dir = os.path.join(os.getcwd(),FLAGS.save_dir,'{}'.format(make_name(hps))) 392 | init_logger(hps,hps.my_dir) 393 | with open(os.path.join(hps.my_dir,'hps.txt'),'w') as file: 394 | file.write(repr(hps._items)) 395 | 396 | run(hps) 397 | 398 | if __name__ == "__main__": 399 | tf.app.run() 400 | 401 | 402 | -------------------------------------------------------------------------------- /vae_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Variational inference on Toy Domain 5 | @author: thomas 6 | """ 7 | import tensorflow as tf 8 | import numpy as np 9 | import plotter.Plotter2 as Plotter 10 | from tfutils.helpers import anneal_linear, HParams 11 | import matplotlib.pyplot as plt 12 | import os 13 | from pythonutils.helpers import save, make_name 14 | import logging 15 | 16 | # settings 17 | flags = tf.app.flags 18 | flags.DEFINE_string("save_dir", "results/toy", "Logging directory.") 19 | flags.DEFINE_string("hpconfig", "", "Overrides default hyper-parameters.") 20 | FLAGS = flags.FLAGS 21 | 22 | def get_hps(): 23 | ''' Hyperparameter settings ''' 24 | return HParams( 25 | dataset = 2, # which dataset to run on 26 | network = 1, # which network to run 27 | # 28 | data_size = 2000, 29 | n_epochs = 30000, 30 | batch_size = 64, 31 | eval_freq = 500, # Evaluate every .. steps 32 | debug = False, 33 | 34 | # Learning 35 | lr_init = 0.005, 36 | lr_final = 0.0005, 37 | anneal_frac_lr = 0.7, 38 | 39 | # Latent dimensions 40 | #var_type = ['continuous','discrete'], 41 | var_type = ['continuous'], 42 | depth = 1, # depth of stochastic layers 43 | h_size = 30, # representation size in stoch layers 44 | deterministic = False, # used for the MLP only 45 | 46 | # discrete 47 | K = 3, # categories per discrete latent 48 | N = 3, # number of discrete latents 49 | tau_init = 2.0, 50 | tau_final = 0.001, 51 | anneal_frac_tau = 0.7, 52 | 53 | # cont 54 | z_size = 3, # number of cont latents 55 | n_flow = 5, # depth of flow (only for continuous latents, for now) 56 | ar = False, 57 | ignore_sigma_latent = False, 58 | ignore_sigma_flow = False, 59 | 60 | # KL divergence 61 | k = 3, 62 | alpha = 0.5, 63 | kl_min = 0.07, # Number of "free bits/nats", only used if >0 64 | use_lamb = False, 65 | lamb_init = 0.1, # Annealing KL contribution. Don't use together with kl_min>0 66 | lamb_final = 1.0, 67 | anneal_frac_lamb = 0.3, 68 | 69 | # Architecture 70 | ladder = True, # For ladder = True, top-down inference 71 | resnet = False, # Whether to use Resnet connections 72 | 73 | # Outcome 74 | out_lik = 'normal', # 75 | ignore_sigma_outcome = False, 76 | 77 | # Specify how to loop over hyperparameter settings 78 | loop_hyper = False, 79 | item1 = 'lr_init', 80 | seq1 = [0.005], 81 | item2 = 'kl_min', 82 | seq2 = [0.07], 83 | n_rep = 10, 84 | verbose = True 85 | ) 86 | 87 | def run(hps): 88 | if hps.verbose: 89 | plt.ion() 90 | ## Generate data 91 | if hps.dataset == 1: 92 | from datasets import Dataset as Dataset 93 | elif hps.dataset == 2: 94 | from datasets import Dataset2 as Dataset 95 | 96 | Data = Dataset.Dataset(hps.data_size) 97 | Data_valid = Dataset.Dataset(int(hps.data_size/4)) 98 | Data_test = Dataset.Dataset(int(hps.data_size/4)) 99 | 100 | n_rep = hps.n_rep 101 | seq1 = hps.seq1 102 | seq2 = hps.seq2 103 | results = np.empty([len(seq1),len(seq2),n_rep]) 104 | results_elbo = np.empty([len(seq1),len(seq2),n_rep]) 105 | for j,item1 in enumerate(seq1): 106 | hps._set(hps.item1,item1) 107 | for l,item2 in enumerate(seq2): 108 | hps._set(hps.item2,item2) 109 | for rep in range(n_rep): 110 | tf.reset_default_graph() 111 | 112 | # Initialize anneal parameters 113 | np_lr= anneal_linear(0,hps.n_epochs * hps.anneal_frac_lr,hps.lr_final,hps.lr_init) 114 | np_temp = anneal_linear(0,hps.n_epochs * hps.anneal_frac_tau,hps.tau_final,hps.tau_init) 115 | np_lamb = anneal_linear(0,hps.n_epochs * hps.anneal_frac_lamb,hps.lamb_final,hps.lamb_init) 116 | 117 | Plot = Plotter.Plotter() 118 | Plot.plot_data(Data) 119 | # Build network 120 | if hps.network == 1: 121 | import networks.toy_vae as net 122 | model = net.Network(hps) 123 | elif hps.network == 2: 124 | import networks.toy_mlp as net 125 | model = net.Network(hps) 126 | 127 | # Check model size 128 | total_size = 0 129 | for v in tf.trainable_variables(): 130 | total_size += np.prod([int(s) for s in v.get_shape()]) 131 | print("Total number of trainable variables: {}".format(total_size)) 132 | 133 | # Session and initialization 134 | sess = tf.Session() 135 | if hps.debug: 136 | sess = tf.python.debug.LocalCLIDebugWrapperSession(sess) 137 | sess.add_tensor_filter("has_inf_or_nan", tf.python.debug.has_inf_or_nan) 138 | 139 | 140 | np_x, np_y = Data.next_batch_random(hps.batch_size) 141 | sess.run(model.init_op, feed_dict = {model.x : np_x[:,None], 142 | model.y : np_y[:,None]}) 143 | 144 | # Some storage 145 | t = [] 146 | lr = [] 147 | neg_elbo_keep = [] 148 | train_nats_keep = [] 149 | valid_nats_keep = [] 150 | test_nats_keep = [] 151 | min_valid_nats = 1e50 152 | best_sample = [] 153 | best_test_nats = 0.0 154 | best_elbo = 0.0 155 | best_iter = 0 156 | 157 | # Train 158 | print('Initialized, starting to train') 159 | for i in range(hps.n_epochs): 160 | # draw batch 161 | np_x, np_y = Data.next_batch_epoch(hps.batch_size) 162 | _, np_elbo = sess.run([model.train_op,model.elbo],{ model.x : np_x[:,None], 163 | model.y : np_y[:,None], 164 | model.lr : np_lr, 165 | model.lamb : np_lamb, 166 | model.temp : np_temp, 167 | model.is_training : True, 168 | model.k: hps.k} ) 169 | # Annealing 170 | if i % 250 == 1: 171 | np_lr= anneal_linear(i,hps.n_epochs * hps.anneal_frac_lr,hps.lr_final,hps.lr_init) 172 | np_temp = anneal_linear(i,hps.n_epochs * hps.anneal_frac_tau,hps.tau_final,hps.tau_init) 173 | np_lamb = anneal_linear(i,hps.n_epochs * hps.anneal_frac_lamb,hps.lamb_final,hps.lamb_init) 174 | 175 | # Evaluate 176 | if i % hps.eval_freq == 1: 177 | train_elbo,train_nats,train_kl = sess.run([model.elbo,model.nats,model.kl],{model.x: Data.X[:,None], 178 | model.y: Data.Y[:,None], 179 | model.lamb : np_lamb, 180 | model.temp : 0.0001, 181 | model.is_training : True, # to sample from q(z|y,x), we're not running the train_op anyway 182 | model.k: 40}) 183 | 184 | valid_nats = sess.run(model.nats,{model.x: Data_valid.X[:,None], 185 | model.y: Data_valid.Y[:,None], 186 | model.lamb : np_lamb, 187 | model.temp : 0.0001, 188 | model.is_training : True, # to sample from q(z|y,x), we're not running the train_op anyway 189 | model.k: 40}) 190 | 191 | test_nats = sess.run(model.nats,{model.x: Data_test.X[:,None], 192 | model.y: Data_test.Y[:,None], 193 | model.lamb : np_lamb, 194 | model.temp : 0.0001, 195 | model.is_training : True, # to sample from q(z|y,x), we're not running the train_op anyway 196 | model.k: 40}) 197 | if hps.verbose: 198 | print('Step',i,'ELBO: ',train_elbo, 'Training nats:',train_nats, 'Training KL:',train_kl, 'Valid nats',valid_nats, 199 | ' \n Test set nats',test_nats, 'Learning rate',np_lr,'Softmax Temp',np_temp,) 200 | 201 | # draw new data 202 | y_samples = sess.run(model.y_sample,{model.x: Data_test.X[:,None], 203 | model.y: Data_test.X[:,None], 204 | model.lamb : np_lamb, 205 | model.temp : 0.0001, 206 | model.is_training : False, 207 | model.k: 1}) 208 | t.extend([i]) 209 | lr.extend([np_lr]) 210 | train_nats_keep.extend([train_nats]) 211 | valid_nats_keep.extend([valid_nats]) 212 | test_nats_keep.extend([test_nats]) 213 | neg_elbo_keep.extend([train_elbo]) 214 | #Plot.plot_lr(t,lr) 215 | if hps.verbose: 216 | Plot.plot_samples(Data_test.X[:,None],y_samples) 217 | 218 | if valid_nats < min_valid_nats: 219 | min_valid_nats = valid_nats 220 | best_sample = y_samples # keep the sample 221 | best_test_nats = test_nats 222 | best_elbo = train_nats 223 | best_iter = i 224 | 225 | Plot.plot_samples(Data_test.X[:,None],best_sample) 226 | save(os.path.join(hps.my_dir,'sample{}'.format(rep)),ext='png',close=True,verbose=False) 227 | print('Best result in iteration',best_iter,'with valid_nats',min_valid_nats,'and test nats',best_test_nats) 228 | 229 | fig = plt.figure() 230 | plt.plot(t,train_nats_keep,label='train nats') 231 | plt.plot(t,valid_nats_keep,label='valid nats') 232 | plt.plot(t,test_nats_keep,label='test nats') 233 | plt.plot(t,test_nats_keep,label='negative ELBO') 234 | plt.legend(loc=0) 235 | fig.canvas.draw() 236 | save(os.path.join(hps.my_dir,'nats{}'.format(rep)),ext='png',close=True,verbose=False) 237 | 238 | results[j,l,rep] = best_test_nats 239 | results_elbo[j,l,rep] = best_elbo 240 | 241 | results_raw = results 242 | results_raw_elbo = results_elbo 243 | results = np.mean(results,axis=2) 244 | results_elbo = np.mean(results_elbo,axis=2) 245 | 246 | logging.info('-------------------- Overall Results --------------------------') 247 | logging.info('vae' if hps.network == 1 else 'mlp_{}'.format('deterministic' if hps.deterministic else 'stochastic')) 248 | logging.info('Latent type %s of depth %s',hps.var_type[0],hps.depth) 249 | logging.info('(z_size,n_flow) %s %s and (n,k) %s %s',hps.z_size,hps.n_flow,hps.N,hps.K) 250 | logging.info('Results over %s runs',n_rep) 251 | logging.info('Test nats: %s',results[0][0]) 252 | logging.info('Elbo: %s',-1*results_elbo[0][0]) 253 | logging.info('Raw data over repetitions \n %s \n %s',results_raw,results_raw_elbo) 254 | 255 | if hps.loop_hyper: 256 | fig = plt.figure() 257 | for i in range(results.shape[1]): 258 | plt.plot([i for i in range(len(seq1))],results[:,i],label='{} = {}'.format(hps.item2,hps.seq2[i])) 259 | plt.xlabel(hps.item1) 260 | plt.gca().set_xticklabels(hps.seq1) 261 | plt.legend(loc=0) 262 | save(os.path.join(os.getcwd(),FLAGS.save_dir,'run_{}/looped'.format(make_name(hps))),ext='png',close=True,verbose=False) 263 | 264 | 265 | def init_logger(hps,my_dir=None): 266 | if not os.path.exists(my_dir): 267 | os.makedirs(my_dir) 268 | handlers = [logging.FileHandler(os.path.join(my_dir,'results.txt'),mode='w'), 269 | logging.StreamHandler()] 270 | logging.basicConfig(level = logging.INFO, format = '%(message)s', handlers = handlers) 271 | 272 | def main(_): 273 | hps = get_hps().parse(FLAGS.hpconfig) 274 | 275 | if hps.depth>1 and len(hps.var_type) == 1: 276 | hps.var_type = [hps.var_type[0] for i in range(hps.depth)] 277 | print(hps.var_type) 278 | 279 | if not hps.loop_hyper: 280 | hps._set('seq1',[hps._items[hps.item1]]) 281 | hps._set('seq2',[hps._items[hps.item2]]) 282 | 283 | # logging and saving 284 | hps.my_dir = os.path.join(os.getcwd(),FLAGS.save_dir,'run_{}'.format(make_name(hps))) 285 | init_logger(hps,hps.my_dir) 286 | with open(os.path.join(hps.my_dir,'hps.txt'),'w') as file: 287 | file.write(repr(hps._items)) 288 | 289 | run(hps) 290 | 291 | if __name__ == "__main__": 292 | tf.app.run() 293 | --------------------------------------------------------------------------------