├── docs ├── build │ ├── html │ │ ├── _static │ │ │ ├── docsearch_config.js │ │ │ ├── awesome-sphinx-design.31d6cfe0d16ae931b73c.js │ │ │ ├── banner.jpg │ │ │ ├── file.png │ │ │ ├── minus.png │ │ │ ├── plus.png │ │ │ ├── 26400cae88e50682937d.woff │ │ │ ├── 3925889378745d0382d0.woff │ │ │ ├── 5b12b1b913a1d0348fc6.woff │ │ │ ├── 6c1a3008005254946aef.woff │ │ │ ├── b8546ea1646db8ea9c7f.woff │ │ │ ├── f509ddf49c74ded8c0ee.woff │ │ │ ├── 0ff19efc74e94c856af0.woff2 │ │ │ ├── 2a472f0334546ace60b3.woff2 │ │ │ ├── 4163112e566ed7697acf.woff2 │ │ │ ├── aef37e2fab43d03531cd.woff2 │ │ │ ├── c10c163dd1c289f11c49.woff2 │ │ │ ├── f4604891b5f1fc1bdbe5.woff2 │ │ │ ├── docsearch.8cecda6602174cdd6086.js.LICENSE.txt │ │ │ ├── theme.929b2cdd5fa757959a38.js.LICENSE.txt │ │ │ ├── documentation_options.js │ │ │ ├── awesome-sphinx-design.f3507627e0af8330cfa7.css │ │ │ ├── manifest.json │ │ │ ├── pygments.css │ │ │ ├── doctools.js │ │ │ ├── sphinx_highlight.js │ │ │ └── language_data.js │ │ ├── objects.inv │ │ ├── _images │ │ │ └── banner.jpg │ │ ├── _sources │ │ │ ├── model.rst.txt │ │ │ ├── make_train.rst.txt │ │ │ ├── make_train_ml.rst.txt │ │ │ ├── trainer.rst.txt │ │ │ ├── dataloader.rst.txt │ │ │ ├── make_submission.rst.txt │ │ │ ├── documentation.rst.txt │ │ │ ├── make_data.rst.txt │ │ │ ├── preprocessing.rst.txt │ │ │ ├── make_preprocessing.rst.txt │ │ │ ├── tutorial.rst.txt │ │ │ ├── install.rst.txt │ │ │ └── index.rst.txt │ │ ├── .buildinfo │ │ ├── search.html │ │ ├── documentation.html │ │ └── make_train.html │ └── doctrees │ │ ├── index.doctree │ │ ├── model.doctree │ │ ├── install.doctree │ │ ├── trainer.doctree │ │ ├── tutorial.doctree │ │ ├── dataloader.doctree │ │ ├── environment.pickle │ │ ├── make_data.doctree │ │ ├── make_train.doctree │ │ ├── documentation.doctree │ │ ├── make_train_ml.doctree │ │ ├── preprocessing.doctree │ │ ├── make_submission.doctree │ │ └── make_preprocessing.doctree ├── source │ ├── _static │ │ └── banner.jpg │ ├── model.rst │ ├── make_train.rst │ ├── trainer.rst │ ├── make_train_ml.rst │ ├── make_submission.rst │ ├── dataloader.rst │ ├── documentation.rst │ ├── make_data.rst │ ├── preprocessing.rst │ ├── make_preprocessing.rst │ ├── conf.py │ ├── tutorial.rst │ ├── install.rst │ └── index.rst ├── Makefile └── make.bat ├── utils.py ├── assets ├── model.png ├── smoothing.png └── crop-forecasting.png ├── .gitignore ├── CITATION.cff ├── environment.yml ├── src ├── constants.py ├── models │ ├── ml │ │ ├── sweep.yaml │ │ └── make_train.py │ ├── sweep.yaml │ ├── model.py │ ├── make_submission.py │ ├── make_train.py │ ├── trainer.py │ └── dataloader.py └── data │ ├── datascaler.py │ ├── preprocessing.py │ ├── make_data.py │ └── make_preprocessing.py ├── Makefile ├── LICENSE ├── README.md ├── notebooks └── data │ ├── concat_data.ipynb │ ├── enrich_xarray.ipynb │ └── enrich_csv.ipynb ├── data └── raw │ └── test.csv └── submissions └── toasty-sky-343.csv /docs/build/html/_static/docsearch_config.js: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/build/html/_static/awesome-sphinx-design.31d6cfe0d16ae931b73c.js: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/assets/model.png -------------------------------------------------------------------------------- /assets/smoothing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/assets/smoothing.png -------------------------------------------------------------------------------- /assets/crop-forecasting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/assets/crop-forecasting.png -------------------------------------------------------------------------------- /docs/build/html/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/objects.inv -------------------------------------------------------------------------------- /docs/source/_static/banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/source/_static/banner.jpg -------------------------------------------------------------------------------- /docs/build/doctrees/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/doctrees/index.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/model.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/doctrees/model.doctree -------------------------------------------------------------------------------- /docs/build/html/_images/banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_images/banner.jpg -------------------------------------------------------------------------------- /docs/build/html/_static/banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/banner.jpg -------------------------------------------------------------------------------- /docs/build/html/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/file.png -------------------------------------------------------------------------------- /docs/build/html/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/minus.png -------------------------------------------------------------------------------- /docs/build/html/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/plus.png -------------------------------------------------------------------------------- /docs/build/doctrees/install.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/doctrees/install.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/trainer.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/doctrees/trainer.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/tutorial.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/doctrees/tutorial.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/dataloader.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/doctrees/dataloader.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/environment.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/doctrees/environment.pickle -------------------------------------------------------------------------------- /docs/build/doctrees/make_data.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/doctrees/make_data.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/make_train.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/doctrees/make_train.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/documentation.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/doctrees/documentation.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/make_train_ml.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/doctrees/make_train_ml.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/preprocessing.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/doctrees/preprocessing.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/make_submission.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/doctrees/make_submission.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/make_preprocessing.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/doctrees/make_preprocessing.doctree -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | **/.idea/ 3 | **/.ipynb_checkpoints 4 | **/__pycache__ 5 | **/model 6 | **/wandb 7 | /models 8 | core 9 | **/lightning_logs 10 | -------------------------------------------------------------------------------- /docs/build/html/_static/26400cae88e50682937d.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/26400cae88e50682937d.woff -------------------------------------------------------------------------------- /docs/build/html/_static/3925889378745d0382d0.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/3925889378745d0382d0.woff -------------------------------------------------------------------------------- /docs/build/html/_static/5b12b1b913a1d0348fc6.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/5b12b1b913a1d0348fc6.woff -------------------------------------------------------------------------------- /docs/build/html/_static/6c1a3008005254946aef.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/6c1a3008005254946aef.woff -------------------------------------------------------------------------------- /docs/build/html/_static/b8546ea1646db8ea9c7f.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/b8546ea1646db8ea9c7f.woff -------------------------------------------------------------------------------- /docs/build/html/_static/f509ddf49c74ded8c0ee.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/f509ddf49c74ded8c0ee.woff -------------------------------------------------------------------------------- /docs/build/html/_static/0ff19efc74e94c856af0.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/0ff19efc74e94c856af0.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/2a472f0334546ace60b3.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/2a472f0334546ace60b3.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/4163112e566ed7697acf.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/4163112e566ed7697acf.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/aef37e2fab43d03531cd.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/aef37e2fab43d03531cd.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/c10c163dd1c289f11c49.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/c10c163dd1c289f11c49.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/f4604891b5f1fc1bdbe5.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/association-rosia/crop-forecasting/HEAD/docs/build/html/_static/f4604891b5f1fc1bdbe5.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/docsearch.8cecda6602174cdd6086.js.LICENSE.txt: -------------------------------------------------------------------------------- 1 | /*! @docsearch/js 3.3.5 | MIT License | © Algolia, Inc. and contributors | https://docsearch.algolia.com */ 2 | -------------------------------------------------------------------------------- /docs/build/html/_static/theme.929b2cdd5fa757959a38.js.LICENSE.txt: -------------------------------------------------------------------------------- 1 | /*! 2 | * clipboard.js v2.0.11 3 | * https://clipboardjs.com/ 4 | * 5 | * Licensed MIT © Zeno Rocha 6 | */ 7 | -------------------------------------------------------------------------------- /docs/source/model.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | Pytorch Model 3 | ============= 4 | 5 | .. currentmodule:: src.models.model 6 | 7 | LSTMModel 8 | --------- 9 | 10 | .. autoclass:: LSTMModel 11 | :members: -------------------------------------------------------------------------------- /docs/build/html/_sources/model.rst.txt: -------------------------------------------------------------------------------- 1 | ============= 2 | Pytorch Model 3 | ============= 4 | 5 | .. currentmodule:: src.models.model 6 | 7 | LSTMModel 8 | --------- 9 | 10 | .. autoclass:: LSTMModel 11 | :members: -------------------------------------------------------------------------------- /docs/source/make_train.rst: -------------------------------------------------------------------------------- 1 | ===================== 2 | Train - Deep Learning 3 | ===================== 4 | 5 | .. currentmodule:: src.models.make_train 6 | 7 | init_wandb 8 | ---------- 9 | 10 | .. autofunction:: init_wandb -------------------------------------------------------------------------------- /docs/build/html/_sources/make_train.rst.txt: -------------------------------------------------------------------------------- 1 | ===================== 2 | Train - Deep Learning 3 | ===================== 4 | 5 | .. currentmodule:: src.models.make_train 6 | 7 | init_wandb 8 | ---------- 9 | 10 | .. autofunction:: init_wandb -------------------------------------------------------------------------------- /docs/build/html/.buildinfo: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: 6c63be8d011e773470cf148cdaaab1f0 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /docs/source/trainer.rst: -------------------------------------------------------------------------------- 1 | ======================= 2 | Trainer - Deep Learning 3 | ======================= 4 | 5 | .. currentmodule:: src.models.trainer 6 | 7 | compute_r2_scores 8 | ----------------- 9 | 10 | .. autofunction:: compute_r2_scores 11 | 12 | Trainer 13 | ------- 14 | 15 | .. autoclass:: Trainer 16 | :members: -------------------------------------------------------------------------------- /docs/source/make_train_ml.rst: -------------------------------------------------------------------------------- 1 | ======================== 2 | Train - Machine Learning 3 | ======================== 4 | 5 | .. currentmodule:: src.models.ml.make_train 6 | 7 | init_pipeline 8 | ------------- 9 | 10 | .. autofunction:: init_pipeline 11 | 12 | preprocess_y 13 | ------------ 14 | 15 | .. autofunction:: preprocess_y -------------------------------------------------------------------------------- /docs/build/html/_sources/make_train_ml.rst.txt: -------------------------------------------------------------------------------- 1 | ======================== 2 | Train - Machine Learning 3 | ======================== 4 | 5 | .. currentmodule:: src.models.ml.make_train 6 | 7 | init_pipeline 8 | ------------- 9 | 10 | .. autofunction:: init_pipeline 11 | 12 | preprocess_y 13 | ------------ 14 | 15 | .. autofunction:: preprocess_y -------------------------------------------------------------------------------- /docs/build/html/_sources/trainer.rst.txt: -------------------------------------------------------------------------------- 1 | ======================= 2 | Trainer - Deep Learning 3 | ======================= 4 | 5 | .. currentmodule:: src.models.trainer 6 | 7 | compute_r2_scores 8 | ----------------- 9 | 10 | .. autofunction:: compute_r2_scores 11 | 12 | Trainer 13 | ------- 14 | 15 | .. autoclass:: Trainer 16 | :members: -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "URGELL" 5 | given-names: "Baptiste" 6 | - family-names: "REBERGA" 7 | given-names: "Louis" 8 | title: "GitHub repository" 9 | publisher: "Github" 10 | year: "2023" 11 | version: 1.0 12 | date-released: 2023-4-9 13 | url: "https://github.com/association-rosia/crop-forecasting" 14 | data: "Crop Yield Data - EY" 15 | -------------------------------------------------------------------------------- /docs/source/make_submission.rst: -------------------------------------------------------------------------------- 1 | =================== 2 | Create a Submission 3 | =================== 4 | 5 | .. currentmodule:: src.models.make_submission 6 | 7 | rounded_yield 8 | ------------- 9 | 10 | .. autofunction:: rounded_yield 11 | 12 | get_device 13 | ---------- 14 | 15 | .. autofunction:: get_device 16 | 17 | create_submission 18 | ----------------- 19 | 20 | .. autofunction:: create_submission 21 | 22 | Evaluator 23 | --------- 24 | 25 | .. autoclass:: Evaluator 26 | :members: -------------------------------------------------------------------------------- /docs/source/dataloader.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | Dataloader 3 | ========== 4 | 5 | .. currentmodule:: src.models.dataloader 6 | 7 | CustomDataset 8 | ------------- 9 | 10 | .. autoclass:: CustomDataset 11 | :members: 12 | 13 | create_train_val_idx 14 | -------------------- 15 | 16 | .. autofunction:: create_train_val_idx 17 | 18 | transform_data 19 | -------------- 20 | 21 | .. autofunction:: transform_data 22 | 23 | get_dataloaders 24 | --------------- 25 | 26 | .. autofunction:: get_dataloaders -------------------------------------------------------------------------------- /docs/source/documentation.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | Documentation 3 | ============= 4 | 5 | .. toctree:: 6 | :caption: Data 7 | :titlesonly: 8 | 9 | make_data 10 | make_preprocessing 11 | preprocessing 12 | 13 | .. toctree:: 14 | :caption: Deep Learning 15 | :titlesonly: 16 | 17 | make_train 18 | trainer 19 | dataloader 20 | model 21 | make_submission 22 | 23 | .. toctree:: 24 | :caption: Machine Learning 25 | :titlesonly: 26 | 27 | make_train_ml -------------------------------------------------------------------------------- /docs/source/make_data.rst: -------------------------------------------------------------------------------- 1 | ================================== 2 | Download Data - Planetary Computer 3 | ================================== 4 | 5 | .. currentmodule:: src.data.make_data 6 | 7 | save_data 8 | --------- 9 | 10 | .. autofunction:: save_data 11 | 12 | save_data_app 13 | ------------- 14 | 15 | .. autofunction:: save_data_app 16 | 17 | init_df 18 | ------- 19 | 20 | .. autofunction:: init_df 21 | 22 | make_data 23 | --------- 24 | 25 | .. autofunction:: make_data 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/build/html/_sources/dataloader.rst.txt: -------------------------------------------------------------------------------- 1 | ========== 2 | Dataloader 3 | ========== 4 | 5 | .. currentmodule:: src.models.dataloader 6 | 7 | CustomDataset 8 | ------------- 9 | 10 | .. autoclass:: CustomDataset 11 | :members: 12 | 13 | create_train_val_idx 14 | -------------------- 15 | 16 | .. autofunction:: create_train_val_idx 17 | 18 | transform_data 19 | -------------- 20 | 21 | .. autofunction:: transform_data 22 | 23 | get_dataloaders 24 | --------------- 25 | 26 | .. autofunction:: get_dataloaders -------------------------------------------------------------------------------- /docs/build/html/_sources/make_submission.rst.txt: -------------------------------------------------------------------------------- 1 | =================== 2 | Create a Submission 3 | =================== 4 | 5 | .. currentmodule:: src.models.make_submission 6 | 7 | rounded_yield 8 | ------------- 9 | 10 | .. autofunction:: rounded_yield 11 | 12 | get_device 13 | ---------- 14 | 15 | .. autofunction:: get_device 16 | 17 | create_submission 18 | ----------------- 19 | 20 | .. autofunction:: create_submission 21 | 22 | Evaluator 23 | --------- 24 | 25 | .. autoclass:: Evaluator 26 | :members: -------------------------------------------------------------------------------- /docs/build/html/_sources/documentation.rst.txt: -------------------------------------------------------------------------------- 1 | ============= 2 | Documentation 3 | ============= 4 | 5 | .. toctree:: 6 | :caption: Data 7 | :titlesonly: 8 | 9 | make_data 10 | make_preprocessing 11 | preprocessing 12 | 13 | .. toctree:: 14 | :caption: Deep Learning 15 | :titlesonly: 16 | 17 | make_train 18 | trainer 19 | dataloader 20 | model 21 | make_submission 22 | 23 | .. toctree:: 24 | :caption: Machine Learning 25 | :titlesonly: 26 | 27 | make_train_ml -------------------------------------------------------------------------------- /docs/build/html/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: '1.0.0', 4 | LANGUAGE: 'en', 5 | COLLAPSE_INDEX: false, 6 | BUILDER: 'html', 7 | FILE_SUFFIX: '.html', 8 | LINK_SUFFIX: '.html', 9 | HAS_SOURCE: true, 10 | SOURCELINK_SUFFIX: '.txt', 11 | NAVIGATION_WITH_KEYS: false, 12 | SHOW_SEARCH_SUMMARY: true, 13 | ENABLE_SEARCH_SHORTCUTS: true, 14 | }; -------------------------------------------------------------------------------- /docs/build/html/_sources/make_data.rst.txt: -------------------------------------------------------------------------------- 1 | ================================== 2 | Download Data - Planetary Computer 3 | ================================== 4 | 5 | .. currentmodule:: src.data.make_data 6 | 7 | save_data 8 | --------- 9 | 10 | .. autofunction:: save_data 11 | 12 | save_data_app 13 | ------------- 14 | 15 | .. autofunction:: save_data_app 16 | 17 | init_df 18 | ------- 19 | 20 | .. autofunction:: init_df 21 | 22 | make_data 23 | --------- 24 | 25 | .. autofunction:: make_data 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: crop-forecasting-env 2 | dependencies: 3 | - python=3.10.* 4 | - pip 5 | - pandas 6 | - ipykernel 7 | - netCDF4 8 | - matplotlib 9 | - Jinja2 10 | - python-dotenv 11 | - seaborn 12 | - statsmodels 13 | - scikit-learn 14 | - rich 15 | - ipywidgets 16 | - pytorch::pytorch 17 | - pip: 18 | - xarray-spatial 19 | - odc-stac 20 | - rioxarray 21 | - pystac 22 | - ipyleaflet 23 | - stackstac 24 | - pystac-client 25 | - planetary-computer 26 | - wandb -------------------------------------------------------------------------------- /docs/source/preprocessing.rst: -------------------------------------------------------------------------------- 1 | ===================== 2 | Preprocessing - Utils 3 | ===================== 4 | 5 | .. currentmodule:: src.data.preprocessing 6 | 7 | Sorter 8 | ------ 9 | 10 | .. autoclass:: Sorter 11 | :members: 12 | 13 | Convertor 14 | --------- 15 | 16 | .. autoclass:: Convertor 17 | :members: 18 | 19 | Smoother 20 | -------- 21 | 22 | .. autoclass:: Smoother 23 | :members: 24 | 25 | Filler 26 | ------ 27 | 28 | .. autoclass:: Filler 29 | :members: 30 | 31 | .. currentmodule:: src.data.datascaler 32 | 33 | DatasetScaler 34 | ------------- 35 | 36 | .. autoclass:: DatasetScaler 37 | :members: -------------------------------------------------------------------------------- /docs/build/html/_sources/preprocessing.rst.txt: -------------------------------------------------------------------------------- 1 | ===================== 2 | Preprocessing - Utils 3 | ===================== 4 | 5 | .. currentmodule:: src.data.preprocessing 6 | 7 | Sorter 8 | ------ 9 | 10 | .. autoclass:: Sorter 11 | :members: 12 | 13 | Convertor 14 | --------- 15 | 16 | .. autoclass:: Convertor 17 | :members: 18 | 19 | Smoother 20 | -------- 21 | 22 | .. autoclass:: Smoother 23 | :members: 24 | 25 | Filler 26 | ------ 27 | 28 | .. autoclass:: Filler 29 | :members: 30 | 31 | .. currentmodule:: src.data.datascaler 32 | 33 | DatasetScaler 34 | ------------- 35 | 36 | .. autoclass:: DatasetScaler 37 | :members: -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | FOLDER = 'augment_100_5' 2 | 3 | BANDS = ['red', 'green', 'blue', 'rededge1', 'rededge2', 'rededge3', 'nir', 'swir'] 4 | VI = ['ndvi', 'savi', 'evi', 'rep', 'osavi', 'rdvi', 'mtvi1', 'lswi'] 5 | 6 | M_COLUMNS = ['tempmax', 'tempmin', 'temp', 'dew', 'humidity', 'precip', 'precipprob', 'precipcover', 'windspeed', 7 | 'winddir', 'sealevelpressure', 'cloudcover', 'solarradiation', 'solarenergy', 'uvindex', 'moonphase', 8 | 'solarexposure'] 9 | 10 | S_COLUMNS = ['ndvi', 'savi', 'evi', 'rep', 'osavi', 'rdvi', 'mtvi1', 'lswi'] 11 | 12 | G_COLUMNS = ['Other Rice Yield (kg/ha)', 'Field size (ha)', 'Rice Crop Intensity(D=Double, T=Triple)'] 13 | 14 | TARGET = 'Rice Yield (kg/ha)' 15 | 16 | TARGET_TEST = 'Predicted Rice Yield (kg/ha)' 17 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Makefile 2 | 3 | # Default target 4 | all: help 5 | 6 | # Help target 7 | help: 8 | @echo "Usage: make [target]" 9 | @echo "Targets:" 10 | @echo " train Train the model" 11 | @echo " data Download data" 12 | @echo " preprocess Preprocess data" 13 | @echo " submission Generate submission" 14 | @echo " optimize Optimize with sweep ID" 15 | @echo "" 16 | @echo "Variables:" 17 | @echo " sweepid Sweep ID for optimization" 18 | 19 | # Train target 20 | train: 21 | python src/models/make_train.py -m 22 | 23 | # Data target 24 | data: 25 | python src/data/make_data.py 26 | 27 | # Submission target 28 | submission: 29 | python src/models/make_submission.py 30 | 31 | # Preprocess target 32 | preprocess: 33 | python src/data/make_preprocessing.py 34 | 35 | # Optimize target 36 | optimize: 37 | @wandb agent winged-bull/crop-forecasting/$(sweepid) 38 | -------------------------------------------------------------------------------- /docs/source/make_preprocessing.rst: -------------------------------------------------------------------------------- 1 | =============================== 2 | Preprocess Data - Deep Learning 3 | =============================== 4 | 5 | .. currentmodule:: src.data.make_preprocessing 6 | 7 | merge_satellite 8 | --------------- 9 | 10 | .. autofunction:: merge_satellite 11 | 12 | add_observation 13 | --------------- 14 | 15 | .. autofunction:: add_observation 16 | 17 | add_weather 18 | ----------- 19 | 20 | .. autofunction:: add_weather 21 | 22 | compute_vi 23 | ---------- 24 | 25 | .. autofunction:: compute_vi 26 | 27 | features_modification 28 | --------------------- 29 | 30 | .. autofunction:: features_modification 31 | 32 | scale_data 33 | ---------- 34 | 35 | .. autofunction:: scale_data 36 | 37 | create_id 38 | --------- 39 | 40 | .. autofunction:: create_id 41 | 42 | create_pb 43 | --------- 44 | 45 | .. autofunction:: create_pb 46 | 47 | process_data 48 | ------------ 49 | 50 | .. autofunction:: process_data -------------------------------------------------------------------------------- /docs/build/html/_sources/make_preprocessing.rst.txt: -------------------------------------------------------------------------------- 1 | =============================== 2 | Preprocess Data - Deep Learning 3 | =============================== 4 | 5 | .. currentmodule:: src.data.make_preprocessing 6 | 7 | merge_satellite 8 | --------------- 9 | 10 | .. autofunction:: merge_satellite 11 | 12 | add_observation 13 | --------------- 14 | 15 | .. autofunction:: add_observation 16 | 17 | add_weather 18 | ----------- 19 | 20 | .. autofunction:: add_weather 21 | 22 | compute_vi 23 | ---------- 24 | 25 | .. autofunction:: compute_vi 26 | 27 | features_modification 28 | --------------------- 29 | 30 | .. autofunction:: features_modification 31 | 32 | scale_data 33 | ---------- 34 | 35 | .. autofunction:: scale_data 36 | 37 | create_id 38 | --------- 39 | 40 | .. autofunction:: create_id 41 | 42 | create_pb 43 | --------- 44 | 45 | .. autofunction:: create_pb 46 | 47 | process_data 48 | ------------ 49 | 50 | .. autofunction:: process_data -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 RosIA 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/models/ml/sweep.yaml: -------------------------------------------------------------------------------- 1 | method: bayes 2 | 3 | metric: 4 | name: val_r2_score 5 | goal: maximize 6 | 7 | program: src/models/ml/train.py 8 | 9 | parameters: 10 | 11 | n_splits: 12 | distribution: constant 13 | value: 5 14 | 15 | dim_reduction: 16 | distribution: categorical 17 | values: ['Aggregate', 'PCA', None] 18 | 19 | weather: 20 | distribution: categorical 21 | values: [True, False] 22 | 23 | vi: 24 | distribution: categorical 25 | values: ['savgol', True, False] 26 | 27 | n_estimators: 28 | distribution: int_uniform 29 | min: 100 30 | max: 1000 31 | 32 | colsample_bytree: 33 | distribution: q_uniform 34 | min: 0.1 35 | max: 0.5 36 | q: 0.01 37 | 38 | colsample_bylevel: 39 | distribution: q_uniform 40 | min: 0.1 41 | max: 0.5 42 | q: 0.01 43 | 44 | colsample_bynode: 45 | distribution: q_uniform 46 | min: 0.1 47 | max: 0.5 48 | q: 0.01 49 | 50 | subsample: 51 | distribution: q_uniform 52 | min: 0.1 53 | max: 0.5 54 | q: 0.01 55 | 56 | max_depth: 57 | distribution: int_uniform 58 | min: 2 59 | max: 10 60 | 61 | learning_rate: 62 | distribution: q_uniform 63 | min: 0.001 64 | max: 0.1 65 | q: 0.001 -------------------------------------------------------------------------------- /src/models/sweep.yaml: -------------------------------------------------------------------------------- 1 | method: bayes 2 | 3 | metric: 4 | name: val_r2_score 5 | goal: maximize 6 | 7 | program: src/models/make_train.py 8 | 9 | parameters: 10 | epochs: 11 | distribution: constant 12 | value: 20 13 | 14 | batch_size: 15 | distribution: categorical 16 | values: [16, 32, 64] 17 | 18 | learning_rate: 19 | distribution: uniform 20 | min: 0.00001 21 | max: 0.001 22 | 23 | scheduler_patience: 24 | distribution: int_uniform 25 | min: 3 26 | max: 9 27 | 28 | s_hidden_size: 29 | distribution: int_uniform 30 | min: 64 31 | max: 256 32 | 33 | m_hidden_size: 34 | distribution: int_uniform 35 | min: 64 36 | max: 256 37 | 38 | s_num_layers: 39 | distribution: int_uniform 40 | min: 1 41 | max: 2 42 | 43 | m_num_layers: 44 | distribution: int_uniform 45 | min: 1 46 | max: 2 47 | 48 | c_out_in_features_1: 49 | distribution: int_uniform 50 | min: 64 51 | max: 256 52 | 53 | c_out_in_features_2: 54 | distribution: int_uniform 55 | min: 64 56 | max: 256 57 | 58 | dropout: 59 | distribution: uniform 60 | min: 0. 61 | max: 0.8 -------------------------------------------------------------------------------- /docs/build/html/_static/awesome-sphinx-design.f3507627e0af8330cfa7.css: -------------------------------------------------------------------------------- 1 | :root{--sd-color-tabs-label-active:hsl(var(--foreground));--sd-color-tabs-underline-active:hsl(var(--accent-foreground));--sd-color-tabs-label-hover:hsl(var(--accent-foreground));--sd-color-tabs-overline:hsl(var(--border));--sd-color-tabs-underline:hsl(var(--border))}.sd-card{background-color:hsl(var(--card));border-color:hsl(var(--border));border-radius:var(--radius);border-width:1px;color:hsl(var(--card-foreground));margin-top:1.5rem}.sd-container-fluid{margin-bottom:1.5rem;margin-top:1.5rem}.sd-card-title{font-weight:600!important}.sd-summary-title{color:hsl(var(--muted-foreground));font-weight:500!important}.sd-card-footer,.sd-card-header{font-size:.875rem;line-height:1.25rem}.sd-tab-set{margin-top:1.5rem}.sd-tab-content>p{margin-bottom:1.5rem}.sd-tab-content pre:first-of-type{margin-top:0}.sd-tab-set>label{font-weight:500;letter-spacing:.05em}details.sd-dropdown{border-color:hsl(var(--border))}details.sd-dropdown summary:focus{outline-style:solid}.sd-cards-carousel{overflow-x:auto}.sd-shadow-sm{--tw-shadow:0 0 transparent!important;--tw-shadow-colored:0 0 transparent!important;box-shadow:0 0 transparent,0 0 transparent,0 0 transparent!important;box-shadow:var(--tw-ring-offset-shadow,0 0 #0000),var(--tw-ring-shadow,0 0 #0000),var(--tw-shadow)!important} 2 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | import sys 4 | 5 | curdir = os.path.abspath(os.curdir) 6 | sys.path.insert(0, join(curdir, os.pardir, os.pardir)) 7 | 8 | # Configuration file for the Sphinx documentation builder. 9 | # 10 | # For the full list of built-in configuration values, see the documentation: 11 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 12 | 13 | # -- Project information ----------------------------------------------------- 14 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 15 | 16 | project = 'Crop forecasting' 17 | copyright = '2023, Baptiste URGELL, Louis REBERGA' 18 | author = 'Baptiste URGELL, Louis REBERGA' 19 | release = '1.0.0' 20 | 21 | # -- General configuration --------------------------------------------------- 22 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 23 | 24 | extensions = ['sphinx.ext.autodoc'] 25 | 26 | templates_path = ['_templates'] 27 | exclude_patterns = [] 28 | 29 | 30 | 31 | # -- Options for HTML output ------------------------------------------------- 32 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 33 | 34 | # html_theme = 'karma_sphinx_theme' 35 | html_theme = "sphinxawesome_theme" 36 | 37 | html_static_path = ['_static'] 38 | 39 | 40 | rst_prolog = """ 41 | .. |project_name| replace:: Crop forecasting 42 | .. image:: _static/banner.jpg 43 | """ 44 | -------------------------------------------------------------------------------- /docs/build/html/_static/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "_static/theme.css": "_static/theme.871137ef162fc5460fb0.css", 3 | "_static/theme.js": "_static/theme.929b2cdd5fa757959a38.js", 4 | "_static/docsearch.css": "_static/docsearch.82e38f6cd0ccc11d77d4.css", 5 | "_static/docsearch.js": "_static/docsearch.8cecda6602174cdd6086.js", 6 | "_static/awesome-sphinx-design.css": "_static/awesome-sphinx-design.f3507627e0af8330cfa7.css", 7 | "_static/awesome-sphinx-design.js": "_static/awesome-sphinx-design.31d6cfe0d16ae931b73c.js", 8 | "_static/jetbrains-mono-latin-500-italic.woff": "_static/5b12b1b913a1d0348fc6.woff", 9 | "_static/jetbrains-mono-latin-700-italic.woff": "_static/b8546ea1646db8ea9c7f.woff", 10 | "_static/jetbrains-mono-latin-400-italic.woff": "_static/f509ddf49c74ded8c0ee.woff", 11 | "_static/jetbrains-mono-latin-700-normal.woff": "_static/3925889378745d0382d0.woff", 12 | "_static/jetbrains-mono-latin-500-normal.woff": "_static/26400cae88e50682937d.woff", 13 | "_static/jetbrains-mono-latin-400-normal.woff": "_static/6c1a3008005254946aef.woff", 14 | "_static/jetbrains-mono-latin-700-italic.woff2": "_static/aef37e2fab43d03531cd.woff2", 15 | "_static/jetbrains-mono-latin-500-italic.woff2": "_static/c10c163dd1c289f11c49.woff2", 16 | "_static/jetbrains-mono-latin-400-italic.woff2": "_static/2a472f0334546ace60b3.woff2", 17 | "_static/jetbrains-mono-latin-500-normal.woff2": "_static/0ff19efc74e94c856af0.woff2", 18 | "_static/jetbrains-mono-latin-700-normal.woff2": "_static/f4604891b5f1fc1bdbe5.woff2", 19 | "_static/jetbrains-mono-latin-400-normal.woff2": "_static/4163112e566ed7697acf.woff2", 20 | "_static/docsearch_config.js_t": "_static/docsearch_config.js_t" 21 | } -------------------------------------------------------------------------------- /docs/source/tutorial.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | Tutorial 3 | ======== 4 | 5 | Overview 6 | -------- 7 | 8 | Welcome to our data-driven project! In this tutorial, we will walk you through the steps to create and process a satellite dataset using the Cookiecutter architecture, and then run a parameter sweep using the Weights & Biases (wandb) platform. 9 | 10 | 11 | .. warning:: 12 | .. line-block:: 13 | Make sure you have installed all dependencies and have created a wandb account. 14 | See :doc:`Installation Guide ` for futher informations. 15 | 16 | 17 | Get Started |project_name| 18 | -------------------------- 19 | 20 | Navigate to the project directory in your terminal. 21 | 22 | Create satellite data 23 | ===================== 24 | 25 | - Run the `make_data.py` script to generate the satellite dataset (datasets has already been generated). 26 | 27 | .. code-block:: bash 28 | 29 | make data 30 | 31 | Preprocess data 32 | =============== 33 | 34 | - After generating the dataset, run the `make_preprocessing.py` script to preprocess and analyze the data. 35 | 36 | .. code-block:: bash 37 | 38 | make preprocess 39 | 40 | Train the model 41 | =============== 42 | 43 | - You can now train the model with the configuration available at the end of the file `make_train.py`. 44 | 45 | .. code-block:: bash 46 | 47 | make train 48 | 49 | Creating a Wandb Sweep 50 | ====================== 51 | 52 | - Next, create a Weights & Biases (wandb) sweep to perform a parameter search for model training. A sweep is a set of hyperparameters and their possible values that wandb uses to launch multiple runs with different configurations. 53 | 54 | - Open the `sweep.yml` file provided in the project repository, and specify the hyperparameters and their ranges or values for the sweep. 55 | 56 | - Save the `sweep.yml` file. 57 | 58 | Running the Sweep 59 | ================= 60 | 61 | - Once the sweep configuration is set, you can run the sweep using the `wandb sweep` command in your terminal, specifying the path to the `sweep.yml` file. 62 | 63 | .. code-block:: bash 64 | 65 | wandb sweep --project crop-forecasting sweep.yaml 66 | 67 | - This will generate a sweep ID, which you can use to launch the sweep runs. 68 | 69 | - Run the sweep using the `wandb agent` command, specifying the sweep ID. 70 | 71 | .. code-block:: bash 72 | 73 | wandb agent /crop-forecasting/ 74 | 75 | - The agent will start running the sweep runs with different hyperparameter configurations, and wandb will log the results for each run. 76 | 77 | Bonus - Machine Learning 78 | ======================== 79 | 80 | - Use the sweep.yml located in src/models/ml. 81 | 82 | - Adapt the path to the configuration file: 83 | 84 | .. code-block:: bash 85 | 86 | wandb sweep --project crop_forecasting src/models/ml/sweep.yaml 87 | -------------------------------------------------------------------------------- /docs/build/html/_sources/tutorial.rst.txt: -------------------------------------------------------------------------------- 1 | ======== 2 | Tutorial 3 | ======== 4 | 5 | Overview 6 | -------- 7 | 8 | Welcome to our data-driven project! In this tutorial, we will walk you through the steps to create and process a satellite dataset using the Cookiecutter architecture, and then run a parameter sweep using the Weights & Biases (wandb) platform. 9 | 10 | 11 | .. warning:: 12 | .. line-block:: 13 | Make sure you have installed all dependencies and have created a wandb account. 14 | See :doc:`Installation Guide ` for futher informations. 15 | 16 | 17 | Get Started |project_name| 18 | -------------------------- 19 | 20 | Navigate to the project directory in your terminal. 21 | 22 | Create satellite data 23 | ===================== 24 | 25 | - Run the `make_data.py` script to generate the satellite dataset (datasets has already been generated). 26 | 27 | .. code-block:: bash 28 | 29 | make data 30 | 31 | Preprocess data 32 | =============== 33 | 34 | - After generating the dataset, run the `make_preprocessing.py` script to preprocess and analyze the data. 35 | 36 | .. code-block:: bash 37 | 38 | make preprocess 39 | 40 | Train the model 41 | =============== 42 | 43 | - You can now train the model with the configuration available at the end of the file `make_train.py`. 44 | 45 | .. code-block:: bash 46 | 47 | make train 48 | 49 | Creating a Wandb Sweep 50 | ====================== 51 | 52 | - Next, create a Weights & Biases (wandb) sweep to perform a parameter search for model training. A sweep is a set of hyperparameters and their possible values that wandb uses to launch multiple runs with different configurations. 53 | 54 | - Open the `sweep.yml` file provided in the project repository, and specify the hyperparameters and their ranges or values for the sweep. 55 | 56 | - Save the `sweep.yml` file. 57 | 58 | Running the Sweep 59 | ================= 60 | 61 | - Once the sweep configuration is set, you can run the sweep using the `wandb sweep` command in your terminal, specifying the path to the `sweep.yml` file. 62 | 63 | .. code-block:: bash 64 | 65 | wandb sweep --project crop-forecasting sweep.yaml 66 | 67 | - This will generate a sweep ID, which you can use to launch the sweep runs. 68 | 69 | - Run the sweep using the `wandb agent` command, specifying the sweep ID. 70 | 71 | .. code-block:: bash 72 | 73 | wandb agent /crop-forecasting/ 74 | 75 | - The agent will start running the sweep runs with different hyperparameter configurations, and wandb will log the results for each run. 76 | 77 | Bonus - Machine Learning 78 | ======================== 79 | 80 | - Use the sweep.yml located in src/models/ml. 81 | 82 | - Adapt the path to the configuration file: 83 | 84 | .. code-block:: bash 85 | 86 | wandb sweep --project crop_forecasting src/models/ml/sweep.yaml 87 | -------------------------------------------------------------------------------- /docs/build/html/_static/pygments.css: -------------------------------------------------------------------------------- 1 | pre { line-height: 125%; } 2 | td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 3 | span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 4 | td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 5 | span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 6 | .highlight .hll { background-color: #ffffcc } 7 | .highlight { background: #ffffff; } 8 | .highlight .c { font-style: italic } /* Comment */ 9 | .highlight .err { border: 1px solid #FF0000 } /* Error */ 10 | .highlight .k { font-weight: bold } /* Keyword */ 11 | .highlight .ch { font-style: italic } /* Comment.Hashbang */ 12 | .highlight .cm { font-style: italic } /* Comment.Multiline */ 13 | .highlight .cpf { font-style: italic } /* Comment.PreprocFile */ 14 | .highlight .c1 { font-style: italic } /* Comment.Single */ 15 | .highlight .cs { font-style: italic } /* Comment.Special */ 16 | .highlight .ge { font-style: italic } /* Generic.Emph */ 17 | .highlight .gh { font-weight: bold } /* Generic.Heading */ 18 | .highlight .gp { font-weight: bold } /* Generic.Prompt */ 19 | .highlight .gs { font-weight: bold } /* Generic.Strong */ 20 | .highlight .gu { font-weight: bold } /* Generic.Subheading */ 21 | .highlight .kc { font-weight: bold } /* Keyword.Constant */ 22 | .highlight .kd { font-weight: bold } /* Keyword.Declaration */ 23 | .highlight .kn { font-weight: bold } /* Keyword.Namespace */ 24 | .highlight .kr { font-weight: bold } /* Keyword.Reserved */ 25 | .highlight .s { font-style: italic } /* Literal.String */ 26 | .highlight .nc { font-weight: bold } /* Name.Class */ 27 | .highlight .ni { font-weight: bold } /* Name.Entity */ 28 | .highlight .ne { font-weight: bold } /* Name.Exception */ 29 | .highlight .nn { font-weight: bold } /* Name.Namespace */ 30 | .highlight .nt { font-weight: bold } /* Name.Tag */ 31 | .highlight .ow { font-weight: bold } /* Operator.Word */ 32 | .highlight .sa { font-style: italic } /* Literal.String.Affix */ 33 | .highlight .sb { font-style: italic } /* Literal.String.Backtick */ 34 | .highlight .sc { font-style: italic } /* Literal.String.Char */ 35 | .highlight .dl { font-style: italic } /* Literal.String.Delimiter */ 36 | .highlight .sd { font-style: italic } /* Literal.String.Doc */ 37 | .highlight .s2 { font-style: italic } /* Literal.String.Double */ 38 | .highlight .se { font-weight: bold; font-style: italic } /* Literal.String.Escape */ 39 | .highlight .sh { font-style: italic } /* Literal.String.Heredoc */ 40 | .highlight .si { font-weight: bold; font-style: italic } /* Literal.String.Interpol */ 41 | .highlight .sx { font-style: italic } /* Literal.String.Other */ 42 | .highlight .sr { font-style: italic } /* Literal.String.Regex */ 43 | .highlight .s1 { font-style: italic } /* Literal.String.Single */ 44 | .highlight .ss { font-style: italic } /* Literal.String.Symbol */ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🍚 Crop Forecasting 2 | 3 | 4 | 5 | The project 2023 EY Open Science Data Challenge - Crop Forecasting is a Data Science project conducted as part of the challenge proposed by EY, Microsoft, and Cornell University. The objective of this project is to predict the yield of rice fields using satellite image data provided by Microsoft Planetary, meteorological data, and field data. 6 | 7 | ## 🏆 Challenge ranking 8 | The score of the challenge was the R2 score. 9 | Our solution was the 4th (out of 185 teams) one with a R2 score equal to 0.66 🎉. 10 | 11 | The podium: 12 | 🥇 Outatime - 0.68 13 | 🥈 Joshua Rexmond Nunoo Otoo - 0.68 14 | 🥉 Amma Simmons - 0.67 15 | 16 | ## 🛠️ Data processing 17 | 18 | 19 | ## 🏛️ Model architecture 20 | 21 | 22 | ## 📚 Documentation 23 | The project documentation, generated using Sphinx, can be found in the `docs/` directory. It provides detailed information about the project's setup, usage, implementation, tutorial. 24 | 25 | ## 🔬 References 26 | 27 | Jeong, S., Ko, J., & Yeom, J. M. (2022). Predicting rice yield at pixel scale through synthetic use of crop and deep learning models with satellite data in South and North Korea. Science of The Total Environment, 802, 149726. 28 | 29 | Nazir, A., Ullah, S., Saqib, Z. A., Abbas, A., Ali, A., Iqbal, M. S., ... & Butt, M. U. (2021). Estimation and forecasting of rice yield using phenology-based algorithm and linear regression model on sentinel-ii satellite data. Agriculture, 11(10), 1026. 30 | 31 | ## 📝 Citing 32 | 33 | ``` 34 | @misc{UrgellReberga:2023, 35 | Author = {Baptiste Urgell and Louis Reberga}, 36 | Title = {Crop forecasting}, 37 | Year = {2023}, 38 | Publisher = {GitHub}, 39 | Journal = {GitHub repository}, 40 | Howpublished = {\url{https://github.com/association-rosia/crop-forecasting}} 41 | } 42 | ``` 43 | 44 | ## 🛡️ License 45 | 46 | Project is distributed under [MIT License](https://github.com/association-rosia/crop-forecasting/blob/main/LICENSE) 47 | 48 | ## 👨🏻‍💻 Contributors 49 | 50 | Louis 51 | REBERGA 52 | 53 | Baptiste 54 | URGELL 55 | -------------------------------------------------------------------------------- /docs/source/install.rst: -------------------------------------------------------------------------------- 1 | ================== 2 | Installation Guide 3 | ================== 4 | 5 | Conda environement 6 | ================== 7 | 8 | This guide provides step-by-step instructions for installing the Conda environment. This allows you to recreate the environment to run our project with all its dependencies in a consistent and reproducible manner. 9 | 10 | Prerequisites 11 | ------------- 12 | 13 | Before installing the Conda environment, make sure you have the following prerequisites: 14 | 15 | - Conda: If you don't have Conda installed already, you can download and install it from the official `Miniconda `_ or `Anaconda `_ website. 16 | 17 | Installation Steps 18 | ------------------ 19 | 20 | Follow these steps to install the `crop-forecasting-env` Conda environment: 21 | 22 | - Open a terminal or command prompt. 23 | 24 | - Navigate to the main project directory named `crop-forecasting`. 25 | 26 | - Run the following command to create a Conda environment from the `environment.yml` file: 27 | 28 | .. code-block:: bash 29 | 30 | conda env create -f environment.yml 31 | 32 | - Activate the Conda environment: 33 | 34 | .. code-block:: bash 35 | 36 | conda activate crop-forecasting-env 37 | 38 | 39 | 40 | Weights & Biases Account 41 | ======================== 42 | 43 | Overview 44 | -------- 45 | 46 | Weights & Biases (wandb) is a powerful platform for tracking, visualizing, and analyzing machine learning experiments. To get started with wandb, you'll need to create an account on the wandb website. Here are the steps to create a wandb account: 47 | 48 | Go to the WandB website 49 | ----------------------- 50 | 51 | Open a web browser and navigate to the `wandb `_ website. 52 | 53 | Click on "Get Started for Free" 54 | ------------------------------- 55 | 56 | Click on the "Get Started for Free" button on the wandb homepage to begin the account creation process. 57 | 58 | Sign Up with a Provider or Email 59 | -------------------------------- 60 | 61 | You can sign up for a wandb account using your existing Google, GitHub, or email account. Choose the option that suits you best and follow the prompts to sign up. 62 | 63 | Complete the Sign Up Process 64 | ---------------------------- 65 | 66 | If you sign up with a provider like Google or GitHub, you'll be prompted to authorize wandb to access your account information. If you sign up with an email, you'll need to provide some additional information like your name and password. 67 | 68 | Verify Your Email (if applicable) 69 | --------------------------------- 70 | 71 | If you sign up with an email, you may need to verify your email address by clicking on a verification link sent to your email inbox. Follow the instructions in the email to complete the verification process. 72 | 73 | 74 | Congratulations ! You can now start using our project. 75 | ====================================================== 76 | See :doc:`tutorial` for futher informations about it. 77 | ===================================================== -------------------------------------------------------------------------------- /notebooks/data/concat_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "ca8ce2e2-68c9-4644-ab3b-7d120db8801d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import xarray as xr\n", 11 | "import numpy as np\n", 12 | "from os.path import join" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 7, 18 | "id": "c2bf647e-08a4-40f1-8555-ae863046038b", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "ROOT_DIR = '../../'\n", 23 | "\n", 24 | "FOLDER = 'augment_50_5'\n", 25 | "xdf_train_50 = xr.open_dataset(join(ROOT_DIR, 'data', 'processed', FOLDER, 'train_enriched.nc'), engine='scipy')\n", 26 | "xdf_test_50 = xr.open_dataset(join(ROOT_DIR, 'data', 'processed', FOLDER, 'test_enriched.nc'), engine='scipy')\n", 27 | "\n", 28 | "FOLDER = 'augment_40_5'\n", 29 | "xdf_train_40 = xr.open_dataset(join(ROOT_DIR, 'data', 'processed', FOLDER, 'train_enriched.nc'), engine='scipy')\n", 30 | "xdf_test_40 = xr.open_dataset(join(ROOT_DIR, 'data', 'processed', FOLDER, 'test_enriched.nc'), engine='scipy')\n", 31 | "xdf_train_40['ts_aug'] = np.arange(50, 90)\n", 32 | "xdf_test_40['ts_aug'] = np.arange(50, 90)\n", 33 | "\n", 34 | "FOLDER = 'augment_10_5'\n", 35 | "xdf_train_10 = xr.open_dataset(join(ROOT_DIR, 'data', 'processed', FOLDER, 'train_enriched.nc'), engine='scipy')\n", 36 | "xdf_test_10 = xr.open_dataset(join(ROOT_DIR, 'data', 'processed', FOLDER, 'test_enriched.nc'), engine='scipy')\n", 37 | "xdf_train_10['ts_aug'] = np.arange(90, 100)\n", 38 | "xdf_test_10['ts_aug'] = np.arange(90, 100)\n", 39 | "\n", 40 | "FOLDER = 'augment_100_5'\n", 41 | "xdf_train_100 = xr.merge([xdf_train_50, xdf_train_40, xdf_train_10], compat='no_conflicts')\n", 42 | "xdf_test_100 = xr.merge([xdf_test_50, xdf_test_40, xdf_test_10], compat='no_conflicts')\n", 43 | "\n", 44 | "id_shape = xdf_train_100['ts_id'].values.shape\n", 45 | "xdf_train_100['ts_id'] = (xdf_train_100['ts_id'].dims, np.reshape(np.arange(np.prod(id_shape)), id_shape))\n", 46 | "\n", 47 | "id_shape = xdf_test_100['ts_id'].values.shape\n", 48 | "xdf_test_100['ts_id'] = (xdf_test_100['ts_id'].dims, np.reshape(np.arange(np.prod(id_shape)), id_shape))\n", 49 | "\n", 50 | "xdf_train_100.to_netcdf(join(ROOT_DIR, 'data', 'processed', FOLDER, 'train_enriched.nc'), engine='scipy')\n", 51 | "xdf_test_100.to_netcdf(join(ROOT_DIR, 'data', 'processed', FOLDER, 'test_enriched.nc'), engine='scipy')" 52 | ] 53 | } 54 | ], 55 | "metadata": { 56 | "kernelspec": { 57 | "display_name": "ey-2023", 58 | "language": "python", 59 | "name": "python3" 60 | }, 61 | "language_info": { 62 | "codemirror_mode": { 63 | "name": "ipython", 64 | "version": 3 65 | }, 66 | "file_extension": ".py", 67 | "mimetype": "text/x-python", 68 | "name": "python", 69 | "nbconvert_exporter": "python", 70 | "pygments_lexer": "ipython3", 71 | "version": "3.9.16" 72 | } 73 | }, 74 | "nbformat": 4, 75 | "nbformat_minor": 5 76 | } 77 | -------------------------------------------------------------------------------- /docs/build/html/_sources/install.rst.txt: -------------------------------------------------------------------------------- 1 | ================== 2 | Installation Guide 3 | ================== 4 | 5 | Conda environement 6 | ================== 7 | 8 | This guide provides step-by-step instructions for installing the Conda environment. This allows you to recreate the environment to run our project with all its dependencies in a consistent and reproducible manner. 9 | 10 | Prerequisites 11 | ------------- 12 | 13 | Before installing the Conda environment, make sure you have the following prerequisites: 14 | 15 | - Conda: If you don't have Conda installed already, you can download and install it from the official `Miniconda `_ or `Anaconda `_ website. 16 | 17 | Installation Steps 18 | ------------------ 19 | 20 | Follow these steps to install the `crop-forecasting-env` Conda environment: 21 | 22 | - Open a terminal or command prompt. 23 | 24 | - Navigate to the main project directory named `crop-forecasting`. 25 | 26 | - Run the following command to create a Conda environment from the `environment.yml` file: 27 | 28 | .. code-block:: bash 29 | 30 | conda env create -f environment.yml 31 | 32 | - Activate the Conda environment: 33 | 34 | .. code-block:: bash 35 | 36 | conda activate crop-forecasting-env 37 | 38 | 39 | 40 | Weights & Biases Account 41 | ======================== 42 | 43 | Overview 44 | -------- 45 | 46 | Weights & Biases (wandb) is a powerful platform for tracking, visualizing, and analyzing machine learning experiments. To get started with wandb, you'll need to create an account on the wandb website. Here are the steps to create a wandb account: 47 | 48 | Go to the WandB website 49 | ----------------------- 50 | 51 | Open a web browser and navigate to the `wandb `_ website. 52 | 53 | Click on "Get Started for Free" 54 | ------------------------------- 55 | 56 | Click on the "Get Started for Free" button on the wandb homepage to begin the account creation process. 57 | 58 | Sign Up with a Provider or Email 59 | -------------------------------- 60 | 61 | You can sign up for a wandb account using your existing Google, GitHub, or email account. Choose the option that suits you best and follow the prompts to sign up. 62 | 63 | Complete the Sign Up Process 64 | ---------------------------- 65 | 66 | If you sign up with a provider like Google or GitHub, you'll be prompted to authorize wandb to access your account information. If you sign up with an email, you'll need to provide some additional information like your name and password. 67 | 68 | Verify Your Email (if applicable) 69 | --------------------------------- 70 | 71 | If you sign up with an email, you may need to verify your email address by clicking on a verification link sent to your email inbox. Follow the instructions in the email to complete the verification process. 72 | 73 | 74 | Congratulations ! You can now start using our project. 75 | ====================================================== 76 | See :doc:`tutorial` for futher informations about it. 77 | ===================================================== -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | =================================== 2 | 2023 EY Open Science Data Challenge 3 | =================================== 4 | 5 | Welcome to the Crop forecasting project ! 6 | ========================================= 7 | 8 | Overview 9 | -------- 10 | 11 | Welcome to our Climate Change and Crop Yield Prediction project! We are a team of two junior Data Scientists, Baptiste Urgell and Louis Reberga, recently graduated from `CY-Tech `_ and currently working at `Aqsone `_. In this project, we aim to build models that will help scientists better understand the impact of climate change on crop yields, with a focus on rice cultivation. 12 | 13 | Challenge Description 14 | --------------------- 15 | 16 | Climate change, along with man-made conflicts and economic downturns, is one of the main causes of acute hunger. By leveraging data from Microsoft's Planetary Computer, including Sentinel-1 (radar), Sentinel-2 (optical), and Landsat (optical) satellite data, we will develop predictive models to forecast crop yields in areas of rice cultivation. Our goal is to contribute to the scientific community's understanding of the effects of climate change on crop production and support efforts to mitigate its impact on food security. 17 | 18 | About Us 19 | -------- 20 | 21 | - Baptiste Urgell: I am a Data Scientist with a background in statistics and machine learning. I have a strong interest in applying data-driven solutions to real-world challenges, particularly in the field of environmental sustainability. I am excited to contribute to this project and apply my skills to better understand the impacts of climate change on crop yields. 22 | 23 | - Louis Reberga: I am a Data Scientist with a passion for leveraging data and analytics to generate insights and solve complex problems. My expertise lies in machine learning, data visualization, and statistical analysis. I am eager to work on this project and contribute to the development of models that can provide valuable insights into the effects of climate change on rice cultivation. 24 | 25 | Project Goals 26 | ------------- 27 | 28 | Our main objectives for this project are: 29 | 30 | - Build predictive models using machine/deep learning techniques to forecast crop yields for rice cultivation areas. 31 | 32 | - Utilize satellite data from Microsoft's Planetary Computer (Sentinel-2), Meteorological data and Paddies informations from EY, to train and validate our models. 33 | 34 | - Analyze the impact of climate change on crop yields and provide insights to the scientific community to support efforts in addressing food security challenges. 35 | 36 | Conclusion 37 | ---------- 38 | 39 | We are excited to embark on this project and contribute to the understanding of the impacts of climate change on crop yields. By leveraging our skills in data science and machine/deep learning, along with the rich dataset provided by Microsoft's Planetary Computer, we aim to develop models that can provide valuable insights for scientists and stakeholders working on addressing the challenges of climate change and food security. 40 | 41 | 42 | 43 | .. toctree:: 44 | :hidden: 45 | :caption: Navigation 46 | :titlesonly: 47 | 48 | install 49 | tutorial 50 | documentation 51 | -------------------------------------------------------------------------------- /docs/build/html/_sources/index.rst.txt: -------------------------------------------------------------------------------- 1 | =================================== 2 | 2023 EY Open Science Data Challenge 3 | =================================== 4 | 5 | Welcome to the Crop forecasting project ! 6 | ========================================= 7 | 8 | Overview 9 | -------- 10 | 11 | Welcome to our Climate Change and Crop Yield Prediction project! We are a team of two junior Data Scientists, Baptiste Urgell and Louis Reberga, recently graduated from `CY-Tech `_ and currently working at `Aqsone `_. In this project, we aim to build models that will help scientists better understand the impact of climate change on crop yields, with a focus on rice cultivation. 12 | 13 | Challenge Description 14 | --------------------- 15 | 16 | Climate change, along with man-made conflicts and economic downturns, is one of the main causes of acute hunger. By leveraging data from Microsoft's Planetary Computer, including Sentinel-1 (radar), Sentinel-2 (optical), and Landsat (optical) satellite data, we will develop predictive models to forecast crop yields in areas of rice cultivation. Our goal is to contribute to the scientific community's understanding of the effects of climate change on crop production and support efforts to mitigate its impact on food security. 17 | 18 | About Us 19 | -------- 20 | 21 | - Baptiste Urgell: I am a Data Scientist with a background in statistics and machine learning. I have a strong interest in applying data-driven solutions to real-world challenges, particularly in the field of environmental sustainability. I am excited to contribute to this project and apply my skills to better understand the impacts of climate change on crop yields. 22 | 23 | - Louis Reberga: I am a Data Scientist with a passion for leveraging data and analytics to generate insights and solve complex problems. My expertise lies in machine learning, data visualization, and statistical analysis. I am eager to work on this project and contribute to the development of models that can provide valuable insights into the effects of climate change on rice cultivation. 24 | 25 | Project Goals 26 | ------------- 27 | 28 | Our main objectives for this project are: 29 | 30 | - Build predictive models using machine/deep learning techniques to forecast crop yields for rice cultivation areas. 31 | 32 | - Utilize satellite data from Microsoft's Planetary Computer (Sentinel-2), Meteorological data and Paddies informations from EY, to train and validate our models. 33 | 34 | - Analyze the impact of climate change on crop yields and provide insights to the scientific community to support efforts in addressing food security challenges. 35 | 36 | Conclusion 37 | ---------- 38 | 39 | We are excited to embark on this project and contribute to the understanding of the impacts of climate change on crop yields. By leveraging our skills in data science and machine/deep learning, along with the rich dataset provided by Microsoft's Planetary Computer, we aim to develop models that can provide valuable insights for scientists and stakeholders working on addressing the challenges of climate change and food security. 40 | 41 | 42 | 43 | .. toctree:: 44 | :hidden: 45 | :caption: Navigation 46 | :titlesonly: 47 | 48 | install 49 | tutorial 50 | documentation 51 | -------------------------------------------------------------------------------- /notebooks/data/enrich_xarray.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "2faca195-623f-46f3-bb78-db480ad4e7bd", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import xarray as xr\n", 11 | "import pandas as pd\n", 12 | "from os.path import join\n", 13 | "from sklearn.preprocessing import MinMaxScaler" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "id": "b4df6089-0d34-4a8c-a1b6-caa2596527da", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "ROOT_DIR = '/home/jovyan/crop-forecasting'\n", 24 | "FOLDER = 'augment_50_5'" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "id": "88babe3e-04d2-47a7-ae88-5466652c0f79", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "xdf_train = xr.open_dataset(join(ROOT_DIR, 'data', 'processed', FOLDER, 'train.nc'), engine='scipy')\n", 35 | "df_train = pd.read_csv(join(ROOT_DIR, 'data', 'interim', 'train_enriched.csv'))\n", 36 | "df_train.set_index(['Unnamed: 0'], inplace=True)\n", 37 | "dictio = {'Unnamed: 0': xdf_train['ts_obs']}\n", 38 | "\n", 39 | "scaler = MinMaxScaler()\n", 40 | "scaler.fit(df_train[['Other Rice Yield (kg/ha)']])\n", 41 | "df_train['Other Rice Yield (kg/ha)'] = scaler.transform(df_train[['Other Rice Yield (kg/ha)']])\n", 42 | "\n", 43 | "xdf_train = xr.merge([xdf_train, df_train[['Other Rice Yield (kg/ha)']].to_xarray().sel(**dictio)], compat='override')\n", 44 | "xdf_train = xdf_train.reset_coords('Unnamed: 0', drop=True)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 4, 50 | "id": "9f47da15-436e-46d2-b54b-65ec0960f788", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "xdf_test = xr.open_dataset(join(ROOT_DIR, 'data', 'processed', FOLDER, 'test.nc'), engine='scipy')\n", 55 | "df_test = pd.read_csv(join(ROOT_DIR, 'data', 'interim', 'test_enriched.csv'))\n", 56 | "df_test.set_index(['Unnamed: 0'], inplace=True)\n", 57 | "dictio = {'Unnamed: 0': xdf_test['ts_obs']}\n", 58 | "\n", 59 | "scaler.fit(df_test[['Other Rice Yield (kg/ha)']])\n", 60 | "df_test['Other Rice Yield (kg/ha)'] = scaler.transform(df_test[['Other Rice Yield (kg/ha)']])\n", 61 | "\n", 62 | "xdf_test = xr.merge([xdf_test, df_test[['Other Rice Yield (kg/ha)']].to_xarray().sel(**dictio)], compat='override')\n", 63 | "xdf_test = xdf_test.reset_coords('Unnamed: 0', drop=True)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 5, 69 | "id": "4d5c99d3-5984-4cbb-9ae0-b4853d3e808d", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "xdf_train.to_netcdf(join(ROOT_DIR, 'data', 'processed', FOLDER, 'train_enriched.nc'), engine='scipy')\n", 74 | "xdf_test.to_netcdf(join(ROOT_DIR, 'data', 'processed', FOLDER, 'test_enriched.nc'), engine='scipy')" 75 | ] 76 | } 77 | ], 78 | "metadata": { 79 | "kernelspec": { 80 | "display_name": "Python [conda env:notebook] *", 81 | "language": "python", 82 | "name": "conda-env-notebook-py" 83 | }, 84 | "language_info": { 85 | "codemirror_mode": { 86 | "name": "ipython", 87 | "version": 3 88 | }, 89 | "file_extension": ".py", 90 | "mimetype": "text/x-python", 91 | "name": "python", 92 | "nbconvert_exporter": "python", 93 | "pygments_lexer": "ipython3", 94 | "version": "3.9.13" 95 | } 96 | }, 97 | "nbformat": 4, 98 | "nbformat_minor": 5 99 | } 100 | -------------------------------------------------------------------------------- /notebooks/data/enrich_csv.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "1fd5b4f7-f249-44e7-9ba0-d3b9c9cb32fe", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import pandas as pd\n", 11 | "from os.path import join" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "6c4f11fb-3a03-4a59-9531-7ffc875238be", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "ROOT_DIR = '/home/jovyan/crop-forecasting'\n", 22 | "\n", 23 | "train_df = pd.read_csv(join(ROOT_DIR, 'data', 'raw', 'train.csv'))\n", 24 | "test_df = pd.read_csv(join(ROOT_DIR, 'data', 'raw', 'test.csv'))" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "id": "5c1f22e3-9fbb-4fb7-8a96-a1aa47c847a0", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "crop_yields = train_df['Rice Yield (kg/ha)'].unique().tolist()\n", 35 | "\n", 36 | "def rounded_yield(x, crop_yields):\n", 37 | " diffs = [abs(x - crop_yield) for crop_yield in crop_yields]\n", 38 | " return crop_yields[diffs.index(min(diffs))]" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 4, 44 | "id": "323c1a93-9c0b-41ac-8c70-09eb71d6f165", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "def get_other_yield(row):\n", 49 | " Season = 'Season(SA = Summer Autumn, WS = Winter Spring)'\n", 50 | " other_yield = train_df[(train_df['District'] == row['District']) & \n", 51 | " (train_df['Latitude'] == row['Latitude']) & \n", 52 | " (train_df['Longitude'] == row['Longitude']) & \n", 53 | " (train_df[Season] != row[Season])]['Rice Yield (kg/ha)'].tolist()\n", 54 | "\n", 55 | " if len(other_yield) > 0:\n", 56 | " res = other_yield[0]\n", 57 | " else:\n", 58 | " res = train_df[(train_df['District'] == row['District']) & \n", 59 | " (train_df[Season] != row[Season])]['Rice Yield (kg/ha)'].mean()\n", 60 | " \n", 61 | " return res" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 5, 67 | "id": "05cd7053-6192-41f4-9cd4-b7eb1107f81d", 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "train_df['Other Rice Yield (kg/ha)'] = train_df.apply(get_other_yield, axis='columns')\n", 72 | "train_df['Other Rice Yield (kg/ha)'] = train_df['Other Rice Yield (kg/ha)'].apply(lambda x: rounded_yield(x, crop_yields)).astype('int64')\n", 73 | "\n", 74 | "train_df.to_csv(join(ROOT_DIR, 'data', 'interim', 'train_enriched.csv'))" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 6, 80 | "id": "9b7a5609-6bb0-4695-adbf-0ca83579c18c", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "test_df['Other Rice Yield (kg/ha)'] = test_df.apply(get_other_yield, axis='columns')\n", 85 | "test_df['Other Rice Yield (kg/ha)'] = test_df['Other Rice Yield (kg/ha)'].apply(lambda x: rounded_yield(x, crop_yields)).astype('int64')\n", 86 | "\n", 87 | "test_df.to_csv(join(ROOT_DIR, 'data', 'interim', 'test_enriched.csv'))" 88 | ] 89 | } 90 | ], 91 | "metadata": { 92 | "kernelspec": { 93 | "display_name": "Python 3 (ipykernel)", 94 | "language": "python", 95 | "name": "python3" 96 | }, 97 | "language_info": { 98 | "codemirror_mode": { 99 | "name": "ipython", 100 | "version": 3 101 | }, 102 | "file_extension": ".py", 103 | "mimetype": "text/x-python", 104 | "name": "python", 105 | "nbconvert_exporter": "python", 106 | "pygments_lexer": "ipython3", 107 | "version": "3.9.13" 108 | } 109 | }, 110 | "nbformat": 4, 111 | "nbformat_minor": 5 112 | } 113 | -------------------------------------------------------------------------------- /src/models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LSTMModel(nn.Module): 6 | def __init__(self, config, device): 7 | super(LSTMModel, self).__init__() 8 | self.s_hidden_size = config['s_hidden_size'] 9 | self.m_hidden_size = config['m_hidden_size'] 10 | self.s_num_features = config['s_num_features'] 11 | 12 | self.s_num_layers = config['s_num_layers'] 13 | self.m_num_layers = config['m_num_layers'] 14 | self.m_num_features = config['m_num_features'] 15 | 16 | self.g_in_features = config['g_in_features'] 17 | 18 | self.c_in_features = config['c_in_features'] 19 | self.c_out_in_features_1 = config['c_out_in_features_1'] 20 | self.c_out_in_features_2 = config['c_out_in_features_2'] 21 | 22 | self.dropout = config['dropout'] 23 | 24 | self.device = device 25 | 26 | self.s_lstm = nn.LSTM(self.s_num_features, self.s_hidden_size, self.s_num_layers, batch_first=True) 27 | self.s_bn_lstm = nn.BatchNorm1d(self.s_hidden_size) 28 | self.s_cnn = nn.Conv1d(1, 1, kernel_size=3) 29 | self.s_bn_cnn = nn.BatchNorm1d(self.s_hidden_size - 2) # because kernel_size = 3 30 | 31 | self.m_lstm = nn.LSTM(self.m_num_features, self.m_hidden_size, self.m_num_layers, batch_first=True) 32 | self.m_bn_lstm = nn.BatchNorm1d(self.m_hidden_size) 33 | self.m_cnn = nn.Conv1d(1, 1, kernel_size=3) 34 | self.m_bn_cnn = nn.BatchNorm1d(self.m_hidden_size - 2) 35 | 36 | self.c_linear_1 = nn.Linear(self.c_in_features, self.c_out_in_features_1) 37 | self.c_bn_1 = nn.BatchNorm1d(self.c_out_in_features_1) 38 | self.c_linear_2 = nn.Linear(self.c_out_in_features_1, self.c_out_in_features_2) 39 | self.c_bn_2 = nn.BatchNorm1d(self.c_out_in_features_2) 40 | self.c_linear_3 = nn.Linear(self.c_out_in_features_2, 1) 41 | 42 | self.tanh = nn.Tanh() 43 | self.relu = nn.ReLU() 44 | self.dropout = nn.Dropout(self.dropout) 45 | 46 | def forward(self, x): 47 | s_input = x['s_input'] 48 | m_input = x['m_input'] 49 | g_input = x['g_input'] 50 | 51 | # Spectral LSTM 52 | s_h0 = torch.zeros(self.s_num_layers, s_input.size(0), self.s_hidden_size).requires_grad_().to(self.device) 53 | s_c0 = torch.zeros(self.s_num_layers, s_input.size(0), self.s_hidden_size).requires_grad_().to(self.device) 54 | s_output, _ = self.s_lstm(s_input, (s_h0, s_c0)) 55 | s_output = self.s_bn_lstm(s_output[:, -1, :]) 56 | s_output = self.tanh(s_output) 57 | s_output = self.dropout(s_output) 58 | 59 | # Spectral Conv1D 60 | s_output = torch.unsqueeze(s_output, 1) 61 | s_output = self.s_cnn(s_output) 62 | s_output = self.s_bn_cnn(torch.squeeze(s_output)) 63 | s_output = self.relu(s_output) 64 | s_output = self.dropout(s_output) 65 | 66 | # Meteorological LSTM 67 | m_h0 = torch.zeros(self.m_num_layers, m_input.size(0), self.m_hidden_size).requires_grad_().to(self.device) 68 | m_c0 = torch.zeros(self.m_num_layers, m_input.size(0), self.m_hidden_size).requires_grad_().to(self.device) 69 | m_output, _ = self.m_lstm(m_input, (m_h0, m_c0)) 70 | m_output = self.m_bn_lstm(m_output[:, -1, :]) 71 | m_output = self.tanh(m_output) 72 | m_output = self.dropout(m_output) 73 | 74 | # Meteorological Conv1D 75 | m_output = torch.unsqueeze(m_output, 1) 76 | m_output = self.m_cnn(m_output) 77 | m_output = self.m_bn_cnn(torch.squeeze(m_output)) 78 | m_output = self.relu(m_output) 79 | m_output = self.dropout(m_output) 80 | 81 | # Concatenate inputs 82 | c_input = torch.cat((s_output, m_output, g_input), 1) 83 | c_output = self.c_bn_1(self.c_linear_1(c_input)) 84 | c_output = self.relu(c_output) 85 | c_output = self.dropout(c_output) 86 | c_output = self.c_bn_2(self.c_linear_2(c_output)) 87 | c_output = self.relu(c_output) 88 | c_output = self.dropout(c_output) 89 | output = self.c_linear_3(c_output) 90 | 91 | return output 92 | -------------------------------------------------------------------------------- /src/models/ml/make_train.py: -------------------------------------------------------------------------------- 1 | # Native library 2 | import os, sys 3 | 4 | path = os.path.join(".") 5 | sys.path.insert(1, path) 6 | 7 | from utils import ROOT_DIR 8 | 9 | # Data management 10 | import numpy as np 11 | import xarray as xr 12 | 13 | from src.constants import TARGET, FOLDER, S_COLUMNS 14 | 15 | from tqdm import tqdm 16 | 17 | # Data prepocessing 18 | from src.data.preprocessing import Smoother, Convertor, Filler, Sorter 19 | from sklearn.preprocessing import StandardScaler 20 | from sklearn.decomposition import PCA 21 | 22 | # Hyperparameter Optimization 23 | import wandb 24 | 25 | # Regressor models 26 | from xgboost import XGBRegressor 27 | 28 | # Training 29 | from sklearn.model_selection import KFold 30 | from sklearn.pipeline import Pipeline 31 | 32 | 33 | def main(): 34 | data_path = os.path.join(ROOT_DIR, "data", "interim", FOLDER, "train.nc") 35 | xds = xr.open_dataset(data_path, engine="scipy") 36 | 37 | obs_idx = xds["ts_obs"].values 38 | obs_idx = obs_idx.reshape(-1, 1) 39 | 40 | wandb.init( 41 | project="winged-bull", 42 | group="Machine Learning", 43 | ) 44 | 45 | pipeline = init_pipeline() 46 | 47 | val_R2_score = 0 48 | n_splits = wandb.config.n_splits 49 | 50 | p_bar = tqdm(KFold(n_splits=n_splits).split(obs_idx), total=n_splits, leave=False) 51 | for i, (index_train, index_test) in enumerate(p_bar): 52 | xds_train = xds.sel(ts_obs=obs_idx[index_train].reshape(-1)) 53 | xds_test = xds.sel(ts_obs=obs_idx[index_test].reshape(-1)) 54 | 55 | y_train = preprocess_y(xds_train) 56 | y_test = preprocess_y(xds_test) 57 | 58 | pipeline.fit(X=xds_train, y=y_train) 59 | val_split_R2_score = pipeline.score(X=xds_test, y=y_test) 60 | val_R2_score += val_split_R2_score 61 | p_bar.write(f"Split {i + 1}/{n_splits}: R2 score = {val_split_R2_score:.5f}") 62 | wandb.log({"val_split_r2_score": val_split_R2_score}) 63 | 64 | print(f"Mean R2 score = {(val_R2_score / n_splits):.5f}") 65 | wandb.log({"val_r2_score": val_R2_score / n_splits}) 66 | 67 | 68 | def init_pipeline() -> Pipeline: 69 | """Initialise scikit-learn Pipeline with the wandb configuration. 70 | 71 | :return: Configured Pipeline. 72 | :rtype: Pipeline 73 | """ 74 | params_pipeline = { 75 | "smoother__mode": None if isinstance(wandb.config.vi, bool) else wandb.config.vi, 76 | "convertor__agg": wandb.config.dim_reduction == "Aggregate", 77 | "convertor__weather": wandb.config.weather, 78 | "convertor__vi": isinstance(wandb.config.vi, str) or wandb.config.vi, 79 | "estimator__n_estimators": wandb.config.n_estimators, 80 | "estimator__colsample_bytree": wandb.config.colsample_bytree, 81 | "estimator__colsample_bylevel": wandb.config.colsample_bylevel, 82 | "estimator__colsample_bynode": wandb.config.colsample_bynode, 83 | "estimator__subsample": wandb.config.subsample, 84 | "estimator__max_depth": wandb.config.max_depth, 85 | "estimator__learning_rate": wandb.config.learning_rate, 86 | } 87 | 88 | steps_pipeline = [ 89 | ("filler", Filler()), 90 | ("smoother", Smoother()), 91 | ("convertor", Convertor()), 92 | ("sorter", Sorter()), 93 | ] 94 | 95 | if wandb.config.dim_reduction == "PCA": 96 | steps_pipeline.append(("scaler", StandardScaler())) 97 | steps_pipeline.append(("dim_reductor", PCA(n_components="mle"))) 98 | 99 | steps_pipeline.append(("estimator", XGBRegressor())) 100 | 101 | pipeline = Pipeline(steps_pipeline) 102 | pipeline.set_params(**params_pipeline) 103 | 104 | return pipeline 105 | 106 | 107 | def preprocess_y(xds: xr.Dataset) -> np.ndarray: 108 | """Preprocess target to match processed samples. 109 | 110 | :param xds: Dataset containing the target. 111 | :type xds: xr.Dataset 112 | :return: Target processed. 113 | :rtype: np.ndarray 114 | """ 115 | df = xds[[TARGET] + S_COLUMNS].to_dataframe() 116 | y = df[[TARGET]].groupby(["ts_obs", "ts_aug"]).first() 117 | return y.reorder_levels(["ts_obs", "ts_aug"]).sort_index().to_numpy() 118 | 119 | 120 | if __name__ == "__main__": 121 | main() 122 | -------------------------------------------------------------------------------- /src/models/make_submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import pandas as pd 5 | import torch 6 | import torch.nn as nn 7 | from sklearn.preprocessing import MinMaxScaler 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | parent = os.path.abspath('.') 12 | sys.path.insert(1, parent) 13 | 14 | from os.path import join 15 | 16 | from src.models.dataloader import get_dataloaders 17 | 18 | from src.constants import TARGET, TARGET_TEST 19 | from utils import ROOT_DIR 20 | 21 | 22 | MODEL = 'toasty-sky-343.pt' 23 | 24 | 25 | def main()->None: 26 | # get the device 27 | device = get_device() 28 | 29 | # get the test dataloader 30 | _, _, test_dataloader = get_dataloaders(batch_size=64, val_rate=0.2, device=device) 31 | 32 | # create the evaluator 33 | evaluator = Evaluator(test_dataloader, device) 34 | 35 | # load the model 36 | model_path = join(ROOT_DIR, 'models', MODEL) 37 | model = torch.load(model_path).to(device) 38 | 39 | # evaluate the model on the test set 40 | evaluator.evaluate(model) 41 | 42 | 43 | def rounded_yield(x: float, crop_yields: list) -> float: 44 | """ Rounded predictions using the labelled crop yields. 45 | 46 | :param x: Current prediction 47 | :type: float 48 | :param crop_yields: Labelled crop yields values 49 | :type: list 50 | :return: Rounded predictions 51 | :rtype: float 52 | """ 53 | diffs = [abs(x - crop_yield) for crop_yield in crop_yields] 54 | return crop_yields[diffs.index(min(diffs))] 55 | 56 | 57 | def get_device() -> str: 58 | """ Get GPU device, return Exception if no GPU is available. 59 | 60 | :return: GPU device 61 | :rtype: str 62 | """ 63 | if torch.cuda.is_available(): 64 | device = 'cuda' 65 | elif torch.backends.mps.is_available(): 66 | device = 'mps' 67 | else: 68 | device = 'cpu' 69 | 70 | return device 71 | 72 | 73 | def create_submission(observations: list, preds: list) -> None: 74 | """ Create submission file using the predictions. 75 | 76 | :param observations: Obseravtions indexes 77 | :type observations: list 78 | :param preds: Associated predictions 79 | :type preds: list 80 | """ 81 | df = pd.DataFrame() 82 | df['observations'] = observations 83 | df['preds'] = preds 84 | df = df.groupby(['observations']).mean() 85 | df = df.sort_values(by='observations') 86 | 87 | test_path = join(ROOT_DIR, 'data', 'raw', 'test.csv') 88 | test_df = pd.read_csv(test_path) 89 | 90 | # scale the data using MinMaxScaler 91 | scaler = MinMaxScaler() 92 | train_path = join(ROOT_DIR, 'data', 'raw', 'train.csv') 93 | train_df = pd.read_csv(train_path) 94 | scaler.fit(train_df[[TARGET]]) 95 | 96 | # transform back the predictions 97 | test_df[TARGET_TEST] = scaler.inverse_transform(df[['preds']]) 98 | 99 | crop_yields = train_df[TARGET].unique().tolist() 100 | test_df[TARGET_TEST] = test_df[TARGET_TEST].apply(lambda x: rounded_yield(x, crop_yields)) 101 | os.makedirs('submissions', exist_ok=True) 102 | test_df.to_csv(f'submissions/{MODEL.split(".")[0]}.csv', index=False) 103 | 104 | 105 | class Evaluator: 106 | """ Evaluate model performance on test set. 107 | 108 | :param test_dataloader: Test dataloader 109 | :type test_dataloader: DataLoader 110 | :param device: GPU device 111 | :type device: str 112 | """ 113 | def __init__(self, test_dataloader: DataLoader, device: str): 114 | self.test_dataloader = test_dataloader 115 | self.device = device 116 | 117 | def evaluate(self, model: nn.Module): 118 | """ Evaluate model performance on test set using the rounded predictions. 119 | 120 | :param model: Our PyTorch model 121 | :type model: nn.Module 122 | :return: 123 | """ 124 | observations = [] 125 | test_preds = [] 126 | 127 | pbar = tqdm(self.test_dataloader, leave=False) 128 | for i, data in enumerate(pbar): 129 | keys_input = ['s_input', 'm_input', 'g_input'] 130 | inputs = {key: data[key].to(self.device) for key in keys_input} 131 | outputs = model(inputs) 132 | 133 | observations += data['observation'].squeeze().tolist() 134 | test_preds += outputs.squeeze().tolist() 135 | 136 | pbar.set_description(f'TEST - Batch: {i + 1}/{len(self.test_dataloader)}') 137 | 138 | create_submission(observations, test_preds) 139 | 140 | 141 | if __name__ == '__main__': 142 | main() -------------------------------------------------------------------------------- /src/models/make_train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings('ignore') 4 | 5 | import os 6 | import sys 7 | 8 | parent = os.path.abspath('.') 9 | sys.path.insert(1, parent) 10 | 11 | import torch 12 | import torch.nn as nn 13 | import wandb 14 | from src.models.dataloader import get_dataloaders 15 | from src.models.model import LSTMModel 16 | from torch.utils.data import DataLoader 17 | from src.models.trainer import Trainer 18 | 19 | 20 | def main(): 21 | # empty the GPU cache 22 | torch.cuda.empty_cache() 23 | 24 | # get the device 25 | device = get_device() 26 | 27 | # init W&B logger and get the model config from W&B sweep config yaml file 28 | # + get the training and validation dataloaders 29 | config, train_dataloader, val_dataloader = init_wandb() 30 | 31 | # init the model 32 | model = LSTMModel(config, device) 33 | model.to(device) 34 | 35 | # init the loss, optimizer and learning rate scheduler 36 | criterion = nn.MSELoss() 37 | optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate']) 38 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, 39 | patience=config['scheduler_patience'], 40 | verbose=True) 41 | 42 | train_config = { 43 | 'model': model, 44 | 'train_dataloader': train_dataloader, 45 | 'val_dataloader': val_dataloader, 46 | 'epochs': config['epochs'], 47 | 'criterion': criterion, 48 | 'optimizer': optimizer, 49 | 'scheduler': scheduler, 50 | } 51 | 52 | # init the trainer 53 | trainer = Trainer(**train_config) 54 | 55 | # train the model 56 | trainer.train() 57 | 58 | 59 | def init_wandb() -> tuple[dict, DataLoader, DataLoader]: 60 | """ Init W&B logger and get the model config from W&B sweep config yaml file 61 | + get the training and validation dataloaders. 62 | 63 | :return: the model config and the training and validation dataloaders 64 | :rtype: (dict, DataLoader, DataLoader) 65 | """ 66 | 67 | epochs = wandb.config.epochs 68 | batch_size = wandb.config.batch_size 69 | 70 | learning_rate = wandb.config.learning_rate 71 | scheduler_patience = wandb.config.scheduler_patience 72 | 73 | s_hidden_size = wandb.config.s_hidden_size 74 | m_hidden_size = wandb.config.m_hidden_size 75 | s_num_layers = wandb.config.s_num_layers 76 | m_num_layers = wandb.config.m_num_layers 77 | c_out_in_features_1 = wandb.config.c_out_in_features_1 78 | c_out_in_features_2 = wandb.config.c_out_in_features_2 79 | dropout = wandb.config.dropout 80 | 81 | train_dataloader, val_dataloader, _ = get_dataloaders(batch_size, 0.2, get_device()) 82 | first_row = train_dataloader.dataset[0] 83 | 84 | c_in_features = s_hidden_size - 2 + m_hidden_size - 2 + first_row['g_input'].shape[0] 85 | 86 | config = { 87 | 'epochs': epochs, 88 | 'batch_size': batch_size, 89 | 'train_size': len(train_dataloader), 90 | 'val_size': len(val_dataloader), 91 | 'learning_rate': learning_rate, 92 | 'scheduler_patience': scheduler_patience, 93 | 's_hidden_size': s_hidden_size, 94 | 'm_hidden_size': m_hidden_size, 95 | 's_num_features': first_row['s_input'].shape[1], 96 | 's_num_layers': s_num_layers, 97 | 'm_num_layers': m_num_layers, 98 | 'm_num_features': first_row['m_input'].shape[1], 99 | 'g_in_features': first_row['g_input'].shape[0], 100 | 'c_in_features': c_in_features, 101 | 'c_out_in_features_1': c_out_in_features_1, 102 | 'c_out_in_features_2': c_out_in_features_2, 103 | 'dropout': dropout, 104 | } 105 | 106 | return config, train_dataloader, val_dataloader 107 | 108 | 109 | def get_device(): 110 | if torch.cuda.is_available(): 111 | device = 'cuda' 112 | elif torch.backends.mps.is_available(): 113 | device = 'mps' 114 | else: 115 | raise Exception("None accelerator available") 116 | 117 | return device 118 | 119 | 120 | if __name__ == '__main__': 121 | 122 | if sys.argv[1] == '--manual' or sys.argv[1] == '-m': 123 | wandb.init( 124 | project="crop-forecasting", 125 | entity="winged-bull", 126 | group="test", 127 | config=dict( 128 | epochs=20, 129 | batch_size=16, 130 | learning_rate=0.0001, 131 | scheduler_patience=6, 132 | s_hidden_size=256, 133 | m_hidden_size=256, 134 | s_num_layers=2, 135 | m_num_layers=2, 136 | c_out_in_features_1=256, 137 | c_out_in_features_2=128, 138 | dropout=.4, 139 | ), 140 | ) 141 | else: 142 | wandb.init(project="crop-forecasting", entity="winged-bull", group='Deep Learning') 143 | 144 | main() 145 | -------------------------------------------------------------------------------- /docs/build/html/_static/doctools.js: -------------------------------------------------------------------------------- 1 | /* 2 | * doctools.js 3 | * ~~~~~~~~~~~ 4 | * 5 | * Base JavaScript utilities for all Sphinx HTML documentation. 6 | * 7 | * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | "use strict"; 12 | 13 | const BLACKLISTED_KEY_CONTROL_ELEMENTS = new Set([ 14 | "TEXTAREA", 15 | "INPUT", 16 | "SELECT", 17 | "BUTTON", 18 | ]); 19 | 20 | const _ready = (callback) => { 21 | if (document.readyState !== "loading") { 22 | callback(); 23 | } else { 24 | document.addEventListener("DOMContentLoaded", callback); 25 | } 26 | }; 27 | 28 | /** 29 | * Small JavaScript module for the documentation. 30 | */ 31 | const Documentation = { 32 | init: () => { 33 | Documentation.initDomainIndexTable(); 34 | Documentation.initOnKeyListeners(); 35 | }, 36 | 37 | /** 38 | * i18n support 39 | */ 40 | TRANSLATIONS: {}, 41 | PLURAL_EXPR: (n) => (n === 1 ? 0 : 1), 42 | LOCALE: "unknown", 43 | 44 | // gettext and ngettext don't access this so that the functions 45 | // can safely bound to a different name (_ = Documentation.gettext) 46 | gettext: (string) => { 47 | const translated = Documentation.TRANSLATIONS[string]; 48 | switch (typeof translated) { 49 | case "undefined": 50 | return string; // no translation 51 | case "string": 52 | return translated; // translation exists 53 | default: 54 | return translated[0]; // (singular, plural) translation tuple exists 55 | } 56 | }, 57 | 58 | ngettext: (singular, plural, n) => { 59 | const translated = Documentation.TRANSLATIONS[singular]; 60 | if (typeof translated !== "undefined") 61 | return translated[Documentation.PLURAL_EXPR(n)]; 62 | return n === 1 ? singular : plural; 63 | }, 64 | 65 | addTranslations: (catalog) => { 66 | Object.assign(Documentation.TRANSLATIONS, catalog.messages); 67 | Documentation.PLURAL_EXPR = new Function( 68 | "n", 69 | `return (${catalog.plural_expr})` 70 | ); 71 | Documentation.LOCALE = catalog.locale; 72 | }, 73 | 74 | /** 75 | * helper function to focus on search bar 76 | */ 77 | focusSearchBar: () => { 78 | document.querySelectorAll("input[name=q]")[0]?.focus(); 79 | }, 80 | 81 | /** 82 | * Initialise the domain index toggle buttons 83 | */ 84 | initDomainIndexTable: () => { 85 | const toggler = (el) => { 86 | const idNumber = el.id.substr(7); 87 | const toggledRows = document.querySelectorAll(`tr.cg-${idNumber}`); 88 | if (el.src.substr(-9) === "minus.png") { 89 | el.src = `${el.src.substr(0, el.src.length - 9)}plus.png`; 90 | toggledRows.forEach((el) => (el.style.display = "none")); 91 | } else { 92 | el.src = `${el.src.substr(0, el.src.length - 8)}minus.png`; 93 | toggledRows.forEach((el) => (el.style.display = "")); 94 | } 95 | }; 96 | 97 | const togglerElements = document.querySelectorAll("img.toggler"); 98 | togglerElements.forEach((el) => 99 | el.addEventListener("click", (event) => toggler(event.currentTarget)) 100 | ); 101 | togglerElements.forEach((el) => (el.style.display = "")); 102 | if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) togglerElements.forEach(toggler); 103 | }, 104 | 105 | initOnKeyListeners: () => { 106 | // only install a listener if it is really needed 107 | if ( 108 | !DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS && 109 | !DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS 110 | ) 111 | return; 112 | 113 | document.addEventListener("keydown", (event) => { 114 | // bail for input elements 115 | if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return; 116 | // bail with special keys 117 | if (event.altKey || event.ctrlKey || event.metaKey) return; 118 | 119 | if (!event.shiftKey) { 120 | switch (event.key) { 121 | case "ArrowLeft": 122 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break; 123 | 124 | const prevLink = document.querySelector('link[rel="prev"]'); 125 | if (prevLink && prevLink.href) { 126 | window.location.href = prevLink.href; 127 | event.preventDefault(); 128 | } 129 | break; 130 | case "ArrowRight": 131 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break; 132 | 133 | const nextLink = document.querySelector('link[rel="next"]'); 134 | if (nextLink && nextLink.href) { 135 | window.location.href = nextLink.href; 136 | event.preventDefault(); 137 | } 138 | break; 139 | } 140 | } 141 | 142 | // some keyboard layouts may need Shift to get / 143 | switch (event.key) { 144 | case "/": 145 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) break; 146 | Documentation.focusSearchBar(); 147 | event.preventDefault(); 148 | } 149 | }); 150 | }, 151 | }; 152 | 153 | // quick alias for translations 154 | const _ = Documentation.gettext; 155 | 156 | _ready(Documentation.init); 157 | -------------------------------------------------------------------------------- /docs/build/html/_static/sphinx_highlight.js: -------------------------------------------------------------------------------- 1 | /* Highlighting utilities for Sphinx HTML documentation. */ 2 | "use strict"; 3 | 4 | const SPHINX_HIGHLIGHT_ENABLED = true 5 | 6 | /** 7 | * highlight a given string on a node by wrapping it in 8 | * span elements with the given class name. 9 | */ 10 | const _highlight = (node, addItems, text, className) => { 11 | if (node.nodeType === Node.TEXT_NODE) { 12 | const val = node.nodeValue; 13 | const parent = node.parentNode; 14 | const pos = val.toLowerCase().indexOf(text); 15 | if ( 16 | pos >= 0 && 17 | !parent.classList.contains(className) && 18 | !parent.classList.contains("nohighlight") 19 | ) { 20 | let span; 21 | 22 | const closestNode = parent.closest("body, svg, foreignObject"); 23 | const isInSVG = closestNode && closestNode.matches("svg"); 24 | if (isInSVG) { 25 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); 26 | } else { 27 | span = document.createElement("span"); 28 | span.classList.add(className); 29 | } 30 | 31 | span.appendChild(document.createTextNode(val.substr(pos, text.length))); 32 | parent.insertBefore( 33 | span, 34 | parent.insertBefore( 35 | document.createTextNode(val.substr(pos + text.length)), 36 | node.nextSibling 37 | ) 38 | ); 39 | node.nodeValue = val.substr(0, pos); 40 | 41 | if (isInSVG) { 42 | const rect = document.createElementNS( 43 | "http://www.w3.org/2000/svg", 44 | "rect" 45 | ); 46 | const bbox = parent.getBBox(); 47 | rect.x.baseVal.value = bbox.x; 48 | rect.y.baseVal.value = bbox.y; 49 | rect.width.baseVal.value = bbox.width; 50 | rect.height.baseVal.value = bbox.height; 51 | rect.setAttribute("class", className); 52 | addItems.push({ parent: parent, target: rect }); 53 | } 54 | } 55 | } else if (node.matches && !node.matches("button, select, textarea")) { 56 | node.childNodes.forEach((el) => _highlight(el, addItems, text, className)); 57 | } 58 | }; 59 | const _highlightText = (thisNode, text, className) => { 60 | let addItems = []; 61 | _highlight(thisNode, addItems, text, className); 62 | addItems.forEach((obj) => 63 | obj.parent.insertAdjacentElement("beforebegin", obj.target) 64 | ); 65 | }; 66 | 67 | /** 68 | * Small JavaScript module for the documentation. 69 | */ 70 | const SphinxHighlight = { 71 | 72 | /** 73 | * highlight the search words provided in localstorage in the text 74 | */ 75 | highlightSearchWords: () => { 76 | if (!SPHINX_HIGHLIGHT_ENABLED) return; // bail if no highlight 77 | 78 | // get and clear terms from localstorage 79 | const url = new URL(window.location); 80 | const highlight = 81 | localStorage.getItem("sphinx_highlight_terms") 82 | || url.searchParams.get("highlight") 83 | || ""; 84 | localStorage.removeItem("sphinx_highlight_terms") 85 | url.searchParams.delete("highlight"); 86 | window.history.replaceState({}, "", url); 87 | 88 | // get individual terms from highlight string 89 | const terms = highlight.toLowerCase().split(/\s+/).filter(x => x); 90 | if (terms.length === 0) return; // nothing to do 91 | 92 | // There should never be more than one element matching "div.body" 93 | const divBody = document.querySelectorAll("div.body"); 94 | const body = divBody.length ? divBody[0] : document.querySelector("body"); 95 | window.setTimeout(() => { 96 | terms.forEach((term) => _highlightText(body, term, "highlighted")); 97 | }, 10); 98 | 99 | const searchBox = document.getElementById("searchbox"); 100 | if (searchBox === null) return; 101 | searchBox.appendChild( 102 | document 103 | .createRange() 104 | .createContextualFragment( 105 | '" 109 | ) 110 | ); 111 | }, 112 | 113 | /** 114 | * helper function to hide the search marks again 115 | */ 116 | hideSearchWords: () => { 117 | document 118 | .querySelectorAll("#searchbox .highlight-link") 119 | .forEach((el) => el.remove()); 120 | document 121 | .querySelectorAll("span.highlighted") 122 | .forEach((el) => el.classList.remove("highlighted")); 123 | localStorage.removeItem("sphinx_highlight_terms") 124 | }, 125 | 126 | initEscapeListener: () => { 127 | // only install a listener if it is really needed 128 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) return; 129 | 130 | document.addEventListener("keydown", (event) => { 131 | // bail for input elements 132 | if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return; 133 | // bail with special keys 134 | if (event.shiftKey || event.altKey || event.ctrlKey || event.metaKey) return; 135 | if (DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS && (event.key === "Escape")) { 136 | SphinxHighlight.hideSearchWords(); 137 | event.preventDefault(); 138 | } 139 | }); 140 | }, 141 | }; 142 | 143 | _ready(SphinxHighlight.highlightSearchWords); 144 | _ready(SphinxHighlight.initEscapeListener); 145 | -------------------------------------------------------------------------------- /docs/build/html/_static/language_data.js: -------------------------------------------------------------------------------- 1 | /* 2 | * language_data.js 3 | * ~~~~~~~~~~~~~~~~ 4 | * 5 | * This script contains the language-specific data used by searchtools.js, 6 | * namely the list of stopwords, stemmer, scorer and splitter. 7 | * 8 | * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS. 9 | * :license: BSD, see LICENSE for details. 10 | * 11 | */ 12 | 13 | var stopwords = ["a", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in", "into", "is", "it", "near", "no", "not", "of", "on", "or", "such", "that", "the", "their", "then", "there", "these", "they", "this", "to", "was", "will", "with"]; 14 | 15 | 16 | /* Non-minified version is copied as a separate JS file, is available */ 17 | 18 | /** 19 | * Porter Stemmer 20 | */ 21 | var Stemmer = function() { 22 | 23 | var step2list = { 24 | ational: 'ate', 25 | tional: 'tion', 26 | enci: 'ence', 27 | anci: 'ance', 28 | izer: 'ize', 29 | bli: 'ble', 30 | alli: 'al', 31 | entli: 'ent', 32 | eli: 'e', 33 | ousli: 'ous', 34 | ization: 'ize', 35 | ation: 'ate', 36 | ator: 'ate', 37 | alism: 'al', 38 | iveness: 'ive', 39 | fulness: 'ful', 40 | ousness: 'ous', 41 | aliti: 'al', 42 | iviti: 'ive', 43 | biliti: 'ble', 44 | logi: 'log' 45 | }; 46 | 47 | var step3list = { 48 | icate: 'ic', 49 | ative: '', 50 | alize: 'al', 51 | iciti: 'ic', 52 | ical: 'ic', 53 | ful: '', 54 | ness: '' 55 | }; 56 | 57 | var c = "[^aeiou]"; // consonant 58 | var v = "[aeiouy]"; // vowel 59 | var C = c + "[^aeiouy]*"; // consonant sequence 60 | var V = v + "[aeiou]*"; // vowel sequence 61 | 62 | var mgr0 = "^(" + C + ")?" + V + C; // [C]VC... is m>0 63 | var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1 64 | var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1 65 | var s_v = "^(" + C + ")?" + v; // vowel in stem 66 | 67 | this.stemWord = function (w) { 68 | var stem; 69 | var suffix; 70 | var firstch; 71 | var origword = w; 72 | 73 | if (w.length < 3) 74 | return w; 75 | 76 | var re; 77 | var re2; 78 | var re3; 79 | var re4; 80 | 81 | firstch = w.substr(0,1); 82 | if (firstch == "y") 83 | w = firstch.toUpperCase() + w.substr(1); 84 | 85 | // Step 1a 86 | re = /^(.+?)(ss|i)es$/; 87 | re2 = /^(.+?)([^s])s$/; 88 | 89 | if (re.test(w)) 90 | w = w.replace(re,"$1$2"); 91 | else if (re2.test(w)) 92 | w = w.replace(re2,"$1$2"); 93 | 94 | // Step 1b 95 | re = /^(.+?)eed$/; 96 | re2 = /^(.+?)(ed|ing)$/; 97 | if (re.test(w)) { 98 | var fp = re.exec(w); 99 | re = new RegExp(mgr0); 100 | if (re.test(fp[1])) { 101 | re = /.$/; 102 | w = w.replace(re,""); 103 | } 104 | } 105 | else if (re2.test(w)) { 106 | var fp = re2.exec(w); 107 | stem = fp[1]; 108 | re2 = new RegExp(s_v); 109 | if (re2.test(stem)) { 110 | w = stem; 111 | re2 = /(at|bl|iz)$/; 112 | re3 = new RegExp("([^aeiouylsz])\\1$"); 113 | re4 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 114 | if (re2.test(w)) 115 | w = w + "e"; 116 | else if (re3.test(w)) { 117 | re = /.$/; 118 | w = w.replace(re,""); 119 | } 120 | else if (re4.test(w)) 121 | w = w + "e"; 122 | } 123 | } 124 | 125 | // Step 1c 126 | re = /^(.+?)y$/; 127 | if (re.test(w)) { 128 | var fp = re.exec(w); 129 | stem = fp[1]; 130 | re = new RegExp(s_v); 131 | if (re.test(stem)) 132 | w = stem + "i"; 133 | } 134 | 135 | // Step 2 136 | re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/; 137 | if (re.test(w)) { 138 | var fp = re.exec(w); 139 | stem = fp[1]; 140 | suffix = fp[2]; 141 | re = new RegExp(mgr0); 142 | if (re.test(stem)) 143 | w = stem + step2list[suffix]; 144 | } 145 | 146 | // Step 3 147 | re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/; 148 | if (re.test(w)) { 149 | var fp = re.exec(w); 150 | stem = fp[1]; 151 | suffix = fp[2]; 152 | re = new RegExp(mgr0); 153 | if (re.test(stem)) 154 | w = stem + step3list[suffix]; 155 | } 156 | 157 | // Step 4 158 | re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/; 159 | re2 = /^(.+?)(s|t)(ion)$/; 160 | if (re.test(w)) { 161 | var fp = re.exec(w); 162 | stem = fp[1]; 163 | re = new RegExp(mgr1); 164 | if (re.test(stem)) 165 | w = stem; 166 | } 167 | else if (re2.test(w)) { 168 | var fp = re2.exec(w); 169 | stem = fp[1] + fp[2]; 170 | re2 = new RegExp(mgr1); 171 | if (re2.test(stem)) 172 | w = stem; 173 | } 174 | 175 | // Step 5 176 | re = /^(.+?)e$/; 177 | if (re.test(w)) { 178 | var fp = re.exec(w); 179 | stem = fp[1]; 180 | re = new RegExp(mgr1); 181 | re2 = new RegExp(meq1); 182 | re3 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 183 | if (re.test(stem) || (re2.test(stem) && !(re3.test(stem)))) 184 | w = stem; 185 | } 186 | re = /ll$/; 187 | re2 = new RegExp(mgr1); 188 | if (re.test(w) && re2.test(w)) { 189 | re = /.$/; 190 | w = w.replace(re,""); 191 | } 192 | 193 | // and turn initial Y back to y 194 | if (firstch == "y") 195 | w = firstch.toLowerCase() + w.substr(1); 196 | return w; 197 | } 198 | } 199 | 200 | -------------------------------------------------------------------------------- /data/raw/test.csv: -------------------------------------------------------------------------------- 1 | ID No,District,Latitude,Longitude,"Season(SA = Summer Autumn, WS = Winter Spring)","Rice Crop Intensity(D=Double, T=Triple)",Date of Harvest,Field size (ha),Predicted Rice Yield (kg/ha) 2 | 1,Chau_Phu,10.542192,105.18792,WS,T,10-04-2022,1.4, 3 | 2,Chau_Thanh,10.400189,105.331053,SA,T,15-07-2022,1.32, 4 | 3,Chau_Phu,10.505489,105.203926,SA,D,14-07-2022,1.4, 5 | 4,Chau_Phu,10.52352,105.138274,WS,D,10-04-2022,1.8, 6 | 5,Thoai_Son,10.29466,105.248528,SA,T,20-07-2022,2.2, 7 | 6,Chau_Phu,10.633572,105.172813,WS,D,12-04-2022,2.5, 8 | 7,Chau_Thanh,10.434116,105.27315,SA,T,15-07-2022,2.2, 9 | 8,Chau_Phu,10.61225,105.175364,SA,D,09-08-2022,4, 10 | 9,Thoai_Son,10.268095,105.344826,WS,T,10-04-2022,1.5, 11 | 10,Chau_Phu,10.523312,105.286299,WS,T,01-04-2022,3, 12 | 11,Chau_Phu,10.554634,105.1522,SA,D,04-07-2022,2.3, 13 | 12,Chau_Phu,10.542998,105.261159,SA,D,14-07-2022,4.9, 14 | 13,Chau_Phu,10.473786,105.190479,SA,T,14-07-2022,1.7, 15 | 14,Chau_Phu,10.52352,105.138274,SA,D,20-07-2022,1.8, 16 | 15,Chau_Thanh,10.371317,105.31009,SA,T,04-08-2022,2, 17 | 16,Chau_Thanh,10.456905,105.186106,WS,T,26-03-2022,1.529, 18 | 17,Chau_Phu,10.441423,105.115088,SA,D,20-07-2022,4, 19 | 18,Thoai_Son,10.227343,105.230821,SA,T,20-07-2022,1.8, 20 | 19,Thoai_Son,10.279477,105.288145,SA,T,20-07-2022,2, 21 | 20,Thoai_Son,10.326776,105.33819,WS,T,12-04-2022,1.7, 22 | 21,Thoai_Son,10.293434,105.256309,WS,T,13-04-2022,2.3, 23 | 22,Chau_Thanh,10.405702,105.208668,SA,T,20-07-2022,1.32, 24 | 23,Chau_Phu,10.56984,105.190938,WS,T,10-04-2022,2.3, 25 | 24,Chau_Thanh,10.41636,105.142144,SA,T,26-07-2022,1.43, 26 | 25,Chau_Thanh,10.371317,105.31009,WS,T,10-04-2022,2, 27 | 26,Chau_Thanh,10.40097,105.348815,WS,T,25-03-2022,2.53, 28 | 27,Chau_Phu,10.594603,105.175686,WS,D,12-04-2022,2.2, 29 | 28,Chau_Phu,10.514912,105.216308,SA,D,14-07-2022,2, 30 | 29,Thoai_Son,10.293434,105.256309,SA,T,20-07-2022,2.3, 31 | 30,Chau_Phu,10.48194,105.151543,SA,D,14-07-2022,2.9, 32 | 31,Thoai_Son,10.238171,105.198793,SA,T,20-07-2022,3.6, 33 | 32,Thoai_Son,10.326776,105.33819,SA,T,20-07-2022,1.7, 34 | 33,Chau_Thanh,10.421158,105.237197,WS,T,27-03-2022,1.76, 35 | 34,Chau_Thanh,10.369692,105.29684,WS,T,10-04-2022,4, 36 | 35,Chau_Phu,10.473786,105.190479,WS,T,03-04-2022,1.7, 37 | 36,Thoai_Son,10.365734,105.267634,WS,T,12-04-2022,2.5, 38 | 37,Thoai_Son,10.281814,105.114918,SA,T,25-07-2022,5, 39 | 38,Chau_Thanh,10.411649,105.370659,SA,T,15-07-2022,1.43, 40 | 39,Chau_Phu,10.477062,105.168941,WS,D,03-04-2022,3.4, 41 | 40,Chau_Thanh,10.437524,105.202166,SA,D,20-07-2022,2.31, 42 | 41,Chau_Thanh,10.414613,105.125615,WS,T,27-03-2022,3.85, 43 | 42,Thoai_Son,10.28697,105.341571,SA,T,20-07-2022,1.78, 44 | 43,Chau_Phu,10.64394,105.165296,SA,T,05-08-2022,1.4, 45 | 44,Chau_Thanh,10.420981,105.295797,WS,T,25-03-2022,1.76, 46 | 45,Thoai_Son,10.340626,105.172116,SA,T,23-07-2022,3, 47 | 46,Chau_Phu,10.625193,105.181059,SA,D,09-08-2022,3, 48 | 47,Chau_Phu,10.479304,105.102943,SA,D,20-07-2022,2.9, 49 | 48,Chau_Thanh,10.392899,105.188514,WS,T,24-03-2022,1.43, 50 | 49,Chau_Thanh,10.426536,105.115181,SA,T,14-07-2022,1.43, 51 | 50,Chau_Phu,10.552114,105.091399,SA,D,05-08-2022,3.6, 52 | 51,Chau_Phu,10.541934,105.247538,WS,T,10-04-2022,2, 53 | 52,Thoai_Son,10.303437,105.381252,WS,T,28-03-2022,2.1, 54 | 53,Thoai_Son,10.368837,105.205763,SA,T,21-07-2022,2.3, 55 | 54,Chau_Thanh,10.440557,105.250671,WS,T,27-03-2022,1.76, 56 | 55,Chau_Thanh,10.430707,105.315671,SA,T,10-07-2022,1.54, 57 | 56,Chau_Thanh,10.378696,105.309113,WS,T,27-03-2022,1.32, 58 | 57,Thoai_Son,10.283467,105.267082,WS,T,20-04-2022,2.2, 59 | 58,Chau_Thanh,10.436907,105.236241,WS,T,27-03-2022,1.375, 60 | 59,Chau_Thanh,10.387619,105.243467,SA,T,19-07-2022,1.3, 61 | 60,Chau_Phu,10.482003,105.203866,SA,T,15-07-2022,3, 62 | 61,Chau_Phu,10.545098,105.112895,SA,D,20-07-2022,3.6, 63 | 62,Thoai_Son,10.279477,105.288145,WS,T,10-04-2022,2, 64 | 63,Chau_Thanh,10.40466,105.311554,SA,T,17-07-2022,1.87, 65 | 64,Thoai_Son,10.34489,105.243935,SA,T,23-07-2022,2.3, 66 | 65,Chau_Phu,10.623506,105.132794,WS,T,10-04-2022,5, 67 | 66,Thoai_Son,10.303766,105.203102,WS,T,16-04-2022,2.2, 68 | 67,Chau_Phu,10.656963,105.152679,WS,D,10-04-2022,3, 69 | 68,Thoai_Son,10.312202,105.330633,WS,T,10-04-2022,3, 70 | 69,Chau_Thanh,10.440335,105.22309,WS,D,27-03-2022,1.76, 71 | 70,Thoai_Son,10.306886,105.290958,SA,T,20-07-2022,1.85, 72 | 71,Chau_Thanh,10.392045,105.307085,SA,T,17-07-2022,1.43, 73 | 72,Chau_Thanh,10.375619,105.124248,WS,T,24-03-2022,1.87, 74 | 73,Thoai_Son,10.313955,105.243521,SA,T,20-07-2022,1.3, 75 | 74,Chau_Phu,10.592666,105.141217,SA,T,05-08-2022,2.8, 76 | 75,Thoai_Son,10.266311,105.23555,WS,T,20-04-2022,1.5, 77 | 76,Chau_Thanh,10.386546,105.19302,SA,T,14-07-2022,1.65, 78 | 77,Thoai_Son,10.282365,105.276189,WS,T,10-04-2022,1.91, 79 | 78,Thoai_Son,10.33791,105.357013,SA,T,20-07-2022,2, 80 | 79,Thoai_Son,10.31856,105.374468,WS,D,28-03-2022,1.4, 81 | 80,Chau_Thanh,10.429212,105.141436,WS,D,27-03-2022,1.21, 82 | 81,Chau_Phu,10.474439,105.216928,SA,T,15-07-2022,7, 83 | 82,Chau_Thanh,10.443488,105.236111,WS,T,27-03-2022,1.65, 84 | 83,Chau_Thanh,10.409367,105.355252,WS,T,25-03-2022,2.42, 85 | 84,Thoai_Son,10.292156,105.361294,SA,T,12-07-2022,2.3, 86 | 85,Chau_Phu,10.469839,105.211568,WS,T,01-04-2022,4, 87 | 86,Thoai_Son,10.257436,105.217205,SA,T,22-07-2022,2.3, 88 | 87,Chau_Phu,10.658291,105.127704,WS,T,10-04-2022,3, 89 | 88,Chau_Thanh,10.407982,105.123304,WS,D,27-03-2022,2.75, 90 | 89,Thoai_Son,10.320684,105.272431,SA,T,20-07-2022,3, 91 | 90,Chau_Phu,10.501648,105.096892,WS,T,10-04-2022,1.2, 92 | 91,Chau_Phu,10.490352,105.23065,WS,T,02-04-2022,4, 93 | 92,Thoai_Son,10.34489,105.243935,WS,T,12-04-2022,2.3, 94 | 93,Chau_Thanh,10.440557,105.250671,SA,T,28-07-2022,1.76, 95 | 94,Thoai_Son,10.320371,105.259016,WS,T,13-04-2022,1.52, 96 | 95,Thoai_Son,10.250745,105.24539,WS,T,20-04-2022,3, 97 | 96,Chau_Thanh,10.435839,105.132981,SA,D,26-07-2022,1.21, 98 | 97,Chau_Phu,10.529357,105.147388,WS,T,10-04-2022,2, 99 | 98,Chau_Thanh,10.452537,105.205118,SA,T,20-07-2022,5.5, 100 | 99,Chau_Thanh,10.394341,105.126836,SA,T,14-07-2022,4.4, 101 | 100,Chau_Phu,10.48065,105.130089,WS,T,10-04-2022,2, 102 | -------------------------------------------------------------------------------- /submissions/toasty-sky-343.csv: -------------------------------------------------------------------------------- 1 | ID No,District,Latitude,Longitude,"Season(SA = Summer Autumn, WS = Winter Spring)","Rice Crop Intensity(D=Double, T=Triple)",Date of Harvest,Field size (ha),Predicted Rice Yield (kg/ha) 2 | 1,Chau_Phu,10.542192,105.18792,WS,T,10-04-2022,1.4,7200 3 | 2,Chau_Thanh,10.400189,105.331053,SA,T,15-07-2022,1.32,6000 4 | 3,Chau_Phu,10.505489,105.203926,SA,D,14-07-2022,1.4,6000 5 | 4,Chau_Phu,10.52352,105.138274,WS,D,10-04-2022,1.8,6960 6 | 5,Thoai_Son,10.29466,105.248528,SA,T,20-07-2022,2.2,6000 7 | 6,Chau_Phu,10.633572,105.172813,WS,D,12-04-2022,2.5,7200 8 | 7,Chau_Thanh,10.434116,105.27315,SA,T,15-07-2022,2.2,6000 9 | 8,Chau_Phu,10.61225,105.175364,SA,D,09-08-2022,4.0,6000 10 | 9,Thoai_Son,10.268095,105.344826,WS,T,10-04-2022,1.5,7200 11 | 10,Chau_Phu,10.523312,105.286299,WS,T,01-04-2022,3.0,7200 12 | 11,Chau_Phu,10.554634,105.1522,SA,D,04-07-2022,2.3,6000 13 | 12,Chau_Phu,10.542998,105.261159,SA,D,14-07-2022,4.9,6000 14 | 13,Chau_Phu,10.473786,105.190479,SA,T,14-07-2022,1.7,6000 15 | 14,Chau_Phu,10.52352,105.138274,SA,D,20-07-2022,1.8,6000 16 | 15,Chau_Thanh,10.371317,105.31009,SA,T,04-08-2022,2.0,6000 17 | 16,Chau_Thanh,10.456905,105.186106,WS,T,26-03-2022,1.529,7200 18 | 17,Chau_Phu,10.441423,105.115088,SA,D,20-07-2022,4.0,6000 19 | 18,Thoai_Son,10.227343,105.230821,SA,T,20-07-2022,1.8,6000 20 | 19,Thoai_Son,10.279477,105.288145,SA,T,20-07-2022,2.0,6000 21 | 20,Thoai_Son,10.326776,105.33819,WS,T,12-04-2022,1.7,7400 22 | 21,Thoai_Son,10.293434,105.256309,WS,T,13-04-2022,2.3,7200 23 | 22,Chau_Thanh,10.405702,105.208668,SA,T,20-07-2022,1.32,6000 24 | 23,Chau_Phu,10.56984,105.190938,WS,T,10-04-2022,2.3,7200 25 | 24,Chau_Thanh,10.41636,105.142144,SA,T,26-07-2022,1.43,6000 26 | 25,Chau_Thanh,10.371317,105.31009,WS,T,10-04-2022,2.0,7400 27 | 26,Chau_Thanh,10.40097,105.348815,WS,T,25-03-2022,2.53,7400 28 | 27,Chau_Phu,10.594603,105.175686,WS,D,12-04-2022,2.2,7200 29 | 28,Chau_Phu,10.514912,105.216308,SA,D,14-07-2022,2.0,6000 30 | 29,Thoai_Son,10.293434,105.256309,SA,T,20-07-2022,2.3,6000 31 | 30,Chau_Phu,10.48194,105.151543,SA,D,14-07-2022,2.9,6000 32 | 31,Thoai_Son,10.238171,105.198793,SA,T,20-07-2022,3.6,6000 33 | 32,Thoai_Son,10.326776,105.33819,SA,T,20-07-2022,1.7,6000 34 | 33,Chau_Thanh,10.421158,105.237197,WS,T,27-03-2022,1.76,7400 35 | 34,Chau_Thanh,10.369692,105.29684,WS,T,10-04-2022,4.0,7200 36 | 35,Chau_Phu,10.473786,105.190479,WS,T,03-04-2022,1.7,7200 37 | 36,Thoai_Son,10.365734,105.267634,WS,T,12-04-2022,2.5,7200 38 | 37,Thoai_Son,10.281814,105.114918,SA,T,25-07-2022,5.0,6000 39 | 38,Chau_Thanh,10.411649,105.370659,SA,T,15-07-2022,1.43,6000 40 | 39,Chau_Phu,10.477062,105.168941,WS,D,03-04-2022,3.4,6800 41 | 40,Chau_Thanh,10.437524,105.202166,SA,D,20-07-2022,2.31,6000 42 | 41,Chau_Thanh,10.414613,105.125615,WS,T,27-03-2022,3.85,7040 43 | 42,Thoai_Son,10.28697,105.341571,SA,T,20-07-2022,1.78,6000 44 | 43,Chau_Phu,10.64394,105.165296,SA,T,05-08-2022,1.4,6000 45 | 44,Chau_Thanh,10.420981,105.295797,WS,T,25-03-2022,1.76,7200 46 | 45,Thoai_Son,10.340626,105.172116,SA,T,23-07-2022,3.0,6000 47 | 46,Chau_Phu,10.625193,105.181059,SA,D,09-08-2022,3.0,6000 48 | 47,Chau_Phu,10.479304,105.102943,SA,D,20-07-2022,2.9,6000 49 | 48,Chau_Thanh,10.392899,105.188514,WS,T,24-03-2022,1.43,7200 50 | 49,Chau_Thanh,10.426536,105.115181,SA,T,14-07-2022,1.43,6000 51 | 50,Chau_Phu,10.552114,105.091399,SA,D,05-08-2022,3.6,6000 52 | 51,Chau_Phu,10.541934,105.247538,WS,T,10-04-2022,2.0,7400 53 | 52,Thoai_Son,10.303437,105.381252,WS,T,28-03-2022,2.1,7400 54 | 53,Thoai_Son,10.368837,105.205763,SA,T,21-07-2022,2.3,6000 55 | 54,Chau_Thanh,10.440557,105.250671,WS,T,27-03-2022,1.76,7200 56 | 55,Chau_Thanh,10.430707,105.315671,SA,T,10-07-2022,1.54,6000 57 | 56,Chau_Thanh,10.378696,105.309113,WS,T,27-03-2022,1.32,7200 58 | 57,Thoai_Son,10.283467,105.267082,WS,T,20-04-2022,2.2,7400 59 | 58,Chau_Thanh,10.436907,105.236241,WS,T,27-03-2022,1.375,7400 60 | 59,Chau_Thanh,10.387619,105.243467,SA,T,19-07-2022,1.3,6000 61 | 60,Chau_Phu,10.482003,105.203866,SA,T,15-07-2022,3.0,6000 62 | 61,Chau_Phu,10.545098,105.112895,SA,D,20-07-2022,3.6,6000 63 | 62,Thoai_Son,10.279477,105.288145,WS,T,10-04-2022,2.0,7200 64 | 63,Chau_Thanh,10.40466,105.311554,SA,T,17-07-2022,1.87,6000 65 | 64,Thoai_Son,10.34489,105.243935,SA,T,23-07-2022,2.3,6000 66 | 65,Chau_Phu,10.623506,105.132794,WS,T,10-04-2022,5.0,7040 67 | 66,Thoai_Son,10.303766,105.203102,WS,T,16-04-2022,2.2,7400 68 | 67,Chau_Phu,10.656963,105.152679,WS,D,10-04-2022,3.0,7040 69 | 68,Thoai_Son,10.312202,105.330633,WS,T,10-04-2022,3.0,7200 70 | 69,Chau_Thanh,10.440335,105.22309,WS,D,27-03-2022,1.76,7040 71 | 70,Thoai_Son,10.306886,105.290958,SA,T,20-07-2022,1.85,6000 72 | 71,Chau_Thanh,10.392045,105.307085,SA,T,17-07-2022,1.43,6000 73 | 72,Chau_Thanh,10.375619,105.124248,WS,T,24-03-2022,1.87,7040 74 | 73,Thoai_Son,10.313955,105.243521,SA,T,20-07-2022,1.3,6000 75 | 74,Chau_Phu,10.592666,105.141217,SA,T,05-08-2022,2.8,6000 76 | 75,Thoai_Son,10.266311,105.23555,WS,T,20-04-2022,1.5,7400 77 | 76,Chau_Thanh,10.386546,105.19302,SA,T,14-07-2022,1.65,6000 78 | 77,Thoai_Son,10.282365,105.276189,WS,T,10-04-2022,1.91,7400 79 | 78,Thoai_Son,10.33791,105.357013,SA,T,20-07-2022,2.0,6000 80 | 79,Thoai_Son,10.31856,105.374468,WS,D,28-03-2022,1.4,7200 81 | 80,Chau_Thanh,10.429212,105.141436,WS,D,27-03-2022,1.21,7040 82 | 81,Chau_Phu,10.474439,105.216928,SA,T,15-07-2022,7.0,6000 83 | 82,Chau_Thanh,10.443488,105.236111,WS,T,27-03-2022,1.65,7200 84 | 83,Chau_Thanh,10.409367,105.355252,WS,T,25-03-2022,2.42,7400 85 | 84,Thoai_Son,10.292156,105.361294,SA,T,12-07-2022,2.3,6000 86 | 85,Chau_Phu,10.469839,105.211568,WS,T,01-04-2022,4.0,7200 87 | 86,Thoai_Son,10.257436,105.217205,SA,T,22-07-2022,2.3,6000 88 | 87,Chau_Phu,10.658291,105.127704,WS,T,10-04-2022,3.0,7040 89 | 88,Chau_Thanh,10.407982,105.123304,WS,D,27-03-2022,2.75,7040 90 | 89,Thoai_Son,10.320684,105.272431,SA,T,20-07-2022,3.0,6000 91 | 90,Chau_Phu,10.501648,105.096892,WS,T,10-04-2022,1.2,7200 92 | 91,Chau_Phu,10.490352,105.23065,WS,T,02-04-2022,4.0,7200 93 | 92,Thoai_Son,10.34489,105.243935,WS,T,12-04-2022,2.3,7200 94 | 93,Chau_Thanh,10.440557,105.250671,SA,T,28-07-2022,1.76,6000 95 | 94,Thoai_Son,10.320371,105.259016,WS,T,13-04-2022,1.52,7200 96 | 95,Thoai_Son,10.250745,105.24539,WS,T,20-04-2022,3.0,7200 97 | 96,Chau_Thanh,10.435839,105.132981,SA,D,26-07-2022,1.21,6000 98 | 97,Chau_Phu,10.529357,105.147388,WS,T,10-04-2022,2.0,7200 99 | 98,Chau_Thanh,10.452537,105.205118,SA,T,20-07-2022,5.5,6000 100 | 99,Chau_Thanh,10.394341,105.126836,SA,T,14-07-2022,4.4,6000 101 | 100,Chau_Phu,10.48065,105.130089,WS,T,10-04-2022,2.0,6960 102 | -------------------------------------------------------------------------------- /src/data/datascaler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Union 4 | 5 | import xarray as xr 6 | from sklearn.preprocessing import (MinMaxScaler, PowerTransformer, 7 | QuantileTransformer, RobustScaler, 8 | StandardScaler) 9 | 10 | parent = os.path.abspath("../features") 11 | sys.path.insert(1, parent) 12 | 13 | 14 | class DatasetScaler: 15 | """Scaler for Vegetable Indice, Geographical, Meteorological and Target. 16 | 17 | :param scaler_s: Scikit-Learn scaler for Vegetable Indice data 18 | :type scaler_s: Union[StandardScaler, RobustScaler, PowerTransformer, QuantileTransformer] 19 | :param columns_s: Vegetable Indice columns name 20 | :type columns_s: list[str] 21 | :param scaler_g: Scikit-Learn scaler for Geographical data 22 | :type scaler_g: Union[StandardScaler, RobustScaler, PowerTransformer, QuantileTransformer] 23 | :param columns_g: Geographical columns name 24 | :type columns_g: list[str] 25 | :param scaler_m: Scikit-Learn scaler for Meteorological data 26 | :type scaler_m: Union[StandardScaler, RobustScaler, PowerTransformer, QuantileTransformer] 27 | :param columns_m: Meteorological columns name 28 | :type columns_m: list[str] 29 | :param scaler_t: Scikit-Learn scaler for Target data 30 | :type scaler_t: MinMaxScaler 31 | """ 32 | def __init__( 33 | self, 34 | scaler_s: Union[ 35 | StandardScaler, RobustScaler, PowerTransformer, QuantileTransformer 36 | ], 37 | columns_s: list[str], 38 | scaler_g: Union[ 39 | StandardScaler, RobustScaler, PowerTransformer, QuantileTransformer 40 | ], 41 | columns_g: list[str], 42 | scaler_m: Union[ 43 | StandardScaler, RobustScaler, PowerTransformer, QuantileTransformer 44 | ], 45 | columns_m: list[str], 46 | scaler_t: MinMaxScaler, 47 | ) -> None: 48 | self.scaler_s = scaler_s 49 | self.columns_s = columns_s 50 | self.scaler_g = scaler_g 51 | self.columns_g = columns_g 52 | self.scaler_m = scaler_m 53 | self.columns_m = columns_m 54 | self.scaler_t = scaler_t 55 | 56 | def fit(self, xdf: xr.Dataset, target: str) -> object: 57 | """Fit all scalers to be used for later scaling. 58 | 59 | :param xdf: The data used to fit all scalers, used for later scaling along the features axis. 60 | :type xdf: xr.Dataset 61 | :param target: Column name to fit the target scaler, used for later scaling along the target axis. 62 | :type target: str 63 | :return: Fitted scaler. 64 | :rtype: object 65 | """ 66 | 67 | def fit_scaler( 68 | xdf: xr.Dataset, 69 | columns: list[str], 70 | scaler: Union[ 71 | StandardScaler, 72 | RobustScaler, 73 | PowerTransformer, 74 | QuantileTransformer, 75 | MinMaxScaler, 76 | ], 77 | ): 78 | df = xdf[columns].to_dataframe() 79 | 80 | return scaler.fit(df[columns]) 81 | 82 | # Fit S data scaler 83 | self.scaler_s = fit_scaler(xdf, self.columns_s, self.scaler_s) 84 | # Fit G data scaler 85 | self.scaler_g = fit_scaler(xdf, self.columns_g, self.scaler_g) 86 | # Fit M data scaler 87 | self.scaler_m = fit_scaler(xdf, self.columns_m, self.scaler_m) 88 | # Fit Target data scaler 89 | self.scaler_t = fit_scaler(xdf, [target], self.scaler_t) 90 | 91 | return self 92 | 93 | def transform(self, xdf: xr.Dataset, target: str = None) -> xr.Dataset: 94 | """Perform transform of each scaler. 95 | 96 | :param xdf: The Dataset used to scale along the features axis. 97 | :type xdf: xr.Dataset 98 | :param target: Column name used to scale along the Target axis, defaults to None 99 | :type target: str, optional 100 | :return: Transformed Dataset. 101 | :rtype: xr.Dataset 102 | """ 103 | 104 | def transform_data( 105 | xdf: xr.Dataset, 106 | columns: str, 107 | scaler: Union[ 108 | StandardScaler, 109 | RobustScaler, 110 | PowerTransformer, 111 | QuantileTransformer, 112 | MinMaxScaler, 113 | ], 114 | ) -> xr.Dataset: 115 | df = xdf[columns].to_dataframe() 116 | df.loc[:, columns] = scaler.transform(df[columns]) 117 | xdf_scale = df[columns].to_xarray() 118 | xdf = xr.merge([xdf_scale, xdf], compat="override") 119 | return xdf 120 | 121 | # Scale S data 122 | xdf = transform_data(xdf, self.columns_s, self.scaler_s) 123 | # Scale G data 124 | xdf = transform_data(xdf, self.columns_g, self.scaler_g) 125 | # Scale M data 126 | xdf = transform_data(xdf, self.columns_m, self.scaler_m) 127 | 128 | if target: 129 | # Scale M data 130 | xdf = transform_data(xdf, [target], self.scaler_t) 131 | 132 | return xdf 133 | 134 | def fit_transform(self, xdf: xr.Dataset, target: str) -> xr.Dataset: 135 | """Fit to data, then transform it. 136 | 137 | :param xdf: The data used to perform fit and transform. 138 | :type xdf: xr.Dataset 139 | :param target: Column name used to scale along the Target axis 140 | :type target: str 141 | :return: Transformed Dataset. 142 | :rtype: xr.Dataset 143 | """ 144 | return self.fit(xdf, target).transform(xdf, target) 145 | 146 | def inverse_transform(self, xdf: xr.Dataset, target: str = None) -> xr.Dataset: 147 | """Scale back the data to the original representation. 148 | 149 | :param xdf: The data used to scale along the features axis. 150 | :type xdf: xr.Dataset 151 | :param target: Column name used to scale along the Target axis, defaults to None 152 | :type target: str, optional 153 | :return: Transformed Dataset. 154 | :rtype: xr.Dataset 155 | """ 156 | 157 | def inverse_transform_data( 158 | xdf: xr.Dataset, 159 | columns: str, 160 | scaler: Union[ 161 | StandardScaler, 162 | RobustScaler, 163 | PowerTransformer, 164 | QuantileTransformer, 165 | MinMaxScaler, 166 | ], 167 | ) -> xr.Dataset: 168 | df = xdf[columns].to_dataframe() 169 | df.loc[:, columns] = scaler.inverse_transform(df[columns]) 170 | xdf_scale = df[columns].to_xarray() 171 | xdf = xr.merge([xdf_scale, xdf], compat="override") 172 | return xdf 173 | 174 | # Scale S data 175 | xdf = inverse_transform_data(xdf, self.columns_s, self.scaler_s) 176 | # Scale G data 177 | xdf = inverse_transform_data(xdf, self.columns_g, self.scaler_g) 178 | # Scale M data 179 | xdf = inverse_transform_data(xdf, self.columns_m, self.scaler_m) 180 | 181 | if target: 182 | # Scale M data 183 | xdf = inverse_transform_data(xdf, [target], self.scaler_t) 184 | 185 | return xdf 186 | -------------------------------------------------------------------------------- /src/models/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from datetime import datetime 4 | from os.path import join 5 | 6 | import pandas as pd 7 | import torch 8 | import torch.nn as nn 9 | import wandb 10 | from sklearn.metrics import r2_score 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | 14 | parent = os.path.abspath('.') 15 | sys.path.insert(1, parent) 16 | from utils import ROOT_DIR 17 | 18 | 19 | def compute_r2_scores(observations: list, labels: list, preds: list) -> tuple[float, float]: 20 | """ Compute R^2 scores for a given set of observations and labels. 21 | 22 | :param observations: list of observations 23 | :type observations: list[int] 24 | :param labels: list of labels 25 | :type labels: list[float] 26 | :param preds: list of predictions 27 | :type preds: list[float] 28 | :return: R^2 scores (the full is the score for the all the rows, 29 | the mean is the aggregated score grouped by observations) 30 | :rtype: tuple[float, float] 31 | """ 32 | 33 | df = pd.DataFrame() 34 | df['observations'] = observations 35 | df['labels'] = labels 36 | df['preds'] = preds 37 | full_r2_score = r2_score(df.labels, df.preds) 38 | df = df.groupby(['observations']).mean() 39 | mean_r2_score = r2_score(df.labels, df.preds) 40 | 41 | return full_r2_score, mean_r2_score 42 | 43 | 44 | class Trainer: 45 | """ Define the Trainer class. 46 | 47 | :param model: our deep learning model 48 | :type model: nn.Module 49 | :param train_dataloader: training dataloader 50 | :type train_dataloader: DataLoader 51 | :param val_dataloader: validation dataloader 52 | :type val_dataloader: DataLoader 53 | :param epochs: max number of epochs 54 | :type epochs: int 55 | :param criterion: loss function 56 | :param optimizer: model optimizer 57 | :param scheduler: learning scheduler 58 | """ 59 | def __init__(self, model: nn.Module, train_dataloader: DataLoader, val_dataloader: DataLoader, 60 | epochs: int, criterion, optimizer, scheduler): 61 | self.model = model 62 | self.train_loader = train_dataloader 63 | self.val_loader = val_dataloader 64 | self.criterion = criterion 65 | self.epochs = epochs 66 | self.optimizer = optimizer 67 | self.scheduler = scheduler 68 | self.timestamp = int(datetime.now().timestamp()) 69 | self.val_best_r2_score = 0. 70 | 71 | def train_one_epoch(self) -> float: 72 | """ Train the model for one epoch. 73 | 74 | :return: the training loss 75 | :rtype: float 76 | """ 77 | train_loss = 0. 78 | 79 | self.model.train() 80 | 81 | pbar = tqdm(self.train_loader, leave=False) 82 | for i, data in enumerate(pbar): 83 | keys_input = ['s_input', 'm_input', 'g_input'] 84 | inputs = {key: data[key] for key in keys_input} 85 | labels = data['target'] 86 | 87 | # Zero gradients for every batch 88 | self.optimizer.zero_grad() 89 | 90 | # Make predictions for this batch 91 | outputs = self.model(inputs) 92 | 93 | # Compute the loss and its gradients 94 | loss = self.criterion(outputs, labels) 95 | loss.backward() 96 | 97 | # Adjust learning weights 98 | self.optimizer.step() 99 | 100 | train_loss += loss.item() 101 | epoch_loss = train_loss / (i + 1) 102 | 103 | # Update the progress bar with new metrics values 104 | pbar.set_description(f'TRAIN - Batch: {i + 1}/{len(self.train_loader)} - ' 105 | f'Epoch Loss: {epoch_loss:.5f} - ' 106 | f'Batch Loss: {loss.item():.5f}') 107 | 108 | train_loss /= len(self.train_loader) 109 | 110 | return train_loss 111 | 112 | def val_one_epoch(self) -> tuple[float, float, float]: 113 | """ Validate the model for one epoch. 114 | 115 | :return: the validation loss, the R^2 score and the aggregated R^2 score 116 | :rtype: tuple[float, float, float] 117 | """ 118 | val_loss = 0. 119 | observations = [] 120 | val_labels = [] 121 | val_preds = [] 122 | 123 | self.model.eval() 124 | 125 | pbar = tqdm(self.val_loader, leave=False) 126 | for i, data in enumerate(pbar): 127 | keys_input = ['s_input', 'm_input', 'g_input'] 128 | inputs = {key: data[key] for key in keys_input} 129 | labels = data['target'] 130 | 131 | outputs = self.model(inputs) 132 | 133 | loss = self.criterion(outputs, labels) 134 | val_loss += loss.item() 135 | epoch_loss = val_loss / (i + 1) 136 | 137 | observations += data['observation'].squeeze().tolist() 138 | val_labels += labels.squeeze().tolist() 139 | val_preds += outputs.squeeze().tolist() 140 | 141 | # Update the progress bar with new metrics values 142 | pbar.set_description(f'VAL - Batch: {i + 1}/{len(self.val_loader)} - ' 143 | f'Epoch Loss: {epoch_loss:.5f} - ' 144 | f'Batch Loss: {loss.item():.5f}') 145 | 146 | val_loss /= len(self.val_loader) 147 | val_r2_score, val_mean_r2_score = compute_r2_scores(observations, val_labels, val_preds) 148 | 149 | return val_loss, val_r2_score, val_mean_r2_score 150 | 151 | def save(self, score: float): 152 | """ Save the model if it is the better than the previous sevaed one. 153 | 154 | :param score: current model epoch score 155 | :type score: float 156 | """ 157 | save_folder = join(ROOT_DIR, 'models') 158 | 159 | if score > self.val_best_r2_score: 160 | self.val_best_r2_score = score 161 | os.makedirs(save_folder, exist_ok=True) 162 | 163 | # delete the former best model 164 | former_model = [f for f in os.listdir(save_folder) if f.split('_')[-1] == f'{self.timestamp}.pt'] 165 | if len(former_model) == 1: 166 | os.remove(join(save_folder, former_model[0])) 167 | 168 | # save the new model 169 | score = str(score)[:7].replace('.', '-') 170 | file_name = f'{score}_model_{self.timestamp}.pt' 171 | save_path = join(save_folder, file_name) 172 | torch.save(self.model, save_path) 173 | 174 | def train(self): 175 | """ Main function to train the model. """ 176 | iter_epoch = tqdm(range(self.epochs), leave=False) 177 | 178 | for epoch in iter_epoch: 179 | iter_epoch.set_description(f'EPOCH {epoch + 1}/{self.epochs}') 180 | train_loss = self.train_one_epoch() 181 | 182 | val_loss, val_r2_score, val_mean_r2_score = self.val_one_epoch() 183 | self.scheduler.step(val_loss) 184 | self.save(val_mean_r2_score) 185 | 186 | # log the metrics to W&B 187 | wandb.log({ 188 | 'train_loss': train_loss, 189 | 'val_loss': val_loss, 190 | 'val_r2_score': val_r2_score, 191 | 'val_mean_r2_score': val_mean_r2_score, 192 | 'val_best_r2_score': self.val_best_r2_score 193 | }) 194 | 195 | # Write the finished epoch metrics values 196 | iter_epoch.write(f'EPOCH {epoch + 1}/{self.epochs}: ' 197 | f'Train = {train_loss:.5f} - ' 198 | f'Val = {val_loss:.5f} - ' 199 | f'Val R2 = {val_r2_score:.5f} - ' 200 | f'Val mean R2 = {val_mean_r2_score:.5f}') 201 | -------------------------------------------------------------------------------- /src/models/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Dict, List, Tuple 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import xarray as xr 9 | from scipy import stats 10 | from sklearn.model_selection import train_test_split 11 | from torch.utils.data import DataLoader, Dataset 12 | 13 | from src.constants import (FOLDER, G_COLUMNS, M_COLUMNS, S_COLUMNS, TARGET, 14 | TARGET_TEST) 15 | 16 | parent = os.path.abspath(".") 17 | sys.path.insert(1, parent) 18 | 19 | from os.path import join 20 | 21 | from utils import ROOT_DIR 22 | 23 | 24 | class CustomDataset(Dataset): 25 | def __init__( 26 | self, 27 | s_inputs: np.ndarray, 28 | g_inputs: np.ndarray, 29 | m_inputs: np.ndarray, 30 | obs_targets: np.ndarray, 31 | augment: int, 32 | device: str, 33 | ): 34 | """Dataset used for the dataloader. 35 | 36 | :param s_inputs: Satellite data. 37 | :type s_inputs: np.ndarray 38 | :param g_inputs: Raw data. 39 | :type g_inputs: np.ndarray 40 | :param m_inputs: Meteorological data. 41 | :type m_inputs: np.ndarray 42 | :param obs_targets: Yield data. 43 | :type obs_targets: np.ndarray 44 | :param augment: Number of data augmentation. 45 | :type augment: int 46 | :param device: Training device. 47 | :type device: str 48 | """ 49 | # Move data on the training device. 50 | self.augment = augment 51 | self.device = device 52 | self.s_inputs = torch.tensor(s_inputs).to( 53 | device=self.device, dtype=torch.float32 54 | ) 55 | self.g_inputs = torch.tensor(g_inputs).to( 56 | device=self.device, dtype=torch.float32 57 | ) 58 | self.m_inputs = torch.tensor(m_inputs).to( 59 | device=self.device, dtype=torch.float32 60 | ) 61 | self.observations = torch.tensor(obs_targets[:, 0]).to( 62 | device=self.device, dtype=torch.float32 63 | ) 64 | self.targets = torch.tensor(obs_targets[:, 1]).to( 65 | device=self.device, dtype=torch.float32 66 | ) 67 | 68 | def __len__(self): 69 | return self.s_inputs.shape[0] 70 | 71 | def __getitem__(self, idx): 72 | # Return data for a particular indexe 73 | # The data depend only on the observation indexe 74 | # Only the satellite data depend on the augmentation indexe 75 | idx_obs = idx // self.augment 76 | item = { 77 | "observation": self.observations[[idx_obs]], 78 | "s_input": self.s_inputs[idx], 79 | "m_input": self.m_inputs[idx_obs], 80 | "g_input": self.g_inputs[idx_obs], 81 | "target": self.targets[[idx_obs]], 82 | } 83 | 84 | return item 85 | 86 | 87 | def create_train_val_idx(xds: xr.Dataset, val_rate: float) -> Tuple[List, List]: 88 | """Compute a stratifate Train/Val split. 89 | 90 | :param xds: Dataset used for the split. 91 | :type xds: xr.Dataset 92 | :param val_rate: Percentage of data in the validation set. 93 | :type val_rate: float 94 | :return: return list of train index & list of val index 95 | :rtype: tuple[list, list] 96 | """ 97 | yields = xds[TARGET].values 98 | yields_distribution = stats.norm(loc=yields.mean(), scale=yields.std()) 99 | bounds = yields_distribution.cdf([0, 1]) 100 | bins = np.linspace(*bounds, num=10) 101 | stratify = np.digitize(yields, bins) 102 | train_idx, val_idx = train_test_split( 103 | xds.ts_obs, test_size=val_rate, random_state=42, stratify=stratify 104 | ) 105 | 106 | return train_idx, val_idx 107 | 108 | 109 | def transform_data( 110 | xds: xr.Dataset, m_times: int = 120, test=False 111 | ) -> Dict[str, np.ndarray]: 112 | """Transform data from xr.Dataset to dict of np.ndarray 113 | sorted by observation and augmentation. 114 | 115 | :param xds: The Dataset to be transformed. 116 | :type xds: xr.Dataset 117 | :param m_times: Length of the time series for Weather data, defaults to 120. 118 | :type m_times: int, optional 119 | :param test: True if it is the Test dataset, defaults to False. 120 | :type test: bool, optional 121 | :return: Dictionnary of all data used to construct the torch Dataset. 122 | :rtype: dict[str, np.ndarray] 123 | """ 124 | items = {} 125 | # Dataset sorting for compatibility with torch Dataset indexes 126 | xds = xds.sortby(["ts_obs", "ts_aug"]) 127 | 128 | # Create raw data 129 | g_arr = xds[G_COLUMNS].to_dataframe() 130 | items["g_inputs"] = g_arr.values 131 | 132 | # Create satellite data 133 | # Keep only useful values and convert into numpy array 134 | s_arr = xds[S_COLUMNS].to_dataframe()[S_COLUMNS] 135 | s_arr = s_arr.to_numpy() 136 | # Reshape axis to match index, date, features 137 | # TODO: set as variable the number of state_dev and features. 138 | s_arr = s_arr.reshape(s_arr.shape[0] // 24, 24, 8) 139 | items["s_inputs"] = s_arr 140 | 141 | # Create Meteorological data 142 | # time and District are the keys to link observations and meteorological data 143 | df_time = xds[["time", "District"]].to_dataframe() 144 | # Keep only useful data 145 | df_time.reset_index(inplace=True) 146 | df_time = df_time[["ts_obs", "state_dev", "time", "District"]] 147 | # Meteorological data only dependend of the observation 148 | df_time = df_time.groupby(["ts_obs", "state_dev", "District"]).first() 149 | # Take the min and max datetime of satellite data to create a daily time series of meteorological data 150 | df_time.reset_index("state_dev", inplace=True) 151 | # TODO: set as variable the number of state_dev. 152 | df_time = df_time[df_time["state_dev"].isin([0, 23])] 153 | df_time = df_time.pivot(columns="state_dev").droplevel(None, axis=1) 154 | df_time.reset_index("District", inplace=True) 155 | 156 | # For each observation take m_times daily date before the 157 | # harverest date and get data with the corresponding location 158 | list_weather = [] 159 | for _, series in df_time.iterrows(): 160 | all_dates = pd.date_range(series[0], series[23], freq="D") 161 | all_dates = all_dates[-m_times:] 162 | m_arr = ( 163 | xds.sel(datetime=all_dates, name=series["District"])[M_COLUMNS] 164 | .to_array() 165 | .values 166 | ) 167 | list_weather.append(m_arr.T) 168 | 169 | items["m_inputs"] = np.asarray(list_weather) 170 | 171 | # If test create the target array with 0 instead of np.nan 172 | if test: 173 | df = xds[TARGET_TEST].to_dataframe().reset_index() 174 | df[TARGET_TEST] = 0 175 | items["obs_targets"] = df.to_numpy() 176 | else: 177 | items["obs_targets"] = xds[TARGET].to_dataframe().reset_index().to_numpy() 178 | 179 | items["augment"] = xds["ts_aug"].values.shape[0] 180 | 181 | return items 182 | 183 | 184 | def get_dataloaders( 185 | batch_size: int, val_rate: float, device: str 186 | ) -> Tuple[DataLoader, DataLoader, DataLoader]: 187 | """Generate Train / Validation / Test Torch Dataloader. 188 | 189 | :param batch_size: Batch size of Dataloader. 190 | :type batch_size: int 191 | :param val_rate: Percentage of data on the Validation Dataset. 192 | :type val_rate: float 193 | :param device: Device where to put the data. 194 | :type device: str 195 | :return: Train / Validation / Test Dataloader 196 | :rtype: tuple[DataLoader, DataLoader, DataLoader] 197 | """ 198 | # Read the dataset processed 199 | dataset_path = join(ROOT_DIR, "data", "processed", FOLDER, "train_enriched.nc") 200 | xdf_train = xr.open_dataset(dataset_path, engine="scipy") 201 | 202 | # Create a Train / Validation split 203 | train_idx, val_idx = create_train_val_idx(xdf_train, val_rate) 204 | train_array = xdf_train.sel(ts_obs=train_idx) 205 | # Prepare data for th Torch Dataset 206 | items = transform_data(train_array) 207 | train_dataset = CustomDataset(**items, device=device) 208 | # Create the Dataloader 209 | train_dataloader = DataLoader( 210 | train_dataset, batch_size=batch_size, drop_last=True, shuffle=True 211 | ) 212 | 213 | # ?: Make a function to create each dataloader 214 | val_array = xdf_train.sel(ts_obs=val_idx) 215 | items = transform_data(val_array) 216 | val_dataset = CustomDataset(**items, device=device) 217 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size, drop_last=True) 218 | 219 | dataset_path = join(ROOT_DIR, "data", "processed", FOLDER, "test_enriched.nc") 220 | xdf_test = xr.open_dataset(dataset_path, engine="scipy") 221 | items = transform_data(xdf_test, test=True) 222 | test_dataset = CustomDataset(**items, device=device) 223 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True) 224 | 225 | return train_dataloader, val_dataloader, test_dataloader 226 | -------------------------------------------------------------------------------- /src/data/preprocessing.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Union 3 | 4 | warnings.filterwarnings("ignore") 5 | 6 | import os 7 | import sys 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import xarray as xr 12 | from scipy.signal import savgol_filter 13 | from sklearn.preprocessing import ( 14 | MinMaxScaler, 15 | QuantileTransformer, 16 | RobustScaler, 17 | StandardScaler, 18 | ) 19 | 20 | parent = os.path.abspath(".") 21 | sys.path.insert(1, parent) 22 | 23 | from os.path import join 24 | 25 | from sklearn.base import BaseEstimator, OneToOneFeatureMixin, TransformerMixin 26 | 27 | from src.constants import G_COLUMNS, M_COLUMNS, S_COLUMNS 28 | from utils import ROOT_DIR 29 | 30 | 31 | class Sorter(OneToOneFeatureMixin, TransformerMixin, BaseEstimator): 32 | """Sort dataset to align dataset samples with labels samples.""" 33 | 34 | def __init__(self) -> None: 35 | pass 36 | 37 | def fit(self, X=None, y=None) -> object: 38 | """Identity function 39 | 40 | :param X: Ignored 41 | :type X: None 42 | :param y: Ignored 43 | :type y: None 44 | :return: self 45 | :rtype: object 46 | """ 47 | return self 48 | 49 | def transform(self, X: pd.DataFrame) -> pd.DataFrame: 50 | """Reorder the indexes in an ascending way first by observation then by augmentation. 51 | 52 | :param X: Dataset that will be transformed. 53 | :type X: pd.DataFrame 54 | :return: Transformed Dataframe. 55 | :rtype: pd.DataFrame 56 | """ 57 | return X.reorder_levels(["ts_obs", "ts_aug"]).sort_index() 58 | 59 | 60 | # Convertor class used on ML exploration 61 | class Convertor(OneToOneFeatureMixin, TransformerMixin, BaseEstimator): 62 | """Used to transform the xarray.Dataset into pandas.DataFrame and reduce the dimention and/or tranform it. 63 | 64 | :param agg: If True then replace features with their aggregations along the state_dev axis (agg = min, mean, max), defaults to None 65 | :type agg: bool, optional 66 | :param weather: If False then remove weather data from the Dataset, defaults to True 67 | :type weather: bool, optional 68 | :param vi: If False then remove vegetable indices from the Dataset, defaults to True 69 | :type vi: bool, optional 70 | """ 71 | 72 | def __init__(self, agg: bool = None, weather: bool = True, vi: bool = True) -> None: 73 | self.agg = agg 74 | self.weather = weather 75 | self.vi = vi 76 | 77 | def to_dataframe(self, X: xr.Dataset) -> pd.DataFrame: 78 | # Convert xarray.Dataset into usable pandas.DataFrame 79 | 80 | # Depend of aggregations was performed, change the columns name 81 | col = "agg" if self.agg else "state_dev" 82 | # Convert xarray.Dataset into pandas.DataFrame 83 | df = X.to_dataframe() 84 | # set G_COLUMNS as index to not be duplicate by the pivot operation 85 | df.set_index(G_COLUMNS, append=True, inplace=True) 86 | # reset the columns use to apply the pivot and convert its values into string 87 | df.reset_index(col, inplace=True) 88 | df[col] = df[col].astype(str) 89 | # Apply pivot to change state_dev or agg from samples to features 90 | df = df.pivot(columns=col) 91 | # Convert pandas.MultiIndex to a pandas.Index by merging names 92 | df.columns = df.columns.map("_".join).str.strip("_") 93 | # set G_COLUMNS as features 94 | df.reset_index(G_COLUMNS, inplace=True) 95 | # sort dataset for future compability 96 | df = df.reorder_levels(["ts_obs", "ts_aug"]).sort_index() 97 | return df 98 | 99 | def merge_dimensions(self, X: xr.Dataset) -> xr.Dataset: 100 | # Merge VI, Geographical and Meteorological data into the same dimension 101 | X = xr.merge( 102 | [ 103 | X[G_COLUMNS], 104 | X[M_COLUMNS].sel(datetime=X["time"], name=X["District"]), 105 | X[S_COLUMNS], 106 | ] 107 | ) 108 | # Drop useless columns 109 | X = X.drop(["name", "datetime", "time"]) 110 | return X 111 | 112 | def compute_agg(self, X: xr.Dataset) -> xr.Dataset: 113 | # Compute aggregation on the Dataset and set the new dimension values 114 | # with the name of each aggregation performed 115 | X = xr.concat( 116 | [X.mean(dim="state_dev"), X.max(dim="state_dev"), X.min(dim="state_dev")], 117 | dim="agg", 118 | ) 119 | X["agg"] = ["mean", "max", "min"] 120 | return X 121 | 122 | def fit(self, X=None, y=None) -> object: 123 | """Identity function. 124 | 125 | :param X: Ignored 126 | :type X: None 127 | :param y: Ignored 128 | :type y: None 129 | :return: Convertor. 130 | :rtype: object 131 | """ 132 | return self 133 | 134 | def transform(self, X: xr.Dataset) -> pd.DataFrame: 135 | """Transform the xarray.Dataset to pandas.Dataframe depends on the argument of the class. 136 | 137 | :param X: Dataset that will be transformed. 138 | :type X: xr.Dataset 139 | :return: Transformed Dataset. 140 | :rtype: pd.DataFrame 141 | """ 142 | # Transform data to depends of the sames dimentions 143 | X = self.merge_dimensions(X) 144 | # If True, compute aggregation to the data 145 | if self.agg: 146 | X = self.compute_agg(X) 147 | # If False, remove weather data 148 | if not self.weather: 149 | X = X.drop(M_COLUMNS) 150 | # If False, remove vi data 151 | if not self.vi: 152 | X = X.drop(S_COLUMNS) 153 | # Convert the Dataset into a DataFrame 154 | X = self.to_dataframe(X) 155 | return X 156 | 157 | 158 | class Smoother(OneToOneFeatureMixin, TransformerMixin, BaseEstimator): 159 | """Smooth Vegetable Indice Data. 160 | 161 | :param mode: methode used to smooth vi data, None to not perform smoothing during , defaults to "savgol" 162 | :type mode: str, optional 163 | """ 164 | 165 | def __init__(self, mode: str = "savgol") -> None: 166 | self.mode = mode 167 | 168 | def smooth_savgol(self, ds: xr.Dataset) -> xr.Dataset: 169 | # apply savgol_filter to vegetable indice 170 | ds_s = xr.apply_ufunc( 171 | savgol_filter, 172 | ds[S_COLUMNS], 173 | kwargs={"axis": 2, "window_length": 12, "polyorder": 4, "mode": "mirror"}, 174 | ) 175 | # merge both dataset and override old vegetable indice and bands 176 | return xr.merge([ds_s, ds], compat="override") 177 | 178 | def fit(self, X: xr.Dataset = None, y=None) -> object: 179 | """Identity function. 180 | 181 | :param X: Ignored, defaults to None 182 | :type X: xr.Dataset, optional 183 | :param y: Ignored, defaults to None 184 | :type y: _type_, optional 185 | :return: Themself. 186 | :rtype: object 187 | """ 188 | return self 189 | 190 | def transform(self, X: xr.Dataset) -> xr.Dataset: 191 | """Smooth Vegetable Indice Data according to the mode used. 192 | 193 | :param X: Dataset that will be transformed. 194 | :type X: xr.Dataset 195 | :return: Transformed Dataset. 196 | :rtype: xr.Dataset 197 | """ 198 | # If mode not equal to savgol, transform correspond to identity function. 199 | if self.mode == "savgol": 200 | X = self.smooth_savgol(X) 201 | 202 | return X 203 | 204 | 205 | def replaceinf(arr: np.ndarray) -> np.ndarray: 206 | if np.issubdtype(arr.dtype, np.number): 207 | arr[np.isinf(arr)] = np.nan 208 | return arr 209 | 210 | 211 | class Filler(OneToOneFeatureMixin, TransformerMixin, BaseEstimator): 212 | """Fill dataset using the mean of each group of observation for a given date. 213 | For the reaining data use the mean of the dataset for a given developpment state. 214 | """ 215 | 216 | def __init__(self) -> None: 217 | self.values = None 218 | 219 | def fit(self, X: xr.Dataset, y=None) -> object: 220 | """Compute mean by developpement state to be used for later filling. 221 | 222 | :param X: The data used to compute mean by developpement state used for later filling. 223 | :type X: xr.Dataset 224 | :param y: Ignored 225 | :type y: None 226 | :return: self 227 | :rtype: object 228 | """ 229 | # replace infinite value by na 230 | xr.apply_ufunc(replaceinf, X[S_COLUMNS]) 231 | # compute mean of all stage of developpement for each cluster obsevation 232 | self.values = ( 233 | X[S_COLUMNS].mean(dim="ts_aug", skipna=True).mean(dim="ts_obs", skipna=True) 234 | ) 235 | 236 | return self 237 | 238 | def transform(self, X: xr.Dataset) -> xr.Dataset: 239 | """Performs the filling of missing values 240 | 241 | :param X: The dataset used to fill. 242 | :type X: xr.Dataset 243 | :return: Transformed Dataset. 244 | :rtype: xr.Dataset 245 | """ 246 | # replace infinite value by na 247 | xr.apply_ufunc(replaceinf, X[S_COLUMNS]) 248 | # compute mean of all stage of developpement and all obsevation 249 | X[S_COLUMNS] = X[S_COLUMNS].fillna(X[S_COLUMNS].mean(dim="ts_aug", skipna=True)) 250 | # fill na value with fited mean 251 | X[S_COLUMNS] = X[S_COLUMNS].fillna(self.values) 252 | 253 | return X 254 | -------------------------------------------------------------------------------- /docs/build/html/search.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | Search | Crop forecasting 1.0.0 documentation 10 | 11 | 12 | 13 | 14 | 15 | 16 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 |
38 | Skip to content 39 |
40 | 48 |
49 |
50 | 57 |
58 | 68 |
69 |
70 |
71 |
102 |
103 |
104 | 113 |
114 |
115 | Please activate Javascript to enable searching the documentation.
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |

