├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── $CACHE_FILE$ ├── misc.xml ├── modules.xml └── context-aware-representation-crop-yield-prediction.iml ├── .DS_Store ├── data_preprocessing ├── preprocess │ ├── __init__.py │ ├── cdl.py │ ├── prism.py │ ├── subset.py │ ├── landsat.py │ ├── lai.py │ ├── landcover.py │ ├── lst.py │ └── county_locations.py ├── merge │ ├── __init__.py │ └── merge_various_days.py ├── plot │ ├── __init__.py │ ├── plot_local.py │ └── counties_plot.py ├── rescaling │ ├── __init__.py │ ├── rescale_utils.py │ ├── prism_upscale.py │ ├── soil_fraction.py │ ├── soil_moisture.py │ ├── prism_downscale.py │ ├── elevation.py │ ├── us_counties.py │ ├── lst.py │ ├── nws_precip.py │ └── cdl_upscale.py ├── sample_quadruplets │ └── __init__.py ├── postprocess │ └── __init__.py ├── utils │ ├── get_lat_lon_bins.py │ ├── get_closest_date.py │ ├── __init__.py │ ├── match_lat_lon.py │ ├── timing.py │ └── generate_doy.py └── __init__.py ├── crop_yield_prediction ├── models │ ├── __init__.py │ ├── c3d │ │ ├── __init__.py │ │ └── conv3d.py │ ├── cnn_lstm │ │ ├── __init__.py │ │ └── cnn_lstm.py │ ├── deep_gaussian_process │ │ ├── __init__.py │ │ ├── loss.py │ │ ├── feature_engineering.py │ │ ├── gp.py │ │ └── rnn.py │ └── semi_transformer │ │ ├── __init__.py │ │ ├── Layers.py │ │ ├── Modules.py │ │ ├── Optim.py │ │ ├── AttentionModels.py │ │ ├── SemiTransformer.py │ │ ├── SubLayers.py │ │ └── TileNet.py ├── plot │ ├── __init__.py │ ├── plot_crop_yield.py │ ├── plot_crop_yield_prediction_error.py │ └── plot_loss.py ├── __init__.py ├── utils │ ├── logger.py │ ├── __init__.py │ ├── timing.py │ └── train_utils.py └── dataloader │ ├── __init__.py │ ├── cross_location_dataloader.py │ ├── semi_cropyield_dataloader.py │ ├── cnn_lstm_dataloader.py │ └── c3d_dataloader.py ├── LICENSE ├── CONTRIBUTING.md ├── generate_feature_importance_data.py ├── crop_yield_no_spatial.py ├── crop_yield_deep_gaussian.py ├── CODE_OF_CONDUCT.md ├── generate_for_deep_gaussian.py ├── generate_experiment_data.py ├── README.md ├── crop_yield_train_c3d.py └── crop_yield_train_cnn_lstm.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/Context-Aware-Representation-Crop-Yield-Prediction/HEAD/.DS_Store -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | -------------------------------------------------------------------------------- /.idea/$CACHE_FILE$: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /data_preprocessing/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .cdl import cdl_convert_to_nc 8 | 9 | __all__ = ["cdl_convert_to_nc"] 10 | -------------------------------------------------------------------------------- /data_preprocessing/merge/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .merge_various_days import merge_various_days 8 | 9 | __all__ = ['merge_various_days'] 10 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .no_spatial import predict_no_spatial 8 | 9 | __all__ = ['predict_no_spatial'] 10 | 11 | -------------------------------------------------------------------------------- /data_preprocessing/plot/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .counties_plot import counties_plot, save_colorbar 8 | 9 | __all__ = ['counties_plot', 'save_colorbar'] 10 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /data_preprocessing/rescaling/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .rescale_utils import search_kdtree 8 | from .rescale_utils import get_lat_lon_bins 9 | 10 | __all__ = ['search_kdtree', 'get_lat_lon_bins'] 11 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/c3d/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Based on transformer code from https://github.com/jadore801120/attention-is-all-you-need-pytorch 7 | 8 | 9 | from crop_yield_prediction.models.c3d.conv3d import C3D 10 | 11 | __all__ = ['C3D'] 12 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/cnn_lstm/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Based on transformer code from https://github.com/jadore801120/attention-is-all-you-need-pytorch 7 | 8 | 9 | from crop_yield_prediction.models.cnn_lstm.cnn_lstm import CnnLstm 10 | 11 | __all__ = ['CnnLstm'] 12 | -------------------------------------------------------------------------------- /crop_yield_prediction/plot/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .plot_crop_yield import crop_yield_plot 8 | from .plot_crop_yield_prediction_error import crop_yield_prediction_error_plot 9 | 10 | __all__ = ['crop_yield_plot', 11 | 'crop_yield_prediction_error_plot'] 12 | -------------------------------------------------------------------------------- /.idea/context-aware-representation-crop-yield-prediction.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /data_preprocessing/sample_quadruplets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sample_for_counties import generate_training_for_counties 8 | from .sample_for_pretrained import generate_training_for_pretrained 9 | 10 | __all__ = ["generate_training_for_counties", 11 | "generate_training_for_pretrained"] 12 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/deep_gaussian_process/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .feature_engineering import get_features_for_deep_gaussian 8 | from .convnet import ConvModel 9 | from .rnn import RNNModel 10 | 11 | __all__ = ['get_features_for_deep_gaussian', 12 | 'ConvModel', 13 | 'RNNModel'] 14 | -------------------------------------------------------------------------------- /crop_yield_prediction/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | __all__ = ['CLIMATE_VARS', 'STATIC_CLIMATE_VARS', 'DYNAMIC_CLIMATE_VARS'] 8 | 9 | CLIMATE_VARS = ['ppt', 'evi', 'ndvi', 'elevation', 'lst_day', 'lst_night', 'clay', 'sand', 'silt'] 10 | STATIC_CLIMATE_VARS = ['elevation', 'clay', 'sand', 'silt'] 11 | DYNAMIC_CLIMATE_VARS = [x for x in CLIMATE_VARS if x not in STATIC_CLIMATE_VARS] 12 | -------------------------------------------------------------------------------- /data_preprocessing/postprocess/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .combine_multi_vars import mask_non_major_states 8 | from .combine_multi_vars import generate_no_spatial_for_counties 9 | from .combine_multi_vars import obtain_channel_wise_mean_std 10 | 11 | __all__ = ['mask_non_major_states', 12 | 'generate_no_spatial_for_counties', 13 | 'obtain_channel_wise_mean_std'] 14 | -------------------------------------------------------------------------------- /crop_yield_prediction/utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | 9 | 10 | class Logger(object): 11 | def __init__(self, filename="Default.log"): 12 | self.terminal = sys.stdout 13 | self.log = open(filename, "a+") 14 | 15 | def write(self, message): 16 | self.terminal.write(message) 17 | self.log.write(message) 18 | 19 | def close(self): 20 | self.log.close() 21 | 22 | def flush(self): 23 | pass 24 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/semi_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree.# Based on transformer code from https://github.com/jadore801120/attention-is-all-you-need-pytorch 6 | 7 | from crop_yield_prediction.models.semi_transformer.SemiTransformer import SemiTransformer 8 | from crop_yield_prediction.models.semi_transformer.TileNet import make_tilenet 9 | from crop_yield_prediction.models.semi_transformer.Optim import ScheduledOptim 10 | 11 | __all__ = ['SemiTransformer', 'ScheduledOptim', 'make_tilenet'] 12 | -------------------------------------------------------------------------------- /data_preprocessing/utils/get_lat_lon_bins.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | 9 | 10 | def get_lat_lon_bins(lats, lons): 11 | inter_lat = np.array([(x + y) / 2.0 for x, y in zip(lats[:-1], lats[1:])]) 12 | inter_lon = np.array([(x + y) / 2.0 for x, y in zip(lons[:-1], lons[1:])]) 13 | lat_bins = np.concatenate([[2 * inter_lat[0] - inter_lat[1]], inter_lat, [2 * inter_lat[-1] - inter_lat[-2]]]) 14 | lon_bins = np.concatenate([[2 * inter_lon[0] - inter_lon[1]], inter_lon, [2 * inter_lon[-1] - inter_lon[-2]]]) 15 | 16 | return lats, lons, lat_bins, lon_bins -------------------------------------------------------------------------------- /crop_yield_prediction/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from crop_yield_prediction.dataloader.c3d_dataloader import c3d_dataloader 8 | from crop_yield_prediction.dataloader.semi_cropyield_dataloader import semi_cropyield_dataloader 9 | from crop_yield_prediction.dataloader.cnn_lstm_dataloader import cnn_lstm_dataloader 10 | from crop_yield_prediction.dataloader.cross_location_dataloader import cross_location_dataloader 11 | 12 | 13 | __all__ = ['c3d_dataloader', 14 | 'semi_cropyield_dataloader', 15 | 'cnn_lstm_dataloader', 16 | 'cross_location_dataloader'] 17 | -------------------------------------------------------------------------------- /data_preprocessing/utils/get_closest_date.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from datetime import date 9 | 10 | 11 | def get_closet_date(query_date, folder): 12 | doys = [x[:-3] for x in os.listdir(folder) if x.endswith('.nc')] 13 | doys = [date(*map(int, [x[:4], x[4:6], x[6:]])) for x in doys] 14 | query_date = date(*map(int, [query_date[:4], query_date[4:6], query_date[6:]])) 15 | 16 | return str(min(doys, key=lambda x: abs(x - query_date))).replace('-', '') 17 | 18 | 19 | if __name__ == '__main__': 20 | print(get_closet_date('20170101', ['20161230', '20170503', '20170105'])) 21 | 22 | -------------------------------------------------------------------------------- /data_preprocessing/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .get_lat_lon_bins import get_lat_lon_bins 8 | from .timing import timeit, timenow 9 | from .generate_doy import generate_doy, generate_nearest_doys, generate_most_recent_doys, generate_doy_every_n 10 | from .generate_doy import generate_future_doys 11 | from .get_closest_date import get_closet_date 12 | from .match_lat_lon import match_lat_lon 13 | 14 | __all__ = ["get_lat_lon_bins", 15 | "timeit", "timenow", 16 | "generate_doy", "generate_most_recent_doys", "generate_nearest_doys", 17 | "generate_doy_every_n", "generate_future_doys", 18 | "get_closest_date", 19 | "match_lat_lon"] 20 | -------------------------------------------------------------------------------- /data_preprocessing/utils/match_lat_lon.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | def match_lat_lon(lats_from, lons_from, lats_to, lons_to, expand=0): 9 | i_lat_start = i_lat_end = i_lon_start = i_lon_end = 0 10 | 11 | for i in range(len(lats_from)): 12 | if abs(lats_from[i] - lats_to[0]) < 0.00001: 13 | i_lat_start = i - expand 14 | if abs(lats_from[i] - lats_to[-1]) < 0.00001: 15 | i_lat_end = i + expand 16 | for i in range(len(lons_from)): 17 | if abs(lons_from[i] - lons_to[0]) < 0.00001: 18 | i_lon_start = i - expand 19 | if abs(lons_from[i] - lons_to[-1]) < 0.00001: 20 | i_lon_end = i + expand 21 | 22 | return i_lat_start, i_lat_end, i_lon_start, i_lon_end 23 | -------------------------------------------------------------------------------- /data_preprocessing/rescaling/rescale_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from scipy.spatial import cKDTree 8 | import numpy as np 9 | 10 | 11 | def search_kdtree(lats, lons, points): 12 | mytree = cKDTree(np.dstack([lats.ravel(), lons.ravel()])[0]) 13 | print('Finish building KDTree') 14 | dist, indices = mytree.query(points) 15 | return indices 16 | 17 | 18 | def get_lat_lon_bins(lats, lons): 19 | inter_lat = np.array([(x + y) / 2.0 for x, y in zip(lats[:-1], lats[1:])]) 20 | inter_lon = np.array([(x + y) / 2.0 for x, y in zip(lons[:-1], lons[1:])]) 21 | lat_bins = np.concatenate([[2 * inter_lat[0] - inter_lat[1]], inter_lat, [2 * inter_lat[-1] - inter_lat[-2]]]) 22 | lon_bins = np.concatenate([[2 * inter_lon[0] - inter_lon[1]], inter_lon, [2 * inter_lon[-1] - inter_lon[-2]]]) 23 | 24 | return lat_bins, lon_bins 25 | -------------------------------------------------------------------------------- /crop_yield_prediction/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .timing import timeit, timenow 8 | from .logger import Logger 9 | from .train_utils import get_statistics 10 | from .train_utils import get_latest_model_dir 11 | from .train_utils import get_latest_model 12 | from .train_utils import get_latest_models_cvs 13 | from .train_utils import plot_predict 14 | from .train_utils import plot_predict_error 15 | from .train_utils import output_to_csv_no_spatial 16 | from .train_utils import output_to_csv_complex 17 | from .train_utils import output_to_csv_simple 18 | 19 | __all__ = ['timeit', 'timenow', 20 | 'Logger', 21 | 'get_statistics', 'get_latest_model_dir', 'get_latest_model', 'get_latest_models_cvs', 22 | 'plot_predict', 'plot_predict_error', 23 | 'output_to_csv_no_spatial', 'output_to_csv_complex', 'output_to_csv_simple'] 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/semi_transformer/Layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Based on transformer code from https://github.com/jadore801120/attention-is-all-you-need-pytorch 7 | 8 | from crop_yield_prediction.models.semi_transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward 9 | 10 | import torch.nn as nn 11 | 12 | 13 | class EncoderLayer(nn.Module): 14 | ''' Compose with two layers ''' 15 | 16 | def __init__(self, n_tsteps, query_type, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 17 | super(EncoderLayer, self).__init__() 18 | self.slf_attn = MultiHeadAttention(n_tsteps, query_type, n_head, d_model, d_k, d_v, dropout=dropout) 19 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 20 | 21 | def forward(self, enc_input, slf_attn_mask=None): 22 | enc_output, enc_slf_attn = self.slf_attn(enc_input, enc_input, enc_input) 23 | enc_output = self.pos_ffn(enc_output) 24 | return enc_output, enc_slf_attn 25 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/semi_transformer/Modules.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Based on transformer code from https://github.com/jadore801120/attention-is-all-you-need-pytorch 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class ScaledDotProductAttention(nn.Module): 14 | ''' Scaled Dot-Product Attention ''' 15 | 16 | def __init__(self, temperature, attn_dropout=0.1): 17 | super().__init__() 18 | self.temperature = temperature 19 | self.dropout = nn.Dropout(attn_dropout) 20 | 21 | def forward(self, q, k, v, gq=None): 22 | 23 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 24 | 25 | attn = self.dropout(F.softmax(attn, dim=-1)) 26 | if gq is not None: 27 | attn_gq = torch.matmul(gq / self.temperature, k.transpose(2, 3)) 28 | attn_gq = self.dropout(F.softmax(attn_gq, dim=-1)) 29 | attn += attn_gq 30 | output = torch.matmul(attn, v) 31 | 32 | return output, attn 33 | -------------------------------------------------------------------------------- /crop_yield_prediction/utils/timing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # https://stackoverflow.com/questions/1557571/how-do-i-get-time-of-a-python-programs-execution 8 | import atexit 9 | from time import time, clock 10 | from time import strftime, localtime 11 | import functools 12 | 13 | 14 | def _secondsToStr(t): 15 | return "%d:%02d:%02d.%03d" % \ 16 | functools.reduce(lambda ll,b : divmod(ll[0],b) + ll[1:], [(t*1000,),1000,60,60]) 17 | 18 | 19 | def _log(s, elapsed=None): 20 | line = "=" * 40 21 | print(line) 22 | print(s) 23 | print(strftime("%Y-%m-%d %H:%M:%S", localtime())) 24 | if elapsed: 25 | print("Elapsed time:", elapsed) 26 | print(line) 27 | print() 28 | 29 | 30 | def _endlog(start): 31 | end = time() 32 | elapsed = end-start 33 | _log("End Program", _secondsToStr(elapsed)) 34 | 35 | 36 | def timenow(): 37 | print(strftime("%Y-%m-%d %H:%M:%S", localtime()), _secondsToStr(clock())) 38 | 39 | 40 | def timeit(): 41 | start = time() 42 | atexit.register(_endlog, start) 43 | _log("Start Program") 44 | -------------------------------------------------------------------------------- /data_preprocessing/utils/timing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # https://stackoverflow.com/questions/1557571/how-do-i-get-time-of-a-python-programs-execution 8 | import atexit 9 | from time import time, clock 10 | from time import strftime, localtime 11 | import functools 12 | 13 | 14 | def _secondsToStr(t): 15 | return "%d:%02d:%02d.%03d" % \ 16 | functools.reduce(lambda ll,b : divmod(ll[0],b) + ll[1:], [(t*1000,),1000,60,60]) 17 | 18 | 19 | def _log(s, elapsed=None): 20 | line = "=" * 40 21 | print(line) 22 | print(s) 23 | print(strftime("%Y-%m-%d %H:%M:%S", localtime())) 24 | if elapsed: 25 | print("Elapsed time:", elapsed) 26 | print(line) 27 | print() 28 | 29 | 30 | def _endlog(start): 31 | end = time() 32 | elapsed = end-start 33 | _log("End Program", _secondsToStr(elapsed)) 34 | 35 | 36 | def timenow(): 37 | print(strftime("%Y-%m-%d %H:%M:%S", localtime()), _secondsToStr(clock())) 38 | 39 | 40 | def timeit(): 41 | start = time() 42 | atexit.register(_endlog, start) 43 | _log("Start Program") 44 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/deep_gaussian_process/loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Adapt code from https://github.com/gabrieltseng/pycrop-yield-prediction 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | def l1_l2_loss(pred, true, l1_weight, scores_dict): 12 | """ 13 | Regularized MSE loss; l2 loss with l1 loss too. 14 | 15 | Parameters 16 | ---------- 17 | pred: torch.floatTensor 18 | The model predictions 19 | true: torch.floatTensor 20 | The true values 21 | l1_weight: int 22 | The value by which to weight the l1 loss 23 | scores_dict: defaultdict(list) 24 | A dict to which scores can be appended. 25 | 26 | Returns 27 | ---------- 28 | loss: the regularized mse loss 29 | """ 30 | loss = F.mse_loss(pred, true) 31 | 32 | scores_dict['l2'].append(loss.item()) 33 | 34 | if l1_weight > 0: 35 | l1 = F.l1_loss(pred, true) 36 | loss += l1 37 | scores_dict['l1'].append(l1.item()) 38 | scores_dict['loss'].append(loss.item()) 39 | 40 | return loss, scores_dict 41 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to our project 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | ### Core contributions 9 | 10 | 1. Fork the repo and create your branch from `master`. 11 | 2. If you've added code that should be tested, add tests. 12 | 3. If you've changed APIs, update the documentation. 13 | 4. Ensure the test suite passes. 14 | 5. Make sure your code lints. 15 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 16 | 17 | ### Task Contributions 18 | TODO TODO TODO 19 | 20 | ## Contributor License Agreement ("CLA") 21 | In order to accept your pull request, we need you to submit a CLA. You only need 22 | to do this once to work on any of Facebook's open source projects. 23 | 24 | Complete your CLA here: 25 | 26 | ## Issues 27 | We use GitHub issues to track public bugs. Please ensure your description is 28 | clear and has sufficient instructions to be able to reproduce the issue. 29 | 30 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 31 | disclosure of security bugs. In those cases, please go through the process 32 | outlined on that page and do not file a public issue. 33 | 34 | ## License 35 | By contributing to the project, you agree that your contributions will be licensed 36 | under the LICENSE file in the root directory of this source tree. 37 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/cnn_lstm/cnn_lstm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from crop_yield_prediction.models.semi_transformer.TileNet import make_tilenet 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class CnnLstm(nn.Module): 14 | ''' A sequence to sequence model with attention mechanism. ''' 15 | 16 | def __init__(self, tn_in_channels, tn_z_dim, d_model=512, d_inner=2048): 17 | 18 | super().__init__() 19 | 20 | self.tilenet = make_tilenet(tn_in_channels, tn_z_dim) 21 | 22 | self.encoder = nn.LSTM(d_model, d_inner, batch_first=True) 23 | 24 | self.predict_proj = nn.Linear(d_inner, 1) 25 | 26 | for p in self.parameters(): 27 | if p.dim() > 1: 28 | nn.init.xavier_uniform_(p) 29 | 30 | def forward(self, x): 31 | """ 32 | Input x: (n_batches, n_tsteps, n_triplets, n_var, img_height, img_width) 33 | """ 34 | n_batches, n_tsteps, n_vars, img_size = x.shape[:-1] 35 | 36 | x = x.view(n_batches * n_tsteps, n_vars, img_size, img_size) 37 | emb_x = self.tilenet(x) 38 | emb_x = emb_x.view(n_batches, n_tsteps, -1) 39 | 40 | enc_output, *_ = self.encoder(emb_x) 41 | enc_output = enc_output[:, -1, :] 42 | 43 | pred = torch.squeeze(self.predict_proj(enc_output)) 44 | 45 | return pred 46 | -------------------------------------------------------------------------------- /generate_feature_importance_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from crop_yield_prediction import CLIMATE_VARS 8 | 9 | import os 10 | import numpy as np 11 | 12 | 13 | def generate_feature_importance_data_exclude(in_dir, out_dir, exclude_group): 14 | os.makedirs(out_dir, exist_ok=True) 15 | 16 | include_indices = [i for i, x in enumerate(CLIMATE_VARS) if x not in exclude_group] 17 | print(exclude_group, include_indices) 18 | for f in os.listdir(in_dir): 19 | if f.endswith('.npy'): 20 | in_data = np.load('{}/{}'.format(in_dir, f)) 21 | out_data = in_data[:, :, :, include_indices, :, :] 22 | # print(out_data.shape) 23 | np.save('{}/{}'.format(out_dir, f), out_data) 24 | 25 | 26 | if __name__ == '__main__': 27 | # exclude_groups = [('ppt',), ('evi', 'ndvi'), ('elevation',), ('lst_day', 'lst_night'), 28 | # ('clay', 'sand', 'silt')] 29 | exclude_groups = [('ppt', 'elevation', 'lst_day', 'lst_night', 'clay', 'sand', 'silt')] 30 | for eg in exclude_groups: 31 | generate_feature_importance_data_exclude(in_dir='data/spatial_temporal/counties/nr_25_dr100', 32 | out_dir='data/spatial_temporal/counties/nr_25_dr100_{}'.format('_'.join(eg)), 33 | exclude_group=eg) 34 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/semi_transformer/Optim.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Based on transformer code from https://github.com/jadore801120/attention-is-all-you-need-pytorch 7 | 8 | '''A wrapper class for scheduled optimizer ''' 9 | import numpy as np 10 | 11 | class ScheduledOptim(): 12 | '''A simple wrapper class for learning rate scheduling''' 13 | 14 | def __init__(self, optimizer, init_lr, d_model, n_warmup_steps): 15 | self._optimizer = optimizer 16 | self.init_lr = init_lr 17 | self.d_model = d_model 18 | self.n_warmup_steps = n_warmup_steps 19 | self.n_steps = 0 20 | 21 | 22 | def step_and_update_lr(self): 23 | "Step with the inner optimizer" 24 | self._update_learning_rate() 25 | self._optimizer.step() 26 | 27 | 28 | def zero_grad(self): 29 | "Zero out the gradients with the inner optimizer" 30 | self._optimizer.zero_grad() 31 | 32 | 33 | def _get_lr_scale(self): 34 | d_model = self.d_model 35 | n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps 36 | return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5)) 37 | 38 | 39 | def _update_learning_rate(self): 40 | ''' Learning rate scheduling per step ''' 41 | 42 | self.n_steps += 1 43 | lr = self.init_lr * self._get_lr_scale() 44 | 45 | for param_group in self._optimizer.param_groups: 46 | param_group['lr'] = lr 47 | 48 | -------------------------------------------------------------------------------- /data_preprocessing/utils/generate_doy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from datetime import date, timedelta 8 | 9 | 10 | def generate_doy(s_doy, e_doy, delimiter): 11 | s_doy = map(int, [s_doy[:4], s_doy[4:6], s_doy[6:]]) 12 | e_doy = map(int, [e_doy[:4], e_doy[4:6], e_doy[6:]]) 13 | 14 | d1 = date(*s_doy) 15 | d2 = date(*e_doy) 16 | delta = d2 - d1 17 | 18 | for i in range(delta.days + 1): 19 | yield str(d1 + timedelta(days=i)).replace("-", delimiter) 20 | 21 | 22 | def generate_doy_every_n(s_doy, e_doy, n, delimiter): 23 | s_doy = map(int, [s_doy[:4], s_doy[4:6], s_doy[6:]]) 24 | e_doy = map(int, [e_doy[:4], e_doy[4:6], e_doy[6:]]) 25 | 26 | d1 = date(*s_doy) 27 | d2 = date(*e_doy) 28 | delta = d2 - d1 29 | 30 | for i in range(0, delta.days + 1, n): 31 | yield str(d1 + timedelta(days=i)).replace("-", delimiter) 32 | 33 | 34 | def generate_nearest_doys(doy, n, delimiter): 35 | doy = map(int, [doy[:4], doy[4:6], doy[6:]]) 36 | d1 = date(*doy) 37 | 38 | for i in range((n+1)//2-n, (n+1)//2): 39 | yield str(d1 + timedelta(days=i)).replace("-", delimiter) 40 | 41 | 42 | def generate_most_recent_doys(doy, n, delimiter): 43 | doy = map(int, [doy[:4], doy[4:6], doy[6:]]) 44 | d1 = date(*doy) 45 | 46 | for i in range(-1, -n-1, -1): 47 | yield str(d1 + timedelta(days=i)).replace("-", delimiter) 48 | 49 | 50 | def generate_future_doys(doy, n, delimiter): 51 | doy = map(int, [doy[:4], doy[4:6], doy[6:]]) 52 | d1 = date(*doy) 53 | 54 | for i in range(n): 55 | yield str(d1 + timedelta(days=i)).replace("-", delimiter) 56 | -------------------------------------------------------------------------------- /crop_yield_no_spatial.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from crop_yield_prediction.utils import Logger 8 | from crop_yield_prediction.models import * 9 | 10 | import os 11 | import sys 12 | import argparse 13 | 14 | 15 | def predict_for_no_spatial(train_years): 16 | log_folder = 'results/no_spatial/prediction_logs' 17 | if not os.path.exists(log_folder): 18 | os.makedirs(log_folder) 19 | sys.stdout = Logger('{}/nt{}_all_results_online_learning.txt'.format(log_folder, train_years)) 20 | 21 | predict_no_spatial('data/no_spatial/soybeans_3_9.csv', 2014, 2018, 9, train_years, 22 | 'crop_yield_no_spatial/results/all') 23 | predict_no_spatial('data/no_spatial/soybeans_3_9.csv', 2014, 2018, 8, train_years, 24 | 'crop_yield_no_spatial/results/all') 25 | predict_no_spatial('data/no_spatial/soybeans_3_9.csv', 2014, 2018, 7, train_years, 26 | 'crop_yield_no_spatial/results/all') 27 | predict_no_spatial('data/no_spatial/soybeans_3_9.csv', 2014, 2018, 6, train_years, 28 | 'crop_yield_no_spatial/results/all') 29 | predict_no_spatial('data/no_spatial/soybeans_3_9.csv', 2014, 2018, 5, train_years, 30 | 'crop_yield_no_spatial/results/all') 31 | 32 | sys.stdout.close() 33 | sys.stdout = sys.__stdout__ 34 | 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--predict', required=True) 39 | parser.add_argument('--train-years', type=int, default=None, metavar='TRAINYEAR', required=True) 40 | 41 | args = parser.parse_args() 42 | 43 | predict = args.predict 44 | train_years = args.train_years 45 | 46 | if predict == 'no_spatial': 47 | predict_for_no_spatial(train_years) 48 | -------------------------------------------------------------------------------- /data_preprocessing/preprocess/cdl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from netCDF4 import Dataset 8 | 9 | import os 10 | from osgeo import gdal, osr 11 | import numpy as np 12 | from pyproj import Proj, transform 13 | import numpy.ma as ma 14 | 15 | # gdalwarp -t_srs '+proj=longlat +ellps=WGS84 +datum=WGS84 +no_defs' 2018_30m_cdls.img 2018_30m_cdls.tif 16 | # gdal_translate -of netCDF PRISM_ppt_stable_4kmM3_201806_bil.bil PRISM_ppt_stable_4kmM3_201806.nc 17 | 18 | 19 | def cdl_convert_to_nc(in_dir, in_file, out_dir, out_file): 20 | if not os.path.exists(out_dir): 21 | os.makedirs(out_dir) 22 | 23 | raster = gdal.Open(os.path.join(in_dir, in_file)) 24 | cdl_values = raster.ReadAsArray() 25 | geo = raster.GetGeoTransform() 26 | projWKT = raster.GetProjection() 27 | proj = osr.SpatialReference() 28 | proj.ImportFromWkt(projWKT) 29 | # n_lat, n_lon = np.shape(cdl_values) 30 | # b = raster.GetGeoTransform() 31 | # lons = (np.arange(n_lon) * b[1] + b[0]) 32 | # lats = (np.arange(n_lat) * b[5] + b[3]) 33 | # 34 | # fh_out = Dataset(os.path.join(out_dir, out_file), "w") 35 | # fh_out.createDimension("lat", len(lats)) 36 | # fh_out.createDimension("lon", len(lons)) 37 | # 38 | # outVar = fh_out.createVariable('lat', float, ('lat')) 39 | # outVar.setncatts({"units": "degree_north"}) 40 | # outVar[:] = lats[:] 41 | # outVar = fh_out.createVariable('lon', float, ('lon')) 42 | # outVar.setncatts({"units": "degree_east"}) 43 | # outVar[:] = lons[:] 44 | # 45 | # outVar = fh_out.createVariable("cdl", float, ("lat", "lon")) 46 | # outVar[:] = ma.masked_less(cdl_values, 0) 47 | # 48 | # fh_out.close() 49 | 50 | 51 | if __name__ == "__main__": 52 | cdl_convert_to_nc("raw_data/cdl/2008_30m_cdls", "2008_30m_cdls.img", 53 | "processed_data/cdl/30m/") 54 | -------------------------------------------------------------------------------- /data_preprocessing/rescaling/prism_upscale.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from netCDF4 import Dataset 9 | import numpy as np 10 | 11 | 12 | def get_origi_lat_lon(): 13 | in_dir = '../../processed_data/prism/monthly' 14 | lats, lons = None, None 15 | 16 | for f in os.listdir(in_dir): 17 | if f.endswith('.nc'): 18 | fh = Dataset(os.path.join(in_dir, f), 'r') 19 | if lats is None and lons is None: 20 | lats, lons = fh.variables['lat'][:], fh.variables['lon'][:] 21 | else: 22 | assert np.allclose(lats, fh.variables['lat'][:]) 23 | assert np.allclose(lons, fh.variables['lon'][:]) 24 | 25 | out_dir = '../../processed_data/prism/latlon' 26 | if not os.path.exists(out_dir): 27 | os.makedirs(out_dir) 28 | np.save(os.path.join(out_dir, 'lat_4km.npy'), lats.compressed()) 29 | np.save(os.path.join(out_dir, 'lon_4km.npy'), lons.compressed()) 30 | 31 | 32 | def get_lat_lon_even(n=10): 33 | """ 34 | :param n: how many pixels in lat or lon constructs one cell, e.g. n = 10 means the cell will be ~40 km * 40 km 35 | """ 36 | origi_lats = np.load('../../processed_data/prism/latlon/lat_4km.npy') 37 | origi_lons = np.load('../../processed_data/prism/latlon/lon_4km.npy') 38 | # print('Lengths of origi: ', len(origi_lats), len(origi_lons)) 39 | 40 | n_cell_lat = len(origi_lats)//n 41 | n_cell_lon = len(origi_lons)//n 42 | 43 | new_lats = [] 44 | new_lons = [] 45 | 46 | for i in range(n_cell_lat): 47 | i1, i2 = (n//2-1) + n * i, n//2 + n * i 48 | new_lats.append((origi_lats[i1] + origi_lats[i2])/2) 49 | 50 | for i in range(n_cell_lon): 51 | i1, i2 = (n // 2 - 1) + n * i, n // 2 + n * i 52 | new_lons.append((origi_lons[i1] + origi_lons[i2])/2) 53 | 54 | out_dir = '../../processed_data/prism/latlon' 55 | if not os.path.exists(out_dir): 56 | os.makedirs(out_dir) 57 | np.save(os.path.join(out_dir, 'lat_{}km.npy'.format(4*n)), np.asarray(new_lats)) 58 | np.save(os.path.join(out_dir, 'lon_{}km.npy').format(4*n), np.asarray(new_lons)) 59 | 60 | 61 | if __name__ == "__main__": 62 | # get_origi_lat_lon() 63 | get_lat_lon_even() 64 | -------------------------------------------------------------------------------- /crop_yield_prediction/dataloader/cross_location_dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch.utils.data import Dataset, DataLoader 8 | import torch 9 | import numpy as np 10 | 11 | 12 | class CrossLocationDataset(Dataset): 13 | """ 14 | Case 0 n_triplets_per_file == (max_index + 1): load numpy file in __init__, retrieve idx in __getitem__ 15 | Case 1 n_triplets_per_file == 1: load numpy file for idx in __getitem__ 16 | Case 2 n_triplets_per_file > 1: load numpy file that stores idx (and others) in __getitem__ 17 | idx is the index in "current" train/validation/test set. global idx is the index in the whole data set. 18 | Indices in train/validation/test set need to be sequential. 19 | """ 20 | def __init__(self, data_dir, global_index_dic, y, n_tsteps, max_index, n_triplets_per_file): 21 | self.data_dir = data_dir 22 | self.global_index_dic = global_index_dic 23 | self.n_triplets = len(global_index_dic) 24 | self.y = y 25 | self.n_tsteps = n_tsteps 26 | self.max_index = max_index 27 | if n_triplets_per_file == (max_index + 1): 28 | self.X_data = np.load('{}/0_{}.npy'.format(data_dir, max_index)) 29 | assert n_triplets_per_file == 1 30 | 31 | def __len__(self): 32 | return self.n_triplets 33 | 34 | def __getitem__(self, idx): 35 | global_idx = self.global_index_dic[idx] 36 | X_idx = np.load('{}/{}.npy'.format(self.data_dir, global_idx))[0][:self.n_tsteps] 37 | y_idx = np.array(self.y[idx]) 38 | 39 | return torch.from_numpy(X_idx).float(), torch.from_numpy(y_idx).float() 40 | 41 | 42 | def cross_location_dataloader(data_dir, global_index_dic, y, n_tsteps, max_index, n_triplets_per_file, 43 | batch_size=50, shuffle=True, num_workers=4): 44 | """ 45 | img_type: 'landsat', 'rgb', or 'naip' 46 | augment: random flip and rotate for data augmentation 47 | shuffle: turn shuffle to False for producing embeddings that correspond to original tiles. 48 | Returns a DataLoader with either NAIP (RGB/IR), RGB, or Landsat tiles. 49 | """ 50 | 51 | dataset = CrossLocationDataset(data_dir, global_index_dic, y, n_tsteps, max_index, 52 | n_triplets_per_file=n_triplets_per_file) 53 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 54 | return dataloader 55 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/semi_transformer/AttentionModels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Based on transformer code from https://github.com/jadore801120/attention-is-all-you-need-pytorch 7 | 8 | from crop_yield_prediction.models.semi_transformer.Layers import EncoderLayer 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.autograd import Variable 13 | import math 14 | 15 | 16 | class PositionalEncoding(nn.Module): 17 | 18 | def __init__(self, d_hid, n_position): 19 | super(PositionalEncoding, self).__init__() 20 | 21 | # Compute the positional encodings once in log space. 22 | pe = torch.zeros(n_position, d_hid) 23 | position = torch.arange(0.0, n_position).unsqueeze(1) 24 | div_term = torch.exp(torch.arange(0.0, d_hid, 2) * -(math.log(10000.0) / d_hid)) 25 | pe[:, 0::2] = torch.sin(position * div_term) 26 | pe[:, 1::2] = torch.cos(position * div_term) 27 | pe = pe.unsqueeze(0) 28 | self.register_buffer('pe', pe) 29 | 30 | def forward(self, x): 31 | return x + Variable(self.pe[:, :x.size(1)], requires_grad=False) 32 | 33 | 34 | class Encoder(nn.Module): 35 | ''' A encoder model with self attention mechanism. ''' 36 | 37 | def __init__( 38 | self, n_tsteps, query_type, d_word_vec, d_model, d_inner, n_layers, n_head, d_k, d_v, dropout=0.1, 39 | apply_position_enc=True): 40 | 41 | super().__init__() 42 | 43 | self.apply_position_enc = apply_position_enc 44 | 45 | self.position_enc = PositionalEncoding(d_word_vec, n_position=n_tsteps) 46 | self.dropout = nn.Dropout(p=dropout) 47 | 48 | self.layer_stack = nn.ModuleList([ 49 | EncoderLayer(n_tsteps, query_type, d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 50 | for _ in range(n_layers)]) 51 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 52 | 53 | def forward(self, x, return_attns=False): 54 | enc_slf_attn_list = [] 55 | 56 | # -- Forward 57 | if self.apply_position_enc: 58 | x = self.position_enc(x) 59 | enc_output = self.dropout(x) 60 | 61 | for enc_layer in self.layer_stack: 62 | enc_output, enc_slf_attn = enc_layer(enc_output) 63 | enc_slf_attn_list += [enc_slf_attn] if return_attns else [] 64 | 65 | enc_output = self.layer_norm(enc_output) 66 | 67 | if return_attns: 68 | return enc_output, enc_slf_attn_list 69 | return enc_output, 70 | -------------------------------------------------------------------------------- /data_preprocessing/merge/merge_various_days.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ..utils import generate_doy 8 | 9 | import os 10 | import numpy as np 11 | import datetime as dt 12 | from datetime import datetime 13 | from netCDF4 import Dataset 14 | 15 | FIRST_DATE = dt.date(2001, 1, 1) 16 | 17 | 18 | def merge_various_days(in_path, out_path, fout_name, doy_start=None, doy_end=None, select_vars=None): 19 | fh_out = Dataset(os.path.join(out_path, fout_name + '.nc'), 'w') 20 | 21 | num = 0 22 | var_list = [] 23 | 24 | if doy_start is None or doy_end is None: 25 | fnames = [fname[:-3] for fname in os.listdir(in_path) if fname.endswith(".nc")] 26 | fnames = sorted(fnames, key=lambda x: datetime.strptime("".join(c for c in x if c.isdigit()), '%Y%m%d')) 27 | else: 28 | fnames = list(generate_doy(doy_start, doy_end, "")) 29 | num_files = len(fnames) 30 | print("Number of files", num_files) 31 | 32 | for nc_file in fnames: 33 | nc_doy = "".join(c for c in nc_file if c.isdigit()) 34 | fh_in = Dataset(os.path.join(in_path, nc_file + ".nc"), 'r') 35 | n_dim = {} 36 | if num == 0: 37 | for name, dim in fh_in.dimensions.items(): 38 | n_dim[name] = len(dim) 39 | fh_out.createDimension(name, len(dim) if not dim.isunlimited() else None) 40 | 41 | fh_out.createDimension('time', num_files) 42 | outVar = fh_out.createVariable('time', 'int', ("time",)) 43 | outVar[:] = range(1, num_files + 1) 44 | 45 | select_vars = list(fh_in.variables.keys()) if select_vars is None else select_vars 46 | for v_name, varin in fh_in.variables.items(): 47 | if v_name == 'lat' or v_name == 'lon': 48 | outVar = fh_out.createVariable(v_name, varin.datatype, varin.dimensions) 49 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 50 | outVar[:] = varin[:] 51 | else: 52 | if v_name in select_vars: 53 | var_list.append(v_name) 54 | outVar = fh_out.createVariable(v_name, varin.datatype, ("time", "lat", "lon",)) 55 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 56 | outVar[:] = np.empty((num_files, n_dim['lat'], n_dim['lon'])) 57 | 58 | current_date = datetime.strptime(nc_doy, "%Y%m%d").date() 59 | fh_out.variables['time'][num] = (current_date - FIRST_DATE).days 60 | for vname in var_list: 61 | var_value = fh_in.variables[vname][:] 62 | fh_out.variables[vname][num, :, :] = var_value[:] 63 | 64 | num += 1 65 | fh_in.close() 66 | fh_out.close() 67 | 68 | print(num, num_files) 69 | assert (num == num_files) 70 | -------------------------------------------------------------------------------- /data_preprocessing/preprocess/prism.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from netCDF4 import Dataset 9 | 10 | # gdal_translate -of netCDF PRISM_ppt_stable_4kmM3_201806_bil.bil PRISM_ppt_stable_4kmM3_201806.nc 11 | 12 | 13 | def prism_convert_to_nc(): 14 | fh_out = open(os.path.join("../..", "prism_convert_to_nc.sh"), "w") 15 | fh_out.write("#!/bin/bash\n") 16 | # m_dic = {"ppt": "M3", "tdmean": "M1", "tmax": "M2", "tmean": "M2", "tmin": "M2", "vpdmax": "M1", "vpdmin": "M1"} 17 | for climate_var in ["ppt", "tdmean", "tmax", "tmean", "tmin", "vpdmax", "vpdmin"]: 18 | for year in range(1999, 2019): 19 | for month in range(1, 13): 20 | fh_out.write("gdal_translate -of netCDF raw_data/prism/monthly/PRISM_{}_stable_4kmM3_198101_201904_bil/" 21 | "PRISM_{}_stable_4kmM3_{}{}_bil.bil processed_data/prism/monthly/{}_{}{}.nc\n" 22 | .format(climate_var, climate_var, year, "{0:02}".format(month), climate_var, year, 23 | "{0:02}".format(month))) 24 | 25 | 26 | def combine_multivar(): 27 | climate_vars = ["ppt", "tdmean", "tmax", "tmean", "tmin", "vpdmax", "vpdmin"] 28 | 29 | for year in range(1999, 2019): 30 | for month in range(1, 13): 31 | fh_out = Dataset('../../processed_data/prism/combined_monthly/{}{}.nc'.format(year, 32 | '{0:02}'.format(month)), 'w') 33 | first_flag = True 34 | for v in climate_vars: 35 | fh_in = Dataset('../../processed_data/prism/monthly/{}_{}{}.nc'.format(v, year, 36 | '{0:02}'.format(month), 'r')) 37 | if first_flag: 38 | for name, dim in fh_in.dimensions.items(): 39 | fh_out.createDimension(name, len(dim)) 40 | for v_name, varin in fh_in.variables.items(): 41 | if v_name in ['lat', 'lon']: 42 | outVar = fh_out.createVariable(v_name, varin.datatype, varin.dimensions) 43 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 44 | outVar[:] = varin[:] 45 | first_flag = False 46 | 47 | for v_name, varin in fh_in.variables.items(): 48 | if v_name == 'Band1': 49 | outVar = fh_out.createVariable(v, varin.datatype, varin.dimensions) 50 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 51 | outVar[:] = varin[:] 52 | 53 | fh_in.close() 54 | 55 | fh_out.close() 56 | 57 | 58 | if __name__ == "__main__": 59 | prism_convert_to_nc() 60 | -------------------------------------------------------------------------------- /data_preprocessing/plot/plot_local.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from netCDF4 import Dataset 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | 12 | # mean of lats: 40.614586, mean of lons: -121.24792 13 | def plot_local(in_file, x_axis, y_axis): 14 | fh = Dataset(in_file, 'r') 15 | lats = fh.variables['lat'][:] 16 | lons = fh.variables['lon'][:] 17 | x_indices = [(np.abs(lons-i)).argmin() for i in x_axis] 18 | y_indices = [(np.abs(lats-i)).argmin() for i in y_axis] 19 | for v in fh.variables.keys(): 20 | if v not in ['lat', 'lon']: 21 | values = fh.variables[v][:] 22 | plt.imshow(values, interpolation='none', cmap=plt.get_cmap("jet")) 23 | plt.title(v) 24 | plt.gca().set_xticks(x_indices) 25 | plt.gca().set_yticks(y_indices) 26 | plt.gca().set_xticklabels(x_axis) 27 | plt.gca().set_yticklabels(y_axis) 28 | plt.colorbar() 29 | plt.savefig('../../processed_data/local/ca_20190604/{}.jpg'.format(v)) 30 | plt.close() 31 | 32 | 33 | def plot_landsat(in_file, x_axis, y_axis): 34 | fh = Dataset(in_file, 'r') 35 | lats = fh.variables['lat'][:][::-1] 36 | lons = fh.variables['lon'][:] 37 | x_indices = [(np.abs(lons - i)).argmin() for i in x_axis] 38 | y_indices = [(np.abs(lats - i)).argmin() for i in y_axis] 39 | titles = ["Band 1 Ultra Blue", "Band 2 Blue", "Band 3 Green", 40 | "Band 4 Red", "Band 5 Near Infrared", 41 | "Band 6 Shortwave Infrared 1", "Band 7 Shortwave Infrared 2"] 42 | for title, v in zip(titles, range(1, 8)): 43 | values = np.flipud(fh.variables['band{}'.format(v)][:]) 44 | plt.imshow(values, interpolation='none', cmap=plt.get_cmap("jet"), vmin=0, vmax=10000) 45 | plt.title(title) 46 | plt.gca().set_xticks(x_indices) 47 | plt.gca().set_yticks(y_indices) 48 | plt.gca().set_xticklabels(x_axis) 49 | plt.gca().set_yticklabels(y_axis) 50 | plt.colorbar() 51 | plt.savefig('../../processed_data/local/ca_20190604/band{}.jpg'.format(v)) 52 | plt.close() 53 | 54 | 55 | if __name__ == '__main__': 56 | y_axis = [41.20, 40.95, 40.70, 40.45, 40.20] 57 | x_axis = [-122.0, -121.75, -121.5, -121.25, -121.0, -120.75, -120.5] 58 | plot_local('../../processed_data/local/ca_20190604/elevation.nc', x_axis, y_axis) 59 | plot_local('../../processed_data/local/ca_20190604/lai.nc', x_axis, y_axis) 60 | plot_local('../../processed_data/local/ca_20190604/lst.nc', x_axis, y_axis) 61 | plot_local('../../processed_data/local/ca_20190604/nws_precip.nc', x_axis, y_axis) 62 | plot_local('../../processed_data/local/ca_20190604/soil_fraction.nc', x_axis, y_axis) 63 | plot_local('../../processed_data/local/ca_20190604/soil_moisture.nc', x_axis, y_axis) 64 | plot_landsat('../../processed_data/local/ca_20190604/landsat.nc', x_axis, y_axis) 65 | 66 | -------------------------------------------------------------------------------- /data_preprocessing/preprocess/subset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from netCDF4 import Dataset 8 | import numpy as np 9 | import sys 10 | sys.path.append("..") 11 | 12 | from data_preprocessing.utils import get_closet_date 13 | 14 | 15 | def subset(in_file, out_file, lat1, lat2, lon1, lon2): 16 | fh_in = Dataset(in_file, 'r') 17 | fh_out = Dataset(out_file, 'w') 18 | 19 | lats, lons = fh_in.variables['lat'][:], fh_in.variables['lon'][:] 20 | lat_indices = lats.size - np.searchsorted(lats[::-1], [lat1, lat2], side="right") 21 | lon_indices = np.searchsorted(lons, [lon1, lon2]) 22 | lats = lats[lat_indices[0]: lat_indices[1]] 23 | lons = lons[lon_indices[0]: lon_indices[1]] 24 | 25 | fh_out.createDimension("lat", len(lats)) 26 | fh_out.createDimension("lon", len(lons)) 27 | 28 | for v_name, varin in fh_in.variables.items(): 29 | if v_name in ["lat", "lon"]: 30 | outVar = fh_out.createVariable(v_name, varin.datatype, (v_name,)) 31 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 32 | fh_out.variables["lat"][:] = lats[:] 33 | fh_out.variables["lon"][:] = lons[:] 34 | 35 | for v_name, varin in fh_in.variables.items(): 36 | if v_name not in ["lat", "lon"]: 37 | outVar = fh_out.createVariable(v_name, varin.datatype, ('lat', 'lon')) 38 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 39 | outVar[:] = varin[lat_indices[0]: lat_indices[1], lon_indices[0]: lon_indices[1]] 40 | 41 | fh_in.close() 42 | fh_out.close() 43 | 44 | 45 | if __name__ == '__main__': 46 | # subset('../../processed_data/nws_precip/500m/20190604.nc', 47 | # '../../processed_data/local/ca_20190604/nws_precip.nc', 48 | # 41.2047, 40.0268, -122.0304, -120.4676) 49 | # subset('../../processed_data/elevation/500m.nc', 50 | # '../../processed_data/local/ca_20190604/elevation.nc', 51 | # 41.2047, 40.0268, -122.0304, -120.4676) 52 | # subset('../../processed_data/soil_fraction/soil_fraction_usa_500m.nc', 53 | # '../../processed_data/local/ca_20190604/soil_fraction.nc', 54 | # 41.2047, 40.0268, -122.0304, -120.4676) 55 | 56 | lai_date = get_closet_date('20190604', '../../processed_data/lai/500m') 57 | print(lai_date) 58 | subset('../../processed_data/lai/500m/{}.nc'.format(lai_date), 59 | '../../processed_data/local/ca_20190604/lai.nc', 60 | 41.2047, 40.0268, -122.0304, -120.4676) 61 | lst_date = get_closet_date('20190604', '../../processed_data/lst/500m') 62 | print(lst_date) 63 | subset('../../processed_data/lst/500m/{}.nc'.format(lst_date), 64 | '../../processed_data/local/ca_20190604/lst.nc', 65 | 41.2047, 40.0268, -122.0304, -120.4676) 66 | # subset('../../processed_data/soil_moisture/9km_500m/20190604.nc', 67 | # '../../processed_data/local/ca_20190604/soil_moisture.nc', 68 | # 41.2047, 40.0268, -122.0304, -120.4676) 69 | 70 | -------------------------------------------------------------------------------- /crop_yield_deep_gaussian.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from crop_yield_prediction.models.deep_gaussian_process import * 8 | 9 | from pathlib import Path 10 | import torch 11 | import argparse 12 | 13 | 14 | def train_cnn_gp(times, train_years, dropout=0.5, dense_features=None, 15 | pred_years=range(2014, 2019), num_runs=2, train_steps=25000, 16 | batch_size=32, starter_learning_rate=1e-3, weight_decay=0, l1_weight=0, 17 | patience=None, use_gp=True, sigma=1, r_loc=0.5, r_year=1.5, sigma_e=0.32, sigma_b=0.01, 18 | device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')): 19 | histogram_path = Path('data/deep_gaussian/data.npz') 20 | savedir = Path('results/deep_gaussian/nt{}_tyear{}_cnn'.format(times[0], train_years)) 21 | 22 | model = ConvModel(in_channels=9, dropout=dropout, dense_features=dense_features, 23 | savedir=savedir, use_gp=use_gp, sigma=sigma, r_loc=r_loc, 24 | r_year=r_year, sigma_e=sigma_e, sigma_b=sigma_b, device=device) 25 | model.run(times, train_years, histogram_path, pred_years, num_runs, train_steps, batch_size, 26 | starter_learning_rate, weight_decay, l1_weight, patience) 27 | 28 | 29 | def train_rnn_gp(times, train_years, num_bins=32, hidden_size=128, 30 | rnn_dropout=0.75, dense_features=None, pred_years=range(2014, 2019), 31 | num_runs=2, train_steps=10000, batch_size=32, starter_learning_rate=1e-3, weight_decay=0, 32 | l1_weight=0, patience=None, use_gp=True, sigma=1, r_loc=0.5, r_year=1.5, sigma_e=0.32, sigma_b=0.01, 33 | device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')): 34 | histogram_path = Path('data/deep_gaussian/data.npz') 35 | savedir = Path('results/deep_gaussian/nt{}_tyear{}_rnn'.format(times[0], train_years)) 36 | 37 | model = RNNModel(in_channels=9, num_bins=num_bins, hidden_size=hidden_size, 38 | rnn_dropout=rnn_dropout, dense_features=dense_features, 39 | savedir=savedir, use_gp=use_gp, sigma=sigma, r_loc=r_loc, r_year=r_year, 40 | sigma_e=sigma_e, sigma_b=sigma_b, device=device) 41 | model.run(times, train_years, histogram_path, pred_years, num_runs, train_steps, batch_size, 42 | starter_learning_rate, weight_decay, l1_weight, patience) 43 | 44 | 45 | if __name__ == '__main__': 46 | get_features_for_deep_gaussian() 47 | 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--type', type=str) 50 | parser.add_argument('--time', type=int, default=None, metavar='TIME', required=True) 51 | parser.add_argument('--train-years', type=int, default=None, metavar='TRAINYEAR', required=True) 52 | 53 | args = parser.parse_args() 54 | model_type = args.type 55 | times = [args.time] 56 | train_years = args.train_years 57 | 58 | if model_type == 'cnn': 59 | train_cnn_gp(times, train_years) 60 | elif model_type == 'rnn': 61 | train_rnn_gp(times, train_years) 62 | -------------------------------------------------------------------------------- /data_preprocessing/rescaling/soil_fraction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from netCDF4 import Dataset 9 | import numpy as np 10 | import numpy.ma as ma 11 | import pandas as pd 12 | import csv 13 | 14 | import sys 15 | sys.path.append("..") 16 | 17 | from data_preprocessing.rescaling.rescale_utils import search_kdtree 18 | 19 | 20 | def reproject_lat_lon(): 21 | fh_sf = Dataset('../../raw_data/soil_fraction/soil_fraction_usa.nc', 'r') 22 | lats, lons = fh_sf.variables['lat'][:], fh_sf.variables['lon'][:] 23 | lons, lats = np.meshgrid(lons, lats) 24 | 25 | fh_ref = Dataset('../../processed_data/lst/monthly_1km/201701.nc', 'r') 26 | ref_lats, ref_lons = fh_ref.variables['lat'][:], fh_ref.variables['lon'][:] 27 | 28 | xv, yv = np.meshgrid(ref_lons, ref_lats) 29 | points = np.dstack([yv.ravel(), xv.ravel()])[0] 30 | print('Finish building points') 31 | results = search_kdtree(lats, lons, points) 32 | np.save('../../raw_data/soil_fraction/projected_indices_lst_1km.npy', results) 33 | 34 | 35 | def reproject_sf(): 36 | fh_ref = Dataset('../../processed_data/lst/monthly_1km/201701.nc', 'r') 37 | fh_in = Dataset('../../raw_data/soil_fraction/soil_fraction_usa.nc', 'r') 38 | fh_out = Dataset('../../processed_data/soil_fraction/soil_fraction_usa_1km.nc', 'w') 39 | 40 | ref_lats, ref_lons = fh_ref.variables['lat'][:], fh_ref.variables['lon'][:] 41 | n_lat, n_lon = len(ref_lats), len(ref_lons) 42 | for name, dim in fh_ref.dimensions.items(): 43 | fh_out.createDimension(name, len(dim)) 44 | 45 | for v_name, varin in fh_ref.variables.items(): 46 | if v_name in ['lat', 'lon']: 47 | outVar = fh_out.createVariable(v_name, varin.datatype, (v_name,)) 48 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 49 | outVar[:] = varin[:] 50 | 51 | origi_values = {} 52 | projected_values = {} 53 | for v_name, varin in fh_in.variables.items(): 54 | if v_name not in ['lat', 'lon']: 55 | outVar = fh_out.createVariable(v_name, 'f4', ('lat', 'lon')) 56 | outVar.setncatts({'_FillValue': np.array([-9999.9]).astype('f')}) 57 | origi_values[v_name] = varin[:] 58 | projected_values[v_name] = np.full((n_lat, n_lon), -9999.9) 59 | 60 | projected_indices = np.load('../../raw_data/soil_fraction/projected_indices_lst_1km.npy') 61 | projected_i = 0 62 | for i in range(n_lat): 63 | for j in range(n_lon): 64 | for key in origi_values.keys(): 65 | proj_i, proj_j = projected_indices[projected_i] // 8724, projected_indices[projected_i] % 8724 66 | if not origi_values[key].mask[proj_i, proj_j]: 67 | projected_values[key][i, j] = origi_values[key][proj_i, proj_j] 68 | projected_i += 1 69 | 70 | for key in origi_values.keys(): 71 | fh_out.variables[key][:] = ma.masked_equal(projected_values[key], -9999.9) 72 | 73 | fh_in.close() 74 | fh_ref.close() 75 | fh_out.close() 76 | 77 | 78 | if __name__ == '__main__': 79 | # reproject_lat_lon() 80 | reproject_sf() 81 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/semi_transformer/SemiTransformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Based on transformer code from https://github.com/jadore801120/attention-is-all-you-need-pytorch 7 | 8 | from crop_yield_prediction.models.semi_transformer.AttentionModels import Encoder 9 | from crop_yield_prediction.models.semi_transformer.TileNet import make_tilenet 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class SemiTransformer(nn.Module): 16 | ''' A sequence to sequence model with attention mechanism. ''' 17 | 18 | def __init__( 19 | self, tn_in_channels, tn_z_dim, tn_warm_start_model, 20 | sentence_embedding, output_pred, query_type, 21 | attn_n_tsteps, d_word_vec=512, d_model=512, d_inner=2048, 22 | n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, apply_position_enc=True): 23 | 24 | super().__init__() 25 | 26 | assert d_model == d_word_vec, \ 27 | 'To facilitate the residual connections, \ 28 | the dimensions of all module outputs shall be the same.' 29 | 30 | self.output_pred = output_pred 31 | 32 | self.tilenet = make_tilenet(tn_in_channels, tn_z_dim) 33 | 34 | self.encoder = Encoder( 35 | attn_n_tsteps, query_type=query_type, d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 36 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, dropout=dropout, 37 | apply_position_enc=apply_position_enc) 38 | 39 | self.sentence_embedding = sentence_embedding 40 | 41 | if self.output_pred: 42 | self.predict_proj = nn.Linear(d_model, 1) 43 | 44 | for p in self.parameters(): 45 | if p.dim() > 1: 46 | nn.init.xavier_uniform_(p) 47 | 48 | if tn_warm_start_model is not None: 49 | warm_start = torch.load(tn_warm_start_model) 50 | self.tilenet.load_state_dict(warm_start['model_state_dict']) 51 | 52 | def forward(self, x, unsup_weight): 53 | """ 54 | Input x: (n_batches, n_tsteps, n_triplets, n_var, img_height, img_width) 55 | """ 56 | n_batches, n_tsteps, n_triplets, n_vars, img_size = x.shape[:-1] 57 | 58 | emb_triplets = None 59 | if unsup_weight != 0: 60 | x = x.view(n_batches * n_tsteps * n_triplets, n_vars, img_size, img_size) 61 | 62 | emb_triplets = self.tilenet(x) 63 | 64 | emb_triplets = emb_triplets.view(n_batches, n_tsteps, n_triplets, -1) 65 | emb_x = emb_triplets[:, :, 0, :] 66 | # emb_triplets = emb_triplets.view(n_batches * n_tsteps, n_triplets, -1) 67 | else: 68 | x = x[:, :, 0, :, :, :] 69 | x = x.view(n_batches * n_tsteps, n_vars, img_size, img_size) 70 | emb_x = self.tilenet(x) 71 | emb_x = emb_x.view(n_batches, n_tsteps, -1) 72 | 73 | enc_output, *_ = self.encoder(emb_x) 74 | 75 | if self.sentence_embedding == 'simple_average': 76 | enc_output = enc_output.mean(1) 77 | 78 | pred = torch.squeeze(self.predict_proj(enc_output)) 79 | 80 | return emb_triplets, pred 81 | -------------------------------------------------------------------------------- /crop_yield_prediction/dataloader/semi_cropyield_dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch.utils.data import Dataset, DataLoader 8 | import torch 9 | import numpy as np 10 | 11 | 12 | class SemiCropYieldDataset(Dataset): 13 | """ 14 | Case 0 n_triplets_per_file == (max_index + 1): load numpy file in __init__, retrieve idx in __getitem__ 15 | Case 1 n_triplets_per_file == 1: load numpy file for idx in __getitem__ 16 | Case 2 n_triplets_per_file > 1: load numpy file that stores idx (and others) in __getitem__ 17 | idx is the index in "current" train/validation/test set. global idx is the index in the whole data set. 18 | Indices in train/validation/test set need to be sequential. 19 | """ 20 | def __init__(self, data_dir, start_index, end_index, y, n_tsteps, max_index, n_triplets_per_file): 21 | self.data_dir = data_dir 22 | self.start_index = start_index 23 | self.end_index = end_index 24 | self.n_triplets = end_index - start_index + 1 25 | self.n_triplets_per_file = n_triplets_per_file 26 | self.y = y 27 | self.n_tsteps = n_tsteps 28 | self.max_index = max_index 29 | if n_triplets_per_file == (max_index + 1): 30 | self.X_data = np.load('{}/0_{}.npy'.format(data_dir, max_index)) 31 | 32 | def __len__(self): 33 | return self.n_triplets 34 | 35 | def __getitem__(self, idx): 36 | global_idx = idx + self.start_index 37 | 38 | if self.n_triplets_per_file == (self.max_index + 1): 39 | X_idx = self.X_data[global_idx][:self.n_tsteps] 40 | else: 41 | if self.n_triplets_per_file > 1: 42 | file_idx = global_idx // self.n_triplets_per_file 43 | local_idx = global_idx % self.n_triplets_per_file 44 | 45 | end_idx = min((file_idx+1)*self.n_triplets_per_file-1, self.max_index) 46 | X_idx = np.load('{}/{}_{}.npy'.format(self.data_dir, 47 | file_idx * self.n_triplets_per_file, 48 | end_idx))[local_idx][:self.n_tsteps] 49 | else: 50 | X_idx = np.load('{}/{}.npy'.format(self.data_dir, global_idx))[0][:self.n_tsteps] 51 | y_idx = np.array(self.y[idx]) 52 | 53 | return torch.from_numpy(X_idx).float(), torch.from_numpy(y_idx).float() 54 | 55 | 56 | def semi_cropyield_dataloader(data_dir, start_index, end_index, y, n_tsteps, max_index, n_triplets_per_file, 57 | batch_size=50, shuffle=True, num_workers=4): 58 | """ 59 | img_type: 'landsat', 'rgb', or 'naip' 60 | augment: random flip and rotate for data augmentation 61 | shuffle: turn shuffle to False for producing embeddings that correspond to original tiles. 62 | Returns a DataLoader with either NAIP (RGB/IR), RGB, or Landsat tiles. 63 | """ 64 | 65 | dataset = SemiCropYieldDataset(data_dir, start_index, end_index, y, n_tsteps, max_index, 66 | n_triplets_per_file=n_triplets_per_file) 67 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 68 | return dataloader 69 | -------------------------------------------------------------------------------- /data_preprocessing/rescaling/soil_moisture.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from netCDF4 import Dataset 9 | import numpy as np 10 | import numpy.ma as ma 11 | 12 | import sys 13 | sys.path.append("..") 14 | 15 | from data_preprocessing.utils import generate_doy 16 | from data_preprocessing.rescaling.rescale_utils import search_kdtree 17 | 18 | 19 | def reproject_lat_lon(): 20 | fh_sm = Dataset('../../raw_data/soil_moisture/9km/20170101.nc', 'r') 21 | lats, lons = fh_sm.variables['lat'][:], fh_sm.variables['lon'][:] 22 | lons, lats = np.meshgrid(lons, lats) 23 | 24 | fh_ref = Dataset('../../processed_data/lai/500m/20181028.nc', 'r') 25 | ref_lats, ref_lons = fh_ref.variables['lat'][:], fh_ref.variables['lon'][:] 26 | 27 | xv, yv = np.meshgrid(ref_lons, ref_lats) 28 | points = np.dstack([yv.ravel(), xv.ravel()])[0] 29 | print('Finish building points') 30 | results = search_kdtree(lats, lons, points) 31 | np.save('../../raw_data/soil_moisture/projected_indices_lai_500m.npy', results) 32 | 33 | 34 | def reproject_sm(doy): 35 | fh_ref = Dataset('../../processed_data/lai/500m/20181028.nc', 'r') 36 | fh_in = Dataset('../../raw_data/soil_moisture/9km/{}.nc'.format(doy), 'r') 37 | fh_out = Dataset('../../processed_data/soil_moisture/9km_500m/{}.nc'.format(doy), 'w') 38 | 39 | ref_lats, ref_lons = fh_ref.variables['lat'][:], fh_ref.variables['lon'][:] 40 | n_lat, n_lon = len(ref_lats), len(ref_lons) 41 | for name, dim in fh_ref.dimensions.items(): 42 | fh_out.createDimension(name, len(dim)) 43 | 44 | for v_name, varin in fh_ref.variables.items(): 45 | if v_name in ['lat', 'lon']: 46 | outVar = fh_out.createVariable(v_name, varin.datatype, (v_name,)) 47 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 48 | outVar[:] = varin[:] 49 | 50 | origi_values = {} 51 | projected_values = {} 52 | for v_name, varin in fh_in.variables.items(): 53 | if v_name in ['soil_moisture']: 54 | outVar = fh_out.createVariable(v_name, varin.datatype, ('lat', 'lon')) 55 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 56 | origi_values[v_name] = varin[:] 57 | projected_values[v_name] = np.full((n_lat, n_lon), -9999.9) 58 | 59 | projected_indices = np.load('../../raw_data/soil_moisture/projected_indices_lai_500m.npy') 60 | projected_i = 0 61 | 62 | for i in range(n_lat): 63 | for j in range(n_lon): 64 | for key in origi_values.keys(): 65 | proj_i, proj_j = projected_indices[projected_i] // 674, projected_indices[projected_i] % 674 66 | if not origi_values[key].mask[proj_i, proj_j]: 67 | projected_values[key][i, j] = origi_values[key][proj_i, proj_j] 68 | projected_i += 1 69 | 70 | for key in origi_values.keys(): 71 | fh_out.variables[key][:] = ma.masked_equal(projected_values[key], -9999.9) 72 | 73 | fh_in.close() 74 | fh_ref.close() 75 | fh_out.close() 76 | 77 | 78 | if __name__ == '__main__': 79 | # reproject_lat_lon() 80 | for doy in generate_doy('20181002', '20181231', ''): 81 | reproject_sm(doy) 82 | -------------------------------------------------------------------------------- /crop_yield_prediction/dataloader/cnn_lstm_dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch.utils.data import Dataset, DataLoader 8 | import torch 9 | import numpy as np 10 | 11 | 12 | class CnnLSTMDataset(Dataset): 13 | """ 14 | Case 0 n_triplets_per_file == (max_index + 1): load numpy file in __init__, retrieve idx in __getitem__ 15 | Case 1 n_triplets_per_file == 1: load numpy file for idx in __getitem__ 16 | Case 2 n_triplets_per_file > 1: load numpy file that stores idx (and others) in __getitem__ 17 | idx is the index in "current" train/validation/test set. global idx is the index in the whole data set. 18 | Indices in train/validation/test set need to be sequential. 19 | """ 20 | def __init__(self, data_dir, start_index, end_index, y, n_tsteps, max_index, n_triplets_per_file): 21 | self.data_dir = data_dir 22 | self.start_index = start_index 23 | self.end_index = end_index 24 | self.n_triplets = end_index - start_index + 1 25 | self.n_triplets_per_file = n_triplets_per_file 26 | self.y = y 27 | self.n_tsteps = n_tsteps 28 | self.max_index = max_index 29 | if n_triplets_per_file == (max_index + 1): 30 | self.X_data = np.load('{}/0_{}.npy'.format(data_dir, max_index)) 31 | 32 | def __len__(self): 33 | return self.n_triplets 34 | 35 | def __getitem__(self, idx): 36 | global_idx = idx + self.start_index 37 | 38 | if self.n_triplets_per_file == (self.max_index + 1): 39 | X_idx = self.X_data[global_idx][:self.n_tsteps] 40 | else: 41 | if self.n_triplets_per_file > 1: 42 | file_idx = global_idx // self.n_triplets_per_file 43 | local_idx = global_idx % self.n_triplets_per_file 44 | 45 | end_idx = min((file_idx+1)*self.n_triplets_per_file-1, self.max_index) 46 | X_idx = np.load('{}/{}_{}.npy'.format(self.data_dir, 47 | file_idx * self.n_triplets_per_file, 48 | end_idx))[local_idx][:self.n_tsteps] 49 | else: 50 | X_idx = np.load('{}/{}.npy'.format(self.data_dir, global_idx))[0][:self.n_tsteps] 51 | y_idx = np.array(self.y[idx]) 52 | 53 | X_idx = X_idx[:, 0, :, :, :] 54 | 55 | return torch.from_numpy(X_idx).float(), torch.from_numpy(y_idx).float() 56 | 57 | 58 | def cnn_lstm_dataloader(data_dir, start_index, end_index, y, n_tsteps, max_index, n_triplets_per_file, 59 | batch_size=50, shuffle=True, num_workers=4): 60 | """ 61 | img_type: 'landsat', 'rgb', or 'naip' 62 | augment: random flip and rotate for data augmentation 63 | shuffle: turn shuffle to False for producing embeddings that correspond to original tiles. 64 | Returns a DataLoader with either NAIP (RGB/IR), RGB, or Landsat tiles. 65 | """ 66 | 67 | dataset = CnnLSTMDataset(data_dir, start_index, end_index, y, n_tsteps, max_index, 68 | n_triplets_per_file=n_triplets_per_file) 69 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 70 | return dataloader 71 | -------------------------------------------------------------------------------- /crop_yield_prediction/dataloader/c3d_dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch.utils.data import Dataset, DataLoader 8 | import torch 9 | import numpy as np 10 | 11 | 12 | class C3DDataset(Dataset): 13 | """ 14 | Case 0 n_triplets_per_file == (max_index + 1): load numpy file in __init__, retrieve idx in __getitem__ 15 | Case 1 n_triplets_per_file == 1: load numpy file for idx in __getitem__ 16 | Case 2 n_triplets_per_file > 1: load numpy file that stores idx (and others) in __getitem__ 17 | idx is the index in "current" train/validation/test set. global idx is the index in the whole data set. 18 | Indices in train/validation/test set need to be sequential. 19 | """ 20 | def __init__(self, data_dir, start_index, end_index, y, n_tsteps, max_index, n_triplets_per_file): 21 | self.data_dir = data_dir 22 | self.start_index = start_index 23 | self.end_index = end_index 24 | self.n_triplets = end_index - start_index + 1 25 | self.n_triplets_per_file = n_triplets_per_file 26 | self.y = y 27 | self.n_tsteps = n_tsteps 28 | self.max_index = max_index 29 | if n_triplets_per_file == (max_index + 1): 30 | self.X_data = np.load('{}/0_{}.npy'.format(data_dir, max_index)) 31 | 32 | def __len__(self): 33 | return self.n_triplets 34 | 35 | def __getitem__(self, idx): 36 | global_idx = idx + self.start_index 37 | 38 | if self.n_triplets_per_file == (self.max_index + 1): 39 | X_idx = self.X_data[global_idx][:self.n_tsteps] 40 | else: 41 | if self.n_triplets_per_file > 1: 42 | file_idx = global_idx // self.n_triplets_per_file 43 | local_idx = global_idx % self.n_triplets_per_file 44 | 45 | end_idx = min((file_idx+1)*self.n_triplets_per_file-1, self.max_index) 46 | X_idx = np.load('{}/{}_{}.npy'.format(self.data_dir, 47 | file_idx * self.n_triplets_per_file, 48 | end_idx))[local_idx][:self.n_tsteps] 49 | else: 50 | X_idx = np.load('{}/{}.npy'.format(self.data_dir, global_idx))[0][:self.n_tsteps] 51 | y_idx = np.array(self.y[idx]) 52 | 53 | X_idx = X_idx[:, 0, :, :, :] 54 | X_idx = np.swapaxes(X_idx, 0, 1) 55 | return torch.from_numpy(X_idx).float(), torch.from_numpy(y_idx).float() 56 | 57 | 58 | def c3d_dataloader(data_dir, start_index, end_index, y, n_tsteps, max_index, n_triplets_per_file, 59 | batch_size=50, shuffle=True, num_workers=4): 60 | """ 61 | img_type: 'landsat', 'rgb', or 'naip' 62 | augment: random flip and rotate for data augmentation 63 | shuffle: turn shuffle to False for producing embeddings that correspond to original tiles. 64 | Returns a DataLoader with either NAIP (RGB/IR), RGB, or Landsat tiles. 65 | """ 66 | 67 | dataset = C3DDataset(data_dir, start_index, end_index, y, n_tsteps, max_index, 68 | n_triplets_per_file=n_triplets_per_file) 69 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 70 | return dataloader 71 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/deep_gaussian_process/feature_engineering.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Adapt code from https://github.com/gabrieltseng/pycrop-yield-prediction 7 | 8 | 9 | from crop_yield_prediction import CLIMATE_VARS 10 | 11 | import pandas as pd 12 | import numpy as np 13 | 14 | MAX_BIN_VAL = {'ppt': 179.812, 'evi': 0.631, 'ndvi': 0.850, 'elevation': 961.420, 'lst_day': 309.100, 15 | 'lst_night': 293.400, 'clay': 47.0, 'sand': 91.0, 'silt': 70.0} 16 | MIN_BIN_VAL = {'ppt': 11.045, 'evi': 0.084, 'ndvi': 0.138, 'elevation': 175.0, 'lst_day': 269.640, 17 | 'lst_night': 261.340, 'clay': 4.0, 'sand': 13.0, 'silt': 10.0} 18 | 19 | 20 | def _calculate_histogram(image, num_bins=32): 21 | """ 22 | Input image shape: (n_variables, n_timesteps, 50, 50) 23 | """ 24 | hist = [] 25 | n_variables, n_timesteps = image.shape[:2] 26 | for var_idx in range(n_variables): 27 | bin_seq = np.linspace(MIN_BIN_VAL[CLIMATE_VARS[var_idx]], MAX_BIN_VAL[CLIMATE_VARS[var_idx]], num_bins + 1) 28 | im = image[var_idx] 29 | imhist = [] 30 | for ts_idx in range(n_timesteps): 31 | density, _ = np.histogram(im[ts_idx, :, :], bin_seq, density=False) 32 | # max() prevents divide by 0 33 | imhist.append(density / max(1, density.sum())) 34 | hist.append(np.stack(imhist)) 35 | 36 | # [bands, times, bins] 37 | hist = np.stack(hist) 38 | 39 | return hist 40 | 41 | 42 | def get_features_for_deep_gaussian(): 43 | output_images = [] 44 | yields = [] 45 | years = [] 46 | locations = [] 47 | state_county_info = [] 48 | 49 | yield_data = pd.read_csv('data/deep_gaussian/deep_gaussian_dim_y.csv')[['state', 'county', 'year', 'value', 'lat', 'lon']] 50 | yield_data.columns = ['state', 'county', 'year', 'value', 'lat', 'lon'] 51 | 52 | for idx, yield_data in enumerate(yield_data.itertuples()): 53 | year, county, state = yield_data.year, yield_data.county, yield_data.state 54 | 55 | # [1, n_timesteps, 1+n_temporal_neighbor+n_spatial_neighbor+n_distant, n_variables, 50, 50] 56 | image = np.load('data/deep_gaussian/nr_25/{}.npy'.format(idx)) 57 | 58 | # get anchor image from shape (1, n_timesteps, 4, n_variables, 50, 50) 59 | # to shape (n_timestep, n_variables, 50, 50) 60 | image = image[0, :, 0, :, :, :] 61 | # shape (n_variables, n_timesteps, 50, 50) 62 | image = np.swapaxes(image, 0, 1) 63 | 64 | image = _calculate_histogram(image, num_bins=32) 65 | 66 | output_images.append(image) 67 | yields.append(yield_data.value) 68 | years.append(year) 69 | 70 | lat, lon = yield_data.lat, yield_data.lon 71 | locations.append(np.array([lon, lat])) 72 | 73 | state_county_info.append(np.array([int(state), int(county)])) 74 | 75 | # print(f'County: {int(county)}, State: {state}, Year: {year}, Output shape: {image.shape}') 76 | 77 | np.savez('data/deep_gaussian/data.npz', 78 | output_image=np.stack(output_images), output_yield=np.array(yields), 79 | output_year=np.array(years), output_locations=np.stack(locations), 80 | output_index=np.stack(state_county_info)) 81 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/deep_gaussian_process/gp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Adapt code from https://github.com/gabrieltseng/pycrop-yield-prediction 7 | 8 | import numpy as np 9 | from scipy.spatial.distance import pdist, squareform 10 | 11 | 12 | class GaussianProcess: 13 | """ 14 | The crop yield Gaussian process 15 | """ 16 | def __init__(self, sigma=1, r_loc=0.5, r_year=1.5, sigma_e=0.32, sigma_b=0.01): 17 | self.sigma = sigma 18 | self.r_loc = r_loc 19 | self.r_year = r_year 20 | self.sigma_e = sigma_e 21 | self.sigma_b = sigma_b 22 | 23 | @staticmethod 24 | def _normalize(x): 25 | x_mean = np.mean(x, axis=0, keepdims=True) 26 | x_scale = np.ptp(x, axis=0, keepdims=True) 27 | 28 | return (x - x_mean) / x_scale 29 | 30 | def run(self, feat_train, feat_test, loc_train, loc_test, year_train, year_test, 31 | train_yield, model_weights, model_bias): 32 | 33 | # makes sure the features have an additional testue for the bias term 34 | # We call the features H since the features are used as the basis functions h(x) 35 | H_train = np.concatenate((feat_train, np.ones((feat_train.shape[0], 1))), axis=1) 36 | H_test = np.concatenate((feat_test, np.ones((feat_test.shape[0], 1))), axis=1) 37 | 38 | Y_train = np.expand_dims(train_yield, axis=1) 39 | 40 | n_train = feat_train.shape[0] 41 | n_test = feat_test.shape[0] 42 | 43 | locations = self._normalize(np.concatenate((loc_train, loc_test), axis=0)) 44 | years = self._normalize(np.concatenate((year_train, year_test), axis=0)) 45 | # to calculate the se_kernel, a dim=2 array must be passed 46 | years = np.expand_dims(years, axis=1) 47 | 48 | # These are the squared exponential kernel function we'll use for the covariance 49 | se_loc = squareform(pdist(locations, 'euclidean')) ** 2 / (self.r_loc ** 2) 50 | se_year = squareform(pdist(years, 'euclidean')) ** 2 / (self.r_year ** 2) 51 | 52 | # make the dirac matrix we'll add onto the kernel function 53 | noise = np.zeros([n_train + n_test, n_train + n_test]) 54 | noise[0: n_train, 0: n_train] += (self.sigma_e ** 2) * np.identity(n_train) 55 | 56 | kernel = ((self.sigma ** 2) * np.exp(-se_loc) * np.exp(-se_year)) + noise 57 | 58 | # since B is diagonal, and B = self.sigma_b * np.identity(feat_train.shape[1]), 59 | # its easy to calculate the inverse of B 60 | B_inv = np.identity(H_train.shape[1]) / self.sigma_b 61 | # "We choose b as the weight vector of the last layer of our deep models" 62 | b = np.concatenate((model_weights.transpose(1, 0), np.expand_dims(model_bias, 1))) 63 | 64 | K_inv = np.linalg.inv(kernel[0: n_train, 0: n_train]) 65 | 66 | # The definition of beta comes from equation 2.41 in Rasmussen (2006) 67 | beta = np.linalg.inv(B_inv + H_train.T.dot(K_inv).dot(H_train)).dot( 68 | H_train.T.dot(K_inv).dot(Y_train) + B_inv.dot(b)) 69 | 70 | # We take the mean of g(X*) as our prediction, also from equation 2.41 71 | pred = H_test.dot(beta) + \ 72 | kernel[n_train:, :n_train].dot(K_inv).dot(Y_train - H_train.dot(beta)) 73 | 74 | return pred 75 | -------------------------------------------------------------------------------- /data_preprocessing/preprocess/landsat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from netCDF4 import Dataset 9 | import numpy as np 10 | import numpy.ma as ma 11 | 12 | 13 | def combine_landsat(): 14 | fh_out = Dataset('../../processed_data/landsat/20180719.nc', 'w') 15 | 16 | flag = False 17 | for i in range(1, 8): 18 | fh_in = Dataset('../../raw_data/landsat/nebraska/SRB{}_20180719.nc'.format(i), 'r') 19 | if not flag: 20 | lats, lons = fh_in.variables['lat'][:], fh_in.variables['lon'][:] 21 | 22 | fh_out.createDimension("lat", len(lats)) 23 | fh_out.createDimension("lon", len(lons)) 24 | 25 | for v_name, varin in fh_in.variables.items(): 26 | if v_name in ["lat", "lon"]: 27 | outVar = fh_out.createVariable(v_name, varin.datatype, (v_name,)) 28 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 29 | fh_out.variables["lat"][:] = lats[:] 30 | fh_out.variables["lon"][:] = lons[:] 31 | flag = True 32 | 33 | for v_name, varin in fh_in.variables.items(): 34 | if v_name == 'Band1': 35 | outVar = fh_out.createVariable('band{}'.format(i), varin.datatype, ('lat', 'lon')) 36 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 37 | outVar[:] = ma.masked_less(varin[:], 0) 38 | 39 | fh_in.close() 40 | fh_out.close() 41 | 42 | 43 | # 20190604 44 | def subset_landsat(lat1, lat2, lon1, lon2): 45 | fh_out = Dataset('../../processed_data/landsat/2019155.nc', 'w') 46 | 47 | flag = False 48 | lat_indices, lon_indices = None, None 49 | for i in range(1, 8): 50 | fh_in = Dataset('../../raw_data/landsat/SRB{}_doy2019155.nc'.format(i), 'r') 51 | if not flag: 52 | lats, lons = fh_in.variables['lat'][:], fh_in.variables['lon'][:] 53 | lat_indices = np.searchsorted(lats, [lat2, lat1]) 54 | lon_indices = np.searchsorted(lons, [lon1, lon2]) 55 | lats = lats[lat_indices[0]: lat_indices[1]] 56 | lons = lons[lon_indices[0]: lon_indices[1]] 57 | 58 | fh_out.createDimension("lat", len(lats)) 59 | fh_out.createDimension("lon", len(lons)) 60 | 61 | for v_name, varin in fh_in.variables.items(): 62 | if v_name in ["lat", "lon"]: 63 | outVar = fh_out.createVariable(v_name, varin.datatype, (v_name,)) 64 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 65 | fh_out.variables["lat"][:] = lats[:] 66 | fh_out.variables["lon"][:] = lons[:] 67 | flag = True 68 | 69 | for v_name, varin in fh_in.variables.items(): 70 | if v_name == 'Band1': 71 | outVar = fh_out.createVariable('band{}'.format(i), varin.datatype, ('lat', 'lon')) 72 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 73 | outVar[:] = ma.masked_less(varin[lat_indices[0]: lat_indices[1], lon_indices[0]: lon_indices[1]], 0) 74 | 75 | fh_in.close() 76 | fh_out.close() 77 | 78 | 79 | if __name__ == '__main__': 80 | # combine_landsat() 81 | subset_landsat(41.2047, 40.0268, -122.0304, -120.4676) 82 | -------------------------------------------------------------------------------- /data_preprocessing/rescaling/prism_downscale.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from netCDF4 import Dataset 9 | import numpy as np 10 | import numpy.ma as ma 11 | 12 | import sys 13 | sys.path.append("..") 14 | 15 | from data_preprocessing.utils import generate_doy 16 | from data_preprocessing.rescaling.rescale_utils import search_kdtree 17 | 18 | 19 | def reproject_lat_lon(): 20 | fh_prism = Dataset('../../processed_data/prism/combined_monthly/201701.nc', 'r') 21 | lats, lons = fh_prism.variables['lat'][:], fh_prism.variables['lon'][:] 22 | lons, lats = np.meshgrid(lons, lats) 23 | 24 | fh_ref = Dataset('../../processed_data/lst/monthly_1km/201701.nc', 'r') 25 | ref_lats, ref_lons = fh_ref.variables['lat'][:], fh_ref.variables['lon'][:] 26 | 27 | xv, yv = np.meshgrid(ref_lons, ref_lats) 28 | points = np.dstack([yv.ravel(), xv.ravel()])[0] 29 | print('Finish building points') 30 | results = search_kdtree(lats, lons, points) 31 | np.save('../../processed_data/prism/projected_indices_lst_1km.npy', results) 32 | 33 | 34 | def reproject_prism(doy): 35 | fh_ref = Dataset('../../processed_data/lst/monthly_1km/201702.nc', 'r') 36 | fh_in = Dataset('../../processed_data/prism/combined_monthly/{}.nc'.format(doy), 'r') 37 | fh_out = Dataset('../../processed_data/prism/combined_monthly_1km/{}.nc'.format(doy), 'w') 38 | 39 | ref_lats, ref_lons = fh_ref.variables['lat'][:], fh_ref.variables['lon'][:] 40 | n_lat, n_lon = len(ref_lats), len(ref_lons) 41 | for name, dim in fh_ref.dimensions.items(): 42 | fh_out.createDimension(name, len(dim)) 43 | 44 | for v_name, varin in fh_ref.variables.items(): 45 | if v_name in ['lat', 'lon']: 46 | outVar = fh_out.createVariable(v_name, varin.datatype, (v_name,)) 47 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 48 | outVar[:] = varin[:] 49 | 50 | origi_values = {} 51 | projected_values = {} 52 | for v_name, varin in fh_in.variables.items(): 53 | if v_name not in ['lat', 'lon']: 54 | outVar = fh_out.createVariable(v_name, varin.datatype, ('lat', 'lon')) 55 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 56 | origi_values[v_name] = varin[:] 57 | projected_values[v_name] = np.full((n_lat, n_lon), -9999.9) 58 | 59 | projected_indices = np.load('../../processed_data/prism/projected_indices_lst_1km.npy') 60 | projected_i = 0 61 | 62 | for i in range(n_lat): 63 | for j in range(n_lon): 64 | for key in origi_values.keys(): 65 | proj_i, proj_j = projected_indices[projected_i] // 1405, projected_indices[projected_i] % 1405 66 | if not origi_values[key].mask[proj_i, proj_j]: 67 | projected_values[key][i, j] = origi_values[key][proj_i, proj_j] 68 | projected_i += 1 69 | 70 | for key in origi_values.keys(): 71 | fh_out.variables[key][:] = ma.masked_equal(projected_values[key], -9999.9) 72 | 73 | fh_in.close() 74 | fh_ref.close() 75 | fh_out.close() 76 | 77 | 78 | if __name__ == '__main__': 79 | # reproject_lat_lon() 80 | for year in range(2018, 2019): 81 | for month in range(2, 11): 82 | reproject_prism(doy='{}{}'.format(year, '{0:02}'.format(month))) 83 | -------------------------------------------------------------------------------- /data_preprocessing/rescaling/elevation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import numpy.ma as ma 9 | from netCDF4 import Dataset 10 | 11 | import sys 12 | sys.path.append("..") 13 | 14 | from data_preprocessing.rescaling.rescale_utils import get_lat_lon_bins 15 | 16 | 17 | def reproject_elevation(): 18 | fh_in = Dataset('../../raw_data/elevation/90m.nc', 'r') 19 | fh_out = Dataset('../../processed_data/elevation/1km.nc', 'w') 20 | fh_ref = Dataset('../../processed_data/lst/monthly_1km/201508.nc', 'r') 21 | 22 | ref_lats, ref_lons = fh_ref.variables['lat'][:], fh_ref.variables['lon'][:] 23 | lat_bins, lon_bins = get_lat_lon_bins(ref_lats, ref_lons) 24 | 25 | ele_lats = fh_in.variables['lat'] 26 | ele_lats_value = ele_lats[:][::-1] 27 | ele_lons = fh_in.variables['lon'] 28 | ele_lons_value = ele_lons[:] 29 | ele_var = fh_in.variables['Band1'][0, :, :] 30 | ele_resampled = np.full((len(ref_lats), len(ref_lons)), -9999.9) 31 | # ele_std_resampled = np.full((len(ref_lats), len(ref_lons)), -9999.9) 32 | 33 | for id_lats in range(len(ref_lats)): 34 | for id_lons in range(len(ref_lons)): 35 | lats_index = np.searchsorted(ele_lats_value, [lat_bins[id_lats + 1], lat_bins[id_lats]]) 36 | lons_index = np.searchsorted(ele_lons_value, [lon_bins[id_lons], lon_bins[id_lons + 1]]) 37 | if lats_index[0] != lats_index[1] and lons_index[0] != lons_index[1]: 38 | ele_selected = ele_var[np.array(range(-lats_index[1], -lats_index[0]))[:, None], 39 | np.array(range(lons_index[0], lons_index[1]))] 40 | avg = ma.mean(ele_selected) 41 | # std = ma.std(ele_selected) 42 | ele_resampled[id_lats, id_lons] = (avg if avg is not ma.masked else -9999.9) 43 | # ele_std_resampled[id_lats, id_lons] = (std if std is not ma.masked else -9999.9) 44 | print(id_lats) 45 | 46 | ele_resampled = ma.masked_equal(ele_resampled, -9999.9) 47 | # ele_std_resampled = ma.masked_equal(ele_std_resampled, -9999.9) 48 | 49 | fh_out.createDimension('lat', len(ref_lats)) 50 | fh_out.createDimension('lon', len(ref_lons)) 51 | 52 | outVar = fh_out.createVariable('lat', 'f4', ('lat',)) 53 | outVar.setncatts({k: ele_lats.getncattr(k) for k in ele_lats.ncattrs()}) 54 | outVar[:] = ref_lats[:] 55 | 56 | outVar = fh_out.createVariable('lon', 'f4', ('lon',)) 57 | outVar.setncatts({k: ele_lons.getncattr(k) for k in ele_lons.ncattrs()}) 58 | outVar[:] = ref_lons[:] 59 | 60 | # outVar = fh_out.createVariable('elevation_mean', 'f4', ('lat', 'lon',)) 61 | outVar = fh_out.createVariable('elevation', 'f4', ('lat', 'lon',)) 62 | outVar.setncatts({'units': "m"}) 63 | outVar.setncatts({'long_name': "USGS_NED Elevation value"}) 64 | outVar.setncatts({'_FillValue': np.array([-9999.9]).astype('f')}) 65 | outVar[:] = ele_resampled[:] 66 | 67 | # outVar = fh_out.createVariable('elevation_std', 'f4', ('lat', 'lon',)) 68 | # outVar.setncatts({'units': "m"}) 69 | # outVar.setncatts({'long_name': "USGS_NED Elevation value"}) 70 | # outVar.setncatts({'_FillValue': np.array([-9999.9]).astype('f')}) 71 | # outVar[:] = ele_std_resampled[:] 72 | 73 | 74 | if __name__ == '__main__': 75 | reproject_elevation() 76 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /data_preprocessing/rescaling/us_counties.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import numpy.ma as ma 9 | from netCDF4 import Dataset 10 | 11 | import sys 12 | sys.path.append("..") 13 | 14 | from data_preprocessing.rescaling.rescale_utils import get_lat_lon_bins 15 | 16 | 17 | def reproject_us_counties(in_file, ref_file, out_file): 18 | fh_in = Dataset(in_file, 'r') 19 | fh_out = Dataset(out_file, 'w') 20 | fh_ref = Dataset(ref_file, 'r') 21 | 22 | ref_lats, ref_lons = fh_ref.variables['lat'][:], fh_ref.variables['lon'][:] 23 | lat_bins, lon_bins = get_lat_lon_bins(ref_lats, ref_lons) 24 | 25 | origi_lats = fh_in.variables['lat'] 26 | origi_lats_value = origi_lats[:] 27 | origi_lons = fh_in.variables['lon'] 28 | origi_lons_value = origi_lons[:] 29 | origi_values = {} 30 | sampled_values = {} 31 | selected_vars = [] 32 | for v in fh_in.variables: 33 | if v not in ['lat', 'lon']: 34 | selected_vars.append(v) 35 | origi_values[v] = fh_in.variables[v][:] 36 | sampled_values[v] = np.full((len(ref_lats), len(ref_lons)), 0) 37 | 38 | for id_lats in range(len(ref_lats)): 39 | for id_lons in range(len(ref_lons)): 40 | lats_index = np.searchsorted(origi_lats_value, [lat_bins[id_lats + 1], lat_bins[id_lats]]) 41 | lons_index = np.searchsorted(origi_lons_value, [lon_bins[id_lons], lon_bins[id_lons + 1]]) 42 | if lats_index[0] != lats_index[1] and lons_index[0] != lons_index[1]: 43 | for v in selected_vars: 44 | selected = origi_values[v][np.array(range(lats_index[0], lats_index[1])), 45 | np.array(range(lons_index[0], lons_index[1]))] 46 | if selected.count() > 0: 47 | sampled_values[v][id_lats, id_lons] = np.bincount(selected.compressed()).argmax() 48 | 49 | else: 50 | sampled_values[v][id_lats, id_lons] = 0 51 | print(id_lats) 52 | 53 | fh_out.createDimension('lat', len(ref_lats)) 54 | fh_out.createDimension('lon', len(ref_lons)) 55 | 56 | outVar = fh_out.createVariable('lat', 'f4', ('lat',)) 57 | outVar.setncatts({k: origi_lats.getncattr(k) for k in origi_lats.ncattrs()}) 58 | outVar[:] = ref_lats[:] 59 | 60 | outVar = fh_out.createVariable('lon', 'f4', ('lon',)) 61 | outVar.setncatts({k: origi_lons.getncattr(k) for k in origi_lons.ncattrs()}) 62 | outVar[:] = ref_lons[:] 63 | 64 | outVar = fh_out.createVariable('county_label', 'int', ('lat', 'lon')) 65 | outVar.setncatts({'_FillValue': np.array([0]).astype(int)}) 66 | outVar[:] = ma.masked_equal(sampled_values['county_label'], 0) 67 | 68 | outVar = fh_out.createVariable('state_code', 'int', ('lat', 'lon')) 69 | outVar.setncatts({'_FillValue': np.array([0]).astype(int)}) 70 | outVar[:] = ma.masked_equal(sampled_values['state_code'], 0) 71 | 72 | outVar = fh_out.createVariable('county_code', 'int', ('lat', 'lon')) 73 | outVar.setncatts({'_FillValue': np.array([0]).astype(int)}) 74 | outVar[:] = ma.masked_equal(sampled_values['county_code'], 0) 75 | 76 | fh_in.close() 77 | fh_ref.close() 78 | fh_out.close() 79 | 80 | 81 | if __name__ == '__main__': 82 | reproject_us_counties('../../processed_data/counties/us_counties.nc', 83 | '../../processed_data/lst/monthly_1km/201505.nc', 84 | '../../processed_data/counties/lst/us_counties.nc') 85 | -------------------------------------------------------------------------------- /crop_yield_prediction/plot/plot_crop_yield.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from bs4 import BeautifulSoup 8 | from pathlib import Path 9 | import matplotlib as mpl 10 | import matplotlib.pyplot as plt 11 | import pandas as pd 12 | from collections import defaultdict 13 | import numpy as np 14 | import seaborn as sns 15 | 16 | 17 | # colors = sns.color_palette("RdYlBu", 10).as_hex() 18 | colors = ['#cdeaf3', '#9bcce2', '#fff1aa', '#fece7f', '#fa9b58', '#ee613e', '#d22b27'] 19 | 20 | SOYBEAN_QUANTILES = {0.05: 20.0, 0.2: 29.5, 0.4: 36.8, 0.6: 43.0, 0.8: 49.3, 0.95: 56.8, 0.0: 0.7, 1.0: 82.3} 21 | 22 | 23 | def crop_yield_plot(data_dict, savepath, quantiles=SOYBEAN_QUANTILES): 24 | """ 25 | For the most part, reformatting of 26 | https://github.com/JiaxuanYou/crop_yield_prediction/blob/master/6%20result_analysis/yield_map.py 27 | """ 28 | # load the svg file 29 | svg = Path('data/counties.svg').open('r').read() 30 | # Load into Beautiful Soup 31 | soup = BeautifulSoup(svg, features="html.parser") 32 | # Find counties 33 | paths = soup.findAll('path') 34 | 35 | path_style = 'font-size:12px;fill-rule:nonzero;stroke:#FFFFFF;stroke-opacity:1;stroke-width:0.1' \ 36 | ';stroke-miterlimit:4;stroke-dasharray:none;stroke-linecap:butt;marker-start' \ 37 | ':none;stroke-linejoin:bevel;fill:' 38 | 39 | for p in paths: 40 | if p['id'] not in ["State_Lines", "separator"]: 41 | try: 42 | rate = data_dict[p['id']] 43 | except KeyError: 44 | continue 45 | if rate > quantiles[0.95]: 46 | color_class = 6 47 | elif rate > quantiles[0.8]: 48 | color_class = 5 49 | elif rate > quantiles[0.6]: 50 | color_class = 4 51 | elif rate > quantiles[0.4]: 52 | color_class = 3 53 | elif rate > quantiles[0.2]: 54 | color_class = 2 55 | elif rate > quantiles[0.05]: 56 | color_class = 1 57 | else: 58 | color_class = 0 59 | 60 | color = colors[color_class] 61 | p['style'] = path_style + color 62 | soup = soup.prettify() 63 | with savepath.open('w') as f: 64 | f.write(soup) 65 | 66 | 67 | def save_colorbar(savedir, quantiles=SOYBEAN_QUANTILES): 68 | """ 69 | For the most part, reformatting of 70 | https://github.com/JiaxuanYou/crop_yield_prediction/blob/master/6%20result_analysis/yield_map.py 71 | """ 72 | fig = plt.figure() 73 | ax = fig.add_axes([0.1, 0.1, 0.02, 0.8]) 74 | 75 | cmap = mpl.colors.ListedColormap(colors[1:-1]) 76 | 77 | cmap.set_over(colors[-1]) 78 | cmap.set_under(colors[0]) 79 | 80 | bounds = [quantiles[x] for x in [0.05, 0.2, 0.4, 0.6, 0.8, 0.95]] 81 | 82 | norm = mpl.colors.BoundaryNorm(bounds, cmap.N) 83 | cb = mpl.colorbar.ColorbarBase(ax, cmap=cmap, 84 | norm=norm, 85 | # to use 'extend', you must 86 | # specify two extra boundaries: 87 | boundaries=[quantiles[0.0]] + bounds + [quantiles[1.0]], 88 | extend='both', 89 | ticks=bounds, # optional 90 | spacing='proportional', 91 | orientation='vertical') 92 | plt.savefig('{}/colorbar.jpg'.format(savedir), dpi=300, bbox_inches='tight') 93 | -------------------------------------------------------------------------------- /generate_for_deep_gaussian.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from data_preprocessing import CLIMATE_VARS 8 | from data_preprocessing.sample_quadruplets import generate_training_for_counties 9 | 10 | from netCDF4 import Dataset 11 | import pandas as pd 12 | import numpy.ma as ma 13 | from collections import defaultdict 14 | import numpy as np 15 | 16 | 17 | def generate_dims_for_counties(croptype): 18 | yield_data = pd.read_csv('processed_data/crop_yield/{}_2000_2018.csv'.format(croptype))[[ 19 | 'Year', 'State ANSI', 'County ANSI', 'Value']] 20 | yield_data.columns = ['year', 'state', 'county', 'value'] 21 | ppt_fh = Dataset('experiment_data/spatial_temporal/nc_files/2014.nc', 'r') 22 | v_ppt = ppt_fh.variables['ppt'][0, :, :] 23 | if yield_data.value.dtype != float: 24 | yield_data['value'] = yield_data['value'].str.replace(',', '') 25 | yield_data = yield_data.astype({'year': int, 'state': int, 'county': int, 'value': float}) 26 | counties = pd.read_csv('processed_data/counties/lst/us_counties_cro_cvm_locations.csv') 27 | county_dic = {} 28 | for c in counties.itertuples(): 29 | state, county, lat, lon, lat0, lat1, lon0, lon1 = c.state, c.county, c.lat, c.lon, c.lat0, c.lat1, c.lon0, c.lon1 30 | county_dic[(state, county)] = [lat, lon, lat0, lat1, lon0, lon1] 31 | 32 | yield_dim_csv = [] 33 | for yd in yield_data.itertuples(): 34 | year, state, county, value = yd.year, yd.state, yd.county, yd.value 35 | 36 | if (state, county) not in county_dic: 37 | continue 38 | 39 | lat, lon, lat0, lat1, lon0, lon1 = county_dic[(state, county)] 40 | assert lat1 - lat0 == 49 41 | assert lon1 - lon0 == 49 42 | 43 | selected_ppt = v_ppt[lat0:lat1+1, lon0:lon1+1] 44 | if ma.count_masked(selected_ppt) != 0: 45 | continue 46 | 47 | yield_dim_csv.append([state, county, year, value, lat, lon]) 48 | 49 | yield_dim_csv = pd.DataFrame(yield_dim_csv, columns=['state', 'county', 'year', 'value', 'lat', 'lon']) 50 | yield_dim_csv.to_csv('experiment_data/spatial_temporal/counties/deep_gaussian_dim_y.csv') 51 | 52 | 53 | def get_max_min_val_for_climate_variable(img_dir): 54 | cv_dic = defaultdict(list) 55 | 56 | for year in range(2000, 2014): 57 | fh = Dataset('{}/{}.nc'.format(img_dir, year)) 58 | 59 | for v_name, varin in fh.variables.items(): 60 | if v_name in CLIMATE_VARS: 61 | cv_dic[v_name].append(varin[:].compressed()) 62 | fh.close() 63 | 64 | for cv in CLIMATE_VARS: 65 | values = np.asarray(cv_dic[cv]) 66 | print(cv, np.percentile(values, 95)) 67 | print(cv, np.percentile(values, 5)) 68 | 69 | 70 | if __name__ == '__main__': 71 | generate_dims_for_counties(croptype='soybeans') 72 | get_max_min_val_for_climate_variable('experiment_data/spatial_temporal/nc_files') 73 | for nr in [25]: 74 | # for nr in [10, 25, 50, 100, 500, None]: 75 | generate_training_for_counties(out_dir='experiment_data/deep_gaussian/counties', 76 | img_dir='experiment_data/spatial_temporal/nc_files', 77 | start_month=3, end_month=9, start_month_index=1, n_spatial_neighbor=1, n_distant=1, 78 | img_timestep_quadruplets= 79 | 'experiment_data/spatial_temporal/counties/img_timestep_quadruplets_hard.csv', 80 | img_size=50, neighborhood_radius=nr, prenorm=False) 81 | -------------------------------------------------------------------------------- /generate_experiment_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from data_preprocessing.sample_quadruplets import generate_training_for_counties 8 | from data_preprocessing.postprocess import mask_non_major_states 9 | from data_preprocessing.postprocess import generate_no_spatial_for_counties 10 | from data_preprocessing.postprocess import obtain_channel_wise_mean_std 11 | from data_preprocessing.sample_quadruplets import generate_training_for_pretrained 12 | 13 | if __name__ == '__main__': 14 | # MAJOR_STATES = [17, 18, 19, 20, 27, 29, 31, 38, 39, 46, 21, 55, 26] 15 | # ['Illinois', 'Indiana', 'Iowa', 'Kansas', 'Minnesota', 'Missouri', 'Nebraska', 'North Dakota', 'Ohio', 16 | # 'South Dakota', 'Kentucky', 'Wisconsin', 'Michigan'] 17 | mask_non_major_states('experiment_data/spatial_temporal/nc_files_unmasked', 18 | 'experiment_data/spatial_temporal/nc_files', 19 | 'processed_data/counties/lst/us_counties.nc', 20 | MAJOR_STATES) 21 | 22 | generate_no_spatial_for_counties(yield_data_dir='processed_data/crop_yield', 23 | ppt_file='experiment_data/spatial_temporal/nc_files/2003.nc', 24 | county_location_file='processed_data/counties/lst/us_counties_cro_cvm_locations.csv', 25 | out_dir='experiment_data/no_spatial', 26 | img_dir='experiment_data/spatial_temporal/nc_files', 27 | croptype='soybeans', 28 | start_month=3, 29 | end_month=9, 30 | start_index=1) 31 | 32 | obtain_channel_wise_mean_std('experiment_data/spatial_temporal/nc_files') 33 | 34 | for nr in [5, 50, 100, 1000, None]: 35 | generate_training_for_counties(out_dir='experiment_data/spatial_temporal/counties', 36 | img_dir='experiment_data/spatial_temporal/nc_files', 37 | start_month=3, end_month=9, start_month_index=1, n_spatial_neighbor=1, n_distant=1, 38 | img_timestep_quadruplets= 39 | 'experiment_data/spatial_temporal/counties/img_timestep_quadruplets_hard.csv', 40 | img_size=50, neighborhood_radius=nr, distant_radius=None, prenorm=True) 41 | 42 | generate_training_for_pretrained(out_dir='experiment_data/spatial_temporal/counties', 43 | img_dir='experiment_data/spatial_temporal/nc_files', 44 | n_quadruplets=100000, 45 | start_year=2003, end_year=2012, start_month=3, end_month=9, start_month_index=1, 46 | n_spatial_neighbor=1, n_distant=1, 47 | img_size=50, neighborhood_radius=10, distant_radius=50, prenorm=True) 48 | generate_training_for_pretrained(out_dir='experiment_data/spatial_temporal/counties', 49 | img_dir='experiment_data/spatial_temporal/nc_files', 50 | n_quadruplets=100000, 51 | start_year=2003, end_year=2012, start_month=3, end_month=9, start_month_index=1, 52 | n_spatial_neighbor=1, n_distant=1, 53 | img_size=50, neighborhood_radius=25, distant_radius=100, prenorm=True) 54 | -------------------------------------------------------------------------------- /data_preprocessing/rescaling/lst.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from netCDF4 import Dataset 9 | import numpy as np 10 | import numpy.ma as ma 11 | import datetime 12 | import os 13 | 14 | import sys 15 | sys.path.append("..") 16 | 17 | from data_preprocessing.rescaling.rescale_utils import search_kdtree 18 | 19 | 20 | def reproject_lat_lon(): 21 | fh_lst = Dataset('../../raw_data/lst/1km/20170101.nc', 'r') 22 | lats, lons = fh_lst.variables['lat'][:], fh_lst.variables['lon'][:] 23 | lons, lats = np.meshgrid(lons, lats) 24 | 25 | fh_ref = Dataset('../../processed_data/lai/500m/20181028.nc', 'r') 26 | ref_lats, ref_lons = fh_ref.variables['lat'][:], fh_ref.variables['lon'][:] 27 | 28 | xv, yv = np.meshgrid(ref_lons, ref_lats) 29 | points = np.dstack([yv.ravel(), xv.ravel()])[0] 30 | print('Finish building points') 31 | results = search_kdtree(lats, lons, points) 32 | np.save('../../raw_data/lst/projected_indices_lai_500m.npy', results) 33 | 34 | 35 | def reproject_lst(doy): 36 | print(doy) 37 | fh_ref = Dataset('../../processed_data/lai/500m/20181028.nc', 'r') 38 | fh_in = Dataset('../../raw_data/lst/1km/{}.nc'.format(doy), 'r') 39 | fh_out = Dataset('../../processed_data/lst/500m/{}.nc'.format(doy), 'w') 40 | 41 | ref_lats, ref_lons = fh_ref.variables['lat'][:], fh_ref.variables['lon'][:] 42 | n_lat, n_lon = len(ref_lats), len(ref_lons) 43 | for name, dim in fh_ref.dimensions.items(): 44 | fh_out.createDimension(name, len(dim)) 45 | 46 | for v_name, varin in fh_ref.variables.items(): 47 | if v_name in ['lat', 'lon']: 48 | outVar = fh_out.createVariable(v_name, varin.datatype, (v_name,)) 49 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 50 | outVar[:] = varin[:] 51 | 52 | origi_values = {} 53 | projected_values = {} 54 | for v_name, varin in fh_in.variables.items(): 55 | if v_name not in ['lat', 'lon']: 56 | new_name = '_'.join(v_name.lower().split('_')[:-1]) 57 | outVar = fh_out.createVariable(new_name, varin.datatype, ('lat', 'lon')) 58 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 59 | origi_values[new_name] = varin[:] 60 | projected_values[new_name] = np.full((n_lat, n_lon), -9999.9) 61 | 62 | projected_indices = np.load('../../raw_data/lst/projected_indices_lai_500m.npy') 63 | projected_i = 0 64 | 65 | for i in range(n_lat): 66 | for j in range(n_lon): 67 | for key in origi_values.keys(): 68 | proj_i, proj_j = projected_indices[projected_i] // 7797, projected_indices[projected_i] % 7797 69 | if not origi_values[key].mask[proj_i, proj_j]: 70 | projected_values[key][i, j] = origi_values[key][proj_i, proj_j] 71 | projected_i += 1 72 | 73 | for key in origi_values.keys(): 74 | fh_out.variables[key][:] = ma.masked_equal(projected_values[key], -9999.9) 75 | 76 | fh_in.close() 77 | fh_ref.close() 78 | fh_out.close() 79 | 80 | 81 | if __name__ == '__main__': 82 | # reproject_lat_lon() 83 | for doy in os.listdir('../../raw_data/lst/1km/'): 84 | if doy.endswith('.nc'): 85 | doy = doy[:-3] 86 | date = datetime.datetime.strptime(doy, "%Y%m%d").date() 87 | date_start = datetime.datetime.strptime('20190602', "%Y%m%d").date() 88 | date_end = datetime.datetime.strptime('20190602', "%Y%m%d").date() 89 | if date_start <= date <= date_end: 90 | reproject_lst(doy) 91 | -------------------------------------------------------------------------------- /data_preprocessing/preprocess/lai.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from netCDF4 import Dataset 8 | import numpy.ma as ma 9 | import datetime 10 | 11 | import sys 12 | sys.path.append("..") 13 | 14 | 15 | def extract_lai(nc_file): 16 | fh_in = Dataset('../../raw_data/lai/' + nc_file, 'r') 17 | 18 | for index, n_days in enumerate(fh_in.variables['time'][:]): 19 | date = (datetime.datetime(2000, 1, 1, 0, 0) + datetime.timedelta(int(n_days))).strftime('%Y%m%d') 20 | print(date) 21 | fh_out = Dataset('../../processed_data/lai/500m/{}.nc'.format(date), 'w') 22 | 23 | for name, dim in fh_in.dimensions.items(): 24 | if name != 'time': 25 | fh_out.createDimension(name, len(dim) if not dim.isunlimited() else None) 26 | 27 | ignore_features = ["time", "crs", "FparExtra_QC", "FparLai_QC"] 28 | mask_value_dic = {'Lai_500m': 10, 'LaiStdDev_500m': 10, 'Fpar_500m': 1, 'FparStdDev_500m': 1} 29 | for v_name, varin in fh_in.variables.items(): 30 | if v_name not in ignore_features: 31 | dimensions = varin.dimensions if v_name in ['lat', 'lon'] else ('lat', 'lon') 32 | outVar = fh_out.createVariable(v_name, varin.datatype, dimensions) 33 | if v_name == "lat": 34 | outVar.setncatts({"units": "degree_north"}) 35 | outVar[:] = varin[:] 36 | elif v_name == "lon": 37 | outVar.setncatts({"units": "degree_east"}) 38 | outVar[:] = varin[:] 39 | else: 40 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 41 | vin = varin[index, :, :] 42 | vin = ma.masked_greater(vin, mask_value_dic[v_name]) 43 | vin = ma.masked_less(vin, 0) 44 | outVar[:] = vin[:] 45 | fh_out.close() 46 | fh_in.close() 47 | 48 | 49 | def extract_ndvi(nc_file): 50 | fh_in = Dataset('../../raw_data/ndvi/' + nc_file, 'r') 51 | 52 | for index, n_days in enumerate(fh_in.variables['time'][:]): 53 | date = (datetime.datetime(2000, 1, 1, 0, 0) + datetime.timedelta(int(n_days))).strftime('%Y%m%d') 54 | print(date) 55 | fh_out = Dataset('../../processed_data/ndvi/1km/{}.nc'.format(date[:-2]), 'w') 56 | 57 | for name, dim in fh_in.dimensions.items(): 58 | if name != 'time': 59 | fh_out.createDimension(name, len(dim) if not dim.isunlimited() else None) 60 | 61 | ignore_features = ["time", "crs", "_1_km_monthly_VI_Quality"] 62 | for v_name, varin in fh_in.variables.items(): 63 | if v_name not in ignore_features: 64 | dimensions = varin.dimensions if v_name in ['lat', 'lon'] else ('lat', 'lon') 65 | v_name = v_name if v_name in ['lat', 'lon'] else v_name.split('_')[-1].lower() 66 | outVar = fh_out.createVariable(v_name, varin.datatype, dimensions) 67 | if v_name == "lat": 68 | outVar.setncatts({"units": "degree_north"}) 69 | outVar[:] = varin[:] 70 | elif v_name == "lon": 71 | outVar.setncatts({"units": "degree_east"}) 72 | outVar[:] = varin[:] 73 | else: 74 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 75 | vin = varin[index, :, :] 76 | vin = ma.masked_greater(vin, 1.0) 77 | vin = ma.masked_less(vin, -0.2) 78 | outVar[:] = vin[:] 79 | fh_out.close() 80 | fh_in.close() 81 | 82 | 83 | if __name__ == '__main__': 84 | # extract_lai('20190604.nc') 85 | extract_ndvi('MOD13A3_20000201_20181231.nc') 86 | -------------------------------------------------------------------------------- /data_preprocessing/preprocess/landcover.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from netCDF4 import Dataset 9 | import numpy as np 10 | import numpy.ma as ma 11 | import sys 12 | sys.path.append("..") 13 | 14 | from data_preprocessing.merge import merge_various_days 15 | 16 | 17 | def generate_convert_to_nc_script(): 18 | fh_out = open('../../processed_data/landcover/convert_to_nc.sh', 'w') 19 | fh_out.write('#!/bin/bash\n') 20 | 21 | for tif_file in os.listdir('../../processed_data/landcover/'): 22 | if tif_file.endswith('.tif'): 23 | fh_out.write('gdal_translate -of netCDF {} {}.nc\n'.format(tif_file, tif_file[:-4])) 24 | 25 | 26 | def mask_with_landcover(out_folder, kept_ldcs): 27 | for nc_file in os.listdir('../../processed_data/landcover/origi/'): 28 | if nc_file.endswith('.nc'): 29 | fh_in = Dataset('../../processed_data/landcover/origi/{}'.format(nc_file), 'r') 30 | fh_out = Dataset('../../processed_data/landcover/{}/{}'.format(out_folder, nc_file), 'w') 31 | 32 | for name, dim in fh_in.dimensions.items(): 33 | fh_out.createDimension(name, len(dim)) 34 | 35 | for v_name, varin in fh_in.variables.items(): 36 | outVar = fh_out.createVariable(v_name, varin.datatype, varin.dimensions) 37 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 38 | if v_name in ['lat', 'lon']: 39 | outVar[:] = varin[:] 40 | else: 41 | landcovers = varin[:] 42 | lc_mask = np.in1d(landcovers, kept_ldcs).reshape(landcovers.shape) 43 | outVar[:] = ma.array(varin[:], mask=~lc_mask) 44 | 45 | fh_in.close() 46 | fh_out.close() 47 | 48 | 49 | def generate_cropland(in_file, out_file): 50 | fh_in = Dataset(in_file, 'r') 51 | fh_out = Dataset(out_file, 'w') 52 | 53 | lats, lons = fh_in.variables['lat'][:], fh_in.variables['lon'][:] 54 | for name, dim in fh_in.dimensions.items(): 55 | fh_out.createDimension(name, len(dim)) 56 | 57 | for v_name, varin in fh_in.variables.items(): 58 | if v_name in ['lat', 'lon']: 59 | outVar = fh_out.createVariable(v_name, varin.datatype, (v_name,)) 60 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 61 | outVar[:] = varin[:] 62 | 63 | outVar = fh_out.createVariable('cropland', 'f4', ('lat', 'lon')) 64 | outVar.setncatts({'_FillValue': np.array([0.0]).astype('f')}) 65 | cropland = np.full((len(lats), len(lons)), 1.0) 66 | mask_value = ma.getmaskarray(fh_in.variables['Band1'][:]) 67 | mask_value = np.logical_and.reduce(mask_value) 68 | outVar[:] = ma.array(cropland, mask=mask_value) 69 | 70 | fh_in.close() 71 | fh_out.close() 72 | 73 | 74 | if __name__ == '__main__': 75 | generate_convert_to_nc_script() 76 | mask_with_landcover('cro', [12]) 77 | mask_with_landcover('cro_cvm', [12, 14]) 78 | merge_various_days('../../processed_data/landcover/origi/', '../../processed_data/landcover/', 'ts_merged', 79 | select_vars=['Band1']) 80 | merge_various_days('../../processed_data/landcover/cro/', '../../processed_data/landcover/', 'ts_merged_cro', 81 | select_vars=['Band1']) 82 | merge_various_days('../../processed_data/landcover/cro_cvm/', '../../processed_data/landcover/', 83 | 'ts_merged_cro_cvm', select_vars=['Band1']) 84 | generate_cropland('../../processed_data/landcover/ts_merged_cro.nc', 85 | '../../processed_data/landcover/cropland_cro.nc') 86 | generate_cropland('../../processed_data/landcover/ts_merged_cro_cvm.nc', 87 | '../../processed_data/landcover/cropland_cro_cvm.nc') 88 | -------------------------------------------------------------------------------- /data_preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | __all__ = ['cdl_values_to_crops', 'crops_to_cdl_values', 8 | 'CLIMATE_VARS', 'STATIC_CLIMATE_VARS', 'DYNAMIC_CLIMATE_VARS'] 9 | 10 | CLIMATE_VARS = ['ppt', 'evi', 'ndvi', 'elevation', 'lst_day', 'lst_night', 'clay', 'sand', 'silt'] 11 | STATIC_CLIMATE_VARS = ['elevation', 'clay', 'sand', 'silt'] 12 | DYNAMIC_CLIMATE_VARS = [x for x in CLIMATE_VARS if x not in STATIC_CLIMATE_VARS] 13 | 14 | cdl_values_to_crops = {1: 'Corn', 2: 'Cotton', 3: 'Rice', 4: 'Sorghum', 5: 'Soybeans', 6: 'Sunflower', 15 | 10: 'Peanuts', 11: 'Tobacco', 12: 'Sweet Corn', 13: 'Pop or Orn Corn', 14: 'Mint', 21: 'Barley', 16 | 22: 'Durum Wheat', 23: 'Spring Wheat', 24: 'Winter Wheat', 25: 'Other Small Grains', 17 | 26: 'Dbl Crop WinWht/Soybeans', 27: 'Rye', 28: 'Oats', 29: 'Millet', 30: 'Speltz', 31: 'Canola', 18 | 32: 'Flaxseed', 33: 'Safflower', 34: 'Rape Seed', 35: 'Mustard', 36: 'Alfalfa', 19 | 37: 'Other Hay/Non Alfalfa', 38: 'Camelina', 39: 'Buckwheat', 41: 'Sugarbeets', 42: 'Dry Beans', 20 | 43: 'Potatoes', 44: 'Other Crops', 45: 'Sugarcane', 46: 'Sweet Potatoes', 21 | 47: 'Misc Vegs & Fruits', 48: 'Watermelons', 49: 'Onions', 50: 'Cucumbers', 51: 'Chick Peas', 22 | 52: 'Lentils', 53: 'Peas', 54: 'Tomatoes', 55: 'Caneberries', 56: 'Hops', 57: 'Herbs', 23 | 58: 'Clover/Wildflowers', 59: 'Sod/Grass Seed', 60: 'Switchgrass', 61: 'Fallow/Idle Cropland', 24 | 63: 'Forest', 64: 'Shrubland1', 65: 'Barren1', 66: 'Cherries', 67: 'Peaches', 68: 'Apples', 25 | 69: 'Grapes', 70: 'Christmas Trees', 71: 'Other Tree Crops', 72: 'Citrus', 74: 'Pecans', 26 | 75: 'Almonds', 76: 'Walnuts', 77: 'Pears', 81: 'Clouds/No Data', 82: 'Developed', 83: 'Water', 27 | 87: 'Wetlands', 88: 'Nonag/Undefined', 92: 'Aquaculture', 111: 'Open Water', 28 | 112: 'Perennial Ice/Snow ', 121: 'Developed/Open Space', 122: 'Developed/Low Intensity', 29 | 123: 'Developed/Med Intensity', 124: 'Developed/High Intensity', 131: 'Barren2', 30 | 141: 'Deciduous Forest', 142: 'Evergreen Forest', 143: 'Mixed Forest', 152: 'Shrubland2', 31 | 176: 'Grassland/Pasture', 190: 'Woody Wetlands', 195: 'Herbaceous Wetlands', 204: 'Pistachios', 32 | 205: 'Triticale', 206: 'Carrots', 207: 'Asparagus', 208: 'Garlic', 209: 'Cantaloupes', 33 | 210: 'Prunes', 211: 'Olives', 212: 'Oranges', 213: 'Honeydew Melons', 214: 'Broccoli', 34 | 215: 'Avocados', 216: 'Peppers', 217: 'Pomegranates', 218: 'Nectarines', 219: 'Greens', 35 | 220: 'Plums', 221: 'Strawberries', 222: 'Squash', 223: 'Apricots', 224: 'Vetch', 36 | 225: 'Dbl Crop WinWht/Corn', 226: 'Dbl Crop Oats/Corn', 227: 'Lettuce', 229: 'Pumpkins', 37 | 230: 'Dbl Crop Lettuce/Durum Wht', 38 | 231: 'Dbl Crop Lettuce/Cantaloupe', 232: 'Dbl Crop Lettuce/Cotton', 39 | 233: 'Dbl Crop Lettuce/Barley', 234: 'Dbl Crop Durum Wht/Sorghum', 40 | 235: 'Dbl Crop Barley/Sorghum', 236: 'Dbl Crop WinWht/Sorghum', 237: 'Dbl Crop Barley/Corn', 41 | 238: 'Dbl Crop WinWht/Cotton', 239: 'Dbl Crop Soybeans/Cotton', 240: 'Dbl Crop Soybeans/Oats', 42 | 241: 'Dbl Crop Corn/Soybeans', 242: 'Blueberries', 243: 'Cabbage', 244: 'Cauliflower', 43 | 245: 'Celery', 246: 'Radishes', 247: 'Turnips', 248: 'Eggplants', 249: 'Gourds', 44 | 250: 'Cranberries', 254: 'Dbl Crop Barley/Soybeans'} 45 | 46 | # A reverse map of above, allowing you to lookup CDL values from category name. 47 | crops_to_cdl_values = {v: k for k, v in cdl_values_to_crops.items()} 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Context-aware Deep Representation Learning for Geo-spatiotemporal Analysis 2 | Code for ICDM 2020 paper Context-aware Deep Representation Learning for Geo-spatiotemporal Analysis. 3 | 4 | ## Data Preprocessing 5 | 6 | ### Data Sources 7 | 1. County-level soybean yields (year 2003 to 2018) is downloaded from [USDA NASS Quick Stats Database](https://quickstats.nass.usda.gov/). 8 | 2. Landcover class is from the [MODIS product MCD12Q1](https://lpdaac.usgs.gov/products/mcd12q1v006/) and downloaded from Google Earth Engine. **gee_county_lc.py** and **gee_landcover.py** are the files to call Google Earth Engine and download the data. County boundaries from [Google's fusion table](https://fusiontables.google.com/data?docid=1S4EB6319wWW2sWQDPhDvmSBIVrD3iEmCLYB7nMM#rows:id=1) are utilized to download the landcover class data for each county separately. 9 | 10 | Input: 11 | 12 | 1. Vegetation indices including NDVI and EVI are from [MODIS product MOD13A3](https://lpdaac.usgs.gov/products/mod13a3v006/) and downloaded from [AρρEEARS](https://lpdaacsvc.cr.usgs.gov/appeears/). 13 | 2. Precipitation is from the [PRISM dataset](http://www.prism.oregonstate.edu/). 14 | 3. Land surface temperature is from [MODIS product MOD11A1]() and downloaded from [AρρEEARS](https://lpdaacsvc.cr.usgs.gov/appeears/). 15 | 4. Elevation is from the NASA Shuttle Radar Topography Mission Global 30 m product and downloaded from [AρρEEARS](https://lpdaacsvc.cr.usgs.gov/appeears/). 16 | 5. Soil properties including soil sand, silt, and clay fractions are from the [STATSGO](https://catalog.data.gov/dataset/statsgo-soil-polygons) data base. 17 | 18 | ### Preprocessing 19 | Data from various sources are first converted to a unified format [netCDF4](https://unidata.github.io/netcdf4-python/netCDF4/index.html) with their original resolutions being kept. They are then rescaled to the MODIS product grid at 1 km resolution. 20 | 21 | ## Experiment Data Generation 22 | Quadruplet sampling code is contained in folder **data_preprocessing/sample_quadruplets**. Functions are then called by **generate_experiment_data.py** to generate experiment data. 23 | 24 | ## Modeling 25 | We provide code here for the context-aware representation learning model and all baselines mentioned in the paper, including traditional models for scalar inputs, deep gausian models, cnn-lstm and c3d. 26 | 27 | A few examples of commands to train the models: 28 | 1. attention model - semisupervised: 29 | ```console 30 | python ./crop_yield_train_semi_transformer.py --neighborhood-radius 25 --distant-radius 100 --weight-decay 0.0 --tilenet-margin 50 --tilenet-l2 0.2 --tilenet-ltn 0.001 --tilenet-zdim 256 --attention-layer 1 --attention-dff 512 --sentence-embedding simple_average --dropout 0.2 --unsup-weight 0.2 --patience 9999 --feature all --feature-len 9 --year 2018 --ntsteps 7 --train-years 10 --query-type combine 31 | ``` 32 | 2. attention model - supervised: 33 | ```console 34 | python ./crop_yield_train_semi_transformer.py --neighborhood-radius 25 --distant-radius 100 --weight-decay 0.0 --tilenet-margin 50 --tilenet-l2 0.2 --tilenet-ltn 0.001 --tilenet-zdim 256 --attention-layer 1 --attention-dff 512 --sentence-embedding simple_average --dropout 0.2 --unsup-weight 0.0 --patience 9999 --feature all --feature-len 9 --year 2018 --ntsteps 7 --train-years 10 --query-type combine 35 | ``` 36 |   When query type is set as "combine", the hybrid attention mechanism introduced in the ICDM 2020 papaer is adopted. You can test other query types ("global", "fixed", "separate") on your data as well. 37 | 38 | 3. c3d 39 | ```console 40 | python ./crop_yield_train_c3d.py --patience 9999 --feature all --feature-len 9 --year 2018 --ntsteps 7 --train-years 10 41 | ``` 42 | 4. cnn-lstm 43 | ```console 44 | python ./crop_yield_train_cnn_lstm.py --patience 9999 --feature all --feature-len 9 --year 2018 --ntsteps 7 --train-years 10 --tilenet-zdim 256 --lstm-inner 512 45 | ``` 46 | 5. deep gaussian 47 | ```console 48 | python ./crop_yield_deep_gaussian.py --type cnn --time 7 --train-years 10 49 | ``` 50 | 6. traditional models 51 | ```console 52 | python ./crop_yield_no_spatial.py --predict no_spatial --train-years 10 53 | ``` 54 | 55 | ## Cite this work 56 | 57 | ## License 58 | MIT licensed. See the LICENSE file for details. 59 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/semi_transformer/SubLayers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Based on transformer code from https://github.com/jadore801120/attention-is-all-you-need-pytorch 7 | 8 | ''' Define the sublayers in encoder/decoder layer ''' 9 | from crop_yield_prediction.models.semi_transformer.Modules import ScaledDotProductAttention 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | class MultiHeadAttention(nn.Module): 18 | ''' Multi-Head Attention module ''' 19 | 20 | def __init__(self, n_tsteps, query_type, n_head, d_model, d_k, d_v, dropout=0.1): 21 | super().__init__() 22 | 23 | self.n_head = n_head 24 | self.d_k = d_k 25 | self.d_v = d_v 26 | 27 | self.global_query = nn.Parameter(torch.randn(n_head, d_k, n_tsteps), requires_grad=True) 28 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 29 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 30 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 31 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False) 32 | 33 | self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) 34 | 35 | self.dropout = nn.Dropout(dropout) 36 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 37 | 38 | self.query_type = query_type 39 | def forward(self, q, k, v): 40 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 41 | # sz_b: batch size, len_q, len_k, len_v: number of time steps 42 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 43 | 44 | residual = q 45 | 46 | # Pass through the pre-attention projection: b x lq x (n*dv) 47 | # Separate different heads: b x lq x n x dv 48 | if self.query_type == 'global': 49 | q = self.global_query 50 | q = q.transpose(1, 2) # transpose to n * lq * dk 51 | elif self.query_type == 'fixed': 52 | q = self.layer_norm(q) 53 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 54 | q = q.transpose(1, 2) # transpose to b x n x lq x dk 55 | elif self.query_type == 'combine': 56 | lq = self.layer_norm(q) 57 | lq = self.w_qs(lq).view(sz_b, len_q, n_head, d_k) 58 | lq = lq.transpose(1, 2) 59 | gq = self.global_query 60 | gq = gq.transpose(1, 2) 61 | q = lq + gq 62 | elif self.query_type == 'separate': 63 | lq = self.layer_norm(q) 64 | lq = self.w_qs(lq).view(sz_b, len_q, n_head, d_k) 65 | lq = lq.transpose(1, 2) 66 | gq = self.global_query 67 | gq = gq.transpose(1, 2) 68 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 69 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 70 | k, v = k.transpose(1, 2), v.transpose(1, 2) # Transpose for attention dot product: b x n x lq x dv 71 | # Transpose for attention dot product: b x n x lq x dv 72 | if self.query_type == 'separate': 73 | q, attn = self.attention(lq, k, v, gq) 74 | else: 75 | q, attn = self.attention(q, k, v) 76 | 77 | # Transpose to move the head dimension back: b x lq x n x dv 78 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 79 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 80 | q = self.dropout(self.fc(q)) 81 | q += residual 82 | 83 | return q, attn 84 | 85 | 86 | class PositionwiseFeedForward(nn.Module): 87 | ''' A two-feed-forward-layer module ''' 88 | 89 | def __init__(self, d_in, d_hid, dropout=0.1): 90 | super().__init__() 91 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise 92 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise 93 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 94 | self.dropout = nn.Dropout(dropout) 95 | 96 | def forward(self, x): 97 | residual = x 98 | x = self.layer_norm(x) 99 | 100 | x = self.w_2(F.relu(self.w_1(x))) 101 | x = self.dropout(x) 102 | x += residual 103 | 104 | return x 105 | -------------------------------------------------------------------------------- /data_preprocessing/rescaling/nws_precip.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from netCDF4 import Dataset 8 | import numpy as np 9 | import numpy.ma as ma 10 | import fiona 11 | 12 | import sys 13 | sys.path.append("..") 14 | 15 | from data_preprocessing.utils import generate_doy 16 | from data_preprocessing.preprocess import search_kdtree 17 | 18 | 19 | def extract_shapefile(): 20 | shapefile = fiona.open('../../raw_data/nws_precip/nws_precip_allpoint_conversion/nws_precip_allpoint_conversion.shp') 21 | 22 | lats = np.full((881, 1121), np.inf) 23 | lons = np.full((881, 1121), np.inf) 24 | max_hrapx, max_hrapy = -float('inf'), -float('inf') 25 | for feature in shapefile: 26 | hrapx, hrapy = feature['properties']['Hrapx'], feature['properties']['Hrapy'] 27 | max_hrapx = max(max_hrapx, hrapx) 28 | max_hrapy = max(max_hrapy, hrapy) 29 | lon, lat = feature['geometry']['coordinates'] 30 | if 0 <= hrapx < 1121 and 0 <= hrapy < 881: 31 | lats[hrapy, hrapx] = lat 32 | lons[hrapy, hrapx] = lon 33 | print(max_hrapx, max_hrapy) 34 | np.save('../../raw_data/nws_precip/nws_precip_allpoint_conversion/lats.npy', lats) 35 | np.save('../../raw_data/nws_precip/nws_precip_allpoint_conversion/lons.npy', lons) 36 | 37 | 38 | def compute_closest_grid_point(lats, lons, lat, lon): 39 | d_lats = lats - float(lat) 40 | d_lons = lons - float(lon) 41 | d = np.multiply(d_lats, d_lats) + np.multiply(d_lons, d_lons) 42 | i, j = np.unravel_index(d.argmin(), d.shape) 43 | return i, j, np.sqrt(d.min()) 44 | 45 | 46 | def reproject_lat_lon(): 47 | lats = np.load('../../raw_data/nws_precip/nws_precip_allpoint_conversion/lats.npy') 48 | lons = np.load('../../raw_data/nws_precip/nws_precip_allpoint_conversion/lons.npy') 49 | 50 | fh_ref = Dataset('../../processed_data/lai/500m/20181028.nc', 'r') 51 | ref_lats, ref_lons = fh_ref.variables['lat'][:], fh_ref.variables['lon'][:] 52 | 53 | xv, yv = np.meshgrid(ref_lons, ref_lats) 54 | points = np.dstack([yv.ravel(), xv.ravel()])[0] 55 | print('Finish building points') 56 | results = search_kdtree(lats, lons, points) 57 | np.save('../../raw_data/nws_precip/nws_precip_allpoint_conversion/projected_indices_lai_500m.npy', results) 58 | 59 | 60 | def reproject_nws_precip(doy): 61 | print(doy) 62 | fh_ref = Dataset('../../processed_data/lai/500m/20181028.nc', 'r') 63 | fh_in = Dataset('../../raw_data/nws_precip/{}/nws_precip_1day_{}_conus.nc'.format(doy, doy), 'r') 64 | fh_out = Dataset('../../processed_data/nws_precip/500m/{}.nc'.format(doy), 'w') 65 | 66 | ref_lats, ref_lons = fh_ref.variables['lat'][:], fh_ref.variables['lon'][:] 67 | n_lat, n_lon = len(ref_lats), len(ref_lons) 68 | for name, dim in fh_ref.dimensions.items(): 69 | fh_out.createDimension(name, len(dim)) 70 | 71 | for v_name, varin in fh_ref.variables.items(): 72 | if v_name in ['lat', 'lon']: 73 | outVar = fh_out.createVariable(v_name, varin.datatype, (v_name,)) 74 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 75 | outVar[:] = varin[:] 76 | 77 | observed_values = fh_in.variables['observation'][:] 78 | projected_values = np.full((n_lat, n_lon), -9999.9) 79 | projected_indices = \ 80 | np.load('../../raw_data/nws_precip/nws_precip_allpoint_conversion/projected_indices_lai_500m.npy') 81 | projected_i = 0 82 | for i in range(n_lat): 83 | for j in range(n_lon): 84 | proj_i, proj_j = 881 - projected_indices[projected_i] // 1121, projected_indices[projected_i] % 1121 85 | if not observed_values.mask[proj_i, proj_j]: 86 | projected_values[i, j] = observed_values[proj_i, proj_j] 87 | projected_i += 1 88 | 89 | outVar = fh_out.createVariable('precip', 'f4', ('lat', 'lon')) 90 | outVar[:] = ma.masked_equal(projected_values, -9999.9) 91 | 92 | fh_in.close() 93 | fh_ref.close() 94 | fh_out.close() 95 | 96 | 97 | if __name__ == '__main__': 98 | # extract_shapefile() 99 | # reproject_lat_lon() 100 | for doy in generate_doy('20171227', '20171231', ''): 101 | reproject_nws_precip(doy) 102 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/semi_transformer/TileNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Based on tile2vec code from https://github.com/ermongroup/tile2vec 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class ResidualBlock(nn.Module): 14 | def __init__(self, in_planes, planes, stride=1): 15 | super(ResidualBlock, self).__init__() 16 | 17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_planes != planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(planes)) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(self.conv1(x))) 30 | out = self.bn2(self.conv2(out)) 31 | out += self.shortcut(x) 32 | out = F.relu(out) 33 | return out 34 | 35 | 36 | class TileNet(nn.Module): 37 | def __init__(self, num_blocks, in_channels=4, z_dim=512): 38 | super(TileNet, self).__init__() 39 | self.in_channels = in_channels 40 | self.z_dim = z_dim 41 | self.in_planes = 64 42 | 43 | self.conv1 = nn.Conv2d(self.in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False) 44 | self.bn1 = nn.BatchNorm2d(64) 45 | self.layer1 = self._make_layer(64, num_blocks[0], stride=1) 46 | self.layer2 = self._make_layer(128, num_blocks[1], stride=2) 47 | self.layer3 = self._make_layer(256, num_blocks[2], stride=2) 48 | self.layer4 = self._make_layer(512, num_blocks[3], stride=2) 49 | self.layer5 = self._make_layer(self.z_dim, num_blocks[4], stride=2) 50 | 51 | def _make_layer(self, planes, num_blocks, stride, no_relu=False): 52 | strides = [stride] + [1]*(num_blocks-1) 53 | layers = [] 54 | for stride in strides: 55 | layers.append(ResidualBlock(self.in_planes, planes, stride=stride)) 56 | self.in_planes = planes 57 | return nn.Sequential(*layers) 58 | 59 | def encode(self, x): 60 | x = F.relu(self.bn1(self.conv1(x))) 61 | x = self.layer1(x) 62 | x = self.layer2(x) 63 | x = self.layer3(x) 64 | x = self.layer4(x) 65 | x = self.layer5(x) 66 | x = F.avg_pool2d(x, 4) 67 | z = x.view(x.size(0), -1) 68 | return z 69 | 70 | def forward(self, x): 71 | return self.encode(x) 72 | 73 | def loss(self, anchor, temporal_neighbor, spatial_neighbor, spatial_distant, margin, l2, ltn): 74 | """ 75 | Computes loss for each batch. 76 | """ 77 | z_a, z_tn, z_sn, z_d = (self.encode(anchor), self.encode(temporal_neighbor), self.encode(spatial_neighbor), 78 | self.encode(spatial_distant)) 79 | 80 | return triplet_loss(z_a, z_tn, z_sn, z_d, margin, l2, ltn) 81 | 82 | 83 | def triplet_loss(z_a, z_tn, z_sn, z_d, margin, l2, ltn): 84 | dim = z_a.shape[-1] 85 | 86 | l_n = torch.sqrt(((z_a - z_sn) ** 2).sum(dim=1)) 87 | l_d = - torch.sqrt(((z_a - z_d) ** 2).sum(dim=1)) 88 | sn_loss = F.relu(l_n + l_d + margin) 89 | tn_loss = torch.sqrt(((z_a - z_tn) ** 2).sum(dim=1)) 90 | 91 | # average by #samples in mini-batch 92 | l_n = torch.mean(l_n) 93 | l_d = torch.mean(l_d) 94 | l_nd = torch.mean(l_n + l_d) 95 | sn_loss = torch.mean(sn_loss) 96 | tn_loss = torch.mean(tn_loss) 97 | 98 | loss = (1 - ltn) * sn_loss + ltn * tn_loss 99 | 100 | norm_loss = 0 101 | if l2 != 0: 102 | z_a_norm = torch.sqrt((z_a ** 2).sum(dim=1)) 103 | z_sn_norm = torch.sqrt((z_sn ** 2).sum(dim=1)) 104 | z_d_norm = torch.sqrt((z_d ** 2).sum(dim=1)) 105 | z_tn_norm = torch.sqrt((z_tn ** 2).sum(dim=1)) 106 | norm_loss = torch.mean(z_a_norm + z_sn_norm + z_d_norm + z_tn_norm) / (dim ** 0.5) 107 | loss += l2 * norm_loss 108 | 109 | return loss, l_n, l_d, l_nd, sn_loss, tn_loss, norm_loss 110 | 111 | 112 | def make_tilenet(in_channels, z_dim=512): 113 | """ 114 | Returns a TileNet for unsupervised Tile2Vec with the specified number of 115 | input channels and feature dimension. 116 | """ 117 | num_blocks = [2, 2, 2, 2, 2] 118 | return TileNet(num_blocks, in_channels=in_channels, z_dim=z_dim) 119 | 120 | -------------------------------------------------------------------------------- /crop_yield_prediction/plot/plot_crop_yield_prediction_error.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from bs4 import BeautifulSoup 8 | from pathlib import Path 9 | import matplotlib as mpl 10 | import matplotlib.pyplot as plt 11 | import pandas as pd 12 | from collections import defaultdict 13 | import numpy as np 14 | import seaborn as sns 15 | 16 | 17 | # colors = sns.color_palette("RdYlBu", 10).as_hex() 18 | colors = ["#b2182b", "#d6604d", "#f4a582", "#fddbc7", "#d1e5f0", "#92c5de", "#4393c3", "#2166ac"] 19 | 20 | 21 | def crop_yield_prediction_error_plot(data_dict, savepath): 22 | """ 23 | For the most part, reformatting of 24 | https://github.com/JiaxuanYou/crop_yield_prediction/blob/master/6%20result_analysis/yield_map.py 25 | """ 26 | # load the svg file 27 | svg = Path('data/counties.svg').open('r').read() 28 | # Load into Beautiful Soup 29 | soup = BeautifulSoup(svg, features="html.parser") 30 | # Find counties 31 | paths = soup.findAll('path') 32 | 33 | path_style = 'font-size:12px;fill-rule:nonzero;stroke:#FFFFFF;stroke-opacity:1;stroke-width:0.1' \ 34 | ';stroke-miterlimit:4;stroke-dasharray:none;stroke-linecap:butt;marker-start' \ 35 | ':none;stroke-linejoin:bevel;fill:' 36 | 37 | for p in paths: 38 | if p['id'] not in ["State_Lines", "separator"]: 39 | try: 40 | rate = data_dict[p['id']] 41 | except KeyError: 42 | continue 43 | if rate > 15: 44 | color_class = 7 45 | elif rate > 10: 46 | color_class = 6 47 | elif rate > 5: 48 | color_class = 5 49 | elif rate > 0: 50 | color_class = 4 51 | elif rate > -5: 52 | color_class = 3 53 | elif rate > -10: 54 | color_class = 2 55 | elif rate > -15: 56 | color_class = 1 57 | else: 58 | color_class = 0 59 | 60 | color = colors[color_class] 61 | p['style'] = path_style + color 62 | soup = soup.prettify() 63 | with savepath.open('w') as f: 64 | f.write(soup) 65 | 66 | 67 | def save_colorbar(savedir): 68 | """ 69 | For the most part, reformatting of 70 | https://github.com/JiaxuanYou/crop_yield_prediction/blob/master/6%20result_analysis/yield_map.py 71 | """ 72 | fig = plt.figure() 73 | ax = fig.add_axes([0.1, 0.1, 0.02, 0.8]) 74 | 75 | cmap = mpl.colors.ListedColormap(colors[1:-1]) 76 | 77 | cmap.set_over(colors[-1]) 78 | cmap.set_under(colors[0]) 79 | 80 | bounds = [-15, -10, -5, 0, 5, 10, 15] 81 | 82 | norm = mpl.colors.BoundaryNorm(bounds, cmap.N) 83 | cb = mpl.colorbar.ColorbarBase(ax, cmap=cmap, 84 | norm=norm, 85 | # to use 'extend', you must 86 | # specify two extra boundaries: 87 | boundaries=[-20] + bounds + [20], 88 | extend='both', 89 | ticks=bounds, # optional 90 | spacing='proportional', 91 | orientation='vertical') 92 | plt.savefig('{}/colorbar.jpg'.format(savedir), dpi=300, bbox_inches='tight') 93 | 94 | 95 | def process_yield_data(): 96 | important_columns = ['Year', 'State ANSI', 'County ANSI', 'Value'] 97 | yield_data = pd.read_csv('../../processed_data/crop_yield/yield_data.csv').dropna( 98 | subset=important_columns, how='any')[['Year', 'State ANSI', 'County ANSI', 'Value']] 99 | yield_data.columns = ['Year', 'State', 'County', 'Value'] 100 | yield_per_year_dic = defaultdict(dict) 101 | 102 | for yd in yield_data.itertuples(): 103 | year, state, county, value = yd.Year, yd.State, int(yd.County), yd.Value 104 | state = str(state).zfill(2) 105 | county = str(county).zfill(3) 106 | 107 | yield_per_year_dic[year][state+county] = value 108 | 109 | return yield_per_year_dic 110 | 111 | 112 | if __name__ == '__main__': 113 | yield_data = process_yield_data() 114 | for year in range(2003, 2017): 115 | crop_yield_prediction_error_plot(yield_data[year], Path('../../processed_data/crop_yield/plots/{}_yield.html'.format(year))) 116 | values = np.array(list(yield_data[year].values())) 117 | print(year, np.percentile(values, 0), np.percentile(values, 25), np.percentile(values, 50), 118 | np.percentile(values, 75), np.percentile(values, 100)) 119 | save_colorbar('../../processed_data/crop_yield/plots') 120 | -------------------------------------------------------------------------------- /data_preprocessing/plot/counties_plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from bs4 import BeautifulSoup 8 | from pathlib import Path 9 | import matplotlib as mpl 10 | import matplotlib.pyplot as plt 11 | import pandas as pd 12 | from collections import defaultdict 13 | import numpy as np 14 | import seaborn as sns 15 | 16 | 17 | # colors = sns.color_palette("RdYlBu", 10).as_hex() 18 | colors = ['#cdeaf3', '#9bcce2', '#fff1aa', '#fece7f', '#fa9b58', '#ee613e', '#d22b27'] 19 | 20 | 21 | def counties_plot(data_dict, savepath, quantiles): 22 | """ 23 | For the most part, reformatting of 24 | https://github.com/JiaxuanYou/crop_yield_prediction/blob/master/6%20result_analysis/yield_map.py 25 | """ 26 | # load the svg file 27 | svg = Path('../../processed_data/counties/counties.svg').open('r').read() 28 | # Load into Beautiful Soup 29 | soup = BeautifulSoup(svg, features="html.parser") 30 | # Find counties 31 | paths = soup.findAll('path') 32 | 33 | path_style = 'font-size:12px;fill-rule:nonzero;stroke:#FFFFFF;stroke-opacity:1;stroke-width:0.1' \ 34 | ';stroke-miterlimit:4;stroke-dasharray:none;stroke-linecap:butt;marker-start' \ 35 | ':none;stroke-linejoin:bevel;fill:' 36 | 37 | for p in paths: 38 | if p['id'] not in ["State_Lines", "separator"]: 39 | try: 40 | rate = data_dict[p['id']] 41 | except KeyError: 42 | continue 43 | if rate > quantiles[0.95]: 44 | color_class = 6 45 | elif rate > quantiles[0.8]: 46 | color_class = 5 47 | elif rate > quantiles[0.6]: 48 | color_class = 4 49 | elif rate > quantiles[0.4]: 50 | color_class = 3 51 | elif rate > quantiles[0.2]: 52 | color_class = 2 53 | elif rate > quantiles[0.05]: 54 | color_class = 1 55 | else: 56 | color_class = 0 57 | 58 | color = colors[color_class] 59 | p['style'] = path_style + color 60 | soup = soup.prettify() 61 | with savepath.open('w') as f: 62 | f.write(soup) 63 | 64 | 65 | def save_colorbar(savedir, quantiles): 66 | """ 67 | For the most part, reformatting of 68 | https://github.com/JiaxuanYou/crop_yield_prediction/blob/master/6%20result_analysis/yield_map.py 69 | """ 70 | fig = plt.figure() 71 | ax = fig.add_axes([0.1, 0.1, 0.02, 0.8]) 72 | 73 | cmap = mpl.colors.ListedColormap(colors[1:-1]) 74 | 75 | cmap.set_over(colors[-1]) 76 | cmap.set_under(colors[0]) 77 | 78 | bounds = [quantiles[x] for x in [0.05, 0.2, 0.4, 0.6, 0.8, 0.95]] 79 | 80 | norm = mpl.colors.BoundaryNorm(bounds, cmap.N) 81 | cb = mpl.colorbar.ColorbarBase(ax, cmap=cmap, 82 | norm=norm, 83 | # to use 'extend', you must 84 | # specify two extra boundaries: 85 | boundaries=[quantiles[0.0]] + bounds + [quantiles[1.0]], 86 | extend='both', 87 | ticks=bounds, # optional 88 | spacing='proportional', 89 | orientation='vertical') 90 | plt.savefig('{}/colorbar.jpg'.format(savedir), dpi=300, bbox_inches='tight') 91 | 92 | 93 | def process_yield_data(): 94 | important_columns = ['Year', 'State ANSI', 'County ANSI', 'Value'] 95 | yield_data = pd.read_csv('../../processed_data/crop_yield/yield_data.csv').dropna( 96 | subset=important_columns, how='any')[['Year', 'State ANSI', 'County ANSI', 'Value']] 97 | yield_data.columns = ['Year', 'State', 'County', 'Value'] 98 | yield_per_year_dic = defaultdict(dict) 99 | 100 | for yd in yield_data.itertuples(): 101 | year, state, county, value = yd.Year, yd.State, int(yd.County), yd.Value 102 | state = str(state).zfill(2) 103 | county = str(county).zfill(3) 104 | 105 | yield_per_year_dic[year][state+county] = value 106 | 107 | return yield_per_year_dic 108 | 109 | 110 | if __name__ == '__main__': 111 | yield_data = process_yield_data() 112 | for year in range(2003, 2017): 113 | counties_plot(yield_data[year], Path('../../processed_data/crop_yield/plots/{}_yield.html'.format(year))) 114 | values = np.array(list(yield_data[year].values())) 115 | print(year, np.percentile(values, 0), np.percentile(values, 25), np.percentile(values, 50), 116 | np.percentile(values, 75), np.percentile(values, 100)) 117 | save_colorbar('../../processed_data/crop_yield/plots') 118 | -------------------------------------------------------------------------------- /data_preprocessing/preprocess/lst.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from netCDF4 import Dataset 9 | import datetime 10 | import calendar 11 | from collections import defaultdict 12 | import numpy.ma as ma 13 | import os 14 | 15 | import sys 16 | sys.path.append("..") 17 | 18 | 19 | def extract_lst(nc_file): 20 | fh_in = Dataset('../../raw_data/lst/' + nc_file, 'r') 21 | 22 | for index, n_days in enumerate(fh_in.variables['time'][:]): 23 | date = (datetime.datetime(2000, 1, 1, 0, 0) + datetime.timedelta(int(n_days))).strftime('%Y%m%d') 24 | print(date) 25 | fh_out = Dataset('../../raw_data/lst/1km/{}.nc'.format(date), 'w') 26 | 27 | for name, dim in fh_in.dimensions.items(): 28 | if name != 'time': 29 | fh_out.createDimension(name, len(dim) if not dim.isunlimited() else None) 30 | 31 | ignore_features = ['time', 'crs', 'Clear_day_cov', 'Clear_night_cov', 'Day_view_angl', 'Day_view_time', 32 | 'Night_view_angl', 'Night_view_time', 'Emis_31', 'Emis_32', "QC_Day", "QC_Night"] 33 | for v_name, varin in fh_in.variables.items(): 34 | if v_name not in ignore_features: 35 | dimensions = varin.dimensions if v_name in ['lat', 'lon'] else ('lat', 'lon') 36 | outVar = fh_out.createVariable(v_name, varin.datatype, dimensions) 37 | if v_name == "lat": 38 | outVar.setncatts({"units": "degree_north"}) 39 | outVar[:] = varin[:] 40 | elif v_name == "lon": 41 | outVar.setncatts({"units": "degree_east"}) 42 | outVar[:] = varin[:] 43 | else: 44 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 45 | outVar[:] = varin[index, :, :] 46 | fh_out.close() 47 | fh_in.close() 48 | 49 | 50 | def generate_monthly_average(start_year, end_year, start_month, end_month): 51 | in_dir = '../../raw_data/lst/1km' 52 | out_dir = '../../processed_data/lst/monthly_1km' 53 | os.makedirs(out_dir, exist_ok=True) 54 | for year in range(start_year, end_year): 55 | for month in range(start_month, end_month): 56 | fh_out = Dataset('{}/{}{}.nc'.format(out_dir, year, '{0:02}'.format(month)), 'w') 57 | print(year, month) 58 | 59 | var_lis = defaultdict(list) 60 | first = True 61 | num_days = calendar.monthrange(year, month)[1] 62 | days = map(lambda x: x.strftime('%Y%m%d'), [datetime.date(year, month, day) for day in range(1, num_days+1)]) 63 | for day in days: 64 | if '{}.nc'.format(day) not in os.listdir(in_dir): 65 | print('Missing {}'.format(day)) 66 | continue 67 | fh_in = Dataset('{}/{}.nc'.format(in_dir, day), 'r') 68 | 69 | len_lat, len_lon = len(fh_in.variables['lat'][:]), len(fh_in.variables['lon'][:]) 70 | assert len_lat == 3578 or len_lat == 3579 71 | assert len_lon == 7797 72 | 73 | for v_name, varin in fh_in.variables.items(): 74 | if v_name in ['LST_Day_1km', 'LST_Night_1km']: 75 | if len_lat == 3578: 76 | var_lis[v_name[:-4].lower()].append(fh_in.variables[v_name][:]) 77 | else: 78 | var_lis[v_name[:-4].lower()].append(fh_in.variables[v_name][:-1, :]) 79 | 80 | if first: 81 | for name, dim in fh_in.dimensions.items(): 82 | if name == 'lat': 83 | fh_out.createDimension(name, 3578) 84 | else: 85 | fh_out.createDimension(name, len(dim)) 86 | for v_name, varin in fh_in.variables.items(): 87 | if v_name in ['LST_Day_1km', 'LST_Night_1km'] or v_name in ["lat", "lon"]: 88 | new_name = v_name[:-4].lower() if v_name in ['LST_Day_1km', 'LST_Night_1km'] else v_name 89 | outVar = fh_out.createVariable(new_name, varin.datatype, varin.dimensions) 90 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 91 | if v_name == 'lat': 92 | outVar[:] = varin[:3578] 93 | elif v_name == 'lon': 94 | outVar[:] = varin[:] 95 | 96 | first = False 97 | 98 | fh_in.close() 99 | 100 | for var in fh_out.variables: 101 | if var != "lat" and var != "lon": 102 | print(ma.array(var_lis[var]).shape) 103 | fh_out.variables[var][:] = ma.array(var_lis[var]).mean(axis=0) 104 | 105 | fh_out.close() 106 | 107 | 108 | if __name__ == '__main__': 109 | extract_lst('MOD11A1_20140201_20140930.nc') 110 | generate_monthly_average(2014, 2015, 2, 10) 111 | -------------------------------------------------------------------------------- /crop_yield_prediction/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from math import sqrt 9 | from sklearn.metrics import r2_score, mean_squared_error 10 | from scipy.stats.stats import pearsonr 11 | import numpy as np 12 | import pandas as pd 13 | 14 | from crop_yield_prediction.plot import crop_yield_plot 15 | from crop_yield_prediction.plot import crop_yield_prediction_error_plot 16 | 17 | 18 | def get_statistics(y, prediction, valid): 19 | corr = tuple(map(lambda x: np.around(x, 3), pearsonr(y, prediction))) 20 | r2 = np.around(r2_score(y, prediction), 3) 21 | rmse = np.around(sqrt(mean_squared_error(y, prediction)), 3) 22 | 23 | if valid: 24 | print('Validation - Pearson correlation: {}, R2: {}, RMSE: {}'.format(corr, r2, rmse)) 25 | else: 26 | print('Test - Pearson correlation: {}, R2: {}, RMSE: {}'.format(corr, r2, rmse)) 27 | 28 | return corr, r2, rmse 29 | 30 | 31 | def get_latest_model_dir(model_dir): 32 | latest_folder = sorted([x for x in os.listdir(model_dir) if x.startswith('log')], key=lambda x: int(x[3:]))[-1] 33 | 34 | return os.path.join(model_dir, latest_folder) 35 | 36 | 37 | def get_latest_model(model_dir, cv=None): 38 | log_folders = sorted([x for x in os.listdir(model_dir) if x.startswith('log')], key=lambda x: int(x[3:]))[-1] 39 | check_dir = os.path.join(model_dir, log_folders) if cv is None else os.path.join(model_dir, log_folders, cv) 40 | latest_model = sorted([x for x in os.listdir(check_dir) if x.endswith('.tar')], 41 | key=lambda x: int(x.split('.')[0][13:]))[-1] 42 | return os.path.join(check_dir, latest_model) 43 | 44 | 45 | def get_latest_models_cvs(model_dir, cvs): 46 | log_folders = sorted([x for x in os.listdir(model_dir) if x.startswith('log')], key=lambda x: int(x[3:]))[-1] 47 | latest_models = [] 48 | for cv in cvs: 49 | check_dir = os.path.join(model_dir, log_folders, cv) 50 | latest_model = sorted([x for x in os.listdir(check_dir) if x.endswith('.tar')], 51 | key=lambda x: int(x.split('.')[0][13:]))[-1] 52 | latest_models.append(os.path.join(check_dir, latest_model)) 53 | return latest_models 54 | 55 | 56 | def plot_predict(prediction, dim, savepath): 57 | pred_dict = {} 58 | for idx, pred in zip(dim, prediction): 59 | state, county = idx 60 | state = str(int(state)).zfill(2) 61 | county = str(int(county)).zfill(3) 62 | 63 | pred_dict[state + county] = pred 64 | 65 | crop_yield_plot(pred_dict, savepath) 66 | 67 | 68 | def plot_predict_error(prediction, real_values, dim, savepath): 69 | test_pred_error = prediction - real_values 70 | pred_dict = {} 71 | for idx, err in zip(dim, test_pred_error): 72 | state, county = idx 73 | state = str(int(state)).zfill(2) 74 | county = str(int(county)).zfill(3) 75 | 76 | pred_dict[state + county] = err 77 | 78 | crop_yield_prediction_error_plot(pred_dict, savepath) 79 | 80 | 81 | def output_to_csv_no_spatial(results_dic, out_dir): 82 | if not os.path.exists(out_dir): 83 | os.makedirs(out_dir) 84 | 85 | years = sorted(results_dic.keys()) 86 | model_types = sorted(results_dic[years[0]].keys()) 87 | 88 | for dt in ['valid', 'test']: 89 | data = [] 90 | for year in years: 91 | year_data, columns = [], [] 92 | for st in ['corr', 'r2', 'rmse']: 93 | for mt in model_types: 94 | year_data.append(results_dic[year][mt]['{}_{}'.format(dt, st)]) 95 | columns.append('{}_{}'.format(mt, '{}_{}'.format(dt, st))) 96 | data.append(year_data) 97 | 98 | data = pd.DataFrame(data, columns=columns, index=years) 99 | data.to_csv('{}/{}.csv'.format(out_dir, dt)) 100 | 101 | 102 | def output_to_csv_complex(results_dic, out_dir): 103 | years = sorted(results_dic.keys()) 104 | model_types = sorted(results_dic[years[0]].keys()) 105 | 106 | for dt in ['train', 'test']: 107 | data = [] 108 | for year in years: 109 | year_data, columns = [], [] 110 | for st in ['corr', 'r2', 'rmse']: 111 | for mt in model_types: 112 | year_data.append(results_dic[year][mt]['{}_{}'.format(dt, st)]) 113 | columns.append('{}_{}'.format(mt, '{}_{}'.format(dt, st))) 114 | data.append(year_data) 115 | 116 | data = pd.DataFrame(data, columns=columns, index=years) 117 | data.to_csv('{}/{}.csv'.format(out_dir, dt)) 118 | 119 | 120 | def output_to_csv_simple(results_dic, out_dir): 121 | years = sorted(results_dic.keys()) 122 | 123 | data = [] 124 | for year in years: 125 | year_data, columns = [], [] 126 | for st in ['corr', 'r2', 'rmse']: 127 | year_data.append(results_dic[year]['test_'+st]) 128 | columns.append(st) 129 | data.append(year_data) 130 | 131 | data = pd.DataFrame(data, columns=columns, index=years) 132 | data.to_csv('{}/test.csv'.format(out_dir)) 133 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/c3d/conv3d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Based on code from https://github.com/jfzhang95/pytorch-video-recognition/blob/master/network/C3D_model.py 7 | # Architecture is taken from https://esc.fnwi.uva.nl/thesis/centraal/files/f1570224447.pdf 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class C3D(nn.Module): 14 | """ 15 | The C3D network. 16 | """ 17 | 18 | def __init__(self, in_channels, n_tsteps): 19 | super(C3D, self).__init__() 20 | 21 | self.n_tsteps = n_tsteps 22 | 23 | # input (9, 7, 50, 50), output (9, 7, 50, 50) 24 | self.dimr1 = nn.Conv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1)) 25 | self.bn_dimr1 = nn.BatchNorm3d(in_channels, eps=1e-6, momentum=0.1) 26 | # output (3, 7, 50, 50) 27 | self.dimr2 = nn.Conv3d(in_channels, 3, kernel_size=(1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0)) 28 | self.bn_dimr2 = nn.BatchNorm3d(3, eps=1e-6, momentum=0.1) 29 | 30 | # output (64, 7, 50, 50) 31 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) 32 | self.bn1 = nn.BatchNorm3d(64, eps=1e-6, momentum=0.1) 33 | # output (64, 7, 25, 25) 34 | self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) 35 | 36 | # output (128, 7, 25, 25) 37 | self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) 38 | self.bn2 = nn.BatchNorm3d(128, eps=1e-6, momentum=0.1) 39 | # output (128, 7, 12, 12) 40 | self.pool2 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) 41 | 42 | # output (256, 7, 12, 12) 43 | self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 44 | self.bn3a = nn.BatchNorm3d(256, eps=1e-6, momentum=0.1) 45 | # output (256, 7, 12, 12) 46 | self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 47 | self.bn3b = nn.BatchNorm3d(256, eps=1e-6, momentum=0.1) 48 | # output (256, 3, 6, 6) 49 | self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 50 | 51 | # output (512, 3, 6, 6) 52 | self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 53 | self.bn4a = nn.BatchNorm3d(512, eps=1e-6, momentum=0.1) 54 | # output (512, 3, 6, 6) 55 | self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 56 | self.bn4b = nn.BatchNorm3d(512, eps=1e-6, momentum=0.1) 57 | # output (512, 1, 3, 3) 58 | self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 59 | self.pool4_keept = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) 60 | 61 | self.fc5 = nn.Linear(4608, 1024) 62 | self.fc6 = nn.Linear(1024, 1) 63 | 64 | self.dropout = nn.Dropout(p=0.5) 65 | 66 | self.relu = nn.ReLU() 67 | 68 | self.__init_weight() 69 | 70 | def forward(self, x): 71 | 72 | x = self.relu(self.bn_dimr1(self.dimr1(x))) 73 | x = self.relu(self.bn_dimr2(self.dimr2(x))) 74 | 75 | x = self.relu(self.bn1(self.conv1(x))) 76 | x = self.pool1(x) 77 | 78 | x = self.relu(self.bn2(self.conv2(x))) 79 | x = self.pool2(x) 80 | 81 | x = self.relu(self.bn3a(self.conv3a(x))) 82 | x = self.relu(self.bn3b(self.conv3b(x))) 83 | x = self.pool3(x) 84 | 85 | x = self.relu(self.bn4a(self.conv4a(x))) 86 | x = self.relu(self.bn4b(self.conv4b(x))) 87 | if self.n_tsteps > 3: 88 | x = self.pool4(x) 89 | else: 90 | x = self.pool4_keept(x) 91 | 92 | # output (512, 1, 3, 3) 93 | x = x.view(-1, 4608) 94 | x = self.relu(self.fc5(x)) 95 | x = self.dropout(x) 96 | 97 | pred = torch.squeeze(self.fc6(x)) 98 | 99 | return pred 100 | 101 | def __init_weight(self): 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv3d): 104 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 105 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 106 | torch.nn.init.kaiming_normal_(m.weight) 107 | elif isinstance(m, nn.BatchNorm3d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | 111 | # def get_1x_lr_params(model): 112 | # """ 113 | # This generator returns all the parameters for conv and two fc layers of the net. 114 | # """ 115 | # b = [model.conv1, model.conv2, model.conv3a, model.conv3b, model.conv4a, model.conv4b, 116 | # model.conv5a, model.conv5b, model.fc6, model.fc7] 117 | # for i in range(len(b)): 118 | # for k in b[i].parameters(): 119 | # if k.requires_grad: 120 | # yield k 121 | # 122 | # def get_10x_lr_params(model): 123 | # """ 124 | # This generator returns all the parameters for the last fc layer of the net. 125 | # """ 126 | # b = [model.fc8] 127 | # for j in range(len(b)): 128 | # for k in b[j].parameters(): 129 | # if k.requires_grad: 130 | # yield k 131 | 132 | # if __name__ == "__main__": 133 | # inputs = torch.rand(1, 3, 16, 112, 112) 134 | # net = C3D(num_classes=101, pretrained=True) 135 | # 136 | # outputs = net.forward(inputs) 137 | # print(outputs.size()) 138 | -------------------------------------------------------------------------------- /crop_yield_prediction/models/deep_gaussian_process/rnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Adapt code from https://github.com/gabrieltseng/pycrop-yield-prediction 7 | 8 | from torch import nn 9 | import torch 10 | 11 | import math 12 | from pathlib import Path 13 | 14 | from .base import ModelBase 15 | 16 | 17 | class RNNModel(ModelBase): 18 | """ 19 | A PyTorch replica of the RNN structured model from the original paper. Note that 20 | this class assumes feature_engineering was run with channels_first=True 21 | 22 | Parameters 23 | ---------- 24 | in_channels: int, default=9 25 | Number of channels in the input data. Default taken from the number of bands in the 26 | MOD09A1 + the number of bands in the MYD11A2 datasets 27 | num_bins: int, default=32 28 | Number of bins in the histogram 29 | hidden_size: int, default=128 30 | The size of the hidden state. Default taken from the original repository 31 | rnn_dropout: float, default=0.75 32 | Default taken from the original paper. Note that this dropout is applied to the 33 | hidden state after each timestep, not after each layer (since there is only one layer) 34 | dense_features: list, or None, default=None. 35 | output feature size of the Linear layers. If None, default values will be taken from the paper. 36 | The length of the list defines how many linear layers are used. 37 | savedir: pathlib Path, default=Path('data/models') 38 | The directory into which the models should be saved. 39 | device: torch.device 40 | Device to run model on. By default, checks for a GPU. If none exists, uses 41 | the CPU 42 | """ 43 | 44 | def __init__(self, in_channels=9, num_bins=32, hidden_size=128, rnn_dropout=0.75, 45 | dense_features=None, savedir=Path('data/models'), use_gp=True, 46 | sigma=1, r_loc=0.5, r_year=1.5, sigma_e=0.01, sigma_b=0.01, 47 | device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')): 48 | 49 | model = RNNet(in_channels=in_channels, num_bins=num_bins, hidden_size=hidden_size, 50 | num_rnn_layers=1, rnn_dropout=rnn_dropout, 51 | dense_features=dense_features) 52 | 53 | if dense_features is None: 54 | num_dense_layers = 2 55 | else: 56 | num_dense_layers = len(dense_features) 57 | model_weight = f'dense_layers.{num_dense_layers - 1}.weight' 58 | model_bias = f'dense_layers.{num_dense_layers - 1}.bias' 59 | 60 | super().__init__(model, model_weight, model_bias, 'rnn', savedir, use_gp, sigma, r_loc, r_year, 61 | sigma_e, sigma_b, device) 62 | 63 | def reinitialize_model(self, time=None): 64 | self.model.initialize_weights() 65 | 66 | 67 | class RNNet(nn.Module): 68 | """ 69 | A crop yield conv net. 70 | 71 | For a description of the parameters, see the RNNModel class. 72 | """ 73 | def __init__(self, in_channels=9, num_bins=32, hidden_size=128, num_rnn_layers=1, 74 | rnn_dropout=0.25, dense_features=None): 75 | super().__init__() 76 | 77 | if dense_features is None: 78 | dense_features = [256, 1] 79 | dense_features.insert(0, hidden_size) 80 | 81 | self.dropout = nn.Dropout(rnn_dropout) 82 | self.rnn = nn.LSTM(input_size=in_channels * num_bins, 83 | hidden_size=hidden_size, 84 | num_layers=num_rnn_layers, 85 | batch_first=True) 86 | self.hidden_size = hidden_size 87 | 88 | self.dense_layers = nn.ModuleList([ 89 | nn.Linear(in_features=dense_features[i-1], 90 | out_features=dense_features[i]) 91 | for i in range(1, len(dense_features)) 92 | ]) 93 | 94 | self.initialize_weights() 95 | 96 | def initialize_weights(self): 97 | 98 | sqrt_k = math.sqrt(1 / self.hidden_size) 99 | for parameters in self.rnn.all_weights: 100 | for pam in parameters: 101 | nn.init.uniform_(pam.data, -sqrt_k, sqrt_k) 102 | 103 | for dense_layer in self.dense_layers: 104 | nn.init.kaiming_uniform_(dense_layer.weight.data) 105 | nn.init.constant_(dense_layer.bias.data, 0) 106 | 107 | def forward(self, x, return_last_dense=False): 108 | """ 109 | If return_last_dense is true, the feature vector generated by the second to last 110 | dense layer will also be returned. This is then used to train a Gaussian Process model. 111 | """ 112 | # the model expects feature_engineer to have been run with channels_first=True, which means 113 | # the input is [batch, bands, times, bins]. 114 | # Reshape to [batch, times, bands * bins] 115 | x = x.permute(0, 2, 1, 3).contiguous() 116 | x = x.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]) 117 | 118 | sequence_length = x.shape[1] 119 | 120 | hidden_state = torch.zeros(1, x.shape[0], self.hidden_size) 121 | cell_state = torch.zeros(1, x.shape[0], self.hidden_size) 122 | 123 | if x.is_cuda: 124 | hidden_state = hidden_state.cuda() 125 | cell_state = cell_state.cuda() 126 | 127 | for i in range(sequence_length): 128 | # The reason the RNN is unrolled here is to apply dropout to each timestep; 129 | # The rnn_dropout argument only applies it after each layer. This better mirrors 130 | # the behaviour of the Dropout Wrapper used in the original repository 131 | # https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/DropoutWrapper 132 | input_x = x[:, i, :].unsqueeze(1) 133 | _, (hidden_state, cell_state) = self.rnn(input_x, 134 | (hidden_state, cell_state)) 135 | hidden_state = self.dropout(hidden_state) 136 | 137 | x = hidden_state.squeeze(0) 138 | for layer_number, dense_layer in enumerate(self.dense_layers): 139 | x = dense_layer(x) 140 | if return_last_dense and (layer_number == len(self.dense_layers) - 2): 141 | output = x 142 | if return_last_dense: 143 | return x, output 144 | return x 145 | -------------------------------------------------------------------------------- /data_preprocessing/rescaling/cdl_upscale.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from data_preprocessing.utils import get_lat_lon_bins 8 | from data_preprocessing import cdl_values_to_crops, crops_to_cdl_values 9 | from data_preprocessing.utils import timeit 10 | 11 | import os 12 | import numpy as np 13 | import numpy.ma as ma 14 | from netCDF4 import Dataset 15 | 16 | 17 | # water: Water Wetlands Aquaculture Open Water Perennial Ice/Snow 18 | # urban: Developed Developed/Open Space Developed/Low Intensity Developed/Med Intensity Developed/High Intensity 19 | # native: Clover/Wildflowers Forest Shrubland1 Deciduous Forest Evergreen Forest Mixed Forest Shrubland2 Woody Wetlands 20 | # Herbaceous Wetlands 21 | # idle/fallow: Sod/Grass Seed Fallow/Idle Cropland 22 | # hay/pasture: Other Hay/Non Alfalfa Switchgrass Grassland/Pasture 23 | # barren/missing: Barren1 Clouds/No Data Nonag/Undefined Barren2 24 | ignored_labels = {"water": [83, 87, 92, 111, 112], 25 | "urban": [82, 121, 122, 123, 124], 26 | "native": [58, 63, 64, 141, 142, 143, 152, 190, 195], 27 | "idle/fallow": [59, 61], 28 | "hay/pasture": [37, 60, 176], 29 | "barren/missing": [65, 81, 88, 131]} 30 | 31 | 32 | def cdl_upscale(in_dir, in_file, out_dir, out_file, reso='40km', ignore=False): 33 | if not os.path.exists(out_dir): 34 | os.makedirs(out_dir) 35 | 36 | ignored_lis = [x for lis in ignored_labels.values() for x in lis] 37 | kept_lis = [x for x in cdl_values_to_crops.keys() if x not in ignored_lis] 38 | 39 | # increasing 40 | lats = np.load('../../processed_data/prism/latlon/lat_{}.npy'.format(reso)) 41 | lons = np.load('../../processed_data/prism/latlon/lon_{}.npy'.format(reso)) 42 | _, _, lat_bins, lon_bins = get_lat_lon_bins(lats, lons) 43 | 44 | fh_in = Dataset(os.path.join(in_dir, in_file), 'r') 45 | fh_out = Dataset(os.path.join(out_dir, out_file), 'w') 46 | 47 | dic_var = {} 48 | for var in ['lat', 'lon']: 49 | dic_var[var] = fh_in.variables[var] 50 | # increasing 51 | dic_var['lat_value'] = dic_var['lat'][:] 52 | dic_var['lon_value'] = dic_var['lon'][:] 53 | 54 | fh_out.createDimension('lat', len(lats)) 55 | fh_out.createDimension('lon', len(lons)) 56 | 57 | for var in ['lat', 'lon']: 58 | outVar = fh_out.createVariable(var, 'f4', (var,)) 59 | outVar.setncatts({k: dic_var[var].getncattr(k) for k in dic_var[var].ncattrs()}) 60 | outVar[:] = lats if var == "lat" else lons 61 | 62 | cdl_value = fh_in.variables['Band1'][:] 63 | cdl_resampled_dic = {} 64 | for v in cdl_values_to_crops.values(): 65 | if (ignore and crops_to_cdl_values[v] in kept_lis) or not ignore: 66 | cdl_resampled_dic[v] = np.full((len(lats), len(lons)), -1.0) 67 | 68 | for s in ["1", "2", "3"]: 69 | cdl_resampled_dic["cdl_" + s] = np.full((len(lats), len(lons)), -1.0) 70 | cdl_resampled_dic["cdl_fraction_" + s] = np.full((len(lats), len(lons)), -1.0) 71 | 72 | for id_lats in range(len(lats)): 73 | for id_lons in range(len(lons)): 74 | lats_index = np.searchsorted(dic_var['lat_value'], 75 | [lat_bins[id_lats], lat_bins[id_lats + 1]]) 76 | lons_index = np.searchsorted(dic_var['lon_value'], 77 | [lon_bins[id_lons], lon_bins[id_lons + 1]]) 78 | 79 | if lats_index[0] != lats_index[1] and lons_index[0] != lons_index[1]: 80 | selected = cdl_value[np.array(range(lats_index[0], lats_index[1]))[:, None], 81 | np.array(range(lons_index[0], lons_index[1]))] 82 | # selected_size = selected.shape[0] * selected.shape[1] 83 | selected_compressed = selected.compressed() 84 | selected_size = len(selected_compressed) 85 | cdl_id, cdl_count = np.unique(selected_compressed, return_counts=True) 86 | 87 | # filter ignored_label after selected_size has been calculated 88 | if ignore: 89 | new_cdl_id, new_cdl_count = [], [] 90 | for i, c in zip(cdl_id, cdl_count): 91 | if i in kept_lis: 92 | new_cdl_id.append(i) 93 | new_cdl_count.append(c) 94 | cdl_id, cdl_count = np.asarray(new_cdl_id), np.asarray(new_cdl_count) 95 | 96 | for i, c in zip(cdl_id, cdl_count): 97 | cdl_resampled_dic[cdl_values_to_crops[i]][id_lats, id_lons] = c / selected_size 98 | cdl_count_sort_ind = np.argsort(-cdl_count) 99 | for i in range(3): 100 | if len(cdl_id) > i: 101 | cdl_resampled_dic["cdl_" + str(i+1)][id_lats, id_lons] = \ 102 | cdl_id[cdl_count_sort_ind[i]] 103 | cdl_resampled_dic["cdl_fraction_" + str(i+1)][id_lats, id_lons] = \ 104 | cdl_count[cdl_count_sort_ind[i]] / selected_size 105 | else: 106 | cdl_resampled_dic["cdl_" + str(i + 1)][id_lats, id_lons] = -1 107 | cdl_resampled_dic["cdl_fraction_" + str(i + 1)][id_lats, id_lons] = -1 108 | 109 | for v in cdl_values_to_crops.values(): 110 | if (ignore and crops_to_cdl_values[v] in kept_lis) or not ignore: 111 | outVar = fh_out.createVariable("cdl_" + v.lower().replace(' ', '_').replace(' & ', '_').replace('/', '_'), 112 | 'f4', ('lat', 'lon',)) 113 | outVar[:] = cdl_resampled_dic[v][:] 114 | outVar[:] = ma.masked_equal(outVar, -1.0) 115 | for s in ["1", "2", "3"]: 116 | for t in ["cdl_", "cdl_fraction_"]: 117 | outVar = fh_out.createVariable(t + s, 'f4', ('lat', 'lon',)) 118 | outVar[:] = cdl_resampled_dic[t + s][:] 119 | outVar[:] = ma.masked_equal(outVar, -1.0) 120 | 121 | fh_in.close() 122 | fh_out.close() 123 | 124 | 125 | def upscaled_cdl_postprocess(in_file, out_dir, out_file, threshold=0.0): 126 | fh_in = Dataset(in_file, 'r') 127 | 128 | cdl_fraction_1 = fh_in.variables['cdl_fraction_1'][:] 129 | kept_cdls = ma.masked_where(cdl_fraction_1 < threshold, fh_in.variables['cdl_1'][:]) 130 | cdl_id, cdl_count = np.unique(kept_cdls.compressed(), return_counts=True) 131 | for i, c in zip(cdl_id, cdl_count): 132 | print(cdl_values_to_crops[int(i)], c) 133 | 134 | 135 | if __name__ == "__main__": 136 | timeit() 137 | # cdl_upscale('../../raw_data/cdl/2018_30m_cdls', '2018_30m_cdls.nc', 138 | # '../../processed_data/cdl/40km', '2018_40km_cdls_crop_only.nc', reso='40km', ignore=True) 139 | upscaled_cdl_postprocess('../../processed_data/cdl/40km/2018_40km_cdls_crop_only.nc', 140 | '', '') 141 | # print([x for lis in ignored_labels.values() for x in lis]) 142 | -------------------------------------------------------------------------------- /crop_yield_prediction/plot/plot_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import matplotlib.pyplot as plt 9 | from collections import defaultdict 10 | 11 | 12 | def plot_loss(params): 13 | out_dir = '../../results/spatial_temporal/plots/{}'.format(params[:-4]) 14 | os.makedirs(out_dir, exist_ok=True) 15 | prediction_log = '../../results/spatial_temporal/prediction_logs/{}'.format(params) 16 | train_epochs_dic = defaultdict(lambda: defaultdict(list)) 17 | train_loss_dic, train_super_loss_dic, train_unsuper_loss_dic = (defaultdict(lambda: defaultdict(list)) for _ in range(3)) 18 | valid_loss_dic, valid_super_loss_dic, valid_unsuper_loss_dic = (defaultdict(lambda: defaultdict(list)) for _ in range(3)) 19 | valid_l_n_loss_dic, valid_l_d_loss_dic, valid_l_nd_loss_dic, valid_sn_loss_dic, valid_tn_loss_dic, valid_norm_loss_dic = \ 20 | (defaultdict(lambda: defaultdict(list)) for _ in range(6)) 21 | valid_rmse_dic, valid_r2_dic, valid_corr_dic = (defaultdict(lambda: defaultdict(list)) for _ in range(3)) 22 | test_epochs_dic = defaultdict(lambda: defaultdict(list)) 23 | test_rmse_dic, test_r2_dic, test_corr_dic = (defaultdict(lambda: defaultdict(list)) for _ in range(3)) 24 | 25 | exp = 0 26 | year = 0 27 | with open(prediction_log) as f: 28 | content = f.readlines() 29 | for line in content: 30 | line = line.strip() 31 | if line.startswith('Predict'): 32 | year = int(line.split()[2][:4]) 33 | if line.startswith('Experiment'): 34 | exp = int(line.split()[1]) 35 | if 'Epoch' in line: 36 | train_epochs_dic[year][exp].append(int(line.split()[2])) 37 | if 'Training' in line: 38 | ws = line.split() 39 | train_loss_dic[year][exp].append(float(ws[4][:-1])) 40 | train_super_loss_dic[year][exp].append(float(ws[7][:-1])) 41 | train_unsuper_loss_dic[year][exp].append(float(ws[10][:-1])) 42 | if 'Validation' in line: 43 | ws = line.split() 44 | valid_loss_dic[year][exp].append(float(ws[3][:-1])) 45 | valid_super_loss_dic[year][exp].append(float(ws[6][:-1])) 46 | valid_unsuper_loss_dic[year][exp].append(float(ws[9][:-1])) 47 | valid_l_n_loss_dic[year][exp].append(float(ws[12][:-1])) 48 | valid_l_d_loss_dic[year][exp].append(float(ws[15][:-1])) 49 | valid_l_nd_loss_dic[year][exp].append(float(ws[18][:-1])) 50 | valid_sn_loss_dic[year][exp].append(float(ws[20][:-1])) 51 | valid_tn_loss_dic[year][exp].append(float(ws[22][:-1])) 52 | valid_norm_loss_dic[year][exp].append(float(ws[24][:-1])) 53 | valid_rmse_dic[year][exp].append(float(ws[26][:-1])) 54 | valid_r2_dic[year][exp].append(float(ws[28][:-1])) 55 | valid_corr_dic[year][exp].append(float(ws[30][:-1])) 56 | if '(Test)' in line and 'epoch' in line: 57 | ws = line.split() 58 | test_epochs_dic[year][exp].append(int(ws[3][:-1])) 59 | test_rmse_dic[year][exp].append(float(ws[5][:-1])) 60 | test_r2_dic[year][exp].append(float(ws[7][:-1])) 61 | test_corr_dic[year][exp].append(float(ws[9])) 62 | 63 | for year in train_epochs_dic.keys(): 64 | n_exps = len(train_epochs_dic[year]) 65 | for i in range(n_exps): 66 | # assert train_epochs_dic[year][i] == test_epochs_dic[year][i], params 67 | 68 | plt.plot(train_epochs_dic[year][i], train_loss_dic[year][i], label='Training') 69 | plt.plot(train_epochs_dic[year][i], valid_loss_dic[year][i], label='Validation') 70 | plt.title(params, fontsize=8) 71 | plt.grid(True) 72 | plt.legend() 73 | plt.savefig('{}/{}_{}_total_loss.jpg'.format(out_dir, year, i), dpi=300) 74 | plt.close() 75 | 76 | plt.plot(train_epochs_dic[year][i], train_super_loss_dic[year][i], label='Training') 77 | plt.plot(train_epochs_dic[year][i], valid_super_loss_dic[year][i], label='Validation') 78 | plt.title(params, fontsize=8) 79 | plt.grid(True) 80 | plt.legend() 81 | plt.savefig('{}/{}_{}_supervised_loss.jpg'.format(out_dir, year, i), dpi=300) 82 | plt.close() 83 | 84 | plt.plot(train_epochs_dic[year][i], train_unsuper_loss_dic[year][i], label='Training') 85 | plt.plot(train_epochs_dic[year][i], valid_unsuper_loss_dic[year][i], label='Validation') 86 | plt.title(params, fontsize=8) 87 | plt.grid(True) 88 | plt.legend() 89 | plt.savefig('{}/{}_{}_unsupervised_loss.jpg'.format(out_dir, year, i), dpi=300) 90 | plt.close() 91 | 92 | # valid_l_n_loss, valid_l_d_loss, valid_l_nd_loss, valid_sn_loss, valid_tn_loss, valid_norm_loss 93 | plt.plot(train_epochs_dic[year][i], valid_l_n_loss_dic[year][i], label='l_n_loss') 94 | plt.plot(train_epochs_dic[year][i], valid_l_d_loss_dic[year][i], label='l_d_loss') 95 | plt.plot(train_epochs_dic[year][i], valid_l_nd_loss_dic[year][i], label='l_nd_loss') 96 | plt.plot(train_epochs_dic[year][i], valid_sn_loss_dic[year][i], label='spatial_neighbor_loss') 97 | plt.plot(train_epochs_dic[year][i], valid_tn_loss_dic[year][i], label='temporal_neighbor_loss') 98 | plt.plot(train_epochs_dic[year][i], valid_norm_loss_dic[year][i], label='l2_norm_loss') 99 | plt.title(params, fontsize=8) 100 | plt.grid(True) 101 | plt.legend() 102 | plt.savefig('{}/{}_{}_validation_various_losses.jpg'.format(out_dir, year, i), dpi=300) 103 | plt.close() 104 | 105 | plt.plot(train_epochs_dic[year][i], valid_rmse_dic[year][i], label='Validation') 106 | plt.plot(test_epochs_dic[year][i], test_rmse_dic[year][i], label='Test') 107 | plt.title(params, fontsize=8) 108 | plt.grid(True) 109 | plt.legend() 110 | plt.savefig('{}/{}_{}_rmse.jpg'.format(out_dir, year, i), dpi=300) 111 | plt.close() 112 | 113 | plt.plot(train_epochs_dic[year][i], valid_r2_dic[year][i], label='Validation') 114 | plt.plot(test_epochs_dic[year][i], test_r2_dic[year][i], label='Test') 115 | plt.title(params, fontsize=8) 116 | plt.grid(True) 117 | plt.legend() 118 | plt.savefig('{}/{}_{}_r2.jpg'.format(out_dir, year, i), dpi=300) 119 | plt.close() 120 | 121 | plt.plot(train_epochs_dic[year][i], valid_corr_dic[year][i], label='Validation') 122 | plt.plot(test_epochs_dic[year][i], test_corr_dic[year][i], label='Test') 123 | plt.title(params, fontsize=8) 124 | plt.grid(True) 125 | plt.legend() 126 | plt.savefig('{}/{}_{}_corr.jpg'.format(out_dir, year, i), dpi=300) 127 | plt.close() 128 | 129 | 130 | if __name__ == '__main__': 131 | for prediction_log in os.listdir('../../results/spatial_temporal/prediction_logs'): 132 | if prediction_log.endswith('.txt'): 133 | print(prediction_log) 134 | plot_loss(prediction_log) 135 | 136 | -------------------------------------------------------------------------------- /crop_yield_train_c3d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import numpy as np 9 | import pandas as pd 10 | import argparse 11 | import torch.optim as optim 12 | from pathlib import Path 13 | import sys 14 | 15 | sys.path.append("..") 16 | 17 | from crop_yield_prediction.models.c3d import C3D 18 | from crop_yield_prediction.train_c3d import train_c3d 19 | from crop_yield_prediction.utils import plot_predict 20 | from crop_yield_prediction.utils import plot_predict_error 21 | from crop_yield_prediction.utils import output_to_csv_simple 22 | from crop_yield_prediction.train_c3d import eval_test 23 | 24 | 25 | def crop_yield_train_c3d(args, data_dir, model_out_dir, result_out_dir, log_out_dir, start_year, end_year, 26 | n_tsteps, train_years=None): 27 | batch_size = 30 28 | test_batch_size = 128 29 | n_triplets_per_file = 1 30 | epochs = 50 31 | n_experiment = 2 32 | 33 | patience = args.patience if args.patience != 9999 else None 34 | feature = args.feature 35 | feature_len = args.feature_len 36 | 37 | params = '{}_nt{}_es{}_{}_tyear{}'.format(start_year, n_tsteps, patience, feature, train_years) 38 | 39 | os.makedirs(log_out_dir, exist_ok=True) 40 | param_model_out_dir = '{}/{}'.format(model_out_dir, params) 41 | os.makedirs(param_model_out_dir, exist_ok=True) 42 | param_result_out_dir = '{}/{}'.format(result_out_dir, params) 43 | os.makedirs(param_result_out_dir, exist_ok=True) 44 | 45 | if feature == 'all': 46 | X_dir = '{}/nr_25_dr100'.format(data_dir) 47 | else: 48 | X_dir = '{}/nr_25_dr100_{}'.format(data_dir, feature) 49 | 50 | dim_y = pd.read_csv('{}/dim_y.csv'.format(data_dir)) 51 | dim_y = dim_y.astype({'state': int, 'county': int, 'year': int, 'value': float, 'lat': float, 'lon': float}) 52 | max_index = len(dim_y) - 1 53 | 54 | results = dict() 55 | for year in range(start_year, end_year + 1): 56 | print('Predict year {}......'.format(year)) 57 | 58 | test_idx = (dim_y['year'] == year) 59 | valid_idx = (dim_y['year'] == (year - 1)) 60 | if train_years is None: 61 | train_idx = (dim_y['year'] < (year - 1)) 62 | else: 63 | train_idx = (dim_y['year'] < (year - 1)) & (dim_y['year'] >= (year - 1 - train_years)) 64 | 65 | y_valid, y_train = np.array(dim_y.loc[valid_idx]['value']), np.array(dim_y.loc[train_idx]['value']) 66 | y_test, dim_test = np.array(dim_y.loc[test_idx]['value']), np.array(dim_y.loc[test_idx][['state', 'county']]) 67 | 68 | test_indices = [i for i, x in enumerate(test_idx) if x] 69 | valid_indices = [i for i, x in enumerate(valid_idx) if x] 70 | train_indices = [i for i, x in enumerate(train_idx) if x] 71 | 72 | # check if the indices are sequential 73 | assert all(elem == 1 for elem in [y - x for x, y in zip(test_indices[:-1], test_indices[1:])]) 74 | assert all(elem == 1 for elem in [y - x for x, y in zip(valid_indices[:-1], valid_indices[1:])]) 75 | assert all(elem == 1 for elem in [y - x for x, y in zip(train_indices[:-1], train_indices[1:])]) 76 | print('Train size {}, valid size {}, test size {}'.format(y_train.shape[0], y_valid.shape[0], y_test.shape[0])) 77 | 78 | test_corr_lis, test_r2_lis, test_rmse_lis = [], [], [] 79 | test_prediction_lis = [] 80 | for i in range(n_experiment): 81 | print('Experiment {}'.format(i)) 82 | 83 | c3d = C3D(in_channels=feature_len, n_tsteps=n_tsteps) 84 | 85 | optimizer = optim.Adam(c3d.parameters(), lr=0.001) 86 | 87 | trained_epochs = train_c3d(model=c3d, 88 | X_dir=X_dir, 89 | X_train_indices=(train_indices[0], train_indices[-1]), 90 | y_train=y_train, 91 | X_valid_indices=(valid_indices[0], valid_indices[-1]), 92 | y_valid=y_valid, 93 | X_test_indices=(test_indices[0], test_indices[-1]), 94 | y_test=y_test, 95 | n_tsteps=n_tsteps, 96 | max_index=max_index, 97 | n_triplets_per_file=n_triplets_per_file, 98 | patience=patience, 99 | optimizer=optimizer, 100 | batch_size=batch_size, 101 | test_batch_size=test_batch_size, 102 | n_epochs=epochs, 103 | out_dir=param_model_out_dir, 104 | year=year, 105 | exp_idx=i, 106 | log_file='{}/{}.txt'.format(log_out_dir, params)) 107 | 108 | test_prediction, rmse, r2, corr = eval_test(X_dir, 109 | X_test_indices=(test_indices[0], test_indices[-1]), 110 | y_test=y_test, 111 | n_tsteps=n_tsteps, 112 | max_index=max_index, 113 | n_triplets_per_file=n_triplets_per_file, 114 | batch_size=test_batch_size, 115 | model_dir=param_model_out_dir, 116 | model=c3d, 117 | epochs=trained_epochs, 118 | year=year, 119 | exp_idx=i, 120 | log_file='{}/{}.txt'.format(log_out_dir, params)) 121 | test_corr_lis.append(corr) 122 | test_r2_lis.append(r2) 123 | test_rmse_lis.append(rmse) 124 | 125 | test_prediction_lis.append(test_prediction) 126 | 127 | test_prediction = np.mean(np.asarray(test_prediction_lis), axis=0) 128 | np.save('{}/{}.npy'.format(param_result_out_dir, year), test_prediction) 129 | plot_predict(test_prediction, dim_test, Path('{}/pred_{}.html'.format(param_result_out_dir, year))) 130 | plot_predict_error(test_prediction, y_test, dim_test, Path('{}/err_{}.html'.format(param_result_out_dir, year))) 131 | 132 | results[year] = {'test_rmse': np.around(np.mean(test_rmse_lis), 3), 133 | 'test_r2': np.around(np.mean(test_r2_lis), 3), 134 | 'test_corr': np.around(np.mean(test_corr_lis), 3)} 135 | 136 | output_to_csv_simple(results, param_result_out_dir) 137 | 138 | 139 | if __name__ == '__main__': 140 | parser = argparse.ArgumentParser(description='Crop Yield Train C3D') 141 | parser.add_argument('--patience', type=int, default=9999, metavar='PATIENCE') 142 | parser.add_argument('--feature', type=str, default='all', metavar='FEATURE') 143 | parser.add_argument('--feature-len', type=int, default=9, metavar='FEATURE_LEN') 144 | parser.add_argument('--year', type=int, default=2014, metavar='YEAR') 145 | parser.add_argument('--ntsteps', type=int, default=7, metavar='NTSTEPS', required=True) 146 | parser.add_argument('--train-years', type=int, default=None, metavar='TRAINYEAR', required=True) 147 | 148 | args = parser.parse_args() 149 | 150 | crop_yield_train_c3d(args, 151 | data_dir='data/spatial_temporal/counties', 152 | model_out_dir='results/c3d/models', 153 | result_out_dir='results/c3d/results', 154 | log_out_dir='results/c3d/prediction_logs', 155 | start_year=args.year, 156 | end_year=args.year, 157 | n_tsteps=args.ntsteps, 158 | train_years=args.train_years) 159 | -------------------------------------------------------------------------------- /data_preprocessing/preprocess/county_locations.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from netCDF4 import Dataset 9 | import numpy as np 10 | import numpy.ma as ma 11 | from pathlib import Path 12 | import pandas as pd 13 | import matplotlib.pyplot as plt 14 | import csv 15 | import sys 16 | sys.path.append("..") 17 | 18 | from data_preprocessing.utils import match_lat_lon 19 | from data_preprocessing.plot import counties_plot 20 | 21 | 22 | def generate_convert_to_nc_script(): 23 | fh_out = open('../../processed_data/counties/all/convert_to_nc.sh', 'w') 24 | fh_out.write('#!/bin/bash\n') 25 | 26 | for tif_file in os.listdir('../../processed_data/counties/all/tif/'): 27 | if tif_file.endswith('.tif'): 28 | fh_out.write('gdal_translate -of netCDF tif/{} nc/{}.nc\n'.format(tif_file, tif_file[:-4])) 29 | 30 | 31 | def combine_ncs(): 32 | fh_out = Dataset('../../processed_data/counties/us_counties.nc', 'w') 33 | fh_ref = Dataset('../../processed_data/landcover/cropland_cro.nc', 'r') 34 | 35 | lats, lons = fh_ref.variables['lat'][:], fh_ref.variables['lon'][:] 36 | 37 | for name, dim in fh_ref.dimensions.items(): 38 | fh_out.createDimension(name, len(dim)) 39 | 40 | for v_name, varin in fh_ref.variables.items(): 41 | if v_name in ['lat', 'lon']: 42 | outVar = fh_out.createVariable(v_name, varin.datatype, (v_name,)) 43 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 44 | outVar[:] = varin[:] 45 | 46 | outVar = fh_out.createVariable('county_label', 'int', ('lat', 'lon')) 47 | outVar.setncatts({'_FillValue': np.array([0]).astype(int)}) 48 | counties_labels = np.full((len(lats), len(lons)), 0) 49 | 50 | outVar = fh_out.createVariable('state_code', 'int', ('lat', 'lon')) 51 | outVar.setncatts({'_FillValue': np.array([0]).astype(int)}) 52 | state_code = np.full((len(lats), len(lons)), 0) 53 | 54 | outVar = fh_out.createVariable('county_code', 'int', ('lat', 'lon')) 55 | outVar.setncatts({'_FillValue': np.array([0]).astype(int)}) 56 | county_code = np.full((len(lats), len(lons)), 0) 57 | 58 | for nc_file in os.listdir('../../processed_data/counties/all/nc/'): 59 | if nc_file.endswith('.nc'): 60 | # ignore Alaska 61 | if nc_file.split('_')[0] == '2': 62 | continue 63 | print(nc_file) 64 | fh_in = Dataset('../../processed_data/counties/all/nc/{}'.format(nc_file), 'r') 65 | local_lats, local_lons = fh_in.variables['lat'][:], fh_in.variables['lon'][:] 66 | 67 | i_lat_start, i_lat_end, i_lon_start, i_lon_end = match_lat_lon(lats, lons, local_lats, local_lons) 68 | 69 | local_values = ma.masked_equal(fh_in.variables['Band1'][:], 0.0) 70 | for i, j in zip(*local_values.nonzero()): 71 | state, county = nc_file[:-3].split('_') 72 | state = str(state).zfill(2) 73 | county = str(county).zfill(3) 74 | counties_labels[i+i_lat_start, j+i_lon_start] = int(state+county) 75 | state_code[i+i_lat_start, j+i_lon_start] = int(state) 76 | county_code[i+i_lat_start, j+i_lon_start] = int(county) 77 | 78 | fh_in.close() 79 | 80 | fh_out.variables['county_label'][:] = ma.masked_equal(counties_labels, 0) 81 | fh_out.variables['state_code'][:] = ma.masked_equal(state_code, 0) 82 | fh_out.variables['county_code'][:] = ma.masked_equal(county_code, 0) 83 | 84 | fh_ref.close() 85 | fh_out.close() 86 | 87 | 88 | def mask_with_landcover(out_file, ref_file): 89 | fh_in = Dataset('../../processed_data/counties/us_counties.nc', 'r') 90 | fh_out = Dataset(out_file, 'w') 91 | fh_ref = Dataset(ref_file, 'r') 92 | 93 | for name, dim in fh_in.dimensions.items(): 94 | fh_out.createDimension(name, len(dim)) 95 | 96 | for v_name, varin in fh_in.variables.items(): 97 | outVar = fh_out.createVariable(v_name, varin.datatype, varin.dimensions) 98 | outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) 99 | if v_name in ['lat', 'lon']: 100 | outVar[:] = varin[:] 101 | else: 102 | cropland_mask = ma.getmaskarray(fh_ref.variables['cropland'][:]) 103 | outVar[:] = ma.array(varin[:], mask=cropland_mask) 104 | 105 | fh_in.close() 106 | fh_out.close() 107 | fh_ref.close() 108 | 109 | 110 | def plot_counties(in_file): 111 | fh = Dataset(in_file, 'r') 112 | county_labels = fh.variables['county_label'][:] 113 | print(len(np.unique(county_labels.compressed()))) 114 | fh.close() 115 | 116 | county_labels = np.unique(county_labels.compressed()) 117 | county_labels = [[str(x).zfill(5)[:2], str(x).zfill(5)[2:]] for x in county_labels] 118 | 119 | data_dic = {} 120 | for state, county in county_labels: 121 | data_dic[state+county] = 100 122 | fake_quantiles = {x: 1 for x in [0.05, 0.2, 0.4, 0.6, 0.8, 0.95]} 123 | counties_plot(data_dic, Path('../../processed_data/counties/{}.html'.format(in_file[:-3])), fake_quantiles) 124 | 125 | return data_dic.keys() 126 | 127 | 128 | def plot_counties_data(in_file): 129 | county_data = pd.read_csv(in_file)[['StateFips', 'CntyFips']] 130 | county_data.columns = ['State', 'County'] 131 | 132 | data_dic = {} 133 | for row in county_data.itertuples(): 134 | state, county = int(row.State), int(row.County) 135 | state = str(state).zfill(2) 136 | county = str(county).zfill(3) 137 | data_dic[state + county] = 100 138 | 139 | fake_quantiles = {x: 1 for x in [0.05, 0.2, 0.4, 0.6, 0.8, 0.95]} 140 | counties_plot(data_dic, Path('../../processed_data/counties/county_data.html'), fake_quantiles) 141 | 142 | return data_dic.keys() 143 | 144 | 145 | def analyze_counties(in_file): 146 | fh = Dataset(in_file, 'r') 147 | counties, sizes = np.unique(fh.variables['county_label'][:].compressed(), return_counts=True) 148 | for county, size in zip(counties, sizes): 149 | print(county, size) 150 | plt.hist(sizes) 151 | plt.show() 152 | 153 | 154 | def get_county_locations(in_file): 155 | fh = Dataset(in_file, 'r') 156 | lats, lons = fh.variables['lat'][:], fh.variables['lon'][:] 157 | 158 | county_labels = fh.variables['county_label'][:] 159 | counties = np.unique(county_labels.compressed()) 160 | 161 | with open('{}_locations.csv'.format(in_file[:-3]), 'w') as f: 162 | writer = csv.writer(f, delimiter=',') 163 | writer.writerow(['state', 'county', 'lat', 'lon']) 164 | for county in counties: 165 | selected_rows, selected_cols = np.where(county_labels == county) 166 | lat_mean, lon_mean = np.mean(lats[selected_rows]), np.mean(lons[selected_cols]) 167 | line = [str(county).zfill(5)[:2], str(county).zfill(5)[2:], lat_mean, lon_mean] 168 | writer.writerow(line) 169 | 170 | 171 | if __name__ == '__main__': 172 | # generate_convert_to_nc_script() 173 | combine_ncs() 174 | # mask_with_landcover('../../processed_data/counties/us_counties_cro.nc', 175 | # '../../processed_data/landcover/cropland_cro.nc') 176 | # mask_with_landcover('../../processed_data/counties/us_counties_cro_cvm.nc', 177 | # '../../processed_data/landcover/cropland_cro_cvm.nc') 178 | # 179 | # county_key = plot_counties_data('../../processed_data/counties/county_data.csv') 180 | # us_county_key = plot_counties('../../processed_data/counties/us_counties.nc') 181 | # print([x for x in us_county_key if x not in county_key]) 182 | # print([x for x in county_key if x not in us_county_key and not x.startswith('02')]) 183 | # plot_counties('../../processed_data/counties/us_counties_cro.nc') 184 | # plot_counties('../../processed_data/counties/us_counties_cro_cvm.nc') 185 | # 186 | # analyze_counties('../../processed_data/counties/us_counties.nc') 187 | # analyze_counties('../../processed_data/counties/us_counties_cro.nc') 188 | # analyze_counties('../../processed_data/counties/us_counties_cro_cvm.nc') 189 | 190 | # get_county_locations('../../processed_data/counties/us_counties.nc') 191 | # get_county_locations('../../processed_data/counties/us_counties_cro.nc') 192 | # get_county_locations('../../processed_data/counties/us_counties_cro_cvm.nc') 193 | -------------------------------------------------------------------------------- /crop_yield_train_cnn_lstm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import numpy as np 9 | import pandas as pd 10 | import argparse 11 | import torch.optim as optim 12 | from pathlib import Path 13 | import sys 14 | 15 | sys.path.append("..") 16 | 17 | from crop_yield_prediction.models.cnn_lstm import CnnLstm 18 | from crop_yield_prediction.train_cnn_lstm import train_cnn_lstm 19 | from crop_yield_prediction.utils import plot_predict 20 | from crop_yield_prediction.utils import plot_predict_error 21 | from crop_yield_prediction.utils import output_to_csv_simple 22 | from crop_yield_prediction.train_cnn_lstm import eval_test 23 | 24 | 25 | def crop_yield_train_cnn_lstm(args, data_dir, model_out_dir, result_out_dir, log_out_dir, start_year, end_year, 26 | n_tsteps, train_years=None): 27 | batch_size = 64 28 | test_batch_size = 128 29 | n_triplets_per_file = 1 30 | epochs = 50 31 | n_experiment = 2 32 | 33 | patience = args.patience if args.patience != 9999 else None 34 | feature = args.feature 35 | feature_len = args.feature_len 36 | tilenet_zdim = args.tilenet_zdim 37 | lstm_inner = args.lstm_inner 38 | 39 | params = '{}_nt{}_es{}_{}_tyear{}_zdim{}_din{}'.format(start_year, n_tsteps, patience, feature, train_years, tilenet_zdim, lstm_inner) 40 | 41 | os.makedirs(log_out_dir, exist_ok=True) 42 | param_model_out_dir = '{}/{}'.format(model_out_dir, params) 43 | os.makedirs(param_model_out_dir, exist_ok=True) 44 | param_result_out_dir = '{}/{}'.format(result_out_dir, params) 45 | os.makedirs(param_result_out_dir, exist_ok=True) 46 | 47 | if feature == 'all': 48 | X_dir = '{}/nr_25_dr100'.format(data_dir) 49 | else: 50 | X_dir = '{}/nr_25_dr100_{}'.format(data_dir, feature) 51 | 52 | dim_y = pd.read_csv('{}/dim_y.csv'.format(data_dir)) 53 | dim_y = dim_y.astype({'state': int, 'county': int, 'year': int, 'value': float, 'lat': float, 'lon': float}) 54 | max_index = len(dim_y) - 1 55 | 56 | results = dict() 57 | for year in range(start_year, end_year + 1): 58 | print('Predict year {}......'.format(year)) 59 | 60 | test_idx = (dim_y['year'] == year) 61 | valid_idx = (dim_y['year'] == (year - 1)) 62 | if train_years is None: 63 | train_idx = (dim_y['year'] < (year - 1)) 64 | else: 65 | train_idx = (dim_y['year'] < (year - 1)) & (dim_y['year'] >= (year - 1 - train_years)) 66 | 67 | y_valid, y_train = np.array(dim_y.loc[valid_idx]['value']), np.array(dim_y.loc[train_idx]['value']) 68 | y_test, dim_test = np.array(dim_y.loc[test_idx]['value']), np.array(dim_y.loc[test_idx][['state', 'county']]) 69 | 70 | test_indices = [i for i, x in enumerate(test_idx) if x] 71 | valid_indices = [i for i, x in enumerate(valid_idx) if x] 72 | train_indices = [i for i, x in enumerate(train_idx) if x] 73 | 74 | # check if the indices are sequential 75 | assert all(elem == 1 for elem in [y - x for x, y in zip(test_indices[:-1], test_indices[1:])]) 76 | assert all(elem == 1 for elem in [y - x for x, y in zip(valid_indices[:-1], valid_indices[1:])]) 77 | assert all(elem == 1 for elem in [y - x for x, y in zip(train_indices[:-1], train_indices[1:])]) 78 | print('Train size {}, valid size {}, test size {}'.format(y_train.shape[0], y_valid.shape[0], y_test.shape[0])) 79 | 80 | test_corr_lis, test_r2_lis, test_rmse_lis = [], [], [] 81 | test_prediction_lis = [] 82 | for i in range(n_experiment): 83 | print('Experiment {}'.format(i)) 84 | 85 | cnn_lstm = CnnLstm(tn_in_channels=feature_len, 86 | tn_z_dim=tilenet_zdim, 87 | d_model=tilenet_zdim, 88 | d_inner=lstm_inner) 89 | 90 | optimizer = optim.Adam(cnn_lstm.parameters(), lr=0.001) 91 | 92 | trained_epochs = train_cnn_lstm(model=cnn_lstm, 93 | X_dir=X_dir, 94 | X_train_indices=(train_indices[0], train_indices[-1]), 95 | y_train=y_train, 96 | X_valid_indices=(valid_indices[0], valid_indices[-1]), 97 | y_valid=y_valid, 98 | X_test_indices=(test_indices[0], test_indices[-1]), 99 | y_test=y_test, 100 | n_tsteps=n_tsteps, 101 | max_index=max_index, 102 | n_triplets_per_file=n_triplets_per_file, 103 | patience=patience, 104 | optimizer=optimizer, 105 | batch_size=batch_size, 106 | test_batch_size=test_batch_size, 107 | n_epochs=epochs, 108 | out_dir=param_model_out_dir, 109 | year=year, 110 | exp_idx=i, 111 | log_file='{}/{}.txt'.format(log_out_dir, params)) 112 | 113 | test_prediction, rmse, r2, corr = eval_test(X_dir, 114 | X_test_indices=(test_indices[0], test_indices[-1]), 115 | y_test=y_test, 116 | n_tsteps=n_tsteps, 117 | max_index=max_index, 118 | n_triplets_per_file=n_triplets_per_file, 119 | batch_size=test_batch_size, 120 | model_dir=param_model_out_dir, 121 | model=cnn_lstm, 122 | epochs=trained_epochs, 123 | year=year, 124 | exp_idx=i, 125 | log_file='{}/{}.txt'.format(log_out_dir, params)) 126 | test_corr_lis.append(corr) 127 | test_r2_lis.append(r2) 128 | test_rmse_lis.append(rmse) 129 | 130 | test_prediction_lis.append(test_prediction) 131 | 132 | test_prediction = np.mean(np.asarray(test_prediction_lis), axis=0) 133 | np.save('{}/{}.npy'.format(param_result_out_dir, year), test_prediction) 134 | plot_predict(test_prediction, dim_test, Path('{}/pred_{}.html'.format(param_result_out_dir, year))) 135 | plot_predict_error(test_prediction, y_test, dim_test, Path('{}/err_{}.html'.format(param_result_out_dir, year))) 136 | 137 | results[year] = {'test_rmse': np.around(np.mean(test_rmse_lis), 3), 138 | 'test_r2': np.around(np.mean(test_r2_lis), 3), 139 | 'test_corr': np.around(np.mean(test_corr_lis), 3)} 140 | 141 | output_to_csv_simple(results, param_result_out_dir) 142 | 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser(description='Crop Yield Train CNN_LSTM') 146 | parser.add_argument('--patience', type=int, default=9999, metavar='PATIENCE') 147 | parser.add_argument('--feature', type=str, default='all', metavar='FEATURE') 148 | parser.add_argument('--feature-len', type=int, default=9, metavar='FEATURE_LEN') 149 | parser.add_argument('--year', type=int, default=2014, metavar='YEAR') 150 | parser.add_argument('--ntsteps', type=int, default=7, metavar='NTSTEPS', required=True) 151 | parser.add_argument('--train-years', type=int, default=None, metavar='TRAINYEAR', required=True) 152 | parser.add_argument('--tilenet-zdim', type=int, default=256, metavar='ZDIM') 153 | parser.add_argument('--lstm-inner', type=int, default=512, metavar='LSTM_INNER') 154 | 155 | args = parser.parse_args() 156 | 157 | crop_yield_train_cnn_lstm(args, 158 | data_dir='data/spatial_temporal/counties', 159 | model_out_dir='results/cnn_lstm/models', 160 | result_out_dir='results/cnn_lstm/results', 161 | log_out_dir='results/cnn_lstm/prediction_logs', 162 | start_year=args.year, 163 | end_year=args.year, 164 | n_tsteps=args.ntsteps, 165 | train_years=args.train_years) 166 | --------------------------------------------------------------------------------