├── CAMB-algorithms.ipynb
├── ElasticNet+SARIMA.ipynb
├── README.md
├── SVM+SARIMA.ipynb
├── classic-machine-learning-algorithms.ipynb
├── data.xlsx
├── neural-network-algorithms.ipynb
└── time-series-analysis-algorithms.ipynb
/CAMB-algorithms.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import matplotlib.pyplot as plt\n",
11 | "import pandas as pd\n",
12 | "from sklearn.metrics import mean_squared_error\n",
13 | "from sklearn.metrics import mean_absolute_error\n",
14 | "import statsmodels.api as sm\n",
15 | "%matplotlib inline\n",
16 | "plt.style.use('fivethirtyeight')"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 2,
22 | "metadata": {},
23 | "outputs": [
24 | {
25 | "data": {
26 | "text/html": [
27 | "
\n",
28 | "\n",
41 | "
\n",
42 | " \n",
43 | " \n",
44 | " | \n",
45 | " 年 | \n",
46 | " 月 | \n",
47 | " ASPL | \n",
48 | " BME | \n",
49 | " BASPL | \n",
50 | " BLD | \n",
51 | " BLW | \n",
52 | " BJ | \n",
53 | " JN | \n",
54 | " LBL | \n",
55 | " SN | \n",
56 | " TJ | \n",
57 | " ST | \n",
58 | " SX | \n",
59 | " XMD | \n",
60 | " YXX | \n",
61 | "
\n",
62 | " \n",
63 | " \n",
64 | " \n",
65 | " 55 | \n",
66 | " 2019 | \n",
67 | " 8 | \n",
68 | " 3.763248e+06 | \n",
69 | " 62854.619304 | \n",
70 | " 8.599843e+07 | \n",
71 | " 1.650015e+08 | \n",
72 | " 3.481067e+08 | \n",
73 | " 1.060927e+06 | \n",
74 | " 2.359482e+06 | \n",
75 | " 13219.096853 | \n",
76 | " 6.456187e+05 | \n",
77 | " 1.758874e+08 | \n",
78 | " 2.961155e+07 | \n",
79 | " 1.949292e+07 | \n",
80 | " 56728.894684 | \n",
81 | " 107392.856432 | \n",
82 | "
\n",
83 | " \n",
84 | " 56 | \n",
85 | " 2019 | \n",
86 | " 9 | \n",
87 | " 4.294952e+06 | \n",
88 | " 51559.590738 | \n",
89 | " 1.052790e+08 | \n",
90 | " 1.811433e+08 | \n",
91 | " 4.621942e+08 | \n",
92 | " 1.135383e+06 | \n",
93 | " 2.416740e+06 | \n",
94 | " 224849.311140 | \n",
95 | " 1.233509e+06 | \n",
96 | " 2.395380e+08 | \n",
97 | " 3.489288e+07 | \n",
98 | " 2.116995e+07 | \n",
99 | " 26868.589363 | \n",
100 | " 227671.344996 | \n",
101 | "
\n",
102 | " \n",
103 | " 57 | \n",
104 | " 2019 | \n",
105 | " 10 | \n",
106 | " 2.835714e+06 | \n",
107 | " 31933.698695 | \n",
108 | " 8.023482e+07 | \n",
109 | " 1.674704e+08 | \n",
110 | " 3.022591e+08 | \n",
111 | " 1.023727e+06 | \n",
112 | " 2.074879e+06 | \n",
113 | " 32356.283055 | \n",
114 | " 4.381713e+05 | \n",
115 | " 1.561968e+08 | \n",
116 | " 2.178650e+07 | \n",
117 | " 1.655320e+07 | \n",
118 | " 32469.492085 | \n",
119 | " 134291.930366 | \n",
120 | "
\n",
121 | " \n",
122 | " 58 | \n",
123 | " 2019 | \n",
124 | " 11 | \n",
125 | " 3.188248e+06 | \n",
126 | " 42248.757526 | \n",
127 | " 9.053573e+07 | \n",
128 | " 1.624495e+08 | \n",
129 | " 2.918764e+08 | \n",
130 | " 1.297557e+06 | \n",
131 | " 2.163609e+06 | \n",
132 | " 136809.701451 | \n",
133 | " 8.436577e+05 | \n",
134 | " 1.604672e+08 | \n",
135 | " 2.843713e+07 | \n",
136 | " 1.799922e+07 | \n",
137 | " 30197.836641 | \n",
138 | " 201416.711500 | \n",
139 | "
\n",
140 | " \n",
141 | " 59 | \n",
142 | " 2019 | \n",
143 | " 12 | \n",
144 | " 4.437618e+06 | \n",
145 | " 67250.974962 | \n",
146 | " 1.034565e+08 | \n",
147 | " 1.545455e+08 | \n",
148 | " 2.355498e+08 | \n",
149 | " 1.336709e+06 | \n",
150 | " 2.369512e+06 | \n",
151 | " 246160.303870 | \n",
152 | " 6.326834e+05 | \n",
153 | " 1.424411e+08 | \n",
154 | " 2.360613e+07 | \n",
155 | " 1.968310e+07 | \n",
156 | " 20921.125742 | \n",
157 | " 157548.075168 | \n",
158 | "
\n",
159 | " \n",
160 | "
\n",
161 | "
"
162 | ],
163 | "text/plain": [
164 | " 年 月 ASPL BME BASPL BLD \\\n",
165 | "55 2019 8 3.763248e+06 62854.619304 8.599843e+07 1.650015e+08 \n",
166 | "56 2019 9 4.294952e+06 51559.590738 1.052790e+08 1.811433e+08 \n",
167 | "57 2019 10 2.835714e+06 31933.698695 8.023482e+07 1.674704e+08 \n",
168 | "58 2019 11 3.188248e+06 42248.757526 9.053573e+07 1.624495e+08 \n",
169 | "59 2019 12 4.437618e+06 67250.974962 1.034565e+08 1.545455e+08 \n",
170 | "\n",
171 | " BLW BJ JN LBL SN \\\n",
172 | "55 3.481067e+08 1.060927e+06 2.359482e+06 13219.096853 6.456187e+05 \n",
173 | "56 4.621942e+08 1.135383e+06 2.416740e+06 224849.311140 1.233509e+06 \n",
174 | "57 3.022591e+08 1.023727e+06 2.074879e+06 32356.283055 4.381713e+05 \n",
175 | "58 2.918764e+08 1.297557e+06 2.163609e+06 136809.701451 8.436577e+05 \n",
176 | "59 2.355498e+08 1.336709e+06 2.369512e+06 246160.303870 6.326834e+05 \n",
177 | "\n",
178 | " TJ ST SX XMD YXX \n",
179 | "55 1.758874e+08 2.961155e+07 1.949292e+07 56728.894684 107392.856432 \n",
180 | "56 2.395380e+08 3.489288e+07 2.116995e+07 26868.589363 227671.344996 \n",
181 | "57 1.561968e+08 2.178650e+07 1.655320e+07 32469.492085 134291.930366 \n",
182 | "58 1.604672e+08 2.843713e+07 1.799922e+07 30197.836641 201416.711500 \n",
183 | "59 1.424411e+08 2.360613e+07 1.968310e+07 20921.125742 157548.075168 "
184 | ]
185 | },
186 | "execution_count": 2,
187 | "metadata": {},
188 | "output_type": "execute_result"
189 | }
190 | ],
191 | "source": [
192 | "#读取数据 \n",
193 | "data=pd.read_excel('data.xlsx') \n",
194 | "data=data.iloc[:,:]\n",
195 | "choose='BASPL'#选取这一列\n",
196 | "data.tail()"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 3,
202 | "metadata": {},
203 | "outputs": [
204 | {
205 | "data": {
206 | "text/plain": [
207 | "Text(0.5, 1.0, 'BASPL')"
208 | ]
209 | },
210 | "execution_count": 3,
211 | "metadata": {},
212 | "output_type": "execute_result"
213 | },
214 | {
215 | "data": {
216 | "image/png": "\n",
217 | "text/plain": [
218 | ""
219 | ]
220 | },
221 | "metadata": {
222 | "needs_background": "light"
223 | },
224 | "output_type": "display_data"
225 | }
226 | ],
227 | "source": [
228 | "#展示一下看看\n",
229 | "data=data[choose]\n",
230 | "plt.figure(figsize=(12,6))\n",
231 | "plt.plot(data)\n",
232 | "#plt.grid()\n",
233 | "plt.title(choose,fontsize='15') #添加标题"
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "execution_count": 4,
239 | "metadata": {},
240 | "outputs": [
241 | {
242 | "data": {
243 | "text/html": [
244 | "\n",
245 | "\n",
258 | "
\n",
259 | " \n",
260 | " \n",
261 | " | \n",
262 | " 0 | \n",
263 | " 1 | \n",
264 | " 2 | \n",
265 | " 3 | \n",
266 | " 4 | \n",
267 | " 5 | \n",
268 | " 6 | \n",
269 | " 7 | \n",
270 | " 8 | \n",
271 | " 9 | \n",
272 | " 10 | \n",
273 | " 11 | \n",
274 | " 12 | \n",
275 | "
\n",
276 | " \n",
277 | " \n",
278 | " \n",
279 | " 43 | \n",
280 | " 8.246982 | \n",
281 | " 9.620531 | \n",
282 | " 6.602126 | \n",
283 | " 8.713054 | \n",
284 | " 8.821096 | \n",
285 | " 11.614596 | \n",
286 | " 6.550721 | \n",
287 | " 9.362264 | \n",
288 | " 9.868245 | \n",
289 | " 8.736202 | \n",
290 | " 9.008205 | \n",
291 | " 9.361410 | \n",
292 | " 8.599843 | \n",
293 | "
\n",
294 | " \n",
295 | " 44 | \n",
296 | " 9.620531 | \n",
297 | " 6.602126 | \n",
298 | " 8.713054 | \n",
299 | " 8.821096 | \n",
300 | " 11.614596 | \n",
301 | " 6.550721 | \n",
302 | " 9.362264 | \n",
303 | " 9.868245 | \n",
304 | " 8.736202 | \n",
305 | " 9.008205 | \n",
306 | " 9.361410 | \n",
307 | " 8.599843 | \n",
308 | " 10.527895 | \n",
309 | "
\n",
310 | " \n",
311 | " 45 | \n",
312 | " 6.602126 | \n",
313 | " 8.713054 | \n",
314 | " 8.821096 | \n",
315 | " 11.614596 | \n",
316 | " 6.550721 | \n",
317 | " 9.362264 | \n",
318 | " 9.868245 | \n",
319 | " 8.736202 | \n",
320 | " 9.008205 | \n",
321 | " 9.361410 | \n",
322 | " 8.599843 | \n",
323 | " 10.527895 | \n",
324 | " 8.023482 | \n",
325 | "
\n",
326 | " \n",
327 | " 46 | \n",
328 | " 8.713054 | \n",
329 | " 8.821096 | \n",
330 | " 11.614596 | \n",
331 | " 6.550721 | \n",
332 | " 9.362264 | \n",
333 | " 9.868245 | \n",
334 | " 8.736202 | \n",
335 | " 9.008205 | \n",
336 | " 9.361410 | \n",
337 | " 8.599843 | \n",
338 | " 10.527895 | \n",
339 | " 8.023482 | \n",
340 | " 9.053573 | \n",
341 | "
\n",
342 | " \n",
343 | " 47 | \n",
344 | " 8.821096 | \n",
345 | " 11.614596 | \n",
346 | " 6.550721 | \n",
347 | " 9.362264 | \n",
348 | " 9.868245 | \n",
349 | " 8.736202 | \n",
350 | " 9.008205 | \n",
351 | " 9.361410 | \n",
352 | " 8.599843 | \n",
353 | " 10.527895 | \n",
354 | " 8.023482 | \n",
355 | " 9.053573 | \n",
356 | " 10.345651 | \n",
357 | "
\n",
358 | " \n",
359 | "
\n",
360 | "
"
361 | ],
362 | "text/plain": [
363 | " 0 1 2 3 4 5 6 \\\n",
364 | "43 8.246982 9.620531 6.602126 8.713054 8.821096 11.614596 6.550721 \n",
365 | "44 9.620531 6.602126 8.713054 8.821096 11.614596 6.550721 9.362264 \n",
366 | "45 6.602126 8.713054 8.821096 11.614596 6.550721 9.362264 9.868245 \n",
367 | "46 8.713054 8.821096 11.614596 6.550721 9.362264 9.868245 8.736202 \n",
368 | "47 8.821096 11.614596 6.550721 9.362264 9.868245 8.736202 9.008205 \n",
369 | "\n",
370 | " 7 8 9 10 11 12 \n",
371 | "43 9.362264 9.868245 8.736202 9.008205 9.361410 8.599843 \n",
372 | "44 9.868245 8.736202 9.008205 9.361410 8.599843 10.527895 \n",
373 | "45 8.736202 9.008205 9.361410 8.599843 10.527895 8.023482 \n",
374 | "46 9.008205 9.361410 8.599843 10.527895 8.023482 9.053573 \n",
375 | "47 9.361410 8.599843 10.527895 8.023482 9.053573 10.345651 "
376 | ]
377 | },
378 | "execution_count": 4,
379 | "metadata": {},
380 | "output_type": "execute_result"
381 | }
382 | ],
383 | "source": [
384 | "#构造针对机器学习模型的数据集\n",
385 | "window=12#时间窗为12\n",
386 | "data=data.values \n",
387 | "dataset=data\n",
388 | "for i in range(window):\n",
389 | " zero=np.zeros(i+1)\n",
390 | " temp=np.append(data[i+1:],zero)\n",
391 | " dataset=np.row_stack((dataset,temp))\n",
392 | "dataset=pd.DataFrame(dataset).T\n",
393 | "dataset=dataset.iloc[:-window]\n",
394 | "dataset=dataset/10000000#进行伪归一化\n",
395 | "dataset.tail()#构造好的数据集如下"
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "execution_count": 5,
401 | "metadata": {},
402 | "outputs": [
403 | {
404 | "name": "stdout",
405 | "output_type": "stream",
406 | "text": [
407 | "(48, 12)\n",
408 | "(48,)\n"
409 | ]
410 | }
411 | ],
412 | "source": [
413 | "#划分特征与标签\n",
414 | "x=dataset.iloc[:,:-1]\n",
415 | "y=dataset.iloc[:,-1]\n",
416 | "print(x.shape)\n",
417 | "print(y.shape)"
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": 6,
423 | "metadata": {},
424 | "outputs": [
425 | {
426 | "name": "stdout",
427 | "output_type": "stream",
428 | "text": [
429 | "504\n",
430 | "72\n",
431 | "42\n",
432 | "6\n"
433 | ]
434 | }
435 | ],
436 | "source": [
437 | "#构造训练集测试集\n",
438 | "cut=6#取最后cut天为测试集\n",
439 | "X_train, X_test=x.iloc[:-cut],x.iloc[-cut:]#列表的切片操作\n",
440 | "y_train, y_test=y.iloc[:-cut],y.iloc[-cut:]\n",
441 | "X_train,X_test,y_train,y_test=X_train.values,X_test.values,y_train.values,y_test.values\n",
442 | "print(X_train.size)#通过输出训练集测试集的大小来判断数据格式正确。\n",
443 | "print(X_test.size)\n",
444 | "print(y_train.size)\n",
445 | "print(y_test.size)"
446 | ]
447 | },
448 | {
449 | "cell_type": "code",
450 | "execution_count": 7,
451 | "metadata": {},
452 | "outputs": [],
453 | "source": [
454 | "#以下函数均为对模型进行检验所用到的 精简代码的作用\n",
455 | "def mape(y_true, y_pred):\n",
456 | " return np.mean(np.abs((y_pred - y_true) / y_true)) * 100\n",
457 | "def up_down_accuracy(y_true, y_pred):\n",
458 | " y_var_test=y_true[1:]-y_true[:len(y_true)-1]#实际涨跌\n",
459 | " y_var_predict=y_pred[1:]-y_pred[:len(y_pred)-1]#原始涨跌\n",
460 | " txt=np.zeros(len(y_var_test))\n",
461 | " for i in range(len(y_var_test-1)):#计算数量\n",
462 | " txt[i]=np.sign(y_var_test[i])==np.sign(y_var_predict[i])\n",
463 | " result=sum(txt)/len(txt)\n",
464 | " return result\n",
465 | "def output():\n",
466 | " #展示在测试集上的表现 \n",
467 | " draw=pd.concat([pd.DataFrame(y_test),pd.DataFrame(y_test_predict)],axis=1);\n",
468 | " draw.iloc[:,0].plot(figsize=(12,6))\n",
469 | " draw.iloc[:,1].plot(figsize=(12,6))\n",
470 | " plt.legend(('real', 'predict'),loc='upper right',fontsize='15')\n",
471 | " plt.title(\"Test Data\",fontsize='30') #添加标题\n",
472 | " plt.show()\n",
473 | " #输出结果\n",
474 | " print('测试集上的MAE/MSE/MAPE/涨跌准确率')\n",
475 | " print(mean_absolute_error(y_test_predict, y_test))\n",
476 | " print(mean_squared_error(y_test_predict, y_test) )\n",
477 | " print(mape(y_test_predict, y_test) )\n",
478 | " print(up_down_accuracy(y_test_predict,y_test))"
479 | ]
480 | },
481 | {
482 | "cell_type": "markdown",
483 | "metadata": {},
484 | "source": [
485 | "# SVM+Elastic Net+Random Forest"
486 | ]
487 | },
488 | {
489 | "cell_type": "code",
490 | "execution_count": 8,
491 | "metadata": {},
492 | "outputs": [],
493 | "source": [
494 | "#支持向量机\n",
495 | "from sklearn.svm import LinearSVR \n",
496 | "def svm_model():\n",
497 | " svr = LinearSVR() \n",
498 | " model = svr.fit(X_train, y_train)\n",
499 | " #在训练集上的拟合结果\n",
500 | " y_train_predict=model.predict(X_train)\n",
501 | " y_test_predict=model.predict(X_test)\n",
502 | " return [y_train_predict,y_test_predict]"
503 | ]
504 | },
505 | {
506 | "cell_type": "code",
507 | "execution_count": 9,
508 | "metadata": {},
509 | "outputs": [],
510 | "source": [
511 | "#弹性网回归\n",
512 | "from sklearn.linear_model import ElasticNet\n",
513 | "def elasticnet_model():\n",
514 | " elasticnet=ElasticNet(alpha=.1, l1_ratio=.9, random_state=3)\n",
515 | " model = elasticnet.fit(X_train, y_train)\n",
516 | " #在训练集上的拟合结果\n",
517 | " y_train_predict=model.predict(X_train)\n",
518 | " y_test_predict=model.predict(X_test)\n",
519 | " return [y_train_predict,y_test_predict]"
520 | ]
521 | },
522 | {
523 | "cell_type": "code",
524 | "execution_count": 10,
525 | "metadata": {},
526 | "outputs": [],
527 | "source": [
528 | "# 建立随机森林模型 预测\n",
529 | "from sklearn.ensemble import RandomForestRegressor\n",
530 | "def rf_model():\n",
531 | " rf=RandomForestRegressor()\n",
532 | " model = rf.fit(X_train, y_train) \n",
533 | " #在训练集上的拟合结果\n",
534 | " y_train_predict=model.predict(X_train)\n",
535 | " y_test_predict=model.predict(X_test)\n",
536 | " return [y_train_predict,y_test_predict]"
537 | ]
538 | },
539 | {
540 | "cell_type": "code",
541 | "execution_count": 11,
542 | "metadata": {},
543 | "outputs": [
544 | {
545 | "name": "stderr",
546 | "output_type": "stream",
547 | "text": [
548 | "E:\\anoconda\\lib\\site-packages\\sklearn\\svm\\base.py:929: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
549 | " \"the number of iterations.\", ConvergenceWarning)\n",
550 | "E:\\anoconda\\lib\\site-packages\\sklearn\\ensemble\\forest.py:245: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.\n",
551 | " \"10 in version 0.20 to 100 in 0.22.\", FutureWarning)\n"
552 | ]
553 | }
554 | ],
555 | "source": [
556 | "y_train_predict1,y_test_predict1=svm_model()\n",
557 | "y_train_predict2,y_test_predict2=elasticnet_model()\n",
558 | "y_train_predict3,y_test_predict3=rf_model()"
559 | ]
560 | },
561 | {
562 | "cell_type": "markdown",
563 | "metadata": {},
564 | "source": [
565 | "# SARIMA"
566 | ]
567 | },
568 | {
569 | "cell_type": "code",
570 | "execution_count": 12,
571 | "metadata": {},
572 | "outputs": [
573 | {
574 | "name": "stderr",
575 | "output_type": "stream",
576 | "text": [
577 | "E:\\anoconda\\lib\\site-packages\\ipykernel_launcher.py:9: FutureWarning: Creating a DatetimeIndex by passing range endpoints is deprecated. Use `pandas.date_range` instead.\n",
578 | " if __name__ == '__main__':\n"
579 | ]
580 | }
581 | ],
582 | "source": [
583 | "#sarima\n",
584 | "#展示一下看看\n",
585 | "data=pd.read_excel('data.xlsx') \n",
586 | "data=data[choose]\n",
587 | "data=data/10000000#进行伪归一化\n",
588 | "#构建训练集测试集\n",
589 | "cut=6#取最后cut天为测试集\n",
590 | "data=pd.Series(data.values,\\\n",
591 | " index=pd.DatetimeIndex(start='2015-01-01',end='2019-12-01',freq='MS'))\n",
592 | "y=data[:-cut]\n",
593 | "true=data[-cut:]"
594 | ]
595 | },
596 | {
597 | "cell_type": "code",
598 | "execution_count": 13,
599 | "metadata": {},
600 | "outputs": [],
601 | "source": [
602 | "#SARIMA模型 \n",
603 | "def sarima_model():\n",
604 | " mod = sm.tsa.statespace.SARIMAX(y,\n",
605 | " order=(1, 1, 1),\n",
606 | " seasonal_order=(1, 1, 1, 12),\n",
607 | " enforce_stationarity=False,\n",
608 | " enforce_invertibility=False)\n",
609 | " results = mod.fit()\n",
610 | " #在训练集上进行预测\n",
611 | " pred = results.get_prediction(start=pd.to_datetime('2015-01-01'), dynamic=False)\n",
612 | " y_train_predict = pred.predicted_mean.values\n",
613 | " y_train = y.values\n",
614 | " y_train_predict = y_train_predict[1:]\n",
615 | " y_train_predict = y_train_predict[-42:]\n",
616 | " y_train = y_train[1:]\n",
617 | " #在测试集上进行预测\n",
618 | " pred_uc = results.get_forecast(steps=cut) # retun out-of-sample forecast \n",
619 | " y_test_predict=pred_uc.predicted_mean\n",
620 | " y_test_predict=y_test_predict.values\n",
621 | " y_test=true.values\n",
622 | " return [y_train_predict,y_test_predict]"
623 | ]
624 | },
625 | {
626 | "cell_type": "code",
627 | "execution_count": 14,
628 | "metadata": {},
629 | "outputs": [],
630 | "source": [
631 | "#构建神经网络数据集\n",
632 | "y_train_predict4,y_test_predict4=sarima_model()\n",
633 | "x_train_coef=np.array([y_train_predict1,y_train_predict2,y_train_predict3,y_train_predict4])\n",
634 | "x_train_coef=x_train_coef.T\n",
635 | "y_train_coef=y_train\n",
636 | "x_test_coef=np.array([y_test_predict1,y_test_predict2,y_test_predict3,y_test_predict4])\n",
637 | "x_test_coef=x_test_coef.T"
638 | ]
639 | },
640 | {
641 | "cell_type": "code",
642 | "execution_count": 15,
643 | "metadata": {},
644 | "outputs": [],
645 | "source": [
646 | "from sklearn.neural_network import MLPRegressor\n",
647 | "# alpha:L2的参数:MLP是可以支持正则化的,默认为L2,具体参数需要调整\n",
648 | "# hidden_layer_sizes=(5, 2) hidden层2层,第一层5个神经元,第二层2个神经元),2层隐藏层,也就有3层神经网络\n",
649 | "#clf = MLPRegressor(solver='lbfgs', alpha=1e-5,hidden_layer_sizes=(5, 2), random_state=1)\n",
650 | "#'identity',无操作激活,对实现线性瓶颈很有用,返回f(x)= x\n",
651 | "#'logistic',logistic sigmoid函数,返回f(x)= 1 /(1 + exp(-x))。\n",
652 | "#'tanh',双曲tan函数,返回f(x)= tanh(x)。\n",
653 | "#'relu',整流后的线性单位函数,返回f(x)= max(0,x)\n",
654 | "NN= MLPRegressor(\n",
655 | " hidden_layer_sizes=(12,12,1), activation='relu', solver='adam', alpha=0.0001, batch_size='auto',\n",
656 | " learning_rate='constant', learning_rate_init=0.001, power_t=0.5, max_iter=5000, shuffle=True,\n",
657 | " random_state=1, tol=0.0001, verbose=False, warm_start=False, momentum=0.9, nesterovs_momentum=True,\n",
658 | " early_stopping=False,beta_1=0.9, beta_2=0.999, epsilon=1e-08)\n",
659 | "model = NN.fit(x_train_coef, y_train_coef)\n",
660 | "y_test_predict=model.predict(x_test_coef)"
661 | ]
662 | },
663 | {
664 | "cell_type": "code",
665 | "execution_count": 16,
666 | "metadata": {},
667 | "outputs": [
668 | {
669 | "data": {
670 | "image/png": "\n",
671 | "text/plain": [
672 | ""
673 | ]
674 | },
675 | "metadata": {
676 | "needs_background": "light"
677 | },
678 | "output_type": "display_data"
679 | },
680 | {
681 | "name": "stdout",
682 | "output_type": "stream",
683 | "text": [
684 | "测试集上的MAE/MSE/MAPE/涨跌准确率\n",
685 | "0.8584564079619827\n",
686 | "0.8972802516676929\n",
687 | "8.378834394349372\n",
688 | "1.0\n"
689 | ]
690 | }
691 | ],
692 | "source": [
693 | "#最终组合的输出结果\n",
694 | "output()"
695 | ]
696 | },
697 | {
698 | "cell_type": "code",
699 | "execution_count": null,
700 | "metadata": {},
701 | "outputs": [],
702 | "source": []
703 | },
704 | {
705 | "cell_type": "code",
706 | "execution_count": null,
707 | "metadata": {},
708 | "outputs": [],
709 | "source": []
710 | }
711 | ],
712 | "metadata": {
713 | "anaconda-cloud": {},
714 | "kernelspec": {
715 | "display_name": "Python 3",
716 | "language": "python",
717 | "name": "python3"
718 | },
719 | "language_info": {
720 | "codemirror_mode": {
721 | "name": "ipython",
722 | "version": 3
723 | },
724 | "file_extension": ".py",
725 | "mimetype": "text/x-python",
726 | "name": "python",
727 | "nbconvert_exporter": "python",
728 | "pygments_lexer": "ipython3",
729 | "version": "3.7.3"
730 | }
731 | },
732 | "nbformat": 4,
733 | "nbformat_minor": 1
734 | }
735 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # regression-prediction-algorithms
2 | 使用支持向量机、弹性网络、随机森林、LSTM、SARIMA等多种算法进行时间序列的回归预测,除此以外还采取了多种组合方法对以上算法输出的结果进行组合预测。Support vector machine, elastic network, random forest, LSTM, SARIMA and other algorithms are used for regression prediction of time series. In addition, a variety of combination methods are used to forecast the output of the above algorithms.
3 |
--------------------------------------------------------------------------------
/data.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/momodagithub/regression-prediction-algorithms/430519b7b3f8d11732a288711c5dade571cf3a64/data.xlsx
--------------------------------------------------------------------------------