© 2023, Baptiste URGELL, Louis REBERGA Built with Sphinx 7.0.1

125 |
126 |
127 |
128 |
129 | 130 | -------------------------------------------------------------------------------- /src/data/make_data.py: -------------------------------------------------------------------------------- 1 | import math 2 | import multiprocessing as mp 3 | import os 4 | import sys 5 | import time 6 | from datetime import datetime, timedelta 7 | from random import random, uniform 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import planetary_computer as pc 12 | import pystac_client 13 | import xarray as xr 14 | from odc.stac import stac_load 15 | from tqdm import tqdm 16 | 17 | parent = os.path.abspath(".") 18 | sys.path.insert(1, parent) 19 | 20 | from os.path import join 21 | 22 | from utils import ROOT_DIR 23 | 24 | # Make data constants 25 | SIZE = "adaptative" # 'fixed' 26 | FACTOR = 1 # for 'adaptative' 27 | NUM_AUGMENT = 40 28 | MAX_AUGMENT = 5 29 | DEGREE = 0.0014589825157734703 # = ha_to_degree(2.622685) # Field size (ha) mean = 2.622685 (train + test) 30 | 31 | 32 | def ha_to_degree(field_size: float) -> float: # Field_size (ha) 33 | """ Convert field size (ha) to degree. 34 | 35 | :param field_size: field size (ha) 36 | :type field_size: float 37 | :return: field width/length (degree) 38 | :rtype: float 39 | """ 40 | 41 | # 1° ~= 111km 42 | # 1ha = 0.01km2 43 | # then, side_size = sqrt(0.01 * field_size) (km) 44 | # so, degree = side_size / 111 (°) 45 | side_size = math.sqrt(0.01 * field_size) 46 | degree = side_size / 111 47 | return degree 48 | 49 | 50 | def create_folders() -> str: 51 | """ Create folders in function of the extraction type. 52 | 53 | :return: name of the folder created 54 | :rtype: str 55 | """ 56 | save_folder = None 57 | 58 | if NUM_AUGMENT > 1: 59 | save_folder = join( 60 | ROOT_DIR, 61 | "data", 62 | "external", 63 | "satellite", 64 | f"augment_{NUM_AUGMENT}_{MAX_AUGMENT}", 65 | ) 66 | elif SIZE == "fixed": 67 | degree = str(round(DEGREE, 5)).replace(".", "-") 68 | save_folder = join(ROOT_DIR, "data", "external", "satellite", f"fixed_{degree}") 69 | elif SIZE == "adaptative": 70 | save_folder = join( 71 | ROOT_DIR, "data", "external", "satellite", f"adaptative_factor_{FACTOR}" 72 | ) 73 | 74 | os.makedirs(save_folder, exist_ok=True) 75 | return save_folder 76 | 77 | 78 | def get_factors() -> list[float]: 79 | """ Randomly draw factors to create augmented windows 80 | to retrieve different satellite images. 81 | 82 | :return: four random factors (between 1/MAX_AUGMENT and MAX_AUGMENT) 83 | :rtype: list[float] 84 | """ 85 | factors = [] 86 | for _ in range(4): 87 | factor = uniform(1, MAX_AUGMENT) 88 | if random() < 0.5: 89 | factor = 1 / factor 90 | factors.append(factor) 91 | 92 | return factors 93 | 94 | 95 | def get_bbox(longitude: float, latitude: float, field_size: float) -> tuple[float, float, float, float]: 96 | """ Get the bounding box of the satellite image 97 | using augmented window factors. 98 | 99 | :param longitude: longitude of the satellite image 100 | :type longitude: float 101 | :param latitude: latitude of the satellite image 102 | :type latitude: float 103 | :param field_size: field size (ha) 104 | :type field_size: float 105 | :return: max and min longitude, min and max latitude 106 | :rtype: tuple[float, float, float, float] 107 | """ 108 | 109 | if SIZE == "fixed": 110 | degree = DEGREE 111 | else: 112 | degree = ha_to_degree(field_size) * FACTOR 113 | 114 | length = degree / 2 115 | factors = get_factors() 116 | min_longitude = longitude - factors[0] * length 117 | min_latitude = latitude - factors[1] * length 118 | max_longitude = longitude + factors[2] * length 119 | max_latitude = latitude + factors[3] * length 120 | 121 | return min_longitude, min_latitude, max_longitude, max_latitude 122 | 123 | 124 | def get_time_period(harvest_date: str, history_days: int) -> str: 125 | """ Get the time period using the harvest date 126 | and the number history days defined. 127 | 128 | :param harvest_date: Date of the harvest 129 | :type harvest_date: str (%d-%m-%Y date format) 130 | :param history_days: Number of history days chosen 131 | :type history_days: int 132 | :return: string with the (calculated) sowing and harvest date (%d-%m-%Y date format) 133 | :rtype: str 134 | """ 135 | harvest_datetime = datetime.strptime(harvest_date, "%d-%m-%Y") 136 | sowing_datetime = harvest_datetime - timedelta(days=history_days) 137 | return f'{sowing_datetime.strftime("%Y-%m-%d")}/{harvest_datetime.strftime("%Y-%m-%d")}' 138 | 139 | 140 | def get_data(bbox: tuple[float, float, float, float], time_period: str, bands: list[str], scale: float) -> xr.Dataset: 141 | """ Get satellite data. 142 | 143 | :param bbox: Bounding box of the satellite image. 144 | :type bbox: tuple[float, float, float, float] 145 | :param time_period: Time period of the satellite image. 146 | :type time_period: str 147 | :param bands: List of bands to retrieve. 148 | :type bands: list[str] 149 | :param scale: Resolution of the satellite image, defaults to 10. 150 | :type scale: float 151 | :return: Dataset processed of an observation. 152 | :rtype: xr.Dataset 153 | """ 154 | catalog = pystac_client.Client.open( 155 | "https://planetarycomputer.microsoft.com/api/stac/v1", modifier=pc.sign_inplace 156 | ) 157 | search = catalog.search( 158 | collections=["sentinel-2-l2a"], bbox=bbox, datetime=time_period 159 | ) 160 | items = search.item_collection() 161 | data = stac_load(items, bands=bands, crs="EPSG:4326", resolution=scale, bbox=bbox) 162 | return data 163 | 164 | 165 | def save_data(row: pd.Series, history_days: int, history_dates: int, resolution: int) -> xr.Dataset: 166 | """ Get Satellite Dataset and process it to be used. 167 | 168 | :param row: Series representing an observation. 169 | :type row: pd.Series 170 | :param history_days: Number of day to take satellite data before the harvest. 171 | :type history_days: int 172 | :param history_dates: Number of satellite data to take before the harvest 173 | :type history_dates: int 174 | :param resolution: Resolution of the satellite image, defaults to 10. 175 | :type resolution: int 176 | :return: Dataset processed of an observation. 177 | :rtype: xr.Dataset 178 | """ 179 | scale = resolution / 111320.0 180 | bands = ["red", "green", "blue", "B05", "B06", "B07", "nir", "B11", "SCL"] 181 | 182 | longitude = row["Longitude"] 183 | latitude = row["Latitude"] 184 | field_size = float(row["Field size (ha)"]) 185 | bbox = get_bbox(longitude, latitude, field_size) 186 | 187 | # Get the time periode to retrieve statellite data 188 | harvest_date = row["Date of Harvest"] 189 | time_period = get_time_period(harvest_date, history_days) 190 | 191 | # Get the satellite data 192 | xds = get_data(bbox, time_period, bands, scale) 193 | 194 | # Cloud mask on SCL value to only keep clear data 195 | cloud_mask = ( 196 | (xds.SCL != 0) 197 | & (xds.SCL != 1) 198 | & (xds.SCL != 3) 199 | & (xds.SCL != 6) 200 | & (xds.SCL != 8) 201 | & (xds.SCL != 9) 202 | & (xds.SCL != 10) 203 | ) 204 | xds = xds.where(cloud_mask) 205 | 206 | # Keep only useful data 207 | xds = xds.drop(["spatial_ref", "SCL"]) 208 | 209 | # Compute the mean of each image by localisation 210 | xds = xds.mean(dim=["latitude", "longitude"], skipna=True) 211 | 212 | # Sort data by time 213 | xds = xds.sortby("time", ascending=False) 214 | 215 | # Keep only the history_dates oldest data 216 | xds = xds.isel(time=slice(None, history_dates)) 217 | 218 | # Format data 219 | xds["time"] = xds["time"].dt.strftime("%Y-%m-%d") 220 | 221 | # Create a Variable named state_dev which reprensent 222 | # the number of development state keep and set it as dimension 223 | xds["state_dev"] = ("time", np.arange(history_dates)[::-1]) 224 | xds = xds.swap_dims({"time": "state_dev"}) 225 | 226 | # Rename bands api name by more readable name 227 | # Dictionnary for matching api bands name and natural bands name 228 | dict_band_name = { 229 | "B05": "rededge1", 230 | "B06": "rededge2", 231 | "B07": "rededge3", 232 | "B11": "swir", 233 | } 234 | xds = xds.rename_vars(dict_band_name) 235 | 236 | return xds 237 | 238 | 239 | def save_data_app(index_row: tuple[str, pd.Series], history_days: int = 130, history_dates: int = 24, 240 | resolution: int = 10,) -> xr.Dataset: 241 | """ Get Satellite Datasets from an observation and concat them to one. 242 | 243 | :param index_row: Tuple of index string and a Series representing an observation. 244 | :type index_row: tuple[str, pd.Series] 245 | :param history_days: Number of day to take satellite data before the harvest, defaults to 130. 246 | :type history_days: int, optional 247 | :param history_dates: Number of satellite data to take before the harvest, defaults to 24. 248 | :type history_dates: int, optional 249 | :param resolution: Resolution of the satellite image, defaults to 10. 250 | :type resolution: int, optional 251 | :return: Concatenate Dataset of one observation. 252 | :rtype: xr.Dataset 253 | """ 254 | list_xds = [] 255 | # For the number of data augmentation desired get satellite data 256 | # and append it into a list 257 | for i in range(NUM_AUGMENT): 258 | xds = save_data(index_row[1], history_days, history_dates, resolution) 259 | xds = xds.expand_dims({"ts_aug": [i]}) 260 | list_xds.append(xds) 261 | 262 | # Concat list of Dataset into a single one representing one observation * NUM_AUGMENT 263 | xds: xr.Dataset = xr.concat(list_xds, dim="ts_aug") 264 | 265 | # Create a new dimenstion called ts_obs representing the index of the observation. 266 | xds = xds.expand_dims({"ts_obs": [index_row[0]]}) 267 | 268 | return xds 269 | 270 | 271 | def init_df(df: pd.DataFrame, path: str) -> tuple[pd.DataFrame, list]: 272 | """ Check for missing observations on the dataset and make 273 | and filter the dataframe to only keep the missing ones. 274 | 275 | :param df: DataFrame of all observations. 276 | :type df: pd.DataFrame 277 | :param path: Dataset path of already retrieves data. 278 | :type path: str 279 | :return: Dataframe filtered and list with one Dataset if some data was already retrieves. 280 | :rtype: tuple[pd.DataFrame, list] 281 | """ 282 | list_data = [] 283 | df.index.name = "ts_obs" 284 | 285 | if os.path.exists(path=path): 286 | xdf = xr.open_dataset(path, engine="scipy") 287 | unique = np.unique(xdf["ts_obs"].values) 288 | list_data.append(xdf) 289 | 290 | df = df.loc[~df.index.isin(unique)] 291 | 292 | return df, list_data 293 | 294 | 295 | class Checkpoint(Exception): 296 | """ Exception class to save data during the retrieval. """ 297 | 298 | def __init__(self): 299 | pass 300 | 301 | 302 | def make_data(path: str, save_file: str) -> bool: 303 | """ From a given csv at EY data format get satellite data 304 | corresponding to the localisation and date of each observation 305 | from microsoft api and save it into external directory. 306 | Implement an auto restart from the last observation saved. 307 | Save data as nc format using scipy engine. 308 | 309 | :param path: CSV path of EY data. 310 | :type path: str 311 | :param save_file: Directory to save the Dataset. 312 | :type save_file: str 313 | :raises Checkpoint: Auto save every hour. 314 | :return: True if a checkpoint is reached, False otherwise. 315 | :rtype: bool 316 | """ 317 | start = time.time() 318 | checkpoint = False 319 | 320 | df: pd.DataFrame = pd.read_csv(path) 321 | 322 | df, list_data = init_df(df, save_file) 323 | 324 | print(f'\nRetrieve SAR data from {path.split("/")[-1]}...') 325 | try: 326 | with mp.Pool(4) as pool: 327 | # Multiprocessing data retrieval. 328 | # Create a list of dataset with all dataset represent one obervation 329 | # multiply by the number of augmentation desired. 330 | for xds in tqdm(pool.imap(save_data_app, df.iterrows()), total=len(df)): 331 | list_data.append(xds) 332 | if time.time() - start > 3600: 333 | # Each houre raise a Checkpoint Exception to stop the process and save the data 334 | raise Checkpoint("Checkpoint.") 335 | except Checkpoint as c: 336 | # If the error is a checkpoint set the boolean to True 337 | checkpoint = True 338 | finally: 339 | # Concat all Dataset to one and save it. 340 | data = xr.concat(list_data, dim="ts_obs") 341 | print(f'\nSave SAR data from {path.split("/")[-1]}...') 342 | data.to_netcdf(save_file, engine="scipy") 343 | print(f'\nSAR data from {path.split("/")[-1]} saved!') 344 | return checkpoint 345 | 346 | 347 | if __name__ == "__main__": 348 | save_folder = create_folders() 349 | 350 | checkpoint = True 351 | while checkpoint: 352 | # While make data finish because of a checkpoint exception 353 | # Restarts satellite data retrieval. 354 | train_path = join(ROOT_DIR, "data", "raw", "train.csv") 355 | train_file = join(save_folder, "train.nc") 356 | checkpoint = make_data(train_path, train_file) 357 | 358 | checkpoint = True 359 | while checkpoint: 360 | # Same for Test data. 361 | test_path = join(ROOT_DIR, "data", "raw", "test.csv") 362 | test_file = join(save_folder, "test.nc") 363 | checkpoint = make_data(test_path, test_file) 364 | -------------------------------------------------------------------------------- /docs/build/html/documentation.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | Documentation | Crop forecasting 1.0.0 documentation 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 31 | 32 | 33 | 34 | 35 | 36 | 37 |
38 | Skip to content 39 |
40 | 48 |
49 |
50 | 57 |
58 | 68 |
69 |
70 |
71 |
102 |
103 |
104 | 113 |
114 | _images/banner.jpg 115 |
116 |

