├── B0006.csv ├── B18_test.csv ├── B5_test1.csv ├── B5_test2.csv ├── B6_test.csv ├── B6_test1.csv ├── B6_test2.csv ├── B7_test.csv ├── B7_test1.csv ├── B7_test2.csv ├── B18_test1.csv ├── README.md ├── B5_test.csv ├── B0018.csv ├── B0005.csv ├── B0007.csv ├── LSTM20230601.ipynb └── rvm20230618.ipynb /B0006.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fcan1997/RUL/HEAD/B0006.csv -------------------------------------------------------------------------------- /B18_test.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fcan1997/RUL/HEAD/B18_test.csv -------------------------------------------------------------------------------- /B5_test1.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fcan1997/RUL/HEAD/B5_test1.csv -------------------------------------------------------------------------------- /B5_test2.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fcan1997/RUL/HEAD/B5_test2.csv -------------------------------------------------------------------------------- /B6_test.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fcan1997/RUL/HEAD/B6_test.csv -------------------------------------------------------------------------------- /B6_test1.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fcan1997/RUL/HEAD/B6_test1.csv -------------------------------------------------------------------------------- /B6_test2.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fcan1997/RUL/HEAD/B6_test2.csv -------------------------------------------------------------------------------- /B7_test.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fcan1997/RUL/HEAD/B7_test.csv -------------------------------------------------------------------------------- /B7_test1.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fcan1997/RUL/HEAD/B7_test1.csv -------------------------------------------------------------------------------- /B7_test2.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fcan1997/RUL/HEAD/B7_test2.csv -------------------------------------------------------------------------------- /B18_test1.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fcan1997/RUL/HEAD/B18_test1.csv -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RUL 2 | Remaining Useful Life Prediction of Lithium-ion Batteries 3 | 4 | 这是本人在锂离子剩余使用寿命预测领域研究的模型,目前实验数据源来自于NASA,后续会补充CALCE,并且丰富更多内容,欢迎小伙伴批评指正,相互学习,一起进步! 5 | -------------------------------------------------------------------------------- /B5_test.csv: -------------------------------------------------------------------------------- 1 | 循环次数,平均放电电压,平均放电温度,等压降放电时间,容量 2 | 81,3.514773138,32.21997627,1180.563,1.559765947 3 | 82,3.513299298,32.71202254,1170.906,1.559481567 4 | 83,3.514484831,32.94502215,1162.032,1.554689354 5 | 84,3.510967735,32.8575496,1152.125,1.548874108 6 | 85,3.510935875,32.62966333,1142.672,1.538236599 7 | 86,3.509393098,32.37985901,1133.531,1.527914258 8 | 87,3.506678241,32.68215852,1124.141,1.528525263 9 | 88,3.510988562,32.99387424,1124.047,1.522647325 10 | 89,3.510305463,33.04539707,1114.562,1.517485994 11 | 90,3.522221531,32.85629564,1152.235,1.605818899 12 | 91,3.516349215,33.33950238,1179.984,1.563849145 13 | 92,3.513439762,33.35156485,1152.141,1.548091525 14 | 93,3.515137522,33.14748636,1142.766,1.532375563 15 | 94,3.508303906,33.00727586,1124.125,1.526952827 16 | 95,3.509190808,32.91990813,1105.297,1.516957302 17 | 96,3.506809758,32.99903329,1105.469,1.511897596 18 | 97,3.506090907,33.03772282,1095.86,1.506563896 19 | 98,3.503715773,33.07740307,1096.281,1.501545353 20 | 99,3.506747996,33.05467982,1076.938,1.490844405 21 | 100,3.505881303,33.11570361,1067.922,1.485868385 22 | 101,3.504168334,33.12942149,1068.031,1.480413678 23 | 102,3.503056893,33.13068131,1049.438,1.475209587 24 | 103,3.498232488,33.2006921,1068,1.485903599 25 | 104,3.506585937,33.78697528,1077.766,1.49609183 26 | 105,3.501457141,33.67336509,1058.469,1.480800595 27 | 106,3.501658387,33.53581812,1049.031,1.469754327 28 | 107,3.496618641,33.32269982,1030.36,1.453901227 29 | 108,3.499164572,33.34623033,1029.969,1.454146207 30 | 109,3.496080777,33.649489,1021.531,1.455023752 31 | 110,3.495901453,33.79269514,1021.484,1.449042164 32 | 111,3.49758698,33.52446664,1002.359,1.438670937 33 | 112,3.495698962,33.39042252,992.719,1.433445432 34 | 113,3.492875381,33.67370213,992.625,1.43339589 35 | 114,3.495510764,33.91080565,993,1.428064873 36 | 115,3.494233359,33.78506806,983.891,1.422920318 37 | 116,3.491197414,33.61126807,974.343,1.417429033 38 | 117,3.488909408,33.45598581,965.187,1.412409229 39 | 118,3.487376296,33.72326745,974.297,1.412578793 40 | 119,3.491994244,34.05007637,965.328,1.407598373 41 | 120,3.491585879,33.88934691,993.359,1.433392049 42 | 121,3.495576227,34.06764855,1002.406,1.438255018 43 | 122,3.493898859,33.75168101,974.312,1.417354717 44 | 123,3.49159163,33.464823,965.188,1.406981664 45 | 124,3.48857293,33.40210698,945.828,1.401203778 46 | 125,3.486902115,33.41068892,946.469,1.396700823 47 | 126,3.485548521,33.38984948,937,1.391284789 48 | 127,3.483891255,33.37004334,936.828,1.386228768 49 | 128,3.483198941,33.27215987,917.718,1.380436676 50 | 129,3.483627445,33.22244422,917.735,1.375236415 51 | 130,3.480172443,33.26207748,908.953,1.370512802 52 | 131,3.478452149,33.31202991,908.703,1.370508557 53 | 132,3.478578256,33.25287943,899.265,1.364735511 54 | 133,3.482519558,33.74505405,927.235,1.375392017 55 | 134,3.488682364,33.88841975,927.5,1.386111515 56 | 135,3.485850238,33.7353757,917.563,1.369850004 57 | 136,3.478794177,33.59131053,899.156,1.364782795 58 | 137,3.481282602,33.32366689,899.172,1.354322047 59 | 138,3.480352042,33.69196151,890.109,1.354641702 60 | 139,3.478112782,33.88763968,890.14,1.354703921 61 | 140,3.476586713,33.67308061,880.812,1.349314998 62 | 141,3.473376929,33.4119882,871.204,1.344189194 63 | 142,3.47263664,33.29926525,861.594,1.338991379 64 | 143,3.473439373,33.59980731,871.11,1.338914523 65 | 144,3.47547771,33.61961196,861.937,1.334006676 66 | 145,3.472288482,33.39452017,852.5,1.328644431 67 | 146,3.471373665,33.16734508,842.984,1.323170899 68 | 147,3.472616016,33.10962162,843.172,1.318169159 69 | 148,3.473814687,33.46501751,843.078,1.318466444 70 | 149,3.470142649,33.4426191,833.75,1.31829304 71 | 150,3.467523652,33.03792817,843.188,1.323872422 72 | 151,3.476752584,33.49494936,880.578,1.360121677 73 | 152,3.475895683,33.49136303,861.516,1.339530687 74 | 153,3.47581218,33.3298086,852.438,1.329028657 75 | 154,3.47465544,33.27302769,852.515,1.323674127 76 | 155,3.472301781,33.22987913,833.984,1.318633897 77 | 156,3.472930577,33.20860834,834.109,1.313475132 78 | 157,3.469619203,33.23611506,833.703,1.313202063 79 | 158,3.470018471,33.18660088,824.375,1.307795995 80 | 159,3.470762692,33.1740581,815.031,1.303032919 81 | 160,3.468111467,33.25398452,815.125,1.303357356 82 | 161,3.464862578,33.31980932,814.766,1.303410044 83 | 162,3.467864134,33.33765795,805.687,1.297887081 84 | 163,3.464021405,33.29076744,805.406,1.298073506 85 | 164,3.466462386,33.27568842,805.656,1.293463614 86 | 165,3.468509067,33.32067807,796.281,1.288003393 87 | 166,3.466806213,33.37314973,795.937,1.287452522 88 | 167,3.471070722,33.7135189,824.375,1.309015364 89 | 168,3.475472204,33.86531768,843.109,1.325079329 90 | -------------------------------------------------------------------------------- /B0018.csv: -------------------------------------------------------------------------------- 1 | 循环次数,平均放电电压,平均放电温度,等压降放电时间(3.8V-3.5V),容量 2 | 1,3.527546226,31.77328528,1559.375,1.855004521 3 | 2,3.534835576,31.92489159,1581.11,1.843195532 4 | 3,3.538246596,31.31691283,1582.438,1.839601842 5 | 4,3.537697945,30.82119075,1572.984,1.830673604 6 | 5,3.542667605,31.95346972,1592.5,1.832700207 7 | 6,3.545103733,32.02311236,1594.313,1.828528885 8 | 7,3.547514244,31.63453316,1572.641,1.82120119 9 | 8,3.544812434,31.28819492,1561.937,1.815170011 10 | 9,3.543478027,31.11122471,1560.36,1.804298052 11 | 10,3.543133207,31.13166984,1578.359,1.82310023 12 | 11,3.545155146,31.51305776,1568.156,1.812125352 13 | 12,3.545032367,31.31102569,1556.093,1.804691638 14 | 13,3.544089217,30.89396471,1533.828,1.79084435 15 | 14,3.541317109,30.54059897,1520.812,1.783470723 16 | 15,3.54119798,30.82566019,1527.5,1.780938613 17 | 16,3.541255461,31.45140068,1519.391,1.77120904 18 | 17,3.542747895,31.51884775,1512.047,1.7686304 19 | 18,3.541251094,31.11237289,1487.406,1.75363048 20 | 19,3.535026056,30.58645259,1464.657,1.746219742 21 | 20,3.532662762,30.24291302,1440.859,1.737664734 22 | 21,3.533310944,30.10270376,1434.578,1.731516674 23 | 22,3.527387058,30.25118485,1429.515,1.708594989 24 | 23,3.529602597,30.2124077,1415.203,1.711469957 25 | 24,3.528970769,30.12501101,1401.078,1.707502137 26 | 25,3.532274185,30.2242863,1438.079,1.749238206 27 | 26,3.532438323,30.97637625,1432.656,1.732769715 28 | 27,3.536321131,31.41570291,1428.031,1.722231333 29 | 28,3.533535274,31.38329984,1413.032,1.711846299 30 | 29,3.527012842,31.16793783,1385.422,1.699267835 31 | 30,3.5292573,30.99469044,1379.375,1.694036644 32 | 31,3.527821498,31.05531798,1352.297,1.681902708 33 | 32,3.525692957,31.37394554,1346.844,1.676977062 34 | 33,3.524396248,31.56128411,1342.375,1.665522931 35 | 34,3.520797576,31.31847472,1322.75,1.657192458 36 | 35,3.519457281,31.10401233,1309.078,1.648224156 37 | 36,3.5131656,30.9897304,1287.969,1.63877015 38 | 37,3.517217159,30.96785197,1272.062,1.627648691 39 | 38,3.512009111,31.28534205,1265.828,1.622153702 40 | 39,3.509110527,31.41096207,1258.156,1.614006651 41 | 40,3.5161067,30.75513713,1319.578,1.676051615 42 | 41,3.519010068,30.97744663,1300.343,1.649300643 43 | 42,3.512421274,30.80051219,1273.125,1.632382992 44 | 43,3.508152153,30.6161928,1263.922,1.616415972 45 | 44,3.508648569,30.3998393,1235.219,1.610902838 46 | 45,3.505199344,30.38386825,1215.265,1.595463869 47 | 46,3.52316528,30.5762395,1413.5,1.72670744 48 | 47,3.527269876,30.70590628,1376.531,1.716567392 49 | 48,3.526762655,30.62583595,1344.906,1.695823557 50 | 49,3.523020179,30.59735707,1317.391,1.677778319 51 | 50,3.519792346,30.5284939,1289.531,1.66065869 52 | 51,3.525101458,31.20375425,1305.437,1.666410927 53 | 52,3.5185192,31.17263945,1286.985,1.646833617 54 | 53,3.519080179,30.96825923,1267.062,1.625763993 55 | 54,3.511416727,30.82170431,1237.657,1.612157741 56 | 55,3.510803868,30.76521972,1219.891,1.605736682 57 | 56,3.514326118,30.80998843,1269.672,1.673645315 58 | 57,3.519666405,31.03903359,1252.984,1.640434857 59 | 58,3.512403849,31.04176453,1223.781,1.613265332 60 | 59,3.507786223,30.9597842,1203.672,1.592096432 61 | 60,3.50765961,30.84351328,1174.234,1.586601182 62 | 61,3.507483357,30.7971574,1169.125,1.580076649 63 | 62,3.502386593,30.98443458,1158.828,1.564173295 64 | 63,3.503693494,30.96789928,1139.172,1.555618272 65 | 64,3.498994846,30.97661635,1120.093,1.540204528 66 | 65,3.495405577,30.89768962,1099.375,1.532161302 67 | 66,3.4945283,30.81568822,1104.015,1.53162324 68 | 67,3.496255483,30.72618808,1082.891,1.522259595 69 | 68,3.487696462,30.76818713,1073.36,1.506526713 70 | 69,3.486547581,30.77541517,1056.219,1.501421956 71 | 70,3.482298184,30.76182198,1045.625,1.496353412 72 | 71,3.491819482,30.91372193,1095.343,1.5334262 73 | 72,3.493995853,31.22821565,1086.391,1.523845268 74 | 73,3.484377016,31.26167978,1064.766,1.501192099 75 | 74,3.485911679,31.18408594,1045.062,1.492417671 76 | 75,3.484278754,31.00143235,1035.375,1.483323852 77 | 76,3.482988652,30.98137403,1027.079,1.480737679 78 | 77,3.483799877,31.08868793,1019.672,1.473507781 79 | 78,3.478754076,31.33642013,1007.875,1.468019285 80 | 79,3.475219435,31.22599987,997.907,1.458092356 81 | 80,3.470739536,31.03732349,988.625,1.447866099 82 | 81,3.47460091,30.81134923,979.656,1.452780878 83 | 82,3.474052157,30.93222318,970.125,1.442535636 84 | 83,3.475075729,31.21552707,972.157,1.439271288 85 | 84,3.472424308,31.40079708,961.812,1.428647568 86 | 85,3.471751102,31.25921327,951.484,1.424902675 87 | 86,3.47454073,31.13470129,1003.266,1.469811532 88 | 87,3.47172006,31.2468797,981.391,1.452581965 89 | 88,3.47139841,31.02755867,959.093,1.442834199 90 | 89,3.462649109,30.93191739,950.875,1.428376059 91 | 90,3.466072726,30.66708224,939.328,1.415454784 92 | 91,3.473067053,31.22678911,966.984,1.45457667 93 | 92,3.466854816,31.38773312,955.328,1.428317955 94 | 93,3.465273675,31.50566684,946.969,1.419702688 95 | 94,3.463887822,31.54622584,936.641,1.415560663 96 | 95,3.457807923,31.38997788,912.812,1.405680035 97 | 96,3.459813554,31.38543375,927.86,1.408446107 98 | 97,3.458027728,31.4237072,904.625,1.396854778 99 | 98,3.461399672,31.43536891,907.047,1.393567037 100 | 99,3.461128548,31.54237529,888.687,1.389364429 101 | 100,3.455556823,31.52669677,881.796,1.378565142 102 | 101,3.459598363,31.43592191,876.406,1.38490111 103 | 102,3.450656989,31.47523811,876.156,1.370325181 104 | 103,3.455515243,31.58174456,878.719,1.373346439 105 | 104,3.455840339,31.64795854,867.375,1.367948729 106 | 105,3.448678294,31.66203228,856.234,1.356714109 107 | 106,3.474503179,30.43450088,986.156,1.460431021 108 | 107,3.474031882,31.01799706,976.453,1.450099894 109 | 108,3.47440421,31.39483761,966.094,1.43814069 110 | 109,3.471767044,31.24511886,939.016,1.428368198 111 | 110,3.468530527,31.05306925,931.171,1.414787452 112 | 111,3.465121365,30.88589263,917.297,1.413889163 113 | 112,3.465287986,31.05399238,908.75,1.398933166 114 | 113,3.461799526,31.49284007,911.078,1.395332258 115 | 114,3.46326014,31.93832685,912.531,1.39004613 116 | 115,3.462839429,31.74562722,888.344,1.386027145 117 | 116,3.456915224,31.31018499,876.609,1.38821526 118 | 117,3.457469018,31.12736035,878.954,1.376182656 119 | 118,3.450219178,30.98555081,853.859,1.36471689 120 | 119,3.448672805,30.89965328,854.875,1.358985474 121 | 120,3.44463272,30.70612364,835.266,1.346230876 122 | 121,3.467272639,30.47122187,926.797,1.426842782 123 | 122,3.464490491,30.93922178,914.735,1.406448148 124 | 123,3.459667012,31.11164063,889.781,1.393490749 125 | 124,3.462241957,31.03791022,891.781,1.388249351 126 | 125,3.45556617,30.96795084,880.578,1.370187834 127 | 126,3.4586962,30.93996204,882.11,1.379695167 128 | 127,3.452528261,31.2950193,870.844,1.368658632 129 | 128,3.455229008,31.81846533,872.5,1.362737178 130 | 129,3.461974422,32.00067163,858.765,1.363405113 131 | 130,3.452221906,31.73992229,847.656,1.351864604 132 | 131,3.447650049,31.5261365,835.906,1.354796973 133 | 132,3.447928508,31.27624494,836.829,1.341051441 134 | -------------------------------------------------------------------------------- /B0005.csv: -------------------------------------------------------------------------------- 1 | 循环次数,平均放电电压,平均放电温度,等压降放电时间,容量 2 | 1,3.529828669,32.57232811,1622.625,1.856487421 3 | 2,3.537320128,32.72523523,1661.078,1.84632725 4 | 3,3.543736728,32.64286194,1661.922,1.835349194 5 | 4,3.543666107,32.51487646,1662.906,1.835262528 6 | 5,3.542343253,32.38234902,1661.938,1.834645508 7 | 6,3.541334735,32.43418218,1662.296,1.83566166 8 | 7,3.54102505,32.48041644,1662.219,1.835146143 9 | 8,3.554133297,32.41046186,1664.61,1.825756791 10 | 9,3.552935599,32.34614083,1662.828,1.824773853 11 | 10,3.551205732,32.27679751,1663.421,1.824613268 12 | 11,3.54929232,32.1570935,1663.672,1.824619553 13 | 12,3.559759221,32.3640713,1645.328,1.814201936 14 | 13,3.558081969,32.31819544,1645.907,1.813752158 15 | 14,3.555692244,32.15709029,1626.656,1.813440491 16 | 15,3.562746818,32.04407,1626.657,1.802598004 17 | 16,3.561452176,31.93159367,1625.157,1.8021069 18 | 17,3.560342047,31.94165536,1607.688,1.802579501 19 | 18,3.559118904,31.96500865,1607.64,1.803068314 20 | 19,3.557886772,31.96624168,1607.188,1.802777625 21 | 20,3.552037524,32.37143342,1683.25,1.847025995 22 | 21,3.552566844,32.7002499,1702.735,1.847417311 23 | 22,3.556794038,32.32266412,1684.125,1.836177421 24 | 23,3.561410004,31.98795811,1664.61,1.825780748 25 | 24,3.556610795,31.89445652,1645.016,1.825113644 26 | 25,3.558674217,31.78039682,1645.609,1.825581504 27 | 26,3.564103183,31.68874715,1626.563,1.814031128 28 | 27,3.559588064,31.59674799,1645.609,1.814769194 29 | 28,3.556833231,31.5655865,1625.141,1.813969389 30 | 29,3.562441085,31.59124667,1625.063,1.802765665 31 | 30,3.559140711,31.6034854,1608.266,1.80407704 32 | 31,3.557316133,31.91556385,1620.875,1.851802552 33 | 32,3.561870783,32.08112812,1630.218,1.830703846 34 | 33,3.561758634,31.84979263,1611.313,1.819904109 35 | 34,3.559730765,31.46569043,1592.172,1.809307964 36 | 35,3.557593755,31.40934896,1574.25,1.804609905 37 | 36,3.559723788,31.90224748,1583.688,1.799377065 38 | 37,3.559061771,31.64902429,1554.906,1.788443234 39 | 38,3.557015528,31.48874626,1545.75,1.782923048 40 | 39,3.556248193,31.32999561,1536.453,1.773033716 41 | 40,3.553275653,31.5695512,1527.125,1.773037755 42 | 41,3.556266432,32.11142626,1526.969,1.767872111 43 | 42,3.556309542,32.27167601,1517.688,1.76231507 44 | 43,3.558950851,31.84229166,1492.594,1.767617292 45 | 44,3.553441821,31.75981545,1508.703,1.76266836 46 | 45,3.551012154,31.41911167,1479.953,1.751730487 47 | 46,3.551317608,31.33162657,1470.656,1.741849605 48 | 47,3.549003181,31.35934131,1461.11,1.736091351 49 | 48,3.559764834,32.31088484,1564.578,1.793624015 50 | 49,3.562345434,32.4358564,1555,1.783189022 51 | 50,3.558336658,32.66698125,1517.5,1.767364208 52 | 51,3.554013139,32.28423409,1489.516,1.757017785 53 | 52,3.552376096,32.00739179,1471.203,1.746870618 54 | 53,3.549618661,32.06387072,1470.375,1.741717251 55 | 54,3.549381031,32.4375456,1451.86,1.736422506 56 | 55,3.547211868,32.16809877,1433.406,1.726321724 57 | 56,3.54670521,32.05752646,1424.109,1.715806539 58 | 57,3.544629654,31.97592093,1413.985,1.710533351 59 | 58,3.543629161,32.21689335,1414.781,1.7060145 60 | 59,3.544494615,32.55632994,1395.656,1.700311027 61 | 60,3.540920185,32.40048085,1386.125,1.69457986 62 | 61,3.541195669,32.29371861,1377.266,1.684902909 63 | 62,3.540671045,32.13607145,1358.031,1.674474159 64 | 63,3.537604879,32.3826855,1358.188,1.674569248 65 | 64,3.538422485,32.30899951,1339.453,1.663716376 66 | 65,3.53928627,32.56636929,1340.234,1.659013869 67 | 66,3.53369017,32.46542952,1321.282,1.653854057 68 | 67,3.534367735,32.34705645,1311.406,1.642653782 69 | 68,3.531294193,32.37153894,1302.266,1.637857843 70 | 69,3.530913477,32.45549348,1292.609,1.632735041 71 | 70,3.528722689,32.50017195,1283.297,1.627752892 72 | 71,3.526254936,32.50100298,1273.468,1.622125486 73 | 72,3.528684324,32.47289274,1264.406,1.61132566 74 | 73,3.525556939,32.50480155,1255.297,1.60656314 75 | 74,3.522680565,32.51218336,1236.812,1.601514223 76 | 75,3.524673459,32.45692822,1226.859,1.590369231 77 | 76,3.521315575,32.44705919,1217.609,1.585788997 78 | 77,3.518033865,32.32889655,1198.578,1.584943071 79 | 78,3.523684196,32.84488475,1227.078,1.595526389 80 | 79,3.519146638,32.31731555,1189.719,1.574730175 81 | 80,3.516233649,32.0637613,1180.766,1.564901995 82 | 81,3.514773138,32.21997627,1180.563,1.559765947 83 | 82,3.513299298,32.71202254,1170.906,1.559481567 84 | 83,3.514484831,32.94502215,1162.032,1.554689354 85 | 84,3.510967735,32.8575496,1152.125,1.548874108 86 | 85,3.510935875,32.62966333,1142.672,1.538236599 87 | 86,3.509393098,32.37985901,1133.531,1.527914258 88 | 87,3.506678241,32.68215852,1124.141,1.528525263 89 | 88,3.510988562,32.99387424,1124.047,1.522647325 90 | 89,3.510305463,33.04539707,1114.562,1.517485994 91 | 90,3.522221531,32.85629564,1152.235,1.605818899 92 | 91,3.516349215,33.33950238,1179.984,1.563849145 93 | 92,3.513439762,33.35156485,1152.141,1.548091525 94 | 93,3.515137522,33.14748636,1142.766,1.532375563 95 | 94,3.508303906,33.00727586,1124.125,1.526952827 96 | 95,3.509190808,32.91990813,1105.297,1.516957302 97 | 96,3.506809758,32.99903329,1105.469,1.511897596 98 | 97,3.506090907,33.03772282,1095.86,1.506563896 99 | 98,3.503715773,33.07740307,1096.281,1.501545353 100 | 99,3.506747996,33.05467982,1076.938,1.490844405 101 | 100,3.505881303,33.11570361,1067.922,1.485868385 102 | 101,3.504168334,33.12942149,1068.031,1.480413678 103 | 102,3.503056893,33.13068131,1049.438,1.475209587 104 | 103,3.498232488,33.2006921,1068,1.485903599 105 | 104,3.506585937,33.78697528,1077.766,1.49609183 106 | 105,3.501457141,33.67336509,1058.469,1.480800595 107 | 106,3.501658387,33.53581812,1049.031,1.469754327 108 | 107,3.496618641,33.32269982,1030.36,1.453901227 109 | 108,3.499164572,33.34623033,1029.969,1.454146207 110 | 109,3.496080777,33.649489,1021.531,1.455023752 111 | 110,3.495901453,33.79269514,1021.484,1.449042164 112 | 111,3.49758698,33.52446664,1002.359,1.438670937 113 | 112,3.495698962,33.39042252,992.719,1.433445432 114 | 113,3.492875381,33.67370213,992.625,1.43339589 115 | 114,3.495510764,33.91080565,993,1.428064873 116 | 115,3.494233359,33.78506806,983.891,1.422920318 117 | 116,3.491197414,33.61126807,974.343,1.417429033 118 | 117,3.488909408,33.45598581,965.187,1.412409229 119 | 118,3.487376296,33.72326745,974.297,1.412578793 120 | 119,3.491994244,34.05007637,965.328,1.407598373 121 | 120,3.491585879,33.88934691,993.359,1.433392049 122 | 121,3.495576227,34.06764855,1002.406,1.438255018 123 | 122,3.493898859,33.75168101,974.312,1.417354717 124 | 123,3.49159163,33.464823,965.188,1.406981664 125 | 124,3.48857293,33.40210698,945.828,1.401203778 126 | 125,3.486902115,33.41068892,946.469,1.396700823 127 | 126,3.485548521,33.38984948,937,1.391284789 128 | 127,3.483891255,33.37004334,936.828,1.386228768 129 | 128,3.483198941,33.27215987,917.718,1.380436676 130 | 129,3.483627445,33.22244422,917.735,1.375236415 131 | 130,3.480172443,33.26207748,908.953,1.370512802 132 | 131,3.478452149,33.31202991,908.703,1.370508557 133 | 132,3.478578256,33.25287943,899.265,1.364735511 134 | 133,3.482519558,33.74505405,927.235,1.375392017 135 | 134,3.488682364,33.88841975,927.5,1.386111515 136 | 135,3.485850238,33.7353757,917.563,1.369850004 137 | 136,3.478794177,33.59131053,899.156,1.364782795 138 | 137,3.481282602,33.32366689,899.172,1.354322047 139 | 138,3.480352042,33.69196151,890.109,1.354641702 140 | 139,3.478112782,33.88763968,890.14,1.354703921 141 | 140,3.476586713,33.67308061,880.812,1.349314998 142 | 141,3.473376929,33.4119882,871.204,1.344189194 143 | 142,3.47263664,33.29926525,861.594,1.338991379 144 | 143,3.473439373,33.59980731,871.11,1.338914523 145 | 144,3.47547771,33.61961196,861.937,1.334006676 146 | 145,3.472288482,33.39452017,852.5,1.328644431 147 | 146,3.471373665,33.16734508,842.984,1.323170899 148 | 147,3.472616016,33.10962162,843.172,1.318169159 149 | 148,3.473814687,33.46501751,843.078,1.318466444 150 | 149,3.470142649,33.4426191,833.75,1.31829304 151 | 150,3.467523652,33.03792817,843.188,1.323872422 152 | 151,3.476752584,33.49494936,880.578,1.360121677 153 | 152,3.475895683,33.49136303,861.516,1.339530687 154 | 153,3.47581218,33.3298086,852.438,1.329028657 155 | 154,3.47465544,33.27302769,852.515,1.323674127 156 | 155,3.472301781,33.22987913,833.984,1.318633897 157 | 156,3.472930577,33.20860834,834.109,1.313475132 158 | 157,3.469619203,33.23611506,833.703,1.313202063 159 | 158,3.470018471,33.18660088,824.375,1.307795995 160 | 159,3.470762692,33.1740581,815.031,1.303032919 161 | 160,3.468111467,33.25398452,815.125,1.303357356 162 | 161,3.464862578,33.31980932,814.766,1.303410044 163 | 162,3.467864134,33.33765795,805.687,1.297887081 164 | 163,3.464021405,33.29076744,805.406,1.298073506 165 | 164,3.466462386,33.27568842,805.656,1.293463614 166 | 165,3.468509067,33.32067807,796.281,1.288003393 167 | 166,3.466806213,33.37314973,795.937,1.287452522 168 | 167,3.471070722,33.7135189,824.375,1.309015364 169 | 168,3.475472204,33.86531768,843.109,1.325079329 170 | -------------------------------------------------------------------------------- /B0007.csv: -------------------------------------------------------------------------------- 1 | 循环次数,平均放电电压,平均放电温度,等压降放电时间(3.8V-3.5V),容量 2 | 1,3.522006703,32.62485875,1643.235,1.891052295 3 | 2,3.531949043,32.73820959,1682.079,1.880637028 4 | 3,3.534402522,32.68511457,1682.61,1.880662672 5 | 4,3.534925049,32.5547485,1702.015,1.880770901 6 | 5,3.53453664,32.46787845,1701.078,1.879450873 7 | 6,3.541162667,32.48624734,1701.578,1.880700352 8 | 7,3.541165612,32.55779966,1701.468,1.879935252 9 | 8,3.549458118,32.45374436,1703.406,1.881508812 10 | 9,3.550280546,32.37324644,1701.984,1.869690787 11 | 10,3.548134767,32.30534276,1702.735,1.87005238 12 | 11,3.548427215,32.18331948,1702.687,1.870044239 13 | 12,3.55604336,32.37885318,1702.86,1.859651899 14 | 13,3.560484964,32.29132495,1703.219,1.859074656 15 | 14,3.556287254,32.14782797,1702.906,1.859008458 16 | 15,3.564774622,32.01380948,1702.969,1.859362259 17 | 16,3.565438929,31.8813101,1682.453,1.858735545 18 | 17,3.565961679,31.8543043,1683.719,1.847817291 19 | 18,3.572496559,31.8565554,1683.719,1.848525292 20 | 19,3.569949202,31.86718261,1683.594,1.848378952 21 | 20,3.546433845,32.35216872,1722.562,1.880780544 22 | 21,3.553196809,32.60219299,1742.062,1.88147216 23 | 22,3.559934021,32.1957182,1723.656,1.881095431 24 | 23,3.559560497,31.80598816,1703.875,1.871008965 25 | 24,3.552385162,31.73700198,1721.656,1.870199515 26 | 25,3.57076703,31.51707668,1703.047,1.870671536 27 | 26,3.565722485,31.44286078,1703.062,1.859612464 28 | 27,3.571283507,31.32326275,1684.547,1.86003391 29 | 28,3.567971627,31.28571571,1683.031,1.859165311 30 | 29,3.564631404,31.33185087,1682.843,1.848139818 31 | 30,3.568983889,31.32285734,1665.984,1.849201702 32 | 31,3.564620689,31.60843645,1695.828,1.883467744 33 | 32,3.569438814,31.7821015,1695.625,1.862821112 34 | 33,3.568742147,31.51504199,1686.125,1.85208504 35 | 34,3.564242407,31.13866351,1667.406,1.846950465 36 | 35,3.56530381,31.06871052,1649.094,1.836828431 37 | 36,3.565432168,31.61824484,1658.703,1.837161962 38 | 37,3.562230316,31.33152806,1639.219,1.831672258 39 | 38,3.564481027,31.13256919,1630.14,1.821147095 40 | 39,3.561614557,30.97296968,1620.766,1.816506092 41 | 40,3.563058746,31.19906485,1611.406,1.81138156 42 | 41,3.562953437,31.7222676,1620.594,1.811605633 43 | 42,3.563446012,31.90082758,1601.953,1.806054896 44 | 43,3.563998342,31.53724025,1587.562,1.813204183 45 | 44,3.560490228,31.24662583,1592.938,1.806264194 46 | 45,3.55670021,30.93939936,1583.031,1.795831025 47 | 46,3.558412726,30.884857,1564.25,1.78588526 48 | 47,3.556136821,30.92068232,1545.453,1.78032686 49 | 48,3.568999942,33.88057782,1648.766,1.81507569 50 | 49,3.571059291,33.03833932,1639.219,1.815702515 51 | 50,3.566692128,32.2732917,1601.797,1.800243218 52 | 51,3.563327431,31.93791978,1583.187,1.790447621 53 | 52,3.559014621,31.6886866,1555.469,1.780343032 54 | 53,3.556694584,31.77912069,1554.813,1.77532877 55 | 54,3.555013985,32.13387337,1545.516,1.770262973 56 | 55,3.552637375,31.88943153,1527,1.760081062 57 | 56,3.547186025,31.81736328,1517.875,1.749647594 58 | 57,3.546814755,31.72316206,1498.484,1.749650075 59 | 58,3.544785017,31.95715065,1489.765,1.745028739 60 | 59,3.546521542,32.29917704,1489.39,1.739645746 61 | 60,3.540832548,32.12884295,1470.203,1.728564228 62 | 61,3.542992511,32.02113436,1461.531,1.723900325 63 | 62,3.536081117,31.85904886,1442.328,1.713978878 64 | 63,3.539243694,32.16100334,1442.547,1.7139702 65 | 64,3.534823706,32.10276049,1423.797,1.703652117 66 | 65,3.53664774,32.33880662,1433.828,1.704013747 67 | 66,3.535206284,32.21373015,1414.953,1.693571616 68 | 67,3.532985224,32.09949527,1386.297,1.687657736 69 | 68,3.532158561,32.11305439,1386.625,1.683074436 70 | 69,3.530837638,32.20371989,1376.953,1.677927014 71 | 70,3.529695844,32.23340826,1367.578,1.673206028 72 | 71,3.528181937,32.21216972,1367.204,1.667436721 73 | 72,3.526573018,32.21746402,1348.61,1.662266263 74 | 73,3.524502255,32.23410177,1339.5,1.657424307 75 | 74,3.522854841,32.22834398,1330.328,1.652538685 76 | 75,3.522282912,32.18331061,1320.531,1.647217463 77 | 76,3.518992382,32.17200142,1302.141,1.636413681 78 | 77,3.517784428,32.03316224,1292.312,1.636255431 79 | 78,3.521222606,32.50667189,1311.469,1.641119179 80 | 79,3.517222788,31.96103949,1283.703,1.631570079 81 | 80,3.512501086,31.79808267,1265.078,1.621213283 82 | 81,3.512961464,31.95802834,1264.875,1.616427321 83 | 82,3.51252041,32.4181317,1273.843,1.61582351 84 | 83,3.515145456,32.64044973,1265.078,1.61672418 85 | 84,3.513443179,32.56399373,1255.141,1.610865586 86 | 85,3.50970894,32.28305351,1245.422,1.600660337 87 | 86,3.504277147,32.06065237,1227.203,1.595917449 88 | 87,3.509207943,32.36526113,1227.078,1.596217952 89 | 88,3.509506492,32.68851871,1227,1.595806511 90 | 89,3.505203266,32.74773364,1217.656,1.590650782 91 | 90,3.528118235,32.50500517,1255.313,1.688821116 92 | 91,3.519590928,32.91993879,1273.656,1.625993828 93 | 92,3.514626037,32.94830586,1255.156,1.615743386 94 | 93,3.511549482,32.76906652,1245.797,1.605662676 95 | 94,3.510769869,32.59734425,1236.579,1.594974692 96 | 95,3.508044943,32.52349894,1217.688,1.590410244 97 | 96,3.507748081,32.57822883,1208.484,1.585351706 98 | 97,3.506667352,32.62728298,1208.141,1.579973715 99 | 98,3.506184392,32.65211571,1199.343,1.575218547 100 | 99,3.505100434,32.64767456,1189.422,1.569798485 101 | 100,3.505447919,32.69938792,1189.703,1.570256538 102 | 101,3.500074247,32.72457677,1180.485,1.565249874 103 | 102,3.501164716,32.7043403,1171.297,1.559633934 104 | 103,3.500407285,32.74051962,1170.844,1.564982053 105 | 104,3.507816455,33.38712825,1199.468,1.5749795 106 | 105,3.502524399,33.27062734,1180.265,1.565281117 107 | 106,3.505155823,33.11507414,1161.407,1.559755497 108 | 107,3.498364648,32.90572814,1161.437,1.549897686 109 | 108,3.500140844,32.95311265,1151.718,1.543877153 110 | 109,3.49909547,33.22821092,1152.766,1.545018749 111 | 110,3.501035761,33.36134259,1143.109,1.544594015 112 | 111,3.497212316,33.12431679,1133.797,1.534308672 113 | 112,3.496089958,32.99155359,1123.813,1.528760394 114 | 113,3.495877119,33.23183188,1133.016,1.528926025 115 | 114,3.493116071,33.50981239,1123.75,1.52909165 116 | 115,3.492655946,33.38240285,1114.89,1.523905932 117 | 116,3.492088245,33.17829959,1114.813,1.518667542 118 | 117,3.490216477,33.01591345,1096.344,1.513648227 119 | 118,3.490593196,33.27100036,1105.453,1.513605019 120 | 119,3.491141937,33.61966968,1096.375,1.513947061 121 | 120,3.492836452,33.29697163,1124.422,1.533959919 122 | 121,3.502515752,33.39313087,1133.547,1.53913728 123 | 122,3.492953566,33.1382071,1114.781,1.523624797 124 | 123,3.491746264,32.84069779,1096.453,1.513581304 125 | 124,3.490432264,32.75739656,1095.515,1.507785075 126 | 125,3.488861417,32.77234087,1087.078,1.503196068 127 | 126,3.48919321,32.72041756,1077.485,1.49782206 128 | 127,3.487878571,32.7100228,1077.359,1.492848535 129 | 128,3.486430546,32.6207676,1058.203,1.49241439 130 | 129,3.481257278,32.57651826,1048.766,1.487482777 131 | 130,3.482355079,32.58396495,1058.578,1.482535149 132 | 131,3.483507429,32.62374267,1049.063,1.482787432 133 | 132,3.478976542,32.58599202,1039.704,1.47697118 134 | 133,3.483501191,32.86454472,1058.422,1.481901555 135 | 134,3.489970199,33.0796511,1068.047,1.49813662 136 | 135,3.484801139,32.93214165,1058.063,1.482077521 137 | 136,3.47835438,32.755748,1039.532,1.477115531 138 | 137,3.475694934,32.52965475,1030.328,1.472077278 139 | 138,3.481431873,32.8480652,1040.047,1.472248014 140 | 139,3.479606034,33.07232706,1030.75,1.467010667 141 | 140,3.478225897,32.85096452,1021.078,1.46176725 142 | 141,3.474323387,32.57230233,1011.891,1.456695267 143 | 142,3.473730545,32.45068444,1002.265,1.451354973 144 | 143,3.474812723,32.73943594,1002.234,1.451615926 145 | 144,3.476096427,32.77589824,1002.375,1.446816133 146 | 145,3.468682253,32.5817324,993.219,1.446757061 147 | 146,3.46619647,32.35021153,992.672,1.441379748 148 | 147,3.468183424,32.28754241,983.578,1.436245625 149 | 148,3.467860663,32.60323918,983.813,1.436815415 150 | 149,3.466249955,32.61208344,983.672,1.436294515 151 | 150,3.462469269,32.26742657,993.125,1.441857532 152 | 151,3.480518309,32.6377337,1011.625,1.467206348 153 | 152,3.472313799,32.66525872,1002.235,1.451864037 154 | 153,3.474362109,32.50265263,993.078,1.447137975 155 | 154,3.472282321,32.44684209,983.625,1.441790587 156 | 155,3.47017619,32.39361893,974.672,1.431630948 157 | 156,3.463501112,32.41722178,965.234,1.431808425 158 | 157,3.4674531,32.4009404,965.078,1.426256259 159 | 158,3.467586473,32.36314242,964.875,1.426087753 160 | 159,3.462182536,32.38260071,955.484,1.421263351 161 | 160,3.465962999,32.4323818,955.532,1.416327275 162 | 161,3.460172751,32.4886183,946.094,1.416578183 163 | 162,3.464226115,32.52406238,946.219,1.410840173 164 | 163,3.458799193,32.48261597,945.781,1.410994963 165 | 164,3.46208055,32.46389583,946.094,1.406171429 166 | 165,3.457878295,32.55681405,936.735,1.406335848 167 | 166,3.461590937,32.57629643,927.063,1.40045524 168 | 167,3.462906231,32.8180488,964.828,1.421786505 169 | 168,3.475358043,32.92941193,965.11,1.432455272 170 | -------------------------------------------------------------------------------- /LSTM20230601.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "执行多元时间序列分析时,需要使用多个特征预测当前的目标在训练时,如果使用 5 列 [feature1, feature2, feature3, feature4, target] 来训练模型,我们需要提供 4 列 [feature1, feature2, feature3, feature4]。\n", 8 | "\n", 9 | "导入预测所需要的库\n", 10 | "在Keras中有两种深度学习的模型:序列模型(Sequential)和通用模型(Model)。差异在于不同的拓扑结构。" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 46, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import numpy as np\n", 20 | "import pandas as pd\n", 21 | "from matplotlib import pyplot as plt\n", 22 | "from tensorflow.keras.models import Sequential #按顺序建立\n", 23 | "from tensorflow.keras.layers import LSTM\n", 24 | "from tensorflow.keras.layers import Dense,Dropout #全连接层\n", 25 | "from sklearn.preprocessing import MinMaxScaler #数据归一化\n", 26 | "from keras.wrappers.scikit_learn import KerasRegressor #回归\n", 27 | "from sklearn.model_selection import GridSearchCV #自动调参" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "读取数据" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 47, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "df = pd.read_csv(\"B0005.csv\")\n", 44 | "df.head() #前面五行" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 48, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "capacity_original_half = np.array(df)[:20,4]#1到20行的第四列的容量数据 切片" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "训练测试拆分" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 49, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "df_for_training=df[:120]\n", 70 | "df_for_testing=df[:]\n", 71 | "print(df_for_training.shape)\n", 72 | "print(df_for_testing.shape)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "MinMax归一化预处理" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 50, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "scaler = MinMaxScaler(feature_range=(0,1))\n", 89 | "df_for_training_scaled = scaler.fit_transform(df_for_training)\n", 90 | "df_for_testing_scaled = scaler.transform(df_for_testing)\n", 91 | "df_for_training_scaled" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "将数据拆分为X和Y\n" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 51, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "def createXY(dataset,n_past):\n", 108 | " dataX = []\n", 109 | " dataY = []\n", 110 | " for i in range(n_past,len(dataset)):\n", 111 | " \n", 112 | " dataX.append(dataset[i-n_past:i,0:dataset.shape[1]]) \n", 113 | " dataY.append(dataset[i,4])\n", 114 | " return np.array(dataX),np.array(dataY)\n", 115 | "\n", 116 | "trainX, trainY = createXY(df_for_training_scaled,20) \n", 117 | "testX, testY = createXY(df_for_testing_scaled,20) " 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "n_past是预测下一个目标值时将在过去查看的步骤数,为20的话,就是使用过去20个值(包括目标列在内的所有特性)来预测第21个目标值\n", 125 | "所以trainX有所有的特征值,而trainY中只有目标值" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 52, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "print(trainX.shape)\n", 135 | "print(trainY.shape)\n", 136 | "\n", 137 | "print(testX.shape)\n", 138 | "print(testY.shape)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "如果查看 trainX[1] 值,会发现到它与 trainX[0] 中的数据相同(第一列除外),因为我们将看到前 20 个来预测第 21 列,在第一次预测之后它会自动移动 到第 2 列并取下一个 20 值来预测下一个目标值。\n", 146 | "\n", 147 | "每个数据都将保存在 trainX 和 trainY 中" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "GridSearchCV,它存在的意义就是自动调参,只要把参数输进去,就能给出最优化的结果和参数。但是这个方法适合于小数据集,一旦数据的量级上去了,很难得出结果。\n", 155 | "\n", 156 | "训练模型,使用gridsearchcv网格搜索进行超参数(需要人工选择的参数)调整找到基础模型\n", 157 | "\n", 158 | "GridSearchCV的名字其实可以拆分为两部分,GridSearch和CV,即网格搜索和交叉验证。网格搜索,搜索的是参数,即在指定的参数范围内,按步长依次调整参数,利用调整的参数训练学习器,从所有的参数中找到在验证集上精度最高的参数,是一个训练和比较的过程。\n", 159 | "\n", 160 | "1.选择并构建训练模型model\n", 161 | "\n", 162 | "2.将训练模型model投入到GridSearchCV中,得到GridSearchCV模型grid_model\n", 163 | "\n", 164 | "3.用grid_model拟合训练集数据,选择在validation_dataset上效果最好的参数的模型best_estimator\n", 165 | "\n", 166 | "4.1.用best_estimator拟合训练集(得到的结果应该与之前不同,因为之前用交叉验证等方法对训练集进行了分割)\n", 167 | "\n", 168 | "4.2.用best_estimator拟合测试集\n", 169 | "\n" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 53, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "learning_rate = 0.01\n", 179 | "def build_model(optimizer):\n", 180 | " grid_model = Sequential()\n", 181 | " grid_model.add(LSTM(50,return_sequences=True,input_shape=(20,5)))\n", 182 | " grid_model.add(LSTM(50))\n", 183 | " grid_model.add(Dropout(0.2))\n", 184 | " grid_model.add(Dense(1))\n", 185 | " \n", 186 | " grid_model.compile(loss = 'mse',optimizer=optimizer)\n", 187 | " return grid_model\n", 188 | "grid_model = KerasRegressor(build_fn=build_model,verbose=1)\n", 189 | "\n", 190 | "parameters = {'batch_size':[16,24,28,32,40],\n", 191 | " 'epochs':[300,500,800],\n", 192 | " 'optimizer':['adam']}\n", 193 | "\n", 194 | "grid_search = GridSearchCV(estimator = grid_model,\n", 195 | " param_grid = parameters,cv = 2)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "将模型拟合到trainX和trainY数据中" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 54, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "grid_search = grid_search.fit(trainX,trainY)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": {}, 217 | "source": [ 218 | "找到最佳的模型参数\n" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 55, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "grid_search.best_params_" 228 | ] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": {}, 233 | "source": [ 234 | "将最佳模型保存在在my_model变量中" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 56, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "my_model=grid_search.best_estimator_.model" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 57, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "prediction = my_model.predict(testX)\n", 253 | "print(\"prediction is\\n\",prediction)\n", 254 | "print(\"\\nprediction Shape-\",prediction.shape)" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 58, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "prediction_copies_array = np.repeat(prediction,5,axis=-1)#在缩放的时候一行有五列,现在是目标列一列,所以将预测列复制四次得到五列相同的值" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 59, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "pred=scaler.inverse_transform(np.reshape(prediction_copies_array,(len(prediction),5)))[:,4]#只需要最后一列 切片" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": {}, 278 | "source": [ 279 | "将这个pred的值与testY进行比较,testY也是按比例缩放,同样要逆变换" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 60, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "original_copies_array = np.repeat(testY,5,axis=-1)\n", 289 | "original=scaler.inverse_transform(np.reshape(original_copies_array,(len(testY),5)))[:,4]" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 61, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "capacity_original_complete = np.append(capacity_original_half,original)\n", 299 | "pred_complete = np.append(capacity_original_half,pred)" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": 62, 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "plt.plot(pred_complete,color = 'blue',label = 'Prediccted Capacity')\n", 309 | "plt.plot(capacity_original_complete,color = 'red',label = 'Real Capacity')\n", 310 | "plt.title('B0005 Battery')\n", 311 | "plt.xlabel('Cycle')\n", 312 | "plt.ylabel('Capacity')\n", 313 | "plt.legend()\n", 314 | "plt.show()" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 63, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [ 323 | "from math import sqrt\n", 324 | "from sklearn.metrics import mean_absolute_error\n", 325 | "from sklearn.metrics import mean_squared_error\n", 326 | "from sklearn.metrics import r2_score\n", 327 | "print(\"mean_absolute_error:\",mean_absolute_error(original,pred))\n", 328 | "print(\"mean_squared_error:\",mean_squared_error(original,pred))\n", 329 | "print(\"rmse:\",sqrt(mean_squared_error(original,pred)))\n", 330 | "print(\"r2 score:\",r2_score(original,pred))" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": {}, 336 | "source": [ 337 | "预测未来的值\n", 338 | "\n", 339 | "df.loc[]:是按标签或者布尔数组进行行/列索引\n", 340 | "df.iloc[]:是按标签位置(from 0 to length - 1)或者布尔数组进行索引" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 64, 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [ 349 | "df_cycle_past = df.iloc[79:99,:] \n", 350 | "df_cycle_past\n" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 65, 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "df_cycle_future=pd.read_csv(\"B18_test1.csv\",encoding=\"gbk\")\n" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 66, 365 | "metadata": {}, 366 | "outputs": [], 367 | "source": [ 368 | "df_cycle_future[\"容量\"] = 0\n" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": 67, 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "#剔除预测数据中容量列,进行归一化缩放,拼接20个预测输入和88test点\n", 378 | "#df_cycle_future = df_cycle_future[[\"循环次数\",\"平均放电电压\",\"平均放电温度\",\"等压降放电时间\",\"容量\"]]\n", 379 | "old_scaled_array = scaler.transform(df_cycle_past)\n", 380 | "new_scaled_array = scaler.transform(df_cycle_future)\n", 381 | "new_scaled_df = pd.DataFrame(new_scaled_array)\n", 382 | "new_scaled_df.iloc[:,4] = np.nan\n", 383 | "full_df = pd.concat([pd.DataFrame(old_scaled_array),new_scaled_df]).reset_index().drop([\"index\"],axis=1)" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": null, 389 | "metadata": {}, 390 | "outputs": [], 391 | "source": [ 392 | "#滚动填充容量数据预测\n", 393 | "full_df_scaled_array = full_df.values\n", 394 | "all_data = [] #预测值\n", 395 | "time_step = 20\n", 396 | "for i in range(time_step,len(full_df_scaled_array)):\n", 397 | " data_x = []\n", 398 | " data_x.append(full_df_scaled_array[i-time_step:i,0:full_df_scaled_array.shape[1]])\n", 399 | " data_x = np.array(data_x)\n", 400 | " prediction = my_model.predict(data_x)\n", 401 | " print(prediction)\n", 402 | " all_data.append(prediction)\n", 403 | " full_df.iloc[i,4] = prediction " 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": null, 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [ 412 | "full_df_scaled_array[0:,0:full_df_scaled_array.shape[1]]\n", 413 | "full_df" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "metadata": {}, 420 | "outputs": [], 421 | "source": [ 422 | "#逆缩放\n", 423 | "new_array=np.array(all_data)\n", 424 | "new_array=new_array.reshape(-1,1)\n", 425 | "prediction_copies_array = np.repeat(new_array,5,axis=-1)\n", 426 | "y_pred_future_cycle = scaler.inverse_transform(np.reshape(prediction_copies_array,(len(new_array),5)))[:,4]\n", 427 | "print(y_pred_future_cycle)" 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "metadata": {}, 433 | "source": [ 434 | "起始点为80预测末尾88个容量" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": null, 440 | "metadata": {}, 441 | "outputs": [], 442 | "source": [ 443 | "capacity_original_half = np.array(df)[:100,4]\n", 444 | "capacity_original_complete = np.array(df)[:,4]\n", 445 | "len(capacity_original_half)" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": null, 451 | "metadata": {}, 452 | "outputs": [], 453 | "source": [ 454 | "pred_complete = np.append(capacity_original_half,y_pred_future_cycle)" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "metadata": {}, 461 | "outputs": [], 462 | "source": [ 463 | "len(pred_complete)" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": null, 469 | "metadata": {}, 470 | "outputs": [], 471 | "source": [ 472 | "len(capacity_original_complete)" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": null, 478 | "metadata": {}, 479 | "outputs": [], 480 | "source": [ 481 | "plt.plot(pred_complete,color = 'blue',label = 'Predicted Capacity')\n", 482 | "plt.plot(capacity_original_complete,color = 'red',label = 'Real Capacity')\n", 483 | "plt.title('B0018 Battery')\n", 484 | "plt.xlabel('Cycle')\n", 485 | "plt.ylabel('Capacity')\n", 486 | "plt.legend()\n", 487 | "plt.show()" 488 | ] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "execution_count": null, 493 | "metadata": {}, 494 | "outputs": [], 495 | "source": [ 496 | "from math import sqrt\n", 497 | "from sklearn.metrics import mean_absolute_error\n", 498 | "from sklearn.metrics import mean_squared_error\n", 499 | "from sklearn.metrics import r2_score\n", 500 | "print(\"mean_absolute_error MAE:\", mean_absolute_error(capacity_original_complete, pred_complete))\n", 501 | "print(\"mean_squared_error MSE:\", mean_squared_error(capacity_original_complete, pred_complete))\n", 502 | "print(\"rmse:\", sqrt(mean_squared_error(capacity_original_complete, pred_complete)))\n", 503 | "print(\"r2 score:\", r2_score(capacity_original_complete, pred_complete))" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": null, 509 | "metadata": {}, 510 | "outputs": [], 511 | "source": [] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": null, 516 | "metadata": {}, 517 | "outputs": [], 518 | "source": [] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": null, 523 | "metadata": {}, 524 | "outputs": [], 525 | "source": [] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": null, 530 | "metadata": {}, 531 | "outputs": [], 532 | "source": [] 533 | } 534 | ], 535 | "metadata": { 536 | "kernelspec": { 537 | "display_name": "Python 3", 538 | "language": "python", 539 | "name": "python3" 540 | }, 541 | "language_info": { 542 | "codemirror_mode": { 543 | "name": "ipython", 544 | "version": 3 545 | }, 546 | "file_extension": ".py", 547 | "mimetype": "text/x-python", 548 | "name": "python", 549 | "nbconvert_exporter": "python", 550 | "pygments_lexer": "ipython3", 551 | "version": "3.6.5" 552 | } 553 | }, 554 | "nbformat": 4, 555 | "nbformat_minor": 2 556 | } 557 | -------------------------------------------------------------------------------- /rvm20230618.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "定义rvm的类,回归和分类都有" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "\"\"\"Relevance Vector Machine classes for regression and classification.\"\"\"\n", 17 | "import numpy as np\n", 18 | " \n", 19 | "from scipy.optimize import minimize\n", 20 | "from scipy.special import expit\n", 21 | " \n", 22 | "from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin\n", 23 | "from sklearn.metrics.pairwise import (\n", 24 | " linear_kernel,\n", 25 | " rbf_kernel,\n", 26 | " polynomial_kernel\n", 27 | ")\n", 28 | "from sklearn.multiclass import OneVsOneClassifier\n", 29 | "from sklearn.utils.validation import check_X_y\n", 30 | " \n", 31 | " \n", 32 | "class BaseRVM(BaseEstimator):\n", 33 | " \n", 34 | " \"\"\"Base Relevance Vector Machine class.\n", 35 | " Implementation of Mike Tipping's Relevance Vector Machine using the\n", 36 | " scikit-learn API. Add a posterior over weights method and a predict\n", 37 | " in subclass to use for classification or regression.\n", 38 | " \"\"\"\n", 39 | " \n", 40 | " def __init__(\n", 41 | " self,\n", 42 | " kernel='rbf',\n", 43 | " degree=3,\n", 44 | " coef1=None,\n", 45 | " coef0=0.0,\n", 46 | " n_iter=3000,\n", 47 | " tol=1e-3,\n", 48 | " alpha=1e-6,\n", 49 | " threshold_alpha=1e9,\n", 50 | " beta=1.e-6,\n", 51 | " beta_fixed=False,\n", 52 | " bias_used=True,\n", 53 | " verbose=False\n", 54 | " ):\n", 55 | " \"\"\"Copy params to object properties, no validation.\"\"\"\n", 56 | " self.kernel = kernel\n", 57 | " self.degree = degree\n", 58 | " self.coef1 = coef1\n", 59 | " self.coef0 = coef0\n", 60 | " self.n_iter = n_iter\n", 61 | " self.tol = tol\n", 62 | " self.alpha = alpha\n", 63 | " self.threshold_alpha = threshold_alpha\n", 64 | " self.beta = beta\n", 65 | " self.beta_fixed = beta_fixed\n", 66 | " self.bias_used = bias_used\n", 67 | " self.verbose = verbose\n", 68 | " \n", 69 | " def get_params(self, deep=True):\n", 70 | " \"\"\"Return parameters as a dictionary.\"\"\"\n", 71 | " params = {\n", 72 | " 'kernel': self.kernel,\n", 73 | " 'degree': self.degree,\n", 74 | " 'coef1': self.coef1,\n", 75 | " 'coef0': self.coef0,\n", 76 | " 'n_iter': self.n_iter,\n", 77 | " 'tol': self.tol,\n", 78 | " 'alpha': self.alpha,\n", 79 | " 'threshold_alpha': self.threshold_alpha,\n", 80 | " 'beta': self.beta,\n", 81 | " 'beta_fixed': self.beta_fixed,\n", 82 | " 'bias_used': self.bias_used,\n", 83 | " 'verbose': self.verbose\n", 84 | " }\n", 85 | " return params\n", 86 | " \n", 87 | " def set_params(self, **parameters):\n", 88 | " \"\"\"Set parameters using kwargs.\"\"\"\n", 89 | " for parameter, value in parameters.items():\n", 90 | " setattr(self, parameter, value)\n", 91 | " return self\n", 92 | " \n", 93 | " def _apply_kernel(self, x, y):\n", 94 | " \"\"\"Apply the selected kernel function to the data.\"\"\"\n", 95 | " if self.kernel == 'linear':\n", 96 | " phi = linear_kernel(x, y)\n", 97 | " elif self.kernel == 'rbf':\n", 98 | " phi = rbf_kernel(x, y, self.coef1)\n", 99 | " elif self.kernel == 'poly':\n", 100 | " phi = polynomial_kernel(x, y, self.degree, self.coef1, self.coef0)\n", 101 | " elif callable(self.kernel):\n", 102 | " phi = self.kernel(x, y)\n", 103 | " if len(phi.shape) != 2:\n", 104 | " raise ValueError(\n", 105 | " \"Custom kernel function did not return 2D matrix\"\n", 106 | " )\n", 107 | " if phi.shape[0] != x.shape[0]:\n", 108 | " raise ValueError(\n", 109 | " \"Custom kernel function did not return matrix with rows\"\n", 110 | " \" equal to number of data points.\"\"\"\n", 111 | " )\n", 112 | " else:\n", 113 | " raise ValueError(\"Kernel selection is invalid.\")\n", 114 | " \n", 115 | " if self.bias_used:\n", 116 | " phi = np.append(phi, np.ones((phi.shape[0], 1)), axis=1)\n", 117 | " \n", 118 | " return phi\n", 119 | " \n", 120 | " def _prune(self):\n", 121 | " \"\"\"Remove basis functions based on alpha values.\"\"\"\n", 122 | " keep_alpha = self.alpha_ < self.threshold_alpha\n", 123 | " \n", 124 | " if not np.any(keep_alpha):\n", 125 | " keep_alpha[0] = True\n", 126 | " if self.bias_used:\n", 127 | " keep_alpha[-1] = True\n", 128 | " \n", 129 | " if self.bias_used:\n", 130 | " if not keep_alpha[-1]:\n", 131 | " self.bias_used = False\n", 132 | " self.relevance_ = self.relevance_[keep_alpha[:-1]]\n", 133 | " else:\n", 134 | " self.relevance_ = self.relevance_[keep_alpha]\n", 135 | " \n", 136 | " self.alpha_ = self.alpha_[keep_alpha]\n", 137 | " self.alpha_old = self.alpha_old[keep_alpha]\n", 138 | " self.gamma = self.gamma[keep_alpha]\n", 139 | " self.phi = self.phi[:, keep_alpha]\n", 140 | " self.sigma_ = self.sigma_[np.ix_(keep_alpha, keep_alpha)]\n", 141 | " self.m_ = self.m_[keep_alpha]\n", 142 | " \n", 143 | " def fit(self, X, y):\n", 144 | " \"\"\"Fit the RVR to the training data.\"\"\"\n", 145 | " X, y = check_X_y(X, y)\n", 146 | " \n", 147 | " n_samples, n_features = X.shape\n", 148 | " \n", 149 | " self.phi = self._apply_kernel(X, X)\n", 150 | " \n", 151 | " n_basis_functions = self.phi.shape[1]\n", 152 | " \n", 153 | " self.relevance_ = X\n", 154 | " self.y = y\n", 155 | " \n", 156 | " self.alpha_ = self.alpha * np.ones(n_basis_functions)\n", 157 | " self.beta_ = self.beta\n", 158 | " \n", 159 | " self.m_ = np.zeros(n_basis_functions)\n", 160 | " \n", 161 | " self.alpha_old = self.alpha_\n", 162 | " \n", 163 | " for i in range(self.n_iter):\n", 164 | " self._posterior()\n", 165 | " \n", 166 | " self.gamma = 1 - self.alpha_*np.diag(self.sigma_)\n", 167 | " self.alpha_ = self.gamma/(self.m_ ** 2)\n", 168 | " \n", 169 | " if not self.beta_fixed:\n", 170 | " self.beta_ = (n_samples - np.sum(self.gamma))/(\n", 171 | " np.sum((y - np.dot(self.phi, self.m_)) ** 2))\n", 172 | " \n", 173 | " self._prune()\n", 174 | " \n", 175 | " if self.verbose:\n", 176 | " print(\"Iteration: {}\".format(i))\n", 177 | " print(\"Alpha: {}\".format(self.alpha_))\n", 178 | " print(\"Beta: {}\".format(self.beta_))\n", 179 | " print(\"Gamma: {}\".format(self.gamma))\n", 180 | " print(\"m: {}\".format(self.m_))\n", 181 | " print(\"Relevance Vectors: {}\".format(self.relevance_.shape[0]))\n", 182 | " print()\n", 183 | " \n", 184 | " delta = np.amax(np.absolute(self.alpha_ - self.alpha_old))\n", 185 | " \n", 186 | " if delta < self.tol and i > 1:\n", 187 | " break\n", 188 | " \n", 189 | " self.alpha_old = self.alpha_\n", 190 | " \n", 191 | " if self.bias_used:\n", 192 | " self.bias = self.m_[-1]\n", 193 | " else:\n", 194 | " self.bias = None\n", 195 | " \n", 196 | " return self\n", 197 | " \n", 198 | " \n", 199 | "class RVR(BaseRVM, RegressorMixin):\n", 200 | " \n", 201 | " \"\"\"Relevance Vector Machine Regression.\n", 202 | " Implementation of Mike Tipping's Relevance Vector Machine for regression\n", 203 | " using the scikit-learn API.\n", 204 | " \"\"\"\n", 205 | " \n", 206 | " def _posterior(self):\n", 207 | " \"\"\"Compute the posterior distriubtion over weights.\"\"\"\n", 208 | " i_s = np.diag(self.alpha_) + self.beta_ * np.dot(self.phi.T, self.phi)\n", 209 | " self.sigma_ = np.linalg.inv(i_s)\n", 210 | " self.m_ = self.beta_ * np.dot(self.sigma_, np.dot(self.phi.T, self.y))\n", 211 | " \n", 212 | " def predict(self, X, eval_MSE=False):\n", 213 | " \"\"\"Evaluate the RVR model at x.\"\"\"\n", 214 | " phi = self._apply_kernel(X, self.relevance_)\n", 215 | " \n", 216 | " y = np.dot(phi, self.m_)\n", 217 | " \n", 218 | " if eval_MSE:\n", 219 | " MSE = (1/self.beta_) + np.dot(phi, np.dot(self.sigma_, phi.T))\n", 220 | " return y, MSE[:, 0]\n", 221 | " else:\n", 222 | " return y\n", 223 | " \n", 224 | " \n", 225 | "class RVC(BaseRVM, ClassifierMixin):\n", 226 | " \n", 227 | " \"\"\"Relevance Vector Machine Classification.\n", 228 | " Implementation of Mike Tipping's Relevance Vector Machine for\n", 229 | " classification using the scikit-learn API.\n", 230 | " \"\"\"\n", 231 | " \n", 232 | " def __init__(self, n_iter_posterior=50, **kwargs):\n", 233 | " \"\"\"Copy params to object properties, no validation.\"\"\"\n", 234 | " self.n_iter_posterior = n_iter_posterior\n", 235 | " super(RVC, self).__init__(**kwargs)\n", 236 | " \n", 237 | " def get_params(self, deep=True):\n", 238 | " \"\"\"Return parameters as a dictionary.\"\"\"\n", 239 | " params = super(RVC, self).get_params(deep=deep)\n", 240 | " params['n_iter_posterior'] = self.n_iter_posterior\n", 241 | " return params\n", 242 | " \n", 243 | " def _classify(self, m, phi):\n", 244 | " return expit(np.dot(phi, m))\n", 245 | " \n", 246 | " def _log_posterior(self, m, alpha, phi, t):\n", 247 | " \n", 248 | " y = self._classify(m, phi)\n", 249 | " \n", 250 | " log_p = -1 * (np.sum(np.log(y[t == 1]), 0) +\n", 251 | " np.sum(np.log(1-y[t == 0]), 0))\n", 252 | " log_p = log_p + 0.5*np.dot(m.T, np.dot(np.diag(alpha), m))\n", 253 | " \n", 254 | " jacobian = np.dot(np.diag(alpha), m) - np.dot(phi.T, (t-y))\n", 255 | " \n", 256 | " return log_p, jacobian\n", 257 | " \n", 258 | " def _hessian(self, m, alpha, phi, t):\n", 259 | " y = self._classify(m, phi)\n", 260 | " B = np.diag(y*(1-y))\n", 261 | " return np.diag(alpha) + np.dot(phi.T, np.dot(B, phi))\n", 262 | " \n", 263 | " def _posterior(self):\n", 264 | " result = minimize(\n", 265 | " fun=self._log_posterior,\n", 266 | " hess=self._hessian,\n", 267 | " x0=self.m_,\n", 268 | " args=(self.alpha_, self.phi, self.t),\n", 269 | " method='Newton-CG',\n", 270 | " jac=True,\n", 271 | " options={\n", 272 | " 'maxiter': self.n_iter_posterior\n", 273 | " }\n", 274 | " )\n", 275 | " \n", 276 | " self.m_ = result.x\n", 277 | " self.sigma_ = np.linalg.inv(\n", 278 | " self._hessian(self.m_, self.alpha_, self.phi, self.t)\n", 279 | " )\n", 280 | " \n", 281 | " def fit(self, X, y):\n", 282 | " \"\"\"Check target values and fit model.\"\"\"\n", 283 | " self.classes_ = np.unique(y)\n", 284 | " n_classes = len(self.classes_)\n", 285 | " \n", 286 | " if n_classes < 2:\n", 287 | " raise ValueError(\"Need 2 or more classes.\")\n", 288 | " elif n_classes == 2:\n", 289 | " self.t = np.zeros(y.shape)\n", 290 | " self.t[y == self.classes_[1]] = 1\n", 291 | " return super(RVC, self).fit(X, self.t)\n", 292 | " else:\n", 293 | " self.multi_ = None\n", 294 | " self.multi_ = OneVsOneClassifier(self)\n", 295 | " self.multi_.fit(X, y)\n", 296 | " return self\n", 297 | " \n", 298 | " def predict_proba(self, X):\n", 299 | " \"\"\"Return an array of class probabilities.\"\"\"\n", 300 | " phi = self._apply_kernel(X, self.relevance_)\n", 301 | " y = self._classify(self.m_, phi)\n", 302 | " return np.column_stack((1-y, y))\n", 303 | " \n", 304 | " def predict(self, X):\n", 305 | " \"\"\"Return an array of classes for each input.\"\"\"\n", 306 | " if len(self.classes_) == 2:\n", 307 | " y = self.predict_proba(X)\n", 308 | " res = np.empty(y.shape[0], dtype=self.classes_.dtype)\n", 309 | " res[y[:, 1] <= 0.5] = self.classes_[0]\n", 310 | " res[y[:, 1] >= 0.5] = self.classes_[1]\n", 311 | " return res\n", 312 | " else:\n", 313 | " return self.multi_.predict(X)" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": {}, 319 | "source": [ 320 | "代码测试" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "metadata": {}, 326 | "source": [ 327 | "测试回归和分类问题,并和支持向量机作对比" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 14, 333 | "metadata": {}, 334 | "outputs": [], 335 | "source": [ 336 | "import numpy as np\n", 337 | "import pandas as pd\n", 338 | "import matplotlib.pyplot as plt\n", 339 | "import seaborn as sns\n", 340 | "from sklearn.preprocessing import StandardScaler\n", 341 | "from sklearn.model_selection import train_test_split\n", 342 | "from sklearn.model_selection import KFold, StratifiedKFold\n", 343 | "from sklearn.model_selection import GridSearchCV\n", 344 | "#from sklearn.metrics import plot_confusion_matrix #矩阵可视化\n", 345 | " \n", 346 | "from sklearn.svm import SVC\n", 347 | "from sklearn.svm import SVR\n", 348 | "from sklearn.datasets import load_boston\n", 349 | "from sklearn.datasets import load_breast_cancer" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": {}, 355 | "source": [ 356 | "分类测试\n", 357 | "\n", 358 | "分类使用鸢尾花数据集" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 15, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "iris = load_breast_cancer() #加载数据\n", 368 | "X = iris.data\n", 369 | "y = iris.target\n", 370 | " \n", 371 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=0)\n", 372 | " \n", 373 | "scaler = StandardScaler()\n", 374 | "scaler.fit(X_train)\n", 375 | "X_train_s = scaler.transform(X_train)\n", 376 | "X_test_s = scaler.transform(X_test)" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": {}, 382 | "source": [ 383 | "支持向量机,不同核函数的效果" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 16, 389 | "metadata": {}, 390 | "outputs": [ 391 | { 392 | "name": "stdout", 393 | "output_type": "stream", 394 | "text": [ 395 | "0.9824561403508771\n", 396 | "0.7719298245614035\n", 397 | "0.9210526315789473\n", 398 | "0.9649122807017544\n", 399 | "0.9649122807017544\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "#线性核函数\n", 405 | "model = SVC(kernel=\"linear\", random_state=123)\n", 406 | "model.fit(X_train_s, y_train)\n", 407 | "print(model.score(X_test_s, y_test))\n", 408 | "#二次多项式核\n", 409 | "model = SVC(kernel=\"poly\", degree=2, random_state=123)\n", 410 | "model.fit(X_train_s, y_train)\n", 411 | "print(model.score(X_test_s, y_test))\n", 412 | "#三次多项式\n", 413 | "model = SVC(kernel=\"poly\", degree=3, random_state=123)\n", 414 | "model.fit(X_train_s, y_train)\n", 415 | "print(model.score(X_test_s, y_test))\n", 416 | "#径向核\n", 417 | "model = SVC(kernel=\"rbf\", random_state=123)\n", 418 | "model.fit(X_train_s, y_train)\n", 419 | "print(model.score(X_test_s, y_test))\n", 420 | "#S核\n", 421 | "model = SVC(kernel=\"sigmoid\",random_state=123)\n", 422 | "model.fit(X_train_s, y_train)\n", 423 | "print(model.score(X_test_s, y_test))" 424 | ] 425 | }, 426 | { 427 | "cell_type": "markdown", 428 | "metadata": {}, 429 | "source": [ 430 | "相关向量机(RVM)的效果" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 18, 436 | "metadata": {}, 437 | "outputs": [ 438 | { 439 | "name": "stdout", 440 | "output_type": "stream", 441 | "text": [ 442 | "0.9649122807017544\n", 443 | "0.956140350877193\n" 444 | ] 445 | }, 446 | { 447 | "name": "stderr", 448 | "output_type": "stream", 449 | "text": [ 450 | "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\ipykernel_launcher.py:236: RuntimeWarning: divide by zero encountered in log\n", 451 | "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\ipykernel_launcher.py:236: RuntimeWarning: divide by zero encountered in log\n" 452 | ] 453 | }, 454 | { 455 | "name": "stdout", 456 | "output_type": "stream", 457 | "text": [ 458 | "0.9473684210526315\n" 459 | ] 460 | } 461 | ], 462 | "source": [ 463 | "model = RVC(kernel=\"linear\")\n", 464 | "model.fit(X_train_s, y_train)\n", 465 | "print(model.score(X_test_s, y_test))\n", 466 | " \n", 467 | "model = RVC(kernel=\"rbf\")\n", 468 | "model.fit(X_train_s, y_train)\n", 469 | "print(model.score(X_test_s, y_test))\n", 470 | " \n", 471 | "model = RVC(kernel=\"poly\")\n", 472 | "model.fit(X_train_s, y_train)\n", 473 | "print(model.score(X_test_s, y_test))" 474 | ] 475 | }, 476 | { 477 | "cell_type": "markdown", 478 | "metadata": {}, 479 | "source": [ 480 | "回归测试" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": 19, 486 | "metadata": {}, 487 | "outputs": [], 488 | "source": [ 489 | "# Support Vector Regression with Boston Housing Data\n", 490 | "X, y = load_boston(return_X_y=True)\n", 491 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)\n", 492 | " \n", 493 | "scaler = StandardScaler()\n", 494 | "scaler.fit(X_train)\n", 495 | "X_train_s = scaler.transform(X_train)\n", 496 | "X_test_s = scaler.transform(X_test)" 497 | ] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "metadata": {}, 502 | "source": [ 503 | "支持向量机不同核函数效果" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": 21, 509 | "metadata": {}, 510 | "outputs": [ 511 | { 512 | "name": "stdout", 513 | "output_type": "stream", 514 | "text": [ 515 | "0.7621581456480089\n", 516 | "0.4759966424636959\n", 517 | "0.6620498346854784\n", 518 | "0.6471361547244447\n", 519 | "0.6352431807543495\n" 520 | ] 521 | } 522 | ], 523 | "source": [ 524 | " #线性核函数\n", 525 | "model = SVR(kernel=\"linear\")\n", 526 | "model.fit(X_train_s, y_train)\n", 527 | "print(model.score(X_test_s, y_test))\n", 528 | "#二次多项式核\n", 529 | "model = SVR(kernel=\"poly\", degree=2)\n", 530 | "model.fit(X_train_s, y_train)\n", 531 | "print(model.score(X_test_s, y_test))\n", 532 | "#三次多项式\n", 533 | "model = SVR(kernel=\"poly\", degree=3)\n", 534 | "model.fit(X_train_s, y_train)\n", 535 | "print(model.score(X_test_s, y_test))\n", 536 | "#径向核\n", 537 | "model = SVR(kernel=\"rbf\")\n", 538 | "model.fit(X_train_s, y_train)\n", 539 | "print(model.score(X_test_s, y_test))\n", 540 | "#S核\n", 541 | "model = SVR(kernel=\"sigmoid\")\n", 542 | "model.fit(X_train_s, y_train)\n", 543 | "print(model.score(X_test_s, y_test))" 544 | ] 545 | }, 546 | { 547 | "cell_type": "markdown", 548 | "metadata": {}, 549 | "source": [ 550 | "相关向量机(RVM)效果" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": 22, 556 | "metadata": {}, 557 | "outputs": [ 558 | { 559 | "name": "stdout", 560 | "output_type": "stream", 561 | "text": [ 562 | "0.7651187959253563\n", 563 | "0.9231470349980935\n", 564 | "0.8250885733161272\n" 565 | ] 566 | } 567 | ], 568 | "source": [ 569 | "model = RVR(kernel=\"linear\")\n", 570 | "model.fit(X_train_s, y_train)\n", 571 | "print(model.score(X_test_s, y_test))\n", 572 | " \n", 573 | "model = RVR(kernel=\"rbf\")\n", 574 | "model.fit(X_train_s, y_train)\n", 575 | "print(model.score(X_test_s, y_test))\n", 576 | " \n", 577 | "model = RVR(kernel=\"poly\")\n", 578 | "model.fit(X_train_s, y_train)\n", 579 | "print(model.score(X_test_s, y_test))" 580 | ] 581 | }, 582 | { 583 | "cell_type": "markdown", 584 | "metadata": {}, 585 | "source": [ 586 | "学习总结RVM与SVM的区别: \n", 587 | "1. SVM 基于结构风险最小化原则构建学习机,RVM基于贝叶斯框架构建学习机 \n", 588 | "2. 与SVM相比,RVM不仅获得二值输出,而且获得概率输出 \n", 589 | "3. 在核函数的选择上,不受梅西定理的限制,可以构建任意的核函数 \n", 590 | "4. 不需对惩罚因子做出设置。在SVM中惩罚因子是平衡经验风险和置信区间的一个常数,实验结果对该数据十分敏感,设置不当会引起过学习等问题。但是在RVM中参数自动赋值 \n", 591 | "5. 与SVM相比,RVM更稀疏,从而测试时间更短,更适用于在线检测。众所周知,SVM的支持向量的个数随着训练样本的增大成线性增长,当训练样本很大的时候,显然是不合适的。虽然RVM的相关向量也随着训练样本的增加而增加,但是增长速度相对SVM却慢了很多。 \n", 592 | "6. 学习机有一个很重要的能力是泛化能力,也就是对于没有训练过的样本的测试能力。文章表明,RVM的泛化能力好于SVM。 \n", 593 | "7. 无论是在回归问题上还是分类问题上,RVM的准确率都不亚于SVM。 \n", 594 | "8. 但是RVM训练时间长 " 595 | ] 596 | }, 597 | { 598 | "cell_type": "markdown", 599 | "metadata": {}, 600 | "source": [ 601 | "导入包和定义RVM类" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": 24, 607 | "metadata": {}, 608 | "outputs": [], 609 | "source": [ 610 | "import os\n", 611 | "import math\n", 612 | "import time\n", 613 | "import datetime\n", 614 | "import random as rn\n", 615 | "import numpy as np\n", 616 | "import pandas as pd\n", 617 | "import matplotlib.pyplot as plt\n", 618 | "%matplotlib inline\n", 619 | "plt.rcParams ['font.sans-serif'] ='SimHei' #显示中文\n", 620 | "plt.rcParams ['axes.unicode_minus']=False #显示负号\n", 621 | " \n", 622 | "from sklearn.model_selection import train_test_split\n", 623 | "from sklearn.preprocessing import MinMaxScaler,StandardScaler\n", 624 | "from sklearn.metrics import mean_absolute_error\n", 625 | "from sklearn.metrics import mean_squared_error,r2_score\n", 626 | " \n", 627 | "import tensorflow as tf\n", 628 | "import keras\n", 629 | "from keras.models import Model, Sequential\n", 630 | "from keras.layers import GRU, Dense,Conv1D, MaxPooling1D,GlobalMaxPooling1D,Embedding,Dropout,Flatten,SimpleRNN,LSTM\n", 631 | "from keras.callbacks import EarlyStopping\n", 632 | "#from tensorflow.keras import regularizers\n", 633 | "#from keras.utils.np_utils import to_categorical\n", 634 | "from tensorflow.keras import optimizers" 635 | ] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "execution_count": 26, 640 | "metadata": {}, 641 | "outputs": [], 642 | "source": [ 643 | "\"\"\"Relevance Vector Machine classes for regression and classification.\"\"\"\n", 644 | "from scipy.optimize import minimize\n", 645 | "from scipy.special import expit\n", 646 | " \n", 647 | "from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin\n", 648 | "from sklearn.metrics.pairwise import (\n", 649 | " linear_kernel,\n", 650 | " rbf_kernel,\n", 651 | " polynomial_kernel\n", 652 | ")\n", 653 | "from sklearn.multiclass import OneVsOneClassifier\n", 654 | "from sklearn.utils.validation import check_X_y\n", 655 | " \n", 656 | "class BaseRVM(BaseEstimator):\n", 657 | " \n", 658 | " \"\"\"Base Relevance Vector Machine class.\n", 659 | " Implementation of Mike Tipping's Relevance Vector Machine using the\n", 660 | " scikit-learn API. Add a posterior over weights method and a predict\n", 661 | " in subclass to use for classification or regression.\n", 662 | " \"\"\"\n", 663 | " \n", 664 | " def __init__(\n", 665 | " self,\n", 666 | " kernel='rbf',\n", 667 | " degree=3,\n", 668 | " coef1=None,\n", 669 | " coef0=0.0,\n", 670 | " n_iter=3000,\n", 671 | " tol=1e-3,\n", 672 | " alpha=1e-6,\n", 673 | " threshold_alpha=1e9,\n", 674 | " beta=1.e-6,\n", 675 | " beta_fixed=False,\n", 676 | " bias_used=True,\n", 677 | " verbose=False\n", 678 | " ):\n", 679 | " \"\"\"Copy params to object properties, no validation.\"\"\"\n", 680 | " self.kernel = kernel\n", 681 | " self.degree = degree\n", 682 | " self.coef1 = coef1\n", 683 | " self.coef0 = coef0\n", 684 | " self.n_iter = n_iter\n", 685 | " self.tol = tol\n", 686 | " self.alpha = alpha\n", 687 | " self.threshold_alpha = threshold_alpha\n", 688 | " self.beta = beta\n", 689 | " self.beta_fixed = beta_fixed\n", 690 | " self.bias_used = bias_used\n", 691 | " self.verbose = verbose\n", 692 | " \n", 693 | " def get_params(self, deep=True):\n", 694 | " \"\"\"Return parameters as a dictionary.\"\"\"\n", 695 | " params = {\n", 696 | " 'kernel': self.kernel,\n", 697 | " 'degree': self.degree,\n", 698 | " 'coef1': self.coef1,\n", 699 | " 'coef0': self.coef0,\n", 700 | " 'n_iter': self.n_iter,\n", 701 | " 'tol': self.tol,\n", 702 | " 'alpha': self.alpha,\n", 703 | " 'threshold_alpha': self.threshold_alpha,\n", 704 | " 'beta': self.beta,\n", 705 | " 'beta_fixed': self.beta_fixed,\n", 706 | " 'bias_used': self.bias_used,\n", 707 | " 'verbose': self.verbose\n", 708 | " }\n", 709 | " return params\n", 710 | " \n", 711 | " def set_params(self, **parameters):\n", 712 | " \"\"\"Set parameters using kwargs.\"\"\"\n", 713 | " for parameter, value in parameters.items():\n", 714 | " setattr(self, parameter, value)\n", 715 | " return self\n", 716 | " \n", 717 | " def _apply_kernel(self, x, y):\n", 718 | " \"\"\"Apply the selected kernel function to the data.\"\"\"\n", 719 | " if self.kernel == 'linear':\n", 720 | " phi = linear_kernel(x, y)\n", 721 | " elif self.kernel == 'rbf':\n", 722 | " phi = rbf_kernel(x, y, self.coef1)\n", 723 | " elif self.kernel == 'poly':\n", 724 | " phi = polynomial_kernel(x, y, self.degree, self.coef1, self.coef0)\n", 725 | " elif callable(self.kernel):\n", 726 | " phi = self.kernel(x, y)\n", 727 | " if len(phi.shape) != 2:\n", 728 | " raise ValueError(\n", 729 | " \"Custom kernel function did not return 2D matrix\"\n", 730 | " )\n", 731 | " if phi.shape[0] != x.shape[0]:\n", 732 | " raise ValueError(\n", 733 | " \"Custom kernel function did not return matrix with rows\"\n", 734 | " \" equal to number of data points.\"\"\"\n", 735 | " )\n", 736 | " else:\n", 737 | " raise ValueError(\"Kernel selection is invalid.\")\n", 738 | " \n", 739 | " if self.bias_used:\n", 740 | " phi = np.append(phi, np.ones((phi.shape[0], 1)), axis=1)\n", 741 | " \n", 742 | " return phi\n", 743 | " \n", 744 | " def _prune(self):\n", 745 | " \"\"\"Remove basis functions based on alpha values.\"\"\"\n", 746 | " keep_alpha = self.alpha_ < self.threshold_alpha\n", 747 | " \n", 748 | " if not np.any(keep_alpha):\n", 749 | " keep_alpha[0] = True\n", 750 | " if self.bias_used:\n", 751 | " keep_alpha[-1] = True\n", 752 | " \n", 753 | " if self.bias_used:\n", 754 | " if not keep_alpha[-1]:\n", 755 | " self.bias_used = False\n", 756 | " self.relevance_ = self.relevance_[keep_alpha[:-1]]\n", 757 | " else:\n", 758 | " self.relevance_ = self.relevance_[keep_alpha]\n", 759 | " \n", 760 | " self.alpha_ = self.alpha_[keep_alpha]\n", 761 | " self.alpha_old = self.alpha_old[keep_alpha]\n", 762 | " self.gamma = self.gamma[keep_alpha]\n", 763 | " self.phi = self.phi[:, keep_alpha]\n", 764 | " self.sigma_ = self.sigma_[np.ix_(keep_alpha, keep_alpha)]\n", 765 | " self.m_ = self.m_[keep_alpha]\n", 766 | " \n", 767 | " def fit(self, X, y):\n", 768 | " \"\"\"Fit the RVR to the training data.\"\"\"\n", 769 | " X, y = check_X_y(X, y)\n", 770 | " \n", 771 | " n_samples, n_features = X.shape\n", 772 | " \n", 773 | " self.phi = self._apply_kernel(X, X)\n", 774 | " \n", 775 | " n_basis_functions = self.phi.shape[1]\n", 776 | " \n", 777 | " self.relevance_ = X\n", 778 | " self.y = y\n", 779 | " \n", 780 | " self.alpha_ = self.alpha * np.ones(n_basis_functions)\n", 781 | " self.beta_ = self.beta\n", 782 | " \n", 783 | " self.m_ = np.zeros(n_basis_functions)\n", 784 | " \n", 785 | " self.alpha_old = self.alpha_\n", 786 | " \n", 787 | " for i in range(self.n_iter):\n", 788 | " self._posterior()\n", 789 | " \n", 790 | " self.gamma = 1 - self.alpha_*np.diag(self.sigma_)\n", 791 | " self.alpha_ = self.gamma/(self.m_ ** 2)\n", 792 | " \n", 793 | " if not self.beta_fixed:\n", 794 | " self.beta_ = (n_samples - np.sum(self.gamma))/(\n", 795 | " np.sum((y - np.dot(self.phi, self.m_)) ** 2))\n", 796 | " \n", 797 | " self._prune()\n", 798 | " \n", 799 | " if self.verbose:\n", 800 | " print(\"Iteration: {}\".format(i))\n", 801 | " print(\"Alpha: {}\".format(self.alpha_))\n", 802 | " print(\"Beta: {}\".format(self.beta_))\n", 803 | " print(\"Gamma: {}\".format(self.gamma))\n", 804 | " print(\"m: {}\".format(self.m_))\n", 805 | " print(\"Relevance Vectors: {}\".format(self.relevance_.shape[0]))\n", 806 | " print()\n", 807 | " \n", 808 | " delta = np.amax(np.absolute(self.alpha_ - self.alpha_old))\n", 809 | " \n", 810 | " if delta < self.tol and i > 1:\n", 811 | " break\n", 812 | " \n", 813 | " self.alpha_old = self.alpha_\n", 814 | " \n", 815 | " if self.bias_used:\n", 816 | " self.bias = self.m_[-1]\n", 817 | " else:\n", 818 | " self.bias = None\n", 819 | " \n", 820 | " return self\n", 821 | " \n", 822 | "class RVR(BaseRVM, RegressorMixin):\n", 823 | " \n", 824 | " \"\"\"Relevance Vector Machine Regression.\n", 825 | " Implementation of Mike Tipping's Relevance Vector Machine for regression\n", 826 | " using the scikit-learn API.\n", 827 | " \"\"\"\n", 828 | " \n", 829 | " def _posterior(self):\n", 830 | " \"\"\"Compute the posterior distriubtion over weights.\"\"\"\n", 831 | " i_s = np.diag(self.alpha_) + self.beta_ * np.dot(self.phi.T, self.phi)\n", 832 | " self.sigma_ = np.linalg.inv(i_s)\n", 833 | " self.m_ = self.beta_ * np.dot(self.sigma_, np.dot(self.phi.T, self.y))\n", 834 | " \n", 835 | " def predict(self, X, eval_MSE=False):\n", 836 | " \"\"\"Evaluate the RVR model at x.\"\"\"\n", 837 | " phi = self._apply_kernel(X, self.relevance_)\n", 838 | " \n", 839 | " y = np.dot(phi, self.m_)\n", 840 | " \n", 841 | " if eval_MSE:\n", 842 | " MSE = (1/self.beta_) + np.dot(phi, np.dot(self.sigma_, phi.T))\n", 843 | " return y, MSE[:, 0]\n", 844 | " else:\n", 845 | " return y" 846 | ] 847 | }, 848 | { 849 | "cell_type": "code", 850 | "execution_count": 23, 851 | "metadata": {}, 852 | "outputs": [ 853 | { 854 | "data": { 855 | "text/html": [ 856 | "
\n", 857 | "\n", 870 | "\n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | "
循环次数平均放电电压平均放电温度等压降放电时间容量
013.52982932.5723281622.6251.856487
123.53732032.7252351661.0781.846327
233.54373732.6428621661.9221.835349
343.54366632.5148761662.9061.835263
453.54234332.3823491661.9381.834646
\n", 924 | "
" 925 | ], 926 | "text/plain": [ 927 | " 循环次数 平均放电电压 平均放电温度 等压降放电时间 容量\n", 928 | "0 1 3.529829 32.572328 1622.625 1.856487\n", 929 | "1 2 3.537320 32.725235 1661.078 1.846327\n", 930 | "2 3 3.543737 32.642862 1661.922 1.835349\n", 931 | "3 4 3.543666 32.514876 1662.906 1.835263\n", 932 | "4 5 3.542343 32.382349 1661.938 1.834646" 933 | ] 934 | }, 935 | "execution_count": 23, 936 | "metadata": {}, 937 | "output_type": "execute_result" 938 | } 939 | ], 940 | "source": [ 941 | "df = pd.read_csv(\"B0005.csv\")\n", 942 | "df.head() #前面五行" 943 | ] 944 | }, 945 | { 946 | "cell_type": "code", 947 | "execution_count": 35, 948 | "metadata": {}, 949 | "outputs": [], 950 | "source": [ 951 | "#capacity_original_half = np.array(df)[:20,4]#1到20行的第四列的容量数据 切片" 952 | ] 953 | }, 954 | { 955 | "cell_type": "markdown", 956 | "metadata": {}, 957 | "source": [ 958 | "df_for_training=df[:120]\n", 959 | "df_for_testing=df[:]\n", 960 | "print(df_for_training.shape)\n", 961 | "print(df_for_testing.shape)" 962 | ] 963 | }, 964 | { 965 | "cell_type": "markdown", 966 | "metadata": {}, 967 | "source": [ 968 | "scaler = MinMaxScaler(feature_range=(0,1))\n", 969 | "df_for_training_scaled = scaler.fit_transform(df_for_training)\n", 970 | "df_for_testing_scaled = scaler.transform(df_for_testing)\n", 971 | "df_for_training_scaled" 972 | ] 973 | }, 974 | { 975 | "cell_type": "markdown", 976 | "metadata": {}, 977 | "source": [ 978 | "def createXY(dataset,n_past):\n", 979 | " dataX = []\n", 980 | " dataY = []\n", 981 | " for i in range(n_past,len(dataset)):\n", 982 | " \n", 983 | " dataX.append(dataset[i-n_past:i,0:dataset.shape[1]]) \n", 984 | " dataY.append(dataset[i,4])\n", 985 | " return np.array(dataX),np.array(dataY)\n", 986 | "\n", 987 | "trainX, trainY = createXY(df_for_training_scaled,20) \n", 988 | "testX, testY = createXY(df_for_testing_scaled,20) " 989 | ] 990 | }, 991 | { 992 | "cell_type": "markdown", 993 | "metadata": {}, 994 | "source": [ 995 | "训练集合测试集拆分" 996 | ] 997 | }, 998 | { 999 | "cell_type": "code", 1000 | "execution_count": null, 1001 | "metadata": {}, 1002 | "outputs": [], 1003 | "source": [ 1004 | "test_split=round(len(df)*0.30)\n", 1005 | "df_for_training=df[:120]\n", 1006 | "df_for_testing=df[0:]" 1007 | ] 1008 | }, 1009 | { 1010 | "cell_type": "code", 1011 | "execution_count": 79, 1012 | "metadata": {}, 1013 | "outputs": [ 1014 | { 1015 | "data": { 1016 | "text/plain": [ 1017 | "(168, 5)" 1018 | ] 1019 | }, 1020 | "execution_count": 79, 1021 | "metadata": {}, 1022 | "output_type": "execute_result" 1023 | } 1024 | ], 1025 | "source": [ 1026 | "df_for_testing.shape" 1027 | ] 1028 | }, 1029 | { 1030 | "cell_type": "code", 1031 | "execution_count": 80, 1032 | "metadata": {}, 1033 | "outputs": [ 1034 | { 1035 | "data": { 1036 | "text/plain": [ 1037 | "(120, 5)" 1038 | ] 1039 | }, 1040 | "execution_count": 80, 1041 | "metadata": {}, 1042 | "output_type": "execute_result" 1043 | } 1044 | ], 1045 | "source": [ 1046 | "df_for_training.shape" 1047 | ] 1048 | }, 1049 | { 1050 | "cell_type": "code", 1051 | "execution_count": 81, 1052 | "metadata": {}, 1053 | "outputs": [], 1054 | "source": [ 1055 | "df_for_training_1 =df.iloc[:,:-1]" 1056 | ] 1057 | }, 1058 | { 1059 | "cell_type": "code", 1060 | "execution_count": 86, 1061 | "metadata": {}, 1062 | "outputs": [ 1063 | { 1064 | "data": { 1065 | "text/plain": [ 1066 | "(168, 4)" 1067 | ] 1068 | }, 1069 | "execution_count": 86, 1070 | "metadata": {}, 1071 | "output_type": "execute_result" 1072 | } 1073 | ], 1074 | "source": [ 1075 | "df_for_training_1.shape" 1076 | ] 1077 | }, 1078 | { 1079 | "cell_type": "code", 1080 | "execution_count": 154, 1081 | "metadata": {}, 1082 | "outputs": [ 1083 | { 1084 | "data": { 1085 | "text/plain": [ 1086 | "(168, 4)" 1087 | ] 1088 | }, 1089 | "execution_count": 154, 1090 | "metadata": {}, 1091 | "output_type": "execute_result" 1092 | } 1093 | ], 1094 | "source": [ 1095 | "df_for_training_1.shape" 1096 | ] 1097 | }, 1098 | { 1099 | "cell_type": "code", 1100 | "execution_count": 117, 1101 | "metadata": {}, 1102 | "outputs": [], 1103 | "source": [ 1104 | "df_for_testing_1=df.iloc[:,-1]" 1105 | ] 1106 | }, 1107 | { 1108 | "cell_type": "code", 1109 | "execution_count": 121, 1110 | "metadata": {}, 1111 | "outputs": [ 1112 | { 1113 | "data": { 1114 | "text/plain": [ 1115 | "0 1.856487\n", 1116 | "1 1.846327\n", 1117 | "2 1.835349\n", 1118 | "3 1.835263\n", 1119 | "4 1.834646\n", 1120 | "5 1.835662\n", 1121 | "6 1.835146\n", 1122 | "7 1.825757\n", 1123 | "8 1.824774\n", 1124 | "9 1.824613\n", 1125 | "10 1.824620\n", 1126 | "11 1.814202\n", 1127 | "12 1.813752\n", 1128 | "13 1.813440\n", 1129 | "14 1.802598\n", 1130 | "15 1.802107\n", 1131 | "16 1.802580\n", 1132 | "17 1.803068\n", 1133 | "18 1.802778\n", 1134 | "19 1.847026\n", 1135 | "20 1.847417\n", 1136 | "21 1.836177\n", 1137 | "22 1.825781\n", 1138 | "23 1.825114\n", 1139 | "24 1.825582\n", 1140 | "25 1.814031\n", 1141 | "26 1.814769\n", 1142 | "27 1.813969\n", 1143 | "28 1.802766\n", 1144 | "29 1.804077\n", 1145 | " ... \n", 1146 | "138 1.354704\n", 1147 | "139 1.349315\n", 1148 | "140 1.344189\n", 1149 | "141 1.338991\n", 1150 | "142 1.338915\n", 1151 | "143 1.334007\n", 1152 | "144 1.328644\n", 1153 | "145 1.323171\n", 1154 | "146 1.318169\n", 1155 | "147 1.318466\n", 1156 | "148 1.318293\n", 1157 | "149 1.323872\n", 1158 | "150 1.360122\n", 1159 | "151 1.339531\n", 1160 | "152 1.329029\n", 1161 | "153 1.323674\n", 1162 | "154 1.318634\n", 1163 | "155 1.313475\n", 1164 | "156 1.313202\n", 1165 | "157 1.307796\n", 1166 | "158 1.303033\n", 1167 | "159 1.303357\n", 1168 | "160 1.303410\n", 1169 | "161 1.297887\n", 1170 | "162 1.298074\n", 1171 | "163 1.293464\n", 1172 | "164 1.288003\n", 1173 | "165 1.287453\n", 1174 | "166 1.309015\n", 1175 | "167 1.325079\n", 1176 | "Name: 容量, Length: 168, dtype: float64" 1177 | ] 1178 | }, 1179 | "execution_count": 121, 1180 | "metadata": {}, 1181 | "output_type": "execute_result" 1182 | } 1183 | ], 1184 | "source": [ 1185 | "df_for_testing_1" 1186 | ] 1187 | }, 1188 | { 1189 | "cell_type": "code", 1190 | "execution_count": null, 1191 | "metadata": {}, 1192 | "outputs": [], 1193 | "source": [] 1194 | }, 1195 | { 1196 | "cell_type": "code", 1197 | "execution_count": null, 1198 | "metadata": {}, 1199 | "outputs": [], 1200 | "source": [] 1201 | }, 1202 | { 1203 | "cell_type": "code", 1204 | "execution_count": null, 1205 | "metadata": {}, 1206 | "outputs": [], 1207 | "source": [] 1208 | }, 1209 | { 1210 | "cell_type": "code", 1211 | "execution_count": null, 1212 | "metadata": {}, 1213 | "outputs": [], 1214 | "source": [] 1215 | }, 1216 | { 1217 | "cell_type": "code", 1218 | "execution_count": null, 1219 | "metadata": {}, 1220 | "outputs": [], 1221 | "source": [] 1222 | }, 1223 | { 1224 | "cell_type": "code", 1225 | "execution_count": null, 1226 | "metadata": {}, 1227 | "outputs": [], 1228 | "source": [] 1229 | }, 1230 | { 1231 | "cell_type": "code", 1232 | "execution_count": null, 1233 | "metadata": {}, 1234 | "outputs": [], 1235 | "source": [] 1236 | }, 1237 | { 1238 | "cell_type": "code", 1239 | "execution_count": null, 1240 | "metadata": {}, 1241 | "outputs": [], 1242 | "source": [] 1243 | }, 1244 | { 1245 | "cell_type": "code", 1246 | "execution_count": 118, 1247 | "metadata": {}, 1248 | "outputs": [], 1249 | "source": [ 1250 | "scaler_s=StandardScaler()\n", 1251 | "scaler_s.fit(df_for_training_1)\n", 1252 | "#scaler_s.fit(df_for_testing_1)\n", 1253 | "#df_for_testing_1_s = scaler_s.transform(df_for_trainging_1)\n", 1254 | "df_for_training_1_s = scaler_s.transform(df_for_training_1)" 1255 | ] 1256 | }, 1257 | { 1258 | "cell_type": "code", 1259 | "execution_count": 119, 1260 | "metadata": {}, 1261 | "outputs": [ 1262 | { 1263 | "data": { 1264 | "text/plain": [ 1265 | "array([[-1.72177148, 0.39968477, -0.28679889, 1.33915253],\n", 1266 | " [-1.70115147, 0.62927986, -0.06876642, 1.46712657],\n", 1267 | " [-1.68053145, 0.82593311, -0.18622368, 1.46993546],\n", 1268 | " [-1.65991143, 0.82376875, -0.36872002, 1.47321027],\n", 1269 | " [-1.63929141, 0.78322648, -0.55769281, 1.46998871],\n", 1270 | " [-1.61867139, 0.75231784, -0.48378316, 1.47118015],\n", 1271 | " [-1.59805138, 0.74282674, -0.41785706, 1.47092389],\n", 1272 | " [-1.57743136, 1.14456278, -0.51760631, 1.47888129],\n", 1273 | " [-1.55681134, 1.10785623, -0.60932259, 1.47295068],\n", 1274 | " [-1.53619132, 1.05484 , -0.70820023, 1.47492423],\n", 1275 | " [-1.51557131, 0.99619855, -0.8788879 , 1.47575957],\n", 1276 | " [-1.49495129, 1.31698373, -0.58375527, 1.41470957],\n", 1277 | " [-1.47433127, 1.26558002, -0.64917032, 1.41663652],\n", 1278 | " [-1.45371125, 1.19234073, -0.87889248, 1.35256796],\n", 1279 | " [-1.43309123, 1.40854634, -1.04004974, 1.35257129],\n", 1280 | " [-1.41247122, 1.3688687 , -1.20043135, 1.34757919],\n", 1281 | " [-1.3918512 , 1.33484593, -1.18608425, 1.28944124],\n", 1282 | " [-1.37123118, 1.29735956, -1.15278445, 1.28928149],\n", 1283 | " [-1.35061116, 1.2595977 , -1.15102626, 1.28777721],\n", 1284 | " [-1.32999115, 1.08033241, -0.57325752, 1.54091641],\n", 1285 | " [-1.30937113, 1.09655479, -0.10439337, 1.60576373],\n", 1286 | " [-1.28875111, 1.22610805, -0.64279837, 1.54382846],\n", 1287 | " [-1.26813109, 1.36757623, -1.12006049, 1.47888129],\n", 1288 | " [-1.24751107, 1.22049209, -1.25338575, 1.41367121],\n", 1289 | " [-1.22689106, 1.28373098, -1.41602512, 1.41564475],\n", 1290 | " [-1.20627104, 1.45011565, -1.5467097 , 1.35225845],\n", 1291 | " [-1.18565102, 1.31173818, -1.67789263, 1.41564475],\n", 1292 | " [-1.165031 , 1.22730922, -1.72232625, 1.34752594],\n", 1293 | " [-1.14441099, 1.39917637, -1.68573704, 1.34726635],\n", 1294 | " [-1.12379097, 1.29802789, -1.66828566, 1.29136486],\n", 1295 | " [-1.10317095, 1.242109 , -1.22328851, 1.33332842],\n", 1296 | " [-1.08255093, 1.38169799, -0.98720803, 1.36442252],\n", 1297 | " [-1.06193091, 1.3782609 , -1.31707264, 1.30150547],\n", 1298 | " [-1.0413109 , 1.31611162, -1.86476949, 1.237803 ],\n", 1299 | " [-1.02069088, 1.25061744, -1.94510761, 1.17815744],\n", 1300 | " [-1.00007086, 1.31589779, -1.24227651, 1.20956771],\n", 1301 | " [-0.97945084, 1.29560857, -1.6033511 , 1.11377937],\n", 1302 | " [-0.95883083, 1.23289618, -1.83189385, 1.08330762],\n", 1303 | " [-0.93821081, 1.20937922, -2.05825869, 1.05236661],\n", 1304 | " [-0.91759079, 1.11827807, -1.71667293, 1.02132243],\n", 1305 | " [-0.89697077, 1.2099382 , -0.94400547, 1.02080325],\n", 1306 | " [-0.87635076, 1.21125942, -0.71550305, 0.98991549],\n", 1307 | " [-0.85573074, 1.29220915, -1.32776838, 0.90640106],\n", 1308 | " [-0.83511072, 1.12337071, -1.44537239, 0.96001284],\n", 1309 | " [-0.8144907 , 1.0489073 , -1.93118682, 0.864331 ],\n", 1310 | " [-0.79387068, 1.05826873, -2.05593308, 0.83338999],\n", 1311 | " [-0.77325067, 0.98733714, -2.01641424, 0.80162029],\n", 1312 | " [-0.75263065, 1.31715575, -0.65959461, 1.14596841],\n", 1313 | " [-0.73201063, 1.3962449 , -0.48139586, 1.11409221],\n", 1314 | " [-0.71139061, 1.27338562, -0.15183161, 0.98928981],\n", 1315 | " [-0.6907706 , 1.14088023, -0.6975963 , 0.89615727],\n", 1316 | " [-0.67015058, 1.09070882, -1.09234972, 0.83521044],\n", 1317 | " [-0.64953056, 1.00620012, -1.0118156 , 0.8324548 ],\n", 1318 | " [-0.62891054, 0.99891733, -0.47898721, 0.7708357 ],\n", 1319 | " [-0.60829052, 0.93243774, -0.86319533, 0.70941961],\n", 1320 | " [-0.58767051, 0.9169099 , -1.02086198, 0.6784786 ],\n", 1321 | " [-0.56705049, 0.85329913, -1.13722448, 0.64478528],\n", 1322 | " [-0.54643047, 0.82263644, -0.79361843, 0.64743442],\n", 1323 | " [-0.52581045, 0.84916051, -0.30961091, 0.5837852 ],\n", 1324 | " [-0.50519044, 0.73961289, -0.53183837, 0.55206542],\n", 1325 | " [-0.48457042, 0.7480558 , -0.68407219, 0.5225821 ],\n", 1326 | " [-0.4639504 , 0.73197735, -0.90886354, 0.45856679],\n", 1327 | " [-0.44333038, 0.63800678, -0.55721302, 0.4590893 ],\n", 1328 | " [-0.42271036, 0.66306443, -0.66228293, 0.39673802],\n", 1329 | " [-0.40209035, 0.68953735, -0.29529566, 0.39933724],\n", 1330 | " [-0.38147033, 0.51803043, -0.43922713, 0.33626377],\n", 1331 | " [-0.36085031, 0.53879616, -0.60801699, 0.30339581],\n", 1332 | " [-0.34023029, 0.44459954, -0.57310706, 0.27297731],\n", 1333 | " [-0.31961028, 0.43293151, -0.45339507, 0.2408382 ],\n", 1334 | " [-0.29899026, 0.36578917, -0.3896874 , 0.20984727],\n", 1335 | " [-0.27837024, 0.29015851, -0.38850242, 0.17713573],\n", 1336 | " [-0.25775022, 0.36461337, -0.42858521, 0.14697681],\n", 1337 | " [-0.2371302 , 0.26876659, -0.38308598, 0.11666148],\n", 1338 | " [-0.21651019, 0.18061269, -0.37256015, 0.05514222],\n", 1339 | " [-0.19589017, 0.24169007, -0.45134925, 0.022018 ],\n", 1340 | " [-0.17527015, 0.13877905, -0.46542165, -0.0087666 ],\n", 1341 | " [-0.15465013, 0.03820259, -0.63391146, -0.07210298],\n", 1342 | " [-0.13403012, 0.21137155, 0.10184355, 0.02274684],\n", 1343 | " [-0.1134101 , 0.07230639, -0.65042497, -0.1015863 ],\n", 1344 | " [-0.09279008, -0.01696967, -1.01197163, -0.13138245],\n", 1345 | " [-0.07217006, -0.0617308 , -0.78922245, -0.13205805],\n", 1346 | " [-0.05155004, -0.10690043, -0.08760659, -0.16419716],\n", 1347 | " [-0.03093003, -0.07056671, 0.24463091, -0.1937304 ],\n", 1348 | " [-0.01031001, -0.17835719, 0.11990254, -0.22670153],\n", 1349 | " [ 0.01031001, -0.17933362, -0.20504377, -0.25816172],\n", 1350 | " [ 0.03093003, -0.226616 , -0.56124335, -0.28858355],\n", 1351 | " [ 0.05155004, -0.30981979, -0.13019012, -0.31983407],\n", 1352 | " [ 0.07217006, -0.17771889, 0.31428981, -0.32014691],\n", 1353 | " [ 0.09279008, -0.19865422, 0.38775696, -0.3517136 ],\n", 1354 | " [ 0.1134101 , 0.16654441, 0.1181145 , -0.22633544],\n", 1355 | " [ 0.13403012, -0.01342786, 0.80712595, -0.133985 ],\n", 1356 | " [ 0.15465013, -0.10259555, 0.824326 , -0.22664828],\n", 1357 | " [ 0.17527015, -0.05056331, 0.53332754, -0.25784888],\n", 1358 | " [ 0.19589017, -0.2599971 , 0.33339937, -0.31988732],\n", 1359 | " [ 0.21651019, -0.2328157 , 0.20882047, -0.38254811],\n", 1360 | " [ 0.2371302 , -0.30578912, 0.32164617, -0.38197568],\n", 1361 | " [ 0.25775022, -0.32782016, 0.37681413, -0.41395505],\n", 1362 | " [ 0.27837024, -0.40061227, 0.43339477, -0.41255393],\n", 1363 | " [ 0.29899026, -0.30768197, 0.40099336, -0.47692867],\n", 1364 | " [ 0.31961028, -0.33424401, 0.48800806, -0.5069345 ],\n", 1365 | " [ 0.34023029, -0.38674237, 0.50756859, -0.50657174],\n", 1366 | " [ 0.36085031, -0.42080534, 0.50936498, -0.56845043],\n", 1367 | " [ 0.38147033, -0.56866168, 0.60919438, -0.50667491],\n", 1368 | " [ 0.40209035, -0.31264869, 1.44518401, -0.47417304],\n", 1369 | " [ 0.42271036, -0.46983387, 1.2831856 , -0.53839468],\n", 1370 | " [ 0.44333038, -0.46366617, 1.08705539, -0.56980495],\n", 1371 | " [ 0.4639504 , -0.61812218, 0.78316694, -0.63194323],\n", 1372 | " [ 0.48457042, -0.54009556, 0.81671943, -0.63324451],\n", 1373 | " [ 0.50519044, -0.63460641, 1.24914034, -0.66132671],\n", 1374 | " [ 0.52581045, -0.64010225, 1.45334004, -0.66148313],\n", 1375 | " [ 0.54643047, -0.58844493, 1.07086916, -0.72513235],\n", 1376 | " [ 0.56705049, -0.64630811, 0.87973371, -0.75721489],\n", 1377 | " [ 0.58767051, -0.73284404, 1.28366619, -0.75752772],\n", 1378 | " [ 0.60829052, -0.65207593, 1.62175551, -0.7562797 ],\n", 1379 | " [ 0.62891054, -0.6912253 , 1.44246447, -0.78659503],\n", 1380 | " [ 0.64953056, -0.78426966, 1.19464056, -0.81837139],\n", 1381 | " [ 0.67015058, -0.85439151, 0.97322135, -0.84884314],\n", 1382 | " [ 0.6907706 , -0.90137768, 1.35434209, -0.81852448],\n", 1383 | " [ 0.71139061, -0.75984875, 1.82034364, -0.84837388],\n", 1384 | " [ 0.73201063, -0.77236415, 1.59115719, -0.75508492],\n", 1385 | " [ 0.75263065, -0.65006964, 1.84540006, -0.72497593],\n", 1386 | " [ 0.77325067, -0.70147691, 1.3948574 , -0.81847456],\n", 1387 | " [ 0.79387068, -0.7721879 , 0.98582243, -0.84883981],\n", 1388 | " [ 0.8144907 , -0.86470374, 0.89639475, -0.91327113],\n", 1389 | " [ 0.83511072, -0.91591018, 0.90863187, -0.91113784],\n", 1390 | " [ 0.85573074, -0.95739456, 0.87891661, -0.94265128],\n", 1391 | " [ 0.87635076, -1.00818575, 0.85067475, -0.9432237 ],\n", 1392 | " [ 0.89697077, -1.02940349, 0.7111013 , -1.00682301],\n", 1393 | " [ 0.91759079, -1.01627088, 0.64021104, -1.00676643],\n", 1394 | " [ 0.93821081, -1.12215833, 0.69672467, -1.03599349],\n", 1395 | " [ 0.95883083, -1.17488118, 0.76795256, -1.0368255 ],\n", 1396 | " [ 0.97945084, -1.1710163 , 0.68360904, -1.06823577],\n", 1397 | " [ 1.00007086, -1.05022494, 1.38540792, -0.97514982],\n", 1398 | " [ 1.02069088, -0.86134985, 1.58983514, -0.97426788],\n", 1399 | " [ 1.0413109 , -0.94814766, 1.37160742, -1.00733886],\n", 1400 | " [ 1.06193091, -1.16439885, 1.16618282, -1.06859853],\n", 1401 | " [ 1.08255093, -1.08813464, 0.7845459 , -1.06854528],\n", 1402 | " [ 1.10317095, -1.11665406, 1.3097025 , -1.09870752],\n", 1403 | " [ 1.12379097, -1.18528195, 1.58872282, -1.09860435],\n", 1404 | " [ 1.14441099, -1.23205227, 1.28277996, -1.12964853],\n", 1405 | " [ 1.165031 , -1.33042438, 0.91048453, -1.16162457],\n", 1406 | " [ 1.18565102, -1.35311245, 0.74975125, -1.19360726],\n", 1407 | " [ 1.20627104, -1.32851062, 1.1782985 , -1.16193741],\n", 1408 | " [ 1.22689106, -1.26604053, 1.20653824, -1.19246574],\n", 1409 | " [ 1.24751107, -1.36378265, 0.88557661, -1.22387268],\n", 1410 | " [ 1.26813109, -1.39181957, 0.56164438, -1.25554253],\n", 1411 | " [ 1.28875111, -1.35374452, 0.47933567, -1.25491686],\n", 1412 | " [ 1.30937113, -1.31700816, 0.98609978, -1.2552297 ],\n", 1413 | " [ 1.32999115, -1.42954723, 0.95416157, -1.28627388],\n", 1414 | " [ 1.35061116, -1.50981315, 0.37710694, -1.25486361],\n", 1415 | " [ 1.37123118, -1.22696873, 1.02878004, -1.1304273 ],\n", 1416 | " [ 1.3918512 , -1.25323067, 1.02366624, -1.19386685],\n", 1417 | " [ 1.41247122, -1.25578983, 0.79330345, -1.22407902],\n", 1418 | " [ 1.43309123, -1.29124111, 0.71233874, -1.22382276],\n", 1419 | " [ 1.45371125, -1.36337506, 0.65081258, -1.28549511],\n", 1420 | " [ 1.47433127, -1.34410399, 0.62048226, -1.2850791 ],\n", 1421 | " [ 1.49495129, -1.44558959, 0.65970448, -1.28643029],\n", 1422 | " [ 1.51557131, -1.43335299, 0.5891015 , -1.31747447],\n", 1423 | " [ 1.53619132, -1.41054442, 0.57121657, -1.3485719 ],\n", 1424 | " [ 1.55681134, -1.49179804, 0.6851848 , -1.34825907],\n", 1425 | " [ 1.57743136, -1.59136863, 0.77904534, -1.34945384],\n", 1426 | " [ 1.59805138, -1.4993782 , 0.80449595, -1.37966933],\n", 1427 | " [ 1.61867139, -1.61714854, 0.7376341 , -1.38060452],\n", 1428 | " [ 1.63929141, -1.54233838, 0.71613271, -1.3797725 ],\n", 1429 | " [ 1.65991143, -1.47961257, 0.7802841 , -1.4109731 ],\n", 1430 | " [ 1.68053145, -1.53180092, 0.85510419, -1.41211796],\n", 1431 | " [ 1.70115147, -1.40110404, 1.3404415 , -1.31747447],\n", 1432 | " [ 1.72177148, -1.26620928, 1.55689356, -1.25512653]])" 1433 | ] 1434 | }, 1435 | "execution_count": 119, 1436 | "metadata": {}, 1437 | "output_type": "execute_result" 1438 | } 1439 | ], 1440 | "source": [ 1441 | "df_for_training_1_s" 1442 | ] 1443 | }, 1444 | { 1445 | "cell_type": "code", 1446 | "execution_count": 126, 1447 | "metadata": {}, 1448 | "outputs": [ 1449 | { 1450 | "name": "stdout", 1451 | "output_type": "stream", 1452 | "text": [ 1453 | "0.9966711789757259\n", 1454 | "0.9980897747654969\n", 1455 | "0.9398968240706602\n" 1456 | ] 1457 | } 1458 | ], 1459 | "source": [ 1460 | "#构建RVM模型\n", 1461 | "rvm_model_linear = RVR(kernel=\"linear\")\n", 1462 | "rvm_model_linear.fit(df_for_training_1_s, df_for_testing_1)\n", 1463 | "print(rvm_model_linear.score(df_for_training_1_s, df_for_testing_1))\n", 1464 | " \n", 1465 | "rvm_model_rbf = RVR(kernel=\"rbf\")\n", 1466 | "rvm_model_rbf.fit(df_for_training_1_s, y)\n", 1467 | "print(rvm_model_rbf.score(X_s, df_for_testing_1))\n", 1468 | " \n", 1469 | "rvm_model_poly = RVR(kernel=\"poly\")\n", 1470 | "rvm_model_poly.fit(df_for_training_1_s, df_for_testing_1)\n", 1471 | "print(rvm_model_poly.score(df_for_training_1_s, df_for_testing_1))" 1472 | ] 1473 | }, 1474 | { 1475 | "cell_type": "markdown", 1476 | "metadata": {}, 1477 | "source": [ 1478 | "可以看出rbf核表现效果最好,将使用RBF的rvm" 1479 | ] 1480 | }, 1481 | { 1482 | "cell_type": "markdown", 1483 | "metadata": {}, 1484 | "source": [ 1485 | "定义一些函数,固定随机数种子,回归问题评价指标函数" 1486 | ] 1487 | }, 1488 | { 1489 | "cell_type": "code", 1490 | "execution_count": 96, 1491 | "metadata": {}, 1492 | "outputs": [], 1493 | "source": [ 1494 | "def set_my_seed():\n", 1495 | " os.environ['PYTHONHASHSEED'] = '0'\n", 1496 | " np.random.seed(1)\n", 1497 | " rn.seed(12345)\n", 1498 | " tf.random.set_seed(123)\n", 1499 | " \n", 1500 | "def evaluation(y_test, y_predict):\n", 1501 | " mae = mean_absolute_error(y_test, y_predict)\n", 1502 | " mse = mean_squared_error(y_test, y_predict)\n", 1503 | " rmse = np.sqrt(mean_squared_error(y_test, y_predict))\n", 1504 | " mape=(abs(y_predict -y_test)/ y_test).mean()\n", 1505 | " r_2=r2_score(y_test, y_predict)\n", 1506 | " return mae, rmse, mape,r_2 #mse" 1507 | ] 1508 | }, 1509 | { 1510 | "cell_type": "markdown", 1511 | "metadata": {}, 1512 | "source": [ 1513 | "##定义构建训练集和测试集的数据函数" 1514 | ] 1515 | }, 1516 | { 1517 | "cell_type": "code", 1518 | "execution_count": 160, 1519 | "metadata": {}, 1520 | "outputs": [], 1521 | "source": [ 1522 | "def build_sequences(text, window_size=20):\n", 1523 | " #text:list of capacity\n", 1524 | " x, y = [],[]\n", 1525 | " for i in range(len(text) - window_size):\n", 1526 | " sequence = text[i:i+window_size]\n", 1527 | " target = text[i+window_size]\n", 1528 | " x.append(sequence)\n", 1529 | " y.append(target)\n", 1530 | " return np.array(x), np.array(y)\n", 1531 | " \n", 1532 | "def get_traintest(data,train_size=len(df),window_size=20):\n", 1533 | " train=data[:train_size]\n", 1534 | " test=data[train_size-window_size:]\n", 1535 | " X_train,y_train=build_sequences(train,window_size=window_size)\n", 1536 | " X_test,y_test=build_sequences(test,window_size=window_size)\n", 1537 | " return X_train,y_train,X_test,y_test" 1538 | ] 1539 | }, 1540 | { 1541 | "cell_type": "code", 1542 | "execution_count": 161, 1543 | "metadata": {}, 1544 | "outputs": [ 1545 | { 1546 | "data": { 1547 | "text/plain": [ 1548 | "(404, 13)" 1549 | ] 1550 | }, 1551 | "execution_count": 161, 1552 | "metadata": {}, 1553 | "output_type": "execute_result" 1554 | } 1555 | ], 1556 | "source": [ 1557 | "X_train.shape" 1558 | ] 1559 | }, 1560 | { 1561 | "cell_type": "code", 1562 | "execution_count": 164, 1563 | "metadata": {}, 1564 | "outputs": [ 1565 | { 1566 | "data": { 1567 | "text/plain": [ 1568 | "(404,)" 1569 | ] 1570 | }, 1571 | "execution_count": 164, 1572 | "metadata": {}, 1573 | "output_type": "execute_result" 1574 | } 1575 | ], 1576 | "source": [ 1577 | "y_train.shape" 1578 | ] 1579 | }, 1580 | { 1581 | "cell_type": "code", 1582 | "execution_count": null, 1583 | "metadata": {}, 1584 | "outputs": [], 1585 | "source": [] 1586 | }, 1587 | { 1588 | "cell_type": "code", 1589 | "execution_count": null, 1590 | "metadata": {}, 1591 | "outputs": [], 1592 | "source": [] 1593 | }, 1594 | { 1595 | "cell_type": "code", 1596 | "execution_count": null, 1597 | "metadata": {}, 1598 | "outputs": [], 1599 | "source": [] 1600 | }, 1601 | { 1602 | "cell_type": "code", 1603 | "execution_count": null, 1604 | "metadata": {}, 1605 | "outputs": [], 1606 | "source": [] 1607 | }, 1608 | { 1609 | "cell_type": "code", 1610 | "execution_count": null, 1611 | "metadata": {}, 1612 | "outputs": [], 1613 | "source": [] 1614 | }, 1615 | { 1616 | "cell_type": "code", 1617 | "execution_count": null, 1618 | "metadata": {}, 1619 | "outputs": [], 1620 | "source": [] 1621 | }, 1622 | { 1623 | "cell_type": "code", 1624 | "execution_count": null, 1625 | "metadata": {}, 1626 | "outputs": [], 1627 | "source": [] 1628 | }, 1629 | { 1630 | "cell_type": "code", 1631 | "execution_count": null, 1632 | "metadata": {}, 1633 | "outputs": [], 1634 | "source": [] 1635 | }, 1636 | { 1637 | "cell_type": "code", 1638 | "execution_count": null, 1639 | "metadata": {}, 1640 | "outputs": [], 1641 | "source": [] 1642 | }, 1643 | { 1644 | "cell_type": "code", 1645 | "execution_count": 137, 1646 | "metadata": {}, 1647 | "outputs": [ 1648 | { 1649 | "data": { 1650 | "text/plain": [ 1651 | "(102, 13)" 1652 | ] 1653 | }, 1654 | "execution_count": 137, 1655 | "metadata": {}, 1656 | "output_type": "execute_result" 1657 | } 1658 | ], 1659 | "source": [ 1660 | "X_test.shape" 1661 | ] 1662 | }, 1663 | { 1664 | "cell_type": "code", 1665 | "execution_count": 138, 1666 | "metadata": {}, 1667 | "outputs": [ 1668 | { 1669 | "data": { 1670 | "text/plain": [ 1671 | "(102,)" 1672 | ] 1673 | }, 1674 | "execution_count": 138, 1675 | "metadata": {}, 1676 | "output_type": "execute_result" 1677 | } 1678 | ], 1679 | "source": [ 1680 | "y_test.shape" 1681 | ] 1682 | }, 1683 | { 1684 | "cell_type": "code", 1685 | "execution_count": null, 1686 | "metadata": {}, 1687 | "outputs": [], 1688 | "source": [] 1689 | }, 1690 | { 1691 | "cell_type": "code", 1692 | "execution_count": null, 1693 | "metadata": {}, 1694 | "outputs": [], 1695 | "source": [] 1696 | }, 1697 | { 1698 | "cell_type": "code", 1699 | "execution_count": null, 1700 | "metadata": {}, 1701 | "outputs": [], 1702 | "source": [] 1703 | }, 1704 | { 1705 | "cell_type": "code", 1706 | "execution_count": null, 1707 | "metadata": {}, 1708 | "outputs": [], 1709 | "source": [] 1710 | }, 1711 | { 1712 | "cell_type": "markdown", 1713 | "metadata": {}, 1714 | "source": [ 1715 | "# n_past是在预测下一个目标值时将在过去查看的步骤数\n", 1716 | "def createXY(dataset,n_past):\n", 1717 | " dataX = []\n", 1718 | " dataY = []\n", 1719 | " for i in range(n_past,len(dataset)):\n", 1720 | " dataX.append(dataset[i - n_past:i,0:dataset.shape[1]])\n", 1721 | " dataY.append(dataset[i,4])\n", 1722 | " return np.array(dataX),np.array(dataY)\n", 1723 | "\n", 1724 | "trainX , trainY = createXY(df_for_training_scaled,20)\n", 1725 | "testX , testY = createXY(df_for_testing_scaled,20)" 1726 | ] 1727 | }, 1728 | { 1729 | "cell_type": "markdown", 1730 | "metadata": {}, 1731 | "source": [ 1732 | "定义模型函数(5种神经网络模型),还有画损失图和拟合图对比的函数。" 1733 | ] 1734 | }, 1735 | { 1736 | "cell_type": "code", 1737 | "execution_count": 127, 1738 | "metadata": {}, 1739 | "outputs": [], 1740 | "source": [ 1741 | "def build_model(X_train,mode='LSTM',hidden_dim=[32,16]):\n", 1742 | " set_my_seed()\n", 1743 | " model = Sequential()\n", 1744 | " if mode=='RNN':\n", 1745 | " #RNN\n", 1746 | " model.add(SimpleRNN(hidden_dim[0],return_sequences=True, input_shape=(X_train.shape[-2],X_train.shape[-1])))\n", 1747 | " model.add(SimpleRNN(hidden_dim[1])) \n", 1748 | " \n", 1749 | " elif mode=='MLP':\n", 1750 | " model.add(Dense(hidden_dim[0],activation='relu',input_shape=(X_train.shape[-1],)))\n", 1751 | " model.add(Dense(hidden_dim[1],activation='relu'))\n", 1752 | " \n", 1753 | " elif mode=='LSTM':\n", 1754 | " # LSTM\n", 1755 | " model.add(LSTM(hidden_dim[0],return_sequences=True, input_shape=(X_train.shape[-2],X_train.shape[-1])))\n", 1756 | " model.add(LSTM(hidden_dim[1]))\n", 1757 | " elif mode=='GRU':\n", 1758 | " #GRU\n", 1759 | " model.add(GRU(hidden_dim[0],return_sequences=True, input_shape=(X_train.shape[-2],X_train.shape[-1])))\n", 1760 | " model.add(GRU(hidden_dim[1]))\n", 1761 | " elif mode=='CNN':\n", 1762 | " #一维卷积\n", 1763 | " model.add(Conv1D(hidden_dim[0], kernel_size=3, padding='causal', strides=1, activation='relu', dilation_rate=1, input_shape=(X_train.shape[-2],X_train.shape[-1])))\n", 1764 | " #model.add(MaxPooling1D())\n", 1765 | " model.add(Conv1D(hidden_dim[0], kernel_size=3, padding='causal', strides=1, activation='relu', dilation_rate=2))\n", 1766 | " #model.add(MaxPooling1D())\n", 1767 | " model.add(Conv1D(hidden_dim[0], kernel_size=3, padding='causal', strides=1, activation='relu', dilation_rate=4))\n", 1768 | " #GlobalMaxPooling1D()\n", 1769 | " model.add(Flatten())\n", 1770 | " \n", 1771 | " model.add(Dense(1))\n", 1772 | " model.compile(optimizer='Adam', loss='mse',metrics=[tf.keras.metrics.RootMeanSquaredError(),\"mape\",\"mae\"])\n", 1773 | " return model\n", 1774 | " \n", 1775 | "def plot_loss(hist,imfname):\n", 1776 | " plt.subplots(1,4,figsize=(16,2))\n", 1777 | " for i,key in enumerate(hist.history.keys()):\n", 1778 | " n=int(str('14')+str(i+1))\n", 1779 | " plt.subplot(n)\n", 1780 | " plt.plot(hist.history[key], 'k', label=f'Training {key}')\n", 1781 | " plt.title(f'{imfname} Training {key}')\n", 1782 | " plt.xlabel('Epochs')\n", 1783 | " plt.ylabel(key)\n", 1784 | " plt.legend()\n", 1785 | " plt.tight_layout()\n", 1786 | " plt.show()\n", 1787 | "def plot_fit(y_test, y_pred,name):\n", 1788 | " plt.figure(figsize=(4,2))\n", 1789 | " plt.plot(y_test, color=\"red\", label=\"actual\")\n", 1790 | " plt.plot(y_pred, color=\"blue\", label=\"predict\")\n", 1791 | " plt.title(f\"{name}拟合值和真实值对比\")\n", 1792 | " plt.xlabel(\"Time\")\n", 1793 | " plt.ylabel(name)\n", 1794 | " plt.legend()\n", 1795 | " plt.show()" 1796 | ] 1797 | }, 1798 | { 1799 | "cell_type": "code", 1800 | "execution_count": null, 1801 | "metadata": {}, 1802 | "outputs": [], 1803 | "source": [] 1804 | }, 1805 | { 1806 | "cell_type": "code", 1807 | "execution_count": null, 1808 | "metadata": {}, 1809 | "outputs": [], 1810 | "source": [] 1811 | }, 1812 | { 1813 | "cell_type": "code", 1814 | "execution_count": null, 1815 | "metadata": {}, 1816 | "outputs": [], 1817 | "source": [] 1818 | }, 1819 | { 1820 | "cell_type": "code", 1821 | "execution_count": null, 1822 | "metadata": {}, 1823 | "outputs": [], 1824 | "source": [] 1825 | }, 1826 | { 1827 | "cell_type": "code", 1828 | "execution_count": null, 1829 | "metadata": {}, 1830 | "outputs": [], 1831 | "source": [] 1832 | }, 1833 | { 1834 | "cell_type": "code", 1835 | "execution_count": null, 1836 | "metadata": {}, 1837 | "outputs": [], 1838 | "source": [] 1839 | }, 1840 | { 1841 | "cell_type": "code", 1842 | "execution_count": null, 1843 | "metadata": {}, 1844 | "outputs": [], 1845 | "source": [] 1846 | }, 1847 | { 1848 | "cell_type": "code", 1849 | "execution_count": null, 1850 | "metadata": {}, 1851 | "outputs": [], 1852 | "source": [] 1853 | }, 1854 | { 1855 | "cell_type": "code", 1856 | "execution_count": null, 1857 | "metadata": {}, 1858 | "outputs": [], 1859 | "source": [] 1860 | }, 1861 | { 1862 | "cell_type": "markdown", 1863 | "metadata": {}, 1864 | "source": [ 1865 | "定义五种神经网络模型,花损失函数模拟图和对比的函数" 1866 | ] 1867 | }, 1868 | { 1869 | "cell_type": "code", 1870 | "execution_count": 101, 1871 | "metadata": {}, 1872 | "outputs": [], 1873 | "source": [ 1874 | "def build_model(X_train,mode='LSTM',hidden_dim=[32,16]):\n", 1875 | " set_my_seed()\n", 1876 | " model = Sequential()\n", 1877 | " if mode=='RNN':\n", 1878 | " #RNN\n", 1879 | " model.add(SimpleRNN(hidden_dim[0],return_sequences=True, input_shape=(X_train.shape[-2],X_train.shape[-1])))\n", 1880 | " model.add(SimpleRNN(hidden_dim[1])) \n", 1881 | " \n", 1882 | " elif mode=='MLP':\n", 1883 | " model.add(Dense(hidden_dim[0],activation='relu',input_shape=(X_train.shape[-1],)))\n", 1884 | " model.add(Dense(hidden_dim[1],activation='relu'))\n", 1885 | " \n", 1886 | " elif mode=='LSTM':\n", 1887 | " # LSTM\n", 1888 | " model.add(LSTM(hidden_dim[0],return_sequences=True, input_shape=(X_train.shape[-2],X_train.shape[-1])))\n", 1889 | " model.add(LSTM(hidden_dim[1]))\n", 1890 | " elif mode=='GRU':\n", 1891 | " #GRU\n", 1892 | " model.add(GRU(hidden_dim[0],return_sequences=True, input_shape=(X_train.shape[-2],X_train.shape[-1])))\n", 1893 | " model.add(GRU(hidden_dim[1]))\n", 1894 | " elif mode=='CNN':\n", 1895 | " #一维卷积\n", 1896 | " model.add(Conv1D(hidden_dim[0], kernel_size=3, padding='causal', strides=1, activation='relu', dilation_rate=1, input_shape=(X_train.shape[-2],X_train.shape[-1])))\n", 1897 | " #model.add(MaxPooling1D())\n", 1898 | " model.add(Conv1D(hidden_dim[0], kernel_size=3, padding='causal', strides=1, activation='relu', dilation_rate=2))\n", 1899 | " #model.add(MaxPooling1D())\n", 1900 | " model.add(Conv1D(hidden_dim[0], kernel_size=3, padding='causal', strides=1, activation='relu', dilation_rate=4))\n", 1901 | " #GlobalMaxPooling1D()\n", 1902 | " model.add(Flatten())\n", 1903 | " \n", 1904 | " model.add(Dense(1))\n", 1905 | " model.compile(optimizer='Adam', loss='mse',metrics=[tf.keras.metrics.RootMeanSquaredError(),\"mape\",\"mae\"])\n", 1906 | " return model\n", 1907 | " \n", 1908 | "def plot_loss(hist,imfname):\n", 1909 | " plt.subplots(1,4,figsize=(16,2))\n", 1910 | " for i,key in enumerate(hist.history.keys()):\n", 1911 | " n=int(str('14')+str(i+1))\n", 1912 | " plt.subplot(n)\n", 1913 | " plt.plot(hist.history[key], 'k', label=f'Training {key}')\n", 1914 | " plt.title(f'{imfname} Training {key}')\n", 1915 | " plt.xlabel('Epochs')\n", 1916 | " plt.ylabel(key)\n", 1917 | " plt.legend()\n", 1918 | " plt.tight_layout()\n", 1919 | " plt.show()\n", 1920 | "def plot_fit(y_test, y_pred,name):\n", 1921 | " plt.figure(figsize=(4,2))\n", 1922 | " plt.plot(y_test, color=\"red\", label=\"actual\")\n", 1923 | " plt.plot(y_pred, color=\"blue\", label=\"predict\")\n", 1924 | " plt.title(f\"{name}拟合值和真实值对比\")\n", 1925 | " plt.xlabel(\"Time\")\n", 1926 | " plt.ylabel(name)\n", 1927 | " plt.legend()\n", 1928 | " plt.show()" 1929 | ] 1930 | }, 1931 | { 1932 | "cell_type": "markdown", 1933 | "metadata": {}, 1934 | "source": [ 1935 | "定义训练函数" 1936 | ] 1937 | }, 1938 | { 1939 | "cell_type": "code", 1940 | "execution_count": 102, 1941 | "metadata": {}, 1942 | "outputs": [], 1943 | "source": [ 1944 | "df_eval_all=pd.DataFrame(columns=['MAE','RMSE','MAPE','R2'])\n", 1945 | "df_preds_all=pd.DataFrame()\n", 1946 | "def train_fuc(mode='LSTM',window_size=64,batch_size=32,epochs=50,hidden_dim=[32,16],train_ratio=0.8,kernel=\"rbf\",show_loss=True,show_fit=True):\n", 1947 | " df_preds=pd.DataFrame()\n", 1948 | " #预测每一列\n", 1949 | " for i,col_name in enumerate(df.columns):\n", 1950 | " print(f'正在处理变量:{col_name}')\n", 1951 | " #准备数据\n", 1952 | " data=df[col_name]\n", 1953 | " train_size=int(len(data)*train_ratio)\n", 1954 | " X_train,y_train,X_test,y_test=get_traintest(data.values,window_size=window_size,train_size=train_size)\n", 1955 | " #print(X_train.shape,y_train.shape,X_test.shape,y_test.shape)\n", 1956 | " #归一化\n", 1957 | " scaler = MinMaxScaler() \n", 1958 | " scaler = scaler.fit(X_train)\n", 1959 | " X_train=scaler.transform(X_train) ; X_test=scaler.transform(X_test)\n", 1960 | " \n", 1961 | " y_train_orage=y_train.copy() ; y_scaler = MinMaxScaler() \n", 1962 | " y_scaler = y_scaler.fit(y_train.reshape(-1,1))\n", 1963 | " y_train=y_scaler.transform(y_train.reshape(-1,1))\n", 1964 | " \n", 1965 | " if mode!='MLP':\n", 1966 | " X_train = X_train.reshape((X_train.shape[0], X_train.shape[1], 1))\n", 1967 | " #构建模型\n", 1968 | " s = time.time()\n", 1969 | " set_my_seed()\n", 1970 | " model=build_model(X_train=X_train,mode=mode,hidden_dim=hidden_dim)\n", 1971 | " earlystop = EarlyStopping(monitor='loss', min_delta=0, patience=5)\n", 1972 | " hist=model.fit(X_train, y_train,batch_size=batch_size,epochs=epochs,callbacks=[earlystop],verbose=0)\n", 1973 | " if show_loss:\n", 1974 | " plot_loss(hist,col_name)\n", 1975 | " \n", 1976 | " #预测\n", 1977 | " y_pred = model.predict(X_test)\n", 1978 | " y_pred = y_scaler.inverse_transform(y_pred)\n", 1979 | " #print(f'真实y的形状:{y_test.shape},预测y的形状:{y_pred.shape}')\n", 1980 | " if show_fit:\n", 1981 | " plot_fit(y_test, y_pred,name=col_name)\n", 1982 | " e=time.time()\n", 1983 | " print(f\"运行时间为{round(e-s,3)}\")\n", 1984 | " df_preds[col_name]=y_pred.reshape(-1,)\n", 1985 | " \n", 1986 | " s=list(evaluation(y_test, y_pred))\n", 1987 | " s=[round(i,3) for i in s]\n", 1988 | " print(f'{col_name}变量的预测效果为:MAE:{s[0]},RMSE:{s[1]},MAPE:{s[2]},R2:{s[3]}')\n", 1989 | " print(\"=================================================================================\")\n", 1990 | " \n", 1991 | " \n", 1992 | " X_pred=df_preds.iloc[:,:-1]\n", 1993 | " X_pred_s = scaler_s.transform(X_pred)\n", 1994 | " y_direct=df_preds.iloc[:,-1] \n", 1995 | " \n", 1996 | " if kernel==\"rbf\":\n", 1997 | " y_nodirect=rvm_model_rbf.predict(X_pred_s)\n", 1998 | " if kernel==\"linear\":\n", 1999 | " y_nodirect=rvm_model_linear.predict(X_pred_s) \n", 2000 | " if kernel==\"ploy\":\n", 2001 | " y_nodirect=rvm_model_ploy.predict(X_pred_s) \n", 2002 | " \n", 2003 | " score1=list(evaluation(y_test, y_direct))\n", 2004 | " score2=list(evaluation(y_test, y_nodirect))\n", 2005 | " df_preds_all[mode]=y_direct\n", 2006 | " df_preds_all[f'{mode}+RVM']=y_nodirect\n", 2007 | " df_eval_all.loc[f'{mode}',:]=score1\n", 2008 | " df_eval_all.loc[f'{mode}+RVM',:]=score2\n", 2009 | " print(score2)" 2010 | ] 2011 | }, 2012 | { 2013 | "cell_type": "markdown", 2014 | "metadata": {}, 2015 | "source": [ 2016 | "开始训练\n", 2017 | "\n", 2018 | "初始化超参数" 2019 | ] 2020 | }, 2021 | { 2022 | "cell_type": "code", 2023 | "execution_count": 103, 2024 | "metadata": {}, 2025 | "outputs": [], 2026 | "source": [ 2027 | "window_size=64\n", 2028 | "batch_size=32\n", 2029 | "epochs=50\n", 2030 | "hidden_dim=[32,16]\n", 2031 | "train_ratio=0.8\n", 2032 | "kernel=\"rbf\"\n", 2033 | "show_fit=True\n", 2034 | "show_loss=True\n", 2035 | "mode='LSTM' #RNN,GRU,CNN" 2036 | ] 2037 | }, 2038 | { 2039 | "cell_type": "markdown", 2040 | "metadata": {}, 2041 | "source": [ 2042 | "LSTM预测" 2043 | ] 2044 | }, 2045 | { 2046 | "cell_type": "code", 2047 | "execution_count": 104, 2048 | "metadata": {}, 2049 | "outputs": [ 2050 | { 2051 | "name": "stdout", 2052 | "output_type": "stream", 2053 | "text": [ 2054 | "正在处理变量:循环次数\n" 2055 | ] 2056 | }, 2057 | { 2058 | "name": "stderr", 2059 | "output_type": "stream", 2060 | "text": [ 2061 | "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\sklearn\\utils\\validation.py:475: DataConversionWarning: Data with input dtype int64 was converted to float64 by MinMaxScaler.\n", 2062 | " warnings.warn(msg, DataConversionWarning)\n", 2063 | "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\sklearn\\utils\\validation.py:475: DataConversionWarning: Data with input dtype int64 was converted to float64 by MinMaxScaler.\n", 2064 | " warnings.warn(msg, DataConversionWarning)\n" 2065 | ] 2066 | }, 2067 | { 2068 | "data": { 2069 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABHgAAACKCAYAAADR5/1CAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzs3Xd4FFX3wPHvSaFXAUFAelGKdJESSBCRRIr6ExCRKoqgoqBIUwhVREFfUbHQxAKooFioyguiECmvUkRREJAuBgTpJDm/P2bBEBJIwiazuzmf55kHdvbOnbObnbN378y9I6qKMcYYY4wxxhhjjPFfQW4HYIwxxhhjjDHGGGOujnXwGGOMMcYYY4wxxvg56+AxxhhjjDHGGGOM8XPWwWOMMcYYY4wxxhjj56yDxxhjjDHGGGOMMcbPWQePMcYYY4wxxhhjjJ+zDh5jjDHGGGOMMcYYP2cdPMYYY4wxxhhjjDF+zjp4rpKIdBCR2iKSU0TyiEgxEakuIq1FZISI3OMpFyQiIZ7/i4iEikgOEZEk9YWKSPD5ckmeCxeRly4Ty6OeOpd5YhggIvlE5E0RaeL9V2+M/7Jj1xhj3GM52BjjDyxXGX9jHTxXQUSKAOOBo8DnwDfAx8BM4CUgGFBP8TrAchE5APwPGASsA9aJyGYROSki64D1wK2eA36FiNyUaJf/B/xymZBCgGeAOCA3cK+qHgOaAXtS+ZqeEpEDInJKRI56/h+Rmm0vU99TGVU+HfF0E5EZGVW/yVgiMtVL9dixm7r6fObYNb5BRHaKSBm34zD+zXJwquuzHGyMiyxXpbo+y1U+RFT1yqXMJUQkCPgM+FRVp4jIAqATTgJoDUSq6sOenlxV1XjPdjOAN1Q1JlFd9wB3qWqnJPu4E3gBJ2GcA/YCB4CziYqtUNXHRSQH0Aa4DrgfmAsUBT4B5qjqdZ46s6vqmVS8vhlAjKq+kbZ3xreJSDcgXFW7uRyKSQcRmQZMVtW1V1GHHbtZlIgUALqp6stux+KvRGQnTg7d6XIoxk9ZDjbG+APLVcZfhbgdgB+7GaendKeIzMHpvX0FKA0UB0JEZDkQCoz1JIVQz7ZBIpJNVc8fvLcBXyXdgap+KiK3Ax2BvMBUVR1w/nkReQ3Y6nmYDWjiiasu8DNwGOjt2d86oBTwj4jUVNV/vPM2GJOpsgFfichi4ATOF2qPNNZhx27WVQB4ArAOHmPcYznYGOMPLFcZ/6SqtqRzwTkQfwUaAYuA7EAZ4HlgrOf/hTxlSwCbgD89/473rM8FxALdgKeT2UcIUBj4DSgIdME5+AH+AIp7/h8K9Mfpxf0GJwENB9YAz3rKzAbqpfK1zQAeTrJuOXAP8CmwLNH6hz2x7AUGJtkmGohO9DjcU88E4C9PrDnTUx64Fyfx/s/z2qal4nV1A2YkWdcC53LI3cCQROt7edb9CYy+0npbMuWYK510SWc9duyq/xy7ntfUG5gG/JZofUrH7v3ADmAnzhU7AB8Ah4B4nLNji66wz52ebfYDz3n+/l08zw3xvG+7gNaedUHAW8A+YBvQItH78hLO5ywW+M8V9nu+nv2efdzlWZ8HmOdZPxmnczO5930nzuc32XgSlWkG/Bd4J9H67sDvnn08mCieyZ5173s+K2Wu8BouqecKn8O0/n2TLW+L/yxYDj6/3udzsKeOufybC/cBQz3PjfDE/gfQOVEcn+P8+PwDT470PJdsbrDFFl9dLFddWB9wucqzPiBzks3Bk04iEorT+I9V1e9wDrpzOGMtD+H0qPbHOfhR1b2qWh1YgPMBetpT1VNAfuAI0F5EOiTZVSGcg76uqh7BSRLxnuduUtV9nv+XxDkonsb5AA/H+RCH4lz2B06v7varfOljcRrVd3nehxxAZ+AWoCLwlIjkuUIdDXB+fF0H5ANaprP8f4BbgTeBOE37lRyISCGccbTtgWpABxGJ8jz9AhCJ875VFpG8V1hvMpiq7gIqA3cCFT2P08SOXb89dgcD3wH1PfEne+yKyA3AOJyzXA2BkSJSXVXvA+oBu1W1mKpeKXZw/ubfAUWAkUC4iER6Yr8RpwPiLc9nqh5wDc7f6n5gdKJ6egDPAjcBD4vINZfZZ02gFU4H5u04Z/0A+uJ8fkpw+fH5510uHoAXPTE9CiAiVXE+93U9MUSLSFGchl5toCxOB1Pxy+30MvWcd9HnMJFU/X1TKm/8h+Vgv8zBP+B08NYDHsKZP6QUEAZU8ryGFxKVr4qTf1sD00QkVypygzE+xXJVYOeqQM5JNkQr/W4FDuL0tAIEqWqCJxm08qzLB6xObmMRyYfTYO6C09t4BrgPZ7Kt71T1/ERZQ4AziZJEFZwPJ6r6d6IqrwH64fRQBuGcPTmH8wW7XEQKAnlV9fBVvWqn5/Sz8w9U9bSIdME58MM8cRQGjl+mjoPAa6qqIrIBJ+ldTkrlT+NcrpiN9E8Y3hD4UVU3woXxqFE4yflbYAxOL3Zv/fdSx5TWmwwmIhNxhtl8D3QSkShV7Z/GauzYxS+P3QWqmniS7ZSO3a3AF6q627P+E5yOkk2p3E9iq4Hmnn/jPbE2x2k4nG9A5cI5u/a95/M5EqdTpkiiej5Xz7xR4ky+mA+nYZic7UACTgNkGc6QMnA6Mt71fFank/IwMwG4QjzgnFn8NtHjZkA5YIvncU6cztSGwMeqehqYKyJ/c3kp1XPQ8/iiz2Eiqf37LkihvPEfloPxuxy82hPjak98Qar6h4g8ATwJRODMBXLep54fqkdE5CDOD6swLp8bjPE1lqsI6Fx1pfaK37IreNJJVRcBQ+FCD+/5D+JfQE/PMhbPzOoiEiwiNXE+SG/hXLI/G+gD/OOp81ecmdqreLYpjZNARnke58I5Y7xTRMKTxLMeaIwz8VclTz1bVfUUzlnXT3F+LFytmMQPRKQ8zqV0h3EOnN2pqGOHqp6f3Ts1s3ynVH4dTsLsjvMjJr00yf/PP26Dc/ljZWCzODPpX269yXh1VLWHqr6pqt1xet3TxI5dhx8euzHJrEvp2E1pfVrFJ/kXnA6UMZ6rgIrhnC3bKyKdgNdwzh49mqSexGfTLhuLqh7F+RytxGkILk203fnbqSb73S3ORI9FPf+/XDxw6fspwMxEr6ukp4wkiTnhcvFfpp6U9nu59Zf7O6ZUj/FxloMdfpaDL8mFIhKGM1Tkd5yhJ4klvvVzEE7euFJuMManWK5yBHCuCticZB08V+f8h6cOzqVy4IzNrOtZbkhUNgyYiHPJWxOcs8nPqeqSxBWq6kuqukScmdunAMNV9R/P47c86/oCU0WkdpJ4iuMcCItwLmVb71k/z7PPz6/u5SarFs58DtNwOjxKpmKbtP7YuqS853K7ckAVVa2tqqkZspCcVUBNEanm6WnvCiz0JNjNOGM+h+H0AldIaX06923S7qiIdBSRciJyH86XXHrYsRugxy5Oh0grESkhItfhXGJ8/m8VCxTyDBfIJSI507Hfr3Ausc4nIsVxOm8K4FxevBiYD7RNsk2q3zcRuRXnb/IpzjCkm0VEcM5G3ev5PHVPtMkx4HrP/x/C+RxzhXiSswyIFJFi4gw73YDTAF0D3C0i2UWkLc6Zu/TUk1Yp/X1NYLAc7P85uD5OfpiFc3VdYneKSEERqYUz/OQ3vJcbjMlMlqsCN1cFbE6yDp6rE4zzHvbk38vz8uOcLX0UaOd5HlVdrqrNVHWGqv6tqi+r6sxEdSU+2wFO720u4H0RqYbzo+K4qr6ozu1pewJLRKSvp1cZnIk0Z+MknqFAVRG5G5iDM/7zLU8vpjednxH+IM5EWDtwepUz2m6c93afiOwUkS9EJDUJ5yKqGovzw+Fj4CfgI1VdoKongddxkvNunLPpa1Ja740XZFKlK858IJOAGjiXvaaHHbuBe+z+gtMx8i3O33a4qm7ybPMPzsSI2/l3rHda97sAp/NlM84cMI+p6l84kxW299SbDSgiIle6JDk5K3DO9O3FyS9Pe85qvep5fh/OJd/nzQbqisgSnKt3zs9LlaZ4VHUzzhnE1TiXK7+mqj966j8/WWoPnAmqU3SZetIkpb9vWusxPstysJ/nYJxjsxpOTqoCHBeR8/H/gJOD5wM9VPWUt3KDMZnMclWA5qqAzknqAzM9++uCc9Z0Hc7s3zd41v2e6PlbgLdTUc9M4M5k1ufGuXPKT0CvZJ6/EecLtAZOgvgvzuWAxTzP349zUFbyPA7HObud2+33zgvvfRtgiuf/wThDER53Oy5b/GOxY9fV996OXe+8j+p2DLbYkt7FcrCr732G5mCS3CHHFlv8ebFc5ep7b+3FdC7iedNMOohIMFBAnTON59cVUdVDXt6PqP2hLiIi1wPv4fQgx+NcUdNDVfe7GpjJUCIyVVUf8EI9duy6xNeOXXEmPU7qT1W9KQP3eT2wNpmn1qvqHamsQ1U16dnATOPG+2YCh+Vg92R0DhaRaABVjfZGfca4yXKVe3ytvehPrIPHGOM3RGQaMFk9dyQyxhhjjDHGGOOw26QbY/xJNmCpZ76REzjDVHq4HJMxxhhjjDHGuM6u4DHG+A3P7SQvoqq7kitrjDHGGGOMMVmJX3TwFC5cWMuUKeN2GMb4jfXr1/+lqkXcjiMrsPxkTNpYfso8lp+MSRvLT5nH8pMxaZPa/OQXQ7TKlCnDunXr3A7DGL8hIgF5VYu3Jln2JstPxqRNoOYnX2T5yZi0sfyUeSw/GZM2qc1PQRkdiDHGeJGISD23g0grf7hS0hhjjDHGGOPfAqKD53//+x+NGzdm69atbodijMlY2YCvRORDEZnuuauWT1u9ejV169blwIHk7iptjDHmcj777DPatWvHK6+8wubNm63D3Jgs4osvvqB+/fqcPHnS7VCM8SsB0cFTqFAhvvvuOxYsWOB2KMaYjDUUuAkYAEQDI1yNJhUKFy7Mhg0bGDt2rNuhGGOMX9m/fz9du3bliy++4PHHH6d69eoUK1aMjh078vbbb3Po0CG3QzTGZJD8+fOzZs0a3n77bbdDMcav+MUcPFdSunRpqlatyoIFC+jXr5/b4QS0c+fOsWfPHk6fPu12KAbIkSMHJUuWJDQ01O1QMoWq7hKRgkBx4DBw0OWQLpHcMbJ8+XKOHz/Opk2bCAkJiLRrzBVltfxkvEtV6d27N6dPn2bDhg1ky5aNZcuWXVhmz57NuHHj2LBhA3ny5HE73Cuy9pNvsfzk+8LCwggPD+f555+nV69e5MiRw+2QshTLWe652vwUML80oqKiePnllzl+/LhffNH7qz179pA3b17KlCmDiLgdTpamqsTGxrJnzx7Kli3rdjiZQkQGAncBuYDngduBLq4GlURyx8jZs2fZtGkTefLkyTJ/K5O1ZcX8ZLzrww8/ZP78+bzwwgtUqlQJgB49etCjRw9UlcWLFxMVFcWAAQOYPHmyy9FembWffIflJ/8xfPhwIiIimDJlCo8++qjb4WQplrPc4Y38FBBDtMDp4Dl37hxff/2126EEtNOnT1OoUCE70H2AiFCoUKGs1rPeWlVvAWJV9X2gnNsBJZXcMZItWzauvfZaYmNjOXXqlIvRGZM5smh+Ml5y6NAhHn30UW6++eZkr8wWEVq2bEn//v154403WLx4sQtRpo21n3yH5Sf/0bRpU8LCwhg3bhxnzpxxO5wsxXKWO7yRnzKkg0dEporIahF55grliorID97YZ6NGjcibN6/Nw5MJ7ED3HVnwb3FMRLoAOUSkKfC32wElJ7m/S7FixQgKCmLfvn0uRGRM5suC+cl4yWOPPcbRo0eZNm0awcHBKZYbPXo0N954Iw888AB//+2TXwcXsWPCd9jfwj+ICMOHD2fv3r1MnTrV7XCyHDtO3HG177vXO3hE5G4gWFUbAOVEpOJlir8I5PTGfkNDQ2nRogULFiywOywYE7i6AbWAI0Bb4AFXo0mD0NBQihUrxpEjRzhx4oTb4RhjjE/65JNPmDNnDsOGDaNq1aqXLZsjRw7eeecdDhw4wBNPPJFJERpjMlOzZs1o2LAhzz33nF3FY0wqZMQVPOHAh57/LwEaJ1dIRJoBJwCv3Ts4KiqKPXv2sGnTJm9VaXzMq6++Snh4ODlz5iQ8PJxPPvkkzXWkthHo7cZieHi4V+vLilT1T1Xtp6pRqtpfVQ8CiIjvT8AAFC1alJCQEPbu3et2KMYY43MOHz5Mnz59qFmzJgMHDkzVNvXq1WPIkCG88847zJ8/P4Mj9F/+3H4yWdv5q3j27NnDjBkz3A7HZBLLWemXER08uYHzv14OA0WTFhCRbMCzwKCUKhGRh0RknYisS+1tMCMjIwFsmFYAe/TRR1m+fDklSpRg+fLl3HXXXWmu4+WXX/ZqOeMTKrsdQGoEBwdTrFgxjh07xj///JMh+8iKX4g//vgjP/74o9th+IRu3bqxc+dOt8MwJl369+/PX3/9xfTp09N095BnnnmGmjVr8tBDD/HXX39lYIT+y9pPJrWuNNWGiISIyB8istyzVPesHyEia0XkNW/HdNttt1G/fn3Gjh3L2bNnvV298UGWs9IvI+6idZx/h13lIflOpEHA66r6d0pjzFT1LeAtgLp166ZqzNV1111H7dq1WbBgAYMGpdh3ZLzkiSee8PqPqpo1a6brIAsPD6devXps3LiRxYsXc/z4ce655x5OnDhBhQoVmD59+kVlly9fDkB0dDTnzp1j5cqVHDt2jEWLFlGsWLFUlcufPz933303hw8fpnz58lSrVo0hQ4akKt4zZ87QrVs39u3bR8mSJZk+fTrx8fG0a9eOY8eOUahQIT766CPOnTt3yTq7zbb/SOkYOX78OEFBQeTKlSvNdV7pGHn00Ud59NFHqVChwoXPb1r52xfi+fe4Zs2aLkdijEmv5cuX884771zorEmLbNmyMXPmTOrUqUPv3r358MMPfXruCGs/pa79VKdOHa699lqyZcvGgQMH6N69O23atKF9+/aICE2aNGHMmDF069aNY8eOcfDgQWrVqsWrr77KyZMn6dKlC3/++SfVq1fntde83ucQkBJPtSEi00Skoqr+lqTYTcAsVR2YaLs6OKM2bgaGiUhzVf3Ki3ExfPhwoqKimDlzJj179vRW1SYVLGd5N2cdPHiQbt26cfToUVq3bs3gwYPT/gZeRkZcwbOef4dl1QB2JlOmOfCIiCwHaorIFG/tPCoqilWrVnHkyBFvVWn8QExMDA0aNLhwJ439+/fz2GOP8dVXX7Fz504OHjyY4rbbtm3jm2++4e6772bZsmWpLvfLL79QsmRJvv32W7Zt25bqzh2At99+m2rVqrFixQoqVqzItGnT2LJlC0FBQXzzzTd0796d48ePJ7vO+L/s2bMTHx9PXFxcpu0zPDycAQMGcPvttwNOJ1PLli0JCwuje/ful5Q9Lzo6mqFDh9KkSRNq1qzJgQMHUl3u1KlTREZGUr9+fe677z7Gjh2b6vjOnDlDx44dadq0KZ06deLs2bPJrhs8eDDjxo1j3Lhx3HrrrSnWX6dOHSIjI2nbti3169fnjTfe4ODBg0RGRl4Y2w+wb98+GjduTFhYGEOHDgWcq2JGjhxJWFgYDRs2TPFOaH/++ScRERE0btyYXr16AU7eaNCgAc2aNaNOnTrs3LnzoqtsoqOjWb58ebL73blzJ506daJ79+4X/kbJxbxjxw4aNmxI8+bN2bJlS4rvQXLbJreP5NYdPnyY1q1bExYWduHKreTKGZNeU6ZM4Zprrrnw+U+r6tWrM2LECD7++GPef/99L0cXuHy5/XTy5Ek++ugjNm7cyAcffMD333/P3r17GTduHAsXLuTzzz+/UPaee+7hu+++Y8eOHaxfv5633nqLatWq8c0337B//342btyYzncoywnnylNt3AK0EpE1nqt9QoCmwFx1JkJdDIQlV3l6Rmic17JlS+rVq8eYMWM4d+5cmrY1gSMQctZzzz1Hhw4dWLVqFZ9++imxsbFX8Y5cKiMuBfgUWCkixYFI4F4RGa2qFy7zU9Um5/8vIstV1WvdsFFRUYwePZqlS5fSvn17b1VrkuErZ+8BqlWrxt13333hcWhoKFOmTGH69OkcPnz4srem7tKlCwClSpW67GWfScuVKFGC9evX06RJEx5//PE0xbtly5YL8d5yyy0sXLiQXr16Ua1aNVq0aEHFihVp2bIltWvXvmSdSZZPnqpN6RhJSEjgp59+IigoiCpVqmTKmeaYmBj69u3LCy+8APz7hdi8eXNatmzJwYMHKVr0khG1wL9fdCNHjmTZsmXcd999qSp34403UrJkST777DMaNWrEBx98kOr4zneCzpo1i+joaKZNm0ZcXNwl65577jkqV3ZG6HXr1i3F+s9/6VavXp2vvvqK0aNH88svv9ChQwe6detG/fr1eeihhy58EdesWZOGDRsyZswYwOkQW7lyJT169OCHH36gYcOGl+xj5cqVVK9enVdeeYX333+fhIQExo8fz9ChQ4mMjKRGjRopxpfSfj///HOWLFnCLbfcAvzbKEgc8/jx43n66adp1aoV1atXT3EfyW2b3D6SWzd27FjuvffeCx06ixYt4oYbbkh2W2PS6uTJk8yfP5+OHTuSI0eOdNczYMAAFi5cyAMPPECJEiWIiIjwYpTeY+2n1LWfihYtSp48eShdujTBwcGoKiEhIYwYMYI8efJcNNS5Tp06ANx0003s3LmTrVu3smrVKpYvX87ff//N3r17uemmm1L3pmRtSafaqJ1MmbVAc1XdLyIzgSjPdtsTbZdsgyI9IzTOExGGDRtG69atee+99+zEQiaynOXdnLV161ZWr17NjBkzOHHiBPv27aNQoUJpeh8ux+tX8KjqMZze3xggQlU3JO7cSaZ8uDf3f/PNN3PNNdfYPDxZTJ48eS56PHXqVO655x5mzZpF7ty5L7vtlZ5PqdyiRYt49tlnWb16NZ06dUpTvFWrViUmJgZwfthWrVqVDRs20KhRI5YsWcKRI0dYuXJlsuuyIhEpldLiKXKvqwGmUVBQECVKlODUqVOZNl9ESl+InTp1cv0LMbn4tmzZQv369QGnE/Tnn39Odl1qJfelu3XrViZPnkx4ePiFL9iQkBDGjRtHz549L/rx0LVr1yu+B5GRkcTHx3Pbbbfxyy+/EBQUxB9//EHVqlUJDg5OtvPl/Pue0n5btGhxUedJcjHv2LGDGjVqEBISctmhLcltm9w+kluX+L2vX7/+hfc+uW2NSasvv/yS48eP07Fjx6uqJyQkhE8//ZSKFSvStm1bfvjhBy9FGLj8rf00ceJEBg8ezJQpUy46ObJmzRrAGbJbvnx5KleuzBNPPMHy5csZPXo0pUqVSqlKc7HUTLWxUVX3e/6/DqiYyu2u2h133EHt2rUZPXq0zcWTRQVCzqpcuTLjxo1j+fLlDBo0iGuuuSZNdV5Jhhx8qnpEVT9UVa/dISu1goODadmyJQsXLiQhISGzd298xG233cZzzz1Hs2bNADLkrkW1atXiscceo1mzZtx7771s3rw51dv27NmTn376iSZNmvDbb7/RrVs3ypQpwyuvvELDhg05cOAAdevWTXZdFjXCs3wILATGAV8C8wDO303LnxQsWJA8efKwd+/eTBmq5etfiEnjS64TNLl1ADlz5uTkyZMAOFeHp05yX7Ap/XhIzXuwevVqOnfuzNKlS1m2bBnbt2+nXLlybNq0ifj4eDZs2AA484UcOnSI+Ph4li5dCqT8oyXp+5JczKVKleKnn34iPj7+sneRTKlBkXQfya1L6b1Pbltj0mr27NkUK1aMJk2aXLnwFVxzzTUsWrSIAgUK0LJlS7Zt2+aFCLMOX28/tWrViocffpg2bdqQK1euC/F98cUXNGrUiBtuuIGaNWvy4IMPsnDhQpo0acIbb7zB9ddf7/XXEaBSM9XGuyJSQ0SCgTuBDanc7qqJCKNGjeL333/nrbfeyohdGD/jjzlr0KBBvPjiizRq1IhFixaleAV9uqmqzy916tTRtHjvvfcU0LVr16ZpO3NlW7ZscTsEn/HWW29pRESEtmjRQu+44w7973//60ocyf1NgHXqA8eutxdgKRDk+X8w8LXbMSXNT2k5Rk6cOKFr167VXbt2pXqb1CpfvvxFj5s2bXrR4xUrVmjVqlW1cePG2qBBA/3222+TLTt8+PALn+3p06fr9OnTU11u48aNWqJECY2IiNAOHTropk2bUow3aXynT5/We++9V8PCwvS+++7TM2fOJLtOVTU2NlZvvfVWbdiwoa5YseKy9Tdt2lR37NihXbt21f3792tUVJQ2bNhQO3XqpOfOndPZs2drtWrVNCIiQqtWrap79uzRrl276o4dOy55nUn9/vvvGhYWprfccoveeeedeu7cOf3999+1QYMGeuutt2q9evV0x44d+vXXX2tERIQ++OCD2q5dO/3vf/+b7H7Px5lYcjH/9ttvWr9+fY2IiNBatWpdiDWp5LZNbh/JrYuNjdU77rhDGzVqpI8//niK5ZLKSvnJF5e0tp/ccPToUc2ePbv27dvXq/X+/PPPWqhQIS1Xrpzu37/fq3Wnh7Wf/uXt9lPiHJ0Wlp8uXYB8OB02E4GfcTprRicpUw3YCGwCxnjWBQHfAf8BtgJlr7Sv9OanhIQEDQ8P1yJFiujRo0fTVYe5MstZ/3LjN9/V5CfXE0lqlrQmgEOHDqmIaHR0dJq2M1dmB7vvyUoNFE/joTVQGmfM92q3Y7qaDh5V1V27dunatWv1xIkTadrOH/hKJ6ivSO+PEH+WlfKTLy7+0MEzc+ZMBXTVqlVer/v777/X3Llza40aNfTvv//2ev1pYe0n32P5KfkFKAi0B4qlcbucwD1AudSUv5r8tGbNGgX0mWeeSXcd5vIsZ7nravJTQN5vuXDhwtSvX58FCxYwfPhwt8MJOKrq07cfzUqcYz1L6QQ8DfQBdgH3uxvO1StevDiHDx/mjz/+oHLlygF1bD344IM8+OCDmb7fxHf3AsifPz/z58/3Wv0HDhzg3nsvnvapcuXKvPnmm5fdbsaMGV6L4XLSG58xbpg9ezalS5fOkLmcbr75ZubNm0erVq1o06YNixcvvqpJnK+WtZ98RxZsP6Waqh7h3ztppWW+BOXAAAAgAElEQVS7U8DH3o/oUvXq1aNDhw5MnDiR3r17U7x48czYbZZjOcsdV5ufArKDB5xJuIYNG8aff/7Jtdde63Y4ASNHjhzExsZSqFAhO+BdpqrExsa62ljNbKq6U0SGAiVw7tKQ4jxfIjIVqAJ8qaqjk3k+BPjdswA8pqqbRGQEztVBa1T1kXTGmerjIyQkhBIlSrBr1y4OHz7s1Vn0s6rly5dnaP3FihXL8H1cDV+Jz35AmSuJjY1lyZIl9O/fP8PaFC1atOCdd97hvvvuIzo6mnHjxmXIfq7E2k++Iyu2nwLRmDFjmDdvHtHR0TYfTwawnOUOb+SngO3giYqK4tlnn2Xx4sV07tzZ7XACRsmSJdmzZw+HDh1yOxSDk3xLlizpdhiZRkQGAncBuXAmWm4JdEmm3N1AsKo2EJFpIlJRVX9LUuwmYJaqDky0XR2cSQJvBoaJSHNV/SotMabnC7Fw4cIcOnSIPXv2UKBAAYKDg9OyS2N8jv2AMqkxd+5c4uLiLrnizNs6duzI4sWLeemll+jVqxdly5bN0P0lx9pPviWrtZ8CUfny5enduzevvvoq/fr148Ybb3Q7pIBiOcs9V5ufxB/OsNWtW1fXrVuXpm0SEhIoUaIE4eHhzJo1K4MiM8Y3ich6VQ24W26JyLeq2lhE/quqEecfJ1PuFWCRqi4QkXuBnKo6PUmZPsAjwAmciQJ7AX2B06r6uojcAkSq6iXjPEXkIeAhgFKlStXZtWvXhefOnTvHnj17OH36dJpe25kzZzhw4AD58uWjYMGCadrWGF90voESGhp60fpAzU++KD3tp8zUrFkz9u3bx88//5zhZ4j37t1LpUqVaNWqFXPmzMnQfRn/Zfkp83gjPx06dIjy5csTERHh1aHYxvii1OangL2CJygoiMjISD755BPi4uIICQnYl2pMVnJMRLoAOUSkKfB3CuVyA+fvk3gYqJ1MmbVAc1XdLyIzcYZl5Qa2J9ou2fsWqupbwFvgNFASPxcaGprus8Ndu3Zl1qxZ/PTTT1SsWDFddRhjjD/Yv38/y5cvZ9iwYZly+X+JEiV4+umniY6Opm/fvjRq1CjD92mMyVhFihRh0KBBDB06lJUrVxIWFuZ2SMa4LsjtADJSmzZt+Pvvv1myZInboRhjvKMbUAs4ArQFHkih3HGcuzkA5CH5XLdRVfd7/r8OqJjK7TLM888/T44cOXjiiScyc7fGGJPpPvroI1SVDh06ZNo+n3rqKYoXL07//v1JSEjItP0aYzLOE088QfHixXn66adt7jdjCPAOnqioKIoUKcLUqVPdDsUY4wWq+qeq9lPVKFXtr6oHUyi6HmcuHYAawM5kyrwrIjVEJBi4E9iQyu0yTLFixYiOjmbBggV8/vnnmblrY4zJVLNnz6ZGjRqZOm9G7ty5GTt2LGvWrLHh+8YEiFy5cjFy5EhiYmKYN2+e2+EY47qA7uDJli0bXbp04bPPPrMJoowJAJ47Y6XGp0BnEZkItAd+EpGkd9IaCbwL/Ais9kym/C1QS0T+AwwCMv0XwGOPPUaVKlV4/PHH0zyPjzHG+IOdO3eyevXqDJ9cOTmdO3emdu3aDBo0iJMnT2b6/o0x3te1a1eqVKnCwIEDOXPmjNvhGOOqgO7gAejevTtxcXG89957bodijLl6IiL1rlRIVY8B4UAMEKGqG1T1mSRlNqvqTapaXVWHetYlAM2BlTgTLO/w+iu4gtDQUCZNmsSOHTsYP358Zu/eGGMy3IcffgiQqcOzzgsKCmLixIns2bOHiRMnZvr+jTHeFxISwoQJE9i+fTuvvvqq2+EY46qA7+CpWrUq9evXZ+rUqTYu0xj/lw1YKiIfish0EZmWUkFVPaKqH6rqgbTsQFVPqerHqvr7VUebTs2aNaN9+/Y899xz7Ny5060wjDEmQ8yaNYv69eu7crtygKZNm3LXXXcxbtw49u/ff+UNjDE+r2XLlrRs2ZJRo0bZyA2TpQV8Bw/AAw88wE8//cTatWvdDsUYc3WG4syNMwCIBka4Gk0GmjBhAkFBQfTv39/tUIwxxmu+//57fvzxRzp27OhqHOPHj+fs2bM888wzVy5sjPELEyZM4Pjx4wwfPtztUIxxTZbo4OnQoQM5c+a0yZaN8XOquivxApx1O6aMUrJkSZ599lk++eQTFi9e7HY4xphERCS/iCwUkSUi8omIZBORqSKyWkSeSVQuw9f5k6NHj3Lfffdx/fXX06VLF1djqVChAo899hjTp0/nlVdeIT4+3tV4jDFXr0qVKvTu3Zs333yTzZs3ux2OMa7IEh08+fLlo127dsyaNcsm1DPGj4nIaBHZICLbRWQ7ENA9H/369aNixYo89thjNmmgMb6lEzBRVVsAB4B7gWBVbQCUE5GKInJ3Rq9z4XWnm6rSs2dPdu3axezZsylYsKDbITF8+HBuv/12Hn/8ccLCwtiyZYvbIRljrlJ0dDT58uWjf//+Nj2HyZKyRAcPOMO0/vnnH+bOnet2KMaY9AsDGgJrgJuAgB5knT17diZNmsRvv/3GSy+95HY4xhgPVX1dVZd6HhYB7gc+9DxeAjTGmeg9o9ddQkQeEpF1IrLOl+ahmDx5Mh9//DFjx46lYcOGbocDOCcAFyxYwMyZM9m6dSu1atVi9OjRnD0bsBeHGhPwChUqxPDhw1m6dCkLFixwOxxjMl2W6eAJCwujQoUKNkzLGP8WhDMHTx6cDp4i7oaT8W6//XbuvPNORo0axZ49e9wOx5iAJiLVROR2EblRRPKkonwDoCCwG9jrWX0YKArkzoR1l1DVt1S1rqrWLVLEN1LkDz/8QL9+/YiMjOSpp55yO5yLiAidO3fm559/5q677uLZZ5+lbt26Nm+jMX6sT58+VKpUiSeffJJz5865HY4xmSrLdPCICN27d2fFihVs27bN7XCMMenTHmfenWeB3sAod8PJHC+99BIJCQkMGDDA7VCMCVgiMgln4vbngHLAB1cofw0wCegBHAdyep7Kg9O+yox1Pu/YsWO0b9+eIkWKMHPmTIKCfDPsa6+9ltmzZzN//nxiY2O55ZZbeOaZZ+xqHmP8ULZs2XjxxRfZunUrkydPdjscYzJVqr9lRSRIRPKJSIiIRIhI3owMLCN07dqVoKAgpk+f7nYoxpj0qQjkAvIBU4CD7oaTOcqUKcOAAQOYPXu2nVU2JuNUV9X/A/5W1S+B/CkVFJFswEfAYM+E7+v5d8hUDWBnJq3zaapKr1692LFjB7NmzaJw4cJuh3RFbdq0YcuWLXTt2pUxY8ZQv359m6zV+D0RKSgiVUXkOhHxzV5WL2vVqhXNmzcnOjqa2NhYt8MxJtOk5QD/CGgCvAT0BD7JkIgyUIkSJWjZsiUzZsywuyUY458iPEsUTi7q4244mWfAgAEUKVKEAQMG2KSBxmSMQyIyDCgoIl1xJk9OyQNAbWCoiCwHBOgsIhNxrjT8Evg0E9b5tLfffpvZs2czcuRIwsLC3A4n1fLnz8+0adOYP38++/bto06dOrzwwgvWdjR+SUQGAguBWUAzYIarAWUSEWHixIkcPXqUESNGuB2OMZkmLR08hVT1C6Ciqnbi38uE/UqPHj3Yt28fS5YscTsUY0waqeoIzzIIuBn42+2YMkvevHmJjo5mxYoVfPmlz/+uM8YfdQGOAqtxrt7pnlJBVZ2sqgVVNdyzvIMzCXIMEKGqR1X1WEav8/Yb4E2HDh3iiSee4LbbbmPQoEFuh5Mubdq0YfPmzbRq1Yqnn36a8PBwfv/9d7fDMiatWqvqLUCsqr6PMwQ1S6hevTo9e/Zk8uTJ/Prrr26HY0ymSEsHzz8i8imwXkSigH8yKKYM1bp1awoXLmyTLRvjh0Sk1PkFqIkzZCvLePDBB6lUqRIDBw4kLi7O7XCMCSiqegr4BlgKLFfVk2nc/oiqfqiqBzJzna+aOnUqp06d4j//+Y/PzruTGkWKFOHjjz/m3XffZdOmTVSpUoUnn3ySv/76y+3QjEmtYyLSBcghIk3JQifHAEaOHEmOHDkYOHCg26EYkynS8o3bDhipqkNx7uLQIWNCyljZsmWja9euzJ8/n127drkdjjEmbUYkWh4ERrsbTuYKDQ1l3LhxbNmyhRkzZrgdjjEBRUReAUYC9YBxIjLB5ZD8Vnx8PG+88QbNmjXjxhtvdDucqyYi3H///WzevJmOHTvy8ssvU65cOUaNGsXx48fdDs+YK+kG1AKOAG1xhphmGUWLFmXQoEF8+umnrFixwu1wjMlwaengOQtsE5EQ4BogIWNCynh9+/YFYOLEiS5HYoxJo2hguGcZC2z3XM2TZdx55500atSIYcOGceLECbfDMSaQ1FbV1qo6RFWjgPpuB+SvFixYwK5du+jTJ7CmSStZsiTTp09n06ZNNG/enGHDhlGuXDleeeUVTp8+7XZ4xiRLVf9U1X6qGqWq/fGTO/B5U79+/ShZsiRPPvkkCQl++xPWmFTJkEmWRWSqiKwWkWdSeD6/iCwUkSUi8onnbhSZplSpUnTq1IkpU6bYJbbG+Je5wBc4tzFeCHyI0+mTZYgIL7zwAvv377dOamO866CIdBCRiiLSCdid1TqQveX111+nePHitGnTxu1QMkSVKlWYN28eMTExVKtWjccff5y8efNSvXp1OnbsyNixY/n888/ZuXOnTYpvXCcio0Vkg4hsF5FtwGK3Y8psuXLlYuzYsaxfv54PPvjA7XCMyVBen2RZRO4GglW1AVBORJKbI6MTMFFVW+DcpaJlGuO+ak8//TQnT57k1VdfzexdG2PS7yhQw5ODbgL+UdUeLseU6Ro0aMD//d//MX78eA4ezBJ3ijcmMxzDaY8MAZoDp8liHcjesG3bNhYtWkSvXr0IDQ11O5wMVb9+fb7++muWLVvGgAEDKFOmDKtXr2bo0KG0adOGsmXLUqxYMdq1a8ekSZPYuHGjXT1g3BAGNATWADWAQ+6G445OnTpRu3ZthgwZwqlTp9wOx5gME5KGsqmdZDkc56w6wBKgMfBb4gKq+nqih0WAP5NWIiIPAQ+Bc8WNt1WpUoU2bdowadIknnrqKfLkyeP1fRhjvC4n0FJEtgDVgVwux+OasWPHMn/+fEaOHMlrr73mdjjG+D1VveiuWSJynarudysef/XGG28QEhJCz5493Q4lU4gIERERREREXFh37NgxNm/ezMaNG1m1ahUrVqzg448/BqBgwYI0btyYGjVqUKVKFW688UYqV65Mzpx+eXNa4x+CcDp2cuOcHCt8ucIiMhWoAnypqpfMdSgi+YHZQDBwAmde1gTgd88C8JiqbvLWC/CGoKAgJkyYQEREBC+//DKDBw92OyRjMkRGTLKc2/M8wGGgaEoVikgDoKCqxiR9TlXfUtW6qlq3SJEiaQgz9QYNGsThw4eZMmVKhtRvjPG6zsCdwBs4Z9o7uRuOeypVqkSvXr1488032bp1q9vhGOP3RGRUkmEMi9yOyd+cPHmSadOmcdddd1G8eHG3w3FNvnz5aNiwIQ8//DAzZ85k165d7Ny5k3feeYe7776bX3/9lbFjx3LfffdRq1YtcufOTfny5bn77ruJibmkSWzM1VqAc1OKBKA3l7kD6VWMxLgJmKWq4Z7Fpzp3zgsPD6dNmzY899xz/PnnJdcXGBMQ0tLBEwfUFZGXcO4wkdLsnsf5d/hWnpT2ISLXAJMA14ZXNGjQgCZNmjBx4kTOnj3rVhjGmFRS1e3Aw6oaCUwA/nA5JFcNGzaMXLlyMWjQILdDMSYQNOHiYQw2SV8azZkzhyNHjvDII4+4HYrPKV26NF26dGHKlCn88ssvnDhxgo0bNzJnzhyGDRtG3bp1+e6772jQoAGdOnVi9+7dbodsAkcbnDtn9QWeAS53a7twLh2JcRFVfV1Vl3oenh+JcQvQSkTWeOZiTXaUiIg8JCLrRGTdoUPujBQbP348p06dYvjw4a7s35iMlpYOnunAdThntEp4HidnPf8mgxrAzqQFPJMqfwQMVlVX71U+cOBAdu/ezaxZs9wMwxiTCiIyGbhXREYA7/JvIyRLuvbaaxk4cCCffvop3377rdvhGOPv0jSMwVxMVXnttdeoUqUKTZo0cTscn5cjRw6qV69O+/btiY6OZs6cOWzbto2hQ4cyb948KleuzPDhw+1uicYbDgJfATOAd0j5NxykfyTGWqC5qt4MhAJRyW2TGSM0rqRy5co8/PDDvPXWW2zZssWVGIzJSGnp4CmpqiNUdbGqjgCuT6Hcp0BnEZkItAd+EpGk4zcfAGoDQ0VkuYikNNwrw0VGRlK9enWef/55m/jOGN9XVVVnAbeoamMgxTEA6bmbn4iEiMgfnry0XESqZ9QL8ZZ+/fpRvHhxBgwYYHdrMebqpHoYg7nU2rVrWb9+PX369EFE3A7HL+XNm5fRo0fzyy+/0LZtW0aOHEmlSpWYOXOmtVHN1QgFqqtqM1WNUNVmlymb3pEYGxPNWbYOH8+fw4cPJ2/evDz55JNuh2KM16Wlg2e/iAwWkWYiMgTYl1whVT2Gc3lfDBChqhtU9ZkkZSarasFE4zTnpPcFXC0RYdCgQfz88898/vnnboVhjEmdOBF5GfhNRG4GziVXKNDHkCeWK1cuRo0aRUxMDHPnznU7HGP8WVqGMZgkXn/9dfLkyUPnzp3dDsXvlS5dmlmzZvHdd99RsmRJunbtSoMGDWx+HpNeRYG1IrLs/HKZsukdifGuiNQQkWCcuRI3eC36DFC4cGGGDx/OokWLWLhwodvhGONVaeng6YZzC9H/A/72PE6Wqh5R1Q9V9cBVRZdJ2rdvT5kyZRg3bpydATfGt3UAvgEG4JxZ6ppCuXAycAy5r+natSvVqlVj0KBBNp+YMemXlmEMJpG//vqL2bNn07lzZ/Lly+d2OAGjYcOGrF69mpkzZ7J7924aNGhA586d2bt375U3NsbDMySqqucKnmZXuIInvSMxRuIMnf8RWK2qX2XAS/GqRx55hIoVK9K/f3/OnUv2fKExfinVHTyqelZVX1PVRzw/jALmV0RISAhPPfUUMTExrFy50u1wjDEpUNVDqjpPVU+p6jJV3QkX5uZJLEPHkPvCJIGJBQcHM378eLZv386bb77pdjjG+Ku0DGMwiUyfPp0zZ87Qp08ft0MJOEFBQXTu3Jlff/2VIUOG8NFHH1GpUiXGjBnDqVOn3A7PBJj0jsRQ1c2qepOqVvfccdnnZcuWjQkTJvDLL78weXLSZqQx/uuKHTwi8t/El/R5lv9e4fI+v9O9e3eKFCnC6NFJO6mNMX6gcpLHGTqG3BcmCUyqZcuW3HrrrYwYMYKjR4+6HY4x/igtwxiMx4kTJ3jllVcICwujWrVqbocTsPLkycOYMWPYsmULLVu25JlnnqFcuXIMGjSIX3/91e3wTADxt5EYV6NVq1bcdtttREdHExsb63Y4xnjFFTt4zp/FSrIE3Jmt87caXrp0KUuXLr3yBsYYX5YlxpAnJiKMHz+e2NhYxo0b53Y4xvidNA5jMB4jRoxgz549jBkzxu1QsoRy5coxd+5cli1bRr169XjxxRepXLkyTZo04Z133rG7bhmTBiLCxIkTOXr0KNHR0W6HY4xXpGUOnoD3yCOPUKZMGZ5++mm7W4Ex/i3LjCFPrHbt2tx///28/PLL7N692+1wjDEBbuPGjUycOJEHHniAsLAwt8PJUiIiIvjss8/YvXs348aN48CBA3Tr1o3rrruOfv36cfDgQbdDNMYvVKtWjYcffpjJkyfbbdNNQLAOnkSyZ8/OmDFj+PHHH/nggw/cDscYk3oX3ZM3K40hT2r06NGoKs8++6zboRhjAlhCQgK9evWiYMGCPP/8826Hk2Vdd911DBw4kK1bt/LNN9/Qtm1bJk2aRLly5Rg8eDCHDx92O0RjfN6IESPImzcv/fv3txvuGL9nHTxJ3HvvvdSuXZuhQ4dy+vRpt8MxxiQiItlF5BYRaXJ+8Tx1b9KyWWkMeWKlS5emb9++zJw5k++//97tcIwxAertt98mJiaGCRMmUKhQIbfDyfJEhLCwMN599122bNnCnXfeyfPPP0/ZsmUZMWIEx44dcztEY3zW+dumL1682G6bbvyedfAkERQUxAsvvMAff/zBpEmT3A7HGHOxr4GeQIRnCQdQVbsWPZEhQ4ZQqlQp2rZty44dO9wOxxgTYA4cOMCgQYOIiIigc+fObodjkqhUqRLvv/8+GzdupHnz5kRHR1O2bFnGjRvH8ePH3Q7PGJ/Up08fKlWqRP/+/Tl79upuFr1nzx67es64xjp4ktGsWTMiIyMZM2aMzahujG9JUNWeqjrCs4x0OyBfVKBAARYuXMjZs2eJjIxMdx47c+bMVTdyjDGBp3///pw8eZLJkycjIlfewLiiWrVqzJ07l3Xr1lG/fn0GDx5M2bJlGT9+vE3GbEwS2bJl4+WXX2br1q20adMmzcdIXFwc8+fPp2XLllx//fWULVuWadOm2ZAvk+msgycFzz//PMeOHWPs2LFuh2KM+ddSERknIjeKSCkRKeV2QL7qxhtvZP78+ezYsYO2bdumesjpyZMnmTt3Lh07dqRw4cIUL16c6dOnWwPFGAPAkiVLmDVrFoMHD6Zy5cpuh2NSoU6dOixYsIDVq1dTp04dBg4cSNmyZXnxxRc5efKk2+EZ4zMiIyOZOnUqS5cu5dZbb03VCbL9+/czatQoypYty5133snmzZsZNmwYtWrV4oEHHiAqKoo9e/ZkQvTGOKyDJwXVq1enW7duvPrqqzbEwRjfUQ4oCjwNjACiXY3Gx52fj+G7776jc+fOKd4d8MSJE8yZM4d27dpRpEgR7rnnHr766is6duzIDTfcQI8ePWjatCk//fRTJr8CY4wvOXXqFH369KFixYoMGjTI7XBMGt1yyy0sWrSIVatWUbNmTQYMGEDZsmXp2bMnH3zwAfv373c7RGNc16NHD+bNm8ePP/5IWFhYincl3bx5M506daJUqVIMGzaMKlWqMG/ePHbu3MmIESNYtmwZkyZN4ptvvqFq1ap2NY/JPKrq80udOnXUDbt379acOXPqfffd58r+jUkvYJ36wLGb0QtwndsxuJWf0uLFF19UQPv163dh3blz53ThwoXaqVMnzZUrlwJarFgx7dOnjy5btkzPnTunqqrx8fE6depUveaaazQkJEQHDRqkJ06ccOulmACQVfKTLyzezk9PPvmkAvr11197tV7jjm+//VbvvvtuLVCggAIKaOXKlfXhhx/W2bNn686dOzUhIcHtMDOV5Sf/zU/etnz5cs2XL5+WLFlSt2zZcmH9mjVrtG3btgpo7ty5tV+/fvrrr7+mWM/27du1adOmCmjLli11z549mRG+CUCpzU+uH9ypWdxMAEOGDFFA165d61oMxqRVoDZQgFHABmA7sA3Y4HZMvt5AUVVNSEjQvn37KqBDhgzRvn376rXXXquAFixYUHv16qUrVqzQuLi4FOs4dOiQdu/eXQEtU6aMLly4MBNfgQkkgZqffHHxZn56/vnnFdDevXt7rU7jG+Li4nTdunX6wgsvaFRUlObJk+dCh0+RIkX0jjvu0OjoaF2wYIHu378/oDt9LD/5Z37KKD/88IMWLVpUr7nmGp06daredtttF9pOw4cP19jY2FTVEx8fr6+++qrmypVLCxcurIsWLcrgyE0gSm1+Eqesb6tbt66uW7fOlX0fPXqUihUrUqxYMWJiYsiVK5crcRiTFiKyXlXruh2Ht4nICiAKmIJzN63PVPVWN2NyMz+lRXx8PO3ateOTTz4he/bstGrVivvvv5/IyEiyZ8+e6nq++eYbHn74YX7++WceeughJkyYQJ48eTIwchNoAjU/+SJv5adJkybRt29fOnbsyLvvvktwcLAXojO+Ki4ujh9//JG1a9eyZs0a1q5dy5YtWzj/myFPnjyUL1+e8uXLU6FCBcqXL0/NmjWpW7cuQUH+PfuD5afM4y/tp+3bt9OiRQt+//13ihYtSv/+/enduzd58+ZNc11bt26lXbt2bN68maFDhxIdHW351KRaavOTdfCkwqJFi4iKiqJLly5Mnz7d7hhhfF6gNlBEZCUwEBgMjAXeVNWb3IzJ7fyUFqdPn2bp0qWEhYVRoECBdNdz5swZhg0bxgsvvECZMmV45513CAsL82KkJpAFan7yRd7IT2+//TYPPfQQd911F3PmzCE0NNRL0Rl/8s8///C///2PjRs3sn37drZt28a2bdvYsWPHhbstFilShMjISO644w5uv/128ufP73LUaWf5KfP4U/vpzz//ZNmyZbRt25acOXNeVV0nT57kscceY9q0aURERPDBBx9QrFgxL0VqApl18HjZ8OHDGTlyJFOmTOGBBx5wNRZjriRQGygich1QAogD+gNfqOqHbsbkC/nJLd9++y1du3Zlx44dPPnkk4waNYocOXK4HZbxcYGan3zR1ean9957jy5duhAZGcknn3xCtmzZvBidCQTx8fHs3r2bVatW8cUXX7Bo0SKOHDlCSEgIjRs3JiIignr16lGvXj0KFy7sdrhXZPkp82Tl9hPAjBkz6NOnD/ny5WPWrFlERES4HZLxcdbB42Xx8fG0bNmSlStXEhMTQ82aNV2Nx5jLCeQGiogUBIoDR4ADqpr8raEyiS/kJzcdP36cAQMG8MYbb1C1alUGDBhAs2bNuP76690OzfioQM5PvuZq8tPHH39Mhw4dCA8P54svvrjqs9Yma4iLiyMmJoYvv/ySBQsWsGnTpgtDu8qUKXOhs6datWpUrlyZ0qVL+9QQFctPmSert5/AuRPXPffcw2+//catt95KmzZtaNOmDaVKlXI7NOODrIMnAxw6dIhatWqRI0cO1q1bd1VDHIzJSIHaQBGRgcBdQC5gHNBSVbu4GZOv5Ce3LVq0iF69evHHH38AUKFCBZo1a0azZs1o1KgRp/ScfSkAABTbSURBVE6dYt++fRctsbGxZMuWjZw5c5IjR44LS+HChalTpw7VqlWzKwYCUKDmJ1+U3vy0dOlSoqKiLtxWO3fu3BkQnckKjh07xvr161m7du2FZdeuXReez549OxUqVKBy5cpUqlSJ4sWLU7Ro0YuWAgUKZNr0CJafMo+1nxzHjx/nueeeY+7cuWzduhWAWrVq0aZNG1q0aEFwcDAnTpzg5MmTF5a4uDjy5ctH/vz5L1oKFy5s7aYAZh08GeS7774jPDycVq1aMW/ePJuPx/ikQG2giMi3qtpYRP6rqhHnH7sZky/lJ7clJCTw008/8fXXX7Ns2TJWrFjBsWPHki2bM2dOChcuzLlz5zh9+jSnTp3izJkzF5XJnj07NWrUoG7dutSrV4/GjRtToUKFzHgpJgMFan7yRenJT8eOHaNKlSoUKFCAVatWkS9fvgyKzmRVf/31Fz///DNbt269aNm+fTvx8fGXlA8ODiZHjhxkz579oiU0NBRVJSEh4aI7yCQkJBAXF0dcXBzx8fEX/t+jRw9efPHFy8Zm+SnzWPvpUlu3bmX+/PnMnz+f1atXk9bf6aGhodx0003UrVuXOnXqULduXapVq2ZzpwUI6+DJQC+99BL9+/dnwoQJ9O/f3+1wjLlEoDZQRGQBMBvoDQwCBqhqKzdj8rX85Evi4uL43//+x5o1a8ifPz/Fixe/sOTLl++SDvKEhATOnDnDvn37WLduHevWrWPt2rWsX7+e48ePA1C5cmVat25N69atadiwISEhIW68NHMVAjU/+aL05KdHH32U119/nZiYGG6++eYMisyYS8XHxxMbG8vBgwcvWv766y/OnDlzyXLu3DmCgoIQkYuWoKAgQkNDCQ4OJiQkhJCQEIKDg2nSpAnt2rW7bAyWnzKPtZ8u788//yQmJobQ0FBy5cp1YcmdOzfBwcEcO3aMo0ePXrTs2LHjQvvp6NGjgHOy7Oabb+b222/n9v9v795jozrPPI5/H3zD2NhgEmzSJLAUGtw1t5C0ZMHBRSFAtMoql6bVslslEY262k0q9Z+olz+aqs1GqbTqatWo6SVkuytFanc3rViWcGlrME0IGyChuIDiUFgoLhdjLrZxbM88+8dc4ju28XjOnPP7SEfneDhz5n3mzDy885xz3rN2LXfeeWfO3+0uqlTgySB359FHH+WXv/wl3//+99m4cWOgrh8WCWsHxcxmkriD1h3AUeAFdz+XzTYFLT+FUTwe59ixY+zcuZPNmzdTX19Pd3c3FRUVrF+/nvvvv5/a2lrmzJmjsypzQFjzUxCNNj+99dZbrFixgmeeeYbvfe97GWyZSDApP00c9Z8yx905fvx4+kDZr3/9aw4ePAjAjBkzWLNmDWvXrmXFihV8/OMfV8EnR6jAk2GXL1/mwQcfZPfu3dTU1PDiiy+ybt26Ef+4cHfOnj3L0aNHAbjnnnsoKirKZJMlQsLaQTGzp4DlQOqL5u7+ZBabFMj8FHZXrlxhx44dbN68mS1btnDhwgUAbr31Vmpra6mtrWXlypXMmjWL4uJiiouL1XkJkLDmpyAaTX7q6upi6dKltLW10djYSGlpaYZbJxI8yk8TR/2niXXu3Dl27NjBtm3b2L59O2fPngWgrKyMpUuXsmzZMpYtW8bixYspKiqip6eH7u7u9CWO8XicgoICioqKKCwsTM/Ly8s1CP8EGWl+0rntY1ReXk59fT2vv/46zz77LA888AD33Xcf3/3ud/vcYau7u5sPPvggfX3xkSNHOHr0KEePHuXSpUvp9aZMmcLq1atZv34969atY+7cudkISyTongD+Ghh4kb5ERllZGY888giPPPJIetyfhoYGGhoa2LVrF6+99tqA5xQWFjJlyhSKi4spKSmhpKSE0tLS9HJFRQULFy5k6dKlLF68mPLy8ixEJpI9L7zwAr///e/ZsmWLijsiIiEzc+ZMNmzYwIYNG3B3Dh8+zL59+zhw4AD79+/npZdeorOzc9TbnTRpEjU1NXz6059OT9XV1bq6JYt0Bs846Orq4gc/+AHPPfccra2tPPTQQ3R1dXHs2DGOHz/eZ8C4qqoqqqurWbBgAQsWLKC6uprOzk62bdvG1q1bOX78OADz58/nvvvuY9WqVaxatYqqqqpshSc5KKxHoMzsF8BC4CSJs3jc3Vdns01Bz09R4+784Q9/4M033+TixYtcu3atz5S6A0VbWxvt7e20t7fT1tbGuXPnOH/+fHo7c+fOZenSpSxZsoRFixaxaNEiZs+erUvAxkFY81MQjTQ/HTlyhCVLlvDwww8PWiAViQrlp4mj/lOw9PT0cOTIEQ4fPkwsFkuPX1VQUEB+fj5mRnd3N11dXXz44YfpeXNzM/v27WPfvn20trYCUFpayvLly6mrq6Ouro67775bd/caB7pEKwsuXbrE888/z6ZNm6iqquKOO+4YMA13a3V3p6mpiTfeeIOtW7fS0NCQHlj0E5/4BKtWreLee+9l3rx5zJgxgxkzZjBt2jRdeiADhLWDYmZbgM+6e0e225KSK/lJrq+5uZl3332XgwcPpudNTU3pfy8rK2PhwoUsWrSIOXPmUF5eTllZWZ9pypQpTJ48OX1p2OTJk3UUq5+w5qcgGkl+isfjrFq1isbGRo4cOUJlZeUEtU4keJSfJo76T+Hi7rz//vu8/fbb7N27lz179nDo0CEgcefUFStWUFdXR01NDbfccguzZs2isrJSd/gahawWeMzsJ8AngS3u/u2xrpMS1QSQugPN7t272bVrFw0NDekR0VPMjOnTpzN9+vT0HWj632EgdUeBvLw8Jk2axKRJkygsLGTq1KlMnTqVsrKy9Ly8vJxp06alt5mapk6dmr6UITVNmTIlXdHVUe1gCWsHxczeAYqBs6nHdAaPZFJbWxuHDx/m0KFDfab+uXg4hYWFVFVVcdtttw2YKisr01NJSUkGIwmOsOanIBpJfnr55Zf50pe+xCuvvMITTzwxQS0TCSblp4mj/lP4tbS0sHv3burr66mvr08XfFLMjJkzZ1JVVUVhYSGxWIx4PJ6ezIzbb7+defPm9Zlmz54dycJQ1go8ZvYw8KC7P25mrwD/6O7vj3ad3pQAEmKxGI2NjZw+fZqWlhYuXrxIS0sLLS0ttLa2kpeXR1FRUZ+psLAQd09/UVJfnK6uLq5evcqVK1f6zC9dukRraysdHaM7QaL3rSlT8/7LqcJSql2p5YKCggFT6rTAwW59mSom9Z/3fp3eU/9tjKUgNdRrDradsXynRrOdmpoaNm7ceL3tRaqDYmavu/tDgzw+pmKzCtAyHHeno6ODK1euDJg6Ojq4du0anZ2d6cvC2tvbaW5u5tSpU5w6dYrTp0/z4YcfDthuSUkJlZWVTJs2je7u7vQp0KnToN29T+5MLRcXFw+4hWrvgaXdvc8Ui8UGTPF4nLy8vD5Tfn7+kGeIxmKx9MCLvQdirK2t5ctf/vKw71/U8lM2XS8/nTlzhurqau666y527typgzUSecpPE9N3AvWfoujixYscP36c5uZmzpw502cei8UGnJDQ09PDyZMnaWpqor29Pb2d1AkON910U5+puLiYtrY2rl69ytWrV9PLPT09A/o4eXl5A/pUqXmq79P7N6OZpfs+vefxeDzd3t5tBwbtb8VisT6DV6eWH3/8cZ5++ulh379sDrJcB/wsubwdWAn0L95cd53k3XKeArj99tsz0Mzck5eXlx4LItO6urpobW1NT1evXqWjoyM9ZkVqisVifQpIqeXBHovFYn1+rPReTn3A29vb+3zo+/8wicfjwEfFj97z3q/Zu6DVfxupaTR3PBtuPpjRdJJHu50HHnjgugWeCBpw7WOykJzn7veY2StmNn+IYnOfdUiM8TPs8yTazCx9JuOsWbNG/Xx35/z585w6dYpz585x9uzZPtPly5cHLeQMdf17Z2cnHR0dXLhwgfb29vQ4Q6mjX6k2p6bBOjmTJk0iHo+nOyy9p8HykJkNKMjn5+czf/78G35/ZeLs2bMHM+Pll19WcUdE1HeSjKqoqKCiomLUz0vdfbqpqYmmpiZOnDjBhQsX0tPJkyc5cOAAHR0dlJaWpq9SKS0tpaqqivz8/CELLdeuXePy5ct9+lW9fy+mXt/dyc/PTx8A691/Sh086/0b1N0H7W/17jNNnjyZ0tJSCgoKKCsrG7f3ORMFnhLgj8nli8CdY1nH3X8I/BASFd7xb6YMp7CwMH3ZgEgOGCxH1DG2YvPSETxPZMxSpyTPnDkz202RHDTao+TDeeyxx1i7dq3uGiciKXVksO+kA/gyFmZGVVUVVVVVrFy5MtvNCbxMjM7bRmKMDIDSIV5jJOuIiNyI/oXkwaqVg61z3eeZ2VNm9o6ZvdP7zksiIpnU+8g5MDd55PyGqLgjIr1krO8EiQP47n6Xu9918803j1ujReQjmSis7CdRtQVYDJwY4zoiIiM12LUFYy02X/d56qCISJbUMfAoeR8qQIvIDchY30lEJkYmLtH6BdBgZrcA64HPm9m33f0bw6yzfLgN7t+//4KZnRzBa98EXBhju3OJ4gyXTMQ5e5y3F0hmttLd97j7Zwb551QheS+JQvKxEa5zegTP+2gDI8tPUfksQ3RiVZxjF4n8lCGjusTdzM4rP/URlVgV59hFPT9NSN8J1H8aRFRiVZxjN6L8lKnbpE8H1gC73f1PY11nDK/7ThRGvlec4RKVOMeDme1w9zW9/m5w99oh1i0DGoBfkSw2A5/tXWweZJ3lJMbz6fOYu4/8ntiDtyUy+zgqsSpOyQYz+2fgNXffm7xca4G7P3+D24zMPo5KrIpTxipIfafka0VmH0clVsWZeZk4gwd3b+WjU4jHvI6ICICZLSIxgN/HzOwLyYdLgM6hnuPuV8ysjkQh+cVkIfm966xzOfl6Ax4TEQmAkRxdFxEZE/WdRHJfRgo8IiLjzAaZtwCPDfeksRabVYAWkYAa1SXuIiKjpb6TSG4LW4Hnh9luwARRnOESlTjHzN3fA94zszvc/afZbs8YRGkfRyVWxSkTbqgj5zcoSvs4KrEqTgmLKO3jqMSqODMsI2PwiIhkgpkVAE8A1UAj8Kq792S3VSIiIiIiItmnW9iJSC55BZgFvAF8DNiU3eaIiIiIiIgEQ9gu0RLJCWZWASwDDrp7FG4VOF5uc/e/TS5vM7P6bDZGJIyUn0QkqJSfRCSogpKfQnEGj5n9xMzeMrNvXH/t3GRmlWbWkFwuMLPNZvZbM3sy220bD2ZWbmZbzWy7mb1uZoVh3a9mNh34b+BTwG/M7OawxpoBZ8zsq2a22sy+DpzJdoOuJwr7VvkpPJSfoiUK+1b5KTyUn6IlCvtW+Sk8gpSfcr7AY2YPA3nufg8w18zmZ7tN4y35gflXEreFBnga2O/uK4BHzWxq1ho3fjYA/+Tu9wN/Aj5PePfrIuAr7v4dYBuwmvDGOt4eB64ADwMXk38HlvKT8lMOUn6KCOUn5accpPwUEcpPyk85KDD5KecLPEAdH92SbzuwMntNyZgY8DkSP2yhb8y7gbuy0KZx5e4vufuO5J83A39DSPeru+9y971mdi+JKu9aQhprBsSBnuTUnfw7yOoI/75VfgrRflV+ipQ6wr9vlZ9CtF+VnyKljvDvW+WnEO3XIOWnMBR4SoA/JpcvApVZbEtGuPuVfrdCDW3MZnYPMB04RUhjBDAzI5HUWwEnxLGOs00k3p+t5MYgy6H9rqYoP4UrRlB+ipDQfldTlJ/CFSMoP0VIaL+rKcpP4YoRgpOfwlDgaQOKk8ulhCOm6wllzJYYmOpfgCcJaYwpnvD3wCHgLwhxrOPsVnf/lrtvc/fngNuy3aDrCPXneAihjFn5CQhhrBEX6s/xEEIZs/ITEMJYIy7Un+MhhDJm5SdggmMNw5u6n49OeVoMnMheUyZM6GI2s0Lg58BX3f0kIYwxxcyeNbMvJP+cBrxASGPNgOZegyx/jeAPshzaz/EwQhez8lM4Y5Xwfo6HEbqYlZ/CGauE93M8jNDFrPyUnVjN3SfqtTLCzMqABuBXwHpgeb/T3ULDzOrdvc7MZgP/A+wkUR1c7u6x7LbuxpjZ3wHPA+8lH9oEfIUQ7tfkoGo/A4qAw8BXSVxrG7pYx1vyP4ovAp8EGoEfu3tXdls1NOUn5adco/wUHcpPyk+5RvkpOpSflJ9yTZDyU84XeCD9hq4Bdrv7n7LdnolgZreQqApuC8sXo78o7dcoxRo1Udy3yk/hEqVYoyaK+1b5KVyiFGvURHHfKj+FS7ZiDUWBR0Siwcy2uvv6bLdDREREREQkaMIwBo+IRMfvzOyvst0IERERERGRoNEZPCKSM8zsN8By4HdAO4kB61dnt1UiIiIiIiLZpzN4BDP7ppkdMbP65LTkBrdVN47NE0lz98+4e7G7fyq5rOJOyCk/iUhQKT+JSFApP0VXfrYbIIHxHXf/92w3QkRkEMpPIhJUyk8iElTKTxGkAo8MYGavAmVAJXDQ3f/BzIqAV4FbgNPAEyTOAHsVuBW4BDyW3MQaM/tWchvrgMvAz5N/twCfdfeeCQpHREJE+UlEgkr5SUSCSvkpOnSJlqR8PXUKH5AH/Ie7rwD+zMyWAV8EDrv7KuB94EngKeA9d18J/CdQk9zWPHe/F/gvYDXwSSCefGwTUDqBcYlI7lN+EpGgUn4SkaBSfoogFXgk5TvuXufudUAM2J98/BAwh8SX+O3kY3uBamABsC/52KvA/yaXf5qc/x9QCBwADpvZdmAt0JGpIEQklJSfRCSolJ9EJKiUnyJIBR4ZyqeS8yXAB0AjibsXkZw3AkeBu5OPfQ3YmFxu77etxcBv3f1+YDpQm6E2i0g0KD+JSFApP4lIUCk/RYAKPJLS+xS+zwF/aWa/BY66+7vAj4E/N7PdwHwSFd0fAXcmn3Mn8G9DbPsE8IyZvQlUAe9kMA4RCR/lJxEJKuUnEQkq5acIMnfPdhskYJKDcH3T3U9kuSkiIn0oP4lIUCk/iUhQKT9Fhwo8IiIiIiIiIiI5TpdoiYiIiIiIiIjkOBV4RERERERERERynAo8IiIiIiIiIiI5TgUeEREREREREZEcpwKPiIiIiIiIiEiOU4FHRERERERERCTH/T/A4MimUYfVJgAAAABJRU5ErkJggg==\n", 2070 | "text/plain": [ 2071 | "
" 2072 | ] 2073 | }, 2074 | "metadata": {}, 2075 | "output_type": "display_data" 2076 | }, 2077 | { 2078 | "ename": "ValueError", 2079 | "evalue": "in user code:\n\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\keras\\engine\\training.py:1586 predict_function *\n return step_function(self, iterator)\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\keras\\engine\\training.py:1576 step_function **\n outputs = model.distribute_strategy.run(run_step, args=(data,))\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\tensorflow\\python\\distribute\\distribute_lib.py:1286 run\n return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\tensorflow\\python\\distribute\\distribute_lib.py:2849 call_for_each_replica\n return self._call_for_each_replica(fn, args, kwargs)\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\tensorflow\\python\\distribute\\distribute_lib.py:3632 _call_for_each_replica\n return fn(*args, **kwargs)\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\keras\\engine\\training.py:1569 run_step **\n outputs = model.predict_step(data)\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\keras\\engine\\training.py:1537 predict_step\n return self(x, training=False)\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\keras\\engine\\base_layer.py:1020 __call__\n input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\keras\\engine\\input_spec.py:218 assert_input_compatibility\n str(tuple(shape)))\n\n ValueError: Input 0 of layer sequential is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: (None, 64)\n", 2080 | "output_type": "error", 2081 | "traceback": [ 2082 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 2083 | "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", 2084 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[0mmode\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'LSTM'\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mset_my_seed\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mtrain_fuc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mmode\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mwindow_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mwindow_size\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mepochs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mepochs\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mhidden_dim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mhidden_dim\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", 2085 | "\u001b[1;32m\u001b[0m in \u001b[0;36mtrain_fuc\u001b[1;34m(mode, window_size, batch_size, epochs, hidden_dim, train_ratio, kernel, show_loss, show_fit)\u001b[0m\n\u001b[0;32m 32\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 33\u001b[0m \u001b[1;31m#预测\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 34\u001b[1;33m \u001b[0my_pred\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_test\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 35\u001b[0m \u001b[0my_pred\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0my_scaler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minverse_transform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 36\u001b[0m \u001b[1;31m#print(f'真实y的形状:{y_test.shape},预测y的形状:{y_pred.shape}')\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 2086 | "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\keras\\engine\\training.py\u001b[0m in \u001b[0;36mpredict\u001b[1;34m(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[0;32m 1749\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mstep\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mdata_handler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msteps\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1750\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mon_predict_batch_begin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1751\u001b[1;33m \u001b[0mtmp_batch_outputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpredict_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1752\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mdata_handler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshould_sync\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1753\u001b[0m \u001b[0mcontext\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0masync_wait\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 2087 | "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\tensorflow\\python\\eager\\def_function.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *args, **kwds)\u001b[0m\n\u001b[0;32m 883\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 884\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mOptionalXlaContext\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_jit_compile\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 885\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 886\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 887\u001b[0m \u001b[0mnew_tracing_count\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexperimental_get_tracing_count\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 2088 | "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\tensorflow\\python\\eager\\def_function.py\u001b[0m in \u001b[0;36m_call\u001b[1;34m(self, *args, **kwds)\u001b[0m\n\u001b[0;32m 931\u001b[0m \u001b[1;31m# This is the first call of __call__, so we have to initialize.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 932\u001b[0m \u001b[0minitializers\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 933\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_initialize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwds\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0madd_initializers_to\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0minitializers\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 934\u001b[0m \u001b[1;32mfinally\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 935\u001b[0m \u001b[1;31m# At this point we know that the initialization is complete (or less\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 2089 | "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\tensorflow\\python\\eager\\def_function.py\u001b[0m in \u001b[0;36m_initialize\u001b[1;34m(self, args, kwds, add_initializers_to)\u001b[0m\n\u001b[0;32m 758\u001b[0m self._concrete_stateful_fn = (\n\u001b[0;32m 759\u001b[0m self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access\n\u001b[1;32m--> 760\u001b[1;33m *args, **kwds))\n\u001b[0m\u001b[0;32m 761\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 762\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0minvalid_creator_scope\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0munused_args\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0munused_kwds\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 2090 | "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\tensorflow\\python\\eager\\function.py\u001b[0m in \u001b[0;36m_get_concrete_function_internal_garbage_collected\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 3064\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3065\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_lock\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 3066\u001b[1;33m \u001b[0mgraph_function\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_maybe_define_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3067\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mgraph_function\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3068\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 2091 | "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\tensorflow\\python\\eager\\function.py\u001b[0m in \u001b[0;36m_maybe_define_function\u001b[1;34m(self, args, kwargs)\u001b[0m\n\u001b[0;32m 3461\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3462\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_function_cache\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmissed\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcall_context_key\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 3463\u001b[1;33m \u001b[0mgraph_function\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_create_graph_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3464\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_function_cache\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mprimary\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mcache_key\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgraph_function\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3465\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 2092 | "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\tensorflow\\python\\eager\\function.py\u001b[0m in \u001b[0;36m_create_graph_function\u001b[1;34m(self, args, kwargs, override_flat_arg_shapes)\u001b[0m\n\u001b[0;32m 3306\u001b[0m \u001b[0marg_names\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0marg_names\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3307\u001b[0m \u001b[0moverride_flat_arg_shapes\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moverride_flat_arg_shapes\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 3308\u001b[1;33m capture_by_value=self._capture_by_value),\n\u001b[0m\u001b[0;32m 3309\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_function_attributes\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3310\u001b[0m \u001b[0mfunction_spec\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfunction_spec\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 2093 | "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\tensorflow\\python\\framework\\func_graph.py\u001b[0m in \u001b[0;36mfunc_graph_from_py_func\u001b[1;34m(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)\u001b[0m\n\u001b[0;32m 1005\u001b[0m \u001b[0m_\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moriginal_func\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_decorator\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munwrap\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpython_func\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1006\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1007\u001b[1;33m \u001b[0mfunc_outputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpython_func\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mfunc_args\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mfunc_kwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1008\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1009\u001b[0m \u001b[1;31m# invariant: `func_outputs` contains only Tensors, CompositeTensors,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 2094 | "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\tensorflow\\python\\eager\\def_function.py\u001b[0m in \u001b[0;36mwrapped_fn\u001b[1;34m(*args, **kwds)\u001b[0m\n\u001b[0;32m 666\u001b[0m \u001b[1;31m# the function a weak reference to itself to avoid a reference cycle.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 667\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mOptionalXlaContext\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcompile_with_xla\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 668\u001b[1;33m \u001b[0mout\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mweak_wrapped_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__wrapped__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 669\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 670\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 2095 | "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\tensorflow\\python\\framework\\func_graph.py\u001b[0m in \u001b[0;36mwrapper\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 992\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# pylint:disable=broad-except\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 993\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"ag_error_metadata\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 994\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mag_error_metadata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto_exception\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 995\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 996\u001b[0m \u001b[1;32mraise\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 2096 | "\u001b[1;31mValueError\u001b[0m: in user code:\n\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\keras\\engine\\training.py:1586 predict_function *\n return step_function(self, iterator)\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\keras\\engine\\training.py:1576 step_function **\n outputs = model.distribute_strategy.run(run_step, args=(data,))\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\tensorflow\\python\\distribute\\distribute_lib.py:1286 run\n return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\tensorflow\\python\\distribute\\distribute_lib.py:2849 call_for_each_replica\n return self._call_for_each_replica(fn, args, kwargs)\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\tensorflow\\python\\distribute\\distribute_lib.py:3632 _call_for_each_replica\n return fn(*args, **kwargs)\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\keras\\engine\\training.py:1569 run_step **\n outputs = model.predict_step(data)\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\keras\\engine\\training.py:1537 predict_step\n return self(x, training=False)\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\keras\\engine\\base_layer.py:1020 __call__\n input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)\n C:\\Users\\HUAWEI\\AppData\\Roaming\\Python\\Python36\\site-packages\\keras\\engine\\input_spec.py:218 assert_input_compatibility\n str(tuple(shape)))\n\n ValueError: Input 0 of layer sequential is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: (None, 64)\n" 2097 | ] 2098 | } 2099 | ], 2100 | "source": [ 2101 | "mode='LSTM' \n", 2102 | "set_my_seed()\n", 2103 | "train_fuc(mode=mode,window_size=window_size,batch_size=batch_size,epochs=epochs,hidden_dim=hidden_dim)" 2104 | ] 2105 | }, 2106 | { 2107 | "cell_type": "code", 2108 | "execution_count": null, 2109 | "metadata": {}, 2110 | "outputs": [], 2111 | "source": [] 2112 | }, 2113 | { 2114 | "cell_type": "code", 2115 | "execution_count": null, 2116 | "metadata": {}, 2117 | "outputs": [], 2118 | "source": [] 2119 | }, 2120 | { 2121 | "cell_type": "code", 2122 | "execution_count": null, 2123 | "metadata": {}, 2124 | "outputs": [], 2125 | "source": [] 2126 | }, 2127 | { 2128 | "cell_type": "code", 2129 | "execution_count": null, 2130 | "metadata": {}, 2131 | "outputs": [], 2132 | "source": [] 2133 | }, 2134 | { 2135 | "cell_type": "code", 2136 | "execution_count": null, 2137 | "metadata": {}, 2138 | "outputs": [], 2139 | "source": [] 2140 | }, 2141 | { 2142 | "cell_type": "code", 2143 | "execution_count": null, 2144 | "metadata": {}, 2145 | "outputs": [], 2146 | "source": [] 2147 | }, 2148 | { 2149 | "cell_type": "code", 2150 | "execution_count": null, 2151 | "metadata": {}, 2152 | "outputs": [], 2153 | "source": [] 2154 | }, 2155 | { 2156 | "cell_type": "code", 2157 | "execution_count": null, 2158 | "metadata": {}, 2159 | "outputs": [], 2160 | "source": [] 2161 | }, 2162 | { 2163 | "cell_type": "code", 2164 | "execution_count": null, 2165 | "metadata": {}, 2166 | "outputs": [], 2167 | "source": [ 2168 | "scaler = MinMaxScaler(feature_range=(0,1))\n", 2169 | "df_for_training_scaled = scaler.fit_transform(df_for_training)\n", 2170 | "df_for_testing_scaled = scaler.fit_transform(df_for_testing)\n", 2171 | "df_for_training_scaled[1]" 2172 | ] 2173 | }, 2174 | { 2175 | "cell_type": "code", 2176 | "execution_count": null, 2177 | "metadata": {}, 2178 | "outputs": [], 2179 | "source": [] 2180 | }, 2181 | { 2182 | "cell_type": "code", 2183 | "execution_count": 45, 2184 | "metadata": {}, 2185 | "outputs": [], 2186 | "source": [] 2187 | }, 2188 | { 2189 | "cell_type": "code", 2190 | "execution_count": 47, 2191 | "metadata": {}, 2192 | "outputs": [], 2193 | "source": [] 2194 | }, 2195 | { 2196 | "cell_type": "code", 2197 | "execution_count": 49, 2198 | "metadata": {}, 2199 | "outputs": [ 2200 | { 2201 | "data": { 2202 | "text/plain": [ 2203 | "array([[-1.72177148, 0.39968477, -0.28679889, 1.33915253],\n", 2204 | " [-1.70115147, 0.62927986, -0.06876642, 1.46712657],\n", 2205 | " [-1.68053145, 0.82593311, -0.18622368, 1.46993546],\n", 2206 | " [-1.65991143, 0.82376875, -0.36872002, 1.47321027],\n", 2207 | " [-1.63929141, 0.78322648, -0.55769281, 1.46998871],\n", 2208 | " [-1.61867139, 0.75231784, -0.48378316, 1.47118015],\n", 2209 | " [-1.59805138, 0.74282674, -0.41785706, 1.47092389],\n", 2210 | " [-1.57743136, 1.14456278, -0.51760631, 1.47888129],\n", 2211 | " [-1.55681134, 1.10785623, -0.60932259, 1.47295068],\n", 2212 | " [-1.53619132, 1.05484 , -0.70820023, 1.47492423],\n", 2213 | " [-1.51557131, 0.99619855, -0.8788879 , 1.47575957],\n", 2214 | " [-1.49495129, 1.31698373, -0.58375527, 1.41470957],\n", 2215 | " [-1.47433127, 1.26558002, -0.64917032, 1.41663652],\n", 2216 | " [-1.45371125, 1.19234073, -0.87889248, 1.35256796],\n", 2217 | " [-1.43309123, 1.40854634, -1.04004974, 1.35257129],\n", 2218 | " [-1.41247122, 1.3688687 , -1.20043135, 1.34757919],\n", 2219 | " [-1.3918512 , 1.33484593, -1.18608425, 1.28944124],\n", 2220 | " [-1.37123118, 1.29735956, -1.15278445, 1.28928149],\n", 2221 | " [-1.35061116, 1.2595977 , -1.15102626, 1.28777721],\n", 2222 | " [-1.32999115, 1.08033241, -0.57325752, 1.54091641],\n", 2223 | " [-1.30937113, 1.09655479, -0.10439337, 1.60576373],\n", 2224 | " [-1.28875111, 1.22610805, -0.64279837, 1.54382846],\n", 2225 | " [-1.26813109, 1.36757623, -1.12006049, 1.47888129],\n", 2226 | " [-1.24751107, 1.22049209, -1.25338575, 1.41367121],\n", 2227 | " [-1.22689106, 1.28373098, -1.41602512, 1.41564475],\n", 2228 | " [-1.20627104, 1.45011565, -1.5467097 , 1.35225845],\n", 2229 | " [-1.18565102, 1.31173818, -1.67789263, 1.41564475],\n", 2230 | " [-1.165031 , 1.22730922, -1.72232625, 1.34752594],\n", 2231 | " [-1.14441099, 1.39917637, -1.68573704, 1.34726635],\n", 2232 | " [-1.12379097, 1.29802789, -1.66828566, 1.29136486]])" 2233 | ] 2234 | }, 2235 | "execution_count": 49, 2236 | "metadata": {}, 2237 | "output_type": "execute_result" 2238 | } 2239 | ], 2240 | "source": [] 2241 | }, 2242 | { 2243 | "cell_type": "code", 2244 | "execution_count": null, 2245 | "metadata": {}, 2246 | "outputs": [], 2247 | "source": [] 2248 | }, 2249 | { 2250 | "cell_type": "code", 2251 | "execution_count": null, 2252 | "metadata": {}, 2253 | "outputs": [], 2254 | "source": [] 2255 | }, 2256 | { 2257 | "cell_type": "code", 2258 | "execution_count": null, 2259 | "metadata": {}, 2260 | "outputs": [], 2261 | "source": [] 2262 | }, 2263 | { 2264 | "cell_type": "code", 2265 | "execution_count": null, 2266 | "metadata": {}, 2267 | "outputs": [], 2268 | "source": [] 2269 | }, 2270 | { 2271 | "cell_type": "code", 2272 | "execution_count": null, 2273 | "metadata": {}, 2274 | "outputs": [], 2275 | "source": [] 2276 | }, 2277 | { 2278 | "cell_type": "code", 2279 | "execution_count": null, 2280 | "metadata": {}, 2281 | "outputs": [], 2282 | "source": [] 2283 | }, 2284 | { 2285 | "cell_type": "code", 2286 | "execution_count": null, 2287 | "metadata": {}, 2288 | "outputs": [], 2289 | "source": [] 2290 | }, 2291 | { 2292 | "cell_type": "markdown", 2293 | "metadata": {}, 2294 | "source": [ 2295 | "定义随机种子和评估函数" 2296 | ] 2297 | }, 2298 | { 2299 | "cell_type": "code", 2300 | "execution_count": 32, 2301 | "metadata": {}, 2302 | "outputs": [], 2303 | "source": [ 2304 | "def set_my_seed():\n", 2305 | " os.environ['PYTHONHASHSEED'] = '0'##主要是为了禁止hash随机化,使得实验可复现。\n", 2306 | " np.random.seed(1)##Numpy 在一个明确的初始状态生成固定随机数字所必需的。\n", 2307 | " rn.seed(12345)##Python 在一个明确的初始状态生成固定随机数字所必需的。\n", 2308 | " tf.random.set_seed(123)##在一个明确的初始状态下生成固定随机数字。\n", 2309 | " \n", 2310 | "def evaluation(y_test, y_predict):\n", 2311 | " mae = mean_absolute_error(y_test, y_predict)\n", 2312 | " mse = mean_squared_error(y_test, y_predict)\n", 2313 | " rmse = np.sqrt(mean_squared_error(y_test, y_predict))\n", 2314 | " mape=(abs(y_predict -y_test)/ y_test).mean()\n", 2315 | " r_2=r2_score(y_test, y_predict)\n", 2316 | " return mae, rmse, mape,r_2 #mse" 2317 | ] 2318 | }, 2319 | { 2320 | "cell_type": "markdown", 2321 | "metadata": {}, 2322 | "source": [ 2323 | "构建序列数据的测试集和数据集" 2324 | ] 2325 | }, 2326 | { 2327 | "cell_type": "code", 2328 | "execution_count": 33, 2329 | "metadata": {}, 2330 | "outputs": [], 2331 | "source": [ 2332 | "def build_sequences(text, window_size=24):\n", 2333 | " #text:list of capacity\n", 2334 | " x, y = [],[]\n", 2335 | " for i in range(len(text) - window_size):\n", 2336 | " sequence = text[i:i+window_size]\n", 2337 | " target = text[i+window_size]\n", 2338 | " x.append(sequence)\n", 2339 | " y.append(target)\n", 2340 | " return np.array(x), np.array(y)\n" 2341 | ] 2342 | }, 2343 | { 2344 | "cell_type": "code", 2345 | "execution_count": 34, 2346 | "metadata": {}, 2347 | "outputs": [ 2348 | { 2349 | "ename": "NameError", 2350 | "evalue": "name 'data0' is not defined", 2351 | "output_type": "error", 2352 | "traceback": [ 2353 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 2354 | "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", 2355 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[1;32mdef\u001b[0m \u001b[0mget_traintest\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtrain_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mwindow_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m24\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2\u001b[0m \u001b[0mtrain\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mtrain_size\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mtest\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mtrain_size\u001b[0m\u001b[1;33m-\u001b[0m\u001b[0mwindow_size\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0mX_train\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0my_train\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbuild_sequences\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mwindow_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mwindow_size\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mX_test\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0my_test\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbuild_sequences\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtest\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mwindow_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mwindow_size\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 2356 | "\u001b[1;31mNameError\u001b[0m: name 'data0' is not defined" 2357 | ] 2358 | } 2359 | ], 2360 | "source": [ 2361 | "def get_traintest(data,train_size=len(data0),window_size=24):\n", 2362 | " train=data[:train_size]\n", 2363 | " test=data[train_size-window_size:]\n", 2364 | " X_train,y_train=build_sequences(train,window_size=window_size)\n", 2365 | " X_test,y_test=build_sequences(test,window_size=window_size)\n", 2366 | " return X_train,y_train[:,-1],X_test,y_test[:,-1]\n" 2367 | ] 2368 | }, 2369 | { 2370 | "cell_type": "code", 2371 | "execution_count": null, 2372 | "metadata": {}, 2373 | "outputs": [], 2374 | "source": [] 2375 | } 2376 | ], 2377 | "metadata": { 2378 | "kernelspec": { 2379 | "display_name": "Python 3", 2380 | "language": "python", 2381 | "name": "python3" 2382 | }, 2383 | "language_info": { 2384 | "codemirror_mode": { 2385 | "name": "ipython", 2386 | "version": 3 2387 | }, 2388 | "file_extension": ".py", 2389 | "mimetype": "text/x-python", 2390 | "name": "python", 2391 | "nbconvert_exporter": "python", 2392 | "pygments_lexer": "ipython3", 2393 | "version": "3.6.5" 2394 | } 2395 | }, 2396 | "nbformat": 4, 2397 | "nbformat_minor": 2 2398 | } 2399 | --------------------------------------------------------------------------------