├── tests ├── __init__.py ├── test_cleaner.py └── test_rnn.py ├── data ├── preprocessing │ ├── __init__.py │ ├── gleam_cleaner.py │ ├── chirps_cleaner.py │ ├── collect_regrid_data_script.py │ ├── spei_cleaner.py │ ├── cleaner.py │ └── utils.py ├── README.md ├── raw │ └── README.md ├── processed │ └── README.md ├── nc_to_pandas.py ├── ee_extract_landcover_by_country.js ├── globmap_lookup.py ├── utils.py ├── drought_masking.py └── common_grid.py ├── predictor ├── models │ ├── neural_networks │ │ ├── __init__.py │ │ ├── feedforward.py │ │ ├── nn_base.py │ │ └── recurrent.py │ ├── __init__.py │ ├── baseline.py │ └── base.py ├── __init__.py ├── analysis │ ├── __init__.py │ ├── utils.py │ ├── plot_shap.py │ └── plot_results.py ├── engineer.py └── preprocessing.py ├── ndvi_results.png ├── figs ├── variables_histogram.png └── ndvi_results_logistic_regression.png ├── .gitignore ├── TODO.md ├── notebooks ├── 10_tl_growing_season.ipynb ├── 05_tl_explore_models.ipynb ├── 02_gt_linear_model.ipynb └── 01_tl_data_exploration.ipynb ├── run.py ├── environment.yml └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /predictor/models/neural_networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | Placeholder for data 4 | -------------------------------------------------------------------------------- /data/raw/README.md: -------------------------------------------------------------------------------- 1 | # Raw data 2 | 3 | Placeholder for the raw data file, `tabular_data.csv`. 4 | -------------------------------------------------------------------------------- /ndvi_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommylees112/vegetation_health/HEAD/ndvi_results.png -------------------------------------------------------------------------------- /predictor/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocessing import Cleaner 2 | from .engineer import Engineer 3 | -------------------------------------------------------------------------------- /predictor/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | from .plot_shap import plot_shap_values 2 | from .plot_results import plot_results 3 | -------------------------------------------------------------------------------- /figs/variables_histogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommylees112/vegetation_health/HEAD/figs/variables_histogram.png -------------------------------------------------------------------------------- /figs/ndvi_results_logistic_regression.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommylees112/vegetation_health/HEAD/figs/ndvi_results_logistic_regression.png -------------------------------------------------------------------------------- /data/processed/README.md: -------------------------------------------------------------------------------- 1 | # Processed Data 2 | 3 | Placeholder for processed data. Steps in the pipeline will save data here for future steps to use. 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.csv 2 | *.npy 3 | .DS_Store 4 | notebooks/.ipynb_checkpoints 5 | .ipynb_checkpoints 6 | __pycache__ 7 | *.pyc 8 | .idea 9 | temp.py 10 | *.nc 11 | test.py 12 | *.json 13 | docs/ 14 | -------------------------------------------------------------------------------- /predictor/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .baseline import LinearModel 2 | from .neural_networks.feedforward import FeedForward as nn_FeedForward 3 | from .neural_networks.recurrent import Recurrent as nn_Recurrent 4 | -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | Tommy: 2 | - [x] Produce a tabular dataset of precip / temp / vegetation health indices 3 | - [x] upload to github (if < 50mb) (SENT BY WE TRANSFER) 4 | - [x] Mask out the sea values (not sure if already done - check!) 5 | 6 | Gabriel: 7 | - [x] produce project skeleton 8 | - [x] Mask out the LST values == 200 9 | -------------------------------------------------------------------------------- /data/nc_to_pandas.py: -------------------------------------------------------------------------------- 1 | """ convert netcdf file to tabular dataformat (pandas) 2 | BOILERPLATE 3 | """ 4 | import xarray as xr 5 | import pandas as pd 6 | 7 | # change this path to point to the .nc file 8 | data_dir = '' 9 | 10 | ds = xr.open_dataset(data_dir) 11 | 12 | # NOTE 13 | df = ds.to_dataframe() 14 | df.to_csv('path/to/csv/file') 15 | -------------------------------------------------------------------------------- /predictor/analysis/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from collections import namedtuple 4 | 5 | 6 | 7 | def load_model_data(train_or_test, target='ndvi'): 8 | """ Return a named tuple with the following data attrs 9 | x, y, latlon, years. This is the data fed through the model. 10 | """ 11 | if train_or_test == "test": 12 | data_dir = Path('.') / "data" / "processed" / target / "arrays" / "test" 13 | elif train_or_test == "train": 14 | data_dir = Path('.') / "data" / "processed" / target / "arrays" / "train" 15 | else: 16 | assert False, "train_or_test must be either ['train','test']" 17 | 18 | Data = namedtuple('Data',["x","y","latlon","years"]) 19 | data = Data( 20 | x=np.load(data_dir/"x.npy"), 21 | y=np.load(data_dir/"y.npy"), 22 | latlon=np.load(data_dir/"latlon.npy"), 23 | years=np.load(data_dir/"years.npy"), 24 | ) 25 | 26 | return data 27 | -------------------------------------------------------------------------------- /tests/test_cleaner.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from datetime import datetime 3 | 4 | from predictor import Cleaner 5 | 6 | 7 | def test_year_month(): 8 | 9 | # lets predict June given a year's data 10 | months_2018 = [datetime(2018, x, 1) for x in range(1, 13)] 11 | months_2019 = [datetime(2019, x, 1) for x in range(1, 13)] 12 | 13 | g1, g2, g3 = [1] * 5, [2] * 12, [3] * 7 14 | test_data = { 15 | 'times': months_2018 + months_2019, 16 | 'group': g1 + g2 + g3, 17 | } 18 | test_df = pd.DataFrame(data=test_data) 19 | 20 | cleaner = Cleaner() 21 | test_df['gp_month'], test_df['gp_year'] = cleaner.update_year_month(test_df['times'], 22 | pred_month=5) 23 | 24 | # all of g2 should be in the same gp_year 25 | group_2 = test_df[test_df['gp_year'] == 2019] 26 | 27 | assert len(group_2) == 12, "Chopped out some 2018 data" 28 | assert (group_2['group'] == 2).all(), "Not all the correct months were grouped!" 29 | -------------------------------------------------------------------------------- /data/ee_extract_landcover_by_country.js: -------------------------------------------------------------------------------- 1 | // ee_extract_landcover_by_country.js 2 | 3 | // import the landcover dataset & the country shapefiles 4 | var globcover = ee.FeatureCollection('ESA/GLOBCOVER_L4_200901_200912_V2_3') 5 | var world_region = ee.FeatureCollection('ft:1tdSwUL7MVpOauSgRzqVTOwdfy17KDbw-1d9omPw') 6 | var landcover = globcover.select('landcover'); 7 | 8 | // create Polygon for Ethiopia and Kenya 9 | var eth = world_region.filterMetadata('Country','equals','Ethiopia').select(['landcover']); 10 | var kenya = world_region.filterMetadata('Country','equals','Kenya'); 11 | 12 | // clip the landcover to the countries 13 | var eth_lc = landcover.clip(eth); 14 | var ken_lc = landcover.clip(kenya); 15 | 16 | // export the landcover map 17 | var scale = 500; 18 | var crs='EPSG:4326'; 19 | 20 | var task = Export.image.toDrive({ 21 | image: eth_lc, 22 | description: 'ethiopia_landcover', 23 | scale: 1000, 24 | region: eth, 25 | folder: 'landcover' 26 | }) 27 | 28 | var task_kenya = Export.image.toDrive({ 29 | image: ken_lc, 30 | description: 'kenya_landcover', 31 | scale: 1000, 32 | region: kenya, 33 | folder: 'landcover' 34 | }) 35 | -------------------------------------------------------------------------------- /predictor/models/baseline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import linear_model 3 | from sklearn.metrics import mean_squared_error 4 | 5 | from .base import ModelBase 6 | 7 | 8 | class LinearModel(ModelBase): 9 | """A logistic regression, to be used as a baseline 10 | against our more complex models 11 | """ 12 | 13 | model_name = 'linear' 14 | 15 | def train(self): 16 | 17 | train_data = self.load_arrays(mode='train') 18 | 19 | x = train_data.x.reshape(train_data.x.shape[0], -1) 20 | 21 | self.model = linear_model.LinearRegression() 22 | self.model.fit(x, train_data.y) 23 | 24 | train_pred_y = self.model.predict(x) 25 | train_rmse = np.sqrt(mean_squared_error(train_data.y, train_pred_y)) 26 | 27 | print(f'Train set RMSE: {train_rmse}') 28 | 29 | def predict(self): 30 | 31 | test_data = self.load_arrays(mode='test') 32 | x = test_data.x.reshape(test_data.x.shape[0], -1) 33 | test_pred_y = self.model.predict(x) 34 | return test_data.y, test_pred_y 35 | 36 | def save_model(self): 37 | savedir = self.data_path / self.model_name 38 | if not savedir.exists(): savedir.mkdir() 39 | np.save(savedir / 'model.npy', self.model.coef_) 40 | -------------------------------------------------------------------------------- /tests/test_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | from predictor.models.neural_networks.recurrent import UnrolledRNN 6 | 7 | 8 | def test_rnn(): 9 | """ 10 | We implement our own unrolled RNN, so that it can be explained with 11 | shap. This test makes sure it roughly mirrors the behaviour of the pytorch 12 | LSTM. 13 | """ 14 | 15 | batch_size, hidden_size, features_per_month = 32, 124, 6 16 | 17 | x = torch.ones(batch_size, 1, features_per_month) 18 | 19 | hidden_state = torch.zeros(1, x.shape[0], hidden_size) 20 | cell_state = torch.zeros(1, x.shape[0], hidden_size) 21 | 22 | torch_rnn = nn.LSTM(input_size=features_per_month, 23 | hidden_size=hidden_size, 24 | batch_first=True, 25 | num_layers=1) 26 | 27 | our_rnn = UnrolledRNN(input_size=features_per_month, 28 | hidden_size=hidden_size, 29 | batch_first=True) 30 | 31 | for parameters in torch_rnn.all_weights: 32 | for pam in parameters: 33 | nn.init.constant_(pam.data, 1) 34 | 35 | for parameters in our_rnn.parameters(): 36 | for pam in parameters: 37 | nn.init.constant_(pam.data, 1) 38 | 39 | with torch.no_grad(): 40 | o_out, (o_cell, o_hidden) = our_rnn(x, (hidden_state, cell_state)) 41 | t_out, (t_cell, t_hidden) = torch_rnn(x, (hidden_state, cell_state)) 42 | 43 | assert np.isclose(o_out.numpy(), t_out.numpy(), 0.01).all(), "Difference in hidden state" 44 | assert np.isclose(t_cell.numpy(), o_cell.numpy(), 0.01).all(), "Difference in cell state" 45 | -------------------------------------------------------------------------------- /notebooks/10_tl_growing_season.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# When is the growing season? (Season with Maximum NDVI?)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import xarray as xr\n", 17 | "import pandas as pd\n", 18 | "import numpy as np\n", 19 | "from pathlib import Path\n", 20 | "from collections import namedtuple\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "import seaborn as sns\n", 23 | "%matplotlib inline\n", 24 | "\n", 25 | "import os\n", 26 | "if os.getcwd().split('/')[-1] != \"vegetation_health\":\n", 27 | " os.chdir('..')\n", 28 | " \n", 29 | "assert os.getcwd().split('/')[-1] == \"vegetation_health\", f\"Working directory should be the root (), currently: {os.getcwd()}\"\n", 30 | "\n", 31 | "from predictor.analysis.plot_results import create_dataset_from_vars, plot_results\n", 32 | "from predictor.analysis.utils import load_model_data" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [] 48 | } 49 | ], 50 | "metadata": { 51 | "kernelspec": { 52 | "display_name": "Python 3", 53 | "language": "python", 54 | "name": "python3" 55 | }, 56 | "language_info": { 57 | "codemirror_mode": { 58 | "name": "ipython", 59 | "version": 3 60 | }, 61 | "file_extension": ".py", 62 | "mimetype": "text/x-python", 63 | "name": "python", 64 | "nbconvert_exporter": "python", 65 | "pygments_lexer": "ipython3", 66 | "version": "3.7.2" 67 | } 68 | }, 69 | "nbformat": 4, 70 | "nbformat_minor": 2 71 | } 72 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import fire 2 | from pathlib import Path 3 | 4 | from predictor import Cleaner, Engineer 5 | from predictor.models import LinearModel, nn_FeedForward, nn_Recurrent 6 | 7 | 8 | class RunTask: 9 | 10 | @staticmethod 11 | def clean(raw_filepath='data/raw/predict_vegetation_health.nc', 12 | processed_folder='data/processed', 13 | target='ndvi', pred_month=6): 14 | 15 | raw_filepath, processed_folder = Path(raw_filepath), Path(processed_folder) 16 | processed_filepath = processed_folder / target / 'cleaned_data.csv' 17 | 18 | cleaner = Cleaner(raw_filepath, processed_filepath) 19 | cleaner.process(pred_month, target) 20 | 21 | @staticmethod 22 | def engineer(processed_folder='data/processed', target='ndvi', 23 | test_year=2016): 24 | 25 | processed_folder = Path(processed_folder) 26 | cleaned_data = processed_folder / target / 'cleaned_data.csv' 27 | arrays_folder = processed_folder / target / 'arrays' 28 | 29 | engineer = Engineer(cleaned_data, arrays_folder) 30 | engineer.process(test_year) 31 | 32 | @staticmethod 33 | def train_model(model_type='baseline', data_folder='data', 34 | target='ndvi', hide_vegetation=True, save_results=True): 35 | 36 | data_folder = Path(data_folder) 37 | arrays_folder = data_folder / 'processed' / target / 'arrays' 38 | 39 | string2model = { 40 | 'baseline': LinearModel(data_folder, arrays_folder, hide_vegetation), 41 | 'feedforward': nn_FeedForward(data_folder, arrays_folder, hide_vegetation), 42 | 'recurrent': nn_Recurrent(data_folder, arrays_folder, hide_vegetation), 43 | } 44 | 45 | model = string2model[model_type] 46 | model.train() 47 | model.evaluate(save_preds=save_results) 48 | if save_results: 49 | model.save_model() 50 | 51 | 52 | if __name__ == '__main__': 53 | fire.Fire(RunTask) 54 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: vegetation_health 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - atomicwrites=1.3.0 9 | - attrs=19.1.0 10 | - blas=1.0 11 | - bzip2=1.0.6 12 | - cftime=1.0.3.4 13 | - curl=7.64.0 14 | - hdf4=4.2.13 15 | - hdf5=1.10.4 16 | - intel-openmp=2019.1 17 | - jpeg=9b 18 | - libcurl=7.64.0 19 | - libgfortran=3.0.1 20 | - libnetcdf=4.6.1 21 | - mkl=2019.1 22 | - mkl_fft=1.0.10 23 | - mkl_random=1.0.2 24 | - more-itertools=6.0.0 25 | - netcdf4=1.4.2 26 | - numpy=1.16.2 27 | - numpy-base=1.16.2 28 | - pandas=0.24.1 29 | - pluggy=0.9.0 30 | - py=1.8.0 31 | - pytest=4.3.0 32 | - python-dateutil=2.8.0 33 | - pytz=2018.9 34 | - scikit-learn=0.20.2 35 | - scipy=1.2.1 36 | - six=1.12.0 37 | - xarray=0.11.3 38 | - ca-certificates=2019.3.9 39 | - certifi=2019.3.9 40 | - cycler=0.10.0 41 | - fire=0.1.3 42 | - freetype=2.9.1 43 | - kiwisolver=1.0.1 44 | - libpng=1.6.36 45 | - libssh2=1.8.0 46 | - matplotlib=3.0.3 47 | - matplotlib-base=3.0.3 48 | - openssl=1.1.1b 49 | - pyparsing=2.3.1 50 | - tqdm=4.31.1 51 | - appnope=0.1.0 52 | - backcall=0.1.0 53 | - cffi=1.12.2 54 | - decorator=4.3.2 55 | - ipykernel=5.1.0 56 | - ipython=7.3.0 57 | - ipython_genutils=0.2.0 58 | - jedi=0.13.3 59 | - jupyter_client=5.2.4 60 | - jupyter_core=4.4.0 61 | - krb5=1.16.1 62 | - libcxx=4.0.1 63 | - libcxxabi=4.0.1 64 | - libedit=3.1.20181209 65 | - libffi=3.2.1 66 | - libsodium=1.0.16 67 | - ncurses=6.1 68 | - ninja=1.8.2 69 | - parso=0.3.4 70 | - pexpect=4.6.0 71 | - pickleshare=0.7.5 72 | - pip=19.0.3 73 | - prompt_toolkit=2.0.9 74 | - ptyprocess=0.6.0 75 | - pycparser=2.19 76 | - pygments=2.3.1 77 | - python=3.6.8 78 | - pyzmq=18.0.0 79 | - readline=7.0 80 | - setuptools=40.8.0 81 | - sqlite=3.26.0 82 | - tk=8.6.8 83 | - tornado=6.0.1 84 | - traitlets=4.3.2 85 | - wcwidth=0.1.7 86 | - wheel=0.33.1 87 | - xz=5.2.4 88 | - zeromq=4.3.1 89 | - zlib=1.2.11 90 | - pytorch=1.0.1 91 | - pip: 92 | - torch==1.0.1.post2 93 | -------------------------------------------------------------------------------- /predictor/analysis/plot_shap.py: -------------------------------------------------------------------------------- 1 | from mpl_toolkits.axes_grid1 import host_subplot 2 | import mpl_toolkits.axisartist as AA 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def plot_shap_values(x, shap_values, val_list, normalizing_dict, value_to_plot, normalize_shap_plots=True): 7 | """Plots the denormalized values against their shap values, so that 8 | variations in the input features to the model can be compared to their effect 9 | on the model. For example plots, see notebooks/08_gt_recurrent_model.ipynb. 10 | 11 | Parameters: 12 | ---------- 13 | x: np.array 14 | The input to a model for a single data instance 15 | shap_values: np.array 16 | The corresponding shap values (to x) 17 | val_list: list 18 | A list of the variable names, for axis labels 19 | normalizing_dict: dict 20 | The normalizing dict saved by the `Cleaner`, so that the x array can be 21 | denormalized 22 | value_to_plot: str 23 | The specific input variable to plot. Must be in val_list 24 | normalize_shap_plots: boolean 25 | If True, then the scale of the shap plots will be uniform across all 26 | variable plots (on an instance specific basis). 27 | 28 | A plot of the variable `value_to_plot` against its shap values will be plotted. 29 | """ 30 | # first, lets isolate the lists 31 | idx = val_list.index(value_to_plot) 32 | 33 | x_val = x[:, idx].cpu().numpy() 34 | 35 | # we also want to denormalize 36 | x_val = (x_val * normalizing_dict[value_to_plot]['std']) + \ 37 | normalizing_dict[value_to_plot]['mean'] 38 | 39 | shap_val = shap_values[:, idx] 40 | 41 | months = list(range(1, 12)) 42 | 43 | host = host_subplot(111, axes_class=AA.Axes) 44 | plt.subplots_adjust(right=0.75) 45 | 46 | par1 = host.twinx() 47 | par1.axis["right"].toggle(all=True) 48 | 49 | if normalize_shap_plots: 50 | par1.set_ylim(shap_values.min(), shap_values.max()) 51 | 52 | host.set_xlabel("Months") 53 | host.set_ylabel(value_to_plot) 54 | par1.set_ylabel("Shap value") 55 | 56 | p1, = host.plot(months, x_val, label=value_to_plot) 57 | p2, = par1.plot(months, shap_val, label="shap value") 58 | 59 | host.axis["left"].label.set_color(p1.get_color()) 60 | par1.axis["right"].label.set_color(p2.get_color()) 61 | 62 | host.legend() 63 | 64 | plt.draw() 65 | plt.show() 66 | -------------------------------------------------------------------------------- /predictor/analysis/plot_results.py: -------------------------------------------------------------------------------- 1 | import xarray as xr 2 | import pandas as pd 3 | import numpy as np 4 | from pathlib import Path 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def create_dataset_from_vars(vars, latlon, varname, to_xarray=True): 9 | """ Convert the variables from `np.array` to `pd.DataFrame` 10 | and optionally `xr.Dataset`. By default converts to `xr.Dataset` 11 | 12 | Arguments: 13 | --------- 14 | : vars (np.array) 15 | the values of the variable of interest (e.g. Predictions of NDVI from model) 16 | 17 | : latlon (np.array) 18 | the latlon location for each of the values in vars 19 | 20 | : varname (str) 21 | the name of the variable 22 | 23 | TODO: 24 | ---- 25 | * Implement a method that works with TIME so that the xarray objects 26 | have a time dimension too 27 | """ 28 | assert len(vars) == len(latlon), f"The length of the latlons array should be the same as the legnth of the vars array. Currently latlons: {len(latlon)} vars: {len(vars)}" 29 | 30 | 31 | df = pd.DataFrame(data={varname: vars, 'lat': latlon[:, 0], 32 | 'lon': latlon[:, 1]}).set_index(['lat', 'lon']) 33 | if to_xarray: 34 | return df.to_xarray() 35 | else: 36 | return df 37 | 38 | 39 | def plot_results(processed_data=Path('data/processed'), target='ndvi', 40 | plot_difference=False, savefig=True): 41 | """Plots a landscape of the results (and optionally, 42 | of the ground truth) 43 | """ 44 | 45 | preds = np.load(processed_data / target / 'arrays/preds.npy') 46 | true = np.load(processed_data / target / 'arrays/test/y.npy') 47 | latlon = np.load(processed_data / target / 'arrays/test/latlon.npy') 48 | 49 | preds_xr = create_dataset_from_vars(preds, latlon, "preds", to_xarray=True) 50 | true_xr = create_dataset_from_vars(true, latlon, "true", to_xarray=True) 51 | 52 | data_xr = xr.concat((preds_xr['preds'], true_xr['true']), 53 | pd.Index(['predictions', 'ground truth'], name='data')) 54 | 55 | if plot_difference: 56 | # compute the difference and create a difference plot 57 | data = data_xr.data[1] - data_xr.data[0] 58 | da = xr.DataArray(data, coords=[data_xr.lat, data_xr.lon], dims=['lat','lon']) 59 | ds = da.to_dataset('difference') 60 | 61 | ds.difference.plot(x='lon', y='lat', figsize=(12, 8)) 62 | else: 63 | data_xr.plot(x='lon', y='lat', col='data', figsize=(15, 6)) 64 | 65 | if savefig: 66 | plt.savefig(f'{target}_results.png', dpi=300, bbox_inches='tight') 67 | plt.show() 68 | -------------------------------------------------------------------------------- /predictor/models/neural_networks/feedforward.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from pathlib import Path 4 | 5 | from .nn_base import NNBase 6 | from ...preprocessing import VALUE_COLS, VEGETATION_LABELS 7 | 8 | 9 | class FeedForward(NNBase): 10 | """A simple feedforward neural network 11 | """ 12 | 13 | def __init__(self, data=Path('data'), arrays=Path('data/processed/arrays'), 14 | hide_vegetation=False): 15 | 16 | features_per_month = len(VALUE_COLS) 17 | if hide_vegetation: 18 | features_per_month -= len(VEGETATION_LABELS) 19 | 20 | num_features = features_per_month * 11 21 | 22 | super().__init__(LinearModel(num_features, [num_features], 0.25), 23 | data, arrays, hide_vegetation) 24 | self.model_name = "feedforward" 25 | 26 | class LinearModel(nn.Module): 27 | 28 | def __init__(self, input_size, layer_sizes, dropout): 29 | super().__init__() 30 | layer_sizes.insert(0, input_size) 31 | 32 | self.dense_layers = nn.ModuleList([ 33 | LinearBlock(in_features=layer_sizes[i - 1], 34 | out_features=layer_sizes[i], dropout=dropout) for 35 | i in range(1, len(layer_sizes)) 36 | ]) 37 | 38 | self.final_dense = nn.Linear(in_features=layer_sizes[-1], out_features=1) 39 | 40 | self.init_weights() 41 | 42 | def init_weights(self): 43 | for dense_layer in self.dense_layers: 44 | nn.init.kaiming_uniform_(dense_layer.linear.weight.data) 45 | 46 | nn.init.kaiming_uniform_(self.final_dense.weight.data) 47 | # http://cs231n.github.io/neural-networks-2/#init 48 | # see: Initializing the biases 49 | nn.init.constant_(self.final_dense.bias.data, 0) 50 | 51 | def forward(self, x): 52 | # flatten 53 | x = x.view(x.shape[0], -1) 54 | for layer in self.dense_layers: 55 | x = layer(x) 56 | 57 | return self.final_dense(x) 58 | 59 | 60 | class LinearBlock(nn.Module): 61 | """ 62 | A linear layer followed by batchnorm, a ReLU activation, and dropout 63 | """ 64 | 65 | def __init__(self, in_features, out_features, dropout=0.25): 66 | super().__init__() 67 | self.linear = nn.Linear(in_features=in_features, out_features=out_features, bias=False) 68 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 69 | self.batchnorm = nn.BatchNorm1d(num_features=out_features) 70 | self.dropout = nn.Dropout(dropout) 71 | 72 | def forward(self, x): 73 | x = self.relu(self.batchnorm(self.linear(x))) 74 | return self.dropout(x) 75 | -------------------------------------------------------------------------------- /data/preprocessing/gleam_cleaner.py: -------------------------------------------------------------------------------- 1 | """gleam_cleaner.py""" 2 | from pathlib import Path 3 | import xarray as xr 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import ipdb 8 | import warnings 9 | import os 10 | 11 | from preprocessing.utils import ( 12 | gdal_reproject, 13 | bands_to_time, 14 | convert_to_same_grid, 15 | select_same_time_slice, 16 | save_netcdf, 17 | get_holaps_mask, 18 | merge_data_arrays, 19 | ) 20 | 21 | from preprocessing.cleaner import Cleaner 22 | 23 | class GleamCleaner(Cleaner): 24 | def __init__(self): 25 | self.base_data_path = Path("/soge-home/projects/crop_yield/EGU_compare/") 26 | reference_data_path = self.base_data_path / "holaps_EA_clean.nc" 27 | 28 | # CHANGE THIS PATH: 29 | # ---------------------------------------------------------------------- 30 | data_path = self.base_data_path / "EA_chirps_monthly.nc" 31 | # ---------------------------------------------------------------------- 32 | 33 | # open the reference dataset 34 | self.reference_data_path = Path(reference_data_path) 35 | self.reference_ds = xr.open_dataset(self.reference_data_path).holaps_evapotranspiration 36 | 37 | # initialise the object using methods from the parent class 38 | super(ChirpsCleaner, self).__init__(data_path=data_path) 39 | 40 | # extract the variable of interest (TO xr.DataArray) 41 | self.update_clean_data( 42 | self.raw_data.precip, msg="Extract Precipitation from CHIRPS xr.Dataset" 43 | ) 44 | 45 | # make the mask (FROM REFERENCE_DS) to copy to this dataset too 46 | self.get_mask() 47 | # self.mask = self.mask.drop('units') 48 | 49 | def get_mask(self): 50 | self.mask = get_holaps_mask(self.reference_ds) 51 | 52 | def convert_units(self): 53 | # convert unit label to 'mm day-1' 54 | self.clean_data.attrs["units"] = "mm day-1" 55 | 56 | 57 | def preprocess(self): 58 | # Resample the timesteps to END OF MONTH 59 | self.resample_time(resample_str="M") 60 | # select the correct time slice 61 | self.correct_time_slice() 62 | # update the units 63 | self.convert_units() 64 | # regrid to same as reference data (holaps) 65 | self.regrid_to_reference() 66 | # ipdb.set_trace() 67 | # use the same mask as HOLAPS 68 | self.use_reference_mask() # THIS GOING WRONG 69 | # rename data 70 | self.rename_xr_object("gleam_evapotranspiration") 71 | # save data 72 | save_netcdf( 73 | self.clean_data, filepath=self.base_data_path / "gleam_EA_clean.nc" 74 | ) 75 | print("\n\n GLEAM Preprocessed \n\n") 76 | return 77 | -------------------------------------------------------------------------------- /data/preprocessing/chirps_cleaner.py: -------------------------------------------------------------------------------- 1 | """chirps_cleaner.py""" 2 | from pathlib import Path 3 | import xarray as xr 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import ipdb 8 | import warnings 9 | import os 10 | 11 | from preprocessing.utils import ( 12 | gdal_reproject, 13 | bands_to_time, 14 | convert_to_same_grid, 15 | select_same_time_slice, 16 | save_netcdf, 17 | get_holaps_mask, 18 | merge_data_arrays, 19 | ) 20 | 21 | from preprocessing.cleaner import Cleaner 22 | 23 | class ChirpsCleaner(Cleaner): 24 | def __init__(self): 25 | self.base_data_path = Path("/soge-home/projects/crop_yield/EGU_compare/") 26 | reference_data_path = self.base_data_path / "holaps_EA_clean.nc" 27 | 28 | # CHANGE THIS PATH: 29 | # ---------------------------------------------------------------------- 30 | data_path = self.base_data_path / "EA_chirps_monthly.nc" 31 | # ---------------------------------------------------------------------- 32 | 33 | # open the reference dataset 34 | self.reference_data_path = Path(reference_data_path) 35 | self.reference_ds = xr.open_dataset(self.reference_data_path).holaps_evapotranspiration 36 | 37 | # initialise the object using methods from the parent class 38 | super(ChirpsCleaner, self).__init__(data_path=data_path) 39 | 40 | # extract the variable of interest (TO xr.DataArray) 41 | self.update_clean_data( 42 | self.raw_data.precip, msg="Extract Precipitation from CHIRPS xr.Dataset" 43 | ) 44 | 45 | # make the mask (FROM REFERENCE_DS) to copy to this dataset too 46 | self.get_mask() 47 | # self.mask = self.mask.drop('units') 48 | 49 | 50 | def convert_units(self): 51 | daily_mm = self.clean_data / 30.417 52 | daily_mm.attrs.units = 'mm day-1' 53 | self.update_clean_data(daily_mm, msg="Change the mm month-1 values to mm day-1") 54 | 55 | 56 | def preprocess(self): 57 | # Resample the timesteps to END OF MONTH 58 | self.resample_time(resample_str="M") 59 | # select the correct time slice 60 | self.correct_time_slice() 61 | # update the units 62 | self.convert_units() 63 | # latitude,longitude => lat,lon 64 | self.rename_lat_lon() 65 | # regrid to same as reference data (holaps) 66 | self.regrid_to_reference(method="bilinear") 67 | # ipdb.set_trace() 68 | # use the same mask as HOLAPS 69 | self.use_reference_mask() # THIS GOING WRONG 70 | # rename data 71 | self.rename_xr_object("chirps_precipitation") 72 | # save data 73 | save_netcdf( 74 | self.clean_data, filepath=self.base_data_path / "chirps_EA_clean.nc" 75 | ) 76 | print("\n\n CHIRPS Preprocessed \n\n") 77 | return 78 | -------------------------------------------------------------------------------- /predictor/models/base.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | from collections import namedtuple 4 | 5 | from sklearn.metrics import mean_squared_error 6 | from ..preprocessing import VALUE_COLS, VEGETATION_LABELS 7 | 8 | DataTuple = namedtuple('Data', ['x', 'y', 'latlon', 'years']) 9 | 10 | 11 | class ModelBase: 12 | """Base for all machine learning models. 13 | 14 | Attributes: 15 | ---------- 16 | arrays: pathlib.Path 17 | The location where the arrays were saved by the `Engineer` class 18 | hide_vegetation: bool, default: False 19 | Whether to hide vegetation-specific information from the training 20 | data. This allows us to better understand how the other factors drive 21 | vegetation health. 22 | """ 23 | 24 | model_name = None # to be added by the model classes 25 | 26 | def __init__(self, data=Path('data'), arrays_path=Path('data/processed/arrays'), 27 | hide_vegetation=False): 28 | self.data_path = data 29 | self.arrays_path = arrays_path 30 | self.hide_vegetation = hide_vegetation 31 | self.model = None # to be added by the model classes 32 | 33 | def train(self): 34 | raise NotImplementedError 35 | 36 | def predict(self): 37 | # This method should return the predictions, and 38 | # the corresponding true values, read from the test 39 | # arrays 40 | raise NotImplementedError 41 | 42 | def save_model(self): 43 | # This method should save the model in data / model_name 44 | raise NotImplementedError 45 | 46 | def evaluate(self, return_eval=False, save_preds=False): 47 | """Evaluates the model using root mean squared error. 48 | This ensures evaluation is consistent across different models. 49 | 50 | Parameters: 51 | ---------- 52 | return_eval: bool, default: False 53 | Whether to return the calculated root mean squared error 54 | save_preds: bool, default: False 55 | Whether to save the predictions. If True, they will be saved 56 | in self.arrays_path / preds.npy 57 | 58 | Returns: 59 | ---------- 60 | (if return_eval) test_rmse: float 61 | The calculated root mean squared error for the test set 62 | """ 63 | y_true, y_pred = self.predict() 64 | 65 | test_rmse = np.sqrt(mean_squared_error(y_true, y_pred)) 66 | 67 | print(f'Test set RMSE: {test_rmse}') 68 | 69 | if save_preds: 70 | savedir = self.data_path / self.model_name 71 | if not savedir.exists(): savedir.mkdir() 72 | print(f'Saving predictions to {savedir / "preds.npy"}') 73 | np.save(savedir / 'preds.npy', y_pred) 74 | 75 | if return_eval: 76 | return test_rmse 77 | 78 | def load_arrays(self, mode='train'): 79 | 80 | arrays_path = self.arrays_path / mode 81 | 82 | x = np.load(arrays_path / 'x.npy') 83 | 84 | if self.hide_vegetation: 85 | if mode == 'train': 86 | print('Training model without vegetation features') 87 | indices_to_keep = [idx for idx, val in enumerate(VALUE_COLS) if val not in VEGETATION_LABELS] 88 | 89 | x = x[:, :, indices_to_keep] 90 | 91 | return DataTuple( 92 | latlon=np.load(arrays_path / 'latlon.npy'), 93 | years=np.load(arrays_path / 'years.npy'), 94 | x=x, 95 | y=np.load(arrays_path / 'y.npy')) 96 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vegetation Health 2 | 3 | Predicting vegetation health from precipitation and temperature 4 | 5 | ## Introduction 6 | 7 | This repository experiments with different machine learning models to predict drought indices in East Africa 8 | (specifically the Normalized Difference Vegetation Index) using temperature and precipitation data. 9 | 10 | ## Results 11 | 12 | Models are trained on data before 2016, and evaluated on 2016 data. Vegetation health in June is being predicted. 13 | 14 | In addition, vegetation health can be hidden from the model to better understand the effects of the other features. 15 | 16 | | Model | RMSE | RMSE (no veg) | 17 | |:------------------------:|:----:|:-------------:| 18 | |Linear Regression |0.040 |0.084 | 19 | |Feedforward neural network|0.038 |0.070 | 20 | |Recurrent neural network |0.035 |0.060 | 21 | 22 | The results of the models can also be compared visually with the ground truths (the example below is from the baseline 23 | logistic regression): 24 | 25 | Logstic regression results 26 | 27 | In addition, the effects of the inputs on the models' predictions are investigated using [shap values](https://github.com/slundberg/shap) 28 | in Jupyter Notebooks, for both the [feedforward neural network](notebooks/04_gt_feedforward_model.ipynb) and the 29 | [recurrent neural network](notebooks/08_gt_recurrent_model.ipynb). 30 | 31 | ## Pipeline 32 | 33 | [Python Fire](https://github.com/google/python-fire) is used to generate a CLI. 34 | 35 | ### Data cleaning 36 | 37 | Normalize values from the original csv file, remove null values, add a year series. 38 | 39 | ```bash 40 | python run.py clean 41 | ``` 42 | A target can be selected by adding the flag `--target`, e.g. `--target=ndvi_anomaly`. 43 | By default, the target is `ndvi`. The selected target must be in 44 | [`predictor.preprocessing.VALUE_COLS`](predictor/preprocessing.py). 45 | 46 | The original data is currently generated using datasets on the Oxford University cluster, using the scripts 47 | in [`data`](data). 48 | 49 | ### Data Processing 50 | 51 | Turn the CSV into `numpy` arrays which can be input into the model. 52 | 53 | ```bash 54 | python run.py engineer 55 | ``` 56 | 57 | ### Models 58 | 59 | 3 models have been implemented: a baseline linear regression, a feedforward neural network and a 60 | recurrent neural network. They can be selected using the `--model_type` flag. 61 | 62 | ```bash 63 | python run.py train_model 64 | ``` 65 | 66 | ## Setup 67 | 68 | [Anaconda](https://www.anaconda.com/download/#macos) running python 3.7 is used as the package manager. To get set up 69 | with an environment, install Anaconda from the link above, and (from this directory) run 70 | 71 | ```bash 72 | conda env create -f environment.yml 73 | ``` 74 | This will create an environment named `vegetation_health` with all the necessary packages to run the code. To 75 | activate this environment, run 76 | 77 | ```bash 78 | conda activate vegetation_health 79 | ``` 80 | 81 | ## Additional Notes 82 | 83 | - The following variables are used by the model: `['lst_night', 'lst_day', 'precip', 'sm', 'ndvi', 'evi', 'ndvi_anomaly']`. 84 | They are all from different sources 85 | 86 | 87 | - East Africa is defined here as the area of the original `.nc` file (`spi_spei.nc`) 88 | 89 | lat min, lat max : `-4.9750023`, `15.174995` 90 | 91 | lon min, lon max : `32.524994`, `48.274994` 92 | 93 | This makes the following bounding box: (left, bottom, right, top): `(32.524994, -4.9750023, 15.174995, 48.274994)` 94 | -------------------------------------------------------------------------------- /data/globmap_lookup.py: -------------------------------------------------------------------------------- 1 | """ 2 | The lookup values for the GLOBMAP landcover classes 3 | (defined in ee_extract_landcover.js) 4 | """ 5 | globmap_lookup = { 6 | 11 : "Post-flooding or irrigated croplands", 7 | 14 : "Rainfed croplands", 8 | 20 : "Mosaic cropland (50-70%) / vegetation (grassland, shrubland, forest) (20-50%)", 9 | 30 : "Mosaic vegetation (grassland, shrubland, forest) (50-70%) / cropland (20-50%)", 10 | 40 : "Closed to open (>15%) broadleaved evergreen and/or semi-deciduous forest (>5m)", 11 | 50 : "Closed (>40%) broadleaved deciduous forest (>5m)", 12 | 60 : "Open (15-40%) broadleaved deciduous forest (>5m)", 13 | 70 : "Closed (>40%) needleleaved evergreen forest (>5m)", 14 | 90 : "Open (15-40%) needleleaved deciduous or evergreen forest (>5m)", 15 | 100 : "Closed to open (>15%) mixed broadleaved and needleleaved forest (>5m)", 16 | 110 : "Mosaic forest-shrubland (50-70%) / grassland (20-50%)", 17 | 120 : "Mosaic grassland (50-70%) / forest-shrubland (20-50%)", 18 | 130 : "Closed to open (>15%) shrubland (<5m)", 19 | 140 : "Closed to open (>15%) grassland", 20 | 150 : "Sparse (>15%) vegetation (woody vegetation, shrubs, grassland)", 21 | 160 : "Closed (>40%) broadleaved forest regularly flooded - Fresh water", 22 | 170 : "Closed (>40%) broadleaved semi-deciduous and/or evergreen forest regularly flooded - saline water", 23 | 180 : "Closed to open (>15%) vegetation (grassland, shrubland, woody vegetation) on regularly flooded or waterlogged soil", 24 | 190 : "Artificial surfaces and associated areas (urban areas >50%) GLOBCOVER 2009", 25 | 200 : "Bare areas", 26 | 210 : "Water bodies", 27 | 220 : "Permanent snow and ice", 28 | 230 : "Unclassified", 29 | } 30 | 31 | globmap_lookup_tommy = { 32 | 11 : "irrigated croplands", 33 | 14 : "Rainfed croplands", 34 | 20 : "cropland/vegetation", 35 | 30 : "vegetation/cropland", 36 | 40 : "broad-leaved evergreen/semi-deciduous forest", 37 | 50 : "broad-leaved deciduous forest", 38 | 60 : "broad-leaved deciduous forest", 39 | 70 : "needle-leaved evergreen forest", 40 | 90 : "needle-leaved deciduous / evergreen forest", 41 | 100 : "broad-leaved / needleleaved forest", 42 | 110 : "forest-shrubland / grassland", 43 | 120 : "Mosaic grassland / forest-shrubland", 44 | 130 : "shrubland", 45 | 140 : "grassland", 46 | 150 : "Sparse vegetation", 47 | 160 : "forest regularly flooded - Fresh water", 48 | 170 : "forest regularly flooded - saline water", 49 | 180 : "vegetation on flooded soil", 50 | 190 : "Artificial", 51 | 200 : "Bare areas", 52 | 210 : "Water bodies", 53 | 220 : "Permanent snow and ice", 54 | 230 : "Unclassified", 55 | } 56 | 57 | 58 | globmap_lookup2 = { 59 | # np.nan : "NA", 60 | 11 : "cropland", 61 | 14 : "cropland", 62 | 20 : "cropland", 63 | 30 : "cropland", 64 | 40 : "forest", 65 | 50 : "forest", 66 | 60 : "forest", 67 | 70 : "forest", 68 | 90 : "forest", 69 | 100 : "forest", 70 | 110 : "shrubland/grassland", 71 | 120 : "shrubland/grassland", 72 | 130 : "shrubland", 73 | 140 : "grassland", 74 | 150 : "Sparse vegetation", 75 | 160 : "flooded", 76 | 170 : "flooded", 77 | 180 : "flooded", 78 | 190 : "Artificial", 79 | 200 : "Bare areas", 80 | 210 : "Water bodies", 81 | 220 : "Permanent snow and ice", 82 | 230 : "Unclassified", 83 | } 84 | 85 | remap_to_int_dict = { 86 | # "NA": 1, 87 | "cropland": 2, 88 | "forest": 3, 89 | "shrubland/grassland": 4, 90 | "shrubland": 5, 91 | "grassland": 6, 92 | "Sparse vegetation": 7, 93 | "flooded": 8, 94 | "Artificial": 9, 95 | "Bare areas": 10, 96 | "Water bodies": 11, 97 | "Permanent snow and ice": 12, 98 | "Unclassified": 13, 99 | } 100 | 101 | globmap_lookup3 = dict(zip(remap_to_int_dict.values(), remap_to_int_dict.keys())) 102 | -------------------------------------------------------------------------------- /data/preprocessing/collect_regrid_data_script.py: -------------------------------------------------------------------------------- 1 | """ 2 | @tommylees112 3 | 4 | An awful script (sorry Gabi) the data is read in separately for each product. 5 | They are then preprocessed: 6 | resample_time 7 | select_time_slice 8 | regrid_to_reference 9 | 10 | using the precipitation data as reference. 11 | """ 12 | # test.py 13 | import xarray as xr 14 | from pathlib import Path 15 | import matplotlib.pyplot as plt 16 | 17 | from preprocessing.utils import ( 18 | gdal_reproject, 19 | bands_to_time, 20 | convert_to_same_grid, 21 | select_same_time_slice, 22 | save_netcdf, 23 | get_holaps_mask, 24 | merge_data_arrays, 25 | ) 26 | 27 | 28 | def correct_time_slice(ds, reference_ds): 29 | """select the same time slice as the reference data""" 30 | correct_time_slice = select_same_time_slice(reference_ds, ds) 31 | 32 | return correct_time_slice 33 | 34 | 35 | def resample_time(ds, resample_str="M"): 36 | """ should resample to the given timestep """ 37 | resampled_time_data = ds.resample(time=resample_str).first() 38 | return resampled_time_data 39 | 40 | 41 | def regrid_to_reference(ds, reference_ds, method="nearest_s2d"): 42 | """ regrid data (spatially) onto the same grid as reference data """ 43 | 44 | regrid_data = convert_to_same_grid( 45 | reference_ds, ds, method=method 46 | ) 47 | 48 | return regrid_data 49 | 50 | 51 | def use_reference_mask(ds, mask, one_time=False): 52 | # if only one timestep (e.g. landcover) then convert to one time 53 | if one_time: 54 | self.mask = self.mask.isel(time=0) 55 | 56 | masked_d = ds.where(~mask.values) 57 | return masked_d 58 | 59 | 60 | def rename_lat_lon(ds): 61 | rename_latlon = ds.rename({"longitude": "lon", "latitude": "lat"}) 62 | return rename_latlon 63 | 64 | 65 | def select_time_slice(ds, timemin, timemax): 66 | return ds.sel(time=slice(timemin, timemax)) 67 | 68 | 69 | # ------------------------------------------------------------------------------ 70 | # Read the data 71 | # ------------------------------------------------------------------------------ 72 | DATA_DIR1 = Path("/soge-home/users/chri4118/EA_data") 73 | DATA_DIR2 = Path("/soge-home/projects/crop_yield/EGU_compare") 74 | 75 | et = xr.open_dataset(DATA_DIR1 / "ET_EastAfrica.nc") 76 | lst = xr.open_dataset(DATA_DIR1 / "LST_EastAfrica.nc")[["lst_day", "lst_night"]] 77 | sm = xr.open_dataset(DATA_DIR1 / "SM_EastAfrica.nc")[['sm','sm_uncertainty']] 78 | ndvi = xr.open_dataset(DATA_DIR1 / "NDVI_EastAfrica.nc")[["ndvi", "evi"]] 79 | precip = xr.open_dataset(DATA_DIR2 / "EA_chirps_monthly.nc") 80 | 81 | # ------------------------------------------------------------------------------ 82 | # Clean the data (same timesteps and same gridsizes) 83 | # ------------------------------------------------------------------------------ 84 | # RESAMPLE THE REFERENCE DATA 85 | precip = select_time_slice(precip, '2000-02-14','2016-12-01') 86 | precip = resample_time(precip) 87 | precip = rename_lat_lon(precip) 88 | reference_ds = precip 89 | 90 | all_vars = [et,lst,sm,ndvi,precip] 91 | names = ["et","lst","sm","ndvi","precip"] 92 | 93 | # RESAMPLE data (except 'precip') 94 | out = [] 95 | for ix, ds in enumerate(all_vars[:-1]): 96 | name = names[ix] 97 | print(f"\n*** working on ds: {name} ***") 98 | # select same time slice 99 | ds = resample_time(ds) 100 | ds = select_time_slice(ds, '2000-02-14','2016-12-01') 101 | # ds = correct_time_slice(ds, reference_ds) 102 | print("selected same time slice") 103 | # convert to same grid 104 | ds = regrid_to_reference(ds, reference_ds) 105 | print("converted to same grid") 106 | out.append(ds) 107 | 108 | # ------------------------------------------------------------------------------ 109 | # Merge all of the datasets 110 | # ------------------------------------------------------------------------------ 111 | alldata = out + [precip] 112 | OUT = xr.merge(alldata) 113 | 114 | # ------------------------------------------------------------------------------ 115 | # Save the data to netcdf format 116 | # ------------------------------------------------------------------------------ 117 | OUT.to_netcdf(DATA_DIR1 / "OUT2.nc") 118 | OUT.to_netcdf(DATA_DIR2 / "predict_vegetation_health.nc") 119 | -------------------------------------------------------------------------------- /data/preprocessing/spei_cleaner.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import xarray as xr 3 | import numpy as np 4 | 5 | import ipdb 6 | import warnings 7 | import os 8 | 9 | from preprocessing.utils import ( 10 | gdal_reproject, 11 | bands_to_time, 12 | convert_to_same_grid, 13 | select_same_time_slice, 14 | save_netcdf, 15 | get_holaps_mask, 16 | merge_data_arrays, 17 | ) 18 | 19 | from preprocessing.cleaner import Cleaner 20 | 21 | 22 | # ------------------------------------------------------------------------------ 23 | # HOLAPS cleaner 24 | # ------------------------------------------------------------------------------ 25 | 26 | 27 | class SpeiCleaner(Cleaner): 28 | """Preprocess the HOLAPS dataset""" 29 | assert False, "This is just boilerplate code from another" 30 | def __init__(self): 31 | # init data paths (should be arguments) 32 | self.base_data_path = Path("/soge-home/projects/crop_yield/EGU_compare/") 33 | data_path = self.base_data_path / "holaps_africa.nc" 34 | reproject_path = self.base_data_path / "holaps_africa_reproject.nc" 35 | 36 | super(HolapsCleaner, self).__init__(data_path=data_path) 37 | self.reproject_path = Path(reproject_path) 38 | 39 | 40 | def chop_EA_region(self, outfile_path): 41 | """ cheeky little bit of bash scripting with string interpolation (kids don't try this at home) """ 42 | in_file = self.base_data_path / "holaps_reprojected.nc" 43 | out_file = self.base_data_path / "holaps_EA.nc" 44 | lonmin = 32.6 45 | lonmax = 51.8 46 | latmin = -5.0 47 | latmax = 15.2 48 | 49 | cmd = ( 50 | f"cdo sellonlatbox,{lonmin},{lonmax},{latmin},{latmax} {in_file} {out_file}" 51 | ) 52 | print(f"Running command: {cmd}") 53 | os.system(cmd) 54 | print("Chopped East Africa from the Reprojected data") 55 | re_chopped_data = xr.open_dataset(out_file) 56 | self.update_clean_data( 57 | re_chopped_data, msg="Opened the reprojected & chopped data" 58 | ) 59 | return 60 | 61 | 62 | def reproject(self): 63 | """ reproject to WGS84 / geographic latlon """ 64 | if not self.reproject_path.is_file(): 65 | gdal_reproject(infile=self.data_path, outfile=self.reproject_path) 66 | 67 | repr_data = xr.open_dataset(self.reproject_path) 68 | 69 | # get the timestamps from the original holaps data 70 | h_times = self.clean_data.time 71 | # each BAND is a time (multiple raster images 1 per time) 72 | repr_data = bands_to_time(repr_data, h_times, var_name="LE_Mean") 73 | 74 | # TODO: ASSUMPTION / PROBLEM 75 | warnings.warn( 76 | "TODO: No idea why but the values appear to be 10* bigger than the pre-reprojected holaps data" 77 | ) 78 | repr_data /= 10 # WHY ARE THE VALUES 10* bigger? 79 | 80 | self.update_clean_data(repr_data, "Data Reprojected to WGS84") 81 | 82 | save_netcdf( 83 | self.clean_data, filepath=self.base_data_path / "holaps_reprojected.nc" 84 | ) 85 | return 86 | 87 | 88 | def convert_units(self): 89 | # Convert from latent heat (w m-2) to evaporation (mm day-1) 90 | holaps_mm = self.clean_data / 28 91 | holaps_mm = holaps_mm.LE_Mean 92 | holaps_mm.name = "Evapotranspiration" 93 | holaps_mm.attrs["units"] = "mm day-1 [w m-2 / 28]" 94 | self.update_clean_data( 95 | holaps_mm, msg="Transform Latent Heat (w m-2) to Evaporation (mm day-1)" 96 | ) 97 | 98 | return 99 | 100 | def preprocess(self): 101 | # reproject the file from sinusoidal to WGS84 / 'ESPG:4326' 102 | self.reproject() 103 | #  chop out the correct lat/lon (changes when reprojected) 104 | self.chop_EA_region() 105 | # convert the units 106 | self.convert_units() 107 | # rename data 108 | self.rename_xr_object("holaps_evapotranspiration") 109 | # resample the time units 110 | self.resample_time() 111 | # save the netcdf file (used as reference data for MODIS and GLEAM) 112 | save_netcdf( 113 | self.clean_data, filepath=self.base_data_path / "holaps_EA_clean.nc", 114 | force=True 115 | ) 116 | print("\n\n HOLAPS Preprocessed \n\n") 117 | return 118 | -------------------------------------------------------------------------------- /predictor/engineer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pandas as pd 3 | import numpy as np 4 | 5 | from .preprocessing import VALUE_COLS 6 | 7 | 8 | class Engineer: 9 | """Take the clean csv file and turn it into numpy arrays ready 10 | for training 11 | 12 | Attributes: 13 | ---------- 14 | cleaned_data: pathlib.Path 15 | The location of cleaned data, as saved by the `Cleaner` class 16 | arrays: pathlib.Path 17 | The location where the arrays will be saved 18 | """ 19 | 20 | def __init__(self, cleaned_data=Path('data/processed/cleaned_data.csv'), 21 | arrays=Path('data/processed/arrays')): 22 | 23 | self.arrays_path = arrays 24 | if not self.arrays_path.exists(): 25 | self.arrays_path.mkdir() 26 | self.cleaned_data_path = cleaned_data 27 | 28 | def readfile(self): 29 | return pd.read_csv(self.cleaned_data_path) 30 | 31 | def process(self, test_year=2016): 32 | """Takes the processed data saved by the `preprocessing.Cleaner` class, 33 | and turns it into `np.array`s which can be ingested by the machine learning 34 | models 35 | 36 | Parameters 37 | ---------- 38 | test_year: int, default: 2016 39 | Data from this year will be used for testing, and so will be saved in 40 | seperate arrays 41 | 42 | The following are saved, for both the training and test sets: 43 | 44 | {train, test}/latlon.npy: 45 | The locations of each data instance (so that latlon[i] represents the latitude and 46 | longitude of the ith data point 47 | {train, test}/years.npy: 48 | The years of each data point. Specifically, this represents the prediction year, so if 49 | `pred_month` passed to the Cleaner = 6, then if years[i] = 2015, that means y[i] is the 50 | value of the target in June 2015. 51 | {train, test}/x.npy: 52 | The training data; the previous 11 months of data 53 | {train, test}/y.npy: 54 | The test data - the value of the target variable at `pred_month`. 55 | """ 56 | data = self.readfile() 57 | 58 | # outputs 59 | latlons, years, vals, targets = [], [], [], [] 60 | 61 | skipped = 0 62 | # first, groupby lat, lon, so that we process the same place together 63 | for latlon, group in data.groupby(by=['lat', 'lon']): 64 | latlon_np = np.array(latlon) 65 | for year, subgroup in group.groupby(by='gb_year'): 66 | if len(subgroup) != 12: 67 | # print(f'Ignoring data from {year} at {latlon} due to missing rows') 68 | skipped += 1 69 | continue 70 | subgroup = subgroup.sort_values(by='gb_month', ascending=True) 71 | 72 | # create a np.array of the features (VALUE_COLS) and the target 73 | x = subgroup[:-1][VALUE_COLS].values 74 | y = subgroup.iloc[-1]['target'] 75 | 76 | # create lists of np.arrays 77 | latlons.append(latlon_np) 78 | years.append(year) 79 | vals.append(x) 80 | targets.append(y) 81 | 82 | if len(latlons) % 1000 == 0: 83 | print(f'Processed {len(latlons)} examples') 84 | 85 | print(f'Done processing {len(latlons)} pixel-years! Skipped {skipped} pixel-years due to missing rows') 86 | 87 | # turn everything into np arrays for manipulation 88 | latlons, years, vals, targets = np.vstack(latlons), np.array(years), np.stack(vals), np.array(targets) 89 | 90 | # split into train and test sets 91 | test_idx = np.where(years == test_year)[0] 92 | train_idx = np.where(years < test_year)[0] 93 | 94 | test_arrays = self.arrays_path / 'test' 95 | train_arrays = self.arrays_path / 'train' 96 | 97 | test_arrays.mkdir(parents=True, exist_ok=True) 98 | train_arrays.mkdir(exist_ok=True) 99 | 100 | print('Saving data') 101 | np.save(train_arrays / 'latlon.npy', latlons[train_idx]) 102 | np.save(train_arrays / 'years.npy', years[train_idx]) 103 | np.save(train_arrays / 'x.npy', vals[train_idx]) 104 | np.save(train_arrays / 'y.npy', targets[train_idx]) 105 | 106 | np.save(test_arrays / 'latlon.npy', latlons[test_idx]) 107 | np.save(test_arrays / 'years.npy', years[test_idx]) 108 | np.save(test_arrays / 'x.npy', vals[test_idx]) 109 | np.save(test_arrays / 'y.npy', targets[test_idx]) 110 | -------------------------------------------------------------------------------- /data/preprocessing/cleaner.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import xarray as xr 3 | import numpy as np 4 | 5 | import ipdb 6 | import warnings 7 | import os 8 | 9 | from preprocessing.utils import ( 10 | gdal_reproject, 11 | bands_to_time, 12 | convert_to_same_grid, 13 | select_same_time_slice, 14 | save_netcdf, 15 | get_holaps_mask, 16 | merge_data_arrays, 17 | ) 18 | 19 | 20 | # ------------------------------------------------------------------------------ 21 | # Base cleaner 22 | # ------------------------------------------------------------------------------ 23 | 24 | 25 | class Cleaner: 26 | """Base class for preprocessing the input data. 27 | 28 | Tasks include: 29 | - Reprojecting 30 | - Putting datasets onto a consistent spatial grid (spatial resolution) 31 | - Converting to equivalent units 32 | - Converting to the same temporal resolution 33 | - Selecting the same time slice 34 | 35 | Design Considerations: 36 | - Have an attribute, clean_data', that is constantly updated 37 | - Keep a copy of the raw data for reference 38 | - Update the 'clean_data' each time a transformation is applied 39 | """ 40 | 41 | def __init__(self, data_path): 42 | self.data_path = Path(data_path) 43 | 44 | # open the datasets using xarray 45 | self.raw_data = xr.open_dataset(self.data_path) 46 | 47 | # start with clean data as a copy of the raw data 48 | self.clean_data = self.raw_data.copy() 49 | 50 | 51 | def get_mask(self): 52 | self.mask = get_holaps_mask(self.reference_ds) 53 | 54 | 55 | def update_clean_data(self, clean_data, msg=""): 56 | """ """ 57 | self.clean_data = clean_data 58 | print("***** self.clean_data Updated: ", msg, " *****") 59 | 60 | return 61 | 62 | def correct_time_slice(self): 63 | """select the same time slice as the reference data""" 64 | assert ( 65 | self.reference_ds is not None 66 | ), "self.reference_ds does not exist! Likely because you're not using the MODIS or GLEAM cleaners / correct data paths" 67 | correct_time_slice = select_same_time_slice(self.reference_ds, self.clean_data) 68 | 69 | self.update_clean_data( 70 | correct_time_slice, msg="Selected the same time slice as reference data" 71 | ) 72 | return 73 | 74 | def resample_time(self, resample_str="M"): 75 | """ should resample to the given timestep """ 76 | resampled_time_data = self.clean_data.resample(time=resample_str).first() 77 | self.update_clean_data(resampled_time_data, msg="Resampled time ") 78 | 79 | return 80 | 81 | def regrid_to_reference(self, method="nearest_s2d"): 82 | """ regrid data (spatially) onto the same grid as referebce data """ 83 | assert ( 84 | self.reference_ds is not None 85 | ), "self.reference_ds does not exist! Likely because you're not using the MODIS or GLEAM cleaners / correct data paths" 86 | 87 | regrid_data = convert_to_same_grid( 88 | self.reference_ds, self.clean_data, method=method 89 | ) 90 | # UPDATE THE SELF.CLEAN_DATA 91 | self.update_clean_data(regrid_data, msg="Data Regridded to same as HOLAPS") 92 | return 93 | 94 | def use_reference_mask(self, one_time=False): 95 | assert not 'units' in self.mask.coords, "MUST NOT HAVE EXTRA COORDS or you remove ALL values. self.mask has 'units' coord and needs to be dropped:\n self.mask = self.mask.drop('units')" 96 | assert ( 97 | self.reference_ds is not None 98 | ), "self.reference_ds does not exist! Likely because you're not using the MODIS or GLEAM cleaners / correct data paths" 99 | assert ( 100 | self.mask is not None 101 | ), "self.mask does not exist! Likely because you're not using the MODIS or GLEAM cleaners / correct data paths" 102 | 103 | # if only one timestep (e.g. landcover) then convert to one time 104 | if one_time: 105 | self.mask = self.mask.isel(time=0) 106 | 107 | masked_d = self.clean_data.where(~self.mask.values) 108 | self.update_clean_data(masked_d, msg="Copied the mask from reference to dataset!") 109 | return 110 | 111 | def mask_illegitimate_values(self): 112 | # mask out the missing values (coded as something else) 113 | return NotImplementedError 114 | 115 | def convert_units(self): 116 | """ convert to the equivalent units """ 117 | raise NotImplementedError 118 | 119 | def rename_xr_object(self, name): 120 | renamed_data = self.clean_data.rename(name) 121 | self.update_clean_data(renamed_data, msg=f"Data renamed {name}") 122 | return 123 | 124 | def rename_lat_lon(self): 125 | rename_latlon = self.clean_data.rename({"longitude": "lon", "latitude": "lat"}) 126 | self.update_clean_data(rename_latlon, msg="Renamed latitude,longitude => lat,lon") 127 | return 128 | 129 | def preprocess(self): 130 | """ The preprocessing steps (relatively unique for each dtype) """ 131 | raise NotImplementedError 132 | -------------------------------------------------------------------------------- /predictor/models/neural_networks/nn_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from torch.utils.data import TensorDataset, DataLoader, random_split 4 | 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | from ..base import ModelBase, DataTuple 10 | 11 | 12 | class NNBase(ModelBase): 13 | """The base for neural networks models 14 | """ 15 | 16 | def __init__(self, model, data=Path('data'), arrays=Path('data/processed/arrays'), 17 | hide_vegetation=False): 18 | super().__init__(data, arrays, hide_vegetation) 19 | 20 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 21 | # for reproducability 22 | torch.manual_seed(42) 23 | torch.cuda.manual_seed_all(42) 24 | 25 | if torch.cuda.is_available(): 26 | model = model.cuda() 27 | self.model = model 28 | 29 | def train(self, num_epochs=100, patience=3, batch_size=32, learning_rate=1e-3): 30 | """Train the neural network 31 | 32 | Parameters 33 | ---------- 34 | num_epochs: int 35 | The maximum number of epochs to train the model for 36 | patience: int 37 | If no improvement is seen in the validation set for `patience` epochs, 38 | training is stopped to prevent overfitting 39 | batch_size: int 40 | The batch size to use 41 | learning_rate: float 42 | The learning rate to use when updating model parameters 43 | """ 44 | train_data = self.load_tensors(mode='train') 45 | 46 | # split the data into a training and validation set 47 | total_size = train_data.x.shape[0] 48 | val_size = total_size // 10 # 10 % for validation 49 | train_size = total_size - val_size 50 | print(f'After split, training on {train_size} examples, ' 51 | f'validating on {val_size} examples') 52 | train_dataset, val_dataset = random_split(TensorDataset(train_data.x, train_data.y.unsqueeze(1)), 53 | (train_size, val_size)) 54 | 55 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 56 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size) 57 | 58 | optimizer = torch.optim.Adam([pam for pam in self.model.parameters()], 59 | lr=learning_rate) 60 | 61 | epochs_without_improvement = 0 62 | best_loss = np.inf 63 | 64 | for epoch in range(num_epochs): 65 | self.model.train() 66 | 67 | epoch_train_loss = 0 68 | num_train_batches = 0 69 | 70 | epoch_val_loss = 0 71 | num_val_batches = 0 72 | 73 | for train_x, train_y in tqdm(train_dataloader): 74 | optimizer.zero_grad() 75 | pred_y = self.model(train_x) 76 | 77 | loss = F.smooth_l1_loss(pred_y, train_y) 78 | loss.backward() 79 | optimizer.step() 80 | 81 | num_train_batches += 1 82 | epoch_train_loss += loss.sqrt().item() 83 | 84 | self.model.eval() 85 | with torch.no_grad(): 86 | for val_x, val_y in tqdm(val_dataloader): 87 | val_pred_y = self.model(val_x) 88 | val_loss = F.mse_loss(val_pred_y, val_y) 89 | 90 | num_val_batches += 1 91 | epoch_val_loss += val_loss.sqrt().item() 92 | 93 | epoch_train_loss /= num_train_batches 94 | epoch_val_loss /= num_val_batches 95 | 96 | print(f'Epoch {epoch} - Training RMSE: {epoch_train_loss}, ' 97 | f'Validation RMSE: {epoch_val_loss}') 98 | 99 | if epoch_val_loss < best_loss: 100 | best_state = self.model.state_dict() 101 | best_loss = epoch_val_loss 102 | 103 | epochs_without_improvement = 0 104 | else: 105 | epochs_without_improvement += 1 106 | if epochs_without_improvement == patience: 107 | self.model.load_state_dict(best_state) 108 | print('Early stopping!') 109 | return 110 | self.model.load_state_dict(best_state) 111 | 112 | def predict(self, batch_size=64): 113 | test_data = self.load_tensors(mode='test') 114 | 115 | test_dataloader = DataLoader(TensorDataset(test_data.x, test_data.y), 116 | batch_size=batch_size) 117 | 118 | output_preds, output_true = [], [] 119 | 120 | self.model.eval() 121 | with torch.no_grad(): 122 | for test_x, test_y in tqdm(test_dataloader): 123 | output_preds.append(self.model(test_x).squeeze(1).cpu().numpy()) 124 | output_true.append(test_y.cpu().numpy()) 125 | return np.concatenate(output_true), np.concatenate(output_preds) 126 | 127 | def load_tensors(self, mode='train'): 128 | data = self.load_arrays(mode) 129 | 130 | return DataTuple( 131 | latlon=data.latlon, 132 | years=data.years, 133 | x=torch.as_tensor(data.x, device=self.device).float(), 134 | y=torch.as_tensor(data.y, device=self.device).float()) 135 | 136 | def save_model(self, name="model.pt"): 137 | """save the model's state_dict""" 138 | savedir = self.data_path / self.model_name 139 | if not savedir.exists(): savedir.mkdir() 140 | 141 | # save with .pt extension 142 | if not '.pt' in name: name += '.pt' 143 | print(f'Saving predictions to {savedir / name}') 144 | torch.save(self.model, savedir / name) 145 | -------------------------------------------------------------------------------- /predictor/models/neural_networks/recurrent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | import math 5 | from pathlib import Path 6 | 7 | from .nn_base import NNBase 8 | from ...preprocessing import VALUE_COLS, VEGETATION_LABELS 9 | 10 | 11 | class Recurrent(NNBase): 12 | """A simple feedforward neural network 13 | """ 14 | 15 | def __init__(self, data=Path('data'), arrays=Path('data/processed/arrays'), 16 | hide_vegetation=False): 17 | 18 | features_per_month = len(VALUE_COLS) 19 | if hide_vegetation: 20 | features_per_month -= len(VEGETATION_LABELS) 21 | 22 | super().__init__(RNN(features_per_month, [256]), 23 | data, arrays, hide_vegetation) 24 | self.model_name = "recurrent" 25 | 26 | class RNN(nn.Module): 27 | """ 28 | A crop yield conv net. 29 | For a description of the parameters, see the RNNModel class. 30 | """ 31 | def __init__(self, features_per_month, dense_features, hidden_size=128, 32 | rnn_dropout=0.25): 33 | super().__init__() 34 | 35 | dense_features.insert(0, hidden_size) 36 | if dense_features[-1] != 1: 37 | dense_features.append(1) 38 | 39 | self.dropout = nn.Dropout(rnn_dropout) 40 | self.rnn = UnrolledRNN(input_size=features_per_month, 41 | hidden_size=hidden_size, 42 | batch_first=True) 43 | self.hidden_size = hidden_size 44 | 45 | self.dense_layers = nn.ModuleList([ 46 | nn.Linear(in_features=dense_features[i-1], 47 | out_features=dense_features[i]) 48 | for i in range(1, len(dense_features)) 49 | ]) 50 | 51 | self.initialize_weights() 52 | 53 | def initialize_weights(self): 54 | 55 | sqrt_k = math.sqrt(1 / self.hidden_size) 56 | for parameters in self.rnn.parameters(): 57 | for pam in parameters: 58 | nn.init.uniform_(pam.data, -sqrt_k, sqrt_k) 59 | 60 | for dense_layer in self.dense_layers: 61 | nn.init.kaiming_uniform_(dense_layer.weight.data) 62 | nn.init.constant_(dense_layer.bias.data, 0) 63 | 64 | def forward(self, x): 65 | """ 66 | If return_last_dense is true, the feature vector generated by the second to last 67 | dense layer will also be returned. This is then used to train a Gaussian Process model. 68 | """ 69 | 70 | sequence_length = x.shape[1] 71 | 72 | hidden_state = torch.zeros(1, x.shape[0], self.hidden_size) 73 | cell_state = torch.zeros(1, x.shape[0], self.hidden_size) 74 | 75 | if x.is_cuda: 76 | hidden_state = hidden_state.cuda() 77 | cell_state = cell_state.cuda() 78 | 79 | for i in range(sequence_length): 80 | # The reason the RNN is unrolled here is to apply dropout to each timestep; 81 | # The rnn_dropout argument only applies it after each layer. 82 | # https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/DropoutWrapper 83 | input_x = x[:, i, :].unsqueeze(1) 84 | _, (hidden_state, cell_state) = self.rnn(input_x, 85 | (hidden_state, cell_state)) 86 | hidden_state = self.dropout(hidden_state) 87 | 88 | x = hidden_state.squeeze(0) 89 | for layer_number, dense_layer in enumerate(self.dense_layers): 90 | x = dense_layer(x) 91 | return x 92 | 93 | 94 | class UnrolledRNN(nn.Module): 95 | """An unrolled RNN. The motivation for this is mainly so that we can explain this model using 96 | the shap deep explainer, but also because we unroll the RNN anyway to apply dropout. 97 | """ 98 | 99 | def __init__(self, input_size, hidden_size, batch_first=True): 100 | super().__init__() 101 | 102 | self.input_size = input_size 103 | self.hidden_size = hidden_size 104 | self.batch_first = batch_first 105 | 106 | self.forget_gate = nn.Sequential(*[ 107 | nn.Linear(in_features=input_size + hidden_size, out_features=hidden_size, 108 | bias=True), nn.Sigmoid()]) 109 | 110 | self.update_gate = nn.Sequential(*[ 111 | nn.Linear(in_features=input_size + hidden_size, out_features=hidden_size, 112 | bias=True), nn.Sigmoid() 113 | ]) 114 | 115 | self.update_candidates = nn.Sequential(*[ 116 | nn.Linear(in_features=input_size + hidden_size, out_features=hidden_size, 117 | bias=True), nn.Tanh() 118 | ]) 119 | 120 | self.output_gate = nn.Sequential(*[ 121 | nn.Linear(in_features=input_size + hidden_size, out_features=hidden_size, 122 | bias=True), nn.Sigmoid() 123 | ]) 124 | 125 | self.cell_state_activation = nn.Tanh() 126 | 127 | def forward(self, x, state): 128 | hidden, cell = state 129 | 130 | if self.batch_first: 131 | hidden, cell = torch.transpose(hidden, 0, 1), torch.transpose(cell, 0, 1) 132 | 133 | forget_state = self.forget_gate(torch.cat((x, hidden), dim=-1)) 134 | update_state = self.update_gate(torch.cat((x, hidden), dim=-1)) 135 | cell_candidates = self.update_candidates(torch.cat((x, hidden), dim=-1)) 136 | 137 | updated_cell = (forget_state * cell) + (update_state * cell_candidates) 138 | 139 | output_state = self.output_gate(torch.cat((x, hidden), dim=-1)) 140 | updated_hidden = output_state * self.cell_state_activation(updated_cell) 141 | 142 | if self.batch_first: 143 | updated_hidden = torch.transpose(updated_hidden, 0, 1) 144 | updated_cell = torch.transpose(updated_cell, 0, 1) 145 | 146 | return updated_hidden, (updated_hidden, updated_cell) 147 | -------------------------------------------------------------------------------- /predictor/preprocessing.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pathlib import Path 3 | import xarray as xr 4 | import numpy as np 5 | import json 6 | 7 | KEY_COLS = ['lat', 'lon', 'time', 'gb_year', 'gb_month'] 8 | VALUE_COLS = ['lst_night', 'lst_day', 'precip', 'sm', 'ndvi', 'evi', 'ndvi_anomaly'] 9 | VEGETATION_LABELS = ['ndvi', 'evi', 'ndvi_anomaly'] 10 | 11 | 12 | class Cleaner: 13 | """Clean the input data (from the .nc file), by removing nan 14 | (or encoded-as-nan) values, and normalising the non-key values. 15 | 16 | Does some preprocessing on the .nc file using xarray and then converts 17 | it to a dataframe for the other methods 18 | 19 | Attributes: 20 | ---------- 21 | raw_filepath: pathlib.Path 22 | The location of the raw .NC data 23 | processed_filepath: pathlib.Path 24 | The location where the processed csv will be saved. The normalizing dict 25 | will be saved in the same directory. 26 | """ 27 | def __init__(self, raw_filepath=Path('data/raw/OUT.NC'), 28 | processed_filepath=Path('data/processed/cleaned_data.csv')): 29 | 30 | self.filepath = raw_filepath 31 | 32 | if not processed_filepath.parents[0].exists(): 33 | processed_filepath.parents[0].mkdir() 34 | 35 | self.processed_filepath = processed_filepath 36 | self.normalizing_dict = processed_filepath.parents[0] / 'normalizing_dict.json' 37 | 38 | def process(self, pred_month=6, target='ndvi_anomaly'): 39 | """Preprocesses the raw data, and saves it. Specifically, the following 40 | preprocessing happens: 41 | 1. `gb_year` and `gb_month`, which are the dates relative to 42 | `pred_month`, are added. 43 | 2. `ndvi_anomaly` is calculated 44 | 3. NaN values (and missing data) is removed from the dataframe. 45 | 4. Normalizes all values to have mean 0 and std 1 46 | 47 | Parameters 48 | ---------- 49 | pred_month: int 50 | The month for which the target value should be predicted. This value will be 51 | predicted using the preceding 11 months of data 52 | target: str 53 | The target variable being predicted 54 | 55 | A processed CSV and a .json object containing the values used to normalize 56 | each variable are saved. 57 | """ 58 | 59 | assert target in VALUE_COLS, f'f{target} not in {VALUE_COLS}' 60 | 61 | data = self._readfile(pred_month) 62 | 63 | data['target'] = data[target] 64 | 65 | normalizing_dict = {} 66 | for col in VALUE_COLS: 67 | print(f'Normalizing {col}') 68 | 69 | series = data[col] 70 | 71 | # calculate normalizing values 72 | mean, std = series.mean(), series.std() 73 | # add them to the dictionary 74 | normalizing_dict[col] = { 75 | 'mean': float(mean), 'std': float(std), 76 | } 77 | 78 | data[col] = (series - mean) / std 79 | 80 | print("Saving data") 81 | data.to_csv(self.processed_filepath, index=False) 82 | print(f'Saved {self.processed_filepath}') 83 | 84 | with open(self.normalizing_dict, 'w') as f: 85 | json.dump(normalizing_dict, f) 86 | 87 | print(f'Saved {self.normalizing_dict}') 88 | 89 | def _readfile(self, pred_month): 90 | # drop any Pixel-Times with missing values 91 | data = xr.open_dataset(self.filepath) 92 | 93 | data['gb_month'], data['gb_year'] = self._update_year_month( 94 | pd.to_datetime(data.time.to_series()), pred_month) 95 | 96 | # mask out the invalid temperature values 97 | lst_cols = ['lst_night', 'lst_day'] 98 | for var_ in lst_cols: 99 | # for the lst_cols, missing data is coded as 200 100 | valid = (data[var_] < 200) | (np.isnan(data[var_])) 101 | data[var_] = data[var_].fillna(np.nan).where(valid) 102 | 103 | return_cols = KEY_COLS + VALUE_COLS 104 | 105 | # compute the ndvi_anomaly 106 | data['ndvi_anomaly'] = self._compute_anomaly(data.ndvi) 107 | 108 | data = data.to_dataframe().reset_index() 109 | data.dropna(how='any', axis=0, inplace=True) 110 | 111 | print(f'Loaded {len(data)} rows!') 112 | return data[return_cols] 113 | 114 | @staticmethod 115 | def _update_year_month(times, pred_month): 116 | """Given a pred year (e.g. 6), this method will return two new series with 117 | updated years and months so that a "year" of data will be the 11 months preceding the 118 | pred_month, and the pred_month. This makes it easier for the engineer to then make the training 119 | data 120 | """ 121 | if pred_month == 12: 122 | return times.dt.month, times.dt.year 123 | 124 | relative_times = times - pd.DateOffset(months=pred_month) 125 | 126 | # we add one year so that the year column the engineer makes will be reflective 127 | # of the pred year, which is shifted because of the data offset we used 128 | return relative_times.dt.month, relative_times.dt.year + 1 129 | 130 | @staticmethod 131 | def _compute_anomaly(da, time_group='time.month'): 132 | """ Return a dataarray where values are an anomaly from the MEAN for that 133 | location at a given timestep. Defaults to finding monthly anomalies. 134 | Notes: http://xarray.pydata.org/en/stable/examples/weather-data.html#calculate-monthly-anomalies 135 | 136 | In addition, since 2016 is being used as the prediction year, data from that year 137 | is not being used to compute the mean. 138 | 139 | Arguments: 140 | --------- 141 | : da (xr.DataArray) 142 | : time_group (str) 143 | time string to group. 144 | """ 145 | print('Computing ndvi anomaly') 146 | assert isinstance(da, xr.DataArray), f"`da` should be of type `xr.DataArray`. Currently: {type(da)}" 147 | trimmed_da = da[da['time.year'] < 2016] 148 | mthly_vals = trimmed_da.groupby(time_group).mean('time') 149 | da = da.groupby(time_group) - mthly_vals 150 | 151 | return da 152 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import xarray as xr 2 | import pandas as pd 3 | import numpy as np 4 | from pathlib import Path 5 | import os 6 | import warnings 7 | 8 | # invert the upside down variables 9 | def invert(da): 10 | """ given a dataarray invert the latitude values so that the plots are the right way up """ 11 | da.values = da.values[:,::-1,:] 12 | assert da.name, f"the dataarray passed to function invert() must have a name!" 13 | print(f'inverted the values of {da.name}') 14 | return da 15 | 16 | 17 | def get_lc_mask(ds, mask_var): 18 | assert mask_var in ["spi", "spei", "ndvi"], f"Mask Variable should be one of spi spei ndvi (ADD MORE)" 19 | # create a land cover mask 20 | warnings.warn('Currently get_lc_mask() is working with hardcoded SPI values. May cause problems') 21 | lc_mask = ~ds['spi'].isel(time=1).isnull() 22 | lc_mask.name = "lc_mask" 23 | print("created lc_mask") 24 | return lc_mask 25 | 26 | 27 | def remove_excess_parameters(ds): 28 | """ select only some of the variables (NOTE: poorly programmed because hardcoded) 29 | TODO: remove hardcoded params 30 | """ 31 | key_parameters = ["lst_mean","lst_day","lst_night","sm","precip","evaporation","transpiration","spi","spei","ndvi","evi","surface_soil_moisture","rootzone_soil_moisture","baresoil_evaporation","potential_evaporation"] 32 | # check that all the parameters exist in the dataset 33 | for param in key_parameters: 34 | assert param in [var for var in ds.variables.keys()], f"Param {param} not found!" 35 | ds = ds[key_parameters] 36 | print(f"Keeping parameters: {key_parameters}") 37 | return ds 38 | 39 | 40 | def mask_sea(ds, lc_mask): 41 | """ MASK THE SEA VALUES (select only where lc_mask == 1)""" 42 | ds = ds.where(lc_mask) 43 | print(f"Sea values masked for ds with vars {[var for var in ds.data_vars.keys()]}") 44 | return ds 45 | 46 | 47 | def get_boolean_drought_ds(ds): 48 | """ extract the drought indices from the data array """ 49 | assert "drought_ndvi" in [var for var in ds.variables.keys()], f"drought_ndvi should be in {[var for var in ds.variables.keys()]}" 50 | assert "drought_spei" in [var for var in ds.variables.keys()], f"drought_spei should be in {[var for var in ds.variables.keys()]}" 51 | assert "drought_spi" in [var for var in ds.variables.keys()], f"drought_spi should be in {[var for var in ds.variables.keys()]}" 52 | drought = ds[['drought_spei','drought_spi','drought_ndvi']] 53 | print("Created a drought index xr.Dataset") 54 | return drought 55 | 56 | 57 | def save_netcdf(output_ds, filename): 58 | """ save the dataset""" 59 | output_ds.to_netcdf(filename) 60 | print(f"{filename} saved!") 61 | return 62 | 63 | 64 | def clean_lst_variables(ds): 65 | """""" 66 | lst_vars = [var_ for var_ in ds.data_vars.keys() if "lst" in var_] 67 | for lst_var in lst_vars: 68 | # filter OUT the lst values of 200 69 | valid = ds[lst_var] < 200 70 | ds[lst_var] = ds[lst_var].fillna(np.nan).where(valid) 71 | return ds 72 | 73 | 74 | def clean_data(ds, lc_mask): 75 | # drop the final timestep 76 | ds = ds.isel(time=slice(0,-1)) 77 | 78 | # invert precip 79 | ds['precip'] = invert(ds.precip) 80 | 81 | # get only the important parameters 82 | ds = remove_excess_parameters(ds) 83 | 84 | # mask out the sea 85 | ds = mask_sea(ds, lc_mask) 86 | 87 | # clean lst variables 88 | ds = clean_lst_variables(ds) 89 | 90 | return ds 91 | 92 | 93 | def get_df_of_pixels_to_remove(lc_mask): 94 | """ return a dataframe of the pixels that correspond to SEA points (and therefore that we want to remove) 95 | """ 96 | mask_df = lc_mask.to_dataframe() 97 | mask_df = mask_df.reset_index().drop(columns=["time"]) 98 | indexes_to_remove = mask_df.where(~mask_df.lc_mask).dropna() 99 | 100 | return indexes_to_remove 101 | 102 | 103 | 104 | def shift_by_time(ds): 105 | """ shift dataset by number of timesteps (so if shift by 3 you get )""" 106 | return ds.shift(ts) 107 | 108 | 109 | 110 | def mask_ds(ds, drought_mask, drought=True): 111 | """set all of the pixels that are SEA or DROUGHT to NAN""" 112 | print("drought pixels masked") 113 | 114 | return ds 115 | 116 | 117 | def make_masks_boolean(mask_ds): 118 | """ convert masks from 0,1 to False,True """ 119 | try: 120 | print(f"Convert mask to bool for ds with vars {[var for var in mask_ds.data_vars.keys()]}") 121 | mask_ds = mask_ds.astype(bool) 122 | except: 123 | try: 124 | print(f"UNABLE to convert mask to bool for ds with vars {[var for var in mask_ds.data_vars.keys()]}") 125 | except: # is a data array and it doesn't have .data_vars() 126 | print(f"UNABLE to convert mask to bool for ds with vars {mask_ds.name}") 127 | return mask_ds 128 | 129 | 130 | def read_data(data_path='.', mask_var='spi'): 131 | """ """ 132 | assert os.path.isfile(data_path), f"The path provided to read data does not exist! Currently: {data_path}" 133 | # data_path = "/Volumes/Lees_Extend/data/ea_data/all_variables_LC2.nc" 134 | print(f"Reading from file: {data_path}") 135 | ds = xr.open_dataset(data_path) 136 | lc_mask = get_lc_mask(ds, mask_var) 137 | # -------------------------------------------------------------------------- 138 | # OFFENDING LINE 139 | warnings.warn('drought_ndvi hardcoded in here. Not at all okay.gst FIX ME') 140 | ds['drought_ndvi'] = ds.ndvi < (ds.ndvi.mean(dim='time') - ds.ndvi.std(dim='time')) 141 | # -------------------------------------------------------------------------- 142 | drought_mask = get_boolean_drought_ds(ds) 143 | 144 | # REMOVE THE SEA VALUES from drought mask 145 | drought_mask = mask_sea(drought_mask, lc_mask) 146 | drought_mask = make_masks_boolean(drought_mask) 147 | lc_mask = make_masks_boolean(lc_mask) 148 | 149 | ds = clean_data(ds, lc_mask) 150 | 151 | return ds, lc_mask, drought_mask 152 | 153 | 154 | def print_shift_explanation(ds): 155 | """ print statements to explain the differences with the shift operator """ 156 | print("*** UNSHIFTED TIME ***") 157 | print("time=0\n", ds.isel(lat=slice(0,2), lon=slice(0,2),time=0).precip.values) 158 | print("time=1\n", ds.isel(lat=slice(0,2), lon=slice(0,2),time=1).precip.values) 159 | print("time=2\n",ds.isel(lat=slice(0,2), lon=slice(0,2), time=2).precip.values) 160 | print("time=3\n",ds.isel(lat=slice(0,2), lon=slice(0,2), time=3).precip.values) 161 | print() 162 | print("*** SHIFTED TIME (+1) = moving the HISTORICAL TIMESTEPS FORWARD to the PRESENT ***") 163 | print("time=0\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=1).isel(time=0).precip.values) 164 | print("time=1\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=1).isel(time=1).precip.values) 165 | print("time=2\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=1).isel(time=2).precip.values) 166 | print("time=3\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=1).isel(time=3).precip.values) 167 | print() 168 | print("*** SHIFTED TIME (-1) = moving the PRESENT TIMESTEPS BACKWARD to the PAST ***") 169 | print("time=0\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=-1).isel(time=0).precip.values) 170 | print("time=1\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=-1).isel(time=1).precip.values) 171 | print("time=2\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=-1).isel(time=2).precip.values) 172 | print("time=3\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=-1).isel(time=3).precip.values) 173 | print() 174 | 175 | # ---------------------------------------------------------------------- 176 | def create_lc_mask(ds): 177 | """ from the dataset create a landcover mask """ 178 | # create a land cover mask 179 | lc_mask = ~ds.spi.isel(time=1).isnull() 180 | lc_mask.name = "lc_mask" 181 | 182 | # create df lc mask 183 | mask_df = lc_mask.to_dataframe() 184 | mask_df = mask_df.reset_index().drop(columns=["time"]) 185 | 186 | # get df of SEA pixels (pixels to remove) 187 | indexes_to_remove = mask_df.where(~mask_df.lc_mask).dropna() 188 | 189 | return lc_mask, mask_df, indexes_to_remove 190 | 191 | 192 | def drop_nans_and_flatten(dataArray): 193 | """flatten the array and drop nans from that array. Useful for plotting histograms. 194 | 195 | Arguments: 196 | --------- 197 | : dataArray (xr.DataArray) 198 | the DataArray of your value you want to flatten 199 | """ 200 | # drop NaNs and flatten 201 | return dataArray.values[~np.isnan(dataArray.values)] 202 | -------------------------------------------------------------------------------- /notebooks/05_tl_explore_models.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pathlib import Path\n", 10 | "from itertools import chain\n", 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "\n", 14 | "import json\n", 15 | "import sys\n", 16 | "sys.path.insert(0, '..')\n", 17 | "\n", 18 | "from predictor.models import LinearModel\n", 19 | "from predictor.preprocessing import VALUE_COLS, VEGETATION_LABELS\n", 20 | "from predictor.models import nn_FeedForward\n", 21 | "from predictor.analysis import plot_shap_values" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "with open('../data/processed/normalizing_dict.json', 'r') as f:\n", 31 | " normalizing_dict = json.load(f)\n", 32 | " \n", 33 | "path_to_arrays = Path('../data/processed/arrays')" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 9, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "data": { 43 | "text/plain": [ 44 | "" 45 | ] 46 | }, 47 | "execution_count": 9, 48 | "metadata": {}, 49 | "output_type": "execute_result" 50 | } 51 | ], 52 | "source": [ 53 | "model = nn_FeedForward(path_to_arrays, hide_vegetation=True)\n", 54 | "model" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 10, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "Training model without vegetation features\n" 67 | ] 68 | }, 69 | { 70 | "name": "stderr", 71 | "output_type": "stream", 72 | "text": [ 73 | " 0%| | 0/8175 [00:00 3, "Need timestep to be greater than 3 because want the PREVIOUS 3 months. Otherwise all nans!" 84 | assert index in ['spei','spi', 'ndvi'], f"index must be either ['spei', 'spi', 'ndvi']. Currently: {index}" 85 | 86 | # get the drought mask at that timestep (NOTE have to invert BEFORE this) 87 | if index=='spi': 88 | msk = drought_mask.drought_spi.isel(time=ts) 89 | elif index=='spei': 90 | msk = drought_mask.drought_spei.isel(time=ts) 91 | elif index=='ndvi': 92 | msk = drought_mask.drought_ndvi.isel(time=ts) 93 | 94 | # if want pixels that were IN DROUGHT 95 | if drought: 96 | t0 = ds.where(msk).isel(time=ts) 97 | t1 = ds.where(msk).isel(time=ts-1) 98 | t2 = ds.where(msk).isel(time=ts-2) 99 | t3 = ds.where(msk).isel(time=ts-3) 100 | 101 | # if want pixels that are NOT IN DROUGHT 102 | else: 103 | t0 = ds.where(msk).isel(time=ts) 104 | t1 = ds.where(msk).isel(time=ts-1) 105 | t2 = ds.where(msk).isel(time=ts-2) 106 | t3 = ds.where(msk).isel(time=ts-3) 107 | 108 | return [t0, t1, t2, t3] 109 | 110 | 111 | def merge_all_ts_into_one_ds(ds_arr): 112 | """merge all of the timesteps (t=0, t-1, t-2, t-3) into one dataset object 113 | 114 | input: 115 | : array of all the 116 | returns: 117 | : ds_out (xr.Dataset): one dataset with all of the variables 118 | """ 119 | # assert that the minimum time is the first dataset in the array 120 | time = ds_arr[0].time 121 | ds_rnm = [] 122 | 123 | for ts, ds_ in enumerate(ds_arr): 124 | # create dict of variables to rename 125 | map_names = dict(zip([var for var in ds_.variables.keys() if var not in ['time','lat','lon']], 126 | [f"{var}_t{ts}" for var in ds_.variables.keys() if var not in ['time','lat','lon']]) 127 | ) 128 | # rename the variables 129 | ds_ = ds_.rename(map_names) 130 | ds_rnm.append(ds_) 131 | 132 | # drop the 'time' (TO ALLOW THE MERGE) 133 | ds_rnm = [ds_.drop('time') for ds_ in ds_rnm] 134 | 135 | # merge the variables 136 | ds_out = xr.merge(ds_rnm) 137 | 138 | # reassign 'Time' from the first dataset 139 | ds_out = ds_out.assign_coords(**{'time':time}) 140 | 141 | return ds_out 142 | 143 | 144 | def calculate_drought_masked_ds(ds, drought_mask, ts): 145 | """ run the above functions on EVERY timestep 146 | so output a dataset with each timestep a calculation of the previous months 147 | """ 148 | print(f"Extracting Drought vars from timestep {ts}") 149 | ds_arr = mask_drought_events(ds, drought_mask, ts=ts, index='ndvi', drought=True) 150 | ds_drought = merge_all_ts_into_one_ds(ds_arr) 151 | 152 | return ds_drought 153 | 154 | 155 | def calculate_Ndrought_masked_ds(ds, drought_mask, ts): 156 | """ run the above functions on EVERY timestep 157 | so output a dataset with each timestep a calculation of the previous months 158 | """ 159 | print(f"Extracting NON-Drought vars from timestep {ts}") 160 | ds_arr = mask_drought_events(ds, drought_mask, ts=ts, index='ndvi', drought=False) 161 | ds_Ndrought = merge_all_ts_into_one_ds(ds_arr) 162 | 163 | return ds_Ndrought 164 | 165 | 166 | def drought_across_all_timesteps(ds, drought_mask, start_ts=4): 167 | """ run the drought masking process for ALL TIMESTEPS 168 | 169 | Note: the output of Parallel(n_jobs=2)({FNCTN}) is a list of all of the variables 170 | because they are timestamped it doesn't matter if they come back in a different 171 | order. Therefore, we save the reordering to a later date. Future me can deal 172 | with that problem. 173 | """ 174 | dr = [] 175 | Ndr = [] 176 | 177 | # RUN the drought first 178 | print("Extracting the pixels in DROUGHT") 179 | with Parallel(n_jobs=30, verbose=True) as parallel: 180 | dr = parallel( 181 | delayed(calculate_drought_masked_ds) 182 | (ds=ds,drought_mask=drought_mask,ts=ts) for ts in range(start_ts, ds.time.shape[0]) 183 | ) 184 | print("Pixels in DROUGHT extracted") 185 | 186 | # then run the Ndrought 187 | print("Extracting the pixels NOT in DROUGHT") 188 | # convert to boolean arrays in order to INVERT them 189 | Ndrought_mask = ~(drought_mask.astype(bool)) 190 | with Parallel(n_jobs=30, verbose=True) as parallel: 191 | Ndr = parallel( 192 | delayed(calculate_Ndrought_masked_ds) 193 | (ds=ds,drought_mask=Ndrought_mask,ts=ts) for ts in range(start_ts, ds.time.shape[0]) 194 | ) 195 | print("Pixels NOT in DROUGHT extracted") 196 | 197 | # concatenate the arrays by time into ONE dataset (agree there's lots of duplication of data here) 198 | # TODO: does this really make sense? we are duplicating SOOOOOOO much data and for what? 199 | # definitely a better way. 200 | ds_drought = xr.concat(dr, dim='time') 201 | ds_drought.to_netcdf('ds_drought.nc') 202 | print("DROUGHT Variables extracted. Saving to netcdf ds_drought.nc ...") 203 | 204 | ds_Ndrought = xr.concat(Ndr, dim='time') 205 | ds_Ndrought.to_netcdf('ds_Ndrought.nc') 206 | print("NOT DROUGHT Variables extracted. Saving to netcdf ds_Ndrought.nc ...") 207 | 208 | return ds_drought, ds_Ndrought 209 | 210 | 211 | def run_drought_processing(data_path, mask_var='spi'): 212 | """ 213 | 214 | """ 215 | print("** Running drought processing for all timesteps! **") 216 | ds, lc_mask, drought_mask = read_data(data_path, mask_var) 217 | print(f"** Data read in - using {mask_var} as the masking variable **") 218 | ds_drought, ds_Ndrought = drought_across_all_timesteps(ds, drought_mask) 219 | 220 | save_netcdf(ds_drought, "ds_drought.nc") 221 | save_netcdf(ds_Ndrought, "ds_Ndrought.nc") 222 | print("** Process finished **") 223 | 224 | return ds_drought, ds_Ndrought 225 | 226 | 227 | 228 | 229 | 230 | def fix_drought_var(): 231 | """if drought variable not in dataset then append it!""" 232 | pass 233 | 234 | 235 | if __name__=="__main__": 236 | """ TODO: set up as a CLI """ 237 | # data_path = "/soge-home/users/chri4118/EA_data/all_variables_LCMASK.nc" 238 | data_path = "/soge-home/users/chri4118/ea_exploration/OUT.nc" 239 | ds_drought, ds_Ndrought = run_drought_processing(data_path, mask_var='ndvi') 240 | -------------------------------------------------------------------------------- /data/preprocessing/utils.py: -------------------------------------------------------------------------------- 1 | import xarray as xr 2 | import pandas as pd 3 | import xesmf as xe # for regridding 4 | import numpy as np 5 | 6 | import os 7 | from pathlib import Path 8 | import ipdb 9 | import warnings 10 | import datetime 11 | 12 | import geopandas as gpd 13 | from shapely import geometry 14 | 15 | 16 | def read_csv_point_data(df, lat_col='lat', lon_col='lon', crs='epsg:4326'): 17 | """Read in a csv file with lat,lon values in a column and turn those lat lon 18 | values into geometry.Point objects. 19 | Arguments: 20 | --------- 21 | : df (pd.DataFrame) 22 | : lat_col (str) 23 | the column in the dataframe that has the point latitude information 24 | : lon_col (str) 25 | the column in the dataframe that has the point longitude information 26 | : crs (str) 27 | coordinate reference system (defaults to 'epsg:4326') 28 | Returns: 29 | ------- 30 | : gdf (gpd.GeoDataFrame) 31 | a geopandas.GeoDataFrame object 32 | """ 33 | df['geometry'] = [geometry.Point(y, x) \ 34 | for x, y in zip(df[lat_col], 35 | df[lon_col]) 36 | ] 37 | crs = {'init': crs} 38 | gdf = gpd.GeoDataFrame(df, crs=crs, geometry="geometry") 39 | return gdf 40 | 41 | 42 | # ------------------------------------------------------------------------------ 43 | # Functions for reprojecting using GDAL and reading resulting .nc file back 44 | # ------------------------------------------------------------------------------ 45 | 46 | 47 | def gdal_reproject(infile, outfile, **kwargs): 48 | """Use gdalwarp to reproject one file to another 49 | 50 | Help: 51 | ---- 52 | https://www.gdal.org/gdalwarp.html 53 | """ 54 | to_proj4_string = "+proj=longlat +ellps=WGS84 +datum=WGS84 +no_defs" 55 | resample_method = "near" 56 | 57 | # check options 58 | valid_resample_methods = [ 59 | "average", 60 | "near", 61 | "bilinear", 62 | "cubic", 63 | "cubicspline", 64 | "lanczos", 65 | "mode", 66 | "max", 67 | "min", 68 | "med", 69 | "q1", 70 | "q3", 71 | ] 72 | assert ( 73 | resample_method in valid_resample_methods 74 | ), f"Resample method not Valid. Must be one of: {valid_resample_methods} Currently: {resample_method}" 75 | 76 | cmd = f'gdalwarp -t_srs "{to_proj4_string}" -of netCDF -r average -dstnodata -9999 -ot Float32 {infile} {outfile}' 77 | 78 | # run command 79 | print(f"\n\n#### Running command: {cmd} ####\n\n") 80 | os.system(cmd) 81 | print(f"\n\n#### Run command {cmd} \n FILE REPROJECTED ####\n\n") 82 | 83 | return 84 | 85 | 86 | def bands_to_time(da, times, var_name): 87 | """ For a dataArray with each timestep saved as a different band, create 88 | a time Coordinate 89 | """ 90 | # get a list of all the bands as dataarray objects (for concatenating later) 91 | band_strings = [key for key in da.variables.keys() if "Band" in key] 92 | bands = [da[key] for key in band_strings] 93 | bands = [band.rename(var_name) for band in bands] 94 | 95 | # check the number of bands matches n timesteps 96 | assert len(times) == len( 97 | bands 98 | ), f"The number of bands should match the number of timesteps. n bands: {len(times)} n times: {len(bands)}" 99 | # concatenate into one array 100 | timestamped_da = xr.concat(bands, dim=times) 101 | 102 | return timestamped_da 103 | 104 | 105 | # ------------------------------------------------------------------------------ 106 | # Functions for matching resolutions / gridsizes (time and space) 107 | # ------------------------------------------------------------------------------ 108 | 109 | 110 | def convert_to_same_grid(reference_ds, ds, method="nearest_s2d"): 111 | """ Use xEMSF package to regrid ds to the same grid as reference_ds """ 112 | assert ("lat" in reference_ds.dims) & ( 113 | "lon" in reference_ds.dims 114 | ), f"Need (lat,lon) in reference_ds dims Currently: {reference_ds.dims}" 115 | assert ("lat" in ds.dims) & ( 116 | "lon" in ds.dims 117 | ), f"Need (lat,lon) in ds dims Currently: {ds.dims}" 118 | 119 | # create the grid you want to convert TO (from reference_ds) 120 | ds_out = xr.Dataset( 121 | {"lat": (["lat"], reference_ds.lat), "lon": (["lon"], reference_ds.lon)} 122 | ) 123 | 124 | # create the regridder object 125 | # xe.Regridder(grid_in, grid_out, method='bilinear') 126 | regridder = xe.Regridder(ds, ds_out, method, reuse_weights=True) 127 | 128 | # IF it's a dataarray just do the original transformations 129 | if isinstance(ds, xr.core.dataarray.DataArray): 130 | ds = regridder(ds) 131 | # OTHERWISE loop through each of the variables, regrid the datarray then recombine into dataset 132 | elif isinstance(ds, xr.core.dataset.Dataset): 133 | vars = [i for i in ds.var().variables] 134 | if len(vars) == 1: 135 | ds = regridder(ds) 136 | else: 137 | output_dict = {} 138 | # LOOP over each variable and append to dict 139 | for var in vars: 140 | print(f"- regridding var {var} -") 141 | da = ds[var] 142 | da = regridder(da) 143 | output_dict[var] = da 144 | # REBUILD 145 | ds = xr.Dataset(output_dict) 146 | else: 147 | assert False, "This function only works with xarray dataset / dataarray objects" 148 | 149 | print( 150 | f"Regridded from {(regridder.Ny_in, regridder.Nx_in)} to {(regridder.Ny_out, regridder.Nx_out)}" 151 | ) 152 | 153 | return ds 154 | 155 | 156 | def select_same_time_slice(reference_ds, ds): 157 | """ Select the values for the same timestep as the reference ds""" 158 | # CHECK THEY ARE THE SAME FREQUENCY 159 | # get the frequency of the time series from reference_ds 160 | freq = pd.infer_freq(reference_ds.time.values) 161 | if freq == None: 162 | warnings.warn('HARDCODED FOR THIS PROBLEM BUT NO IDEA WHY NOT WORKING') 163 | freq = "M" 164 | # assert False, f"Unable to infer frequency from the reference_ds timestep" 165 | 166 | old_freq = pd.infer_freq(ds.time.values) 167 | warnings.warn( 168 | "Disabled the assert statement. ENSURE FREQUENCIES THE SAME (e.g. monthly)" 169 | ) 170 | # assert freq == old_freq, f"The frequencies should be the same! currenlty ref: {freq} vs. old: {old_freq}" 171 | 172 | # get the STARTING time point from the reference_ds 173 | min_time = reference_ds.time.min().values 174 | max_time = reference_ds.time.max().values 175 | orig_time_range = pd.date_range(min_time, max_time, freq=freq) 176 | # EXTEND the original time_range by 1 (so selecting the whole slice) 177 | # because python doesn't select the final in a range 178 | periods = len(orig_time_range) #+ 1 179 | # create new time series going ONE EXTRA PERIOD 180 | new_time_range = pd.date_range(min_time, freq=freq, periods=periods) 181 | new_max = new_time_range.max() 182 | 183 | # select using the NEW MAX as upper limit 184 | # -------------------------------------------------------------------------- 185 | # FOR SOME REASON slice is removing the minimum time ... 186 | # something to do with the fact that matplotlib / xarray is working oddly with numpy64datetime object 187 | warnings.warn("L153: HARDCODING THE MIN VALUE OTHERWISE IGNORED ...") 188 | min_time = datetime.datetime(2001, 1, 31) 189 | # -------------------------------------------------------------------------- 190 | ds = ds.sel(time=slice(min_time, new_max)) 191 | assert reference_ds.time.shape[0] == ds.time.shape[0],f"The time dimensions should match, currently reference_ds.time dims {reference_ds.time.shape[0]} != ds.time dims {ds.time.shape[0]}" 192 | 193 | print_time_min = pd.to_datetime(ds.time.min().values) 194 | print_time_max = pd.to_datetime(ds.time.max().values) 195 | try: 196 | vars = [i for i in ds.var().variables] 197 | except: 198 | vars = ds.name 199 | # ref_vars = [i for i in reference_ds.var().variables] 200 | print( 201 | f"Select same timeslice for ds with vars: {vars}. Min {print_time_min} Max {print_time_max}" 202 | ) 203 | 204 | return ds 205 | 206 | 207 | def get_holaps_mask(ds): 208 | """ 209 | NOTE: 210 | - assumes that all of the null values from the HOLAPS file are valid null values (e.g. water bodies). Could also be invalid nulls due to poor data processing / lack of satellite input data for a pixel! 211 | """ 212 | warnings.warn( 213 | "assumes that all of the null values from the HOLAPS file are valid null values (e.g. water bodies). Could also be invalid nulls due to poor data processing / lack of satellite input data for a pixel!" 214 | ) 215 | warnings.warn( 216 | "How to collapse the time dimension in the holaps mask? Here we just select the first time because all of the valid pixels are constant for first, last second last. Need to check this is true for all timesteps" 217 | ) 218 | 219 | mask = ds.isnull().isel(time=0).drop("time") 220 | mask.name = "holaps_mask" 221 | 222 | mask = xr.concat([mask for _ in range(len(ds.time))]) 223 | mask = mask.rename({"concat_dims": "time"}) 224 | mask["time"] = ds.time 225 | 226 | return mask 227 | 228 | 229 | 230 | def select_east_africa(ds): 231 | """ """ 232 | lonmin=32.6 233 | lonmax=51.8 234 | latmin=-5.0 235 | latmax=15.2 236 | 237 | if ('x' in ds.dims) & ('y' in ds.dims): 238 | ds = ds.sel(y=slice(latmax,latmin),x=slice(lonmin, lonmax)) 239 | elif ('lat' in ds.dims) & ('lon' in ds.dims): 240 | ds = ds.sel(lat=slice(latmax,latmin),lon=slice(lonmin, lonmax)) 241 | elif ('latitude' in ds.dims) & ('longitude' in ds.dims): 242 | ds = ds.sel(latitude=slice(latmax,latmin),longitude=slice(lonmin, lonmax)) 243 | else: 244 | assert False, "You need one of [(y, x), (lat, lon), (latitude, longitude)] in your dims" 245 | 246 | return 247 | 248 | 249 | # ------------------------------------------------------------------------------ 250 | # Functions for working with xarray objects 251 | # ------------------------------------------------------------------------------ 252 | 253 | 254 | def merge_data_arrays(*DataArrays): 255 | das = [da.name for da in DataArrays] 256 | print(f"Merging data: {das}") 257 | ds = xr.merge([*DataArrays]) 258 | return ds 259 | 260 | 261 | def save_netcdf(xr_obj, filepath, force=False): 262 | """""" 263 | if not Path(filepath).is_file(): 264 | xr_obj.to_netcdf(filepath) 265 | print(f"File Saved: {filepath}") 266 | elif force: 267 | print(f"Filepath {filepath} already exists! Overwriting...") 268 | xr_obj.to_netcdf(filepath) 269 | print(f"File Saved: {filepath}") 270 | else: 271 | print(f"Filepath {filepath} already exists!") 272 | 273 | return 274 | 275 | 276 | def get_all_valid(ds, holaps_da, modis_da, gleam_da): 277 | """ Return only values for pixels/times where ALL PRODUCTS ARE VALID """ 278 | valid_mask = ( 279 | holaps_da.notnull() 280 | & modis_da.notnull() 281 | & gleam_da.notnull() 282 | ) 283 | ds_valid = ds.where(valid_mask) 284 | 285 | return ds_valid 286 | 287 | 288 | def drop_nans_and_flatten(dataArray): 289 | """flatten the array and drop nans from that array. Useful for plotting histograms. 290 | 291 | Arguments: 292 | --------- 293 | : dataArray (xr.DataArray) 294 | the DataArray of your value you want to flatten 295 | """ 296 | # drop NaNs and flatten 297 | return dataArray.values[~np.isnan(dataArray.values)] 298 | 299 | # 300 | 301 | # 302 | # def create_flattened_dataframe_of_values(h,g,m): 303 | # """ """ 304 | # h_ = drop_nans_and_flatten(h) 305 | # g_ = drop_nans_and_flatten(g) 306 | # m_ = drop_nans_and_flatten(m) 307 | # df = pd.DataFrame(dict( 308 | # holaps=h_, 309 | # gleam=g_, 310 | # modis=m_ 311 | # )) 312 | # return df 313 | -------------------------------------------------------------------------------- /notebooks/02_gt_linear_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Linear model\n", 8 | "\n", 9 | "In this notebook, we train a linear model and investigate the coefficients it learns. These coefficients can be interpreted as the global feature importances of the input features." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from pathlib import Path\n", 19 | "from itertools import chain\n", 20 | "import pandas as pd\n", 21 | "import numpy as np\n", 22 | "\n", 23 | "import sys\n", 24 | "sys.path.insert(0, '..')\n", 25 | "\n", 26 | "from predictor.models import LinearModel\n", 27 | "from predictor.preprocessing import VALUE_COLS, VEGETATION_LABELS" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "target = 'ndvi'" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "path_to_arrays = Path(f'../data/processed/{target}/arrays')" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 4, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "model_with_veg = LinearModel(path_to_arrays)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 5, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "Train set RMSE: 0.03977352798000336\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "model_with_veg.train()" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "We can isolate the coefficients, as well as the values they correspond to. Note that each label in `value_labels` has the following format: `{value}_{month}` where `month` is relative to the `pred_month` (so if we are predicting June, then `month=11` corresponds to data in May)." 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 6, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "coefs = model_with_veg.model.coef_" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 7, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "value_labels = list(chain(*[[f'{val}_{month}' for val in VALUE_COLS] for month in range(1, 12)]))" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 8, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "feature_importances_veg = pd.DataFrame(data={\n", 106 | " 'feature': value_labels,\n", 107 | " 'value': coefs\n", 108 | "})" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "Lets investigate the most important features (by absolute value)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 9, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "text/html": [ 126 | "
\n", 127 | "\n", 140 | "\n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | "
featurevalue
74ndvi_110.118459
4ndvi_10.110067
67ndvi_10-0.054666
6ndvi_anomaly_1-0.036427
11ndvi_2-0.031033
75evi_110.020969
39ndvi_60.020805
71lst_day_11-0.016052
46ndvi_70.015323
40evi_6-0.014433
\n", 201 | "
" 202 | ], 203 | "text/plain": [ 204 | " feature value\n", 205 | "74 ndvi_11 0.118459\n", 206 | "4 ndvi_1 0.110067\n", 207 | "67 ndvi_10 -0.054666\n", 208 | "6 ndvi_anomaly_1 -0.036427\n", 209 | "11 ndvi_2 -0.031033\n", 210 | "75 evi_11 0.020969\n", 211 | "39 ndvi_6 0.020805\n", 212 | "71 lst_day_11 -0.016052\n", 213 | "46 ndvi_7 0.015323\n", 214 | "40 evi_6 -0.014433" 215 | ] 216 | }, 217 | "execution_count": 9, 218 | "metadata": {}, 219 | "output_type": "execute_result" 220 | } 221 | ], 222 | "source": [ 223 | "feature_importances_veg.iloc[(-np.abs(feature_importances_veg['value'].values)).argsort()][:10]" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": {}, 229 | "source": [ 230 | "## Without vegetation\n", 231 | "\n", 232 | "The model above tells us that the vegetation health in May is predictive of the vegetation health in June. What happens if we hide vegetation health from the model?" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 10, 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "model_no_veg = LinearModel(path_to_arrays, hide_vegetation=True)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 11, 247 | "metadata": {}, 248 | "outputs": [ 249 | { 250 | "name": "stdout", 251 | "output_type": "stream", 252 | "text": [ 253 | "Training model without vegetation features\n", 254 | "Train set RMSE: 0.07832527470825593\n" 255 | ] 256 | } 257 | ], 258 | "source": [ 259 | "model_no_veg.train()" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 12, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "coefs = model_no_veg.model.coef_\n", 269 | "\n", 270 | "veg_features = ['ndvi', 'evi']\n", 271 | "value_labels = list(chain(*[[f'{val}_{month}' for val in VALUE_COLS if val not in VEGETATION_LABELS] \n", 272 | " for month in range(1, 12)]))" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 13, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "feature_importances_no_veg = pd.DataFrame(data={\n", 282 | " 'feature': value_labels,\n", 283 | " 'value': coefs\n", 284 | "})" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 14, 290 | "metadata": {}, 291 | "outputs": [ 292 | { 293 | "data": { 294 | "text/html": [ 295 | "
\n", 296 | "\n", 309 | "\n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | "
featurevalue
41lst_day_11-0.068263
42precip_110.048880
0lst_night_10.035778
40lst_night_11-0.035483
30precip_80.030408
20lst_night_60.029277
21lst_day_6-0.026891
10precip_30.026077
33lst_day_90.025803
2precip_10.021489
\n", 370 | "
" 371 | ], 372 | "text/plain": [ 373 | " feature value\n", 374 | "41 lst_day_11 -0.068263\n", 375 | "42 precip_11 0.048880\n", 376 | "0 lst_night_1 0.035778\n", 377 | "40 lst_night_11 -0.035483\n", 378 | "30 precip_8 0.030408\n", 379 | "20 lst_night_6 0.029277\n", 380 | "21 lst_day_6 -0.026891\n", 381 | "10 precip_3 0.026077\n", 382 | "33 lst_day_9 0.025803\n", 383 | "2 precip_1 0.021489" 384 | ] 385 | }, 386 | "execution_count": 14, 387 | "metadata": {}, 388 | "output_type": "execute_result" 389 | } 390 | ], 391 | "source": [ 392 | "feature_importances_no_veg.iloc[(-np.abs(feature_importances_no_veg['value'].values)).argsort()][:10]" 393 | ] 394 | } 395 | ], 396 | "metadata": { 397 | "kernelspec": { 398 | "display_name": "Python [conda env:vegetation_health]", 399 | "language": "python", 400 | "name": "conda-env-vegetation_health-py" 401 | }, 402 | "language_info": { 403 | "codemirror_mode": { 404 | "name": "ipython", 405 | "version": 3 406 | }, 407 | "file_extension": ".py", 408 | "mimetype": "text/x-python", 409 | "name": "python", 410 | "nbconvert_exporter": "python", 411 | "pygments_lexer": "ipython3", 412 | "version": "3.6.8" 413 | } 414 | }, 415 | "nbformat": 4, 416 | "nbformat_minor": 2 417 | } 418 | -------------------------------------------------------------------------------- /data/common_grid.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for putting netcdf (.nc) files onto a common grid. This means that the 3 | location, timesteps and resolution of the data is now the same. This makes 4 | working with the data a lot simpler! 5 | """ 6 | 7 | import xarray as xr 8 | import xesmf as xe # for regridding 9 | import numpy as np 10 | import pickle 11 | import matplotlib.pyplot as plt 12 | import pandas as pd 13 | import tqdm # for progress-bars 14 | import click 15 | import warnings 16 | 17 | from drought_masking import create_drought_mask 18 | from utils import save_netcdf 19 | 20 | 21 | def pickle_obj(obj, filename): 22 | """ write to pickle 23 | """ 24 | assert (filename.split('.')[-1] == "pickle") or (filename.split('.')[-1] == "pkl"), f"filename should end with ('pickle','pkl') currently: {filename}" 25 | output = open(filename, 'wb') 26 | pickle.dump(obj, output) 27 | 28 | return 29 | 30 | 31 | def read_pickle_df(filepath): 32 | """ read pickled object 33 | """ 34 | obj = pd.read_pickle(filepath) 35 | return obj 36 | 37 | 38 | def normalise_coordinate_names(ds): 39 | """ rename latitude/longitude to lat/lon """ 40 | if "latitude" in ds.dims: 41 | ds = ds.rename({"latitude":"lat"}) 42 | print("Renamed `latitude` to `lat`") 43 | if "longitude" in ds.dims: 44 | ds = ds.rename({"longitude":"lon"}) 45 | print("Renamed `longitude` to `lon`") 46 | assert "time" in ds.dims, f"Must have `time` as dimension in object. Currently: {ds.dims}" 47 | 48 | return ds 49 | 50 | 51 | def select_same_time_slice(reference_ds, ds): 52 | """ Select the values for the same timestep as the """ 53 | # CHECK THEY ARE THE SAME FREQUENCY 54 | # get the frequency of the time series from reference_ds 55 | freq = pd.infer_freq(reference_ds.time.values) 56 | old_freq = pd.infer_freq(ds.time.values) 57 | assert freq == old_freq, f"The frequencies should be the same! currenlty ref: {freq} vs. old: {old_freq}" 58 | 59 | # get the STARTING time point from the reference_ds 60 | min_time = reference_ds.time.min().values 61 | max_time = reference_ds.time.max().values 62 | orig_time_range = pd.date_range(min_time, max_time, freq=freq) 63 | # EXTEND the original time_range by 1 (so selecting the whole slice) 64 | # because python doesn't select the final in a range 65 | periods = len(orig_time_range) + 1 66 | # create new time series going ONE EXTRA PERIOD 67 | new_time_range = pd.date_range(min_time, freq=freq, periods=periods) 68 | new_max = new_time_range.max() 69 | 70 | # select using the NEW MAX as upper limit 71 | ds = ds.sel(time=slice(min_time, new_max)) 72 | # assert reference_ds.time.shape[0] == ds.time.shape[0],"The time dimensions should match, currently reference_ds.time dims {reference_ds.time.shape[0]} != ds.time dims {ds.time.shape[0]}" 73 | 74 | print_time_min = pd.to_datetime(ds.time.min().values) 75 | print_time_max = pd.to_datetime(ds.time.max().values) 76 | try: 77 | vars = [i for i in ds.var().variables] 78 | except: 79 | vars = ds.name 80 | ref_vars = [i for i in reference_ds.var().variables] 81 | print(f"Select same timeslice for ds with vars: {vars}. Min {print_time_min} Max {print_time_max}") 82 | 83 | return ds 84 | 85 | 86 | def select_same_lat_lon_slice(reference_ds, ds): 87 | """ 88 | Take a slice of data from the `ds` according to the bounding box from 89 | the reference_ds. 90 | NOTE: - latitude has to be from max() to min() for some reason? 91 | - becuase it's crossing the equator? e.g. -14.:8. 92 | Therefore, have to run an if statement to decide which way round to put the data 93 | """ 94 | # lat_bounds = [reference_ds.lat.min(),reference_ds.lat.max()] 95 | # lon_bounds = [reference_ds.lon.min(),reference_ds.lon.max()] 96 | if len(ds.sel(lat=slice(reference_ds.lat.min(), reference_ds.lat.max())).lat) == 0: 97 | ds = ds.sel(lat=slice(reference_ds.lat.max(), reference_ds.lat.min())) 98 | else: 99 | ds = ds.sel(lat=slice(reference_ds.lat.min(), reference_ds.lat.max())) 100 | ds = ds.sel(lon=slice(reference_ds.lon.min(), reference_ds.lon.max())) 101 | 102 | try: 103 | vars = [i for i in ds.var().variables] 104 | except: 105 | vars = ds.name 106 | ref_vars = [i for i in reference_ds.var().variables] 107 | print(f"Select the same bounding box for ds {vars} from reference_ds {ref_vars}") 108 | return ds 109 | 110 | 111 | def open_drought_ds(data_path = 'spei_spi.nc'): 112 | """ returns the raw dataset & 2x boolean masks (SPI/SPEI)""" 113 | # Data path on MONTHLY GRID 114 | 115 | # open the data () 116 | ds = xr.open_dataset(data_path) 117 | ds = ds.rename({"value":"spi"}) 118 | ds = normalise_coordinate_names(ds) 119 | 120 | # turn the values less than -1 into mask 121 | spei_drought = ds.spei.where(ds.spei < -1) 122 | spi_drought = ds.spi.where(ds.spi < -1) 123 | 124 | # turn into a boolean mask 125 | drought_spei = spei_drought.notnull() 126 | drought_spei = drought_spei.rename('drought_spei') 127 | drought_spi = spi_drought.notnull() 128 | drought_spi = drought_spi.rename('drought_spi') 129 | 130 | return ds, drought_spei, drought_spi 131 | 132 | 133 | def ensure_same_time(reference_ds, ds): 134 | """ convert the TIMESTEPS to the same values 135 | e.g. • if Monthly data sometimes its in the middle - 01-16-98 136 | • other times its the start 01-01-98, othertimes 01-31-98 137 | Set them all to the same frequency using the freq from reference_ds 138 | """ 139 | freq = pd.infer_freq(reference_ds.time.values) 140 | dr = pd.date_range(reference_ds.time.min().values , periods=len(ds.time.values), freq=freq) 141 | ds['time'] = dr 142 | 143 | return ds 144 | 145 | 146 | def convert_to_same_time_freq(reference_ds,ds): 147 | """ Upscale or downscale data so on the same time frequencies 148 | e.g. convert daily data to monthly ('MS' = month start) 149 | """ 150 | freq = pd.infer_freq(reference_ds.time.values) 151 | ds = ds.resample(time='MS').median(dim='time') 152 | 153 | try: 154 | vars = [i for i in ds.var().variables] 155 | except: 156 | vars = ds.name 157 | print(f"Resampled ds ({vars}) to {freq}") 158 | return ds 159 | 160 | 161 | def convert_to_same_grid(reference_ds, ds): 162 | """ Use xEMSF package to regrid ds to the same grid as reference_ds """ 163 | assert ("lat" in reference_ds.dims)&("lon" in reference_ds.dims), f"Need (lat,lon) in reference_ds dims Currently: {reference_ds.dims}" 164 | assert ("lat" in ds.dims)&("lon" in ds.dims), f"Need (lat,lon) in ds dims Currently: {ds.dims}" 165 | 166 | # create the grid you want to convert TO (from reference_ds) 167 | ds_out = xr.Dataset({ 168 | 'lat': (['lat'], reference_ds.lat), 169 | 'lon': (['lon'], reference_ds.lon), 170 | }) 171 | 172 | # create the regridder object 173 | # xe.Regridder(grid_in, grid_out, method='bilinear') 174 | regridder = xe.Regridder(ds, ds_out, 'bilinear', reuse_weights=True) 175 | 176 | # IF it's a dataarray just do the original transformations 177 | if isinstance(ds, xr.core.dataarray.DataArray): 178 | ds = regridder(ds) 179 | # OTHERWISE loop through each of the variables, regrid the datarray then recombine into dataset 180 | elif isinstance(ds, xr.core.dataset.Dataset): 181 | vars = [i for i in ds.var().variables] 182 | if len(vars) ==1 : 183 | ds = regridder(ds) 184 | else: 185 | output_dict = {} 186 | # LOOP over each variable and append to dict 187 | for var in vars: 188 | print(f"- regridding var {var} -") 189 | da = ds[var] 190 | da = regridder(da) 191 | output_dict[var] = da 192 | # REBUILD 193 | ds = xr.Dataset(output_dict) 194 | else: 195 | assert False, "This function only works with xarray dataset / dataarray objects" 196 | 197 | print(f"Regridded from {(regridder.Ny_in, regridder.Nx_in)} to {(regridder.Ny_out, regridder.Nx_out)}") 198 | 199 | return ds 200 | 201 | 202 | def netcdf_to_same_dim_shapes(reference_ds, *other_netcdfs): 203 | """ loop over each of the other_netcdfs and reshape them in TIME, LAT/LON 204 | to be the same as the reference_ds. 205 | This allows them to be stored as a single netcdf file 206 | """ 207 | ds_to_merge = [] 208 | for ds_ in tqdm.tqdm([*other_netcdfs]): 209 | ds_ = normalise_coordinate_names(ds_) 210 | assert ("lat" in ds_.dims)&("lon" in ds_.dims),f"lat and lon should be in ds.dims. Currently: {ds_.dims}" 211 | 212 | # select the same SLICE of time (reference_ds.min() - reference_ds.min()) 213 | ds_ = convert_to_same_time_freq(reference_ds, ds_) 214 | ds_ = select_same_time_slice(reference_ds, ds_) 215 | 216 | # REGRID to same dimensions 217 | ds_ = convert_to_same_grid(reference_ds, ds_) 218 | # select the same lat lon slice 219 | ds_ = select_same_lat_lon_slice(reference_ds, ds_) 220 | 221 | # ENSURE DIMS MATCH 222 | first_var = [i for i in reference_ds.var().variables][0] 223 | np.testing.assert_allclose(reference_ds[first_var].shape[0], ds_.shape[0], atol=1), f"The TIME DIMENSION should be equal. Currently (time, lat, lon) {ds_.shape} should be {reference_ds[first_var].shape}" 224 | assert reference_ds[first_var].shape[1] == ds_.shape[1], f"The LAT DIMENSION should be equal. Currently (time, lat, lon) {ds_.shape} should be {reference_ds[first_var].shape}" 225 | assert reference_ds[first_var].shape[2] == ds_.shape[2], f"The LON DIMENSION should be equal. Currently (time, lat, lon) {ds_.shape} should be {reference_ds[first_var].shape}" 226 | 227 | ds_to_merge.append(ds_) 228 | 229 | return ds_to_merge 230 | 231 | 232 | def get_all_vars_from_ds(ds): 233 | """ return a list of strings for all variables in dataset """ 234 | assert isinstance(ds, xr.core.dataset.Dataset), f"Currently only works with xr.Dataset objects. ds = {type(ds)}" 235 | vars = [i for i in ds.var().variables] 236 | return vars 237 | 238 | 239 | def append_reference_ds_vars(reference_ds, ds_to_merge): 240 | """merge in the reference_ds variables to the ds_to_merge""" 241 | for var in reference_ds.var().variables: 242 | ds_to_merge.append(reference_ds[var]) 243 | 244 | return ds_to_merge 245 | 246 | 247 | def merge_netcdfs_to_one_file(reference_ds, *other_netcdfs, drought=False): 248 | """ merge the netcdf files with multiple variables into ONE netcdf file. 249 | Use the structure from the reference_ds to get the same TIME & LAT/LON GRID. 250 | """ 251 | ds_to_merge = netcdf_to_same_dim_shapes(reference_ds, *other_netcdfs) 252 | ds_to_merge = append_reference_ds_vars(reference_ds, ds_to_merge) 253 | 254 | # join in the drought mask 255 | # -------------------------------------------------------------------------- 256 | warnings.warn('This is done once here to put the ndvi mask into the nc file') 257 | output_ds = xr.merge(ds_to_merge) 258 | # -------------------------------------------------------------------------- 259 | if drought: 260 | warnings.warn('Should be less focused on hardcoding for the drought variables. need to have a look at this') 261 | # TODO: reference ds might not be the drought mask. this functionality should be elsewhere 262 | spei_mask, spi_mask, ndvi_mask = create_drought_mask(output_ds[['spi','spei','ndvi']]) 263 | ds_to_merge.append(spei_mask) 264 | ds_to_merge.append(spi_mask) 265 | ds_to_merge.append(ndvi_mask) 266 | 267 | output_ds = xr.merge(ds_to_merge) 268 | 269 | return output_ds 270 | 271 | 272 | def check_other_netcdfs(*other_netcdfs): 273 | """ if dataarray check that it's named! """ 274 | for i, xr_obj in enumerate([*other_netcdfs]): 275 | if isinstance(xr_obj, xr.core.dataarray.DataArray): 276 | assert xr_obj.name != None, f"All dataarrays must be named! Dataarray #{i+1} not named" 277 | return 278 | 279 | 280 | def read_files(files): 281 | """ read in the files to be """ 282 | xr_objs = [] 283 | for file in files: 284 | xr_obj = xr.open_dataset(file) 285 | xr_objs.append(xr_obj) 286 | 287 | return xr_objs 288 | 289 | 290 | # @click.command() 291 | # @click.argument('reference_ds_path', type=click.Path(exists=True), default="EA_data/spei_spi.nc") 292 | # @click.option('files', '--files', envvar='FILES', multiple=True, type=click.Path()) 293 | # @click.argument('output', type=click.File('wb')) 294 | # @click.option('--drought', default=False) 295 | if __name__ == "__main__": 296 | # TODO: IMPLEMENT THIS ALL IN PARALLEL 297 | 298 | reference_ds_path = '/soge-home/users/chri4118/EA_data/spei_spi.nc' 299 | # the reference ds 300 | reference_ds = xr.open_dataset('/soge-home/users/chri4118/EA_data/spei_spi.nc') 301 | reference_ds = reference_ds.rename({"value":'spi'}) 302 | reference_ds = normalise_coordinate_names(reference_ds) 303 | 304 | # the output filepath 305 | output = "OUT.nc" 306 | 307 | # -------------------------------------------------------------------------- 308 | # TODO: THIS ALL NEEDS TO BE MORE DYNAMICALLY SET UP 309 | # variables to join 310 | TEMP = xr.open_dataset("/soge-home/users/chri4118/EA_data/LST_EastAfrica.nc") 311 | lst_day = TEMP.lst_day 312 | lst_night = TEMP.lst_night 313 | lst_mean = (TEMP.lst_day + TEMP.lst_night) / 2 314 | lst_mean.name = "lst_mean" 315 | 316 | ET = xr.open_dataset("/soge-home/users/chri4118/EA_data/ET_EastAfrica.nc") 317 | evap = ET.evaporation 318 | baresoil_evap = ET.baresoil_evaporation 319 | pet = ET.potential_evaporation 320 | transp = ET.transpiration 321 | surface_sm = ET.surface_soil_moisture 322 | rootzone_sm = ET.rootzone_soil_moisture 323 | 324 | SM = xr.open_dataset('/soge-home/users/chri4118/EA_data/SM_EastAfrica.nc') 325 | sm = SM.sm 326 | 327 | PCP = xr.open_dataset('/soge-home/projects/crop_yield/chirps/EA_precip.nc') 328 | precip = PCP.precip 329 | 330 | VEG = xr.open_dataset('/soge-home/users/chri4118/EA_data/NDVI_EastAfrica.nc') 331 | ndvi = VEG.ndvi 332 | evi = VEG.evi 333 | 334 | # concatenate into one list to pass to the function 335 | vars_list = [lst_day, lst_night, lst_mean, lst_mean, evap, baresoil_evap, pet, transp, surface_sm, rootzone_sm, sm, precip, ndvi, evi] 336 | 337 | print(f"Reading Data from: \n{TEMP}\n{ET}\n{SM}\n{PCP}\n{VEG}") 338 | # -------------------------------------------------------------------------- 339 | 340 | check_other_netcdfs(*vars_list) 341 | output_ds = merge_netcdfs_to_one_file(reference_ds, *vars_list, drought=True) 342 | save_netcdf(output_ds, output) 343 | 344 | print("** Process Finished **") 345 | -------------------------------------------------------------------------------- /notebooks/01_tl_data_exploration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import xarray as xr\n", 10 | "import pandas as pd\n", 11 | "import numpy as np\n", 12 | "import seaborn as sns\n", 13 | "import matplotlib.pyplot as plt" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "data": { 23 | "text/plain": [ 24 | "\n", 25 | "Dimensions: (lat: 404, lon: 316, time: 85)\n", 26 | "Coordinates:\n", 27 | " * time (time) datetime64[ns] 2010-01-01 ... 2017-01-01\n", 28 | " * lon (lon) float32 32.524994 32.574997 ... 48.274994\n", 29 | " * lat (lat) float32 -4.9750023 -4.925003 ... 15.174995\n", 30 | " month (time) int64 ...\n", 31 | "Data variables:\n", 32 | " lst_day (time, lat, lon) float64 ...\n", 33 | " lst_night (time, lat, lon) float64 ...\n", 34 | " lst_mean (time, lat, lon) float64 ...\n", 35 | " evaporation (time, lat, lon) float64 ...\n", 36 | " baresoil_evaporation (time, lat, lon) float64 ...\n", 37 | " potential_evaporation (time, lat, lon) float64 ...\n", 38 | " transpiration (time, lat, lon) float64 ...\n", 39 | " surface_soil_moisture (time, lat, lon) float64 ...\n", 40 | " rootzone_soil_moisture (time, lat, lon) float64 ...\n", 41 | " sm (time, lat, lon) float64 ...\n", 42 | " precip (time, lat, lon) float64 ...\n", 43 | " ndvi (time, lat, lon) float64 ...\n", 44 | " evi (time, lat, lon) float64 ...\n", 45 | " spei (time, lat, lon) float32 ...\n", 46 | " spi (time, lat, lon) float64 ...\n", 47 | " drought_spei (time, lat, lon) bool ...\n", 48 | " drought_spi (time, lat, lon) bool ...\n", 49 | " drought_ndvi (time, lat, lon) bool ..." 50 | ] 51 | }, 52 | "execution_count": 2, 53 | "metadata": {}, 54 | "output_type": "execute_result" 55 | } 56 | ], 57 | "source": [ 58 | "ds = xr.open_dataset('/Volumes/Lees_Extend/data/ea_data/OUT.nc')\n", 59 | "ds" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 61, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "name": "stdout", 69 | "output_type": "stream", 70 | "text": [ 71 | "Done for Varible ndvi\n", 72 | "Done for Varible spi\n", 73 | "Done for Varible precip\n", 74 | "Done for Varible sm\n", 75 | "Done for Varible lst_day\n", 76 | "Done for Varible lst_night\n", 77 | "Done for Varible lst_mean\n" 78 | ] 79 | }, 80 | { 81 | "ename": "AttributeError", 82 | "evalue": "'Figure' object has no attribute 'title'", 83 | "output_type": "error", 84 | "traceback": [ 85 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 86 | "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", 87 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;31m# variables = [\"ndvi\",\"spi\"]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0;31m# TODO: why doesn't it work with 2 values?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 65\u001b[0;31m \u001b[0mfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mplot_variables\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvariables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 66\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msavefig\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'../figs/variable_distribution.png'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 88 | "\u001b[0;32m\u001b[0m in \u001b[0;36mplot_variables\u001b[0;34m(Dataset, variables)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'Done for Varible {var}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtitle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Distribution of variable values'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 89 | "\u001b[0;31mAttributeError\u001b[0m: 'Figure' object has no attribute 'title'" 90 | ] 91 | }, 92 | { 93 | "data": { 94 | "image/png": "\n", 95 | "text/plain": [ 96 | "
" 97 | ] 98 | }, 99 | "metadata": { 100 | "needs_background": "light" 101 | }, 102 | "output_type": "display_data" 103 | } 104 | ], 105 | "source": [ 106 | "def drop_nans_and_flatten(dataArray):\n", 107 | " \"\"\"flatten the array and drop nans from that array. Useful for plotting histograms.\n", 108 | " \n", 109 | " Arguments:\n", 110 | " ---------\n", 111 | " : dataArray (xr.DataArray)\n", 112 | " the DataArray of your value you want to flatten\n", 113 | " \"\"\"\n", 114 | " # drop NaNs and flatten\n", 115 | " return dataArray.values[~np.isnan(dataArray.values)]\n", 116 | "\n", 117 | "\n", 118 | "def set_no_axs_rows(variables):\n", 119 | " \"\"\"Dynamically set the number of axes rows\n", 120 | " \n", 121 | " Arguments:\n", 122 | " ---------\n", 123 | " : variables (list)\n", 124 | " the variables that we want to plot distributions of\n", 125 | " \"\"\"\n", 126 | " # Dynamically set the number of ROWS in the subplots\n", 127 | " if len(variables) % 3 > 0:\n", 128 | " n_axs = (len(variables) // 3) + 1\n", 129 | " else:\n", 130 | " n_axs = len(variables) // 3\n", 131 | " \n", 132 | " return n_axs\n", 133 | "\n", 134 | "def plot_variables(Dataset, variables):\n", 135 | " \"\"\" plot histograms of \n", 136 | " \n", 137 | " Arguments:\n", 138 | " ---------\n", 139 | " : Dataset (xr.Dataset)\n", 140 | " the dataset holding the variables of interest\n", 141 | " : variables (list)\n", 142 | " list of str for the labels of the values that you want to plot\n", 143 | " \"\"\"\n", 144 | " assert all([var_ in [var for var in ds.variables.keys()] for var_ in variables]), f\"The variables supplied must be in the xr.Dataset variables. Currently looking for {variables} in dataset variables: {[var for var in ds.variables.keys()]}\\n The variable missing is: {[var_ in [var for var in ds.variables.keys()] for var_ in variables]}\"\n", 145 | " \n", 146 | " n_axs = set_no_axs_rows(variables)\n", 147 | " fig, axs = plt.subplots(n_axs, 3, figsize=(15,8))\n", 148 | " \n", 149 | " # plot each of the variables \n", 150 | " for ix, var in enumerate(variables):\n", 151 | " # get the axes we are plotting on\n", 152 | " ax_ix = np.unravel_index(ix+1,(n_axs,3))\n", 153 | " ax = axs[ax_ix]\n", 154 | " \n", 155 | " # flatten array and drop nans from the variable\n", 156 | " flat_array = drop_nans_and_flatten(ds[var])\n", 157 | " \n", 158 | " # plot the histogram for that variable\n", 159 | " sns.distplot(flat_array, ax=ax)\n", 160 | " ax.set_title(f'{var} Histogram')\n", 161 | " print(f'Done for Varible {var}')\n", 162 | " \n", 163 | " fig.suptitle('Distribution of variable values')\n", 164 | " plt.tight_layout()\n", 165 | " \n", 166 | " return fig\n", 167 | " \n", 168 | "variables = [\"ndvi\",\"spi\",\"precip\",\"sm\",\"lst_day\",\"lst_night\",\"lst_mean\"]\n", 169 | "# variables = [\"ndvi\",\"spi\"]\n", 170 | "# TODO: why doesn't it work with 2 values?\n", 171 | "fig = plot_variables(ds, variables)\n", 172 | "fig.savefig('../figs/variable_distribution.png')" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 49, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "data": { 182 | "text/plain": [ 183 | "(0, 1)" 184 | ] 185 | }, 186 | "execution_count": 49, 187 | "metadata": {}, 188 | "output_type": "execute_result" 189 | } 190 | ], 191 | "source": [ 192 | "assert False, \"The following does not work because the unravel index doesn't work for \"\n", 193 | "\n", 194 | "variables = [\"ndvi\",\"spi\"]\n", 195 | "plot_variables(ds, variables)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [] 204 | } 205 | ], 206 | "metadata": { 207 | "kernelspec": { 208 | "display_name": "Python 3", 209 | "language": "python", 210 | "name": "python3" 211 | }, 212 | "language_info": { 213 | "codemirror_mode": { 214 | "name": "ipython", 215 | "version": 3 216 | }, 217 | "file_extension": ".py", 218 | "mimetype": "text/x-python", 219 | "name": "python", 220 | "nbconvert_exporter": "python", 221 | "pygments_lexer": "ipython3", 222 | "version": "3.7.2" 223 | } 224 | }, 225 | "nbformat": 4, 226 | "nbformat_minor": 2 227 | } 228 | --------------------------------------------------------------------------------