├── datasets ├── london_air_pollution │ ├── run.sh │ ├── readme.md │ ├── data_processing │ │ ├── rdata_to_csv.r │ │ └── rdata_to_csv_for_aq.r │ └── aq_downloader.py ├── london │ └── run.sh └── new_york_crime_large │ ├── run.sh │ └── clean.py ├── requirements.txt ├── LICENSE ├── .gitignore ├── README.md ├── Makefile └── experiments ├── nyc_crime ├── models │ ├── m_bayes_newt.py │ └── m_gpflow.py └── setup_data.py ├── air_quality ├── models │ ├── m_bayes_newt.py │ ├── m_ski.py │ └── m_gpflow.py └── setup_data.py └── utils.py /datasets/london_air_pollution/run.sh: -------------------------------------------------------------------------------- 1 | python aq_downloader.py 2 | -------------------------------------------------------------------------------- /datasets/london_air_pollution/readme.md: -------------------------------------------------------------------------------- 1 | # Downloading London Air Pollution 2 | 3 | 4 | 5 | To download simply run `sh ./run.sh`. This will download air pollution data from http://www.londonair.org.uk and collects all sensors into one csv file `download_data/aq_data.csv`. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax==0.2.9 2 | jaxlib==0.1.60 3 | objax==1.3.1 4 | numpy==1.21.3 5 | matplotlib==3.4.3 6 | scipy==1.7.1 7 | sklearn 8 | pandas==1.3.4 9 | bayesnewton==1.1 10 | gpflow==2.3.0 11 | gpytorch==1.5.1 12 | loguru==0.5.3 13 | rtree==0.9.7 14 | geopandas==0.10.2 15 | tqdm==4.62.3 16 | numba==0.54.1 17 | -------------------------------------------------------------------------------- /datasets/london/run.sh: -------------------------------------------------------------------------------- 1 | wd=$(dirname "$0") 2 | curl -L https://data.london.gov.uk/download/statistical-gis-boundary-files-london/9ba8c833-6370-4b11-abdc-314aa020d5e0/statistical-gis-boundaries-london.zip --output london_shp.zip 3 | mkdir -p london_shp 4 | mv london_shp.zip london_shp 5 | cd london_shp 6 | unzip london_shp.zip 7 | -------------------------------------------------------------------------------- /datasets/london_air_pollution/data_processing/rdata_to_csv.r: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | 3 | 4 | args = commandArgs(trailingOnly=TRUE) 5 | 6 | rdata_input = args[1] 7 | csv_output = args[2] 8 | df = args[3] 9 | 10 | load(rdata_input) 11 | 12 | options(TZ="Europe/London") 13 | Sys.setenv (TZ="Europe/London") 14 | 15 | print(rdata_input) 16 | print(csv_output) 17 | 18 | a = subset(get(df), select=-c(SiteName, Address, Authority)) 19 | #a = get(df) 20 | write.table(a, file=csv_output, row.names = FALSE, sep=';', na='') 21 | -------------------------------------------------------------------------------- /datasets/london_air_pollution/data_processing/rdata_to_csv_for_aq.r: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | 3 | 4 | args = commandArgs(trailingOnly=TRUE) 5 | 6 | rdata_input = args[1] 7 | csv_output = args[2] 8 | df = args[3] 9 | 10 | load(rdata_input) 11 | 12 | options(TZ="Europe/London") 13 | Sys.setenv (TZ="Europe/London") 14 | 15 | print(rdata_input) 16 | print(csv_output) 17 | 18 | #a = subset(get(df), select=-c(SiteName, Address, Authority)) 19 | a = get(df) 20 | write.table(a, file=csv_output, row.names = FALSE, sep=';', na='') 21 | -------------------------------------------------------------------------------- /datasets/new_york_crime_large/run.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data 2 | cd data 3 | 4 | # Download NYPD compaint data 5 | curl -L https://data.cityofnewyork.us/api/views/qgea-i56i/rows.csv?accessType=DOWNLOAD --output NYPD_Complaint_Data_Historic.csv 6 | 7 | # Download NYC boundary 8 | curl -L "https://data.cityofnewyork.us/api/geospatial/tqmj-j8zm?method=export&format=Shapefile" -o nyc_borough.zip 9 | 10 | mkdir Borough_Boundaries 11 | mv nyc_borough.zip Borough_Boundaries 12 | cd Borough_Boundaries 13 | unzip nyc_borough.zip 14 | 15 | # shapefile has a unique id attached, rename for convience 16 | mv *.dbf nyc.dbf 17 | mv *.shp nyc.shp 18 | mv *.shx nyc.shx 19 | mv *.prj nyc.prj 20 | 21 | cd ../ 22 | 23 | cd ../ 24 | 25 | 26 | #clean data 27 | python clean.py 28 | 29 | 30 | -------------------------------------------------------------------------------- /datasets/new_york_crime_large/clean.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | import sys 4 | sys.path.append('../../experiments/') 5 | import utils 6 | 7 | df = pd.read_csv('data/NYPD_Complaint_Data_Historic.csv') 8 | 9 | df['datetime'] = df.apply( 10 | lambda row: str(row['CMPLNT_FR_DT'])+' '+str(row['CMPLNT_FR_TM']), 11 | axis=1) 12 | 13 | df['datetime'] = pd.to_datetime(df['datetime'], format='%m/%d/%Y %H:%M:%S', errors='coerce') 14 | df = df[df['datetime'].isnull()==False] #drop rows with non-valid dates 15 | 16 | df['epoch'] = utils.datetime_to_epoch(df['datetime']) 17 | 18 | #df = df[(df['datetime'] >= '2014') & (df['datetime'] < '2015')] #only want in 2014 19 | 20 | df = df[df['Latitude'].notna()] #drop rows with nan lat-lons 21 | df = df[df['Longitude'].notna()] #drop rows with nan lat-lons 22 | 23 | df.to_csv('data/cleaned_nyc_crime.csv', index=False) 24 | 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 AaltoML 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Excluded Data 2 | datasets/london/ 3 | datasets/london_air_pollution/downloaded_data 4 | datasets/new_york_crime_large/data 5 | 6 | experiments/*/data/* 7 | experiments/*/results/* 8 | 9 | # Exclude shape files 10 | *.dbf 11 | *.prj 12 | *.shp 13 | *.shx 14 | *.zip 15 | *.cpg 16 | *.xml 17 | 18 | # Exclude python cache 19 | **/__pycache__/** 20 | 21 | 22 | # This is a .gitignore file which aims to keep the git 23 | # repository tidy by preventing the inclusion of different 24 | # temporary or system files. 25 | # 26 | # Last modified by Arno Solin, 2012-11-22 27 | 28 | # Exclude TeX temporary working files 29 | *.acn 30 | *.acr 31 | *.alg 32 | *.aux 33 | *.bbl 34 | *.blg 35 | *.dvi 36 | *.fdb_latexmk 37 | *.glg 38 | *.glo 39 | *.gls 40 | *.idx 41 | *.ilg 42 | *.ind 43 | *.ist 44 | *.lof 45 | *.log 46 | *.lot 47 | *.maf 48 | *.mtc 49 | *.mtc0 50 | *.nav 51 | *.nlo 52 | *.out 53 | *.pdfsync 54 | *.ps 55 | *.snm 56 | *.synctex.gz 57 | *.toc 58 | *.vrb 59 | *.xdy 60 | *.tdo 61 | *.dpth 62 | *.auxlock 63 | *.dep 64 | *.brf 65 | 66 | # Exclude backup files freated by e.g. Matlab and Emacs 67 | *~ 68 | 69 | # Exclude system specific thumbnail and other folder metadata 70 | .DS_Store 71 | .DS_Store? 72 | ._* 73 | .Spotlight-V100 74 | .Trashes 75 | Icon? 76 | ehthumbs.db 77 | Thumbs.db 78 | 79 | # Exclude externalisation results from tikz 80 | tikz*.pdf 81 | *.md5 82 | *.spl 83 | 84 | *.lock 85 | *.bin 86 | *.iml 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spatio-Temporal Variational GPs 2 | 3 | This repository is the official implementation of the methods in the publication: 4 | 5 | * O. Hamelijnck, W.J. Wilkinson, N.A. Loppi, A. Solin, and T. Damoulas (2021). **Spatio-temporal variational Gaussian processes**. In *Neural Information Processing Systems (NeurIPS)*. [[arXiv]](https://arxiv.org/abs/2111.01732) 6 | 7 | ## Citing this work: 8 | ```bibtex 9 | @inproceedings{hamelijnck2021spatio, 10 | title={Spatio-Temporal Variational {G}aussian Processes}, 11 | author={Hamelijnck, Oliver and Wilkinson, William and Loppi, Niki and Solin, Arno and Damoulas, Theodoros}, 12 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, 13 | year={2021}, 14 | } 15 | ``` 16 | 17 | ## Experiment Setup 18 | 19 | This has been tested on a Macbook Pro. All spatio-temporal VGP models have been implemented within the [Bayes-Newton package](https://github.com/AaltoML/BayesNewton). 20 | 21 | ### Environment Setup 22 | 23 | We recommend using conda: 24 | 25 | ```bash 26 | conda create -n spatio_gp python=3.7 27 | conda activate spatio_gp 28 | ``` 29 | 30 | Then install the required python packages: 31 | 32 | ```bash 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ### Data Download 37 | 38 | #### Pre-processed Data 39 | 40 | All data, preprocessed and split into train-test splits used in the paper is provided at https://doi.org/10.5281/zenodo.4531304. Download the folder and place the corresponding datasets into `experiments/*/data` folders. 41 | 42 | #### Manual Data Setup 43 | 44 | We also provide scripts to generate the data manually: 45 | 46 | ```bash 47 | make data 48 | ``` 49 | 50 | which will download the relevant London air quality and NYC data, clean them, and split into train-test splits. 51 | 52 | ### Running Experiments 53 | 54 | To run all experiments across all training folds run: 55 | 56 | ```bash 57 | make experiments 58 | ``` 59 | 60 | To run an individual experiment refer to the `Makefile`. 61 | 62 | #### Baselines used 63 | 64 | - `GPFlow2` : https://github.com/GPflow/GPflow 65 | - `GPYTorch`: https://github.com/cornellius-gp/gpytorch 66 | 67 | ## License 68 | 69 | This software is provided under the [MIT license](LICENSE). 70 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all data experiments 2 | 3 | data: 4 | @echo 'Generating Data' 5 | @echo 'Downloading london shape file' 6 | cd datasets/london && sh ./run.sh 7 | @echo 'Downloading laqn air pollution data' 8 | cd datasets/london_air_pollution && sh ./run.sh 9 | @echo 'Downloading new york crime data' 10 | cd datasets/new_york_crime_large && sh ./run.sh 11 | @echo 'Generating Air Quality train-test splits' 12 | cd experiments/air_quality && python setup_data.py 13 | @echo 'Generating NYC train-test splits' 14 | cd experiments/nyc_crime && python setup_data.py 15 | 16 | experiments: 17 | @echo 'Running Air Quality experiments' 18 | cd experiments/air_quality && mkdir -p results 19 | @echo 'Running bayesnewton model' 20 | cd experiments/air_quality/models && python m_bayes_newt.py 0 0 0 21 | cd experiments/air_quality/models && python m_bayes_newt.py 1 0 0 22 | cd experiments/air_quality/models && python m_bayes_newt.py 2 0 0 23 | cd experiments/air_quality/models && python m_bayes_newt.py 3 0 0 24 | cd experiments/air_quality/models && python m_bayes_newt.py 4 0 0 25 | @echo 'Running bayesnewton mean-field model' 26 | cd experiments/air_quality/models && python m_bayes_newt.py 0 1 0 27 | cd experiments/air_quality/models && python m_bayes_newt.py 1 1 0 28 | cd experiments/air_quality/models && python m_bayes_newt.py 2 1 0 29 | cd experiments/air_quality/models && python m_bayes_newt.py 3 1 0 30 | cd experiments/air_quality/models && python m_bayes_newt.py 4 1 0 31 | @echo 'Running bayesnewton parallel model' 32 | cd experiments/air_quality/models && python m_bayes_newt.py 0 0 1 33 | cd experiments/air_quality/models && python m_bayes_newt.py 1 0 1 34 | cd experiments/air_quality/models && python m_bayes_newt.py 2 0 1 35 | cd experiments/air_quality/models && python m_bayes_newt.py 3 0 1 36 | cd experiments/air_quality/models && python m_bayes_newt.py 4 0 1 37 | @echo 'Running gpflow model' 38 | cd experiments/air_quality/models && python m_gpflow.py 0 39 | cd experiments/air_quality/models && python m_gpflow.py 1 40 | cd experiments/air_quality/models && python m_gpflow.py 2 41 | cd experiments/air_quality/models && python m_gpflow.py 3 42 | cd experiments/air_quality/models && python m_gpflow.py 4 43 | @echo 'Running ski model' 44 | cd experiments/air_quality/models && python m_ski.py 0 45 | cd experiments/air_quality/models && python m_ski.py 1 46 | cd experiments/air_quality/models && python m_ski.py 2 47 | cd experiments/air_quality/models && python m_ski.py 3 48 | cd experiments/air_quality/models && python m_ski.py 4 49 | @echo 'Running NYC experiments' 50 | cd experiments/nyc_crime && mkdir -p results 51 | @echo 'Running bayesnewton model' 52 | cd experiments/nyc_crime/models && python m_bayes_newt.py 0 0 0 53 | cd experiments/nyc_crime/models && python m_bayes_newt.py 1 0 0 54 | cd experiments/nyc_crime/models && python m_bayes_newt.py 2 0 0 55 | cd experiments/nyc_crime/models && python m_bayes_newt.py 3 0 0 56 | cd experiments/nyc_crime/models && python m_bayes_newt.py 4 0 0 57 | @echo 'Running bayesnewton mean-field model' 58 | cd experiments/nyc_crime/models && python m_bayes_newt.py 0 1 0 59 | cd experiments/nyc_crime/models && python m_bayes_newt.py 1 1 0 60 | cd experiments/nyc_crime/models && python m_bayes_newt.py 2 1 0 61 | cd experiments/nyc_crime/models && python m_bayes_newt.py 3 1 0 62 | cd experiments/nyc_crime/models && python m_bayes_newt.py 4 1 0 63 | @echo 'Running bayesnewton parallel model' 64 | cd experiments/nyc_crime/models && python m_bayes_newt.py 0 0 1 65 | cd experiments/nyc_crime/models && python m_bayes_newt.py 1 0 1 66 | cd experiments/nyc_crime/models && python m_bayes_newt.py 2 0 1 67 | cd experiments/nyc_crime/models && python m_bayes_newt.py 3 0 1 68 | cd experiments/nyc_crime/models && python m_bayes_newt.py 4 0 1 69 | @echo 'Running gpflow model' 70 | cd experiments/nyc_crime/models && python m_gpflow.py 0 71 | cd experiments/nyc_crime/models && python m_gpflow.py 1 72 | cd experiments/nyc_crime/models && python m_gpflow.py 2 73 | cd experiments/nyc_crime/models && python m_gpflow.py 3 74 | cd experiments/nyc_crime/models && python m_gpflow.py 4 75 | 76 | 77 | -------------------------------------------------------------------------------- /experiments/nyc_crime/models/m_bayes_newt.py: -------------------------------------------------------------------------------- 1 | import bayesnewton 2 | import objax 3 | import numpy as np 4 | import pickle 5 | import time 6 | import sys 7 | from scipy.cluster.vq import kmeans2 8 | from jax.lib import xla_bridge 9 | 10 | if len(sys.argv) > 1: 11 | ind = int(sys.argv[1]) 12 | else: 13 | ind = 0 14 | 15 | 16 | if len(sys.argv) > 2: 17 | mean_field = bool(int(sys.argv[2])) 18 | else: 19 | mean_field = False 20 | 21 | 22 | if len(sys.argv) > 3: 23 | parallel = bool(int(sys.argv[3])) 24 | else: 25 | parallel = None 26 | 27 | # ===========================Load Data=========================== 28 | train_data = pickle.load(open("../data/train_data_" + str(ind) + ".pickle", "rb")) 29 | pred_data = pickle.load(open("../data/pred_data_" + str(ind) + ".pickle", "rb")) 30 | 31 | X = train_data['X'] 32 | Y = train_data['Y'] 33 | 34 | X_t = pred_data['test']['X'] 35 | Y_t = pred_data['test']['Y'] 36 | 37 | bin_sizes = train_data['bin_sizes'] 38 | binsize = np.prod(bin_sizes) 39 | 40 | print('X: ', X.shape) 41 | 42 | num_z_space = 30 43 | 44 | grid = True 45 | print(Y.shape) 46 | print("num data points =", Y.shape[0]) 47 | 48 | if grid: 49 | # the gridded approach: 50 | t, R, Y = bayesnewton.utils.create_spatiotemporal_grid(X, Y) 51 | t_t, R_t, Y_t = bayesnewton.utils.create_spatiotemporal_grid(X_t, Y_t) 52 | else: 53 | # the sequential approach: 54 | t = X[:, :1] 55 | R = X[:, 1:] 56 | t_t = X_t[:, :1] 57 | R_t = X_t[:, 1:] 58 | Nt = t.shape[0] 59 | print("num time steps =", Nt) 60 | N = Y.shape[0] * Y.shape[1] * Y.shape[2] 61 | print("num data points =", N) 62 | 63 | var_f = 1. 64 | len_time = 0.001 65 | len_space = 0.1 66 | 67 | sparse = True 68 | opt_z = True # will be set to False if sparse=False 69 | 70 | if sparse: 71 | z = kmeans2(R[0, ...], num_z_space, minit="points")[0] 72 | else: 73 | z = R[0, ...] 74 | 75 | kern_time = bayesnewton.kernels.Matern32(variance=var_f, lengthscale=len_time) 76 | kern_space0 = bayesnewton.kernels.Matern32(variance=var_f, lengthscale=len_space) 77 | kern_space1 = bayesnewton.kernels.Matern32(variance=var_f, lengthscale=len_space) 78 | kern_space = bayesnewton.kernels.Separable([kern_space0, kern_space1]) 79 | 80 | kern = bayesnewton.kernels.SpatioTemporalKernel(temporal_kernel=kern_time, 81 | spatial_kernel=kern_space, 82 | z=z, 83 | sparse=sparse, 84 | opt_z=opt_z, 85 | conditional='Full') 86 | 87 | lik = bayesnewton.likelihoods.Poisson(binsize=binsize) 88 | 89 | if mean_field: 90 | model = bayesnewton.models.MarkovVariationalMeanFieldGP(kernel=kern, likelihood=lik, X=t, R=R, Y=Y, parallel=parallel) 91 | else: 92 | model = bayesnewton.models.MarkovVariationalGP(kernel=kern, likelihood=lik, X=t, R=R, Y=Y, parallel=parallel) 93 | 94 | lr_adam = 0.01 95 | lr_newton = 0.1 96 | iters = 500 97 | opt_hypers = objax.optimizer.Adam(model.vars()) 98 | energy = objax.GradValues(model.energy, model.vars()) 99 | 100 | 101 | @objax.Function.with_vars(model.vars() + opt_hypers.vars()) 102 | def train_op(): 103 | model.inference(lr=lr_newton) # perform inference and update variational params 104 | dE, E = energy() # compute energy and its gradients w.r.t. hypers 105 | opt_hypers(lr_adam, dE) 106 | return E 107 | 108 | 109 | train_op = objax.Jit(train_op) 110 | 111 | t0 = time.time() 112 | for i in range(1, iters): 113 | loss = train_op() 114 | print('iter %2d: energy: %1.4f' % (i, loss[0])) 115 | t1 = time.time() 116 | # print('optimisation time: %2.2f secs' % (t1-t0)) 117 | avg_time_taken = (t1-t0)/iters 118 | print('average iter time: %2.2f secs' % avg_time_taken) 119 | 120 | posterior_mean, posterior_var = model.predict_y(X=t_t, R=R_t) 121 | nlpd = model.negative_log_predictive_density(X=t_t, R=R_t, Y=Y_t) 122 | rmse = np.sqrt(np.nanmean((np.squeeze(Y_t) - np.squeeze(posterior_mean))**2)) 123 | print('nlpd: %2.3f' % nlpd) 124 | print('rmse: %2.3f' % rmse) 125 | 126 | cpugpu = xla_bridge.get_backend().platform 127 | 128 | with open("../results/" + str(int(mean_field)) + "_" + str(ind) + "_" + str(int(parallel)) + "_" + cpugpu + "_time.txt", "wb") as fp: 129 | pickle.dump(avg_time_taken, fp) 130 | with open("../results/" + str(int(mean_field)) + "_" + str(ind) + "_" + str(int(parallel)) + "_" + cpugpu + "_nlpd.txt", "wb") as fp: 131 | pickle.dump(nlpd, fp) 132 | with open("../results/" + str(int(mean_field)) + "_" + str(ind) + "_" + str(int(parallel)) + "_" + cpugpu + "_rmse.txt", "wb") as fp: 133 | pickle.dump(rmse, fp) 134 | -------------------------------------------------------------------------------- /datasets/london_air_pollution/aq_downloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Downloads air pollution data from http://www.londonair.org.uk adn converts to a csv file. 3 | Will need R and Rscript to run. 4 | """ 5 | import urllib.request 6 | import requests 7 | import urllib.request 8 | from pathlib import Path 9 | from subprocess import call 10 | import os 11 | import pandas as pd 12 | #CONFIG 13 | #year to download data for 14 | YEAR = 2019 15 | DOWNLOAD_FLAG = True 16 | POLLUTANTS = ["nox","no2","o3","co","pm10_raw","pm10","pm25"] 17 | 18 | #Ensure folder structure exists 19 | Path("downloaded_data/").mkdir(exist_ok=True) 20 | Path("downloaded_data/aq_r_data/").mkdir(exist_ok=True) 21 | Path("downloaded_data/aq_csv/").mkdir(exist_ok=True) 22 | 23 | def rdata_to_csv(file_rdata, file_csv_output, rdata_df): 24 | call(['Rscript','--vanilla','data_processing/rdata_to_csv.r',file_rdata,file_csv_output,rdata_df]) 25 | 26 | def rdata_to_csv_for_aq(file_rdata, file_csv_output, rdata_df): 27 | call(['Rscript','--vanilla','data_processing/rdata_to_csv_for_aq.r',file_rdata,file_csv_output,rdata_df]) 28 | 29 | #========== DOWNLOAD LAQN SENSOR SITES ============ 30 | if not os.path.isfile('downloaded_data/laqn_sites.RData'): 31 | sites_url = 'http://www.londonair.org.uk/r_data/sites.RData' 32 | urllib.request.urlretrieve(sites_url, 'downloaded_data/laqn_sites.RData') 33 | 34 | #Convert Rdata to csv 35 | if not os.path.isfile('downloaded_data/laqn_sites.csv'): 36 | rdata_to_csv('downloaded_data/laqn_sites.RData', 'downloaded_data/laqn_sites.csv', 'sites') 37 | 38 | if not os.path.isfile('downloaded_data/laqn_sites.csv'): 39 | raise RuntimeError('Could not generate laqn_sites.csv') 40 | 41 | laqn_sites = pd.read_csv('downloaded_data/laqn_sites.csv', delimiter=';') 42 | 43 | laqn_site_codes = laqn_sites['SiteCode'] 44 | 45 | #========== DOWNLOAD LAQN SENSOR SITES ============ 46 | #download Rdata 47 | if DOWNLOAD_FLAG: 48 | count = 0 49 | for site in laqn_site_codes: 50 | filename = "{site}_{year}.RData".format(site=site,year=YEAR) 51 | 52 | #check if site already downloaded 53 | if not os.path.isfile('downloaded_data/aq_r_data/{filename}'.format(filename=filename)): 54 | 55 | url = "https://www.londonair.org.uk/r_data/{filename}".format(filename=filename) 56 | dir_to_save = 'downloaded_data/aq_r_data/{filename}'.format(filename=filename) 57 | 58 | #check that file exists on server 59 | resp = requests.head(url) 60 | if resp.status_code is 200: 61 | #file exists 62 | count = count + 1 63 | print('Downloading {site} - {year}'.format(site=site, year=YEAR)) 64 | while True: 65 | try: 66 | urllib.request.urlretrieve(url, dir_to_save) 67 | break 68 | except Exception as e: 69 | print('Exception on site %s'%site) 70 | print(e) 71 | sleep(10) 72 | print('trying again') 73 | continue 74 | except: 75 | print('problem on site %s'%site) 76 | sleep(10) 77 | 78 | print('Downloaded {site} - {year}'.format(site=site, year=YEAR)) 79 | else: 80 | print('File not found: ', url) 81 | 82 | #========== CONVERT RDATA TO CSV ============ 83 | for site in laqn_site_codes: 84 | filename = 'downloaded_data/aq_r_data/' + "{site}_{year}.RData".format(site=site,year=YEAR) 85 | filename_csv ='downloaded_data/aq_csv/' + '{site}_{year}.csv'.format(site=site, year=YEAR) 86 | if os.path.isfile(filename) and not os.path.isfile(filename_csv): 87 | rdata_to_csv_for_aq(filename, filename_csv, 'x') 88 | else: 89 | print('Does not exists: ', filename) 90 | 91 | 92 | #========== MERGE CSVs INTO ONE FILE ============ 93 | total_df = pd.DataFrame(columns=['site', 'date']+POLLUTANTS) 94 | for site in laqn_site_codes: 95 | filename = "{site}_{year}.RData".format(site=site,year=YEAR) 96 | filename_csv ='downloaded_data/aq_csv/' + '{site}_{year}.csv'.format(site=site, year=YEAR) 97 | if os.path.isfile(filename_csv) and os.path.getsize(filename_csv) > 0: 98 | print(filename_csv) 99 | df = pd.read_csv(filename_csv, sep=';') 100 | print(filename_csv, ' , ',df.shape) 101 | result_df = df[['site', 'date']] 102 | for p in POLLUTANTS: 103 | if p in list(df.columns): 104 | result_df[p] = df[p] 105 | else: 106 | result_df[p] = None 107 | total_df = total_df.append(result_df) 108 | 109 | total_df.to_csv('downloaded_data/aq_data.csv', index=False) 110 | print(total_df.shape) 111 | -------------------------------------------------------------------------------- /experiments/air_quality/models/m_bayes_newt.py: -------------------------------------------------------------------------------- 1 | import bayesnewton 2 | import objax 3 | import numpy as np 4 | import pickle 5 | import time 6 | import sys 7 | from scipy.cluster.vq import kmeans2 8 | from jax.lib import xla_bridge 9 | # import os 10 | 11 | # Limit ourselves to single-threaded jax/xla operations to avoid thrashing. See 12 | # https://github.com/google/jax/issues/743. 13 | # os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false " 14 | # "intra_op_parallelism_threads=1") 15 | 16 | 17 | if len(sys.argv) > 1: 18 | ind = int(sys.argv[1]) 19 | else: 20 | ind = 0 21 | 22 | 23 | if len(sys.argv) > 2: 24 | mean_field = bool(int(sys.argv[2])) 25 | else: 26 | mean_field = False 27 | 28 | 29 | if len(sys.argv) > 3: 30 | parallel = bool(int(sys.argv[3])) 31 | else: 32 | parallel = None 33 | 34 | # ===========================Load Data=========================== 35 | train_data = pickle.load(open("../data/train_data_" + str(ind) + ".pickle", "rb")) 36 | pred_data = pickle.load(open("../data/pred_data_" + str(ind) + ".pickle", "rb")) 37 | 38 | X = train_data['X'] 39 | Y = train_data['Y'] 40 | 41 | X_t = pred_data['test']['X'] 42 | Y_t = pred_data['test']['Y'] 43 | 44 | print('X: ', X.shape) 45 | 46 | num_z_space = 30 47 | 48 | grid = True 49 | print(Y.shape) 50 | print("num data points =", Y.shape[0]) 51 | 52 | if grid: 53 | # the gridded approach: 54 | t, R, Y = bayesnewton.utils.create_spatiotemporal_grid(X, Y) 55 | t_t, R_t, Y_t = bayesnewton.utils.create_spatiotemporal_grid(X_t, Y_t) 56 | else: 57 | # the sequential approach: 58 | t = X[:, :1] 59 | R = X[:, 1:] 60 | t_t = X_t[:, :1] 61 | R_t = X_t[:, 1:] 62 | Nt = t.shape[0] 63 | print("num time steps =", Nt) 64 | Nr = R.shape[1] 65 | print("num spatial points =", Nr) 66 | N = Y.shape[0] * Y.shape[1] * Y.shape[2] 67 | print("num data points =", N) 68 | 69 | var_y = 5. 70 | var_f = 1. 71 | len_time = 0.001 72 | len_space = 0.2 73 | 74 | sparse = True 75 | opt_z = True # will be set to False if sparse=False 76 | 77 | if sparse: 78 | z = kmeans2(R[0, ...], num_z_space, minit="points")[0] 79 | else: 80 | z = R[0, ...] 81 | 82 | # kern = bayesnewton.kernels.SpatioTemporalMatern52(variance=var_f, 83 | # lengthscale_time=len_time, 84 | # lengthscale_space=[len_space, len_space], 85 | # z=z, 86 | # sparse=sparse, 87 | # opt_z=opt_z, 88 | # conditional='Full') 89 | 90 | kern_time = bayesnewton.kernels.Matern32(variance=var_f, lengthscale=len_time) 91 | kern_space0 = bayesnewton.kernels.Matern32(variance=var_f, lengthscale=len_space) 92 | kern_space1 = bayesnewton.kernels.Matern32(variance=var_f, lengthscale=len_space) 93 | kern_space = bayesnewton.kernels.Separable([kern_space0, kern_space1]) 94 | 95 | kern = bayesnewton.kernels.SpatioTemporalKernel(temporal_kernel=kern_time, 96 | spatial_kernel=kern_space, 97 | z=z, 98 | sparse=sparse, 99 | opt_z=opt_z, 100 | conditional='Full') 101 | 102 | lik = bayesnewton.likelihoods.Gaussian(variance=var_y) 103 | 104 | if mean_field: 105 | model = bayesnewton.models.MarkovVariationalMeanFieldGP(kernel=kern, likelihood=lik, X=t, R=R, Y=Y, parallel=parallel) 106 | else: 107 | model = bayesnewton.models.MarkovVariationalGP(kernel=kern, likelihood=lik, X=t, R=R, Y=Y, parallel=parallel) 108 | 109 | lr_adam = 0.01 110 | lr_newton = 1. 111 | iters = 300 112 | opt_hypers = objax.optimizer.Adam(model.vars()) 113 | energy = objax.GradValues(model.energy, model.vars()) 114 | 115 | 116 | @objax.Function.with_vars(model.vars() + opt_hypers.vars()) 117 | def train_op(): 118 | model.inference(lr=lr_newton) # perform inference and update variational params 119 | dE, E = energy() # compute energy and its gradients w.r.t. hypers 120 | opt_hypers(lr_adam, dE) 121 | return E 122 | 123 | 124 | train_op = objax.Jit(train_op) 125 | 126 | t0 = time.time() 127 | for i in range(1, iters + 1): 128 | loss = train_op() 129 | print('iter %2d: energy: %1.4f' % (i, loss[0])) 130 | t1 = time.time() 131 | # print('optimisation time: %2.2f secs' % (t1-t0)) 132 | avg_time_taken = (t1-t0)/iters 133 | print('average iter time: %2.2f secs' % avg_time_taken) 134 | 135 | posterior_mean, posterior_var = model.predict_y(X=t_t, R=R_t) 136 | nlpd = model.negative_log_predictive_density(X=t_t, R=R_t, Y=Y_t) 137 | rmse = np.sqrt(np.nanmean((np.squeeze(Y_t) - np.squeeze(posterior_mean))**2)) 138 | print('nlpd: %2.3f' % nlpd) 139 | print('rmse: %2.3f' % rmse) 140 | 141 | cpugpu = xla_bridge.get_backend().platform 142 | 143 | with open("../results/" + str(int(mean_field)) + "_" + str(ind) + "_" + str(int(parallel)) + "_" + cpugpu + "_time.txt", "wb") as fp: 144 | pickle.dump(avg_time_taken, fp) 145 | with open("../results/" + str(int(mean_field)) + "_" + str(ind) + "_" + str(int(parallel)) + "_" + cpugpu + "_nlpd.txt", "wb") as fp: 146 | pickle.dump(nlpd, fp) 147 | with open("../results/" + str(int(mean_field)) + "_" + str(ind) + "_" + str(int(parallel)) + "_" + cpugpu + "_rmse.txt", "wb") as fp: 148 | pickle.dump(rmse, fp) 149 | -------------------------------------------------------------------------------- /experiments/air_quality/models/m_ski.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gpytorch 3 | from gpytorch.means import ConstantMean 4 | from gpytorch.kernels import MaternKernel, ScaleKernel, GridInterpolationKernel 5 | from gpytorch.distributions import MultivariateNormal 6 | import numpy as np 7 | from loguru import logger 8 | import pickle 9 | from timeit import default_timer as timer 10 | import matplotlib.pyplot as plt 11 | import sys 12 | 13 | 14 | if len(sys.argv) > 1: 15 | ind = int(sys.argv[1]) 16 | plot_final = False 17 | else: 18 | ind = 0 19 | plot_final = True 20 | 21 | 22 | inducing_type = 'all_time' # 'default' 23 | num_z = 30 24 | likelihood_noise = 5. 25 | kernel_lengthscales = [0.001, 0.2, 0.2] 26 | step_size = 0.01 27 | iters = 300 28 | init_params = {} 29 | optimizer = torch.optim.Adam 30 | 31 | cpugpu = str(0) 32 | 33 | 34 | # ===========================Load Data=========================== 35 | train_data = pickle.load(open("../data/train_data_" + str(ind) + ".pickle", "rb")) 36 | pred_data = pickle.load(open("../data/pred_data_" + str(ind) + ".pickle", "rb")) 37 | 38 | X = train_data['X'] 39 | Y = np.squeeze(train_data['Y']) 40 | 41 | X_t = pred_data['test']['X'] 42 | Y_t = pred_data['test']['Y'] 43 | 44 | print('X: ', X.shape) 45 | 46 | non_nan_idx = np.squeeze(~np.isnan(Y)) 47 | 48 | X = torch.tensor(X[non_nan_idx]).float() 49 | Y = torch.tensor(Y[non_nan_idx]).float() 50 | 51 | D = X.shape[1] 52 | Nt = 2159 # number of time steps 53 | 54 | non_nan_idx_t = np.squeeze(~np.isnan(Y_t)) 55 | 56 | X_t = torch.tensor(X_t[non_nan_idx_t]).float() 57 | Y_t = np.squeeze(Y_t[non_nan_idx_t]) 58 | 59 | 60 | class GPRegressionModelSKI(gpytorch.models.ExactGP): 61 | 62 | def __init__(self, train_x, train_y, kernel, likelihood): 63 | super(GPRegressionModelSKI, self).__init__(train_x, train_y, likelihood) 64 | self.mean_module = ConstantMean() 65 | 66 | self.base_covar_module = kernel 67 | self.base_covar_module.lengthscale = torch.tensor(kernel_lengthscales) 68 | logger.info(f'kernel_lengthscales : {kernel_lengthscales}') 69 | 70 | if inducing_type == 'default': 71 | grid_size = gpytorch.utils.grid.choose_grid_size(train_x) 72 | init_params['grid_size'] = grid_size 73 | 74 | elif inducing_type == 'all_time': 75 | grid_size = np.array([Nt, np.ceil(np.sqrt(num_z)), np.ceil(np.sqrt(num_z))]).astype(int) 76 | init_params['grid_size'] = grid_size 77 | 78 | logger.info(f'grid_size : {grid_size}') 79 | 80 | self.covar_module = ScaleKernel( 81 | GridInterpolationKernel(self.base_covar_module, grid_size, num_dims=D) 82 | ) 83 | 84 | def forward(self, x): 85 | mean_x = self.mean_module(x) 86 | covar_x = self.covar_module(x) 87 | return MultivariateNormal(mean_x, covar_x) 88 | 89 | 90 | kern = MaternKernel(ard_num_dims=3, nu=1.5) 91 | # kern = MaternKernel(ard_num_dims=1, nu=1.5) 92 | lik = gpytorch.likelihoods.GaussianLikelihood() 93 | lik.noise = torch.tensor(likelihood_noise) 94 | 95 | model = GPRegressionModelSKI(X, Y, kern, lik) # SKI model 96 | 97 | # train 98 | model.train() 99 | lik.train() 100 | 101 | # Use the adam optimizer 102 | optimizer = optimizer(model.parameters(), lr=step_size) 103 | 104 | # "Loss" for GPs - the marginal log likelihood 105 | mll = gpytorch.mlls.ExactMarginalLogLikelihood(lik, model) 106 | 107 | loss_arr = [] 108 | 109 | 110 | def train(): 111 | 112 | for i in range(iters): 113 | # Zero backprop gradients 114 | optimizer.zero_grad() 115 | 116 | with gpytorch.settings.use_toeplitz(False), gpytorch.settings.max_root_decomposition_size(30): 117 | 118 | # Get output from model 119 | output = model(X) 120 | 121 | # Calc loss and backprop derivatives 122 | loss = -mll(output, torch.squeeze(Y)) 123 | loss.backward() 124 | print('Iter %d/%d - Loss: %.3f' % (i + 1, iters, loss.item())) 125 | 126 | loss_arr.append(loss.detach().numpy()) 127 | 128 | optimizer.step() 129 | torch.cuda.empty_cache() 130 | 131 | 132 | start = timer() 133 | 134 | with gpytorch.settings.use_toeplitz(True): 135 | train() 136 | 137 | end = timer() 138 | 139 | training_time = end - start 140 | 141 | 142 | # ===========================Predict=========================== 143 | 144 | model.eval() 145 | lik.eval() 146 | 147 | print('noise var:', model.likelihood.noise.detach().numpy()) 148 | 149 | logger.info('Predicting') 150 | 151 | 152 | with gpytorch.settings.max_preconditioner_size(10), torch.no_grad(): 153 | with gpytorch.settings.use_toeplitz(False), gpytorch.settings.max_root_decomposition_size(30), gpytorch.settings.fast_pred_var(): 154 | preds = model(X_t) 155 | 156 | 157 | def negative_log_predictive_density(y, post_mean, post_cov, lik_cov): 158 | # logZₙ = log ∫ 𝓝(yₙ|fₙ,σ²) 𝓝(fₙ|mₙ,vₙ) dfₙ = log 𝓝(yₙ|mₙ,σ²+vₙ) 159 | cov = lik_cov + post_cov 160 | lZ = np.squeeze(-0.5 * np.log(2 * np.pi * cov) - 0.5 * (y - post_mean) ** 2 / cov) 161 | return -lZ 162 | 163 | 164 | posterior_mean, posterior_var = preds.mean.detach().numpy(), preds.variance.detach().numpy() 165 | 166 | noise_var = model.likelihood.noise.detach().numpy() 167 | print('noise var:', noise_var) 168 | 169 | nlpd = np.mean(negative_log_predictive_density(y=Y_t, 170 | post_mean=posterior_mean, 171 | post_cov=posterior_var, 172 | lik_cov=noise_var)) 173 | rmse = np.sqrt(np.nanmean((np.squeeze(Y_t) - np.squeeze(posterior_mean))**2)) 174 | print('nlpd: %2.3f' % nlpd) 175 | print('rmse: %2.3f' % rmse) 176 | 177 | avg_time_taken = training_time / iters 178 | print('avg iter time:', avg_time_taken) 179 | 180 | with open("../results/ski_" + str(ind) + "_" + cpugpu + "_time.txt", "wb") as fp: 181 | pickle.dump(avg_time_taken, fp) 182 | with open("../results/ski_" + str(ind) + "_" + cpugpu + "_nlpd.txt", "wb") as fp: 183 | pickle.dump(nlpd, fp) 184 | with open("../results/ski_" + str(ind) + "_" + cpugpu + "_rmse.txt", "wb") as fp: 185 | pickle.dump(rmse, fp) 186 | 187 | if plot_final: 188 | plt.plot(posterior_mean) 189 | plt.show() 190 | -------------------------------------------------------------------------------- /experiments/nyc_crime/models/m_gpflow.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import gpflow 3 | from gpflow.optimizers import NaturalGradient 4 | from gpflow.utilities import set_trainable, leaf_components 5 | import numpy as np 6 | import scipy as sp 7 | import time 8 | from scipy.cluster.vq import kmeans2 9 | from tqdm import tqdm 10 | import pickle 11 | import sys 12 | 13 | print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU'))) 14 | 15 | 16 | if len(sys.argv) > 1: 17 | ind = int(sys.argv[1]) 18 | else: 19 | ind = 0 20 | 21 | 22 | if len(sys.argv) > 2: 23 | num_z_ind = int(sys.argv[2]) 24 | else: 25 | num_z_ind = 0 26 | 27 | 28 | # ===========================Load Data=========================== 29 | train_data = pickle.load(open("../data/train_data_" + str(ind) + ".pickle", "rb")) 30 | pred_data = pickle.load(open("../data/pred_data_" + str(ind) + ".pickle", "rb")) 31 | 32 | X = train_data['X'] 33 | Y = train_data['Y'] 34 | 35 | X_t = pred_data['test']['X'] 36 | Y_t = pred_data['test']['Y'] 37 | 38 | bin_sizes = train_data['bin_sizes'] 39 | binsize = np.prod(bin_sizes) 40 | 41 | non_nan_idx = np.logical_not(np.isnan(np.squeeze(Y))) 42 | X = X[non_nan_idx, :] 43 | Y = Y[non_nan_idx, :] 44 | 45 | non_nan_idx_t = np.logical_not(np.isnan(np.squeeze(Y_t))) 46 | X_t = X_t[non_nan_idx_t, :] 47 | Y_t = Y_t[non_nan_idx_t, :] 48 | 49 | print('X: ', X.shape) 50 | 51 | kernel_lengthscales = [0.001, 0.1, 0.1] 52 | kernel_variances = 1.0 53 | train_z = True 54 | epochs = 500 55 | step_size = 0.01 56 | # jitter = 1e-4 57 | natgrad_step_size = 0.1 58 | # enforce_psd = False 59 | minibatch_size = [1500, 3000] 60 | num_z = [1500, 3000] 61 | 62 | 63 | def get_gpflow_params(m): 64 | params = {} 65 | leafs = leaf_components(m) 66 | for key in leafs.keys(): 67 | tf_vars = leafs[key].trainable_variables 68 | 69 | # check if variable exists 70 | if len(tf_vars) == 1: 71 | tf_var = tf_vars[0] 72 | 73 | params[key] = tf_var.numpy() 74 | 75 | return params 76 | 77 | 78 | N, D = X.shape 79 | 80 | print('num_z: ', num_z[num_z_ind]) 81 | Z_all = kmeans2(X, num_z[num_z_ind], minit="points")[0] 82 | 83 | kernel = gpflow.kernels.Matern32 84 | 85 | k = None 86 | for d in range(D): 87 | # print(d, kernel_lengthscales) 88 | if type(kernel_lengthscales) is list: 89 | k_ls = kernel_lengthscales[d] 90 | else: 91 | k_ls = kernel_lengthscales 92 | 93 | if type(kernel_variances) is list: 94 | k_var = kernel_variances[d] 95 | else: 96 | k_var = kernel_variances 97 | 98 | k_d = kernel( 99 | lengthscales=[k_ls], 100 | variance=k_var, 101 | active_dims=[d] 102 | ) 103 | 104 | # print(k_d) 105 | if k is None: 106 | k = k_d 107 | else: 108 | k = k * k_d 109 | 110 | init_as_cvi = True 111 | 112 | if init_as_cvi: 113 | M = Z_all.shape[0] 114 | jit = 1e-6 115 | 116 | Kzz = k(Z_all, Z_all) 117 | 118 | def inv(K): 119 | K_chol = sp.linalg.cholesky(K + jit * np.eye(M), lower=True) 120 | return sp.linalg.cho_solve((K_chol, True), np.eye(K.shape[0])) 121 | 122 | # manual q(u) decompositin 123 | nat1 = np.zeros([M, 1]) 124 | nat2 = -0.5 * inv(Kzz) 125 | 126 | lam1 = 1e-5 * np.ones([M, 1]) 127 | lam2 = -0.5 * np.eye(M) 128 | 129 | S = inv(-2 * (nat2 + lam2)) 130 | m = S @ (lam1 + nat1) 131 | 132 | S_chol = sp.linalg.cholesky(S + jit * np.eye(M), lower=True) 133 | S_flattened = S_chol[np.tril_indices(M, 0)] 134 | 135 | q_mu = m 136 | q_sqrt = np.array([S_chol]) 137 | else: 138 | q_mu = 1e-5 * np.ones([Z_all.shape[0], 1]) # match gpjax init 139 | q_sqrt = None 140 | 141 | lik = gpflow.likelihoods.Poisson(binsize=binsize) 142 | 143 | data = (X, Y) 144 | 145 | m = gpflow.models.SVGP( 146 | inducing_variable=Z_all, 147 | whiten=True, 148 | kernel=k, 149 | mean_function=None, 150 | likelihood=lik, 151 | q_mu=q_mu, 152 | q_sqrt=q_sqrt 153 | ) 154 | 155 | set_trainable(m.inducing_variable, True) 156 | 157 | # ===========================Train=========================== 158 | 159 | if minibatch_size[num_z_ind] is None or minibatch_size[num_z_ind] is 'none': 160 | training_loss = m.training_loss_closure( 161 | data 162 | ) 163 | else: 164 | print(N, minibatch_size[num_z_ind]) 165 | train_dataset = (tf.data.Dataset.from_tensor_slices(data).repeat().shuffle(N).batch(minibatch_size[num_z_ind])) 166 | train_iter = iter(train_dataset) 167 | training_loss = m.training_loss_closure(train_iter) 168 | 169 | 170 | # make it so adam does not train these 171 | set_trainable(m.q_mu, False) 172 | set_trainable(m.q_sqrt, False) 173 | 174 | natgrad_opt = NaturalGradient(gamma=natgrad_step_size) 175 | variational_params = [(m.q_mu, m.q_sqrt)] 176 | 177 | optimizer = tf.optimizers.Adam 178 | 179 | adam_opt_for_vgp = optimizer(step_size) 180 | 181 | loss_arr = [] 182 | 183 | bar = tqdm(total=epochs) 184 | 185 | # MINIBATCHING TRAINING 186 | t0 = time.time() 187 | for i in range(epochs): 188 | # NAT GRAD STEP 189 | natgrad_opt.minimize(training_loss, var_list=variational_params) 190 | 191 | # elbo = -m.elbo(data).numpy() 192 | 193 | # loss_arr.append(elbo) 194 | 195 | # ADAM STEP 196 | adam_opt_for_vgp.minimize(training_loss, var_list=m.trainable_variables) 197 | 198 | bar.update(1) 199 | t1 = time.time() 200 | avg_time_taken = (t1-t0)/epochs 201 | print('average iter time: %2.2f secs' % avg_time_taken) 202 | 203 | 204 | def _prediction_fn(X_, Y_): 205 | mu, var = m.predict_y(X_) 206 | log_pred_density = m.predict_log_density((X_, Y_)) 207 | return mu.numpy(), var.numpy(), log_pred_density.numpy() 208 | 209 | 210 | print('predicting...') 211 | posterior_mean, posterior_var, lpd = _prediction_fn(X_t, Y_t) 212 | # print(lpd.shape) 213 | # print(lpd) 214 | nlpd = np.mean(-lpd) 215 | rmse = np.sqrt(np.nanmean((np.squeeze(Y_t) - np.squeeze(posterior_mean))**2)) 216 | print('nlpd: %2.3f' % nlpd) 217 | print('rmse: %2.3f' % rmse) 218 | 219 | # prediction_fn = lambda X: utils.batch_predict(X, _prediction_fn, verbose=True) 220 | 221 | if len(tf.config.list_physical_devices('GPU')) > 0: 222 | cpugpu = 'gpu' 223 | else: 224 | cpugpu = 'cpu' 225 | 226 | with open("../results/gpflow_" + str(ind) + "_" + str(num_z_ind) + "_" + cpugpu + "_time.txt", "wb") as fp: 227 | pickle.dump(avg_time_taken, fp) 228 | with open("../results/gpflow_" + str(ind) + "_" + str(num_z_ind) + "_" + cpugpu + "_nlpd.txt", "wb") as fp: 229 | pickle.dump(nlpd, fp) 230 | with open("../results/gpflow_" + str(ind) + "_" + str(num_z_ind) + "_" + cpugpu + "_rmse.txt", "wb") as fp: 231 | pickle.dump(rmse, fp) 232 | -------------------------------------------------------------------------------- /experiments/air_quality/models/m_gpflow.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import gpflow 3 | from gpflow.optimizers import NaturalGradient 4 | from gpflow.utilities import set_trainable, leaf_components 5 | import numpy as np 6 | import scipy as sp 7 | import time 8 | from scipy.cluster.vq import kmeans2 9 | from tqdm import tqdm 10 | import pickle 11 | import sys 12 | 13 | print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU'))) 14 | 15 | 16 | if len(sys.argv) > 1: 17 | ind = int(sys.argv[1]) 18 | else: 19 | ind = 0 20 | 21 | 22 | if len(sys.argv) > 2: 23 | num_z_ind = int(sys.argv[2]) 24 | else: 25 | num_z_ind = 0 26 | 27 | 28 | # ===========================Load Data=========================== 29 | train_data = pickle.load(open("../data/train_data_" + str(ind) + ".pickle", "rb")) 30 | pred_data = pickle.load(open("../data/pred_data_" + str(ind) + ".pickle", "rb")) 31 | 32 | X = train_data['X'] 33 | Y = train_data['Y'] 34 | 35 | X_t = pred_data['test']['X'] 36 | Y_t = pred_data['test']['Y'] 37 | 38 | non_nan_idx = np.logical_not(np.isnan(np.squeeze(Y))) 39 | X = X[non_nan_idx, :] 40 | Y = Y[non_nan_idx, :] 41 | 42 | non_nan_idx_t = np.logical_not(np.isnan(np.squeeze(Y_t))) 43 | X_t = X_t[non_nan_idx_t, :] 44 | Y_t = Y_t[non_nan_idx_t, :] 45 | 46 | print('X: ', X.shape) 47 | 48 | kernel_lengthscales = [0.01, 0.2, 0.2] 49 | kernel_variances = 1.0 50 | likelihood_noise = 5.0 51 | train_z = True 52 | epochs = 300 53 | step_size = 0.01 54 | # jitter = 1e-4 55 | natgrad_step_size = 1.0 56 | # enforce_psd = False 57 | # minibatch_size = [100, 500, 100] 58 | minibatch_size = [400, 600, 800, 2000, 3000] 59 | # num_z = [1000, 1500, 2000] 60 | num_z = [1500, 2000, 2500, 5000, 8000] 61 | 62 | 63 | def get_gpflow_params(m): 64 | params = {} 65 | leafs = leaf_components(m) 66 | for key in leafs.keys(): 67 | tf_vars = leafs[key].trainable_variables 68 | 69 | # check if variable exists 70 | if len(tf_vars) == 1: 71 | tf_var = tf_vars[0] 72 | 73 | params[key] = tf_var.numpy() 74 | 75 | return params 76 | 77 | 78 | N, D = X.shape 79 | 80 | print('num_z: ', num_z[num_z_ind]) 81 | Z_all = kmeans2(X, num_z[num_z_ind], minit="points")[0] 82 | 83 | kernel = gpflow.kernels.Matern32 84 | 85 | k = None 86 | for d in range(D): 87 | # print(d, kernel_lengthscales) 88 | if type(kernel_lengthscales) is list: 89 | k_ls = kernel_lengthscales[d] 90 | else: 91 | k_ls = kernel_lengthscales 92 | 93 | if type(kernel_variances) is list: 94 | k_var = kernel_variances[d] 95 | else: 96 | k_var = kernel_variances 97 | 98 | k_d = kernel( 99 | lengthscales=[k_ls], 100 | variance=k_var, 101 | active_dims=[d] 102 | ) 103 | 104 | # print(k_d) 105 | if k is None: 106 | k = k_d 107 | else: 108 | k = k * k_d 109 | 110 | init_as_cvi = True 111 | 112 | if init_as_cvi: 113 | M = Z_all.shape[0] 114 | jit = 1e-6 115 | 116 | Kzz = k(Z_all, Z_all) 117 | 118 | def inv(K): 119 | K_chol = sp.linalg.cholesky(K + jit * np.eye(M), lower=True) 120 | return sp.linalg.cho_solve((K_chol, True), np.eye(K.shape[0])) 121 | 122 | # manual q(u) decompositin 123 | nat1 = np.zeros([M, 1]) 124 | nat2 = -0.5 * inv(Kzz) 125 | 126 | lam1 = 1e-5 * np.ones([M, 1]) 127 | lam2 = -0.5 * np.eye(M) 128 | 129 | S = inv(-2 * (nat2 + lam2)) 130 | m = S @ (lam1 + nat1) 131 | 132 | S_chol = sp.linalg.cholesky(S + jit * np.eye(M), lower=True) 133 | S_flattened = S_chol[np.tril_indices(M, 0)] 134 | 135 | q_mu = m 136 | q_sqrt = np.array([S_chol]) 137 | else: 138 | q_mu = 1e-5 * np.ones([Z_all.shape[0], 1]) # match gpjax init 139 | q_sqrt = None 140 | 141 | lik = gpflow.likelihoods.Gaussian(variance=likelihood_noise) 142 | 143 | data = (X, Y) 144 | 145 | m = gpflow.models.SVGP( 146 | inducing_variable=Z_all, 147 | whiten=True, 148 | kernel=k, 149 | mean_function=None, 150 | likelihood=lik, 151 | q_mu=q_mu, 152 | q_sqrt=q_sqrt 153 | ) 154 | 155 | set_trainable(m.inducing_variable, True) 156 | 157 | # ===========================Train=========================== 158 | 159 | if minibatch_size[num_z_ind] is None or minibatch_size[num_z_ind] is 'none': 160 | training_loss = m.training_loss_closure( 161 | data 162 | ) 163 | else: 164 | print(N, minibatch_size[num_z_ind]) 165 | train_dataset = (tf.data.Dataset.from_tensor_slices(data).repeat().shuffle(N).batch(minibatch_size[num_z_ind])) 166 | train_iter = iter(train_dataset) 167 | training_loss = m.training_loss_closure(train_iter) 168 | 169 | 170 | # make it so adam does not train these 171 | set_trainable(m.q_mu, False) 172 | set_trainable(m.q_sqrt, False) 173 | 174 | natgrad_opt = NaturalGradient(gamma=natgrad_step_size) 175 | variational_params = [(m.q_mu, m.q_sqrt)] 176 | 177 | optimizer = tf.optimizers.Adam 178 | 179 | adam_opt_for_vgp = optimizer(step_size) 180 | 181 | loss_arr = [] 182 | 183 | bar = tqdm(total=epochs) 184 | 185 | # MINIBATCHING TRAINING 186 | t0 = time.time() 187 | for i in range(epochs): 188 | # NAT GRAD STEP 189 | natgrad_opt.minimize(training_loss, var_list=variational_params) 190 | 191 | # elbo = -m.elbo(data).numpy() 192 | 193 | # loss_arr.append(elbo) 194 | 195 | # ADAM STEP 196 | adam_opt_for_vgp.minimize(training_loss, var_list=m.trainable_variables) 197 | 198 | bar.update(1) 199 | t1 = time.time() 200 | avg_time_taken = (t1-t0)/epochs 201 | print('average iter time: %2.2f secs' % avg_time_taken) 202 | 203 | 204 | def _prediction_fn(X_, Y_): 205 | mu, var = m.predict_y(X_) 206 | log_pred_density = m.predict_log_density((X_, Y_)) 207 | return mu.numpy(), var.numpy(), log_pred_density.numpy() 208 | 209 | 210 | print('predicting...') 211 | posterior_mean, posterior_var, lpd = _prediction_fn(X_t, Y_t) 212 | # print(lpd.shape) 213 | # print(lpd) 214 | nlpd = np.mean(-lpd) 215 | rmse = np.sqrt(np.nanmean((np.squeeze(Y_t) - np.squeeze(posterior_mean))**2)) 216 | print('nlpd: %2.3f' % nlpd) 217 | print('rmse: %2.3f' % rmse) 218 | 219 | # prediction_fn = lambda X: utils.batch_predict(X, _prediction_fn, verbose=True) 220 | 221 | if len(tf.config.list_physical_devices('GPU')) > 0: 222 | cpugpu = 'gpu' 223 | else: 224 | cpugpu = 'cpu' 225 | 226 | with open("../results/gpflow_" + str(ind) + "_" + str(num_z_ind) + "_" + cpugpu + "_time.txt", "wb") as fp: 227 | pickle.dump(avg_time_taken, fp) 228 | with open("../results/gpflow_" + str(ind) + "_" + str(num_z_ind) + "_" + cpugpu + "_nlpd.txt", "wb") as fp: 229 | pickle.dump(nlpd, fp) 230 | with open("../results/gpflow_" + str(ind) + "_" + str(num_z_ind) + "_" + cpugpu + "_rmse.txt", "wb") as fp: 231 | pickle.dump(rmse, fp) 232 | -------------------------------------------------------------------------------- /experiments/nyc_crime/setup_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | NYC-CRIME data setup. 3 | 4 | See datasets/new_york_crime for raw data installation. 5 | 6 | """ 7 | import pandas as pd 8 | import numpy as np 9 | import pickle 10 | from pathlib import Path 11 | from sklearn.model_selection import KFold, train_test_split 12 | 13 | import geopandas as gpd 14 | 15 | import json 16 | 17 | from loguru import logger 18 | 19 | import matplotlib.pyplot as plt 20 | from matplotlib import animation 21 | 22 | from matplotlib.colors import Normalize 23 | 24 | def weeks_between_dates(start, end): 25 | x = pd.to_datetime(end) - pd.to_datetime(start) 26 | return int(x / np.timedelta64(1, 'W')) 27 | 28 | #========================== Settings ========================== 29 | 30 | SEED = 2 31 | TRAIN_SPLIT=80 #80-20% train test split 32 | NUM_FOLDS=5 33 | FOLDS = list(range(NUM_FOLDS)) 34 | #BIN_SIZES=[int(365/2), 30, 30] 35 | date_start='2010-01-01' 36 | date_end='2013-01-1' 37 | BIN_SIZES=[weeks_between_dates(date_start, date_end), 40, 40] 38 | PD_DESC_FILTER = 'ASSAULT 3' 39 | dataset_folder_root = '../../datasets/' 40 | 41 | 42 | #========================== Helper functions ========================== 43 | 44 | def get_data_file_names(config): 45 | """ Used by the models to find the correct pickles for a specific fold. """ 46 | fold=config['fold'] 47 | return f'train_data_{fold}.pickle', f'pred_data_{fold}.pickle', f'raw_data_{fold}.pickle' 48 | 49 | def get_sub_dataframe(data_df, boundaries_gdf, date_start, date_end, bin_sizes): 50 | """ Returns count data inside the new york boroughs. """ 51 | 52 | df = data_df.copy() 53 | #get data related only to PD_DESC_FILTER 54 | if PD_DESC_FILTER is not None: 55 | df = df[df['PD_DESC'] == PD_DESC_FILTER].copy() 56 | 57 | 58 | #get data only within the date range 59 | df = df[(df['datetime'] >= date_start) & (df['datetime'] < date_end)] 60 | X_raw = np.array(df[['epoch', 'Longitude', 'Latitude']]) 61 | 62 | 63 | #calculate the count data using binning defined bin_sizes 64 | X, Y, actual_bin_sizes = utils.center_point_discretise_grid( 65 | X_raw, 66 | bin_sizes=bin_sizes 67 | 68 | ) 69 | print('time: ', X[X[:, 0] == X[0, 0], :].shape) 70 | print('unique: ', np.unique(X[:, 1:], axis=0).shape) 71 | 72 | #convert the count data to a geodataframe so it can be merged with the boundaries 73 | data = pd.DataFrame( 74 | np.hstack([X, Y]), 75 | columns=['epoch', 'Longitude', 'Latitude', 'Y'] 76 | ) 77 | 78 | data_gdf = gpd.GeoDataFrame( 79 | data, 80 | geometry=gpd.points_from_xy(data.Longitude, data.Latitude), 81 | crs='EPSG:4326' 82 | ) 83 | data_gdf.to_crs('EPSG:4326') 84 | 85 | sub_gdf = data_gdf 86 | col = sub_gdf.apply( 87 | lambda row: utils.lat_lon_to_polygon_buffer(row, actual_bin_sizes), 88 | axis=1 89 | ) 90 | sub_gdf['polygon'] = col 91 | sub_gdf = sub_gdf.set_geometry('polygon') 92 | 93 | 94 | #merge and remove all data that is not in the new york boundaries 95 | #this causes duplicates due to one cell belonging to multiple boroughs 96 | sub_gdf = gpd.sjoin( 97 | sub_gdf, 98 | boundaries_gdf, 99 | how='inner', 100 | op='intersects' 101 | ) 102 | 103 | sub_gdf.drop_duplicates(subset=['epoch', 'Longitude', 'Latitude'], inplace=True, ignore_index=True) 104 | 105 | print('Number of spatial bins: ', sub_gdf[sub_gdf['epoch']==sub_gdf['epoch'][0]].shape) 106 | 107 | 108 | return sub_gdf, actual_bin_sizes 109 | 110 | def load_data(): 111 | """ Load raw data locally. """ 112 | 113 | def get_dataframe(root, file_name): 114 | raw_data = data_root + file_name 115 | df = pd.read_csv(raw_data, low_memory=False) 116 | return df 117 | 118 | data_root = f'{dataset_folder_root}/new_york_crime_large/' 119 | boundary_data_root = f'{dataset_folder_root}/new_york_crime_large/' 120 | df = get_dataframe(data_root, 'data/cleaned_nyc_crime.csv') 121 | boundaries_gdf = gpd.read_file(f'{boundary_data_root}/data/Borough_Boundaries/nyc.shp') 122 | 123 | return data_root, df, boundaries_gdf 124 | 125 | 126 | if __name__ == "__main__": 127 | import sys 128 | sys.path.append('../') 129 | from utils import normalise_df, create_spatial_temporal_grid, datetime_to_epoch, normalise, ensure_timeseries_at_each_locations, un_normalise_df, epoch_to_datetime 130 | import utils 131 | 132 | logger.info('setting up data') 133 | 134 | print('BINNING: ', BIN_SIZES) 135 | 136 | #ensure correct data structure 137 | Path("data/").mkdir(exist_ok=True) 138 | Path("results/").mkdir(exist_ok=True) 139 | 140 | np.random.seed(SEED) 141 | 142 | data_root, df, boundaries_gdf = load_data() 143 | raw_df = df.copy() 144 | 145 | df, bin_sizes_raw = get_sub_dataframe(df, boundaries_gdf, date_start, date_end, BIN_SIZES) 146 | 147 | print(f'Filtered df: {df.shape} with binsizes: {bin_sizes_raw}') 148 | 149 | #get basic stats 150 | num_time_points = np.unique(np.array(df['epoch'])).shape[0] 151 | num_spatial_points = np.array(df['epoch']).shape[0]/num_time_points 152 | 153 | print('num_time_points: ', num_time_points) 154 | print('num_spatial_points: ', num_spatial_points) 155 | 156 | X_raw = np.array(df[['epoch', 'Longitude', 'Latitude']]) 157 | Y_raw = np.array(df[['Y']]) 158 | 159 | #although Y are integers we cast to float so that we can set some indices to np.nan 160 | Y_raw = Y_raw.astype(float) 161 | 162 | #raw x 163 | N = Y_raw.shape[0] 164 | 165 | for fold in range(NUM_FOLDS): 166 | _SEED = SEED + fold 167 | train_indices, test_indices = utils.train_test_split_indices(N, split=TRAIN_SPLIT/100, seed=_SEED) 168 | 169 | #Collect training and testing data 170 | X_train, Y_train = X_raw.copy(), Y_raw.copy() 171 | Y_train[test_indices, :] = np.nan #to keep grid structure in X we just mask the testing data in the training set 172 | 173 | X_test, Y_test = X_raw.copy(), Y_raw.copy() 174 | Y_test[train_indices, :] = np.nan 175 | 176 | X_all = X_raw 177 | Y_all = Y_raw 178 | 179 | #normalise all data with respect to training data 180 | X_train_std = np.std(X_train, axis=0) 181 | 182 | X_train_norm = normalise_df(X_train, wrt_to=X_train) 183 | X_test_norm = normalise_df(X_test, wrt_to=X_train) 184 | X_all_norm = normalise_df(X_all, wrt_to=X_train) 185 | 186 | # (x1-m)/s - (x2-m)/s = ((x-1)-(x2-m))/s = (x1-x2)/s 187 | bin_sizes_norm = bin_sizes_raw/X_train_std 188 | 189 | print('---') 190 | print('X_train: ', X_train_norm.shape) 191 | print(np.nanmean(Y_train), np.nanstd(Y_train)) 192 | print('Y_train: ', Y_train.shape, ' Non nans: ', np.sum(np.logical_not(np.isnan(Y_train)))) 193 | print('X_test: ', X_test_norm.shape) 194 | print('Y_test: ', Y_test.shape, ' Non nans: ', np.sum(np.logical_not(np.isnan(Y_test)))) 195 | print('X_all: ', X_all.shape) 196 | print('Y_all: ', X_all_norm.shape) 197 | print('bin_sizes_norm: ', bin_sizes_raw, bin_sizes_norm) 198 | 199 | training_data = { 200 | 'X': X_train_norm, 201 | 'Y': Y_train, 202 | 'bin_sizes': bin_sizes_norm 203 | } 204 | 205 | prediction_data = { 206 | 'test': { 207 | 'X': X_test_norm, 208 | 'Y': Y_test 209 | }, 210 | 'all': { 211 | 'X': X_all_norm, 212 | 'Y': Y_all 213 | } 214 | 215 | } 216 | 217 | raw_data_dict = { 218 | 'data': { 219 | 'train': { 220 | 'X': X_train, 221 | 'Y': Y_train 222 | }, 223 | 'all': { 224 | 'X': X_raw, 225 | 'Y': Y_raw, 226 | 'bin_sizes': bin_sizes_raw 227 | } 228 | }, 229 | } 230 | 231 | train_data_name, pred_data_name, raw_data_name = get_data_file_names({'fold': fold}) 232 | 233 | with open('data/{train_name}'.format(train_name=train_data_name), 'wb') as file: 234 | pickle.dump(training_data, file) 235 | 236 | with open('data/{pred_name}'.format(pred_name=pred_data_name), "wb") as file: 237 | pickle.dump(prediction_data, file) 238 | 239 | with open(f'data/{raw_data_name}', "wb") as file: 240 | pickle.dump(raw_data_dict, file) 241 | 242 | logger.info('finished') 243 | 244 | 245 | 246 | 247 | 248 | -------------------------------------------------------------------------------- /experiments/air_quality/setup_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file creates the required folder structure and setups the experiment data 3 | 4 | This reads the london air quality data from: 5 | ../../datasets/london_air_pollution 6 | 7 | and constructs train-test splits stored in pickles inside ./data/ 8 | """ 9 | import pandas as pd 10 | import numpy as np 11 | import pickle 12 | from pathlib import Path 13 | import matplotlib.pyplot as plt 14 | 15 | from dateutil import parser 16 | 17 | import json 18 | from loguru import logger 19 | 20 | try: 21 | import geopandas as gpd 22 | from geopandas.tools import sjoin 23 | except ModuleNotFoundError: 24 | print('Geopandas not found') 25 | pass 26 | 27 | 28 | #============================== Experiment Settings ============================== 29 | 30 | NUM_FOLDS = 5 31 | FOLDS = list(range(NUM_FOLDS)) 32 | 33 | #1 day, 2 days, 1 week 34 | SPECIES=['pm10'] 35 | PREDICTION_SITE = 'HK6' #HK6 = Hackney - Old Street 36 | DATE_START='2019/01/01' 37 | DATE_END='2019/04/01' 38 | TRAIN_SPLIT=0.8 #use 80% of data for training 39 | datasets_folder_root = '../../datasets/' 40 | 41 | SEED = 3 42 | 43 | def get_data_file_names(config): 44 | fold = config['fold'] 45 | 46 | return f'train_data_{fold}.pickle', f'pred_data_{fold}.pickle', f'raw_data_{fold}.pickle' 47 | 48 | #============================== Create Experiment Data ============================== 49 | # only create data when file run from terminal so that the experiment settings can be imported 50 | 51 | def load_data(): 52 | """ Locally load the raw air quality dataset. """ 53 | 54 | def get_data(root): 55 | raw_data = pd.read_csv(root+'london_air_pollution/downloaded_data/aq_data.csv') 56 | sites_df = pd.read_csv(root+'london_air_pollution/downloaded_data/laqn_sites.csv', sep=';') 57 | return raw_data, sites_df 58 | 59 | # Get local dataset 60 | raw_data, sites_df = get_data(datasets_folder_root) 61 | 62 | return raw_data, sites_df 63 | 64 | def get_boundary_gdf(): 65 | """ Load the london boundary locally. """ 66 | 67 | root = datasets_folder_root 68 | #try running locally 69 | data_root = 'london/london_shp/statistical-gis-boundaries-london/ESRI/' 70 | boundary_gdf = gpd.read_file(root+data_root+'LSOA_2011_London_gen_MHW.shp') 71 | 72 | boundary_gdf = boundary_gdf.to_crs({'init': 'epsg:4326'}) 73 | return boundary_gdf 74 | 75 | 76 | def filter_sites_not_in_london(sites_df): 77 | """ The LAQN datasets includes sensors that are not in London. Remove them. """ 78 | 79 | # Approximate bounding box of London 80 | london_box = [ 81 | [51.279, 51.684], #lat 82 | [-0.533, 0.208] #lon 83 | ] 84 | 85 | # Remove sites that do not lie within the bounding box 86 | sites_df = sites_df[(sites_df['Latitude'] > london_box[0][0]) & (sites_df['Latitude'] < london_box[0][1])] 87 | sites_df = sites_df[(sites_df['Longitude'] > london_box[1][0]) & (sites_df['Longitude'] < london_box[1][1])] 88 | 89 | boundary_gdf = get_boundary_gdf() 90 | 91 | sites_gdf = gpd.GeoDataFrame( 92 | sites_df, 93 | geometry=gpd.points_from_xy(sites_df.Longitude, sites_df.Latitude) 94 | ) 95 | sites_gdf.set_crs(epsg=4326, inplace=True) 96 | 97 | 98 | # Remove sites that do not lie within the boundary 99 | filtered_sites = sjoin(sites_gdf, boundary_gdf, how='inner', op='within') 100 | 101 | return filtered_sites 102 | 103 | def drop_locations_with_low_data(raw_data): 104 | """ Some sensors have been badly calibrates and have lot of missing data. We remove these with less that 40% data. """ 105 | raw_data['is_null'] = raw_data[SPECIES[0]].isnull().astype(int) 106 | raw_data['one'] = 1 107 | 108 | _df = raw_data.groupby('site').sum() 109 | 110 | #one is the total number of observations 111 | #is null is the number that were null 112 | _df = _df[_df['is_null'] < _df['one'] * 0.4] 113 | 114 | 115 | raw_data = raw_data.merge(_df, left_on='site', right_on='site', suffixes=[None, '_y']) 116 | 117 | return raw_data 118 | 119 | 120 | def clean_data(raw_data, sites_df): 121 | """ 122 | To clean the data we: 123 | 1) Remove sites not in London 124 | 2) Remove sensors with lots of missing values 125 | 3) Convert datetimes to unix epoch 126 | """ 127 | sites_df = filter_sites_not_in_london(sites_df) 128 | 129 | sites = sites_df['SiteCode'] 130 | 131 | #merge spatial infomation to data 132 | raw_data = raw_data.merge(sites_df, left_on='site', right_on='SiteCode') 133 | 134 | raw_data = drop_locations_with_low_data(raw_data) 135 | 136 | 137 | #convert to datetimes 138 | raw_data['date'] = pd.to_datetime(raw_data['date']) 139 | raw_data['epoch'] = datetime_to_epoch(raw_data['date']) 140 | 141 | 142 | return raw_data 143 | 144 | def get_grid(time_point=0): 145 | """ Create a spatial grid across London. Used for visualisations. """ 146 | 147 | london_box = [ 148 | [51.279, 51.684], #lat 149 | [-0.533, 0.332] #lon 150 | ] 151 | 152 | boundary_gdf = get_boundary_gdf() 153 | 154 | grid = utils.create_spatial_temporal_grid([time_point], london_box[0][0], london_box[0][1], london_box[1][0], london_box[1][1], 100, 100) 155 | 156 | grid_df = pd.DataFrame(grid, columns=['epoch', 'lat', 'lon']) 157 | grid_df = gpd.GeoDataFrame( 158 | grid_df, 159 | geometry=gpd.points_from_xy(grid_df.lon, grid_df.lat) 160 | ) 161 | grid_df.set_crs(epsg=4326, inplace=True) 162 | 163 | filtered_grid = sjoin(grid_df, boundary_gdf, how='inner', op='within') 164 | 165 | if False: 166 | fig = plt.figure() 167 | ax = plt.gca() 168 | #boundary_gdf.plot(ax=ax) 169 | filtered_grid.plot(ax=ax) 170 | print(filtered_grid.shape) 171 | plt.show() 172 | 173 | exit() 174 | 175 | return filtered_grid 176 | 177 | 178 | 179 | if __name__ == "__main__": 180 | logger.info('starting') 181 | 182 | # This file is also imported to extract the experiment information. To avoid import issues they are included here. 183 | import sys 184 | sys.path.append('../') 185 | from utils import normalise_df, create_spatial_temporal_grid, datetime_to_epoch, normalise, ensure_timeseries_at_each_locations, un_normalise_df, epoch_to_datetime, pad_with_nan_to_make_grid 186 | import utils 187 | 188 | #ensure correct data structure 189 | Path("data/").mkdir(exist_ok=True) 190 | Path("results/").mkdir(exist_ok=True) 191 | 192 | #clean data and get prediction site location 193 | raw_data, sites_df = load_data() 194 | prediction_site = sites_df[sites_df['SiteCode'] == PREDICTION_SITE].iloc[0] 195 | data_df = clean_data(raw_data, sites_df) 196 | 197 | print('Number of sites: ', pd.unique(data_df['SiteCode']).shape) 198 | 199 | print( 200 | 'Amount of data in jan: ', 201 | data_df[(data_df['date'] >= '2019/01/01') & (data_df['date'] < '2019/02/01')].shape 202 | 203 | ) 204 | 205 | #get data in daterange 206 | data_df = data_df[(data_df['date'] >= DATE_START) & (data_df['date'] < DATE_END)] 207 | 208 | #data_df may have missing oberservations, for the state space we require X, Y to be on a grid, with missings denoted by nans 209 | X = np.array(data_df[['epoch', 'Latitude', 'Longitude']]) 210 | Y = np.array(data_df[SPECIES]) 211 | 212 | #remove duplicated data 213 | u, unique_idx = np.unique(X, return_index=True, axis=0) 214 | X = X[unique_idx, :] 215 | Y = Y[unique_idx, :] 216 | 217 | # For the filtering methods to work we need a full spatio-temporal grid 218 | X_raw, Y_raw = pad_with_nan_to_make_grid(X.copy(), Y.copy()) 219 | 220 | N = X.shape[0] 221 | 222 | print('Y: ', Y.shape, ' X_raw: ', X_raw.shape) 223 | print('statst: ', np.nanmean(Y_raw), np.nanmin(Y_raw), np.nanmax(Y_raw)) 224 | 225 | #extract prediction timeseries 226 | prediction_idx = (X_raw[:, 1] == prediction_site['Latitude']) & (X_raw[:, 2] == prediction_site['Longitude']) 227 | timeseries_x_raw = X_raw[prediction_idx, :] 228 | timeseries_y_raw = Y_raw[prediction_idx, :] 229 | 230 | slice_epoch = utils.datetime_str_to_epoch('2019/01/05 10:00:00') 231 | 232 | spatial_grid = get_grid(slice_epoch) 233 | spatial_grid_x = np.array(spatial_grid[['epoch', 'lat', 'lon']]) 234 | spatial_grid_y = None 235 | print('spatial_grid: ', spatial_grid.shape) 236 | 237 | print('timeseries_x_raw: ', timeseries_x_raw.shape, ' timeseries_y_raw: ', timeseries_y_raw.shape) 238 | 239 | if False: 240 | print(Y_raw) 241 | print(timeseries_x_raw) 242 | plt.figure() 243 | plt.scatter(timeseries_x_raw[:, 0], timeseries_y_raw) 244 | plt.show() 245 | exit() 246 | 247 | print('number of timesteps: ', np.unique(X[:, 0]).shape) 248 | 249 | #construct test-train splits 250 | # train test splits are constructed by construction a random permutation (constrolled through seed) 251 | # and then splitting this is into train-test data as specificed by TRAIN_SPLIT 252 | for i, fold in enumerate(FOLDS): 253 | train_indices, test_indices = utils.train_test_split_indices(N, split=TRAIN_SPLIT, seed=(SEED+i)) 254 | 255 | _config = {'fold': fold} 256 | 257 | train_name, pred_name, raw_name = get_data_file_names(_config) 258 | 259 | #Collect training and testing data 260 | X_train, Y_train = X_raw.copy(), Y_raw.copy() 261 | Y_train[test_indices, :] = np.nan #to keep grid structure in X we just mask the testing data in the training set 262 | 263 | X_test, Y_test = X_raw.copy(), Y_raw.copy() 264 | Y_test[train_indices, :] = np.nan 265 | 266 | X_all = X_raw 267 | Y_all = Y_raw 268 | 269 | #normalise all data with respect to training data 270 | X_train_norm = normalise_df(X_train, wrt_to=X_train) 271 | X_test_norm = normalise_df(X_test, wrt_to=X_train) 272 | X_all_norm = normalise_df(X_all, wrt_to=X_train) 273 | timeseries_x_norm = normalise_df(timeseries_x_raw, wrt_to=X_train) 274 | spatial_grid_x_norm = normalise_df(spatial_grid_x, wrt_to=X_train) 275 | 276 | print('---') 277 | print('X_train: ', X_train_norm.shape) 278 | print('X_all: ', X_all.shape) 279 | print(np.nanmean(Y_train), np.nanstd(Y_train)) 280 | print('Y_train: ', Y_train.shape, ' Non nans: ', np.sum(np.logical_not(np.isnan(Y_train)))) 281 | print('X_test: ', X_test_norm.shape) 282 | print('Y_test: ', Y_test.shape, ' Non nans: ', np.sum(np.logical_not(np.isnan(Y_test)))) 283 | 284 | training_data = { 285 | 'X': X_train_norm, 286 | 'Y': Y_train, 287 | } 288 | 289 | prediction_data = { 290 | 'test': { 291 | 'X': X_test_norm, 292 | 'Y': Y_test 293 | 294 | }, 295 | 'all': { 296 | 'X': X_all_norm, 297 | 'Y': Y_all 298 | }, 299 | 'timeseries': { 300 | 'X': timeseries_x_norm, 301 | 'Y': timeseries_y_raw 302 | }, 303 | 'grid': { 304 | 'X': spatial_grid_x_norm, 305 | 'Y': None 306 | } 307 | } 308 | 309 | 310 | raw_data = { 311 | 'all': { 312 | 'X': X_all, 313 | 'Y': Y_all, 314 | }, 315 | 'timeseries': { 316 | 'X': timeseries_x_raw, 317 | 'Y': timeseries_y_raw, 318 | }, 319 | 'grid': { 320 | 'X': spatial_grid_x, 321 | 'Y': None, 322 | } 323 | } 324 | 325 | with open(f'data/{train_name}', 'wb') as file: 326 | pickle.dump(training_data, file) 327 | 328 | with open(f'data/{pred_name}', "wb") as file: 329 | pickle.dump(prediction_data, file) 330 | 331 | 332 | with open(f'data/{raw_name}', "wb") as file: 333 | pickle.dump(raw_data, file) 334 | 335 | 336 | logger.info('seting up data finished') 337 | 338 | -------------------------------------------------------------------------------- /experiments/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import datetime 3 | import itertools 4 | import multiprocessing as mp 5 | 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | import matplotlib.patches as patches 9 | 10 | 11 | def hours_to_seconds(hr): 12 | return hr*(60*60) 13 | 14 | def minutes_to_seconds(m): 15 | return m*(60) 16 | 17 | def slice_array(A, arr): 18 | _A = [] 19 | for a in arr: 20 | _A.append(A[a, :]) 21 | return np.vstack(_A) 22 | 23 | def slice_array_insert(A,B, arr): 24 | for i, a in enumerate(arr): 25 | _A_len = A[a].shape[0] 26 | 27 | A[a] = A[a] + B[i*_A_len:(i+1)*_A_len, :] 28 | 29 | return A 30 | 31 | def st_batch_predict(model , XS, prediction_fn=None, batch_size=5000, verbose=False): 32 | """ 33 | With KF models the prediction data must be at all timesteps. This function 34 | breaks up the space-time data into space-time batches (which spans across all time points) 35 | 36 | """ 37 | #sort XS into grid/'kronecker' structure 38 | 39 | XS = np.roll(XS, -1, axis=1) 40 | #sort by time points 41 | grid_idx = np.lexsort(XS.T) 42 | #reset time axis 43 | XS = np.roll(XS, 1, axis=1) 44 | 45 | inv_grid_idx = np.argsort(grid_idx) 46 | 47 | XS = XS[grid_idx] 48 | 49 | 50 | time_points = np.unique(XS[:, 0]) 51 | num_time_points = time_points.shape[0] 52 | num_spatial_points = int(XS.shape[0]/num_time_points) 53 | 54 | #number of spatial points that fit into batch 55 | 56 | num_spatial_points_per_batch = int(np.floor(batch_size/num_time_points)) 57 | 58 | num_steps = max(1, int(np.floor(num_spatial_points/num_spatial_points_per_batch))) 59 | 60 | if verbose: 61 | print('num_time_points: ', num_time_points) 62 | print('num_spatial_points: ', num_spatial_points) 63 | print('num_steps: ', num_steps) 64 | print('num_spatial_points_per_batch: ', num_spatial_points_per_batch) 65 | 66 | #empty prediction data 67 | mean = np.zeros([XS.shape[0], 1]) 68 | var = np.zeros_like(mean) 69 | 70 | for i in range(num_steps): 71 | if verbose: 72 | print(f"{i}/{num_steps}") 73 | 74 | batch = num_spatial_points_per_batch 75 | if i == num_steps-1: 76 | #select the remaining spatial points 77 | batch = num_spatial_points - i*num_spatial_points_per_batch 78 | 79 | #k*num_spatial_points is the index to the j'th time slice 80 | #i*num_spatial_points_per_batch is index to current spatial batch 81 | start_idx = lambda j: j*num_spatial_points + i*num_spatial_points_per_batch 82 | end_idx = lambda j: start_idx(j) + batch 83 | 84 | step_idx = [ 85 | slice(start_idx(j), end_idx(j)) for j in range(num_time_points) 86 | ] 87 | 88 | _XS = slice_array(XS, step_idx) 89 | 90 | if prediction_fn is not None: 91 | _mean, _var = prediction_fn(_XS) 92 | else: 93 | _mean, _var = model.predict_y(_XS, diagonal_var=True) 94 | 95 | _mean = np.squeeze(_mean).reshape([-1, 1]) 96 | _var = np.squeeze(_var).reshape([-1, 1]) 97 | 98 | mean = slice_array_insert(mean, _mean, step_idx) 99 | var = slice_array_insert(var, _var, step_idx) 100 | 101 | #unsort grid/kronecker structre 102 | return [mean[inv_grid_idx]], [var[inv_grid_idx]] 103 | #return [mean], [var] 104 | 105 | def batch_predict(XS, prediction_fn=None, batch_size=1000, verbose=False): 106 | # Ensure batch is less than the number of test points 107 | if XS.shape[0] < batch_size: 108 | batch_size = XS.shape[0] 109 | 110 | # Split up test points into equal batches 111 | num_batches = int(np.ceil(XS.shape[0] / batch_size)) 112 | 113 | ys_arr = [] 114 | ys_var_arr = [] 115 | index = 0 116 | 117 | for count in range(num_batches): 118 | if verbose: 119 | print(f"{count}/{num_batches}") 120 | if count == num_batches - 1: 121 | # in last batch just use remaining of test points 122 | batch = XS[index:, :] 123 | else: 124 | batch = XS[index : index + batch_size, :] 125 | 126 | index = index + batch_size 127 | 128 | # predict for current batch 129 | y_mean, y_var = prediction_fn(batch) 130 | 131 | ys_arr.append(y_mean) 132 | ys_var_arr.append(y_var) 133 | 134 | y_mean = np.concatenate(ys_arr, axis=0) 135 | y_var = np.concatenate(ys_var_arr, axis=0) 136 | 137 | return y_mean, y_var 138 | 139 | def key_that_starts_with(arr, s): 140 | for a in arr: 141 | if a.startswith(s): 142 | return a 143 | return False 144 | 145 | 146 | #Normalise Data input 147 | def normalise(x, wrt_to): 148 | return (x - np.mean(wrt_to))/np.std(wrt_to) 149 | 150 | def normalise_df(x, wrt_to): 151 | return (x - np.mean(wrt_to, axis=0))/np.std(wrt_to, axis=0) 152 | 153 | def un_normalise_df(x, wrt_to): 154 | return x* np.std(wrt_to, axis=0) + np.mean(wrt_to, axis=0) 155 | 156 | def create_spatial_grid(x1, x2, y1, y2, n1, n2): 157 | x = np.linspace(x1, x2, n1) 158 | y = np.linspace(y1, y2, n2) 159 | grid = [] 160 | for i in x: 161 | for j in y: 162 | grid.append([i, j]) 163 | return np.array(grid) 164 | 165 | def create_spatial_temporal_grid(time_points, x1, x2, y1, y2, n1, n2): 166 | x = np.linspace(x1, x2, n1) 167 | y = np.linspace(y1, y2, n2) 168 | grid = [] 169 | for t in time_points: 170 | for i in x: 171 | for j in y: 172 | grid.append([t, i, j]) 173 | return np.array(grid) 174 | 175 | def numpy_to_list(a): 176 | return [a[i][:, None] for i in range(a.shape[0])] 177 | 178 | def datetime_str_to_epoch(s): 179 | p = '%Y/%m/%d %H:%M:%S' 180 | epoch = datetime.datetime(1970, 1, 1) 181 | return int((datetime.datetime.strptime(s, p) - epoch).total_seconds()) 182 | 183 | def datetime_to_epoch(datetime): 184 | """ 185 | Converts a datetime to a number 186 | args: 187 | datatime: is a pandas column 188 | """ 189 | return datetime.astype('int64')//1e9 190 | 191 | def epoch_to_datetime(epoch): 192 | return datetime.datetime.fromtimestamp(epoch) 193 | 194 | def epochs_to_datetime_list(epochs): 195 | return [datetime.datetime.fromtimestamp(epoch) for epoch in epochs] 196 | 197 | def ensure_timeseries_at_each_locations(X, Y): 198 | """ 199 | This removes all spatial locations that do not have a full timeseries 200 | TODO: replace with Will's function 201 | """ 202 | time_points = np.unique(X[:, 0]) 203 | num_time_points = time_points.shape[0] 204 | 205 | X_space_only = X[:, 1:] 206 | new_arr, indices, counts = np.unique(X_space_only, return_index=True, return_counts=True, axis=0) 207 | spatial_to_remove = X_space_only[indices[counts != num_time_points]] 208 | for spatial_point in spatial_to_remove: 209 | idx = (X[:, 1:] != spatial_point) 210 | idx = np.all(idx, axis=1) 211 | 212 | X = X[idx, :] 213 | Y = Y[idx, :] 214 | 215 | return X, Y 216 | 217 | #taken directly from https://github.com/AaltoML/spacetime-kalman/blob/master/kalmanjax/utils.py#L308 218 | def discretegrid(xy, w, nt): 219 | """ 220 | Convert spatial observations to a discrete intensity grid 221 | :param xy: observed spatial locations as a two-column vector 222 | :param w: observation window, i.e. discrete grid to be mapped to, [xmin xmax ymin ymax] 223 | :param nt: two-element vector defining number of bins in both directions 224 | """ 225 | # Make grid 226 | x = nnp.linspace(w[0], w[1], nt[0] + 1) 227 | y = nnp.linspace(w[2], w[3], nt[1] + 1) 228 | X, Y = nnp.meshgrid(x, y) 229 | 230 | # Count points 231 | N = nnp.zeros([nt[1], nt[0]]) 232 | for i in range(nt[0]): 233 | for j in range(nt[1]): 234 | ind = (xy[:, 0] >= x[i]) & (xy[:, 0] < x[i + 1]) & (xy[:, 1] >= y[j]) & (xy[:, 1] < y[j + 1]) 235 | N[j, i] = nnp.sum(ind) 236 | return X[:-1, :-1].T, Y[:-1, :-1].T, N.T 237 | 238 | class center_point_discretise_grid_inner_loop(): 239 | def __init__(self, data, t_arr, centers_arr, input_dim): 240 | self.data = data 241 | self.t_arr = t_arr 242 | self.centers_arr = centers_arr 243 | self.input_dim = input_dim 244 | 245 | def __call__(self, i): 246 | in_range = lambda data, d, t, i: (data[:, d] >= t[i]) & (data[:, d] < t[i + 1]) 247 | 248 | #create idx over each dim 249 | idx = [ 250 | in_range(self.data, d, self.t_arr[d], i[d]) for d in range(self.input_dim) 251 | ] 252 | 253 | idx = np.logical_and.reduce(idx, axis=0) 254 | 255 | N_kij = np.sum(idx) 256 | 257 | _counts = N_kij 258 | _X = [self.centers_arr[d][i[d]] for d in range(self.input_dim)] 259 | 260 | return _X, _counts 261 | 262 | def center_point_discretise_grid(data, bin_sizes = [5]): 263 | """ 264 | data: [time] columns 265 | 266 | Returns: 267 | the binned counts, the center locations of the bins and the bin sizes 268 | """ 269 | 270 | 271 | input_dim = len(bin_sizes) 272 | 273 | #helper functions 274 | lin = lambda i: np.linspace( 275 | np.min(data[:, i]), np.max(data[:, i]), bin_sizes[i]+1 276 | ) 277 | 278 | centers = lambda A: np.array([np.mean([A[i], A[i+1]]) for i in range(A.shape[0] -1 )]) 279 | 280 | 281 | #get grid cell positions 282 | t_arr = [lin(d) for d in range(input_dim)] 283 | 284 | #get size of a single cell 285 | binned_x_size_arr = [t_arr[d][1]-t_arr[d][0] for d in range(input_dim)] 286 | 287 | #get center locations of grid cells 288 | centers_arr = [centers(t_arr[d]) for d in range(input_dim)] 289 | 290 | 291 | #X.append([centers_arr[d][i[d]] for d in range(input_dim)]) 292 | #counts.append(N_kij) 293 | 294 | bin_ranges = [range(0, bin_sizes[d]) for d in range(input_dim)] 295 | 296 | with mp.Pool(processes=mp.cpu_count()) as p: 297 | res = p.map( 298 | center_point_discretise_grid_inner_loop( 299 | data, 300 | t_arr, 301 | centers_arr, 302 | input_dim 303 | ), 304 | [i for i in itertools.product(*bin_ranges)] 305 | ) 306 | X, counts = zip(*res) 307 | 308 | 309 | X = np.array(X) 310 | Y = np.array(counts)[:, None] 311 | 312 | return X, Y, binned_x_size_arr 313 | 314 | 315 | def center_point_st_discretise_grid(data, bin_sizes = [5, 10, 20]): 316 | """ 317 | data: [time, lat, lon] columns 318 | 319 | Returns: 320 | the binned counts, the center locations of the bins and the bin sizes 321 | """ 322 | lin = lambda i: np.linspace( 323 | np.min(data[:, i]), np.max(data[:, i]), bin_sizes[i]+1 324 | ) 325 | 326 | centers = lambda A: np.array([np.mean([A[i], A[i+1]]) for i in range(A.shape[0] -1 )]) 327 | 328 | #get grid cell positions 329 | t = lin(0) 330 | x = lin(1) 331 | y = lin(2) 332 | 333 | #get size of a single cell 334 | binned_x_sizes = [ 335 | t[1]-t[0], 336 | x[1]-x[0], 337 | y[1]-y[0] 338 | ] 339 | 340 | #get center locations of grid cells 341 | t_centers = centers(t) 342 | x_centers = centers(x) 343 | y_centers = centers(y) 344 | 345 | 346 | counts = [] 347 | X = [] 348 | 349 | #iterate each grid cell and count how many occurences are within it 350 | for k in range(bin_sizes[0]): 351 | for i in range(bin_sizes[1]): 352 | for j in range(bin_sizes[2]): 353 | idx = (data[:, 0] >= t[k]) & (data[:, 0] < t[k + 1]) & \ 354 | (data[:, 1] >= x[i]) & (data[:, 1] < x[i + 1]) & \ 355 | (data[:, 2] >= y[j]) & (data[:, 2] < y[j + 1]) 356 | 357 | N_kij = np.sum(idx) 358 | X.append([t_centers[k], x_centers[i], y_centers[j]]) 359 | counts.append(N_kij) 360 | 361 | 362 | X = np.array(X) 363 | Y = np.array(counts)[:, None] 364 | 365 | return X, Y, binned_x_sizes 366 | 367 | def create_geopandas_spatial_grid(xmin, xmax, ymin, ymax, cell_size_x, cell_size_y, crs=None): 368 | """ 369 | see https://james-brennan.github.io/posts/fast_gridding_geopandas/ 370 | """ 371 | import geopandas 372 | import shapely 373 | 374 | # create the cells in a loop 375 | grid_cells = [] 376 | for x0 in np.arange(xmin, xmax-cell_size_x, cell_size_x ): 377 | for y0 in np.arange(ymin, ymax-cell_size_y, cell_size_y): 378 | # bounds 379 | x1 = x0+cell_size_x 380 | y1 = y0+cell_size_y 381 | grid_cells.append( shapely.geometry.box(x0, y0, x1, y1) ) 382 | 383 | cell = geopandas.GeoDataFrame(grid_cells, columns=['geometry'], 384 | crs=crs) 385 | return cell 386 | 387 | def pad_with_nan_to_make_grid(X, Y): 388 | #converts data into grid 389 | 390 | N = X.shape[0] 391 | 392 | #construct target grid 393 | unique_time = np.unique(X[:, 0]) 394 | unique_space = np.unique(X[:, 1:], axis=0) 395 | 396 | Nt = unique_time.shape[0] 397 | Ns = unique_space.shape[0] 398 | 399 | print('grid size:', N, Nt, Ns, Nt*Ns) 400 | 401 | X_tmp = np.tile(np.expand_dims(unique_space, 0), [Nt, 1, 1]) 402 | 403 | time_tmp = np.tile(unique_time, [Ns]).reshape([Nt, Ns], order='F') 404 | 405 | X_tmp = X_tmp.reshape([Nt*Ns, -1]) 406 | 407 | time_tmp = time_tmp.reshape([Nt*Ns, 1]) 408 | 409 | #X_tmp is the full grid 410 | X_tmp = np.hstack([time_tmp, X_tmp]) 411 | 412 | #Find the indexes in X_tmp that we need to add to X to make a full grid 413 | _X = np.vstack([X, X_tmp]) 414 | _Y = np.nan*np.zeros([_X.shape[0], 1]) 415 | 416 | _, idx = np.unique(_X, return_index=True, axis=0) 417 | idx = idx[idx>=N] 418 | print('unique points: ', idx.shape) 419 | 420 | X_to_add = _X[idx, :] 421 | Y_to_add = _Y[idx, :] 422 | 423 | X_grid = np.vstack([X, X_to_add]) 424 | Y_grid = np.vstack([Y, Y_to_add]) 425 | 426 | #sort for good measure 427 | _X = np.roll(X_grid, -1, axis=1) 428 | #sort by time points first 429 | idx = np.lexsort(_X.T) 430 | 431 | return X_grid[idx], Y_grid[idx] 432 | 433 | 434 | def train_test_split_indices(N, split=0.5, seed=0): 435 | np.random.seed(seed) 436 | rand_index = np.random.permutation(N) 437 | 438 | N_tr = int(N * split) 439 | 440 | return rand_index[:N_tr], rand_index[N_tr:] 441 | 442 | def lat_lon_to_polygon_buffer(row, actual_bin_sizes): 443 | import shapely 444 | from shapely.geometry import Polygon 445 | 446 | lat = row['Latitude'] 447 | lon = row['Longitude'] 448 | 449 | w1 = actual_bin_sizes[1]/2 450 | w2 = actual_bin_sizes[2]/2 451 | 452 | p1 = [lon-w1, lat-w2] 453 | p2 = [lon-w1, lat+w2] 454 | p3 = [lon+w1, lat+w2] 455 | p4 = [lon+w1, lat-w2] 456 | return Polygon([p1, p2, p3, p4, p1]) 457 | 458 | --------------------------------------------------------------------------------