├── .gitignore ├── AWS-S3-Dask-XGBoost.ipynb ├── LICENSE ├── README.md ├── brenowitz-data-loading.ipynb ├── djgagne_partial_dependence_plot.ipynb └── rasp-data-loading.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /AWS-S3-Dask-XGBoost.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "toc": true 7 | }, 8 | "source": [ 9 | "

Table of Contents

\n", 10 | "
" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "# Overview" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "**Author**:\n", 25 | "**Zhonghua Zheng** (zzheng25@illinois.edu) \n", 26 | "GitHub: **zzheng93** \n", 27 | "\n", 28 | "**Motivation**:\n", 29 | "This notebook shows the workflow of \n", 30 | "* Loading Climate Data ([CESM-LE from AWS S3](https://medium.com/pangeo/cesm-lens-on-aws-4e2a996397a1))\n", 31 | "* Scalable [Dask-XGBoost](https://examples.dask.org/machine-learning/xgboost.html) *regression* implementation (the official documentation of [Dask-XGBoost](https://examples.dask.org/machine-learning/xgboost.html) only provides the *classfication* implementation.\n", 32 | "\n", 33 | "**Task**:\n", 34 | "How to predict **maximum reference height temperature (TREFHTMX)** based on related features?\n", 35 | "\n", 36 | "* Response (Y): \n", 37 | "\"TREFHTMX\": Maximum reference height temperature over output period \n", 38 | "\n", 39 | "* Features (X): \n", 40 | "\"PRECT\": Total (convective and large-scale) precipitation rate (liq + ice) \n", 41 | "\"WSPDSRFAV\": Horizontal total wind speed average at the surface \n", 42 | "\"TS\": Surface temperature (radiative) \n", 43 | "\"TREFHT\": Reference height temperature \n", 44 | "\n", 45 | "**Prerequisite** \n", 46 | "How to create environment for this workflow? \n", 47 | "Recommendation: \n", 48 | "```bash\n", 49 | "git clone http://github.com/dask/dask-tutorial\n", 50 | "cd dask-tutorial\n", 51 | "conda env create -f binder/environment.yml \n", 52 | "conda activate dask-tutorial\n", 53 | "conda install xarray cftime=1.0.3.4\n", 54 | "conda install -c conda-forge dask-ml dask-xgboost\n", 55 | "```\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "# Load libraries" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "Here we load the necessary libraries, connect to cluster, and set up anonymous access to s3" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 1, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "import numpy as np\n", 79 | "import s3fs\n", 80 | "import xarray as xr\n", 81 | "import matplotlib.pyplot as plt\n", 82 | "import dask\n", 83 | "import dask.array as da\n", 84 | "import dask.dataframe as dd\n", 85 | "import dask_xgboost\n", 86 | "import matplotlib.pyplot as plt\n", 87 | "from dask.distributed import Client\n", 88 | "from dask_ml.model_selection import train_test_split\n", 89 | "from sklearn import metrics\n", 90 | "import xgboost\n", 91 | "\n", 92 | "client = Client() # connect to cluster\n", 93 | "s3 = s3fs.S3FileSystem(anon=True) # anonymous access to s3" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 2, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "data": { 103 | "text/html": [ 104 | "\n", 105 | "\n", 106 | "\n", 113 | "\n", 121 | "\n", 122 | "
\n", 107 | "

Client

\n", 108 | "\n", 112 | "
\n", 114 | "

Cluster

\n", 115 | "
    \n", 116 | "
  • Workers: 4
  • \n", 117 | "
  • Cores: 16
  • \n", 118 | "
  • Memory: 32.66 GB
  • \n", 119 | "
\n", 120 | "
" 123 | ], 124 | "text/plain": [ 125 | "" 126 | ] 127 | }, 128 | "execution_count": 2, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "client" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "# Load Data" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "Here we load the CESM-LE RCP8.5 data from [AWS S3](https://medium.com/pangeo/cesm-lens-on-aws-4e2a996397a1). \n", 149 | "The original format of the data is [Zarr](https://zarr.readthedocs.io/en/stable/), but Dask-XGBoost only accepts Dask.array or Dask.dataframe collections.\n", 150 | "\n", 151 | "So we use [xarray](http://xarray.pydata.org/en/stable/) to load the data, and convert to [Dask Dataframe](https://docs.dask.org/en/latest/dataframe.html). \n", 152 | "\n", 153 | "As an illustration, here I only use \"member_id=1\" and \"time=\"2010-01-01\"\" for this tutorial.\n" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 3, 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "CPU times: user 6.55 s, sys: 308 ms, total: 6.85 s\n", 166 | "Wall time: 1min 34s\n" 167 | ] 168 | } 169 | ], 170 | "source": [ 171 | "%%time\n", 172 | "zarr_ls = [xr.open_zarr(s3fs.S3Map(root=\"ncar-cesm-lens/atm/daily/cesmLE-RCP85-\"+var+\".zarr\", s3=s3))[var]\n", 173 | " .sel(member_id=1,time=\"2010-01-01\") \n", 174 | " for var in [\"TREFHT\",\"TREFHTMX\",\"PRECT\",\"TS\",\"WSPDSRFAV\"]] # read data from s3" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 4, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "dddf=xr.merge(zarr_ls).to_dask_dataframe() # merge arrays and convert to dask dataframe" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 5, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "X=dddf[[\"TREFHT\",\"PRECT\",\"TS\",\"WSPDSRFAV\"]] # features for Dask-XGBoost\n", 193 | "Y=dddf[\"TREFHTMX\"] # response for Dask-XGBoost" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 6, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.15) # split into training and testing" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "# Train the XGB model" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "This parameters are available at [Learning Task Parameters](https://xgboost.readthedocs.io/en/latest/parameter.html).\n", 217 | "* **reg:squarederror**: regression with squared loss.\n", 218 | "* **eta (learning_rate)**: Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features, and eta shrinks the feature weights to make the boosting process more conservative. \n", 219 | "* **max_depth**: Maximum depth of a tree. Increasing this value will make the model more complex and more likely to overfit. 0 is only accepted in lossguided growing policy when tree_method is set as hist and it indicates no limit on depth. Beware that XGBoost aggressively consumes memory when training a deep tree. \n", 220 | "* **num_boost_round**: [Number of boosting](iterationshttps://xgboost.readthedocs.io/en/latest/python/python_api.html)." 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 7, 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "name": "stdout", 230 | "output_type": "stream", 231 | "text": [ 232 | "CPU times: user 306 ms, sys: 57.8 ms, total: 364 ms\n", 233 | "Wall time: 13 s\n" 234 | ] 235 | } 236 | ], 237 | "source": [ 238 | "%%time\n", 239 | "params = {'objective': 'reg:squarederror',\n", 240 | " 'max_depth': 6, 'eta': 0.01}\n", 241 | "\n", 242 | "bst = dask_xgboost.train(client, params, X_train, y_train, num_boost_round=500)" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": {}, 248 | "source": [ 249 | "## Determine feature importance" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "Generally, importance provides a score that indicates how useful or valuable each feature was in the construction of the boosted decision trees within the model. The more an attribute is used to make key decisions with decision trees, the higher its relative importance (reference: [Feature Importance and Feature Selection With XGBoost in Python](https://machinelearningmastery.com/feature-importance-and-feature-selection-with-xgboost-in-python/))." 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 8, 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "data": { 266 | "image/png": "\n", 267 | "text/plain": [ 268 | "
" 269 | ] 270 | }, 271 | "metadata": { 272 | "needs_background": "light" 273 | }, 274 | "output_type": "display_data" 275 | } 276 | ], 277 | "source": [ 278 | "%matplotlib inline\n", 279 | "ax = xgboost.plot_importance(bst)\n", 280 | "ax.grid(False, axis=\"y\")\n", 281 | "ax.set_title('Estimated feature importance')\n", 282 | "plt.show()" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": {}, 288 | "source": [ 289 | "# Test model performance" 290 | ] 291 | }, 292 | { 293 | "cell_type": "markdown", 294 | "metadata": {}, 295 | "source": [ 296 | "Here we test our XGBoost emulator performance using the testing data" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 9, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "y_hat = dask_xgboost.predict(client, bst, X_test).persist()\n", 306 | "y_test, y_hat = dask.compute(y_test, y_hat) # get the predicted data and true data" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 10, 312 | "metadata": {}, 313 | "outputs": [ 314 | { 315 | "data": { 316 | "image/png": "\n", 317 | "text/plain": [ 318 | "
" 319 | ] 320 | }, 321 | "metadata": { 322 | "needs_background": "light" 323 | }, 324 | "output_type": "display_data" 325 | }, 326 | { 327 | "name": "stdout", 328 | "output_type": "stream", 329 | "text": [ 330 | "rmse: 2.5904884\n", 331 | "r2_score: 0.98544579149015\n" 332 | ] 333 | } 334 | ], 335 | "source": [ 336 | "fig, ax = plt.subplots(figsize=(5, 5))\n", 337 | "ax.scatter(y_hat,y_test,s=0.1)\n", 338 | "ax.plot([y_test.min(), y_test.min()], [y_test.max(), y_test.max()],c=\"black\")\n", 339 | "ax.set(\n", 340 | " xlabel=\"Predicted\",\n", 341 | " ylabel=\"True\",\n", 342 | ")\n", 343 | "plt.show()\n", 344 | "print(\"rmse:\",np.sqrt(metrics.mean_squared_error(y_test, y_hat)))\n", 345 | "print(\"r2_score:\",metrics.r2_score(y_test,y_hat))" 346 | ] 347 | } 348 | ], 349 | "metadata": { 350 | "kernelspec": { 351 | "display_name": "Python 3", 352 | "language": "python", 353 | "name": "python3" 354 | }, 355 | "language_info": { 356 | "codemirror_mode": { 357 | "name": "ipython", 358 | "version": 3 359 | }, 360 | "file_extension": ".py", 361 | "mimetype": "text/x-python", 362 | "name": "python", 363 | "nbconvert_exporter": "python", 364 | "pygments_lexer": "ipython3", 365 | "version": "3.7.4" 366 | }, 367 | "toc": { 368 | "base_numbering": 1, 369 | "nav_menu": {}, 370 | "number_sections": false, 371 | "sideBar": true, 372 | "skip_h1_title": false, 373 | "title_cell": "Table of Contents", 374 | "title_sidebar": "Contents", 375 | "toc_cell": true, 376 | "toc_position": {}, 377 | "toc_section_display": true, 378 | "toc_window_display": false 379 | } 380 | }, 381 | "nbformat": 4, 382 | "nbformat_minor": 2 383 | } 384 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ml-workflow-examples 2 | Simple examples of data pipelines from xarray to ML training 3 | -------------------------------------------------------------------------------- /brenowitz-data-loading.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Here is a toy xarray dataset. It has a few 3D and 2D variables." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "data": { 17 | "text/plain": [ 18 | "\n", 19 | "Dimensions: (time: 4, x: 2, y: 5, z: 4)\n", 20 | "Dimensions without coordinates: time, x, y, z\n", 21 | "Data variables:\n", 22 | " a (time, z, y, x) float64 1.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0 1.0\n", 23 | " b (time, y, x) float64 1.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0 1.0" 24 | ] 25 | }, 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "output_type": "execute_result" 29 | } 30 | ], 31 | "source": [ 32 | "import numpy as np\n", 33 | "import xarray as xr\n", 34 | "from itertools import product\n", 35 | "\n", 36 | "from torch.utils.data import Dataset, DataLoader\n", 37 | "\n", 38 | "class XRTimeSeries(Dataset):\n", 39 | " \"\"\"A pytorch Dataset class for time series data in xarray format\n", 40 | "\n", 41 | " This function assumes the data has dimensions ['time', 'z', 'y', 'x'], and\n", 42 | " that the axes of the data arrays are all stored in that order.\n", 43 | "\n", 44 | " An individual \"sample\" is the full time time series from a single\n", 45 | " horizontal location. The time-varying variables in this sample will have\n", 46 | " shape (time, z, 1, 1).\n", 47 | "\n", 48 | " Examples\n", 49 | " --------\n", 50 | " >>> ds = xr.open_dataset(\"in.nc\")\n", 51 | " >>> dataset = XRTimeSeries(ds)\n", 52 | " >>> dataset[0]\n", 53 | "\n", 54 | " \"\"\"\n", 55 | " dims = ['time', 'z', 'x', 'y']\n", 56 | "\n", 57 | " def __init__(self, data, time_length=None):\n", 58 | " \"\"\"\n", 59 | " Parameters\n", 60 | " ----------\n", 61 | " data : xr.DataArray\n", 62 | " An input dataset. This dataset must contain at least some variables\n", 63 | " with all of the dimensions ['time' , 'z', 'x', 'y'].\n", 64 | " time_length : int, optional\n", 65 | " The length of the time sequences to use, must evenly divide the\n", 66 | " total number of time points.\n", 67 | " \"\"\"\n", 68 | " self.time_length = time_length or len(data.time)\n", 69 | " self.data = data\n", 70 | " self.numpy_data = {key: data[key].values for key in data.data_vars}\n", 71 | " self.data_vars = set(data.data_vars)\n", 72 | " self.dims = {key: data[key].dims for key in data.data_vars}\n", 73 | " self.constants = {\n", 74 | " key\n", 75 | " for key in data.data_vars\n", 76 | " if len({'x', 'y', 'time'} & set(data[key].dims)) == 0\n", 77 | " }\n", 78 | " self.setup_indices()\n", 79 | "\n", 80 | " def setup_indices(self):\n", 81 | " len_x = len(self.data['x'].values)\n", 82 | " len_y = len(self.data['y'].values)\n", 83 | " len_t = len(self.data['time'].values)\n", 84 | "\n", 85 | " x_iter = range(0, len_x, 1)\n", 86 | " y_iter = range(0, len_y, 1)\n", 87 | " t_iter = range(0, len_t, self.time_length)\n", 88 | " assert len_t % self.time_length == 0\n", 89 | " self.indices = list(product(t_iter, y_iter, x_iter))\n", 90 | "\n", 91 | " def __len__(self):\n", 92 | " return len(self.indices)\n", 93 | "\n", 94 | " def __getitem__(self, i):\n", 95 | " t, y, x = self.indices[i]\n", 96 | " output_tensors = {}\n", 97 | " for key in self.data_vars:\n", 98 | " if key in self.constants:\n", 99 | " continue\n", 100 | "\n", 101 | " data_array = self.numpy_data[key]\n", 102 | " if 'z' in self.dims[key]:\n", 103 | " this_array_index = (slice(t, t + self.time_length),\n", 104 | " slice(None), y, x)\n", 105 | " else:\n", 106 | " this_array_index = (slice(t, t + self.time_length), None, y, x)\n", 107 | "\n", 108 | " sample = data_array[this_array_index][:, :, np.newaxis, np.newaxis]\n", 109 | " output_tensors[key] = sample.astype(np.float32)\n", 110 | " return output_tensors\n", 111 | "\n", 112 | " @property\n", 113 | " def time_dim(self):\n", 114 | " return self.dims[0][0]\n", 115 | "\n", 116 | " def torch_constants(self):\n", 117 | " return {\n", 118 | " key: torch.tensor(self.data[key].values, requires_grad=False)\n", 119 | " .float()\n", 120 | " for key in self.constants\n", 121 | " }\n", 122 | "\n", 123 | " @property\n", 124 | " def scale(self):\n", 125 | " std = self.std\n", 126 | " return valmap(lambda x: x.max(), std)\n", 127 | " \n", 128 | "\n", 129 | "def get_xarray_dataset():\n", 130 | "\n", 131 | " dims_3d = ['time', 'z', 'y', 'x']\n", 132 | " dims_2d = ['time', 'y', 'x']\n", 133 | "\n", 134 | " data_3d = np.ones((4, 4, 5, 2))\n", 135 | " data_2d = np.ones((4, 5, 2))\n", 136 | "\n", 137 | " return xr.Dataset({\n", 138 | " 'a': (dims_3d, data_3d),\n", 139 | " 'b': (dims_2d, data_2d)\n", 140 | " })\n", 141 | "\n", 142 | "ds = get_xarray_dataset()\n", 143 | "ds" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 2, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "torch_dataset = XRTimeSeries(ds, time_length=4)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 3, 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | "data": { 162 | "text/plain": [ 163 | "10" 164 | ] 165 | }, 166 | "execution_count": 3, 167 | "metadata": {}, 168 | "output_type": "execute_result" 169 | } 170 | ], 171 | "source": [ 172 | "len(torch_dataset)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "The length of the dataset is $x\\dot y$" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 4, 185 | "metadata": {}, 186 | "outputs": [ 187 | { 188 | "data": { 189 | "text/plain": [ 190 | "{'a': array([[[[1.]],\n", 191 | " \n", 192 | " [[1.]],\n", 193 | " \n", 194 | " [[1.]],\n", 195 | " \n", 196 | " [[1.]]],\n", 197 | " \n", 198 | " \n", 199 | " [[[1.]],\n", 200 | " \n", 201 | " [[1.]],\n", 202 | " \n", 203 | " [[1.]],\n", 204 | " \n", 205 | " [[1.]]],\n", 206 | " \n", 207 | " \n", 208 | " [[[1.]],\n", 209 | " \n", 210 | " [[1.]],\n", 211 | " \n", 212 | " [[1.]],\n", 213 | " \n", 214 | " [[1.]]],\n", 215 | " \n", 216 | " \n", 217 | " [[[1.]],\n", 218 | " \n", 219 | " [[1.]],\n", 220 | " \n", 221 | " [[1.]],\n", 222 | " \n", 223 | " [[1.]]]], dtype=float32), 'b': array([[[[1.]]],\n", 224 | " \n", 225 | " \n", 226 | " [[[1.]]],\n", 227 | " \n", 228 | " \n", 229 | " [[[1.]]],\n", 230 | " \n", 231 | " \n", 232 | " [[[1.]]]], dtype=float32)}" 233 | ] 234 | }, 235 | "execution_count": 4, 236 | "metadata": {}, 237 | "output_type": "execute_result" 238 | } 239 | ], 240 | "source": [ 241 | "sample = torch_dataset[0]\n", 242 | "sample" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": {}, 248 | "source": [ 249 | "Two dimensional and three dimensional variables have broadcastable shapes" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 5, 255 | "metadata": {}, 256 | "outputs": [ 257 | { 258 | "data": { 259 | "text/plain": [ 260 | "(4, 1, 1, 1)" 261 | ] 262 | }, 263 | "execution_count": 5, 264 | "metadata": {}, 265 | "output_type": "execute_result" 266 | } 267 | ], 268 | "source": [ 269 | "sample['b'].shape" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 6, 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "data": { 279 | "text/plain": [ 280 | "(4, 4, 1, 1)" 281 | ] 282 | }, 283 | "execution_count": 6, 284 | "metadata": {}, 285 | "output_type": "execute_result" 286 | } 287 | ], 288 | "source": [ 289 | "sample['a'].shape" 290 | ] 291 | }, 292 | { 293 | "cell_type": "markdown", 294 | "metadata": {}, 295 | "source": [ 296 | "Now that we have made the torch dataset object, we can pass it to pytorch's DataLoader class." 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 7, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "train_loader = DataLoader(torch_dataset, batch_size=4)" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 8, 311 | "metadata": {}, 312 | "outputs": [ 313 | { 314 | "name": "stdout", 315 | "output_type": "stream", 316 | "text": [ 317 | "shape of b torch.Size([4, 4, 1, 1, 1])\n", 318 | "shape of b torch.Size([4, 4, 1, 1, 1])\n", 319 | "shape of b torch.Size([2, 4, 1, 1, 1])\n" 320 | ] 321 | } 322 | ], 323 | "source": [ 324 | "for batch in train_loader:\n", 325 | " print(\"shape of b\", batch['b'].shape)" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "metadata": {}, 331 | "source": [ 332 | "The first dimension becomes the \"batch\" dimension. The other dimensions are the physical dimensions (time, z, y, x). My [model classes](https://github.com/nbren12/uwnet/blob/047a63b70985b12e17013355ecd25c908681ab76/uwnet/modules.py) accept data in this format." 333 | ] 334 | } 335 | ], 336 | "metadata": { 337 | "kernelspec": { 338 | "display_name": "Python 3", 339 | "language": "python", 340 | "name": "python3" 341 | }, 342 | "language_info": { 343 | "codemirror_mode": { 344 | "name": "ipython", 345 | "version": 3 346 | }, 347 | "file_extension": ".py", 348 | "mimetype": "text/x-python", 349 | "name": "python", 350 | "nbconvert_exporter": "python", 351 | "pygments_lexer": "ipython3", 352 | "version": "3.6.6" 353 | } 354 | }, 355 | "nbformat": 4, 356 | "nbformat_minor": 2 357 | } 358 | -------------------------------------------------------------------------------- /djgagne_partial_dependence_plot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Partial Dependence Plot Example\n", 8 | "David John Gagne\n", 9 | "\n", 10 | "The goal of this notebook is to show an example of a serial and potentially parallel partial dependence plot in order to figure out ways to scale this better with Pangeo tools." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 4, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "%matplotlib inline\n", 20 | "import numpy as np\n", 21 | "import pandas as pd\n", 22 | "from sklearn.ensemble import RandomForestRegressor\n", 23 | "from sklearn.linear_model import Ridge\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "from dask.distributed import LocalCluster, Client\n", 26 | "import os\n", 27 | "from os.path import exists, join\n", 28 | "from urllib.request import urlretrieve\n", 29 | "import tarfile\n", 30 | "import glob" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "Download CSV Data." 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 3, 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "name": "stdout", 47 | "output_type": "stream", 48 | "text": [ 49 | "Get csv files\n", 50 | "Extract csv tar file\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "if not exists(\"tornado_data\"):\n", 56 | " os.mkdir(\"tornado_data\")\n", 57 | "csv_tar_file = \"https://storage.googleapis.com/track_data_ncar_ams_3km_csv_small/track_data_ncar_ams_3km_csv_small.tar.gz\"\n", 58 | "print(\"Get csv files\")\n", 59 | "urlretrieve(csv_tar_file, join(\"tornado_data\", csv_tar_file.split(\"/\")[-1]))\n", 60 | "print(\"Extract csv tar file\")\n", 61 | "csv_tar = tarfile.open(join(\"tornado_data\", csv_tar_file.split(\"/\")[-1]))\n", 62 | "csv_tar.extractall(\"tornado_data/\")\n", 63 | "csv_tar.close()" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "Load CSV data" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 11, 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "['tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20101024-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20101122-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110201-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110308-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110326-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110404-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110414-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110420-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110425-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110509-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110522-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110527-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110605-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110610-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110615-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110620-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110625-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110704-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20110712-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20111116-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20120218-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20120315-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20120323-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20120401-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20120409-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20120426-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20120503-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20120510-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20120529-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20120606-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20120622-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20120701-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20120706-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20120715-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20121225-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20130318-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20130331-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20130411-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20130429-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20130513-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20130519-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20130527-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20130602-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20130613-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20130619-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20130625-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20130701-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20130708-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20130715-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20140220-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20140328-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20140407-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20140425-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20140508-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20140514-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20140526-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20140604-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20140609-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20140617-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20140622-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20140628-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20140705-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20140710-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20141123-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20150331-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20150416-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20150422-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20150505-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20150510-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20150523-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20150528-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20150605-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20150612-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20150620-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20150625-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20150630-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20150706-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20150712-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20151031-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20151227-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20160224-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20160323-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20160401-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20160415-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20160429-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20160505-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20160511-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20160522-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20160528-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20160608-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20160616-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20160622-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20160628-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20160707-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20160712-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20161129-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20170121-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20170228-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20170323-0000.csv', 'tornado_data/track_data_ncar_ams_3km_csv_small/track_step_NCARSTORM_d01_20170329-0000.csv']\n" 83 | ] 84 | } 85 | ], 86 | "source": [ 87 | "path = \"tornado_data/track_data_ncar_ams_3km_csv_small/\"\n", 88 | "files = sorted(glob.glob(path+\"/*.csv\"))\n", 89 | "print(files)\n", 90 | "df = pd.concat([pd.read_csv(f, parse_dates=[\"Run_Date\", \"Valid_Date\"]) for f in files], ignore_index=True)\n" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 9, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "Step_ID\n", 103 | "Track_ID\n", 104 | "Ensemble_Name\n", 105 | "Ensemble_Member\n", 106 | "Run_Date\n", 107 | "Valid_Date\n", 108 | "Forecast_Hour\n", 109 | "Valid_Hour_UTC\n", 110 | "Duration\n", 111 | "Centroid_Lon\n", 112 | "Centroid_Lat\n", 113 | "Centroid_X\n", 114 | "Centroid_Y\n", 115 | "Storm_Motion_U\n", 116 | "Storm_Motion_V\n", 117 | "REFL_COM_mean\n", 118 | "REFL_COM_max\n", 119 | "REFL_COM_min\n", 120 | "REFL_COM_std\n", 121 | "REFL_COM_percentile_10\n", 122 | "REFL_COM_percentile_25\n", 123 | "REFL_COM_percentile_50\n", 124 | "REFL_COM_percentile_75\n", 125 | "REFL_COM_percentile_90\n", 126 | "U10_mean\n", 127 | "U10_max\n", 128 | "U10_min\n", 129 | "U10_std\n", 130 | "U10_percentile_10\n", 131 | "U10_percentile_25\n", 132 | "U10_percentile_50\n", 133 | "U10_percentile_75\n", 134 | "U10_percentile_90\n", 135 | "V10_mean\n", 136 | "V10_max\n", 137 | "V10_min\n", 138 | "V10_std\n", 139 | "V10_percentile_10\n", 140 | "V10_percentile_25\n", 141 | "V10_percentile_50\n", 142 | "V10_percentile_75\n", 143 | "V10_percentile_90\n", 144 | "T2_mean\n", 145 | "T2_max\n", 146 | "T2_min\n", 147 | "T2_std\n", 148 | "T2_percentile_10\n", 149 | "T2_percentile_25\n", 150 | "T2_percentile_50\n", 151 | "T2_percentile_75\n", 152 | "T2_percentile_90\n", 153 | "RVORT1_MAX-future_mean\n", 154 | "RVORT1_MAX-future_max\n", 155 | "RVORT1_MAX-future_min\n", 156 | "RVORT1_MAX-future_std\n", 157 | "RVORT1_MAX-future_percentile_10\n", 158 | "RVORT1_MAX-future_percentile_25\n", 159 | "RVORT1_MAX-future_percentile_50\n", 160 | "RVORT1_MAX-future_percentile_75\n", 161 | "RVORT1_MAX-future_percentile_90\n", 162 | "HAIL_MAXK1-future_mean\n", 163 | "HAIL_MAXK1-future_max\n", 164 | "HAIL_MAXK1-future_min\n", 165 | "HAIL_MAXK1-future_std\n", 166 | "HAIL_MAXK1-future_percentile_10\n", 167 | "HAIL_MAXK1-future_percentile_25\n", 168 | "HAIL_MAXK1-future_percentile_50\n", 169 | "HAIL_MAXK1-future_percentile_75\n", 170 | "HAIL_MAXK1-future_percentile_90\n", 171 | "area\n", 172 | "eccentricity\n", 173 | "major_axis_length\n", 174 | "minor_axis_length\n", 175 | "orientation\n", 176 | "Matched\n", 177 | "Max_Hail_Size\n", 178 | "Num_Matches\n", 179 | "Shape\n", 180 | "Location\n", 181 | "Scale\n" 182 | ] 183 | } 184 | ], 185 | "source": [ 186 | "for col in df.columns:\n", 187 | " print(col)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "We are using reflectivity, u-wind, v-wind, and 2 m temperature to predict vorticity for a given storm in a dataset of storms." 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 12, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "input_cols = [\"REFL_COM_mean\", \"U10_mean\", \"V10_mean\", \"T2_mean\"]\n", 204 | "output_col = [\"RVORT1_MAX-future_max\"]\n", 205 | "split_date = pd.Timestamp(\"2015-01-01\")\n", 206 | "train_in = df.loc[df[\"Run_Date\"] < split_date, input_cols]\n", 207 | "train_out = df.loc[df[\"Run_Date\"] < split_date, output_col]\n", 208 | "test_in = df.loc[df[\"Run_Date\"] >= split_date, input_cols]\n", 209 | "test_out = df.loc[df[\"Run_Date\"]>= split_date, output_col]\n" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 16, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "data": { 219 | "text/plain": [ 220 | "(76377, 1)" 221 | ] 222 | }, 223 | "execution_count": 16, 224 | "metadata": {}, 225 | "output_type": "execute_result" 226 | } 227 | ], 228 | "source": [ 229 | "train_out.shape" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": {}, 235 | "source": [ 236 | "Train the random forest" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 17, 242 | "metadata": {}, 243 | "outputs": [ 244 | { 245 | "data": { 246 | "text/plain": [ 247 | "RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=8,\n", 248 | " max_features='auto', max_leaf_nodes=None,\n", 249 | " min_impurity_decrease=0.0, min_impurity_split=None,\n", 250 | " min_samples_leaf=1, min_samples_split=2,\n", 251 | " min_weight_fraction_leaf=0.0, n_estimators=50, n_jobs=None,\n", 252 | " oob_score=False, random_state=None, verbose=0, warm_start=False)" 253 | ] 254 | }, 255 | "execution_count": 17, 256 | "metadata": {}, 257 | "output_type": "execute_result" 258 | } 259 | ], 260 | "source": [ 261 | "rf = RandomForestRegressor(n_estimators=50, max_depth=8)\n", 262 | "rf.fit(train_in, train_out.values.ravel())" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": {}, 268 | "source": [ 269 | "We want to interpret the input sensitivities for the random forest. We will use a technique called partial dependence plots, which changes the inputs in a way that can affect the mean output and reveal sensitivities of the model to certain ranges of inputs. " 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 18, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "def partial_dependence_1d(x, model, var_index, var_vals):\n", 279 | " \"\"\"\n", 280 | " Calculate how the mean prediction of an ML model varies if one variable's value is fixed across all input\n", 281 | " examples.\n", 282 | "\n", 283 | " Args:\n", 284 | " x: array of input variables\n", 285 | " model: scikit-learn style model object\n", 286 | " var_index: column index of the variable being investigated\n", 287 | " var_vals: values of the input variable that are fixed.\n", 288 | "\n", 289 | " Returns:\n", 290 | " Array of partial dependence values.\n", 291 | " \"\"\"\n", 292 | " partial_dependence = np.zeros(var_vals.shape)\n", 293 | " x_copy = np.copy(x)\n", 294 | " for v, var_val in enumerate(var_vals):\n", 295 | " x_copy[:, var_index] = var_val\n", 296 | " partial_dependence[v] = model.predict(x_copy).mean()\n", 297 | " return partial_dependence" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 19, 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "data": { 307 | "text/html": [ 308 | "
\n", 309 | "\n", 322 | "\n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | "
REFL_COM_meanU10_meanV10_meanT2_mean
count76377.00000076377.00000076377.00000076377.000000
mean46.8465150.4497490.542538289.408528
std3.9840914.3476254.4619366.931555
min40.082820-18.653700-20.561280262.921600
25%43.755930-2.366580-2.301950285.114260
50%46.1706500.3968100.659590290.625850
75%49.1281903.2719803.496990294.398190
max68.74316019.68664018.616930312.149230
\n", 391 | "
" 392 | ], 393 | "text/plain": [ 394 | " REFL_COM_mean U10_mean V10_mean T2_mean\n", 395 | "count 76377.000000 76377.000000 76377.000000 76377.000000\n", 396 | "mean 46.846515 0.449749 0.542538 289.408528\n", 397 | "std 3.984091 4.347625 4.461936 6.931555\n", 398 | "min 40.082820 -18.653700 -20.561280 262.921600\n", 399 | "25% 43.755930 -2.366580 -2.301950 285.114260\n", 400 | "50% 46.170650 0.396810 0.659590 290.625850\n", 401 | "75% 49.128190 3.271980 3.496990 294.398190\n", 402 | "max 68.743160 19.686640 18.616930 312.149230" 403 | ] 404 | }, 405 | "execution_count": 19, 406 | "metadata": {}, 407 | "output_type": "execute_result" 408 | } 409 | ], 410 | "source": [ 411 | "train_in.describe()" 412 | ] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "metadata": {}, 417 | "source": [ 418 | "Here is an example for a single variable." 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": 41, 424 | "metadata": {}, 425 | "outputs": [], 426 | "source": [ 427 | "pd_count = 50\n", 428 | "index = 0\n", 429 | "var_vals = np.linspace(train_in.iloc[:, 0].min(), train_in.iloc[:, 0].max(), pd_count)\n", 430 | "pd_vals = partial_dependence_1d(train_in, rf, index, var_vals)" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 43, 436 | "metadata": {}, 437 | "outputs": [ 438 | { 439 | "name": "stdout", 440 | "output_type": "stream", 441 | "text": [ 442 | "4.99 s ± 236 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" 443 | ] 444 | } 445 | ], 446 | "source": [ 447 | "%timeit pd_vals = partial_dependence_1d(train_in, rf, index, var_vals)" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 45, 453 | "metadata": {}, 454 | "outputs": [ 455 | { 456 | "data": { 457 | "text/plain": [ 458 | "Text(0.5, 0, 'REFL_COM_mean')" 459 | ] 460 | }, 461 | "execution_count": 45, 462 | "metadata": {}, 463 | "output_type": "execute_result" 464 | }, 465 | { 466 | "data": { 467 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAELCAYAAAA2mZrgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xt4VfWd7/H3dydcEgSKkToIJEFl6KCdemGY09I6M/qoKPPItMNUMB45DphjFQfrzOlRUnWGR7S251TpGUWDgFQyAmov6NQrerxwRjRMoQhKJ8UEIjpSQSwGEpN8zx977XRnZ+2wd0iyL/m8nicPe//Wb631+7ni+mat383cHREREYBIpgsgIiLZQ0FBREQ6KCiIiEgHBQUREemgoCAiIh0UFEREpIOCgoiIdFBQEBGRDgoKIiLSoTDTBUjHSSed5OXl5ZkuhohIztiyZctv3X10qvlzKiiUl5dTW1ub6WKIiOQMM2tIJ79eH4mISAcFBRER6aCgICIiHRQURESkg4KCiIh0UFAQEclSNTU1lJeXE4lEKC8vp6amps/PmVNdUkVEBoqamhoqKytpamoCoKGhgcrKSgAqKir67Lx6UhARyUJVVVUdASGmqamJqqqqPj2vgoKISBbas2dPWum9RUFBRCQLlZaWppXeW1IKCmY23cx2mVmdmd0csn2Ima0Ltm82s/K4bbcE6bvM7OK49M+Z2eNm9o6ZvW1mX+6NComI5IPbb78dM+uUVlxczJIlS/r0vMcMCmZWANwHXAJMBuaY2eSEbPOAg+5+OnAPcHew72RgNnAGMB24PzgewFLgGXf/AvAl4O3jr46ISH7Yt28f7s7JJ5+MmVFWVkZ1dXWfNjJDar2PpgJ17r4bwMzWAjOBnXF5ZgL/GHx+HPhni4a4mcBad28G3jWzOmCqme0AzgP+G4C7twAtx10bEZE88MEHH3DXXXfxjW98gyeeeKJfz53K66OxwN64741BWmged28FDgEl3ex7KrAfWGVmvzSzh8xsWI9qICKSZ2699VZaWlq4++67+/3cqQQFC0nzFPMkSy8EzgGWufvZwKdAl7YKADOrNLNaM6vdv39/CsUVEcld27ZtY8WKFdxwww2cfvrp/X7+VIJCIzA+7vs4YF+yPGZWCIwEDnSzbyPQ6O6bg/THiQaJLty92t2nuPuU0aNTXidCRCTnuDt///d/z6hRo/jud7+bkTKkEhTeBCaa2QQzG0y04XhDQp4NwNzg8yzgRXf3IH120DtpAjAReMPdPwD2mtmkYJ8L6NxGISIyYMSmsygoKGDjxo3MmDGDUaNGZaQsx2xodvdWM1sAPAsUACvdfYeZLQZq3X0DsAJ4JGhIPkA0cBDkW0/0ht8KXO/ubcGhbwBqgkCzG7i6l+smIpL1EqezAHjiiSe4+OKL+7ynURiL/kGfG6ZMmeJajlNE8kl5eTkNDV1XzCwrK6O+vv64j29mW9x9Sqr5NaJZRCSDMjWdRTIKCiIiGTRu3LjQ9L6eziIZBQURkQxx99Cbf39MZ5GMgoKISIY88MADbNq0iVmzZlFWVtav01kko0V2REQy4I033mDhwoVceumlrFu3jkgkO/5Gz45SiIgMAPHLa37lK19h5MiRPPLII1kTEEBBQUQGmEysexw7b2VlJQ0NDbg7bW1tHD58mKeffrpfzp8qjVMQkQEjbKBYcXFxv7zD7+vxCMmkO05BQUFEBoxM3ZgBIpEIYfdbM6O9vb3PzqvBayIiSWRyoNiJJ54Ymp6p8QjJKCiIyIAxfvz40PS+vjE/9thjfPTRR10alDM5HiEZBQURGTDOP//8LmlFRUV9emN+8sknueKKK5g2bRrLly/PmvEIyahNQUQGhN/97necfvrpjBo1iiNHjrB3717cnRkzZvDUU0/16rlqamqoqqpiz549uDsTJkxg69atjBgxolfPkwq1KYiIhPj+97/Phx9+yI9//GMaGhpob29n3rx5PPPMM2zfvr3XzpPY9RSiay4/+eSTvXaOvqQnBRHJe++99x4TJ07ksssuY+3atR3pH330EZMmTeILX/gCr7zySq8MIstkD6cwelIQEUlw22230drayp133tkpvaSkhB/84Ads2rSJVatWHfd5WltbQwMCZG4q7HQpKIhIXtu+fTurVq1iwYIFnHrqqV22z507l6997WssXLiQ8ePH93ik8wcffMCFF16YdHu2dT1NRkFBRPLad77zHUaOHMl3v/vd0O2RSIQZM2bw6aef0tjYiLvT0NBAZWVlt4EhfrqMP/iDP2DSpEls3ryZa6+9luLi4k55s7HraVLunjM/5557rouIxKxZs8bLysrczLysrMzXrFnTJR3wOXPmdHucsrIyB7r8xI6ZeI41a9Z4cXFxp7xm5nfddVe35coEoNbTuM+qoVlEsl58F8/S0tKOv7rD5jGaO3cuq1evTmt+o2RTUAAUFhbS2tra6buZ8dlnn3XJm6nG5O5o7iMRyVmp3vyHDh3KkCFDOHToUMrH7u6GnazHULr6eh6jnlBQEJGcFDaD6ZAhQygoKOiU1lPd3bCTzZ6a7nnz4UlBDc0ikhWqqqq63ISbm5vTvjEXFBSEpnfX+6eiooLq6uouU1CUlZWF5i8pKcntxuRuKCiISMY1NTWl/fom2Y25srKyRzfsiooK6uvraW9vp76+noqKCpYsWRJ6rKVLl4YGkWybx6hH0mmVzvSPeh+J5If43jmjR4/2k046KbT3D+AlJSVdevoUFxd39AI6Vu+j4+39k009iXqCNHsfZfxGn86PgoJI9kn3xpysO+fMmTPTvvnLsSkoiEi/CbvBDx061OfMmeNDhw7tlD5o0CC/4IILuqRzjDEBcnzSDQrqfSQiPZZuV85IJJK0B1A2dufMB+p9JCL9Jt1J3tw9aY+eXJkbKN+lFBTMbLqZ7TKzOjO7OWT7EDNbF2zfbGblcdtuCdJ3mdnFcen1ZrbdzLaamf78F8lBJ510Umh6d91Ck/XoyYfunPngmEHBzAqA+4BLgMnAHDObnJBtHnDQ3U8H7gHuDvadDMwGzgCmA/cHx4v5C3c/K51HGxHJDnV1dRw+fBgz65R+rG6hycYE5EV3zjyQypPCVKDO3Xe7ewuwFpiZkGcmsDr4/DhwgUV/U2YCa9292d3fBeqC44lIDjt8+DBf//rXKSoq4oc//GGXG/z999/f7Y0/bEyAZIfCFPKMBfbGfW8E/jRZHndvNbNDQEmQ/nrCvmODzw48F8xi+KC7V4ed3MwqgUrQO0eRTIqfl6ioqIimpiaeffZZLrroIm688cYu+SsqKnSzz0GpBAULSUvsspQsT3f7TnP3fWb2eeB5M3vH3V/pkjkaLKoh2vsohfKKSC9LnBuoqamJQYMGsX///gyXTHpbKq+PGoHxcd/HAfuS5TGzQmAkcKC7fd099u+HwE/RayWRrBU2L9Fnn31GVVVVhkokfSWVoPAmMNHMJpjZYKINxxsS8mwA5gafZwEvBoMmNgCzg95JE4CJwBtmNszMhgOY2TDgIuCt46+OiPSFZF1Pc2XdYUndMV8fBW0EC4BngQJgpbvvMLPFREfKbQBWAI+YWR3RJ4TZwb47zGw9sBNoBa539zYzOxn4adBroRD4F3d/pg/qJyK9YOzYsTQ2NnZJVztf/kmlTQF3/wXwi4S02+I+HwX+Jsm+S4AlCWm7gS+lW1gR6X+fffYZI0aM6JKusQX5SSOaRSQpd2fBggXs3LmTyspKjS0YABQURPJYTU0N5eXlRCIRysvLqampSWv/e+65h+rqam655RYefPBBjS0YAFJ6fSQiuSexG2lDQwOVlZUA3d7Q48cjuDtTpkzhjjvu6JcyS+bpSUEkT4V1I21qauroRhr2FBELJA0NDcRmUN6xYwePPvpov5dfMkNTZ4vkqUgkQrL/v2+88UYefPBBjhw50pE2ePBgIpEIR48e7ZI/Gxekl9Ro6mwRAbrvLnrvvfd2CggALS0toQEBNB5hIFFQEMlTF154YZe04uJiHn744S4zmx6LxiMMHAoKInnohRde4OGHH+bMM8+ktLS0UzfSuXPnJr3Jl5SUaK2DAU5BQSTP7Nixg7/+67/mj/7oj9i0aRMNDQ1dupEmW+hm6dKlWutggFOXVJE8EN+NNBKJMGzYMJ566qnQkcjw+y6psX1iK6LFr3cgA5OeFERyXGI30ra2NlpaWnj11Ve73U8L3UgYBQWRHBc2HuHo0aOa1lp6REFBJMdpWmvpTQoKIjlu/PjxoenqRio9oaAgkuOSjUdQN1LpCQUFkRx24MABfvaznzFp0qQu4xHUcCw9oS6pIjnstttu4+DBg2zcuJEvfUnrVsnx05OCSAYdz3oH27ZtY9myZXzrW99SQJBeoycFkQzp6XoHEF0R7YYbbmDUqFEsXry4z8sqA4eeFEQypLv1DpI9QcSnv/rqq8ycOZMTTzwxE8WXPKX1FEQypLv1DoqKijpNbV1cXMzcuXNZvXp1p0BSXFysRmXpVrrrKSgoiGRIWVlZWgPMzCw0iGgBHOmOFtkRyRFhf90nzlwaL9kfcBq5LL1JQUEkQ959912Ki4sZP358p/EFZWVlofkLCgpC0zVyWXqTgoJIBvz2t7/lJz/5CfPmzWPPnj2dZipNttZBZWWlFsCRPqegIJIBa9asoaWlhfnz53fZVlFREbrQzf33368FcKTPqaFZpJ+5O2eeeSYnnHACmzdvznRxJM+poVkky73++uvs3LmTa665JtNFEekipaBgZtPNbJeZ1ZnZzSHbh5jZumD7ZjMrj9t2S5C+y8wuTtivwMx+aWZPHW9FRHLF8uXLGTZsGJdffnmmiyLSxTGDgpkVAPcBlwCTgTlmNjkh2zzgoLufDtwD3B3sOxmYDZwBTAfuD44XsxB4+3grIZIrPvnkE9atW8ecOXMYPnx4posj0kUqTwpTgTp33+3uLcBaYGZCnpnA6uDz48AFZmZB+lp3b3b3d4G64HiY2ThgBvDQ8VdDJDc8+uijNDU1hTYwi2SDVILCWGBv3PfGIC00j7u3AoeAkmPsey/wHaA97VKL5KiHHnqIL37xi0ydOjXTRREJlUpQsJC0xC5LyfKEppvZXwIfuvuWY57crNLMas2sdv/+/ccurUiW2rp1K7W1tcyfP5/og7RI9kklKDQC8YvAjgP2JctjZoXASOBAN/tOAy4zs3qir6PON7M1YSd392p3n+LuU0aPHp1CcUWyS2xm07PPPhuAoUOHZrhEIsmlEhTeBCaa2QQzG0y04XhDQp4NwNzg8yzgRY8OgNgAzA56J00AJgJvuPst7j7O3cuD473o7lf2Qn1EskpszYSGhoaOtG9/+9tpLaYj0p+OGRSCNoIFwLNEewqtd/cdZrbYzC4Lsq0ASsysDrgJuDnYdwewHtgJPANc7+5tvV8NkezU3ZoJItlII5pF+lCyNRPMjPZ29bGQvqcRzSJZJNkMpprZVLKVgoJIH/rmN7/ZJU0zm0o2U1AQ6SOHDx9m/fr1jBkzhtLSUs1sKjmhMNMFEMlXt956Kw0NDbz22mtMmzYt08URSYmeFET6wObNm1m6dCnXXXedAoLkFAUFkV7W0tLCNddcwymnnMJdd92V6eKIpEVBQaSXxEYuDxkyhO3bt3P55ZczYsSITBdLJC0KCiK9IGzk8gMPPKCRy5JzFBREeoFGLku+UFAQOU779u3r9IQQb8+ePf1cGpHjo6AgkqZY20EkEqGkpITTTjstaV6NXJZco6Agkob4tgN358CBA7S0tHDFFVdQXFzcKa9GLksuUlAQSUNY20F7ezubNm2iurqasrIyjVyWnKZZUkXSoFlPJddollSRPvLaa68l3aa2A8kXCgoiKXjttdeYPn06J598MkVFRZ22qe1A8omCggxo8T2JysvLOw02i9923nnnMXz4cLZs2cLy5cvVdiB5S7OkyoAV60kUazhuaGigsrKyY3v8NoBDhw7x0ksvUVFRoSAgeUsNzTJglZeXJx10lkxZWRn19fV9UyCRPqCGZpEU9WS0sUYoS75TUJABK1mPobKyMsrKytLaRyRfKCjIgHXTTTd1SYv1JFqyZIlGKMuApIZmGbBeeeUVBg0axOjRo3n//fcpLS1lyZIlnRqRq6qq2LNnT+g2kXykhmYZkJ5++mkuvfRSlixZwqJFizJdHJE+k25Ds4KCDDhHjhzhzDPPZPDgwWzbto3BgwdnukgifSbdoKDXRzLgfO9732P37t1s3LhRAUEkgYKCDAg1NTUd7QPuzpe//GXOP//8TBdLJOuo95HkvcQ1EAC2bt2q9ZNFQigoSN4LWwPhyJEjWj9ZJERKQcHMppvZLjOrM7ObQ7YPMbN1wfbNZlYet+2WIH2XmV0cpA01szfMbJuZ7TCzf+qtCokkSjYKWaOTRbo6ZlAwswLgPuASYDIwx8wmJ2SbBxx099OBe4C7g30nA7OBM4DpwP3B8ZqB8939S8BZwHQz+y+9UyWRzj7/+c+Hpmt0skhXqTwpTAXq3H23u7cAa4GZCXlmAquDz48DF5iZBelr3b3Z3d8F6oCpHnU4yD8o+MmdvrGSM2pra/n444+J/jr+nkYni4RLJSiMBfbGfW8M0kLzuHsrcAgo6W5fMysws63Ah8Dz7r65JxUQSWbnzp1Mnz6dMWPG8KMf/UhrIIikIJUuqRaSlvhXfbI8Sfd19zbgLDP7HPBTMzvT3d/qcnKzSqAS9Lgvxxbf9dTMGD58OC+88AKnnXYaCxYsyHTxRLJeKk8KjcD4uO/jgH3J8phZITASOJDKvu7+MfB/ibY5dOHu1e4+xd2njB49OoXiykAQtmJaYtfT9vZ2mpubef311zNdXJGcccxpLoKb/K+BC4D3gDeBK9x9R1ye64Evuvu1ZjYb+Ia7f9PMzgD+hWi7xCnARmAicCLwmbt/bGZFwHPA3e7+VHdl0TQXAl1XTAMYPHgwkUiEo0ePdsmvhXFkIOv1aS7cvdXMFgDPAgXASnffYWaLgVp33wCsAB4xszqiTwizg313mNl6YCfQClzv7m1mNgZYHfREigDrjxUQRGLCxh20tLQkza+upyKp04R4knMikQjp/N7qSUEGMi3HKXnvlFNOCU0vKSnRwjgix0lBQXKKu1NSUtIlvbi4mKVLl1JdXa2upyLHQbOkSk5ZuXIlv/rVr7jqqqt4+eWXQ1dFUxAQ6TkFBckZe/fu5aabbuLP//zPWbVqFZGIHnRFepv+r5Kc4O7Mnz+ftrY2VqxYoYAg0kf0f5ZktfhBas899xyzZs3i1FNPzXSxRPKWgoJkrfgRyjGPPfaYFscR6UMKCpK1wgapNTU1aXEckT6koCBZS4vjiPQ/BQXJWlocR6T/KShIVtq/fz/Nzc1aHEeknykoSNZpa2vjyiuv5MiRI9xxxx0aoSzSjzR4TbLOnXfeyXPPPUd1dTXXXHMNixYtynSRRAYMPSlIVtm4cSO33347V155JfPnz890cUQGHAUFybj4AWoXXXQRY8aMYdmyZV3aE0Sk7ykoSEaFLaF54MABfv7zn2e6aCIDkoKCZFTYALWjR49qgJpIhigoSMa4e6cpLOJpgJpIZigoSL+JbzsYP34855xzTtK8GqAmkhkKCtIvEtsOGhsb2bp1K9OmTdMSmiJZREFB+kVY2wFAY2OjltAUySLm7pkuQ8qmTJnitbW1mS6G9EAkEiHsd83MaG9vz0CJRAYGM9vi7lNSza8nBel1iW0HX/3qV0MDAqjtQCTbaJoL6VWxtoPYq6LGxkYaGxuZMmUKO3fu7PQKSW0HItlHTwrSq5K1Hezfv19tByI5QG0K0mvcnUgk/O8MtR2IZIbaFKTfxLcdjB07lsmTJyfNq7YDkdygoCA9kjjuYN++fbzzzjv82Z/9mcYdiOSwlIKCmU03s11mVmdmN4dsH2Jm64Ltm82sPG7bLUH6LjO7OEgbb2YvmdnbZrbDzBb2VoWkfyRrO6ivr1fbgUgOO2abgpkVAL8GLgQagTeBOe6+My7PdcAfu/u1ZjYb+Lq7X25mk4FHganAKcALwB8CnwfGuPu/m9lwYAvwV/HHDKM2heyRbFprtR2IZJe+aFOYCtS5+253bwHWAjMT8swEVgefHwcusOhdYyaw1t2b3f1doA6Y6u7vu/u/A7j774C3gbGpFloyp7m5meuuuy7pdrUdiOS2VILCWGBv3PdGut7AO/K4eytwCChJZd/gVdPZwObUiy39Kb5BecSIESxbtowZM2ao7UAkD6USFMLeEyS+c0qWp9t9zewE4AngRnf/JPTkZpVmVmtmtfv370+huNKbEhuUW1paGDJkCHPmzFHbgUgeSmVEcyMwPu77OGBfkjyNZlYIjAQOdLevmQ0iGhBq3P0nyU7u7tVANUTbFFIor/SiRYsWdWlQbm5upqqqivr6egUBkTyTypPCm8BEM5tgZoOB2cCGhDwbgLnB51nAix5twd4AzA56J00AJgJvBO0NK4C33f2HvVER6X2tra1JF7vRIjgi+emYTwru3mpmC4BngQJgpbvvMLPFQK27byB6g3/EzOqIPiHMDvbdYWbrgZ1AK3C9u7eZ2VeB/wpsN7OtwakWufsveruC0jOHDx/m8ssvT7pdDcoi+UnTXAgQbTuoqqpiz549nHLKKRQWFrJ3717mzp3LunXrukxkp/YDkdygaS4kbYmNye+99x4NDQ3cdNNNrFy5Ug3KIgOInhSE8vJyGhoauqSXlZVRX1/f/wUSkV6jJwVJmxqTRSRGQWGAiR+INm7cOM477zytiiYiHRQU8lT8zb+8vJyamprQtoNXX32Vs846i6Kiok77a3SyyMCkoJCHEm/+DQ0NXH311fzt3/5t6MymBw8eZPny5WpMFhE1NOejZA3HyWhmU5H8pYZmSbuBWG0HIhKjoJBnYhPWhSkpKdHMpiLSLQWFPNLW1sZVV13F0aNHGTx4cKdtxcXFLF26VAPRRKRbqcySKlksfnqKYcOGcfjwYX7wgx8wZsyYjvTS0lKWLFnScfNXEBCRZBQUclisl1GsR9Hhw4cpLCxkzJgxVFRU6OYvImnT66McEDbmAKCqqqpLF9PW1laqqqoyUUwRyQN6UshyiU8DDQ0NzJ8/n3/9139N2u1U01OISE/pSSHLhT0NHD16lEcffZRBgwaF7qMupiLSUwoKWS7ZX/1mxqpVq9TFVER6lYJCFolvOygrK2PhwoVEIuGXqLS0lIqKCnUxFZFepWkuskRi20HMiSeeyKeffkpzc3NHmlY+E5FUaZqLHBXWdgBwwgknsGLFCj0NiEi/UFDIgMQuposXL07ak2jv3r1UVFRQX19Pe3s79fX1Cggi0mfUJbWfhXUxvf322zGz0MVu1JNIRPqTnhT62aJFi0JfE40aNUo9iUQk4xQU+lD8a6LS0lKuuuqqpF1MDx48qJ5EIpJx6n3UR5L1JiosLKS1tbVL/rKyMurr6/updCIyUKj3UQYkNhwvX76chQsXhr4mGjlypF4TiUjWUlA4TmHrIVdWVvLRRx+F5j9w4IBeE4lI1tLro+OUbD3kSCQSuu6xXhOJSH/S66N+lqzhuL29Xa+JRCTnKCgch1//+teYWei22GshvSYSkVySUlAws+lmtsvM6szs5pDtQ8xsXbB9s5mVx227JUjfZWYXx6WvNLMPzeyt3qhIf9u7dy8XXnghw4YNY+jQoZ22xZ4INBJZRHLNMYOCmRUA9wGXAJOBOWY2OSHbPOCgu58O3APcHew7GZgNnAFMB+4PjgfwcJCWE+J7GI0fP56pU6fy8ccf8/LLL/PQQw/piUBE8kIq01xMBercfTeAma0FZgI74/LMBP4x+Pw48M8Wfa8yE1jr7s3Au2ZWFxzv39z9lfgnimyWOOagsbERgFtvvZWzzz6bs88+W0FARPJCKq+PxgJ74743Bmmhedy9FTgElKS4b9ZLNoPpj3/84wyURkSk76QSFMJaUhP7sSbLk8q+3Z/crNLMas2sdv/+/ens2iOJA9FqamqS9jDSWsgikm9SCQqNwPi47+OAfcnymFkhMBI4kOK+3XL3anef4u5TRo8enc6uQPhNPll62EC0uXPnhs5eCprBVETykLt3+0O03WE3MAEYDGwDzkjIcz3wQPB5NrA++HxGkH9IsP9uoCBuv3LgrWOVIfZz7rnnejrWrFnjxcXFTvTpxAEvLi72b33rW13ShwwZ4ieccEKntNhPUVGRFxUVdTnOmjVr0iqPiEh/A2o9xXusux/7ScGjbQQLgGeBt4Mb/g4zW2xmlwXZVgAlQUPyTcDNwb47gPVEG6WfAa539zYAM3sU+Ddgkpk1mtm8FONYysLaApqamli2bFmX9ObmZg4fPhx6nKNHj7J8+XL1MBKRvJfX01xEIpGkr37SoakpRCRXaZqLOMne+RcUFISml5SUaGoKERnQ8jooLFmyJPQmX1lZGZq+dOlSTU0hIgNaXq/RHLuZV1VVsWfPHkpLSzumn5g2bVpoevx+IiIDTV63KYiIDHRqUxARkR5TUBARkQ4KCiIi0kFBQUREOigoiIhIh5zqfWRm+4GGHux6EvDbXi5OpqlOuUF1yg35WCeI1muYu6c8m2hOBYWeMrPadLpk5QLVKTeoTrkhH+sEPauXXh+JiEgHBQUREekwUIJCdaYL0AdUp9ygOuWGfKwT9KBeA6JNQUREUjNQnhRERCQFeRkUzKzAzH5pZk8F3yeY2WYz+w8zW2dmgzNdxnSF1OlhM3vXzLYGP2dluozpMrN6M9selL82SDvRzJ4PrtXzZjYq0+VMR5I6/aOZvRd3rS7NdDnTYWafM7PHzewdM3vbzL6cB9cprE45e53MbFJcubea2SdmdmNPrlNeBgVgIdGlQ2PuBu5x94nAQaDXl/7sB4l1Avgf7n5W8LM1E4XqBX8RlD/Wbe5mYGNwrTYG33NNYp0g+vsXu1a/yFjJemYp8Iy7fwH4EtHfw1y/TmF1ghy9Tu6+K1Zu4FygCfgpPbhOeRcUzGwcMAN4KPhuwPnA40GW1cBfZaZ0PZNYpzw3k+g1ghy8VvnGzEYA5xFdhx13b3H3j8nh69RNnfLFBcBv3L2BHlynvAsKwL3Ad4D24HsJ8LG7twbfG4HMuNu4AAAFgElEQVSxmSjYcUisU8wSM/uVmd1jZkMyUK7j5cBzZrbFzCqDtJPd/X2A4N/PZ6x0PRNWJ4AFwbVamWOvWk4F9gOrgteXD5nZMHL7OiWrE+TudYo3G3g0+Jz2dcqroGBmfwl86O5b4pNDsuZMl6skdQK4BfgC8CfAicD/7O+y9YJp7n4OcAlwvZmdl+kC9YKwOi0DTgPOAt4H/ncGy5euQuAcYJm7nw18Su69KkqUrE65fJ0ACNpLLwMe6+kx8iooANOAy8ysHlhL9LXRvcDnzCy29Og4YF9mitcjXepkZmvc/X2PagZWAVMzWciecPd9wb8fEn3/ORX4TzMbAxD8+2HmSpi+sDq5+3+6e5u7twPLya1r1Qg0uvvm4PvjRG+ouXydQuuU49cp5hLg3939P4PvaV+nvAoK7n6Lu49z93Kij1AvunsF8BIwK8g2F/h5hoqYtiR1ujLuQhvR94RvZbCYaTOzYWY2PPYZuIhoHTYQvUaQY9cqWZ1i1yrwdXLoWrn7B8BeM5sUJF0A7CSHr1OyOuXydYozh9+/OoIeXKfCY2XIE/8TWGtmdwC/JGhgynE1Zjaa6OuxrcC1GS5Puk4GfhqNaRQC/+Luz5jZm8B6M5sH7AH+JoNlTFeyOj0SdBl2oB7475krYo/cQPT3bTCwG7ia6B+UuXqdILxOP8rl62RmxcCFdC7390jzOmlEs4iIdMir10ciInJ8FBRERKSDgoKIiHRQUBARkQ4KCiIi0kFBQUREOigoSFYzs7ZgKuC3zOxJM/tckF5uZkcSpgu+KtgWP331VjP7SpA/5cFIZvYPwbTKb5nZtrhjDzaze83sN8F0xD8PJiyM7edm9kjc90Iz22/BlOci2W6gDF6T3HUkmA4YM1sNXA8sCbb9JrYtxF+4+29jX8ysPNUTmtm1RAcBTXX3T8xsJL+fXfJOYDjwh+7eZmZXAz8xsz/16KCfT4EzzazI3Y8Ex3kv1XOLZJqeFCSX/Bv9M8PtIuA6d/8EwN0PufvqYMTo1cC33b0t2LYKaCY6z1bM00SnOoeu0w50ESzustrMnguecr5hZt8PnnaeMbNBQb5zzezlYAbWZ+OmOrnGzN4MnmieCMoZW4jpR2b2/8xst5nN6q4cIqCgIDnCzAqIzlGzIS75tITXR1+L2/ZSkLaZNARzFw1399+EbD4d2BMLFnFqgTPivq8FZpvZUOCPgVTKcBrRQDITWAO85O5fBI4AM4LA8H+AWe5+LrCS3z8x/cTd/8TdY4vFxC8iNQb4KvCXRKc8EOmWXh9Jtisys61AObAFeD5uW8qvj9JgJJ9aPdm2Tunu/qvgddUcINXVu55298/MbDtQADwTpG8nWvdJwJnA88HcSgVEp3eG6OuqO4DPAScAz8Yd92fBrJ87zezkFMsiA5ieFCTbxdoUyoDBRNsU+kzwFPCpmZ0asrkOKIvNhBrnHKIzh8bbAPwvjvHqKE5zcP524DP//aRk7UT/eDNgR9xSkV9094uCPA8DC4Ini38ChiYeNxC2tohIJwoKkhPc/RDwd8A/xN6x96G7gPssumwjZjbCzCrd/VOiSxr+MHidRdArqRh4MeEYK4HF7r69l8q0CxhtZl8OzjvIzGKvrIYD7wf/XSp66XwyQCkoSM5w918C24iuKwFd2xT+7hiHmGRmjXE/yaYRXkZ0DY43g26sLxNdCB2iK94dBX5tZv9BdCrir8f9ZR8ra6O7L02/luHcvYXomiB3m9k2otOlfyXYfCvRdovngXd665wyMGnqbBER6aAnBRER6aDeRzJgmdl9RNfAjrc0GHvQF+e7GliYkLzJ3fu08VwkHXp9JCIiHfT6SEREOigoiIhIBwUFERHpoKAgIiIdFBRERKTD/wfrdbO4LUPpZwAAAABJRU5ErkJggg==\n", 468 | "text/plain": [ 469 | "
" 470 | ] 471 | }, 472 | "metadata": { 473 | "needs_background": "light" 474 | }, 475 | "output_type": "display_data" 476 | } 477 | ], 478 | "source": [ 479 | "plt.plot(var_vals, pd_vals, 'ko-')\n", 480 | "plt.xlabel(input_cols[index])" 481 | ] 482 | }, 483 | { 484 | "cell_type": "markdown", 485 | "metadata": {}, 486 | "source": [ 487 | "How well can Dask parallelize the task across all the input variables?" 488 | ] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "execution_count": 25, 493 | "metadata": {}, 494 | "outputs": [], 495 | "source": [ 496 | "cluster = LocalCluster(n_workers=4)\n", 497 | "client = Client(cluster)" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": 26, 503 | "metadata": {}, 504 | "outputs": [ 505 | { 506 | "data": { 507 | "text/html": [ 508 | "\n", 509 | "\n", 510 | "\n", 517 | "\n", 525 | "\n", 526 | "
\n", 511 | "

Client

\n", 512 | "\n", 516 | "
\n", 518 | "

Cluster

\n", 519 | "
    \n", 520 | "
  • Workers: 4
  • \n", 521 | "
  • Cores: 8
  • \n", 522 | "
  • Memory: 17.18 GB
  • \n", 523 | "
\n", 524 | "
" 527 | ], 528 | "text/plain": [ 529 | "" 530 | ] 531 | }, 532 | "execution_count": 26, 533 | "metadata": {}, 534 | "output_type": "execute_result" 535 | } 536 | ], 537 | "source": [ 538 | "client" 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": 38, 544 | "metadata": {}, 545 | "outputs": [], 546 | "source": [ 547 | "pd_count = 50\n", 548 | "var_vals = np.zeros((len(input_cols), pd_count))\n", 549 | "futures = []\n", 550 | "train_future = client.scatter(train_in.values)\n", 551 | "for i, input_col in enumerate(input_cols):\n", 552 | " var_vals[i] = np.linspace(train_in[input_col].min(), train_in[input_col].max(), pd_count)\n", 553 | " futures.append(client.submit(partial_dependence_1d, train_future, rf, i, var_vals[i]))" 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": 39, 559 | "metadata": {}, 560 | "outputs": [], 561 | "source": [ 562 | "results = client.gather(futures)" 563 | ] 564 | }, 565 | { 566 | "cell_type": "code", 567 | "execution_count": 40, 568 | "metadata": {}, 569 | "outputs": [ 570 | { 571 | "data": { 572 | "image/png": "\n", 573 | "text/plain": [ 574 | "
" 575 | ] 576 | }, 577 | "metadata": { 578 | "needs_background": "light" 579 | }, 580 | "output_type": "display_data" 581 | } 582 | ], 583 | "source": [ 584 | "fig, axes = plt.subplots(2, 2, figsize=(10, 8))\n", 585 | "axef = axes.ravel()\n", 586 | "for r, res in enumerate(results):\n", 587 | " axef[r].plot(var_vals[r], res, 'ko-')\n", 588 | " axef[r].set_xlabel(input_cols[r])" 589 | ] 590 | }, 591 | { 592 | "cell_type": "markdown", 593 | "metadata": {}, 594 | "source": [ 595 | "It turns out that in this case, dask already does quite well with paralellizing the task? \n", 596 | "\n", 597 | "Is there a way we can make dask struggle or fail? At what dataset size do we see problems?\n", 598 | "\n", 599 | "Is there a more concise way to paralellize the task?" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": null, 605 | "metadata": {}, 606 | "outputs": [], 607 | "source": [] 608 | } 609 | ], 610 | "metadata": { 611 | "kernelspec": { 612 | "display_name": "Python 3", 613 | "language": "python", 614 | "name": "python3" 615 | }, 616 | "language_info": { 617 | "codemirror_mode": { 618 | "name": "ipython", 619 | "version": 3 620 | }, 621 | "file_extension": ".py", 622 | "mimetype": "text/x-python", 623 | "name": "python", 624 | "nbconvert_exporter": "python", 625 | "pygments_lexer": "ipython3", 626 | "version": "3.6.7" 627 | } 628 | }, 629 | "nbformat": 4, 630 | "nbformat_minor": 2 631 | } 632 | -------------------------------------------------------------------------------- /rasp-data-loading.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# xarray use case: Neural network training\n", 8 | "\n", 9 | "\n", 10 | "**tl;dr**\n", 11 | "\n", 12 | "1. This notebook is an example of reading from a climate model netCDF file to train a neural network. Neural networks (for use in parameterization research) require random columns of several stacked variables at a time. \n", 13 | "\n", 14 | "2. Experiments in this notebook show:\n", 15 | " 1. Reading from raw climate model output files is super slow (1s per batch... need speeds on the order of ms)\n", 16 | " 2. open_mfdataset is half as fast as opening the same dataset with open_dataset\n", 17 | " 3. Pure h5py is much faster than reading the same dataset using xarray (even using the h5 backend)\n", 18 | "\n", 19 | "3. Currently, I revert to preformatting the dataset (flatten time, lat, lon). This gets the reading speed down to milliseconds per batch.\n", 20 | "\n", 21 | "**Conclusions**\n", 22 | "\n", 23 | "Reading straight from the raw netCDF files (with all dimensions intact) is handy and might be necessary for later applications (using continuous time slices or lat-lon regions for RNNs or CNNs).\n", 24 | "\n", 25 | "However, at the moment this is many orders of magnitude too slow. Preprocessing seems required.\n", 26 | "\n", 27 | "What would be a good way of speeding this up without too extensive post processing?\n" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 1, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import xarray as xr\n", 37 | "import numpy as np" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "data": { 47 | "text/plain": [ 48 | "'0.11.2'" 49 | ] 50 | }, 51 | "execution_count": 2, 52 | "metadata": {}, 53 | "output_type": "execute_result" 54 | } 55 | ], 56 | "source": [ 57 | "xr.__version__" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "## Load an example dataset\n", 65 | "\n", 66 | "I uploaded a sample dataset here: http://doi.org/10.5281/zenodo.2559313\n", 67 | "\n", 68 | "The files are around 1GB large. Let's download it.\n", 69 | "\n", 70 | "NOTE: I have all my data on an SSD" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 47, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# Modify this path!\n", 80 | "DATADIR = '/local/S.Rasp/tmp/'" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 48, 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "name": "stdout", 90 | "output_type": "stream", 91 | "text": [ 92 | "--2019-02-07 13:08:52-- https://zenodo.org/record/2559183/files/sample_SPCAM_1.nc\n", 93 | "Resolving zenodo.org (zenodo.org)... 137.138.76.77\n", 94 | "Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.\n", 95 | "HTTP request sent, awaiting response... 200 OK\n", 96 | "Length: 923498891 (881M) [application/octet-stream]\n", 97 | "Saving to: ‘/local/S.Rasp/tmp/sample_SPCAM_1.nc’\n", 98 | "\n", 99 | "sample_SPCAM_1.nc 100%[===================>] 880.72M 6.59MB/s in 1m 49s \n", 100 | "\n", 101 | "2019-02-07 13:10:42 (8.09 MB/s) - ‘/local/S.Rasp/tmp/sample_SPCAM_1.nc’ saved [923498891/923498891]\n", 102 | "\n", 103 | "--2019-02-07 13:10:42-- https://zenodo.org/record/2559183/files/sample_SPCAM_2.nc\n", 104 | "Resolving zenodo.org (zenodo.org)... 137.138.76.77\n", 105 | "Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.\n", 106 | "HTTP request sent, awaiting response... 200 OK\n", 107 | "Length: 923498891 (881M) [application/octet-stream]\n", 108 | "Saving to: ‘/local/S.Rasp/tmp/sample_SPCAM_2.nc’\n", 109 | "\n", 110 | "sample_SPCAM_2.nc 100%[===================>] 880.72M 24.5MB/s in 87s \n", 111 | "\n", 112 | "2019-02-07 13:12:09 (10.1 MB/s) - ‘/local/S.Rasp/tmp/sample_SPCAM_2.nc’ saved [923498891/923498891]\n", 113 | "\n", 114 | "--2019-02-07 13:12:09-- https://zenodo.org/record/2559183/files/sample_SPCAM_concat.nc\n", 115 | "Resolving zenodo.org (zenodo.org)... 137.138.76.77\n", 116 | "Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.\n", 117 | "HTTP request sent, awaiting response... 200 OK\n", 118 | "Length: 1846816429 (1.7G) [application/octet-stream]\n", 119 | "Saving to: ‘/local/S.Rasp/tmp/sample_SPCAM_concat.nc’\n", 120 | "\n", 121 | "sample_SPCAM_concat 100%[===================>] 1.72G 3.80MB/s in 3m 28s \n", 122 | "\n", 123 | "2019-02-07 13:15:38 (8.46 MB/s) - ‘/local/S.Rasp/tmp/sample_SPCAM_concat.nc’ saved [1846816429/1846816429]\n", 124 | "\n" 125 | ] 126 | } 127 | ], 128 | "source": [ 129 | "!wget -P $DATADIR https://zenodo.org/record/2559313/files/sample_SPCAM_1.nc\n", 130 | "!wget -P $DATADIR https://zenodo.org/record/2559313/files/sample_SPCAM_2.nc\n", 131 | "!wget -P $DATADIR https://zenodo.org/record/2559313/files/sample_SPCAM_concat.nc" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 49, 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "name": "stdout", 141 | "output_type": "stream", 142 | "text": [ 143 | "-rw-r--r-- 1 S.Rasp ls-craig 881M Feb 7 13:00 /local/S.Rasp/tmp//sample_SPCAM_1.nc\r\n", 144 | "-rw-r--r-- 1 S.Rasp ls-craig 881M Feb 7 13:00 /local/S.Rasp/tmp//sample_SPCAM_2.nc\r\n", 145 | "-rw-r--r-- 1 S.Rasp ls-craig 1.8G Feb 7 13:00 /local/S.Rasp/tmp//sample_SPCAM_concat.nc\r\n" 146 | ] 147 | } 148 | ], 149 | "source": [ 150 | "!ls -lh $DATADIR/sample_SPCAM*" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "The files are typical climate model output files. `sample_SPCAM_1.nc` and `sample_SPCAM_2.nc` are two contiguous output files. `sample_SPCAM_concat.nc` is the concatenated version of the two files." 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 53, 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "name": "stdout", 167 | "output_type": "stream", 168 | "text": [ 169 | "CPU times: user 56 ms, sys: 0 ns, total: 56 ms\n", 170 | "Wall time: 54.7 ms\n" 171 | ] 172 | } 173 | ], 174 | "source": [ 175 | "%%time\n", 176 | "ds = xr.open_mfdataset(DATADIR + 'sample_SPCAM_1.nc')" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 54, 182 | "metadata": {}, 183 | "outputs": [ 184 | { 185 | "data": { 186 | "text/plain": [ 187 | "\n", 188 | "Dimensions: (crm_x: 32, crm_y: 1, crm_z: 28, ilev: 31, isccp_prs: 7, isccp_prstau: 49, isccp_tau: 7, lat: 64, lev: 30, lon: 128, tbnd: 2, time: 48)\n", 189 | "Coordinates:\n", 190 | " * lat (lat) float64 -87.86 -85.1 -82.31 -79.53 ... 82.31 85.1 87.86\n", 191 | " * lon (lon) float64 0.0 2.812 5.625 8.438 ... 351.6 354.4 357.2\n", 192 | " * crm_x (crm_x) float64 0.0 4.0 8.0 12.0 ... 112.0 116.0 120.0 124.0\n", 193 | " * crm_y (crm_y) float64 0.0\n", 194 | " * crm_z (crm_z) float64 992.6 976.3 957.5 936.2 ... 38.27 24.61 14.36\n", 195 | " * lev (lev) float64 3.643 7.595 14.36 24.61 ... 957.5 976.3 992.6\n", 196 | " * ilev (ilev) float64 2.255 5.032 10.16 18.56 ... 967.5 985.1 1e+03\n", 197 | " * isccp_prs (isccp_prs) float64 90.0 245.0 375.0 500.0 620.0 740.0 900.0\n", 198 | " * isccp_tau (isccp_tau) float64 0.15 0.8 2.45 6.5 16.2 41.5 219.5\n", 199 | " * isccp_prstau (isccp_prstau) float64 90.0 90.0 90.0 ... 900.0 900.0 900.2\n", 200 | " * time (time) object 0000-01-01 00:00:00 ... 0000-01-01 23:29:59\n", 201 | "Dimensions without coordinates: tbnd\n", 202 | "Data variables:\n", 203 | " P0 float64 ...\n", 204 | " time_bnds (time, tbnd) object dask.array\n", 205 | " date_written (time) |S8 dask.array\n", 206 | " time_written (time) |S8 dask.array\n", 207 | " ntrm int32 ...\n", 208 | " ntrn int32 ...\n", 209 | " ntrk int32 ...\n", 210 | " ndbase int32 ...\n", 211 | " nsbase int32 ...\n", 212 | " nbdate int32 ...\n", 213 | " nbsec int32 ...\n", 214 | " mdt int32 ...\n", 215 | " nlon (lat) int32 dask.array\n", 216 | " wnummax (lat) int32 dask.array\n", 217 | " hyai (ilev) float64 dask.array\n", 218 | " hybi (ilev) float64 dask.array\n", 219 | " hyam (lev) float64 dask.array\n", 220 | " hybm (lev) float64 dask.array\n", 221 | " gw (lat) float64 dask.array\n", 222 | " ndcur (time) int32 dask.array\n", 223 | " nscur (time) int32 dask.array\n", 224 | " date (time) int32 dask.array\n", 225 | " datesec (time) int32 dask.array\n", 226 | " nsteph (time) int32 dask.array\n", 227 | " DTV (time, lev, lat, lon) float32 dask.array\n", 228 | " DTVKE (time, lev, lat, lon) float32 dask.array\n", 229 | " FLNS (time, lat, lon) float32 dask.array\n", 230 | " FLNT (time, lat, lon) float32 dask.array\n", 231 | " FLUT (time, lat, lon) float32 dask.array\n", 232 | " FSNS (time, lat, lon) float32 dask.array\n", 233 | " FSNT (time, lat, lon) float32 dask.array\n", 234 | " LHFLX (time, lat, lon) float32 dask.array\n", 235 | " PHCLDICE (time, lev, lat, lon) float32 dask.array\n", 236 | " PHCLDLIQ (time, lev, lat, lon) float32 dask.array\n", 237 | " PHQ (time, lev, lat, lon) float32 dask.array\n", 238 | " PRECC (time, lat, lon) float32 dask.array\n", 239 | " PRECL (time, lat, lon) float32 dask.array\n", 240 | " PRECSC (time, lat, lon) float32 dask.array\n", 241 | " PRECSL (time, lat, lon) float32 dask.array\n", 242 | " PRECSTEN (time, lat, lon) float32 dask.array\n", 243 | " PRECT (time, lat, lon) float32 dask.array\n", 244 | " PRECTEND (time, lat, lon) float32 dask.array\n", 245 | " PS (time, lat, lon) float32 dask.array\n", 246 | " QAP (time, lev, lat, lon) float32 dask.array\n", 247 | " QCAP (time, lev, lat, lon) float32 dask.array\n", 248 | " QIAP (time, lev, lat, lon) float32 dask.array\n", 249 | " QRL (time, lev, lat, lon) float32 dask.array\n", 250 | " QRS (time, lev, lat, lon) float32 dask.array\n", 251 | " SHFLX (time, lat, lon) float32 dask.array\n", 252 | " SOLIN (time, lat, lon) float32 dask.array\n", 253 | " SPDQ (time, lev, lat, lon) float32 dask.array\n", 254 | " SPDT (time, lev, lat, lon) float32 dask.array\n", 255 | " T (time, lev, lat, lon) float32 dask.array\n", 256 | " TAP (time, lev, lat, lon) float32 dask.array\n", 257 | " TPHYSTND (time, lev, lat, lon) float32 dask.array\n", 258 | " TS (time, lat, lon) float32 dask.array\n", 259 | " UAP (time, lev, lat, lon) float32 dask.array\n", 260 | " VAP (time, lev, lat, lon) float32 dask.array\n", 261 | " VD01 (time, lev, lat, lon) float32 dask.array\n", 262 | " VPHYSTND (time, lev, lat, lon) float32 dask.array\n", 263 | "Attributes:\n", 264 | " Conventions: CF-1.0\n", 265 | " source: CAM\n", 266 | " case: AndKua_aqua_SPCAM3.0_sp_fbp32\n", 267 | " title: \n", 268 | " logname: tg847872\n", 269 | " host: \n", 270 | " Version: $Name: $\n", 271 | " revision_Id: $Id: history.F90,v 1.26.2.38 2003/12/15 18:52:35 hender Exp $" 272 | ] 273 | }, 274 | "execution_count": 54, 275 | "metadata": {}, 276 | "output_type": "execute_result" 277 | } 278 | ], 279 | "source": [ 280 | "ds" 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": {}, 286 | "source": [ 287 | "## Random columns for machine learning parameterizations\n", 288 | "\n", 289 | "For the work on ML parameterizations that a few of us are doing now, we would like to work one column at a time. One simple example would be predicting the temperature and humidity tendencies (TPHYSTND and PHQ) from the temperature and humidity profiles (TAP and QAP). \n", 290 | "\n", 291 | "This means we would like to give the neural network a stacked vector containing the inputs (2 x 30 levels) and ask it to predict the outputs (also 2 x 30 levels).\n", 292 | "\n", 293 | "In NN training, we usually train on a batch of data at a time. Batches typically have a few hundred samples (columns in our case). It is really important that the samples in a batch are not correlated but rather represent a random sample of the entire dataset.\n", 294 | "\n", 295 | "To achieve this we will write a data generator that loads the batches by randomly selecting along the time, lat and lon dimensions." 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 57, 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "class DataGenerator(object):\n", 305 | " \"\"\"\n", 306 | " Data generator that randomly (if shuffle = True) picks columns from the dataset and returns them in \n", 307 | " batches. For each column the input variables and output variables will be stacked.\n", 308 | " \"\"\"\n", 309 | " def __init__(self, fn_or_ds, batch_size=128, input_vars=['TAP', 'QAP'], output_vars=['TPHYSTND', 'PHQ'], \n", 310 | " shuffle=True, engine='netcdf4'):\n", 311 | " self.ds = xr.open_mfdataset(fn_or_ds, engine=engine) if type(fn_or_ds) is str else fn_or_ds\n", 312 | " self.batch_size = batch_size\n", 313 | " self.input_vars = input_vars\n", 314 | " self.output_vars = output_vars\n", 315 | " self.ntime, self.nlat, self.nlon = self.ds.time.size, self.ds.lat.size, self.ds.lon.size\n", 316 | " self.ntot = self.ntime * self.nlat * self.ntime\n", 317 | " self.n_batches = self.ntot // batch_size\n", 318 | " self.indices = np.arange(self.ntot)\n", 319 | " if shuffle:\n", 320 | " self.indices = np.random.permutation(self.indices)\n", 321 | " def __getitem__(self, index):\n", 322 | " time_indices, lat_indices, lon_indices = np.unravel_index(\n", 323 | " self.indices[index*self.batch_size:(index+1)*self.batch_size], (self.ntime, self.nlat, self.nlon)\n", 324 | " )\n", 325 | " \n", 326 | " X, Y = [], []\n", 327 | " for itime, ilat, ilon in zip(time_indices, lat_indices, lon_indices):\n", 328 | " X.append(\n", 329 | " np.concatenate(\n", 330 | " [self.ds[v].isel(time=itime, lat=ilat, lon=ilon).values for v in self.input_vars]\n", 331 | " )\n", 332 | " )\n", 333 | " Y.append(\n", 334 | " np.concatenate(\n", 335 | " [self.ds[v].isel(time=itime, lat=ilat, lon=ilon).values for v in self.output_vars]\n", 336 | " )\n", 337 | " )\n", 338 | "\n", 339 | " return np.array(X), np.array(Y)" 340 | ] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "metadata": {}, 345 | "source": [ 346 | "### Multi-file dataset\n", 347 | "\n", 348 | "Let's start by using the split dataset `sample_SPCAM_1.nc` and `sample_SPCAM_2.nc`." 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": 58, 354 | "metadata": {}, 355 | "outputs": [], 356 | "source": [ 357 | "gen = DataGenerator(DATADIR + 'sample_SPCAM_[1-2].nc')" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 59, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "# This is how we get one batch of inputs and corresponding outputs\n", 367 | "x, y = gen[0]" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 60, 373 | "metadata": {}, 374 | "outputs": [ 375 | { 376 | "data": { 377 | "text/plain": [ 378 | "((128, 60), (128, 60))" 379 | ] 380 | }, 381 | "execution_count": 60, 382 | "metadata": {}, 383 | "output_type": "execute_result" 384 | } 385 | ], 386 | "source": [ 387 | "x.shape, y.shape" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 61, 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "# A little test function to check the timing.\n", 397 | "def test(g, n):\n", 398 | " for i in range(n):\n", 399 | " x, y = g[i]" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 64, 405 | "metadata": {}, 406 | "outputs": [ 407 | { 408 | "name": "stdout", 409 | "output_type": "stream", 410 | "text": [ 411 | "CPU times: user 13.3 s, sys: 1.34 s, total: 14.6 s\n", 412 | "Wall time: 14.3 s\n" 413 | ] 414 | } 415 | ], 416 | "source": [ 417 | "%%time\n", 418 | "test(gen, 10)" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": 65, 424 | "metadata": {}, 425 | "outputs": [ 426 | { 427 | "name": "stdout", 428 | "output_type": "stream", 429 | "text": [ 430 | "CPU times: user 12.5 s, sys: 1.28 s, total: 13.8 s\n", 431 | "Wall time: 13.5 s\n" 432 | ] 433 | } 434 | ], 435 | "source": [ 436 | "# does shuffling make a big difference\n", 437 | "gen = DataGenerator(DATADIR + 'sample_SPCAM_[1-2].nc', shuffle=True)\n", 438 | "%time test(gen, 10)" 439 | ] 440 | }, 441 | { 442 | "cell_type": "markdown", 443 | "metadata": {}, 444 | "source": [ 445 | "So it takes more than one second to read one batch. This is way too slow to train a neural network in a reasonable amount of time. Shuffling doesn't seem to be a huge problem, but even without shuffling I am probably accessing the data in a different order than saved on disc. \n", 446 | "\n", 447 | "Let's check what actually takes that long." 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 66, 453 | "metadata": {}, 454 | "outputs": [ 455 | { 456 | "name": "stdout", 457 | "output_type": "stream", 458 | "text": [ 459 | "The line_profiler extension is already loaded. To reload it, use:\n", 460 | " %reload_ext line_profiler\n" 461 | ] 462 | } 463 | ], 464 | "source": [ 465 | "%load_ext line_profiler" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 67, 471 | "metadata": {}, 472 | "outputs": [], 473 | "source": [ 474 | "%lprun -f gen.__getitem__ test(gen, 10)" 475 | ] 476 | }, 477 | { 478 | "cell_type": "markdown", 479 | "metadata": {}, 480 | "source": [ 481 | "Output:\n", 482 | "\n", 483 | "```\n", 484 | "Timer unit: 1e-06 s\n", 485 | "\n", 486 | "Total time: 24.5229 s\n", 487 | "File: \n", 488 | "Function: __getitem__ at line 18\n", 489 | "\n", 490 | "Line # Hits Time Per Hit % Time Line Contents\n", 491 | "==============================================================\n", 492 | " 18 def __getitem__(self, index):\n", 493 | " 19 10 17.0 1.7 0.0 time_indices, lat_indices, lon_indices = np.unravel_index(\n", 494 | " 20 10 267.0 26.7 0.0 self.indices[index*self.batch_size:(index+1)*self.batch_size], (self.ntime, self.nlat, self.nlon)\n", 495 | " 21 )\n", 496 | " 22 \n", 497 | " 23 10 10.0 1.0 0.0 X, Y = [], []\n", 498 | " 24 1290 4642.0 3.6 0.0 for itime, ilat, ilon in zip(time_indices, lat_indices, lon_indices):\n", 499 | " 25 1280 1399.0 1.1 0.0 X.append(\n", 500 | " 26 1280 1721.0 1.3 0.0 np.concatenate(\n", 501 | " 27 1280 12256070.0 9575.1 50.0 [self.ds[v].isel(time=itime, lat=ilat, lon=ilon).values for v in self.input_vars]\n", 502 | " 28 )\n", 503 | " 29 )\n", 504 | " 30 1280 2393.0 1.9 0.0 Y.append(\n", 505 | " 31 1280 1750.0 1.4 0.0 np.concatenate(\n", 506 | " 32 1280 12253415.0 9573.0 50.0 [self.ds[v].isel(time=itime, lat=ilat, lon=ilon).values for v in self.output_vars]\n", 507 | " 33 )\n", 508 | " 34 )\n", 509 | " 35 \n", 510 | " 36 10 1218.0 121.8 0.0 return np.array(X), np.array(Y)\n", 511 | "```" 512 | ] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "metadata": {}, 517 | "source": [ 518 | "### Using the concatenated dataset\n", 519 | "\n", 520 | "Let's see whether it makes a difference to use the pre-concatenated dataset." 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": 74, 526 | "metadata": {}, 527 | "outputs": [ 528 | { 529 | "name": "stdout", 530 | "output_type": "stream", 531 | "text": [ 532 | "CPU times: user 5.93 s, sys: 984 ms, total: 6.91 s\n", 533 | "Wall time: 6.91 s\n" 534 | ] 535 | } 536 | ], 537 | "source": [ 538 | "ds = xr.open_dataset(f'{DATADIR}sample_SPCAM_concat.nc')\n", 539 | "gen = DataGenerator(ds, shuffle=True)\n", 540 | "%time test(gen, 10)" 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "execution_count": 76, 546 | "metadata": {}, 547 | "outputs": [ 548 | { 549 | "name": "stdout", 550 | "output_type": "stream", 551 | "text": [ 552 | "CPU times: user 11.5 s, sys: 1.25 s, total: 12.8 s\n", 553 | "Wall time: 12.5 s\n" 554 | ] 555 | } 556 | ], 557 | "source": [ 558 | "ds = xr.open_mfdataset(f'{DATADIR}sample_SPCAM_concat.nc')\n", 559 | "gen = DataGenerator(ds, shuffle=True)\n", 560 | "%time test(gen, 10)" 561 | ] 562 | }, 563 | { 564 | "cell_type": "markdown", 565 | "metadata": {}, 566 | "source": [ 567 | "So yes, it approximately halves the time but only if the single dataset is NOT opened with `open_mfdataset`." 568 | ] 569 | }, 570 | { 571 | "cell_type": "markdown", 572 | "metadata": {}, 573 | "source": [ 574 | "### With h5py engine\n", 575 | "\n", 576 | "Let's see whether using the h5py backend makes a difference" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": 77, 582 | "metadata": {}, 583 | "outputs": [], 584 | "source": [ 585 | "import h5netcdf" 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": 79, 591 | "metadata": {}, 592 | "outputs": [], 593 | "source": [ 594 | "ds.close()" 595 | ] 596 | }, 597 | { 598 | "cell_type": "code", 599 | "execution_count": 80, 600 | "metadata": {}, 601 | "outputs": [], 602 | "source": [ 603 | "ds = xr.open_dataset(f'{DATADIR}sample_SPCAM_concat.nc', engine='h5netcdf')\n", 604 | "gen = DataGenerator(ds)" 605 | ] 606 | }, 607 | { 608 | "cell_type": "code", 609 | "execution_count": 81, 610 | "metadata": {}, 611 | "outputs": [ 612 | { 613 | "name": "stdout", 614 | "output_type": "stream", 615 | "text": [ 616 | "CPU times: user 6.97 s, sys: 972 ms, total: 7.94 s\n", 617 | "Wall time: 7.8 s\n" 618 | ] 619 | } 620 | ], 621 | "source": [ 622 | "%%time\n", 623 | "test(gen, 10)" 624 | ] 625 | }, 626 | { 627 | "cell_type": "markdown", 628 | "metadata": {}, 629 | "source": [ 630 | "Doesn't seem to speed it up" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": 83, 636 | "metadata": {}, 637 | "outputs": [], 638 | "source": [ 639 | "ds.close()" 640 | ] 641 | }, 642 | { 643 | "cell_type": "markdown", 644 | "metadata": {}, 645 | "source": [ 646 | "### Using plain h5py\n", 647 | "\n", 648 | "Let's write a version of the data generator that uses plain h5py for data loading." 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": 82, 654 | "metadata": {}, 655 | "outputs": [], 656 | "source": [ 657 | "class DataGeneratorH5(object):\n", 658 | " def __init__(self, fn, batch_size=128, input_vars=['TAP', 'QAP'], output_vars=['TPHYSTND', 'PHQ'], shuffle=True):\n", 659 | " self.ds = xr.open_dataset(fn)\n", 660 | " self.batch_size = batch_size\n", 661 | " self.input_vars = input_vars\n", 662 | " self.output_vars = output_vars\n", 663 | " self.ntime, self.nlat, self.nlon = self.ds.time.size, self.ds.lat.size, self.ds.lon.size\n", 664 | " self.ntot = self.ntime * self.nlat * self.ntime\n", 665 | " self.n_batches = self.ntot // batch_size\n", 666 | " self.indices = np.arange(self.ntot)\n", 667 | " if shuffle:\n", 668 | " self.indices = np.random.permutation(self.indices)\n", 669 | " \n", 670 | " # Close xarray dataset and open h5py object\n", 671 | " self.ds.close()\n", 672 | " self.ds = h5py.File(fn, 'r')\n", 673 | " \n", 674 | " def __getitem__(self, index):\n", 675 | " time_indices, lat_indices, lon_indices = np.unravel_index(\n", 676 | " self.indices[index*self.batch_size:(index+1)*self.batch_size], (self.ntime, self.nlat, self.nlon)\n", 677 | " )\n", 678 | " \n", 679 | " X, Y = [], []\n", 680 | " for itime, ilat, ilon in zip(time_indices, lat_indices, lon_indices):\n", 681 | " X.append(\n", 682 | " np.concatenate(\n", 683 | " [self.ds[v][itime, :, ilat, ilon] for v in self.input_vars]\n", 684 | " )\n", 685 | " )\n", 686 | " Y.append(\n", 687 | " np.concatenate(\n", 688 | " [self.ds[v][itime, :, ilat, ilon] for v in self.output_vars]\n", 689 | " )\n", 690 | " )\n", 691 | "\n", 692 | " return np.array(X), np.array(Y)" 693 | ] 694 | }, 695 | { 696 | "cell_type": "code", 697 | "execution_count": 84, 698 | "metadata": {}, 699 | "outputs": [], 700 | "source": [ 701 | "gen = DataGeneratorH5(f'{DATADIR}sample_SPCAM_concat.nc')" 702 | ] 703 | }, 704 | { 705 | "cell_type": "code", 706 | "execution_count": 85, 707 | "metadata": {}, 708 | "outputs": [ 709 | { 710 | "name": "stdout", 711 | "output_type": "stream", 712 | "text": [ 713 | "CPU times: user 1.78 s, sys: 860 ms, total: 2.64 s\n", 714 | "Wall time: 2.61 s\n" 715 | ] 716 | } 717 | ], 718 | "source": [ 719 | "%%time\n", 720 | "test(gen, 10)" 721 | ] 722 | }, 723 | { 724 | "cell_type": "code", 725 | "execution_count": 96, 726 | "metadata": {}, 727 | "outputs": [], 728 | "source": [ 729 | "gen.ds.close()" 730 | ] 731 | }, 732 | { 733 | "cell_type": "markdown", 734 | "metadata": {}, 735 | "source": [ 736 | "So this is significantly faster than xarray." 737 | ] 738 | }, 739 | { 740 | "cell_type": "markdown", 741 | "metadata": {}, 742 | "source": [ 743 | "## Use in a simple neural network\n", 744 | "\n", 745 | "How would we actually use this data generator for network training...\n", 746 | "\n", 747 | "Note that this neural network will not actually learn much because we didn't normalize the input data. But we only care about computational performance here, right?" 748 | ] 749 | }, 750 | { 751 | "cell_type": "code", 752 | "execution_count": 87, 753 | "metadata": {}, 754 | "outputs": [], 755 | "source": [ 756 | "import tensorflow as tf\n", 757 | "from tensorflow.keras.layers import *\n", 758 | "from tensorflow.keras.models import Sequential" 759 | ] 760 | }, 761 | { 762 | "cell_type": "code", 763 | "execution_count": 88, 764 | "metadata": {}, 765 | "outputs": [ 766 | { 767 | "data": { 768 | "text/plain": [ 769 | "'2.1.6-tf'" 770 | ] 771 | }, 772 | "execution_count": 88, 773 | "metadata": {}, 774 | "output_type": "execute_result" 775 | } 776 | ], 777 | "source": [ 778 | "tf.keras.__version__" 779 | ] 780 | }, 781 | { 782 | "cell_type": "code", 783 | "execution_count": 89, 784 | "metadata": {}, 785 | "outputs": [], 786 | "source": [ 787 | "model = Sequential([\n", 788 | " Dense(128, input_shape=(60,), activation='relu'),\n", 789 | " Dense(60),\n", 790 | "])" 791 | ] 792 | }, 793 | { 794 | "cell_type": "code", 795 | "execution_count": 90, 796 | "metadata": {}, 797 | "outputs": [ 798 | { 799 | "name": "stdout", 800 | "output_type": "stream", 801 | "text": [ 802 | "_________________________________________________________________\n", 803 | "Layer (type) Output Shape Param # \n", 804 | "=================================================================\n", 805 | "dense (Dense) (None, 128) 7808 \n", 806 | "_________________________________________________________________\n", 807 | "dense_1 (Dense) (None, 60) 7740 \n", 808 | "=================================================================\n", 809 | "Total params: 15,548\n", 810 | "Trainable params: 15,548\n", 811 | "Non-trainable params: 0\n", 812 | "_________________________________________________________________\n" 813 | ] 814 | } 815 | ], 816 | "source": [ 817 | "model.summary()" 818 | ] 819 | }, 820 | { 821 | "cell_type": "code", 822 | "execution_count": 91, 823 | "metadata": {}, 824 | "outputs": [], 825 | "source": [ 826 | "model.compile('adam', 'mse')" 827 | ] 828 | }, 829 | { 830 | "cell_type": "code", 831 | "execution_count": 99, 832 | "metadata": {}, 833 | "outputs": [], 834 | "source": [ 835 | "# Load the xarray version using the concatenated dataset\n", 836 | "ds = xr.open_dataset(f'{DATADIR}sample_SPCAM_concat.nc')\n", 837 | "gen = DataGenerator(ds, shuffle=True)" 838 | ] 839 | }, 840 | { 841 | "cell_type": "code", 842 | "execution_count": 101, 843 | "metadata": {}, 844 | "outputs": [ 845 | { 846 | "name": "stdout", 847 | "output_type": "stream", 848 | "text": [ 849 | "Epoch 1/1\n", 850 | " 37/4608 [..............................] - ETA: 1:04:11 - loss: 1733.6299" 851 | ] 852 | }, 853 | { 854 | "ename": "KeyboardInterrupt", 855 | "evalue": "", 856 | "output_type": "error", 857 | "traceback": [ 858 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 859 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 860 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_generator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_batches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 861 | "\u001b[0;32m~/miniconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py\u001b[0m in \u001b[0;36mfit_generator\u001b[0;34m(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)\u001b[0m\n\u001b[1;32m 2175\u001b[0m \u001b[0muse_multiprocessing\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0muse_multiprocessing\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2176\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2177\u001b[0;31m initial_epoch=initial_epoch)\n\u001b[0m\u001b[1;32m 2178\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2179\u001b[0m def evaluate_generator(self,\n", 862 | "\u001b[0;32m~/miniconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_generator.py\u001b[0m in \u001b[0;36mfit_generator\u001b[0;34m(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0mbatch_index\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0msteps_done\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m \u001b[0mgenerator_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput_generator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator_output\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'__len__'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 863 | "\u001b[0;32m~/miniconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/keras/utils/data_utils.py\u001b[0m in \u001b[0;36mget\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 823\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 824\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 825\u001b[0;31m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msleep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait_time\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 826\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 827\u001b[0m \u001b[0;31m# Make sure to rethrow the first exception in the queue, if any\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 864 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 865 | ] 866 | } 867 | ], 868 | "source": [ 869 | "model.fit_generator(iter(gen), steps_per_epoch=gen.n_batches)" 870 | ] 871 | }, 872 | { 873 | "cell_type": "markdown", 874 | "metadata": {}, 875 | "source": [ 876 | "So as you can see, it would take around 1 hour to go through one epoch (i.e. the entire dataset once). This is crazy slow since we only used 2 days of data. The full dataset contains a year of data..." 877 | ] 878 | }, 879 | { 880 | "cell_type": "markdown", 881 | "metadata": {}, 882 | "source": [ 883 | "## Pre-processing the dataset\n", 884 | "\n", 885 | "What I have resorted to to solve this issue is to prestack the data, preshuffle the data and save it to disc conveniently.\n", 886 | "\n", 887 | "These files contain the exactly same information for the input (features) and output (targets) variables required.\n", 888 | "\n", 889 | "The files only have two dimensions: sample, which is the shuffled, flattened time, lat and lon dimensions and lev which is the stacked vertical coordinate.\n", 890 | "\n", 891 | "The preprocessing for these two files only takes a few seconds but for an entire year of data, the preprocessing alone can take around an hour.\n" 892 | ] 893 | }, 894 | { 895 | "cell_type": "code", 896 | "execution_count": 163, 897 | "metadata": {}, 898 | "outputs": [ 899 | { 900 | "name": "stdout", 901 | "output_type": "stream", 902 | "text": [ 903 | "--2019-02-07 15:42:32-- https://zenodo.org/record/2559313/files/preproc_features.nc\n", 904 | "Resolving zenodo.org (zenodo.org)... 137.138.76.77\n", 905 | "Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.\n", 906 | "HTTP request sent, awaiting response... 200 OK\n", 907 | "Length: 205465847 (196M) [application/octet-stream]\n", 908 | "Saving to: ‘/local/S.Rasp/tmp/preproc_features.nc.2’\n", 909 | "\n", 910 | "preproc_features.nc 100%[===================>] 195.95M 10.8MB/s in 15s \n", 911 | "\n", 912 | "2019-02-07 15:42:48 (13.0 MB/s) - ‘/local/S.Rasp/tmp/preproc_features.nc.2’ saved [205465847/205465847]\n", 913 | "\n", 914 | "--2019-02-07 15:42:48-- https://zenodo.org/record/2559313/files/preproc_targets.nc\n", 915 | "Resolving zenodo.org (zenodo.org)... 137.138.76.77\n", 916 | "Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.\n", 917 | "HTTP request sent, awaiting response... 200 OK\n", 918 | "Length: 205465846 (196M) [application/octet-stream]\n", 919 | "Saving to: ‘/local/S.Rasp/tmp/preproc_targets.nc.1’\n", 920 | "\n", 921 | "preproc_targets.nc. 100%[===================>] 195.95M 9.98MB/s in 9.5s \n", 922 | "\n", 923 | "2019-02-07 15:42:58 (20.6 MB/s) - ‘/local/S.Rasp/tmp/preproc_targets.nc.1’ saved [205465846/205465846]\n", 924 | "\n" 925 | ] 926 | } 927 | ], 928 | "source": [ 929 | "!wget -P $DATADIR https://zenodo.org/record/2559313/files/preproc_features.nc\n", 930 | "!wget -P $DATADIR https://zenodo.org/record/2559313/files/preproc_targets.nc" 931 | ] 932 | }, 933 | { 934 | "cell_type": "code", 935 | "execution_count": 104, 936 | "metadata": {}, 937 | "outputs": [ 938 | { 939 | "name": "stdout", 940 | "output_type": "stream", 941 | "text": [ 942 | "-rw-r--r-- 1 S.Rasp ls-craig 196M Feb 7 13:57 /local/S.Rasp/tmp//preproc_features.nc\r\n", 943 | "-rw-r--r-- 1 S.Rasp ls-craig 196M Feb 7 13:57 /local/S.Rasp/tmp//preproc_targets.nc\r\n" 944 | ] 945 | } 946 | ], 947 | "source": [ 948 | "!ls -lh $DATADIR/preproc*" 949 | ] 950 | }, 951 | { 952 | "cell_type": "code", 953 | "execution_count": 105, 954 | "metadata": {}, 955 | "outputs": [], 956 | "source": [ 957 | "ds = xr.open_dataset(f'{DATADIR}preproc_features.nc')" 958 | ] 959 | }, 960 | { 961 | "cell_type": "code", 962 | "execution_count": 106, 963 | "metadata": {}, 964 | "outputs": [ 965 | { 966 | "data": { 967 | "text/plain": [ 968 | "\n", 969 | "Dimensions: (feature_lev: 60, sample: 778240)\n", 970 | "Coordinates:\n", 971 | " * feature_lev (feature_lev) int64 0 1 2 3 4 5 6 7 ... 53 54 55 56 57 58 59\n", 972 | " time (sample) int64 ...\n", 973 | " lat (sample) float64 ...\n", 974 | " lon (sample) float64 ...\n", 975 | " feature_names (feature_lev) object ...\n", 976 | "Dimensions without coordinates: sample\n", 977 | "Data variables:\n", 978 | " features (sample, feature_lev) float32 ...\n", 979 | "Attributes:\n", 980 | " log: \\n Time: 2019-02-07T13:57:24\\n\\n Executed command:\\n\\n ..." 981 | ] 982 | }, 983 | "execution_count": 106, 984 | "metadata": {}, 985 | "output_type": "execute_result" 986 | } 987 | ], 988 | "source": [ 989 | "ds" 990 | ] 991 | }, 992 | { 993 | "cell_type": "code", 994 | "execution_count": 129, 995 | "metadata": {}, 996 | "outputs": [], 997 | "source": [ 998 | "# Write a new data generator\n", 999 | "class DataGeneratorPreproc(object):\n", 1000 | " \"\"\"\n", 1001 | " Data generator that randomly (if shuffle = True) picks columns from the dataset and returns them in \n", 1002 | " batches. For each column the input variables and output variables will be stacked.\n", 1003 | " \"\"\"\n", 1004 | " def __init__(self, feature_fn, target_fn, batch_size=128, shuffle=True, engine='netcdf4'):\n", 1005 | " self.feature_ds = xr.open_dataset(feature_fn, engine=engine)\n", 1006 | " self.target_ds = xr.open_dataset(target_fn, engine=engine)\n", 1007 | " self.batch_size = batch_size\n", 1008 | " self.ntot = self.feature_ds.sample.size\n", 1009 | " self.n_batches = self.ntot // batch_size\n", 1010 | " self.indices = np.arange(self.ntot)\n", 1011 | " if shuffle:\n", 1012 | " self.indices = np.random.permutation(self.indices)\n", 1013 | " def __getitem__(self, index):\n", 1014 | " batch_indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]\n", 1015 | " \n", 1016 | " X = self.feature_ds.features.isel(sample=batch_indices)\n", 1017 | " Y = self.target_ds.targets.isel(sample=batch_indices)\n", 1018 | "\n", 1019 | " return X, Y" 1020 | ] 1021 | }, 1022 | { 1023 | "cell_type": "code", 1024 | "execution_count": 130, 1025 | "metadata": {}, 1026 | "outputs": [], 1027 | "source": [ 1028 | "gen = DataGeneratorPreproc(f'{DATADIR}preproc_features.nc', f'{DATADIR}preproc_targets.nc')" 1029 | ] 1030 | }, 1031 | { 1032 | "cell_type": "code", 1033 | "execution_count": 131, 1034 | "metadata": {}, 1035 | "outputs": [], 1036 | "source": [ 1037 | "x, y = gen[0]" 1038 | ] 1039 | }, 1040 | { 1041 | "cell_type": "code", 1042 | "execution_count": 132, 1043 | "metadata": {}, 1044 | "outputs": [ 1045 | { 1046 | "data": { 1047 | "text/plain": [ 1048 | "((128, 60), (128, 60))" 1049 | ] 1050 | }, 1051 | "execution_count": 132, 1052 | "metadata": {}, 1053 | "output_type": "execute_result" 1054 | } 1055 | ], 1056 | "source": [ 1057 | "x.shape, y.shape" 1058 | ] 1059 | }, 1060 | { 1061 | "cell_type": "code", 1062 | "execution_count": 133, 1063 | "metadata": {}, 1064 | "outputs": [ 1065 | { 1066 | "name": "stdout", 1067 | "output_type": "stream", 1068 | "text": [ 1069 | "CPU times: user 84 ms, sys: 0 ns, total: 84 ms\n", 1070 | "Wall time: 81.6 ms\n" 1071 | ] 1072 | } 1073 | ], 1074 | "source": [ 1075 | "%%time\n", 1076 | "test(gen, 10)" 1077 | ] 1078 | }, 1079 | { 1080 | "cell_type": "code", 1081 | "execution_count": 134, 1082 | "metadata": {}, 1083 | "outputs": [], 1084 | "source": [ 1085 | "gen = DataGeneratorPreproc(f'{DATADIR}preproc_features.nc', f'{DATADIR}preproc_targets.nc', shuffle=False)" 1086 | ] 1087 | }, 1088 | { 1089 | "cell_type": "code", 1090 | "execution_count": 135, 1091 | "metadata": {}, 1092 | "outputs": [ 1093 | { 1094 | "name": "stdout", 1095 | "output_type": "stream", 1096 | "text": [ 1097 | "CPU times: user 84 ms, sys: 0 ns, total: 84 ms\n", 1098 | "Wall time: 83.9 ms\n" 1099 | ] 1100 | } 1101 | ], 1102 | "source": [ 1103 | "%%time\n", 1104 | "test(gen, 10)" 1105 | ] 1106 | }, 1107 | { 1108 | "cell_type": "code", 1109 | "execution_count": 152, 1110 | "metadata": {}, 1111 | "outputs": [], 1112 | "source": [ 1113 | "gen.feature_ds.close(); gen.target_ds.close()" 1114 | ] 1115 | }, 1116 | { 1117 | "cell_type": "code", 1118 | "execution_count": 139, 1119 | "metadata": {}, 1120 | "outputs": [], 1121 | "source": [ 1122 | "gen = DataGeneratorPreproc(f'{DATADIR}preproc_features.nc', f'{DATADIR}preproc_targets.nc', engine='h5netcdf')" 1123 | ] 1124 | }, 1125 | { 1126 | "cell_type": "code", 1127 | "execution_count": 140, 1128 | "metadata": {}, 1129 | "outputs": [ 1130 | { 1131 | "name": "stdout", 1132 | "output_type": "stream", 1133 | "text": [ 1134 | "CPU times: user 84 ms, sys: 0 ns, total: 84 ms\n", 1135 | "Wall time: 80.6 ms\n" 1136 | ] 1137 | } 1138 | ], 1139 | "source": [ 1140 | "%%time\n", 1141 | "test(gen, 10)" 1142 | ] 1143 | }, 1144 | { 1145 | "cell_type": "markdown", 1146 | "metadata": {}, 1147 | "source": [ 1148 | "So these are the sort of times that are required for training a neural network." 1149 | ] 1150 | }, 1151 | { 1152 | "cell_type": "markdown", 1153 | "metadata": {}, 1154 | "source": [ 1155 | "### Pure h5py version" 1156 | ] 1157 | }, 1158 | { 1159 | "cell_type": "code", 1160 | "execution_count": 158, 1161 | "metadata": {}, 1162 | "outputs": [], 1163 | "source": [ 1164 | "class DataGeneratorPreprocH5(object):\n", 1165 | " \"\"\"\n", 1166 | " Data generator that randomly (if shuffle = True) picks columns from the dataset and returns them in \n", 1167 | " batches. For each column the input variables and output variables will be stacked.\n", 1168 | " \"\"\"\n", 1169 | " def __init__(self, feature_fn, target_fn, batch_size=128):\n", 1170 | " self.feature_ds = xr.open_dataset(feature_fn)\n", 1171 | " self.target_ds = xr.open_dataset(target_fn)\n", 1172 | " self.batch_size = batch_size\n", 1173 | " self.ntot = self.feature_ds.sample.size\n", 1174 | " self.n_batches = self.ntot // batch_size\n", 1175 | " \n", 1176 | " # Close xarray dataset and open h5py object\n", 1177 | " self.feature_ds.close()\n", 1178 | " self.feature_ds = h5py.File(feature_fn, 'r')\n", 1179 | " self.target_ds.close()\n", 1180 | " self.target_ds = h5py.File(target_fn, 'r')\n", 1181 | " \n", 1182 | " def __getitem__(self, index):\n", 1183 | " \n", 1184 | " X = self.feature_ds['features'][index*self.batch_size:(index+1)*self.batch_size, :]\n", 1185 | " Y = self.target_ds['targets'][index*self.batch_size:(index+1)*self.batch_size, :]\n", 1186 | "\n", 1187 | " return X, Y" 1188 | ] 1189 | }, 1190 | { 1191 | "cell_type": "code", 1192 | "execution_count": 159, 1193 | "metadata": {}, 1194 | "outputs": [], 1195 | "source": [ 1196 | "gen.feature_ds.close(); gen.target_ds.close()" 1197 | ] 1198 | }, 1199 | { 1200 | "cell_type": "code", 1201 | "execution_count": 160, 1202 | "metadata": {}, 1203 | "outputs": [], 1204 | "source": [ 1205 | "gen = DataGeneratorPreprocH5(f'{DATADIR}preproc_features.nc', f'{DATADIR}preproc_targets.nc')" 1206 | ] 1207 | }, 1208 | { 1209 | "cell_type": "code", 1210 | "execution_count": 161, 1211 | "metadata": {}, 1212 | "outputs": [ 1213 | { 1214 | "name": "stdout", 1215 | "output_type": "stream", 1216 | "text": [ 1217 | "CPU times: user 8 ms, sys: 0 ns, total: 8 ms\n", 1218 | "Wall time: 6.61 ms\n" 1219 | ] 1220 | } 1221 | ], 1222 | "source": [ 1223 | "%%time\n", 1224 | "test(gen, 10)" 1225 | ] 1226 | }, 1227 | { 1228 | "cell_type": "markdown", 1229 | "metadata": {}, 1230 | "source": [ 1231 | "So again, the pure h5py version is an order of magnitude faster than the xarray version." 1232 | ] 1233 | }, 1234 | { 1235 | "cell_type": "markdown", 1236 | "metadata": {}, 1237 | "source": [ 1238 | "## End" 1239 | ] 1240 | }, 1241 | { 1242 | "cell_type": "code", 1243 | "execution_count": null, 1244 | "metadata": {}, 1245 | "outputs": [], 1246 | "source": [] 1247 | } 1248 | ], 1249 | "metadata": { 1250 | "kernelspec": { 1251 | "display_name": "Python 3", 1252 | "language": "python", 1253 | "name": "python3" 1254 | }, 1255 | "language_info": { 1256 | "codemirror_mode": { 1257 | "name": "ipython", 1258 | "version": 3 1259 | }, 1260 | "file_extension": ".py", 1261 | "mimetype": "text/x-python", 1262 | "name": "python", 1263 | "nbconvert_exporter": "python", 1264 | "pygments_lexer": "ipython3", 1265 | "version": "3.6.5" 1266 | }, 1267 | "toc": { 1268 | "base_numbering": 1, 1269 | "nav_menu": {}, 1270 | "number_sections": true, 1271 | "sideBar": false, 1272 | "skip_h1_title": true, 1273 | "title_cell": "Table of Contents", 1274 | "title_sidebar": "Contents", 1275 | "toc_cell": false, 1276 | "toc_position": {}, 1277 | "toc_section_display": true, 1278 | "toc_window_display": false 1279 | } 1280 | }, 1281 | "nbformat": 4, 1282 | "nbformat_minor": 2 1283 | } 1284 | --------------------------------------------------------------------------------