├── tests
├── __init__.py
└── test_data.py
├── rosmo
├── __about__.py
├── __init__.py
├── agent
│ ├── __init__.py
│ ├── type.py
│ ├── utils.py
│ ├── actor.py
│ ├── improvement_op.py
│ ├── network.py
│ └── learning.py
├── data
│ ├── __init__.py
│ ├── buffer.py
│ ├── bsuite.py
│ ├── rl_unplugged_atari_baselines.json
│ └── rlu_atari.py
├── type.py
├── profiler.py
├── env_loop_observer.py
└── loggers.py
├── .github
├── actions
│ └── cache
│ │ └── action.yml
└── workflows
│ └── check.yml
├── Makefile
├── .gitignore
├── INSTALL.md
├── experiment
├── bsuite
│ ├── config.py
│ └── main.py
└── atari
│ ├── config.py
│ └── main.py
├── pyproject.toml
├── README.md
└── LICENSE
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/rosmo/__about__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """About."""
16 | __version__ = "0.0.2"
17 |
--------------------------------------------------------------------------------
/rosmo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """ROSMO: A Regularized One-Step Model-based algorithm for Offline-RL."""
16 |
--------------------------------------------------------------------------------
/.github/actions/cache/action.yml:
--------------------------------------------------------------------------------
1 | name: Cache
2 | description: "cache for pip"
3 | outputs:
4 | cache-hit:
5 | value: ${{ steps.cache.outputs.cache-hit }}
6 | description: "cache hit"
7 |
8 | runs:
9 | using: "composite"
10 | steps:
11 | - name: Get pip cache dir
12 | id: pip-cache-dir
13 | shell: bash
14 | run: |
15 | python -m pip install --upgrade pip
16 | echo "::set-output name=dir::$(pip cache dir)"
17 | - name: Cache pip
18 | id: cache-pip
19 | uses: actions/cache@v2
20 | with:
21 | path: ${{ steps.pip-cache-dir.outputs.dir }}
22 | key: ${{ runner.os }}-pip-cache-${{ hashFiles('pyproject.toml') }}
23 | restore-keys: |
24 | ${{ runner.os }}-pip-cache-
25 |
--------------------------------------------------------------------------------
/rosmo/agent/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """ROSMO Agent."""
16 | import os
17 |
18 | import tensorflow as tf
19 |
20 | os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.85"
21 | tf.config.experimental.set_visible_devices([], "GPU")
22 |
--------------------------------------------------------------------------------
/rosmo/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Offline dataset."""
16 | from rosmo.data.bsuite import env_loader as bsuite_env_loader
17 | from rosmo.data.rlu_atari import env_loader as atari_env_loader
18 |
19 | __all__ = [
20 | "bsuite_env_loader",
21 | "atari_env_loader",
22 | ]
23 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | SHELL=/bin/bash
2 | PROJECT_NAME=rosmo
3 | PROJECT_PATH=rosmo/
4 | LINT_PATHS=${PROJECT_PATH} tests/ experiment/
5 |
6 | check_install = python3 -c "import $(1)" || pip3 install $(1) --upgrade
7 | check_install_extra = python3 -c "import $(1)" || pip3 install $(2) --upgrade
8 |
9 | test:
10 | $(call check_install, pytest)
11 | pytest -s
12 |
13 | lint:
14 | $(call check_install, isort)
15 | $(call check_install, pylint)
16 | $(call check_install, mypy)
17 | isort --check --diff --project=${LINT_PATHS}
18 | pylint -j 8 --recursive=y ${LINT_PATHS}
19 | mypy ${PROJECT_PATH}
20 |
21 | format:
22 | # format using black
23 | $(call check_install, black)
24 | black ${LINT_PATHS}
25 | # sort imports
26 | $(call check_install, isort)
27 | isort ${LINT_PATHS}
28 |
29 | check-docstyle:
30 | $(call check_install, pydocstyle)
31 | pydocstyle ${PROJECT_PATH} --convention=google
32 |
33 | checks: lint check-docstyle
34 |
35 | .PHONY: format lint check-docstyle checks
36 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Temporary and binary files
2 | *~
3 | *.py[cod]
4 | *.so
5 | *.cfg
6 | !.isort.cfg
7 | !setup.cfg
8 | *.orig
9 | *.log
10 | *.pot
11 | __pycache__/*
12 | .cache/*
13 | .*.swp
14 | */.ipynb_checkpoints/*
15 | .DS_Store
16 |
17 | # Project files
18 | .ropeproject
19 | .project
20 | .pydevproject
21 | .settings
22 | .idea
23 | .vscode
24 | tags
25 |
26 | # Package files
27 | *.egg
28 | *.eggs/
29 | .installed.cfg
30 | *.egg-info
31 |
32 | # Unittest and coverage
33 | htmlcov/*
34 | .coverage
35 | .coverage.*
36 | .tox
37 | junit*.xml
38 | coverage.xml
39 | .pytest_cache/
40 |
41 | # Build and docs folder/files
42 | build/*
43 | dist/*
44 | sdist/*
45 | docs/api/*
46 | docs/_rst/*
47 | docs/_build/*
48 | cover/*
49 | MANIFEST
50 |
51 | # Per-project virtualenvs
52 | .venv*/
53 | .conda*/
54 | .python-version
55 |
56 | # Local dev folder/files
57 | local-dev/
58 | sota/
59 | wandb/
60 | .dockerignore
61 | checkpoint/
62 | logs/
63 | datasets/
64 | internal/gke/.mujoco
65 | internal/gke/bsuite
66 | internal/gke/.boto
67 |
68 | sample_data
69 | wandb
70 |
--------------------------------------------------------------------------------
/rosmo/type.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Type definitions."""
16 | from typing import Callable, NamedTuple, Union
17 |
18 | import jax.numpy as jnp
19 | import numpy as np
20 |
21 | Array = Union[np.ndarray, jnp.ndarray]
22 | Forwardable = Callable
23 |
24 |
25 | class ActorOutput(NamedTuple):
26 | """Actor output parsed from the dataset."""
27 |
28 | observation: Array
29 | reward: Array
30 | is_first: Array
31 | is_last: Array
32 | action: Array
33 |
--------------------------------------------------------------------------------
/.github/workflows/check.yml:
--------------------------------------------------------------------------------
1 | name: CI checks
2 |
3 | on:
4 | pull_request:
5 | paths:
6 | - '.github/workflows/check.yml'
7 | - 'experiment/**'
8 | - 'rosmo/**'
9 | - 'tests/**'
10 | - 'pyproject.py'
11 | - 'Makefile'
12 | push:
13 | branches:
14 | - main
15 | paths:
16 | - '.github/workflows/check.yml'
17 | - 'experiment/**'
18 | - 'rosmo/**'
19 | - 'tests/**'
20 | - 'pyproject.py'
21 | - 'Makefile'
22 |
23 | concurrency:
24 | group: ${{ github.ref }}-${{ github.workflow }}
25 | cancel-in-progress: true
26 |
27 | jobs:
28 | lint:
29 | runs-on: ubuntu-latest
30 | timeout-minutes: 7
31 | steps:
32 | - uses: actions/checkout@v3
33 | with:
34 | fetch-depth: 0
35 | - uses: actions/setup-python@v4
36 | with:
37 | python-version: 3.8.13
38 | - uses: ./.github/actions/cache
39 | - name: Install
40 | run: |
41 | pip install -e .
42 | pip install dopamine-rl==3.1.2
43 | pip install chex==0.1.5
44 | - name: Lint
45 | run: make checks
46 |
--------------------------------------------------------------------------------
/rosmo/agent/type.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Types."""
16 | from typing import NamedTuple
17 |
18 | from acme.jax import networks as networks_lib
19 |
20 | from rosmo.type import Array
21 |
22 |
23 | class Params(NamedTuple):
24 | """Agent parameters."""
25 |
26 | representation: networks_lib.Params
27 | transition: networks_lib.Params
28 | prediction: networks_lib.Params
29 |
30 |
31 | class AgentOutput(NamedTuple):
32 | """Agent prediction output."""
33 |
34 | state: Array
35 | policy_logits: Array
36 | value_logits: Array
37 | value: Array
38 | reward_logits: Array
39 | reward: Array
40 |
--------------------------------------------------------------------------------
/rosmo/profiler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Profiling utils."""
16 | import os
17 |
18 | import jax
19 | from viztracer import VizTracer
20 |
21 |
22 | class Profiler:
23 | """Profiler for python and jax (optional)."""
24 |
25 | def __init__(self, folder: str, name: str, with_jax: bool = False) -> None:
26 | """Init."""
27 | super().__init__()
28 | self._name = name
29 | self._folder = folder
30 | self._with_jax = with_jax
31 | self._vistracer = VizTracer(
32 | output_file=os.path.join(folder, "viztracer", name + ".html"),
33 | max_stack_depth=3,
34 | )
35 | self._jax_folder = os.path.join(folder, "jax_profiler/" + name)
36 |
37 | def start(self) -> None:
38 | """Start to trace."""
39 | if self._with_jax:
40 | jax.profiler.start_trace(self._jax_folder)
41 | self._vistracer.start()
42 |
43 | def stop(self) -> None:
44 | """Stop tracing."""
45 | self._vistracer.stop()
46 | if self._with_jax:
47 | jax.profiler.stop_trace()
48 |
49 | def save(self) -> None:
50 | """Save the results."""
51 | self._vistracer.save()
52 |
53 | def stop_and_save(self) -> None:
54 | """Combine stop and save."""
55 | self.stop()
56 | self.save()
57 |
--------------------------------------------------------------------------------
/tests/test_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Testing data loading."""
16 | import os
17 |
18 | from absl import logging
19 | from absl.testing import absltest, parameterized
20 |
21 | from rosmo.data import atari_env_loader, bsuite_env_loader
22 |
23 | _DATASET_DIR = "./datasets"
24 |
25 |
26 | class RLUAtari(parameterized.TestCase):
27 | """Test RL Unplugged Atari data loader."""
28 |
29 | @staticmethod
30 | def test_data_loader():
31 | """Test data loader."""
32 | dataset_dir = os.path.join(_DATASET_DIR, "atari")
33 | _, dataloader = atari_env_loader(
34 | env_name="Asterix",
35 | run_number=1,
36 | dataset_dir=dataset_dir,
37 | )
38 | iterator = iter(dataloader)
39 | data = next(iterator)
40 | logging.info(data)
41 |
42 |
43 | class BSuite(parameterized.TestCase):
44 | """Test BSuite data loader."""
45 |
46 | @staticmethod
47 | def test_data_loader():
48 | """Test data loader."""
49 | dataset_dir = os.path.join(_DATASET_DIR, "bsuite")
50 | _, dataloader = bsuite_env_loader(
51 | env_name="catch",
52 | dataset_dir=dataset_dir,
53 | )
54 | iterator = iter(dataloader)
55 | _ = next(iterator)
56 |
57 |
58 | if __name__ == "__main__":
59 | absltest.main()
60 |
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | ## Installation
2 |
3 | :wrench:
4 |
5 | **Table of Contents**
6 |
7 | - [Installation](#installation)
8 | - [General](#general)
9 | - [TPU](#tpu)
10 | - [GPU](#gpu)
11 | - [Test](#test)
12 |
13 | ### General
14 |
15 | 1. Prepare an environment with `python=3.8`.
16 | 2. Clone this repository and install it in develop mode:
17 | ```console
18 | pip install -e .
19 |
20 | # Install the following packages separately due to version conflicts.
21 | pip install dopamine-rl==3.1.2
22 | pip install chex==0.1.5
23 | ```
24 | 3. [Install the ROM for Atari](https://github.com/openai/atari-py#roms).
25 | 4. Download dataset:
26 | 1. **BSuite** datasets ([drive](https://drive.google.com/file/d/1FWexoOphUgBaWTWtY9VR43N90z9A6FvP/view?usp=sharing)) if you are running BSuite experiments;
27 | 2. **Atari** datasets will be automatically downloaded from [TFDS](https://www.tensorflow.org/datasets/catalog/rlu_atari) when starting the experiment. The dataset path is defined in `experiment/*/config.py`. Or you could also download it using the following script:
28 | ```
29 | from rosmo.data.rlu_atari import create_atari_ds_loader
30 |
31 | create_atari_ds_loader(
32 | env_name="Pong", # Change this.
33 | run_number=1, # Fix this.
34 | dataset_dir="/path/to/download",
35 | )
36 | ```
37 |
38 | ### TPU
39 |
40 | All of our Atari experiments reported in the paper were run on TPUv3-8 machines from Google Cloud. If you would like to run your experiments on TPUs as well, the following commands might help:
41 | ```console
42 | sudo apt-get update && sudo apt install -y libopencv-dev
43 | pip uninstall jax jaxlib libtpu-nightly libtpu -y
44 | pip install "jax[tpu]==0.3.6" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -i https://pypi.python.org/simple
45 | ```
46 |
47 | ### GPU
48 |
49 | We also conducted verification experiments on 4 Tesla-A100 GPUs to ensure our algorithm's reproducibility on different platforms. To install the same version of Jax as ours:
50 | ```console
51 | pip uninstall jax jaxlib libtpu-nightly libtpu -y
52 |
53 | # jax-0.3.25 jaxlib-0.3.25+cuda11.cudnn82
54 | pip install --upgrade "jax[cuda]==0.3.25" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
55 | ```
56 |
57 | ### Test
58 |
59 | Finally, for TPU/GPU users, to validate you have installed Jax correctly, run `python -c "import jax; print(jax.devices())"` and expect a list of TPU/GPU devices printed.
60 |
--------------------------------------------------------------------------------
/experiment/bsuite/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """BSuite experiment configs."""
16 | from copy import deepcopy
17 | from typing import Dict
18 |
19 | from absl import flags, logging
20 | from ml_collections import ConfigDict
21 | from wandb.util import generate_id
22 |
23 | FLAGS = flags.FLAGS
24 |
25 |
26 | # ===== Configurations ===== #
27 | def get_config(game_name: str) -> Dict:
28 | """Get experiment configurations."""
29 | config = deepcopy(CONFIG)
30 | config["seed"] = FLAGS.seed
31 | config["benchmark"] = "bsuite"
32 | config["mcts"] = FLAGS.algo == "mzu"
33 | config["game_name"] = game_name
34 | config["batch_size"] = 16 if FLAGS.debug else config.batch_size
35 | exp_full_name = f"{FLAGS.exp_id}_{game_name}_" + generate_id()
36 | config["exp_full_name"] = exp_full_name
37 | logging.info(f"Configs: {config}")
38 | return config
39 |
40 |
41 | CONFIG = ConfigDict(
42 | {
43 | "data_dir": "./datasets/bsuite",
44 | "run_number": 1,
45 | "data_percentage": 100,
46 | "batch_size": 512,
47 | "unroll_steps": 3,
48 | "td_steps": 3,
49 | "num_bins": 20,
50 | "encoder_layers": [64, 64, 32],
51 | "dynamics_layers": [32, 32],
52 | "prediction_layers": [32],
53 | "output_init_scale": 0.0,
54 | "discount_factor": 0.997**4,
55 | "clipping_threshold": 1.0,
56 | "evaluate_episodes": 2,
57 | "log_interval": 400,
58 | "learning_rate": 7e-4,
59 | "warmup_steps": 1_000,
60 | "learning_rate_decay": 0.1,
61 | "weight_decay": 1e-4,
62 | "max_grad_norm": 5.0,
63 | "target_update_interval": 200,
64 | "value_coef": 0.25,
65 | "policy_coef": 1.0,
66 | "behavior_coef": 0.2,
67 | "save_period": 10_000,
68 | "eval_period": 1_000,
69 | "total_steps": 200_000,
70 | }
71 | )
72 |
--------------------------------------------------------------------------------
/experiment/atari/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Atari experiment configs."""
16 | from copy import deepcopy
17 | from typing import Dict
18 |
19 | from absl import flags, logging
20 | from ml_collections import ConfigDict
21 | from wandb.util import generate_id
22 |
23 | FLAGS = flags.FLAGS
24 |
25 |
26 | # ===== Configurations ===== #
27 | def get_config(game_name: str) -> Dict:
28 | """Get experiment configurations."""
29 | config = deepcopy(CONFIG)
30 | config["seed"] = FLAGS.seed
31 | config["benchmark"] = "atari"
32 | config["sampling"] = FLAGS.sampling
33 | config["mcts"] = FLAGS.algo == "mzu"
34 | config["game_name"] = game_name
35 | config["num_simulations"] = FLAGS.num_simulations
36 | config["search_depth"] = FLAGS.search_depth or FLAGS.num_simulations
37 | config["batch_size"] = 16 if FLAGS.debug else config["batch_size"]
38 | exp_full_name = f"{FLAGS.exp_id}_{game_name}_" + generate_id()
39 | config["exp_full_name"] = exp_full_name
40 | logging.info(f"Configs: {config}")
41 | return config
42 |
43 |
44 | CONFIG = ConfigDict(
45 | {
46 | "data_dir": "./datasets/rl_unplugged/tensorflow_datasets",
47 | "run_number": 1,
48 | "data_percentage": 100,
49 | "batch_size": 512,
50 | "stack_size": 4,
51 | "unroll_steps": 5,
52 | "td_steps": 5,
53 | "num_bins": 601,
54 | "channels": 64,
55 | "blocks_representation": 6,
56 | "blocks_prediction": 2,
57 | "blocks_transition": 2,
58 | "reduced_channels_head": 128,
59 | "fc_layers_reward": [128, 128],
60 | "fc_layers_value": [128, 128],
61 | "fc_layers_policy": [128, 128],
62 | "output_init_scale": 0.0,
63 | "discount_factor": 0.997**4,
64 | "clipping_threshold": 1.0,
65 | "evaluate_episodes": 2,
66 | "log_interval": 400,
67 | "learning_rate": 7e-4,
68 | "warmup_steps": 1_000,
69 | "learning_rate_decay": 0.1,
70 | "weight_decay": 1e-4,
71 | "max_grad_norm": 5.0,
72 | "target_update_interval": 200,
73 | "value_coef": 0.25,
74 | "policy_coef": 1.0,
75 | "behavior_coef": 0.2,
76 | "save_period": 10_000,
77 | "eval_period": 1_000,
78 | "total_steps": 200_000,
79 | }
80 | )
81 |
--------------------------------------------------------------------------------
/rosmo/agent/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Agent utilities."""
16 | # Codes adapted from:
17 | # https://github.com/Hwhitetooth/jax_muzero/blob/main/algorithms/utils.py
18 | import chex
19 | import jax
20 | import jax.numpy as jnp
21 |
22 | from rosmo.type import Array
23 |
24 |
25 | def scale_gradient(g: Array, scale: float) -> Array:
26 | """Scale the gradient.
27 |
28 | Args:
29 | g (_type_): Parameters that contain gradients.
30 | scale (float): Scale.
31 |
32 | Returns:
33 | Array: Parameters with scaled gradients.
34 | """
35 | return g * scale + jax.lax.stop_gradient(g) * (1.0 - scale)
36 |
37 |
38 | def scalar_to_two_hot(x: Array, num_bins: int) -> Array:
39 | """A categorical representation of real values.
40 |
41 | Ref: https://www.nature.com/articles/s41586-020-03051-4.pdf.
42 |
43 | Args:
44 | x (Array): Scalar data.
45 | num_bins (int): Number of bins.
46 |
47 | Returns:
48 | Array: Distributional data.
49 | """
50 | max_val = (num_bins - 1) // 2
51 | x = jnp.clip(x, -max_val, max_val)
52 | x_low = jnp.floor(x).astype(jnp.int32)
53 | x_high = jnp.ceil(x).astype(jnp.int32)
54 | p_high = x - x_low
55 | p_low = 1.0 - p_high
56 | idx_low = x_low + max_val
57 | idx_high = x_high + max_val
58 | cat_low = jax.nn.one_hot(idx_low, num_bins) * p_low[..., None]
59 | cat_high = jax.nn.one_hot(idx_high, num_bins) * p_high[..., None]
60 | return cat_low + cat_high
61 |
62 |
63 | def logits_to_scalar(logits: Array, num_bins: int) -> Array:
64 | """The inverse of the scalar_to_two_hot function above.
65 |
66 | Args:
67 | logits (Array): Distributional logits.
68 | num_bins (int): Number of bins.
69 |
70 | Returns:
71 | Array: Scalar data.
72 | """
73 | chex.assert_equal(num_bins, logits.shape[-1])
74 | max_val = (num_bins - 1) // 2
75 | x = jnp.sum((jnp.arange(num_bins) - max_val) * jax.nn.softmax(logits), axis=-1)
76 | return x
77 |
78 |
79 | def value_transform(x: Array, epsilon: float = 1e-3) -> Array:
80 | """A non-linear value transformation for variance reduction.
81 |
82 | Ref: https://arxiv.org/abs/1805.11593.
83 |
84 | Args:
85 | x (Array): Data.
86 | epsilon (float, optional): Epsilon. Defaults to 1e-3.
87 |
88 | Returns:
89 | Array: Transformed data.
90 | """
91 | return jnp.sign(x) * (jnp.sqrt(jnp.abs(x) + 1) - 1) + epsilon * x
92 |
93 |
94 | def inv_value_transform(x: Array, epsilon: float = 1e-3) -> Array:
95 | """The inverse of the non-linear value transformation above.
96 |
97 | Args:
98 | x (Array): Data.
99 | epsilon (float, optional): Epsilon. Defaults to 1e-3.
100 |
101 | Returns:
102 | Array: Inversely transformed data.
103 | """
104 | return jnp.sign(x) * (
105 | ((jnp.sqrt(1 + 4 * epsilon * (jnp.abs(x) + 1 + epsilon)) - 1) / (2 * epsilon))
106 | ** 2
107 | - 1
108 | )
109 |
--------------------------------------------------------------------------------
/rosmo/data/buffer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """In-memory buffer."""
16 | from typing import Any, Iterator
17 |
18 | import numpy as np
19 | import tree
20 |
21 |
22 | class UniformBuffer(Iterator):
23 | """Buffer that supports uniform sampling."""
24 |
25 | def __init__(
26 | self, min_size: int, max_size: int, traj_len: int, batch_size: int = 2
27 | ) -> None:
28 | """Init the buffer."""
29 | self._min_size = min_size
30 | self._max_size = max_size
31 | self._traj_len = traj_len
32 | self._timestep_storage = None
33 | self._n = 0
34 | self._idx = 0
35 | self._bs = batch_size
36 | self._static_buffer = False
37 |
38 | def __next__(self) -> Any:
39 | """Get the next sample.
40 |
41 | Returns:
42 | Any: Sampled data.
43 | """
44 | return self.sample(self._bs)
45 |
46 | def init_storage(self, timesteps: Any) -> None:
47 | """Initialize the buffer.
48 |
49 | Args:
50 | timesteps (Any): Timesteps that contain the whole dataset.
51 | """
52 | assert self._timestep_storage is None
53 | size = timesteps.observation.shape[0]
54 | assert self._min_size <= size <= self._max_size
55 | self._n = size
56 | self._timestep_storage = timesteps
57 | self._static_buffer = True
58 |
59 | def sample(self, batch_size: int) -> Any:
60 | """Sample a batch of data.
61 |
62 | Args:
63 | batch_size (int): Batch size to sample.
64 |
65 | Returns:
66 | Any: Sampled data.
67 | """
68 | if batch_size + self._traj_len > self._n:
69 | return None
70 | start_indices = np.random.choice(
71 | self._n - self._traj_len, batch_size, replace=False
72 | )
73 | all_indices = start_indices[:, None] + np.arange(self._traj_len + 1)[None]
74 | base_idx = 0 if self._n < self._max_size else self._idx
75 | all_indices = (all_indices + base_idx) % self._max_size
76 | trajectories = tree.map_structure(
77 | lambda a: a[all_indices], self._timestep_storage
78 | )
79 | return trajectories
80 |
81 | def full(self) -> bool:
82 | """Test if the buffer is full.
83 |
84 | Returns:
85 | bool: True if the buffer is full.
86 | """
87 | return self._n == self._max_size
88 |
89 | def ready(self) -> bool:
90 | """Test if the buffer has minimum size.
91 |
92 | Returns:
93 | bool: True if the buffer is ready.
94 | """
95 | return self._n >= self._min_size
96 |
97 | @property
98 | def size(self) -> int:
99 | """Get the size of the buffer.
100 |
101 | Returns:
102 | int: Buffer size.
103 | """
104 | return self._n
105 |
106 | def _preallocate(self, item: Any) -> Any:
107 | return tree.map_structure(
108 | lambda t: np.empty((self._max_size,) + t.shape, t.dtype), item
109 | )
110 |
111 |
112 | def assign(array: Any, index: Any, data: Any) -> None:
113 | """Update array.
114 |
115 | Args:
116 | array (Any): Array to be updated.
117 | index (Any): Index of updates.
118 | data (Any): Update data.
119 | """
120 | array[index] = data
121 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "rosmo"
7 | description = ''
8 | readme = "README.md"
9 | requires-python = ">=3.7"
10 | license = "MIT"
11 | keywords = []
12 | authors = [
13 | { name = "Zichen", email = "liuzc@sea.com" },
14 | ]
15 | classifiers = [
16 | "Development Status :: 4 - Beta",
17 | "Programming Language :: Python",
18 | "Programming Language :: Python :: 3.7",
19 | "Programming Language :: Python :: 3.8",
20 | "Programming Language :: Python :: 3.9",
21 | "Programming Language :: Python :: 3.10",
22 | "Programming Language :: Python :: 3.11",
23 | "Programming Language :: Python :: Implementation :: CPython",
24 | "Programming Language :: Python :: Implementation :: PyPy",
25 | ]
26 | dependencies = [
27 | "dm-acme==0.4.0",
28 | "dm-launchpad-nightly==0.3.0.dev20220321",
29 | "dm-haiku==0.0.9",
30 | "gym==0.17.2",
31 | "gin-config==0.3.0",
32 | "rlax==0.1.4",
33 | "tensorflow==2.8.0",
34 | "tensorflow-probability==0.16.0",
35 | "optax==0.1.3",
36 | "tfds-nightly",
37 | "rlds[tensorflow]==0.1.4",
38 | "wandb==0.12.19",
39 | "ml-collections==0.1.1",
40 | "dm-sonnet==2.0.0",
41 | "mujoco-py<2.2,>=2.1",
42 | "bsuite==0.3.5",
43 | "viztracer==0.15.6",
44 | "mctx==0.0.2",
45 | ]
46 | dynamic = ["version"]
47 |
48 | [project.urls]
49 | Documentation = "https://github.com/unknown/rosmo#readme"
50 | Issues = "https://github.com/unknown/rosmo/issues"
51 | Source = "https://github.com/unknown/rosmo"
52 |
53 | [tool.hatch.version]
54 | path = "rosmo/__about__.py"
55 |
56 | [tool.hatch.envs.default]
57 | dependencies = [
58 | "pytest",
59 | "pytest-cov",
60 | ]
61 | [tool.hatch.envs.default.scripts]
62 | cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=rosmo --cov=tests"
63 | no-cov = "cov --no-cov"
64 |
65 | [[tool.hatch.envs.test.matrix]]
66 | python = ["37", "38", "39", "310", "311"]
67 |
68 | [tool.coverage.run]
69 | branch = true
70 | parallel = true
71 | omit = [
72 | "rosmo/__about__.py",
73 | ]
74 |
75 | [tool.coverage.report]
76 | exclude_lines = [
77 | "no cov",
78 | "if __name__ == .__main__.:",
79 | "if TYPE_CHECKING:",
80 | ]
81 |
82 | [tool.pylint.master]
83 | load-plugins = "pylint.extensions.docparams,pylint.extensions.docstyle,pylint.extensions.no_self_use"
84 | default-docstring-type = "google"
85 | ignore-paths = ["rosmo/__about__.py"]
86 |
87 | [tool.pylint.format]
88 | max-line-length = 88
89 | indent-after-paren = 4
90 | indent-string = " "
91 |
92 | [tool.pylint.imports]
93 | known-third-party = "wandb"
94 |
95 | [tool.pylint.reports]
96 | output-format = "colorized"
97 | reports = "no"
98 | score = "yes"
99 | max-args = 7
100 |
101 | [tool.pylint.messages_control]
102 | disable = ["W0108", "W0212", "W1514", "R0902", "R0903", "R0913", "R0914", "R0915", "R1719",
103 | "R1732", "C0103", "C3001"]
104 |
105 | [tool.yapf]
106 | based_on_style = "yapf"
107 | spaces_before_comment = 4
108 | dedent_closing_brackets = true
109 | column_limit = 88
110 | continuation_indent_width = 4
111 |
112 | [tool.isort]
113 | profile = "black"
114 | multi_line_output = 3
115 | indent = 4
116 | line_length = 88
117 | known_third_party = "wandb"
118 |
119 | [tool.mypy]
120 | files = "rosmo/**/*.py"
121 | allow_redefinition = true
122 | check_untyped_defs = true
123 | disallow_incomplete_defs = true
124 | disallow_untyped_defs = true
125 | ignore_missing_imports = true
126 | no_implicit_optional = true
127 | pretty = true
128 | show_error_codes = true
129 | show_error_context = true
130 | show_traceback = true
131 | strict_equality = true
132 | strict_optional = true
133 | warn_no_return = true
134 | warn_redundant_casts = true
135 | warn_unreachable = true
136 | warn_unused_configs = true
137 |
138 | [tool.pydocstyle]
139 | ignore = ["D100", "D102", "D104", "D105", "D107", "D203", "D213", "D401", "D402"]
140 |
141 |
142 | [tool.pytest.ini_options]
143 | filterwarnings = [
144 | "ignore::UserWarning",
145 | "ignore::DeprecationWarning",
146 | "ignore::FutureWarning",
147 | ]
148 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ROSMO
2 |
3 |
4 |

5 |
6 |
7 | -----
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | **Table of Contents**
26 |
27 | - [ROSMO](#rosmo)
28 | - [Introduction](#introduction)
29 | - [Installation](#installation)
30 | - [Usage](#usage)
31 | - [BSuite](#bsuite)
32 | - [Atari](#atari)
33 | - [Citation](#citation)
34 | - [License](#license)
35 | - [Acknowledgement](#acknowledgement)
36 | - [Disclaimer](#disclaimer)
37 |
38 | ## Introduction
39 |
40 | This repository contains the implementation of ROSMO, a **R**egularized **O**ne-**S**tep **M**odel-based algorithm for **O**ffline-RL, introduced in our paper "Efficient Offline Policy Optimization with a Learned Model". We provide the training codes for both Atari and BSuite experiments, and have made the reproduced results on `Atari MsPacman` publicly available at [W&B](https://wandb.ai/lkevinzc/rosmo-public).
41 |
42 | ## Installation
43 | Please follow the [installation guide](INSTALL.md).
44 |
45 | ## Usage
46 | ### BSuite
47 |
48 | To run the BSuite experiments, please ensure you have downloaded the [datasets](https://drive.google.com/file/d/1FWexoOphUgBaWTWtY9VR43N90z9A6FvP/view?usp=sharing) and placed them at the directory defined by `CONFIG.data_dir` in `experiment/bsuite/config.py`.
49 |
50 | 1. Debug run.
51 | ```console
52 | python experiment/bsuite/main.py -exp_id test -env cartpole
53 | ```
54 | 2. Enable [W&B](https://wandb.ai/site) logger and start training.
55 | ```console
56 | python experiment/bsuite/main.py -exp_id test -env cartpole -nodebug -use_wb -user ${WB_USER}
57 | ```
58 |
59 | ### Atari
60 |
61 | The following commands are examples to train 1) a ROSMO agent, 2) its sampling variant, and 3) a MZU agent on the game `MsPacman`.
62 |
63 | 1. Train ROSMO with exact policy target.
64 | ```console
65 | python experiment/atari/main.py -exp_id rosmo -env MsPacman -nodebug -use_wb -user ${WB_USER}
66 | ```
67 | 2. Train ROSMO with sampled policy target (N=4).
68 | ```console
69 | python experiment/atari/main.py -exp_id rosmo-sample-4 -sampling -env MsPacman -nodebug -use_wb -user ${WB_USER}
70 | ```
71 | 3. Train MuZero unplugged for benchmark (N=20).
72 | ```console
73 | python experiment/atari/main.py -exp_id mzu-sample-20 -algo mzu -num_simulations 20 -env MsPacman -nodebug -use_wb -user ${WB_USER}
74 | ```
75 |
76 | ## Citation
77 |
78 | If you find this work useful for your research, please consider citing
79 | ```
80 | @inproceedings{
81 | liu2023rosmo,
82 | title={Efficient Offline Policy Optimization with a Learned Model},
83 | author={Zichen Liu and Siyi Li and Wee Sun Lee and Shuicheng Yan and Zhongwen Xu},
84 | booktitle={International Conference on Learning Representations},
85 | year={2023},
86 | url={https://arxiv.org/abs/2210.05980}
87 | }
88 | ```
89 |
90 | ## License
91 |
92 | `ROSMO` is distributed under the terms of the [Apache2](https://www.apache.org/licenses/LICENSE-2.0) license.
93 |
94 | ## Acknowledgement
95 |
96 | We thank the following projects which provide great references:
97 |
98 | * [Jax Muzero](https://github.com/Hwhitetooth/jax_muzero)
99 | * [Efficient Zero](https://github.com/YeWR/EfficientZero)
100 | * [Acme](https://github.com/deepmind/acme)
101 |
102 | ## Disclaimer
103 |
104 | This is not an official Sea Limited or Garena Online Private Limited product.
105 |
--------------------------------------------------------------------------------
/rosmo/env_loop_observer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Environment loop extensions."""
16 | import time
17 | from abc import abstractmethod
18 | from typing import Dict, Optional, Union
19 |
20 | import acme
21 | import dm_env
22 | import numpy as np
23 | from acme.utils import observers as observers_lib
24 | from acme.utils import signals
25 | from importlib_metadata import collections
26 |
27 | Number = Union[int, float]
28 |
29 |
30 | class EvaluationLoop(acme.EnvironmentLoop):
31 | """Evaluation env-actor loop."""
32 |
33 | def run(
34 | self, num_episodes: Optional[int] = None, num_steps: Optional[int] = None
35 | ) -> None:
36 | """Run the evaluation loop."""
37 | if not (num_episodes is None or num_steps is None):
38 | raise ValueError('Either "num_episodes" or "num_steps" should be None.')
39 |
40 | def should_terminate(episode_count: int, step_count: int) -> bool:
41 | return (num_episodes is not None and episode_count >= num_episodes) or (
42 | num_steps is not None and step_count >= num_steps
43 | )
44 |
45 | episode_count, step_count = 0, 0
46 | all_results: Dict[str, list] = collections.defaultdict(list)
47 | with signals.runtime_terminator():
48 | while not should_terminate(episode_count, step_count):
49 | result = self.run_episode()
50 | episode_count += 1
51 | step_count += result["episode_length"]
52 | for k, v in result.items():
53 | all_results[k].append(v)
54 | # Log the averaged results from all episodes.
55 | self._logger.write({k: np.mean(v) for k, v in all_results.items()})
56 |
57 |
58 | class ExtendedEnvLoopObserver(observers_lib.EnvLoopObserver):
59 | """Extended env loop observer."""
60 |
61 | @abstractmethod
62 | def step(self) -> None:
63 | """Steps the observer."""
64 |
65 | @abstractmethod
66 | def restore(self, learning_step: int) -> None:
67 | """Restore the observer state."""
68 |
69 |
70 | class LearningStepObserver(ExtendedEnvLoopObserver):
71 | """Observer to record the learning steps."""
72 |
73 | def __init__(self) -> None:
74 | """Init observer."""
75 | super().__init__()
76 | self._learning_step = 0
77 | self._eval_step = 0
78 | self._status = 1 # {0: train, 1: eval}
79 | self._train_elapsed = 0.0
80 | self._last_time: Optional[float] = None
81 |
82 | def step(self) -> None:
83 | """Steps the observer."""
84 | self._learning_step += 1
85 |
86 | if self._status == 0 and self._last_time:
87 | self._train_elapsed += time.time() - self._last_time
88 | if self._status == 1:
89 | self._status = 0
90 |
91 | self._last_time = time.time()
92 |
93 | def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep) -> None:
94 | """Observes the initial state, setting states."""
95 | self._status = 1
96 | self._eval_step += 1
97 |
98 | def observe(
99 | self, env: dm_env.Environment, timestep: dm_env.TimeStep, action: np.ndarray
100 | ) -> None:
101 | """Records one environment step, dummy."""
102 |
103 | def get_metrics(self) -> Dict[str, Number]:
104 | """Returns metrics collected for the current episode."""
105 | return {
106 | "step": self._learning_step,
107 | "eval_step": self._eval_step,
108 | "learning_time": self._train_elapsed,
109 | }
110 |
111 | def restore(self, learning_step: int) -> None:
112 | """Restore."""
113 | self._learning_step = learning_step
114 |
--------------------------------------------------------------------------------
/experiment/bsuite/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """BSuite experiment entry."""
16 | import random
17 | import time
18 | from typing import Dict, List, Optional, Tuple
19 |
20 | import jax
21 | import numpy as np
22 | import wandb
23 | from absl import app, flags, logging
24 | from acme import EnvironmentLoop
25 | from acme.specs import make_environment_spec
26 | from acme.utils.loggers import Logger
27 |
28 | from experiment.bsuite.config import get_config
29 | from rosmo.agent.actor import RosmoEvalActor
30 | from rosmo.agent.learning import RosmoLearner
31 | from rosmo.agent.network import get_bsuite_networks
32 | from rosmo.data import bsuite_env_loader
33 | from rosmo.env_loop_observer import (
34 | EvaluationLoop,
35 | ExtendedEnvLoopObserver,
36 | LearningStepObserver,
37 | )
38 | from rosmo.loggers import logger_fn
39 |
40 | # ===== Flags. ===== #
41 | FLAGS = flags.FLAGS
42 | flags.DEFINE_boolean("debug", True, "Debug run.")
43 | flags.DEFINE_boolean("use_wb", False, "Use WB to log.")
44 | flags.DEFINE_string("user", "username", "Wandb user id.")
45 | flags.DEFINE_string("project", "rosmo", "Wandb project id.")
46 |
47 | flags.DEFINE_string("exp_id", None, "Experiment id.", required=True)
48 | flags.DEFINE_string("env", None, "Environment name to run.", required=True)
49 | flags.DEFINE_integer("seed", int(time.time()), "Random seed.")
50 |
51 |
52 | # ===== Learner. ===== #
53 | def get_learner(config, networks, data_iterator, logger) -> RosmoLearner:
54 | """Get ROSMO learner."""
55 | learner = RosmoLearner(
56 | networks,
57 | demonstrations=data_iterator,
58 | config=config,
59 | logger=logger,
60 | )
61 | return learner
62 |
63 |
64 | # ===== Eval Actor-Env Loop. ===== #
65 | def get_actor_env_eval_loop(
66 | config, networks, environment, observers, logger
67 | ) -> Tuple[RosmoEvalActor, EnvironmentLoop]:
68 | """Get actor, env and evaluation loop."""
69 | actor = RosmoEvalActor(
70 | networks,
71 | config,
72 | )
73 | eval_loop = EvaluationLoop(
74 | environment=environment,
75 | actor=actor,
76 | logger=logger,
77 | should_update=False,
78 | observers=observers,
79 | )
80 | return actor, eval_loop
81 |
82 |
83 | def get_env_loop_observers() -> List[ExtendedEnvLoopObserver]:
84 | """Get environment loop observers."""
85 | observers = []
86 | learning_step_ob = LearningStepObserver()
87 | observers.append(learning_step_ob)
88 | return observers
89 |
90 |
91 | # ===== Misc. ===== #
92 | def get_logger_fn(
93 | exp_full_name: str,
94 | job_name: str,
95 | is_eval: bool = False,
96 | config: Optional[Dict] = None,
97 | ) -> Logger:
98 | """Get logger function."""
99 | save_data = is_eval
100 | return logger_fn(
101 | exp_name=exp_full_name,
102 | label=job_name,
103 | save_data=save_data and not FLAGS.debug,
104 | use_tb=False,
105 | use_wb=FLAGS.use_wb and not FLAGS.debug,
106 | config=config,
107 | )
108 |
109 |
110 | def main(_):
111 | """Main program."""
112 | logging.info(f"Debug mode: {FLAGS.debug}")
113 | random.seed(FLAGS.seed)
114 | np.random.seed(FLAGS.seed)
115 |
116 | platform = jax.lib.xla_bridge.get_backend().platform
117 | num_devices = jax.device_count()
118 | logging.warn(f"Compute platform: {platform} with {num_devices} devices.")
119 |
120 | # ===== Setup. ===== #
121 | cfg = get_config(FLAGS.env)
122 |
123 | env, dataloader = bsuite_env_loader(
124 | env_name=FLAGS.env,
125 | dataset_dir=cfg["data_dir"],
126 | data_percentage=cfg["data_percentage"],
127 | batch_size=cfg["batch_size"],
128 | trajectory_length=cfg["td_steps"] + cfg["unroll_steps"] + 1,
129 | )
130 | networks = get_bsuite_networks(make_environment_spec(env), cfg)
131 |
132 | # ===== Essentials. ===== #
133 | learner = get_learner(
134 | cfg,
135 | networks,
136 | dataloader,
137 | get_logger_fn(
138 | cfg["exp_full_name"],
139 | "learner",
140 | config=cfg,
141 | ),
142 | )
143 | observers = get_env_loop_observers()
144 | actor, eval_loop = get_actor_env_eval_loop(
145 | cfg,
146 | networks,
147 | env,
148 | observers,
149 | get_logger_fn(cfg["exp_full_name"], "evaluator", is_eval=True, config=cfg),
150 | )
151 | evaluate_episodes = 2 if FLAGS.debug else cfg["evaluate_episodes"]
152 |
153 | # ===== Restore. ===== #
154 | init_step = 0
155 | if FLAGS.use_wb and not FLAGS.debug:
156 | wb_name = cfg["exp_full_name"]
157 | wb_cfg = cfg.to_dict()
158 | wandb.init(
159 | project=FLAGS.project,
160 | entity=FLAGS.user,
161 | name=wb_name,
162 | config=wb_cfg,
163 | sync_tensorboard=False,
164 | )
165 |
166 | # ===== Training Loop. ===== #
167 | for i in range(init_step + 1, cfg["total_steps"]):
168 | learner.step()
169 | for ob in observers:
170 | ob.step()
171 |
172 | if FLAGS.debug or (i + 1) % cfg["eval_period"] == 0:
173 | actor.update_params(learner.save().params)
174 | eval_loop.run(evaluate_episodes)
175 |
176 | if FLAGS.debug:
177 | break
178 |
179 | # ===== Cleanup. ===== #
180 | learner._logger.close()
181 | eval_loop._logger.close()
182 | del env, networks, dataloader, learner, observers, actor, eval_loop
183 |
184 |
185 | if __name__ == "__main__":
186 | app.run(main)
187 |
--------------------------------------------------------------------------------
/rosmo/agent/actor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Evaluating actor."""
16 | from typing import Dict, Optional, Tuple
17 |
18 | import acme
19 | import chex
20 | import dm_env
21 | import jax
22 | import numpy as np
23 | import rlax
24 | from absl import logging
25 | from acme import types
26 | from acme.jax import networks as networks_lib
27 |
28 | from rosmo.agent.improvement_op import mcts_improve, one_step_improve
29 | from rosmo.agent.learning import root_unroll
30 | from rosmo.agent.network import Networks
31 | from rosmo.agent.type import AgentOutput, Params
32 | from rosmo.type import ActorOutput, Array
33 |
34 |
35 | class RosmoEvalActor(acme.core.Actor):
36 | """ROSMO evaluation actor."""
37 |
38 | def __init__(
39 | self,
40 | networks: Networks,
41 | config: Dict,
42 | ) -> None:
43 | """Init ROSMO evaluation actor."""
44 | self._networks = networks
45 |
46 | self._environment_specs = networks.environment_specs
47 | self._rng_key = jax.random.PRNGKey(config["seed"])
48 | self._random_action = False
49 | self._params: Optional[Params] = None
50 | self._timestep: Optional[ActorOutput] = None
51 |
52 | num_bins = config["num_bins"]
53 | discount_factor = config["discount_factor"]
54 | use_mcts = config["mcts"]
55 | sampling = config.get("sampling", False)
56 | num_simulations = config.get("num_simulations", -1)
57 | search_depth = config.get("search_depth", num_simulations)
58 |
59 | def root_step(
60 | rng_key: chex.PRNGKey,
61 | params: Params,
62 | timesteps: ActorOutput,
63 | ) -> Tuple[Array, AgentOutput]:
64 | # Model one-step look-ahead for acting.
65 | trajectory = jax.tree_map(
66 | lambda t: t[None], timesteps
67 | ) # Add a dummy time dimension.
68 | state = networks.representation_network.apply(
69 | params.representation, trajectory.observation
70 | )
71 | agent_out: AgentOutput = root_unroll(
72 | self._networks, params, num_bins, state
73 | )
74 | improve_key, sample_key = jax.random.split(rng_key)
75 |
76 | if use_mcts:
77 | logging.info("[Actor] Using MCTS planning.")
78 | mcts_out = mcts_improve(
79 | networks,
80 | improve_key,
81 | params,
82 | agent_out,
83 | num_bins,
84 | discount_factor,
85 | num_simulations,
86 | search_depth,
87 | )
88 | action = mcts_out.action
89 | else:
90 | agent_out = jax.tree_map(
91 | lambda t: t.squeeze(axis=0), agent_out
92 | ) # Squeeze the dummy time dimension.
93 | if not sampling:
94 | logging.info("[Actor] Using onestep improvement.")
95 | improved_policy, _ = one_step_improve(
96 | self._networks,
97 | improve_key,
98 | params,
99 | agent_out,
100 | num_bins,
101 | discount_factor,
102 | num_simulations,
103 | sampling,
104 | )
105 | else:
106 | logging.info("[Actor] Using policy.")
107 | improved_policy = jax.nn.softmax(agent_out.policy_logits)
108 | action = rlax.categorical_sample(sample_key, improved_policy)
109 | return action, agent_out
110 |
111 | def batch_step(
112 | rng_key: chex.PRNGKey,
113 | params: Params,
114 | timesteps: ActorOutput,
115 | ) -> Tuple[networks_lib.PRNGKey, Array, AgentOutput]:
116 | batch_size = timesteps.reward.shape[0]
117 | rng_key, step_key = jax.random.split(rng_key)
118 | step_keys = jax.random.split(step_key, batch_size)
119 | batch_root_step = jax.vmap(root_step, (0, None, 0))
120 | actions, agent_out = batch_root_step(step_keys, params, timesteps)
121 | return rng_key, actions, agent_out
122 |
123 | self._agent_step = jax.jit(batch_step)
124 |
125 | def select_action(self, observation: types.NestedArray) -> types.NestedArray:
126 | """Select action to execute."""
127 | if self._random_action:
128 | return np.random.randint(0, self._environment_specs.actions.num_values, [1])
129 | batched_timestep = jax.tree_map(
130 | lambda t: t[None], jax.device_put(self._timestep)
131 | )
132 | self._rng_key, action, _ = self._agent_step(
133 | self._rng_key, self._params, batched_timestep
134 | )
135 | action = jax.device_get(action).item()
136 | return action
137 |
138 | def observe_first(self, timestep: dm_env.TimeStep) -> None:
139 | """Observe and record the first timestep."""
140 | self._timestep = ActorOutput(
141 | action=np.zeros((1,), dtype=np.int32),
142 | reward=np.zeros((1,), dtype=np.float32),
143 | observation=timestep.observation,
144 | is_first=np.ones((1,), dtype=np.float32),
145 | is_last=np.zeros((1,), dtype=np.float32),
146 | )
147 |
148 | def observe(
149 | self,
150 | action: types.NestedArray,
151 | next_timestep: dm_env.TimeStep,
152 | ) -> None:
153 | """Observe and record a timestep."""
154 | self._timestep = ActorOutput(
155 | action=action,
156 | reward=next_timestep.reward,
157 | observation=next_timestep.observation,
158 | is_first=next_timestep.first(), # previous last = this first.
159 | is_last=next_timestep.last(),
160 | )
161 |
162 | def update(self, wait: bool = False) -> None:
163 | """Update."""
164 |
165 | def update_params(self, params: Params) -> None:
166 | """Update parameters.
167 |
168 | Args:
169 | params (Params): Parameters.
170 | """
171 | self._params = params
172 |
--------------------------------------------------------------------------------
/rosmo/data/bsuite.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """BSuite datasets."""
16 | import os
17 | from typing import Any, Dict, Tuple
18 |
19 | import dm_env
20 | import numpy as np
21 | import rlds
22 | import tensorflow as tf
23 | import tensorflow_datasets as tfds
24 | from absl import logging
25 | from bsuite.environments.cartpole import Cartpole as _Cartpole
26 | from bsuite.environments.catch import Catch as _Catch
27 | from bsuite.environments.mountain_car import MountainCar as _MountainCar
28 |
29 | from rosmo.data.buffer import UniformBuffer
30 | from rosmo.type import ActorOutput
31 |
32 |
33 | class Cartpole(_Cartpole):
34 | """Carpole environment."""
35 |
36 | def __init__(self, *args: Any, **kwargs: Any) -> None:
37 | """Init env."""
38 | super().__init__(*args, **kwargs)
39 | self.episode_id = 0
40 | self.episode_return = 0
41 | self.bsuite_id = "cartpole/0"
42 |
43 | def reset(self) -> dm_env.TimeStep:
44 | """Reset env."""
45 | self.episode_id += 1
46 | self.episode_return = 0
47 | return super().reset()
48 |
49 | def step(self, action: int) -> dm_env.TimeStep:
50 | """Step env."""
51 | timestep = super().step(action)
52 | if timestep.reward is not None:
53 | self.episode_return += timestep.reward
54 | return timestep
55 |
56 |
57 | class Catch(_Catch):
58 | """Catch environment."""
59 |
60 | def __init__(self, *args: Any, **kwargs: Any) -> None:
61 | """Init env."""
62 | super().__init__(*args, **kwargs)
63 | self.episode_id = 0
64 | self.episode_return = 0
65 | self.bsuite_id = "catch/0"
66 |
67 | def _reset(self) -> dm_env.TimeStep:
68 | self.episode_id += 1
69 | self.episode_return = 0
70 | return super()._reset()
71 |
72 | def _step(self, action: int) -> dm_env.TimeStep:
73 | timestep = super()._step(action)
74 | if timestep.reward is not None:
75 | self.episode_return += timestep.reward
76 | return timestep
77 |
78 |
79 | class MountainCar(_MountainCar):
80 | """Mountain Car environment."""
81 |
82 | def __init__(self, *args: Any, **kwargs: Any) -> None:
83 | """Init env."""
84 | super().__init__(*args, **kwargs)
85 | self.episode_id = 0
86 | self.episode_return = 0
87 | self.bsuite_id = "mountain_car/0"
88 |
89 | def _reset(self) -> dm_env.TimeStep:
90 | self.episode_id += 1
91 | self.episode_return = 0
92 | return super()._reset()
93 |
94 | def _step(self, action: int) -> dm_env.TimeStep:
95 | timestep = super()._step(action)
96 | if timestep.reward is not None:
97 | self.episode_return += timestep.reward
98 | return timestep
99 |
100 |
101 | _ENV_FACTORY: Dict[str, Tuple[dm_env.Environment, int]] = {
102 | "cartpole": (Cartpole, 1000),
103 | "catch": (Catch, 2000),
104 | "mountain_car": (MountainCar, 500),
105 | }
106 |
107 | _LOAD_SIZE = 1e7
108 |
109 | SCORES = {
110 | "cartpole": {
111 | "random": 64.833,
112 | "online_dqn": 1001.0,
113 | },
114 | "catch": {
115 | "random": -0.667,
116 | "online_dqn": 1.0,
117 | },
118 | "mountain_car": {
119 | "random": -1000.0,
120 | "online_dqn": -102.167,
121 | },
122 | }
123 |
124 |
125 | def create_bsuite_ds_loader(
126 | env_name: str, dataset_name: str, dataset_percentage: int
127 | ) -> tf.data.Dataset:
128 | """Create BSuite dataset loader.
129 |
130 | Args:
131 | env_name (str): Environment name.
132 | dataset_name (str): Dataset name.
133 | dataset_percentage (int): Fraction of data to be used
134 |
135 | Returns:
136 | tf.data.Dataset: Dataset.
137 | """
138 | dataset = tfds.builder_from_directory(dataset_name).as_dataset(split="all")
139 | num_trajectory = _ENV_FACTORY[env_name][1]
140 | if dataset_percentage < 100:
141 | idx = np.arange(0, num_trajectory, (100 // dataset_percentage))
142 | idx += np.random.randint(0, 100 // dataset_percentage, idx.shape) + 1
143 | idx = tf.convert_to_tensor(idx, "int32")
144 | filter_fn = lambda episode: tf.math.equal(
145 | tf.reduce_sum(tf.cast(episode["episode_id"] == idx, "int32")), 1
146 | )
147 | dataset = dataset.filter(filter_fn)
148 | parse_fn = lambda episode: episode[rlds.STEPS]
149 | dataset = dataset.interleave(
150 | parse_fn,
151 | cycle_length=1,
152 | block_length=1,
153 | deterministic=False,
154 | num_parallel_calls=tf.data.AUTOTUNE,
155 | )
156 | return dataset
157 |
158 |
159 | def env_loader(
160 | env_name: str,
161 | dataset_dir: str,
162 | data_percentage: int = 100,
163 | batch_size: int = 8,
164 | trajectory_length: int = 1,
165 | **_: Any,
166 | ) -> Tuple[dm_env.Environment, tf.data.Dataset]:
167 | """Get the environment and dataset.
168 |
169 | Args:
170 | env_name (str): Name of the environment.
171 | dataset_dir (str): Directory storing the dataset.
172 | data_percentage (int, optional): Fraction of data to be used. Defaults to 100.
173 | batch_size (int, optional): Batch size. Defaults to 8.
174 | trajectory_length (int, optional): Trajectory length. Defaults to 1.
175 | **_: Other keyword arguments.
176 |
177 | Returns:
178 | Tuple[dm_env.Environment, tf.data.Dataset]: Environment and dataset.
179 | """
180 | data_name = env_name
181 | if env_name not in _ENV_FACTORY:
182 | _env_setting = env_name.split("_")
183 | if len(_env_setting) > 1:
184 | env_name = "_".join(_env_setting[:-1])
185 | assert env_name in _ENV_FACTORY, f"env {env_name} not supported"
186 |
187 | dataset_name = os.path.join(dataset_dir, f"{data_name}")
188 | print(dataset_name)
189 | dataset = create_bsuite_ds_loader(env_name, dataset_name, data_percentage)
190 | dataloader = dataset.batch(int(_LOAD_SIZE)).as_numpy_iterator()
191 | data = next(dataloader)
192 |
193 | data_buffer = {}
194 | data_buffer["observation"] = data["observation"]
195 | data_buffer["reward"] = data["reward"]
196 | data_buffer["is_first"] = data["is_first"]
197 | data_buffer["is_last"] = data["is_last"]
198 | data_buffer["action"] = data["action"]
199 |
200 | timesteps = ActorOutput(**data_buffer)
201 | data_size = len(timesteps.reward)
202 | assert data_size < _LOAD_SIZE
203 |
204 | iterator = UniformBuffer(
205 | 0,
206 | data_size,
207 | trajectory_length,
208 | batch_size,
209 | )
210 | logging.info(f"[Data] {data_size} transitions totally.")
211 | iterator.init_storage(timesteps)
212 | return _ENV_FACTORY[env_name][0](), iterator
213 |
--------------------------------------------------------------------------------
/rosmo/agent/improvement_op.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Improvement operators."""
16 | from typing import Tuple
17 |
18 | import chex
19 | import distrax
20 | import jax
21 | import jax.numpy as jnp
22 | import mctx
23 | from absl import logging
24 | from acme.jax import networks as networks_lib
25 |
26 | from rosmo.agent.network import Networks
27 | from rosmo.agent.type import AgentOutput, Params
28 | from rosmo.agent.utils import inv_value_transform, logits_to_scalar
29 | from rosmo.type import Array
30 |
31 | ## -------------------------------- ##
32 | ## One-step Look-ahead Improvement. ##
33 | ## -------------------------------- ##
34 |
35 |
36 | def model_simulate(
37 | networks: Networks,
38 | params: Params,
39 | num_bins: int,
40 | state: Array,
41 | actions_to_simulate: Array,
42 | ) -> AgentOutput:
43 | """Simulate the learned model using one-step look-ahead."""
44 |
45 | def fn(state: Array, action: Array) -> Array:
46 | """Dynamics fun for vmap."""
47 | next_state = networks.transition_network.apply(
48 | params.transition, action[None], state
49 | )
50 | return next_state
51 |
52 | states_imagined = jax.vmap(fn, (None, 0))(state, actions_to_simulate)
53 |
54 | (
55 | policy_logits,
56 | reward_logits,
57 | value_logits,
58 | ) = networks.prediction_network.apply(params.prediction, states_imagined)
59 | reward = logits_to_scalar(reward_logits, num_bins)
60 | reward = inv_value_transform(reward)
61 | value = logits_to_scalar(value_logits, num_bins)
62 | value = inv_value_transform(value)
63 | return AgentOutput(
64 | state=states_imagined,
65 | policy_logits=policy_logits,
66 | reward_logits=reward_logits,
67 | reward=reward,
68 | value_logits=value_logits,
69 | value=value,
70 | )
71 |
72 |
73 | def one_step_improve(
74 | networks: Networks,
75 | rng_key: networks_lib.PRNGKey,
76 | params: Params,
77 | model_root: AgentOutput,
78 | num_bins: int,
79 | discount_factor: float,
80 | num_simulations: int = -1,
81 | sampling: bool = False,
82 | ) -> Tuple[Array, Array]:
83 | """Obtain the one-step look-ahead target policy."""
84 | environment_specs = networks.environment_specs
85 |
86 | pi_prior = jax.nn.softmax(model_root.policy_logits)
87 | value_prior = model_root.value
88 |
89 | if sampling:
90 | assert num_simulations > 0
91 | logging.info(
92 | f"[Sample] Using {num_simulations} samples to estimate improvement."
93 | )
94 | pi_sample = distrax.Categorical(probs=pi_prior)
95 | sample_acts = pi_sample.sample(
96 | seed=rng_key, sample_shape=num_simulations)
97 | sample_one_step_out: AgentOutput = model_simulate(
98 | networks, params, num_bins, model_root.state, sample_acts
99 | )
100 | sample_adv = (
101 | sample_one_step_out.reward
102 | + discount_factor * sample_one_step_out.value
103 | - value_prior
104 | )
105 | adv = sample_adv # for log
106 | sample_exp_adv = jnp.exp(sample_adv)
107 | normalizer_raw = (jnp.sum(sample_exp_adv) + 1) / num_simulations
108 | coeff = jnp.zeros_like(pi_prior)
109 |
110 | def body(i: int, val: jnp.ndarray) -> jnp.ndarray:
111 | """Body fun for the loop."""
112 | normalizer_i = normalizer_raw - sample_exp_adv[i] / num_simulations
113 | delta = jnp.zeros_like(val)
114 | delta = delta.at[sample_acts[i]].set(
115 | sample_exp_adv[i] / normalizer_i)
116 | return val + delta
117 |
118 | coeff = jax.lax.fori_loop(0, num_simulations, body, coeff)
119 | pi_improved = coeff / num_simulations
120 | else:
121 | all_actions = jnp.arange(environment_specs.actions.num_values)
122 | model_one_step_out: AgentOutput = model_simulate(
123 | networks, params, num_bins, model_root.state, all_actions
124 | )
125 | chex.assert_equal_shape([model_one_step_out.reward, pi_prior])
126 | chex.assert_equal_shape([model_one_step_out.value, pi_prior])
127 | adv = (
128 | model_one_step_out.reward
129 | + discount_factor * model_one_step_out.value
130 | - value_prior
131 | )
132 | pi_improved = pi_prior * jnp.exp(adv)
133 | pi_improved = pi_improved / jnp.sum(pi_improved)
134 |
135 | chex.assert_equal_shape([pi_improved, pi_prior])
136 | # pi_improved here might not sum to 1, in which case we use CE
137 | # to conveniently calculate the policy gradients (Eq. 9)
138 | return pi_improved, adv
139 |
140 |
141 | ## ------------------------------------ ##
142 | ## Monte-Carlo Tree Search Improvement. ##
143 | ## ------------------------------------ ##
144 |
145 |
146 | def mcts_improve(
147 | networks: Networks,
148 | rng_key: networks_lib.PRNGKey,
149 | params: Params,
150 | model_root: AgentOutput,
151 | num_bins: int,
152 | discount_factor: float,
153 | num_simulations: int,
154 | search_depth: int,
155 | ) -> mctx.PolicyOutput:
156 | """Obtain the Monte-Carlo Tree Search target policy."""
157 |
158 | def recurrent_fn(
159 | params: Params, rng_key: networks_lib.PRNGKey, action: Array, state: Array
160 | ) -> Tuple[mctx.RecurrentFnOutput, Array]:
161 | del rng_key
162 |
163 | def fn(state: Array, action: Array) -> Array:
164 | next_state = networks.transition_network.apply(
165 | params.transition, action[None], state
166 | )
167 | return next_state
168 |
169 | next_state = jax.vmap(fn, (0, 0))(state, action)
170 |
171 | (
172 | policy_logits,
173 | reward_logits,
174 | value_logits,
175 | ) = networks.prediction_network.apply(params.prediction, next_state)
176 | reward = logits_to_scalar(reward_logits, num_bins)
177 | reward = inv_value_transform(reward)
178 | value = logits_to_scalar(value_logits, num_bins)
179 | value = inv_value_transform(value)
180 | recurrent_fn_output = mctx.RecurrentFnOutput(
181 | reward=reward,
182 | discount=jnp.full_like(value, fill_value=discount_factor),
183 | prior_logits=policy_logits,
184 | value=value,
185 | )
186 | return recurrent_fn_output, next_state
187 |
188 | root = mctx.RootFnOutput(
189 | prior_logits=model_root.policy_logits,
190 | value=model_root.value,
191 | embedding=model_root.state,
192 | )
193 |
194 | return mctx.muzero_policy(
195 | params,
196 | rng_key,
197 | root,
198 | recurrent_fn,
199 | num_simulations,
200 | max_depth=search_depth,
201 | )
202 |
--------------------------------------------------------------------------------
/rosmo/loggers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Logger utils."""
16 | import os
17 | import re
18 | import threading
19 | from typing import Any, Callable, Dict, Mapping, Optional
20 |
21 | import wandb
22 | from acme.utils.loggers import Logger, aggregators
23 | from acme.utils.loggers import asynchronous as async_logger
24 | from acme.utils.loggers import base, csv, filters, terminal
25 | from acme.utils.loggers.tf_summary import TFSummaryLogger
26 |
27 | from rosmo.data.bsuite import SCORES
28 | from rosmo.data.rlu_atari import BASELINES
29 |
30 |
31 | class WBLogger(base.Logger):
32 | """Logger for W&B."""
33 |
34 | def __init__(
35 | self,
36 | scope: Optional[str] = None,
37 | ) -> None:
38 | """Init WB logger."""
39 | self._lock = threading.Lock()
40 | self._scope = scope
41 |
42 | def write(self, data: Dict[str, Any]) -> None:
43 | """Log the data."""
44 | step = data.pop("step", None)
45 | if step is not None:
46 | step = int(step)
47 | with self._lock:
48 | if self._scope is None:
49 | wandb.log(
50 | data,
51 | step=step,
52 | )
53 | else:
54 | wandb.log(
55 | {f"{self._scope}/{k}": v for k, v in data.items()},
56 | step=step,
57 | )
58 |
59 | def close(self) -> None:
60 | """Close WB logger."""
61 | with self._lock:
62 | wandb.finish()
63 |
64 |
65 | class ResultFilter(base.Logger):
66 | """Postprocessing for normalized score."""
67 |
68 | def __init__(self, to: base.Logger, game_name: str):
69 | """Init result filter."""
70 | self._to = to
71 | game_name = re.sub(r"(? float:
82 | return (score - random_score) / (dqn_score - random_score)
83 |
84 | self._normalizer = normalizer
85 |
86 | def write(self, data: base.LoggingData) -> None:
87 | """Write to logger."""
88 | if "episode_return" in data:
89 | data = {
90 | **data,
91 | "normalized_score": self._normalizer(data.get("episode_return", 0)),
92 | }
93 | self._to.write(data)
94 |
95 | def close(self) -> None:
96 | """Close logger."""
97 | self._to.close()
98 |
99 |
100 | def make_sail_logger(
101 | exp_name: str,
102 | label: str,
103 | save_data: bool = True,
104 | save_dir: str = "./logs",
105 | use_tb: bool = False,
106 | tb_dir: Optional[str] = None,
107 | use_wb: bool = False,
108 | config: Optional[dict] = None,
109 | time_delta: float = 1.0,
110 | asynchronous: bool = False,
111 | print_fn: Optional[Callable[[str], None]] = None,
112 | serialize_fn: Optional[Callable[[Mapping[str, Any]], str]] = base.to_numpy,
113 | ) -> base.Logger:
114 | """Makes a logger for SAILors.
115 |
116 | Args:
117 | exp_name: Name of the experiment.
118 | label: Name to give to the logger.
119 | save_data: Whether to persist data.
120 | save_dir: Directory to save log data.
121 | use_tb: Whether to use TensorBoard.
122 | tb_dir: Tensorboard directory.
123 | use_wb: Whether to use Weights and Biases.
124 | config: Experiment configurations.
125 | time_delta: Time (in seconds) between logging events.
126 | asynchronous: Whether the write function should block or not.
127 | print_fn: How to print to terminal (defaults to print).
128 | serialize_fn: An optional function to apply to the write inputs before
129 | passing them to the various loggers.
130 |
131 | Returns:
132 | A logger object that responds to logger.write(some_dict).
133 | """
134 | if not print_fn:
135 | print_fn = print
136 | terminal_logger = terminal.TerminalLogger(label=label, print_fn=print_fn)
137 |
138 | loggers = [terminal_logger]
139 |
140 | if save_data:
141 | os.makedirs(save_dir, exist_ok=True)
142 | fd = open(os.path.join(save_dir, f"{exp_name}.csv"), "a")
143 | loggers.append(
144 | csv.CSVLogger(
145 | directory_or_file=fd,
146 | label=exp_name,
147 | add_uid=False,
148 | flush_every=2,
149 | )
150 | )
151 |
152 | if use_wb:
153 | wb_logger = WBLogger(scope=label)
154 | wb_logger = filters.TimeFilter(wb_logger, time_delta)
155 | loggers.append(wb_logger)
156 |
157 | if use_tb:
158 | if tb_dir is None:
159 | tb_dir = "./tblogs"
160 | loggers.append(TFSummaryLogger(tb_dir, label))
161 |
162 | # Dispatch to all writers and filter Nones and by time.
163 | logger = aggregators.Dispatcher(loggers, serialize_fn)
164 | logger = filters.NoneFilter(logger)
165 |
166 | if config:
167 | logger = ResultFilter(logger, game_name=config["game_name"])
168 |
169 | if asynchronous:
170 | logger = async_logger.AsyncLogger(logger)
171 | logger = filters.TimeFilter(logger, 5.0)
172 |
173 | return logger
174 |
175 |
176 | def logger_fn(
177 | exp_name: str,
178 | label: str,
179 | save_data: bool = False,
180 | use_tb: bool = True,
181 | use_wb: bool = True,
182 | config: Optional[dict] = None,
183 | time_delta: float = 15.0,
184 | ) -> Logger:
185 | """Get logger function.
186 |
187 | Args:
188 | exp_name (str): Experiment name.
189 | label (str): Experiment label.
190 | save_data (bool, optional): Whether to save data. Defaults to False.
191 | use_tb (bool, optional): Whether to use TB. Defaults to True.
192 | use_wb (bool, optional): Whether to use WB. Defaults to True.
193 | config (Optional[dict], optional): Experiment configurations. Defaults to None.
194 | time_delta (float, optional): Time delta to emit logs. Defaults to 15.0.
195 |
196 | Returns:
197 | Logger: Logger.
198 | """
199 | tb_path = os.path.join("./tblogs", exp_name)
200 | return make_sail_logger(
201 | exp_name=exp_name,
202 | label=label,
203 | save_data=save_data,
204 | save_dir="./logs",
205 | use_tb=use_tb,
206 | tb_dir=tb_path,
207 | use_wb=use_wb,
208 | config=config,
209 | time_delta=time_delta, # Applied to W&B.
210 | print_fn=print,
211 | )
212 |
--------------------------------------------------------------------------------
/rosmo/data/rl_unplugged_atari_baselines.json:
--------------------------------------------------------------------------------
1 | {
2 | "alien": {
3 | "BC": 2670.0,
4 | "BCQ": 2090.0,
5 | "DQN": 1690.0,
6 | "IQN": 2860.0,
7 | "REM": 1730.0,
8 | "online_dqn": 2766.81,
9 | "random": 199.8
10 | },
11 | "amidar": {
12 | "BC": 256.0,
13 | "BCQ": 254.0,
14 | "DQN": 224.0,
15 | "IQN": 351.0,
16 | "REM": 214.0,
17 | "online_dqn": 1556.96,
18 | "random": 3.19
19 | },
20 | "assault": {
21 | "BC": 1810.0,
22 | "BCQ": 2260.0,
23 | "DQN": 1940.0,
24 | "IQN": 2180.0,
25 | "REM": 3070.0,
26 | "online_dqn": 1946.1,
27 | "random": 235.15
28 | },
29 | "asterix": {
30 | "BC": 2960.0,
31 | "BCQ": 1930.0,
32 | "DQN": 1520.0,
33 | "IQN": 5710.0,
34 | "REM": 4890.0,
35 | "online_dqn": 4131.77,
36 | "random": 279.06
37 | },
38 | "atlantis": {
39 | "BC": 2390000.0,
40 | "BCQ": 3200000.0,
41 | "DQN": 3020000.0,
42 | "IQN": 2710000.0,
43 | "REM": 3360000.0,
44 | "online_dqn": 944228.0,
45 | "random": 16973.04
46 | },
47 | "bank_heist": {
48 | "BC": 1050.0,
49 | "BCQ": 270.0,
50 | "DQN": 50.0,
51 | "IQN": 1110.0,
52 | "REM": 160.0,
53 | "online_dqn": 907.72,
54 | "random": 13.65
55 | },
56 | "battle_zone": {
57 | "BC": 4800.0,
58 | "BCQ": 25400.0,
59 | "DQN": 25600.0,
60 | "IQN": 16500.0,
61 | "REM": 26200.0,
62 | "online_dqn": 26458.99,
63 | "random": 2786.8
64 | },
65 | "beam_rider": {
66 | "BC": 1480.0,
67 | "BCQ": 1990.0,
68 | "DQN": 1810.0,
69 | "IQN": 3020.0,
70 | "REM": 2200.0,
71 | "online_dqn": 6453.26,
72 | "random": 362.05
73 | },
74 | "boxing": {
75 | "BC": 83.9,
76 | "BCQ": 97.2,
77 | "DQN": 96.3,
78 | "IQN": 95.8,
79 | "REM": 97.3,
80 | "online_dqn": 84.11,
81 | "random": 0.79
82 | },
83 | "breakout": {
84 | "BC": 235.0,
85 | "BCQ": 375.0,
86 | "DQN": 324.0,
87 | "IQN": 314.0,
88 | "REM": 362.0,
89 | "online_dqn": 157.86,
90 | "random": 1.33
91 | },
92 | "carnival": {
93 | "BC": 3920.0,
94 | "BCQ": 4310.0,
95 | "DQN": 1450.0,
96 | "IQN": 4820.0,
97 | "REM": 2080.0,
98 | "online_dqn": 5339.46,
99 | "random": 669.55
100 | },
101 | "centipede": {
102 | "BC": 1070.0,
103 | "BCQ": 1430.0,
104 | "DQN": 1250.0,
105 | "IQN": 1830.0,
106 | "REM": 810.0,
107 | "online_dqn": 3972.49,
108 | "random": 2181.66
109 | },
110 | "chopper_command": {
111 | "BC": 660.0,
112 | "BCQ": 3950.0,
113 | "DQN": 2250.0,
114 | "IQN": 830.0,
115 | "REM": 3610.0,
116 | "online_dqn": 3678.15,
117 | "random": 823.14
118 | },
119 | "crazy_climber": {
120 | "BC": 123000.0,
121 | "BCQ": 28000.0,
122 | "DQN": 23000.0,
123 | "IQN": 126000.0,
124 | "REM": 42000.0,
125 | "online_dqn": 118080.24,
126 | "random": 8173.59
127 | },
128 | "demon_attack": {
129 | "BC": 7600.0,
130 | "BCQ": 19300.0,
131 | "DQN": 11000.0,
132 | "IQN": 15500.0,
133 | "REM": 17000.0,
134 | "online_dqn": 6517.02,
135 | "random": 166.02
136 | },
137 | "double_dunk": {
138 | "BC": -16.4,
139 | "BCQ": -12.9,
140 | "DQN": -17.9,
141 | "IQN": -16.7,
142 | "REM": -17.9,
143 | "online_dqn": -1.22,
144 | "random": -18.42
145 | },
146 | "enduro": {
147 | "BC": 720.0,
148 | "BCQ": 1390.0,
149 | "DQN": 1210.0,
150 | "IQN": 1700.0,
151 | "REM": 3650.0,
152 | "online_dqn": 1016.28,
153 | "random": 0.0
154 | },
155 | "fishing_derby": {
156 | "BC": -7.4,
157 | "BCQ": 28.9,
158 | "DQN": 17.0,
159 | "IQN": 20.8,
160 | "REM": 29.3,
161 | "online_dqn": 18.57,
162 | "random": -93.25
163 | },
164 | "freeway": {
165 | "BC": 21.8,
166 | "BCQ": 16.9,
167 | "DQN": 15.4,
168 | "IQN": 24.7,
169 | "REM": 7.2,
170 | "online_dqn": 26.76,
171 | "random": 0.0
172 | },
173 | "frostbite": {
174 | "BC": 780.0,
175 | "BCQ": 3520.0,
176 | "DQN": 3230.0,
177 | "IQN": 2630.0,
178 | "REM": 3070.0,
179 | "online_dqn": 1643.65,
180 | "random": 71.99
181 | },
182 | "gopher": {
183 | "BC": 4900.0,
184 | "BCQ": 8700.0,
185 | "DQN": 2400.0,
186 | "IQN": 11300.0,
187 | "REM": 3700.0,
188 | "online_dqn": 8241.0,
189 | "random": 282.57
190 | },
191 | "gravitar": {
192 | "BC": 20.0,
193 | "BCQ": 580.0,
194 | "DQN": 500.0,
195 | "IQN": 235.0,
196 | "REM": 424.0,
197 | "online_dqn": 310.56,
198 | "random": 213.71
199 | },
200 | "hero": {
201 | "BC": 13900.0,
202 | "BCQ": 13200.0,
203 | "DQN": 5200.0,
204 | "IQN": 16200.0,
205 | "REM": 14000.0,
206 | "online_dqn": 16233.54,
207 | "random": 719.15
208 | },
209 | "ice_hockey": {
210 | "BC": -5.63,
211 | "BCQ": -2.51,
212 | "DQN": -2.88,
213 | "IQN": -4.65,
214 | "REM": -1.16,
215 | "online_dqn": -4.02,
216 | "random": -9.82
217 | },
218 | "jamesbond": {
219 | "BC": 237.0,
220 | "BCQ": 438.0,
221 | "DQN": 490.0,
222 | "IQN": 699.0,
223 | "REM": 369.0,
224 | "online_dqn": 777.73,
225 | "random": 27.63
226 | },
227 | "kangaroo": {
228 | "BC": 5690.0,
229 | "BCQ": 1300.0,
230 | "DQN": 820.0,
231 | "IQN": 9120.0,
232 | "REM": 1210.0,
233 | "online_dqn": 14125.11,
234 | "random": 41.3
235 | },
236 | "krull": {
237 | "BC": 8500.0,
238 | "BCQ": 7780.0,
239 | "DQN": 7480.0,
240 | "IQN": 8470.0,
241 | "REM": 7980.0,
242 | "online_dqn": 7238.51,
243 | "random": 1556.92
244 | },
245 | "kung_fu_master": {
246 | "BC": 5100.0,
247 | "BCQ": 16900.0,
248 | "DQN": 16100.0,
249 | "IQN": 19500.0,
250 | "REM": 19400.0,
251 | "online_dqn": 26637.88,
252 | "random": 556.31
253 | },
254 | "ms_pacman": {
255 | "BC": 4040.0,
256 | "BCQ": 3080.0,
257 | "DQN": 2470.0,
258 | "IQN": 4390.0,
259 | "REM": 3150.0,
260 | "online_dqn": 4171.53,
261 | "random": 247.97
262 | },
263 | "name_this_game": {
264 | "BC": 4100.0,
265 | "BCQ": 12600.0,
266 | "DQN": 11500.0,
267 | "IQN": 9900.0,
268 | "REM": 13000.0,
269 | "online_dqn": 8645.09,
270 | "random": 2401.05
271 | },
272 | "phoenix": {
273 | "BC": 2940.0,
274 | "BCQ": 6620.0,
275 | "DQN": 6410.0,
276 | "IQN": 4940.0,
277 | "REM": 7480.0,
278 | "online_dqn": 5122.3,
279 | "random": 873.03
280 | },
281 | "pong": {
282 | "BC": 18.9,
283 | "BCQ": 16.5,
284 | "DQN": 12.9,
285 | "IQN": 19.2,
286 | "REM": 16.5,
287 | "online_dqn": 18.25,
288 | "random": -20.3
289 | },
290 | "pooyan": {
291 | "BC": 3850.0,
292 | "BCQ": 4200.0,
293 | "DQN": 3180.0,
294 | "IQN": 5000.0,
295 | "REM": 4470.0,
296 | "online_dqn": 4135.32,
297 | "random": 411.36
298 | },
299 | "qbert": {
300 | "BC": 12600.0,
301 | "BCQ": 12600.0,
302 | "DQN": 10600.0,
303 | "IQN": 13400.0,
304 | "REM": 13100.0,
305 | "online_dqn": 12275.13,
306 | "random": 155.01
307 | },
308 | "riverraid": {
309 | "BC": 6000.0,
310 | "BCQ": 14200.0,
311 | "DQN": 9100.0,
312 | "IQN": 13000.0,
313 | "REM": 14200.0,
314 | "online_dqn": 12798.88,
315 | "random": 1504.25
316 | },
317 | "road_runner": {
318 | "BC": 19000.0,
319 | "BCQ": 57400.0,
320 | "DQN": 31700.0,
321 | "IQN": 44700.0,
322 | "REM": 56500.0,
323 | "online_dqn": 47880.48,
324 | "random": 15.46
325 | },
326 | "robotank": {
327 | "BC": 15.7,
328 | "BCQ": 60.7,
329 | "DQN": 55.7,
330 | "IQN": 42.7,
331 | "REM": 60.5,
332 | "online_dqn": 63.44,
333 | "random": 2.01
334 | },
335 | "seaquest": {
336 | "BC": 150.0,
337 | "BCQ": 5410.0,
338 | "DQN": 2870.0,
339 | "IQN": 1670.0,
340 | "REM": 5910.0,
341 | "online_dqn": 3233.47,
342 | "random": 81.78
343 | },
344 | "space_invaders": {
345 | "BC": 790.0,
346 | "BCQ": 2920.0,
347 | "DQN": 2710.0,
348 | "IQN": 2840.0,
349 | "REM": 2810.0,
350 | "online_dqn": 2044.63,
351 | "random": 149.52
352 | },
353 | "star_gunner": {
354 | "BC": 3000.0,
355 | "BCQ": 2500.0,
356 | "DQN": 1600.0,
357 | "IQN": 39400.0,
358 | "REM": 7500.0,
359 | "online_dqn": 55103.84,
360 | "random": 677.22
361 | },
362 | "time_pilot": {
363 | "BC": 1950.0,
364 | "BCQ": 5180.0,
365 | "DQN": 5310.0,
366 | "IQN": 3140.0,
367 | "REM": 4490.0,
368 | "online_dqn": 4160.51,
369 | "random": 3450.95
370 | },
371 | "up_n_down": {
372 | "BC": 16300.0,
373 | "BCQ": 32500.0,
374 | "DQN": 14600.0,
375 | "IQN": 32300.0,
376 | "REM": 27600.0,
377 | "online_dqn": 15677.92,
378 | "random": 513.94
379 | },
380 | "video_pinball": {
381 | "BC": 27000.0,
382 | "BCQ": 103000.0,
383 | "DQN": 82000.0,
384 | "IQN": 102000.0,
385 | "REM": 313000.0,
386 | "online_dqn": 335055.69,
387 | "random": 26024.42
388 | },
389 | "wizard_of_wor": {
390 | "BC": 730.0,
391 | "BCQ": 4680.0,
392 | "DQN": 2300.0,
393 | "IQN": 1400.0,
394 | "REM": 2730.0,
395 | "online_dqn": 1787.79,
396 | "random": 686.63
397 | },
398 | "yars_revenge": {
399 | "BC": 19100.0,
400 | "BCQ": 29100.0,
401 | "DQN": 24900.0,
402 | "IQN": 28400.0,
403 | "REM": 23100.0,
404 | "online_dqn": 26762.98,
405 | "random": 3147.67
406 | },
407 | "zaxxon": {
408 | "BC": 10.0,
409 | "BCQ": 9430.0,
410 | "DQN": 6050.0,
411 | "IQN": 870.0,
412 | "REM": 8300.0,
413 | "online_dqn": 4681.93,
414 | "random": 10.57
415 | }
416 | }
417 |
--------------------------------------------------------------------------------
/experiment/atari/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Atari experiment entry."""
16 |
17 | import os
18 | import pickle
19 | import random
20 | import time
21 | from typing import Dict, Iterator, List, Optional, Tuple
22 |
23 | import dm_env
24 | import jax
25 | import numpy as np
26 | import tensorflow as tf
27 | import wandb
28 | from absl import app, flags, logging
29 | from acme import EnvironmentLoop
30 | from acme.specs import make_environment_spec
31 | from acme.utils.loggers import Logger
32 |
33 | from experiment.atari.config import get_config
34 | from rosmo.agent.actor import RosmoEvalActor
35 | from rosmo.agent.learning import RosmoLearner
36 | from rosmo.agent.network import Networks, make_atari_networks
37 | from rosmo.data import atari_env_loader
38 | from rosmo.env_loop_observer import (
39 | EvaluationLoop,
40 | ExtendedEnvLoopObserver,
41 | LearningStepObserver,
42 | )
43 | from rosmo.loggers import logger_fn
44 | from rosmo.profiler import Profiler
45 | from rosmo.type import ActorOutput
46 |
47 | # ===== Flags. ===== #
48 | FLAGS = flags.FLAGS
49 | flags.DEFINE_boolean("debug", True, "Debug run.")
50 | flags.DEFINE_boolean("profile", False, "Profile codes.")
51 | flags.DEFINE_boolean("use_wb", False, "Use WB to log.")
52 | flags.DEFINE_string("user", "username", "Wandb user id.")
53 | flags.DEFINE_string("project", "rosmo", "Wandb project id.")
54 |
55 | flags.DEFINE_string("exp_id", None, "Experiment id.", required=True)
56 | flags.DEFINE_string("env", None, "Environment name to run.", required=True)
57 | flags.DEFINE_integer("seed", int(time.time()), "Random seed.")
58 |
59 | flags.DEFINE_boolean("sampling", False, "Whether to sample policy target.")
60 | flags.DEFINE_integer("num_simulations", 4, "Simulation budget.")
61 | flags.DEFINE_enum("algo", "rosmo", ["rosmo", "mzu"], "Algorithm to use.")
62 | flags.DEFINE_integer(
63 | "search_depth",
64 | 0,
65 | "Depth of Monte-Carlo Tree Search (only for mzu), \
66 | defaults to num_simulations.",
67 | )
68 |
69 | # ===== Learner. ===== #
70 | def get_learner(config, networks, data_iterator, logger) -> RosmoLearner:
71 | """Get ROSMO learner."""
72 | learner = RosmoLearner(
73 | networks,
74 | demonstrations=data_iterator,
75 | config=config,
76 | logger=logger,
77 | )
78 | return learner
79 |
80 |
81 | # ===== Eval Actor-Env Loop & Observer. ===== #
82 | def get_actor_env_eval_loop(
83 | config, networks, environment, observers, logger
84 | ) -> Tuple[RosmoEvalActor, EnvironmentLoop]:
85 | """Get actor, env and evaluation loop."""
86 | actor = RosmoEvalActor(
87 | networks,
88 | config,
89 | )
90 | eval_loop = EvaluationLoop(
91 | environment=environment,
92 | actor=actor,
93 | logger=logger,
94 | should_update=False,
95 | observers=observers,
96 | )
97 | return actor, eval_loop
98 |
99 |
100 | def get_env_loop_observers() -> List[ExtendedEnvLoopObserver]:
101 | """Get environment loop observers."""
102 | observers = []
103 | learning_step_ob = LearningStepObserver()
104 | observers.append(learning_step_ob)
105 | return observers
106 |
107 |
108 | # ===== Environment & Dataloader. ===== #
109 | def get_env_data_loader(config) -> Tuple[dm_env.Environment, Iterator]:
110 | """Get environment and trajectory data loader."""
111 | trajectory_length = config["unroll_steps"] + config["td_steps"] + 1
112 | environment, dataset = atari_env_loader(
113 | env_name=config["game_name"],
114 | run_number=config["run_number"],
115 | dataset_dir=config["data_dir"],
116 | stack_size=config["stack_size"],
117 | data_percentage=config["data_percentage"],
118 | trajectory_length=trajectory_length,
119 | shuffle_num_steps=5000 if FLAGS.debug else 50000,
120 | )
121 |
122 | def transform_timesteps(steps: Dict[str, np.ndarray]) -> ActorOutput:
123 | return ActorOutput(
124 | observation=steps["observation"],
125 | reward=steps["reward"],
126 | is_first=steps["is_first"],
127 | is_last=steps["is_last"],
128 | action=steps["action"],
129 | )
130 |
131 | dataset = (
132 | dataset.repeat()
133 | .batch(config["batch_size"])
134 | .map(transform_timesteps)
135 | .prefetch(tf.data.AUTOTUNE)
136 | )
137 | options = tf.data.Options()
138 | options.threading.max_intra_op_parallelism = 1
139 | dataset = dataset.with_options(options)
140 | iterator = dataset.as_numpy_iterator()
141 | return environment, iterator
142 |
143 |
144 | # ===== Network. ===== #
145 | def get_networks(config, environment) -> Networks:
146 | """Get environment-specific networks."""
147 | environment_spec = make_environment_spec(environment)
148 | logging.info(environment_spec)
149 | networks = make_atari_networks(
150 | env_spec=environment_spec,
151 | channels=config["channels"],
152 | num_bins=config["num_bins"],
153 | output_init_scale=config["output_init_scale"],
154 | blocks_representation=config["blocks_representation"],
155 | blocks_prediction=config["blocks_prediction"],
156 | blocks_transition=config["blocks_transition"],
157 | reduced_channels_head=config["reduced_channels_head"],
158 | fc_layers_reward=config["fc_layers_reward"],
159 | fc_layers_value=config["fc_layers_value"],
160 | fc_layers_policy=config["fc_layers_policy"],
161 | )
162 | return networks
163 |
164 |
165 | # ===== Misc. ===== #
166 | def get_logger_fn(
167 | exp_full_name: str,
168 | job_name: str,
169 | is_eval: bool = False,
170 | config: Optional[Dict] = None,
171 | ) -> Logger:
172 | """Get logger function."""
173 | save_data = is_eval
174 | return logger_fn(
175 | exp_name=exp_full_name,
176 | label=job_name,
177 | save_data=save_data and not FLAGS.debug,
178 | use_tb=False,
179 | use_wb=FLAGS.use_wb and not FLAGS.debug,
180 | config=config,
181 | )
182 |
183 |
184 | def main(_):
185 | """Main program."""
186 | platform = jax.lib.xla_bridge.get_backend().platform
187 | num_devices = jax.device_count()
188 | logging.warn(f"Compute platform: {platform} with {num_devices} devices.")
189 | logging.info(f"Debug mode: {FLAGS.debug}")
190 | random.seed(FLAGS.seed)
191 | np.random.seed(FLAGS.seed)
192 |
193 | # ===== Setup. ===== #
194 | cfg = get_config(FLAGS.env)
195 | env, dataloader = get_env_data_loader(cfg)
196 | networks = get_networks(cfg, env)
197 |
198 | # ===== Essentials. ===== #
199 | learner = get_learner(
200 | cfg,
201 | networks,
202 | dataloader,
203 | get_logger_fn(
204 | cfg["exp_full_name"],
205 | "learner",
206 | config=cfg,
207 | ),
208 | )
209 | observers = get_env_loop_observers()
210 | actor, eval_loop = get_actor_env_eval_loop(
211 | cfg,
212 | networks,
213 | env,
214 | observers,
215 | get_logger_fn(
216 | cfg["exp_full_name"],
217 | "evaluator",
218 | is_eval=True,
219 | config=cfg,
220 | ),
221 | )
222 | evaluate_episodes = 2 if FLAGS.debug else cfg["evaluate_episodes"]
223 |
224 | init_step = 0
225 | save_path = os.path.join("./checkpoint", cfg["exp_full_name"])
226 | os.makedirs(save_path, exist_ok=True)
227 | if FLAGS.profile:
228 | profile_dir = "./profile"
229 | os.makedirs(profile_dir, exist_ok=True)
230 | profiler = Profiler(profile_dir, cfg["exp_full_name"], with_jax=True)
231 |
232 | if FLAGS.use_wb and not (FLAGS.debug or FLAGS.profile):
233 | wb_name = cfg["exp_full_name"]
234 | wb_cfg = cfg.to_dict()
235 |
236 | wandb.init(
237 | project=FLAGS.project,
238 | entity=FLAGS.user,
239 | name=wb_name,
240 | config=wb_cfg,
241 | sync_tensorboard=False,
242 | )
243 |
244 | # ===== Training Loop. ===== #
245 | for i in range(init_step + 1, cfg["total_steps"]):
246 | learner.step()
247 | for ob in observers:
248 | ob.step()
249 |
250 | if (i + 1) % cfg["save_period"] == 0:
251 | with open(os.path.join(save_path, f"ckpt_{i}.pkl"), "wb") as f:
252 | pickle.dump(learner.save(), f)
253 | if (i + 1) % cfg["eval_period"] == 0:
254 | actor.update_params(learner.save().params)
255 | eval_loop.run(evaluate_episodes)
256 |
257 | if FLAGS.profile:
258 | if i == 100:
259 | profiler.start()
260 | if i == 200:
261 | profiler.stop_and_save()
262 | break
263 | elif FLAGS.debug:
264 | actor.update_params(learner.save().params)
265 | eval_loop.run(evaluate_episodes)
266 | break
267 |
268 | # ===== Cleanup. ===== #
269 | learner._logger.close()
270 | eval_loop._logger.close()
271 | del env, networks, dataloader, learner, observers, actor, eval_loop
272 |
273 |
274 | if __name__ == "__main__":
275 | app.run(main)
276 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2022 Garena Online Private Limited
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/rosmo/data/rlu_atari.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """RL Unplugged Atari datasets."""
16 | import json
17 | import os
18 | from typing import Any, Callable, Dict, Optional, Tuple, Union
19 |
20 | import dm_env
21 | import gym
22 | import rlds
23 | import tensorflow as tf
24 | import tensorflow_datasets as tfds
25 | from absl import logging
26 | from acme import wrappers
27 | from dm_env import specs
28 | from dopamine.discrete_domains import atari_lib
29 |
30 | from rosmo.type import Array
31 |
32 | with open(
33 | os.path.join(
34 | os.path.dirname(os.path.abspath(__file__)), "rl_unplugged_atari_baselines.json"
35 | ),
36 | "r",
37 | ) as f:
38 | BASELINES = json.load(f)
39 |
40 |
41 | class _BatchToTransition:
42 | """Creates (s,a,r,f,l) transitions."""
43 |
44 | @staticmethod
45 | def create_transitions(batch: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
46 | """Create stacked transitions.
47 |
48 | Args:
49 | batch (Dict[str, tf.Tensor]): Data batch
50 |
51 | Returns:
52 | Dict[str, tf.Tensor]: Stacked data batch.
53 | """
54 | observation = tf.squeeze(batch[rlds.OBSERVATION], axis=-1)
55 | observation = tf.transpose(observation, perm=[1, 2, 0])
56 | action = batch[rlds.ACTION][-1]
57 | reward = batch[rlds.REWARD][-1]
58 | discount = batch[rlds.DISCOUNT][-1]
59 | return {
60 | "observation": observation,
61 | "action": action,
62 | "reward": reward,
63 | "discount": discount,
64 | "is_first": batch[rlds.IS_FIRST][0],
65 | "is_last": batch[rlds.IS_LAST][-1],
66 | }
67 |
68 |
69 | def _get_trajectory_dataset_fn(
70 | stack_size: int,
71 | trajectory_length: int = 1,
72 | ) -> Callable[[tf.data.Dataset], tf.data.Dataset]:
73 | batch_fn = _BatchToTransition().create_transitions
74 |
75 | def make_trajectory_dataset(episode: tf.data.Dataset) -> tf.data.Dataset:
76 | """Converts an episode of steps to a dataset of custom transitions.
77 |
78 | Episode spec: {
79 | 'checkpoint_id': ,
80 | 'episode_id': ,
81 | 'episode_return': ,
82 | 'steps': <_VariantDataset element_spec={
83 | 'action': TensorSpec(shape=(), dtype=tf.int64, name=None),
84 | 'discount': TensorSpec(shape=(), dtype=tf.float32, name=None),
85 | 'is_first': TensorSpec(shape=(), dtype=tf.bool, name=None),
86 | 'is_last': TensorSpec(shape=(), dtype=tf.bool, name=None),
87 | 'is_terminal': TensorSpec(shape=(), dtype=tf.bool, name=None),
88 | 'observation': TensorSpec(shape=(84, 84, 1), dtype=tf.uint8,
89 | name=None),
90 | 'reward': TensorSpec(shape=(), dtype=tf.float32, name=None)
91 | }
92 | >}
93 | """
94 | # Create a dataset of 2-step sequences with overlap of 1.
95 | timesteps: tf.data.Dataset = episode[rlds.STEPS]
96 | batched_steps = rlds.transformations.batch(
97 | timesteps,
98 | size=stack_size,
99 | shift=1,
100 | drop_remainder=True,
101 | )
102 | transitions = batched_steps.map(batch_fn)
103 | # Batch trajectory.
104 | if trajectory_length > 1:
105 | transitions = transitions.repeat(2)
106 | transitions = transitions.skip(
107 | tf.random.uniform([], 0, trajectory_length, dtype=tf.int64)
108 | )
109 | trajectory = transitions.batch(trajectory_length, drop_remainder=True)
110 | else:
111 | trajectory = transitions
112 | return trajectory
113 |
114 | return make_trajectory_dataset
115 |
116 |
117 | def _uniformly_subsampled_atari_data(
118 | dataset_name: str,
119 | data_percent: int,
120 | data_dir: str,
121 | ) -> tf.data.Dataset:
122 | ds_builder = tfds.builder(dataset_name)
123 | data_splits = []
124 | total_num_episode = 0
125 | for split, info in ds_builder.info.splits.items():
126 | # Convert `data_percent` to number of episodes to allow
127 | # for fractional percentages.
128 | num_episodes = int((data_percent / 100) * info.num_examples)
129 | total_num_episode += num_episodes
130 | if num_episodes == 0:
131 | raise ValueError(f"{data_percent}% leads to 0 episodes in {split}!")
132 | # Sample first `data_percent` episodes from each of the data split.
133 | data_splits.append(f"{split}[:{num_episodes}]")
134 | # Interleave episodes across different splits/checkpoints.
135 | # Set `shuffle_files=True` to shuffle episodes across files within splits.
136 | read_config = tfds.ReadConfig(
137 | interleave_cycle_length=len(data_splits),
138 | shuffle_reshuffle_each_iteration=True,
139 | enable_ordering_guard=False,
140 | )
141 | logging.info(f"Total number of episode = {total_num_episode}")
142 | return tfds.load(
143 | dataset_name,
144 | data_dir=data_dir,
145 | split="+".join(data_splits),
146 | read_config=read_config,
147 | shuffle_files=True,
148 | )
149 |
150 |
151 | def create_atari_ds_loader(
152 | env_name: str,
153 | run_number: int,
154 | dataset_dir: str,
155 | stack_size: int = 4,
156 | data_percentage: int = 10,
157 | trajectory_fn: Optional[Callable] = None,
158 | shuffle_num_episodes: int = 1000,
159 | shuffle_num_steps: int = 50000,
160 | trajectory_length: int = 10,
161 | **_: Any,
162 | ) -> tf.data.Dataset:
163 | """Create Atari dataset loader.
164 |
165 | Args:
166 | env_name (str): Environment name.
167 | run_number (int): Run number.
168 | dataset_dir (str): Directory to the dataset.
169 | stack_size (int, optional): Stack size. Defaults to 4.
170 | data_percentage (int, optional): Fraction of data to be used. Defaults to 10.
171 | trajectory_fn (Optional[Callable], optional): Function to form trajectory.
172 | Defaults to None.
173 | shuffle_num_episodes (int, optional): Number of episodes to shuffle.
174 | Defaults to 1000.
175 | shuffle_num_steps (int, optional): Number of steps to shuffle.
176 | Defaults to 50000.
177 | trajectory_length (int, optional): Trajectory length. Defaults to 10.
178 | **_: Other keyword arguments.
179 |
180 | Returns:
181 | tf.data.Dataset: Dataset.
182 | """
183 | if trajectory_fn is None:
184 | trajectory_fn = _get_trajectory_dataset_fn(stack_size, trajectory_length)
185 | dataset_name = f"rlu_atari_checkpoints_ordered/{env_name}_run_{run_number}"
186 | # Create a dataset of episodes sampling `data_percent`% episodes
187 | # from each of the data split.
188 | dataset = _uniformly_subsampled_atari_data(
189 | dataset_name, data_percentage, dataset_dir
190 | )
191 | # Shuffle the episodes to avoid consecutive episodes.
192 | dataset = dataset.shuffle(shuffle_num_episodes)
193 | # Interleave=1 keeps ordered sequential steps.
194 | dataset = dataset.interleave(
195 | trajectory_fn,
196 | cycle_length=100,
197 | block_length=1,
198 | deterministic=False,
199 | num_parallel_calls=tf.data.AUTOTUNE,
200 | )
201 | # Shuffle trajectories in the dataset.
202 | dataset = dataset.shuffle(
203 | shuffle_num_steps // trajectory_length,
204 | reshuffle_each_iteration=True,
205 | )
206 | return dataset
207 |
208 |
209 | class _AtariDopamineWrapper(dm_env.Environment):
210 | """Wrapper for Atari Dopamine environmnet."""
211 |
212 | def __init__(self, env: gym.Env, max_episode_steps: int = 108000):
213 | self._env = env
214 | self._max_episode_steps = max_episode_steps
215 | self._episode_steps = 0
216 | self._reset_next_episode = True
217 | self._reset_next_step = True
218 |
219 | def reset(self) -> dm_env.TimeStep:
220 | self._episode_steps = 0
221 | self._reset_next_step = False
222 | observation = self._env.reset()
223 | return dm_env.restart(observation.squeeze(-1)) # type: ignore
224 |
225 | def step(self, action: Union[int, Array]) -> dm_env.TimeStep:
226 | if self._reset_next_step:
227 | return self.reset()
228 | if not isinstance(action, int):
229 | action = action.item()
230 | observation, reward, terminal, _ = self._env.step(action) # type: ignore
231 | observation = observation.squeeze(-1)
232 | discount = 1 - float(terminal)
233 | self._episode_steps += 1
234 | if terminal:
235 | self._reset_next_episode = True
236 | return dm_env.termination(reward, observation)
237 | if self._episode_steps == self._max_episode_steps:
238 | self._reset_next_episode = True
239 | return dm_env.truncation(reward, observation, discount)
240 | return dm_env.transition(reward, observation, discount)
241 |
242 | def observation_spec(self) -> specs.Array:
243 | space = self._env.observation_space
244 | return specs.Array(space.shape[:-1], space.dtype) # type: ignore
245 |
246 | def action_spec(self) -> specs.DiscreteArray:
247 | return specs.DiscreteArray(self._env.action_space.n) # type: ignore
248 |
249 | def render(self, mode: str = "rgb_array") -> Any:
250 | """Render the environment.
251 |
252 | Args:
253 | mode (str, optional): Mode of rendering. Defaults to "rgb_array".
254 |
255 | Returns:
256 | Any: Rendered result.
257 | """
258 | return self._env.render(mode)
259 |
260 |
261 | def environment(game: str, stack_size: int) -> dm_env.Environment:
262 | """Atari environment."""
263 | env = atari_lib.create_atari_environment(game_name=game, sticky_actions=True)
264 | env = _AtariDopamineWrapper(env, max_episode_steps=20_000)
265 | env = wrappers.FrameStackingWrapper(env, num_frames=stack_size)
266 | return wrappers.SinglePrecisionWrapper(env)
267 |
268 |
269 | def env_loader(
270 | env_name: str,
271 | run_number: int,
272 | dataset_dir: str,
273 | stack_size: int = 4,
274 | data_percentage: int = 10,
275 | trajectory_fn: Optional[Callable] = None,
276 | shuffle_num_episodes: int = 1000,
277 | shuffle_num_steps: int = 50000,
278 | trajectory_length: int = 10,
279 | **_: Any,
280 | ) -> Tuple[dm_env.Environment, tf.data.Dataset]:
281 | """Get the environment and dataset.
282 |
283 | Args:
284 | env_name (str): Name of the environment.
285 | run_number (int): Run number of the dataset.
286 | dataset_dir (str): Directory storing the dataset.
287 | stack_size (int, optional): Number of frame stacking. Defaults to 4.
288 | data_percentage (int, optional): Fraction of data to be used. Defaults to 10.
289 | trajectory_fn (Optional[Callable], optional): Function to form trajectory.
290 | Defaults to None.
291 | shuffle_num_episodes (int, optional): Number of episodes to shuffle.
292 | Defaults to 1000.
293 | shuffle_num_steps (int, optional): Number of steps to shuffle.
294 | Defaults to 50000.
295 | trajectory_length (int, optional): Trajectory length. Defaults to 10.
296 | **_: Other keyword arguments.
297 |
298 | Returns:
299 | Tuple[dm_env.Environment, tf.data.Dataset]: Environment and dataset.
300 | """
301 | return environment(game=env_name, stack_size=stack_size), create_atari_ds_loader(
302 | env_name=env_name,
303 | run_number=run_number,
304 | dataset_dir=dataset_dir,
305 | stack_size=stack_size,
306 | data_percentage=data_percentage,
307 | trajectory_fn=trajectory_fn,
308 | shuffle_num_episodes=shuffle_num_episodes,
309 | shuffle_num_steps=shuffle_num_steps,
310 | trajectory_length=trajectory_length,
311 | )
312 |
--------------------------------------------------------------------------------
/rosmo/agent/network.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Haiku neural network modules."""
16 | import dataclasses
17 | from typing import Any, Dict, List, Optional, Tuple
18 |
19 | import haiku as hk
20 | import jax
21 | import jax.numpy as jnp
22 | from acme import specs
23 | from acme.jax import networks as networks_lib
24 | from acme.jax import utils
25 |
26 | from rosmo.type import Forwardable
27 |
28 |
29 | def get_prediction_head_layers(
30 | reduced_channels_head: int,
31 | mlp_layers: List[int],
32 | num_predictions: int,
33 | w_init: Optional[hk.initializers.Initializer] = None,
34 | ) -> List[Forwardable]:
35 | """Get prediction head layers.
36 |
37 | Args:
38 | reduced_channels_head (int): Conv reduced channels.
39 | mlp_layers (List[int]): MLP layers' hidden units.
40 | num_predictions (int): Output size.
41 | w_init (Optional[hk.initializers.Initializer], optional): Weight
42 | initialization. Defaults to None.
43 |
44 | Returns:
45 | List[Forwardable]: List of layers.
46 | """
47 | layers: List[Forwardable] = [
48 | hk.Conv2D(
49 | reduced_channels_head,
50 | kernel_shape=1,
51 | stride=1,
52 | padding="SAME",
53 | with_bias=False,
54 | ),
55 | hk.LayerNorm(axis=(-3, -2, -1), create_scale=True, create_offset=True),
56 | jax.nn.relu,
57 | hk.Flatten(-3),
58 | ]
59 | for l in mlp_layers:
60 | layers.extend(
61 | [
62 | hk.Linear(l, with_bias=False),
63 | hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
64 | jax.nn.relu,
65 | ]
66 | )
67 | layers.append(hk.Linear(num_predictions, w_init=w_init))
68 | return layers
69 |
70 |
71 | def get_ln_relu_layers() -> List[Forwardable]:
72 | """Get LN relu layers.
73 |
74 | Returns:
75 | List[Forwardable]: LayerNorm+relu.
76 | """
77 | return [
78 | hk.LayerNorm(axis=(-3, -2, -1), create_scale=True, create_offset=True),
79 | jax.nn.relu,
80 | ]
81 |
82 |
83 | class ResConvBlock(hk.Module):
84 | """A residual convolutional block in pre-activation style."""
85 |
86 | def __init__(
87 | self,
88 | channels: int,
89 | stride: int,
90 | use_projection: bool,
91 | name: str = "res_conv_block",
92 | ):
93 | """Init residual block."""
94 | super().__init__(name=name)
95 | self._use_projection = use_projection
96 | if use_projection:
97 | self._proj_conv = hk.Conv2D(
98 | channels, kernel_shape=3, stride=stride, padding="SAME", with_bias=False
99 | )
100 | self._conv_0 = hk.Conv2D(
101 | channels, kernel_shape=3, stride=stride, padding="SAME", with_bias=False
102 | )
103 | self._ln_0 = hk.LayerNorm(
104 | axis=(-3, -2, -1), create_scale=True, create_offset=True
105 | )
106 | self._conv_1 = hk.Conv2D(
107 | channels, kernel_shape=3, stride=1, padding="SAME", with_bias=False
108 | )
109 | self._ln_1 = hk.LayerNorm(
110 | axis=(-3, -2, -1), create_scale=True, create_offset=True
111 | )
112 |
113 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
114 | """Forward ResBlock."""
115 | # NOTE: Using LayerNorm is fine
116 | # (https://arxiv.org/pdf/2104.06294.pdf Appendix A).
117 | shortcut = out = x
118 | out = self._ln_0(out)
119 | out = jax.nn.relu(out)
120 | if self._use_projection:
121 | shortcut = self._proj_conv(out)
122 | out = hk.Sequential(
123 | [
124 | self._conv_0,
125 | self._ln_1,
126 | jax.nn.relu,
127 | self._conv_1,
128 | ]
129 | )(out)
130 | return shortcut + out
131 |
132 |
133 | class Representation(hk.Module):
134 | """Representation encoding module."""
135 |
136 | def __init__(
137 | self,
138 | channels: int,
139 | num_blocks: int,
140 | name: str = "representation",
141 | ):
142 | """Init representatioin function."""
143 | super().__init__(name=name)
144 | self._channels = channels
145 | self._num_blocks = num_blocks
146 |
147 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
148 | """Forward representation function."""
149 | # 1. Downsampling.
150 | torso: List[Forwardable] = [
151 | lambda x: x / 255.0,
152 | hk.Conv2D(
153 | self._channels // 2,
154 | kernel_shape=3,
155 | stride=2,
156 | padding="SAME",
157 | with_bias=False,
158 | ),
159 | ]
160 | torso.extend(
161 | [
162 | ResConvBlock(self._channels // 2, stride=1, use_projection=False)
163 | for _ in range(1)
164 | ]
165 | )
166 | torso.append(ResConvBlock(self._channels, stride=2, use_projection=True))
167 | torso.extend(
168 | [
169 | ResConvBlock(self._channels, stride=1, use_projection=False)
170 | for _ in range(1)
171 | ]
172 | )
173 | torso.append(
174 | hk.AvgPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding="SAME")
175 | )
176 | torso.extend(
177 | [
178 | ResConvBlock(self._channels, stride=1, use_projection=False)
179 | for _ in range(1)
180 | ]
181 | )
182 | torso.append(
183 | hk.AvgPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding="SAME")
184 | )
185 |
186 | # 2. Encoding.
187 | torso.extend(
188 | [
189 | ResConvBlock(self._channels, stride=1, use_projection=False)
190 | for _ in range(self._num_blocks)
191 | ]
192 | )
193 | return hk.Sequential(torso)(observations)
194 |
195 |
196 | class Transition(hk.Module):
197 | """Dynamics transition module."""
198 |
199 | def __init__(
200 | self,
201 | channels: int,
202 | num_blocks: int,
203 | name: str = "transition",
204 | ):
205 | """Init transition function."""
206 | super().__init__(name=name)
207 | self._channels = channels
208 | self._num_blocks = num_blocks
209 |
210 | def __call__(
211 | self, encoded_action: jnp.ndarray, prev_state: jnp.ndarray
212 | ) -> jnp.ndarray:
213 | """Forward transition function."""
214 | channels = prev_state.shape[-1]
215 | shortcut = prev_state
216 |
217 | prev_state = hk.LayerNorm(
218 | axis=(-3, -2, -1), create_scale=True, create_offset=True
219 | )(prev_state)
220 | prev_state = jax.nn.relu(prev_state)
221 |
222 | x_and_h = jnp.concatenate([prev_state, encoded_action], axis=-1)
223 | out = hk.Conv2D(
224 | self._channels,
225 | kernel_shape=3,
226 | stride=1,
227 | padding="SAME",
228 | with_bias=False,
229 | )(x_and_h)
230 | out += shortcut # Residual link to maintain recurrent info flow.
231 |
232 | res_layers = [
233 | ResConvBlock(channels, stride=1, use_projection=False)
234 | for _ in range(self._num_blocks)
235 | ]
236 | out = hk.Sequential(res_layers)(out)
237 | return out
238 |
239 |
240 | class Prediction(hk.Module):
241 | """Policy, value and reward prediction module."""
242 |
243 | def __init__(
244 | self,
245 | num_blocks: int,
246 | num_actions: int,
247 | num_bins: int,
248 | channel: int,
249 | fc_layers_reward: List[int],
250 | fc_layers_value: List[int],
251 | fc_layers_policy: List[int],
252 | output_init_scale: float,
253 | name: str = "prediction",
254 | ) -> None:
255 | """Init prediction function."""
256 | super().__init__(name=name)
257 | self._num_blocks = num_blocks
258 | self._num_actions = num_actions
259 | self._num_bins = num_bins
260 | self._channel = channel
261 | self._fc_layers_reward = fc_layers_reward
262 | self._fc_layers_value = fc_layers_value
263 | self._fc_layers_policy = fc_layers_policy
264 | self._output_init_scale = output_init_scale
265 |
266 | def __call__(
267 | self, states: jnp.ndarray
268 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
269 | """Forward prediction function."""
270 | output_init = hk.initializers.VarianceScaling(scale=self._output_init_scale)
271 | reward_head, value_head, policy_head = [], [], []
272 |
273 | # Add LN+Relu due to pre-activation.
274 | reward_head.extend(get_ln_relu_layers())
275 | value_head.extend(get_ln_relu_layers())
276 | policy_head.extend(get_ln_relu_layers())
277 |
278 | reward_head.extend(
279 | get_prediction_head_layers(
280 | self._channel,
281 | self._fc_layers_reward,
282 | self._num_bins,
283 | output_init,
284 | )
285 | )
286 | reward_logits = hk.Sequential(reward_head)(states)
287 |
288 | res_layers = [
289 | ResConvBlock(states.shape[-1], stride=1, use_projection=False)
290 | for _ in range(self._num_blocks)
291 | ]
292 | out = hk.Sequential(res_layers)(states)
293 |
294 | value_head.extend(
295 | get_prediction_head_layers(
296 | self._channel,
297 | self._fc_layers_value,
298 | self._num_bins,
299 | output_init,
300 | )
301 | )
302 | value_logits = hk.Sequential(value_head)(out)
303 |
304 | policy_head.extend(
305 | get_prediction_head_layers(
306 | self._channel,
307 | self._fc_layers_policy,
308 | self._num_actions,
309 | output_init,
310 | )
311 | )
312 | policy_logits = hk.Sequential(policy_head)(out)
313 | return policy_logits, reward_logits, value_logits
314 |
315 |
316 | @dataclasses.dataclass
317 | class Networks:
318 | """ROSMO Networks."""
319 |
320 | representation_network: networks_lib.FeedForwardNetwork
321 | transition_network: networks_lib.FeedForwardNetwork
322 | prediction_network: networks_lib.FeedForwardNetwork
323 |
324 | environment_specs: specs.EnvironmentSpec
325 |
326 |
327 | def make_atari_networks(
328 | env_spec: specs.EnvironmentSpec,
329 | channels: int,
330 | num_bins: int,
331 | output_init_scale: float,
332 | blocks_representation: int,
333 | blocks_prediction: int,
334 | blocks_transition: int,
335 | reduced_channels_head: int,
336 | fc_layers_reward: List[int],
337 | fc_layers_value: List[int],
338 | fc_layers_policy: List[int],
339 | ) -> Networks:
340 | """Make Atari networks.
341 |
342 | Args:
343 | env_spec (specs.EnvironmentSpec): Environment spec.
344 | channels (int): Convolution channels.
345 | num_bins (int): Number of bins.
346 | output_init_scale (float): Weight init scale.
347 | blocks_representation (int): Number of blocks for representation.
348 | blocks_prediction (int): Number of blocks for prediction.
349 | blocks_transition (int): Number of blocks for transition.
350 | reduced_channels_head (int): Reduced conv channels for prediction head.
351 | fc_layers_reward (List[int]): Fully connected layers for reward prediction.
352 | fc_layers_value (List[int]): Fully connected layers for value prediction.
353 | fc_layers_policy (List[int]): Fully connected layers for policy prediction.
354 |
355 | Returns:
356 | Networks: Constructed networks.
357 | """
358 | action_space_size = env_spec.actions.num_values
359 |
360 | def _representation_fun(observations: jnp.ndarray) -> jnp.ndarray:
361 | network = Representation(channels, blocks_representation)
362 | state = network(observations)
363 | return state
364 |
365 | representation = hk.without_apply_rng(hk.transform(_representation_fun))
366 | hidden_channels = channels
367 |
368 | def _transition_fun(action: jnp.ndarray, state: jnp.ndarray) -> jnp.ndarray:
369 | action = hk.one_hot(action, action_space_size)[None, :]
370 | encoded_action = jnp.broadcast_to(action, state.shape[:-1] + action.shape[-1:])
371 |
372 | network = Transition(hidden_channels, blocks_transition)
373 | next_state = network(encoded_action, state)
374 | return next_state
375 |
376 | transition = hk.without_apply_rng(hk.transform(_transition_fun))
377 | prediction = hk.without_apply_rng(
378 | hk.transform(
379 | lambda states: Prediction(
380 | blocks_prediction,
381 | action_space_size,
382 | num_bins,
383 | reduced_channels_head,
384 | fc_layers_reward,
385 | fc_layers_value,
386 | fc_layers_policy,
387 | output_init_scale,
388 | )(states)
389 | )
390 | )
391 |
392 | dummy_action = jnp.array([env_spec.actions.generate_value()])
393 | dummy_obs = utils.zeros_like(env_spec.observations)
394 |
395 | def _dummy_state(key: networks_lib.PRNGKey) -> jnp.ndarray:
396 | encoder_params = representation.init(key, dummy_obs)
397 | dummy_state = representation.apply(encoder_params, dummy_obs)
398 | return dummy_state
399 |
400 | return Networks(
401 | representation_network=networks_lib.FeedForwardNetwork(
402 | lambda key: representation.init(key, dummy_obs), representation.apply
403 | ),
404 | transition_network=networks_lib.FeedForwardNetwork(
405 | lambda key: transition.init(key, dummy_action, _dummy_state(key)),
406 | transition.apply,
407 | ),
408 | prediction_network=networks_lib.FeedForwardNetwork(
409 | lambda key: prediction.init(key, _dummy_state(key)),
410 | prediction.apply,
411 | ),
412 | environment_specs=env_spec,
413 | )
414 |
415 |
416 | def get_bsuite_networks(
417 | env_spec: specs.EnvironmentSpec, config: Dict[str, Any]
418 | ) -> Networks:
419 | """Make BSuite networks.
420 |
421 | Args:
422 | env_spec (specs.EnvironmentSpec): Environment specifications.
423 | config (Dict[str, Any]): Configurations.
424 |
425 | Returns:
426 | Networks: Constructed networks.
427 | """
428 | action_space_size = env_spec.actions.num_values
429 |
430 | def _representation_fun(observations: jnp.ndarray) -> jnp.ndarray:
431 | network = hk.Sequential([hk.Flatten(), hk.nets.MLP(config["encoder_layers"])])
432 | state = network(observations)
433 | return state
434 |
435 | representation = hk.without_apply_rng(hk.transform(_representation_fun))
436 |
437 | def _transition_fun(action: jnp.ndarray, state: jnp.ndarray) -> jnp.ndarray:
438 | action = hk.one_hot(action, action_space_size)
439 | network = hk.nets.MLP(config["dynamics_layers"])
440 | sa = jnp.concatenate(
441 | [jnp.reshape(state, (-1, state.shape[-1])), action], axis=-1
442 | )
443 | next_state = network(sa).squeeze()
444 | return next_state
445 |
446 | transition = hk.without_apply_rng(hk.transform(_transition_fun))
447 |
448 | def _prediction_fun(state: jnp.ndarray) -> Tuple[jnp.ndarray, ...]:
449 | network = hk.nets.MLP(config["prediction_layers"], activate_final=True)
450 | head_state = network(state)
451 | output_init = hk.initializers.VarianceScaling(scale=0.0)
452 | head_policy = hk.nets.MLP([action_space_size], w_init=output_init)
453 | head_value = hk.nets.MLP([config["num_bins"]], w_init=output_init)
454 | head_reward = hk.nets.MLP([config["num_bins"]], w_init=output_init)
455 |
456 | return (
457 | head_policy(head_state),
458 | head_reward(head_state),
459 | head_value(head_state),
460 | )
461 |
462 | prediction = hk.without_apply_rng(hk.transform(_prediction_fun))
463 |
464 | dummy_action = utils.add_batch_dim(jnp.array(env_spec.actions.generate_value()))
465 |
466 | dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
467 |
468 | def _dummy_state(key: networks_lib.PRNGKey) -> jnp.ndarray:
469 | encoder_params = representation.init(key, dummy_obs)
470 | dummy_state = representation.apply(encoder_params, dummy_obs)
471 | return dummy_state
472 |
473 | return Networks(
474 | representation_network=networks_lib.FeedForwardNetwork(
475 | lambda key: representation.init(key, dummy_obs), representation.apply
476 | ),
477 | transition_network=networks_lib.FeedForwardNetwork(
478 | lambda key: transition.init(key, dummy_action, _dummy_state(key)),
479 | transition.apply,
480 | ),
481 | prediction_network=networks_lib.FeedForwardNetwork(
482 | lambda key: prediction.init(key, _dummy_state(key)),
483 | prediction.apply,
484 | ),
485 | environment_specs=env_spec,
486 | )
487 |
--------------------------------------------------------------------------------
/rosmo/agent/learning.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Garena Online Private Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Agent learner."""
16 | import time
17 | from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple
18 |
19 | import acme
20 | import distrax
21 | import haiku as hk
22 | import jax
23 | import jax.numpy as jnp
24 | import mctx
25 | import optax
26 | import rlax
27 | import tree
28 | from absl import logging
29 | from acme.jax import networks as networks_lib
30 | from acme.jax import utils
31 | from acme.utils import loggers
32 |
33 | from rosmo.agent.improvement_op import mcts_improve, one_step_improve
34 | from rosmo.agent.network import Networks
35 | from rosmo.agent.type import AgentOutput, Params
36 | from rosmo.agent.utils import (
37 | inv_value_transform,
38 | logits_to_scalar,
39 | scalar_to_two_hot,
40 | scale_gradient,
41 | value_transform,
42 | )
43 | from rosmo.type import ActorOutput, Array
44 |
45 |
46 | class TrainingState(NamedTuple):
47 | """Training state."""
48 |
49 | optimizer_state: optax.OptState
50 | params: Params
51 | target_params: Params
52 |
53 | step: int
54 |
55 |
56 | class RosmoLearner(acme.core.Learner):
57 | """ROSMO learner."""
58 |
59 | def __init__(
60 | self,
61 | networks: Networks,
62 | demonstrations: Iterator[ActorOutput],
63 | config: Dict,
64 | logger: Optional[loggers.Logger] = None,
65 | ) -> None:
66 | """Init ROSMO learner.
67 |
68 | Args:
69 | networks (Networks): ROSMO networks.
70 | demonstrations (Iterator[ActorOutput]): Data loader.
71 | config (Dict): Configurations.
72 | logger (Optional[loggers.Logger], optional): Logger. Defaults to None.
73 | """
74 | discount_factor = config["discount_factor"]
75 | weight_decay = config["weight_decay"]
76 | value_coef = config["value_coef"]
77 | behavior_coef = config["behavior_coef"]
78 | policy_coef = config["policy_coef"]
79 | unroll_steps = config["unroll_steps"]
80 | td_steps = config["td_steps"]
81 | target_update_interval = config["target_update_interval"]
82 | log_interval = config["log_interval"]
83 | batch_size = config["batch_size"]
84 | max_grad_norm = config["max_grad_norm"]
85 | num_bins = config["num_bins"]
86 | use_mcts = config["mcts"]
87 | sampling = config.get("sampling", False)
88 | num_simulations = config.get("num_simulations", -1)
89 | search_depth = config.get("search_depth", num_simulations)
90 |
91 | _batch_categorical_cross_entropy = jax.vmap(rlax.categorical_cross_entropy)
92 |
93 | def loss(
94 | params: Params,
95 | target_params: Params,
96 | trajectory: ActorOutput,
97 | rng_key: networks_lib.PRNGKey,
98 | ) -> Tuple[Array, Dict[str, Array]]:
99 | # Encode obs via learning and target networks, [T, S]
100 | state = networks.representation_network.apply(
101 | params.representation, trajectory.observation
102 | )
103 | target_state = networks.representation_network.apply(
104 | target_params.representation, trajectory.observation
105 | )
106 |
107 | # 1) Model unroll, sampling and estimation.
108 | root_state = jax.tree_map(lambda t: t[:1], state)
109 | learner_root = root_unroll(networks, params, num_bins, root_state)
110 | learner_root: AgentOutput = jax.tree_map(lambda t: t[0], learner_root)
111 |
112 | unroll_trajectory: ActorOutput = jax.tree_map(
113 | lambda t: t[: unroll_steps + 1], trajectory
114 | )
115 | random_action_mask = (
116 | jnp.cumprod(1.0 - unroll_trajectory.is_first[1:]) == 0.0
117 | )
118 | action_sequence = unroll_trajectory.action[:unroll_steps]
119 | num_actions = learner_root.policy_logits.shape[-1]
120 | rng_key, action_key = jax.random.split(rng_key)
121 | random_actions = jax.random.choice(
122 | action_key, num_actions, action_sequence.shape, replace=True
123 | )
124 | simulate_action_sequence = jax.lax.select(
125 | random_action_mask, random_actions, action_sequence
126 | )
127 |
128 | model_out: AgentOutput = model_unroll(
129 | networks,
130 | params,
131 | num_bins,
132 | learner_root.state,
133 | simulate_action_sequence,
134 | )
135 |
136 | # Model predictions.
137 | policy_logits = jnp.concatenate(
138 | [
139 | learner_root.policy_logits[None],
140 | model_out.policy_logits,
141 | ],
142 | axis=0,
143 | )
144 |
145 | value_logits = jnp.concatenate(
146 | [
147 | learner_root.value_logits[None],
148 | model_out.value_logits,
149 | ],
150 | axis=0,
151 | )
152 |
153 | # 2) Model learning targets.
154 | # a) Reward.
155 | rewards = trajectory.reward
156 | reward_target = jax.lax.select(
157 | random_action_mask,
158 | jnp.zeros_like(rewards[:unroll_steps]),
159 | rewards[:unroll_steps],
160 | )
161 | reward_target_transformed = value_transform(reward_target)
162 | reward_logits_target = scalar_to_two_hot(
163 | reward_target_transformed, num_bins
164 | )
165 |
166 | # b) Policy.
167 | target_roots: AgentOutput = root_unroll(
168 | networks, target_params, num_bins, target_state
169 | )
170 | search_roots: AgentOutput = jax.tree_map(
171 | lambda t: t[: unroll_steps + 1], target_roots
172 | )
173 | rng_key, improve_key = jax.random.split(rng_key)
174 |
175 | if use_mcts:
176 | logging.info(
177 | f"[Learning] Using MuZero with simulation={num_simulations}"
178 | f" & depth={search_depth}."
179 | )
180 | mcts_out = mcts_improve(
181 | networks,
182 | improve_key,
183 | target_params,
184 | target_roots,
185 | num_bins,
186 | discount_factor,
187 | num_simulations,
188 | search_depth,
189 | )
190 | policy_target, improve_adv = (
191 | mcts_out.action_weights[: unroll_steps + 1],
192 | 0.0,
193 | )
194 | else:
195 | logging.info("[Learning] Using ROSMO.")
196 | improve_keys = jax.random.split(
197 | improve_key, search_roots.state.shape[0]
198 | )
199 | policy_target, improve_adv = jax.vmap( # type: ignore
200 | one_step_improve,
201 | (None, 0, None, 0, None, None, None, None),
202 | )(
203 | networks,
204 | improve_keys,
205 | target_params,
206 | search_roots,
207 | num_bins,
208 | discount_factor,
209 | num_simulations,
210 | sampling,
211 | )
212 | uniform_policy = jnp.ones_like(policy_target) / num_actions
213 | random_policy_mask = jnp.cumprod(1.0 - unroll_trajectory.is_last) == 0.0
214 | random_policy_mask = jnp.broadcast_to(
215 | random_policy_mask[:, None], policy_target.shape
216 | )
217 | policy_target = jax.lax.select(
218 | random_policy_mask, uniform_policy, policy_target
219 | )
220 | policy_target = jax.lax.stop_gradient(policy_target)
221 |
222 | # c) Value.
223 | discounts = (1.0 - trajectory.is_last[1:]) * discount_factor
224 | if use_mcts:
225 | node_values = mcts_out.search_tree.node_values
226 | v_bootstrap = node_values[:, mctx.Tree.ROOT_INDEX]
227 | else:
228 | v_bootstrap = target_roots.value
229 |
230 | def n_step_return(i: int) -> jnp.ndarray:
231 | bootstrap_value = jax.tree_map(lambda t: t[i + td_steps], v_bootstrap)
232 | _rewards = jnp.concatenate(
233 | [rewards[i : i + td_steps], bootstrap_value[None]], axis=0
234 | )
235 | _discounts = jnp.concatenate(
236 | [jnp.ones((1,)), jnp.cumprod(discounts[i : i + td_steps])],
237 | axis=0,
238 | )
239 | return jnp.sum(_rewards * _discounts)
240 |
241 | returns = []
242 | for i in range(unroll_steps + 1):
243 | returns.append(n_step_return(i))
244 | returns = jnp.stack(returns)
245 | # Value targets for the absorbing state and the states after are 0.
246 | zero_return_mask = jnp.cumprod(1.0 - unroll_trajectory.is_last) == 0.0
247 | value_target = jax.lax.select(
248 | zero_return_mask, jnp.zeros_like(returns), returns
249 | )
250 | value_target_transformed = value_transform(value_target)
251 | value_logits_target = scalar_to_two_hot(value_target_transformed, num_bins)
252 | value_logits_target = jax.lax.stop_gradient(value_logits_target)
253 |
254 | # 3) Behavior regularization.
255 | behavior_loss = jnp.array(0.0)
256 | if not use_mcts:
257 | in_sample_action = trajectory.action[: unroll_steps + 1]
258 | log_prob = jax.nn.log_softmax(policy_logits)
259 | action_log_prob = log_prob[
260 | jnp.arange(unroll_steps + 1), in_sample_action
261 | ]
262 |
263 | _target_value = target_roots.value[: unroll_steps + 1]
264 | _target_reward = target_roots.reward[1 : unroll_steps + 1 + 1]
265 | _target_value_prime = target_roots.value[1 : unroll_steps + 1 + 1]
266 | _target_adv = (
267 | _target_reward
268 | + discount_factor * _target_value_prime
269 | - _target_value
270 | )
271 | _target_adv = jax.lax.stop_gradient(_target_adv)
272 | behavior_loss = -action_log_prob * jnp.heaviside(_target_adv, 0.0)
273 | # Deal with cross-episode trajectories.
274 | invalid_action_mask = jnp.cumprod(1.0 - trajectory.is_first[1:]) == 0.0
275 | behavior_loss = jax.lax.select(
276 | invalid_action_mask[: unroll_steps + 1],
277 | jnp.zeros_like(behavior_loss),
278 | behavior_loss,
279 | )
280 | behavior_loss = jnp.mean(behavior_loss) * behavior_coef
281 |
282 | # 4) Compute the losses.
283 | reward_loss = jnp.mean(
284 | _batch_categorical_cross_entropy(
285 | reward_logits_target, model_out.reward_logits
286 | )
287 | )
288 |
289 | value_loss = (
290 | jnp.mean(
291 | _batch_categorical_cross_entropy(value_logits_target, value_logits)
292 | )
293 | * value_coef
294 | )
295 |
296 | policy_loss = (
297 | jnp.mean(_batch_categorical_cross_entropy(policy_target, policy_logits))
298 | * policy_coef
299 | )
300 |
301 | total_loss = reward_loss + value_loss + policy_loss + behavior_loss
302 |
303 | if sampling:
304 | # Unnormalized.
305 | def entropy_fn(p: Array) -> Array:
306 | return distrax.Categorical(logits=p).entropy()
307 |
308 | else:
309 |
310 | def entropy_fn(p: Array) -> Array:
311 | return distrax.Categorical(probs=p).entropy()
312 |
313 | policy_target_entropy = jax.vmap(entropy_fn)(policy_target)
314 | policy_entropy = jax.vmap(
315 | lambda l: distrax.Categorical(logits=l).entropy()
316 | )(policy_logits)
317 |
318 | log = {
319 | "reward_target": reward_target,
320 | "reward_prediction": model_out.reward,
321 | "value_target": value_target,
322 | "value_prediction": model_out.value,
323 | "policy_entropy": policy_entropy,
324 | "policy_target_entropy": policy_target_entropy,
325 | "reward_loss": reward_loss,
326 | "value_loss": value_loss,
327 | "policy_loss": policy_loss,
328 | "behavior_loss": behavior_loss,
329 | "improve_advantage": improve_adv,
330 | "total_loss": total_loss,
331 | }
332 | return total_loss, log
333 |
334 | def batch_loss(
335 | params: Params,
336 | target_params: Params,
337 | trajectory: ActorOutput,
338 | rng_key: networks_lib.PRNGKey,
339 | ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:
340 | bs = len(trajectory.reward)
341 | rng_keys = jax.random.split(rng_key, bs)
342 | losses, log = jax.vmap(loss, (None, None, 0, 0))(
343 | params,
344 | target_params,
345 | trajectory,
346 | rng_keys,
347 | )
348 | log_mean = {f"{k}_mean": jnp.mean(v) for k, v in log.items()}
349 | std_keys = [
350 | "reward_target",
351 | "reward_prediction",
352 | "q_val_target",
353 | "q_val_prediction",
354 | "value_target",
355 | "value_prediction",
356 | "improve_advantage",
357 | ]
358 | std_keys = [k for k in std_keys if k in log]
359 | log_std = {f"{k}_std": jnp.std(log[k]) for k in std_keys}
360 | log_mean.update(log_std)
361 | return jnp.mean(losses), log_mean
362 |
363 | def update_step(
364 | state: TrainingState,
365 | trajectory: ActorOutput,
366 | rng_key: networks_lib.PRNGKey,
367 | ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
368 | params = state.params
369 | optimizer_state = state.optimizer_state
370 |
371 | grads, log = jax.grad(batch_loss, has_aux=True)(
372 | state.params, state.target_params, trajectory, rng_key
373 | )
374 | grads = jax.lax.pmean(grads, axis_name="i")
375 | network_updates, optimizer_state = optimizer.update(
376 | grads, optimizer_state, params
377 | )
378 | params = optax.apply_updates(params, network_updates)
379 | log.update(
380 | {
381 | "grad_norm": optax.global_norm(grads),
382 | "update_norm": optax.global_norm(network_updates),
383 | "param_norm": optax.global_norm(params),
384 | }
385 | )
386 | new_state = TrainingState(
387 | optimizer_state=optimizer_state,
388 | params=params,
389 | target_params=state.target_params,
390 | step=state.step + 1,
391 | )
392 | return new_state, log
393 |
394 | # Logger.
395 | self._logger = logger or loggers.make_default_logger(
396 | "learner", asynchronous=True, serialize_fn=utils.fetch_devicearray
397 | )
398 |
399 | # Iterator on demonstration transitions.
400 | self._demonstrations = demonstrations
401 |
402 | # JIT compiler.
403 | self._batch_size = batch_size
404 | self._num_devices = jax.device_count()
405 | assert self._batch_size % self._num_devices == 0
406 | self._update_step = jax.pmap(update_step, axis_name="i")
407 |
408 | # Create initial state.
409 | random_key = jax.random.PRNGKey(config["seed"])
410 | self._rng_key: networks_lib.PRNGKey
411 | key_r, key_d, key_p, self._rng_key = jax.random.split(random_key, 4)
412 | representation_params = networks.representation_network.init(key_r)
413 | transition_params = networks.transition_network.init(key_d)
414 | prediction_params = networks.prediction_network.init(key_p)
415 |
416 | # Create and initialize optimizer.
417 | params = Params(
418 | representation_params,
419 | transition_params,
420 | prediction_params,
421 | )
422 | weight_decay_mask = Params(
423 | representation=hk.data_structures.map(
424 | lambda module_name, name, value: True if name == "w" else False,
425 | params.representation,
426 | ),
427 | transition=hk.data_structures.map(
428 | lambda module_name, name, value: True if name == "w" else False,
429 | params.transition,
430 | ),
431 | prediction=hk.data_structures.map(
432 | lambda module_name, name, value: True if name == "w" else False,
433 | params.prediction,
434 | ),
435 | )
436 | learning_rate = optax.warmup_exponential_decay_schedule(
437 | init_value=0.0,
438 | peak_value=config["learning_rate"],
439 | warmup_steps=config["warmup_steps"],
440 | transition_steps=100_000,
441 | decay_rate=config["learning_rate_decay"],
442 | staircase=True,
443 | )
444 | optimizer = optax.adamw(
445 | learning_rate=learning_rate,
446 | weight_decay=weight_decay,
447 | mask=weight_decay_mask,
448 | )
449 | if max_grad_norm:
450 | optimizer = optax.chain(optax.clip_by_global_norm(max_grad_norm), optimizer)
451 | optimizer_state = optimizer.init(params)
452 | target_params = params
453 |
454 | # Learner state.
455 | self._state = TrainingState(
456 | optimizer_state=optimizer_state,
457 | params=params,
458 | target_params=target_params,
459 | step=0,
460 | )
461 | self._target_update_interval = target_update_interval
462 |
463 | self._state = jax.device_put_replicated(self._state, jax.local_devices())
464 |
465 | # Do not record timestamps until after the first learning step is done.
466 | # This is to avoid including the time it takes for actors to come online
467 | # and fill the replay buffer.
468 | self._timestamp: Optional[float] = None
469 | self._elapsed = 0.0
470 | self._log_interval = log_interval
471 | self._unroll_steps = unroll_steps
472 |
473 | def step(self) -> None:
474 | """Train step."""
475 | update_key, self._rng_key = jax.random.split(self._rng_key)
476 | update_keys = jax.random.split(update_key, self._num_devices)
477 | trajectory: ActorOutput = next(self._demonstrations)
478 | trajectory = tree.map_structure(
479 | lambda x: x.reshape(
480 | self._num_devices, self._batch_size // self._num_devices, *x.shape[1:]
481 | ),
482 | trajectory,
483 | )
484 |
485 | self._state, metrics = self._update_step(self._state, trajectory, update_keys)
486 |
487 | _step = self._state.step[0] # type: ignore
488 | timestamp = time.time()
489 | elapsed: float = 0
490 | if self._timestamp:
491 | elapsed = timestamp - self._timestamp
492 | self._timestamp = timestamp
493 | self._elapsed += elapsed
494 |
495 | if _step % self._target_update_interval == 0:
496 | state: TrainingState = self._state
497 | self._state = TrainingState(
498 | optimizer_state=state.optimizer_state,
499 | params=state.params,
500 | target_params=state.params,
501 | step=state.step,
502 | )
503 | if _step % self._log_interval == 0:
504 | metrics = jax.tree_util.tree_map(lambda t: t[0], metrics)
505 | metrics = jax.device_get(metrics)
506 | self._logger.write(
507 | {
508 | **metrics,
509 | **{
510 | "step": _step,
511 | "elapsed_time": self._elapsed,
512 | },
513 | }
514 | )
515 |
516 | def get_variables(self, names: List[str]) -> List[Any]:
517 | """Get network parameters."""
518 | state = self.save()
519 | variables = {
520 | "representation": state.params.representation,
521 | "dynamics": state.params.transition,
522 | "prediction": state.params.prediction,
523 | }
524 | return [variables[name] for name in names]
525 |
526 | def save(self) -> TrainingState:
527 | """Save the training state.
528 |
529 | Returns:
530 | TrainingState: State to be saved.
531 | """
532 | _state = utils.fetch_devicearray(jax.tree_map(lambda t: t[0], self._state))
533 | return _state
534 |
535 | def restore(self, state: TrainingState) -> None:
536 | """Restore the training state.
537 |
538 | Args:
539 | state (TrainingState): State to be resumed.
540 | """
541 | self._state = jax.device_put_replicated(state, jax.local_devices())
542 |
543 |
544 | def root_unroll(
545 | networks: Networks,
546 | params: Params,
547 | num_bins: int,
548 | state: Array,
549 | ) -> AgentOutput:
550 | """Unroll the learned model from the root node."""
551 | (
552 | policy_logits,
553 | reward_logits,
554 | value_logits,
555 | ) = networks.prediction_network.apply(params.prediction, state)
556 | reward = logits_to_scalar(reward_logits, num_bins)
557 | reward = inv_value_transform(reward)
558 | value = logits_to_scalar(value_logits, num_bins)
559 | value = inv_value_transform(value)
560 | return AgentOutput(
561 | state=state,
562 | policy_logits=policy_logits,
563 | reward_logits=reward_logits,
564 | reward=reward,
565 | value_logits=value_logits,
566 | value=value,
567 | )
568 |
569 |
570 | def model_unroll(
571 | networks: Networks,
572 | params: Params,
573 | num_bins: int,
574 | state: Array,
575 | action_sequence: Array,
576 | ) -> AgentOutput:
577 | """Unroll the learned model with a sequence of actions."""
578 |
579 | def fn(state: Array, action: Array) -> Tuple[Array, Array]:
580 | """Dynamics fun for scan."""
581 | next_state = networks.transition_network.apply(
582 | params.transition, action[None], state
583 | )
584 | next_state = scale_gradient(next_state, 0.5)
585 | return next_state, next_state
586 |
587 | _, state_sequence = jax.lax.scan(fn, state, action_sequence)
588 | (
589 | policy_logits,
590 | reward_logits,
591 | value_logits,
592 | ) = networks.prediction_network.apply(params.prediction, state_sequence)
593 | reward = logits_to_scalar(reward_logits, num_bins)
594 | reward = inv_value_transform(reward)
595 | value = logits_to_scalar(value_logits, num_bins)
596 | value = inv_value_transform(value)
597 | return AgentOutput(
598 | state=state_sequence,
599 | policy_logits=policy_logits,
600 | reward_logits=reward_logits,
601 | reward=reward,
602 | value_logits=value_logits,
603 | value=value,
604 | )
605 |
--------------------------------------------------------------------------------