├── tests
├── __init__.py
├── test_cleaner.py
└── test_rnn.py
├── data
├── preprocessing
│ ├── __init__.py
│ ├── gleam_cleaner.py
│ ├── chirps_cleaner.py
│ ├── collect_regrid_data_script.py
│ ├── spei_cleaner.py
│ ├── cleaner.py
│ └── utils.py
├── README.md
├── raw
│ └── README.md
├── processed
│ └── README.md
├── nc_to_pandas.py
├── ee_extract_landcover_by_country.js
├── globmap_lookup.py
├── utils.py
├── drought_masking.py
└── common_grid.py
├── predictor
├── models
│ ├── neural_networks
│ │ ├── __init__.py
│ │ ├── feedforward.py
│ │ ├── nn_base.py
│ │ └── recurrent.py
│ ├── __init__.py
│ ├── baseline.py
│ └── base.py
├── __init__.py
├── analysis
│ ├── __init__.py
│ ├── utils.py
│ ├── plot_shap.py
│ └── plot_results.py
├── engineer.py
└── preprocessing.py
├── ndvi_results.png
├── figs
├── variables_histogram.png
└── ndvi_results_logistic_regression.png
├── .gitignore
├── TODO.md
├── notebooks
├── 10_tl_growing_season.ipynb
├── 05_tl_explore_models.ipynb
├── 02_gt_linear_model.ipynb
└── 01_tl_data_exploration.ipynb
├── run.py
├── environment.yml
└── README.md
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/predictor/models/neural_networks/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data
2 |
3 | Placeholder for data
4 |
--------------------------------------------------------------------------------
/data/raw/README.md:
--------------------------------------------------------------------------------
1 | # Raw data
2 |
3 | Placeholder for the raw data file, `tabular_data.csv`.
4 |
--------------------------------------------------------------------------------
/ndvi_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tommylees112/vegetation_health/HEAD/ndvi_results.png
--------------------------------------------------------------------------------
/predictor/__init__.py:
--------------------------------------------------------------------------------
1 | from .preprocessing import Cleaner
2 | from .engineer import Engineer
3 |
--------------------------------------------------------------------------------
/predictor/analysis/__init__.py:
--------------------------------------------------------------------------------
1 | from .plot_shap import plot_shap_values
2 | from .plot_results import plot_results
3 |
--------------------------------------------------------------------------------
/figs/variables_histogram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tommylees112/vegetation_health/HEAD/figs/variables_histogram.png
--------------------------------------------------------------------------------
/figs/ndvi_results_logistic_regression.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tommylees112/vegetation_health/HEAD/figs/ndvi_results_logistic_regression.png
--------------------------------------------------------------------------------
/data/processed/README.md:
--------------------------------------------------------------------------------
1 | # Processed Data
2 |
3 | Placeholder for processed data. Steps in the pipeline will save data here for future steps to use.
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.csv
2 | *.npy
3 | .DS_Store
4 | notebooks/.ipynb_checkpoints
5 | .ipynb_checkpoints
6 | __pycache__
7 | *.pyc
8 | .idea
9 | temp.py
10 | *.nc
11 | test.py
12 | *.json
13 | docs/
14 |
--------------------------------------------------------------------------------
/predictor/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .baseline import LinearModel
2 | from .neural_networks.feedforward import FeedForward as nn_FeedForward
3 | from .neural_networks.recurrent import Recurrent as nn_Recurrent
4 |
--------------------------------------------------------------------------------
/TODO.md:
--------------------------------------------------------------------------------
1 | Tommy:
2 | - [x] Produce a tabular dataset of precip / temp / vegetation health indices
3 | - [x] upload to github (if < 50mb) (SENT BY WE TRANSFER)
4 | - [x] Mask out the sea values (not sure if already done - check!)
5 |
6 | Gabriel:
7 | - [x] produce project skeleton
8 | - [x] Mask out the LST values == 200
9 |
--------------------------------------------------------------------------------
/data/nc_to_pandas.py:
--------------------------------------------------------------------------------
1 | """ convert netcdf file to tabular dataformat (pandas)
2 | BOILERPLATE
3 | """
4 | import xarray as xr
5 | import pandas as pd
6 |
7 | # change this path to point to the .nc file
8 | data_dir = ''
9 |
10 | ds = xr.open_dataset(data_dir)
11 |
12 | # NOTE
13 | df = ds.to_dataframe()
14 | df.to_csv('path/to/csv/file')
15 |
--------------------------------------------------------------------------------
/predictor/analysis/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pathlib import Path
3 | from collections import namedtuple
4 |
5 |
6 |
7 | def load_model_data(train_or_test, target='ndvi'):
8 | """ Return a named tuple with the following data attrs
9 | x, y, latlon, years. This is the data fed through the model.
10 | """
11 | if train_or_test == "test":
12 | data_dir = Path('.') / "data" / "processed" / target / "arrays" / "test"
13 | elif train_or_test == "train":
14 | data_dir = Path('.') / "data" / "processed" / target / "arrays" / "train"
15 | else:
16 | assert False, "train_or_test must be either ['train','test']"
17 |
18 | Data = namedtuple('Data',["x","y","latlon","years"])
19 | data = Data(
20 | x=np.load(data_dir/"x.npy"),
21 | y=np.load(data_dir/"y.npy"),
22 | latlon=np.load(data_dir/"latlon.npy"),
23 | years=np.load(data_dir/"years.npy"),
24 | )
25 |
26 | return data
27 |
--------------------------------------------------------------------------------
/tests/test_cleaner.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from datetime import datetime
3 |
4 | from predictor import Cleaner
5 |
6 |
7 | def test_year_month():
8 |
9 | # lets predict June given a year's data
10 | months_2018 = [datetime(2018, x, 1) for x in range(1, 13)]
11 | months_2019 = [datetime(2019, x, 1) for x in range(1, 13)]
12 |
13 | g1, g2, g3 = [1] * 5, [2] * 12, [3] * 7
14 | test_data = {
15 | 'times': months_2018 + months_2019,
16 | 'group': g1 + g2 + g3,
17 | }
18 | test_df = pd.DataFrame(data=test_data)
19 |
20 | cleaner = Cleaner()
21 | test_df['gp_month'], test_df['gp_year'] = cleaner.update_year_month(test_df['times'],
22 | pred_month=5)
23 |
24 | # all of g2 should be in the same gp_year
25 | group_2 = test_df[test_df['gp_year'] == 2019]
26 |
27 | assert len(group_2) == 12, "Chopped out some 2018 data"
28 | assert (group_2['group'] == 2).all(), "Not all the correct months were grouped!"
29 |
--------------------------------------------------------------------------------
/data/ee_extract_landcover_by_country.js:
--------------------------------------------------------------------------------
1 | // ee_extract_landcover_by_country.js
2 |
3 | // import the landcover dataset & the country shapefiles
4 | var globcover = ee.FeatureCollection('ESA/GLOBCOVER_L4_200901_200912_V2_3')
5 | var world_region = ee.FeatureCollection('ft:1tdSwUL7MVpOauSgRzqVTOwdfy17KDbw-1d9omPw')
6 | var landcover = globcover.select('landcover');
7 |
8 | // create Polygon for Ethiopia and Kenya
9 | var eth = world_region.filterMetadata('Country','equals','Ethiopia').select(['landcover']);
10 | var kenya = world_region.filterMetadata('Country','equals','Kenya');
11 |
12 | // clip the landcover to the countries
13 | var eth_lc = landcover.clip(eth);
14 | var ken_lc = landcover.clip(kenya);
15 |
16 | // export the landcover map
17 | var scale = 500;
18 | var crs='EPSG:4326';
19 |
20 | var task = Export.image.toDrive({
21 | image: eth_lc,
22 | description: 'ethiopia_landcover',
23 | scale: 1000,
24 | region: eth,
25 | folder: 'landcover'
26 | })
27 |
28 | var task_kenya = Export.image.toDrive({
29 | image: ken_lc,
30 | description: 'kenya_landcover',
31 | scale: 1000,
32 | region: kenya,
33 | folder: 'landcover'
34 | })
35 |
--------------------------------------------------------------------------------
/predictor/models/baseline.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn import linear_model
3 | from sklearn.metrics import mean_squared_error
4 |
5 | from .base import ModelBase
6 |
7 |
8 | class LinearModel(ModelBase):
9 | """A logistic regression, to be used as a baseline
10 | against our more complex models
11 | """
12 |
13 | model_name = 'linear'
14 |
15 | def train(self):
16 |
17 | train_data = self.load_arrays(mode='train')
18 |
19 | x = train_data.x.reshape(train_data.x.shape[0], -1)
20 |
21 | self.model = linear_model.LinearRegression()
22 | self.model.fit(x, train_data.y)
23 |
24 | train_pred_y = self.model.predict(x)
25 | train_rmse = np.sqrt(mean_squared_error(train_data.y, train_pred_y))
26 |
27 | print(f'Train set RMSE: {train_rmse}')
28 |
29 | def predict(self):
30 |
31 | test_data = self.load_arrays(mode='test')
32 | x = test_data.x.reshape(test_data.x.shape[0], -1)
33 | test_pred_y = self.model.predict(x)
34 | return test_data.y, test_pred_y
35 |
36 | def save_model(self):
37 | savedir = self.data_path / self.model_name
38 | if not savedir.exists(): savedir.mkdir()
39 | np.save(savedir / 'model.npy', self.model.coef_)
40 |
--------------------------------------------------------------------------------
/tests/test_rnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 |
5 | from predictor.models.neural_networks.recurrent import UnrolledRNN
6 |
7 |
8 | def test_rnn():
9 | """
10 | We implement our own unrolled RNN, so that it can be explained with
11 | shap. This test makes sure it roughly mirrors the behaviour of the pytorch
12 | LSTM.
13 | """
14 |
15 | batch_size, hidden_size, features_per_month = 32, 124, 6
16 |
17 | x = torch.ones(batch_size, 1, features_per_month)
18 |
19 | hidden_state = torch.zeros(1, x.shape[0], hidden_size)
20 | cell_state = torch.zeros(1, x.shape[0], hidden_size)
21 |
22 | torch_rnn = nn.LSTM(input_size=features_per_month,
23 | hidden_size=hidden_size,
24 | batch_first=True,
25 | num_layers=1)
26 |
27 | our_rnn = UnrolledRNN(input_size=features_per_month,
28 | hidden_size=hidden_size,
29 | batch_first=True)
30 |
31 | for parameters in torch_rnn.all_weights:
32 | for pam in parameters:
33 | nn.init.constant_(pam.data, 1)
34 |
35 | for parameters in our_rnn.parameters():
36 | for pam in parameters:
37 | nn.init.constant_(pam.data, 1)
38 |
39 | with torch.no_grad():
40 | o_out, (o_cell, o_hidden) = our_rnn(x, (hidden_state, cell_state))
41 | t_out, (t_cell, t_hidden) = torch_rnn(x, (hidden_state, cell_state))
42 |
43 | assert np.isclose(o_out.numpy(), t_out.numpy(), 0.01).all(), "Difference in hidden state"
44 | assert np.isclose(t_cell.numpy(), o_cell.numpy(), 0.01).all(), "Difference in cell state"
45 |
--------------------------------------------------------------------------------
/notebooks/10_tl_growing_season.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# When is the growing season? (Season with Maximum NDVI?)"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 2,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import xarray as xr\n",
17 | "import pandas as pd\n",
18 | "import numpy as np\n",
19 | "from pathlib import Path\n",
20 | "from collections import namedtuple\n",
21 | "import matplotlib.pyplot as plt\n",
22 | "import seaborn as sns\n",
23 | "%matplotlib inline\n",
24 | "\n",
25 | "import os\n",
26 | "if os.getcwd().split('/')[-1] != \"vegetation_health\":\n",
27 | " os.chdir('..')\n",
28 | " \n",
29 | "assert os.getcwd().split('/')[-1] == \"vegetation_health\", f\"Working directory should be the root (), currently: {os.getcwd()}\"\n",
30 | "\n",
31 | "from predictor.analysis.plot_results import create_dataset_from_vars, plot_results\n",
32 | "from predictor.analysis.utils import load_model_data"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": []
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": []
48 | }
49 | ],
50 | "metadata": {
51 | "kernelspec": {
52 | "display_name": "Python 3",
53 | "language": "python",
54 | "name": "python3"
55 | },
56 | "language_info": {
57 | "codemirror_mode": {
58 | "name": "ipython",
59 | "version": 3
60 | },
61 | "file_extension": ".py",
62 | "mimetype": "text/x-python",
63 | "name": "python",
64 | "nbconvert_exporter": "python",
65 | "pygments_lexer": "ipython3",
66 | "version": "3.7.2"
67 | }
68 | },
69 | "nbformat": 4,
70 | "nbformat_minor": 2
71 | }
72 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import fire
2 | from pathlib import Path
3 |
4 | from predictor import Cleaner, Engineer
5 | from predictor.models import LinearModel, nn_FeedForward, nn_Recurrent
6 |
7 |
8 | class RunTask:
9 |
10 | @staticmethod
11 | def clean(raw_filepath='data/raw/predict_vegetation_health.nc',
12 | processed_folder='data/processed',
13 | target='ndvi', pred_month=6):
14 |
15 | raw_filepath, processed_folder = Path(raw_filepath), Path(processed_folder)
16 | processed_filepath = processed_folder / target / 'cleaned_data.csv'
17 |
18 | cleaner = Cleaner(raw_filepath, processed_filepath)
19 | cleaner.process(pred_month, target)
20 |
21 | @staticmethod
22 | def engineer(processed_folder='data/processed', target='ndvi',
23 | test_year=2016):
24 |
25 | processed_folder = Path(processed_folder)
26 | cleaned_data = processed_folder / target / 'cleaned_data.csv'
27 | arrays_folder = processed_folder / target / 'arrays'
28 |
29 | engineer = Engineer(cleaned_data, arrays_folder)
30 | engineer.process(test_year)
31 |
32 | @staticmethod
33 | def train_model(model_type='baseline', data_folder='data',
34 | target='ndvi', hide_vegetation=True, save_results=True):
35 |
36 | data_folder = Path(data_folder)
37 | arrays_folder = data_folder / 'processed' / target / 'arrays'
38 |
39 | string2model = {
40 | 'baseline': LinearModel(data_folder, arrays_folder, hide_vegetation),
41 | 'feedforward': nn_FeedForward(data_folder, arrays_folder, hide_vegetation),
42 | 'recurrent': nn_Recurrent(data_folder, arrays_folder, hide_vegetation),
43 | }
44 |
45 | model = string2model[model_type]
46 | model.train()
47 | model.evaluate(save_preds=save_results)
48 | if save_results:
49 | model.save_model()
50 |
51 |
52 | if __name__ == '__main__':
53 | fire.Fire(RunTask)
54 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: vegetation_health
2 | channels:
3 | - pytorch
4 | - anaconda
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - atomicwrites=1.3.0
9 | - attrs=19.1.0
10 | - blas=1.0
11 | - bzip2=1.0.6
12 | - cftime=1.0.3.4
13 | - curl=7.64.0
14 | - hdf4=4.2.13
15 | - hdf5=1.10.4
16 | - intel-openmp=2019.1
17 | - jpeg=9b
18 | - libcurl=7.64.0
19 | - libgfortran=3.0.1
20 | - libnetcdf=4.6.1
21 | - mkl=2019.1
22 | - mkl_fft=1.0.10
23 | - mkl_random=1.0.2
24 | - more-itertools=6.0.0
25 | - netcdf4=1.4.2
26 | - numpy=1.16.2
27 | - numpy-base=1.16.2
28 | - pandas=0.24.1
29 | - pluggy=0.9.0
30 | - py=1.8.0
31 | - pytest=4.3.0
32 | - python-dateutil=2.8.0
33 | - pytz=2018.9
34 | - scikit-learn=0.20.2
35 | - scipy=1.2.1
36 | - six=1.12.0
37 | - xarray=0.11.3
38 | - ca-certificates=2019.3.9
39 | - certifi=2019.3.9
40 | - cycler=0.10.0
41 | - fire=0.1.3
42 | - freetype=2.9.1
43 | - kiwisolver=1.0.1
44 | - libpng=1.6.36
45 | - libssh2=1.8.0
46 | - matplotlib=3.0.3
47 | - matplotlib-base=3.0.3
48 | - openssl=1.1.1b
49 | - pyparsing=2.3.1
50 | - tqdm=4.31.1
51 | - appnope=0.1.0
52 | - backcall=0.1.0
53 | - cffi=1.12.2
54 | - decorator=4.3.2
55 | - ipykernel=5.1.0
56 | - ipython=7.3.0
57 | - ipython_genutils=0.2.0
58 | - jedi=0.13.3
59 | - jupyter_client=5.2.4
60 | - jupyter_core=4.4.0
61 | - krb5=1.16.1
62 | - libcxx=4.0.1
63 | - libcxxabi=4.0.1
64 | - libedit=3.1.20181209
65 | - libffi=3.2.1
66 | - libsodium=1.0.16
67 | - ncurses=6.1
68 | - ninja=1.8.2
69 | - parso=0.3.4
70 | - pexpect=4.6.0
71 | - pickleshare=0.7.5
72 | - pip=19.0.3
73 | - prompt_toolkit=2.0.9
74 | - ptyprocess=0.6.0
75 | - pycparser=2.19
76 | - pygments=2.3.1
77 | - python=3.6.8
78 | - pyzmq=18.0.0
79 | - readline=7.0
80 | - setuptools=40.8.0
81 | - sqlite=3.26.0
82 | - tk=8.6.8
83 | - tornado=6.0.1
84 | - traitlets=4.3.2
85 | - wcwidth=0.1.7
86 | - wheel=0.33.1
87 | - xz=5.2.4
88 | - zeromq=4.3.1
89 | - zlib=1.2.11
90 | - pytorch=1.0.1
91 | - pip:
92 | - torch==1.0.1.post2
93 |
--------------------------------------------------------------------------------
/predictor/analysis/plot_shap.py:
--------------------------------------------------------------------------------
1 | from mpl_toolkits.axes_grid1 import host_subplot
2 | import mpl_toolkits.axisartist as AA
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | def plot_shap_values(x, shap_values, val_list, normalizing_dict, value_to_plot, normalize_shap_plots=True):
7 | """Plots the denormalized values against their shap values, so that
8 | variations in the input features to the model can be compared to their effect
9 | on the model. For example plots, see notebooks/08_gt_recurrent_model.ipynb.
10 |
11 | Parameters:
12 | ----------
13 | x: np.array
14 | The input to a model for a single data instance
15 | shap_values: np.array
16 | The corresponding shap values (to x)
17 | val_list: list
18 | A list of the variable names, for axis labels
19 | normalizing_dict: dict
20 | The normalizing dict saved by the `Cleaner`, so that the x array can be
21 | denormalized
22 | value_to_plot: str
23 | The specific input variable to plot. Must be in val_list
24 | normalize_shap_plots: boolean
25 | If True, then the scale of the shap plots will be uniform across all
26 | variable plots (on an instance specific basis).
27 |
28 | A plot of the variable `value_to_plot` against its shap values will be plotted.
29 | """
30 | # first, lets isolate the lists
31 | idx = val_list.index(value_to_plot)
32 |
33 | x_val = x[:, idx].cpu().numpy()
34 |
35 | # we also want to denormalize
36 | x_val = (x_val * normalizing_dict[value_to_plot]['std']) + \
37 | normalizing_dict[value_to_plot]['mean']
38 |
39 | shap_val = shap_values[:, idx]
40 |
41 | months = list(range(1, 12))
42 |
43 | host = host_subplot(111, axes_class=AA.Axes)
44 | plt.subplots_adjust(right=0.75)
45 |
46 | par1 = host.twinx()
47 | par1.axis["right"].toggle(all=True)
48 |
49 | if normalize_shap_plots:
50 | par1.set_ylim(shap_values.min(), shap_values.max())
51 |
52 | host.set_xlabel("Months")
53 | host.set_ylabel(value_to_plot)
54 | par1.set_ylabel("Shap value")
55 |
56 | p1, = host.plot(months, x_val, label=value_to_plot)
57 | p2, = par1.plot(months, shap_val, label="shap value")
58 |
59 | host.axis["left"].label.set_color(p1.get_color())
60 | par1.axis["right"].label.set_color(p2.get_color())
61 |
62 | host.legend()
63 |
64 | plt.draw()
65 | plt.show()
66 |
--------------------------------------------------------------------------------
/predictor/analysis/plot_results.py:
--------------------------------------------------------------------------------
1 | import xarray as xr
2 | import pandas as pd
3 | import numpy as np
4 | from pathlib import Path
5 | import matplotlib.pyplot as plt
6 |
7 |
8 | def create_dataset_from_vars(vars, latlon, varname, to_xarray=True):
9 | """ Convert the variables from `np.array` to `pd.DataFrame`
10 | and optionally `xr.Dataset`. By default converts to `xr.Dataset`
11 |
12 | Arguments:
13 | ---------
14 | : vars (np.array)
15 | the values of the variable of interest (e.g. Predictions of NDVI from model)
16 |
17 | : latlon (np.array)
18 | the latlon location for each of the values in vars
19 |
20 | : varname (str)
21 | the name of the variable
22 |
23 | TODO:
24 | ----
25 | * Implement a method that works with TIME so that the xarray objects
26 | have a time dimension too
27 | """
28 | assert len(vars) == len(latlon), f"The length of the latlons array should be the same as the legnth of the vars array. Currently latlons: {len(latlon)} vars: {len(vars)}"
29 |
30 |
31 | df = pd.DataFrame(data={varname: vars, 'lat': latlon[:, 0],
32 | 'lon': latlon[:, 1]}).set_index(['lat', 'lon'])
33 | if to_xarray:
34 | return df.to_xarray()
35 | else:
36 | return df
37 |
38 |
39 | def plot_results(processed_data=Path('data/processed'), target='ndvi',
40 | plot_difference=False, savefig=True):
41 | """Plots a landscape of the results (and optionally,
42 | of the ground truth)
43 | """
44 |
45 | preds = np.load(processed_data / target / 'arrays/preds.npy')
46 | true = np.load(processed_data / target / 'arrays/test/y.npy')
47 | latlon = np.load(processed_data / target / 'arrays/test/latlon.npy')
48 |
49 | preds_xr = create_dataset_from_vars(preds, latlon, "preds", to_xarray=True)
50 | true_xr = create_dataset_from_vars(true, latlon, "true", to_xarray=True)
51 |
52 | data_xr = xr.concat((preds_xr['preds'], true_xr['true']),
53 | pd.Index(['predictions', 'ground truth'], name='data'))
54 |
55 | if plot_difference:
56 | # compute the difference and create a difference plot
57 | data = data_xr.data[1] - data_xr.data[0]
58 | da = xr.DataArray(data, coords=[data_xr.lat, data_xr.lon], dims=['lat','lon'])
59 | ds = da.to_dataset('difference')
60 |
61 | ds.difference.plot(x='lon', y='lat', figsize=(12, 8))
62 | else:
63 | data_xr.plot(x='lon', y='lat', col='data', figsize=(15, 6))
64 |
65 | if savefig:
66 | plt.savefig(f'{target}_results.png', dpi=300, bbox_inches='tight')
67 | plt.show()
68 |
--------------------------------------------------------------------------------
/predictor/models/neural_networks/feedforward.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | from pathlib import Path
4 |
5 | from .nn_base import NNBase
6 | from ...preprocessing import VALUE_COLS, VEGETATION_LABELS
7 |
8 |
9 | class FeedForward(NNBase):
10 | """A simple feedforward neural network
11 | """
12 |
13 | def __init__(self, data=Path('data'), arrays=Path('data/processed/arrays'),
14 | hide_vegetation=False):
15 |
16 | features_per_month = len(VALUE_COLS)
17 | if hide_vegetation:
18 | features_per_month -= len(VEGETATION_LABELS)
19 |
20 | num_features = features_per_month * 11
21 |
22 | super().__init__(LinearModel(num_features, [num_features], 0.25),
23 | data, arrays, hide_vegetation)
24 | self.model_name = "feedforward"
25 |
26 | class LinearModel(nn.Module):
27 |
28 | def __init__(self, input_size, layer_sizes, dropout):
29 | super().__init__()
30 | layer_sizes.insert(0, input_size)
31 |
32 | self.dense_layers = nn.ModuleList([
33 | LinearBlock(in_features=layer_sizes[i - 1],
34 | out_features=layer_sizes[i], dropout=dropout) for
35 | i in range(1, len(layer_sizes))
36 | ])
37 |
38 | self.final_dense = nn.Linear(in_features=layer_sizes[-1], out_features=1)
39 |
40 | self.init_weights()
41 |
42 | def init_weights(self):
43 | for dense_layer in self.dense_layers:
44 | nn.init.kaiming_uniform_(dense_layer.linear.weight.data)
45 |
46 | nn.init.kaiming_uniform_(self.final_dense.weight.data)
47 | # http://cs231n.github.io/neural-networks-2/#init
48 | # see: Initializing the biases
49 | nn.init.constant_(self.final_dense.bias.data, 0)
50 |
51 | def forward(self, x):
52 | # flatten
53 | x = x.view(x.shape[0], -1)
54 | for layer in self.dense_layers:
55 | x = layer(x)
56 |
57 | return self.final_dense(x)
58 |
59 |
60 | class LinearBlock(nn.Module):
61 | """
62 | A linear layer followed by batchnorm, a ReLU activation, and dropout
63 | """
64 |
65 | def __init__(self, in_features, out_features, dropout=0.25):
66 | super().__init__()
67 | self.linear = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
68 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
69 | self.batchnorm = nn.BatchNorm1d(num_features=out_features)
70 | self.dropout = nn.Dropout(dropout)
71 |
72 | def forward(self, x):
73 | x = self.relu(self.batchnorm(self.linear(x)))
74 | return self.dropout(x)
75 |
--------------------------------------------------------------------------------
/data/preprocessing/gleam_cleaner.py:
--------------------------------------------------------------------------------
1 | """gleam_cleaner.py"""
2 | from pathlib import Path
3 | import xarray as xr
4 | import numpy as np
5 | import pandas as pd
6 |
7 | import ipdb
8 | import warnings
9 | import os
10 |
11 | from preprocessing.utils import (
12 | gdal_reproject,
13 | bands_to_time,
14 | convert_to_same_grid,
15 | select_same_time_slice,
16 | save_netcdf,
17 | get_holaps_mask,
18 | merge_data_arrays,
19 | )
20 |
21 | from preprocessing.cleaner import Cleaner
22 |
23 | class GleamCleaner(Cleaner):
24 | def __init__(self):
25 | self.base_data_path = Path("/soge-home/projects/crop_yield/EGU_compare/")
26 | reference_data_path = self.base_data_path / "holaps_EA_clean.nc"
27 |
28 | # CHANGE THIS PATH:
29 | # ----------------------------------------------------------------------
30 | data_path = self.base_data_path / "EA_chirps_monthly.nc"
31 | # ----------------------------------------------------------------------
32 |
33 | # open the reference dataset
34 | self.reference_data_path = Path(reference_data_path)
35 | self.reference_ds = xr.open_dataset(self.reference_data_path).holaps_evapotranspiration
36 |
37 | # initialise the object using methods from the parent class
38 | super(ChirpsCleaner, self).__init__(data_path=data_path)
39 |
40 | # extract the variable of interest (TO xr.DataArray)
41 | self.update_clean_data(
42 | self.raw_data.precip, msg="Extract Precipitation from CHIRPS xr.Dataset"
43 | )
44 |
45 | # make the mask (FROM REFERENCE_DS) to copy to this dataset too
46 | self.get_mask()
47 | # self.mask = self.mask.drop('units')
48 |
49 | def get_mask(self):
50 | self.mask = get_holaps_mask(self.reference_ds)
51 |
52 | def convert_units(self):
53 | # convert unit label to 'mm day-1'
54 | self.clean_data.attrs["units"] = "mm day-1"
55 |
56 |
57 | def preprocess(self):
58 | # Resample the timesteps to END OF MONTH
59 | self.resample_time(resample_str="M")
60 | # select the correct time slice
61 | self.correct_time_slice()
62 | # update the units
63 | self.convert_units()
64 | # regrid to same as reference data (holaps)
65 | self.regrid_to_reference()
66 | # ipdb.set_trace()
67 | # use the same mask as HOLAPS
68 | self.use_reference_mask() # THIS GOING WRONG
69 | # rename data
70 | self.rename_xr_object("gleam_evapotranspiration")
71 | # save data
72 | save_netcdf(
73 | self.clean_data, filepath=self.base_data_path / "gleam_EA_clean.nc"
74 | )
75 | print("\n\n GLEAM Preprocessed \n\n")
76 | return
77 |
--------------------------------------------------------------------------------
/data/preprocessing/chirps_cleaner.py:
--------------------------------------------------------------------------------
1 | """chirps_cleaner.py"""
2 | from pathlib import Path
3 | import xarray as xr
4 | import numpy as np
5 | import pandas as pd
6 |
7 | import ipdb
8 | import warnings
9 | import os
10 |
11 | from preprocessing.utils import (
12 | gdal_reproject,
13 | bands_to_time,
14 | convert_to_same_grid,
15 | select_same_time_slice,
16 | save_netcdf,
17 | get_holaps_mask,
18 | merge_data_arrays,
19 | )
20 |
21 | from preprocessing.cleaner import Cleaner
22 |
23 | class ChirpsCleaner(Cleaner):
24 | def __init__(self):
25 | self.base_data_path = Path("/soge-home/projects/crop_yield/EGU_compare/")
26 | reference_data_path = self.base_data_path / "holaps_EA_clean.nc"
27 |
28 | # CHANGE THIS PATH:
29 | # ----------------------------------------------------------------------
30 | data_path = self.base_data_path / "EA_chirps_monthly.nc"
31 | # ----------------------------------------------------------------------
32 |
33 | # open the reference dataset
34 | self.reference_data_path = Path(reference_data_path)
35 | self.reference_ds = xr.open_dataset(self.reference_data_path).holaps_evapotranspiration
36 |
37 | # initialise the object using methods from the parent class
38 | super(ChirpsCleaner, self).__init__(data_path=data_path)
39 |
40 | # extract the variable of interest (TO xr.DataArray)
41 | self.update_clean_data(
42 | self.raw_data.precip, msg="Extract Precipitation from CHIRPS xr.Dataset"
43 | )
44 |
45 | # make the mask (FROM REFERENCE_DS) to copy to this dataset too
46 | self.get_mask()
47 | # self.mask = self.mask.drop('units')
48 |
49 |
50 | def convert_units(self):
51 | daily_mm = self.clean_data / 30.417
52 | daily_mm.attrs.units = 'mm day-1'
53 | self.update_clean_data(daily_mm, msg="Change the mm month-1 values to mm day-1")
54 |
55 |
56 | def preprocess(self):
57 | # Resample the timesteps to END OF MONTH
58 | self.resample_time(resample_str="M")
59 | # select the correct time slice
60 | self.correct_time_slice()
61 | # update the units
62 | self.convert_units()
63 | # latitude,longitude => lat,lon
64 | self.rename_lat_lon()
65 | # regrid to same as reference data (holaps)
66 | self.regrid_to_reference(method="bilinear")
67 | # ipdb.set_trace()
68 | # use the same mask as HOLAPS
69 | self.use_reference_mask() # THIS GOING WRONG
70 | # rename data
71 | self.rename_xr_object("chirps_precipitation")
72 | # save data
73 | save_netcdf(
74 | self.clean_data, filepath=self.base_data_path / "chirps_EA_clean.nc"
75 | )
76 | print("\n\n CHIRPS Preprocessed \n\n")
77 | return
78 |
--------------------------------------------------------------------------------
/predictor/models/base.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import numpy as np
3 | from collections import namedtuple
4 |
5 | from sklearn.metrics import mean_squared_error
6 | from ..preprocessing import VALUE_COLS, VEGETATION_LABELS
7 |
8 | DataTuple = namedtuple('Data', ['x', 'y', 'latlon', 'years'])
9 |
10 |
11 | class ModelBase:
12 | """Base for all machine learning models.
13 |
14 | Attributes:
15 | ----------
16 | arrays: pathlib.Path
17 | The location where the arrays were saved by the `Engineer` class
18 | hide_vegetation: bool, default: False
19 | Whether to hide vegetation-specific information from the training
20 | data. This allows us to better understand how the other factors drive
21 | vegetation health.
22 | """
23 |
24 | model_name = None # to be added by the model classes
25 |
26 | def __init__(self, data=Path('data'), arrays_path=Path('data/processed/arrays'),
27 | hide_vegetation=False):
28 | self.data_path = data
29 | self.arrays_path = arrays_path
30 | self.hide_vegetation = hide_vegetation
31 | self.model = None # to be added by the model classes
32 |
33 | def train(self):
34 | raise NotImplementedError
35 |
36 | def predict(self):
37 | # This method should return the predictions, and
38 | # the corresponding true values, read from the test
39 | # arrays
40 | raise NotImplementedError
41 |
42 | def save_model(self):
43 | # This method should save the model in data / model_name
44 | raise NotImplementedError
45 |
46 | def evaluate(self, return_eval=False, save_preds=False):
47 | """Evaluates the model using root mean squared error.
48 | This ensures evaluation is consistent across different models.
49 |
50 | Parameters:
51 | ----------
52 | return_eval: bool, default: False
53 | Whether to return the calculated root mean squared error
54 | save_preds: bool, default: False
55 | Whether to save the predictions. If True, they will be saved
56 | in self.arrays_path / preds.npy
57 |
58 | Returns:
59 | ----------
60 | (if return_eval) test_rmse: float
61 | The calculated root mean squared error for the test set
62 | """
63 | y_true, y_pred = self.predict()
64 |
65 | test_rmse = np.sqrt(mean_squared_error(y_true, y_pred))
66 |
67 | print(f'Test set RMSE: {test_rmse}')
68 |
69 | if save_preds:
70 | savedir = self.data_path / self.model_name
71 | if not savedir.exists(): savedir.mkdir()
72 | print(f'Saving predictions to {savedir / "preds.npy"}')
73 | np.save(savedir / 'preds.npy', y_pred)
74 |
75 | if return_eval:
76 | return test_rmse
77 |
78 | def load_arrays(self, mode='train'):
79 |
80 | arrays_path = self.arrays_path / mode
81 |
82 | x = np.load(arrays_path / 'x.npy')
83 |
84 | if self.hide_vegetation:
85 | if mode == 'train':
86 | print('Training model without vegetation features')
87 | indices_to_keep = [idx for idx, val in enumerate(VALUE_COLS) if val not in VEGETATION_LABELS]
88 |
89 | x = x[:, :, indices_to_keep]
90 |
91 | return DataTuple(
92 | latlon=np.load(arrays_path / 'latlon.npy'),
93 | years=np.load(arrays_path / 'years.npy'),
94 | x=x,
95 | y=np.load(arrays_path / 'y.npy'))
96 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Vegetation Health
2 |
3 | Predicting vegetation health from precipitation and temperature
4 |
5 | ## Introduction
6 |
7 | This repository experiments with different machine learning models to predict drought indices in East Africa
8 | (specifically the Normalized Difference Vegetation Index) using temperature and precipitation data.
9 |
10 | ## Results
11 |
12 | Models are trained on data before 2016, and evaluated on 2016 data. Vegetation health in June is being predicted.
13 |
14 | In addition, vegetation health can be hidden from the model to better understand the effects of the other features.
15 |
16 | | Model | RMSE | RMSE (no veg) |
17 | |:------------------------:|:----:|:-------------:|
18 | |Linear Regression |0.040 |0.084 |
19 | |Feedforward neural network|0.038 |0.070 |
20 | |Recurrent neural network |0.035 |0.060 |
21 |
22 | The results of the models can also be compared visually with the ground truths (the example below is from the baseline
23 | logistic regression):
24 |
25 |
26 |
27 | In addition, the effects of the inputs on the models' predictions are investigated using [shap values](https://github.com/slundberg/shap)
28 | in Jupyter Notebooks, for both the [feedforward neural network](notebooks/04_gt_feedforward_model.ipynb) and the
29 | [recurrent neural network](notebooks/08_gt_recurrent_model.ipynb).
30 |
31 | ## Pipeline
32 |
33 | [Python Fire](https://github.com/google/python-fire) is used to generate a CLI.
34 |
35 | ### Data cleaning
36 |
37 | Normalize values from the original csv file, remove null values, add a year series.
38 |
39 | ```bash
40 | python run.py clean
41 | ```
42 | A target can be selected by adding the flag `--target`, e.g. `--target=ndvi_anomaly`.
43 | By default, the target is `ndvi`. The selected target must be in
44 | [`predictor.preprocessing.VALUE_COLS`](predictor/preprocessing.py).
45 |
46 | The original data is currently generated using datasets on the Oxford University cluster, using the scripts
47 | in [`data`](data).
48 |
49 | ### Data Processing
50 |
51 | Turn the CSV into `numpy` arrays which can be input into the model.
52 |
53 | ```bash
54 | python run.py engineer
55 | ```
56 |
57 | ### Models
58 |
59 | 3 models have been implemented: a baseline linear regression, a feedforward neural network and a
60 | recurrent neural network. They can be selected using the `--model_type` flag.
61 |
62 | ```bash
63 | python run.py train_model
64 | ```
65 |
66 | ## Setup
67 |
68 | [Anaconda](https://www.anaconda.com/download/#macos) running python 3.7 is used as the package manager. To get set up
69 | with an environment, install Anaconda from the link above, and (from this directory) run
70 |
71 | ```bash
72 | conda env create -f environment.yml
73 | ```
74 | This will create an environment named `vegetation_health` with all the necessary packages to run the code. To
75 | activate this environment, run
76 |
77 | ```bash
78 | conda activate vegetation_health
79 | ```
80 |
81 | ## Additional Notes
82 |
83 | - The following variables are used by the model: `['lst_night', 'lst_day', 'precip', 'sm', 'ndvi', 'evi', 'ndvi_anomaly']`.
84 | They are all from different sources
85 |
86 |
87 | - East Africa is defined here as the area of the original `.nc` file (`spi_spei.nc`)
88 |
89 | lat min, lat max : `-4.9750023`, `15.174995`
90 |
91 | lon min, lon max : `32.524994`, `48.274994`
92 |
93 | This makes the following bounding box: (left, bottom, right, top): `(32.524994, -4.9750023, 15.174995, 48.274994)`
94 |
--------------------------------------------------------------------------------
/data/globmap_lookup.py:
--------------------------------------------------------------------------------
1 | """
2 | The lookup values for the GLOBMAP landcover classes
3 | (defined in ee_extract_landcover.js)
4 | """
5 | globmap_lookup = {
6 | 11 : "Post-flooding or irrigated croplands",
7 | 14 : "Rainfed croplands",
8 | 20 : "Mosaic cropland (50-70%) / vegetation (grassland, shrubland, forest) (20-50%)",
9 | 30 : "Mosaic vegetation (grassland, shrubland, forest) (50-70%) / cropland (20-50%)",
10 | 40 : "Closed to open (>15%) broadleaved evergreen and/or semi-deciduous forest (>5m)",
11 | 50 : "Closed (>40%) broadleaved deciduous forest (>5m)",
12 | 60 : "Open (15-40%) broadleaved deciduous forest (>5m)",
13 | 70 : "Closed (>40%) needleleaved evergreen forest (>5m)",
14 | 90 : "Open (15-40%) needleleaved deciduous or evergreen forest (>5m)",
15 | 100 : "Closed to open (>15%) mixed broadleaved and needleleaved forest (>5m)",
16 | 110 : "Mosaic forest-shrubland (50-70%) / grassland (20-50%)",
17 | 120 : "Mosaic grassland (50-70%) / forest-shrubland (20-50%)",
18 | 130 : "Closed to open (>15%) shrubland (<5m)",
19 | 140 : "Closed to open (>15%) grassland",
20 | 150 : "Sparse (>15%) vegetation (woody vegetation, shrubs, grassland)",
21 | 160 : "Closed (>40%) broadleaved forest regularly flooded - Fresh water",
22 | 170 : "Closed (>40%) broadleaved semi-deciduous and/or evergreen forest regularly flooded - saline water",
23 | 180 : "Closed to open (>15%) vegetation (grassland, shrubland, woody vegetation) on regularly flooded or waterlogged soil",
24 | 190 : "Artificial surfaces and associated areas (urban areas >50%) GLOBCOVER 2009",
25 | 200 : "Bare areas",
26 | 210 : "Water bodies",
27 | 220 : "Permanent snow and ice",
28 | 230 : "Unclassified",
29 | }
30 |
31 | globmap_lookup_tommy = {
32 | 11 : "irrigated croplands",
33 | 14 : "Rainfed croplands",
34 | 20 : "cropland/vegetation",
35 | 30 : "vegetation/cropland",
36 | 40 : "broad-leaved evergreen/semi-deciduous forest",
37 | 50 : "broad-leaved deciduous forest",
38 | 60 : "broad-leaved deciduous forest",
39 | 70 : "needle-leaved evergreen forest",
40 | 90 : "needle-leaved deciduous / evergreen forest",
41 | 100 : "broad-leaved / needleleaved forest",
42 | 110 : "forest-shrubland / grassland",
43 | 120 : "Mosaic grassland / forest-shrubland",
44 | 130 : "shrubland",
45 | 140 : "grassland",
46 | 150 : "Sparse vegetation",
47 | 160 : "forest regularly flooded - Fresh water",
48 | 170 : "forest regularly flooded - saline water",
49 | 180 : "vegetation on flooded soil",
50 | 190 : "Artificial",
51 | 200 : "Bare areas",
52 | 210 : "Water bodies",
53 | 220 : "Permanent snow and ice",
54 | 230 : "Unclassified",
55 | }
56 |
57 |
58 | globmap_lookup2 = {
59 | # np.nan : "NA",
60 | 11 : "cropland",
61 | 14 : "cropland",
62 | 20 : "cropland",
63 | 30 : "cropland",
64 | 40 : "forest",
65 | 50 : "forest",
66 | 60 : "forest",
67 | 70 : "forest",
68 | 90 : "forest",
69 | 100 : "forest",
70 | 110 : "shrubland/grassland",
71 | 120 : "shrubland/grassland",
72 | 130 : "shrubland",
73 | 140 : "grassland",
74 | 150 : "Sparse vegetation",
75 | 160 : "flooded",
76 | 170 : "flooded",
77 | 180 : "flooded",
78 | 190 : "Artificial",
79 | 200 : "Bare areas",
80 | 210 : "Water bodies",
81 | 220 : "Permanent snow and ice",
82 | 230 : "Unclassified",
83 | }
84 |
85 | remap_to_int_dict = {
86 | # "NA": 1,
87 | "cropland": 2,
88 | "forest": 3,
89 | "shrubland/grassland": 4,
90 | "shrubland": 5,
91 | "grassland": 6,
92 | "Sparse vegetation": 7,
93 | "flooded": 8,
94 | "Artificial": 9,
95 | "Bare areas": 10,
96 | "Water bodies": 11,
97 | "Permanent snow and ice": 12,
98 | "Unclassified": 13,
99 | }
100 |
101 | globmap_lookup3 = dict(zip(remap_to_int_dict.values(), remap_to_int_dict.keys()))
102 |
--------------------------------------------------------------------------------
/data/preprocessing/collect_regrid_data_script.py:
--------------------------------------------------------------------------------
1 | """
2 | @tommylees112
3 |
4 | An awful script (sorry Gabi) the data is read in separately for each product.
5 | They are then preprocessed:
6 | resample_time
7 | select_time_slice
8 | regrid_to_reference
9 |
10 | using the precipitation data as reference.
11 | """
12 | # test.py
13 | import xarray as xr
14 | from pathlib import Path
15 | import matplotlib.pyplot as plt
16 |
17 | from preprocessing.utils import (
18 | gdal_reproject,
19 | bands_to_time,
20 | convert_to_same_grid,
21 | select_same_time_slice,
22 | save_netcdf,
23 | get_holaps_mask,
24 | merge_data_arrays,
25 | )
26 |
27 |
28 | def correct_time_slice(ds, reference_ds):
29 | """select the same time slice as the reference data"""
30 | correct_time_slice = select_same_time_slice(reference_ds, ds)
31 |
32 | return correct_time_slice
33 |
34 |
35 | def resample_time(ds, resample_str="M"):
36 | """ should resample to the given timestep """
37 | resampled_time_data = ds.resample(time=resample_str).first()
38 | return resampled_time_data
39 |
40 |
41 | def regrid_to_reference(ds, reference_ds, method="nearest_s2d"):
42 | """ regrid data (spatially) onto the same grid as reference data """
43 |
44 | regrid_data = convert_to_same_grid(
45 | reference_ds, ds, method=method
46 | )
47 |
48 | return regrid_data
49 |
50 |
51 | def use_reference_mask(ds, mask, one_time=False):
52 | # if only one timestep (e.g. landcover) then convert to one time
53 | if one_time:
54 | self.mask = self.mask.isel(time=0)
55 |
56 | masked_d = ds.where(~mask.values)
57 | return masked_d
58 |
59 |
60 | def rename_lat_lon(ds):
61 | rename_latlon = ds.rename({"longitude": "lon", "latitude": "lat"})
62 | return rename_latlon
63 |
64 |
65 | def select_time_slice(ds, timemin, timemax):
66 | return ds.sel(time=slice(timemin, timemax))
67 |
68 |
69 | # ------------------------------------------------------------------------------
70 | # Read the data
71 | # ------------------------------------------------------------------------------
72 | DATA_DIR1 = Path("/soge-home/users/chri4118/EA_data")
73 | DATA_DIR2 = Path("/soge-home/projects/crop_yield/EGU_compare")
74 |
75 | et = xr.open_dataset(DATA_DIR1 / "ET_EastAfrica.nc")
76 | lst = xr.open_dataset(DATA_DIR1 / "LST_EastAfrica.nc")[["lst_day", "lst_night"]]
77 | sm = xr.open_dataset(DATA_DIR1 / "SM_EastAfrica.nc")[['sm','sm_uncertainty']]
78 | ndvi = xr.open_dataset(DATA_DIR1 / "NDVI_EastAfrica.nc")[["ndvi", "evi"]]
79 | precip = xr.open_dataset(DATA_DIR2 / "EA_chirps_monthly.nc")
80 |
81 | # ------------------------------------------------------------------------------
82 | # Clean the data (same timesteps and same gridsizes)
83 | # ------------------------------------------------------------------------------
84 | # RESAMPLE THE REFERENCE DATA
85 | precip = select_time_slice(precip, '2000-02-14','2016-12-01')
86 | precip = resample_time(precip)
87 | precip = rename_lat_lon(precip)
88 | reference_ds = precip
89 |
90 | all_vars = [et,lst,sm,ndvi,precip]
91 | names = ["et","lst","sm","ndvi","precip"]
92 |
93 | # RESAMPLE data (except 'precip')
94 | out = []
95 | for ix, ds in enumerate(all_vars[:-1]):
96 | name = names[ix]
97 | print(f"\n*** working on ds: {name} ***")
98 | # select same time slice
99 | ds = resample_time(ds)
100 | ds = select_time_slice(ds, '2000-02-14','2016-12-01')
101 | # ds = correct_time_slice(ds, reference_ds)
102 | print("selected same time slice")
103 | # convert to same grid
104 | ds = regrid_to_reference(ds, reference_ds)
105 | print("converted to same grid")
106 | out.append(ds)
107 |
108 | # ------------------------------------------------------------------------------
109 | # Merge all of the datasets
110 | # ------------------------------------------------------------------------------
111 | alldata = out + [precip]
112 | OUT = xr.merge(alldata)
113 |
114 | # ------------------------------------------------------------------------------
115 | # Save the data to netcdf format
116 | # ------------------------------------------------------------------------------
117 | OUT.to_netcdf(DATA_DIR1 / "OUT2.nc")
118 | OUT.to_netcdf(DATA_DIR2 / "predict_vegetation_health.nc")
119 |
--------------------------------------------------------------------------------
/data/preprocessing/spei_cleaner.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import xarray as xr
3 | import numpy as np
4 |
5 | import ipdb
6 | import warnings
7 | import os
8 |
9 | from preprocessing.utils import (
10 | gdal_reproject,
11 | bands_to_time,
12 | convert_to_same_grid,
13 | select_same_time_slice,
14 | save_netcdf,
15 | get_holaps_mask,
16 | merge_data_arrays,
17 | )
18 |
19 | from preprocessing.cleaner import Cleaner
20 |
21 |
22 | # ------------------------------------------------------------------------------
23 | # HOLAPS cleaner
24 | # ------------------------------------------------------------------------------
25 |
26 |
27 | class SpeiCleaner(Cleaner):
28 | """Preprocess the HOLAPS dataset"""
29 | assert False, "This is just boilerplate code from another"
30 | def __init__(self):
31 | # init data paths (should be arguments)
32 | self.base_data_path = Path("/soge-home/projects/crop_yield/EGU_compare/")
33 | data_path = self.base_data_path / "holaps_africa.nc"
34 | reproject_path = self.base_data_path / "holaps_africa_reproject.nc"
35 |
36 | super(HolapsCleaner, self).__init__(data_path=data_path)
37 | self.reproject_path = Path(reproject_path)
38 |
39 |
40 | def chop_EA_region(self, outfile_path):
41 | """ cheeky little bit of bash scripting with string interpolation (kids don't try this at home) """
42 | in_file = self.base_data_path / "holaps_reprojected.nc"
43 | out_file = self.base_data_path / "holaps_EA.nc"
44 | lonmin = 32.6
45 | lonmax = 51.8
46 | latmin = -5.0
47 | latmax = 15.2
48 |
49 | cmd = (
50 | f"cdo sellonlatbox,{lonmin},{lonmax},{latmin},{latmax} {in_file} {out_file}"
51 | )
52 | print(f"Running command: {cmd}")
53 | os.system(cmd)
54 | print("Chopped East Africa from the Reprojected data")
55 | re_chopped_data = xr.open_dataset(out_file)
56 | self.update_clean_data(
57 | re_chopped_data, msg="Opened the reprojected & chopped data"
58 | )
59 | return
60 |
61 |
62 | def reproject(self):
63 | """ reproject to WGS84 / geographic latlon """
64 | if not self.reproject_path.is_file():
65 | gdal_reproject(infile=self.data_path, outfile=self.reproject_path)
66 |
67 | repr_data = xr.open_dataset(self.reproject_path)
68 |
69 | # get the timestamps from the original holaps data
70 | h_times = self.clean_data.time
71 | # each BAND is a time (multiple raster images 1 per time)
72 | repr_data = bands_to_time(repr_data, h_times, var_name="LE_Mean")
73 |
74 | # TODO: ASSUMPTION / PROBLEM
75 | warnings.warn(
76 | "TODO: No idea why but the values appear to be 10* bigger than the pre-reprojected holaps data"
77 | )
78 | repr_data /= 10 # WHY ARE THE VALUES 10* bigger?
79 |
80 | self.update_clean_data(repr_data, "Data Reprojected to WGS84")
81 |
82 | save_netcdf(
83 | self.clean_data, filepath=self.base_data_path / "holaps_reprojected.nc"
84 | )
85 | return
86 |
87 |
88 | def convert_units(self):
89 | # Convert from latent heat (w m-2) to evaporation (mm day-1)
90 | holaps_mm = self.clean_data / 28
91 | holaps_mm = holaps_mm.LE_Mean
92 | holaps_mm.name = "Evapotranspiration"
93 | holaps_mm.attrs["units"] = "mm day-1 [w m-2 / 28]"
94 | self.update_clean_data(
95 | holaps_mm, msg="Transform Latent Heat (w m-2) to Evaporation (mm day-1)"
96 | )
97 |
98 | return
99 |
100 | def preprocess(self):
101 | # reproject the file from sinusoidal to WGS84 / 'ESPG:4326'
102 | self.reproject()
103 | # chop out the correct lat/lon (changes when reprojected)
104 | self.chop_EA_region()
105 | # convert the units
106 | self.convert_units()
107 | # rename data
108 | self.rename_xr_object("holaps_evapotranspiration")
109 | # resample the time units
110 | self.resample_time()
111 | # save the netcdf file (used as reference data for MODIS and GLEAM)
112 | save_netcdf(
113 | self.clean_data, filepath=self.base_data_path / "holaps_EA_clean.nc",
114 | force=True
115 | )
116 | print("\n\n HOLAPS Preprocessed \n\n")
117 | return
118 |
--------------------------------------------------------------------------------
/predictor/engineer.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import pandas as pd
3 | import numpy as np
4 |
5 | from .preprocessing import VALUE_COLS
6 |
7 |
8 | class Engineer:
9 | """Take the clean csv file and turn it into numpy arrays ready
10 | for training
11 |
12 | Attributes:
13 | ----------
14 | cleaned_data: pathlib.Path
15 | The location of cleaned data, as saved by the `Cleaner` class
16 | arrays: pathlib.Path
17 | The location where the arrays will be saved
18 | """
19 |
20 | def __init__(self, cleaned_data=Path('data/processed/cleaned_data.csv'),
21 | arrays=Path('data/processed/arrays')):
22 |
23 | self.arrays_path = arrays
24 | if not self.arrays_path.exists():
25 | self.arrays_path.mkdir()
26 | self.cleaned_data_path = cleaned_data
27 |
28 | def readfile(self):
29 | return pd.read_csv(self.cleaned_data_path)
30 |
31 | def process(self, test_year=2016):
32 | """Takes the processed data saved by the `preprocessing.Cleaner` class,
33 | and turns it into `np.array`s which can be ingested by the machine learning
34 | models
35 |
36 | Parameters
37 | ----------
38 | test_year: int, default: 2016
39 | Data from this year will be used for testing, and so will be saved in
40 | seperate arrays
41 |
42 | The following are saved, for both the training and test sets:
43 |
44 | {train, test}/latlon.npy:
45 | The locations of each data instance (so that latlon[i] represents the latitude and
46 | longitude of the ith data point
47 | {train, test}/years.npy:
48 | The years of each data point. Specifically, this represents the prediction year, so if
49 | `pred_month` passed to the Cleaner = 6, then if years[i] = 2015, that means y[i] is the
50 | value of the target in June 2015.
51 | {train, test}/x.npy:
52 | The training data; the previous 11 months of data
53 | {train, test}/y.npy:
54 | The test data - the value of the target variable at `pred_month`.
55 | """
56 | data = self.readfile()
57 |
58 | # outputs
59 | latlons, years, vals, targets = [], [], [], []
60 |
61 | skipped = 0
62 | # first, groupby lat, lon, so that we process the same place together
63 | for latlon, group in data.groupby(by=['lat', 'lon']):
64 | latlon_np = np.array(latlon)
65 | for year, subgroup in group.groupby(by='gb_year'):
66 | if len(subgroup) != 12:
67 | # print(f'Ignoring data from {year} at {latlon} due to missing rows')
68 | skipped += 1
69 | continue
70 | subgroup = subgroup.sort_values(by='gb_month', ascending=True)
71 |
72 | # create a np.array of the features (VALUE_COLS) and the target
73 | x = subgroup[:-1][VALUE_COLS].values
74 | y = subgroup.iloc[-1]['target']
75 |
76 | # create lists of np.arrays
77 | latlons.append(latlon_np)
78 | years.append(year)
79 | vals.append(x)
80 | targets.append(y)
81 |
82 | if len(latlons) % 1000 == 0:
83 | print(f'Processed {len(latlons)} examples')
84 |
85 | print(f'Done processing {len(latlons)} pixel-years! Skipped {skipped} pixel-years due to missing rows')
86 |
87 | # turn everything into np arrays for manipulation
88 | latlons, years, vals, targets = np.vstack(latlons), np.array(years), np.stack(vals), np.array(targets)
89 |
90 | # split into train and test sets
91 | test_idx = np.where(years == test_year)[0]
92 | train_idx = np.where(years < test_year)[0]
93 |
94 | test_arrays = self.arrays_path / 'test'
95 | train_arrays = self.arrays_path / 'train'
96 |
97 | test_arrays.mkdir(parents=True, exist_ok=True)
98 | train_arrays.mkdir(exist_ok=True)
99 |
100 | print('Saving data')
101 | np.save(train_arrays / 'latlon.npy', latlons[train_idx])
102 | np.save(train_arrays / 'years.npy', years[train_idx])
103 | np.save(train_arrays / 'x.npy', vals[train_idx])
104 | np.save(train_arrays / 'y.npy', targets[train_idx])
105 |
106 | np.save(test_arrays / 'latlon.npy', latlons[test_idx])
107 | np.save(test_arrays / 'years.npy', years[test_idx])
108 | np.save(test_arrays / 'x.npy', vals[test_idx])
109 | np.save(test_arrays / 'y.npy', targets[test_idx])
110 |
--------------------------------------------------------------------------------
/data/preprocessing/cleaner.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import xarray as xr
3 | import numpy as np
4 |
5 | import ipdb
6 | import warnings
7 | import os
8 |
9 | from preprocessing.utils import (
10 | gdal_reproject,
11 | bands_to_time,
12 | convert_to_same_grid,
13 | select_same_time_slice,
14 | save_netcdf,
15 | get_holaps_mask,
16 | merge_data_arrays,
17 | )
18 |
19 |
20 | # ------------------------------------------------------------------------------
21 | # Base cleaner
22 | # ------------------------------------------------------------------------------
23 |
24 |
25 | class Cleaner:
26 | """Base class for preprocessing the input data.
27 |
28 | Tasks include:
29 | - Reprojecting
30 | - Putting datasets onto a consistent spatial grid (spatial resolution)
31 | - Converting to equivalent units
32 | - Converting to the same temporal resolution
33 | - Selecting the same time slice
34 |
35 | Design Considerations:
36 | - Have an attribute, clean_data', that is constantly updated
37 | - Keep a copy of the raw data for reference
38 | - Update the 'clean_data' each time a transformation is applied
39 | """
40 |
41 | def __init__(self, data_path):
42 | self.data_path = Path(data_path)
43 |
44 | # open the datasets using xarray
45 | self.raw_data = xr.open_dataset(self.data_path)
46 |
47 | # start with clean data as a copy of the raw data
48 | self.clean_data = self.raw_data.copy()
49 |
50 |
51 | def get_mask(self):
52 | self.mask = get_holaps_mask(self.reference_ds)
53 |
54 |
55 | def update_clean_data(self, clean_data, msg=""):
56 | """ """
57 | self.clean_data = clean_data
58 | print("***** self.clean_data Updated: ", msg, " *****")
59 |
60 | return
61 |
62 | def correct_time_slice(self):
63 | """select the same time slice as the reference data"""
64 | assert (
65 | self.reference_ds is not None
66 | ), "self.reference_ds does not exist! Likely because you're not using the MODIS or GLEAM cleaners / correct data paths"
67 | correct_time_slice = select_same_time_slice(self.reference_ds, self.clean_data)
68 |
69 | self.update_clean_data(
70 | correct_time_slice, msg="Selected the same time slice as reference data"
71 | )
72 | return
73 |
74 | def resample_time(self, resample_str="M"):
75 | """ should resample to the given timestep """
76 | resampled_time_data = self.clean_data.resample(time=resample_str).first()
77 | self.update_clean_data(resampled_time_data, msg="Resampled time ")
78 |
79 | return
80 |
81 | def regrid_to_reference(self, method="nearest_s2d"):
82 | """ regrid data (spatially) onto the same grid as referebce data """
83 | assert (
84 | self.reference_ds is not None
85 | ), "self.reference_ds does not exist! Likely because you're not using the MODIS or GLEAM cleaners / correct data paths"
86 |
87 | regrid_data = convert_to_same_grid(
88 | self.reference_ds, self.clean_data, method=method
89 | )
90 | # UPDATE THE SELF.CLEAN_DATA
91 | self.update_clean_data(regrid_data, msg="Data Regridded to same as HOLAPS")
92 | return
93 |
94 | def use_reference_mask(self, one_time=False):
95 | assert not 'units' in self.mask.coords, "MUST NOT HAVE EXTRA COORDS or you remove ALL values. self.mask has 'units' coord and needs to be dropped:\n self.mask = self.mask.drop('units')"
96 | assert (
97 | self.reference_ds is not None
98 | ), "self.reference_ds does not exist! Likely because you're not using the MODIS or GLEAM cleaners / correct data paths"
99 | assert (
100 | self.mask is not None
101 | ), "self.mask does not exist! Likely because you're not using the MODIS or GLEAM cleaners / correct data paths"
102 |
103 | # if only one timestep (e.g. landcover) then convert to one time
104 | if one_time:
105 | self.mask = self.mask.isel(time=0)
106 |
107 | masked_d = self.clean_data.where(~self.mask.values)
108 | self.update_clean_data(masked_d, msg="Copied the mask from reference to dataset!")
109 | return
110 |
111 | def mask_illegitimate_values(self):
112 | # mask out the missing values (coded as something else)
113 | return NotImplementedError
114 |
115 | def convert_units(self):
116 | """ convert to the equivalent units """
117 | raise NotImplementedError
118 |
119 | def rename_xr_object(self, name):
120 | renamed_data = self.clean_data.rename(name)
121 | self.update_clean_data(renamed_data, msg=f"Data renamed {name}")
122 | return
123 |
124 | def rename_lat_lon(self):
125 | rename_latlon = self.clean_data.rename({"longitude": "lon", "latitude": "lat"})
126 | self.update_clean_data(rename_latlon, msg="Renamed latitude,longitude => lat,lon")
127 | return
128 |
129 | def preprocess(self):
130 | """ The preprocessing steps (relatively unique for each dtype) """
131 | raise NotImplementedError
132 |
--------------------------------------------------------------------------------
/predictor/models/neural_networks/nn_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 | from torch.utils.data import TensorDataset, DataLoader, random_split
4 |
5 | from pathlib import Path
6 | from tqdm import tqdm
7 | import numpy as np
8 |
9 | from ..base import ModelBase, DataTuple
10 |
11 |
12 | class NNBase(ModelBase):
13 | """The base for neural networks models
14 | """
15 |
16 | def __init__(self, model, data=Path('data'), arrays=Path('data/processed/arrays'),
17 | hide_vegetation=False):
18 | super().__init__(data, arrays, hide_vegetation)
19 |
20 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
21 | # for reproducability
22 | torch.manual_seed(42)
23 | torch.cuda.manual_seed_all(42)
24 |
25 | if torch.cuda.is_available():
26 | model = model.cuda()
27 | self.model = model
28 |
29 | def train(self, num_epochs=100, patience=3, batch_size=32, learning_rate=1e-3):
30 | """Train the neural network
31 |
32 | Parameters
33 | ----------
34 | num_epochs: int
35 | The maximum number of epochs to train the model for
36 | patience: int
37 | If no improvement is seen in the validation set for `patience` epochs,
38 | training is stopped to prevent overfitting
39 | batch_size: int
40 | The batch size to use
41 | learning_rate: float
42 | The learning rate to use when updating model parameters
43 | """
44 | train_data = self.load_tensors(mode='train')
45 |
46 | # split the data into a training and validation set
47 | total_size = train_data.x.shape[0]
48 | val_size = total_size // 10 # 10 % for validation
49 | train_size = total_size - val_size
50 | print(f'After split, training on {train_size} examples, '
51 | f'validating on {val_size} examples')
52 | train_dataset, val_dataset = random_split(TensorDataset(train_data.x, train_data.y.unsqueeze(1)),
53 | (train_size, val_size))
54 |
55 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
56 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
57 |
58 | optimizer = torch.optim.Adam([pam for pam in self.model.parameters()],
59 | lr=learning_rate)
60 |
61 | epochs_without_improvement = 0
62 | best_loss = np.inf
63 |
64 | for epoch in range(num_epochs):
65 | self.model.train()
66 |
67 | epoch_train_loss = 0
68 | num_train_batches = 0
69 |
70 | epoch_val_loss = 0
71 | num_val_batches = 0
72 |
73 | for train_x, train_y in tqdm(train_dataloader):
74 | optimizer.zero_grad()
75 | pred_y = self.model(train_x)
76 |
77 | loss = F.smooth_l1_loss(pred_y, train_y)
78 | loss.backward()
79 | optimizer.step()
80 |
81 | num_train_batches += 1
82 | epoch_train_loss += loss.sqrt().item()
83 |
84 | self.model.eval()
85 | with torch.no_grad():
86 | for val_x, val_y in tqdm(val_dataloader):
87 | val_pred_y = self.model(val_x)
88 | val_loss = F.mse_loss(val_pred_y, val_y)
89 |
90 | num_val_batches += 1
91 | epoch_val_loss += val_loss.sqrt().item()
92 |
93 | epoch_train_loss /= num_train_batches
94 | epoch_val_loss /= num_val_batches
95 |
96 | print(f'Epoch {epoch} - Training RMSE: {epoch_train_loss}, '
97 | f'Validation RMSE: {epoch_val_loss}')
98 |
99 | if epoch_val_loss < best_loss:
100 | best_state = self.model.state_dict()
101 | best_loss = epoch_val_loss
102 |
103 | epochs_without_improvement = 0
104 | else:
105 | epochs_without_improvement += 1
106 | if epochs_without_improvement == patience:
107 | self.model.load_state_dict(best_state)
108 | print('Early stopping!')
109 | return
110 | self.model.load_state_dict(best_state)
111 |
112 | def predict(self, batch_size=64):
113 | test_data = self.load_tensors(mode='test')
114 |
115 | test_dataloader = DataLoader(TensorDataset(test_data.x, test_data.y),
116 | batch_size=batch_size)
117 |
118 | output_preds, output_true = [], []
119 |
120 | self.model.eval()
121 | with torch.no_grad():
122 | for test_x, test_y in tqdm(test_dataloader):
123 | output_preds.append(self.model(test_x).squeeze(1).cpu().numpy())
124 | output_true.append(test_y.cpu().numpy())
125 | return np.concatenate(output_true), np.concatenate(output_preds)
126 |
127 | def load_tensors(self, mode='train'):
128 | data = self.load_arrays(mode)
129 |
130 | return DataTuple(
131 | latlon=data.latlon,
132 | years=data.years,
133 | x=torch.as_tensor(data.x, device=self.device).float(),
134 | y=torch.as_tensor(data.y, device=self.device).float())
135 |
136 | def save_model(self, name="model.pt"):
137 | """save the model's state_dict"""
138 | savedir = self.data_path / self.model_name
139 | if not savedir.exists(): savedir.mkdir()
140 |
141 | # save with .pt extension
142 | if not '.pt' in name: name += '.pt'
143 | print(f'Saving predictions to {savedir / name}')
144 | torch.save(self.model, savedir / name)
145 |
--------------------------------------------------------------------------------
/predictor/models/neural_networks/recurrent.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | import math
5 | from pathlib import Path
6 |
7 | from .nn_base import NNBase
8 | from ...preprocessing import VALUE_COLS, VEGETATION_LABELS
9 |
10 |
11 | class Recurrent(NNBase):
12 | """A simple feedforward neural network
13 | """
14 |
15 | def __init__(self, data=Path('data'), arrays=Path('data/processed/arrays'),
16 | hide_vegetation=False):
17 |
18 | features_per_month = len(VALUE_COLS)
19 | if hide_vegetation:
20 | features_per_month -= len(VEGETATION_LABELS)
21 |
22 | super().__init__(RNN(features_per_month, [256]),
23 | data, arrays, hide_vegetation)
24 | self.model_name = "recurrent"
25 |
26 | class RNN(nn.Module):
27 | """
28 | A crop yield conv net.
29 | For a description of the parameters, see the RNNModel class.
30 | """
31 | def __init__(self, features_per_month, dense_features, hidden_size=128,
32 | rnn_dropout=0.25):
33 | super().__init__()
34 |
35 | dense_features.insert(0, hidden_size)
36 | if dense_features[-1] != 1:
37 | dense_features.append(1)
38 |
39 | self.dropout = nn.Dropout(rnn_dropout)
40 | self.rnn = UnrolledRNN(input_size=features_per_month,
41 | hidden_size=hidden_size,
42 | batch_first=True)
43 | self.hidden_size = hidden_size
44 |
45 | self.dense_layers = nn.ModuleList([
46 | nn.Linear(in_features=dense_features[i-1],
47 | out_features=dense_features[i])
48 | for i in range(1, len(dense_features))
49 | ])
50 |
51 | self.initialize_weights()
52 |
53 | def initialize_weights(self):
54 |
55 | sqrt_k = math.sqrt(1 / self.hidden_size)
56 | for parameters in self.rnn.parameters():
57 | for pam in parameters:
58 | nn.init.uniform_(pam.data, -sqrt_k, sqrt_k)
59 |
60 | for dense_layer in self.dense_layers:
61 | nn.init.kaiming_uniform_(dense_layer.weight.data)
62 | nn.init.constant_(dense_layer.bias.data, 0)
63 |
64 | def forward(self, x):
65 | """
66 | If return_last_dense is true, the feature vector generated by the second to last
67 | dense layer will also be returned. This is then used to train a Gaussian Process model.
68 | """
69 |
70 | sequence_length = x.shape[1]
71 |
72 | hidden_state = torch.zeros(1, x.shape[0], self.hidden_size)
73 | cell_state = torch.zeros(1, x.shape[0], self.hidden_size)
74 |
75 | if x.is_cuda:
76 | hidden_state = hidden_state.cuda()
77 | cell_state = cell_state.cuda()
78 |
79 | for i in range(sequence_length):
80 | # The reason the RNN is unrolled here is to apply dropout to each timestep;
81 | # The rnn_dropout argument only applies it after each layer.
82 | # https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/DropoutWrapper
83 | input_x = x[:, i, :].unsqueeze(1)
84 | _, (hidden_state, cell_state) = self.rnn(input_x,
85 | (hidden_state, cell_state))
86 | hidden_state = self.dropout(hidden_state)
87 |
88 | x = hidden_state.squeeze(0)
89 | for layer_number, dense_layer in enumerate(self.dense_layers):
90 | x = dense_layer(x)
91 | return x
92 |
93 |
94 | class UnrolledRNN(nn.Module):
95 | """An unrolled RNN. The motivation for this is mainly so that we can explain this model using
96 | the shap deep explainer, but also because we unroll the RNN anyway to apply dropout.
97 | """
98 |
99 | def __init__(self, input_size, hidden_size, batch_first=True):
100 | super().__init__()
101 |
102 | self.input_size = input_size
103 | self.hidden_size = hidden_size
104 | self.batch_first = batch_first
105 |
106 | self.forget_gate = nn.Sequential(*[
107 | nn.Linear(in_features=input_size + hidden_size, out_features=hidden_size,
108 | bias=True), nn.Sigmoid()])
109 |
110 | self.update_gate = nn.Sequential(*[
111 | nn.Linear(in_features=input_size + hidden_size, out_features=hidden_size,
112 | bias=True), nn.Sigmoid()
113 | ])
114 |
115 | self.update_candidates = nn.Sequential(*[
116 | nn.Linear(in_features=input_size + hidden_size, out_features=hidden_size,
117 | bias=True), nn.Tanh()
118 | ])
119 |
120 | self.output_gate = nn.Sequential(*[
121 | nn.Linear(in_features=input_size + hidden_size, out_features=hidden_size,
122 | bias=True), nn.Sigmoid()
123 | ])
124 |
125 | self.cell_state_activation = nn.Tanh()
126 |
127 | def forward(self, x, state):
128 | hidden, cell = state
129 |
130 | if self.batch_first:
131 | hidden, cell = torch.transpose(hidden, 0, 1), torch.transpose(cell, 0, 1)
132 |
133 | forget_state = self.forget_gate(torch.cat((x, hidden), dim=-1))
134 | update_state = self.update_gate(torch.cat((x, hidden), dim=-1))
135 | cell_candidates = self.update_candidates(torch.cat((x, hidden), dim=-1))
136 |
137 | updated_cell = (forget_state * cell) + (update_state * cell_candidates)
138 |
139 | output_state = self.output_gate(torch.cat((x, hidden), dim=-1))
140 | updated_hidden = output_state * self.cell_state_activation(updated_cell)
141 |
142 | if self.batch_first:
143 | updated_hidden = torch.transpose(updated_hidden, 0, 1)
144 | updated_cell = torch.transpose(updated_cell, 0, 1)
145 |
146 | return updated_hidden, (updated_hidden, updated_cell)
147 |
--------------------------------------------------------------------------------
/predictor/preprocessing.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from pathlib import Path
3 | import xarray as xr
4 | import numpy as np
5 | import json
6 |
7 | KEY_COLS = ['lat', 'lon', 'time', 'gb_year', 'gb_month']
8 | VALUE_COLS = ['lst_night', 'lst_day', 'precip', 'sm', 'ndvi', 'evi', 'ndvi_anomaly']
9 | VEGETATION_LABELS = ['ndvi', 'evi', 'ndvi_anomaly']
10 |
11 |
12 | class Cleaner:
13 | """Clean the input data (from the .nc file), by removing nan
14 | (or encoded-as-nan) values, and normalising the non-key values.
15 |
16 | Does some preprocessing on the .nc file using xarray and then converts
17 | it to a dataframe for the other methods
18 |
19 | Attributes:
20 | ----------
21 | raw_filepath: pathlib.Path
22 | The location of the raw .NC data
23 | processed_filepath: pathlib.Path
24 | The location where the processed csv will be saved. The normalizing dict
25 | will be saved in the same directory.
26 | """
27 | def __init__(self, raw_filepath=Path('data/raw/OUT.NC'),
28 | processed_filepath=Path('data/processed/cleaned_data.csv')):
29 |
30 | self.filepath = raw_filepath
31 |
32 | if not processed_filepath.parents[0].exists():
33 | processed_filepath.parents[0].mkdir()
34 |
35 | self.processed_filepath = processed_filepath
36 | self.normalizing_dict = processed_filepath.parents[0] / 'normalizing_dict.json'
37 |
38 | def process(self, pred_month=6, target='ndvi_anomaly'):
39 | """Preprocesses the raw data, and saves it. Specifically, the following
40 | preprocessing happens:
41 | 1. `gb_year` and `gb_month`, which are the dates relative to
42 | `pred_month`, are added.
43 | 2. `ndvi_anomaly` is calculated
44 | 3. NaN values (and missing data) is removed from the dataframe.
45 | 4. Normalizes all values to have mean 0 and std 1
46 |
47 | Parameters
48 | ----------
49 | pred_month: int
50 | The month for which the target value should be predicted. This value will be
51 | predicted using the preceding 11 months of data
52 | target: str
53 | The target variable being predicted
54 |
55 | A processed CSV and a .json object containing the values used to normalize
56 | each variable are saved.
57 | """
58 |
59 | assert target in VALUE_COLS, f'f{target} not in {VALUE_COLS}'
60 |
61 | data = self._readfile(pred_month)
62 |
63 | data['target'] = data[target]
64 |
65 | normalizing_dict = {}
66 | for col in VALUE_COLS:
67 | print(f'Normalizing {col}')
68 |
69 | series = data[col]
70 |
71 | # calculate normalizing values
72 | mean, std = series.mean(), series.std()
73 | # add them to the dictionary
74 | normalizing_dict[col] = {
75 | 'mean': float(mean), 'std': float(std),
76 | }
77 |
78 | data[col] = (series - mean) / std
79 |
80 | print("Saving data")
81 | data.to_csv(self.processed_filepath, index=False)
82 | print(f'Saved {self.processed_filepath}')
83 |
84 | with open(self.normalizing_dict, 'w') as f:
85 | json.dump(normalizing_dict, f)
86 |
87 | print(f'Saved {self.normalizing_dict}')
88 |
89 | def _readfile(self, pred_month):
90 | # drop any Pixel-Times with missing values
91 | data = xr.open_dataset(self.filepath)
92 |
93 | data['gb_month'], data['gb_year'] = self._update_year_month(
94 | pd.to_datetime(data.time.to_series()), pred_month)
95 |
96 | # mask out the invalid temperature values
97 | lst_cols = ['lst_night', 'lst_day']
98 | for var_ in lst_cols:
99 | # for the lst_cols, missing data is coded as 200
100 | valid = (data[var_] < 200) | (np.isnan(data[var_]))
101 | data[var_] = data[var_].fillna(np.nan).where(valid)
102 |
103 | return_cols = KEY_COLS + VALUE_COLS
104 |
105 | # compute the ndvi_anomaly
106 | data['ndvi_anomaly'] = self._compute_anomaly(data.ndvi)
107 |
108 | data = data.to_dataframe().reset_index()
109 | data.dropna(how='any', axis=0, inplace=True)
110 |
111 | print(f'Loaded {len(data)} rows!')
112 | return data[return_cols]
113 |
114 | @staticmethod
115 | def _update_year_month(times, pred_month):
116 | """Given a pred year (e.g. 6), this method will return two new series with
117 | updated years and months so that a "year" of data will be the 11 months preceding the
118 | pred_month, and the pred_month. This makes it easier for the engineer to then make the training
119 | data
120 | """
121 | if pred_month == 12:
122 | return times.dt.month, times.dt.year
123 |
124 | relative_times = times - pd.DateOffset(months=pred_month)
125 |
126 | # we add one year so that the year column the engineer makes will be reflective
127 | # of the pred year, which is shifted because of the data offset we used
128 | return relative_times.dt.month, relative_times.dt.year + 1
129 |
130 | @staticmethod
131 | def _compute_anomaly(da, time_group='time.month'):
132 | """ Return a dataarray where values are an anomaly from the MEAN for that
133 | location at a given timestep. Defaults to finding monthly anomalies.
134 | Notes: http://xarray.pydata.org/en/stable/examples/weather-data.html#calculate-monthly-anomalies
135 |
136 | In addition, since 2016 is being used as the prediction year, data from that year
137 | is not being used to compute the mean.
138 |
139 | Arguments:
140 | ---------
141 | : da (xr.DataArray)
142 | : time_group (str)
143 | time string to group.
144 | """
145 | print('Computing ndvi anomaly')
146 | assert isinstance(da, xr.DataArray), f"`da` should be of type `xr.DataArray`. Currently: {type(da)}"
147 | trimmed_da = da[da['time.year'] < 2016]
148 | mthly_vals = trimmed_da.groupby(time_group).mean('time')
149 | da = da.groupby(time_group) - mthly_vals
150 |
151 | return da
152 |
--------------------------------------------------------------------------------
/data/utils.py:
--------------------------------------------------------------------------------
1 | import xarray as xr
2 | import pandas as pd
3 | import numpy as np
4 | from pathlib import Path
5 | import os
6 | import warnings
7 |
8 | # invert the upside down variables
9 | def invert(da):
10 | """ given a dataarray invert the latitude values so that the plots are the right way up """
11 | da.values = da.values[:,::-1,:]
12 | assert da.name, f"the dataarray passed to function invert() must have a name!"
13 | print(f'inverted the values of {da.name}')
14 | return da
15 |
16 |
17 | def get_lc_mask(ds, mask_var):
18 | assert mask_var in ["spi", "spei", "ndvi"], f"Mask Variable should be one of spi spei ndvi (ADD MORE)"
19 | # create a land cover mask
20 | warnings.warn('Currently get_lc_mask() is working with hardcoded SPI values. May cause problems')
21 | lc_mask = ~ds['spi'].isel(time=1).isnull()
22 | lc_mask.name = "lc_mask"
23 | print("created lc_mask")
24 | return lc_mask
25 |
26 |
27 | def remove_excess_parameters(ds):
28 | """ select only some of the variables (NOTE: poorly programmed because hardcoded)
29 | TODO: remove hardcoded params
30 | """
31 | key_parameters = ["lst_mean","lst_day","lst_night","sm","precip","evaporation","transpiration","spi","spei","ndvi","evi","surface_soil_moisture","rootzone_soil_moisture","baresoil_evaporation","potential_evaporation"]
32 | # check that all the parameters exist in the dataset
33 | for param in key_parameters:
34 | assert param in [var for var in ds.variables.keys()], f"Param {param} not found!"
35 | ds = ds[key_parameters]
36 | print(f"Keeping parameters: {key_parameters}")
37 | return ds
38 |
39 |
40 | def mask_sea(ds, lc_mask):
41 | """ MASK THE SEA VALUES (select only where lc_mask == 1)"""
42 | ds = ds.where(lc_mask)
43 | print(f"Sea values masked for ds with vars {[var for var in ds.data_vars.keys()]}")
44 | return ds
45 |
46 |
47 | def get_boolean_drought_ds(ds):
48 | """ extract the drought indices from the data array """
49 | assert "drought_ndvi" in [var for var in ds.variables.keys()], f"drought_ndvi should be in {[var for var in ds.variables.keys()]}"
50 | assert "drought_spei" in [var for var in ds.variables.keys()], f"drought_spei should be in {[var for var in ds.variables.keys()]}"
51 | assert "drought_spi" in [var for var in ds.variables.keys()], f"drought_spi should be in {[var for var in ds.variables.keys()]}"
52 | drought = ds[['drought_spei','drought_spi','drought_ndvi']]
53 | print("Created a drought index xr.Dataset")
54 | return drought
55 |
56 |
57 | def save_netcdf(output_ds, filename):
58 | """ save the dataset"""
59 | output_ds.to_netcdf(filename)
60 | print(f"{filename} saved!")
61 | return
62 |
63 |
64 | def clean_lst_variables(ds):
65 | """"""
66 | lst_vars = [var_ for var_ in ds.data_vars.keys() if "lst" in var_]
67 | for lst_var in lst_vars:
68 | # filter OUT the lst values of 200
69 | valid = ds[lst_var] < 200
70 | ds[lst_var] = ds[lst_var].fillna(np.nan).where(valid)
71 | return ds
72 |
73 |
74 | def clean_data(ds, lc_mask):
75 | # drop the final timestep
76 | ds = ds.isel(time=slice(0,-1))
77 |
78 | # invert precip
79 | ds['precip'] = invert(ds.precip)
80 |
81 | # get only the important parameters
82 | ds = remove_excess_parameters(ds)
83 |
84 | # mask out the sea
85 | ds = mask_sea(ds, lc_mask)
86 |
87 | # clean lst variables
88 | ds = clean_lst_variables(ds)
89 |
90 | return ds
91 |
92 |
93 | def get_df_of_pixels_to_remove(lc_mask):
94 | """ return a dataframe of the pixels that correspond to SEA points (and therefore that we want to remove)
95 | """
96 | mask_df = lc_mask.to_dataframe()
97 | mask_df = mask_df.reset_index().drop(columns=["time"])
98 | indexes_to_remove = mask_df.where(~mask_df.lc_mask).dropna()
99 |
100 | return indexes_to_remove
101 |
102 |
103 |
104 | def shift_by_time(ds):
105 | """ shift dataset by number of timesteps (so if shift by 3 you get )"""
106 | return ds.shift(ts)
107 |
108 |
109 |
110 | def mask_ds(ds, drought_mask, drought=True):
111 | """set all of the pixels that are SEA or DROUGHT to NAN"""
112 | print("drought pixels masked")
113 |
114 | return ds
115 |
116 |
117 | def make_masks_boolean(mask_ds):
118 | """ convert masks from 0,1 to False,True """
119 | try:
120 | print(f"Convert mask to bool for ds with vars {[var for var in mask_ds.data_vars.keys()]}")
121 | mask_ds = mask_ds.astype(bool)
122 | except:
123 | try:
124 | print(f"UNABLE to convert mask to bool for ds with vars {[var for var in mask_ds.data_vars.keys()]}")
125 | except: # is a data array and it doesn't have .data_vars()
126 | print(f"UNABLE to convert mask to bool for ds with vars {mask_ds.name}")
127 | return mask_ds
128 |
129 |
130 | def read_data(data_path='.', mask_var='spi'):
131 | """ """
132 | assert os.path.isfile(data_path), f"The path provided to read data does not exist! Currently: {data_path}"
133 | # data_path = "/Volumes/Lees_Extend/data/ea_data/all_variables_LC2.nc"
134 | print(f"Reading from file: {data_path}")
135 | ds = xr.open_dataset(data_path)
136 | lc_mask = get_lc_mask(ds, mask_var)
137 | # --------------------------------------------------------------------------
138 | # OFFENDING LINE
139 | warnings.warn('drought_ndvi hardcoded in here. Not at all okay.gst FIX ME')
140 | ds['drought_ndvi'] = ds.ndvi < (ds.ndvi.mean(dim='time') - ds.ndvi.std(dim='time'))
141 | # --------------------------------------------------------------------------
142 | drought_mask = get_boolean_drought_ds(ds)
143 |
144 | # REMOVE THE SEA VALUES from drought mask
145 | drought_mask = mask_sea(drought_mask, lc_mask)
146 | drought_mask = make_masks_boolean(drought_mask)
147 | lc_mask = make_masks_boolean(lc_mask)
148 |
149 | ds = clean_data(ds, lc_mask)
150 |
151 | return ds, lc_mask, drought_mask
152 |
153 |
154 | def print_shift_explanation(ds):
155 | """ print statements to explain the differences with the shift operator """
156 | print("*** UNSHIFTED TIME ***")
157 | print("time=0\n", ds.isel(lat=slice(0,2), lon=slice(0,2),time=0).precip.values)
158 | print("time=1\n", ds.isel(lat=slice(0,2), lon=slice(0,2),time=1).precip.values)
159 | print("time=2\n",ds.isel(lat=slice(0,2), lon=slice(0,2), time=2).precip.values)
160 | print("time=3\n",ds.isel(lat=slice(0,2), lon=slice(0,2), time=3).precip.values)
161 | print()
162 | print("*** SHIFTED TIME (+1) = moving the HISTORICAL TIMESTEPS FORWARD to the PRESENT ***")
163 | print("time=0\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=1).isel(time=0).precip.values)
164 | print("time=1\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=1).isel(time=1).precip.values)
165 | print("time=2\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=1).isel(time=2).precip.values)
166 | print("time=3\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=1).isel(time=3).precip.values)
167 | print()
168 | print("*** SHIFTED TIME (-1) = moving the PRESENT TIMESTEPS BACKWARD to the PAST ***")
169 | print("time=0\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=-1).isel(time=0).precip.values)
170 | print("time=1\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=-1).isel(time=1).precip.values)
171 | print("time=2\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=-1).isel(time=2).precip.values)
172 | print("time=3\n",ds.isel(lat=slice(0,2), lon=slice(0,2)).shift(time=-1).isel(time=3).precip.values)
173 | print()
174 |
175 | # ----------------------------------------------------------------------
176 | def create_lc_mask(ds):
177 | """ from the dataset create a landcover mask """
178 | # create a land cover mask
179 | lc_mask = ~ds.spi.isel(time=1).isnull()
180 | lc_mask.name = "lc_mask"
181 |
182 | # create df lc mask
183 | mask_df = lc_mask.to_dataframe()
184 | mask_df = mask_df.reset_index().drop(columns=["time"])
185 |
186 | # get df of SEA pixels (pixels to remove)
187 | indexes_to_remove = mask_df.where(~mask_df.lc_mask).dropna()
188 |
189 | return lc_mask, mask_df, indexes_to_remove
190 |
191 |
192 | def drop_nans_and_flatten(dataArray):
193 | """flatten the array and drop nans from that array. Useful for plotting histograms.
194 |
195 | Arguments:
196 | ---------
197 | : dataArray (xr.DataArray)
198 | the DataArray of your value you want to flatten
199 | """
200 | # drop NaNs and flatten
201 | return dataArray.values[~np.isnan(dataArray.values)]
202 |
--------------------------------------------------------------------------------
/notebooks/05_tl_explore_models.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from pathlib import Path\n",
10 | "from itertools import chain\n",
11 | "import pandas as pd\n",
12 | "import numpy as np\n",
13 | "\n",
14 | "import json\n",
15 | "import sys\n",
16 | "sys.path.insert(0, '..')\n",
17 | "\n",
18 | "from predictor.models import LinearModel\n",
19 | "from predictor.preprocessing import VALUE_COLS, VEGETATION_LABELS\n",
20 | "from predictor.models import nn_FeedForward\n",
21 | "from predictor.analysis import plot_shap_values"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "with open('../data/processed/normalizing_dict.json', 'r') as f:\n",
31 | " normalizing_dict = json.load(f)\n",
32 | " \n",
33 | "path_to_arrays = Path('../data/processed/arrays')"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 9,
39 | "metadata": {},
40 | "outputs": [
41 | {
42 | "data": {
43 | "text/plain": [
44 | ""
45 | ]
46 | },
47 | "execution_count": 9,
48 | "metadata": {},
49 | "output_type": "execute_result"
50 | }
51 | ],
52 | "source": [
53 | "model = nn_FeedForward(path_to_arrays, hide_vegetation=True)\n",
54 | "model"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 10,
60 | "metadata": {},
61 | "outputs": [
62 | {
63 | "name": "stdout",
64 | "output_type": "stream",
65 | "text": [
66 | "Training model without vegetation features\n"
67 | ]
68 | },
69 | {
70 | "name": "stderr",
71 | "output_type": "stream",
72 | "text": [
73 | " 0%| | 0/8175 [00:00, ?it/s]"
74 | ]
75 | },
76 | {
77 | "name": "stdout",
78 | "output_type": "stream",
79 | "text": [
80 | "After split, training on 261593 examples, validating on 29065 examples\n"
81 | ]
82 | },
83 | {
84 | "name": "stderr",
85 | "output_type": "stream",
86 | "text": [
87 | "100%|██████████| 8175/8175 [00:26<00:00, 305.35it/s]\n",
88 | "100%|██████████| 909/909 [00:01<00:00, 826.49it/s]\n",
89 | " 0%| | 27/8175 [00:00<00:30, 266.30it/s]"
90 | ]
91 | },
92 | {
93 | "name": "stdout",
94 | "output_type": "stream",
95 | "text": [
96 | "Epoch 0 - Training RMSE: 0.06997232270623566, Validation RMSE: 0.06983803885162074\n"
97 | ]
98 | },
99 | {
100 | "name": "stderr",
101 | "output_type": "stream",
102 | "text": [
103 | "100%|██████████| 8175/8175 [00:24<00:00, 331.95it/s]\n",
104 | "100%|██████████| 909/909 [00:01<00:00, 806.22it/s]\n",
105 | " 0%| | 26/8175 [00:00<00:32, 252.17it/s]"
106 | ]
107 | },
108 | {
109 | "name": "stdout",
110 | "output_type": "stream",
111 | "text": [
112 | "Epoch 1 - Training RMSE: 0.05651973409048461, Validation RMSE: 0.06605148216252542\n"
113 | ]
114 | },
115 | {
116 | "name": "stderr",
117 | "output_type": "stream",
118 | "text": [
119 | "100%|██████████| 8175/8175 [00:25<00:00, 319.95it/s]\n",
120 | "100%|██████████| 909/909 [00:01<00:00, 792.90it/s]\n",
121 | " 0%| | 28/8175 [00:00<00:29, 278.07it/s]"
122 | ]
123 | },
124 | {
125 | "name": "stdout",
126 | "output_type": "stream",
127 | "text": [
128 | "Epoch 2 - Training RMSE: 0.05542360727917346, Validation RMSE: 0.0645962596365852\n"
129 | ]
130 | },
131 | {
132 | "name": "stderr",
133 | "output_type": "stream",
134 | "text": [
135 | "100%|██████████| 8175/8175 [00:26<00:00, 304.98it/s]\n",
136 | "100%|██████████| 909/909 [00:01<00:00, 763.35it/s]\n",
137 | " 0%| | 26/8175 [00:00<00:31, 256.86it/s]"
138 | ]
139 | },
140 | {
141 | "name": "stdout",
142 | "output_type": "stream",
143 | "text": [
144 | "Epoch 3 - Training RMSE: 0.054586497671013576, Validation RMSE: 0.06387223563145752\n"
145 | ]
146 | },
147 | {
148 | "name": "stderr",
149 | "output_type": "stream",
150 | "text": [
151 | "100%|██████████| 8175/8175 [00:25<00:00, 326.63it/s]\n",
152 | "100%|██████████| 909/909 [00:01<00:00, 803.46it/s]"
153 | ]
154 | },
155 | {
156 | "name": "stdout",
157 | "output_type": "stream",
158 | "text": [
159 | "Epoch 4 - Training RMSE: 0.054150071823314425, Validation RMSE: 0.06281156092472333\n"
160 | ]
161 | },
162 | {
163 | "name": "stderr",
164 | "output_type": "stream",
165 | "text": [
166 | "\n"
167 | ]
168 | }
169 | ],
170 | "source": [
171 | "model.train(num_epochs=5)"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 11,
177 | "metadata": {},
178 | "outputs": [
179 | {
180 | "name": "stderr",
181 | "output_type": "stream",
182 | "text": [
183 | "100%|██████████| 916/916 [00:01<00:00, 619.96it/s]\n"
184 | ]
185 | },
186 | {
187 | "name": "stdout",
188 | "output_type": "stream",
189 | "text": [
190 | "Test set RMSE: 0.08204460144042969\n"
191 | ]
192 | }
193 | ],
194 | "source": [
195 | "model.evaluate()"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": 12,
201 | "metadata": {},
202 | "outputs": [
203 | {
204 | "name": "stdout",
205 | "output_type": "stream",
206 | "text": [
207 | "Training model without vegetation features\n"
208 | ]
209 | }
210 | ],
211 | "source": [
212 | "background_data = model.load_tensors(mode='train')\n",
213 | "test_data = model.load_tensors(mode='test')"
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": 8,
219 | "metadata": {},
220 | "outputs": [
221 | {
222 | "data": {
223 | "text/plain": [
224 | "[PosixPath('../data/processed/arrays/train/x.npy'),\n",
225 | " PosixPath('../data/processed/arrays/train/y.npy'),\n",
226 | " PosixPath('../data/processed/arrays/train/latlon.npy'),\n",
227 | " PosixPath('../data/processed/arrays/train/years.npy')]"
228 | ]
229 | },
230 | "execution_count": 8,
231 | "metadata": {},
232 | "output_type": "execute_result"
233 | }
234 | ],
235 | "source": [
236 | "[arr for arr in path_to_arrays.glob('test/*')]\n",
237 | "[arr for arr in path_to_arrays.glob('train/*')]"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": 11,
243 | "metadata": {},
244 | "outputs": [
245 | {
246 | "data": {
247 | "text/plain": [
248 | "(58604, 11, 8)"
249 | ]
250 | },
251 | "execution_count": 11,
252 | "metadata": {},
253 | "output_type": "execute_result"
254 | }
255 | ],
256 | "source": [
257 | "x = np.load(path_to_arrays/'test/x.npy')\n",
258 | "x.shape"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": 12,
264 | "metadata": {},
265 | "outputs": [
266 | {
267 | "data": {
268 | "text/plain": [
269 | "(58604,)"
270 | ]
271 | },
272 | "execution_count": 12,
273 | "metadata": {},
274 | "output_type": "execute_result"
275 | }
276 | ],
277 | "source": [
278 | "y = np.load(path_to_arrays/'test/y.npy')\n",
279 | "y.shape"
280 | ]
281 | },
282 | {
283 | "cell_type": "code",
284 | "execution_count": 14,
285 | "metadata": {},
286 | "outputs": [
287 | {
288 | "data": {
289 | "text/plain": [
290 | "(58604,)"
291 | ]
292 | },
293 | "execution_count": 14,
294 | "metadata": {},
295 | "output_type": "execute_result"
296 | }
297 | ],
298 | "source": [
299 | "latlon = np.load(path_to_arrays/'test/latlon.npy')\n",
300 | "latlon.shape\n",
301 | "\n",
302 | "years = np.load(path_to_arrays/'test/years.npy')\n",
303 | "years.shape"
304 | ]
305 | },
306 | {
307 | "cell_type": "code",
308 | "execution_count": null,
309 | "metadata": {},
310 | "outputs": [],
311 | "source": []
312 | }
313 | ],
314 | "metadata": {
315 | "kernelspec": {
316 | "display_name": "Python 3",
317 | "language": "python",
318 | "name": "python3"
319 | },
320 | "language_info": {
321 | "codemirror_mode": {
322 | "name": "ipython",
323 | "version": 3
324 | },
325 | "file_extension": ".py",
326 | "mimetype": "text/x-python",
327 | "name": "python",
328 | "nbconvert_exporter": "python",
329 | "pygments_lexer": "ipython3",
330 | "version": "3.7.2"
331 | }
332 | },
333 | "nbformat": 4,
334 | "nbformat_minor": 2
335 | }
336 |
--------------------------------------------------------------------------------
/data/drought_masking.py:
--------------------------------------------------------------------------------
1 | """
2 | @tommylees112
3 |
4 | Code for getting land surface variables for previous timesteps that are IN
5 | drought or OUT of drought.
6 |
7 | For the pixels that are in drought, return variables for that pixel for t=0,
8 | t=-1,t=-2,t=-3.
9 |
10 | This way you have the previous 3 months worth of data
11 | """
12 | import xarray as xr
13 | from joblib import Parallel, delayed
14 | import pandas as pd
15 | import os
16 | from utils import shift_by_time, read_data, print_shift_explanation
17 | from utils import save_netcdf
18 | import warnings
19 | import numpy as np
20 |
21 |
22 | def get_monthly_location_specific_drought_mask(ds, variable):
23 | """ from a dataset, use a variable and calculate 1 STD below
24 | the mean conditioned on:
25 | a) LOCATION (lat/lon)
26 | b) TIME (month)
27 | """
28 | # get month labels
29 | ds = ds.assign_coords(**{'month':ds.time.dt.month})
30 |
31 | outs = []
32 | for month in range(0,13):
33 | # select ONE MONTH for your variable of interest from the dataset
34 | d = ds[variable].sel(time=(ds.time.dt.month == month))
35 |
36 | # calculate avg/std
37 | avg = ds[variable].isel(time=(ds.time.dt.month == month)).mean('time')
38 | std_dev = ds[variable].isel(time=(ds.time.dt.month == month)).std('time')
39 |
40 | # calculate threshold
41 | threshold = avg - std_dev
42 |
43 | # calculate mask
44 | mask = d < threshold
45 |
46 | # create month specific mask for each timestep
47 | outs.append(mask)
48 | drought_ndvi = xr.concat(outs, dim="time")
49 | return drought_ndvi
50 |
51 |
52 |
53 | def create_drought_mask(drought_ds):
54 | """ at the moment just use an SPI / SPEI of <= -1
55 | TODO: implement more flexible functionality
56 | """
57 | assert ("spi" in [_ for _ in drought_ds.var().variables])&("spei" in [_ for _ in drought_ds.var().variables]), f"spi and spei are expected to be in the drought_ds netcdf file. Currently {drought_ds.var().variables}"
58 | assert "ndvi" in [_ for _ in drought_ds.var().variables], f"ndvi should be a variable. currently: {drought_ds.var().variables}"
59 |
60 | # turn the values less than -1 SD into mask (SAME DIMS AS RAW DS)
61 | warnings.warn('NOTE: here the code could do with being more flexible in the determination of drought thresholds and then applying them.')
62 | warnings.warn('Use Ellens functionality already built in Akkadia')
63 | spei_drought = drought_ds.spei.where(drought_ds.spei < -1)
64 | spi_drought = drought_ds.spi.where(drought_ds.spi < -1)
65 |
66 | ndvi_drought = get_monthly_location_specific_drought_mask(drought_ds, 'ndvi')
67 |
68 | # turn into a boolean mask
69 | drought_spei = spei_drought.notnull()
70 | drought_spei = drought_spei.rename('drought_spei')
71 | drought_spi = spi_drought.notnull()
72 | drought_spi = drought_spi.rename('drought_spi')
73 | drought_ndvi = ndvi_drought.notnull()
74 | drought_ndvi = drought_ndvi.rename('drought_ndvi')
75 |
76 | return drought_spei, drought_spi, drought_ndvi
77 |
78 |
79 | def mask_drought_events(ds, drought_mask, ts, index, drought=True):
80 | """ for a given drought mask, at a given time, return the previous 3 months of data
81 | for pixels that were IN (or out of) drought
82 | """
83 | assert ts > 3, "Need timestep to be greater than 3 because want the PREVIOUS 3 months. Otherwise all nans!"
84 | assert index in ['spei','spi', 'ndvi'], f"index must be either ['spei', 'spi', 'ndvi']. Currently: {index}"
85 |
86 | # get the drought mask at that timestep (NOTE have to invert BEFORE this)
87 | if index=='spi':
88 | msk = drought_mask.drought_spi.isel(time=ts)
89 | elif index=='spei':
90 | msk = drought_mask.drought_spei.isel(time=ts)
91 | elif index=='ndvi':
92 | msk = drought_mask.drought_ndvi.isel(time=ts)
93 |
94 | # if want pixels that were IN DROUGHT
95 | if drought:
96 | t0 = ds.where(msk).isel(time=ts)
97 | t1 = ds.where(msk).isel(time=ts-1)
98 | t2 = ds.where(msk).isel(time=ts-2)
99 | t3 = ds.where(msk).isel(time=ts-3)
100 |
101 | # if want pixels that are NOT IN DROUGHT
102 | else:
103 | t0 = ds.where(msk).isel(time=ts)
104 | t1 = ds.where(msk).isel(time=ts-1)
105 | t2 = ds.where(msk).isel(time=ts-2)
106 | t3 = ds.where(msk).isel(time=ts-3)
107 |
108 | return [t0, t1, t2, t3]
109 |
110 |
111 | def merge_all_ts_into_one_ds(ds_arr):
112 | """merge all of the timesteps (t=0, t-1, t-2, t-3) into one dataset object
113 |
114 | input:
115 | : array of all the
116 | returns:
117 | : ds_out (xr.Dataset): one dataset with all of the variables
118 | """
119 | # assert that the minimum time is the first dataset in the array
120 | time = ds_arr[0].time
121 | ds_rnm = []
122 |
123 | for ts, ds_ in enumerate(ds_arr):
124 | # create dict of variables to rename
125 | map_names = dict(zip([var for var in ds_.variables.keys() if var not in ['time','lat','lon']],
126 | [f"{var}_t{ts}" for var in ds_.variables.keys() if var not in ['time','lat','lon']])
127 | )
128 | # rename the variables
129 | ds_ = ds_.rename(map_names)
130 | ds_rnm.append(ds_)
131 |
132 | # drop the 'time' (TO ALLOW THE MERGE)
133 | ds_rnm = [ds_.drop('time') for ds_ in ds_rnm]
134 |
135 | # merge the variables
136 | ds_out = xr.merge(ds_rnm)
137 |
138 | # reassign 'Time' from the first dataset
139 | ds_out = ds_out.assign_coords(**{'time':time})
140 |
141 | return ds_out
142 |
143 |
144 | def calculate_drought_masked_ds(ds, drought_mask, ts):
145 | """ run the above functions on EVERY timestep
146 | so output a dataset with each timestep a calculation of the previous months
147 | """
148 | print(f"Extracting Drought vars from timestep {ts}")
149 | ds_arr = mask_drought_events(ds, drought_mask, ts=ts, index='ndvi', drought=True)
150 | ds_drought = merge_all_ts_into_one_ds(ds_arr)
151 |
152 | return ds_drought
153 |
154 |
155 | def calculate_Ndrought_masked_ds(ds, drought_mask, ts):
156 | """ run the above functions on EVERY timestep
157 | so output a dataset with each timestep a calculation of the previous months
158 | """
159 | print(f"Extracting NON-Drought vars from timestep {ts}")
160 | ds_arr = mask_drought_events(ds, drought_mask, ts=ts, index='ndvi', drought=False)
161 | ds_Ndrought = merge_all_ts_into_one_ds(ds_arr)
162 |
163 | return ds_Ndrought
164 |
165 |
166 | def drought_across_all_timesteps(ds, drought_mask, start_ts=4):
167 | """ run the drought masking process for ALL TIMESTEPS
168 |
169 | Note: the output of Parallel(n_jobs=2)({FNCTN}) is a list of all of the variables
170 | because they are timestamped it doesn't matter if they come back in a different
171 | order. Therefore, we save the reordering to a later date. Future me can deal
172 | with that problem.
173 | """
174 | dr = []
175 | Ndr = []
176 |
177 | # RUN the drought first
178 | print("Extracting the pixels in DROUGHT")
179 | with Parallel(n_jobs=30, verbose=True) as parallel:
180 | dr = parallel(
181 | delayed(calculate_drought_masked_ds)
182 | (ds=ds,drought_mask=drought_mask,ts=ts) for ts in range(start_ts, ds.time.shape[0])
183 | )
184 | print("Pixels in DROUGHT extracted")
185 |
186 | # then run the Ndrought
187 | print("Extracting the pixels NOT in DROUGHT")
188 | # convert to boolean arrays in order to INVERT them
189 | Ndrought_mask = ~(drought_mask.astype(bool))
190 | with Parallel(n_jobs=30, verbose=True) as parallel:
191 | Ndr = parallel(
192 | delayed(calculate_Ndrought_masked_ds)
193 | (ds=ds,drought_mask=Ndrought_mask,ts=ts) for ts in range(start_ts, ds.time.shape[0])
194 | )
195 | print("Pixels NOT in DROUGHT extracted")
196 |
197 | # concatenate the arrays by time into ONE dataset (agree there's lots of duplication of data here)
198 | # TODO: does this really make sense? we are duplicating SOOOOOOO much data and for what?
199 | # definitely a better way.
200 | ds_drought = xr.concat(dr, dim='time')
201 | ds_drought.to_netcdf('ds_drought.nc')
202 | print("DROUGHT Variables extracted. Saving to netcdf ds_drought.nc ...")
203 |
204 | ds_Ndrought = xr.concat(Ndr, dim='time')
205 | ds_Ndrought.to_netcdf('ds_Ndrought.nc')
206 | print("NOT DROUGHT Variables extracted. Saving to netcdf ds_Ndrought.nc ...")
207 |
208 | return ds_drought, ds_Ndrought
209 |
210 |
211 | def run_drought_processing(data_path, mask_var='spi'):
212 | """
213 |
214 | """
215 | print("** Running drought processing for all timesteps! **")
216 | ds, lc_mask, drought_mask = read_data(data_path, mask_var)
217 | print(f"** Data read in - using {mask_var} as the masking variable **")
218 | ds_drought, ds_Ndrought = drought_across_all_timesteps(ds, drought_mask)
219 |
220 | save_netcdf(ds_drought, "ds_drought.nc")
221 | save_netcdf(ds_Ndrought, "ds_Ndrought.nc")
222 | print("** Process finished **")
223 |
224 | return ds_drought, ds_Ndrought
225 |
226 |
227 |
228 |
229 |
230 | def fix_drought_var():
231 | """if drought variable not in dataset then append it!"""
232 | pass
233 |
234 |
235 | if __name__=="__main__":
236 | """ TODO: set up as a CLI """
237 | # data_path = "/soge-home/users/chri4118/EA_data/all_variables_LCMASK.nc"
238 | data_path = "/soge-home/users/chri4118/ea_exploration/OUT.nc"
239 | ds_drought, ds_Ndrought = run_drought_processing(data_path, mask_var='ndvi')
240 |
--------------------------------------------------------------------------------
/data/preprocessing/utils.py:
--------------------------------------------------------------------------------
1 | import xarray as xr
2 | import pandas as pd
3 | import xesmf as xe # for regridding
4 | import numpy as np
5 |
6 | import os
7 | from pathlib import Path
8 | import ipdb
9 | import warnings
10 | import datetime
11 |
12 | import geopandas as gpd
13 | from shapely import geometry
14 |
15 |
16 | def read_csv_point_data(df, lat_col='lat', lon_col='lon', crs='epsg:4326'):
17 | """Read in a csv file with lat,lon values in a column and turn those lat lon
18 | values into geometry.Point objects.
19 | Arguments:
20 | ---------
21 | : df (pd.DataFrame)
22 | : lat_col (str)
23 | the column in the dataframe that has the point latitude information
24 | : lon_col (str)
25 | the column in the dataframe that has the point longitude information
26 | : crs (str)
27 | coordinate reference system (defaults to 'epsg:4326')
28 | Returns:
29 | -------
30 | : gdf (gpd.GeoDataFrame)
31 | a geopandas.GeoDataFrame object
32 | """
33 | df['geometry'] = [geometry.Point(y, x) \
34 | for x, y in zip(df[lat_col],
35 | df[lon_col])
36 | ]
37 | crs = {'init': crs}
38 | gdf = gpd.GeoDataFrame(df, crs=crs, geometry="geometry")
39 | return gdf
40 |
41 |
42 | # ------------------------------------------------------------------------------
43 | # Functions for reprojecting using GDAL and reading resulting .nc file back
44 | # ------------------------------------------------------------------------------
45 |
46 |
47 | def gdal_reproject(infile, outfile, **kwargs):
48 | """Use gdalwarp to reproject one file to another
49 |
50 | Help:
51 | ----
52 | https://www.gdal.org/gdalwarp.html
53 | """
54 | to_proj4_string = "+proj=longlat +ellps=WGS84 +datum=WGS84 +no_defs"
55 | resample_method = "near"
56 |
57 | # check options
58 | valid_resample_methods = [
59 | "average",
60 | "near",
61 | "bilinear",
62 | "cubic",
63 | "cubicspline",
64 | "lanczos",
65 | "mode",
66 | "max",
67 | "min",
68 | "med",
69 | "q1",
70 | "q3",
71 | ]
72 | assert (
73 | resample_method in valid_resample_methods
74 | ), f"Resample method not Valid. Must be one of: {valid_resample_methods} Currently: {resample_method}"
75 |
76 | cmd = f'gdalwarp -t_srs "{to_proj4_string}" -of netCDF -r average -dstnodata -9999 -ot Float32 {infile} {outfile}'
77 |
78 | # run command
79 | print(f"\n\n#### Running command: {cmd} ####\n\n")
80 | os.system(cmd)
81 | print(f"\n\n#### Run command {cmd} \n FILE REPROJECTED ####\n\n")
82 |
83 | return
84 |
85 |
86 | def bands_to_time(da, times, var_name):
87 | """ For a dataArray with each timestep saved as a different band, create
88 | a time Coordinate
89 | """
90 | # get a list of all the bands as dataarray objects (for concatenating later)
91 | band_strings = [key for key in da.variables.keys() if "Band" in key]
92 | bands = [da[key] for key in band_strings]
93 | bands = [band.rename(var_name) for band in bands]
94 |
95 | # check the number of bands matches n timesteps
96 | assert len(times) == len(
97 | bands
98 | ), f"The number of bands should match the number of timesteps. n bands: {len(times)} n times: {len(bands)}"
99 | # concatenate into one array
100 | timestamped_da = xr.concat(bands, dim=times)
101 |
102 | return timestamped_da
103 |
104 |
105 | # ------------------------------------------------------------------------------
106 | # Functions for matching resolutions / gridsizes (time and space)
107 | # ------------------------------------------------------------------------------
108 |
109 |
110 | def convert_to_same_grid(reference_ds, ds, method="nearest_s2d"):
111 | """ Use xEMSF package to regrid ds to the same grid as reference_ds """
112 | assert ("lat" in reference_ds.dims) & (
113 | "lon" in reference_ds.dims
114 | ), f"Need (lat,lon) in reference_ds dims Currently: {reference_ds.dims}"
115 | assert ("lat" in ds.dims) & (
116 | "lon" in ds.dims
117 | ), f"Need (lat,lon) in ds dims Currently: {ds.dims}"
118 |
119 | # create the grid you want to convert TO (from reference_ds)
120 | ds_out = xr.Dataset(
121 | {"lat": (["lat"], reference_ds.lat), "lon": (["lon"], reference_ds.lon)}
122 | )
123 |
124 | # create the regridder object
125 | # xe.Regridder(grid_in, grid_out, method='bilinear')
126 | regridder = xe.Regridder(ds, ds_out, method, reuse_weights=True)
127 |
128 | # IF it's a dataarray just do the original transformations
129 | if isinstance(ds, xr.core.dataarray.DataArray):
130 | ds = regridder(ds)
131 | # OTHERWISE loop through each of the variables, regrid the datarray then recombine into dataset
132 | elif isinstance(ds, xr.core.dataset.Dataset):
133 | vars = [i for i in ds.var().variables]
134 | if len(vars) == 1:
135 | ds = regridder(ds)
136 | else:
137 | output_dict = {}
138 | # LOOP over each variable and append to dict
139 | for var in vars:
140 | print(f"- regridding var {var} -")
141 | da = ds[var]
142 | da = regridder(da)
143 | output_dict[var] = da
144 | # REBUILD
145 | ds = xr.Dataset(output_dict)
146 | else:
147 | assert False, "This function only works with xarray dataset / dataarray objects"
148 |
149 | print(
150 | f"Regridded from {(regridder.Ny_in, regridder.Nx_in)} to {(regridder.Ny_out, regridder.Nx_out)}"
151 | )
152 |
153 | return ds
154 |
155 |
156 | def select_same_time_slice(reference_ds, ds):
157 | """ Select the values for the same timestep as the reference ds"""
158 | # CHECK THEY ARE THE SAME FREQUENCY
159 | # get the frequency of the time series from reference_ds
160 | freq = pd.infer_freq(reference_ds.time.values)
161 | if freq == None:
162 | warnings.warn('HARDCODED FOR THIS PROBLEM BUT NO IDEA WHY NOT WORKING')
163 | freq = "M"
164 | # assert False, f"Unable to infer frequency from the reference_ds timestep"
165 |
166 | old_freq = pd.infer_freq(ds.time.values)
167 | warnings.warn(
168 | "Disabled the assert statement. ENSURE FREQUENCIES THE SAME (e.g. monthly)"
169 | )
170 | # assert freq == old_freq, f"The frequencies should be the same! currenlty ref: {freq} vs. old: {old_freq}"
171 |
172 | # get the STARTING time point from the reference_ds
173 | min_time = reference_ds.time.min().values
174 | max_time = reference_ds.time.max().values
175 | orig_time_range = pd.date_range(min_time, max_time, freq=freq)
176 | # EXTEND the original time_range by 1 (so selecting the whole slice)
177 | # because python doesn't select the final in a range
178 | periods = len(orig_time_range) #+ 1
179 | # create new time series going ONE EXTRA PERIOD
180 | new_time_range = pd.date_range(min_time, freq=freq, periods=periods)
181 | new_max = new_time_range.max()
182 |
183 | # select using the NEW MAX as upper limit
184 | # --------------------------------------------------------------------------
185 | # FOR SOME REASON slice is removing the minimum time ...
186 | # something to do with the fact that matplotlib / xarray is working oddly with numpy64datetime object
187 | warnings.warn("L153: HARDCODING THE MIN VALUE OTHERWISE IGNORED ...")
188 | min_time = datetime.datetime(2001, 1, 31)
189 | # --------------------------------------------------------------------------
190 | ds = ds.sel(time=slice(min_time, new_max))
191 | assert reference_ds.time.shape[0] == ds.time.shape[0],f"The time dimensions should match, currently reference_ds.time dims {reference_ds.time.shape[0]} != ds.time dims {ds.time.shape[0]}"
192 |
193 | print_time_min = pd.to_datetime(ds.time.min().values)
194 | print_time_max = pd.to_datetime(ds.time.max().values)
195 | try:
196 | vars = [i for i in ds.var().variables]
197 | except:
198 | vars = ds.name
199 | # ref_vars = [i for i in reference_ds.var().variables]
200 | print(
201 | f"Select same timeslice for ds with vars: {vars}. Min {print_time_min} Max {print_time_max}"
202 | )
203 |
204 | return ds
205 |
206 |
207 | def get_holaps_mask(ds):
208 | """
209 | NOTE:
210 | - assumes that all of the null values from the HOLAPS file are valid null values (e.g. water bodies). Could also be invalid nulls due to poor data processing / lack of satellite input data for a pixel!
211 | """
212 | warnings.warn(
213 | "assumes that all of the null values from the HOLAPS file are valid null values (e.g. water bodies). Could also be invalid nulls due to poor data processing / lack of satellite input data for a pixel!"
214 | )
215 | warnings.warn(
216 | "How to collapse the time dimension in the holaps mask? Here we just select the first time because all of the valid pixels are constant for first, last second last. Need to check this is true for all timesteps"
217 | )
218 |
219 | mask = ds.isnull().isel(time=0).drop("time")
220 | mask.name = "holaps_mask"
221 |
222 | mask = xr.concat([mask for _ in range(len(ds.time))])
223 | mask = mask.rename({"concat_dims": "time"})
224 | mask["time"] = ds.time
225 |
226 | return mask
227 |
228 |
229 |
230 | def select_east_africa(ds):
231 | """ """
232 | lonmin=32.6
233 | lonmax=51.8
234 | latmin=-5.0
235 | latmax=15.2
236 |
237 | if ('x' in ds.dims) & ('y' in ds.dims):
238 | ds = ds.sel(y=slice(latmax,latmin),x=slice(lonmin, lonmax))
239 | elif ('lat' in ds.dims) & ('lon' in ds.dims):
240 | ds = ds.sel(lat=slice(latmax,latmin),lon=slice(lonmin, lonmax))
241 | elif ('latitude' in ds.dims) & ('longitude' in ds.dims):
242 | ds = ds.sel(latitude=slice(latmax,latmin),longitude=slice(lonmin, lonmax))
243 | else:
244 | assert False, "You need one of [(y, x), (lat, lon), (latitude, longitude)] in your dims"
245 |
246 | return
247 |
248 |
249 | # ------------------------------------------------------------------------------
250 | # Functions for working with xarray objects
251 | # ------------------------------------------------------------------------------
252 |
253 |
254 | def merge_data_arrays(*DataArrays):
255 | das = [da.name for da in DataArrays]
256 | print(f"Merging data: {das}")
257 | ds = xr.merge([*DataArrays])
258 | return ds
259 |
260 |
261 | def save_netcdf(xr_obj, filepath, force=False):
262 | """"""
263 | if not Path(filepath).is_file():
264 | xr_obj.to_netcdf(filepath)
265 | print(f"File Saved: {filepath}")
266 | elif force:
267 | print(f"Filepath {filepath} already exists! Overwriting...")
268 | xr_obj.to_netcdf(filepath)
269 | print(f"File Saved: {filepath}")
270 | else:
271 | print(f"Filepath {filepath} already exists!")
272 |
273 | return
274 |
275 |
276 | def get_all_valid(ds, holaps_da, modis_da, gleam_da):
277 | """ Return only values for pixels/times where ALL PRODUCTS ARE VALID """
278 | valid_mask = (
279 | holaps_da.notnull()
280 | & modis_da.notnull()
281 | & gleam_da.notnull()
282 | )
283 | ds_valid = ds.where(valid_mask)
284 |
285 | return ds_valid
286 |
287 |
288 | def drop_nans_and_flatten(dataArray):
289 | """flatten the array and drop nans from that array. Useful for plotting histograms.
290 |
291 | Arguments:
292 | ---------
293 | : dataArray (xr.DataArray)
294 | the DataArray of your value you want to flatten
295 | """
296 | # drop NaNs and flatten
297 | return dataArray.values[~np.isnan(dataArray.values)]
298 |
299 | #
300 |
301 | #
302 | # def create_flattened_dataframe_of_values(h,g,m):
303 | # """ """
304 | # h_ = drop_nans_and_flatten(h)
305 | # g_ = drop_nans_and_flatten(g)
306 | # m_ = drop_nans_and_flatten(m)
307 | # df = pd.DataFrame(dict(
308 | # holaps=h_,
309 | # gleam=g_,
310 | # modis=m_
311 | # ))
312 | # return df
313 |
--------------------------------------------------------------------------------
/notebooks/02_gt_linear_model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Linear model\n",
8 | "\n",
9 | "In this notebook, we train a linear model and investigate the coefficients it learns. These coefficients can be interpreted as the global feature importances of the input features."
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "from pathlib import Path\n",
19 | "from itertools import chain\n",
20 | "import pandas as pd\n",
21 | "import numpy as np\n",
22 | "\n",
23 | "import sys\n",
24 | "sys.path.insert(0, '..')\n",
25 | "\n",
26 | "from predictor.models import LinearModel\n",
27 | "from predictor.preprocessing import VALUE_COLS, VEGETATION_LABELS"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 2,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "target = 'ndvi'"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 3,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "path_to_arrays = Path(f'../data/processed/{target}/arrays')"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 4,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "model_with_veg = LinearModel(path_to_arrays)"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 5,
60 | "metadata": {},
61 | "outputs": [
62 | {
63 | "name": "stdout",
64 | "output_type": "stream",
65 | "text": [
66 | "Train set RMSE: 0.03977352798000336\n"
67 | ]
68 | }
69 | ],
70 | "source": [
71 | "model_with_veg.train()"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {},
77 | "source": [
78 | "We can isolate the coefficients, as well as the values they correspond to. Note that each label in `value_labels` has the following format: `{value}_{month}` where `month` is relative to the `pred_month` (so if we are predicting June, then `month=11` corresponds to data in May)."
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 6,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "coefs = model_with_veg.model.coef_"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": 7,
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "value_labels = list(chain(*[[f'{val}_{month}' for val in VALUE_COLS] for month in range(1, 12)]))"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 8,
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "feature_importances_veg = pd.DataFrame(data={\n",
106 | " 'feature': value_labels,\n",
107 | " 'value': coefs\n",
108 | "})"
109 | ]
110 | },
111 | {
112 | "cell_type": "markdown",
113 | "metadata": {},
114 | "source": [
115 | "Lets investigate the most important features (by absolute value)"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 9,
121 | "metadata": {},
122 | "outputs": [
123 | {
124 | "data": {
125 | "text/html": [
126 | "\n",
127 | "\n",
140 | "
\n",
141 | " \n",
142 | " \n",
143 | " | \n",
144 | " feature | \n",
145 | " value | \n",
146 | "
\n",
147 | " \n",
148 | " \n",
149 | " \n",
150 | " | 74 | \n",
151 | " ndvi_11 | \n",
152 | " 0.118459 | \n",
153 | "
\n",
154 | " \n",
155 | " | 4 | \n",
156 | " ndvi_1 | \n",
157 | " 0.110067 | \n",
158 | "
\n",
159 | " \n",
160 | " | 67 | \n",
161 | " ndvi_10 | \n",
162 | " -0.054666 | \n",
163 | "
\n",
164 | " \n",
165 | " | 6 | \n",
166 | " ndvi_anomaly_1 | \n",
167 | " -0.036427 | \n",
168 | "
\n",
169 | " \n",
170 | " | 11 | \n",
171 | " ndvi_2 | \n",
172 | " -0.031033 | \n",
173 | "
\n",
174 | " \n",
175 | " | 75 | \n",
176 | " evi_11 | \n",
177 | " 0.020969 | \n",
178 | "
\n",
179 | " \n",
180 | " | 39 | \n",
181 | " ndvi_6 | \n",
182 | " 0.020805 | \n",
183 | "
\n",
184 | " \n",
185 | " | 71 | \n",
186 | " lst_day_11 | \n",
187 | " -0.016052 | \n",
188 | "
\n",
189 | " \n",
190 | " | 46 | \n",
191 | " ndvi_7 | \n",
192 | " 0.015323 | \n",
193 | "
\n",
194 | " \n",
195 | " | 40 | \n",
196 | " evi_6 | \n",
197 | " -0.014433 | \n",
198 | "
\n",
199 | " \n",
200 | "
\n",
201 | "
"
202 | ],
203 | "text/plain": [
204 | " feature value\n",
205 | "74 ndvi_11 0.118459\n",
206 | "4 ndvi_1 0.110067\n",
207 | "67 ndvi_10 -0.054666\n",
208 | "6 ndvi_anomaly_1 -0.036427\n",
209 | "11 ndvi_2 -0.031033\n",
210 | "75 evi_11 0.020969\n",
211 | "39 ndvi_6 0.020805\n",
212 | "71 lst_day_11 -0.016052\n",
213 | "46 ndvi_7 0.015323\n",
214 | "40 evi_6 -0.014433"
215 | ]
216 | },
217 | "execution_count": 9,
218 | "metadata": {},
219 | "output_type": "execute_result"
220 | }
221 | ],
222 | "source": [
223 | "feature_importances_veg.iloc[(-np.abs(feature_importances_veg['value'].values)).argsort()][:10]"
224 | ]
225 | },
226 | {
227 | "cell_type": "markdown",
228 | "metadata": {},
229 | "source": [
230 | "## Without vegetation\n",
231 | "\n",
232 | "The model above tells us that the vegetation health in May is predictive of the vegetation health in June. What happens if we hide vegetation health from the model?"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": 10,
238 | "metadata": {},
239 | "outputs": [],
240 | "source": [
241 | "model_no_veg = LinearModel(path_to_arrays, hide_vegetation=True)"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": 11,
247 | "metadata": {},
248 | "outputs": [
249 | {
250 | "name": "stdout",
251 | "output_type": "stream",
252 | "text": [
253 | "Training model without vegetation features\n",
254 | "Train set RMSE: 0.07832527470825593\n"
255 | ]
256 | }
257 | ],
258 | "source": [
259 | "model_no_veg.train()"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": 12,
265 | "metadata": {},
266 | "outputs": [],
267 | "source": [
268 | "coefs = model_no_veg.model.coef_\n",
269 | "\n",
270 | "veg_features = ['ndvi', 'evi']\n",
271 | "value_labels = list(chain(*[[f'{val}_{month}' for val in VALUE_COLS if val not in VEGETATION_LABELS] \n",
272 | " for month in range(1, 12)]))"
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": 13,
278 | "metadata": {},
279 | "outputs": [],
280 | "source": [
281 | "feature_importances_no_veg = pd.DataFrame(data={\n",
282 | " 'feature': value_labels,\n",
283 | " 'value': coefs\n",
284 | "})"
285 | ]
286 | },
287 | {
288 | "cell_type": "code",
289 | "execution_count": 14,
290 | "metadata": {},
291 | "outputs": [
292 | {
293 | "data": {
294 | "text/html": [
295 | "\n",
296 | "\n",
309 | "
\n",
310 | " \n",
311 | " \n",
312 | " | \n",
313 | " feature | \n",
314 | " value | \n",
315 | "
\n",
316 | " \n",
317 | " \n",
318 | " \n",
319 | " | 41 | \n",
320 | " lst_day_11 | \n",
321 | " -0.068263 | \n",
322 | "
\n",
323 | " \n",
324 | " | 42 | \n",
325 | " precip_11 | \n",
326 | " 0.048880 | \n",
327 | "
\n",
328 | " \n",
329 | " | 0 | \n",
330 | " lst_night_1 | \n",
331 | " 0.035778 | \n",
332 | "
\n",
333 | " \n",
334 | " | 40 | \n",
335 | " lst_night_11 | \n",
336 | " -0.035483 | \n",
337 | "
\n",
338 | " \n",
339 | " | 30 | \n",
340 | " precip_8 | \n",
341 | " 0.030408 | \n",
342 | "
\n",
343 | " \n",
344 | " | 20 | \n",
345 | " lst_night_6 | \n",
346 | " 0.029277 | \n",
347 | "
\n",
348 | " \n",
349 | " | 21 | \n",
350 | " lst_day_6 | \n",
351 | " -0.026891 | \n",
352 | "
\n",
353 | " \n",
354 | " | 10 | \n",
355 | " precip_3 | \n",
356 | " 0.026077 | \n",
357 | "
\n",
358 | " \n",
359 | " | 33 | \n",
360 | " lst_day_9 | \n",
361 | " 0.025803 | \n",
362 | "
\n",
363 | " \n",
364 | " | 2 | \n",
365 | " precip_1 | \n",
366 | " 0.021489 | \n",
367 | "
\n",
368 | " \n",
369 | "
\n",
370 | "
"
371 | ],
372 | "text/plain": [
373 | " feature value\n",
374 | "41 lst_day_11 -0.068263\n",
375 | "42 precip_11 0.048880\n",
376 | "0 lst_night_1 0.035778\n",
377 | "40 lst_night_11 -0.035483\n",
378 | "30 precip_8 0.030408\n",
379 | "20 lst_night_6 0.029277\n",
380 | "21 lst_day_6 -0.026891\n",
381 | "10 precip_3 0.026077\n",
382 | "33 lst_day_9 0.025803\n",
383 | "2 precip_1 0.021489"
384 | ]
385 | },
386 | "execution_count": 14,
387 | "metadata": {},
388 | "output_type": "execute_result"
389 | }
390 | ],
391 | "source": [
392 | "feature_importances_no_veg.iloc[(-np.abs(feature_importances_no_veg['value'].values)).argsort()][:10]"
393 | ]
394 | }
395 | ],
396 | "metadata": {
397 | "kernelspec": {
398 | "display_name": "Python [conda env:vegetation_health]",
399 | "language": "python",
400 | "name": "conda-env-vegetation_health-py"
401 | },
402 | "language_info": {
403 | "codemirror_mode": {
404 | "name": "ipython",
405 | "version": 3
406 | },
407 | "file_extension": ".py",
408 | "mimetype": "text/x-python",
409 | "name": "python",
410 | "nbconvert_exporter": "python",
411 | "pygments_lexer": "ipython3",
412 | "version": "3.6.8"
413 | }
414 | },
415 | "nbformat": 4,
416 | "nbformat_minor": 2
417 | }
418 |
--------------------------------------------------------------------------------
/data/common_grid.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for putting netcdf (.nc) files onto a common grid. This means that the
3 | location, timesteps and resolution of the data is now the same. This makes
4 | working with the data a lot simpler!
5 | """
6 |
7 | import xarray as xr
8 | import xesmf as xe # for regridding
9 | import numpy as np
10 | import pickle
11 | import matplotlib.pyplot as plt
12 | import pandas as pd
13 | import tqdm # for progress-bars
14 | import click
15 | import warnings
16 |
17 | from drought_masking import create_drought_mask
18 | from utils import save_netcdf
19 |
20 |
21 | def pickle_obj(obj, filename):
22 | """ write to pickle
23 | """
24 | assert (filename.split('.')[-1] == "pickle") or (filename.split('.')[-1] == "pkl"), f"filename should end with ('pickle','pkl') currently: {filename}"
25 | output = open(filename, 'wb')
26 | pickle.dump(obj, output)
27 |
28 | return
29 |
30 |
31 | def read_pickle_df(filepath):
32 | """ read pickled object
33 | """
34 | obj = pd.read_pickle(filepath)
35 | return obj
36 |
37 |
38 | def normalise_coordinate_names(ds):
39 | """ rename latitude/longitude to lat/lon """
40 | if "latitude" in ds.dims:
41 | ds = ds.rename({"latitude":"lat"})
42 | print("Renamed `latitude` to `lat`")
43 | if "longitude" in ds.dims:
44 | ds = ds.rename({"longitude":"lon"})
45 | print("Renamed `longitude` to `lon`")
46 | assert "time" in ds.dims, f"Must have `time` as dimension in object. Currently: {ds.dims}"
47 |
48 | return ds
49 |
50 |
51 | def select_same_time_slice(reference_ds, ds):
52 | """ Select the values for the same timestep as the """
53 | # CHECK THEY ARE THE SAME FREQUENCY
54 | # get the frequency of the time series from reference_ds
55 | freq = pd.infer_freq(reference_ds.time.values)
56 | old_freq = pd.infer_freq(ds.time.values)
57 | assert freq == old_freq, f"The frequencies should be the same! currenlty ref: {freq} vs. old: {old_freq}"
58 |
59 | # get the STARTING time point from the reference_ds
60 | min_time = reference_ds.time.min().values
61 | max_time = reference_ds.time.max().values
62 | orig_time_range = pd.date_range(min_time, max_time, freq=freq)
63 | # EXTEND the original time_range by 1 (so selecting the whole slice)
64 | # because python doesn't select the final in a range
65 | periods = len(orig_time_range) + 1
66 | # create new time series going ONE EXTRA PERIOD
67 | new_time_range = pd.date_range(min_time, freq=freq, periods=periods)
68 | new_max = new_time_range.max()
69 |
70 | # select using the NEW MAX as upper limit
71 | ds = ds.sel(time=slice(min_time, new_max))
72 | # assert reference_ds.time.shape[0] == ds.time.shape[0],"The time dimensions should match, currently reference_ds.time dims {reference_ds.time.shape[0]} != ds.time dims {ds.time.shape[0]}"
73 |
74 | print_time_min = pd.to_datetime(ds.time.min().values)
75 | print_time_max = pd.to_datetime(ds.time.max().values)
76 | try:
77 | vars = [i for i in ds.var().variables]
78 | except:
79 | vars = ds.name
80 | ref_vars = [i for i in reference_ds.var().variables]
81 | print(f"Select same timeslice for ds with vars: {vars}. Min {print_time_min} Max {print_time_max}")
82 |
83 | return ds
84 |
85 |
86 | def select_same_lat_lon_slice(reference_ds, ds):
87 | """
88 | Take a slice of data from the `ds` according to the bounding box from
89 | the reference_ds.
90 | NOTE: - latitude has to be from max() to min() for some reason?
91 | - becuase it's crossing the equator? e.g. -14.:8.
92 | Therefore, have to run an if statement to decide which way round to put the data
93 | """
94 | # lat_bounds = [reference_ds.lat.min(),reference_ds.lat.max()]
95 | # lon_bounds = [reference_ds.lon.min(),reference_ds.lon.max()]
96 | if len(ds.sel(lat=slice(reference_ds.lat.min(), reference_ds.lat.max())).lat) == 0:
97 | ds = ds.sel(lat=slice(reference_ds.lat.max(), reference_ds.lat.min()))
98 | else:
99 | ds = ds.sel(lat=slice(reference_ds.lat.min(), reference_ds.lat.max()))
100 | ds = ds.sel(lon=slice(reference_ds.lon.min(), reference_ds.lon.max()))
101 |
102 | try:
103 | vars = [i for i in ds.var().variables]
104 | except:
105 | vars = ds.name
106 | ref_vars = [i for i in reference_ds.var().variables]
107 | print(f"Select the same bounding box for ds {vars} from reference_ds {ref_vars}")
108 | return ds
109 |
110 |
111 | def open_drought_ds(data_path = 'spei_spi.nc'):
112 | """ returns the raw dataset & 2x boolean masks (SPI/SPEI)"""
113 | # Data path on MONTHLY GRID
114 |
115 | # open the data ()
116 | ds = xr.open_dataset(data_path)
117 | ds = ds.rename({"value":"spi"})
118 | ds = normalise_coordinate_names(ds)
119 |
120 | # turn the values less than -1 into mask
121 | spei_drought = ds.spei.where(ds.spei < -1)
122 | spi_drought = ds.spi.where(ds.spi < -1)
123 |
124 | # turn into a boolean mask
125 | drought_spei = spei_drought.notnull()
126 | drought_spei = drought_spei.rename('drought_spei')
127 | drought_spi = spi_drought.notnull()
128 | drought_spi = drought_spi.rename('drought_spi')
129 |
130 | return ds, drought_spei, drought_spi
131 |
132 |
133 | def ensure_same_time(reference_ds, ds):
134 | """ convert the TIMESTEPS to the same values
135 | e.g. • if Monthly data sometimes its in the middle - 01-16-98
136 | • other times its the start 01-01-98, othertimes 01-31-98
137 | Set them all to the same frequency using the freq from reference_ds
138 | """
139 | freq = pd.infer_freq(reference_ds.time.values)
140 | dr = pd.date_range(reference_ds.time.min().values , periods=len(ds.time.values), freq=freq)
141 | ds['time'] = dr
142 |
143 | return ds
144 |
145 |
146 | def convert_to_same_time_freq(reference_ds,ds):
147 | """ Upscale or downscale data so on the same time frequencies
148 | e.g. convert daily data to monthly ('MS' = month start)
149 | """
150 | freq = pd.infer_freq(reference_ds.time.values)
151 | ds = ds.resample(time='MS').median(dim='time')
152 |
153 | try:
154 | vars = [i for i in ds.var().variables]
155 | except:
156 | vars = ds.name
157 | print(f"Resampled ds ({vars}) to {freq}")
158 | return ds
159 |
160 |
161 | def convert_to_same_grid(reference_ds, ds):
162 | """ Use xEMSF package to regrid ds to the same grid as reference_ds """
163 | assert ("lat" in reference_ds.dims)&("lon" in reference_ds.dims), f"Need (lat,lon) in reference_ds dims Currently: {reference_ds.dims}"
164 | assert ("lat" in ds.dims)&("lon" in ds.dims), f"Need (lat,lon) in ds dims Currently: {ds.dims}"
165 |
166 | # create the grid you want to convert TO (from reference_ds)
167 | ds_out = xr.Dataset({
168 | 'lat': (['lat'], reference_ds.lat),
169 | 'lon': (['lon'], reference_ds.lon),
170 | })
171 |
172 | # create the regridder object
173 | # xe.Regridder(grid_in, grid_out, method='bilinear')
174 | regridder = xe.Regridder(ds, ds_out, 'bilinear', reuse_weights=True)
175 |
176 | # IF it's a dataarray just do the original transformations
177 | if isinstance(ds, xr.core.dataarray.DataArray):
178 | ds = regridder(ds)
179 | # OTHERWISE loop through each of the variables, regrid the datarray then recombine into dataset
180 | elif isinstance(ds, xr.core.dataset.Dataset):
181 | vars = [i for i in ds.var().variables]
182 | if len(vars) ==1 :
183 | ds = regridder(ds)
184 | else:
185 | output_dict = {}
186 | # LOOP over each variable and append to dict
187 | for var in vars:
188 | print(f"- regridding var {var} -")
189 | da = ds[var]
190 | da = regridder(da)
191 | output_dict[var] = da
192 | # REBUILD
193 | ds = xr.Dataset(output_dict)
194 | else:
195 | assert False, "This function only works with xarray dataset / dataarray objects"
196 |
197 | print(f"Regridded from {(regridder.Ny_in, regridder.Nx_in)} to {(regridder.Ny_out, regridder.Nx_out)}")
198 |
199 | return ds
200 |
201 |
202 | def netcdf_to_same_dim_shapes(reference_ds, *other_netcdfs):
203 | """ loop over each of the other_netcdfs and reshape them in TIME, LAT/LON
204 | to be the same as the reference_ds.
205 | This allows them to be stored as a single netcdf file
206 | """
207 | ds_to_merge = []
208 | for ds_ in tqdm.tqdm([*other_netcdfs]):
209 | ds_ = normalise_coordinate_names(ds_)
210 | assert ("lat" in ds_.dims)&("lon" in ds_.dims),f"lat and lon should be in ds.dims. Currently: {ds_.dims}"
211 |
212 | # select the same SLICE of time (reference_ds.min() - reference_ds.min())
213 | ds_ = convert_to_same_time_freq(reference_ds, ds_)
214 | ds_ = select_same_time_slice(reference_ds, ds_)
215 |
216 | # REGRID to same dimensions
217 | ds_ = convert_to_same_grid(reference_ds, ds_)
218 | # select the same lat lon slice
219 | ds_ = select_same_lat_lon_slice(reference_ds, ds_)
220 |
221 | # ENSURE DIMS MATCH
222 | first_var = [i for i in reference_ds.var().variables][0]
223 | np.testing.assert_allclose(reference_ds[first_var].shape[0], ds_.shape[0], atol=1), f"The TIME DIMENSION should be equal. Currently (time, lat, lon) {ds_.shape} should be {reference_ds[first_var].shape}"
224 | assert reference_ds[first_var].shape[1] == ds_.shape[1], f"The LAT DIMENSION should be equal. Currently (time, lat, lon) {ds_.shape} should be {reference_ds[first_var].shape}"
225 | assert reference_ds[first_var].shape[2] == ds_.shape[2], f"The LON DIMENSION should be equal. Currently (time, lat, lon) {ds_.shape} should be {reference_ds[first_var].shape}"
226 |
227 | ds_to_merge.append(ds_)
228 |
229 | return ds_to_merge
230 |
231 |
232 | def get_all_vars_from_ds(ds):
233 | """ return a list of strings for all variables in dataset """
234 | assert isinstance(ds, xr.core.dataset.Dataset), f"Currently only works with xr.Dataset objects. ds = {type(ds)}"
235 | vars = [i for i in ds.var().variables]
236 | return vars
237 |
238 |
239 | def append_reference_ds_vars(reference_ds, ds_to_merge):
240 | """merge in the reference_ds variables to the ds_to_merge"""
241 | for var in reference_ds.var().variables:
242 | ds_to_merge.append(reference_ds[var])
243 |
244 | return ds_to_merge
245 |
246 |
247 | def merge_netcdfs_to_one_file(reference_ds, *other_netcdfs, drought=False):
248 | """ merge the netcdf files with multiple variables into ONE netcdf file.
249 | Use the structure from the reference_ds to get the same TIME & LAT/LON GRID.
250 | """
251 | ds_to_merge = netcdf_to_same_dim_shapes(reference_ds, *other_netcdfs)
252 | ds_to_merge = append_reference_ds_vars(reference_ds, ds_to_merge)
253 |
254 | # join in the drought mask
255 | # --------------------------------------------------------------------------
256 | warnings.warn('This is done once here to put the ndvi mask into the nc file')
257 | output_ds = xr.merge(ds_to_merge)
258 | # --------------------------------------------------------------------------
259 | if drought:
260 | warnings.warn('Should be less focused on hardcoding for the drought variables. need to have a look at this')
261 | # TODO: reference ds might not be the drought mask. this functionality should be elsewhere
262 | spei_mask, spi_mask, ndvi_mask = create_drought_mask(output_ds[['spi','spei','ndvi']])
263 | ds_to_merge.append(spei_mask)
264 | ds_to_merge.append(spi_mask)
265 | ds_to_merge.append(ndvi_mask)
266 |
267 | output_ds = xr.merge(ds_to_merge)
268 |
269 | return output_ds
270 |
271 |
272 | def check_other_netcdfs(*other_netcdfs):
273 | """ if dataarray check that it's named! """
274 | for i, xr_obj in enumerate([*other_netcdfs]):
275 | if isinstance(xr_obj, xr.core.dataarray.DataArray):
276 | assert xr_obj.name != None, f"All dataarrays must be named! Dataarray #{i+1} not named"
277 | return
278 |
279 |
280 | def read_files(files):
281 | """ read in the files to be """
282 | xr_objs = []
283 | for file in files:
284 | xr_obj = xr.open_dataset(file)
285 | xr_objs.append(xr_obj)
286 |
287 | return xr_objs
288 |
289 |
290 | # @click.command()
291 | # @click.argument('reference_ds_path', type=click.Path(exists=True), default="EA_data/spei_spi.nc")
292 | # @click.option('files', '--files', envvar='FILES', multiple=True, type=click.Path())
293 | # @click.argument('output', type=click.File('wb'))
294 | # @click.option('--drought', default=False)
295 | if __name__ == "__main__":
296 | # TODO: IMPLEMENT THIS ALL IN PARALLEL
297 |
298 | reference_ds_path = '/soge-home/users/chri4118/EA_data/spei_spi.nc'
299 | # the reference ds
300 | reference_ds = xr.open_dataset('/soge-home/users/chri4118/EA_data/spei_spi.nc')
301 | reference_ds = reference_ds.rename({"value":'spi'})
302 | reference_ds = normalise_coordinate_names(reference_ds)
303 |
304 | # the output filepath
305 | output = "OUT.nc"
306 |
307 | # --------------------------------------------------------------------------
308 | # TODO: THIS ALL NEEDS TO BE MORE DYNAMICALLY SET UP
309 | # variables to join
310 | TEMP = xr.open_dataset("/soge-home/users/chri4118/EA_data/LST_EastAfrica.nc")
311 | lst_day = TEMP.lst_day
312 | lst_night = TEMP.lst_night
313 | lst_mean = (TEMP.lst_day + TEMP.lst_night) / 2
314 | lst_mean.name = "lst_mean"
315 |
316 | ET = xr.open_dataset("/soge-home/users/chri4118/EA_data/ET_EastAfrica.nc")
317 | evap = ET.evaporation
318 | baresoil_evap = ET.baresoil_evaporation
319 | pet = ET.potential_evaporation
320 | transp = ET.transpiration
321 | surface_sm = ET.surface_soil_moisture
322 | rootzone_sm = ET.rootzone_soil_moisture
323 |
324 | SM = xr.open_dataset('/soge-home/users/chri4118/EA_data/SM_EastAfrica.nc')
325 | sm = SM.sm
326 |
327 | PCP = xr.open_dataset('/soge-home/projects/crop_yield/chirps/EA_precip.nc')
328 | precip = PCP.precip
329 |
330 | VEG = xr.open_dataset('/soge-home/users/chri4118/EA_data/NDVI_EastAfrica.nc')
331 | ndvi = VEG.ndvi
332 | evi = VEG.evi
333 |
334 | # concatenate into one list to pass to the function
335 | vars_list = [lst_day, lst_night, lst_mean, lst_mean, evap, baresoil_evap, pet, transp, surface_sm, rootzone_sm, sm, precip, ndvi, evi]
336 |
337 | print(f"Reading Data from: \n{TEMP}\n{ET}\n{SM}\n{PCP}\n{VEG}")
338 | # --------------------------------------------------------------------------
339 |
340 | check_other_netcdfs(*vars_list)
341 | output_ds = merge_netcdfs_to_one_file(reference_ds, *vars_list, drought=True)
342 | save_netcdf(output_ds, output)
343 |
344 | print("** Process Finished **")
345 |
--------------------------------------------------------------------------------
/notebooks/01_tl_data_exploration.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import xarray as xr\n",
10 | "import pandas as pd\n",
11 | "import numpy as np\n",
12 | "import seaborn as sns\n",
13 | "import matplotlib.pyplot as plt"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 2,
19 | "metadata": {},
20 | "outputs": [
21 | {
22 | "data": {
23 | "text/plain": [
24 | "\n",
25 | "Dimensions: (lat: 404, lon: 316, time: 85)\n",
26 | "Coordinates:\n",
27 | " * time (time) datetime64[ns] 2010-01-01 ... 2017-01-01\n",
28 | " * lon (lon) float32 32.524994 32.574997 ... 48.274994\n",
29 | " * lat (lat) float32 -4.9750023 -4.925003 ... 15.174995\n",
30 | " month (time) int64 ...\n",
31 | "Data variables:\n",
32 | " lst_day (time, lat, lon) float64 ...\n",
33 | " lst_night (time, lat, lon) float64 ...\n",
34 | " lst_mean (time, lat, lon) float64 ...\n",
35 | " evaporation (time, lat, lon) float64 ...\n",
36 | " baresoil_evaporation (time, lat, lon) float64 ...\n",
37 | " potential_evaporation (time, lat, lon) float64 ...\n",
38 | " transpiration (time, lat, lon) float64 ...\n",
39 | " surface_soil_moisture (time, lat, lon) float64 ...\n",
40 | " rootzone_soil_moisture (time, lat, lon) float64 ...\n",
41 | " sm (time, lat, lon) float64 ...\n",
42 | " precip (time, lat, lon) float64 ...\n",
43 | " ndvi (time, lat, lon) float64 ...\n",
44 | " evi (time, lat, lon) float64 ...\n",
45 | " spei (time, lat, lon) float32 ...\n",
46 | " spi (time, lat, lon) float64 ...\n",
47 | " drought_spei (time, lat, lon) bool ...\n",
48 | " drought_spi (time, lat, lon) bool ...\n",
49 | " drought_ndvi (time, lat, lon) bool ..."
50 | ]
51 | },
52 | "execution_count": 2,
53 | "metadata": {},
54 | "output_type": "execute_result"
55 | }
56 | ],
57 | "source": [
58 | "ds = xr.open_dataset('/Volumes/Lees_Extend/data/ea_data/OUT.nc')\n",
59 | "ds"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 61,
65 | "metadata": {},
66 | "outputs": [
67 | {
68 | "name": "stdout",
69 | "output_type": "stream",
70 | "text": [
71 | "Done for Varible ndvi\n",
72 | "Done for Varible spi\n",
73 | "Done for Varible precip\n",
74 | "Done for Varible sm\n",
75 | "Done for Varible lst_day\n",
76 | "Done for Varible lst_night\n",
77 | "Done for Varible lst_mean\n"
78 | ]
79 | },
80 | {
81 | "ename": "AttributeError",
82 | "evalue": "'Figure' object has no attribute 'title'",
83 | "output_type": "error",
84 | "traceback": [
85 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
86 | "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
87 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;31m# variables = [\"ndvi\",\"spi\"]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0;31m# TODO: why doesn't it work with 2 values?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 65\u001b[0;31m \u001b[0mfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mplot_variables\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvariables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 66\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msavefig\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'../figs/variable_distribution.png'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
88 | "\u001b[0;32m\u001b[0m in \u001b[0;36mplot_variables\u001b[0;34m(Dataset, variables)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'Done for Varible {var}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtitle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Distribution of variable values'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
89 | "\u001b[0;31mAttributeError\u001b[0m: 'Figure' object has no attribute 'title'"
90 | ]
91 | },
92 | {
93 | "data": {
94 | "image/png": "\n",
95 | "text/plain": [
96 | ""
97 | ]
98 | },
99 | "metadata": {
100 | "needs_background": "light"
101 | },
102 | "output_type": "display_data"
103 | }
104 | ],
105 | "source": [
106 | "def drop_nans_and_flatten(dataArray):\n",
107 | " \"\"\"flatten the array and drop nans from that array. Useful for plotting histograms.\n",
108 | " \n",
109 | " Arguments:\n",
110 | " ---------\n",
111 | " : dataArray (xr.DataArray)\n",
112 | " the DataArray of your value you want to flatten\n",
113 | " \"\"\"\n",
114 | " # drop NaNs and flatten\n",
115 | " return dataArray.values[~np.isnan(dataArray.values)]\n",
116 | "\n",
117 | "\n",
118 | "def set_no_axs_rows(variables):\n",
119 | " \"\"\"Dynamically set the number of axes rows\n",
120 | " \n",
121 | " Arguments:\n",
122 | " ---------\n",
123 | " : variables (list)\n",
124 | " the variables that we want to plot distributions of\n",
125 | " \"\"\"\n",
126 | " # Dynamically set the number of ROWS in the subplots\n",
127 | " if len(variables) % 3 > 0:\n",
128 | " n_axs = (len(variables) // 3) + 1\n",
129 | " else:\n",
130 | " n_axs = len(variables) // 3\n",
131 | " \n",
132 | " return n_axs\n",
133 | "\n",
134 | "def plot_variables(Dataset, variables):\n",
135 | " \"\"\" plot histograms of \n",
136 | " \n",
137 | " Arguments:\n",
138 | " ---------\n",
139 | " : Dataset (xr.Dataset)\n",
140 | " the dataset holding the variables of interest\n",
141 | " : variables (list)\n",
142 | " list of str for the labels of the values that you want to plot\n",
143 | " \"\"\"\n",
144 | " assert all([var_ in [var for var in ds.variables.keys()] for var_ in variables]), f\"The variables supplied must be in the xr.Dataset variables. Currently looking for {variables} in dataset variables: {[var for var in ds.variables.keys()]}\\n The variable missing is: {[var_ in [var for var in ds.variables.keys()] for var_ in variables]}\"\n",
145 | " \n",
146 | " n_axs = set_no_axs_rows(variables)\n",
147 | " fig, axs = plt.subplots(n_axs, 3, figsize=(15,8))\n",
148 | " \n",
149 | " # plot each of the variables \n",
150 | " for ix, var in enumerate(variables):\n",
151 | " # get the axes we are plotting on\n",
152 | " ax_ix = np.unravel_index(ix+1,(n_axs,3))\n",
153 | " ax = axs[ax_ix]\n",
154 | " \n",
155 | " # flatten array and drop nans from the variable\n",
156 | " flat_array = drop_nans_and_flatten(ds[var])\n",
157 | " \n",
158 | " # plot the histogram for that variable\n",
159 | " sns.distplot(flat_array, ax=ax)\n",
160 | " ax.set_title(f'{var} Histogram')\n",
161 | " print(f'Done for Varible {var}')\n",
162 | " \n",
163 | " fig.suptitle('Distribution of variable values')\n",
164 | " plt.tight_layout()\n",
165 | " \n",
166 | " return fig\n",
167 | " \n",
168 | "variables = [\"ndvi\",\"spi\",\"precip\",\"sm\",\"lst_day\",\"lst_night\",\"lst_mean\"]\n",
169 | "# variables = [\"ndvi\",\"spi\"]\n",
170 | "# TODO: why doesn't it work with 2 values?\n",
171 | "fig = plot_variables(ds, variables)\n",
172 | "fig.savefig('../figs/variable_distribution.png')"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 49,
178 | "metadata": {},
179 | "outputs": [
180 | {
181 | "data": {
182 | "text/plain": [
183 | "(0, 1)"
184 | ]
185 | },
186 | "execution_count": 49,
187 | "metadata": {},
188 | "output_type": "execute_result"
189 | }
190 | ],
191 | "source": [
192 | "assert False, \"The following does not work because the unravel index doesn't work for \"\n",
193 | "\n",
194 | "variables = [\"ndvi\",\"spi\"]\n",
195 | "plot_variables(ds, variables)"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": null,
201 | "metadata": {},
202 | "outputs": [],
203 | "source": []
204 | }
205 | ],
206 | "metadata": {
207 | "kernelspec": {
208 | "display_name": "Python 3",
209 | "language": "python",
210 | "name": "python3"
211 | },
212 | "language_info": {
213 | "codemirror_mode": {
214 | "name": "ipython",
215 | "version": 3
216 | },
217 | "file_extension": ".py",
218 | "mimetype": "text/x-python",
219 | "name": "python",
220 | "nbconvert_exporter": "python",
221 | "pygments_lexer": "ipython3",
222 | "version": "3.7.2"
223 | }
224 | },
225 | "nbformat": 4,
226 | "nbformat_minor": 2
227 | }
228 |
--------------------------------------------------------------------------------