4 |
5 | The project 2023 EY Open Science Data Challenge - Crop Forecasting is a Data Science project conducted as part of the challenge proposed by EY, Microsoft, and Cornell University. The objective of this project is to predict the yield of rice fields using satellite image data provided by Microsoft Planetary, meteorological data, and field data.
6 |
7 | ## 🏆 Challenge ranking
8 | The score of the challenge was the R2 score.
9 | Our solution was the 4th (out of 185 teams) one with a R2 score equal to 0.66 🎉.
10 |
11 | The podium:
12 | 🥇 Outatime - 0.68
13 | 🥈 Joshua Rexmond Nunoo Otoo - 0.68
14 | 🥉 Amma Simmons - 0.67
15 |
16 | ## 🛠️ Data processing
17 |
18 |
19 | ## 🏛️ Model architecture
20 |
21 |
22 | ## 📚 Documentation
23 | The project documentation, generated using Sphinx, can be found in the `docs/` directory. It provides detailed information about the project's setup, usage, implementation, tutorial.
24 |
25 | ## 🔬 References
26 |
27 | Jeong, S., Ko, J., & Yeom, J. M. (2022). Predicting rice yield at pixel scale through synthetic use of crop and deep learning models with satellite data in South and North Korea. Science of The Total Environment, 802, 149726.
28 |
29 | Nazir, A., Ullah, S., Saqib, Z. A., Abbas, A., Ali, A., Iqbal, M. S., ... & Butt, M. U. (2021). Estimation and forecasting of rice yield using phenology-based algorithm and linear regression model on sentinel-ii satellite data. Agriculture, 11(10), 1026.
30 |
31 | ## 📝 Citing
32 |
33 | ```
34 | @misc{UrgellReberga:2023,
35 | Author = {Baptiste Urgell and Louis Reberga},
36 | Title = {Crop forecasting},
37 | Year = {2023},
38 | Publisher = {GitHub},
39 | Journal = {GitHub repository},
40 | Howpublished = {\url{https://github.com/association-rosia/crop-forecasting}}
41 | }
42 | ```
43 |
44 | ## 🛡️ License
45 |
46 | Project is distributed under [MIT License](https://github.com/association-rosia/crop-forecasting/blob/main/LICENSE)
47 |
48 | ## 👨🏻💻 Contributors
49 |
50 | Louis
51 | REBERGA ' + 106 | '' + 107 | _("Hide Search Matches") + 108 | "
" 109 | ) 110 | ); 111 | }, 112 | 113 | /** 114 | * helper function to hide the search marks again 115 | */ 116 | hideSearchWords: () => { 117 | document 118 | .querySelectorAll("#searchbox .highlight-link") 119 | .forEach((el) => el.remove()); 120 | document 121 | .querySelectorAll("span.highlighted") 122 | .forEach((el) => el.classList.remove("highlighted")); 123 | localStorage.removeItem("sphinx_highlight_terms") 124 | }, 125 | 126 | initEscapeListener: () => { 127 | // only install a listener if it is really needed 128 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) return; 129 | 130 | document.addEventListener("keydown", (event) => { 131 | // bail for input elements 132 | if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return; 133 | // bail with special keys 134 | if (event.shiftKey || event.altKey || event.ctrlKey || event.metaKey) return; 135 | if (DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS && (event.key === "Escape")) { 136 | SphinxHighlight.hideSearchWords(); 137 | event.preventDefault(); 138 | } 139 | }); 140 | }, 141 | }; 142 | 143 | _ready(SphinxHighlight.highlightSearchWords); 144 | _ready(SphinxHighlight.initEscapeListener); 145 | -------------------------------------------------------------------------------- /docs/build/html/_static/language_data.js: -------------------------------------------------------------------------------- 1 | /* 2 | * language_data.js 3 | * ~~~~~~~~~~~~~~~~ 4 | * 5 | * This script contains the language-specific data used by searchtools.js, 6 | * namely the list of stopwords, stemmer, scorer and splitter. 7 | * 8 | * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS. 9 | * :license: BSD, see LICENSE for details. 10 | * 11 | */ 12 | 13 | var stopwords = ["a", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in", "into", "is", "it", "near", "no", "not", "of", "on", "or", "such", "that", "the", "their", "then", "there", "these", "they", "this", "to", "was", "will", "with"]; 14 | 15 | 16 | /* Non-minified version is copied as a separate JS file, is available */ 17 | 18 | /** 19 | * Porter Stemmer 20 | */ 21 | var Stemmer = function() { 22 | 23 | var step2list = { 24 | ational: 'ate', 25 | tional: 'tion', 26 | enci: 'ence', 27 | anci: 'ance', 28 | izer: 'ize', 29 | bli: 'ble', 30 | alli: 'al', 31 | entli: 'ent', 32 | eli: 'e', 33 | ousli: 'ous', 34 | ization: 'ize', 35 | ation: 'ate', 36 | ator: 'ate', 37 | alism: 'al', 38 | iveness: 'ive', 39 | fulness: 'ful', 40 | ousness: 'ous', 41 | aliti: 'al', 42 | iviti: 'ive', 43 | biliti: 'ble', 44 | logi: 'log' 45 | }; 46 | 47 | var step3list = { 48 | icate: 'ic', 49 | ative: '', 50 | alize: 'al', 51 | iciti: 'ic', 52 | ical: 'ic', 53 | ful: '', 54 | ness: '' 55 | }; 56 | 57 | var c = "[^aeiou]"; // consonant 58 | var v = "[aeiouy]"; // vowel 59 | var C = c + "[^aeiouy]*"; // consonant sequence 60 | var V = v + "[aeiou]*"; // vowel sequence 61 | 62 | var mgr0 = "^(" + C + ")?" + V + C; // [C]VC... is m>0 63 | var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1 64 | var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1 65 | var s_v = "^(" + C + ")?" + v; // vowel in stem 66 | 67 | this.stemWord = function (w) { 68 | var stem; 69 | var suffix; 70 | var firstch; 71 | var origword = w; 72 | 73 | if (w.length < 3) 74 | return w; 75 | 76 | var re; 77 | var re2; 78 | var re3; 79 | var re4; 80 | 81 | firstch = w.substr(0,1); 82 | if (firstch == "y") 83 | w = firstch.toUpperCase() + w.substr(1); 84 | 85 | // Step 1a 86 | re = /^(.+?)(ss|i)es$/; 87 | re2 = /^(.+?)([^s])s$/; 88 | 89 | if (re.test(w)) 90 | w = w.replace(re,"$1$2"); 91 | else if (re2.test(w)) 92 | w = w.replace(re2,"$1$2"); 93 | 94 | // Step 1b 95 | re = /^(.+?)eed$/; 96 | re2 = /^(.+?)(ed|ing)$/; 97 | if (re.test(w)) { 98 | var fp = re.exec(w); 99 | re = new RegExp(mgr0); 100 | if (re.test(fp[1])) { 101 | re = /.$/; 102 | w = w.replace(re,""); 103 | } 104 | } 105 | else if (re2.test(w)) { 106 | var fp = re2.exec(w); 107 | stem = fp[1]; 108 | re2 = new RegExp(s_v); 109 | if (re2.test(stem)) { 110 | w = stem; 111 | re2 = /(at|bl|iz)$/; 112 | re3 = new RegExp("([^aeiouylsz])\\1$"); 113 | re4 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 114 | if (re2.test(w)) 115 | w = w + "e"; 116 | else if (re3.test(w)) { 117 | re = /.$/; 118 | w = w.replace(re,""); 119 | } 120 | else if (re4.test(w)) 121 | w = w + "e"; 122 | } 123 | } 124 | 125 | // Step 1c 126 | re = /^(.+?)y$/; 127 | if (re.test(w)) { 128 | var fp = re.exec(w); 129 | stem = fp[1]; 130 | re = new RegExp(s_v); 131 | if (re.test(stem)) 132 | w = stem + "i"; 133 | } 134 | 135 | // Step 2 136 | re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/; 137 | if (re.test(w)) { 138 | var fp = re.exec(w); 139 | stem = fp[1]; 140 | suffix = fp[2]; 141 | re = new RegExp(mgr0); 142 | if (re.test(stem)) 143 | w = stem + step2list[suffix]; 144 | } 145 | 146 | // Step 3 147 | re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/; 148 | if (re.test(w)) { 149 | var fp = re.exec(w); 150 | stem = fp[1]; 151 | suffix = fp[2]; 152 | re = new RegExp(mgr0); 153 | if (re.test(stem)) 154 | w = stem + step3list[suffix]; 155 | } 156 | 157 | // Step 4 158 | re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/; 159 | re2 = /^(.+?)(s|t)(ion)$/; 160 | if (re.test(w)) { 161 | var fp = re.exec(w); 162 | stem = fp[1]; 163 | re = new RegExp(mgr1); 164 | if (re.test(stem)) 165 | w = stem; 166 | } 167 | else if (re2.test(w)) { 168 | var fp = re2.exec(w); 169 | stem = fp[1] + fp[2]; 170 | re2 = new RegExp(mgr1); 171 | if (re2.test(stem)) 172 | w = stem; 173 | } 174 | 175 | // Step 5 176 | re = /^(.+?)e$/; 177 | if (re.test(w)) { 178 | var fp = re.exec(w); 179 | stem = fp[1]; 180 | re = new RegExp(mgr1); 181 | re2 = new RegExp(meq1); 182 | re3 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 183 | if (re.test(stem) || (re2.test(stem) && !(re3.test(stem)))) 184 | w = stem; 185 | } 186 | re = /ll$/; 187 | re2 = new RegExp(mgr1); 188 | if (re.test(w) && re2.test(w)) { 189 | re = /.$/; 190 | w = w.replace(re,""); 191 | } 192 | 193 | // and turn initial Y back to y 194 | if (firstch == "y") 195 | w = firstch.toLowerCase() + w.substr(1); 196 | return w; 197 | } 198 | } 199 | 200 | -------------------------------------------------------------------------------- /data/raw/test.csv: -------------------------------------------------------------------------------- 1 | ID No,District,Latitude,Longitude,"Season(SA = Summer Autumn, WS = Winter Spring)","Rice Crop Intensity(D=Double, T=Triple)",Date of Harvest,Field size (ha),Predicted Rice Yield (kg/ha) 2 | 1,Chau_Phu,10.542192,105.18792,WS,T,10-04-2022,1.4, 3 | 2,Chau_Thanh,10.400189,105.331053,SA,T,15-07-2022,1.32, 4 | 3,Chau_Phu,10.505489,105.203926,SA,D,14-07-2022,1.4, 5 | 4,Chau_Phu,10.52352,105.138274,WS,D,10-04-2022,1.8, 6 | 5,Thoai_Son,10.29466,105.248528,SA,T,20-07-2022,2.2, 7 | 6,Chau_Phu,10.633572,105.172813,WS,D,12-04-2022,2.5, 8 | 7,Chau_Thanh,10.434116,105.27315,SA,T,15-07-2022,2.2, 9 | 8,Chau_Phu,10.61225,105.175364,SA,D,09-08-2022,4, 10 | 9,Thoai_Son,10.268095,105.344826,WS,T,10-04-2022,1.5, 11 | 10,Chau_Phu,10.523312,105.286299,WS,T,01-04-2022,3, 12 | 11,Chau_Phu,10.554634,105.1522,SA,D,04-07-2022,2.3, 13 | 12,Chau_Phu,10.542998,105.261159,SA,D,14-07-2022,4.9, 14 | 13,Chau_Phu,10.473786,105.190479,SA,T,14-07-2022,1.7, 15 | 14,Chau_Phu,10.52352,105.138274,SA,D,20-07-2022,1.8, 16 | 15,Chau_Thanh,10.371317,105.31009,SA,T,04-08-2022,2, 17 | 16,Chau_Thanh,10.456905,105.186106,WS,T,26-03-2022,1.529, 18 | 17,Chau_Phu,10.441423,105.115088,SA,D,20-07-2022,4, 19 | 18,Thoai_Son,10.227343,105.230821,SA,T,20-07-2022,1.8, 20 | 19,Thoai_Son,10.279477,105.288145,SA,T,20-07-2022,2, 21 | 20,Thoai_Son,10.326776,105.33819,WS,T,12-04-2022,1.7, 22 | 21,Thoai_Son,10.293434,105.256309,WS,T,13-04-2022,2.3, 23 | 22,Chau_Thanh,10.405702,105.208668,SA,T,20-07-2022,1.32, 24 | 23,Chau_Phu,10.56984,105.190938,WS,T,10-04-2022,2.3, 25 | 24,Chau_Thanh,10.41636,105.142144,SA,T,26-07-2022,1.43, 26 | 25,Chau_Thanh,10.371317,105.31009,WS,T,10-04-2022,2, 27 | 26,Chau_Thanh,10.40097,105.348815,WS,T,25-03-2022,2.53, 28 | 27,Chau_Phu,10.594603,105.175686,WS,D,12-04-2022,2.2, 29 | 28,Chau_Phu,10.514912,105.216308,SA,D,14-07-2022,2, 30 | 29,Thoai_Son,10.293434,105.256309,SA,T,20-07-2022,2.3, 31 | 30,Chau_Phu,10.48194,105.151543,SA,D,14-07-2022,2.9, 32 | 31,Thoai_Son,10.238171,105.198793,SA,T,20-07-2022,3.6, 33 | 32,Thoai_Son,10.326776,105.33819,SA,T,20-07-2022,1.7, 34 | 33,Chau_Thanh,10.421158,105.237197,WS,T,27-03-2022,1.76, 35 | 34,Chau_Thanh,10.369692,105.29684,WS,T,10-04-2022,4, 36 | 35,Chau_Phu,10.473786,105.190479,WS,T,03-04-2022,1.7, 37 | 36,Thoai_Son,10.365734,105.267634,WS,T,12-04-2022,2.5, 38 | 37,Thoai_Son,10.281814,105.114918,SA,T,25-07-2022,5, 39 | 38,Chau_Thanh,10.411649,105.370659,SA,T,15-07-2022,1.43, 40 | 39,Chau_Phu,10.477062,105.168941,WS,D,03-04-2022,3.4, 41 | 40,Chau_Thanh,10.437524,105.202166,SA,D,20-07-2022,2.31, 42 | 41,Chau_Thanh,10.414613,105.125615,WS,T,27-03-2022,3.85, 43 | 42,Thoai_Son,10.28697,105.341571,SA,T,20-07-2022,1.78, 44 | 43,Chau_Phu,10.64394,105.165296,SA,T,05-08-2022,1.4, 45 | 44,Chau_Thanh,10.420981,105.295797,WS,T,25-03-2022,1.76, 46 | 45,Thoai_Son,10.340626,105.172116,SA,T,23-07-2022,3, 47 | 46,Chau_Phu,10.625193,105.181059,SA,D,09-08-2022,3, 48 | 47,Chau_Phu,10.479304,105.102943,SA,D,20-07-2022,2.9, 49 | 48,Chau_Thanh,10.392899,105.188514,WS,T,24-03-2022,1.43, 50 | 49,Chau_Thanh,10.426536,105.115181,SA,T,14-07-2022,1.43, 51 | 50,Chau_Phu,10.552114,105.091399,SA,D,05-08-2022,3.6, 52 | 51,Chau_Phu,10.541934,105.247538,WS,T,10-04-2022,2, 53 | 52,Thoai_Son,10.303437,105.381252,WS,T,28-03-2022,2.1, 54 | 53,Thoai_Son,10.368837,105.205763,SA,T,21-07-2022,2.3, 55 | 54,Chau_Thanh,10.440557,105.250671,WS,T,27-03-2022,1.76, 56 | 55,Chau_Thanh,10.430707,105.315671,SA,T,10-07-2022,1.54, 57 | 56,Chau_Thanh,10.378696,105.309113,WS,T,27-03-2022,1.32, 58 | 57,Thoai_Son,10.283467,105.267082,WS,T,20-04-2022,2.2, 59 | 58,Chau_Thanh,10.436907,105.236241,WS,T,27-03-2022,1.375, 60 | 59,Chau_Thanh,10.387619,105.243467,SA,T,19-07-2022,1.3, 61 | 60,Chau_Phu,10.482003,105.203866,SA,T,15-07-2022,3, 62 | 61,Chau_Phu,10.545098,105.112895,SA,D,20-07-2022,3.6, 63 | 62,Thoai_Son,10.279477,105.288145,WS,T,10-04-2022,2, 64 | 63,Chau_Thanh,10.40466,105.311554,SA,T,17-07-2022,1.87, 65 | 64,Thoai_Son,10.34489,105.243935,SA,T,23-07-2022,2.3, 66 | 65,Chau_Phu,10.623506,105.132794,WS,T,10-04-2022,5, 67 | 66,Thoai_Son,10.303766,105.203102,WS,T,16-04-2022,2.2, 68 | 67,Chau_Phu,10.656963,105.152679,WS,D,10-04-2022,3, 69 | 68,Thoai_Son,10.312202,105.330633,WS,T,10-04-2022,3, 70 | 69,Chau_Thanh,10.440335,105.22309,WS,D,27-03-2022,1.76, 71 | 70,Thoai_Son,10.306886,105.290958,SA,T,20-07-2022,1.85, 72 | 71,Chau_Thanh,10.392045,105.307085,SA,T,17-07-2022,1.43, 73 | 72,Chau_Thanh,10.375619,105.124248,WS,T,24-03-2022,1.87, 74 | 73,Thoai_Son,10.313955,105.243521,SA,T,20-07-2022,1.3, 75 | 74,Chau_Phu,10.592666,105.141217,SA,T,05-08-2022,2.8, 76 | 75,Thoai_Son,10.266311,105.23555,WS,T,20-04-2022,1.5, 77 | 76,Chau_Thanh,10.386546,105.19302,SA,T,14-07-2022,1.65, 78 | 77,Thoai_Son,10.282365,105.276189,WS,T,10-04-2022,1.91, 79 | 78,Thoai_Son,10.33791,105.357013,SA,T,20-07-2022,2, 80 | 79,Thoai_Son,10.31856,105.374468,WS,D,28-03-2022,1.4, 81 | 80,Chau_Thanh,10.429212,105.141436,WS,D,27-03-2022,1.21, 82 | 81,Chau_Phu,10.474439,105.216928,SA,T,15-07-2022,7, 83 | 82,Chau_Thanh,10.443488,105.236111,WS,T,27-03-2022,1.65, 84 | 83,Chau_Thanh,10.409367,105.355252,WS,T,25-03-2022,2.42, 85 | 84,Thoai_Son,10.292156,105.361294,SA,T,12-07-2022,2.3, 86 | 85,Chau_Phu,10.469839,105.211568,WS,T,01-04-2022,4, 87 | 86,Thoai_Son,10.257436,105.217205,SA,T,22-07-2022,2.3, 88 | 87,Chau_Phu,10.658291,105.127704,WS,T,10-04-2022,3, 89 | 88,Chau_Thanh,10.407982,105.123304,WS,D,27-03-2022,2.75, 90 | 89,Thoai_Son,10.320684,105.272431,SA,T,20-07-2022,3, 91 | 90,Chau_Phu,10.501648,105.096892,WS,T,10-04-2022,1.2, 92 | 91,Chau_Phu,10.490352,105.23065,WS,T,02-04-2022,4, 93 | 92,Thoai_Son,10.34489,105.243935,WS,T,12-04-2022,2.3, 94 | 93,Chau_Thanh,10.440557,105.250671,SA,T,28-07-2022,1.76, 95 | 94,Thoai_Son,10.320371,105.259016,WS,T,13-04-2022,1.52, 96 | 95,Thoai_Son,10.250745,105.24539,WS,T,20-04-2022,3, 97 | 96,Chau_Thanh,10.435839,105.132981,SA,D,26-07-2022,1.21, 98 | 97,Chau_Phu,10.529357,105.147388,WS,T,10-04-2022,2, 99 | 98,Chau_Thanh,10.452537,105.205118,SA,T,20-07-2022,5.5, 100 | 99,Chau_Thanh,10.394341,105.126836,SA,T,14-07-2022,4.4, 101 | 100,Chau_Phu,10.48065,105.130089,WS,T,10-04-2022,2, 102 | -------------------------------------------------------------------------------- /submissions/toasty-sky-343.csv: -------------------------------------------------------------------------------- 1 | ID No,District,Latitude,Longitude,"Season(SA = Summer Autumn, WS = Winter Spring)","Rice Crop Intensity(D=Double, T=Triple)",Date of Harvest,Field size (ha),Predicted Rice Yield (kg/ha) 2 | 1,Chau_Phu,10.542192,105.18792,WS,T,10-04-2022,1.4,7200 3 | 2,Chau_Thanh,10.400189,105.331053,SA,T,15-07-2022,1.32,6000 4 | 3,Chau_Phu,10.505489,105.203926,SA,D,14-07-2022,1.4,6000 5 | 4,Chau_Phu,10.52352,105.138274,WS,D,10-04-2022,1.8,6960 6 | 5,Thoai_Son,10.29466,105.248528,SA,T,20-07-2022,2.2,6000 7 | 6,Chau_Phu,10.633572,105.172813,WS,D,12-04-2022,2.5,7200 8 | 7,Chau_Thanh,10.434116,105.27315,SA,T,15-07-2022,2.2,6000 9 | 8,Chau_Phu,10.61225,105.175364,SA,D,09-08-2022,4.0,6000 10 | 9,Thoai_Son,10.268095,105.344826,WS,T,10-04-2022,1.5,7200 11 | 10,Chau_Phu,10.523312,105.286299,WS,T,01-04-2022,3.0,7200 12 | 11,Chau_Phu,10.554634,105.1522,SA,D,04-07-2022,2.3,6000 13 | 12,Chau_Phu,10.542998,105.261159,SA,D,14-07-2022,4.9,6000 14 | 13,Chau_Phu,10.473786,105.190479,SA,T,14-07-2022,1.7,6000 15 | 14,Chau_Phu,10.52352,105.138274,SA,D,20-07-2022,1.8,6000 16 | 15,Chau_Thanh,10.371317,105.31009,SA,T,04-08-2022,2.0,6000 17 | 16,Chau_Thanh,10.456905,105.186106,WS,T,26-03-2022,1.529,7200 18 | 17,Chau_Phu,10.441423,105.115088,SA,D,20-07-2022,4.0,6000 19 | 18,Thoai_Son,10.227343,105.230821,SA,T,20-07-2022,1.8,6000 20 | 19,Thoai_Son,10.279477,105.288145,SA,T,20-07-2022,2.0,6000 21 | 20,Thoai_Son,10.326776,105.33819,WS,T,12-04-2022,1.7,7400 22 | 21,Thoai_Son,10.293434,105.256309,WS,T,13-04-2022,2.3,7200 23 | 22,Chau_Thanh,10.405702,105.208668,SA,T,20-07-2022,1.32,6000 24 | 23,Chau_Phu,10.56984,105.190938,WS,T,10-04-2022,2.3,7200 25 | 24,Chau_Thanh,10.41636,105.142144,SA,T,26-07-2022,1.43,6000 26 | 25,Chau_Thanh,10.371317,105.31009,WS,T,10-04-2022,2.0,7400 27 | 26,Chau_Thanh,10.40097,105.348815,WS,T,25-03-2022,2.53,7400 28 | 27,Chau_Phu,10.594603,105.175686,WS,D,12-04-2022,2.2,7200 29 | 28,Chau_Phu,10.514912,105.216308,SA,D,14-07-2022,2.0,6000 30 | 29,Thoai_Son,10.293434,105.256309,SA,T,20-07-2022,2.3,6000 31 | 30,Chau_Phu,10.48194,105.151543,SA,D,14-07-2022,2.9,6000 32 | 31,Thoai_Son,10.238171,105.198793,SA,T,20-07-2022,3.6,6000 33 | 32,Thoai_Son,10.326776,105.33819,SA,T,20-07-2022,1.7,6000 34 | 33,Chau_Thanh,10.421158,105.237197,WS,T,27-03-2022,1.76,7400 35 | 34,Chau_Thanh,10.369692,105.29684,WS,T,10-04-2022,4.0,7200 36 | 35,Chau_Phu,10.473786,105.190479,WS,T,03-04-2022,1.7,7200 37 | 36,Thoai_Son,10.365734,105.267634,WS,T,12-04-2022,2.5,7200 38 | 37,Thoai_Son,10.281814,105.114918,SA,T,25-07-2022,5.0,6000 39 | 38,Chau_Thanh,10.411649,105.370659,SA,T,15-07-2022,1.43,6000 40 | 39,Chau_Phu,10.477062,105.168941,WS,D,03-04-2022,3.4,6800 41 | 40,Chau_Thanh,10.437524,105.202166,SA,D,20-07-2022,2.31,6000 42 | 41,Chau_Thanh,10.414613,105.125615,WS,T,27-03-2022,3.85,7040 43 | 42,Thoai_Son,10.28697,105.341571,SA,T,20-07-2022,1.78,6000 44 | 43,Chau_Phu,10.64394,105.165296,SA,T,05-08-2022,1.4,6000 45 | 44,Chau_Thanh,10.420981,105.295797,WS,T,25-03-2022,1.76,7200 46 | 45,Thoai_Son,10.340626,105.172116,SA,T,23-07-2022,3.0,6000 47 | 46,Chau_Phu,10.625193,105.181059,SA,D,09-08-2022,3.0,6000 48 | 47,Chau_Phu,10.479304,105.102943,SA,D,20-07-2022,2.9,6000 49 | 48,Chau_Thanh,10.392899,105.188514,WS,T,24-03-2022,1.43,7200 50 | 49,Chau_Thanh,10.426536,105.115181,SA,T,14-07-2022,1.43,6000 51 | 50,Chau_Phu,10.552114,105.091399,SA,D,05-08-2022,3.6,6000 52 | 51,Chau_Phu,10.541934,105.247538,WS,T,10-04-2022,2.0,7400 53 | 52,Thoai_Son,10.303437,105.381252,WS,T,28-03-2022,2.1,7400 54 | 53,Thoai_Son,10.368837,105.205763,SA,T,21-07-2022,2.3,6000 55 | 54,Chau_Thanh,10.440557,105.250671,WS,T,27-03-2022,1.76,7200 56 | 55,Chau_Thanh,10.430707,105.315671,SA,T,10-07-2022,1.54,6000 57 | 56,Chau_Thanh,10.378696,105.309113,WS,T,27-03-2022,1.32,7200 58 | 57,Thoai_Son,10.283467,105.267082,WS,T,20-04-2022,2.2,7400 59 | 58,Chau_Thanh,10.436907,105.236241,WS,T,27-03-2022,1.375,7400 60 | 59,Chau_Thanh,10.387619,105.243467,SA,T,19-07-2022,1.3,6000 61 | 60,Chau_Phu,10.482003,105.203866,SA,T,15-07-2022,3.0,6000 62 | 61,Chau_Phu,10.545098,105.112895,SA,D,20-07-2022,3.6,6000 63 | 62,Thoai_Son,10.279477,105.288145,WS,T,10-04-2022,2.0,7200 64 | 63,Chau_Thanh,10.40466,105.311554,SA,T,17-07-2022,1.87,6000 65 | 64,Thoai_Son,10.34489,105.243935,SA,T,23-07-2022,2.3,6000 66 | 65,Chau_Phu,10.623506,105.132794,WS,T,10-04-2022,5.0,7040 67 | 66,Thoai_Son,10.303766,105.203102,WS,T,16-04-2022,2.2,7400 68 | 67,Chau_Phu,10.656963,105.152679,WS,D,10-04-2022,3.0,7040 69 | 68,Thoai_Son,10.312202,105.330633,WS,T,10-04-2022,3.0,7200 70 | 69,Chau_Thanh,10.440335,105.22309,WS,D,27-03-2022,1.76,7040 71 | 70,Thoai_Son,10.306886,105.290958,SA,T,20-07-2022,1.85,6000 72 | 71,Chau_Thanh,10.392045,105.307085,SA,T,17-07-2022,1.43,6000 73 | 72,Chau_Thanh,10.375619,105.124248,WS,T,24-03-2022,1.87,7040 74 | 73,Thoai_Son,10.313955,105.243521,SA,T,20-07-2022,1.3,6000 75 | 74,Chau_Phu,10.592666,105.141217,SA,T,05-08-2022,2.8,6000 76 | 75,Thoai_Son,10.266311,105.23555,WS,T,20-04-2022,1.5,7400 77 | 76,Chau_Thanh,10.386546,105.19302,SA,T,14-07-2022,1.65,6000 78 | 77,Thoai_Son,10.282365,105.276189,WS,T,10-04-2022,1.91,7400 79 | 78,Thoai_Son,10.33791,105.357013,SA,T,20-07-2022,2.0,6000 80 | 79,Thoai_Son,10.31856,105.374468,WS,D,28-03-2022,1.4,7200 81 | 80,Chau_Thanh,10.429212,105.141436,WS,D,27-03-2022,1.21,7040 82 | 81,Chau_Phu,10.474439,105.216928,SA,T,15-07-2022,7.0,6000 83 | 82,Chau_Thanh,10.443488,105.236111,WS,T,27-03-2022,1.65,7200 84 | 83,Chau_Thanh,10.409367,105.355252,WS,T,25-03-2022,2.42,7400 85 | 84,Thoai_Son,10.292156,105.361294,SA,T,12-07-2022,2.3,6000 86 | 85,Chau_Phu,10.469839,105.211568,WS,T,01-04-2022,4.0,7200 87 | 86,Thoai_Son,10.257436,105.217205,SA,T,22-07-2022,2.3,6000 88 | 87,Chau_Phu,10.658291,105.127704,WS,T,10-04-2022,3.0,7040 89 | 88,Chau_Thanh,10.407982,105.123304,WS,D,27-03-2022,2.75,7040 90 | 89,Thoai_Son,10.320684,105.272431,SA,T,20-07-2022,3.0,6000 91 | 90,Chau_Phu,10.501648,105.096892,WS,T,10-04-2022,1.2,7200 92 | 91,Chau_Phu,10.490352,105.23065,WS,T,02-04-2022,4.0,7200 93 | 92,Thoai_Son,10.34489,105.243935,WS,T,12-04-2022,2.3,7200 94 | 93,Chau_Thanh,10.440557,105.250671,SA,T,28-07-2022,1.76,6000 95 | 94,Thoai_Son,10.320371,105.259016,WS,T,13-04-2022,1.52,7200 96 | 95,Thoai_Son,10.250745,105.24539,WS,T,20-04-2022,3.0,7200 97 | 96,Chau_Thanh,10.435839,105.132981,SA,D,26-07-2022,1.21,6000 98 | 97,Chau_Phu,10.529357,105.147388,WS,T,10-04-2022,2.0,7200 99 | 98,Chau_Thanh,10.452537,105.205118,SA,T,20-07-2022,5.5,6000 100 | 99,Chau_Thanh,10.394341,105.126836,SA,T,14-07-2022,4.4,6000 101 | 100,Chau_Phu,10.48065,105.130089,WS,T,10-04-2022,2.0,6960 102 | -------------------------------------------------------------------------------- /src/data/datascaler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Union 4 | 5 | import xarray as xr 6 | from sklearn.preprocessing import (MinMaxScaler, PowerTransformer, 7 | QuantileTransformer, RobustScaler, 8 | StandardScaler) 9 | 10 | parent = os.path.abspath("../features") 11 | sys.path.insert(1, parent) 12 | 13 | 14 | class DatasetScaler: 15 | """Scaler for Vegetable Indice, Geographical, Meteorological and Target. 16 | 17 | :param scaler_s: Scikit-Learn scaler for Vegetable Indice data 18 | :type scaler_s: Union[StandardScaler, RobustScaler, PowerTransformer, QuantileTransformer] 19 | :param columns_s: Vegetable Indice columns name 20 | :type columns_s: list[str] 21 | :param scaler_g: Scikit-Learn scaler for Geographical data 22 | :type scaler_g: Union[StandardScaler, RobustScaler, PowerTransformer, QuantileTransformer] 23 | :param columns_g: Geographical columns name 24 | :type columns_g: list[str] 25 | :param scaler_m: Scikit-Learn scaler for Meteorological data 26 | :type scaler_m: Union[StandardScaler, RobustScaler, PowerTransformer, QuantileTransformer] 27 | :param columns_m: Meteorological columns name 28 | :type columns_m: list[str] 29 | :param scaler_t: Scikit-Learn scaler for Target data 30 | :type scaler_t: MinMaxScaler 31 | """ 32 | def __init__( 33 | self, 34 | scaler_s: Union[ 35 | StandardScaler, RobustScaler, PowerTransformer, QuantileTransformer 36 | ], 37 | columns_s: list[str], 38 | scaler_g: Union[ 39 | StandardScaler, RobustScaler, PowerTransformer, QuantileTransformer 40 | ], 41 | columns_g: list[str], 42 | scaler_m: Union[ 43 | StandardScaler, RobustScaler, PowerTransformer, QuantileTransformer 44 | ], 45 | columns_m: list[str], 46 | scaler_t: MinMaxScaler, 47 | ) -> None: 48 | self.scaler_s = scaler_s 49 | self.columns_s = columns_s 50 | self.scaler_g = scaler_g 51 | self.columns_g = columns_g 52 | self.scaler_m = scaler_m 53 | self.columns_m = columns_m 54 | self.scaler_t = scaler_t 55 | 56 | def fit(self, xdf: xr.Dataset, target: str) -> object: 57 | """Fit all scalers to be used for later scaling. 58 | 59 | :param xdf: The data used to fit all scalers, used for later scaling along the features axis. 60 | :type xdf: xr.Dataset 61 | :param target: Column name to fit the target scaler, used for later scaling along the target axis. 62 | :type target: str 63 | :return: Fitted scaler. 64 | :rtype: object 65 | """ 66 | 67 | def fit_scaler( 68 | xdf: xr.Dataset, 69 | columns: list[str], 70 | scaler: Union[ 71 | StandardScaler, 72 | RobustScaler, 73 | PowerTransformer, 74 | QuantileTransformer, 75 | MinMaxScaler, 76 | ], 77 | ): 78 | df = xdf[columns].to_dataframe() 79 | 80 | return scaler.fit(df[columns]) 81 | 82 | # Fit S data scaler 83 | self.scaler_s = fit_scaler(xdf, self.columns_s, self.scaler_s) 84 | # Fit G data scaler 85 | self.scaler_g = fit_scaler(xdf, self.columns_g, self.scaler_g) 86 | # Fit M data scaler 87 | self.scaler_m = fit_scaler(xdf, self.columns_m, self.scaler_m) 88 | # Fit Target data scaler 89 | self.scaler_t = fit_scaler(xdf, [target], self.scaler_t) 90 | 91 | return self 92 | 93 | def transform(self, xdf: xr.Dataset, target: str = None) -> xr.Dataset: 94 | """Perform transform of each scaler. 95 | 96 | :param xdf: The Dataset used to scale along the features axis. 97 | :type xdf: xr.Dataset 98 | :param target: Column name used to scale along the Target axis, defaults to None 99 | :type target: str, optional 100 | :return: Transformed Dataset. 101 | :rtype: xr.Dataset 102 | """ 103 | 104 | def transform_data( 105 | xdf: xr.Dataset, 106 | columns: str, 107 | scaler: Union[ 108 | StandardScaler, 109 | RobustScaler, 110 | PowerTransformer, 111 | QuantileTransformer, 112 | MinMaxScaler, 113 | ], 114 | ) -> xr.Dataset: 115 | df = xdf[columns].to_dataframe() 116 | df.loc[:, columns] = scaler.transform(df[columns]) 117 | xdf_scale = df[columns].to_xarray() 118 | xdf = xr.merge([xdf_scale, xdf], compat="override") 119 | return xdf 120 | 121 | # Scale S data 122 | xdf = transform_data(xdf, self.columns_s, self.scaler_s) 123 | # Scale G data 124 | xdf = transform_data(xdf, self.columns_g, self.scaler_g) 125 | # Scale M data 126 | xdf = transform_data(xdf, self.columns_m, self.scaler_m) 127 | 128 | if target: 129 | # Scale M data 130 | xdf = transform_data(xdf, [target], self.scaler_t) 131 | 132 | return xdf 133 | 134 | def fit_transform(self, xdf: xr.Dataset, target: str) -> xr.Dataset: 135 | """Fit to data, then transform it. 136 | 137 | :param xdf: The data used to perform fit and transform. 138 | :type xdf: xr.Dataset 139 | :param target: Column name used to scale along the Target axis 140 | :type target: str 141 | :return: Transformed Dataset. 142 | :rtype: xr.Dataset 143 | """ 144 | return self.fit(xdf, target).transform(xdf, target) 145 | 146 | def inverse_transform(self, xdf: xr.Dataset, target: str = None) -> xr.Dataset: 147 | """Scale back the data to the original representation. 148 | 149 | :param xdf: The data used to scale along the features axis. 150 | :type xdf: xr.Dataset 151 | :param target: Column name used to scale along the Target axis, defaults to None 152 | :type target: str, optional 153 | :return: Transformed Dataset. 154 | :rtype: xr.Dataset 155 | """ 156 | 157 | def inverse_transform_data( 158 | xdf: xr.Dataset, 159 | columns: str, 160 | scaler: Union[ 161 | StandardScaler, 162 | RobustScaler, 163 | PowerTransformer, 164 | QuantileTransformer, 165 | MinMaxScaler, 166 | ], 167 | ) -> xr.Dataset: 168 | df = xdf[columns].to_dataframe() 169 | df.loc[:, columns] = scaler.inverse_transform(df[columns]) 170 | xdf_scale = df[columns].to_xarray() 171 | xdf = xr.merge([xdf_scale, xdf], compat="override") 172 | return xdf 173 | 174 | # Scale S data 175 | xdf = inverse_transform_data(xdf, self.columns_s, self.scaler_s) 176 | # Scale G data 177 | xdf = inverse_transform_data(xdf, self.columns_g, self.scaler_g) 178 | # Scale M data 179 | xdf = inverse_transform_data(xdf, self.columns_m, self.scaler_m) 180 | 181 | if target: 182 | # Scale M data 183 | xdf = inverse_transform_data(xdf, [target], self.scaler_t) 184 | 185 | return xdf 186 | -------------------------------------------------------------------------------- /src/models/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from datetime import datetime 4 | from os.path import join 5 | 6 | import pandas as pd 7 | import torch 8 | import torch.nn as nn 9 | import wandb 10 | from sklearn.metrics import r2_score 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | 14 | parent = os.path.abspath('.') 15 | sys.path.insert(1, parent) 16 | from utils import ROOT_DIR 17 | 18 | 19 | def compute_r2_scores(observations: list, labels: list, preds: list) -> tuple[float, float]: 20 | """ Compute R^2 scores for a given set of observations and labels. 21 | 22 | :param observations: list of observations 23 | :type observations: list[int] 24 | :param labels: list of labels 25 | :type labels: list[float] 26 | :param preds: list of predictions 27 | :type preds: list[float] 28 | :return: R^2 scores (the full is the score for the all the rows, 29 | the mean is the aggregated score grouped by observations) 30 | :rtype: tuple[float, float] 31 | """ 32 | 33 | df = pd.DataFrame() 34 | df['observations'] = observations 35 | df['labels'] = labels 36 | df['preds'] = preds 37 | full_r2_score = r2_score(df.labels, df.preds) 38 | df = df.groupby(['observations']).mean() 39 | mean_r2_score = r2_score(df.labels, df.preds) 40 | 41 | return full_r2_score, mean_r2_score 42 | 43 | 44 | class Trainer: 45 | """ Define the Trainer class. 46 | 47 | :param model: our deep learning model 48 | :type model: nn.Module 49 | :param train_dataloader: training dataloader 50 | :type train_dataloader: DataLoader 51 | :param val_dataloader: validation dataloader 52 | :type val_dataloader: DataLoader 53 | :param epochs: max number of epochs 54 | :type epochs: int 55 | :param criterion: loss function 56 | :param optimizer: model optimizer 57 | :param scheduler: learning scheduler 58 | """ 59 | def __init__(self, model: nn.Module, train_dataloader: DataLoader, val_dataloader: DataLoader, 60 | epochs: int, criterion, optimizer, scheduler): 61 | self.model = model 62 | self.train_loader = train_dataloader 63 | self.val_loader = val_dataloader 64 | self.criterion = criterion 65 | self.epochs = epochs 66 | self.optimizer = optimizer 67 | self.scheduler = scheduler 68 | self.timestamp = int(datetime.now().timestamp()) 69 | self.val_best_r2_score = 0. 70 | 71 | def train_one_epoch(self) -> float: 72 | """ Train the model for one epoch. 73 | 74 | :return: the training loss 75 | :rtype: float 76 | """ 77 | train_loss = 0. 78 | 79 | self.model.train() 80 | 81 | pbar = tqdm(self.train_loader, leave=False) 82 | for i, data in enumerate(pbar): 83 | keys_input = ['s_input', 'm_input', 'g_input'] 84 | inputs = {key: data[key] for key in keys_input} 85 | labels = data['target'] 86 | 87 | # Zero gradients for every batch 88 | self.optimizer.zero_grad() 89 | 90 | # Make predictions for this batch 91 | outputs = self.model(inputs) 92 | 93 | # Compute the loss and its gradients 94 | loss = self.criterion(outputs, labels) 95 | loss.backward() 96 | 97 | # Adjust learning weights 98 | self.optimizer.step() 99 | 100 | train_loss += loss.item() 101 | epoch_loss = train_loss / (i + 1) 102 | 103 | # Update the progress bar with new metrics values 104 | pbar.set_description(f'TRAIN - Batch: {i + 1}/{len(self.train_loader)} - ' 105 | f'Epoch Loss: {epoch_loss:.5f} - ' 106 | f'Batch Loss: {loss.item():.5f}') 107 | 108 | train_loss /= len(self.train_loader) 109 | 110 | return train_loss 111 | 112 | def val_one_epoch(self) -> tuple[float, float, float]: 113 | """ Validate the model for one epoch. 114 | 115 | :return: the validation loss, the R^2 score and the aggregated R^2 score 116 | :rtype: tuple[float, float, float] 117 | """ 118 | val_loss = 0. 119 | observations = [] 120 | val_labels = [] 121 | val_preds = [] 122 | 123 | self.model.eval() 124 | 125 | pbar = tqdm(self.val_loader, leave=False) 126 | for i, data in enumerate(pbar): 127 | keys_input = ['s_input', 'm_input', 'g_input'] 128 | inputs = {key: data[key] for key in keys_input} 129 | labels = data['target'] 130 | 131 | outputs = self.model(inputs) 132 | 133 | loss = self.criterion(outputs, labels) 134 | val_loss += loss.item() 135 | epoch_loss = val_loss / (i + 1) 136 | 137 | observations += data['observation'].squeeze().tolist() 138 | val_labels += labels.squeeze().tolist() 139 | val_preds += outputs.squeeze().tolist() 140 | 141 | # Update the progress bar with new metrics values 142 | pbar.set_description(f'VAL - Batch: {i + 1}/{len(self.val_loader)} - ' 143 | f'Epoch Loss: {epoch_loss:.5f} - ' 144 | f'Batch Loss: {loss.item():.5f}') 145 | 146 | val_loss /= len(self.val_loader) 147 | val_r2_score, val_mean_r2_score = compute_r2_scores(observations, val_labels, val_preds) 148 | 149 | return val_loss, val_r2_score, val_mean_r2_score 150 | 151 | def save(self, score: float): 152 | """ Save the model if it is the better than the previous sevaed one. 153 | 154 | :param score: current model epoch score 155 | :type score: float 156 | """ 157 | save_folder = join(ROOT_DIR, 'models') 158 | 159 | if score > self.val_best_r2_score: 160 | self.val_best_r2_score = score 161 | os.makedirs(save_folder, exist_ok=True) 162 | 163 | # delete the former best model 164 | former_model = [f for f in os.listdir(save_folder) if f.split('_')[-1] == f'{self.timestamp}.pt'] 165 | if len(former_model) == 1: 166 | os.remove(join(save_folder, former_model[0])) 167 | 168 | # save the new model 169 | score = str(score)[:7].replace('.', '-') 170 | file_name = f'{score}_model_{self.timestamp}.pt' 171 | save_path = join(save_folder, file_name) 172 | torch.save(self.model, save_path) 173 | 174 | def train(self): 175 | """ Main function to train the model. """ 176 | iter_epoch = tqdm(range(self.epochs), leave=False) 177 | 178 | for epoch in iter_epoch: 179 | iter_epoch.set_description(f'EPOCH {epoch + 1}/{self.epochs}') 180 | train_loss = self.train_one_epoch() 181 | 182 | val_loss, val_r2_score, val_mean_r2_score = self.val_one_epoch() 183 | self.scheduler.step(val_loss) 184 | self.save(val_mean_r2_score) 185 | 186 | # log the metrics to W&B 187 | wandb.log({ 188 | 'train_loss': train_loss, 189 | 'val_loss': val_loss, 190 | 'val_r2_score': val_r2_score, 191 | 'val_mean_r2_score': val_mean_r2_score, 192 | 'val_best_r2_score': self.val_best_r2_score 193 | }) 194 | 195 | # Write the finished epoch metrics values 196 | iter_epoch.write(f'EPOCH {epoch + 1}/{self.epochs}: ' 197 | f'Train = {train_loss:.5f} - ' 198 | f'Val = {val_loss:.5f} - ' 199 | f'Val R2 = {val_r2_score:.5f} - ' 200 | f'Val mean R2 = {val_mean_r2_score:.5f}') 201 | -------------------------------------------------------------------------------- /src/models/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Dict, List, Tuple 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import xarray as xr 9 | from scipy import stats 10 | from sklearn.model_selection import train_test_split 11 | from torch.utils.data import DataLoader, Dataset 12 | 13 | from src.constants import (FOLDER, G_COLUMNS, M_COLUMNS, S_COLUMNS, TARGET, 14 | TARGET_TEST) 15 | 16 | parent = os.path.abspath(".") 17 | sys.path.insert(1, parent) 18 | 19 | from os.path import join 20 | 21 | from utils import ROOT_DIR 22 | 23 | 24 | class CustomDataset(Dataset): 25 | def __init__( 26 | self, 27 | s_inputs: np.ndarray, 28 | g_inputs: np.ndarray, 29 | m_inputs: np.ndarray, 30 | obs_targets: np.ndarray, 31 | augment: int, 32 | device: str, 33 | ): 34 | """Dataset used for the dataloader. 35 | 36 | :param s_inputs: Satellite data. 37 | :type s_inputs: np.ndarray 38 | :param g_inputs: Raw data. 39 | :type g_inputs: np.ndarray 40 | :param m_inputs: Meteorological data. 41 | :type m_inputs: np.ndarray 42 | :param obs_targets: Yield data. 43 | :type obs_targets: np.ndarray 44 | :param augment: Number of data augmentation. 45 | :type augment: int 46 | :param device: Training device. 47 | :type device: str 48 | """ 49 | # Move data on the training device. 50 | self.augment = augment 51 | self.device = device 52 | self.s_inputs = torch.tensor(s_inputs).to( 53 | device=self.device, dtype=torch.float32 54 | ) 55 | self.g_inputs = torch.tensor(g_inputs).to( 56 | device=self.device, dtype=torch.float32 57 | ) 58 | self.m_inputs = torch.tensor(m_inputs).to( 59 | device=self.device, dtype=torch.float32 60 | ) 61 | self.observations = torch.tensor(obs_targets[:, 0]).to( 62 | device=self.device, dtype=torch.float32 63 | ) 64 | self.targets = torch.tensor(obs_targets[:, 1]).to( 65 | device=self.device, dtype=torch.float32 66 | ) 67 | 68 | def __len__(self): 69 | return self.s_inputs.shape[0] 70 | 71 | def __getitem__(self, idx): 72 | # Return data for a particular indexe 73 | # The data depend only on the observation indexe 74 | # Only the satellite data depend on the augmentation indexe 75 | idx_obs = idx // self.augment 76 | item = { 77 | "observation": self.observations[[idx_obs]], 78 | "s_input": self.s_inputs[idx], 79 | "m_input": self.m_inputs[idx_obs], 80 | "g_input": self.g_inputs[idx_obs], 81 | "target": self.targets[[idx_obs]], 82 | } 83 | 84 | return item 85 | 86 | 87 | def create_train_val_idx(xds: xr.Dataset, val_rate: float) -> Tuple[List, List]: 88 | """Compute a stratifate Train/Val split. 89 | 90 | :param xds: Dataset used for the split. 91 | :type xds: xr.Dataset 92 | :param val_rate: Percentage of data in the validation set. 93 | :type val_rate: float 94 | :return: return list of train index & list of val index 95 | :rtype: tuple[list, list] 96 | """ 97 | yields = xds[TARGET].values 98 | yields_distribution = stats.norm(loc=yields.mean(), scale=yields.std()) 99 | bounds = yields_distribution.cdf([0, 1]) 100 | bins = np.linspace(*bounds, num=10) 101 | stratify = np.digitize(yields, bins) 102 | train_idx, val_idx = train_test_split( 103 | xds.ts_obs, test_size=val_rate, random_state=42, stratify=stratify 104 | ) 105 | 106 | return train_idx, val_idx 107 | 108 | 109 | def transform_data( 110 | xds: xr.Dataset, m_times: int = 120, test=False 111 | ) -> Dict[str, np.ndarray]: 112 | """Transform data from xr.Dataset to dict of np.ndarray 113 | sorted by observation and augmentation. 114 | 115 | :param xds: The Dataset to be transformed. 116 | :type xds: xr.Dataset 117 | :param m_times: Length of the time series for Weather data, defaults to 120. 118 | :type m_times: int, optional 119 | :param test: True if it is the Test dataset, defaults to False. 120 | :type test: bool, optional 121 | :return: Dictionnary of all data used to construct the torch Dataset. 122 | :rtype: dict[str, np.ndarray] 123 | """ 124 | items = {} 125 | # Dataset sorting for compatibility with torch Dataset indexes 126 | xds = xds.sortby(["ts_obs", "ts_aug"]) 127 | 128 | # Create raw data 129 | g_arr = xds[G_COLUMNS].to_dataframe() 130 | items["g_inputs"] = g_arr.values 131 | 132 | # Create satellite data 133 | # Keep only useful values and convert into numpy array 134 | s_arr = xds[S_COLUMNS].to_dataframe()[S_COLUMNS] 135 | s_arr = s_arr.to_numpy() 136 | # Reshape axis to match index, date, features 137 | # TODO: set as variable the number of state_dev and features. 138 | s_arr = s_arr.reshape(s_arr.shape[0] // 24, 24, 8) 139 | items["s_inputs"] = s_arr 140 | 141 | # Create Meteorological data 142 | # time and District are the keys to link observations and meteorological data 143 | df_time = xds[["time", "District"]].to_dataframe() 144 | # Keep only useful data 145 | df_time.reset_index(inplace=True) 146 | df_time = df_time[["ts_obs", "state_dev", "time", "District"]] 147 | # Meteorological data only dependend of the observation 148 | df_time = df_time.groupby(["ts_obs", "state_dev", "District"]).first() 149 | # Take the min and max datetime of satellite data to create a daily time series of meteorological data 150 | df_time.reset_index("state_dev", inplace=True) 151 | # TODO: set as variable the number of state_dev. 152 | df_time = df_time[df_time["state_dev"].isin([0, 23])] 153 | df_time = df_time.pivot(columns="state_dev").droplevel(None, axis=1) 154 | df_time.reset_index("District", inplace=True) 155 | 156 | # For each observation take m_times daily date before the 157 | # harverest date and get data with the corresponding location 158 | list_weather = [] 159 | for _, series in df_time.iterrows(): 160 | all_dates = pd.date_range(series[0], series[23], freq="D") 161 | all_dates = all_dates[-m_times:] 162 | m_arr = ( 163 | xds.sel(datetime=all_dates, name=series["District"])[M_COLUMNS] 164 | .to_array() 165 | .values 166 | ) 167 | list_weather.append(m_arr.T) 168 | 169 | items["m_inputs"] = np.asarray(list_weather) 170 | 171 | # If test create the target array with 0 instead of np.nan 172 | if test: 173 | df = xds[TARGET_TEST].to_dataframe().reset_index() 174 | df[TARGET_TEST] = 0 175 | items["obs_targets"] = df.to_numpy() 176 | else: 177 | items["obs_targets"] = xds[TARGET].to_dataframe().reset_index().to_numpy() 178 | 179 | items["augment"] = xds["ts_aug"].values.shape[0] 180 | 181 | return items 182 | 183 | 184 | def get_dataloaders( 185 | batch_size: int, val_rate: float, device: str 186 | ) -> Tuple[DataLoader, DataLoader, DataLoader]: 187 | """Generate Train / Validation / Test Torch Dataloader. 188 | 189 | :param batch_size: Batch size of Dataloader. 190 | :type batch_size: int 191 | :param val_rate: Percentage of data on the Validation Dataset. 192 | :type val_rate: float 193 | :param device: Device where to put the data. 194 | :type device: str 195 | :return: Train / Validation / Test Dataloader 196 | :rtype: tuple[DataLoader, DataLoader, DataLoader] 197 | """ 198 | # Read the dataset processed 199 | dataset_path = join(ROOT_DIR, "data", "processed", FOLDER, "train_enriched.nc") 200 | xdf_train = xr.open_dataset(dataset_path, engine="scipy") 201 | 202 | # Create a Train / Validation split 203 | train_idx, val_idx = create_train_val_idx(xdf_train, val_rate) 204 | train_array = xdf_train.sel(ts_obs=train_idx) 205 | # Prepare data for th Torch Dataset 206 | items = transform_data(train_array) 207 | train_dataset = CustomDataset(**items, device=device) 208 | # Create the Dataloader 209 | train_dataloader = DataLoader( 210 | train_dataset, batch_size=batch_size, drop_last=True, shuffle=True 211 | ) 212 | 213 | # ?: Make a function to create each dataloader 214 | val_array = xdf_train.sel(ts_obs=val_idx) 215 | items = transform_data(val_array) 216 | val_dataset = CustomDataset(**items, device=device) 217 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size, drop_last=True) 218 | 219 | dataset_path = join(ROOT_DIR, "data", "processed", FOLDER, "test_enriched.nc") 220 | xdf_test = xr.open_dataset(dataset_path, engine="scipy") 221 | items = transform_data(xdf_test, test=True) 222 | test_dataset = CustomDataset(**items, device=device) 223 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True) 224 | 225 | return train_dataloader, val_dataloader, test_dataloader 226 | -------------------------------------------------------------------------------- /src/data/preprocessing.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Union 3 | 4 | warnings.filterwarnings("ignore") 5 | 6 | import os 7 | import sys 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import xarray as xr 12 | from scipy.signal import savgol_filter 13 | from sklearn.preprocessing import ( 14 | MinMaxScaler, 15 | QuantileTransformer, 16 | RobustScaler, 17 | StandardScaler, 18 | ) 19 | 20 | parent = os.path.abspath(".") 21 | sys.path.insert(1, parent) 22 | 23 | from os.path import join 24 | 25 | from sklearn.base import BaseEstimator, OneToOneFeatureMixin, TransformerMixin 26 | 27 | from src.constants import G_COLUMNS, M_COLUMNS, S_COLUMNS 28 | from utils import ROOT_DIR 29 | 30 | 31 | class Sorter(OneToOneFeatureMixin, TransformerMixin, BaseEstimator): 32 | """Sort dataset to align dataset samples with labels samples.""" 33 | 34 | def __init__(self) -> None: 35 | pass 36 | 37 | def fit(self, X=None, y=None) -> object: 38 | """Identity function 39 | 40 | :param X: Ignored 41 | :type X: None 42 | :param y: Ignored 43 | :type y: None 44 | :return: self 45 | :rtype: object 46 | """ 47 | return self 48 | 49 | def transform(self, X: pd.DataFrame) -> pd.DataFrame: 50 | """Reorder the indexes in an ascending way first by observation then by augmentation. 51 | 52 | :param X: Dataset that will be transformed. 53 | :type X: pd.DataFrame 54 | :return: Transformed Dataframe. 55 | :rtype: pd.DataFrame 56 | """ 57 | return X.reorder_levels(["ts_obs", "ts_aug"]).sort_index() 58 | 59 | 60 | # Convertor class used on ML exploration 61 | class Convertor(OneToOneFeatureMixin, TransformerMixin, BaseEstimator): 62 | """Used to transform the xarray.Dataset into pandas.DataFrame and reduce the dimention and/or tranform it. 63 | 64 | :param agg: If True then replace features with their aggregations along the state_dev axis (agg = min, mean, max), defaults to None 65 | :type agg: bool, optional 66 | :param weather: If False then remove weather data from the Dataset, defaults to True 67 | :type weather: bool, optional 68 | :param vi: If False then remove vegetable indices from the Dataset, defaults to True 69 | :type vi: bool, optional 70 | """ 71 | 72 | def __init__(self, agg: bool = None, weather: bool = True, vi: bool = True) -> None: 73 | self.agg = agg 74 | self.weather = weather 75 | self.vi = vi 76 | 77 | def to_dataframe(self, X: xr.Dataset) -> pd.DataFrame: 78 | # Convert xarray.Dataset into usable pandas.DataFrame 79 | 80 | # Depend of aggregations was performed, change the columns name 81 | col = "agg" if self.agg else "state_dev" 82 | # Convert xarray.Dataset into pandas.DataFrame 83 | df = X.to_dataframe() 84 | # set G_COLUMNS as index to not be duplicate by the pivot operation 85 | df.set_index(G_COLUMNS, append=True, inplace=True) 86 | # reset the columns use to apply the pivot and convert its values into string 87 | df.reset_index(col, inplace=True) 88 | df[col] = df[col].astype(str) 89 | # Apply pivot to change state_dev or agg from samples to features 90 | df = df.pivot(columns=col) 91 | # Convert pandas.MultiIndex to a pandas.Index by merging names 92 | df.columns = df.columns.map("_".join).str.strip("_") 93 | # set G_COLUMNS as features 94 | df.reset_index(G_COLUMNS, inplace=True) 95 | # sort dataset for future compability 96 | df = df.reorder_levels(["ts_obs", "ts_aug"]).sort_index() 97 | return df 98 | 99 | def merge_dimensions(self, X: xr.Dataset) -> xr.Dataset: 100 | # Merge VI, Geographical and Meteorological data into the same dimension 101 | X = xr.merge( 102 | [ 103 | X[G_COLUMNS], 104 | X[M_COLUMNS].sel(datetime=X["time"], name=X["District"]), 105 | X[S_COLUMNS], 106 | ] 107 | ) 108 | # Drop useless columns 109 | X = X.drop(["name", "datetime", "time"]) 110 | return X 111 | 112 | def compute_agg(self, X: xr.Dataset) -> xr.Dataset: 113 | # Compute aggregation on the Dataset and set the new dimension values 114 | # with the name of each aggregation performed 115 | X = xr.concat( 116 | [X.mean(dim="state_dev"), X.max(dim="state_dev"), X.min(dim="state_dev")], 117 | dim="agg", 118 | ) 119 | X["agg"] = ["mean", "max", "min"] 120 | return X 121 | 122 | def fit(self, X=None, y=None) -> object: 123 | """Identity function. 124 | 125 | :param X: Ignored 126 | :type X: None 127 | :param y: Ignored 128 | :type y: None 129 | :return: Convertor. 130 | :rtype: object 131 | """ 132 | return self 133 | 134 | def transform(self, X: xr.Dataset) -> pd.DataFrame: 135 | """Transform the xarray.Dataset to pandas.Dataframe depends on the argument of the class. 136 | 137 | :param X: Dataset that will be transformed. 138 | :type X: xr.Dataset 139 | :return: Transformed Dataset. 140 | :rtype: pd.DataFrame 141 | """ 142 | # Transform data to depends of the sames dimentions 143 | X = self.merge_dimensions(X) 144 | # If True, compute aggregation to the data 145 | if self.agg: 146 | X = self.compute_agg(X) 147 | # If False, remove weather data 148 | if not self.weather: 149 | X = X.drop(M_COLUMNS) 150 | # If False, remove vi data 151 | if not self.vi: 152 | X = X.drop(S_COLUMNS) 153 | # Convert the Dataset into a DataFrame 154 | X = self.to_dataframe(X) 155 | return X 156 | 157 | 158 | class Smoother(OneToOneFeatureMixin, TransformerMixin, BaseEstimator): 159 | """Smooth Vegetable Indice Data. 160 | 161 | :param mode: methode used to smooth vi data, None to not perform smoothing during , defaults to "savgol" 162 | :type mode: str, optional 163 | """ 164 | 165 | def __init__(self, mode: str = "savgol") -> None: 166 | self.mode = mode 167 | 168 | def smooth_savgol(self, ds: xr.Dataset) -> xr.Dataset: 169 | # apply savgol_filter to vegetable indice 170 | ds_s = xr.apply_ufunc( 171 | savgol_filter, 172 | ds[S_COLUMNS], 173 | kwargs={"axis": 2, "window_length": 12, "polyorder": 4, "mode": "mirror"}, 174 | ) 175 | # merge both dataset and override old vegetable indice and bands 176 | return xr.merge([ds_s, ds], compat="override") 177 | 178 | def fit(self, X: xr.Dataset = None, y=None) -> object: 179 | """Identity function. 180 | 181 | :param X: Ignored, defaults to None 182 | :type X: xr.Dataset, optional 183 | :param y: Ignored, defaults to None 184 | :type y: _type_, optional 185 | :return: Themself. 186 | :rtype: object 187 | """ 188 | return self 189 | 190 | def transform(self, X: xr.Dataset) -> xr.Dataset: 191 | """Smooth Vegetable Indice Data according to the mode used. 192 | 193 | :param X: Dataset that will be transformed. 194 | :type X: xr.Dataset 195 | :return: Transformed Dataset. 196 | :rtype: xr.Dataset 197 | """ 198 | # If mode not equal to savgol, transform correspond to identity function. 199 | if self.mode == "savgol": 200 | X = self.smooth_savgol(X) 201 | 202 | return X 203 | 204 | 205 | def replaceinf(arr: np.ndarray) -> np.ndarray: 206 | if np.issubdtype(arr.dtype, np.number): 207 | arr[np.isinf(arr)] = np.nan 208 | return arr 209 | 210 | 211 | class Filler(OneToOneFeatureMixin, TransformerMixin, BaseEstimator): 212 | """Fill dataset using the mean of each group of observation for a given date. 213 | For the reaining data use the mean of the dataset for a given developpment state. 214 | """ 215 | 216 | def __init__(self) -> None: 217 | self.values = None 218 | 219 | def fit(self, X: xr.Dataset, y=None) -> object: 220 | """Compute mean by developpement state to be used for later filling. 221 | 222 | :param X: The data used to compute mean by developpement state used for later filling. 223 | :type X: xr.Dataset 224 | :param y: Ignored 225 | :type y: None 226 | :return: self 227 | :rtype: object 228 | """ 229 | # replace infinite value by na 230 | xr.apply_ufunc(replaceinf, X[S_COLUMNS]) 231 | # compute mean of all stage of developpement for each cluster obsevation 232 | self.values = ( 233 | X[S_COLUMNS].mean(dim="ts_aug", skipna=True).mean(dim="ts_obs", skipna=True) 234 | ) 235 | 236 | return self 237 | 238 | def transform(self, X: xr.Dataset) -> xr.Dataset: 239 | """Performs the filling of missing values 240 | 241 | :param X: The dataset used to fill. 242 | :type X: xr.Dataset 243 | :return: Transformed Dataset. 244 | :rtype: xr.Dataset 245 | """ 246 | # replace infinite value by na 247 | xr.apply_ufunc(replaceinf, X[S_COLUMNS]) 248 | # compute mean of all stage of developpement and all obsevation 249 | X[S_COLUMNS] = X[S_COLUMNS].fillna(X[S_COLUMNS].mean(dim="ts_aug", skipna=True)) 250 | # fill na value with fited mean 251 | X[S_COLUMNS] = X[S_COLUMNS].fillna(self.values) 252 | 253 | return X 254 | -------------------------------------------------------------------------------- /docs/build/html/search.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
115 | Data
119 |Deep Learning
127 |Machine Learning
137 |
116 | get the training and validation dataloaders.
the model config and the training and validation dataloaders
132 |(dict, DataLoader, DataLoader)
135 |