├── GRUI ├── Run_GAN_imputed.py ├── __init__.py ├── gru_delta_forGAN.py ├── gru_impute_lastValue.py ├── gru_impute_zero.py ├── max_auc ├── mygru_cell.py ├── tune_lastValue_imputed.py ├── tune_mean_imputed.py ├── tune_zero_imputed.py └── untitled1.py ├── Gan_Imputation ├── Physionet_main.py ├── WGAN_GRUI.py ├── __init__.py ├── meanAndstd ├── ops.py ├── ops.pyc ├── readMe ├── utils.py └── utils.pyc ├── KDD_dataset ├── Beijing_AirQuality_Stations_en.xlsx ├── beijing_17_18_aq.csv ├── beijing_17_18_meo.csv ├── bj_aq_online.csv ├── holiday_bj.csv ├── holiday_ld.csv ├── ld_aq_online.csv ├── station_beijing.txt ├── station_london.txt └── tmp │ └── rate.pkl ├── Physionet2012Data ├── __init__.py ├── __init__.pyc ├── calculateMissingRate.py ├── meanAndstd ├── readData.py ├── readData.pyc ├── readTestData.py └── readTestData.pyc ├── Physionet2012ImputedData ├── __init__.py ├── __init__.pyc ├── readImputed.py └── readImputed.pyc ├── README.md ├── requirements.txt └── set-a ├── data_loader.py ├── test.py ├── test.zip └── train.zip /GRUI/Run_GAN_imputed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Mar 26 10:47:41 2018 5 | 6 | @author: yonghong 7 | """ 8 | 9 | from __future__ import print_function 10 | import sys 11 | sys.path.append("..") 12 | import argparse 13 | import os 14 | import tensorflow as tf 15 | from Physionet2012ImputedData import readImputed 16 | import gru_delta_forGAN 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser(description='manual to this script') 19 | parser.add_argument('--gpus', type=str, default = None) 20 | parser.add_argument('--batch-size', type=int, default=128) 21 | parser.add_argument('--run-type', type=str, default='test') 22 | parser.add_argument('--data-path', type=str, default="../Gan_Imputation/imputation_train_results/WGAN_no_mask/") 23 | #输入填充之后的训练数据集的完整路径 Gan_Imputation/imputation_train_results/WGAN_no_mask/30_8_128_64_0.001_400_True_True_True_0.15_0.5 24 | parser.add_argument('--model-path', type=str, default=None) 25 | parser.add_argument('--result-path', type=str, default=None) 26 | parser.add_argument('--lr', type=float, default=0.01) 27 | parser.add_argument('--epoch', type=int, default=30) 28 | parser.add_argument('--n-inputs', type=int, default=41) 29 | parser.add_argument('--n-hidden-units', type=int, default=64) 30 | parser.add_argument('--n-classes', type=int, default=2) 31 | parser.add_argument('--checkpoint-dir', type=str, default='checkpoint_physionet_imputed', 32 | help='Directory name to save the checkpoints') 33 | parser.add_argument('--log-dir', type=str, default='logs_physionet_imputed', 34 | help='Directory name to save training logs') 35 | parser.add_argument('--isNormal',type=int,default=1) 36 | parser.add_argument('--isSlicing',type=int,default=1) 37 | #0 false 1 true 38 | parser.add_argument('--isBatch-normal',type=int,default=1) 39 | args = parser.parse_args() 40 | 41 | 42 | if args.isBatch_normal==0: 43 | args.isBatch_normal=False 44 | if args.isBatch_normal==1: 45 | args.isBatch_normal=True 46 | if args.isNormal==0: 47 | args.isNormal=False 48 | if args.isNormal==1: 49 | args.isNormal=True 50 | if args.isSlicing==0: 51 | args.isSlicing=False 52 | if args.isSlicing==1: 53 | args.isSlicing=True 54 | 55 | 56 | checkdir=args.checkpoint_dir 57 | logdir=args.log_dir 58 | base=args.data_path 59 | data_paths=["30_8_128_64_0.001_400_True_True_True_0.15_0.5"] 60 | max_auc = 0.0 61 | for d in data_paths: 62 | args.data_path=os.path.join(base,d) 63 | path_splits=args.data_path.split("/") 64 | if len(path_splits[-1])==0: 65 | datasetName=path_splits[-2] 66 | else: 67 | datasetName=path_splits[-1] 68 | args.checkpoint_dir=checkdir+"/"+datasetName 69 | args.log_dir=logdir+"/"+datasetName 70 | 71 | dt_train=readImputed.ReadImputedPhysionetData(args.data_path) 72 | dt_train.load() 73 | 74 | dt_test=readImputed.ReadImputedPhysionetData(args.data_path.replace("imputation_train_results","imputation_test_results")) 75 | dt_test.load() 76 | 77 | lrs=[0.004,0.003,0.005,0.006,0.007,0.008,0.009,0.01,0.012,0.015] 78 | #lrs = [0.0075,0.0085] 79 | for lr in lrs: 80 | args.lr=lr 81 | epoch= args.epoch 82 | #epoch=30 83 | args.epoch=epoch 84 | print("epoch: %2d"%(epoch)) 85 | tf.reset_default_graph() 86 | config = tf.ConfigProto() 87 | config.gpu_options.allow_growth = True 88 | with tf.Session(config=config) as sess: 89 | model = gru_delta_forGAN.grui(sess, 90 | args=args, 91 | dataset=dt_train, 92 | test_set = dt_test 93 | ) 94 | 95 | # build graph 96 | model.build() 97 | 98 | auc = model.train() 99 | if auc > max_auc: 100 | max_auc = auc 101 | 102 | print("") 103 | print("max auc is: " + str(max_auc)) 104 | f2 = open("max_auc","w") 105 | f2.write(str(max_auc)) 106 | f2.close() 107 | 108 | 109 | -------------------------------------------------------------------------------- /GRUI/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Jan 29 22:55:25 2018 5 | 6 | @author: lyh 7 | """ 8 | 9 | -------------------------------------------------------------------------------- /GRUI/gru_delta_forGAN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Jan 25 21:52:13 2018 5 | gru for imputed data 6 | @author: lyh 7 | """ 8 | 9 | 10 | from __future__ import print_function 11 | import os 12 | import numpy as np 13 | from sklearn import metrics 14 | import time 15 | import mygru_cell 16 | import tensorflow as tf 17 | from tensorflow.python.ops import math_ops 18 | tf.set_random_seed(1) # set random seed 19 | 20 | class grui(object): 21 | model_name = "GRU_I" 22 | def __init__(self, sess, args, dataset, test_set): 23 | self.lr = args.lr 24 | self.sess=sess 25 | self.isbatch_normal=args.isBatch_normal 26 | self.isNormal=args.isNormal 27 | self.isSlicing=args.isSlicing 28 | self.dataset=dataset 29 | self.test_set = test_set 30 | self.epoch = args.epoch 31 | self.batch_size = args.batch_size 32 | self.n_inputs = args.n_inputs # MNIST data input (img shape: 28*28) 33 | self.n_steps = dataset.maxLength # time steps 34 | self.n_hidden_units = args.n_hidden_units # neurons in hidden layer 35 | self.n_classes = args.n_classes # MNIST classes (0-9 digits) 36 | self.run_type=args.run_type 37 | self.result_path=args.result_path 38 | self.model_path=args.model_path 39 | self.log_dir=args.log_dir 40 | self.checkpoint_dir=args.checkpoint_dir 41 | self.num_batches = len(dataset.x) // self.batch_size 42 | # x y placeholder 43 | self.keep_prob = tf.placeholder(tf.float32) 44 | self.x = tf.placeholder(tf.float32, [None, self.n_steps, self.n_inputs]) 45 | self.y = tf.placeholder(tf.float32, [None, self.n_classes]) 46 | self.m = tf.placeholder(tf.float32, [None, self.n_steps, self.n_inputs]) 47 | self.delta = tf.placeholder(tf.float32, [None, self.n_steps, self.n_inputs]) 48 | self.mean = tf.placeholder(tf.float32, [self.n_inputs,]) 49 | self.lastvalues = tf.placeholder(tf.float32, [None, self.n_steps, self.n_inputs]) 50 | self.x_lengths = tf.placeholder(tf.int32, shape=[self.batch_size,]) 51 | # 对 weights biases 初始值的定义 52 | 53 | 54 | #concatenate x and m 55 | #rth should be also concatenate after x, then decay the older state 56 | #rth's length is n_hidden_units 57 | 58 | 59 | 60 | def RNN(self,X, M, Delta, Mean, Lastvalues, X_lengths,Keep_prob, reuse=False): 61 | # 2*3*2 62 | # X: batches * steps, n_inputs 63 | # m:batches * steps, n_inputs 64 | # delta:batches * steps, n_inputs 65 | # mean:n_inputs mean of all observations, not contian the imputations 66 | # lastvalues: batches * steps, n_inputs last obsevation value of x, if x is missing 67 | # if lastvalues is zero, take mean as it 68 | 69 | with tf.variable_scope("grui", reuse=reuse): 70 | 71 | # then wr_x should be transformed into a diag matrix:tf.matrix_diag(wr_x) 72 | wr_h=tf.get_variable('wr_h',shape=[self.n_inputs,self.n_hidden_units],initializer=tf.random_normal_initializer()) 73 | w_out=tf.get_variable('w_out', shape=[self.n_hidden_units, self.n_classes],initializer=tf.random_normal_initializer()) 74 | 75 | br_h=tf.get_variable('br_h', shape=[self.n_hidden_units, ],initializer=tf.constant_initializer(0.001)) 76 | b_out=tf.get_variable('b_out', shape=[self.n_classes, ],initializer=tf.constant_initializer(0.001)) 77 | 78 | 79 | 80 | Lastvalues=tf.reshape(Lastvalues,[-1,self.n_inputs]) 81 | #M=tf.reshape(M,[-1,self.n_inputs]) 82 | X = tf.reshape(X, [-1, self.n_inputs]) 83 | Delta=tf.reshape(Delta,[-1,self.n_inputs]) 84 | 85 | 86 | rth= tf.matmul( Delta, wr_h)+br_h 87 | rth=math_ops.exp(-tf.maximum(0.0,rth)) 88 | 89 | #X = tf.reshape(X, [-1, n_inputs]) 90 | #print(X.get_shape(),M.get_shape(),rth.get_shape()) 91 | X=tf.concat([X,rth],1) 92 | 93 | X_in = tf.reshape(X, [-1, self.n_steps, self.n_inputs+self.n_hidden_units]) 94 | 95 | #print(X_in.get_shape()) 96 | # X_in = W*X + b 97 | #X_in = tf.matmul(X, weights['in']) + biases['in'] 98 | # X_in ==> (128 batches, 28 steps, 128 hidden) 换回3维 99 | #X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units]) 100 | 101 | if "1.5" in tf.__version__ or "1.7" in tf.__version__ : 102 | grud_cell = mygru_cell.MyGRUCell15(self.n_hidden_units) 103 | elif "1.4" in tf.__version__: 104 | grud_cell = mygru_cell.MyGRUCell4(self.n_hidden_units) 105 | elif "1.2" in tf.__version__: 106 | grud_cell = mygru_cell.MyGRUCell2(self.n_hidden_units) 107 | init_state = grud_cell.zero_state(self.batch_size, dtype=tf.float32) # 初始化全零 state 108 | outputs, final_state = tf.nn.dynamic_rnn(grud_cell, X_in, \ 109 | initial_state=init_state,\ 110 | sequence_length=X_lengths, 111 | time_major=False) 112 | 113 | factor=tf.matrix_diag([1.0/9,1]) 114 | tempout=tf.matmul(tf.nn.dropout(final_state,Keep_prob), w_out) + b_out 115 | results =tf.nn.softmax(tf.matmul(tempout,factor)) #选取最后一个 output 116 | #todo: dropout of 0.5 and batch normalization 117 | return results 118 | def build(self): 119 | 120 | self.pred = self.RNN(self.x, self.m, self.delta, self. mean, self.lastvalues, self.x_lengths, self.keep_prob) 121 | self.cross_entropy = -tf.reduce_sum(self.y*tf.log(self.pred)) 122 | self.train_op = tf.train.AdamOptimizer(self.lr).minimize(self.cross_entropy) 123 | 124 | 125 | self.correct_pred = tf.equal(tf.argmax(self.pred, 1), tf.argmax(self.y, 1)) 126 | self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32)) 127 | self.saver = tf.train.Saver(max_to_keep=None) 128 | 129 | loss_sum = tf.summary.scalar("loss", self.cross_entropy) 130 | acc_sum = tf.summary.scalar("acc", self.accuracy) 131 | 132 | self.sum=tf.summary.merge([loss_sum, acc_sum]) 133 | 134 | 135 | def model_dir(self,epoch): 136 | return "{}_{}_{}_{}_{}_{}/epoch{}".format( 137 | self.model_name, self.lr, 138 | self.batch_size, self.isNormal, 139 | self.isbatch_normal,self.isSlicing, 140 | epoch 141 | ) 142 | 143 | def save(self, checkpoint_dir, step, epoch): 144 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir(epoch), self.model_name) 145 | 146 | if not os.path.exists(checkpoint_dir): 147 | os.makedirs(checkpoint_dir) 148 | 149 | self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step) 150 | 151 | def load(self, checkpoint_dir, epoch): 152 | import re 153 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir(epoch), self.model_name) 154 | 155 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 156 | if ckpt and ckpt.model_checkpoint_path: 157 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 158 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 159 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 160 | print(" [*] Success to read {}".format(ckpt_name)) 161 | return True, counter 162 | else: 163 | #print(" [*] Failed to find a checkpoint") 164 | return False, 0 165 | 166 | def train(self): 167 | 168 | max_auc = 0.5 169 | model_dir2= "{}_{}_{}_{}_{}_{}".format( 170 | self.model_name, self.lr, 171 | self.batch_size, self.isNormal, 172 | self.isbatch_normal,self.isSlicing 173 | ) 174 | if not os.path.exists(os.path.join(self.checkpoint_dir, model_dir2)): 175 | os.makedirs(os.path.join(self.checkpoint_dir, model_dir2)) 176 | result_file=open(os.path.join(self.checkpoint_dir, model_dir2, "result"),"a+") 177 | 178 | if os.path.exists(os.path.join(self.checkpoint_dir, self.model_dir(self.epoch), self.model_name)): 179 | for nowepoch in range(1,self.epoch+1): 180 | print(" [*] Load SUCCESS") 181 | print("epoch: "+str(nowepoch)) 182 | self.load(self.checkpoint_dir,nowepoch) 183 | acc,auc,model_name=self.test(self.test_set,nowepoch) 184 | if auc > max_auc : 185 | max_auc = auc 186 | result_file.write("epoch: "+str(nowepoch)+","+str(acc)+","+str(auc)+"\r\n") 187 | print("") 188 | result_file.close() 189 | return max_auc 190 | else: 191 | # initialize all variables 192 | tf.global_variables_initializer().run() 193 | counter = 1 194 | print(" [!] Load failed...") 195 | 196 | start_time=time.time() 197 | idx = 0 198 | epochcount=0 199 | dataset=self.dataset 200 | while epochcount max_auc : 222 | max_auc = auc 223 | result_file.write("epoch: "+str(epochcount)+","+str(acc)+","+str(auc)+"\r\n") 224 | print("") 225 | 226 | result_file.close() 227 | return max_auc 228 | 229 | def test(self,dataset, epoch): 230 | start_time=time.time() 231 | counter=0 232 | dataset.shuffle(self.batch_size,False) 233 | totalacc=0.0 234 | totalauc=0.0 235 | auccounter=0 236 | for data_x,data_y,data_mean,data_m,data_delta,data_x_lengths,data_lastvalues,_,_,_ in dataset.nextBatch(): 237 | summary_str,acc,pred = self.sess.run([self.sum, self.accuracy,self.pred], feed_dict={\ 238 | self.x: data_x,\ 239 | self.y: data_y,\ 240 | self.m: data_m,\ 241 | self.delta: data_delta,\ 242 | self.mean: data_mean,\ 243 | self.x_lengths: data_x_lengths,\ 244 | self.lastvalues: data_lastvalues,\ 245 | self.keep_prob: 1.0}) 246 | 247 | try: 248 | auc = metrics.roc_auc_score(np.array(data_y),np.array(pred)) 249 | totalauc+=auc 250 | auccounter+=1 251 | print("Batch: %4d time: %4.4f, acc: %.8f, auc: %.8f" \ 252 | % ( counter, time.time() - start_time, acc, auc)) 253 | except ValueError: 254 | print("Batch: %4d time: %4.4f, acc: %.8f " \ 255 | % ( counter, time.time() - start_time, acc)) 256 | pass 257 | totalacc+=acc 258 | counter += 1 259 | totalacc=totalacc/counter 260 | try: 261 | totalauc=totalauc/auccounter 262 | except: 263 | pass 264 | print("epoch is : %2.2f, Total acc: %.8f, Total auc: %.8f , counter is : %.2f , auccounter is %.2f" % (epoch, totalacc,totalauc,counter,auccounter)) 265 | f=open(os.path.join(self.checkpoint_dir, self.model_dir(epoch), self.model_name,"final_acc_and_auc"),"w") 266 | f.write(str(totalacc)+","+str(totalauc)) 267 | f.close() 268 | return totalacc,totalauc,self.model_name 269 | 270 | -------------------------------------------------------------------------------- /GRUI/gru_impute_lastValue.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Jan 25 21:52:13 2018 5 | 6 | @author: lyh 7 | """ 8 | from __future__ import print_function 9 | import os 10 | import numpy as np 11 | from sklearn import metrics 12 | import time 13 | import mygru_cell 14 | import tensorflow as tf 15 | from tensorflow.python.ops import math_ops 16 | tf.set_random_seed(1) # set random seed 17 | 18 | class grud(object): 19 | model_name = "GRU_ImputeLastValue" 20 | def __init__(self, sess, args, dataset): 21 | self.lr = args.lr 22 | self.sess=sess 23 | self.isbatch_normal=args.isBatch_normal 24 | self.isNormal=args.isNormal 25 | self.isSlicing=args.isSlicing 26 | self.dataset=dataset 27 | self.epoch = args.epoch 28 | self.batch_size = args.batch_size 29 | self.n_inputs = args.n_inputs # MNIST data input (img shape: 28*28) 30 | self.n_steps = dataset.maxLength # time steps 31 | self.n_hidden_units = args.n_hidden_units # neurons in hidden layer 32 | self.n_classes = args.n_classes # MNIST classes (0-9 digits) 33 | self.run_type=args.run_type 34 | self.result_path=args.result_path 35 | self.model_path=args.model_path 36 | self.log_dir=args.log_dir 37 | self.checkpoint_dir=args.checkpoint_dir 38 | self.num_batches = len(dataset.x) // self.batch_size 39 | # x y placeholder 40 | self.keep_prob = tf.placeholder(tf.float32) 41 | self.x = tf.placeholder(tf.float32, [None, self.n_steps, self.n_inputs]) 42 | self.y = tf.placeholder(tf.float32, [None, self.n_classes]) 43 | self.m = tf.placeholder(tf.float32, [None, self.n_steps, self.n_inputs]) 44 | self.delta = tf.placeholder(tf.float32, [None, self.n_steps, self.n_inputs]) 45 | self.mean = tf.placeholder(tf.float32, [self.n_inputs,]) 46 | self.lastvalues = tf.placeholder(tf.float32, [None, self.n_steps, self.n_inputs]) 47 | self.x_lengths = tf.placeholder(tf.int32, shape=[self.batch_size,]) 48 | # 对 weights biases 初始值的定义 49 | 50 | 51 | #concatenate x and m 52 | #rth should be also concatenate after x, then decay the older state 53 | #rth's length is n_hidden_units 54 | 55 | 56 | 57 | def RNN(self,X, M, Delta, Mean, Lastvalues, X_lengths,Keep_prob, reuse=False): 58 | # 2*3*2 59 | # X: batches * steps, n_inputs 60 | # m:batches * steps, n_inputs 61 | # delta:batches * steps, n_inputs 62 | # mean:n_inputs mean of all observations, not contian the imputations 63 | # lastvalues: batches * steps, n_inputs last obsevation value of x, if x is missing 64 | # if lastvalues is zero, take mean as it 65 | 66 | with tf.variable_scope("grud", reuse=reuse): 67 | 68 | # then wr_x should be transformed into a diag matrix:tf.matrix_diag(wr_x) 69 | wr_h=tf.get_variable('wr_h',shape=[self.n_inputs,self.n_hidden_units],initializer=tf.random_normal_initializer()) 70 | w_out=tf.get_variable('w_out', shape=[self.n_hidden_units, self.n_classes],initializer=tf.random_normal_initializer()) 71 | 72 | br_h=tf.get_variable('br_h', shape=[self.n_hidden_units, ],initializer=tf.constant_initializer(0.001)) 73 | b_out=tf.get_variable('b_out', shape=[self.n_classes, ],initializer=tf.constant_initializer(0.001)) 74 | 75 | 76 | 77 | Lastvalues=tf.reshape(Lastvalues,[-1,self.n_inputs]) 78 | M=tf.reshape(M,[-1,self.n_inputs]) 79 | X = tf.reshape(X, [-1, self.n_inputs]) 80 | Delta=tf.reshape(Delta,[-1,self.n_inputs]) 81 | 82 | X=math_ops.multiply(X,M)+math_ops.multiply((1-M),Lastvalues) 83 | 84 | rth= tf.matmul( Delta, wr_h)+br_h 85 | rth=math_ops.exp(-tf.maximum(0.0,rth)) 86 | 87 | #X = tf.reshape(X, [-1, n_inputs]) 88 | #print(X.get_shape(),M.get_shape(),rth.get_shape()) 89 | X=tf.concat([X,rth],1) 90 | 91 | X_in = tf.reshape(X, [-1, self.n_steps, self.n_inputs+self.n_hidden_units]) 92 | 93 | #print(X_in.get_shape()) 94 | # X_in = W*X + b 95 | #X_in = tf.matmul(X, weights['in']) + biases['in'] 96 | # X_in ==> (128 batches, 28 steps, 128 hidden) 换回3维 97 | #X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units]) 98 | 99 | if "1.5" in tf.__version__ or "1.7" in tf.__version__ : 100 | grud_cell = mygru_cell.MyGRUCell15(self.n_hidden_units) 101 | elif "1.4" in tf.__version__: 102 | grud_cell = mygru_cell.MyGRUCell4(self.n_hidden_units) 103 | elif "1.2" in tf.__version__: 104 | grud_cell = mygru_cell.MyGRUCell2(self.n_hidden_units) 105 | init_state = grud_cell.zero_state(self.batch_size, dtype=tf.float32) # 初始化全零 state 106 | outputs, final_state = tf.nn.dynamic_rnn(grud_cell, X_in, \ 107 | initial_state=init_state,\ 108 | sequence_length=X_lengths, 109 | time_major=False) 110 | 111 | factor=tf.matrix_diag([1.0/9,1]) 112 | tempout=tf.matmul(tf.nn.dropout(final_state,Keep_prob), w_out) + b_out 113 | results =tf.nn.softmax(tf.matmul(tempout,factor)) #选取最后一个 output 114 | #todo: dropout of 0.5 and batch normalization 115 | return results 116 | def build(self): 117 | 118 | self.pred = self.RNN(self.x, self.m, self.delta, self. mean, self.lastvalues, self.x_lengths, self.keep_prob) 119 | self.cross_entropy = -tf.reduce_sum(self.y*tf.log(self.pred)) 120 | self.train_op = tf.train.AdamOptimizer(self.lr).minimize(self.cross_entropy) 121 | 122 | 123 | self.correct_pred = tf.equal(tf.argmax(self.pred, 1), tf.argmax(self.y, 1)) 124 | self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32)) 125 | self.saver = tf.train.Saver() 126 | 127 | loss_sum = tf.summary.scalar("loss", self.cross_entropy) 128 | acc_sum = tf.summary.scalar("acc", self.accuracy) 129 | 130 | self.sum=tf.summary.merge([loss_sum, acc_sum]) 131 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) 132 | 133 | 134 | @property 135 | def model_dir(self): 136 | return "{}_{}_{}_{}_{}_{}/epoch{}".format( 137 | self.model_name, self.lr, 138 | self.batch_size, self.isNormal, 139 | self.isbatch_normal,self.isSlicing, 140 | self.epoch 141 | ) 142 | 143 | def save(self, checkpoint_dir, step): 144 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 145 | 146 | if not os.path.exists(checkpoint_dir): 147 | os.makedirs(checkpoint_dir) 148 | 149 | self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step) 150 | 151 | def load(self, checkpoint_dir): 152 | import re 153 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 154 | 155 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 156 | if ckpt and ckpt.model_checkpoint_path: 157 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 158 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 159 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 160 | print(" [*] Success to read {}".format(ckpt_name)) 161 | return True, counter 162 | else: 163 | #print(" [*] Failed to find a checkpoint") 164 | return False, 0 165 | 166 | def train(self): 167 | 168 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 169 | if could_load: 170 | start_epoch = (int)(checkpoint_counter / self.num_batches) 171 | #start_batch_id = checkpoint_counter - start_epoch * self.num_batches 172 | start_batch_id=0 173 | #counter = checkpoint_counter 174 | counter=start_epoch*self.num_batches 175 | print(" [*] Load SUCCESS") 176 | return 177 | else: 178 | # initialize all variables 179 | tf.global_variables_initializer().run() 180 | counter = 1 181 | print(" [!] Load failed...") 182 | start_time=time.time() 183 | idx = 0 184 | epochcount=0 185 | # X: batches * steps, n_inputs 2*3*2 186 | # m:batches * steps, n_inputs 187 | # delta:batches * steps, n_inputs 188 | # mean:n_inputs mean of all observations, not contian the imputations 189 | # lastvalues: batches * steps, n_inputs last obsevation value of x, if x is missing 190 | # if lastvalues is zero, take mean as it 191 | # assume series1's time: 0,0.8,2 ;series2's time:0,1 192 | #data_x=[[[1,0],[3,2],[2,0]],[[0,2],[1,1],[0,0]]] 193 | #data_y=[[1,0],[0,1]] 194 | #data_m=[[[1,0],[1,1],[1,0]],[[0,1],[1,1],[0,0]]] 195 | #data_delta=[[[0,0],[0.8,0.8],[1.2,1.2]],[[0,0],[1,1],[0,0]]] 196 | #data_mean=[1.75,1.66667] 197 | #data_lastvalues=[[[1,1.66667],[3,2],[2,2]],[[1.75,2],[1,1],[0,0]]] 198 | #data_x_lengths=[3,2] 199 | dataset=self.dataset 200 | while epochcount