Documentation

117 | 125 |
126 |

Deep Learning

127 | 134 |
135 |
136 |

Machine Learning

137 | 140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |

© 2023, Baptiste URGELL, Louis REBERGA Built with Sphinx 7.0.1

149 |
150 |
151 |
152 |
153 | 154 | -------------------------------------------------------------------------------- /src/data/make_preprocessing.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import joblib 4 | 5 | warnings.filterwarnings("ignore") 6 | 7 | import os 8 | import sys 9 | import glob 10 | from os.path import join 11 | from tqdm import tqdm 12 | 13 | import numpy as np 14 | import pandas as pd 15 | import xarray as xr 16 | 17 | parent = os.path.abspath(".") 18 | sys.path.insert(1, parent) 19 | 20 | from src.data.datascaler import DatasetScaler 21 | from src.data.preprocessing import Smoother 22 | from sklearn.preprocessing import MinMaxScaler, StandardScaler 23 | 24 | 25 | from src.constants import (FOLDER, G_COLUMNS, M_COLUMNS, S_COLUMNS, TARGET, 26 | TARGET_TEST) 27 | from utils import ROOT_DIR 28 | 29 | 30 | def merge_satellite(file: str)->xr.Dataset: 31 | """Merge Augmented 10 / 40 / 50 to another one. 32 | 33 | :param file: File name of all datasets. 34 | :type file: str 35 | :return: Merged dataset. 36 | :rtype: xr.Dataset 37 | """ 38 | # Open dataset 39 | def open_dataset(folder): 40 | return xr.open_dataset( 41 | join(ROOT_DIR, "data", "external", "satellite", folder, file), 42 | engine="scipy", 43 | ) 44 | 45 | folder = "augment_50_5" 46 | xds_50 = open_dataset(folder) 47 | 48 | folder = "augment_40_5" 49 | xds_40 = open_dataset(folder) 50 | # Change number of ts_aug to not be overwrite during the merge 51 | xds_40["ts_aug"] = np.arange(50, 90) 52 | 53 | folder = "augment_10_5" 54 | xds_10 = open_dataset(folder) 55 | # Same 56 | xds_10["ts_aug"] = np.arange(90, 100) 57 | 58 | xds_100 = xr.merge([xds_50, xds_40, xds_10], compat="no_conflicts") 59 | 60 | return xds_100 61 | 62 | 63 | def add_observation(xds: xr.Dataset, test: bool) -> xr.Dataset: 64 | """Process and Merge EY data to Satellite Dataset. 65 | 66 | :param xds: Satellite Dataset that will be merged 67 | :type xds: xr.Dataset 68 | :param test: True if it is the test Dataset. 69 | :type test: bool 70 | :return: Merged Dataset. 71 | :rtype: xr.Dataset 72 | """ 73 | def categorical_encoding(xds: xr.Dataset) -> xr.Dataset: 74 | # Encode Rice Crop Intensity feature D = 2 and T = 3 75 | xds["Rice Crop Intensity(D=Double, T=Triple)"] = ( 76 | xds["Rice Crop Intensity(D=Double, T=Triple)"] 77 | .str.replace("D", "2") 78 | .str.replace("T", "3") 79 | .astype(np.int8) 80 | ) 81 | return xds 82 | 83 | file_name = "train_enriched.csv" 84 | if test: 85 | file_name = "test_enriched.csv" 86 | 87 | path = join(ROOT_DIR, "data", "interim", file_name) 88 | # Read csv EY data 89 | df = pd.read_csv(path) 90 | # Set index name as ts_obs for linked both Dataset 91 | df.index.name = "ts_obs" 92 | # Convert pandas.DataFrame into xarray.Dataset and merge on ts_obs 93 | xds = xr.merge([xds, df.to_xarray()], compat='override') 94 | # Encode categoricals data 95 | xds = categorical_encoding(xds) 96 | 97 | return xds 98 | 99 | 100 | def add_weather(xds: xr.Dataset) -> xr.Dataset: 101 | """Add meteorological data to the Dataset. 102 | 103 | :param xds: Dataset that will be merged. 104 | :type xds: xr.Dataset 105 | :return: Merged Dataset. 106 | :rtype: xr.Dataset 107 | """ 108 | 109 | def features_modification(xds: xr.Dataset) -> xr.Dataset: 110 | # Crreate new features named solarexposure 111 | # It is the difference between sunset and sunrise 112 | xds["sunrise"] = xds["sunrise"].astype(np.datetime64) 113 | xds["sunset"] = xds["sunset"].astype(np.datetime64) 114 | 115 | xds["solarexposure"] = (xds["sunset"] - xds["sunrise"]).dt.seconds 116 | return xds 117 | 118 | # Read all weather csv and create a pandas.DataFrame of its 119 | weather = [] 120 | for path in glob.glob(join(ROOT_DIR, "data", "external", "weather", "*.csv")): 121 | weather.append(pd.read_csv(path)) 122 | df_weather = pd.concat(weather, axis="index") 123 | 124 | # Convert timestamp into datetime for future purpose 125 | df_weather["datetime"] = pd.to_datetime(df_weather["datetime"]) 126 | # Format name to match District features 127 | df_weather["name"] = df_weather["name"].str.replace(" ", "_") 128 | # Set as index datetime and name to became dimensions with the 129 | # xarray.Dataset conversion 130 | df_weather.set_index(["datetime", "name"], inplace=True) 131 | xds_weather = df_weather.to_xarray().set_coords(["datetime", "name"]) 132 | xds_weather["datetime"] = xds_weather["datetime"].dt.strftime("%Y-%m-%d") 133 | # Feature engineering on weather data 134 | xds_weather = features_modification(xds_weather) 135 | # Merge both Dataset 136 | xds = xr.merge([xds, xds_weather]) 137 | 138 | return xds 139 | 140 | 141 | def compute_vi(xds: xr.Dataset) -> xr.Dataset: 142 | """Compute vegetable indices. That include NDVI, SAVI, EVI, REP, OSAVI, RDVI, MTVI1, LSWI. 143 | 144 | :param xds: Dataset that include satellite band data, used to compute vegetable indice. 145 | :type xds: xr.Dataset 146 | :return: Merged Dataset. 147 | :rtype: xr.Dataset 148 | """ 149 | # Compute vegetable indices 150 | 151 | def compute_ndvi(xds: xr.Dataset) -> xr.Dataset: 152 | # Compute ndvi indice 153 | return (xds.nir - xds.red) / (xds.nir + xds.red) 154 | 155 | def compute_savi(xds, L=0.5) -> xr.Dataset: 156 | # Compute savi indice 157 | return 1 + L * (xds.nir - xds.red) / (xds.nir + xds.red + L) 158 | 159 | def compute_evi(xds, G=2.5, L=1, C1=6, C2=7.5) -> xr.Dataset: 160 | # Compute evi indice 161 | return G * (xds.nir - xds.red) / (xds.nir + C1 * xds.red - C2 * xds.blue + L) 162 | 163 | def compute_rep(xds: xr.Dataset) -> xr.Dataset: 164 | # Compute rep indice 165 | rededge = (xds.red + xds.rededge3) / 2 166 | return 704 + 35 * (rededge - xds.rededge1) / (xds.rededge2 - xds.rededge1) 167 | 168 | def compute_osavi(xds: xr.Dataset) -> xr.Dataset: 169 | # Compute osavi indice 170 | return (xds.nir - xds.red) / (xds.nir + xds.red + 0.16) 171 | 172 | def compute_rdvi(xds: xr.Dataset) -> xr.Dataset: 173 | # Compute rdvi indice 174 | return (xds.nir - xds.red) / np.sqrt(xds.nir + xds.red) 175 | 176 | def compute_mtvi1(xds: xr.Dataset) -> xr.Dataset: 177 | # Compute mtvi1 indice 178 | return 1.2 * (1.2 * (xds.nir - xds.green) - 2.5 * (xds.red - xds.green)) 179 | 180 | def compute_lswi(xds: xr.Dataset) -> xr.Dataset: 181 | # Compute lswi indice 182 | return (xds.nir - xds.swir) / (xds.nir + xds.swir) 183 | 184 | xds["ndvi"] = compute_ndvi(xds) 185 | xds["savi"] = compute_savi(xds) 186 | xds["evi"] = compute_evi(xds) 187 | xds["rep"] = compute_rep(xds) 188 | xds["osavi"] = compute_osavi(xds) 189 | xds["rdvi"] = compute_rdvi(xds) 190 | xds["mtvi1"] = compute_mtvi1(xds) 191 | xds["lswi"] = compute_lswi(xds) 192 | 193 | return xds 194 | 195 | 196 | def statedev_fill(xds: xr.Dataset) -> xr.Dataset: 197 | # Fill missing vegetable indice and replace abnormal values 198 | 199 | def replaceinf(arr: np.ndarray) -> np.ndarray: 200 | if np.issubdtype(arr.dtype, np.number): 201 | arr[np.isinf(arr)] = np.nan 202 | return arr 203 | 204 | # replace ± infinite value by na 205 | xr.apply_ufunc(replaceinf, xds[S_COLUMNS]) 206 | # compute mean of all stage of developpement and all obsevation 207 | xds_mean = xds[S_COLUMNS].mean(dim="ts_aug", skipna=True) 208 | # fill na value with computed mean 209 | xds[S_COLUMNS] = xds[S_COLUMNS].fillna(xds_mean) 210 | # compute mean of all stage of developpement of rice field to complete last na values 211 | xds_mean = xds_mean.mean(dim="ts_obs", skipna=True) 212 | # fill na value with computed mean 213 | xds[S_COLUMNS] = xds[S_COLUMNS].fillna(xds_mean) 214 | 215 | return xds 216 | 217 | 218 | def features_modification(xds: xr.Dataset, test: bool) -> xr.Dataset: 219 | """Reduce dimension of the Dataset to only keep useful features for training. 220 | Transform features for training. 221 | 222 | :param xds: The Dataset used to perform dimension reduction and transform timestamp into numpy.datetime64. 223 | :type xds: xr.Dataset 224 | :param test: If True then the target name is 'Predicted Rice Yield (kg/ha)' else it is 'Rice Yield (kg/ha)'. 225 | :type test: bool 226 | :return: Transformed Dataset. 227 | :rtype: xr.Dataset 228 | """ 229 | xds["time"] = xds["time"].astype(np.datetime64) 230 | xds["datetime"] = xds["datetime"].astype(np.datetime64) 231 | xds = xds.reset_coords("time") 232 | 233 | # time and District are keys to link with weather data 234 | columns = S_COLUMNS + G_COLUMNS + M_COLUMNS + ["time", "District"] 235 | if test: 236 | columns.append(TARGET_TEST) 237 | else: 238 | columns.append(TARGET) 239 | xds = xds[columns] 240 | 241 | return xds 242 | 243 | 244 | def scale_data(xds: xr.Dataset, dir: str, test: bool) -> xr.Dataset: 245 | """Scale all features of the Dataset and save the scaler. 246 | 247 | :param xds: The Dataset used to perform the scaling. 248 | :type xds: xr.Dataset 249 | :param dir: Directory to save the scaler. 250 | :type dir: str 251 | :param test: If True then perform a transform else perform a fit_transform. 252 | :type test: bool 253 | :return: Transformed Dataset. 254 | :rtype: xr.Dataset 255 | """ 256 | # Path for saving scaler 257 | path = join(dir, "scaler_dataset.joblib") 258 | # Perform a fit_transform else Perform a transform. 259 | if not test: 260 | # Initialised scaler and all subscaler 261 | scaler = DatasetScaler( 262 | scaler_s=StandardScaler(), 263 | columns_s=S_COLUMNS, 264 | scaler_g=StandardScaler(), 265 | columns_g=G_COLUMNS, 266 | scaler_m=StandardScaler(), 267 | columns_m=M_COLUMNS, 268 | scaler_t=MinMaxScaler(), 269 | ) 270 | # Fit the scaler and Transform the data 271 | xds = scaler.fit_transform(xds, TARGET) 272 | # Save the scaler 273 | joblib.dump(scaler, path) 274 | else: 275 | # Load scaler and transform data 276 | scaler: DatasetScaler = joblib.load(path) 277 | xds = scaler.transform(xds) 278 | 279 | return xds 280 | 281 | 282 | def create_id(xds: xr.Dataset) -> xr.Dataset: 283 | """Add the coordinate ts_id to be used as index in the Pytorch Dataset. 284 | 285 | :param xds: Dataset used to add IDs. 286 | :type xds: xr.Dataset 287 | :return: Transformed Dataset. 288 | :rtype: xr.Dataset 289 | """ 290 | # Create np.ndarray with unique integer of the dimension number of Observation * number of Augmentation. 291 | ts_id = np.arange(xds.dims["ts_obs"] * xds.dims["ts_aug"]) 292 | # Reshape and assign it as coordinate to the Dataset 293 | ts_id = ts_id.reshape((xds.dims["ts_obs"], xds.dims["ts_aug"])) 294 | xds = xds.assign_coords({"ts_id": (("ts_obs", "ts_aug"), ts_id)}) 295 | return xds 296 | 297 | 298 | def create_pb(nb_action: int, test: bool) -> tuple: 299 | """Initialise tqdm progressbar for preprossessing verbose purpose. 300 | 301 | :param nb_action: Number of preprossessing steps. 302 | :type nb_action: int 303 | :param test: True if it is the Test preprossessing. 304 | :type test: bool 305 | :return: Progressbar and Begining of the message for the progressbar. 306 | :rtype: tuple 307 | """ 308 | progress_bar = tqdm(range(nb_action), leave=False) 309 | if test: 310 | msg = "Test Dataset - " 311 | else: 312 | msg = "Train Dataset - " 313 | return progress_bar, msg 314 | 315 | 316 | def process_data(folder: str, test: bool = False) -> None: 317 | """Prepare data for Deep Learning and Machine Learning purpose and save it in processed directory. 318 | 319 | :param folder: Directory to load Satellite Dataset. 320 | :type folder: str 321 | :param test: True if it is the Test preprossessing, defaults to False 322 | :type test: bool, optional 323 | """ 324 | # Create the progress bar 325 | pb, msg = create_pb(9, test) 326 | 327 | # Determine the name of the processed / original dataset file 328 | file_name = "train.nc" 329 | if test: 330 | file_name = "test.nc" 331 | 332 | # Create all directories useful 333 | processed_dir = join(ROOT_DIR, "data", "processed", folder) 334 | os.makedirs(processed_dir, exist_ok=True) 335 | processed_path = join(processed_dir, file_name) 336 | 337 | interim_dir = join(ROOT_DIR, "data", "interim", folder) 338 | os.makedirs(interim_dir, exist_ok=True) 339 | interim_path = join(interim_dir, file_name) 340 | 341 | # Load Satellite Dataset 342 | pb.set_description(msg + "Read Data") 343 | if folder == "augment_100_5": 344 | xds = merge_satellite(file_name) 345 | else: 346 | path_sat = join(ROOT_DIR, "data", "external", "satellite", folder, file_name) 347 | xds = xr.open_dataset(path_sat, engine="scipy") 348 | 349 | # Concatenate and create all features 350 | pb.update(1) 351 | pb.refresh() 352 | pb.set_description(msg + "Add Paddies Data") 353 | # Process and Merge EY data to Satellite Dataset 354 | xds = add_observation(xds, test) 355 | 356 | pb.update(2) 357 | pb.refresh() 358 | pb.set_description(msg + "Add Meteorological Data") 359 | # Process and Merge Weather data to Satellite & EY Dataset 360 | xds = add_weather(xds) 361 | 362 | pb.update(3) 363 | pb.refresh() 364 | pb.set_description(msg + "Compute Vegetable Indices") 365 | # Compute vegetable indices 366 | xds = compute_vi(xds) 367 | 368 | # Save for ML 369 | xds.to_netcdf(interim_path, engine="scipy") 370 | 371 | pb.update(4) 372 | pb.refresh() 373 | pb.set_description(msg + "Fill NaN values") 374 | # Fill missing vegetable indice and replace abnormal values 375 | xds = statedev_fill(xds) 376 | 377 | # Smooth variable 378 | pb.update(5) 379 | pb.refresh() 380 | pb.set_description(msg + "Smooth VI") 381 | xds = Smoother(mode='savgol').transform(xds) 382 | 383 | # Create new features 384 | pb.update(6) 385 | pb.refresh() 386 | pb.set_description(msg + "Modification of Features") 387 | xds = features_modification(xds, test) 388 | 389 | # Scale data 390 | pb.update(7) 391 | pb.refresh() 392 | pb.set_description(msg + "Data Scaling") 393 | xds = scale_data(xds, processed_dir, test) 394 | 395 | # Add an id for each line 396 | pb.update(8) 397 | pb.refresh() 398 | pb.set_description(msg + "Create an Index 1D") 399 | xds = create_id(xds) 400 | 401 | # Save data for DL 402 | pb.update(9) 403 | pb.refresh() 404 | pb.set_description(msg + "Saving Data") 405 | xds.to_netcdf(processed_path, engine="scipy") 406 | 407 | 408 | if __name__ == "__main__": 409 | process_data(FOLDER) 410 | process_data(FOLDER, test=True) 411 | -------------------------------------------------------------------------------- /docs/build/html/make_train.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | Train - Deep Learning | Crop forecasting 1.0.0 documentation 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 31 | 32 | 33 | 34 | 35 | 36 | 37 |
38 | Skip to content 39 |
40 | 48 |
49 |
50 | 57 |
58 | 68 |
69 |
70 |
71 |
102 |
103 |
104 | 114 |
115 | _images/banner.jpg 116 |
117 |

Train - Deep Learning

118 |
119 |

init_wandb

120 |
121 |
122 | src.models.make_train.init_wandb() tuple[dict, torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader]
123 |
124 |
Init W&B logger and get the model config from W&B sweep config yaml file
    125 |
  • get the training and validation dataloaders.

  • 126 |
127 |
128 |
129 |
130 |
Returns:
131 |

the model config and the training and validation dataloaders

132 |
133 |
Return type:
134 |

(dict, DataLoader, DataLoader)

135 |
136 |
137 |
138 |
139 |
140 |
150 |
151 |
152 |
159 |
160 | 161 | --------------------------------------------------------------------------------