├── torch_code ├── __init__.py ├── utils │ ├── __init__.py │ ├── error_metrics.py │ ├── loss.py │ ├── utils.py │ ├── transforms.py │ ├── plots.py │ └── parser.py ├── models │ ├── __init__.py │ ├── simplecnn.py │ ├── models.py │ └── resnet.py ├── train.py ├── predict.py ├── collect_ensemble_preds.py ├── dataset.py └── trainer.py ├── .gitignore ├── requirements.txt ├── DEMO_GEDI_orbit_prediction.sh ├── DEMO_GEDI_regression_crossval_ensemble.sh ├── cluster └── job_array_regression_crossval_ensemble.sh └── README.md /torch_code/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_code/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_code/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | demo_data 2 | *output_demo* 3 | 4 | *.pyc 5 | *.ipynb_checkpoints 6 | *.png 7 | *.pdf 8 | *.jpg 9 | *.DS_Store 10 | 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==3.1.0 2 | matplotlib==3.3.4 3 | torchvision==0.8.2 4 | numpy==1.20.1 5 | scipy==1.6.1 6 | tqdm==4.57.0 7 | pandas==1.2.2 8 | torch_summary==1.4.5 9 | opencv_python==4.5.1.48 10 | torch==1.7.1 11 | torchsummary==1.5.1 12 | tensorboard -------------------------------------------------------------------------------- /torch_code/utils/error_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def mse(x, y): 4 | return np.nanmean(np.square(y - x)) 5 | 6 | 7 | def rmse(x, y): 8 | return float(np.sqrt(np.nanmean(np.square(y - x)))) 9 | 10 | 11 | def mae(x, y): 12 | return float(np.nanmean(np.abs(y - x))) 13 | 14 | 15 | def me(x, y): 16 | return float(np.nanmean(y - x)) 17 | 18 | 19 | def rmspe(x, y): 20 | return float(np.sqrt(np.nanmean(np.square((y - x) / x))) * 100) 21 | 22 | 23 | def mape(x, y): 24 | return float(np.nanmean(np.abs((y - x) / x)) * 100) 25 | 26 | 27 | def mpe(x, y): 28 | return float(np.nanmean((y - x) / x) * 100) 29 | 30 | 31 | def get_metrics_dict(): 32 | metrics_dict_fun = {'RMSE': rmse, 33 | 'MAE': mae, 34 | 'ME': me, 35 | 'RMSPE': rmspe, 36 | 'MAPE': mape, 37 | 'MPE': mpe} 38 | return metrics_dict_fun 39 | 40 | -------------------------------------------------------------------------------- /DEMO_GEDI_orbit_prediction.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ----- CONFIGURATION ----- 4 | 5 | # path to base directory with model_XX subdirectories 6 | ensemble_dir=demo_data/GEDI_BDL_demo/output_demo/testfold_0/ 7 | 8 | model_name='SimpleResNet_8blocks' 9 | n_models=10 10 | batch_size=8192 # Note: Reduce the batch size if the GPU is out of memory. 11 | 12 | # set output directory for predictions 13 | prediction_dir='output_demo_orbit_prediction' 14 | # NOTE that in predict.py: 15 | # if prediction_dir is None: 16 | # args.prediction_dir = os.path.dirname(args.file_path_L1B).replace('/L1B', '/pred_RH98') 17 | 18 | # set L1B file 19 | file_path_L1B=demo_data/GEDI_BDL_demo/DEMO_orbit_files/L1B/GEDI01_B_2019224233051_O03775_T03020_02_003_01.h5 20 | # set corresponding L2A file (used for quality filtering) 21 | file_path_L2A=demo_data/GEDI_BDL_demo/DEMO_orbit_files/L2A/processed_GEDI02_A_2019224233051_O03775_T03020_02_001_01.h5 22 | 23 | echo L1B_path: 24 | echo ${file_path_L1B} 25 | 26 | echo L2A_path: 27 | echo ${file_path_L2A} 28 | 29 | echo output directory: 30 | echo ${prediction_dir} 31 | 32 | python3 torch_code/predict.py --ensemble_dir=${ensemble_dir} \ 33 | --n_models=${n_models} \ 34 | --batch_size=${batch_size} \ 35 | --file_path_L1B=${file_path_L1B} \ 36 | --file_path_L2A=${file_path_L2A} \ 37 | --prediction_dir=${prediction_dir} \ 38 | --model_name=${model_name} 39 | -------------------------------------------------------------------------------- /torch_code/models/simplecnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SimpleCNN(nn.Module): 5 | def __init__(self, 6 | in_features=1, 7 | out_features=(16, 32, 64, 128, 128, 128, 128, 128), 8 | num_outputs=1, 9 | global_pool=nn.AdaptiveAvgPool1d, 10 | max_pool=True, 11 | dilation_rate=1): 12 | """ 13 | The constructor. 14 | :param in_features: input channel dimension (we say waveforms have channel dimension 1). 15 | :param out_features: list of channel feature dimensions. 16 | :param num_outputs: number of the output dimension . 17 | :param global_pool: the global pooling to use before the fully connected layer. 18 | :param max_pool: whether to use max pooling or not. 19 | :param dilation_rate: must be >= 1. Defaults to 1 (no dilation). 20 | """ 21 | super(SimpleCNN, self).__init__() 22 | self.relu = nn.ReLU(inplace=True) 23 | layers = list() 24 | for i in range(len(out_features)): 25 | in_channels = in_features if i == 0 else out_features[i-1] 26 | layers.append(nn.Conv1d(in_channels=in_channels, out_channels=out_features[i], kernel_size=3, 27 | padding=1, dilation=dilation_rate)) 28 | layers.append(nn.BatchNorm1d(num_features=out_features[i])) 29 | layers.append(self.relu) 30 | if max_pool: 31 | layers.append(nn.MaxPool1d(kernel_size=3, stride=2, padding=1)) 32 | self.conv_layers = nn.Sequential(*layers) 33 | self.global_pool = global_pool(output_size=1) 34 | self.dropout = nn.Dropout(p=0.5) 35 | self.fc = nn.Linear(in_features=out_features[-1], out_features=num_outputs) 36 | 37 | def forward(self, x): 38 | x = self.conv_layers(x) 39 | x = self.global_pool(x) 40 | x = self.dropout(x) 41 | x = x.flatten(start_dim=1) 42 | x = self.fc(x) 43 | return x 44 | 45 | 46 | if __name__ == "__main__": 47 | from torchsummary import summary 48 | import torch 49 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 50 | 51 | net = SimpleCNN() 52 | net.to(device) 53 | summary(net, input_size=(1, 1420)) 54 | -------------------------------------------------------------------------------- /torch_code/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from models.models import Models 4 | from utils.parser import setup_parser 5 | from pathlib import Path 6 | from torchsummary import summary 7 | import json 8 | from trainer import Trainer 9 | from utils.plots import plot_hist2d 10 | import os 11 | 12 | 13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 | 15 | 16 | if __name__ == '__main__': 17 | # set parameters / parse arguments 18 | parser = setup_parser() 19 | args, unknown = parser.parse_known_args() 20 | 21 | if args.loss_key in ['gaussian_nll', 'laplacian_nll']: 22 | args.num_outputs = 2 23 | 24 | # log train/val metrics in tensorboard 25 | tensorboard_log_dir = Path(args.out_dir)/'log' 26 | 27 | # Setup model 28 | model_chooser = Models() 29 | model = model_chooser(args.model_name)(num_outputs=args.num_outputs) 30 | model.to(device) 31 | summary(model, input_size=(1, args.sample_length)) 32 | 33 | # Setup Trainer 34 | trainer = Trainer(model=model, log_dir=tensorboard_log_dir, args=args) 35 | 36 | print('TRAIN: ', len(trainer.ds_train)) 37 | print('VAL: ', len(trainer.ds_val)) 38 | print('TEST: ', len(trainer.ds_test)) 39 | 40 | # train 41 | if os.path.exists(Path(args.out_dir) / 'weights_last_epoch.pt') or args.skip_training: 42 | print("MODEL WAS ALREADY TRAINED. SKIP TRAINING! (because the file 'weights_last_epoch.pt' exists already)") 43 | else: 44 | trainer.train() 45 | 46 | # --- test --- 47 | if os.path.exists(Path(args.out_dir) / 'confusion.png'): 48 | print("MODEL WAS ALREADY TESTED. SKIP TESTING! (because the file 'confusion.png' exists already)") 49 | else: 50 | test_metrics, test_dict, test_metric_string = trainer.test() 51 | 52 | # save results 53 | with open(Path(args.out_dir) / 'results.txt', 'w') as f: 54 | f.write(test_metric_string) 55 | 56 | with open(Path(args.out_dir) / 'test_results.json', 'w') as f: 57 | json.dump(test_metrics, f) 58 | 59 | for key in test_dict.keys(): 60 | np.save(file=Path(args.out_dir) / '{}.npy'.format(key), arr=test_dict[key]) 61 | np.save(file=Path(args.out_dir) / 'test_indices.npy', arr=trainer.test_indices) 62 | 63 | # plot confusion ground truth vs. prediction 64 | plot_hist2d(x=test_dict['targets'], y=test_dict['predictions'], ma=args.max_gt, step=args.max_gt / 10, 65 | out_dir=args.out_dir, figsize=(8, 6), xlabel='Ground truth [m]', ylabel='Prediction [m]') 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /torch_code/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MELoss(nn.Module): 6 | def __init__(self): 7 | super(MELoss, self).__init__() 8 | 9 | def __call__(self, prediction, target): 10 | return torch.mean(prediction - target) 11 | 12 | 13 | class RMSELoss(nn.Module): 14 | def __init__(self): 15 | super(RMSELoss, self).__init__() 16 | 17 | def __call__(self, prediction, target): 18 | return torch.sqrt(torch.mean((prediction - target)**2)) 19 | 20 | 21 | class GaussianNLL(nn.Module): 22 | """ 23 | Gaussian negative log likelihood to fit the mean and variance to p(y|x) 24 | Note: We estimate the heteroscedastic variance. Hence, we include the var_i of sample i in the sum over all samples N. 25 | Furthermore, the constant log term is discarded. 26 | """ 27 | def __init__(self): 28 | super(GaussianNLL, self).__init__() 29 | self.eps = 1e-8 30 | 31 | def __call__(self, prediction, log_variance, target): 32 | """ 33 | This function expects the log(var) to guarantee a positive variance with var = exp(log(var)). 34 | :param prediction: Predicted mean values 35 | :param log_variance: Predicted log(variance) 36 | :param target: Ground truth labels 37 | :return: gaussian negative log likelihood 38 | """ 39 | # add a small constant to the variance for numeric stability 40 | variance = torch.exp(log_variance) + self.eps 41 | return torch.mean(0.5 / variance * (prediction - target)**2 + 0.5 * torch.log(variance)) 42 | 43 | 44 | class LaplacianNLL(nn.Module): 45 | """ 46 | Laplacian negative log likelihood to fit the mean and variance to p(y|x) 47 | Note: We estimate the heteroscedastic variance. Hence, we include the var_i of sample i in the sum over all samples N. 48 | Furthermore, the constant log term is discarded. 49 | """ 50 | def __init__(self): 51 | super(LaplacianNLL, self).__init__() 52 | self.eps = 1e-8 53 | 54 | def __call__(self, prediction, log_variance, target): 55 | """ 56 | This function expects the log(var) to guarantee a positive variance with var = exp(log(var)). 57 | :param prediction: Predicted mean values 58 | :param log_variance: Predicted log(variance) 59 | :param target: Ground truth labels 60 | :return: gaussian negative log likelihood 61 | """ 62 | # add a small constant to the variance for numeric stability 63 | variance = torch.exp(log_variance) + self.eps 64 | return torch.mean(1 / variance * torch.abs(prediction - target) + torch.log(variance)) 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /DEMO_GEDI_regression_crossval_ensemble.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # job index (set this to your system job variable e.g. for parallel job arrays) 4 | # used to set model_idx and test_fold_idx below. 5 | index=0 # index=0 --> model_idx=0, test_fold_idx=0 6 | 7 | inputs_path=demo_data/GEDI_BDL_demo/GEDI_BDL_demo_subset_neon.npy 8 | 9 | target_key='als_rh098' 10 | min_gt=0 11 | max_gt=100 12 | 13 | input_key='rxwaveform' 14 | sample_length=1420 15 | noise_mean_key='noise_mean_corrected' 16 | 17 | model_name='SimpleResNet_8blocks' 18 | loss_key='gaussian_nll' 19 | n_models=10 20 | n_folds=10 21 | normalize_targets=true 22 | 23 | batch_size=16 # the batch size was reduced to obtain a stable optimization with the small demo dataset 24 | nb_epoch=200 25 | base_lr=0.0001 26 | 27 | # data augmentation 28 | shift_left=0.2 29 | shift_right=0.2 30 | 31 | # quality flags to filter different expected noise levels 32 | setting_idx=3 # 0: power-night, 1: power-night + power-day, 2: power-night + power-day + coverage-night, 3: all 33 | # filtering for complete crossover data including waveform matching information, otherwise all data is used 34 | use_quality_flag=true 35 | pearson_thresh=0.95 36 | 37 | # select the model index for the model ensemble 38 | model_idx=$(( $index % ${n_models} )) 39 | 40 | # select the test fold index 41 | test_fold_idx=$(( $index / ${n_models} )) 42 | 43 | out_dir=output_demo/testfold_${test_fold_idx}/model_${model_idx} 44 | 45 | echo job index: $index 46 | echo model_idx: $model_idx 47 | echo test_fold_idx: ${test_fold_idx} 48 | echo output directory: ${out_dir} 49 | 50 | # train and test 51 | python3 torch_code/train.py --out_dir=${out_dir} \ 52 | --n_folds=${n_folds} \ 53 | --test_fold_idx=${test_fold_idx} \ 54 | --min_gt=${min_gt} \ 55 | --max_gt=${max_gt} \ 56 | --batch_size=${batch_size} \ 57 | --nb_epoch=${nb_epoch} \ 58 | --base_learning_rate=${base_lr} \ 59 | --loss_key=${loss_key} \ 60 | --sample_length=${sample_length} \ 61 | --inputs_path=${inputs_path} \ 62 | --input_key=${input_key} \ 63 | --target_key=${target_key} \ 64 | --shift_left=${shift_left} \ 65 | --shift_right=${shift_right} \ 66 | --model_name=${model_name}\ 67 | --setting_idx=${setting_idx} \ 68 | --normalize_targets=${normalize_targets} \ 69 | --pearson_thresh=${pearson_thresh} \ 70 | --noise_mean_key=${noise_mean_key} 71 | 72 | -------------------------------------------------------------------------------- /cluster/job_array_regression_crossval_ensemble.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #BSUB -W 2:00 4 | #BSUB -o /cluster/home/nlang/HCS_project/output/gedi_crossval.%J.%I.txt 5 | #BSUB -e /cluster/home/nlang/HCS_project/output/gedi_crossval.%J.%I.txt 6 | #BSUB -R "rusage[mem=6000,ngpus_excl_p=1]" 7 | #BSUB -R "select[gpu_model0==GeForceGTX1080Ti]" 8 | #BSUB -n 1 9 | #BSUB -J "GEDI_crossval[1-100]" 10 | ##BSUB -u nlang@ethz.ch 11 | 12 | # load modules on cluster 13 | module load python_gpu/3.7.1 14 | module load hdf5/1.10.1 15 | 16 | 17 | # job index (set this to your system job variable e.g. for parallel job arrays) 18 | # used to set model_idx and test_fold_idx below. 19 | #index=0 # index=0 --> model_idx=0, test_fold_idx=0 20 | index=$((LSB_JOBINDEX - 1)) 21 | 22 | inputs_path=demo_data/GEDI_BDL_demo/GEDI_BDL_demo_subset_neon.npy 23 | 24 | target_key='als_rh098' 25 | min_gt=0 26 | max_gt=100 27 | 28 | input_key='rxwaveform' 29 | sample_length=1420 30 | noise_mean_key='noise_mean_corrected' 31 | 32 | model_name='SimpleResNet_8blocks' 33 | loss_key='gaussian_nll' 34 | n_models=10 35 | n_folds=10 36 | normalize_targets=true 37 | 38 | batch_size=16 # the batch size was reduced to obtain a stable optimization with the small demo dataset 39 | nb_epoch=200 40 | base_lr=0.0001 41 | 42 | # data augmentation 43 | shift_left=0.2 44 | shift_right=0.2 45 | 46 | # quality flags to filter different expected noise levels 47 | setting_idx=3 # 0: power-night, 1: power-night + power-day, 2: power-night + power-day + coverage-night, 3: all 48 | # filtering for complete crossover data including waveform matching information, otherwise all data is used 49 | use_quality_flag=true 50 | pearson_thresh=0.95 51 | 52 | # select the model index for the model ensemble 53 | model_idx=$(( $index % ${n_models} )) 54 | 55 | # select the test fold index 56 | test_fold_idx=$(( $index / ${n_models} )) 57 | 58 | out_dir=output_demo/testfold_${test_fold_idx}/model_${model_idx} 59 | 60 | echo job index: $index 61 | echo model_idx: $model_idx 62 | echo test_fold_idx: ${test_fold_idx} 63 | echo output directory: ${out_dir} 64 | 65 | # train and test 66 | python3 torch_code/train.py --out_dir=${out_dir} \ 67 | --n_folds=${n_folds} \ 68 | --test_fold_idx=${test_fold_idx} \ 69 | --min_gt=${min_gt} \ 70 | --max_gt=${max_gt} \ 71 | --batch_size=${batch_size} \ 72 | --nb_epoch=${nb_epoch} \ 73 | --base_learning_rate=${base_lr} \ 74 | --loss_key=${loss_key} \ 75 | --sample_length=${sample_length} \ 76 | --inputs_path=${inputs_path} \ 77 | --input_key=${input_key} \ 78 | --target_key=${target_key} \ 79 | --shift_left=${shift_left} \ 80 | --shift_right=${shift_right} \ 81 | --model_name=${model_name}\ 82 | --setting_idx=${setting_idx} \ 83 | --normalize_targets=${normalize_targets} \ 84 | --pearson_thresh=${pearson_thresh} \ 85 | --noise_mean_key=${noise_mean_key} 86 | 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Global canopy height regression and uncertainty estimation from GEDI LIDAR waveforms with deep ensembles 2 | 3 | This repository provides the code used to create the results presented in [Global canopy height regression and uncertainty estimation from GEDI LIDAR waveforms with deep ensembles](https://doi.org/10.1016/j.rse.2021.112760). 4 | 5 | ## Installation 6 | The code has been tested with `Python 3.8.5`. 7 | 8 | ### Setup a virtual environment 9 | See details in the [venv documentation](https://docs.python.org/3/library/venv.html). 10 | 11 | **Example on linux:** 12 | 13 | Create a new virtual environment called `GEDI_BDL_env`. 14 | ``` 15 | python3 -m venv /path/to/new/virtual/environment/GEDI_BDL_env 16 | ``` 17 | 18 | Activate the new environment: 19 | ``` 20 | source /path/to/new/virtual/environment/GEDI_BDL_env/bin/activate 21 | ``` 22 | 23 | ### Install python packages 24 | After activating the venv, install the python packages with: 25 | ``` 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | ## Download data for the DEMO scripts 30 | Please download the zip file `GEDI_BDL_demo.zip` from [here](https://share.phys.ethz.ch/~pf/nlangdata/GEDI_BDL_demo.zip). 31 | 32 | Extract and save it in this repository such that path reads like this: `GEDI-BDL/demo_data/GEDI_BDL_demo/`. 33 | The demo scripts will refer to this relative path. 34 | 35 | This demo dataset contains a subset of the ALS crossover training data used in the paper. It consists of 6,868 samples and is based on the publicly available ALS data from the [National Ecological Observatory Network (NEON)](https://data.neonscience.org/data-products/explore) in the United States. 36 | 37 | Note: The purpose of this demo dataset is to setup the code. Models trained on this subset may *not* generalize as described in the paper. 38 | More information on the demo dataset in the readme file: `demo_data/GEDI_BDL_demo/README.txt`. 39 | 40 | ## Running the code 41 | 42 | ### Train and test a single CNN (or an ensemble) 43 | This example runs the regression of RH98 (proxy for the canopy top height) from the input L1B waveform. It runs the first model of the first random cross-validation fold. 44 | 45 | Running this script multiple times with job indices from 0-9 will train and test a full ensemble of 10 models for the first cross-validation fold. Job indices 10-19 will run the ensemble for the second fold and so on. 46 | ``` 47 | bash DEMO_GEDI_regression_crossval_ensemble.sh 48 | ``` 49 | Alternative run a parallel job array on an IBM LSF batch system: 50 | ``` 51 | bsub < cluster/job_array_regression_crossval_ensemble.sh 52 | ``` 53 | 54 | Launch tensorboard to look at the training and validation loss curves: 55 | 56 | ```tensorboard --logdir output_demo --port 7777``` 57 | 58 | #### Collect ensemble predictions from all cross-validation folds 59 | Here we run it for the ensemble demo output that was already included in the .zip file. 60 | 61 | ``` 62 | python torch_code/collect_ensemble_preds.py demo_data/GEDI_BDL_demo/output_demo/ 63 | ``` 64 | 65 | ### Predict for all (quality) waveforms in an L1B orbit file 66 | This example demonstrates how a trained model can be deployed to a full orbit file of the GEDI Version 1 data. This script loads the ensemble trained on the demo dataset from here: `demo_data/GEDI_BDL_demo/output_demo/testfold_0/`. 67 | 68 | The orbit files are loaded from `demo_data/GEDI_BDL_demo/DEMO_orbit_files`. The quality flag from the corresponding L2A file is used to filter the predictions. 69 | ``` 70 | bash DEMO_GEDI_orbit_prediction.sh 71 | ``` 72 | 73 | ## Citation 74 | 75 | If you use this code please cite our paper: 76 | 77 | *Lang, N., Kalischek, N., Armston, J., Schindler, K., Dubayah, R., & Wegner, J. D. (2022). Global canopy height regression and uncertainty estimation from GEDI LIDAR waveforms with deep ensembles. Remote Sensing of Environment, 268, 112760.* 78 | 79 | BibTex: 80 | ``` 81 | @article{lang2022global, 82 | title={Global canopy height regression and uncertainty estimation from GEDI LIDAR waveforms with deep ensembles}, 83 | author={Lang, Nico and Kalischek, Nikolai and Armston, John and Schindler, Konrad and Dubayah, Ralph and Wegner, Jan Dirk}, 84 | journal={Remote Sensing of Environment}, 85 | volume={268}, 86 | pages={112760}, 87 | year={2022}, 88 | publisher={Elsevier} 89 | } 90 | ``` -------------------------------------------------------------------------------- /torch_code/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | np.random.seed(2401) 4 | 5 | 6 | def get_quality_indices(data_crossover, pearson_thresh=0.95): 7 | """ 8 | 9 | Args: 10 | data_crossover: dict with arrays 11 | pearson_thresh: minimum threshold on pearson correlation 12 | 13 | Returns: bool array flagging quality samples 14 | 15 | """ 16 | quality = ~np.isnan(data_crossover['ground_elev_cog']) & \ 17 | (data_crossover['dz_pearson'] > 0.9) & \ 18 | (data_crossover['dz_count'] > 5) & \ 19 | (data_crossover['pearson'] > pearson_thresh) 20 | return quality 21 | 22 | 23 | def filter_shots(data, night_strong=True, day_strong=True, night_coverage=True, day_coverage=True): 24 | """ 25 | Get indices to create subsets of samples from specific noise level groups. 26 | Day vs. night and strong (full power beams) vs. coverage beams. 27 | 28 | Args: 29 | data: dict with data arrays 30 | night_strong: bool, return samples from this group if True 31 | day_strong: bool, return samples from this group if True 32 | night_coverage: bool, return samples from this group if True 33 | day_coverage: bool, return samples from this group if True 34 | 35 | Returns: 36 | valid_indices: bool array 37 | out_str: string to identify the settings 38 | """ 39 | night_indices = data['solar_elevation'] < 0 40 | day_indices = ~night_indices 41 | coverage_indices = data['coverage_flag'] == 1 42 | strong_indices = ~coverage_indices 43 | 44 | out_str = '' 45 | 46 | # init: all points are invalid 47 | valid_indices = np.repeat(0, repeats=len(data['shot_number'])) 48 | 49 | if night_strong: 50 | indices = np.logical_and(night_indices, strong_indices) 51 | valid_indices = np.logical_or(valid_indices, indices) 52 | out_str += 'night-strong' 53 | if day_strong: 54 | indices = np.logical_and(day_indices, strong_indices) 55 | valid_indices = np.logical_or(valid_indices, indices) 56 | out_str += '_day-strong' 57 | if night_coverage: 58 | indices = np.logical_and(night_indices, coverage_indices) 59 | valid_indices = np.logical_or(valid_indices, indices) 60 | out_str += '_night-coverage' 61 | if day_coverage: 62 | indices = np.logical_and(day_indices, coverage_indices) 63 | valid_indices = np.logical_or(valid_indices, indices) 64 | out_str += '_day-coverage' 65 | 66 | return valid_indices, out_str 67 | 68 | 69 | # attention: this function returns bool arrays to filter the data (not indices) 70 | def filter_subset_indices_by_target_range(subset_attribute, range_to_remove): 71 | """ 72 | Returns bool arrays to create subsets that are within and outside the specified target range 73 | 74 | Note: While this function returns bool arrays, the pendant 75 | Trainer.filter_subset_indices_by_attribute_range (in run.py) returns indices. 76 | 77 | Args: 78 | subset_attribute: array of attribute which is used for splitting 79 | range_to_remove: tuple (min_value, max_value) 80 | 81 | Returns: 82 | in_dist_indices: bool array 83 | out_dist_indices: bool array 84 | """ 85 | range_indices = (subset_attribute > range_to_remove[0]) & (subset_attribute < range_to_remove[1]) 86 | print(len(range_indices)) 87 | in_dist_indices = ~range_indices 88 | out_dist_indices = range_indices 89 | return in_dist_indices, out_dist_indices 90 | 91 | 92 | def filter_subset_indices_by_attribute(subset_attribute, out_dist_value): 93 | """ 94 | Returns bool arrays to create subsets that have a specific attribute value as out_dist_indices and 95 | the remaining samples as in_dist_indices. 96 | 97 | Note: While this function returns bool arrays, the pendant 98 | Trainer.filter_subset_indices_by_attribute (in run.py) returns indices. 99 | 100 | Args: 101 | subset_attribute: array of attribute which is used for splitting 102 | out_dist_value: attribute value considered to be out of distribution 103 | 104 | Returns: 105 | in_dist_indices: bool array 106 | out_dist_indices: bool array 107 | """ 108 | in_dist_indices = subset_attribute != out_dist_value 109 | out_dist_indices = subset_attribute == out_dist_value 110 | return in_dist_indices, out_dist_indices 111 | 112 | -------------------------------------------------------------------------------- /torch_code/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy import ndimage 4 | 5 | 6 | def pad_waveforms(waveforms, out_size=1420, axis=1, value=0): 7 | """ 8 | Pads the waveforms on the right side to out_size (this preserves the georeference of the first returns) 9 | :param waveforms: array (N, waveform_length) with N samples 10 | :param out_size: expected waveform length 11 | :param axis: int, axis along to pad 12 | :param value: constant value used for padding 13 | :return: array (N, out_size) padded waveforms 14 | """ 15 | 16 | if waveforms.shape[axis] == out_size: 17 | return waveforms 18 | else: 19 | waveform_length = waveforms.shape[axis] 20 | pad_after = out_size-waveform_length 21 | pad_widths = [(0, 0) for i in range(len(waveforms.shape))] 22 | pad_widths[axis] = (0, pad_after) 23 | padded_waveforms = np.pad(waveforms, pad_width=pad_widths, mode='constant', constant_values=value) 24 | return padded_waveforms 25 | 26 | 27 | def denormalize(x, mean, std): 28 | x = x * std 29 | x = x + mean 30 | return x 31 | 32 | 33 | class ToTensor(object): 34 | """ 35 | Turn numpy array into torch tensor with CxW 36 | """ 37 | 38 | def __call__(self, x): 39 | x = torch.from_numpy(x) 40 | x = x.permute((1, 0)).contiguous() 41 | return x 42 | 43 | 44 | class Normalize(object): 45 | """normalize the input tensor with training mean and std. 46 | """ 47 | 48 | def __init__(self, mean, std): 49 | """ 50 | :param mean: scalar (float) 51 | :param std: scalar (float) 52 | """ 53 | self.mean = mean 54 | self.std = std 55 | 56 | def __call__(self, x): 57 | x = x - self.mean 58 | x = x / self.std 59 | return x 60 | 61 | 62 | # -------- DATA AUGMENTATION -------- 63 | class RandomShift(object): 64 | """ 65 | Randomly shift the waveform by a fraction uniformly sampled from [-shift, shift]. 66 | """ 67 | 68 | def __init__(self, shift_interval, do_shift_target): 69 | """ 70 | :param shift_interval: tuple (float, float) i.e. (-shift_left, shift_right) each in the range of [0, 1] 71 | :param do_shift_target: bool True: adjust the target elevation (e.g. delta = Z0 - ground) False: keep the target 72 | """ 73 | self.shift_interval = shift_interval 74 | self.elev_resolution = 0.15 # GEDI waveforms have a resolution of 0.15 m between two returns 75 | self.do_shift_target = do_shift_target 76 | 77 | def __call__(self, sample): 78 | x, y = sample 79 | if self.shift_interval: 80 | # TODO: change to torch.rand 81 | rel_shift = np.random.uniform(self.shift_interval[0], self.shift_interval[1]) 82 | abs_shift = int(rel_shift * len(x)) 83 | x = ndimage.interpolation.shift(x, abs_shift, mode='nearest') 84 | 85 | # this is used to update the elevation target (delta = Z0 - ground) 86 | if self.do_shift_target: 87 | elev_shift = abs_shift * self.elev_resolution 88 | y = y + elev_shift 89 | return x, y 90 | 91 | 92 | class RandomLabelNoise(object): 93 | """ 94 | Add random noise to the target variable. 95 | """ 96 | 97 | def __init__(self, rel_label_noise, distribution='uniform'): 98 | """ 99 | :param rel_label_noise: scalar (float) fraction of label [0, 1]. 100 | :param distribution: string choices= ['uniform', 'normal']. If normal, then rel_label_noise defines the std 101 | """ 102 | self.rel_label_noise = rel_label_noise 103 | self.distribution = distribution 104 | 105 | def __call__(self, sample): 106 | x, y = sample 107 | if self.rel_label_noise: 108 | if self.distribution == 'uniform': 109 | # TODO: change to torch.rand 110 | rel_noise = np.random.uniform(-self.rel_label_noise, self.rel_label_noise) 111 | elif self.distribution == 'normal': 112 | # TODO: change to torch.rand 113 | rel_noise = np.random.normal(loc=0, scale=self.rel_label_noise) 114 | else: 115 | raise ValueError("This distribution is not impelmented. Use 'uniform' or 'normal'.") 116 | abs_noise = rel_noise * y 117 | y = y + abs_noise 118 | return x, y 119 | -------------------------------------------------------------------------------- /torch_code/utils/plots.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | 6 | CMAP = matplotlib.cm.get_cmap('cividis') 7 | HIST_COLOR = CMAP(0.6) 8 | 9 | 10 | def plot_hist2d(x, y, ma, step, out_dir=None, figsize=(8, 6), xlabel='Ground truth [m]', ylabel='Prediction [m]', 11 | usetex=False, fontsize=18, bins=None, vmax=None, cmap='cividis'): 12 | """ 13 | Confusion plot ground truth values vs. predicted values. 14 | """ 15 | if bins is None: 16 | edges = np.linspace(0, ma, 100) 17 | bins = (edges, edges) 18 | 19 | fig = plt.figure(figsize=figsize) 20 | h = plt.hist2d(x.squeeze(), y.squeeze(), bins=bins, cmap=cmap, 21 | norm=matplotlib.colors.LogNorm(), vmax=vmax, rasterized=True) 22 | plt.colorbar(h[3], label='Number of samples') 23 | plt.xlabel(xlabel) 24 | plt.ylabel(ylabel) 25 | plt.grid() 26 | plt.axis('equal') 27 | plt.plot([0, ma], [0, ma], 'k--') 28 | plt.xlim((0, ma)) 29 | plt.ylim((0, ma)) 30 | ticks = np.arange(0, ma + step, step) 31 | plt.xticks(ticks) 32 | plt.yticks(ticks) 33 | plt.tight_layout() 34 | if out_dir: 35 | fig.savefig(fname=os.path.join(out_dir, 'confusion.png'), dpi=300) 36 | return fig 37 | 38 | 39 | def plot_precision_recall(predictions, targets, uncertainties, metric='RMSE', 40 | ax=None, figsize=(8, 6), ylabel='RMSE', label=None, style=None, out_dir=None): 41 | # compute errors 42 | errors = predictions - targets 43 | 44 | # sort data 45 | sorted_inds = uncertainties.argsort() 46 | uncertainties = uncertainties[sorted_inds] 47 | errors = errors[sorted_inds] 48 | 49 | precision, recall = [], [] 50 | 51 | num_total = len(errors) 52 | for i in range(num_total): 53 | if metric == 'RMSE': 54 | precision.append(np.sqrt(np.mean(errors[0:i] ** 2))) 55 | elif metric == 'MSE': 56 | precision.append(np.mean(errors[0:i] ** 2)) 57 | 58 | recall.append(i / num_total) 59 | 60 | if ax is None: 61 | fig = plt.figure(figsize=figsize) 62 | ax = fig.gca() 63 | if style is not None: 64 | ax.plot(recall, precision, style, label=label) 65 | else: 66 | ax.plot(recall, precision, label=label) 67 | 68 | ax.set_xlabel('Recall') 69 | ax.set_ylabel(ylabel) 70 | 71 | if out_dir: 72 | fig.savefig(fname=os.path.join(out_dir, 'precision_recall_curve.png'), dpi=300) 73 | 74 | return ax 75 | 76 | 77 | def plot_calibration(predictions, targets, uncertainties, min_bin_count=10, metric='RMSE', bins=None, step=None, 78 | style='k-o', ax=None, figsize=(8, 6), xlabel='STD', ylabel='RMSE', out_dir=None): 79 | """ 80 | uncertainties: must be standard deviations -> will be squared for MSE. 81 | """ 82 | color_ax = CMAP(0.05) 83 | color_ax2 = HIST_COLOR 84 | 85 | if bins is None: 86 | ma = np.max(uncertainties) 87 | step = ma / 10 88 | bins = np.arange(0, ma + step, step) 89 | print(bins) 90 | else: 91 | ma = np.max(bins) 92 | 93 | # bin data 94 | bin_indices = np.digitize(x=uncertainties, bins=bins, right=True) 95 | 96 | errors = predictions - targets 97 | 98 | error_binned, uncertainty_binned, num_binned = [], [], [] 99 | 100 | for idx in np.arange(len(bins)) + 1: 101 | bin_i_indices = bin_indices == idx 102 | 103 | if metric == 'RMSE': 104 | # average the estimated uncertainty per bin (var or std) 105 | uncertainty_binned.append(np.sqrt(np.mean(uncertainties[bin_i_indices] ** 2))) 106 | # average the respective error metric per bin 107 | error_binned.append(np.sqrt(np.mean(errors[bin_i_indices] ** 2))) 108 | 109 | elif metric == 'MSE': 110 | # average the estimated uncertainty per bin (var or std) 111 | uncertainty_binned.append(np.mean(uncertainties[bin_i_indices] ** 2)) 112 | # average the respective error metric per bin 113 | error_binned.append(np.mean(errors[bin_i_indices] ** 2)) 114 | 115 | num_binned.append(np.sum(bin_i_indices)) 116 | 117 | # convert to numpy 118 | error_binned = np.array(error_binned) 119 | uncertainty_binned = np.array(uncertainty_binned) 120 | num_binned = np.array(num_binned) 121 | 122 | # remove estimates where the number of samples per bin is to small 123 | error_binned[num_binned < min_bin_count] = np.nan 124 | uncertainty_binned[num_binned < min_bin_count] = np.nan 125 | 126 | if ax is None: 127 | # fig, ax = plt.subplots(figsize=figsize) 128 | fig = plt.figure(figsize=figsize) 129 | ax = fig.gca() 130 | ax.plot(uncertainty_binned, error_binned, style, zorder=2, color=color_ax) 131 | print('x: ', uncertainty_binned) 132 | print('y: ', error_binned) 133 | print('count: ', num_binned) 134 | 135 | ax.set_xlim(0, ma) 136 | ax.set_ylim(0, ma) 137 | 138 | # perfect calibration line 139 | ax.plot(ax.get_xlim(), ax.get_xlim(), "k--", zorder=1, alpha=0.5) 140 | 141 | ax2 = ax.twinx() # instantiate a second axes that shares the same x-axis 142 | ax2.bar(bins, num_binned, align='edge', width=step, zorder=0, edgecolor='white', color=color_ax2) 143 | 144 | ax.set_xlabel(xlabel) 145 | ax.set_ylabel(ylabel, color=color_ax) 146 | ax2.set_ylabel('Number of samples', color=color_ax2) 147 | ax2.tick_params(axis='y', labelcolor=color_ax2) 148 | 149 | ax.set_zorder(ax2.get_zorder() + 1) 150 | ax.patch.set_visible(False) 151 | 152 | ax.tick_params(axis='y', labelcolor=color_ax) 153 | 154 | # ax.set_aspect("equal") 155 | fig.tight_layout() 156 | 157 | if out_dir: 158 | fig.savefig(fname=os.path.join(out_dir, 'calibration.png'), dpi=300) 159 | 160 | return fig, ax 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /torch_code/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import h5py 4 | import time 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from torchvision.transforms import Compose 8 | from utils.parser import setup_parser 9 | from dataset import GediDataOrbitMem 10 | from tqdm import tqdm 11 | 12 | from utils.transforms import Normalize, ToTensor, denormalize 13 | from models.models import Models 14 | 15 | 16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 | 18 | 19 | def predict(model, dl_pred, mean_target_train, std_target_train): 20 | """ 21 | Predict with trained model. 22 | Args: 23 | model: torch model (trained weights must be loaded) 24 | dl_pred: torch dataloader with (unlabeled) data to predict. 25 | mean_target_train: train mean of target variable (to denormalize predictions) 26 | std_target_train: train standard deviation (std) of target variable (to denormalize predictions) 27 | 28 | Returns: Dict with torch tensors 'predictions', 'variances' in original scale (denormalized). 29 | 30 | """ 31 | 32 | # init validation results for current epoch 33 | out_dict = {'predictions': [], 'log_variances': []} 34 | 35 | with torch.no_grad(): 36 | for step, (inputs) in enumerate(tqdm(dl_pred, ncols=100, desc='pred')): 37 | 38 | inputs = inputs.to(device) 39 | 40 | predictions = model.forward(inputs).squeeze(dim=-1) 41 | predictions, log_variances = predictions[:, 0], predictions[:, 1] 42 | 43 | out_dict['predictions'] += list(predictions) 44 | out_dict['log_variances'] += list(log_variances) 45 | 46 | for key in out_dict.keys(): 47 | if out_dict[key]: 48 | out_dict[key] = torch.stack(out_dict[key], dim=0) 49 | print("out_dict['{}'].shape: ".format(key), out_dict[key].shape) 50 | 51 | # compute variance from log_variance 52 | out_dict['variances'] = torch.exp(out_dict['log_variances']) 53 | del out_dict['log_variances'] 54 | 55 | # convert torch tensor to numpy 56 | for key in out_dict.keys(): 57 | out_dict[key] = out_dict[key].data.cpu().numpy() 58 | 59 | # denormalize model outputs 60 | out_dict['predictions'] = denormalize(out_dict['predictions'], mean_target_train, std_target_train) 61 | out_dict['variances'] = out_dict['variances'] * std_target_train ** 2 62 | 63 | return out_dict 64 | 65 | 66 | if __name__ == "__main__": 67 | 68 | # set parameters 69 | parser = setup_parser() 70 | args, unknown = parser.parse_known_args() 71 | 72 | if args.file_path_L2A is not None: 73 | if not os.path.exists(args.file_path_L2A): 74 | raise ValueError('L2A file does not exists: {}'.format(args.file_path_L2A)) 75 | 76 | start_time_loading = time.time() 77 | 78 | # load input mean and std 79 | mean_input_train = np.load(os.path.join(args.ensemble_dir, 'model_0', 'mean_input_train.npy')) 80 | std_input_train = np.load(os.path.join(args.ensemble_dir, 'model_0', 'std_input_train.npy')) 81 | 82 | # load target mean and std (for denormalizing the predictions) 83 | mean_target_train = np.load(os.path.join(args.ensemble_dir, 'model_0', 'mean_target_train.npy')) 84 | std_target_train = np.load(os.path.join(args.ensemble_dir, 'model_0', 'std_target_train.npy')) 85 | 86 | # setupt preprocessing input_transforms 87 | input_transforms = Compose([Normalize(mean=mean_input_train, std=std_input_train), ToTensor()]) 88 | 89 | # create dataset 90 | ds_orbit = GediDataOrbitMem(args.file_path_L1B, args.file_path_L2A, sample_length=1420, 91 | input_transforms=input_transforms, noise_mean_key=args.noise_mean_key) 92 | 93 | print('orbit data loaded.') 94 | end_time_loading = time.time() 95 | duration_loading = end_time_loading - start_time_loading 96 | 97 | start_time_predicting = time.time() 98 | 99 | # create dataloader 100 | dl_pred = DataLoader(ds_orbit, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 101 | 102 | # load model architecture 103 | # Setup model 104 | model_chooser = Models() 105 | model = model_chooser(args.model_name)(num_outputs=args.num_outputs) 106 | model.to(device) 107 | 108 | # set model to eval model 109 | model.eval() 110 | 111 | # initialize predictions 112 | pred_means = [] 113 | pred_vars = [] 114 | 115 | # Loop through all models 116 | for i in range(args.n_models): 117 | print('predicting with model_{}'.format(i)) 118 | # get model_dir_i 119 | model_dir_i = os.path.join(args.ensemble_dir, 'model_{}'.format(i)) 120 | 121 | # load weights 122 | model_weights_path = os.path.join(model_dir_i, 'best_weights.pt') 123 | model.load_state_dict(torch.load(model_weights_path)) 124 | 125 | # predict 126 | out_dict_i = predict(model, dl_pred, mean_target_train, std_target_train) 127 | 128 | pred_means.append(out_dict_i['predictions']) 129 | pred_vars.append(out_dict_i['variances']) 130 | 131 | # convert to np array 132 | pred_means = np.array(pred_means, dtype=np.float32) 133 | pred_vars = np.array(pred_vars, dtype=np.float32) 134 | print('pred_means.shape', pred_means.shape) 135 | 136 | out_dict = {} 137 | 138 | # compute mean across ensemble mean predictions 139 | out_dict['pred_ensemble'] = np.mean(pred_means, axis=0) 140 | 141 | # compute variances 142 | epistemic_var = np.var(pred_means, axis=0) 143 | aleatoric_var = np.mean(pred_vars, axis=0) 144 | predictive_var = epistemic_var + aleatoric_var 145 | 146 | out_dict['predictive_std'] = np.sqrt(predictive_var) 147 | out_dict['aleatoric_std'] = np.sqrt(aleatoric_var) 148 | out_dict['epistemic_std']= np.sqrt(epistemic_var) 149 | 150 | end_time_predicting = time.time() 151 | duration_predicting = end_time_predicting - start_time_predicting 152 | 153 | # copy gedi keys to out_dict 154 | if ds_orbit.data_L2A: 155 | gedi_keys = ['shot_number', 'lat_lowestmode', 'lon_lowestmode', 'modis_nonvegetated'] 156 | for key in gedi_keys: 157 | out_dict[key] = ds_orbit.data_L2A[key] 158 | else: 159 | out_dict['shot_number'] = ds_orbit.data_L1B['shot_number'] 160 | 161 | # output directory to save predictions 162 | if args.prediction_dir is None: 163 | args.prediction_dir = os.path.dirname(args.file_path_L1B).replace('/L1B', '/pred_RH98') 164 | 165 | if not os.path.exists(args.prediction_dir): 166 | os.makedirs(args.prediction_dir) 167 | 168 | # save as h5 file 169 | out_path = os.path.join(args.prediction_dir, os.path.basename(args.file_path_L1B).replace('GEDI01_B', 'GEDI02_RH98')) 170 | print('writing to hdf5 file:') 171 | print(out_path) 172 | with h5py.File(out_path, 'w') as f: 173 | for key in out_dict.keys(): 174 | print(key, out_dict[key].shape) 175 | f.create_dataset(key, data=out_dict[key]) 176 | 177 | print('time loading data: ', time.strftime("%Hh%Mm%Ss", time.gmtime(duration_loading)) ) 178 | print('time predicting: ', time.strftime("%Hh%Mm%Ss", time.gmtime(duration_predicting)) ) 179 | print('DONE!') 180 | 181 | -------------------------------------------------------------------------------- /torch_code/utils/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | 5 | def setup_parser(): 6 | """ 7 | Setup parser with default settings. 8 | Returns: argparse.ArgumentParser() object 9 | """ 10 | 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument("--out_dir", default='./tmp/', help="output directory for the experiment") 14 | 15 | # dataset file 16 | parser.add_argument("--dataset", help="Dataset type.", default='CROSSOVER_GEDI', choices=['CROSSOVER_GEDI']) 17 | parser.add_argument("--inputs_path", help="path to h5 file with input and target arrays and additional attributes") 18 | parser.add_argument("--input_key", default='rxwaveform', help="input waveform", choices=['rxwaveform']) 19 | parser.add_argument("--target_key", default='als_rh098', help="target variable that is estimated.") 20 | 21 | # for simulated data (args.dataset == 'SIMULATED_GEDI') with numpy files e.g. input_key.npy target_key.npy 22 | parser.add_argument("--data_dir", help="path to directory with npy files") 23 | 24 | # dataset preprocessing and filtering 25 | parser.add_argument("--sample_length", default=1420, help="Waveform length. GEDI waveform max length is 1420.", type=int) 26 | parser.add_argument("--setting_idx", default=3, type=int, help="0: power-night, 1: power-night + power-day, 2: power-night + power-day + coverage-night, 3: all") 27 | parser.add_argument("--use_quality_flag", type=str2bool, nargs='?', const=True, default=True, help="True: only use samples with quality_flag==1") 28 | parser.add_argument("--min_gt", default=-np.inf, help="Filter target range: Keep samples >= min_gt", type=float) 29 | parser.add_argument("--max_gt", default=np.inf, help="Filter target range: Keep samples <= max_gt", type=float) 30 | parser.add_argument("--pearson_thresh", default=0.95, help="scalar (float) [0,1], quality criteria to filter data ", type=float) 31 | parser.add_argument("--noise_mean_key", default='noise_mean_corrected', help="noise mean key") 32 | parser.add_argument("--normalize_targets", type=str2bool, nargs='?', const=True, default=True, help="normalize target labels with mean and std") 33 | 34 | # model architecture 35 | parser.add_argument("--model_name", help="model names (functions) defined in models.py", default='SimpleResNet_8blocks') 36 | parser.add_argument("--num_outputs", default=2, help="Number of outputs. Set to 2 for regressing mean and variance.") 37 | 38 | # training params 39 | parser.add_argument("--skip_training", type=str2bool, nargs='?', const=True, default=False, help="do not optimize parameters (i.e. run test only)") 40 | parser.add_argument("--num_workers", default=8, help="Number of workers for pytorch Dataloader") 41 | parser.add_argument("--loss_key", default='gaussian_nll', help="Loss keys", choices=['MSE', 'MAE', 'gaussian_nll', 'laplacian_nll']) 42 | parser.add_argument("--batch_size", default=64, help="batch size at train/val time. (number of samples per iteration)", type=int) 43 | parser.add_argument("--nb_epoch", default=200, help="number of epochs to train", type=int) 44 | parser.add_argument("--base_learning_rate", default=0.0001, help="initial learning rate", type=float) 45 | parser.add_argument("--l2_lambda", default=0.0, help="L2 regularizer on weights hyperparameter", type=float) 46 | parser.add_argument("--optimizer", default='ADAM', help="optimizer name", choices=['ADAM', 'SGD']) 47 | parser.add_argument("--momentum", default=0.0, help="momentum for SGD ", type=float) 48 | 49 | # data augmentation for training 50 | parser.add_argument("--shift_left", default=None, help="Augmentation: scalar (float) [0,1], relative shift w.r.t waveform length", type=str2none) 51 | parser.add_argument("--shift_right", default=None, help="Augmentation: scalar (float) [0,1], relative shift w.r.t waveform length", type=str2none) 52 | 53 | # additional augmentation at train time for robustness tests. NOTE: currently not implemented see todos run.py Trainer class 54 | parser.add_argument("--label_noise", default=0.0, help="scalar (float) [0,1], relative label noise ", type=float) 55 | parser.add_argument("--label_noise_distribution", default='uniform', help="label noise distribution", choices=['uniform', 'normal']) 56 | 57 | # data splits and generalization experiment settings 58 | parser.add_argument("--data_split", default='randCV', help="attribute_name used to split data into train/test", choices=['randCV', 'attrCV']) 59 | parser.add_argument("--n_folds", default=10, help="Number of folds. ", type=int) 60 | parser.add_argument("--test_fold_idx", default=0, help="Fold split index that is used for testing. ", type=int) 61 | parser.add_argument("--range_to_remove", default=None, nargs='+', type=float, help="if data_split=='randCV': removes a target range (min max) to evaluate the epistemic uncertainty") 62 | parser.add_argument("--ood_attribute", default=None, help="if data_split=='randCV': attribute_name used to remove OOD data e.g. 'continental_region_1km' ") 63 | parser.add_argument("--ood_value", default=None, help="if data_split=='randCV': attribute value float defined as OOD (will be only in test data)", type=float) 64 | parser.add_argument("--ood_value_string", default=None, help="if data_split=='randCV': attribute value string defined as OOD (will be only in test data)") 65 | parser.add_argument("--split_attribute", default=None, help="if data_split=='attrCV': attribute_name used to split data into train/test") 66 | parser.add_argument("--test_attribute_value", default=None, nargs='+', type=float, help="if data_split=='attrCV': attribute value to hold-out for testing") 67 | parser.add_argument("--model_weights_path", help="Pre-trained model weights path (e.g. weights_best.h5) ") 68 | 69 | # test params 70 | parser.add_argument("--model_dir", help="Model directory with weights_best.h5, train_mean.npy, train_std.npy.") 71 | 72 | # prediction params 73 | parser.add_argument("--ensemble_dir", help="path to directory with subdirectories called model_i") 74 | parser.add_argument("--file_path_L1B", help="path to L1B h5 file for prediction.") 75 | parser.add_argument("--file_path_L2A", help="path to L2A h5 file for prediction.", type=str_or_none) 76 | parser.add_argument("--n_models", default=10, help="Number of models in the ensemble. ", type=int) 77 | parser.add_argument("--prediction_dir", default=None, help="output directory for the predictions") 78 | 79 | return parser 80 | 81 | 82 | # --- Helper functions to parse arguments --- 83 | 84 | def str2bool(v): 85 | if isinstance(v, bool): 86 | return v 87 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 88 | return True 89 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 90 | return False 91 | else: 92 | raise argparse.ArgumentTypeError('Boolean value expected.') 93 | 94 | 95 | def str2none(v): 96 | if v.lower() in ('none', '', 'nan', '0', '0.0'): 97 | return None 98 | else: 99 | return float(v) 100 | 101 | 102 | def str_or_none(v): 103 | if v.lower() in ('none', '', 'nan', '0', '0.0'): 104 | return None 105 | else: 106 | return str(v) 107 | 108 | class StoreAsArray(argparse._StoreAction): 109 | def __call__(self, parser, namespace, values, option_string=None): 110 | values = np.array(values) 111 | return super(StoreAsArray, self).__call__(parser, namespace, values, option_string) 112 | -------------------------------------------------------------------------------- /torch_code/models/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .resnet import ResNet, BasicBlock, Bottleneck, SimpleResNet 3 | from .simplecnn import SimpleCNN 4 | 5 | 6 | class Models: 7 | """ 8 | This is a wrapper class that defines several model functions. 9 | The model functions are called by passing the function name as a string. 10 | All model functions take the number of outputs (num_outputs) as an argument (int). 11 | 12 | Example: 13 | model_chooser = Models() 14 | model = model_chooser('SimpleResNet_8blocks')(num_outputs=1) 15 | 16 | """ 17 | def __init__(self): 18 | pass 19 | 20 | def __call__(self, func): 21 | """ 22 | Args: 23 | func: function name as string 24 | 25 | Returns: corresponding model function 26 | """ 27 | return getattr(self, func) 28 | 29 | # SIMPLE CNN 30 | 31 | def SimpleCNN_8layers(self, num_outputs): 32 | return SimpleCNN(out_features=(16, 32, 64, 128, 128, 128, 128, 128), 33 | num_outputs=num_outputs, 34 | global_pool=nn.AdaptiveAvgPool1d) 35 | 36 | # ---- WIDER VERSIONS ---- 37 | 38 | # double the number of features in each layer w.r.t. SimpleCNN_8layers 39 | def SimpleCNN_8layers_width2(self, num_outputs): 40 | return SimpleCNN(out_features=(32, 64, 128, 256, 256, 256, 256, 256), 41 | num_outputs=num_outputs, 42 | global_pool=nn.AdaptiveAvgPool1d) 43 | 44 | # The number of features in each layer w.r.t. SimpleCNN_8layers multiplied by factor 4 45 | def SimpleCNN_8layers_width4(self, num_outputs): 46 | return SimpleCNN(out_features=(64, 128, 256, 512, 512, 512, 512, 512), 47 | num_outputs=num_outputs, 48 | global_pool=nn.AdaptiveAvgPool1d) 49 | 50 | # The number of features in each layer w.r.t. SimpleCNN_8layers multiplied by factor 8 51 | def SimpleCNN_8layers_width8(self, num_outputs): 52 | return SimpleCNN(out_features=(128, 256, 512, 1024, 1024, 1024, 1024, 1024), 53 | num_outputs=num_outputs, 54 | global_pool=nn.AdaptiveAvgPool1d) 55 | 56 | # ---- DEEPER VERSIONS ---- 57 | 58 | def SimpleCNN_12layers(self, num_outputs): 59 | return SimpleCNN(out_features=(16, 32, 64, 128, 128, 128, 128, 128, 128, 128, 128, 128), 60 | num_outputs=num_outputs, 61 | global_pool=nn.AdaptiveAvgPool1d) 62 | 63 | def SimpleCNN_16layers(self, num_outputs): 64 | return SimpleCNN(out_features=(16, 32, 64, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128), 65 | num_outputs=num_outputs, 66 | global_pool=nn.AdaptiveAvgPool1d) 67 | 68 | # SIMPLE RESNET 69 | 70 | def SimpleResNet_8blocks(self, num_outputs): 71 | return SimpleResNet(block=BasicBlock, 72 | in_features=1, 73 | out_features=(16, 32, 64, 128, 128, 128, 128, 128), 74 | num_outputs=num_outputs, 75 | global_pool=nn.AdaptiveAvgPool1d, 76 | max_pool=True) 77 | 78 | def SimpleResNet_8blocks_width2(self, num_outputs): 79 | return SimpleResNet(block=BasicBlock, 80 | in_features=1, 81 | out_features=(32, 64, 128, 256, 256, 256, 256, 256), 82 | num_outputs=num_outputs, 83 | global_pool=nn.AdaptiveAvgPool1d, 84 | max_pool=True) 85 | 86 | def SimpleResNet_8blocks_width4(self, num_outputs): 87 | return SimpleResNet(block=BasicBlock, 88 | in_features=1, 89 | out_features=(64, 128, 256, 512, 512, 512, 512, 512), 90 | num_outputs=num_outputs, 91 | global_pool=nn.AdaptiveAvgPool1d, 92 | max_pool=True) 93 | 94 | # ---- DEEPER VERSIONS ---- 95 | 96 | def SimpleResNet_4blocks(self, num_outputs): 97 | return SimpleResNet(block=BasicBlock, 98 | in_features=1, 99 | out_features=(16, 32, 64, 128), 100 | num_outputs=num_outputs, 101 | global_pool=nn.AdaptiveAvgPool1d, 102 | max_pool=True) 103 | 104 | def SimpleResNet_6blocks(self, num_outputs): 105 | return SimpleResNet(block=BasicBlock, 106 | in_features=1, 107 | out_features=(16, 32, 64, 128, 128, 128), 108 | num_outputs=num_outputs, 109 | global_pool=nn.AdaptiveAvgPool1d, 110 | max_pool=True) 111 | 112 | def SimpleResNet_12blocks(self, num_outputs): 113 | return SimpleResNet(block=BasicBlock, 114 | in_features=1, 115 | out_features=(16, 32, 64, 128, 128, 128, 128, 128, 128, 128, 128, 128), 116 | num_outputs=num_outputs, 117 | global_pool=nn.AdaptiveAvgPool1d, 118 | max_pool=True) 119 | 120 | def SimpleResNet_16blocks(self, num_outputs): 121 | return SimpleResNet(block=BasicBlock, 122 | in_features=1, 123 | out_features=(16, 32, 64, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128), 124 | num_outputs=num_outputs, 125 | global_pool=nn.AdaptiveAvgPool1d, 126 | max_pool=True) 127 | 128 | # 1D-RESNET 129 | def _resnet(self, block, layers, **kwargs): 130 | model = ResNet(block, layers, **kwargs) 131 | return model 132 | 133 | def resnet10(self, **kwargs): 134 | r"""ResNet-10 model, a smaller adaptation of the ResNet-18 model from 135 | `"Deep Residual Learning for Image Recognition" `_ 136 | """ 137 | return self._resnet(BasicBlock, [1, 1, 1, 1], **kwargs) 138 | 139 | def resnet18(self, **kwargs): 140 | r"""ResNet-18 model from 141 | `"Deep Residual Learning for Image Recognition" `_ 142 | """ 143 | return self._resnet(BasicBlock, [2, 2, 2, 2], **kwargs) 144 | 145 | def resnet34(self, **kwargs): 146 | r"""ResNet-34 model from 147 | `"Deep Residual Learning for Image Recognition" `_ 148 | """ 149 | return self._resnet(BasicBlock, [3, 4, 6, 3], **kwargs) 150 | 151 | def resnet50(self, **kwargs): 152 | r"""ResNet-50 model from 153 | `"Deep Residual Learning for Image Recognition" `_ 154 | """ 155 | return self._resnet(Bottleneck, [3, 4, 6, 3], **kwargs) 156 | 157 | def resnet101(self, **kwargs): 158 | r"""ResNet-101 model from 159 | `"Deep Residual Learning for Image Recognition" `_ 160 | """ 161 | return self._resnet(Bottleneck, [3, 4, 23, 3], **kwargs) 162 | 163 | def resnet152(self, **kwargs): 164 | r"""ResNet-152 model from 165 | `"Deep Residual Learning for Image Recognition" `_ 166 | """ 167 | return self._resnet(Bottleneck, [3, 8, 36, 3], **kwargs) 168 | 169 | def resnext50_32x4d(self, **kwargs): 170 | r"""ResNeXt-50 32x4d model from 171 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 172 | """ 173 | kwargs['groups'] = 32 174 | kwargs['width_per_group'] = 4 175 | return self._resnet(Bottleneck, [3, 4, 6, 3], **kwargs) 176 | 177 | def resnext101_32x8d(self, **kwargs): 178 | r"""ResNeXt-101 32x8d model from 179 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 180 | """ 181 | kwargs['groups'] = 32 182 | kwargs['width_per_group'] = 8 183 | return self._resnet(Bottleneck, [3, 4, 23, 3], **kwargs) 184 | 185 | def wide_resnet50_2(self, **kwargs): 186 | r"""Wide ResNet-50-2 model from 187 | `"Wide Residual Networks" `_ 188 | The model is the same as ResNet except for the bottleneck number of channels 189 | which is twice larger in every block. The number of channels in outer 1x1 190 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 191 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 192 | """ 193 | kwargs['width_per_group'] = 64 * 2 194 | return self._resnet(Bottleneck, [3, 4, 6, 3], **kwargs) 195 | 196 | def wide_resnet101_2(self, **kwargs): 197 | r"""Wide ResNet-101-2 model from 198 | `"Wide Residual Networks" `_ 199 | The model is the same as ResNet except for the bottleneck number of channels 200 | which is twice larger in every block. The number of channels in outer 1x1 201 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 202 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 203 | Args: 204 | pretrained (bool): If True, returns a model pre-trained on ImageNet 205 | progress (bool): If True, displays a progress bar of the download to stderr 206 | """ 207 | kwargs['width_per_group'] = 64 * 2 208 | return self._resnet(Bottleneck, [3, 4, 23, 3], **kwargs) 209 | -------------------------------------------------------------------------------- /torch_code/collect_ensemble_preds.py: -------------------------------------------------------------------------------- 1 | """ 2 | Collect ensemble predictions and results from all cross-validation folds. 3 | 4 | This script creates: 5 | 6 | i) a new subdirectory in each testfold directory e.g. "/experiment_dir/testfold_XX/ensemble", 7 | containing the ensemble results for the particular test fold. 8 | 9 | ii) a new subdirectory in the experiment base directory e.g. "/experiment_dir/ensemble_collected", 10 | containing the collected ensemble predictions and results from all disjoint test folds. 11 | 12 | """ 13 | 14 | import numpy as np 15 | import json 16 | import os 17 | import sys 18 | import matplotlib.pyplot as plt 19 | 20 | 21 | from utils.error_metrics import get_metrics_dict 22 | from utils.plots import plot_hist2d, plot_calibration, plot_precision_recall 23 | 24 | 25 | def collect_CV_ensemble_predictions(experiment_dir_base, collected_dir, n_folds, n_models, 26 | metrics_dict_fun, model_indices_to_exclude=None, filter_negative_preds=True): 27 | 28 | # Add testfold subdir template 29 | experiment_dir_base = os.path.join(experiment_dir_base, 'testfold_{}') 30 | 31 | # ## Collect results from cross-validation 32 | results = {'pred_ensemble': [], 33 | 'targets': [], 34 | 'epistemic_std': [], 35 | 'aleatoric_std': [], 36 | 'predictive_std': [], 37 | 'test_indices': []} 38 | 39 | error_metrics_folds = {} 40 | for metric in metrics_dict_fun.keys(): 41 | error_metrics_folds[metric] = [] 42 | 43 | for fold_i in np.arange(n_folds): 44 | print('*****************************************************') 45 | print('fold: ', fold_i) 46 | experiment_dir = experiment_dir_base.format(fold_i) 47 | 48 | out_dir = os.path.join(experiment_dir, 'ensemble') 49 | if not os.path.exists(out_dir): 50 | os.makedirs(out_dir) 51 | 52 | #  init lists 53 | if model_indices_to_exclude is None: 54 | model_indices_to_exclude = [] 55 | 56 | pred_means, pred_var = [], [] 57 | 58 | for i in range(n_models): 59 | if i in model_indices_to_exclude: 60 | continue 61 | model_i_dir = os.path.join(experiment_dir, 'model_{}'.format(i)) 62 | pred_means.append(np.load(os.path.join(model_i_dir, 'predictions.npy'))) 63 | pred_var.append(np.load(os.path.join(model_i_dir, 'variances.npy'))) 64 | 65 | if i == 0: 66 | targets = np.load(os.path.join(model_i_dir, 'targets.npy')) 67 | test_indices = np.load(os.path.join(model_i_dir, 'test_indices.npy')) 68 | print('mean_target_train: ', np.load(os.path.join(model_i_dir, 'mean_target_train.npy'))) 69 | print('std_target_train: ', np.load(os.path.join(model_i_dir, 'std_target_train.npy'))) 70 | 71 | # convert the ensemble list to numpy 72 | pred_means = np.array(pred_means) 73 | pred_var = np.array(pred_var) 74 | 75 | # final predictions (average over model ensemble) 76 | pred_ensemble = np.mean(pred_means, axis=0) 77 | 78 | print(pred_means.shape) 79 | print(pred_var.shape) 80 | print(targets.shape) 81 | print(pred_ensemble.shape) 82 | 83 | print('pred_var: min: {}, max: {}'.format(np.min(pred_var), np.max(pred_var))) 84 | print('pred_ensemble: min: {}, max: {}'.format(np.min(pred_ensemble), np.max(pred_ensemble))) 85 | 86 | # save as npy 87 | np.save(os.path.join(out_dir, 'pred_means.npy'), pred_means) 88 | np.save(os.path.join(out_dir, 'pred_var.npy'), pred_var) 89 | np.save(os.path.join(out_dir, 'targets.npy'), targets) 90 | np.save(os.path.join(out_dir, 'pred_ensemble.npy'), pred_ensemble) 91 | 92 | # compute performance of ensemble 93 | if filter_negative_preds: 94 | # remove samples with negative pred_ensemble 95 | valid_indices = pred_ensemble >= 0 96 | else: 97 | # all samples are valid 98 | valid_indices = pred_ensemble == pred_ensemble 99 | 100 | x = targets[valid_indices] # ground truth 101 | y = pred_ensemble[valid_indices] # prediction 102 | 103 | for metric in metrics_dict_fun.keys(): 104 | print('{}: {:.1f}'.format(metric, metrics_dict_fun[metric](x, y))) 105 | 106 | error_metrics = {} 107 | for metric in metrics_dict_fun.keys(): 108 | error_metrics[metric] = metrics_dict_fun[metric](x, y) 109 | print(metric, error_metrics[metric]) 110 | 111 | with open(os.path.join(out_dir, 'error_metrics.json'), 'w') as f: 112 | json.dump(error_metrics, f) 113 | 114 | # collect error metrics for all folds 115 | for metric in error_metrics.keys(): 116 | error_metrics_folds[metric].append(error_metrics[metric]) 117 | 118 | # compute epistemic and aleatoric over model ensemble 119 | epistemic_var = np.var(pred_means, axis=0) 120 | aleatoric_var = np.mean(pred_var, axis=0) 121 | predictive_var = epistemic_var + aleatoric_var 122 | 123 | aleatoric_std = np.sqrt(aleatoric_var) 124 | epistemic_std = np.sqrt(epistemic_var) 125 | predictive_std = np.sqrt(predictive_var) 126 | 127 | print('epistemic_std.shape', epistemic_std.shape) 128 | print('aleatoric_std.shape', aleatoric_std.shape) 129 | print('predictive_std.shape', predictive_std.shape) 130 | 131 | # save as npy 132 | np.save(os.path.join(out_dir, 'epistemic_std.npy'), epistemic_std) 133 | np.save(os.path.join(out_dir, 'aleatoric_std.npy'), aleatoric_std) 134 | np.save(os.path.join(out_dir, 'predictive_std.npy'), predictive_std) 135 | 136 | results['pred_ensemble'].append(pred_ensemble) 137 | results['targets'].append(targets) 138 | results['epistemic_std'].append(epistemic_std) 139 | results['aleatoric_std'].append(aleatoric_std) 140 | results['predictive_std'].append(predictive_std) 141 | results['test_indices'].append(test_indices) 142 | 143 | #  concatenate to numpy array 144 | for key in results: 145 | results[key] = np.concatenate(results[key]) 146 | 147 | # errors folds to numpy 148 | for key in error_metrics_folds: 149 | error_metrics_folds[key] = np.array(error_metrics_folds[key]) 150 | 151 | for key in results: 152 | print(results[key].shape, results[key].dtype) 153 | 154 | for key in error_metrics_folds: 155 | print(key) 156 | print(error_metrics_folds[key]) 157 | 158 | # Save the collected folds (JOIN THE FOLDS) 159 | for key in results: 160 | np.save(os.path.join(collected_dir, '{}.npy'.format(key)), results[key]) 161 | 162 | np.save(os.path.join(collected_dir, 'error_metrics_folds.npy'), error_metrics_folds) 163 | 164 | # compute mean and std of error metrics over all 10 folds 165 | error_metrics = {} 166 | for key in error_metrics_folds: 167 | error_metrics[key + '_mean'] = np.mean(error_metrics_folds[key]) 168 | error_metrics[key + '_std'] = np.std(error_metrics_folds[key]) 169 | print(error_metrics) 170 | 171 | with open(os.path.join(collected_dir, 'error_metrics.json'), 'w') as f: 172 | json.dump(error_metrics, f) 173 | 174 | return results 175 | 176 | 177 | if __name__ == "__main__": 178 | 179 | # Set path to experiment base directory containing all test fold subdirectories 180 | experiment_dir_base = sys.argv[1] 181 | # experiment_dir_base = "demo_data/GEDI_BDL_demo/output_demo/" 182 | 183 | filter_negative_preds=True # remove negative height predictions for evaluation (negative predictions are still included in the collected data) 184 | 185 | n_folds = 10 186 | n_models = 10 187 | 188 | # make new subdirectory for collected predictions 189 | collected_dir = os.path.join(experiment_dir_base, 'ensemble_collected') 190 | if not os.path.exists(collected_dir): 191 | os.makedirs(collected_dir) 192 | 193 | print('collected_dir:') 194 | print(collected_dir) 195 | 196 | # get all metrics functions 197 | metrics_dict_fun = get_metrics_dict() 198 | 199 | # collect results 200 | results = collect_CV_ensemble_predictions(experiment_dir_base=experiment_dir_base, 201 | collected_dir=collected_dir, 202 | n_folds=n_folds, 203 | n_models=n_models, 204 | metrics_dict_fun=metrics_dict_fun, 205 | filter_negative_preds=filter_negative_preds) 206 | 207 | # ------ PLOT RESULTS ------ 208 | 209 | # -- plot confusion ground truth vs. prediction -- 210 | ma = 100 211 | plot_hist2d(x=results['targets'], y=results['pred_ensemble'], ma=ma, step=ma/10, 212 | out_dir=collected_dir, figsize=(8, 6), xlabel='Ground truth [m]', ylabel='Prediction [m]') 213 | 214 | # -- plot precision recall curve -- 215 | # (Note that here we do not use the adaptive thresholding described in the paper) 216 | metric = 'RMSE' 217 | fig = plt.figure(figsize=(8, 6)) 218 | ax = fig.gca() 219 | plot_precision_recall(predictions=results['pred_ensemble'], targets=results['targets'], 220 | uncertainties=results['predictive_std'], metric=metric, 221 | ax=ax, label=None, ylabel='RMSE [m]', style='k-') 222 | plt.grid() 223 | plt.legend(loc='upper left') 224 | plt.xlim(0.5, 1.0) 225 | plt.tight_layout() 226 | fig.savefig(fname=os.path.join(collected_dir, 'PR_curves_RMSE_predictive_std.png'), dpi=300) 227 | 228 | # -- plot calibration -- 229 | step = 1 230 | bins = np.arange(0, 16, step) 231 | xticks = np.arange(0, 16, 2) 232 | min_bin_count = 200 233 | fig, ax = plot_calibration(predictions=results['pred_ensemble'], targets=results['targets'], 234 | uncertainties=results['predictive_std'], metric='RMSE', bins=bins, step=step, 235 | min_bin_count=min_bin_count, xlabel='Predictive STD [m]', ylabel='Empirical RMSE [m]') 236 | ax.set_xticks(xticks) 237 | plt.tight_layout() 238 | fig.savefig(fname=os.path.join(collected_dir, 'calibration.png'), dpi=300) 239 | 240 | print('Collected results and plots are saved in:') 241 | print(collected_dir) 242 | 243 | -------------------------------------------------------------------------------- /torch_code/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is an adaptation of the ResNet architecture from 2D to 1D based on a hard copy from 3 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py (accessed 2020-07-23) 4 | Every 2D layer is replaced with its 1D pendant. Logically, pretraining has been removed. 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 13 | """3x3 convolution with padding""" 14 | return nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=dilation, groups=groups, bias=False, dilation=dilation) 16 | 17 | 18 | def conv1x1(in_planes, out_planes, stride=1): 19 | """1x1 convolution""" 20 | return nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 27 | base_width=64, dilation=1, norm_layer=None): 28 | super(BasicBlock, self).__init__() 29 | if norm_layer is None: 30 | norm_layer = nn.BatchNorm1d 31 | if groups != 1 or base_width != 64: 32 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 33 | if dilation > 1: 34 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 35 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 36 | self.conv1 = conv3x3(inplanes, planes, stride) 37 | self.bn1 = norm_layer(planes) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.conv2 = conv3x3(planes, planes) 40 | self.bn2 = norm_layer(planes) 41 | self.downsample = downsample 42 | self.stride = stride 43 | 44 | def forward(self, x): 45 | identity = x 46 | 47 | out = self.conv1(x) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | 54 | if self.downsample is not None: 55 | identity = self.downsample(x) 56 | 57 | out += identity 58 | out = self.relu(out) 59 | 60 | return out 61 | 62 | 63 | class Bottleneck(nn.Module): 64 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 65 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 66 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 67 | # This variant is also known as ResNet V1.5 and improves accuracy according to 68 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 69 | 70 | expansion = 4 71 | 72 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 73 | base_width=64, dilation=1, norm_layer=None): 74 | super(Bottleneck, self).__init__() 75 | if norm_layer is None: 76 | norm_layer = nn.BatchNorm1d 77 | width = int(planes * (base_width / 64.)) * groups 78 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 79 | self.conv1 = conv1x1(inplanes, width) 80 | self.bn1 = norm_layer(width) 81 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 82 | self.bn2 = norm_layer(width) 83 | self.conv3 = conv1x1(width, planes * self.expansion) 84 | self.bn3 = norm_layer(planes * self.expansion) 85 | self.relu = nn.ReLU(inplace=True) 86 | self.downsample = downsample 87 | self.stride = stride 88 | 89 | def forward(self, x): 90 | identity = x 91 | 92 | out = self.conv1(x) 93 | out = self.bn1(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv3(out) 101 | out = self.bn3(out) 102 | 103 | if self.downsample is not None: 104 | identity = self.downsample(x) 105 | 106 | out += identity 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | 112 | class ResNet(nn.Module): 113 | 114 | def __init__(self, block, layers, in_features=1, num_outputs=1000, zero_init_residual=False, 115 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 116 | norm_layer=None): 117 | super(ResNet, self).__init__() 118 | if norm_layer is None: 119 | norm_layer = nn.BatchNorm1d 120 | self._norm_layer = norm_layer 121 | 122 | self.inplanes = 64 123 | self.dilation = 1 124 | if replace_stride_with_dilation is None: 125 | # each element in the tuple indicates if we should replace 126 | # the 2x2 stride with a dilated convolution instead 127 | replace_stride_with_dilation = [False, False, False] 128 | if len(replace_stride_with_dilation) != 3: 129 | raise ValueError("replace_stride_with_dilation should be None " 130 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 131 | self.groups = groups 132 | self.base_width = width_per_group 133 | self.conv1 = nn.Conv1d(in_features, self.inplanes, kernel_size=7, stride=2, padding=3, 134 | bias=False) 135 | self.bn1 = norm_layer(self.inplanes) 136 | self.relu = nn.ReLU(inplace=True) 137 | self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 138 | self.layer1 = self._make_layer(block, 64, layers[0]) 139 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 140 | dilate=replace_stride_with_dilation[0]) 141 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 142 | dilate=replace_stride_with_dilation[1]) 143 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 144 | dilate=replace_stride_with_dilation[2]) 145 | self.avgpool = nn.AdaptiveAvgPool1d(1) 146 | self.fc = nn.Linear(512 * block.expansion, num_outputs) 147 | 148 | for m in self.modules(): 149 | if isinstance(m, nn.Conv1d): 150 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 151 | elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)): 152 | nn.init.constant_(m.weight, 1) 153 | nn.init.constant_(m.bias, 0) 154 | 155 | # Zero-initialize the last BN in each residual branch, 156 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 157 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 158 | if zero_init_residual: 159 | for m in self.modules(): 160 | if isinstance(m, Bottleneck): 161 | nn.init.constant_(m.bn3.weight, 0) 162 | elif isinstance(m, BasicBlock): 163 | nn.init.constant_(m.bn2.weight, 0) 164 | 165 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 166 | norm_layer = self._norm_layer 167 | downsample = None 168 | previous_dilation = self.dilation 169 | if dilate: 170 | self.dilation *= stride 171 | stride = 1 172 | if stride != 1 or self.inplanes != planes * block.expansion: 173 | downsample = nn.Sequential( 174 | conv1x1(self.inplanes, planes * block.expansion, stride), 175 | norm_layer(planes * block.expansion), 176 | ) 177 | 178 | layers = [] 179 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 180 | self.base_width, previous_dilation, norm_layer)) 181 | self.inplanes = planes * block.expansion 182 | for _ in range(1, blocks): 183 | layers.append(block(self.inplanes, planes, groups=self.groups, 184 | base_width=self.base_width, dilation=self.dilation, 185 | norm_layer=norm_layer)) 186 | 187 | return nn.Sequential(*layers) 188 | 189 | def _forward_impl(self, x): 190 | # See note [TorchScript super()] 191 | x = self.conv1(x) 192 | x = self.bn1(x) 193 | x = self.relu(x) 194 | x = self.maxpool(x) 195 | 196 | x = self.layer1(x) 197 | x = self.layer2(x) 198 | x = self.layer3(x) 199 | x = self.layer4(x) 200 | 201 | x = self.avgpool(x) 202 | x = torch.flatten(x, 1) 203 | x = self.fc(x) 204 | 205 | return x 206 | 207 | def forward(self, x): 208 | return self._forward_impl(x) 209 | 210 | 211 | class SimpleResNet(nn.Module): 212 | """ 213 | This is the basic model architecture used in the experiments. 214 | """ 215 | def __init__(self, 216 | block=BasicBlock, 217 | in_features=1, 218 | out_features=(16, 32, 64, 128, 128, 128, 128, 128), 219 | num_outputs=1, 220 | global_pool=nn.AdaptiveAvgPool1d, 221 | max_pool=True, 222 | dilation_rate=1): 223 | """ 224 | The constructor. 225 | :param in_features: input channel dimension (we say waveforms have channel dimension 1). 226 | :param out_features: list of channel feature dimensions. 227 | :param num_outputs: number of the output dimension . 228 | :param global_pool: the global pooling to use before the fully connected layer. 229 | :param max_pool: whether to use max pooling or not. 230 | :param dilation_rate: must be >= 1. Defaults to 1 (no dilation). 231 | """ 232 | super(SimpleResNet, self).__init__() 233 | self.relu = nn.ReLU(inplace=True) 234 | norm_layer = nn.BatchNorm1d 235 | layers = list() 236 | for i in range(len(out_features)): 237 | in_channels = in_features if i == 0 else out_features[i-1] 238 | 239 | # Check if the residual shortcut can be the identity (with a constant number of features). 240 | # Otherwise use a conv1x1 as a trainable shortcut to adapt to new number of features. 241 | if in_channels != out_features[i]: 242 | shortcut = nn.Sequential(conv1x1(in_channels, out_features[i]), 243 | norm_layer(out_features[i])) 244 | else: 245 | # this will build an identity shortcut 246 | shortcut = None 247 | 248 | # append a residual block 249 | layers.append(block(inplanes=in_channels, planes=out_features[i], stride=1, downsample=shortcut, 250 | norm_layer=norm_layer)) 251 | 252 | if max_pool: 253 | layers.append(nn.MaxPool1d(kernel_size=3, stride=2, padding=1)) 254 | 255 | self.conv_layers = nn.Sequential(*layers) 256 | self.global_pool = global_pool(output_size=1) 257 | self.dropout = nn.Dropout(p=0.5) 258 | self.fc = nn.Linear(in_features=out_features[-1], out_features=num_outputs) 259 | 260 | def forward(self, x): 261 | x = self.conv_layers(x) 262 | x = self.global_pool(x) 263 | x = self.dropout(x) 264 | x = x.flatten(start_dim=1) 265 | x = self.fc(x) 266 | return x 267 | -------------------------------------------------------------------------------- /torch_code/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | import h5py 4 | from abc import ABC, abstractmethod 5 | from torch.utils.data import Dataset 6 | from utils.transforms import pad_waveforms 7 | from utils.utils import get_quality_indices 8 | 9 | 10 | # parsing the geolocated waveforms 11 | def parse_L1B_hdf5(filepath, 12 | quality_dict, 13 | use_coverage=True, 14 | keys_beam=('shot_number',), 15 | noise_mean_key='noise_mean_corrected'): 16 | 17 | # initialize output dictionary 18 | keys_beam = list(keys_beam) 19 | keys_beam.append(noise_mean_key) 20 | 21 | out_dict = {} 22 | keys = keys_beam + ['rxwaveform'] 23 | for key in keys: 24 | out_dict[key] = [] 25 | 26 | # select beam names 27 | coverage_beams = ['BEAM0000', 'BEAM0001', 'BEAM0010', 'BEAM0011'] 28 | power_beams = ['BEAM0101', 'BEAM0110', 'BEAM1000', 'BEAM1011'] 29 | if use_coverage: 30 | beam_names = coverage_beams + power_beams 31 | else: 32 | beam_names = power_beams 33 | 34 | with h5py.File(filepath, 'r') as f: 35 | 36 | # if quality_dict=None, it is not given by the L2A file: set quality=1 for all samples 37 | if not quality_dict: 38 | quality_dict = {} 39 | for beam in beam_names: 40 | quality_dict[beam] = np.ones_like(f[beam]['shot_number'][:]).astype(np.bool) 41 | 42 | for beam in beam_names: 43 | print(beam) 44 | if not beam in f.keys(): 45 | print('Beam: {} does not exist. continue...'.format(beam)) 46 | continue 47 | 48 | for k in keys_beam: 49 | out_dict[k] = out_dict[k] + list(f[beam][k][:][quality_dict[beam]]) 50 | 51 | # parse the waveforms (rxwaveform is a single array) 52 | print('parsing full rxwaveform array...') 53 | rxwaveform_all = np.array(f[beam]['rxwaveform'][:]) 54 | print('rxwaveform_all.shape', rxwaveform_all.shape) 55 | print('extracting waveform starts and ends') 56 | # we cut it into parts belonging to each of the valid laser shot 57 | # index starts at 1 correct for python indexing starting at 0 58 | start = f[beam]['rx_sample_start_index'][:][quality_dict[beam]] - 1 59 | count = f[beam]['rx_sample_count'][:][quality_dict[beam]] 60 | end = start + count 61 | 62 | print('max count', np.max(count)) 63 | 64 | print('cutting into single waveforms...') 65 | size_rxwaveform = len(rxwaveform_all) 66 | print('size_rxwaveform', size_rxwaveform) 67 | 68 | rxwaveform = [] 69 | for i in range(len(start)): 70 | wave = rxwaveform_all[start[i]:end[i]] # this is a view, not a copy 71 | rxwaveform.append(wave) 72 | 73 | out_dict['rxwaveform'] += rxwaveform 74 | 75 | # convert to numpy arrays 76 | for k in out_dict.keys(): 77 | if k not in ['shot_number', 'rxwaveform']: 78 | out_dict[k] = np.array(out_dict[k], dtype=np.float32) 79 | out_dict['shot_number'] = np.array(out_dict['shot_number'], dtype=np.uint64) 80 | return out_dict 81 | 82 | 83 | def parse_L2A_hdf5(filepath, use_coverage=True, 84 | keys_beam=('shot_number', 'quality_flag', 'lat_lowestmode', 'lon_lowestmode'), 85 | keys_land_cover=('modis_nonvegetated',)): 86 | 87 | keys_beam = list(keys_beam) 88 | keys_land_cover = list(keys_land_cover) 89 | 90 | # initialize output dictionary 91 | out_dict = {} 92 | keys = keys_beam + keys_land_cover 93 | for key in keys: 94 | out_dict[key] = [] 95 | 96 | # select beam names 97 | coverage_beams = ['BEAM0000', 'BEAM0001', 'BEAM0010', 'BEAM0011'] 98 | power_beams = ['BEAM0101', 'BEAM0110', 'BEAM1000', 'BEAM1011'] 99 | 100 | if use_coverage: 101 | beam_names = coverage_beams + power_beams 102 | else: 103 | beam_names = power_beams 104 | 105 | # init quality dictionary (for each beam) 106 | quality_dict = {} 107 | 108 | with h5py.File(filepath, 'r') as f: 109 | for beam in beam_names: 110 | if not beam in f.keys(): 111 | print('Beam: {} does not exist. continue...'.format(beam)) 112 | continue 113 | 114 | quality_dict[beam] = np.array(f[beam]['quality_flag'][:], dtype=np.bool) 115 | 116 | for k in keys_beam: 117 | out_dict[k] = out_dict[k] + list(f[beam][k][:][quality_dict[beam]]) 118 | 119 | for k in keys_land_cover: 120 | out_dict[k] = out_dict[k] + list(f[beam]['land_cover_data'][k][:][quality_dict[beam]]) 121 | 122 | # convert to numpy arrays 123 | for k in out_dict.keys(): 124 | if k != 'shot_number': 125 | out_dict[k] = np.array(out_dict[k]) 126 | out_dict['shot_number'] = np.array(out_dict['shot_number'], dtype=np.uint64) 127 | out_dict['quality_flag'] = np.array(out_dict['quality_flag'], dtype=np.bool) 128 | return out_dict, quality_dict 129 | 130 | 131 | def filter_quality(data_dict, quality_indices): 132 | for key in data_dict.keys(): 133 | data_dict[key] = data_dict[key][quality_indices] 134 | return data_dict 135 | 136 | 137 | def pad_waveform(waveform, pad_constant, out_size=1420): 138 | pad_after = out_size-len(waveform) 139 | padded_waveform = np.pad(waveform, pad_width=((0, pad_after),), mode='constant', constant_values=(0, pad_constant)) 140 | return padded_waveform 141 | 142 | 143 | class GediDataOrbitMem(Dataset): 144 | """ 145 | This dataset class loads all waveforms from an L1B orbit h5 file and filters based on the corresponding L2A quality_flag. 146 | """ 147 | 148 | def __init__(self, file_path_L1B, file_path_L2A=None, sample_length=1420, input_transforms=None, 149 | noise_mean_key='noise_mean_corrected'): 150 | 151 | super(GediDataOrbitMem, self).__init__() 152 | 153 | self.file_path_L1B = file_path_L1B 154 | self.file_path_L2A = file_path_L2A 155 | self.input_transforms = input_transforms 156 | self.sample_length = sample_length 157 | 158 | if self.file_path_L2A: 159 | 160 | # parse L2A data 161 | print('Loading L2A...') 162 | self.data_L2A, self.quality_dict = parse_L2A_hdf5(filepath=self.file_path_L2A) 163 | 164 | for key in self.data_L2A.keys(): 165 | print(key, self.data_L2A[key].shape, self.data_L2A[key].dtype) 166 | else: 167 | self.data_L2A, self.quality_dict = None, None 168 | 169 | # parse L1B data 170 | print('Loading L1B...') 171 | self.data_L1B = parse_L1B_hdf5(filepath=self.file_path_L1B, quality_dict=self.quality_dict, 172 | noise_mean_key=noise_mean_key) 173 | 174 | self.inputs = self.data_L1B['rxwaveform'] 175 | self.mean_noise_lvls = self.data_L1B[noise_mean_key] 176 | 177 | print('preprocessing samples...') 178 | self._preprocess_samples() 179 | 180 | # expand dimension of waveforms 181 | self.inputs = np.array(self.inputs, dtype=np.float32) 182 | self.inputs = self.inputs[..., None] 183 | 184 | self.data_L1B['rxwaveform'] = self.inputs 185 | 186 | for key in self.data_L1B.keys(): 187 | print(key, self.data_L1B[key].shape, self.data_L1B[key].dtype) 188 | 189 | print('inputs.shape:', self.inputs.shape) 190 | print('mean_noise_lvls.shape:', self.mean_noise_lvls.shape) 191 | 192 | if self.file_path_L2A: 193 | assert np.array_equal(self.data_L1B['shot_number'], self.data_L2A['shot_number']), 'shot_number in L1B and L2A do not have the same order.' 194 | 195 | 196 | def _preprocess_samples(self): 197 | """ 198 | Applies three different transformations to the waveforms 199 | """ 200 | 201 | for i in range(len(self.inputs)): 202 | 203 | # pad first with mean: 204 | self.inputs[i] = pad_waveform(self.inputs[i], pad_constant=self.mean_noise_lvls[i], out_size=1420) 205 | # subtract noise level 206 | self.inputs[i] = self.inputs[i] - self.mean_noise_lvls[i] 207 | # normalize integral to 1 (total energy return) 208 | self.inputs[i] = self.inputs[i] / np.sum(self.inputs[i]) 209 | 210 | def __getitem__(self, index): 211 | """ 212 | Returns the index-th waveform and the corresponding target. 213 | :param index: index within 0 and len(self) 214 | :return: index-th waveform 215 | """ 216 | sample = self.inputs[index] 217 | 218 | if self.input_transforms: 219 | sample = self.input_transforms(sample) 220 | return sample 221 | 222 | def __len__(self): 223 | return len(self.inputs) 224 | 225 | 226 | class AbstractDataMem(Dataset, ABC): 227 | """ 228 | This dataset class loads all data into memory. 229 | """ 230 | 231 | def __init__(self, input_path, target_path='', min_gt=-np.inf, max_gt=np.inf, sample_length=1420, 232 | input_transforms=None, target_transforms=None): 233 | """ 234 | Constructor, inherits from torch.utils.data.Dataset. 235 | :param input_path: numpy file of waveforms 236 | :param target_path: numpy file of targets 237 | :param max_gt: upper bound of ground truth height 238 | :param sample_length: maximum sample length 239 | :param input_transforms: torchvision.transforms.Compose or single transform 240 | :param target_transforms: torchvision.transforms.Compose or single transform 241 | """ 242 | super(AbstractDataMem, self).__init__() 243 | self.input_path = Path(input_path) 244 | self.target_path = Path(target_path) 245 | self.input_transforms = input_transforms 246 | self.target_transforms = target_transforms 247 | self.min_gt = min_gt 248 | self.max_gt = max_gt 249 | self.sample_length = sample_length 250 | 251 | self.inputs, self.targets, self.mean_noise_lvls, self.quality_indices_train, self.quality_indices_valtest, self.split_attribute = self._get_data() 252 | self._preprocess_samples() 253 | 254 | @abstractmethod 255 | def _get_data(self): 256 | pass 257 | 258 | def _preprocess_samples(self): 259 | """ 260 | Applies three different transformations to the waveforms 261 | """ 262 | assert np.ndim(self.inputs) == 3, 'waveforms should have the shape (num_samples, sample_length, num_featuers=1) e.g. (N, 1420, 1)' 263 | # subtract noise level 264 | self.inputs = self.inputs - self.mean_noise_lvls[..., None, None] 265 | # normalize integral to 1 (total energy return) 266 | self.inputs = self.inputs / np.sum(self.inputs, axis=1)[..., None] 267 | # pad waveforms to a fixed length with zeros at the end 268 | self.inputs = pad_waveforms(waveforms=self.inputs, out_size=self.sample_length) 269 | 270 | def __getitem__(self, index): 271 | """ 272 | Returns the index-th waveform and the corresponding target. 273 | :param index: index within 0 and len(self) 274 | :return: index-th waveform 275 | """ 276 | sample, target, mean_noise_lvl = self.inputs[index], self.targets[index], self.mean_noise_lvls[index] 277 | 278 | # input: Normalize_in, ToTensor 279 | if self.input_transforms: 280 | sample = self.input_transforms(sample) 281 | # target: Normalize_target 282 | if self.target_transforms: 283 | target = self.target_transforms(target) 284 | return sample, target 285 | 286 | def __len__(self): 287 | return len(self.inputs) 288 | 289 | 290 | class CrossOverDataMem(AbstractDataMem): 291 | """ 292 | This dataset class loads all crossover waveforms into memory. 293 | This class differs from the main class as it loads only one big numpy dictionary including all inputs AND targets. 294 | """ 295 | def __init__(self, input_path, target_path='', min_gt=-np.inf, max_gt=np.inf, sample_length=1420, 296 | input_transforms=None, target_transforms=None, 297 | input_key='rxwaveform', target_key='als_rh098', settings_index=3, pearson_thresh=0.95, 298 | split_attribute_name=None, noise_mean_key='noise_mean'): 299 | """ 300 | :param input_key: input key to use (see args.input_key) 301 | :param target_key: target key to use (see args.target_key) 302 | :param settings_index: what setting to use (see self._get_data) 303 | """ 304 | self.input_key = input_key 305 | self.target_key = target_key 306 | self.noise_mean_key = noise_mean_key 307 | self.settings_index = settings_index 308 | self.pearson_thresh = pearson_thresh 309 | self.split_attribute_name = split_attribute_name 310 | super(CrossOverDataMem, self).__init__(input_path, target_path, min_gt, max_gt, sample_length, input_transforms, 311 | target_transforms) 312 | 313 | def _get_data(self): 314 | """ 315 | Load numpy files into memory as np.float32 arrays. 316 | For cross over data, all samples are within one file including targets. 317 | :return: loaded input and target without preprocessing. 318 | """ 319 | if not self.input_path.exists(): 320 | raise FileNotFoundError('The file {} does not exist.'.format(self.input_path)) 321 | if not self.target_path.exists(): 322 | raise FileNotFoundError('The file {} does not exist.'.format(self.target_path)) 323 | 324 | # Load samples and ground truth 325 | print('Start loading dataset.') 326 | data = np.load(self.input_path, allow_pickle=True).item() 327 | inputs = np.array(data[self.input_key], dtype=np.float32, copy=False)[..., None] 328 | targets = np.array(data[self.target_key], dtype=np.float32, copy=False) 329 | mean_noise_lvls = np.array(data[self.noise_mean_key], dtype=np.float32, copy=False) 330 | 331 | if not self.split_attribute_name is None: 332 | split_attribute = np.array(data[self.split_attribute_name], copy=False) 333 | else: 334 | split_attribute = None 335 | 336 | settings = [{'night_strong': True, 'day_strong': False, 'night_coverage': False, 'day_coverage': False}, 337 | {'night_strong': True, 'day_strong': True, 'night_coverage': False, 'day_coverage': False}, 338 | {'night_strong': True, 'day_strong': True, 'night_coverage': True, 'day_coverage': False}, 339 | {'night_strong': True, 'day_strong': True, 'night_coverage': True, 'day_coverage': True}] 340 | 341 | setting = settings[self.settings_index] 342 | 343 | # Filter samples within valid range 344 | print('Start filtering dataset.') 345 | valid_indices = np.logical_and(targets >= self.min_gt, targets <= self.max_gt) 346 | # filter night and strong beams 347 | night_coverage_indices, filter_string = self.filter_shots(data=data, 348 | night_strong=setting['night_strong'], 349 | day_strong=setting['day_strong'], 350 | night_coverage=setting['night_coverage'], 351 | day_coverage=setting['day_coverage']) 352 | valid_indices = np.logical_and(valid_indices, night_coverage_indices) 353 | 354 | # filter by quality criteria 355 | do_filter_quality = True 356 | # check if all quality criteria keys exists (difference between crossover v1 and v2) 357 | for key in ['ground_elev_cog', 'dz_pearson', 'dz_count', 'pearson']: 358 | do_filter_quality = do_filter_quality & (key in data) 359 | print('do_filter_quality:', do_filter_quality) 360 | 361 | if do_filter_quality: 362 | # to allow a different pearson threshold between training and testing 363 | quality_indices_train = get_quality_indices(data_crossover=data, pearson_thresh=self.pearson_thresh) 364 | quality_indices_valtest = get_quality_indices(data_crossover=data, pearson_thresh=0.95) 365 | 366 | quality_indices_train = quality_indices_train[valid_indices] 367 | quality_indices_valtest = quality_indices_valtest[valid_indices] 368 | else: 369 | quality_indices_train = None 370 | quality_indices_valtest = None 371 | 372 | inputs = inputs[valid_indices] 373 | targets = targets[valid_indices] 374 | mean_noise_lvls = mean_noise_lvls[valid_indices] 375 | if split_attribute is not None: 376 | split_attribute = split_attribute[valid_indices] 377 | 378 | print('inputs.shape', inputs.shape) 379 | print('targets.shape', targets.shape) 380 | print('mean_noise_lvls.shape', mean_noise_lvls.shape) 381 | if split_attribute is not None: 382 | print('split_attribute.shape', split_attribute.shape) 383 | 384 | print('Done loading dataset.') 385 | return inputs, targets, mean_noise_lvls, quality_indices_train, quality_indices_valtest, split_attribute 386 | 387 | @staticmethod 388 | def filter_shots(data, night_strong=True, day_strong=True, night_coverage=True, day_coverage=True): 389 | night_indices = data['solar_elevation'] < 0 390 | day_indices = ~night_indices 391 | coverage_indices = data['coverage_flag'] == 1 392 | strong_indices = ~coverage_indices 393 | 394 | out_str = '' 395 | 396 | # init: all points are invalid 397 | valid_indices = np.repeat(0, repeats=len(data['shot_number'])) 398 | 399 | if night_strong: 400 | indices = np.logical_and(night_indices, strong_indices) 401 | valid_indices = np.logical_or(valid_indices, indices) 402 | out_str += 'night-strong' 403 | if day_strong: 404 | indices = np.logical_and(day_indices, strong_indices) 405 | valid_indices = np.logical_or(valid_indices, indices) 406 | out_str += '_day-strong' 407 | if night_coverage: 408 | indices = np.logical_and(night_indices, coverage_indices) 409 | valid_indices = np.logical_or(valid_indices, indices) 410 | out_str += '_night-coverage' 411 | if day_coverage: 412 | indices = np.logical_and(day_indices, coverage_indices) 413 | valid_indices = np.logical_or(valid_indices, indices) 414 | out_str += '_day-coverage' 415 | 416 | return valid_indices, out_str 417 | 418 | 419 | class CustomSubset(Dataset): 420 | """ 421 | Subset of a dataset at specified indices. 422 | 423 | Arguments: 424 | dataset (Dataset): The whole Dataset 425 | indices (sequence): Indices in the whole set selected for subset 426 | """ 427 | def __init__(self, dataset, indices, input_transforms=None, target_transforms=None, augmentation_transforms=None): 428 | """ 429 | Constructor, inherits from torch.utils.data.Dataset. 430 | :param dataset: pytorch dataset object (e.g. SimulatedDataMem) 431 | :param indices: numpy file of sample indices defining the subset 432 | :param input_transforms: torchvision.transforms.Compose or single transform 433 | :param target_transforms: torchvision.transforms.Compose or single transform 434 | :param augmentation_transforms: torchvision.transforms.Compose or single transform. Expects a tuple (input_, target) 435 | """ 436 | super(CustomSubset, self).__init__() 437 | self.dataset = dataset 438 | self.indices = indices 439 | self.input_transforms = input_transforms 440 | self.target_transforms = target_transforms 441 | self.augmentation_transforms = augmentation_transforms 442 | 443 | def __getitem__(self, idx): 444 | input_, target = self.dataset[self.indices[idx]] 445 | 446 | # augmentation 447 | # Note: since we adjust the target in its original scale, augmentation is applied before normalization. 448 | if self.augmentation_transforms: 449 | input_, target = self.augmentation_transforms((input_, target)) 450 | 451 | if self.input_transforms is not None: 452 | input_ = self.input_transforms(input_) 453 | if self.target_transforms is not None: 454 | target = self.target_transforms(target) 455 | return input_, target 456 | 457 | def __len__(self): 458 | return len(self.indices) 459 | 460 | -------------------------------------------------------------------------------- /torch_code/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.tensorboard import SummaryWriter 4 | from torch.utils.data import DataLoader 5 | from torchvision.transforms import Compose 6 | from dataset import CrossOverDataMem, CustomSubset 7 | from utils.transforms import Normalize, RandomShift, ToTensor, denormalize, RandomLabelNoise 8 | from utils.loss import RMSELoss, MELoss, GaussianNLL, LaplacianNLL 9 | from pathlib import Path 10 | from tqdm import tqdm 11 | 12 | 13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 | 15 | 16 | class Trainer: 17 | 18 | def __init__(self, model, log_dir, args): 19 | """ 20 | Initialize a Trainer object to train and test the model. 21 | Args: 22 | model: pytorch model 23 | log_dir: path to directory to save tensorboard logs 24 | args: argparse object (see setup_parser() in utils.parser) 25 | """ 26 | self.model = model 27 | self.args = args 28 | self.out_dir = args.out_dir 29 | self.writer = SummaryWriter(log_dir=log_dir) 30 | self.do_shift_target = self.args.input_key != 'delta_Z0_ground' 31 | self.shift_interval = (-self.args.shift_left, self.args.shift_right) \ 32 | if self.args.shift_left and self.args.shift_right else None 33 | self.train_indices, self.val_indices, self.test_indices = None, None, None 34 | 35 | self.ds_train, self.ds_val, self.ds_test, \ 36 | self.mean_input_train, self.std_input_train, \ 37 | self.mean_target_train, self.std_target_train= self._setup_dataset() 38 | 39 | self.optimizer = self._setup_optimizer() 40 | self.error_metrics = self._setup_metrics() 41 | 42 | print('self.mean_target_train', self.mean_target_train) 43 | print('self.std_target_train', self.std_target_train) 44 | 45 | def _setup_metrics(self): 46 | error_metrics = {'MSE': torch.nn.MSELoss(), 47 | 'RMSE': RMSELoss(), 48 | 'MAE': torch.nn.L1Loss(), 49 | 'ME': MELoss()} 50 | 51 | if self.args.num_outputs == 2: 52 | error_metrics['gaussian_nll'] = GaussianNLL() 53 | error_metrics['laplacian_nll'] = LaplacianNLL() 54 | 55 | print('error_metrics.keys():', error_metrics.keys()) 56 | error_metrics['loss'] = error_metrics[self.args.loss_key] 57 | return error_metrics 58 | 59 | def _setup_optimizer(self): 60 | if self.args.optimizer == 'ADAM': 61 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.base_learning_rate, 62 | weight_decay=self.args.l2_lambda) 63 | elif self.args.optimizer == 'SGD': 64 | optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.base_learning_rate, 65 | weight_decay=self.args.l2_lambda) 66 | else: 67 | raise ValueError("Solver '{}' is not defined.".format(self.args.optimizer)) 68 | return optimizer 69 | 70 | def _setup_transforms(self, mean_input_train, std_input_train, mean_target_train, std_target_train): 71 | input_transforms = Compose([Normalize(mean=mean_input_train, std=std_input_train), ToTensor()]) 72 | 73 | if self.args.normalize_targets: 74 | target_transforms = Compose([Normalize(mean=mean_target_train, std=std_target_train)]) 75 | else: 76 | target_transforms = None 77 | 78 | # data augmentation for training 79 | augmentation_transforms = Compose([RandomShift(shift_interval=self.shift_interval, 80 | do_shift_target=self.do_shift_target)]) 81 | 82 | # TODO Option to add random label noise as an augmentation to carry out robustness experiments 83 | # RandomLabelNoise(rel_label_noise=self.args.label_noise, 84 | # distribution=self.args.label_noise_distribution)]) 85 | 86 | return input_transforms, target_transforms, augmentation_transforms 87 | 88 | def _setup_dataset(self): 89 | if self.args.dataset == 'CROSSOVER_GEDI': 90 | dataset = CrossOverDataMem(input_path=self.args.inputs_path, 91 | target_key=self.args.target_key, 92 | input_key=self.args.input_key, 93 | min_gt=self.args.min_gt, 94 | max_gt=self.args.max_gt, 95 | sample_length=self.args.sample_length, 96 | settings_index=self.args.setting_idx, 97 | pearson_thresh=self.args.pearson_thresh, 98 | split_attribute_name=self.args.ood_attribute, 99 | noise_mean_key=self.args.noise_mean_key) 100 | 101 | else: 102 | raise ValueError('Dataset {} is undefined.'.format(self.args.dataset)) 103 | 104 | # print target range: 105 | print('overall target min: {}, max: {}'.format(np.min(dataset.targets), np.max(dataset.targets))) 106 | 107 | # Split dataset into three subsets 108 | 109 | if self.args.data_split == 'randCV': 110 | # Random n-fold cross validation split 111 | self.train_indices, self.val_indices, self.test_indices = self._split_dataset_random_nfold_CV(len_dataset=len(dataset), 112 | n_folds=self.args.n_folds, 113 | test_fold_idx=self.args.test_fold_idx, 114 | quality_indices_train=dataset.quality_indices_train, 115 | quality_indices_valtest=dataset.quality_indices_valtest) 116 | 117 | if self.args.range_to_remove is not None: 118 | # Remove specific target range from training and validation indices. 119 | # We keep that range in the test data to evaluate the epistemic uncertainty 120 | print('TRAIN & VAL data is filtered in the range: ', self.args.range_to_remove) 121 | self.train_indices, out_dist_indices_train = self.filter_subset_indices_by_attribute_range(self.train_indices, 122 | dataset_attribute=dataset.split_attribute, 123 | range_to_remove=self.args.range_to_remove) 124 | self.val_indices, out_dist_indices_val = self.filter_subset_indices_by_attribute_range(self.val_indices, 125 | dataset_attribute=dataset.split_attribute, 126 | range_to_remove=self.args.range_to_remove) 127 | 128 | # append OOD indices to test_indices 129 | print('num samples OOD: ', len(out_dist_indices_train) + len(out_dist_indices_val)) 130 | print('test_indices.shape before', self.test_indices.shape) 131 | self.test_indices = np.concatenate((self.test_indices, out_dist_indices_train, out_dist_indices_val)) 132 | print('test_indices.shape with ood', self.test_indices.shape) 133 | 134 | # Check test indices 135 | test_indices_in_dist, test_indices_out_dist = self.filter_subset_indices_by_attribute_range(self.test_indices, 136 | dataset_attribute=dataset.split_attribute, 137 | range_to_remove=self.args.range_to_remove) 138 | print('test_indices_in_dist.shape', test_indices_in_dist.shape) 139 | print('test_indices_out_dist.shape', test_indices_out_dist.shape) 140 | 141 | elif self.args.ood_attribute is not None: 142 | # Remove a specific attribute value from train and val data. 143 | # We keep that attribute in the test data to evaluate the epistemic uncertainty for in- and out of distribution. 144 | 145 | if self.args.ood_value_string is not None: 146 | self.args.ood_value = self.args.ood_value_string 147 | 148 | print('TRAIN & VAL data is filtered by attribute: {} and value: {}'.format(self.args.ood_attribute, self.args.ood_value)) 149 | self.train_indices, out_dist_indices_train = self.filter_subset_indices_by_attribute(self.train_indices, 150 | dataset_attribute=dataset.split_attribute, 151 | out_dist_value=self.args.ood_value) 152 | self.val_indices, out_dist_indices_val = self.filter_subset_indices_by_attribute(self.val_indices, 153 | dataset_attribute=dataset.split_attribute, 154 | out_dist_value=self.args.ood_value) 155 | 156 | # append OOD indices to test_indices 157 | print('num samples OOD: ', len(out_dist_indices_train) + len(out_dist_indices_val)) 158 | print('test_indices.shape before', self.test_indices.shape) 159 | self.test_indices = np.concatenate((self.test_indices, out_dist_indices_train, out_dist_indices_val)) 160 | print('test_indices.shape with ood', self.test_indices.shape) 161 | 162 | # Check test indices 163 | test_indices_in_dist, test_indices_out_dist = self.filter_subset_indices_by_attribute(self.test_indices, 164 | dataset_attribute=dataset.split_attribute, 165 | out_dist_value=self.args.ood_value) 166 | print('test_indices_in_dist.shape', test_indices_in_dist.shape) 167 | print('test_indices_out_dist.shape', test_indices_out_dist.shape) 168 | 169 | elif self.args.data_split == 'attrCV': 170 | # Split by attribute: Hold-out a specific attribute value 171 | self.train_indices, self.val_indices, self.test_indices = self._split_dataset_by_attribute(attribute=dataset.split_attribute, 172 | test_attribute=self.args.test_attribute_value, 173 | quality_indices_train=dataset.quality_indices_train, 174 | quality_indices_valtest=dataset.quality_indices_valtest) 175 | else: 176 | raise ValueError("self.args.data_split = '{}' is not defined".format(self.args.data_split)) 177 | 178 | # check that data splits do not overlap 179 | if not set(self.train_indices).isdisjoint(set(self.val_indices)): 180 | raise ValueError("TRAIN indices overlap with VAL indices.") 181 | if not set(self.train_indices).isdisjoint(set(self.test_indices)): 182 | raise ValueError("TRAIN indices overlap with TEST indices.") 183 | if not set(self.test_indices).isdisjoint(set(self.val_indices)): 184 | raise ValueError("TEST indices overlap with VAL indices.") 185 | 186 | # Calculate mean and std of training set and setup transforms. 187 | mean_input_train = np.mean(dataset.inputs[self.train_indices]) 188 | std_input_train = np.std(dataset.inputs[self.train_indices]) 189 | 190 | mean_target_train = np.mean(dataset.targets[self.train_indices]) 191 | std_target_train = np.std(dataset.targets[self.train_indices]) 192 | 193 | # save training mean and std 194 | np.save(Path(self.args.out_dir) / 'mean_input_train.npy', mean_input_train) 195 | np.save(Path(self.args.out_dir) / 'std_input_train.npy', std_input_train) 196 | 197 | np.save(Path(self.args.out_dir) / 'mean_target_train.npy', mean_target_train) 198 | np.save(Path(self.args.out_dir) / 'std_target_train.npy', std_target_train) 199 | 200 | input_transforms, target_transforms, augmentation_transforms = self._setup_transforms(mean_input_train=mean_input_train, 201 | std_input_train=std_input_train, 202 | mean_target_train=mean_target_train, 203 | std_target_train=std_target_train) 204 | 205 | ds_train = CustomSubset(dataset, self.train_indices, 206 | input_transforms=input_transforms, 207 | target_transforms=target_transforms, 208 | augmentation_transforms=augmentation_transforms) 209 | ds_val = CustomSubset(dataset, self.val_indices, 210 | input_transforms=input_transforms, 211 | target_transforms=target_transforms) 212 | ds_test = CustomSubset(dataset, self.test_indices, 213 | input_transforms=input_transforms, 214 | target_transforms=target_transforms) 215 | return ds_train, ds_val, ds_test, mean_input_train, std_input_train, mean_target_train, std_target_train 216 | 217 | def train(self): 218 | """ 219 | A routine to train and validated the model for several epochs. 220 | """ 221 | # Initialize train and validation loader 222 | dl_train = DataLoader(self.ds_train, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.num_workers) 223 | dl_val = DataLoader(self.ds_val, batch_size=self.args.batch_size, shuffle=False, num_workers=self.args.num_workers) 224 | 225 | # Init best losses for weights saving. 226 | loss_val_best = np.inf 227 | best_epoch = None 228 | 229 | if self.args.model_weights_path is not None: 230 | # load best model weights 231 | print('ATTENTION: loading pretrained model weights from:') 232 | print(self.args.model_weights_path) 233 | self.model.load_state_dict(torch.load(self.args.model_weights_path)) 234 | 235 | # Start training 236 | for epoch in range(self.args.nb_epoch): 237 | epoch += 1 238 | print('Epoch: {} / {} '.format(epoch, self.args.nb_epoch)) 239 | 240 | # optimize parameters 241 | training_metrics = self.optimize_epoch(dl_train) 242 | # validated performance 243 | val_dict, val_metrics = self.validate(dl_val) 244 | 245 | # -------- LOG TRAINING METRICS -------- 246 | metric_string = 'TRAIN: ' 247 | for metric in self.error_metrics.keys(): 248 | # tensorboard logs 249 | self.writer.add_scalar('{}/train'.format(metric), training_metrics[metric], epoch) 250 | metric_string += ' {}: {:.3f},'.format(metric, training_metrics[metric]) 251 | print(metric_string) 252 | 253 | # -------- LOG VALIDATION METRICS -------- 254 | metric_string = 'VAL: ' 255 | for metric in self.error_metrics: 256 | # tensorboard logs 257 | self.writer.add_scalar('{}/val'.format(metric), val_metrics[metric], epoch) 258 | metric_string += ' {}: {:.3f},'.format(metric, val_metrics[metric]) 259 | print(metric_string) 260 | 261 | # logging the estimated variance 262 | if 'log_variances' in val_dict: 263 | val_dict['variances'] = torch.exp(val_dict['log_variances']) 264 | 265 | if self.args.normalize_targets: 266 | # denormalize the variance 267 | val_dict['variances'] = val_dict['variances'] * self.std_target_train**2 268 | 269 | self.writer.add_scalar('var_mean/val', torch.mean(val_dict['variances']), epoch) 270 | self.writer.add_scalar('std_mean/val', torch.mean(torch.sqrt(val_dict['variances'])), epoch) 271 | self.writer.add_scalar('std_min/val', torch.min(torch.sqrt(val_dict['variances'])), epoch) 272 | self.writer.add_scalar('std_max/val', torch.max(torch.sqrt(val_dict['variances'])), epoch) 273 | self.writer.add_scalar('var_count_infinite_elements/val', self.count_infinite_elements(val_dict['variances']), epoch) 274 | 275 | print('VAL: Number of infinite elements in variances: ', self.count_infinite_elements(val_dict['variances'])) 276 | 277 | if val_metrics['loss'] < loss_val_best: 278 | loss_val_best = val_metrics['loss'] 279 | best_epoch = epoch 280 | # save and overwrite the best model weights: 281 | path = Path(self.out_dir) / 'best_weights.pt' 282 | torch.save(self.model.state_dict(), path) 283 | print('Saved weights at {}'.format(path)) 284 | 285 | # stop training if loss is nan 286 | if np.isnan(training_metrics['loss']) or np.isnan(val_metrics['loss']): 287 | raise ValueError("Training loss is nan. Stop training.") 288 | 289 | # TODO: Currently we save only the best and last epoch weights --> maybe want to save every nth epoch. 290 | print('Best val loss: {} at epoch: {}'.format(loss_val_best, best_epoch)) 291 | # save model weights after last epoch: 292 | path = Path(self.out_dir) / 'weights_last_epoch.pt' 293 | torch.save(self.model.state_dict(), path) 294 | print('Saved weights at {}'.format(path)) 295 | 296 | def optimize_epoch(self, dl_train): 297 | """ 298 | Run the optimization for one epoch. 299 | 300 | Args: 301 | dl_train: torch dataloader with training data. 302 | 303 | Returns: Dict with error metrics on training data (including the loss). Used for tensorboard logs. 304 | """ 305 | # init running error 306 | training_metrics = {} 307 | for metric in self.error_metrics: 308 | training_metrics[metric] = 0 309 | 310 | total_count_infinite_var = 0 311 | 312 | # set model to training mode 313 | self.model.train() 314 | for step, (inputs, labels) in enumerate(tqdm(dl_train, ncols=100, desc='train')): 315 | inputs, labels = inputs.to(device), labels.to(device) 316 | # Run forward pass 317 | predictions = self.model.forward(inputs).squeeze(dim=-1) 318 | 319 | if self.args.num_outputs == 2: 320 | predictions, log_variances = predictions[:, 0], predictions[:, 1] 321 | # pass predicted mean and log_variance to e.g. gaussian_nll 322 | loss = self.error_metrics['loss'](predictions, log_variances, labels) 323 | 324 | # debug 325 | variances = torch.exp(log_variances) 326 | count_infinite = self.count_infinite_elements(variances) 327 | total_count_infinite_var += count_infinite 328 | 329 | else: 330 | loss = self.error_metrics['loss'](predictions, labels) 331 | 332 | # Run backward pass 333 | self.optimizer.zero_grad() 334 | loss.backward() 335 | self.optimizer.step() 336 | 337 | # compute metrics on every batch and add to running sum 338 | for metric in self.error_metrics: 339 | if self.args.num_outputs == 2 and metric in ['gaussian_nll', 'laplacian_nll', 'loss']: 340 | training_metrics[metric] += self.error_metrics[metric](predictions, log_variances, labels).item() 341 | else: 342 | if self.args.normalize_targets: 343 | # denormalize labels and predictions 344 | predictions_ = denormalize(predictions, self.mean_target_train, self.std_target_train) 345 | labels_ = denormalize(labels, self.mean_target_train, self.std_target_train) 346 | training_metrics[metric] += self.error_metrics[metric](predictions_, labels_).item() 347 | else: 348 | training_metrics[metric] += self.error_metrics[metric](predictions, labels).item() 349 | 350 | # debug 351 | if total_count_infinite_var > 0: 352 | print('TRAIN DEBUG: ATTENTION: count infinite elements in variances is: {}'.format(total_count_infinite_var)) 353 | 354 | # average over number of batches 355 | for metric in self.error_metrics.keys(): 356 | training_metrics[metric] /= len(dl_train) 357 | return training_metrics 358 | 359 | def validate(self, dl_val): 360 | """ 361 | Validate the model on validation data. 362 | 363 | Args: 364 | dl_val: torch dataloader with validation data 365 | 366 | Returns: 367 | val_dict: Dict with torch tensors for 'predictions', 'targets', 'log_variances'. 368 | val_metrics: Dict with error metrics on validation data (including the loss). Used for tensorboard logs. 369 | """ 370 | # set model to eval model 371 | self.model.eval() 372 | 373 | # init validation results for current epoch 374 | val_dict = {'predictions': [], 'targets': []} 375 | 376 | if self.args.num_outputs == 2: 377 | val_dict['log_variances'] = [] 378 | 379 | with torch.no_grad(): 380 | for step, (inputs, labels) in enumerate(dl_val): # for each training step 381 | 382 | inputs = inputs.to(device) 383 | labels = labels.to(device) 384 | 385 | predictions = self.model.forward(inputs).squeeze(dim=-1) 386 | if self.args.num_outputs == 2: 387 | predictions, log_variances = predictions[:, 0], predictions[:, 1] 388 | val_dict['log_variances'] += list(log_variances) 389 | 390 | val_dict['predictions'] += list(predictions) 391 | val_dict['targets'] += list(labels) 392 | 393 | for key in val_dict.keys(): 394 | if val_dict[key]: 395 | val_dict[key] = torch.stack(val_dict[key], dim=0) 396 | print("val_dict['{}'].shape: ".format(key), val_dict[key].shape) 397 | 398 | val_metrics = {} 399 | 400 | for metric in self.error_metrics: 401 | if self.args.num_outputs == 2 and metric in ['gaussian_nll', 'laplacian_nll', 'loss']: 402 | val_metrics[metric] = self.error_metrics[metric](val_dict['predictions'], 403 | val_dict['log_variances'], 404 | val_dict['targets']).item() 405 | else: 406 | # denormalize labels and predictions 407 | if self.args.normalize_targets: 408 | predictions_ = denormalize(val_dict['predictions'], self.mean_target_train, self.std_target_train) 409 | targets_ = denormalize(val_dict['targets'], self.mean_target_train, self.std_target_train) 410 | val_metrics[metric] = self.error_metrics[metric](predictions_, targets_).item() 411 | else: 412 | val_metrics[metric] = self.error_metrics[metric](val_dict['predictions'], 413 | val_dict['targets']).item() 414 | return val_dict, val_metrics 415 | 416 | def test(self, model_weights_path=None, dl_test=None): 417 | """ 418 | Test trained model on test data. 419 | 420 | Args: 421 | model_weights_path: path to trained model weights. Default: "best_weights.pt" 422 | dl_test: torch dataloader with test data. Default: self.ds_test is loaded. 423 | 424 | Returns: 425 | test_metrics: Dict with error metrics on test data (including the loss). Used for tensorboard logs. 426 | test_dict: Dict with torch tensors for 'predictions', 'targets', 'variances'. 427 | metric_string: formatted string to print test metrics. 428 | """ 429 | if dl_test is None: 430 | dl_test = DataLoader(self.ds_test, batch_size=self.args.batch_size, shuffle=False, num_workers=self.args.num_workers) 431 | # test performance 432 | 433 | if model_weights_path is None: 434 | model_weights_path = Path(self.out_dir) / 'best_weights.pt' 435 | 436 | # load best model weights 437 | self.model.load_state_dict(torch.load(model_weights_path)) 438 | 439 | test_dict, test_metrics = self.validate(dl_test) 440 | 441 | # convert log(var) to var 442 | if self.args.num_outputs == 2: 443 | test_dict['variances'] = torch.exp(test_dict['log_variances']) 444 | del test_dict['log_variances'] 445 | 446 | # denormalize predictions and targets 447 | if self.args.normalize_targets: 448 | test_dict['predictions'] = denormalize(test_dict['predictions'], self.mean_target_train, self.std_target_train) 449 | test_dict['targets'] = denormalize(test_dict['targets'], self.mean_target_train, self.std_target_train) 450 | if self.args.num_outputs == 2: 451 | # denormalize the variances by multiplying with the target variance 452 | test_dict['variances'] = test_dict['variances'] * self.std_target_train**2 453 | 454 | if self.args.num_outputs == 2: 455 | print('TEST: Number infinite elements in variances: ', self.count_infinite_elements(test_dict['variances'])) 456 | 457 | # convert torch tensor to numpy 458 | for key in test_dict.keys(): 459 | test_dict[key] = test_dict[key].data.cpu().numpy() 460 | 461 | metric_string = 'TEST: ' 462 | for metric in self.error_metrics: 463 | metric_string += ' {}: {:.3f},'.format(metric, test_metrics[metric]) 464 | print(metric_string) 465 | return test_metrics, test_dict, metric_string 466 | 467 | def count_infinite_elements(self, x): 468 | return torch.sum(torch.logical_not(torch.isfinite(x))).item() 469 | 470 | @staticmethod 471 | def _split_dataset_random_nfold_CV(len_dataset, n_folds=10, test_fold_idx=0, quality_indices_train=None, quality_indices_valtest=None): 472 | """ 473 | Split data into n folds for cross-validation. 474 | """ 475 | # shuffle the samples randomly to create train, val and test sets 476 | indices = np.arange(len_dataset) 477 | # always generate the same random numbers with random seed for testing the code 478 | np.random.seed(2401) 479 | np.random.shuffle(indices) 480 | 481 | # split indices int n folds: 482 | indices_list = np.array_split(indices, n_folds) 483 | 484 | test_indices = indices_list[test_fold_idx] 485 | 486 | # all except indices from fold test fold i 487 | trainval_indices = indices_list[:test_fold_idx] + indices_list[test_fold_idx + 1:] 488 | trainval_indices = np.concatenate(trainval_indices) 489 | 490 | # split training into 90% train and 10% val 491 | train_indices = trainval_indices[:int(0.9 * len(trainval_indices))] 492 | val_indices = trainval_indices[int(0.9 * len(trainval_indices)):] 493 | 494 | def filter_subset_indices(subset_indices, quality_indices): 495 | return subset_indices[quality_indices[subset_indices]] 496 | 497 | if quality_indices_train is not None: 498 | train_indices = filter_subset_indices(train_indices, quality_indices_train) 499 | val_indices = filter_subset_indices(val_indices, quality_indices_valtest) 500 | test_indices = filter_subset_indices(test_indices, quality_indices_valtest) 501 | 502 | return train_indices, val_indices, test_indices 503 | 504 | 505 | @staticmethod 506 | def _split_dataset_by_attribute(attribute, test_attribute, quality_indices_train=None, quality_indices_valtest=None): 507 | """ 508 | Split data using all samples with a specific attribute as test data. 509 | """ 510 | 511 | test_indices = np.argwhere(attribute == test_attribute) 512 | trainval_indices = np.argwhere(attribute != test_attribute) 513 | 514 | # split training into 80% train and 20% val 515 | train_indices = trainval_indices[:int(0.9 * len(trainval_indices))] 516 | val_indices = trainval_indices[int(0.9 * len(trainval_indices)):] 517 | 518 | def filter_subset_indices(subset_indices, quality_indices): 519 | return subset_indices[quality_indices[subset_indices]] 520 | 521 | if quality_indices_train is not None: 522 | train_indices = filter_subset_indices(train_indices, quality_indices_train) 523 | val_indices = filter_subset_indices(val_indices, quality_indices_valtest) 524 | test_indices = filter_subset_indices(test_indices, quality_indices_valtest) 525 | 526 | return train_indices, val_indices, test_indices 527 | 528 | def filter_subset_indices_by_attribute_range(self, subset_indices, dataset_attribute, range_to_remove): 529 | """ 530 | Returns subset_indices that are within and outside the specified target range 531 | """ 532 | subset_targets = dataset_attribute[subset_indices] 533 | range_indices = (subset_targets > range_to_remove[0]) & (subset_targets < range_to_remove[1]) 534 | 535 | in_dist_indices = subset_indices[~range_indices] 536 | out_dist_indices = subset_indices[range_indices] 537 | return in_dist_indices, out_dist_indices 538 | 539 | def filter_subset_indices_by_attribute(self, subset_indices, dataset_attribute, out_dist_value): 540 | """ 541 | Returns subset_indices that have a specific attribute value as out_dist_indices and 542 | the remaining samples as in_dist_indices. 543 | """ 544 | subset_attribute = dataset_attribute[subset_indices] 545 | in_dist_indices = subset_indices[subset_attribute != out_dist_value] 546 | out_dist_indices = subset_indices[subset_attribute == out_dist_value] 547 | return in_dist_indices, out_dist_indices 548 | 549 | --------------------------------------------------------------------------------