├── __init__.py
├── tests
├── __init__.py
├── test_train_with_rllib.py
└── rice_rllib_test.yaml
├── .dockerignore
├── White_Paper.pdf
├── example_data
├── example.pkl
└── example_2.pkl
├── climate_economic_min_max_indices.txt
├── CHANGELOG.md
├── [only for test, pls use your own version if possible] climate_economic_min_max_indices.txt
├── scripts
├── __init__.py
├── fixed_paths.py
├── lint.sh
├── desired_outputs.py
├── run_cpu_gpu_env_consistency_checks.py
├── rice_warpdrive.yaml
├── rice_rllib.yaml
├── create_submission_zip.py
├── torch_models.py
├── train_with_warp_drive.py
├── run_unittests.py
├── train_with_rllib.py
└── evaluate_submission.py
├── region_yamls
├── 12.yml
├── 14.yml
├── 16.yml
├── 18.yml
├── 2.yml
├── 5.yml
├── 7.yml
├── 9.yml
├── 11.yml
├── 13.yml
├── 15.yml
├── 17.yml
├── 19.yml
├── 20.yml
├── 4.yml
├── 6.yml
├── 22.yml
├── 25.yml
├── 28.yml
├── 3.yml
├── 21.yml
├── 26.yml
├── 23.yml
├── 24.yml
├── 29.yml
├── 30.yml
├── 27.yml
└── default.yml
├── requirements.txt
├── rice_build.cu
├── Dockerfile
├── Dockerfile_CPU
├── README_CPU.md
├── .github
└── workflows
│ └── main.yaml
├── CITATION.cff
├── LICENSE.txt
├── Makefile
├── .gitignore
├── MARL_using_RLlib.ipynb
├── rice_cuda.py
├── rice_helpers.py
├── README.md
├── getting_started.ipynb
├── Visualization.ipynb
└── MARL_training.ipynb
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | .git
2 | Submissions
3 | *.pdf
4 | core
5 | .tmp
6 |
--------------------------------------------------------------------------------
/White_Paper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mila-iqia/climate-cooperation-competition/HEAD/White_Paper.pdf
--------------------------------------------------------------------------------
/example_data/example.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mila-iqia/climate-cooperation-competition/HEAD/example_data/example.pkl
--------------------------------------------------------------------------------
/example_data/example_2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mila-iqia/climate-cooperation-competition/HEAD/example_data/example_2.pkl
--------------------------------------------------------------------------------
/climate_economic_min_max_indices.txt:
--------------------------------------------------------------------------------
1 | {"min_ci": 7.006738662719727, "max_ci": 1.2506777048110962, "min_ei": 606.57275390625, "max_ei": 9673.5322265625}
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | # Release 1.0 (2022-07-03)
4 |
5 | - Initial public release of RICE-N simulation code, training code, documentation, tutorials, and submission/evaluation scripts.
--------------------------------------------------------------------------------
/[only for test, pls use your own version if possible] climate_economic_min_max_indices.txt:
--------------------------------------------------------------------------------
1 | {"min_ci": 7.006738662719727, "max_ci": 1.2506777048110962, "min_ei": 606.57275390625, "max_ei": 9673.5322265625}
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc and MILA.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root
5 | # or https://opensource.org/licenses/BSD-3-Clause
--------------------------------------------------------------------------------
/scripts/fixed_paths.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | from pathlib import Path
3 | _path = Path(os.path.abspath(__file__))
4 | PUBLIC_REPO_DIR = str(_path.parent.parent.absolute())
5 | sys.path.append(os.path.join(PUBLIC_REPO_DIR, "scripts"))
6 | #print("fixed_paths: Using PUBLIC_REPO_DIR = {}".format(PUBLIC_REPO_DIR))
7 |
--------------------------------------------------------------------------------
/scripts/lint.sh:
--------------------------------------------------------------------------------
1 | # stop the build if there are Python syntax errors or undefined names
2 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
3 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
4 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
5 |
--------------------------------------------------------------------------------
/region_yamls/12.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 8.405493223457656
3 | xK_0: 3.30354611
4 | xL_0: 68.394527
5 | xL_a: 93.497311
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.1880269001436297
10 | xg_A: 0.10300420806704261
11 | xgamma: 0.3
12 | xl_g: 0.05753057218640376
13 | xsigma_0: 0.5289744017993728
14 |
--------------------------------------------------------------------------------
/region_yamls/14.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 1.92663301826947
3 | xK_0: 1.423908312
4 | xL_0: 284.698846
5 | xL_a: 465.307807
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.24445012982169362
10 | xg_A: 0.1335428337437049
11 | xgamma: 0.3
12 | xl_g: 0.024422285778436918
13 | xsigma_0: 1.220638524516315
14 |
--------------------------------------------------------------------------------
/region_yamls/16.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 4.217213133650901
3 | xK_0: 3.18362519
4 | xL_0: 548.75442
5 | xL_a: 560.054221
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.1703030497846267
10 | xg_A: 0.09485139864239062
11 | xgamma: 0.3
12 | xl_g: 0.08033413573292254
13 | xsigma_0: 0.3019631318655498
14 |
--------------------------------------------------------------------------------
/region_yamls/18.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 2.5248787296824777
3 | xK_0: 1.080409098
4 | xL_0: 69.194146
5 | xL_a: 100.015768
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.3464232696284064
10 | xg_A: 0.0785686384327884
11 | xgamma: 0.3
12 | xl_g: 0.028895835870575235
13 | xsigma_0: 1.0104732880546095
14 |
--------------------------------------------------------------------------------
/region_yamls/2.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 12.157936179442062
3 | xK_0: 2.64167507
4 | xL_0: 38.101107
5 | xL_a: 56.990157
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.13084887390535965
10 | xg_A: 0.06274070897633105
11 | xgamma: 0.3
12 | xl_g: 0.020192884216840113
13 | xsigma_0: 0.35044418275452427
14 |
--------------------------------------------------------------------------------
/region_yamls/5.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 2.480556451289923
3 | xK_0: 0.090493838
4 | xL_0: 94.484285
5 | xL_a: 102.997258
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.20277540450432627
10 | xg_A: 0.20063089079847785
11 | xgamma: 0.3
12 | xl_g: 0.036907187009908436
13 | xsigma_0: 1.6646404809736024
14 |
--------------------------------------------------------------------------------
/region_yamls/7.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 4.135420683261054
3 | xK_0: 1.00243116
4 | xL_0: 103.2943
5 | xL_a: 87.417937
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.1577590297489783
10 | xg_A: 0.12252697231832654
11 | xgamma: 0.3
12 | xl_g: -0.06254962519160859
13 | xsigma_0: 0.6013249328720022
14 |
--------------------------------------------------------------------------------
/region_yamls/9.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 2.7159049409990375
3 | xK_0: 1.0340369
4 | xL_0: 573.818276
5 | xL_a: 681.210099
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.09686897303646223
10 | xg_A: 0.10149076832484585
11 | xgamma: 0.3
12 | xl_g: 0.04313067107477931
13 | xsigma_0: 0.6378326935010085
14 |
--------------------------------------------------------------------------------
/region_yamls/11.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 1.8724419952820714
3 | xK_0: 0.239419592
4 | xL_0: 476.878017
5 | xL_a: 669.593553
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.13880429539557834
10 | xg_A: 0.12202134941105497
11 | xgamma: 0.3
12 | xl_g: 0.034238352160625596
13 | xsigma_0: 0.4559257467059924
14 |
--------------------------------------------------------------------------------
/region_yamls/13.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 3.5579000509140952
3 | xK_0: 0.109143954
4 | xL_0: 64.122372
5 | xL_a: 135.074132
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.16127452284439697
10 | xg_A: 0.12735655631209186
11 | xgamma: 0.3
12 | xl_g: 0.02623933488387354
13 | xsigma_0: 0.8162518983719008
14 |
--------------------------------------------------------------------------------
/region_yamls/15.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 8.111280036435135
3 | xK_0: 0.268152174
4 | xL_0: 28.141422
5 | xL_a: 23.573851
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.16335430971807735
10 | xg_A: 0.10573757974990125
11 | xgamma: 0.3
12 | xl_g: -0.05715547594428186
13 | xsigma_0: 0.29029694003558093
14 |
--------------------------------------------------------------------------------
/region_yamls/17.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 2.4913586566019945
3 | xK_0: 0.043635414
4 | xL_0: 46.488546
5 | xL_a: 59.987638
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.05834835573455604
10 | xg_A: 0.049004053769246436
11 | xgamma: 0.3
12 | xl_g: 0.03709027315241262
13 | xsigma_0: 0.4196283605267465
14 |
--------------------------------------------------------------------------------
/region_yamls/19.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 2.4596628149703816
3 | xK_0: 0.183982308
4 | xL_0: 513.737375
5 | xL_a: 1867.771496
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 1.8390289471577375
10 | xg_A: 0.46217845237530203
11 | xgamma: 0.3
12 | xl_g: 0.017149514576045286
13 | xsigma_0: 0.3103140976545981
14 |
--------------------------------------------------------------------------------
/region_yamls/20.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 0.9929511285910457
3 | xK_0: 0.160199062
4 | xL_0: 522.481879
5 | xL_a: 1830.325243
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.08560686741591728
10 | xg_A: 0.06506072277236097
11 | xgamma: 0.3
12 | xl_g: 0.01902705663391574
13 | xsigma_0: 0.23517024551671273
14 |
--------------------------------------------------------------------------------
/region_yamls/4.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 6.386787299600672
3 | xK_0: 1.094110266
4 | xL_0: 317.880267
5 | xL_a: 287.533185
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.19401001361417486
10 | xg_A: 0.23666124530625898
11 | xgamma: 0.3
12 | xl_g: -0.052512175141607595
13 | xsigma_0: 0.8402859337043421
14 |
--------------------------------------------------------------------------------
/region_yamls/6.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 10.852953800501595
3 | xK_0: 17.553847656
4 | xL_0: 222.891134
5 | xL_a: 168.350837
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.005000000000006631
10 | xg_A: -0.00046726965128841526
11 | xgamma: 0.3
12 | xl_g: -0.011976043898247184
13 | xsigma_0: 0.2851271547872655
14 |
--------------------------------------------------------------------------------
/region_yamls/22.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 29.853559456004625
3 | xK_0: 2.019951041942154
4 | xL_0: 165.75054
5 | xL_a: 216.9269455
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.08802757142538145
10 | xg_A: 0.07541285925058157
11 | xgamma: 0.3
12 | xl_g: -0.0024986057450947508
13 | xsigma_0: 0.25439108584131914
14 |
--------------------------------------------------------------------------------
/region_yamls/25.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 10.922004036104973
3 | xK_0: 0.6059142357084183
4 | xL_0: 705.464681
5 | xL_a: 532.496728
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.09603760623634239
10 | xg_A: 0.16817991622593043
11 | xgamma: 0.3
12 | xl_g: -0.015844877082389193
13 | xsigma_0: 0.7813181890031158
14 |
--------------------------------------------------------------------------------
/region_yamls/28.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 3.1898536850408714
3 | xK_0: 0.1287514001006796
4 | xL_0: 690.0021925
5 | xL_a: 723.512806
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.05389348561714375
10 | xg_A: 0.06812795377170236
11 | xgamma: 0.3
12 | xl_g: -0.012597171762552104
13 | xsigma_0: 0.9487399403167854
14 |
--------------------------------------------------------------------------------
/region_yamls/3.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 13.219587477586199
3 | xK_0: 16.295084052817813
4 | xL_0: 502.409662
5 | xL_a: 445.861101
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.25224119422016716
10 | xg_A: 0.07423569831381745
11 | xgamma: 0.3
12 | xl_g: -0.033398145012670695
13 | xsigma_0: 0.17048017530013193
14 |
--------------------------------------------------------------------------------
/region_yamls/21.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 5.000360862831762
3 | xK_0: 2.289358084859004
4 | xL_0: 165.293239
5 | xL_a: 230.19114338372032
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.18278991377259912
10 | xg_A: 0.07108490122759262
11 | xgamma: 0.3
12 | xl_g: 0.026773049602328805
13 | xsigma_0: 0.4187771240034329
14 |
--------------------------------------------------------------------------------
/region_yamls/26.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 9.633893693772771
3 | xK_0: 0.6076078389971926
4 | xL_0: 465.60668946000004
5 | xL_a: 351.44784048
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.09603760623634239
10 | xg_A: 0.16817991622593043
11 | xgamma: 0.3
12 | xl_g: -0.015844877094273714
13 | xsigma_0: 0.7813181890031158
14 |
--------------------------------------------------------------------------------
/region_yamls/23.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 23.314991608844633
3 | xK_0: 3.0391651447451187
4 | xL_0: 109.39535640000001
5 | xL_a: 143.17178403
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.08802757821836926
10 | xg_A: 0.07541286104378697
11 | xgamma: 0.3
12 | xl_g: -0.002498605753950543
13 | xsigma_0: 0.25439108584131914
14 |
--------------------------------------------------------------------------------
/region_yamls/24.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 29.853559456004625
3 | xK_0: 0.6867833542603324
4 | xL_0: 56.355183600000004
5 | xL_a: 73.75516147
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.08802757142538145
10 | xg_A: 0.07541285925058157
11 | xgamma: 0.3
12 | xl_g: -0.002498605778464897
13 | xsigma_0: 0.25439108584131914
14 |
--------------------------------------------------------------------------------
/region_yamls/29.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 2.033527139192083
3 | xK_0: 0.3810937821808831
4 | xL_0: 455.40144705
5 | xL_a: 477.51845196000005
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.053893489919463196
10 | xg_A: 0.06812795559760368
11 | xgamma: 0.3
12 | xl_g: -0.012597171772754604
13 | xsigma_0: 0.9487399403167854
14 |
--------------------------------------------------------------------------------
/region_yamls/30.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 3.1898536850408714
3 | xK_0: 0.04377547603423107
4 | xL_0: 234.60074545
5 | xL_a: 245.99435404000002
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.05389348561714375
10 | xg_A: 0.06812795377170236
11 | xgamma: 0.3
12 | xl_g: -0.012597171800996251
13 | xsigma_0: 0.9487399403167854
14 |
--------------------------------------------------------------------------------
/region_yamls/27.yml:
--------------------------------------------------------------------------------
1 | _RICE_CONSTANT:
2 | xA_0: 8.620918323265558
3 | xK_0: 0.45330037729157585
4 | xL_0: 239.85799154000003
5 | xL_a: 181.04888752000002
6 | xa_1: 0
7 | xa_2: 0.00236
8 | xa_3: 2
9 | xdelta_A: 0.09603760623634239
10 | xg_A: 0.16817991622593043
11 | xgamma: 0.3
12 | xl_g: -0.015844877127171766
13 | xsigma_0: 0.7813181890031158
14 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | importlib-metadata==4.12.0
2 | flask==2.1.1
3 | gym==0.22.0
4 | pandas==1.3.0
5 | waitress==2.1.1
6 | jupyterlab>=3.4.0
7 | tensorflow==1.13.1
8 | torch==1.9.0
9 | # optional: RLlib and WarpDrive for training
10 | # For CPU
11 | # ray[rllib]==1.0.0
12 | # For GPU
13 | # rl-warp-drive==1.7.0
14 | matplotlib==3.5.3
15 | scikit-learn
16 | numpy==1.21.6
17 | deepdiff==5.8.1
18 | pyyaml==6.0
19 | pytest==7.1.3
20 |
--------------------------------------------------------------------------------
/rice_build.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, salesforce.com, inc and MILA.
2 | // All rights reserved.
3 | // SPDX-License-Identifier: BSD-3-Clause
4 | // For full license text, see the LICENSE file in the repo root
5 | // or https://opensource.org/licenses/BSD-3-Clause
6 |
7 |
8 | #ifndef CUDA_INCLUDES_RICE_CONST_H_
9 | #define CUDA_INCLUDES_RICE_CONST_H_
10 |
11 | #include "rice_step.cu"
12 |
13 | #endif // CUDA_INCLUDES_RICE_CONST_H_
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc and MILA.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root
5 | # or https://opensource.org/licenses/BSD-3-Clause
6 |
7 |
8 | FROM nvcr.io/nvidia/pytorch:21.08-py3
9 | LABEL description="warpdrive-env"
10 | WORKDIR /home/
11 | RUN chmod a+rwx /home/
12 | # Install other packages
13 | RUN pip3 install pycuda==2021.1
14 |
--------------------------------------------------------------------------------
/Dockerfile_CPU:
--------------------------------------------------------------------------------
1 | #FROM toluclassics/transformers_notebook
2 | FROM jupyter/base-notebook:python-3.7.6
3 | #FROM jupyter/tensorflow-notebook
4 |
5 | ENV TRANSFORMERS_CACHE=/tmp/.cache
6 | ENV TOKENIZERS_PARALLELISM=true
7 |
8 | USER root
9 | RUN apt-get update && apt-get install -y libglib2.0-0
10 | RUN pip3 install --no-cache-dir protobuf==3.20.1 importlib-metadata==4.13.0
11 | RUN pip3 install --no-cache-dir tensorflow==1.13.2 gym==0.21 ray[rllib]==1.0.0 torch==1.9.0
12 | RUN pip3 install --no-cache-dir importlib-resources ale-py~=0.7.1 \
13 | && pip3 install --no-cache-dir MarkupSafe==2.0.1
14 | RUN pip3 install --no-cache-dir scikit-learn matplotlib
15 |
16 | USER ${NB_UID}
17 | WORKDIR "${HOME}/work"
18 |
--------------------------------------------------------------------------------
/README_CPU.md:
--------------------------------------------------------------------------------
1 | # Overview
2 |
3 | You can train and evaluate your model in an isolated Docker container.
4 | This keeps your RL environment separate from your local environment,
5 | which make reproducibility more robust.
6 |
7 | The `Dockerfile_CPU` builds a standalone local Docker image that both
8 | hosts the Jupyter notebook server and also can train and evaluate
9 | your submission from the command line.
10 |
11 | # Quick Start
12 |
13 | Build the image and start the notebook server:
14 |
15 | ```
16 | make
17 | ```
18 |
19 | Train model from command line:
20 |
21 | ```
22 | make train
23 | ```
24 |
25 | Evaluate most recent submission
26 |
27 | ```
28 | make evaluate
29 | ```
30 |
31 | # Requirements
32 |
33 | You need to have Docker installed in your workstation.
34 |
35 |
36 |
--------------------------------------------------------------------------------
/tests/test_train_with_rllib.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import subprocess
4 | from pathlib import Path
5 |
6 | _path = Path(os.path.abspath(__file__))
7 | PUBLIC_REPO_DIR = str(_path.parent.parent.absolute())
8 | os.chdir(PUBLIC_REPO_DIR)
9 |
10 |
11 | def test_training_run():
12 | env = os.environ.copy()
13 | env["CONFIG_FILE"] = "./tests/rice_rllib_test.yaml"
14 | command = subprocess.run(["python", "./scripts/train_with_rllib.py"],
15 | env=env,
16 | capture_output=True)
17 | assert command.returncode == 0
18 | output = command.stdout.decode("utf-8")
19 |
20 | file_reg = re.compile(r'is created at: (.*$)', flags=re.M)
21 | print(output)
22 | match = file_reg.search(output).group(1)
23 | assert os.path.exists(match)
24 |
--------------------------------------------------------------------------------
/.github/workflows/main.yaml:
--------------------------------------------------------------------------------
1 | name: Python package
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 | strategy:
10 | matrix:
11 | python-version: ["3.7"]
12 |
13 | steps:
14 | - uses: actions/checkout@v3
15 | - name: Set up Python ${{ matrix.python-version }}
16 | uses: actions/setup-python@v4
17 | with:
18 | python-version: ${{ matrix.python-version }}
19 | - name: Install dependencies
20 | run: |
21 | python -m pip install --upgrade pip
22 | pip install -r requirements.txt
23 | pip install ray[rllib]==1.0.0
24 | # TODO fix flake8 errors and enable
25 | # - name: Lint with flake8
26 | # run: |
27 | # ./scripts/lint.sh
28 | - name: Test with pytest
29 | run: |
30 | pytest
31 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Zhang"
5 | given-names: "Tianyu"
6 | orcid: "https://orcid.org/0000-0002-4410-1343"
7 | - family-names: "Srinivasa"
8 | given-names: "Sunil"
9 | orcid: "https://orcid.org/0000-0000-0000-0000"
10 | - family-names: "Wiliams"
11 | given-names: "Andrew"
12 | orcid: "https://orcid.org/0000-0000-0000-0000"
13 | - family-names: "Phade"
14 | given-names: "Soham"
15 | orcid: "https://orcid.org/0000-0000-0000-0000"
16 | - family-names: "Zhang"
17 | given-names: "Yang"
18 | orcid: "https://orcid.org/0000-0000-0000-0000"
19 | - family-names: "Gupta"
20 | given-names: "Prateek"
21 | orcid: "https://orcid.org/0000-0002-0892-3518"
22 | - family-names: "Bengio"
23 | given-names: "Yoshua"
24 | orcid: "https://orcid.org/0000-0002-9322-3515"
25 | - family-names: "Zheng"
26 | given-names: "Stephan"
27 | orcid: "https://orcid.org/0000-0002-7271-1616"
28 | title: "RICE-N"
29 | version: 1.0.0
30 | date-released: 2022-07-03
31 | url: "https://github.com/mila-iqia/climate-cooperation-competition"
32 |
--------------------------------------------------------------------------------
/scripts/desired_outputs.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc and MILA.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root
5 | # or https://opensource.org/licenses/BSD-3-Clause
6 |
7 |
8 | desired_outputs = [
9 | "global_temperature",
10 | "global_carbon_mass",
11 | "capital_all_regions",
12 | "labor_all_regions",
13 | "production_factor_all_regions",
14 | "intensity_all_regions",
15 | "global_exogenous_emissions",
16 | "global_land_emissions",
17 | "timestep",
18 | "activity_timestep",
19 | "capital_depreciation_all_regions",
20 | "savings_all_regions",
21 | "mitigation_rate_all_regions",
22 | "max_export_limit_all_regions",
23 | "mitigation_cost_all_regions",
24 | "damages_all_regions",
25 | "abatement_cost_all_regions",
26 | "utility_all_regions",
27 | "social_welfare_all_regions",
28 | "reward_all_regions",
29 | "consumption_all_regions",
30 | "current_balance_all_regions",
31 | "gross_output_all_regions",
32 | "investment_all_regions",
33 | "production_all_regions",
34 | "tariffs",
35 | "future_tariffs",
36 | "scaled_imports",
37 | "desired_imports",
38 | "tariffed_imports",
39 | "stage",
40 | "minimum_mitigation_rate_all_regions",
41 | "promised_mitigation_rate",
42 | "requested_mitigation_rate",
43 | "proposal_decisions",
44 | ]
45 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2022, Salesforce.com, Inc and MILA.
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5 |
6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7 |
8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9 |
10 | * Neither the name of Salesforce.com or MILA, nor the names of their contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 |
12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/scripts/run_cpu_gpu_env_consistency_checks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc and MILA.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root
5 | # or https://opensource.org/licenses/BSD-3-Clause
6 |
7 |
8 | # Run this via
9 | # !python run_cpu_gpu_env_consistency_checks.py
10 |
11 | import logging
12 | import os
13 | import sys
14 |
15 | import torch
16 |
17 | from fixed_paths import PUBLIC_REPO_DIR
18 | sys.path.append(PUBLIC_REPO_DIR)
19 |
20 | from warp_drive.env_cpu_gpu_consistency_checker import EnvironmentCPUvsGPU
21 | from warp_drive.utils.env_registrar import EnvironmentRegistrar
22 |
23 | from rice import Rice
24 | from rice_cuda import RiceCuda
25 |
26 | logging.getLogger().setLevel(logging.ERROR)
27 |
28 | _NUM_GPUS_AVAILABLE = torch.cuda.device_count()
29 | assert _NUM_GPUS_AVAILABLE > 0, "This script needs a GPU to run!"
30 |
31 | env_registrar = EnvironmentRegistrar()
32 |
33 | env_registrar.add_cuda_env_src_path(Rice.name, os.path.join(PUBLIC_REPO_DIR, "rice_build.cu"))
34 | env_configs = {
35 | "no_negotiation": {
36 | "num_discrete_action_levels": 100,
37 | "negotiation_on": False,
38 | },
39 | "with_negotiation": {
40 | "num_discrete_action_levels": 100,
41 | "negotiation_on": True,
42 | },
43 | }
44 | testing_class = EnvironmentCPUvsGPU(
45 | cpu_env_class=Rice,
46 | cuda_env_class=RiceCuda,
47 | env_configs=env_configs,
48 | num_envs=2,
49 | num_episodes=2,
50 | use_gpu_testing_mode=False,
51 | env_registrar=env_registrar,
52 | )
53 |
54 | testing_class.test_env_reset_and_step(consistency_threshold_pct=1, seed=17)
55 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | PROJECT = climate-coop/cpu
2 | #VERSION ?= $(shell git describe --abbrev=0 --tags)
3 | VERSION ?= latest
4 | IMAGE ?= $(PROJECT):$(VERSION)
5 | DOCKER_VARS = --user $(shell id -u):$(shell id -g) --group-add users -e GRANT_SUDO=yes
6 | MOUNT_VARS = -v "${PWD}":/home/jovyan/work -v ${PWD}/ray_results:/home/jovyan/ray_results
7 | X11_VARS = -e "DISPLAY" -v "/etc/group:/etc/group:ro" -v "/etc/passwd:/etc/passwd:ro" -v "/etc/shadow:/etc/shadow:ro" -v "/etc/sudoers.d:/etc/sudoers.d:ro" -v "/tmp/.X11-unix:/tmp/.X11-unix:rw"
8 | ALL_VARS = $(DOCKER_VARS) $(MOUNT_VARS) $(X11_VARS)
9 |
10 | PYTHONPATH ?= -e PYTHONPATH=scripts
11 |
12 | all: build run
13 |
14 | build:
15 | docker build -t $(IMAGE) -f Dockerfile_CPU .
16 |
17 | run:
18 | docker run -it --rm $(ALL_VARS) -p 8888:8888 $(IMAGE)
19 |
20 | train:
21 | mkdir -p ray_results > /dev/null
22 | docker run -it --rm $(ALL_VARS) $(IMAGE) python scripts/train_with_rllib.py
23 |
24 | evaluate: SUB=$(shell ls -r Submissions/*.zip | tr ' ' '\n' | head -1)
25 | evaluate:
26 | mkdir -p .tmp/_base
27 | docker run -it --rm $(ALL_VARS) $(IMAGE) python scripts/evaluate_submission.py -r $(SUB)
28 |
29 | bash:
30 | docker run -it --rm $(ALL_VARS) $(IMAGE) bash
31 |
32 | python:
33 | docker run -it --rm $(ALL_VARS) $(PYTHONPATH) $(IMAGE) python
34 |
35 | # View on http://localhost:6006
36 | tensorboard:
37 | docker run -it --rm $(ALL_VARS) -p 6006:6006 $(IMAGE) tensorboard --logdir '~/work/ray_results'
38 |
39 | diagnose: LOG=diagnostic.log
40 | diagnose:
41 | @echo "Build clean image (no cache)..." | tee $(LOG)
42 | docker build --no-cache -t $(IMAGE) -f Dockerfile_CPU . | tee -a $(LOG)
43 | @echo "Using Docker: $(shell docker --version)" | tee -a $(LOG)
44 | @echo "Working dir: $(shell pwd)" | tee -a $(LOG)
45 | @echo "User: $(shell id -un)" | tee -a $(LOG)
46 | @echo "In docker group: $(shell id -Gn | tr ' ' '\n' | grep docker)" | tee -a $(LOG)
47 | @echo "Inspect Docker container..." | tee -a $(LOG)
48 | docker run --rm $(ALL_VARS) $(IMAGE) whoami | tee -a $(LOG)
49 | docker run --rm $(ALL_VARS) $(IMAGE) ls -l | tee -a $(LOG)
50 |
51 |
--------------------------------------------------------------------------------
/tests/rice_rllib_test.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc and MILA.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root
5 | # or https://opensource.org/licenses/BSD-3-Clause
6 |
7 | # This file does a quick run for tests
8 |
9 | # Checkpoint saving setting
10 | saving:
11 | metrics_log_freq: 100 # how often (in iterations) to log (and print) the metrics
12 | model_params_save_freq: 1000 # how often (in iterations) to save the model parameters
13 | basedir: "/tmp" # base folder used for saving
14 | name: "rice" # experiment name
15 | tag: "experiments" # experiment tag
16 |
17 | # Trainer settings
18 | trainer:
19 | num_envs: 1
20 | rollout_fragment_length: 100
21 | train_batch_size: 20
22 | num_episodes: 1 # This will cause one iteration only
23 | framework: torch # framework setting.
24 | # Note: RLlib supports TF as well, but our end-to-end pipeline is built for Pytorch only.
25 | # === Hardware Settings ===
26 | num_workers: 1 # number of rollout worker actors to create for parallel sampling.
27 | # Note: Setting the num_workers to 0 will force rollouts to be done in the trainer actor.
28 | num_gpus: 0 # number of GPUs to allocate to the trainer process. This can also be fractional (e.g., 0.3 GPUs).
29 |
30 | # Environment configuration
31 | env:
32 | num_discrete_action_levels: 10 # number of discrete levels for the saving and mitigation actions
33 | negotiation_on: False # flag to indicate whether negotiation is allowed or not
34 |
35 | # Policy network settings
36 | policy:
37 | regions:
38 | vf_loss_coeff: 0.1 # loss coefficient schedule for the value function loss
39 | entropy_coeff_schedule: # loss coefficient schedule for the entropy loss
40 | # piecewise linear, specified as (timestep, value)
41 | - [0, 0.5]
42 | - [1000000, 0.1]
43 | - [5000000, 0.05]
44 | clip_grad_norm: True # flag indicating whether to clip the gradient norm or not
45 | max_grad_norm: 0.5 # when clip_grad_norm is True, the clip level
46 | gamma: 0.92 # discount factor
47 | lr: 0.0005 # learning rate
48 | model:
49 | custom_model: torch_linear
50 | custom_model_config:
51 | fc_dims: [16, 16]
52 |
--------------------------------------------------------------------------------
/scripts/rice_warpdrive.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc and MILA.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root
5 | # or https://opensource.org/licenses/BSD-3-Clause
6 |
7 |
8 | # Checkpoint saving setting
9 | saving:
10 | metrics_log_freq: 100 # how often (in iterations) to log (and print) the metrics
11 | model_params_save_freq: 1000 # how often (in iterations) to save the model parameters
12 | basedir: "/tmp" # base folder used for saving
13 | name: "rice" # experiment name
14 | tag: "experiments" # experiment tag
15 |
16 | # Trainer settings
17 | trainer:
18 | num_envs: 100 # number of environment replicas
19 | num_episodes: 1000 # number of episodes to run the training for
20 | train_batch_size: 10000 # total batch size used for training per iteration (across all the environments)
21 |
22 | # Environment configuration
23 | env:
24 | num_discrete_action_levels: 10 # number of discrete levels for the saving and mitigation actions
25 | negotiation_on: False # flag to indicate whether negotiation is allowed or not
26 |
27 | # Policy network settings
28 | policy:
29 | regions:
30 | to_train: True # flag indicating whether the model needs to be trained
31 | algorithm: "A2C" # algorithm used to train the policy
32 | vf_loss_coeff: 0.1 # loss coefficient schedule for the value function loss
33 | entropy_coeff: # loss coefficient schedule for the entropy loss
34 | # piecewise linear, specified as (timestep, value)
35 | - [0, 0.5]
36 | - [1000000, 0.1]
37 | - [5000000, 0.05]
38 | clip_grad_norm: True # flag indicating whether to clip the gradient norm or not
39 | max_grad_norm: 0.5 # when clip_grad_norm is True, the clip level
40 | normalize_advantage: False # flag indicating whether to normalize advantage or not
41 | normalize_return: False # flag indicating whether to normalize return or not
42 | gamma: 0.92 # discount factor
43 | lr: 0.0005 # learning rate
44 | model: # policy model settings
45 | type: "fully_connected" # model type
46 | fc_dims: [256, 256] # dimension(s) of the fully connected layers as a list
47 | model_ckpt_filepath: "" # filepath (used to restore a previously saved model)
48 |
--------------------------------------------------------------------------------
/scripts/rice_rllib.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc and MILA.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root
5 | # or https://opensource.org/licenses/BSD-3-Clause
6 |
7 |
8 | # Checkpoint saving setting
9 | saving:
10 | metrics_log_freq: 100 # how often (in iterations) to log (and print) the metrics
11 | model_params_save_freq: 1000 # how often (in iterations) to save the model parameters
12 | basedir: "/tmp" # base folder used for saving
13 | name: "rice" # experiment name
14 | tag: "experiments" # experiment tag
15 |
16 | # Trainer settings
17 | trainer:
18 | num_envs: 20 # number of environment replicas
19 | rollout_fragment_length: 100 # divide episodes into fragments of this many steps each during rollouts.
20 | train_batch_size: 2000 # total batch size used for training per iteration (across all the environments)
21 | num_episodes: 100 # number of episodes to run the training for
22 | framework: torch # framework setting.
23 | # Note: RLlib supports TF as well, but our end-to-end pipeline is built for Pytorch only.
24 | # === Hardware Settings ===
25 | num_workers: 4 # number of rollout worker actors to create for parallel sampling.
26 | # Note: Setting the num_workers to 0 will force rollouts to be done in the trainer actor.
27 | num_gpus: 0 # number of GPUs to allocate to the trainer process. This can also be fractional (e.g., 0.3 GPUs).
28 |
29 | # Environment configuration
30 | env:
31 | num_discrete_action_levels: 10 # number of discrete levels for the saving and mitigation actions
32 | negotiation_on: False # flag to indicate whether negotiation is allowed or not
33 |
34 | # Policy network settings
35 | policy:
36 | regions:
37 | vf_loss_coeff: 0.1 # loss coefficient schedule for the value function loss
38 | entropy_coeff_schedule: # loss coefficient schedule for the entropy loss
39 | # piecewise linear, specified as (timestep, value)
40 | - [0, 0.5]
41 | - [1000000, 0.1]
42 | - [5000000, 0.05]
43 | clip_grad_norm: True # flag indicating whether to clip the gradient norm or not
44 | max_grad_norm: 0.5 # when clip_grad_norm is True, the clip level
45 | gamma: 0.92 # discount factor
46 | lr: 0.0005 # learning rate
47 | model:
48 | custom_model: torch_linear
49 | custom_model_config:
50 | fc_dims: [256, 256]
51 |
--------------------------------------------------------------------------------
/region_yamls/default.yml:
--------------------------------------------------------------------------------
1 | _DICE_CONSTANT:
2 | xt_0: 2015 # starting year of the whole model
3 | xDelta: 5 # the time interval (year)
4 | xN: 20 # total time steps
5 |
6 | # Climate diffusion parameters
7 | xPhi_T: [[0.8718, 0.0088], [0.025, 0.975]]
8 | xB_T: [0.1005, 0]
9 | # xB_T: [0.03, 0]
10 |
11 | # Carbon cycle diffusion parameters (the zeta matrix in the paper)
12 | xPhi_M: [[0.88, 0.196, 0], [0.12, 0.797, 0.001465], [0, 0.007, 0.99853488]]
13 | # xB_M: [0.2727272727272727, 0, 0] # 12/44
14 | xB_M: [1.36388, 0, 0] # 12/44
15 | xeta: 3.6813 #?? I don't find where it's used
16 |
17 | xM_AT_1750: 588 # atmospheric mass of carbon in the year of 1750
18 | xf_0: 0.5 # in Eq 3 param to effect of greenhouse gases other than carbon dioxide
19 | xf_1: 1 # in Eq 3 param to effect of greenhouse gases other than carbon dioxide
20 | xt_f: 20 # in Eq 3 time step param to effect of greenhouse gases other than carbon dioxide
21 | xE_L0: 2.6 # 2.6 # in Eq 4 param to the emissions due to land use changes
22 | xdelta_EL: 0.001 # 0.115 # 0.115 # in Eq 4 param to the emissions due to land use changes
23 |
24 | xM_AT_0: 851 # in CAP the atmospheric mass of carbon in the year t
25 | xM_UP_0: 460 # in CAP the atmospheric upper bound of mass of carbon in the year t
26 | xM_LO_0: 1740 # in CAP the atmospheric lower bound of mass of carbon in the year t
27 | xe_0: 35.85 # in EI define the initial simga_0: e0/(q0(1-mu0))
28 | xq_0: 105.5 # in EI define the initial simga_0: e0/(q0(1-mu0))
29 | xmu_0: 0.03 # in EI define the initial simga_0: e0/(q0(1-mu0))
30 |
31 | # From Python implementation PyDICE
32 | xF_2x: 3.6813 # 3.6813 # Forcing that doubles equilibrium carbon.
33 | xT_2x: 3.1 # 3.1 # Equilibrium temperature increase at double carbon eq.
34 |
35 | _RICE_CONSTANT:
36 | xgamma: 0.3 # in CAP Eq 5 the capital elasticty
37 | xtheta_2: 2.6 # in CAP Eq 6
38 | xa_1: 0
39 | xa_2: 0.00236 # in CAP Eq 6
40 | xa_3: 2 # in CAP Eq 6
41 | xdelta_K: 0.1 # in CAP Eq 9 param discribe the depreciate of the capital
42 | xalpha: 1.45 # Utility function param
43 | xrho: 0.015 # discount factor of the utility
44 | xL_0: 7403 # in POP population at the staring point
45 | xL_a: 11500 # in POP the expected population at convergence
46 | xl_g: 0.134 # in POP control the rate to converge
47 | xA_0: 5.115 # in TFP technology at starting point
48 | xg_A: 0.076 # in TFP control the rate of increasing of tech larger->faster
49 | xdelta_A: 0.005 # in TFP control the rate of increasing of tech smaller->faster
50 | xsigma_0: 0.3503 # e0/(q0(1-mu0)) in EI emission intensity at the starting point
51 | xg_sigma: 0.0025 # 0.0152 # 0.0025 in EI control the rate of mitigation larger->reduce more emission
52 | xdelta_sigma: 0.1 # 0.01 in EI control the rate of mitigation larger->reduce less emission
53 | xp_b: 550 # 550 # in Eq 2 (estimate of the cost of mitigation) represents the price of a backstop technology that can remove carbon dioxide from the atmosphere
54 | xdelta_pb: 0.001 # 0.025 # in Eq 2 control the how the cost of mitigation change through time larger->cost less as time goes by
55 | xscale_1: 0.030245527 # in Eq 29 Nordhaus scaled cost function param
56 | xscale_2: 10993.704 # in Eq 29 Nordhaus scaled cost function param
57 |
58 | xT_AT_0: 0.85 # in CAP a part of damage function initial condition
59 | xT_LO_0: 0.0068 # in CAP a part of damage function initial condition
60 | xK_0: 223 # in CAP initial condition for capital
--------------------------------------------------------------------------------
/scripts/create_submission_zip.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc and MILA.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root
5 | # or https://opensource.org/licenses/BSD-3-Clause
6 |
7 |
8 | """
9 | Script to create the zipped submission file from the results directory
10 | """
11 | import os
12 | import shutil
13 | import sys
14 | import argparse
15 | from evaluate_submission import validate_dir
16 | import time
17 | from fixed_paths import PUBLIC_REPO_DIR
18 |
19 | sys.path.append(PUBLIC_REPO_DIR)
20 |
21 |
22 | def get_results_dir():
23 | """
24 | Obtain the 'results' directory from the system arguments.
25 | """
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument(
28 | "--results_dir",
29 | "-r",
30 | type=str,
31 | default=".",
32 | help="the directory where all the submission files are saved. Can also be "
33 | "the zipped file containing all the submission files.",
34 | )
35 | args = parser.parse_args()
36 |
37 | if "results_dir" not in args:
38 | raise ValueError(
39 | "Please provide a results directory to evaluate with the argument -r"
40 | )
41 | if not os.path.exists(args.results_dir):
42 | raise ValueError(
43 | "The results directory is missing. Please make sure the correct path "
44 | "is specified!"
45 | )
46 | try:
47 | results_dir = args.results_dir
48 |
49 | # Also handle a zipped file
50 | if results_dir.endswith(".zip"):
51 | unzipped_results_dir = os.path.join("/tmp", str(time.time()))
52 | shutil.unpack_archive(results_dir, unzipped_results_dir)
53 | results_dir = unzipped_results_dir
54 | return results_dir, parser
55 | except Exception as err:
56 | raise ValueError("Cannot obtain the results directory") from err
57 |
58 |
59 | def prepare_submission(results_dir=None):
60 | """
61 | # Validate all the submission files and compress into a .zip.
62 | Note: This method is also invoked in the trainer script itself!
63 | So if you ran the training script, you may not need to re-run this.
64 | Args results_dir: the directory where all the training files were saved.
65 | """
66 | assert results_dir is not None
67 | submission_filename = results_dir.split("/")[-1]
68 | submission_file = os.path.join(PUBLIC_REPO_DIR, "Submissions", submission_filename)
69 |
70 | validate_dir(results_dir)
71 |
72 | # Only copy the latest policy model file for submission
73 | results_dir_copy = os.path.join("/tmp", "_copies_", submission_filename)
74 | shutil.copytree(results_dir, results_dir_copy)
75 |
76 | policy_models = [
77 | os.path.join(results_dir, file)
78 | for file in os.listdir(results_dir)
79 | if file.endswith(".state_dict")
80 | ]
81 | sorted_policy_models = sorted(policy_models, key=os.path.getmtime)
82 | # Delete all but the last policy model file
83 | for policy_model in sorted_policy_models[:-1]:
84 | os.remove(os.path.join(results_dir_copy, policy_model.split("/")[-1]))
85 |
86 | shutil.make_archive(submission_file, "zip", results_dir_copy)
87 | print("NOTE: The submission file is created at:", submission_file + ".zip")
88 | shutil.rmtree(results_dir_copy)
89 |
90 |
91 | if __name__ == "__main__":
92 | prepare_submission(results_dir=get_results_dir()[0])
93 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # See http://help.github.com/ignore-files/ for more about ignoring files.
2 |
3 | # compiled output
4 | /dist
5 | /tmp
6 | /out-tsc
7 |
8 | data
9 |
10 | # Runtime data
11 | pids
12 | *.pid
13 | *.seed
14 | *.pid.lock
15 |
16 | # Directory for instrumented libs generated by jscoverage/JSCover
17 | lib-cov
18 |
19 | # Coverage directory used by tools like istanbul
20 | coverage
21 |
22 | # nyc test coverage
23 | .nyc_output
24 |
25 | # Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files)
26 | .grunt
27 |
28 | # Bower dependency directory (https://bower.io/)
29 | bower_components
30 |
31 | # node-waf configuration
32 | .lock-wscript
33 |
34 | # IDEs and editors
35 | .idea
36 | .project
37 | .classpath
38 | .c9/
39 | *.launch
40 | .settings/
41 | *.sublime-workspace
42 |
43 | # IDE - VSCode
44 | .vscode/*
45 | !.vscode/settings.json
46 | !.vscode/tasks.json
47 | !.vscode/launch.json
48 | !.vscode/extensions.json
49 |
50 | # misc
51 | .sass-cache
52 | connect.lock
53 | typings
54 |
55 | # Logs
56 | logs
57 | *.log
58 | npm-debug.log*
59 | yarn-debug.log*
60 | yarn-error.log*
61 |
62 |
63 | # Dependency directories
64 | node_modules/
65 | jspm_packages/
66 |
67 | # Optional npm cache directory
68 | .npm
69 |
70 | # Optional eslint cache
71 | .eslintcache
72 |
73 | # Optional REPL history
74 | .node_repl_history
75 |
76 | # Output of 'npm pack'
77 | *.tgz
78 |
79 | # Yarn Integrity file
80 | .yarn-integrity
81 |
82 | # dotenv environment variables file
83 | .env
84 |
85 | # next.js build output
86 | .next
87 |
88 | # Lerna
89 | lerna-debug.log
90 |
91 | # System Files
92 | .DS_Store
93 | Thumbs.db
94 |
95 |
96 | # Byte-compiled / optimized / DLL files
97 | __pycache__/
98 | *.py[cod]
99 | *$py.class
100 |
101 | # C extensions
102 | *.so
103 |
104 | # Distribution / packaging
105 | .Python
106 | build/
107 | develop-eggs/
108 | dist/
109 | downloads/
110 | eggs/
111 | .eggs/
112 | lib/
113 | lib64/
114 | parts/
115 | sdist/
116 | var/
117 | wheels/
118 | pip-wheel-metadata/
119 | share/python-wheels/
120 | *.egg-info/
121 | .installed.cfg
122 | *.egg
123 | MANIFEST
124 |
125 | # PyInstaller
126 | # Usually these files are written by a python script from a template
127 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
128 | *.manifest
129 | *.spec
130 |
131 | # Installer logs
132 | pip-log.txt
133 | pip-delete-this-directory.txt
134 |
135 | # Unit test / coverage reports
136 | htmlcov/
137 | .tox/
138 | .nox/
139 | .coverage
140 | .coverage.*
141 | .cache
142 | nosetests.xml
143 | coverage.xml
144 | *.cover
145 | *.py,cover
146 | .hypothesis/
147 | .pytest_cache/
148 |
149 | # Translations
150 | *.mo
151 | *.pot
152 |
153 | # Django stuff:
154 | *.log
155 | local_settings.py
156 | db.sqlite3
157 | db.sqlite3-journal
158 |
159 | # Flask stuff:
160 | instance/
161 | .webassets-cache
162 |
163 | # Scrapy stuff:
164 | .scrapy
165 |
166 | # Sphinx documentation
167 | docs/_build/
168 |
169 | # PyBuilder
170 | target/
171 |
172 | # Jupyter Notebook
173 | .ipynb_checkpoints
174 |
175 | # IPython
176 | profile_default/
177 | ipython_config.py
178 |
179 | # pyenv
180 | .python-version
181 |
182 | # pipenv
183 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
184 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
185 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
186 | # install all needed dependencies.
187 | #Pipfile.lock
188 |
189 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
190 | __pypackages__/
191 |
192 | # Celery stuff
193 | celerybeat-schedule
194 | celerybeat.pid
195 |
196 | # SageMath parsed files
197 | *.sage.py
198 |
199 | # Environments
200 | .env
201 | .venv
202 | env/
203 | venv/
204 | ENV/
205 | env.bak/
206 | venv.bak/
207 |
208 | # Spyder project settings
209 | .spyderproject
210 | .spyproject
211 |
212 | # Rope project settings
213 | .ropeproject
214 |
215 | # mkdocs documentation
216 | /site
217 |
218 | # mypy
219 | .mypy_cache/
220 | .dmypy.json
221 | dmypy.json
222 |
223 | # Pyre type checker
224 | .pyre/
225 | Submissions/*.zip
226 | emissions.csv
227 | Colab_Tutorial-Copy1.ipynb
228 | Draw_graphs.ipynb
229 | 100_rounds.pkl
230 | *.pdf
231 | core
232 | .tmp
233 | Submissions/*
234 | .vscode/*
235 | *.pkl
236 | scripts/climate_economic_min_max_indices.txt
237 | scripts/dev.ipynb
238 | scripts/dev.py
239 | scripts/eval_by_logs.py
240 |
--------------------------------------------------------------------------------
/MARL_using_RLlib.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "56cb26e3",
6 | "metadata": {},
7 | "source": [
8 | "Copyright (c) 2022, salesforce.com, inc and MILA. \n",
9 | "All rights reserved. \n",
10 | "SPDX-License-Identifier: BSD-3-Clause \n",
11 | "For full license text, see the LICENSE file in the repo root \n",
12 | "or https://opensource.org/licenses/BSD-3-Clause "
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "id": "7635e695",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "import logging\n",
23 | "import numpy as np\n",
24 | "import os\n",
25 | "import shutil\n",
26 | "import subprocess\n",
27 | "import sys\n",
28 | "import time\n",
29 | "import yaml\n",
30 | "\n",
31 | "from scripts.train_with_rllib import create_trainer, fetch_episode_states, load_model_checkpoints"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "id": "a1e40c9c",
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "# Set logger level e.g., DEBUG, INFO, WARNING, ERROR.\n",
42 | "logging.getLogger().setLevel(logging.ERROR)"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "id": "4f636cf3",
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "# Needed to perform this install when the system threw the lib.so file missing error\n",
53 | "# ! apt-get install libglib2.0-0 --yes"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "id": "6b952acb",
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "print(\"Training with RLlib...\")\n",
64 | "# Read the run configurations specific to the environment.\n",
65 | "# Note: The run config yaml(s) can be edited at warp_drive/training/run_configs\n",
66 | "# -----------------------------------------------------------------------------\n",
67 | "config_path = os.path.join(\"scripts\", \"rice_rllib.yaml\")\n",
68 | "if not os.path.exists(config_path):\n",
69 | " raise ValueError(\n",
70 | " \"The run configuration is missing. Please make sure the correct path\"\n",
71 | " \"is specified.\"\n",
72 | " )\n",
73 | "\n",
74 | "with open(config_path, \"r\", encoding=\"utf8\") as fp:\n",
75 | " run_config = yaml.safe_load(fp)\n",
76 | "\n",
77 | "# Create trainer\n",
78 | "# --------------\n",
79 | "trainer, save_dir = create_trainer(run_config)\n",
80 | "\n",
81 | "# Copy the source files into the results directory\n",
82 | "# ------------------------------------------------\n",
83 | "os.makedirs(save_dir)\n",
84 | "for file in [\n",
85 | " \"rice.py\",\n",
86 | "]:\n",
87 | " shutil.copyfile(\n",
88 | " os.path.join(file),\n",
89 | " os.path.join(save_dir, file),\n",
90 | " )\n",
91 | "# Add an identifier file\n",
92 | "with open(os.path.join(save_dir, \".rllib\"), \"x\", encoding=\"utf-8\") as fp:\n",
93 | " pass\n",
94 | "fp.close()"
95 | ]
96 | },
97 | {
98 | "cell_type": "markdown",
99 | "id": "bf5ea093",
100 | "metadata": {},
101 | "source": [
102 | "### Invoke training"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "id": "10564c91",
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "NUM_ITERS = 5\n",
113 | "for iter in range(NUM_ITERS):\n",
114 | " result = trainer.train()\n",
115 | "print(result)"
116 | ]
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "id": "d31b0cf0",
121 | "metadata": {},
122 | "source": [
123 | "### Fetch episode states"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "id": "23cb91ee",
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "outputs = fetch_episode_states(trainer, [\"T_i\", \"carbon_mass_i\", \"capital_i\"])"
134 | ]
135 | }
136 | ],
137 | "metadata": {
138 | "kernelspec": {
139 | "display_name": "Python 3 (ipykernel)",
140 | "language": "python",
141 | "name": "python3"
142 | },
143 | "language_info": {
144 | "codemirror_mode": {
145 | "name": "ipython",
146 | "version": 3
147 | },
148 | "file_extension": ".py",
149 | "mimetype": "text/x-python",
150 | "name": "python",
151 | "nbconvert_exporter": "python",
152 | "pygments_lexer": "ipython3",
153 | "version": "3.8.10"
154 | }
155 | },
156 | "nbformat": 4,
157 | "nbformat_minor": 5
158 | }
159 |
--------------------------------------------------------------------------------
/rice_cuda.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc and MILA.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root
5 | # or https://opensource.org/licenses/BSD-3-Clause
6 |
7 |
8 | """
9 | CUDA version of the Regional Integrated model of Climate and the Economy (RICE).
10 | This subclasses the python version of the model and also the CUDAEnvironmentContext
11 | for running with WarpDrive (https://github.com/salesforce/warp-drive)
12 | """
13 |
14 | import os
15 | import sys
16 | import numpy as np
17 | from warp_drive.utils.constants import Constants
18 | from warp_drive.utils.data_feed import DataFeed
19 | from warp_drive.utils.gpu_environment_context import CUDAEnvironmentContext
20 |
21 | _PUBLIC_REPO_DIR = os.path.dirname(os.path.abspath(__file__))
22 | sys.path = [_PUBLIC_REPO_DIR] + sys.path
23 |
24 | from rice import Rice
25 |
26 | _OBSERVATIONS = Constants.OBSERVATIONS
27 | _ACTIONS = Constants.ACTIONS
28 | _REWARDS = Constants.REWARDS
29 |
30 |
31 | class RiceCuda(Rice, CUDAEnvironmentContext):
32 | """
33 | Rice env class that invokes the CUDA step function.
34 | """
35 |
36 | name = "Rice"
37 |
38 | def get_data_dictionary(self):
39 | """
40 | Create a dictionary of data to push to the device.
41 | """
42 | data_feed = DataFeed()
43 |
44 | # Add constants
45 | for key, value in sorted(self.dice_constant.items()):
46 | data_feed.add_data(name=key, data=value)
47 |
48 | for key, value in sorted(self.rice_constant.items()):
49 | data_feed.add_data(name=key, data=value)
50 |
51 | # Add all the global states at timestep 0.
52 | timestep = 0
53 | for key in sorted(self.global_state.keys()):
54 | data_feed.add_data(
55 | name=key,
56 | data=self.global_state[key]["value"][timestep],
57 | save_copy_and_apply_at_reset=True,
58 | )
59 |
60 | for key in sorted(self.global_state.keys()):
61 | data_feed.add_data(
62 | name=key + "_norm",
63 | data=self.global_state[key]["norm"],
64 | )
65 |
66 | # Env config parameters
67 | data_feed.add_data(
68 | name="aux_ms",
69 | data=np.zeros(self.num_regions, dtype=np.float32),
70 | save_copy_and_apply_at_reset=True,
71 | )
72 |
73 | # Env config parameters
74 | data_feed.add_data(
75 | name="num_discrete_action_levels",
76 | data=self.num_discrete_action_levels,
77 | )
78 |
79 | data_feed.add_data(
80 | name="balance_interest_rate",
81 | data=self.balance_interest_rate,
82 | )
83 |
84 | data_feed.add_data(name="negotiation_on", data=self.negotiation_on)
85 |
86 | # Armington agg. parameters
87 | data_feed.add_data_list(
88 | [
89 | ("sub_rate", self.sub_rate),
90 | ("dom_pref", self.dom_pref),
91 | ("for_pref", self.for_pref),
92 | ]
93 | )
94 |
95 | # Year parameters
96 | data_feed.add_data_list(
97 | [("current_year", self.current_year, True), ("end_year", self.end_year)]
98 | )
99 |
100 | return data_feed
101 |
102 | @staticmethod
103 | def get_tensor_dictionary():
104 | """
105 | Create a dictionary of pytorch-accessible tensors to push to the device.
106 | """
107 | tensor_dict = DataFeed()
108 | return tensor_dict
109 |
110 | def step(self):
111 | constants_keys = list(sorted(self.dice_constant.keys())) + list(
112 | sorted(self.rice_constant.keys())
113 | )
114 | args = (
115 | constants_keys
116 | + list(sorted(self.global_state.keys()))
117 | + [key + "_norm" for key in list(sorted(self.global_state.keys()))]
118 | + [
119 | "num_discrete_action_levels",
120 | "balance_interest_rate",
121 | "negotiation_on",
122 | "aux_ms",
123 | "sub_rate",
124 | "dom_pref",
125 | "for_pref",
126 | "current_year",
127 | "end_year",
128 | _OBSERVATIONS + "_features",
129 | _OBSERVATIONS + "_action_mask",
130 | _ACTIONS,
131 | _REWARDS,
132 | "_done_",
133 | "_timestep_",
134 | ("n_agents", "meta"),
135 | ("episode_length", "meta"),
136 | ]
137 | )
138 |
139 | self.cuda_step(
140 | *self.cuda_step_function_feed(args),
141 | block=self.cuda_function_manager.block,
142 | grid=self.cuda_function_manager.grid,
143 | )
144 |
--------------------------------------------------------------------------------
/scripts/torch_models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc and MILA.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root
5 | # or https://opensource.org/licenses/BSD-3-Clause
6 |
7 |
8 | """
9 | Custom Pytorch policy models to use with RLlib.
10 | """
11 |
12 | # API reference:
13 | # https://docs.ray.io/en/latest/rllib/rllib-models.html#custom-pytorch-models
14 |
15 | import numpy as np
16 | from gym.spaces import Box, Dict
17 | from ray.rllib.models import ModelCatalog
18 | from ray.rllib.models.modelv2 import restore_original_dimensions
19 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
20 | from ray.rllib.utils import try_import_torch
21 | from ray.rllib.utils.annotations import override
22 |
23 | torch, nn = try_import_torch()
24 |
25 | _ACTION_MASK = "action_mask"
26 |
27 |
28 | class TorchLinear(TorchModelV2, nn.Module):
29 | """
30 | Fully-connected Pytorch policy model.
31 | """
32 |
33 | custom_name = "torch_linear"
34 |
35 | def __init__(
36 | self, obs_space, action_space, num_outputs, model_config, name, fc_dims=None
37 | ):
38 | super().__init__(obs_space, action_space, num_outputs, model_config, name)
39 | nn.Module.__init__(self)
40 |
41 | if fc_dims is None:
42 | fc_dims = [256, 256]
43 |
44 | # Check Observation spaces
45 |
46 | self.observation_space = obs_space.original_space
47 |
48 | if not isinstance(self.observation_space, Dict):
49 | if isinstance(self.observation_space, Box):
50 | raise TypeError(
51 | "({name}) Observation space should be a gym Dict. "
52 | "Is a Box of shape {self.observation_space.shape}"
53 | )
54 | raise TypeError(
55 | f"({name}) Observation space should be a gym Dict. "
56 | "Is {type(self.observation_space))} instead."
57 | )
58 |
59 | flattened_obs_size = self.get_flattened_obs_size()
60 |
61 | # Model only outputs policy logits,
62 | # values are accessed via the self.value_function
63 | self.values = None
64 |
65 | num_fc_layers = len(fc_dims)
66 |
67 | input_dims = [flattened_obs_size] + fc_dims[:-1]
68 | output_dims = fc_dims
69 |
70 | self.fc_dict = nn.ModuleDict()
71 | for fc_layer in range(num_fc_layers):
72 | self.fc_dict[str(fc_layer)] = nn.Sequential(
73 | nn.Linear(input_dims[fc_layer], output_dims[fc_layer]),
74 | nn.ReLU(),
75 | )
76 |
77 | # policy network (list of heads)
78 | policy_heads = [None for _ in range(len(action_space))]
79 | self.output_dims = [] # Network output dimension(s)
80 |
81 | for idx, act_space in enumerate(action_space):
82 | output_dim = act_space.n
83 | self.output_dims += [output_dim]
84 | policy_heads[idx] = nn.Linear(fc_dims[-1], output_dim)
85 | self.policy_head = nn.ModuleList(policy_heads)
86 |
87 | # value-function network head
88 | self.vf_head = nn.Linear(fc_dims[-1], 1)
89 |
90 | # used for action masking
91 | self.action_mask = None
92 |
93 | def get_flattened_obs_size(self):
94 | """Get the total size of the observation after flattening."""
95 | if isinstance(self.observation_space, Box):
96 | obs_size = np.prod(self.observation_space.shape)
97 | elif isinstance(self.observation_space, Dict):
98 | obs_size = 0
99 | for key in sorted(self.observation_space):
100 | if key == _ACTION_MASK:
101 | pass
102 | else:
103 | obs_size += np.prod(self.observation_space[key].shape)
104 | else:
105 | raise NotImplementedError("Observation space must be of Box or Dict type")
106 | return int(obs_size)
107 |
108 | def get_flattened_obs(self, obs):
109 | """Get the flattened observation (ignore the action masks)."""
110 | if isinstance(self.observation_space, Box):
111 | return self.reshape_and_flatten(obs)
112 | if isinstance(self.observation_space, Dict):
113 | flattened_obs_dict = {}
114 | for key in sorted(self.observation_space):
115 | assert key in obs
116 | if key == _ACTION_MASK:
117 | self.action_mask = self.reshape_and_flatten_obs(obs[key])
118 | else:
119 | flattened_obs_dict[key] = self.reshape_and_flatten_obs(obs[key])
120 | flattened_obs = torch.cat(list(flattened_obs_dict.values()), dim=-1)
121 | return flattened_obs
122 | raise NotImplementedError("Observation space must be of Box or Dict type")
123 |
124 | @staticmethod
125 | def reshape_and_flatten_obs(obs):
126 | """Flatten observation."""
127 | assert len(obs.shape) >= 2
128 | batch_dim = obs.shape[0]
129 | return obs.reshape(batch_dim, -1)
130 |
131 | @override(TorchModelV2)
132 | def value_function(self):
133 | """Returns the estimated value function."""
134 | return self.values.reshape(-1)
135 |
136 | @staticmethod
137 | def apply_logit_mask(logits, mask):
138 | """
139 | Mask values of 1 are valid actions.
140 | Add huge negative values to logits with 0 mask values.
141 | """
142 | logit_mask = torch.ones_like(logits) * -10000000
143 | logit_mask = logit_mask * (1 - mask)
144 | return logits + logit_mask
145 |
146 | @override(TorchModelV2)
147 | def forward(self, input_dict, state, seq_lens):
148 | """You should implement forward() of forward_rnn() in your subclass."""
149 | if isinstance(seq_lens, np.ndarray):
150 | seq_lens = torch.Tensor(seq_lens).int()
151 |
152 | # Note: restoring original obs
153 | # as RLlib does not seem to be doing it automatically!
154 | original_obs = restore_original_dimensions(
155 | input_dict["obs"], self.obs_space.original_space, "torch"
156 | )
157 |
158 | obs = self.get_flattened_obs(original_obs)
159 |
160 | # Feed through the FC layers
161 | for layer in range(len(self.fc_dict)):
162 | output = self.fc_dict[str(layer)](obs)
163 | obs = output
164 | logits = output
165 |
166 | # Compute the action probabilities and the value function estimate
167 | # Apply action mask to the logits as well.
168 | action_masks = [None for _ in range(len(self.output_dims))]
169 | if self.action_mask is not None:
170 | start = 0
171 | for idx, dim in enumerate(self.output_dims):
172 | action_masks[idx] = self.action_mask[..., start : start + dim]
173 | start = start + dim
174 | action_logits = [
175 | self.apply_logit_mask(ph(logits), action_masks[idx])
176 | for idx, ph in enumerate(self.policy_head)
177 | ]
178 | self.values = self.vf_head(logits)[..., 0]
179 |
180 | concatenated_action_logits = torch.cat(action_logits, dim=-1)
181 | return torch.reshape(concatenated_action_logits, [-1, self.num_outputs]), state
182 |
183 |
184 | ModelCatalog.register_custom_model("torch_linear", TorchLinear)
185 |
--------------------------------------------------------------------------------
/rice_helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc and MILA.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root
5 | # or https://opensource.org/licenses/BSD-3-Clause
6 |
7 |
8 | """
9 | Helper functions for the rice simulation
10 | """
11 | import os
12 |
13 | import numpy as np
14 | import yaml
15 |
16 | _SMALL_NUM = 1e-0 # small number added to handle consumption blow-up
17 |
18 |
19 | # Load calibration data from yaml files
20 | def read_yaml_data(yaml_file):
21 | """Helper function to read yaml configuration data."""
22 | with open(yaml_file, "r", encoding="utf-8") as file_ptr:
23 | file_data = file_ptr.read()
24 | file_ptr.close()
25 | data = yaml.load(file_data, Loader=yaml.FullLoader)
26 | return data
27 |
28 |
29 | def set_rice_params(yamls_folder=None):
30 | """Helper function to read yaml data and set environment configs."""
31 | assert yamls_folder is not None
32 | dice_params = read_yaml_data(os.path.join(yamls_folder, "default.yml"))
33 | file_list = sorted(os.listdir(yamls_folder)) #
34 | yaml_files = []
35 | for file in file_list:
36 | if file[-4:] == ".yml" and file != "default.yml":
37 | yaml_files.append(file)
38 |
39 | rice_params = []
40 | for file in yaml_files:
41 | rice_params.append(read_yaml_data(os.path.join(yamls_folder, file)))
42 |
43 | # Overwrite rice params
44 | num_regions = len(rice_params)
45 | for k in dice_params["_RICE_CONSTANT"].keys():
46 | dice_params["_RICE_CONSTANT"][k] = [
47 | dice_params["_RICE_CONSTANT"][k]
48 | ] * num_regions
49 | for idx, param in enumerate(rice_params):
50 | for k in param["_RICE_CONSTANT"].keys():
51 | dice_params["_RICE_CONSTANT"][k][idx] = param["_RICE_CONSTANT"][k]
52 |
53 | return dice_params, num_regions
54 |
55 |
56 | # RICE dynamics
57 | def get_mitigation_cost(p_b, theta_2, delta_pb, timestep, intensity):
58 | """Obtain the cost for mitigation."""
59 | return p_b / (1000 * theta_2) * pow(1 - delta_pb, timestep - 1) * intensity
60 |
61 |
62 | def get_exogenous_emissions(f_0, f_1, t_f, timestep):
63 | """Obtain the amount of exogeneous emissions."""
64 | return f_0 + min(f_1 - f_0, (f_1 - f_0) / t_f * (timestep - 1))
65 |
66 |
67 | def get_land_emissions(e_l0, delta_el, timestep, num_regions):
68 | """Obtain the amount of land emissions."""
69 | return e_l0 * pow(1 - delta_el, timestep - 1)/num_regions
70 |
71 |
72 | def get_production(production_factor, capital, labor, gamma):
73 | """Obtain the amount of goods produced."""
74 | return production_factor * pow(capital, gamma) * pow(labor / 1000, 1 - gamma)
75 |
76 |
77 | def get_damages(t_at, a_1, a_2, a_3):
78 | """Obtain damages."""
79 | return 1 / (1 + a_1 * t_at + a_2 * pow(t_at, a_3))
80 |
81 |
82 | def get_abatement_cost(mitigation_rate, mitigation_cost, theta_2):
83 | """Compute the abatement cost."""
84 | return mitigation_cost * pow(mitigation_rate, theta_2)
85 |
86 |
87 | def get_gross_output(damages, abatement_cost, production):
88 | """Compute the gross production output, taking into account
89 | damages and abatement cost."""
90 | return damages * (1 - abatement_cost) * production
91 |
92 |
93 | def get_investment(savings, gross_output):
94 | """Obtain the investment cost."""
95 | return savings * gross_output
96 |
97 |
98 | def get_consumption(gross_output, investment, exports):
99 | """Obtain the consumption cost."""
100 | total_exports = np.sum(exports)
101 | assert gross_output - investment - total_exports > -1e-5, "consumption cannot be negative!"
102 | return max(0.0, gross_output - investment - total_exports)
103 |
104 |
105 | def get_max_potential_exports(x_max, gross_output, investment):
106 | """Determine the maximum potential exports."""
107 | if x_max * gross_output <= gross_output - investment:
108 | return x_max * gross_output
109 | return gross_output - investment
110 |
111 |
112 | def get_capital_depreciation(x_delta_k, x_delta):
113 | """Compute the global capital depreciation."""
114 | return pow(1 - x_delta_k, x_delta)
115 |
116 |
117 | def get_global_temperature(
118 | phi_t, temperature, b_t, f_2x, m_at, m_at_1750, exogenous_emissions
119 | ):
120 | """Get the temperature levels."""
121 | return np.dot(phi_t, temperature) + np.dot(
122 | b_t, f_2x * np.log(m_at / m_at_1750) / np.log(2) + exogenous_emissions
123 | )
124 |
125 |
126 | def get_aux_m(intensity, mitigation_rate, production, land_emissions):
127 | """Auxiliary variable to denote carbon mass levels."""
128 | return intensity * (1 - mitigation_rate) * production + land_emissions
129 |
130 |
131 | def get_global_carbon_mass(phi_m, carbon_mass, b_m, aux_m):
132 | """Get the carbon mass level."""
133 | return np.dot(phi_m, carbon_mass) + np.dot(b_m, aux_m)
134 |
135 |
136 | def get_capital(capital_depreciation, capital, delta, investment):
137 | """Evaluate capital."""
138 | return capital_depreciation * capital + delta * investment
139 |
140 |
141 | def get_labor(labor, l_a, l_g):
142 | """Compute total labor."""
143 | return labor * pow((1 + l_a) / (1 + labor), l_g)
144 |
145 |
146 | def get_production_factor(production_factor, g_a, delta_a, delta, timestep):
147 | """Compute the production factor."""
148 | return production_factor * (
149 | np.exp(0.0033) + g_a * np.exp(-delta_a * delta * (timestep - 1))
150 | )
151 |
152 |
153 | def get_carbon_intensity(intensity, g_sigma, delta_sigma, delta, timestep):
154 | """Determine the carbon emission intensity."""
155 | return intensity * np.exp(
156 | -g_sigma * pow(1 - delta_sigma, delta * (timestep - 1)) * delta
157 | )
158 |
159 |
160 | def get_utility(labor, consumption, alpha):
161 | """Obtain the utility."""
162 | return (
163 | (labor / 1000.0)
164 | * (pow(consumption / (labor / 1000.0) + _SMALL_NUM, 1 - alpha) - 1)
165 | / (1 - alpha)
166 | )
167 |
168 |
169 | def get_social_welfare(utility, rho, delta, timestep):
170 | """Compute social welfare"""
171 | return utility / pow(1 + rho, delta * timestep)
172 |
173 |
174 | def get_armington_agg(
175 | c_dom,
176 | c_for, # np.array
177 | sub_rate=0.5, # in (0,1)
178 | dom_pref=0.5, # in [0,1]
179 | for_pref=None, # np.array
180 | ):
181 | """
182 | Armington aggregate from Lessmann,2009.
183 | Consumption goods from different regions act as imperfect substitutes.
184 | As such, consumption of domestic and foreign goods are scaled according to
185 | relative preferences, as well as a substitution rate, which are modeled
186 | by a CES functional form.
187 | Inputs :
188 | `C_dom` : A scalar representing domestic consumption. The value of
189 | C_dom is what is left over from initial production after
190 | investment and exports are deducted.
191 | `C_for` : An array reprensenting foreign consumption. Each element
192 | is the consumption imported from a given country.
193 | `sub_rate` : A substitution parameter in (0,1). The elasticity of
194 | substitution is 1 / (1 - sub_rate).
195 | `dom_pref` : A scalar in [0,1] representing the relative preference for
196 | domestic consumption over foreign consumption.
197 | `for_pref` : An array of the same size as `C_for`. Each element is the
198 | relative preference for foreign goods from that country.
199 | """
200 |
201 | c_dom_pref = dom_pref * (c_dom ** sub_rate)
202 | c_for_pref = np.sum(for_pref * pow(c_for, sub_rate))
203 |
204 | c_agg = (c_dom_pref + c_for_pref) ** (1 / sub_rate) # CES function
205 | return c_agg
206 |
--------------------------------------------------------------------------------
/scripts/train_with_warp_drive.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc and MILA.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root
5 | # or https://opensource.org/licenses/BSD-3-Clause
6 |
7 |
8 | """
9 | Training script for the rice environment using WarpDrive
10 | www.github.com/salesforce/warp-drive
11 | """
12 |
13 | import logging
14 | import os
15 | import shutil
16 | import subprocess
17 | import sys
18 | import numpy as np
19 | import yaml
20 | from desired_outputs import desired_outputs
21 |
22 | sys.path.append("./")
23 | from opt_helper import get_mean_std
24 | from fixed_paths import PUBLIC_REPO_DIR
25 |
26 | sys.path.append(PUBLIC_REPO_DIR)
27 |
28 | from scripts.run_unittests import import_class_from_path
29 |
30 | # Set logger level e.g., DEBUG, INFO, WARNING, ERROR.
31 | logging.getLogger().setLevel(logging.ERROR)
32 |
33 |
34 | def perform_other_imports():
35 | """
36 | WarpDrive-related imports.
37 | """
38 | import torch
39 |
40 | num_gpus_available = torch.cuda.device_count()
41 | assert num_gpus_available > 0, "This script needs a GPU to run!"
42 |
43 | from warp_drive.env_wrapper import EnvWrapper
44 | from warp_drive.training.trainer import Trainer
45 | from warp_drive.utils.env_registrar import EnvironmentRegistrar
46 |
47 | return torch, EnvWrapper, Trainer, EnvironmentRegistrar
48 |
49 |
50 | try:
51 | other_imports = perform_other_imports()
52 | except ImportError:
53 | print("Installing requirements...")
54 | subprocess.call(["pip", "install", "rl-warp-drive>=1.6.5"])
55 |
56 | other_imports = perform_other_imports()
57 |
58 | torch, EnvWrapper, Trainer, EnvironmentRegistrar = other_imports
59 |
60 |
61 | def create_trainer(run_config=None, source_dir=None, seed=None):
62 | """
63 | Create the WarpDrive trainer.
64 | """
65 | torch.cuda.FloatTensor(8) # add this line for successful cuda_init
66 |
67 | assert run_config is not None
68 | if source_dir is None:
69 | source_dir = PUBLIC_REPO_DIR
70 | if seed is not None:
71 | run_config["trainer"]["seed"] = seed
72 |
73 | # Create a wrapped environment object via the EnvWrapper
74 | # Ensure that use_cuda is set to True (in order to run on the GPU)
75 |
76 | # Register the environment
77 | env_registrar = EnvironmentRegistrar()
78 |
79 | rice_cuda_class = import_class_from_path(
80 | "RiceCuda", os.path.join(source_dir, "rice_cuda.py")
81 | )
82 |
83 | env_registrar.add_cuda_env_src_path(
84 | rice_cuda_class.name, os.path.join(source_dir, "rice_build.cu")
85 | )
86 |
87 | env_wrapper = EnvWrapper(
88 | rice_cuda_class(**run_config["env"]),
89 | num_envs=run_config["trainer"]["num_envs"],
90 | use_cuda=True,
91 | env_registrar=env_registrar,
92 | )
93 |
94 | # Policy mapping to agent ids: agents can share models
95 | # The policy_tag_to_agent_id_map dictionary maps
96 | # policy model names to agent ids.
97 | # ----------------------------------------------------
98 | policy_tag_to_agent_id_map = {
99 | "regions": list(range(env_wrapper.env.num_agents)),
100 | }
101 |
102 | # Create the Trainer object
103 | # -------------------------
104 | trainer_obj = Trainer(
105 | env_wrapper=env_wrapper,
106 | config=run_config,
107 | policy_tag_to_agent_id_map=policy_tag_to_agent_id_map,
108 | )
109 | return trainer_obj, trainer_obj.save_dir
110 |
111 |
112 | def load_model_checkpoints(trainer=None, save_directory=None, ckpt_idx=-1):
113 | """
114 | Load trained model checkpoints.
115 | """
116 | assert trainer is not None
117 | assert save_directory is not None
118 | assert os.path.exists(save_directory), (
119 | "Invalid folder path. "
120 | "Please specify a valid directory to load the checkpoints from."
121 | )
122 | files = [file for file in os.listdir(save_directory) if file.endswith("state_dict")]
123 | assert len(files) >= len(trainer.policies), "Missing policy checkpoints"
124 |
125 | ckpts_dict = {}
126 | for policy in trainer.policies_to_train:
127 | policy_models = [
128 | os.path.join(save_directory, file) for file in files if policy in file
129 | ]
130 | # If there are multiple files, then use the ckpt_idx to specify the checkpoint
131 | assert ckpt_idx < len(policy_models)
132 | sorted_policy_models = sorted(policy_models, key=os.path.getmtime)
133 | policy_model_file = sorted_policy_models[ckpt_idx]
134 | logging.info(f"Loaded model checkpoints {policy_model_file}.")
135 |
136 | ckpts_dict.update({policy: policy_model_file})
137 | trainer.load_model_checkpoint(ckpts_dict)
138 |
139 |
140 | def fetch_episode_states(trainer_obj=None, episode_states=None, env_id=None):
141 | """
142 | Helper function to rollout the env and fetch env states for an episode.
143 | """
144 | assert trainer_obj is not None
145 | assert isinstance(
146 | episode_states, list
147 | ), "Please pass the 'episode states' args as a list."
148 | assert len(episode_states) > 0
149 | return trainer_obj.fetch_episode_states(episode_states, env_id)
150 |
151 |
152 | def copy_source_files(trainer):
153 | """
154 | Copy source files to the saving directory.
155 | """
156 | for file in [
157 | "rice.py",
158 | "rice_helpers.py",
159 | "rice_cuda.py",
160 | "rice_step.cu",
161 | "rice_build.cu",
162 | ]:
163 | shutil.copyfile(
164 | os.path.join(PUBLIC_REPO_DIR, file),
165 | os.path.join(trainer.save_dir, file),
166 | )
167 |
168 | for file in [
169 | "rice_warpdrive.yaml",
170 | ]:
171 | shutil.copyfile(
172 | os.path.join(PUBLIC_REPO_DIR, "scripts", file),
173 | os.path.join(trainer.save_dir, file),
174 | )
175 |
176 | # Add an identifier file
177 | with open(
178 | os.path.join(trainer.save_dir, ".warpdrive"), "x", encoding="utf-8"
179 | ) as file_pointer:
180 | pass
181 | file_pointer.close()
182 |
183 |
184 | def trainer(
185 | negotiation_on=0,
186 | num_envs=100,
187 | train_batch_size=1024,
188 | num_episodes=30000,
189 | lr=0.0005,
190 | model_params_save_freq=5000,
191 | desired_outputs=desired_outputs,
192 | output_all_envs=False,
193 | ):
194 | """
195 | Main function to run the trainer.
196 | """
197 | # Load the run_config
198 | print("Training with WarpDrive...")
199 |
200 | # Read the run configurations specific to the environment.
201 | # Note: The run config yaml(s) can be edited at warp_drive/training/run_configs
202 | # -----------------------------------------------------------------------------
203 | config_path = os.path.join(PUBLIC_REPO_DIR, "scripts", "rice_warpdrive.yaml")
204 | if not os.path.exists(config_path):
205 | raise ValueError(
206 | "The run configuration is missing. Please make sure the correct path"
207 | "is specified."
208 | )
209 |
210 | with open(config_path, "r", encoding="utf8") as fp:
211 | run_configuration = yaml.safe_load(fp)
212 | run_configuration["env"]["negotiation_on"] = negotiation_on
213 | run_configuration["trainer"]["num_envs"] = num_envs
214 | run_configuration["trainer"]["train_batch_size"] = train_batch_size
215 | run_configuration["trainer"]["num_episodes"] = num_episodes
216 | run_configuration["policy"]["regions"]["lr"] = lr
217 | run_configuration["saving"]["model_params_save_freq"] = model_params_save_freq
218 | # run_configuration trainer
219 | # --------------
220 | trainer_object, _ = create_trainer(run_config=run_configuration)
221 |
222 | # Copy the source files into the results directory
223 | # ------------------------------------------------
224 | copy_source_files(trainer_object)
225 |
226 | # Perform training!
227 | # -----------------
228 | trainer_object.train()
229 |
230 | # Create a (zipped) submission file
231 | # ---------------------------------
232 | subprocess.call(
233 | [
234 | "python",
235 | os.path.join(PUBLIC_REPO_DIR, "scripts", "create_submission_zip.py"),
236 | "--results_dir",
237 | trainer_object.save_dir,
238 | ]
239 | )
240 | outputs_ts = [
241 | fetch_episode_states(trainer_object, desired_outputs, env_id=i)
242 | for i in range(num_envs)
243 | ]
244 | for i in range(len(outputs_ts)):
245 | outputs_ts[i]["global_consumption"] = np.sum(
246 | outputs_ts[i]["consumption_all_regions"], axis=-1
247 | )
248 | outputs_ts[i]["global_production"] = np.sum(
249 | outputs_ts[i]["gross_output_all_regions"], axis=-1
250 | )
251 | if not output_all_envs:
252 | outputs_ts, _ = get_mean_std(outputs_ts)
253 | # Shut off the trainer gracefully
254 | # -------------------------------
255 | trainer_object.graceful_close()
256 | return trainer_object, outputs_ts
257 |
258 |
259 | if __name__ == "__main__":
260 | print("Training with WarpDrive...")
261 |
262 | # Read the run configurations specific to the environment.
263 | # Note: The run config yaml(s) can be edited at warp_drive/training/run_configs
264 | # -----------------------------------------------------------------------------
265 | config_path = os.path.join(PUBLIC_REPO_DIR, "scripts", "rice_warpdrive.yaml")
266 | if not os.path.exists(config_path):
267 | raise ValueError(
268 | "The run configuration is missing. Please make sure the correct path"
269 | "is specified."
270 | )
271 |
272 | with open(config_path, "r", encoding="utf8") as fp:
273 | run_configuration = yaml.safe_load(fp)
274 |
275 | # Create trainer
276 | # --------------
277 | trainer_object, _ = create_trainer(run_config=run_configuration)
278 |
279 | # Copy the source files into the results directory
280 | # ------------------------------------------------
281 | copy_source_files(trainer_object)
282 |
283 | # Perform training!
284 | # -----------------
285 | trainer_object.train()
286 |
287 | # Create a (zipped) submission file
288 | # ---------------------------------
289 | subprocess.call(
290 | [
291 | "python",
292 | os.path.join(PUBLIC_REPO_DIR, "scripts", "create_submission_zip.py"),
293 | "--results_dir",
294 | trainer_object.save_dir,
295 | ]
296 | )
297 |
298 | # Shut off the trainer gracefully
299 | # -------------------------------
300 | trainer_object.graceful_close()
301 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Competition: Fostering Global Cooperation to Mitigate Climate Change
2 |
3 | [](https://pytorch.org/docs/1.12/)
4 | [](https://www.python.org/downloads/release/python-3713/)
5 | [](https://github.com/salesforce/warp-drive/)
6 | [](https://docs.ray.io/en/latest/index.html)
7 | [](https://arxiv.org/abs/2208.07004)
8 | [](https://colab.research.google.com/github/mila-iqia/climate-cooperation-competition/blob/main/Colab_Tutorial.ipynb)
9 |
10 |
(Code Tutorial Notebook on Kaggle with free GPU available)
11 |
12 |
13 | This is the code respository for the competition on modeling global cooperation in the RICE-N Integrated Assessment Model. This competition is co-organized by MILA and Salesforce Research.
14 |
15 | The RICE-N IAM is an agent-based model that incorporates DICE climate-economic dynamics and multi-lateral negotiation protocols between several fictitious nations.
16 |
17 | In this competition, you will design negotiation protocols and contracts between nations. You will use the simulation and agents to evaluate their impact on the climate and the economy.
18 |
19 | We recommend that GPU users use ``warp_drive`` and CPU users use ``rllib``.
20 |
21 | ## Tutorial
22 | - For all information and the leaderboard, see [our official website](https://www.ai4climatecoop.org).
23 | - [How to Get Started](getting_started.ipynb)
24 | - [Code Tutorial Notebook with **free GPU**](https://colab.research.google.com/github/mila-iqia/climate-cooperation-competition/blob/main/Colab_Tutorial.ipynb)
25 | - [Code Kaggle Tutorial Notebook with **free GPU**](https://www.kaggle.com/kernels/fork-version/105300459)
26 |
27 | ## Resources
28 | - For the mathematical background and scientific references, please see [the white paper](https://deliverypdf.ssrn.com/delivery.php?ID=579098091025080122123095015088114126057046084059055038121023114094110112070107123088059057002107022006023123122016086089001013042072002040020075022078097115093071118048047053064022064117095120085074022123010099031092026025094015125094099080071097079070&EXT=pdf&INDEX=TRUE).
29 | - Other free GPU resources: [Baidu Paddle](https://aistudio.baidu.com/), [MegStudio](https://studio.brainpp.com/)
30 |
31 |
32 | ## Installation
33 |
34 | Notice: we recommend using `Linux` or `MacOS`. For Windows users, we recommend to use virtual machine running `Ubuntu 20.04` or [Windows Subsystem for Linux](https://docs.microsoft.com/en-us/windows/wsl/install).
35 |
36 | You can get a copy of the code by cloning the repo using Git:
37 |
38 | ```
39 | git clone https://github.com/mila-iqia/climate-cooperation-competition
40 | cd climate-cooperation-competition
41 | ```
42 |
43 | As an alternative, one can also use:
44 |
45 | ```
46 | git clone https://e.coding.net/ai4climatecoop/ai4climatecoop/climate-cooperation-competition.git
47 | cd climate-cooperation-competition
48 | ```
49 |
50 | We recommend using a virtual environment (such as provided by ```virtualenv``` or Anaconda).
51 |
52 | You can install the dependencies using pip:
53 |
54 | ```
55 | pip install -r requirements.txt
56 | ```
57 |
58 | ## Get Started
59 | Then run the getting started Jupyter notebook, by starting Jupyter:
60 |
61 | ```
62 | jupyter notebook
63 | ```
64 |
65 | and then navigating to [getting_started.ipynb](getting_started.ipynb).
66 |
67 | It provides a quick walkthrough for registering for the competition and creating a valid submission.
68 |
69 |
70 | ## Training with reinforcement learning
71 |
72 | RL agents can be trained using the RICE-N simulation using one of these two frameworks:
73 |
74 | 1. [RLlib](https://docs.ray.io/en/latest/rllib/index.html#:~:text=RLlib%20is%20an%20open%2Dsource,large%20variety%20of%20industry%20applications): The pythonic environment can be trained on your local CPU machine using open-source RL framework, RLlib.
75 | 2. [WarpDrive](https://github.com/salesforce/warp-drive): WarpDrive is a GPU-based framework that allows for [over 10x faster training](https://arxiv.org/pdf/2108.13976.pdf) compared to CPU-based training. It requires the simulation to be written out in CUDA C, and we also provide a starter version of the simulation environment written in CUDA C ([rice_step.cu](rice_step.cu))
76 |
77 | We also provide starter scripts to train the simulation you build with either of the above frameworks.
78 |
79 | Note that we only allow these two options, since our backend submission evaluation process only supports these at the moment.
80 |
81 |
82 | For training with RLlib, `rllib (1.0.0)`, `torch (1.9.0)` and `gym (0.21)` packages are required.
83 |
84 |
85 |
86 | For training with WarpDrive, the `rl-warp-drive (>=1.6.5)` package is needed.
87 |
88 | Note that these requirements are automatically installed (or updated) when you run the corresponding training scripts.
89 |
90 |
91 | ## Docker image (for GPU users)
92 |
93 | We have also provided a sample dockerfile for your reference. It mainly uses a Nvidia PyTorch base image, and installs the `pycuda` package as well. Note: `pycuda` is only required if you would like to train using WarpDrive.
94 |
95 |
96 | ## Docker image (for CPU users)
97 |
98 | Thanks for the contribution from @muxspace. We also have an end-to-end docker environment ready for CPU users. Please refer to [README_CPU.md](README_CPU.md) for more details.
99 |
100 | # Customizing and running the simulation
101 |
102 | See the [Colab_Tutorial.ipynb](Tutorial.ipynb). [](https://colab.research.google.com/github/mila-iqia/climate-cooperation-competition/blob/main/Colab_Tutorial.ipynb) for details.
103 |
104 | It provides examples on modifying the code to implement different negotiation protocols. It describes ways of changing the agent observations and action spaces corresponding to the proposed negotiation protocols and implementing the negotiation logic in the provided code.
105 |
106 | The notebook has a walkthrough of how to train RL agents with the simulation code and how to visualize results from the simulation after running it with a set of agents.
107 |
108 | For those who have **limited access to Colab**, please try to use [**free GPUs on Kaggle**](https://www.kaggle.com/kernels/fork-version/105300459). Please notice that Kaggle platform requires mobile phone verification to be able to access the GPUs. One may find the **settings** to get GPUs and internet connect on the right hand side after clicking on the link above and login.
109 |
110 | # Training RL agents in your simulation
111 |
112 | Once you build your simulation, you can use either of the following scripts to perform training.
113 |
114 | - [train_with_rllib.py](/scripts/train_with_rllib.py): this script performs end-to-end training with RLlib. The experiment run configuration will be read in from [rice_rllib.yaml](/scripts/rice_rllib.yaml), which contains the environment configuration, logging and saving settings and the trainer and policy network parameters. The duration of training can be set via the `num_episodes` parameter. We have also provided an initial implementation of a linear PyTorch policy model in [torch_models.py](/scripts/torch_models.py). You can [add other policy models](https://docs.ray.io/en/latest/rllib/rllib-concepts.html) you wish to use into that file.
115 |
116 | USAGE: The training script (with RLlib) is invoked using (from the root directory)
117 | ```commandline
118 | python scripts/train_with_rllib.py
119 | ```
120 |
121 | - [train_with_warp_drive.py](/scripts/train_with_warp_drive.py): this script performs end-to-end training with WarpDrive. The experiment run configuration will be read in from [rice_warpdrive.yaml](/scripts/rice_warpdrive.yaml). Currently, WarpDrive just supports the Advantage Actor-Critic (A2C) and the Proximal Policy Optimization (PPO) algorithms, and the fully-connected policy model.
122 |
123 | USAGE: The training script (with WarpDrive) is invoked using
124 | ```commandline
125 | python scripts/train_with_warpdrive.py
126 | ```
127 |
128 | As training progresses, some key metrics (such as the mean episode reward) are printed on screen for your reference. At the end of training, a zipped submission file is automatically created and saved for your reference. The zipped file essentially comprises the following
129 |
130 | - An identifier file (`.rllib` or `.warpdrive`) indicating which framework was used towards training.
131 | - The environment files - [rice.py](rice.py) and [rice_helpers.py](rice_helpers.py).
132 | - A copy of the yaml configuration file ([rice_rllib.yaml](/scripts/rice_rllib.yaml) or [rice_warpdrive.yaml](/scripts/rice_warpdrive.yaml)) used for training.
133 | - PyTorch policy model(s) (of type ".state_dict") containing the trained weights for the policy network(s). Only the trained policy model for the final timestep will be copied over into the submission zip. If you would like to instead submit the trained policy model at a different timestep, please see the section below on creating your submission file.
134 | - For submissions using WarpDrive, the submission will also contain CUDA-specific files [rice_step.cu](rice_step.cu) and [rice_cuda](rice_cuda.py) that were used for training.
135 |
136 |
137 | # Contributing
138 |
139 | We are always looking for contributors from various domains to help us make this simulation more realistic.
140 |
141 | If there are bugs or corner cases, please open a PR detailing the issue and consider submitting to Track 3!
142 |
143 |
144 | # Citation
145 |
146 | To cite this code, please use the information in [CITATION.cff](CITATION.cff) and the following bibtex entry:
147 |
148 | ```
149 | @software{Zhang_RICE-N_2022,
150 | author = {Zhang, Tianyu and Srinivasa, Sunil and Williams, Andrew and Phade, Soham and Zhang, Yang and Gupta, Prateek and Bengio, Yoshua and Zheng, Stephan},
151 | month = {7},
152 | title = {{RICE-N}},
153 | url = {https://github.com/mila-iqia/climate-cooperation-competition},
154 | version = {1.0.0},
155 | year = {2022}
156 | }
157 |
158 | @misc{https://doi.org/10.48550/arxiv.2208.07004,
159 | doi = {10.48550/ARXIV.2208.07004},
160 | url = {https://arxiv.org/abs/2208.07004},
161 | author = {Zhang, Tianyu and Williams, Andrew and Phade, Soham and Srinivasa, Sunil and Zhang, Yang and Gupta, Prateek and Bengio, Yoshua and Zheng, Stephan},
162 | title = {AI for Global Climate Cooperation: Modeling Global Climate Negotiations, Agreements, and Long-Term Cooperation in RICE-N},
163 | publisher = {arXiv},
164 | year = {2022}
165 | }
166 |
167 | ```
168 |
169 | # License
170 |
171 | For license information, see ```LICENSE.txt```.
172 |
--------------------------------------------------------------------------------
/scripts/run_unittests.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc and MILA.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root
5 | # or https://opensource.org/licenses/BSD-3-Clause
6 |
7 |
8 | """
9 | Unit tests for the rice simulation
10 | """
11 | import argparse
12 | import importlib.util as iu
13 | import logging
14 | import os
15 | import shutil
16 | import subprocess
17 | import sys
18 | import time
19 | import unittest
20 |
21 | import numpy as np
22 |
23 | # from evaluate_submission import get_results_dir
24 | from fixed_paths import PUBLIC_REPO_DIR
25 |
26 | sys.path.append(PUBLIC_REPO_DIR)
27 |
28 | _REGION_YAMLS = "region_yamls"
29 |
30 | # Set logger level e.g., DEBUG, INFO, WARNING, ERROR.
31 | logging.getLogger().setLevel(logging.ERROR)
32 |
33 | _BASE_CODE_PATH = (
34 | "https://raw.githubusercontent.com/mila-iqia/climate-cooperation-competition/main"
35 | )
36 | _BASE_RICE_PATH = os.path.join(_BASE_CODE_PATH, "rice.py")
37 | _BASE_RICE_HELPERS_PATH = os.path.join(_BASE_CODE_PATH, "rice_helpers.py")
38 | _BASE_RICE_BUILD_PATH = os.path.join(_BASE_CODE_PATH, "rice_build.cu")
39 | _BASE_CONSISTENCY_CHECKER_PATH = os.path.join(
40 | _BASE_CODE_PATH, "scripts/run_cpu_gpu_env_consistency_checks.py"
41 | )
42 |
43 |
44 | def import_class_from_path(class_name=None, path=None):
45 | """
46 | Helper function to import class from a path.
47 | """
48 | assert class_name is not None
49 | assert path is not None
50 | spec = iu.spec_from_file_location(class_name, path)
51 | module_from_spec = iu.module_from_spec(spec)
52 | spec.loader.exec_module(module_from_spec)
53 | return getattr(module_from_spec, class_name)
54 |
55 |
56 | def fetch_base_env(base_folder=".tmp/_base"):
57 | """
58 | Download the base version of the code from GitHub.
59 | """
60 | if not base_folder.startswith("/"):
61 | base_folder = os.path.join(PUBLIC_REPO_DIR, base_folder)
62 | # print(f"Using tmp dir {base_folder}")
63 | if os.path.exists(base_folder):
64 | shutil.rmtree(base_folder)
65 | os.makedirs(base_folder, exist_ok=False)
66 |
67 | print(
68 | "\nDownloading a base version of the code from GitHub"
69 | " to run consistency checks..."
70 | )
71 | prev_dir = os.getcwd()
72 | os.chdir(base_folder)
73 | subprocess.call(["wget", "-O", "rice.py", _BASE_RICE_PATH])
74 | subprocess.call(["wget", "-O", "rice_helpers.py", _BASE_RICE_HELPERS_PATH])
75 | if "region_yamls" not in os.listdir(base_folder):
76 | shutil.copytree(
77 | os.path.join(PUBLIC_REPO_DIR, "region_yamls"),
78 | os.path.join(base_folder, "region_yamls"),
79 | )
80 |
81 | base_rice = import_class_from_path("Rice", os.path.join(base_folder, "rice.py"))()
82 |
83 | # Clean up base code
84 | os.chdir(prev_dir)
85 | shutil.rmtree(base_folder)
86 | return base_rice
87 |
88 |
89 | class TestEnv(unittest.TestCase):
90 | """
91 | The env testing class.
92 | """
93 |
94 | @classmethod
95 | def setUpClass(cls):
96 | """Set-up"""
97 | # Initialization
98 | cls.framework = "rllib"
99 | assert cls.results_dir is not None
100 |
101 | # Note: results_dir attributed set in __main__.
102 | if _REGION_YAMLS not in os.listdir(cls.results_dir):
103 | shutil.copytree(
104 | os.path.join(PUBLIC_REPO_DIR, "region_yamls"),
105 | os.path.join(cls.results_dir, "region_yamls"),
106 | )
107 |
108 | if ".warpdrive" in os.listdir(cls.results_dir):
109 | cls.framework = "warpdrive"
110 | # Copy the consistency checker file into the results_dir
111 | prev_dir = os.getcwd()
112 | os.chdir(cls.results_dir)
113 | os.makedirs("scripts", exist_ok=True)
114 | subprocess.call(
115 | [
116 | "wget",
117 | "-O",
118 | "scripts/run_cpu_gpu_env_consistency_checks.py",
119 | _BASE_CONSISTENCY_CHECKER_PATH,
120 | ]
121 | )
122 | subprocess.call(
123 | [
124 | "wget",
125 | "-O",
126 | "rice_build.cu",
127 | _BASE_RICE_BUILD_PATH,
128 | ]
129 | )
130 | os.chdir(prev_dir)
131 | else:
132 | assert ".rllib" in os.listdir(cls.results_dir), (
133 | f"Missing identifier file! "
134 | f"Either the .rllib or the .warpdrive "
135 | f"file must be present in the results directory: {cls.results_dir}"
136 | )
137 |
138 | cls.base_env = fetch_base_env() # Fetch the base version from GitHub
139 | try:
140 | cls.env = import_class_from_path(
141 | "Rice", os.path.join(cls.results_dir, "rice.py")
142 | )()
143 | except Exception as err:
144 | raise ValueError(
145 | "The Rice environment could not be instantiated !"
146 | ) from err
147 |
148 | base_env_action_nvec = np.array(cls.base_env.action_space[0].nvec)
149 | cls.base_env_random_actions = {
150 | agent_id: np.random.randint(
151 | low=0 * base_env_action_nvec, high=base_env_action_nvec - 1
152 | )
153 | for agent_id in range(cls.base_env.num_agents)
154 | }
155 | sample_agent_id = 0
156 | env_action_nvec = np.array(cls.env.action_space[sample_agent_id].nvec)
157 | len_negotiation_actions = len(env_action_nvec) - len(base_env_action_nvec)
158 | cls.env_random_actions = {
159 | agent_id: np.append(
160 | cls.base_env_random_actions[agent_id],
161 | np.zeros(len_negotiation_actions, dtype=np.int32),
162 | )
163 | for agent_id in range(cls.env.num_agents)
164 | }
165 |
166 | def test_env_attributes(self):
167 | """
168 | Test the env attributes are consistent with the base version.
169 | """
170 | for attribute in [
171 | "all_constants",
172 | "num_regions",
173 | "num_agents",
174 | "start_year",
175 | "end_year",
176 | "num_discrete_action_levels",
177 | ]:
178 | np.testing.assert_array_equal(
179 | getattr(self.base_env, attribute), getattr(self.env, attribute)
180 | )
181 |
182 | features = [
183 | "activity_timestep",
184 | "global_temperature",
185 | "global_carbon_mass",
186 | "global_exogenous_emissions",
187 | "global_land_emissions",
188 | "capital_all_regions",
189 | "capital_depreciation_all_regions",
190 | "labor_all_regions",
191 | "gross_output_all_regions",
192 | "investment_all_regions",
193 | "consumption_all_regions",
194 | "savings_all_regions",
195 | "mitigation_rate_all_regions",
196 | "tariffs",
197 | "max_export_limit_all_regions",
198 | "current_balance_all_regions",
199 | "production_factor_all_regions",
200 | "intensity_all_regions",
201 | "mitigation_cost_all_regions",
202 | "damages_all_regions",
203 | "abatement_cost_all_regions",
204 | "production_all_regions",
205 | "utility_all_regions",
206 | "social_welfare_all_regions",
207 | "reward_all_regions",
208 | "scaled_imports",
209 | ]
210 |
211 | # Test equivalence after reset
212 | self.base_env.reset()
213 | self.env.reset()
214 |
215 | for feature in features:
216 | np.testing.assert_array_equal(
217 | getattr(self.base_env, "global_state")[feature]["value"][0],
218 | getattr(self.env, "global_state")[feature]["value"][0],
219 | )
220 |
221 | # Test equivalence after stepping through the env
222 | for timestep in range(self.base_env.episode_length):
223 | self.base_env.timestep += 1
224 | self.base_env.climate_and_economy_simulation_step(
225 | self.base_env_random_actions
226 | )
227 |
228 | self.env.timestep += 1
229 | self.env.climate_and_economy_simulation_step(self.env_random_actions)
230 |
231 | for feature in features:
232 | np.testing.assert_array_equal(
233 | getattr(self.base_env, "global_state")[feature]["value"][timestep],
234 | getattr(self.env, "global_state")[feature]["value"][timestep],
235 | )
236 |
237 | def test_env_reset(self):
238 | """
239 | Test the env reset output
240 | """
241 | obs_at_reset = self.env.reset()
242 | self.assertEqual(len(obs_at_reset), self.env.num_agents)
243 |
244 | def test_env_step(self):
245 | """
246 | Test the env step output
247 | """
248 | assert isinstance(
249 | self.env.action_space, dict
250 | ), "Action space must be a dictionary keyed by agent ids."
251 | assert sorted(list(self.env.action_space.keys())) == list(
252 | range(self.env.num_agents)
253 | )
254 |
255 | # Test with all random actions
256 | obs, rew, done, _ = self.env.step(self.env_random_actions)
257 | self.assertEqual(list(obs.keys()), list(rew.keys()))
258 | assert list(done.keys()) == ["__all__"]
259 |
260 | def test_cpu_gpu_consistency_checks(self):
261 | """
262 | Run the CPU/GPU environment consistency checks
263 | (only if using the CUDA version of the env)
264 | """
265 | if self.framework == "warpdrive":
266 | # Execute the CPU-GPU consistency checks
267 | os.chdir(self.results_dir)
268 | subprocess.check_output(
269 | ["python", "scripts/run_cpu_gpu_env_consistency_checks.py"]
270 | )
271 |
272 |
273 | def get_results_dir():
274 | """
275 | Obtain the 'results' directory from the system arguments.
276 | """
277 | parser = argparse.ArgumentParser()
278 | parser.add_argument(
279 | "--results_dir",
280 | "-r",
281 | type=str,
282 | default=".",
283 | help="the directory where all the submission files are saved. Can also be "
284 | "the zipped file containing all the submission files.",
285 | )
286 | args = parser.parse_args()
287 |
288 | if "results_dir" not in args:
289 | raise ValueError(
290 | "Please provide a results directory to evaluate with the argument -r"
291 | )
292 | if not os.path.exists(args.results_dir):
293 | raise ValueError(
294 | "The results directory is missing. Please make sure the correct path "
295 | "is specified!"
296 | )
297 | try:
298 | results_dir = args.results_dir
299 |
300 | # Also handle a zipped file
301 | if results_dir.endswith(".zip"):
302 | unzipped_results_dir = os.path.join("/tmp", str(time.time()))
303 | shutil.unpack_archive(results_dir, unzipped_results_dir)
304 | results_dir = unzipped_results_dir
305 | return results_dir, parser
306 | except Exception as err:
307 | raise ValueError("Cannot obtain the results directory") from err
308 |
309 | # if __name__ == "__main__":
310 | # Skip all of this
311 | # logging.info("Running env unit tests...")
312 | # # Set the results directory
313 | # results_dir, parser = get_results_dir()
314 | # parser.add_argument("unittest_args", nargs="*")
315 | # args = parser.parse_args()
316 | # sys.argv[1:] = args.unittest_args
317 | # TestEnv.results_dir = results_dir
318 |
319 | # unittest.main()
320 |
--------------------------------------------------------------------------------
/getting_started.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "Copyright (c) 2022, salesforce.com, inc and MILA. \n",
8 | "All rights reserved. \n",
9 | "SPDX-License-Identifier: BSD-3-Clause \n",
10 | "For full license text, see the LICENSE file in the repo root \n",
11 | "or https://opensource.org/licenses/BSD-3-Clause "
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "# How can I register for the competition?"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "Please fill out the [registration form](https://docs.google.com/forms/d/e/1FAIpQLSe2SWnhJaRpjcCa3idq7zIFubRoH0pATLOP7c1Y0kMXOV6U4w/viewform) in order to register for the competition. \n",
26 | "\n",
27 | "You will only need to provide an email address and a team name. You will also need to be willing to open-source your code after the competition.\n",
28 | "\n",
29 | "After you submit your registration form, we will register it internally. Please allow for upto 1-2 working days for your team name to be registered. You will be notified via email upon successful registration. You will need your team name in order to make submissions towards the competition."
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {},
35 | "source": [
36 | "# Quickly train agents with CPU using rllib and create a submission"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {},
42 | "source": [
43 | "The following command should install all the pre-requisites automatically.\n",
44 | "\n",
45 | "Please make sure that you are using Python 3.7 or older version. Our code does not support 3.8 or newer versions currently."
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "!python ./scripts/train_with_rllib.py"
55 | ]
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "metadata": {},
60 | "source": [
61 | "# Evaluate your submission locally"
62 | ]
63 | },
64 | {
65 | "cell_type": "markdown",
66 | "metadata": {},
67 | "source": [
68 | "Before you actually upload your submission files, you can also evaluate and score your submission on your end using this script. The evaluation script essentially validates the submission files, performs unit testing and computes the metrics for evaluation. To compute the metrics, we first instantiate a trainer, load the policy model with the saved parameters, and then generate several episode rollouts to measure the impact of the policy on the environment."
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "!python ./scripts/evaluate_submission.py -r Submissions/.zip"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "metadata": {},
83 | "source": [
84 | "# Where can I submit my solution?"
85 | ]
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "metadata": {},
90 | "source": [
91 | "*NOTE: Please register for the competition (see the steps above), if you have not done so. Your team must be registered before you can submit your solutions.*\n",
92 | "\n",
93 | "The AI climate competition features 3 tracks.\n",
94 | "\n",
95 | "In Track 1, you will propose and implement multilateral agreements to augment the simulator, and train the AI agents in the simulator. We evaluate the learned policies and resulting economic and climate change metrics.\n",
96 | "\n",
97 | "- The submission form for Track 1 is [here](https://forms.gle/fuM4NZ5eX2rdckit6).\n",
98 | "- Or, as an alternative, submit [here](https://workspace.jianguoyun.com/inbox/collect/c7be3a1c61624a4498666095d8a51824/submit) if you have difficulty accessing the Google form\n",
99 | "\n",
100 | "\n",
101 | "Please select your registered team name from the drop-down menu, and upload a zip file containing the submission files - we will be providing scripts to help you create the zip file.\n",
102 | "\n",
103 | "In Track 2, you will argue why your solution is practically relevant and usable in the real world. We expect the entries in this track to contain a high-level summary for policymakers\n",
104 | "\n",
105 | "To submit the generated running result and your code (the **.zip** file):\n",
106 | "- The submission form for Track 2 is [here](https://forms.gle/1kTsFLUp6yVF3xQf9).\n",
107 | "- Or, as an alternative, submit [here](https://workspace.jianguoyun.com/inbox/collect/323b16a8697741348e3197ebccafea81/submit) if you have difficulty accessing the Google form\n",
108 | "\n",
109 | "To submit your essay: [OpenReview](https://openreview.net/group?id=AI4ClimateCoop.org/2022/Workshop)\n",
110 | "\n",
111 | "In Track 3, we invite you to point out potential simulation loopholes and improvements.\n",
112 | "\n",
113 | "To submit your code to support your suggestions (a **.zip** file, optional):\n",
114 | "- The submission form for Track 3 is [here](https://forms.gle/stBihEdaf58xi2Vr6).\n",
115 | "- Or, as an alternative, submit [here](https://workspace.jianguoyun.com/inbox/collect/e96127b108ed4172a3b79273688a883c/submit) if you have difficulty accessing the Google form\n",
116 | "\n",
117 | "To submit your essay: [OpenReview](https://openreview.net/group?id=AI4ClimateCoop.org/2022/Workshop)\n",
118 | "\n",
119 | "If you do not see your team name in the drop-down menu, please contact us on Slack or by e-mail, and we will resolve that for you."
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "metadata": {},
125 | "source": [
126 | "# How do I create a submission using my modified negotiation protocol? "
127 | ]
128 | },
129 | {
130 | "cell_type": "markdown",
131 | "metadata": {},
132 | "source": [
133 | "We provide the base version of the RICE-N (regional integrated climate environment) simulation environment written in Python (`rice.py`).\n",
134 | "\n",
135 | "**For the mathematical background and scientific references, please see [the white paper](https://deliverypdf.ssrn.com/delivery.php?ID=428101121103108016095076093074095111015069058086095042085123117113111092124092117108004117037031126012054120125119115118069067102029022089006118121099082113093096121049050055084110110018003106083072011105122122123113102083083074084083085090104119080101&EXT=pdf&INDEX=TRUE).**\n",
136 | "\n",
137 | "You will need to mainly modify the `rice.py` to implement the proposed negotiatoin protocol. Additional details can be found below."
138 | ]
139 | },
140 | {
141 | "cell_type": "markdown",
142 | "metadata": {},
143 | "source": [
144 | "## Scripts for creating the zipped submission file"
145 | ]
146 | },
147 | {
148 | "cell_type": "markdown",
149 | "metadata": {},
150 | "source": [
151 | "As mentioned above, the zipped file required for submission is automatically created post-training. However, for any reason (for example, for providing a trained policy model at a different timestep), you can create the zipped submission yourself using the `create_submizzion_zip.py` script. Accordingly, create a new directory (say `submission_dir`) with all the relevant files (see the section above), and you can then simply invoke\n",
152 | "```commandline\n",
153 | "python scripts/create_submission_zip.py -r \n",
154 | "```\n",
155 | "\n",
156 | "That will first validate that the submission directory contains all the required files, and then provide you a zipped file that can you use towards your submission."
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": null,
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "!python ./scripts/create_submission_zip.py -r "
166 | ]
167 | },
168 | {
169 | "cell_type": "markdown",
170 | "metadata": {},
171 | "source": [
172 | "## Scripts for unit testing"
173 | ]
174 | },
175 | {
176 | "cell_type": "markdown",
177 | "metadata": {},
178 | "source": [
179 | "In order to make sure that all the submissions are consistent in that they comply within the rules of the competition, we have also added unit tests. These are automatically run also when the evaluation is performed. The script currently performs the following tests\n",
180 | "\n",
181 | "- Test that the environment attributes (such as the RICE and DICE constants, the simulation period and the number of regions) are consistent with the base environment class that we also provide.\n",
182 | "- Test that the `climate_and_economy_simulation_step()` is consistent with the base class. As aforementioned, users are free to add different negotiation strategies such as multi-lateral negotiations or climate clubs, but should not modify the equations underlying the climate and economic dynamics in the world.\n",
183 | "- Test that the environment resetting and stepping yield outputs in the desired format (for instance, observations are a dictionary keyed by region id, and so are rewards.)\n",
184 | "- If the user used WarpDrive, we also perform consistency checks to verify that the CUDA implementation of the rice environment is consistent with the pythonic version.\n",
185 | "\n",
186 | "USAGE: You may invoke the unit tests on a submission file via\n",
187 | "```commandline\n",
188 | "python scripts/run_unittests.py -r \n",
189 | "```"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": null,
195 | "metadata": {},
196 | "outputs": [],
197 | "source": [
198 | "!python ./scripts/run_unittests.py -r "
199 | ]
200 | },
201 | {
202 | "cell_type": "markdown",
203 | "metadata": {},
204 | "source": [
205 | "## Scripts for performance evaluation"
206 | ]
207 | },
208 | {
209 | "cell_type": "markdown",
210 | "metadata": {},
211 | "source": [
212 | "USAGE: You may evaluate the submission file using\n",
213 | "```commandline\n",
214 | "python scripts/evaluate_submission.py -r \n",
215 | "```\n",
216 | "Please verify that you can indeed evaluate your submission, before actually uploading it."
217 | ]
218 | },
219 | {
220 | "cell_type": "markdown",
221 | "metadata": {},
222 | "source": [
223 | "# Evaluation process"
224 | ]
225 | },
226 | {
227 | "cell_type": "markdown",
228 | "metadata": {},
229 | "source": [
230 | "After you submit your solution, we will be using the same evaluation script that is provided to you, to score your submissions, but using several rollout episodes to average the metrics such as the average rewards, the global temperature rise, capital, production, and many more. We will then rank the submissions based on the various metrics.The score computed by the evaluation process should be similar to the score computed on your end, since they use the same scripts."
231 | ]
232 | },
233 | {
234 | "cell_type": "markdown",
235 | "metadata": {},
236 | "source": [
237 | "## What happens when I make an invalid submission?"
238 | ]
239 | },
240 | {
241 | "cell_type": "markdown",
242 | "metadata": {},
243 | "source": [
244 | "An \"invalid submission\" may refer to a submission wherein some or all of the submission files are missing, or the submission files are inconsistent with the base version, basically anything that fails in the evaluation process. Any invalid solution cannot be evaluated, and hence will not feature in the leaderboard. While we can let you know if your submission is invalid, the process is not automated, so we may not be able to do it promptly. To avoid any issues, please use the `create_submission_zip` script to create your zipped submission file."
245 | ]
246 | },
247 | {
248 | "cell_type": "markdown",
249 | "metadata": {},
250 | "source": [
251 | "# Leaderboard"
252 | ]
253 | },
254 | {
255 | "cell_type": "markdown",
256 | "metadata": {},
257 | "source": [
258 | "The competition leaderboard is displayed on the [competition website](https://mila-iqia.github.io/climate-cooperation-competition). After you submit your valid submission, please give it a few minutes to perform an evaluation of your submission and refresh the leaderboard."
259 | ]
260 | },
261 | {
262 | "cell_type": "markdown",
263 | "metadata": {},
264 | "source": [
265 | "# How many submissions are allowed per team?"
266 | ]
267 | },
268 | {
269 | "cell_type": "markdown",
270 | "metadata": {},
271 | "source": [
272 | "There is no limit on the number of submissions per team. Feel free to submit as many solutions as you would like. We will only be using your submission with the highest evaluation score towards the leaderboard."
273 | ]
274 | },
275 | {
276 | "cell_type": "markdown",
277 | "metadata": {
278 | "ExecuteTime": {
279 | "end_time": "2022-06-13T13:41:49.356590Z",
280 | "start_time": "2022-06-13T13:41:22.620490Z"
281 | },
282 | "scrolled": true
283 | },
284 | "source": [
285 | "# Code overview"
286 | ]
287 | },
288 | {
289 | "cell_type": "markdown",
290 | "metadata": {},
291 | "source": [
292 | "## File Structure"
293 | ]
294 | },
295 | {
296 | "cell_type": "markdown",
297 | "metadata": {},
298 | "source": [
299 | "Below is the detailed file tree, and file descriptions.\n",
300 | "```commandline\n",
301 | "ROOT_DIR\n",
302 | "├── rice.py\n",
303 | "├── rice_helpers.py\n",
304 | "├── region_yamls\n",
305 | "\n",
306 | "├── rice_step.cu\n",
307 | "├── rice_cuda.py\n",
308 | "├── rice_build.cu\n",
309 | "\n",
310 | "└── scripts\n",
311 | " ├── train_with_rllib.py\n",
312 | " ├── rice_rllib.yaml\n",
313 | " ├── torch_models.py\n",
314 | " \n",
315 | " ├── train_with_warp_drive.py\n",
316 | " ├── rice_warpdrive.yaml\n",
317 | " ├── run_cpu_gpu_env_consistency_checks.py\n",
318 | " \n",
319 | " ├── run_unittests.py \n",
320 | " ├── create_submission_zip.py\n",
321 | " └── evaluate_submission.py \n",
322 | "```"
323 | ]
324 | },
325 | {
326 | "cell_type": "markdown",
327 | "metadata": {},
328 | "source": [
329 | "## Environment files"
330 | ]
331 | },
332 | {
333 | "cell_type": "markdown",
334 | "metadata": {},
335 | "source": [
336 | "- `rice.py`: This python script contains the base Rice class. This is written in [OpenAI Gym](https://gym.openai.com/) style with the `reset()` and `step()` functionalities. The step() function comprises an implementation of the `climate_and_economy_simulation_step` which dictate the dynamics of the climate and economy simulation, and should not be altered by the user. We have also provided a simple implementation of bilateral negotiation between regions via the `proposal_step()` and `evaluation_step()` methods. Users can extend the simulation by adding additional proposal strategies, for example, and incorporating them in the `step()` function. However, please do not modify any of the equations dictating the environment dynamics in the `climate_and_economy_simulation_step()`. All the helper functions related to modeling the climate and economic simulation are located in `rice_helpers.py`. Region-specific environment parameters are provided in the `region_yamls` directory.\n",
337 | "\n",
338 | "\n",
339 | "- `rice_step.cu`\n",
340 | "This is the CUDA C version of the step() function that is required for use with WarpDrive. To get started with WarpDrive, we recommend following these [tutorials](https://github.com/salesforce/warp-drive/tree/master/tutorials). While WarpDrive requires writing the simulation in CUDA C, it also offers orders-of-magnitude speedups for end-to-end training, since it performs rollouts and training all on the GPU. `rice_cuda.py` nd `rice_build.cu` are necessary files for copying simulation data to the GPU and compiling the CUDA code.\n",
341 | "\n",
342 | "While implementing the simulation in CUDA C on the GPU offers significantly faster simulations, it requires careful memory management. To make sure that everything works properly, one approach is to first implement your simulation logic in Python. You can then implement the same logic in CUDA C and check the simulation behaviors are the same. To help with this process, we provide an environment consistency checker method to do consistency tests between Python and CUDA C simulations. Before training your CUDA C code, please run the consistency checker to ensure the Python and CUDA C implementations are consistent.\n",
343 | "```commandline\n",
344 | "python scripts/run_env_cpu_gpu_consistency_checks.py\n",
345 | "```\n",
346 | "\n",
347 | "See the [tutorial notebook](https://colab.research.google.com/drive/1ifcYaczxy4eHM986fyFeSXarGAKmPLt5#scrollTo=vrIIciaAlHSl) for additional details on modifying the code to implement proposed negotiation protocols."
348 | ]
349 | }
350 | ],
351 | "metadata": {
352 | "kernelspec": {
353 | "display_name": "Python 3.8.10 64-bit",
354 | "language": "python",
355 | "name": "python3"
356 | },
357 | "language_info": {
358 | "codemirror_mode": {
359 | "name": "ipython",
360 | "version": 3
361 | },
362 | "file_extension": ".py",
363 | "mimetype": "text/x-python",
364 | "name": "python",
365 | "nbconvert_exporter": "python",
366 | "pygments_lexer": "ipython3",
367 | "version": "3.8.10"
368 | },
369 | "vscode": {
370 | "interpreter": {
371 | "hash": "570feb405e2e27c949193ac68f46852414290d515b0ba6e5d90d076ed2284471"
372 | }
373 | }
374 | },
375 | "nbformat": 4,
376 | "nbformat_minor": 4
377 | }
378 |
--------------------------------------------------------------------------------
/scripts/train_with_rllib.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc and MILA.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root
5 | # or https://opensource.org/licenses/BSD-3-Clause
6 |
7 |
8 | """
9 | Training script for the rice environment using RLlib
10 | https://docs.ray.io/en/latest/rllib-training.html
11 | """
12 |
13 | import logging
14 | import os
15 | import shutil
16 | import subprocess
17 | import sys
18 | import time
19 |
20 | import numpy as np
21 | import yaml
22 | from desired_outputs import desired_outputs
23 | from fixed_paths import PUBLIC_REPO_DIR
24 | from run_unittests import import_class_from_path
25 | from opt_helper import save
26 |
27 | sys.path.append(PUBLIC_REPO_DIR)
28 |
29 | # Set logger level e.g., DEBUG, INFO, WARNING, ERROR.
30 | logging.getLogger().setLevel(logging.DEBUG)
31 |
32 |
33 | def perform_other_imports():
34 | """
35 | RLlib-related imports.
36 | """
37 | import ray
38 | import torch
39 | from gym.spaces import Box, Dict
40 | from ray.rllib.agents.a3c import A2CTrainer
41 | from ray.rllib.env.multi_agent_env import MultiAgentEnv
42 | from ray.tune.logger import NoopLogger
43 |
44 | return ray, torch, Box, Dict, MultiAgentEnv, A2CTrainer, NoopLogger
45 |
46 |
47 | print("Do imports")
48 |
49 | try:
50 | other_imports = perform_other_imports()
51 | except ImportError:
52 | print("Installing requirements...")
53 |
54 | # Install gym
55 | subprocess.call(["pip", "install", "gym==0.21.0"])
56 | # Install RLlib v1.0.0
57 | subprocess.call(["pip", "install", "ray[rllib]==1.0.0"])
58 | # Install PyTorch
59 | subprocess.call(["pip", "install", "torch==1.9.0"])
60 |
61 | other_imports = perform_other_imports()
62 |
63 | ray, torch, Box, Dict, MultiAgentEnv, A2CTrainer, NoopLogger = other_imports
64 |
65 | from torch_models import TorchLinear
66 |
67 | logging.info("Finished imports")
68 |
69 |
70 | _BIG_NUMBER = 1e20
71 |
72 |
73 | def recursive_obs_dict_to_spaces_dict(obs):
74 | """Recursively return the observation space dictionary
75 | for a dictionary of observations
76 |
77 | Args:
78 | obs (dict): A dictionary of observations keyed by agent index
79 | for a multi-agent environment
80 |
81 | Returns:
82 | spaces.Dict: A dictionary of observation spaces
83 | """
84 | assert isinstance(obs, dict)
85 | dict_of_spaces = {}
86 | for key, val in obs.items():
87 |
88 | # list of lists are 'listified' np arrays
89 | _val = val
90 | if isinstance(val, list):
91 | _val = np.array(val)
92 | elif isinstance(val, (int, np.integer, float, np.floating)):
93 | _val = np.array([val])
94 |
95 | # assign Space
96 | if isinstance(_val, np.ndarray):
97 | large_num = float(_BIG_NUMBER)
98 | box = Box(
99 | low=-large_num, high=large_num, shape=_val.shape, dtype=_val.dtype
100 | )
101 | low_high_valid = (box.low < 0).all() and (box.high > 0).all()
102 |
103 | # This loop avoids issues with overflow to make sure low/high are good.
104 | while not low_high_valid:
105 | large_num = large_num // 2
106 | box = Box(
107 | low=-large_num, high=large_num, shape=_val.shape, dtype=_val.dtype
108 | )
109 | low_high_valid = (box.low < 0).all() and (box.high > 0).all()
110 |
111 | dict_of_spaces[key] = box
112 |
113 | elif isinstance(_val, dict):
114 | dict_of_spaces[key] = recursive_obs_dict_to_spaces_dict(_val)
115 | else:
116 | raise TypeError
117 | return Dict(dict_of_spaces)
118 |
119 |
120 | def recursive_list_to_np_array(dictionary):
121 | """
122 | Numpy-ify dictionary object to be used with RLlib.
123 | """
124 | if isinstance(dictionary, dict):
125 | new_d = {}
126 | for key, val in dictionary.items():
127 | if isinstance(val, list):
128 | new_d[key] = np.array(val)
129 | elif isinstance(val, dict):
130 | new_d[key] = recursive_list_to_np_array(val)
131 | elif isinstance(val, (int, np.integer, float, np.floating)):
132 | new_d[key] = np.array([val])
133 | elif isinstance(val, np.ndarray):
134 | new_d[key] = val
135 | else:
136 | raise AssertionError
137 | return new_d
138 | raise AssertionError
139 |
140 |
141 | class EnvWrapper(MultiAgentEnv):
142 | """
143 | The environment wrapper class.
144 | """
145 |
146 | def __init__(self, env_config=None):
147 |
148 | super().__init__()
149 |
150 | env_config_copy = env_config.copy()
151 | if env_config_copy is None:
152 | env_config_copy = {}
153 | source_dir = env_config_copy.get("source_dir", None)
154 | # Remove source_dir key in env_config if it exists
155 | if "source_dir" in env_config_copy:
156 | del env_config_copy["source_dir"]
157 | if source_dir is None:
158 | source_dir = PUBLIC_REPO_DIR
159 | assert isinstance(env_config_copy, dict)
160 | self.env = import_class_from_path("Rice", os.path.join(source_dir, "rice.py"))(
161 | **env_config_copy
162 | )
163 |
164 | self.action_space = self.env.action_space
165 |
166 | self.observation_space = recursive_obs_dict_to_spaces_dict(self.env.reset())
167 |
168 | def reset(self):
169 | """Reset the env."""
170 | obs = self.env.reset()
171 | return recursive_list_to_np_array(obs)
172 |
173 | def step(self, actions=None):
174 | """Step through the env."""
175 | assert actions is not None
176 | assert isinstance(actions, dict)
177 | obs, rew, done, info = self.env.step(actions)
178 | return recursive_list_to_np_array(obs), rew, done, info
179 |
180 |
181 | def get_rllib_config(exp_run_config=None, env_class=None, seed=None):
182 | """
183 | Reference: https://docs.ray.io/en/latest/rllib-training.html
184 | """
185 |
186 | assert exp_run_config is not None
187 | assert env_class is not None
188 |
189 | env_config = exp_run_config["env"]
190 | assert isinstance(env_config, dict)
191 | env_object = env_class(env_config=env_config)
192 |
193 | # Define all the policies here
194 | policy_config = exp_run_config["policy"]["regions"]
195 |
196 | # Map of type MultiAgentPolicyConfigDict from policy ids to tuples
197 | # of (policy_cls, obs_space, act_space, config). This defines the
198 | # observation and action spaces of the policies and any extra config.
199 | policies = {
200 | "regions": (
201 | None, # uses default policy
202 | env_object.observation_space[0],
203 | env_object.action_space[0],
204 | policy_config,
205 | ),
206 | }
207 |
208 | # Function mapping agent ids to policy ids.
209 | def policy_mapping_fn(agent_id=None):
210 | assert agent_id is not None
211 | return "regions"
212 |
213 | # Optional list of policies to train, or None for all policies.
214 | policies_to_train = None
215 |
216 | # Settings for Multi-Agent Environments
217 | multiagent_config = {
218 | "policies": policies,
219 | "policies_to_train": policies_to_train,
220 | "policy_mapping_fn": policy_mapping_fn,
221 | }
222 |
223 | train_config = exp_run_config["trainer"]
224 | rllib_config = {
225 | # Arguments dict passed to the env creator as an EnvContext object (which
226 | # is a dict plus the properties: num_workers, worker_index, vector_index,
227 | # and remote).
228 | "env_config": exp_run_config["env"],
229 | "framework": train_config["framework"],
230 | "multiagent": multiagent_config,
231 | "num_workers": train_config["num_workers"],
232 | "num_gpus": train_config["num_gpus"],
233 | "num_envs_per_worker": train_config["num_envs"] // train_config["num_workers"],
234 | "train_batch_size": train_config["train_batch_size"],
235 | }
236 | if seed is not None:
237 | rllib_config["seed"] = seed
238 |
239 | return rllib_config
240 |
241 |
242 | def save_model_checkpoint(trainer_obj=None, save_directory=None, current_timestep=0):
243 | """
244 | Save trained model checkpoints.
245 | """
246 | assert trainer_obj is not None
247 | assert save_directory is not None
248 | assert os.path.exists(save_directory), (
249 | "Invalid folder path. "
250 | "Please specify a valid directory to save the checkpoints."
251 | )
252 | model_params = trainer_obj.get_weights()
253 | for policy in model_params:
254 | filepath = os.path.join(
255 | save_directory,
256 | f"{policy}_{current_timestep}.state_dict",
257 | )
258 | logging.info(
259 | "Saving the model checkpoints for policy %s to %s.", (policy, filepath)
260 | )
261 | torch.save(model_params[policy], filepath)
262 |
263 |
264 | def load_model_checkpoints(trainer_obj=None, save_directory=None, ckpt_idx=-1):
265 | """
266 | Load trained model checkpoints.
267 | """
268 | assert trainer_obj is not None
269 | assert save_directory is not None
270 | assert os.path.exists(save_directory), (
271 | "Invalid folder path. "
272 | "Please specify a valid directory to load the checkpoints from."
273 | )
274 | files = [f for f in os.listdir(save_directory) if f.endswith("state_dict")]
275 |
276 | assert len(files) == len(trainer_obj.config["multiagent"]["policies"])
277 |
278 | model_params = trainer_obj.get_weights()
279 | for policy in model_params:
280 | policy_models = [
281 | os.path.join(save_directory, file) for file in files if policy in file
282 | ]
283 | # If there are multiple files, then use the ckpt_idx to specify the checkpoint
284 | assert ckpt_idx < len(policy_models)
285 | sorted_policy_models = sorted(policy_models, key=os.path.getmtime)
286 | policy_model_file = sorted_policy_models[ckpt_idx]
287 | model_params[policy] = torch.load(policy_model_file)
288 | logging.info(f"Loaded model checkpoints {policy_model_file}.")
289 |
290 | trainer_obj.set_weights(model_params)
291 |
292 |
293 | def create_trainer(exp_run_config=None, source_dir=None, results_dir=None, seed=None):
294 | """
295 | Create the RLlib trainer.
296 | """
297 | assert exp_run_config is not None
298 | if results_dir is None:
299 | # Use the current time as the name for the results directory.
300 | results_dir = f"{time.time():10.0f}"
301 |
302 | # Directory to save model checkpoints and metrics
303 |
304 | save_config = exp_run_config["saving"]
305 | results_save_dir = os.path.join(
306 | save_config["basedir"],
307 | save_config["name"],
308 | save_config["tag"],
309 | results_dir,
310 | )
311 |
312 | ray.init(ignore_reinit_error=True)
313 |
314 | # Create the A2C trainer.
315 | exp_run_config["env"]["source_dir"] = source_dir
316 | rllib_trainer = A2CTrainer(
317 | env=EnvWrapper,
318 | config=get_rllib_config(
319 | exp_run_config=exp_run_config, env_class=EnvWrapper, seed=seed
320 | ),
321 | )
322 | return rllib_trainer, results_save_dir
323 |
324 |
325 | def fetch_episode_states(trainer_obj=None, episode_states=None):
326 | """
327 | Helper function to rollout the env and fetch env states for an episode.
328 | """
329 | assert trainer_obj is not None
330 | assert episode_states is not None
331 | assert isinstance(episode_states, list)
332 | assert len(episode_states) > 0
333 |
334 | outputs = {}
335 |
336 | # Fetch the env object from the trainer
337 | env_object = trainer_obj.workers.local_worker().env
338 | obs = env_object.reset()
339 |
340 | env = env_object.env
341 |
342 | for state in episode_states:
343 | assert state in env.global_state, f"{state} is not in global state!"
344 | # Initialize the episode states
345 | array_shape = env.global_state[state]["value"].shape
346 | outputs[state] = np.nan * np.ones(array_shape)
347 |
348 | agent_states = {}
349 | policy_ids = {}
350 | policy_mapping_fn = trainer_obj.config["multiagent"]["policy_mapping_fn"]
351 | for region_id in range(env.num_agents):
352 | policy_ids[region_id] = policy_mapping_fn(region_id)
353 | agent_states[region_id] = trainer_obj.get_policy(
354 | policy_ids[region_id]
355 | ).get_initial_state()
356 |
357 | for timestep in range(env.episode_length):
358 | for state in episode_states:
359 | outputs[state][timestep] = env.global_state[state]["value"][timestep]
360 |
361 | actions = {}
362 | # TODO: Consider using the `compute_actions` (instead of `compute_action`)
363 | # API below for speed-up when there are many agents.
364 | for region_id in range(env.num_agents):
365 | if (
366 | len(agent_states[region_id]) == 0
367 | ): # stateless, with a linear model, for example
368 | actions[region_id] = trainer_obj.compute_action(
369 | obs[region_id],
370 | agent_states[region_id],
371 | policy_id=policy_ids[region_id],
372 | )
373 | else: # stateful
374 | (
375 | actions[region_id],
376 | agent_states[region_id],
377 | _,
378 | ) = trainer_obj.compute_action(
379 | obs[region_id],
380 | agent_states[region_id],
381 | policy_id=policy_ids[region_id],
382 | )
383 | obs, _, done, _ = env_object.step(actions)
384 | if done["__all__"]:
385 | for state in episode_states:
386 | outputs[state][timestep + 1] = env.global_state[state]["value"][
387 | timestep + 1
388 | ]
389 | break
390 |
391 | return outputs
392 |
393 |
394 | def trainer(
395 | negotiation_on=0,
396 | num_envs=100,
397 | train_batch_size=1024,
398 | num_episodes=30000,
399 | lr=0.0005,
400 | model_params_save_freq=5000,
401 | desired_outputs=desired_outputs,
402 | num_workers=4,
403 | ):
404 | print("Training with RLlib...")
405 |
406 | # Read the run configurations specific to the environment.
407 | # Note: The run config yaml(s) can be edited at warp_drive/training/run_configs
408 | # -----------------------------------------------------------------------------
409 | config_path = os.path.join(PUBLIC_REPO_DIR, "scripts", "rice_rllib.yaml")
410 | if not os.path.exists(config_path):
411 | raise ValueError(
412 | "The run configuration is missing. Please make sure the correct path "
413 | "is specified."
414 | )
415 |
416 | with open(config_path, "r", encoding="utf8") as fp:
417 | run_config = yaml.safe_load(fp)
418 | # replace the default setting
419 | run_config["env"]["negotiation_on"] = negotiation_on
420 | run_config["trainer"]["num_envs"] = num_envs
421 | run_config["trainer"]["train_batch_size"] = train_batch_size
422 | run_config["trainer"]["num_workers"] = num_workers
423 | run_config["trainer"]["num_episodes"] = num_episodes
424 | run_config["policy"]["regions"]["lr"] = lr
425 | run_config["saving"]["model_params_save_freq"] = model_params_save_freq
426 |
427 | # Create trainer
428 | # --------------
429 | trainer, save_dir = create_trainer(run_config)
430 | # debug: print("trainer weghts: ", trainer.get_weights()["regions"]["policy_head.97.weight"])
431 | # Copy the source files into the results directory
432 | # ------------------------------------------------
433 | os.makedirs(save_dir)
434 | with open(os.path.join(save_dir, "rice_rllib.yaml"), "w") as yaml_file:
435 | yaml.dump(run_config, yaml_file)
436 | # Copy source files to the saving directory
437 | for file in ["rice.py", "rice_helpers.py"]:
438 | shutil.copyfile(
439 | os.path.join(PUBLIC_REPO_DIR, file),
440 | os.path.join(save_dir, file),
441 | )
442 |
443 | # Add an identifier file
444 | with open(os.path.join(save_dir, ".rllib"), "x", encoding="utf-8") as fp:
445 | pass
446 | fp.close()
447 |
448 | # Perform training
449 | # ----------------
450 | trainer_config = run_config["trainer"]
451 | # num_episodes = trainer_config["num_episodes"]
452 | # train_batch_size = trainer_config["train_batch_size"]
453 | # Fetch the env object from the trainer
454 | env_obj = trainer.workers.local_worker().env.env
455 | episode_length = env_obj.episode_length
456 | num_iters = (num_episodes * episode_length) // train_batch_size
457 |
458 | for iteration in range(num_iters):
459 | print(f"********** Iter : {iteration + 1:5d} / {num_iters:5d} **********")
460 | result = trainer.train()
461 | total_timesteps = result.get("timesteps_total")
462 | if (
463 | iteration % run_config["saving"]["model_params_save_freq"] == 0
464 | or iteration == num_iters - 1
465 | ):
466 | save_model_checkpoint(trainer, save_dir, total_timesteps)
467 | logging.info(result)
468 | print(f"""episode_reward_mean: {result.get('episode_reward_mean')}""")
469 |
470 | outputs_ts = fetch_episode_states(trainer, desired_outputs)
471 | save(
472 | outputs_ts,
473 | os.path.join(
474 | save_dir,
475 | f"outputs_ts_{total_timesteps}.pkl",
476 | ),
477 | )
478 | print(f"Saving logged outputs to {save_dir}")
479 | # Create a (zipped) submission file
480 | # ---------------------------------
481 | subprocess.call(
482 | [
483 | "python",
484 | os.path.join(PUBLIC_REPO_DIR, "scripts", "create_submission_zip.py"),
485 | "--results_dir",
486 | save_dir,
487 | ]
488 | )
489 | # Close Ray gracefully after completion
490 | ray.shutdown()
491 | return trainer, outputs_ts
492 |
493 |
494 | if __name__ == "__main__":
495 | print("Training with RLlib...")
496 |
497 | # Read the run configurations specific to the environment.
498 | # Note: The run config yaml(s) can be edited at warp_drive/training/run_configs
499 | # -----------------------------------------------------------------------------
500 | config_path = os.getenv("CONFIG_FILE", os.path.join(PUBLIC_REPO_DIR, "scripts", "rice_rllib.yaml"))
501 | if not os.path.exists(config_path):
502 | raise ValueError(
503 | "The run configuration is missing. Please make sure the correct path "
504 | "is specified."
505 | )
506 |
507 | with open(config_path, "r", encoding="utf8") as fp:
508 | run_config = yaml.safe_load(fp)
509 |
510 | # Create trainer
511 | # --------------
512 | trainer, save_dir = create_trainer(run_config)
513 |
514 | # Copy the source files into the results directory
515 | # ------------------------------------------------
516 | os.makedirs(save_dir)
517 | # Copy source files to the saving directory
518 | for file in ["rice.py", "rice_helpers.py"]:
519 | shutil.copyfile(
520 | os.path.join(PUBLIC_REPO_DIR, file),
521 | os.path.join(save_dir, file),
522 | )
523 | for file in ["rice_rllib.yaml"]:
524 | shutil.copyfile(
525 | os.path.join(PUBLIC_REPO_DIR, "scripts", file),
526 | os.path.join(save_dir, file),
527 | )
528 |
529 | # Add an identifier file
530 | with open(os.path.join(save_dir, ".rllib"), "x", encoding="utf-8") as fp:
531 | pass
532 | fp.close()
533 |
534 | # Perform training
535 | # ----------------
536 | trainer_config = run_config["trainer"]
537 | num_episodes = trainer_config["num_episodes"]
538 | train_batch_size = trainer_config["train_batch_size"]
539 | # Fetch the env object from the trainer
540 | env_obj = trainer.workers.local_worker().env.env
541 | episode_length = env_obj.episode_length
542 | num_iters = (num_episodes * episode_length) // train_batch_size
543 |
544 | for iteration in range(num_iters):
545 | print(f"********** Iter : {iteration + 1:5d} / {num_iters:5d} **********")
546 | result = trainer.train()
547 | total_timesteps = result.get("timesteps_total")
548 | if (
549 | iteration % run_config["saving"]["model_params_save_freq"] == 0
550 | or iteration == num_iters - 1
551 | ):
552 | save_model_checkpoint(trainer, save_dir, total_timesteps)
553 | logging.info(result)
554 | print(f"""episode_reward_mean: {result.get('episode_reward_mean')}""")
555 |
556 | # Create a (zipped) submission file
557 | # ---------------------------------
558 | subprocess.call(
559 | [
560 | "python",
561 | os.path.join(PUBLIC_REPO_DIR, "scripts", "create_submission_zip.py"),
562 | "--results_dir",
563 | save_dir,
564 | ]
565 | )
566 |
567 | # Close Ray gracefully after completion
568 | ray.shutdown()
569 |
--------------------------------------------------------------------------------
/Visualization.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "%autoreload 2"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {
24 | "id": "U5SXMcYAarz5"
25 | },
26 | "source": [
27 | "# Visualization of RICE-N and RL training\n"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "import os, glob, sys, numpy as np, scipy as sp, sklearn as skl"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "if \"_ROOT\" not in globals():\n",
46 | " _ROOT = os.getcwd()\n",
47 | " print(f\"Set _ROOT = {_ROOT}\")\n",
48 | "else: \n",
49 | " print(f\"Already set: _ROOT = {_ROOT}\")"
50 | ]
51 | },
52 | {
53 | "attachments": {},
54 | "cell_type": "markdown",
55 | "metadata": {
56 | "id": "OF4mXVOHlHS3"
57 | },
58 | "source": [
59 | "## Save or load from previous training results"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "metadata": {
65 | "id": "80g85HbZlHS4"
66 | },
67 | "source": [
68 | "This section is for saving and loading the results of training (not the trainer itself)."
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {
75 | "id": "K9zxtDGzlHS4"
76 | },
77 | "outputs": [],
78 | "source": [
79 | "from opt_helper import save, load"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "metadata": {
85 | "id": "03fOZB9fpHAs"
86 | },
87 | "source": [
88 | "To save the output timeseries: "
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "metadata": {
95 | "id": "lYF6UDHKlHS4"
96 | },
97 | "outputs": [],
98 | "source": [
99 | "# [uncomment below to save]\n",
100 | "# save({\"nego_off\":gpu_nego_off_ts, \"nego_on\":gpu_nego_on_ts}, \"filename.pkl\")"
101 | ]
102 | },
103 | {
104 | "cell_type": "markdown",
105 | "metadata": {
106 | "id": "0vG1JZ75pIa7"
107 | },
108 | "source": [
109 | "To load the output timeseries:"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {
116 | "id": "4TEE7CvHlHS4"
117 | },
118 | "outputs": [],
119 | "source": [
120 | "# [uncomment below to load]\n",
121 | "dict_ts = load(\"example_data/example.pkl\")\n",
122 | "nego_off_ts, nego_on_ts = dict_ts[\"nego_off\"], dict_ts[\"nego_on\"]"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": null,
128 | "metadata": {},
129 | "outputs": [],
130 | "source": [
131 | "for key, value in nego_off_ts.items(): \n",
132 | " print(f\"{key:40} {value.shape}\")"
133 | ]
134 | },
135 | {
136 | "attachments": {},
137 | "cell_type": "markdown",
138 | "metadata": {},
139 | "source": [
140 | "## Plot training procedures"
141 | ]
142 | },
143 | {
144 | "cell_type": "markdown",
145 | "metadata": {},
146 | "source": [
147 | "One may want to plot the some metrics such as `mean reward` which are logged during the training procedure.\n",
148 | "\n",
149 | "```python\n",
150 | "metrics = ['Iterations Completed',\n",
151 | " 'VF loss coefficient',\n",
152 | " 'Entropy coefficient',\n",
153 | " 'Total loss',\n",
154 | " 'Policy loss',\n",
155 | " 'Value function loss',\n",
156 | " 'Mean rewards',\n",
157 | " 'Max. rewards',\n",
158 | " 'Min. rewards',\n",
159 | " 'Mean value function',\n",
160 | " 'Mean advantages',\n",
161 | " 'Mean (norm.) advantages',\n",
162 | " 'Mean (discounted) returns',\n",
163 | " 'Mean normalized returns',\n",
164 | " 'Mean entropy',\n",
165 | " 'Variance explained by the value function',\n",
166 | " 'Gradient norm',\n",
167 | " 'Learning rate',\n",
168 | " 'Mean episodic reward',\n",
169 | " 'Mean policy eval time per iter (ms)',\n",
170 | " 'Mean action sample time per iter (ms)',\n",
171 | " 'Mean env. step time per iter (ms)',\n",
172 | " 'Mean training time per iter (ms)',\n",
173 | " 'Mean total time per iter (ms)',\n",
174 | " 'Mean steps per sec (policy eval)',\n",
175 | " 'Mean steps per sec (action sample)',\n",
176 | " 'Mean steps per sec (env. step)',\n",
177 | " 'Mean steps per sec (training time)',\n",
178 | " 'Mean steps per sec (total)'\n",
179 | " ]\n",
180 | "```"
181 | ]
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "metadata": {},
186 | "source": [
187 | "To check out the logged submissions, please run the following block."
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": null,
193 | "metadata": {},
194 | "outputs": [],
195 | "source": [
196 | "from glob import glob\n",
197 | "submission_zip_files = glob(os.path.join(_ROOT,\"Submissions/*.zip\"))"
198 | ]
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "metadata": {},
203 | "source": [
204 | "If previous trainings are finished and logged properly, this should give a list of `*.zip` files where the logs are included. \n",
205 | "\n",
206 | "We picked one of the submissions and the metric `Mean episodic reward` as an example, please check the code below."
207 | ]
208 | },
209 | {
210 | "cell_type": "code",
211 | "execution_count": null,
212 | "metadata": {},
213 | "outputs": [],
214 | "source": [
215 | "from opt_helper import get_training_curve, plot_training_curve"
216 | ]
217 | },
218 | {
219 | "cell_type": "markdown",
220 | "metadata": {},
221 | "source": [
222 | "### WarpDrive version"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "metadata": {},
229 | "outputs": [],
230 | "source": [
231 | "# TBC"
232 | ]
233 | },
234 | {
235 | "cell_type": "markdown",
236 | "metadata": {},
237 | "source": [
238 | "### RLLib (CPU) version"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": null,
244 | "metadata": {},
245 | "outputs": [],
246 | "source": [
247 | "log_zip = submission_zip_files[1]\n",
248 | "plot_training_curve(None, 'Mean episodic reward', log_zip)\n",
249 | "\n",
250 | "# to check the raw logging dictionary, uncomment below\n",
251 | "# logs = get_training_curve(log_zip)\n",
252 | "# logs"
253 | ]
254 | },
255 | {
256 | "attachments": {},
257 | "cell_type": "markdown",
258 | "metadata": {
259 | "id": "0BCG5IYWlHS5"
260 | },
261 | "source": [
262 | "## Plot results"
263 | ]
264 | },
265 | {
266 | "cell_type": "code",
267 | "execution_count": null,
268 | "metadata": {
269 | "id": "oZW6-QJGlHS5"
270 | },
271 | "outputs": [],
272 | "source": [
273 | "from scripts.desired_outputs import desired_outputs"
274 | ]
275 | },
276 | {
277 | "cell_type": "markdown",
278 | "metadata": {
279 | "id": "AdxL0JanlHS5"
280 | },
281 | "source": [
282 | "One may want to check the performance of the agents by plotting graphs. Below, we list all the logged variables. One may change the ``desired_outputs.py`` to add more variables of interest.\n",
283 | "\n",
284 | "```python\n",
285 | "desired_outputs = ['global_temperature', \n",
286 | " 'global_carbon_mass', \n",
287 | " 'capital_all_regions', \n",
288 | " 'labor_all_regions', \n",
289 | " 'production_factor_all_regions', \n",
290 | " 'intensity_all_regions', \n",
291 | " 'global_exogenous_emissions', \n",
292 | " 'global_land_emissions', \n",
293 | " 'timestep', \n",
294 | " 'activity_timestep', \n",
295 | " 'capital_depreciation_all_regions', \n",
296 | " 'savings_all_regions', \n",
297 | " 'mitigation_rate_all_regions', \n",
298 | " 'max_export_limit_all_regions', \n",
299 | " 'mitigation_cost_all_regions', \n",
300 | " 'damages_all_regions', \n",
301 | " 'abatement_cost_all_regions', \n",
302 | " 'utility_all_regions', \n",
303 | " 'social_welfare_all_regions', \n",
304 | " 'reward_all_regions', \n",
305 | " 'consumption_all_regions', \n",
306 | " 'current_balance_all_regions', \n",
307 | " 'gross_output_all_regions', \n",
308 | " 'investment_all_regions', \n",
309 | " 'production_all_regions', \n",
310 | " 'tariffs', \n",
311 | " 'future_tariffs', \n",
312 | " 'scaled_imports', \n",
313 | " 'desired_imports', \n",
314 | " 'tariffed_imports', \n",
315 | " 'stage', \n",
316 | " 'minimum_mitigation_rate_all_regions', \n",
317 | " 'promised_mitigation_rate', \n",
318 | " 'requested_mitigation_rate', \n",
319 | " 'proposal_decisions',\n",
320 | " 'global_consumption',\n",
321 | " 'global_production']\n",
322 | "```"
323 | ]
324 | },
325 | {
326 | "cell_type": "code",
327 | "execution_count": null,
328 | "metadata": {
329 | "id": "rYLsNHRjlHS5"
330 | },
331 | "outputs": [],
332 | "source": [
333 | "from opt_helper import plot_result"
334 | ]
335 | },
336 | {
337 | "cell_type": "markdown",
338 | "metadata": {
339 | "id": "-Ab3Dd-zlHS5"
340 | },
341 | "source": [
342 | "`plot_result()` plots the time series of logged variables.\n",
343 | "\n",
344 | "```python\n",
345 | "plot_result(variables, nego_off, nego_on, k)\n",
346 | "```\n",
347 | "* ``variables`` can be either a single variable of interest or a list of variable names from the above list. \n",
348 | "* The ``nego_off_ts`` and ``nego_on_ts`` are the logged time series for these variables, with and without negotiation. \n",
349 | "* ``k`` represents the dimension of the variable of interest ( it should be ``0`` by default for most situations)."
350 | ]
351 | },
352 | {
353 | "cell_type": "markdown",
354 | "metadata": {
355 | "id": "mrn13h2V2jBQ"
356 | },
357 | "source": [
358 | "Here's an example of plotting a single variable of interest."
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": null,
364 | "metadata": {
365 | "id": "nDdreoz2lHS6"
366 | },
367 | "outputs": [],
368 | "source": [
369 | "plot_result(\"global_temperature\", \n",
370 | " nego_off=nego_off_ts, # change it to cpu_nego_off_ts if using CPU\n",
371 | " nego_on=nego_on_ts, \n",
372 | " k=0)"
373 | ]
374 | },
375 | {
376 | "cell_type": "markdown",
377 | "metadata": {
378 | "id": "VNXPz9kk2meL"
379 | },
380 | "source": [
381 | "Here's an example of plotting a list of variables."
382 | ]
383 | },
384 | {
385 | "cell_type": "code",
386 | "execution_count": null,
387 | "metadata": {
388 | "id": "Z76yOBp3lHS5",
389 | "scrolled": true
390 | },
391 | "outputs": [],
392 | "source": [
393 | "plot_result(desired_outputs[0:3], # truncated for demonstration purposes\n",
394 | " nego_off=nego_off_ts, \n",
395 | " nego_on=nego_on_ts, \n",
396 | " k=0)"
397 | ]
398 | },
399 | {
400 | "cell_type": "markdown",
401 | "metadata": {
402 | "id": "6sJSQ5gdxCni"
403 | },
404 | "source": [
405 | "If one only want to plot negotiation-off plots, feel free to set `nego_on=None`. "
406 | ]
407 | },
408 | {
409 | "cell_type": "code",
410 | "execution_count": null,
411 | "metadata": {
412 | "id": "FkuU8kV2xCnj"
413 | },
414 | "outputs": [],
415 | "source": [
416 | "plot_result(desired_outputs[0:3], # truncated for demonstration purposes\n",
417 | " nego_off=nego_off_ts, \n",
418 | " nego_on=None, \n",
419 | " k=0)"
420 | ]
421 | },
422 | {
423 | "cell_type": "markdown",
424 | "metadata": {},
425 | "source": [
426 | "## Plot region data using a grid plot"
427 | ]
428 | },
429 | {
430 | "cell_type": "code",
431 | "execution_count": null,
432 | "metadata": {},
433 | "outputs": [],
434 | "source": [
435 | "from opt_helper import make_grid_plot"
436 | ]
437 | },
438 | {
439 | "cell_type": "code",
440 | "execution_count": null,
441 | "metadata": {},
442 | "outputs": [],
443 | "source": [
444 | "feature_name = \"labor_all_regions\"\n",
445 | "feature_label = feature_name.replace(\"_\", \" \").title() + \" - atmosphere layer\"\n",
446 | "make_grid_plot(\n",
447 | " nego_off_ts[feature_name], \n",
448 | " xlabel=\"Year\", \n",
449 | " ylabel=feature_label, \n",
450 | " feature_label=feature_label, \n",
451 | " cols=3, \n",
452 | " fig_scale=3\n",
453 | ");"
454 | ]
455 | },
456 | {
457 | "cell_type": "markdown",
458 | "metadata": {},
459 | "source": [
460 | "## Plot multiple time series data with mean and spread around the mean"
461 | ]
462 | },
463 | {
464 | "cell_type": "code",
465 | "execution_count": null,
466 | "metadata": {},
467 | "outputs": [],
468 | "source": [
469 | "from opt_helper import plot_fig_with_bounds"
470 | ]
471 | },
472 | {
473 | "cell_type": "markdown",
474 | "metadata": {},
475 | "source": [
476 | "Generating some example perturbed data around the original data."
477 | ]
478 | },
479 | {
480 | "cell_type": "code",
481 | "execution_count": null,
482 | "metadata": {},
483 | "outputs": [],
484 | "source": [
485 | "from copy import deepcopy \n",
486 | "perturbed_nego_on_ts = [\n",
487 | " deepcopy(nego_on_ts),\n",
488 | " deepcopy(nego_on_ts),\n",
489 | " deepcopy(nego_on_ts),\n",
490 | "]\n",
491 | "\n",
492 | "for key, ts in nego_on_ts.items():\n",
493 | " perturbed_nego_on_ts[0][key] = ts\n",
494 | " perturbed_nego_on_ts[1][key] = ts + 0.1 * ts\n",
495 | " perturbed_nego_on_ts[2][key] = ts - 0.1 * ts\n",
496 | "\n",
497 | "\n",
498 | "perturbed_nego_off_ts = [\n",
499 | " deepcopy(nego_off_ts),\n",
500 | " deepcopy(nego_off_ts),\n",
501 | " deepcopy(nego_off_ts),\n",
502 | "]\n",
503 | "\n",
504 | "for key, ts in nego_off_ts.items():\n",
505 | " perturbed_nego_off_ts[0][key] = ts\n",
506 | " perturbed_nego_off_ts[1][key] = ts + 0.1 * ts\n",
507 | " perturbed_nego_off_ts[2][key] = ts - 0.1 * ts"
508 | ]
509 | },
510 | {
511 | "cell_type": "code",
512 | "execution_count": null,
513 | "metadata": {},
514 | "outputs": [],
515 | "source": [
516 | "feature_name =\"labor_all_regions\"\n",
517 | "\n",
518 | "plot_fig_with_bounds(\n",
519 | " feature_name, # variable,\n",
520 | " \"y_label\", # y_label,\n",
521 | " list_of_dict_off=perturbed_nego_off_ts,\n",
522 | " list_of_dict_on=perturbed_nego_on_ts,\n",
523 | " title=None,\n",
524 | " idx=0,\n",
525 | " x_label=\"year\",\n",
526 | " skips=3,\n",
527 | " line_colors=[\"#0868ac\", \"#7e0018\"],\n",
528 | " region_colors=[\"#7bccc4\", \"#ffac3b\"],\n",
529 | " start=2020,\n",
530 | " alpha=0.5,\n",
531 | " is_grid=True,\n",
532 | " is_save=True,\n",
533 | " delta=5,\n",
534 | ")"
535 | ]
536 | },
537 | {
538 | "attachments": {},
539 | "cell_type": "markdown",
540 | "metadata": {},
541 | "source": [
542 | "## Aggregate simulation statistics to visualize"
543 | ]
544 | },
545 | {
546 | "cell_type": "code",
547 | "execution_count": null,
548 | "metadata": {},
549 | "outputs": [],
550 | "source": [
551 | "for k, v in nego_off_ts.items(): \n",
552 | " print(f\"{k:40}{v.shape}\")"
553 | ]
554 | },
555 | {
556 | "cell_type": "code",
557 | "execution_count": null,
558 | "metadata": {},
559 | "outputs": [],
560 | "source": [
561 | "feature_name = \"mitigation_rate_all_regions\"\n",
562 | "feature_label = feature_name.replace(\"_\", \" \").title()\n",
563 | "make_grid_plot(\n",
564 | " nego_on_ts[feature_name],\n",
565 | " feature_label=feature_label, \n",
566 | " xlabel=\"Step\",\n",
567 | " ylabel=feature_label,\n",
568 | " cols=4,\n",
569 | " fig_scale=3,\n",
570 | ");"
571 | ]
572 | },
573 | {
574 | "cell_type": "code",
575 | "execution_count": null,
576 | "metadata": {},
577 | "outputs": [],
578 | "source": [
579 | "feature_name = \"savings_all_regions\"\n",
580 | "feature_label = feature_name.replace(\"_\", \" \").title()\n",
581 | "make_grid_plot(\n",
582 | " nego_on_ts[feature_name],\n",
583 | " feature_label=feature_label, \n",
584 | " xlabel=\"Step\",\n",
585 | " ylabel=feature_label,\n",
586 | " cols=4,\n",
587 | " fig_scale=3,\n",
588 | ");"
589 | ]
590 | },
591 | {
592 | "cell_type": "markdown",
593 | "metadata": {},
594 | "source": [
595 | "## Cluster regions"
596 | ]
597 | },
598 | {
599 | "cell_type": "code",
600 | "execution_count": null,
601 | "metadata": {},
602 | "outputs": [],
603 | "source": [
604 | "feature_name = \"savings_all_regions\"\n",
605 | "lo_savings, med_savings, hi_savings = [], [], []\n",
606 | "steps, n_regions = nego_on_ts[feature_name].shape\n",
607 | "for region_j in range(n_regions):\n",
608 | " mean_region_j = np.mean(nego_on_ts[feature_name][:, region_j])\n",
609 | " print(f\"{region_j}: {mean_region_j:.2f}\")\n",
610 | " if mean_region_j < 0.25: \n",
611 | " lo_savings.append(region_j)\n",
612 | " elif 0.25 < mean_region_j < 0.35:\n",
613 | " med_savings.append(region_j)\n",
614 | " else:\n",
615 | " hi_savings.append(region_j)"
616 | ]
617 | },
618 | {
619 | "cell_type": "code",
620 | "execution_count": null,
621 | "metadata": {},
622 | "outputs": [],
623 | "source": [
624 | "lo_savings, med_savings, hi_savings"
625 | ]
626 | },
627 | {
628 | "cell_type": "code",
629 | "execution_count": null,
630 | "metadata": {},
631 | "outputs": [],
632 | "source": [
633 | "from opt_helper import make_aggregate_data_across_three_clusters"
634 | ]
635 | },
636 | {
637 | "cell_type": "code",
638 | "execution_count": null,
639 | "metadata": {},
640 | "outputs": [],
641 | "source": [
642 | "aggregate_nego_on_ts = make_aggregate_data_across_three_clusters(\n",
643 | " nego_on_ts, \n",
644 | " lo_savings, \n",
645 | " med_savings, \n",
646 | " hi_savings\n",
647 | ")"
648 | ]
649 | },
650 | {
651 | "cell_type": "code",
652 | "execution_count": null,
653 | "metadata": {},
654 | "outputs": [],
655 | "source": [
656 | "feature_name = \"savings_all_regions\"\n",
657 | "feature_label = feature_name.replace(\"_\", \" \").title() + \" - aggregate\"\n",
658 | "make_grid_plot(\n",
659 | " aggregate_nego_on_ts[feature_name],\n",
660 | " feature_label=feature_label, \n",
661 | " xlabel=\"Step\",\n",
662 | " ylabel=feature_label,\n",
663 | " cols=4,\n",
664 | " fig_scale=3,\n",
665 | ");"
666 | ]
667 | },
668 | {
669 | "cell_type": "markdown",
670 | "metadata": {},
671 | "source": [
672 | "## Correlation between clusters and feature of those cluster members?"
673 | ]
674 | },
675 | {
676 | "cell_type": "code",
677 | "execution_count": null,
678 | "metadata": {},
679 | "outputs": [],
680 | "source": [
681 | "from opt_helper import compute_correlation_across_groups"
682 | ]
683 | },
684 | {
685 | "cell_type": "code",
686 | "execution_count": null,
687 | "metadata": {},
688 | "outputs": [],
689 | "source": [
690 | "# Define the groups of regions that are aggregated.\n",
691 | "# Note: these shoudl be the same groups that were used to generate the aggregate statistics.\n",
692 | "groups = (lo_savings, med_savings, hi_savings)\n",
693 | "\n",
694 | "# Defint the X variable\n",
695 | "x_feature_name = \"savings_all_regions\"\n",
696 | "aggregate_stats_across_groups = aggregate_nego_on_ts[x_feature_name]\n",
697 | "\n",
698 | "# Define the y data\n",
699 | "data_ts = nego_on_ts\n",
700 | "\n",
701 | "# Only compute correlations for the y variables that are region-specific time-series.\n",
702 | "y_features_names = [\n",
703 | " i for i in data_ts.keys() if \"all_regions\" in i\n",
704 | "]\n",
705 | "\n",
706 | "for y_feature_name in y_features_names: \n",
707 | " r2 = compute_correlation_across_groups(aggregate_stats_across_groups, \n",
708 | " data_ts, \n",
709 | " y_feature_name, \n",
710 | " do_plot=False\n",
711 | " )\n",
712 | "\n",
713 | " print(f\"{x_feature_name:35} vs {y_feature_name:35} r2 = {r2:5.2f}\")"
714 | ]
715 | }
716 | ],
717 | "metadata": {
718 | "accelerator": "GPU",
719 | "colab": {
720 | "collapsed_sections": [
721 | "ytMyQ2OHlHSr",
722 | "ukp1MeR1Q0dG",
723 | "4BIL2upxlHSw",
724 | "hrqSp18wlHS2",
725 | "OF4mXVOHlHS3",
726 | "0BCG5IYWlHS5",
727 | "yDI4p7cqlHS6",
728 | "LuF76W4FlHS6"
729 | ],
730 | "name": "Copy of Colab_Tutorial.ipynb",
731 | "provenance": []
732 | },
733 | "gpuClass": "standard",
734 | "kernelspec": {
735 | "display_name": "Python 3 (ipykernel)",
736 | "language": "python",
737 | "name": "python3"
738 | },
739 | "language_info": {
740 | "codemirror_mode": {
741 | "name": "ipython",
742 | "version": 3
743 | },
744 | "file_extension": ".py",
745 | "mimetype": "text/x-python",
746 | "name": "python",
747 | "nbconvert_exporter": "python",
748 | "pygments_lexer": "ipython3",
749 | "version": "3.7.16"
750 | },
751 | "vscode": {
752 | "interpreter": {
753 | "hash": "251d2b4a5597da75c3ffbd37b4ea77a636d0b2648b30b76d98939345ff58fd97"
754 | }
755 | }
756 | },
757 | "nbformat": 4,
758 | "nbformat_minor": 1
759 | }
760 |
--------------------------------------------------------------------------------
/scripts/evaluate_submission.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, salesforce.com, inc and MILA.
2 | # All rights reserved.
3 | # SPDX-License-Identifier: BSD-3-Clause
4 | # For full license text, see the LICENSE file in the repo root
5 | # or https://opensource.org/licenses/BSD-3-Clause
6 |
7 |
8 | """
9 | Evaluation script for the rice environment
10 | """
11 |
12 | import argparse
13 | import logging
14 | import os
15 | import shutil
16 | import subprocess
17 | import sys
18 | import time
19 | from collections import OrderedDict
20 | from pathlib import Path
21 | import json
22 | import numpy as np
23 | import yaml
24 |
25 | _path = Path(os.path.abspath(__file__))
26 |
27 | from fixed_paths import PUBLIC_REPO_DIR
28 | from run_unittests import fetch_base_env
29 | from gym.spaces import MultiDiscrete
30 |
31 | # climate-cooperation-competition
32 | sys.path.append(os.path.join(PUBLIC_REPO_DIR, "scripts"))
33 | logging.info("Using PUBLIC_REPO_DIR = {}".format(PUBLIC_REPO_DIR))
34 |
35 | # mila-sfdc-...
36 | _PRIVATE_REPO_DIR = os.path.join(
37 | _path.parent.parent.parent.absolute(), "private-repo-clone"
38 | )
39 | sys.path.append(os.path.join(_PRIVATE_REPO_DIR, "backend"))
40 | logging.info("Using _PRIVATE_REPO_DIR = {}".format(_PRIVATE_REPO_DIR))
41 |
42 |
43 | # Set logger level e.g., DEBUG, INFO, WARNING, ERROR.
44 | logging.getLogger().setLevel(logging.ERROR)
45 |
46 | _EVAL_SEED = 1234567890 # seed used for evaluation
47 |
48 | _INDEXES_FILENAME = "climate_economic_min_max_indices.txt"
49 |
50 | _METRICS_TO_LABEL_DICT = OrderedDict()
51 | # Read the dict values below as
52 | # (label, decimal points used to round off value: 0 becomes an integer)
53 | _METRICS_TO_LABEL_DICT["reward_all_regions"] = ("Episode Reward", 2)
54 | _METRICS_TO_LABEL_DICT["global_temperature"] = ("Temperature Rise", 2)
55 | _METRICS_TO_LABEL_DICT["global_carbon_mass"] = ("Carbon Mass", 0)
56 | _METRICS_TO_LABEL_DICT["capital_all_regions"] = ("Capital", 0)
57 | _METRICS_TO_LABEL_DICT["production_all_regions"] = ("Production", 0)
58 | _METRICS_TO_LABEL_DICT["gross_output_all_regions"] = ("Gross Output", 0)
59 | _METRICS_TO_LABEL_DICT["investment_all_regions"] = ("Investment", 0)
60 | _METRICS_TO_LABEL_DICT["abatement_cost_all_regions"] = ("Abatement Cost", 2)
61 |
62 |
63 | def get_imports(framework=None):
64 | """
65 | Fetch relevant imports.
66 | """
67 | assert framework is not None
68 | if framework == "rllib":
69 | from train_with_rllib import (
70 | create_trainer,
71 | fetch_episode_states,
72 | load_model_checkpoints,
73 | )
74 | elif framework == "warpdrive":
75 | from train_with_warp_drive import (
76 | create_trainer,
77 | fetch_episode_states,
78 | load_model_checkpoints,
79 | )
80 | else:
81 | raise ValueError(f"Unknown framework {framework}!")
82 | return create_trainer, load_model_checkpoints, fetch_episode_states
83 |
84 |
85 | def try_to_unzip_file(path):
86 | """
87 | Obtain the 'results' directory from the system arguments.
88 | """
89 | try:
90 | _unzipped_dir = os.path.join("/tmp", str(time.time()))
91 | shutil.unpack_archive(path, _unzipped_dir)
92 | return _unzipped_dir
93 | except Exception as err:
94 | raise ValueError("Cannot obtain the results directory") from err
95 |
96 |
97 | def validate_dir(results_dir=None):
98 | """
99 | Validate that all the required files are present in the 'results' directory.
100 | """
101 | assert results_dir is not None
102 | framework = None
103 |
104 | files = os.listdir(results_dir)
105 | if ".warpdrive" in files:
106 | framework = "warpdrive"
107 | # Warpdrive was used for training
108 | for file in [
109 | "rice.py",
110 | "rice_helpers.py",
111 | "rice_cuda.py",
112 | "rice_step.cu",
113 | "rice_warpdrive.yaml",
114 | ]:
115 | if file not in files:
116 | success = False
117 | logging.error(
118 | "%s is not present in the results directory: %s!", file, results_dir
119 | )
120 | comment = f"{file} is not present in the results directory!"
121 | break
122 | success = True
123 | comment = "Valid submission"
124 | elif ".rllib" in files:
125 | framework = "rllib"
126 | # RLlib was used for training
127 | for file in ["rice.py", "rice_helpers.py", "rice_rllib.yaml"]:
128 | if file not in files:
129 | success = False
130 | logging.error(
131 | "%s is not present in the results directory: %s!", file, results_dir
132 | )
133 | comment = f"{file} is not present in the results directory!"
134 | break
135 | success = True
136 | comment = "Valid submission"
137 | else:
138 | success = False
139 | logging.error(
140 | "Missing identifier file! Either the .rllib or the .warpdrive "
141 | "file must be present in the results directory: %s",
142 | results_dir,
143 | )
144 | comment = "Missing identifier file!"
145 | print("comment", comment)
146 | return framework, success, comment
147 |
148 |
149 | def compute_metrics(
150 | fetch_episode_states, trainer, framework, num_episodes=1, include_c_e_idx=True
151 | ):
152 | """
153 | Generate episode rollouts and compute metrics.
154 | """
155 | assert trainer is not None
156 | available_frameworks = ["rllib", "warpdrive"]
157 | assert (
158 | framework in available_frameworks
159 | ), f"Invalid framework {framework}, should be in f{available_frameworks}."
160 |
161 | # Fetch all the desired outputs to compute various metrics.
162 | desired_outputs = list(_METRICS_TO_LABEL_DICT.keys())
163 | # Add auxiliary outputs required for processing
164 | required_outputs = desired_outputs + ["activity_timestep"]
165 |
166 | episode_states = {}
167 | eval_metrics = {}
168 | try:
169 | for episode_id in range(num_episodes):
170 | if fetch_episode_states is not None:
171 | episode_states[episode_id] = fetch_episode_states(
172 | trainer, required_outputs
173 | )
174 | else:
175 | episode_states[episode_id] = trainer.fetch_episode_global_states(
176 | required_outputs
177 | )
178 |
179 | for feature in desired_outputs:
180 | feature_values = [None for _ in range(num_episodes)]
181 |
182 | if feature == "global_temperature":
183 | # Get the temp rise for upper strata
184 | for episode_id in range(num_episodes):
185 | feature_values[episode_id] = (
186 | episode_states[episode_id][feature][-1, 0]
187 | - episode_states[episode_id][feature][0, 0]
188 | )
189 |
190 | elif feature == "global_carbon_mass":
191 | for episode_id in range(num_episodes):
192 | feature_values[episode_id] = episode_states[episode_id][feature][
193 | -1, 0
194 | ]
195 |
196 | elif feature == "gross_output_all_regions":
197 | for episode_id in range(num_episodes):
198 | # collect gross output results based on activity timestep
199 | activity_timestep = episode_states[episode_id]["activity_timestep"]
200 | activity_index = np.append(
201 | 1.0, np.diff(activity_timestep.squeeze())
202 | )
203 | activity_index = [np.isclose(v, 1.0) for v in activity_index]
204 | feature_values[episode_id] = np.sum(
205 | episode_states[episode_id]["gross_output_all_regions"][
206 | activity_index
207 | ]
208 | )
209 |
210 | else:
211 | for episode_id in range(num_episodes):
212 | feature_values[episode_id] = np.sum(
213 | episode_states[episode_id][feature]
214 | )
215 |
216 | # Compute mean feature value across episodes
217 | mean_feature_value = np.mean(feature_values)
218 |
219 | # Formatting the values
220 | metrics_to_label_dict = _METRICS_TO_LABEL_DICT[feature]
221 |
222 | eval_metrics[metrics_to_label_dict[0]] = perform_format(
223 | mean_feature_value, metrics_to_label_dict[1]
224 | )
225 | if include_c_e_idx:
226 | if not os.path.exists(_INDEXES_FILENAME):
227 | # Write min, max climate and economic index values to a file
228 | # for use during evaluation.
229 | indices_dict = generate_min_max_climate_economic_indices()
230 | # Write indices to a file
231 | with open(_INDEXES_FILENAME, "w", encoding="utf-8") as file_ptr:
232 | file_ptr.write(json.dumps(indices_dict))
233 | with open(_INDEXES_FILENAME, "r", encoding="utf-8") as file_ptr:
234 | index_dict = json.load(file_ptr)
235 | eval_metrics["climate_index"] = np.round(
236 | (eval_metrics["Temperature Rise"] - index_dict["min_ci"])
237 | / (index_dict["max_ci"] - index_dict["min_ci"]),
238 | 2,
239 | )
240 | eval_metrics["economic_index"] = np.round(
241 | (eval_metrics["Gross Output"] - index_dict["min_ei"])
242 | / (index_dict["max_ei"] - index_dict["min_ei"]),
243 | 2,
244 | )
245 | success = True
246 | comment = "Successful submission"
247 | except Exception as err:
248 | logging.error(err)
249 | success = False
250 | comment = "Could not obtain an episode rollout!"
251 | eval_metrics = {}
252 |
253 | return success, comment, eval_metrics
254 |
255 |
256 | def val_metrics(logged_ts, framework, num_episodes=1, include_c_e_idx=True):
257 | """
258 | Generate episode rollouts and compute metrics.
259 | """
260 | available_frameworks = ["rllib", "warpdrive"]
261 | assert (
262 | framework in available_frameworks
263 | ), f"Invalid framework {framework}, should be in f{available_frameworks}."
264 |
265 | # Fetch all the desired outputs to compute various metrics.
266 | desired_outputs = list(_METRICS_TO_LABEL_DICT.keys())
267 | episode_states = {}
268 | eval_metrics = {}
269 | try:
270 | for episode_id in range(num_episodes):
271 | episode_states[episode_id] = logged_ts
272 |
273 | for feature in desired_outputs:
274 | feature_values = [None for _ in range(num_episodes)]
275 |
276 | if feature == "global_temperature":
277 | # Get the temp rise for upper strata
278 | for episode_id in range(num_episodes):
279 | feature_values[episode_id] = (
280 | episode_states[episode_id][feature][-1, 0]
281 | - episode_states[episode_id][feature][0, 0]
282 | )
283 |
284 | elif feature == "global_carbon_mass":
285 | for episode_id in range(num_episodes):
286 | feature_values[episode_id] = episode_states[episode_id][feature][
287 | -1, 0
288 | ]
289 |
290 | elif feature == "gross_output_all_regions":
291 | for episode_id in range(num_episodes):
292 | # collect gross output results based on activity timestep
293 | activity_timestep = episode_states[episode_id]["activity_timestep"]
294 | activity_index = np.append(
295 | 1.0, np.diff(activity_timestep.squeeze())
296 | )
297 | activity_index = [np.isclose(v, 1.0) for v in activity_index]
298 | feature_values[episode_id] = np.sum(
299 | episode_states[episode_id]["gross_output_all_regions"][
300 | activity_index
301 | ]
302 | )
303 | else:
304 | for episode_id in range(num_episodes):
305 | feature_values[episode_id] = np.sum(
306 | episode_states[episode_id][feature]
307 | )
308 |
309 | # Compute mean feature value across episodes
310 | mean_feature_value = np.mean(feature_values)
311 |
312 | # Formatting the values
313 | metrics_to_label_dict = _METRICS_TO_LABEL_DICT[feature]
314 |
315 | eval_metrics[metrics_to_label_dict[0]] = perform_format(
316 | mean_feature_value, metrics_to_label_dict[1]
317 | )
318 | if include_c_e_idx:
319 | if not os.path.exists(_INDEXES_FILENAME):
320 | # Write min, max climate and economic index values to a file
321 | # for use during evaluation.
322 | indices_dict = generate_min_max_climate_economic_indices()
323 | # Write indices to a file
324 | with open(_INDEXES_FILENAME, "w", encoding="utf-8") as file_ptr:
325 | file_ptr.write(json.dumps(indices_dict))
326 | with open(_INDEXES_FILENAME, "r", encoding="utf-8") as file_ptr:
327 | index_dict = json.load(file_ptr)
328 | eval_metrics["climate_index"] = np.round(
329 | (eval_metrics["Temperature Rise"] - index_dict["min_ci"])
330 | / (index_dict["max_ci"] - index_dict["min_ci"]),
331 | 2,
332 | )
333 | eval_metrics["economic_index"] = np.round(
334 | (eval_metrics["Gross Output"] - index_dict["min_ei"])
335 | / (index_dict["max_ei"] - index_dict["min_ei"]),
336 | 2,
337 | )
338 | success = True
339 | comment = "Successful submission"
340 | except Exception as err:
341 | logging.error(err)
342 | success = False
343 | comment = "Could not obtain an episode rollout!"
344 | eval_metrics = {}
345 |
346 | return success, comment, eval_metrics
347 |
348 |
349 | def perform_format(val, num_decimal_places):
350 | """
351 | Format value to the number of desired decimal points.
352 | """
353 | if np.isnan(val):
354 | return val
355 | assert num_decimal_places >= 0
356 | rounded_val = np.round(val, num_decimal_places)
357 | if num_decimal_places == 0:
358 | return int(rounded_val)
359 | return rounded_val
360 |
361 |
362 | def perform_evaluation(
363 | results_directory,
364 | framework,
365 | num_episodes=1,
366 | eval_seed=None,
367 | ):
368 | """
369 | Create the trainer and compute metrics.
370 | """
371 | assert results_directory is not None
372 | assert num_episodes > 0
373 |
374 | (
375 | create_trainer,
376 | load_model_checkpoints,
377 | fetch_episode_states,
378 | ) = get_imports(framework=framework)
379 |
380 | # Load a run configuration
381 | config_file = os.path.join(results_directory, f"rice_{framework}.yaml")
382 |
383 | try:
384 | assert os.path.exists(config_file)
385 | except Exception as err:
386 | logging.error(f"The run configuration is missing in {results_directory}.")
387 | raise err
388 |
389 | # Copy the PUBLIC region yamls and rice_build.cu to the results directory.
390 | if not os.path.exists(os.path.join(results_directory, "region_yamls")):
391 | shutil.copytree(
392 | os.path.join(PUBLIC_REPO_DIR, "region_yamls"),
393 | os.path.join(results_directory, "region_yamls"),
394 | )
395 | if not os.path.exists(os.path.join(results_directory, "rice_build.cu")):
396 | shutil.copyfile(
397 | os.path.join(PUBLIC_REPO_DIR, "rice_build.cu"),
398 | os.path.join(results_directory, "rice_build.cu"),
399 | )
400 |
401 | # Create Trainer object
402 | try:
403 | with open(config_file, "r", encoding="utf-8") as file_ptr:
404 | run_config = yaml.safe_load(file_ptr)
405 |
406 | trainer, _ = create_trainer(
407 | run_config, source_dir=results_directory, seed=eval_seed
408 | )
409 |
410 | except Exception as err:
411 | logging.error(f"Could not create Trainer with the run_config provided.")
412 | raise err
413 |
414 | # Load model checkpoints
415 | try:
416 | load_model_checkpoints(trainer, results_directory)
417 | except Exception as err:
418 | logging.error(f"Could not load model checkpoints.")
419 | raise err
420 |
421 | # Compute metrics
422 | try:
423 | success, comment, eval_metrics = compute_metrics(
424 | fetch_episode_states,
425 | trainer,
426 | framework,
427 | num_episodes=num_episodes,
428 | )
429 |
430 | if framework == "warpdrive":
431 | trainer.graceful_close()
432 |
433 | return success, eval_metrics, comment
434 |
435 | except Exception as err:
436 | logging.error(f"Count not fetch episode and compute metrics.")
437 | raise err
438 |
439 |
440 | def get_temp_rise_and_gross_output(env, actions):
441 | env.reset()
442 | for _ in range(env.episode_length):
443 | env.step(actions)
444 | temperature_array = env.global_state["global_temperature"]["value"]
445 | temperature_rise = temperature_array[-1, 0] - temperature_array[0, 0]
446 |
447 | total_gross_production = np.sum(
448 | env.global_state["gross_output_all_regions"]["value"]
449 | )
450 | return temperature_rise, total_gross_production
451 |
452 |
453 | def generate_min_max_climate_economic_indices():
454 | """
455 | Generate min and max climate and economic indices for the leaderboard.
456 | 0% savings, 100% mitigation => best climate index, worst economic index
457 | 100% savings, 0% mitigation => worst climate index, best economic index
458 | """
459 | env = fetch_base_env() # base rice env
460 | assert isinstance(
461 | env.action_space[0], MultiDiscrete
462 | ), "Unknown action space for env."
463 | all_zero_actions = {
464 | agent_id: np.zeros(
465 | len(env.action_space[agent_id].nvec),
466 | dtype=np.int32,
467 | )
468 | for agent_id in range(env.num_agents)
469 | }
470 |
471 | # 0% savings, 100% mitigation
472 | low_savings_high_mitigation_actions = {}
473 | savings_action_idx = 0
474 | mitigation_action_idx = 1
475 | for agent_id in range(env.num_agents):
476 | low_savings_high_mitigation_actions[agent_id] = all_zero_actions[
477 | agent_id
478 | ].copy()
479 | low_savings_high_mitigation_actions[agent_id][
480 | mitigation_action_idx
481 | ] = env.num_discrete_action_levels
482 | # Best climate index, worst economic index
483 | best_ci, worst_ei = get_temp_rise_and_gross_output(
484 | env, low_savings_high_mitigation_actions
485 | )
486 |
487 | high_savings_low_mitigation_actions = {}
488 | for agent_id in range(env.num_agents):
489 | high_savings_low_mitigation_actions[agent_id] = all_zero_actions[
490 | agent_id
491 | ].copy()
492 | high_savings_low_mitigation_actions[agent_id][
493 | savings_action_idx
494 | ] = env.num_discrete_action_levels
495 | worst_ci, best_ei = get_temp_rise_and_gross_output(
496 | env, high_savings_low_mitigation_actions
497 | )
498 |
499 | index_dict = {
500 | "min_ci": float(worst_ci),
501 | "max_ci": float(best_ci),
502 | "min_ei": float(worst_ei),
503 | "max_ei": float(best_ei),
504 | }
505 | return index_dict
506 |
507 |
508 | if __name__ == "__main__":
509 | logging.info("This script performs evaluation of your code.")
510 |
511 | # CLI arguments
512 | parser = argparse.ArgumentParser()
513 | parser.add_argument(
514 | "--results_dir",
515 | "-r",
516 | type=str,
517 | default="./Submissions/1680502535.zip", # an example of a submission file
518 | help="The directory where all the submission files are saved. Can also be "
519 | "a zip-file containing all the submission files.",
520 | )
521 | args = parser.parse_args()
522 |
523 | # Check the submission zip file or directory.
524 | if "results_dir" not in args:
525 | raise ValueError(
526 | "Please provide a results directory to evaluate with the argument -r"
527 | )
528 |
529 | if not os.path.exists(args.results_dir):
530 | raise ValueError(
531 | "The results directory is missing. Please make sure the correct path "
532 | "is specified!"
533 | )
534 |
535 | results_dir = (
536 | try_to_unzip_file(args.results_dir)
537 | if args.results_dir.endswith(".zip")
538 | else args.results_dir
539 | )
540 |
541 | logging.info(f"Using submission files in {results_dir}")
542 |
543 | # Validate the submission directory
544 | framework, results_dir_is_valid, comment = validate_dir(results_dir)
545 | if not results_dir_is_valid:
546 | raise AssertionError(f"{results_dir} is not a valid submission directory.")
547 |
548 | # Run unit tests on the simulation files
549 | skip_unit_tests = True # = args.skip_unit_tests
550 |
551 | try:
552 | if skip_unit_tests:
553 | logging.info("Skipping check_output test")
554 | else:
555 | logging.info("Running unit tests...")
556 | subprocess.check_output(
557 | [
558 | "python",
559 | "run_unittests.py",
560 | "--results_dir",
561 | results_dir,
562 | ],
563 | )
564 | logging.info("run_unittests.py is done")
565 | except subprocess.CalledProcessError as err:
566 | logging.error(f"{results_dir}: unit tests were not successful.")
567 | raise err
568 |
569 | # Run evaluation with submitted simulation and trained agents.
570 | logging.info("Starting eval...")
571 | succeeded, metrics, comments = perform_evaluation(
572 | results_dir, framework, eval_seed=_EVAL_SEED
573 | )
574 |
575 | # Report results.
576 | eval_result_str = "\n".join(
577 | [
578 | f"Framework used: {framework}",
579 | f"Succeeded: {succeeded}",
580 | f"Metrics: {metrics}",
581 | f"Comments: {comments}",
582 | ]
583 | )
584 | logging.info(eval_result_str)
585 | print(eval_result_str)
586 |
--------------------------------------------------------------------------------
/MARL_training.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "Copyright (c) 2022, salesforce.com, inc and MILA. \n",
8 | "All rights reserved. \n",
9 | "SPDX-License-Identifier: BSD-3-Clause \n",
10 | "For full license text, see the LICENSE file in the repo root \n",
11 | "or https://opensource.org/licenses/BSD-3-Clause "
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "Get started quickly with end-to-end multi-agent RL using WarpDrive! This shows a basic example to create a simple Rice environment and perform training."
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "**Try this notebook on [Colab](http://colab.research.google.com/github/salesforce/warp-drive/blob/master/tutorials/simple-end-to-end-example.ipynb)!**"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "## ⚠️ PLEASE NOTE:\n",
33 | "This notebook runs on a GPU runtime.\\\n",
34 | "If running on Colab, choose Runtime > Change runtime type from the menu, then select `GPU` in the 'hardware accelerator' dropdown menu."
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {},
40 | "source": [
41 | "### Dependencies"
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {},
47 | "source": [
48 | "First, install the WarpDrive package"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "%load_ext autoreload\n",
58 | "%autoreload 2\n"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {
65 | "scrolled": true
66 | },
67 | "outputs": [],
68 | "source": [
69 | "# !pip install rl-warp-drive"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "import os\n",
79 | "import torch\n",
80 | "\n",
81 | "from rice_cuda import RiceCuda\n",
82 | "from warp_drive.env_wrapper import EnvWrapper\n",
83 | "from warp_drive.training.trainer import Trainer\n",
84 | "from warp_drive.utils.env_registrar import EnvironmentRegistrar\n",
85 | "\n",
86 | "pytorch_cuda_init_success = torch.cuda.FloatTensor(8)"
87 | ]
88 | },
89 | {
90 | "cell_type": "markdown",
91 | "metadata": {},
92 | "source": [
93 | "# Environment, Training, and Model Hyperparameters"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "run_config = dict(\n",
103 | " \n",
104 | " # Environment settings\n",
105 | " env = dict( \n",
106 | " negotiation_on=1,\n",
107 | " num_discretization_cells = 10,\n",
108 | " ),\n",
109 | "\n",
110 | " # Trainer settings\n",
111 | " trainer = dict(\n",
112 | " num_envs = 100, # Number of environment replicas (numbre of GPU blocks used)\n",
113 | " train_batch_size = 10000, # total batch size used for training per iteration (across all the environments)\n",
114 | " num_episodes = 100000, # Total number of episodes to run the training for (can be arbitrarily high!)\n",
115 | " ),\n",
116 | " \n",
117 | " # Policy network settings\n",
118 | " policy = dict(\n",
119 | " regions = dict(\n",
120 | " to_train = True,\n",
121 | " gamma = 0.92, # discount factor\n",
122 | " lr = 0.0005, # learning rate\n",
123 | " entropy_coeff = [[0,0.5], [1000000, 0.1], [5000000, 0.05]],\n",
124 | " vf_loss_coeff = [[0,0.0001], [1000000, 0.001], [5000000, 0.01], [10000000, 0.1]],\n",
125 | " model = dict( \n",
126 | " type = \"fully_connected\",\n",
127 | " fc_dims = [256,256], # dimension(s) of the fully connected layers as a list\n",
128 | " model_ckpt_filepath = \"\" # load model parameters from a saved checkpoint (if specified)\n",
129 | " )\n",
130 | " ),\n",
131 | " ),\n",
132 | " \n",
133 | " # Checkpoint saving setting\n",
134 | " saving = dict(\n",
135 | " metrics_log_freq = 10, # How often (in iterations) to print the metrics\n",
136 | " model_params_save_freq = 5000, # How often (in iterations) to save the model parameters\n",
137 | " basedir = \"/tmp\", # base folder used for saving\n",
138 | " name = \"rice\",\n",
139 | " tag = \"example\",\n",
140 | " )\n",
141 | ")"
142 | ]
143 | },
144 | {
145 | "cell_type": "markdown",
146 | "metadata": {},
147 | "source": [
148 | "# End-to-End Training Loop"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": null,
154 | "metadata": {
155 | "scrolled": true
156 | },
157 | "outputs": [],
158 | "source": [
159 | "# Register the environment\n",
160 | "env_registrar = EnvironmentRegistrar()\n",
161 | "this_file_dir = os.path.dirname(os.path.abspath(\"__file__\"))\n",
162 | "env_registrar.add_cuda_env_src_path(\n",
163 | " RiceCuda.name,\n",
164 | " os.path.join(this_file_dir, \"rice_build.cu\")\n",
165 | ")\n",
166 | "\n",
167 | "# cpu_env = EnvWrapper(Rice())\n",
168 | "\n",
169 | "# add_cpu_env = env_registrar.add(device=\"cpu\")\n",
170 | "# add_cpu_env(cpu_env)\n",
171 | "# add_gpu_env = env_registrar.add(device=\"gpu\")\n",
172 | "# add_gpu_env(cpu_env)\n",
173 | "\n",
174 | "# Create a wrapped environment object via the EnvWrapper\n",
175 | "# Ensure that use_cuda is set to True (in order to run on the GPU)\n",
176 | "env_wrapper = EnvWrapper(\n",
177 | " RiceCuda(**run_config[\"env\"]),\n",
178 | " num_envs=run_config[\"trainer\"][\"num_envs\"], \n",
179 | " use_cuda=True,\n",
180 | " env_registrar=env_registrar,\n",
181 | ")\n",
182 | "\n",
183 | "# Agents can share policy models: this dictionary maps policy model names to agent ids.\n",
184 | "policy_tag_to_agent_id_map = {\n",
185 | " \"regions\": [agent_id for agent_id in range(env_wrapper.env.num_agents)],\n",
186 | "}\n",
187 | "\n",
188 | "# Create the trainer object\n",
189 | "trainer = Trainer(\n",
190 | " env_wrapper=env_wrapper,\n",
191 | " config=run_config,\n",
192 | " policy_tag_to_agent_id_map=policy_tag_to_agent_id_map,\n",
193 | ")\n",
194 | "\n",
195 | "# Perform training!\n",
196 | "trainer.train()\n",
197 | "\n",
198 | "# Shut off gracefully\n",
199 | "# trainer.graceful_close()"
200 | ]
201 | },
202 | {
203 | "cell_type": "markdown",
204 | "metadata": {},
205 | "source": [
206 | "### Fetch episode states"
207 | ]
208 | },
209 | {
210 | "cell_type": "code",
211 | "execution_count": null,
212 | "metadata": {},
213 | "outputs": [],
214 | "source": [
215 | "# Please note that any variable registered in rice_cuda.py can be put here\n",
216 | "desired_outputs = [\n",
217 | " \"T_i\", # Temperature\n",
218 | " \"M_i\", # Carbon mass\n",
219 | " \"sampled_actions\",\n",
220 | " \"minMu\"\n",
221 | " ]\n",
222 | "\n",
223 | "episode_states = trainer.fetch_episode_states(\n",
224 | " desired_outputs\n",
225 | ")\n",
226 | "\n",
227 | "trainer.graceful_close()"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": null,
233 | "metadata": {},
234 | "outputs": [],
235 | "source": [
236 | "import matplotlib.pyplot as plt"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": null,
242 | "metadata": {},
243 | "outputs": [],
244 | "source": [
245 | "def get_episode_T_AT(episode_states, negotiation_on, plot = 0):\n",
246 | " state = 'T_i'\n",
247 | " if negotiation_on:\n",
248 | " values = episode_states[state][::3,0,0]\n",
249 | " else:\n",
250 | " values = episode_states[state][:,0,0]\n",
251 | "\n",
252 | " if plot:\n",
253 | " fig = plt.figure() \n",
254 | " plt.plot(values[:], label='Temperature - Atmosphere')\n",
255 | " fig.legend()\n",
256 | " # plt.yscale('log')\n",
257 | " fig.show()\n",
258 | "\n",
259 | " return values\n",
260 | "\n",
261 | "\n",
262 | "def get_episode_T_LO(episode_states, negotiation_on, plot = 0):\n",
263 | " state = 'T_i'\n",
264 | " if negotiation_on:\n",
265 | " values = episode_states[state][::3,0,1]\n",
266 | " else:\n",
267 | " values = episode_states[state][:,0,1]\n",
268 | "\n",
269 | " if plot:\n",
270 | " fig = plt.figure()\n",
271 | " plt.plot(values[:], label='Temperature - Lower Oceans')\n",
272 | " fig.legend()\n",
273 | " # plt.yscale('log')\n",
274 | " fig.show()\n",
275 | "\n",
276 | " return values\n",
277 | "\n",
278 | "def get_episode_M_AT(episode_states, negotiation_on, plot = 0):\n",
279 | " state = 'M_i'\n",
280 | " if negotiation_on:\n",
281 | " values = episode_states[state][::3,0,0]\n",
282 | " else:\n",
283 | " values = episode_states[state][:,0,0]\n",
284 | "\n",
285 | " if plot:\n",
286 | " fig = plt.figure()\n",
287 | " plt.plot(values[:], label='Carbon - Atmosphere')\n",
288 | " fig.legend()\n",
289 | " # plt.yscale('log')\n",
290 | " fig.show()\n",
291 | "\n",
292 | " return values\n",
293 | "\n",
294 | "def get_episode_M_UP(episode_states, negotiation_on, plot = 0):\n",
295 | " state = 'M_i'\n",
296 | " if negotiation_on:\n",
297 | " values = episode_states[state][::3,0,1]\n",
298 | " else:\n",
299 | " values = episode_states[state][:,:0,1]\n",
300 | "\n",
301 | " if plot:\n",
302 | " fig = plt.figure()\n",
303 | " plt.plot(values[:], label='Carbon - Upper Strata')\n",
304 | " fig.legend()\n",
305 | " # plt.yscale('log')\n",
306 | " fig.show()\n",
307 | "\n",
308 | " return values\n",
309 | "\n",
310 | "def get_episode_M_UP(episode_states, negotiation_on, plot = 0):\n",
311 | " state = 'M_i'\n",
312 | " if negotiation_on:\n",
313 | " values = episode_states[state][::3,0,2]\n",
314 | " else:\n",
315 | " values = episode_states[state][:,0,2]\n",
316 | "\n",
317 | " if plot:\n",
318 | " fig = plt.figure()\n",
319 | " plt.plot(values[:], label='Carbon - Lower Oceans')\n",
320 | " fig.legend()\n",
321 | " # plt.yscale('log')\n",
322 | " fig.show()\n",
323 | "\n",
324 | " return values\n",
325 | "\n",
326 | "def get_episode_minMu(episode_states, negotiation_on, plot = 0):\n",
327 | " state = 'minMu'\n",
328 | " if negotiation_on:\n",
329 | " values = episode_states[state][::3,:]\n",
330 | " else:\n",
331 | " values = episode_states[state][:,:]\n",
332 | "\n",
333 | " if plot:\n",
334 | " for agent in range(len(values[0])):\n",
335 | " fig = plt.figure()\n",
336 | " plt.plot(values[:,agent], label='minMu - Agent:' + str(agent))\n",
337 | " fig.legend()\n",
338 | " # plt.yscale('log')\n",
339 | " fig.show()\n",
340 | "\n",
341 | " return values"
342 | ]
343 | },
344 | {
345 | "cell_type": "code",
346 | "execution_count": null,
347 | "metadata": {},
348 | "outputs": [],
349 | "source": [
350 | "def get_episode_MuAction(episode_states, negotiation_on, plot = 0):\n",
351 | " state = 'samples_actions'\n",
352 | " if negotiation_on:\n",
353 | " values = episode_states[state][::3,:, -2]\n",
354 | " else:\n",
355 | " values = episode_states[state][:,:, -2]\n",
356 | "\n",
357 | " if plot:\n",
358 | " for agent in range(len(values[0])):\n",
359 | " fig = plt.figure()\n",
360 | " plt.plot(values[:,agent], label='Mu Action - Agent:' + str(agent))\n",
361 | " fig.legend()\n",
362 | " # plt.yscale('log')\n",
363 | " fig.show()\n",
364 | "\n",
365 | " return values\n",
366 | "\n",
367 | "def get_episode_SavingAction(episode_states, negotiation_on, plot = 0):\n",
368 | " state = 'samples_actions'\n",
369 | " if negotiation_on:\n",
370 | " values = episode_states[state][::3,:, -1]\n",
371 | " else:\n",
372 | " values = episode_states[state][:,:, -1]\n",
373 | "\n",
374 | " if plot:\n",
375 | " for agent in range(len(values[0])):\n",
376 | " fig = plt.figure()\n",
377 | " plt.plot(values[:,agent], label='Mu Action - Agent:' + str(agent))\n",
378 | " fig.legend()\n",
379 | " # plt.yscale('log')\n",
380 | " fig.show()\n",
381 | "\n",
382 | " return values"
383 | ]
384 | },
385 | {
386 | "cell_type": "code",
387 | "execution_count": null,
388 | "metadata": {},
389 | "outputs": [],
390 | "source": []
391 | },
392 | {
393 | "cell_type": "code",
394 | "execution_count": null,
395 | "metadata": {},
396 | "outputs": [],
397 | "source": [
398 | "episode_states['sampled_actions']"
399 | ]
400 | },
401 | {
402 | "cell_type": "code",
403 | "execution_count": null,
404 | "metadata": {},
405 | "outputs": [],
406 | "source": [
407 | "get_episode_T_AT(episode_states, 1, 1)"
408 | ]
409 | },
410 | {
411 | "cell_type": "code",
412 | "execution_count": null,
413 | "metadata": {},
414 | "outputs": [],
415 | "source": [
416 | "get_episode_T_LO(episode_states, 1, 1)"
417 | ]
418 | },
419 | {
420 | "cell_type": "code",
421 | "execution_count": null,
422 | "metadata": {},
423 | "outputs": [],
424 | "source": [
425 | "get_episode_minMu(episode_states, 1, 1)"
426 | ]
427 | },
428 | {
429 | "cell_type": "code",
430 | "execution_count": null,
431 | "metadata": {},
432 | "outputs": [],
433 | "source": [
434 | "get_episode_M_AT(episode_states, 1, 1)"
435 | ]
436 | },
437 | {
438 | "cell_type": "code",
439 | "execution_count": null,
440 | "metadata": {},
441 | "outputs": [],
442 | "source": []
443 | },
444 | {
445 | "cell_type": "code",
446 | "execution_count": null,
447 | "metadata": {},
448 | "outputs": [],
449 | "source": [
450 | "T_AT_reg_0_neg_on = episode_states['T_i'][:,0,0]\n",
451 | "#T_AT_reg_1_neg_off = episode_states['T_i'][:,1,0]\n",
452 | "\n",
453 | "T_LO_reg_0_neg_on = episode_states['T_i'][:,0,1]\n",
454 | "#T_LO_reg_1_neg_off = episode_states['T_i'][:,1,1]\n",
455 | "\n",
456 | "\n",
457 | "M_AT_reg_0_neg_on = episode_states['M_i'][:,0,0]\n",
458 | "#M_AT_reg_1_neg_off = episode_states['M_i'][:,1,0]\n",
459 | "\n",
460 | "M_UP_reg_0_neg_on = episode_states['M_i'][:,0,1]\n",
461 | "#M_UP_reg_1_neg_off = episode_states['M_i'][:,1,1]\n",
462 | "\n",
463 | "M_LO_reg_0_neg_on = episode_states['M_i'][:,0,2]\n",
464 | "#M_LO_reg_1_neg_off = episode_states['M_i'][:,1,2]\n",
465 | "# episode_states_neg_off = episode_states.copy()\n",
466 | "\n"
467 | ]
468 | },
469 | {
470 | "cell_type": "code",
471 | "execution_count": null,
472 | "metadata": {},
473 | "outputs": [],
474 | "source": [
475 | "episode_states['minMu']"
476 | ]
477 | },
478 | {
479 | "cell_type": "code",
480 | "execution_count": null,
481 | "metadata": {},
482 | "outputs": [],
483 | "source": [
484 | "import matplotlib.pyplot as plt\n",
485 | "\n",
486 | "fig = plt.figure()\n",
487 | "plt.plot(T_AT_reg_0_neg_on, label='Temperature - Upper Strata - negotiation on')\n",
488 | "fig.legend()\n",
489 | "# plt.yscale('log')\n",
490 | "fig.show()\n",
491 | "\n",
492 | "fig = plt.figure()\n",
493 | "plt.plot(T_LO_reg_0_neg_on, label=\"Temperature - Lower Oceans - negotiation on\")\n",
494 | "fig.legend()\n",
495 | "# plt.yscale('log')\n",
496 | "fig.show()"
497 | ]
498 | },
499 | {
500 | "cell_type": "code",
501 | "execution_count": null,
502 | "metadata": {},
503 | "outputs": [],
504 | "source": [
505 | "import matplotlib.pyplot as plt\n",
506 | "\n",
507 | "fig = plt.figure()\n",
508 | "plt.plot(T_AT_reg_0_neg_on, label='Temperature - Upper Strata - negotiation on')\n",
509 | "fig.legend()\n",
510 | "# plt.yscale('log')\n",
511 | "fig.show()\n",
512 | "\n",
513 | "fig = plt.figure()\n",
514 | "plt.plot(T_LO_reg_0_neg_on, label=\"Temperature - Lower Oceans - negotiation on\")\n",
515 | "fig.legend()\n",
516 | "# plt.yscale('log')\n",
517 | "fig.show()"
518 | ]
519 | },
520 | {
521 | "cell_type": "code",
522 | "execution_count": null,
523 | "metadata": {},
524 | "outputs": [],
525 | "source": [
526 | "run_config = dict(\n",
527 | " \n",
528 | " # Environment settings\n",
529 | " env = dict( \n",
530 | " negotiation_on=0,\n",
531 | " ),\n",
532 | "\n",
533 | " # Trainer settings\n",
534 | " trainer = dict(\n",
535 | " num_envs = 100, # Number of environment replicas (numbre of GPU blocks used)\n",
536 | " train_batch_size = 10000, # total batch size used for training per iteration (across all the environments)\n",
537 | " num_episodes = 30000, # Total number of episodes to run the training for (can be arbitrarily high!)\n",
538 | " ),\n",
539 | " \n",
540 | " # Policy network settings\n",
541 | " policy = dict(\n",
542 | " regions = dict(\n",
543 | " to_train = True,\n",
544 | " gamma = 0.98, # discount factor\n",
545 | " lr = 0.005, # learning rate\n",
546 | " model = dict( \n",
547 | " type = \"fully_connected\",\n",
548 | " fc_dims = [256, 256], # dimension(s) of the fully connected layers as a list\n",
549 | " model_ckpt_filepath = \"\" # load model parameters from a saved checkpoint (if specified)\n",
550 | " )\n",
551 | " ),\n",
552 | " ),\n",
553 | " \n",
554 | " # Checkpoint saving setting\n",
555 | " saving = dict(\n",
556 | " metrics_log_freq = 10, # How often (in iterations) to print the metrics\n",
557 | " model_params_save_freq = 5000, # How often (in iterations) to save the model parameters\n",
558 | " basedir = \"/tmp\", # base folder used for saving\n",
559 | " name = \"rice\",\n",
560 | " tag = \"example\",\n",
561 | " )\n",
562 | ")"
563 | ]
564 | },
565 | {
566 | "cell_type": "code",
567 | "execution_count": null,
568 | "metadata": {},
569 | "outputs": [],
570 | "source": [
571 | "# Register the environment\n",
572 | "env_registrar = EnvironmentRegistrar()\n",
573 | "this_file_dir = os.path.dirname(os.path.abspath(\"__file__\"))\n",
574 | "env_registrar.add_cuda_env_src_path(\n",
575 | " RiceCuda.name,\n",
576 | " os.path.join(this_file_dir, \"rice_build.cu\")\n",
577 | ")\n",
578 | "\n",
579 | "# cpu_env = EnvWrapper(Rice())\n",
580 | "\n",
581 | "# add_cpu_env = env_registrar.add(device=\"cpu\")\n",
582 | "# add_cpu_env(cpu_env)\n",
583 | "# add_gpu_env = env_registrar.add(device=\"gpu\")\n",
584 | "# add_gpu_env(cpu_env)\n",
585 | "\n",
586 | "# Create a wrapped environment object via the EnvWrapper\n",
587 | "# Ensure that use_cuda is set to True (in order to run on the GPU)\n",
588 | "env_wrapper = EnvWrapper(\n",
589 | " RiceCuda(**run_config[\"env\"]),\n",
590 | " num_envs=run_config[\"trainer\"][\"num_envs\"], \n",
591 | " use_cuda=True,\n",
592 | " env_registrar=env_registrar,\n",
593 | ")\n",
594 | "\n",
595 | "# Agents can share policy models: this dictionary maps policy model names to agent ids.\n",
596 | "policy_tag_to_agent_id_map = {\n",
597 | " \"regions\": [agent_id for agent_id in range(env_wrapper.env.num_agents)],\n",
598 | "}\n",
599 | "\n",
600 | "# Create the trainer object\n",
601 | "trainer = Trainer(\n",
602 | " env_wrapper=env_wrapper,\n",
603 | " config=run_config,\n",
604 | " policy_tag_to_agent_id_map=policy_tag_to_agent_id_map,\n",
605 | ")\n",
606 | "\n",
607 | "# Perform training!\n",
608 | "trainer.train()\n"
609 | ]
610 | },
611 | {
612 | "cell_type": "code",
613 | "execution_count": null,
614 | "metadata": {},
615 | "outputs": [],
616 | "source": [
617 | "# Please note that any variable registered in rice_cuda.py can be put here\n",
618 | "desired_outputs = [\n",
619 | " \"T_i\", # Temperature\n",
620 | " \"M_i\", # Carbon mass\n",
621 | " \"sampled_actions\",\n",
622 | " \"minMu\"\n",
623 | " ]\n",
624 | "\n",
625 | "episode_states_neg_off = trainer.fetch_episode_states(\n",
626 | " desired_outputs\n",
627 | ")\n",
628 | "\n",
629 | "trainer.graceful_close()"
630 | ]
631 | },
632 | {
633 | "cell_type": "code",
634 | "execution_count": null,
635 | "metadata": {},
636 | "outputs": [],
637 | "source": [
638 | "get_episode_T_AT(episode_states_neg_off, 0, 1)"
639 | ]
640 | },
641 | {
642 | "cell_type": "code",
643 | "execution_count": null,
644 | "metadata": {},
645 | "outputs": [],
646 | "source": [
647 | "get_episode_M_AT(episode_states_neg_off, 0, 1)"
648 | ]
649 | },
650 | {
651 | "cell_type": "code",
652 | "execution_count": null,
653 | "metadata": {},
654 | "outputs": [],
655 | "source": [
656 | "T_AT_reg_0_neg_off = episode_states_neg_off['T_i'][:,0,0]\n",
657 | "#T_AT_reg_1_neg_off = episode_states['T_i'][:,1,0]\n",
658 | "\n",
659 | "T_LO_reg_0_neg_off = episode_states_neg_off['T_i'][:,0,1]\n",
660 | "#T_LO_reg_1_neg_off = episode_states['T_i'][:,1,1]\n",
661 | "\n",
662 | "\n",
663 | "M_AT_reg_0_neg_off = episode_states_neg_off['M_i'][:,0,0]\n",
664 | "#M_AT_reg_1_neg_off = episode_states['M_i'][:,1,0]\n",
665 | "\n",
666 | "M_UP_reg_0_neg_off = episode_states_neg_off['M_i'][:,0,1]\n",
667 | "#M_UP_reg_1_neg_off = episode_states['M_i'][:,1,1]\n",
668 | "\n",
669 | "M_LO_reg_0_neg_off = episode_states_neg_off['M_i'][:,0,2]\n",
670 | "#M_LO_reg_1_neg_off = episode_states['M_i'][:,1,2]\n",
671 | "# episode_states_neg_off = episode_states.copy()"
672 | ]
673 | },
674 | {
675 | "cell_type": "code",
676 | "execution_count": null,
677 | "metadata": {},
678 | "outputs": [],
679 | "source": [
680 | "fig = plt.figure()\n",
681 | "plt.plot(T_AT_reg_0_neg_off, label='Temperature - Upper Strata - negotiation off')\n",
682 | "fig.legend()\n",
683 | "# plt.yscale('log')\n",
684 | "fig.show()\n",
685 | "\n",
686 | "fig = plt.figure()\n",
687 | "plt.plot(T_LO_reg_0_neg_off, label=\"Temperature - Lower Oceans - negotiation off\")\n",
688 | "fig.legend()\n",
689 | "# plt.yscale('log')\n",
690 | "fig.show()"
691 | ]
692 | },
693 | {
694 | "cell_type": "code",
695 | "execution_count": null,
696 | "metadata": {},
697 | "outputs": [],
698 | "source": []
699 | }
700 | ],
701 | "metadata": {
702 | "kernelspec": {
703 | "display_name": "Python 3 (ipykernel)",
704 | "language": "python",
705 | "name": "python3"
706 | },
707 | "language_info": {
708 | "codemirror_mode": {
709 | "name": "ipython",
710 | "version": 3
711 | },
712 | "file_extension": ".py",
713 | "mimetype": "text/x-python",
714 | "name": "python",
715 | "nbconvert_exporter": "python",
716 | "pygments_lexer": "ipython3",
717 | "version": "3.8.10"
718 | }
719 | },
720 | "nbformat": 4,
721 | "nbformat_minor": 4
722 | }
723 |
--------------------------------------------------------------------------------