├── tests ├── __init__.py └── test_data.py ├── rosmo ├── __about__.py ├── __init__.py ├── agent │ ├── __init__.py │ ├── type.py │ ├── utils.py │ ├── actor.py │ ├── improvement_op.py │ ├── network.py │ └── learning.py ├── data │ ├── __init__.py │ ├── buffer.py │ ├── bsuite.py │ ├── rl_unplugged_atari_baselines.json │ └── rlu_atari.py ├── type.py ├── profiler.py ├── env_loop_observer.py └── loggers.py ├── .github ├── actions │ └── cache │ │ └── action.yml └── workflows │ └── check.yml ├── Makefile ├── .gitignore ├── INSTALL.md ├── experiment ├── bsuite │ ├── config.py │ └── main.py └── atari │ ├── config.py │ └── main.py ├── pyproject.toml ├── README.md └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /rosmo/__about__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """About.""" 16 | __version__ = "0.0.2" 17 | -------------------------------------------------------------------------------- /rosmo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ROSMO: A Regularized One-Step Model-based algorithm for Offline-RL.""" 16 | -------------------------------------------------------------------------------- /.github/actions/cache/action.yml: -------------------------------------------------------------------------------- 1 | name: Cache 2 | description: "cache for pip" 3 | outputs: 4 | cache-hit: 5 | value: ${{ steps.cache.outputs.cache-hit }} 6 | description: "cache hit" 7 | 8 | runs: 9 | using: "composite" 10 | steps: 11 | - name: Get pip cache dir 12 | id: pip-cache-dir 13 | shell: bash 14 | run: | 15 | python -m pip install --upgrade pip 16 | echo "::set-output name=dir::$(pip cache dir)" 17 | - name: Cache pip 18 | id: cache-pip 19 | uses: actions/cache@v2 20 | with: 21 | path: ${{ steps.pip-cache-dir.outputs.dir }} 22 | key: ${{ runner.os }}-pip-cache-${{ hashFiles('pyproject.toml') }} 23 | restore-keys: | 24 | ${{ runner.os }}-pip-cache- 25 | -------------------------------------------------------------------------------- /rosmo/agent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ROSMO Agent.""" 16 | import os 17 | 18 | import tensorflow as tf 19 | 20 | os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.85" 21 | tf.config.experimental.set_visible_devices([], "GPU") 22 | -------------------------------------------------------------------------------- /rosmo/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Offline dataset.""" 16 | from rosmo.data.bsuite import env_loader as bsuite_env_loader 17 | from rosmo.data.rlu_atari import env_loader as atari_env_loader 18 | 19 | __all__ = [ 20 | "bsuite_env_loader", 21 | "atari_env_loader", 22 | ] 23 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL=/bin/bash 2 | PROJECT_NAME=rosmo 3 | PROJECT_PATH=rosmo/ 4 | LINT_PATHS=${PROJECT_PATH} tests/ experiment/ 5 | 6 | check_install = python3 -c "import $(1)" || pip3 install $(1) --upgrade 7 | check_install_extra = python3 -c "import $(1)" || pip3 install $(2) --upgrade 8 | 9 | test: 10 | $(call check_install, pytest) 11 | pytest -s 12 | 13 | lint: 14 | $(call check_install, isort) 15 | $(call check_install, pylint) 16 | $(call check_install, mypy) 17 | isort --check --diff --project=${LINT_PATHS} 18 | pylint -j 8 --recursive=y ${LINT_PATHS} 19 | mypy ${PROJECT_PATH} 20 | 21 | format: 22 | # format using black 23 | $(call check_install, black) 24 | black ${LINT_PATHS} 25 | # sort imports 26 | $(call check_install, isort) 27 | isort ${LINT_PATHS} 28 | 29 | check-docstyle: 30 | $(call check_install, pydocstyle) 31 | pydocstyle ${PROJECT_PATH} --convention=google 32 | 33 | checks: lint check-docstyle 34 | 35 | .PHONY: format lint check-docstyle checks 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary and binary files 2 | *~ 3 | *.py[cod] 4 | *.so 5 | *.cfg 6 | !.isort.cfg 7 | !setup.cfg 8 | *.orig 9 | *.log 10 | *.pot 11 | __pycache__/* 12 | .cache/* 13 | .*.swp 14 | */.ipynb_checkpoints/* 15 | .DS_Store 16 | 17 | # Project files 18 | .ropeproject 19 | .project 20 | .pydevproject 21 | .settings 22 | .idea 23 | .vscode 24 | tags 25 | 26 | # Package files 27 | *.egg 28 | *.eggs/ 29 | .installed.cfg 30 | *.egg-info 31 | 32 | # Unittest and coverage 33 | htmlcov/* 34 | .coverage 35 | .coverage.* 36 | .tox 37 | junit*.xml 38 | coverage.xml 39 | .pytest_cache/ 40 | 41 | # Build and docs folder/files 42 | build/* 43 | dist/* 44 | sdist/* 45 | docs/api/* 46 | docs/_rst/* 47 | docs/_build/* 48 | cover/* 49 | MANIFEST 50 | 51 | # Per-project virtualenvs 52 | .venv*/ 53 | .conda*/ 54 | .python-version 55 | 56 | # Local dev folder/files 57 | local-dev/ 58 | sota/ 59 | wandb/ 60 | .dockerignore 61 | checkpoint/ 62 | logs/ 63 | datasets/ 64 | internal/gke/.mujoco 65 | internal/gke/bsuite 66 | internal/gke/.boto 67 | 68 | sample_data 69 | wandb 70 | -------------------------------------------------------------------------------- /rosmo/type.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Type definitions.""" 16 | from typing import Callable, NamedTuple, Union 17 | 18 | import jax.numpy as jnp 19 | import numpy as np 20 | 21 | Array = Union[np.ndarray, jnp.ndarray] 22 | Forwardable = Callable 23 | 24 | 25 | class ActorOutput(NamedTuple): 26 | """Actor output parsed from the dataset.""" 27 | 28 | observation: Array 29 | reward: Array 30 | is_first: Array 31 | is_last: Array 32 | action: Array 33 | -------------------------------------------------------------------------------- /.github/workflows/check.yml: -------------------------------------------------------------------------------- 1 | name: CI checks 2 | 3 | on: 4 | pull_request: 5 | paths: 6 | - '.github/workflows/check.yml' 7 | - 'experiment/**' 8 | - 'rosmo/**' 9 | - 'tests/**' 10 | - 'pyproject.py' 11 | - 'Makefile' 12 | push: 13 | branches: 14 | - main 15 | paths: 16 | - '.github/workflows/check.yml' 17 | - 'experiment/**' 18 | - 'rosmo/**' 19 | - 'tests/**' 20 | - 'pyproject.py' 21 | - 'Makefile' 22 | 23 | concurrency: 24 | group: ${{ github.ref }}-${{ github.workflow }} 25 | cancel-in-progress: true 26 | 27 | jobs: 28 | lint: 29 | runs-on: ubuntu-latest 30 | timeout-minutes: 7 31 | steps: 32 | - uses: actions/checkout@v3 33 | with: 34 | fetch-depth: 0 35 | - uses: actions/setup-python@v4 36 | with: 37 | python-version: 3.8.13 38 | - uses: ./.github/actions/cache 39 | - name: Install 40 | run: | 41 | pip install -e . 42 | pip install dopamine-rl==3.1.2 43 | pip install chex==0.1.5 44 | - name: Lint 45 | run: make checks 46 | -------------------------------------------------------------------------------- /rosmo/agent/type.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Types.""" 16 | from typing import NamedTuple 17 | 18 | from acme.jax import networks as networks_lib 19 | 20 | from rosmo.type import Array 21 | 22 | 23 | class Params(NamedTuple): 24 | """Agent parameters.""" 25 | 26 | representation: networks_lib.Params 27 | transition: networks_lib.Params 28 | prediction: networks_lib.Params 29 | 30 | 31 | class AgentOutput(NamedTuple): 32 | """Agent prediction output.""" 33 | 34 | state: Array 35 | policy_logits: Array 36 | value_logits: Array 37 | value: Array 38 | reward_logits: Array 39 | reward: Array 40 | -------------------------------------------------------------------------------- /rosmo/profiler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Profiling utils.""" 16 | import os 17 | 18 | import jax 19 | from viztracer import VizTracer 20 | 21 | 22 | class Profiler: 23 | """Profiler for python and jax (optional).""" 24 | 25 | def __init__(self, folder: str, name: str, with_jax: bool = False) -> None: 26 | """Init.""" 27 | super().__init__() 28 | self._name = name 29 | self._folder = folder 30 | self._with_jax = with_jax 31 | self._vistracer = VizTracer( 32 | output_file=os.path.join(folder, "viztracer", name + ".html"), 33 | max_stack_depth=3, 34 | ) 35 | self._jax_folder = os.path.join(folder, "jax_profiler/" + name) 36 | 37 | def start(self) -> None: 38 | """Start to trace.""" 39 | if self._with_jax: 40 | jax.profiler.start_trace(self._jax_folder) 41 | self._vistracer.start() 42 | 43 | def stop(self) -> None: 44 | """Stop tracing.""" 45 | self._vistracer.stop() 46 | if self._with_jax: 47 | jax.profiler.stop_trace() 48 | 49 | def save(self) -> None: 50 | """Save the results.""" 51 | self._vistracer.save() 52 | 53 | def stop_and_save(self) -> None: 54 | """Combine stop and save.""" 55 | self.stop() 56 | self.save() 57 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Testing data loading.""" 16 | import os 17 | 18 | from absl import logging 19 | from absl.testing import absltest, parameterized 20 | 21 | from rosmo.data import atari_env_loader, bsuite_env_loader 22 | 23 | _DATASET_DIR = "./datasets" 24 | 25 | 26 | class RLUAtari(parameterized.TestCase): 27 | """Test RL Unplugged Atari data loader.""" 28 | 29 | @staticmethod 30 | def test_data_loader(): 31 | """Test data loader.""" 32 | dataset_dir = os.path.join(_DATASET_DIR, "atari") 33 | _, dataloader = atari_env_loader( 34 | env_name="Asterix", 35 | run_number=1, 36 | dataset_dir=dataset_dir, 37 | ) 38 | iterator = iter(dataloader) 39 | data = next(iterator) 40 | logging.info(data) 41 | 42 | 43 | class BSuite(parameterized.TestCase): 44 | """Test BSuite data loader.""" 45 | 46 | @staticmethod 47 | def test_data_loader(): 48 | """Test data loader.""" 49 | dataset_dir = os.path.join(_DATASET_DIR, "bsuite") 50 | _, dataloader = bsuite_env_loader( 51 | env_name="catch", 52 | dataset_dir=dataset_dir, 53 | ) 54 | iterator = iter(dataloader) 55 | _ = next(iterator) 56 | 57 | 58 | if __name__ == "__main__": 59 | absltest.main() 60 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | :wrench: 4 | 5 | **Table of Contents** 6 | 7 | - [Installation](#installation) 8 | - [General](#general) 9 | - [TPU](#tpu) 10 | - [GPU](#gpu) 11 | - [Test](#test) 12 | 13 | ### General 14 | 15 | 1. Prepare an environment with `python=3.8`. 16 | 2. Clone this repository and install it in develop mode: 17 | ```console 18 | pip install -e . 19 | 20 | # Install the following packages separately due to version conflicts. 21 | pip install dopamine-rl==3.1.2 22 | pip install chex==0.1.5 23 | ``` 24 | 3. [Install the ROM for Atari](https://github.com/openai/atari-py#roms). 25 | 4. Download dataset: 26 | 1. **BSuite** datasets ([drive](https://drive.google.com/file/d/1FWexoOphUgBaWTWtY9VR43N90z9A6FvP/view?usp=sharing)) if you are running BSuite experiments; 27 | 2. **Atari** datasets will be automatically downloaded from [TFDS](https://www.tensorflow.org/datasets/catalog/rlu_atari) when starting the experiment. The dataset path is defined in `experiment/*/config.py`. Or you could also download it using the following script: 28 | ``` 29 | from rosmo.data.rlu_atari import create_atari_ds_loader 30 | 31 | create_atari_ds_loader( 32 | env_name="Pong", # Change this. 33 | run_number=1, # Fix this. 34 | dataset_dir="/path/to/download", 35 | ) 36 | ``` 37 | 38 | ### TPU 39 | 40 | All of our Atari experiments reported in the paper were run on TPUv3-8 machines from Google Cloud. If you would like to run your experiments on TPUs as well, the following commands might help: 41 | ```console 42 | sudo apt-get update && sudo apt install -y libopencv-dev 43 | pip uninstall jax jaxlib libtpu-nightly libtpu -y 44 | pip install "jax[tpu]==0.3.6" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -i https://pypi.python.org/simple 45 | ``` 46 | 47 | ### GPU 48 | 49 | We also conducted verification experiments on 4 Tesla-A100 GPUs to ensure our algorithm's reproducibility on different platforms. To install the same version of Jax as ours: 50 | ```console 51 | pip uninstall jax jaxlib libtpu-nightly libtpu -y 52 | 53 | # jax-0.3.25 jaxlib-0.3.25+cuda11.cudnn82 54 | pip install --upgrade "jax[cuda]==0.3.25" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 55 | ``` 56 | 57 | ### Test 58 | 59 | Finally, for TPU/GPU users, to validate you have installed Jax correctly, run `python -c "import jax; print(jax.devices())"` and expect a list of TPU/GPU devices printed. 60 | -------------------------------------------------------------------------------- /experiment/bsuite/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """BSuite experiment configs.""" 16 | from copy import deepcopy 17 | from typing import Dict 18 | 19 | from absl import flags, logging 20 | from ml_collections import ConfigDict 21 | from wandb.util import generate_id 22 | 23 | FLAGS = flags.FLAGS 24 | 25 | 26 | # ===== Configurations ===== # 27 | def get_config(game_name: str) -> Dict: 28 | """Get experiment configurations.""" 29 | config = deepcopy(CONFIG) 30 | config["seed"] = FLAGS.seed 31 | config["benchmark"] = "bsuite" 32 | config["mcts"] = FLAGS.algo == "mzu" 33 | config["game_name"] = game_name 34 | config["batch_size"] = 16 if FLAGS.debug else config.batch_size 35 | exp_full_name = f"{FLAGS.exp_id}_{game_name}_" + generate_id() 36 | config["exp_full_name"] = exp_full_name 37 | logging.info(f"Configs: {config}") 38 | return config 39 | 40 | 41 | CONFIG = ConfigDict( 42 | { 43 | "data_dir": "./datasets/bsuite", 44 | "run_number": 1, 45 | "data_percentage": 100, 46 | "batch_size": 512, 47 | "unroll_steps": 3, 48 | "td_steps": 3, 49 | "num_bins": 20, 50 | "encoder_layers": [64, 64, 32], 51 | "dynamics_layers": [32, 32], 52 | "prediction_layers": [32], 53 | "output_init_scale": 0.0, 54 | "discount_factor": 0.997**4, 55 | "clipping_threshold": 1.0, 56 | "evaluate_episodes": 2, 57 | "log_interval": 400, 58 | "learning_rate": 7e-4, 59 | "warmup_steps": 1_000, 60 | "learning_rate_decay": 0.1, 61 | "weight_decay": 1e-4, 62 | "max_grad_norm": 5.0, 63 | "target_update_interval": 200, 64 | "value_coef": 0.25, 65 | "policy_coef": 1.0, 66 | "behavior_coef": 0.2, 67 | "save_period": 10_000, 68 | "eval_period": 1_000, 69 | "total_steps": 200_000, 70 | } 71 | ) 72 | -------------------------------------------------------------------------------- /experiment/atari/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Atari experiment configs.""" 16 | from copy import deepcopy 17 | from typing import Dict 18 | 19 | from absl import flags, logging 20 | from ml_collections import ConfigDict 21 | from wandb.util import generate_id 22 | 23 | FLAGS = flags.FLAGS 24 | 25 | 26 | # ===== Configurations ===== # 27 | def get_config(game_name: str) -> Dict: 28 | """Get experiment configurations.""" 29 | config = deepcopy(CONFIG) 30 | config["seed"] = FLAGS.seed 31 | config["benchmark"] = "atari" 32 | config["sampling"] = FLAGS.sampling 33 | config["mcts"] = FLAGS.algo == "mzu" 34 | config["game_name"] = game_name 35 | config["num_simulations"] = FLAGS.num_simulations 36 | config["search_depth"] = FLAGS.search_depth or FLAGS.num_simulations 37 | config["batch_size"] = 16 if FLAGS.debug else config["batch_size"] 38 | exp_full_name = f"{FLAGS.exp_id}_{game_name}_" + generate_id() 39 | config["exp_full_name"] = exp_full_name 40 | logging.info(f"Configs: {config}") 41 | return config 42 | 43 | 44 | CONFIG = ConfigDict( 45 | { 46 | "data_dir": "./datasets/rl_unplugged/tensorflow_datasets", 47 | "run_number": 1, 48 | "data_percentage": 100, 49 | "batch_size": 512, 50 | "stack_size": 4, 51 | "unroll_steps": 5, 52 | "td_steps": 5, 53 | "num_bins": 601, 54 | "channels": 64, 55 | "blocks_representation": 6, 56 | "blocks_prediction": 2, 57 | "blocks_transition": 2, 58 | "reduced_channels_head": 128, 59 | "fc_layers_reward": [128, 128], 60 | "fc_layers_value": [128, 128], 61 | "fc_layers_policy": [128, 128], 62 | "output_init_scale": 0.0, 63 | "discount_factor": 0.997**4, 64 | "clipping_threshold": 1.0, 65 | "evaluate_episodes": 2, 66 | "log_interval": 400, 67 | "learning_rate": 7e-4, 68 | "warmup_steps": 1_000, 69 | "learning_rate_decay": 0.1, 70 | "weight_decay": 1e-4, 71 | "max_grad_norm": 5.0, 72 | "target_update_interval": 200, 73 | "value_coef": 0.25, 74 | "policy_coef": 1.0, 75 | "behavior_coef": 0.2, 76 | "save_period": 10_000, 77 | "eval_period": 1_000, 78 | "total_steps": 200_000, 79 | } 80 | ) 81 | -------------------------------------------------------------------------------- /rosmo/agent/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Agent utilities.""" 16 | # Codes adapted from: 17 | # https://github.com/Hwhitetooth/jax_muzero/blob/main/algorithms/utils.py 18 | import chex 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | from rosmo.type import Array 23 | 24 | 25 | def scale_gradient(g: Array, scale: float) -> Array: 26 | """Scale the gradient. 27 | 28 | Args: 29 | g (_type_): Parameters that contain gradients. 30 | scale (float): Scale. 31 | 32 | Returns: 33 | Array: Parameters with scaled gradients. 34 | """ 35 | return g * scale + jax.lax.stop_gradient(g) * (1.0 - scale) 36 | 37 | 38 | def scalar_to_two_hot(x: Array, num_bins: int) -> Array: 39 | """A categorical representation of real values. 40 | 41 | Ref: https://www.nature.com/articles/s41586-020-03051-4.pdf. 42 | 43 | Args: 44 | x (Array): Scalar data. 45 | num_bins (int): Number of bins. 46 | 47 | Returns: 48 | Array: Distributional data. 49 | """ 50 | max_val = (num_bins - 1) // 2 51 | x = jnp.clip(x, -max_val, max_val) 52 | x_low = jnp.floor(x).astype(jnp.int32) 53 | x_high = jnp.ceil(x).astype(jnp.int32) 54 | p_high = x - x_low 55 | p_low = 1.0 - p_high 56 | idx_low = x_low + max_val 57 | idx_high = x_high + max_val 58 | cat_low = jax.nn.one_hot(idx_low, num_bins) * p_low[..., None] 59 | cat_high = jax.nn.one_hot(idx_high, num_bins) * p_high[..., None] 60 | return cat_low + cat_high 61 | 62 | 63 | def logits_to_scalar(logits: Array, num_bins: int) -> Array: 64 | """The inverse of the scalar_to_two_hot function above. 65 | 66 | Args: 67 | logits (Array): Distributional logits. 68 | num_bins (int): Number of bins. 69 | 70 | Returns: 71 | Array: Scalar data. 72 | """ 73 | chex.assert_equal(num_bins, logits.shape[-1]) 74 | max_val = (num_bins - 1) // 2 75 | x = jnp.sum((jnp.arange(num_bins) - max_val) * jax.nn.softmax(logits), axis=-1) 76 | return x 77 | 78 | 79 | def value_transform(x: Array, epsilon: float = 1e-3) -> Array: 80 | """A non-linear value transformation for variance reduction. 81 | 82 | Ref: https://arxiv.org/abs/1805.11593. 83 | 84 | Args: 85 | x (Array): Data. 86 | epsilon (float, optional): Epsilon. Defaults to 1e-3. 87 | 88 | Returns: 89 | Array: Transformed data. 90 | """ 91 | return jnp.sign(x) * (jnp.sqrt(jnp.abs(x) + 1) - 1) + epsilon * x 92 | 93 | 94 | def inv_value_transform(x: Array, epsilon: float = 1e-3) -> Array: 95 | """The inverse of the non-linear value transformation above. 96 | 97 | Args: 98 | x (Array): Data. 99 | epsilon (float, optional): Epsilon. Defaults to 1e-3. 100 | 101 | Returns: 102 | Array: Inversely transformed data. 103 | """ 104 | return jnp.sign(x) * ( 105 | ((jnp.sqrt(1 + 4 * epsilon * (jnp.abs(x) + 1 + epsilon)) - 1) / (2 * epsilon)) 106 | ** 2 107 | - 1 108 | ) 109 | -------------------------------------------------------------------------------- /rosmo/data/buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """In-memory buffer.""" 16 | from typing import Any, Iterator 17 | 18 | import numpy as np 19 | import tree 20 | 21 | 22 | class UniformBuffer(Iterator): 23 | """Buffer that supports uniform sampling.""" 24 | 25 | def __init__( 26 | self, min_size: int, max_size: int, traj_len: int, batch_size: int = 2 27 | ) -> None: 28 | """Init the buffer.""" 29 | self._min_size = min_size 30 | self._max_size = max_size 31 | self._traj_len = traj_len 32 | self._timestep_storage = None 33 | self._n = 0 34 | self._idx = 0 35 | self._bs = batch_size 36 | self._static_buffer = False 37 | 38 | def __next__(self) -> Any: 39 | """Get the next sample. 40 | 41 | Returns: 42 | Any: Sampled data. 43 | """ 44 | return self.sample(self._bs) 45 | 46 | def init_storage(self, timesteps: Any) -> None: 47 | """Initialize the buffer. 48 | 49 | Args: 50 | timesteps (Any): Timesteps that contain the whole dataset. 51 | """ 52 | assert self._timestep_storage is None 53 | size = timesteps.observation.shape[0] 54 | assert self._min_size <= size <= self._max_size 55 | self._n = size 56 | self._timestep_storage = timesteps 57 | self._static_buffer = True 58 | 59 | def sample(self, batch_size: int) -> Any: 60 | """Sample a batch of data. 61 | 62 | Args: 63 | batch_size (int): Batch size to sample. 64 | 65 | Returns: 66 | Any: Sampled data. 67 | """ 68 | if batch_size + self._traj_len > self._n: 69 | return None 70 | start_indices = np.random.choice( 71 | self._n - self._traj_len, batch_size, replace=False 72 | ) 73 | all_indices = start_indices[:, None] + np.arange(self._traj_len + 1)[None] 74 | base_idx = 0 if self._n < self._max_size else self._idx 75 | all_indices = (all_indices + base_idx) % self._max_size 76 | trajectories = tree.map_structure( 77 | lambda a: a[all_indices], self._timestep_storage 78 | ) 79 | return trajectories 80 | 81 | def full(self) -> bool: 82 | """Test if the buffer is full. 83 | 84 | Returns: 85 | bool: True if the buffer is full. 86 | """ 87 | return self._n == self._max_size 88 | 89 | def ready(self) -> bool: 90 | """Test if the buffer has minimum size. 91 | 92 | Returns: 93 | bool: True if the buffer is ready. 94 | """ 95 | return self._n >= self._min_size 96 | 97 | @property 98 | def size(self) -> int: 99 | """Get the size of the buffer. 100 | 101 | Returns: 102 | int: Buffer size. 103 | """ 104 | return self._n 105 | 106 | def _preallocate(self, item: Any) -> Any: 107 | return tree.map_structure( 108 | lambda t: np.empty((self._max_size,) + t.shape, t.dtype), item 109 | ) 110 | 111 | 112 | def assign(array: Any, index: Any, data: Any) -> None: 113 | """Update array. 114 | 115 | Args: 116 | array (Any): Array to be updated. 117 | index (Any): Index of updates. 118 | data (Any): Update data. 119 | """ 120 | array[index] = data 121 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "rosmo" 7 | description = '' 8 | readme = "README.md" 9 | requires-python = ">=3.7" 10 | license = "MIT" 11 | keywords = [] 12 | authors = [ 13 | { name = "Zichen", email = "liuzc@sea.com" }, 14 | ] 15 | classifiers = [ 16 | "Development Status :: 4 - Beta", 17 | "Programming Language :: Python", 18 | "Programming Language :: Python :: 3.7", 19 | "Programming Language :: Python :: 3.8", 20 | "Programming Language :: Python :: 3.9", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Programming Language :: Python :: Implementation :: CPython", 24 | "Programming Language :: Python :: Implementation :: PyPy", 25 | ] 26 | dependencies = [ 27 | "dm-acme==0.4.0", 28 | "dm-launchpad-nightly==0.3.0.dev20220321", 29 | "dm-haiku==0.0.9", 30 | "gym==0.17.2", 31 | "gin-config==0.3.0", 32 | "rlax==0.1.4", 33 | "tensorflow==2.8.0", 34 | "tensorflow-probability==0.16.0", 35 | "optax==0.1.3", 36 | "tfds-nightly", 37 | "rlds[tensorflow]==0.1.4", 38 | "wandb==0.12.19", 39 | "ml-collections==0.1.1", 40 | "dm-sonnet==2.0.0", 41 | "mujoco-py<2.2,>=2.1", 42 | "bsuite==0.3.5", 43 | "viztracer==0.15.6", 44 | "mctx==0.0.2", 45 | ] 46 | dynamic = ["version"] 47 | 48 | [project.urls] 49 | Documentation = "https://github.com/unknown/rosmo#readme" 50 | Issues = "https://github.com/unknown/rosmo/issues" 51 | Source = "https://github.com/unknown/rosmo" 52 | 53 | [tool.hatch.version] 54 | path = "rosmo/__about__.py" 55 | 56 | [tool.hatch.envs.default] 57 | dependencies = [ 58 | "pytest", 59 | "pytest-cov", 60 | ] 61 | [tool.hatch.envs.default.scripts] 62 | cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=rosmo --cov=tests" 63 | no-cov = "cov --no-cov" 64 | 65 | [[tool.hatch.envs.test.matrix]] 66 | python = ["37", "38", "39", "310", "311"] 67 | 68 | [tool.coverage.run] 69 | branch = true 70 | parallel = true 71 | omit = [ 72 | "rosmo/__about__.py", 73 | ] 74 | 75 | [tool.coverage.report] 76 | exclude_lines = [ 77 | "no cov", 78 | "if __name__ == .__main__.:", 79 | "if TYPE_CHECKING:", 80 | ] 81 | 82 | [tool.pylint.master] 83 | load-plugins = "pylint.extensions.docparams,pylint.extensions.docstyle,pylint.extensions.no_self_use" 84 | default-docstring-type = "google" 85 | ignore-paths = ["rosmo/__about__.py"] 86 | 87 | [tool.pylint.format] 88 | max-line-length = 88 89 | indent-after-paren = 4 90 | indent-string = " " 91 | 92 | [tool.pylint.imports] 93 | known-third-party = "wandb" 94 | 95 | [tool.pylint.reports] 96 | output-format = "colorized" 97 | reports = "no" 98 | score = "yes" 99 | max-args = 7 100 | 101 | [tool.pylint.messages_control] 102 | disable = ["W0108", "W0212", "W1514", "R0902", "R0903", "R0913", "R0914", "R0915", "R1719", 103 | "R1732", "C0103", "C3001"] 104 | 105 | [tool.yapf] 106 | based_on_style = "yapf" 107 | spaces_before_comment = 4 108 | dedent_closing_brackets = true 109 | column_limit = 88 110 | continuation_indent_width = 4 111 | 112 | [tool.isort] 113 | profile = "black" 114 | multi_line_output = 3 115 | indent = 4 116 | line_length = 88 117 | known_third_party = "wandb" 118 | 119 | [tool.mypy] 120 | files = "rosmo/**/*.py" 121 | allow_redefinition = true 122 | check_untyped_defs = true 123 | disallow_incomplete_defs = true 124 | disallow_untyped_defs = true 125 | ignore_missing_imports = true 126 | no_implicit_optional = true 127 | pretty = true 128 | show_error_codes = true 129 | show_error_context = true 130 | show_traceback = true 131 | strict_equality = true 132 | strict_optional = true 133 | warn_no_return = true 134 | warn_redundant_casts = true 135 | warn_unreachable = true 136 | warn_unused_configs = true 137 | 138 | [tool.pydocstyle] 139 | ignore = ["D100", "D102", "D104", "D105", "D107", "D203", "D213", "D401", "D402"] 140 | 141 | 142 | [tool.pytest.ini_options] 143 | filterwarnings = [ 144 | "ignore::UserWarning", 145 | "ignore::DeprecationWarning", 146 | "ignore::FutureWarning", 147 | ] 148 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ROSMO 2 | 3 |
4 | 5 |
6 | 7 | ----- 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | Check status 17 | 18 | 19 | License 20 | 21 | Arxiv 22 | 23 |

