├── .gitignore ├── images ├── 2Dlatentspace_with_hit-miss.png ├── 2Dlatentspace_with_x_coordinate.png └── 2Dlatentspace_with_y_coordinate.png ├── VAE_rec_main.py ├── VAE_rec_model.py ├── VAE_rec_model_reverse.py ├── VAE_util.py └── output.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /images/2Dlatentspace_with_hit-miss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobRomijnders/VAE_rec/HEAD/images/2Dlatentspace_with_hit-miss.png -------------------------------------------------------------------------------- /images/2Dlatentspace_with_x_coordinate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobRomijnders/VAE_rec/HEAD/images/2Dlatentspace_with_x_coordinate.png -------------------------------------------------------------------------------- /images/2Dlatentspace_with_y_coordinate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobRomijnders/VAE_rec/HEAD/images/2Dlatentspace_with_y_coordinate.png -------------------------------------------------------------------------------- /VAE_rec_main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 22 10:43:29 2016 4 | 5 | @author: Rob Romijnders 6 | 7 | TODO 8 | - Cross validate over different learning-rates 9 | """ 10 | import sys 11 | import socket 12 | 13 | if 'rob-laptop' in socket.gethostname(): 14 | sys.path.append('/home/rob/Dropbox/ml_projects/basket_local/') 15 | sys.path.append('/home/rob/Dropbox/ml_projects/basket_local/SportVU-seq') 16 | #The folder where your dataset is. Note that is must end with a '/' 17 | direc = '/home/rob/Dropbox/ml_projects/basket_local/SportVU-seq/' 18 | elif 'rob-com' in socket.gethostname(): 19 | sys.path.append('/home/rob/Documents/nn_sportvu') 20 | direc = '/home/rob/Documents/nn_sportvu/SportVU-seq/' 21 | 22 | #Rajiv: you can add your computer name here 23 | 24 | 25 | import numpy as np 26 | import tensorflow as tf 27 | import matplotlib.pyplot as plt 28 | from tensorflow.python.framework import ops 29 | from tensorflow.python.ops import clip_ops 30 | from basket_util import * 31 | from VAE_util import * 32 | import sklearn as sk 33 | 34 | from sklearn.metrics import roc_auc_score,roc_curve 35 | 36 | from data_loader_class import * 37 | from mpl_toolkits.mplot3d import axes3d 38 | import matplotlib.mlab as mlab 39 | from VAE_rec_model_reverse import * 40 | 41 | 42 | 43 | """Hyperparameters""" 44 | config = {} 45 | config['num_layers'] = 2 46 | config['hidden_size'] = 60 47 | config['max_grad_norm'] = 1 48 | config['batch_size'] = batch_size = 64 49 | config['sl'] = sl = 18 #sequence length 50 | config['mixtures'] = 1 51 | config['learning_rate'] = .005 52 | config['num_l'] = num_l = 2 53 | 54 | 55 | 56 | ratio = 0.8 #Ratio for train-val split 57 | plot_every = 100 #How often do you want terminal output for the performances 58 | max_iterations = 50000 59 | 60 | 61 | 62 | """Load the data""" 63 | #The name of the dataset. Note that it must end with '.csv' 64 | csv_file = 'seq_all_9feet.csv' 65 | #Load an instance 66 | center = np.array([5.25, 25.0, 10.0]) 67 | dl = DataLoad(direc,csv_file, center) 68 | #Munge the data. Arguments see the class 69 | db = 4 #distance to basket 70 | dl.munge_data(11,sl,db) 71 | #Center the data 72 | dl.center_data(center) 73 | dl.entropy_offset() 74 | dl.split_train_test(ratio = 0.8) 75 | data_dict = dl.data 76 | dl.plot_traj_2d(20,'at %.0f feet from basket'%db) 77 | 78 | X_train = np.transpose(data_dict['X_train'],[0,2,1]) 79 | #y_train = data_dict['y_train'] 80 | X_val = np.transpose(data_dict['X_val'],[0,2,1]) 81 | y_val = data_dict['y_val'] 82 | 83 | N,crd,_ = X_train.shape 84 | Nval = X_val.shape[0] 85 | 86 | config['crd'] = crd 87 | 88 | #Proclaim the epochs 89 | epochs = np.floor(batch_size*max_iterations / N) 90 | print('Train with approximately %d epochs' %(epochs)) 91 | 92 | model = Model(config) 93 | 94 | # For now, we collect performances in a Numpy array. 95 | # In future releases, I hope TensorBoard allows for more 96 | # flexibility in plotting 97 | perf_collect = np.zeros((7,int(np.floor(max_iterations /plot_every)))) 98 | 99 | sess = tf.Session() 100 | 101 | #with tf.Session() as sess: 102 | if True: 103 | writer = tf.train.SummaryWriter("/home/rob/Dropbox/ml_projects/basket_local/nn_sportvu/log_tb", sess.graph) 104 | 105 | sess.run(tf.initialize_all_variables()) 106 | 107 | step = 0 # Step is a counter for filling the numpy array perf_collect 108 | for i in range(max_iterations): 109 | batch_ind = np.random.choice(N,batch_size,replace=False) 110 | # debug = sess.run(model.sl_t,feed_dict={model.x:X_train[batch_ind], model.y_: y_train[batch_ind], model.keep_prob: dropout}) 111 | # print(np.max(debug[0])) 112 | # print(np.max(debug[1])) 113 | if i%plot_every == 0: 114 | #Check training performance 115 | fetch = [model.cost_seq, model.cost_kld, model.cost_xstart] 116 | 117 | result = sess.run(fetch,feed_dict = { model.x: X_train[batch_ind]}) 118 | perf_collect[0,step] = cost_train_seq = result[0] 119 | perf_collect[1,step] = cost_train_kld = result[1] 120 | perf_collect[4,step] = cost_train_xstart = result[2] 121 | 122 | #Check validation performance 123 | batch_ind_val = np.random.choice(Nval,batch_size,replace=False) 124 | fetch = [model.cost_seq, model.cost_kld, model.cost_xstart] #, model.merged 125 | 126 | result = sess.run(fetch, feed_dict={ model.x: X_val[batch_ind_val]}) 127 | 128 | perf_collect[2,step] = cost_val_seq = result[0] 129 | perf_collect[3,step] = cost_val_kld = result[1] 130 | perf_collect[5,step] = cost_val_xstart = result[2] 131 | 132 | # #Write information to TensorBoard 133 | # summary_str = result[3] 134 | # writer.add_summary(summary_str, i) 135 | # writer.flush() #Don't forget this command! It makes sure Python writes the summaries to the log-file 136 | print("At %6s / %6s train (%6.3f,%6.3f,%6.3f) val (%6.3f,%6.3f,%6.3f)" % (i,max_iterations,cost_train_seq,cost_train_kld,cost_train_xstart,cost_val_seq,cost_val_kld,cost_val_xstart )) 137 | step +=1 138 | sess.run(model.train_step,feed_dict={model.x:X_train[batch_ind]}) 139 | #In the next line we also fetch the softmax outputs 140 | batch_ind_val = np.random.choice(Nval,batch_size,replace=False) 141 | result = sess.run([model.numel], feed_dict={ model.x: X_val[batch_ind_val]}) 142 | print('The network has %s trainable parameters'%(result[0])) 143 | 144 | # debug = sess.run(model.b_xend) 145 | z_feed = np.random.randn(batch_size,num_l) 146 | result = sess.run(model.x_col, feed_dict={ model.z: z_feed}) 147 | 148 | X_vae = np.transpose(result,[1,2,0]) 149 | labels_dummy = np.random.randint(0,1,size=(batch_size,1)) 150 | plot_basket(X_vae,labels_dummy) 151 | 152 | """Visualize the 2D latent space""" 153 | label_type = 'class' #Color scatter plot according to hit/miss 154 | label_type = 'x' #Color the scatter plot according to x coordinate 155 | label_type = 'y' #Color scatter plot according to y coordinate 156 | 157 | 158 | if num_l == 2: 159 | ##Extract the latent space coordinates of the validation set 160 | start = 0 161 | label = [] #The label to save to visualize the latent space 162 | z_run = [] 163 | 164 | while start + batch_size < Nval: 165 | run_ind = range(start,start+batch_size) 166 | z_mu_fetch = sess.run(model.z_mu, feed_dict = {model.x:X_val[run_ind]}) 167 | z_run.append(z_mu_fetch) 168 | if label_type == 'y': 169 | label.append(X_val[run_ind,1,0]) #The y coordinate of x_start 170 | if label_type == 'x': 171 | label.append(X_val[run_ind,0,0]) #The y coordinate of x_start 172 | if label_type == 'class': 173 | label.append(y_val[run_ind]) 174 | start += batch_size 175 | 176 | z_run = np.concatenate(z_run,axis=0) 177 | label = np.concatenate(label,axis=0) 178 | 179 | plt.figure() 180 | plt.scatter(z_run[:,0],z_run[:,1],c = label,linewidths=0.0) -------------------------------------------------------------------------------- /VAE_rec_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 22 10:43:29 2016 4 | 5 | @author: Rob Romijnders 6 | 7 | TODO 8 | - Cross validate over different learning-rates 9 | """ 10 | import sys 11 | sys.path.append('/home/rob/Dropbox/ml_projects/VAE_rec') 12 | import numpy as np 13 | import tensorflow as tf 14 | import matplotlib.pyplot as plt 15 | from tensorflow.python.framework import ops 16 | from tensorflow.python.ops import clip_ops 17 | from VAE_util import * 18 | 19 | 20 | 21 | class Model(): 22 | def __init__(self,config): 23 | """Hyperparameters""" 24 | num_layers = config['num_layers'] 25 | hidden_size = config['hidden_size'] 26 | max_grad_norm = config['max_grad_norm'] 27 | batch_size = config['batch_size'] 28 | sl = config['sl'] 29 | mixtures = config['mixtures'] 30 | crd = config['crd'] 31 | learning_rate = config['learning_rate'] 32 | num_l = config['num_l'] 33 | self.sl = sl 34 | self.crd = crd 35 | self.batch_size = batch_size 36 | 37 | 38 | #Function for initialization 39 | def xv_init(arg_in, arg_out,shape=None): 40 | low = -np.sqrt(6.0/(arg_in + arg_out)) 41 | high = np.sqrt(6.0/(arg_in + arg_out)) 42 | if shape is None: 43 | tensor_shape = (arg_in, arg_out) 44 | return tf.random_uniform(tensor_shape, minval=low, maxval=high, dtype=tf.float32) 45 | 46 | 47 | 48 | 49 | 50 | # Nodes for the input variables 51 | self.x = tf.placeholder("float", shape=[batch_size, crd,sl], name = 'Input_data') 52 | x_next = tf.sub(self.x[:,:3,1:], self.x[:,:3,:sl-1]) 53 | xn1,xn2,xn3 = tf.split(1,3,x_next) #Now tensors in [batch_size,1,seq_len-1] 54 | 55 | 56 | with tf.variable_scope("Enc") as scope: 57 | cell_enc = tf.nn.rnn_cell.LSTMCell(hidden_size) 58 | cell_enc = tf.nn.rnn_cell.MultiRNNCell([cell_enc] * num_layers) 59 | 60 | #Initial state 61 | initial_state_enc = cell_enc.zero_state(batch_size, tf.float32) 62 | 63 | 64 | outputs_enc = [] 65 | self.states_enc = [] 66 | state = initial_state_enc 67 | for time_step in range(sl): 68 | if time_step > 0: tf.get_variable_scope().reuse_variables() 69 | (cell_output, state) = cell_enc(self.x[:, :, time_step], state) 70 | outputs_enc.append(cell_output) 71 | 72 | with tf.name_scope("Enc_2_lat") as scope: 73 | #m_enc,h_enc = tf.split(1,2,self.final_state) 74 | #layer for mean of z 75 | W_mu = tf.Variable(xv_init(hidden_size,num_l)) 76 | b_mu = tf.Variable(tf.constant(0.1,shape=[num_l],dtype=tf.float32)) 77 | self.z_mu = tf.nn.xw_plus_b(cell_output,W_mu,b_mu) #mu, mean, of latent space 78 | 79 | #layer for sigma of z 80 | W_sig = tf.Variable(xv_init(hidden_size,num_l)) 81 | b_sig = tf.Variable(tf.constant(0.1,shape=[num_l],dtype=tf.float32)) 82 | z_sig_log_sq = tf.nn.xw_plus_b(cell_output,W_sig,b_sig) #sigma of latent space, in log-scale and squared. 83 | # This log_sq will save computation later on. log(sig^2) is a real number, so no sigmoid is necessary 84 | 85 | with tf.name_scope("Latent_space") as scope: 86 | self.eps = tf.random_normal(tf.shape(self.z_mu),0,1,dtype=tf.float32) 87 | self.z = self.z_mu + tf.mul(tf.sqrt(tf.exp(z_sig_log_sq)),self.eps) #Z is the vector in latent space 88 | 89 | with tf.variable_scope("Lat_2_dec") as scope: 90 | #Create initial vector 91 | params_xstart = 3 + 4 # 3 (X,Y,Z) plus 4 (sx,sx,sz,rho) 92 | W_xstart = tf.Variable(xv_init(num_l,params_xstart)) 93 | self.b_xstart = tf.Variable(tf.constant([5.0,2.0,0.1,8.0,12.0,2.0,0.2],dtype=tf.float32)) 94 | self.parameters_xstart = tf.nn.xw_plus_b(self.z,W_xstart,self.b_xstart) 95 | 96 | mu1x,mu2x,mu3x,s1x,s2x,s3x,rhox = tf.split(1,params_xstart,self.parameters_xstart) #Individual vectors in [batch_size,1] 97 | s1x = tf.exp(s1x) 98 | s2x = tf.exp(s2x) 99 | s3x = tf.exp(s3x) 100 | rhox = tf.tanh(rhox) 101 | x_start = tf.concat(1,[mu1x,mu2x,mu3x]) 102 | 103 | 104 | #Reconstruction loss for x_start 105 | xs1,xs2,xs3 = tf.split(1,3,self.x[:,:3,0]) 106 | pxstart12 = tf_2d_normal(xs1, xs2, mu1x, mu2x, s1x, s2x, rhox) #probability in x1x2 plane 107 | pxstart3 = tf_1d_normal(xs3,mu3x,s3x) 108 | pxstart = tf.mul(pxstart12,pxstart3) 109 | loss_xstart = -tf.log(tf.maximum(pxstart, 1e-20)) # at the beginning, some errors are exactly zero. 110 | self.cost_xstart = tf.reduce_mean(loss_xstart)###tf.constant(0.0)# 111 | #Create initial hidden state and memory state 112 | W_hstart = tf.Variable(xv_init(num_l,hidden_size)) 113 | b_hstart = tf.Variable(tf.constant(0.01,shape=[hidden_size],dtype=tf.float32)) 114 | h_start = tf.nn.xw_plus_b(self.z,W_hstart,b_hstart) 115 | 116 | with tf.variable_scope("Out_layer") as scope: 117 | params = 7 # x,y,z,sx,sy,sz,rho 118 | output_units = mixtures*params #Two for distribution over hit&miss, params for distribution parameters 119 | W_o = tf.Variable(tf.random_normal([hidden_size,output_units], stddev=0.1)) 120 | b_o = tf.Variable(tf.constant(0.5, shape=[output_units])) 121 | 122 | 123 | with tf.variable_scope("Dec") as scope: 124 | cell_dec = tf.nn.rnn_cell.LSTMCell(hidden_size) 125 | cell_dec = tf.nn.rnn_cell.MultiRNNCell([cell_dec] * num_layers) 126 | 127 | #Initial state 128 | initial_state_dec = tf.tile(h_start,[1,2*num_layers]) 129 | PARAMS = [] 130 | self.states = [] 131 | state = initial_state_dec 132 | x_in = x_start 133 | x_collect = [] 134 | x_collect.append(x_in) 135 | for time_step in range(sl): 136 | if time_step > 0: tf.get_variable_scope().reuse_variables() 137 | (cell_output, state) = cell_dec(x_in, state) 138 | self.states.append(state) 139 | #Convert hidden state to offset for the next 140 | params_MDN = tf.nn.xw_plus_b(cell_output,W_o,b_o) # Now in [batch_size,output_units] 141 | PARAMS.append(params_MDN) 142 | x_in = x_in + params_MDN[:,:3] #First three columns are the new x_in 143 | x_collect.append(x_in) 144 | 145 | #Prepare x_collect for extraction 146 | self.x_col = tf.pack(x_collect) #in [seq_len, batch_size,crd] 147 | 148 | 149 | with tf.variable_scope("Loss_calc") as scope: 150 | ### Reconstruction loss 151 | PARAMS = tf.pack(PARAMS[:-1]) 152 | PARAMS = tf.transpose(PARAMS,[1,2,0]) # Now in [batch_size, output_units,seq_len-1] 153 | mu1,mu2,mu3,s1,s2,s3,rho = tf.split(1,7,PARAMS) #Each Tensor in [batch_size,seq_len-1] 154 | s1 = tf.exp(s1) 155 | s2 = tf.exp(s2) 156 | s3 = tf.exp(s3) 157 | rho = tf.tanh(rho) 158 | px1x2 = tf_2d_normal(xn1, xn2, mu1, mu2, s1, s2, rho) #probability in x1x2 plane 159 | px3 = tf_1d_normal(xn3,mu3,s3) 160 | px1x2x3 = tf.mul(px1x2,px3) #Now in [batch_size,1,seq_len-1] 161 | loss_seq = -tf.log(tf.maximum(px1x2x3, 1e-20)) # at the beginning, some errors are exactly zero. 162 | self.cost_seq = tf.reduce_mean(loss_seq) 163 | 164 | ### KL divergence between posterior on encoder and prior on z 165 | self.cost_kld = tf.reduce_mean(-0.5*tf.reduce_sum((1+z_sig_log_sq-tf.square(self.z_mu)-tf.exp(z_sig_log_sq)),1)) #KL divergence 166 | 167 | self.cost = self.cost_seq + self.cost_kld + self.cost_xstart 168 | 169 | with tf.name_scope("train") as scope: 170 | tvars = tf.trainable_variables() 171 | #We clip the gradients to prevent explosion 172 | grads = tf.gradients(self.cost, tvars) 173 | grads, _ = tf.clip_by_global_norm(grads,max_grad_norm) 174 | 175 | #Some decay on the learning rate 176 | global_step = tf.Variable(0,trainable=False) 177 | lr = tf.train.exponential_decay(learning_rate,global_step,1000,0.90,staircase=False) 178 | optimizer = tf.train.AdamOptimizer(lr) 179 | gradients = zip(grads, tvars) 180 | self.train_step = optimizer.apply_gradients(gradients,global_step=global_step) 181 | # The following block plots for every trainable variable 182 | # - Histogram of the entries of the Tensor 183 | # - Histogram of the gradient over the Tensor 184 | # - Histogram of the grradient-norm over the Tensor 185 | self.numel = tf.constant([[0]]) 186 | for gradient, variable in gradients: 187 | if isinstance(gradient, ops.IndexedSlices): 188 | grad_values = gradient.values 189 | else: 190 | grad_values = gradient 191 | 192 | self.numel +=tf.reduce_sum(tf.size(variable)) 193 | 194 | h1 = tf.histogram_summary(variable.name, variable) 195 | h2 = tf.histogram_summary(variable.name + "/gradients", grad_values) 196 | h3 = tf.histogram_summary(variable.name + "/gradient_norm", clip_ops.global_norm([grad_values])) 197 | #Define one op to call all summaries 198 | self.merged = tf.merge_all_summaries() 199 | 200 | -------------------------------------------------------------------------------- /VAE_rec_model_reverse.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 22 10:43:29 2016 4 | 5 | @author: Rob Romijnders 6 | 7 | TODO 8 | - Cross validate over different learning-rates 9 | """ 10 | import sys 11 | sys.path.append('/home/rob/Dropbox/ml_projects/VAE_rec') 12 | import numpy as np 13 | import tensorflow as tf 14 | import matplotlib.pyplot as plt 15 | from tensorflow.python.framework import ops 16 | from tensorflow.python.ops import clip_ops 17 | from VAE_util import * 18 | 19 | 20 | 21 | class Model(): 22 | def __init__(self,config): 23 | """Hyperparameters""" 24 | num_layers = config['num_layers'] 25 | hidden_size = config['hidden_size'] 26 | max_grad_norm = config['max_grad_norm'] 27 | batch_size = config['batch_size'] 28 | sl = config['sl'] 29 | mixtures = config['mixtures'] 30 | crd = config['crd'] 31 | learning_rate = config['learning_rate'] 32 | num_l = config['num_l'] 33 | self.sl = sl 34 | self.crd = crd 35 | self.batch_size = batch_size 36 | 37 | 38 | #Function for initialization 39 | def xv_init(arg_in, arg_out,shape=None): 40 | low = -np.sqrt(6.0/(arg_in + arg_out)) 41 | high = np.sqrt(6.0/(arg_in + arg_out)) 42 | if shape is None: 43 | tensor_shape = (arg_in, arg_out) 44 | return tf.random_uniform(tensor_shape, minval=low, maxval=high, dtype=tf.float32) 45 | 46 | # Nodes for the input variables 47 | self.x = tf.placeholder("float", shape=[batch_size, crd,sl], name = 'Input_data') 48 | x_next = tf.sub(self.x[:,:3,1:], self.x[:,:3,:sl-1]) 49 | xn1,xn2,xn3 = tf.split(1,3,x_next) #Now tensors in [batch_size,1,seq_len-1] 50 | rev_dims = [False, False, True] 51 | xn1 = tf.reverse(xn1,rev_dims) 52 | xn2 = tf.reverse(xn2,rev_dims) 53 | xn3 = tf.reverse(xn3,rev_dims) 54 | 55 | 56 | with tf.variable_scope("Enc") as scope: 57 | cell_enc = tf.nn.rnn_cell.LSTMCell(hidden_size) 58 | cell_enc = tf.nn.rnn_cell.MultiRNNCell([cell_enc] * num_layers) 59 | 60 | #Initial state 61 | initial_state_enc = cell_enc.zero_state(batch_size, tf.float32) 62 | 63 | 64 | outputs_enc = [] 65 | self.states_enc = [] 66 | state = initial_state_enc 67 | for time_step in range(sl): 68 | if time_step > 0: tf.get_variable_scope().reuse_variables() 69 | (cell_output, state) = cell_enc(self.x[:, :, time_step], state) 70 | outputs_enc.append(cell_output) 71 | 72 | with tf.name_scope("Enc_2_lat") as scope: 73 | #m_enc,h_enc = tf.split(1,2,self.final_state) 74 | #layer for mean of z 75 | W_mu = tf.Variable(xv_init(hidden_size,num_l)) 76 | b_mu = tf.Variable(tf.constant(0.1,shape=[num_l],dtype=tf.float32)) 77 | self.z_mu = tf.nn.xw_plus_b(cell_output,W_mu,b_mu) #mu, mean, of latent space 78 | 79 | #layer for sigma of z 80 | W_sig = tf.Variable(xv_init(hidden_size,num_l)) 81 | b_sig = tf.Variable(tf.constant(0.1,shape=[num_l],dtype=tf.float32)) 82 | z_sig_log_sq = tf.nn.xw_plus_b(cell_output,W_sig,b_sig) #sigma of latent space, in log-scale and squared. 83 | # This log_sq will save computation later on. log(sig^2) is a real number, so no sigmoid is necessary 84 | 85 | with tf.name_scope("Latent_space") as scope: 86 | self.eps = tf.random_normal(tf.shape(self.z_mu),0,1,dtype=tf.float32) 87 | self.z = self.z_mu + tf.mul(tf.sqrt(tf.exp(z_sig_log_sq)),self.eps) #Z is the vector in latent space 88 | 89 | with tf.variable_scope("Lat_2_dec") as scope: 90 | #Create initial vector 91 | params_xend = 3 + 4 # 3 (X,Y,Z) plus 4 (sx,sx,sz,rho) 92 | W_xend = tf.Variable(xv_init(num_l,params_xend)) 93 | self.b_xend = tf.Variable(tf.constant(0.1,shape=[params_xend],dtype=tf.float32)) 94 | self.parameters_xend = tf.nn.xw_plus_b(self.z,W_xend,self.b_xend) 95 | 96 | mu1x,mu2x,mu3x,s1x,s2x,s3x,rhox = tf.split(1,params_xend,self.parameters_xend) #Individual vectors in [batch_size,1] 97 | s1x = tf.exp(s1x) 98 | s2x = tf.exp(s2x) 99 | s3x = tf.exp(s3x) 100 | rhox = tf.tanh(rhox) 101 | x_end = tf.concat(1,[mu1x,mu2x,mu3x]) 102 | 103 | 104 | #Reconstruction loss for x_end 105 | xs1,xs2,xs3 = tf.split(1,3,self.x[:,:3,sl-1]) 106 | pxend12 = tf_2d_normal(xs1, xs2, mu1x, mu2x, s1x, s2x, rhox) #probability in x1x2 plane 107 | pxend3 = tf_1d_normal(xs3,mu3x,s3x) 108 | pxend = tf.mul(pxend12,pxend3) 109 | loss_xend = -tf.log(tf.maximum(pxend, 1e-20)) # at the beginning, some errors are exactly zero. 110 | self.cost_xstart = tf.reduce_mean(loss_xend)###tf.constant(0.0)# 111 | #Create initial hidden state and memory state 112 | W_hstart = tf.Variable(xv_init(num_l,hidden_size)) 113 | b_hstart = tf.Variable(tf.constant(0.01,shape=[hidden_size],dtype=tf.float32)) 114 | h_start = tf.nn.xw_plus_b(self.z,W_hstart,b_hstart) 115 | 116 | with tf.variable_scope("Out_layer") as scope: 117 | params = 7 # x,y,z,sx,sy,sz,rho 118 | output_units = mixtures*params #Two for distribution over hit&miss, params for distribution parameters 119 | W_o = tf.Variable(tf.random_normal([hidden_size,output_units], stddev=0.1)) 120 | b_o = tf.Variable(tf.constant(0.5, shape=[output_units])) 121 | 122 | 123 | with tf.variable_scope("Dec") as scope: 124 | cell_dec = tf.nn.rnn_cell.LSTMCell(hidden_size) 125 | cell_dec = tf.nn.rnn_cell.MultiRNNCell([cell_dec] * num_layers) 126 | 127 | #Initial state 128 | initial_state_dec = tf.tile(h_start,[1,2*num_layers]) 129 | PARAMS = [] 130 | self.states = [] 131 | state = initial_state_dec 132 | x_in = x_end 133 | x_collect = [] 134 | x_collect.append(x_in) 135 | for time_step in range(sl): 136 | if time_step > 0: tf.get_variable_scope().reuse_variables() 137 | (cell_output, state) = cell_dec(x_in, state) 138 | self.states.append(state) 139 | #Convert hidden state to offset for the next 140 | params_MDN = tf.nn.xw_plus_b(cell_output,W_o,b_o) # Now in [batch_size,output_units] 141 | PARAMS.append(params_MDN) 142 | x_in = x_in - params_MDN[:,:3] #First three columns are the new x_in 143 | x_collect.append(x_in) 144 | 145 | #Prepare x_collect for extraction 146 | self.x_col = tf.pack(x_collect) #in [seq_len, batch_size,crd] 147 | 148 | 149 | with tf.variable_scope("Loss_calc") as scope: 150 | ### Reconstruction loss 151 | PARAMS = tf.pack(PARAMS[:-1]) 152 | PARAMS = tf.transpose(PARAMS,[1,2,0]) # Now in [batch_size, output_units,seq_len-1] 153 | mu1,mu2,mu3,s1,s2,s3,rho = tf.split(1,7,PARAMS) #Each Tensor in [batch_size,seq_len-1] 154 | s1 = tf.exp(s1) 155 | s2 = tf.exp(s2) 156 | s3 = tf.exp(s3) 157 | rho = tf.tanh(rho) 158 | px1x2 = tf_2d_normal(xn1, xn2, mu1, mu2, s1, s2, rho) #probability in x1x2 plane 159 | px3 = tf_1d_normal(xn3,mu3,s3) 160 | px1x2x3 = tf.mul(px1x2,px3) #Now in [batch_size,1,seq_len-1] 161 | loss_seq = -tf.log(tf.maximum(px1x2x3, 1e-20)) # at the beginning, some errors are exactly zero. 162 | self.cost_seq = tf.reduce_mean(loss_seq) 163 | 164 | ### KL divergence between posterior on encoder and prior on z 165 | self.cost_kld = tf.reduce_mean(-0.5*tf.reduce_sum((1+z_sig_log_sq-tf.square(self.z_mu)-tf.exp(z_sig_log_sq)),1)) #KL divergence 166 | 167 | self.cost = self.cost_seq + self.cost_kld + self.cost_xstart 168 | 169 | with tf.name_scope("train") as scope: 170 | tvars = tf.trainable_variables() 171 | #We clip the gradients to prevent explosion 172 | grads = tf.gradients(self.cost, tvars) 173 | grads, _ = tf.clip_by_global_norm(grads,max_grad_norm) 174 | 175 | #Some decay on the learning rate 176 | global_step = tf.Variable(0,trainable=False) 177 | lr = tf.train.exponential_decay(learning_rate,global_step,1000,0.90,staircase=False) 178 | optimizer = tf.train.AdamOptimizer(lr) 179 | gradients = zip(grads, tvars) 180 | self.train_step = optimizer.apply_gradients(gradients,global_step=global_step) 181 | # The following block plots for every trainable variable 182 | # - Histogram of the entries of the Tensor 183 | # - Histogram of the gradient over the Tensor 184 | # - Histogram of the grradient-norm over the Tensor 185 | self.numel = tf.constant([[0]]) 186 | for gradient, variable in gradients: 187 | if isinstance(gradient, ops.IndexedSlices): 188 | grad_values = gradient.values 189 | else: 190 | grad_values = gradient 191 | 192 | self.numel +=tf.reduce_sum(tf.size(variable)) 193 | 194 | h1 = tf.histogram_summary(variable.name, variable) 195 | h2 = tf.histogram_summary(variable.name + "/gradients", grad_values) 196 | h3 = tf.histogram_summary(variable.name + "/gradient_norm", clip_ops.global_norm([grad_values])) 197 | #Define one op to call all summaries 198 | self.merged = tf.merge_all_summaries() 199 | 200 | -------------------------------------------------------------------------------- /VAE_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jun 9 12:12:57 2016 4 | 5 | @author: rob 6 | """ 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from scipy.stats import multivariate_normal 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | # Extracts form the implementation by https://github.com/hardmaru/write-rnn-tensorflow 15 | 16 | def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho): 17 | """ 2D normal distribution 18 | input 19 | - x,mu: input vectors 20 | - s1,s2: standard deviances over x1 and x2 21 | - rho: correlation coefficient in x1-x2 plane 22 | """ 23 | # eq # 24 and 25 of http://arxiv.org/abs/1308.0850 24 | norm1 = tf.sub(x1, mu1) 25 | norm2 = tf.sub(x2, mu2) 26 | s1s2 = tf.mul(s1, s2) 27 | z = tf.square(tf.div(norm1, s1))+tf.square(tf.div(norm2, s2))-2.0*tf.div(tf.mul(rho, tf.mul(norm1, norm2)), s1s2) 28 | negRho = 1-tf.square(rho) 29 | result = tf.exp(tf.div(-1.0*z,2.0*negRho)) 30 | denom = 2*np.pi*tf.mul(s1s2, tf.sqrt(negRho)) 31 | px1x2 = tf.div(result, denom) 32 | return px1x2 33 | 34 | def tf_1d_normal(x3,mu3,s3): 35 | """ 3D normal distribution Under assumption that x3 is uncorrelated with x1 and x2 36 | input 37 | - x,mu: input vectors 38 | - s1,s2,s3: standard deviances over x1 and x2 and x3 39 | - rho: correlation coefficient in x1-x2 plane 40 | """ 41 | norm3 = tf.sub(x3, mu3) 42 | z = tf.square(tf.div(norm3, s3)) 43 | result = tf.exp(tf.div(-z,2)) 44 | denom = 2.0*np.pi*s3 45 | px3 = tf.div(result, denom) #probability in x3 dimension 46 | return px3 47 | 48 | #def plot_traj_MDN(sess,val_dict,batch,sl_plot = 5, ind = -1): 49 | # """Plots the trajectory. At given time-stamp, it plots the probability distributions 50 | # of where the next point will be 51 | # input: 52 | # - sess: the TF session 53 | # - val_dict: a dictionary with which to evaluate the model 54 | # - batch: the batch X_val[some_indices] that you feed into val_dict. 55 | # we could also pick this from val-dict, but this workflow is cleaner 56 | # - sl_plot: the time-stamp where you'd like to visualize 57 | # - ind: some index into the batch. if -1, we'll pick a random one""" 58 | # try: 59 | # result = sess.run([model.mu1,model.mu2,model.mu3,model.s1,model.s2,model.s3,model.rho],feed_dict=val_dict) 60 | # except: 61 | # print('We cannot fetch all variables for the MDN') 62 | # batch_size,crd,seq_len = batch.shape 63 | # assert ind < batch_size, 'Your index is outside batch' 64 | # assert sl_plot < seq_len, 'Your sequence index is outside sequence' 65 | # if ind == -1: ind = np.random.randint(0,batch_size) 66 | # delta = 0.025 #Grid size to evaluate the PDF 67 | # width = 1.0 # how far to evaluate the pdf? 68 | # 69 | # fig = plt.figure() 70 | # ax = fig.add_subplot(2,2,1,projection='3d') 71 | # ax.plot(batch[ind,0,:], batch[ind,1,:], batch[ind,2,:],'r') 72 | # ax.scatter(batch[ind,0,sl_plot], batch[ind,1,sl_plot], batch[ind,2,sl_plot]) 73 | # ax.set_xlabel('x coordinate') 74 | # ax.set_ylabel('y coordinate') 75 | # ax.set_zlabel('z coordinate') 76 | # 77 | # mean1 = result[0][ind,0,sl_plot] 78 | # mean2 = result[1][ind,0,sl_plot] 79 | # mean3 = result[2][ind,0,sl_plot] 80 | # sigma1 = result[3][ind,0,sl_plot] 81 | # sigma2 = result[4][ind,0,sl_plot] 82 | # sigma3 = result[5][ind,0,sl_plot] 83 | # sigma12 = result[6][ind,0,sl_plot]*sigma1*sigma2 84 | # 85 | # ax = fig.add_subplot(2,2,2) 86 | # 87 | # x1 = np.arange(-width, width, delta) 88 | # x2 = np.arange(-width, width, delta) 89 | # X1, X2 = np.meshgrid(x1, x2) 90 | # Z = mlab.bivariate_normal(X1, X2, sigma1, sigma2, mean1, mean2,sigma12) 91 | # CS = ax.contour(X1, X2, Z) 92 | # plt.clabel(CS, inline=1, fontsize=10) 93 | # ax.set_xlabel('x coordinate') 94 | # ax.set_ylabel('y coordinate') 95 | # 96 | # ax = fig.add_subplot(2,2,3) 97 | # x3 = np.arange(-width, width, delta) 98 | # X1, X3 = np.meshgrid(x1, x3) 99 | # Z = mlab.bivariate_normal(X1, X3, sigma1, sigma3, mean1, mean3) 100 | # CS = ax.contour(X1, X3, Z) 101 | # plt.clabel(CS, inline=1, fontsize=10) 102 | # ax.set_xlabel('x coordinate') 103 | # ax.set_ylabel('Z coordinate') 104 | # 105 | # ax = fig.add_subplot(2,2,4) 106 | # X2, X3 = np.meshgrid(x2, x3) 107 | # Z = mlab.bivariate_normal(X2, X3, sigma2, sigma3, mean2, mean3) 108 | # CS = ax.contour(X2, X3, Z) 109 | # plt.clabel(CS, inline=1, fontsize=10) 110 | # ax.set_xlabel('y coordinate') 111 | # ax.set_ylabel('Z coordinate') 112 | 113 | def plot_traj_MDN_mult(model,sess,val_dict,batch,sl_plot = 5, ind = -1): 114 | """Plots the trajectory. At given time-stamp, it plots the probability distributions 115 | of where the next point will be 116 | THIS IS FOR MULTIPLE MIXTURES 117 | input: 118 | - sess: the TF session 119 | - val_dict: a dictionary with which to evaluate the model 120 | - batch: the batch X_val[some_indices] that you feed into val_dict. 121 | we could also pick this from val-dict, but this workflow is cleaner 122 | - sl_plot: the time-stamp where you'd like to visualize 123 | - ind: some index into the batch. if -1, we'll pick a random one""" 124 | result = sess.run([model.mu1,model.mu2,model.mu3,model.s1,model.s2,model.s3,model.rho,model.theta],feed_dict=val_dict) 125 | batch_size,crd,seq_len = batch.shape 126 | assert ind < batch_size, 'Your index is outside batch' 127 | assert sl_plot < seq_len, 'Your sequence index is outside sequence' 128 | if ind == -1: ind = np.random.randint(0,batch_size) 129 | delta = 0.025 #Grid size to evaluate the PDF 130 | width = 1.0 # how far to evaluate the pdf? 131 | 132 | fig = plt.figure() 133 | ax = fig.add_subplot(2,2,1,projection='3d') 134 | ax.plot(batch[ind,0,:], batch[ind,1,:], batch[ind,2,:],'r') 135 | ax.scatter(batch[ind,0,sl_plot], batch[ind,1,sl_plot], batch[ind,2,sl_plot]) 136 | ax.set_xlabel('x coordinate') 137 | ax.set_ylabel('y coordinate') 138 | ax.set_zlabel('z coordinate') 139 | 140 | 141 | # lower-case x1,x2,x3 are indezing the grid 142 | # upper-case X1,X2,X3 are coordinates in the mesh 143 | x1 = np.arange(-width, width+0.1, delta) 144 | x2 = np.arange(-width, width+0.2, delta) 145 | x3 = np.arange(-width, width+0.3, delta) 146 | X1,X2,X3 = np.meshgrid(x1,x2,x3,indexing='ij') 147 | XX = np.stack((X1,X2,X3),axis=3) 148 | 149 | PP = [] 150 | 151 | mixtures = result[0].shape[1] 152 | for m in range(mixtures): 153 | mean = np.zeros((3)) 154 | mean[0] = result[0][ind,m,sl_plot] 155 | mean[1] = result[1][ind,m,sl_plot] 156 | mean[2] = result[2][ind,m,sl_plot] 157 | cov = np.zeros((3,3)) 158 | sigma1 = result[3][ind,m,sl_plot] 159 | sigma2 = result[4][ind,m,sl_plot] 160 | sigma3 = result[5][ind,m,sl_plot] 161 | sigma12 = result[6][ind,m,sl_plot]*sigma1*sigma2 162 | cov[0,0] = np.square(sigma1) 163 | cov[1,1] = np.square(sigma2) 164 | cov[2,2] = np.square(sigma3) 165 | cov[1,2] = sigma12 166 | cov[2,1] = sigma12 167 | rv = multivariate_normal(mean,cov) 168 | P = rv.pdf(XX) #P is now in [x1,x2,x3] 169 | PP.append(P) 170 | # PP is now a list 171 | PP = np.stack(PP,axis=3) 172 | # PP is now in [x1,x2,x3,mixtures] 173 | #Multiply with the mixture 174 | theta_local = result[7][ind,:,sl_plot] 175 | ZZ = np.dot(PP,theta_local) 176 | #ZZ is now in [x1,x2,x3] 177 | 178 | print('The theta variables %s'%theta_local) 179 | 180 | 181 | #Every Z is a marginalization of ZZ. 182 | # summing over axis 2, gives the pdf over x1,x2 183 | # summing over axis 1, gives the pdf over x1,x3 184 | # summing over axis 0, gives the pdf over x2,x3 185 | ax = fig.add_subplot(2,2,2) 186 | X1, X2 = np.meshgrid(x1, x2) 187 | Z = np.sum(ZZ,axis=2) 188 | CS = ax.contour(X1, X2, Z.T) 189 | plt.clabel(CS, inline=1, fontsize=10) 190 | ax.set_xlabel('x coordinate') 191 | ax.set_ylabel('y coordinate') 192 | 193 | ax = fig.add_subplot(2,2,3) 194 | X1, X3 = np.meshgrid(x1, x3) 195 | Z = np.sum(ZZ,axis=1) 196 | CS = ax.contour(X1, X3, Z.T) 197 | plt.clabel(CS, inline=1, fontsize=10) 198 | ax.set_xlabel('x coordinate') 199 | ax.set_ylabel('Z coordinate') 200 | 201 | ax = fig.add_subplot(2,2,4) 202 | X2, X3 = np.meshgrid(x2, x3) 203 | Z = np.sum(ZZ,axis=0) 204 | CS = ax.contour(X2, X3, Z.T) 205 | plt.clabel(CS, inline=1, fontsize=10) 206 | ax.set_xlabel('y coordinate') 207 | ax.set_ylabel('Z coordinate') 208 | 209 | 210 | 211 | 212 | 213 | 214 | # Piece of code doesn;t work yet 215 | #def tf_3d_normal(x_in, mu_in, s1, s2, s3, rho12, rho13, rho23): 216 | # 217 | # x = tf.sub(x_in,mu_in) 218 | # V11 = tf.pow(s1,2) 219 | # V12 = tf.mul(tf.mul(rho12,s1),s2) 220 | # V13 = tf.mul(tf.mul(rho13,s1),s3) 221 | # V22 = tf.pow(s2,2) 222 | # V23 = tf.mul(tf.mul(rho23,s2),s3) 223 | # V22 = tf.pow(s3,2) 224 | # 225 | # cov = tf.pack([tf.pack([V11,V12.V13]),tf.pack([V12,V22,V23]),tf.pack([V13,V23,V33])]) 226 | # quad = tf.matmul(x,tf.matmul(tf.inv(cov),tf.transpose(x))) 227 | # expo = tf.exp(tf.mul(tf.constant([-0.5]),quad)) 228 | # den = tf.mul(tf.pow((tf.constant([2*3.1415])),tf.constant([3.0/2.0])), 229 | # determinant = tf.mul(V11,tf.mul(V22,V33)) - tf.mul(V11,tf.pow(V23,2)) + tf.mul(V12,tf.sub(tf.mul(tf.mul(tf.constant([2]),V13),V23),tf.mul(V12,V33))) - tf.mul(V22,tf.pow(V13,2)) 230 | # 231 | # 232 | # w = tf.pow(x1,2) 233 | # 234 | ## s1s2 = tf.mul(s1, s2) 235 | ## z = tf.square(tf.div(norm1, s1))+tf.square(tf.div(norm2, s2))-2*tf.div(tf.mul(rho, tf.mul(norm1, norm2)), s1s2) 236 | ## negRho = 1-tf.square(rho) 237 | ## result = tf.exp(tf.div(-z,2*negRho)) 238 | ## denom = 2*np.pi*tf.mul(s1s2, tf.sqrt(negRho)) 239 | ## result = tf.div(result, denom) 240 | # return result -------------------------------------------------------------------------------- /output.txt: -------------------------------------------------------------------------------- 1 | predict x_start and work through 2 | 3 | At 0 / 50000 train ( 5.601, 0.673,29.206) val ( 5.550, 0.639,29.505) 4 | At 100 / 50000 train ( 1.250, 3.564,15.702) val ( 1.178, 3.600,15.483) 5 | At 200 / 50000 train ( 1.239, 2.078,15.037) val ( 1.214, 2.093,15.042) 6 | At 300 / 50000 train ( 1.073, 2.605,13.742) val ( 1.110, 2.669,14.310) 7 | At 400 / 50000 train ( 0.951, 3.222,14.502) val ( 0.681, 3.221,13.842) 8 | At 500 / 50000 train ( 0.656, 3.333,12.708) val ( 0.469, 3.309,12.658) 9 | At 600 / 50000 train ( 0.327, 2.925,12.917) val ( 0.228, 2.971,12.622) 10 | At 700 / 50000 train ( 0.013, 3.049,12.133) val ( 0.411, 3.060,13.853) 11 | At 800 / 50000 train (-0.041, 3.190,11.835) val ( 0.436, 3.189,13.077) 12 | At 900 / 50000 train ( 0.145, 3.080,12.613) val ( 0.009, 3.102,11.706) 13 | At 1000 / 50000 train (-0.194, 3.352,11.734) val ( 0.124, 3.261,11.765) 14 | At 1100 / 50000 train (-0.218, 3.495,11.265) val (-0.111, 3.394,10.957) 15 | At 1200 / 50000 train (-0.239, 2.969,10.631) val ( 0.109, 2.999,11.315) 16 | At 1300 / 50000 train (-0.218, 3.022,10.509) val (-0.306, 3.025,10.324) 17 | At 1400 / 50000 train (-0.504, 3.318,10.138) val ( 0.144, 3.344, 9.937) 18 | At 1500 / 50000 train (-0.295, 3.005, 9.692) val (-0.448, 3.042, 8.931) 19 | At 1600 / 50000 train (-0.385, 3.167, 9.457) val (-0.687, 3.123, 9.352) 20 | At 1700 / 50000 train (-0.145, 3.113, 9.318) val (-0.225, 3.218, 9.009) 21 | At 1800 / 50000 train (-0.437, 2.919, 9.053) val (-0.367, 3.005, 8.947) 22 | At 1900 / 50000 train (-0.683, 2.976, 9.183) val (-0.460, 3.068, 8.806) 23 | At 2000 / 50000 train (-0.640, 3.026, 8.477) val (-0.757, 2.984, 8.477) 24 | At 2100 / 50000 train (-0.801, 2.981, 8.353) val (-0.593, 3.030, 8.838) 25 | At 2200 / 50000 train (-0.389, 2.821, 8.571) val (-0.646, 2.826, 8.847) 26 | At 2300 / 50000 train (-0.732, 2.997, 8.101) val (-0.788, 3.021, 8.631) 27 | At 2400 / 50000 train (-0.540, 2.740, 8.954) val (-0.708, 2.900, 8.590) 28 | At 2500 / 50000 train (-0.581, 2.903, 8.388) val (-0.589, 2.855, 8.244) 29 | At 2600 / 50000 train (-0.630, 2.959, 8.468) val (-0.601, 2.900, 8.740) 30 | At 2700 / 50000 train (-0.393, 2.485, 8.741) val (-0.697, 2.633, 8.056) 31 | At 2800 / 50000 train (-0.768, 2.785, 7.788) val (-0.558, 2.909, 7.822) 32 | At 2900 / 50000 train (-0.749, 2.814, 7.809) val (-0.525, 2.685, 8.448) 33 | At 3000 / 50000 train (-0.499, 2.737, 8.003) val (-0.668, 2.729, 7.688) 34 | At 3100 / 50000 train (-0.721, 2.659, 7.838) val (-0.494, 2.568, 8.715) 35 | At 3200 / 50000 train (-0.580, 2.678, 8.001) val (-0.614, 2.664, 7.930) 36 | At 3300 / 50000 train (-0.712, 2.685, 7.790) val (-0.896, 2.530, 7.863) 37 | At 3400 / 50000 train (-0.532, 2.707, 7.351) val (-0.787, 2.671, 7.533) 38 | At 3500 / 50000 train (-0.527, 2.693, 7.550) val (-0.368, 2.738, 7.997) 39 | At 3600 / 50000 train (-0.719, 2.664, 7.463) val (-0.554, 2.720, 7.557) 40 | At 3700 / 50000 train (-0.734, 2.827, 7.640) val (-1.035, 2.713, 7.441) 41 | At 3800 / 50000 train (-1.123, 2.722, 7.377) val (-0.764, 2.663, 7.647) 42 | At 3900 / 50000 train (-0.327, 2.656, 7.586) val (-1.077, 2.540, 7.499) 43 | At 4000 / 50000 train (-0.624, 2.731, 7.676) val (-0.591, 2.715, 8.011) 44 | At 4100 / 50000 train (-0.574, 2.755, 7.241) val (-0.803, 2.760, 6.967) 45 | At 4200 / 50000 train (-0.664, 2.673, 7.575) val (-0.806, 2.761, 7.069) 46 | At 4300 / 50000 train (-0.806, 2.743, 7.119) val (-0.955, 2.757, 7.205) 47 | At 4400 / 50000 train (-0.707, 2.633, 7.424) val (-0.932, 2.624, 6.983) 48 | At 4500 / 50000 train (-0.609, 2.940, 7.043) val (-1.097, 2.878, 7.088) 49 | At 4600 / 50000 train (-0.658, 2.672, 6.908) val (-0.367, 2.683, 7.302) 50 | At 4700 / 50000 train (-0.863, 2.738, 6.847) val (-0.077, 2.841, 7.163) 51 | 52 | With reversed targets 53 | With num_l = 20 54 | At 0 / 50000 train ( 5.395, 0.864,26.276) val ( 5.375, 0.842,25.953) 55 | At 100 / 50000 train ( 1.369, 3.004,14.831) val ( 1.339, 3.018,14.524) 56 | At 200 / 50000 train ( 1.338, 3.513,10.279) val ( 1.113, 3.549,11.123) 57 | At 300 / 50000 train ( 1.185, 3.347,11.429) val ( 1.008, 3.375,11.426) 58 | At 400 / 50000 train ( 0.523, 4.065,11.051) val ( 0.984, 4.071, 9.127) 59 | At 500 / 50000 train ( 0.296, 3.952, 9.190) val ( 0.259, 4.025,10.005) 60 | At 600 / 50000 train ( 0.303, 3.946, 8.875) val ( 0.420, 3.971,10.041) 61 | At 700 / 50000 train ( 0.014, 3.858, 9.622) val ( 0.377, 3.849, 9.576) 62 | At 800 / 50000 train ( 0.055, 3.807, 8.364) val ( 0.134, 3.843, 7.843) 63 | At 900 / 50000 train (-0.079, 4.210, 7.926) val ( 0.248, 4.139, 9.748) 64 | At 1000 / 50000 train (-0.073, 4.428, 8.718) val (-0.004, 4.440, 7.407) 65 | At 1100 / 50000 train (-0.068, 4.334, 8.134) val (-0.129, 4.403, 7.360) 66 | At 1200 / 50000 train (-0.160, 3.975, 7.621) val (-0.299, 3.936, 8.087) 67 | At 1300 / 50000 train (-0.163, 3.885, 8.589) val (-0.382, 3.901, 7.018) 68 | At 1400 / 50000 train (-0.534, 4.021, 7.504) val (-0.413, 4.056, 6.720) 69 | At 1500 / 50000 train (-0.616, 4.425, 5.844) val (-0.457, 4.466, 5.421) 70 | At 1600 / 50000 train (-0.294, 4.650, 5.926) val (-0.410, 4.669, 5.696) 71 | At 1700 / 50000 train (-0.598, 4.556, 4.621) val (-0.400, 4.595, 5.495) 72 | At 1800 / 50000 train (-0.546, 4.426, 4.815) val (-0.951, 4.394, 5.474) 73 | At 1900 / 50000 train (-0.543, 4.651, 5.049) val (-0.455, 4.740, 5.027) 74 | At 2000 / 50000 train (-0.257, 4.633, 4.240) val (-0.844, 4.597, 4.408) 75 | At 2100 / 50000 train (-0.297, 4.537, 4.035) val (-0.778, 4.566, 4.162) 76 | At 2200 / 50000 train (-0.865, 4.547, 3.713) val (-0.627, 4.513, 4.115) 77 | At 2300 / 50000 train (-0.531, 4.562, 3.845) val (-0.797, 4.561, 3.703) 78 | At 2400 / 50000 train (-0.489, 4.390, 3.700) val (-0.669, 4.368, 3.994) 79 | At 2500 / 50000 train (-0.644, 4.627, 3.624) val (-0.606, 4.631, 3.781) 80 | At 2600 / 50000 train (-0.616, 4.602, 3.446) val (-0.930, 4.619, 4.231) 81 | At 2700 / 50000 train (-0.619, 4.597, 3.511) val (-0.736, 4.598, 4.404) 82 | At 2800 / 50000 train (-0.577, 4.584, 3.455) val (-0.520, 4.736, 3.570) 83 | At 2900 / 50000 train (-0.217, 4.629, 4.382) val (-0.904, 4.697, 3.545) 84 | At 3000 / 50000 train (-0.810, 4.344, 3.676) val (-0.816, 4.249, 3.836) 85 | At 3100 / 50000 train (-0.604, 4.486, 4.014) val (-0.625, 4.495, 3.687) 86 | At 3200 / 50000 train (-0.786, 4.574, 3.238) val (-0.708, 4.539, 3.169) 87 | At 3300 / 50000 train (-1.036, 4.620, 3.055) val (-0.429, 4.545, 3.341) 88 | At 3400 / 50000 train (-0.542, 4.390, 4.292) val (-0.484, 4.456, 3.268) 89 | At 3500 / 50000 train (-0.702, 4.432, 3.490) val (-1.104, 4.505, 3.641) 90 | At 3600 / 50000 train (-0.584, 4.043, 3.740) val (-0.592, 4.281, 3.324) 91 | At 3700 / 50000 train (-0.630, 4.437, 4.340) val (-0.519, 4.455, 3.948) 92 | At 3800 / 50000 train (-0.594, 4.398, 3.202) val (-0.587, 4.368, 3.352) 93 | At 3900 / 50000 train (-0.721, 4.258, 3.565) val (-0.555, 4.268, 3.536) 94 | At 4000 / 50000 train (-0.685, 4.503, 2.856) val (-0.931, 4.446, 3.162) 95 | At 4100 / 50000 train (-0.823, 4.103, 3.666) val (-0.951, 4.142, 2.964) 96 | At 4200 / 50000 train (-1.039, 4.298, 3.584) val (-0.456, 4.211, 3.074) 97 | At 4300 / 50000 train (-0.439, 4.091, 3.741) val (-0.781, 4.061, 3.466) 98 | At 4400 / 50000 train (-0.904, 4.072, 3.106) val (-0.770, 4.065, 2.677) 99 | At 4500 / 50000 train (-0.779, 4.011, 3.074) val (-0.642, 4.170, 3.349) 100 | At 4600 / 50000 train (-0.976, 4.158, 2.787) val (-0.811, 4.255, 2.660) 101 | At 4700 / 50000 train (-0.524, 3.940, 3.577) val (-0.672, 4.016, 3.427) 102 | At 4800 / 50000 train (-0.030, 3.785, 3.362) val (-0.588, 3.862, 3.614) 103 | At 4900 / 50000 train (-0.793, 3.961, 3.114) val (-0.574, 3.933, 3.197) 104 | At 5000 / 50000 train (-0.811, 3.667, 3.517) val (-0.533, 3.839, 3.721) 105 | At 5100 / 50000 train (-0.728, 3.731, 3.179) val (-0.593, 3.626, 3.597) 106 | At 5200 / 50000 train (-0.889, 3.747, 3.230) val (-0.673, 3.676, 3.546) 107 | At 5300 / 50000 train (-0.726, 3.747, 3.348) val (-0.901, 3.764, 3.324) 108 | At 5400 / 50000 train (-0.152, 3.550, 3.284) val (-0.869, 3.586, 3.581) 109 | At 5500 / 50000 train (-0.236, 3.552, 3.456) val (-0.738, 3.475, 3.105) 110 | At 5600 / 50000 train (-0.184, 3.283, 3.426) val (-0.751, 3.295, 2.921) 111 | At 5700 / 50000 train (-0.727, 3.203, 3.292) val (-0.580, 3.224, 3.417) 112 | At 5800 / 50000 train (-0.828, 3.383, 3.161) val (-0.258, 3.387, 3.658) 113 | At 5900 / 50000 train (-0.679, 3.358, 2.810) val (-0.696, 3.165, 3.423) 114 | At 6000 / 50000 train (-0.610, 3.096, 3.452) val (-0.870, 3.226, 2.746) 115 | At 6100 / 50000 train (-0.522, 3.006, 3.296) val (-0.723, 3.100, 3.442) 116 | At 6200 / 50000 train (-0.601, 3.158, 3.181) val (-0.738, 3.182, 2.771) 117 | At 6300 / 50000 train (-0.387, 3.025, 3.051) val (-0.755, 2.960, 3.435) 118 | At 6400 / 50000 train (-0.761, 2.979, 3.176) val (-0.818, 3.089, 3.034) 119 | At 6500 / 50000 train (-0.685, 3.018, 3.271) val (-0.338, 3.025, 3.388) 120 | At 6600 / 50000 train (-0.526, 2.728, 2.635) val (-0.827, 2.694, 3.027) 121 | At 6700 / 50000 train (-0.487, 2.831, 3.382) val (-0.345, 2.871, 3.396) 122 | At 6800 / 50000 train (-0.588, 2.766, 3.262) val (-0.746, 2.664, 2.875) 123 | At 6900 / 50000 train (-0.342, 2.806, 2.780) val (-0.582, 2.863, 2.556) 124 | At 7000 / 50000 train (-0.419, 2.645, 3.313) val (-0.134, 2.561, 3.504) 125 | At 7100 / 50000 train (-0.549, 2.576, 2.940) val (-0.517, 2.615, 2.697) 126 | 127 | 128 | with better init b_xend 129 | At 0 / 50000 train ( 5.549, 1.104,33.353) val ( 5.536, 1.110,36.488) 130 | At 100 / 50000 train ( 0.691, 1.847, 6.703) val ( 0.428, 1.749, 6.688) 131 | At 200 / 50000 train ( 0.023, 1.448, 5.844) val ( 0.561, 1.308, 6.240) 132 | At 300 / 50000 train ( 0.070, 1.514, 5.381) val ( 0.608, 1.659, 5.278) 133 | At 400 / 50000 train (-0.220, 1.692, 4.119) val ( 0.186, 1.777, 5.310) 134 | At 500 / 50000 train (-0.637, 1.869, 3.920) val (-0.197, 1.858, 4.124) 135 | At 600 / 50000 train (-0.205, 2.150, 3.271) val (-0.395, 2.104, 3.504) 136 | At 700 / 50000 train (-0.264, 2.064, 3.497) val (-0.541, 1.992, 3.253) 137 | At 800 / 50000 train (-0.482, 2.179, 3.518) val (-0.182, 1.930, 3.085) 138 | At 900 / 50000 train (-0.502, 2.027, 2.897) val (-0.500, 2.150, 2.974) 139 | At 1000 / 50000 train (-0.399, 2.415, 3.227) val (-0.214, 2.496, 3.250) 140 | At 1100 / 50000 train (-0.507, 2.130, 3.587) val (-0.397, 2.170, 2.748) 141 | At 1200 / 50000 train (-0.728, 2.175, 3.103) val (-0.636, 2.408, 3.102) 142 | At 1300 / 50000 train (-0.732, 2.162, 3.011) val (-0.755, 2.403, 3.174) 143 | At 1400 / 50000 train (-0.497, 2.116, 2.833) val (-0.609, 2.223, 2.806) 144 | At 1500 / 50000 train (-0.344, 2.285, 2.894) val (-0.437, 2.163, 2.793) 145 | 146 | --------------------------------------------------------------------------------