├── MLproject ├── README.md ├── experiment_scripts ├── create_test_data.sh ├── examples │ ├── baselines.sh │ ├── test_destination.sh │ ├── test_hour_day.sh │ ├── test_length_of_history.sh │ ├── test_loss.sh │ ├── test_model_type.sh │ ├── test_time_gaps.sh │ └── test_weather.sh ├── run_script_1.sh ├── run_script_2.sh ├── run_script_3.sh ├── run_script_4.sh └── run_script_5.sh ├── model_fitting_environment.yml ├── processing ├── cleaner.py ├── config │ ├── config.py │ ├── dataset_config.py │ ├── navigation_statuses.csv │ └── vessel_type_codes.csv ├── current_aggregator.py ├── current_appender.py ├── current_downloader.py ├── destination_appender.py ├── downloader.py ├── formatter.py ├── interpolator.py ├── process.sh ├── processing_step.py ├── sliding_window.py └── utils.py ├── processing_environment.yml ├── resources_and_information ├── ais_data_faq_from_marine_cadastre.pdf ├── coast_guard_mmsi_document.pdf ├── vessel_type_codes_2018.pdf └── vessel_type_guide.pdf └── tests ├── config ├── config.py └── dataset_config.py ├── create_data.py ├── fit_and_evaluate_model.py ├── loading ├── __init__.py ├── data_loader.py ├── disk_array.py ├── generator.py ├── loading.py └── normalizer.py ├── models ├── __init__.py ├── direct_fusion_runner.py ├── direct_rnn_runner.py ├── iterative_rnn_runner.py ├── losses.py ├── median_stopping.py ├── model_runner.py ├── seq2seq_model_pieces.py └── seq2seq_runner.py └── utils ├── __init__.py ├── arg_validation.py ├── processor_manager.py ├── test_arg_parser.py └── utils.py /MLproject: -------------------------------------------------------------------------------- 1 | name: Ships 2 | 3 | conda_env: model_fitting_environment.yml 4 | 5 | entry_points: 6 | test_time_gaps: 7 | parameters: 8 | time_gap: float 9 | dataset_name: str 10 | command: 'python tests/fit_and_evaluate_model.py {time_gap} {dataset_name}' 11 | create_data: 12 | parameters: 13 | time_gap: float 14 | dataset_name: str 15 | command: 'python tests/create_data.py {time_gap} {dataset_name}' -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vessel trajectory prediction with recurrent neural networks: An evaluation of datasets, features, and architectures 2 | 3 | Official implementation of the Journal of Ocean Engineering and Science paper. Please create an issue and tag me (@isaacOnline) if you have any questions. 4 | 5 | ## Repository Overview 6 | This repository repository contains code for downloading and processing AIS and weather data, as well as for fitting 7 | vessel trajectory predictions models. 8 | 9 | ### Organization 10 | * The [processing](processing) directory contains all python files used for creating a dataset of AIS messages, as well as 11 | a script for doing so. See [Creating a New Dataset](#creating-a-new-dataset) below for instructions. 12 | * The [tests](tests) directory contains all Python files used for fitting models. See [Fitting and Evaluating 13 | Models](#fitting-and-evaluating-models) below for more details. 14 | * The [experiment_scripts](experiment_scripts) directory contains shell scripts that can serve as examples for fitting models. 15 | * The [resources_and_information](resources_and_information) directory contains documents with extra information about the datasets we used. 16 | 17 | ### Python Environments 18 | There are two distinct python environments to use with this repository, [one for processing](processing_environment.yml), 19 | and [one for fitting models](model_fitting_environment.yml). 20 | The yaml specification files for these environments are provided at the root level. It will be necessary to create the 21 | processing environment from the yaml file yourself if you want to run anything in the processing directory, but if you 22 | use the below instructions to train models, mlflow will create the model fitting environment for you. 23 | 24 | ## Fitting and Evaluating Models 25 | Before fitting any models, make sure that you have set the paths in the [config file](tests/config/config.py) to point 26 | to where the data is stored on your machine. 27 | 28 | This project uses [mlflow](https://www.mlflow.org) for model training, saving model artifacts, and storing data 29 | about their performance. The mlflow library can be installed using conda, and can then be run from the command line. 30 | The first time you train a model, mlflow will recreate the conda environment that we have been using for model fitting, 31 | which it will then reload when you run models in the future. (If you'd like to recreate this environment 32 | without mlflow, you can use the model_fitting_environment.yml specifications file). 33 | 34 | The script for training and evaluating models can be kicked off using the 'test_time_gaps' mlflow entry point. This is 35 | done by running the following from the command line. **Other command line arguments will need to be passed in for the 36 | entry point to work properly. Please use the shell scripts in [experiment_scripts/examples]( 37 | experiment_scripts/examples) as templates.** 38 | 39 | > mlflow run . -e test_time_gaps 40 | 41 | Mlflow must be kicked off from the project's root directory. Running the command will have mlflow run the 42 | fit_and_evaluate_model.py script. Provided that the correct preprocesing has been run beforehand, all of the features of 43 | the experiment can be controlled using command line arguments. Mlflow uses the syntax -P argname=argvalue, for example 44 | '-P weather=ignore'. For a full list of possible arguments and values, see the 45 | [TestArgParser](tests/utils/test_arg_parser.py) object in tests/utils, or run the following command from the root 46 | directory: 47 | > python fit_and_evaluate_model.py -h 48 | 49 | Some arguments are dependent on one another, but the script should warn you if you enter an invalid combination. 50 | 51 | ## Creating a New Dataset 52 | The first step in creating a new dataset is to add specifications to the 53 | configuration files found in [processing/config/dataset_config.py](processing/config/dataset_config.py) and 54 | [tests/config/dataset_config.py](tests/config/dataset_config.py). You will need 55 | to specify a name for the dataset, the latitude/longitude boundaries, the desired amount of time to move 56 | the window forward when performing sliding window calculations, and an initial set of min_pts values to try when using 57 | DBSCAN to create destination clusters. See current examples in the *dataset_config.py* files 58 | for reference. Make sure to include the same information in both of the *dataset_config.py* files. Also make sure to 59 | update the paths in [processing/config/config.py](processing/config/config.py) and 60 | [tests/config/config.py](tests/config/config.py) 61 | 62 | The second step is to change the dataset names in [process.sh](processing/process.sh), then to run the first section of 63 | [process.sh](processing/process.sh). Make sure that you have created and activated the processing conda environment before doing so, and that you 64 | run [process.sh](processing/process.sh) from the processing directory. Once the first section is done, you will then need to specify parameters 65 | for DBSCAN to use when creating destination clusters (see [Using the Destination 66 | Appender](#using-the-destination-appender) 67 | below for details). The final step is to run the second section of [process.sh](processing/process.sh), which will 68 | calculate the destination clusters and perform the rest of the preprocessing steps. 69 | 70 | If you are interested in changing other features of the experiment that aren't defined in dataset_config.py, most 71 | relevant constants have been defined in [processing/config/config.py](processing/config/config.py) and 72 | [tests/config/config.py](tests/config/config.py). For example if you 73 | want to change the vessel types to also include military vessels or the maximum SOG value that messages can 74 | have without getting filtered out, you can do so here. Make sure that these two config files are kept in sync. 75 | 76 | ## Using the Destination Appender 77 | Because we did not have access to reliable destination data, we elected to calculate our own destination clusters for 78 | trajectories using DBSCAN. DBSCAN requires the selection of two parameters: min_pts and eps. I would recommend reviewing 79 | [these](https://www.aaai.org/Papers/KDD/1996/KDD96-037.pdf?source=post_page) 80 | [papers](https://dl.acm.org/doi/pdf/10.1145/3068335) for information on how to best chose these parameters. The 81 | destination appender is set up to create plots to help with the choice, if parameter values have not already 82 | been selected. 83 | 84 | If you are not interested in using destination data, you can edit the from_dir in 85 | [current_appender.py](processing/current_appender.py#L22) from "interpolated_with_destination" to "interpolated". You 86 | can then remove the "destination_appender.py" lines in [process.sh](processing/process.sh). 87 | This will simply skip over the destination appending step. Otherwise, feel free to follow the steps below for choosing 88 | the min_pts/eps values for DBSCAN, or to set them on your own. 89 | 90 | 1. Before doing anything else, set a range of min_pts values to try in 91 | [processing/config/dataset_config.py](processing/config/dataset_config.py). I tried the values: 4, 10, 20, 50, 100, 92 | 250, and 500 as defaults, but you may want to switch these if your dataset is of a significantly different size. 93 | 2. Run the destination appender script with these values. This will create an elbow plot, (in the *artifacts/ 94 | destination_appending* subdirectory of your data directory), which will help you set a range of eps values to try. 95 | The range of eps values should also be set in [dataset_config.py](processing/config/dataset_config.py). 96 | 3. Once you've selected both parameter ranges, run the destination appender again. This will create plots of a number of 97 | cluster quality metrics for each of the min_pts/eps values, which you can use to select a min_pts value to use. (Again, 98 | set this in [dataset_config.py](processing/config/dataset_config.py).) 99 | 4. Run the destination appender again. This will create plots showing the clusters. Use these plots to 100 | select a final value of eps to use. 101 | 5. Run the destination appender a final time - this will actually calculate and append the destination values. 102 | 103 | ## Memory Issues 104 | I encountered a number of memory issues when creating datasets, so the repository is currently optimized to conserve 105 | memory when possible. (The data processing is fairly slow for this reason.) Despite this, if the desired region or time 106 | period is large enough, users may still run out of memory and have their scripts killed when creating their own 107 | datasets. The below summarizes some of the changes you can make to get around these problems. 108 | 109 | * If the dataset is too large at the cleaning step, you can try changing the shuffle method that dask uses to 110 | 'disk' [here](processing/cleaner.py#L301). The dataset needs to be sorted by MMSI and timestamp during this step, which 111 | is what creates the bottleneck. 112 | * If the dataset is too large at the formatting step, you can try changing the 113 | [partition size](processing/formatter.py#L254) to a smaller value. This will make the partitions smaller. 114 | * If the dataset is small enough at the formatting step, you will be able to change the 115 | [conserve_memory flag](processing/formatter.py#L66) in *formatter.py* to False, which will significantly speed up the 116 | final processing step. After the windowing step, the dataset is stored in chunks, and if the conserve_memory flag is set 117 | to True (the default) these chunks will be processed iteratively, saving the processed versions to disk separately. 118 | Otherwise, the chunks can be processed in parallel, then later combined together. 119 | 120 | 121 | -------------------------------------------------------------------------------- /experiment_scripts/examples/test_model_type.sh: -------------------------------------------------------------------------------- 1 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=florida_gulf --experiment-name 'Final' -P model_type=iterative -P sog_cog=raw -P rnn_to_dense_connection=all_nodes -P batch_size=2048 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.007503556066789443 -P number_of_dense_layers=2 -P number_of_rnn_layers=5 -P rnn_layer_size=91 -P dense_layer_size=263 2 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=florida_gulf --experiment-name 'Final' -P model_type=iterative -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 3 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=florida_gulf --experiment-name 'Final' -P model_type=iterative -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=128 -P dense_layer_size=260 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 4 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=florida_gulf --experiment-name 'Final' -P model_type=attention_seq2seq -P sog_cog=raw -P batch_size=2048 -P direction=forward_only -P distance_traveled=ignore -P layer_type=gru -P learning_rate=0.007503556066789443 -P number_of_dense_layers=2 -P number_of_rnn_layers=5 -P rnn_layer_size=91 -P dense_layer_size=263 5 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=florida_gulf --experiment-name 'Final' -P model_type=attention_seq2seq -P sog_cog=raw -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=gru -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 6 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=florida_gulf --experiment-name 'Final' -P model_type=attention_seq2seq -P sog_cog=raw -P batch_size=128 -P dense_layer_size=260 -P direction=forward_only -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 7 | 8 | 9 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=iterative -P sog_cog=raw -P rnn_to_dense_connection=all_nodes -P batch_size=2048 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.007503556066789443 -P number_of_dense_layers=2 -P number_of_rnn_layers=5 -P rnn_layer_size=91 -P dense_layer_size=263 10 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=iterative -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 11 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=iterative -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=128 -P dense_layer_size=260 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 12 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=attention_seq2seq -P sog_cog=raw -P batch_size=2048 -P direction=forward_only -P distance_traveled=ignore -P layer_type=gru -P learning_rate=0.007503556066789443 -P number_of_dense_layers=2 -P number_of_rnn_layers=5 -P rnn_layer_size=91 -P dense_layer_size=263 13 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=attention_seq2seq -P sog_cog=raw -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=gru -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 14 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=attention_seq2seq -P sog_cog=raw -P batch_size=128 -P dense_layer_size=260 -P direction=forward_only -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 15 | 16 | 17 | 18 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=new_york --experiment-name 'Final' -P model_type=iterative -P sog_cog=raw -P rnn_to_dense_connection=all_nodes -P batch_size=2048 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.007503556066789443 -P number_of_dense_layers=2 -P number_of_rnn_layers=5 -P rnn_layer_size=91 -P dense_layer_size=263 19 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=new_york --experiment-name 'Final' -P model_type=iterative -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 20 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=new_york --experiment-name 'Final' -P model_type=iterative -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=128 -P dense_layer_size=260 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 21 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=new_york --experiment-name 'Final' -P model_type=attention_seq2seq -P sog_cog=raw -P batch_size=2048 -P direction=forward_only -P distance_traveled=ignore -P layer_type=gru -P learning_rate=0.007503556066789443 -P number_of_dense_layers=2 -P number_of_rnn_layers=5 -P rnn_layer_size=91 -P dense_layer_size=263 22 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=new_york --experiment-name 'Final' -P model_type=attention_seq2seq -P sog_cog=raw -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=gru -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 23 | mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=new_york --experiment-name 'Final' -P model_type=attention_seq2seq -P sog_cog=raw -P batch_size=128 -P dense_layer_size=260 -P direction=forward_only -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 24 | -------------------------------------------------------------------------------- /experiment_scripts/run_script_3.sh: -------------------------------------------------------------------------------- 1 | # Hour/Day NY 2 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=1 -P loss=haversine -P time_of_day=ignore -P weather=currents -P dataset_name=new_york --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=all_nodes -P batch_size=2048 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.007503556066789443 -P number_of_dense_layers=2 -P number_of_rnn_layers=5 -P rnn_layer_size=91 -P dense_layer_size=263 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=3 3 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=1 -P loss=haversine -P time_of_day=ignore -P weather=currents -P dataset_name=new_york --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=1 4 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=1 -P loss=haversine -P time_of_day=ignore -P weather=currents -P dataset_name=new_york --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=128 -P dense_layer_size=260 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=2 5 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=2 -P loss=haversine -P time_of_day=ignore -P weather=currents -P dataset_name=new_york --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=1 6 | 7 | #mlflow run . -e test_time_gaps -P time_gap=15 -P length_of_history=3 -P hours_out=2 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=128 -P dense_layer_size=260 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=2 8 | 9 | #mlflow run . -e test_time_gaps -P time_gap=15 -P length_of_history=3 -P hours_out=1 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=all_nodes -P batch_size=2048 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.007503556066789443 -P number_of_dense_layers=2 -P number_of_rnn_layers=5 -P rnn_layer_size=91 -P dense_layer_size=263 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=3 10 | 11 | #mlflow run . -e test_time_gaps -P time_gap=15 -P length_of_history=3 -P hours_out=1 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=1 12 | #mlflow run . -e test_time_gaps -P time_gap=15 -P length_of_history=3 -P hours_out=1 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=128 -P dense_layer_size=260 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=2 13 | #mlflow run . -e test_time_gaps -P time_gap=15 -P length_of_history=3 -P hours_out=1 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=new_york --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=1 14 | #mlflow run . -e test_time_gaps -P time_gap=15 -P length_of_history=3 -P hours_out=1 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=new_york --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=128 -P dense_layer_size=260 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=2 15 | mlflow run . -e test_time_gaps -P time_gap=15 -P length_of_history=3 -P hours_out=2 -P loss=haversine -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=1 16 | -------------------------------------------------------------------------------- /experiment_scripts/run_script_4.sh: -------------------------------------------------------------------------------- 1 | # FL LOSS + WEATHER 2 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=1 -P loss=mse -P time_of_day=hour_day -P weather=ignore -P dataset_name=florida_gulf --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=128 -P dense_layer_size=260 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=2 3 | 4 | 5 | # NY WEATHER 6 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=3 -P loss=haversine -P time_of_day=hour_day -P weather=ignore -P dataset_name=new_york --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=all_nodes -P batch_size=2048 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.007503556066789443 -P number_of_dense_layers=2 -P number_of_rnn_layers=5 -P rnn_layer_size=91 -P dense_layer_size=263 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=3 7 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=3 -P loss=haversine -P time_of_day=hour_day -P weather=ignore -P dataset_name=new_york --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=1 8 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=3 -P loss=haversine -P time_of_day=hour_day -P weather=ignore -P dataset_name=new_york --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=128 -P dense_layer_size=260 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=2 9 | 10 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=2 -P loss=haversine -P time_of_day=hour_day -P weather=ignore -P dataset_name=new_york --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=all_nodes -P batch_size=2048 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.007503556066789443 -P number_of_dense_layers=2 -P number_of_rnn_layers=5 -P rnn_layer_size=91 -P dense_layer_size=263 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=3 11 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=2 -P loss=haversine -P time_of_day=hour_day -P weather=ignore -P dataset_name=new_york --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=1 12 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=2 -P loss=haversine -P time_of_day=hour_day -P weather=ignore -P dataset_name=new_york --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=128 -P dense_layer_size=260 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=2 13 | 14 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=1 -P loss=haversine -P time_of_day=hour_day -P weather=ignore -P dataset_name=new_york --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=all_nodes -P batch_size=2048 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.007503556066789443 -P number_of_dense_layers=2 -P number_of_rnn_layers=5 -P rnn_layer_size=91 -P dense_layer_size=263 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=3 15 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=1 -P loss=haversine -P time_of_day=hour_day -P weather=ignore -P dataset_name=new_york --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=1 16 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=1 -P loss=haversine -P time_of_day=hour_day -P weather=ignore -P dataset_name=new_york --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=128 -P dense_layer_size=260 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=2 17 | 18 | 19 | # Loss CA 20 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=3 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=1 21 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=3 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=128 -P dense_layer_size=260 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=2 22 | 23 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=2 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=all_nodes -P batch_size=2048 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.007503556066789443 -P number_of_dense_layers=2 -P number_of_rnn_layers=5 -P rnn_layer_size=91 -P dense_layer_size=263 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=3 24 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=2 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=1 25 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=2 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=128 -P dense_layer_size=260 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=2 26 | 27 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=1 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=all_nodes -P batch_size=2048 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.007503556066789443 -P number_of_dense_layers=2 -P number_of_rnn_layers=5 -P rnn_layer_size=91 -P dense_layer_size=263 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=3 28 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=1 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=1024 -P dense_layer_size=65 -P direction=forward_only -P distance_traveled=ignore -P layer_type=lstm -P learning_rate=0.0007636935263538555 -P number_of_dense_layers=0 -P number_of_rnn_layers=4 -P rnn_layer_size=141 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=1 29 | #mlflow run . -e test_time_gaps -P time_gap=30 -P length_of_history=3 -P hours_out=1 -P loss=mse -P time_of_day=hour_day -P weather=currents -P dataset_name=california_coast --experiment-name 'Final' -P model_type=long_term_fusion -P sog_cog=raw -P rnn_to_dense_connection=final_node -P batch_size=128 -P dense_layer_size=260 -P direction=bidirectional -P distance_traveled=ignore -P layer_type=gru -P learning_rate=4.3377454427327665e-05 -P number_of_dense_layers=0 -P number_of_rnn_layers=5 -P rnn_layer_size=280 -P extended_recurrent_idxs=vt_dst_and_time -P number_of_fusion_weather_layers=2 30 | -------------------------------------------------------------------------------- /model_fitting_environment.yml: -------------------------------------------------------------------------------- 1 | name: ships_fitting 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=1_gnu 8 | - abseil-cpp=20210324.2=h9c3ff4c_0 9 | - absl-py=1.0.0=pyhd8ed1ab_0 10 | - aiohttp=3.8.1=py310h6acc77f_0 11 | - aiosignal=1.2.0=pyhd8ed1ab_0 12 | - astunparse=1.6.3=pyhd8ed1ab_0 13 | - async-timeout=4.0.2=pyhd8ed1ab_0 14 | - attrs=21.4.0=pyhd8ed1ab_0 15 | - blas=1.0=openblas 16 | - blinker=1.4=py_1 17 | - bottleneck=1.3.2=py310h9102076_1 18 | - brotlipy=0.7.0=py310h6acc77f_1003 19 | - bzip2=1.0.8=h7f98852_4 20 | - c-ares=1.18.1=h7f98852_0 21 | - ca-certificates=2021.10.26=h06a4308_2 22 | - cached-property=1.5.2=hd8ed1ab_1 23 | - cached_property=1.5.2=pyha770c72_1 24 | - cachetools=4.2.4=pyhd8ed1ab_0 25 | - certifi=2020.6.20=pyhd3eb1b0_3 26 | - cffi=1.15.0=py310h0fdd8cc_0 27 | - charset-normalizer=2.0.12=pyhd8ed1ab_0 28 | - click=8.0.3=py310hff52083_1 29 | - cryptography=36.0.1=py310h685ca39_0 30 | - cudatoolkit=11.6.0=habf752d_10 31 | - cudnn=8.2.1.32=h86fa8c9_0 32 | - frozenlist=1.3.0=py310h6acc77f_0 33 | - gast=0.4.0=pyh9f0ad1d_0 34 | - giflib=5.2.1=h36c2ea0_2 35 | - google-auth=1.35.0=pyh6c4a22f_0 36 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 37 | - google-pasta=0.2.0=pyh8c360ce_0 38 | - grpc-cpp=1.42.0=ha1441d3_1 39 | - grpcio=1.42.0=py310h94ab34a_0 40 | - h5py=3.6.0=nompi_py310he751f51_100 41 | - hdf5=1.12.1=nompi_h2750804_103 42 | - icu=69.1=h9c3ff4c_0 43 | - idna=3.3=pyhd8ed1ab_0 44 | - importlib-metadata=4.11.1=py310hff52083_0 45 | - jpeg=9e=h7f98852_0 46 | - keras=2.7.0=pyhd8ed1ab_0 47 | - keras-preprocessing=1.1.2=pyhd8ed1ab_0 48 | - krb5=1.19.2=hcc1bbae_3 49 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 50 | - libblas=3.9.0=13_linux64_openblas 51 | - libcblas=3.9.0=13_linux64_openblas 52 | - libcurl=7.81.0=h2574ce0_0 53 | - libedit=3.1.20191231=he28a2e2_2 54 | - libev=4.33=h516909a_1 55 | - libffi=3.4.2=h7f98852_5 56 | - libgcc-ng=11.2.0=h1d223b6_12 57 | - libgfortran-ng=11.2.0=h69a702a_12 58 | - libgfortran5=11.2.0=h5c6108e_12 59 | - libgomp=11.2.0=h1d223b6_12 60 | - liblapack=3.9.0=13_linux64_openblas 61 | - libnghttp2=1.46.0=h812cca2_0 62 | - libnsl=2.0.0=h7f98852_0 63 | - libopenblas=0.3.18=pthreads_h8fe5266_0 64 | - libpng=1.6.37=h21135ba_2 65 | - libprotobuf=3.19.4=h780b84a_0 66 | - libssh2=1.10.0=ha56f1ee_2 67 | - libstdcxx-ng=11.2.0=he4da1e4_12 68 | - libuuid=2.32.1=h7f98852_1000 69 | - libzlib=1.2.11=h36c2ea0_1013 70 | - markdown=3.3.6=pyhd8ed1ab_0 71 | - multidict=6.0.2=py310h6acc77f_0 72 | - nccl=2.11.4.1=h5c60f85_2 73 | - ncurses=6.3=h9c3ff4c_0 74 | - numexpr=2.7.3=py310hfd7a2a2_1 75 | - numpy=1.22.2=py310h454958d_0 76 | - oauthlib=3.2.0=pyhd8ed1ab_0 77 | - openssl=1.1.1m=h7f8727e_0 78 | - opt_einsum=3.3.0=pyhd8ed1ab_1 79 | - pandas=1.4.1=py310h295c915_0 80 | - pip=22.0.3=pyhd8ed1ab_0 81 | - protobuf=3.19.4=py310h122e73d_0 82 | - pyasn1=0.4.8=py_0 83 | - pyasn1-modules=0.2.7=py_0 84 | - pycparser=2.21=pyhd8ed1ab_0 85 | - pyjwt=2.3.0=pyhd8ed1ab_1 86 | - pyopenssl=22.0.0=pyhd8ed1ab_0 87 | - pysocks=1.7.1=py310hff52083_4 88 | - python=3.10.2=h85951f9_3_cpython 89 | - python-dateutil=2.8.2=pyhd3eb1b0_0 90 | - python-flatbuffers=2.0=pyhd8ed1ab_0 91 | - python_abi=3.10=2_cp310 92 | - pytz=2021.3=pyhd3eb1b0_0 93 | - pyu2f=0.1.5=pyhd8ed1ab_0 94 | - re2=2021.11.01=h9c3ff4c_0 95 | - readline=8.1=h46c0cb4_0 96 | - requests=2.27.1=pyhd8ed1ab_0 97 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0 98 | - rsa=4.8=pyhd8ed1ab_0 99 | - scipy=1.8.0=py310hea5193d_1 100 | - setuptools=60.9.1=py310hff52083_0 101 | - six=1.16.0=pyh6c4a22f_0 102 | - snappy=1.1.8=he1b5a44_3 103 | - sqlite=3.37.0=h9cd32fc_0 104 | - tensorboard=2.6.0=pyhd8ed1ab_1 105 | - tensorboard-data-server=0.6.0=py310h685ca39_1 106 | - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0 107 | - tensorflow=2.7.0=cuda112py310he87a039_0 108 | - tensorflow-base=2.7.0=cuda112py310h2bd284a_0 109 | - tensorflow-estimator=2.7.0=cuda112py310h922d117_0 110 | - termcolor=1.1.0=py_2 111 | - tk=8.6.11=h27826a3_1 112 | - typing-extensions=4.1.1=hd8ed1ab_0 113 | - typing_extensions=4.1.1=pyha770c72_0 114 | - tzdata=2021e=he74cb21_0 115 | - urllib3=1.26.8=pyhd8ed1ab_1 116 | - werkzeug=2.0.3=pyhd8ed1ab_1 117 | - wheel=0.37.1=pyhd8ed1ab_0 118 | - wrapt=1.13.3=py310h6acc77f_1 119 | - xz=5.2.5=h516909a_1 120 | - yarl=1.7.2=py310h6acc77f_1 121 | - zipp=3.7.0=pyhd8ed1ab_1 122 | - zlib=1.2.11=h36c2ea0_1013 123 | - pip: 124 | - alembic==1.7.6 125 | - cloudpickle==2.0.0 126 | - databricks-cli==0.16.4 127 | - docker==5.0.3 128 | - entrypoints==0.4 129 | - flask==2.0.3 130 | - gitdb==4.0.9 131 | - gitpython==3.1.27 132 | - greenlet==1.1.2 133 | - gunicorn==20.1.0 134 | - haversine==2.5.1 135 | - itsdangerous==2.1.0 136 | - jinja2==3.0.3 137 | - mako==1.1.6 138 | - markupsafe==2.1.0 139 | - mlflow==1.23.1 140 | - packaging==21.3 141 | - prometheus-client==0.13.1 142 | - prometheus-flask-exporter==0.18.7 143 | - pyparsing==3.0.7 144 | - pyyaml==6.0 145 | - querystring-parser==1.2.4 146 | - smmap==5.0.0 147 | - sqlalchemy==1.4.31 148 | - sqlparse==0.4.2 149 | - tabulate==0.8.9 150 | - websocket-client==1.2.3 151 | -------------------------------------------------------------------------------- /processing/config/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pandas as pd 4 | 5 | # Set a base directory to use for data storage. You should change this value. 6 | global data_directory 7 | data_directory = '/home/isaac/data/' 8 | 9 | global dataset_config 10 | 11 | # Define start/end years to look at. Currently, the earliest supported year is 2015. Data prior to 2015 uses a different 12 | # url and will also need to be preprocessed slightly differently - check the ais_data_faq_from_marine_cadastre.pdf in 13 | # resources_and_information for details. 14 | global start_year 15 | start_year = 2015 16 | assert start_year >= 2015 17 | 18 | global end_year 19 | end_year = 2019 20 | 21 | global years 22 | years = range(start_year, end_year + 1) 23 | 24 | # Url to download data from. This should not be changed. 25 | global base_url 26 | base_url = 'https://coast.noaa.gov/htdata/CMSP/AISDataHandler/' 27 | 28 | # The length of time between AIS messages for the trajectory to be considered a new one, in seconds. Currently 29 | # set to two hours 30 | global new_trajectory_time_gap 31 | new_trajectory_time_gap = 120 * 60 32 | 33 | # The maximum sog that a message can have without being removed from the dataset 34 | global sog_cutoff 35 | sog_cutoff = 30 36 | 37 | # The maximum empirical speed that a message can have without being removed from the dataset, in knots 38 | global empirical_speed_cutoff 39 | empirical_speed_cutoff = 40 40 | 41 | # The number of seconds between timestamps when interpolating. Currently set to 5 minutes 42 | global interpolation_time_gap 43 | interpolation_time_gap = 5 * 60 44 | 45 | # Number of *timestamps* to use for prediction, and to predict into the future. The current setting uses three hours of 46 | # history and predicts three hours into the future 47 | global length_of_history 48 | length_of_history = int(3 * 60 * 60 / interpolation_time_gap) + 1 49 | 50 | global length_into_the_future 51 | length_into_the_future = int(3 * 60 * 60 / interpolation_time_gap) - 1 52 | 53 | # The length of time needed for a track to be kept in the dataset, in seconds. Should not be edited directly 54 | global min_track_length 55 | min_track_length = (length_of_history + length_into_the_future) * interpolation_time_gap 56 | 57 | # The vessel groups to keep in the analysis. 58 | # Other valid vessel types that may be used are 'other' and 'military' 59 | global vessel_types 60 | vessel_types = [ 61 | 'cargo', 62 | 'passenger', 63 | 'fishing', 64 | 'tug tow', 65 | 'tanker', 66 | 'pleasure craft or sailing', 67 | ] 68 | 69 | # Statuses to be kept in the analysis. 70 | # See columns in navigation_statuses.csv for other possible values 71 | global desired_statuses 72 | desired_statuses = [ 73 | 'under way sailing', 74 | 'under way using engine', 75 | 'undefined' 76 | ] 77 | 78 | # The different time gaps to create datasets for. Currently only testing 15 and 30 minute time gaps 79 | global time_gaps 80 | time_gaps = [min * 60 for min in [15, 30]] 81 | 82 | # Variables to take from the ocean currents dataset. 83 | # Other possible values are 'salinity' and 'water_temp'. See NOAA website for more. 84 | global currents_variables 85 | currents_variables = ['water_u', 'water_v'] 86 | 87 | 88 | 89 | # Categorical columns that need to be one hot encoded. Only change this if you change preprocessing to add in other 90 | # columns. 91 | global categorical_columns 92 | categorical_columns = ['vessel_group','destination_cluster'] 93 | 94 | # Used for preprocessing of currents dataset. All other values have been deprecated 95 | global currents_window 96 | currents_window = 'stable' 97 | 98 | # Files to use for data cleaning. Do not change. 99 | global statuses 100 | statuses = pd.read_csv('config/navigation_statuses.csv') 101 | global types 102 | types = pd.read_csv('config/vessel_type_codes.csv') 103 | 104 | # Function for setting the log level 105 | def set_log_level(level): 106 | global log_level 107 | if level == 0: 108 | log_level = logging.CRITICAL 109 | elif level == 1: 110 | log_level = logging.ERROR 111 | elif level == 2: 112 | log_level = logging.WARNING 113 | elif level == 3: 114 | log_level = logging.INFO 115 | elif level == 4: 116 | log_level = logging.DEBUG 117 | -------------------------------------------------------------------------------- /processing/config/dataset_config.py: -------------------------------------------------------------------------------- 1 | class DatasetConfig(): 2 | def __init__(self, dataset_name, 3 | lat_1, lat_2, lon_1, lon_2, 4 | sliding_window_movement, 5 | depth_1=0, depth_2=0, 6 | min_pts_to_try = None, eps_to_try = None, 7 | min_pts_to_use = None, eps_to_use = None): 8 | self.dataset_name = dataset_name 9 | self.lat_1 = min(lat_1, lat_2) 10 | self.lat_2 = max(lat_1, lat_2) 11 | self.lon_1 = min(lon_1, lon_2) 12 | self.lon_2 = max(lon_1, lon_2) 13 | self.corner_1 = (lat_1, lon_1) 14 | self.corner_2 = (lat_2, lon_2) 15 | # When expanding data using a sliding window, the sliding_window_movement is the length of time between windows. 16 | # E.g. if sliding_window_length is 10 * 60, tracks are supposed to be made up of three timestamps, and the 17 | # interpolated trajectory has timestamps at [0, 5, 10, 15, 20], then the trajectories output 18 | # will be [0, 5, 10], and [10, 15, 20] 19 | self.sliding_window_movement = sliding_window_movement 20 | self.min_pts_to_try = min_pts_to_try 21 | self.eps_to_try = eps_to_try 22 | self.min_pts_to_use = min_pts_to_use 23 | self.eps_to_use = eps_to_use 24 | self.depth_1 = depth_1 25 | self.depth_2 = depth_2 26 | 27 | 28 | 29 | datasets = { 30 | 'florida_gulf': 31 | DatasetConfig( 32 | dataset_name='florida_gulf', 33 | lat_1=26.00, lon_1=-85.50, lat_2=29.00, lon_2=-81.50, 34 | min_pts_to_try=[4, 10, 20, 50, 100, 250, 500], 35 | eps_to_try=[0.0001, 0.00025, 0.0005, 0.00075, 36 | 0.001, 0.0025, 0.005, 0.0075, 37 | 0.01, 0.025, 0.05, 0.075, 38 | 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 39 | 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 40 | 6, 7, 8, 9, 10], 41 | min_pts_to_use=50, 42 | eps_to_use=5, 43 | sliding_window_movement=15 * 60 44 | ), 45 | 'california_coast': 46 | DatasetConfig( 47 | dataset_name='california_coast', 48 | lat_1=33.40, lon_1=-122.00, lat_2=36.40, lon_2=-118.50, 49 | sliding_window_movement=15 * 60, 50 | min_pts_to_try = [4, 10, 20, 50, 100, 250, 500], 51 | eps_to_try=[1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5], 52 | min_pts_to_use = 50, 53 | eps_to_use = 3 54 | ), 55 | 'new_york': 56 | DatasetConfig( 57 | dataset_name='new_york', 58 | lat_1=39.50, lon_1=-74.50, lat_2=41.50, lon_2=-71.50, 59 | sliding_window_movement=60 * 60, 60 | min_pts_to_try=[4, 10, 20, 50, 100, 250, 500], 61 | eps_to_try=[0.0001, 0.00025, 0.0005, 0.00075, 62 | 0.001, 0.0025, 0.005, 0.0075, 63 | 0.01, 0.025, 0.05, 0.075, 64 | 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 65 | 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 66 | 6, 7, 8, 9, 10], 67 | min_pts_to_use=50, eps_to_use=3 68 | ) 69 | } 70 | -------------------------------------------------------------------------------- /processing/config/navigation_statuses.csv: -------------------------------------------------------------------------------- 1 | Navigation Status,Description 2 | -1,Undefined 3 | -1,Power-driven vessel pushing ahead or towing alongside 4 | -1,Power-driven vessel towing astern 5 | 0,Under way using engine 6 | 1,At anchor 7 | 2,Not under command 8 | 3,Restricted maneuverability 9 | 4,Constrained by her draught 10 | 5,Moored 11 | 6,Aground 12 | 7,Engaged in Fishing 13 | 8,Under way sailing 14 | 9,Reserved for future use (9) 15 | 10,Reserved for future use (10) 16 | 11,Reserved for future use (11) 17 | 12,Reserved for future use (12) 18 | 13,Reserved for future use (13) 19 | 14,AIS-SART (active); MOB-AIS; EPIRB-AIS 20 | 15,Not defined (default) -------------------------------------------------------------------------------- /processing/config/vessel_type_codes.csv: -------------------------------------------------------------------------------- 1 | Vessel Group,VesselType,Description 2 | Null,-1,Null 3 | Not Available,0,"Not available or no ship, default" 4 | Other,1 to 19,Reserved for future use 5 | Other,20,"Wing in ground (WIG), all ships of this type" 6 | Tug Tow,21,"Wing in ground (WIG), hazardous category A" 7 | Tug Tow,22,"Wing in ground (WIG), hazardous category B" 8 | Other,23,"Wing in ground (WIG), hazardous category C" 9 | Other,24,"Wing in ground (WIG), hazardous category D" 10 | Other,25,"Wing in ground (WIG), reserved for future use" 11 | Other,26,"Wing in ground (WIG), reserved for future use" 12 | Other,27,"Wing in ground (WIG), reserved for future use" 13 | Other,28,"Wing in ground (WIG), reserved for future use" 14 | Other,29,"Wing in ground (WIG), reserved for future use" 15 | Fishing,30,Fishing 16 | Tug Tow,31,Towing 17 | Tug Tow,32,Towing: length exceeds 200m or breadth exceeds 25m 18 | Other,33,Dredging or underwater operations 19 | Other,34,Diving operations 20 | Military,35,Military operations 21 | Pleasure Craft/Sailing,36,Sailing 22 | Pleasure Craft/Sailing,37,Pleasure Craft 23 | Other,38,Reserved 24 | Other,39,Reserved 25 | Other,40,"High speed craft (HSC), all ships of this type" 26 | Other,41,"High speed craft (HSC), hazardous category A" 27 | Other,42,"High speed craft (HSC), hazardous category B" 28 | Other,43,"High speed craft (HSC), hazardous category C" 29 | Other,44,"High speed craft (HSC), hazardous category D" 30 | Other,45,"High speed craft (HSC), reserved for future use" 31 | Other,46,"High speed craft (HSC), reserved for future use" 32 | Other,47,"High speed craft (HSC), reserved for future use" 33 | Other,48,"High speed craft (HSC), reserved for future use" 34 | Other,49,"High speed craft (HSC), no additional information" 35 | Other,50,Pilot Vessel 36 | Other,51,Search and Rescue vessel 37 | Tug Tow,52,Tug 38 | Other,53,Port Tender 39 | Other,54,Anti-pollution equipment 40 | Other,55,Law Enforcement 41 | Other,56,Spare - for assignment to local vessel 42 | Other,57,Spare - for assignment to local vessel 43 | Other,58,Medical Transport 44 | Other,59,Ship according to RR Resolution No. 18 45 | Passenger,60,"Passenger, all ships of this type" 46 | Passenger,61,"Passenger, hazardous category A" 47 | Passenger,62,"Passenger, hazardous category B" 48 | Passenger,63,"Passenger, hazardous category C" 49 | Passenger,64,"Passenger, hazardous category D" 50 | Passenger,65,"Passenger, reserved for future use" 51 | Passenger,66,"Passenger, reserved for future use" 52 | Passenger,67,"Passenger, reserved for future use" 53 | Passenger,68,"Passenger, reserved for future use" 54 | Passenger,69,"Passenger, no additional information" 55 | Cargo,70,"Cargo, all ships of this type" 56 | Cargo,71,"Cargo, hazardous category A" 57 | Cargo,72,"Cargo, hazardous category B" 58 | Cargo,73,"Cargo, hazardous category C" 59 | Cargo,74,"Cargo, hazardous category D" 60 | Cargo,75,"Cargo, reserved for future use" 61 | Cargo,76,"Cargo, reserved for future use" 62 | Cargo,77,"Cargo, reserved for future use" 63 | Cargo,78,"Cargo, reserved for future use" 64 | Cargo,79,"Cargo, no additional information" 65 | Tanker,80,"Tanker, all ships of this type" 66 | Tanker,81,"Tanker, hazardous category A" 67 | Tanker,82,"Tanker, hazardous category B" 68 | Tanker,83,"Tanker, hazardous category C" 69 | Tanker,84,"Tanker, hazardous category D" 70 | Tanker,85,"Tanker, reserved for future use" 71 | Tanker,86,"Tanker, reserved for future use" 72 | Tanker,87,"Tanker, reserved for future use" 73 | Tanker,88,"Tanker, reserved for future use" 74 | Tanker,89,"Tanker, no additional information" 75 | Other,90,"Other Type, all ships of this type" 76 | Other,91,"Other Type, hazardous category A" 77 | Other,92,"Other Type, hazardous category B" 78 | Other,93,"Other Type, hazardous category C" 79 | Other,94,"Other Type, hazardous category D" 80 | Other,95,"Other Type, reserved for future use" 81 | Other,96,"Other Type, reserved for future use" 82 | Other,97,"Other Type, reserved for future use" 83 | Other,98,"Other Type, reserved for future use" 84 | Other,99,"Other Type, no additional information" 85 | Other,100 to 199,Reserved for regional use 86 | Other,200 to 255,Reserved for future use 87 | Other,256 to 999,No designation 88 | Fishing,1001,Commercial Fishing Vessel 89 | Fishing,1002,Fish Processing Vessel 90 | Cargo,1003,Freight Barge 91 | Cargo,1004,Freight Ship 92 | Other,1005,Industrial Vessel 93 | Other,1006,Miscellaneous Vessel 94 | Other,1007,Mobile Offshore Drilling Unit 95 | Other,1008,Non-vessel 96 | Other,1009,NON-VESSEL 97 | Other,1010,Offshore Supply Vessel 98 | Other,1011,Oil Recovery 99 | Passenger,1012,Passenger (Inspected) 100 | Passenger,1013,Passenger (Uninspected) 101 | Passenger,1014,Passenger Barge (Inspected) 102 | Passenger,1015,Passenger Barge (Uninspected) 103 | Cargo,1016,Public Freight 104 | Tanker,1017,Public Tankship/Barge 105 | Other,1018,"Public Vessel, Unclassified" 106 | Pleasure Craft/Sailing,1019,Recreational 107 | Other,1020,Research Vessel 108 | Military,1021,SAR Aircraft 109 | Other,1022,School Ship 110 | Tug Tow,1023,Tank Barge 111 | Tanker,1024,Tank Ship 112 | Tug Tow,1025,Towing Vessel -------------------------------------------------------------------------------- /processing/current_aggregator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import dask 5 | 6 | import dask.dataframe as dd 7 | import pandas as pd 8 | import numpy as np 9 | 10 | from config import config 11 | from config.dataset_config import datasets 12 | from processing_step import ProcessingStep 13 | 14 | 15 | class WeatherAggregator(ProcessingStep): 16 | """ 17 | Class for preprocessing the ocean currents dataset 18 | 19 | """ 20 | def __init__(self): 21 | super().__init__() 22 | self._define_directories( 23 | from_name='ocean_current_downloads', 24 | to_name='ocean_currents_aggregated' 25 | ) 26 | self._initialize_logging(args.save_log, 'aggregate_currents') 27 | 28 | def load(self): 29 | """ 30 | Load the downloaded ocean currents datasets from disk 31 | 32 | The ocean currents datasets are split by time, but this loads them in all at once (one dataset for each weather 33 | variable) 34 | 35 | :return: 36 | """ 37 | for var in config.currents_variables: 38 | self.datasets[var] = dd.read_csv(os.path.join(self.from_dir, f'{var}_*.csv')) 39 | 40 | 41 | def calculate_and_save(self): 42 | """ 43 | Perform the aggregation and save to disk 44 | 45 | The ocean currents are forecasted at every three hours. This reformats the 46 | dataset so that there is a row for each of these three hour timestamps, and the 47 | u/v observations are stored as columns. This performs the downsampling of the u/v vectors 48 | so that we only keep those observed at every 0.16, 0.24, or 0.32 degrees, in order to keep 49 | the size of the dataset manageable. It also filters out any coordinates that are over land, 50 | so that they do not need to be recorded/take up space. 51 | 52 | 53 | :return: 54 | """ 55 | if config.currents_window != 'stable': 56 | raise ValueError('Moving currents windows have been deprecated') 57 | else: 58 | for var in config.currents_variables: 59 | # Make sure that weather variables are filtered to correct region. 60 | self.datasets[var] = self.datasets[var][ 61 | (self.datasets[var]['latitude'] >= config.dataset_config.lat_1) 62 | & (self.datasets[var]['latitude'] <= config.dataset_config.lat_2) 63 | & (self.datasets[var]['longitude'] >= config.dataset_config.lon_1) 64 | & (self.datasets[var]['longitude'] <= config.dataset_config.lon_2) 65 | ] 66 | 67 | partition = self.datasets[var]._partitions(0).compute() 68 | 69 | # Even though the weather data was originally recorded at every 0.08 degrees, 70 | # we were unable to use data of this size on our machines, so instead created 71 | # a grid of approximate dimension 14 x 14 (no matter the size of the region). 72 | approx_grid_dim = 14 73 | latitudes = partition['latitude'].unique() 74 | longitudes = partition['longitude'].unique() 75 | 76 | # Calculate the stride for latitude 77 | lat_stride = np.round(len(latitudes) /(approx_grid_dim -1)) 78 | latitude_idx = np.arange(0,len(latitudes), lat_stride, dtype=int) 79 | latitudes = latitudes[latitude_idx] 80 | 81 | # Calculate the stride for longitude 82 | lon_stride = np.round(len(longitudes) / (approx_grid_dim - 1)) 83 | longitude_idx = np.arange(0,len(longitudes), lon_stride, dtype=int) 84 | longitudes = longitudes[longitude_idx] 85 | 86 | # Select the lat/lon coordinates we're interested in 87 | self.datasets[var] = self.datasets[var][self.datasets[var]['latitude'].isin(latitudes) 88 | & self.datasets[var]['longitude'].isin(longitudes) 89 | & ~self.datasets[var]['speed'].isna()] 90 | 91 | self.datasets[var]['coord'] = ( 92 | self.datasets[var]['latitude'].round(2).astype(str) 93 | + '_' 94 | + self.datasets[var]['longitude'].round(2).astype(str) 95 | ) 96 | 97 | self.datasets[var] = self.datasets[var].compute() 98 | 99 | # Reshape dataset 100 | self.datasets[var] = self.datasets[var].pivot(index=['year', 'month', 'day', 'hour'], 101 | columns='coord', values='speed') 102 | 103 | # In my dataset there were a few timestamps where currents were reported over land, which this filters 104 | # out. This also filters out locations that were only sporadically reported to. 105 | likely_land_coords = self.datasets[var].columns[self.datasets[var].isna().mean() > 0.5] 106 | self.datasets[var] = self.datasets[var].drop(columns=likely_land_coords) 107 | self.datasets[var] = (self.datasets[var] * 1000).astype(np.int16) 108 | 109 | 110 | self.datasets[var].columns = self.datasets[var].columns + '_' + var 111 | 112 | # Join weather variables together 113 | self.datasets['complete'] = self.datasets[config.currents_variables[0]].copy() 114 | for i in range(len(config.currents_variables) - 1): 115 | new_var = config.currents_variables[i+1] 116 | self.datasets['complete'] = pd.merge(self.datasets['complete'], self.datasets[new_var], 117 | left_index=True,right_index=True) 118 | self.datasets['complete'].to_csv(os.path.join(self.to_dir, 'weather_aggregated.csv')) 119 | logging.info('Weather Aggregation complete') 120 | 121 | 122 | 123 | if __name__ == '__main__': 124 | parser = argparse.ArgumentParser() 125 | 126 | parser.add_argument('dataset_name', choices=datasets.keys()) 127 | # Tool for debugging 128 | parser.add_argument('-d', '--debug', action='store_true') 129 | parser.add_argument('-l', '--log_level', type=int, 130 | default=2, choices=[0, 1, 2, 3, 4], 131 | help='Level of logging to use') 132 | parser.add_argument('-s', '--save_log', action='store_true') 133 | 134 | args = parser.parse_args() 135 | config.dataset_config = datasets[args.dataset_name] 136 | config.set_log_level(args.log_level) 137 | 138 | 139 | dask.config.set(scheduler='single-threaded') 140 | 141 | aggregator = WeatherAggregator() 142 | aggregator.load() 143 | aggregator.calculate_and_save() 144 | -------------------------------------------------------------------------------- /processing/current_appender.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import dask 5 | 6 | import dask.dataframe as dd 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from config import config 11 | from config.dataset_config import datasets 12 | from processing_step import ProcessingStep 13 | from utils import clear_path 14 | 15 | class WeatherAppender(ProcessingStep): 16 | """ 17 | This class is used for joining the interpolated AIS messages to the weather dataset 18 | """ 19 | def __init__(self): 20 | super().__init__() 21 | self._define_directories( 22 | from_name='interpolated_with_destination' + ('_debug' if args.debug else ''), 23 | to_name='interpolated_with_currents_stride_3' + ('_debug' if args.debug else '') 24 | ) 25 | self.weather_dir = os.path.join(self.box_and_year_dir, 'ocean_currents_aggregated') 26 | self._initialize_logging(args.save_log, 'add_currents') 27 | 28 | self.unneeded_columns = [ 29 | 'mmsi', 30 | 'heading', 31 | ] 32 | 33 | logging.info(f'Not using columns {self.unneeded_columns}') 34 | 35 | def load(self): 36 | """ 37 | Load the test, train, and validation sets, as well as the weather data 38 | 39 | This function just specifies the test/train/valid paths for dask. (Dask uses lazy evaluation so the full sets aren't read in 40 | here.) The weather data is loaded with pandas. 41 | 42 | :return: 43 | """ 44 | for dataset_name in ['test', 'valid', 'train']: 45 | dataset_path = os.path.join(self.from_dir, f'{dataset_name}.parquet') 46 | self.datasets[dataset_name] = dd.read_parquet(dataset_path) 47 | for col in self.unneeded_columns: 48 | if col in self.datasets[dataset_name].columns: 49 | self.datasets[dataset_name] = self.datasets[dataset_name].drop(columns=col, axis=1) 50 | if args.debug: 51 | self.datasets[dataset_name] = self.datasets[dataset_name].partitions[:1] 52 | 53 | if config.currents_window != 'stable': 54 | raise ValueError('Moving currents windows have been deprecated') 55 | else: 56 | self.datasets['weather'] = pd.read_csv(os.path.join(self.weather_dir, 57 | 'weather_aggregated.csv')) 58 | self.datasets['weather'] = self.datasets['weather'].set_index(['year']) 59 | 60 | logging.info('File paths have been specified for dask') 61 | 62 | def _change_data_sizes(self, partition, features): 63 | """ 64 | Go through and change the sizes of columns. 65 | 66 | :param partition: Dataset to change sizes of 67 | :param features: Second dataset specifying column names and desired datatypes for dataset 68 | :return: 69 | """ 70 | for col, dtype in features.iteritems(): 71 | if str(partition[col].dtype) != dtype: 72 | partition[col] = partition[col].astype(dtype) 73 | return partition 74 | 75 | 76 | def save(self): 77 | """ 78 | Save the joined dataset to disk 79 | 80 | Because Dask uses lazy evaluation, the processing will actually happen only when this method is called. 81 | 82 | Also changes data types so as to save space. 83 | 84 | :return: 85 | """ 86 | 87 | for dataset_name in ['test','valid','train']: 88 | out_path = os.path.join(self.to_dir, f'{dataset_name}.parquet') 89 | clear_path(out_path) 90 | 91 | features = self.datasets[dataset_name].dtypes.astype(str) 92 | size = 32 93 | for col in features.index: 94 | original_type = features[col] 95 | if col == 'base_datetime': 96 | new_type = 'float64' 97 | elif 'water_u' in col or 'water_v' in col: 98 | new_type = 'int16' 99 | elif features[col] == 'float64': 100 | new_type = f'float{size}' 101 | elif features[col] == 'int64': 102 | new_type = f'int{size}' 103 | elif features[col] == 'Sparse[bool, False]': 104 | new_type = 'bool' 105 | else: 106 | new_type = original_type 107 | 108 | features[col] = new_type 109 | 110 | partition = self.datasets[dataset_name]._partitions(0).compute() 111 | output_meta = self._change_data_sizes(partition, features) 112 | self.datasets[dataset_name] = self.datasets[dataset_name].map_partitions(self._change_data_sizes, features, 113 | meta=output_meta) 114 | dd.to_parquet(self.datasets[dataset_name], out_path, schema='infer') 115 | dataset_len = len(dd.read_parquet(out_path)) 116 | logging.info( 117 | f'{dataset_name} dataset has {dataset_len:,} records after joining to weather data') 118 | 119 | 120 | logging.info(f'{dataset_name} set saved to {out_path}') 121 | 122 | 123 | def calculate(self): 124 | """ 125 | Join each AIS message to the most recent forecasted ocean currents 126 | 127 | :return: 128 | """ 129 | if config.currents_window != 'stable': 130 | raise ValueError('Moving currents windows have been deprecated') 131 | else: 132 | # Iterate through train/test/valid 133 | for dataset_name in ['train','test','valid']: 134 | dataset_len = len(self.datasets[dataset_name]) 135 | logging.info(f'Length of {dataset_name} set is {dataset_len:,} before merging with weather') 136 | 137 | # Get hour/day where message occurred (year and month are already calculated) 138 | # (hour is rounded down to every 3rd hour, as the ocean current forecasts only occur every 3 hours, on 139 | # the hour) 140 | self.datasets[dataset_name]['hour'] = dd.to_datetime(self.datasets[dataset_name]['base_datetime'], unit='s').dt.hour // 3 * 3 141 | self.datasets[dataset_name]['day'] = dd.to_datetime(self.datasets[dataset_name]['base_datetime'], unit='s').dt.day 142 | idx = self.datasets[dataset_name].index 143 | 144 | # Join to weather data 145 | self.datasets[dataset_name] = dd.merge(self.datasets[dataset_name], 146 | self.datasets['weather'], 147 | left_on =['year','month','day','hour'], 148 | right_on=['year','month','day','hour'], 149 | how='left') 150 | self.datasets[dataset_name].index = idx 151 | self.datasets[dataset_name] = self.datasets[dataset_name].drop(columns=['hour','day']) 152 | 153 | # If this is the training set, calculate the mean u/v values for each lat/lon location, so we can use 154 | # these values for imputation 155 | if dataset_name == 'train': 156 | means = self.datasets[dataset_name].mean(axis=0).round().astype(np.int16).compute() 157 | 158 | # Perform the mean imputation 159 | nas = self.datasets[dataset_name].isna().sum(axis=0).compute() 160 | na_cols = [] 161 | for col in nas.index: 162 | na_count = nas[col] 163 | if na_count != 0: 164 | na_cols.append(col) 165 | logging.info(f'{dataset_name} has {na_count:,} NA values in column {col} ' 166 | f'({na_count / dataset_len * 100:0.3}%). These values have been imputed with ' 167 | f'{means[col]}.') 168 | else: 169 | logging.info(f'{dataset_name} has {na_count:,} NA values in column {col} ' 170 | f'({na_count / dataset_len * 100:0.3}%).') 171 | 172 | means_dict = {c: means[c] for c in means.index if c in na_cols} 173 | 174 | # We only need a single "weather_is_imputed" col here, as imputation happens for all the columns at once 175 | self.datasets[dataset_name]['weather_is_imputed'] = self.datasets[dataset_name][na_cols].isna().any(axis=1) 176 | 177 | self.datasets[dataset_name] = self.datasets[dataset_name].fillna(means_dict) 178 | 179 | 180 | 181 | if len(self.datasets['train'].columns) != len(self.datasets['test'].columns): 182 | raise ValueError( 183 | 'There was an error in preprocessing and the train and test sets have differing numbers of' 184 | 'columns. This is likely due to the NA filling code in buoy_appender.py, which should ' 185 | 'be edited to account for your use case.') 186 | if len(self.datasets['train'].columns) != len(self.datasets['valid'].columns): 187 | raise ValueError( 188 | 'There was an error in preprocessing and the train and valid sets have differing numbers of' 189 | 'columns. This is likely due to the NA filling code in buoy_appender.py, which should ' 190 | 'be edited to account for your use case.') 191 | 192 | 193 | 194 | 195 | if len(self.datasets['train'].columns) != len(self.datasets['test'].columns): 196 | raise ValueError('There was an error in preprocessing and the train and test sets have differing numbers of' 197 | 'columns.') 198 | if len(self.datasets['train'].columns) != len(self.datasets['valid'].columns): 199 | raise ValueError('There was an error in preprocessing and the train and valid sets have differing numbers of' 200 | 'columns.') 201 | 202 | 203 | if __name__ == '__main__': 204 | parser = argparse.ArgumentParser() 205 | 206 | parser.add_argument('dataset_name', choices=datasets.keys()) 207 | 208 | # Logging and debugging 209 | parser.add_argument('-l', '--log_level', type=int, 210 | default=2, choices=[0, 1, 2, 3, 4], 211 | help='Level of logging to use') 212 | parser.add_argument('-s', '--save_log', action='store_true') 213 | parser.add_argument('--debug', action='store_true') 214 | 215 | args = parser.parse_args() 216 | 217 | config.dataset_config = datasets[args.dataset_name] 218 | config.set_log_level(args.log_level) 219 | 220 | if args.debug: 221 | dask.config.set(scheduler='single-threaded') 222 | else: 223 | dask.config.set(scheduler='single-threaded') 224 | 225 | appender = WeatherAppender() 226 | appender.load() 227 | appender.calculate() 228 | appender.save() 229 | -------------------------------------------------------------------------------- /processing/current_downloader.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import numpy as np 5 | 6 | import pandas as pd 7 | from pydap.client import open_url, open_dods 8 | 9 | from config import config 10 | from config.dataset_config import datasets 11 | from processing_step import ProcessingStep 12 | 13 | 14 | class Downloader(ProcessingStep): 15 | """ 16 | Class for downloading ocean current data from NOAA 17 | """ 18 | def __init__(self): 19 | super().__init__() 20 | self._define_directories( 21 | from_name=None, 22 | to_name='ocean_current_downloads' 23 | ) 24 | self._initialize_logging(args.save_log, 'ocean_current_download') 25 | 26 | 27 | def _get_hycom_region(self): 28 | """ 29 | Find which weather region to download data from 30 | 31 | The coordinates for the weather regions are given through the link below 32 | https://www.ncei.noaa.gov/products/weather-climate-models/frnmoc-navy-global-hybrid-ocean 33 | 34 | :return: 35 | """ 36 | ranges = { 37 | # RegionNum: [[Lat_min, lat_max], [lon_min, lon_max]] 38 | 1: [(0.0, 70.0), (-99.99996948242188, -50.0)], 39 | 6: [(10.0, 70.0), (-150.00001525878906, -210.0)], 40 | 7: [(10.0, 60.0), (-149.99996948242188,-100.0)], 41 | 17: [(60.0, 80.0), (-179.99996948242188, -120.0)] 42 | } 43 | 44 | for region_num, [(lat_min, lat_max), (lon_min, lon_max)] in ranges.items(): 45 | if config.dataset_config.lat_1 >= lat_min and config.dataset_config.lat_2 <= lat_max: 46 | if config.dataset_config.lon_1 >= lon_min and config.dataset_config.lon_2 <= lon_max: 47 | return region_num 48 | raise ValueError("Regional weather data not available for the lat/lon coordinates chosen. They may be " 49 | "available in HYCOM's global surface currents dataset, which uses a slightly different" 50 | "url format. See the link in the docstring to amend the code for that dataset. ") 51 | 52 | 53 | 54 | def _define_directories(self, from_name, to_name): 55 | """ 56 | Save file paths to directory as member variable 57 | 58 | Override of the ProcessingStep's _define_directories, as the downloader does not have a from_dir. 59 | 60 | :param from_name: Should always be None, but included to keep signature in line with ProcessingStep's method 61 | :param to_name: Should always be 'ocean_current_downloads', but included to keep signature in line with 62 | ProcessingStep's method 63 | :return: 64 | """ 65 | self.box_and_year_dir = os.path.join( 66 | config.data_directory, 67 | f'{config.dataset_config.lat_1}_{config.dataset_config.lat_2}_' 68 | f'{config.dataset_config.lon_1}_{config.dataset_config.lon_2}_' 69 | f'{config.start_year}_{config.end_year}' 70 | ) 71 | self.from_dir = from_name 72 | self.to_dir = os.path.join(self.box_and_year_dir, to_name) 73 | self.artifact_directory = os.path.join(self.box_and_year_dir, 'artifacts') 74 | 75 | self._create_directories() 76 | 77 | def _get_map_idxs(self, dataset, variable, map): 78 | if map == 'time': 79 | time_offset = pd.to_datetime(dataset['time'].attributes['units'].replace('hours since ', '')) 80 | min = pd.to_datetime(f'{config.start_year}-01-01').tz_localize(time_offset.tzname()) 81 | max = pd.to_datetime(f'{config.end_year + 1}-01-01').tz_localize(time_offset.tzname()) 82 | else: 83 | max = getattr(config.dataset_config, f'{map}_2') 84 | min = getattr(config.dataset_config, f'{map}_1') 85 | 86 | if map in ['lat', 'lon']: 87 | lat_lon_margin = 1 88 | max += lat_lon_margin 89 | min -= lat_lon_margin 90 | if hasattr(dataset[map], 'modulo'): 91 | if dataset[map].modulo == '360 degrees': 92 | max %= 360 93 | min %= 360 94 | else: 95 | raise ValueError(f'Unknown module: {dataset[variable][map].modulo} for dataset with id {dataset.id}') 96 | 97 | map_vals = np.array(dataset[map][:]) 98 | if map == 'time': 99 | map_vals = pd.to_datetime([time_offset + pd.Timedelta(hours=h) for h in map_vals]) 100 | 101 | idxs = np.where((map_vals <= max) & (map_vals >= min))[0] 102 | 103 | 104 | if len(idxs) == 0: 105 | raise ValueError('No surface current observations are in target range') 106 | 107 | if len(idxs) > 1: 108 | continuous = len(np.unique(idxs[1:] - idxs[:-1])) == 1 109 | if not continuous: 110 | raise ValueError(f'Slice for map {map} with dataset {dataset.id} is not continuous') 111 | 112 | min_idx = idxs.min() 113 | max_idx = idxs.max() 114 | 115 | map_vals = map_vals[(map_vals <= max) & (map_vals >= min)] 116 | return min_idx, max_idx, map_vals 117 | 118 | 119 | def download(self): 120 | """ 121 | Download relevant dataset from NOAA 122 | 123 | This accesses the aggregated NetCDF using OPENDAP. It only downloads the ocean current U/V values for the water 124 | surface (i.e. it doesn't download the currents below the surface). It downloads time chunks so as to not 125 | overload the THREDDS server, e.g. requesting the first two months, then the next two, and so on. 126 | 127 | :return: 128 | """ 129 | logging.info('Starting downloads') 130 | 131 | # Get the url to query 132 | region = self._get_hycom_region() 133 | aggregated_url = ('https://www.ncei.noaa.gov/' 134 | f'thredds-coastal/dodsC/hycom/hycom_reg{region}_agg/' 135 | f'HYCOM_Region_{region}_Aggregation_best.ncd') 136 | 137 | # Open connection to url 138 | sample_ds = open_url(aggregated_url) 139 | 140 | # Find the correct indexes that we want to filter down to 141 | time_min, time_max, time_vals = self._get_map_idxs(sample_ds, 'water_u', 'time') 142 | depth_min, depth_max, depth_vals = self._get_map_idxs(sample_ds, 'water_u', 'depth') 143 | lat_min, lat_max, lat_vals = self._get_map_idxs(sample_ds, 'water_u', 'lat') 144 | lon_min, lon_max, lon_vals = self._get_map_idxs(sample_ds, 'water_u', 'lon') 145 | 146 | # Only download this many time points at once. This number can be changed if there are timeout issues. 147 | time_points_to_download_at_once = 1000 148 | time_slices = np.arange(time_min, time_max, time_points_to_download_at_once) 149 | 150 | for var in config.currents_variables: 151 | for j, min_time in enumerate(time_slices): 152 | max_time = min(time_max, min_time + time_points_to_download_at_once - 1) 153 | 154 | # Add a filter to the url so we just download for the desired coordinates/time period/depth 155 | filtered_url = aggregated_url + ( 156 | f'.dods?{var}.{var}' 157 | f'[{min_time}:1:{max_time}]' 158 | f'[{depth_min}:1:{depth_max}]' 159 | f'[{lat_min}:1:{lat_max}]' 160 | f'[{lon_min}:1:{lon_max}]' 161 | ) 162 | # Open connection to filtered url 163 | data = open_dods(filtered_url) 164 | 165 | # Download data 166 | data = (np.array(data.data))[0,0,:,0] 167 | # I was getting an error about little endian/big endian mismatch that the byteswapping fixes 168 | data = [pd.DataFrame(data[i].byteswap().newbyteorder(), index=lat_vals, columns=lon_vals) for i in range(len(data))] 169 | 170 | times = time_vals[min_time - time_min: max_time-time_min + 1] 171 | # Reshape downloaded data 172 | for i in range(len(data)): 173 | data[i]['time'] = times[i] 174 | data[i].index.name = 'latitude' 175 | data[i] = data[i].reset_index() 176 | data[i] = pd.melt( 177 | data[i], 178 | id_vars=['latitude','time'], 179 | value_vars=lon_vals 180 | ) 181 | data[i] = data[i].rename(columns={'variable':'longitude','value':'speed'}) 182 | if (data[i]['longitude'] > 180).all(): 183 | data[i]['longitude'] -= 360 184 | data = pd.concat(data) 185 | for col in ['year','month','day','hour']: 186 | data[col] = getattr(data['time'].dt, col) 187 | del data['time'] 188 | 189 | # Save this downloaded dataset to csv 190 | data.to_csv(os.path.join(self.to_dir, f'{var}_{j}.csv'),index=False) 191 | logging.info(f'{var} data downloaded for times from {times[0].strftime("%Y-%m-%d %H:%M")} to ' 192 | f'{times[-1].strftime("%Y-%m-%d %H:%M")}') 193 | 194 | logging.info(f'All downloads complete for coordinates {config.dataset_config.corner_1},' 195 | f' {config.dataset_config.corner_2}') 196 | 197 | 198 | if __name__ == '__main__': 199 | parser = argparse.ArgumentParser() 200 | 201 | parser.add_argument('dataset_name', choices=datasets.keys()) 202 | # Tool for debugging 203 | parser.add_argument('-l', '--log_level', type=int, 204 | default=2, choices=[0, 1, 2, 3, 4], 205 | help='Level of logging to use') 206 | parser.add_argument('-s', '--save_log', action='store_true') 207 | 208 | args = parser.parse_args() 209 | 210 | config.dataset_config = datasets[args.dataset_name] 211 | config.set_log_level(args.log_level) 212 | 213 | downloader = Downloader() 214 | downloader.download() 215 | -------------------------------------------------------------------------------- /processing/interpolator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import dask 5 | 6 | import dask.dataframe as dd 7 | import numpy as np 8 | import pandas as pd 9 | from scipy.interpolate import interp1d 10 | 11 | from config import config 12 | from config.dataset_config import datasets 13 | from processing_step import ProcessingStep 14 | from utils import clear_path 15 | 16 | 17 | class Interpolator(ProcessingStep): 18 | """ 19 | Class for performing the interpolation preprocessing step, by linearly interpolating between messages so that 20 | they are at a regular interval. 21 | """ 22 | def __init__(self, method): 23 | super().__init__() 24 | self._define_directories( 25 | from_name='cleaned' + ('_debug' if args.debug else ''), 26 | to_name='interpolated' + ('_debug' if args.debug else '') 27 | ) 28 | self._initialize_logging(args.save_log, 'interpolate') 29 | 30 | self.method = method 31 | if self.method == '1d': 32 | self._interpolator = np.interp 33 | else: 34 | raise ValueError('Only 1d interpolation is currently supported') 35 | 36 | # Specify interpolation methods for different columns 37 | self.timestamp_column = ['base_datetime'] 38 | self.columns_to_interpolate = [ 39 | 'lat', 40 | 'lon', 41 | 'sog', 42 | 'cog', 43 | 'heading' 44 | ] 45 | self.columns_to_use_most_recent = [ 46 | # 'status', 47 | # 'draft', 48 | # 'cargo', 49 | # 'transceiver_class', 50 | ] 51 | self.stable_columns = [ 52 | 'mmsi', 53 | # 'vessel_type', 54 | # 'length', 55 | # 'width', 56 | 'vessel_group', 57 | ] 58 | self.columns_to_calculate = { 59 | 'year': 'int16', 60 | 'month': 'byte' 61 | } 62 | 63 | def load(self): 64 | """ 65 | Load the test, train, and validation sets 66 | 67 | This function just specifies the paths for dask. (Dask uses lazy evaluation so the full sets aren't read in 68 | here.) 69 | 70 | :return: 71 | """ 72 | for dataset_name in ['test', 'train', 'valid']: 73 | dataset_path = os.path.join(self.from_dir, f'{dataset_name}.parquet') 74 | self.datasets[dataset_name] = dd.read_parquet(dataset_path) 75 | logging.info(f'{dataset_name} set starting with {self.datasets[dataset_name].shape[0].compute():,} messages') 76 | 77 | logging.debug('File paths have been specified for dask') 78 | 79 | def save(self): 80 | """ 81 | Save the interpolated datasets to disk 82 | 83 | Because Dask uses lazy evaluation, the processing will actually happen only when this method is called. 84 | 85 | :return: 86 | """ 87 | clear_path(self.to_dir) 88 | os.mkdir(self.to_dir) 89 | 90 | for dataset_name in ['test', 'train', 'valid']: 91 | out_path = os.path.join(self.to_dir, f'{dataset_name}.parquet') 92 | self.current_file = out_path 93 | dd.to_parquet(self.datasets[dataset_name + '_interpolated'], out_path, schema='infer') 94 | logging.info(f'{dataset_name} contains {self.datasets[dataset_name+ "_interpolated"].shape[0].compute():,} messages after interpolation') 95 | logging.debug(f'Interpolation complete for {dataset_name} set and dataset saved to {out_path}') 96 | self.current_file = None 97 | 98 | def interpolate(self): 99 | """ 100 | Interpolate each of the datasets 101 | 102 | Also converts the timestamps to seconds 103 | 104 | Does not actually do the interpolation, as Dask uses lazy evaluation 105 | 106 | :return: 107 | """ 108 | for dataset_name in ['test', 'train', 'valid']: 109 | self.datasets[dataset_name]['base_datetime'] = self.datasets[dataset_name]['base_datetime'].astype( 110 | int) / 10 ** 9 111 | 112 | out_meta = self.datasets[dataset_name].dtypes.append(pd.Series(self.columns_to_calculate.values(), 113 | index=self.columns_to_calculate.keys())) 114 | out_meta = [(i, z) for i, z in out_meta.items()] 115 | 116 | self.datasets[dataset_name + '_interpolated'] = self.datasets[dataset_name].map_partitions( 117 | self.interpolate_partition, 118 | meta=out_meta) 119 | 120 | def interpolate_partition(self, partition: pd.DataFrame): 121 | """ 122 | Interpolate the messages in a single partition. 123 | 124 | Dask works by splitting up a DataFrame into multiple partitions, then spreading the partitions across multiple 125 | processes (or threads, if you configure it that way). The map_partitions dask method can be used to 126 | have each processor do a transformation of its partition. This is the transformation that we are applying. This 127 | method uses the pandas groupby().apply() method to interpolate each track using the interpolate_track method 128 | below. 129 | 130 | :param partition: Partition to interpolate 131 | :return: Interpolated partition 132 | """ 133 | interpolated = partition.groupby('track').apply(self.interpolate_track) 134 | interpolated = interpolated.reset_index() 135 | interpolated = interpolated.drop('level_1', axis=1) 136 | interpolated = interpolated.set_index('track') 137 | return interpolated 138 | 139 | def interpolate_track(self, track: pd.DataFrame): 140 | """ 141 | Interpolate the messages in a single track 142 | 143 | This is a function that is applied to each track, using Pandas' groupby().apply() method. The input is a 144 | Pandas DataFrame that has all the messages for a single track, and the output is a Pandas DataFrame for 145 | this track that has been interpolated. 146 | 147 | It keeps the first message in the trajectory, then resamples at every config.interpolation_time_gap seconds. 148 | 149 | :param track: Track to interpolate 150 | :return: Interpolated track 151 | """ 152 | 153 | # Find the times that we want to sample at 154 | first_ts = track['base_datetime'].iloc[0] 155 | last_ts = track['base_datetime'].iloc[-1] 156 | times_to_sample = np.arange(first_ts, last_ts + 1, config.interpolation_time_gap) 157 | 158 | # Because categorical variables can't be interpolated linearly, we are instead taking the value from the most 159 | # recent timestamp. The categorical_interpolator object just finds the index of the most recent timestamp (in 160 | # the true data) for each time in times_to_sample, so that the categorical variables for this index can be found 161 | # more quickly 162 | categorical_interpolator = interp1d(track['base_datetime'], 163 | range(len(track['base_datetime'])), 164 | kind='previous', assume_sorted=True) 165 | most_recent_idx = categorical_interpolator(times_to_sample) 166 | 167 | interpolated = {} 168 | # iterate through columns and interpolate each one 169 | for col in track.columns: 170 | if col == 'base_datetime': 171 | interpolated[col] = times_to_sample 172 | elif col in self.columns_to_interpolate: 173 | # If this should be interpolated, do so 174 | interpolated[col] = self._interpolator(times_to_sample, track['base_datetime'], track[col]) 175 | elif col in self.columns_to_use_most_recent: 176 | # If this column is categorical and can change, then use the most recent value 177 | interpolated[col] = [track[col].iloc[int(i)] for i in most_recent_idx] 178 | elif col in self.stable_columns: 179 | # If this column is categorical but should be stable over the whole dataset, just use the first value 180 | interpolated[col] = track[col].iloc[0] 181 | else: 182 | raise ValueError(f'Please specify how to handle column {col}') 183 | 184 | # Add year and month variables 185 | for col in self.columns_to_calculate.keys(): 186 | if col == 'year': 187 | interpolated[col] = pd.to_datetime(interpolated['base_datetime'] * 10 ** 9).year 188 | elif col == 'month': 189 | interpolated[col] = pd.to_datetime(interpolated['base_datetime'] * 10 ** 9).month 190 | else: 191 | raise ValueError(f'Please specify how to interpolate column {col}') 192 | 193 | interpolated = pd.DataFrame(interpolated) 194 | return interpolated 195 | 196 | 197 | if __name__ == '__main__': 198 | # Because interpolation is at five minutes, and most of the original timestamps are *more frequent* than five 199 | # minutes, this is really more of a downsampling than an interpolation 200 | parser = argparse.ArgumentParser() 201 | 202 | parser.add_argument('dataset_name', choices=datasets.keys()) 203 | # Interpolation Method 204 | parser.add_argument('--method', type=str, default='1d', choices=['1d'], 205 | help='Interpolation method') 206 | 207 | # Logging and debugging 208 | parser.add_argument('-l', '--log_level', type=int, 209 | default=2, choices=[0, 1, 2, 3, 4], 210 | help='Level of logging to use') 211 | parser.add_argument('-s', '--save_log', action='store_true') 212 | parser.add_argument('--debug', action='store_true') 213 | 214 | args = parser.parse_args() 215 | config.set_log_level(args.log_level) 216 | 217 | config.dataset_config = datasets[args.dataset_name] 218 | 219 | if args.debug: 220 | dask.config.set(scheduler='single-threaded') 221 | else: 222 | dask.config.set(scheduler='processes') 223 | 224 | 225 | interpolator = Interpolator(args.method) 226 | interpolator.load() 227 | interpolator.interpolate() 228 | interpolator.save() 229 | -------------------------------------------------------------------------------- /processing/process.sh: -------------------------------------------------------------------------------- 1 | ## RUN AFTER dataset_config.py HAS BEEN UPDATED: 2 | python downloader.py california_coast -l 4 -s && 3 | python cleaner.py california_coast -l 4 -s --seed 47033218 --memory conserve && 4 | python interpolator.py california_coast -l 4 -s && 5 | python current_downloader.py california_coast -l 4 -s && 6 | python current_aggregator.py california_coast -l 4 -s && 7 | python destination_appender.py california_coast -l 4 -s 8 | 9 | ## RUN ONLY AFTER SETTING DBSCAN PARAMETER VALUES 10 | #python destination_appender.py california_coast -l 4 -s && 11 | #python current_appender.py california_coast -l 4 -s && 12 | #python sliding_window.py california_coast -l 4 -s && 13 | #python formatter.py california_coast -l 4 -s -------------------------------------------------------------------------------- /processing/processing_step.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import signal 4 | 5 | from abc import ABC 6 | 7 | from config import config 8 | from utils import get_zones_from_coordinates 9 | 10 | 11 | class ProcessingStep(ABC): 12 | """ 13 | Base class for members in the processing directory. Includes member functions for saving data passed in during 14 | initialization, defining/creating the necessary directories where data will be located, and initializing logging. 15 | """ 16 | def __init__(self): 17 | self.lat_min = config.dataset_config.lat_1 18 | self.lat_max = config.dataset_config.lat_2 19 | self.lon_min = config.dataset_config.lon_1 20 | self.lon_max = config.dataset_config.lon_2 21 | self.current_file = None 22 | self.zones = get_zones_from_coordinates(config.dataset_config.corner_1, config.dataset_config.corner_2) 23 | self.datasets = {} 24 | 25 | def _define_directories(self, from_name, to_name): 26 | """ 27 | Save file paths for from/to directories as member variables. 28 | 29 | Most processing steps take a dataset saved in one location, process them, then save them to another 30 | location. This function simply saves the paths for the from directory and to directory to the object. (The 31 | downloaders do not have from directories, as they source the data from NOAA or MarineCadastre.gov.) 32 | 33 | :param from_name: Name of directory that data is coming from 34 | :param to_name: Name of directory that data is going to 35 | :return: 36 | """ 37 | self.box_and_year_dir = os.path.join( 38 | config.data_directory, 39 | f'{config.dataset_config.lat_1}_{config.dataset_config.lat_2}_' 40 | f'{config.dataset_config.lon_1}_{config.dataset_config.lon_2}_' 41 | f'{config.start_year}_{config.end_year}' 42 | ) 43 | self.from_dir = os.path.join(self.box_and_year_dir, from_name) 44 | self.to_dir = os.path.join(self.box_and_year_dir, to_name) 45 | self.artifact_directory = os.path.join(self.box_and_year_dir, 'artifacts') 46 | 47 | self._create_directories() 48 | 49 | def _create_directories(self): 50 | """ 51 | Create directory to move data into, if it doesn't already exist. 52 | 53 | Must call _define_directories before _create_directories. 54 | 55 | :return: 56 | """ 57 | if not os.path.exists(config.data_directory): 58 | os.mkdir(config.data_directory) 59 | if not os.path.exists(self.box_and_year_dir): 60 | os.mkdir(self.box_and_year_dir) 61 | if not os.path.exists(self.to_dir): 62 | os.mkdir(self.to_dir) 63 | if hasattr(self, 'artifact_directory'): 64 | if not os.path.exists(self.artifact_directory): 65 | os.mkdir(self.artifact_directory) 66 | 67 | def _initialize_logging(self, save_log=False, log_file_name=None): 68 | """ 69 | Kick off the logging process 70 | 71 | This adds a logging handler, which will write logging messages to stdout. (If save_log=True, it will also 72 | write them to disk.) Once this has been run, any other module can import logging, then write messages using 73 | functions like logging.info, logging.warning, etc. 74 | 75 | If the save_log option is specified, this will to save logs to 76 | f'{self.box_and_year_directory}/logs/{log_file_name}.log' 77 | 78 | :param save_log: Whether or not the log should be saved to disk 79 | :param log_file_name: The name of the log file, if the log should be saved to disk 80 | :return: 81 | """ 82 | logging_directory = os.path.join( 83 | self.box_and_year_dir, 84 | 'logs' 85 | ) 86 | if not os.path.exists(logging_directory): 87 | os.mkdir(logging_directory) 88 | 89 | # remove any other handlers that have been added by imported python libraries 90 | for handler in logging.root.handlers: 91 | logging.root.removeHandler(handler) 92 | 93 | # Reset format 94 | format = logging.Formatter('%(asctime)s %(levelname)s %(message)s') 95 | c_handler = logging.StreamHandler() 96 | c_handler.setLevel(config.log_level) 97 | c_handler.setFormatter(format) 98 | 99 | if save_log: 100 | log_file = os.path.join(logging_directory, f'{log_file_name}.log') 101 | f_handler = logging.FileHandler(log_file) 102 | f_handler.setLevel(logging.DEBUG) 103 | f_handler.setFormatter(format) 104 | logging.basicConfig(handlers=[c_handler, f_handler], datefmt='%m/%d/%Y %I:%M:%S', level=logging.DEBUG) 105 | logging.info(f'Logs being saved to {log_file}') 106 | logging.info('New run beginning') 107 | else: 108 | logging.basicConfig(handlers=[c_handler], datefmt='%m/%d/%Y %I:%M:%S', level=logging.DEBUG) 109 | 110 | self._add_exit_handling() 111 | 112 | def _add_exit_handling(self): 113 | """ 114 | Initialize exit handling. 115 | 116 | If signal is interrupted and run does not complete, then write this to log. Also delete a file that was in the 117 | middle of being unzipped, if there was one. 118 | """ 119 | 120 | def log_sigint(a, b): 121 | # this gets fed two inputs when called, neither of which are needed 122 | if self.current_file is not None: 123 | if os.path.exists(self.current_file): 124 | os.remove(self.current_file) 125 | logging.error(f'Process ended prematurely by signal interruption. File {self.current_file} ' 126 | 'was being processed when interruption occurred and an attempt to remove ' 127 | 'the incomplete file was made. You may still need to remove it manually.') 128 | else: 129 | logging.error(f'Process ended prematurely by signal interruption.') 130 | 131 | signal.signal(signal.SIGINT, log_sigint) 132 | 133 | -------------------------------------------------------------------------------- /processing/sliding_window.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import dask 5 | 6 | import dask.dataframe as dd 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from config import config 11 | from config.dataset_config import datasets 12 | from processing_step import ProcessingStep 13 | from utils import clear_path 14 | 15 | class SlidingWindow(ProcessingStep): 16 | """ 17 | Class for performing the sliding window processing step, where long tracks are split up into multiple 18 | shorter tracks that can be used for fitting the model. 19 | """ 20 | def __init__(self): 21 | super().__init__() 22 | self._define_directories( 23 | from_name='interpolated_with_currents_stride_3' + ('_debug' if args.debug else ''), 24 | to_name='windowed_with_currents_stride_3' + ('_debug' if args.debug else '') 25 | ) 26 | self._initialize_logging(args.save_log, 'sliding_window_with_weather') 27 | 28 | 29 | def load(self): 30 | """ 31 | Load the test, train, and validation sets 32 | 33 | This function just specifies the paths for dask. (Dask uses lazy evaluation so the full sets aren't read in 34 | here.) 35 | 36 | :return: 37 | """ 38 | for dataset_name in ['test', 'valid', 'train']: 39 | dataset_path = os.path.join(self.from_dir, f'{dataset_name}.parquet') 40 | self.datasets[dataset_name] = dd.read_parquet(dataset_path) 41 | 42 | if self.datasets[dataset_name].index.name is None: 43 | def rename_index(partition): 44 | partition.index.name = 'track' 45 | return partition 46 | self.datasets[dataset_name] = self.datasets[dataset_name].map_partitions(rename_index) 47 | logging.info('File paths have been specified for dask') 48 | 49 | def save(self): 50 | """ 51 | Save the windowed datasets to disk 52 | 53 | Because Dask uses lazy evaluation, the processing will actually happen only when this method is called. 54 | 55 | :return: 56 | """ 57 | 58 | for dataset_name in self.datasets.keys(): 59 | self.datasets[dataset_name] = self.datasets[dataset_name] 60 | out_path = os.path.join(self.to_dir, f'{dataset_name}.parquet') 61 | clear_path(out_path) 62 | logging.info( 63 | f'Number of messages in {dataset_name} set is {len(self.datasets[dataset_name]):,}') 64 | dd.to_parquet(self.datasets[dataset_name], out_path, schema='infer') 65 | logging.info(f'{dataset_name} set saved to {out_path}') 66 | 67 | def calculate(self): 68 | """ 69 | Iterate through and perform the sliding window calculations. 70 | 71 | This creates two types of windowed datasets, which differ in the size of their sliding window movements. The 72 | long_term_train datasets are used for model training and validation - they use a shorter sliding window 73 | movement so that there will be a significant overlap between successive sequences in these datasets. The 74 | long_term_test datasets are used for performance evaluation, and for this reason their sliding window movement 75 | is set so that the input portions of successive sequences do not overlap at all. 76 | 77 | :return: 78 | """ 79 | # Only the train/validation sets need long_term_train versions 80 | for dataset_name in ['train', 'valid']: 81 | out_meta = self.datasets[dataset_name].dtypes 82 | out_meta = [(i, z) for i, z in out_meta.items()] 83 | self.datasets[dataset_name + '_long_term_train'] = self.datasets[dataset_name].map_partitions( 84 | self.window_partition, 85 | self.window_track_long_term_train, 86 | meta=out_meta) 87 | # Only the test/validation sets need long_term_test versions 88 | for dataset_name in ['test', 'valid']: 89 | out_meta = self.datasets[dataset_name].dtypes 90 | out_meta = [(i, z) for i, z in out_meta.items()] 91 | self.datasets[dataset_name + '_long_term_test'] = self.datasets[dataset_name].map_partitions( 92 | self.window_partition, 93 | self.window_track_long_term_test, 94 | meta=out_meta) 95 | del self.datasets['test'], self.datasets['train'], self.datasets['valid'] 96 | logging.info(f'Calculation methods have been defined for Dask') 97 | 98 | def window_partition(self, partition: pd.DataFrame, track_fn) -> pd.DataFrame: 99 | """ 100 | Perform sliding window calculations on the messages in a single partition 101 | 102 | Dask works by splitting up a DataFrame into multiple partitions, then spreading the partitions across multiple 103 | processes (or threads, if you were to configure it that way). The map_partitions dask method can be used to 104 | have each processor do a transformation of its partition. This is the transformation that we are applying. This 105 | method uses the pandas groupby().apply() method to window each track using the track_fn function that is 106 | specified. 107 | 108 | :param partition: Partition to window 109 | :param track_fn: Function for windowing 110 | :return: Windowed partition 111 | """ 112 | partition = partition.groupby('track').apply(track_fn).reset_index() 113 | partition = partition.drop('level_1', axis=1) 114 | partition = partition.set_index('track') 115 | return partition 116 | 117 | def window_track_long_term_test(self, track: pd.DataFrame) -> pd.DataFrame: 118 | """ 119 | Create the long term dataset for a single track, for testing purposes 120 | 121 | Windows a long trajectory into multiple shorter ones. Even though we're still breaking up the trajectories using 122 | a sliding window, the input portions of the new trajectories should not overlap. e.g. if there's 6 hours of 123 | data, the first set will use (hour 1) to predict (hours 2, 3, 4, 5), and the second set will use (hour 2) to 124 | predict (hours 3, 4, 5, 6). This is different than 'window_track_long_term_train', as the training data is 125 | allowed to overlap 126 | 127 | :param track: Track to window 128 | :return: Windowed track 129 | """ 130 | number_of_subtracks = len(track) - (config.length_of_history + config.length_into_the_future) 131 | subtrack_idxs = [np.arange(i, i + config.length_of_history + config.length_into_the_future + 1) for i in 132 | range(0, number_of_subtracks, config.length_of_history)] 133 | windowed_track = [track.iloc[idxs] for idxs in subtrack_idxs] 134 | windowed_track = pd.concat(windowed_track).reset_index(drop=True) 135 | return windowed_track 136 | 137 | def window_track_long_term_train(self, track: pd.DataFrame) -> pd.DataFrame: 138 | """ 139 | Create the long term dataset for a single track, for training purposes 140 | 141 | Windows a long trajectory into multiple shorter ones. The difference between this and 142 | window_track_long_term_test, is the sliding window size. The 'test' version makes sure that input sequences do 143 | not overlap, where as this version does include overlaps. 144 | 145 | :param track: Track to window 146 | :return: Windowed track 147 | """ 148 | number_of_gaps_into_the_future = config.length_into_the_future 149 | number_of_subtracks = len(track) - (config.length_of_history + number_of_gaps_into_the_future) 150 | window_movement_in_ts = int(config.dataset_config.sliding_window_movement / config.interpolation_time_gap) 151 | subtrack_idxs = [np.arange(i, i + config.length_of_history + number_of_gaps_into_the_future + 1) for i in 152 | range(0, number_of_subtracks, window_movement_in_ts)] 153 | windowed_track = [track.iloc[idxs] for idxs in subtrack_idxs] 154 | windowed_track = pd.concat(windowed_track).reset_index(drop=True) 155 | return windowed_track 156 | 157 | 158 | 159 | if __name__ == '__main__': 160 | parser = argparse.ArgumentParser() 161 | 162 | parser.add_argument('dataset_name', choices=datasets.keys()) 163 | # Logging and debugging 164 | parser.add_argument('-l', '--log_level', type=int, 165 | default=2, choices=[0, 1, 2, 3, 4], 166 | help='Level of logging to use') 167 | parser.add_argument('-s', '--save_log', action='store_true') 168 | parser.add_argument('--debug', action='store_true') 169 | 170 | args = parser.parse_args() 171 | config.dataset_config = datasets[args.dataset_name] 172 | config.set_log_level(args.log_level) 173 | 174 | if args.debug: 175 | dask.config.set(scheduler='single-threaded') 176 | else: 177 | dask.config.set(scheduler='single-threaded') 178 | 179 | window = SlidingWindow() 180 | window.load() 181 | window.calculate() 182 | window.save() 183 | -------------------------------------------------------------------------------- /processing/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import shutil 4 | import urllib 5 | 6 | import pandas as pd 7 | import utm 8 | 9 | from calendar import monthrange 10 | from dateutil import rrule 11 | from datetime import datetime 12 | 13 | def get_zones_from_coordinates(corner_1, corner_2): 14 | """ 15 | Get UTM zones to download, based on lat/lon coordinates 16 | 17 | :param corner_1: Lat/lon pair 18 | :param corner_2: Lat/lon pair 19 | :return: range of zones to download 20 | """ 21 | _, _, zone_1, _ = utm.from_latlon(*corner_1) 22 | _, _, zone_2, _ = utm.from_latlon(*corner_2) 23 | if zone_1 > 19: 24 | raise ValueError(f'Corner 1 {corner_1} is outside data available on MarineCadastre.gov') 25 | if zone_2 > 19: 26 | raise ValueError(f'Corner 2 {corner_2} is outside data available on MarineCadastre.gov') 27 | zones_to_download = range(min(zone_1, zone_2), max(zone_1, zone_2) + 1) 28 | return zones_to_download 29 | 30 | 31 | def get_file_specifier(year, month, zone_or_day, extension): 32 | """ 33 | Get the specific file name for this year, month, and zone or day 34 | 35 | Files from 2017 and prior are split by utm zone, while files from 2018 and on are split by day. 36 | 37 | This unfortunately means that all 2018 AIS messages need to be downloaded. 38 | 39 | :param year: 40 | :param month: 41 | :param zone: 42 | :param extension: 43 | :return: 44 | """ 45 | if year in (2015, 2016, 2017): 46 | specifier = f'AIS_{year}_{month:02d}_Zone{zone_or_day:02d}.{extension}' 47 | elif year in (2018, 2019, 2020, 2021): 48 | specifier = f'AIS_{year}_{month:02d}_{zone_or_day:02d}.{extension}' 49 | else: 50 | raise ValueError(f"I'm not sure how to format the specifier for year {year}; " 51 | f"Check https://coast.noaa.gov/htdata/CMSP/AISDataHandler/{year}/index.html to edit " 52 | f"me to do so, and make sure to also edit get_info_from_specifier()") 53 | 54 | specifier = urllib.parse.urljoin(f'{year}/', specifier) 55 | return specifier 56 | 57 | 58 | def get_info_from_specifier(file_name): 59 | """ 60 | Split the file specifier into its constituent info 61 | 62 | The file specifier contains the year, month, zone/day, and file extension for the file in question. This splits up 63 | a file specifier into these parts. Whether the third piece of information is the zone or day depends on what year 64 | the file is from (2015-2017 will contain the zone, will 2018+ will contain the day, as this is how the files are 65 | organized on MarineCadastre.gov). 66 | 67 | :return: year, month, zone or day, extension 68 | """ 69 | split = re.search('[0-9]{4}.+AIS_([0-9]{4})_([0-9]{2})_(Zone)?([0-9]{2}|\*)\.(.+)', file_name) 70 | if split: 71 | year = split.group(1) 72 | month = split.group(2) 73 | zone_or_day = split.group(4) 74 | extension = split.group(5) 75 | else: 76 | raise ValueError('This file does not have a known specifier format; the year, month, zone/day, and extension ' 77 | 'cannot be found') 78 | 79 | return year, month, zone_or_day, extension 80 | 81 | 82 | def all_specifiers(zones, years, extension, dir=None): 83 | """ 84 | Get all file specifiers for the relevant zones and years 85 | 86 | A specifier is a string formatted something like '2017/AIS_2017_01_Zone01.zip' 87 | 88 | :param zones: The UTM zones to look at 89 | :param years: The years to consider 90 | :param extension: The file extension to use for the specifier 91 | :param dir: The directory, if one is desired at the start of the specifier, 92 | :return: paths 93 | """ 94 | specifiers = [] 95 | if dir is not None: 96 | paths = [] 97 | for year in years: 98 | if year in (2015, 2016, 2017): 99 | for month in range(1, 13): 100 | for zone in zones: 101 | specifier = get_file_specifier(year, month, zone, extension) 102 | specifiers.append(specifier) 103 | 104 | if dir is not None: 105 | path = os.path.join(dir, specifier) 106 | paths.append(path) 107 | elif year in (2018, 2019, 2020, 2021): 108 | for dt in rrule.rrule(rrule.DAILY, 109 | dtstart=datetime.strptime(f'{year}-01-01', '%Y-%m-%d'), 110 | until=datetime.strptime(f'{year}-12-31', '%Y-%m-%d')): 111 | specifier = get_file_specifier(dt.year, dt.month, dt.day, extension) 112 | specifiers.append(specifier) 113 | 114 | if dir is not None: 115 | path = os.path.join(dir, specifier) 116 | paths.append(path) 117 | 118 | if dir is not None: 119 | all_zym = {'paths': paths, 'specifiers': specifiers} 120 | else: 121 | all_zym = {'specifiers': specifiers} 122 | 123 | return all_zym 124 | 125 | 126 | def pd_append(values): 127 | """ 128 | Append values together into a pandas series 129 | 130 | values should be a list of different things to append together, e.g. the first item might be the integer, the second 131 | a pd.Series of integers. 132 | 133 | :param values: A list of different things to append together 134 | :return: 135 | """ 136 | v1 = values[0] 137 | if len(values) > 2: 138 | if type(v1) == pd.Series: 139 | series = pd.concat([ 140 | v1, 141 | pd_append(values[1:]) 142 | ]).reset_index(drop=True) 143 | else: 144 | series = pd.concat([ 145 | pd.Series([v1]), 146 | pd_append(values[1:]) 147 | ]).reset_index(drop=True) 148 | elif len(values) == 2: 149 | v2 = values[1] 150 | if type(v1) == pd.Series: 151 | series = pd.concat([ 152 | v1, 153 | pd.Series([v2]) 154 | ]).reset_index(drop=True) 155 | elif type(v2) == pd.Series: 156 | series = pd.concat([ 157 | pd.Series([v1]), 158 | v2 159 | ]).reset_index(drop=True) 160 | else: 161 | series = pd.concat([ 162 | pd.Series([v1]), 163 | pd.Series([v2]) 164 | ]).reset_index(drop=True) 165 | return series 166 | 167 | 168 | def to_snake_case(name): 169 | """ 170 | Convert a string to snake case 171 | 172 | :param name: Name to convert 173 | :return: Converted name 174 | """ 175 | name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) 176 | name = re.sub('__([A-Z])', r'_\1', name) 177 | name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', name) 178 | return name.lower() 179 | 180 | 181 | def clear_path(path): 182 | """ 183 | Delete any files or directories from a path 184 | 185 | :param path: Path to remove 186 | :return: 187 | """ 188 | if os.path.exists(path): 189 | if os.path.isfile(path): 190 | os.remove(path) 191 | else: 192 | shutil.rmtree(path) 193 | 194 | def get_min_max_times(specifier): 195 | """ 196 | Get the first/last possible time for AIS messages contained in a file 197 | 198 | :param specifier: file information 199 | :return: 200 | """ 201 | year, month, zone_or_day, extension = get_info_from_specifier(specifier) 202 | year = int(year) 203 | month = int(month) 204 | 205 | if year in (2015, 2016, 2017): 206 | min_time = pd.to_datetime(f'{year}-{month}-01 00:00:00') 207 | _, last_day = monthrange(year, month) 208 | max_time = pd.to_datetime(f'{year}-{month}-{last_day} 23:59:59') 209 | 210 | 211 | elif year in (2018, 2019, 2020, 2021): 212 | day = zone_or_day 213 | min_time = pd.to_datetime(f'{year}-{month}-{day} 00:00:00') 214 | max_time = pd.to_datetime(f'{year}-{month}-{day} 23:59:59') 215 | 216 | else: 217 | raise ValueError('Year unaccounted for') 218 | 219 | return min_time, max_time 220 | -------------------------------------------------------------------------------- /processing_environment.yml: -------------------------------------------------------------------------------- 1 | name: ships_processing 2 | channels: 3 | - conda-forge 4 | - anaconda 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=4.5=1_gnu 9 | - _tflow_select=2.1.0=gpu 10 | - absl-py=0.7.1=py37_0 11 | - alembic=1.3.1=py_0 12 | - altair=3.1.0=py37_0 13 | - appdirs=1.4.4=pyhd3eb1b0_0 14 | - asn1crypto=0.24.0=py37_1003 15 | - astor=0.7.1=py_0 16 | - attrs=19.1.0=py37_1 17 | - backcall=0.1.0=py37_0 18 | - basemap=1.2.0=py37h705c2d8_0 19 | - beautifulsoup4=4.11.1=pyha770c72_0 20 | - blas=2.14=openblas 21 | - bleach=3.1.0=py37_0 22 | - bokeh=2.1.1=py37_0 23 | - branca=0.3.1=py_0 24 | - bzip2=1.0.6=h14c3975_1002 25 | - c-ares=1.15.0=h516909a_1001 26 | - ca-certificates=2021.10.8=ha878542_0 27 | - cairo=1.14.12=h8948797_3 28 | - certifi=2021.10.8=py37h89c1867_2 29 | - cffi=1.12.3=py37h8022711_0 30 | - chardet=3.0.4=py37_1003 31 | - click=8.0.3=pyhd3eb1b0_0 32 | - cloudpickle=2.0.0=pyhd3eb1b0_0 33 | - colorama=0.4.4=pyh9f0ad1d_0 34 | - configparser=5.0.2=pyhd3eb1b0_0 35 | - cryptography=2.7=py37h72c5cf5_0 36 | - cudatoolkit=10.1.168=0 37 | - cudnn=7.6.0=cuda10.1_0 38 | - cupti=10.1.168=0 39 | - cycler=0.10.0=py_1 40 | - cytoolz=0.9.0.1=py37h14c3975_1 41 | - dask=2.30.0=py_0 42 | - dask-core=2.30.0=py_0 43 | - dask-glm=0.2.0=py37_0 44 | - dask-ml=1.9.0=pyhd3eb1b0_0 45 | - databricks-cli=0.12.1=pyhd8ed1ab_0 46 | - dbus=1.13.12=h746ee38_0 47 | - decorator=4.4.0=py37_1 48 | - defusedxml=0.6.0=py_0 49 | - distributed=2.30.1=py37h06a4308_0 50 | - docker-py=4.4.1=py37h06a4308_5 51 | - docker-pycreds=0.4.0=pyhd3eb1b0_0 52 | - docopt=0.6.2=py_1 53 | - entrypoints=0.3=py37_0 54 | - expat=2.2.6=he6710b0_0 55 | - ffmpeg=4.1.3=h167e202_0 56 | - flask=1.1.2=pyhd3eb1b0_0 57 | - folium=0.9.1=py_0 58 | - fontconfig=2.13.0=h9420a91_0 59 | - freetype=2.9.1=h8a8886c_1 60 | - fribidi=1.0.10=h7b6447c_0 61 | - fsspec=2021.10.1=pyhd3eb1b0_0 62 | - gast=0.2.2=py_0 63 | - geos=3.6.2=heeff764_2 64 | - gettext=0.19.8.1=hc5be6a0_1002 65 | - giflib=5.1.9=h516909a_0 66 | - gitdb=4.0.7=pyhd3eb1b0_0 67 | - gitpython=3.1.18=pyhd3eb1b0_1 68 | - glib=2.56.2=hd408876_0 69 | - gmp=6.1.2=hb3b607b_0 70 | - gnutls=3.6.5=hd3a4fd2_1002 71 | - google-pasta=0.1.7=py_0 72 | - graphite2=1.3.13=hf484d3e_1000 73 | - graphviz=2.40.1=h21bd128_2 74 | - greenlet=1.1.1=py37h295c915_0 75 | - grpcio=1.16.1=py37hf8bcb03_1 76 | - gst-plugins-base=1.14.0=hbbd80ab_1 77 | - gstreamer=1.14.0=hb453b48_1 78 | - gunicorn=20.1.0=py37h06a4308_0 79 | - h5py=2.10.0=nompi_py37h513d04c_100 80 | - harfbuzz=1.8.8=hffaf4a1_0 81 | - haversine=2.5.1=pyhd8ed1ab_0 82 | - hdf5=1.10.5=nompi_h3c11f04_1104 83 | - heapdict=1.0.1=pyhd3eb1b0_0 84 | - icu=58.2=h9c2bf20_1 85 | - idna=2.8=py37_1000 86 | - imageio=2.5.0=py37_0 87 | - importlib-metadata=4.8.1=py37h06a4308_0 88 | - intel-openmp=2019.4=243 89 | - ipykernel=5.1.1=py37h39e3cac_0 90 | - ipython=7.5.0=py37h39e3cac_0 91 | - ipython_genutils=0.2.0=py37_0 92 | - ipywidgets=7.4.2=py37_0 93 | - itsdangerous=2.0.1=pyhd3eb1b0_0 94 | - jasper=1.900.1=h07fcdf6_1006 95 | - jedi=0.13.3=py37_0 96 | - jinja2=2.10.1=py37_0 97 | - joblib=0.13.2=py37_0 98 | - jpeg=9c=h14c3975_1001 99 | - jsonschema=3.0.1=py37_0 100 | - jupyter=1.0.0=py37_7 101 | - jupyter_client=5.2.4=py37_0 102 | - jupyter_console=6.0.0=py37_0 103 | - jupyter_core=4.4.0=py37_0 104 | - keras=2.2.4=py37_1 105 | - keras-applications=1.0.7=py_1 106 | - keras-preprocessing=1.0.9=py_1 107 | - kiwisolver=1.1.0=py37hc9558a2_0 108 | - lame=3.100=h14c3975_1001 109 | - libblas=3.8.0=14_openblas 110 | - libcblas=3.8.0=14_openblas 111 | - libclang=9.0.0=hc9558a2_1 112 | - libedit=3.1.20181209=hc058e9b_0 113 | - libffi=3.2.1=h4deb6c0_3 114 | - libgcc-ng=9.1.0=hdf63c60_0 115 | - libgfortran-ng=7.3.0=hdf63c60_0 116 | - libgomp=9.3.0=h5101ec6_17 117 | - libgpuarray=0.7.6=h14c3975_1003 118 | - libiconv=1.15=h516909a_1005 119 | - liblapack=3.8.0=14_openblas 120 | - liblapacke=3.8.0=14_openblas 121 | - libllvm11=11.1.0=h3826bc1_0 122 | - libllvm9=9.0.0=hc9558a2_2 123 | - libopenblas=0.3.7=h6e990d7_2 124 | - libpng=1.6.37=hbc83047_0 125 | - libprotobuf=3.9.1=h8b12597_0 126 | - libsodium=1.0.16=h1bed415_0 127 | - libstdcxx-ng=9.1.0=hdf63c60_0 128 | - libtiff=4.0.10=h2733197_2 129 | - libuuid=1.0.3=h1bed415_2 130 | - libwebp=1.0.2=h576950b_1 131 | - libxcb=1.13=h1bed415_1 132 | - libxkbcommon=0.9.1=hebb1f50_0 133 | - libxml2=2.9.9=hea5a465_1 134 | - libxslt=1.1.33=h7d1a2b0_0 135 | - llvmlite=0.37.0=py37h295c915_1 136 | - locket=0.2.1=py37h06a4308_1 137 | - lxml=4.4.1=py37hefd8a0e_0 138 | - mako=1.1.0=py_0 139 | - markdown=3.1.1=py_0 140 | - markupsafe=1.1.1=py37h7b6447c_0 141 | - matplotlib=3.1.1=py37h5429711_0 142 | - memory_profiler=0.55.0=py37_0 143 | - mistune=0.8.4=py37h7b6447c_0 144 | - mkl=2019.4=243 145 | - mkl-service=2.0.2=py37h7b6447c_0 146 | - mkl_fft=1.0.15=py37h516909a_1 147 | - mlflow=1.20.2=py37h02d9ccd_1 148 | - msgpack-python=1.0.2=py37hff7bd54_1 149 | - multipledispatch=0.6.0=py37_0 150 | - nbconvert=5.5.0=py_0 151 | - nbformat=4.4.0=py37_0 152 | - nbstripout=0.3.6=py_0 153 | - ncurses=6.1=he6710b0_1 154 | - nettle=3.4.1=h1bed415_1002 155 | - notebook=5.7.8=py37_0 156 | - nspr=4.23=he1b5a44_0 157 | - nss=3.47=he751ad9_0 158 | - numba=0.54.1=py37h51133e4_0 159 | - numpy=1.17.3=py37h95a1406_0 160 | - numpy-base=1.17.2=py37h2f8d375_0 161 | - olefile=0.46=py37_0 162 | - openh264=1.8.0=hdbcaa40_1000 163 | - openssl=1.1.1n=h7f8727e_0 164 | - owslib=0.18.0=py_0 165 | - packaging=21.0=pyhd3eb1b0_0 166 | - pandas=1.2.4=py37ha9443f7_0 167 | - pandoc=2.2.3.2=0 168 | - pandocfilters=1.4.2=py37_1 169 | - pango=1.42.4=h049681c_0 170 | - parso=0.4.0=py_0 171 | - partd=1.2.0=pyhd3eb1b0_0 172 | - patsy=0.5.1=py37_0 173 | - pcre=8.43=he6710b0_0 174 | - pexpect=4.7.0=py37_0 175 | - pickleshare=0.7.5=py37_0 176 | - pillow=6.0.0=py37h34e0f95_0 177 | - pip=19.1.1=py37_0 178 | - pixman=0.38.0=h516909a_1003 179 | - proj4=5.0.1=h14c3975_0 180 | - prometheus_client=0.6.0=py37_0 181 | - prometheus_flask_exporter=0.18.5=pyhd8ed1ab_0 182 | - prompt_toolkit=2.0.9=py37_0 183 | - protobuf=3.9.1=py37he1b5a44_0 184 | - psutil=5.6.3=py37h7b6447c_0 185 | - ptyprocess=0.6.0=py37_0 186 | - pycparser=2.19=py37_1 187 | - pydap=3.3.0=pyhd8ed1ab_0 188 | - pydot=1.4.1=py37h06a4308_0 189 | - pyepsg=0.4.0=py37_0 190 | - pygments=2.4.2=py_0 191 | - pygpu=0.7.6=py37h3010b51_1000 192 | - pyopenssl=19.0.0=py37_0 193 | - pyparsing=2.4.0=py_0 194 | - pyproj=1.9.5.1=py37h7b21b82_1 195 | - pyqt=5.9.2=py37h05f1152_2 196 | - pyrsistent=0.14.11=py37h7b6447c_0 197 | - pyshp=2.1.0=py_0 198 | - pysocks=1.7.0=py37_0 199 | - python=3.7.3=h0371630_0 200 | - python-dateutil=2.8.0=py37_0 201 | - python-editor=1.0.4=pyhd3eb1b0_0 202 | - python_abi=3.7=2_cp37m 203 | - pytz=2019.1=py_0 204 | - pyyaml=5.1.2=py37h516909a_0 205 | - pyzmq=18.0.0=py37he6710b0_0 206 | - qt=5.9.7=h5867ecd_1 207 | - qtconsole=4.5.1=py_0 208 | - querystring_parser=1.2.4=py_0 209 | - readline=7.0=h7b6447c_5 210 | - requests=2.22.0=py37_0 211 | - scikit-learn=1.0.1=py37h51133e4_0 212 | - scipy=1.3.1=py37h921218d_2 213 | - seaborn=0.11.2=pyhd3eb1b0_0 214 | - send2trash=1.5.0=py37_0 215 | - setuptools=41.0.1=py37_0 216 | - shapely=1.6.4=py37h7ef4460_0 217 | - sip=4.19.8=py37hf484d3e_0 218 | - six=1.12.0=py37_0 219 | - smmap=4.0.0=pyhd3eb1b0_0 220 | - sortedcontainers=2.4.0=pyhd3eb1b0_0 221 | - soupsieve=2.3.1=pyhd8ed1ab_0 222 | - sqlalchemy=1.4.22=py37h7f8727e_0 223 | - sqlite=3.30.1=h7b6447c_0 224 | - sqlparse=0.4.1=py_0 225 | - statsmodels=0.10.1=py37hdd07704_0 226 | - tabulate=0.8.9=py37h06a4308_0 227 | - tbb=2021.4.0=hd09550d_0 228 | - tblib=1.7.0=pyhd3eb1b0_0 229 | - tenacity=8.0.1=py37h06a4308_0 230 | - tensorboard=1.14.0=py37_0 231 | - tensorflow=1.14.0=gpu_py37h74c33d7_0 232 | - tensorflow-base=1.14.0=gpu_py37he45bfe2_0 233 | - tensorflow-estimator=1.14.0=py37h5ca1d4c_0 234 | - tensorflow-gpu=1.14.0=h0d30ee6_0 235 | - termcolor=1.1.0=py_2 236 | - terminado=0.8.2=py37_0 237 | - testpath=0.4.2=py37_0 238 | - theano=1.0.3=py37hfc679d8_1 239 | - threadpoolctl=2.2.0=pyh0d69192_0 240 | - tk=8.6.9=hed695b0_1002 241 | - toolchain=2.4.0=0 242 | - toolchain_c_linux-64=2.4.0=0 243 | - toolchain_cxx_linux-64=2.4.0=0 244 | - toolz=0.9.0=py_1 245 | - tornado=6.0.2=py37h7b6447c_0 246 | - tqdm=4.62.3=pyhd8ed1ab_0 247 | - traitlets=4.3.2=py37_0 248 | - typing-extensions=3.10.0.2=hd3eb1b0_0 249 | - typing_extensions=3.10.0.2=pyh06a4308_0 250 | - urllib3=1.24.3=py37_0 251 | - utm=0.7.0=pyhd8ed1ab_0 252 | - vincent=0.4.4=py_1 253 | - wcwidth=0.1.7=py37_0 254 | - webencodings=0.5.1=py37_1 255 | - webob=1.8.7=pyhd8ed1ab_0 256 | - websocket-client=0.58.0=py37h06a4308_4 257 | - werkzeug=0.15.5=py_0 258 | - wheel=0.33.4=py37_0 259 | - widgetsnbextension=3.4.2=py37_0 260 | - wrapt=1.11.2=py37h516909a_0 261 | - x264=1!152.20180806=h14c3975_0 262 | - xlrd=1.2.0=py37_0 263 | - xorg-kbproto=1.0.7=h14c3975_1002 264 | - xorg-libice=1.0.10=h516909a_0 265 | - xorg-libx11=1.6.9=h516909a_0 266 | - xorg-libxext=1.3.4=h516909a_0 267 | - xorg-libxrender=0.9.10=h516909a_1002 268 | - xorg-renderproto=0.11.1=h14c3975_1002 269 | - xorg-xextproto=7.3.0=h14c3975_1002 270 | - xorg-xproto=7.0.31=h14c3975_1007 271 | - xz=5.2.4=h14c3975_4 272 | - yaml=0.1.7=h14c3975_1001 273 | - zeromq=4.3.1=he6710b0_3 274 | - zict=2.0.0=pyhd3eb1b0_0 275 | - zipp=3.6.0=pyhd3eb1b0_0 276 | - zlib=1.2.11=h7b6447c_3 277 | - zstd=1.3.7=h0b5b093_0 278 | - pip: 279 | - pyarrow==6.0.0 280 | 281 | -------------------------------------------------------------------------------- /resources_and_information/ais_data_faq_from_marine_cadastre.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaacOnline/ships/ca82410a122267d7ed4ac26356b084de88c2095b/resources_and_information/ais_data_faq_from_marine_cadastre.pdf -------------------------------------------------------------------------------- /resources_and_information/coast_guard_mmsi_document.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaacOnline/ships/ca82410a122267d7ed4ac26356b084de88c2095b/resources_and_information/coast_guard_mmsi_document.pdf -------------------------------------------------------------------------------- /resources_and_information/vessel_type_codes_2018.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaacOnline/ships/ca82410a122267d7ed4ac26356b084de88c2095b/resources_and_information/vessel_type_codes_2018.pdf -------------------------------------------------------------------------------- /resources_and_information/vessel_type_guide.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaacOnline/ships/ca82410a122267d7ed4ac26356b084de88c2095b/resources_and_information/vessel_type_guide.pdf -------------------------------------------------------------------------------- /tests/config/config.py: -------------------------------------------------------------------------------- 1 | from socket import gethostname 2 | import os 3 | 4 | global dataset_config 5 | 6 | # Define start/end years to look at. Currently, the earliest supported year is 2015. Data prior to 2015 uses a different 7 | # url and will also need to be preprocessed slightly differently - check the ais_data_faq_from_marine_cadastre.pdf in 8 | # resources_and_information for details. 9 | global start_year 10 | start_year = 2015 11 | 12 | global end_year 13 | end_year = 2019 14 | 15 | global years 16 | years = range(start_year, end_year + 1) 17 | 18 | # The number of seconds between timestamps when interpolating. Currently set to 5 minutes 19 | global interpolation_time_gap 20 | interpolation_time_gap = 5 * 60 21 | 22 | # Number of *timestamps* to use for prediction, and to predict into the future. The current setting uses three hours of 23 | # history and predicts three hours into the future 24 | global length_of_history 25 | length_of_history = int(3 * 60 * 60 / interpolation_time_gap) + 1 26 | 27 | global length_into_the_future 28 | length_into_the_future = int(3 * 60 * 60 / interpolation_time_gap) - 1 29 | 30 | # Set a base directory to use for data storage. You should change this value. 31 | global data_directory 32 | data_directory = '/home/isaac/data/' 33 | 34 | global box_and_year_dir 35 | 36 | # Name of dataset being used 37 | global dataset_name 38 | dataset_name = 'formatted_with_currents_stride_3' 39 | 40 | 41 | # Whether logging should be used 42 | global logging 43 | logging = True 44 | 45 | # The host machine 46 | global machine 47 | host = gethostname() 48 | 49 | # Categorical columns that are one hot encoded. Only change this if you change preprocessing to add in other 50 | # columns. 51 | global categorical_columns 52 | categorical_columns = ['vessel_group','destination_cluster'] -------------------------------------------------------------------------------- /tests/config/dataset_config.py: -------------------------------------------------------------------------------- 1 | class DatasetConfig(): 2 | def __init__(self, dataset_name, 3 | lat_1, lat_2, lon_1, lon_2, 4 | sliding_window_movement, 5 | depth_1=0, depth_2=0, 6 | min_pts_to_try = None, eps_to_try = None, 7 | min_pts_to_use = None, eps_to_use = None): 8 | self.dataset_name = dataset_name 9 | self.lat_1 = min(lat_1, lat_2) 10 | self.lat_2 = max(lat_1, lat_2) 11 | self.lon_1 = min(lon_1, lon_2) 12 | self.lon_2 = max(lon_1, lon_2) 13 | self.corner_1 = (lat_1, lon_1) 14 | self.corner_2 = (lat_2, lon_2) 15 | # When expanding data using a sliding window, the sliding_window_movement is the length of time between windows. 16 | # E.g. if sliding_window_length is 10 * 60, tracks are supposed to be made up of three timestamps, and the 17 | # interpolated trajectory has timestamps at [0, 5, 10, 15, 20], then the trajectories output 18 | # will be [0, 5, 10], and [10, 15, 20] 19 | self.sliding_window_movement = sliding_window_movement 20 | self.min_pts_to_try = min_pts_to_try 21 | self.eps_to_try = eps_to_try 22 | self.min_pts_to_use = min_pts_to_use 23 | self.eps_to_use = eps_to_use 24 | self.depth_1 = depth_1 25 | self.depth_2 = depth_2 26 | 27 | 28 | 29 | datasets = { 30 | 'florida_gulf': 31 | DatasetConfig( 32 | dataset_name='florida_gulf', 33 | lat_1=26.00, lon_1=-85.50, lat_2=29.00, lon_2=-81.50, 34 | min_pts_to_try=[4, 10, 20, 50, 100, 250, 500], 35 | eps_to_try=[0.0001, 0.00025, 0.0005, 0.00075, 36 | 0.001, 0.0025, 0.005, 0.0075, 37 | 0.01, 0.025, 0.05, 0.075, 38 | 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 39 | 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 40 | 6, 7, 8, 9, 10], 41 | min_pts_to_use=50, 42 | eps_to_use=5, 43 | sliding_window_movement=15 * 60 44 | ), 45 | 'california_coast': 46 | DatasetConfig( 47 | dataset_name='california_coast', 48 | lat_1=33.40, lon_1=-122.00, lat_2=36.40, lon_2=-118.50, 49 | sliding_window_movement=15 * 60, 50 | min_pts_to_try = [4, 10, 20, 50, 100, 250, 500], 51 | eps_to_try=[1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5], 52 | min_pts_to_use = 50, 53 | eps_to_use = 3 54 | ), 55 | 'new_york': 56 | DatasetConfig( 57 | dataset_name='new_york', 58 | lat_1=39.50, lon_1=-74.50, lat_2=41.50, lon_2=-71.50, 59 | sliding_window_movement=60 * 60, 60 | min_pts_to_try=[4, 10, 20, 50, 100, 250, 500], 61 | eps_to_try=[0.0001, 0.00025, 0.0005, 0.00075, 62 | 0.001, 0.0025, 0.005, 0.0075, 63 | 0.01, 0.025, 0.05, 0.075, 64 | 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 65 | 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 66 | 6, 7, 8, 9, 10], 67 | min_pts_to_use=50, eps_to_use=3 68 | ) 69 | } 70 | -------------------------------------------------------------------------------- /tests/create_data.py: -------------------------------------------------------------------------------- 1 | # Script for performing the final preprocessing steps for a model, and saving the dataset to disk 2 | # Useful if you want to perform all the preprocessing at once - which may be necessary as some of the 3 | # datasets have large memory requirements, so shouldn't be processed in unison, which can happen 4 | # if the datasets are calculated on the fly. The downside is that many of the datasets are huge, so 5 | # it's difficult to store them on the same disk at once, if you don't have multiple free TB for storage 6 | 7 | import sys 8 | import os 9 | import time 10 | from config import config 11 | import atexit 12 | import json 13 | import gc 14 | 15 | os.environ['PYTHONHASHSEED'] = '0' 16 | 17 | import numpy as np 18 | import pandas as pd 19 | 20 | 21 | from utils import ProcessorManager, TestArgParser 22 | from utils.utils import total_system_ram 23 | 24 | # These need to come before tensorflow is imported so that if we're using CPU we can unregister the GPUs before tf 25 | # imports them. 26 | parser = TestArgParser() 27 | args = parser.parse() 28 | manager = ProcessorManager(debug=True) 29 | manager.open() 30 | 31 | import mlflow 32 | 33 | if args.debug: 34 | mlflow.set_experiment(experiment_name='Ships Debugging') 35 | 36 | 37 | from loading.data_loader import DataLoader 38 | import utils 39 | 40 | 41 | 42 | if __name__ == '__main__': 43 | # Parse command line arguments 44 | 45 | start_ts = time.time() 46 | 47 | utils.set_seed(args.seed) 48 | 49 | loader = DataLoader(config, args, conserve_memory=True) 50 | a = loader.load_set('train', 'train', 'y') 51 | if type(a) == list: 52 | nbytes = 0 53 | for ds in a: 54 | nbytes += ds.nbytes 55 | print(nbytes) 56 | print(nbytes / total_system_ram()) 57 | else: 58 | print(a.nbytes) 59 | print(a.nbytes / total_system_ram()) 60 | del a 61 | gc.collect() 62 | 63 | a = loader.load_set('train', 'train', 'x') 64 | if type(a) == list: 65 | nbytes = 0 66 | for ds in a: 67 | nbytes += ds.nbytes 68 | print(nbytes) 69 | print(nbytes / total_system_ram()) 70 | else: 71 | print(a.nbytes) 72 | print(a.nbytes / total_system_ram()) 73 | del a 74 | 75 | gc.collect() 76 | 77 | loader.load_set('valid', 'train', 'y') 78 | gc.collect() 79 | loader.load_set('valid', 'train', 'x') 80 | gc.collect() 81 | 82 | loader.load_set('valid', 'test', 'y') 83 | gc.collect() 84 | loader.load_set('valid', 'test', 'x') 85 | gc.collect() 86 | 87 | loader.load_set('test', 'test', 'y') 88 | gc.collect() 89 | loader.load_set('test', 'test', 'x') 90 | gc.collect() 91 | -------------------------------------------------------------------------------- /tests/loading/__init__.py: -------------------------------------------------------------------------------- 1 | from loading.loading import * 2 | from loading.generator import DataGenerator 3 | from loading.normalizer import Normalizer -------------------------------------------------------------------------------- /tests/loading/disk_array.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import numpy as np 3 | import gc 4 | import os 5 | import atexit 6 | from loading import loading 7 | from utils.utils import total_system_ram 8 | 9 | class DiskArray(): 10 | """ 11 | Light weight class for processing large data sets in chunks 12 | 13 | Saves chunks as tmp files on disk, so it only has to have one chunk in memory at a time. 14 | 15 | Chunks can contain multiple arrays, meaning a DiskArray can be equivalent to 2+ numpy arrays (although these 16 | arrays will need to have the same number of rows). 17 | 18 | Keeps track of the complete size of all chunks, and if the chunks are small enough to fit in memory all together, 19 | can convert itself to a numpy array or list of numpy arrays 20 | """ 21 | def __init__(self): 22 | self.temp_paths = [] 23 | self.shape = None 24 | self.nbytes = 0 25 | self.axis = 0 26 | atexit.register(self.exit_handler, self) 27 | 28 | def close(self): 29 | """ 30 | Clear all data 31 | 32 | Deletes any temp files stored on disk 33 | 34 | :return: 35 | """ 36 | for t in self.temp_paths: 37 | t.cleanup() 38 | self.temp_paths = [] 39 | self.shape = None 40 | self.nbytes = 0 41 | self.axis = 0 42 | 43 | def exit_handler(self, _): 44 | """ 45 | Wrapper for 'close' method, which can be called by atexit 46 | 47 | Used for handling a SIGINT 48 | 49 | :param _: 50 | :return: 51 | """ 52 | self.close() 53 | 54 | def add_array(self, array): 55 | """ 56 | Add numpy array to disk array. 57 | 58 | Currently, appended arrays are only added to the 0th axis of the DiskArray 59 | 60 | :param array: Array to append 61 | :return: 62 | """ 63 | self._update_basic_info(array) 64 | 65 | self.temp_paths.append(tempfile.TemporaryDirectory()) 66 | 67 | self._save_partition_to_path(self.temp_paths[-1].name, array) 68 | del array 69 | gc.collect() 70 | 71 | 72 | def _save_partition_to_path(self, path, data): 73 | """ 74 | Save one of the DiskArray's chunks to a specific path 75 | 76 | The data object can be a numpy array or list of numpy arrays. 77 | 78 | :param path: The path to save to 79 | :param data: The chunk to be saved 80 | :return: 81 | """ 82 | if type(data) != list: 83 | data = [data] 84 | for i, d in enumerate(data): 85 | if not os.path.exists(path): 86 | os.mkdir(path) 87 | save_path = os.path.join(path, f'{i}.npy') 88 | np.save(save_path, d) 89 | 90 | def _load_partition_from_path(self, path): 91 | """ 92 | Load a chunk from disk 93 | 94 | :param path: The path to save to 95 | :return: 96 | """ 97 | files = os.listdir(path) 98 | files = np.array(files)[np.argsort([int(n.split('.')[0]) for n in files])].tolist() 99 | data = [] 100 | for f in files: 101 | data += [np.load(os.path.join(path, f))] 102 | if len(data) == 1: 103 | data = data[0] 104 | return data 105 | 106 | 107 | def __getitem__(self, item): 108 | """ 109 | Load a chunk by index 110 | 111 | :param item: index of chunk to load 112 | :return: 113 | """ 114 | return self._load_partition_from_path(self.temp_paths[item].name) 115 | 116 | def compute(self): 117 | """ 118 | Convert self to numpy array or list of numpy arrays 119 | 120 | Requires that the size of arrays is less than 95% of total system ram 121 | 122 | :return: 123 | """ 124 | if self.nbytes > (total_system_ram() * 0.95): 125 | raise MemoryError('Not enough memory to load the dataset') 126 | data = [self._load_partition_from_path(t.name) for t in self.temp_paths] 127 | joined_data = [] 128 | if type(data[0]) == list: 129 | for i in range(len(data[0])): 130 | joined_data += [np.concatenate([d[i] for d in data], axis=self.axis)] 131 | else: 132 | joined_data = np.concatenate(data, axis=self.axis) 133 | 134 | return joined_data 135 | 136 | def head(self, n): 137 | """ 138 | Return the first n rows of DiskArray 139 | 140 | Will return two numpy arrays if DiskArray contains two arrays 141 | 142 | :param n: Number of rows 143 | :return: 144 | """ 145 | sampled = 0 146 | head_data = None 147 | for t in self.temp_paths: 148 | data = self._load_partition_from_path(t.name) 149 | if type(data) == list: 150 | t_len = len(data[0]) 151 | to_sample = min(t_len, n - sampled) 152 | data = [d[:to_sample] for d in data] 153 | if head_data is None: 154 | head_data = data 155 | else: 156 | head_data = [np.concatenate([hd, d], axis=self.axis) for hd, d in zip(head_data, data)] 157 | else: 158 | t_len = len(data) 159 | to_sample = min(t_len, n - sampled) 160 | data = data[:to_sample] 161 | if head_data is None: 162 | head_data = data 163 | else: 164 | head_data = np.concatenate([head_data, data], axis=self.axis) 165 | sampled += to_sample 166 | if sampled == n: 167 | break 168 | assert sampled < n 169 | 170 | if sampled < n: 171 | raise UserWarning('Dataset does not contain the desired number of records. Entire dataset returned') 172 | 173 | return head_data 174 | 175 | 176 | def save_to_disk(self, dir): 177 | """ 178 | Save DiskArray to a path 179 | 180 | If the size of the disk array is less than 35% of total system ram, the array will be converted to 181 | numpy arrays beforehand. Otherwise the chunks will be saved individually 182 | 183 | :param dir: 184 | :return: 185 | """ 186 | if self.nbytes / total_system_ram() < 0.35: 187 | data = self.compute() 188 | if type(data) == list: 189 | os.mkdir(dir) 190 | for i, set in enumerate(data): 191 | path = os.path.join(dir, f'{i}.npy') 192 | np.save(path, set) 193 | else: 194 | path = dir + '.npy' 195 | np.save(path, data) 196 | else: 197 | dir += '_disk_array' 198 | if not os.path.exists(dir): 199 | os.mkdir(dir) 200 | for i, t in enumerate(self.temp_paths): 201 | data = self._load_partition_from_path(t.name) 202 | t_dir = os.path.join(dir, str(i)) 203 | self._save_partition_to_path(t_dir, data) 204 | 205 | def load_from_disk(self, dir): 206 | """ 207 | Load a disk array from a save path 208 | 209 | :param dir: 210 | :return: 211 | """ 212 | partitions = os.listdir(dir) 213 | for p in partitions: 214 | path = os.path.join(dir, p) 215 | data = self._load_partition_from_path(path) 216 | self.add_array(data) 217 | 218 | 219 | def __del__(self): 220 | """ 221 | Delete self 222 | 223 | :return: 224 | """ 225 | self.close() 226 | 227 | def __len__(self): 228 | """ 229 | Number of rows 230 | 231 | :return: 232 | """ 233 | return self.shape[0] 234 | 235 | def _calculate_min_max(self): 236 | """ 237 | Calculate the minimum/maximum value for each column in DiskArray 238 | 239 | Used for calculating normalization factors 240 | 241 | :return: 242 | """ 243 | num_columns = self.shape[-1] 244 | mins = np.array(np.ones((num_columns,)) * np.inf) 245 | maxes = np.array(np.ones((num_columns,)) * -np.inf) 246 | 247 | for t in self.temp_paths: 248 | data = self._load_partition_from_path(t.name) 249 | partition_mins = data.min(axis=0) 250 | partition_maxes = data.max(axis=0) 251 | while len(partition_mins.shape) > 1: 252 | partition_mins = partition_mins.min(axis=0) 253 | partition_maxes = partition_maxes.max(axis=0) 254 | mins = np.min([mins, partition_mins], axis=0) 255 | maxes = np.max([maxes, partition_maxes], axis=0) 256 | del data 257 | gc.collect() 258 | 259 | return mins, maxes 260 | 261 | def _update_basic_info(self, data): 262 | """ 263 | Update summary information for DiskArray 264 | 265 | Keeps track of the total number of bytes of all chunks, as well as the total shape 266 | 267 | :param data: New chunk 268 | :return: 269 | """ 270 | if type(data) == list: 271 | self.nbytes += np.sum([a.nbytes for a in data]).astype(int) 272 | if self.shape == None: 273 | self.shape = [a.shape for a in data] 274 | else: 275 | for i, (s, a) in enumerate(zip(self.shape, data)): 276 | s = list(s) 277 | s[self.axis] += a.shape[self.axis] 278 | self.shape[i] = tuple(s) 279 | else: 280 | self.nbytes += data.nbytes 281 | if self.shape == None: 282 | self.shape = data.shape 283 | else: 284 | self.shape = list(self.shape) 285 | self.shape[self.axis] += data.shape[self.axis] 286 | self.shape = tuple(self.shape) 287 | 288 | def _apply_transformations(self, x_or_y, transformations, normalizer, normalization_factors): 289 | """ 290 | Apply a list of transformations to a DiskArray 291 | 292 | Transformations should be a list of specifications generated by a DataLoader, which can be handled by 293 | loading.apply_transformations 294 | 295 | :param x_or_y: Whether this is an input or output dataset 296 | :param transformations: List 297 | :param normalizer: 298 | :param normalization_factors: 299 | :return: 300 | """ 301 | self.nbytes = 0 302 | self.shape = None 303 | for t in self.temp_paths: 304 | array = self._load_partition_from_path(t.name) 305 | array = loading.apply_transformations(array, x_or_y, transformations, normalizer, normalization_factors) 306 | self._update_basic_info(array) 307 | self._save_partition_to_path(t.name, array) 308 | 309 | del array 310 | gc.collect() 311 | -------------------------------------------------------------------------------- /tests/loading/generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import gc 4 | from loading.disk_array import DiskArray 5 | 6 | 7 | class DataGenerator(tf.keras.utils.Sequence): 8 | """ 9 | Data generator for keras modeling 10 | """ 11 | def __init__(self, X, Y, batch_size=512, shuffle=True): 12 | """ 13 | If data length is not divisible by batch size, will keep out a random set of rows each round (who make up the 14 | remainder that don't fit into a full batch) 15 | 16 | 17 | :param X: input sequences. Can contain multiple arrays 18 | :param Y: output sequences 19 | :param batch_size: 20 | :param shuffle: Whether to shuffle data after each epoch 21 | """ 22 | self.batch_size = batch_size 23 | if isinstance(X, list): 24 | self.num_X_sets = len(X) 25 | self.X_unified = [xset.copy() for xset in X] 26 | else: 27 | self.num_X_sets = 1 28 | self.X_unified = [X] 29 | self.Y_unified = Y.copy() 30 | self.X_split = None 31 | self.Y_split = None 32 | self.shuffle = shuffle 33 | self.data_is_split = False 34 | if isinstance(X, list): 35 | self.complete_len = len(X[0]) 36 | else: 37 | self.complete_len = len(X) 38 | self.batch_len = int(np.floor(self.complete_len / self.batch_size)) 39 | self.split_indexes = np.arange(0, self.complete_len, self.batch_size)[1:] 40 | self.on_epoch_end() 41 | 42 | def __len__(self): 43 | """ 44 | :return: Number of batches per epoch 45 | """ 46 | return self.batch_len 47 | 48 | def __getitem__(self, index): 49 | """ 50 | Return one batch of data 51 | 52 | :param index: Index of batch to retrieve 53 | :return: 54 | """ 55 | return self.__data_generation(index) 56 | 57 | def on_epoch_begin(self): 58 | """ 59 | Split data into batches at the beginning of epoch 60 | 61 | :return: 62 | """ 63 | if not self.data_is_split: 64 | self.X_split = [np.split(xset, self.split_indexes, axis=0) for xset in self.X_unified] 65 | self.Y_split = np.split(self.Y_unified, self.split_indexes, axis=0) 66 | self.X_unified = None 67 | self.Y_unified = None 68 | self.data_is_split = True 69 | 70 | def on_epoch_end(self): 71 | """ 72 | Shuffle dataset at the end of an epoch 73 | 74 | :return: 75 | """ 76 | if self.data_is_split: 77 | X_info = [[list(xset[0].shape), xset[0].dtype] for xset in self.X_split] 78 | for i in range(len(X_info)): 79 | X_info[i][0][0] = self.complete_len 80 | self.X_unified = [np.empty(shape, dtype=dtype) for shape, dtype in X_info] 81 | for i in range(self.num_X_sets): 82 | self.X_unified[i][:] = np.nan 83 | for xset_idx, start_index in zip(range(len(self.X_split[0])), range(0, self.complete_len, self.batch_size)): 84 | for i in range(self.num_X_sets): 85 | self.X_unified[i][start_index:start_index+self.batch_size] = self.X_split[i][xset_idx] 86 | self.X_split[i][xset_idx] = None 87 | 88 | 89 | 90 | self.Y_unified = np.concatenate(self.Y_split, axis=0) 91 | self.Y_split = None 92 | self.X_split = None 93 | self.data_is_split = False 94 | self.indexes = np.arange(self.complete_len) 95 | if self.shuffle: 96 | np.random.shuffle(self.indexes) 97 | 98 | self.X_unified = [xset[self.indexes] for xset in self.X_unified] 99 | self.Y_unified = self.Y_unified[self.indexes] 100 | 101 | 102 | self.X_split = [ 103 | np.split(xset, self.split_indexes, axis=0) for xset in self.X_unified 104 | ] 105 | 106 | self.Y_split = np.split(self.Y_unified, self.split_indexes, axis=0) 107 | self.X_unified = None 108 | self.Y_unified = None 109 | self.data_is_split = True 110 | gc.collect() 111 | 112 | 113 | def __data_generation(self, index): 114 | """ 115 | Retrieve a batch by index 116 | 117 | :param index: Index of batch 118 | :return: 119 | """ 120 | output = self.Y_split[index] 121 | if self.num_X_sets == 1: 122 | input = self.X_split[0][index] 123 | else: 124 | input = {f'input_{i+1}':xset[index] for i, xset in enumerate(self.X_split)} 125 | 126 | return input, output -------------------------------------------------------------------------------- /tests/loading/normalizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from loading.disk_array import DiskArray 4 | from loading.loading import _find_current_col_idx 5 | from config import config 6 | 7 | 8 | class Normalizer(): 9 | def __init__(self): 10 | pass 11 | 12 | @staticmethod 13 | def get_normalization_factors(X, columns): 14 | """ 15 | Get the normalization factors for a specific dataset 16 | 17 | Find the min and max values for different columns, which are used when normalizing the dataset to a 0-1 range 18 | 19 | :param X: Dataset 20 | :param columns: The names and types of columns in X 21 | :return: Normalization factors 22 | """ 23 | year_idx = _find_current_col_idx('year', columns) 24 | month_idx = _find_current_col_idx('month', columns) 25 | 26 | normalization_factors = { 27 | 'lat': { 28 | 'idx': _find_current_col_idx('lat', columns), 29 | 'min': float(config.dataset_config.lat_1), 30 | 'max': float(config.dataset_config.lat_2) 31 | }, 32 | 'lon': { 33 | 'idx': _find_current_col_idx('lon', columns), 34 | 'min': float(config.dataset_config.lon_1), 35 | 'max': float(config.dataset_config.lon_2) 36 | }, 37 | 'year': { 38 | 'idx': year_idx, 39 | 'min': min(config.years), 40 | 'max': max(config.years) 41 | }, 42 | 'month': { 43 | 'idx': month_idx, 44 | 'min': 1, 45 | 'max': 12 46 | } 47 | 48 | } 49 | cols_to_use = columns['column'][columns['being_used']] 50 | non_bools = cols_to_use[columns.dtype != 'bool'] 51 | multi_cols = ['speed','water','mmsi_neighbor','lat_neighbor','lon_neighbor','time_since_neighbor'] 52 | ranges = {k: [np.Inf, -np.Inf] for k in multi_cols} 53 | if type(X) == DiskArray: 54 | mins, maxes = X._calculate_min_max() 55 | for col in non_bools: 56 | if col not in normalization_factors.keys(): 57 | idx = _find_current_col_idx(col, columns) 58 | normalization_factors[col] = { 59 | 'idx': idx, 60 | 'min': mins[idx] if type(X) == DiskArray else float(X[:, :, idx].min()), 61 | 'max': maxes[idx] if type(X) == DiskArray else float(X[:, :, idx].max()) 62 | } 63 | for mc in multi_cols: 64 | if mc in col: 65 | ranges[mc][0] = min(normalization_factors[col]['min'], ranges[mc][0]) 66 | ranges[mc][1] = max(normalization_factors[col]['max'], ranges[mc][1]) 67 | 68 | for col in non_bools: 69 | for mc in multi_cols: 70 | if mc in col: 71 | normalization_factors[col]['min'] = ranges[mc][0] 72 | normalization_factors[col]['max'] = ranges[mc][1] 73 | 74 | return normalization_factors 75 | 76 | @staticmethod 77 | def normalize_data(data, normalization_factors): 78 | """ 79 | Apply the normalization factors to a dataset 80 | 81 | Uses min/max normalization (and the mins/maxes specified in normalization_factors) to put variables on a 0 to 1 range 82 | 83 | Assumes data is 3D array of shape (# of timestamps, # of trajectories, # of columns) 84 | 85 | :param data: Dataset to normalize 86 | :param normalization_factors: Normalization factors 87 | :return: Normalized data 88 | """ 89 | 90 | if len(data.shape) == 3: 91 | for col in normalization_factors.values(): 92 | if col['idx'] < data.shape[-1]: 93 | range = (col['max'] - col['min']) 94 | if range != 0: 95 | dist_above_min = (data[:, :, col['idx']] - col['min']) 96 | data[:, :, col['idx']] = dist_above_min / range 97 | else: # if the variable doesn't vary at all (which can happen when debugging), just set it to 0 98 | data[:,:, col['idx']] = 0 99 | 100 | else: 101 | for col in normalization_factors.values(): 102 | if col['idx'] < data.shape[-1]: 103 | range = (col['max'] - col['min']) 104 | if range != 0: 105 | dist_above_min = (data[:, col['idx']] - col['min']) 106 | data[:, col['idx']] = dist_above_min / range 107 | else: 108 | data[:, col['idx']] = 0 109 | return data 110 | 111 | @staticmethod 112 | def unnormalize(data, normalization_factors, idxs=None): 113 | """ 114 | Move data from 0-1 scale back to original scale 115 | 116 | :param data: Data to unnormalize 117 | :param normalization_factors: Normalization factors 118 | :param idxs: Indexes to unnormalize 119 | :return: Unnormalized data 120 | """ 121 | data = data.copy() 122 | if idxs is None: 123 | idxs = np.arange(0, data.shape[-1]) 124 | 125 | if len(data.shape) == 3: 126 | for col in normalization_factors.values(): 127 | if np.any(col['idx'] == idxs): 128 | range = (col['max'] - col['min']) 129 | data[:, :, col['idx']] = data[:, :, col['idx']] * range + col['min'] 130 | else: 131 | for col in normalization_factors.values(): 132 | if np.any(col['idx'] == idxs): 133 | range = (col['max'] - col['min']) 134 | data[:, col['idx']] = data[:, col['idx']] * range + col['min'] 135 | 136 | return data -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.direct_fusion_runner import FusionModelRunner 2 | from models.direct_rnn_runner import RNNLongTermModelRunner 3 | from models.seq2seq_runner import Seq2SeqRNNAttentionRunner 4 | from models.iterative_rnn_runner import RNNModelRunner 5 | from models.median_stopping import MedianStopper 6 | -------------------------------------------------------------------------------- /tests/models/direct_fusion_runner.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from haversine import haversine_vector, Unit 3 | from keras import Input, Model 4 | from keras.layers import GRU as GRUKeras, LSTM as LSTMKeras, Bidirectional, Flatten, Conv2D, MaxPooling2D, Dense, \ 5 | Dropout 6 | from keras.optimizer_v2.adam import Adam as AdamKeras 7 | from keras.regularizers import L1, L2 8 | 9 | from loading import Normalizer 10 | from loading.loading import _find_current_col_idx 11 | from models.losses import HaversineLoss 12 | from models.model_runner import ModelRunner 13 | 14 | 15 | class FusionModelRunner(ModelRunner): 16 | """ 17 | Class for creating the desired type of Tensorflow model. Creates the model object, and also provides a wrapper 18 | function for making predictions 19 | """ 20 | def __init__(self, node_type, number_of_rnn_layers, rnn_layer_size, number_of_final_dense_layers, 21 | number_of_fusion_weather_layers, dense_layer_size, 22 | direction, input_ts_length, input_num_recurrent_features, weather_shape, 23 | output_num_features, normalization_factors, 24 | y_idxs, columns, learning_rate, rnn_to_dense_connection, recurrent_idxs, loss='mse', 25 | regularization = None, regularization_application=None, regularization_coefficient=None, 26 | fusion_layer_structure = 'dense', output_feature_size=None, conv_kernel_size=None, 27 | conv_stride_size=None, pool_size=None): 28 | if node_type.lower() == 'gru': 29 | self.rnn_layer = GRUKeras 30 | elif node_type.lower() == 'lstm': 31 | self.rnn_layer = LSTMKeras 32 | else: 33 | raise ValueError('node_type must either be "gru" or "lstm"') 34 | 35 | if direction in ['forward_only','bidirectional']: 36 | self.direction = direction 37 | else: 38 | raise ValueError('direction must be either "forward_only" or "bidirectional"') 39 | 40 | self.number_of_rnn_layers = number_of_rnn_layers 41 | self.rnn_layer_size = rnn_layer_size 42 | self.number_of_fusion_weather_layers = number_of_fusion_weather_layers 43 | self.number_of_final_dense_layers = number_of_final_dense_layers 44 | self.dense_layer_size = dense_layer_size 45 | self.ts_length = input_ts_length 46 | self.input_num_recurrent_features = input_num_recurrent_features 47 | self.output_feature_size = output_feature_size 48 | self.conv_kernel_size = conv_kernel_size 49 | self.conv_stride_size = conv_stride_size 50 | self.pool_size = pool_size 51 | if fusion_layer_structure == 'dense': 52 | self.input_num_dense_features = weather_shape[-1] 53 | else: 54 | self.weather_shape = list(weather_shape) 55 | self.weather_shape[0] = None 56 | self.output_size = output_num_features 57 | self.rnn_to_dense_connection = rnn_to_dense_connection 58 | self.fusion_layer_structure = fusion_layer_structure 59 | if regularization == 'dropout': 60 | if regularization_application == 'recurrent': 61 | self.rnn_regularization = {'recurrent_dropout':regularization_coefficient} 62 | self.dense_dropout = 0.0 63 | self.dense_regularization = {} 64 | elif regularization_application is None: 65 | self.rnn_regularization = {'dropout':regularization_coefficient} 66 | self.dense_dropout = regularization_coefficient 67 | self.dense_regularization = {} 68 | elif regularization in ['l1','l2']: 69 | self.regularizer = L1 if regularization == 'l1' else L2 70 | if regularization_application in ['bias','activity']: 71 | self.rnn_regularization = {f'{regularization_application}_regularizer': 72 | self.regularizer(regularization_coefficient)} 73 | self.dense_dropout = 0.0 74 | self.dense_regularization = {f'{regularization_application}_regularizer': 75 | self.regularizer(regularization_coefficient)} 76 | elif regularization_application == 'recurrent': 77 | self.rnn_regularization = {f'{regularization_application}_regularizer': 78 | self.regularizer(regularization_coefficient)} 79 | self.dense_dropout = 0.0 80 | self.dense_regularization = {} 81 | else: 82 | self.rnn_regularization = {} 83 | self.dense_dropout = 0.0 84 | self.dense_regularization = {} 85 | 86 | 87 | self._init_model() 88 | self.normalization_factors = normalization_factors 89 | self.y_idxs = y_idxs 90 | self.recurrent_idxs = recurrent_idxs 91 | self.columns = columns 92 | self.optimizer = AdamKeras(learning_rate=learning_rate) 93 | self.loss = 'mse' if loss=='mse' else HaversineLoss(normalization_factors).haversine_loss 94 | 95 | def _init_model(self): 96 | """ 97 | Create model as specified during initialization 98 | 99 | :return: 100 | """ 101 | # Recurrent section 102 | recurrent_input = Input(shape=(self.ts_length, self.input_num_recurrent_features)) 103 | 104 | recurrent_layers = [recurrent_input] 105 | if self.rnn_to_dense_connection == 'all_nodes': 106 | num_full_sequence_layers = self.number_of_rnn_layers 107 | else: 108 | num_full_sequence_layers = self.number_of_rnn_layers - 1 109 | 110 | if num_full_sequence_layers > 0: 111 | if self.direction == 'forward_only': 112 | recurrent_layers.append(self.rnn_layer(self.rnn_layer_size, return_sequences=True, 113 | **self.rnn_regularization)(recurrent_layers[-1])) 114 | num_full_sequence_layers -= 1 115 | elif self.direction == 'bidirectional': 116 | recurrent_layers.append(Bidirectional(self.rnn_layer(self.rnn_layer_size, return_sequences=True, 117 | **self.rnn_regularization))(recurrent_layers[-1])) 118 | num_full_sequence_layers -= 1 119 | 120 | for layer in range(num_full_sequence_layers): 121 | if self.direction == 'forward_only': 122 | recurrent_layers.append(self.rnn_layer(self.rnn_layer_size, return_sequences=True, 123 | **self.rnn_regularization)(recurrent_layers[-1])) 124 | elif self.direction == 'bidirectional': 125 | recurrent_layers.append(Bidirectional(self.rnn_layer(self.rnn_layer_size, return_sequences=True, 126 | **self.rnn_regularization))(recurrent_layers[-1])) 127 | 128 | if self.rnn_to_dense_connection == 'all_nodes': 129 | recurrent_layers.append(Flatten('channels_first')(recurrent_layers[-1])) 130 | else: 131 | if self.direction == 'forward_only': 132 | recurrent_layers.append(self.rnn_layer(self.rnn_layer_size, return_sequences=False, 133 | **self.rnn_regularization)(recurrent_layers[-1])) 134 | elif self.direction == 'bidirectional': 135 | recurrent_layers.append(Bidirectional(self.rnn_layer(self.rnn_layer_size, return_sequences=False, 136 | **self.rnn_regularization))(recurrent_layers[-1])) 137 | # Weather section 138 | if self.fusion_layer_structure == 'convolutions': 139 | weather_input = Input(shape=self.weather_shape[1:]) 140 | initial_weather_layers = [weather_input] 141 | for layer in range(self.number_of_fusion_weather_layers): 142 | initial_weather_layers.append(Conv2D(self.output_feature_size / (2 ** (self.number_of_fusion_weather_layers - layer)), 143 | (self.conv_kernel_size, self.conv_kernel_size), 144 | padding='same', 145 | strides=self.conv_stride_size, 146 | activation='relu', 147 | data_format='channels_last')(initial_weather_layers[-1])) 148 | initial_weather_layers.append(Conv2D(self.output_feature_size / (2 ** (self.number_of_fusion_weather_layers - layer)), 149 | (self.conv_kernel_size, self.conv_kernel_size), 150 | padding='same', 151 | strides=self.conv_stride_size, 152 | activation='relu', 153 | data_format='channels_last')(initial_weather_layers[-1])) 154 | initial_weather_layers.append(Conv2D(self.output_feature_size / (2 ** (self.number_of_fusion_weather_layers - layer)), 155 | (self.conv_kernel_size, self.conv_kernel_size), 156 | padding='same', 157 | strides=self.conv_stride_size, 158 | activation='relu', 159 | data_format='channels_last')(initial_weather_layers[-1])) 160 | initial_weather_layers.append(MaxPooling2D((self.pool_size, self.pool_size), 161 | padding='same')(initial_weather_layers[-1])) 162 | # initial_weather_layers.append(Dropout(0.3)(initial_weather_layers[-1])) 163 | initial_weather_layers.append(Flatten('channels_last')(initial_weather_layers[-1])) 164 | 165 | elif self.fusion_layer_structure == 'dense': 166 | weather_input = Input(shape=(self.input_num_dense_features,)) 167 | initial_weather_layers = [weather_input] 168 | for layer in range(self.number_of_fusion_weather_layers): 169 | initial_weather_layers.append(Dense(self.dense_layer_size, activation='relu', 170 | **self.dense_regularization)(initial_weather_layers[-1])) 171 | initial_weather_layers.append(Dropout(self.dense_dropout)(initial_weather_layers[-1])) 172 | 173 | if self.fusion_layer_structure is None: 174 | final_layers = [recurrent_layers[-1]] 175 | else: 176 | final_layers = [tf.concat([recurrent_layers[-1], initial_weather_layers[-1]], axis=1)] 177 | 178 | # Define Final Dense section 179 | for layer in range(self.number_of_final_dense_layers): 180 | final_layers.append(Dense(self.dense_layer_size, activation='relu', 181 | **self.dense_regularization)(final_layers[-1])) 182 | final_layers.append(Dropout(self.dense_dropout)(final_layers[-1])) 183 | 184 | # Define output 185 | output = Dense(self.output_size, activation='linear')(final_layers[-1]) 186 | 187 | if self.fusion_layer_structure is None: 188 | self.model = Model(inputs=[recurrent_input], outputs=output) 189 | else: 190 | self.model = Model(inputs=[recurrent_input, weather_input], outputs=output) 191 | 192 | 193 | def save(self, *pos_args, **named_args): 194 | """ 195 | Save model to disk 196 | 197 | :param pos_args: Positional args, passed down to model object's save method 198 | :param named_args: Named args, passed down to model object's save method 199 | :return: 200 | """ 201 | self.model.save(*pos_args, **named_args) 202 | 203 | def predict(self, valid_X_long_term, valid_Y_long_term, args): 204 | """ 205 | Make predictions for an evaluation dataset, returning both the predictions and errors 206 | 207 | :param valid_X_long_term: Dataset to make predictions for 208 | :param valid_Y_long_term: Ground truth 209 | :param args: argparse.Namespace specifying model 210 | :return: 211 | """ 212 | Y_hat = self.model.predict(valid_X_long_term) 213 | Y_hat = Normalizer().unnormalize(Y_hat, self.normalization_factors) 214 | 215 | valid_Y_long_term = Normalizer().unnormalize(valid_Y_long_term, self.normalization_factors) 216 | 217 | haversine_distances = haversine_vector(valid_Y_long_term, Y_hat, Unit.KILOMETERS) 218 | mean_haversine_distance = haversine_distances.mean() 219 | return [Y_hat], [haversine_distances], [mean_haversine_distance] -------------------------------------------------------------------------------- /tests/models/direct_rnn_runner.py: -------------------------------------------------------------------------------- 1 | from haversine import haversine_vector, Unit 2 | from keras import Input, Model 3 | from keras.layers import GRU as GRUKeras, LSTM as LSTMKeras, Bidirectional, Flatten, Dense, Dropout 4 | from keras.optimizer_v2.adam import Adam as AdamKeras 5 | from keras.regularizers import L1, L2 6 | 7 | from loading import add_distance_traveled, Normalizer 8 | from models.losses import HaversineLoss 9 | from models.model_runner import ModelRunner 10 | 11 | 12 | class RNNLongTermModelRunner(ModelRunner): 13 | """ 14 | Class for creating the desired type of Tensorflow model. Creates the model object, and also provides a wrapper 15 | function for making predictions 16 | """ 17 | def __init__(self, node_type, number_of_rnn_layers, rnn_layer_size, number_of_dense_layers, dense_layer_size, 18 | direction, input_ts_length, input_num_features, output_num_features, normalization_factors, 19 | y_idxs, columns, learning_rate, rnn_to_dense_connection, loss='mse', 20 | regularization = None, regularization_application=None, regularization_coefficient=None): 21 | if node_type.lower() == 'gru': 22 | self.rnn_layer = GRUKeras 23 | elif node_type.lower() == 'lstm': 24 | self.rnn_layer = LSTMKeras 25 | else: 26 | raise ValueError('node_type must either be "gru" or "lstm"') 27 | 28 | if direction in ['forward_only','bidirectional']: 29 | self.direction = direction 30 | else: 31 | raise ValueError('direction must be either "forward_only" or "bidirectional"') 32 | 33 | self.number_of_rnn_layers = number_of_rnn_layers 34 | self.rnn_layer_size = rnn_layer_size 35 | self.number_of_dense_layers = number_of_dense_layers 36 | self.dense_layer_size = dense_layer_size 37 | self.ts_length = input_ts_length 38 | self.input_num_features = input_num_features 39 | self.output_size = output_num_features 40 | self.rnn_to_dense_connection = rnn_to_dense_connection 41 | if regularization == 'dropout': 42 | if regularization_application == 'recurrent': 43 | self.rnn_regularization = {'recurrent_dropout':regularization_coefficient} 44 | self.dense_dropout = 0.0 45 | self.dense_regularization = {} 46 | elif regularization_application is None: 47 | self.rnn_regularization = {'dropout':regularization_coefficient} 48 | self.dense_dropout = regularization_coefficient 49 | self.dense_regularization = {} 50 | elif regularization in ['l1','l2']: 51 | self.regularizer = L1 if regularization == 'l1' else L2 52 | if regularization_application in ['bias','activity']: 53 | self.rnn_regularization = {f'{regularization_application}_regularizer': 54 | self.regularizer(regularization_coefficient)} 55 | self.dense_dropout = 0.0 56 | self.dense_regularization = {f'{regularization_application}_regularizer': 57 | self.regularizer(regularization_coefficient)} 58 | elif regularization_application == 'recurrent': 59 | self.rnn_regularization = {f'{regularization_application}_regularizer': 60 | self.regularizer(regularization_coefficient)} 61 | self.dense_dropout = 0.0 62 | self.dense_regularization = {} 63 | else: 64 | self.rnn_regularization = {} 65 | self.dense_dropout = 0.0 66 | self.dense_regularization = {} 67 | 68 | 69 | self._init_model() 70 | self.normalization_factors = normalization_factors 71 | self.y_idxs = y_idxs 72 | self.columns = columns 73 | self.optimizer = AdamKeras(learning_rate=learning_rate) 74 | self.loss = 'mse' if loss=='mse' else HaversineLoss(normalization_factors).haversine_loss 75 | 76 | def _init_model(self): 77 | """ 78 | Create model as specified during initialization 79 | 80 | :return: 81 | """ 82 | recurrent_input = Input(shape=(self.ts_length, self.input_num_features)) 83 | 84 | if self.rnn_to_dense_connection == 'all_nodes': 85 | num_full_sequence_layers = self.number_of_rnn_layers 86 | else: 87 | num_full_sequence_layers = self.number_of_rnn_layers - 1 88 | 89 | # Define RNN section 90 | if self.direction == 'forward_only': 91 | hidden = [self.rnn_layer(self.rnn_layer_size, return_sequences=True, 92 | **self.rnn_regularization)(recurrent_input)] 93 | num_full_sequence_layers -= 1 94 | elif self.direction == 'bidirectional': 95 | hidden = [Bidirectional(self.rnn_layer(self.rnn_layer_size, return_sequences=True, 96 | **self.rnn_regularization))(recurrent_input)] 97 | num_full_sequence_layers -= 1 98 | 99 | for layer in range(num_full_sequence_layers): 100 | if self.direction == 'forward_only': 101 | hidden.append(self.rnn_layer(self.rnn_layer_size, return_sequences=True, 102 | **self.rnn_regularization)(hidden[-1])) 103 | elif self.direction == 'bidirectional': 104 | hidden.append(Bidirectional(self.rnn_layer(self.rnn_layer_size, return_sequences=True, 105 | **self.rnn_regularization))(hidden[-1])) 106 | 107 | if self.rnn_to_dense_connection == 'all_nodes': 108 | hidden.append(Flatten('channels_first')(hidden[-1])) 109 | else: 110 | if self.direction == 'forward_only': 111 | hidden.append(self.rnn_layer(self.rnn_layer_size, return_sequences=False, 112 | **self.rnn_regularization)(hidden[-1])) 113 | elif self.direction == 'bidirectional': 114 | hidden.append(Bidirectional(self.rnn_layer(self.rnn_layer_size, return_sequences=False, 115 | **self.rnn_regularization))(hidden[-1])) 116 | 117 | # Define Dense section 118 | for layer in range(self.number_of_dense_layers): 119 | hidden.append(Dense(self.dense_layer_size, activation='relu', 120 | **self.dense_regularization)(hidden[-1])) 121 | hidden.append(Dropout(self.dense_dropout)(hidden[-1])) 122 | 123 | # Define output 124 | output = Dense(self.output_size, activation='linear')(hidden[-1]) 125 | 126 | self.model = Model(inputs=recurrent_input, outputs=output) 127 | 128 | def save(self, *pos_args, **named_args): 129 | """ 130 | Save model to disk 131 | 132 | :param pos_args: Positional args, passed down to model object's save method 133 | :param named_args: Named args, passed down to model object's save method 134 | :return: 135 | """ 136 | self.model.save(*pos_args, **named_args) 137 | 138 | def predict(self, valid_X_long_term, valid_Y_long_term, args): 139 | """ 140 | Make predictions for an evaluation dataset, returning both the predictions and errors 141 | 142 | :param valid_X_long_term: Dataset to make predictions for 143 | :param valid_Y_long_term: Ground truth 144 | :param args: argparse.Namespace specifying model 145 | :return: 146 | """ 147 | Y_hat = self.model.predict(valid_X_long_term) 148 | Y_hat = Normalizer().unnormalize(Y_hat, self.normalization_factors) 149 | valid_Y_long_term = Normalizer().unnormalize(valid_Y_long_term, self.normalization_factors) 150 | 151 | haversine_distances = haversine_vector(valid_Y_long_term, Y_hat, Unit.KILOMETERS) 152 | mean_haversine_distance = haversine_distances.mean() 153 | return [Y_hat], [haversine_distances], [mean_haversine_distance] -------------------------------------------------------------------------------- /tests/models/iterative_rnn_runner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from haversine import haversine_vector, Unit 4 | from keras import Input, Model 5 | from keras.layers import GRU as GRUKeras, LSTM as LSTMKeras, Bidirectional, Dense, Flatten 6 | from keras.optimizer_v2.adam import Adam as AdamKeras 7 | 8 | from loading import Normalizer 9 | from loading.loading import _find_current_col_idx 10 | from models.model_runner import ModelRunner 11 | from models.losses import HaversineLoss 12 | 13 | class RNNModelRunner(ModelRunner): 14 | """ 15 | Class for creating the desired type of Tensorflow model. Creates the model object, and also provides a wrapper 16 | function for making predictions 17 | """ 18 | def __init__(self, node_type, number_of_rnn_layers, rnn_layer_size, number_of_dense_layers, dense_layer_size, 19 | direction, input_ts_length, input_num_features, output_num_features, normalization_factors, 20 | y_idxs, columns, learning_rate, rnn_to_dense_connection, loss='mse'): 21 | if node_type.lower() == 'gru': 22 | self.rnn_layer = GRUKeras 23 | elif node_type.lower() == 'lstm': 24 | self.rnn_layer = LSTMKeras 25 | else: 26 | raise ValueError('node_type must either be "gru" or "lstm"') 27 | 28 | if direction in ['forward_only','bidirectional']: 29 | self.direction = direction 30 | else: 31 | raise ValueError('direction must be either "forward_only" or "bidirectional"') 32 | 33 | if not np.all([i in range(len(y_idxs)) for i in y_idxs]): 34 | raise ValueError('Y indexes must be the first columns in the dataset in order for normalization to work ' 35 | 'properly') 36 | 37 | 38 | self.number_of_rnn_layers = number_of_rnn_layers 39 | self.rnn_layer_size = rnn_layer_size 40 | self.number_of_dense_layers = number_of_dense_layers 41 | self.dense_layer_size = dense_layer_size 42 | self.ts_length = input_ts_length 43 | self.input_num_features = input_num_features 44 | self.output_size = output_num_features 45 | self.rnn_to_dense_connection = rnn_to_dense_connection 46 | self._init_model() 47 | self.normalization_factors = normalization_factors 48 | self.y_idxs = y_idxs 49 | self.columns = columns 50 | self.optimizer = AdamKeras(learning_rate=learning_rate) 51 | self.loss = 'mse' if loss=='mse' else HaversineLoss(normalization_factors).haversine_loss 52 | 53 | 54 | def _init_model(self): 55 | """ 56 | Create model as specified during initialization 57 | 58 | :return: 59 | """ 60 | input = Input(shape=(self.ts_length, self.input_num_features)) 61 | 62 | hidden = [input] 63 | if self.rnn_to_dense_connection == 'all_nodes': 64 | num_full_sequence_layers = self.number_of_rnn_layers 65 | else: 66 | num_full_sequence_layers = self.number_of_rnn_layers - 1 67 | 68 | if num_full_sequence_layers > 0: 69 | if self.direction == 'forward_only': 70 | hidden.append(self.rnn_layer(self.rnn_layer_size, return_sequences=True)(hidden[-1])) 71 | num_full_sequence_layers -= 1 72 | elif self.direction == 'bidirectional': 73 | hidden.append(Bidirectional(self.rnn_layer(self.rnn_layer_size, return_sequences=True))(hidden[-1])) 74 | num_full_sequence_layers -= 1 75 | 76 | for layer in range(num_full_sequence_layers): 77 | if self.direction == 'forward_only': 78 | hidden.append(self.rnn_layer(self.rnn_layer_size, return_sequences=True)(hidden[-1])) 79 | elif self.direction == 'bidirectional': 80 | hidden.append(Bidirectional(self.rnn_layer(self.rnn_layer_size, return_sequences=True))(hidden[-1])) 81 | 82 | if self.rnn_to_dense_connection == 'all_nodes': 83 | hidden.append(Flatten('channels_first')(hidden[-1])) 84 | else: 85 | if self.direction == 'forward_only': 86 | hidden.append(self.rnn_layer(self.rnn_layer_size, return_sequences=False)(hidden[-1])) 87 | elif self.direction == 'bidirectional': 88 | hidden.append(Bidirectional(self.rnn_layer(self.rnn_layer_size, return_sequences=False))(hidden[-1])) 89 | 90 | 91 | 92 | # Define Dense section 93 | for layer in range(self.number_of_dense_layers): 94 | hidden.append(Dense(self.dense_layer_size, activation='relu')(hidden[-1])) 95 | 96 | # Define output 97 | output = Dense(self.output_size, activation='linear')(hidden[-1]) 98 | 99 | self.model = Model(inputs=input, outputs=output) 100 | 101 | 102 | def save(self, *pos_args, **named_args): 103 | """ 104 | Save model to disk 105 | 106 | :param pos_args: Positional args, passed down to model object's save method 107 | :param named_args: Named args, passed down to model object's save method 108 | :return: 109 | """ 110 | self.model.save(*pos_args, **named_args) 111 | 112 | def insert_predictions(self, X, predictions, time): 113 | """ 114 | Insert predicted values into X dataset 115 | 116 | The iterative model makes predictions for the next timestamp, then uses these as input data 117 | to predict the following timestamp. 118 | 119 | This method is used for amending the input dataset so that new predictions can be made. It appends a set of 120 | predicted values to the end of the previous input data, and cuts off that input data's first timestamp 121 | 122 | :param X: Original input data 123 | :param predictions: Predictions to append 124 | :param time: The time gap being used 125 | :return: 126 | """ 127 | # Copy over the static info 128 | Y_hat_i_full = Normalizer().unnormalize(X[:, -1,:].copy(), self.normalization_factors) 129 | 130 | predictions = Normalizer().unnormalize(predictions.copy(), self.normalization_factors) 131 | 132 | # Input the new predictions 133 | Y_hat_i_full[:, self.y_idxs] = predictions 134 | 135 | # Distance traveled 136 | if 'distance_traveled' in self.columns['column'].values: 137 | lat_lon_idx = [_find_current_col_idx(c,self.columns) for c in ['lat','lon']] 138 | first_lat_lon = Normalizer().unnormalize(X[:, 1, lat_lon_idx], self.normalization_factors).copy() 139 | predicted_lat_lon = Y_hat_i_full[:,lat_lon_idx] 140 | distance_traveled = haversine_vector(first_lat_lon, predicted_lat_lon, Unit.KILOMETERS) 141 | dt_idx = _find_current_col_idx('distance_traveled',self.columns) 142 | Y_hat_i_full[:, dt_idx] = distance_traveled 143 | 144 | # Hour/day 145 | if 'day_of_week' in self.columns['column'].values: 146 | hour_col = _find_current_col_idx('hour',self.columns) 147 | dow_col = _find_current_col_idx('day_of_week',self.columns) 148 | incremented = (Normalizer().unnormalize(X, self.normalization_factors)[:,1:, hour_col] 149 | - Normalizer().unnormalize(X, self.normalization_factors)[:,:-1, hour_col]) != 0 150 | incremented = pd.DataFrame({'r': np.where(incremented)[0], 'c': np.where(incremented)[1]}) 151 | last_incremented = incremented.groupby('r').max().to_numpy().squeeze() 152 | 153 | how_often_to_increment = 60 / time 154 | 155 | Y_hat_i_full[:, hour_col] = np.where( 156 | last_incremented + how_often_to_increment < X.shape[1], 157 | Y_hat_i_full[:, hour_col] + 1, 158 | Y_hat_i_full[:, hour_col] 159 | ) 160 | Y_hat_i_full[:, dow_col] = np.where( 161 | Y_hat_i_full[:, hour_col] > 23, 162 | Y_hat_i_full[:, dow_col] + 1, 163 | Y_hat_i_full[:, dow_col] 164 | ) 165 | Y_hat_i_full[:, hour_col] %= 24 166 | Y_hat_i_full[:, dow_col] %= 7 167 | 168 | # Drop off the first timestamp 169 | layers = [X[:, j, :] for j in range(1, X.shape[1])] 170 | 171 | # Add in the prediction from last round as the last timestamp 172 | Y_hat_i_full = Normalizer().normalize_data(Y_hat_i_full, self.normalization_factors) 173 | layers.append(Y_hat_i_full) 174 | 175 | # Stack everything together and append to the list 176 | new_X = np.stack(layers, axis=1) 177 | return new_X 178 | 179 | def predict(self, valid_X_long_term, valid_Y_long_term, args): 180 | """ 181 | Make predictions for an evaluation dataset, returning both the predictions and errors 182 | 183 | :param valid_X_long_term: Dataset to make predictions for 184 | :param valid_Y_long_term: Ground truth 185 | :param args: argparse.Namespace specifying model 186 | :return: 187 | """ 188 | valid_Xs = [valid_X_long_term] 189 | valid_Y_hats = [] 190 | mean_haversine_distances = [] 191 | haversine_distances = [] 192 | # Iteratively carry predictions forward 193 | for i in range(valid_Y_long_term.shape[1]): 194 | # Valid Xs is a list containing the datasets used for prediction. It is appended to as we make new predictions 195 | # and use those for prediction 196 | X = valid_Xs[i] 197 | 198 | # Make prediction 199 | Y_hat_i_normalized = self.model.predict(X) 200 | Y_hat_i_unnormalized = Normalizer().unnormalize(Y_hat_i_normalized, self.normalization_factors) 201 | valid_Y_hats.append(Y_hat_i_unnormalized) 202 | 203 | 204 | # Get haversine distance 205 | lat_lon_idxs = [_find_current_col_idx(c, self.columns) for c in ['lat','lon']] 206 | ground_truth_Y_unnormalized = Normalizer().unnormalize(valid_Y_long_term[:, i, lat_lon_idxs], self.normalization_factors) 207 | 208 | haversine_distance = haversine_vector(ground_truth_Y_unnormalized, 209 | Y_hat_i_unnormalized[:, lat_lon_idxs], 210 | Unit.KILOMETERS) 211 | haversine_distances.append(haversine_distance) 212 | mean_haversine_distance = haversine_distance.mean() 213 | mean_haversine_distances.append(mean_haversine_distance) 214 | 215 | # If this isn't the last prediction we needed to make, append to the list of X sets 216 | if i != valid_Y_long_term.shape[1] - 1: 217 | valid_Xs.append(self.insert_predictions(X, Y_hat_i_normalized, args.time)) 218 | 219 | # TODO: Make more adjustable 220 | common_prediction_time = 60 221 | hour_idxs = [int((h - common_prediction_time) / args.time) - 1 for h in [120, 180, 240]] 222 | 223 | valid_Y_hats = [valid_Y_hats[i] for i in hour_idxs] 224 | hour_haversine_distances = [haversine_distances[i] for i in hour_idxs] 225 | mean_hour_haversine_distances = np.array(mean_haversine_distances)[hour_idxs] 226 | return valid_Y_hats, hour_haversine_distances, mean_hour_haversine_distances -------------------------------------------------------------------------------- /tests/models/losses.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | 3 | import tensorflow as tf 4 | from config import config 5 | 6 | 7 | class HaversineLoss(): 8 | """ 9 | Class implementing haversine loss 10 | """ 11 | def __init__(self, normalization_factors): 12 | self.lat_min = normalization_factors['lat']['min'] 13 | self.lat_range = normalization_factors['lat']['max'] - self.lat_min 14 | 15 | self.lon_min = normalization_factors['lon']['min'] 16 | self.lon_range = normalization_factors['lon']['max'] - self.lon_min 17 | 18 | @tf.function 19 | def haversine_loss(self, y_true, y_pred): 20 | """ 21 | Calculate haversine distance between two sets of points 22 | 23 | Coordinates should be in degrees, scaled to be between 0 and 1 based on the lat/lon mins in the config file 24 | 25 | Latitudes should be in the 0th column in both sets, and longitudes should be in the first columns 26 | 27 | :param y_true: Ground truth points 28 | :param y_pred: Predicted points 29 | :return: 30 | """ 31 | # If this is just being traced, the batch_size should be None, which 32 | y_true = tf.cast(y_true, y_pred.dtype) 33 | if len(y_true.shape) == 2: 34 | batch_axis = 0 35 | lat_lon_axis = 1 36 | num_predictions = tf.shape(y_true)[batch_axis] 37 | elif len(y_true.shape) == 3: 38 | batch_axis = 0 39 | time_axis = 1 40 | lat_lon_axis = 2 41 | num_predictions = tf.shape(y_true)[batch_axis] * tf.shape(y_true)[time_axis] 42 | else: 43 | raise ValueError('Unknown input shape') 44 | 45 | 46 | lat_min = config.dataset_config.lat_1 47 | lat_range = config.dataset_config.lat_2 - lat_min 48 | 49 | lon_min = config.dataset_config.lon_1 50 | lon_range = config.dataset_config.lon_2 - lon_min 51 | 52 | y_true_deg = (y_true * tf.constant([lat_range, lon_range], dtype=y_true.dtype) 53 | + tf.constant([lat_min, lon_min],dtype=y_true.dtype)) * tf.constant(pi / 180., dtype=y_true.dtype) 54 | lat_true = tf.gather(y_true_deg, 0, axis=lat_lon_axis) 55 | lon_true = tf.gather(y_true_deg, 1, axis=lat_lon_axis) 56 | 57 | y_pred_deg = (y_pred * tf.constant([lat_range, lon_range], dtype = y_pred.dtype) 58 | + tf.constant([lat_min, lon_min], dtype=y_pred.dtype)) * tf.constant(pi / 180., dtype=y_pred.dtype) 59 | lat_pred = tf.gather(y_pred_deg, 0, axis=lat_lon_axis) 60 | lon_pred = tf.gather(y_pred_deg, 1, axis=lat_lon_axis) 61 | 62 | EARTH_RADIUS = 6371 63 | 64 | interior = (tf.math.sin((lat_true - lat_pred)/2) ** 2 65 | + tf.math.cos(lat_true) * tf.math.cos(lat_pred) * (tf.math.sin((lon_true - lon_pred)/2) ** 2)) 66 | 67 | # Clip, to make sure there aren't any floating point issues 68 | interior = tf.clip_by_value(interior, 0., 1.) 69 | 70 | interior = tf.math.sqrt(interior) 71 | 72 | distance = 2 * EARTH_RADIUS * tf.math.asin( 73 | interior 74 | ) 75 | 76 | # Only keep places where distance is not 0, as due to the sqrt above, these will not be differentiable and will 77 | # result in a NaN. These are still still ~kept in~ the loss by the fact that we are averaging using the batch 78 | # size below, (so they still contribute to the mean), it's just that we aren't using them for optimization, 79 | # which is fine as they are already perfect. 80 | distance = tf.reshape(distance,[-1]) 81 | distance = tf.squeeze(tf.gather(distance, tf.where(distance != 0))) 82 | 83 | mean_distance = tf.reduce_sum(distance) / tf.cast(num_predictions, distance.dtype) 84 | 85 | return mean_distance 86 | 87 | 88 | def get_config(self): 89 | """ 90 | Return the lat/lon min/max, which can be used for saving a loss object to disk 91 | 92 | :return: 93 | """ 94 | return {'lat_min': self.lat_min, 95 | 'lat_range': self.lat_range, 96 | 'lon_min': self.lon_min, 97 | 'lon_range': self.lon_range} 98 | 99 | def from_config(self, config): 100 | """ 101 | Save the lat/lon to this object, based on a config loaded from somewhere else 102 | 103 | :param config: 104 | :return: 105 | """ 106 | for k, v in config.items(): 107 | setattr(self, k, v) -------------------------------------------------------------------------------- /tests/models/median_stopping.py: -------------------------------------------------------------------------------- 1 | # This class has been deprecated, and no longer works with all configurations 2 | from mlflow.tracking.client import MlflowClient 3 | from mlflow.entities import ViewType 4 | import pandas as pd 5 | import numpy as np 6 | import os 7 | from config import config 8 | import urllib 9 | import keras 10 | from mlflow import log_metric 11 | 12 | class MedianStopper(keras.callbacks.Callback): 13 | """ 14 | Based on description in section 3.2.2 of the below paper: 15 | https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/46180.pdf 16 | 17 | Calculates, for each model that has been fit, at each step, what its average loss was. 18 | If for a given step, the current model's best loss is worse than the median average loss, the model 19 | is ended early. 20 | """ 21 | def __init__(self, run_id, experiment_id, iterations = 30): 22 | self.cutoffs_path = os.path.join(config.box_and_year_dir, '.median_losses.csv') 23 | self.run_id = run_id 24 | self.experiment_id = experiment_id 25 | self.early_stopping_occured = False 26 | self.stopped_epoch = -1 27 | self.best_loss = np.Inf 28 | self.iterations = iterations 29 | 30 | def on_train_begin(self, logs=None): 31 | """ 32 | Load the median average loss from previous training runs when training begins 33 | 34 | These will be used as the cutoffs for early stopping 35 | 36 | :param logs: Information on current run 37 | :return: None 38 | """ 39 | self._load_cutoffs() 40 | 41 | def on_train_end(self, logs=None): 42 | """ 43 | Record whether early stopping occurred or not in the console and logs 44 | 45 | :param logs: Information on current run 46 | :return: None 47 | """ 48 | if self.early_stopping_occured: 49 | print(f'Epoch {self.stopped_epoch + 1}: early stopping occurred based on median loss') 50 | log_metric('early_median_stopping_occured', 1.) 51 | else: 52 | print(f'Early stopping did not occur based on median loss') 53 | log_metric('early_median_stopping_occured', 0.) 54 | 55 | def on_epoch_end(self, epoch, logs=None): 56 | """ 57 | Log the best loss after each epoch, and if this is an iteration to check for early stopping, do so 58 | 59 | :param epoch: The epoch number 60 | :param logs: Information on current run 61 | :return: None 62 | """ 63 | current_loss = logs.get('val_loss') 64 | if current_loss < self.best_loss: 65 | self.best_loss = current_loss 66 | if ((epoch + 1) / self.iterations) % 1 == 0: 67 | if epoch + 1 in self.cutoffs['steps'].values: 68 | cutoff = self.cutoffs[self.cutoffs['steps'] == epoch + 1]['average_loss'].iloc[0] 69 | else: 70 | cutoff = np.Inf 71 | if np.greater(self.best_loss, cutoff): 72 | self.stopped_epoch = epoch 73 | self.model.stop_training = True 74 | self.early_stopping_occured = True 75 | 76 | def _load_cutoffs(self): 77 | """ 78 | Load the median average loss from previous training runs 79 | 80 | 81 | :return: None 82 | """ 83 | self.cutoffs = pd.read_csv(self.cutoffs_path) 84 | self.cutoffs = self.cutoffs[self.cutoffs['experiment_id'] == int(self.experiment_id)] 85 | 86 | def recalculate_cutoffs(self): 87 | """ 88 | After training ends, recalculate the median mean losses across the models 89 | 90 | Writes the medians so they can be easily loaded by the next run 91 | 92 | :return: None 93 | """ 94 | client = MlflowClient() 95 | experiments = [exp.experiment_id for exp in client.list_experiments()] 96 | runs = client.search_runs( 97 | experiment_ids=experiments, 98 | run_view_type=ViewType.ACTIVE_ONLY 99 | ) 100 | 101 | run_info = pd.DataFrame(columns=['steps','average_loss','experiment_id']) 102 | for run in runs: 103 | if run.info.status == 'FINISHED' or run.info.run_id == self.run_id: 104 | artifact_dir = run.info._artifact_uri 105 | overall_dir = os.path.dirname(artifact_dir) 106 | try: 107 | loss_path = os.path.join(overall_dir, 'metrics', 'epoch_val_loss') 108 | loss_history = pd.read_csv(loss_path,names=['time','loss','step'], sep='\s+') 109 | except urllib.error.URLError: 110 | loss_path = os.path.join(overall_dir, 'metrics', 'val_loss') 111 | loss_history = pd.read_csv(loss_path,names=['time','loss','step'], sep='\s+') 112 | loss_history['steps'] = 1 113 | loss_history = loss_history[['loss','steps']].cumsum() 114 | loss_history['average_loss'] = loss_history['loss'] / loss_history['steps'] 115 | 116 | loss_history = loss_history.drop(['loss'],axis=1) 117 | loss_history['experiment_id'] = run.info.experiment_id 118 | 119 | run_info = pd.concat([run_info, loss_history]) 120 | self.cutoffs = run_info.groupby(['experiment_id','steps']).median() 121 | 122 | self.cutoffs.to_csv(self.cutoffs_path) 123 | 124 | def on_epoch_begin(self, epoch, logs=None): 125 | pass 126 | 127 | def on_test_begin(self, logs=None): 128 | pass 129 | 130 | def on_test_end(self, logs=None): 131 | pass 132 | 133 | def on_predict_begin(self, logs=None): 134 | pass 135 | 136 | def on_predict_end(self, logs=None): 137 | pass 138 | 139 | def on_train_batch_begin(self, batch, logs=None): 140 | pass 141 | 142 | def on_train_batch_end(self, batch, logs=None): 143 | pass 144 | 145 | def on_test_batch_begin(self, batch, logs=None): 146 | pass 147 | 148 | def on_test_batch_end(self, batch, logs=None): 149 | pass 150 | 151 | def on_predict_batch_begin(self, batch, logs=None): 152 | pass 153 | 154 | def on_predict_batch_end(self, batch, logs=None): 155 | pass 156 | 157 | if __name__ == '__main__': 158 | stopper = MedianStopper(run_id=None, experiment_id='6') 159 | stopper.recalculate_cutoffs() 160 | 161 | -------------------------------------------------------------------------------- /tests/models/model_runner.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class ModelRunner(ABC): 5 | @abstractmethod 6 | def __init__(self): 7 | pass 8 | 9 | def fit(self, *pos_args, **named_args): 10 | return self.model.fit(*pos_args, **named_args) 11 | 12 | @abstractmethod 13 | def predict(self): 14 | pass 15 | 16 | @abstractmethod 17 | def save(self): 18 | pass 19 | 20 | def compile(self): 21 | self.model.compile(optimizer=self.optimizer, loss=self.loss) 22 | -------------------------------------------------------------------------------- /tests/models/seq2seq_runner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from haversine import haversine_vector, Unit 4 | from tensorflow.keras.layers import GRU as GRUTF, LSTM as LSTMTF 5 | from tensorflow.keras.optimizers import Adam as AdamTF 6 | 7 | from loading.loading import _find_current_col_idx 8 | from loading import Normalizer 9 | from models.seq2seq_model_pieces import TrainTranslator 10 | from models.model_runner import ModelRunner 11 | from models.losses import HaversineLoss 12 | 13 | 14 | class Seq2SeqRNNAttentionRunner(ModelRunner): 15 | """ 16 | Class for creating the desired type of Tensorflow model. Creates the model object, and also provides a wrapper 17 | function for making predictions 18 | """ 19 | def __init__(self, node_type, number_of_rnn_layers, rnn_layer_size, direction, input_ts_length, output_ts_length, 20 | input_num_features, output_num_features, normalization_factors, y_idxs, columns, learning_rate, loss): 21 | if node_type.lower() == 'gru': 22 | self.rnn_layer = GRUTF 23 | elif node_type.lower() == 'lstm': 24 | self.rnn_layer = LSTMTF 25 | else: 26 | raise ValueError('node_type must either be "gru" or "lstm"') 27 | 28 | if direction in ['forward_only', 'bidirectional']: 29 | self.direction = direction 30 | else: 31 | raise ValueError('direction must be either "forward_only" or "bidirectional"') 32 | 33 | self.number_of_rnn_layers = number_of_rnn_layers 34 | 35 | self.rnn_layer_size = rnn_layer_size 36 | self.input_ts_length = input_ts_length 37 | self.output_ts_length = output_ts_length 38 | self.input_num_features = input_num_features 39 | self.output_num_features = output_num_features 40 | 41 | self._init_model() 42 | self.normalization_factors = normalization_factors 43 | self.y_idxs = y_idxs 44 | self.columns = columns 45 | self.optimizer = AdamTF(learning_rate=learning_rate) 46 | self.loss = tf.keras.losses.MeanSquaredError() if loss == 'mse' else HaversineLoss(normalization_factors).haversine_loss 47 | 48 | 49 | 50 | def _init_model(self): 51 | """ 52 | Create model as specified during initialization 53 | 54 | :return: 55 | """ 56 | self.model = TrainTranslator(units = self.rnn_layer_size, 57 | num_input_variables=self.input_num_features, 58 | num_output_variables=self.output_num_features, 59 | input_series_length=self.input_ts_length, 60 | output_series_length=self.output_ts_length) 61 | 62 | def predict(self, input_text, output_text, args): 63 | """ 64 | Make predictions for an evaluation dataset, returning both the predictions and errors 65 | 66 | Because of a constraint by the model object,the data has to be predicted in batches. This is handled within the 67 | predict method 68 | 69 | :param input_text: Dataset to make predictions for 70 | :param output_text: Ground truth 71 | :param args: argparse.Namespace specifying model 72 | :return: 73 | """ 74 | input_text = input_text.astype(np.float32) 75 | chunks = np.split(input_text,np.arange(args.batch_size,input_text.shape[0], args.batch_size)) 76 | preds = [self.model.predict(t).numpy() for t in chunks[:-1]] 77 | preds += [self.model.predict(input_text[-args.batch_size:]).numpy()[-chunks[-1].shape[0]:]]# get last chunk (which is not correct size) 78 | result_tokens = np.concatenate(preds) 79 | predicted_lat_long = Normalizer().unnormalize(result_tokens, self.normalization_factors) 80 | 81 | output_text = Normalizer().unnormalize(output_text, self.normalization_factors) 82 | lat_lon_idxs = [_find_current_col_idx(c, self.columns) for c in ['lat','lon']] 83 | mean_haversine_distances = [] 84 | haversine_distances = [] 85 | for i in range(self.output_ts_length): 86 | haversine_distance = haversine_vector(output_text[:, i, lat_lon_idxs], 87 | predicted_lat_long[:, i, lat_lon_idxs], 88 | Unit.KILOMETERS) 89 | haversine_distances.append(haversine_distance) 90 | mean_haversine_distance = haversine_distance.mean() 91 | mean_haversine_distances.append(mean_haversine_distance) 92 | 93 | common_prediction_time = 60 94 | hour_idxs = [int((h - common_prediction_time) / args.time) - 1 for h in [120, 180, 240]] 95 | 96 | Y_hats = [predicted_lat_long[:, i, lat_lon_idxs] for i in hour_idxs] 97 | hour_haversine_distances = [haversine_distances[i] for i in hour_idxs] 98 | mean_hour_haversine_distances = [mean_haversine_distances[i] for i in hour_idxs] 99 | return Y_hats, hour_haversine_distances, mean_hour_haversine_distances 100 | 101 | 102 | def save(self, path): 103 | """ 104 | Save model to disk 105 | 106 | :param pos_args: Positional args, passed down to model object's save method 107 | :param named_args: Named args, passed down to model object's save method 108 | :return: 109 | """ 110 | tf.saved_model.save(self.model, path) -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.processor_manager import ProcessorManager 2 | from utils.test_arg_parser import TestArgParser 3 | from utils.utils import set_seed -------------------------------------------------------------------------------- /tests/utils/arg_validation.py: -------------------------------------------------------------------------------- 1 | # Set of classes for validating that command line arguments do not conflict with one another 2 | # 3 | # The Req class is used for checking conditional and unconditional requirements. The other classes are the 4 | # specific requirements that can be made - e.g. that an argument is given/not given, or that an argument takes on 5 | # a specific value or range of values. 6 | import numpy as np 7 | 8 | 9 | class NotGiven(): 10 | """ 11 | Requirement that an argument have value None 12 | """ 13 | def __init__(self, arg_name): 14 | if type(arg_name) != list: 15 | arg_name = [arg_name] 16 | self.arg_name = arg_name 17 | 18 | def is_true(self, args): 19 | """ 20 | Check that this requirement is true 21 | 22 | :param args: The args namespace which needs to be validated 23 | :return: 24 | """ 25 | return not np.any([getattr(args, a) is not None and getattr(args, a) != 'ignore' for a in self.arg_name]) 26 | 27 | def message(self, args): 28 | """ 29 | Message used for creating error message when this is apart of a requirement that 30 | has not been met 31 | 32 | :param args: The args namespace which needs to be validated 33 | :return: 34 | """ 35 | issues = np.array(self.arg_name)[[hasattr(args, a) for a in self.arg_name]][0] 36 | return f'specifying {issues}' 37 | 38 | 39 | class Given(): 40 | """ 41 | Requirement that an argument have a value other than None or 'ignore' 42 | """ 43 | def __init__(self, arg_name): 44 | if type(arg_name) != list: 45 | arg_name = [arg_name] 46 | self.arg_name = arg_name 47 | 48 | def is_true(self, args): 49 | """ 50 | Check that this requirement is true 51 | 52 | :param args: The args namespace which needs to be validated 53 | :return: 54 | """ 55 | return np.all([getattr(args, a) is not None and getattr(args, a) != 'ignore' for a in self.arg_name]) 56 | 57 | def message(self, args): 58 | """ 59 | Message used for creating error message when this is apart of a requirement that 60 | has not been met 61 | 62 | :param args: The args namespace which needs to be validated 63 | :return: 64 | """ 65 | issues = np.array(self.arg_name)[[not hasattr(args, a) or getattr(args, a) == 'ignore' for a in self.arg_name]] 66 | return f'not specifying {issues}' 67 | 68 | 69 | class Values(): 70 | """ 71 | Reqiurement that an argument be an element of a list 72 | """ 73 | def __init__(self, arg_name, values): 74 | self.arg_name = arg_name 75 | assert type(values) == list 76 | self.values = values 77 | 78 | def is_true(self, args): 79 | """ 80 | Check that this requirement is true 81 | 82 | :param args: The args namespace which needs to be validated 83 | :return: 84 | """ 85 | return getattr(args, self.arg_name) in self.values 86 | 87 | def message(self, args): 88 | """ 89 | Message used for creating error message when this is apart of a requirement that 90 | has not been met 91 | 92 | :param args: The args namespace which needs to be validated 93 | :return: 94 | """ 95 | return f'{self.arg_name} value of {getattr(args, self.arg_name)}' 96 | 97 | 98 | class ValueRange(): 99 | """ 100 | Reguire that an argument have value in a specific range 101 | """ 102 | def __init__(self, arg_name, value_range): 103 | self.arg_name =arg_name 104 | self.value_range = value_range 105 | assert len(value_range) == 2 106 | 107 | def is_true(self, args): 108 | """ 109 | Check that this requirement is true 110 | 111 | :param args: The args namespace which needs to be validated 112 | :return: 113 | """ 114 | val = getattr(args, self.arg_name) 115 | lower_lim = self.value_range[0] 116 | upper_lim = self.value_range[1] 117 | if lower_lim is not None and val < lower_lim: 118 | return False 119 | elif upper_lim is not None and val > upper_lim: 120 | return False 121 | else: 122 | return True 123 | 124 | def message(self, args): 125 | """ 126 | Message used for creating error message when this is apart of a requirement that 127 | has not been met 128 | 129 | :param args: The args namespace which needs to be validated 130 | :return: 131 | """ 132 | return f'{self.arg_name} value outside range of {self.value_range}' 133 | 134 | 135 | class Req(): 136 | """ 137 | Class for checking conditional or unconditional requirements. 138 | 139 | If two requirements are specified, the first one only needs to be true if the second one is 140 | 141 | If only one requirement is specified, it must always be true 142 | 143 | """ 144 | def __init__(self, b, a=None): 145 | assert type(b) in [ValueRange, Given, NotGiven, Values] 146 | self.b = b 147 | if type(a) != list: 148 | a = [a] 149 | self.a = a 150 | 151 | def validate(self, args): 152 | """ 153 | Check that the requirement has been met 154 | 155 | :param args: The args namespace which needs to be validated 156 | :return: 157 | """ 158 | self.args = args 159 | a_is_true = np.all([a is None or a.is_true(args) for a in self.a]) 160 | if a_is_true and not self.b.is_true(args): 161 | raise ValueError(self) 162 | else: 163 | pass 164 | 165 | def __str__(self): 166 | """ 167 | Convert to string. Used for error messages. 168 | 169 | :return: 170 | """ 171 | if self.a[0] is not None: 172 | a_message = [a.message(self.args) for a in self.a] 173 | return f'{a_message} cannot be combined with {self.b.message(self.args)}' 174 | else: 175 | return f'Error: {self.b.message(self.args)}' -------------------------------------------------------------------------------- /tests/utils/processor_manager.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import signal 4 | import atexit 5 | from importlib.util import find_spec 6 | 7 | 8 | class NoProcessorAvailableError(Exception): 9 | """ 10 | All processors on the host are in use 11 | """ 12 | pass 13 | 14 | class ProcessorManager(): 15 | """ 16 | Class used for running multiple machine learning models at once, using separate processors. 17 | 18 | The manager is a class that should be run with each model being fit. It keeps a list of available processors on 19 | disk, and checks one out for a model (again, writing to disk). When the model is done running, the processor is 20 | checked back in. 21 | 22 | Treats each GPU as a unique processor, but all CPUs as a single processor 23 | """ 24 | def __init__(self, save_dir='/home/isaac/data/', debug=False): 25 | self.processor_list_fp = os.path.join(save_dir, '.processor_list.json') 26 | self.my_processor = None 27 | self.debug = debug 28 | 29 | self.version = 'torch' if find_spec('torch') is not None else 'tensorflow' 30 | 31 | def __getitem__(self, item): 32 | return self.data[item] 33 | 34 | def __setitem__(self, key, value): 35 | self.data[key] = value 36 | 37 | def log_sigint(self, a=None, b=None): 38 | """ 39 | Close the manager if the model is terminated early by signal interuption 40 | 41 | Gets fed two inputs when called, neither of which are needed 42 | 43 | :param a: Not used 44 | :param b: Not used 45 | :return: None 46 | """ 47 | self.close() 48 | 49 | def _add_exit_handling(self): 50 | """ 51 | Add the log_sigint function to the exit handling process 52 | 53 | :return: None 54 | """ 55 | signal.signal(signal.SIGINT, self.log_sigint) 56 | atexit.register(self.log_sigint) 57 | 58 | def _load(self): 59 | """ 60 | Load the list of processors from disk 61 | 62 | :return: None 63 | """ 64 | if os.path.exists(self.processor_list_fp): 65 | with open(self.processor_list_fp, "r") as outfile: 66 | self.data = json.load(outfile) 67 | else: 68 | self._create_processor_list() 69 | 70 | def _remove_most_recent(self): 71 | """ 72 | Check the most recently checked-out processor back in 73 | 74 | Useful when the most recent run was terminated without a proper exit handle 75 | 76 | :return: None 77 | """ 78 | self._load() 79 | if len(self['in_use']) > 0: 80 | self['in_use'] = self['in_use'][:-1] 81 | self._save() 82 | 83 | def _remove_all(self): 84 | """ 85 | Check back in all processors that have been checked out 86 | 87 | Useful if a bunch of processors were terminated without a proper exit handle (e.g. if the host lost power 88 | suddenly) 89 | 90 | :return: None 91 | """ 92 | self._create_processor_list(overwrite=True) 93 | self._save() 94 | 95 | def _save(self): 96 | """ 97 | Save the current list of processors (including those that are checked out and in) to disk 98 | 99 | :return: None 100 | """ 101 | with open(self.processor_list_fp, "w") as outfile: 102 | json.dump(self.data, outfile) 103 | 104 | def _create_processor_list(self, overwrite=False): 105 | """ 106 | Create list of processors that can be used 107 | 108 | If tensorflow is available, uses tensorflow's list of devices to create list. Otherwise, uses pytorch. Treats 109 | each GPU as it's own processor and all CPUs as a single processor. 110 | """ 111 | if not overwrite: 112 | if os.path.exists(self.processor_list_fp): 113 | raise FileExistsError('Processor list already exists. It should not be recreated as it may contain' 114 | 'information on which processors are currently in use') 115 | if self.version == 'tensorflow': 116 | import tensorflow as tf 117 | self.data = {} 118 | processors = [dv.name for dv in tf.config.list_logical_devices()] 119 | processors = [p for p in processors if 'XLA' not in p] 120 | elif self.version == 'torch': 121 | from torch.cuda import device_count 122 | num_gpus = torch.cuda.device_count() 123 | processors = [f'/device:GPU:{i}' for i in range(num_gpus)] 124 | processors.append(f'/device:CPU:0') 125 | 126 | self.data['processors'] = processors 127 | self.data['in_use'] = [] 128 | 129 | def _choose_device(self): 130 | """ 131 | Select the device that should be used by the model 132 | 133 | :return: Processor to use 134 | :rtype: str 135 | """ 136 | # If we're debugging, use CPU, since even if CPU is already in use, room can still be made available 137 | if self.debug: 138 | all_cpus = [p for p in self.data['processors'] if 'CPU' in p] 139 | return all_cpus[0] 140 | else: 141 | available_gpus = [p for p in self.data['processors'] if 'GPU' in p and p not in self.data['in_use']] 142 | available_cpus = [p for p in self.data['processors'] if 'CPU' in p and p not in self.data['in_use']] 143 | if len(available_gpus) > 0: 144 | return available_gpus[0] 145 | elif len(available_cpus) > 0: 146 | return available_cpus[0] 147 | else: 148 | raise NoProcessorAvailableError(f'All processors currently in use. These are {self.data["processors"]}') 149 | 150 | def open(self): 151 | """ 152 | Check out a processor 153 | 154 | :return: None 155 | """ 156 | if self.my_processor is not None: 157 | print(f'Processor already open. The processor being used is {self.my_processor}') 158 | else: 159 | self._load() 160 | self.my_processor = self._choose_device() 161 | if not self.debug: 162 | self.data['in_use'].append(self.my_processor) 163 | self._save() 164 | self._add_exit_handling() 165 | if 'cpu' in self.my_processor.lower() and self.version == 'tensorflow': 166 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 167 | elif self.version == 'tensorflow': 168 | os.environ['CUDA_VISIBLE_DEVICES'] = self.device()[-1] 169 | 170 | def close(self): 171 | """ 172 | Check the processor back in 173 | 174 | :return: None 175 | """ 176 | if self.my_processor is None: 177 | print(f'Processor is not open.') 178 | else: 179 | self._load() 180 | if not self.debug: 181 | self.data['in_use'] = [p for p in self.data['in_use'] if p != self.my_processor] 182 | self._save() 183 | atexit.unregister(self.log_sigint) 184 | self.my_processor = None 185 | 186 | def device(self): 187 | """ 188 | Return the processor that this manager has checked out 189 | 190 | :return: The processor that this manager has checked out 191 | :rtype: str 192 | """ 193 | return self.my_processor 194 | 195 | 196 | if __name__ == '__main__': 197 | mgr = ProcessorManager() 198 | mgr._remove_all() 199 | -------------------------------------------------------------------------------- /tests/utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import shutil 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | def set_seed(seed): 9 | """ 10 | Set relevant random seeds 11 | 12 | :param seed: 13 | :return: 14 | """ 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | tf.random.set_seed(seed) 18 | 19 | 20 | def clear_path(path): 21 | """ 22 | Delete any files or directories from a path 23 | 24 | :param path: Path to remove 25 | :type path: str 26 | :return: None 27 | """ 28 | if os.path.exists(path): 29 | if os.path.isfile(path): 30 | os.remove(path) 31 | else: 32 | shutil.rmtree(path) 33 | 34 | def total_system_ram(): 35 | """ 36 | Get total system ram, in bytes 37 | 38 | :return: 39 | """ 40 | return os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES') --------------------------------------------------------------------------------