├── run ├── data.csv ├── simple.py ├── nyc.py ├── nyc-function.py └── pytorch.py ├── images └── sea-level-rise.png ├── .pre-commit-config.yaml ├── .dockerignore ├── batch ├── hello.py ├── hello.sh ├── accelerate-example.sh ├── gdal.sh └── nlp_example.py ├── README.md ├── .devcontainer ├── noop.txt ├── dev.Dockerfile └── devcontainer.json ├── national-water-model ├── environment.yaml ├── xarray-water-model.py ├── make-timelapse-video.ipynb └── xarray-water-model.ipynb ├── datashader.yml ├── spark.yml ├── pytorch.yml ├── environment.yml ├── .github └── workflows │ ├── linting.yml │ └── docker.yml ├── Dockerfile ├── reproject.sh ├── satellite-imagery.ipynb ├── spark.ipynb ├── xarray.ipynb ├── sea-level-rise.ipynb ├── datashader.ipynb ├── futures.ipynb ├── xgboost-optuna.ipynb ├── xgboost.ipynb ├── uber-lyft.ipynb ├── pytorch.ipynb ├── arxiv-matplotlib.ipynb └── pytorch-optuna.ipynb /run/data.csv: -------------------------------------------------------------------------------- 1 | name,x,y 2 | Alice,1,100 3 | Bob,2,200 4 | Charlie,3,300 5 | -------------------------------------------------------------------------------- /images/sea-level-rise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coiled/examples/HEAD/images/sea-level-rise.png -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/kynan/nbstripout 3 | rev: 0.6.1 4 | hooks: 5 | - id: nbstripout -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .dockerignore 2 | .git 3 | .github 4 | .gitignore 5 | .ipynb_checkpoints 6 | .pre-commit-config.yaml 7 | .devcontainer 8 | Dockerfile -------------------------------------------------------------------------------- /batch/hello.py: -------------------------------------------------------------------------------- 1 | #COILED memory 8 GiB 2 | #COILED ntasks 10 3 | 4 | #COILED region us-east-2 5 | 6 | import os 7 | print("Hello from", os.environ["COILED_ARRAY_TASK_ID"]) 8 | -------------------------------------------------------------------------------- /batch/hello.sh: -------------------------------------------------------------------------------- 1 | #COILED memory 8 GiB 2 | #COILED ntasks 100 3 | #COILED container ubuntu:latest 4 | #COILED region us-east-2 5 | 6 | 7 | echo Hello from $COILED_ARRAY_TASK_ID ! 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Coiled Examples 2 | =============== 3 | 4 | Some simple example notebooks to show how to use Dask and Coiled. 5 | 6 | ```bash 7 | pip install coiled 8 | coiled setup 9 | ``` 10 | -------------------------------------------------------------------------------- /.devcontainer/noop.txt: -------------------------------------------------------------------------------- 1 | This file is copied into the container along with environment.yml* from the 2 | parent folder. This is done to prevent the Dockerfile COPY instruction from 3 | failing if no environment.yml is found. -------------------------------------------------------------------------------- /run/simple.py: -------------------------------------------------------------------------------- 1 | import coiled 2 | import pandas as pd 3 | 4 | df = pd.read_csv("data.csv") 5 | 6 | @coiled.function( 7 | region="us-west-1", 8 | ) 9 | def process(df): 10 | df["z"] = df.x + df.y 11 | return df 12 | 13 | print(process(df)) 14 | -------------------------------------------------------------------------------- /national-water-model/environment.yaml: -------------------------------------------------------------------------------- 1 | name: nwm 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.11 6 | - xarray 7 | - dask 8 | - flox 9 | - rioxarray 10 | - coiled 11 | - zarr 12 | - s3fs 13 | - hvplot 14 | - geopandas 15 | - geoviews 16 | -------------------------------------------------------------------------------- /datashader.yml: -------------------------------------------------------------------------------- 1 | name: datashader 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python==3.10 6 | - dask 7 | - coiled 8 | - pyarrow 9 | - ipykernel 10 | - s3fs 11 | - datashader 12 | - holoviews 13 | - hvplot 14 | - jupyterlab 15 | - dask-labextension 16 | -------------------------------------------------------------------------------- /spark.yml: -------------------------------------------------------------------------------- 1 | name: spark 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.11 6 | - dask 7 | - coiled 8 | - ipykernel 9 | - pyspark==3.4.1 10 | - pyarrow 11 | - grpcio 12 | - grpcio-status 13 | - openjdk~=11.0 14 | - protobuf 15 | - jupyterlab 16 | - s3fs 17 | -------------------------------------------------------------------------------- /pytorch.yml: -------------------------------------------------------------------------------- 1 | name: pytorch 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.11 6 | - dask 7 | - coiled 8 | - ipykernel 9 | - ipython 10 | - dask-labextension 11 | - jupyterlab 12 | - optuna 13 | - pytorch 14 | - torchvision 15 | - matplotlib 16 | - pip 17 | - optuna-integration 18 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: coiled-examples 2 | channels: 3 | - conda-forge 4 | - nodefaults 5 | dependencies: 6 | - python=3.10 7 | - dask 8 | - dask-ml 9 | - coiled 10 | - pyarrow 11 | - s3fs 12 | - matplotlib 13 | - ipykernel 14 | - dask-labextension 15 | - xgboost 16 | - pandas=2 17 | - optuna 18 | - xarray 19 | - geogif 20 | - zarr 21 | - h5netcdf 22 | - pre-commit 23 | - earthaccess 24 | -------------------------------------------------------------------------------- /.github/workflows/linting.yml: -------------------------------------------------------------------------------- 1 | name: Linting 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: main 8 | 9 | jobs: 10 | checks: 11 | name: pre-commit hooks 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v3.5.3 15 | - uses: actions/setup-python@v4 16 | with: 17 | python-version: '3.9' 18 | - uses: pre-commit/action@v3.0.0 19 | -------------------------------------------------------------------------------- /batch/accelerate-example.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | #COILED n-tasks 20 4 | #COILED vm-type g6.xlarge 5 | #COILED task-on-scheduler True 6 | 7 | accelerate launch \ 8 | --multi_gpu \ 9 | --machine_rank $COILED_BATCH_TASK_ID \ 10 | --main_process_ip $COILED_BATCH_SCHEDULER_ADDRESS \ 11 | --main_process_port 12345 \ 12 | --num_machines $COILED_BATCH_TASK_COUNT \ 13 | --num_processes $COILED_BATCH_TASK_COUNT \ 14 | nlp_example.py 15 | -------------------------------------------------------------------------------- /run/nyc.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | df = pd.read_parquet( 4 | "s3://nyc-tlc/trip data/yellow_tripdata_2023-04.parquet" 5 | ) 6 | 7 | print("Head") 8 | print("====") 9 | print(df.head()) 10 | 11 | print("Columns") 12 | print("=======") 13 | print(df.columns) 14 | 15 | print("Tip Percentage") 16 | print("==============") 17 | print((df.tip_amount != 0).mean()) 18 | 19 | print("Uploading tipped rides to S3") 20 | 21 | 22 | df = df[df.tip_amount != 0] 23 | df.to_parquet( 24 | "s3://oss-shared-scratch/mrocklin/nyc-tipped-2023-04.parquet" 25 | ) 26 | 27 | 28 | print("Done") 29 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CUDA_VER=11.8.0 2 | ARG PYTHON_VER=3.10 3 | ARG LINUX_VER=ubuntu20.04 4 | 5 | FROM rapidsai/miniforge-cuda:cuda${CUDA_VER}-base-${LINUX_VER}-py${PYTHON_VER} as base 6 | ARG CUDA_VER 7 | ARG PYTHON_VER 8 | 9 | WORKDIR /home 10 | COPY pytorch.yml pytorch.yml 11 | RUN mamba env update -n base --file pytorch.yml \ 12 | && conda clean -afy \ 13 | && mamba uninstall -y pytorch torchvision \ 14 | && mamba install -y -n base -c pytorch -c nvidia -c conda-forge \ 15 | "cudatoolkit=${CUDA_VER%.*}.*" \ 16 | "cuda-version=${CUDA_VER%.*}.*" \ 17 | pytorch \ 18 | torchvision \ 19 | "pytorch-cuda=${CUDA_VER%.*}.*" \ 20 | && conda clean -afy \ 21 | && rm pytorch.yml 22 | 23 | 24 | FROM base as examples 25 | COPY . . 26 | -------------------------------------------------------------------------------- /run/nyc-function.py: -------------------------------------------------------------------------------- 1 | import coiled 2 | from dask.distributed import print 3 | import pandas as pd 4 | import s3fs 5 | 6 | s3 = s3fs.S3FileSystem() 7 | filenames = s3.ls("s3://nyc-tlc/trip data/") 8 | filenames = [ 9 | "s3://" + fn 10 | for fn in filenames 11 | if "yellow_tripdata_2022" in fn 12 | ] 13 | 14 | @coiled.function( 15 | region="us-east-1", 16 | memory="8 GiB", 17 | ) 18 | def process(filename): 19 | df = pd.read_parquet(filename) 20 | df = df[df.tip_amount != 0] 21 | 22 | outfile = "s3://oss-shared-scratch/mrocklin/" + filename.split("/")[-1] 23 | df.to_parquet(outfile) 24 | print("Finished", outfile) 25 | 26 | 27 | print(f"\nProcessing {len(filenames)} files") 28 | for filename in filenames: 29 | print("Processing", filename) 30 | process(filename) 31 | -------------------------------------------------------------------------------- /batch/gdal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | #COILED n-tasks 3111 4 | #COILED max-workers 200 5 | #COILED region us-west-2 6 | #COILED memory 8 GiB 7 | #COILED container ghcr.io/osgeo/gdal 8 | #COILED forward-aws-credentials 9 | 10 | # Install aws CLI 11 | if [ ! "$(which aws)" ]; then 12 | curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" 13 | unzip -qq awscliv2.zip 14 | ./aws/install 15 | fi 16 | 17 | # Download file to be processed 18 | filename=$(aws s3 ls --no-sign-request --recursive s3://sentinel-cogs/sentinel-s2-l2a-cogs/54/E/XR/ | \ 19 | grep ".tif" | \ 20 | awk '{print $4}' | \ 21 | awk "NR==$(($COILED_ARRAY_TASK_ID + 1))") 22 | aws s3 cp --no-sign-request s3://sentinel-cogs/$filename in.tif 23 | 24 | # Reproject GeoTIFF 25 | gdalwarp -t_srs EPSG:4326 in.tif out.tif 26 | 27 | # Move result to processed bucket 28 | aws s3 mv out.tif s3://oss-scratch-space/sentinel-reprojected/$filename 29 | -------------------------------------------------------------------------------- /reproject.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | #COILED n-tasks 3111 4 | #COILED max-workers 100 5 | #COILED region us-west-2 6 | #COILED memory 8 GiB 7 | #COILED container ghcr.io/osgeo/gdal 8 | #COILED forward-aws-credentials True 9 | 10 | # Install aws CLI 11 | if [ ! "$(which aws)" ]; then 12 | curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" 13 | unzip -qq awscliv2.zip 14 | ./aws/install 15 | fi 16 | 17 | # Download file to be processed 18 | filename=$(aws s3 ls --no-sign-request --recursive s3://sentinel-cogs/sentinel-s2-l2a-cogs/54/E/XR/ | \ 19 | grep ".tif" | \ 20 | awk '{print $4}' | \ 21 | awk "NR==$(($COILED_BATCH_TASK_ID + 1))") 22 | aws s3 cp --no-sign-request s3://sentinel-cogs/$filename in.tif 23 | 24 | # Reproject GeoTIFF 25 | gdalwarp -t_srs EPSG:4326 in.tif out.tif 26 | 27 | # Move result to processed bucket 28 | aws s3 mv out.tif s3://oss-scratch-space/sentinel-reprojected/$filename -------------------------------------------------------------------------------- /.devcontainer/dev.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ghcr.io/mamba-org/micromamba-devcontainer:git-d175103 2 | 3 | # Ensure that all users have read-write access to all files created in the subsequent commands. 4 | ARG DOCKERFILE_UMASK=0000 5 | 6 | # Install the Conda packages. 7 | COPY --chown=$MAMBA_USER:$MAMBA_USER environment.yml /tmp/environment.yml 8 | RUN : \ 9 | # Configure Conda to use the conda-forge channel 10 | && micromamba config append channels conda-forge \ 11 | # Install and clean up 12 | && micromamba install --yes --name base \ 13 | --category dev --category main --file /tmp/environment.yml \ 14 | && micromamba clean --all --yes \ 15 | ; 16 | 17 | # Activate the conda environment for the Dockerfile. 18 | # 19 | ARG MAMBA_DOCKERFILE_ACTIVATE=1 20 | # Create and set the workspace folder 21 | ARG CONTAINER_WORKSPACE_FOLDER=/workspaces/default-workspace-folder 22 | RUN mkdir -p "${CONTAINER_WORKSPACE_FOLDER}" 23 | WORKDIR "${CONTAINER_WORKSPACE_FOLDER}" 24 | -------------------------------------------------------------------------------- /.github/workflows/docker.yml: -------------------------------------------------------------------------------- 1 | name: Docker build 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | workflow_dispatch: 9 | 10 | jobs: 11 | docker: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | include: 17 | - repo: "coiled/gpu" 18 | target: "base" 19 | - repo: "coiled/gpu-examples" 20 | target: "examples" 21 | steps: 22 | - name: Set up QEMU 23 | uses: docker/setup-qemu-action@v2 24 | 25 | - name: Set up Docker Buildx 26 | uses: docker/setup-buildx-action@v2 27 | 28 | - name: Login to Docker Hub 29 | uses: docker/login-action@v2 30 | with: 31 | username: ${{ secrets.DOCKERHUB_USERNAME }} 32 | password: ${{ secrets.DOCKERHUB_TOKEN }} 33 | 34 | - name: Build ${{ matrix.repo }} image and push 35 | uses: docker/build-push-action@v4 36 | with: 37 | push: ${{ github.repository == 'coiled/examples' && (github.event_name == 'push' || github.event_name == 'workflow_dispatch') && github.ref == 'refs/heads/main'}} 38 | platforms: linux/amd64 39 | target: ${{ matrix.target }} 40 | tags: ${{ matrix.repo }}:latest,${{ matrix.repo }}:${{ github.sha }} -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the 2 | // README at: https://github.com/devcontainers/templates/tree/main/src/miniconda 3 | { 4 | "name": "Miniconda (Python 3)", 5 | "containerEnv": { 6 | "LOCAL_WORKSPACE_FOLDER": "${localWorkspaceFolder}", 7 | "CONTAINER_WORKSPACE_FOLDER": "${containerWorkspaceFolder}", 8 | }, 9 | "build": { 10 | "context": "..", 11 | "dockerfile": "dev.Dockerfile", 12 | "args": { 13 | "CONTAINER_WORKSPACE_FOLDER": "${containerWorkspaceFolder}", 14 | } 15 | }, 16 | "overrideCommand": false, 17 | "customizations": { 18 | "vscode": { 19 | "extension": [ 20 | "ms-python.python", 21 | "ms-python.vscode-pylance", 22 | "ms-python.black", 23 | "ms-python.isort", 24 | "ms-toolsai.jupyter", 25 | "charliermarsh.ruff" 26 | ] 27 | } 28 | }, 29 | "remoteUser": "mambauser", 30 | // Features to add to the dev container. More info: https://containers.dev/features. 31 | // "features": {}, 32 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 33 | // "forwardPorts": [], 34 | // Use 'postCreateCommand' to run commands after the container is created. 35 | // "postCreateCommand": "python --version", 36 | // Configure tool-specific properties. 37 | // "customizations": {}, 38 | // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. 39 | // "remoteUser": "root" 40 | } -------------------------------------------------------------------------------- /national-water-model/xarray-water-model.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example was adapted from https://github.com/dcherian/dask-demo/blob/main/nwm-aws.ipynb 3 | """ 4 | 5 | import coiled 6 | import dask 7 | import flox.xarray 8 | import fsspec 9 | import numpy as np 10 | import rioxarray 11 | import xarray as xr 12 | 13 | 14 | # optionally run with coiled run 15 | # coiled run --region us-east-1 --vm-type m6g.xlarge python national-water-model/xarray-water-model.py 16 | 17 | cluster = coiled.Cluster( 18 | name="nwm-1979-2020", 19 | region="us-east-1", # close to data 20 | n_workers=10, 21 | scheduler_vm_types="r7g.xlarge", # ARM instance 22 | worker_vm_types="r7g.2xlarge", 23 | spot_policy="spot_with_fallback", # use spot, replace with on-demand 24 | ) 25 | 26 | client = cluster.get_client() 27 | cluster.adapt(minimum=10, maximum=200) 28 | 29 | ds = xr.open_zarr( 30 | fsspec.get_mapper("s3://noaa-nwm-retrospective-2-1-zarr-pds/rtout.zarr", anon=True), 31 | consolidated=True, 32 | chunks={"time": 896, "x": 350, "y": 350}, 33 | ) 34 | 35 | subset = ds.zwattablrt.sel(time=slice("1979-02-01", "2020-12-31")) 36 | 37 | fs = fsspec.filesystem("s3", requester_pays=True) 38 | 39 | with dask.annotate(retries=3): 40 | counties = rioxarray.open_rasterio( 41 | fs.open("s3://nwm-250m-us-counties/Counties_on_250m_grid.tif"), chunks="auto" 42 | ).squeeze() 43 | 44 | # remove any small floating point error in coordinate locations 45 | _, counties_aligned = xr.align(subset, counties, join="override") 46 | 47 | counties_aligned = counties_aligned.persist() 48 | 49 | county_id = np.unique(counties_aligned.data).compute() 50 | county_id = county_id[county_id != 0] 51 | print(f"There are {len(county_id)} counties!") 52 | 53 | county_mean = flox.xarray.xarray_reduce( 54 | subset, 55 | counties_aligned.rename("county"), 56 | func="mean", 57 | expected_groups=(county_id,), 58 | ) 59 | 60 | county_mean.load() 61 | yearly_mean = county_mean.mean("time") 62 | # optionally, save dataset for further analysis 63 | # print("Saving") 64 | # yearly_mean.to_netcdf("mean_zwattablrt_nwm_1979_2020.nc") 65 | cluster.shutdown() 66 | -------------------------------------------------------------------------------- /satellite-imagery.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e6a0105b-46d5-4514-ab80-f1be4d930106", 6 | "metadata": {}, 7 | "source": [ 8 | "# Satellite Image Processing" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "f2f207b1-e454-4cbf-b846-b7e0d815c784", 14 | "metadata": {}, 15 | "source": [ 16 | "## List file paths in the cloud" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "f3221ff3-24bd-4e38-a974-efde525e88b1", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import s3fs\n", 27 | "\n", 28 | "s3 = s3fs.S3FileSystem()\n", 29 | "\n", 30 | "urls = s3.glob(\"sentinel-cogs/sentinel-s2-l2a-cogs/1/C/CV/20*/*/*/*.tif\")\n", 31 | "urls[:5]" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "f2cd4092-358a-4a69-8013-e0b688b9583a", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "len(urls)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "id": "91f84137-28c0-4aea-9698-760cf6d93665", 47 | "metadata": {}, 48 | "source": [ 49 | "## Process each file sequentially" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "6be2be6b-c673-4caf-aa9d-aa01c0262316", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "import rasterio, rioxarray\n", 60 | "\n", 61 | "def process(url):\n", 62 | " data = rioxarray.open_rasterio(\"s3://\" + url)\n", 63 | " \n", 64 | " # TODO: do real work with data\n", 65 | "\n", 66 | " return ...\n", 67 | "\n", 68 | "for url in urls:\n", 69 | " process(url)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "id": "7d46d6ac-7b83-44a6-9b18-9e10f8b34e0a", 75 | "metadata": {}, 76 | "source": [ 77 | "## Process each file in parallel on the cloud" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "7cb4013d-f2f9-497b-a5f1-9fb14bc50627", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "import coiled # New!\n", 88 | "import rasterio, rioxarray\n", 89 | "\n", 90 | "@coiled.function( # New!\n", 91 | " region=\"us-west-2\",\n", 92 | ")\n", 93 | "def process(url):\n", 94 | " data = rioxarray.open_rasterio(\"s3://\" + url)\n", 95 | " \n", 96 | " # TODO: do real work with data\n", 97 | "\n", 98 | " return ...\n", 99 | "\n", 100 | "results = process.map(urls)" 101 | ] 102 | } 103 | ], 104 | "metadata": { 105 | "kernelspec": { 106 | "display_name": "Python [conda env:coiled-examples]", 107 | "language": "python", 108 | "name": "conda-env-coiled-examples-py" 109 | }, 110 | "language_info": { 111 | "codemirror_mode": { 112 | "name": "ipython", 113 | "version": 3 114 | }, 115 | "file_extension": ".py", 116 | "mimetype": "text/x-python", 117 | "name": "python", 118 | "nbconvert_exporter": "python", 119 | "pygments_lexer": "ipython3", 120 | "version": "3.10.15" 121 | } 122 | }, 123 | "nbformat": 4, 124 | "nbformat_minor": 5 125 | } 126 | -------------------------------------------------------------------------------- /spark.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f8eafed2-77a1-4691-8e8b-aeb1187ce8f5", 6 | "metadata": { 7 | "tags": [] 8 | }, 9 | "source": [ 10 | "\n", 11 | "Spark on Coiled\n", 12 | "===============\n", 13 | "\n", 14 | "\n", 17 | "\n", 18 | "Coiled can run Spark Jobs.\n", 19 | "\n", 20 | "You get all the same Coiled ease of use features:\n", 21 | "\n", 22 | "1. Quick startup\n", 23 | "2. Copies all of your local packages and code\n", 24 | "3. Runs in any region on any hardware\n", 25 | "4. Runs from your local notebook\n", 26 | "\n", 27 | "But now rather than just Dask you can run Spark too." 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "0d130128-ac72-4ce6-87b1-b7a20337fd2a", 33 | "metadata": {}, 34 | "source": [ 35 | "### Read a little bit of data with pandas" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "id": "09728a96-0c84-4198-ab52-4dcdfd704606", 42 | "metadata": { 43 | "tags": [] 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "import pandas as pd\n", 48 | "\n", 49 | "df = pd.read_parquet(\n", 50 | " \"s3://coiled-data/uber/part.0.parquet\",\n", 51 | ")\n", 52 | "df" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "3148ad8d-3de7-47b6-91a3-1d1f5a393f64", 58 | "metadata": {}, 59 | "source": [ 60 | "## Start Spark cluster to read lots of data" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "id": "9a9e6076-c8b3-4282-90a4-0fe3ab49440d", 67 | "metadata": { 68 | "tags": [] 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "import coiled\n", 73 | "\n", 74 | "cluster = coiled.Cluster(\n", 75 | " n_workers=10,\n", 76 | " worker_memory=\"16 GiB\",\n", 77 | " region=\"us-east-2\",\n", 78 | ")\n", 79 | "\n", 80 | "spark = cluster.get_spark()" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "33b598a4-fe0a-43c5-8007-0e955ac193f9", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "df = spark.read.parquet(\"s3a://coiled-data/uber\")\n", 91 | "df.show()" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "id": "56e2b982-af1b-4140-8cc1-414343ba1f0a", 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "df.count()" 102 | ] 103 | } 104 | ], 105 | "metadata": { 106 | "kernelspec": { 107 | "display_name": "Python [conda env:spark]", 108 | "language": "python", 109 | "name": "conda-env-spark-py" 110 | }, 111 | "language_info": { 112 | "codemirror_mode": { 113 | "name": "ipython", 114 | "version": 3 115 | }, 116 | "file_extension": ".py", 117 | "mimetype": "text/x-python", 118 | "name": "python", 119 | "nbconvert_exporter": "python", 120 | "pygments_lexer": "ipython3", 121 | "version": "3.11.9" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 5 126 | } 127 | -------------------------------------------------------------------------------- /xarray.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "75688ac6-879d-4449-b73e-74f03a5f991f", 6 | "metadata": { 7 | "tags": [], 8 | "user_expressions": [] 9 | }, 10 | "source": [ 11 | "\n", 14 | "\n", 15 | "# Geospatial Large\n", 16 | "\n", 17 | "This is a national water model: https://registry.opendata.aws/nwm-archive/" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "id": "8185966d-6659-482b-bcbb-826b8f30b1e3", 23 | "metadata": { 24 | "tags": [] 25 | }, 26 | "source": [ 27 | "## Load NWM data" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "e8b1749a-0d64-4278-823c-892120bf1a5b", 34 | "metadata": { 35 | "tags": [] 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "import xarray as xr\n", 40 | "\n", 41 | "ds = xr.open_zarr(\n", 42 | " \"s3://noaa-nwm-retrospective-2-1-zarr-pds/rtout.zarr\",\n", 43 | " consolidated=True,\n", 44 | ").drop_encoding()\n", 45 | "ds" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "id": "5dd71599-465f-4c97-baaa-19d900d2a070", 51 | "metadata": { 52 | "user_expressions": [] 53 | }, 54 | "source": [ 55 | "## Set up cluster" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "60b08a1c-d042-40f2-aaaa-e7665ca85d64", 62 | "metadata": { 63 | "tags": [] 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "import coiled\n", 68 | "\n", 69 | "cluster = coiled.Cluster(\n", 70 | " n_workers=100,\n", 71 | " region=\"us-east-1\",\n", 72 | ")\n", 73 | "client = cluster.get_client()" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "id": "0911fb96-7c08-4ca6-a35a-22e2a5a908cd", 79 | "metadata": { 80 | "tags": [] 81 | }, 82 | "source": [ 83 | "## Compute average over space" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "id": "2a6fb91d-6a02-4afc-8d8a-ec3529f805f4", 90 | "metadata": { 91 | "tags": [] 92 | }, 93 | "outputs": [], 94 | "source": [ 95 | "subset = ds.zwattablrt.sel(time=slice(\"2001-01-01\", \"2001-03-31\"))\n", 96 | "subset" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "8ae07b31-383c-4cc9-b94a-cbbb68369746", 103 | "metadata": { 104 | "tags": [] 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "avg = subset.mean(dim=[\"x\", \"y\"]).compute()\n", 109 | "avg.plot()" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "id": "b237d221-a2db-44fb-924d-6003cd73f933", 115 | "metadata": {}, 116 | "source": [ 117 | "## Rechunk" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "8ac6e24d-24d6-445d-a532-438b9d3a13f9", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "result = subset.chunk({\"time\": \"auto\", \"x\": -1, \"y\": \"auto\"})\n", 128 | "result" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "id": "657ce639-b644-42ea-b98b-b70c2cb3170a", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "%%time\n", 139 | "\n", 140 | "result.to_zarr(\"s3://oss-scratch-space/nwm-x-optimized.zarr\", mode=\"w\")" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "id": "268ccde2-7e83-4e97-9fb7-4887a52adbe6", 146 | "metadata": {}, 147 | "source": [ 148 | "## Cleanup if you like\n", 149 | "\n", 150 | "(but we'll clean up automatically eventually)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "1f71eff1-986f-4a7c-9ebe-92437effff8b", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "cluster.shutdown()" 161 | ] 162 | } 163 | ], 164 | "metadata": { 165 | "kernelspec": { 166 | "display_name": "Python 3 (ipykernel)", 167 | "language": "python", 168 | "name": "python3" 169 | }, 170 | "language_info": { 171 | "codemirror_mode": { 172 | "name": "ipython", 173 | "version": 3 174 | }, 175 | "file_extension": ".py", 176 | "mimetype": "text/x-python", 177 | "name": "python", 178 | "nbconvert_exporter": "python", 179 | "pygments_lexer": "ipython3", 180 | "version": "3.10.17" 181 | } 182 | }, 183 | "nbformat": 4, 184 | "nbformat_minor": 5 185 | } 186 | -------------------------------------------------------------------------------- /sea-level-rise.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "8f552380-cc0a-4fb8-b4b3-276a6564b239", 6 | "metadata": {}, 7 | "source": [ 8 | "# Analyzing Sea Level Rise in the Cloud with Coiled and Earthaccess\n", 9 | "\n", 10 | "_This notebook was adapted from [this NASA Earthdata Cloud Cookbook example](https://nasa-openscapes.github.io/earthdata-cloud-cookbook/tutorials/Sea_Level_Rise.html)_" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "4a30aadb-41d4-4c08-a6fa-5ea60c6bd695", 16 | "metadata": {}, 17 | "source": [ 18 | "## Get data files with `earthaccess`" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "899b10a6-51fc-4d97-96d2-54e81a8b8f7d", 25 | "metadata": { 26 | "tags": [] 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "# Authenticate my machine with `earthaccess`\n", 31 | "import earthaccess\n", 32 | "\n", 33 | "earthaccess.login();" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "id": "d2d48b80-6e72-4ad0-a71d-915d96f1ce03", 40 | "metadata": { 41 | "tags": [] 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "%%time\n", 46 | "\n", 47 | "# Retrieve data files for the dataset I'm interested in\n", 48 | "granules = earthaccess.search_data(\n", 49 | " short_name=\"SEA_SURFACE_HEIGHT_ALT_GRIDS_L4_2SATS_5DAY_6THDEG_V_JPL2205\",\n", 50 | " temporal=(\"2000\", \"2019\"),\n", 51 | ")" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "64df6beb-e1e9-4691-b6ec-18dfa6803f53", 57 | "metadata": {}, 58 | "source": [ 59 | "## Define processing function" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "f5cb0793-5273-4ede-9b8d-080a79ab8228", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "import coiled\n", 70 | "import xarray as xr\n", 71 | "\n", 72 | "@coiled.function(region=\"us-west-2\") # Same region as data\n", 73 | "def process(granule, fs):\n", 74 | " results = []\n", 75 | " for file in granule.data_links(\"direct\"):\n", 76 | " ds = xr.open_dataset(fs.open(file))\n", 77 | " ds = ds.sel(Latitude=slice(23, 50), Longitude=slice(270, 330))\n", 78 | " ds = ds.SLA.where((ds.SLA >= 0) & (ds.SLA < 10))\n", 79 | " results.append(ds)\n", 80 | " result = xr.concat(results, dim=\"Time\")\n", 81 | " return result" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "id": "fd853bd7-244e-4fbb-9cee-5e7921d082b1", 87 | "metadata": {}, 88 | "source": [ 89 | "## Process Granules" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "id": "f1d3fb8a-8f86-406c-860d-7ee9c8eb45ba", 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "fs = earthaccess.get_s3fs_session(results=granules)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "id": "86988643-b344-43c5-9ad7-51c17c24b4d7", 105 | "metadata": {}, 106 | "source": [ 107 | "### Process single file" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "b268d9e1-b056-444e-9b99-1171e2f5b075", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "chunk = process(granules[0], fs=fs)\n", 118 | "chunk.plot(x=\"Longitude\", y=\"Latitude\", figsize=(14, 6));" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "id": "cbe32079-17f9-45a8-9601-c0fe21980183", 124 | "metadata": {}, 125 | "source": [ 126 | "### Process all files in parallel" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "id": "d187d258-4caf-4430-b4f1-d34d920a83a1", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "%%time\n", 137 | "\n", 138 | "chunks = process.map(granules, fs=fs) # This runs on the cloud in parallel\n", 139 | "ds = xr.concat(chunks, dim=\"Time\")\n", 140 | "ds.std(\"Time\").plot(x=\"Longitude\", y=\"Latitude\", figsize=(14, 6));" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "id": "41e5f464-5504-41d6-a170-96c57a4c6808", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [] 150 | } 151 | ], 152 | "metadata": { 153 | "kernelspec": { 154 | "display_name": "Python 3 (ipykernel)", 155 | "language": "python", 156 | "name": "python3" 157 | }, 158 | "language_info": { 159 | "codemirror_mode": { 160 | "name": "ipython", 161 | "version": 3 162 | }, 163 | "file_extension": ".py", 164 | "mimetype": "text/x-python", 165 | "name": "python", 166 | "nbconvert_exporter": "python", 167 | "pygments_lexer": "ipython3", 168 | "version": "3.9.16" 169 | } 170 | }, 171 | "nbformat": 4, 172 | "nbformat_minor": 5 173 | } 174 | -------------------------------------------------------------------------------- /national-water-model/make-timelapse-video.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Data Prep" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import geopandas as gpd\n", 17 | "import xarray as xr\n", 18 | "import numpy as np\n", 19 | "\n", 20 | "# Read county shapefile, combo of state FIPS code and county FIPS code as multi-index\n", 21 | "counties = gpd.read_file(\n", 22 | " \"https://www2.census.gov/geo/tiger/GENZ2020/shp/cb_2020_us_county_20m.zip\"\n", 23 | ").to_crs(\"EPSG:3395\")\n", 24 | "counties[\"STATEFP\"] = counties.STATEFP.astype(int)\n", 25 | "counties[\"COUNTYFP\"] = counties.COUNTYFP.astype(int)\n", 26 | "continental = counties[\n", 27 | " ~counties[\"STATEFP\"].isin([2, 15, 72])\n", 28 | "] # drop Alaska, Hawaii, Puerto Rico\n", 29 | "\n", 30 | "# Read in saved data from xarray-water-model.py\n", 31 | "ds = xr.open_dataset(\"mean_zwattablrt_nwm_1979_2020.nc\")\n", 32 | "ds[\"week\"] = ds.time.dt.strftime(\"%Y-%U\")\n", 33 | "ds = ds.groupby(\"week\").mean()\n", 34 | "# Interpret county as combo of state FIPS code and county FIPS code\n", 35 | "ds.coords[\"STATEFP\"] = (ds.county // 1000).astype(int)\n", 36 | "ds.coords[\"COUNTYFP\"] = np.mod(ds.county, 1000).astype(int)\n", 37 | "df = ds.to_dataframe().reset_index()\n", 38 | "\n", 39 | "# Join\n", 40 | "merge_df = continental.merge(df, on=[\"STATEFP\", \"COUNTYFP\"])" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "Make all the plots" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "import matplotlib.pyplot as plt\n", 57 | "from mpl_toolkits.axes_grid1 import make_axes_locatable\n", 58 | "from datetime import datetime\n", 59 | "\n", 60 | "day_0 = datetime.strptime(weeks[0] + \"-0\", \"%Y-%U-%w\")\n", 61 | "weeks = merge_df.week.unique()\n", 62 | "\n", 63 | "for week in weeks:\n", 64 | " fig, ax = plt.subplots(1, 1, figsize=(7.68, 4.32)) # for 3840x2160 resolution\n", 65 | "\n", 66 | " ax.set_axis_off()\n", 67 | "\n", 68 | " divider = make_axes_locatable(ax)\n", 69 | " cax = divider.append_axes(\"bottom\", size=\"5%\", pad=0.1)\n", 70 | "\n", 71 | " cax.tick_params(labelsize=8)\n", 72 | " cax.set_title(\"Depth (in meters) of the water table\", fontsize=8)\n", 73 | "\n", 74 | " merge_df[merge_df[\"week\"] == f\"{week}\"].plot(\n", 75 | " column=\"zwattablrt\",\n", 76 | " cmap=\"BrBG_r\",\n", 77 | " vmin=0,\n", 78 | " vmax=2,\n", 79 | " legend=True,\n", 80 | " ax=ax,\n", 81 | " cax=cax,\n", 82 | " legend_kwds={\n", 83 | " \"orientation\": \"horizontal\",\n", 84 | " \"ticks\": [0, 0.5, 1, 1.5, 2],\n", 85 | " },\n", 86 | " )\n", 87 | "\n", 88 | " # Add legends for memory, time, and cost\n", 89 | " current_day = datetime.strptime(week + \"-0\", \"%Y-%U-%w\")\n", 90 | " n_days = (current_day - day_0).days\n", 91 | " memory = n_days * (16.88 * 1.07374) # daily memory converted to GB\n", 92 | " cost = (n_days / 7) * 0.01124606 # weekly cost\n", 93 | "\n", 94 | " if memory >= 1000:\n", 95 | " memory_string = f\"{memory/1000:.1f} TB processed, ~${cost:.2f} in cloud costs\"\n", 96 | " else:\n", 97 | " memory_string = f\"{memory:.0f} GB processed, ~${cost:.2f} in cloud costs\"\n", 98 | "\n", 99 | " plt.text(0, 1, memory_string, transform=ax.transAxes, size=9)\n", 100 | " # convert Year - Week Number to Month - Year\n", 101 | " date = datetime.strptime(week + \"-0\", \"%Y-%U-%w\").strftime(\"%Y %b\")\n", 102 | " plt.text(0.85, 1, f\"{date}\", transform=ax.transAxes, size=10)\n", 103 | " plt.savefig(f\"../../nwm-animation/3840x2160/{week}.png\", transparent=True, dpi=500)\n", 104 | " plt.close()" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "Use [ffmpeg](https://ffmpeg.org/) to stitch the images together" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "# ffmpeg -pattern_type glob -i '3840x2160/*.png' -r 60 -crf 18 -pix_fmt yuv420p nwm-video.mp4" 121 | ] 122 | } 123 | ], 124 | "metadata": { 125 | "kernelspec": { 126 | "display_name": "nwm", 127 | "language": "python", 128 | "name": "python3" 129 | }, 130 | "language_info": { 131 | "codemirror_mode": { 132 | "name": "ipython", 133 | "version": 3 134 | }, 135 | "file_extension": ".py", 136 | "mimetype": "text/x-python", 137 | "name": "python", 138 | "nbconvert_exporter": "python", 139 | "pygments_lexer": "ipython3", 140 | "version": "3.10.13" 141 | } 142 | }, 143 | "nbformat": 4, 144 | "nbformat_minor": 2 145 | } 146 | -------------------------------------------------------------------------------- /datashader.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "2d6196c9-ba8e-4fa6-bb98-56d2a1679631", 6 | "metadata": {}, 7 | "source": [ 8 | "\n", 11 | "\n", 12 | "Visualize 1,000,000,000 Points\n", 13 | "==============================\n", 14 | "\n", 15 | "In this notebook we process roughly one billion points and set them up for interactive visualization." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "id": "04e68108-29fe-4bec-8800-48a488caffc6", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import dask.dataframe as dd\n", 26 | "import datashader\n", 27 | "import hvplot.dask\n", 28 | "import coiled\n", 29 | "from dask.distributed import Client, wait" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "7649d478-bce0-4b6f-a39d-3611859a81cf", 35 | "metadata": {}, 36 | "source": [ 37 | "## Create Cluster" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "id": "147f0f8f-79f9-4133-81a9-e533c35a2a24", 44 | "metadata": { 45 | "tags": [] 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "%%time \n", 50 | "\n", 51 | "cluster = coiled.Cluster(\n", 52 | " n_workers=20,\n", 53 | " name=\"datashader\",\n", 54 | " region=\"us-east-2\", # start workers close to data to minimize costs\n", 55 | ") \n", 56 | "\n", 57 | "client = cluster.get_client()" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "485c960e-e2e3-4eab-869b-e70002beb2dc", 63 | "metadata": {}, 64 | "source": [ 65 | "## Load data" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "id": "10bb9925-a3c3-458d-aa6f-eb9821c15087", 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "%%time\n", 76 | "\n", 77 | "df = dd.read_parquet(\n", 78 | " \"s3://coiled-datasets/dask-book/nyc-tlc/2009-2013/\",\n", 79 | " columns=[\"dropoff_longitude\", \"dropoff_latitude\", \"pickup_longitude\", \"pickup_latitude\"]\n", 80 | ")\n", 81 | "\n", 82 | "# clean data to limit to lat-longs near nyc\n", 83 | "df = df.loc[\n", 84 | " (df.dropoff_longitude > -74.1) & (df.dropoff_longitude < -73.7) & \n", 85 | " (df.dropoff_latitude > 40.6) & (df.dropoff_latitude < 40.9) &\n", 86 | " (df.pickup_longitude > -74.1) & (df.pickup_longitude < -73.7) &\n", 87 | " (df.pickup_latitude > 40.6) & (df.pickup_latitude < 40.9)\n", 88 | "]\n", 89 | "\n", 90 | "# now we have to get a DataFrame with just dropoff locations\n", 91 | "df_drop = df[[\"dropoff_longitude\", \"dropoff_latitude\"]]\n", 92 | "df_drop[\"journey_type\"] = \"dropoff\"\n", 93 | "df_drop = df_drop.rename(columns={'dropoff_longitude': 'long', 'dropoff_latitude': 'lat'})\n", 94 | "\n", 95 | "\n", 96 | "# now do the same for pickups\n", 97 | "df_pick = df[[\"pickup_longitude\", \"pickup_latitude\"]]\n", 98 | "df_pick[\"journey_type\"] = \"pickup\"\n", 99 | "df_pick = df_pick.rename(columns={'pickup_longitude': 'long', 'pickup_latitude': 'lat'})\n", 100 | "\n", 101 | "# concatenate two dask dataframes\n", 102 | "df_plot = dd.concat([df_drop, df_pick])\n", 103 | "\n", 104 | "df_plot = df_plot.astype({\"journey_type\": \"category\"})\n", 105 | "df_plot[\"journey_type\"] = df_plot[\"journey_type\"].cat.set_categories([\"dropoff\", \"pickup\"])\n", 106 | "\n", 107 | "#partitions are small - better to repartition\n", 108 | "df_plot = df_plot.persist()\n", 109 | "df_plot = df_plot.repartition(partition_size=\"256MiB\").persist()\n", 110 | "\n", 111 | "print(\"Number of records:\", len(df_plot))" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "id": "90cbceaf-97b2-4b83-b384-633037ee2544", 117 | "metadata": { 118 | "tags": [] 119 | }, 120 | "source": [ 121 | "## Visualize" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "id": "5118e81d-4d1f-4c3d-817a-689dd1bd1f06", 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "import holoviews as hv\n", 132 | "hv.extension('bokeh')\n", 133 | "\n", 134 | "color_key = {\"pickup\": \"#EF1561\", \"dropoff\": \"#1F5AFF\"}\n", 135 | "\n", 136 | "df_plot.hvplot.scatter(\n", 137 | " x=\"long\", \n", 138 | " y=\"lat\", \n", 139 | " aggregator=datashader.by(\"journey_type\"), \n", 140 | " datashade=True, \n", 141 | " cnorm=\"eq_hist\",\n", 142 | " frame_width=700, \n", 143 | " aspect=1.33, \n", 144 | " color_key=color_key\n", 145 | ")" 146 | ] 147 | } 148 | ], 149 | "metadata": { 150 | "kernelspec": { 151 | "display_name": "Python 3 (ipykernel)", 152 | "language": "python", 153 | "name": "python3" 154 | }, 155 | "language_info": { 156 | "codemirror_mode": { 157 | "name": "ipython", 158 | "version": 3 159 | }, 160 | "file_extension": ".py", 161 | "mimetype": "text/x-python", 162 | "name": "python", 163 | "nbconvert_exporter": "python", 164 | "pygments_lexer": "ipython3", 165 | "version": "3.9.13" 166 | } 167 | }, 168 | "nbformat": 4, 169 | "nbformat_minor": 5 170 | } 171 | -------------------------------------------------------------------------------- /run/pytorch.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example was adapted from the following PyTorch tutorial 3 | https://pytorch.org/tutorials/beginner/introyt/trainingyt.html 4 | """ 5 | 6 | import os 7 | import sys 8 | import dask 9 | import coiled 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.optim import SGD 14 | from torchvision import datasets, transforms 15 | 16 | 17 | coiled.create_software_environment( 18 | name="pytorch", 19 | conda={ 20 | "channels": ["pytorch", "nvidia", "conda-forge", "defaults"], 21 | "dependencies": [ 22 | "python=" + sys.version.split(" ")[0], 23 | "dask=" + dask.__version__, 24 | "coiled", 25 | "pytorch", 26 | "cudatoolkit", 27 | "pynvml", 28 | "torchvision", 29 | ], 30 | }, 31 | gpu_enabled=True, 32 | ) 33 | 34 | 35 | def load_data(): 36 | transform = transforms.Compose( 37 | [transforms.ToTensor(), 38 | transforms.Normalize((0.5,), (0.5,))]) 39 | 40 | # Create datasets for training & validation, download if necessary 41 | training_set = datasets.FashionMNIST(os.getcwd(), train=True, transform=transform, download=True) 42 | validation_set = datasets.FashionMNIST(os.getcwd(), train=False, transform=transform, download=True) 43 | 44 | # Create data loaders for our datasets; shuffle for training, not for validation 45 | training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True) 46 | validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False) 47 | 48 | # Report split sizes 49 | print('Training set has {} instances'.format(len(training_set))) 50 | print('Validation set has {} instances'.format(len(validation_set))) 51 | 52 | return training_loader, validation_loader 53 | 54 | 55 | class GarmentClassifier(nn.Module): 56 | def __init__(self): 57 | super(GarmentClassifier, self).__init__() 58 | self.conv1 = nn.Conv2d(1, 6, 5) 59 | self.pool = nn.MaxPool2d(2, 2) 60 | self.conv2 = nn.Conv2d(6, 16, 5) 61 | self.fc1 = nn.Linear(16 * 4 * 4, 120) 62 | self.fc2 = nn.Linear(120, 84) 63 | self.fc3 = nn.Linear(84, 10) 64 | 65 | def forward(self, x): 66 | x = self.pool(F.relu(self.conv1(x))) 67 | x = self.pool(F.relu(self.conv2(x))) 68 | x = x.view(-1, 16 * 4 * 4) 69 | x = F.relu(self.fc1(x)) 70 | x = F.relu(self.fc2(x)) 71 | x = self.fc3(x) 72 | return x 73 | 74 | 75 | def train_one_epoch(model, loss_fn, optimizer, training_loader, device): 76 | running_loss = 0. 77 | last_loss = 0. 78 | 79 | # Here, we use enumerate(training_loader) instead of 80 | # iter(training_loader) so that we can track the batch 81 | # index and do some intra-epoch reporting 82 | for i, data in enumerate(training_loader): 83 | # Every data instance is an input + label pair 84 | inputs, labels = data 85 | 86 | # Move to GPU 87 | inputs, labels = inputs.to(device), labels.to(device) 88 | 89 | # Zero your gradients for every batch! 90 | optimizer.zero_grad() 91 | 92 | # Make predictions for this batch 93 | outputs = model(inputs) 94 | 95 | # Compute the loss and its gradients 96 | loss = loss_fn(outputs, labels) 97 | loss.backward() 98 | 99 | # Adjust learning weights 100 | optimizer.step() 101 | 102 | # Gather data 103 | running_loss += loss.item() 104 | if i % 1000 == 999: 105 | last_loss = running_loss / 1000 # loss per batch 106 | print(' batch {} loss: {}'.format(i + 1, last_loss)) 107 | running_loss = 0. 108 | 109 | return last_loss 110 | 111 | @coiled.function( 112 | vm_type="g5.xlarge", # A GPU Instance Type 113 | software="pytorch", # Our software environment defined above 114 | region="us-west-2", # We find GPUs are easier to get here 115 | ) 116 | def train_all_epochs(): 117 | # Confirm that GPU shows up 118 | print( 119 | "Available GPU is " 120 | f"{torch.cuda.get_device_name(torch.cuda.current_device()) if torch.cuda.is_available() else ''}" 121 | ) 122 | device = ( 123 | "cuda" 124 | if torch.cuda.is_available() 125 | else "mps" 126 | if torch.backends.mps.is_available() 127 | else "cpu" 128 | ) 129 | print(f"Using {device} device") 130 | 131 | training_loader, validation_loader = load_data() 132 | model = GarmentClassifier().to(device) 133 | loss_fn = nn.CrossEntropyLoss() 134 | optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9) 135 | 136 | epochs = 5 137 | best_vloss = 1_000_000. 138 | 139 | for epoch in range(epochs): 140 | print(f'EPOCH {epoch + 1}:') 141 | 142 | # Make sure gradient tracking is on, and do a pass over the data 143 | model.train(True) 144 | avg_loss = train_one_epoch(model, loss_fn, optimizer, training_loader, device) 145 | 146 | running_vloss = 0.0 147 | # Set the model to evaluation mode, disabling dropout and using population 148 | # statistics for batch normalization. 149 | model.eval() 150 | 151 | # Disable gradient computation and reduce memory consumption. 152 | with torch.no_grad(): 153 | for i, vdata in enumerate(validation_loader): 154 | vinputs, vlabels = vdata 155 | 156 | # Move to GPU 157 | vinputs, vlabels = vinputs.to(device), vlabels.to(device) 158 | 159 | voutputs = model(vinputs) 160 | vloss = loss_fn(voutputs, vlabels) 161 | running_vloss += vloss 162 | 163 | avg_vloss = running_vloss / (i + 1) 164 | print('LOSS train {} valid {}'.format(avg_loss, avg_vloss)) 165 | 166 | # Return the best model 167 | if avg_vloss < best_vloss: 168 | best_vloss = avg_vloss 169 | best_model = model 170 | 171 | print(f"Model on CUDA device: {next(best_model.parameters()).is_cuda}") 172 | 173 | # Move model to CPU so it can be serialized and returned to local machine 174 | best_model = best_model.to("cpu") 175 | 176 | return best_model 177 | 178 | model = train_all_epochs() 179 | 180 | # Save model locally 181 | torch.save(model.state_dict(), "model.pt") 182 | 183 | # Load model back to your machine for more training, inference, or analysis 184 | # device = torch.device('cpu') 185 | # saved_model = GarmentClassifier() 186 | # saved_model.load_state_dict(torch.load('model.pt', map_location=device)) -------------------------------------------------------------------------------- /futures.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "bcbbc249-ad44-41e7-acd1-d38dc5ae70cd", 6 | "metadata": {}, 7 | "source": [ 8 | "\n", 11 | "\n", 12 | "# Dask Futures for simple parallelism\n", 13 | "\n", 14 | "Dask futures are the foundation of all Dask APIs. They are easy to use and flexible. \n", 15 | "Dask futures work with any Python function on any Python object." 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "id": "42499a51-5130-4cae-8ce5-0d9234133d26", 21 | "metadata": {}, 22 | "source": [ 23 | "## Create a few processes locally" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "id": "cf01af2d-07ce-46dc-bb2e-17149678d1a6", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "from dask.distributed import Client\n", 34 | "client = Client()\n", 35 | "\n", 36 | "client" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "id": "3e84a69c-e9a6-473f-806d-5e46f0b67bee", 42 | "metadata": {}, 43 | "source": [ 44 | "## Some basic Python code\n", 45 | "\n", 46 | "These functions pretend to do some work, but are very simple. Dask doesn't care what code it runs. You should imagine replacing these with your own code." 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "id": "acaf1116-6802-4561-adc4-2f3858d7a21f", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "import time, random\n", 57 | "\n", 58 | "def inc(x):\n", 59 | " time.sleep(random.random())\n", 60 | " return x + 1\n", 61 | "\n", 62 | "def dec(x):\n", 63 | " time.sleep(random.random())\n", 64 | " return x - 1\n", 65 | "\n", 66 | "def add(x, y):\n", 67 | " time.sleep(random.random())\n", 68 | " return x + y" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "id": "0265fa74-1086-4218-b9f9-7af773b4b897", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "%%time\n", 79 | "inc(10)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "id": "ef34469f-2aef-4e60-99af-5e8090cdd3bc", 85 | "metadata": {}, 86 | "source": [ 87 | "## Sequential Code\n", 88 | "\n", 89 | "This very simple code just calls these function ten times in a loop.\n", 90 | "\n", 91 | "Dask makes it easy to parallelize simple code like this on your computer." 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "id": "2f618490-e13e-4b9e-b8ca-a62f7f4c614b", 98 | "metadata": { 99 | "tags": [] 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "%%time\n", 104 | "\n", 105 | "results = []\n", 106 | "for x in range(20):\n", 107 | " result = inc(x)\n", 108 | " result = dec(result)\n", 109 | " results.append(result)\n" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "id": "38b4d6e9-f511-42b5-8708-c3350fc82894", 115 | "metadata": { 116 | "tags": [] 117 | }, 118 | "source": [ 119 | "## Parallel code\n" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "0948cc01-d3a5-455d-9b05-040dd525bc98", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "%%time\n", 130 | "\n", 131 | "results = []\n", 132 | "for x in range(20):\n", 133 | " result = client.submit(inc, x)\n", 134 | " result = client.submit(dec, result)\n", 135 | " results.append(result)\n", 136 | "\n", 137 | "results = client.gather(results)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "id": "f7b680d4-1fe5-4f0e-9844-e9f3345dab7e", 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "id": "66c87e7c-630f-4e8c-99d3-802483bb921f", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "id": "6a7504e7-0c8c-4b00-91e9-daa5bd94c2d3", 159 | "metadata": { 160 | "tags": [] 161 | }, 162 | "source": [ 163 | "## More complex code with tree reduction\n", 164 | "\n", 165 | "The code above is very simple. Let's show off that Dask can do more complex things. \n", 166 | "Here we add all of our elements pair-wise until there is only one left. This looks especially fun if you bring up the \"Graph\" dashboard plot." 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "8cbcc7ef-7739-4b3f-a17d-f663270cbcd7", 173 | "metadata": { 174 | "tags": [] 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "%%time\n", 179 | "\n", 180 | "results = []\n", 181 | "for x in range(128):\n", 182 | " result = client.submit(inc, x)\n", 183 | " result = client.submit(dec, result)\n", 184 | " results.append(result)\n", 185 | "\n", 186 | "# Add up all of the results, pairwise\n", 187 | "while len(results) > 1:\n", 188 | " results = [\n", 189 | " client.submit(add, results[i], results[i + 1]) \n", 190 | " for i in range(0, len(results), 2)\n", 191 | " ]\n", 192 | " \n", 193 | "results = client.gather(results)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "id": "664ec9c2-edfd-4120-98a7-a77937ead759", 199 | "metadata": {}, 200 | "source": [ 201 | "## Scale Out" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "id": "e5a67007-7af7-490c-8f4d-0c1a5f30d5c5", 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "import coiled\n", 212 | "\n", 213 | "cluster = coiled.Cluster(\n", 214 | " n_workers=20,\n", 215 | ")\n", 216 | "client = cluster.get_client()\n", 217 | "\n", 218 | "# Then rerun the cells above" 219 | ] 220 | } 221 | ], 222 | "metadata": { 223 | "kernelspec": { 224 | "display_name": "Python [conda env:coiled]", 225 | "language": "python", 226 | "name": "conda-env-coiled-py" 227 | }, 228 | "language_info": { 229 | "codemirror_mode": { 230 | "name": "ipython", 231 | "version": 3 232 | }, 233 | "file_extension": ".py", 234 | "mimetype": "text/x-python", 235 | "name": "python", 236 | "nbconvert_exporter": "python", 237 | "pygments_lexer": "ipython3", 238 | "version": "3.10.0" 239 | } 240 | }, 241 | "nbformat": 4, 242 | "nbformat_minor": 5 243 | } 244 | -------------------------------------------------------------------------------- /xgboost-optuna.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7653595e-196c-4d7f-b2c3-5e6442599404", 6 | "metadata": {}, 7 | "source": [ 8 | "\n", 11 | "\n", 12 | "# Hyper-Parameter Optimization with Optuna\n", 13 | "\n", 14 | "This trains an XGBoost model and does hyperparameter optimization using Optuna to search and scikit-learn for cross validation." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "id": "d8e7cefa-5e11-4cd5-908f-5b3b054ed844", 20 | "metadata": {}, 21 | "source": [ 22 | "## Launch Cluster\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "2386ff41-7fc3-435a-bc8c-a38771073174", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "%%time\n", 33 | "\n", 34 | "import coiled\n", 35 | "\n", 36 | "cluster = coiled.Cluster(\n", 37 | " n_workers=20,\n", 38 | " name=\"hpo\",\n", 39 | ")\n", 40 | "\n", 41 | "client = cluster.get_client()" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "id": "59064aae-8d76-4296-9404-b00619c5036c", 47 | "metadata": {}, 48 | "source": [ 49 | "## Optuna Study\n", 50 | "\n", 51 | "We use the Dask scheduler to track work between the different experiments." 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "9ae2c1a8-14aa-4ebd-86a7-4cf58373d83c", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "import optuna\n", 62 | "from optuna.integration.dask import DaskStorage\n", 63 | "\n", 64 | "study = optuna.create_study(\n", 65 | " direction=\"maximize\",\n", 66 | " storage=DaskStorage(),\n", 67 | ")" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "id": "2c72a7cb-dbbe-4630-8a4c-a59ddb4bfd30", 73 | "metadata": { 74 | "tags": [] 75 | }, 76 | "source": [ 77 | "## Objective function\n", 78 | "\n", 79 | "We ...\n", 80 | "\n", 81 | "- Load data\n", 82 | "- Get recommended hyper-parameters from Optuna\n", 83 | "- Train\n", 84 | "- Report Score" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "id": "41eaa3e7-9b99-442d-8672-32b085909625", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "from sklearn.datasets import load_breast_cancer\n", 95 | "from sklearn.model_selection import cross_val_score, KFold\n", 96 | "import xgboost as xgb\n", 97 | "from optuna.samplers import RandomSampler\n", 98 | "\n", 99 | "def objective(trial):\n", 100 | " X, y = load_breast_cancer(return_X_y=True)\n", 101 | " params = {\n", 102 | " \"n_estimators\": 10,\n", 103 | " \"verbosity\": 0,\n", 104 | " \"lambda\": trial.suggest_float(\"lambda\", 1e-8, 100.0, log=True),\n", 105 | " \"alpha\": trial.suggest_float(\"alpha\", 1e-8, 100.0, log=True),\n", 106 | " \"colsample_bytree\": trial.suggest_float(\"colsample_bytree\", 0.2, 1.0),\n", 107 | " \"max_depth\": trial.suggest_int(\"max_depth\", 2, 10, step=1),\n", 108 | " \"min_child_weight\": trial.suggest_float(\"min_child_weight\", 1e-8, 100, log=True),\n", 109 | " \"learning_rate\": trial.suggest_float(\"learning_rate\", 1e-8, 1.0, log=True),\n", 110 | " \"gamma\": trial.suggest_float(\"gamma\", 1e-8, 1.0, log=True),\n", 111 | " \"grow_policy\": \"depthwise\",\n", 112 | " \"eval_metric\": \"logloss\"\n", 113 | " }\n", 114 | " clf = xgb.XGBClassifier(**params)\n", 115 | " fold = KFold(n_splits=5, shuffle=True, random_state=0)\n", 116 | " score = cross_val_score(clf, X, y, cv=fold, scoring='neg_log_loss')\n", 117 | " return score.mean()\n" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "id": "0519f886-c142-4c2c-954c-9eb2ba87bfcb", 123 | "metadata": {}, 124 | "source": [ 125 | "## Execute at Scale\n", 126 | "\n", 127 | "All of the actual coordination happens within Optuna. It's Dask's job just to provide a lot of firepower, which we do by submitting the optimize method many times." 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "id": "9a977efd-6301-49b8-8abc-2f5177c03a8e", 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "from dask.distributed import wait\n", 138 | "\n", 139 | "futures = [\n", 140 | " client.submit(study.optimize, objective, n_trials=1, pure=False)\n", 141 | " for _ in range(500)\n", 142 | "]\n", 143 | "\n", 144 | "_ = wait(futures)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "id": "ba99faa0-6cf5-44ec-b359-0d9007d622cd", 150 | "metadata": {}, 151 | "source": [ 152 | "## Results" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "id": "396e798c-8c4c-47ef-b5bf-73af7767589b", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "study.best_params" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "id": "1b4098f5-ea7c-47a6-81ba-3b27e1078b20", 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "study.best_value" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "id": "f7ebfcd6-d9e9-48f0-b256-25072f07ab48", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "from optuna.visualization.matplotlib import plot_optimization_history, plot_param_importances\n", 183 | "\n", 184 | "plot_optimization_history(study)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "id": "9940ad3a-e757-4703-b641-b5ebefdc8304", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "plot_param_importances(study)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "id": "a41ad2a6-cdfe-48fc-bca8-b0b801977380", 200 | "metadata": {}, 201 | "source": [ 202 | "## Clean up\n", 203 | "\n", 204 | "This cost us about $0.08" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "id": "3610787e-0710-406e-a3dd-588b42b42ee9", 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "cluster.shutdown()" 215 | ] 216 | } 217 | ], 218 | "metadata": { 219 | "kernelspec": { 220 | "display_name": "Python [conda env:coiled]", 221 | "language": "python", 222 | "name": "conda-env-coiled-py" 223 | }, 224 | "language_info": { 225 | "codemirror_mode": { 226 | "name": "ipython", 227 | "version": 3 228 | }, 229 | "file_extension": ".py", 230 | "mimetype": "text/x-python", 231 | "name": "python", 232 | "nbconvert_exporter": "python", 233 | "pygments_lexer": "ipython3", 234 | "version": "3.10.0" 235 | } 236 | }, 237 | "nbformat": 4, 238 | "nbformat_minor": 5 239 | } 240 | -------------------------------------------------------------------------------- /xgboost.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f6031837-1dc6-4773-9fa4-b892b6f0e968", 6 | "metadata": {}, 7 | "source": [ 8 | "\n", 11 | "\n", 12 | "# XGBoost for Gradient Boosted Trees\n", 13 | "\n", 14 | "[XGBoost](https://xgboost.readthedocs.io/en/latest/) is a library used for training gradient boosted supervised machine learning models, and it has a [Dask integration](https://xgboost.readthedocs.io/en/latest/tutorials/dask.html) for distributed training. In this guide, you'll learn how to train an XGBoost model in parallel using Dask and Coiled. Download {download}`this jupyter notebook ` to follow along." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "id": "9fe7754e-88e1-4a6a-bb16-527f64de3d2d", 20 | "metadata": {}, 21 | "source": [ 22 | "## About the Data\n", 23 | "\n", 24 | "In this example we will use a dataset that joins the\n", 25 | "Uber/Lyft dataset from the [High-Volume For-Hire Services](https://www.nyc.gov/site/tlc/businesses/high-volume-for-hire-services.page), with the [NYC Taxi Zone Lookup Table](https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page). \n", 26 | "\n", 27 | "This results in a dataset with ~1.4 billion rows. " 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "1357a947-9ccf-43cf-928c-4e072ea88839", 33 | "metadata": {}, 34 | "source": [ 35 | "## Get a Coiled Cluster\n", 36 | "\n", 37 | "To start we need to spin up a Dask cluster." 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "id": "a588327d-b81a-4397-8ba4-55fce39126f7", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "%%time\n", 48 | "\n", 49 | "import coiled\n", 50 | "\n", 51 | "cluster = coiled.Cluster(\n", 52 | " n_workers=50,\n", 53 | " name=\"xgboost\",\n", 54 | " worker_vm_types=[\"r6i.large\"],\n", 55 | " scheduler_vm_types=[\"m6i.large\"],\n", 56 | " region=\"us-east-2\",\n", 57 | ")\n", 58 | "\n", 59 | "client = cluster.get_client()" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "c2137fe6-a09b-46a2-bd79-f2765a5a58b1", 65 | "metadata": {}, 66 | "source": [ 67 | "## Load and Engineer Data" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "id": "b3a2dab5-6f93-46e1-b52b-fa97c2c5ce92", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "import dask.dataframe as dd\n", 78 | "\n", 79 | "df = dd.read_parquet(\n", 80 | " \"s3://coiled-datasets/dask-xgboost-example/feature_table.parquet/\"\n", 81 | ")\n", 82 | "\n", 83 | "# Convert dtypes\n", 84 | "df = df.astype({\n", 85 | " c: \"float32\" \n", 86 | " for c in df.select_dtypes(include=\"float\").columns.tolist()\n", 87 | "}).persist()\n", 88 | "\n", 89 | "# Categorize\n", 90 | "df = df.categorize(columns=df.select_dtypes(include=\"category\").columns.tolist())\n", 91 | "\n", 92 | "df = df.persist()" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "id": "fcaf580c-4f7c-4564-a35b-9f1ce94cc3bd", 98 | "metadata": {}, 99 | "source": [ 100 | "## Custom Cross-validation" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "id": "2cf4eff6-0174-45b2-9059-3e061e88c2e8", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "def make_cv_splits(n_folds = 5):\n", 111 | " frac = [1 / n_folds] * n_folds\n", 112 | " splits = df.random_split(frac, shuffle=True)\n", 113 | " for i in range(n_folds):\n", 114 | " train = [splits[j] for j in range(n_folds) if j != i]\n", 115 | " test = splits[i]\n", 116 | " yield dd.concat(train), test" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "id": "687b2949-2252-4ad4-ad34-a2e80c255dc3", 122 | "metadata": {}, 123 | "source": [ 124 | "## Train Model\n", 125 | "\n", 126 | "When using XGBoost with Dask, we need to call the XGBoost Dask interface from the client side. The main difference with XGBoost’s Dask interface is that we pass our Dask client as an additional argument for carrying out the computation. Note that if the `client` parameter below is set to `None`, XGBoost will use the default client returned by Dask.\n" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "id": "55a95362-92ab-4296-9bb1-9874a9eec4ab", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "from datetime import datetime\n", 137 | "\n", 138 | "import dask.array as da\n", 139 | "import xgboost.dask\n", 140 | "from dask_ml.metrics import mean_squared_error\n", 141 | "\n", 142 | "start = datetime.now()\n", 143 | "scores = []\n", 144 | "\n", 145 | "for i, (train, test) in enumerate(make_cv_splits(5)):\n", 146 | " print(f\"Train/Test split #{i + 1} / 5\")\n", 147 | " y_train = train[\"trip_time\"]\n", 148 | " X_train = train.drop(columns=[\"trip_time\"])\n", 149 | " y_test = test[\"trip_time\"]\n", 150 | " X_test = test.drop(columns=[\"trip_time\"])\n", 151 | "\n", 152 | " d_train = xgboost.dask.DaskDMatrix(None, X_train, y_train, enable_categorical=True)\n", 153 | "\n", 154 | " print(\"Training ...\")\n", 155 | " model = xgboost.dask.train(\n", 156 | " None,\n", 157 | " {\"tree_method\": \"hist\"},\n", 158 | " d_train,\n", 159 | " num_boost_round=4,\n", 160 | " evals=[(d_train, \"train\")],\n", 161 | " )\n", 162 | "\n", 163 | " print(\"Scoring ...\")\n", 164 | " predictions = xgboost.dask.predict(None, model, X_test)\n", 165 | "\n", 166 | " score = mean_squared_error(\n", 167 | " y_test.to_dask_array(),\n", 168 | " predictions.to_dask_array(),\n", 169 | " squared=False,\n", 170 | " compute=False,\n", 171 | " )\n", 172 | " scores.append(score.reshape(1).persist())\n", 173 | " print()\n", 174 | " print(\"-\" * 80)\n", 175 | " print()\n", 176 | "\n", 177 | "scores = da.concatenate(scores).compute()\n", 178 | "print(f\"RSME={scores.mean()} +/- {scores.std()}\")\n", 179 | "print(f\"Total time: {datetime.now() - start}\")" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "id": "64ff006f-1c8b-479f-9cc0-a9d128a1ee46", 185 | "metadata": {}, 186 | "source": [ 187 | "## Inspect Model" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "id": "4906d315-2e93-4c24-8b23-98fcebd3659a", 194 | "metadata": { 195 | "tags": [] 196 | }, 197 | "outputs": [], 198 | "source": [ 199 | "model" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "id": "48292f1a-09dd-443c-ad9b-1393aee0a510", 205 | "metadata": {}, 206 | "source": [ 207 | "## Clean up" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "id": "eb9ee713-fab6-4660-b034-34023fb4db9c", 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "cluster.close()" 218 | ] 219 | } 220 | ], 221 | "metadata": { 222 | "kernelspec": { 223 | "display_name": "Python 3 (ipykernel)", 224 | "language": "python", 225 | "name": "python3" 226 | }, 227 | "language_info": { 228 | "codemirror_mode": { 229 | "name": "ipython", 230 | "version": 3 231 | }, 232 | "file_extension": ".py", 233 | "mimetype": "text/x-python", 234 | "name": "python", 235 | "nbconvert_exporter": "python", 236 | "pygments_lexer": "ipython3", 237 | "version": "3.10.17" 238 | } 239 | }, 240 | "nbformat": 4, 241 | "nbformat_minor": 5 242 | } 243 | -------------------------------------------------------------------------------- /uber-lyft.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f8eafed2-77a1-4691-8e8b-aeb1187ce8f5", 6 | "metadata": { 7 | "tags": [] 8 | }, 9 | "source": [ 10 | "\n", 11 | "NYC Uber/Lyft Rides\n", 12 | "===================\n", 13 | "\n", 14 | "\n", 17 | "\n", 18 | "The NYC Taxi dataset is a timeless classic. \n", 19 | "\n", 20 | "Interestingly there is a new variant. The NYC Taxi and Livery Commission requires data from all ride-share services in the city of New York. This includes private limosine services, van services, and a new category \"High Volume For Hire Vehicle\" services, those that dispatch 10,000 rides per day or more. This is a special category defined for Uber and Lyft. \n", 21 | "\n", 22 | "This data is available here:" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "09728a96-0c84-4198-ab52-4dcdfd704606", 29 | "metadata": { 30 | "tags": [] 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "# We can read a small piece of data with pandas\n", 35 | "# but this is slow and not scalable\n", 36 | "\n", 37 | "import pandas as pd\n", 38 | "\n", 39 | "df = pd.read_parquet(\n", 40 | " \"s3://coiled-data/uber/part.0.parquet\",\n", 41 | ")\n", 42 | "df" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "id": "9a9e6076-c8b3-4282-90a4-0fe3ab49440d", 49 | "metadata": { 50 | "tags": [] 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "import coiled\n", 55 | "\n", 56 | "cluster = coiled.Cluster(\n", 57 | " n_workers=30,\n", 58 | " worker_memory=\"16 GiB\",\n", 59 | " region=\"us-east-2\",\n", 60 | ")\n", 61 | "\n", 62 | "client = cluster.get_client()" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "33b598a4-fe0a-43c5-8007-0e955ac193f9", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "import dask\n", 73 | "import dask.dataframe as dd\n", 74 | "\n", 75 | "df = dd.read_parquet(\n", 76 | " \"s3://coiled-data/uber/\",\n", 77 | ")\n", 78 | "df" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "id": "f9d012d8-6055-4cdc-a37e-d6a01b36c7db", 84 | "metadata": {}, 85 | "source": [ 86 | "Play time\n", 87 | "---------\n", 88 | "\n", 89 | "We actually don't know what to expect from this dataset. No one in our team has spent much time inspecting it. We'd like to solicit help from you, new Dask user, to uncover some interesting insights. \n", 90 | "\n", 91 | "Care to explore and report your findings?" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "id": "8c589e9f-f3e9-41d3-b34c-ca42fee44729", 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "df = df.persist()\n", 102 | "\n", 103 | "df.columns" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "062fe739-a3f1-4a9a-8e62-34da95b63982", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "id": "d4d906ce-605b-4128-890d-830334986974", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "e622064f-0e1c-48bc-a639-8d181fc92dea", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "id": "d264b39e-2dee-49a1-bbbf-2537265a3630", 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "id": "98d43c82-9a03-4311-8c17-2ef3b3519aa7", 141 | "metadata": {}, 142 | "source": [ 143 | "## Tipping Practices" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "id": "23af7df8-cf98-4e04-8cd0-829a33e65840", 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "# How often do New Yorkers tip?\n", 154 | "\n", 155 | "(df.tips != 0).mean().compute()" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "6ac9e09e-1b8f-4e79-9929-fc26a753e23e", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "id": "18236b33-c27d-4d96-8a61-98d73c224744", 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "id": "10f285be-6839-4866-9535-43adbcc965d0", 177 | "metadata": {}, 178 | "source": [ 179 | "## Broken down by carrier" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "id": "b71729eb-f841-433a-b020-0f2b1c425355", 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "# Uber / Lyft / Via / ... different carriers\n", 190 | "df.hvfhs_license_num.value_counts().compute()" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "id": "da79527d-8db9-4f7b-96fe-c87159241103", 197 | "metadata": { 198 | "tags": [] 199 | }, 200 | "outputs": [], 201 | "source": [ 202 | "df[\"tipped\"] = df.tips != 0\n", 203 | "\n", 204 | "df.groupby(\"hvfhs_license_num\").tipped.mean().compute()" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "id": "5a6e07be-4694-496c-a323-161518ebed74", 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "id": "74cb092c-11ea-40ef-8bcb-127b540681f3", 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "id": "0cac5bd4-8f32-43b4-bfe0-e20767cf1db2", 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "id": "6ae5e64d-4bf3-4f99-b54f-aebedfdfd98a", 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "id": "7c962f9f-7308-465a-98fb-ac622a967a34", 242 | "metadata": { 243 | "jp-MarkdownHeadingCollapsed": true, 244 | "tags": [] 245 | }, 246 | "source": [ 247 | "## Dask TV\n", 248 | "\n", 249 | "We use this in conference events just to make the dashboard go and bring in a crowd. Colloquially we call this \"Dask TV\". Enjoy!" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "id": "829de2bc-ed09-4e10-b06f-268aa79ead59", 256 | "metadata": { 257 | "tags": [] 258 | }, 259 | "outputs": [], 260 | "source": [ 261 | "import dask\n", 262 | "import dask.dataframe as dd\n", 263 | "dask.config.set({\"dataframe.convert-string\": True}) # use PyArrow strings by default\n", 264 | "\n", 265 | "while True:\n", 266 | " client.restart()\n", 267 | "\n", 268 | " df = dd.read_parquet(\n", 269 | " \"s3://coiled-datasets/uber-lyft-tlc/\",\n", 270 | " storage_options={\"anon\": True},\n", 271 | " ).persist()\n", 272 | "\n", 273 | " for _ in range(10):\n", 274 | " df[\"tipped\"] = df.tips != 0\n", 275 | "\n", 276 | " df.groupby(\"hvfhs_license_num\").tipped.mean().compute()" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "id": "f4b6b9f9-7ef3-4ca0-b769-0dd7e4ce6b0b", 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [] 286 | } 287 | ], 288 | "metadata": { 289 | "kernelspec": { 290 | "display_name": "Python [conda env:coiled]", 291 | "language": "python", 292 | "name": "conda-env-coiled-py" 293 | }, 294 | "language_info": { 295 | "codemirror_mode": { 296 | "name": "ipython", 297 | "version": 3 298 | }, 299 | "file_extension": ".py", 300 | "mimetype": "text/x-python", 301 | "name": "python", 302 | "nbconvert_exporter": "python", 303 | "pygments_lexer": "ipython3", 304 | "version": "3.10.14" 305 | } 306 | }, 307 | "nbformat": 4, 308 | "nbformat_minor": 5 309 | } 310 | -------------------------------------------------------------------------------- /batch/nlp_example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | 16 | import evaluate 17 | import torch 18 | from datasets import load_dataset 19 | from torch.optim import AdamW 20 | from torch.utils.data import DataLoader 21 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed 22 | 23 | from accelerate import Accelerator, DistributedType 24 | 25 | 26 | ######################################################################## 27 | # This is a fully working simple example to use Accelerate 28 | # 29 | # This example trains a Bert base model on GLUE MRPC 30 | # in any of the following settings (with the same script): 31 | # - single CPU or single GPU 32 | # - multi GPUS (using PyTorch distributed mode) 33 | # - (multi) TPUs 34 | # - fp16 (mixed-precision) or fp32 (normal precision) 35 | # 36 | # To run it in each of these various modes, follow the instructions 37 | # in the readme for examples: 38 | # https://github.com/huggingface/accelerate/tree/main/examples 39 | # 40 | ######################################################################## 41 | 42 | 43 | MAX_GPU_BATCH_SIZE = 16 44 | EVAL_BATCH_SIZE = 32 45 | 46 | 47 | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16): 48 | """ 49 | Creates a set of `DataLoader`s for the `glue` dataset, 50 | using "bert-base-cased" as the tokenizer. 51 | 52 | Args: 53 | accelerator (`Accelerator`): 54 | An `Accelerator` object 55 | batch_size (`int`, *optional*): 56 | The batch size for the train and validation DataLoaders. 57 | """ 58 | tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") 59 | datasets = load_dataset("glue", "mrpc") 60 | 61 | def tokenize_function(examples): 62 | # max_length=None => use the model max length (it's actually the default) 63 | outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) 64 | return outputs 65 | 66 | # Apply the method we just defined to all the examples in all the splits of the dataset 67 | # starting with the main process first: 68 | with accelerator.main_process_first(): 69 | tokenized_datasets = datasets.map( 70 | tokenize_function, 71 | batched=True, 72 | remove_columns=["idx", "sentence1", "sentence2"], 73 | ) 74 | 75 | # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the 76 | # transformers library 77 | tokenized_datasets = tokenized_datasets.rename_column("label", "labels") 78 | 79 | def collate_fn(examples): 80 | # For Torchxla, it's best to pad everything to the same length or training will be very slow. 81 | max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None 82 | # When using mixed precision we want round multiples of 8/16 83 | if accelerator.mixed_precision == "fp8": 84 | pad_to_multiple_of = 16 85 | elif accelerator.mixed_precision != "no": 86 | pad_to_multiple_of = 8 87 | else: 88 | pad_to_multiple_of = None 89 | 90 | return tokenizer.pad( 91 | examples, 92 | padding="longest", 93 | max_length=max_length, 94 | pad_to_multiple_of=pad_to_multiple_of, 95 | return_tensors="pt", 96 | ) 97 | 98 | # Instantiate dataloaders. 99 | train_dataloader = DataLoader( 100 | tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True 101 | ) 102 | eval_dataloader = DataLoader( 103 | tokenized_datasets["validation"], 104 | shuffle=False, 105 | collate_fn=collate_fn, 106 | batch_size=EVAL_BATCH_SIZE, 107 | drop_last=(accelerator.mixed_precision == "fp8"), 108 | ) 109 | 110 | return train_dataloader, eval_dataloader 111 | 112 | 113 | def training_function(config, args): 114 | # Initialize accelerator 115 | accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision) 116 | # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs 117 | lr = config["lr"] 118 | num_epochs = int(config["num_epochs"]) 119 | seed = int(config["seed"]) 120 | batch_size = int(config["batch_size"]) 121 | 122 | metric = evaluate.load("glue", "mrpc") 123 | 124 | # If the batch size is too big we use gradient accumulation 125 | gradient_accumulation_steps = 1 126 | if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA: 127 | gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE 128 | batch_size = MAX_GPU_BATCH_SIZE 129 | 130 | set_seed(seed) 131 | train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size) 132 | # Instantiate the model (we build the model here so that the seed also control new weights initialization) 133 | model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True) 134 | 135 | # We could avoid this line since the accelerator is set with `device_placement=True` (default value). 136 | # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer 137 | # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that). 138 | model = model.to(accelerator.device) 139 | # Instantiate optimizer 140 | optimizer = AdamW(params=model.parameters(), lr=lr) 141 | 142 | # Instantiate scheduler 143 | lr_scheduler = get_linear_schedule_with_warmup( 144 | optimizer=optimizer, 145 | num_warmup_steps=100, 146 | num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps, 147 | ) 148 | 149 | # Prepare everything 150 | # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the 151 | # prepare method. 152 | 153 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( 154 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler 155 | ) 156 | 157 | # Now we train the model 158 | for epoch in range(num_epochs): 159 | model.train() 160 | for step, batch in enumerate(train_dataloader): 161 | # We could avoid this line since we set the accelerator with `device_placement=True`. 162 | batch.to(accelerator.device) 163 | outputs = model(**batch) 164 | loss = outputs.loss 165 | loss = loss / gradient_accumulation_steps 166 | accelerator.backward(loss) 167 | if step % gradient_accumulation_steps == 0: 168 | optimizer.step() 169 | lr_scheduler.step() 170 | optimizer.zero_grad() 171 | 172 | model.eval() 173 | for step, batch in enumerate(eval_dataloader): 174 | # We could avoid this line since we set the accelerator with `device_placement=True`. 175 | batch.to(accelerator.device) 176 | with torch.no_grad(): 177 | outputs = model(**batch) 178 | predictions = outputs.logits.argmax(dim=-1) 179 | predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"])) 180 | metric.add_batch( 181 | predictions=predictions, 182 | references=references, 183 | ) 184 | 185 | eval_metric = metric.compute() 186 | # Use accelerator.print to print only on the main process. 187 | accelerator.print(f"epoch {epoch}:", eval_metric) 188 | accelerator.end_training() 189 | 190 | 191 | def main(): 192 | parser = argparse.ArgumentParser(description="Simple example of training script.") 193 | parser.add_argument( 194 | "--mixed_precision", 195 | type=str, 196 | default=None, 197 | choices=["no", "fp16", "bf16", "fp8"], 198 | help="Whether to use mixed precision. Choose" 199 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 200 | "and an Nvidia Ampere GPU.", 201 | ) 202 | parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.") 203 | args = parser.parse_args() 204 | config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 16} 205 | training_function(config, args) 206 | 207 | 208 | if __name__ == "__main__": 209 | main() 210 | -------------------------------------------------------------------------------- /pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "07e34e5f-9aa3-4d9c-8dab-cc04b3b5022e", 6 | "metadata": {}, 7 | "source": [ 8 | "# PyTorch GPUs\n", 9 | "\n", 10 | "Optuna example that optimizes multi-layer perceptrons using PyTorch. \n", 11 | "\n", 12 | "Modified from https://github.com/optuna/optuna-examples/blob/main/pytorch/pytorch_simple.py\n", 13 | "\n", 14 | "In this example, we optimize the validation accuracy of fashion product recognition using\n", 15 | "PyTorch and FashionMNIST. We optimize the neural network architecture as well as the optimizer\n", 16 | "configuration. As it is too time consuming to use the whole FashionMNIST dataset,\n", 17 | "we here use a small subset of it.\n" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "id": "4a8811e7-af13-436e-85ec-1a8ef651f6bf", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "\"\"\"\n", 28 | "This example was adapted from the following PyTorch tutorial\n", 29 | "https://pytorch.org/tutorials/beginner/introyt/trainingyt.html\n", 30 | "\"\"\"\n", 31 | "\n", 32 | "import os\n", 33 | "import torch\n", 34 | "import torch.nn as nn\n", 35 | "import torch.nn.functional as F\n", 36 | "from torch.optim import SGD\n", 37 | "from torchvision import datasets, transforms\n", 38 | "from dask.distributed import print\n", 39 | "\n", 40 | "def load_data():\n", 41 | " transform = transforms.Compose(\n", 42 | " [transforms.ToTensor(),\n", 43 | " transforms.Normalize((0.5,), (0.5,))])\n", 44 | "\n", 45 | " # Create datasets for training & validation, download if necessary\n", 46 | " training_set = datasets.FashionMNIST(os.getcwd(), train=True, transform=transform, download=True)\n", 47 | " validation_set = datasets.FashionMNIST(os.getcwd(), train=False, transform=transform, download=True)\n", 48 | "\n", 49 | " # Create data loaders for our datasets; shuffle for training, not for validation\n", 50 | " training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)\n", 51 | " validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)\n", 52 | "\n", 53 | " # Report split sizes\n", 54 | " print('Training set has {} instances'.format(len(training_set)))\n", 55 | " print('Validation set has {} instances'.format(len(validation_set)))\n", 56 | "\n", 57 | " return training_loader, validation_loader\n", 58 | "\n", 59 | "\n", 60 | "class GarmentClassifier(nn.Module):\n", 61 | " def __init__(self):\n", 62 | " super(GarmentClassifier, self).__init__()\n", 63 | " self.conv1 = nn.Conv2d(1, 6, 5)\n", 64 | " self.pool = nn.MaxPool2d(2, 2)\n", 65 | " self.conv2 = nn.Conv2d(6, 16, 5)\n", 66 | " self.fc1 = nn.Linear(16 * 4 * 4, 120)\n", 67 | " self.fc2 = nn.Linear(120, 84)\n", 68 | " self.fc3 = nn.Linear(84, 10)\n", 69 | "\n", 70 | " def forward(self, x):\n", 71 | " x = self.pool(F.relu(self.conv1(x)))\n", 72 | " x = self.pool(F.relu(self.conv2(x)))\n", 73 | " x = x.view(-1, 16 * 4 * 4)\n", 74 | " x = F.relu(self.fc1(x))\n", 75 | " x = F.relu(self.fc2(x))\n", 76 | " x = self.fc3(x)\n", 77 | " return x\n", 78 | "\n", 79 | "\n", 80 | "def train_one_epoch(model, loss_fn, optimizer, training_loader, device):\n", 81 | " running_loss = 0.\n", 82 | " last_loss = 0.\n", 83 | "\n", 84 | " # Here, we use enumerate(training_loader) instead of\n", 85 | " # iter(training_loader) so that we can track the batch\n", 86 | " # index and do some intra-epoch reporting\n", 87 | " for i, data in enumerate(training_loader):\n", 88 | " # Every data instance is an input + label pair\n", 89 | " inputs, labels = data\n", 90 | "\n", 91 | " # Move to GPU\n", 92 | " inputs, labels = inputs.to(device), labels.to(device)\n", 93 | "\n", 94 | " # Zero your gradients for every batch!\n", 95 | " optimizer.zero_grad()\n", 96 | "\n", 97 | " # Make predictions for this batch\n", 98 | " outputs = model(inputs)\n", 99 | "\n", 100 | " # Compute the loss and its gradients\n", 101 | " loss = loss_fn(outputs, labels)\n", 102 | " loss.backward()\n", 103 | "\n", 104 | " # Adjust learning weights\n", 105 | " optimizer.step()\n", 106 | "\n", 107 | " # Gather data\n", 108 | " running_loss += loss.item()\n", 109 | " if i % 1000 == 999:\n", 110 | " last_loss = running_loss / 1000 # loss per batch\n", 111 | " print(' batch {} loss: {}'.format(i + 1, last_loss))\n", 112 | " running_loss = 0.\n", 113 | "\n", 114 | " return last_loss\n", 115 | "\n", 116 | "\n", 117 | "def train_all_epochs():\n", 118 | " # Confirm that GPU shows up\n", 119 | " if torch.cuda.is_available():\n", 120 | " device = \"cuda\"\n", 121 | " print(f\"Using GPU {torch.cuda.get_device_name(torch.cuda.current_device())} 😎\\n\")\n", 122 | " else:\n", 123 | " device = \"cpu\"\n", 124 | " print(\"Using CPU 😔\\n\")\n", 125 | "\n", 126 | " training_loader, validation_loader = load_data()\n", 127 | " model = GarmentClassifier().to(device)\n", 128 | " loss_fn = nn.CrossEntropyLoss()\n", 129 | " optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)\n", 130 | "\n", 131 | " epochs = 5\n", 132 | " best_vloss = 1_000_000.\n", 133 | "\n", 134 | " for epoch in range(epochs):\n", 135 | " print(f'EPOCH {epoch + 1}:')\n", 136 | "\n", 137 | " # Make sure gradient tracking is on, and do a pass over the data\n", 138 | " model.train(True)\n", 139 | " avg_loss = train_one_epoch(model, loss_fn, optimizer, training_loader, device)\n", 140 | "\n", 141 | " running_vloss = 0.0\n", 142 | " # Set the model to evaluation mode, disabling dropout and using population\n", 143 | " # statistics for batch normalization.\n", 144 | " model.eval()\n", 145 | "\n", 146 | " # Disable gradient computation and reduce memory consumption.\n", 147 | " with torch.no_grad():\n", 148 | " for i, vdata in enumerate(validation_loader):\n", 149 | " vinputs, vlabels = vdata\n", 150 | "\n", 151 | " # Move to GPU\n", 152 | " vinputs, vlabels = vinputs.to(device), vlabels.to(device)\n", 153 | "\n", 154 | " voutputs = model(vinputs)\n", 155 | " vloss = loss_fn(voutputs, vlabels)\n", 156 | " running_vloss += vloss\n", 157 | "\n", 158 | " avg_vloss = running_vloss / (i + 1)\n", 159 | " print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))\n", 160 | "\n", 161 | " # Return the best model\n", 162 | " if avg_vloss < best_vloss:\n", 163 | " best_vloss = avg_vloss\n", 164 | " best_model = model\n", 165 | "\n", 166 | " print(f\"Model on CUDA device: {next(best_model.parameters()).is_cuda}\")\n", 167 | "\n", 168 | " # Move model to CPU so it can be serialized and returned to local machine\n", 169 | "\n", 170 | " return best_model\n", 171 | "\n" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "id": "95510b24-123d-4bb4-8249-a0edb003582b", 177 | "metadata": {}, 178 | "source": [ 179 | "## Run on CPU" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "id": "9ade8f94-9008-443c-93e1-6da90d80efbc", 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "model = train_all_epochs()" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "id": "6ebb91d8-9c2d-408d-9c3c-a0e9535106f1", 195 | "metadata": {}, 196 | "source": [ 197 | "## Run on GPU\n" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "id": "2b59f138-f84f-4319-9a54-821dd0525d65", 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "import coiled\n", 208 | "\n", 209 | "@coiled.function(\n", 210 | " vm_type=\"g5.xlarge\",\n", 211 | " region=\"us-east-2\",\n", 212 | " keepalive=\"1 hour\",\n", 213 | ")\n", 214 | "def train_on_gpu():\n", 215 | " model = train_all_epochs()\n", 216 | " return model.to(\"cpu\")\n", 217 | "\n", 218 | "model = train_on_gpu()" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "id": "b6e7ae53-ce21-49f3-856a-87774a19bf37", 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [] 228 | } 229 | ], 230 | "metadata": { 231 | "kernelspec": { 232 | "display_name": "Python 3 (ipykernel)", 233 | "language": "python", 234 | "name": "python3" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.11.9" 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 5 251 | } 252 | -------------------------------------------------------------------------------- /national-water-model/xarray-water-model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "75688ac6-879d-4449-b73e-74f03a5f991f", 6 | "metadata": { 7 | "tags": [] 8 | }, 9 | "source": [ 10 | "# Analyzing the National Water Model with Xarray, Dask, and Coiled\n", 11 | "\n", 12 | "_This example was adapted from [this notebook](https://github.com/dcherian/dask-demo/blob/main/nwm-aws.ipynb) by Deepak Cherian, Kevin Sampson, and Matthew Rocklin._\n", 13 | "\n", 14 | "\n", 15 | "\n", 16 | "## The National Water Model Dataset\n", 17 | "\n", 18 | "In this example, we'll perform a county-wise aggregation of output from the National Water Model (NWM) available on the [AWS Open Data Registry](https://registry.opendata.aws/nwm-archive/). You can read more on the NWM from the [Office of Water Prediction](https://water.noaa.gov/about/nwm).\n", 19 | "\n", 20 | "## Problem description\n", 21 | "\n", 22 | "Datasets with high spatio-temporal resolution can get large quickly, vastly exceeding the resources you may have on your laptop. Dask integrates with Xarray to support parallel computing and you can use Coiled to scale to the cloud.\n", 23 | "\n", 24 | "We'll calculate the mean depth to soil saturation for each US county:\n", 25 | "\n", 26 | "- Years: 2020\n", 27 | "- Temporal resolution: 3-hourly land surface output\n", 28 | "- Spatial resolution: 250 m grid\n", 29 | "- 6 TB\n", 30 | "\n", 31 | "This example relies on a few tools:\n", 32 | "- `dask` + `coiled` process the dataset in parallel in the cloud\n", 33 | "- `xarray` + `flox` to work with the multi-dimensional Zarr datset and aggregate to county-level means from the 250m grid." 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "5dd71599-465f-4c97-baaa-19d900d2a070", 39 | "metadata": { 40 | "tags": [] 41 | }, 42 | "source": [ 43 | "## Start a Coiled cluster\n", 44 | "\n", 45 | "To demonstrate calculation on a cloud-available dataset, we will use Coiled to set up a dask cluster in AWS `us-east-1`." 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "60b08a1c-d042-40f2-aaaa-e7665ca85d64", 52 | "metadata": { 53 | "tags": [] 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "import coiled\n", 58 | "\n", 59 | "cluster = coiled.Cluster(\n", 60 | " region=\"us-east-1\", # close to dataset, avoid egress charges\n", 61 | " n_workers=10,\n", 62 | " scheduler_vm_types=\"r7g.xlarge\", # memory optimized AWS EC2 instances\n", 63 | " worker_vm_types=\"r7g.2xlarge\"\n", 64 | ")\n", 65 | "\n", 66 | "client = cluster.get_client()\n", 67 | "\n", 68 | "cluster.adapt(minimum=10, maximum=50)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "id": "8185966d-6659-482b-bcbb-826b8f30b1e3", 74 | "metadata": {}, 75 | "source": [ 76 | "### Load NWM data" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "e8b1749a-0d64-4278-823c-892120bf1a5b", 83 | "metadata": { 84 | "tags": [] 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "import fsspec\n", 89 | "import xarray as xr\n", 90 | "\n", 91 | "ds = xr.open_zarr(\n", 92 | " fsspec.get_mapper(\"s3://noaa-nwm-retrospective-2-1-zarr-pds/rtout.zarr\", anon=True),\n", 93 | " consolidated=True,\n", 94 | " chunks={\"time\": 896, \"x\": 350, \"y\": 350}\n", 95 | ")\n", 96 | "ds" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "id": "c7d5631a-0974-48fe-a2dc-4cbeb5654838", 102 | "metadata": {}, 103 | "source": [ 104 | "Each field in this dataset is big!" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "id": "41fd51a5-858f-4c43-926f-e212e6d3dd7b", 111 | "metadata": { 112 | "tags": [] 113 | }, 114 | "outputs": [], 115 | "source": [ 116 | "ds.zwattablrt" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "id": "9d70429a-ad29-46c1-80e5-5b15b8012b47", 122 | "metadata": { 123 | "tags": [] 124 | }, 125 | "source": [ 126 | "Subset to a single year subset for demo purposes" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "id": "2a6fb91d-6a02-4afc-8d8a-ec3529f805f4", 133 | "metadata": { 134 | "tags": [] 135 | }, 136 | "outputs": [], 137 | "source": [ 138 | "subset = ds.zwattablrt.sel(time=slice(\"2020-01-01\", \"2020-12-31\"))\n", 139 | "subset" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "id": "10701b9e-3607-4734-9a14-094cebc3c26e", 145 | "metadata": {}, 146 | "source": [ 147 | "### Load county raster for grouping\n", 148 | "\n", 149 | "Load a raster TIFF file identifying counties by unique integer with [rioxarray](https://corteva.github.io/rioxarray/html/rioxarray.html)." 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "id": "50c13d54-7bad-4864-92da-aa7c5b2b35d6", 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "import fsspec\n", 160 | "import rioxarray\n", 161 | "\n", 162 | "fs = fsspec.filesystem(\"s3\", requester_pays=True)\n", 163 | "\n", 164 | "counties = rioxarray.open_rasterio(\n", 165 | " fs.open(\"s3://nwm-250m-us-counties/Counties_on_250m_grid.tif\"), chunks=\"auto\"\n", 166 | ").squeeze()\n", 167 | "\n", 168 | "# remove any small floating point error in coordinate locations\n", 169 | "_, counties_aligned = xr.align(subset, counties, join=\"override\")\n", 170 | "counties_aligned = counties_aligned.persist()\n", 171 | "\n", 172 | "counties_aligned" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "id": "2b3f28f8-8548-4bbc-ab64-e5efaf21bf3c", 178 | "metadata": {}, 179 | "source": [ 180 | "We'll need the unique county IDs later, calculate that now." 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "id": "ae3b2fab-b576-41a7-b0ec-2f37aba924bd", 187 | "metadata": { 188 | "tags": [] 189 | }, 190 | "outputs": [], 191 | "source": [ 192 | "import numpy as np\n", 193 | "\n", 194 | "county_id = np.unique(counties_aligned.data).compute()\n", 195 | "county_id = county_id[county_id != 0]\n", 196 | "print(f\"There are {len(county_id)} counties!\")" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "id": "0a2b5ebe-d01a-46b4-b0c5-ce1052af4a4c", 202 | "metadata": {}, 203 | "source": [ 204 | "### GroupBy with flox\n", 205 | "\n", 206 | "We could run the computation as:\n", 207 | "\n", 208 | "```python\n", 209 | "subset.groupby(counties_aligned).mean()\n", 210 | "```\n", 211 | "\n", 212 | "This would use flox in the background, however, it would also load `counties_aligned` into memory. To avoid egress charges, you can use `flox.xarray` which allows you to lazily groupby a Dask array (here `counties_aligned`) as long as you pass in the expected group labels in `expected_groups`. See the [flox documentation](https://flox.readthedocs.io/en/latest/intro.html#with-dask)." 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "id": "a2c98fc8", 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "import flox.xarray\n", 223 | "\n", 224 | "county_mean = flox.xarray.xarray_reduce(\n", 225 | " subset,\n", 226 | " counties_aligned.rename(\"county\"),\n", 227 | " func=\"mean\",\n", 228 | " expected_groups=(county_id,),\n", 229 | ")\n", 230 | "\n", 231 | "county_mean" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "id": "e35efab0-9783-43e6-ab70-bb50a27629ec", 238 | "metadata": { 239 | "tags": [] 240 | }, 241 | "outputs": [], 242 | "source": [ 243 | "county_mean.load()" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "id": "f0e9079e-5aac-4ac2-9440-0b520f0cac76", 249 | "metadata": {}, 250 | "source": [ 251 | "### Cleanup" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "id": "294cc83e-ad8e-4451-90a9-877f39097c63", 258 | "metadata": { 259 | "tags": [] 260 | }, 261 | "outputs": [], 262 | "source": [ 263 | "# since our dataset is much smaller now, we no longer need cloud resources\n", 264 | "cluster.shutdown()\n", 265 | "client.close()" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "id": "a3a84a45-7d68-4be0-8b13-28aca2a0a122", 271 | "metadata": {}, 272 | "source": [ 273 | "## Visualize" 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "id": "6f358457", 279 | "metadata": {}, 280 | "source": [ 281 | "Data prep" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "id": "186edb79-06ac-4207-b3d8-251b39254211", 288 | "metadata": { 289 | "tags": [] 290 | }, 291 | "outputs": [], 292 | "source": [ 293 | "# Read county shapefile, combo of state FIPS code and county FIPS code as multi-index\n", 294 | "import geopandas as gpd\n", 295 | "\n", 296 | "counties = gpd.read_file(\n", 297 | " \"https://www2.census.gov/geo/tiger/GENZ2020/shp/cb_2020_us_county_20m.zip\"\n", 298 | ").to_crs(\"EPSG:3395\")\n", 299 | "counties[\"STATEFP\"] = counties.STATEFP.astype(int)\n", 300 | "counties[\"COUNTYFP\"] = counties.COUNTYFP.astype(int)\n", 301 | "continental = counties[~counties[\"STATEFP\"].isin([2, 15, 72])].set_index([\"STATEFP\", \"COUNTYFP\"]) # drop Alaska, Hawaii, Puerto Rico\n", 302 | "\n", 303 | "# Interpret `county` as combo of state FIPS code and county FIPS code. Set multi-index:\n", 304 | "yearly_mean = county_mean.mean(\"time\")\n", 305 | "yearly_mean.coords[\"STATEFP\"] = (yearly_mean.county // 1000).astype(int)\n", 306 | "yearly_mean.coords[\"COUNTYFP\"] = np.mod(yearly_mean.county, 1000).astype(int)\n", 307 | "yearly_mean = yearly_mean.drop_vars(\"county\").set_index(county=[\"STATEFP\", \"COUNTYFP\"])\n", 308 | "\n", 309 | "# join\n", 310 | "continental[\"zwattablrt\"] = yearly_mean.to_dataframe()[\"zwattablrt\"]" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "id": "a07d98d2-a207-4dc4-bf22-da732d41d445", 316 | "metadata": {}, 317 | "source": [ 318 | "Plot" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": null, 324 | "id": "c3f52f16-1a11-448b-bc01-e278e093d0d4", 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "import matplotlib.pyplot as plt\n", 329 | "from mpl_toolkits.axes_grid1 import make_axes_locatable\n", 330 | "\n", 331 | "fig, ax = plt.subplots(1, 1, figsize=(7.68, 4.32))\n", 332 | "\n", 333 | "ax.set_axis_off()\n", 334 | "\n", 335 | "divider = make_axes_locatable(ax)\n", 336 | "cax = divider.append_axes(\"bottom\", size='5%', pad=0.1)\n", 337 | "\n", 338 | "cax.tick_params(labelsize=8)\n", 339 | "cax.set_title(\"Average depth (in meters) of the water table in 2020\", fontsize=8)\n", 340 | "\n", 341 | "continental.plot(\n", 342 | " column=\"zwattablrt\",\n", 343 | " cmap=\"BrBG_r\",\n", 344 | " vmin=0,\n", 345 | " vmax=2,\n", 346 | " legend=True,\n", 347 | " ax=ax,\n", 348 | " cax=cax,\n", 349 | " legend_kwds={\n", 350 | " \"orientation\": \"horizontal\",\n", 351 | " \"ticks\": [0,0.5,1,1.5,2],\n", 352 | " }\n", 353 | ")\n", 354 | "\n", 355 | "plt.text(0, 1, \"6 TB processed, ~$1 in cloud costs\", transform=ax.transAxes, size=9)\n", 356 | "plt.show()" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": null, 362 | "id": "c2ba2756-10e4-49be-a900-e25fb80b2818", 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [] 366 | } 367 | ], 368 | "metadata": { 369 | "kernelspec": { 370 | "display_name": "Python 3 (ipykernel)", 371 | "language": "python", 372 | "name": "python3" 373 | }, 374 | "language_info": { 375 | "codemirror_mode": { 376 | "name": "ipython", 377 | "version": 3 378 | }, 379 | "file_extension": ".py", 380 | "mimetype": "text/x-python", 381 | "name": "python", 382 | "nbconvert_exporter": "python", 383 | "pygments_lexer": "ipython3", 384 | "version": "3.10.0" 385 | } 386 | }, 387 | "nbformat": 4, 388 | "nbformat_minor": 5 389 | } 390 | -------------------------------------------------------------------------------- /arxiv-matplotlib.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "ae977bc3-d3cf-4492-a359-a95f8156fb52", 6 | "metadata": {}, 7 | "source": [ 8 | "\n", 11 | "\n", 12 | "# How Popular is Matplotlib?\n", 13 | "\n", 14 | "Anecdotally the Matplotlib maintainers were told \n", 15 | "\n", 16 | "*\"About 15% of arXiv papers use Matplotlib\"*\n", 17 | "\n", 18 | "arXiv is the preeminent repository for scholarly preprint articles. It stores millions of journal articles used across science. It's also public access, and so we can just scrape the entire thing given enough compute power.\n", 19 | "\n", 20 | "## Watermark\n", 21 | "\n", 22 | "Starting in the early 2010s, Matplotlib started including the bytes `b\"Matplotlib\"` in every PNG and PDF that they produce. These bytes persist in PDFs that contain Matplotlib plots, including the PDFs stored on arXiv. As a result, it's pretty simple to check if a PDF contains a Matplotlib image. All we have to do is scan through every PDF and look for these bytes; no parsing required.\n", 23 | "\n", 24 | "## Data\n", 25 | "\n", 26 | "The data is stored in a requester pays bucket at s3://arxiv (more information at https://arxiv.org/help/bulk_data_s3 ) and also on GCS hosted by Kaggle (more information at https://www.kaggle.com/datasets/Cornell-University/arxiv). \n", 27 | "\n", 28 | "The data is about 1TB in size. We're going to use Dask for this.\n", 29 | "\n", 30 | "This is a good example of writing plain vanilla Python code to solve a problem, running into issues of scale, and then using Dask to easily jump over those problems." 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "id": "6f0965a0-fa87-470b-bd3d-0b5b7ecaca99", 36 | "metadata": {}, 37 | "source": [ 38 | "### Get all filenames\n", 39 | "\n", 40 | "Our data is stored in a requester pays S3 bucket in the `us-east-1` region. Each file is a tar file which contains a directory of papers." 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "e62539ef-5e91-43c5-afa8-0c3fa51b8f11", 47 | "metadata": { 48 | "tags": [] 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "import s3fs\n", 53 | "s3 = s3fs.S3FileSystem(requester_pays=True)\n", 54 | "\n", 55 | "directories = s3.ls(\"s3://arxiv/pdf\")\n", 56 | "directories[:10]" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "9e5cb2b5-1ad5-4a21-b98d-4f0f615dacd6", 63 | "metadata": { 64 | "tags": [] 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "len(directories)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "id": "92438ddf-7b02-462d-8a5d-2b2e760dd1a4", 74 | "metadata": {}, 75 | "source": [ 76 | "There are lots of these" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "b1646a64-e4b2-4965-98d8-d0d5322e4368", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "s3.du(\"s3://arxiv/pdf\") / 1e12" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "id": "a4d3219f-a6ce-487f-b5ae-522a7415c014", 92 | "metadata": {}, 93 | "source": [ 94 | "## Process one file with plain Python\n", 95 | "\n", 96 | "Mostly we have to muck about with tar files. This wasn't hard. The `tarfile` library is in the stardard library. It's not beautiful, but it's also not hard to use." 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "85146f13-5e5a-40e3-8d56-a79064f35ce4", 103 | "metadata": { 104 | "tags": [] 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "import tarfile\n", 109 | "import io\n", 110 | "\n", 111 | "def extract(filename: str):\n", 112 | " \"\"\" Extract and process one directory of arXiv data\n", 113 | " \n", 114 | " Returns\n", 115 | " -------\n", 116 | " filename: str\n", 117 | " contains_matplotlib: boolean\n", 118 | " \"\"\"\n", 119 | " out = []\n", 120 | " with s3.open(filename) as f:\n", 121 | " bytes = f.read()\n", 122 | " with io.BytesIO() as bio:\n", 123 | " bio.write(bytes)\n", 124 | " bio.seek(0)\n", 125 | " try:\n", 126 | " with tarfile.TarFile(fileobj=bio) as tf:\n", 127 | " for member in tf.getmembers():\n", 128 | " if member.isfile() and member.name.endswith(\".pdf\"):\n", 129 | " data = tf.extractfile(member).read()\n", 130 | " out.append((\n", 131 | " member.name, \n", 132 | " b\"matplotlib\" in data.lower()\n", 133 | " ))\n", 134 | " except tarfile.ReadError:\n", 135 | " pass\n", 136 | " return out" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "de243ca8-bcd2-47b4-8574-bce3f0bda790", 143 | "metadata": { 144 | "tags": [] 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "%%time\n", 149 | "\n", 150 | "# See an example of its use\n", 151 | "extract(directories[20])[:20]" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "id": "9feae86b-4c46-455d-8e5b-09eb80ec3400", 157 | "metadata": {}, 158 | "source": [ 159 | "# Scale processing to full dataset\n", 160 | "\n", 161 | "Great, we can get a record of each file and whether or not it used Matplotlib. Each of these takes about a minute to run on my local machine. Processing all 5000 files would take 5000 minutes, or around 100 hours. \n", 162 | "\n", 163 | "We can accelerate this in two ways:\n", 164 | "\n", 165 | "1. **Process closer to the data** by spinning up resources in the same region on the cloud (this also reduces data transfer costs)\n", 166 | "2. **Use hundreds of workers** in parallel\n", 167 | "\n", 168 | "We can do this easily with [Coiled Functions](https://docs.coiled.io/user_guide/usage/functions/index.html)." 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "id": "34748fc3-7f2f-442e-9829-626d88878234", 174 | "metadata": {}, 175 | "source": [ 176 | "## Run function on the cloud in parallel\n", 177 | "\n", 178 | "We annotate our `extract` function with the `@coiled.function` decorator to have it run on AWS in the same region where the data is stored." 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "id": "245c269a-c432-4352-aa01-0d9110b63304", 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "import coiled\n", 189 | "\n", 190 | "@coiled.function(\n", 191 | " region=\"us-east-1\", # Local to data. Faster and cheaper.\n", 192 | " vm_type=\"m6i.xlarge\",\n", 193 | " threads_per_worker=4,\n", 194 | ")\n", 195 | "def extract(filename: str):\n", 196 | " \"\"\" Extract and process one directory of arXiv data\n", 197 | " \n", 198 | " Returns\n", 199 | " -------\n", 200 | " filename: str\n", 201 | " contains_matplotlib: boolean\n", 202 | " \"\"\"\n", 203 | " out = []\n", 204 | " with s3.open(filename) as f:\n", 205 | " bytes = f.read()\n", 206 | " with io.BytesIO() as bio:\n", 207 | " bio.write(bytes)\n", 208 | " bio.seek(0)\n", 209 | " try:\n", 210 | " with tarfile.TarFile(fileobj=bio) as tf:\n", 211 | " for member in tf.getmembers():\n", 212 | " if member.isfile() and member.name.endswith(\".pdf\"):\n", 213 | " data = tf.extractfile(member).read()\n", 214 | " out.append((\n", 215 | " member.name, \n", 216 | " b\"matplotlib\" in data.lower()\n", 217 | " ))\n", 218 | " except tarfile.ReadError:\n", 219 | " pass\n", 220 | " return out" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "id": "f9c6042c-917f-46ca-9722-c99e02fb97cb", 226 | "metadata": { 227 | "tags": [] 228 | }, 229 | "source": [ 230 | "### Map function across every directory\n", 231 | "\n", 232 | "Let's scale up this work across all of the directories in our dataset.\n", 233 | "\n", 234 | "Hopefully it will also be faster because the cloud VMs are in the same region as the dataset itself." 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "id": "89a6ab9f-3451-48dc-abb2-f4d8f5e0b038", 241 | "metadata": { 242 | "tags": [] 243 | }, 244 | "outputs": [], 245 | "source": [ 246 | "%%time\n", 247 | "\n", 248 | "results = extract.map(directories)\n", 249 | "lists = list(results)" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "id": "84af922e-4f6b-4f4b-97d5-a7097084aa1b", 255 | "metadata": {}, 256 | "source": [ 257 | "Now that we're done with the large data problem we can turn off Coiled and proceed with pure Pandas. There's no reason to deal with scalable tools if we don't have to." 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "id": "299b5097-ff28-4036-b569-a18449cca0d9", 263 | "metadata": {}, 264 | "source": [ 265 | "## Enrich Data\n", 266 | "\n", 267 | "Let's enhance our data a bit. The filenames of each file include the year and month when they were published. After extracting this data we'll be able to see a timeseries of Matplotlib adoption." 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "id": "85d74a56-e7c1-4614-86bd-6342e16d58fd", 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "# Convert to Pandas\n", 278 | "\n", 279 | "import pandas as pd\n", 280 | "\n", 281 | "dfs = [\n", 282 | " pd.DataFrame(list, columns=[\"filename\", \"has_matplotlib\"]) \n", 283 | " for list in lists\n", 284 | "]\n", 285 | "\n", 286 | "df = pd.concat(dfs)\n", 287 | "\n", 288 | "df" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "id": "f98c54e7-fc46-4180-9586-c06eac6432e6", 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "def date(filename):\n", 299 | " year = int(filename.split(\"/\")[0][:2])\n", 300 | " month = int(filename.split(\"/\")[0][2:4])\n", 301 | " if year > 80:\n", 302 | " year = 1900 + year\n", 303 | " else:\n", 304 | " year = 2000 + year\n", 305 | " \n", 306 | " return pd.Timestamp(year=year, month=month, day=1)\n", 307 | "\n", 308 | "date(\"0005/astro-ph0001322.pdf\")" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "id": "033451dc-344f-40e0-bd05-f2f6fd80d0c2", 314 | "metadata": {}, 315 | "source": [ 316 | "Yup. That seems to work. Let's map this function over our dataset." 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "id": "315f8de5-c53a-49d9-8f8d-3ba62fcee727", 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "df[\"date\"] = df.filename.map(date)\n", 327 | "df.head()" 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "id": "e64781f1-0486-4eda-81b4-68911383be7a", 333 | "metadata": {}, 334 | "source": [ 335 | "## Plot\n", 336 | "\n", 337 | "Now we can just fool around with Pandas and Matplotlib." 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "id": "270ae46c-ac65-48f1-a7ca-c6f742591c7c", 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [ 347 | "df.groupby(\"date\").has_matplotlib.mean().plot(\n", 348 | " title=\"Matplotlib Usage in arXiv\", \n", 349 | " ylabel=\"Fraction of papers\"\n", 350 | ").get_figure().savefig(\"results.png\")" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "id": "bd3a2592-d153-4139-a4e2-a80f4a466c24", 356 | "metadata": {}, 357 | "source": [ 358 | "I did the plot above. Then Thomas Caswell (matplotlib maintainer) came by and, in true form, made something much better 🙂" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "id": "cf83b666-05c1-47f3-a2e3-55bbe8f5eeb8", 365 | "metadata": {}, 366 | "outputs": [], 367 | "source": [ 368 | "import datetime\n", 369 | "import matplotlib.pyplot as plt\n", 370 | "from matplotlib.ticker import PercentFormatter\n", 371 | "\n", 372 | "import pandas as pd\n", 373 | "\n", 374 | "by_month = df.groupby(\"date\").has_matplotlib.mean()\n", 375 | "\n", 376 | "# get figure\n", 377 | "fig, ax = plt.subplots(layout=\"constrained\")\n", 378 | "# plot the data\n", 379 | "ax.plot(by_month, \"o\", color=\"k\", ms=3)\n", 380 | "\n", 381 | "# over-ride the default auto limits\n", 382 | "ax.set_xlim(left=datetime.date(2004, 1, 1))\n", 383 | "ax.set_ylim(bottom=0)\n", 384 | "\n", 385 | "# turn on a horizontal grid\n", 386 | "ax.grid(axis=\"y\")\n", 387 | "\n", 388 | "# remove the top and right spines\n", 389 | "ax.spines.right.set_visible(False)\n", 390 | "ax.spines.top.set_visible(False)\n", 391 | "\n", 392 | "# format y-ticks a percent\n", 393 | "ax.yaxis.set_major_formatter(PercentFormatter(xmax=1))\n", 394 | "\n", 395 | "# add title and labels\n", 396 | "ax.set_xlabel(\"date\")\n", 397 | "ax.set_ylabel(\"% of all papers\")\n", 398 | "ax.set_title(\"Matplotlib usage on arXiv\");" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "id": "81198952-f6ba-4e1b-9792-36e54a5fe491", 404 | "metadata": {}, 405 | "source": [ 406 | "Yup. Matplotlib is used pretty commonly on arXiv. Go team." 407 | ] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "id": "a59d3a1d-289c-405c-8330-fc249e376b70", 412 | "metadata": {}, 413 | "source": [ 414 | "## Save results\n", 415 | "\n", 416 | "This data was slighly painful to procure. Let's save the results locally for future analysis. That way other researchers can further analyze the results without having to muck about with parallelism or cloud stuff." 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": null, 422 | "id": "3ae0d6ff-4471-4a45-bbe1-6bcd8a8d72a3", 423 | "metadata": {}, 424 | "outputs": [], 425 | "source": [ 426 | "df.to_csv(\"arxiv-matplotlib.csv\")" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": null, 432 | "id": "623becf0-a84d-4cf7-b719-883bfe60eef4", 433 | "metadata": {}, 434 | "outputs": [], 435 | "source": [ 436 | "!du -hs arxiv-matplotlib.csv" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "id": "1755b7a4-8b92-441c-9083-2bc0b51e2f81", 443 | "metadata": {}, 444 | "outputs": [], 445 | "source": [ 446 | "df.to_parquet(\"arxiv-matplotlib.parquet\", compression=\"snappy\")" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": null, 452 | "id": "afc412d6-38b3-424e-8ca5-b4141d1b776f", 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [ 456 | "!du -hs arxiv-matplotlib.parquet" 457 | ] 458 | }, 459 | { 460 | "cell_type": "markdown", 461 | "id": "f637b60b-7051-4898-9c99-b3b436acd1ae", 462 | "metadata": {}, 463 | "source": [ 464 | "## Conclusion\n", 465 | "\n", 466 | "### Matplotlib + arXiv\n", 467 | "\n", 468 | "It's incredible to see the steady growth of Matplotlib across arXiv. It's worth noting that this is *all* papers, even from fields like theoretical mathematics that are unlikely to include computer generated plots. Is this matplotlib growing in popularity? Is it Python generally?\n", 469 | "\n", 470 | "For future work, we should break this down by subfield. The filenames actually contained the name of the field for a while, like \"hep-ex\" for \"high energy physics, experimental\", but it looks like arXiv stopped doing this at some point. My guess is that there is a list mapping filenames to fields somewhere though. The filenames are all in the Pandas dataframe / parquet dataset, so doing this analysis shouldn't require any scalable computing.\n", 471 | "\n", 472 | "### Coiled\n", 473 | "\n", 474 | "Coiled was built to make it easy to answer large questions. \n", 475 | "\n", 476 | "We started this notebook with some generic Python code. When we wanted to scale up we invoked Coiled, did some work, and then tore things down, all in about ten minutes. The problem of scale or \"big data\" didn't get in the way of us analyzing data and making a delightful discovery. \n", 477 | "\n", 478 | "This is exactly why these projects exist." 479 | ] 480 | } 481 | ], 482 | "metadata": { 483 | "kernelspec": { 484 | "display_name": "Python 3 (ipykernel)", 485 | "language": "python", 486 | "name": "python3" 487 | }, 488 | "language_info": { 489 | "codemirror_mode": { 490 | "name": "ipython", 491 | "version": 3 492 | }, 493 | "file_extension": ".py", 494 | "mimetype": "text/x-python", 495 | "name": "python", 496 | "nbconvert_exporter": "python", 497 | "pygments_lexer": "ipython3", 498 | "version": "3.10.0" 499 | } 500 | }, 501 | "nbformat": 4, 502 | "nbformat_minor": 5 503 | } 504 | -------------------------------------------------------------------------------- /pytorch-optuna.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b70220fa", 6 | "metadata": {}, 7 | "source": [ 8 | "\n", 11 | "\n", 12 | "## PyTorch, GPUs, Optuna, and Dask\n", 13 | "\n", 14 | "A derived PyTorch example of a Generative Adversarial Network (GAN) being optimized with Optuna and traied on Dask.\n", 15 | "\n", 16 | "Note: To focus on GPU load and demontration purposes, this notebook generates fake data in-memory. Likely to be easily swapped out with your intended dataset." 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "df7d9367-df3b-4f27-8a71-9f6bb96ce0fc", 22 | "metadata": {}, 23 | "source": [ 24 | "## Define Model and Training Job" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "44900f9e", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "# Derived partially from\n", 35 | "# - https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html\n", 36 | "import os\n", 37 | "\n", 38 | "import coiled\n", 39 | "\n", 40 | "import optuna\n", 41 | "from optuna.integration.dask import DaskStorage\n", 42 | "from optuna.trial import TrialState\n", 43 | "\n", 44 | "import random\n", 45 | "import torch\n", 46 | "import torch.nn as nn\n", 47 | "import torch.nn.parallel\n", 48 | "import torch.backends.cudnn as cudnn\n", 49 | "import torch.optim as optim\n", 50 | "import torch.utils.data\n", 51 | "import torchvision.datasets as dset\n", 52 | "import torchvision.transforms as transforms\n", 53 | "import torchvision.utils as vutils\n", 54 | "import numpy as np\n", 55 | "from IPython.display import HTML\n", 56 | "from distributed import wait\n", 57 | "\n", 58 | "print(f\"Cuda is available: {torch.cuda.is_available()}\")\n", 59 | "\n", 60 | "# Root directory for dataset\n", 61 | "dataroot = \"data/celeba\"\n", 62 | "\n", 63 | "# Number of workers for dataloader\n", 64 | "workers = 0 # IMPORTANT w/ optuna; it launches a daemonic process, so PyTorch can't itself use it then.\n", 65 | "\n", 66 | "# Batch size during training\n", 67 | "batch_size = 64\n", 68 | "\n", 69 | "# Spatial size of training images. All images will be resized to this\n", 70 | "# size using a transformer.\n", 71 | "image_size = 64\n", 72 | "\n", 73 | "# Number of channels in the training images. For color images this is 3\n", 74 | "nc = 3\n", 75 | "\n", 76 | "# Size of z latent vector (i.e. size of generator input)\n", 77 | "nz = 100\n", 78 | "\n", 79 | "# Size of feature maps in generator\n", 80 | "ngf = 64\n", 81 | "\n", 82 | "# Size of feature maps in discriminator\n", 83 | "ndf = 64\n", 84 | "\n", 85 | "# Number of training epochs\n", 86 | "num_epochs = 5\n", 87 | "\n", 88 | "# Beta1 hyperparameter for Adam optimizers\n", 89 | "beta1 = 0.5\n", 90 | "\n", 91 | "# Number of GPUs available. Use 0 for CPU mode.\n", 92 | "ngpu = 1\n", 93 | "\n", 94 | "\n", 95 | "# We create a fake image dataset. This ought to be replaced by\n", 96 | "# your actual dataset or pytorch's example datasets. We do this here\n", 97 | "# to focus more on computation than actual convergence.\n", 98 | "class FakeImageDataset(torch.utils.data.Dataset):\n", 99 | " def __init__(self, count=1_000):\n", 100 | " self.labels = np.random.randint(0, 5, size=count)\n", 101 | " \n", 102 | " def __len__(self):\n", 103 | " return len(self.labels)\n", 104 | " \n", 105 | " def __getitem__(self, idx):\n", 106 | " label = self.labels[idx]\n", 107 | " torch.random.seed = label\n", 108 | " img = torch.rand(3, image_size, image_size)\n", 109 | " return img, label\n", 110 | "\n", 111 | "\n", 112 | "\n", 113 | "def weights_init(m):\n", 114 | " classname = m.__class__.__name__\n", 115 | " if classname.find('Conv') != -1:\n", 116 | " nn.init.normal_(m.weight.data, 0.0, 0.02)\n", 117 | " elif classname.find('BatchNorm') != -1:\n", 118 | " nn.init.normal_(m.weight.data, 1.0, 0.02)\n", 119 | " nn.init.constant_(m.bias.data, 0)\n", 120 | " \n", 121 | " \n", 122 | "class Generator(nn.Module):\n", 123 | " def __init__(self, ngpu, activation=None):\n", 124 | " super(Generator, self).__init__()\n", 125 | " self.ngpu = ngpu\n", 126 | " self.main = nn.Sequential(\n", 127 | " # input is Z, going into a convolution\n", 128 | " nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),\n", 129 | " nn.BatchNorm2d(ngf * 8),\n", 130 | " nn.ReLU(True),\n", 131 | " # state size. ``(ngf*8) x 4 x 4``\n", 132 | " nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),\n", 133 | " nn.BatchNorm2d(ngf * 4),\n", 134 | " nn.ReLU(True),\n", 135 | " # state size. ``(ngf*4) x 8 x 8``\n", 136 | " nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),\n", 137 | " nn.BatchNorm2d(ngf * 2),\n", 138 | " nn.ReLU(True),\n", 139 | " # state size. ``(ngf*2) x 16 x 16``\n", 140 | " nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),\n", 141 | " nn.BatchNorm2d(ngf),\n", 142 | " nn.ReLU(True),\n", 143 | " # state size. ``(ngf) x 32 x 32``\n", 144 | " nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),\n", 145 | " activation or nn.Sigmoid()\n", 146 | " # state size. ``(nc) x 64 x 64``\n", 147 | " )\n", 148 | "\n", 149 | " def forward(self, input):\n", 150 | " return self.main(input)\n", 151 | "\n", 152 | " @classmethod\n", 153 | " def from_trial(cls, trial):\n", 154 | " activations = [nn.Sigmoid]\n", 155 | " idx = trial.suggest_categorical(\"generator_activation\", list(range(len(activations))))\n", 156 | " activation = activations[idx]()\n", 157 | " return cls(ngpu, activation)\n", 158 | "\n", 159 | " \n", 160 | "class Discriminator(nn.Module):\n", 161 | " def __init__(self, ngpu, leaky_relu_slope=0.2, activation=None):\n", 162 | " super(Discriminator, self).__init__()\n", 163 | " self.ngpu = ngpu\n", 164 | " self.main = nn.Sequential(\n", 165 | " # input is ``(nc) x 64 x 64``\n", 166 | " nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),\n", 167 | " nn.LeakyReLU(leaky_relu_slope, inplace=True),\n", 168 | " # state size. ``(ndf) x 32 x 32``\n", 169 | " nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),\n", 170 | " nn.BatchNorm2d(ndf * 2),\n", 171 | " nn.LeakyReLU(leaky_relu_slope, inplace=True),\n", 172 | " # state size. ``(ndf*2) x 16 x 16``\n", 173 | " nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),\n", 174 | " nn.BatchNorm2d(ndf * 4),\n", 175 | " nn.LeakyReLU(leaky_relu_slope, inplace=True),\n", 176 | " # state size. ``(ndf*4) x 8 x 8``\n", 177 | " nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),\n", 178 | " nn.BatchNorm2d(ndf * 8),\n", 179 | " nn.LeakyReLU(leaky_relu_slope, inplace=True),\n", 180 | " # state size. ``(ndf*8) x 4 x 4``\n", 181 | " nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),\n", 182 | " activation or nn.Sigmoid()\n", 183 | " )\n", 184 | "\n", 185 | " def forward(self, input):\n", 186 | " return self.main(input)\n", 187 | " \n", 188 | " @classmethod\n", 189 | " def from_trial(cls, trial):\n", 190 | " activations = [nn.Sigmoid]\n", 191 | " idx = trial.suggest_categorical(\"descriminator_activation\", list(range(len(activations))))\n", 192 | " activation = activations[idx]()\n", 193 | " \n", 194 | " slopes = np.arange(0.1, 0.4, 0.1)\n", 195 | " idx = trial.suggest_categorical(\"descriminator_leaky_relu_slope\", list(range(len(slopes))))\n", 196 | " leaky_relu_slope = slopes[idx]\n", 197 | " return cls(ngpu, leaky_relu_slope, activation)\n", 198 | " \n", 199 | "\n", 200 | "def objective(trial):\n", 201 | " dataset = FakeImageDataset()\n", 202 | "\n", 203 | " # Create the dataloader\n", 204 | " dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,\n", 205 | " shuffle=True, num_workers=workers)\n", 206 | " \n", 207 | " # Decide which device we want to run on\n", 208 | " device = torch.device(\"cuda:0\" if (torch.cuda.is_available() and ngpu > 0) else \"cpu\")\n", 209 | "\n", 210 | "\n", 211 | " ############################\n", 212 | " ### Create the generator ###\n", 213 | " netG = Generator.from_trial(trial).to(device)\n", 214 | "\n", 215 | " # Handle multi-GPU if desired\n", 216 | " if (device.type == 'cuda') and (ngpu > 1):\n", 217 | " netG = nn.DataParallel(netG, list(range(ngpu)))\n", 218 | "\n", 219 | " # Apply the ``weights_init`` function to randomly initialize all weights\n", 220 | " # to ``mean=0``, ``stdev=0.02``.\n", 221 | " netG.apply(weights_init)\n", 222 | "\n", 223 | " ################################\n", 224 | " ### Create the Discriminator ###\n", 225 | " netD = Discriminator.from_trial(trial).to(device)\n", 226 | "\n", 227 | " # Handle multi-GPU if desired\n", 228 | " if (device.type == 'cuda') and (ngpu > 1):\n", 229 | " netD = nn.DataParallel(netD, list(range(ngpu)))\n", 230 | "\n", 231 | " # Apply the ``weights_init`` function to randomly initialize all weights\n", 232 | " # like this: ``to mean=0, stdev=0.2``.\n", 233 | " netD.apply(weights_init)\n", 234 | "\n", 235 | " ############################################\n", 236 | " ### Remaining crierion, optimizers, etc. ###\n", 237 | " # Initialize the ``BCELoss`` function\n", 238 | " criterion = nn.BCELoss()\n", 239 | "\n", 240 | " # Create batch of latent vectors that we will use to visualize\n", 241 | " # the progression of the generator\n", 242 | " fixed_noise = torch.randn(64, nz, 1, 1, device=device)\n", 243 | "\n", 244 | " # Establish convention for real and fake labels during training\n", 245 | " real_label = 1.\n", 246 | " fake_label = 0.\n", 247 | " \n", 248 | " # Learning rate for optimizers\n", 249 | " lr = 0.00001\n", 250 | "\n", 251 | " # Setup Adam optimizers for both G and D\n", 252 | " optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))\n", 253 | " optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))\n", 254 | " \n", 255 | " #####################\n", 256 | " ### Training Loop ###\n", 257 | "\n", 258 | " # Lists to keep track of progress\n", 259 | " img_list = []\n", 260 | " G_losses = []\n", 261 | " D_losses = []\n", 262 | " iters = 0\n", 263 | "\n", 264 | " print(\"Starting Training Loop...\")\n", 265 | " # For each epoch\n", 266 | " for epoch in range(num_epochs):\n", 267 | " # For each batch in the dataloader\n", 268 | " for i, data in enumerate(dataloader, 0):\n", 269 | "\n", 270 | " ############################\n", 271 | " # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))\n", 272 | " ###########################\n", 273 | " ## Train with all-real batch\n", 274 | " netD.zero_grad()\n", 275 | " # Format batch\n", 276 | " real_cpu = data[0].to(device)\n", 277 | " b_size = real_cpu.size(0)\n", 278 | " label = torch.full((b_size,), real_label, dtype=torch.float, device=device)\n", 279 | " # Forward pass real batch through D\n", 280 | " output = netD(real_cpu).view(-1)\n", 281 | " # Calculate loss on all-real batch\n", 282 | " errD_real = criterion(output, label)\n", 283 | " # Calculate gradients for D in backward pass\n", 284 | " errD_real.backward()\n", 285 | " D_x = output.mean().item()\n", 286 | "\n", 287 | " ## Train with all-fake batch\n", 288 | " # Generate batch of latent vectors\n", 289 | " noise = torch.randn(b_size, nz, 1, 1, device=device)\n", 290 | " # Generate fake image batch with G\n", 291 | " fake = netG(noise)\n", 292 | " label.fill_(fake_label)\n", 293 | " # Classify all fake batch with D\n", 294 | " output = netD(fake.detach()).view(-1)\n", 295 | " # Calculate D's loss on the all-fake batch\n", 296 | " errD_fake = criterion(output, label)\n", 297 | " # Calculate the gradients for this batch, accumulated (summed) with previous gradients\n", 298 | " errD_fake.backward()\n", 299 | " D_G_z1 = output.mean().item()\n", 300 | " # Compute error of D as sum over the fake and the real batches\n", 301 | " errD = errD_real + errD_fake\n", 302 | " # Update D\n", 303 | " optimizerD.step()\n", 304 | "\n", 305 | " ############################\n", 306 | " # (2) Update G network: maximize log(D(G(z)))\n", 307 | " ###########################\n", 308 | " netG.zero_grad()\n", 309 | " label.fill_(real_label) # fake labels are real for generator cost\n", 310 | " # Since we just updated D, perform another forward pass of all-fake batch through D\n", 311 | " output = netD(fake).view(-1)\n", 312 | " # Calculate G's loss based on this output\n", 313 | " errG = criterion(output, label)\n", 314 | " # Calculate gradients for G\n", 315 | " errG.backward()\n", 316 | " D_G_z2 = output.mean().item()\n", 317 | " # Update G\n", 318 | " optimizerG.step()\n", 319 | "\n", 320 | " # Output training stats\n", 321 | " if i % 50 == 0:\n", 322 | " print('[%d/%d][%d/%d]\\tLoss_D: %.4f\\tLoss_G: %.4f\\tD(x): %.4f\\tD(G(z)): %.4f / %.4f'\n", 323 | " % (epoch, num_epochs, i, len(dataloader),\n", 324 | " errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))\n", 325 | "\n", 326 | " # Save Losses for plotting later\n", 327 | " G_losses.append(errG.item())\n", 328 | " D_losses.append(errD.item())\n", 329 | "\n", 330 | " # Check how the generator is doing by saving G's output on fixed_noise\n", 331 | " if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):\n", 332 | " with torch.no_grad():\n", 333 | " fake = netG(fixed_noise).detach().cpu()\n", 334 | " img_list.append(vutils.make_grid(fake, padding=2, normalize=True))\n", 335 | "\n", 336 | " iters += 1\n", 337 | "\n", 338 | " # Report to Optuna\n", 339 | " trial.report(errD.item(), epoch)\n", 340 | " if trial.should_prune():\n", 341 | " raise optuna.exceptions.TrialPruned()\n", 342 | " return errD.item()" 343 | ] 344 | }, 345 | { 346 | "cell_type": "markdown", 347 | "id": "fa3f3952", 348 | "metadata": {}, 349 | "source": [ 350 | "## Start Dask cluster with GPU workers" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "id": "75516f5a", 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "import coiled\n", 361 | "\n", 362 | "cluster = coiled.Cluster(\n", 363 | " n_workers=20,\n", 364 | " worker_vm_type=\"g5.xlarge\", # single A10 GPU\n", 365 | " worker_options={\"nthreads\": 1},\n", 366 | ")\n", 367 | "\n", 368 | "client = cluster.get_client()" 369 | ] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "id": "d9d03fb7", 374 | "metadata": {}, 375 | "source": [ 376 | "## Perform HyperParameter Optimization / Tuning" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": null, 382 | "id": "d3122981", 383 | "metadata": {}, 384 | "outputs": [], 385 | "source": [ 386 | "# Set to your heart's desire, patience and cluster size. :)\n", 387 | "n_trials = 500\n", 388 | "\n", 389 | "study = optuna.create_study(\n", 390 | " direction='minimize', \n", 391 | " storage=DaskStorage(client=client)\n", 392 | ")\n", 393 | "jobs = [\n", 394 | " client.submit(study.optimize, objective, n_trials=1, pure=False)\n", 395 | " for _ in range(n_trials)\n", 396 | "]" 397 | ] 398 | }, 399 | { 400 | "cell_type": "markdown", 401 | "id": "77538d36", 402 | "metadata": {}, 403 | "source": [ 404 | "## Analyze Results" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": null, 410 | "id": "cb0e6fef", 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "_ = wait(jobs)\n", 415 | "\n", 416 | "pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])\n", 417 | "complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])\n", 418 | "\n", 419 | "print(\"Study statistics: \")\n", 420 | "print(\" Number of finished trials: \", len(study.trials))\n", 421 | "print(\" Number of pruned trials: \", len(pruned_trials))\n", 422 | "print(\" Number of complete trials: \", len(complete_trials))\n", 423 | "\n", 424 | "print(\"Best trial:\")\n", 425 | "trial = study.best_trial\n", 426 | "\n", 427 | "print(\" Value: \", trial.value)\n", 428 | "\n", 429 | "print(\" Params: \")\n", 430 | "for key, value in trial.params.items():\n", 431 | " print(\" {}: {}\".format(key, value))" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "id": "72416cbc-cdab-4f8a-a19e-746e4b4b4280", 438 | "metadata": {}, 439 | "outputs": [], 440 | "source": [ 441 | "from optuna.visualization.matplotlib import plot_optimization_history, plot_param_importances\n", 442 | "\n", 443 | "plot_optimization_history(study)" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": null, 449 | "id": "13d5b8e2", 450 | "metadata": {}, 451 | "outputs": [], 452 | "source": [ 453 | "client.shutdown()" 454 | ] 455 | } 456 | ], 457 | "metadata": { 458 | "kernelspec": { 459 | "display_name": "Python [conda env:pytorch]", 460 | "language": "python", 461 | "name": "conda-env-pytorch-py" 462 | }, 463 | "language_info": { 464 | "codemirror_mode": { 465 | "name": "ipython", 466 | "version": 3 467 | }, 468 | "file_extension": ".py", 469 | "mimetype": "text/x-python", 470 | "name": "python", 471 | "nbconvert_exporter": "python", 472 | "pygments_lexer": "ipython3", 473 | "version": "3.10.14" 474 | } 475 | }, 476 | "nbformat": 4, 477 | "nbformat_minor": 5 478 | } 479 | --------------------------------------------------------------------------------