├── .gitignore ├── LICENSE ├── README.md ├── checkpoints └── .gitignore ├── data └── demo.csv ├── demo.py ├── doc └── results.png └── rnn.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | .idea/ 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Jianfeng Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Water-Table-Depth-Prediction-Pytorch 2 | 3 | ### Introduction 4 | This is a PyTorch implementation of our work *Developing a Long Short-Term Memory (LSTM) based Model for Predicting Water Table Depth in Agricultural Areas*.[[Paper](https://www.sciencedirect.com/science/article/pii/S0022169418303184)] 5 | 6 | ### Requirements 7 | ``` 8 | Python3.x 9 | pytorch>=0.4.0 10 | numpy>=1.14.0 11 | pandas>=0.22.0 12 | scikit-learn>=0.14 13 | ``` 14 | ### Installation 15 | The code was tested with Python 3.5. To use this code, please do: 16 | 17 | 18 | 0. Clone the repo: 19 | ```Shell 20 | git clone https://github.com/jfzhang95/Water-Table-Depth-Prediction-PyTorch 21 | cd Water-Table-Depth-Prediction-PyTorch 22 | ``` 23 | 24 | 1. Install dependencies: 25 | ```Shell 26 | pip install matplotlib numpy pandas scikit-learn 27 | ``` 28 |    For pytorch installation, please see in [PyTorch.org](https://pytorch.org/). 29 | 30 | 2. To try the demo code, please run: 31 | ```Shell 32 | python demo.py 33 | ``` 34 | 35 | If installed correctly, the result should look like this: 36 | ![results](doc/results.png) 37 | 38 | Noted that the demo data ([demo.csv](https://github.com/jfzhang95/Water-Table-Depth-Prediction-PyTorch/blob/master/data/demo.csv)) are processed manually, so they are not real data, but they still can reflect the correlation between the original data. 39 | 40 | ### Citation 41 | If you use this code, please consider citing our paper: 42 | 43 | @article{zjf18, 44 | journal = {Journal of Hydrology}, 45 | title = {Developing a Long Short-Term Memory (LSTM) based Model for Predicting Water Table Depth in Agricultural Areas}, 46 |  author         = {Jianfeng Zhang, Yan Zhu, Xiaoping Zhang, Ming Ye and Jinzhong Yang}, 47 | year = {2018}, 48 | volume = {561}, 49 | pages = {918-929} 50 | } 51 | 52 | 53 | ### License 54 | [MIT](https://github.com/jfzhang95/Water-Table-Depth-Prediction-PyTorch/blob/master/LICENSE) 55 | 56 | -------------------------------------------------------------------------------- /checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /data/demo.csv: -------------------------------------------------------------------------------- 1 | Year,Month,Irrigation,Rainfall,Tem,Evaporation,Depth 2 | 2000,1,1992.5083196719,314.1957169733,-19.295407758,12040.4476394335,8.972728964 3 | 2000,2,-2042.2517088609,-1633.7667748126,-38.4498364209,18045.5168256601,8.8372379517 4 | 2000,3,-1374.45069054,868.2637115245,-36.9623765139,90248.5986880446,10.7168385148 5 | 2000,4,29951.3265990933,309.8571424647,-13.1656454124,141672.950996956,8.904409599 6 | 2000,5,43748.3711548337,-291.3250047579,19.6021026569,188789.213347986,6.8581773986 7 | 2000,6,34502.6486603004,17799.0968336829,57.7370337131,149223.642080642,5.2845146329 8 | 2000,7,30296.9873863791,14179.7310546686,91.1360627559,173045.99634606,6.0969095283 9 | 2000,8,19459.0051582282,29671.9912945174,106.0619435675,125595.394087347,7.8860077628 10 | 2000,9,17758.6975818984,21215.1823079736,110.7978477205,93648.5653302581,7.9172426387 11 | 2000,10,62539.8414178833,4294.4213316051,86.9145269101,55423.3544036544,7.5450815888 12 | 2000,11,235.967332703,529.2638159455,59.7571213723,32169.4952479278,4.4325341631 13 | 2000,12,2605.7856577036,738.6209819721,14.319977756,20927.4646969387,7.3245968878 14 | 2001,1,-1446.3754864288,457.590908923,-11.6944352737,16903.9008100765,7.5467988992 15 | 2001,2,1994.960901874,78.495377589,-32.9582290949,43792.8401055522,8.8359511074 16 | 2001,3,-267.3643645519,-577.8149148489,-44.3692361598,83874.316059848,10.2077398958 17 | 2001,4,26117.1536851372,156.3112109546,-5.7425824914,131993.693718587,8.4048502044 18 | 2001,5,44438.8272441019,7482.3624327083,34.8285539136,187135.718846405,6.6102531838 19 | 2001,6,40444.5063014989,1747.6364483683,68.8225044018,184055.273477082,6.0819109676 20 | 2001,7,40209.0405479997,8124.129399824,78.5143052622,188270.500152369,7.3103318927 21 | 2001,8,23856.4602065489,26413.0362767538,100.759263918,158303.293193421,6.9534469533 22 | 2001,9,15448.5993112436,41288.3846495247,90.3775469511,90812.4456162747,7.2820627349 23 | 2001,10,64213.7344128649,36381.0620477837,77.8568838742,65292.5713181995,7.7141083027 24 | 2001,11,-1613.0641490158,382.1082475746,48.0512566729,41973.584633257,7.3394374677 25 | 2001,12,-600.5262692763,165.7094367505,19.7482066789,9449.1667292834,7.2133820599 26 | 2002,1,1014.7620170268,-214.1443198758,-15.5832115562,38770.3662096595,8.7714648667 27 | 2002,2,-756.8050526192,769.3649758967,-24.0603398165,38865.7208288432,9.0372763717 28 | 2002,3,-2182.7253705371,3454.3186301167,-2.0376116995,80372.9199556001,9.3197329858 29 | 2002,4,25396.1458187972,11157.5283660833,14.6646376991,136223.773754898,6.0079325001 30 | 2002,5,37102.9616992109,24694.2961065592,13.4495850367,131044.917270516,4.2630808858 31 | 2002,6,40234.9691689304,53415.0948522383,38.3047052931,142136.464950778,4.6419067361 32 | 2002,7,42739.5324691317,16110.2434594719,95.6419438562,140852.810807609,5.0832563496 33 | 2002,8,24747.4193208903,3222.8981484282,102.8578298589,144341.865884577,6.0206484119 34 | 2002,9,14298.2763540136,6217.5868866187,98.773476424,93394.0088531071,8.4178573982 35 | 2002,10,61894.0851518047,-190.0557246785,91.4465661896,70301.2762720597,6.7635269011 36 | 2002,11,-865.5967647891,-1050.5053853096,67.8969387849,29663.145500379,5.868084999 37 | 2002,12,3380.2264775865,2024.1566560343,26.4162820183,10924.8154850659,6.5219879718 38 | 2003,1,-386.4503653792,896.1226862132,-12.3658294534,24884.9389507601,9.100252177 39 | 2003,2,1557.7950405133,241.0719972437,-29.0865412199,30745.2287503802,9.1648896612 40 | 2003,3,650.3164758386,2658.4287558479,-28.8135959502,55239.448750046,9.718686127 41 | 2003,4,13394.0070559086,7004.2807028833,0.5013660454,137237.861044995,7.203485342 42 | 2003,5,33911.9167946654,10094.1076700016,43.8572289702,169825.01753652,6.3415744691 43 | 2003,6,21533.3530476802,1801.0880057285,72.3854747144,157660.469586514,6.1825390601 44 | 2003,7,37725.0376329477,41187.2889160595,84.1265373016,143046.516829244,7.5462799977 45 | 2003,8,13699.7352306785,10833.4909793159,89.5893857661,141965.598874551,8.2421210066 46 | 2003,9,9978.0099721629,71274.7476319909,93.952103771,107087.085814857,9.91492158 47 | 2003,10,57565.7322977106,2245.7232658376,84.7504765573,73721.7175061161,6.8986012184 48 | 2003,11,2104.3251052962,-688.0205005814,47.4993730858,42009.9453231635,6.1908231216 49 | 2003,12,-3800.045422625,-675.1317978651,3.8680187888,22007.6853877412,6.7358412773 50 | 2004,1,2086.2142577479,1014.0117373894,-18.9368128305,40137.6946586243,9.0026294949 51 | 2004,2,2602.0320557557,64.2927011031,-28.6506700652,45511.5986683673,8.9124447037 52 | 2004,3,2607.4575419277,2920.8152751269,4.3002725762,86881.0671414217,8.3773358986 53 | 2004,4,30247.7097798502,11588.2504945039,5.1650030987,126272.358305815,7.5391370342 54 | 2004,5,36500.2164145915,24567.0820912928,17.523411113,137497.901388769,5.5749822377 55 | 2004,6,40216.9582150654,53622.3914122528,31.8578154274,159777.477389197,4.7412483249 56 | 2004,7,44234.0442499885,16017.7407660594,87.9881978603,142250.302255822,5.7880415424 57 | 2004,8,26429.4837362526,3032.1676788145,109.5552108576,150812.144558594,5.6830398478 58 | 2004,9,17653.4709589239,7921.8204937788,95.1165548783,88714.1885122467,8.4739868368 59 | 2004,10,61542.4921065865,-1956.0525356638,83.2737058279,69006.5725073486,7.4165250571 60 | 2004,11,36.6134202471,375.9965101976,70.5831069354,50624.7207739744,4.8386637862 61 | 2004,12,566.8345209991,2295.0177818814,24.7191836218,25473.7540558068,7.2194637112 62 | 2005,1,-3822.7516415472,-269.118333342,-11.1668030912,34991.0142469635,10.0489238052 63 | 2005,2,-478.5741823872,-740.4862052638,-41.8258276737,36191.346610205,9.1174074249 64 | 2005,3,2799.0550332015,476.4050785832,-44.2945982171,94786.7810525021,9.4844878526 65 | 2005,4,29287.1921667186,1585.6983177999,-6.1786769353,144571.282654031,8.9805315125 66 | 2005,5,43635.8172255169,5385.6761171638,27.4260701646,179526.719378807,4.9065506861 67 | 2005,6,38767.9532063676,3074.1477008705,67.4911123598,200064.14810238,5.7893406162 68 | 2005,7,36898.2934518845,7315.0385396565,75.3836548388,183654.375943701,7.2543564682 69 | 2005,8,27775.4765135801,27297.3044460976,89.5734226638,151187.565655289,8.9121817993 70 | 2005,9,14098.3139874008,40482.8112304931,96.8313331372,85210.057303944,7.3927618458 71 | 2005,10,58558.6405662779,36067.9426784126,77.5209698405,55957.1688862193,7.396345879 72 | 2005,11,-1526.4067351494,-500.7574263379,47.9846285858,27459.3562322097,4.8347087281 73 | 2005,12,-2384.5975872833,-955.6499553698,16.434862517,5253.6811570786,7.265792525 74 | 2006,1,1567.0140135536,1151.0120327973,-20.7764078078,15208.9585318538,9.213940087 75 | 2006,2,2913.6113992749,2237.2749740479,-34.8211671529,41474.5565895835,9.4822795584 76 | 2006,3,1364.2481336267,-539.7593045133,-4.7837468176,70721.6145288348,10.4681292384 77 | 2006,4,19901.1900571259,401.7360131679,-7.4301816306,134583.034034957,8.7084561594 78 | 2006,5,47418.4063191555,17712.058979923,16.8583771074,178743.129218957,5.8765688575 79 | 2006,6,35118.3770144947,581.5465176179,71.890703121,166218.51424835,5.4374069985 80 | 2006,7,38639.1467694101,21044.8809130651,96.1165280213,147054.737379063,6.2034790388 81 | 2006,8,20824.3935742376,17345.5866955182,106.6081167304,118437.967687395,6.5392017222 82 | 2006,9,16592.3708190747,9395.8290645353,103.4509679917,100884.217353987,9.1829937693 83 | 2006,10,59893.6372421636,1075.9417619642,85.0284281428,88809.9542884777,8.1613393083 84 | 2006,11,450.6792621476,3993.0308642043,64.6761010142,54368.3063440379,4.5093902331 85 | 2006,12,1636.6951413217,-269.1058091842,21.7779008449,9427.0795085631,6.3615198295 86 | 2007,1,-1949.3823617708,-613.7695950016,-9.9274563559,11037.3165455179,9.7072859282 87 | 2007,2,-1102.7275340532,6915.6950115401,-41.5785918665,38921.433102614,9.3888818492 88 | 2007,3,-1328.6199157601,15140.5571259182,-26.1639986773,58800.46085095,8.7392711075 89 | 2007,4,22066.1764420817,4611.0302377563,-6.8326942472,131650.151883966,7.8428531611 90 | 2007,5,48294.8972788041,557.6491793652,14.0424328046,192546.753305473,7.6120575187 91 | 2007,6,38201.0941112011,61360.2621782378,54.4707773583,159112.097351731,6.5865266032 92 | 2007,7,43577.9169831785,23377.9328093385,79.0143582462,135487.718649246,5.4125839723 93 | 2007,8,22387.8184977697,17639.4606527731,95.2976035666,138453.329881849,8.0884594137 94 | 2007,9,16455.9120262078,8327.9144964874,103.6746016514,100811.637258969,7.2745707204 95 | 2007,10,69045.5356580916,12600.584343208,93.5948080771,71999.6992130972,6.7300876177 96 | 2007,11,1764.920909494,1407.5788180452,60.1161104802,39895.8722576806,3.9499619889 97 | 2007,12,-1328.6644638271,-210.5224779435,3.1001191414,30203.3060668208,6.2708627671 98 | 2008,1,-1350.0320982345,3249.250766653,-9.5890238558,21.3084869433,7.3096706004 99 | 2008,2,-1786.7360085808,273.5099143006,-32.9446339333,17702.7933414733,8.2673016384 100 | 2008,3,620.9043586546,-270.7739529231,-39.6899915103,73607.5905421309,9.5670900484 101 | 2008,4,20856.1771359163,3146.4328390487,-11.1111667718,115815.870531152,8.5797313863 102 | 2008,5,48507.6564602097,4603.5040817133,38.2975783097,184503.383801004,5.4766284548 103 | 2008,6,32320.2718047226,31213.3180879014,66.7717880508,151502.754940086,5.273589638 104 | 2008,7,37320.3390553108,50614.6737924457,85.0058001322,144649.351474712,6.95911323 105 | 2008,8,18199.504489151,48386.5761062642,90.8480076578,129316.651920114,6.708450655 106 | 2008,9,20391.6375385559,10580.62437324,92.1797829335,108635.638152625,7.3945263069 107 | 2008,10,58444.2146693381,2099.3563299519,78.6830414204,94122.764504983,6.8888233428 108 | 2008,11,3185.2189538728,1121.2413835035,78.3066422298,60397.5914525808,5.6778711418 109 | 2008,12,1817.7912010928,-305.9038928113,34.4921687043,33321.8397371101,6.8578472569 110 | 2009,1,-1239.4607399782,447.1989880709,-4.2538746605,12771.0282184424,7.9649197838 111 | 2009,2,1384.5773495339,2285.6514062288,-35.3629320795,35150.6137938001,8.5362422017 112 | 2009,3,-3694.4681882967,3538.3493277784,-23.9588452957,52207.7113327517,8.3072500889 113 | 2009,4,20362.1631904287,3023.1583772773,1.6637355677,117521.497351571,7.0187079798 114 | 2009,5,50593.6620445755,10397.6296840017,44.4562936861,178431.410097219,4.5029332386 115 | 2009,6,39496.6966297038,1103.6091603384,68.4440664534,199772.892337299,6.2849465083 116 | 2009,7,43063.2203963758,16971.7448318511,89.502913534,188137.937088275,6.9081535976 117 | 2009,8,22179.077838634,28799.5500957659,108.6809709063,156495.206904778,6.5459771819 118 | 2009,9,20277.105811496,29843.8851236642,89.2721959082,100934.093582542,7.8622919891 119 | 2009,10,67470.522894461,554.6146432594,84.5529362875,117297.542083953,7.6988466176 120 | 2009,11,5358.255415246,5440.5919569377,69.1393389741,23380.7116971716,5.1541271415 121 | 2009,12,-210.6269713507,-360.3685495956,-4.8166265881,24665.9656176214,6.5616755275 122 | 2010,1,-3342.697254843,1284.8149484832,-37.4415721765,30241.1465269572,8.2024262882 123 | 2010,2,-452.764277594,744.177282546,-57.7914762143,49211.2689722823,9.4745240103 124 | 2010,3,909.9920247126,3338.1093007488,-43.4529371602,66084.8092526247,8.7060000246 125 | 2010,4,17889.3688406509,1543.3575281812,-33.6275363251,83724.4603660708,8.0760718308 126 | 2010,5,45018.9569561923,34561.0464100109,21.8435753334,163531.079424784,4.6677119702 127 | 2010,6,29744.5416891046,9370.3123240047,49.9679099832,196893.20739719,4.7737786968 128 | 2010,7,37866.7682664827,6363.628313858,70.7218208922,190326.285888571,7.0881596334 129 | 2010,8,14851.0026421879,2103.2241861837,97.4570630129,150863.734618918,8.3356989488 130 | 2010,9,14391.0277938624,40134.8679810422,102.4820452202,89967.6839468569,8.8852705717 131 | 2010,10,62627.9895315413,1867.1358941553,83.3924029748,60433.5690272789,8.7453597732 132 | 2010,11,1415.8603595977,339.5991879025,54.6712121297,47464.8595064449,5.319094786 133 | 2010,12,-696.4824898257,-1566.5294139293,9.8216238936,36627.2931313552,6.308273918 134 | 2011,1,-983.416942065,2492.0253975684,-17.7596422889,14397.6728898769,7.6475898738 135 | 2011,2,-2008.67431919,526.2270719924,-24.4021423157,15744.9403465722,9.1615206376 136 | 2011,3,-2315.468761431,-705.1343620344,-52.9468240718,52551.880586609,10.0513610431 137 | 2011,4,16535.2059893341,387.8472438606,-15.3226199683,137014.759041677,7.8917952873 138 | 2011,5,46778.498956385,2735.0318816074,18.1755182623,189061.38480658,7.202117884 139 | 2011,6,33916.0645445093,9109.9666718717,68.5778968532,191397.405111407,7.3181654511 140 | 2011,7,40324.6320967434,12553.1054928233,94.705649525,195249.234866651,6.473298021 141 | 2011,8,15886.3752995587,12229.2728686363,108.2367538651,169191.907487504,7.3519418885 142 | 2011,9,17289.7121073133,3350.8047659506,92.9929231079,121092.696797974,7.7416124147 143 | 2011,10,64553.8805496607,5184.2156795966,73.9472321028,74807.3549096308,8.2582967962 144 | 2011,11,3099.6128648317,-693.7528249307,56.336108519,47013.9946411057,3.9749151445 145 | 2011,12,-1482.5665107684,-1490.6920639721,27.298843012,18434.8071640145,6.2398553094 146 | 2012,1,2648.9282565448,977.2033118444,-14.9177599049,10557.8926792309,9.1600318335 147 | 2012,2,-1395.3402643954,-14.47669025,-48.1206666417,32169.1793952093,8.1842023323 148 | 2012,3,744.8559974669,86.2927021337,-36.4061958197,73041.489636775,9.7357233547 149 | 2012,4,20946.7633883001,-512.8763798724,-1.7524697111,152460.562126691,8.0061040867 150 | 2012,5,44585.2222314331,15626.91276034,41.2707165624,173275.839222574,5.589704837 151 | 2012,6,33854.5303935094,30466.5672038557,60.4637337002,178672.678043996,4.6394531333 152 | 2012,7,40119.822845255,74516.3708826216,77.479703859,164117.527661127,5.7473437607 153 | 2012,8,23447.5727963349,29570.8960473694,90.6072636124,145542.540005518,6.4308159126 154 | 2012,9,15075.6051431687,14204.7957822182,88.7135554338,88696.293001768,8.7294292494 155 | 2012,10,62843.3572745977,-404.3974302493,80.6597494508,98133.7890383517,7.9329949439 156 | 2012,11,1214.3119584021,-446.4124345215,39.9731689583,40253.0140927485,5.793181429 157 | 2012,12,891.0915264901,-374.2199526424,3.2885555308,15117.6546529998,5.420527625 158 | 2013,1,-2342.067247293,-407.3789982866,-6.6294774919,14439.2247424584,8.2856885456 159 | 2013,2,-1218.9978343601,-889.7145877889,-44.1710267156,26722.278465606,9.032417103 160 | 2013,3,-1183.3063329005,479.3456347608,-29.0699469078,79055.2834668528,9.1947298699 161 | 2013,4,27964.0020483829,1246.1037720473,-0.9600276419,118632.937151314,8.9618038092 162 | 2013,5,48357.5892469551,3891.6885537972,56.2357845017,205855.274427051,5.1941795566 163 | 2013,6,37359.7170253037,22072.5297461547,69.3827486111,152549.568057488,6.7913924531 164 | 2013,7,41337.5757739085,32804.2742369428,72.5544526249,161613.873657892,6.8242571638 165 | 2013,8,19170.8498326554,1711.2395646448,92.1502631409,133324.710105294,5.3346744476 166 | 2013,9,17390.1471218854,14636.8455336336,96.1909503257,99089.0072254957,8.9661012989 167 | 2013,10,68788.714885654,-251.937762411,87.4205022764,81966.0912259647,8.8800850284 168 | 2013,11,-1052.8417731971,-243.5652595827,58.5260443976,31874.8378428709,4.6790693867 169 | 2013,12,510.8054866447,1294.5554983925,18.4621176651,11562.8268159789,5.9895643152 170 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from sklearn.preprocessing import StandardScaler 3 | from sklearn.metrics import r2_score, mean_squared_error 4 | import matplotlib.pyplot as plt 5 | from rnn import RNN 6 | import numpy as np 7 | from torch import nn 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | 12 | ss_X_dep = StandardScaler() 13 | ss_y_dep = StandardScaler() 14 | 15 | def rmse(y1, y2): 16 | return np.sqrt(mean_squared_error(y1, y2)) 17 | 18 | # Noted that the demo data are processed manually, so they are not real data, 19 | # but they still can reflect the correlation between the original data. 20 | data = pd.read_csv('data/demo.csv') 21 | 22 | Inputs = data.drop('Year', axis=1).drop('Depth', axis=1) 23 | Outputs = data['Depth'] 24 | 25 | Inputs = Inputs.as_matrix() 26 | Outputs = Outputs.as_matrix().reshape(-1, 1) 27 | 28 | # First 12 years of data 29 | X_train_dep = Inputs[0:144] 30 | y_train_dep = Outputs[0:144] 31 | 32 | # Last 2 years of data 33 | X_test_dep = Inputs[144:] 34 | 35 | print("X_train_dep shape", X_train_dep.shape) 36 | print("y_train_dep shape", y_train_dep.shape) 37 | print("X_test_dep shape", X_test_dep.shape) 38 | 39 | X = np.concatenate([X_train_dep, X_test_dep], axis=0) 40 | 41 | # Standardization 42 | X = ss_X_dep.fit_transform(X) 43 | 44 | # First 12 years of data 45 | X_train_dep_std = X[0:144] 46 | y_train_dep_std = ss_y_dep.fit_transform(y_train_dep) 47 | 48 | # All 14 years of data 49 | X_test_dep_std = X 50 | X_train_dep_std = np.expand_dims(X_train_dep_std, axis=0) 51 | y_train_dep_std = np.expand_dims(y_train_dep_std, axis=0) 52 | X_test_dep_std = np.expand_dims(X_test_dep_std, axis=0) 53 | 54 | # Transfer to Pytorch Variable 55 | X_train_dep_std = Variable(torch.from_numpy(X_train_dep_std).float()) 56 | y_train_dep_std = Variable(torch.from_numpy(y_train_dep_std).float()) 57 | X_test_dep_std = Variable(torch.from_numpy(X_test_dep_std).float()) 58 | 59 | # Define rnn model 60 | # You can also choose rnn_type as 'rnn' or 'gru' 61 | model = RNN(input_size=5, hidden_size=40, num_layers=1, class_size=1, dropout=0.5, rnn_type='lstm') 62 | # Define optimization function 63 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # optimize all rnn parameters 64 | # Define loss function 65 | loss_func = nn.MSELoss() 66 | 67 | # Start training 68 | for iter in range(20000+1): 69 | model.train() 70 | prediction = model(X_train_dep_std) 71 | loss = loss_func(prediction, y_train_dep_std) 72 | optimizer.zero_grad() # clear gradients for this training step 73 | loss.backward() # back propagation, compute gradients 74 | optimizer.step() 75 | if iter % 100 == 0: 76 | print("iteration: %s, loss: %s" % (iter, loss.item())) 77 | 78 | # Save model 79 | save_filename = 'checkpoints/LSTM_FC.pth' 80 | torch.save(model, save_filename) 81 | print('Saved as %s' % save_filename) 82 | 83 | # Start evaluating model 84 | model.eval() 85 | 86 | y_pred_dep_ = model(X_test_dep_std).detach().numpy() 87 | y_pred_dep = ss_y_dep.inverse_transform(y_pred_dep_[0, 144:]) 88 | 89 | print('the value of R-squared of Evaporation is ', r2_score(Outputs[144:], y_pred_dep)) 90 | print('the value of Root mean squared error of Evaporation is ', rmse(Outputs[144:], y_pred_dep)) 91 | 92 | f, ax1 = plt.subplots(1, 1, sharex=True, figsize=(6, 4)) 93 | 94 | ax1.plot(Outputs[144:], color="blue", linestyle="-", linewidth=1.5, label="Measurements") 95 | ax1.plot(y_pred_dep, color="green", linestyle="--", linewidth=1.5, label="Proposed model") 96 | 97 | plt.legend(loc='upper right') 98 | plt.xticks(fontsize=8,fontweight='normal') 99 | plt.yticks(fontsize=8,fontweight='normal') 100 | plt.xlabel('Time (Month)', fontsize=10) 101 | plt.ylabel('Water table depth (m)', fontsize=10) 102 | plt.xlim(0, 25) 103 | plt.savefig('results.png', format='png') 104 | plt.show() 105 | 106 | 107 | ##### Loading Model ##### 108 | model = torch.load('checkpoints/LSTM_FC.pth') 109 | model.eval() 110 | y_pred_dep_ = model(X_test_dep_std).detach().numpy() 111 | y_pred_dep = ss_y_dep.inverse_transform(y_pred_dep_[0, 144:]) 112 | 113 | print('the value of R-squared of Evaporation is ', r2_score(Outputs[144:], y_pred_dep)) 114 | print('the value of Root mean squared error of Evaporation is ', rmse(Outputs[144:], y_pred_dep)) 115 | -------------------------------------------------------------------------------- /doc/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfzhang95/Water-Table-Depth-Prediction-PyTorch/85932d33ae30b833ca5148764fed2430c6a07a7b/doc/results.png -------------------------------------------------------------------------------- /rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class RNN(nn.Module): 6 | def __init__(self, input_size, hidden_size, num_layers, class_size, dropout=0.5, rnn_type='lstm'): 7 | super(RNN, self).__init__() 8 | 9 | self.input_size = input_size 10 | self.hidden_size = hidden_size 11 | self.class_size = class_size 12 | self.num_layers = num_layers 13 | self.rnn_type = rnn_type 14 | 15 | if self.rnn_type == 'lstm': 16 | self.rnn = nn.LSTM( 17 | input_size=self.input_size, 18 | hidden_size=self.hidden_size, # rnn hidden unit 19 | num_layers=self.num_layers, # number of rnn layer 20 | batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size) 21 | ) 22 | elif self.rnn_type == 'rnn': 23 | self.rnn = nn.RNN( 24 | input_size=self.input_size, 25 | hidden_size=self.hidden_size, 26 | num_layers=self.num_layers, 27 | batch_first=True, 28 | ) 29 | elif self.rnn_type == 'gru': 30 | self.rnn = nn.GRU( 31 | input_size=self.input_size, 32 | hidden_size=self.hidden_size, 33 | num_layers=self.num_layers, 34 | batch_first=True, 35 | ) 36 | else: 37 | raise NotImplementedError 38 | 39 | self.dropout = nn.Dropout(dropout) 40 | self.out = nn.Linear(self.hidden_size, self.class_size) # FC layer in our paper 41 | 42 | def forward(self, x): 43 | h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) 44 | if self.rnn_type == 'lstm': 45 | c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) 46 | r_out, _ = self.rnn(x, (h0, c0)) 47 | else: 48 | r_out, _ = self.rnn(x, h0) 49 | 50 | outs = [] # save all predictions 51 | for time_step in range(r_out.size(1)): # calculate output for each time step 52 | outs.append(self.out(self.dropout((r_out[:, time_step, :])))) 53 | return torch.stack(outs, dim=1) 54 | --------------------------------------------------------------------------------