├── imgs ├── Fig1.png ├── Fig2.png ├── Fig3.jpg ├── Fig4.jpg ├── Fig5.png └── Fig6.jpg ├── temp_data ├── cost.npy ├── cost.png └── npy.txt ├── model ├── final_state.npy ├── model.ckpt.index ├── model.ckpt.meta ├── model.ckpt.data-00000-of-00001 └── checkpoint ├── __pycache__ ├── config.cpython-36.pyc ├── LSTM_model.cpython-36.pyc └── _5predict_by_model.cpython-36.pyc ├── .idea ├── vcs.xml ├── misc.xml ├── modules.xml ├── AnomalyDetection.iml └── workspace.xml ├── config.py ├── Untitled Diagram.drawio ├── _5draw_train_loss.py ├── _2npy_filter.py ├── _3generate_data_csv.py ├── _3draw_the_first_month.py ├── _1in_out_merge.py ├── _6predict_the_5th_week.py ├── _5predict_by_model.py ├── README.md ├── _4train_model.py └── LSTM_model.py /imgs/Fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSQianDong/AnomalyDetection/HEAD/imgs/Fig1.png -------------------------------------------------------------------------------- /imgs/Fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSQianDong/AnomalyDetection/HEAD/imgs/Fig2.png -------------------------------------------------------------------------------- /imgs/Fig3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSQianDong/AnomalyDetection/HEAD/imgs/Fig3.jpg -------------------------------------------------------------------------------- /imgs/Fig4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSQianDong/AnomalyDetection/HEAD/imgs/Fig4.jpg -------------------------------------------------------------------------------- /imgs/Fig5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSQianDong/AnomalyDetection/HEAD/imgs/Fig5.png -------------------------------------------------------------------------------- /imgs/Fig6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSQianDong/AnomalyDetection/HEAD/imgs/Fig6.jpg -------------------------------------------------------------------------------- /temp_data/cost.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSQianDong/AnomalyDetection/HEAD/temp_data/cost.npy -------------------------------------------------------------------------------- /temp_data/cost.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSQianDong/AnomalyDetection/HEAD/temp_data/cost.png -------------------------------------------------------------------------------- /model/final_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSQianDong/AnomalyDetection/HEAD/model/final_state.npy -------------------------------------------------------------------------------- /model/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSQianDong/AnomalyDetection/HEAD/model/model.ckpt.index -------------------------------------------------------------------------------- /model/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSQianDong/AnomalyDetection/HEAD/model/model.ckpt.meta -------------------------------------------------------------------------------- /__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSQianDong/AnomalyDetection/HEAD/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/LSTM_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSQianDong/AnomalyDetection/HEAD/__pycache__/LSTM_model.cpython-36.pyc -------------------------------------------------------------------------------- /model/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSQianDong/AnomalyDetection/HEAD/model/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /__pycache__/_5predict_by_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSQianDong/AnomalyDetection/HEAD/__pycache__/_5predict_by_model.cpython-36.pyc -------------------------------------------------------------------------------- /model/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "D:\\work\\python_project\\AnomalyDetection\\model\\model.ckpt" 2 | all_model_checkpoint_paths: "D:\\work\\python_project\\AnomalyDetection\\model\\model.ckpt" 3 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2018/11/16 9:18 4 | @author: Eric 5 | @email: qian.dong.2018@gmail.com 6 | """ 7 | BATCH_START = 0 # 建立 batch data 时候的 index 8 | TIME_STEPS = 168 # 时序数据的size 9 | BATCH_SIZE = 256 10 | INPUT_SIZE = 1 11 | OUTPUT_SIZE = 1 12 | CELL_SIZE = 64 # RNN 的 hidden unit size 13 | LR = 1e-5 # learning rate -------------------------------------------------------------------------------- /Untitled Diagram.drawio: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /.idea/AnomalyDetection.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /_5draw_train_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2018/11/16 11:15 4 | @author: Eric 5 | @email: qian.dong.2018@gmail.com 6 | """ 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | def main(): 12 | costs = np.load('cost.npy') 13 | plt.plot(costs) 14 | plt.xlabel("step") 15 | plt.ylabel("mean squre error") 16 | plt.title("loss-step graph") 17 | plt.show() 18 | costs = costs.reshape(10, -1) 19 | f, axs = plt.subplots(1, 1, figsize=(15, 5)) 20 | axs.boxplot([cost for cost in costs]) 21 | plt.ylabel("mean square error") 22 | plt.xlabel("epoch") 23 | plt.title("loss-epoch box graph") 24 | plt.show() 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /_2npy_filter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2018/11/15 20:04 4 | @author: Eric 5 | @email: qian.dong.2018@gmail.com 6 | """ 7 | import os 8 | import numpy as np 9 | 10 | 11 | def main(): 12 | root_path = '../data/in_out/' 13 | save_list = [] 14 | for root, dirs, files in os.walk(root_path): 15 | for npy in files: 16 | npy_ = os.path.join(root, npy) 17 | npy_ = np.load(npy_) 18 | pos_rate = 1 - sum(npy_ == 0) / len(npy_) # 非0值的比例 19 | if pos_rate > 0.667: 20 | print(os.path.join(root, npy)) 21 | save_list.append(os.path.join(root, npy)) 22 | with open('npy.txt', 'w')as f: 23 | for name in save_list: 24 | f.write(name + "\n") 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /_3generate_data_csv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2018/11/16 11:35 4 | @author: Eric 5 | @email: qian.dong.2018@gmail.com 6 | """ 7 | import numpy as np 8 | import pandas as pd 9 | 10 | 11 | def main(): 12 | f = open('./temp_data/npy.txt') 13 | lines = f.readlines() 14 | lines = [line.strip() for line in lines] 15 | f.close() 16 | data = np.array([[0] * 169] * 15063) 17 | cnt = 0 18 | for idx, line in enumerate(lines): 19 | npy = np.load(line) 20 | offset = idx % (24 * 7) 21 | for i in range((2088 - offset - 1) // (24 * 7)): 22 | data[cnt] = (npy[i * 24 * 7 + offset:(i + 1) * 24 * 7 + offset + 1]) 23 | cnt += 1 24 | data = pd.DataFrame(data) 25 | data.to_csv('./temp_data/data.csv', index=False) 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /_3draw_the_first_month.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2018/11/16 12:02 4 | @author: Eric 5 | @email: qian.dong.2018@gmail.com 6 | """ 7 | import matplotlib 8 | matplotlib.use("Agg") 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import os 13 | 14 | 15 | def main(): 16 | f = open("./temp_data/npy.txt") 17 | lines = f.readlines() 18 | lines = [line.strip() for line in lines] 19 | for line in lines: 20 | print(line) 21 | f, axs = plt.subplots(2, 1, figsize=(15, 5)) 22 | axs[0].plot(np.load(line)[0:24 * 7]) 23 | axs[1].plot(np.load(line)[0:24 * 7]) 24 | axs[1].plot(np.load(line)[24 * 7 * 1:24 * 7 * 2]) 25 | axs[1].plot(np.load(line)[24 * 7 * 2:24 * 7 * 3]) 26 | axs[1].plot(np.load(line)[24 * 7 * 3:24 * 7 * 4]) 27 | plt.xlabel("hour") 28 | plt.ylabel("number of in/out") 29 | if not os.path.exists("../data/first_month_figures"): 30 | os.makedirs("../data/first_month_figures") 31 | plt.savefig("../data/first_month_figures/%d.png" % (int(os.path.basename(line).split('.')[0]))) 32 | plt.show() 33 | 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /_1in_out_merge.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2018/11/15 20:04 4 | @author: Eric 5 | @email: qian.dong.2018@gmail.com 6 | """ 7 | import os 8 | import numpy as np 9 | 10 | 11 | def main(): 12 | if not os.path.exists("./in"): 13 | os.makedirs("./in") 14 | if not os.path.exists("./out"): 15 | os.makedirs("./out") 16 | gridID_in_hourly = {} 17 | gridID_out_hourly = {} 18 | for i in range(104000): 19 | gridID_in_hourly[i] = [0] * 2088 20 | gridID_out_hourly[i] = [0] * 2088 21 | root_path = "/datahouse/tripflow/ano_detect/200/bj-byhour-io/" 22 | for io_file in os.listdir(root_path): 23 | hour = int(io_file.split('-')[-1]) 24 | io_file = os.path.join(root_path, io_file) 25 | print(io_file) 26 | with open(io_file) as f: 27 | lines = f.readlines() 28 | lines = [line.strip().split(',') for line in lines] 29 | lines = np.array(lines) 30 | lines = lines.astype(int) 31 | for line in lines: 32 | gridID_out_hourly[line[0]][hour] = line[1] 33 | gridID_in_hourly[line[0]][hour] = line[2] 34 | for ID in gridID_in_hourly: 35 | np.save("./in/%d.npy" % ID, gridID_in_hourly[ID]) 36 | np.save("./out/%d.npy" % ID, gridID_out_hourly[ID]) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /_6predict_the_5th_week.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2018/11/16 12:21 4 | @author: Eric 5 | @email: qian.dong.2018@gmail.com 6 | """ 7 | from _5predict_by_model import predict_by_model 8 | import matplotlib 9 | matplotlib.use("Agg") 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import os 13 | 14 | 15 | def predict_5th_week(npy): 16 | test = np.load(npy) 17 | predicts = [] 18 | for i in range(24 * 7): 19 | print(i, npy) 20 | seq = test[24 * 7 * 3 + i:24 * 7 * 4 + i] 21 | predicts.append(predict_by_model(seq)) 22 | return predicts 23 | 24 | 25 | def draw_predict(predicts, npy): 26 | test = np.load(npy) 27 | f, axs = plt.subplots(2, 1, figsize=(15, 5)) 28 | axs[0].plot(test[24 * 7 * 4:24 * 7 * 5] - 3, 'r') 29 | axs[0].plot(test[24 * 7 * 4:24 * 7 * 5], 'b') 30 | axs[0].plot(test[24 * 7 * 4:24 * 7 * 5] + 3, 'r') 31 | axs[0].plot(predicts, 'g') 32 | axs[1].plot(test[24 * 7 * 4:24 * 7 * 5], 'b') 33 | axs[1].plot(predicts, 'g') 34 | plt.xlabel("hour") 35 | plt.ylabel("number of in/out") 36 | plt.savefig("../data/predict_result/%d.png"%(int(os.path.basename(npy).split('.')[0]))) 37 | plt.show() 38 | 39 | 40 | def main(): 41 | f = open("./temp_data/npy.txt") 42 | lines = f.readlines() 43 | npys = [line.strip() for line in lines] 44 | for npy in npys[0:10]: 45 | predict = predict_5th_week(npy) 46 | draw_predict(predict, npy) 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /_5predict_by_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2018/11/16 9:14 4 | @author: Eric 5 | @email: qian.dong.2018@gmail.com 6 | """ 7 | import numpy as np 8 | import pandas as pd 9 | import tensorflow as tf 10 | from config import * 11 | 12 | Data = pd.read_csv('./temp_data/data.csv') 13 | Indices = np.arange(Data.shape[0]) # 随机打乱索引并切分训练集与测试集 14 | np.random.shuffle(Indices) 15 | Data = np.array(Data) 16 | 17 | 18 | def get_batch(): 19 | global BATCH_START, BATCH_SIZE 20 | # xs shape (50batch, 20steps) 21 | data = Data[Indices[BATCH_START:BATCH_START + BATCH_SIZE]] 22 | seq = np.array([d[0:TIME_STEPS].tolist() for d in data]) 23 | res = np.array([d[-1] for d in data]).reshape(-1, 1) 24 | BATCH_START += BATCH_SIZE 25 | return [seq[:, :, np.newaxis], res] 26 | 27 | 28 | def get_data(num): 29 | data = Data[num] 30 | seq = np.array(data[0:TIME_STEPS]) 31 | res = data[-1] 32 | return [seq.reshape(-1, 1), res] 33 | 34 | 35 | def predict_by_model(seq): # seq shape为 [TIME_STEPS, 1] [24*7, 1] 36 | seq = np.array(seq) 37 | # reshape为[BATCH_SIZE, 24*7, 1] 38 | seq = np.array([seq.tolist() for _ in range(BATCH_SIZE)]).reshape(BATCH_SIZE, 24*7, 1) 39 | with tf.Session() as sess: 40 | new_saver = tf.train.import_meta_graph('./model/model.ckpt.meta') 41 | new_saver.restore(sess, tf.train.latest_checkpoint('./model')) 42 | graph = tf.get_default_graph() 43 | predict = graph.get_tensor_by_name("out_hidden/Wx_plus_b/Maximum:0") 44 | xs = graph.get_tensor_by_name("inputs/xs:0") 45 | cell_init_state = graph.get_tensor_by_name("LSTM_cell/initial_state/BasicLSTMCellZeroState/zeros:0") 46 | state = np.load("./model/final_state.npy") 47 | pred = sess.run(predict, feed_dict={xs: seq, cell_init_state: state}) 48 | return pred[0][0] 49 | 50 | 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AnomalyDetection 2 | ## LSTM对时间序列的数据做预测,通过偏离值的大小达到异常检测的目的 3 | ### 1.思路 4 |   用一个网格7x24个小时的数据预测该网格第7x24+1个小时的in out,如果这个预测in out与真实值的差距大于一个阈值(该阈值最终取3,见Fig3),就把这小时对应的网格当做一个异常的网格,可以达到实时检测的效果 5 |
6 | 7 | ####
Fig1.某网格7*24小时(一周)的时序图
8 | 9 |
10 | 11 | ####
Fig2.某网格相邻四周的时序图
12 | 13 |
14 | 15 | ####
Fig3_1.训练loss与epoch的箱型图
16 | 17 |
18 | 19 | ####
Fig3_2.最后一个epoch的loss箱型图
20 | 21 | ####
(异常的阈值取3因为最后一个epoch的均方差loss基本小于9)
22 | 23 |
24 | 25 | ####
Fig4.训练loss与每个step的曲线图
26 | 27 |
28 | 29 | ####
Fig5.某网格第五周的预测图
30 | 31 | ####
Fig5中红线是真实值±3,代表异常的阈值曲线,绿线是预测值曲线,蓝线是真实值的曲线,绿线超出红线即视为异常
32 | 33 | ### 2.流程 34 | 35 | >1)过滤掉稀疏的网格,一个网格用一个向量存储2088个小时的in/out,如果2088个值里边有超过三分之一的值为0,则过滤掉,不用该网格的数据进行模型训练。原因:一天24小时,可以有三分之一的时间比如0a.m.-8a.m. 没有in/out,其他时间段有in/out,比较符合该数据合理的直觉,共1317个符合条件的向量,原始数据有104000*2个向量 36 | 37 | >2)最终训练数据生成:将符合条件的向量用一个24*7的窗口划分 38 | 39 | >>A.为了减少模型过拟合,窗口之间不重合 40 | 41 | >>B.为了增加模型的适用性,划分窗口时加一个offset,使得模型从一周的随便一个小时起开始预测都可以, 42 | 最终有大约1317*(2088/(24*7))=16368个训练样本 43 | 44 | 45 | [in out and some temp images zip file](https://pan.baidu.com/s/1xtjGd2vXjHrTTpBo_UntvQ) 46 | -------------------------------------------------------------------------------- /_4train_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2018/11/15 17:33 4 | @author: Eric 5 | @email: qian.dong.2018@gmail.com 6 | """ 7 | import os 8 | import numpy as np 9 | import pandas as pd 10 | import tensorflow as tf 11 | from LSTM_model import LSTMRNN 12 | import matplotlib.pyplot as plt 13 | from config import * 14 | 15 | Data = pd.read_csv('./temp_data/data.csv') 16 | Indices = np.arange(Data.shape[0]) 17 | np.random.shuffle(Indices) 18 | Data = np.array(Data) 19 | 20 | 21 | def get_batch(): 22 | global BATCH_START, BATCH_SIZE 23 | data = Data[Indices[BATCH_START:BATCH_START + BATCH_SIZE]] 24 | seq = np.array([d[0:TIME_STEPS].tolist() for d in data]) 25 | res = np.array([d[-1] for d in data]).reshape(-1, 1) 26 | BATCH_START += BATCH_SIZE 27 | return [seq[:, :, np.newaxis], res] 28 | 29 | 30 | if __name__ == '__main__': 31 | with tf.Session() as sess: 32 | model = LSTMRNN(TIME_STEPS, INPUT_SIZE, OUTPUT_SIZE, CELL_SIZE, BATCH_SIZE) 33 | sess.run(tf.global_variables_initializer()) 34 | saver = tf.train.Saver() 35 | costs = [] 36 | begin = True 37 | # 训练 10 次 38 | for epoch in range(100): 39 | BATCH_START = 0 40 | np.random.shuffle(Indices) 41 | for i in range(len(Indices) // BATCH_SIZE): 42 | seq, res = get_batch() # 提取 batch data 43 | if begin: 44 | begin = False 45 | # 初始化 data 46 | feed_dict = { 47 | model.xs: seq, 48 | model.ys: res, 49 | } 50 | else: 51 | feed_dict = { 52 | model.xs: seq, 53 | model.ys: res, 54 | model.cell_init_state: state # 保持 state 的连续性 55 | } 56 | # 训练 57 | _, cost, state, pred = sess.run( 58 | [model.train_op, model.cost, model.cell_final_state, model.pred], 59 | feed_dict=feed_dict) 60 | print('epoch:%d step:%d cost: ' % (epoch, i), round(cost, 4)) 61 | costs.append(cost) 62 | # print(pred) 63 | np.save("./model/final_state.npy", state) 64 | np.save("./temp_data/cost.npy", costs) 65 | plt.plot(costs) 66 | plt.show() 67 | plt.savefig("./temp_data/cost.png") 68 | if not os.path.exists("model"): 69 | os.makedirs("model") 70 | saver.save(sess=sess, save_path=os.path.join(os.getcwd(), 'model', 'model.ckpt')) 71 | -------------------------------------------------------------------------------- /LSTM_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2018/11/16 9:58 4 | @author: Eric 5 | @email: qian.dong.2018@gmail.com 6 | """ 7 | import tensorflow as tf 8 | from config import * 9 | class LSTMRNN(object): 10 | def __init__(self, n_steps, input_size, output_size, cell_size, batch_size): 11 | self.n_steps = n_steps 12 | self.input_size = input_size 13 | self.output_size = output_size 14 | self.cell_size = cell_size 15 | self.batch_size = batch_size 16 | with tf.name_scope('inputs'): 17 | self.xs = tf.placeholder(tf.float32, [None, n_steps, input_size], name='xs') 18 | self.ys = tf.placeholder(tf.float32, [None, output_size], name='ys') 19 | print("xs, ys:", self.xs,self.ys) 20 | with tf.variable_scope('in_hidden'): 21 | self.add_input_layer() 22 | with tf.variable_scope('LSTM_cell'): 23 | self.add_cell() 24 | with tf.variable_scope('out_hidden'): 25 | self.add_output_layer() 26 | with tf.name_scope('cost'): 27 | self.compute_cost() 28 | with tf.name_scope('train'): 29 | self.train_op = tf.train.AdamOptimizer(LR).minimize(self.cost) 30 | print("train_op:", self.train_op) 31 | 32 | def add_input_layer(self, ): 33 | l_in_x = tf.reshape(self.xs, [-1, self.input_size], name='2_2D') # (batch*n_step, in_size) 34 | # Ws (in_size, cell_size) 35 | Ws_in = self._weight_variable([self.input_size, self.cell_size]) 36 | # bs (cell_size, ) 37 | bs_in = self._bias_variable([self.cell_size, ]) 38 | # l_in_y = (batch * n_steps, cell_size) 39 | with tf.name_scope('Wx_plus_b'): 40 | l_in_y = tf.matmul(l_in_x, Ws_in) + bs_in 41 | # reshape l_in_y ==> (batch, n_steps, cell_size) 42 | self.l_in_y = tf.reshape(l_in_y, [-1, self.n_steps, self.cell_size], name='2_3D') 43 | 44 | def add_cell(self): 45 | lstm_cell = tf.contrib.rnn.BasicLSTMCell(self.cell_size, forget_bias=0.3, state_is_tuple=False) 46 | with tf.name_scope('initial_state'): 47 | self.cell_init_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32) 48 | print("cell_init_state:", self.cell_init_state) 49 | self.cell_outputs, self.cell_final_state = tf.nn.dynamic_rnn( 50 | lstm_cell, self.l_in_y, initial_state=self.cell_init_state, time_major=False) 51 | print("cell_final_state:",self.cell_final_state) 52 | 53 | def add_output_layer(self): 54 | # shape = (batch * steps, cell_size) 55 | print("cell_outputs:", self.cell_outputs) # 50 20 10 56 | l_out_x = tf.reshape(self.cell_outputs, [self.batch_size, -1], name='2_2D') 57 | Ws_out = self._weight_variable([self.n_steps*self.cell_size, self.output_size]) 58 | bs_out = self._bias_variable([self.output_size, ]) 59 | # shape = (batch * steps, output_size) 60 | with tf.name_scope('Wx_plus_b'): 61 | self.pred = tf.matmul(l_out_x, Ws_out) + bs_out 62 | self.predict = tf.maximum(tf.rint(tf.matmul(l_out_x, Ws_out) + bs_out), tf.zeros([self.batch_size,1], tf.float32)) 63 | print("predict:",self.predict) 64 | 65 | def compute_cost(self): 66 | print("pred:", self.pred) 67 | losses = self.ms_error(self.pred, self.ys) 68 | with tf.name_scope('average_cost'): 69 | self.cost = tf.div( 70 | tf.reduce_sum(losses, name='losses_sum'), 71 | self.batch_size, 72 | name='average_cost') 73 | print("cost:", self.cost) 74 | # tf.summary.scalar('cost', self.cost) 75 | 76 | def ms_error(self, labels, logits): 77 | return tf.square(tf.subtract(labels, logits)) 78 | 79 | def _weight_variable(self, shape, name='weights'): 80 | initializer = tf.random_normal_initializer(mean=0., stddev=1., ) 81 | return tf.get_variable(shape=shape, initializer=initializer, name=name) 82 | 83 | def _bias_variable(self, shape, name='biases'): 84 | initializer = tf.constant_initializer(0.1) 85 | return tf.get_variable(name=name, shape=shape, initializer=initializer) -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 114 | 115 | 116 | 117 | train_op 118 | cost 119 | cell_init_state 120 | cell_final_state 121 | print 122 | cell_fin 123 | init 124 | zero 125 | pred 126 | LSTMRNN 127 | 150 128 | 129 | 130 | 300 131 | 132 | 133 | 134 | 136 | 137 | 162 | 163 | 164 | 165 | 166 | true 167 | DEFINITION_ORDER 168 | 169 | 170 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 |