├── test.csv ├── test.py ├── README.md ├── lstm_airline_predict.py └── international-airline-passengers.csv /test.csv: -------------------------------------------------------------------------------- 1 | 1,2 2 | 2,3 3 | 3,4 4 | 5,6 5 | 7,8 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import matplotlib.pyplot as plt 3 | import matplotlib.dates as mdates 4 | import numpy as np 5 | import pandas as pd 6 | 7 | df = pd.read_csv('test.csv', sep=',') 8 | print(df.head(5)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LSTM_learn 2 | a implement of LSTM using Keras for time series prediction regression problem 3 | ### Data 4 | the data were from internet, this data was using for predict the number of people in a airline company, we use LSTM network to solve this problem 5 | ### Denpensies 6 | for the implemention of code, we using Keras to establish LSTM network, as well as using numpy, pandas, so before you runing this tutorial, it is strongly recommended you install Anaconda which is a package inclueded them all. 7 | ### open source protocol 8 | MIT 9 | ### contact 10 | author: jinfagang19@163.com 11 | Central South University, Mr. Jin 12 | -------------------------------------------------------------------------------- /lstm_airline_predict.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from sklearn.preprocessing import MinMaxScaler 5 | from keras.models import Sequential 6 | from keras.layers import LSTM, Dense, Activation 7 | 8 | 9 | def load_data(file_name, sequence_length=10, split=0.8): 10 | df = pd.read_csv(file_name, sep=',', usecols=[1]) 11 | data_all = np.array(df).astype(float) 12 | scaler = MinMaxScaler() 13 | data_all = scaler.fit_transform(data_all) 14 | data = [] 15 | for i in range(len(data_all) - sequence_length - 1): 16 | data.append(data_all[i: i + sequence_length + 1]) 17 | reshaped_data = np.array(data).astype('float64') 18 | np.random.shuffle(reshaped_data) 19 | # 对x进行统一归一化,而y则不归一化 20 | x = reshaped_data[:, :-1] 21 | y = reshaped_data[:, -1] 22 | split_boundary = int(reshaped_data.shape[0] * split) 23 | train_x = x[: split_boundary] 24 | test_x = x[split_boundary:] 25 | 26 | train_y = y[: split_boundary] 27 | test_y = y[split_boundary:] 28 | 29 | return train_x, train_y, test_x, test_y, scaler 30 | 31 | 32 | def build_model(): 33 | # input_dim是输入的train_x的最后一个维度,train_x的维度为(n_samples, time_steps, input_dim) 34 | model = Sequential() 35 | model.add(LSTM(input_dim=1, output_dim=50, return_sequences=True)) 36 | print(model.layers) 37 | model.add(LSTM(100, return_sequences=False)) 38 | model.add(Dense(output_dim=1)) 39 | model.add(Activation('linear')) 40 | 41 | model.compile(loss='mse', optimizer='rmsprop') 42 | return model 43 | 44 | 45 | def train_model(train_x, train_y, test_x, test_y): 46 | model = build_model() 47 | 48 | try: 49 | model.fit(train_x, train_y, batch_size=512, nb_epoch=30, validation_split=0.1) 50 | predict = model.predict(test_x) 51 | predict = np.reshape(predict, (predict.size, )) 52 | except KeyboardInterrupt: 53 | print(predict) 54 | print(test_y) 55 | print(predict) 56 | print(test_y) 57 | try: 58 | fig = plt.figure(1) 59 | plt.plot(predict, 'r:') 60 | plt.plot(test_y, 'g-') 61 | plt.legend(['predict', 'true']) 62 | except Exception as e: 63 | print(e) 64 | return predict, test_y 65 | 66 | 67 | if __name__ == '__main__': 68 | train_x, train_y, test_x, test_y, scaler = load_data('international-airline-passengers.csv') 69 | train_x = np.reshape(train_x, (train_x.shape[0], train_x.shape[1], 1)) 70 | test_x = np.reshape(test_x, (test_x.shape[0], test_x.shape[1], 1)) 71 | predict_y, test_y = train_model(train_x, train_y, test_x, test_y) 72 | predict_y = scaler.inverse_transform([[i] for i in predict_y]) 73 | test_y = scaler.inverse_transform(test_y) 74 | fig2 = plt.figure(2) 75 | plt.plot(predict_y, 'g:') 76 | plt.plot(test_y, 'r-') 77 | plt.show() 78 | 79 | -------------------------------------------------------------------------------- /international-airline-passengers.csv: -------------------------------------------------------------------------------- 1 | time,passengers 2 | "1949-01",112 3 | "1949-02",118 4 | "1949-03",132 5 | "1949-04",129 6 | "1949-05",121 7 | "1949-06",135 8 | "1949-07",148 9 | "1949-08",148 10 | "1949-09",136 11 | "1949-10",119 12 | "1949-11",104 13 | "1949-12",118 14 | "1950-01",115 15 | "1950-02",126 16 | "1950-03",141 17 | "1950-04",135 18 | "1950-05",125 19 | "1950-06",149 20 | "1950-07",170 21 | "1950-08",170 22 | "1950-09",158 23 | "1950-10",133 24 | "1950-11",114 25 | "1950-12",140 26 | "1951-01",145 27 | "1951-02",150 28 | "1951-03",178 29 | "1951-04",163 30 | "1951-05",172 31 | "1951-06",178 32 | "1951-07",199 33 | "1951-08",199 34 | "1951-09",184 35 | "1951-10",162 36 | "1951-11",146 37 | "1951-12",166 38 | "1952-01",171 39 | "1952-02",180 40 | "1952-03",193 41 | "1952-04",181 42 | "1952-05",183 43 | "1952-06",218 44 | "1952-07",230 45 | "1952-08",242 46 | "1952-09",209 47 | "1952-10",191 48 | "1952-11",172 49 | "1952-12",194 50 | "1953-01",196 51 | "1953-02",196 52 | "1953-03",236 53 | "1953-04",235 54 | "1953-05",229 55 | "1953-06",243 56 | "1953-07",264 57 | "1953-08",272 58 | "1953-09",237 59 | "1953-10",211 60 | "1953-11",180 61 | "1953-12",201 62 | "1954-01",204 63 | "1954-02",188 64 | "1954-03",235 65 | "1954-04",227 66 | "1954-05",234 67 | "1954-06",264 68 | "1954-07",302 69 | "1954-08",293 70 | "1954-09",259 71 | "1954-10",229 72 | "1954-11",203 73 | "1954-12",229 74 | "1955-01",242 75 | "1955-02",233 76 | "1955-03",267 77 | "1955-04",269 78 | "1955-05",270 79 | "1955-06",315 80 | "1955-07",364 81 | "1955-08",347 82 | "1955-09",312 83 | "1955-10",274 84 | "1955-11",237 85 | "1955-12",278 86 | "1956-01",284 87 | "1956-02",277 88 | "1956-03",317 89 | "1956-04",313 90 | "1956-05",318 91 | "1956-06",374 92 | "1956-07",413 93 | "1956-08",405 94 | "1956-09",355 95 | "1956-10",306 96 | "1956-11",271 97 | "1956-12",306 98 | "1957-01",315 99 | "1957-02",301 100 | "1957-03",356 101 | "1957-04",348 102 | "1957-05",355 103 | "1957-06",422 104 | "1957-07",465 105 | "1957-08",467 106 | "1957-09",404 107 | "1957-10",347 108 | "1957-11",305 109 | "1957-12",336 110 | "1958-01",340 111 | "1958-02",318 112 | "1958-03",362 113 | "1958-04",348 114 | "1958-05",363 115 | "1958-06",435 116 | "1958-07",491 117 | "1958-08",505 118 | "1958-09",404 119 | "1958-10",359 120 | "1958-11",310 121 | "1958-12",337 122 | "1959-01",360 123 | "1959-02",342 124 | "1959-03",406 125 | "1959-04",396 126 | "1959-05",420 127 | "1959-06",472 128 | "1959-07",548 129 | "1959-08",559 130 | "1959-09",463 131 | "1959-10",407 132 | "1959-11",362 133 | "1959-12",405 134 | "1960-01",417 135 | "1960-02",391 136 | "1960-03",419 137 | "1960-04",461 138 | "1960-05",472 139 | "1960-06",535 140 | "1960-07",622 141 | "1960-08",606 142 | "1960-09",508 143 | "1960-10",461 144 | "1960-11",390 145 | "1960-12",432 --------------------------------------------------------------------------------