├── __init__.py ├── tests ├── __init__.py ├── test_train_with_rllib.py └── rice_rllib_test.yaml ├── .dockerignore ├── White_Paper.pdf ├── example_data ├── example.pkl └── example_2.pkl ├── climate_economic_min_max_indices.txt ├── CHANGELOG.md ├── [only for test, pls use your own version if possible] climate_economic_min_max_indices.txt ├── scripts ├── __init__.py ├── fixed_paths.py ├── lint.sh ├── desired_outputs.py ├── run_cpu_gpu_env_consistency_checks.py ├── rice_warpdrive.yaml ├── rice_rllib.yaml ├── create_submission_zip.py ├── torch_models.py ├── train_with_warp_drive.py ├── run_unittests.py ├── train_with_rllib.py └── evaluate_submission.py ├── region_yamls ├── 12.yml ├── 14.yml ├── 16.yml ├── 18.yml ├── 2.yml ├── 5.yml ├── 7.yml ├── 9.yml ├── 11.yml ├── 13.yml ├── 15.yml ├── 17.yml ├── 19.yml ├── 20.yml ├── 4.yml ├── 6.yml ├── 22.yml ├── 25.yml ├── 28.yml ├── 3.yml ├── 21.yml ├── 26.yml ├── 23.yml ├── 24.yml ├── 29.yml ├── 30.yml ├── 27.yml └── default.yml ├── requirements.txt ├── rice_build.cu ├── Dockerfile ├── Dockerfile_CPU ├── README_CPU.md ├── .github └── workflows │ └── main.yaml ├── CITATION.cff ├── LICENSE.txt ├── Makefile ├── .gitignore ├── MARL_using_RLlib.ipynb ├── rice_cuda.py ├── rice_helpers.py ├── README.md ├── getting_started.ipynb ├── Visualization.ipynb └── MARL_training.ipynb /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | Submissions 3 | *.pdf 4 | core 5 | .tmp 6 | -------------------------------------------------------------------------------- /White_Paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mila-iqia/climate-cooperation-competition/HEAD/White_Paper.pdf -------------------------------------------------------------------------------- /example_data/example.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mila-iqia/climate-cooperation-competition/HEAD/example_data/example.pkl -------------------------------------------------------------------------------- /example_data/example_2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mila-iqia/climate-cooperation-competition/HEAD/example_data/example_2.pkl -------------------------------------------------------------------------------- /climate_economic_min_max_indices.txt: -------------------------------------------------------------------------------- 1 | {"min_ci": 7.006738662719727, "max_ci": 1.2506777048110962, "min_ei": 606.57275390625, "max_ei": 9673.5322265625} -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | # Release 1.0 (2022-07-03) 4 | 5 | - Initial public release of RICE-N simulation code, training code, documentation, tutorials, and submission/evaluation scripts. -------------------------------------------------------------------------------- /[only for test, pls use your own version if possible] climate_economic_min_max_indices.txt: -------------------------------------------------------------------------------- 1 | {"min_ci": 7.006738662719727, "max_ci": 1.2506777048110962, "min_ei": 606.57275390625, "max_ei": 9673.5322265625} -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc and MILA. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root 5 | # or https://opensource.org/licenses/BSD-3-Clause -------------------------------------------------------------------------------- /scripts/fixed_paths.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from pathlib import Path 3 | _path = Path(os.path.abspath(__file__)) 4 | PUBLIC_REPO_DIR = str(_path.parent.parent.absolute()) 5 | sys.path.append(os.path.join(PUBLIC_REPO_DIR, "scripts")) 6 | #print("fixed_paths: Using PUBLIC_REPO_DIR = {}".format(PUBLIC_REPO_DIR)) 7 | -------------------------------------------------------------------------------- /scripts/lint.sh: -------------------------------------------------------------------------------- 1 | # stop the build if there are Python syntax errors or undefined names 2 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 3 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 4 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 5 | -------------------------------------------------------------------------------- /region_yamls/12.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 8.405493223457656 3 | xK_0: 3.30354611 4 | xL_0: 68.394527 5 | xL_a: 93.497311 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.1880269001436297 10 | xg_A: 0.10300420806704261 11 | xgamma: 0.3 12 | xl_g: 0.05753057218640376 13 | xsigma_0: 0.5289744017993728 14 | -------------------------------------------------------------------------------- /region_yamls/14.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 1.92663301826947 3 | xK_0: 1.423908312 4 | xL_0: 284.698846 5 | xL_a: 465.307807 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.24445012982169362 10 | xg_A: 0.1335428337437049 11 | xgamma: 0.3 12 | xl_g: 0.024422285778436918 13 | xsigma_0: 1.220638524516315 14 | -------------------------------------------------------------------------------- /region_yamls/16.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 4.217213133650901 3 | xK_0: 3.18362519 4 | xL_0: 548.75442 5 | xL_a: 560.054221 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.1703030497846267 10 | xg_A: 0.09485139864239062 11 | xgamma: 0.3 12 | xl_g: 0.08033413573292254 13 | xsigma_0: 0.3019631318655498 14 | -------------------------------------------------------------------------------- /region_yamls/18.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 2.5248787296824777 3 | xK_0: 1.080409098 4 | xL_0: 69.194146 5 | xL_a: 100.015768 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.3464232696284064 10 | xg_A: 0.0785686384327884 11 | xgamma: 0.3 12 | xl_g: 0.028895835870575235 13 | xsigma_0: 1.0104732880546095 14 | -------------------------------------------------------------------------------- /region_yamls/2.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 12.157936179442062 3 | xK_0: 2.64167507 4 | xL_0: 38.101107 5 | xL_a: 56.990157 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.13084887390535965 10 | xg_A: 0.06274070897633105 11 | xgamma: 0.3 12 | xl_g: 0.020192884216840113 13 | xsigma_0: 0.35044418275452427 14 | -------------------------------------------------------------------------------- /region_yamls/5.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 2.480556451289923 3 | xK_0: 0.090493838 4 | xL_0: 94.484285 5 | xL_a: 102.997258 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.20277540450432627 10 | xg_A: 0.20063089079847785 11 | xgamma: 0.3 12 | xl_g: 0.036907187009908436 13 | xsigma_0: 1.6646404809736024 14 | -------------------------------------------------------------------------------- /region_yamls/7.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 4.135420683261054 3 | xK_0: 1.00243116 4 | xL_0: 103.2943 5 | xL_a: 87.417937 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.1577590297489783 10 | xg_A: 0.12252697231832654 11 | xgamma: 0.3 12 | xl_g: -0.06254962519160859 13 | xsigma_0: 0.6013249328720022 14 | -------------------------------------------------------------------------------- /region_yamls/9.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 2.7159049409990375 3 | xK_0: 1.0340369 4 | xL_0: 573.818276 5 | xL_a: 681.210099 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.09686897303646223 10 | xg_A: 0.10149076832484585 11 | xgamma: 0.3 12 | xl_g: 0.04313067107477931 13 | xsigma_0: 0.6378326935010085 14 | -------------------------------------------------------------------------------- /region_yamls/11.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 1.8724419952820714 3 | xK_0: 0.239419592 4 | xL_0: 476.878017 5 | xL_a: 669.593553 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.13880429539557834 10 | xg_A: 0.12202134941105497 11 | xgamma: 0.3 12 | xl_g: 0.034238352160625596 13 | xsigma_0: 0.4559257467059924 14 | -------------------------------------------------------------------------------- /region_yamls/13.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 3.5579000509140952 3 | xK_0: 0.109143954 4 | xL_0: 64.122372 5 | xL_a: 135.074132 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.16127452284439697 10 | xg_A: 0.12735655631209186 11 | xgamma: 0.3 12 | xl_g: 0.02623933488387354 13 | xsigma_0: 0.8162518983719008 14 | -------------------------------------------------------------------------------- /region_yamls/15.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 8.111280036435135 3 | xK_0: 0.268152174 4 | xL_0: 28.141422 5 | xL_a: 23.573851 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.16335430971807735 10 | xg_A: 0.10573757974990125 11 | xgamma: 0.3 12 | xl_g: -0.05715547594428186 13 | xsigma_0: 0.29029694003558093 14 | -------------------------------------------------------------------------------- /region_yamls/17.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 2.4913586566019945 3 | xK_0: 0.043635414 4 | xL_0: 46.488546 5 | xL_a: 59.987638 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.05834835573455604 10 | xg_A: 0.049004053769246436 11 | xgamma: 0.3 12 | xl_g: 0.03709027315241262 13 | xsigma_0: 0.4196283605267465 14 | -------------------------------------------------------------------------------- /region_yamls/19.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 2.4596628149703816 3 | xK_0: 0.183982308 4 | xL_0: 513.737375 5 | xL_a: 1867.771496 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 1.8390289471577375 10 | xg_A: 0.46217845237530203 11 | xgamma: 0.3 12 | xl_g: 0.017149514576045286 13 | xsigma_0: 0.3103140976545981 14 | -------------------------------------------------------------------------------- /region_yamls/20.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 0.9929511285910457 3 | xK_0: 0.160199062 4 | xL_0: 522.481879 5 | xL_a: 1830.325243 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.08560686741591728 10 | xg_A: 0.06506072277236097 11 | xgamma: 0.3 12 | xl_g: 0.01902705663391574 13 | xsigma_0: 0.23517024551671273 14 | -------------------------------------------------------------------------------- /region_yamls/4.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 6.386787299600672 3 | xK_0: 1.094110266 4 | xL_0: 317.880267 5 | xL_a: 287.533185 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.19401001361417486 10 | xg_A: 0.23666124530625898 11 | xgamma: 0.3 12 | xl_g: -0.052512175141607595 13 | xsigma_0: 0.8402859337043421 14 | -------------------------------------------------------------------------------- /region_yamls/6.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 10.852953800501595 3 | xK_0: 17.553847656 4 | xL_0: 222.891134 5 | xL_a: 168.350837 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.005000000000006631 10 | xg_A: -0.00046726965128841526 11 | xgamma: 0.3 12 | xl_g: -0.011976043898247184 13 | xsigma_0: 0.2851271547872655 14 | -------------------------------------------------------------------------------- /region_yamls/22.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 29.853559456004625 3 | xK_0: 2.019951041942154 4 | xL_0: 165.75054 5 | xL_a: 216.9269455 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.08802757142538145 10 | xg_A: 0.07541285925058157 11 | xgamma: 0.3 12 | xl_g: -0.0024986057450947508 13 | xsigma_0: 0.25439108584131914 14 | -------------------------------------------------------------------------------- /region_yamls/25.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 10.922004036104973 3 | xK_0: 0.6059142357084183 4 | xL_0: 705.464681 5 | xL_a: 532.496728 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.09603760623634239 10 | xg_A: 0.16817991622593043 11 | xgamma: 0.3 12 | xl_g: -0.015844877082389193 13 | xsigma_0: 0.7813181890031158 14 | -------------------------------------------------------------------------------- /region_yamls/28.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 3.1898536850408714 3 | xK_0: 0.1287514001006796 4 | xL_0: 690.0021925 5 | xL_a: 723.512806 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.05389348561714375 10 | xg_A: 0.06812795377170236 11 | xgamma: 0.3 12 | xl_g: -0.012597171762552104 13 | xsigma_0: 0.9487399403167854 14 | -------------------------------------------------------------------------------- /region_yamls/3.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 13.219587477586199 3 | xK_0: 16.295084052817813 4 | xL_0: 502.409662 5 | xL_a: 445.861101 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.25224119422016716 10 | xg_A: 0.07423569831381745 11 | xgamma: 0.3 12 | xl_g: -0.033398145012670695 13 | xsigma_0: 0.17048017530013193 14 | -------------------------------------------------------------------------------- /region_yamls/21.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 5.000360862831762 3 | xK_0: 2.289358084859004 4 | xL_0: 165.293239 5 | xL_a: 230.19114338372032 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.18278991377259912 10 | xg_A: 0.07108490122759262 11 | xgamma: 0.3 12 | xl_g: 0.026773049602328805 13 | xsigma_0: 0.4187771240034329 14 | -------------------------------------------------------------------------------- /region_yamls/26.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 9.633893693772771 3 | xK_0: 0.6076078389971926 4 | xL_0: 465.60668946000004 5 | xL_a: 351.44784048 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.09603760623634239 10 | xg_A: 0.16817991622593043 11 | xgamma: 0.3 12 | xl_g: -0.015844877094273714 13 | xsigma_0: 0.7813181890031158 14 | -------------------------------------------------------------------------------- /region_yamls/23.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 23.314991608844633 3 | xK_0: 3.0391651447451187 4 | xL_0: 109.39535640000001 5 | xL_a: 143.17178403 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.08802757821836926 10 | xg_A: 0.07541286104378697 11 | xgamma: 0.3 12 | xl_g: -0.002498605753950543 13 | xsigma_0: 0.25439108584131914 14 | -------------------------------------------------------------------------------- /region_yamls/24.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 29.853559456004625 3 | xK_0: 0.6867833542603324 4 | xL_0: 56.355183600000004 5 | xL_a: 73.75516147 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.08802757142538145 10 | xg_A: 0.07541285925058157 11 | xgamma: 0.3 12 | xl_g: -0.002498605778464897 13 | xsigma_0: 0.25439108584131914 14 | -------------------------------------------------------------------------------- /region_yamls/29.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 2.033527139192083 3 | xK_0: 0.3810937821808831 4 | xL_0: 455.40144705 5 | xL_a: 477.51845196000005 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.053893489919463196 10 | xg_A: 0.06812795559760368 11 | xgamma: 0.3 12 | xl_g: -0.012597171772754604 13 | xsigma_0: 0.9487399403167854 14 | -------------------------------------------------------------------------------- /region_yamls/30.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 3.1898536850408714 3 | xK_0: 0.04377547603423107 4 | xL_0: 234.60074545 5 | xL_a: 245.99435404000002 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.05389348561714375 10 | xg_A: 0.06812795377170236 11 | xgamma: 0.3 12 | xl_g: -0.012597171800996251 13 | xsigma_0: 0.9487399403167854 14 | -------------------------------------------------------------------------------- /region_yamls/27.yml: -------------------------------------------------------------------------------- 1 | _RICE_CONSTANT: 2 | xA_0: 8.620918323265558 3 | xK_0: 0.45330037729157585 4 | xL_0: 239.85799154000003 5 | xL_a: 181.04888752000002 6 | xa_1: 0 7 | xa_2: 0.00236 8 | xa_3: 2 9 | xdelta_A: 0.09603760623634239 10 | xg_A: 0.16817991622593043 11 | xgamma: 0.3 12 | xl_g: -0.015844877127171766 13 | xsigma_0: 0.7813181890031158 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | importlib-metadata==4.12.0 2 | flask==2.1.1 3 | gym==0.22.0 4 | pandas==1.3.0 5 | waitress==2.1.1 6 | jupyterlab>=3.4.0 7 | tensorflow==1.13.1 8 | torch==1.9.0 9 | # optional: RLlib and WarpDrive for training 10 | # For CPU 11 | # ray[rllib]==1.0.0 12 | # For GPU 13 | # rl-warp-drive==1.7.0 14 | matplotlib==3.5.3 15 | scikit-learn 16 | numpy==1.21.6 17 | deepdiff==5.8.1 18 | pyyaml==6.0 19 | pytest==7.1.3 20 | -------------------------------------------------------------------------------- /rice_build.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, salesforce.com, inc and MILA. 2 | // All rights reserved. 3 | // SPDX-License-Identifier: BSD-3-Clause 4 | // For full license text, see the LICENSE file in the repo root 5 | // or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | 8 | #ifndef CUDA_INCLUDES_RICE_CONST_H_ 9 | #define CUDA_INCLUDES_RICE_CONST_H_ 10 | 11 | #include "rice_step.cu" 12 | 13 | #endif // CUDA_INCLUDES_RICE_CONST_H_ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc and MILA. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root 5 | # or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | 8 | FROM nvcr.io/nvidia/pytorch:21.08-py3 9 | LABEL description="warpdrive-env" 10 | WORKDIR /home/ 11 | RUN chmod a+rwx /home/ 12 | # Install other packages 13 | RUN pip3 install pycuda==2021.1 14 | -------------------------------------------------------------------------------- /Dockerfile_CPU: -------------------------------------------------------------------------------- 1 | #FROM toluclassics/transformers_notebook 2 | FROM jupyter/base-notebook:python-3.7.6 3 | #FROM jupyter/tensorflow-notebook 4 | 5 | ENV TRANSFORMERS_CACHE=/tmp/.cache 6 | ENV TOKENIZERS_PARALLELISM=true 7 | 8 | USER root 9 | RUN apt-get update && apt-get install -y libglib2.0-0 10 | RUN pip3 install --no-cache-dir protobuf==3.20.1 importlib-metadata==4.13.0 11 | RUN pip3 install --no-cache-dir tensorflow==1.13.2 gym==0.21 ray[rllib]==1.0.0 torch==1.9.0 12 | RUN pip3 install --no-cache-dir importlib-resources ale-py~=0.7.1 \ 13 | && pip3 install --no-cache-dir MarkupSafe==2.0.1 14 | RUN pip3 install --no-cache-dir scikit-learn matplotlib 15 | 16 | USER ${NB_UID} 17 | WORKDIR "${HOME}/work" 18 | -------------------------------------------------------------------------------- /README_CPU.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | You can train and evaluate your model in an isolated Docker container. 4 | This keeps your RL environment separate from your local environment, 5 | which make reproducibility more robust. 6 | 7 | The `Dockerfile_CPU` builds a standalone local Docker image that both 8 | hosts the Jupyter notebook server and also can train and evaluate 9 | your submission from the command line. 10 | 11 | # Quick Start 12 | 13 | Build the image and start the notebook server: 14 | 15 | ``` 16 | make 17 | ``` 18 | 19 | Train model from command line: 20 | 21 | ``` 22 | make train 23 | ``` 24 | 25 | Evaluate most recent submission 26 | 27 | ``` 28 | make evaluate 29 | ``` 30 | 31 | # Requirements 32 | 33 | You need to have Docker installed in your workstation. 34 | 35 | 36 | -------------------------------------------------------------------------------- /tests/test_train_with_rllib.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import subprocess 4 | from pathlib import Path 5 | 6 | _path = Path(os.path.abspath(__file__)) 7 | PUBLIC_REPO_DIR = str(_path.parent.parent.absolute()) 8 | os.chdir(PUBLIC_REPO_DIR) 9 | 10 | 11 | def test_training_run(): 12 | env = os.environ.copy() 13 | env["CONFIG_FILE"] = "./tests/rice_rllib_test.yaml" 14 | command = subprocess.run(["python", "./scripts/train_with_rllib.py"], 15 | env=env, 16 | capture_output=True) 17 | assert command.returncode == 0 18 | output = command.stdout.decode("utf-8") 19 | 20 | file_reg = re.compile(r'is created at: (.*$)', flags=re.M) 21 | print(output) 22 | match = file_reg.search(output).group(1) 23 | assert os.path.exists(match) 24 | -------------------------------------------------------------------------------- /.github/workflows/main.yaml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.7"] 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -r requirements.txt 23 | pip install ray[rllib]==1.0.0 24 | # TODO fix flake8 errors and enable 25 | # - name: Lint with flake8 26 | # run: | 27 | # ./scripts/lint.sh 28 | - name: Test with pytest 29 | run: | 30 | pytest 31 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Zhang" 5 | given-names: "Tianyu" 6 | orcid: "https://orcid.org/0000-0002-4410-1343" 7 | - family-names: "Srinivasa" 8 | given-names: "Sunil" 9 | orcid: "https://orcid.org/0000-0000-0000-0000" 10 | - family-names: "Wiliams" 11 | given-names: "Andrew" 12 | orcid: "https://orcid.org/0000-0000-0000-0000" 13 | - family-names: "Phade" 14 | given-names: "Soham" 15 | orcid: "https://orcid.org/0000-0000-0000-0000" 16 | - family-names: "Zhang" 17 | given-names: "Yang" 18 | orcid: "https://orcid.org/0000-0000-0000-0000" 19 | - family-names: "Gupta" 20 | given-names: "Prateek" 21 | orcid: "https://orcid.org/0000-0002-0892-3518" 22 | - family-names: "Bengio" 23 | given-names: "Yoshua" 24 | orcid: "https://orcid.org/0000-0002-9322-3515" 25 | - family-names: "Zheng" 26 | given-names: "Stephan" 27 | orcid: "https://orcid.org/0000-0002-7271-1616" 28 | title: "RICE-N" 29 | version: 1.0.0 30 | date-released: 2022-07-03 31 | url: "https://github.com/mila-iqia/climate-cooperation-competition" 32 | -------------------------------------------------------------------------------- /scripts/desired_outputs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc and MILA. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root 5 | # or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | 8 | desired_outputs = [ 9 | "global_temperature", 10 | "global_carbon_mass", 11 | "capital_all_regions", 12 | "labor_all_regions", 13 | "production_factor_all_regions", 14 | "intensity_all_regions", 15 | "global_exogenous_emissions", 16 | "global_land_emissions", 17 | "timestep", 18 | "activity_timestep", 19 | "capital_depreciation_all_regions", 20 | "savings_all_regions", 21 | "mitigation_rate_all_regions", 22 | "max_export_limit_all_regions", 23 | "mitigation_cost_all_regions", 24 | "damages_all_regions", 25 | "abatement_cost_all_regions", 26 | "utility_all_regions", 27 | "social_welfare_all_regions", 28 | "reward_all_regions", 29 | "consumption_all_regions", 30 | "current_balance_all_regions", 31 | "gross_output_all_regions", 32 | "investment_all_regions", 33 | "production_all_regions", 34 | "tariffs", 35 | "future_tariffs", 36 | "scaled_imports", 37 | "desired_imports", 38 | "tariffed_imports", 39 | "stage", 40 | "minimum_mitigation_rate_all_regions", 41 | "promised_mitigation_rate", 42 | "requested_mitigation_rate", 43 | "proposal_decisions", 44 | ] 45 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022, Salesforce.com, Inc and MILA. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of Salesforce.com or MILA, nor the names of their contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /scripts/run_cpu_gpu_env_consistency_checks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc and MILA. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root 5 | # or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | 8 | # Run this via 9 | # !python run_cpu_gpu_env_consistency_checks.py 10 | 11 | import logging 12 | import os 13 | import sys 14 | 15 | import torch 16 | 17 | from fixed_paths import PUBLIC_REPO_DIR 18 | sys.path.append(PUBLIC_REPO_DIR) 19 | 20 | from warp_drive.env_cpu_gpu_consistency_checker import EnvironmentCPUvsGPU 21 | from warp_drive.utils.env_registrar import EnvironmentRegistrar 22 | 23 | from rice import Rice 24 | from rice_cuda import RiceCuda 25 | 26 | logging.getLogger().setLevel(logging.ERROR) 27 | 28 | _NUM_GPUS_AVAILABLE = torch.cuda.device_count() 29 | assert _NUM_GPUS_AVAILABLE > 0, "This script needs a GPU to run!" 30 | 31 | env_registrar = EnvironmentRegistrar() 32 | 33 | env_registrar.add_cuda_env_src_path(Rice.name, os.path.join(PUBLIC_REPO_DIR, "rice_build.cu")) 34 | env_configs = { 35 | "no_negotiation": { 36 | "num_discrete_action_levels": 100, 37 | "negotiation_on": False, 38 | }, 39 | "with_negotiation": { 40 | "num_discrete_action_levels": 100, 41 | "negotiation_on": True, 42 | }, 43 | } 44 | testing_class = EnvironmentCPUvsGPU( 45 | cpu_env_class=Rice, 46 | cuda_env_class=RiceCuda, 47 | env_configs=env_configs, 48 | num_envs=2, 49 | num_episodes=2, 50 | use_gpu_testing_mode=False, 51 | env_registrar=env_registrar, 52 | ) 53 | 54 | testing_class.test_env_reset_and_step(consistency_threshold_pct=1, seed=17) 55 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PROJECT = climate-coop/cpu 2 | #VERSION ?= $(shell git describe --abbrev=0 --tags) 3 | VERSION ?= latest 4 | IMAGE ?= $(PROJECT):$(VERSION) 5 | DOCKER_VARS = --user $(shell id -u):$(shell id -g) --group-add users -e GRANT_SUDO=yes 6 | MOUNT_VARS = -v "${PWD}":/home/jovyan/work -v ${PWD}/ray_results:/home/jovyan/ray_results 7 | X11_VARS = -e "DISPLAY" -v "/etc/group:/etc/group:ro" -v "/etc/passwd:/etc/passwd:ro" -v "/etc/shadow:/etc/shadow:ro" -v "/etc/sudoers.d:/etc/sudoers.d:ro" -v "/tmp/.X11-unix:/tmp/.X11-unix:rw" 8 | ALL_VARS = $(DOCKER_VARS) $(MOUNT_VARS) $(X11_VARS) 9 | 10 | PYTHONPATH ?= -e PYTHONPATH=scripts 11 | 12 | all: build run 13 | 14 | build: 15 | docker build -t $(IMAGE) -f Dockerfile_CPU . 16 | 17 | run: 18 | docker run -it --rm $(ALL_VARS) -p 8888:8888 $(IMAGE) 19 | 20 | train: 21 | mkdir -p ray_results > /dev/null 22 | docker run -it --rm $(ALL_VARS) $(IMAGE) python scripts/train_with_rllib.py 23 | 24 | evaluate: SUB=$(shell ls -r Submissions/*.zip | tr ' ' '\n' | head -1) 25 | evaluate: 26 | mkdir -p .tmp/_base 27 | docker run -it --rm $(ALL_VARS) $(IMAGE) python scripts/evaluate_submission.py -r $(SUB) 28 | 29 | bash: 30 | docker run -it --rm $(ALL_VARS) $(IMAGE) bash 31 | 32 | python: 33 | docker run -it --rm $(ALL_VARS) $(PYTHONPATH) $(IMAGE) python 34 | 35 | # View on http://localhost:6006 36 | tensorboard: 37 | docker run -it --rm $(ALL_VARS) -p 6006:6006 $(IMAGE) tensorboard --logdir '~/work/ray_results' 38 | 39 | diagnose: LOG=diagnostic.log 40 | diagnose: 41 | @echo "Build clean image (no cache)..." | tee $(LOG) 42 | docker build --no-cache -t $(IMAGE) -f Dockerfile_CPU . | tee -a $(LOG) 43 | @echo "Using Docker: $(shell docker --version)" | tee -a $(LOG) 44 | @echo "Working dir: $(shell pwd)" | tee -a $(LOG) 45 | @echo "User: $(shell id -un)" | tee -a $(LOG) 46 | @echo "In docker group: $(shell id -Gn | tr ' ' '\n' | grep docker)" | tee -a $(LOG) 47 | @echo "Inspect Docker container..." | tee -a $(LOG) 48 | docker run --rm $(ALL_VARS) $(IMAGE) whoami | tee -a $(LOG) 49 | docker run --rm $(ALL_VARS) $(IMAGE) ls -l | tee -a $(LOG) 50 | 51 | -------------------------------------------------------------------------------- /tests/rice_rllib_test.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc and MILA. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root 5 | # or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | # This file does a quick run for tests 8 | 9 | # Checkpoint saving setting 10 | saving: 11 | metrics_log_freq: 100 # how often (in iterations) to log (and print) the metrics 12 | model_params_save_freq: 1000 # how often (in iterations) to save the model parameters 13 | basedir: "/tmp" # base folder used for saving 14 | name: "rice" # experiment name 15 | tag: "experiments" # experiment tag 16 | 17 | # Trainer settings 18 | trainer: 19 | num_envs: 1 20 | rollout_fragment_length: 100 21 | train_batch_size: 20 22 | num_episodes: 1 # This will cause one iteration only 23 | framework: torch # framework setting. 24 | # Note: RLlib supports TF as well, but our end-to-end pipeline is built for Pytorch only. 25 | # === Hardware Settings === 26 | num_workers: 1 # number of rollout worker actors to create for parallel sampling. 27 | # Note: Setting the num_workers to 0 will force rollouts to be done in the trainer actor. 28 | num_gpus: 0 # number of GPUs to allocate to the trainer process. This can also be fractional (e.g., 0.3 GPUs). 29 | 30 | # Environment configuration 31 | env: 32 | num_discrete_action_levels: 10 # number of discrete levels for the saving and mitigation actions 33 | negotiation_on: False # flag to indicate whether negotiation is allowed or not 34 | 35 | # Policy network settings 36 | policy: 37 | regions: 38 | vf_loss_coeff: 0.1 # loss coefficient schedule for the value function loss 39 | entropy_coeff_schedule: # loss coefficient schedule for the entropy loss 40 | # piecewise linear, specified as (timestep, value) 41 | - [0, 0.5] 42 | - [1000000, 0.1] 43 | - [5000000, 0.05] 44 | clip_grad_norm: True # flag indicating whether to clip the gradient norm or not 45 | max_grad_norm: 0.5 # when clip_grad_norm is True, the clip level 46 | gamma: 0.92 # discount factor 47 | lr: 0.0005 # learning rate 48 | model: 49 | custom_model: torch_linear 50 | custom_model_config: 51 | fc_dims: [16, 16] 52 | -------------------------------------------------------------------------------- /scripts/rice_warpdrive.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc and MILA. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root 5 | # or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | 8 | # Checkpoint saving setting 9 | saving: 10 | metrics_log_freq: 100 # how often (in iterations) to log (and print) the metrics 11 | model_params_save_freq: 1000 # how often (in iterations) to save the model parameters 12 | basedir: "/tmp" # base folder used for saving 13 | name: "rice" # experiment name 14 | tag: "experiments" # experiment tag 15 | 16 | # Trainer settings 17 | trainer: 18 | num_envs: 100 # number of environment replicas 19 | num_episodes: 1000 # number of episodes to run the training for 20 | train_batch_size: 10000 # total batch size used for training per iteration (across all the environments) 21 | 22 | # Environment configuration 23 | env: 24 | num_discrete_action_levels: 10 # number of discrete levels for the saving and mitigation actions 25 | negotiation_on: False # flag to indicate whether negotiation is allowed or not 26 | 27 | # Policy network settings 28 | policy: 29 | regions: 30 | to_train: True # flag indicating whether the model needs to be trained 31 | algorithm: "A2C" # algorithm used to train the policy 32 | vf_loss_coeff: 0.1 # loss coefficient schedule for the value function loss 33 | entropy_coeff: # loss coefficient schedule for the entropy loss 34 | # piecewise linear, specified as (timestep, value) 35 | - [0, 0.5] 36 | - [1000000, 0.1] 37 | - [5000000, 0.05] 38 | clip_grad_norm: True # flag indicating whether to clip the gradient norm or not 39 | max_grad_norm: 0.5 # when clip_grad_norm is True, the clip level 40 | normalize_advantage: False # flag indicating whether to normalize advantage or not 41 | normalize_return: False # flag indicating whether to normalize return or not 42 | gamma: 0.92 # discount factor 43 | lr: 0.0005 # learning rate 44 | model: # policy model settings 45 | type: "fully_connected" # model type 46 | fc_dims: [256, 256] # dimension(s) of the fully connected layers as a list 47 | model_ckpt_filepath: "" # filepath (used to restore a previously saved model) 48 | -------------------------------------------------------------------------------- /scripts/rice_rllib.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc and MILA. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root 5 | # or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | 8 | # Checkpoint saving setting 9 | saving: 10 | metrics_log_freq: 100 # how often (in iterations) to log (and print) the metrics 11 | model_params_save_freq: 1000 # how often (in iterations) to save the model parameters 12 | basedir: "/tmp" # base folder used for saving 13 | name: "rice" # experiment name 14 | tag: "experiments" # experiment tag 15 | 16 | # Trainer settings 17 | trainer: 18 | num_envs: 20 # number of environment replicas 19 | rollout_fragment_length: 100 # divide episodes into fragments of this many steps each during rollouts. 20 | train_batch_size: 2000 # total batch size used for training per iteration (across all the environments) 21 | num_episodes: 100 # number of episodes to run the training for 22 | framework: torch # framework setting. 23 | # Note: RLlib supports TF as well, but our end-to-end pipeline is built for Pytorch only. 24 | # === Hardware Settings === 25 | num_workers: 4 # number of rollout worker actors to create for parallel sampling. 26 | # Note: Setting the num_workers to 0 will force rollouts to be done in the trainer actor. 27 | num_gpus: 0 # number of GPUs to allocate to the trainer process. This can also be fractional (e.g., 0.3 GPUs). 28 | 29 | # Environment configuration 30 | env: 31 | num_discrete_action_levels: 10 # number of discrete levels for the saving and mitigation actions 32 | negotiation_on: False # flag to indicate whether negotiation is allowed or not 33 | 34 | # Policy network settings 35 | policy: 36 | regions: 37 | vf_loss_coeff: 0.1 # loss coefficient schedule for the value function loss 38 | entropy_coeff_schedule: # loss coefficient schedule for the entropy loss 39 | # piecewise linear, specified as (timestep, value) 40 | - [0, 0.5] 41 | - [1000000, 0.1] 42 | - [5000000, 0.05] 43 | clip_grad_norm: True # flag indicating whether to clip the gradient norm or not 44 | max_grad_norm: 0.5 # when clip_grad_norm is True, the clip level 45 | gamma: 0.92 # discount factor 46 | lr: 0.0005 # learning rate 47 | model: 48 | custom_model: torch_linear 49 | custom_model_config: 50 | fc_dims: [256, 256] 51 | -------------------------------------------------------------------------------- /region_yamls/default.yml: -------------------------------------------------------------------------------- 1 | _DICE_CONSTANT: 2 | xt_0: 2015 # starting year of the whole model 3 | xDelta: 5 # the time interval (year) 4 | xN: 20 # total time steps 5 | 6 | # Climate diffusion parameters 7 | xPhi_T: [[0.8718, 0.0088], [0.025, 0.975]] 8 | xB_T: [0.1005, 0] 9 | # xB_T: [0.03, 0] 10 | 11 | # Carbon cycle diffusion parameters (the zeta matrix in the paper) 12 | xPhi_M: [[0.88, 0.196, 0], [0.12, 0.797, 0.001465], [0, 0.007, 0.99853488]] 13 | # xB_M: [0.2727272727272727, 0, 0] # 12/44 14 | xB_M: [1.36388, 0, 0] # 12/44 15 | xeta: 3.6813 #?? I don't find where it's used 16 | 17 | xM_AT_1750: 588 # atmospheric mass of carbon in the year of 1750 18 | xf_0: 0.5 # in Eq 3 param to effect of greenhouse gases other than carbon dioxide 19 | xf_1: 1 # in Eq 3 param to effect of greenhouse gases other than carbon dioxide 20 | xt_f: 20 # in Eq 3 time step param to effect of greenhouse gases other than carbon dioxide 21 | xE_L0: 2.6 # 2.6 # in Eq 4 param to the emissions due to land use changes 22 | xdelta_EL: 0.001 # 0.115 # 0.115 # in Eq 4 param to the emissions due to land use changes 23 | 24 | xM_AT_0: 851 # in CAP the atmospheric mass of carbon in the year t 25 | xM_UP_0: 460 # in CAP the atmospheric upper bound of mass of carbon in the year t 26 | xM_LO_0: 1740 # in CAP the atmospheric lower bound of mass of carbon in the year t 27 | xe_0: 35.85 # in EI define the initial simga_0: e0/(q0(1-mu0)) 28 | xq_0: 105.5 # in EI define the initial simga_0: e0/(q0(1-mu0)) 29 | xmu_0: 0.03 # in EI define the initial simga_0: e0/(q0(1-mu0)) 30 | 31 | # From Python implementation PyDICE 32 | xF_2x: 3.6813 # 3.6813 # Forcing that doubles equilibrium carbon. 33 | xT_2x: 3.1 # 3.1 # Equilibrium temperature increase at double carbon eq. 34 | 35 | _RICE_CONSTANT: 36 | xgamma: 0.3 # in CAP Eq 5 the capital elasticty 37 | xtheta_2: 2.6 # in CAP Eq 6 38 | xa_1: 0 39 | xa_2: 0.00236 # in CAP Eq 6 40 | xa_3: 2 # in CAP Eq 6 41 | xdelta_K: 0.1 # in CAP Eq 9 param discribe the depreciate of the capital 42 | xalpha: 1.45 # Utility function param 43 | xrho: 0.015 # discount factor of the utility 44 | xL_0: 7403 # in POP population at the staring point 45 | xL_a: 11500 # in POP the expected population at convergence 46 | xl_g: 0.134 # in POP control the rate to converge 47 | xA_0: 5.115 # in TFP technology at starting point 48 | xg_A: 0.076 # in TFP control the rate of increasing of tech larger->faster 49 | xdelta_A: 0.005 # in TFP control the rate of increasing of tech smaller->faster 50 | xsigma_0: 0.3503 # e0/(q0(1-mu0)) in EI emission intensity at the starting point 51 | xg_sigma: 0.0025 # 0.0152 # 0.0025 in EI control the rate of mitigation larger->reduce more emission 52 | xdelta_sigma: 0.1 # 0.01 in EI control the rate of mitigation larger->reduce less emission 53 | xp_b: 550 # 550 # in Eq 2 (estimate of the cost of mitigation) represents the price of a backstop technology that can remove carbon dioxide from the atmosphere 54 | xdelta_pb: 0.001 # 0.025 # in Eq 2 control the how the cost of mitigation change through time larger->cost less as time goes by 55 | xscale_1: 0.030245527 # in Eq 29 Nordhaus scaled cost function param 56 | xscale_2: 10993.704 # in Eq 29 Nordhaus scaled cost function param 57 | 58 | xT_AT_0: 0.85 # in CAP a part of damage function initial condition 59 | xT_LO_0: 0.0068 # in CAP a part of damage function initial condition 60 | xK_0: 223 # in CAP initial condition for capital -------------------------------------------------------------------------------- /scripts/create_submission_zip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc and MILA. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root 5 | # or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | 8 | """ 9 | Script to create the zipped submission file from the results directory 10 | """ 11 | import os 12 | import shutil 13 | import sys 14 | import argparse 15 | from evaluate_submission import validate_dir 16 | import time 17 | from fixed_paths import PUBLIC_REPO_DIR 18 | 19 | sys.path.append(PUBLIC_REPO_DIR) 20 | 21 | 22 | def get_results_dir(): 23 | """ 24 | Obtain the 'results' directory from the system arguments. 25 | """ 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument( 28 | "--results_dir", 29 | "-r", 30 | type=str, 31 | default=".", 32 | help="the directory where all the submission files are saved. Can also be " 33 | "the zipped file containing all the submission files.", 34 | ) 35 | args = parser.parse_args() 36 | 37 | if "results_dir" not in args: 38 | raise ValueError( 39 | "Please provide a results directory to evaluate with the argument -r" 40 | ) 41 | if not os.path.exists(args.results_dir): 42 | raise ValueError( 43 | "The results directory is missing. Please make sure the correct path " 44 | "is specified!" 45 | ) 46 | try: 47 | results_dir = args.results_dir 48 | 49 | # Also handle a zipped file 50 | if results_dir.endswith(".zip"): 51 | unzipped_results_dir = os.path.join("/tmp", str(time.time())) 52 | shutil.unpack_archive(results_dir, unzipped_results_dir) 53 | results_dir = unzipped_results_dir 54 | return results_dir, parser 55 | except Exception as err: 56 | raise ValueError("Cannot obtain the results directory") from err 57 | 58 | 59 | def prepare_submission(results_dir=None): 60 | """ 61 | # Validate all the submission files and compress into a .zip. 62 | Note: This method is also invoked in the trainer script itself! 63 | So if you ran the training script, you may not need to re-run this. 64 | Args results_dir: the directory where all the training files were saved. 65 | """ 66 | assert results_dir is not None 67 | submission_filename = results_dir.split("/")[-1] 68 | submission_file = os.path.join(PUBLIC_REPO_DIR, "Submissions", submission_filename) 69 | 70 | validate_dir(results_dir) 71 | 72 | # Only copy the latest policy model file for submission 73 | results_dir_copy = os.path.join("/tmp", "_copies_", submission_filename) 74 | shutil.copytree(results_dir, results_dir_copy) 75 | 76 | policy_models = [ 77 | os.path.join(results_dir, file) 78 | for file in os.listdir(results_dir) 79 | if file.endswith(".state_dict") 80 | ] 81 | sorted_policy_models = sorted(policy_models, key=os.path.getmtime) 82 | # Delete all but the last policy model file 83 | for policy_model in sorted_policy_models[:-1]: 84 | os.remove(os.path.join(results_dir_copy, policy_model.split("/")[-1])) 85 | 86 | shutil.make_archive(submission_file, "zip", results_dir_copy) 87 | print("NOTE: The submission file is created at:", submission_file + ".zip") 88 | shutil.rmtree(results_dir_copy) 89 | 90 | 91 | if __name__ == "__main__": 92 | prepare_submission(results_dir=get_results_dir()[0]) 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # See http://help.github.com/ignore-files/ for more about ignoring files. 2 | 3 | # compiled output 4 | /dist 5 | /tmp 6 | /out-tsc 7 | 8 | data 9 | 10 | # Runtime data 11 | pids 12 | *.pid 13 | *.seed 14 | *.pid.lock 15 | 16 | # Directory for instrumented libs generated by jscoverage/JSCover 17 | lib-cov 18 | 19 | # Coverage directory used by tools like istanbul 20 | coverage 21 | 22 | # nyc test coverage 23 | .nyc_output 24 | 25 | # Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files) 26 | .grunt 27 | 28 | # Bower dependency directory (https://bower.io/) 29 | bower_components 30 | 31 | # node-waf configuration 32 | .lock-wscript 33 | 34 | # IDEs and editors 35 | .idea 36 | .project 37 | .classpath 38 | .c9/ 39 | *.launch 40 | .settings/ 41 | *.sublime-workspace 42 | 43 | # IDE - VSCode 44 | .vscode/* 45 | !.vscode/settings.json 46 | !.vscode/tasks.json 47 | !.vscode/launch.json 48 | !.vscode/extensions.json 49 | 50 | # misc 51 | .sass-cache 52 | connect.lock 53 | typings 54 | 55 | # Logs 56 | logs 57 | *.log 58 | npm-debug.log* 59 | yarn-debug.log* 60 | yarn-error.log* 61 | 62 | 63 | # Dependency directories 64 | node_modules/ 65 | jspm_packages/ 66 | 67 | # Optional npm cache directory 68 | .npm 69 | 70 | # Optional eslint cache 71 | .eslintcache 72 | 73 | # Optional REPL history 74 | .node_repl_history 75 | 76 | # Output of 'npm pack' 77 | *.tgz 78 | 79 | # Yarn Integrity file 80 | .yarn-integrity 81 | 82 | # dotenv environment variables file 83 | .env 84 | 85 | # next.js build output 86 | .next 87 | 88 | # Lerna 89 | lerna-debug.log 90 | 91 | # System Files 92 | .DS_Store 93 | Thumbs.db 94 | 95 | 96 | # Byte-compiled / optimized / DLL files 97 | __pycache__/ 98 | *.py[cod] 99 | *$py.class 100 | 101 | # C extensions 102 | *.so 103 | 104 | # Distribution / packaging 105 | .Python 106 | build/ 107 | develop-eggs/ 108 | dist/ 109 | downloads/ 110 | eggs/ 111 | .eggs/ 112 | lib/ 113 | lib64/ 114 | parts/ 115 | sdist/ 116 | var/ 117 | wheels/ 118 | pip-wheel-metadata/ 119 | share/python-wheels/ 120 | *.egg-info/ 121 | .installed.cfg 122 | *.egg 123 | MANIFEST 124 | 125 | # PyInstaller 126 | # Usually these files are written by a python script from a template 127 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 128 | *.manifest 129 | *.spec 130 | 131 | # Installer logs 132 | pip-log.txt 133 | pip-delete-this-directory.txt 134 | 135 | # Unit test / coverage reports 136 | htmlcov/ 137 | .tox/ 138 | .nox/ 139 | .coverage 140 | .coverage.* 141 | .cache 142 | nosetests.xml 143 | coverage.xml 144 | *.cover 145 | *.py,cover 146 | .hypothesis/ 147 | .pytest_cache/ 148 | 149 | # Translations 150 | *.mo 151 | *.pot 152 | 153 | # Django stuff: 154 | *.log 155 | local_settings.py 156 | db.sqlite3 157 | db.sqlite3-journal 158 | 159 | # Flask stuff: 160 | instance/ 161 | .webassets-cache 162 | 163 | # Scrapy stuff: 164 | .scrapy 165 | 166 | # Sphinx documentation 167 | docs/_build/ 168 | 169 | # PyBuilder 170 | target/ 171 | 172 | # Jupyter Notebook 173 | .ipynb_checkpoints 174 | 175 | # IPython 176 | profile_default/ 177 | ipython_config.py 178 | 179 | # pyenv 180 | .python-version 181 | 182 | # pipenv 183 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 184 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 185 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 186 | # install all needed dependencies. 187 | #Pipfile.lock 188 | 189 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 190 | __pypackages__/ 191 | 192 | # Celery stuff 193 | celerybeat-schedule 194 | celerybeat.pid 195 | 196 | # SageMath parsed files 197 | *.sage.py 198 | 199 | # Environments 200 | .env 201 | .venv 202 | env/ 203 | venv/ 204 | ENV/ 205 | env.bak/ 206 | venv.bak/ 207 | 208 | # Spyder project settings 209 | .spyderproject 210 | .spyproject 211 | 212 | # Rope project settings 213 | .ropeproject 214 | 215 | # mkdocs documentation 216 | /site 217 | 218 | # mypy 219 | .mypy_cache/ 220 | .dmypy.json 221 | dmypy.json 222 | 223 | # Pyre type checker 224 | .pyre/ 225 | Submissions/*.zip 226 | emissions.csv 227 | Colab_Tutorial-Copy1.ipynb 228 | Draw_graphs.ipynb 229 | 100_rounds.pkl 230 | *.pdf 231 | core 232 | .tmp 233 | Submissions/* 234 | .vscode/* 235 | *.pkl 236 | scripts/climate_economic_min_max_indices.txt 237 | scripts/dev.ipynb 238 | scripts/dev.py 239 | scripts/eval_by_logs.py 240 | -------------------------------------------------------------------------------- /MARL_using_RLlib.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "56cb26e3", 6 | "metadata": {}, 7 | "source": [ 8 | "Copyright (c) 2022, salesforce.com, inc and MILA. \n", 9 | "All rights reserved. \n", 10 | "SPDX-License-Identifier: BSD-3-Clause \n", 11 | "For full license text, see the LICENSE file in the repo root \n", 12 | "or https://opensource.org/licenses/BSD-3-Clause " 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "7635e695", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import logging\n", 23 | "import numpy as np\n", 24 | "import os\n", 25 | "import shutil\n", 26 | "import subprocess\n", 27 | "import sys\n", 28 | "import time\n", 29 | "import yaml\n", 30 | "\n", 31 | "from scripts.train_with_rllib import create_trainer, fetch_episode_states, load_model_checkpoints" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "a1e40c9c", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# Set logger level e.g., DEBUG, INFO, WARNING, ERROR.\n", 42 | "logging.getLogger().setLevel(logging.ERROR)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "id": "4f636cf3", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "# Needed to perform this install when the system threw the lib.so file missing error\n", 53 | "# ! apt-get install libglib2.0-0 --yes" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "id": "6b952acb", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "print(\"Training with RLlib...\")\n", 64 | "# Read the run configurations specific to the environment.\n", 65 | "# Note: The run config yaml(s) can be edited at warp_drive/training/run_configs\n", 66 | "# -----------------------------------------------------------------------------\n", 67 | "config_path = os.path.join(\"scripts\", \"rice_rllib.yaml\")\n", 68 | "if not os.path.exists(config_path):\n", 69 | " raise ValueError(\n", 70 | " \"The run configuration is missing. Please make sure the correct path\"\n", 71 | " \"is specified.\"\n", 72 | " )\n", 73 | "\n", 74 | "with open(config_path, \"r\", encoding=\"utf8\") as fp:\n", 75 | " run_config = yaml.safe_load(fp)\n", 76 | "\n", 77 | "# Create trainer\n", 78 | "# --------------\n", 79 | "trainer, save_dir = create_trainer(run_config)\n", 80 | "\n", 81 | "# Copy the source files into the results directory\n", 82 | "# ------------------------------------------------\n", 83 | "os.makedirs(save_dir)\n", 84 | "for file in [\n", 85 | " \"rice.py\",\n", 86 | "]:\n", 87 | " shutil.copyfile(\n", 88 | " os.path.join(file),\n", 89 | " os.path.join(save_dir, file),\n", 90 | " )\n", 91 | "# Add an identifier file\n", 92 | "with open(os.path.join(save_dir, \".rllib\"), \"x\", encoding=\"utf-8\") as fp:\n", 93 | " pass\n", 94 | "fp.close()" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "id": "bf5ea093", 100 | "metadata": {}, 101 | "source": [ 102 | "### Invoke training" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "10564c91", 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "NUM_ITERS = 5\n", 113 | "for iter in range(NUM_ITERS):\n", 114 | " result = trainer.train()\n", 115 | "print(result)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "id": "d31b0cf0", 121 | "metadata": {}, 122 | "source": [ 123 | "### Fetch episode states" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "id": "23cb91ee", 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "outputs = fetch_episode_states(trainer, [\"T_i\", \"carbon_mass_i\", \"capital_i\"])" 134 | ] 135 | } 136 | ], 137 | "metadata": { 138 | "kernelspec": { 139 | "display_name": "Python 3 (ipykernel)", 140 | "language": "python", 141 | "name": "python3" 142 | }, 143 | "language_info": { 144 | "codemirror_mode": { 145 | "name": "ipython", 146 | "version": 3 147 | }, 148 | "file_extension": ".py", 149 | "mimetype": "text/x-python", 150 | "name": "python", 151 | "nbconvert_exporter": "python", 152 | "pygments_lexer": "ipython3", 153 | "version": "3.8.10" 154 | } 155 | }, 156 | "nbformat": 4, 157 | "nbformat_minor": 5 158 | } 159 | -------------------------------------------------------------------------------- /rice_cuda.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc and MILA. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root 5 | # or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | 8 | """ 9 | CUDA version of the Regional Integrated model of Climate and the Economy (RICE). 10 | This subclasses the python version of the model and also the CUDAEnvironmentContext 11 | for running with WarpDrive (https://github.com/salesforce/warp-drive) 12 | """ 13 | 14 | import os 15 | import sys 16 | import numpy as np 17 | from warp_drive.utils.constants import Constants 18 | from warp_drive.utils.data_feed import DataFeed 19 | from warp_drive.utils.gpu_environment_context import CUDAEnvironmentContext 20 | 21 | _PUBLIC_REPO_DIR = os.path.dirname(os.path.abspath(__file__)) 22 | sys.path = [_PUBLIC_REPO_DIR] + sys.path 23 | 24 | from rice import Rice 25 | 26 | _OBSERVATIONS = Constants.OBSERVATIONS 27 | _ACTIONS = Constants.ACTIONS 28 | _REWARDS = Constants.REWARDS 29 | 30 | 31 | class RiceCuda(Rice, CUDAEnvironmentContext): 32 | """ 33 | Rice env class that invokes the CUDA step function. 34 | """ 35 | 36 | name = "Rice" 37 | 38 | def get_data_dictionary(self): 39 | """ 40 | Create a dictionary of data to push to the device. 41 | """ 42 | data_feed = DataFeed() 43 | 44 | # Add constants 45 | for key, value in sorted(self.dice_constant.items()): 46 | data_feed.add_data(name=key, data=value) 47 | 48 | for key, value in sorted(self.rice_constant.items()): 49 | data_feed.add_data(name=key, data=value) 50 | 51 | # Add all the global states at timestep 0. 52 | timestep = 0 53 | for key in sorted(self.global_state.keys()): 54 | data_feed.add_data( 55 | name=key, 56 | data=self.global_state[key]["value"][timestep], 57 | save_copy_and_apply_at_reset=True, 58 | ) 59 | 60 | for key in sorted(self.global_state.keys()): 61 | data_feed.add_data( 62 | name=key + "_norm", 63 | data=self.global_state[key]["norm"], 64 | ) 65 | 66 | # Env config parameters 67 | data_feed.add_data( 68 | name="aux_ms", 69 | data=np.zeros(self.num_regions, dtype=np.float32), 70 | save_copy_and_apply_at_reset=True, 71 | ) 72 | 73 | # Env config parameters 74 | data_feed.add_data( 75 | name="num_discrete_action_levels", 76 | data=self.num_discrete_action_levels, 77 | ) 78 | 79 | data_feed.add_data( 80 | name="balance_interest_rate", 81 | data=self.balance_interest_rate, 82 | ) 83 | 84 | data_feed.add_data(name="negotiation_on", data=self.negotiation_on) 85 | 86 | # Armington agg. parameters 87 | data_feed.add_data_list( 88 | [ 89 | ("sub_rate", self.sub_rate), 90 | ("dom_pref", self.dom_pref), 91 | ("for_pref", self.for_pref), 92 | ] 93 | ) 94 | 95 | # Year parameters 96 | data_feed.add_data_list( 97 | [("current_year", self.current_year, True), ("end_year", self.end_year)] 98 | ) 99 | 100 | return data_feed 101 | 102 | @staticmethod 103 | def get_tensor_dictionary(): 104 | """ 105 | Create a dictionary of pytorch-accessible tensors to push to the device. 106 | """ 107 | tensor_dict = DataFeed() 108 | return tensor_dict 109 | 110 | def step(self): 111 | constants_keys = list(sorted(self.dice_constant.keys())) + list( 112 | sorted(self.rice_constant.keys()) 113 | ) 114 | args = ( 115 | constants_keys 116 | + list(sorted(self.global_state.keys())) 117 | + [key + "_norm" for key in list(sorted(self.global_state.keys()))] 118 | + [ 119 | "num_discrete_action_levels", 120 | "balance_interest_rate", 121 | "negotiation_on", 122 | "aux_ms", 123 | "sub_rate", 124 | "dom_pref", 125 | "for_pref", 126 | "current_year", 127 | "end_year", 128 | _OBSERVATIONS + "_features", 129 | _OBSERVATIONS + "_action_mask", 130 | _ACTIONS, 131 | _REWARDS, 132 | "_done_", 133 | "_timestep_", 134 | ("n_agents", "meta"), 135 | ("episode_length", "meta"), 136 | ] 137 | ) 138 | 139 | self.cuda_step( 140 | *self.cuda_step_function_feed(args), 141 | block=self.cuda_function_manager.block, 142 | grid=self.cuda_function_manager.grid, 143 | ) 144 | -------------------------------------------------------------------------------- /scripts/torch_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc and MILA. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root 5 | # or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | 8 | """ 9 | Custom Pytorch policy models to use with RLlib. 10 | """ 11 | 12 | # API reference: 13 | # https://docs.ray.io/en/latest/rllib/rllib-models.html#custom-pytorch-models 14 | 15 | import numpy as np 16 | from gym.spaces import Box, Dict 17 | from ray.rllib.models import ModelCatalog 18 | from ray.rllib.models.modelv2 import restore_original_dimensions 19 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 20 | from ray.rllib.utils import try_import_torch 21 | from ray.rllib.utils.annotations import override 22 | 23 | torch, nn = try_import_torch() 24 | 25 | _ACTION_MASK = "action_mask" 26 | 27 | 28 | class TorchLinear(TorchModelV2, nn.Module): 29 | """ 30 | Fully-connected Pytorch policy model. 31 | """ 32 | 33 | custom_name = "torch_linear" 34 | 35 | def __init__( 36 | self, obs_space, action_space, num_outputs, model_config, name, fc_dims=None 37 | ): 38 | super().__init__(obs_space, action_space, num_outputs, model_config, name) 39 | nn.Module.__init__(self) 40 | 41 | if fc_dims is None: 42 | fc_dims = [256, 256] 43 | 44 | # Check Observation spaces 45 | 46 | self.observation_space = obs_space.original_space 47 | 48 | if not isinstance(self.observation_space, Dict): 49 | if isinstance(self.observation_space, Box): 50 | raise TypeError( 51 | "({name}) Observation space should be a gym Dict. " 52 | "Is a Box of shape {self.observation_space.shape}" 53 | ) 54 | raise TypeError( 55 | f"({name}) Observation space should be a gym Dict. " 56 | "Is {type(self.observation_space))} instead." 57 | ) 58 | 59 | flattened_obs_size = self.get_flattened_obs_size() 60 | 61 | # Model only outputs policy logits, 62 | # values are accessed via the self.value_function 63 | self.values = None 64 | 65 | num_fc_layers = len(fc_dims) 66 | 67 | input_dims = [flattened_obs_size] + fc_dims[:-1] 68 | output_dims = fc_dims 69 | 70 | self.fc_dict = nn.ModuleDict() 71 | for fc_layer in range(num_fc_layers): 72 | self.fc_dict[str(fc_layer)] = nn.Sequential( 73 | nn.Linear(input_dims[fc_layer], output_dims[fc_layer]), 74 | nn.ReLU(), 75 | ) 76 | 77 | # policy network (list of heads) 78 | policy_heads = [None for _ in range(len(action_space))] 79 | self.output_dims = [] # Network output dimension(s) 80 | 81 | for idx, act_space in enumerate(action_space): 82 | output_dim = act_space.n 83 | self.output_dims += [output_dim] 84 | policy_heads[idx] = nn.Linear(fc_dims[-1], output_dim) 85 | self.policy_head = nn.ModuleList(policy_heads) 86 | 87 | # value-function network head 88 | self.vf_head = nn.Linear(fc_dims[-1], 1) 89 | 90 | # used for action masking 91 | self.action_mask = None 92 | 93 | def get_flattened_obs_size(self): 94 | """Get the total size of the observation after flattening.""" 95 | if isinstance(self.observation_space, Box): 96 | obs_size = np.prod(self.observation_space.shape) 97 | elif isinstance(self.observation_space, Dict): 98 | obs_size = 0 99 | for key in sorted(self.observation_space): 100 | if key == _ACTION_MASK: 101 | pass 102 | else: 103 | obs_size += np.prod(self.observation_space[key].shape) 104 | else: 105 | raise NotImplementedError("Observation space must be of Box or Dict type") 106 | return int(obs_size) 107 | 108 | def get_flattened_obs(self, obs): 109 | """Get the flattened observation (ignore the action masks).""" 110 | if isinstance(self.observation_space, Box): 111 | return self.reshape_and_flatten(obs) 112 | if isinstance(self.observation_space, Dict): 113 | flattened_obs_dict = {} 114 | for key in sorted(self.observation_space): 115 | assert key in obs 116 | if key == _ACTION_MASK: 117 | self.action_mask = self.reshape_and_flatten_obs(obs[key]) 118 | else: 119 | flattened_obs_dict[key] = self.reshape_and_flatten_obs(obs[key]) 120 | flattened_obs = torch.cat(list(flattened_obs_dict.values()), dim=-1) 121 | return flattened_obs 122 | raise NotImplementedError("Observation space must be of Box or Dict type") 123 | 124 | @staticmethod 125 | def reshape_and_flatten_obs(obs): 126 | """Flatten observation.""" 127 | assert len(obs.shape) >= 2 128 | batch_dim = obs.shape[0] 129 | return obs.reshape(batch_dim, -1) 130 | 131 | @override(TorchModelV2) 132 | def value_function(self): 133 | """Returns the estimated value function.""" 134 | return self.values.reshape(-1) 135 | 136 | @staticmethod 137 | def apply_logit_mask(logits, mask): 138 | """ 139 | Mask values of 1 are valid actions. 140 | Add huge negative values to logits with 0 mask values. 141 | """ 142 | logit_mask = torch.ones_like(logits) * -10000000 143 | logit_mask = logit_mask * (1 - mask) 144 | return logits + logit_mask 145 | 146 | @override(TorchModelV2) 147 | def forward(self, input_dict, state, seq_lens): 148 | """You should implement forward() of forward_rnn() in your subclass.""" 149 | if isinstance(seq_lens, np.ndarray): 150 | seq_lens = torch.Tensor(seq_lens).int() 151 | 152 | # Note: restoring original obs 153 | # as RLlib does not seem to be doing it automatically! 154 | original_obs = restore_original_dimensions( 155 | input_dict["obs"], self.obs_space.original_space, "torch" 156 | ) 157 | 158 | obs = self.get_flattened_obs(original_obs) 159 | 160 | # Feed through the FC layers 161 | for layer in range(len(self.fc_dict)): 162 | output = self.fc_dict[str(layer)](obs) 163 | obs = output 164 | logits = output 165 | 166 | # Compute the action probabilities and the value function estimate 167 | # Apply action mask to the logits as well. 168 | action_masks = [None for _ in range(len(self.output_dims))] 169 | if self.action_mask is not None: 170 | start = 0 171 | for idx, dim in enumerate(self.output_dims): 172 | action_masks[idx] = self.action_mask[..., start : start + dim] 173 | start = start + dim 174 | action_logits = [ 175 | self.apply_logit_mask(ph(logits), action_masks[idx]) 176 | for idx, ph in enumerate(self.policy_head) 177 | ] 178 | self.values = self.vf_head(logits)[..., 0] 179 | 180 | concatenated_action_logits = torch.cat(action_logits, dim=-1) 181 | return torch.reshape(concatenated_action_logits, [-1, self.num_outputs]), state 182 | 183 | 184 | ModelCatalog.register_custom_model("torch_linear", TorchLinear) 185 | -------------------------------------------------------------------------------- /rice_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc and MILA. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root 5 | # or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | 8 | """ 9 | Helper functions for the rice simulation 10 | """ 11 | import os 12 | 13 | import numpy as np 14 | import yaml 15 | 16 | _SMALL_NUM = 1e-0 # small number added to handle consumption blow-up 17 | 18 | 19 | # Load calibration data from yaml files 20 | def read_yaml_data(yaml_file): 21 | """Helper function to read yaml configuration data.""" 22 | with open(yaml_file, "r", encoding="utf-8") as file_ptr: 23 | file_data = file_ptr.read() 24 | file_ptr.close() 25 | data = yaml.load(file_data, Loader=yaml.FullLoader) 26 | return data 27 | 28 | 29 | def set_rice_params(yamls_folder=None): 30 | """Helper function to read yaml data and set environment configs.""" 31 | assert yamls_folder is not None 32 | dice_params = read_yaml_data(os.path.join(yamls_folder, "default.yml")) 33 | file_list = sorted(os.listdir(yamls_folder)) # 34 | yaml_files = [] 35 | for file in file_list: 36 | if file[-4:] == ".yml" and file != "default.yml": 37 | yaml_files.append(file) 38 | 39 | rice_params = [] 40 | for file in yaml_files: 41 | rice_params.append(read_yaml_data(os.path.join(yamls_folder, file))) 42 | 43 | # Overwrite rice params 44 | num_regions = len(rice_params) 45 | for k in dice_params["_RICE_CONSTANT"].keys(): 46 | dice_params["_RICE_CONSTANT"][k] = [ 47 | dice_params["_RICE_CONSTANT"][k] 48 | ] * num_regions 49 | for idx, param in enumerate(rice_params): 50 | for k in param["_RICE_CONSTANT"].keys(): 51 | dice_params["_RICE_CONSTANT"][k][idx] = param["_RICE_CONSTANT"][k] 52 | 53 | return dice_params, num_regions 54 | 55 | 56 | # RICE dynamics 57 | def get_mitigation_cost(p_b, theta_2, delta_pb, timestep, intensity): 58 | """Obtain the cost for mitigation.""" 59 | return p_b / (1000 * theta_2) * pow(1 - delta_pb, timestep - 1) * intensity 60 | 61 | 62 | def get_exogenous_emissions(f_0, f_1, t_f, timestep): 63 | """Obtain the amount of exogeneous emissions.""" 64 | return f_0 + min(f_1 - f_0, (f_1 - f_0) / t_f * (timestep - 1)) 65 | 66 | 67 | def get_land_emissions(e_l0, delta_el, timestep, num_regions): 68 | """Obtain the amount of land emissions.""" 69 | return e_l0 * pow(1 - delta_el, timestep - 1)/num_regions 70 | 71 | 72 | def get_production(production_factor, capital, labor, gamma): 73 | """Obtain the amount of goods produced.""" 74 | return production_factor * pow(capital, gamma) * pow(labor / 1000, 1 - gamma) 75 | 76 | 77 | def get_damages(t_at, a_1, a_2, a_3): 78 | """Obtain damages.""" 79 | return 1 / (1 + a_1 * t_at + a_2 * pow(t_at, a_3)) 80 | 81 | 82 | def get_abatement_cost(mitigation_rate, mitigation_cost, theta_2): 83 | """Compute the abatement cost.""" 84 | return mitigation_cost * pow(mitigation_rate, theta_2) 85 | 86 | 87 | def get_gross_output(damages, abatement_cost, production): 88 | """Compute the gross production output, taking into account 89 | damages and abatement cost.""" 90 | return damages * (1 - abatement_cost) * production 91 | 92 | 93 | def get_investment(savings, gross_output): 94 | """Obtain the investment cost.""" 95 | return savings * gross_output 96 | 97 | 98 | def get_consumption(gross_output, investment, exports): 99 | """Obtain the consumption cost.""" 100 | total_exports = np.sum(exports) 101 | assert gross_output - investment - total_exports > -1e-5, "consumption cannot be negative!" 102 | return max(0.0, gross_output - investment - total_exports) 103 | 104 | 105 | def get_max_potential_exports(x_max, gross_output, investment): 106 | """Determine the maximum potential exports.""" 107 | if x_max * gross_output <= gross_output - investment: 108 | return x_max * gross_output 109 | return gross_output - investment 110 | 111 | 112 | def get_capital_depreciation(x_delta_k, x_delta): 113 | """Compute the global capital depreciation.""" 114 | return pow(1 - x_delta_k, x_delta) 115 | 116 | 117 | def get_global_temperature( 118 | phi_t, temperature, b_t, f_2x, m_at, m_at_1750, exogenous_emissions 119 | ): 120 | """Get the temperature levels.""" 121 | return np.dot(phi_t, temperature) + np.dot( 122 | b_t, f_2x * np.log(m_at / m_at_1750) / np.log(2) + exogenous_emissions 123 | ) 124 | 125 | 126 | def get_aux_m(intensity, mitigation_rate, production, land_emissions): 127 | """Auxiliary variable to denote carbon mass levels.""" 128 | return intensity * (1 - mitigation_rate) * production + land_emissions 129 | 130 | 131 | def get_global_carbon_mass(phi_m, carbon_mass, b_m, aux_m): 132 | """Get the carbon mass level.""" 133 | return np.dot(phi_m, carbon_mass) + np.dot(b_m, aux_m) 134 | 135 | 136 | def get_capital(capital_depreciation, capital, delta, investment): 137 | """Evaluate capital.""" 138 | return capital_depreciation * capital + delta * investment 139 | 140 | 141 | def get_labor(labor, l_a, l_g): 142 | """Compute total labor.""" 143 | return labor * pow((1 + l_a) / (1 + labor), l_g) 144 | 145 | 146 | def get_production_factor(production_factor, g_a, delta_a, delta, timestep): 147 | """Compute the production factor.""" 148 | return production_factor * ( 149 | np.exp(0.0033) + g_a * np.exp(-delta_a * delta * (timestep - 1)) 150 | ) 151 | 152 | 153 | def get_carbon_intensity(intensity, g_sigma, delta_sigma, delta, timestep): 154 | """Determine the carbon emission intensity.""" 155 | return intensity * np.exp( 156 | -g_sigma * pow(1 - delta_sigma, delta * (timestep - 1)) * delta 157 | ) 158 | 159 | 160 | def get_utility(labor, consumption, alpha): 161 | """Obtain the utility.""" 162 | return ( 163 | (labor / 1000.0) 164 | * (pow(consumption / (labor / 1000.0) + _SMALL_NUM, 1 - alpha) - 1) 165 | / (1 - alpha) 166 | ) 167 | 168 | 169 | def get_social_welfare(utility, rho, delta, timestep): 170 | """Compute social welfare""" 171 | return utility / pow(1 + rho, delta * timestep) 172 | 173 | 174 | def get_armington_agg( 175 | c_dom, 176 | c_for, # np.array 177 | sub_rate=0.5, # in (0,1) 178 | dom_pref=0.5, # in [0,1] 179 | for_pref=None, # np.array 180 | ): 181 | """ 182 | Armington aggregate from Lessmann,2009. 183 | Consumption goods from different regions act as imperfect substitutes. 184 | As such, consumption of domestic and foreign goods are scaled according to 185 | relative preferences, as well as a substitution rate, which are modeled 186 | by a CES functional form. 187 | Inputs : 188 | `C_dom` : A scalar representing domestic consumption. The value of 189 | C_dom is what is left over from initial production after 190 | investment and exports are deducted. 191 | `C_for` : An array reprensenting foreign consumption. Each element 192 | is the consumption imported from a given country. 193 | `sub_rate` : A substitution parameter in (0,1). The elasticity of 194 | substitution is 1 / (1 - sub_rate). 195 | `dom_pref` : A scalar in [0,1] representing the relative preference for 196 | domestic consumption over foreign consumption. 197 | `for_pref` : An array of the same size as `C_for`. Each element is the 198 | relative preference for foreign goods from that country. 199 | """ 200 | 201 | c_dom_pref = dom_pref * (c_dom ** sub_rate) 202 | c_for_pref = np.sum(for_pref * pow(c_for, sub_rate)) 203 | 204 | c_agg = (c_dom_pref + c_for_pref) ** (1 / sub_rate) # CES function 205 | return c_agg 206 | -------------------------------------------------------------------------------- /scripts/train_with_warp_drive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc and MILA. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root 5 | # or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | 8 | """ 9 | Training script for the rice environment using WarpDrive 10 | www.github.com/salesforce/warp-drive 11 | """ 12 | 13 | import logging 14 | import os 15 | import shutil 16 | import subprocess 17 | import sys 18 | import numpy as np 19 | import yaml 20 | from desired_outputs import desired_outputs 21 | 22 | sys.path.append("./") 23 | from opt_helper import get_mean_std 24 | from fixed_paths import PUBLIC_REPO_DIR 25 | 26 | sys.path.append(PUBLIC_REPO_DIR) 27 | 28 | from scripts.run_unittests import import_class_from_path 29 | 30 | # Set logger level e.g., DEBUG, INFO, WARNING, ERROR. 31 | logging.getLogger().setLevel(logging.ERROR) 32 | 33 | 34 | def perform_other_imports(): 35 | """ 36 | WarpDrive-related imports. 37 | """ 38 | import torch 39 | 40 | num_gpus_available = torch.cuda.device_count() 41 | assert num_gpus_available > 0, "This script needs a GPU to run!" 42 | 43 | from warp_drive.env_wrapper import EnvWrapper 44 | from warp_drive.training.trainer import Trainer 45 | from warp_drive.utils.env_registrar import EnvironmentRegistrar 46 | 47 | return torch, EnvWrapper, Trainer, EnvironmentRegistrar 48 | 49 | 50 | try: 51 | other_imports = perform_other_imports() 52 | except ImportError: 53 | print("Installing requirements...") 54 | subprocess.call(["pip", "install", "rl-warp-drive>=1.6.5"]) 55 | 56 | other_imports = perform_other_imports() 57 | 58 | torch, EnvWrapper, Trainer, EnvironmentRegistrar = other_imports 59 | 60 | 61 | def create_trainer(run_config=None, source_dir=None, seed=None): 62 | """ 63 | Create the WarpDrive trainer. 64 | """ 65 | torch.cuda.FloatTensor(8) # add this line for successful cuda_init 66 | 67 | assert run_config is not None 68 | if source_dir is None: 69 | source_dir = PUBLIC_REPO_DIR 70 | if seed is not None: 71 | run_config["trainer"]["seed"] = seed 72 | 73 | # Create a wrapped environment object via the EnvWrapper 74 | # Ensure that use_cuda is set to True (in order to run on the GPU) 75 | 76 | # Register the environment 77 | env_registrar = EnvironmentRegistrar() 78 | 79 | rice_cuda_class = import_class_from_path( 80 | "RiceCuda", os.path.join(source_dir, "rice_cuda.py") 81 | ) 82 | 83 | env_registrar.add_cuda_env_src_path( 84 | rice_cuda_class.name, os.path.join(source_dir, "rice_build.cu") 85 | ) 86 | 87 | env_wrapper = EnvWrapper( 88 | rice_cuda_class(**run_config["env"]), 89 | num_envs=run_config["trainer"]["num_envs"], 90 | use_cuda=True, 91 | env_registrar=env_registrar, 92 | ) 93 | 94 | # Policy mapping to agent ids: agents can share models 95 | # The policy_tag_to_agent_id_map dictionary maps 96 | # policy model names to agent ids. 97 | # ---------------------------------------------------- 98 | policy_tag_to_agent_id_map = { 99 | "regions": list(range(env_wrapper.env.num_agents)), 100 | } 101 | 102 | # Create the Trainer object 103 | # ------------------------- 104 | trainer_obj = Trainer( 105 | env_wrapper=env_wrapper, 106 | config=run_config, 107 | policy_tag_to_agent_id_map=policy_tag_to_agent_id_map, 108 | ) 109 | return trainer_obj, trainer_obj.save_dir 110 | 111 | 112 | def load_model_checkpoints(trainer=None, save_directory=None, ckpt_idx=-1): 113 | """ 114 | Load trained model checkpoints. 115 | """ 116 | assert trainer is not None 117 | assert save_directory is not None 118 | assert os.path.exists(save_directory), ( 119 | "Invalid folder path. " 120 | "Please specify a valid directory to load the checkpoints from." 121 | ) 122 | files = [file for file in os.listdir(save_directory) if file.endswith("state_dict")] 123 | assert len(files) >= len(trainer.policies), "Missing policy checkpoints" 124 | 125 | ckpts_dict = {} 126 | for policy in trainer.policies_to_train: 127 | policy_models = [ 128 | os.path.join(save_directory, file) for file in files if policy in file 129 | ] 130 | # If there are multiple files, then use the ckpt_idx to specify the checkpoint 131 | assert ckpt_idx < len(policy_models) 132 | sorted_policy_models = sorted(policy_models, key=os.path.getmtime) 133 | policy_model_file = sorted_policy_models[ckpt_idx] 134 | logging.info(f"Loaded model checkpoints {policy_model_file}.") 135 | 136 | ckpts_dict.update({policy: policy_model_file}) 137 | trainer.load_model_checkpoint(ckpts_dict) 138 | 139 | 140 | def fetch_episode_states(trainer_obj=None, episode_states=None, env_id=None): 141 | """ 142 | Helper function to rollout the env and fetch env states for an episode. 143 | """ 144 | assert trainer_obj is not None 145 | assert isinstance( 146 | episode_states, list 147 | ), "Please pass the 'episode states' args as a list." 148 | assert len(episode_states) > 0 149 | return trainer_obj.fetch_episode_states(episode_states, env_id) 150 | 151 | 152 | def copy_source_files(trainer): 153 | """ 154 | Copy source files to the saving directory. 155 | """ 156 | for file in [ 157 | "rice.py", 158 | "rice_helpers.py", 159 | "rice_cuda.py", 160 | "rice_step.cu", 161 | "rice_build.cu", 162 | ]: 163 | shutil.copyfile( 164 | os.path.join(PUBLIC_REPO_DIR, file), 165 | os.path.join(trainer.save_dir, file), 166 | ) 167 | 168 | for file in [ 169 | "rice_warpdrive.yaml", 170 | ]: 171 | shutil.copyfile( 172 | os.path.join(PUBLIC_REPO_DIR, "scripts", file), 173 | os.path.join(trainer.save_dir, file), 174 | ) 175 | 176 | # Add an identifier file 177 | with open( 178 | os.path.join(trainer.save_dir, ".warpdrive"), "x", encoding="utf-8" 179 | ) as file_pointer: 180 | pass 181 | file_pointer.close() 182 | 183 | 184 | def trainer( 185 | negotiation_on=0, 186 | num_envs=100, 187 | train_batch_size=1024, 188 | num_episodes=30000, 189 | lr=0.0005, 190 | model_params_save_freq=5000, 191 | desired_outputs=desired_outputs, 192 | output_all_envs=False, 193 | ): 194 | """ 195 | Main function to run the trainer. 196 | """ 197 | # Load the run_config 198 | print("Training with WarpDrive...") 199 | 200 | # Read the run configurations specific to the environment. 201 | # Note: The run config yaml(s) can be edited at warp_drive/training/run_configs 202 | # ----------------------------------------------------------------------------- 203 | config_path = os.path.join(PUBLIC_REPO_DIR, "scripts", "rice_warpdrive.yaml") 204 | if not os.path.exists(config_path): 205 | raise ValueError( 206 | "The run configuration is missing. Please make sure the correct path" 207 | "is specified." 208 | ) 209 | 210 | with open(config_path, "r", encoding="utf8") as fp: 211 | run_configuration = yaml.safe_load(fp) 212 | run_configuration["env"]["negotiation_on"] = negotiation_on 213 | run_configuration["trainer"]["num_envs"] = num_envs 214 | run_configuration["trainer"]["train_batch_size"] = train_batch_size 215 | run_configuration["trainer"]["num_episodes"] = num_episodes 216 | run_configuration["policy"]["regions"]["lr"] = lr 217 | run_configuration["saving"]["model_params_save_freq"] = model_params_save_freq 218 | # run_configuration trainer 219 | # -------------- 220 | trainer_object, _ = create_trainer(run_config=run_configuration) 221 | 222 | # Copy the source files into the results directory 223 | # ------------------------------------------------ 224 | copy_source_files(trainer_object) 225 | 226 | # Perform training! 227 | # ----------------- 228 | trainer_object.train() 229 | 230 | # Create a (zipped) submission file 231 | # --------------------------------- 232 | subprocess.call( 233 | [ 234 | "python", 235 | os.path.join(PUBLIC_REPO_DIR, "scripts", "create_submission_zip.py"), 236 | "--results_dir", 237 | trainer_object.save_dir, 238 | ] 239 | ) 240 | outputs_ts = [ 241 | fetch_episode_states(trainer_object, desired_outputs, env_id=i) 242 | for i in range(num_envs) 243 | ] 244 | for i in range(len(outputs_ts)): 245 | outputs_ts[i]["global_consumption"] = np.sum( 246 | outputs_ts[i]["consumption_all_regions"], axis=-1 247 | ) 248 | outputs_ts[i]["global_production"] = np.sum( 249 | outputs_ts[i]["gross_output_all_regions"], axis=-1 250 | ) 251 | if not output_all_envs: 252 | outputs_ts, _ = get_mean_std(outputs_ts) 253 | # Shut off the trainer gracefully 254 | # ------------------------------- 255 | trainer_object.graceful_close() 256 | return trainer_object, outputs_ts 257 | 258 | 259 | if __name__ == "__main__": 260 | print("Training with WarpDrive...") 261 | 262 | # Read the run configurations specific to the environment. 263 | # Note: The run config yaml(s) can be edited at warp_drive/training/run_configs 264 | # ----------------------------------------------------------------------------- 265 | config_path = os.path.join(PUBLIC_REPO_DIR, "scripts", "rice_warpdrive.yaml") 266 | if not os.path.exists(config_path): 267 | raise ValueError( 268 | "The run configuration is missing. Please make sure the correct path" 269 | "is specified." 270 | ) 271 | 272 | with open(config_path, "r", encoding="utf8") as fp: 273 | run_configuration = yaml.safe_load(fp) 274 | 275 | # Create trainer 276 | # -------------- 277 | trainer_object, _ = create_trainer(run_config=run_configuration) 278 | 279 | # Copy the source files into the results directory 280 | # ------------------------------------------------ 281 | copy_source_files(trainer_object) 282 | 283 | # Perform training! 284 | # ----------------- 285 | trainer_object.train() 286 | 287 | # Create a (zipped) submission file 288 | # --------------------------------- 289 | subprocess.call( 290 | [ 291 | "python", 292 | os.path.join(PUBLIC_REPO_DIR, "scripts", "create_submission_zip.py"), 293 | "--results_dir", 294 | trainer_object.save_dir, 295 | ] 296 | ) 297 | 298 | # Shut off the trainer gracefully 299 | # ------------------------------- 300 | trainer_object.graceful_close() 301 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Competition: Fostering Global Cooperation to Mitigate Climate Change 2 | 3 | [![PyTorch 1.9.0](https://img.shields.io/badge/PyTorch-1.9.0-ee4c2c?logo=pytorch&logoColor=white%22)](https://pytorch.org/docs/1.12/) 4 | [![Python 3.7](https://img.shields.io/badge/python-3.7-blue.svg)](https://www.python.org/downloads/release/python-3713/) 5 | [![Warp drive 1.7.0](https://img.shields.io/badge/warp_drive-1.7.0-blue.svg)](https://github.com/salesforce/warp-drive/) 6 | [![Ray 1.0.0](https://img.shields.io/badge/ray[rllib]-1.0.0-blue.svg)](https://docs.ray.io/en/latest/index.html) 7 | [![Paper](http://img.shields.io/badge/paper-arxiv.2208.07004-B31B1B.svg)](https://arxiv.org/abs/2208.07004) 8 | [![Code Tutorial](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mila-iqia/climate-cooperation-competition/blob/main/Colab_Tutorial.ipynb) 9 | 10 | (Code Tutorial Notebook on Kaggle with free GPU available) 11 | 12 | 13 | This is the code respository for the competition on modeling global cooperation in the RICE-N Integrated Assessment Model. This competition is co-organized by MILA and Salesforce Research. 14 | 15 | The RICE-N IAM is an agent-based model that incorporates DICE climate-economic dynamics and multi-lateral negotiation protocols between several fictitious nations. 16 | 17 | In this competition, you will design negotiation protocols and contracts between nations. You will use the simulation and agents to evaluate their impact on the climate and the economy. 18 | 19 | We recommend that GPU users use ``warp_drive`` and CPU users use ``rllib``. 20 | 21 | ## Tutorial 22 | - For all information and the leaderboard, see [our official website](https://www.ai4climatecoop.org). 23 | - [How to Get Started](getting_started.ipynb) 24 | - [Code Tutorial Notebook with **free GPU**](https://colab.research.google.com/github/mila-iqia/climate-cooperation-competition/blob/main/Colab_Tutorial.ipynb) 25 | - [Code Kaggle Tutorial Notebook with **free GPU**](https://www.kaggle.com/kernels/fork-version/105300459) 26 | 27 | ## Resources 28 | - For the mathematical background and scientific references, please see [the white paper](https://deliverypdf.ssrn.com/delivery.php?ID=579098091025080122123095015088114126057046084059055038121023114094110112070107123088059057002107022006023123122016086089001013042072002040020075022078097115093071118048047053064022064117095120085074022123010099031092026025094015125094099080071097079070&EXT=pdf&INDEX=TRUE). 29 | - Other free GPU resources: [Baidu Paddle](https://aistudio.baidu.com/), [MegStudio](https://studio.brainpp.com/) 30 | 31 | 32 | ## Installation 33 | 34 | Notice: we recommend using `Linux` or `MacOS`. For Windows users, we recommend to use virtual machine running `Ubuntu 20.04` or [Windows Subsystem for Linux](https://docs.microsoft.com/en-us/windows/wsl/install). 35 | 36 | You can get a copy of the code by cloning the repo using Git: 37 | 38 | ``` 39 | git clone https://github.com/mila-iqia/climate-cooperation-competition 40 | cd climate-cooperation-competition 41 | ``` 42 | 43 | As an alternative, one can also use: 44 | 45 | ``` 46 | git clone https://e.coding.net/ai4climatecoop/ai4climatecoop/climate-cooperation-competition.git 47 | cd climate-cooperation-competition 48 | ``` 49 | 50 | We recommend using a virtual environment (such as provided by ```virtualenv``` or Anaconda). 51 | 52 | You can install the dependencies using pip: 53 | 54 | ``` 55 | pip install -r requirements.txt 56 | ``` 57 | 58 | ## Get Started 59 | Then run the getting started Jupyter notebook, by starting Jupyter: 60 | 61 | ``` 62 | jupyter notebook 63 | ``` 64 | 65 | and then navigating to [getting_started.ipynb](getting_started.ipynb). 66 | 67 | It provides a quick walkthrough for registering for the competition and creating a valid submission. 68 | 69 | 70 | ## Training with reinforcement learning 71 | 72 | RL agents can be trained using the RICE-N simulation using one of these two frameworks: 73 | 74 | 1. [RLlib](https://docs.ray.io/en/latest/rllib/index.html#:~:text=RLlib%20is%20an%20open%2Dsource,large%20variety%20of%20industry%20applications): The pythonic environment can be trained on your local CPU machine using open-source RL framework, RLlib. 75 | 2. [WarpDrive](https://github.com/salesforce/warp-drive): WarpDrive is a GPU-based framework that allows for [over 10x faster training](https://arxiv.org/pdf/2108.13976.pdf) compared to CPU-based training. It requires the simulation to be written out in CUDA C, and we also provide a starter version of the simulation environment written in CUDA C ([rice_step.cu](rice_step.cu)) 76 | 77 | We also provide starter scripts to train the simulation you build with either of the above frameworks. 78 | 79 | Note that we only allow these two options, since our backend submission evaluation process only supports these at the moment. 80 | 81 | 82 | For training with RLlib, `rllib (1.0.0)`, `torch (1.9.0)` and `gym (0.21)` packages are required. 83 | 84 | 85 | 86 | For training with WarpDrive, the `rl-warp-drive (>=1.6.5)` package is needed. 87 | 88 | Note that these requirements are automatically installed (or updated) when you run the corresponding training scripts. 89 | 90 | 91 | ## Docker image (for GPU users) 92 | 93 | We have also provided a sample dockerfile for your reference. It mainly uses a Nvidia PyTorch base image, and installs the `pycuda` package as well. Note: `pycuda` is only required if you would like to train using WarpDrive. 94 | 95 | 96 | ## Docker image (for CPU users) 97 | 98 | Thanks for the contribution from @muxspace. We also have an end-to-end docker environment ready for CPU users. Please refer to [README_CPU.md](README_CPU.md) for more details. 99 | 100 | # Customizing and running the simulation 101 | 102 | See the [Colab_Tutorial.ipynb](Tutorial.ipynb). [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mila-iqia/climate-cooperation-competition/blob/main/Colab_Tutorial.ipynb) for details. 103 | 104 | It provides examples on modifying the code to implement different negotiation protocols. It describes ways of changing the agent observations and action spaces corresponding to the proposed negotiation protocols and implementing the negotiation logic in the provided code. 105 | 106 | The notebook has a walkthrough of how to train RL agents with the simulation code and how to visualize results from the simulation after running it with a set of agents. 107 | 108 | For those who have **limited access to Colab**, please try to use [**free GPUs on Kaggle**](https://www.kaggle.com/kernels/fork-version/105300459). Please notice that Kaggle platform requires mobile phone verification to be able to access the GPUs. One may find the **settings** to get GPUs and internet connect on the right hand side after clicking on the link above and login. 109 | 110 | # Training RL agents in your simulation 111 | 112 | Once you build your simulation, you can use either of the following scripts to perform training. 113 | 114 | - [train_with_rllib.py](/scripts/train_with_rllib.py): this script performs end-to-end training with RLlib. The experiment run configuration will be read in from [rice_rllib.yaml](/scripts/rice_rllib.yaml), which contains the environment configuration, logging and saving settings and the trainer and policy network parameters. The duration of training can be set via the `num_episodes` parameter. We have also provided an initial implementation of a linear PyTorch policy model in [torch_models.py](/scripts/torch_models.py). You can [add other policy models](https://docs.ray.io/en/latest/rllib/rllib-concepts.html) you wish to use into that file. 115 | 116 | USAGE: The training script (with RLlib) is invoked using (from the root directory) 117 | ```commandline 118 | python scripts/train_with_rllib.py 119 | ``` 120 | 121 | - [train_with_warp_drive.py](/scripts/train_with_warp_drive.py): this script performs end-to-end training with WarpDrive. The experiment run configuration will be read in from [rice_warpdrive.yaml](/scripts/rice_warpdrive.yaml). Currently, WarpDrive just supports the Advantage Actor-Critic (A2C) and the Proximal Policy Optimization (PPO) algorithms, and the fully-connected policy model. 122 | 123 | USAGE: The training script (with WarpDrive) is invoked using 124 | ```commandline 125 | python scripts/train_with_warpdrive.py 126 | ``` 127 | 128 | As training progresses, some key metrics (such as the mean episode reward) are printed on screen for your reference. At the end of training, a zipped submission file is automatically created and saved for your reference. The zipped file essentially comprises the following 129 | 130 | - An identifier file (`.rllib` or `.warpdrive`) indicating which framework was used towards training. 131 | - The environment files - [rice.py](rice.py) and [rice_helpers.py](rice_helpers.py). 132 | - A copy of the yaml configuration file ([rice_rllib.yaml](/scripts/rice_rllib.yaml) or [rice_warpdrive.yaml](/scripts/rice_warpdrive.yaml)) used for training. 133 | - PyTorch policy model(s) (of type ".state_dict") containing the trained weights for the policy network(s). Only the trained policy model for the final timestep will be copied over into the submission zip. If you would like to instead submit the trained policy model at a different timestep, please see the section below on creating your submission file. 134 | - For submissions using WarpDrive, the submission will also contain CUDA-specific files [rice_step.cu](rice_step.cu) and [rice_cuda](rice_cuda.py) that were used for training. 135 | 136 | 137 | # Contributing 138 | 139 | We are always looking for contributors from various domains to help us make this simulation more realistic. 140 | 141 | If there are bugs or corner cases, please open a PR detailing the issue and consider submitting to Track 3! 142 | 143 | 144 | # Citation 145 | 146 | To cite this code, please use the information in [CITATION.cff](CITATION.cff) and the following bibtex entry: 147 | 148 | ``` 149 | @software{Zhang_RICE-N_2022, 150 | author = {Zhang, Tianyu and Srinivasa, Sunil and Williams, Andrew and Phade, Soham and Zhang, Yang and Gupta, Prateek and Bengio, Yoshua and Zheng, Stephan}, 151 | month = {7}, 152 | title = {{RICE-N}}, 153 | url = {https://github.com/mila-iqia/climate-cooperation-competition}, 154 | version = {1.0.0}, 155 | year = {2022} 156 | } 157 | 158 | @misc{https://doi.org/10.48550/arxiv.2208.07004, 159 | doi = {10.48550/ARXIV.2208.07004}, 160 | url = {https://arxiv.org/abs/2208.07004}, 161 | author = {Zhang, Tianyu and Williams, Andrew and Phade, Soham and Srinivasa, Sunil and Zhang, Yang and Gupta, Prateek and Bengio, Yoshua and Zheng, Stephan}, 162 | title = {AI for Global Climate Cooperation: Modeling Global Climate Negotiations, Agreements, and Long-Term Cooperation in RICE-N}, 163 | publisher = {arXiv}, 164 | year = {2022} 165 | } 166 | 167 | ``` 168 | 169 | # License 170 | 171 | For license information, see ```LICENSE.txt```. 172 | -------------------------------------------------------------------------------- /scripts/run_unittests.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc and MILA. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root 5 | # or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | 8 | """ 9 | Unit tests for the rice simulation 10 | """ 11 | import argparse 12 | import importlib.util as iu 13 | import logging 14 | import os 15 | import shutil 16 | import subprocess 17 | import sys 18 | import time 19 | import unittest 20 | 21 | import numpy as np 22 | 23 | # from evaluate_submission import get_results_dir 24 | from fixed_paths import PUBLIC_REPO_DIR 25 | 26 | sys.path.append(PUBLIC_REPO_DIR) 27 | 28 | _REGION_YAMLS = "region_yamls" 29 | 30 | # Set logger level e.g., DEBUG, INFO, WARNING, ERROR. 31 | logging.getLogger().setLevel(logging.ERROR) 32 | 33 | _BASE_CODE_PATH = ( 34 | "https://raw.githubusercontent.com/mila-iqia/climate-cooperation-competition/main" 35 | ) 36 | _BASE_RICE_PATH = os.path.join(_BASE_CODE_PATH, "rice.py") 37 | _BASE_RICE_HELPERS_PATH = os.path.join(_BASE_CODE_PATH, "rice_helpers.py") 38 | _BASE_RICE_BUILD_PATH = os.path.join(_BASE_CODE_PATH, "rice_build.cu") 39 | _BASE_CONSISTENCY_CHECKER_PATH = os.path.join( 40 | _BASE_CODE_PATH, "scripts/run_cpu_gpu_env_consistency_checks.py" 41 | ) 42 | 43 | 44 | def import_class_from_path(class_name=None, path=None): 45 | """ 46 | Helper function to import class from a path. 47 | """ 48 | assert class_name is not None 49 | assert path is not None 50 | spec = iu.spec_from_file_location(class_name, path) 51 | module_from_spec = iu.module_from_spec(spec) 52 | spec.loader.exec_module(module_from_spec) 53 | return getattr(module_from_spec, class_name) 54 | 55 | 56 | def fetch_base_env(base_folder=".tmp/_base"): 57 | """ 58 | Download the base version of the code from GitHub. 59 | """ 60 | if not base_folder.startswith("/"): 61 | base_folder = os.path.join(PUBLIC_REPO_DIR, base_folder) 62 | # print(f"Using tmp dir {base_folder}") 63 | if os.path.exists(base_folder): 64 | shutil.rmtree(base_folder) 65 | os.makedirs(base_folder, exist_ok=False) 66 | 67 | print( 68 | "\nDownloading a base version of the code from GitHub" 69 | " to run consistency checks..." 70 | ) 71 | prev_dir = os.getcwd() 72 | os.chdir(base_folder) 73 | subprocess.call(["wget", "-O", "rice.py", _BASE_RICE_PATH]) 74 | subprocess.call(["wget", "-O", "rice_helpers.py", _BASE_RICE_HELPERS_PATH]) 75 | if "region_yamls" not in os.listdir(base_folder): 76 | shutil.copytree( 77 | os.path.join(PUBLIC_REPO_DIR, "region_yamls"), 78 | os.path.join(base_folder, "region_yamls"), 79 | ) 80 | 81 | base_rice = import_class_from_path("Rice", os.path.join(base_folder, "rice.py"))() 82 | 83 | # Clean up base code 84 | os.chdir(prev_dir) 85 | shutil.rmtree(base_folder) 86 | return base_rice 87 | 88 | 89 | class TestEnv(unittest.TestCase): 90 | """ 91 | The env testing class. 92 | """ 93 | 94 | @classmethod 95 | def setUpClass(cls): 96 | """Set-up""" 97 | # Initialization 98 | cls.framework = "rllib" 99 | assert cls.results_dir is not None 100 | 101 | # Note: results_dir attributed set in __main__. 102 | if _REGION_YAMLS not in os.listdir(cls.results_dir): 103 | shutil.copytree( 104 | os.path.join(PUBLIC_REPO_DIR, "region_yamls"), 105 | os.path.join(cls.results_dir, "region_yamls"), 106 | ) 107 | 108 | if ".warpdrive" in os.listdir(cls.results_dir): 109 | cls.framework = "warpdrive" 110 | # Copy the consistency checker file into the results_dir 111 | prev_dir = os.getcwd() 112 | os.chdir(cls.results_dir) 113 | os.makedirs("scripts", exist_ok=True) 114 | subprocess.call( 115 | [ 116 | "wget", 117 | "-O", 118 | "scripts/run_cpu_gpu_env_consistency_checks.py", 119 | _BASE_CONSISTENCY_CHECKER_PATH, 120 | ] 121 | ) 122 | subprocess.call( 123 | [ 124 | "wget", 125 | "-O", 126 | "rice_build.cu", 127 | _BASE_RICE_BUILD_PATH, 128 | ] 129 | ) 130 | os.chdir(prev_dir) 131 | else: 132 | assert ".rllib" in os.listdir(cls.results_dir), ( 133 | f"Missing identifier file! " 134 | f"Either the .rllib or the .warpdrive " 135 | f"file must be present in the results directory: {cls.results_dir}" 136 | ) 137 | 138 | cls.base_env = fetch_base_env() # Fetch the base version from GitHub 139 | try: 140 | cls.env = import_class_from_path( 141 | "Rice", os.path.join(cls.results_dir, "rice.py") 142 | )() 143 | except Exception as err: 144 | raise ValueError( 145 | "The Rice environment could not be instantiated !" 146 | ) from err 147 | 148 | base_env_action_nvec = np.array(cls.base_env.action_space[0].nvec) 149 | cls.base_env_random_actions = { 150 | agent_id: np.random.randint( 151 | low=0 * base_env_action_nvec, high=base_env_action_nvec - 1 152 | ) 153 | for agent_id in range(cls.base_env.num_agents) 154 | } 155 | sample_agent_id = 0 156 | env_action_nvec = np.array(cls.env.action_space[sample_agent_id].nvec) 157 | len_negotiation_actions = len(env_action_nvec) - len(base_env_action_nvec) 158 | cls.env_random_actions = { 159 | agent_id: np.append( 160 | cls.base_env_random_actions[agent_id], 161 | np.zeros(len_negotiation_actions, dtype=np.int32), 162 | ) 163 | for agent_id in range(cls.env.num_agents) 164 | } 165 | 166 | def test_env_attributes(self): 167 | """ 168 | Test the env attributes are consistent with the base version. 169 | """ 170 | for attribute in [ 171 | "all_constants", 172 | "num_regions", 173 | "num_agents", 174 | "start_year", 175 | "end_year", 176 | "num_discrete_action_levels", 177 | ]: 178 | np.testing.assert_array_equal( 179 | getattr(self.base_env, attribute), getattr(self.env, attribute) 180 | ) 181 | 182 | features = [ 183 | "activity_timestep", 184 | "global_temperature", 185 | "global_carbon_mass", 186 | "global_exogenous_emissions", 187 | "global_land_emissions", 188 | "capital_all_regions", 189 | "capital_depreciation_all_regions", 190 | "labor_all_regions", 191 | "gross_output_all_regions", 192 | "investment_all_regions", 193 | "consumption_all_regions", 194 | "savings_all_regions", 195 | "mitigation_rate_all_regions", 196 | "tariffs", 197 | "max_export_limit_all_regions", 198 | "current_balance_all_regions", 199 | "production_factor_all_regions", 200 | "intensity_all_regions", 201 | "mitigation_cost_all_regions", 202 | "damages_all_regions", 203 | "abatement_cost_all_regions", 204 | "production_all_regions", 205 | "utility_all_regions", 206 | "social_welfare_all_regions", 207 | "reward_all_regions", 208 | "scaled_imports", 209 | ] 210 | 211 | # Test equivalence after reset 212 | self.base_env.reset() 213 | self.env.reset() 214 | 215 | for feature in features: 216 | np.testing.assert_array_equal( 217 | getattr(self.base_env, "global_state")[feature]["value"][0], 218 | getattr(self.env, "global_state")[feature]["value"][0], 219 | ) 220 | 221 | # Test equivalence after stepping through the env 222 | for timestep in range(self.base_env.episode_length): 223 | self.base_env.timestep += 1 224 | self.base_env.climate_and_economy_simulation_step( 225 | self.base_env_random_actions 226 | ) 227 | 228 | self.env.timestep += 1 229 | self.env.climate_and_economy_simulation_step(self.env_random_actions) 230 | 231 | for feature in features: 232 | np.testing.assert_array_equal( 233 | getattr(self.base_env, "global_state")[feature]["value"][timestep], 234 | getattr(self.env, "global_state")[feature]["value"][timestep], 235 | ) 236 | 237 | def test_env_reset(self): 238 | """ 239 | Test the env reset output 240 | """ 241 | obs_at_reset = self.env.reset() 242 | self.assertEqual(len(obs_at_reset), self.env.num_agents) 243 | 244 | def test_env_step(self): 245 | """ 246 | Test the env step output 247 | """ 248 | assert isinstance( 249 | self.env.action_space, dict 250 | ), "Action space must be a dictionary keyed by agent ids." 251 | assert sorted(list(self.env.action_space.keys())) == list( 252 | range(self.env.num_agents) 253 | ) 254 | 255 | # Test with all random actions 256 | obs, rew, done, _ = self.env.step(self.env_random_actions) 257 | self.assertEqual(list(obs.keys()), list(rew.keys())) 258 | assert list(done.keys()) == ["__all__"] 259 | 260 | def test_cpu_gpu_consistency_checks(self): 261 | """ 262 | Run the CPU/GPU environment consistency checks 263 | (only if using the CUDA version of the env) 264 | """ 265 | if self.framework == "warpdrive": 266 | # Execute the CPU-GPU consistency checks 267 | os.chdir(self.results_dir) 268 | subprocess.check_output( 269 | ["python", "scripts/run_cpu_gpu_env_consistency_checks.py"] 270 | ) 271 | 272 | 273 | def get_results_dir(): 274 | """ 275 | Obtain the 'results' directory from the system arguments. 276 | """ 277 | parser = argparse.ArgumentParser() 278 | parser.add_argument( 279 | "--results_dir", 280 | "-r", 281 | type=str, 282 | default=".", 283 | help="the directory where all the submission files are saved. Can also be " 284 | "the zipped file containing all the submission files.", 285 | ) 286 | args = parser.parse_args() 287 | 288 | if "results_dir" not in args: 289 | raise ValueError( 290 | "Please provide a results directory to evaluate with the argument -r" 291 | ) 292 | if not os.path.exists(args.results_dir): 293 | raise ValueError( 294 | "The results directory is missing. Please make sure the correct path " 295 | "is specified!" 296 | ) 297 | try: 298 | results_dir = args.results_dir 299 | 300 | # Also handle a zipped file 301 | if results_dir.endswith(".zip"): 302 | unzipped_results_dir = os.path.join("/tmp", str(time.time())) 303 | shutil.unpack_archive(results_dir, unzipped_results_dir) 304 | results_dir = unzipped_results_dir 305 | return results_dir, parser 306 | except Exception as err: 307 | raise ValueError("Cannot obtain the results directory") from err 308 | 309 | # if __name__ == "__main__": 310 | # Skip all of this 311 | # logging.info("Running env unit tests...") 312 | # # Set the results directory 313 | # results_dir, parser = get_results_dir() 314 | # parser.add_argument("unittest_args", nargs="*") 315 | # args = parser.parse_args() 316 | # sys.argv[1:] = args.unittest_args 317 | # TestEnv.results_dir = results_dir 318 | 319 | # unittest.main() 320 | -------------------------------------------------------------------------------- /getting_started.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Copyright (c) 2022, salesforce.com, inc and MILA. \n", 8 | "All rights reserved. \n", 9 | "SPDX-License-Identifier: BSD-3-Clause \n", 10 | "For full license text, see the LICENSE file in the repo root \n", 11 | "or https://opensource.org/licenses/BSD-3-Clause " 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "# How can I register for the competition?" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "Please fill out the [registration form](https://docs.google.com/forms/d/e/1FAIpQLSe2SWnhJaRpjcCa3idq7zIFubRoH0pATLOP7c1Y0kMXOV6U4w/viewform) in order to register for the competition. \n", 26 | "\n", 27 | "You will only need to provide an email address and a team name. You will also need to be willing to open-source your code after the competition.\n", 28 | "\n", 29 | "After you submit your registration form, we will register it internally. Please allow for upto 1-2 working days for your team name to be registered. You will be notified via email upon successful registration. You will need your team name in order to make submissions towards the competition." 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "# Quickly train agents with CPU using rllib and create a submission" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "The following command should install all the pre-requisites automatically.\n", 44 | "\n", 45 | "Please make sure that you are using Python 3.7 or older version. Our code does not support 3.8 or newer versions currently." 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "!python ./scripts/train_with_rllib.py" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "# Evaluate your submission locally" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "Before you actually upload your submission files, you can also evaluate and score your submission on your end using this script. The evaluation script essentially validates the submission files, performs unit testing and computes the metrics for evaluation. To compute the metrics, we first instantiate a trainer, load the policy model with the saved parameters, and then generate several episode rollouts to measure the impact of the policy on the environment." 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "!python ./scripts/evaluate_submission.py -r Submissions/.zip" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "# Where can I submit my solution?" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "*NOTE: Please register for the competition (see the steps above), if you have not done so. Your team must be registered before you can submit your solutions.*\n", 92 | "\n", 93 | "The AI climate competition features 3 tracks.\n", 94 | "\n", 95 | "In Track 1, you will propose and implement multilateral agreements to augment the simulator, and train the AI agents in the simulator. We evaluate the learned policies and resulting economic and climate change metrics.\n", 96 | "\n", 97 | "- The submission form for Track 1 is [here](https://forms.gle/fuM4NZ5eX2rdckit6).\n", 98 | "- Or, as an alternative, submit [here](https://workspace.jianguoyun.com/inbox/collect/c7be3a1c61624a4498666095d8a51824/submit) if you have difficulty accessing the Google form\n", 99 | "\n", 100 | "\n", 101 | "Please select your registered team name from the drop-down menu, and upload a zip file containing the submission files - we will be providing scripts to help you create the zip file.\n", 102 | "\n", 103 | "In Track 2, you will argue why your solution is practically relevant and usable in the real world. We expect the entries in this track to contain a high-level summary for policymakers\n", 104 | "\n", 105 | "To submit the generated running result and your code (the **.zip** file):\n", 106 | "- The submission form for Track 2 is [here](https://forms.gle/1kTsFLUp6yVF3xQf9).\n", 107 | "- Or, as an alternative, submit [here](https://workspace.jianguoyun.com/inbox/collect/323b16a8697741348e3197ebccafea81/submit) if you have difficulty accessing the Google form\n", 108 | "\n", 109 | "To submit your essay: [OpenReview](https://openreview.net/group?id=AI4ClimateCoop.org/2022/Workshop)\n", 110 | "\n", 111 | "In Track 3, we invite you to point out potential simulation loopholes and improvements.\n", 112 | "\n", 113 | "To submit your code to support your suggestions (a **.zip** file, optional):\n", 114 | "- The submission form for Track 3 is [here](https://forms.gle/stBihEdaf58xi2Vr6).\n", 115 | "- Or, as an alternative, submit [here](https://workspace.jianguoyun.com/inbox/collect/e96127b108ed4172a3b79273688a883c/submit) if you have difficulty accessing the Google form\n", 116 | "\n", 117 | "To submit your essay: [OpenReview](https://openreview.net/group?id=AI4ClimateCoop.org/2022/Workshop)\n", 118 | "\n", 119 | "If you do not see your team name in the drop-down menu, please contact us on Slack or by e-mail, and we will resolve that for you." 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "# How do I create a submission using my modified negotiation protocol? " 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "We provide the base version of the RICE-N (regional integrated climate environment) simulation environment written in Python (`rice.py`).\n", 134 | "\n", 135 | "**For the mathematical background and scientific references, please see [the white paper](https://deliverypdf.ssrn.com/delivery.php?ID=428101121103108016095076093074095111015069058086095042085123117113111092124092117108004117037031126012054120125119115118069067102029022089006118121099082113093096121049050055084110110018003106083072011105122122123113102083083074084083085090104119080101&EXT=pdf&INDEX=TRUE).**\n", 136 | "\n", 137 | "You will need to mainly modify the `rice.py` to implement the proposed negotiatoin protocol. Additional details can be found below." 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "## Scripts for creating the zipped submission file" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "As mentioned above, the zipped file required for submission is automatically created post-training. However, for any reason (for example, for providing a trained policy model at a different timestep), you can create the zipped submission yourself using the `create_submizzion_zip.py` script. Accordingly, create a new directory (say `submission_dir`) with all the relevant files (see the section above), and you can then simply invoke\n", 152 | "```commandline\n", 153 | "python scripts/create_submission_zip.py -r \n", 154 | "```\n", 155 | "\n", 156 | "That will first validate that the submission directory contains all the required files, and then provide you a zipped file that can you use towards your submission." 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "!python ./scripts/create_submission_zip.py -r " 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "## Scripts for unit testing" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "In order to make sure that all the submissions are consistent in that they comply within the rules of the competition, we have also added unit tests. These are automatically run also when the evaluation is performed. The script currently performs the following tests\n", 180 | "\n", 181 | "- Test that the environment attributes (such as the RICE and DICE constants, the simulation period and the number of regions) are consistent with the base environment class that we also provide.\n", 182 | "- Test that the `climate_and_economy_simulation_step()` is consistent with the base class. As aforementioned, users are free to add different negotiation strategies such as multi-lateral negotiations or climate clubs, but should not modify the equations underlying the climate and economic dynamics in the world.\n", 183 | "- Test that the environment resetting and stepping yield outputs in the desired format (for instance, observations are a dictionary keyed by region id, and so are rewards.)\n", 184 | "- If the user used WarpDrive, we also perform consistency checks to verify that the CUDA implementation of the rice environment is consistent with the pythonic version.\n", 185 | "\n", 186 | "USAGE: You may invoke the unit tests on a submission file via\n", 187 | "```commandline\n", 188 | "python scripts/run_unittests.py -r \n", 189 | "```" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "!python ./scripts/run_unittests.py -r " 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "## Scripts for performance evaluation" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": {}, 211 | "source": [ 212 | "USAGE: You may evaluate the submission file using\n", 213 | "```commandline\n", 214 | "python scripts/evaluate_submission.py -r \n", 215 | "```\n", 216 | "Please verify that you can indeed evaluate your submission, before actually uploading it." 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "# Evaluation process" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": {}, 229 | "source": [ 230 | "After you submit your solution, we will be using the same evaluation script that is provided to you, to score your submissions, but using several rollout episodes to average the metrics such as the average rewards, the global temperature rise, capital, production, and many more. We will then rank the submissions based on the various metrics.The score computed by the evaluation process should be similar to the score computed on your end, since they use the same scripts." 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "## What happens when I make an invalid submission?" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "An \"invalid submission\" may refer to a submission wherein some or all of the submission files are missing, or the submission files are inconsistent with the base version, basically anything that fails in the evaluation process. Any invalid solution cannot be evaluated, and hence will not feature in the leaderboard. While we can let you know if your submission is invalid, the process is not automated, so we may not be able to do it promptly. To avoid any issues, please use the `create_submission_zip` script to create your zipped submission file." 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": {}, 250 | "source": [ 251 | "# Leaderboard" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "The competition leaderboard is displayed on the [competition website](https://mila-iqia.github.io/climate-cooperation-competition). After you submit your valid submission, please give it a few minutes to perform an evaluation of your submission and refresh the leaderboard." 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": {}, 264 | "source": [ 265 | "# How many submissions are allowed per team?" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "There is no limit on the number of submissions per team. Feel free to submit as many solutions as you would like. We will only be using your submission with the highest evaluation score towards the leaderboard." 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": { 278 | "ExecuteTime": { 279 | "end_time": "2022-06-13T13:41:49.356590Z", 280 | "start_time": "2022-06-13T13:41:22.620490Z" 281 | }, 282 | "scrolled": true 283 | }, 284 | "source": [ 285 | "# Code overview" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "## File Structure" 293 | ] 294 | }, 295 | { 296 | "cell_type": "markdown", 297 | "metadata": {}, 298 | "source": [ 299 | "Below is the detailed file tree, and file descriptions.\n", 300 | "```commandline\n", 301 | "ROOT_DIR\n", 302 | "├── rice.py\n", 303 | "├── rice_helpers.py\n", 304 | "├── region_yamls\n", 305 | "\n", 306 | "├── rice_step.cu\n", 307 | "├── rice_cuda.py\n", 308 | "├── rice_build.cu\n", 309 | "\n", 310 | "└── scripts\n", 311 | " ├── train_with_rllib.py\n", 312 | " ├── rice_rllib.yaml\n", 313 | " ├── torch_models.py\n", 314 | " \n", 315 | " ├── train_with_warp_drive.py\n", 316 | " ├── rice_warpdrive.yaml\n", 317 | " ├── run_cpu_gpu_env_consistency_checks.py\n", 318 | " \n", 319 | " ├── run_unittests.py \n", 320 | " ├── create_submission_zip.py\n", 321 | " └── evaluate_submission.py \n", 322 | "```" 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": {}, 328 | "source": [ 329 | "## Environment files" 330 | ] 331 | }, 332 | { 333 | "cell_type": "markdown", 334 | "metadata": {}, 335 | "source": [ 336 | "- `rice.py`: This python script contains the base Rice class. This is written in [OpenAI Gym](https://gym.openai.com/) style with the `reset()` and `step()` functionalities. The step() function comprises an implementation of the `climate_and_economy_simulation_step` which dictate the dynamics of the climate and economy simulation, and should not be altered by the user. We have also provided a simple implementation of bilateral negotiation between regions via the `proposal_step()` and `evaluation_step()` methods. Users can extend the simulation by adding additional proposal strategies, for example, and incorporating them in the `step()` function. However, please do not modify any of the equations dictating the environment dynamics in the `climate_and_economy_simulation_step()`. All the helper functions related to modeling the climate and economic simulation are located in `rice_helpers.py`. Region-specific environment parameters are provided in the `region_yamls` directory.\n", 337 | "\n", 338 | "\n", 339 | "- `rice_step.cu`\n", 340 | "This is the CUDA C version of the step() function that is required for use with WarpDrive. To get started with WarpDrive, we recommend following these [tutorials](https://github.com/salesforce/warp-drive/tree/master/tutorials). While WarpDrive requires writing the simulation in CUDA C, it also offers orders-of-magnitude speedups for end-to-end training, since it performs rollouts and training all on the GPU. `rice_cuda.py` nd `rice_build.cu` are necessary files for copying simulation data to the GPU and compiling the CUDA code.\n", 341 | "\n", 342 | "While implementing the simulation in CUDA C on the GPU offers significantly faster simulations, it requires careful memory management. To make sure that everything works properly, one approach is to first implement your simulation logic in Python. You can then implement the same logic in CUDA C and check the simulation behaviors are the same. To help with this process, we provide an environment consistency checker method to do consistency tests between Python and CUDA C simulations. Before training your CUDA C code, please run the consistency checker to ensure the Python and CUDA C implementations are consistent.\n", 343 | "```commandline\n", 344 | "python scripts/run_env_cpu_gpu_consistency_checks.py\n", 345 | "```\n", 346 | "\n", 347 | "See the [tutorial notebook](https://colab.research.google.com/drive/1ifcYaczxy4eHM986fyFeSXarGAKmPLt5#scrollTo=vrIIciaAlHSl) for additional details on modifying the code to implement proposed negotiation protocols." 348 | ] 349 | } 350 | ], 351 | "metadata": { 352 | "kernelspec": { 353 | "display_name": "Python 3.8.10 64-bit", 354 | "language": "python", 355 | "name": "python3" 356 | }, 357 | "language_info": { 358 | "codemirror_mode": { 359 | "name": "ipython", 360 | "version": 3 361 | }, 362 | "file_extension": ".py", 363 | "mimetype": "text/x-python", 364 | "name": "python", 365 | "nbconvert_exporter": "python", 366 | "pygments_lexer": "ipython3", 367 | "version": "3.8.10" 368 | }, 369 | "vscode": { 370 | "interpreter": { 371 | "hash": "570feb405e2e27c949193ac68f46852414290d515b0ba6e5d90d076ed2284471" 372 | } 373 | } 374 | }, 375 | "nbformat": 4, 376 | "nbformat_minor": 4 377 | } 378 | -------------------------------------------------------------------------------- /scripts/train_with_rllib.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc and MILA. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root 5 | # or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | 8 | """ 9 | Training script for the rice environment using RLlib 10 | https://docs.ray.io/en/latest/rllib-training.html 11 | """ 12 | 13 | import logging 14 | import os 15 | import shutil 16 | import subprocess 17 | import sys 18 | import time 19 | 20 | import numpy as np 21 | import yaml 22 | from desired_outputs import desired_outputs 23 | from fixed_paths import PUBLIC_REPO_DIR 24 | from run_unittests import import_class_from_path 25 | from opt_helper import save 26 | 27 | sys.path.append(PUBLIC_REPO_DIR) 28 | 29 | # Set logger level e.g., DEBUG, INFO, WARNING, ERROR. 30 | logging.getLogger().setLevel(logging.DEBUG) 31 | 32 | 33 | def perform_other_imports(): 34 | """ 35 | RLlib-related imports. 36 | """ 37 | import ray 38 | import torch 39 | from gym.spaces import Box, Dict 40 | from ray.rllib.agents.a3c import A2CTrainer 41 | from ray.rllib.env.multi_agent_env import MultiAgentEnv 42 | from ray.tune.logger import NoopLogger 43 | 44 | return ray, torch, Box, Dict, MultiAgentEnv, A2CTrainer, NoopLogger 45 | 46 | 47 | print("Do imports") 48 | 49 | try: 50 | other_imports = perform_other_imports() 51 | except ImportError: 52 | print("Installing requirements...") 53 | 54 | # Install gym 55 | subprocess.call(["pip", "install", "gym==0.21.0"]) 56 | # Install RLlib v1.0.0 57 | subprocess.call(["pip", "install", "ray[rllib]==1.0.0"]) 58 | # Install PyTorch 59 | subprocess.call(["pip", "install", "torch==1.9.0"]) 60 | 61 | other_imports = perform_other_imports() 62 | 63 | ray, torch, Box, Dict, MultiAgentEnv, A2CTrainer, NoopLogger = other_imports 64 | 65 | from torch_models import TorchLinear 66 | 67 | logging.info("Finished imports") 68 | 69 | 70 | _BIG_NUMBER = 1e20 71 | 72 | 73 | def recursive_obs_dict_to_spaces_dict(obs): 74 | """Recursively return the observation space dictionary 75 | for a dictionary of observations 76 | 77 | Args: 78 | obs (dict): A dictionary of observations keyed by agent index 79 | for a multi-agent environment 80 | 81 | Returns: 82 | spaces.Dict: A dictionary of observation spaces 83 | """ 84 | assert isinstance(obs, dict) 85 | dict_of_spaces = {} 86 | for key, val in obs.items(): 87 | 88 | # list of lists are 'listified' np arrays 89 | _val = val 90 | if isinstance(val, list): 91 | _val = np.array(val) 92 | elif isinstance(val, (int, np.integer, float, np.floating)): 93 | _val = np.array([val]) 94 | 95 | # assign Space 96 | if isinstance(_val, np.ndarray): 97 | large_num = float(_BIG_NUMBER) 98 | box = Box( 99 | low=-large_num, high=large_num, shape=_val.shape, dtype=_val.dtype 100 | ) 101 | low_high_valid = (box.low < 0).all() and (box.high > 0).all() 102 | 103 | # This loop avoids issues with overflow to make sure low/high are good. 104 | while not low_high_valid: 105 | large_num = large_num // 2 106 | box = Box( 107 | low=-large_num, high=large_num, shape=_val.shape, dtype=_val.dtype 108 | ) 109 | low_high_valid = (box.low < 0).all() and (box.high > 0).all() 110 | 111 | dict_of_spaces[key] = box 112 | 113 | elif isinstance(_val, dict): 114 | dict_of_spaces[key] = recursive_obs_dict_to_spaces_dict(_val) 115 | else: 116 | raise TypeError 117 | return Dict(dict_of_spaces) 118 | 119 | 120 | def recursive_list_to_np_array(dictionary): 121 | """ 122 | Numpy-ify dictionary object to be used with RLlib. 123 | """ 124 | if isinstance(dictionary, dict): 125 | new_d = {} 126 | for key, val in dictionary.items(): 127 | if isinstance(val, list): 128 | new_d[key] = np.array(val) 129 | elif isinstance(val, dict): 130 | new_d[key] = recursive_list_to_np_array(val) 131 | elif isinstance(val, (int, np.integer, float, np.floating)): 132 | new_d[key] = np.array([val]) 133 | elif isinstance(val, np.ndarray): 134 | new_d[key] = val 135 | else: 136 | raise AssertionError 137 | return new_d 138 | raise AssertionError 139 | 140 | 141 | class EnvWrapper(MultiAgentEnv): 142 | """ 143 | The environment wrapper class. 144 | """ 145 | 146 | def __init__(self, env_config=None): 147 | 148 | super().__init__() 149 | 150 | env_config_copy = env_config.copy() 151 | if env_config_copy is None: 152 | env_config_copy = {} 153 | source_dir = env_config_copy.get("source_dir", None) 154 | # Remove source_dir key in env_config if it exists 155 | if "source_dir" in env_config_copy: 156 | del env_config_copy["source_dir"] 157 | if source_dir is None: 158 | source_dir = PUBLIC_REPO_DIR 159 | assert isinstance(env_config_copy, dict) 160 | self.env = import_class_from_path("Rice", os.path.join(source_dir, "rice.py"))( 161 | **env_config_copy 162 | ) 163 | 164 | self.action_space = self.env.action_space 165 | 166 | self.observation_space = recursive_obs_dict_to_spaces_dict(self.env.reset()) 167 | 168 | def reset(self): 169 | """Reset the env.""" 170 | obs = self.env.reset() 171 | return recursive_list_to_np_array(obs) 172 | 173 | def step(self, actions=None): 174 | """Step through the env.""" 175 | assert actions is not None 176 | assert isinstance(actions, dict) 177 | obs, rew, done, info = self.env.step(actions) 178 | return recursive_list_to_np_array(obs), rew, done, info 179 | 180 | 181 | def get_rllib_config(exp_run_config=None, env_class=None, seed=None): 182 | """ 183 | Reference: https://docs.ray.io/en/latest/rllib-training.html 184 | """ 185 | 186 | assert exp_run_config is not None 187 | assert env_class is not None 188 | 189 | env_config = exp_run_config["env"] 190 | assert isinstance(env_config, dict) 191 | env_object = env_class(env_config=env_config) 192 | 193 | # Define all the policies here 194 | policy_config = exp_run_config["policy"]["regions"] 195 | 196 | # Map of type MultiAgentPolicyConfigDict from policy ids to tuples 197 | # of (policy_cls, obs_space, act_space, config). This defines the 198 | # observation and action spaces of the policies and any extra config. 199 | policies = { 200 | "regions": ( 201 | None, # uses default policy 202 | env_object.observation_space[0], 203 | env_object.action_space[0], 204 | policy_config, 205 | ), 206 | } 207 | 208 | # Function mapping agent ids to policy ids. 209 | def policy_mapping_fn(agent_id=None): 210 | assert agent_id is not None 211 | return "regions" 212 | 213 | # Optional list of policies to train, or None for all policies. 214 | policies_to_train = None 215 | 216 | # Settings for Multi-Agent Environments 217 | multiagent_config = { 218 | "policies": policies, 219 | "policies_to_train": policies_to_train, 220 | "policy_mapping_fn": policy_mapping_fn, 221 | } 222 | 223 | train_config = exp_run_config["trainer"] 224 | rllib_config = { 225 | # Arguments dict passed to the env creator as an EnvContext object (which 226 | # is a dict plus the properties: num_workers, worker_index, vector_index, 227 | # and remote). 228 | "env_config": exp_run_config["env"], 229 | "framework": train_config["framework"], 230 | "multiagent": multiagent_config, 231 | "num_workers": train_config["num_workers"], 232 | "num_gpus": train_config["num_gpus"], 233 | "num_envs_per_worker": train_config["num_envs"] // train_config["num_workers"], 234 | "train_batch_size": train_config["train_batch_size"], 235 | } 236 | if seed is not None: 237 | rllib_config["seed"] = seed 238 | 239 | return rllib_config 240 | 241 | 242 | def save_model_checkpoint(trainer_obj=None, save_directory=None, current_timestep=0): 243 | """ 244 | Save trained model checkpoints. 245 | """ 246 | assert trainer_obj is not None 247 | assert save_directory is not None 248 | assert os.path.exists(save_directory), ( 249 | "Invalid folder path. " 250 | "Please specify a valid directory to save the checkpoints." 251 | ) 252 | model_params = trainer_obj.get_weights() 253 | for policy in model_params: 254 | filepath = os.path.join( 255 | save_directory, 256 | f"{policy}_{current_timestep}.state_dict", 257 | ) 258 | logging.info( 259 | "Saving the model checkpoints for policy %s to %s.", (policy, filepath) 260 | ) 261 | torch.save(model_params[policy], filepath) 262 | 263 | 264 | def load_model_checkpoints(trainer_obj=None, save_directory=None, ckpt_idx=-1): 265 | """ 266 | Load trained model checkpoints. 267 | """ 268 | assert trainer_obj is not None 269 | assert save_directory is not None 270 | assert os.path.exists(save_directory), ( 271 | "Invalid folder path. " 272 | "Please specify a valid directory to load the checkpoints from." 273 | ) 274 | files = [f for f in os.listdir(save_directory) if f.endswith("state_dict")] 275 | 276 | assert len(files) == len(trainer_obj.config["multiagent"]["policies"]) 277 | 278 | model_params = trainer_obj.get_weights() 279 | for policy in model_params: 280 | policy_models = [ 281 | os.path.join(save_directory, file) for file in files if policy in file 282 | ] 283 | # If there are multiple files, then use the ckpt_idx to specify the checkpoint 284 | assert ckpt_idx < len(policy_models) 285 | sorted_policy_models = sorted(policy_models, key=os.path.getmtime) 286 | policy_model_file = sorted_policy_models[ckpt_idx] 287 | model_params[policy] = torch.load(policy_model_file) 288 | logging.info(f"Loaded model checkpoints {policy_model_file}.") 289 | 290 | trainer_obj.set_weights(model_params) 291 | 292 | 293 | def create_trainer(exp_run_config=None, source_dir=None, results_dir=None, seed=None): 294 | """ 295 | Create the RLlib trainer. 296 | """ 297 | assert exp_run_config is not None 298 | if results_dir is None: 299 | # Use the current time as the name for the results directory. 300 | results_dir = f"{time.time():10.0f}" 301 | 302 | # Directory to save model checkpoints and metrics 303 | 304 | save_config = exp_run_config["saving"] 305 | results_save_dir = os.path.join( 306 | save_config["basedir"], 307 | save_config["name"], 308 | save_config["tag"], 309 | results_dir, 310 | ) 311 | 312 | ray.init(ignore_reinit_error=True) 313 | 314 | # Create the A2C trainer. 315 | exp_run_config["env"]["source_dir"] = source_dir 316 | rllib_trainer = A2CTrainer( 317 | env=EnvWrapper, 318 | config=get_rllib_config( 319 | exp_run_config=exp_run_config, env_class=EnvWrapper, seed=seed 320 | ), 321 | ) 322 | return rllib_trainer, results_save_dir 323 | 324 | 325 | def fetch_episode_states(trainer_obj=None, episode_states=None): 326 | """ 327 | Helper function to rollout the env and fetch env states for an episode. 328 | """ 329 | assert trainer_obj is not None 330 | assert episode_states is not None 331 | assert isinstance(episode_states, list) 332 | assert len(episode_states) > 0 333 | 334 | outputs = {} 335 | 336 | # Fetch the env object from the trainer 337 | env_object = trainer_obj.workers.local_worker().env 338 | obs = env_object.reset() 339 | 340 | env = env_object.env 341 | 342 | for state in episode_states: 343 | assert state in env.global_state, f"{state} is not in global state!" 344 | # Initialize the episode states 345 | array_shape = env.global_state[state]["value"].shape 346 | outputs[state] = np.nan * np.ones(array_shape) 347 | 348 | agent_states = {} 349 | policy_ids = {} 350 | policy_mapping_fn = trainer_obj.config["multiagent"]["policy_mapping_fn"] 351 | for region_id in range(env.num_agents): 352 | policy_ids[region_id] = policy_mapping_fn(region_id) 353 | agent_states[region_id] = trainer_obj.get_policy( 354 | policy_ids[region_id] 355 | ).get_initial_state() 356 | 357 | for timestep in range(env.episode_length): 358 | for state in episode_states: 359 | outputs[state][timestep] = env.global_state[state]["value"][timestep] 360 | 361 | actions = {} 362 | # TODO: Consider using the `compute_actions` (instead of `compute_action`) 363 | # API below for speed-up when there are many agents. 364 | for region_id in range(env.num_agents): 365 | if ( 366 | len(agent_states[region_id]) == 0 367 | ): # stateless, with a linear model, for example 368 | actions[region_id] = trainer_obj.compute_action( 369 | obs[region_id], 370 | agent_states[region_id], 371 | policy_id=policy_ids[region_id], 372 | ) 373 | else: # stateful 374 | ( 375 | actions[region_id], 376 | agent_states[region_id], 377 | _, 378 | ) = trainer_obj.compute_action( 379 | obs[region_id], 380 | agent_states[region_id], 381 | policy_id=policy_ids[region_id], 382 | ) 383 | obs, _, done, _ = env_object.step(actions) 384 | if done["__all__"]: 385 | for state in episode_states: 386 | outputs[state][timestep + 1] = env.global_state[state]["value"][ 387 | timestep + 1 388 | ] 389 | break 390 | 391 | return outputs 392 | 393 | 394 | def trainer( 395 | negotiation_on=0, 396 | num_envs=100, 397 | train_batch_size=1024, 398 | num_episodes=30000, 399 | lr=0.0005, 400 | model_params_save_freq=5000, 401 | desired_outputs=desired_outputs, 402 | num_workers=4, 403 | ): 404 | print("Training with RLlib...") 405 | 406 | # Read the run configurations specific to the environment. 407 | # Note: The run config yaml(s) can be edited at warp_drive/training/run_configs 408 | # ----------------------------------------------------------------------------- 409 | config_path = os.path.join(PUBLIC_REPO_DIR, "scripts", "rice_rllib.yaml") 410 | if not os.path.exists(config_path): 411 | raise ValueError( 412 | "The run configuration is missing. Please make sure the correct path " 413 | "is specified." 414 | ) 415 | 416 | with open(config_path, "r", encoding="utf8") as fp: 417 | run_config = yaml.safe_load(fp) 418 | # replace the default setting 419 | run_config["env"]["negotiation_on"] = negotiation_on 420 | run_config["trainer"]["num_envs"] = num_envs 421 | run_config["trainer"]["train_batch_size"] = train_batch_size 422 | run_config["trainer"]["num_workers"] = num_workers 423 | run_config["trainer"]["num_episodes"] = num_episodes 424 | run_config["policy"]["regions"]["lr"] = lr 425 | run_config["saving"]["model_params_save_freq"] = model_params_save_freq 426 | 427 | # Create trainer 428 | # -------------- 429 | trainer, save_dir = create_trainer(run_config) 430 | # debug: print("trainer weghts: ", trainer.get_weights()["regions"]["policy_head.97.weight"]) 431 | # Copy the source files into the results directory 432 | # ------------------------------------------------ 433 | os.makedirs(save_dir) 434 | with open(os.path.join(save_dir, "rice_rllib.yaml"), "w") as yaml_file: 435 | yaml.dump(run_config, yaml_file) 436 | # Copy source files to the saving directory 437 | for file in ["rice.py", "rice_helpers.py"]: 438 | shutil.copyfile( 439 | os.path.join(PUBLIC_REPO_DIR, file), 440 | os.path.join(save_dir, file), 441 | ) 442 | 443 | # Add an identifier file 444 | with open(os.path.join(save_dir, ".rllib"), "x", encoding="utf-8") as fp: 445 | pass 446 | fp.close() 447 | 448 | # Perform training 449 | # ---------------- 450 | trainer_config = run_config["trainer"] 451 | # num_episodes = trainer_config["num_episodes"] 452 | # train_batch_size = trainer_config["train_batch_size"] 453 | # Fetch the env object from the trainer 454 | env_obj = trainer.workers.local_worker().env.env 455 | episode_length = env_obj.episode_length 456 | num_iters = (num_episodes * episode_length) // train_batch_size 457 | 458 | for iteration in range(num_iters): 459 | print(f"********** Iter : {iteration + 1:5d} / {num_iters:5d} **********") 460 | result = trainer.train() 461 | total_timesteps = result.get("timesteps_total") 462 | if ( 463 | iteration % run_config["saving"]["model_params_save_freq"] == 0 464 | or iteration == num_iters - 1 465 | ): 466 | save_model_checkpoint(trainer, save_dir, total_timesteps) 467 | logging.info(result) 468 | print(f"""episode_reward_mean: {result.get('episode_reward_mean')}""") 469 | 470 | outputs_ts = fetch_episode_states(trainer, desired_outputs) 471 | save( 472 | outputs_ts, 473 | os.path.join( 474 | save_dir, 475 | f"outputs_ts_{total_timesteps}.pkl", 476 | ), 477 | ) 478 | print(f"Saving logged outputs to {save_dir}") 479 | # Create a (zipped) submission file 480 | # --------------------------------- 481 | subprocess.call( 482 | [ 483 | "python", 484 | os.path.join(PUBLIC_REPO_DIR, "scripts", "create_submission_zip.py"), 485 | "--results_dir", 486 | save_dir, 487 | ] 488 | ) 489 | # Close Ray gracefully after completion 490 | ray.shutdown() 491 | return trainer, outputs_ts 492 | 493 | 494 | if __name__ == "__main__": 495 | print("Training with RLlib...") 496 | 497 | # Read the run configurations specific to the environment. 498 | # Note: The run config yaml(s) can be edited at warp_drive/training/run_configs 499 | # ----------------------------------------------------------------------------- 500 | config_path = os.getenv("CONFIG_FILE", os.path.join(PUBLIC_REPO_DIR, "scripts", "rice_rllib.yaml")) 501 | if not os.path.exists(config_path): 502 | raise ValueError( 503 | "The run configuration is missing. Please make sure the correct path " 504 | "is specified." 505 | ) 506 | 507 | with open(config_path, "r", encoding="utf8") as fp: 508 | run_config = yaml.safe_load(fp) 509 | 510 | # Create trainer 511 | # -------------- 512 | trainer, save_dir = create_trainer(run_config) 513 | 514 | # Copy the source files into the results directory 515 | # ------------------------------------------------ 516 | os.makedirs(save_dir) 517 | # Copy source files to the saving directory 518 | for file in ["rice.py", "rice_helpers.py"]: 519 | shutil.copyfile( 520 | os.path.join(PUBLIC_REPO_DIR, file), 521 | os.path.join(save_dir, file), 522 | ) 523 | for file in ["rice_rllib.yaml"]: 524 | shutil.copyfile( 525 | os.path.join(PUBLIC_REPO_DIR, "scripts", file), 526 | os.path.join(save_dir, file), 527 | ) 528 | 529 | # Add an identifier file 530 | with open(os.path.join(save_dir, ".rllib"), "x", encoding="utf-8") as fp: 531 | pass 532 | fp.close() 533 | 534 | # Perform training 535 | # ---------------- 536 | trainer_config = run_config["trainer"] 537 | num_episodes = trainer_config["num_episodes"] 538 | train_batch_size = trainer_config["train_batch_size"] 539 | # Fetch the env object from the trainer 540 | env_obj = trainer.workers.local_worker().env.env 541 | episode_length = env_obj.episode_length 542 | num_iters = (num_episodes * episode_length) // train_batch_size 543 | 544 | for iteration in range(num_iters): 545 | print(f"********** Iter : {iteration + 1:5d} / {num_iters:5d} **********") 546 | result = trainer.train() 547 | total_timesteps = result.get("timesteps_total") 548 | if ( 549 | iteration % run_config["saving"]["model_params_save_freq"] == 0 550 | or iteration == num_iters - 1 551 | ): 552 | save_model_checkpoint(trainer, save_dir, total_timesteps) 553 | logging.info(result) 554 | print(f"""episode_reward_mean: {result.get('episode_reward_mean')}""") 555 | 556 | # Create a (zipped) submission file 557 | # --------------------------------- 558 | subprocess.call( 559 | [ 560 | "python", 561 | os.path.join(PUBLIC_REPO_DIR, "scripts", "create_submission_zip.py"), 562 | "--results_dir", 563 | save_dir, 564 | ] 565 | ) 566 | 567 | # Close Ray gracefully after completion 568 | ray.shutdown() 569 | -------------------------------------------------------------------------------- /Visualization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%autoreload 2" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "U5SXMcYAarz5" 25 | }, 26 | "source": [ 27 | "# Visualization of RICE-N and RL training\n" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import os, glob, sys, numpy as np, scipy as sp, sklearn as skl" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "if \"_ROOT\" not in globals():\n", 46 | " _ROOT = os.getcwd()\n", 47 | " print(f\"Set _ROOT = {_ROOT}\")\n", 48 | "else: \n", 49 | " print(f\"Already set: _ROOT = {_ROOT}\")" 50 | ] 51 | }, 52 | { 53 | "attachments": {}, 54 | "cell_type": "markdown", 55 | "metadata": { 56 | "id": "OF4mXVOHlHS3" 57 | }, 58 | "source": [ 59 | "## Save or load from previous training results" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": { 65 | "id": "80g85HbZlHS4" 66 | }, 67 | "source": [ 68 | "This section is for saving and loading the results of training (not the trainer itself)." 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": { 75 | "id": "K9zxtDGzlHS4" 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "from opt_helper import save, load" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": { 85 | "id": "03fOZB9fpHAs" 86 | }, 87 | "source": [ 88 | "To save the output timeseries: " 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": { 95 | "id": "lYF6UDHKlHS4" 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "# [uncomment below to save]\n", 100 | "# save({\"nego_off\":gpu_nego_off_ts, \"nego_on\":gpu_nego_on_ts}, \"filename.pkl\")" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": { 106 | "id": "0vG1JZ75pIa7" 107 | }, 108 | "source": [ 109 | "To load the output timeseries:" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": { 116 | "id": "4TEE7CvHlHS4" 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "# [uncomment below to load]\n", 121 | "dict_ts = load(\"example_data/example.pkl\")\n", 122 | "nego_off_ts, nego_on_ts = dict_ts[\"nego_off\"], dict_ts[\"nego_on\"]" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "for key, value in nego_off_ts.items(): \n", 132 | " print(f\"{key:40} {value.shape}\")" 133 | ] 134 | }, 135 | { 136 | "attachments": {}, 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "## Plot training procedures" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "One may want to plot the some metrics such as `mean reward` which are logged during the training procedure.\n", 148 | "\n", 149 | "```python\n", 150 | "metrics = ['Iterations Completed',\n", 151 | " 'VF loss coefficient',\n", 152 | " 'Entropy coefficient',\n", 153 | " 'Total loss',\n", 154 | " 'Policy loss',\n", 155 | " 'Value function loss',\n", 156 | " 'Mean rewards',\n", 157 | " 'Max. rewards',\n", 158 | " 'Min. rewards',\n", 159 | " 'Mean value function',\n", 160 | " 'Mean advantages',\n", 161 | " 'Mean (norm.) advantages',\n", 162 | " 'Mean (discounted) returns',\n", 163 | " 'Mean normalized returns',\n", 164 | " 'Mean entropy',\n", 165 | " 'Variance explained by the value function',\n", 166 | " 'Gradient norm',\n", 167 | " 'Learning rate',\n", 168 | " 'Mean episodic reward',\n", 169 | " 'Mean policy eval time per iter (ms)',\n", 170 | " 'Mean action sample time per iter (ms)',\n", 171 | " 'Mean env. step time per iter (ms)',\n", 172 | " 'Mean training time per iter (ms)',\n", 173 | " 'Mean total time per iter (ms)',\n", 174 | " 'Mean steps per sec (policy eval)',\n", 175 | " 'Mean steps per sec (action sample)',\n", 176 | " 'Mean steps per sec (env. step)',\n", 177 | " 'Mean steps per sec (training time)',\n", 178 | " 'Mean steps per sec (total)'\n", 179 | " ]\n", 180 | "```" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "To check out the logged submissions, please run the following block." 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "from glob import glob\n", 197 | "submission_zip_files = glob(os.path.join(_ROOT,\"Submissions/*.zip\"))" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "If previous trainings are finished and logged properly, this should give a list of `*.zip` files where the logs are included. \n", 205 | "\n", 206 | "We picked one of the submissions and the metric `Mean episodic reward` as an example, please check the code below." 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "from opt_helper import get_training_curve, plot_training_curve" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "### WarpDrive version" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "# TBC" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "### RLLib (CPU) version" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "log_zip = submission_zip_files[1]\n", 248 | "plot_training_curve(None, 'Mean episodic reward', log_zip)\n", 249 | "\n", 250 | "# to check the raw logging dictionary, uncomment below\n", 251 | "# logs = get_training_curve(log_zip)\n", 252 | "# logs" 253 | ] 254 | }, 255 | { 256 | "attachments": {}, 257 | "cell_type": "markdown", 258 | "metadata": { 259 | "id": "0BCG5IYWlHS5" 260 | }, 261 | "source": [ 262 | "## Plot results" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "metadata": { 269 | "id": "oZW6-QJGlHS5" 270 | }, 271 | "outputs": [], 272 | "source": [ 273 | "from scripts.desired_outputs import desired_outputs" 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "metadata": { 279 | "id": "AdxL0JanlHS5" 280 | }, 281 | "source": [ 282 | "One may want to check the performance of the agents by plotting graphs. Below, we list all the logged variables. One may change the ``desired_outputs.py`` to add more variables of interest.\n", 283 | "\n", 284 | "```python\n", 285 | "desired_outputs = ['global_temperature', \n", 286 | " 'global_carbon_mass', \n", 287 | " 'capital_all_regions', \n", 288 | " 'labor_all_regions', \n", 289 | " 'production_factor_all_regions', \n", 290 | " 'intensity_all_regions', \n", 291 | " 'global_exogenous_emissions', \n", 292 | " 'global_land_emissions', \n", 293 | " 'timestep', \n", 294 | " 'activity_timestep', \n", 295 | " 'capital_depreciation_all_regions', \n", 296 | " 'savings_all_regions', \n", 297 | " 'mitigation_rate_all_regions', \n", 298 | " 'max_export_limit_all_regions', \n", 299 | " 'mitigation_cost_all_regions', \n", 300 | " 'damages_all_regions', \n", 301 | " 'abatement_cost_all_regions', \n", 302 | " 'utility_all_regions', \n", 303 | " 'social_welfare_all_regions', \n", 304 | " 'reward_all_regions', \n", 305 | " 'consumption_all_regions', \n", 306 | " 'current_balance_all_regions', \n", 307 | " 'gross_output_all_regions', \n", 308 | " 'investment_all_regions', \n", 309 | " 'production_all_regions', \n", 310 | " 'tariffs', \n", 311 | " 'future_tariffs', \n", 312 | " 'scaled_imports', \n", 313 | " 'desired_imports', \n", 314 | " 'tariffed_imports', \n", 315 | " 'stage', \n", 316 | " 'minimum_mitigation_rate_all_regions', \n", 317 | " 'promised_mitigation_rate', \n", 318 | " 'requested_mitigation_rate', \n", 319 | " 'proposal_decisions',\n", 320 | " 'global_consumption',\n", 321 | " 'global_production']\n", 322 | "```" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "metadata": { 329 | "id": "rYLsNHRjlHS5" 330 | }, 331 | "outputs": [], 332 | "source": [ 333 | "from opt_helper import plot_result" 334 | ] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "metadata": { 339 | "id": "-Ab3Dd-zlHS5" 340 | }, 341 | "source": [ 342 | "`plot_result()` plots the time series of logged variables.\n", 343 | "\n", 344 | "```python\n", 345 | "plot_result(variables, nego_off, nego_on, k)\n", 346 | "```\n", 347 | "* ``variables`` can be either a single variable of interest or a list of variable names from the above list. \n", 348 | "* The ``nego_off_ts`` and ``nego_on_ts`` are the logged time series for these variables, with and without negotiation. \n", 349 | "* ``k`` represents the dimension of the variable of interest ( it should be ``0`` by default for most situations)." 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": { 355 | "id": "mrn13h2V2jBQ" 356 | }, 357 | "source": [ 358 | "Here's an example of plotting a single variable of interest." 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "metadata": { 365 | "id": "nDdreoz2lHS6" 366 | }, 367 | "outputs": [], 368 | "source": [ 369 | "plot_result(\"global_temperature\", \n", 370 | " nego_off=nego_off_ts, # change it to cpu_nego_off_ts if using CPU\n", 371 | " nego_on=nego_on_ts, \n", 372 | " k=0)" 373 | ] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "metadata": { 378 | "id": "VNXPz9kk2meL" 379 | }, 380 | "source": [ 381 | "Here's an example of plotting a list of variables." 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": null, 387 | "metadata": { 388 | "id": "Z76yOBp3lHS5", 389 | "scrolled": true 390 | }, 391 | "outputs": [], 392 | "source": [ 393 | "plot_result(desired_outputs[0:3], # truncated for demonstration purposes\n", 394 | " nego_off=nego_off_ts, \n", 395 | " nego_on=nego_on_ts, \n", 396 | " k=0)" 397 | ] 398 | }, 399 | { 400 | "cell_type": "markdown", 401 | "metadata": { 402 | "id": "6sJSQ5gdxCni" 403 | }, 404 | "source": [ 405 | "If one only want to plot negotiation-off plots, feel free to set `nego_on=None`. " 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": null, 411 | "metadata": { 412 | "id": "FkuU8kV2xCnj" 413 | }, 414 | "outputs": [], 415 | "source": [ 416 | "plot_result(desired_outputs[0:3], # truncated for demonstration purposes\n", 417 | " nego_off=nego_off_ts, \n", 418 | " nego_on=None, \n", 419 | " k=0)" 420 | ] 421 | }, 422 | { 423 | "cell_type": "markdown", 424 | "metadata": {}, 425 | "source": [ 426 | "## Plot region data using a grid plot" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": null, 432 | "metadata": {}, 433 | "outputs": [], 434 | "source": [ 435 | "from opt_helper import make_grid_plot" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": null, 441 | "metadata": {}, 442 | "outputs": [], 443 | "source": [ 444 | "feature_name = \"labor_all_regions\"\n", 445 | "feature_label = feature_name.replace(\"_\", \" \").title() + \" - atmosphere layer\"\n", 446 | "make_grid_plot(\n", 447 | " nego_off_ts[feature_name], \n", 448 | " xlabel=\"Year\", \n", 449 | " ylabel=feature_label, \n", 450 | " feature_label=feature_label, \n", 451 | " cols=3, \n", 452 | " fig_scale=3\n", 453 | ");" 454 | ] 455 | }, 456 | { 457 | "cell_type": "markdown", 458 | "metadata": {}, 459 | "source": [ 460 | "## Plot multiple time series data with mean and spread around the mean" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": null, 466 | "metadata": {}, 467 | "outputs": [], 468 | "source": [ 469 | "from opt_helper import plot_fig_with_bounds" 470 | ] 471 | }, 472 | { 473 | "cell_type": "markdown", 474 | "metadata": {}, 475 | "source": [ 476 | "Generating some example perturbed data around the original data." 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": null, 482 | "metadata": {}, 483 | "outputs": [], 484 | "source": [ 485 | "from copy import deepcopy \n", 486 | "perturbed_nego_on_ts = [\n", 487 | " deepcopy(nego_on_ts),\n", 488 | " deepcopy(nego_on_ts),\n", 489 | " deepcopy(nego_on_ts),\n", 490 | "]\n", 491 | "\n", 492 | "for key, ts in nego_on_ts.items():\n", 493 | " perturbed_nego_on_ts[0][key] = ts\n", 494 | " perturbed_nego_on_ts[1][key] = ts + 0.1 * ts\n", 495 | " perturbed_nego_on_ts[2][key] = ts - 0.1 * ts\n", 496 | "\n", 497 | "\n", 498 | "perturbed_nego_off_ts = [\n", 499 | " deepcopy(nego_off_ts),\n", 500 | " deepcopy(nego_off_ts),\n", 501 | " deepcopy(nego_off_ts),\n", 502 | "]\n", 503 | "\n", 504 | "for key, ts in nego_off_ts.items():\n", 505 | " perturbed_nego_off_ts[0][key] = ts\n", 506 | " perturbed_nego_off_ts[1][key] = ts + 0.1 * ts\n", 507 | " perturbed_nego_off_ts[2][key] = ts - 0.1 * ts" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": null, 513 | "metadata": {}, 514 | "outputs": [], 515 | "source": [ 516 | "feature_name =\"labor_all_regions\"\n", 517 | "\n", 518 | "plot_fig_with_bounds(\n", 519 | " feature_name, # variable,\n", 520 | " \"y_label\", # y_label,\n", 521 | " list_of_dict_off=perturbed_nego_off_ts,\n", 522 | " list_of_dict_on=perturbed_nego_on_ts,\n", 523 | " title=None,\n", 524 | " idx=0,\n", 525 | " x_label=\"year\",\n", 526 | " skips=3,\n", 527 | " line_colors=[\"#0868ac\", \"#7e0018\"],\n", 528 | " region_colors=[\"#7bccc4\", \"#ffac3b\"],\n", 529 | " start=2020,\n", 530 | " alpha=0.5,\n", 531 | " is_grid=True,\n", 532 | " is_save=True,\n", 533 | " delta=5,\n", 534 | ")" 535 | ] 536 | }, 537 | { 538 | "attachments": {}, 539 | "cell_type": "markdown", 540 | "metadata": {}, 541 | "source": [ 542 | "## Aggregate simulation statistics to visualize" 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": null, 548 | "metadata": {}, 549 | "outputs": [], 550 | "source": [ 551 | "for k, v in nego_off_ts.items(): \n", 552 | " print(f\"{k:40}{v.shape}\")" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": null, 558 | "metadata": {}, 559 | "outputs": [], 560 | "source": [ 561 | "feature_name = \"mitigation_rate_all_regions\"\n", 562 | "feature_label = feature_name.replace(\"_\", \" \").title()\n", 563 | "make_grid_plot(\n", 564 | " nego_on_ts[feature_name],\n", 565 | " feature_label=feature_label, \n", 566 | " xlabel=\"Step\",\n", 567 | " ylabel=feature_label,\n", 568 | " cols=4,\n", 569 | " fig_scale=3,\n", 570 | ");" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": null, 576 | "metadata": {}, 577 | "outputs": [], 578 | "source": [ 579 | "feature_name = \"savings_all_regions\"\n", 580 | "feature_label = feature_name.replace(\"_\", \" \").title()\n", 581 | "make_grid_plot(\n", 582 | " nego_on_ts[feature_name],\n", 583 | " feature_label=feature_label, \n", 584 | " xlabel=\"Step\",\n", 585 | " ylabel=feature_label,\n", 586 | " cols=4,\n", 587 | " fig_scale=3,\n", 588 | ");" 589 | ] 590 | }, 591 | { 592 | "cell_type": "markdown", 593 | "metadata": {}, 594 | "source": [ 595 | "## Cluster regions" 596 | ] 597 | }, 598 | { 599 | "cell_type": "code", 600 | "execution_count": null, 601 | "metadata": {}, 602 | "outputs": [], 603 | "source": [ 604 | "feature_name = \"savings_all_regions\"\n", 605 | "lo_savings, med_savings, hi_savings = [], [], []\n", 606 | "steps, n_regions = nego_on_ts[feature_name].shape\n", 607 | "for region_j in range(n_regions):\n", 608 | " mean_region_j = np.mean(nego_on_ts[feature_name][:, region_j])\n", 609 | " print(f\"{region_j}: {mean_region_j:.2f}\")\n", 610 | " if mean_region_j < 0.25: \n", 611 | " lo_savings.append(region_j)\n", 612 | " elif 0.25 < mean_region_j < 0.35:\n", 613 | " med_savings.append(region_j)\n", 614 | " else:\n", 615 | " hi_savings.append(region_j)" 616 | ] 617 | }, 618 | { 619 | "cell_type": "code", 620 | "execution_count": null, 621 | "metadata": {}, 622 | "outputs": [], 623 | "source": [ 624 | "lo_savings, med_savings, hi_savings" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": null, 630 | "metadata": {}, 631 | "outputs": [], 632 | "source": [ 633 | "from opt_helper import make_aggregate_data_across_three_clusters" 634 | ] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "execution_count": null, 639 | "metadata": {}, 640 | "outputs": [], 641 | "source": [ 642 | "aggregate_nego_on_ts = make_aggregate_data_across_three_clusters(\n", 643 | " nego_on_ts, \n", 644 | " lo_savings, \n", 645 | " med_savings, \n", 646 | " hi_savings\n", 647 | ")" 648 | ] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "execution_count": null, 653 | "metadata": {}, 654 | "outputs": [], 655 | "source": [ 656 | "feature_name = \"savings_all_regions\"\n", 657 | "feature_label = feature_name.replace(\"_\", \" \").title() + \" - aggregate\"\n", 658 | "make_grid_plot(\n", 659 | " aggregate_nego_on_ts[feature_name],\n", 660 | " feature_label=feature_label, \n", 661 | " xlabel=\"Step\",\n", 662 | " ylabel=feature_label,\n", 663 | " cols=4,\n", 664 | " fig_scale=3,\n", 665 | ");" 666 | ] 667 | }, 668 | { 669 | "cell_type": "markdown", 670 | "metadata": {}, 671 | "source": [ 672 | "## Correlation between clusters and feature of those cluster members?" 673 | ] 674 | }, 675 | { 676 | "cell_type": "code", 677 | "execution_count": null, 678 | "metadata": {}, 679 | "outputs": [], 680 | "source": [ 681 | "from opt_helper import compute_correlation_across_groups" 682 | ] 683 | }, 684 | { 685 | "cell_type": "code", 686 | "execution_count": null, 687 | "metadata": {}, 688 | "outputs": [], 689 | "source": [ 690 | "# Define the groups of regions that are aggregated.\n", 691 | "# Note: these shoudl be the same groups that were used to generate the aggregate statistics.\n", 692 | "groups = (lo_savings, med_savings, hi_savings)\n", 693 | "\n", 694 | "# Defint the X variable\n", 695 | "x_feature_name = \"savings_all_regions\"\n", 696 | "aggregate_stats_across_groups = aggregate_nego_on_ts[x_feature_name]\n", 697 | "\n", 698 | "# Define the y data\n", 699 | "data_ts = nego_on_ts\n", 700 | "\n", 701 | "# Only compute correlations for the y variables that are region-specific time-series.\n", 702 | "y_features_names = [\n", 703 | " i for i in data_ts.keys() if \"all_regions\" in i\n", 704 | "]\n", 705 | "\n", 706 | "for y_feature_name in y_features_names: \n", 707 | " r2 = compute_correlation_across_groups(aggregate_stats_across_groups, \n", 708 | " data_ts, \n", 709 | " y_feature_name, \n", 710 | " do_plot=False\n", 711 | " )\n", 712 | "\n", 713 | " print(f\"{x_feature_name:35} vs {y_feature_name:35} r2 = {r2:5.2f}\")" 714 | ] 715 | } 716 | ], 717 | "metadata": { 718 | "accelerator": "GPU", 719 | "colab": { 720 | "collapsed_sections": [ 721 | "ytMyQ2OHlHSr", 722 | "ukp1MeR1Q0dG", 723 | "4BIL2upxlHSw", 724 | "hrqSp18wlHS2", 725 | "OF4mXVOHlHS3", 726 | "0BCG5IYWlHS5", 727 | "yDI4p7cqlHS6", 728 | "LuF76W4FlHS6" 729 | ], 730 | "name": "Copy of Colab_Tutorial.ipynb", 731 | "provenance": [] 732 | }, 733 | "gpuClass": "standard", 734 | "kernelspec": { 735 | "display_name": "Python 3 (ipykernel)", 736 | "language": "python", 737 | "name": "python3" 738 | }, 739 | "language_info": { 740 | "codemirror_mode": { 741 | "name": "ipython", 742 | "version": 3 743 | }, 744 | "file_extension": ".py", 745 | "mimetype": "text/x-python", 746 | "name": "python", 747 | "nbconvert_exporter": "python", 748 | "pygments_lexer": "ipython3", 749 | "version": "3.7.16" 750 | }, 751 | "vscode": { 752 | "interpreter": { 753 | "hash": "251d2b4a5597da75c3ffbd37b4ea77a636d0b2648b30b76d98939345ff58fd97" 754 | } 755 | } 756 | }, 757 | "nbformat": 4, 758 | "nbformat_minor": 1 759 | } 760 | -------------------------------------------------------------------------------- /scripts/evaluate_submission.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, salesforce.com, inc and MILA. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root 5 | # or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | 8 | """ 9 | Evaluation script for the rice environment 10 | """ 11 | 12 | import argparse 13 | import logging 14 | import os 15 | import shutil 16 | import subprocess 17 | import sys 18 | import time 19 | from collections import OrderedDict 20 | from pathlib import Path 21 | import json 22 | import numpy as np 23 | import yaml 24 | 25 | _path = Path(os.path.abspath(__file__)) 26 | 27 | from fixed_paths import PUBLIC_REPO_DIR 28 | from run_unittests import fetch_base_env 29 | from gym.spaces import MultiDiscrete 30 | 31 | # climate-cooperation-competition 32 | sys.path.append(os.path.join(PUBLIC_REPO_DIR, "scripts")) 33 | logging.info("Using PUBLIC_REPO_DIR = {}".format(PUBLIC_REPO_DIR)) 34 | 35 | # mila-sfdc-... 36 | _PRIVATE_REPO_DIR = os.path.join( 37 | _path.parent.parent.parent.absolute(), "private-repo-clone" 38 | ) 39 | sys.path.append(os.path.join(_PRIVATE_REPO_DIR, "backend")) 40 | logging.info("Using _PRIVATE_REPO_DIR = {}".format(_PRIVATE_REPO_DIR)) 41 | 42 | 43 | # Set logger level e.g., DEBUG, INFO, WARNING, ERROR. 44 | logging.getLogger().setLevel(logging.ERROR) 45 | 46 | _EVAL_SEED = 1234567890 # seed used for evaluation 47 | 48 | _INDEXES_FILENAME = "climate_economic_min_max_indices.txt" 49 | 50 | _METRICS_TO_LABEL_DICT = OrderedDict() 51 | # Read the dict values below as 52 | # (label, decimal points used to round off value: 0 becomes an integer) 53 | _METRICS_TO_LABEL_DICT["reward_all_regions"] = ("Episode Reward", 2) 54 | _METRICS_TO_LABEL_DICT["global_temperature"] = ("Temperature Rise", 2) 55 | _METRICS_TO_LABEL_DICT["global_carbon_mass"] = ("Carbon Mass", 0) 56 | _METRICS_TO_LABEL_DICT["capital_all_regions"] = ("Capital", 0) 57 | _METRICS_TO_LABEL_DICT["production_all_regions"] = ("Production", 0) 58 | _METRICS_TO_LABEL_DICT["gross_output_all_regions"] = ("Gross Output", 0) 59 | _METRICS_TO_LABEL_DICT["investment_all_regions"] = ("Investment", 0) 60 | _METRICS_TO_LABEL_DICT["abatement_cost_all_regions"] = ("Abatement Cost", 2) 61 | 62 | 63 | def get_imports(framework=None): 64 | """ 65 | Fetch relevant imports. 66 | """ 67 | assert framework is not None 68 | if framework == "rllib": 69 | from train_with_rllib import ( 70 | create_trainer, 71 | fetch_episode_states, 72 | load_model_checkpoints, 73 | ) 74 | elif framework == "warpdrive": 75 | from train_with_warp_drive import ( 76 | create_trainer, 77 | fetch_episode_states, 78 | load_model_checkpoints, 79 | ) 80 | else: 81 | raise ValueError(f"Unknown framework {framework}!") 82 | return create_trainer, load_model_checkpoints, fetch_episode_states 83 | 84 | 85 | def try_to_unzip_file(path): 86 | """ 87 | Obtain the 'results' directory from the system arguments. 88 | """ 89 | try: 90 | _unzipped_dir = os.path.join("/tmp", str(time.time())) 91 | shutil.unpack_archive(path, _unzipped_dir) 92 | return _unzipped_dir 93 | except Exception as err: 94 | raise ValueError("Cannot obtain the results directory") from err 95 | 96 | 97 | def validate_dir(results_dir=None): 98 | """ 99 | Validate that all the required files are present in the 'results' directory. 100 | """ 101 | assert results_dir is not None 102 | framework = None 103 | 104 | files = os.listdir(results_dir) 105 | if ".warpdrive" in files: 106 | framework = "warpdrive" 107 | # Warpdrive was used for training 108 | for file in [ 109 | "rice.py", 110 | "rice_helpers.py", 111 | "rice_cuda.py", 112 | "rice_step.cu", 113 | "rice_warpdrive.yaml", 114 | ]: 115 | if file not in files: 116 | success = False 117 | logging.error( 118 | "%s is not present in the results directory: %s!", file, results_dir 119 | ) 120 | comment = f"{file} is not present in the results directory!" 121 | break 122 | success = True 123 | comment = "Valid submission" 124 | elif ".rllib" in files: 125 | framework = "rllib" 126 | # RLlib was used for training 127 | for file in ["rice.py", "rice_helpers.py", "rice_rllib.yaml"]: 128 | if file not in files: 129 | success = False 130 | logging.error( 131 | "%s is not present in the results directory: %s!", file, results_dir 132 | ) 133 | comment = f"{file} is not present in the results directory!" 134 | break 135 | success = True 136 | comment = "Valid submission" 137 | else: 138 | success = False 139 | logging.error( 140 | "Missing identifier file! Either the .rllib or the .warpdrive " 141 | "file must be present in the results directory: %s", 142 | results_dir, 143 | ) 144 | comment = "Missing identifier file!" 145 | print("comment", comment) 146 | return framework, success, comment 147 | 148 | 149 | def compute_metrics( 150 | fetch_episode_states, trainer, framework, num_episodes=1, include_c_e_idx=True 151 | ): 152 | """ 153 | Generate episode rollouts and compute metrics. 154 | """ 155 | assert trainer is not None 156 | available_frameworks = ["rllib", "warpdrive"] 157 | assert ( 158 | framework in available_frameworks 159 | ), f"Invalid framework {framework}, should be in f{available_frameworks}." 160 | 161 | # Fetch all the desired outputs to compute various metrics. 162 | desired_outputs = list(_METRICS_TO_LABEL_DICT.keys()) 163 | # Add auxiliary outputs required for processing 164 | required_outputs = desired_outputs + ["activity_timestep"] 165 | 166 | episode_states = {} 167 | eval_metrics = {} 168 | try: 169 | for episode_id in range(num_episodes): 170 | if fetch_episode_states is not None: 171 | episode_states[episode_id] = fetch_episode_states( 172 | trainer, required_outputs 173 | ) 174 | else: 175 | episode_states[episode_id] = trainer.fetch_episode_global_states( 176 | required_outputs 177 | ) 178 | 179 | for feature in desired_outputs: 180 | feature_values = [None for _ in range(num_episodes)] 181 | 182 | if feature == "global_temperature": 183 | # Get the temp rise for upper strata 184 | for episode_id in range(num_episodes): 185 | feature_values[episode_id] = ( 186 | episode_states[episode_id][feature][-1, 0] 187 | - episode_states[episode_id][feature][0, 0] 188 | ) 189 | 190 | elif feature == "global_carbon_mass": 191 | for episode_id in range(num_episodes): 192 | feature_values[episode_id] = episode_states[episode_id][feature][ 193 | -1, 0 194 | ] 195 | 196 | elif feature == "gross_output_all_regions": 197 | for episode_id in range(num_episodes): 198 | # collect gross output results based on activity timestep 199 | activity_timestep = episode_states[episode_id]["activity_timestep"] 200 | activity_index = np.append( 201 | 1.0, np.diff(activity_timestep.squeeze()) 202 | ) 203 | activity_index = [np.isclose(v, 1.0) for v in activity_index] 204 | feature_values[episode_id] = np.sum( 205 | episode_states[episode_id]["gross_output_all_regions"][ 206 | activity_index 207 | ] 208 | ) 209 | 210 | else: 211 | for episode_id in range(num_episodes): 212 | feature_values[episode_id] = np.sum( 213 | episode_states[episode_id][feature] 214 | ) 215 | 216 | # Compute mean feature value across episodes 217 | mean_feature_value = np.mean(feature_values) 218 | 219 | # Formatting the values 220 | metrics_to_label_dict = _METRICS_TO_LABEL_DICT[feature] 221 | 222 | eval_metrics[metrics_to_label_dict[0]] = perform_format( 223 | mean_feature_value, metrics_to_label_dict[1] 224 | ) 225 | if include_c_e_idx: 226 | if not os.path.exists(_INDEXES_FILENAME): 227 | # Write min, max climate and economic index values to a file 228 | # for use during evaluation. 229 | indices_dict = generate_min_max_climate_economic_indices() 230 | # Write indices to a file 231 | with open(_INDEXES_FILENAME, "w", encoding="utf-8") as file_ptr: 232 | file_ptr.write(json.dumps(indices_dict)) 233 | with open(_INDEXES_FILENAME, "r", encoding="utf-8") as file_ptr: 234 | index_dict = json.load(file_ptr) 235 | eval_metrics["climate_index"] = np.round( 236 | (eval_metrics["Temperature Rise"] - index_dict["min_ci"]) 237 | / (index_dict["max_ci"] - index_dict["min_ci"]), 238 | 2, 239 | ) 240 | eval_metrics["economic_index"] = np.round( 241 | (eval_metrics["Gross Output"] - index_dict["min_ei"]) 242 | / (index_dict["max_ei"] - index_dict["min_ei"]), 243 | 2, 244 | ) 245 | success = True 246 | comment = "Successful submission" 247 | except Exception as err: 248 | logging.error(err) 249 | success = False 250 | comment = "Could not obtain an episode rollout!" 251 | eval_metrics = {} 252 | 253 | return success, comment, eval_metrics 254 | 255 | 256 | def val_metrics(logged_ts, framework, num_episodes=1, include_c_e_idx=True): 257 | """ 258 | Generate episode rollouts and compute metrics. 259 | """ 260 | available_frameworks = ["rllib", "warpdrive"] 261 | assert ( 262 | framework in available_frameworks 263 | ), f"Invalid framework {framework}, should be in f{available_frameworks}." 264 | 265 | # Fetch all the desired outputs to compute various metrics. 266 | desired_outputs = list(_METRICS_TO_LABEL_DICT.keys()) 267 | episode_states = {} 268 | eval_metrics = {} 269 | try: 270 | for episode_id in range(num_episodes): 271 | episode_states[episode_id] = logged_ts 272 | 273 | for feature in desired_outputs: 274 | feature_values = [None for _ in range(num_episodes)] 275 | 276 | if feature == "global_temperature": 277 | # Get the temp rise for upper strata 278 | for episode_id in range(num_episodes): 279 | feature_values[episode_id] = ( 280 | episode_states[episode_id][feature][-1, 0] 281 | - episode_states[episode_id][feature][0, 0] 282 | ) 283 | 284 | elif feature == "global_carbon_mass": 285 | for episode_id in range(num_episodes): 286 | feature_values[episode_id] = episode_states[episode_id][feature][ 287 | -1, 0 288 | ] 289 | 290 | elif feature == "gross_output_all_regions": 291 | for episode_id in range(num_episodes): 292 | # collect gross output results based on activity timestep 293 | activity_timestep = episode_states[episode_id]["activity_timestep"] 294 | activity_index = np.append( 295 | 1.0, np.diff(activity_timestep.squeeze()) 296 | ) 297 | activity_index = [np.isclose(v, 1.0) for v in activity_index] 298 | feature_values[episode_id] = np.sum( 299 | episode_states[episode_id]["gross_output_all_regions"][ 300 | activity_index 301 | ] 302 | ) 303 | else: 304 | for episode_id in range(num_episodes): 305 | feature_values[episode_id] = np.sum( 306 | episode_states[episode_id][feature] 307 | ) 308 | 309 | # Compute mean feature value across episodes 310 | mean_feature_value = np.mean(feature_values) 311 | 312 | # Formatting the values 313 | metrics_to_label_dict = _METRICS_TO_LABEL_DICT[feature] 314 | 315 | eval_metrics[metrics_to_label_dict[0]] = perform_format( 316 | mean_feature_value, metrics_to_label_dict[1] 317 | ) 318 | if include_c_e_idx: 319 | if not os.path.exists(_INDEXES_FILENAME): 320 | # Write min, max climate and economic index values to a file 321 | # for use during evaluation. 322 | indices_dict = generate_min_max_climate_economic_indices() 323 | # Write indices to a file 324 | with open(_INDEXES_FILENAME, "w", encoding="utf-8") as file_ptr: 325 | file_ptr.write(json.dumps(indices_dict)) 326 | with open(_INDEXES_FILENAME, "r", encoding="utf-8") as file_ptr: 327 | index_dict = json.load(file_ptr) 328 | eval_metrics["climate_index"] = np.round( 329 | (eval_metrics["Temperature Rise"] - index_dict["min_ci"]) 330 | / (index_dict["max_ci"] - index_dict["min_ci"]), 331 | 2, 332 | ) 333 | eval_metrics["economic_index"] = np.round( 334 | (eval_metrics["Gross Output"] - index_dict["min_ei"]) 335 | / (index_dict["max_ei"] - index_dict["min_ei"]), 336 | 2, 337 | ) 338 | success = True 339 | comment = "Successful submission" 340 | except Exception as err: 341 | logging.error(err) 342 | success = False 343 | comment = "Could not obtain an episode rollout!" 344 | eval_metrics = {} 345 | 346 | return success, comment, eval_metrics 347 | 348 | 349 | def perform_format(val, num_decimal_places): 350 | """ 351 | Format value to the number of desired decimal points. 352 | """ 353 | if np.isnan(val): 354 | return val 355 | assert num_decimal_places >= 0 356 | rounded_val = np.round(val, num_decimal_places) 357 | if num_decimal_places == 0: 358 | return int(rounded_val) 359 | return rounded_val 360 | 361 | 362 | def perform_evaluation( 363 | results_directory, 364 | framework, 365 | num_episodes=1, 366 | eval_seed=None, 367 | ): 368 | """ 369 | Create the trainer and compute metrics. 370 | """ 371 | assert results_directory is not None 372 | assert num_episodes > 0 373 | 374 | ( 375 | create_trainer, 376 | load_model_checkpoints, 377 | fetch_episode_states, 378 | ) = get_imports(framework=framework) 379 | 380 | # Load a run configuration 381 | config_file = os.path.join(results_directory, f"rice_{framework}.yaml") 382 | 383 | try: 384 | assert os.path.exists(config_file) 385 | except Exception as err: 386 | logging.error(f"The run configuration is missing in {results_directory}.") 387 | raise err 388 | 389 | # Copy the PUBLIC region yamls and rice_build.cu to the results directory. 390 | if not os.path.exists(os.path.join(results_directory, "region_yamls")): 391 | shutil.copytree( 392 | os.path.join(PUBLIC_REPO_DIR, "region_yamls"), 393 | os.path.join(results_directory, "region_yamls"), 394 | ) 395 | if not os.path.exists(os.path.join(results_directory, "rice_build.cu")): 396 | shutil.copyfile( 397 | os.path.join(PUBLIC_REPO_DIR, "rice_build.cu"), 398 | os.path.join(results_directory, "rice_build.cu"), 399 | ) 400 | 401 | # Create Trainer object 402 | try: 403 | with open(config_file, "r", encoding="utf-8") as file_ptr: 404 | run_config = yaml.safe_load(file_ptr) 405 | 406 | trainer, _ = create_trainer( 407 | run_config, source_dir=results_directory, seed=eval_seed 408 | ) 409 | 410 | except Exception as err: 411 | logging.error(f"Could not create Trainer with the run_config provided.") 412 | raise err 413 | 414 | # Load model checkpoints 415 | try: 416 | load_model_checkpoints(trainer, results_directory) 417 | except Exception as err: 418 | logging.error(f"Could not load model checkpoints.") 419 | raise err 420 | 421 | # Compute metrics 422 | try: 423 | success, comment, eval_metrics = compute_metrics( 424 | fetch_episode_states, 425 | trainer, 426 | framework, 427 | num_episodes=num_episodes, 428 | ) 429 | 430 | if framework == "warpdrive": 431 | trainer.graceful_close() 432 | 433 | return success, eval_metrics, comment 434 | 435 | except Exception as err: 436 | logging.error(f"Count not fetch episode and compute metrics.") 437 | raise err 438 | 439 | 440 | def get_temp_rise_and_gross_output(env, actions): 441 | env.reset() 442 | for _ in range(env.episode_length): 443 | env.step(actions) 444 | temperature_array = env.global_state["global_temperature"]["value"] 445 | temperature_rise = temperature_array[-1, 0] - temperature_array[0, 0] 446 | 447 | total_gross_production = np.sum( 448 | env.global_state["gross_output_all_regions"]["value"] 449 | ) 450 | return temperature_rise, total_gross_production 451 | 452 | 453 | def generate_min_max_climate_economic_indices(): 454 | """ 455 | Generate min and max climate and economic indices for the leaderboard. 456 | 0% savings, 100% mitigation => best climate index, worst economic index 457 | 100% savings, 0% mitigation => worst climate index, best economic index 458 | """ 459 | env = fetch_base_env() # base rice env 460 | assert isinstance( 461 | env.action_space[0], MultiDiscrete 462 | ), "Unknown action space for env." 463 | all_zero_actions = { 464 | agent_id: np.zeros( 465 | len(env.action_space[agent_id].nvec), 466 | dtype=np.int32, 467 | ) 468 | for agent_id in range(env.num_agents) 469 | } 470 | 471 | # 0% savings, 100% mitigation 472 | low_savings_high_mitigation_actions = {} 473 | savings_action_idx = 0 474 | mitigation_action_idx = 1 475 | for agent_id in range(env.num_agents): 476 | low_savings_high_mitigation_actions[agent_id] = all_zero_actions[ 477 | agent_id 478 | ].copy() 479 | low_savings_high_mitigation_actions[agent_id][ 480 | mitigation_action_idx 481 | ] = env.num_discrete_action_levels 482 | # Best climate index, worst economic index 483 | best_ci, worst_ei = get_temp_rise_and_gross_output( 484 | env, low_savings_high_mitigation_actions 485 | ) 486 | 487 | high_savings_low_mitigation_actions = {} 488 | for agent_id in range(env.num_agents): 489 | high_savings_low_mitigation_actions[agent_id] = all_zero_actions[ 490 | agent_id 491 | ].copy() 492 | high_savings_low_mitigation_actions[agent_id][ 493 | savings_action_idx 494 | ] = env.num_discrete_action_levels 495 | worst_ci, best_ei = get_temp_rise_and_gross_output( 496 | env, high_savings_low_mitigation_actions 497 | ) 498 | 499 | index_dict = { 500 | "min_ci": float(worst_ci), 501 | "max_ci": float(best_ci), 502 | "min_ei": float(worst_ei), 503 | "max_ei": float(best_ei), 504 | } 505 | return index_dict 506 | 507 | 508 | if __name__ == "__main__": 509 | logging.info("This script performs evaluation of your code.") 510 | 511 | # CLI arguments 512 | parser = argparse.ArgumentParser() 513 | parser.add_argument( 514 | "--results_dir", 515 | "-r", 516 | type=str, 517 | default="./Submissions/1680502535.zip", # an example of a submission file 518 | help="The directory where all the submission files are saved. Can also be " 519 | "a zip-file containing all the submission files.", 520 | ) 521 | args = parser.parse_args() 522 | 523 | # Check the submission zip file or directory. 524 | if "results_dir" not in args: 525 | raise ValueError( 526 | "Please provide a results directory to evaluate with the argument -r" 527 | ) 528 | 529 | if not os.path.exists(args.results_dir): 530 | raise ValueError( 531 | "The results directory is missing. Please make sure the correct path " 532 | "is specified!" 533 | ) 534 | 535 | results_dir = ( 536 | try_to_unzip_file(args.results_dir) 537 | if args.results_dir.endswith(".zip") 538 | else args.results_dir 539 | ) 540 | 541 | logging.info(f"Using submission files in {results_dir}") 542 | 543 | # Validate the submission directory 544 | framework, results_dir_is_valid, comment = validate_dir(results_dir) 545 | if not results_dir_is_valid: 546 | raise AssertionError(f"{results_dir} is not a valid submission directory.") 547 | 548 | # Run unit tests on the simulation files 549 | skip_unit_tests = True # = args.skip_unit_tests 550 | 551 | try: 552 | if skip_unit_tests: 553 | logging.info("Skipping check_output test") 554 | else: 555 | logging.info("Running unit tests...") 556 | subprocess.check_output( 557 | [ 558 | "python", 559 | "run_unittests.py", 560 | "--results_dir", 561 | results_dir, 562 | ], 563 | ) 564 | logging.info("run_unittests.py is done") 565 | except subprocess.CalledProcessError as err: 566 | logging.error(f"{results_dir}: unit tests were not successful.") 567 | raise err 568 | 569 | # Run evaluation with submitted simulation and trained agents. 570 | logging.info("Starting eval...") 571 | succeeded, metrics, comments = perform_evaluation( 572 | results_dir, framework, eval_seed=_EVAL_SEED 573 | ) 574 | 575 | # Report results. 576 | eval_result_str = "\n".join( 577 | [ 578 | f"Framework used: {framework}", 579 | f"Succeeded: {succeeded}", 580 | f"Metrics: {metrics}", 581 | f"Comments: {comments}", 582 | ] 583 | ) 584 | logging.info(eval_result_str) 585 | print(eval_result_str) 586 | -------------------------------------------------------------------------------- /MARL_training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Copyright (c) 2022, salesforce.com, inc and MILA. \n", 8 | "All rights reserved. \n", 9 | "SPDX-License-Identifier: BSD-3-Clause \n", 10 | "For full license text, see the LICENSE file in the repo root \n", 11 | "or https://opensource.org/licenses/BSD-3-Clause " 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "Get started quickly with end-to-end multi-agent RL using WarpDrive! This shows a basic example to create a simple Rice environment and perform training." 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "**Try this notebook on [Colab](http://colab.research.google.com/github/salesforce/warp-drive/blob/master/tutorials/simple-end-to-end-example.ipynb)!**" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## ⚠️ PLEASE NOTE:\n", 33 | "This notebook runs on a GPU runtime.\\\n", 34 | "If running on Colab, choose Runtime > Change runtime type from the menu, then select `GPU` in the 'hardware accelerator' dropdown menu." 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "### Dependencies" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "First, install the WarpDrive package" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "%load_ext autoreload\n", 58 | "%autoreload 2\n" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": { 65 | "scrolled": true 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "# !pip install rl-warp-drive" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "import os\n", 79 | "import torch\n", 80 | "\n", 81 | "from rice_cuda import RiceCuda\n", 82 | "from warp_drive.env_wrapper import EnvWrapper\n", 83 | "from warp_drive.training.trainer import Trainer\n", 84 | "from warp_drive.utils.env_registrar import EnvironmentRegistrar\n", 85 | "\n", 86 | "pytorch_cuda_init_success = torch.cuda.FloatTensor(8)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "# Environment, Training, and Model Hyperparameters" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "run_config = dict(\n", 103 | " \n", 104 | " # Environment settings\n", 105 | " env = dict( \n", 106 | " negotiation_on=1,\n", 107 | " num_discretization_cells = 10,\n", 108 | " ),\n", 109 | "\n", 110 | " # Trainer settings\n", 111 | " trainer = dict(\n", 112 | " num_envs = 100, # Number of environment replicas (numbre of GPU blocks used)\n", 113 | " train_batch_size = 10000, # total batch size used for training per iteration (across all the environments)\n", 114 | " num_episodes = 100000, # Total number of episodes to run the training for (can be arbitrarily high!)\n", 115 | " ),\n", 116 | " \n", 117 | " # Policy network settings\n", 118 | " policy = dict(\n", 119 | " regions = dict(\n", 120 | " to_train = True,\n", 121 | " gamma = 0.92, # discount factor\n", 122 | " lr = 0.0005, # learning rate\n", 123 | " entropy_coeff = [[0,0.5], [1000000, 0.1], [5000000, 0.05]],\n", 124 | " vf_loss_coeff = [[0,0.0001], [1000000, 0.001], [5000000, 0.01], [10000000, 0.1]],\n", 125 | " model = dict( \n", 126 | " type = \"fully_connected\",\n", 127 | " fc_dims = [256,256], # dimension(s) of the fully connected layers as a list\n", 128 | " model_ckpt_filepath = \"\" # load model parameters from a saved checkpoint (if specified)\n", 129 | " )\n", 130 | " ),\n", 131 | " ),\n", 132 | " \n", 133 | " # Checkpoint saving setting\n", 134 | " saving = dict(\n", 135 | " metrics_log_freq = 10, # How often (in iterations) to print the metrics\n", 136 | " model_params_save_freq = 5000, # How often (in iterations) to save the model parameters\n", 137 | " basedir = \"/tmp\", # base folder used for saving\n", 138 | " name = \"rice\",\n", 139 | " tag = \"example\",\n", 140 | " )\n", 141 | ")" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "# End-to-End Training Loop" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": { 155 | "scrolled": true 156 | }, 157 | "outputs": [], 158 | "source": [ 159 | "# Register the environment\n", 160 | "env_registrar = EnvironmentRegistrar()\n", 161 | "this_file_dir = os.path.dirname(os.path.abspath(\"__file__\"))\n", 162 | "env_registrar.add_cuda_env_src_path(\n", 163 | " RiceCuda.name,\n", 164 | " os.path.join(this_file_dir, \"rice_build.cu\")\n", 165 | ")\n", 166 | "\n", 167 | "# cpu_env = EnvWrapper(Rice())\n", 168 | "\n", 169 | "# add_cpu_env = env_registrar.add(device=\"cpu\")\n", 170 | "# add_cpu_env(cpu_env)\n", 171 | "# add_gpu_env = env_registrar.add(device=\"gpu\")\n", 172 | "# add_gpu_env(cpu_env)\n", 173 | "\n", 174 | "# Create a wrapped environment object via the EnvWrapper\n", 175 | "# Ensure that use_cuda is set to True (in order to run on the GPU)\n", 176 | "env_wrapper = EnvWrapper(\n", 177 | " RiceCuda(**run_config[\"env\"]),\n", 178 | " num_envs=run_config[\"trainer\"][\"num_envs\"], \n", 179 | " use_cuda=True,\n", 180 | " env_registrar=env_registrar,\n", 181 | ")\n", 182 | "\n", 183 | "# Agents can share policy models: this dictionary maps policy model names to agent ids.\n", 184 | "policy_tag_to_agent_id_map = {\n", 185 | " \"regions\": [agent_id for agent_id in range(env_wrapper.env.num_agents)],\n", 186 | "}\n", 187 | "\n", 188 | "# Create the trainer object\n", 189 | "trainer = Trainer(\n", 190 | " env_wrapper=env_wrapper,\n", 191 | " config=run_config,\n", 192 | " policy_tag_to_agent_id_map=policy_tag_to_agent_id_map,\n", 193 | ")\n", 194 | "\n", 195 | "# Perform training!\n", 196 | "trainer.train()\n", 197 | "\n", 198 | "# Shut off gracefully\n", 199 | "# trainer.graceful_close()" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": {}, 205 | "source": [ 206 | "### Fetch episode states" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "# Please note that any variable registered in rice_cuda.py can be put here\n", 216 | "desired_outputs = [\n", 217 | " \"T_i\", # Temperature\n", 218 | " \"M_i\", # Carbon mass\n", 219 | " \"sampled_actions\",\n", 220 | " \"minMu\"\n", 221 | " ]\n", 222 | "\n", 223 | "episode_states = trainer.fetch_episode_states(\n", 224 | " desired_outputs\n", 225 | ")\n", 226 | "\n", 227 | "trainer.graceful_close()" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "import matplotlib.pyplot as plt" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "def get_episode_T_AT(episode_states, negotiation_on, plot = 0):\n", 246 | " state = 'T_i'\n", 247 | " if negotiation_on:\n", 248 | " values = episode_states[state][::3,0,0]\n", 249 | " else:\n", 250 | " values = episode_states[state][:,0,0]\n", 251 | "\n", 252 | " if plot:\n", 253 | " fig = plt.figure() \n", 254 | " plt.plot(values[:], label='Temperature - Atmosphere')\n", 255 | " fig.legend()\n", 256 | " # plt.yscale('log')\n", 257 | " fig.show()\n", 258 | "\n", 259 | " return values\n", 260 | "\n", 261 | "\n", 262 | "def get_episode_T_LO(episode_states, negotiation_on, plot = 0):\n", 263 | " state = 'T_i'\n", 264 | " if negotiation_on:\n", 265 | " values = episode_states[state][::3,0,1]\n", 266 | " else:\n", 267 | " values = episode_states[state][:,0,1]\n", 268 | "\n", 269 | " if plot:\n", 270 | " fig = plt.figure()\n", 271 | " plt.plot(values[:], label='Temperature - Lower Oceans')\n", 272 | " fig.legend()\n", 273 | " # plt.yscale('log')\n", 274 | " fig.show()\n", 275 | "\n", 276 | " return values\n", 277 | "\n", 278 | "def get_episode_M_AT(episode_states, negotiation_on, plot = 0):\n", 279 | " state = 'M_i'\n", 280 | " if negotiation_on:\n", 281 | " values = episode_states[state][::3,0,0]\n", 282 | " else:\n", 283 | " values = episode_states[state][:,0,0]\n", 284 | "\n", 285 | " if plot:\n", 286 | " fig = plt.figure()\n", 287 | " plt.plot(values[:], label='Carbon - Atmosphere')\n", 288 | " fig.legend()\n", 289 | " # plt.yscale('log')\n", 290 | " fig.show()\n", 291 | "\n", 292 | " return values\n", 293 | "\n", 294 | "def get_episode_M_UP(episode_states, negotiation_on, plot = 0):\n", 295 | " state = 'M_i'\n", 296 | " if negotiation_on:\n", 297 | " values = episode_states[state][::3,0,1]\n", 298 | " else:\n", 299 | " values = episode_states[state][:,:0,1]\n", 300 | "\n", 301 | " if plot:\n", 302 | " fig = plt.figure()\n", 303 | " plt.plot(values[:], label='Carbon - Upper Strata')\n", 304 | " fig.legend()\n", 305 | " # plt.yscale('log')\n", 306 | " fig.show()\n", 307 | "\n", 308 | " return values\n", 309 | "\n", 310 | "def get_episode_M_UP(episode_states, negotiation_on, plot = 0):\n", 311 | " state = 'M_i'\n", 312 | " if negotiation_on:\n", 313 | " values = episode_states[state][::3,0,2]\n", 314 | " else:\n", 315 | " values = episode_states[state][:,0,2]\n", 316 | "\n", 317 | " if plot:\n", 318 | " fig = plt.figure()\n", 319 | " plt.plot(values[:], label='Carbon - Lower Oceans')\n", 320 | " fig.legend()\n", 321 | " # plt.yscale('log')\n", 322 | " fig.show()\n", 323 | "\n", 324 | " return values\n", 325 | "\n", 326 | "def get_episode_minMu(episode_states, negotiation_on, plot = 0):\n", 327 | " state = 'minMu'\n", 328 | " if negotiation_on:\n", 329 | " values = episode_states[state][::3,:]\n", 330 | " else:\n", 331 | " values = episode_states[state][:,:]\n", 332 | "\n", 333 | " if plot:\n", 334 | " for agent in range(len(values[0])):\n", 335 | " fig = plt.figure()\n", 336 | " plt.plot(values[:,agent], label='minMu - Agent:' + str(agent))\n", 337 | " fig.legend()\n", 338 | " # plt.yscale('log')\n", 339 | " fig.show()\n", 340 | "\n", 341 | " return values" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "def get_episode_MuAction(episode_states, negotiation_on, plot = 0):\n", 351 | " state = 'samples_actions'\n", 352 | " if negotiation_on:\n", 353 | " values = episode_states[state][::3,:, -2]\n", 354 | " else:\n", 355 | " values = episode_states[state][:,:, -2]\n", 356 | "\n", 357 | " if plot:\n", 358 | " for agent in range(len(values[0])):\n", 359 | " fig = plt.figure()\n", 360 | " plt.plot(values[:,agent], label='Mu Action - Agent:' + str(agent))\n", 361 | " fig.legend()\n", 362 | " # plt.yscale('log')\n", 363 | " fig.show()\n", 364 | "\n", 365 | " return values\n", 366 | "\n", 367 | "def get_episode_SavingAction(episode_states, negotiation_on, plot = 0):\n", 368 | " state = 'samples_actions'\n", 369 | " if negotiation_on:\n", 370 | " values = episode_states[state][::3,:, -1]\n", 371 | " else:\n", 372 | " values = episode_states[state][:,:, -1]\n", 373 | "\n", 374 | " if plot:\n", 375 | " for agent in range(len(values[0])):\n", 376 | " fig = plt.figure()\n", 377 | " plt.plot(values[:,agent], label='Mu Action - Agent:' + str(agent))\n", 378 | " fig.legend()\n", 379 | " # plt.yscale('log')\n", 380 | " fig.show()\n", 381 | "\n", 382 | " return values" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": null, 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": null, 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "episode_states['sampled_actions']" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": null, 404 | "metadata": {}, 405 | "outputs": [], 406 | "source": [ 407 | "get_episode_T_AT(episode_states, 1, 1)" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": null, 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "get_episode_T_LO(episode_states, 1, 1)" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": null, 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "get_episode_minMu(episode_states, 1, 1)" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": null, 431 | "metadata": {}, 432 | "outputs": [], 433 | "source": [ 434 | "get_episode_M_AT(episode_states, 1, 1)" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": null, 440 | "metadata": {}, 441 | "outputs": [], 442 | "source": [] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": null, 447 | "metadata": {}, 448 | "outputs": [], 449 | "source": [ 450 | "T_AT_reg_0_neg_on = episode_states['T_i'][:,0,0]\n", 451 | "#T_AT_reg_1_neg_off = episode_states['T_i'][:,1,0]\n", 452 | "\n", 453 | "T_LO_reg_0_neg_on = episode_states['T_i'][:,0,1]\n", 454 | "#T_LO_reg_1_neg_off = episode_states['T_i'][:,1,1]\n", 455 | "\n", 456 | "\n", 457 | "M_AT_reg_0_neg_on = episode_states['M_i'][:,0,0]\n", 458 | "#M_AT_reg_1_neg_off = episode_states['M_i'][:,1,0]\n", 459 | "\n", 460 | "M_UP_reg_0_neg_on = episode_states['M_i'][:,0,1]\n", 461 | "#M_UP_reg_1_neg_off = episode_states['M_i'][:,1,1]\n", 462 | "\n", 463 | "M_LO_reg_0_neg_on = episode_states['M_i'][:,0,2]\n", 464 | "#M_LO_reg_1_neg_off = episode_states['M_i'][:,1,2]\n", 465 | "# episode_states_neg_off = episode_states.copy()\n", 466 | "\n" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": null, 472 | "metadata": {}, 473 | "outputs": [], 474 | "source": [ 475 | "episode_states['minMu']" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": null, 481 | "metadata": {}, 482 | "outputs": [], 483 | "source": [ 484 | "import matplotlib.pyplot as plt\n", 485 | "\n", 486 | "fig = plt.figure()\n", 487 | "plt.plot(T_AT_reg_0_neg_on, label='Temperature - Upper Strata - negotiation on')\n", 488 | "fig.legend()\n", 489 | "# plt.yscale('log')\n", 490 | "fig.show()\n", 491 | "\n", 492 | "fig = plt.figure()\n", 493 | "plt.plot(T_LO_reg_0_neg_on, label=\"Temperature - Lower Oceans - negotiation on\")\n", 494 | "fig.legend()\n", 495 | "# plt.yscale('log')\n", 496 | "fig.show()" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": null, 502 | "metadata": {}, 503 | "outputs": [], 504 | "source": [ 505 | "import matplotlib.pyplot as plt\n", 506 | "\n", 507 | "fig = plt.figure()\n", 508 | "plt.plot(T_AT_reg_0_neg_on, label='Temperature - Upper Strata - negotiation on')\n", 509 | "fig.legend()\n", 510 | "# plt.yscale('log')\n", 511 | "fig.show()\n", 512 | "\n", 513 | "fig = plt.figure()\n", 514 | "plt.plot(T_LO_reg_0_neg_on, label=\"Temperature - Lower Oceans - negotiation on\")\n", 515 | "fig.legend()\n", 516 | "# plt.yscale('log')\n", 517 | "fig.show()" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": null, 523 | "metadata": {}, 524 | "outputs": [], 525 | "source": [ 526 | "run_config = dict(\n", 527 | " \n", 528 | " # Environment settings\n", 529 | " env = dict( \n", 530 | " negotiation_on=0,\n", 531 | " ),\n", 532 | "\n", 533 | " # Trainer settings\n", 534 | " trainer = dict(\n", 535 | " num_envs = 100, # Number of environment replicas (numbre of GPU blocks used)\n", 536 | " train_batch_size = 10000, # total batch size used for training per iteration (across all the environments)\n", 537 | " num_episodes = 30000, # Total number of episodes to run the training for (can be arbitrarily high!)\n", 538 | " ),\n", 539 | " \n", 540 | " # Policy network settings\n", 541 | " policy = dict(\n", 542 | " regions = dict(\n", 543 | " to_train = True,\n", 544 | " gamma = 0.98, # discount factor\n", 545 | " lr = 0.005, # learning rate\n", 546 | " model = dict( \n", 547 | " type = \"fully_connected\",\n", 548 | " fc_dims = [256, 256], # dimension(s) of the fully connected layers as a list\n", 549 | " model_ckpt_filepath = \"\" # load model parameters from a saved checkpoint (if specified)\n", 550 | " )\n", 551 | " ),\n", 552 | " ),\n", 553 | " \n", 554 | " # Checkpoint saving setting\n", 555 | " saving = dict(\n", 556 | " metrics_log_freq = 10, # How often (in iterations) to print the metrics\n", 557 | " model_params_save_freq = 5000, # How often (in iterations) to save the model parameters\n", 558 | " basedir = \"/tmp\", # base folder used for saving\n", 559 | " name = \"rice\",\n", 560 | " tag = \"example\",\n", 561 | " )\n", 562 | ")" 563 | ] 564 | }, 565 | { 566 | "cell_type": "code", 567 | "execution_count": null, 568 | "metadata": {}, 569 | "outputs": [], 570 | "source": [ 571 | "# Register the environment\n", 572 | "env_registrar = EnvironmentRegistrar()\n", 573 | "this_file_dir = os.path.dirname(os.path.abspath(\"__file__\"))\n", 574 | "env_registrar.add_cuda_env_src_path(\n", 575 | " RiceCuda.name,\n", 576 | " os.path.join(this_file_dir, \"rice_build.cu\")\n", 577 | ")\n", 578 | "\n", 579 | "# cpu_env = EnvWrapper(Rice())\n", 580 | "\n", 581 | "# add_cpu_env = env_registrar.add(device=\"cpu\")\n", 582 | "# add_cpu_env(cpu_env)\n", 583 | "# add_gpu_env = env_registrar.add(device=\"gpu\")\n", 584 | "# add_gpu_env(cpu_env)\n", 585 | "\n", 586 | "# Create a wrapped environment object via the EnvWrapper\n", 587 | "# Ensure that use_cuda is set to True (in order to run on the GPU)\n", 588 | "env_wrapper = EnvWrapper(\n", 589 | " RiceCuda(**run_config[\"env\"]),\n", 590 | " num_envs=run_config[\"trainer\"][\"num_envs\"], \n", 591 | " use_cuda=True,\n", 592 | " env_registrar=env_registrar,\n", 593 | ")\n", 594 | "\n", 595 | "# Agents can share policy models: this dictionary maps policy model names to agent ids.\n", 596 | "policy_tag_to_agent_id_map = {\n", 597 | " \"regions\": [agent_id for agent_id in range(env_wrapper.env.num_agents)],\n", 598 | "}\n", 599 | "\n", 600 | "# Create the trainer object\n", 601 | "trainer = Trainer(\n", 602 | " env_wrapper=env_wrapper,\n", 603 | " config=run_config,\n", 604 | " policy_tag_to_agent_id_map=policy_tag_to_agent_id_map,\n", 605 | ")\n", 606 | "\n", 607 | "# Perform training!\n", 608 | "trainer.train()\n" 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": null, 614 | "metadata": {}, 615 | "outputs": [], 616 | "source": [ 617 | "# Please note that any variable registered in rice_cuda.py can be put here\n", 618 | "desired_outputs = [\n", 619 | " \"T_i\", # Temperature\n", 620 | " \"M_i\", # Carbon mass\n", 621 | " \"sampled_actions\",\n", 622 | " \"minMu\"\n", 623 | " ]\n", 624 | "\n", 625 | "episode_states_neg_off = trainer.fetch_episode_states(\n", 626 | " desired_outputs\n", 627 | ")\n", 628 | "\n", 629 | "trainer.graceful_close()" 630 | ] 631 | }, 632 | { 633 | "cell_type": "code", 634 | "execution_count": null, 635 | "metadata": {}, 636 | "outputs": [], 637 | "source": [ 638 | "get_episode_T_AT(episode_states_neg_off, 0, 1)" 639 | ] 640 | }, 641 | { 642 | "cell_type": "code", 643 | "execution_count": null, 644 | "metadata": {}, 645 | "outputs": [], 646 | "source": [ 647 | "get_episode_M_AT(episode_states_neg_off, 0, 1)" 648 | ] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "execution_count": null, 653 | "metadata": {}, 654 | "outputs": [], 655 | "source": [ 656 | "T_AT_reg_0_neg_off = episode_states_neg_off['T_i'][:,0,0]\n", 657 | "#T_AT_reg_1_neg_off = episode_states['T_i'][:,1,0]\n", 658 | "\n", 659 | "T_LO_reg_0_neg_off = episode_states_neg_off['T_i'][:,0,1]\n", 660 | "#T_LO_reg_1_neg_off = episode_states['T_i'][:,1,1]\n", 661 | "\n", 662 | "\n", 663 | "M_AT_reg_0_neg_off = episode_states_neg_off['M_i'][:,0,0]\n", 664 | "#M_AT_reg_1_neg_off = episode_states['M_i'][:,1,0]\n", 665 | "\n", 666 | "M_UP_reg_0_neg_off = episode_states_neg_off['M_i'][:,0,1]\n", 667 | "#M_UP_reg_1_neg_off = episode_states['M_i'][:,1,1]\n", 668 | "\n", 669 | "M_LO_reg_0_neg_off = episode_states_neg_off['M_i'][:,0,2]\n", 670 | "#M_LO_reg_1_neg_off = episode_states['M_i'][:,1,2]\n", 671 | "# episode_states_neg_off = episode_states.copy()" 672 | ] 673 | }, 674 | { 675 | "cell_type": "code", 676 | "execution_count": null, 677 | "metadata": {}, 678 | "outputs": [], 679 | "source": [ 680 | "fig = plt.figure()\n", 681 | "plt.plot(T_AT_reg_0_neg_off, label='Temperature - Upper Strata - negotiation off')\n", 682 | "fig.legend()\n", 683 | "# plt.yscale('log')\n", 684 | "fig.show()\n", 685 | "\n", 686 | "fig = plt.figure()\n", 687 | "plt.plot(T_LO_reg_0_neg_off, label=\"Temperature - Lower Oceans - negotiation off\")\n", 688 | "fig.legend()\n", 689 | "# plt.yscale('log')\n", 690 | "fig.show()" 691 | ] 692 | }, 693 | { 694 | "cell_type": "code", 695 | "execution_count": null, 696 | "metadata": {}, 697 | "outputs": [], 698 | "source": [] 699 | } 700 | ], 701 | "metadata": { 702 | "kernelspec": { 703 | "display_name": "Python 3 (ipykernel)", 704 | "language": "python", 705 | "name": "python3" 706 | }, 707 | "language_info": { 708 | "codemirror_mode": { 709 | "name": "ipython", 710 | "version": 3 711 | }, 712 | "file_extension": ".py", 713 | "mimetype": "text/x-python", 714 | "name": "python", 715 | "nbconvert_exporter": "python", 716 | "pygments_lexer": "ipython3", 717 | "version": "3.8.10" 718 | } 719 | }, 720 | "nbformat": 4, 721 | "nbformat_minor": 4 722 | } 723 | --------------------------------------------------------------------------------