24 | 25 | **Table of Contents** 26 | 27 | - [ROSMO](#rosmo) 28 | - [Introduction](#introduction) 29 | - [Installation](#installation) 30 | - [Usage](#usage) 31 | - [BSuite](#bsuite) 32 | - [Atari](#atari) 33 | - [Citation](#citation) 34 | - [License](#license) 35 | - [Acknowledgement](#acknowledgement) 36 | - [Disclaimer](#disclaimer) 37 | 38 | ## Introduction 39 | 40 | This repository contains the implementation of ROSMO, a **R**egularized **O**ne-**S**tep **M**odel-based algorithm for **O**ffline-RL, introduced in our paper "Efficient Offline Policy Optimization with a Learned Model". We provide the training codes for both Atari and BSuite experiments, and have made the reproduced results on `Atari MsPacman` publicly available at [W&B](https://wandb.ai/lkevinzc/rosmo-public). 41 | 42 | ## Installation 43 | Please follow the [installation guide](INSTALL.md). 44 | 45 | ## Usage 46 | ### BSuite 47 | 48 | To run the BSuite experiments, please ensure you have downloaded the [datasets](https://drive.google.com/file/d/1FWexoOphUgBaWTWtY9VR43N90z9A6FvP/view?usp=sharing) and placed them at the directory defined by `CONFIG.data_dir` in `experiment/bsuite/config.py`. 49 | 50 | 1. Debug run. 51 | ```console 52 | python experiment/bsuite/main.py -exp_id test -env cartpole 53 | ``` 54 | 2. Enable [W&B](https://wandb.ai/site) logger and start training. 55 | ```console 56 | python experiment/bsuite/main.py -exp_id test -env cartpole -nodebug -use_wb -user ${WB_USER} 57 | ``` 58 | 59 | ### Atari 60 | 61 | The following commands are examples to train 1) a ROSMO agent, 2) its sampling variant, and 3) a MZU agent on the game `MsPacman`. 62 | 63 | 1. Train ROSMO with exact policy target. 64 | ```console 65 | python experiment/atari/main.py -exp_id rosmo -env MsPacman -nodebug -use_wb -user ${WB_USER} 66 | ``` 67 | 2. Train ROSMO with sampled policy target (N=4). 68 | ```console 69 | python experiment/atari/main.py -exp_id rosmo-sample-4 -sampling -env MsPacman -nodebug -use_wb -user ${WB_USER} 70 | ``` 71 | 3. Train MuZero unplugged for benchmark (N=20). 72 | ```console 73 | python experiment/atari/main.py -exp_id mzu-sample-20 -algo mzu -num_simulations 20 -env MsPacman -nodebug -use_wb -user ${WB_USER} 74 | ``` 75 | 76 | ## Citation 77 | 78 | If you find this work useful for your research, please consider citing 79 | ``` 80 | @inproceedings{ 81 | liu2023rosmo, 82 | title={Efficient Offline Policy Optimization with a Learned Model}, 83 | author={Zichen Liu and Siyi Li and Wee Sun Lee and Shuicheng Yan and Zhongwen Xu}, 84 | booktitle={International Conference on Learning Representations}, 85 | year={2023}, 86 | url={https://arxiv.org/abs/2210.05980} 87 | } 88 | ``` 89 | 90 | ## License 91 | 92 | `ROSMO` is distributed under the terms of the [Apache2](https://www.apache.org/licenses/LICENSE-2.0) license. 93 | 94 | ## Acknowledgement 95 | 96 | We thank the following projects which provide great references: 97 | 98 | * [Jax Muzero](https://github.com/Hwhitetooth/jax_muzero) 99 | * [Efficient Zero](https://github.com/YeWR/EfficientZero) 100 | * [Acme](https://github.com/deepmind/acme) 101 | 102 | ## Disclaimer 103 | 104 | This is not an official Sea Limited or Garena Online Private Limited product. 105 | -------------------------------------------------------------------------------- /rosmo/env_loop_observer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Environment loop extensions.""" 16 | import time 17 | from abc import abstractmethod 18 | from typing import Dict, Optional, Union 19 | 20 | import acme 21 | import dm_env 22 | import numpy as np 23 | from acme.utils import observers as observers_lib 24 | from acme.utils import signals 25 | from importlib_metadata import collections 26 | 27 | Number = Union[int, float] 28 | 29 | 30 | class EvaluationLoop(acme.EnvironmentLoop): 31 | """Evaluation env-actor loop.""" 32 | 33 | def run( 34 | self, num_episodes: Optional[int] = None, num_steps: Optional[int] = None 35 | ) -> None: 36 | """Run the evaluation loop.""" 37 | if not (num_episodes is None or num_steps is None): 38 | raise ValueError('Either "num_episodes" or "num_steps" should be None.') 39 | 40 | def should_terminate(episode_count: int, step_count: int) -> bool: 41 | return (num_episodes is not None and episode_count >= num_episodes) or ( 42 | num_steps is not None and step_count >= num_steps 43 | ) 44 | 45 | episode_count, step_count = 0, 0 46 | all_results: Dict[str, list] = collections.defaultdict(list) 47 | with signals.runtime_terminator(): 48 | while not should_terminate(episode_count, step_count): 49 | result = self.run_episode() 50 | episode_count += 1 51 | step_count += result["episode_length"] 52 | for k, v in result.items(): 53 | all_results[k].append(v) 54 | # Log the averaged results from all episodes. 55 | self._logger.write({k: np.mean(v) for k, v in all_results.items()}) 56 | 57 | 58 | class ExtendedEnvLoopObserver(observers_lib.EnvLoopObserver): 59 | """Extended env loop observer.""" 60 | 61 | @abstractmethod 62 | def step(self) -> None: 63 | """Steps the observer.""" 64 | 65 | @abstractmethod 66 | def restore(self, learning_step: int) -> None: 67 | """Restore the observer state.""" 68 | 69 | 70 | class LearningStepObserver(ExtendedEnvLoopObserver): 71 | """Observer to record the learning steps.""" 72 | 73 | def __init__(self) -> None: 74 | """Init observer.""" 75 | super().__init__() 76 | self._learning_step = 0 77 | self._eval_step = 0 78 | self._status = 1 # {0: train, 1: eval} 79 | self._train_elapsed = 0.0 80 | self._last_time: Optional[float] = None 81 | 82 | def step(self) -> None: 83 | """Steps the observer.""" 84 | self._learning_step += 1 85 | 86 | if self._status == 0 and self._last_time: 87 | self._train_elapsed += time.time() - self._last_time 88 | if self._status == 1: 89 | self._status = 0 90 | 91 | self._last_time = time.time() 92 | 93 | def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep) -> None: 94 | """Observes the initial state, setting states.""" 95 | self._status = 1 96 | self._eval_step += 1 97 | 98 | def observe( 99 | self, env: dm_env.Environment, timestep: dm_env.TimeStep, action: np.ndarray 100 | ) -> None: 101 | """Records one environment step, dummy.""" 102 | 103 | def get_metrics(self) -> Dict[str, Number]: 104 | """Returns metrics collected for the current episode.""" 105 | return { 106 | "step": self._learning_step, 107 | "eval_step": self._eval_step, 108 | "learning_time": self._train_elapsed, 109 | } 110 | 111 | def restore(self, learning_step: int) -> None: 112 | """Restore.""" 113 | self._learning_step = learning_step 114 | -------------------------------------------------------------------------------- /experiment/bsuite/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """BSuite experiment entry.""" 16 | import random 17 | import time 18 | from typing import Dict, List, Optional, Tuple 19 | 20 | import jax 21 | import numpy as np 22 | import wandb 23 | from absl import app, flags, logging 24 | from acme import EnvironmentLoop 25 | from acme.specs import make_environment_spec 26 | from acme.utils.loggers import Logger 27 | 28 | from experiment.bsuite.config import get_config 29 | from rosmo.agent.actor import RosmoEvalActor 30 | from rosmo.agent.learning import RosmoLearner 31 | from rosmo.agent.network import get_bsuite_networks 32 | from rosmo.data import bsuite_env_loader 33 | from rosmo.env_loop_observer import ( 34 | EvaluationLoop, 35 | ExtendedEnvLoopObserver, 36 | LearningStepObserver, 37 | ) 38 | from rosmo.loggers import logger_fn 39 | 40 | # ===== Flags. ===== # 41 | FLAGS = flags.FLAGS 42 | flags.DEFINE_boolean("debug", True, "Debug run.") 43 | flags.DEFINE_boolean("use_wb", False, "Use WB to log.") 44 | flags.DEFINE_string("user", "username", "Wandb user id.") 45 | flags.DEFINE_string("project", "rosmo", "Wandb project id.") 46 | 47 | flags.DEFINE_string("exp_id", None, "Experiment id.", required=True) 48 | flags.DEFINE_string("env", None, "Environment name to run.", required=True) 49 | flags.DEFINE_integer("seed", int(time.time()), "Random seed.") 50 | 51 | 52 | # ===== Learner. ===== # 53 | def get_learner(config, networks, data_iterator, logger) -> RosmoLearner: 54 | """Get ROSMO learner.""" 55 | learner = RosmoLearner( 56 | networks, 57 | demonstrations=data_iterator, 58 | config=config, 59 | logger=logger, 60 | ) 61 | return learner 62 | 63 | 64 | # ===== Eval Actor-Env Loop. ===== # 65 | def get_actor_env_eval_loop( 66 | config, networks, environment, observers, logger 67 | ) -> Tuple[RosmoEvalActor, EnvironmentLoop]: 68 | """Get actor, env and evaluation loop.""" 69 | actor = RosmoEvalActor( 70 | networks, 71 | config, 72 | ) 73 | eval_loop = EvaluationLoop( 74 | environment=environment, 75 | actor=actor, 76 | logger=logger, 77 | should_update=False, 78 | observers=observers, 79 | ) 80 | return actor, eval_loop 81 | 82 | 83 | def get_env_loop_observers() -> List[ExtendedEnvLoopObserver]: 84 | """Get environment loop observers.""" 85 | observers = [] 86 | learning_step_ob = LearningStepObserver() 87 | observers.append(learning_step_ob) 88 | return observers 89 | 90 | 91 | # ===== Misc. ===== # 92 | def get_logger_fn( 93 | exp_full_name: str, 94 | job_name: str, 95 | is_eval: bool = False, 96 | config: Optional[Dict] = None, 97 | ) -> Logger: 98 | """Get logger function.""" 99 | save_data = is_eval 100 | return logger_fn( 101 | exp_name=exp_full_name, 102 | label=job_name, 103 | save_data=save_data and not FLAGS.debug, 104 | use_tb=False, 105 | use_wb=FLAGS.use_wb and not FLAGS.debug, 106 | config=config, 107 | ) 108 | 109 | 110 | def main(_): 111 | """Main program.""" 112 | logging.info(f"Debug mode: {FLAGS.debug}") 113 | random.seed(FLAGS.seed) 114 | np.random.seed(FLAGS.seed) 115 | 116 | platform = jax.lib.xla_bridge.get_backend().platform 117 | num_devices = jax.device_count() 118 | logging.warn(f"Compute platform: {platform} with {num_devices} devices.") 119 | 120 | # ===== Setup. ===== # 121 | cfg = get_config(FLAGS.env) 122 | 123 | env, dataloader = bsuite_env_loader( 124 | env_name=FLAGS.env, 125 | dataset_dir=cfg["data_dir"], 126 | data_percentage=cfg["data_percentage"], 127 | batch_size=cfg["batch_size"], 128 | trajectory_length=cfg["td_steps"] + cfg["unroll_steps"] + 1, 129 | ) 130 | networks = get_bsuite_networks(make_environment_spec(env), cfg) 131 | 132 | # ===== Essentials. ===== # 133 | learner = get_learner( 134 | cfg, 135 | networks, 136 | dataloader, 137 | get_logger_fn( 138 | cfg["exp_full_name"], 139 | "learner", 140 | config=cfg, 141 | ), 142 | ) 143 | observers = get_env_loop_observers() 144 | actor, eval_loop = get_actor_env_eval_loop( 145 | cfg, 146 | networks, 147 | env, 148 | observers, 149 | get_logger_fn(cfg["exp_full_name"], "evaluator", is_eval=True, config=cfg), 150 | ) 151 | evaluate_episodes = 2 if FLAGS.debug else cfg["evaluate_episodes"] 152 | 153 | # ===== Restore. ===== # 154 | init_step = 0 155 | if FLAGS.use_wb and not FLAGS.debug: 156 | wb_name = cfg["exp_full_name"] 157 | wb_cfg = cfg.to_dict() 158 | wandb.init( 159 | project=FLAGS.project, 160 | entity=FLAGS.user, 161 | name=wb_name, 162 | config=wb_cfg, 163 | sync_tensorboard=False, 164 | ) 165 | 166 | # ===== Training Loop. ===== # 167 | for i in range(init_step + 1, cfg["total_steps"]): 168 | learner.step() 169 | for ob in observers: 170 | ob.step() 171 | 172 | if FLAGS.debug or (i + 1) % cfg["eval_period"] == 0: 173 | actor.update_params(learner.save().params) 174 | eval_loop.run(evaluate_episodes) 175 | 176 | if FLAGS.debug: 177 | break 178 | 179 | # ===== Cleanup. ===== # 180 | learner._logger.close() 181 | eval_loop._logger.close() 182 | del env, networks, dataloader, learner, observers, actor, eval_loop 183 | 184 | 185 | if __name__ == "__main__": 186 | app.run(main) 187 | -------------------------------------------------------------------------------- /rosmo/agent/actor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluating actor.""" 16 | from typing import Dict, Optional, Tuple 17 | 18 | import acme 19 | import chex 20 | import dm_env 21 | import jax 22 | import numpy as np 23 | import rlax 24 | from absl import logging 25 | from acme import types 26 | from acme.jax import networks as networks_lib 27 | 28 | from rosmo.agent.improvement_op import mcts_improve, one_step_improve 29 | from rosmo.agent.learning import root_unroll 30 | from rosmo.agent.network import Networks 31 | from rosmo.agent.type import AgentOutput, Params 32 | from rosmo.type import ActorOutput, Array 33 | 34 | 35 | class RosmoEvalActor(acme.core.Actor): 36 | """ROSMO evaluation actor.""" 37 | 38 | def __init__( 39 | self, 40 | networks: Networks, 41 | config: Dict, 42 | ) -> None: 43 | """Init ROSMO evaluation actor.""" 44 | self._networks = networks 45 | 46 | self._environment_specs = networks.environment_specs 47 | self._rng_key = jax.random.PRNGKey(config["seed"]) 48 | self._random_action = False 49 | self._params: Optional[Params] = None 50 | self._timestep: Optional[ActorOutput] = None 51 | 52 | num_bins = config["num_bins"] 53 | discount_factor = config["discount_factor"] 54 | use_mcts = config["mcts"] 55 | sampling = config.get("sampling", False) 56 | num_simulations = config.get("num_simulations", -1) 57 | search_depth = config.get("search_depth", num_simulations) 58 | 59 | def root_step( 60 | rng_key: chex.PRNGKey, 61 | params: Params, 62 | timesteps: ActorOutput, 63 | ) -> Tuple[Array, AgentOutput]: 64 | # Model one-step look-ahead for acting. 65 | trajectory = jax.tree_map( 66 | lambda t: t[None], timesteps 67 | ) # Add a dummy time dimension. 68 | state = networks.representation_network.apply( 69 | params.representation, trajectory.observation 70 | ) 71 | agent_out: AgentOutput = root_unroll( 72 | self._networks, params, num_bins, state 73 | ) 74 | improve_key, sample_key = jax.random.split(rng_key) 75 | 76 | if use_mcts: 77 | logging.info("[Actor] Using MCTS planning.") 78 | mcts_out = mcts_improve( 79 | networks, 80 | improve_key, 81 | params, 82 | agent_out, 83 | num_bins, 84 | discount_factor, 85 | num_simulations, 86 | search_depth, 87 | ) 88 | action = mcts_out.action 89 | else: 90 | agent_out = jax.tree_map( 91 | lambda t: t.squeeze(axis=0), agent_out 92 | ) # Squeeze the dummy time dimension. 93 | if not sampling: 94 | logging.info("[Actor] Using onestep improvement.") 95 | improved_policy, _ = one_step_improve( 96 | self._networks, 97 | improve_key, 98 | params, 99 | agent_out, 100 | num_bins, 101 | discount_factor, 102 | num_simulations, 103 | sampling, 104 | ) 105 | else: 106 | logging.info("[Actor] Using policy.") 107 | improved_policy = jax.nn.softmax(agent_out.policy_logits) 108 | action = rlax.categorical_sample(sample_key, improved_policy) 109 | return action, agent_out 110 | 111 | def batch_step( 112 | rng_key: chex.PRNGKey, 113 | params: Params, 114 | timesteps: ActorOutput, 115 | ) -> Tuple[networks_lib.PRNGKey, Array, AgentOutput]: 116 | batch_size = timesteps.reward.shape[0] 117 | rng_key, step_key = jax.random.split(rng_key) 118 | step_keys = jax.random.split(step_key, batch_size) 119 | batch_root_step = jax.vmap(root_step, (0, None, 0)) 120 | actions, agent_out = batch_root_step(step_keys, params, timesteps) 121 | return rng_key, actions, agent_out 122 | 123 | self._agent_step = jax.jit(batch_step) 124 | 125 | def select_action(self, observation: types.NestedArray) -> types.NestedArray: 126 | """Select action to execute.""" 127 | if self._random_action: 128 | return np.random.randint(0, self._environment_specs.actions.num_values, [1]) 129 | batched_timestep = jax.tree_map( 130 | lambda t: t[None], jax.device_put(self._timestep) 131 | ) 132 | self._rng_key, action, _ = self._agent_step( 133 | self._rng_key, self._params, batched_timestep 134 | ) 135 | action = jax.device_get(action).item() 136 | return action 137 | 138 | def observe_first(self, timestep: dm_env.TimeStep) -> None: 139 | """Observe and record the first timestep.""" 140 | self._timestep = ActorOutput( 141 | action=np.zeros((1,), dtype=np.int32), 142 | reward=np.zeros((1,), dtype=np.float32), 143 | observation=timestep.observation, 144 | is_first=np.ones((1,), dtype=np.float32), 145 | is_last=np.zeros((1,), dtype=np.float32), 146 | ) 147 | 148 | def observe( 149 | self, 150 | action: types.NestedArray, 151 | next_timestep: dm_env.TimeStep, 152 | ) -> None: 153 | """Observe and record a timestep.""" 154 | self._timestep = ActorOutput( 155 | action=action, 156 | reward=next_timestep.reward, 157 | observation=next_timestep.observation, 158 | is_first=next_timestep.first(), # previous last = this first. 159 | is_last=next_timestep.last(), 160 | ) 161 | 162 | def update(self, wait: bool = False) -> None: 163 | """Update.""" 164 | 165 | def update_params(self, params: Params) -> None: 166 | """Update parameters. 167 | 168 | Args: 169 | params (Params): Parameters. 170 | """ 171 | self._params = params 172 | -------------------------------------------------------------------------------- /rosmo/data/bsuite.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """BSuite datasets.""" 16 | import os 17 | from typing import Any, Dict, Tuple 18 | 19 | import dm_env 20 | import numpy as np 21 | import rlds 22 | import tensorflow as tf 23 | import tensorflow_datasets as tfds 24 | from absl import logging 25 | from bsuite.environments.cartpole import Cartpole as _Cartpole 26 | from bsuite.environments.catch import Catch as _Catch 27 | from bsuite.environments.mountain_car import MountainCar as _MountainCar 28 | 29 | from rosmo.data.buffer import UniformBuffer 30 | from rosmo.type import ActorOutput 31 | 32 | 33 | class Cartpole(_Cartpole): 34 | """Carpole environment.""" 35 | 36 | def __init__(self, *args: Any, **kwargs: Any) -> None: 37 | """Init env.""" 38 | super().__init__(*args, **kwargs) 39 | self.episode_id = 0 40 | self.episode_return = 0 41 | self.bsuite_id = "cartpole/0" 42 | 43 | def reset(self) -> dm_env.TimeStep: 44 | """Reset env.""" 45 | self.episode_id += 1 46 | self.episode_return = 0 47 | return super().reset() 48 | 49 | def step(self, action: int) -> dm_env.TimeStep: 50 | """Step env.""" 51 | timestep = super().step(action) 52 | if timestep.reward is not None: 53 | self.episode_return += timestep.reward 54 | return timestep 55 | 56 | 57 | class Catch(_Catch): 58 | """Catch environment.""" 59 | 60 | def __init__(self, *args: Any, **kwargs: Any) -> None: 61 | """Init env.""" 62 | super().__init__(*args, **kwargs) 63 | self.episode_id = 0 64 | self.episode_return = 0 65 | self.bsuite_id = "catch/0" 66 | 67 | def _reset(self) -> dm_env.TimeStep: 68 | self.episode_id += 1 69 | self.episode_return = 0 70 | return super()._reset() 71 | 72 | def _step(self, action: int) -> dm_env.TimeStep: 73 | timestep = super()._step(action) 74 | if timestep.reward is not None: 75 | self.episode_return += timestep.reward 76 | return timestep 77 | 78 | 79 | class MountainCar(_MountainCar): 80 | """Mountain Car environment.""" 81 | 82 | def __init__(self, *args: Any, **kwargs: Any) -> None: 83 | """Init env.""" 84 | super().__init__(*args, **kwargs) 85 | self.episode_id = 0 86 | self.episode_return = 0 87 | self.bsuite_id = "mountain_car/0" 88 | 89 | def _reset(self) -> dm_env.TimeStep: 90 | self.episode_id += 1 91 | self.episode_return = 0 92 | return super()._reset() 93 | 94 | def _step(self, action: int) -> dm_env.TimeStep: 95 | timestep = super()._step(action) 96 | if timestep.reward is not None: 97 | self.episode_return += timestep.reward 98 | return timestep 99 | 100 | 101 | _ENV_FACTORY: Dict[str, Tuple[dm_env.Environment, int]] = { 102 | "cartpole": (Cartpole, 1000), 103 | "catch": (Catch, 2000), 104 | "mountain_car": (MountainCar, 500), 105 | } 106 | 107 | _LOAD_SIZE = 1e7 108 | 109 | SCORES = { 110 | "cartpole": { 111 | "random": 64.833, 112 | "online_dqn": 1001.0, 113 | }, 114 | "catch": { 115 | "random": -0.667, 116 | "online_dqn": 1.0, 117 | }, 118 | "mountain_car": { 119 | "random": -1000.0, 120 | "online_dqn": -102.167, 121 | }, 122 | } 123 | 124 | 125 | def create_bsuite_ds_loader( 126 | env_name: str, dataset_name: str, dataset_percentage: int 127 | ) -> tf.data.Dataset: 128 | """Create BSuite dataset loader. 129 | 130 | Args: 131 | env_name (str): Environment name. 132 | dataset_name (str): Dataset name. 133 | dataset_percentage (int): Fraction of data to be used 134 | 135 | Returns: 136 | tf.data.Dataset: Dataset. 137 | """ 138 | dataset = tfds.builder_from_directory(dataset_name).as_dataset(split="all") 139 | num_trajectory = _ENV_FACTORY[env_name][1] 140 | if dataset_percentage < 100: 141 | idx = np.arange(0, num_trajectory, (100 // dataset_percentage)) 142 | idx += np.random.randint(0, 100 // dataset_percentage, idx.shape) + 1 143 | idx = tf.convert_to_tensor(idx, "int32") 144 | filter_fn = lambda episode: tf.math.equal( 145 | tf.reduce_sum(tf.cast(episode["episode_id"] == idx, "int32")), 1 146 | ) 147 | dataset = dataset.filter(filter_fn) 148 | parse_fn = lambda episode: episode[rlds.STEPS] 149 | dataset = dataset.interleave( 150 | parse_fn, 151 | cycle_length=1, 152 | block_length=1, 153 | deterministic=False, 154 | num_parallel_calls=tf.data.AUTOTUNE, 155 | ) 156 | return dataset 157 | 158 | 159 | def env_loader( 160 | env_name: str, 161 | dataset_dir: str, 162 | data_percentage: int = 100, 163 | batch_size: int = 8, 164 | trajectory_length: int = 1, 165 | **_: Any, 166 | ) -> Tuple[dm_env.Environment, tf.data.Dataset]: 167 | """Get the environment and dataset. 168 | 169 | Args: 170 | env_name (str): Name of the environment. 171 | dataset_dir (str): Directory storing the dataset. 172 | data_percentage (int, optional): Fraction of data to be used. Defaults to 100. 173 | batch_size (int, optional): Batch size. Defaults to 8. 174 | trajectory_length (int, optional): Trajectory length. Defaults to 1. 175 | **_: Other keyword arguments. 176 | 177 | Returns: 178 | Tuple[dm_env.Environment, tf.data.Dataset]: Environment and dataset. 179 | """ 180 | data_name = env_name 181 | if env_name not in _ENV_FACTORY: 182 | _env_setting = env_name.split("_") 183 | if len(_env_setting) > 1: 184 | env_name = "_".join(_env_setting[:-1]) 185 | assert env_name in _ENV_FACTORY, f"env {env_name} not supported" 186 | 187 | dataset_name = os.path.join(dataset_dir, f"{data_name}") 188 | print(dataset_name) 189 | dataset = create_bsuite_ds_loader(env_name, dataset_name, data_percentage) 190 | dataloader = dataset.batch(int(_LOAD_SIZE)).as_numpy_iterator() 191 | data = next(dataloader) 192 | 193 | data_buffer = {} 194 | data_buffer["observation"] = data["observation"] 195 | data_buffer["reward"] = data["reward"] 196 | data_buffer["is_first"] = data["is_first"] 197 | data_buffer["is_last"] = data["is_last"] 198 | data_buffer["action"] = data["action"] 199 | 200 | timesteps = ActorOutput(**data_buffer) 201 | data_size = len(timesteps.reward) 202 | assert data_size < _LOAD_SIZE 203 | 204 | iterator = UniformBuffer( 205 | 0, 206 | data_size, 207 | trajectory_length, 208 | batch_size, 209 | ) 210 | logging.info(f"[Data] {data_size} transitions totally.") 211 | iterator.init_storage(timesteps) 212 | return _ENV_FACTORY[env_name][0](), iterator 213 | -------------------------------------------------------------------------------- /rosmo/agent/improvement_op.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Improvement operators.""" 16 | from typing import Tuple 17 | 18 | import chex 19 | import distrax 20 | import jax 21 | import jax.numpy as jnp 22 | import mctx 23 | from absl import logging 24 | from acme.jax import networks as networks_lib 25 | 26 | from rosmo.agent.network import Networks 27 | from rosmo.agent.type import AgentOutput, Params 28 | from rosmo.agent.utils import inv_value_transform, logits_to_scalar 29 | from rosmo.type import Array 30 | 31 | ## -------------------------------- ## 32 | ## One-step Look-ahead Improvement. ## 33 | ## -------------------------------- ## 34 | 35 | 36 | def model_simulate( 37 | networks: Networks, 38 | params: Params, 39 | num_bins: int, 40 | state: Array, 41 | actions_to_simulate: Array, 42 | ) -> AgentOutput: 43 | """Simulate the learned model using one-step look-ahead.""" 44 | 45 | def fn(state: Array, action: Array) -> Array: 46 | """Dynamics fun for vmap.""" 47 | next_state = networks.transition_network.apply( 48 | params.transition, action[None], state 49 | ) 50 | return next_state 51 | 52 | states_imagined = jax.vmap(fn, (None, 0))(state, actions_to_simulate) 53 | 54 | ( 55 | policy_logits, 56 | reward_logits, 57 | value_logits, 58 | ) = networks.prediction_network.apply(params.prediction, states_imagined) 59 | reward = logits_to_scalar(reward_logits, num_bins) 60 | reward = inv_value_transform(reward) 61 | value = logits_to_scalar(value_logits, num_bins) 62 | value = inv_value_transform(value) 63 | return AgentOutput( 64 | state=states_imagined, 65 | policy_logits=policy_logits, 66 | reward_logits=reward_logits, 67 | reward=reward, 68 | value_logits=value_logits, 69 | value=value, 70 | ) 71 | 72 | 73 | def one_step_improve( 74 | networks: Networks, 75 | rng_key: networks_lib.PRNGKey, 76 | params: Params, 77 | model_root: AgentOutput, 78 | num_bins: int, 79 | discount_factor: float, 80 | num_simulations: int = -1, 81 | sampling: bool = False, 82 | ) -> Tuple[Array, Array]: 83 | """Obtain the one-step look-ahead target policy.""" 84 | environment_specs = networks.environment_specs 85 | 86 | pi_prior = jax.nn.softmax(model_root.policy_logits) 87 | value_prior = model_root.value 88 | 89 | if sampling: 90 | assert num_simulations > 0 91 | logging.info( 92 | f"[Sample] Using {num_simulations} samples to estimate improvement." 93 | ) 94 | pi_sample = distrax.Categorical(probs=pi_prior) 95 | sample_acts = pi_sample.sample( 96 | seed=rng_key, sample_shape=num_simulations) 97 | sample_one_step_out: AgentOutput = model_simulate( 98 | networks, params, num_bins, model_root.state, sample_acts 99 | ) 100 | sample_adv = ( 101 | sample_one_step_out.reward 102 | + discount_factor * sample_one_step_out.value 103 | - value_prior 104 | ) 105 | adv = sample_adv # for log 106 | sample_exp_adv = jnp.exp(sample_adv) 107 | normalizer_raw = (jnp.sum(sample_exp_adv) + 1) / num_simulations 108 | coeff = jnp.zeros_like(pi_prior) 109 | 110 | def body(i: int, val: jnp.ndarray) -> jnp.ndarray: 111 | """Body fun for the loop.""" 112 | normalizer_i = normalizer_raw - sample_exp_adv[i] / num_simulations 113 | delta = jnp.zeros_like(val) 114 | delta = delta.at[sample_acts[i]].set( 115 | sample_exp_adv[i] / normalizer_i) 116 | return val + delta 117 | 118 | coeff = jax.lax.fori_loop(0, num_simulations, body, coeff) 119 | pi_improved = coeff / num_simulations 120 | else: 121 | all_actions = jnp.arange(environment_specs.actions.num_values) 122 | model_one_step_out: AgentOutput = model_simulate( 123 | networks, params, num_bins, model_root.state, all_actions 124 | ) 125 | chex.assert_equal_shape([model_one_step_out.reward, pi_prior]) 126 | chex.assert_equal_shape([model_one_step_out.value, pi_prior]) 127 | adv = ( 128 | model_one_step_out.reward 129 | + discount_factor * model_one_step_out.value 130 | - value_prior 131 | ) 132 | pi_improved = pi_prior * jnp.exp(adv) 133 | pi_improved = pi_improved / jnp.sum(pi_improved) 134 | 135 | chex.assert_equal_shape([pi_improved, pi_prior]) 136 | # pi_improved here might not sum to 1, in which case we use CE 137 | # to conveniently calculate the policy gradients (Eq. 9) 138 | return pi_improved, adv 139 | 140 | 141 | ## ------------------------------------ ## 142 | ## Monte-Carlo Tree Search Improvement. ## 143 | ## ------------------------------------ ## 144 | 145 | 146 | def mcts_improve( 147 | networks: Networks, 148 | rng_key: networks_lib.PRNGKey, 149 | params: Params, 150 | model_root: AgentOutput, 151 | num_bins: int, 152 | discount_factor: float, 153 | num_simulations: int, 154 | search_depth: int, 155 | ) -> mctx.PolicyOutput: 156 | """Obtain the Monte-Carlo Tree Search target policy.""" 157 | 158 | def recurrent_fn( 159 | params: Params, rng_key: networks_lib.PRNGKey, action: Array, state: Array 160 | ) -> Tuple[mctx.RecurrentFnOutput, Array]: 161 | del rng_key 162 | 163 | def fn(state: Array, action: Array) -> Array: 164 | next_state = networks.transition_network.apply( 165 | params.transition, action[None], state 166 | ) 167 | return next_state 168 | 169 | next_state = jax.vmap(fn, (0, 0))(state, action) 170 | 171 | ( 172 | policy_logits, 173 | reward_logits, 174 | value_logits, 175 | ) = networks.prediction_network.apply(params.prediction, next_state) 176 | reward = logits_to_scalar(reward_logits, num_bins) 177 | reward = inv_value_transform(reward) 178 | value = logits_to_scalar(value_logits, num_bins) 179 | value = inv_value_transform(value) 180 | recurrent_fn_output = mctx.RecurrentFnOutput( 181 | reward=reward, 182 | discount=jnp.full_like(value, fill_value=discount_factor), 183 | prior_logits=policy_logits, 184 | value=value, 185 | ) 186 | return recurrent_fn_output, next_state 187 | 188 | root = mctx.RootFnOutput( 189 | prior_logits=model_root.policy_logits, 190 | value=model_root.value, 191 | embedding=model_root.state, 192 | ) 193 | 194 | return mctx.muzero_policy( 195 | params, 196 | rng_key, 197 | root, 198 | recurrent_fn, 199 | num_simulations, 200 | max_depth=search_depth, 201 | ) 202 | -------------------------------------------------------------------------------- /rosmo/loggers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Logger utils.""" 16 | import os 17 | import re 18 | import threading 19 | from typing import Any, Callable, Dict, Mapping, Optional 20 | 21 | import wandb 22 | from acme.utils.loggers import Logger, aggregators 23 | from acme.utils.loggers import asynchronous as async_logger 24 | from acme.utils.loggers import base, csv, filters, terminal 25 | from acme.utils.loggers.tf_summary import TFSummaryLogger 26 | 27 | from rosmo.data.bsuite import SCORES 28 | from rosmo.data.rlu_atari import BASELINES 29 | 30 | 31 | class WBLogger(base.Logger): 32 | """Logger for W&B.""" 33 | 34 | def __init__( 35 | self, 36 | scope: Optional[str] = None, 37 | ) -> None: 38 | """Init WB logger.""" 39 | self._lock = threading.Lock() 40 | self._scope = scope 41 | 42 | def write(self, data: Dict[str, Any]) -> None: 43 | """Log the data.""" 44 | step = data.pop("step", None) 45 | if step is not None: 46 | step = int(step) 47 | with self._lock: 48 | if self._scope is None: 49 | wandb.log( 50 | data, 51 | step=step, 52 | ) 53 | else: 54 | wandb.log( 55 | {f"{self._scope}/{k}": v for k, v in data.items()}, 56 | step=step, 57 | ) 58 | 59 | def close(self) -> None: 60 | """Close WB logger.""" 61 | with self._lock: 62 | wandb.finish() 63 | 64 | 65 | class ResultFilter(base.Logger): 66 | """Postprocessing for normalized score.""" 67 | 68 | def __init__(self, to: base.Logger, game_name: str): 69 | """Init result filter.""" 70 | self._to = to 71 | game_name = re.sub(r"(? float: 82 | return (score - random_score) / (dqn_score - random_score) 83 | 84 | self._normalizer = normalizer 85 | 86 | def write(self, data: base.LoggingData) -> None: 87 | """Write to logger.""" 88 | if "episode_return" in data: 89 | data = { 90 | **data, 91 | "normalized_score": self._normalizer(data.get("episode_return", 0)), 92 | } 93 | self._to.write(data) 94 | 95 | def close(self) -> None: 96 | """Close logger.""" 97 | self._to.close() 98 | 99 | 100 | def make_sail_logger( 101 | exp_name: str, 102 | label: str, 103 | save_data: bool = True, 104 | save_dir: str = "./logs", 105 | use_tb: bool = False, 106 | tb_dir: Optional[str] = None, 107 | use_wb: bool = False, 108 | config: Optional[dict] = None, 109 | time_delta: float = 1.0, 110 | asynchronous: bool = False, 111 | print_fn: Optional[Callable[[str], None]] = None, 112 | serialize_fn: Optional[Callable[[Mapping[str, Any]], str]] = base.to_numpy, 113 | ) -> base.Logger: 114 | """Makes a logger for SAILors. 115 | 116 | Args: 117 | exp_name: Name of the experiment. 118 | label: Name to give to the logger. 119 | save_data: Whether to persist data. 120 | save_dir: Directory to save log data. 121 | use_tb: Whether to use TensorBoard. 122 | tb_dir: Tensorboard directory. 123 | use_wb: Whether to use Weights and Biases. 124 | config: Experiment configurations. 125 | time_delta: Time (in seconds) between logging events. 126 | asynchronous: Whether the write function should block or not. 127 | print_fn: How to print to terminal (defaults to print). 128 | serialize_fn: An optional function to apply to the write inputs before 129 | passing them to the various loggers. 130 | 131 | Returns: 132 | A logger object that responds to logger.write(some_dict). 133 | """ 134 | if not print_fn: 135 | print_fn = print 136 | terminal_logger = terminal.TerminalLogger(label=label, print_fn=print_fn) 137 | 138 | loggers = [terminal_logger] 139 | 140 | if save_data: 141 | os.makedirs(save_dir, exist_ok=True) 142 | fd = open(os.path.join(save_dir, f"{exp_name}.csv"), "a") 143 | loggers.append( 144 | csv.CSVLogger( 145 | directory_or_file=fd, 146 | label=exp_name, 147 | add_uid=False, 148 | flush_every=2, 149 | ) 150 | ) 151 | 152 | if use_wb: 153 | wb_logger = WBLogger(scope=label) 154 | wb_logger = filters.TimeFilter(wb_logger, time_delta) 155 | loggers.append(wb_logger) 156 | 157 | if use_tb: 158 | if tb_dir is None: 159 | tb_dir = "./tblogs" 160 | loggers.append(TFSummaryLogger(tb_dir, label)) 161 | 162 | # Dispatch to all writers and filter Nones and by time. 163 | logger = aggregators.Dispatcher(loggers, serialize_fn) 164 | logger = filters.NoneFilter(logger) 165 | 166 | if config: 167 | logger = ResultFilter(logger, game_name=config["game_name"]) 168 | 169 | if asynchronous: 170 | logger = async_logger.AsyncLogger(logger) 171 | logger = filters.TimeFilter(logger, 5.0) 172 | 173 | return logger 174 | 175 | 176 | def logger_fn( 177 | exp_name: str, 178 | label: str, 179 | save_data: bool = False, 180 | use_tb: bool = True, 181 | use_wb: bool = True, 182 | config: Optional[dict] = None, 183 | time_delta: float = 15.0, 184 | ) -> Logger: 185 | """Get logger function. 186 | 187 | Args: 188 | exp_name (str): Experiment name. 189 | label (str): Experiment label. 190 | save_data (bool, optional): Whether to save data. Defaults to False. 191 | use_tb (bool, optional): Whether to use TB. Defaults to True. 192 | use_wb (bool, optional): Whether to use WB. Defaults to True. 193 | config (Optional[dict], optional): Experiment configurations. Defaults to None. 194 | time_delta (float, optional): Time delta to emit logs. Defaults to 15.0. 195 | 196 | Returns: 197 | Logger: Logger. 198 | """ 199 | tb_path = os.path.join("./tblogs", exp_name) 200 | return make_sail_logger( 201 | exp_name=exp_name, 202 | label=label, 203 | save_data=save_data, 204 | save_dir="./logs", 205 | use_tb=use_tb, 206 | tb_dir=tb_path, 207 | use_wb=use_wb, 208 | config=config, 209 | time_delta=time_delta, # Applied to W&B. 210 | print_fn=print, 211 | ) 212 | -------------------------------------------------------------------------------- /rosmo/data/rl_unplugged_atari_baselines.json: -------------------------------------------------------------------------------- 1 | { 2 | "alien": { 3 | "BC": 2670.0, 4 | "BCQ": 2090.0, 5 | "DQN": 1690.0, 6 | "IQN": 2860.0, 7 | "REM": 1730.0, 8 | "online_dqn": 2766.81, 9 | "random": 199.8 10 | }, 11 | "amidar": { 12 | "BC": 256.0, 13 | "BCQ": 254.0, 14 | "DQN": 224.0, 15 | "IQN": 351.0, 16 | "REM": 214.0, 17 | "online_dqn": 1556.96, 18 | "random": 3.19 19 | }, 20 | "assault": { 21 | "BC": 1810.0, 22 | "BCQ": 2260.0, 23 | "DQN": 1940.0, 24 | "IQN": 2180.0, 25 | "REM": 3070.0, 26 | "online_dqn": 1946.1, 27 | "random": 235.15 28 | }, 29 | "asterix": { 30 | "BC": 2960.0, 31 | "BCQ": 1930.0, 32 | "DQN": 1520.0, 33 | "IQN": 5710.0, 34 | "REM": 4890.0, 35 | "online_dqn": 4131.77, 36 | "random": 279.06 37 | }, 38 | "atlantis": { 39 | "BC": 2390000.0, 40 | "BCQ": 3200000.0, 41 | "DQN": 3020000.0, 42 | "IQN": 2710000.0, 43 | "REM": 3360000.0, 44 | "online_dqn": 944228.0, 45 | "random": 16973.04 46 | }, 47 | "bank_heist": { 48 | "BC": 1050.0, 49 | "BCQ": 270.0, 50 | "DQN": 50.0, 51 | "IQN": 1110.0, 52 | "REM": 160.0, 53 | "online_dqn": 907.72, 54 | "random": 13.65 55 | }, 56 | "battle_zone": { 57 | "BC": 4800.0, 58 | "BCQ": 25400.0, 59 | "DQN": 25600.0, 60 | "IQN": 16500.0, 61 | "REM": 26200.0, 62 | "online_dqn": 26458.99, 63 | "random": 2786.8 64 | }, 65 | "beam_rider": { 66 | "BC": 1480.0, 67 | "BCQ": 1990.0, 68 | "DQN": 1810.0, 69 | "IQN": 3020.0, 70 | "REM": 2200.0, 71 | "online_dqn": 6453.26, 72 | "random": 362.05 73 | }, 74 | "boxing": { 75 | "BC": 83.9, 76 | "BCQ": 97.2, 77 | "DQN": 96.3, 78 | "IQN": 95.8, 79 | "REM": 97.3, 80 | "online_dqn": 84.11, 81 | "random": 0.79 82 | }, 83 | "breakout": { 84 | "BC": 235.0, 85 | "BCQ": 375.0, 86 | "DQN": 324.0, 87 | "IQN": 314.0, 88 | "REM": 362.0, 89 | "online_dqn": 157.86, 90 | "random": 1.33 91 | }, 92 | "carnival": { 93 | "BC": 3920.0, 94 | "BCQ": 4310.0, 95 | "DQN": 1450.0, 96 | "IQN": 4820.0, 97 | "REM": 2080.0, 98 | "online_dqn": 5339.46, 99 | "random": 669.55 100 | }, 101 | "centipede": { 102 | "BC": 1070.0, 103 | "BCQ": 1430.0, 104 | "DQN": 1250.0, 105 | "IQN": 1830.0, 106 | "REM": 810.0, 107 | "online_dqn": 3972.49, 108 | "random": 2181.66 109 | }, 110 | "chopper_command": { 111 | "BC": 660.0, 112 | "BCQ": 3950.0, 113 | "DQN": 2250.0, 114 | "IQN": 830.0, 115 | "REM": 3610.0, 116 | "online_dqn": 3678.15, 117 | "random": 823.14 118 | }, 119 | "crazy_climber": { 120 | "BC": 123000.0, 121 | "BCQ": 28000.0, 122 | "DQN": 23000.0, 123 | "IQN": 126000.0, 124 | "REM": 42000.0, 125 | "online_dqn": 118080.24, 126 | "random": 8173.59 127 | }, 128 | "demon_attack": { 129 | "BC": 7600.0, 130 | "BCQ": 19300.0, 131 | "DQN": 11000.0, 132 | "IQN": 15500.0, 133 | "REM": 17000.0, 134 | "online_dqn": 6517.02, 135 | "random": 166.02 136 | }, 137 | "double_dunk": { 138 | "BC": -16.4, 139 | "BCQ": -12.9, 140 | "DQN": -17.9, 141 | "IQN": -16.7, 142 | "REM": -17.9, 143 | "online_dqn": -1.22, 144 | "random": -18.42 145 | }, 146 | "enduro": { 147 | "BC": 720.0, 148 | "BCQ": 1390.0, 149 | "DQN": 1210.0, 150 | "IQN": 1700.0, 151 | "REM": 3650.0, 152 | "online_dqn": 1016.28, 153 | "random": 0.0 154 | }, 155 | "fishing_derby": { 156 | "BC": -7.4, 157 | "BCQ": 28.9, 158 | "DQN": 17.0, 159 | "IQN": 20.8, 160 | "REM": 29.3, 161 | "online_dqn": 18.57, 162 | "random": -93.25 163 | }, 164 | "freeway": { 165 | "BC": 21.8, 166 | "BCQ": 16.9, 167 | "DQN": 15.4, 168 | "IQN": 24.7, 169 | "REM": 7.2, 170 | "online_dqn": 26.76, 171 | "random": 0.0 172 | }, 173 | "frostbite": { 174 | "BC": 780.0, 175 | "BCQ": 3520.0, 176 | "DQN": 3230.0, 177 | "IQN": 2630.0, 178 | "REM": 3070.0, 179 | "online_dqn": 1643.65, 180 | "random": 71.99 181 | }, 182 | "gopher": { 183 | "BC": 4900.0, 184 | "BCQ": 8700.0, 185 | "DQN": 2400.0, 186 | "IQN": 11300.0, 187 | "REM": 3700.0, 188 | "online_dqn": 8241.0, 189 | "random": 282.57 190 | }, 191 | "gravitar": { 192 | "BC": 20.0, 193 | "BCQ": 580.0, 194 | "DQN": 500.0, 195 | "IQN": 235.0, 196 | "REM": 424.0, 197 | "online_dqn": 310.56, 198 | "random": 213.71 199 | }, 200 | "hero": { 201 | "BC": 13900.0, 202 | "BCQ": 13200.0, 203 | "DQN": 5200.0, 204 | "IQN": 16200.0, 205 | "REM": 14000.0, 206 | "online_dqn": 16233.54, 207 | "random": 719.15 208 | }, 209 | "ice_hockey": { 210 | "BC": -5.63, 211 | "BCQ": -2.51, 212 | "DQN": -2.88, 213 | "IQN": -4.65, 214 | "REM": -1.16, 215 | "online_dqn": -4.02, 216 | "random": -9.82 217 | }, 218 | "jamesbond": { 219 | "BC": 237.0, 220 | "BCQ": 438.0, 221 | "DQN": 490.0, 222 | "IQN": 699.0, 223 | "REM": 369.0, 224 | "online_dqn": 777.73, 225 | "random": 27.63 226 | }, 227 | "kangaroo": { 228 | "BC": 5690.0, 229 | "BCQ": 1300.0, 230 | "DQN": 820.0, 231 | "IQN": 9120.0, 232 | "REM": 1210.0, 233 | "online_dqn": 14125.11, 234 | "random": 41.3 235 | }, 236 | "krull": { 237 | "BC": 8500.0, 238 | "BCQ": 7780.0, 239 | "DQN": 7480.0, 240 | "IQN": 8470.0, 241 | "REM": 7980.0, 242 | "online_dqn": 7238.51, 243 | "random": 1556.92 244 | }, 245 | "kung_fu_master": { 246 | "BC": 5100.0, 247 | "BCQ": 16900.0, 248 | "DQN": 16100.0, 249 | "IQN": 19500.0, 250 | "REM": 19400.0, 251 | "online_dqn": 26637.88, 252 | "random": 556.31 253 | }, 254 | "ms_pacman": { 255 | "BC": 4040.0, 256 | "BCQ": 3080.0, 257 | "DQN": 2470.0, 258 | "IQN": 4390.0, 259 | "REM": 3150.0, 260 | "online_dqn": 4171.53, 261 | "random": 247.97 262 | }, 263 | "name_this_game": { 264 | "BC": 4100.0, 265 | "BCQ": 12600.0, 266 | "DQN": 11500.0, 267 | "IQN": 9900.0, 268 | "REM": 13000.0, 269 | "online_dqn": 8645.09, 270 | "random": 2401.05 271 | }, 272 | "phoenix": { 273 | "BC": 2940.0, 274 | "BCQ": 6620.0, 275 | "DQN": 6410.0, 276 | "IQN": 4940.0, 277 | "REM": 7480.0, 278 | "online_dqn": 5122.3, 279 | "random": 873.03 280 | }, 281 | "pong": { 282 | "BC": 18.9, 283 | "BCQ": 16.5, 284 | "DQN": 12.9, 285 | "IQN": 19.2, 286 | "REM": 16.5, 287 | "online_dqn": 18.25, 288 | "random": -20.3 289 | }, 290 | "pooyan": { 291 | "BC": 3850.0, 292 | "BCQ": 4200.0, 293 | "DQN": 3180.0, 294 | "IQN": 5000.0, 295 | "REM": 4470.0, 296 | "online_dqn": 4135.32, 297 | "random": 411.36 298 | }, 299 | "qbert": { 300 | "BC": 12600.0, 301 | "BCQ": 12600.0, 302 | "DQN": 10600.0, 303 | "IQN": 13400.0, 304 | "REM": 13100.0, 305 | "online_dqn": 12275.13, 306 | "random": 155.01 307 | }, 308 | "riverraid": { 309 | "BC": 6000.0, 310 | "BCQ": 14200.0, 311 | "DQN": 9100.0, 312 | "IQN": 13000.0, 313 | "REM": 14200.0, 314 | "online_dqn": 12798.88, 315 | "random": 1504.25 316 | }, 317 | "road_runner": { 318 | "BC": 19000.0, 319 | "BCQ": 57400.0, 320 | "DQN": 31700.0, 321 | "IQN": 44700.0, 322 | "REM": 56500.0, 323 | "online_dqn": 47880.48, 324 | "random": 15.46 325 | }, 326 | "robotank": { 327 | "BC": 15.7, 328 | "BCQ": 60.7, 329 | "DQN": 55.7, 330 | "IQN": 42.7, 331 | "REM": 60.5, 332 | "online_dqn": 63.44, 333 | "random": 2.01 334 | }, 335 | "seaquest": { 336 | "BC": 150.0, 337 | "BCQ": 5410.0, 338 | "DQN": 2870.0, 339 | "IQN": 1670.0, 340 | "REM": 5910.0, 341 | "online_dqn": 3233.47, 342 | "random": 81.78 343 | }, 344 | "space_invaders": { 345 | "BC": 790.0, 346 | "BCQ": 2920.0, 347 | "DQN": 2710.0, 348 | "IQN": 2840.0, 349 | "REM": 2810.0, 350 | "online_dqn": 2044.63, 351 | "random": 149.52 352 | }, 353 | "star_gunner": { 354 | "BC": 3000.0, 355 | "BCQ": 2500.0, 356 | "DQN": 1600.0, 357 | "IQN": 39400.0, 358 | "REM": 7500.0, 359 | "online_dqn": 55103.84, 360 | "random": 677.22 361 | }, 362 | "time_pilot": { 363 | "BC": 1950.0, 364 | "BCQ": 5180.0, 365 | "DQN": 5310.0, 366 | "IQN": 3140.0, 367 | "REM": 4490.0, 368 | "online_dqn": 4160.51, 369 | "random": 3450.95 370 | }, 371 | "up_n_down": { 372 | "BC": 16300.0, 373 | "BCQ": 32500.0, 374 | "DQN": 14600.0, 375 | "IQN": 32300.0, 376 | "REM": 27600.0, 377 | "online_dqn": 15677.92, 378 | "random": 513.94 379 | }, 380 | "video_pinball": { 381 | "BC": 27000.0, 382 | "BCQ": 103000.0, 383 | "DQN": 82000.0, 384 | "IQN": 102000.0, 385 | "REM": 313000.0, 386 | "online_dqn": 335055.69, 387 | "random": 26024.42 388 | }, 389 | "wizard_of_wor": { 390 | "BC": 730.0, 391 | "BCQ": 4680.0, 392 | "DQN": 2300.0, 393 | "IQN": 1400.0, 394 | "REM": 2730.0, 395 | "online_dqn": 1787.79, 396 | "random": 686.63 397 | }, 398 | "yars_revenge": { 399 | "BC": 19100.0, 400 | "BCQ": 29100.0, 401 | "DQN": 24900.0, 402 | "IQN": 28400.0, 403 | "REM": 23100.0, 404 | "online_dqn": 26762.98, 405 | "random": 3147.67 406 | }, 407 | "zaxxon": { 408 | "BC": 10.0, 409 | "BCQ": 9430.0, 410 | "DQN": 6050.0, 411 | "IQN": 870.0, 412 | "REM": 8300.0, 413 | "online_dqn": 4681.93, 414 | "random": 10.57 415 | } 416 | } 417 | -------------------------------------------------------------------------------- /experiment/atari/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Atari experiment entry.""" 16 | 17 | import os 18 | import pickle 19 | import random 20 | import time 21 | from typing import Dict, Iterator, List, Optional, Tuple 22 | 23 | import dm_env 24 | import jax 25 | import numpy as np 26 | import tensorflow as tf 27 | import wandb 28 | from absl import app, flags, logging 29 | from acme import EnvironmentLoop 30 | from acme.specs import make_environment_spec 31 | from acme.utils.loggers import Logger 32 | 33 | from experiment.atari.config import get_config 34 | from rosmo.agent.actor import RosmoEvalActor 35 | from rosmo.agent.learning import RosmoLearner 36 | from rosmo.agent.network import Networks, make_atari_networks 37 | from rosmo.data import atari_env_loader 38 | from rosmo.env_loop_observer import ( 39 | EvaluationLoop, 40 | ExtendedEnvLoopObserver, 41 | LearningStepObserver, 42 | ) 43 | from rosmo.loggers import logger_fn 44 | from rosmo.profiler import Profiler 45 | from rosmo.type import ActorOutput 46 | 47 | # ===== Flags. ===== # 48 | FLAGS = flags.FLAGS 49 | flags.DEFINE_boolean("debug", True, "Debug run.") 50 | flags.DEFINE_boolean("profile", False, "Profile codes.") 51 | flags.DEFINE_boolean("use_wb", False, "Use WB to log.") 52 | flags.DEFINE_string("user", "username", "Wandb user id.") 53 | flags.DEFINE_string("project", "rosmo", "Wandb project id.") 54 | 55 | flags.DEFINE_string("exp_id", None, "Experiment id.", required=True) 56 | flags.DEFINE_string("env", None, "Environment name to run.", required=True) 57 | flags.DEFINE_integer("seed", int(time.time()), "Random seed.") 58 | 59 | flags.DEFINE_boolean("sampling", False, "Whether to sample policy target.") 60 | flags.DEFINE_integer("num_simulations", 4, "Simulation budget.") 61 | flags.DEFINE_enum("algo", "rosmo", ["rosmo", "mzu"], "Algorithm to use.") 62 | flags.DEFINE_integer( 63 | "search_depth", 64 | 0, 65 | "Depth of Monte-Carlo Tree Search (only for mzu), \ 66 | defaults to num_simulations.", 67 | ) 68 | 69 | # ===== Learner. ===== # 70 | def get_learner(config, networks, data_iterator, logger) -> RosmoLearner: 71 | """Get ROSMO learner.""" 72 | learner = RosmoLearner( 73 | networks, 74 | demonstrations=data_iterator, 75 | config=config, 76 | logger=logger, 77 | ) 78 | return learner 79 | 80 | 81 | # ===== Eval Actor-Env Loop & Observer. ===== # 82 | def get_actor_env_eval_loop( 83 | config, networks, environment, observers, logger 84 | ) -> Tuple[RosmoEvalActor, EnvironmentLoop]: 85 | """Get actor, env and evaluation loop.""" 86 | actor = RosmoEvalActor( 87 | networks, 88 | config, 89 | ) 90 | eval_loop = EvaluationLoop( 91 | environment=environment, 92 | actor=actor, 93 | logger=logger, 94 | should_update=False, 95 | observers=observers, 96 | ) 97 | return actor, eval_loop 98 | 99 | 100 | def get_env_loop_observers() -> List[ExtendedEnvLoopObserver]: 101 | """Get environment loop observers.""" 102 | observers = [] 103 | learning_step_ob = LearningStepObserver() 104 | observers.append(learning_step_ob) 105 | return observers 106 | 107 | 108 | # ===== Environment & Dataloader. ===== # 109 | def get_env_data_loader(config) -> Tuple[dm_env.Environment, Iterator]: 110 | """Get environment and trajectory data loader.""" 111 | trajectory_length = config["unroll_steps"] + config["td_steps"] + 1 112 | environment, dataset = atari_env_loader( 113 | env_name=config["game_name"], 114 | run_number=config["run_number"], 115 | dataset_dir=config["data_dir"], 116 | stack_size=config["stack_size"], 117 | data_percentage=config["data_percentage"], 118 | trajectory_length=trajectory_length, 119 | shuffle_num_steps=5000 if FLAGS.debug else 50000, 120 | ) 121 | 122 | def transform_timesteps(steps: Dict[str, np.ndarray]) -> ActorOutput: 123 | return ActorOutput( 124 | observation=steps["observation"], 125 | reward=steps["reward"], 126 | is_first=steps["is_first"], 127 | is_last=steps["is_last"], 128 | action=steps["action"], 129 | ) 130 | 131 | dataset = ( 132 | dataset.repeat() 133 | .batch(config["batch_size"]) 134 | .map(transform_timesteps) 135 | .prefetch(tf.data.AUTOTUNE) 136 | ) 137 | options = tf.data.Options() 138 | options.threading.max_intra_op_parallelism = 1 139 | dataset = dataset.with_options(options) 140 | iterator = dataset.as_numpy_iterator() 141 | return environment, iterator 142 | 143 | 144 | # ===== Network. ===== # 145 | def get_networks(config, environment) -> Networks: 146 | """Get environment-specific networks.""" 147 | environment_spec = make_environment_spec(environment) 148 | logging.info(environment_spec) 149 | networks = make_atari_networks( 150 | env_spec=environment_spec, 151 | channels=config["channels"], 152 | num_bins=config["num_bins"], 153 | output_init_scale=config["output_init_scale"], 154 | blocks_representation=config["blocks_representation"], 155 | blocks_prediction=config["blocks_prediction"], 156 | blocks_transition=config["blocks_transition"], 157 | reduced_channels_head=config["reduced_channels_head"], 158 | fc_layers_reward=config["fc_layers_reward"], 159 | fc_layers_value=config["fc_layers_value"], 160 | fc_layers_policy=config["fc_layers_policy"], 161 | ) 162 | return networks 163 | 164 | 165 | # ===== Misc. ===== # 166 | def get_logger_fn( 167 | exp_full_name: str, 168 | job_name: str, 169 | is_eval: bool = False, 170 | config: Optional[Dict] = None, 171 | ) -> Logger: 172 | """Get logger function.""" 173 | save_data = is_eval 174 | return logger_fn( 175 | exp_name=exp_full_name, 176 | label=job_name, 177 | save_data=save_data and not FLAGS.debug, 178 | use_tb=False, 179 | use_wb=FLAGS.use_wb and not FLAGS.debug, 180 | config=config, 181 | ) 182 | 183 | 184 | def main(_): 185 | """Main program.""" 186 | platform = jax.lib.xla_bridge.get_backend().platform 187 | num_devices = jax.device_count() 188 | logging.warn(f"Compute platform: {platform} with {num_devices} devices.") 189 | logging.info(f"Debug mode: {FLAGS.debug}") 190 | random.seed(FLAGS.seed) 191 | np.random.seed(FLAGS.seed) 192 | 193 | # ===== Setup. ===== # 194 | cfg = get_config(FLAGS.env) 195 | env, dataloader = get_env_data_loader(cfg) 196 | networks = get_networks(cfg, env) 197 | 198 | # ===== Essentials. ===== # 199 | learner = get_learner( 200 | cfg, 201 | networks, 202 | dataloader, 203 | get_logger_fn( 204 | cfg["exp_full_name"], 205 | "learner", 206 | config=cfg, 207 | ), 208 | ) 209 | observers = get_env_loop_observers() 210 | actor, eval_loop = get_actor_env_eval_loop( 211 | cfg, 212 | networks, 213 | env, 214 | observers, 215 | get_logger_fn( 216 | cfg["exp_full_name"], 217 | "evaluator", 218 | is_eval=True, 219 | config=cfg, 220 | ), 221 | ) 222 | evaluate_episodes = 2 if FLAGS.debug else cfg["evaluate_episodes"] 223 | 224 | init_step = 0 225 | save_path = os.path.join("./checkpoint", cfg["exp_full_name"]) 226 | os.makedirs(save_path, exist_ok=True) 227 | if FLAGS.profile: 228 | profile_dir = "./profile" 229 | os.makedirs(profile_dir, exist_ok=True) 230 | profiler = Profiler(profile_dir, cfg["exp_full_name"], with_jax=True) 231 | 232 | if FLAGS.use_wb and not (FLAGS.debug or FLAGS.profile): 233 | wb_name = cfg["exp_full_name"] 234 | wb_cfg = cfg.to_dict() 235 | 236 | wandb.init( 237 | project=FLAGS.project, 238 | entity=FLAGS.user, 239 | name=wb_name, 240 | config=wb_cfg, 241 | sync_tensorboard=False, 242 | ) 243 | 244 | # ===== Training Loop. ===== # 245 | for i in range(init_step + 1, cfg["total_steps"]): 246 | learner.step() 247 | for ob in observers: 248 | ob.step() 249 | 250 | if (i + 1) % cfg["save_period"] == 0: 251 | with open(os.path.join(save_path, f"ckpt_{i}.pkl"), "wb") as f: 252 | pickle.dump(learner.save(), f) 253 | if (i + 1) % cfg["eval_period"] == 0: 254 | actor.update_params(learner.save().params) 255 | eval_loop.run(evaluate_episodes) 256 | 257 | if FLAGS.profile: 258 | if i == 100: 259 | profiler.start() 260 | if i == 200: 261 | profiler.stop_and_save() 262 | break 263 | elif FLAGS.debug: 264 | actor.update_params(learner.save().params) 265 | eval_loop.run(evaluate_episodes) 266 | break 267 | 268 | # ===== Cleanup. ===== # 269 | learner._logger.close() 270 | eval_loop._logger.close() 271 | del env, networks, dataloader, learner, observers, actor, eval_loop 272 | 273 | 274 | if __name__ == "__main__": 275 | app.run(main) 276 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Garena Online Private Limited 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /rosmo/data/rlu_atari.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """RL Unplugged Atari datasets.""" 16 | import json 17 | import os 18 | from typing import Any, Callable, Dict, Optional, Tuple, Union 19 | 20 | import dm_env 21 | import gym 22 | import rlds 23 | import tensorflow as tf 24 | import tensorflow_datasets as tfds 25 | from absl import logging 26 | from acme import wrappers 27 | from dm_env import specs 28 | from dopamine.discrete_domains import atari_lib 29 | 30 | from rosmo.type import Array 31 | 32 | with open( 33 | os.path.join( 34 | os.path.dirname(os.path.abspath(__file__)), "rl_unplugged_atari_baselines.json" 35 | ), 36 | "r", 37 | ) as f: 38 | BASELINES = json.load(f) 39 | 40 | 41 | class _BatchToTransition: 42 | """Creates (s,a,r,f,l) transitions.""" 43 | 44 | @staticmethod 45 | def create_transitions(batch: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: 46 | """Create stacked transitions. 47 | 48 | Args: 49 | batch (Dict[str, tf.Tensor]): Data batch 50 | 51 | Returns: 52 | Dict[str, tf.Tensor]: Stacked data batch. 53 | """ 54 | observation = tf.squeeze(batch[rlds.OBSERVATION], axis=-1) 55 | observation = tf.transpose(observation, perm=[1, 2, 0]) 56 | action = batch[rlds.ACTION][-1] 57 | reward = batch[rlds.REWARD][-1] 58 | discount = batch[rlds.DISCOUNT][-1] 59 | return { 60 | "observation": observation, 61 | "action": action, 62 | "reward": reward, 63 | "discount": discount, 64 | "is_first": batch[rlds.IS_FIRST][0], 65 | "is_last": batch[rlds.IS_LAST][-1], 66 | } 67 | 68 | 69 | def _get_trajectory_dataset_fn( 70 | stack_size: int, 71 | trajectory_length: int = 1, 72 | ) -> Callable[[tf.data.Dataset], tf.data.Dataset]: 73 | batch_fn = _BatchToTransition().create_transitions 74 | 75 | def make_trajectory_dataset(episode: tf.data.Dataset) -> tf.data.Dataset: 76 | """Converts an episode of steps to a dataset of custom transitions. 77 | 78 | Episode spec: { 79 | 'checkpoint_id': , 80 | 'episode_id': , 81 | 'episode_return': , 82 | 'steps': <_VariantDataset element_spec={ 83 | 'action': TensorSpec(shape=(), dtype=tf.int64, name=None), 84 | 'discount': TensorSpec(shape=(), dtype=tf.float32, name=None), 85 | 'is_first': TensorSpec(shape=(), dtype=tf.bool, name=None), 86 | 'is_last': TensorSpec(shape=(), dtype=tf.bool, name=None), 87 | 'is_terminal': TensorSpec(shape=(), dtype=tf.bool, name=None), 88 | 'observation': TensorSpec(shape=(84, 84, 1), dtype=tf.uint8, 89 | name=None), 90 | 'reward': TensorSpec(shape=(), dtype=tf.float32, name=None) 91 | } 92 | >} 93 | """ 94 | # Create a dataset of 2-step sequences with overlap of 1. 95 | timesteps: tf.data.Dataset = episode[rlds.STEPS] 96 | batched_steps = rlds.transformations.batch( 97 | timesteps, 98 | size=stack_size, 99 | shift=1, 100 | drop_remainder=True, 101 | ) 102 | transitions = batched_steps.map(batch_fn) 103 | # Batch trajectory. 104 | if trajectory_length > 1: 105 | transitions = transitions.repeat(2) 106 | transitions = transitions.skip( 107 | tf.random.uniform([], 0, trajectory_length, dtype=tf.int64) 108 | ) 109 | trajectory = transitions.batch(trajectory_length, drop_remainder=True) 110 | else: 111 | trajectory = transitions 112 | return trajectory 113 | 114 | return make_trajectory_dataset 115 | 116 | 117 | def _uniformly_subsampled_atari_data( 118 | dataset_name: str, 119 | data_percent: int, 120 | data_dir: str, 121 | ) -> tf.data.Dataset: 122 | ds_builder = tfds.builder(dataset_name) 123 | data_splits = [] 124 | total_num_episode = 0 125 | for split, info in ds_builder.info.splits.items(): 126 | # Convert `data_percent` to number of episodes to allow 127 | # for fractional percentages. 128 | num_episodes = int((data_percent / 100) * info.num_examples) 129 | total_num_episode += num_episodes 130 | if num_episodes == 0: 131 | raise ValueError(f"{data_percent}% leads to 0 episodes in {split}!") 132 | # Sample first `data_percent` episodes from each of the data split. 133 | data_splits.append(f"{split}[:{num_episodes}]") 134 | # Interleave episodes across different splits/checkpoints. 135 | # Set `shuffle_files=True` to shuffle episodes across files within splits. 136 | read_config = tfds.ReadConfig( 137 | interleave_cycle_length=len(data_splits), 138 | shuffle_reshuffle_each_iteration=True, 139 | enable_ordering_guard=False, 140 | ) 141 | logging.info(f"Total number of episode = {total_num_episode}") 142 | return tfds.load( 143 | dataset_name, 144 | data_dir=data_dir, 145 | split="+".join(data_splits), 146 | read_config=read_config, 147 | shuffle_files=True, 148 | ) 149 | 150 | 151 | def create_atari_ds_loader( 152 | env_name: str, 153 | run_number: int, 154 | dataset_dir: str, 155 | stack_size: int = 4, 156 | data_percentage: int = 10, 157 | trajectory_fn: Optional[Callable] = None, 158 | shuffle_num_episodes: int = 1000, 159 | shuffle_num_steps: int = 50000, 160 | trajectory_length: int = 10, 161 | **_: Any, 162 | ) -> tf.data.Dataset: 163 | """Create Atari dataset loader. 164 | 165 | Args: 166 | env_name (str): Environment name. 167 | run_number (int): Run number. 168 | dataset_dir (str): Directory to the dataset. 169 | stack_size (int, optional): Stack size. Defaults to 4. 170 | data_percentage (int, optional): Fraction of data to be used. Defaults to 10. 171 | trajectory_fn (Optional[Callable], optional): Function to form trajectory. 172 | Defaults to None. 173 | shuffle_num_episodes (int, optional): Number of episodes to shuffle. 174 | Defaults to 1000. 175 | shuffle_num_steps (int, optional): Number of steps to shuffle. 176 | Defaults to 50000. 177 | trajectory_length (int, optional): Trajectory length. Defaults to 10. 178 | **_: Other keyword arguments. 179 | 180 | Returns: 181 | tf.data.Dataset: Dataset. 182 | """ 183 | if trajectory_fn is None: 184 | trajectory_fn = _get_trajectory_dataset_fn(stack_size, trajectory_length) 185 | dataset_name = f"rlu_atari_checkpoints_ordered/{env_name}_run_{run_number}" 186 | # Create a dataset of episodes sampling `data_percent`% episodes 187 | # from each of the data split. 188 | dataset = _uniformly_subsampled_atari_data( 189 | dataset_name, data_percentage, dataset_dir 190 | ) 191 | # Shuffle the episodes to avoid consecutive episodes. 192 | dataset = dataset.shuffle(shuffle_num_episodes) 193 | # Interleave=1 keeps ordered sequential steps. 194 | dataset = dataset.interleave( 195 | trajectory_fn, 196 | cycle_length=100, 197 | block_length=1, 198 | deterministic=False, 199 | num_parallel_calls=tf.data.AUTOTUNE, 200 | ) 201 | # Shuffle trajectories in the dataset. 202 | dataset = dataset.shuffle( 203 | shuffle_num_steps // trajectory_length, 204 | reshuffle_each_iteration=True, 205 | ) 206 | return dataset 207 | 208 | 209 | class _AtariDopamineWrapper(dm_env.Environment): 210 | """Wrapper for Atari Dopamine environmnet.""" 211 | 212 | def __init__(self, env: gym.Env, max_episode_steps: int = 108000): 213 | self._env = env 214 | self._max_episode_steps = max_episode_steps 215 | self._episode_steps = 0 216 | self._reset_next_episode = True 217 | self._reset_next_step = True 218 | 219 | def reset(self) -> dm_env.TimeStep: 220 | self._episode_steps = 0 221 | self._reset_next_step = False 222 | observation = self._env.reset() 223 | return dm_env.restart(observation.squeeze(-1)) # type: ignore 224 | 225 | def step(self, action: Union[int, Array]) -> dm_env.TimeStep: 226 | if self._reset_next_step: 227 | return self.reset() 228 | if not isinstance(action, int): 229 | action = action.item() 230 | observation, reward, terminal, _ = self._env.step(action) # type: ignore 231 | observation = observation.squeeze(-1) 232 | discount = 1 - float(terminal) 233 | self._episode_steps += 1 234 | if terminal: 235 | self._reset_next_episode = True 236 | return dm_env.termination(reward, observation) 237 | if self._episode_steps == self._max_episode_steps: 238 | self._reset_next_episode = True 239 | return dm_env.truncation(reward, observation, discount) 240 | return dm_env.transition(reward, observation, discount) 241 | 242 | def observation_spec(self) -> specs.Array: 243 | space = self._env.observation_space 244 | return specs.Array(space.shape[:-1], space.dtype) # type: ignore 245 | 246 | def action_spec(self) -> specs.DiscreteArray: 247 | return specs.DiscreteArray(self._env.action_space.n) # type: ignore 248 | 249 | def render(self, mode: str = "rgb_array") -> Any: 250 | """Render the environment. 251 | 252 | Args: 253 | mode (str, optional): Mode of rendering. Defaults to "rgb_array". 254 | 255 | Returns: 256 | Any: Rendered result. 257 | """ 258 | return self._env.render(mode) 259 | 260 | 261 | def environment(game: str, stack_size: int) -> dm_env.Environment: 262 | """Atari environment.""" 263 | env = atari_lib.create_atari_environment(game_name=game, sticky_actions=True) 264 | env = _AtariDopamineWrapper(env, max_episode_steps=20_000) 265 | env = wrappers.FrameStackingWrapper(env, num_frames=stack_size) 266 | return wrappers.SinglePrecisionWrapper(env) 267 | 268 | 269 | def env_loader( 270 | env_name: str, 271 | run_number: int, 272 | dataset_dir: str, 273 | stack_size: int = 4, 274 | data_percentage: int = 10, 275 | trajectory_fn: Optional[Callable] = None, 276 | shuffle_num_episodes: int = 1000, 277 | shuffle_num_steps: int = 50000, 278 | trajectory_length: int = 10, 279 | **_: Any, 280 | ) -> Tuple[dm_env.Environment, tf.data.Dataset]: 281 | """Get the environment and dataset. 282 | 283 | Args: 284 | env_name (str): Name of the environment. 285 | run_number (int): Run number of the dataset. 286 | dataset_dir (str): Directory storing the dataset. 287 | stack_size (int, optional): Number of frame stacking. Defaults to 4. 288 | data_percentage (int, optional): Fraction of data to be used. Defaults to 10. 289 | trajectory_fn (Optional[Callable], optional): Function to form trajectory. 290 | Defaults to None. 291 | shuffle_num_episodes (int, optional): Number of episodes to shuffle. 292 | Defaults to 1000. 293 | shuffle_num_steps (int, optional): Number of steps to shuffle. 294 | Defaults to 50000. 295 | trajectory_length (int, optional): Trajectory length. Defaults to 10. 296 | **_: Other keyword arguments. 297 | 298 | Returns: 299 | Tuple[dm_env.Environment, tf.data.Dataset]: Environment and dataset. 300 | """ 301 | return environment(game=env_name, stack_size=stack_size), create_atari_ds_loader( 302 | env_name=env_name, 303 | run_number=run_number, 304 | dataset_dir=dataset_dir, 305 | stack_size=stack_size, 306 | data_percentage=data_percentage, 307 | trajectory_fn=trajectory_fn, 308 | shuffle_num_episodes=shuffle_num_episodes, 309 | shuffle_num_steps=shuffle_num_steps, 310 | trajectory_length=trajectory_length, 311 | ) 312 | -------------------------------------------------------------------------------- /rosmo/agent/network.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Haiku neural network modules.""" 16 | import dataclasses 17 | from typing import Any, Dict, List, Optional, Tuple 18 | 19 | import haiku as hk 20 | import jax 21 | import jax.numpy as jnp 22 | from acme import specs 23 | from acme.jax import networks as networks_lib 24 | from acme.jax import utils 25 | 26 | from rosmo.type import Forwardable 27 | 28 | 29 | def get_prediction_head_layers( 30 | reduced_channels_head: int, 31 | mlp_layers: List[int], 32 | num_predictions: int, 33 | w_init: Optional[hk.initializers.Initializer] = None, 34 | ) -> List[Forwardable]: 35 | """Get prediction head layers. 36 | 37 | Args: 38 | reduced_channels_head (int): Conv reduced channels. 39 | mlp_layers (List[int]): MLP layers' hidden units. 40 | num_predictions (int): Output size. 41 | w_init (Optional[hk.initializers.Initializer], optional): Weight 42 | initialization. Defaults to None. 43 | 44 | Returns: 45 | List[Forwardable]: List of layers. 46 | """ 47 | layers: List[Forwardable] = [ 48 | hk.Conv2D( 49 | reduced_channels_head, 50 | kernel_shape=1, 51 | stride=1, 52 | padding="SAME", 53 | with_bias=False, 54 | ), 55 | hk.LayerNorm(axis=(-3, -2, -1), create_scale=True, create_offset=True), 56 | jax.nn.relu, 57 | hk.Flatten(-3), 58 | ] 59 | for l in mlp_layers: 60 | layers.extend( 61 | [ 62 | hk.Linear(l, with_bias=False), 63 | hk.LayerNorm(axis=-1, create_scale=True, create_offset=True), 64 | jax.nn.relu, 65 | ] 66 | ) 67 | layers.append(hk.Linear(num_predictions, w_init=w_init)) 68 | return layers 69 | 70 | 71 | def get_ln_relu_layers() -> List[Forwardable]: 72 | """Get LN relu layers. 73 | 74 | Returns: 75 | List[Forwardable]: LayerNorm+relu. 76 | """ 77 | return [ 78 | hk.LayerNorm(axis=(-3, -2, -1), create_scale=True, create_offset=True), 79 | jax.nn.relu, 80 | ] 81 | 82 | 83 | class ResConvBlock(hk.Module): 84 | """A residual convolutional block in pre-activation style.""" 85 | 86 | def __init__( 87 | self, 88 | channels: int, 89 | stride: int, 90 | use_projection: bool, 91 | name: str = "res_conv_block", 92 | ): 93 | """Init residual block.""" 94 | super().__init__(name=name) 95 | self._use_projection = use_projection 96 | if use_projection: 97 | self._proj_conv = hk.Conv2D( 98 | channels, kernel_shape=3, stride=stride, padding="SAME", with_bias=False 99 | ) 100 | self._conv_0 = hk.Conv2D( 101 | channels, kernel_shape=3, stride=stride, padding="SAME", with_bias=False 102 | ) 103 | self._ln_0 = hk.LayerNorm( 104 | axis=(-3, -2, -1), create_scale=True, create_offset=True 105 | ) 106 | self._conv_1 = hk.Conv2D( 107 | channels, kernel_shape=3, stride=1, padding="SAME", with_bias=False 108 | ) 109 | self._ln_1 = hk.LayerNorm( 110 | axis=(-3, -2, -1), create_scale=True, create_offset=True 111 | ) 112 | 113 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 114 | """Forward ResBlock.""" 115 | # NOTE: Using LayerNorm is fine 116 | # (https://arxiv.org/pdf/2104.06294.pdf Appendix A). 117 | shortcut = out = x 118 | out = self._ln_0(out) 119 | out = jax.nn.relu(out) 120 | if self._use_projection: 121 | shortcut = self._proj_conv(out) 122 | out = hk.Sequential( 123 | [ 124 | self._conv_0, 125 | self._ln_1, 126 | jax.nn.relu, 127 | self._conv_1, 128 | ] 129 | )(out) 130 | return shortcut + out 131 | 132 | 133 | class Representation(hk.Module): 134 | """Representation encoding module.""" 135 | 136 | def __init__( 137 | self, 138 | channels: int, 139 | num_blocks: int, 140 | name: str = "representation", 141 | ): 142 | """Init representatioin function.""" 143 | super().__init__(name=name) 144 | self._channels = channels 145 | self._num_blocks = num_blocks 146 | 147 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray: 148 | """Forward representation function.""" 149 | # 1. Downsampling. 150 | torso: List[Forwardable] = [ 151 | lambda x: x / 255.0, 152 | hk.Conv2D( 153 | self._channels // 2, 154 | kernel_shape=3, 155 | stride=2, 156 | padding="SAME", 157 | with_bias=False, 158 | ), 159 | ] 160 | torso.extend( 161 | [ 162 | ResConvBlock(self._channels // 2, stride=1, use_projection=False) 163 | for _ in range(1) 164 | ] 165 | ) 166 | torso.append(ResConvBlock(self._channels, stride=2, use_projection=True)) 167 | torso.extend( 168 | [ 169 | ResConvBlock(self._channels, stride=1, use_projection=False) 170 | for _ in range(1) 171 | ] 172 | ) 173 | torso.append( 174 | hk.AvgPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding="SAME") 175 | ) 176 | torso.extend( 177 | [ 178 | ResConvBlock(self._channels, stride=1, use_projection=False) 179 | for _ in range(1) 180 | ] 181 | ) 182 | torso.append( 183 | hk.AvgPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding="SAME") 184 | ) 185 | 186 | # 2. Encoding. 187 | torso.extend( 188 | [ 189 | ResConvBlock(self._channels, stride=1, use_projection=False) 190 | for _ in range(self._num_blocks) 191 | ] 192 | ) 193 | return hk.Sequential(torso)(observations) 194 | 195 | 196 | class Transition(hk.Module): 197 | """Dynamics transition module.""" 198 | 199 | def __init__( 200 | self, 201 | channels: int, 202 | num_blocks: int, 203 | name: str = "transition", 204 | ): 205 | """Init transition function.""" 206 | super().__init__(name=name) 207 | self._channels = channels 208 | self._num_blocks = num_blocks 209 | 210 | def __call__( 211 | self, encoded_action: jnp.ndarray, prev_state: jnp.ndarray 212 | ) -> jnp.ndarray: 213 | """Forward transition function.""" 214 | channels = prev_state.shape[-1] 215 | shortcut = prev_state 216 | 217 | prev_state = hk.LayerNorm( 218 | axis=(-3, -2, -1), create_scale=True, create_offset=True 219 | )(prev_state) 220 | prev_state = jax.nn.relu(prev_state) 221 | 222 | x_and_h = jnp.concatenate([prev_state, encoded_action], axis=-1) 223 | out = hk.Conv2D( 224 | self._channels, 225 | kernel_shape=3, 226 | stride=1, 227 | padding="SAME", 228 | with_bias=False, 229 | )(x_and_h) 230 | out += shortcut # Residual link to maintain recurrent info flow. 231 | 232 | res_layers = [ 233 | ResConvBlock(channels, stride=1, use_projection=False) 234 | for _ in range(self._num_blocks) 235 | ] 236 | out = hk.Sequential(res_layers)(out) 237 | return out 238 | 239 | 240 | class Prediction(hk.Module): 241 | """Policy, value and reward prediction module.""" 242 | 243 | def __init__( 244 | self, 245 | num_blocks: int, 246 | num_actions: int, 247 | num_bins: int, 248 | channel: int, 249 | fc_layers_reward: List[int], 250 | fc_layers_value: List[int], 251 | fc_layers_policy: List[int], 252 | output_init_scale: float, 253 | name: str = "prediction", 254 | ) -> None: 255 | """Init prediction function.""" 256 | super().__init__(name=name) 257 | self._num_blocks = num_blocks 258 | self._num_actions = num_actions 259 | self._num_bins = num_bins 260 | self._channel = channel 261 | self._fc_layers_reward = fc_layers_reward 262 | self._fc_layers_value = fc_layers_value 263 | self._fc_layers_policy = fc_layers_policy 264 | self._output_init_scale = output_init_scale 265 | 266 | def __call__( 267 | self, states: jnp.ndarray 268 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: 269 | """Forward prediction function.""" 270 | output_init = hk.initializers.VarianceScaling(scale=self._output_init_scale) 271 | reward_head, value_head, policy_head = [], [], [] 272 | 273 | # Add LN+Relu due to pre-activation. 274 | reward_head.extend(get_ln_relu_layers()) 275 | value_head.extend(get_ln_relu_layers()) 276 | policy_head.extend(get_ln_relu_layers()) 277 | 278 | reward_head.extend( 279 | get_prediction_head_layers( 280 | self._channel, 281 | self._fc_layers_reward, 282 | self._num_bins, 283 | output_init, 284 | ) 285 | ) 286 | reward_logits = hk.Sequential(reward_head)(states) 287 | 288 | res_layers = [ 289 | ResConvBlock(states.shape[-1], stride=1, use_projection=False) 290 | for _ in range(self._num_blocks) 291 | ] 292 | out = hk.Sequential(res_layers)(states) 293 | 294 | value_head.extend( 295 | get_prediction_head_layers( 296 | self._channel, 297 | self._fc_layers_value, 298 | self._num_bins, 299 | output_init, 300 | ) 301 | ) 302 | value_logits = hk.Sequential(value_head)(out) 303 | 304 | policy_head.extend( 305 | get_prediction_head_layers( 306 | self._channel, 307 | self._fc_layers_policy, 308 | self._num_actions, 309 | output_init, 310 | ) 311 | ) 312 | policy_logits = hk.Sequential(policy_head)(out) 313 | return policy_logits, reward_logits, value_logits 314 | 315 | 316 | @dataclasses.dataclass 317 | class Networks: 318 | """ROSMO Networks.""" 319 | 320 | representation_network: networks_lib.FeedForwardNetwork 321 | transition_network: networks_lib.FeedForwardNetwork 322 | prediction_network: networks_lib.FeedForwardNetwork 323 | 324 | environment_specs: specs.EnvironmentSpec 325 | 326 | 327 | def make_atari_networks( 328 | env_spec: specs.EnvironmentSpec, 329 | channels: int, 330 | num_bins: int, 331 | output_init_scale: float, 332 | blocks_representation: int, 333 | blocks_prediction: int, 334 | blocks_transition: int, 335 | reduced_channels_head: int, 336 | fc_layers_reward: List[int], 337 | fc_layers_value: List[int], 338 | fc_layers_policy: List[int], 339 | ) -> Networks: 340 | """Make Atari networks. 341 | 342 | Args: 343 | env_spec (specs.EnvironmentSpec): Environment spec. 344 | channels (int): Convolution channels. 345 | num_bins (int): Number of bins. 346 | output_init_scale (float): Weight init scale. 347 | blocks_representation (int): Number of blocks for representation. 348 | blocks_prediction (int): Number of blocks for prediction. 349 | blocks_transition (int): Number of blocks for transition. 350 | reduced_channels_head (int): Reduced conv channels for prediction head. 351 | fc_layers_reward (List[int]): Fully connected layers for reward prediction. 352 | fc_layers_value (List[int]): Fully connected layers for value prediction. 353 | fc_layers_policy (List[int]): Fully connected layers for policy prediction. 354 | 355 | Returns: 356 | Networks: Constructed networks. 357 | """ 358 | action_space_size = env_spec.actions.num_values 359 | 360 | def _representation_fun(observations: jnp.ndarray) -> jnp.ndarray: 361 | network = Representation(channels, blocks_representation) 362 | state = network(observations) 363 | return state 364 | 365 | representation = hk.without_apply_rng(hk.transform(_representation_fun)) 366 | hidden_channels = channels 367 | 368 | def _transition_fun(action: jnp.ndarray, state: jnp.ndarray) -> jnp.ndarray: 369 | action = hk.one_hot(action, action_space_size)[None, :] 370 | encoded_action = jnp.broadcast_to(action, state.shape[:-1] + action.shape[-1:]) 371 | 372 | network = Transition(hidden_channels, blocks_transition) 373 | next_state = network(encoded_action, state) 374 | return next_state 375 | 376 | transition = hk.without_apply_rng(hk.transform(_transition_fun)) 377 | prediction = hk.without_apply_rng( 378 | hk.transform( 379 | lambda states: Prediction( 380 | blocks_prediction, 381 | action_space_size, 382 | num_bins, 383 | reduced_channels_head, 384 | fc_layers_reward, 385 | fc_layers_value, 386 | fc_layers_policy, 387 | output_init_scale, 388 | )(states) 389 | ) 390 | ) 391 | 392 | dummy_action = jnp.array([env_spec.actions.generate_value()]) 393 | dummy_obs = utils.zeros_like(env_spec.observations) 394 | 395 | def _dummy_state(key: networks_lib.PRNGKey) -> jnp.ndarray: 396 | encoder_params = representation.init(key, dummy_obs) 397 | dummy_state = representation.apply(encoder_params, dummy_obs) 398 | return dummy_state 399 | 400 | return Networks( 401 | representation_network=networks_lib.FeedForwardNetwork( 402 | lambda key: representation.init(key, dummy_obs), representation.apply 403 | ), 404 | transition_network=networks_lib.FeedForwardNetwork( 405 | lambda key: transition.init(key, dummy_action, _dummy_state(key)), 406 | transition.apply, 407 | ), 408 | prediction_network=networks_lib.FeedForwardNetwork( 409 | lambda key: prediction.init(key, _dummy_state(key)), 410 | prediction.apply, 411 | ), 412 | environment_specs=env_spec, 413 | ) 414 | 415 | 416 | def get_bsuite_networks( 417 | env_spec: specs.EnvironmentSpec, config: Dict[str, Any] 418 | ) -> Networks: 419 | """Make BSuite networks. 420 | 421 | Args: 422 | env_spec (specs.EnvironmentSpec): Environment specifications. 423 | config (Dict[str, Any]): Configurations. 424 | 425 | Returns: 426 | Networks: Constructed networks. 427 | """ 428 | action_space_size = env_spec.actions.num_values 429 | 430 | def _representation_fun(observations: jnp.ndarray) -> jnp.ndarray: 431 | network = hk.Sequential([hk.Flatten(), hk.nets.MLP(config["encoder_layers"])]) 432 | state = network(observations) 433 | return state 434 | 435 | representation = hk.without_apply_rng(hk.transform(_representation_fun)) 436 | 437 | def _transition_fun(action: jnp.ndarray, state: jnp.ndarray) -> jnp.ndarray: 438 | action = hk.one_hot(action, action_space_size) 439 | network = hk.nets.MLP(config["dynamics_layers"]) 440 | sa = jnp.concatenate( 441 | [jnp.reshape(state, (-1, state.shape[-1])), action], axis=-1 442 | ) 443 | next_state = network(sa).squeeze() 444 | return next_state 445 | 446 | transition = hk.without_apply_rng(hk.transform(_transition_fun)) 447 | 448 | def _prediction_fun(state: jnp.ndarray) -> Tuple[jnp.ndarray, ...]: 449 | network = hk.nets.MLP(config["prediction_layers"], activate_final=True) 450 | head_state = network(state) 451 | output_init = hk.initializers.VarianceScaling(scale=0.0) 452 | head_policy = hk.nets.MLP([action_space_size], w_init=output_init) 453 | head_value = hk.nets.MLP([config["num_bins"]], w_init=output_init) 454 | head_reward = hk.nets.MLP([config["num_bins"]], w_init=output_init) 455 | 456 | return ( 457 | head_policy(head_state), 458 | head_reward(head_state), 459 | head_value(head_state), 460 | ) 461 | 462 | prediction = hk.without_apply_rng(hk.transform(_prediction_fun)) 463 | 464 | dummy_action = utils.add_batch_dim(jnp.array(env_spec.actions.generate_value())) 465 | 466 | dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) 467 | 468 | def _dummy_state(key: networks_lib.PRNGKey) -> jnp.ndarray: 469 | encoder_params = representation.init(key, dummy_obs) 470 | dummy_state = representation.apply(encoder_params, dummy_obs) 471 | return dummy_state 472 | 473 | return Networks( 474 | representation_network=networks_lib.FeedForwardNetwork( 475 | lambda key: representation.init(key, dummy_obs), representation.apply 476 | ), 477 | transition_network=networks_lib.FeedForwardNetwork( 478 | lambda key: transition.init(key, dummy_action, _dummy_state(key)), 479 | transition.apply, 480 | ), 481 | prediction_network=networks_lib.FeedForwardNetwork( 482 | lambda key: prediction.init(key, _dummy_state(key)), 483 | prediction.apply, 484 | ), 485 | environment_specs=env_spec, 486 | ) 487 | -------------------------------------------------------------------------------- /rosmo/agent/learning.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Agent learner.""" 16 | import time 17 | from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple 18 | 19 | import acme 20 | import distrax 21 | import haiku as hk 22 | import jax 23 | import jax.numpy as jnp 24 | import mctx 25 | import optax 26 | import rlax 27 | import tree 28 | from absl import logging 29 | from acme.jax import networks as networks_lib 30 | from acme.jax import utils 31 | from acme.utils import loggers 32 | 33 | from rosmo.agent.improvement_op import mcts_improve, one_step_improve 34 | from rosmo.agent.network import Networks 35 | from rosmo.agent.type import AgentOutput, Params 36 | from rosmo.agent.utils import ( 37 | inv_value_transform, 38 | logits_to_scalar, 39 | scalar_to_two_hot, 40 | scale_gradient, 41 | value_transform, 42 | ) 43 | from rosmo.type import ActorOutput, Array 44 | 45 | 46 | class TrainingState(NamedTuple): 47 | """Training state.""" 48 | 49 | optimizer_state: optax.OptState 50 | params: Params 51 | target_params: Params 52 | 53 | step: int 54 | 55 | 56 | class RosmoLearner(acme.core.Learner): 57 | """ROSMO learner.""" 58 | 59 | def __init__( 60 | self, 61 | networks: Networks, 62 | demonstrations: Iterator[ActorOutput], 63 | config: Dict, 64 | logger: Optional[loggers.Logger] = None, 65 | ) -> None: 66 | """Init ROSMO learner. 67 | 68 | Args: 69 | networks (Networks): ROSMO networks. 70 | demonstrations (Iterator[ActorOutput]): Data loader. 71 | config (Dict): Configurations. 72 | logger (Optional[loggers.Logger], optional): Logger. Defaults to None. 73 | """ 74 | discount_factor = config["discount_factor"] 75 | weight_decay = config["weight_decay"] 76 | value_coef = config["value_coef"] 77 | behavior_coef = config["behavior_coef"] 78 | policy_coef = config["policy_coef"] 79 | unroll_steps = config["unroll_steps"] 80 | td_steps = config["td_steps"] 81 | target_update_interval = config["target_update_interval"] 82 | log_interval = config["log_interval"] 83 | batch_size = config["batch_size"] 84 | max_grad_norm = config["max_grad_norm"] 85 | num_bins = config["num_bins"] 86 | use_mcts = config["mcts"] 87 | sampling = config.get("sampling", False) 88 | num_simulations = config.get("num_simulations", -1) 89 | search_depth = config.get("search_depth", num_simulations) 90 | 91 | _batch_categorical_cross_entropy = jax.vmap(rlax.categorical_cross_entropy) 92 | 93 | def loss( 94 | params: Params, 95 | target_params: Params, 96 | trajectory: ActorOutput, 97 | rng_key: networks_lib.PRNGKey, 98 | ) -> Tuple[Array, Dict[str, Array]]: 99 | # Encode obs via learning and target networks, [T, S] 100 | state = networks.representation_network.apply( 101 | params.representation, trajectory.observation 102 | ) 103 | target_state = networks.representation_network.apply( 104 | target_params.representation, trajectory.observation 105 | ) 106 | 107 | # 1) Model unroll, sampling and estimation. 108 | root_state = jax.tree_map(lambda t: t[:1], state) 109 | learner_root = root_unroll(networks, params, num_bins, root_state) 110 | learner_root: AgentOutput = jax.tree_map(lambda t: t[0], learner_root) 111 | 112 | unroll_trajectory: ActorOutput = jax.tree_map( 113 | lambda t: t[: unroll_steps + 1], trajectory 114 | ) 115 | random_action_mask = ( 116 | jnp.cumprod(1.0 - unroll_trajectory.is_first[1:]) == 0.0 117 | ) 118 | action_sequence = unroll_trajectory.action[:unroll_steps] 119 | num_actions = learner_root.policy_logits.shape[-1] 120 | rng_key, action_key = jax.random.split(rng_key) 121 | random_actions = jax.random.choice( 122 | action_key, num_actions, action_sequence.shape, replace=True 123 | ) 124 | simulate_action_sequence = jax.lax.select( 125 | random_action_mask, random_actions, action_sequence 126 | ) 127 | 128 | model_out: AgentOutput = model_unroll( 129 | networks, 130 | params, 131 | num_bins, 132 | learner_root.state, 133 | simulate_action_sequence, 134 | ) 135 | 136 | # Model predictions. 137 | policy_logits = jnp.concatenate( 138 | [ 139 | learner_root.policy_logits[None], 140 | model_out.policy_logits, 141 | ], 142 | axis=0, 143 | ) 144 | 145 | value_logits = jnp.concatenate( 146 | [ 147 | learner_root.value_logits[None], 148 | model_out.value_logits, 149 | ], 150 | axis=0, 151 | ) 152 | 153 | # 2) Model learning targets. 154 | # a) Reward. 155 | rewards = trajectory.reward 156 | reward_target = jax.lax.select( 157 | random_action_mask, 158 | jnp.zeros_like(rewards[:unroll_steps]), 159 | rewards[:unroll_steps], 160 | ) 161 | reward_target_transformed = value_transform(reward_target) 162 | reward_logits_target = scalar_to_two_hot( 163 | reward_target_transformed, num_bins 164 | ) 165 | 166 | # b) Policy. 167 | target_roots: AgentOutput = root_unroll( 168 | networks, target_params, num_bins, target_state 169 | ) 170 | search_roots: AgentOutput = jax.tree_map( 171 | lambda t: t[: unroll_steps + 1], target_roots 172 | ) 173 | rng_key, improve_key = jax.random.split(rng_key) 174 | 175 | if use_mcts: 176 | logging.info( 177 | f"[Learning] Using MuZero with simulation={num_simulations}" 178 | f" & depth={search_depth}." 179 | ) 180 | mcts_out = mcts_improve( 181 | networks, 182 | improve_key, 183 | target_params, 184 | target_roots, 185 | num_bins, 186 | discount_factor, 187 | num_simulations, 188 | search_depth, 189 | ) 190 | policy_target, improve_adv = ( 191 | mcts_out.action_weights[: unroll_steps + 1], 192 | 0.0, 193 | ) 194 | else: 195 | logging.info("[Learning] Using ROSMO.") 196 | improve_keys = jax.random.split( 197 | improve_key, search_roots.state.shape[0] 198 | ) 199 | policy_target, improve_adv = jax.vmap( # type: ignore 200 | one_step_improve, 201 | (None, 0, None, 0, None, None, None, None), 202 | )( 203 | networks, 204 | improve_keys, 205 | target_params, 206 | search_roots, 207 | num_bins, 208 | discount_factor, 209 | num_simulations, 210 | sampling, 211 | ) 212 | uniform_policy = jnp.ones_like(policy_target) / num_actions 213 | random_policy_mask = jnp.cumprod(1.0 - unroll_trajectory.is_last) == 0.0 214 | random_policy_mask = jnp.broadcast_to( 215 | random_policy_mask[:, None], policy_target.shape 216 | ) 217 | policy_target = jax.lax.select( 218 | random_policy_mask, uniform_policy, policy_target 219 | ) 220 | policy_target = jax.lax.stop_gradient(policy_target) 221 | 222 | # c) Value. 223 | discounts = (1.0 - trajectory.is_last[1:]) * discount_factor 224 | if use_mcts: 225 | node_values = mcts_out.search_tree.node_values 226 | v_bootstrap = node_values[:, mctx.Tree.ROOT_INDEX] 227 | else: 228 | v_bootstrap = target_roots.value 229 | 230 | def n_step_return(i: int) -> jnp.ndarray: 231 | bootstrap_value = jax.tree_map(lambda t: t[i + td_steps], v_bootstrap) 232 | _rewards = jnp.concatenate( 233 | [rewards[i : i + td_steps], bootstrap_value[None]], axis=0 234 | ) 235 | _discounts = jnp.concatenate( 236 | [jnp.ones((1,)), jnp.cumprod(discounts[i : i + td_steps])], 237 | axis=0, 238 | ) 239 | return jnp.sum(_rewards * _discounts) 240 | 241 | returns = [] 242 | for i in range(unroll_steps + 1): 243 | returns.append(n_step_return(i)) 244 | returns = jnp.stack(returns) 245 | # Value targets for the absorbing state and the states after are 0. 246 | zero_return_mask = jnp.cumprod(1.0 - unroll_trajectory.is_last) == 0.0 247 | value_target = jax.lax.select( 248 | zero_return_mask, jnp.zeros_like(returns), returns 249 | ) 250 | value_target_transformed = value_transform(value_target) 251 | value_logits_target = scalar_to_two_hot(value_target_transformed, num_bins) 252 | value_logits_target = jax.lax.stop_gradient(value_logits_target) 253 | 254 | # 3) Behavior regularization. 255 | behavior_loss = jnp.array(0.0) 256 | if not use_mcts: 257 | in_sample_action = trajectory.action[: unroll_steps + 1] 258 | log_prob = jax.nn.log_softmax(policy_logits) 259 | action_log_prob = log_prob[ 260 | jnp.arange(unroll_steps + 1), in_sample_action 261 | ] 262 | 263 | _target_value = target_roots.value[: unroll_steps + 1] 264 | _target_reward = target_roots.reward[1 : unroll_steps + 1 + 1] 265 | _target_value_prime = target_roots.value[1 : unroll_steps + 1 + 1] 266 | _target_adv = ( 267 | _target_reward 268 | + discount_factor * _target_value_prime 269 | - _target_value 270 | ) 271 | _target_adv = jax.lax.stop_gradient(_target_adv) 272 | behavior_loss = -action_log_prob * jnp.heaviside(_target_adv, 0.0) 273 | # Deal with cross-episode trajectories. 274 | invalid_action_mask = jnp.cumprod(1.0 - trajectory.is_first[1:]) == 0.0 275 | behavior_loss = jax.lax.select( 276 | invalid_action_mask[: unroll_steps + 1], 277 | jnp.zeros_like(behavior_loss), 278 | behavior_loss, 279 | ) 280 | behavior_loss = jnp.mean(behavior_loss) * behavior_coef 281 | 282 | # 4) Compute the losses. 283 | reward_loss = jnp.mean( 284 | _batch_categorical_cross_entropy( 285 | reward_logits_target, model_out.reward_logits 286 | ) 287 | ) 288 | 289 | value_loss = ( 290 | jnp.mean( 291 | _batch_categorical_cross_entropy(value_logits_target, value_logits) 292 | ) 293 | * value_coef 294 | ) 295 | 296 | policy_loss = ( 297 | jnp.mean(_batch_categorical_cross_entropy(policy_target, policy_logits)) 298 | * policy_coef 299 | ) 300 | 301 | total_loss = reward_loss + value_loss + policy_loss + behavior_loss 302 | 303 | if sampling: 304 | # Unnormalized. 305 | def entropy_fn(p: Array) -> Array: 306 | return distrax.Categorical(logits=p).entropy() 307 | 308 | else: 309 | 310 | def entropy_fn(p: Array) -> Array: 311 | return distrax.Categorical(probs=p).entropy() 312 | 313 | policy_target_entropy = jax.vmap(entropy_fn)(policy_target) 314 | policy_entropy = jax.vmap( 315 | lambda l: distrax.Categorical(logits=l).entropy() 316 | )(policy_logits) 317 | 318 | log = { 319 | "reward_target": reward_target, 320 | "reward_prediction": model_out.reward, 321 | "value_target": value_target, 322 | "value_prediction": model_out.value, 323 | "policy_entropy": policy_entropy, 324 | "policy_target_entropy": policy_target_entropy, 325 | "reward_loss": reward_loss, 326 | "value_loss": value_loss, 327 | "policy_loss": policy_loss, 328 | "behavior_loss": behavior_loss, 329 | "improve_advantage": improve_adv, 330 | "total_loss": total_loss, 331 | } 332 | return total_loss, log 333 | 334 | def batch_loss( 335 | params: Params, 336 | target_params: Params, 337 | trajectory: ActorOutput, 338 | rng_key: networks_lib.PRNGKey, 339 | ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]: 340 | bs = len(trajectory.reward) 341 | rng_keys = jax.random.split(rng_key, bs) 342 | losses, log = jax.vmap(loss, (None, None, 0, 0))( 343 | params, 344 | target_params, 345 | trajectory, 346 | rng_keys, 347 | ) 348 | log_mean = {f"{k}_mean": jnp.mean(v) for k, v in log.items()} 349 | std_keys = [ 350 | "reward_target", 351 | "reward_prediction", 352 | "q_val_target", 353 | "q_val_prediction", 354 | "value_target", 355 | "value_prediction", 356 | "improve_advantage", 357 | ] 358 | std_keys = [k for k in std_keys if k in log] 359 | log_std = {f"{k}_std": jnp.std(log[k]) for k in std_keys} 360 | log_mean.update(log_std) 361 | return jnp.mean(losses), log_mean 362 | 363 | def update_step( 364 | state: TrainingState, 365 | trajectory: ActorOutput, 366 | rng_key: networks_lib.PRNGKey, 367 | ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: 368 | params = state.params 369 | optimizer_state = state.optimizer_state 370 | 371 | grads, log = jax.grad(batch_loss, has_aux=True)( 372 | state.params, state.target_params, trajectory, rng_key 373 | ) 374 | grads = jax.lax.pmean(grads, axis_name="i") 375 | network_updates, optimizer_state = optimizer.update( 376 | grads, optimizer_state, params 377 | ) 378 | params = optax.apply_updates(params, network_updates) 379 | log.update( 380 | { 381 | "grad_norm": optax.global_norm(grads), 382 | "update_norm": optax.global_norm(network_updates), 383 | "param_norm": optax.global_norm(params), 384 | } 385 | ) 386 | new_state = TrainingState( 387 | optimizer_state=optimizer_state, 388 | params=params, 389 | target_params=state.target_params, 390 | step=state.step + 1, 391 | ) 392 | return new_state, log 393 | 394 | # Logger. 395 | self._logger = logger or loggers.make_default_logger( 396 | "learner", asynchronous=True, serialize_fn=utils.fetch_devicearray 397 | ) 398 | 399 | # Iterator on demonstration transitions. 400 | self._demonstrations = demonstrations 401 | 402 | # JIT compiler. 403 | self._batch_size = batch_size 404 | self._num_devices = jax.device_count() 405 | assert self._batch_size % self._num_devices == 0 406 | self._update_step = jax.pmap(update_step, axis_name="i") 407 | 408 | # Create initial state. 409 | random_key = jax.random.PRNGKey(config["seed"]) 410 | self._rng_key: networks_lib.PRNGKey 411 | key_r, key_d, key_p, self._rng_key = jax.random.split(random_key, 4) 412 | representation_params = networks.representation_network.init(key_r) 413 | transition_params = networks.transition_network.init(key_d) 414 | prediction_params = networks.prediction_network.init(key_p) 415 | 416 | # Create and initialize optimizer. 417 | params = Params( 418 | representation_params, 419 | transition_params, 420 | prediction_params, 421 | ) 422 | weight_decay_mask = Params( 423 | representation=hk.data_structures.map( 424 | lambda module_name, name, value: True if name == "w" else False, 425 | params.representation, 426 | ), 427 | transition=hk.data_structures.map( 428 | lambda module_name, name, value: True if name == "w" else False, 429 | params.transition, 430 | ), 431 | prediction=hk.data_structures.map( 432 | lambda module_name, name, value: True if name == "w" else False, 433 | params.prediction, 434 | ), 435 | ) 436 | learning_rate = optax.warmup_exponential_decay_schedule( 437 | init_value=0.0, 438 | peak_value=config["learning_rate"], 439 | warmup_steps=config["warmup_steps"], 440 | transition_steps=100_000, 441 | decay_rate=config["learning_rate_decay"], 442 | staircase=True, 443 | ) 444 | optimizer = optax.adamw( 445 | learning_rate=learning_rate, 446 | weight_decay=weight_decay, 447 | mask=weight_decay_mask, 448 | ) 449 | if max_grad_norm: 450 | optimizer = optax.chain(optax.clip_by_global_norm(max_grad_norm), optimizer) 451 | optimizer_state = optimizer.init(params) 452 | target_params = params 453 | 454 | # Learner state. 455 | self._state = TrainingState( 456 | optimizer_state=optimizer_state, 457 | params=params, 458 | target_params=target_params, 459 | step=0, 460 | ) 461 | self._target_update_interval = target_update_interval 462 | 463 | self._state = jax.device_put_replicated(self._state, jax.local_devices()) 464 | 465 | # Do not record timestamps until after the first learning step is done. 466 | # This is to avoid including the time it takes for actors to come online 467 | # and fill the replay buffer. 468 | self._timestamp: Optional[float] = None 469 | self._elapsed = 0.0 470 | self._log_interval = log_interval 471 | self._unroll_steps = unroll_steps 472 | 473 | def step(self) -> None: 474 | """Train step.""" 475 | update_key, self._rng_key = jax.random.split(self._rng_key) 476 | update_keys = jax.random.split(update_key, self._num_devices) 477 | trajectory: ActorOutput = next(self._demonstrations) 478 | trajectory = tree.map_structure( 479 | lambda x: x.reshape( 480 | self._num_devices, self._batch_size // self._num_devices, *x.shape[1:] 481 | ), 482 | trajectory, 483 | ) 484 | 485 | self._state, metrics = self._update_step(self._state, trajectory, update_keys) 486 | 487 | _step = self._state.step[0] # type: ignore 488 | timestamp = time.time() 489 | elapsed: float = 0 490 | if self._timestamp: 491 | elapsed = timestamp - self._timestamp 492 | self._timestamp = timestamp 493 | self._elapsed += elapsed 494 | 495 | if _step % self._target_update_interval == 0: 496 | state: TrainingState = self._state 497 | self._state = TrainingState( 498 | optimizer_state=state.optimizer_state, 499 | params=state.params, 500 | target_params=state.params, 501 | step=state.step, 502 | ) 503 | if _step % self._log_interval == 0: 504 | metrics = jax.tree_util.tree_map(lambda t: t[0], metrics) 505 | metrics = jax.device_get(metrics) 506 | self._logger.write( 507 | { 508 | **metrics, 509 | **{ 510 | "step": _step, 511 | "elapsed_time": self._elapsed, 512 | }, 513 | } 514 | ) 515 | 516 | def get_variables(self, names: List[str]) -> List[Any]: 517 | """Get network parameters.""" 518 | state = self.save() 519 | variables = { 520 | "representation": state.params.representation, 521 | "dynamics": state.params.transition, 522 | "prediction": state.params.prediction, 523 | } 524 | return [variables[name] for name in names] 525 | 526 | def save(self) -> TrainingState: 527 | """Save the training state. 528 | 529 | Returns: 530 | TrainingState: State to be saved. 531 | """ 532 | _state = utils.fetch_devicearray(jax.tree_map(lambda t: t[0], self._state)) 533 | return _state 534 | 535 | def restore(self, state: TrainingState) -> None: 536 | """Restore the training state. 537 | 538 | Args: 539 | state (TrainingState): State to be resumed. 540 | """ 541 | self._state = jax.device_put_replicated(state, jax.local_devices()) 542 | 543 | 544 | def root_unroll( 545 | networks: Networks, 546 | params: Params, 547 | num_bins: int, 548 | state: Array, 549 | ) -> AgentOutput: 550 | """Unroll the learned model from the root node.""" 551 | ( 552 | policy_logits, 553 | reward_logits, 554 | value_logits, 555 | ) = networks.prediction_network.apply(params.prediction, state) 556 | reward = logits_to_scalar(reward_logits, num_bins) 557 | reward = inv_value_transform(reward) 558 | value = logits_to_scalar(value_logits, num_bins) 559 | value = inv_value_transform(value) 560 | return AgentOutput( 561 | state=state, 562 | policy_logits=policy_logits, 563 | reward_logits=reward_logits, 564 | reward=reward, 565 | value_logits=value_logits, 566 | value=value, 567 | ) 568 | 569 | 570 | def model_unroll( 571 | networks: Networks, 572 | params: Params, 573 | num_bins: int, 574 | state: Array, 575 | action_sequence: Array, 576 | ) -> AgentOutput: 577 | """Unroll the learned model with a sequence of actions.""" 578 | 579 | def fn(state: Array, action: Array) -> Tuple[Array, Array]: 580 | """Dynamics fun for scan.""" 581 | next_state = networks.transition_network.apply( 582 | params.transition, action[None], state 583 | ) 584 | next_state = scale_gradient(next_state, 0.5) 585 | return next_state, next_state 586 | 587 | _, state_sequence = jax.lax.scan(fn, state, action_sequence) 588 | ( 589 | policy_logits, 590 | reward_logits, 591 | value_logits, 592 | ) = networks.prediction_network.apply(params.prediction, state_sequence) 593 | reward = logits_to_scalar(reward_logits, num_bins) 594 | reward = inv_value_transform(reward) 595 | value = logits_to_scalar(value_logits, num_bins) 596 | value = inv_value_transform(value) 597 | return AgentOutput( 598 | state=state_sequence, 599 | policy_logits=policy_logits, 600 | reward_logits=reward_logits, 601 | reward=reward, 602 | value_logits=value_logits, 603 | value=value, 604 | ) 605 | --------------------------------------------------------------------------------